diff --git a/README.md b/README.md index 8e470266d..1046dbd0a 100644 --- a/README.md +++ b/README.md @@ -57,20 +57,16 @@ You can install this package and dependencies in a [Python virtual environment]( We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments. Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document. -*TL;DR to install the Jax version for GPU run:* +*TL;DR to install the Jax version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[pytorch_cpu]' -pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' -pip3 install -e '.[full]' +pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu ``` -*TL;DR to install the PyTorch version for GPU run:* +*TL;DR to install the PyTorch version for GPU and all workload dependencies run:* ```bash -pip3 install -e '.[jax_cpu]' -pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121' -pip3 install -e '.[full]' +pip3 install -e '.[jax_cpu,pytorch_gpu,full]' ``` ## Getting Started diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py index f4cb6c2db..75d8d59ea 100644 --- a/algoperf/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -11,7 +11,6 @@ from flax import jax_utils from flax.training import checkpoints as flax_checkpoints from flax.training.checkpoints import latest_checkpoint -import jax import numpy as np from tensorflow.io import gfile # pytype: disable=import-error import torch @@ -193,10 +192,7 @@ def save_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - model_params = jax.device_get(jax_utils.unreplicate(model_params)) opt_state, _ = optimizer_state - opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) - model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: if isinstance( model_params, diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py index 37d1bd20f..9a7b91b15 100644 --- a/algoperf/data_utils.py +++ b/algoperf/data_utils.py @@ -11,6 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler +from algoperf import jax_sharding_utils from algoperf import spec @@ -60,10 +61,7 @@ def _prepare(x): if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) - # Reshape (global_batch_size, ...) to - # (local_device_count, per_device_batch_size, ...). - # Assumes that `global_batch_size % local_device_count == 0`. - return x.reshape((local_device_count, -1, *x.shape[1:])) + return jax.device_put(x, jax_sharding_utils.get_batch_dim_sharding()) return jax.tree.map(_prepare, batch) diff --git a/algoperf/jax_sharding_utils.py b/algoperf/jax_sharding_utils.py new file mode 100644 index 000000000..248106e6d --- /dev/null +++ b/algoperf/jax_sharding_utils.py @@ -0,0 +1,36 @@ +"""Utilities for dealing with sharding in JAX.""" + +import jax +from jax.sharding import NamedSharding, PartitionSpec as P + + +def get_replicate_sharding(): + """Returns a sharding spec that replicates data across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P()) + + +def get_batch_dim_sharding(): + """Returns a sharding spec that shards data along the first axis.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return NamedSharding(mesh, P('batch')) + + +def shard_along_batch_dim(x): + """Shards a tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x) + + +def replicate(x): + """Replicates tensor across all devices.""" + mesh = jax.sharding.Mesh(jax.devices(), ('batch',)) + return jax.tree.map(lambda x: jax.device_put(x, NamedSharding(mesh, P())), x) + + +def display_shard_info(x: jax.Array): + """Displays shard info of a jax array.""" + for shard in x.addressable_shards: + print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:" + f" {shard.replica_id}.\n") diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 728d05f29..8eec88f28 100644 --- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -171,5 +171,5 @@ def create_input_iter( functools.partial( shard_and_maybe_pad_np, global_batch_size=global_batch_size), ds) - it = jax_utils.prefetch_to_device(it, 2) + return it diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index ad43bc62f..b830fe753 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -3,7 +3,6 @@ import functools from typing import Any, Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax @@ -13,6 +12,7 @@ import tensorflow_datasets as tfds from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter @@ -31,6 +31,7 @@ def _build_cifar_dataset( repeat_final_dataset: Optional[bool] = None ) -> Iterator[Dict[str, spec.Tensor]]: ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir) + ds_builder.download_and_prepare() train = split == 'train' assert self.num_train_examples + self.num_validation_examples == 50000 if split in ['train', 'eval_train']: @@ -96,8 +97,8 @@ def init_model_fn( model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = jax_sharding_utils.replicate(params) + params = jax_sharding_utils.replicate(params) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -175,14 +176,8 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') return metrics - @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -190,20 +185,39 @@ def _eval_model( model_state: spec.ModelAuxiliaryState, rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: """Return the mean accuracy and loss as a dict.""" - logits, _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - rng, - update_batch_norm=False) - weights = batch.get('weights') - if weights is None: - weights = jnp.ones(len(logits)) - return self._compute_metrics(logits, batch['targets'], weights) + + @functools.partial( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng + ), + ) + def _eval_model_jitted( + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]: + """Return the mean accuracy and loss as a dict.""" + logits, _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + rng, + update_batch_norm=False) + weights = batch.get('weights') + if weights is None: + weights = jnp.ones(len(logits)) + return self._compute_metrics(logits, batch['targets'], weights) + + metrics = _eval_model_jitted(params, batch, model_state, rng) + return jax.tree.map(lambda x: x.item(), metrics) def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree_map(lambda x: x / num_examples, total_metrics) diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 91761e458..723326120 100644 --- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -11,6 +11,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf import jax_sharding_utils from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload @@ -105,7 +106,7 @@ def init_model_fn( initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return jax_sharding_utils.replicate(initial_params), None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_7' @@ -129,13 +130,16 @@ def model_fn( return logits_batch, None @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0), - static_broadcasted_argnums=(0,)) - def _eval_batch_pmapped(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> spec.Tensor: + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + ), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) + def _eval_batch_jitted(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> spec.Tensor: logits, _ = self.model_fn( params, batch, @@ -156,8 +160,7 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return np.array( - self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) + return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 1156cf30a..1349cef64 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -10,6 +10,7 @@ from algoperf import param_utils from algoperf import spec +from algoperf import jax_sharding_utils import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim @@ -39,7 +40,7 @@ def init_model_fn( params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - params = jax_utils.replicate(params) + params = jax_sharding_utils.replicate(params) return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -94,10 +95,12 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding()), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.Tensor, batch: Dict[str, spec.Tensor], @@ -126,7 +129,6 @@ def _eval_model(self, 'ssim': ssim_sum, 'loss': summed_loss, } - metrics = jax.lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -154,13 +156,12 @@ def _eval_model_on_split(self, num_batches=num_batches) total_metrics = {'ssim': 0., 'loss': 0.} - eval_rngs = prng.split(model_rng, jax.local_device_count()) for _ in range(num_batches): batch = next(self._eval_iters[split]) # We already sum these metrics across devices inside _eval_model. - synced_metrics = self._eval_model(params, batch, eval_rngs) + synced_metrics = self._eval_model(params, batch, model_rng) total_metrics = { - k: v + synced_metrics[k][0] for k, v in total_metrics.items() + k: v + synced_metrics[k] for k, v in total_metrics.items() } return {k: float(v.item() / num_examples) for k, v in total_metrics.items()} diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 66105335b..35bc3635c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -399,6 +399,7 @@ def create_input_iter(split: str, ds) # Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%. - it = jax_utils.prefetch_to_device(it, 2) + # TODO (kasimbeg): put on device + # it = jax_utils.prefetch_to_device(it, 2) return iter(it) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 4ec3937b8..1fbc40e01 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -20,6 +20,7 @@ from algoperf import param_utils from algoperf import random_utils as prng +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline @@ -71,17 +72,6 @@ def _build_dataset( use_randaug=use_randaug) return ds - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - """Sync the batch statistics across replicas.""" - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() # Create a shallow copy - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - def init_model_fn( self, rng: spec.RandomState, @@ -113,18 +103,29 @@ def init_model_fn( model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = jax.tree.map( + lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding() + ), + params) + model_state = jax.tree.map( + lambda x: jax.device_put(x, jax_sharding_utils.get_replicate_sharding() + ), + model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, 0), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng + ), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding()) def _eval_model(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -218,7 +219,6 @@ def _compute_metrics(self, 'loss': summed_loss, 'accuracy': accuracy, } - metrics = lax.psum(metrics, axis_name='batch') return metrics def _eval_model_on_split(self, @@ -231,9 +231,6 @@ def _eval_model_on_split(self, data_dir: str, global_step: int = 0) -> Dict[str, float]: del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) data_rng, eval_rng = prng.split(rng, 2) # We already repeat the dataset indefinitely in tf.data. @@ -250,20 +247,14 @@ def _eval_model_on_split(self, eval_metrics = {} for bi in range(num_batches): eval_rng = prng.fold_in(eval_rng, bi) - step_eval_rngs = prng.split(eval_rng, jax.local_device_count()) batch = next(self._eval_iters[split]) - # We already average these metrics across devices inside _compute_metrics. - synced_metrics = self._eval_model(params, - batch, - model_state, - step_eval_rngs) + synced_metrics = self._eval_model(params, batch, model_state, eval_rng) for metric_name, metric_value in synced_metrics.items(): if metric_name not in eval_metrics: eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), - eval_metrics) + eval_metrics = jax.tree.map(lambda x: x / num_examples, eval_metrics) return eval_metrics diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 35a6c46be..c4d823319 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -2,13 +2,13 @@ from typing import Dict, Optional, Tuple -from flax import jax_utils from flax import linen as nn from flax.core import pop import jax import jax.numpy as jnp from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload @@ -46,8 +46,8 @@ def init_model_fn( params, model_state = self.initialized(rng, self._model) self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 39012a20d..958e927a4 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -2,11 +2,9 @@ import math from typing import Dict, Iterator, Optional, Tuple -from flax import jax_utils from flax.core import pop import flax.linen as nn import jax -from jax import lax import jax.numpy as jnp import numpy as np import optax @@ -14,6 +12,7 @@ from algoperf import data_utils from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.librispeech_conformer import metrics from algoperf.workloads.librispeech_conformer import workload @@ -93,8 +92,11 @@ def init_model_fn( self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + + # Add sharding + params = jax_sharding_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) + return params, model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: @@ -180,6 +182,7 @@ def _build_input_queue( 'targets': (targets.numpy(), target_paddings.numpy()), } + # Use data_utils.shard_and_maybe_pad_np to handle sharding padded_batch = data_utils.shard_and_maybe_pad_np( numpy_batch, padding_value=1.0) yield padded_batch @@ -305,11 +308,16 @@ def greedy_decode( return hyp, hyp_paddings @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) - def eval_step_pmapped( + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_replicate_sharding(), # rng + ), + out_shardings=jax_sharding_utils.get_batch_dim_sharding(), + static_argnums=(0,)) + def _eval_step( self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor], @@ -325,15 +333,45 @@ def eval_step_pmapped( decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings) loss = self.loss_fn(batch['targets'], (logits, logit_paddings)) - targets, target_paddings = batch['targets'] - return self.metrics_bundle.gather_from_model_output( - loss_dict=loss, - decoded=decoded, - decoded_paddings=decoded_paddings, - targets=targets, - target_paddings=target_paddings, - axis_name='batch') + # Convert metrics bundle to dictionary + metrics_dict = { + 'loss_per_example': + loss['per_example'], + 'decoded': + decoded, + 'decoded_paddings': + decoded_paddings, + 'targets': + targets, + 'target_paddings': + target_paddings, + 'n_valid_examples': + jnp.zeros((len(jax.devices()), 1)) + loss['n_valid_examples'] + } + return metrics_dict + + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState): + """Evaluates the model and returns a metrics bundle.""" + metrics_dict = self._eval_step(params, batch, model_state, rng) + + # Convert dictionary back to metrics bundle + metrics_bundle = self.metrics_bundle.single_from_model_output( + loss_dict={ + 'summed': metrics_dict['loss_per_example'].sum(), + 'per_example': metrics_dict['loss_per_example'], + 'n_valid_examples': metrics_dict['n_valid_examples'].sum() + }, + decoded=metrics_dict['decoded'], + decoded_paddings=metrics_dict['decoded_paddings'], + targets=metrics_dict['targets'], + target_paddings=metrics_dict['target_paddings']) + + return metrics_bundle def _eval_model_on_split(self, split: str, @@ -346,9 +384,6 @@ def _eval_model_on_split(self, global_step: int = 0) -> Dict[str, float]: """Run a full evaluation of the model.""" del global_step - if model_state is not None and len(model_state) > 0: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) num_batches = int(math.ceil(num_examples / global_batch_size)) if split not in self._eval_iters: @@ -358,10 +393,7 @@ def _eval_model_on_split(self, metrics_report = None for _ in range(num_batches): eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() + computed_metrics = self.eval_step(params, eval_batch, model_state, rng) if metrics_report is None: metrics_report = computed_metrics @@ -373,16 +405,6 @@ def _eval_model_on_split(self, return computed_metrics - def sync_batch_stats( - self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState: - # An axis_name is passed to pmap which can then be used by pmean. - # In this case each device has its own version of the batch statistics and - # we average them. - avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() - new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) - return new_model_state - class LibriSpeechConformerAttentionTemperatureWorkload( LibriSpeechConformerWorkload): diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 94f01dd97..42c5bcf4a 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -88,4 +88,6 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 76_000 + # TODO(kasimbeg):rever tot 76000 + # return 76_000 + return 80_000 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index b116f44cd..2c7011445 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -8,13 +8,20 @@ # webpage : https://bastings.github.io/ """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Type, Mapping, Sequence +from absl import logging +import numpy as np + +import functools +import flax from flax import linen as nn from flax import struct import jax from jax.experimental import rnn import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor @@ -310,16 +317,12 @@ def __call__(self, inputs, input_paddings=None, train=False): count_v = jnp.sum( jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) - sum_v = jax.lax.psum(sum_v, axis_name='batch') - count_v = jax.lax.psum(count_v, axis_name='batch') - count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v variance = (inputs - mean) * (inputs - mean) * mask sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True) - sum_vv = jax.lax.psum(sum_vv, axis_name='batch') var = sum_vv / count_v self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index d3b616f43..9b43afe16 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -4,10 +4,13 @@ from flax import jax_utils import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P import numpy as np from algoperf import param_utils from algoperf import spec +from algoperf import jax_sharding_utils from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -41,21 +44,22 @@ def init_model_fn( fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) + # model_init_fn = functools.partial(self._model.init, train=False) params_rng, dropout_rng = jax.random.split(rng, 2) variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state = variables[ - 'batch_stats'] if not self.layernorm_everywhere else {} + model_state = {'batch_stats': variables['batch_stats'] + } if not self.layernorm_everywhere else {} params = variables['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - model_state = jax_utils.replicate(model_state) - params = jax_utils.replicate(params) + model_state = jax_sharding_utils.replicate(model_state) + params = jax_sharding_utils.replicate(params) return params, model_state - def model_fn( + def model_fn_ref( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], @@ -76,6 +80,8 @@ def model_fn( train=True, rngs={'dropout' : rng}, mutable=['batch_stats']) + if 'batch_stats' in new_model_state and new_model_state['batch_stats']: + new_model_state = jax.lax.pmean(new_model_state, 'batch') return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -86,6 +92,36 @@ def model_fn( mutable=False) return (logits, logit_paddings), model_state + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + + model_fn_partial = jax.tree_util.Partial( + self.model_fn_ref, + mode=mode, + rng=rng, + update_batch_norm=update_batch_norm, + use_running_average_bn=use_running_average_bn) + + model_fn_sharded = shard_map( + model_fn_partial, + jax.sharding.Mesh(jax.devices(), ('batch')), + in_specs=(P(), P('batch'), P(None)), + out_specs=(P('batch'), P(None)), + ) + return model_fn_sharded( + params, + augmented_and_preprocessed_input_batch, + model_state, + ) + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' @@ -100,7 +136,9 @@ def test_target_value(self) -> float: @property def step_hint(self) -> int: """Approx. steps the baseline can do in the allowed runtime budget.""" - return 38_400 + # TODO(kasimbeg): revert old version + # return 38_400 + return 48_000 @property def max_allowed_runtime_sec(self) -> int: diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py index 5a4382da1..ad2d7fc8a 100644 --- a/algoperf/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -6,11 +6,11 @@ from flax import jax_utils from flax import linen as nn import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.mnist.workload import BaseMnistWorkload @@ -46,7 +46,7 @@ def init_model_fn( train=True)['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_1' @@ -101,10 +101,14 @@ def loss_fn( } @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + jax_sharding_utils.get_replicate_sharding(), # model_state + jax_sharding_utils.get_batch_dim_sharding(), # rng + ), + static_argnums=(0,)) def _eval_model( self, params: spec.ParameterContainer, @@ -125,11 +129,10 @@ def _eval_model( (jnp.argmax(logits, axis=-1) == batch['targets']) * weights) summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed'] metrics = {'accuracy': accuracy, 'loss': summed_loss} - metrics = lax.psum(metrics, axis_name='batch') return metrics def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x.item() / num_examples), total_metrics) diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py index 3cb6f51de..01aa19c9b 100644 --- a/algoperf/workloads/ogbg/input_pipeline.py +++ b/algoperf/workloads/ogbg/input_pipeline.py @@ -148,17 +148,13 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): weights_shards.append(weights) if count == num_shards: - - def f(x): - return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) - - graphs_shards = f(graphs_shards) - labels_shards = f(labels_shards) - weights_shards = f(weights_shards) + # jraph.batch has a memory leak and OOMs + # It is possible with jraph.batch_np we may have transferred the leak + # to the cpu. yield { - 'inputs': graphs_shards, - 'targets': labels_shards, - 'weights': weights_shards, + 'inputs': jraph.batch_np(graphs_shards), + 'targets': np.vstack(labels_shards), + 'weights': np.vstack(weights_shards) } count = 0 diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py index 0e66d2ab8..7d7de1ecb 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/models.py +++ b/algoperf/workloads/ogbg/ogbg_jax/models.py @@ -2,6 +2,7 @@ # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. from typing import Optional, Tuple +import jax from flax import linen as nn import jax.numpy as jnp import jraph diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py index e895d15a7..9d4a0a404 100644 --- a/algoperf/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,6 +8,7 @@ import jraph import optax +from algoperf import jax_sharding_utils from algoperf import param_utils from algoperf import spec from algoperf.workloads.ogbg import metrics @@ -45,7 +46,8 @@ def init_model_fn( params = params['params'] self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(params), None + params = jax_sharding_utils.replicate(params) + return params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_17' @@ -107,10 +109,14 @@ def _eval_metric(self, labels, logits, masks): loss=loss['per_example'], logits=logits, labels=labels, mask=masks) @functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, 0, 0, 0, None), - static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=(jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding()), + static_argnums=(0,), + out_shardings=jax_sharding_utils.get_replicate_sharding(), + ) def _eval_batch(self, params, batch, model_state, rng): return super()._eval_batch(params, batch, model_state, rng) @@ -119,7 +125,6 @@ def _normalize_eval_metrics( Any]) -> Dict[str, float]: """Normalize eval metrics.""" del num_examples - total_metrics = total_metrics.reduce() return {k: float(v) for k, v in total_metrics.compute().items()} diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 971e7f0f6..45ea778fd 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -161,6 +161,7 @@ def _eval_batch(self, spec.ForwardPassMode.EVAL, rng, update_batch_norm=False) + # jax.debug.print(str(logits)) return self._eval_metric(batch['targets'], logits, batch['weights']) def _eval_model_on_split(self, diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py index ad314a7d3..5e175320a 100644 --- a/algoperf/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -283,8 +283,7 @@ def ref_stats(output, refs): closest_diff = diff closest_len = reflen elif diff == closest_diff: - if reflen < closest_len: - closest_len = reflen + closest_len = min(reflen, closest_len) ngrams_ref = extract_ngrams(ref) for ngram in ngrams_ref: diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py index cdfcb91df..36f5b8606 100644 --- a/algoperf/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -14,6 +14,7 @@ import optax from algoperf import param_utils +from algoperf import jax_sharding_utils from algoperf import spec from algoperf.workloads.wmt import bleu from algoperf.workloads.wmt.wmt_jax import decode @@ -69,10 +70,16 @@ def compute_weighted_cross_entropy( } @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) - def eval_step_pmapped( - self, params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: + jax.jit, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_batch_dim_sharding(), # batch + ), + static_argnums=(0,), # self + ) + def eval_step(self, + params: spec.ParameterContainer, + batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] @@ -90,29 +97,29 @@ def eval_step_pmapped( 'denominator': weight_sum, } - def eval_step(self, - params: spec.ParameterContainer, - batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: - replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) + jax.jit, + in_shardings=( + jax_sharding_utils.get_batch_dim_sharding(), # inputs + ), + static_argnums=( + 0, + 2, + )) def initialize_cache(self, inputs: spec.Tensor, max_decode_len: int = 256) -> Dict[str, spec.Tensor]: """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] - initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), - jnp.ones(inputs.shape, jnp.float32), + dummy_inputs = jax_sharding_utils.shard_along_batch_dim( + jnp.ones(inputs.shape, jnp.float32)) + dummy_targets = jax_sharding_utils.shard_along_batch_dim( jnp.ones(target_shape, jnp.float32)) + initial_variables = models.Transformer(config).init( + jax.random.PRNGKey(0), dummy_inputs, dummy_targets) return initial_variables['cache'] - # eos_id, max_decode_len are constant. - @functools.partial( - jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)) def predict_step(self, inputs: spec.Tensor, params: spec.ParameterContainer, @@ -180,20 +187,35 @@ def translate_and_calculate_bleu(self, """Translates the `predict_ds` and calculates the BLEU score.""" logging.info('Translating evaluation dataset.') references, predictions = [], [] + jitted_predict_step = None for _ in range(num_batches): pred_batch = next(ds_iter) cache = self.initialize_cache(pred_batch['inputs']) - predicted = self.predict_step(pred_batch['inputs'], - params, - cache, - decode.EOS_ID, - max_predict_length) - predicted = _to_host(predicted) - targets = _to_host(pred_batch['targets']) + if jitted_predict_step is None: + jitted_predict_step = jax.jit( + self.predict_step, + in_shardings=( + jax_sharding_utils.get_batch_dim_sharding(), # inputs + jax_sharding_utils.get_replicate_sharding(), # params + jax_sharding_utils.get_replicate_sharding(), # cache + ), + static_argnums=( + 3, # eos_id + 4, # max_decode_len, + 5, # beam_size + )) + predicted = jitted_predict_step(pred_batch['inputs'], + params, + cache, + decode.EOS_ID, + max_predict_length) + # predicted = _to_host(predicted) + # targets = _to_host(pred_batch['targets']) + targets = pred_batch['targets'] # Find actual batch size, ignoring the potential padding. weights = pred_batch.get('weights') if weights is not None: - weights = _to_host(weights) + # weights = _to_host(weights) actual_batch_size = int(weights.sum(0)[0].item()) else: actual_batch_size = len(predicted) @@ -213,7 +235,7 @@ def init_model_fn( aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState: """aux_dropout_rate is used as attention_dropout_rate.""" - init_fake_batch_size = 2 + init_fake_batch_size = 8 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) @@ -235,15 +257,21 @@ def init_model_fn( eval_config = replace(model_config, deterministic=True) self._eval_model = models.Transformer(eval_config) params_rng, dropout_rng = jax.random.split(rng) + inputs = jnp.ones(input_shape, jnp.float32) + targets = jnp.ones(target_shape, jnp.float32) + sharded_inputs = jax_sharding_utils.shard_along_batch_dim(inputs) + sharded_targets = jax_sharding_utils.shard_along_batch_dim(targets) + initial_variables = jax.jit( self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng}, - jnp.ones(input_shape, jnp.float32), - jnp.ones(target_shape, jnp.float32)) + sharded_inputs, + sharded_targets) initial_params = initial_variables['params'] self._param_shapes = param_utils.jax_param_shapes(initial_params) self._param_types = param_utils.jax_param_types(self._param_shapes) - return jax_utils.replicate(initial_params), None + params = jax_sharding_utils.shard_along_batch_dim(initial_params) + return initial_params, None def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'shared_embedding' diff --git a/docker/Dockerfile b/docker/Dockerfile index 76bc5cfe0..4879d9612 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docker build -t --build-arg framework=pytorch # To build Docker image -FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04 # Installing machine packages RUN echo "Setting up machine" @@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libffi-dev \ curl \ - libbz2-dev \ liblzma-dev \ + libbz2-dev \ vim # Download and install Python 3.11 @@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ -RUN pip install --upgrade pip - # Install Algorithmic efficiency repo RUN pip install --upgrade pip @@ -71,25 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ + && pip install -e '.[pytorch_cpu, full]' --extra-index-url https://download.pytorch.org/whl/cpu \ + # Todo: remove temporary nightly install + && pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ && cd /algorithmic-efficiency \ - && pip install -e '.[jax_cpu]' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ - elif [ "$framework" = "both" ] ; then \ - echo "Installing Jax GPU and Pytorch GPU" \ - && cd /algorithmic-efficiency \ - && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ - && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \ + && pip install -e '.[pytorch_gpu, jax_cpu, full]'; \ else \ - echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \ + echo "Invalid build-arg $framework: framework should be either jax or pytorch." >&2 \ && exit 1 ; \ fi -RUN cd /algorithmic-efficiency && pip install -e '.[full]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh index 645b81955..6b5e67ceb 100644 --- a/docker/build_docker_images.sh +++ b/docker/build_docker_images.sh @@ -1,27 +1,40 @@ #!/bin/bash # Bash script to build and push dev docker images to artifact repo # Usage: -# bash build_docker_images.sh -b +# bash build_docker_images.sh -b -f # Make program exit with non-zero exit code if any command fails. set -e -while getopts b: flag +while getopts "b:p:f:" flag; do case "${flag}" in b) GIT_BRANCH=${OPTARG};; + p) PROJECT=${OPTARG};; + f) FRAMEWORK=${OPTARG};; esac done # Artifact repostiory -ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +if [ "$PROJECT" = "mlcommons-algoperf" ]; then + ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo" +else + ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo" +fi -if [[ -z ${GIT_BRANCH+x} ]] +if [[ -z ${GIT_BRANCH+x} ]]; then GIT_BRANCH='main' # Set default argument fi -for FRAMEWORK in "jax" "pytorch" "both" +FRAMEWORKS=( "jax" "pythorch" "both" ) + +if [[ -n "$FRAMEWORK" ]]; +then + FRAMEWORKS=("$FRAMEWORK") +fi + +for FRAMEWORK in "${FRAMEWORKS[@]}"; do IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH" diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 7fec01542..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -813.182501077652,0.0,20.942517280578613,1,0,20.942517280578613,0.457260122944079,95000000,834.1250734329224,0.4569449610305282,0.4571109073102294,83274637 -1456.6285498142242,0.0256640911102294,1221.314148426056,1521,0,1221.314148426056,0.1282897442639802,95000000,2678.023252725601,0.1242542823999183,0.1258836863985669,83274637 -2016.864410877228,0.0503592491149902,2421.367050409317,3037,0,2421.367050409317,0.1271590245271381,95000000,4438.388706684113,0.1243784172261286,0.1247065146412016,83274637 -2558.6020953655243,0.0732090473175048,3621.3504645824432,4555,0,3621.3504645824432,0.1268894827919408,95000000,6180.184489965439,0.1230034951893788,0.1245234417083079,83274637 -3071.354640007019,0.0977675914764404,4821.317496538162,6087,0,4821.317496538162,0.1267936776624177,95000000,7892.981649637222,0.1233932793890155,0.1243785472760451,83274637 -3527.707134723664,0.125211477279663,6021.926441669464,7600,0,6021.926441669464,0.1266444441200657,95000000,9550.021806240082,0.1227233855979247,0.1242731458953252,83274637 -3858.912323474884,0.1494441032409668,7222.050496816635,9109,0,7222.050496816635,0.12616021638569078,95000000,11081.426317453384,0.12224845297681461,0.12382877749552663,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 57c6a327a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.595053,0.45791024,,,,,,,,,,, -1,,,0.4569449610305282,0.4571109073102294,83274637.0,0.457260122944079,95000000.0,20.942517280578613,834.1250734329224,20.942517280578613,813.182501077652,0.0,0.0 -100,0.052369263,0.14439134,,,,,,,,,,, -200,0.011274185,0.12775373,,,,,,,,,,, -300,0.0145961,0.12785143,,,,,,,,,,, -400,0.016431093,0.12760307,,,,,,,,,,, -500,0.011657697,0.13277733,,,,,,,,,,, -600,0.008272649,0.124659255,,,,,,,,,,, -700,0.013119716,0.118887946,,,,,,,,,,, -800,0.02837587,0.12938836,,,,,,,,,,, -900,0.039843943,0.12499282,,,,,,,,,,, -1000,0.013996667,0.13850704,,,,,,,,,,, -1100,0.008121838,0.12663904,,,,,,,,,,, -1200,0.032506075,0.12558989,,,,,,,,,,, -1300,0.016993973,0.12045493,,,,,,,,,,, -1400,0.018361678,0.12555787,,,,,,,,,,, -1500,0.017285459,0.12031912,,,,,,,,,,, -1521,,,0.1242542823999183,0.1258836863985669,83274637.0,0.1282897442639802,95000000.0,1221.314148426056,2678.023252725601,1221.314148426056,1456.6285498142242,0.0256640911102294,0.0 -1600,0.0080781495,0.1332495,,,,,,,,,,, -1700,0.006543472,0.1284553,,,,,,,,,,, -1800,0.015620338,0.12344419,,,,,,,,,,, -1900,0.006610907,0.1305861,,,,,,,,,,, -2000,0.0061059073,0.11774591,,,,,,,,,,, -2100,0.0095490115,0.12197934,,,,,,,,,,, -2200,0.0058529815,0.13005584,,,,,,,,,,, -2300,0.013598117,0.122508906,,,,,,,,,,, -2400,0.009501394,0.11922068,,,,,,,,,,, -2500,0.008959282,0.11970478,,,,,,,,,,, -2600,0.012501967,0.12650742,,,,,,,,,,, -2700,0.009560492,0.12145102,,,,,,,,,,, -2800,0.018516129,0.13355987,,,,,,,,,,, -2900,0.008552481,0.11968313,,,,,,,,,,, -3000,0.009244999,0.12620577,,,,,,,,,,, -3037,,,0.1243784172261286,0.1247065146412016,83274637.0,0.1271590245271381,95000000.0,2421.367050409317,4438.388706684113,2421.367050409317,2016.864410877228,0.0503592491149902,0.0 -3100,0.016859975,0.121443875,,,,,,,,,,, -3200,0.011390192,0.12734412,,,,,,,,,,, -3300,0.01578223,0.12218773,,,,,,,,,,, -3400,0.012737954,0.12198927,,,,,,,,,,, -3500,0.008529306,0.11594345,,,,,,,,,,, -3600,0.0057221455,0.11723854,,,,,,,,,,, -3700,0.012278644,0.12083666,,,,,,,,,,, -3800,0.0071947486,0.1215835,,,,,,,,,,, -3900,0.01672526,0.12385393,,,,,,,,,,, -4000,0.009465944,0.12336844,,,,,,,,,,, -4100,0.008182347,0.121824875,,,,,,,,,,, -4200,0.011010937,0.12693067,,,,,,,,,,, -4300,0.006021543,0.12776452,,,,,,,,,,, -4400,0.023563234,0.12802522,,,,,,,,,,, -4500,0.0065156184,0.121047646,,,,,,,,,,, -4555,,,0.1230034951893788,0.1245234417083079,83274637.0,0.1268894827919408,95000000.0,3621.3504645824432,6180.184489965439,3621.3504645824432,2558.6020953655243,0.0732090473175048,0.0 -4600,0.009190459,0.13318563,,,,,,,,,,, -4700,0.0081238495,0.121328145,,,,,,,,,,, -4800,0.009070785,0.12200863,,,,,,,,,,, -4900,0.0075785993,0.13092919,,,,,,,,,,, -5000,0.010724381,0.12096827,,,,,,,,,,, -5100,0.010005327,0.13379729,,,,,,,,,,, -5200,0.010590516,0.12410814,,,,,,,,,,, -5300,0.005807965,0.12470031,,,,,,,,,,, -5400,0.016870208,0.12171663,,,,,,,,,,, -5500,0.008763265,0.13067703,,,,,,,,,,, -5600,0.00906206,0.119758196,,,,,,,,,,, -5700,0.00591919,0.11863424,,,,,,,,,,, -5800,0.012273158,0.12229582,,,,,,,,,,, -5900,0.0059149456,0.11517441,,,,,,,,,,, -6000,0.018563222,0.12411391,,,,,,,,,,, -6087,,,0.1233932793890155,0.1243785472760451,83274637.0,0.1267936776624177,95000000.0,4821.317496538162,7892.981649637222,4821.317496538162,3071.354640007019,0.0977675914764404,0.0 -6100,0.013684579,0.12588407,,,,,,,,,,, -6200,0.013902485,0.13925087,,,,,,,,,,, -6300,0.01532411,0.12060828,,,,,,,,,,, -6400,0.009698195,0.1268852,,,,,,,,,,, -6500,0.011197286,0.1266221,,,,,,,,,,, -6600,0.006905215,0.1351961,,,,,,,,,,, -6700,0.0076057217,0.12408463,,,,,,,,,,, -6800,0.007856281,0.12371522,,,,,,,,,,, -6900,0.010840494,0.12143174,,,,,,,,,,, -7000,0.008237116,0.11565082,,,,,,,,,,, -7100,0.007685555,0.11505994,,,,,,,,,,, -7200,0.009968832,0.122804135,,,,,,,,,,, -7300,0.017238097,0.119117185,,,,,,,,,,, -7400,0.0067587667,0.12082305,,,,,,,,,,, -7500,0.0072864136,0.11846155,,,,,,,,,,, -7600,,,0.1227233855979247,0.1242731458953252,83274637.0,0.1266444441200657,95000000.0,6021.926441669464,9550.021806240082,6021.926441669464,3527.707134723664,0.125211477279663,0.0 -7600,0.009987882,0.13021319,,,,,,,,,,, -7700,0.008697033,0.11773735,,,,,,,,,,, -7800,0.010010784,0.1215125,,,,,,,,,,, -7900,0.009225212,0.12194373,,,,,,,,,,, -8000,0.007961596,0.12818006,,,,,,,,,,, -8100,0.007398393,0.116453744,,,,,,,,,,, -8200,0.01821585,0.116336524,,,,,,,,,,, -8300,0.0115669975,0.11885777,,,,,,,,,,, -8400,0.009494203,0.12733606,,,,,,,,,,, -8500,0.007484815,0.12146135,,,,,,,,,,, -8600,0.016212368,0.12723136,,,,,,,,,,, -8700,0.016069423,0.11877815,,,,,,,,,,, -8800,0.021099621,0.122659475,,,,,,,,,,, -8900,0.009202063,0.13165034,,,,,,,,,,, -9000,0.009431177,0.123754546,,,,,,,,,,, -9100,0.011813518,0.12514208,,,,,,,,,,, -9109,,,0.1222484529768146,0.1238287774955266,83274637.0,0.1261602163856907,95000000.0,7222.050496816635,11081.426317453384,7222.050496816635,3858.912323474884,0.1494441032409668,0.0 -9200,0.011132723,0.12894654,,,,,,,,,,, -9300,0.010056959,0.120828405,,,,,,,,,,, -9400,0.008353129,0.12036147,,,,,,,,,,, -9500,0.0070782807,0.120868035,,,,,,,,,,, -9600,0.017878985,0.115929335,,,,,,,,,,, -9700,0.020996006,0.13593015,,,,,,,,,,, -9737,,,,,,,,7703.6179575920105,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 6c9a3ffa2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -155.3625431060791,0.0,5.699930191040039,1,0,5.699930191040039,0.457260122944079,95000000,161.06252551078796,0.4567256082528792,0.4571109073102294,83274637 -178.03247785568237,0.0186891555786132,1206.3205771446228,1333,0,1206.3205771446228,0.1282896378392269,95000000,1384.4192180633545,0.1243324928782271,0.1258608463579222,83274637 -200.46779704093933,0.0492315292358398,2406.714089870453,2718,0,2406.714089870453,0.1274692210834704,95000000,2607.327826023102,0.1239407410692868,0.1250111534184772,83274637 -223.11623096466064,0.0789625644683837,3606.766949415207,4097,0,3606.766949415207,0.1269236832648026,95000000,3830.107353925705,0.122401347793873,0.1246056620738406,83274637 -245.94660782814023,0.1041405200958252,4807.293630361557,5459,0,4807.293630361557,0.1267341659333881,95000000,5053.537276268005,0.1230073884223242,0.1242982392105774,83274637 -268.8433749675751,0.13230562210083,6008.01816034317,6819,0,6008.01816034317,0.126314287109375,95000000,6277.2343101501465,0.1250520946001106,0.1240117562370761,83274637 -291.9232795238495,0.15978431701660156,7208.564959049225,8194,0,7208.564959049225,0.1262119042763158,95000000,7500.937597751617,0.11901546869450395,0.12387405083720149,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/measurements.csv deleted file mode 100644 index bcfadea40..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/measurements.csv +++ /dev/null @@ -1,96 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.4390554,0.45606273,,,,,,,,,,, -1,,,0.4567256082528792,0.4571109073102294,83274637.0,0.457260122944079,95000000.0,5.699930191040039,161.06252551078796,5.699930191040039,155.3625431060791,0.0,0.0 -100,0.060920443,0.13237274,,,,,,,,,,, -200,0.14153314,0.12400838,,,,,,,,,,, -300,0.04789811,0.12201474,,,,,,,,,,, -400,0.06332899,0.13071565,,,,,,,,,,, -500,0.06602588,0.12573342,,,,,,,,,,, -600,0.023347335,0.13069947,,,,,,,,,,, -700,0.018870912,0.11814365,,,,,,,,,,, -800,0.015525951,0.11676782,,,,,,,,,,, -900,0.04949253,0.124277994,,,,,,,,,,, -1000,0.030646183,0.12361217,,,,,,,,,,, -1100,0.02375579,0.119545214,,,,,,,,,,, -1200,0.03587274,0.12959868,,,,,,,,,,, -1300,0.013233145,0.12115848,,,,,,,,,,, -1333,,,0.1243324928782271,0.1258608463579222,83274637.0,0.1282896378392269,95000000.0,1206.3205771446228,1384.4192180633545,1206.3205771446228,178.03247785568237,0.0186891555786132,0.0 -1400,0.047744792,0.12076018,,,,,,,,,,, -1500,0.007142925,0.12050425,,,,,,,,,,, -1600,0.0554247,0.12011003,,,,,,,,,,, -1700,0.017859716,0.12459074,,,,,,,,,,, -1800,0.022750404,0.12939805,,,,,,,,,,, -1900,0.021802584,0.11915764,,,,,,,,,,, -2000,0.027529169,0.12505569,,,,,,,,,,, -2100,0.053865355,0.12188184,,,,,,,,,,, -2200,0.04906765,0.1227427,,,,,,,,,,, -2300,0.025565568,0.11876911,,,,,,,,,,, -2400,0.015909987,0.11598399,,,,,,,,,,, -2500,0.019850688,0.12243564,,,,,,,,,,, -2600,0.006253544,0.12149963,,,,,,,,,,, -2700,0.007893926,0.12330208,,,,,,,,,,, -2718,,,0.1239407410692868,0.1250111534184772,83274637.0,0.1274692210834704,95000000.0,2406.714089870453,2607.327826023102,2406.714089870453,200.46779704093933,0.0492315292358398,0.0 -2800,0.031312265,0.12353721,,,,,,,,,,, -2900,0.03204437,0.12714541,,,,,,,,,,, -3000,0.018536927,0.1200861,,,,,,,,,,, -3100,0.010753124,0.12544902,,,,,,,,,,, -3200,0.024220303,0.12500334,,,,,,,,,,, -3300,0.01116332,0.12095913,,,,,,,,,,, -3400,0.012960139,0.11797331,,,,,,,,,,, -3500,0.009452975,0.12883525,,,,,,,,,,, -3600,0.0058812094,0.11662953,,,,,,,,,,, -3700,0.04272561,0.13020732,,,,,,,,,,, -3800,0.0064390437,0.12804143,,,,,,,,,,, -3900,0.0073450734,0.11895685,,,,,,,,,,, -4000,0.018193113,0.12369791,,,,,,,,,,, -4097,,,0.122401347793873,0.1246056620738406,83274637.0,0.1269236832648026,95000000.0,3606.766949415207,3830.107353925705,3606.766949415207,223.11623096466064,0.0789625644683837,0.0 -4100,0.011242135,0.122081175,,,,,,,,,,, -4200,0.0065874993,0.12100526,,,,,,,,,,, -4300,0.024711892,0.12976725,,,,,,,,,,, -4400,0.008941987,0.11569895,,,,,,,,,,, -4500,0.006363438,0.12152226,,,,,,,,,,, -4600,0.010843002,0.12032419,,,,,,,,,,, -4700,0.014552275,0.12052624,,,,,,,,,,, -4800,0.0063895113,0.12140711,,,,,,,,,,, -4900,0.0066234646,0.13600604,,,,,,,,,,, -5000,0.009849289,0.12368231,,,,,,,,,,, -5100,0.006253873,0.12126851,,,,,,,,,,, -5200,0.026139744,0.12531343,,,,,,,,,,, -5300,0.006885868,0.1177107,,,,,,,,,,, -5400,0.011416498,0.12058384,,,,,,,,,,, -5459,,,0.1230073884223242,0.1242982392105774,83274637.0,0.1267341659333881,95000000.0,4807.293630361557,5053.537276268005,4807.293630361557,245.94660782814023,0.1041405200958252,0.0 -5500,0.012883707,0.12942463,,,,,,,,,,, -5600,0.023014875,0.1214724,,,,,,,,,,, -5700,0.006456543,0.12836507,,,,,,,,,,, -5800,0.009441953,0.12329052,,,,,,,,,,, -5900,0.013367521,0.123413086,,,,,,,,,,, -6000,0.018386548,0.11985607,,,,,,,,,,, -6100,0.011404792,0.12755816,,,,,,,,,,, -6200,0.015975421,0.12877865,,,,,,,,,,, -6300,0.011644255,0.12303987,,,,,,,,,,, -6400,0.014322759,0.12339946,,,,,,,,,,, -6500,0.016656501,0.12394816,,,,,,,,,,, -6600,0.011252226,0.12001917,,,,,,,,,,, -6700,0.015471942,0.12322952,,,,,,,,,,, -6800,0.007275899,0.12993953,,,,,,,,,,, -6819,,,0.1250520946001106,0.1240117562370761,83274637.0,0.126314287109375,95000000.0,6008.01816034317,6277.2343101501465,6008.01816034317,268.8433749675751,0.13230562210083,0.0 -6900,0.018681353,0.12591648,,,,,,,,,,, -7000,0.007765912,0.11548301,,,,,,,,,,, -7100,0.0061852993,0.12502441,,,,,,,,,,, -7200,0.005432842,0.12556644,,,,,,,,,,, -7300,0.011321146,0.11464477,,,,,,,,,,, -7400,0.0071827606,0.119832076,,,,,,,,,,, -7500,0.016439138,0.12669729,,,,,,,,,,, -7600,0.007486849,0.12097673,,,,,,,,,,, -7700,0.0079718195,0.123005554,,,,,,,,,,, -7800,0.0071740304,0.11992081,,,,,,,,,,, -7900,0.0075010406,0.11670616,,,,,,,,,,, -8000,0.006271487,0.13278887,,,,,,,,,,, -8100,0.006883387,0.123326436,,,,,,,,,,, -8194,,,0.1190154686945039,0.1238740508372014,83274637.0,0.1262119042763158,95000000.0,7208.564959049225,7500.937597751617,7208.564959049225,291.9232795238495,0.1597843170166015,0.0 -8200,0.008841047,0.111841425,,,,,,,,,,, -8300,0.0068999263,0.13395442,,,,,,,,,,, -8400,0.0075857877,0.13019323,,,,,,,,,,, -8500,0.0069563505,0.117008954,,,,,,,,,,, -8600,0.0061066835,0.121428005,,,,,,,,,,, -8682,,,,,,,,7703.565836429596,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 2a8261a5d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.160903692245483,0.0,5.787775278091431,1,0,5.787775278091431,0.457260122944079,95000000,27.94872999191284,0.456549252166688,0.4571109073102294,83274637 -44.573280334472656,0.0164179801940917,1205.7778165340424,1393,0,1205.7778165340424,0.1283989195415296,95000000,1250.4170544147491,0.1248274825560221,0.1259217729594396,83274637 -66.96710991859436,0.0449867248535156,2406.32721567154,2747,0,2406.32721567154,0.1274786297286184,95000000,2473.437765598297,0.1246258874171934,0.125082685552565,83274637 -89.32167363166809,0.0709710121154785,3606.89742398262,4128,0,3606.89742398262,0.1268704804070723,95000000,3696.4365243911743,0.1226278332606801,0.124549532851386,83274637 -111.73241209983826,0.0971765518188476,4807.161467552185,5527,0,4807.161467552185,0.126873588671875,95000000,4919.186618089676,0.1210244252059444,0.1243756420265902,83274637 -134.3200616836548,0.1260366439819336,6007.878246545792,6912,0,6007.878246545792,0.1262866741981907,95000000,6142.56973862648,0.1247343713382505,0.1239951419253109,83274637 -156.87633275985718,0.14765334129333496,7208.5314157009125,8276,0,7208.5314157009125,0.12630719304070723,95000000,7365.848479747772,0.12290918686479893,0.12398521116489526,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/measurements.csv deleted file mode 100644 index b90434760..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/measurements.csv +++ /dev/null @@ -1,97 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.687231,0.457249,,,,,,,,,,, -1,,,0.456549252166688,0.4571109073102294,83274637.0,0.457260122944079,95000000.0,5.787775278091431,27.94872999191284,5.787775278091431,22.160903692245483,0.0,0.0 -100,0.049194332,0.13630238,,,,,,,,,,, -200,0.011109251,0.13088103,,,,,,,,,,, -300,0.011225884,0.12831932,,,,,,,,,,, -400,0.071036965,0.13105682,,,,,,,,,,, -500,0.038935665,0.12977932,,,,,,,,,,, -600,0.021225862,0.13282129,,,,,,,,,,, -700,0.03885013,0.12751718,,,,,,,,,,, -800,0.03192323,0.11805739,,,,,,,,,,, -900,0.010180866,0.11626887,,,,,,,,,,, -1000,0.0073976004,0.12743299,,,,,,,,,,, -1100,0.016663283,0.1301519,,,,,,,,,,, -1200,0.026177555,0.12371484,,,,,,,,,,, -1300,0.023547215,0.13517646,,,,,,,,,,, -1393,,,0.1248274825560221,0.1259217729594396,83274637.0,0.1283989195415296,95000000.0,1205.7778165340424,1250.4170544147491,1205.7778165340424,44.573280334472656,0.0164179801940917,0.0 -1400,0.010493084,0.12267895,,,,,,,,,,, -1500,0.017739981,0.12523377,,,,,,,,,,, -1600,0.0058661215,0.118518725,,,,,,,,,,, -1700,0.007959809,0.13688208,,,,,,,,,,, -1800,0.014911896,0.12036344,,,,,,,,,,, -1900,0.014060439,0.11627702,,,,,,,,,,, -2000,0.010255026,0.12652552,,,,,,,,,,, -2100,0.0062374636,0.124340564,,,,,,,,,,, -2200,0.007551402,0.1173951,,,,,,,,,,, -2300,0.008985987,0.13099723,,,,,,,,,,, -2400,0.0126199005,0.12539656,,,,,,,,,,, -2500,0.0068777744,0.12528047,,,,,,,,,,, -2600,0.0056410762,0.12496973,,,,,,,,,,, -2700,0.00619455,0.121495135,,,,,,,,,,, -2747,,,0.1246258874171934,0.125082685552565,83274637.0,0.1274786297286184,95000000.0,2406.32721567154,2473.437765598297,2406.32721567154,66.96710991859436,0.0449867248535156,0.0 -2800,0.015712002,0.11985657,,,,,,,,,,, -2900,0.015613381,0.120680206,,,,,,,,,,, -3000,0.005411323,0.12121217,,,,,,,,,,, -3100,0.008502491,0.12177272,,,,,,,,,,, -3200,0.017229704,0.11946741,,,,,,,,,,, -3300,0.011097713,0.12453542,,,,,,,,,,, -3400,0.006934851,0.11791854,,,,,,,,,,, -3500,0.01837584,0.120984875,,,,,,,,,,, -3600,0.008990065,0.11736627,,,,,,,,,,, -3700,0.009730105,0.1291349,,,,,,,,,,, -3800,0.011153789,0.134777,,,,,,,,,,, -3900,0.011907779,0.12029965,,,,,,,,,,, -4000,0.01573217,0.11838029,,,,,,,,,,, -4100,0.0073327953,0.12861279,,,,,,,,,,, -4128,,,0.1226278332606801,0.124549532851386,83274637.0,0.1268704804070723,95000000.0,3606.89742398262,3696.4365243911743,3606.89742398262,89.32167363166809,0.0709710121154785,0.0 -4200,0.0051200604,0.12062572,,,,,,,,,,, -4300,0.010489875,0.123833604,,,,,,,,,,, -4400,0.0074130907,0.12430734,,,,,,,,,,, -4500,0.008221291,0.12108123,,,,,,,,,,, -4600,0.012701847,0.119589895,,,,,,,,,,, -4700,0.010514027,0.11881789,,,,,,,,,,, -4800,0.011746646,0.12235604,,,,,,,,,,, -4900,0.0064696684,0.11949152,,,,,,,,,,, -5000,0.008648386,0.122534804,,,,,,,,,,, -5100,0.008401843,0.12088668,,,,,,,,,,, -5200,0.022047075,0.12308455,,,,,,,,,,, -5300,0.007871592,0.12474726,,,,,,,,,,, -5400,0.005162374,0.119656816,,,,,,,,,,, -5500,0.013352562,0.12549068,,,,,,,,,,, -5527,,,0.1210244252059444,0.1243756420265902,83274637.0,0.126873588671875,95000000.0,4807.161467552185,4919.186618089676,4807.161467552185,111.73241209983826,0.0971765518188476,0.0 -5600,0.024482181,0.12882411,,,,,,,,,,, -5700,0.008767436,0.11842483,,,,,,,,,,, -5800,0.012130027,0.12279017,,,,,,,,,,, -5900,0.011982712,0.1258151,,,,,,,,,,, -6000,0.012427459,0.13166359,,,,,,,,,,, -6100,0.011969856,0.12011531,,,,,,,,,,, -6200,0.012575729,0.124036774,,,,,,,,,,, -6300,0.009841854,0.118779674,,,,,,,,,,, -6400,0.0062593,0.13030538,,,,,,,,,,, -6500,0.0100255245,0.13055381,,,,,,,,,,, -6600,0.008639212,0.12686765,,,,,,,,,,, -6700,0.014000217,0.123771764,,,,,,,,,,, -6800,0.02804086,0.119958386,,,,,,,,,,, -6900,0.0062979404,0.122514434,,,,,,,,,,, -6912,,,0.1247343713382505,0.1239951419253109,83274637.0,0.1262866741981907,95000000.0,6007.878246545792,6142.56973862648,6007.878246545792,134.3200616836548,0.1260366439819336,0.0 -7000,0.0066435817,0.12448236,,,,,,,,,,, -7100,0.012332188,0.119412504,,,,,,,,,,, -7200,0.015570049,0.11958886,,,,,,,,,,, -7300,0.0075572105,0.1200089,,,,,,,,,,, -7400,0.008578128,0.124491304,,,,,,,,,,, -7500,0.010891134,0.12892178,,,,,,,,,,, -7600,0.009862055,0.1255548,,,,,,,,,,, -7700,0.0075048665,0.13132259,,,,,,,,,,, -7800,0.012398823,0.13298084,,,,,,,,,,, -7900,0.008378814,0.12689476,,,,,,,,,,, -8000,0.0076870546,0.13185808,,,,,,,,,,, -8100,0.0072816554,0.11613306,,,,,,,,,,, -8200,0.013485181,0.11774512,,,,,,,,,,, -8276,,,0.1229091868647989,0.1239852111648952,83274637.0,0.1263071930407072,95000000.0,7208.531415700912,7365.848479747772,7208.531415700912,156.87633275985718,0.1476533412933349,0.0 -8300,0.008302952,0.12185592,,,,,,,,,,, -8400,0.009927234,0.11974198,,,,,,,,,,, -8500,0.008286511,0.12271012,,,,,,,,,,, -8600,0.0102838725,0.12755881,,,,,,,,,,, -8700,0.008119846,0.122148424,,,,,,,,,,, -8768,,,,,,,,7703.806035518646,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/eval_measurements.csv deleted file mode 100644 index d46840457..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.35922241210937,0.0,8.699556589126587,1,0,8.699556589126587,0.457260122944079,95000000,31.05883479118347,0.4564077383692159,0.4571109073102294,83274637 -43.892783403396606,0.0172150135040283,1208.6605892181396,1393,0,1208.6605892181396,0.1294493597347862,95000000,1252.619071483612,0.1236191887713078,0.1269140963615068,83274637 -65.52731323242188,0.0454833507537841,2408.7519080638885,2789,0,2408.7519080638885,0.1287058117290296,95000000,2474.4223458766937,0.1246785134451944,0.1262659647367331,83274637 -87.28004765510559,0.0741846561431884,3609.058803796768,4163,0,3609.058803796768,0.128459453176398,95000000,3696.559238433838,0.123837466884709,0.1260347184195021,83274637 -109.08179664611816,0.0967528820037841,4809.377663612366,5548,0,4809.377663612366,0.128173310598273,95000000,4918.750566482544,0.1234656887234381,0.1258077316806556,83274637 -130.91276717185974,0.1190395355224609,6009.324901580811,6936,0,6009.324901580811,0.1277705654810855,95000000,6140.600414991379,0.1253732571718078,0.1254328360112665,83274637 -152.78353452682495,0.1488809585571289,7209.504168272018,8319,0,7209.504168272018,0.12781594426398027,95000000,7362.730304002762,0.12423908162229466,0.12550816220835043,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/measurements.csv deleted file mode 100644 index 800ba9862..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/measurements.csv +++ /dev/null @@ -1,97 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.7819524,0.45796323,,,,,,,,,,, -1,,,0.4564077383692159,0.4571109073102294,83274637.0,0.457260122944079,95000000.0,8.699556589126587,31.05883479118347,8.699556589126587,22.35922241210937,0.0,0.0 -100,0.06331506,0.12716652,,,,,,,,,,, -200,0.03560064,0.12326371,,,,,,,,,,, -300,0.0698723,0.13408089,,,,,,,,,,, -400,0.02998232,0.12932953,,,,,,,,,,, -500,0.05001278,0.12180904,,,,,,,,,,, -600,0.04046634,0.12962519,,,,,,,,,,, -700,0.058634173,0.13164943,,,,,,,,,,, -800,0.04343686,0.11900744,,,,,,,,,,, -900,0.028470393,0.12528557,,,,,,,,,,, -1000,0.007367834,0.12158807,,,,,,,,,,, -1100,0.05701954,0.120786525,,,,,,,,,,, -1200,0.02632552,0.12112973,,,,,,,,,,, -1300,0.052368835,0.12938438,,,,,,,,,,, -1393,,,0.1236191887713078,0.1269140963615068,83274637.0,0.1294493597347862,95000000.0,1208.6605892181396,1252.619071483612,1208.6605892181396,43.892783403396606,0.0172150135040283,0.0 -1400,0.05026665,0.12736718,,,,,,,,,,, -1500,0.02061756,0.1417707,,,,,,,,,,, -1600,0.046593484,0.12777026,,,,,,,,,,, -1700,0.030087015,0.12753886,,,,,,,,,,, -1800,0.032481834,0.118466415,,,,,,,,,,, -1900,0.020286223,0.12315376,,,,,,,,,,, -2000,0.042394057,0.1269775,,,,,,,,,,, -2100,0.052086614,0.12434018,,,,,,,,,,, -2200,0.02383262,0.12283477,,,,,,,,,,, -2300,0.016387109,0.1320942,,,,,,,,,,, -2400,0.032715857,0.13365752,,,,,,,,,,, -2500,0.025387274,0.12174909,,,,,,,,,,, -2600,0.005835963,0.12988254,,,,,,,,,,, -2700,0.032749772,0.13267839,,,,,,,,,,, -2789,,,0.1246785134451944,0.1262659647367331,83274637.0,0.1287058117290296,95000000.0,2408.7519080638885,2474.4223458766937,2408.7519080638885,65.52731323242188,0.0454833507537841,0.0 -2800,0.01310616,0.12589882,,,,,,,,,,, -2900,0.041326318,0.1261717,,,,,,,,,,, -3000,0.060045946,0.123586774,,,,,,,,,,, -3100,0.008646914,0.11926082,,,,,,,,,,, -3200,0.040187474,0.13348317,,,,,,,,,,, -3300,0.008639634,0.12660338,,,,,,,,,,, -3400,0.0064950986,0.121340916,,,,,,,,,,, -3500,0.039587494,0.12681636,,,,,,,,,,, -3600,0.0065355217,0.12144455,,,,,,,,,,, -3700,0.05452722,0.12541942,,,,,,,,,,, -3800,0.053787097,0.13360973,,,,,,,,,,, -3900,0.031417582,0.12509248,,,,,,,,,,, -4000,0.0120343575,0.1233576,,,,,,,,,,, -4100,0.03110819,0.123297,,,,,,,,,,, -4163,,,0.123837466884709,0.1260347184195021,83274637.0,0.128459453176398,95000000.0,3609.058803796768,3696.559238433838,3609.058803796768,87.28004765510559,0.0741846561431884,0.0 -4200,0.033037093,0.12316088,,,,,,,,,,, -4300,0.011160579,0.13783854,,,,,,,,,,, -4400,0.042292364,0.11910769,,,,,,,,,,, -4500,0.049988057,0.13770814,,,,,,,,,,, -4600,0.015362906,0.12259429,,,,,,,,,,, -4700,0.031086909,0.12591228,,,,,,,,,,, -4800,0.03123642,0.12013433,,,,,,,,,,, -4900,0.052385774,0.12442732,,,,,,,,,,, -5000,0.06859437,0.120943375,,,,,,,,,,, -5100,0.0061666616,0.13056411,,,,,,,,,,, -5200,0.005426396,0.12027189,,,,,,,,,,, -5300,0.021181542,0.114510305,,,,,,,,,,, -5400,0.0067052096,0.13104758,,,,,,,,,,, -5500,0.03900398,0.12268844,,,,,,,,,,, -5548,,,0.1234656887234381,0.1258077316806556,83274637.0,0.128173310598273,95000000.0,4809.377663612366,4918.750566482544,4809.377663612366,109.08179664611816,0.0967528820037841,0.0 -5600,0.02037416,0.119444944,,,,,,,,,,, -5700,0.008663078,0.12663469,,,,,,,,,,, -5800,0.031267837,0.12807344,,,,,,,,,,, -5900,0.05521535,0.122318126,,,,,,,,,,, -6000,0.053998016,0.12246877,,,,,,,,,,, -6100,0.015390038,0.12304944,,,,,,,,,,, -6200,0.008874157,0.12396873,,,,,,,,,,, -6300,0.0067269574,0.12610354,,,,,,,,,,, -6400,0.02422354,0.12008336,,,,,,,,,,, -6500,0.0063846903,0.11997382,,,,,,,,,,, -6600,0.02707475,0.12141006,,,,,,,,,,, -6700,0.0124049345,0.12309455,,,,,,,,,,, -6800,0.0062093446,0.11964636,,,,,,,,,,, -6900,0.015756875,0.12433689,,,,,,,,,,, -6936,,,0.1253732571718078,0.1254328360112665,83274637.0,0.1277705654810855,95000000.0,6009.324901580811,6140.600414991379,6009.324901580811,130.91276717185974,0.1190395355224609,0.0 -7000,0.016693667,0.12912478,,,,,,,,,,, -7100,0.018747717,0.13406575,,,,,,,,,,, -7200,0.024548732,0.12431413,,,,,,,,,,, -7300,0.006674419,0.12133458,,,,,,,,,,, -7400,0.027270459,0.12388923,,,,,,,,,,, -7500,0.008662971,0.12885009,,,,,,,,,,, -7600,0.014586609,0.121346585,,,,,,,,,,, -7700,0.0075929207,0.12155636,,,,,,,,,,, -7800,0.01269517,0.119886845,,,,,,,,,,, -7900,0.0069554164,0.12465231,,,,,,,,,,, -8000,0.010119042,0.12078824,,,,,,,,,,, -8100,0.014291343,0.12630221,,,,,,,,,,, -8200,0.015220969,0.13608687,,,,,,,,,,, -8300,0.0145765,0.11362372,,,,,,,,,,, -8319,,,0.1242390816222946,0.1255081622083504,83274637.0,0.1278159442639802,95000000.0,7209.504168272018,7362.730304002762,7209.504168272018,152.78353452682495,0.1488809585571289,0.0 -8400,0.01224702,0.11995322,,,,,,,,,,, -8500,0.014964441,0.12538046,,,,,,,,,,, -8600,0.009948392,0.12699479,,,,,,,,,,, -8700,0.018328443,0.12297777,,,,,,,,,,, -8800,,,,,,,,7703.05911898613,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 16817365b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -21.71107530593872,0.0,5.745176553726196,1,0,5.745176553726196,0.457260122944079,95000000,27.456290245056152,0.4569042233176201,0.4571109073102294,83274637 -43.60981583595276,0.0187180042266845,1206.3538706302645,1360,0,1206.3538706302645,0.1278922480674342,95000000,1250.0298562049866,0.1240272384207203,0.1257054161453444,83274637 -65.35720086097717,0.0409667491912841,2406.909574985504,2723,0,2406.909574985504,0.1277309679070723,95000000,2472.4036326408386,0.1244316359267294,0.125226907604983,83274637 -87.16394591331482,0.0645015239715576,3607.4408810138702,4077,0,3607.4408810138702,0.1268861670435855,95000000,3694.8120152950287,0.1217468752519889,0.1245135632034487,83274637 -109.07357454299928,0.0923397541046142,4807.448707580566,5444,0,4807.448707580566,0.1265907369449013,95000000,4916.806090593338,0.1211082811156908,0.1241941247988418,83274637 -130.9825460910797,0.1210646629333496,6008.207366943359,6783,0,6008.207366943359,0.1263803292557565,95000000,6139.55067896843,0.1223790561253169,0.1240410344531132,83274637 -152.8712134361267,0.15064787864685059,7208.455326795578,8125,0,7208.455326795578,0.12615202905016448,95000000,7361.763083457947,0.12305975194622136,0.12380716147327667,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/measurements.csv deleted file mode 100644 index 204b4b90b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/measurements.csv +++ /dev/null @@ -1,95 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.665436,0.4574473,,,,,,,,,,, -1,,,0.4569042233176201,0.4571109073102294,83274637.0,0.457260122944079,95000000.0,5.745176553726196,27.456290245056152,5.745176553726196,21.71107530593872,0.0,0.0 -100,0.037206307,0.1323281,,,,,,,,,,, -200,0.01659115,0.13611038,,,,,,,,,,, -300,0.031075584,0.12899712,,,,,,,,,,, -400,0.106370844,0.13358968,,,,,,,,,,, -500,0.0316573,0.123821855,,,,,,,,,,, -600,0.052489094,0.12099892,,,,,,,,,,, -700,0.0067842673,0.12458764,,,,,,,,,,, -800,0.024044238,0.11498248,,,,,,,,,,, -900,0.061151437,0.11730434,,,,,,,,,,, -1000,0.04845394,0.127595,,,,,,,,,,, -1100,0.03237439,0.124865144,,,,,,,,,,, -1200,0.038442966,0.12283677,,,,,,,,,,, -1300,0.020729208,0.12131666,,,,,,,,,,, -1360,,,0.1240272384207203,0.1257054161453444,83274637.0,0.1278922480674342,95000000.0,1206.3538706302645,1250.0298562049866,1206.3538706302645,43.60981583595276,0.0187180042266845,0.0 -1400,0.08739969,0.13164932,,,,,,,,,,, -1500,0.00813635,0.12382829,,,,,,,,,,, -1600,0.007152845,0.12904705,,,,,,,,,,, -1700,0.044401664,0.13460922,,,,,,,,,,, -1800,0.016796775,0.11955771,,,,,,,,,,, -1900,0.015296006,0.122493014,,,,,,,,,,, -2000,0.009469636,0.12503716,,,,,,,,,,, -2100,0.013297623,0.12062661,,,,,,,,,,, -2200,0.0074270414,0.12882796,,,,,,,,,,, -2300,0.014289669,0.12024935,,,,,,,,,,, -2400,0.01780693,0.13092907,,,,,,,,,,, -2500,0.0057266555,0.1230165,,,,,,,,,,, -2600,0.017523276,0.11979821,,,,,,,,,,, -2700,0.023615276,0.12231973,,,,,,,,,,, -2723,,,0.1244316359267294,0.125226907604983,83274637.0,0.1277309679070723,95000000.0,2406.909574985504,2472.4036326408386,2406.909574985504,65.35720086097717,0.0409667491912841,0.0 -2800,0.023355091,0.118883304,,,,,,,,,,, -2900,0.0066008363,0.12946582,,,,,,,,,,, -3000,0.019862106,0.12911841,,,,,,,,,,, -3100,0.004153492,0.12697709,,,,,,,,,,, -3200,0.006077517,0.1194224,,,,,,,,,,, -3300,0.01972615,0.12274652,,,,,,,,,,, -3400,0.00967501,0.12317923,,,,,,,,,,, -3500,0.011743161,0.12117226,,,,,,,,,,, -3600,0.01976444,0.11844093,,,,,,,,,,, -3700,0.016731234,0.121035725,,,,,,,,,,, -3800,0.005159445,0.13222486,,,,,,,,,,, -3900,0.007485274,0.1204434,,,,,,,,,,, -4000,0.004293275,0.12094881,,,,,,,,,,, -4077,,,0.1217468752519889,0.1245135632034487,83274637.0,0.1268861670435855,95000000.0,3607.4408810138702,3694.8120152950287,3607.4408810138702,87.16394591331482,0.0645015239715576,0.0 -4100,0.014663411,0.12457059,,,,,,,,,,, -4200,0.006887955,0.12597087,,,,,,,,,,, -4300,0.0070875357,0.12335865,,,,,,,,,,, -4400,0.010162645,0.12512705,,,,,,,,,,, -4500,0.0066605485,0.12860852,,,,,,,,,,, -4600,0.013053073,0.11462897,,,,,,,,,,, -4700,0.01334553,0.1250526,,,,,,,,,,, -4800,0.007974995,0.12879854,,,,,,,,,,, -4900,0.00830556,0.124304995,,,,,,,,,,, -5000,0.01406658,0.12025405,,,,,,,,,,, -5100,0.008377488,0.11460413,,,,,,,,,,, -5200,0.01458564,0.11756226,,,,,,,,,,, -5300,0.007542659,0.11939814,,,,,,,,,,, -5400,0.006377363,0.11517144,,,,,,,,,,, -5444,,,0.1211082811156908,0.1241941247988418,83274637.0,0.1265907369449013,95000000.0,4807.448707580566,4916.806090593338,4807.448707580566,109.07357454299928,0.0923397541046142,0.0 -5500,0.0056394343,0.118710645,,,,,,,,,,, -5600,0.0068958765,0.119180344,,,,,,,,,,, -5700,0.006765532,0.12185587,,,,,,,,,,, -5800,0.0064191064,0.13127127,,,,,,,,,,, -5900,0.012277657,0.12029279,,,,,,,,,,, -6000,0.004620037,0.12229174,,,,,,,,,,, -6100,0.008441295,0.12520221,,,,,,,,,,, -6200,0.013675727,0.12042624,,,,,,,,,,, -6300,0.017355114,0.1220164,,,,,,,,,,, -6400,0.0051176525,0.12290001,,,,,,,,,,, -6500,0.013266354,0.123328686,,,,,,,,,,, -6600,0.0067320853,0.13643107,,,,,,,,,,, -6700,0.007216318,0.1163428,,,,,,,,,,, -6783,,,0.1223790561253169,0.1240410344531132,83274637.0,0.1263803292557565,95000000.0,6008.207366943359,6139.55067896843,6008.207366943359,130.9825460910797,0.1210646629333496,0.0 -6800,0.018188873,0.11720708,,,,,,,,,,, -6900,0.008142134,0.1230987,,,,,,,,,,, -7000,0.008385962,0.12574618,,,,,,,,,,, -7100,0.0056729857,0.12814327,,,,,,,,,,, -7200,0.0061631273,0.1250006,,,,,,,,,,, -7300,0.016039243,0.11685966,,,,,,,,,,, -7400,0.0062819053,0.116325796,,,,,,,,,,, -7500,0.0069982377,0.12085643,,,,,,,,,,, -7600,0.011229546,0.1230119,,,,,,,,,,, -7700,0.00871541,0.12056435,,,,,,,,,,, -7800,0.00816701,0.12475424,,,,,,,,,,, -7900,0.0086165685,0.11782627,,,,,,,,,,, -8000,0.008198645,0.124468595,,,,,,,,,,, -8100,0.00731458,0.12171047,,,,,,,,,,, -8125,,,0.1230597519462213,0.1238071614732766,83274637.0,0.1261520290501644,95000000.0,7208.455326795578,7361.763083457947,7208.455326795578,152.8712134361267,0.1506478786468505,0.0 -8200,0.007871122,0.1167813,,,,,,,,,,, -8300,0.0072216257,0.13245751,,,,,,,,,,, -8400,0.00619979,0.12418568,,,,,,,,,,, -8500,0.0086008,0.12469392,,,,,,,,,,, -8596,,,,,,,,7703.303838253021,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 842c36bfc..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -199.7080636024475,0.0,55.57856774330139,1,0,55.57856774330139,0.912513662602974,3581,0.2829431162570249,255.28697657585144,0.9032635007585798,0.2670762368610927,0.915261701713738,3554,0.2587972028106535 -204.1460680961609,0.0274045467376709,135.91773223876953,340,0,135.91773223876953,0.3440740777606639,3581,0.6827848339107442,340.1038362979889,0.3144971302577427,0.6899064608982631,0.3430679312702236,3554,0.663529567256964 -208.18507623672485,0.0634591579437255,215.9855597019196,584,0,215.9855597019196,0.3218626327535954,3581,0.7052362266781276,424.25503969192505,0.2929373128073556,0.712254456111363,0.3199036984076041,3554,0.6871440245585959 -212.22860860824585,0.0978899002075195,296.20644450187683,831,0,296.20644450187683,0.3117575561557525,3581,0.7169947918484711,508.5625383853912,0.2826029913766043,0.7251592363630023,0.3094001559642304,3554,0.699622742970069 -216.2720394134521,0.1372833251953125,376.2559127807617,1123,0,376.2559127807617,0.3051949388918947,3581,0.7230841949045309,592.7049720287323,0.2763338940484183,0.7318246705191476,0.3028629369526414,3554,0.7060530388294879 -220.3106219768524,0.1638650894165039,456.3312237262726,1468,0,456.3312237262726,0.300658122949246,3581,0.7285530539871893,676.8577456474304,0.2721036502293178,0.7371740341186523,0.298528376092519,3554,0.7114779891495498 -224.34614849090576,0.1881933212280273,536.3904583454132,1814,0,536.3904583454132,0.2976562704529985,3581,0.7304791128525552,760.9892885684967,0.2694324425288609,0.739152022770473,0.2957502288904227,3554,0.7133444902530599 -228.38849234581,0.2128489017486572,616.4971573352814,2158,0,616.4971573352814,0.2959929985295658,3581,0.7317220415692195,845.1751751899719,0.2681713615145002,0.7400693893432617,0.2941182168199563,3554,0.7146711891530669 -232.4296793937683,0.2388749122619629,696.509119272232,2503,0,696.509119272232,0.294625647405578,3581,0.7332659702466141,929.2663757801056,0.2666256427764892,0.7419065747942243,0.2928775579122995,3554,0.7160875345121693 -236.4725306034088,0.2647807598114013,776.5994248390198,2848,0,776.5994248390198,0.294470920472197,3581,0.7336256703129364,1013.4373586177826,0.2665649993079049,0.7424797330583844,0.2928403254365679,3554,0.7164755903339547 -240.51452016830444,0.2896666526794433,856.7893908023834,3190,0,856.7893908023834,0.2926284121055571,3581,0.7353150879860724,1097.7059400081637,0.2646067312785557,0.744572503226144,0.2909685691958005,3554,0.7181011794588844 -244.56147742271423,0.3177082538604736,936.8239948749542,3538,0,936.8239948749542,0.2918307451654566,3581,0.7366975061522619,1181.8276615142822,0.2635725566319057,0.7460277421133858,0.2902894199867227,3554,0.7194483493510833 -248.6039776802063,0.3449945449829101,1016.8996806144714,3884,0,1016.8996806144714,0.2911783626911128,3581,0.7373969305230732,1265.9849543571472,0.2633559022630964,0.7463890484401158,0.2895800451515897,3554,0.7203027728439786 -252.64858031272888,0.3698835372924804,1096.9104199409485,4230,0,1096.9104199409485,0.2914835555165282,3581,0.7372028315676487,1350.0770933628082,0.2637581825256347,0.745875494820731,0.2899739743620744,3554,0.7200002418050084 -256.6939928531647,0.3946630954742431,1176.9021356105804,4576,0,1176.9021356105804,0.2909181323739877,3581,0.7384671677560388,1434.1509294509888,0.2624566214425223,0.7481837953839984,0.2893379997120322,3554,0.7213690504976786 -260.7339150905609,0.4196033477783203,1256.8783206939695,4922,0,1256.8783206939695,0.2905407063756632,3581,0.7384208758028483,1518.2038979530334,0.2626213005610874,0.7474816186087472,0.2890369799543648,3554,0.7212315925937324 -264.77026534080505,0.4451534748077392,1336.9380061626434,5268,0,1336.9380061626434,0.2901217948787873,3581,0.7387660542402611,1602.3376910686493,0.262041585786002,0.7481545039585659,0.2886578200838843,3554,0.7215776073174592 -268.809974193573,0.4723153114318847,1416.9430875778198,5615,0,1416.9430875778198,0.2896329341271642,3581,0.7386184517680118,1686.421659231186,0.2613209826605661,0.7481823648725238,0.2880889772362567,3554,0.7214367146832794 -272.85027623176575,0.4980456829071045,1496.9353461265564,5961,0,1496.9353461265564,0.2895029894102555,3581,0.7390884616727171,1770.4920568466189,0.2613769769668579,0.7485959189278739,0.2880452015994302,3554,0.7219383226865855 -276.8902978897095,0.5235202312469482,1576.922093629837,6303,0,1576.922093629837,0.2894305517073617,3581,0.7398391548930118,1854.556218624115,0.2616139990942819,0.7485676492963519,0.288008621722443,3554,0.7227368287712789 -280.935923576355,0.5493738651275635,1656.9838755130768,6646,0,1656.9838755130768,0.2890858164182491,3581,0.7399165354038676,1938.7012684345243,0.260679977280753,0.7496201651436942,0.2876082523927705,3554,0.7227002145469893 -284.97637605667114,0.5766260623931885,1737.0378777980804,6991,0,1737.0378777980804,0.288878116218846,3581,0.7394819091865051,2022.8351304531093,0.2608584335872105,0.7485895838056292,0.2874398990903735,3554,0.7222806279016601 -289.0156946182251,0.6023619174957275,1817.0663664340973,7338,0,1817.0663664340973,0.2887308887182351,3581,0.7403037106647934,2106.940655231476,0.2608873333249773,0.7492905344281878,0.2873187905023389,3554,0.7231078483311058 -293.05683064460754,0.6286115646362305,1897.189120054245,7683,0,1897.189120054245,0.2893348316745497,3581,0.7404244515323932,2191.1428146362305,0.2609071390969412,0.7503745215279716,0.2878370054263242,3554,0.7233539810996412 -297.0999720096588,0.6573379039764404,1977.1989560127256,8027,0,1977.1989560127256,0.2891022129053511,3581,0.740137905023911,2275.2364633083344,0.2611512967518398,0.7490619250706264,0.2876697168902997,3554,0.7230151106148002 -301.1387703418732,0.68442702293396,2057.3637301921844,8373,0,2057.3637301921844,0.2883704387086359,3581,0.7401876739868403,2359.4793276786804,0.2605456454413278,0.7492820194789341,0.2869772581015669,3554,0.722934188370498 -305.17865443229675,0.7108016014099121,2137.3297839164734,8715,0,2137.3297839164734,0.2884313886440589,3581,0.7403200048869031,2443.523314476013,0.2602596623556955,0.7499177796500069,0.2870410238683525,3554,0.723055022180114 -309.2239181995392,0.7377052307128906,2217.491429805756,9061,0,2217.491429805756,0.2883768132264905,3581,0.7413580627356186,2527.769276380539,0.2600905724934169,0.750847407749721,0.2869474618167909,3554,0.7241669817459201 -313.2646355628968,0.7639820575714111,2297.5060591697693,9408,0,2297.5060591697693,0.2888754232407148,3581,0.7408017411773946,2611.8630299568176,0.2610064574650356,0.7499944823128837,0.2875855488259268,3554,0.7235614388057471 -317.30930352211,0.7913625240325928,2377.5569076538086,9751,0,2377.5569076538086,0.2882359943320651,3581,0.7413401322736317,2695.998062610626,0.2601182460784912,0.7508278574262347,0.2867819078195343,3554,0.7241959708690912 -321.34729051589966,0.818218469619751,2457.756463766098,10099,0,2457.756463766098,0.2881764761065344,3581,0.7411026729614633,2780.274495124817,0.2599021536963327,0.7506995882306781,0.2867163731666784,3554,0.723881967831141 -325.3895990848541,0.8493859767913818,2537.771735906601,10445,0,2537.771735906601,0.2879779797564402,3581,0.7414004686191008,2864.37513422966,0.25971748147692,0.7509869847978864,0.2866183803131155,3554,0.7240853038609665 -329.43454337120056,0.8776884078979492,2617.872266769409,10790,0,2617.872266769409,0.28772633969867,3581,0.741495575061959,2948.560939311981,0.2595784323556082,0.7511575562613351,0.286358319713483,3554,0.7242304555606359 -333.47989439964294,0.9055733680725098,2697.8649265766144,11134,0,2697.8649265766144,0.2877778812547996,3581,0.7416386778745462,3032.638692140579,0.2592880725860595,0.7513892310006278,0.2863875320941193,3554,0.7243891400974254 -337.51908659935,0.9325578212738036,2777.8663704395294,11480,0,2777.8663704395294,0.2875935997386379,3581,0.74135724461568,3116.7184529304504,0.2594648599624634,0.7506914819989886,0.2862425177836593,3554,0.7240741066404052 -341.55452609062195,0.9591963291168212,2857.9913415908813,11824,0,2857.9913415908813,0.288328714591769,3581,0.7408781672149888,3200.917214870453,0.2600880861282348,0.7505995886666434,0.2868656121953257,3554,0.723615913627251 -345.5973918437958,0.986703395843506,2938.1020641326904,12169,0,2938.1020641326904,0.2876703666595399,3581,0.7413234289915527,3285.1103279590607,0.2590014253343854,0.7514924321855817,0.2862641565841305,3554,0.7240672371799382 -349.6387412548065,1.0136513710021973,3018.1377742290497,12514,0,3018.1377742290497,0.2878473532729161,3581,0.7406686603340548,3369.2264947891235,0.2597200189317976,0.7499089922223773,0.286492978312289,3554,0.7233776807382527 -353.6803197860718,1.04435133934021,3098.17995595932,12859,0,3098.17995595932,0.2875071517317963,3581,0.7423111724640463,3453.3531222343445,0.2592545918055943,0.7517422948564801,0.2861169783936234,3554,0.7250731322761326 -357.6665780544281,1.0757217407226562,3178.1596059799194,13201,0,3178.1596059799194,0.2878732944926347,3581,0.7413691073547891,3537.362501144409,0.2592260496956961,0.751594066619873,0.2864757016192142,3554,0.7241303675216305 -361.7101609706879,1.1067495346069336,3258.367030143738,13548,0,3258.367030143738,0.2880651095298974,3581,0.7420486923170903,3621.656540393829,0.259649498122079,0.7518584387642997,0.2866530882621254,3554,0.7249348500369303 -365.7559454441071,1.1348917484283447,3338.5711925029755,13893,0,3338.5711925029755,0.287589679580599,3581,0.7411973703443522,3705.94663143158,0.2593644346509661,0.7509762900216239,0.2861734797059651,3554,0.7240757553109173 -369.80122470855713,1.1632394790649414,3418.6912846565247,14237,0,3418.6912846565247,0.2876861836450188,3581,0.7414800989597877,3790.152283668518,0.2589819261005947,0.7518370492117745,0.2863299316681028,3554,0.7242094350116066 -373.8409984111786,1.191213607788086,3498.888443470001,14583,0,3498.888443470001,0.2875336383648073,3581,0.7419205201933817,3874.429375886917,0.2590833391462053,0.7517756053379604,0.2862203122526994,3554,0.7246804739158342 -377.8848206996918,1.2186510562896729,3578.980725288391,14930,0,3578.980725288391,0.287559579584526,3581,0.7419288377460905,3958.60556936264,0.2592453275408063,0.7516927719116211,0.2861635533355902,3554,0.7247757533325127 -381.9346296787262,1.2474379539489746,3658.9713571071625,15274,0,3658.9713571071625,0.2883185562691985,3581,0.7423833715486247,4042.6870880126953,0.2594645874840872,0.7525625228881836,0.2869511369781408,3554,0.7252395106086452 -385.9861137866974,1.276132583618164,3739.1749410629272,15618,0,3739.1749410629272,0.287680593158772,3581,0.741825345573862,4126.983177185059,0.2591737338474819,0.7516856874738421,0.2863493894148758,3554,0.7246266173457724 -390.0353882312775,1.3046724796295166,3819.206693887711,15962,0,3819.206693887711,0.2871475539304663,3581,0.7421772735007679,4211.104909896851,0.2588274819510324,0.7520236968994141,0.2858968808802581,3554,0.7249162337990644 -394.0746328830719,1.333157300949097,3899.3228511810303,16305,0,3899.3228511810303,0.2877406908859257,3581,0.7417084907759705,4295.300581455231,0.2592099905014038,0.7516825539725167,0.2863791170050471,3554,0.72449877668648 -398.1184675693512,1.3620598316192627,3979.487622022629,16653,0,3979.487622022629,0.2875004022423031,3581,0.7420754175684167,4379.550000429153,0.2589167526790074,0.7520562580653599,0.2861561686655881,3554,0.7248071467668472 -402.1616246700287,1.389885425567627,4059.464731216431,16997,0,4059.464731216431,0.287021767989912,3581,0.7424610929427883,4463.610082149506,0.2585340908595493,0.7524524416242327,0.2857100659028559,3554,0.7252197952571047 -406.2029480934143,1.4190490245819092,4139.484055280685,17340,0,4139.484055280685,0.2885778662013927,3581,0.7420885756640953,4547.71174955368,0.2602828230176653,0.7515924998692104,0.2872200420081246,3554,0.7249902178882949 -410.2422773838043,1.4473683834075928,4219.586560964584,17687,0,4219.586560964584,0.287437952420326,3581,0.7425334965573513,4631.894050359726,0.2587361506053379,0.7524777821132115,0.2860411052027645,3554,0.725357871412493 -414.2865955829621,1.4766151905059814,4299.56721663475,18031,0,4299.56721663475,0.2872522391942718,3581,0.743022118690659,4715.960064649582,0.2585129737854004,0.7531301634652274,0.2859081983163776,3554,0.7258307650710467 -418.3355646133423,1.5043630599975586,4379.592621326447,18376,0,4379.592621326447,0.2869516823818416,3581,0.7419246789697361,4800.07418346405,0.2584585973194667,0.7519820758274623,0.2856576690931433,3554,0.7246246938968416 -422.3765048980713,1.5328974723815918,4459.631384849548,18722,0,4459.631384849548,0.2871849488293249,3581,0.7424981128700083,4884.19443488121,0.2582253387996128,0.7529837744576591,0.2858288045270294,3554,0.7253613061427265 -426.4215528964996,1.5613839626312256,4539.849860191345,19067,0,4539.849860191345,0.2868095340425161,3581,0.7424585704063111,4968.498574972153,0.2581585986273629,0.7525200843811035,0.2855097696092871,3554,0.7251618857053672 -430.4662718772888,1.592024326324463,4619.828453540802,19410,0,4619.828453540802,0.2868313164858803,3581,0.7419808565388509,5052.564338922501,0.2583410739898681,0.7519347327096122,0.2855432067081105,3554,0.7246815043349043 -434.50729513168335,1.6206250190734863,4700.001065015793,19757,0,4700.001065015793,0.2868872895250104,3581,0.7435532148841106,5136.818740844727,0.2578264985765729,0.7540554319109235,0.2855517076654386,3554,0.7263243358056064 -438.54583048820496,1.6494874954223633,4780.03605055809,20104,0,4780.03605055809,0.286899697677412,3581,0.7426360342563181,5220.933315753937,0.2582331725529262,0.752790996006557,0.2855855912791925,3554,0.7253982638400394 -442.5905215740204,1.6790635585784912,4860.2350742816925,20447,0,4860.2350742816925,0.2867745594151424,3581,0.7431528815275062,5305.218400716782,0.2581703151975359,0.7531872476850238,0.2854585062605514,3554,0.7259412259953574 -446.64170718193054,1.7088754177093506,4940.427654981613,20791,0,4940.427654981613,0.2869013339172891,3581,0.7424073697334892,5389.503933191299,0.2577710662569318,0.753039973122733,0.2855243843864308,3554,0.7251662821600662 -450.68442273139954,1.743255853652954,5020.432681083679,21138,0,5020.432681083679,0.2867768433333042,3581,0.7429243533580006,5473.598413228989,0.2579892703465053,0.7532409940447126,0.285431011245032,3554,0.7257597348498172 -454.7285809516907,1.772273302078247,5100.391660451889,21483,0,5100.391660451889,0.286762423969387,3581,0.742838178057805,5557.642760276794,0.2580551760537283,0.7530279840741839,0.2854780842228827,3554,0.7256322376635481 -458.7747828960418,1.8035430908203125,5180.50639796257,21825,0,5180.50639796257,0.2871336799798415,3581,0.7424775916948827,5641.847291946411,0.257967335837228,0.753004619053432,0.2857030934004818,3554,0.7253357517497889 -462.8213136196137,1.834075689315796,5260.673074483872,22173,0,5260.673074483872,0.286666090346621,3581,0.7429429655866029,5726.103197097778,0.2577966281345912,0.7532924243382045,0.2853456753723797,3554,0.7257460646234877 -466.86464047431946,1.8650314807891848,5340.706670284271,22519,0,5340.706670284271,0.2868627118385227,3581,0.743415566204447,5810.223093986511,0.2580252374921526,0.7536420822143555,0.2855393769839002,3554,0.7262866224676421 -470.9098572731018,1.90094256401062,5420.829624176025,22860,0,5420.829624176025,0.2871538602716594,3581,0.7426878485190939,5894.439101219177,0.2581378732408796,0.7531393596104213,0.2857794474535734,3554,0.7255267914453785 -474.96408867836,1.930015563964844,5500.831784963608,23207,0,5500.831784963608,0.2866646586367286,3581,0.7432383750610863,5978.536752939224,0.2577291897365025,0.7534748486110142,0.2853734451663179,3554,0.7260389784178038 -479.0078208446503,1.9629197120666504,5581.046790122986,23552,0,5581.046790122986,0.2868668706148771,3581,0.7430817050928512,6062.840421676636,0.2580478872571672,0.7532529830932617,0.2855493205279262,3554,0.7259530414673607 -483.0488519668579,2.0170490741729736,5661.185736894608,23896,0,5661.185736894608,0.2872241163213836,3581,0.7421021428197431,6147.086599349976,0.2584163461412702,0.7521711758204869,0.285952265905274,3554,0.7247517102208779 -487.0972065925598,2.053084135055542,5741.163243055344,24243,0,5741.163243055344,0.2865261918371265,3581,0.7430396400926766,6231.160847902298,0.2573205062321254,0.7536928313119071,0.2851954059246623,3554,0.7258101566896454 -491.1406552791596,2.083099126815796,5821.219519615173,24588,0,5821.219519615173,0.2865782788065484,3581,0.7431695848095853,6315.30278635025,0.2576057570321219,0.7535805021013532,0.2853072407410664,3554,0.7259259070985158 -495.1886050701141,2.1133792400360107,5901.379008293152,24932,0,5901.379008293152,0.2873815021445651,3581,0.742053396506737,6399.552320480347,0.2583734137671334,0.7523881367274693,0.2859432840857133,3554,0.7248127110298256 -499.2294337749481,2.143554449081421,5981.531459093094,25277,0,5981.531459093094,0.2865642685026005,3581,0.743282826244415,6483.78791642189,0.2572070360183716,0.7541452135358538,0.2851946331103598,3554,0.7261007348674029 -503.27667450904846,2.174089431762696,6061.55241727829,25624,0,6061.55241727829,0.2865010687373464,3581,0.7433787508072117,6567.898736476898,0.2573329380580357,0.7540187154497419,0.2851841400094963,3554,0.7261812449440771 -507.32264375686646,2.204224109649658,6141.601917743683,25968,0,6141.601917743683,0.286626206999616,3581,0.743403226228707,6652.036284208298,0.2575019257409232,0.7539004598345075,0.2853177166682787,3554,0.7261755432918894 -511.3682265281677,2.2362966537475586,6221.579905986786,26314,0,6221.579905986786,0.2864493908278937,3581,0.7430313907166294,6736.104068756104,0.256960460117885,0.7539973258972168,0.2851275356552476,3554,0.7257981351338281 -515.4121625423431,2.269528865814209,6301.547380447388,26656,0,6301.547380447388,0.286397712918441,3581,0.7435331709456158,6820.160589933395,0.2571321044649396,0.7542306355067662,0.2851001436816351,3554,0.7263142376987197 -519.3987793922424,2.301710605621338,6381.7876925468445,27001,0,6381.7876925468445,0.2864608104187028,3581,0.7436299136283511,6904.431702852249,0.2572677816663469,0.7542601994105748,0.2851712769447717,3554,0.726429163772334 -523.4427843093872,2.331604719161988,6461.864283800125,27345,0,6461.864283800125,0.2863423875575956,3581,0.7432553510498116,6988.59413766861,0.2567661660058157,0.7543474606105259,0.285035107064663,3554,0.7260194691500774 -527.4829468727112,2.3632917404174805,6542.003179073334,27691,0,6542.003179073334,0.2863778394216001,3581,0.743281462711184,7072.816970825195,0.257008637700762,0.7541309084211077,0.2850820254796532,3554,0.7260204995691475 -531.5244166851044,2.395188331604004,6621.994435787201,28038,0,6621.994435787201,0.2865027390655543,3581,0.7429226489414619,7156.893805742264,0.257212621825082,0.7536019597734723,0.2851892577575443,3554,0.7257375464925084 -535.5664856433868,2.4301555156707764,6702.170172929764,28384,0,6702.170172929764,0.2864661281983035,3581,0.7429875531232547,7241.158483028412,0.256799612726484,0.7541613578796387,0.2850896334071205,3554,0.7257597348498172 -539.6133246421814,2.461496353149414,6782.152277231216,28730,0,6782.152277231216,0.2863912020472633,3581,0.7433747965608419,7325.230720281601,0.2568827697208949,0.7544137409755162,0.2850468538420617,3554,0.7262156609410172 -543.6638793945312,2.492823839187622,6862.30401468277,29077,0,6862.30401468277,0.286466060021642,3581,0.7436050291468863,7409.476336956024,0.2570457458496094,0.7544715063912528,0.2851410513187166,3554,0.7264549242490855 -547.7039232254028,2.525808572769165,6942.3070504665375,29420,0,6942.3070504665375,0.2863644767959369,3581,0.7431189977267174,7493.5642166137695,0.2565846443176269,0.7544336318969727,0.2850145330305642,3554,0.7259691846994584 -551.7444829940796,2.556824445724488,7022.33305644989,29767,0,7022.33305644989,0.2863429670592188,3581,0.7435238307429838,7577.673763036728,0.2567110913140433,0.7546213694981166,0.2850071827078644,3554,0.7263272209790025 -555.7933006286621,2.587928295135498,7102.540855884552,30113,0,7102.540855884552,0.2864089279792656,3581,0.7436164828260262,7661.973260879517,0.2568695545196533,0.7546024322509766,0.2850905779579347,3554,0.7264716857326252 -559.842346906662,2.6249148845672607,7182.582946300507,30456,0,7182.582946300507,0.2863609656978672,3581,0.7431831519652332,7746.113355398178,0.2566305569240025,0.7543889454432896,0.2850126267552845,3554,0.7259818245067178 -563.8314032554626,2.66121768951416,7262.79195356369,30801,0,7262.79195356369,0.2863411944660186,3581,0.7434223156939402,7830.360119819641,0.256553258214678,0.7546736172267369,0.2849770601237162,3554,0.7262378492983258 -567.8771078586578,2.698338508605957,7342.801480770111,31148,0,7342.801480770111,0.2863343086232023,3581,0.743772607380969,7914.464708566666,0.2566737788064139,0.7548414639064244,0.2850077322647017,3554,0.7265923134584271 -571.9233648777008,2.7299087047576904,7422.888996124268,31490,0,7422.888996124268,0.286525510070511,3581,0.7431318149390882,7998.641962766647,0.256832412311009,0.7542413302830288,0.2851721356273301,3554,0.7259741307109947 -575.9671130180359,2.762040376663208,7502.898470878601,31837,0,7502.898470878601,0.286277449287472,3581,0.7436150511161338,8082.73925280571,0.2564304385866437,0.7548629896981376,0.2849297982357027,3554,0.7264253168744724 -580.0088977813721,2.793639659881592,7582.90892624855,32181,0,7582.90892624855,0.2863116057949071,3581,0.7439609794968235,8166.83486032486,0.2565772703715733,0.7551521573747907,0.2849779703272281,3554,0.7268253255574705 -584.0519845485687,2.826245784759521,7663.136118888855,32523,0,7663.136118888855,0.2863976106534487,3581,0.7437946284426487,8251.14967083931,0.2566173417227609,0.754964964730399,0.2850120943720983,3554,0.7266737852595667 -588.0979740619659,2.8580024242401123,7743.118954181671,32867,0,7743.118954181671,0.2862747563093409,3581,0.7437910150795867,8335.222202539444,0.2563023567199707,0.7552356038774762,0.2849172786440014,3554,0.726623226030529 -592.1455419063568,2.890659809112549,7823.141031265259,33211,0,7823.141031265259,0.2862539624275691,3581,0.7439368449586359,8419.336717128754,0.2563957146235874,0.7552266120910645,0.2849077300939522,3554,0.726784520962296 -596.1902956962585,3.1312170028686523,7902.981092453003,33554,0,7902.981092453003,0.2862377022937901,3581,0.743793673969387,8503.473998785019,0.2564103603363037,0.7550900323050362,0.2848932183587155,3554,0.7266263859823439 -600.2369115352631,3.163629531860352,7983.175691604614,33899,0,7983.175691604614,0.2861828541695755,3581,0.7439810916119799,8587.75936126709,0.2561522551945278,0.7554918016706195,0.2848349138130012,3554,0.7268479947770118 -604.2770035266876,3.2029614448547363,8063.2761261463165,34245,0,8063.2761261463165,0.2861913080756074,3581,0.7437081122591455,8671.951352119446,0.256242104939052,0.7550997052873883,0.2848489103387028,3554,0.7265252675242684 -608.3235597610474,3.236752986907959,8143.4895124435425,34591,0,8143.4895124435425,0.2861426640275935,3581,0.744058131239528,8756.257097244263,0.2562501089913504,0.7553950037275042,0.2848210890238112,3554,0.7269252075126618 -612.3583030700684,3.270219087600708,8223.44726896286,34936,0,8223.44726896286,0.2860856342502094,3581,0.7441798265803895,8840.294992685318,0.2561021532331194,0.7556241580418178,0.2847381918096247,3554,0.7270647262547482 -616.4016833305359,3.3031046390533447,8303.64076423645,35281,0,8303.64076423645,0.2860812709438704,3581,0.7438649185807037,8924.576595306396,0.2561110258102417,0.755319322858538,0.2847387585401132,3554,0.7267110177352982 -620.4445323944092,3.3424787521362305,8383.82935500145,35625,0,8383.82935500145,0.2861023716206192,3581,0.7439365722519896,9008.859397649763,0.2561379841395786,0.755361693246024,0.2847745484291467,3554,0.7268043050084412 -624.4805216789246,3.3811798095703125,8463.80910038948,35966,0,8463.80910038948,0.2860723738895385,3581,0.743905824577632,9092.92571401596,0.2561021191733224,0.7553381238664899,0.2847349975105075,3554,0.7267506545221933 -628.5201015472412,3.4160211086273193,8514.113553285599,36189,0,8514.113553285599,0.2860763281359083,3581,0.744003658086952,9147.31237244606,0.25610746656145367,0.755429880959647,0.28473779681564787,3554,0.7268578181054798 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 64173d399..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.3851867,0.90665424,,,,,,,,,,,,,, -1,,,0.2670762368610927,0.9032635007585798,0.2587972028106535,0.915261701713738,3554.0,0.2829431162570249,0.912513662602974,3581.0,55.57856774330139,255.28697657585144,55.57856774330139,199.7080636024475,0.0,0.0 -100,0.73468214,0.29955631,,,,,,,,,,,,,, -200,0.17182991,0.34124586,,,,,,,,,,,,,, -300,0.12168355,0.3615769,,,,,,,,,,,,,, -340,,,0.6899064608982631,0.3144971302577427,0.663529567256964,0.3430679312702236,3554.0,0.6827848339107442,0.3440740777606639,3581.0,135.91773223876953,340.1038362979889,135.91773223876953,204.1460680961609,0.0274045467376709,0.0 -400,0.11726605,0.3498485,,,,,,,,,,,,,, -500,0.1801889,0.28639138,,,,,,,,,,,,,, -584,,,0.712254456111363,0.2929373128073556,0.6871440245585959,0.3199036984076041,3554.0,0.7052362266781276,0.3218626327535954,3581.0,215.9855597019196,424.25503969192505,215.9855597019196,208.18507623672485,0.0634591579437255,0.0 -600,0.14617684,0.27956447,,,,,,,,,,,,,, -700,0.10495526,0.27740705,,,,,,,,,,,,,, -800,0.06420345,0.2311353,,,,,,,,,,,,,, -831,,,0.7251592363630023,0.2826029913766043,0.699622742970069,0.3094001559642304,3554.0,0.7169947918484711,0.3117575561557525,3581.0,296.20644450187683,508.5625383853912,296.20644450187683,212.22860860824585,0.0978899002075195,0.0 -900,0.20888284,0.33970854,,,,,,,,,,,,,, -1000,0.10314046,0.23391956,,,,,,,,,,,,,, -1100,0.16491258,0.35843664,,,,,,,,,,,,,, -1123,,,0.7318246705191476,0.2763338940484183,0.7060530388294879,0.3028629369526414,3554.0,0.7230841949045309,0.3051949388918947,3581.0,376.2559127807617,592.7049720287323,376.2559127807617,216.2720394134521,0.1372833251953125,0.0 -1200,0.11958895,0.22728217,,,,,,,,,,,,,, -1300,0.11732015,0.3152772,,,,,,,,,,,,,, -1400,0.29979905,0.29784954,,,,,,,,,,,,,, -1468,,,0.7371740341186523,0.2721036502293178,0.7114779891495498,0.298528376092519,3554.0,0.7285530539871893,0.300658122949246,3581.0,456.3312237262726,676.8577456474304,456.3312237262726,220.3106219768524,0.1638650894165039,0.0 -1500,0.18274598,0.4024582,,,,,,,,,,,,,, -1600,0.110896334,0.2294926,,,,,,,,,,,,,, -1700,0.05612098,0.33073884,,,,,,,,,,,,,, -1800,0.34747285,0.26443142,,,,,,,,,,,,,, -1814,,,0.739152022770473,0.2694324425288609,0.7133444902530599,0.2957502288904227,3554.0,0.7304791128525552,0.2976562704529985,3581.0,536.3904583454132,760.9892885684967,536.3904583454132,224.34614849090576,0.1881933212280273,0.0 -1900,0.07635919,0.26046976,,,,,,,,,,,,,, -2000,0.18614803,0.24523604,,,,,,,,,,,,,, -2100,0.054607306,0.3311863,,,,,,,,,,,,,, -2158,,,0.7400693893432617,0.2681713615145002,0.7146711891530669,0.2941182168199563,3554.0,0.7317220415692195,0.2959929985295658,3581.0,616.4971573352814,845.1751751899719,616.4971573352814,228.38849234581,0.2128489017486572,0.0 -2200,0.24397665,0.2546046,,,,,,,,,,,,,, -2300,0.07635026,0.24986866,,,,,,,,,,,,,, -2400,0.08576569,0.35390455,,,,,,,,,,,,,, -2500,0.1886126,0.2722411,,,,,,,,,,,,,, -2503,,,0.7419065747942243,0.2666256427764892,0.7160875345121693,0.2928775579122995,3554.0,0.7332659702466141,0.294625647405578,3581.0,696.509119272232,929.2663757801056,696.509119272232,232.4296793937683,0.2388749122619629,0.0 -2600,0.13163446,0.27626276,,,,,,,,,,,,,, -2700,0.07579369,0.26048797,,,,,,,,,,,,,, -2800,0.055228174,0.31255051,,,,,,,,,,,,,, -2848,,,0.7424797330583844,0.2665649993079049,0.7164755903339547,0.2928403254365679,3554.0,0.7336256703129364,0.294470920472197,3581.0,776.5994248390198,1013.4373586177826,776.5994248390198,236.4725306034088,0.2647807598114013,0.0 -2900,0.09364072,0.23615637,,,,,,,,,,,,,, -3000,0.07941757,0.26710132,,,,,,,,,,,,,, -3100,0.29097772,0.2578977,,,,,,,,,,,,,, -3190,,,0.744572503226144,0.2646067312785557,0.7181011794588844,0.2909685691958005,3554.0,0.7353150879860724,0.2926284121055571,3581.0,856.7893908023834,1097.7059400081637,856.7893908023834,240.51452016830444,0.2896666526794433,0.0 -3200,0.22939149,0.3695374,,,,,,,,,,,,,, -3300,0.07078976,0.30555162,,,,,,,,,,,,,, -3400,0.12837376,0.28601834,,,,,,,,,,,,,, -3500,0.12353232,0.2607426,,,,,,,,,,,,,, -3538,,,0.7460277421133858,0.2635725566319057,0.7194483493510833,0.2902894199867227,3554.0,0.7366975061522619,0.2918307451654566,3581.0,936.8239948749542,1181.8276615142822,936.8239948749542,244.56147742271423,0.3177082538604736,0.0 -3600,0.09336255,0.29096773,,,,,,,,,,,,,, -3700,0.33930972,0.28274232,,,,,,,,,,,,,, -3800,0.16829725,0.2712628,,,,,,,,,,,,,, -3884,,,0.7463890484401158,0.2633559022630964,0.7203027728439786,0.2895800451515897,3554.0,0.7373969305230732,0.2911783626911128,3581.0,1016.8996806144714,1265.9849543571472,1016.8996806144714,248.6039776802063,0.3449945449829101,0.0 -3900,0.17287137,0.26688245,,,,,,,,,,,,,, -4000,0.10843055,0.29593465,,,,,,,,,,,,,, -4100,0.09796631,0.2685127,,,,,,,,,,,,,, -4200,0.1162129,0.24354164,,,,,,,,,,,,,, -4230,,,0.745875494820731,0.2637581825256347,0.7200002418050084,0.2899739743620744,3554.0,0.7372028315676487,0.2914835555165282,3581.0,1096.9104199409485,1350.0770933628082,1096.9104199409485,252.64858031272888,0.3698835372924804,0.0 -4300,0.18100338,0.2866462,,,,,,,,,,,,,, -4400,0.09352117,0.29392272,,,,,,,,,,,,,, -4500,0.25510487,0.20209834,,,,,,,,,,,,,, -4576,,,0.7481837953839984,0.2624566214425223,0.7213690504976786,0.2893379997120322,3554.0,0.7384671677560388,0.2909181323739877,3581.0,1176.9021356105804,1434.1509294509888,1176.9021356105804,256.6939928531647,0.3946630954742431,0.0 -4600,0.14847684,0.20667124,,,,,,,,,,,,,, -4700,0.29796523,0.22045018,,,,,,,,,,,,,, -4800,0.16013895,0.31379673,,,,,,,,,,,,,, -4900,0.041746985,0.28914714,,,,,,,,,,,,,, -4922,,,0.7474816186087472,0.2626213005610874,0.7212315925937324,0.2890369799543648,3554.0,0.7384208758028483,0.2905407063756632,3581.0,1256.8783206939695,1518.2038979530334,1256.8783206939695,260.7339150905609,0.4196033477783203,0.0 -5000,0.43928635,0.37101153,,,,,,,,,,,,,, -5100,0.20753023,0.3288706,,,,,,,,,,,,,, -5200,0.08446183,0.23713337,,,,,,,,,,,,,, -5268,,,0.7481545039585659,0.262041585786002,0.7215776073174592,0.2886578200838843,3554.0,0.7387660542402611,0.2901217948787873,3581.0,1336.9380061626434,1602.3376910686493,1336.9380061626434,264.77026534080505,0.4451534748077392,0.0 -5300,0.33795708,0.2207434,,,,,,,,,,,,,, -5400,0.2965157,0.33151883,,,,,,,,,,,,,, -5500,0.046805758,0.2793099,,,,,,,,,,,,,, -5600,0.14162806,0.2553279,,,,,,,,,,,,,, -5615,,,0.7481823648725238,0.2613209826605661,0.7214367146832794,0.2880889772362567,3554.0,0.7386184517680118,0.2896329341271642,3581.0,1416.9430875778198,1686.421659231186,1416.9430875778198,268.809974193573,0.4723153114318847,0.0 -5700,0.1823292,0.29580355,,,,,,,,,,,,,, -5800,0.07097769,0.29361573,,,,,,,,,,,,,, -5900,0.22834961,0.3000503,,,,,,,,,,,,,, -5961,,,0.7485959189278739,0.2613769769668579,0.7219383226865855,0.2880452015994302,3554.0,0.7390884616727171,0.2895029894102555,3581.0,1496.9353461265564,1770.4920568466189,1496.9353461265564,272.85027623176575,0.4980456829071045,0.0 -6000,0.124122776,0.21122527,,,,,,,,,,,,,, -6100,0.074481785,0.32171154,,,,,,,,,,,,,, -6200,0.17105573,0.29407158,,,,,,,,,,,,,, -6300,0.15407696,0.23814552,,,,,,,,,,,,,, -6303,,,0.7485676492963519,0.2616139990942819,0.7227368287712789,0.288008621722443,3554.0,0.7398391548930118,0.2894305517073617,3581.0,1576.922093629837,1854.556218624115,1576.922093629837,276.8902978897095,0.5235202312469482,0.0 -6400,0.06664382,0.23205727,,,,,,,,,,,,,, -6500,0.14899835,0.25655356,,,,,,,,,,,,,, -6600,0.1786906,0.30148926,,,,,,,,,,,,,, -6646,,,0.7496201651436942,0.260679977280753,0.7227002145469893,0.2876082523927705,3554.0,0.7399165354038676,0.2890858164182491,3581.0,1656.9838755130768,1938.7012684345243,1656.9838755130768,280.935923576355,0.5493738651275635,0.0 -6700,0.13079411,0.28941074,,,,,,,,,,,,,, -6800,0.12745291,0.3078544,,,,,,,,,,,,,, -6900,0.18560813,0.3515516,,,,,,,,,,,,,, -6991,,,0.7485895838056292,0.2608584335872105,0.7222806279016601,0.2874398990903735,3554.0,0.7394819091865051,0.288878116218846,3581.0,1737.0378777980804,2022.8351304531093,1737.0378777980804,284.97637605667114,0.5766260623931885,0.0 -7000,0.1631187,0.24313614,,,,,,,,,,,,,, -7100,0.22010112,0.24662274,,,,,,,,,,,,,, -7200,0.19511057,0.27545953,,,,,,,,,,,,,, -7300,0.09728065,0.27659577,,,,,,,,,,,,,, -7338,,,0.7492905344281878,0.2608873333249773,0.7231078483311058,0.2873187905023389,3554.0,0.7403037106647934,0.2887308887182351,3581.0,1817.0663664340973,2106.940655231476,1817.0663664340973,289.0156946182251,0.6023619174957275,0.0 -7400,0.24500206,0.25878775,,,,,,,,,,,,,, -7500,0.14595772,0.20403783,,,,,,,,,,,,,, -7600,0.2435808,0.2626021,,,,,,,,,,,,,, -7683,,,0.7503745215279716,0.2609071390969412,0.7233539810996412,0.2878370054263242,3554.0,0.7404244515323932,0.2893348316745497,3581.0,1897.189120054245,2191.1428146362305,1897.189120054245,293.05683064460754,0.6286115646362305,0.0 -7700,0.12795198,0.26486892,,,,,,,,,,,,,, -7800,0.053755872,0.27992198,,,,,,,,,,,,,, -7900,0.6971505,0.27165526,,,,,,,,,,,,,, -8000,0.07277042,0.2993999,,,,,,,,,,,,,, -8027,,,0.7490619250706264,0.2611512967518398,0.7230151106148002,0.2876697168902997,3554.0,0.740137905023911,0.2891022129053511,3581.0,1977.1989560127256,2275.2364633083344,1977.1989560127256,297.0999720096588,0.6573379039764404,0.0 -8100,0.26323292,0.21838634,,,,,,,,,,,,,, -8200,0.05507832,0.27487767,,,,,,,,,,,,,, -8300,0.24984208,0.2483797,,,,,,,,,,,,,, -8373,,,0.7492820194789341,0.2605456454413278,0.722934188370498,0.2869772581015669,3554.0,0.7401876739868403,0.2883704387086359,3581.0,2057.3637301921844,2359.4793276786804,2057.3637301921844,301.1387703418732,0.68442702293396,0.0 -8400,0.11039118,0.2680432,,,,,,,,,,,,,, -8500,0.3050151,0.35212588,,,,,,,,,,,,,, -8600,0.07547357,0.34318566,,,,,,,,,,,,,, -8700,0.47625,0.20646052,,,,,,,,,,,,,, -8715,,,0.7499177796500069,0.2602596623556955,0.723055022180114,0.2870410238683525,3554.0,0.7403200048869031,0.2884313886440589,3581.0,2137.3297839164734,2443.523314476013,2137.3297839164734,305.17865443229675,0.7108016014099121,0.0 -8800,0.09579514,0.20313455,,,,,,,,,,,,,, -8900,0.09248939,0.23289236,,,,,,,,,,,,,, -9000,0.11359811,0.19820291,,,,,,,,,,,,,, -9061,,,0.750847407749721,0.2600905724934169,0.7241669817459201,0.2869474618167909,3554.0,0.7413580627356186,0.2883768132264905,3581.0,2217.491429805756,2527.769276380539,2217.491429805756,309.2239181995392,0.7377052307128906,0.0 -9100,0.20493677,0.30338,,,,,,,,,,,,,, -9200,0.35431984,0.21077135,,,,,,,,,,,,,, -9300,0.05327002,0.31838977,,,,,,,,,,,,,, -9400,0.13480924,0.23087633,,,,,,,,,,,,,, -9408,,,0.7499944823128837,0.2610064574650356,0.7235614388057471,0.2875855488259268,3554.0,0.7408017411773946,0.2888754232407148,3581.0,2297.5060591697693,2611.8630299568176,2297.5060591697693,313.2646355628968,0.7639820575714111,0.0 -9500,0.15714222,0.2692982,,,,,,,,,,,,,, -9600,0.093498886,0.2703902,,,,,,,,,,,,,, -9700,0.061632145,0.24741632,,,,,,,,,,,,,, -9751,,,0.7508278574262347,0.2601182460784912,0.7241959708690912,0.2867819078195343,3554.0,0.7413401322736317,0.2882359943320651,3581.0,2377.5569076538086,2695.998062610626,2377.5569076538086,317.30930352211,0.7913625240325928,0.0 -9800,0.39437282,0.26242885,,,,,,,,,,,,,, -9900,0.18079631,0.25315556,,,,,,,,,,,,,, -10000,0.09160895,0.24943826,,,,,,,,,,,,,, -10099,,,0.7506995882306781,0.2599021536963327,0.723881967831141,0.2867163731666784,3554.0,0.7411026729614633,0.2881764761065344,3581.0,2457.756463766098,2780.274495124817,2457.756463766098,321.34729051589966,0.818218469619751,0.0 -10100,0.129312,0.28637093,,,,,,,,,,,,,, -10200,0.05847849,0.3516219,,,,,,,,,,,,,, -10300,0.08324807,0.24454328,,,,,,,,,,,,,, -10400,0.1729954,0.30735058,,,,,,,,,,,,,, -10445,,,0.7509869847978864,0.25971748147692,0.7240853038609665,0.2866183803131155,3554.0,0.7414004686191008,0.2879779797564402,3581.0,2537.771735906601,2864.37513422966,2537.771735906601,325.3895990848541,0.8493859767913818,0.0 -10500,0.13516866,0.21487586,,,,,,,,,,,,,, -10600,0.10961781,0.25450987,,,,,,,,,,,,,, -10700,0.18089707,0.31585383,,,,,,,,,,,,,, -10790,,,0.7511575562613351,0.2595784323556082,0.7242304555606359,0.286358319713483,3554.0,0.741495575061959,0.28772633969867,3581.0,2617.872266769409,2948.560939311981,2617.872266769409,329.43454337120056,0.8776884078979492,0.0 -10800,0.15158336,0.19541636,,,,,,,,,,,,,, -10900,0.059843868,0.2566147,,,,,,,,,,,,,, -11000,0.09941763,0.33257157,,,,,,,,,,,,,, -11100,0.31505743,0.25035214,,,,,,,,,,,,,, -11134,,,0.7513892310006278,0.2592880725860595,0.7243891400974254,0.2863875320941193,3554.0,0.7416386778745462,0.2877778812547996,3581.0,2697.8649265766144,3032.638692140579,2697.8649265766144,333.47989439964294,0.9055733680725098,0.0 -11200,0.21840262,0.24704854,,,,,,,,,,,,,, -11300,0.2833333,0.30052495,,,,,,,,,,,,,, -11400,0.08247258,0.25921878,,,,,,,,,,,,,, -11480,,,0.7506914819989886,0.2594648599624634,0.7240741066404052,0.2862425177836593,3554.0,0.74135724461568,0.2875935997386379,3581.0,2777.8663704395294,3116.7184529304504,2777.8663704395294,337.51908659935,0.9325578212738036,0.0 -11500,0.22460474,0.2904142,,,,,,,,,,,,,, -11600,0.27508056,0.24502088,,,,,,,,,,,,,, -11700,0.06989758,0.24148102,,,,,,,,,,,,,, -11800,0.22455092,0.27851143,,,,,,,,,,,,,, -11824,,,0.7505995886666434,0.2600880861282348,0.723615913627251,0.2868656121953257,3554.0,0.7408781672149888,0.288328714591769,3581.0,2857.9913415908813,3200.917214870453,2857.9913415908813,341.55452609062195,0.9591963291168212,0.0 -11900,0.11766945,0.3384166,,,,,,,,,,,,,, -12000,0.11247433,0.32442194,,,,,,,,,,,,,, -12100,0.28936136,0.21942846,,,,,,,,,,,,,, -12169,,,0.7514924321855817,0.2590014253343854,0.7240672371799382,0.2862641565841305,3554.0,0.7413234289915527,0.2876703666595399,3581.0,2938.1020641326904,3285.1103279590607,2938.1020641326904,345.5973918437958,0.986703395843506,0.0 -12200,0.1970488,0.2709174,,,,,,,,,,,,,, -12300,0.0639771,0.27960736,,,,,,,,,,,,,, -12400,0.15001029,0.25245696,,,,,,,,,,,,,, -12500,0.27270347,0.33286446,,,,,,,,,,,,,, -12514,,,0.7499089922223773,0.2597200189317976,0.7233776807382527,0.286492978312289,3554.0,0.7406686603340548,0.2878473532729161,3581.0,3018.1377742290497,3369.2264947891235,3018.1377742290497,349.6387412548065,1.0136513710021973,0.0 -12600,0.11049478,0.37116307,,,,,,,,,,,,,, -12700,0.10157314,0.289175,,,,,,,,,,,,,, -12800,0.15309991,0.2490702,,,,,,,,,,,,,, -12859,,,0.7517422948564801,0.2592545918055943,0.7250731322761326,0.2861169783936234,3554.0,0.7423111724640463,0.2875071517317963,3581.0,3098.17995595932,3453.3531222343445,3098.17995595932,353.6803197860718,1.04435133934021,0.0 -12900,0.102836475,0.21794365,,,,,,,,,,,,,, -13000,0.10299307,0.28359,,,,,,,,,,,,,, -13100,0.18733728,0.24120566,,,,,,,,,,,,,, -13200,0.20929636,0.19126162,,,,,,,,,,,,,, -13201,,,0.751594066619873,0.2592260496956961,0.7241303675216305,0.2864757016192142,3554.0,0.7413691073547891,0.2878732944926347,3581.0,3178.1596059799194,3537.362501144409,3178.1596059799194,357.6665780544281,1.0757217407226562,0.0 -13300,0.22648527,0.25519943,,,,,,,,,,,,,, -13400,0.074029565,0.28664503,,,,,,,,,,,,,, -13500,0.10528814,0.2690397,,,,,,,,,,,,,, -13548,,,0.7518584387642997,0.259649498122079,0.7249348500369303,0.2866530882621254,3554.0,0.7420486923170903,0.2880651095298974,3581.0,3258.367030143738,3621.656540393829,3258.367030143738,361.7101609706879,1.1067495346069336,0.0 -13600,0.08396407,0.2162522,,,,,,,,,,,,,, -13700,0.14567761,0.32716933,,,,,,,,,,,,,, -13800,0.16201964,0.2158288,,,,,,,,,,,,,, -13893,,,0.7509762900216239,0.2593644346509661,0.7240757553109173,0.2861734797059651,3554.0,0.7411973703443522,0.287589679580599,3581.0,3338.5711925029755,3705.94663143158,3338.5711925029755,365.7559454441071,1.1348917484283447,0.0 -13900,0.3760531,0.23499689,,,,,,,,,,,,,, -14000,0.18251684,0.3304931,,,,,,,,,,,,,, -14100,0.1856042,0.27743053,,,,,,,,,,,,,, -14200,0.21710464,0.23707059,,,,,,,,,,,,,, -14237,,,0.7518370492117745,0.2589819261005947,0.7242094350116066,0.2863299316681028,3554.0,0.7414800989597877,0.2876861836450188,3581.0,3418.6912846565247,3790.152283668518,3418.6912846565247,369.80122470855713,1.1632394790649414,0.0 -14300,0.097725675,0.28788412,,,,,,,,,,,,,, -14400,0.09812094,0.22821085,,,,,,,,,,,,,, -14500,0.07476929,0.36469144,,,,,,,,,,,,,, -14583,,,0.7517756053379604,0.2590833391462053,0.7246804739158342,0.2862203122526994,3554.0,0.7419205201933817,0.2875336383648073,3581.0,3498.888443470001,3874.429375886917,3498.888443470001,373.8409984111786,1.191213607788086,0.0 -14600,0.36820284,0.20449643,,,,,,,,,,,,,, -14700,0.093552485,0.22360075,,,,,,,,,,,,,, -14800,0.11135479,0.26619858,,,,,,,,,,,,,, -14900,0.14258878,0.23459132,,,,,,,,,,,,,, -14930,,,0.7516927719116211,0.2592453275408063,0.7247757533325127,0.2861635533355902,3554.0,0.7419288377460905,0.287559579584526,3581.0,3578.980725288391,3958.60556936264,3578.980725288391,377.8848206996918,1.2186510562896729,0.0 -15000,0.12809536,0.22765319,,,,,,,,,,,,,, -15100,0.12913342,0.2076443,,,,,,,,,,,,,, -15200,0.10618794,0.24450946,,,,,,,,,,,,,, -15274,,,0.7525625228881836,0.2594645874840872,0.7252395106086452,0.2869511369781408,3554.0,0.7423833715486247,0.2883185562691985,3581.0,3658.9713571071625,4042.6870880126953,3658.9713571071625,381.9346296787262,1.2474379539489746,0.0 -15300,0.16586496,0.2418493,,,,,,,,,,,,,, -15400,0.19719888,0.23460549,,,,,,,,,,,,,, -15500,0.18847492,0.34656495,,,,,,,,,,,,,, -15600,0.11462212,0.303904,,,,,,,,,,,,,, -15618,,,0.7516856874738421,0.2591737338474819,0.7246266173457724,0.2863493894148758,3554.0,0.741825345573862,0.287680593158772,3581.0,3739.1749410629272,4126.983177185059,3739.1749410629272,385.9861137866974,1.276132583618164,0.0 -15700,0.1499777,0.24580823,,,,,,,,,,,,,, -15800,0.10314341,0.3034991,,,,,,,,,,,,,, -15900,0.2231847,0.29786804,,,,,,,,,,,,,, -15962,,,0.7520236968994141,0.2588274819510324,0.7249162337990644,0.2858968808802581,3554.0,0.7421772735007679,0.2871475539304663,3581.0,3819.206693887711,4211.104909896851,3819.206693887711,390.0353882312775,1.3046724796295166,0.0 -16000,0.21258788,0.22382456,,,,,,,,,,,,,, -16100,0.13961518,0.2765376,,,,,,,,,,,,,, -16200,0.12138054,0.25844625,,,,,,,,,,,,,, -16300,0.3045817,0.22189012,,,,,,,,,,,,,, -16305,,,0.7516825539725167,0.2592099905014038,0.72449877668648,0.2863791170050471,3554.0,0.7417084907759705,0.2877406908859257,3581.0,3899.3228511810303,4295.300581455231,3899.3228511810303,394.0746328830719,1.333157300949097,0.0 -16400,0.45000753,0.24252369,,,,,,,,,,,,,, -16500,0.1283802,0.20802622,,,,,,,,,,,,,, -16600,0.21806805,0.18881759,,,,,,,,,,,,,, -16653,,,0.7520562580653599,0.2589167526790074,0.7248071467668472,0.2861561686655881,3554.0,0.7420754175684167,0.2875004022423031,3581.0,3979.487622022629,4379.550000429153,3979.487622022629,398.1184675693512,1.3620598316192627,0.0 -16700,0.12910508,0.26587784,,,,,,,,,,,,,, -16800,0.19101472,0.23761316,,,,,,,,,,,,,, -16900,0.070066005,0.354671,,,,,,,,,,,,,, -16997,,,0.7524524416242327,0.2585340908595493,0.7252197952571047,0.2857100659028559,3554.0,0.7424610929427883,0.287021767989912,3581.0,4059.464731216431,4463.610082149506,4059.464731216431,402.1616246700287,1.389885425567627,0.0 -17000,0.11027181,0.23983307,,,,,,,,,,,,,, -17100,0.30282557,0.27528885,,,,,,,,,,,,,, -17200,0.3928251,0.34445632,,,,,,,,,,,,,, -17300,0.108167075,0.30722618,,,,,,,,,,,,,, -17340,,,0.7515924998692104,0.2602828230176653,0.7249902178882949,0.2872200420081246,3554.0,0.7420885756640953,0.2885778662013927,3581.0,4139.484055280685,4547.71174955368,4139.484055280685,406.2029480934143,1.4190490245819092,0.0 -17400,0.16549331,0.25779003,,,,,,,,,,,,,, -17500,0.12679282,0.30195007,,,,,,,,,,,,,, -17600,0.14878814,0.2571479,,,,,,,,,,,,,, -17687,,,0.7524777821132115,0.2587361506053379,0.725357871412493,0.2860411052027645,3554.0,0.7425334965573513,0.287437952420326,3581.0,4219.586560964584,4631.894050359726,4219.586560964584,410.2422773838043,1.4473683834075928,0.0 -17700,0.11942134,0.2858179,,,,,,,,,,,,,, -17800,0.25653195,0.361935,,,,,,,,,,,,,, -17900,0.08221142,0.31996715,,,,,,,,,,,,,, -18000,0.10558866,0.3675033,,,,,,,,,,,,,, -18031,,,0.7531301634652274,0.2585129737854004,0.7258307650710467,0.2859081983163776,3554.0,0.743022118690659,0.2872522391942718,3581.0,4299.56721663475,4715.960064649582,4299.56721663475,414.2865955829621,1.4766151905059814,0.0 -18100,0.243589,0.22983703,,,,,,,,,,,,,, -18200,0.10206398,0.35832042,,,,,,,,,,,,,, -18300,0.5786307,0.20253217,,,,,,,,,,,,,, -18376,,,0.7519820758274623,0.2584585973194667,0.7246246938968416,0.2856576690931433,3554.0,0.7419246789697361,0.2869516823818416,3581.0,4379.592621326447,4800.07418346405,4379.592621326447,418.3355646133423,1.5043630599975586,0.0 -18400,0.12613612,0.2855622,,,,,,,,,,,,,, -18500,0.10520677,0.23058304,,,,,,,,,,,,,, -18600,0.06439784,0.29419976,,,,,,,,,,,,,, -18700,0.19369753,0.30332237,,,,,,,,,,,,,, -18722,,,0.7529837744576591,0.2582253387996128,0.7253613061427265,0.2858288045270294,3554.0,0.7424981128700083,0.2871849488293249,3581.0,4459.631384849548,4884.19443488121,4459.631384849548,422.3765048980713,1.5328974723815918,0.0 -18800,0.113941096,0.25981948,,,,,,,,,,,,,, -18900,0.105024554,0.27861845,,,,,,,,,,,,,, -19000,0.15500772,0.29275998,,,,,,,,,,,,,, -19067,,,0.7525200843811035,0.2581585986273629,0.7251618857053672,0.2855097696092871,3554.0,0.7424585704063111,0.2868095340425161,3581.0,4539.849860191345,4968.498574972153,4539.849860191345,426.4215528964996,1.5613839626312256,0.0 -19100,0.12565704,0.36349165,,,,,,,,,,,,,, -19200,0.09559684,0.35168776,,,,,,,,,,,,,, -19300,0.10302569,0.31618738,,,,,,,,,,,,,, -19400,0.12211125,0.20231909,,,,,,,,,,,,,, -19410,,,0.7519347327096122,0.2583410739898681,0.7246815043349043,0.2855432067081105,3554.0,0.7419808565388509,0.2868313164858803,3581.0,4619.828453540802,5052.564338922501,4619.828453540802,430.4662718772888,1.592024326324463,0.0 -19500,0.25082755,0.3108256,,,,,,,,,,,,,, -19600,0.22047491,0.25866884,,,,,,,,,,,,,, -19700,0.22389549,0.24791996,,,,,,,,,,,,,, -19757,,,0.7540554319109235,0.2578264985765729,0.7263243358056064,0.2855517076654386,3554.0,0.7435532148841106,0.2868872895250104,3581.0,4700.001065015793,5136.818740844727,4700.001065015793,434.50729513168335,1.6206250190734863,0.0 -19800,0.14260581,0.28456512,,,,,,,,,,,,,, -19900,0.09463491,0.27579337,,,,,,,,,,,,,, -20000,0.090096235,0.34847635,,,,,,,,,,,,,, -20100,0.12947094,0.41641796,,,,,,,,,,,,,, -20104,,,0.752790996006557,0.2582331725529262,0.7253982638400394,0.2855855912791925,3554.0,0.7426360342563181,0.286899697677412,3581.0,4780.03605055809,5220.933315753937,4780.03605055809,438.54583048820496,1.6494874954223633,0.0 -20200,0.24575354,0.280332,,,,,,,,,,,,,, -20300,0.18891914,0.24186087,,,,,,,,,,,,,, -20400,0.08974228,0.2369006,,,,,,,,,,,,,, -20447,,,0.7531872476850238,0.2581703151975359,0.7259412259953574,0.2854585062605514,3554.0,0.7431528815275062,0.2867745594151424,3581.0,4860.2350742816925,5305.218400716782,4860.2350742816925,442.5905215740204,1.6790635585784912,0.0 -20500,0.48243445,0.29702598,,,,,,,,,,,,,, -20600,0.19591707,0.24106528,,,,,,,,,,,,,, -20700,0.08420815,0.23039362,,,,,,,,,,,,,, -20791,,,0.753039973122733,0.2577710662569318,0.7251662821600662,0.2855243843864308,3554.0,0.7424073697334892,0.2869013339172891,3581.0,4940.427654981613,5389.503933191299,4940.427654981613,446.64170718193054,1.7088754177093506,0.0 -20800,0.12115955,0.27095538,,,,,,,,,,,,,, -20900,0.14437068,0.2853269,,,,,,,,,,,,,, -21000,0.26146868,0.3906189,,,,,,,,,,,,,, -21100,0.12693915,0.28332424,,,,,,,,,,,,,, -21138,,,0.7532409940447126,0.2579892703465053,0.7257597348498172,0.285431011245032,3554.0,0.7429243533580006,0.2867768433333042,3581.0,5020.432681083679,5473.598413228989,5020.432681083679,450.68442273139954,1.743255853652954,0.0 -21200,0.15074237,0.22081605,,,,,,,,,,,,,, -21300,0.17947587,0.30308986,,,,,,,,,,,,,, -21400,0.10287088,0.25491405,,,,,,,,,,,,,, -21483,,,0.7530279840741839,0.2580551760537283,0.7256322376635481,0.2854780842228827,3554.0,0.742838178057805,0.286762423969387,3581.0,5100.391660451889,5557.642760276794,5100.391660451889,454.7285809516907,1.772273302078247,0.0 -21500,0.16017531,0.25288445,,,,,,,,,,,,,, -21600,0.14180772,0.23529074,,,,,,,,,,,,,, -21700,0.1454475,0.39734444,,,,,,,,,,,,,, -21800,0.31464893,0.18065652,,,,,,,,,,,,,, -21825,,,0.753004619053432,0.257967335837228,0.7253357517497889,0.2857030934004818,3554.0,0.7424775916948827,0.2871336799798415,3581.0,5180.50639796257,5641.847291946411,5180.50639796257,458.7747828960418,1.8035430908203125,0.0 -21900,0.21746941,0.27730137,,,,,,,,,,,,,, -22000,0.1282106,0.354406,,,,,,,,,,,,,, -22100,0.08600262,0.29195967,,,,,,,,,,,,,, -22173,,,0.7532924243382045,0.2577966281345912,0.7257460646234877,0.2853456753723797,3554.0,0.7429429655866029,0.286666090346621,3581.0,5260.673074483872,5726.103197097778,5260.673074483872,462.8213136196137,1.834075689315796,0.0 -22200,0.17943259,0.2728532,,,,,,,,,,,,,, -22300,0.21016519,0.2512897,,,,,,,,,,,,,, -22400,0.13268542,0.289285,,,,,,,,,,,,,, -22500,0.41586828,0.24218954,,,,,,,,,,,,,, -22519,,,0.7536420822143555,0.2580252374921526,0.7262866224676421,0.2855393769839002,3554.0,0.743415566204447,0.2868627118385227,3581.0,5340.706670284271,5810.223093986511,5340.706670284271,466.86464047431946,1.8650314807891848,0.0 -22600,0.19271588,0.23839043,,,,,,,,,,,,,, -22700,0.16167091,0.19511451,,,,,,,,,,,,,, -22800,0.29338622,0.23501575,,,,,,,,,,,,,, -22860,,,0.7531393596104213,0.2581378732408796,0.7255267914453785,0.2857794474535734,3554.0,0.7426878485190939,0.2871538602716594,3581.0,5420.829624176025,5894.439101219177,5420.829624176025,470.9098572731018,1.90094256401062,0.0 -22900,0.43334258,0.1957809,,,,,,,,,,,,,, -23000,0.17954229,0.2695722,,,,,,,,,,,,,, -23100,0.23342276,0.30205572,,,,,,,,,,,,,, -23200,0.106475316,0.32791346,,,,,,,,,,,,,, -23207,,,0.7534748486110142,0.2577291897365025,0.7260389784178038,0.2853734451663179,3554.0,0.7432383750610863,0.2866646586367286,3581.0,5500.831784963608,5978.536752939224,5500.831784963608,474.96408867836,1.930015563964844,0.0 -23300,0.22014646,0.25032824,,,,,,,,,,,,,, -23400,0.091511905,0.30811056,,,,,,,,,,,,,, -23500,0.30006006,0.23208039,,,,,,,,,,,,,, -23552,,,0.7532529830932617,0.2580478872571672,0.7259530414673607,0.2855493205279262,3554.0,0.7430817050928512,0.2868668706148771,3581.0,5581.046790122986,6062.840421676636,5581.046790122986,479.0078208446503,1.9629197120666504,0.0 -23600,0.1193706,0.23517127,,,,,,,,,,,,,, -23700,0.17416555,0.2956705,,,,,,,,,,,,,, -23800,0.11320631,0.2384201,,,,,,,,,,,,,, -23896,,,0.7521711758204869,0.2584163461412702,0.7247517102208779,0.285952265905274,3554.0,0.7421021428197431,0.2872241163213836,3581.0,5661.185736894608,6147.086599349976,5661.185736894608,483.0488519668579,2.0170490741729736,0.0 -23900,0.14839453,0.22177832,,,,,,,,,,,,,, -24000,0.08863408,0.23812158,,,,,,,,,,,,,, -24100,0.13158958,0.3270059,,,,,,,,,,,,,, -24200,0.17631957,0.25998515,,,,,,,,,,,,,, -24243,,,0.7536928313119071,0.2573205062321254,0.7258101566896454,0.2851954059246623,3554.0,0.7430396400926766,0.2865261918371265,3581.0,5741.163243055344,6231.160847902298,5741.163243055344,487.0972065925598,2.053084135055542,0.0 -24300,0.19350645,0.23760824,,,,,,,,,,,,,, -24400,0.35954675,0.1892125,,,,,,,,,,,,,, -24500,0.072285615,0.29492268,,,,,,,,,,,,,, -24588,,,0.7535805021013532,0.2576057570321219,0.7259259070985158,0.2853072407410664,3554.0,0.7431695848095853,0.2865782788065484,3581.0,5821.219519615173,6315.30278635025,5821.219519615173,491.1406552791596,2.083099126815796,0.0 -24600,0.14132994,0.28460667,,,,,,,,,,,,,, -24700,0.15099986,0.2529125,,,,,,,,,,,,,, -24800,0.07193513,0.2722249,,,,,,,,,,,,,, -24900,0.08602809,0.2599999,,,,,,,,,,,,,, -24932,,,0.7523881367274693,0.2583734137671334,0.7248127110298256,0.2859432840857133,3554.0,0.742053396506737,0.2873815021445651,3581.0,5901.379008293152,6399.552320480347,5901.379008293152,495.1886050701141,2.1133792400360107,0.0 -25000,0.0860864,0.24882519,,,,,,,,,,,,,, -25100,0.20513387,0.2822281,,,,,,,,,,,,,, -25200,0.13769385,0.22833246,,,,,,,,,,,,,, -25277,,,0.7541452135358538,0.2572070360183716,0.7261007348674029,0.2851946331103598,3554.0,0.743282826244415,0.2865642685026005,3581.0,5981.531459093094,6483.78791642189,5981.531459093094,499.2294337749481,2.143554449081421,0.0 -25300,0.15304376,0.20181504,,,,,,,,,,,,,, -25400,0.12772866,0.21845151,,,,,,,,,,,,,, -25500,0.10910943,0.25347048,,,,,,,,,,,,,, -25600,0.065084524,0.3319681,,,,,,,,,,,,,, -25624,,,0.7540187154497419,0.2573329380580357,0.7261812449440771,0.2851841400094963,3554.0,0.7433787508072117,0.2865010687373464,3581.0,6061.55241727829,6567.898736476898,6061.55241727829,503.27667450904846,2.174089431762696,0.0 -25700,0.11334096,0.41167146,,,,,,,,,,,,,, -25800,0.07648609,0.27728194,,,,,,,,,,,,,, -25900,0.17522839,0.24489908,,,,,,,,,,,,,, -25968,,,0.7539004598345075,0.2575019257409232,0.7261755432918894,0.2853177166682787,3554.0,0.743403226228707,0.286626206999616,3581.0,6141.601917743683,6652.036284208298,6141.601917743683,507.32264375686646,2.204224109649658,0.0 -26000,0.11885685,0.2936497,,,,,,,,,,,,,, -26100,0.14171857,0.27395445,,,,,,,,,,,,,, -26200,0.08624898,0.29246548,,,,,,,,,,,,,, -26300,0.08492407,0.36173293,,,,,,,,,,,,,, -26314,,,0.7539973258972168,0.256960460117885,0.7257981351338281,0.2851275356552476,3554.0,0.7430313907166294,0.2864493908278937,3581.0,6221.579905986786,6736.104068756104,6221.579905986786,511.3682265281677,2.2362966537475586,0.0 -26400,0.09218796,0.24907738,,,,,,,,,,,,,, -26500,0.094596006,0.25124174,,,,,,,,,,,,,, -26600,0.14953728,0.2643337,,,,,,,,,,,,,, -26656,,,0.7542306355067662,0.2571321044649396,0.7263142376987197,0.2851001436816351,3554.0,0.7435331709456158,0.286397712918441,3581.0,6301.547380447388,6820.160589933395,6301.547380447388,515.4121625423431,2.269528865814209,0.0 -26700,0.12166,0.34156352,,,,,,,,,,,,,, -26800,0.15525192,0.17803563,,,,,,,,,,,,,, -26900,0.113884315,0.21280733,,,,,,,,,,,,,, -27000,0.13725585,0.2120751,,,,,,,,,,,,,, -27001,,,0.7542601994105748,0.2572677816663469,0.726429163772334,0.2851712769447717,3554.0,0.7436299136283511,0.2864608104187028,3581.0,6381.7876925468445,6904.431702852249,6381.7876925468445,519.3987793922424,2.301710605621338,0.0 -27100,0.11043993,0.24629211,,,,,,,,,,,,,, -27200,0.12506783,0.3105281,,,,,,,,,,,,,, -27300,0.052206993,0.2406686,,,,,,,,,,,,,, -27345,,,0.7543474606105259,0.2567661660058157,0.7260194691500774,0.285035107064663,3554.0,0.7432553510498116,0.2863423875575956,3581.0,6461.864283800125,6988.59413766861,6461.864283800125,523.4427843093872,2.331604719161988,0.0 -27400,0.27675048,0.30482924,,,,,,,,,,,,,, -27500,0.07872709,0.23561293,,,,,,,,,,,,,, -27600,0.10696349,0.24451819,,,,,,,,,,,,,, -27691,,,0.7541309084211077,0.257008637700762,0.7260204995691475,0.2850820254796532,3554.0,0.743281462711184,0.2863778394216001,3581.0,6542.003179073334,7072.816970825195,6542.003179073334,527.4829468727112,2.3632917404174805,0.0 -27700,0.10649671,0.29070204,,,,,,,,,,,,,, -27800,0.0698732,0.31194982,,,,,,,,,,,,,, -27900,0.1467746,0.26913556,,,,,,,,,,,,,, -28000,0.081852406,0.23416588,,,,,,,,,,,,,, -28038,,,0.7536019597734723,0.257212621825082,0.7257375464925084,0.2851892577575443,3554.0,0.7429226489414619,0.2865027390655543,3581.0,6621.994435787201,7156.893805742264,6621.994435787201,531.5244166851044,2.395188331604004,0.0 -28100,0.2085831,0.27330726,,,,,,,,,,,,,, -28200,0.19302095,0.256168,,,,,,,,,,,,,, -28300,0.1719384,0.23951305,,,,,,,,,,,,,, -28384,,,0.7541613578796387,0.256799612726484,0.7257597348498172,0.2850896334071205,3554.0,0.7429875531232547,0.2864661281983035,3581.0,6702.170172929764,7241.158483028412,6702.170172929764,535.5664856433868,2.4301555156707764,0.0 -28400,0.16578433,0.23648435,,,,,,,,,,,,,, -28500,0.14086768,0.2944317,,,,,,,,,,,,,, -28600,0.12605977,0.36066946,,,,,,,,,,,,,, -28700,0.09150867,0.26582205,,,,,,,,,,,,,, -28730,,,0.7544137409755162,0.2568827697208949,0.7262156609410172,0.2850468538420617,3554.0,0.7433747965608419,0.2863912020472633,3581.0,6782.152277231216,7325.230720281601,6782.152277231216,539.6133246421814,2.461496353149414,0.0 -28800,0.14172032,0.20619416,,,,,,,,,,,,,, -28900,0.07263442,0.33960056,,,,,,,,,,,,,, -29000,0.06664154,0.2609557,,,,,,,,,,,,,, -29077,,,0.7544715063912528,0.2570457458496094,0.7264549242490855,0.2851410513187166,3554.0,0.7436050291468863,0.286466060021642,3581.0,6862.30401468277,7409.476336956024,6862.30401468277,543.6638793945312,2.492823839187622,0.0 -29100,0.14305887,0.25302976,,,,,,,,,,,,,, -29200,0.08042326,0.33078417,,,,,,,,,,,,,, -29300,0.4099622,0.34245685,,,,,,,,,,,,,, -29400,0.199883,0.18788487,,,,,,,,,,,,,, -29420,,,0.7544336318969727,0.2565846443176269,0.7259691846994584,0.2850145330305642,3554.0,0.7431189977267174,0.2863644767959369,3581.0,6942.3070504665375,7493.5642166137695,6942.3070504665375,547.7039232254028,2.525808572769165,0.0 -29500,0.13401717,0.22390395,,,,,,,,,,,,,, -29600,0.11233709,0.20476514,,,,,,,,,,,,,, -29700,0.06581909,0.30867925,,,,,,,,,,,,,, -29767,,,0.7546213694981166,0.2567110913140433,0.7263272209790025,0.2850071827078644,3554.0,0.7435238307429838,0.2863429670592188,3581.0,7022.33305644989,7577.673763036728,7022.33305644989,551.7444829940796,2.556824445724488,0.0 -29800,0.1020668,0.35652244,,,,,,,,,,,,,, -29900,0.09681076,0.2629989,,,,,,,,,,,,,, -30000,0.066508204,0.3215977,,,,,,,,,,,,,, -30100,0.10933823,0.22511058,,,,,,,,,,,,,, -30113,,,0.7546024322509766,0.2568695545196533,0.7264716857326252,0.2850905779579347,3554.0,0.7436164828260262,0.2864089279792656,3581.0,7102.540855884552,7661.973260879517,7102.540855884552,555.7933006286621,2.587928295135498,0.0 -30200,0.10458977,0.27991882,,,,,,,,,,,,,, -30300,0.11065303,0.2969715,,,,,,,,,,,,,, -30400,0.1668782,0.2842713,,,,,,,,,,,,,, -30456,,,0.7543889454432896,0.2566305569240025,0.7259818245067178,0.2850126267552845,3554.0,0.7431831519652332,0.2863609656978672,3581.0,7182.582946300507,7746.113355398178,7182.582946300507,559.842346906662,2.6249148845672607,0.0 -30500,0.0946144,0.22899406,,,,,,,,,,,,,, -30600,0.11265291,0.28087938,,,,,,,,,,,,,, -30700,0.08563645,0.2906662,,,,,,,,,,,,,, -30800,0.05865326,0.28265554,,,,,,,,,,,,,, -30801,,,0.7546736172267369,0.256553258214678,0.7262378492983258,0.2849770601237162,3554.0,0.7434223156939402,0.2863411944660186,3581.0,7262.79195356369,7830.360119819641,7262.79195356369,563.8314032554626,2.66121768951416,0.0 -30900,0.1192404,0.2161484,,,,,,,,,,,,,, -31000,0.065365605,0.25766385,,,,,,,,,,,,,, -31100,0.08152548,0.21233658,,,,,,,,,,,,,, -31148,,,0.7548414639064244,0.2566737788064139,0.7265923134584271,0.2850077322647017,3554.0,0.743772607380969,0.2863343086232023,3581.0,7342.801480770111,7914.464708566666,7342.801480770111,567.8771078586578,2.698338508605957,0.0 -31200,0.09931093,0.23862386,,,,,,,,,,,,,, -31300,0.09747937,0.25732708,,,,,,,,,,,,,, -31400,0.10552093,0.29852733,,,,,,,,,,,,,, -31490,,,0.7542413302830288,0.256832412311009,0.7259741307109947,0.2851721356273301,3554.0,0.7431318149390882,0.286525510070511,3581.0,7422.888996124268,7998.641962766647,7422.888996124268,571.9233648777008,2.7299087047576904,0.0 -31500,0.064675,0.31020826,,,,,,,,,,,,,, -31600,0.07067276,0.22673032,,,,,,,,,,,,,, -31700,0.07043011,0.25904524,,,,,,,,,,,,,, -31800,0.16026795,0.27689862,,,,,,,,,,,,,, -31837,,,0.7548629896981376,0.2564304385866437,0.7264253168744724,0.2849297982357027,3554.0,0.7436150511161338,0.286277449287472,3581.0,7502.898470878601,8082.73925280571,7502.898470878601,575.9671130180359,2.762040376663208,0.0 -31900,0.11526376,0.34036958,,,,,,,,,,,,,, -32000,0.18143746,0.17100948,,,,,,,,,,,,,, -32100,0.05061127,0.26947537,,,,,,,,,,,,,, -32181,,,0.7551521573747907,0.2565772703715733,0.7268253255574705,0.2849779703272281,3554.0,0.7439609794968235,0.2863116057949071,3581.0,7582.90892624855,8166.83486032486,7582.90892624855,580.0088977813721,2.793639659881592,0.0 -32200,0.07111147,0.26606625,,,,,,,,,,,,,, -32300,0.15662265,0.27205965,,,,,,,,,,,,,, -32400,0.11411826,0.24548386,,,,,,,,,,,,,, -32500,0.09221708,0.2288872,,,,,,,,,,,,,, -32523,,,0.754964964730399,0.2566173417227609,0.7266737852595667,0.2850120943720983,3554.0,0.7437946284426487,0.2863976106534487,3581.0,7663.136118888855,8251.14967083931,7663.136118888855,584.0519845485687,2.826245784759521,0.0 -32600,0.077103935,0.33464563,,,,,,,,,,,,,, -32700,0.069046736,0.2753192,,,,,,,,,,,,,, -32800,0.084066994,0.255092,,,,,,,,,,,,,, -32867,,,0.7552356038774762,0.2563023567199707,0.726623226030529,0.2849172786440014,3554.0,0.7437910150795867,0.2862747563093409,3581.0,7743.118954181671,8335.222202539444,7743.118954181671,588.0979740619659,2.8580024242401123,0.0 -32900,0.083087675,0.21017548,,,,,,,,,,,,,, -33000,0.061800323,0.33009988,,,,,,,,,,,,,, -33100,0.06706556,0.22376315,,,,,,,,,,,,,, -33200,0.09706133,0.3523184,,,,,,,,,,,,,, -33211,,,0.7552266120910645,0.2563957146235874,0.726784520962296,0.2849077300939522,3554.0,0.7439368449586359,0.2862539624275691,3581.0,7823.141031265259,8419.336717128754,7823.141031265259,592.1455419063568,2.890659809112549,0.0 -33300,0.06707582,0.31480083,,,,,,,,,,,,,, -33400,0.06754277,0.2535097,,,,,,,,,,,,,, -33500,0.05886183,0.26412305,,,,,,,,,,,,,, -33554,,,0.7550900323050362,0.2564103603363037,0.7266263859823439,0.2848932183587155,3554.0,0.743793673969387,0.2862377022937901,3581.0,7902.981092453003,8503.473998785019,7902.981092453003,596.1902956962585,3.1312170028686523,0.0 -33600,0.07655911,0.29456437,,,,,,,,,,,,,, -33700,0.09205262,0.25576594,,,,,,,,,,,,,, -33800,0.053706277,0.24027191,,,,,,,,,,,,,, -33899,,,0.7554918016706195,0.2561522551945278,0.7268479947770118,0.2848349138130012,3554.0,0.7439810916119799,0.2861828541695755,3581.0,7983.175691604614,8587.75936126709,7983.175691604614,600.2369115352631,3.163629531860352,0.0 -33900,0.051220026,0.28374645,,,,,,,,,,,,,, -34000,0.053986132,0.27799183,,,,,,,,,,,,,, -34100,0.06820535,0.2571627,,,,,,,,,,,,,, -34200,0.07040699,0.26619777,,,,,,,,,,,,,, -34245,,,0.7550997052873883,0.256242104939052,0.7265252675242684,0.2848489103387028,3554.0,0.7437081122591455,0.2861913080756074,3581.0,8063.2761261463165,8671.951352119446,8063.2761261463165,604.2770035266876,3.2029614448547363,0.0 -34300,0.07531189,0.22458859,,,,,,,,,,,,,, -34400,0.040418398,0.33199945,,,,,,,,,,,,,, -34500,0.083783284,0.22029021,,,,,,,,,,,,,, -34591,,,0.7553950037275042,0.2562501089913504,0.7269252075126618,0.2848210890238112,3554.0,0.744058131239528,0.2861426640275935,3581.0,8143.4895124435425,8756.257097244263,8143.4895124435425,608.3235597610474,3.236752986907959,0.0 -34600,0.0856123,0.256866,,,,,,,,,,,,,, -34700,0.078385495,0.24376367,,,,,,,,,,,,,, -34800,0.07253792,0.22741176,,,,,,,,,,,,,, -34900,0.07856893,0.29759675,,,,,,,,,,,,,, -34936,,,0.7556241580418178,0.2561021532331194,0.7270647262547482,0.2847381918096247,3554.0,0.7441798265803895,0.2860856342502094,3581.0,8223.44726896286,8840.294992685318,8223.44726896286,612.3583030700684,3.270219087600708,0.0 -35000,0.08954601,0.2485624,,,,,,,,,,,,,, -35100,0.07124384,0.30074644,,,,,,,,,,,,,, -35200,0.046095394,0.23112245,,,,,,,,,,,,,, -35281,,,0.755319322858538,0.2561110258102417,0.7267110177352982,0.2847387585401132,3554.0,0.7438649185807037,0.2860812709438704,3581.0,8303.64076423645,8924.576595306396,8303.64076423645,616.4016833305359,3.3031046390533447,0.0 -35300,0.04768852,0.27838892,,,,,,,,,,,,,, -35400,0.06203095,0.32130173,,,,,,,,,,,,,, -35500,0.05289488,0.31514633,,,,,,,,,,,,,, -35600,0.093244605,0.21089026,,,,,,,,,,,,,, -35625,,,0.755361693246024,0.2561379841395786,0.7268043050084412,0.2847745484291467,3554.0,0.7439365722519896,0.2861023716206192,3581.0,8383.82935500145,9008.859397649763,8383.82935500145,620.4445323944092,3.3424787521362305,0.0 -35700,0.076943025,0.21043016,,,,,,,,,,,,,, -35800,0.05210843,0.24361129,,,,,,,,,,,,,, -35900,0.06472887,0.21994317,,,,,,,,,,,,,, -35966,,,0.7553381238664899,0.2561021191733224,0.7267506545221933,0.2847349975105075,3554.0,0.743905824577632,0.2860723738895385,3581.0,8463.80910038948,9092.92571401596,8463.80910038948,624.4805216789246,3.3811798095703125,0.0 -36000,0.05174382,0.23559472,,,,,,,,,,,,,, -36100,0.048497707,0.37203982,,,,,,,,,,,,,, -36189,,,0.755429880959647,0.2561074665614536,0.7268578181054798,0.2847377968156478,3554.0,0.744003658086952,0.2860763281359083,3581.0,8514.113553285599,9147.31237244606,8514.113553285599,628.5201015472412,3.41602110862732,0.0 -36189,,,,,,,,,,,8514.113553285599,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/eval_measurements.csv deleted file mode 100644 index a9990ba28..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.983088254928589,0.0,29.42645001411438,1,0,29.42645001411438,0.912513662602974,3581,0.2829431162570249,33.40965127944946,0.9032635007585798,0.2670762368610927,0.915261701713738,3554,0.2587972028106535 -8.024477005004883,0.0183889865875244,109.55331134796144,342,0,109.55331134796144,0.332435537602974,3581,0.6928078258080843,117.60923790931702,0.3026022570473807,0.700493403843471,0.3304802975630451,3554,0.674353982198755 -12.066364765167236,0.0422174930572509,189.61023998260487,681,0,189.61023998260487,0.3170286688315938,3581,0.7106981999179698,201.7445123195648,0.2885269437517438,0.7177139009748187,0.3146172364105585,3554,0.6936591582547833 -16.104042291641235,0.0660929679870605,269.74872303009033,1020,0,269.74872303009033,0.3058018134446558,3581,0.7232667720041539,285.9569094181061,0.2770050082887922,0.7317164966038295,0.3035240881752954,3554,0.7062299274365152 -20.1486337184906,0.0897228717803955,349.75937604904175,1365,0,349.75937604904175,0.3015255005257784,3581,0.7263307675274016,370.048823595047,0.273283737046378,0.7347615105765206,0.2994085944094682,3554,0.709127740639948 -24.189027786254883,0.1182146072387695,429.8190410137177,1709,0,429.8190410137177,0.3000804961842886,3581,0.7256882024923206,454.1902050971985,0.2724383217947824,0.7332886287144252,0.2982055114505663,3554,0.7087479968653277 -28.22706866264344,0.1440095901489257,509.8124516010285,2053,0,509.8124516010285,0.2967952673943207,3581,0.7322903622198758,538.2597918510437,0.2687742710113525,0.7409915242876325,0.2948922676253869,3554,0.7151846126283765 -32.26940202713013,0.1705169677734375,589.8710687160492,2396,0,589.8710687160492,0.2949343172407323,3581,0.7335974451750559,622.399489402771,0.26690559727805,0.7421709469386509,0.2932297551504467,3554,0.7165991719277575 -36.31421685218811,0.194258451461792,670.0183191299438,2739,0,670.0183191299438,0.2943898243332868,3581,0.7333299199551452,706.627876996994,0.2665479012898036,0.7418599809919085,0.2926901246834553,3554,0.7161163862461312 -40.35602593421936,0.2186872959136963,750.0226106643677,3083,0,750.0226106643677,0.2937812112756562,3581,0.7344748106325049,790.7108700275421,0.2664205006190708,0.7423411096845355,0.292220871838949,3554,0.7173274034318725 -44.39246344566345,0.2435402870178222,830.2108571529388,3425,0,830.2108571529388,0.293075310121998,3581,0.7348356015254119,874.9735074043274,0.2648028305598667,0.7441543170383998,0.2914524883384039,3554,0.7175930454681345 -48.43196511268616,0.2675139904022217,910.250020980835,3770,0,910.250020980835,0.2938836467096307,3581,0.7350305867774365,959.0884766578674,0.2662832736968994,0.7439109938485282,0.2921615883951182,3554,0.7177881381453995 -52.47136425971985,0.2925031185150146,990.3172655105592,4112,0,990.3172655105592,0.2923847487171879,3581,0.7360385105417481,1043.2327580451963,0.2645629644393921,0.7449028151375907,0.2907478877782956,3554,0.7190284192327308 -56.51176953315735,0.3179206848144531,1070.4398555755615,4455,0,1070.4398555755615,0.2916109095202981,3581,0.7361167773492041,1127.433172941208,0.2632608924593244,0.745528153010777,0.2899452256700197,3554,0.7188410203511888 -60.554898500442505,0.3427619934082031,1150.5182836055756,4800,0,1150.5182836055756,0.2917902141401668,3581,0.7369595772392488,1211.5923788547516,0.2641662529536656,0.7455744062151227,0.2902841648494654,3554,0.7195378584209693 -64.59666466712952,0.3678746223449707,1230.709332704544,5145,0,1230.709332704544,0.2903629016423485,3581,0.7380477449342013,1295.8631434440613,0.2624162094933646,0.7470813478742327,0.288812829459324,3554,0.7207821237953714 -68.63917136192322,0.395914077758789,1310.8590772151947,5487,0,1310.8590772151947,0.290981604845888,3581,0.737104043585067,1380.0955708026886,0.2628247056688581,0.7463224955967495,0.2894811249208638,3554,0.7197903797877392 -72.6823296546936,0.4211378097534179,1390.823716878891,5832,0,1390.823716878891,0.2907189883456088,3581,0.7379295266030788,1464.141278505325,0.2629751477922712,0.746668951851981,0.2891875585278032,3554,0.7207214664594471 -76.7279725074768,0.4464681148529053,1470.9608535766602,6176,0,1470.9608535766602,0.290686297636397,3581,0.7384506690039444,1548.361647605896,0.2630742447716849,0.7470532144818988,0.2891715870322172,3554,0.7212173728105655 -80.76854872703552,0.4717001914978027,1551.1067745685575,6517,0,1551.1067745685575,0.2903405056090303,3581,0.739124458950014,1632.5861072540283,0.2622644390378679,0.7485293660845075,0.2888351208585397,3554,0.7220091468240011 -84.80290365219116,0.4968986511230469,1631.1339271068573,6862,0,1631.1339271068573,0.2895691207719562,3581,0.7395610622905613,1716.6855008602142,0.2613810130528041,0.7489062036786761,0.2880435529289181,3554,0.7224099798422552 -88.84633111953735,0.5218157768249512,1711.1187720298767,7206,0,1711.1187720298767,0.289598198118106,3581,0.7398471315624128,1800.75151348114,0.26164140020098,0.7486821583339146,0.2881173137606834,3554,0.7226811174468908 -92.88967680931091,0.5519485473632812,1791.1784365177157,7549,0,1791.1784365177157,0.289402974247766,3581,0.7407856514852694,1884.8976328372955,0.2611466135297502,0.750164372580392,0.2879235778018605,3554,0.7235968165271525 -96.92876887321472,0.5830910205841064,1871.205404281616,7892,0,1871.205404281616,0.2894897290495846,3581,0.7394183003612818,1969.0074808597565,0.2615155833108084,0.748753547668457,0.2880484130721986,3554,0.7221031210431907 -100.9730007648468,0.609849214553833,1951.2375724315643,8237,0,1951.2375724315643,0.288724582377042,3581,0.740555282545902,2053.123046398163,0.2605032920837402,0.7499610355922154,0.287204929195097,3554,0.7234105167592854 -105.01745820045473,0.6357917785644531,2031.3805038928983,8578,0,2031.3805038928983,0.288887285979824,3581,0.7405435561601159,2137.3491065502167,0.2607489483697073,0.7500072887965611,0.2874911624391091,3554,0.7232730588553391 -109.05753564834596,0.6633491516113281,2111.464605331421,8924,0,2111.464605331421,0.2895903237136973,3581,0.7398159748280857,2221.5131690502167,0.2605907576424734,0.7499869891575405,0.2881504417337859,3554,0.7225968978615644 -113.09919476509094,0.690230131149292,2191.5632648468018,9268,0,2191.5632648468018,0.289247326929454,3581,0.739492612922368,2305.6930689811707,0.2612011773245675,0.7488860402788434,0.2878062989380364,3554,0.7222374876899268 -117.13886761665344,0.7173893451690674,2271.7214958667755,9612,0,2271.7214958667755,0.2886738248525202,3581,0.7413944008962231,2389.9308273792267,0.2603103603635515,0.7509616443089077,0.2872652087106957,3554,0.7241491898433103 -121.1830849647522,0.7429046630859375,2352.0011718273163,9956,0,2352.0011718273163,0.288429411520874,3581,0.741594158514556,2474.2932760715485,0.2596092053822109,0.7515387535095215,0.2869361443806714,3554,0.7244041155212436 -125.2218165397644,0.7710072994232178,2432.098387002945,10299,0,2432.098387002945,0.2881840437159662,3581,0.7400625698129014,2558.470093727112,0.2598560537610735,0.7496851512363979,0.2867206494058191,3554,0.7228082711601365 -129.26106452941897,0.798546552658081,2512.111428976059,10645,0,2512.111428976059,0.288655042182264,3581,0.7421747509642908,2642.562907934189,0.2603922741753714,0.7517651149204799,0.2872193894093803,3554,0.7249827301763858 -133.30129551887512,0.8306035995483398,2592.179505586624,10985,0,2592.179505586624,0.2879923309436959,3581,0.7408287391353672,2726.715437889099,0.2591175522123064,0.7508465221949986,0.2865561430012838,3554,0.7235672091525394 -137.34243512153623,0.8594722747802734,2672.226739883423,11330,0,2672.226739883423,0.2880259761261693,3581,0.7414312844701201,2810.844916820526,0.2594335079193115,0.7511290822710309,0.2865808902326164,3554,0.7241735077333639 -141.38680863380432,0.8880980014801025,2752.375296831131,11674,0,2752.375296831131,0.2882127801788083,3581,0.7415087331576375,2895.079016685486,0.2600187574114118,0.750849860055106,0.2867924696150025,3554,0.7243757446495146 -145.4340295791626,0.9185914993286132,2832.403507232666,12016,0,2832.403507232666,0.288554652048136,3581,0.740762539597005,2979.1975667476654,0.2597970792225429,0.7508232252938407,0.2870708029794773,3554,0.7236260117341375 -149.47599267959595,0.9461815357208252,2912.47318983078,12361,0,2912.47318983078,0.2883503606818102,3581,0.7418716375270525,3063.34952712059,0.2597310883658273,0.7515483583722796,0.2868840910439821,3554,0.7247749976918613 -153.514493227005,0.9745748043060304,2992.63488984108,12706,0,2992.63488984108,0.2879713666202702,3581,0.7418361174863864,3147.590842962265,0.2595241921288626,0.75163391658238,0.2866389199999121,3554,0.7246345172253095 -157.55807256698608,1.0025804042816162,3072.810255050659,13047,0,3072.810255050659,0.2905583982193346,3581,0.7402398291329237,3231.8504090309143,0.2622406823294503,0.7500644411359515,0.289036189966411,3554,0.7231141682347355 -161.60008311271667,1.0305194854736328,3152.822122335434,13393,0,3152.822122335434,0.2877308734466629,3581,0.7413110208391511,3315.9444572925568,0.2590423311505999,0.7512751306806292,0.2862096474153243,3554,0.7241309857730726 -165.63811612129211,1.0580720901489258,3232.913407802582,13736,0,3232.913407802582,0.2882589698670064,3581,0.7411356022889906,3400.114282131195,0.2599429743630545,0.7505890301295689,0.2868714684103738,3554,0.7239430373346933 -169.68113112449646,1.0890719890594482,3313.1099050045013,14081,0,3313.1099050045013,0.2879141664012322,3581,0.7414687134573094,3484.3979263305664,0.2593295063291277,0.7513887541634696,0.2864803213313783,3554,0.7242169914181205 -173.728289604187,1.1212973594665527,3393.1668939590454,14426,0,3393.1668939590454,0.2881177760009424,3581,0.7418051652820441,3568.5471363067627,0.2591852971485683,0.7521964481898716,0.2866273964799785,3554,0.7246845268975098 -177.77133631706238,1.1494691371917725,3473.188450574875,14769,0,3473.188450574875,0.2876715938394478,3581,0.7413695164147585,3652.652163267136,0.2590839522225516,0.7512100083487374,0.2862400447778911,3554,0.7241107208646947 -181.81732654571533,1.1774797439575195,3553.326792240143,15112,0,3553.326792240143,0.2875662608973576,3581,0.7413103390725356,3736.877160310745,0.2589912414550781,0.751319340297154,0.2862024860027873,3554,0.7240003286349888 -185.85540199279785,1.2082180976867676,3633.5277137756334,15457,0,3633.5277137756334,0.2873466638705145,3581,0.7416795156948129,3821.159016132354,0.2582509858267648,0.7521231515066964,0.2859320696915007,3554,0.7244947237048045 -189.84543323516849,1.2363996505737305,3713.646933078766,15801,0,3713.646933078766,0.2873067805235095,3581,0.7415334131091176,3905.309055566788,0.2586957727159772,0.7514749254499163,0.2859238778598937,3554,0.7243029283685636 -193.8857979774475,1.270770788192749,3793.65927529335,16143,0,3793.65927529335,0.2874490311278274,3581,0.74123561745148,3989.408354043961,0.2588280779974801,0.7512297630310059,0.2860556512853035,3554,0.7240081598199212 -197.9326922893524,1.299825668334961,3873.663761138916,16488,0,3873.663761138916,0.2872868047616762,3581,0.7423863713217328,4073.501513719559,0.2580648149762835,0.7530438559395927,0.2858771827023688,3554,0.7251870966252814 -201.97663283348083,1.3333888053894043,3953.843817949295,16833,0,3953.843817949295,0.2872385356853009,3581,0.7419785385323583,4157.771581888199,0.258419258253915,0.7522469248090472,0.2858986154190261,3554,0.7247790506735369 -206.0151960849762,1.3643596172332764,4033.999075651169,17177,0,4033.999075651169,0.2873956828901668,3581,0.7413751069010053,4242.009331464768,0.2587161575044904,0.7513459750584194,0.2860677587093768,3554,0.724141770826006 -210.06202054023743,1.3937888145446775,4114.166547060013,17522,0,4114.166547060013,0.2876077123075782,3581,0.742186750056723,4326.265980958939,0.2583299194063459,0.7527370452880859,0.2861687741255451,3554,0.7250032698631823 -214.1040050983429,1.4217588901519775,4194.325728654861,17869,0,4194.325728654861,0.2874424179916574,3581,0.7423210580799707,4410.5075969696045,0.2584952797208513,0.752519062587193,0.2860077539721968,3554,0.725203858108821 -218.1460461616516,1.4510364532470703,4274.364299535751,18213,0,4274.364299535751,0.2874423498149958,3581,0.7415589111805362,4494.629780292511,0.2586789812360491,0.7517549651009696,0.2860423760529509,3554,0.7244219074238534 -222.18551421165463,4.549318790435791,4351.388922452927,18541,0,4351.388922452927,0.290273419774068,3581,0.7355785226062902,4578.803955078125,0.2617180688040597,0.7444351060049874,0.2888652090953855,3554,0.7190021091991418 -226.23116660118103,4.578796863555908,4431.485432386398,18885,0,4431.485432386398,0.2871635072692683,3581,0.7424356630480313,4662.9877717494965,0.2581279448100498,0.7527740342276437,0.2857613292515915,3554,0.7252524251943233 -230.27187418937683,4.607800006866455,4511.616227388382,19230,0,4511.616227388382,0.2871925846154182,3581,0.7422182476743577,4747.201381921768,0.2582859822681972,0.752542359488351,0.2858366013646595,3554,0.7250183826762099 -234.3103144168853,4.636163949966431,4591.753699779511,19570,0,4591.753699779511,0.2871527694450747,3581,0.742229496823513,4831.418623209,0.257946355002267,0.7528716496058873,0.2857356202957934,3554,0.7250105514912775 -238.3534197807312,4.670209646224976,4671.819234609604,19916,0,4671.819234609604,0.2893246733519792,3581,0.7400381625680675,4915.573646783829,0.2603518622262137,0.7501134191240583,0.2876773763387205,3554,0.723229231697559 -242.39756178855896,4.698517560958862,4751.896664619446,20260,0,4751.896664619446,0.2871681432822535,3581,0.7419907421547752,4999.735841989517,0.2582253898893084,0.7525191988263812,0.2858264689104706,3554,0.7247402382218978 -246.3901126384735,4.727694749832153,4831.882543325424,20601,0,4831.882543325424,0.2874492697461428,3581,0.742039692997766,5083.755587100983,0.2585275002888271,0.7524471964154925,0.2860808793788689,3554,0.7248305029324353 -250.43258547782887,4.757020711898804,4911.99870967865,20944,0,4911.99870967865,0.2872574887972109,3581,0.7425531314358769,5167.955820322037,0.2577303307397025,0.7533929007393974,0.2857709636698965,3554,0.7254286955499085 -254.47485995292664,4.789958953857422,4992.026699781418,21288,0,4992.026699781418,0.2872337633189926,3581,0.7425597445720469,5252.0716252326965,0.2579449755804879,0.7532256671360561,0.2858033360023477,3554,0.7253845249191052 -258.51404643058777,4.820429563522339,5071.992090940476,21630,0,5071.992090940476,0.2870563335573164,3581,0.7420083999101159,5336.118559122086,0.25794517993927,0.7525844573974609,0.2856705149842167,3554,0.7247651056687887 -262.5574164390564,4.8497161865234375,5151.954606056213,21973,0,5151.954606056213,0.2870326762557595,3581,0.7425326102607512,5420.165975809097,0.2574006829942976,0.7536664690290179,0.2856152501747591,3554,0.7253517575926772 -266.60148763656616,4.884214401245117,5231.919135332108,22316,0,5231.919135332108,0.2870808430671425,3581,0.7421553206157497,5504.221770524979,0.2576997109821864,0.7529209681919643,0.2856670974276343,3554,0.724979776308385 -270.6449613571167,4.919068574905396,5311.92449259758,22658,0,5311.92449259758,0.2869420012959019,3581,0.7423250805030019,5588.317686080933,0.2577470200402396,0.7528862953186035,0.2855707704182347,3554,0.7251463607247116 -274.68395590782166,4.949132919311523,5391.978044509888,23000,0,5391.978044509888,0.2871562123664828,3581,0.7428951737468584,5672.452279090881,0.2572645630155291,0.7542593819754464,0.2856827769711504,3554,0.7257499115213492 -278.72428584098816,4.980362415313721,5472.120784044266,23344,0,5472.120784044266,0.2870301878076131,3581,0.7426551918982128,5756.679520845413,0.2575033562523978,0.7536674227033343,0.2856114376241998,3554,0.725452189104706 -282.7625916004181,5.010173559188843,5552.164082050324,23689,0,5552.164082050324,0.2871937095303337,3581,0.7427696605129503,5840.803641319275,0.2578081062861851,0.7536218506949288,0.2858345233528682,3554,0.7255891661464196 -286.80222272872925,5.042935132980347,5632.201220989227,24030,0,5632.201220989227,0.2870800931238655,3581,0.7426042639320372,5924.926100969315,0.2571692637034825,0.7539426939828056,0.285618015132597,3554,0.7254523264939153 -290.84347701072693,5.074441194534302,5712.229027509689,24375,0,5712.229027509689,0.2871609165561296,3581,0.7425559266790003,6009.039095878601,0.2576377051217215,0.7535826819283622,0.285717347530951,3554,0.7253617870049592 -294.8886339664459,5.111849069595337,5792.354554653168,24718,0,5792.354554653168,0.2869510687918877,3581,0.7423800308922088,6093.25949716568,0.2575099127633231,0.7533001899719238,0.2855637120476048,3554,0.7252187648380346 -298.93243408203125,5.142282724380493,5872.473770856857,25057,0,5872.473770856857,0.2870945806644442,3581,0.7427837730818906,6177.465390920639,0.2570937190737043,0.7541554995945522,0.285622823754924,3554,0.7256698136123031 -302.9748706817627,5.173599481582642,5952.507550954819,25403,0,5952.507550954819,0.2869819187312378,3581,0.7426374659662106,6261.585339069367,0.257313711302621,0.7537097930908203,0.2855302062541766,3554,0.7254681949475943 -307.0195956230164,5.20525050163269,6032.525489091873,25746,0,6032.525489091873,0.2870435163449455,3581,0.7428361327579587,6345.692269086838,0.257412246295384,0.7538211005074638,0.2856286112753675,3554,0.7256960549512873 -311.0550262928009,5.23749852180481,6112.56943154335,26088,0,6112.56943154335,0.2870197908667272,3581,0.7425881060632505,6429.815916538239,0.2571394102913992,0.7538996423993792,0.285577863136167,3554,0.725426497322559 -315.09079217910767,5.267976999282837,6192.580571889877,26433,0,6192.580571889877,0.2871216467990785,3581,0.7429105134957065,6513.905717372894,0.2572294303349086,0.7542871747698102,0.2856593177636554,3554,0.7257701077351224 -319.1351087093353,5.299068212509155,6272.620399475098,26776,0,6272.620399475098,0.287038334918668,3581,0.7429086727258447,6598.033495903015,0.2573316778455461,0.7540904453822544,0.2856416460766038,3554,0.7257306770320414 -323.1779954433441,5.3297858238220215,6352.729802370071,27120,0,6352.729802370071,0.2878444898531311,3581,0.7427011429680955,6682.2287175655365,0.2578939199447632,0.7539983476911273,0.2862922011564874,3554,0.7256790186893289 -327.2229354381561,5.361184120178223,6432.850472688675,27465,0,6432.850472688675,0.2871871645708252,3581,0.7416571537498254,6766.438841342926,0.2573528460093907,0.7527961730957031,0.285726260655907,3554,0.7244910828907569 -331.2638852596283,5.393904447555542,6513.029781341553,27807,0,6513.029781341553,0.2869928951737469,3581,0.7428482000270525,6850.704304218292,0.2571893760136196,0.754103592463902,0.285555331305835,3554,0.7257026496333356 -335.303875207901,5.425286054611206,6593.055602550507,28147,0,6593.055602550507,0.2870455957331227,3581,0.7427993855373848,6934.813687324524,0.257232666015625,0.7540349960327148,0.2855900736021472,3554,0.7256720118396525 -339.34671092033386,5.456988573074341,6673.217515468597,28493,0,6673.217515468597,0.2869222641523841,3581,0.7428325193948967,7019.063020467758,0.2567970412118094,0.7544220515659877,0.2854602236256682,3554,0.7256915898019837 -343.3887577056885,5.488369941711426,6753.32670044899,28837,0,6753.32670044899,0.2871064434035534,3581,0.7426146949612539,7103.257844209671,0.2571344716208322,0.7540216445922852,0.2856393791546497,3554,0.7254533569129854 -347.432626247406,5.521002531051636,6833.4487562179565,29179,0,6833.4487562179565,0.287044641259861,3581,0.742586060763404,7187.468809604645,0.2571410962513515,0.7539787292480469,0.2856076078999894,3554,0.7254288329391179 -351.4746322631836,5.552271842956543,6913.447591543198,29523,0,6913.447591543198,0.2870695257413257,3581,0.7428071576768012,7271.553107976913,0.256743107523237,0.7546099935259137,0.2855766266332829,3554,0.7256827968925859 -355.5142960548401,5.584680080413818,6993.489233732224,29867,0,6993.489233732224,0.2870715710411721,3581,0.7425986734457903,7355.679824829102,0.257029788834708,0.7541588374546596,0.2856062511815472,3554,0.7254314433340954 -359.5573136806488,5.617060422897339,7073.679226160049,30211,0,7073.679226160049,0.2870932512195441,3581,0.7426834170360933,7439.95772767067,0.2571203368050711,0.7540692601885114,0.2856606573084465,3554,0.7255136707758864 -363.5970644950866,5.648396730422974,7153.867946386337,30552,0,7153.867946386337,0.2871223285656939,3581,0.7429972342091944,7524.22979927063,0.2567461047853742,0.7548450742449079,0.2856448747230233,3554,0.7258831103598059 -367.6330301761627,5.680180311203003,7233.912236690521,30896,0,7233.912236690521,0.2870102120457798,3581,0.7427162100102974,7608.354736804962,0.25690986428942,0.7542660576956612,0.2855315114516654,3554,0.7255811975722777 -371.6770513057709,5.712636709213257,7314.072935581207,31241,0,7314.072935581207,0.2871228058023247,3581,0.742731617935807,7692.604668617248,0.2570479597364153,0.754291330065046,0.2856827082765458,3554,0.7256142396771244 -375.7203893661499,5.745257377624512,7394.21529507637,31581,0,7394.21529507637,0.2870380963003525,3581,0.742917058455215,7776.835196018219,0.2566973481859479,0.7546959604535785,0.2855513126714617,3554,0.725824445167417 -379.7598538398743,5.78111457824707,7474.3756783008575,31928,0,7474.3756783008575,0.2869768054816217,3581,0.7428276788519268,7861.083312034607,0.2567943504878452,0.7544897624424526,0.2854953780896085,3554,0.7257194111168753 -383.8014781475067,5.81369686126709,7554.379847288132,32273,0,7554.379847288132,0.287051424837685,3581,0.7427074833976194,7945.174165248871,0.2569260426930019,0.754249095916748,0.285591790967264,3554,0.7255921887090251 -387.8406648635864,5.852944374084473,7634.347116231918,32612,0,7634.347116231918,0.2869986561016476,3581,0.7427035973279112,8029.232703924179,0.2567449126924787,0.7544171469552177,0.2855080694178215,3554,0.7255723359682752 -391.8827121257782,5.8849639892578125,7714.379668951035,32956,0,7714.379668951035,0.2869958608585241,3581,0.7428032034304315,8113.352165699005,0.2567299774714878,0.7545514106750488,0.2855084128908448,3554,0.7256814916950971 -395.9238955974579,5.918242931365967,7794.549304008484,33299,0,7794.549304008484,0.2869860775075921,3581,0.7429174675151843,8197.608664512634,0.256818277495248,0.754565920148577,0.285514681273521,3554,0.7257891361406162 -399.9638068675995,5.955785274505615,7874.795213460922,33641,0,7874.795213460922,0.2871052844003072,3581,0.7427757964124895,8281.944691896439,0.2568733351571219,0.7544005257742745,0.2856053753253376,3554,0.725658547697137 -404.01122069358826,5.987571239471436,7954.826099395752,33984,0,7954.826099395752,0.2870194158950886,3581,0.7429784856272689,8366.067419290543,0.2566869599478585,0.7547591754368373,0.2855236115721282,3554,0.7258627767568233 -408.0551481246948,6.024341106414795,8034.850677967071,34329,0,8034.850677967071,0.2869331042415701,3581,0.7429921891362399,8450.185488700867,0.2567040579659598,0.7546742303030831,0.2854583688713421,3554,0.7258765843723621 -412.0980794429779,6.058527708053589,8114.963846206665,34672,0,8114.963846206665,0.2869432284758098,3581,0.7429690772479755,8534.388310909271,0.2567169666290283,0.7546633311680385,0.2854794924622784,3554,0.7258422370700267 -416.14691257476807,6.092074155807495,8195.090363264084,35015,0,8195.090363264084,0.2869227413890149,3581,0.7429394204002024,8618.609902381897,0.2566088608333042,0.7547167369297573,0.2854470342615715,3554,0.7258234147483469 -420.1879172325134,6.129887580871582,8275.109664440155,35360,0,8275.109664440155,0.28690225430222,3581,0.7429039003595365,8702.720784902573,0.2566218546458653,0.7546516145978656,0.2854279715087753,3554,0.725779999758195 -424.1739959716797,6.164317607879639,8355.086299180984,35701,0,8355.086299180984,0.2869127194197675,3581,0.7430810915028973,8786.73073220253,0.2566397019795009,0.7548180307660785,0.2854459866688502,3554,0.7259626587120146 -428.2131233215332,6.202942609786987,8435.13413977623,36044,0,8435.13413977623,0.2869018793305815,3581,0.742981485400377,8870.86947607994,0.2566244431904384,0.7547200066702706,0.2854308395085203,3554,0.7258568690208216 -432.25509095191956,6.2368481159210205,8467.644447088242,36189,0,8467.644447088242,0.28690365192378176,3581,0.7429876212999162,8907.461469173431,0.25662572043282644,0.7547260693141392,0.2854322477479161,3554,0.7258640132597074 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/measurements.csv deleted file mode 100644 index a056de353..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.3851867,0.90665424,,,,,,,,,,,,,, -1,,,0.2670762368610927,0.9032635007585798,0.2587972028106535,0.915261701713738,3554.0,0.2829431162570249,0.912513662602974,3581.0,29.42645001411438,33.40965127944946,29.42645001411438,3.983088254928589,0.0,0.0 -100,0.42672008,0.26901054,,,,,,,,,,,,,, -200,0.1256191,0.33167428,,,,,,,,,,,,,, -300,0.14189221,0.35022697,,,,,,,,,,,,,, -342,,,0.700493403843471,0.3026022570473807,0.674353982198755,0.3304802975630451,3554.0,0.6928078258080843,0.332435537602974,3581.0,109.55331134796144,117.60923790931702,109.55331134796144,8.024477005004883,0.0183889865875244,0.0 -400,0.1247513,0.33979893,,,,,,,,,,,,,, -500,0.2162395,0.27751946,,,,,,,,,,,,,, -600,0.15152188,0.2754808,,,,,,,,,,,,,, -681,,,0.7177139009748187,0.2885269437517438,0.6936591582547833,0.3146172364105585,3554.0,0.7106981999179698,0.3170286688315938,3581.0,189.61023998260487,201.7445123195648,189.61023998260487,12.066364765167236,0.0422174930572509,0.0 -700,0.25316006,0.2746179,,,,,,,,,,,,,, -800,0.1336233,0.2290262,,,,,,,,,,,,,, -900,0.15606727,0.33747345,,,,,,,,,,,,,, -1000,0.101702146,0.23241895,,,,,,,,,,,,,, -1020,,,0.7317164966038295,0.2770050082887922,0.7062299274365152,0.3035240881752954,3554.0,0.7232667720041539,0.3058018134446558,3581.0,269.74872303009033,285.9569094181061,269.74872303009033,16.104042291641235,0.0660929679870605,0.0 -1100,0.4482997,0.36047277,,,,,,,,,,,,,, -1200,0.17211345,0.22778618,,,,,,,,,,,,,, -1300,0.3435784,0.31750372,,,,,,,,,,,,,, -1365,,,0.7347615105765206,0.273283737046378,0.709127740639948,0.2994085944094682,3554.0,0.7263307675274016,0.3015255005257784,3581.0,349.75937604904175,370.048823595047,349.75937604904175,20.1486337184906,0.0897228717803955,0.0 -1400,0.47753602,0.30047232,,,,,,,,,,,,,, -1500,0.44040805,0.4055301,,,,,,,,,,,,,, -1600,0.37284696,0.2334345,,,,,,,,,,,,,, -1700,0.14307779,0.33132493,,,,,,,,,,,,,, -1709,,,0.7332886287144252,0.2724383217947824,0.7087479968653277,0.2982055114505663,3554.0,0.7256882024923206,0.3000804961842886,3581.0,429.8190410137177,454.1902050971985,429.8190410137177,24.189027786254883,0.1182146072387695,0.0 -1800,0.26820272,0.2675001,,,,,,,,,,,,,, -1900,0.09804658,0.2604441,,,,,,,,,,,,,, -2000,0.12076444,0.24474087,,,,,,,,,,,,,, -2053,,,0.7409915242876325,0.2687742710113525,0.7151846126283765,0.2948922676253869,3554.0,0.7322903622198758,0.2967952673943207,3581.0,509.8124516010285,538.2597918510437,509.8124516010285,28.22706866264344,0.1440095901489257,0.0 -2100,0.090840556,0.33141315,,,,,,,,,,,,,, -2200,0.09831769,0.254164,,,,,,,,,,,,,, -2300,0.060459018,0.24946645,,,,,,,,,,,,,, -2396,,,0.7421709469386509,0.26690559727805,0.7165991719277575,0.2932297551504467,3554.0,0.7335974451750559,0.2949343172407323,3581.0,589.8710687160492,622.399489402771,589.8710687160492,32.26940202713013,0.1705169677734375,0.0 -2400,0.1042413,0.35495022,,,,,,,,,,,,,, -2500,0.16816491,0.2737429,,,,,,,,,,,,,, -2600,0.14682467,0.2788069,,,,,,,,,,,,,, -2700,0.07608004,0.26123884,,,,,,,,,,,,,, -2739,,,0.7418599809919085,0.2665479012898036,0.7161163862461312,0.2926901246834553,3554.0,0.7333299199551452,0.2943898243332868,3581.0,670.0183191299438,706.627876996994,670.0183191299438,36.31421685218811,0.194258451461792,0.0 -2800,0.09569044,0.31437445,,,,,,,,,,,,,, -2900,0.103505686,0.23616901,,,,,,,,,,,,,, -3000,0.07655918,0.2672699,,,,,,,,,,,,,, -3083,,,0.7423411096845355,0.2664205006190708,0.7173274034318725,0.292220871838949,3554.0,0.7344748106325049,0.2937812112756562,3581.0,750.0226106643677,790.7108700275421,750.0226106643677,40.35602593421936,0.2186872959136963,0.0 -3100,0.317532,0.26104477,,,,,,,,,,,,,, -3200,0.2468407,0.3743924,,,,,,,,,,,,,, -3300,0.07888398,0.30617166,,,,,,,,,,,,,, -3400,0.13401543,0.28603214,,,,,,,,,,,,,, -3425,,,0.7441543170383998,0.2648028305598667,0.7175930454681345,0.2914524883384039,3554.0,0.7348356015254119,0.293075310121998,3581.0,830.2108571529388,874.9735074043274,830.2108571529388,44.39246344566345,0.2435402870178222,0.0 -3500,0.06589158,0.26100755,,,,,,,,,,,,,, -3600,0.13384627,0.29226673,,,,,,,,,,,,,, -3700,0.37005246,0.28552037,,,,,,,,,,,,,, -3770,,,0.7439109938485282,0.2662832736968994,0.7177881381453995,0.2921615883951182,3554.0,0.7350305867774365,0.2938836467096307,3581.0,910.250020980835,959.0884766578674,910.250020980835,48.43196511268616,0.2675139904022217,0.0 -3800,0.08493081,0.27160507,,,,,,,,,,,,,, -3900,0.15059938,0.26701814,,,,,,,,,,,,,, -4000,0.30463699,0.29844332,,,,,,,,,,,,,, -4100,0.08640484,0.26895082,,,,,,,,,,,,,, -4112,,,0.7449028151375907,0.2645629644393921,0.7190284192327308,0.2907478877782956,3554.0,0.7360385105417481,0.2923847487171879,3581.0,990.3172655105592,1043.2327580451963,990.3172655105592,52.47136425971985,0.2925031185150146,0.0 -4200,0.19368485,0.24502501,,,,,,,,,,,,,, -4300,0.16803773,0.28771877,,,,,,,,,,,,,, -4400,0.069387354,0.2941994,,,,,,,,,,,,,, -4455,,,0.745528153010777,0.2632608924593244,0.7188410203511888,0.2899452256700197,3554.0,0.7361167773492041,0.2916109095202981,3581.0,1070.4398555755615,1127.433172941208,1070.4398555755615,56.51176953315735,0.3179206848144531,0.0 -4500,0.072409935,0.20194462,,,,,,,,,,,,,, -4600,0.14151536,0.20685047,,,,,,,,,,,,,, -4700,0.15375392,0.22029787,,,,,,,,,,,,,, -4800,,,0.7455744062151227,0.2641662529536656,0.7195378584209693,0.2902841648494654,3554.0,0.7369595772392488,0.2917902141401668,3581.0,1150.5182836055756,1211.5923788547516,1150.5182836055756,60.554898500442505,0.3427619934082031,0.0 -4800,0.23907174,0.31470734,,,,,,,,,,,,,, -4900,0.050423756,0.28902793,,,,,,,,,,,,,, -5000,0.2385516,0.36911434,,,,,,,,,,,,,, -5100,0.08031695,0.32920024,,,,,,,,,,,,,, -5145,,,0.7470813478742327,0.2624162094933646,0.7207821237953714,0.288812829459324,3554.0,0.7380477449342013,0.2903629016423485,3581.0,1230.709332704544,1295.8631434440613,1230.709332704544,64.59666466712952,0.3678746223449707,0.0 -5200,0.10497945,0.23799941,,,,,,,,,,,,,, -5300,0.18251023,0.22133617,,,,,,,,,,,,,, -5400,0.09163169,0.342787,,,,,,,,,,,,,, -5487,,,0.7463224955967495,0.2628247056688581,0.7197903797877392,0.2894811249208638,3554.0,0.737104043585067,0.290981604845888,3581.0,1310.8590772151947,1380.0955708026886,1310.8590772151947,68.63917136192322,0.395914077758789,0.0 -5500,0.13105299,0.28052956,,,,,,,,,,,,,, -5600,0.036234405,0.25570554,,,,,,,,,,,,,, -5700,0.10873247,0.2960818,,,,,,,,,,,,,, -5800,0.08718821,0.29396152,,,,,,,,,,,,,, -5832,,,0.746668951851981,0.2629751477922712,0.7207214664594471,0.2891875585278032,3554.0,0.7379295266030788,0.2907189883456088,3581.0,1390.823716878891,1464.141278505325,1390.823716878891,72.6823296546936,0.4211378097534179,0.0 -5900,0.19880421,0.30101907,,,,,,,,,,,,,, -6000,0.1011383,0.21136618,,,,,,,,,,,,,, -6100,0.11642525,0.32202733,,,,,,,,,,,,,, -6176,,,0.7470532144818988,0.2630742447716849,0.7212173728105655,0.2891715870322172,3554.0,0.7384506690039444,0.290686297636397,3581.0,1470.9608535766602,1548.361647605896,1470.9608535766602,76.7279725074768,0.4464681148529053,0.0 -6200,0.10742293,0.29438654,,,,,,,,,,,,,, -6300,0.20864017,0.23873037,,,,,,,,,,,,,, -6400,0.10908719,0.23276374,,,,,,,,,,,,,, -6500,0.11626022,0.25663674,,,,,,,,,,,,,, -6517,,,0.7485293660845075,0.2622644390378679,0.7220091468240011,0.2888351208585397,3554.0,0.739124458950014,0.2903405056090303,3581.0,1551.1067745685575,1632.5861072540283,1551.1067745685575,80.76854872703552,0.4717001914978027,0.0 -6600,0.13942082,0.30169484,,,,,,,,,,,,,, -6700,0.044536926,0.2889076,,,,,,,,,,,,,, -6800,0.10118046,0.30786973,,,,,,,,,,,,,, -6862,,,0.7489062036786761,0.2613810130528041,0.7224099798422552,0.2880435529289181,3554.0,0.7395610622905613,0.2895691207719562,3581.0,1631.1339271068573,1716.6855008602142,1631.1339271068573,84.80290365219116,0.4968986511230469,0.0 -6900,0.03252965,0.35139358,,,,,,,,,,,,,, -7000,0.2205683,0.24419218,,,,,,,,,,,,,, -7100,0.0852381,0.24613202,,,,,,,,,,,,,, -7200,0.116363175,0.27529636,,,,,,,,,,,,,, -7206,,,0.7486821583339146,0.26164140020098,0.7226811174468908,0.2881173137606834,3554.0,0.7398471315624128,0.289598198118106,3581.0,1711.1187720298767,1800.75151348114,1711.1187720298767,88.84633111953735,0.5218157768249512,0.0 -7300,0.11946407,0.27699828,,,,,,,,,,,,,, -7400,0.19457991,0.26036376,,,,,,,,,,,,,, -7500,0.067360304,0.20437291,,,,,,,,,,,,,, -7549,,,0.750164372580392,0.2611466135297502,0.7235968165271525,0.2879235778018605,3554.0,0.7407856514852694,0.289402974247766,3581.0,1791.1784365177157,1884.8976328372955,1791.1784365177157,92.88967680931091,0.5519485473632812,0.0 -7600,0.08609896,0.2630054,,,,,,,,,,,,,, -7700,0.37145895,0.26786014,,,,,,,,,,,,,, -7800,0.055776007,0.2800464,,,,,,,,,,,,,, -7892,,,0.748753547668457,0.2615155833108084,0.7221031210431907,0.2880484130721986,3554.0,0.7394183003612818,0.2894897290495846,3581.0,1871.205404281616,1969.0074808597565,1871.205404281616,96.92876887321472,0.5830910205841064,0.0 -7900,0.3748164,0.26894435,,,,,,,,,,,,,, -8000,0.07799697,0.29956,,,,,,,,,,,,,, -8100,0.23378001,0.22035676,,,,,,,,,,,,,, -8200,0.02860266,0.27461645,,,,,,,,,,,,,, -8237,,,0.7499610355922154,0.2605032920837402,0.7234105167592854,0.287204929195097,3554.0,0.740555282545902,0.288724582377042,3581.0,1951.2375724315643,2053.123046398163,1951.2375724315643,100.9730007648468,0.609849214553833,0.0 -8300,0.12162626,0.2485451,,,,,,,,,,,,,, -8400,0.05929264,0.26824602,,,,,,,,,,,,,, -8500,0.1378193,0.35173002,,,,,,,,,,,,,, -8578,,,0.7500072887965611,0.2607489483697073,0.7232730588553391,0.2874911624391091,3554.0,0.7405435561601159,0.288887285979824,3581.0,2031.3805038928983,2137.3491065502167,2031.3805038928983,105.01745820045473,0.6357917785644531,0.0 -8600,0.069255434,0.34389704,,,,,,,,,,,,,, -8700,0.39858446,0.20751132,,,,,,,,,,,,,, -8800,0.05965144,0.20334047,,,,,,,,,,,,,, -8900,0.06179647,0.23277323,,,,,,,,,,,,,, -8924,,,0.7499869891575405,0.2605907576424734,0.7225968978615644,0.2881504417337859,3554.0,0.7398159748280857,0.2895903237136973,3581.0,2111.464605331421,2221.5131690502167,2111.464605331421,109.05753564834596,0.6633491516113281,0.0 -9000,0.10013373,0.19800866,,,,,,,,,,,,,, -9100,0.10209796,0.3030777,,,,,,,,,,,,,, -9200,0.31683806,0.21172354,,,,,,,,,,,,,, -9268,,,0.7488860402788434,0.2612011773245675,0.7222374876899268,0.2878062989380364,3554.0,0.739492612922368,0.289247326929454,3581.0,2191.5632648468018,2305.6930689811707,2191.5632648468018,113.09919476509094,0.690230131149292,0.0 -9300,0.025633462,0.31837863,,,,,,,,,,,,,, -9400,0.088182196,0.23064119,,,,,,,,,,,,,, -9500,0.0897606,0.26884016,,,,,,,,,,,,,, -9600,0.09734171,0.27047884,,,,,,,,,,,,,, -9612,,,0.7509616443089077,0.2603103603635515,0.7241491898433103,0.2872652087106957,3554.0,0.7413944008962231,0.2886738248525202,3581.0,2271.7214958667755,2389.9308273792267,2271.7214958667755,117.13886761665344,0.7173893451690674,0.0 -9700,0.052722972,0.24734215,,,,,,,,,,,,,, -9800,0.2876609,0.26385552,,,,,,,,,,,,,, -9900,0.056648105,0.25250542,,,,,,,,,,,,,, -9956,,,0.7515387535095215,0.2596092053822109,0.7244041155212436,0.2869361443806714,3554.0,0.741594158514556,0.288429411520874,3581.0,2352.0011718273163,2474.2932760715485,2352.0011718273163,121.1830849647522,0.7429046630859375,0.0 -10000,0.063423805,0.2491127,,,,,,,,,,,,,, -10100,0.09591902,0.2862453,,,,,,,,,,,,,, -10200,0.07503651,0.3518268,,,,,,,,,,,,,, -10299,,,0.7496851512363979,0.2598560537610735,0.7228082711601365,0.2867206494058191,3554.0,0.7400625698129014,0.2881840437159662,3581.0,2432.098387002945,2558.470093727112,2432.098387002945,125.2218165397644,0.7710072994232178,0.0 -10300,0.07012811,0.24426657,,,,,,,,,,,,,, -10400,0.10046528,0.30725747,,,,,,,,,,,,,, -10500,0.091835685,0.2147493,,,,,,,,,,,,,, -10600,0.04553039,0.2543349,,,,,,,,,,,,,, -10645,,,0.7517651149204799,0.2603922741753714,0.7249827301763858,0.2872193894093803,3554.0,0.7421747509642908,0.288655042182264,3581.0,2512.111428976059,2642.562907934189,2512.111428976059,129.26106452941897,0.798546552658081,0.0 -10700,0.08314729,0.3158393,,,,,,,,,,,,,, -10800,0.05942162,0.19556427,,,,,,,,,,,,,, -10900,0.085142426,0.2564674,,,,,,,,,,,,,, -10985,,,0.7508465221949986,0.2591175522123064,0.7235672091525394,0.2865561430012838,3554.0,0.7408287391353672,0.2879923309436959,3581.0,2592.179505586624,2726.715437889099,2592.179505586624,133.30129551887512,0.8306035995483398,0.0 -11000,0.0640348,0.33205807,,,,,,,,,,,,,, -11100,0.22913156,0.2510247,,,,,,,,,,,,,, -11200,0.05139161,0.24601564,,,,,,,,,,,,,, -11300,0.18360919,0.30085632,,,,,,,,,,,,,, -11330,,,0.7511290822710309,0.2594335079193115,0.7241735077333639,0.2865808902326164,3554.0,0.7414312844701201,0.2880259761261693,3581.0,2672.226739883423,2810.844916820526,2672.226739883423,137.34243512153623,0.8594722747802734,0.0 -11400,0.041963313,0.25915843,,,,,,,,,,,,,, -11500,0.0358891,0.289709,,,,,,,,,,,,,, -11600,0.3348896,0.24704856,,,,,,,,,,,,,, -11674,,,0.750849860055106,0.2600187574114118,0.7243757446495146,0.2867924696150025,3554.0,0.7415087331576375,0.2882127801788083,3581.0,2752.375296831131,2895.079016685486,2752.375296831131,141.38680863380432,0.8880980014801025,0.0 -11700,0.049481772,0.2410556,,,,,,,,,,,,,, -11800,0.08552632,0.2785782,,,,,,,,,,,,,, -11900,0.13968958,0.34959757,,,,,,,,,,,,,, -12000,0.05328108,0.3241593,,,,,,,,,,,,,, -12016,,,0.7508232252938407,0.2597970792225429,0.7236260117341375,0.2870708029794773,3554.0,0.740762539597005,0.288554652048136,3581.0,2832.403507232666,2979.1975667476654,2832.403507232666,145.4340295791626,0.9185914993286132,0.0 -12100,0.117972195,0.21890107,,,,,,,,,,,,,, -12200,0.12834366,0.27058175,,,,,,,,,,,,,, -12300,0.060977746,0.27931237,,,,,,,,,,,,,, -12361,,,0.7515483583722796,0.2597310883658273,0.7247749976918613,0.2868840910439821,3554.0,0.7418716375270525,0.2883503606818102,3581.0,2912.47318983078,3063.34952712059,2912.47318983078,149.47599267959595,0.9461815357208252,0.0 -12400,0.08236118,0.2512736,,,,,,,,,,,,,, -12500,0.1681309,0.3324492,,,,,,,,,,,,,, -12600,0.06874121,0.37100863,,,,,,,,,,,,,, -12700,0.053779762,0.2892405,,,,,,,,,,,,,, -12706,,,0.75163391658238,0.2595241921288626,0.7246345172253095,0.2866389199999121,3554.0,0.7418361174863864,0.2879713666202702,3581.0,2992.63488984108,3147.590842962265,2992.63488984108,153.514493227005,0.9745748043060304,0.0 -12800,0.11054124,0.24897039,,,,,,,,,,,,,, -12900,0.026005361,0.21756521,,,,,,,,,,,,,, -13000,0.056527648,0.28317064,,,,,,,,,,,,,, -13047,,,0.7500644411359515,0.2622406823294503,0.7231141682347355,0.289036189966411,3554.0,0.7402398291329237,0.2905583982193346,3581.0,3072.810255050659,3231.8504090309143,3072.810255050659,157.55807256698608,1.0025804042816162,0.0 -13100,0.026719071,0.23940879,,,,,,,,,,,,,, -13200,0.08534795,0.19020677,,,,,,,,,,,,,, -13300,0.080970444,0.25442296,,,,,,,,,,,,,, -13393,,,0.7512751306806292,0.2590423311505999,0.7241309857730726,0.2862096474153243,3554.0,0.7413110208391511,0.2877308734466629,3581.0,3152.822122335434,3315.9444572925568,3152.822122335434,161.60008311271667,1.0305194854736328,0.0 -13400,0.090642765,0.286487,,,,,,,,,,,,,, -13500,0.0806377,0.26833948,,,,,,,,,,,,,, -13600,0.109153345,0.21626392,,,,,,,,,,,,,, -13700,0.08190987,0.32675007,,,,,,,,,,,,,, -13736,,,0.7505890301295689,0.2599429743630545,0.7239430373346933,0.2868714684103738,3554.0,0.7411356022889906,0.2882589698670064,3581.0,3232.913407802582,3400.114282131195,3232.913407802582,165.63811612129211,1.0580720901489258,0.0 -13800,0.101105876,0.21497017,,,,,,,,,,,,,, -13900,0.2756818,0.2353599,,,,,,,,,,,,,, -14000,0.12361287,0.3302625,,,,,,,,,,,,,, -14081,,,0.7513887541634696,0.2593295063291277,0.7242169914181205,0.2864803213313783,3554.0,0.7414687134573094,0.2879141664012322,3581.0,3313.1099050045013,3484.3979263305664,3313.1099050045013,169.68113112449646,1.0890719890594482,0.0 -14100,0.07388025,0.2772258,,,,,,,,,,,,,, -14200,0.11826891,0.23659529,,,,,,,,,,,,,, -14300,0.109988526,0.28732154,,,,,,,,,,,,,, -14400,0.05334013,0.22761013,,,,,,,,,,,,,, -14426,,,0.7521964481898716,0.2591852971485683,0.7246845268975098,0.2866273964799785,3554.0,0.7418051652820441,0.2881177760009424,3581.0,3393.1668939590454,3568.5471363067627,3393.1668939590454,173.728289604187,1.1212973594665527,0.0 -14500,0.031210454,0.3648915,,,,,,,,,,,,,, -14600,0.2002634,0.20383714,,,,,,,,,,,,,, -14700,0.06328973,0.22303867,,,,,,,,,,,,,, -14769,,,0.7512100083487374,0.2590839522225516,0.7241107208646947,0.2862400447778911,3554.0,0.7413695164147585,0.2876715938394478,3581.0,3473.188450574875,3652.652163267136,3473.188450574875,177.77133631706238,1.1494691371917725,0.0 -14800,0.07126531,0.26595527,,,,,,,,,,,,,, -14900,0.14993416,0.23491229,,,,,,,,,,,,,, -15000,0.0720836,0.22661859,,,,,,,,,,,,,, -15100,0.061820906,0.20694126,,,,,,,,,,,,,, -15112,,,0.751319340297154,0.2589912414550781,0.7240003286349888,0.2862024860027873,3554.0,0.7413103390725356,0.2875662608973576,3581.0,3553.326792240143,3736.877160310745,3553.326792240143,181.81732654571533,1.1774797439575195,0.0 -15200,0.0724548,0.24388513,,,,,,,,,,,,,, -15300,0.11632163,0.24117723,,,,,,,,,,,,,, -15400,0.06081594,0.23404315,,,,,,,,,,,,,, -15457,,,0.7521231515066964,0.2582509858267648,0.7244947237048045,0.2859320696915007,3554.0,0.7416795156948129,0.2873466638705145,3581.0,3633.5277137756334,3821.159016132354,3633.5277137756334,185.85540199279785,1.2082180976867676,0.0 -15500,0.10914289,0.34608012,,,,,,,,,,,,,, -15600,0.07634871,0.30374002,,,,,,,,,,,,,, -15700,0.10948239,0.24520017,,,,,,,,,,,,,, -15800,0.049033076,0.30274823,,,,,,,,,,,,,, -15801,,,0.7514749254499163,0.2586957727159772,0.7243029283685636,0.2859238778598937,3554.0,0.7415334131091176,0.2873067805235095,3581.0,3713.646933078766,3905.309055566788,3713.646933078766,189.84543323516849,1.2363996505737305,0.0 -15900,0.12692386,0.29775214,,,,,,,,,,,,,, -16000,0.12192635,0.22394729,,,,,,,,,,,,,, -16100,0.051522624,0.2758298,,,,,,,,,,,,,, -16143,,,0.7512297630310059,0.2588280779974801,0.7240081598199212,0.2860556512853035,3554.0,0.74123561745148,0.2874490311278274,3581.0,3793.65927529335,3989.408354043961,3793.65927529335,193.8857979774475,1.270770788192749,0.0 -16200,0.025270145,0.25768557,,,,,,,,,,,,,, -16300,0.12536547,0.22129267,,,,,,,,,,,,,, -16400,0.09100999,0.24057822,,,,,,,,,,,,,, -16488,,,0.7530438559395927,0.2580648149762835,0.7251870966252814,0.2858771827023688,3554.0,0.7423863713217328,0.2872868047616762,3581.0,3873.663761138916,4073.501513719559,3873.663761138916,197.9326922893524,1.299825668334961,0.0 -16500,0.059969503,0.20746806,,,,,,,,,,,,,, -16600,0.11110775,0.18800944,,,,,,,,,,,,,, -16700,0.19850855,0.26574707,,,,,,,,,,,,,, -16800,0.14920956,0.23685996,,,,,,,,,,,,,, -16833,,,0.7522469248090472,0.258419258253915,0.7247790506735369,0.2858986154190261,3554.0,0.7419785385323583,0.2872385356853009,3581.0,3953.843817949295,4157.771581888199,3953.843817949295,201.97663283348083,1.3333888053894043,0.0 -16900,0.05018938,0.35385054,,,,,,,,,,,,,, -17000,0.043724388,0.23885357,,,,,,,,,,,,,, -17100,0.15510958,0.27472854,,,,,,,,,,,,,, -17177,,,0.7513459750584194,0.2587161575044904,0.724141770826006,0.2860677587093768,3554.0,0.7413751069010053,0.2873956828901668,3581.0,4033.999075651169,4242.009331464768,4033.999075651169,206.0151960849762,1.3643596172332764,0.0 -17200,0.13549747,0.34367365,,,,,,,,,,,,,, -17300,0.044596694,0.30665213,,,,,,,,,,,,,, -17400,0.04917891,0.25829226,,,,,,,,,,,,,, -17500,0.043967474,0.30147147,,,,,,,,,,,,,, -17522,,,0.7527370452880859,0.2583299194063459,0.7250032698631823,0.2861687741255451,3554.0,0.742186750056723,0.2876077123075782,3581.0,4114.166547060013,4326.265980958939,4114.166547060013,210.06202054023743,1.3937888145446775,0.0 -17600,0.10801031,0.2572277,,,,,,,,,,,,,, -17700,0.0810771,0.28532636,,,,,,,,,,,,,, -17800,0.10232324,0.36111188,,,,,,,,,,,,,, -17869,,,0.752519062587193,0.2584952797208513,0.725203858108821,0.2860077539721968,3554.0,0.7423210580799707,0.2874424179916574,3581.0,4194.325728654861,4410.5075969696045,4194.325728654861,214.1040050983429,1.4217588901519775,0.0 -17900,0.08392407,0.31965423,,,,,,,,,,,,,, -18000,0.053652406,0.36731106,,,,,,,,,,,,,, -18100,0.060542952,0.22855508,,,,,,,,,,,,,, -18200,0.049288463,0.3577454,,,,,,,,,,,,,, -18213,,,0.7517549651009696,0.2586789812360491,0.7244219074238534,0.2860423760529509,3554.0,0.7415589111805362,0.2874423498149958,3581.0,4274.364299535751,4494.629780292511,4274.364299535751,218.1460461616516,1.4510364532470703,0.0 -18300,0.28889155,0.20191109,,,,,,,,,,,,,, -18400,0.03665171,0.28519237,,,,,,,,,,,,,, -18500,0.034504645,0.22991093,,,,,,,,,,,,,, -18541,,,0.7444351060049874,0.2617180688040597,0.7190021091991418,0.2888652090953855,3554.0,0.7355785226062902,0.290273419774068,3581.0,4351.388922452927,4578.803955078125,4351.388922452927,222.18551421165463,4.549318790435791,0.0 -18600,0.038323633,0.29314026,,,,,,,,,,,,,, -18700,0.093038276,0.30287358,,,,,,,,,,,,,, -18800,0.028866207,0.25899684,,,,,,,,,,,,,, -18885,,,0.7527740342276437,0.2581279448100498,0.7252524251943233,0.2857613292515915,3554.0,0.7424356630480313,0.2871635072692683,3581.0,4431.485432386398,4662.9877717494965,4431.485432386398,226.23116660118103,4.578796863555908,0.0 -18900,0.07157999,0.27785662,,,,,,,,,,,,,, -19000,0.05996915,0.29240113,,,,,,,,,,,,,, -19100,0.05381538,0.36232,,,,,,,,,,,,,, -19200,0.091109775,0.3508817,,,,,,,,,,,,,, -19230,,,0.752542359488351,0.2582859822681972,0.7250183826762099,0.2858366013646595,3554.0,0.7422182476743577,0.2871925846154182,3581.0,4511.616227388382,4747.201381921768,4511.616227388382,230.27187418937683,4.607800006866455,0.0 -19300,0.05092836,0.3153988,,,,,,,,,,,,,, -19400,0.11610616,0.20194657,,,,,,,,,,,,,, -19500,0.051887468,0.31007063,,,,,,,,,,,,,, -19570,,,0.7528716496058873,0.257946355002267,0.7250105514912775,0.2857356202957934,3554.0,0.742229496823513,0.2871527694450747,3581.0,4591.753699779511,4831.418623209,4591.753699779511,234.3103144168853,4.636163949966431,0.0 -19600,0.07573362,0.25779226,,,,,,,,,,,,,, -19700,0.10706302,0.24725993,,,,,,,,,,,,,, -19800,0.08864396,0.28404817,,,,,,,,,,,,,, -19900,0.029366387,0.275176,,,,,,,,,,,,,, -19916,,,0.7501134191240583,0.2603518622262137,0.723229231697559,0.2876773763387205,3554.0,0.7400381625680675,0.2893246733519792,3581.0,4671.819234609604,4915.573646783829,4671.819234609604,238.3534197807312,4.670209646224976,0.0 -20000,0.105956964,0.34773707,,,,,,,,,,,,,, -20100,0.050828423,0.4155305,,,,,,,,,,,,,, -20200,0.108425185,0.27979606,,,,,,,,,,,,,, -20260,,,0.7525191988263812,0.2582253898893084,0.7247402382218978,0.2858264689104706,3554.0,0.7419907421547752,0.2871681432822535,3581.0,4751.896664619446,4999.735841989517,4751.896664619446,242.39756178855896,4.698517560958862,0.0 -20300,0.10933445,0.24167202,,,,,,,,,,,,,, -20400,0.106405415,0.23643005,,,,,,,,,,,,,, -20500,0.323259,0.296829,,,,,,,,,,,,,, -20600,0.08019053,0.2402622,,,,,,,,,,,,,, -20601,,,0.7524471964154925,0.2585275002888271,0.7248305029324353,0.2860808793788689,3554.0,0.742039692997766,0.2874492697461428,3581.0,4831.882543325424,5083.755587100983,4831.882543325424,246.3901126384735,4.727694749832153,0.0 -20700,0.04864291,0.22974649,,,,,,,,,,,,,, -20800,0.06716107,0.2698938,,,,,,,,,,,,,, -20900,0.056477346,0.28434417,,,,,,,,,,,,,, -20944,,,0.7533929007393974,0.2577303307397025,0.7254286955499085,0.2857709636698965,3554.0,0.7425531314358769,0.2872574887972109,3581.0,4911.99870967865,5167.955820322037,4911.99870967865,250.43258547782887,4.757020711898804,0.0 -21000,0.08552789,0.390302,,,,,,,,,,,,,, -21100,0.02657888,0.28263703,,,,,,,,,,,,,, -21200,0.052235592,0.22000214,,,,,,,,,,,,,, -21288,,,0.7532256671360561,0.2579449755804879,0.7253845249191052,0.2858033360023477,3554.0,0.7425597445720469,0.2872337633189926,3581.0,4992.026699781418,5252.0716252326965,4992.026699781418,254.47485995292664,4.789958953857422,0.0 -21300,0.062300947,0.3017634,,,,,,,,,,,,,, -21400,0.11461783,0.254043,,,,,,,,,,,,,, -21500,0.04102009,0.25208586,,,,,,,,,,,,,, -21600,0.07486921,0.23506692,,,,,,,,,,,,,, -21630,,,0.7525844573974609,0.25794517993927,0.7247651056687887,0.2856705149842167,3554.0,0.7420083999101159,0.2870563335573164,3581.0,5071.992090940476,5336.118559122086,5071.992090940476,258.51404643058777,4.820429563522339,0.0 -21700,0.032389667,0.39589635,,,,,,,,,,,,,, -21800,0.08975128,0.17993009,,,,,,,,,,,,,, -21900,0.085108705,0.27587357,,,,,,,,,,,,,, -21973,,,0.7536664690290179,0.2574006829942976,0.7253517575926772,0.2856152501747591,3554.0,0.7425326102607512,0.2870326762557595,3581.0,5151.954606056213,5420.165975809097,5151.954606056213,262.5574164390564,4.8497161865234375,0.0 -22000,0.04532727,0.3537173,,,,,,,,,,,,,, -22100,0.018996581,0.2908321,,,,,,,,,,,,,, -22200,0.08681335,0.27161545,,,,,,,,,,,,,, -22300,0.121869184,0.25076059,,,,,,,,,,,,,, -22316,,,0.7529209681919643,0.2576997109821864,0.724979776308385,0.2856670974276343,3554.0,0.7421553206157497,0.2870808430671425,3581.0,5231.919135332108,5504.221770524979,5231.919135332108,266.60148763656616,4.884214401245117,0.0 -22400,0.034374885,0.28882065,,,,,,,,,,,,,, -22500,0.14552414,0.2412823,,,,,,,,,,,,,, -22600,0.1580673,0.23827678,,,,,,,,,,,,,, -22658,,,0.7528862953186035,0.2577470200402396,0.7251463607247116,0.2855707704182347,3554.0,0.7423250805030019,0.2869420012959019,3581.0,5311.92449259758,5588.317686080933,5311.92449259758,270.6449613571167,4.919068574905396,0.0 -22700,0.100804076,0.19387645,,,,,,,,,,,,,, -22800,0.047238376,0.23352584,,,,,,,,,,,,,, -22900,0.08122501,0.19374308,,,,,,,,,,,,,, -23000,,,0.7542593819754464,0.2572645630155291,0.7257499115213492,0.2856827769711504,3554.0,0.7428951737468584,0.2871562123664828,3581.0,5391.978044509888,5672.452279090881,5391.978044509888,274.68395590782166,4.949132919311523,0.0 -23000,0.038111404,0.26869717,,,,,,,,,,,,,, -23100,0.119921595,0.30147594,,,,,,,,,,,,,, -23200,0.04614463,0.32689768,,,,,,,,,,,,,, -23300,0.20606954,0.25028834,,,,,,,,,,,,,, -23344,,,0.7536674227033343,0.2575033562523978,0.725452189104706,0.2856114376241998,3554.0,0.7426551918982128,0.2870301878076131,3581.0,5472.120784044266,5756.679520845413,5472.120784044266,278.72428584098816,4.980362415313721,0.0 -23400,0.024539445,0.3071632,,,,,,,,,,,,,, -23500,0.06233412,0.23100612,,,,,,,,,,,,,, -23600,0.059316833,0.23448595,,,,,,,,,,,,,, -23689,,,0.7536218506949288,0.2578081062861851,0.7255891661464196,0.2858345233528682,3554.0,0.7427696605129503,0.2871937095303337,3581.0,5552.164082050324,5840.803641319275,5552.164082050324,282.7625916004181,5.010173559188843,0.0 -23700,0.030784998,0.2943529,,,,,,,,,,,,,, -23800,0.052480184,0.23775743,,,,,,,,,,,,,, -23900,0.08173138,0.22073294,,,,,,,,,,,,,, -24000,0.028877808,0.23691122,,,,,,,,,,,,,, -24030,,,0.7539426939828056,0.2571692637034825,0.7254523264939153,0.285618015132597,3554.0,0.7426042639320372,0.2870800931238655,3581.0,5632.201220989227,5924.926100969315,5632.201220989227,286.80222272872925,5.042935132980347,0.0 -24100,0.072903395,0.3263024,,,,,,,,,,,,,, -24200,0.12322182,0.2594242,,,,,,,,,,,,,, -24300,0.06768202,0.23642904,,,,,,,,,,,,,, -24375,,,0.7535826819283622,0.2576377051217215,0.7253617870049592,0.285717347530951,3554.0,0.7425559266790003,0.2871609165561296,3581.0,5712.229027509689,6009.039095878601,5712.229027509689,290.84347701072693,5.074441194534302,0.0 -24400,0.11790544,0.18850996,,,,,,,,,,,,,, -24500,0.019142566,0.29418075,,,,,,,,,,,,,, -24600,0.07015593,0.28378665,,,,,,,,,,,,,, -24700,0.053086177,0.25206053,,,,,,,,,,,,,, -24718,,,0.7533001899719238,0.2575099127633231,0.7252187648380346,0.2855637120476048,3554.0,0.7423800308922088,0.2869510687918877,3581.0,5792.354554653168,6093.25949716568,5792.354554653168,294.8886339664459,5.111849069595337,0.0 -24800,0.027102875,0.27186158,,,,,,,,,,,,,, -24900,0.040290788,0.259417,,,,,,,,,,,,,, -25000,0.043787464,0.24807525,,,,,,,,,,,,,, -25057,,,0.7541554995945522,0.2570937190737043,0.7256698136123031,0.285622823754924,3554.0,0.7427837730818906,0.2870945806644442,3581.0,5872.473770856857,6177.465390920639,5872.473770856857,298.93243408203125,5.142282724380493,0.0 -25100,0.050679944,0.28107712,,,,,,,,,,,,,, -25200,0.042758904,0.22778164,,,,,,,,,,,,,, -25300,0.06670551,0.20118299,,,,,,,,,,,,,, -25400,0.036435578,0.21753,,,,,,,,,,,,,, -25403,,,0.7537097930908203,0.257313711302621,0.7254681949475943,0.2855302062541766,3554.0,0.7426374659662106,0.2869819187312378,3581.0,5952.507550954819,6261.585339069367,5952.507550954819,302.9748706817627,5.173599481582642,0.0 -25500,0.041717824,0.2525762,,,,,,,,,,,,,, -25600,0.030352866,0.3311292,,,,,,,,,,,,,, -25700,0.03998619,0.41094276,,,,,,,,,,,,,, -25746,,,0.7538211005074638,0.257412246295384,0.7256960549512873,0.2856286112753675,3554.0,0.7428361327579587,0.2870435163449455,3581.0,6032.525489091873,6345.692269086838,6032.525489091873,307.0195956230164,5.20525050163269,0.0 -25800,0.02620307,0.2766865,,,,,,,,,,,,,, -25900,0.051166877,0.24456151,,,,,,,,,,,,,, -26000,0.022095593,0.29315588,,,,,,,,,,,,,, -26088,,,0.7538996423993792,0.2571394102913992,0.725426497322559,0.285577863136167,3554.0,0.7425881060632505,0.2870197908667272,3581.0,6112.56943154335,6429.815916538239,6112.56943154335,311.0550262928009,5.23749852180481,0.0 -26100,0.054678068,0.27253172,,,,,,,,,,,,,, -26200,0.030088516,0.29186815,,,,,,,,,,,,,, -26300,0.030793447,0.36123115,,,,,,,,,,,,,, -26400,0.028835678,0.2484959,,,,,,,,,,,,,, -26433,,,0.7542871747698102,0.2572294303349086,0.7257701077351224,0.2856593177636554,3554.0,0.7429105134957065,0.2871216467990785,3581.0,6192.580571889877,6513.905717372894,6192.580571889877,315.09079217910767,5.267976999282837,0.0 -26500,0.067383185,0.25064686,,,,,,,,,,,,,, -26600,0.045998834,0.26333496,,,,,,,,,,,,,, -26700,0.056055024,0.3408944,,,,,,,,,,,,,, -26776,,,0.7540904453822544,0.2573316778455461,0.7257306770320414,0.2856416460766038,3554.0,0.7429086727258447,0.287038334918668,3581.0,6272.620399475098,6598.033495903015,6272.620399475098,319.1351087093353,5.299068212509155,0.0 -26800,0.04246363,0.17741223,,,,,,,,,,,,,, -26900,0.034942485,0.21187045,,,,,,,,,,,,,, -27000,0.055931944,0.2115097,,,,,,,,,,,,,, -27100,0.06662346,0.24584138,,,,,,,,,,,,,, -27120,,,0.7539983476911273,0.2578939199447632,0.7256790186893289,0.2862922011564874,3554.0,0.7427011429680955,0.2878444898531311,3581.0,6352.729802370071,6682.2287175655365,6352.729802370071,323.1779954433441,5.3297858238220215,0.0 -27200,0.028577762,0.30970848,,,,,,,,,,,,,, -27300,0.03847068,0.2399621,,,,,,,,,,,,,, -27400,0.08558118,0.3044198,,,,,,,,,,,,,, -27465,,,0.7527961730957031,0.2573528460093907,0.7244910828907569,0.285726260655907,3554.0,0.7416571537498254,0.2871871645708252,3581.0,6432.850472688675,6766.438841342926,6432.850472688675,327.2229354381561,5.361184120178223,0.0 -27500,0.03942808,0.23515104,,,,,,,,,,,,,, -27600,0.019579798,0.24391903,,,,,,,,,,,,,, -27700,0.04571179,0.29007685,,,,,,,,,,,,,, -27800,0.044139408,0.3115634,,,,,,,,,,,,,, -27807,,,0.754103592463902,0.2571893760136196,0.7257026496333356,0.285555331305835,3554.0,0.7428482000270525,0.2869928951737469,3581.0,6513.029781341553,6850.704304218292,6513.029781341553,331.2638852596283,5.393904447555542,0.0 -27900,0.05079529,0.268633,,,,,,,,,,,,,, -28000,0.038412515,0.23380974,,,,,,,,,,,,,, -28100,0.03776551,0.27264836,,,,,,,,,,,,,, -28147,,,0.7540349960327148,0.257232666015625,0.7256720118396525,0.2855900736021472,3554.0,0.7427993855373848,0.2870455957331227,3581.0,6593.055602550507,6934.813687324524,6593.055602550507,335.303875207901,5.425286054611206,0.0 -28200,0.0615703,0.25451246,,,,,,,,,,,,,, -28300,0.03830544,0.23822193,,,,,,,,,,,,,, -28400,0.03620255,0.23550604,,,,,,,,,,,,,, -28493,,,0.7544220515659877,0.2567970412118094,0.7256915898019837,0.2854602236256682,3554.0,0.7428325193948967,0.2869222641523841,3581.0,6673.217515468597,7019.063020467758,6673.217515468597,339.34671092033386,5.456988573074341,0.0 -28500,0.041520152,0.29372942,,,,,,,,,,,,,, -28600,0.036227874,0.36029148,,,,,,,,,,,,,, -28700,0.033783555,0.2652331,,,,,,,,,,,,,, -28800,0.041134637,0.20532456,,,,,,,,,,,,,, -28837,,,0.7540216445922852,0.2571344716208322,0.7254533569129854,0.2856393791546497,3554.0,0.7426146949612539,0.2871064434035534,3581.0,6753.32670044899,7103.257844209671,6753.32670044899,343.3887577056885,5.488369941711426,0.0 -28900,0.028204154,0.33933035,,,,,,,,,,,,,, -29000,0.038404766,0.2604258,,,,,,,,,,,,,, -29100,0.028927786,0.25242823,,,,,,,,,,,,,, -29179,,,0.7539787292480469,0.2571410962513515,0.7254288329391179,0.2856076078999894,3554.0,0.742586060763404,0.287044641259861,3581.0,6833.4487562179565,7187.468809604645,6833.4487562179565,347.432626247406,5.521002531051636,0.0 -29200,0.025623133,0.33047625,,,,,,,,,,,,,, -29300,0.08293215,0.34299463,,,,,,,,,,,,,, -29400,0.07978841,0.18733114,,,,,,,,,,,,,, -29500,0.021839347,0.2231496,,,,,,,,,,,,,, -29523,,,0.7546099935259137,0.256743107523237,0.7256827968925859,0.2855766266332829,3554.0,0.7428071576768012,0.2870695257413257,3581.0,6913.447591543198,7271.553107976913,6913.447591543198,351.4746322631836,5.552271842956543,0.0 -29600,0.055594657,0.20449649,,,,,,,,,,,,,, -29700,0.020230269,0.30828816,,,,,,,,,,,,,, -29800,0.025719747,0.35622448,,,,,,,,,,,,,, -29867,,,0.7541588374546596,0.257029788834708,0.7254314433340954,0.2856062511815472,3554.0,0.7425986734457903,0.2870715710411721,3581.0,6993.489233732224,7355.679824829102,6993.489233732224,355.5142960548401,5.584680080413818,0.0 -29900,0.031854324,0.26255915,,,,,,,,,,,,,, -30000,0.015823605,0.3210859,,,,,,,,,,,,,, -30100,0.022341724,0.22492275,,,,,,,,,,,,,, -30200,0.03178259,0.27936694,,,,,,,,,,,,,, -30211,,,0.7540692601885114,0.2571203368050711,0.7255136707758864,0.2856606573084465,3554.0,0.7426834170360933,0.2870932512195441,3581.0,7073.679226160049,7439.95772767067,7073.679226160049,359.5573136806488,5.617060422897339,0.0 -30300,0.027623018,0.29668722,,,,,,,,,,,,,, -30400,0.03029895,0.28349894,,,,,,,,,,,,,, -30500,0.03162226,0.22823831,,,,,,,,,,,,,, -30552,,,0.7548450742449079,0.2567461047853742,0.7258831103598059,0.2856448747230233,3554.0,0.7429972342091944,0.2871223285656939,3581.0,7153.867946386337,7524.22979927063,7153.867946386337,363.5970644950866,5.648396730422974,0.0 -30600,0.027417073,0.28040197,,,,,,,,,,,,,, -30700,0.025858002,0.29048347,,,,,,,,,,,,,, -30800,0.0200904,0.28194657,,,,,,,,,,,,,, -30896,,,0.7542660576956612,0.25690986428942,0.7255811975722777,0.2855315114516654,3554.0,0.7427162100102974,0.2870102120457798,3581.0,7233.912236690521,7608.354736804962,7233.912236690521,367.6330301761627,5.680180311203003,0.0 -30900,0.05272328,0.21569799,,,,,,,,,,,,,, -31000,0.023155741,0.2573167,,,,,,,,,,,,,, -31100,0.018987477,0.21193534,,,,,,,,,,,,,, -31200,0.024071626,0.23830757,,,,,,,,,,,,,, -31241,,,0.754291330065046,0.2570479597364153,0.7256142396771244,0.2856827082765458,3554.0,0.742731617935807,0.2871228058023247,3581.0,7314.072935581207,7692.604668617248,7314.072935581207,371.6770513057709,5.712636709213257,0.0 -31300,0.03130703,0.2569682,,,,,,,,,,,,,, -31400,0.028919531,0.29839277,,,,,,,,,,,,,, -31500,0.04221724,0.3099343,,,,,,,,,,,,,, -31581,,,0.7546959604535785,0.2566973481859479,0.725824445167417,0.2855513126714617,3554.0,0.742917058455215,0.2870380963003525,3581.0,7394.21529507637,7776.835196018219,7394.21529507637,375.7203893661499,5.745257377624512,0.0 -31600,0.024257861,0.22664234,,,,,,,,,,,,,, -31700,0.017768634,0.25878963,,,,,,,,,,,,,, -31800,0.04941305,0.27671263,,,,,,,,,,,,,, -31900,0.017751679,0.34025937,,,,,,,,,,,,,, -31928,,,0.7544897624424526,0.2567943504878452,0.7257194111168753,0.2854953780896085,3554.0,0.7428276788519268,0.2869768054816217,3581.0,7474.3756783008575,7861.083312034607,7474.3756783008575,379.7598538398743,5.78111457824707,0.0 -32000,0.046931684,0.17072129,,,,,,,,,,,,,, -32100,0.016861998,0.26930588,,,,,,,,,,,,,, -32200,0.027794292,0.26593524,,,,,,,,,,,,,, -32273,,,0.754249095916748,0.2569260426930019,0.7255921887090251,0.285591790967264,3554.0,0.7427074833976194,0.287051424837685,3581.0,7554.379847288132,7945.174165248871,7554.379847288132,383.8014781475067,5.81369686126709,0.0 -32300,0.025723549,0.2719513,,,,,,,,,,,,,, -32400,0.032249715,0.24529038,,,,,,,,,,,,,, -32500,0.026802544,0.22900854,,,,,,,,,,,,,, -32600,0.02334973,0.33461106,,,,,,,,,,,,,, -32612,,,0.7544171469552177,0.2567449126924787,0.7255723359682752,0.2855080694178215,3554.0,0.7427035973279112,0.2869986561016476,3581.0,7634.347116231918,8029.232703924179,7634.347116231918,387.8406648635864,5.852944374084473,0.0 -32700,0.027490985,0.27506918,,,,,,,,,,,,,, -32800,0.022350399,0.25486216,,,,,,,,,,,,,, -32900,0.015398521,0.2101331,,,,,,,,,,,,,, -32956,,,0.7545514106750488,0.2567299774714878,0.7256814916950971,0.2855084128908448,3554.0,0.7428032034304315,0.2869958608585241,3581.0,7714.379668951035,8113.352165699005,7714.379668951035,391.8827121257782,5.8849639892578125,0.0 -33000,0.020909488,0.33004066,,,,,,,,,,,,,, -33100,0.020687023,0.22367285,,,,,,,,,,,,,, -33200,0.020802816,0.35240075,,,,,,,,,,,,,, -33299,,,0.754565920148577,0.256818277495248,0.7257891361406162,0.285514681273521,3554.0,0.7429174675151843,0.2869860775075921,3581.0,7794.549304008484,8197.608664512634,7794.549304008484,395.9238955974579,5.918242931365967,0.0 -33300,0.025737986,0.31474194,,,,,,,,,,,,,, -33400,0.017450865,0.2534073,,,,,,,,,,,,,, -33500,0.017463323,0.26415935,,,,,,,,,,,,,, -33600,0.02275479,0.29457855,,,,,,,,,,,,,, -33641,,,0.7544005257742745,0.2568733351571219,0.725658547697137,0.2856053753253376,3554.0,0.7427757964124895,0.2871052844003072,3581.0,7874.795213460922,8281.944691896439,7874.795213460922,399.9638068675995,5.955785274505615,0.0 -33700,0.020025901,0.25544578,,,,,,,,,,,,,, -33800,0.020289514,0.24009216,,,,,,,,,,,,,, -33900,0.018566798,0.28361318,,,,,,,,,,,,,, -33984,,,0.7547591754368373,0.2566869599478585,0.7258627767568233,0.2855236115721282,3554.0,0.7429784856272689,0.2870194158950886,3581.0,7954.826099395752,8366.067419290543,7954.826099395752,404.01122069358826,5.987571239471436,0.0 -34000,0.023955168,0.27786362,,,,,,,,,,,,,, -34100,0.017967004,0.25710592,,,,,,,,,,,,,, -34200,0.021846116,0.26612383,,,,,,,,,,,,,, -34300,0.020194737,0.22447906,,,,,,,,,,,,,, -34329,,,0.7546742303030831,0.2567040579659598,0.7258765843723621,0.2854583688713421,3554.0,0.7429921891362399,0.2869331042415701,3581.0,8034.850677967071,8450.185488700867,8034.850677967071,408.0551481246948,6.024341106414795,0.0 -34400,0.017800825,0.33194065,,,,,,,,,,,,,, -34500,0.022690924,0.22038904,,,,,,,,,,,,,, -34600,0.028577441,0.25705612,,,,,,,,,,,,,, -34672,,,0.7546633311680385,0.2567169666290283,0.7258422370700267,0.2854794924622784,3554.0,0.7429690772479755,0.2869432284758098,3581.0,8114.963846206665,8534.388310909271,8114.963846206665,412.0980794429779,6.058527708053589,0.0 -34700,0.025334891,0.24386647,,,,,,,,,,,,,, -34800,0.018914478,0.22742271,,,,,,,,,,,,,, -34900,0.017052287,0.29781935,,,,,,,,,,,,,, -35000,0.021086618,0.24869317,,,,,,,,,,,,,, -35015,,,0.7547167369297573,0.2566088608333042,0.7258234147483469,0.2854470342615715,3554.0,0.7429394204002024,0.2869227413890149,3581.0,8195.090363264084,8618.609902381897,8195.090363264084,416.14691257476807,6.092074155807495,0.0 -35100,0.021406502,0.3008088,,,,,,,,,,,,,, -35200,0.015460953,0.23113263,,,,,,,,,,,,,, -35300,0.016510662,0.27850538,,,,,,,,,,,,,, -35360,,,0.7546516145978656,0.2566218546458653,0.725779999758195,0.2854279715087753,3554.0,0.7429039003595365,0.28690225430222,3581.0,8275.109664440155,8702.720784902573,8275.109664440155,420.1879172325134,6.129887580871582,0.0 -35400,0.018342985,0.32124346,,,,,,,,,,,,,, -35500,0.026686257,0.31532988,,,,,,,,,,,,,, -35600,0.017462851,0.21078788,,,,,,,,,,,,,, -35700,0.019896487,0.21061116,,,,,,,,,,,,,, -35701,,,0.7548180307660785,0.2566397019795009,0.7259626587120146,0.2854459866688502,3554.0,0.7430810915028973,0.2869127194197675,3581.0,8355.086299180984,8786.73073220253,8355.086299180984,424.1739959716797,6.164317607879639,0.0 -35800,0.020069486,0.24375547,,,,,,,,,,,,,, -35900,0.016380938,0.21986884,,,,,,,,,,,,,, -36000,0.016594274,0.23557231,,,,,,,,,,,,,, -36044,,,0.7547200066702706,0.2566244431904384,0.7258568690208216,0.2854308395085203,3554.0,0.742981485400377,0.2869018793305815,3581.0,8435.13413977623,8870.86947607994,8435.13413977623,428.2131233215332,6.202942609786987,0.0 -36100,0.01656713,0.37203568,,,,,,,,,,,,,, -36189,,,0.7547260693141392,0.2566257204328264,0.7258640132597074,0.2854322477479161,3554.0,0.7429876212999162,0.2869036519237817,3581.0,8467.644447088242,8907.461469173431,8467.644447088242,432.2550909519196,6.2368481159210205,0.0 -36189,,,,,,,,,,,8467.644447088242,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 7c59f20b4..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.984562873840332,0.0,33.548996925354004,1,0,33.548996925354004,0.912513662602974,3581,0.2829431162570249,37.53365206718445,0.9032635007585798,0.2670762368610927,0.915261701713738,3554,0.2587972028106535 -8.02499008178711,0.0185849666595459,113.72720742225648,343,0,113.72720742225648,0.3437465229902611,3581,0.6831162406625244,121.78289651870728,0.3141788755144392,0.6902498517717633,0.3427278929771032,3554,0.6638744141724113 -12.062600135803224,0.0420329570770263,193.7486245632172,683,0,193.7486245632172,0.3172561061745148,3581,0.7118955866247207,205.8772604465485,0.2880679539271763,0.7197771753583636,0.3150176572611846,3554,0.6944376055149127 -16.101449489593506,0.0656051635742187,273.89558959007263,1023,0,273.89558959007263,0.3091897844635751,3581,0.7172653850181514,290.09849548339844,0.2809505462646484,0.7243311745779855,0.307128803208093,3554,0.6999504162343486 -20.14368963241577,0.0890505313873291,354.088178396225,1370,0,354.088178396225,0.3018685313983524,3581,0.7265489328443522,374.368869304657,0.2731507675988333,0.7353413445608956,0.2996465525200478,3554,0.7095289171312253 -24.17908215522766,0.1127169132232666,434.0783689022064,1713,0,434.0783689022064,0.2983961235841071,3581,0.7295050728890324,458.4306387901306,0.2702325752803257,0.7381311144147601,0.296381807085766,3554,0.7123925204215321 -28.21941351890564,0.1388208866119384,514.2202973365784,2055,0,514.2202973365784,0.29680184644216,3581,0.7316638187002583,542.6508350372314,0.2689672878810337,0.7401712962559291,0.2949395638607027,3554,0.7145240453098621 -32.20990490913391,0.1659994125366211,594.2618770599365,2401,0,594.2618770599365,0.2944348209299078,3581,0.7333096714866657,626.7217772006989,0.2663253034864153,0.742351940699986,0.2926883042764315,3554,0.7161404293577659 -36.25227451324463,0.1907820701599121,674.29638504982,2746,0,674.29638504982,0.2963557324573269,3581,0.7309620763229545,710.8356049060822,0.268444401877267,0.7398902348109654,0.2946146383806098,3554,0.7138087970860298 -40.29451560974121,0.2158269882202148,754.3580448627472,3090,0,754.3580448627472,0.2928027739174637,3581,0.7354993013255725,794.9765889644623,0.2649519273212978,0.7444577898297992,0.2911634557892515,3554,0.7183651041300295 -44.33660364151001,0.2407479286193847,834.3319246768951,3433,0,834.3319246768951,0.2922338396768535,3581,0.7354296247774714,879.0295767784119,0.2641278845923288,0.7445534297398159,0.2906195319094682,3554,0.7181353893720104 -48.37433314323425,0.2658648490905761,914.4203317165376,3768,0,914.4203317165376,0.2921161326706925,3581,0.7355386392592851,963.192458152771,0.2646568502698626,0.7440680095127651,0.2905039188898072,3554,0.7183314437737408 -52.41978621482849,0.2896695137023926,994.5614778995514,4112,0,994.5614778995514,0.2918586294200293,3581,0.7359028389852694,1047.414608478546,0.2639062404632568,0.7450623512268066,0.2901824624872502,3554,0.7189520995269415 -56.46136999130249,0.3157362937927246,1074.5751271247864,4454,0,1074.5751271247864,0.2909781960128106,3581,0.7382009378926976,1131.507836341858,0.2626642840249197,0.747828483581543,0.2893484756392445,3554,0.7210435754607485 -60.50531268119812,0.3406782150268554,1154.6210136413574,4800,0,1154.6210136413574,0.2904206131863481,3581,0.7379298674863864,1215.6344573497772,0.2625446149281093,0.7469865253993443,0.2888648656223621,3554,0.7206881495761818 -64.54547190666199,0.3658421039581299,1234.7858142852783,5147,0,1234.7858142852783,0.2899860551456472,3581,0.738178439594387,1299.8765351772308,0.2621636731284005,0.747349670955113,0.2884935369368142,3554,0.7209261076867614 -68.58561253547668,0.3900299072265625,1314.8512017726898,5490,0,1314.8512017726898,0.2906221093095504,3581,0.7386932415657288,1384.0182185173037,0.2623655114855085,0.7481654030936105,0.289046837630135,3554,0.7215267733100028 -72.6303391456604,0.4150753021240234,1395.0299680233002,5835,0,1395.0299680233002,0.2937657010851542,3581,0.7345956878534278,1468.2786684036255,0.2656274012156895,0.7441043172563825,0.2921757738309827,3554,0.7176361169852631 -76.67470240592957,0.4403700828552246,1475.153118133545,6180,0,1475.153118133545,0.2896401949416189,3581,0.7403823183555571,1552.4831855297089,0.2616130965096609,0.7495601517813546,0.2881927576102631,3554,0.7231619796795864 -80.7171881198883,0.4680335521697998,1555.2574756145475,6522,0,1555.2574756145475,0.2900773777837894,3581,0.7387238528867635,1636.669602394104,0.2620629412787301,0.7479400634765625,0.2886338800141566,3554,0.7214551935319359 -84.75681829452515,0.4930555820465088,1635.29860162735,6869,0,1635.29860162735,0.2890012091812692,3581,0.7399360339290701,1720.7873928546906,0.260955878666469,0.749246529170445,0.2875545675592202,3554,0.7226707445615855 -88.79988646507263,0.5191292762756348,1715.3318283557892,7216,0,1715.3318283557892,0.2902471376710416,3581,0.7392699479457554,1804.9016375541687,0.261763368334089,0.7487942150660923,0.2888460433006823,3554,0.7220146423923748 -92.84628129005432,0.5441994667053223,1795.4933722019196,7561,0,1795.4933722019196,0.2895986071780753,3581,0.7396165580930606,1889.1465952396395,0.2618751525878906,0.7482390403747559,0.2880615852626442,3554,0.7224570356464547 -96.89130020141602,0.5703647136688232,1875.5250089168549,7906,0,1875.5250089168549,0.2884128445921181,3581,0.7408963022069603,1973.261443376541,0.2602382046835763,0.7501950945172992,0.2870323168272105,3554,0.7236638624613112 -100.94035053253174,0.6000580787658691,1955.5428388118744,8250,0,1955.5428388118744,0.289103440085259,3581,0.739940397235409,2057.3702251911163,0.2608766555786133,0.7495725495474679,0.2876780117638136,3554,0.7226769957706106 -104.98532390594482,0.6310033798217773,2035.5123291015625,8593,0,2035.5123291015625,0.2886298168174916,3581,0.7411272165596202,2141.427557706833,0.2605338437216623,0.750629220690046,0.2872329394201516,3554,0.723893233746307 -109.029296875,0.6593132019042969,2115.628773212433,8936,0,2115.628773212433,0.2885630036891755,3581,0.741647745370532,2225.628358125686,0.2599715845925467,0.751429694039481,0.2870964088933684,3554,0.7244846255979178 -113.07417154312134,0.6917684078216553,2195.7629470825195,9282,0,2195.7629470825195,0.2881314113332519,3581,0.7411813488288885,2309.8518760204315,0.2599564790725708,0.7505640983581543,0.2866972417192776,3554,0.7239608979319077 -117.11841011047365,0.7217345237731934,2275.9334716796875,9625,0,2275.9334716796875,0.2894506638225181,3581,0.7387132173275621,2394.1086626052856,0.2619223764964512,0.746929236820766,0.2881119899288214,3554,0.7217569689302546 -121.1582088470459,0.7490215301513672,2356.0825872421265,9971,0,2356.0825872421265,0.2880375320703016,3581,0.740755926460835,2478.336655139923,0.2593847853796823,0.7508612360273089,0.2865976001952026,3554,0.7234695941193022 -125.19908928871156,0.7764010429382324,2436.082048892975,10318,0,2436.082048892975,0.2879811158828714,3581,0.740412725146607,2562.416161775589,0.2598751102175031,0.7500311306544712,0.2866478846458216,3554,0.7230898503446821 -129.23694705963135,0.8038809299468994,2516.108592271805,10663,0,2516.108592271805,0.2877796879363306,3581,0.7412570930998673,2646.5201659202576,0.2597089494977678,0.7506626674107143,0.2864253484739906,3554,0.724038454140581 -133.27631831169128,0.8316900730133057,2596.211658477783,11009,0,2596.211658477783,0.2879982623132505,3581,0.7412414806443731,2730.702555894852,0.2593451397759573,0.7513034003121513,0.2865712901616137,3554,0.7240849603879431 -137.32483291625977,0.8595569133758545,2676.3213658332825,11355,0,2676.3213658332825,0.2877812559995462,3581,0.7413318828975844,2814.900425195694,0.2594878673553467,0.7510937282017299,0.286352789797807,3554,0.72414403774796 -141.3688714504242,0.8870675563812256,2756.2938911914825,11698,0,2756.2938911914825,0.2875438307757086,3581,0.7414494876387532,2898.9560811519623,0.2594446284430368,0.7509194101606097,0.2861528713245638,3554,0.7242609559651098 -145.4126329421997,0.9173834323883056,2836.426928758621,12040,0,2836.426928758621,0.2896840325349937,3581,0.7397514797062622,2983.175198316574,0.2613473790032523,0.7494462558201381,0.2880535823412,3554,0.7227448660400253 -149.45334696769714,0.9447588920593262,2916.431208133697,12386,0,2916.431208133697,0.2885629014241831,3581,0.7402833940196524,3067.2598235607147,0.2600323813302176,0.750204358782087,0.2871197822326076,3554,0.7231275636826463 -153.49388575553894,0.9722623825073242,2996.577868938446,12733,0,2996.577868938446,0.2877028869270979,3581,0.7415381172987643,3151.4868512153625,0.2594108922140939,0.7511591911315918,0.2863031064249789,3554,0.7242919372318163 -157.53622794151306,1.0001962184906006,3076.590484857559,13076,0,3076.590484857559,0.2877901871422089,3581,0.7408127857965652,3235.581696033478,0.2592736312321254,0.7509640284946987,0.2864076080923343,3554,0.7234867677704699 -161.58116555213928,1.027517318725586,3156.599609851837,13423,0,3156.599609851837,0.2874563601189437,3581,0.7419136343505655,3319.675098657608,0.2590251309531076,0.7516891615731376,0.2861069318076902,3554,0.7246937319745357 -165.62573671340942,1.0552878379821775,3236.6604750156403,13767,0,3236.6604750156403,0.2875207529757749,3581,0.7428037488437238,3403.820134162903,0.2592755045209612,0.7525512150355748,0.2862711806074581,3554,0.725603042456563 -169.66971588134766,1.0825152397155762,3316.681303024292,14109,0,3316.681303024292,0.2877811878228846,3581,0.7413269060012916,3487.92396235466,0.2595189298902239,0.751023496900286,0.2864473479211364,3554,0.7241041261826463 -173.71239233016968,1.1100866794586182,3396.6965384483337,14455,0,3396.6965384483337,0.2875540231866099,3581,0.741536821942195,3572.0214359760284,0.2590499605451311,0.7515461785452706,0.2861744070831282,3554,0.7243747142304445 -177.75773978233337,1.1371407508850098,3476.7126338481903,14798,0,3476.7126338481903,0.2871954139468723,3581,0.7420895983140184,3656.121811389923,0.2589538437979562,0.7518469265529087,0.2859152223397053,3554,0.7248509052300225 -181.7992980480194,1.1644668579101562,3556.6941096782684,15139,0,3556.6941096782684,0.2873668100740016,3581,0.7410505178153798,3740.1839442253113,0.2592305285590036,0.7508306503295898,0.2860429084361371,3554,0.723704598361881 -185.84392023086548,1.192673683166504,3636.6995763778687,15484,0,3636.6995763778687,0.2872118104339744,3581,0.7418912042289165,3824.274204969406,0.258493321282523,0.7520397731236049,0.2858670845954822,3554,0.7246311511896807 -189.8890540599823,1.2215955257415771,3716.708996295929,15830,0,3716.708996295929,0.2872363199438006,3581,0.7425571538589081,3908.369560718536,0.2586504561560495,0.7524919509887695,0.2858471803337788,3554,0.7252742700786086 -193.92577123641968,1.2505333423614502,3796.776345968248,16172,0,3796.776345968248,0.2872484213012252,3581,0.7425735162576794,3992.514327287674,0.2588962146214076,0.752331052507673,0.2859032351311902,3554,0.7253598635560284 -197.97216725349423,1.2790398597717283,3876.9645669460297,16518,0,3876.9645669460297,0.2870318922241517,3581,0.7430086197116728,4076.7896065711975,0.2581365278788975,0.7533513477870396,0.2856688663137046,3554,0.7257867318294527 -202.0219411849976,1.3082997798919678,3957.033786058426,16865,0,3957.033786058426,0.2890332522121963,3581,0.7400654332326864,4160.949885845184,0.2604859556470598,0.7501613753182548,0.2876615250586927,3554,0.7228891934044387 -206.0645265579224,1.340169906616211,4037.09632229805,17208,0,4037.09632229805,0.2873627876509704,3581,0.7430537526616169,4245.098760128021,0.2587417704718454,0.7530238287789481,0.2859616255451603,3554,0.7258637384812887 -210.106116771698,1.3692708015441897,4117.163950681686,17550,0,4117.163950681686,0.2871518490601438,3581,0.7427513209909942,4329.249045372009,0.258039082799639,0.7533447401864188,0.2857619303293824,3554,0.7255524832275253 -214.1493580341339,1.3973166942596436,4197.277628183365,17896,0,4197.277628183365,0.2872524778125873,3581,0.7418576613114354,4413.445830821991,0.2586672987256731,0.7520776476178851,0.2858569006203397,3554,0.7246424857994513 -218.1914882659912,1.4288592338562012,4277.350579023361,18240,0,4277.350579023361,0.2870502317461079,3581,0.7430021429288257,4497.604445934296,0.2585871389933994,0.7529400416782924,0.285744481899796,3554,0.7257781450038688 -222.23729467391968,1.456954002380371,4357.3684694767,18583,0,4357.3684694767,0.287416579036931,3581,0.7419922420413292,4581.708220720291,0.2583138602120535,0.7526175635201591,0.2860380311192054,3554,0.724809276299592 -226.27573919296265,1.4869132041931152,4437.445201873779,18928,0,4437.445201873779,0.2870654692299637,3581,0.742520133931688,4665.86545753479,0.2583979368209839,0.7526444707598005,0.2857613464252427,3554,0.7252818264851224 -230.3186469078064,1.516965389251709,4517.462399482727,19275,0,4517.462399482727,0.2868797560039095,3581,0.7416877650708601,4749.96740436554,0.2583988564355032,0.7517557144165039,0.2855900220811937,3554,0.7244192283342712 -234.3654458522797,1.5466229915618896,4597.49485373497,19617,0,4597.49485373497,0.2880329642339779,3581,0.7436792735313111,4834.088021278381,0.2587548664637974,0.7539084298270089,0.2865791213465462,3554,0.7266995457363182 -238.41141080856323,1.579699993133545,4677.679318904877,19961,0,4677.679318904877,0.2874829149286163,3581,0.7423036048546147,4918.363613605499,0.2587899650846209,0.7524251937866211,0.2861384798048853,3554,0.7250721705516672 -242.4510452747345,1.6081418991088867,4757.7959949970245,20307,0,4757.7959949970245,0.2868783583823478,3581,0.7423865076750559,5002.560688257217,0.2583001852035522,0.7524850709097726,0.2855643646463492,3554,0.7251725333690912 -246.49749064445496,1.6389946937561035,4837.757677555084,20651,0,4837.757677555084,0.2873984440449595,3581,0.7424755463950363,5086.611615896225,0.2586125305720738,0.7529078211103167,0.2861139043100643,3554,0.7252262525499438 -250.5354623794556,1.667454719543457,4917.874920606613,20998,0,4917.874920606613,0.2869485803437412,3581,0.743132905765673,5170.8072781562805,0.2580517871039254,0.7535805702209473,0.2855590923354407,3554,0.72609613232889 -254.5735669136048,1.696807622909546,4998.068671941757,21342,0,4998.068671941757,0.286775820683381,3581,0.7431851972650796,5255.080285787582,0.2580889463424682,0.7533642905099052,0.2854546593626899,3554,0.7260072415104459 -258.6179361343384,1.7260563373565674,5078.216583013535,21685,0,5078.216583013535,0.28759192941043,3581,0.7425711300745252,5339.313606500626,0.2587615592139108,0.7526271683829171,0.2861455725228176,3554,0.7254229252031162 -262.6589164733887,1.7574434280395508,5158.255975008011,22031,0,5158.255975008011,0.2871274759036407,3581,0.7418408898526948,5423.437246322632,0.2580763101577759,0.7523194040570941,0.2857501663783325,3554,0.7246382954285664 -266.70121240615845,1.7868409156799316,5238.368992090225,22376,0,5238.368992090225,0.2867114619148806,3581,0.743134405652227,5507.633875846863,0.2578743355614798,0.7534683772495815,0.2853841786982977,3554,0.7259059169685566 -270.7464687824249,1.816502809524536,5318.356097221375,22719,0,5318.356097221375,0.2866599544470818,3581,0.7429689408946524,5591.70784330368,0.2578210149492536,0.753256116594587,0.285283953270083,3554,0.7257600096282358 -274.7826302051544,1.8535089492797847,5398.403444766998,23065,0,5398.403444766998,0.2866411376884948,3581,0.7428008854239389,5675.84022974968,0.2573808772223336,0.7534429005214146,0.2852384774417909,3554,0.7255944556309791 -278.8257200717926,1.882972240447998,5478.536856174469,23411,0,5478.536856174469,0.2866344222873324,3581,0.7427481166879014,5760.058264970779,0.2576410429818289,0.7532764162336077,0.2853181631832091,3554,0.7255148385841658 -282.8697578907013,1.914400339126587,5558.539285182953,23752,0,5558.539285182953,0.286564643474239,3581,0.7431360418921041,5844.148059844971,0.2576276915413992,0.7536858149937221,0.2852405726272334,3554,0.7259083899743247 -286.9048192501068,1.944274663925171,5638.572496652603,24097,0,5638.572496652603,0.286486819815083,3581,0.7432274667952388,5928.258192777634,0.2570815426962716,0.754117625100272,0.2851191377398266,3554,0.7259977616550014 -290.94787549972534,1.973649024963379,5718.696158885956,24442,0,5718.696158885956,0.2867236655372975,3581,0.7431903105146956,6012.466329336166,0.2575973272323608,0.7537539345877511,0.2854220294254713,3554,0.7259774280520188 -294.99489879608154,2.0035831928253174,5798.691824436188,24786,0,5798.691824436188,0.2866072879760367,3581,0.7434939693652262,6096.550888776779,0.2577033042907715,0.753805433000837,0.285351960928707,3554,0.7263403416484947 -299.0336084365845,2.036623239517212,5878.656435251236,25128,0,5878.656435251236,0.2866430807233489,3581,0.7430844321593131,6180.59944486618,0.2571213245391845,0.7541419437953404,0.2853163084288829,3554,0.725858655080543 -303.0692329406738,2.072427272796631,5958.823487281799,25473,0,5958.823487281799,0.2867858085642977,3581,0.7428558358131457,6264.849833726883,0.257537943976266,0.7536706243242536,0.2854699439122292,3554,0.7256218647782429 -307.111353635788,2.1028239727020264,6038.98579120636,25820,0,6038.98579120636,0.2864411755401773,3581,0.7430077334150726,6349.096369981766,0.2573683943067278,0.7535369055611747,0.28513771963039,3554,0.7258339937174663 -311.15779757499695,2.13295578956604,6118.951326370239,26159,0,6118.951326370239,0.2868199650717327,3581,0.7428385871177744,6433.150407791138,0.2571835517883301,0.7538965770176479,0.2853921300987883,3554,0.7256854072875633 -315.20138478279114,2.1633517742156982,6199.053456783295,26505,0,6199.053456783295,0.286448743149609,3581,0.7432574645263195,6517.338342905045,0.2570844377790178,0.7541265487670898,0.2851045401363341,3554,0.7260599989668332 -319.24842977523804,2.198110342025757,6279.126361370087,26851,0,6279.126361370087,0.2866079356543214,3581,0.7429794401005306,6601.50510263443,0.257405264036996,0.7537615639822823,0.2853107098186023,3554,0.7258054167619232 -323.28483629226685,2.2321996688842773,6359.18699669838,27195,0,6359.18699669838,0.2866400127735793,3581,0.7434335648430955,6685.648197650909,0.2570823771612985,0.7544635363987514,0.2852458105908395,3554,0.7263280453142585 -327.3263280391693,2.268504858016968,6439.320326805115,27543,0,6439.320326805115,0.2864332329591071,3581,0.7433979084491064,6769.8714554309845,0.2569711378642491,0.7543215751647949,0.2850779381506753,3554,0.7262150426895752 -331.366498708725,2.300798416137696,6519.368485689163,27886,0,6519.368485689163,0.2864977621692613,3581,0.743686363904112,6854.003954172134,0.2571393081120082,0.754523481641497,0.2852333596937429,3554,0.726480478642023 -335.408899307251,2.330955982208252,6599.540381908417,28229,0,6599.540381908417,0.2869981788650168,3581,0.7423255577396328,6938.260453462601,0.257628389767238,0.7529176303318569,0.2855669750413266,3554,0.7251205315533554 -339.4527425765991,2.362444400787353,6679.681492090225,28572,0,6679.681492090225,0.2864325852808224,3581,0.7430802733829587,7022.488809347153,0.2567955596106393,0.7542286600385394,0.2850730780073948,3554,0.7258950632210186 -343.49573826789856,2.398847341537476,6759.659672021866,28918,0,6759.659672021866,0.2864167682953434,3581,0.7438062866517733,7106.558267116547,0.2569642066955566,0.7547598566327777,0.2851225209491066,3554,0.7266376518975098 -347.5355215072632,2.4366612434387207,6839.815808057785,29264,0,6839.815808057785,0.2865796082514486,3581,0.7426680091105836,7190.804083108902,0.2571125200816563,0.7535268919808524,0.2852112572046901,3554,0.7255190289550506 -351.5742256641388,2.4678423404693604,6920.005371332169,29609,0,6920.005371332169,0.2864999097341001,3581,0.74362766379852,7275.075638055801,0.2566295181001936,0.7549184390476772,0.2850916598979582,3554,0.7264830890370005 -355.6179938316345,2.49917984008789,7000.064207315445,29956,0,7000.064207315445,0.2863981219784103,3581,0.743571213522759,7359.221385478973,0.2568075145993914,0.7546397617885044,0.2850718758518131,3554,0.7264039528524198 -359.66366052627563,2.530510187149048,7080.181899785996,30300,0,7080.181899785996,0.2865445995357442,3581,0.7436359131745671,7443.427941322327,0.256915875843593,0.7547033173697335,0.285162192083304,3554,0.7264863863780248 -363.7105996608734,2.5629425048828125,7160.268867731094,30646,0,7160.268867731094,0.2863884749808014,3581,0.7435238307429838,7527.606162309647,0.2563994271414621,0.755000250680106,0.2850013264928162,3554,0.7263446694085889 -367.7547652721405,2.593646049499512,7240.273091554642,30990,0,7240.273091554642,0.2863450123590652,3581,0.7430925451820372,7611.697166442871,0.2566179037094116,0.7543023654392788,0.284997616984164,3554,0.7258775460968275 -371.80034589767456,2.63110613822937,7320.299663305283,31334,0,7320.299663305283,0.2864117913990505,3581,0.7438424202823932,7695.818750143051,0.2567619255610874,0.7549773624965123,0.2850739882109067,3554,0.7266886919887803 -375.8334729671478,2.669584989547729,7400.453884601593,31681,0,7400.453884601593,0.2862987204058747,3581,0.7438221718139137,7780.056601762772,0.256252782685416,0.7553375107901437,0.2849354655405881,3554,0.7267150020223692 -379.8758656978607,2.7011075019836426,7480.44393825531,32026,0,7480.44393825531,0.2863594317229824,3581,0.7437481319594736,7864.1325850486755,0.256514344896589,0.7550763402666364,0.2849802887701357,3554,0.726590046536473 -383.9117209911346,2.7331085205078125,7560.630806207657,32370,0,7560.630806207657,0.2863833958195162,3581,0.7435609870235269,7948.39928150177,0.256595424243382,0.7548025676182338,0.2850220894370779,3554,0.7264196839168894 -387.9579894542694,2.765145778656006,7640.62170624733,32712,0,7640.62170624733,0.2862728473628176,3581,0.7437498363760123,8032.480271816254,0.2561919007982526,0.7553097861153739,0.2848949013765299,3554,0.7266341484726716 -391.9960012435913,2.797520399093628,7720.608385562897,33057,0,7720.608385562897,0.2862628253935702,3581,0.743716770695162,8116.549202203751,0.2563128471374511,0.7551392146519252,0.2848909170894591,3554,0.7265522645039041 -395.9805171489716,2.829303979873657,7800.907135248184,33401,0,7800.907135248184,0.2863089469051068,3581,0.7435593507836498,8200.876027584076,0.2564116716384887,0.7549631936209542,0.2849628746878517,3554,0.7264295072453574 -400.01979994773865,2.866029500961304,7881.002731561661,33746,0,7881.002731561661,0.2862026254014241,3581,0.7438874168790143,8285.059644460678,0.2561759267534528,0.7553563117980957,0.2848300021487672,3554,0.7267946190691826 -404.0609667301178,2.9068877696990967,7961.019174337387,34090,0,7961.019174337387,0.2861697301622277,3581,0.7438356707929,8369.170066595078,0.256163546017238,0.7552986826215472,0.2847965478762925,3554,0.7266865624560355 -408.103973865509,2.939914464950561,8041.14001584053,34437,0,8041.14001584053,0.2862102270991867,3581,0.7438358753228846,8453.379067897797,0.2562346458435058,0.7552531787327358,0.2848477081831211,3554,0.7266829216419879 -412.1444194316864,2.974390745162964,8121.161528110504,34779,0,8121.161528110504,0.2861768205350286,3581,0.7437871290098785,8537.487291812897,0.2561770507267543,0.7552689143589565,0.2848144599944605,3554,0.7266373771190912 -416.184344291687,3.0172266960144043,8201.258660793304,35125,0,8201.258660793304,0.2861144729780438,3581,0.7439305727057736,8621.679370641708,0.2560735600335257,0.7554235458374023,0.2847423134859049,3554,0.7267863757166221 -420.22679924964905,3.0492374897003174,8281.31056690216,35471,0,8281.31056690216,0.2861540836184026,3581,0.744045791263788,8705.81771326065,0.2561214140483311,0.7555623735700335,0.2847858830389174,3554,0.7269094764481921 -424.2720458507538,3.081813335418701,8361.536612510681,35814,0,8361.536612510681,0.2861507770503176,3581,0.7439735240025481,8790.133450984955,0.2561060871396746,0.7554898943219867,0.2847845778414287,3554,0.7268266307549592 -428.31070947647095,3.1168370246887207,8441.574692964554,36159,0,8441.574692964554,0.2861035647121963,3581,0.7439485713444219,8874.257183551788,0.256074139050075,0.7554442541939872,0.2847391707077413,3554,0.7268051980383019 -432.35393595695496,3.1495649814605713,8446.291797876358,36189,0,8446.291797876358,0.28610349653553474,3581,0.7439481622844527,8883.051620006561,0.2560740368706839,0.7554440498352051,0.2847391020131366,3554,0.7268047858706739 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/measurements.csv deleted file mode 100644 index 060345290..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.3851867,0.90665424,,,,,,,,,,,,,, -1,,,0.2670762368610927,0.9032635007585798,0.2587972028106535,0.915261701713738,3554.0,0.2829431162570249,0.912513662602974,3581.0,33.548996925354004,37.53365206718445,33.548996925354004,3.984562873840332,0.0,0.0 -100,0.73464495,0.29955524,,,,,,,,,,,,,, -200,0.17185001,0.3412455,,,,,,,,,,,,,, -300,0.12169004,0.36157715,,,,,,,,,,,,,, -343,,,0.6902498517717633,0.3141788755144392,0.6638744141724113,0.3427278929771032,3554.0,0.6831162406625244,0.3437465229902611,3581.0,113.72720742225648,121.78289651870728,113.72720742225648,8.02499008178711,0.0185849666595459,0.0 -400,0.1172807,0.34984744,,,,,,,,,,,,,, -500,0.18207213,0.28638777,,,,,,,,,,,,,, -600,0.13630565,0.27945775,,,,,,,,,,,,,, -683,,,0.7197771753583636,0.2880679539271763,0.6944376055149127,0.3150176572611846,3554.0,0.7118955866247207,0.3172561061745148,3581.0,193.7486245632172,205.8772604465485,193.7486245632172,12.062600135803224,0.0420329570770263,0.0 -700,0.110142305,0.27766237,,,,,,,,,,,,,, -800,0.06203968,0.23101074,,,,,,,,,,,,,, -900,0.23563707,0.33998543,,,,,,,,,,,,,, -1000,0.09870361,0.23403202,,,,,,,,,,,,,, -1023,,,0.7243311745779855,0.2809505462646484,0.6999504162343486,0.307128803208093,3554.0,0.7172653850181514,0.3091897844635751,3581.0,273.89558959007263,290.09849548339844,273.89558959007263,16.101449489593506,0.0656051635742187,0.0 -1100,0.13454455,0.3583073,,,,,,,,,,,,,, -1200,0.090452775,0.2270513,,,,,,,,,,,,,, -1300,0.09703547,0.31518123,,,,,,,,,,,,,, -1370,,,0.7353413445608956,0.2731507675988333,0.7095289171312253,0.2996465525200478,3554.0,0.7265489328443522,0.3018685313983524,3581.0,354.088178396225,374.368869304657,354.088178396225,20.14368963241577,0.0890505313873291,0.0 -1400,0.4017212,0.29849243,,,,,,,,,,,,,, -1500,0.098531626,0.40212965,,,,,,,,,,,,,, -1600,0.115611486,0.2293644,,,,,,,,,,,,,, -1700,0.05383173,0.33089426,,,,,,,,,,,,,, -1713,,,0.7381311144147601,0.2702325752803257,0.7123925204215321,0.296381807085766,3554.0,0.7295050728890324,0.2983961235841071,3581.0,434.0783689022064,458.4306387901306,434.0783689022064,24.17908215522766,0.1127169132232666,0.0 -1800,0.3251291,0.26426527,,,,,,,,,,,,,, -1900,0.06125223,0.26045698,,,,,,,,,,,,,, -2000,0.20731725,0.24566732,,,,,,,,,,,,,, -2055,,,0.7401712962559291,0.2689672878810337,0.7145240453098621,0.2949395638607027,3554.0,0.7316638187002583,0.29680184644216,3581.0,514.2202973365784,542.6508350372314,514.2202973365784,28.21941351890564,0.1388208866119384,0.0 -2100,0.065006614,0.33147463,,,,,,,,,,,,,, -2200,0.2969419,0.25515926,,,,,,,,,,,,,, -2300,0.07403956,0.24973074,,,,,,,,,,,,,, -2400,0.13143134,0.3539455,,,,,,,,,,,,,, -2401,,,0.742351940699986,0.2663253034864153,0.7161404293577659,0.2926883042764315,3554.0,0.7333096714866657,0.2944348209299078,3581.0,594.2618770599365,626.7217772006989,594.2618770599365,32.20990490913391,0.1659994125366211,0.0 -2500,0.20455775,0.27262956,,,,,,,,,,,,,, -2600,0.13322222,0.27642897,,,,,,,,,,,,,, -2700,0.08473213,0.26037446,,,,,,,,,,,,,, -2746,,,0.7398902348109654,0.268444401877267,0.7138087970860298,0.2946146383806098,3554.0,0.7309620763229545,0.2963557324573269,3581.0,674.29638504982,710.8356049060822,674.29638504982,36.25227451324463,0.1907820701599121,0.0 -2800,0.08253322,0.3125131,,,,,,,,,,,,,, -2900,0.18096769,0.23595259,,,,,,,,,,,,,, -3000,0.114913896,0.26709944,,,,,,,,,,,,,, -3090,,,0.7444577898297992,0.2649519273212978,0.7183651041300295,0.2911634557892515,3554.0,0.7354993013255725,0.2928027739174637,3581.0,754.3580448627472,794.9765889644623,754.3580448627472,40.29451560974121,0.2158269882202148,0.0 -3100,0.33233148,0.25826767,,,,,,,,,,,,,, -3200,0.24717034,0.36920714,,,,,,,,,,,,,, -3300,0.054476622,0.30659473,,,,,,,,,,,,,, -3400,0.15108189,0.2860933,,,,,,,,,,,,,, -3433,,,0.7445534297398159,0.2641278845923288,0.7181353893720104,0.2906195319094682,3554.0,0.7354296247774714,0.2922338396768535,3581.0,834.3319246768951,879.0295767784119,834.3319246768951,44.33660364151001,0.2407479286193847,0.0 -3500,0.13191298,0.26078612,,,,,,,,,,,,,, -3600,0.07065319,0.29091644,,,,,,,,,,,,,, -3700,0.21597683,0.28205517,,,,,,,,,,,,,, -3768,,,0.7440680095127651,0.2646568502698626,0.7183314437737408,0.2905039188898072,3554.0,0.7355386392592851,0.2921161326706925,3581.0,914.4203317165376,963.192458152771,914.4203317165376,48.37433314323425,0.2658648490905761,0.0 -3800,0.14812645,0.27104437,,,,,,,,,,,,,, -3900,0.27396396,0.26730162,,,,,,,,,,,,,, -4000,0.08955052,0.29607725,,,,,,,,,,,,,, -4100,0.11301871,0.268641,,,,,,,,,,,,,, -4112,,,0.7450623512268066,0.2639062404632568,0.7189520995269415,0.2901824624872502,3554.0,0.7359028389852694,0.2918586294200293,3581.0,994.5614778995514,1047.414608478546,994.5614778995514,52.41978621482849,0.2896695137023926,0.0 -4200,0.33595023,0.24440192,,,,,,,,,,,,,, -4300,0.1427521,0.2869095,,,,,,,,,,,,,, -4400,0.09008625,0.29391235,,,,,,,,,,,,,, -4454,,,0.747828483581543,0.2626642840249197,0.7210435754607485,0.2893484756392445,3554.0,0.7382009378926976,0.2909781960128106,3581.0,1074.5751271247864,1131.507836341858,1074.5751271247864,56.46136999130249,0.3157362937927246,0.0 -4500,0.19494565,0.20190588,,,,,,,,,,,,,, -4600,0.14666744,0.20672828,,,,,,,,,,,,,, -4700,0.13488436,0.22002858,,,,,,,,,,,,,, -4800,,,0.7469865253993443,0.2625446149281093,0.7206881495761818,0.2888648656223621,3554.0,0.7379298674863864,0.2904206131863481,3581.0,1154.6210136413574,1215.6344573497772,1154.6210136413574,60.50531268119812,0.3406782150268554,0.0 -4800,0.121489294,0.31324002,,,,,,,,,,,,,, -4900,0.0457071,0.28894824,,,,,,,,,,,,,, -5000,0.3968094,0.37030262,,,,,,,,,,,,,, -5100,0.14313945,0.32857782,,,,,,,,,,,,,, -5147,,,0.747349670955113,0.2621636731284005,0.7209261076867614,0.2884935369368142,3554.0,0.738178439594387,0.2899860551456472,3581.0,1234.7858142852783,1299.8765351772308,1234.7858142852783,64.54547190666199,0.3658421039581299,0.0 -5200,0.07771009,0.2370094,,,,,,,,,,,,,, -5300,0.32967773,0.22128561,,,,,,,,,,,,,, -5400,0.45909643,0.33424562,,,,,,,,,,,,,, -5490,,,0.7481654030936105,0.2623655114855085,0.7215267733100028,0.289046837630135,3554.0,0.7386932415657288,0.2906221093095504,3581.0,1314.8512017726898,1384.0182185173037,1314.8512017726898,68.58561253547668,0.3900299072265625,0.0 -5500,0.038014885,0.27983546,,,,,,,,,,,,,, -5600,0.15527634,0.25551006,,,,,,,,,,,,,, -5700,0.1435485,0.29555768,,,,,,,,,,,,,, -5800,0.0541981,0.2943288,,,,,,,,,,,,,, -5835,,,0.7441043172563825,0.2656274012156895,0.7176361169852631,0.2921757738309827,3554.0,0.7345956878534278,0.2937657010851542,3581.0,1395.0299680233002,1468.2786684036255,1395.0299680233002,72.6303391456604,0.4150753021240234,0.0 -5900,0.2655278,0.30005556,,,,,,,,,,,,,, -6000,0.14066833,0.2113004,,,,,,,,,,,,,, -6100,0.08896887,0.32168892,,,,,,,,,,,,,, -6180,,,0.7495601517813546,0.2616130965096609,0.7231619796795864,0.2881927576102631,3554.0,0.7403823183555571,0.2896401949416189,3581.0,1475.153118133545,1552.4831855297089,1475.153118133545,76.67470240592957,0.4403700828552246,0.0 -6200,0.17510897,0.29394627,,,,,,,,,,,,,, -6300,0.20604527,0.23797734,,,,,,,,,,,,,, -6400,0.29496422,0.23315926,,,,,,,,,,,,,, -6500,0.20745455,0.25706816,,,,,,,,,,,,,, -6522,,,0.7479400634765625,0.2620629412787301,0.7214551935319359,0.2886338800141566,3554.0,0.7387238528867635,0.2900773777837894,3581.0,1555.2574756145475,1636.669602394104,1555.2574756145475,80.7171881198883,0.4680335521697998,0.0 -6600,0.16653976,0.3016951,,,,,,,,,,,,,, -6700,0.1004335,0.28865463,,,,,,,,,,,,,, -6800,0.112788565,0.3078722,,,,,,,,,,,,,, -6869,,,0.749246529170445,0.260955878666469,0.7226707445615855,0.2875545675592202,3554.0,0.7399360339290701,0.2890012091812692,3581.0,1635.29860162735,1720.7873928546906,1635.29860162735,84.75681829452515,0.4930555820465088,0.0 -6900,0.092967406,0.35122767,,,,,,,,,,,,,, -7000,0.17688708,0.2436305,,,,,,,,,,,,,, -7100,0.22380938,0.2464972,,,,,,,,,,,,,, -7200,0.15091902,0.2750969,,,,,,,,,,,,,, -7216,,,0.7487942150660923,0.261763368334089,0.7220146423923748,0.2888460433006823,3554.0,0.7392699479457554,0.2902471376710416,3581.0,1715.3318283557892,1804.9016375541687,1715.3318283557892,88.79988646507263,0.5191292762756348,0.0 -7300,0.09910377,0.27641734,,,,,,,,,,,,,, -7400,0.2629098,0.258754,,,,,,,,,,,,,, -7500,0.11242308,0.203661,,,,,,,,,,,,,, -7561,,,0.7482390403747559,0.2618751525878906,0.7224570356464547,0.2880615852626442,3554.0,0.7396165580930606,0.2895986071780753,3581.0,1795.4933722019196,1889.1465952396395,1795.4933722019196,92.84628129005432,0.5441994667053223,0.0 -7600,0.12963706,0.26161423,,,,,,,,,,,,,, -7700,0.140382,0.26491275,,,,,,,,,,,,,, -7800,0.08476832,0.280011,,,,,,,,,,,,,, -7900,0.4567794,0.26765874,,,,,,,,,,,,,, -7906,,,0.7501950945172992,0.2602382046835763,0.7236638624613112,0.2870323168272105,3554.0,0.7408963022069603,0.2884128445921181,3581.0,1875.5250089168549,1973.261443376541,1875.5250089168549,96.89130020141602,0.5703647136688232,0.0 -8000,0.07655296,0.29895964,,,,,,,,,,,,,, -8100,0.288826,0.2185068,,,,,,,,,,,,,, -8200,0.039417554,0.27500412,,,,,,,,,,,,,, -8250,,,0.7495725495474679,0.2608766555786133,0.7226769957706106,0.2876780117638136,3554.0,0.739940397235409,0.289103440085259,3581.0,1955.5428388118744,2057.3702251911163,1955.5428388118744,100.94035053253174,0.6000580787658691,0.0 -8300,0.28032222,0.24848399,,,,,,,,,,,,,, -8400,0.051985394,0.2678765,,,,,,,,,,,,,, -8500,0.26837173,0.35159436,,,,,,,,,,,,,, -8593,,,0.750629220690046,0.2605338437216623,0.723893233746307,0.2872329394201516,3554.0,0.7411272165596202,0.2886298168174916,3581.0,2035.5123291015625,2141.427557706833,2035.5123291015625,104.98532390594482,0.6310033798217773,0.0 -8600,0.14119965,0.34335208,,,,,,,,,,,,,, -8700,0.46431017,0.20633647,,,,,,,,,,,,,, -8800,0.102895595,0.20329988,,,,,,,,,,,,,, -8900,0.06715999,0.23275492,,,,,,,,,,,,,, -8936,,,0.751429694039481,0.2599715845925467,0.7244846255979178,0.2870964088933684,3554.0,0.741647745370532,0.2885630036891755,3581.0,2115.628773212433,2225.628358125686,2115.628773212433,109.029296875,0.6593132019042969,0.0 -9000,0.57068855,0.19994634,,,,,,,,,,,,,, -9100,0.2604382,0.30396068,,,,,,,,,,,,,, -9200,0.292354,0.21030118,,,,,,,,,,,,,, -9282,,,0.7505640983581543,0.2599564790725708,0.7239608979319077,0.2866972417192776,3554.0,0.7411813488288885,0.2881314113332519,3581.0,2195.7629470825195,2309.8518760204315,2195.7629470825195,113.07417154312134,0.6917684078216553,0.0 -9300,0.039064143,0.31832618,,,,,,,,,,,,,, -9400,0.10915758,0.23039192,,,,,,,,,,,,,, -9500,0.16567649,0.26924303,,,,,,,,,,,,,, -9600,0.07630023,0.27036497,,,,,,,,,,,,,, -9625,,,0.746929236820766,0.2619223764964512,0.7217569689302546,0.2881119899288214,3554.0,0.7387132173275621,0.2894506638225181,3581.0,2275.9334716796875,2394.1086626052856,2275.9334716796875,117.11841011047365,0.7217345237731934,0.0 -9700,0.06435208,0.24756287,,,,,,,,,,,,,, -9800,0.3578377,0.26207405,,,,,,,,,,,,,, -9900,0.12461722,0.25294378,,,,,,,,,,,,,, -9971,,,0.7508612360273089,0.2593847853796823,0.7234695941193022,0.2865976001952026,3554.0,0.740755926460835,0.2880375320703016,3581.0,2356.0825872421265,2478.336655139923,2356.0825872421265,121.1582088470459,0.7490215301513672,0.0 -10000,0.10553235,0.24950096,,,,,,,,,,,,,, -10100,0.07191474,0.28603852,,,,,,,,,,,,,, -10200,0.07080932,0.3514356,,,,,,,,,,,,,, -10300,0.062999934,0.24443142,,,,,,,,,,,,,, -10318,,,0.7500311306544712,0.2598751102175031,0.7230898503446821,0.2866478846458216,3554.0,0.740412725146607,0.2879811158828714,3581.0,2436.082048892975,2562.416161775589,2436.082048892975,125.19908928871156,0.7764010429382324,0.0 -10400,0.24823613,0.30747116,,,,,,,,,,,,,, -10500,0.18094191,0.21511579,,,,,,,,,,,,,, -10600,0.116886094,0.25442445,,,,,,,,,,,,,, -10663,,,0.7506626674107143,0.2597089494977678,0.724038454140581,0.2864253484739906,3554.0,0.7412570930998673,0.2877796879363306,3581.0,2516.108592271805,2646.5201659202576,2516.108592271805,129.23694705963135,0.8038809299468994,0.0 -10700,0.17714216,0.31591594,,,,,,,,,,,,,, -10800,0.15891258,0.19536725,,,,,,,,,,,,,, -10900,0.15083407,0.2567985,,,,,,,,,,,,,, -11000,0.20231277,0.33263537,,,,,,,,,,,,,, -11009,,,0.7513034003121513,0.2593451397759573,0.7240849603879431,0.2865712901616137,3554.0,0.7412414806443731,0.2879982623132505,3581.0,2596.211658477783,2730.702555894852,2596.211658477783,133.27631831169128,0.8316900730133057,0.0 -11100,0.30856264,0.25031418,,,,,,,,,,,,,, -11200,0.16355087,0.24660762,,,,,,,,,,,,,, -11300,0.09664441,0.30012178,,,,,,,,,,,,,, -11355,,,0.7510937282017299,0.2594878673553467,0.72414403774796,0.286352789797807,3554.0,0.7413318828975844,0.2877812559995462,3581.0,2676.3213658332825,2814.900425195694,2676.3213658332825,137.32483291625977,0.8595569133758545,0.0 -11400,0.118315406,0.25939614,,,,,,,,,,,,,, -11500,0.17695808,0.29037791,,,,,,,,,,,,,, -11600,0.5534125,0.2459431,,,,,,,,,,,,,, -11698,,,0.7509194101606097,0.2594446284430368,0.7242609559651098,0.2861528713245638,3554.0,0.7414494876387532,0.2875438307757086,3581.0,2756.2938911914825,2898.9560811519623,2756.2938911914825,141.3688714504242,0.8870675563812256,0.0 -11700,0.08339172,0.24153621,,,,,,,,,,,,,, -11800,0.107461765,0.27820557,,,,,,,,,,,,,, -11900,0.14062808,0.33839548,,,,,,,,,,,,,, -12000,0.2493072,0.32520667,,,,,,,,,,,,,, -12040,,,0.7494462558201381,0.2613473790032523,0.7227448660400253,0.2880535823412,3554.0,0.7397514797062622,0.2896840325349937,3581.0,2836.426928758621,2983.175198316574,2836.426928758621,145.4126329421997,0.9173834323883056,0.0 -12100,0.47914824,0.22050239,,,,,,,,,,,,,, -12200,0.20570499,0.271333,,,,,,,,,,,,,, -12300,0.06469487,0.27972612,,,,,,,,,,,,,, -12386,,,0.750204358782087,0.2600323813302176,0.7231275636826463,0.2871197822326076,3554.0,0.7402833940196524,0.2885629014241831,3581.0,2916.431208133697,3067.2598235607147,2916.431208133697,149.45334696769714,0.9447588920593262,0.0 -12400,0.18605609,0.25182456,,,,,,,,,,,,,, -12500,0.22032598,0.33237392,,,,,,,,,,,,,, -12600,0.09342811,0.37123752,,,,,,,,,,,,,, -12700,0.064038314,0.28961712,,,,,,,,,,,,,, -12733,,,0.7511591911315918,0.2594108922140939,0.7242919372318163,0.2863031064249789,3554.0,0.7415381172987643,0.2877028869270979,3581.0,2996.577868938446,3151.4868512153625,2996.577868938446,153.49388575553894,0.9722623825073242,0.0 -12800,0.1541102,0.2492687,,,,,,,,,,,,,, -12900,0.16421591,0.21803002,,,,,,,,,,,,,, -13000,0.07219181,0.28327546,,,,,,,,,,,,,, -13076,,,0.7509640284946987,0.2592736312321254,0.7234867677704699,0.2864076080923343,3554.0,0.7408127857965652,0.2877901871422089,3581.0,3076.590484857559,3235.581696033478,3076.590484857559,157.53622794151306,1.0001962184906006,0.0 -13100,0.21390577,0.24020253,,,,,,,,,,,,,, -13200,0.25721717,0.19093706,,,,,,,,,,,,,, -13300,0.14545168,0.25479645,,,,,,,,,,,,,, -13400,0.12566991,0.28638873,,,,,,,,,,,,,, -13423,,,0.7516891615731376,0.2590251309531076,0.7246937319745357,0.2861069318076902,3554.0,0.7419136343505655,0.2874563601189437,3581.0,3156.599609851837,3319.675098657608,3156.599609851837,161.58116555213928,1.027517318725586,0.0 -13500,0.12567349,0.26871464,,,,,,,,,,,,,, -13600,0.10883336,0.21630716,,,,,,,,,,,,,, -13700,0.12799133,0.32717803,,,,,,,,,,,,,, -13767,,,0.7525512150355748,0.2592755045209612,0.725603042456563,0.2862711806074581,3554.0,0.7428037488437238,0.2875207529757749,3581.0,3236.6604750156403,3403.820134162903,3236.6604750156403,165.62573671340942,1.0552878379821775,0.0 -13800,0.096021146,0.21552621,,,,,,,,,,,,,, -13900,0.37767977,0.23470062,,,,,,,,,,,,,, -14000,0.26120824,0.3308072,,,,,,,,,,,,,, -14100,0.19437471,0.27711505,,,,,,,,,,,,,, -14109,,,0.751023496900286,0.2595189298902239,0.7241041261826463,0.2864473479211364,3554.0,0.7413269060012916,0.2877811878228846,3581.0,3316.681303024292,3487.92396235466,3316.681303024292,169.66971588134766,1.0825152397155762,0.0 -14200,0.20678046,0.23693478,,,,,,,,,,,,,, -14300,0.14970498,0.287993,,,,,,,,,,,,,, -14400,0.124735035,0.22830296,,,,,,,,,,,,,, -14455,,,0.7515461785452706,0.2590499605451311,0.7243747142304445,0.2861744070831282,3554.0,0.741536821942195,0.2875540231866099,3581.0,3396.6965384483337,3572.0214359760284,3396.6965384483337,173.71239233016968,1.1100866794586182,0.0 -14500,0.12703212,0.3648455,,,,,,,,,,,,,, -14600,0.24106823,0.2040048,,,,,,,,,,,,,, -14700,0.09541106,0.22359142,,,,,,,,,,,,,, -14798,,,0.7518469265529087,0.2589538437979562,0.7248509052300225,0.2859152223397053,3554.0,0.7420895983140184,0.2871954139468723,3581.0,3476.7126338481903,3656.121811389923,3476.7126338481903,177.75773978233337,1.1371407508850098,0.0 -14800,0.07158663,0.2661956,,,,,,,,,,,,,, -14900,0.107699044,0.23443367,,,,,,,,,,,,,, -15000,0.21611871,0.2277636,,,,,,,,,,,,,, -15100,0.14223589,0.20763916,,,,,,,,,,,,,, -15139,,,0.7508306503295898,0.2592305285590036,0.723704598361881,0.2860429084361371,3554.0,0.7410505178153798,0.2873668100740016,3581.0,3556.6941096782684,3740.1839442253113,3556.6941096782684,181.7992980480194,1.1644668579101562,0.0 -15200,0.110962294,0.24423617,,,,,,,,,,,,,, -15300,0.123915985,0.24155363,,,,,,,,,,,,,, -15400,0.14898331,0.23439194,,,,,,,,,,,,,, -15484,,,0.7520397731236049,0.258493321282523,0.7246311511896807,0.2858670845954822,3554.0,0.7418912042289165,0.2872118104339744,3581.0,3636.6995763778687,3824.274204969406,3636.6995763778687,185.84392023086548,1.192673683166504,0.0 -15500,0.25242892,0.34679797,,,,,,,,,,,,,, -15600,0.09117434,0.30396318,,,,,,,,,,,,,, -15700,0.17995648,0.24575032,,,,,,,,,,,,,, -15800,0.08936985,0.30331013,,,,,,,,,,,,,, -15830,,,0.7524919509887695,0.2586504561560495,0.7252742700786086,0.2858471803337788,3554.0,0.7425571538589081,0.2872363199438006,3581.0,3716.708996295929,3908.369560718536,3716.708996295929,189.8890540599823,1.2215955257415771,0.0 -15900,0.15948673,0.29793197,,,,,,,,,,,,,, -16000,0.23896563,0.22410852,,,,,,,,,,,,,, -16100,0.1424174,0.27653903,,,,,,,,,,,,,, -16172,,,0.752331052507673,0.2588962146214076,0.7253598635560284,0.2859032351311902,3554.0,0.7425735162576794,0.2872484213012252,3581.0,3796.776345968248,3992.514327287674,3796.776345968248,193.92577123641968,1.2505333423614502,0.0 -16200,0.10758954,0.25830063,,,,,,,,,,,,,, -16300,0.28752622,0.22188398,,,,,,,,,,,,,, -16400,0.28003463,0.24119082,,,,,,,,,,,,,, -16500,0.12647815,0.20794418,,,,,,,,,,,,,, -16518,,,0.7533513477870396,0.2581365278788975,0.7257867318294527,0.2856688663137046,3554.0,0.7430086197116728,0.2870318922241517,3581.0,3876.9645669460297,4076.7896065711975,3876.9645669460297,197.97216725349423,1.2790398597717283,0.0 -16600,0.2541649,0.18869112,,,,,,,,,,,,,, -16700,0.29615793,0.26604378,,,,,,,,,,,,,, -16800,0.14609107,0.2372062,,,,,,,,,,,,,, -16865,,,0.7501613753182548,0.2604859556470598,0.7228891934044387,0.2876615250586927,3554.0,0.7400654332326864,0.2890332522121963,3581.0,3957.033786058426,4160.949885845184,3957.033786058426,202.0219411849976,1.3082997798919678,0.0 -16900,0.07571279,0.3546036,,,,,,,,,,,,,, -17000,0.08219219,0.23977993,,,,,,,,,,,,,, -17100,0.35983,0.27584007,,,,,,,,,,,,,, -17200,0.27450746,0.34417835,,,,,,,,,,,,,, -17208,,,0.7530238287789481,0.2587417704718454,0.7258637384812887,0.2859616255451603,3554.0,0.7430537526616169,0.2873627876509704,3581.0,4037.09632229805,4245.098760128021,4037.09632229805,206.0645265579224,1.340169906616211,0.0 -17300,0.11540617,0.30723608,,,,,,,,,,,,,, -17400,0.2135419,0.25843352,,,,,,,,,,,,,, -17500,0.13720022,0.301931,,,,,,,,,,,,,, -17550,,,0.7533447401864188,0.258039082799639,0.7255524832275253,0.2857619303293824,3554.0,0.7427513209909942,0.2871518490601438,3581.0,4117.163950681686,4329.249045372009,4117.163950681686,210.106116771698,1.3692708015441897,0.0 -17600,0.17844251,0.25719154,,,,,,,,,,,,,, -17700,0.15155593,0.285876,,,,,,,,,,,,,, -17800,0.19251384,0.36166534,,,,,,,,,,,,,, -17896,,,0.7520776476178851,0.2586672987256731,0.7246424857994513,0.2858569006203397,3554.0,0.7418576613114354,0.2872524778125873,3581.0,4197.277628183365,4413.445830821991,4197.277628183365,214.1493580341339,1.3973166942596436,0.0 -17900,0.07467258,0.31996608,,,,,,,,,,,,,, -18000,0.06804511,0.367459,,,,,,,,,,,,,, -18100,0.17531532,0.22976612,,,,,,,,,,,,,, -18200,0.101195365,0.35830587,,,,,,,,,,,,,, -18240,,,0.7529400416782924,0.2585871389933994,0.7257781450038688,0.285744481899796,3554.0,0.7430021429288257,0.2870502317461079,3581.0,4277.350579023361,4497.604445934296,4277.350579023361,218.1914882659912,1.4288592338562012,0.0 -18300,0.39243627,0.20115791,,,,,,,,,,,,,, -18400,0.10797999,0.28554794,,,,,,,,,,,,,, -18500,0.22513168,0.23100676,,,,,,,,,,,,,, -18583,,,0.7526175635201591,0.2583138602120535,0.724809276299592,0.2860380311192054,3554.0,0.7419922420413292,0.287416579036931,3581.0,4357.3684694767,4581.708220720291,4357.3684694767,222.23729467391968,1.456954002380371,0.0 -18600,0.098066896,0.29429746,,,,,,,,,,,,,, -18700,0.18332487,0.30317798,,,,,,,,,,,,,, -18800,0.10352558,0.25969285,,,,,,,,,,,,,, -18900,0.1504757,0.27855718,,,,,,,,,,,,,, -18928,,,0.7526444707598005,0.2583979368209839,0.7252818264851224,0.2857613464252427,3554.0,0.742520133931688,0.2870654692299637,3581.0,4437.445201873779,4665.86545753479,4437.445201873779,226.27573919296265,1.4869132041931152,0.0 -19000,0.14218903,0.29266024,,,,,,,,,,,,,, -19100,0.12513508,0.36329833,,,,,,,,,,,,,, -19200,0.19503781,0.3516592,,,,,,,,,,,,,, -19275,,,0.7517557144165039,0.2583988564355032,0.7244192283342712,0.2855900220811937,3554.0,0.7416877650708601,0.2868797560039095,3581.0,4517.462399482727,4749.96740436554,4517.462399482727,230.3186469078064,1.516965389251709,0.0 -19300,0.10850697,0.3162285,,,,,,,,,,,,,, -19400,0.15476114,0.20236573,,,,,,,,,,,,,, -19500,0.24238557,0.31090787,,,,,,,,,,,,,, -19600,0.14040995,0.25828958,,,,,,,,,,,,,, -19617,,,0.7539084298270089,0.2587548664637974,0.7266995457363182,0.2865791213465462,3554.0,0.7436792735313111,0.2880329642339779,3581.0,4597.49485373497,4834.088021278381,4597.49485373497,234.3654458522797,1.5466229915618896,0.0 -19700,0.2889953,0.2482054,,,,,,,,,,,,,, -19800,0.083225645,0.28445196,,,,,,,,,,,,,, -19900,0.10388808,0.27576,,,,,,,,,,,,,, -19961,,,0.7524251937866211,0.2587899650846209,0.7250721705516672,0.2861384798048853,3554.0,0.7423036048546147,0.2874829149286163,3581.0,4677.679318904877,4918.363613605499,4677.679318904877,238.41141080856323,1.579699993133545,0.0 -20000,0.19897643,0.3486297,,,,,,,,,,,,,, -20100,0.15512659,0.4156347,,,,,,,,,,,,,, -20200,0.3441517,0.28034192,,,,,,,,,,,,,, -20300,0.102691144,0.24206825,,,,,,,,,,,,,, -20307,,,0.7524850709097726,0.2583001852035522,0.7251725333690912,0.2855643646463492,3554.0,0.7423865076750559,0.2868783583823478,3581.0,4757.7959949970245,5002.560688257217,4757.7959949970245,242.4510452747345,1.6081418991088867,0.0 -20400,0.09828502,0.2368963,,,,,,,,,,,,,, -20500,0.47793162,0.2970239,,,,,,,,,,,,,, -20600,0.14772978,0.24054305,,,,,,,,,,,,,, -20651,,,0.7529078211103167,0.2586125305720738,0.7252262525499438,0.2861139043100643,3554.0,0.7424755463950363,0.2873984440449595,3581.0,4837.757677555084,5086.611615896225,4837.757677555084,246.49749064445496,1.6389946937561035,0.0 -20700,0.10625385,0.23040198,,,,,,,,,,,,,, -20800,0.084608346,0.27076355,,,,,,,,,,,,,, -20900,0.1887727,0.28524068,,,,,,,,,,,,,, -20998,,,0.7535805702209473,0.2580517871039254,0.72609613232889,0.2855590923354407,3554.0,0.743132905765673,0.2869485803437412,3581.0,4917.874920606613,5170.8072781562805,4917.874920606613,250.5354623794556,1.667454719543457,0.0 -21000,0.22462645,0.3905931,,,,,,,,,,,,,, -21100,0.14898992,0.28342694,,,,,,,,,,,,,, -21200,0.13445958,0.22068666,,,,,,,,,,,,,, -21300,0.14676233,0.30273375,,,,,,,,,,,,,, -21342,,,0.7533642905099052,0.2580889463424682,0.7260072415104459,0.2854546593626899,3554.0,0.7431851972650796,0.286775820683381,3581.0,4998.068671941757,5255.080285787582,4998.068671941757,254.5735669136048,1.696807622909546,0.0 -21400,0.16747281,0.25510475,,,,,,,,,,,,,, -21500,0.11639338,0.25310707,,,,,,,,,,,,,, -21600,0.13293608,0.23537438,,,,,,,,,,,,,, -21685,,,0.7526271683829171,0.2587615592139108,0.7254229252031162,0.2861455725228176,3554.0,0.7425711300745252,0.28759192941043,3581.0,5078.216583013535,5339.313606500626,5078.216583013535,258.6179361343384,1.7260563373565674,0.0 -21700,0.068308614,0.39718905,,,,,,,,,,,,,, -21800,0.31249416,0.18065137,,,,,,,,,,,,,, -21900,0.2097054,0.27714196,,,,,,,,,,,,,, -22000,0.23467366,0.3548164,,,,,,,,,,,,,, -22031,,,0.7523194040570941,0.2580763101577759,0.7246382954285664,0.2857501663783325,3554.0,0.7418408898526948,0.2871274759036407,3581.0,5158.255975008011,5423.437246322632,5158.255975008011,262.6589164733887,1.7574434280395508,0.0 -22100,0.10015155,0.29218322,,,,,,,,,,,,,, -22200,0.11988747,0.27281326,,,,,,,,,,,,,, -22300,0.28568786,0.25152317,,,,,,,,,,,,,, -22376,,,0.7534683772495815,0.2578743355614798,0.7259059169685566,0.2853841786982977,3554.0,0.743134405652227,0.2867114619148806,3581.0,5238.368992090225,5507.633875846863,5238.368992090225,266.70121240615845,1.7868409156799316,0.0 -22400,0.17764528,0.28936246,,,,,,,,,,,,,, -22500,0.32303214,0.24199763,,,,,,,,,,,,,, -22600,0.1803266,0.23843387,,,,,,,,,,,,,, -22700,0.1810165,0.19525628,,,,,,,,,,,,,, -22719,,,0.753256116594587,0.2578210149492536,0.7257600096282358,0.285283953270083,3554.0,0.7429689408946524,0.2866599544470818,3581.0,5318.356097221375,5591.70784330368,5318.356097221375,270.7464687824249,1.816502809524536,0.0 -22800,0.19999133,0.23476008,,,,,,,,,,,,,, -22900,0.4409805,0.19523372,,,,,,,,,,,,,, -23000,0.17151299,0.26977324,,,,,,,,,,,,,, -23065,,,0.7534429005214146,0.2573808772223336,0.7255944556309791,0.2852384774417909,3554.0,0.7428008854239389,0.2866411376884948,3581.0,5398.403444766998,5675.84022974968,5398.403444766998,274.7826302051544,1.8535089492797847,0.0 -23100,0.20393167,0.30176896,,,,,,,,,,,,,, -23200,0.14429297,0.3280239,,,,,,,,,,,,,, -23300,0.25026986,0.25022802,,,,,,,,,,,,,, -23400,0.14571853,0.30818635,,,,,,,,,,,,,, -23411,,,0.7532764162336077,0.2576410429818289,0.7255148385841658,0.2853181631832091,3554.0,0.7427481166879014,0.2866344222873324,3581.0,5478.536856174469,5760.058264970779,5478.536856174469,278.8257200717926,1.882972240447998,0.0 -23500,0.30623558,0.23183367,,,,,,,,,,,,,, -23600,0.094347455,0.23493871,,,,,,,,,,,,,, -23700,0.15310873,0.29549754,,,,,,,,,,,,,, -23752,,,0.7536858149937221,0.2576276915413992,0.7259083899743247,0.2852405726272334,3554.0,0.7431360418921041,0.286564643474239,3581.0,5558.539285182953,5844.148059844971,5558.539285182953,282.8697578907013,1.914400339126587,0.0 -23800,0.19841258,0.2384895,,,,,,,,,,,,,, -23900,0.21004851,0.22140767,,,,,,,,,,,,,, -24000,0.14942949,0.23774916,,,,,,,,,,,,,, -24097,,,0.754117625100272,0.2570815426962716,0.7259977616550014,0.2851191377398266,3554.0,0.7432274667952388,0.286486819815083,3581.0,5638.572496652603,5928.258192777634,5638.572496652603,286.9048192501068,1.944274663925171,0.0 -24100,0.15182284,0.32699722,,,,,,,,,,,,,, -24200,0.14513,0.25998747,,,,,,,,,,,,,, -24300,0.11991868,0.23740712,,,,,,,,,,,,,, -24400,0.38214573,0.18918142,,,,,,,,,,,,,, -24442,,,0.7537539345877511,0.2575973272323608,0.7259774280520188,0.2854220294254713,3554.0,0.7431903105146956,0.2867236655372975,3581.0,5718.696158885956,6012.466329336166,5718.696158885956,290.94787549972534,1.973649024963379,0.0 -24500,0.07917036,0.29471952,,,,,,,,,,,,,, -24600,0.12152273,0.28447416,,,,,,,,,,,,,, -24700,0.1864065,0.25294074,,,,,,,,,,,,,, -24786,,,0.753805433000837,0.2577033042907715,0.7263403416484947,0.285351960928707,3554.0,0.7434939693652262,0.2866072879760367,3581.0,5798.691824436188,6096.550888776779,5798.691824436188,294.99489879608154,2.0035831928253174,0.0 -24800,0.14307883,0.27247557,,,,,,,,,,,,,, -24900,0.06076226,0.26005733,,,,,,,,,,,,,, -25000,0.0903772,0.24861534,,,,,,,,,,,,,, -25100,0.13222212,0.28224015,,,,,,,,,,,,,, -25128,,,0.7541419437953404,0.2571213245391845,0.725858655080543,0.2853163084288829,3554.0,0.7430844321593131,0.2866430807233489,3581.0,5878.656435251236,6180.59944486618,5878.656435251236,299.0336084365845,2.036623239517212,0.0 -25200,0.1611751,0.22840424,,,,,,,,,,,,,, -25300,0.10066944,0.20161478,,,,,,,,,,,,,, -25400,0.16409597,0.21837175,,,,,,,,,,,,,, -25473,,,0.7536706243242536,0.257537943976266,0.7256218647782429,0.2854699439122292,3554.0,0.7428558358131457,0.2867858085642977,3581.0,5958.823487281799,6264.849833726883,5958.823487281799,303.0692329406738,2.072427272796631,0.0 -25500,0.13968694,0.25340945,,,,,,,,,,,,,, -25600,0.05877939,0.33187923,,,,,,,,,,,,,, -25700,0.09589247,0.41174385,,,,,,,,,,,,,, -25800,0.06979055,0.27733245,,,,,,,,,,,,,, -25820,,,0.7535369055611747,0.2573683943067278,0.7258339937174663,0.28513771963039,3554.0,0.7430077334150726,0.2864411755401773,3581.0,6038.98579120636,6349.096369981766,6038.98579120636,307.111353635788,2.1028239727020264,0.0 -25900,0.15770094,0.24492013,,,,,,,,,,,,,, -26000,0.08302007,0.29351348,,,,,,,,,,,,,, -26100,0.09217711,0.2735526,,,,,,,,,,,,,, -26159,,,0.7538965770176479,0.2571835517883301,0.7256854072875633,0.2853921300987883,3554.0,0.7428385871177744,0.2868199650717327,3581.0,6118.951326370239,6433.150407791138,6118.951326370239,311.15779757499695,2.13295578956604,0.0 -26200,0.115613915,0.2924553,,,,,,,,,,,,,, -26300,0.09369123,0.36164504,,,,,,,,,,,,,, -26400,0.10757222,0.24899645,,,,,,,,,,,,,, -26500,0.108272016,0.25109416,,,,,,,,,,,,,, -26505,,,0.7541265487670898,0.2570844377790178,0.7260599989668332,0.2851045401363341,3554.0,0.7432574645263195,0.286448743149609,3581.0,6199.053456783295,6517.338342905045,6199.053456783295,315.20138478279114,2.1633517742156982,0.0 -26600,0.11118012,0.26406845,,,,,,,,,,,,,, -26700,0.16922528,0.34152937,,,,,,,,,,,,,, -26800,0.14586028,0.17785773,,,,,,,,,,,,,, -26851,,,0.7537615639822823,0.257405264036996,0.7258054167619232,0.2853107098186023,3554.0,0.7429794401005306,0.2866079356543214,3581.0,6279.126361370087,6601.50510263443,6279.126361370087,319.24842977523804,2.198110342025757,0.0 -26900,0.15678935,0.21263221,,,,,,,,,,,,,, -27000,0.17235585,0.21212044,,,,,,,,,,,,,, -27100,0.14360353,0.24615118,,,,,,,,,,,,,, -27195,,,0.7544635363987514,0.2570823771612985,0.7263280453142585,0.2852458105908395,3554.0,0.7434335648430955,0.2866400127735793,3581.0,6359.18699669838,6685.648197650909,6359.18699669838,323.28483629226685,2.2321996688842773,0.0 -27200,0.09515771,0.31031144,,,,,,,,,,,,,, -27300,0.16337053,0.2407366,,,,,,,,,,,,,, -27400,0.21345992,0.3047662,,,,,,,,,,,,,, -27500,0.081859656,0.235601,,,,,,,,,,,,,, -27543,,,0.7543215751647949,0.2569711378642491,0.7262150426895752,0.2850779381506753,3554.0,0.7433979084491064,0.2864332329591071,3581.0,6439.320326805115,6769.8714554309845,6439.320326805115,327.3263280391693,2.268504858016968,0.0 -27600,0.10843402,0.2444527,,,,,,,,,,,,,, -27700,0.10543946,0.2905593,,,,,,,,,,,,,, -27800,0.07286729,0.31186104,,,,,,,,,,,,,, -27886,,,0.754523481641497,0.2571393081120082,0.726480478642023,0.2852333596937429,3554.0,0.743686363904112,0.2864977621692613,3581.0,6519.368485689163,6854.003954172134,6519.368485689163,331.366498708725,2.300798416137696,0.0 -27900,0.1580482,0.2689758,,,,,,,,,,,,,, -28000,0.09770686,0.23415923,,,,,,,,,,,,,, -28100,0.13382371,0.27315414,,,,,,,,,,,,,, -28200,0.20735565,0.25606775,,,,,,,,,,,,,, -28229,,,0.7529176303318569,0.257628389767238,0.7251205315533554,0.2855669750413266,3554.0,0.7423255577396328,0.2869981788650168,3581.0,6599.540381908417,6938.260453462601,6599.540381908417,335.408899307251,2.330955982208252,0.0 -28300,0.11494684,0.23906752,,,,,,,,,,,,,, -28400,0.12312525,0.23623702,,,,,,,,,,,,,, -28500,0.10937046,0.29437116,,,,,,,,,,,,,, -28572,,,0.7542286600385394,0.2567955596106393,0.7258950632210186,0.2850730780073948,3554.0,0.7430802733829587,0.2864325852808224,3581.0,6679.681492090225,7022.488809347153,6679.681492090225,339.4527425765991,2.362444400787353,0.0 -28600,0.084739745,0.36050114,,,,,,,,,,,,,, -28700,0.10564068,0.26583758,,,,,,,,,,,,,, -28800,0.17897119,0.20598152,,,,,,,,,,,,,, -28900,0.08121173,0.33963323,,,,,,,,,,,,,, -28918,,,0.7547598566327777,0.2569642066955566,0.7266376518975098,0.2851225209491066,3554.0,0.7438062866517733,0.2864167682953434,3581.0,6759.659672021866,7106.558267116547,6759.659672021866,343.49573826789856,2.398847341537476,0.0 -29000,0.09406357,0.260836,,,,,,,,,,,,,, -29100,0.10185066,0.2529637,,,,,,,,,,,,,, -29200,0.08189361,0.33090353,,,,,,,,,,,,,, -29264,,,0.7535268919808524,0.2571125200816563,0.7255190289550506,0.2852112572046901,3554.0,0.7426680091105836,0.2865796082514486,3581.0,6839.815808057785,7190.804083108902,6839.815808057785,347.5355215072632,2.4366612434387207,0.0 -29300,0.4411189,0.34249547,,,,,,,,,,,,,, -29400,0.19783561,0.18784252,,,,,,,,,,,,,, -29500,0.12305143,0.2237291,,,,,,,,,,,,,, -29600,0.10958361,0.20454115,,,,,,,,,,,,,, -29609,,,0.7549184390476772,0.2566295181001936,0.7264830890370005,0.2850916598979582,3554.0,0.74362766379852,0.2864999097341001,3581.0,6920.005371332169,7275.075638055801,6920.005371332169,351.5742256641388,2.4678423404693604,0.0 -29700,0.0611734,0.30877668,,,,,,,,,,,,,, -29800,0.09537228,0.35654464,,,,,,,,,,,,,, -29900,0.08609422,0.26290134,,,,,,,,,,,,,, -29956,,,0.7546397617885044,0.2568075145993914,0.7264039528524198,0.2850718758518131,3554.0,0.743571213522759,0.2863981219784103,3581.0,7000.064207315445,7359.221385478973,7000.064207315445,355.6179938316345,2.49917984008789,0.0 -30000,0.104081824,0.32150313,,,,,,,,,,,,,, -30100,0.13392067,0.22496986,,,,,,,,,,,,,, -30200,0.09865943,0.27998197,,,,,,,,,,,,,, -30300,,,0.7547033173697335,0.256915875843593,0.7264863863780248,0.285162192083304,3554.0,0.7436359131745671,0.2865445995357442,3581.0,7080.181899785996,7443.427941322327,7080.181899785996,359.66366052627563,2.530510187149048,0.0 -30300,0.10068883,0.29683793,,,,,,,,,,,,,, -30400,0.13166955,0.28390822,,,,,,,,,,,,,, -30500,0.1273757,0.22865266,,,,,,,,,,,,,, -30600,0.08856771,0.28080207,,,,,,,,,,,,,, -30646,,,0.755000250680106,0.2563994271414621,0.7263446694085889,0.2850013264928162,3554.0,0.7435238307429838,0.2863884749808014,3581.0,7160.268867731094,7527.606162309647,7160.268867731094,363.7105996608734,2.5629425048828125,0.0 -30700,0.079091795,0.29063815,,,,,,,,,,,,,, -30800,0.0744018,0.28267506,,,,,,,,,,,,,, -30900,0.13258861,0.21609576,,,,,,,,,,,,,, -30990,,,0.7543023654392788,0.2566179037094116,0.7258775460968275,0.284997616984164,3554.0,0.7430925451820372,0.2863450123590652,3581.0,7240.273091554642,7611.697166442871,7240.273091554642,367.7547652721405,2.593646049499512,0.0 -31000,0.07959849,0.25760204,,,,,,,,,,,,,, -31100,0.103565045,0.2121905,,,,,,,,,,,,,, -31200,0.07839266,0.2386275,,,,,,,,,,,,,, -31300,0.112920396,0.25731015,,,,,,,,,,,,,, -31334,,,0.7549773624965123,0.2567619255610874,0.7266886919887803,0.2850739882109067,3554.0,0.7438424202823932,0.2864117913990505,3581.0,7320.299663305283,7695.818750143051,7320.299663305283,371.80034589767456,2.63110613822937,0.0 -31400,0.098589785,0.29853335,,,,,,,,,,,,,, -31500,0.058052618,0.31009653,,,,,,,,,,,,,, -31600,0.10898857,0.22654098,,,,,,,,,,,,,, -31681,,,0.7553375107901437,0.256252782685416,0.7267150020223692,0.2849354655405881,3554.0,0.7438221718139137,0.2862987204058747,3581.0,7400.453884601593,7780.056601762772,7400.453884601593,375.8334729671478,2.669584989547729,0.0 -31700,0.06753419,0.25895774,,,,,,,,,,,,,, -31800,0.18913482,0.27683675,,,,,,,,,,,,,, -31900,0.08526476,0.34034133,,,,,,,,,,,,,, -32000,0.16004346,0.17074454,,,,,,,,,,,,,, -32026,,,0.7550763402666364,0.256514344896589,0.726590046536473,0.2849802887701357,3554.0,0.7437481319594736,0.2863594317229824,3581.0,7480.44393825531,7864.1325850486755,7480.44393825531,379.8758656978607,2.7011075019836426,0.0 -32100,0.052495025,0.2693752,,,,,,,,,,,,,, -32200,0.06805172,0.2660147,,,,,,,,,,,,,, -32300,0.10734348,0.2720044,,,,,,,,,,,,,, -32370,,,0.7548025676182338,0.256595424243382,0.7264196839168894,0.2850220894370779,3554.0,0.7435609870235269,0.2863833958195162,3581.0,7560.630806207657,7948.39928150177,7560.630806207657,383.9117209911346,2.7331085205078125,0.0 -32400,0.13428275,0.24553046,,,,,,,,,,,,,, -32500,0.09860569,0.22892563,,,,,,,,,,,,,, -32600,0.067692444,0.33454123,,,,,,,,,,,,,, -32700,0.08711591,0.27523378,,,,,,,,,,,,,, -32712,,,0.7553097861153739,0.2561919007982526,0.7266341484726716,0.2848949013765299,3554.0,0.7437498363760123,0.2862728473628176,3581.0,7640.62170624733,8032.480271816254,7640.62170624733,387.9579894542694,2.765145778656006,0.0 -32800,0.060814273,0.25493175,,,,,,,,,,,,,, -32900,0.08439705,0.21009435,,,,,,,,,,,,,, -33000,0.06582089,0.33002582,,,,,,,,,,,,,, -33057,,,0.7551392146519252,0.2563128471374511,0.7265522645039041,0.2848909170894591,3554.0,0.743716770695162,0.2862628253935702,3581.0,7720.608385562897,8116.549202203751,7720.608385562897,391.9960012435913,2.797520399093628,0.0 -33100,0.063758396,0.2235982,,,,,,,,,,,,,, -33200,0.0971974,0.35223725,,,,,,,,,,,,,, -33300,0.085408784,0.3147399,,,,,,,,,,,,,, -33400,0.07157751,0.2534985,,,,,,,,,,,,,, -33401,,,0.7549631936209542,0.2564116716384887,0.7264295072453574,0.2849628746878517,3554.0,0.7435593507836498,0.2863089469051068,3581.0,7800.907135248184,8200.876027584076,7800.907135248184,395.9805171489716,2.829303979873657,0.0 -33500,0.06527274,0.264128,,,,,,,,,,,,,, -33600,0.0673796,0.29456863,,,,,,,,,,,,,, -33700,0.08998507,0.2555712,,,,,,,,,,,,,, -33746,,,0.7553563117980957,0.2561759267534528,0.7267946190691826,0.2848300021487672,3554.0,0.7438874168790143,0.2862026254014241,3581.0,7881.002731561661,8285.059644460678,7881.002731561661,400.01979994773865,2.866029500961304,0.0 -33800,0.05184966,0.24019472,,,,,,,,,,,,,, -33900,0.05967299,0.2837141,,,,,,,,,,,,,, -34000,0.05667276,0.27776182,,,,,,,,,,,,,, -34090,,,0.7552986826215472,0.256163546017238,0.7266865624560355,0.2847965478762925,3554.0,0.7438356707929,0.2861697301622277,3581.0,7961.019174337387,8369.170066595078,7961.019174337387,404.0609667301178,2.9068877696990967,0.0 -34100,0.06867166,0.2571262,,,,,,,,,,,,,, -34200,0.057291843,0.26612327,,,,,,,,,,,,,, -34300,0.07250889,0.22440466,,,,,,,,,,,,,, -34400,0.06780924,0.33200207,,,,,,,,,,,,,, -34437,,,0.7552531787327358,0.2562346458435058,0.7266829216419879,0.2848477081831211,3554.0,0.7438358753228846,0.2862102270991867,3581.0,8041.14001584053,8453.379067897797,8041.14001584053,408.103973865509,2.939914464950561,0.0 -34500,0.06440422,0.2202433,,,,,,,,,,,,,, -34600,0.08208425,0.25700715,,,,,,,,,,,,,, -34700,0.08039274,0.2437946,,,,,,,,,,,,,, -34779,,,0.7552689143589565,0.2561770507267543,0.7266373771190912,0.2848144599944605,3554.0,0.7437871290098785,0.2861768205350286,3581.0,8121.161528110504,8537.487291812897,8121.161528110504,412.1444194316864,2.974390745162964,0.0 -34800,0.0794943,0.22730868,,,,,,,,,,,,,, -34900,0.06547886,0.29758298,,,,,,,,,,,,,, -35000,0.078990616,0.24848786,,,,,,,,,,,,,, -35100,0.070874445,0.30057853,,,,,,,,,,,,,, -35125,,,0.7554235458374023,0.2560735600335257,0.7267863757166221,0.2847423134859049,3554.0,0.7439305727057736,0.2861144729780438,3581.0,8201.258660793304,8621.679370641708,8201.258660793304,416.184344291687,3.0172266960144043,0.0 -35200,0.046961226,0.23105478,,,,,,,,,,,,,, -35300,0.0570101,0.278343,,,,,,,,,,,,,, -35400,0.040614333,0.32121426,,,,,,,,,,,,,, -35471,,,0.7555623735700335,0.2561214140483311,0.7269094764481921,0.2847858830389174,3554.0,0.744045791263788,0.2861540836184026,3581.0,8281.31056690216,8705.81771326065,8281.31056690216,420.22679924964905,3.0492374897003174,0.0 -35500,0.05533165,0.31518233,,,,,,,,,,,,,, -35600,0.08132794,0.21083777,,,,,,,,,,,,,, -35700,0.062469363,0.2104128,,,,,,,,,,,,,, -35800,0.07286748,0.24353677,,,,,,,,,,,,,, -35814,,,0.7554898943219867,0.2561060871396746,0.7268266307549592,0.2847845778414287,3554.0,0.7439735240025481,0.2861507770503176,3581.0,8361.536612510681,8790.133450984955,8361.536612510681,424.2720458507538,3.081813335418701,0.0 -35900,0.076319024,0.21980296,,,,,,,,,,,,,, -36000,0.04989929,0.23552766,,,,,,,,,,,,,, -36100,0.050236102,0.37189728,,,,,,,,,,,,,, -36159,,,0.7554442541939872,0.256074139050075,0.7268051980383019,0.2847391707077413,3554.0,0.7439485713444219,0.2861035647121963,3581.0,8441.574692964554,8874.257183551788,8441.574692964554,428.31070947647095,3.1168370246887207,0.0 -36189,,,0.7554440498352051,0.2560740368706839,0.7268047858706739,0.2847391020131366,3554.0,0.7439481622844527,0.2861034965355347,3581.0,8446.291797876358,8883.051620006561,8446.291797876358,432.35393595695496,3.1495649814605717,0.0 -36189,,,,,,,,,,,8446.291797876358,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/eval_measurements.csv deleted file mode 100644 index d37310f3d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.9805192947387695,0.0,30.82212352752685,1,0,30.82212352752685,0.912513662602974,3581,0.2829431162570249,34.802770137786865,0.9032635007585798,0.2670762368610927,0.915261701713738,3554,0.2587972028106535 -8.01711654663086,0.0200138092041015,110.83161187171936,344,0,110.83161187171936,0.3137289525010472,3581,0.7153852090460067,118.88110089302064,0.285450986453465,0.7226132665361676,0.3115307191280951,3554,0.6983893313432048 -12.06654405593872,0.0441045761108398,191.00679397583008,684,0,191.00679397583008,0.3042089339242704,3581,0.7213692791599763,203.1422142982483,0.2770511763436453,0.7282571792602539,0.3023097736485298,3554,0.7044029944252954 -16.10910677909851,0.0733783245086669,271.0234563350677,1020,0,271.0234563350677,0.2980568765162489,3581,0.728832305592886,287.2428436279297,0.2704525334494455,0.7370714460100446,0.296143402460256,3554,0.711497979279509 -20.15527033805847,0.0985920429229736,351.08174562454224,1366,0,351.08174562454224,0.2986054941117181,3581,0.7230863083810388,371.3848412036896,0.2714135476521083,0.7303394590105329,0.2970385618537212,3554,0.7056676620972847 -24.19566798210144,0.1235325336456298,431.0522320270538,1711,0,431.0522320270538,0.2959236969531031,3581,0.7334914986430118,455.4326248168945,0.2683972801480974,0.7415225846426827,0.2941325739923325,3554,0.7163652667988534 -28.23944354057312,0.1476261615753173,511.1865234375,2055,0,511.1865234375,0.2959743181243018,3581,0.7336164664636274,539.6468026638031,0.2685650076184954,0.741027695792062,0.2943257432206668,3554,0.71662994711065 -32.28887057304382,0.1761145591735839,591.3152222633362,2400,0,591.3152222633362,0.2951105879991622,3581,0.735299339177255,623.8655636310577,0.2670582021985735,0.7440621512276786,0.2936223104688379,3554,0.7178861653462648 -36.32929086685181,0.205007791519165,671.4278931617737,2747,0,671.4278931617737,0.2937197500152715,3581,0.7375439875820302,708.0597774982452,0.2664312124252319,0.7456315585545131,0.2921799985491699,3554,0.7203574537492966 -40.37436270713806,0.2290525436401367,751.4628582000732,3093,0,751.4628582000732,0.292780752855784,3581,0.735590589875384,792.1761124134064,0.2652908563613891,0.7441027505057198,0.2911974939658659,3554,0.7181291381629854 -44.415067195892334,0.2549364566802978,831.4726746082306,3438,0,831.4726746082306,0.2967027175762706,3581,0.7294369644041468,876.2643911838531,0.268498352595738,0.7385843821934291,0.295027252523565,3554,0.7121779871711452 -48.46218538284302,0.2796492576599121,911.6228678226472,3783,0,911.6228678226472,0.2926715338439856,3581,0.7389904236334125,960.4981019496918,0.2647112948553903,0.7476419040134975,0.2910621999419668,3554,0.7219686170072454 -52.50946092605591,0.3047735691070556,991.5995440483092,4130,0,991.5995440483092,0.2926502286372521,3581,0.7354596225085521,1044.5592665672302,0.2648463419505528,0.744476454598563,0.29101665541907,3554,0.7180813954127392 -56.5556275844574,0.3315975666046142,1071.6571695804596,4472,0,1071.6571695804596,0.2937538042577143,3581,0.7306974145228288,1128.7016820907593,0.2657135725021362,0.7401151657104492,0.2923074957354389,3554,0.7127527549284257 -60.60241341590881,0.3582265377044678,1151.627099752426,4819,0,1151.627099752426,0.2922750183804279,3581,0.7363732579499441,1212.7570588588717,0.2647957120622907,0.7446002279009137,0.290804663869056,3554,0.7191272020742473 -64.64353156089783,0.3878440856933594,1231.6058168411255,5162,0,1231.6058168411255,0.2931434867835451,3581,0.7402641682010961,1296.8183772563934,0.2648267064775739,0.7490038871765137,0.2915698187231816,3554,0.7233019105893008 -68.6874794960022,0.4124460220336914,1311.6497621536255,5505,0,1311.6497621536255,0.2934891765459194,3581,0.7366458964194709,1380.9428961277008,0.2651594366346086,0.7461484500340053,0.2920032816786543,3554,0.7193948362540448 -72.73325824737549,0.4397716522216797,1391.8070256710052,5851,0,1391.8070256710052,0.2933677880000349,3581,0.7399766672193522,1465.185186624527,0.264972414289202,0.7486353601728167,0.2916949459455895,3554,0.72320086082583 -76.77447819709778,0.467041015625,1471.9566292762756,6197,0,1471.9566292762756,0.2929133905508237,3581,0.7316704318364283,1549.4155497550964,0.2653241668428693,0.7402887344360352,0.2915773751296954,3554,0.7139138311365715 -80.81879162788391,0.4929647445678711,1551.9505987167358,6540,0,1551.9505987167358,0.2928044101573408,3581,0.7359811057927255,1633.491622209549,0.2647097281047276,0.7448414393833706,0.2913872284639666,3554,0.7187118744944077 -84.86484432220459,0.5194802284240723,1632.0223336219788,6885,0,1632.0223336219788,0.2944325029234153,3581,0.7342648946916015,1717.647849559784,0.2659179653440203,0.7436417170933315,0.2928763901040201,3554,0.716956727345069 -88.90904879570007,0.5449492931365967,1712.0917258262634,7230,0,1712.0917258262634,0.2919725526214744,3581,0.7341480398937098,1801.7990944385529,0.2644589628492083,0.7429020064217704,0.2905732660932224,3554,0.7165945693892445 -92.9541277885437,0.572021484375,1792.1360602378843,7573,0,1792.1360602378843,0.2919843130955913,3581,0.7334203222083566,1885.9275164604187,0.2645441464015415,0.7418189729963031,0.2905525890172165,3554,0.7159344142383582 -96.99868774414062,0.5981225967407227,1872.2391254901888,7918,0,1872.2391254901888,0.2923499104431374,3581,0.7403620698870776,1970.113526582718,0.2643026624407087,0.7490203039986747,0.290826577447946,3554,0.723283088267621 -101.04316473007202,0.6242008209228516,1952.3275845050807,8265,0,1952.3275845050807,0.2923485809982372,3581,0.7392333370785046,2054.2847859859467,0.2642133235931396,0.7482437406267438,0.2909175291045301,3554,0.7220145736977701 -105.08436179161072,0.6500787734985352,2032.51903796196,8607,0,2032.51903796196,0.2919754501295902,3581,0.7369421921905543,2138.555235147476,0.2638657093048095,0.7464250155857631,0.2906534326968732,3554,0.7195118918604038 -109.13049507141112,0.6793222427368164,2112.5709660053253,8951,0,2112.5709660053253,0.2910422138980033,3581,0.736424458622766,2222.6944692134857,0.2632397242954799,0.7454898016793388,0.2896331804283026,3554,0.7187876446433596 -113.17689657211304,0.7051200866699219,2192.5693085193634,9293,0,2192.5693085193634,0.2910244879660011,3581,0.7345339197980661,2306.7771003246307,0.2630248410361154,0.7439816338675362,0.289561634997538,3554,0.7169932728747538 -117.2190728187561,0.7311830520629883,2272.6771895885468,9636,0,2272.6771895885468,0.2914541713754014,3581,0.7388857042812762,2390.965298175812,0.2636504684175764,0.7478782790047782,0.2900727915508934,3554,0.7215138587243247 -121.26503705978394,0.7569248676300049,2352.732777118683,9980,0,2352.732777118683,0.2924467894791958,3581,0.7312535997277296,2475.1045954227448,0.2646668297903878,0.7397242954799107,0.2910097516113006,3554,0.7141337225661227 -125.30674004554749,0.7888193130493164,2432.74764752388,10325,0,2432.74764752388,0.2908956340756772,3581,0.7399639181836428,2559.2052206993103,0.2629481383732387,0.749014036996024,0.2893752321877638,3554,0.7227236394071821 -129.35560011863708,0.816091775894165,2512.824273824692,10669,0,2512.824273824692,0.292158947614144,3581,0.7358866811164828,2643.370032787323,0.2640965155192784,0.7450334003993443,0.290682490514649,3554,0.7185316885463562 -133.3985664844513,0.8439867496490479,2592.9127190113068,11014,0,2592.9127190113068,0.2901744272615016,3581,0.7411439880183608,2727.5416226387024,0.26199814251491,0.7504299027579171,0.2886156072493142,3554,0.7241798963315982 -137.44315481185913,0.8700685501098633,2672.970718383789,11358,0,2672.970718383789,0.2919174999672752,3581,0.7365518126265359,2811.682421684265,0.2641546896525791,0.7455850328717913,0.2902342925664743,3554,0.7196553261949564 -141.48319029808044,0.8981711864471436,2752.944223165512,11702,0,2752.944223165512,0.2920892028893814,3581,0.7364029147977171,2895.736180782318,0.2641615867614746,0.7454074450901577,0.2905541689931239,3554,0.7192417846748382 -145.52828335762024,0.9264392852783204,2832.9485371112823,12044,0,2832.9485371112823,0.29123552882182,3581,0.7357157622259843,2979.825961828232,0.2633335930960519,0.7435768672398159,0.2896535483785875,3554,0.7189717461838773 -149.57113814353943,0.9530577659606934,2913.143745660782,12389,0,2913.143745660782,0.2899312751980941,3581,0.7393814167873848,3064.1029093265533,0.2619869198117937,0.7485134260995048,0.2885388066812922,3554,0.7220078416265123 -153.61311268806458,0.9799468517303468,2993.192461490631,12736,0,2993.192461490631,0.2895676208854021,3581,0.7409452530499512,3148.232587814331,0.2616373130253383,0.7500293595450265,0.2881183270061023,3554,0.723628003877673 -157.65242910385132,1.007526397705078,3073.1728444099426,13078,0,3073.1728444099426,0.2893386836559271,3581,0.742036215988027,3232.291928291321,0.2612794807979038,0.7514548982892718,0.2879144757667417,3554,0.724869796246307 -161.69960498809814,1.0354080200195312,3153.252564430237,13424,0,3153.252564430237,0.2907379073691881,3581,0.7365478583801661,3316.4589569568634,0.2630315848759242,0.7457209995814732,0.2893385149215672,3554,0.7189907058947664 -165.74770259857178,1.0637977123260498,3233.4567432403564,13770,0,3233.4567432403564,0.2902145492268221,3581,0.7370009604728078,3400.75181722641,0.2622240781784057,0.7465126173836845,0.2889956601496553,3554,0.7193183791590462 -169.79286742210388,1.093024969100952,3313.4611835479736,14111,0,3313.4611835479736,0.2901376118642663,3581,0.7393679178083985,3484.84254193306,0.2619667734418596,0.748753275190081,0.28858569074898,3554,0.7221636409899057 -173.83345460891724,1.1216280460357666,3393.474872589112,14457,0,3393.474872589112,0.2902302639473087,3581,0.7417963704927045,3568.93760895729,0.2622332062040056,0.7506982939583915,0.2887858324796883,3554,0.7246852825381612 -177.88046073913574,1.1491503715515137,3473.484473705292,14804,0,3473.484473705292,0.2902830667716769,3581,0.739383939323862,3653.033741474152,0.262541617665972,0.748584338596889,0.2888958125417663,3554,0.7220001478307893 -181.9265389442444,1.1808226108551023,3553.4609277248383,15148,0,3553.4609277248383,0.2894467436644792,3581,0.7373965896397654,3737.0999524593353,0.2616633006504604,0.746781485421317,0.2881038839654702,3554,0.7198642264877603 -185.9671552181244,4.293005704879761,3630.503707885742,15479,0,3630.503707885742,0.2894908880528309,3581,0.7400738189620567,3821.306978702545,0.2613825287137712,0.7495412826538086,0.2880535479938977,3554,0.7229363865978475 -190.0122487545013,4.322009086608887,3710.492092132568,15827,0,3710.492092132568,0.2894689692561435,3581,0.7377974002330006,3905.3816516399374,0.2616352353777204,0.7468673161097935,0.2881785721743986,3554,0.7202307808982836 -194.06110763549805,4.35016131401062,3790.4588992595673,16170,0,3790.4588992595673,0.2889514061300091,3581,0.7401714479413921,3989.43754029274,0.261055520602635,0.7493359701974052,0.287572565545644,3554,0.7227794194261747 -198.1052176952362,4.382714986801148,3870.481683015824,16513,0,3870.481683015824,0.2893771352930396,3581,0.7395680844867006,4073.548889398575,0.2611632687704904,0.7490717342921666,0.2879586979184985,3554,0.7222987632772931 -202.15042400360107,4.412107944488525,3950.5336406230927,16857,0,3950.5336406230927,0.2901986981530124,3581,0.7370565244519687,4157.687446594238,0.2624172823769705,0.7464184079851423,0.2888134820580684,3554,0.719400606600837 -206.19731330871585,4.441436052322388,4030.594305992128,17202,0,4030.594305992128,0.2892383616984606,3581,0.7391657740069115,4241.836377620697,0.2613915715898786,0.7484077726091657,0.2878314926842993,3554,0.7219015023784819 -210.2408187389373,4.469689130783081,4110.767489433289,17547,0,4110.767489433289,0.2893575345028448,3581,0.7394163914147585,4326.093504428864,0.2607495614460536,0.7491188049316406,0.287938776483144,3554,0.7221773799108399 -214.283281326294,4.499216556549072,4190.812663078308,17893,0,4190.812663078308,0.2890635567382539,3581,0.7397306176478288,4410.222549676895,0.2610698427472795,0.7488366535731724,0.2876933821816087,3554,0.7225361031364308 -218.32737565040588,4.531201362609863,4270.923578500748,18241,0,4270.923578500748,0.2890376155185353,3581,0.7400820683381039,4494.421721696854,0.2612367698124477,0.7490247998918805,0.2877226804305008,3554,0.7226962302599184 -222.3760340213776,4.564069509506226,4350.883985042572,18584,0,4350.883985042572,0.2889587010327946,3581,0.7397645014486177,4578.475754737854,0.2604962587356567,0.7490253448486328,0.2874368078331633,3554,0.7227430799803038 -226.41681051254272,4.5984275341033936,4430.957304239273,18930,0,4430.957304239273,0.289137153444394,3581,0.7375610317474169,4662.636361837387,0.2611930540629795,0.746837956564767,0.2878367134742544,3554,0.7200128129176632 -230.45980858802795,4.627838373184204,4511.09930229187,19276,0,4511.09930229187,0.2894971603056932,3581,0.742224383573897,4746.862930059433,0.2613452843257359,0.7515696798052106,0.2881080056417505,3554,0.7252221995682682 -234.5062789916992,4.661612987518311,4591.108497858048,19620,0,4591.108497858048,0.28975834509608,3581,0.7407267468496929,4830.964126110077,0.2607353925704956,0.7509069442749023,0.2882456868031619,3554,0.7235287401739238 -238.5473177433014,4.692481756210327,4671.207195281982,19969,0,4671.207195281982,0.2884193213749651,3581,0.7388614333897654,4915.146928071976,0.2605396509170532,0.747917720249721,0.2870852460201094,3554,0.7213956353096863 -242.59148383140564,4.722586393356323,4751.304104089737,20315,0,4751.304104089737,0.2887872708173345,3581,0.7419790157689891,4999.330250740051,0.2608891555241176,0.7510645730154855,0.2874345580848603,3554,0.7249340257016742 -246.63685989379883,4.753290891647339,4831.424298048019,20657,0,4831.424298048019,0.2885486525019198,3581,0.7399350112791468,5083.538475513458,0.2601903336388724,0.750018664768764,0.287154438660664,3554,0.722571961720069 -250.6783995628357,4.78401255607605,4911.5534324646,21003,0,4911.5534324646,0.2883246921687378,3581,0.7402386701296775,5167.75226020813,0.2597711597170148,0.750136239188058,0.2868698712608152,3554,0.7229999978017726 -254.72396564483645,4.8128838539123535,4991.542078495026,21347,0,4991.542078495026,0.2881120491613725,3581,0.7412270953687866,5251.82776093483,0.2598606007439749,0.7508693422589984,0.2868814291280511,3554,0.7238796322145822 -258.7666749954224,4.8427369594573975,5071.724465370178,21690,0,5071.724465370178,0.2877457700472109,3581,0.7412582521031137,5336.094830274582,0.2596160514014108,0.7507828984941755,0.2863647598326709,3554,0.7240258830279263 -262.81214714050293,4.872639179229736,5151.788769721985,22037,0,5151.788769721985,0.2878893500964291,3581,0.7419090665142418,5420.246787071228,0.2595355340412685,0.7516030584062848,0.2865019257845473,3554,0.7247510232748312 -266.8560085296631,4.907236814498901,5232.000771999359,22382,0,5232.000771999359,0.2878808280137357,3581,0.742977122094038,5504.5493080616,0.2594784498214721,0.7525712421962193,0.2865110793406197,3554,0.7257786945607062 -270.8928482532501,4.937005043029785,5312.106199026108,22727,0,5312.106199026108,0.2878331725273143,3581,0.7421773416774294,5588.733791828156,0.2596061059406825,0.7518329620361328,0.2864908487795441,3554,0.7250375484709131 -274.9365813732147,4.966019153594971,5392.222705602646,23072,0,5392.222705602646,0.2875140375746125,3581,0.741036337069778,5672.93510222435,0.2589318581989833,0.7509150505065918,0.2861938991772035,3554,0.7236573364738674 -278.97792768478394,4.996206760406494,5472.2902681827545,23418,0,5472.2902681827545,0.2878223324381283,3581,0.7411057409112329,5757.086367607117,0.259568520954677,0.7504358291625977,0.2864663935002813,3554,0.723816433178285 -283.0185122489929,5.02751088142395,5552.431335449219,23760,0,5552.431335449219,0.288040020518448,3581,0.7410613579045657,5841.311386823654,0.259830219405038,0.7507174355643136,0.2867341478956369,3554,0.7237039801104389 -287.06051874160767,5.057190895080566,5632.593497753143,24105,0,5632.593497753143,0.2871316346799951,3581,0.7427400036651773,5925.557373762131,0.2583275181906564,0.7529386111668178,0.2857426271454699,3554,0.7255741907226013 -291.10566115379333,5.088764429092407,5712.731587409973,24450,0,5712.731587409973,0.2877290667651319,3581,0.7411732358061645,6009.784536600113,0.2591849395206996,0.7510510172162738,0.286399175829611,3554,0.7238414380143852 -295.145968914032,5.1201136112213135,5792.829498052597,24793,0,5792.829498052597,0.2878063109226647,3581,0.7438580327378874,6093.966063499451,0.2592285360608782,0.7537099293300084,0.2864781231040289,3554,0.7267622639103827 -299.1887602806092,5.155481338500977,5872.8541264534,25138,0,5872.8541264534,0.2870566403522933,3581,0.7422174977310807,6178.080864906311,0.2580028772354126,0.7527525765555245,0.2856826739292434,3554,0.7249790893623382 -303.2299098968506,5.186309576034546,5952.869534730911,25484,0,5952.869534730911,0.2874117044056304,3581,0.7413706754180047,6262.180478811264,0.2586135864257812,0.751709120614188,0.286100251257386,3554,0.7240188761782499 -307.27237248420715,5.216569900512695,6032.928485393524,25830,0,6032.928485393524,0.2870886492948897,3581,0.7417736676644093,6346.324097394943,0.2585456371307373,0.7517774445669991,0.2857211257342079,3554,0.7246146644845597 -311.31592655181885,5.249204397201538,6113.055310726166,26173,0,6113.055310726166,0.2873534133600077,3581,0.7418254137505236,6430.539331912994,0.2580876691000802,0.7526977402823312,0.2860090763433367,3554,0.7245964604143219 -315.35939478874207,5.28049635887146,6193.021834373474,26515,0,6193.021834373474,0.2869383197561784,3581,0.7428827655944569,6514.592348814011,0.2581151042665754,0.753077507019043,0.2855437390912967,3554,0.7256595094216024 -319.4024076461792,5.313000917434692,6273.217540979385,26861,0,6273.217540979385,0.2866473417646956,3581,0.743087909169052,6598.875570058823,0.2580545970371791,0.7530824797494071,0.2853332416489343,3554,0.7258945136641812 -323.43775701522827,5.34488582611084,6353.390656471252,27205,0,6353.390656471252,0.286867041056531,3581,0.7432430110740715,6683.127971410751,0.2577869551522391,0.7533891541617257,0.2854679345950425,3554,0.7262223243176702 -327.4859962463379,5.375718593597412,6433.46320271492,27551,0,6433.46320271492,0.286973601178529,3581,0.7436233686688425,6767.291547298431,0.2581208944320678,0.7539660590035575,0.285553390683253,3554,0.726424355150007 -331.52803587913513,5.40687370300293,6513.643239974976,27897,0,6513.643239974976,0.2867495385803546,3581,0.7438512832483943,6851.556659221649,0.2578811134610857,0.7541084289550781,0.2854553978296901,3554,0.7266418422683948 -335.5698335170746,5.439523696899414,6593.765914440155,28240,0,6593.765914440155,0.2873368123429209,3581,0.7422999233148911,6935.765740871429,0.2581344842910766,0.7530191285269601,0.286052113513163,3554,0.7249559392805641 -339.61222982406616,5.471314668655396,6673.728529691696,28587,0,6673.728529691696,0.2864703892396502,3581,0.7427766827090896,7019.814661979675,0.2573441437312534,0.753343037196568,0.2851177810213843,3554,0.725530020091798 -343.6575689315796,5.5031914710998535,6753.82816696167,28932,0,6753.82816696167,0.2864604013587336,3581,0.7432078319167132,7104.003720521927,0.2575488771711077,0.7535497801644462,0.2851559752215813,3554,0.7259476832881964 -347.6977567672729,5.534913063049316,6833.869222640991,29273,0,6833.869222640991,0.2868101817208007,3581,0.7424719330319743,7188.128715515137,0.2578637940543039,0.7527131353105817,0.285449541614642,3554,0.7251399034318725 -351.7384581565857,5.566555500030518,6913.852163553238,29619,0,6913.852163553238,0.2864549813141406,3581,0.742045556190659,7272.19589304924,0.2572024720055716,0.7528404508318219,0.2851512181202079,3554,0.7247044483328644 -355.77624320983887,5.598665475845337,6993.946165800095,29964,0,6993.946165800095,0.2864696052080424,3581,0.7436717740985409,7356.371992826462,0.2572521822793143,0.7542448725019183,0.2851106367824986,3554,0.7264857681265827 -359.8158543109894,5.6307532787323,7074.071465730667,30308,0,7074.071465730667,0.2862640184851473,3581,0.7434445412856046,7440.581089496613,0.2571421350751604,0.7539666720799038,0.2849134832670934,3554,0.726212775767621 -363.8551321029663,5.662577629089356,7154.043691635132,30653,0,7154.043691635132,0.2862637798668319,3581,0.7432008778972354,7524.636759996414,0.2566960198538644,0.7541763441903251,0.2849078674831616,3554,0.7259604604846651 -367.8951904773712,5.695420980453491,7234.144478082657,30998,0,7234.144478082657,0.2861960122652541,3581,0.7430420262758308,7608.822690963745,0.2569471086774553,0.7537750516619001,0.2848584417151009,3554,0.7257893422244303 -371.93916368484497,5.727295398712158,7314.175801992416,31341,0,7314.175801992416,0.2861467205389556,3581,0.7435245806862608,7692.941826343536,0.2569677489144461,0.7542284556797573,0.2848178260300893,3554,0.7263002926939716 -375.9801406860352,5.7609148025512695,7394.33158826828,31686,0,7394.33158826828,0.2861874220058992,3581,0.743722633888055,7777.184247970581,0.2564511299133301,0.754953111921038,0.2848102696235755,3554,0.7265166806986846 -380.0295407772064,5.798375606536865,7474.308657169342,32034,0,7474.308657169342,0.2861108596149818,3581,0.743590030281346,7861.260461330414,0.2566853761672973,0.7544609478541783,0.2847692245972847,3554,0.7263270148951885 -384.07799434661865,5.836031198501587,7554.259472131729,32378,0,7554.259472131729,0.2861430049109013,3581,0.7435266259861072,7945.309270620346,0.2567593199866159,0.7543185779026577,0.284801098893852,3554,0.7262754252470808 -388.12148785591125,5.869813680648804,7634.3465123176575,32721,0,7634.3465123176575,0.2862320777192125,3581,0.7438551011414409,8029.485631227493,0.2563129663467407,0.7552104677472796,0.2848670972352894,3554,0.7266483682558385 -392.1617946624756,5.903012752532959,7714.412069559097,33067,0,7714.412069559097,0.2860880545216943,3581,0.7437527679724588,8113.636897563934,0.2565027305058071,0.7548089027404785,0.284718768410154,3554,0.72650617042417 -396.2066998481751,5.935891389846802,7794.469235181808,33414,0,7794.469235181808,0.2861565720665491,3581,0.743361842995148,8197.783804655075,0.2566244602203369,0.7544290678841727,0.2848105444019942,3554,0.7261439437737408 -400.2509334087372,5.968505382537842,7874.45870757103,33758,0,7874.45870757103,0.2861514929052638,3581,0.7440031808503211,8281.86214709282,0.2563022545405796,0.7553148950849261,0.2847635572923994,3554,0.7268665423202729 -404.2913186550141,6.002583742141724,7954.426274776459,34102,0,7954.426274776459,0.2860142191972389,3581,0.7439999083705668,8365.916305065155,0.2563475711005075,0.7551569257463727,0.2846378805131542,3554,0.7268275924794246 -408.3378036022186,6.035517454147339,8034.5121150016785,34447,0,8034.5121150016785,0.2859932889621439,3581,0.7439937724710276,8450.093579769135,0.2563768284661429,0.7550387382507324,0.2846436508599465,3554,0.7268162578696539 -412.38020157814026,6.069193363189697,8114.659774541855,34791,0,8114.659774541855,0.286050932329482,3581,0.7435064456942893,8534.329248189926,0.2562845945358276,0.7547637394496373,0.2846722278154895,3554,0.726274326133406 -416.4238030910492,6.104076862335205,8194.64580988884,35135,0,8194.64580988884,0.2860049130829377,3581,0.7439281865226194,8618.405655145645,0.2562252964292253,0.7552019527980259,0.2846202431734049,3554,0.726717131555114 -420.465188741684,6.143252611160278,8274.605695009232,35479,0,8274.605695009232,0.2859727677870183,3581,0.7441111726822117,8702.45828127861,0.2562356335776193,0.755331311907087,0.2846041171149585,3554,0.726912018148565 -424.51155495643616,6.178053855895996,8354.700634717941,35821,0,8354.700634717941,0.2859783923615959,3581,0.7439823869685492,8786.64630651474,0.2562278509140014,0.7551850591387067,0.2846070022883546,3554,0.7267806053698298 -428.550342798233,6.211203336715698,8434.776199579239,36165,0,8434.776199579239,0.2859633934960556,3581,0.7439969767741204,8870.805983543396,0.256206955228533,0.7552205494471959,0.2845921470800946,3554,0.7267984659670442 -432.59374141693115,6.244928598403931,8438.036951065063,36189,0,8438.036951065063,0.28596335940772477,3581,0.7439967040674742,8878.14504647255,0.256206887108939,0.7552202088492257,0.28459212990644345,3554,0.7267981224940209 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/measurements.csv deleted file mode 100644 index 2cdf6cc49..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.3851867,0.90665424,,,,,,,,,,,,,, -1,,,0.2670762368610927,0.9032635007585798,0.2587972028106535,0.915261701713738,3554.0,0.2829431162570249,0.912513662602974,3581.0,30.82212352752685,34.802770137786865,30.82212352752685,3.9805192947387695,0.0,0.0 -100,0.5183088,0.208827,,,,,,,,,,,,,, -200,0.5754141,0.30864823,,,,,,,,,,,,,, -300,0.2483792,0.3257444,,,,,,,,,,,,,, -344,,,0.7226132665361676,0.285450986453465,0.6983893313432048,0.3115307191280951,3554.0,0.7153852090460067,0.3137289525010472,3581.0,110.83161187171936,118.88110089302064,110.83161187171936,8.01711654663086,0.0200138092041015,0.0 -400,0.22996186,0.32552707,,,,,,,,,,,,,, -500,0.25526947,0.26604897,,,,,,,,,,,,,, -600,0.2343773,0.2690425,,,,,,,,,,,,,, -684,,,0.7282571792602539,0.2770511763436453,0.7044029944252954,0.3023097736485298,3554.0,0.7213692791599763,0.3042089339242704,3581.0,191.00679397583008,203.1422142982483,191.00679397583008,12.06654405593872,0.0441045761108398,0.0 -700,0.28151542,0.26842868,,,,,,,,,,,,,, -800,0.20498557,0.2246263,,,,,,,,,,,,,, -900,0.17692119,0.32950816,,,,,,,,,,,,,, -1000,0.16468406,0.22835997,,,,,,,,,,,,,, -1020,,,0.7370714460100446,0.2704525334494455,0.711497979279509,0.296143402460256,3554.0,0.728832305592886,0.2980568765162489,3581.0,271.0234563350677,287.2428436279297,271.0234563350677,16.10910677909851,0.0733783245086669,0.0 -1100,0.18544787,0.3486896,,,,,,,,,,,,,, -1200,0.28473872,0.22882536,,,,,,,,,,,,,, -1300,0.22943889,0.31123042,,,,,,,,,,,,,, -1366,,,0.7303394590105329,0.2714135476521083,0.7056676620972847,0.2970385618537212,3554.0,0.7230863083810388,0.2986054941117181,3581.0,351.08174562454224,371.3848412036896,351.08174562454224,20.15527033805847,0.0985920429229736,0.0 -1400,0.3433861,0.29203814,,,,,,,,,,,,,, -1500,0.34107548,0.3960622,,,,,,,,,,,,,, -1600,0.271048,0.22807728,,,,,,,,,,,,,, -1700,0.1633269,0.3305081,,,,,,,,,,,,,, -1711,,,0.7415225846426827,0.2683972801480974,0.7163652667988534,0.2941325739923325,3554.0,0.7334914986430118,0.2959236969531031,3581.0,431.0522320270538,455.4326248168945,431.0522320270538,24.19566798210144,0.1235325336456298,0.0 -1800,0.32148638,0.26443237,,,,,,,,,,,,,, -1900,0.13831688,0.25860912,,,,,,,,,,,,,, -2000,0.4394917,0.2472544,,,,,,,,,,,,,, -2055,,,0.741027695792062,0.2685650076184954,0.71662994711065,0.2943257432206668,3554.0,0.7336164664636274,0.2959743181243018,3581.0,511.1865234375,539.6468026638031,511.1865234375,28.23944354057312,0.1476261615753173,0.0 -2100,0.296484,0.3316043,,,,,,,,,,,,,, -2200,0.16740426,0.25318235,,,,,,,,,,,,,, -2300,0.2495351,0.2510645,,,,,,,,,,,,,, -2400,,,0.7440621512276786,0.2670582021985735,0.7178861653462648,0.2936223104688379,3554.0,0.735299339177255,0.2951105879991622,3581.0,591.3152222633362,623.8655636310577,591.3152222633362,32.28887057304382,0.1761145591735839,0.0 -2400,0.16311572,0.35412347,,,,,,,,,,,,,, -2500,0.09003994,0.27034175,,,,,,,,,,,,,, -2600,0.18089388,0.27625978,,,,,,,,,,,,,, -2700,0.09014642,0.2625927,,,,,,,,,,,,,, -2747,,,0.7456315585545131,0.2664312124252319,0.7203574537492966,0.2921799985491699,3554.0,0.7375439875820302,0.2937197500152715,3581.0,671.4278931617737,708.0597774982452,671.4278931617737,36.32929086685181,0.205007791519165,0.0 -2800,0.26694587,0.3159148,,,,,,,,,,,,,, -2900,0.28707966,0.23585597,,,,,,,,,,,,,, -3000,0.27451178,0.26859123,,,,,,,,,,,,,, -3093,,,0.7441027505057198,0.2652908563613891,0.7181291381629854,0.2911974939658659,3554.0,0.735590589875384,0.292780752855784,3581.0,751.4628582000732,792.1761124134064,751.4628582000732,40.37436270713806,0.2290525436401367,0.0 -3100,0.18120338,0.2608294,,,,,,,,,,,,,, -3200,0.2386161,0.37280262,,,,,,,,,,,,,, -3300,0.20327163,0.30659896,,,,,,,,,,,,,, -3400,0.21754564,0.28718275,,,,,,,,,,,,,, -3438,,,0.7385843821934291,0.268498352595738,0.7121779871711452,0.295027252523565,3554.0,0.7294369644041468,0.2967027175762706,3581.0,831.4726746082306,876.2643911838531,831.4726746082306,44.415067195892334,0.2549364566802978,0.0 -3500,0.26038212,0.26277167,,,,,,,,,,,,,, -3600,0.19419108,0.29344997,,,,,,,,,,,,,, -3700,0.27145982,0.2837087,,,,,,,,,,,,,, -3783,,,0.7476419040134975,0.2647112948553903,0.7219686170072454,0.2910621999419668,3554.0,0.7389904236334125,0.2926715338439856,3581.0,911.6228678226472,960.4981019496918,911.6228678226472,48.46218538284302,0.2796492576599121,0.0 -3800,0.22651121,0.27446535,,,,,,,,,,,,,, -3900,0.07120605,0.26711994,,,,,,,,,,,,,, -4000,0.1765104,0.2989771,,,,,,,,,,,,,, -4100,0.071633354,0.26921743,,,,,,,,,,,,,, -4130,,,0.744476454598563,0.2648463419505528,0.7180813954127392,0.29101665541907,3554.0,0.7354596225085521,0.2926502286372521,3581.0,991.5995440483092,1044.5592665672302,991.5995440483092,52.50946092605591,0.3047735691070556,0.0 -4200,0.16627912,0.24641675,,,,,,,,,,,,,, -4300,0.42646825,0.29409963,,,,,,,,,,,,,, -4400,0.113737725,0.2947785,,,,,,,,,,,,,, -4472,,,0.7401151657104492,0.2657135725021362,0.7127527549284257,0.2923074957354389,3554.0,0.7306974145228288,0.2937538042577143,3581.0,1071.6571695804596,1128.7016820907593,1071.6571695804596,56.5556275844574,0.3315975666046142,0.0 -4500,0.07409469,0.20256023,,,,,,,,,,,,,, -4600,0.16494119,0.20805262,,,,,,,,,,,,,, -4700,0.106141515,0.22216752,,,,,,,,,,,,,, -4800,0.15669467,0.3160457,,,,,,,,,,,,,, -4819,,,0.7446002279009137,0.2647957120622907,0.7191272020742473,0.290804663869056,3554.0,0.7363732579499441,0.2922750183804279,3581.0,1151.627099752426,1212.7570588588717,1151.627099752426,60.60241341590881,0.3582265377044678,0.0 -4900,0.13501924,0.29015312,,,,,,,,,,,,,, -5000,0.15238026,0.37062624,,,,,,,,,,,,,, -5100,0.09732206,0.3306864,,,,,,,,,,,,,, -5162,,,0.7490038871765137,0.2648267064775739,0.7233019105893008,0.2915698187231816,3554.0,0.7402641682010961,0.2931434867835451,3581.0,1231.6058168411255,1296.8183772563934,1231.6058168411255,64.64353156089783,0.3878440856933594,0.0 -5200,0.19204782,0.24084489,,,,,,,,,,,,,, -5300,0.33186245,0.22375664,,,,,,,,,,,,,, -5400,0.09195747,0.33212253,,,,,,,,,,,,,, -5500,0.18619043,0.28243852,,,,,,,,,,,,,, -5505,,,0.7461484500340053,0.2651594366346086,0.7193948362540448,0.2920032816786543,3554.0,0.7366458964194709,0.2934891765459194,3581.0,1311.6497621536255,1380.9428961277008,1311.6497621536255,68.6874794960022,0.4124460220336914,0.0 -5600,0.20270714,0.25750613,,,,,,,,,,,,,, -5700,0.15770197,0.2987288,,,,,,,,,,,,,, -5800,0.27156976,0.29810557,,,,,,,,,,,,,, -5851,,,0.7486353601728167,0.264972414289202,0.72320086082583,0.2916949459455895,3554.0,0.7399766672193522,0.2933677880000349,3581.0,1391.8070256710052,1465.185186624527,1391.8070256710052,72.73325824737549,0.4397716522216797,0.0 -5900,0.21833202,0.3032564,,,,,,,,,,,,,, -6000,0.23338641,0.21389224,,,,,,,,,,,,,, -6100,0.19286005,0.32426032,,,,,,,,,,,,,, -6197,,,0.7402887344360352,0.2653241668428693,0.7139138311365715,0.2915773751296954,3554.0,0.7316704318364283,0.2929133905508237,3581.0,1471.9566292762756,1549.4155497550964,1471.9566292762756,76.77447819709778,0.467041015625,0.0 -6200,0.09763719,0.29494673,,,,,,,,,,,,,, -6300,0.119579256,0.2402495,,,,,,,,,,,,,, -6400,0.24177422,0.23509318,,,,,,,,,,,,,, -6500,0.1371438,0.25952893,,,,,,,,,,,,,, -6540,,,0.7448414393833706,0.2647097281047276,0.7187118744944077,0.2913872284639666,3554.0,0.7359811057927255,0.2928044101573408,3581.0,1551.9505987167358,1633.491622209549,1551.9505987167358,80.81879162788391,0.4929647445678711,0.0 -6600,0.19922887,0.30281442,,,,,,,,,,,,,, -6700,0.16811731,0.2924772,,,,,,,,,,,,,, -6800,0.12133973,0.30970863,,,,,,,,,,,,,, -6885,,,0.7436417170933315,0.2659179653440203,0.716956727345069,0.2928763901040201,3554.0,0.7342648946916015,0.2944325029234153,3581.0,1632.0223336219788,1717.647849559784,1632.0223336219788,84.86484432220459,0.5194802284240723,0.0 -6900,0.13469622,0.35355514,,,,,,,,,,,,,, -7000,0.38056445,0.24839172,,,,,,,,,,,,,, -7100,0.1345351,0.24820581,,,,,,,,,,,,,, -7200,0.19370465,0.27727643,,,,,,,,,,,,,, -7230,,,0.7429020064217704,0.2644589628492083,0.7165945693892445,0.2905732660932224,3554.0,0.7341480398937098,0.2919725526214744,3581.0,1712.0917258262634,1801.7990944385529,1712.0917258262634,88.90904879570007,0.5449492931365967,0.0 -7300,0.09663754,0.27971148,,,,,,,,,,,,,, -7400,0.24458408,0.26218048,,,,,,,,,,,,,, -7500,0.22025445,0.20724557,,,,,,,,,,,,,, -7573,,,0.7418189729963031,0.2645441464015415,0.7159344142383582,0.2905525890172165,3554.0,0.7334203222083566,0.2919843130955913,3581.0,1792.1360602378843,1885.9275164604187,1792.1360602378843,92.9541277885437,0.572021484375,0.0 -7600,0.30425784,0.26711294,,,,,,,,,,,,,, -7700,0.33780384,0.27137464,,,,,,,,,,,,,, -7800,0.12768355,0.2823058,,,,,,,,,,,,,, -7900,0.16332684,0.2686871,,,,,,,,,,,,,, -7918,,,0.7490203039986747,0.2643026624407087,0.723283088267621,0.290826577447946,3554.0,0.7403620698870776,0.2923499104431374,3581.0,1872.2391254901888,1970.113526582718,1872.2391254901888,96.99868774414062,0.5981225967407227,0.0 -8000,0.13018084,0.3024965,,,,,,,,,,,,,, -8100,0.23168693,0.22293304,,,,,,,,,,,,,, -8200,0.16008642,0.27666566,,,,,,,,,,,,,, -8265,,,0.7482437406267438,0.2642133235931396,0.7220145736977701,0.2909175291045301,3554.0,0.7392333370785046,0.2923485809982372,3581.0,1952.3275845050807,2054.2847859859467,1952.3275845050807,101.04316473007202,0.6242008209228516,0.0 -8300,0.20775828,0.2517766,,,,,,,,,,,,,, -8400,0.19357555,0.2709783,,,,,,,,,,,,,, -8500,0.16942918,0.3550471,,,,,,,,,,,,,, -8600,0.09712767,0.34925318,,,,,,,,,,,,,, -8607,,,0.7464250155857631,0.2638657093048095,0.7195118918604038,0.2906534326968732,3554.0,0.7369421921905543,0.2919754501295902,3581.0,2032.51903796196,2138.555235147476,2032.51903796196,105.08436179161072,0.6500787734985352,0.0 -8700,0.26023486,0.2090784,,,,,,,,,,,,,, -8800,0.18512581,0.20743677,,,,,,,,,,,,,, -8900,0.14496173,0.23596428,,,,,,,,,,,,,, -8951,,,0.7454898016793388,0.2632397242954799,0.7187876446433596,0.2896331804283026,3554.0,0.736424458622766,0.2910422138980033,3581.0,2112.5709660053253,2222.6944692134857,2112.5709660053253,109.13049507141112,0.6793222427368164,0.0 -9000,0.15062565,0.20098852,,,,,,,,,,,,,, -9100,0.19077952,0.30668357,,,,,,,,,,,,,, -9200,0.2737041,0.21293484,,,,,,,,,,,,,, -9293,,,0.7439816338675362,0.2630248410361154,0.7169932728747538,0.289561634997538,3554.0,0.7345339197980661,0.2910244879660011,3581.0,2192.5693085193634,2306.7771003246307,2192.5693085193634,113.17689657211304,0.7051200866699219,0.0 -9300,0.13952692,0.32031444,,,,,,,,,,,,,, -9400,0.22150995,0.23315822,,,,,,,,,,,,,, -9500,0.24534127,0.27206093,,,,,,,,,,,,,, -9600,0.26469445,0.2741169,,,,,,,,,,,,,, -9636,,,0.7478782790047782,0.2636504684175764,0.7215138587243247,0.2900727915508934,3554.0,0.7388857042812762,0.2914541713754014,3581.0,2272.6771895885468,2390.965298175812,2272.6771895885468,117.2190728187561,0.7311830520629883,0.0 -9700,0.1794584,0.250325,,,,,,,,,,,,,, -9800,0.17106432,0.2641202,,,,,,,,,,,,,, -9900,0.29929262,0.25734794,,,,,,,,,,,,,, -9980,,,0.7397242954799107,0.2646668297903878,0.7141337225661227,0.2910097516113006,3554.0,0.7312535997277296,0.2924467894791958,3581.0,2352.732777118683,2475.1045954227448,2352.732777118683,121.26503705978394,0.7569248676300049,0.0 -10000,0.18900107,0.25272897,,,,,,,,,,,,,, -10100,0.07370399,0.2883696,,,,,,,,,,,,,, -10200,0.23822351,0.35535076,,,,,,,,,,,,,, -10300,0.14448418,0.24834761,,,,,,,,,,,,,, -10325,,,0.749014036996024,0.2629481383732387,0.7227236394071821,0.2893752321877638,3554.0,0.7399639181836428,0.2908956340756772,3581.0,2432.74764752388,2559.2052206993103,2432.74764752388,125.30674004554749,0.7888193130493164,0.0 -10400,0.08347887,0.31063288,,,,,,,,,,,,,, -10500,0.060370788,0.2170692,,,,,,,,,,,,,, -10600,0.30217105,0.2600026,,,,,,,,,,,,,, -10669,,,0.7450334003993443,0.2640965155192784,0.7185316885463562,0.290682490514649,3554.0,0.7358866811164828,0.292158947614144,3581.0,2512.824273824692,2643.370032787323,2512.824273824692,129.35560011863708,0.816091775894165,0.0 -10700,0.24106432,0.32318383,,,,,,,,,,,,,, -10800,0.12217338,0.19851466,,,,,,,,,,,,,, -10900,0.18348692,0.2587854,,,,,,,,,,,,,, -11000,0.1502322,0.3365804,,,,,,,,,,,,,, -11014,,,0.7504299027579171,0.26199814251491,0.7241798963315982,0.2886156072493142,3554.0,0.7411439880183608,0.2901744272615016,3581.0,2592.9127190113068,2727.5416226387024,2592.9127190113068,133.3985664844513,0.8439867496490479,0.0 -11100,0.13380194,0.25422722,,,,,,,,,,,,,, -11200,0.091143064,0.2489095,,,,,,,,,,,,,, -11300,0.14448068,0.30363932,,,,,,,,,,,,,, -11358,,,0.7455850328717913,0.2641546896525791,0.7196553261949564,0.2902342925664743,3554.0,0.7365518126265359,0.2919174999672752,3581.0,2672.970718383789,2811.682421684265,2672.970718383789,137.44315481185913,0.8700685501098633,0.0 -11400,0.15820023,0.26270553,,,,,,,,,,,,,, -11500,0.12555245,0.29192215,,,,,,,,,,,,,, -11600,0.1936865,0.2472802,,,,,,,,,,,,,, -11700,0.13515791,0.24457455,,,,,,,,,,,,,, -11702,,,0.7454074450901577,0.2641615867614746,0.7192417846748382,0.2905541689931239,3554.0,0.7364029147977171,0.2920892028893814,3581.0,2752.944223165512,2895.736180782318,2752.944223165512,141.48319029808044,0.8981711864471436,0.0 -11800,0.13286597,0.28221217,,,,,,,,,,,,,, -11900,0.16925399,0.34361133,,,,,,,,,,,,,, -12000,0.14032637,0.32747313,,,,,,,,,,,,,, -12044,,,0.7435768672398159,0.2633335930960519,0.7189717461838773,0.2896535483785875,3554.0,0.7357157622259843,0.29123552882182,3581.0,2832.9485371112823,2979.825961828232,2832.9485371112823,145.52828335762024,0.9264392852783204,0.0 -12100,0.20275235,0.22279336,,,,,,,,,,,,,, -12200,0.0946609,0.27422038,,,,,,,,,,,,,, -12300,0.16818896,0.28270274,,,,,,,,,,,,,, -12389,,,0.7485134260995048,0.2619869198117937,0.7220078416265123,0.2885388066812922,3554.0,0.7393814167873848,0.2899312751980941,3581.0,2913.143745660782,3064.1029093265533,2913.143745660782,149.57113814353943,0.9530577659606934,0.0 -12400,0.18972947,0.2552971,,,,,,,,,,,,,, -12500,0.2128883,0.3360421,,,,,,,,,,,,,, -12600,0.16340786,0.37431136,,,,,,,,,,,,,, -12700,0.10239174,0.2915533,,,,,,,,,,,,,, -12736,,,0.7500293595450265,0.2616373130253383,0.723628003877673,0.2881183270061023,3554.0,0.7409452530499512,0.2895676208854021,3581.0,2993.192461490631,3148.232587814331,2993.192461490631,153.61311268806458,0.9799468517303468,0.0 -12800,0.24272555,0.253842,,,,,,,,,,,,,, -12900,0.16167943,0.22039531,,,,,,,,,,,,,, -13000,0.1615324,0.28624508,,,,,,,,,,,,,, -13078,,,0.7514548982892718,0.2612794807979038,0.724869796246307,0.2879144757667417,3554.0,0.742036215988027,0.2893386836559271,3581.0,3073.1728444099426,3232.291928291321,3073.1728444099426,157.65242910385132,1.007526397705078,0.0 -13100,0.17160295,0.2431315,,,,,,,,,,,,,, -13200,0.20177913,0.19339809,,,,,,,,,,,,,, -13300,0.16829953,0.25871336,,,,,,,,,,,,,, -13400,0.17882638,0.2901453,,,,,,,,,,,,,, -13424,,,0.7457209995814732,0.2630315848759242,0.7189907058947664,0.2893385149215672,3554.0,0.7365478583801661,0.2907379073691881,3581.0,3153.252564430237,3316.4589569568634,3153.252564430237,161.69960498809814,1.0354080200195312,0.0 -13500,0.21115322,0.27255464,,,,,,,,,,,,,, -13600,0.27741858,0.22020116,,,,,,,,,,,,,, -13700,0.1284811,0.330591,,,,,,,,,,,,,, -13770,,,0.7465126173836845,0.2622240781784057,0.7193183791590462,0.2889956601496553,3554.0,0.7370009604728078,0.2902145492268221,3581.0,3233.4567432403564,3400.75181722641,3233.4567432403564,165.74770259857178,1.0637977123260498,0.0 -13800,0.1563591,0.21920006,,,,,,,,,,,,,, -13900,0.22180618,0.23742616,,,,,,,,,,,,,, -14000,0.08059559,0.3352333,,,,,,,,,,,,,, -14100,0.15744874,0.28217238,,,,,,,,,,,,,, -14111,,,0.748753275190081,0.2619667734418596,0.7221636409899057,0.28858569074898,3554.0,0.7393679178083985,0.2901376118642663,3581.0,3313.4611835479736,3484.84254193306,3313.4611835479736,169.79286742210388,1.093024969100952,0.0 -14200,0.05735554,0.23878421,,,,,,,,,,,,,, -14300,0.12530851,0.2921019,,,,,,,,,,,,,, -14400,0.1610376,0.23132621,,,,,,,,,,,,,, -14457,,,0.7506982939583915,0.2622332062040056,0.7246852825381612,0.2887858324796883,3554.0,0.7417963704927045,0.2902302639473087,3581.0,3393.474872589112,3568.93760895729,3393.474872589112,173.83345460891724,1.1216280460357666,0.0 -14500,0.15076856,0.3692115,,,,,,,,,,,,,, -14600,0.08611508,0.20640545,,,,,,,,,,,,,, -14700,0.05546996,0.2254359,,,,,,,,,,,,,, -14800,0.14665343,0.2694381,,,,,,,,,,,,,, -14804,,,0.748584338596889,0.262541617665972,0.7220001478307893,0.2888958125417663,3554.0,0.739383939323862,0.2902830667716769,3581.0,3473.484473705292,3653.033741474152,3473.484473705292,177.88046073913574,1.1491503715515137,0.0 -14900,0.08624917,0.23731625,,,,,,,,,,,,,, -15000,0.16576505,0.23146577,,,,,,,,,,,,,, -15100,0.124386154,0.21066482,,,,,,,,,,,,,, -15148,,,0.746781485421317,0.2616633006504604,0.7198642264877603,0.2881038839654702,3554.0,0.7373965896397654,0.2894467436644792,3581.0,3553.4609277248383,3737.0999524593353,3553.4609277248383,181.9265389442444,1.1808226108551023,0.0 -15200,0.095376134,0.24700284,,,,,,,,,,,,,, -15300,0.119614735,0.24522497,,,,,,,,,,,,,, -15400,0.24416514,0.23749198,,,,,,,,,,,,,, -15479,,,0.7495412826538086,0.2613825287137712,0.7229363865978475,0.2880535479938977,3554.0,0.7400738189620567,0.2894908880528309,3581.0,3630.503707885742,3821.306978702545,3630.503707885742,185.9671552181244,4.293005704879761,0.0 -15500,0.08378432,0.35040542,,,,,,,,,,,,,, -15600,0.20349719,0.30746374,,,,,,,,,,,,,, -15700,0.13711254,0.24839064,,,,,,,,,,,,,, -15800,0.115993306,0.30637607,,,,,,,,,,,,,, -15827,,,0.7468673161097935,0.2616352353777204,0.7202307808982836,0.2881785721743986,3554.0,0.7377974002330006,0.2894689692561435,3581.0,3710.492092132568,3905.3816516399374,3710.492092132568,190.0122487545013,4.322009086608887,0.0 -15900,0.13345405,0.3013661,,,,,,,,,,,,,, -16000,0.10176493,0.22638221,,,,,,,,,,,,,, -16100,0.16402446,0.2797056,,,,,,,,,,,,,, -16170,,,0.7493359701974052,0.261055520602635,0.7227794194261747,0.287572565545644,3554.0,0.7401714479413921,0.2889514061300091,3581.0,3790.4588992595673,3989.43754029274,3790.4588992595673,194.06110763549805,4.35016131401062,0.0 -16200,0.119231544,0.26071087,,,,,,,,,,,,,, -16300,0.17584278,0.22421055,,,,,,,,,,,,,, -16400,0.2472691,0.24630262,,,,,,,,,,,,,, -16500,0.10160729,0.21084836,,,,,,,,,,,,,, -16513,,,0.7490717342921666,0.2611632687704904,0.7222987632772931,0.2879586979184985,3554.0,0.7395680844867006,0.2893771352930396,3581.0,3870.481683015824,4073.548889398575,3870.481683015824,198.1052176952362,4.382714986801148,0.0 -16600,0.14352055,0.19101526,,,,,,,,,,,,,, -16700,0.09651153,0.26779178,,,,,,,,,,,,,, -16800,0.08135974,0.23929636,,,,,,,,,,,,,, -16857,,,0.7464184079851423,0.2624172823769705,0.719400606600837,0.2888134820580684,3554.0,0.7370565244519687,0.2901986981530124,3581.0,3950.5336406230927,4157.687446594238,3950.5336406230927,202.15042400360107,4.412107944488525,0.0 -16900,0.107081525,0.35825324,,,,,,,,,,,,,, -17000,0.12609708,0.24272522,,,,,,,,,,,,,, -17100,0.12747034,0.27766857,,,,,,,,,,,,,, -17200,0.28231156,0.34865534,,,,,,,,,,,,,, -17202,,,0.7484077726091657,0.2613915715898786,0.7219015023784819,0.2878314926842993,3554.0,0.7391657740069115,0.2892383616984606,3581.0,4030.594305992128,4241.836377620697,4030.594305992128,206.19731330871585,4.441436052322388,0.0 -17300,0.12618013,0.3105283,,,,,,,,,,,,,, -17400,0.17054234,0.25955707,,,,,,,,,,,,,, -17500,0.06510934,0.30418426,,,,,,,,,,,,,, -17547,,,0.7491188049316406,0.2607495614460536,0.7221773799108399,0.287938776483144,3554.0,0.7394163914147585,0.2893575345028448,3581.0,4110.767489433289,4326.093504428864,4110.767489433289,210.2408187389373,4.469689130783081,0.0 -17600,0.117503785,0.26013345,,,,,,,,,,,,,, -17700,0.13936867,0.29050127,,,,,,,,,,,,,, -17800,0.14498943,0.36515322,,,,,,,,,,,,,, -17893,,,0.7488366535731724,0.2610698427472795,0.7225361031364308,0.2876933821816087,3554.0,0.7397306176478288,0.2890635567382539,3581.0,4190.812663078308,4410.222549676895,4190.812663078308,214.283281326294,4.499216556549072,0.0 -17900,0.13724591,0.32324302,,,,,,,,,,,,,, -18000,0.05880072,0.37086558,,,,,,,,,,,,,, -18100,0.1610164,0.23308532,,,,,,,,,,,,,, -18200,0.12037148,0.36249438,,,,,,,,,,,,,, -18241,,,0.7490247998918805,0.2612367698124477,0.7226962302599184,0.2877226804305008,3554.0,0.7400820683381039,0.2890376155185353,3581.0,4270.923578500748,4494.421721696854,4270.923578500748,218.32737565040588,4.531201362609863,0.0 -18300,0.16257195,0.2031862,,,,,,,,,,,,,, -18400,0.24049266,0.28979194,,,,,,,,,,,,,, -18500,0.14359076,0.23323444,,,,,,,,,,,,,, -18584,,,0.7490253448486328,0.2604962587356567,0.7227430799803038,0.2874368078331633,3554.0,0.7397645014486177,0.2889587010327946,3581.0,4350.883985042572,4578.475754737854,4350.883985042572,222.3760340213776,4.564069509506226,0.0 -18600,0.06597962,0.29624504,,,,,,,,,,,,,, -18700,0.17905492,0.30625722,,,,,,,,,,,,,, -18800,0.11952822,0.2622695,,,,,,,,,,,,,, -18900,0.04473741,0.28187323,,,,,,,,,,,,,, -18930,,,0.746837956564767,0.2611930540629795,0.7200128129176632,0.2878367134742544,3554.0,0.7375610317474169,0.289137153444394,3581.0,4430.957304239273,4662.636361837387,4430.957304239273,226.41681051254272,4.5984275341033936,0.0 -19000,0.1763647,0.29579127,,,,,,,,,,,,,, -19100,0.11498627,0.36831677,,,,,,,,,,,,,, -19200,0.1172696,0.35542455,,,,,,,,,,,,,, -19276,,,0.7515696798052106,0.2613452843257359,0.7252221995682682,0.2881080056417505,3554.0,0.742224383573897,0.2894971603056932,3581.0,4511.09930229187,4746.862930059433,4511.09930229187,230.45980858802795,4.627838373184204,0.0 -19300,0.05982042,0.31968874,,,,,,,,,,,,,, -19400,0.16234286,0.20483413,,,,,,,,,,,,,, -19500,0.119435206,0.3120688,,,,,,,,,,,,,, -19600,0.108982,0.2605657,,,,,,,,,,,,,, -19620,,,0.7509069442749023,0.2607353925704956,0.7235287401739238,0.2882456868031619,3554.0,0.7407267468496929,0.28975834509608,3581.0,4591.108497858048,4830.964126110077,4591.108497858048,234.5062789916992,4.661612987518311,0.0 -19700,0.14003506,0.250587,,,,,,,,,,,,,, -19800,0.19401318,0.2869247,,,,,,,,,,,,,, -19900,0.07264986,0.27775952,,,,,,,,,,,,,, -19969,,,0.747917720249721,0.2605396509170532,0.7213956353096863,0.2870852460201094,3554.0,0.7388614333897654,0.2884193213749651,3581.0,4671.207195281982,4915.146928071976,4671.207195281982,238.5473177433014,4.692481756210327,0.0 -20000,0.08106866,0.3505778,,,,,,,,,,,,,, -20100,0.09809739,0.41858593,,,,,,,,,,,,,, -20200,0.14866519,0.28266,,,,,,,,,,,,,, -20300,0.10881093,0.24441303,,,,,,,,,,,,,, -20315,,,0.7510645730154855,0.2608891555241176,0.7249340257016742,0.2874345580848603,3554.0,0.7419790157689891,0.2887872708173345,3581.0,4751.304104089737,4999.330250740051,4751.304104089737,242.59148383140564,4.722586393356323,0.0 -20400,0.11993205,0.23925723,,,,,,,,,,,,,, -20500,0.10125777,0.2985533,,,,,,,,,,,,,, -20600,0.07604974,0.24231613,,,,,,,,,,,,,, -20657,,,0.750018664768764,0.2601903336388724,0.722571961720069,0.287154438660664,3554.0,0.7399350112791468,0.2885486525019198,3581.0,4831.424298048019,5083.538475513458,4831.424298048019,246.63685989379883,4.753290891647339,0.0 -20700,0.18210751,0.23286317,,,,,,,,,,,,,, -20800,0.07631735,0.27443796,,,,,,,,,,,,,, -20900,0.14098129,0.28915653,,,,,,,,,,,,,, -21000,0.18335024,0.394011,,,,,,,,,,,,,, -21003,,,0.750136239188058,0.2597711597170148,0.7229999978017726,0.2868698712608152,3554.0,0.7402386701296775,0.2883246921687378,3581.0,4911.5534324646,5167.75226020813,4911.5534324646,250.6783995628357,4.78401255607605,0.0 -21100,0.09016558,0.28560284,,,,,,,,,,,,,, -21200,0.20691033,0.22411829,,,,,,,,,,,,,, -21300,0.06756322,0.30493987,,,,,,,,,,,,,, -21347,,,0.7508693422589984,0.2598606007439749,0.7238796322145822,0.2868814291280511,3554.0,0.7412270953687866,0.2881120491613725,3581.0,4991.542078495026,5251.82776093483,4991.542078495026,254.72396564483645,4.8128838539123535,0.0 -21400,0.11996557,0.25764143,,,,,,,,,,,,,, -21500,0.10567263,0.25577083,,,,,,,,,,,,,, -21600,0.05958251,0.23739949,,,,,,,,,,,,,, -21690,,,0.7507828984941755,0.2596160514014108,0.7240258830279263,0.2863647598326709,3554.0,0.7412582521031137,0.2877457700472109,3581.0,5071.724465370178,5336.094830274582,5071.724465370178,258.7666749954224,4.8427369594573975,0.0 -21700,0.08264526,0.399659,,,,,,,,,,,,,, -21800,0.07177192,0.18243286,,,,,,,,,,,,,, -21900,0.109522074,0.281831,,,,,,,,,,,,,, -22000,0.07240308,0.35709724,,,,,,,,,,,,,, -22037,,,0.7516030584062848,0.2595355340412685,0.7247510232748312,0.2865019257845473,3554.0,0.7419090665142418,0.2878893500964291,3581.0,5151.788769721985,5420.246787071228,5151.788769721985,262.81214714050293,4.872639179229736,0.0 -22100,0.07683311,0.29377624,,,,,,,,,,,,,, -22200,0.14726214,0.2747944,,,,,,,,,,,,,, -22300,0.06417939,0.25268018,,,,,,,,,,,,,, -22382,,,0.7525712421962193,0.2594784498214721,0.7257786945607062,0.2865110793406197,3554.0,0.742977122094038,0.2878808280137357,3581.0,5232.000771999359,5504.5493080616,5232.000771999359,266.8560085296631,4.907236814498901,0.0 -22400,0.07733939,0.29144257,,,,,,,,,,,,,, -22500,0.19210954,0.24432123,,,,,,,,,,,,,, -22600,0.11748409,0.24149518,,,,,,,,,,,,,, -22700,0.08395966,0.1975832,,,,,,,,,,,,,, -22727,,,0.7518329620361328,0.2596061059406825,0.7250375484709131,0.2864908487795441,3554.0,0.7421773416774294,0.2878331725273143,3581.0,5312.106199026108,5588.733791828156,5312.106199026108,270.8928482532501,4.937005043029785,0.0 -22800,0.12380234,0.2373031,,,,,,,,,,,,,, -22900,0.11657933,0.19616768,,,,,,,,,,,,,, -23000,0.1132232,0.27179477,,,,,,,,,,,,,, -23072,,,0.7509150505065918,0.2589318581989833,0.7236573364738674,0.2861938991772035,3554.0,0.741036337069778,0.2875140375746125,3581.0,5392.222705602646,5672.93510222435,5392.222705602646,274.9365813732147,4.966019153594971,0.0 -23100,0.12378489,0.30495605,,,,,,,,,,,,,, -23200,0.047432505,0.329731,,,,,,,,,,,,,, -23300,0.0808258,0.25206468,,,,,,,,,,,,,, -23400,0.121604905,0.3111355,,,,,,,,,,,,,, -23418,,,0.7504358291625977,0.259568520954677,0.723816433178285,0.2864663935002813,3554.0,0.7411057409112329,0.2878223324381283,3581.0,5472.2902681827545,5757.086367607117,5472.2902681827545,278.97792768478394,4.996206760406494,0.0 -23500,0.09840868,0.2345154,,,,,,,,,,,,,, -23600,0.056473345,0.23682608,,,,,,,,,,,,,, -23700,0.09794156,0.29895544,,,,,,,,,,,,,, -23760,,,0.7507174355643136,0.259830219405038,0.7237039801104389,0.2867341478956369,3554.0,0.7410613579045657,0.288040020518448,3581.0,5552.431335449219,5841.311386823654,5552.431335449219,283.0185122489929,5.02751088142395,0.0 -23800,0.087722115,0.24101743,,,,,,,,,,,,,, -23900,0.09428323,0.22343785,,,,,,,,,,,,,, -24000,0.15611368,0.24000475,,,,,,,,,,,,,, -24100,0.05551834,0.32853368,,,,,,,,,,,,,, -24105,,,0.7529386111668178,0.2583275181906564,0.7255741907226013,0.2857426271454699,3554.0,0.7427400036651773,0.2871316346799951,3581.0,5632.593497753143,5925.557373762131,5632.593497753143,287.06051874160767,5.057190895080566,0.0 -24200,0.08542114,0.26240766,,,,,,,,,,,,,, -24300,0.15918918,0.2399633,,,,,,,,,,,,,, -24400,0.069428645,0.19025083,,,,,,,,,,,,,, -24450,,,0.7510510172162738,0.2591849395206996,0.7238414380143852,0.286399175829611,3554.0,0.7411732358061645,0.2877290667651319,3581.0,5712.731587409973,6009.784536600113,5712.731587409973,291.10566115379333,5.088764429092407,0.0 -24500,0.061155476,0.29643485,,,,,,,,,,,,,, -24600,0.062386617,0.2871765,,,,,,,,,,,,,, -24700,0.03193584,0.2548094,,,,,,,,,,,,,, -24793,,,0.7537099293300084,0.2592285360608782,0.7267622639103827,0.2864781231040289,3554.0,0.7438580327378874,0.2878063109226647,3581.0,5792.829498052597,6093.966063499451,5792.829498052597,295.145968914032,5.1201136112213135,0.0 -24800,0.090435624,0.27385616,,,,,,,,,,,,,, -24900,0.10938052,0.2622201,,,,,,,,,,,,,, -25000,0.15476693,0.25065622,,,,,,,,,,,,,, -25100,0.098388635,0.28382653,,,,,,,,,,,,,, -25138,,,0.7527525765555245,0.2580028772354126,0.7249790893623382,0.2856826739292434,3554.0,0.7422174977310807,0.2870566403522933,3581.0,5872.8541264534,6178.080864906311,5872.8541264534,299.1887602806092,5.155481338500977,0.0 -25200,0.06178034,0.23001231,,,,,,,,,,,,,, -25300,0.18660921,0.20367125,,,,,,,,,,,,,, -25400,0.063301705,0.21989064,,,,,,,,,,,,,, -25484,,,0.751709120614188,0.2586135864257812,0.7240188761782499,0.286100251257386,3554.0,0.7413706754180047,0.2874117044056304,3581.0,5952.869534730911,6262.180478811264,5952.869534730911,303.2299098968506,5.186309576034546,0.0 -25500,0.0656905,0.25516045,,,,,,,,,,,,,, -25600,0.08444333,0.3346514,,,,,,,,,,,,,, -25700,0.088734485,0.41411215,,,,,,,,,,,,,, -25800,0.1120488,0.2791049,,,,,,,,,,,,,, -25830,,,0.7517774445669991,0.2585456371307373,0.7246146644845597,0.2857211257342079,3554.0,0.7417736676644093,0.2870886492948897,3581.0,6032.928485393524,6346.324097394943,6032.928485393524,307.27237248420715,5.216569900512695,0.0 -25900,0.08650525,0.24740958,,,,,,,,,,,,,, -26000,0.087697685,0.29529873,,,,,,,,,,,,,, -26100,0.06197897,0.27475035,,,,,,,,,,,,,, -26173,,,0.7526977402823312,0.2580876691000802,0.7245964604143219,0.2860090763433367,3554.0,0.7418254137505236,0.2873534133600077,3581.0,6113.055310726166,6430.539331912994,6113.055310726166,311.31592655181885,5.249204397201538,0.0 -26200,0.10443065,0.29374397,,,,,,,,,,,,,, -26300,0.051424507,0.36300805,,,,,,,,,,,,,, -26400,0.11359692,0.25061834,,,,,,,,,,,,,, -26500,0.09029543,0.25253218,,,,,,,,,,,,,, -26515,,,0.753077507019043,0.2581151042665754,0.7256595094216024,0.2855437390912967,3554.0,0.7428827655944569,0.2869383197561784,3581.0,6193.021834373474,6514.592348814011,6193.021834373474,315.35939478874207,5.28049635887146,0.0 -26600,0.06047606,0.26706,,,,,,,,,,,,,, -26700,0.07467899,0.34342226,,,,,,,,,,,,,, -26800,0.0513894,0.179376,,,,,,,,,,,,,, -26861,,,0.7530824797494071,0.2580545970371791,0.7258945136641812,0.2853332416489343,3554.0,0.743087909169052,0.2866473417646956,3581.0,6273.217540979385,6598.875570058823,6273.217540979385,319.4024076461792,5.313000917434692,0.0 -26900,0.105223075,0.21472552,,,,,,,,,,,,,, -27000,0.06586276,0.213957,,,,,,,,,,,,,, -27100,0.07370328,0.24782042,,,,,,,,,,,,,, -27200,0.0393859,0.31174302,,,,,,,,,,,,,, -27205,,,0.7533891541617257,0.2577869551522391,0.7262223243176702,0.2854679345950425,3554.0,0.7432430110740715,0.286867041056531,3581.0,6353.390656471252,6683.127971410751,6353.390656471252,323.43775701522827,5.34488582611084,0.0 -27300,0.07222878,0.24213114,,,,,,,,,,,,,, -27400,0.06862223,0.3060838,,,,,,,,,,,,,, -27500,0.10555972,0.2374031,,,,,,,,,,,,,, -27551,,,0.7539660590035575,0.2581208944320678,0.726424355150007,0.285553390683253,3554.0,0.7436233686688425,0.286973601178529,3581.0,6433.46320271492,6767.291547298431,6433.46320271492,327.4859962463379,5.375718593597412,0.0 -27600,0.056841128,0.24599425,,,,,,,,,,,,,, -27700,0.14731474,0.29246747,,,,,,,,,,,,,, -27800,0.041295975,0.3135308,,,,,,,,,,,,,, -27897,,,0.7541084289550781,0.2578811134610857,0.7266418422683948,0.2854553978296901,3554.0,0.7438512832483943,0.2867495385803546,3581.0,6513.643239974976,6851.556659221649,6513.643239974976,331.52803587913513,5.40687370300293,0.0 -27900,0.071668915,0.2708317,,,,,,,,,,,,,, -28000,0.08042549,0.23571128,,,,,,,,,,,,,, -28100,0.07732325,0.27516642,,,,,,,,,,,,,, -28200,0.07069039,0.25640264,,,,,,,,,,,,,, -28240,,,0.7530191285269601,0.2581344842910766,0.7249559392805641,0.286052113513163,3554.0,0.7422999233148911,0.2873368123429209,3581.0,6593.765914440155,6935.765740871429,6593.765914440155,335.5698335170746,5.439523696899414,0.0 -28300,0.065212175,0.24016947,,,,,,,,,,,,,, -28400,0.043036822,0.2376306,,,,,,,,,,,,,, -28500,0.076747745,0.29576528,,,,,,,,,,,,,, -28587,,,0.753343037196568,0.2573441437312534,0.725530020091798,0.2851177810213843,3554.0,0.7427766827090896,0.2864703892396502,3581.0,6673.728529691696,7019.814661979675,6673.728529691696,339.61222982406616,5.471314668655396,0.0 -28600,0.07538031,0.36182773,,,,,,,,,,,,,, -28700,0.0540183,0.26761103,,,,,,,,,,,,,, -28800,0.05373318,0.20790797,,,,,,,,,,,,,, -28900,0.056797463,0.34109384,,,,,,,,,,,,,, -28932,,,0.7535497801644462,0.2575488771711077,0.7259476832881964,0.2851559752215813,3554.0,0.7432078319167132,0.2864604013587336,3581.0,6753.82816696167,7104.003720521927,6753.82816696167,343.6575689315796,5.5031914710998535,0.0 -29000,0.046456683,0.26203832,,,,,,,,,,,,,, -29100,0.057277214,0.2550917,,,,,,,,,,,,,, -29200,0.069108136,0.33265933,,,,,,,,,,,,,, -29273,,,0.7527131353105817,0.2578637940543039,0.7251399034318725,0.285449541614642,3554.0,0.7424719330319743,0.2868101817208007,3581.0,6833.869222640991,7188.128715515137,6833.869222640991,347.6977567672729,5.534913063049316,0.0 -29300,0.1595066,0.35417902,,,,,,,,,,,,,, -29400,0.11258298,0.18901093,,,,,,,,,,,,,, -29500,0.06280786,0.22508073,,,,,,,,,,,,,, -29600,0.071511626,0.20614806,,,,,,,,,,,,,, -29619,,,0.7528404508318219,0.2572024720055716,0.7247044483328644,0.2851512181202079,3554.0,0.742045556190659,0.2864549813141406,3581.0,6913.852163553238,7272.19589304924,6913.852163553238,351.7384581565857,5.566555500030518,0.0 -29700,0.03593228,0.3104535,,,,,,,,,,,,,, -29800,0.049764793,0.35871723,,,,,,,,,,,,,, -29900,0.06219408,0.2641418,,,,,,,,,,,,,, -29964,,,0.7542448725019183,0.2572521822793143,0.7264857681265827,0.2851106367824986,3554.0,0.7436717740985409,0.2864696052080424,3581.0,6993.946165800095,7356.371992826462,6993.946165800095,355.77624320983887,5.598665475845337,0.0 -30000,0.04375236,0.32270718,,,,,,,,,,,,,, -30100,0.059315734,0.22626856,,,,,,,,,,,,,, -30200,0.074118674,0.28204364,,,,,,,,,,,,,, -30300,0.042651877,0.29807457,,,,,,,,,,,,,, -30308,,,0.7539666720799038,0.2571421350751604,0.726212775767621,0.2849134832670934,3554.0,0.7434445412856046,0.2862640184851473,3581.0,7074.071465730667,7440.581089496613,7074.071465730667,359.8158543109894,5.6307532787323,0.0 -30400,0.09918418,0.28508162,,,,,,,,,,,,,, -30500,0.06329855,0.22965018,,,,,,,,,,,,,, -30600,0.037580874,0.28154296,,,,,,,,,,,,,, -30653,,,0.7541763441903251,0.2566960198538644,0.7259604604846651,0.2849078674831616,3554.0,0.7432008778972354,0.2862637798668319,3581.0,7154.043691635132,7524.636759996414,7154.043691635132,363.8551321029663,5.662577629089356,0.0 -30700,0.036762685,0.29156208,,,,,,,,,,,,,, -30800,0.034583647,0.2837308,,,,,,,,,,,,,, -30900,0.055349965,0.2173048,,,,,,,,,,,,,, -30998,,,0.7537750516619001,0.2569471086774553,0.7257893422244303,0.2848584417151009,3554.0,0.7430420262758308,0.2861960122652541,3581.0,7234.144478082657,7608.822690963745,7234.144478082657,367.8951904773712,5.695420980453491,0.0 -31000,0.03473568,0.25880724,,,,,,,,,,,,,, -31100,0.036318038,0.21340586,,,,,,,,,,,,,, -31200,0.039108023,0.2394382,,,,,,,,,,,,,, -31300,0.027126234,0.258321,,,,,,,,,,,,,, -31341,,,0.7542284556797573,0.2569677489144461,0.7263002926939716,0.2848178260300893,3554.0,0.7435245806862608,0.2861467205389556,3581.0,7314.175801992416,7692.941826343536,7314.175801992416,371.93916368484497,5.727295398712158,0.0 -31400,0.03875993,0.29959437,,,,,,,,,,,,,, -31500,0.029610658,0.31057644,,,,,,,,,,,,,, -31600,0.02389356,0.2274292,,,,,,,,,,,,,, -31686,,,0.754953111921038,0.2564511299133301,0.7265166806986846,0.2848102696235755,3554.0,0.743722633888055,0.2861874220058992,3581.0,7394.33158826828,7777.184247970581,7394.33158826828,375.9801406860352,5.7609148025512695,0.0 -31700,0.03280433,0.25988537,,,,,,,,,,,,,, -31800,0.06924339,0.27758378,,,,,,,,,,,,,, -31900,0.03911881,0.34083012,,,,,,,,,,,,,, -32000,0.024830539,0.17149603,,,,,,,,,,,,,, -32034,,,0.7544609478541783,0.2566853761672973,0.7263270148951885,0.2847692245972847,3554.0,0.743590030281346,0.2861108596149818,3581.0,7474.308657169342,7861.260461330414,7474.308657169342,380.0295407772064,5.798375606536865,0.0 -32100,0.020396229,0.27013794,,,,,,,,,,,,,, -32200,0.045133084,0.2671113,,,,,,,,,,,,,, -32300,0.018596718,0.27283978,,,,,,,,,,,,,, -32378,,,0.7543185779026577,0.2567593199866159,0.7262754252470808,0.284801098893852,3554.0,0.7435266259861072,0.2861430049109013,3581.0,7554.259472131729,7945.309270620346,7554.259472131729,384.07799434661865,5.836031198501587,0.0 -32400,0.044000976,0.24630001,,,,,,,,,,,,,, -32500,0.020057926,0.22943583,,,,,,,,,,,,,, -32600,0.03285827,0.33486226,,,,,,,,,,,,,, -32700,0.031030301,0.27565792,,,,,,,,,,,,,, -32721,,,0.7552104677472796,0.2563129663467407,0.7266483682558385,0.2848670972352894,3554.0,0.7438551011414409,0.2862320777192125,3581.0,7634.3465123176575,8029.485631227493,7634.3465123176575,388.12148785591125,5.869813680648804,0.0 -32800,0.048695482,0.2554034,,,,,,,,,,,,,, -32900,0.029352399,0.21070793,,,,,,,,,,,,,, -33000,0.079093195,0.33061075,,,,,,,,,,,,,, -33067,,,0.7548089027404785,0.2565027305058071,0.72650617042417,0.284718768410154,3554.0,0.7437527679724588,0.2860880545216943,3581.0,7714.412069559097,8113.636897563934,7714.412069559097,392.1617946624756,5.903012752532959,0.0 -33100,0.022363076,0.22419053,,,,,,,,,,,,,, -33200,0.0242529,0.35298252,,,,,,,,,,,,,, -33300,0.05258453,0.31535777,,,,,,,,,,,,,, -33400,0.025435295,0.2542861,,,,,,,,,,,,,, -33414,,,0.7544290678841727,0.2566244602203369,0.7261439437737408,0.2848105444019942,3554.0,0.743361842995148,0.2861565720665491,3581.0,7794.469235181808,8197.783804655075,7794.469235181808,396.2066998481751,5.935891389846802,0.0 -33500,0.026649047,0.26451033,,,,,,,,,,,,,, -33600,0.02425082,0.29537523,,,,,,,,,,,,,, -33700,0.019225515,0.25587288,,,,,,,,,,,,,, -33758,,,0.7553148950849261,0.2563022545405796,0.7268665423202729,0.2847635572923994,3554.0,0.7440031808503211,0.2861514929052638,3581.0,7874.45870757103,8281.86214709282,7874.45870757103,400.2509334087372,5.968505382537842,0.0 -33800,0.019913724,0.24071324,,,,,,,,,,,,,, -33900,0.018066565,0.28395107,,,,,,,,,,,,,, -34000,0.015476938,0.27860287,,,,,,,,,,,,,, -34100,0.03885747,0.25777304,,,,,,,,,,,,,, -34102,,,0.7551569257463727,0.2563475711005075,0.7268275924794246,0.2846378805131542,3554.0,0.7439999083705668,0.2860142191972389,3581.0,7954.426274776459,8365.916305065155,7954.426274776459,404.2913186550141,6.002583742141724,0.0 -34200,0.021147981,0.2667166,,,,,,,,,,,,,, -34300,0.027396759,0.22487259,,,,,,,,,,,,,, -34400,0.024769353,0.3325082,,,,,,,,,,,,,, -34447,,,0.7550387382507324,0.2563768284661429,0.7268162578696539,0.2846436508599465,3554.0,0.7439937724710276,0.2859932889621439,3581.0,8034.5121150016785,8450.093579769135,8034.5121150016785,408.3378036022186,6.035517454147339,0.0 -34500,0.019667452,0.22095767,,,,,,,,,,,,,, -34600,0.01816976,0.25742483,,,,,,,,,,,,,, -34700,0.026467936,0.24448124,,,,,,,,,,,,,, -34791,,,0.7547637394496373,0.2562845945358276,0.726274326133406,0.2846722278154895,3554.0,0.7435064456942893,0.286050932329482,3581.0,8114.659774541855,8534.329248189926,8114.659774541855,412.38020157814026,6.069193363189697,0.0 -34800,0.01771603,0.227658,,,,,,,,,,,,,, -34900,0.018547913,0.29812324,,,,,,,,,,,,,, -35000,0.022560304,0.24880792,,,,,,,,,,,,,, -35100,0.017197585,0.3013339,,,,,,,,,,,,,, -35135,,,0.7552019527980259,0.2562252964292253,0.726717131555114,0.2846202431734049,3554.0,0.7439281865226194,0.2860049130829377,3581.0,8194.64580988884,8618.405655145645,8194.64580988884,416.4238030910492,6.104076862335205,0.0 -35200,0.023494443,0.23139125,,,,,,,,,,,,,, -35300,0.012260757,0.2790005,,,,,,,,,,,,,, -35400,0.015578702,0.32163388,,,,,,,,,,,,,, -35479,,,0.755331311907087,0.2562356335776193,0.726912018148565,0.2846041171149585,3554.0,0.7441111726822117,0.2859727677870183,3581.0,8274.605695009232,8702.45828127861,8274.605695009232,420.465188741684,6.143252611160278,0.0 -35500,0.022437386,0.31559882,,,,,,,,,,,,,, -35600,0.025732191,0.21167338,,,,,,,,,,,,,, -35700,0.017396038,0.21115135,,,,,,,,,,,,,, -35800,0.025833188,0.24394003,,,,,,,,,,,,,, -35821,,,0.7551850591387067,0.2562278509140014,0.7267806053698298,0.2846070022883546,3554.0,0.7439823869685492,0.2859783923615959,3581.0,8354.700634717941,8786.64630651474,8354.700634717941,424.51155495643616,6.178053855895996,0.0 -35900,0.018384526,0.22018892,,,,,,,,,,,,,, -36000,0.018143317,0.23600495,,,,,,,,,,,,,, -36100,0.018603671,0.37263677,,,,,,,,,,,,,, -36165,,,0.7552205494471959,0.256206955228533,0.7267984659670442,0.2845921470800946,3554.0,0.7439969767741204,0.2859633934960556,3581.0,8434.776199579239,8870.805983543396,8434.776199579239,428.550342798233,6.211203336715698,0.0 -36189,,,0.7552202088492257,0.256206887108939,0.7267981224940209,0.2845921299064434,3554.0,0.7439967040674742,0.2859633594077247,3581.0,8438.036951065063,8878.14504647255,8438.036951065063,432.5937414169312,6.244928598403931,0.0 -36189,,,,,,,,,,,8438.036951065063,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/eval_measurements.csv deleted file mode 100644 index abc9f123b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -201.9382359981537,0.0,59.85227084159851,1,0,59.85227084159851,1.0479628268116448,3581,0.2770697139089116,261.79093384742737,1.0446930612836565,0.2627705335617065,1.0516581366286226,3554,0.2540589237588808 -206.35710859298703,0.0266585350036621,139.9840567111969,335,0,139.9840567111969,0.3303551267758657,3581,0.6987030617320581,346.37930965423584,0.3110035146985735,0.701645987374442,0.3285286494970456,3554,0.6807306962665307 -210.4021155834198,0.0584521293640136,220.16344499588013,575,0,220.16344499588013,0.3153754188774085,3581,0.7129093054052639,430.6433699131012,0.2962665898459298,0.7159358433314732,0.3132285062825338,3554,0.6952521174424944 -214.4450180530548,0.0861587524414062,300.4242124557495,819,0,300.4242124557495,0.3100092679353707,3581,0.7195823006841664,514.9826629161835,0.2905894688197544,0.7233612196786063,0.3076199009368845,3554,0.7024809193866066 -218.42980575561523,0.1170845031738281,380.4962332248688,1101,0,380.4962332248688,0.3074294630624302,3581,0.7217647037969491,599.080043554306,0.288079023361206,0.7253414562770298,0.3050946529218838,3554,0.7046824440770962 -222.472510099411,0.139963150024414,460.5888922214508,1448,0,460.5888922214508,0.3039168651062028,3581,0.7258353959046007,683.2501714229584,0.2847149882997785,0.7296791076660156,0.3015287503407252,3554,0.7086637085853967 -226.5124213695526,0.1640286445617675,540.682363986969,1795,0,540.682363986969,0.3018747695628839,3581,0.72690447413432,767.4192833900452,0.2828859601702009,0.7306084632873535,0.2997184757711381,3554,0.7098449123127111 -230.55322122573853,0.1883833408355713,620.6480689048767,2138,0,620.6480689048767,0.2987084749589849,3581,0.7305610611997347,851.461841583252,0.2797163895198277,0.7342242513384137,0.2965493532265581,3554,0.7135088764420371 -234.595507144928,0.2120945453643798,700.6400139331818,2483,0,700.6400139331818,0.2980554107180257,3581,0.7298840669505725,935.5314061641692,0.2790699345724923,0.7338643074035645,0.2963122537985369,3554,0.7125804001653067 -238.6415295600891,0.2383944988250732,780.7869019508362,2827,0,780.7869019508362,0.2993399953748953,3581,0.7299614474614283,1019.762139558792,0.2799262319292341,0.7344011579241071,0.2976572941579909,3554,0.7127563270478686 -242.68945455551147,0.2622091770172119,860.8440716266632,3172,0,860.8440716266632,0.2957288821427324,3581,0.7324209205267384,1103.9026482105255,0.2764573097229004,0.736753123147147,0.2941534571521525,3554,0.7148686174468908 -246.73223423957825,0.2885496616363525,940.8284964561462,3518,0,940.8284964561462,0.2952064784736282,3581,0.7318410780202806,1187.9676916599274,0.2760745627539498,0.7362453596932548,0.2935877227353862,3554,0.7144251250791361 -250.771879196167,0.3149302005767822,1020.884655714035,3864,0,1020.884655714035,0.2939418354902611,3581,0.7339011040255864,1272.1013383865356,0.2750032118388584,0.7379326139177594,0.292320410321117,3554,0.7164150703872397 -254.8140749931336,0.3396608829498291,1100.9783499240875,4210,0,1100.9783499240875,0.2939124513491343,3581,0.7353148834560876,1356.273472070694,0.2749361651284354,0.7395276342119489,0.2922343016341622,3554,0.7178736629282146 -258.85726594924927,0.3627669811248779,1181.0696263313291,4557,0,1181.0696263313291,0.2929939071881108,3581,0.7339554408248394,1440.442661523819,0.2739115953445434,0.7383965764726911,0.2914530378952413,3554,0.7164452273186902 -262.89908838272095,0.3872425556182861,1261.114327430725,4903,0,1261.114327430725,0.2929291734479719,3581,0.7338327228340548,1524.5650746822355,0.2737909896033151,0.7381105422973633,0.2913709134953573,3554,0.716312990204699 -266.94445538520813,0.4113929271697998,1341.1720926761627,5250,0,1341.1720926761627,0.2939848890520281,3581,0.7355187316741134,1608.7040948867798,0.2748546430042812,0.7396706853594098,0.2925264941351294,3554,0.7181709044826252 -270.99406599998474,0.4351603984832763,1421.244796037674,5595,0,1421.244796037674,0.2929919982415875,3581,0.734532283558189,1692.861694574356,0.2739379235676357,0.7387257984706334,0.2916653729182787,3554,0.7167135484445343 -275.0368390083313,0.4607465267181396,1501.4044897556305,5942,0,1501.4044897556305,0.2930430284727555,3581,0.7336387602319534,1777.1012864112854,0.2742792197636196,0.7375413349696568,0.2914842252457618,3554,0.7160116269740081 -279.081964969635,0.4848289489746094,1581.499148607254,6286,0,1581.499148607254,0.292247918157463,3581,0.7351854841524714,1861.2766053676603,0.2733951125826154,0.7392627171107701,0.2906868182747432,3554,0.7179266264684159 -283.1269326210022,0.5097403526306152,1661.6680040359497,6631,0,1661.6680040359497,0.2924636632029286,3581,0.7369198984222284,1945.527051448822,0.2734493698392595,0.7411042622157505,0.2909812433503623,3554,0.719641655968627 -287.1779320240021,0.5349032878875732,1741.8491139411926,6980,0,1741.8491139411926,0.2917036979566636,3581,0.7364626375532324,2029.796061038971,0.2729030166353498,0.7404102597917829,0.2903095475058913,3554,0.7189428257553109 -291.2229428291321,0.5588953495025635,1821.821462869644,7326,0,1821.821462869644,0.2917647160687482,3581,0.7352209360164759,2113.848841905594,0.2726343018668039,0.7397974559238979,0.2902227518728897,3554,0.717868167359841 -295.2652988433838,0.584242582321167,1901.7912282943728,7669,0,1901.7912282943728,0.2914618753381562,3581,0.7363083537681514,2197.8975853919983,0.272307276725769,0.7407921382359096,0.2899680322787704,3554,0.7189505882456387 -299.3083698749542,0.6124632358551025,1981.8844645023344,8018,0,1981.8844645023344,0.2915348925426731,3581,0.7379339580860793,2282.073616743088,0.2724392924989973,0.7422355924333844,0.2900204806094365,3554,0.7206700142005487 -303.35083651542664,0.6374008655548096,2062.057511568069,8365,0,2062.057511568069,0.2910017510493752,3581,0.7352480021511101,2366.3256754875183,0.2721735579626901,0.7394072668892997,0.2894907421655177,3554,0.7176794632808103 -307.3949348926544,0.6621432304382324,2142.0391342639923,8708,0,2142.0391342639923,0.2914172196248429,3581,0.7353687430187098,2450.387734413147,0.272202934537615,0.7399611473083496,0.2900589839353545,3554,0.7177564699326463 -311.43977880477905,0.6873118877410889,2222.165897130966,9057,0,2222.165897130966,0.2920888620060737,3581,0.7353536759765079,2534.596352338791,0.2729321547916957,0.7399208886282784,0.2904726971919844,3554,0.7181781861107203 -315.4843747615814,0.712233304977417,2302.3195946216583,9404,0,2302.3195946216583,0.2915744009180396,3581,0.738042086271293,2618.8310911655426,0.2726039886474609,0.7424005780901227,0.2899934492824986,3554,0.7208437428557611 -319.5306646823883,0.7383487224578857,2382.343332052231,9749,0,2382.343332052231,0.2910048530874756,3581,0.7373286856848645,2702.9389379024506,0.2721874713897705,0.7411816460745675,0.2894675577364413,3554,0.7201497212647721 -323.5765333175659,0.7633152008056641,2462.336722135544,10096,0,2462.336722135544,0.2911594436675335,3581,0.7363095809480592,2787.014926671982,0.2717173780713762,0.7412514005388532,0.2895712522421919,3554,0.7188823058085959 -327.6230306625366,0.7884683609008789,2542.3553891181946,10445,0,2542.3553891181946,0.291044531904496,3581,0.7377865601438146,2871.116565465927,0.2717570236751011,0.7422726494925362,0.2895132396485474,3554,0.7204666094761185 -331.6721098423004,0.8161625862121582,2622.469162940979,10792,0,2622.469162940979,0.2905383201925091,3581,0.7382145732250069,2955.3185493946075,0.2712354319436209,0.7428265299115863,0.2889475395790834,3554,0.7209389535778349 -335.7248303890228,0.8413472175598145,2702.5720574855804,11137,0,2702.5720574855804,0.2905359680976857,3581,0.7394820455398282,3039.51064157486,0.2709963662283761,0.7440916470118931,0.2889027163495357,3554,0.722346574722144 -339.7678611278534,0.8715095520019531,2782.628226995468,11486,0,2782.628226995468,0.2905799079560527,3581,0.7376785001352625,3123.651228427887,0.2712761334010533,0.7422352518354144,0.2890337856552476,3554,0.72035882764139 -343.8061776161194,0.8987765312194824,2862.6888043880463,11834,0,2862.6888043880463,0.2906691171176871,3581,0.7368190651398003,3207.7888326644897,0.2715799127306257,0.7413741520472935,0.2892523375400077,3554,0.7194349539031725 -347.8495783805847,0.925361156463623,2942.7759687900543,12182,0,2942.7759687900543,0.2910813473017313,3581,0.7373647511388229,3291.9573850631714,0.2713607719966343,0.7423573902675084,0.289545869585766,3554,0.7200796527680079 -351.8961908817291,0.95231032371521,3022.7916910648346,12532,0,3022.7916910648346,0.2916355894717781,3581,0.7378158761082798,3376.0582070350647,0.272442204611642,0.7423182215009417,0.2901950679472074,3554,0.7204868056898917 -355.9431412220001,0.9793925285339355,3102.942088365555,12878,0,3102.942088365555,0.2911252189834369,3581,0.7393421470303337,3460.2937593460083,0.2720416443688528,0.7436973026820591,0.2895514681960467,3554,0.722205544698755 -359.986421585083,1.0064899921417236,3183.06791138649,13224,0,3183.06791138649,0.2903072353981953,3581,0.7379479343016965,3544.5014159679413,0.2707764932087489,0.7428876331874302,0.2887189582820413,3554,0.7205708878860088 -364.0342311859131,1.0327346324920654,3263.227847099304,13571,0,3263.227847099304,0.2902935318892243,3581,0.737011936915317,3628.746710538864,0.2709683350154331,0.7417325292314801,0.2888601943892445,3554,0.7195579172455332 -368.07639956474304,1.0592927932739258,3343.3426282405853,13922,0,3343.3426282405853,0.2899280368066706,3581,0.7388716598889975,3712.941677093506,0.2707513741084507,0.7432278905596051,0.288457987478897,3554,0.7214930442591094 -372.120733499527,1.0885093212127686,3423.356700658798,14270,0,3423.356700658798,0.2913003648269512,3581,0.7358189816915666,3797.0408568382254,0.2719397204262869,0.7404216357639858,0.2898076647241664,3554,0.7183479304788618 -376.1667695045471,1.1156907081604004,3503.464207649231,14619,0,3503.464207649231,0.2898336121304279,3581,0.7387854164121405,3881.233004808426,0.2705361843109131,0.7435727800641742,0.2882630665381436,3554,0.7214944868458075 -380.2160503864288,1.1415808200836182,3583.558312177658,14967,0,3583.558312177658,0.2905270710433538,3581,0.7385827953740226,3965.413384199144,0.2712163243974958,0.743222849709647,0.2889193060965637,3554,0.721397970926245 -384.2569673061371,1.1672825813293457,3663.6104049682617,15312,0,3663.6104049682617,0.2898316009189123,3581,0.7381061041564856,4049.543503046036,0.2704992464610508,0.7425961494445801,0.2884882817995568,3554,0.7206020065419246 -388.302414894104,1.1935670375823977,3743.7452001571655,15660,0,3743.7452001571655,0.2895370095643675,3581,0.7381669859152471,4133.761077642441,0.2702544757298061,0.7428229195731026,0.2880630621966446,3554,0.7206478945378447 -392.3407049179077,1.2211003303527832,3823.7247598171234,16009,0,3823.7247598171234,0.2894750028906904,3581,0.7380999000802848,4217.8178243637085,0.2702045440673828,0.7428116798400879,0.288014769889561,3554,0.7207355488534046 -396.38701009750366,1.2483484745025637,3903.7739090919495,16355,0,3903.7739090919495,0.2903500503416469,3581,0.739279492678372,4301.95210313797,0.270837596484593,0.7441465514046806,0.288835155205842,3554,0.7220722084710889 -400.4280893802643,1.2757651805877686,3983.775414466858,16704,0,3983.775414466858,0.289581324394373,3581,0.738717307927255,4386.033557653427,0.2703778062547956,0.7429208755493164,0.2881776104499332,3554,0.7212277456958709 -404.47349190711975,1.3030242919921875,4063.8064637184143,17053,0,4063.8064637184143,0.2893098108397619,3581,0.7390477602057736,4470.148884057999,0.2700082744870867,0.7437002318246024,0.2878565318677019,3554,0.7216908160259566 -408.5127048492432,1.3316102027893066,4143.867969751358,17396,0,4143.867969751358,0.2902482966742879,3581,0.7386054982023178,4554.289497852325,0.2703984975814819,0.7439383098057338,0.2886404403489026,3554,0.721435546875 -412.55626583099365,1.3584327697753906,4223.834111213684,17744,0,4223.834111213684,0.2896581935802674,3581,0.7370970213889276,4638.337499380112,0.2704047645841326,0.7415142740522113,0.2882171785422235,3554,0.7196797814742192 -416.6016986370087,1.3848366737365725,4303.8344876766205,18092,0,4303.8344876766205,0.2898229424828958,3581,0.7402336932333845,4722.421092510223,0.2705769879477365,0.7446156229291644,0.2883928306463668,3554,0.7228888499314153 -420.6447505950928,1.4133193492889404,4383.909608125687,18436,0,4383.909608125687,0.2899425925239109,3581,0.7396001956942893,4806.579029083252,0.2705787760870797,0.7440378325326102,0.2884532475511747,3554,0.7222482727428602 -424.69260358810425,1.4410135746002195,4463.937946796417,18784,0,4463.937946796417,0.2910944372207484,3581,0.7384926658274574,4890.694290399551,0.2712665796279907,0.7434712137494769,0.2895970470662458,3554,0.7211822698675788 -428.74211287498474,1.4691977500915527,4544.043174743652,19133,0,4544.043174743652,0.2894481753743717,3581,0.7413472908230941,4974.888587236404,0.2698937143598284,0.7459725652422223,0.28780652219550157,3554,0.7241422516882386 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/measurements.csv deleted file mode 100644 index d69feee60..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/measurements.csv +++ /dev/null @@ -1,251 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.1439998,1.060551,,,,,,,,,,,,,, -1,,,0.2627705335617065,1.0446930612836565,0.2540589237588808,1.0516581366286226,3554.0,0.2770697139089116,1.0479628268116448,3581.0,59.85227084159851,261.79093384742737,59.85227084159851,201.9382359981537,0.0,0.0 -100,0.23321553,0.30819118,,,,,,,,,,,,,, -200,0.11195041,0.3158505,,,,,,,,,,,,,, -300,0.13359508,0.37293983,,,,,,,,,,,,,, -335,,,0.701645987374442,0.3110035146985735,0.6807306962665307,0.3285286494970456,3554.0,0.6987030617320581,0.3303551267758657,3581.0,139.9840567111969,346.37930965423584,139.9840567111969,206.35710859298703,0.0266585350036621,0.0 -400,0.16403794,0.3417199,,,,,,,,,,,,,, -500,0.35904095,0.2436567,,,,,,,,,,,,,, -575,,,0.7159358433314732,0.2962665898459298,0.6952521174424944,0.3132285062825338,3554.0,0.7129093054052639,0.3153754188774085,3581.0,220.16344499588013,430.6433699131012,220.16344499588013,210.4021155834198,0.0584521293640136,0.0 -600,0.22918318,0.4484424,,,,,,,,,,,,,, -700,0.15962915,0.29093578,,,,,,,,,,,,,, -800,0.15682052,0.25154912,,,,,,,,,,,,,, -819,,,0.7233612196786063,0.2905894688197544,0.7024809193866066,0.3076199009368845,3554.0,0.7195823006841664,0.3100092679353707,3581.0,300.4242124557495,514.9826629161835,300.4242124557495,214.4450180530548,0.0861587524414062,0.0 -900,0.18187135,0.30439633,,,,,,,,,,,,,, -1000,0.2045504,0.30810946,,,,,,,,,,,,,, -1100,0.11339661,0.22276036,,,,,,,,,,,,,, -1101,,,0.7253414562770298,0.288079023361206,0.7046824440770962,0.3050946529218838,3554.0,0.7217647037969491,0.3074294630624302,3581.0,380.4962332248688,599.080043554306,380.4962332248688,218.42980575561523,0.1170845031738281,0.0 -1200,0.10154707,0.22867537,,,,,,,,,,,,,, -1300,0.1795263,0.28138018,,,,,,,,,,,,,, -1400,0.10051033,0.325538,,,,,,,,,,,,,, -1448,,,0.7296791076660156,0.2847149882997785,0.7086637085853967,0.3015287503407252,3554.0,0.7258353959046007,0.3039168651062028,3581.0,460.5888922214508,683.2501714229584,460.5888922214508,222.472510099411,0.139963150024414,0.0 -1500,0.12799148,0.3558095,,,,,,,,,,,,,, -1600,0.4833647,0.26929894,,,,,,,,,,,,,, -1700,0.11267589,0.3438891,,,,,,,,,,,,,, -1795,,,0.7306084632873535,0.2828859601702009,0.7098449123127111,0.2997184757711381,3554.0,0.72690447413432,0.3018747695628839,3581.0,540.682363986969,767.4192833900452,540.682363986969,226.5124213695526,0.1640286445617675,0.0 -1800,0.07430064,0.31205225,,,,,,,,,,,,,, -1900,0.11098793,0.28589976,,,,,,,,,,,,,, -2000,0.1468534,0.32428974,,,,,,,,,,,,,, -2100,0.16765969,0.2711564,,,,,,,,,,,,,, -2138,,,0.7342242513384137,0.2797163895198277,0.7135088764420371,0.2965493532265581,3554.0,0.7305610611997347,0.2987084749589849,3581.0,620.6480689048767,851.461841583252,620.6480689048767,230.55322122573853,0.1883833408355713,0.0 -2200,0.3273626,0.27088073,,,,,,,,,,,,,, -2300,0.1850665,0.26120925,,,,,,,,,,,,,, -2400,0.10146901,0.28668258,,,,,,,,,,,,,, -2483,,,0.7338643074035645,0.2790699345724923,0.7125804001653067,0.2963122537985369,3554.0,0.7298840669505725,0.2980554107180257,3581.0,700.6400139331818,935.5314061641692,700.6400139331818,234.595507144928,0.2120945453643798,0.0 -2500,0.45730257,0.25804973,,,,,,,,,,,,,, -2600,0.26106492,0.29555973,,,,,,,,,,,,,, -2700,0.16448455,0.28302655,,,,,,,,,,,,,, -2800,0.18751283,0.36252093,,,,,,,,,,,,,, -2827,,,0.7344011579241071,0.2799262319292341,0.7127563270478686,0.2976572941579909,3554.0,0.7299614474614283,0.2993399953748953,3581.0,780.7869019508362,1019.762139558792,780.7869019508362,238.6415295600891,0.2383944988250732,0.0 -2900,0.27136546,0.29439998,,,,,,,,,,,,,, -3000,0.11570551,0.29549962,,,,,,,,,,,,,, -3100,0.116184644,0.25374565,,,,,,,,,,,,,, -3172,,,0.736753123147147,0.2764573097229004,0.7148686174468908,0.2941534571521525,3554.0,0.7324209205267384,0.2957288821427324,3581.0,860.8440716266632,1103.9026482105255,860.8440716266632,242.68945455551147,0.2622091770172119,0.0 -3200,0.06710736,0.23836656,,,,,,,,,,,,,, -3300,0.14380574,0.24846716,,,,,,,,,,,,,, -3400,0.08686723,0.30059132,,,,,,,,,,,,,, -3500,0.0847157,0.37954518,,,,,,,,,,,,,, -3518,,,0.7362453596932548,0.2760745627539498,0.7144251250791361,0.2935877227353862,3554.0,0.7318410780202806,0.2952064784736282,3581.0,940.8284964561462,1187.9676916599274,940.8284964561462,246.73223423957825,0.2885496616363525,0.0 -3600,0.21599619,0.31775892,,,,,,,,,,,,,, -3700,0.12707807,0.2818855,,,,,,,,,,,,,, -3800,0.10849389,0.312777,,,,,,,,,,,,,, -3864,,,0.7379326139177594,0.2750032118388584,0.7164150703872397,0.292320410321117,3554.0,0.7339011040255864,0.2939418354902611,3581.0,1020.884655714035,1272.1013383865356,1020.884655714035,250.771879196167,0.3149302005767822,0.0 -3900,0.134618,0.32027742,,,,,,,,,,,,,, -4000,0.34981877,0.32063198,,,,,,,,,,,,,, -4100,0.064835116,0.24862656,,,,,,,,,,,,,, -4200,0.2520938,0.3136771,,,,,,,,,,,,,, -4210,,,0.7395276342119489,0.2749361651284354,0.7178736629282146,0.2922343016341622,3554.0,0.7353148834560876,0.2939124513491343,3581.0,1100.9783499240875,1356.273472070694,1100.9783499240875,254.8140749931336,0.3396608829498291,0.0 -4300,0.33932775,0.28513783,,,,,,,,,,,,,, -4400,0.29674837,0.25459912,,,,,,,,,,,,,, -4500,0.27361515,0.25448486,,,,,,,,,,,,,, -4557,,,0.7383965764726911,0.2739115953445434,0.7164452273186902,0.2914530378952413,3554.0,0.7339554408248394,0.2929939071881108,3581.0,1181.0696263313291,1440.442661523819,1181.0696263313291,258.85726594924927,0.3627669811248779,0.0 -4600,0.09162731,0.2601068,,,,,,,,,,,,,, -4700,0.19313976,0.2633643,,,,,,,,,,,,,, -4800,0.10283897,0.24066274,,,,,,,,,,,,,, -4900,0.1425236,0.36630383,,,,,,,,,,,,,, -4903,,,0.7381105422973633,0.2737909896033151,0.716312990204699,0.2913709134953573,3554.0,0.7338327228340548,0.2929291734479719,3581.0,1261.114327430725,1524.5650746822355,1261.114327430725,262.89908838272095,0.3872425556182861,0.0 -5000,0.091771156,0.29311067,,,,,,,,,,,,,, -5100,0.10604722,0.29929593,,,,,,,,,,,,,, -5200,0.11475777,0.31058818,,,,,,,,,,,,,, -5250,,,0.7396706853594098,0.2748546430042812,0.7181709044826252,0.2925264941351294,3554.0,0.7355187316741134,0.2939848890520281,3581.0,1341.1720926761627,1608.7040948867798,1341.1720926761627,266.94445538520813,0.4113929271697998,0.0 -5300,0.15027986,0.2733723,,,,,,,,,,,,,, -5400,0.18772124,0.29706645,,,,,,,,,,,,,, -5500,0.3862569,0.29449475,,,,,,,,,,,,,, -5595,,,0.7387257984706334,0.2739379235676357,0.7167135484445343,0.2916653729182787,3554.0,0.734532283558189,0.2929919982415875,3581.0,1421.244796037674,1692.861694574356,1421.244796037674,270.99406599998474,0.4351603984832763,0.0 -5600,0.13194713,0.23527743,,,,,,,,,,,,,, -5700,0.06332706,0.43102854,,,,,,,,,,,,,, -5800,0.1874859,0.2966962,,,,,,,,,,,,,, -5900,0.18979852,0.26670957,,,,,,,,,,,,,, -5942,,,0.7375413349696568,0.2742792197636196,0.7160116269740081,0.2914842252457618,3554.0,0.7336387602319534,0.2930430284727555,3581.0,1501.4044897556305,1777.1012864112854,1501.4044897556305,275.0368390083313,0.4607465267181396,0.0 -6000,0.43335584,0.28619003,,,,,,,,,,,,,, -6100,0.19488063,0.27641276,,,,,,,,,,,,,, -6200,0.18805186,0.29880327,,,,,,,,,,,,,, -6286,,,0.7392627171107701,0.2733951125826154,0.7179266264684159,0.2906868182747432,3554.0,0.7351854841524714,0.292247918157463,3581.0,1581.499148607254,1861.2766053676603,1581.499148607254,279.081964969635,0.4848289489746094,0.0 -6300,0.22040027,0.28667253,,,,,,,,,,,,,, -6400,0.16853455,0.2887952,,,,,,,,,,,,,, -6500,0.14094588,0.35136518,,,,,,,,,,,,,, -6600,0.26334694,0.24784112,,,,,,,,,,,,,, -6631,,,0.7411042622157505,0.2734493698392595,0.719641655968627,0.2909812433503623,3554.0,0.7369198984222284,0.2924636632029286,3581.0,1661.6680040359497,1945.527051448822,1661.6680040359497,283.1269326210022,0.5097403526306152,0.0 -6700,0.14318213,0.27322108,,,,,,,,,,,,,, -6800,0.18376248,0.23325312,,,,,,,,,,,,,, -6900,0.51076305,0.28838676,,,,,,,,,,,,,, -6980,,,0.7404102597917829,0.2729030166353498,0.7189428257553109,0.2903095475058913,3554.0,0.7364626375532324,0.2917036979566636,3581.0,1741.8491139411926,2029.796061038971,1741.8491139411926,287.1779320240021,0.5349032878875732,0.0 -7000,0.2611605,0.28916746,,,,,,,,,,,,,, -7100,0.117949866,0.32961535,,,,,,,,,,,,,, -7200,0.09010268,0.33101368,,,,,,,,,,,,,, -7300,0.24315032,0.25645098,,,,,,,,,,,,,, -7326,,,0.7397974559238979,0.2726343018668039,0.717868167359841,0.2902227518728897,3554.0,0.7352209360164759,0.2917647160687482,3581.0,1821.821462869644,2113.848841905594,1821.821462869644,291.2229428291321,0.5588953495025635,0.0 -7400,0.1490071,0.2658307,,,,,,,,,,,,,, -7500,0.11833826,0.22705126,,,,,,,,,,,,,, -7600,0.10518404,0.25979853,,,,,,,,,,,,,, -7669,,,0.7407921382359096,0.272307276725769,0.7189505882456387,0.2899680322787704,3554.0,0.7363083537681514,0.2914618753381562,3581.0,1901.7912282943728,2197.8975853919983,1901.7912282943728,295.2652988433838,0.584242582321167,0.0 -7700,0.27575225,0.2021697,,,,,,,,,,,,,, -7800,0.074518524,0.30229148,,,,,,,,,,,,,, -7900,0.13443139,0.32750237,,,,,,,,,,,,,, -8000,0.107581586,0.27920735,,,,,,,,,,,,,, -8018,,,0.7422355924333844,0.2724392924989973,0.7206700142005487,0.2900204806094365,3554.0,0.7379339580860793,0.2915348925426731,3581.0,1981.8844645023344,2282.073616743088,1981.8844645023344,299.3083698749542,0.6124632358551025,0.0 -8100,0.2309425,0.21707955,,,,,,,,,,,,,, -8200,0.07589322,0.27940446,,,,,,,,,,,,,, -8300,0.17318653,0.20808116,,,,,,,,,,,,,, -8365,,,0.7394072668892997,0.2721735579626901,0.7176794632808103,0.2894907421655177,3554.0,0.7352480021511101,0.2910017510493752,3581.0,2062.057511568069,2366.3256754875183,2062.057511568069,303.35083651542664,0.6374008655548096,0.0 -8400,0.17060995,0.32332763,,,,,,,,,,,,,, -8500,0.15793419,0.3354493,,,,,,,,,,,,,, -8600,0.095741585,0.22838096,,,,,,,,,,,,,, -8700,0.2168528,0.27177748,,,,,,,,,,,,,, -8708,,,0.7399611473083496,0.272202934537615,0.7177564699326463,0.2900589839353545,3554.0,0.7353687430187098,0.2914172196248429,3581.0,2142.0391342639923,2450.387734413147,2142.0391342639923,307.3949348926544,0.6621432304382324,0.0 -8800,0.33591,0.27107045,,,,,,,,,,,,,, -8900,0.17825066,0.23099373,,,,,,,,,,,,,, -9000,0.1702957,0.33005223,,,,,,,,,,,,,, -9057,,,0.7399208886282784,0.2729321547916957,0.7181781861107203,0.2904726971919844,3554.0,0.7353536759765079,0.2920888620060737,3581.0,2222.165897130966,2534.596352338791,2222.165897130966,311.43977880477905,0.6873118877410889,0.0 -9100,0.17653413,0.28463253,,,,,,,,,,,,,, -9200,0.2786405,0.19728918,,,,,,,,,,,,,, -9300,0.06737895,0.41065827,,,,,,,,,,,,,, -9400,0.06989933,0.22325742,,,,,,,,,,,,,, -9404,,,0.7424005780901227,0.2726039886474609,0.7208437428557611,0.2899934492824986,3554.0,0.738042086271293,0.2915744009180396,3581.0,2302.3195946216583,2618.8310911655426,2302.3195946216583,315.4843747615814,0.712233304977417,0.0 -9500,0.2411969,0.2999858,,,,,,,,,,,,,, -9600,0.12542312,0.2802693,,,,,,,,,,,,,, -9700,0.08432868,0.31673744,,,,,,,,,,,,,, -9749,,,0.7411816460745675,0.2721874713897705,0.7201497212647721,0.2894675577364413,3554.0,0.7373286856848645,0.2910048530874756,3581.0,2382.343332052231,2702.9389379024506,2382.343332052231,319.5306646823883,0.7383487224578857,0.0 -9800,0.34757957,0.26702675,,,,,,,,,,,,,, -9900,0.1964202,0.3308289,,,,,,,,,,,,,, -10000,0.32308653,0.23118782,,,,,,,,,,,,,, -10096,,,0.7412514005388532,0.2717173780713762,0.7188823058085959,0.2895712522421919,3554.0,0.7363095809480592,0.2911594436675335,3581.0,2462.336722135544,2787.014926671982,2462.336722135544,323.5765333175659,0.7633152008056641,0.0 -10100,0.22246253,0.27440256,,,,,,,,,,,,,, -10200,0.42205352,0.3069168,,,,,,,,,,,,,, -10300,0.12566495,0.2671039,,,,,,,,,,,,,, -10400,0.10790673,0.24734321,,,,,,,,,,,,,, -10445,,,0.7422726494925362,0.2717570236751011,0.7204666094761185,0.2895132396485474,3554.0,0.7377865601438146,0.291044531904496,3581.0,2542.3553891181946,2871.116565465927,2542.3553891181946,327.6230306625366,0.7884683609008789,0.0 -10500,0.14519593,0.33106044,,,,,,,,,,,,,, -10600,0.18199922,0.2523072,,,,,,,,,,,,,, -10700,0.24448532,0.28946447,,,,,,,,,,,,,, -10792,,,0.7428265299115863,0.2712354319436209,0.7209389535778349,0.2889475395790834,3554.0,0.7382145732250069,0.2905383201925091,3581.0,2622.469162940979,2955.3185493946075,2622.469162940979,331.6721098423004,0.8161625862121582,0.0 -10800,0.13072406,0.31524977,,,,,,,,,,,,,, -10900,0.25992328,0.2743791,,,,,,,,,,,,,, -11000,0.07441684,0.3528436,,,,,,,,,,,,,, -11100,0.21099749,0.32426214,,,,,,,,,,,,,, -11137,,,0.7440916470118931,0.2709963662283761,0.722346574722144,0.2889027163495357,3554.0,0.7394820455398282,0.2905359680976857,3581.0,2702.5720574855804,3039.51064157486,2702.5720574855804,335.7248303890228,0.8413472175598145,0.0 -11200,0.12412653,0.29378784,,,,,,,,,,,,,, -11300,0.2571939,0.276709,,,,,,,,,,,,,, -11400,0.1036563,0.31800684,,,,,,,,,,,,,, -11486,,,0.7422352518354144,0.2712761334010533,0.72035882764139,0.2890337856552476,3554.0,0.7376785001352625,0.2905799079560527,3581.0,2782.628226995468,3123.651228427887,2782.628226995468,339.7678611278534,0.8715095520019531,0.0 -11500,0.116718374,0.2573774,,,,,,,,,,,,,, -11600,0.12649311,0.2533922,,,,,,,,,,,,,, -11700,0.20781037,0.28410944,,,,,,,,,,,,,, -11800,0.18608777,0.26539192,,,,,,,,,,,,,, -11834,,,0.7413741520472935,0.2715799127306257,0.7194349539031725,0.2892523375400077,3554.0,0.7368190651398003,0.2906691171176871,3581.0,2862.6888043880463,3207.7888326644897,2862.6888043880463,343.8061776161194,0.8987765312194824,0.0 -11900,0.26032704,0.2567512,,,,,,,,,,,,,, -12000,0.12883398,0.3322627,,,,,,,,,,,,,, -12100,0.22366206,0.24778935,,,,,,,,,,,,,, -12182,,,0.7423573902675084,0.2713607719966343,0.7200796527680079,0.289545869585766,3554.0,0.7373647511388229,0.2910813473017313,3581.0,2942.7759687900543,3291.9573850631714,2942.7759687900543,347.8495783805847,0.925361156463623,0.0 -12200,0.09738826,0.2316501,,,,,,,,,,,,,, -12300,0.12211357,0.24737425,,,,,,,,,,,,,, -12400,0.30541614,0.2671007,,,,,,,,,,,,,, -12500,0.3084966,0.3939178,,,,,,,,,,,,,, -12532,,,0.7423182215009417,0.272442204611642,0.7204868056898917,0.2901950679472074,3554.0,0.7378158761082798,0.2916355894717781,3581.0,3022.7916910648346,3376.0582070350647,3022.7916910648346,351.8961908817291,0.95231032371521,0.0 -12600,0.26467246,0.357748,,,,,,,,,,,,,, -12700,0.20067406,0.20428655,,,,,,,,,,,,,, -12800,0.13990489,0.23171893,,,,,,,,,,,,,, -12878,,,0.7436973026820591,0.2720416443688528,0.722205544698755,0.2895514681960467,3554.0,0.7393421470303337,0.2911252189834369,3581.0,3102.942088365555,3460.2937593460083,3102.942088365555,355.9431412220001,0.9793925285339355,0.0 -12900,0.21255812,0.22317494,,,,,,,,,,,,,, -13000,0.1403486,0.27985364,,,,,,,,,,,,,, -13100,0.1491596,0.23885538,,,,,,,,,,,,,, -13200,0.106344275,0.27282557,,,,,,,,,,,,,, -13224,,,0.7428876331874302,0.2707764932087489,0.7205708878860088,0.2887189582820413,3554.0,0.7379479343016965,0.2903072353981953,3581.0,3183.06791138649,3544.5014159679413,3183.06791138649,359.986421585083,1.0064899921417236,0.0 -13300,0.09279665,0.2906185,,,,,,,,,,,,,, -13400,0.12734537,0.33391953,,,,,,,,,,,,,, -13500,0.09985397,0.27256724,,,,,,,,,,,,,, -13571,,,0.7417325292314801,0.2709683350154331,0.7195579172455332,0.2888601943892445,3554.0,0.737011936915317,0.2902935318892243,3581.0,3263.227847099304,3628.746710538864,3263.227847099304,364.0342311859131,1.0327346324920654,0.0 -13600,0.21292734,0.21417275,,,,,,,,,,,,,, -13700,0.19192417,0.2852882,,,,,,,,,,,,,, -13800,0.114730194,0.2759044,,,,,,,,,,,,,, -13900,0.12669604,0.32097036,,,,,,,,,,,,,, -13922,,,0.7432278905596051,0.2707513741084507,0.7214930442591094,0.288457987478897,3554.0,0.7388716598889975,0.2899280368066706,3581.0,3343.3426282405853,3712.941677093506,3343.3426282405853,368.07639956474304,1.0592927932739258,0.0 -14000,0.18202826,0.2188602,,,,,,,,,,,,,, -14100,0.083934136,0.40111113,,,,,,,,,,,,,, -14200,0.24594899,0.24062112,,,,,,,,,,,,,, -14270,,,0.7404216357639858,0.2719397204262869,0.7183479304788618,0.2898076647241664,3554.0,0.7358189816915666,0.2913003648269512,3581.0,3423.356700658798,3797.0408568382254,3423.356700658798,372.120733499527,1.0885093212127686,0.0 -14300,0.16820273,0.2920773,,,,,,,,,,,,,, -14400,0.16260266,0.24064115,,,,,,,,,,,,,, -14500,0.33275694,0.28683177,,,,,,,,,,,,,, -14600,0.055555265,0.33839512,,,,,,,,,,,,,, -14619,,,0.7435727800641742,0.2705361843109131,0.7214944868458075,0.2882630665381436,3554.0,0.7387854164121405,0.2898336121304279,3581.0,3503.464207649231,3881.233004808426,3503.464207649231,376.1667695045471,1.1156907081604004,0.0 -14700,0.09647025,0.34789354,,,,,,,,,,,,,, -14800,0.11393394,0.37827387,,,,,,,,,,,,,, -14900,0.0968937,0.23897478,,,,,,,,,,,,,, -14967,,,0.743222849709647,0.2712163243974958,0.721397970926245,0.2889193060965637,3554.0,0.7385827953740226,0.2905270710433538,3581.0,3583.558312177658,3965.413384199144,3583.558312177658,380.2160503864288,1.1415808200836182,0.0 -15000,0.20400092,0.23943421,,,,,,,,,,,,,, -15100,0.258098,0.24310711,,,,,,,,,,,,,, -15200,0.17027007,0.31174922,,,,,,,,,,,,,, -15300,0.16301365,0.21156137,,,,,,,,,,,,,, -15312,,,0.7425961494445801,0.2704992464610508,0.7206020065419246,0.2884882817995568,3554.0,0.7381061041564856,0.2898316009189123,3581.0,3663.6104049682617,4049.543503046036,3663.6104049682617,384.2569673061371,1.1672825813293457,0.0 -15400,0.2477064,0.27179295,,,,,,,,,,,,,, -15500,0.28314736,0.2653592,,,,,,,,,,,,,, -15600,0.33519346,0.29489166,,,,,,,,,,,,,, -15660,,,0.7428229195731026,0.2702544757298061,0.7206478945378447,0.2880630621966446,3554.0,0.7381669859152471,0.2895370095643675,3581.0,3743.7452001571655,4133.761077642441,3743.7452001571655,388.302414894104,1.1935670375823977,0.0 -15700,0.1330252,0.22957152,,,,,,,,,,,,,, -15800,0.37788087,0.28512147,,,,,,,,,,,,,, -15900,0.16519557,0.2698673,,,,,,,,,,,,,, -16000,0.4019501,0.30098924,,,,,,,,,,,,,, -16009,,,0.7428116798400879,0.2702045440673828,0.7207355488534046,0.288014769889561,3554.0,0.7380999000802848,0.2894750028906904,3581.0,3823.7247598171234,4217.8178243637085,3823.7247598171234,392.3407049179077,1.2211003303527832,0.0 -16100,0.16716124,0.33790338,,,,,,,,,,,,,, -16200,0.13192566,0.2709756,,,,,,,,,,,,,, -16300,0.11477423,0.2618487,,,,,,,,,,,,,, -16355,,,0.7441465514046806,0.270837596484593,0.7220722084710889,0.288835155205842,3554.0,0.739279492678372,0.2903500503416469,3581.0,3903.7739090919495,4301.95210313797,3903.7739090919495,396.38701009750366,1.2483484745025637,0.0 -16400,0.07725756,0.25477254,,,,,,,,,,,,,, -16500,0.1290388,0.24905026,,,,,,,,,,,,,, -16600,0.24450585,0.30756786,,,,,,,,,,,,,, -16700,0.23748122,0.26821423,,,,,,,,,,,,,, -16704,,,0.7429208755493164,0.2703778062547956,0.7212277456958709,0.2881776104499332,3554.0,0.738717307927255,0.289581324394373,3581.0,3983.775414466858,4386.033557653427,3983.775414466858,400.4280893802643,1.2757651805877686,0.0 -16800,0.24697572,0.22716975,,,,,,,,,,,,,, -16900,0.14218567,0.3770893,,,,,,,,,,,,,, -17000,0.054899372,0.22830474,,,,,,,,,,,,,, -17053,,,0.7437002318246024,0.2700082744870867,0.7216908160259566,0.2878565318677019,3554.0,0.7390477602057736,0.2893098108397619,3581.0,4063.8064637184143,4470.148884057999,4063.8064637184143,404.47349190711975,1.3030242919921875,0.0 -17100,0.2118919,0.33107314,,,,,,,,,,,,,, -17200,0.2845879,0.27439722,,,,,,,,,,,,,, -17300,0.1590136,0.26192993,,,,,,,,,,,,,, -17396,,,0.7439383098057338,0.2703984975814819,0.721435546875,0.2886404403489026,3554.0,0.7386054982023178,0.2902482966742879,3581.0,4143.867969751358,4554.289497852325,4143.867969751358,408.5127048492432,1.3316102027893066,0.0 -17400,0.23000272,0.24867386,,,,,,,,,,,,,, -17500,0.16290863,0.33393323,,,,,,,,,,,,,, -17600,0.15127419,0.29274514,,,,,,,,,,,,,, -17700,0.1598309,0.2652308,,,,,,,,,,,,,, -17744,,,0.7415142740522113,0.2704047645841326,0.7196797814742192,0.2882171785422235,3554.0,0.7370970213889276,0.2896581935802674,3581.0,4223.834111213684,4638.337499380112,4223.834111213684,412.55626583099365,1.3584327697753906,0.0 -17800,0.1426722,0.2723636,,,,,,,,,,,,,, -17900,0.12516066,0.2914284,,,,,,,,,,,,,, -18000,0.12114365,0.26136124,,,,,,,,,,,,,, -18092,,,0.7446156229291644,0.2705769879477365,0.7228888499314153,0.2883928306463668,3554.0,0.7402336932333845,0.2898229424828958,3581.0,4303.8344876766205,4722.421092510223,4303.8344876766205,416.6016986370087,1.3848366737365725,0.0 -18100,0.14782114,0.30530822,,,,,,,,,,,,,, -18200,0.13401692,0.32029945,,,,,,,,,,,,,, -18300,0.15461011,0.2692095,,,,,,,,,,,,,, -18400,0.2399299,0.23631713,,,,,,,,,,,,,, -18436,,,0.7440378325326102,0.2705787760870797,0.7222482727428602,0.2884532475511747,3554.0,0.7396001956942893,0.2899425925239109,3581.0,4383.909608125687,4806.579029083252,4383.909608125687,420.6447505950928,1.4133193492889404,0.0 -18500,0.19407348,0.28969058,,,,,,,,,,,,,, -18600,0.13763873,0.3866493,,,,,,,,,,,,,, -18700,0.15419795,0.25125557,,,,,,,,,,,,,, -18784,,,0.7434712137494769,0.2712665796279907,0.7211822698675788,0.2895970470662458,3554.0,0.7384926658274574,0.2910944372207484,3581.0,4463.937946796417,4890.694290399551,4463.937946796417,424.69260358810425,1.4410135746002195,0.0 -18800,0.07853955,0.29916212,,,,,,,,,,,,,, -18900,0.11856777,0.2874158,,,,,,,,,,,,,, -19000,0.16401742,0.31017295,,,,,,,,,,,,,, -19100,0.26055232,0.25488353,,,,,,,,,,,,,, -19133,,,0.7459725652422223,0.2698937143598284,0.7241422516882386,0.2878065221955015,3554.0,0.7413472908230941,0.2894481753743717,3581.0,4544.043174743652,4974.888587236404,4544.043174743652,428.7421128749848,1.4691977500915527,0.0 -19133,,,,,,,,,,,4544.043174743652,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 706b66da6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.360684394836426,0.0,55.74225401878357,1,0,55.74225401878357,0.0013000001199543,6.9117279052734375,10000,93.10303163528442,0.0007573341717943,6.911828994750977,0.0006799999973736,6.912051200866699,50000 -55.27424716949463,0.0264043807983398,565.8360676765442,1500,0,565.8360676765442,0.0481000021100044,5.63258171081543,10000,621.1860675811768,0.0748166441917419,5.330939769744873,0.0686400011181831,5.407910346984863,50000 -73.62649869918823,0.0525565147399902,1075.7997715473175,2997,0,1075.7997715473175,0.1221000030636787,4.7546820640563965,10000,1149.57959151268,0.1820392161607742,4.255775451660156,0.1641199886798858,4.366084098815918,50000 -91.51947140693665,0.0799648761749267,1585.9234120845797,4495,0,1585.9234120845797,0.1856000125408172,4.202426910400391,10000,1677.6729209423063,0.2831034660339355,3.511129856109619,0.2563399970531463,3.65845799446106,50000 -109.42569613456726,0.1053483486175537,2095.983659267425,5993,0,2095.983659267425,0.2578000128269195,3.696389198303223,10000,2205.7152502536774,0.3771125674247741,2.901236057281494,0.3469599783420563,3.06270432472229,50000 -127.3788537979126,0.1332588195800781,2605.98605966568,7491,0,2605.98605966568,0.3150000274181366,3.4180727005004883,10000,2733.749913215637,0.437220960855484,2.591935873031616,0.4038199782371521,2.763052701950073,50000 -145.59150886535645,0.1610162258148193,3116.0023124217987,8990,0,3116.0023124217987,0.3545000255107879,3.1170003414154053,10000,3262.0570573806763,0.4954759180545807,2.285950183868408,0.4610599875450134,2.458083391189575,50000 -167.5162229537964,0.1880600452423095,3626.022510528565,10489,0,3626.022510528565,0.3785000145435333,2.9464919567108154,10000,3794.081609249115,0.5600087642669678,1.9418503046035769,0.4983799755573272,2.250006437301636,50000 -185.5975196361541,0.2156145572662353,4136.248651504517,11990,0,4136.248651504517,0.406900018453598,2.8000824451446533,10000,4322.467486619949,0.5775271058082581,1.87908947467804,0.5301399827003479,2.1053919792175293,50000 -203.5793845653534,0.2452063560485839,4646.346125364304,13491,0,4646.346125364304,0.424200028181076,2.7253708839416504,10000,4850.625118970871,0.5838648080825806,1.834944128990173,0.5408999919891357,2.040344953536988,50000 -221.64778184890747,0.2781078815460205,5156.277529478073,14992,0,5156.277529478073,0.4354000091552734,2.653869152069092,10000,5378.708662033081,0.5969586968421936,1.7865535020828247,0.5550999641418457,1.9808624982833865,50000 -239.83745956420896,0.3100862503051758,5666.510581970215,16495,0,5666.510581970215,0.4358000159263611,2.657799005508423,10000,5907.21281003952,0.5969586968421936,1.7959184646606443,0.5592600107192993,1.978746771812439,50000 -259.0123672485352,0.342235803604126,6176.487642765045,17998,0,6176.487642765045,0.4360000193119049,2.6286208629608154,10000,6436.447620630264,0.5971181392669678,1.7574492692947388,0.5580799579620361,1.9484416246414185,50000 -277.65295243263245,0.3832590579986572,6686.56877040863,19501,0,6686.56877040863,0.4470000267028808,2.5772900581359863,10000,6965.261632204056,0.6540975570678711,1.5059301853179932,0.5709599852561951,1.882540702819824,50000 -299.892765045166,0.415421724319458,7196.738107442856,21004,0,7196.738107442856,0.4496000111103058,2.5562584400177,10000,7497.752900362015,0.6272919178009033,1.6106610298156738,0.5689799785614014,1.8841341733932493,50000 -321.13594365119934,0.4556243419647217,7706.86282658577,22508,0,7706.86282658577,0.4611000120639801,2.4872632026672363,10000,8029.213074207306,0.6285873651504517,1.606850266456604,0.5818799734115601,1.833842158317566,50000 -343.1451985836029,0.5067691802978516,8216.786780834198,24012,0,8216.786780834198,0.4627000093460083,2.5038511753082275,10000,8561.249808549881,0.6264548897743225,1.6274893283843994,0.5815799832344055,1.8446393013000488,50000 -364.8770282268524,0.5437021255493164,8726.874703884125,25516,0,8726.874703884125,0.4686000347137451,2.4981985092163086,10000,9093.15644454956,0.631277859210968,1.6501520872116089,0.5895199775695801,1.8403079509735107,50000 -386.8250911235809,0.5786194801330566,9237.083882570269,27021,0,9237.083882570269,0.4698000252246856,2.4951579570770264,10000,9625.399262428284,0.6287069320678711,1.6057114601135254,0.5907599925994873,1.7863342761993408,50000 -409.0597383975983,1.13085675239563,9746.664439201357,28524,0,9746.664439201357,0.473000019788742,2.4163501262664795,10000,10157.818138360975,0.683992326259613,1.3645371198654177,0.597000002861023,1.757430911064148,50000 -432.0158112049103,1.1649293899536133,10256.616863250732,30028,0,10256.616863250732,0.4752000272274017,2.4623258113861084,10000,10690.81227874756,0.6511877775192261,1.5284302234649658,0.5967199802398682,1.7889021635055542,50000 -457.0925896167755,1.210730791091919,10766.595863342283,31532,0,10766.595863342283,0.4620000123977661,2.532831907272339,10000,11225.963517665865,0.6303810477256775,1.6111574172973633,0.5847600102424622,1.82930588722229,50000 -480.23394536972046,1.2460718154907229,11276.758833408356,33037,0,11276.758833408356,0.4825000166893005,2.373618125915528,10000,11759.353006362917,0.649832546710968,1.484278440475464,0.6013000011444092,1.7228082418441772,50000 -503.3027272224426,1.2807552814483645,11786.717982053757,34542,0,11786.717982053757,0.4793000221252441,2.3868050575256348,10000,12292.46637749672,0.6471221446990967,1.524101734161377,0.602180004119873,1.7275532484054563,50000 -526.7157201766968,1.3211784362792969,12296.627568244934,36047,0,12296.627568244934,0.4851000308990478,2.3932998180389404,10000,12825.880003213882,0.6511877775192261,1.5404627323150637,0.6087799668312073,1.7445299625396729,50000 -548.5404841899872,1.3562448024749756,12806.58596277237,37552,0,12806.58596277237,0.4791000187397003,2.4134533405303955,10000,13357.748576164246,0.6801658272743225,1.3729093074798584,0.5968199968338013,1.7451149225234983,50000 -569.4229211807251,1.388034105300903,13316.641247034073,39057,0,13316.641247034073,0.4818000197410583,2.383208036422729,10000,13888.769673347471,0.6672711968421936,1.4364635944366455,0.604640007019043,1.7295138835906982,50000 -589.8896474838257,1.4172828197479248,13826.662028312683,40562,0,13826.662028312683,0.490200012922287,2.349595308303833,10000,14419.336757183077,0.6653180718421936,1.4576128721237185,0.6128999590873718,1.690018653869629,50000 -610.3627202510834,1.4521598815917969,14336.611745595932,42067,0,14336.611745595932,0.4836000204086303,2.381596803665161,10000,14949.84656405449,0.6569873690605164,1.500503659248352,0.608460009098053,1.7169100046157837,50000 -632.3841454982758,1.4866185188293457,14846.80584168434,43573,0,14846.80584168434,0.4910000264644623,2.3547487258911133,10000,15482.147728919985,0.6497528553009033,1.5108873844146729,0.6085599660873413,1.701366662979126,50000 -651.5147202014923,1.5189547538757324,15357.030562639236,45079,0,15357.030562639236,0.4716000258922577,2.438140869140625,10000,16011.586278438568,0.6457469463348389,1.5189896821975708,0.6038999557495117,1.721349000930786,50000 -671.9697597026825,1.5562007427215576,15867.12254691124,46585,0,15867.12254691124,0.4827000200748443,2.392597913742065,10000,16542.223130464554,0.6690250039100647,1.4443752765655518,0.603119969367981,1.738900899887085,50000 -691.9411878585815,1.5927386283874512,16377.347776174543,48091,0,16377.347776174543,0.4932000339031219,2.372044324874878,10000,17072.506830453873,0.6783920526504517,1.3911480903625488,0.6177600026130676,1.6825976371765137,50000 -714.2665731906891,1.6319713592529297,16887.48326063156,49597,0,16887.48326063156,0.5010000467300415,2.282940149307251,10000,17605.0581843853,0.6721141338348389,1.4001306295394895,0.6206600069999695,1.6346431970596311,50000 -735.2593734264374,1.6691064834594729,17397.699914216995,51104,0,17397.699914216995,0.4819000363349914,2.446437358856201,10000,18136.358857870106,0.6506895422935486,1.5444890260696411,0.6065199971199036,1.7607401609420776,50000 -753.8350718021393,1.702005386352539,17907.79274368286,52610,0,17907.79274368286,0.4906000196933746,2.3692007064819336,10000,18665.110434770584,0.6633649468421936,1.458517074584961,0.6189799904823303,1.6728732585906982,50000 -771.2206614017487,1.7383487224578855,18417.99221158028,54117,0,18417.99221158028,0.4872000217437744,2.394115447998047,10000,19192.78498077393,0.652363657951355,1.5211600065231323,0.6086999773979187,1.7170997858047483,50000 -788.768620967865,1.7768800258636477,18927.954924583435,55623,0,18927.954924583435,0.4896000325679779,2.345250606536865,10000,19720.38703656197,0.6726123690605164,1.4373608827590942,0.6170600056648254,1.682947874069214,50000 -805.8777091503143,1.8174619674682613,19438.130825281143,57129,0,19438.130825281143,0.4963000118732452,2.3328607082366943,10000,20247.76454782486,0.6888552308082581,1.3610378503799438,0.6221599578857422,1.6592100858688354,50000 -823.1925563812256,1.8636653423309328,19948.24530482292,58636,0,19948.24530482292,0.4946000277996063,2.325051784515381,10000,20775.29114437104,0.6800262928009033,1.3760676383972168,0.620959997177124,1.6456332206726074,50000 -840.5147912502289,1.910284280776977,20458.460973262787,60143,0,20458.460973262787,0.503600001335144,2.290309190750122,10000,21302.928270578384,0.6727718114852905,1.408918380737305,0.6265400052070618,1.6337388753890991,50000 -858.6574947834015,1.9475953578948968,20968.66392183304,61651,0,20968.66392183304,0.5010000467300415,2.281240224838257,10000,21831.363361120224,0.6753029227256775,1.3835617303848269,0.6267799735069275,1.610917329788208,50000 -876.3508095741272,1.9852290153503416,21478.79666543007,63158,0,21478.79666543007,0.5006000399589539,2.285511016845703,10000,22359.277554273605,0.6793287396430969,1.402336597442627,0.631060004234314,1.6216193437576294,50000 -894.2573552131653,2.025351047515869,21988.747662067413,64665,0,21988.747662067413,0.5010000467300415,2.2789101600646973,10000,22887.225957870483,0.6847297549247742,1.3605037927627563,0.6293999552726746,1.600814700126648,50000 -911.7745883464812,2.06522798538208,22498.691545009613,66172,0,22498.691545009613,0.5067000389099121,2.267800807952881,10000,23414.77932405472,0.7006337642669678,1.267736554145813,0.6306799650192261,1.588563084602356,50000 -929.0643086433412,2.106299638748169,23008.63331103325,67679,0,23008.63331103325,0.5098000168800354,2.2622320652008057,10000,23942.103929281235,0.6891342401504517,1.355064868927002,0.6316999793052673,1.623766541481018,50000 -946.126077890396,2.1450648307800293,23518.742659330368,69186,0,23518.742659330368,0.5093000531196594,2.243839263916016,10000,24469.36714053154,0.6904296875,1.3263139724731443,0.6378600001335144,1.573560118675232,50000 -964.0205476284028,2.1862633228302,24028.91832447052,70693,0,24028.91832447052,0.5091000199317932,2.269750595092773,10000,24997.529990196228,0.6878587007522583,1.3525885343551636,0.6403399705886841,1.573731780052185,50000 -981.154138803482,2.2252821922302246,24538.978356838223,72200,0,24538.978356838223,0.5123000144958496,2.238888502120972,10000,25524.813975811005,0.6857063174247742,1.3495975732803345,0.6386199593544006,1.5710564851760864,50000 -998.1271076202391,2.264722347259521,25049.043762922287,73707,0,25049.043762922287,0.5172000527381897,2.202910900115967,10000,26051.94483280182,0.6907086968421936,1.3185800313949585,0.6348400115966797,1.5660574436187744,50000 -1016.093424797058,2.3037054538726807,25559.250710248947,75215,0,25559.250710248947,0.5069000124931335,2.2694790363311768,10000,26580.20965051651,0.7098612785339355,1.2410625219345093,0.6350199580192566,1.5747349262237549,50000 -1033.3393914699554,2.347717046737671,26069.155710458755,76722,0,26069.155710458755,0.5156000256538391,2.2022013664245605,10000,27107.455953359604,0.7016701102256775,1.2833776473999023,0.6454399824142456,1.553180456161499,50000 -1050.3499314785004,2.3884060382843018,26579.27547645569,78229,0,26579.27547645569,0.5181000232696533,2.194598197937012,10000,27634.67905735969,0.7018494606018066,1.2792024612426758,0.6469199657440186,1.5287810564041138,50000 -1067.4200563430786,2.4313290119171143,27089.258211374283,79736,0,27089.258211374283,0.5218000411987305,2.1807165145874023,10000,28161.826422214508,0.702168345451355,1.2795171737670898,0.6496399641036987,1.508771896362305,50000 -1084.517686367035,2.4775948524475098,27599.380972385406,81243,0,27599.380972385406,0.5243000388145447,2.222846031188965,10000,28689.146060228348,0.6947743892669678,1.3432766199111938,0.6473199725151062,1.556182026863098,50000 -1101.7360954284668,2.5183119773864746,28109.40303182602,82750,0,28109.40303182602,0.5206000208854675,2.15671443939209,10000,29216.481212615967,0.6988998651504517,1.2671300172805786,0.6509999632835388,1.4982835054397583,50000 -1119.0488619804382,2.564819812774658,28619.604907035828,84258,0,28619.604907035828,0.5231000185012817,2.192289113998413,10000,29744.09475159645,0.7250677347183228,1.181086778640747,0.653499960899353,1.501522421836853,50000 -1136.2286508083344,2.6103806495666504,29129.81632399559,85765,0,29129.81632399559,0.5260000228881836,2.19507384300232,10000,30271.58340406418,0.7072703838348389,1.2629165649414062,0.644540011882782,1.5403581857681274,50000 -1153.7549715042114,2.653380870819092,29639.82198214531,87272,0,29639.82198214531,0.5247000455856323,2.167389154434204,10000,30799.211097955704,0.7074697017669678,1.2444090843200684,0.6504600048065186,1.5148764848709106,50000 -1170.8653423786163,2.6973180770874023,30149.93711233139,88780,0,30149.93711233139,0.5261000394821167,2.174659013748169,10000,31326.53332161904,0.702566921710968,1.2763031721115112,0.6502799987792969,1.5065932273864746,50000 -1187.9223272800446,2.740720748901367,30659.994156122208,90287,0,30659.994156122208,0.5333000421524048,2.1418609619140625,10000,31853.74282312393,0.7116549611091614,1.246743083000183,0.6574999690055847,1.4754823446273804,50000 -1204.966181755066,2.859598398208618,31169.86101269722,91793,0,31169.86101269722,0.5433000326156616,2.08927321434021,10000,32380.824385404587,0.7194674611091614,1.1982098817825315,0.6654799580574036,1.4386276006698608,50000 -1222.0526728630066,2.903446912765503,31679.799534082413,93300,0,31679.799534082413,0.5268000364303589,2.1804299354553223,10000,32907.94575691223,0.722676157951355,1.201149821281433,0.6536999940872192,1.525767803192139,50000 -1239.2692143917084,2.946500778198242,32189.77259206772,94807,0,32189.77259206772,0.5247000455856323,2.177957057952881,10000,33435.2316634655,0.7135283946990967,1.234078884124756,0.650439977645874,1.527204155921936,50000 -1256.4928257465365,2.9893760681152344,32699.811596155167,96315,0,32699.811596155167,0.532800018787384,2.1280345916748047,10000,33962.59027004242,0.7221978306770325,1.189970850944519,0.6604399681091309,1.467153549194336,50000 -1273.975423336029,3.037425756454468,33209.789820194244,97822,0,33209.789820194244,0.5423000454902649,2.084092617034912,10000,34490.15003442764,0.725027859210968,1.1888576745986938,0.668999969959259,1.4445313215255735,50000 -1291.060376405716,3.081228733062744,33719.99088358879,99329,0,33719.99088358879,0.5276000499725342,2.150489568710327,10000,35017.532859802246,0.7132493257522583,1.2412854433059692,0.6611199975013733,1.480249047279358,50000 -1308.0780620574951,3.127570152282715,34229.93264293671,100836,0,34229.93264293671,0.5448000431060791,2.0938448905944824,10000,35544.59036445618,0.7284757494926453,1.1622774600982666,0.6675999760627747,1.4319740533828735,50000 -1325.09671998024,3.1725218296051025,34739.89816379547,102343,0,34739.89816379547,0.5517000555992126,2.0450947284698486,10000,36071.67170572281,0.7571149468421936,1.045891046524048,0.675879955291748,1.4058371782302856,50000 -1342.2405395507812,3.224278211593628,35250.05817198753,103850,0,35250.05817198753,0.5487000346183777,2.0953757762908936,10000,36599.079026699066,0.7410514950752258,1.109355092048645,0.6710000038146973,1.4239152669906616,50000 -1359.4712007045746,3.270002841949463,35759.976840257645,105356,0,35759.976840257645,0.5460000038146973,2.064647674560547,10000,37126.32803606987,0.7430245280265808,1.116188883781433,0.6763399839401245,1.4128817319869995,50000 -1376.36949300766,3.318311929702759,36269.989028692245,106863,0,36269.989028692245,0.5570999979972839,2.026524066925049,10000,37653.3395678997,0.7424864172935486,1.115410089492798,0.6801799535751343,1.388283610343933,50000 -1393.7109084129331,3.367356061935425,36780.031766176224,108370,0,36780.031766176224,0.5527999997138977,2.0306382179260254,10000,38180.82426738739,0.7401745915412903,1.10150146484375,0.6816999912261963,1.3689839839935305,50000 -1410.958943605423,3.413281202316284,37290.11735486984,109877,0,37290.11735486984,0.5592000484466553,2.0068023204803467,10000,38708.2563290596,0.7429846525192261,1.0977813005447388,0.6827200055122375,1.3649893999099731,50000 -1428.169054031372,3.4625115394592285,37800.23961639404,111385,0,37800.23961639404,0.551300048828125,2.0405306816101074,10000,39235.68991231918,0.7664620280265808,1.0001466274261477,0.6792399883270264,1.3854286670684814,50000 -1445.4499547481537,3.506695508956909,38310.37893438339,112892,0,38310.37893438339,0.5605000257492065,2.010292291641236,10000,39763.20695757866,0.7581512928009033,1.031977891921997,0.6830799579620361,1.3659889698028564,50000 -1462.527009487152,3.55802059173584,38820.30667281151,114399,0,38820.30667281151,0.5693000555038452,1.9787520170211792,10000,40290.31508851051,0.7623365521430969,1.0210669040679932,0.6910600066184998,1.3425350189208984,50000 -1479.7441306114197,3.622938871383667,39330.32674264908,115906,0,39330.32674264908,0.5593000054359436,2.0154571533203125,10000,40817.67096376419,0.7527702450752258,1.0518150329589844,0.6881200075149536,1.3431140184402466,50000 -1496.8971843719482,3.6775968074798575,39840.43624377251,117413,0,39840.43624377251,0.5701000094413757,1.969575047492981,10000,41345.04013109207,0.7529894709587097,1.057737946510315,0.6910399794578552,1.3415244817733765,50000 -1514.122484207153,3.725899457931519,40350.61412215233,118920,0,40350.61412215233,0.5654000043869019,1.998255014419556,10000,41872.54306650162,0.7530691623687744,1.0459004640579224,0.6911999583244324,1.3308132886886597,50000 -1531.4665586948397,3.774169683456421,40860.52008938789,120427,0,40860.52008938789,0.5664000511169434,1.9804224967956543,10000,42399.89472198486,0.7836814522743225,0.931477665901184,0.693120002746582,1.3276666402816772,50000 -1548.710284948349,3.828591108322144,41370.45637798309,121934,0,41370.45637798309,0.5767000317573547,1.9412811994552608,10000,42927.18236398697,0.7770846486091614,0.960997998714447,0.7005199790000916,1.2985070943832395,50000 -1565.789316654205,3.87968111038208,41880.4911146164,123441,0,41880.4911146164,0.5712000131607056,1.9497255086898804,10000,43454.39989852905,0.7704280614852905,0.9899856448173524,0.6974999904632568,1.3206807374954224,50000 -1582.8079690933228,3.928217649459839,42390.486533641815,124948,0,42390.486533641815,0.5805000066757202,1.8990615606307983,10000,43981.515066862106,0.7798947691917419,0.9331660270690918,0.7085399627685547,1.2580126523971558,50000 -1599.7522914409635,3.979759454727173,42900.42901682854,126455,0,42900.42901682854,0.5834000110626221,1.9199188947677608,10000,44508.50655436516,0.7784597873687744,0.9620999693870544,0.7042999863624573,1.283419847488403,50000 -1616.7470650672913,4.031066417694092,43410.41619229317,127961,0,43410.41619229317,0.5796000361442566,1.9083911180496216,10000,45035.59116268158,0.7736966013908386,0.9492216110229492,0.7021399736404419,1.265101194381714,50000 -1634.0042452812197,4.109456300735474,43920.58583164215,129469,0,43920.58583164215,0.5788000226020813,1.9307630062103271,10000,45563.149040699005,0.805683970451355,0.8402982354164124,0.7069199681282043,1.2687928676605225,50000 -1651.078492641449,4.160379648208618,44430.66488528252,130977,0,44430.66488528252,0.5836000442504883,1.9029459953308103,10000,46090.4052464962,0.7959781289100647,0.8902884125709534,0.7091799974441528,1.2596995830535889,50000 -1668.14599943161,4.21102237701416,44940.60585641861,132484,0,44940.60585641861,0.5871000289916992,1.8773962259292605,10000,46617.5152451992,0.7964564561843872,0.8781556487083435,0.7150799632072449,1.2385895252227783,50000 -1685.284812450409,4.261689901351929,45450.668227911,133991,0,45450.668227911,0.5932000279426575,1.840102195739746,10000,47144.82175087929,0.8026745915412903,0.8502547144889832,0.7177799940109253,1.2106214761734009,50000 -1702.4054865837095,4.315122365951538,45960.74799871445,135498,0,45960.74799871445,0.5916000008583069,1.859628438949585,10000,47672.127898454666,0.7995057106018066,0.8691868782043457,0.7166799902915955,1.2189016342163086,50000 -1719.3394558429718,4.369931697845459,46470.69027876854,137005,0,46470.69027876854,0.5957000255584717,1.8467576503753664,10000,48199.111078739166,0.7967952489852905,0.862323522567749,0.7177000045776367,1.208281636238098,50000 -1736.362550497055,4.420541524887085,46980.68811249733,138512,0,46980.68811249733,0.5956000089645386,1.8367879390716555,10000,48726.23480153084,0.8315330147743225,0.7364321947097778,0.7225199937820435,1.1886651515960691,50000 -1753.3998112678528,4.472010850906372,47490.778540849686,140019,0,47490.778540849686,0.6030000448226929,1.794895052909851,10000,49253.46666812897,0.8246771097183228,0.7462592720985413,0.729919970035553,1.1638081073760986,50000 -1770.5072956085205,4.523514747619629,48000.73310112953,141526,0,48000.73310112953,0.6044000387191772,1.7896476984024048,10000,49780.63418030739,0.8252949714660645,0.7593898177146912,0.7310400009155273,1.1624298095703125,50000 -1787.6722025871277,4.576416730880737,48510.89326953888,143033,0,48510.89326953888,0.6105000376701355,1.773678183555603,10000,50308.06536364555,0.8251753449440002,0.7437108159065247,0.7314800024032593,1.141937017440796,50000 -1804.7232937812803,4.630538702011108,49020.87477731705,144539,0,49020.87477731705,0.6085000038146973,1.8027182817459104,10000,50835.20391178131,0.8236008882522583,0.7736658453941345,0.7317599654197693,1.1646684408187866,50000 -1821.6753494739528,4.683336019515991,49530.97762846947,146046,0,49530.97762846947,0.615600049495697,1.7666029930114746,10000,51362.36286592484,0.8278858065605164,0.7407735586166382,0.7351599931716919,1.1394431591033936,50000 -1838.5789158344269,4.734616041183472,50040.92667126656,147552,0,50040.92667126656,0.6152999997138977,1.7525743246078491,10000,51889.318643569946,0.85843825340271,0.6291110515594482,0.7379999756813049,1.130968689918518,50000 -1855.714731216431,4.789153575897217,50551.148307561874,149059,0,50551.148307561874,0.6147000193595886,1.7563661336898804,10000,52416.78352403641,0.8477559089660645,0.667382538318634,0.7384999990463257,1.130212664604187,50000 -1873.110155582428,4.844248533248901,51061.33330345154,150566,0,51061.33330345154,0.6205000281333923,1.735827088356018,10000,52944.4701063633,0.8537547588348389,0.6530286073684692,0.7407999634742737,1.108092188835144,50000 -1890.9894676208496,4.900099754333496,51571.34389591217,152073,0,51571.34389591217,0.6213000416755676,1.74584698677063,10000,53472.46926546097,0.8474569320678711,0.6855957508087158,0.743939995765686,1.1150567531585691,50000 -1908.1210148334503,4.9520440101623535,52081.42667245865,153580,0,52081.42667245865,0.622700035572052,1.7106108665466309,10000,53999.78741884232,0.8515625,0.6506571769714355,0.7438799738883972,1.100218653678894,50000 -1925.201922893524,5.0060014724731445,52591.56684565544,155087,0,52591.56684565544,0.6258000135421753,1.720607876777649,10000,54527.11508727074,0.8556082248687744,0.6351298093795776,0.7490400075912476,1.0883169174194336,50000 -1942.1512784957888,5.063514232635498,53101.60917067528,156594,0,53101.60917067528,0.6260000467300415,1.7072957754135132,10000,55054.2172267437,0.8785076141357422,0.5570769309997559,0.7483800053596497,1.0872067213058472,50000 -1959.310719013214,5.121617794036865,53611.69621825218,158101,0,53611.69621825218,0.628600001335144,1.6892168521881104,10000,55581.5731716156,0.8756377100944519,0.5695292949676514,0.7520399689674377,1.07098388671875,50000 -1976.359569311142,5.178009748458862,54121.8637149334,159608,0,54121.8637149334,0.6337000131607056,1.6755914688110352,10000,56108.89868927002,0.8757772445678711,0.5535916686058044,0.754539966583252,1.05316162109375,50000 -1993.5300033092497,5.232148885726929,54631.85998415947,161115,0,54631.85998415947,0.6381000280380249,1.6685220003128052,10000,56636.17068815231,0.8797831535339355,0.5523720979690552,0.7566799521446228,1.0592304468154907,50000 -2010.63884305954,5.307616949081421,55141.8918106556,162622,0,55141.8918106556,0.6375000476837158,1.665583252906799,10000,57163.43918633461,0.8821747303009033,0.5450007915496826,0.758080005645752,1.0532749891281128,50000 -2027.6138999462128,5.36580753326416,55651.94191503525,164129,0,55651.94191503525,0.6361000537872314,1.6553908586502075,10000,57690.5752518177,0.8889508843421936,0.5121031999588013,0.761139988899231,1.0397335290908811,50000 -2044.88449382782,5.420394659042358,56162.06087732315,165636,0,56162.06087732315,0.6397000551223755,1.6543211936950684,10000,58218.07364296913,0.9054129123687744,0.462770015001297,0.7623400092124939,1.0329078435897827,50000 -2062.039719581604,5.4839723110198975,56672.09117293358,167143,0,56672.09117293358,0.6407000422477722,1.6520830392837524,10000,58745.37574315071,0.9016661047935486,0.473734438419342,0.7623800039291382,1.0327937602996826,50000 -2079.302550792694,5.541658639907837,57182.03576087952,168650,0,57182.03576087952,0.6403000354766846,1.6334751844406128,10000,59272.69396233559,0.9009087681770324,0.4697478115558624,0.766319990158081,1.021145582199097,50000 -2096.389819145202,5.597667694091797,57692.05851483345,170157,0,57692.05851483345,0.6470000147819519,1.6323281526565552,10000,59799.91110467911,0.9041772484779358,0.4596326351165771,0.7662799954414368,1.0119706392288208,50000 -2113.3370258808136,5.657930612564087,58202.10117435455,171664,0,58202.10117435455,0.6431000232696533,1.6356762647628784,10000,60327.013768196106,0.9072065949440002,0.4597722291946411,0.7687399983406067,1.0186874866485596,50000 -2130.31436419487,5.717501878738403,58712.05089068413,173171,0,58712.05089068413,0.6465000510215759,1.628620743751526,10000,60854.05264925957,0.9073860049247742,0.4471368789672851,0.7688800096511841,1.0100610256195068,50000 -2147.503813028336,5.7750890254974365,59221.97033691406,174678,0,59221.97033691406,0.6478000283241272,1.612573504447937,10000,61381.27100849152,0.917191445827484,0.4137333035469055,0.7696399688720703,1.0041489601135254,50000 -2164.4502770900726,5.829644203186035,59731.89277672768,176184,0,59731.89277672768,0.6492000222206116,1.610582947731018,10000,61908.24905347824,0.9174505472183228,0.415025532245636,0.7699399590492249,1.001592040061951,50000 -2181.530555009842,5.88846755027771,60241.88081359863,177690,0,60241.88081359863,0.6515000462532043,1.609252691268921,10000,62435.42911171913,0.9185068607330322,0.4077030122280121,0.7705599665641785,0.9990577697753906,50000 -2198.5527551174164,5.948404788970947,60751.85917687416,179196,0,60751.85917687416,0.6491000056266785,1.6055240631103516,10000,62962.54237771034,0.9178292155265808,0.4130044281482696,0.7719599604606628,0.9963983297348022,50000 -2215.9111063480377,6.012499570846558,61261.88948059082,180703,0,61261.88948059082,0.6509000062942505,1.6056536436080933,10000,63490.04742026329,0.9183274507522584,0.4075784087181091,0.7716599702835083,0.9949711561203004,50000 -2232.815001964569,6.070794105529785,61772.02049660683,182210,0,61772.02049660683,0.6494000554084778,1.6046305894851685,10000,64017.193064928055,0.919702649116516,0.4040387272834778,0.772159993648529,0.9936366081237792,50000 -2249.791279554367,6.127057075500488,62282.11263537407,183717,0,62282.11263537407,0.6511000394821167,1.6006096601486206,10000,64544.36849927902,0.9208585619926452,0.3966313898563385,0.7724399566650391,0.9899981021881104,50000 -2266.940537214279,6.183432102203369,62792.30425167084,185224,0,62792.30425167084,0.6509000062942505,1.6002691984176636,10000,65071.81726360321,0.9212771058082581,0.39945515990257263,0.7727199792861938,0.9911830425262451,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 633879786..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1985 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5947058,6.91732,,,,,,,,,,,,,, -1,,,0.0007573341717943,6.911828994750977,0.0006799999973736,6.912051200866699,50000.0,0.0013000001199543,6.9117279052734375,10000.0,55.74225401878357,93.10303163528442,55.74225401878357,37.360684394836426,0.0,0.0 -100,0.58977854,6.904665,,,,,,,,,,,,,, -200,0.5946515,6.8630304,,,,,,,,,,,,,, -300,0.6285222,6.7927594,,,,,,,,,,,,,, -400,0.69124126,6.6924663,,,,,,,,,,,,,, -500,0.7275424,6.6035347,,,,,,,,,,,,,, -600,0.77145004,6.517273,,,,,,,,,,,,,, -700,0.780751,6.4727135,,,,,,,,,,,,,, -800,0.8867547,6.3975725,,,,,,,,,,,,,, -900,2.126987,6.2626977,,,,,,,,,,,,,, -1000,1.6314394,6.1568966,,,,,,,,,,,,,, -1100,1.4163829,6.125343,,,,,,,,,,,,,, -1200,1.5352036,5.9941235,,,,,,,,,,,,,, -1300,2.484358,5.945636,,,,,,,,,,,,,, -1400,2.161366,5.8677893,,,,,,,,,,,,,, -1500,,,0.0748166441917419,5.330939769744873,0.0686400011181831,5.407910346984863,50000.0,0.0481000021100044,5.63258171081543,10000.0,565.8360676765442,621.1860675811768,565.8360676765442,55.27424716949463,0.0264043807983398,0.0 -1500,2.0523098,5.844819,,,,,,,,,,,,,, -1600,2.986333,5.683003,,,,,,,,,,,,,, -1700,2.4224713,5.621425,,,,,,,,,,,,,, -1800,2.1792803,5.617962,,,,,,,,,,,,,, -1900,3.7523518,5.605742,,,,,,,,,,,,,, -2000,2.9233224,5.4698663,,,,,,,,,,,,,, -2100,3.2221377,5.392377,,,,,,,,,,,,,, -2200,3.0470195,5.38864,,,,,,,,,,,,,, -2300,2.2659914,5.349083,,,,,,,,,,,,,, -2400,3.0803778,5.285293,,,,,,,,,,,,,, -2500,3.0570996,5.246422,,,,,,,,,,,,,, -2600,4.283741,5.1367836,,,,,,,,,,,,,, -2700,3.3854718,5.2742124,,,,,,,,,,,,,, -2800,7.628649,5.1225457,,,,,,,,,,,,,, -2900,3.6394553,5.0128913,,,,,,,,,,,,,, -2997,,,0.1820392161607742,4.255775451660156,0.1641199886798858,4.366084098815918,50000.0,0.1221000030636787,4.7546820640563965,10000.0,1075.7997715473175,1149.57959151268,1075.7997715473175,73.62649869918823,0.0525565147399902,0.0 -3000,5.5909386,5.02162,,,,,,,,,,,,,, -3100,3.700306,4.976178,,,,,,,,,,,,,, -3200,3.254268,4.9345474,,,,,,,,,,,,,, -3300,3.630817,4.990374,,,,,,,,,,,,,, -3400,4.804769,4.8660145,,,,,,,,,,,,,, -3500,4.3506594,4.8328443,,,,,,,,,,,,,, -3600,7.8702106,4.796962,,,,,,,,,,,,,, -3700,5.033083,4.672513,,,,,,,,,,,,,, -3800,6.208363,4.695699,,,,,,,,,,,,,, -3900,3.6737309,4.603224,,,,,,,,,,,,,, -4000,4.466648,4.6066256,,,,,,,,,,,,,, -4100,6.354639,4.523751,,,,,,,,,,,,,, -4200,8.9622965,4.394169,,,,,,,,,,,,,, -4300,3.1492898,4.4904265,,,,,,,,,,,,,, -4400,4.5975266,4.448415,,,,,,,,,,,,,, -4495,,,0.2831034660339355,3.511129856109619,0.2563399970531463,3.65845799446106,50000.0,0.1856000125408172,4.202426910400391,10000.0,1585.9234120845797,1677.6729209423063,1585.9234120845797,91.51947140693665,0.0799648761749267,0.0 -4500,3.8912964,4.474342,,,,,,,,,,,,,, -4600,4.899651,4.376626,,,,,,,,,,,,,, -4700,4.573537,4.3931084,,,,,,,,,,,,,, -4800,5.238992,4.4457855,,,,,,,,,,,,,, -4900,4.68575,4.2849693,,,,,,,,,,,,,, -5000,5.1915555,4.3768926,,,,,,,,,,,,,, -5100,5.9542756,4.1997614,,,,,,,,,,,,,, -5200,4.917846,4.318222,,,,,,,,,,,,,, -5300,5.4439483,4.0484943,,,,,,,,,,,,,, -5400,5.187271,4.1956367,,,,,,,,,,,,,, -5500,7.430348,4.020742,,,,,,,,,,,,,, -5600,7.049895,4.044888,,,,,,,,,,,,,, -5700,6.8233852,3.9578295,,,,,,,,,,,,,, -5800,4.91487,3.9352098,,,,,,,,,,,,,, -5900,5.3297853,4.0581717,,,,,,,,,,,,,, -5993,,,0.3771125674247741,2.901236057281494,0.3469599783420563,3.06270432472229,50000.0,0.2578000128269195,3.696389198303223,10000.0,2095.983659267425,2205.7152502536774,2095.983659267425,109.42569613456726,0.1053483486175537,0.0 -6000,6.6367254,3.9067762,,,,,,,,,,,,,, -6100,5.9900694,3.95218,,,,,,,,,,,,,, -6200,5.7387576,4.0093684,,,,,,,,,,,,,, -6300,6.8290114,3.9513187,,,,,,,,,,,,,, -6400,4.5813923,3.8905196,,,,,,,,,,,,,, -6500,4.1918626,3.8161755,,,,,,,,,,,,,, -6600,4.686833,3.7463613,,,,,,,,,,,,,, -6700,4.748834,3.81692,,,,,,,,,,,,,, -6800,5.0226345,3.7394555,,,,,,,,,,,,,, -6900,4.4025073,3.744476,,,,,,,,,,,,,, -7000,5.854138,3.817212,,,,,,,,,,,,,, -7100,5.966042,3.6761513,,,,,,,,,,,,,, -7200,7.6808143,3.655022,,,,,,,,,,,,,, -7300,5.3452816,3.6928868,,,,,,,,,,,,,, -7400,4.2427015,3.5615838,,,,,,,,,,,,,, -7491,,,0.437220960855484,2.591935873031616,0.4038199782371521,2.763052701950073,50000.0,0.3150000274181366,3.4180727005004883,10000.0,2605.98605966568,2733.749913215637,2605.98605966568,127.3788537979126,0.1332588195800781,0.0 -7500,5.3935933,3.717878,,,,,,,,,,,,,, -7600,4.7973495,3.5542169,,,,,,,,,,,,,, -7700,4.926175,3.6026375,,,,,,,,,,,,,, -7800,3.9583874,3.5876102,,,,,,,,,,,,,, -7900,5.6795497,3.6172276,,,,,,,,,,,,,, -8000,6.0452037,3.5798118,,,,,,,,,,,,,, -8100,6.405931,3.5874014,,,,,,,,,,,,,, -8200,4.0081406,3.4499981,,,,,,,,,,,,,, -8300,5.4633727,3.4598165,,,,,,,,,,,,,, -8400,4.5367274,3.4965422,,,,,,,,,,,,,, -8500,4.6184726,3.587187,,,,,,,,,,,,,, -8600,4.7653813,3.5143485,,,,,,,,,,,,,, -8700,4.6300044,3.4724007,,,,,,,,,,,,,, -8800,5.222556,3.5324595,,,,,,,,,,,,,, -8900,3.8463595,3.3737469,,,,,,,,,,,,,, -8990,,,0.4954759180545807,2.285950183868408,0.4610599875450134,2.458083391189575,50000.0,0.3545000255107879,3.1170003414154053,10000.0,3116.0023124217987,3262.0570573806763,3116.0023124217987,145.59150886535645,0.1610162258148193,0.0 -9000,6.7542505,3.5909226,,,,,,,,,,,,,, -9100,6.751666,3.380505,,,,,,,,,,,,,, -9200,7.7296114,3.4928126,,,,,,,,,,,,,, -9300,4.3553143,3.3975487,,,,,,,,,,,,,, -9400,5.0628295,3.33076,,,,,,,,,,,,,, -9500,3.5728605,3.370016,,,,,,,,,,,,,, -9600,4.4201236,3.3045752,,,,,,,,,,,,,, -9700,4.9362965,3.327145,,,,,,,,,,,,,, -9800,4.9364123,3.318768,,,,,,,,,,,,,, -9900,6.2832375,3.328612,,,,,,,,,,,,,, -10000,5.1094275,3.3304152,,,,,,,,,,,,,, -10100,5.1020675,3.3769653,,,,,,,,,,,,,, -10200,4.1643806,3.2862182,,,,,,,,,,,,,, -10300,4.9842787,3.3384924,,,,,,,,,,,,,, -10400,3.1255376,3.2441988,,,,,,,,,,,,,, -10489,,,0.5600087642669678,1.9418503046035769,0.4983799755573272,2.250006437301636,50000.0,0.3785000145435333,2.9464919567108154,10000.0,3626.022510528565,3794.081609249115,3626.022510528565,167.5162229537964,0.1880600452423095,0.0 -10500,8.269349,3.3153965,,,,,,,,,,,,,, -10600,5.054336,3.2413354,,,,,,,,,,,,,, -10700,4.125519,3.2327852,,,,,,,,,,,,,, -10800,10.041556,3.2572556,,,,,,,,,,,,,, -10900,5.5357656,3.3797948,,,,,,,,,,,,,, -11000,6.0640607,3.3256128,,,,,,,,,,,,,, -11100,4.2071404,3.1474967,,,,,,,,,,,,,, -11200,4.337064,3.208059,,,,,,,,,,,,,, -11300,5.387247,3.1975517,,,,,,,,,,,,,, -11400,5.4184794,3.231736,,,,,,,,,,,,,, -11500,7.789033,3.293092,,,,,,,,,,,,,, -11600,4.003442,3.246879,,,,,,,,,,,,,, -11700,3.3639576,3.21915,,,,,,,,,,,,,, -11800,6.1011453,3.2638054,,,,,,,,,,,,,, -11900,4.529276,3.2411873,,,,,,,,,,,,,, -11990,,,0.5775271058082581,1.87908947467804,0.5301399827003479,2.1053919792175293,50000.0,0.406900018453598,2.8000824451446533,10000.0,4136.248651504517,4322.467486619949,4136.248651504517,185.5975196361541,0.2156145572662353,0.0 -12000,5.2307444,3.1757665,,,,,,,,,,,,,, -12100,3.9678001,3.0888777,,,,,,,,,,,,,, -12200,4.156247,3.1592574,,,,,,,,,,,,,, -12300,5.7392693,3.1920257,,,,,,,,,,,,,, -12400,4.347346,3.1181715,,,,,,,,,,,,,, -12500,8.349034,3.1651964,,,,,,,,,,,,,, -12600,5.5012145,3.2390628,,,,,,,,,,,,,, -12700,4.464661,3.1401784,,,,,,,,,,,,,, -12800,5.1719337,3.1055408,,,,,,,,,,,,,, -12900,3.7153447,3.0984492,,,,,,,,,,,,,, -13000,5.4226537,3.0858302,,,,,,,,,,,,,, -13100,4.283851,3.1196523,,,,,,,,,,,,,, -13200,6.7795944,3.1395285,,,,,,,,,,,,,, -13300,5.451036,3.1855538,,,,,,,,,,,,,, -13400,5.245478,3.152219,,,,,,,,,,,,,, -13491,,,0.5838648080825806,1.834944128990173,0.5408999919891357,2.040344953536988,50000.0,0.424200028181076,2.7253708839416504,10000.0,4646.346125364304,4850.625118970871,4646.346125364304,203.5793845653534,0.2452063560485839,0.0 -13500,5.219427,3.1473043,,,,,,,,,,,,,, -13600,4.931342,3.0503547,,,,,,,,,,,,,, -13700,3.5744646,3.003294,,,,,,,,,,,,,, -13800,5.1450543,3.124593,,,,,,,,,,,,,, -13900,4.5329943,3.0641482,,,,,,,,,,,,,, -14000,5.5560203,3.0579343,,,,,,,,,,,,,, -14100,5.722156,3.073614,,,,,,,,,,,,,, -14200,5.3661447,3.0542681,,,,,,,,,,,,,, -14300,5.7576284,3.1413329,,,,,,,,,,,,,, -14400,6.116804,3.0402446,,,,,,,,,,,,,, -14500,4.5929837,3.0694804,,,,,,,,,,,,,, -14600,5.250259,3.013173,,,,,,,,,,,,,, -14700,5.3250422,2.9638948,,,,,,,,,,,,,, -14800,4.325378,3.0798094,,,,,,,,,,,,,, -14900,6.493981,3.0476575,,,,,,,,,,,,,, -14992,,,0.5969586968421936,1.7865535020828247,0.5550999641418457,1.9808624982833865,50000.0,0.4354000091552734,2.653869152069092,10000.0,5156.277529478073,5378.708662033081,5156.277529478073,221.64778184890747,0.2781078815460205,0.0 -15000,6.1884203,3.0100176,,,,,,,,,,,,,, -15100,3.8804014,3.1494937,,,,,,,,,,,,,, -15200,3.1632464,3.1080017,,,,,,,,,,,,,, -15300,3.432972,3.0921519,,,,,,,,,,,,,, -15400,8.397527,3.116691,,,,,,,,,,,,,, -15500,3.3731635,3.063734,,,,,,,,,,,,,, -15600,4.5442576,3.0752912,,,,,,,,,,,,,, -15700,5.4307175,2.9739249,,,,,,,,,,,,,, -15800,3.1713834,2.9699779,,,,,,,,,,,,,, -15900,6.88895,3.0347629,,,,,,,,,,,,,, -16000,3.6542752,3.0912697,,,,,,,,,,,,,, -16100,4.2554383,3.017363,,,,,,,,,,,,,, -16200,5.82598,3.0692334,,,,,,,,,,,,,, -16300,4.1791124,3.0000799,,,,,,,,,,,,,, -16400,6.248292,3.146814,,,,,,,,,,,,,, -16495,,,0.5969586968421936,1.7959184646606443,0.5592600107192993,1.978746771812439,50000.0,0.4358000159263611,2.657799005508423,10000.0,5666.510581970215,5907.21281003952,5666.510581970215,239.83745956420896,0.3100862503051758,0.0 -16500,5.795391,3.0095716,,,,,,,,,,,,,, -16600,7.063298,3.0591846,,,,,,,,,,,,,, -16700,4.8310785,2.9772503,,,,,,,,,,,,,, -16800,5.165568,2.9586127,,,,,,,,,,,,,, -16900,9.23054,3.1029673,,,,,,,,,,,,,, -17000,3.4611833,2.9971447,,,,,,,,,,,,,, -17100,5.840898,2.9403558,,,,,,,,,,,,,, -17200,3.9219656,3.0277605,,,,,,,,,,,,,, -17300,3.5136216,2.9868298,,,,,,,,,,,,,, -17400,4.3068705,2.9916196,,,,,,,,,,,,,, -17500,4.217706,3.0172803,,,,,,,,,,,,,, -17600,3.7080498,3.0434382,,,,,,,,,,,,,, -17700,3.8986638,3.0322378,,,,,,,,,,,,,, -17800,4.3481975,3.113626,,,,,,,,,,,,,, -17900,3.4350026,2.8943949,,,,,,,,,,,,,, -17998,,,0.5971181392669678,1.7574492692947388,0.5580799579620361,1.9484416246414185,50000.0,0.4360000193119049,2.6286208629608154,10000.0,6176.487642765045,6436.447620630264,6176.487642765045,259.0123672485352,0.342235803604126,0.0 -18000,3.51195,2.992804,,,,,,,,,,,,,, -18100,3.976423,3.0350118,,,,,,,,,,,,,, -18200,4.975228,3.0568662,,,,,,,,,,,,,, -18300,6.4552445,3.0256205,,,,,,,,,,,,,, -18400,3.662536,2.9153576,,,,,,,,,,,,,, -18500,5.3109307,3.0346644,,,,,,,,,,,,,, -18600,6.2302527,3.0500517,,,,,,,,,,,,,, -18700,4.591785,2.994905,,,,,,,,,,,,,, -18800,3.174846,3.0093892,,,,,,,,,,,,,, -18900,3.1101239,2.960435,,,,,,,,,,,,,, -19000,3.5039408,2.9595263,,,,,,,,,,,,,, -19100,3.9968781,2.993183,,,,,,,,,,,,,, -19200,4.4086986,3.0612411,,,,,,,,,,,,,, -19300,4.236629,3.0351386,,,,,,,,,,,,,, -19400,3.0577407,2.9862416,,,,,,,,,,,,,, -19500,3.3918722,2.915894,,,,,,,,,,,,,, -19501,,,0.6540975570678711,1.5059301853179932,0.5709599852561951,1.882540702819824,50000.0,0.4470000267028808,2.5772900581359863,10000.0,6686.56877040863,6965.261632204056,6686.56877040863,277.65295243263245,0.3832590579986572,0.0 -19600,3.1786337,2.9430852,,,,,,,,,,,,,, -19700,4.2428026,2.9310389,,,,,,,,,,,,,, -19800,4.147324,2.9615848,,,,,,,,,,,,,, -19900,3.1059368,3.1272788,,,,,,,,,,,,,, -20000,3.304784,2.9267197,,,,,,,,,,,,,, -20100,3.8278248,2.9717312,,,,,,,,,,,,,, -20200,4.78651,2.9898953,,,,,,,,,,,,,, -20300,3.7883599,3.0087452,,,,,,,,,,,,,, -20400,3.2107608,3.071605,,,,,,,,,,,,,, -20500,2.9802377,2.9920526,,,,,,,,,,,,,, -20600,3.247715,3.0232725,,,,,,,,,,,,,, -20700,4.445465,2.9217176,,,,,,,,,,,,,, -20800,3.424194,2.9553142,,,,,,,,,,,,,, -20900,4.12482,2.9457078,,,,,,,,,,,,,, -21000,3.9748573,2.9837625,,,,,,,,,,,,,, -21004,,,0.6272919178009033,1.6106610298156738,0.5689799785614014,1.8841341733932493,50000.0,0.4496000111103058,2.5562584400177,10000.0,7196.738107442856,7497.752900362015,7196.738107442856,299.892765045166,0.415421724319458,0.0 -21100,3.2318387,2.9482195,,,,,,,,,,,,,, -21200,4.115873,2.9153666,,,,,,,,,,,,,, -21300,4.0700088,2.8843603,,,,,,,,,,,,,, -21400,3.0058126,2.966806,,,,,,,,,,,,,, -21500,4.125378,2.9812288,,,,,,,,,,,,,, -21600,3.2393928,2.9482033,,,,,,,,,,,,,, -21700,4.925283,2.8631625,,,,,,,,,,,,,, -21800,3.9307768,2.913217,,,,,,,,,,,,,, -21900,3.759375,2.9036725,,,,,,,,,,,,,, -22000,4.0400143,2.9529264,,,,,,,,,,,,,, -22100,2.9850526,2.996399,,,,,,,,,,,,,, -22200,2.7430134,2.9072196,,,,,,,,,,,,,, -22300,3.3009465,2.8619497,,,,,,,,,,,,,, -22400,3.0347168,2.9557145,,,,,,,,,,,,,, -22500,3.619329,2.8716109,,,,,,,,,,,,,, -22508,,,0.6285873651504517,1.606850266456604,0.5818799734115601,1.833842158317566,50000.0,0.4611000120639801,2.4872632026672363,10000.0,7706.86282658577,8029.213074207306,7706.86282658577,321.13594365119934,0.4556243419647217,0.0 -22600,3.375378,2.893262,,,,,,,,,,,,,, -22700,3.0974483,2.907404,,,,,,,,,,,,,, -22800,3.2981484,2.885503,,,,,,,,,,,,,, -22900,4.1202216,2.8860931,,,,,,,,,,,,,, -23000,3.2609866,2.879724,,,,,,,,,,,,,, -23100,4.2564607,2.9046679,,,,,,,,,,,,,, -23200,3.2517922,2.9011812,,,,,,,,,,,,,, -23300,3.465652,3.0206375,,,,,,,,,,,,,, -23400,2.8529804,2.952785,,,,,,,,,,,,,, -23500,3.9439645,3.0064244,,,,,,,,,,,,,, -23600,3.884456,2.9431167,,,,,,,,,,,,,, -23700,3.3093665,2.916021,,,,,,,,,,,,,, -23800,3.1850278,2.9947357,,,,,,,,,,,,,, -23900,3.0097837,2.877043,,,,,,,,,,,,,, -24000,2.9352393,2.938858,,,,,,,,,,,,,, -24012,,,0.6264548897743225,1.6274893283843994,0.5815799832344055,1.8446393013000488,50000.0,0.4627000093460083,2.5038511753082275,10000.0,8216.786780834198,8561.249808549881,8216.786780834198,343.1451985836029,0.5067691802978516,0.0 -24100,3.2391338,2.9008124,,,,,,,,,,,,,, -24200,3.0210462,2.8923087,,,,,,,,,,,,,, -24300,3.5271711,2.832199,,,,,,,,,,,,,, -24400,3.2007055,2.91564,,,,,,,,,,,,,, -24500,3.6335151,2.9414554,,,,,,,,,,,,,, -24600,3.445041,2.9094949,,,,,,,,,,,,,, -24700,3.9962652,2.8492758,,,,,,,,,,,,,, -24800,3.6986437,2.8491838,,,,,,,,,,,,,, -24900,3.358724,2.9092607,,,,,,,,,,,,,, -25000,3.5937927,2.8835866,,,,,,,,,,,,,, -25100,3.6376789,2.843048,,,,,,,,,,,,,, -25200,2.8856735,2.8316393,,,,,,,,,,,,,, -25300,3.9935257,2.8364763,,,,,,,,,,,,,, -25400,3.0333855,2.8771217,,,,,,,,,,,,,, -25500,2.992243,2.8692436,,,,,,,,,,,,,, -25516,,,0.631277859210968,1.6501520872116089,0.5895199775695801,1.8403079509735107,50000.0,0.4686000347137451,2.4981985092163086,10000.0,8726.874703884125,9093.15644454956,8726.874703884125,364.8770282268524,0.5437021255493164,0.0 -25600,4.066325,2.8273964,,,,,,,,,,,,,, -25700,3.4843206,2.8709533,,,,,,,,,,,,,, -25800,3.0223486,2.850499,,,,,,,,,,,,,, -25900,3.6563199,2.9467287,,,,,,,,,,,,,, -26000,3.5121033,2.8395977,,,,,,,,,,,,,, -26100,2.8082542,2.8270187,,,,,,,,,,,,,, -26200,3.0156286,2.9027314,,,,,,,,,,,,,, -26300,2.9198422,2.811417,,,,,,,,,,,,,, -26400,2.9762723,2.8081565,,,,,,,,,,,,,, -26500,3.3441308,2.8342693,,,,,,,,,,,,,, -26600,3.4426544,2.7839098,,,,,,,,,,,,,, -26700,3.6893928,2.896978,,,,,,,,,,,,,, -26800,2.9897327,2.9220505,,,,,,,,,,,,,, -26900,2.932816,2.8973877,,,,,,,,,,,,,, -27000,3.6023276,2.8469727,,,,,,,,,,,,,, -27021,,,0.6287069320678711,1.6057114601135254,0.5907599925994873,1.7863342761993408,50000.0,0.4698000252246856,2.4951579570770264,10000.0,9237.083882570269,9625.399262428284,9237.083882570269,386.8250911235809,0.5786194801330566,0.0 -27100,3.2312288,2.802253,,,,,,,,,,,,,, -27200,2.663928,2.8043456,,,,,,,,,,,,,, -27300,3.3744233,2.7199318,,,,,,,,,,,,,, -27400,5.672577,2.9054232,,,,,,,,,,,,,, -27500,2.752677,2.8442163,,,,,,,,,,,,,, -27600,2.5813262,2.8487897,,,,,,,,,,,,,, -27700,3.4026656,2.8462262,,,,,,,,,,,,,, -27800,2.8925688,2.9003637,,,,,,,,,,,,,, -27900,2.8204055,2.7987442,,,,,,,,,,,,,, -28000,3.1684124,2.8287518,,,,,,,,,,,,,, -28100,2.743618,2.869482,,,,,,,,,,,,,, -28200,3.0900838,2.8609166,,,,,,,,,,,,,, -28300,2.8263373,2.835995,,,,,,,,,,,,,, -28400,4.048345,2.8210154,,,,,,,,,,,,,, -28500,3.180586,2.7814212,,,,,,,,,,,,,, -28524,,,0.683992326259613,1.3645371198654177,0.597000002861023,1.757430911064148,50000.0,0.473000019788742,2.4163501262664795,10000.0,9746.664439201357,10157.818138360975,9746.664439201357,409.0597383975983,1.13085675239563,0.0 -28600,3.3425956,2.879921,,,,,,,,,,,,,, -28700,3.4789507,2.8355265,,,,,,,,,,,,,, -28800,3.2795541,2.8443787,,,,,,,,,,,,,, -28900,3.3601718,2.7620041,,,,,,,,,,,,,, -29000,3.354407,2.833785,,,,,,,,,,,,,, -29100,3.4684324,2.8936253,,,,,,,,,,,,,, -29200,3.8011637,2.9048705,,,,,,,,,,,,,, -29300,3.7360606,2.960565,,,,,,,,,,,,,, -29400,2.8231132,2.9023328,,,,,,,,,,,,,, -29500,3.581827,2.814722,,,,,,,,,,,,,, -29600,3.1338146,2.9398274,,,,,,,,,,,,,, -29700,2.8421197,2.8183317,,,,,,,,,,,,,, -29800,3.5378556,2.9369485,,,,,,,,,,,,,, -29900,3.17805,2.7990766,,,,,,,,,,,,,, -30000,3.1682272,2.8548017,,,,,,,,,,,,,, -30028,,,0.6511877775192261,1.5284302234649658,0.5967199802398682,1.7889021635055542,50000.0,0.4752000272274017,2.4623258113861084,10000.0,10256.616863250732,10690.81227874756,10256.616863250732,432.0158112049103,1.1649293899536133,0.0 -30100,3.4696634,2.7864678,,,,,,,,,,,,,, -30200,4.0415974,2.8311555,,,,,,,,,,,,,, -30300,3.083816,2.8921432,,,,,,,,,,,,,, -30400,2.9443395,2.8860292,,,,,,,,,,,,,, -30500,3.095975,2.9705677,,,,,,,,,,,,,, -30600,3.5720935,2.8000073,,,,,,,,,,,,,, -30700,3.4148407,2.796964,,,,,,,,,,,,,, -30800,2.9598207,2.8898573,,,,,,,,,,,,,, -30900,3.7788708,2.8770006,,,,,,,,,,,,,, -31000,3.1594622,2.8005762,,,,,,,,,,,,,, -31100,3.4404266,2.7901163,,,,,,,,,,,,,, -31200,3.4841003,2.8438735,,,,,,,,,,,,,, -31300,3.7896185,2.8322551,,,,,,,,,,,,,, -31400,3.3610754,2.8295193,,,,,,,,,,,,,, -31500,2.8408802,2.7691627,,,,,,,,,,,,,, -31532,,,0.6303810477256775,1.6111574172973633,0.5847600102424622,1.82930588722229,50000.0,0.4620000123977661,2.532831907272339,10000.0,10766.595863342283,11225.963517665865,10766.595863342283,457.0925896167755,1.210730791091919,0.0 -31600,2.8902633,2.7673237,,,,,,,,,,,,,, -31700,3.1583886,2.7803361,,,,,,,,,,,,,, -31800,3.4384155,2.8739953,,,,,,,,,,,,,, -31900,3.3530324,2.791605,,,,,,,,,,,,,, -32000,3.493212,2.8718116,,,,,,,,,,,,,, -32100,3.4699678,2.8192413,,,,,,,,,,,,,, -32200,3.0574183,2.738561,,,,,,,,,,,,,, -32300,2.6888816,2.7400725,,,,,,,,,,,,,, -32400,2.6125755,2.8221865,,,,,,,,,,,,,, -32500,2.9233177,2.8761308,,,,,,,,,,,,,, -32600,3.1957152,2.8479977,,,,,,,,,,,,,, -32700,3.1279287,2.7823548,,,,,,,,,,,,,, -32800,3.318465,2.8720121,,,,,,,,,,,,,, -32900,3.1264071,2.7943928,,,,,,,,,,,,,, -33000,3.0162473,2.8026805,,,,,,,,,,,,,, -33037,,,0.649832546710968,1.484278440475464,0.6013000011444092,1.7228082418441772,50000.0,0.4825000166893005,2.373618125915528,10000.0,11276.758833408356,11759.353006362917,11276.758833408356,480.23394536972046,1.2460718154907229,0.0 -33100,3.0423567,2.7850885,,,,,,,,,,,,,, -33200,3.1011927,2.6978593,,,,,,,,,,,,,, -33300,3.4322815,2.8079927,,,,,,,,,,,,,, -33400,3.101426,2.8581848,,,,,,,,,,,,,, -33500,3.3832126,2.802128,,,,,,,,,,,,,, -33600,3.406823,2.805348,,,,,,,,,,,,,, -33700,3.1041148,2.816572,,,,,,,,,,,,,, -33800,3.1273699,2.9573395,,,,,,,,,,,,,, -33900,3.0299633,2.7710788,,,,,,,,,,,,,, -34000,3.0990407,2.8107967,,,,,,,,,,,,,, -34100,3.3146107,2.7799854,,,,,,,,,,,,,, -34200,3.3363757,2.8259602,,,,,,,,,,,,,, -34300,3.2661004,2.7967525,,,,,,,,,,,,,, -34400,3.7155735,2.796628,,,,,,,,,,,,,, -34500,4.643645,2.819765,,,,,,,,,,,,,, -34542,,,0.6471221446990967,1.524101734161377,0.602180004119873,1.7275532484054563,50000.0,0.4793000221252441,2.3868050575256348,10000.0,11786.717982053757,12292.46637749672,11786.717982053757,503.3027272224426,1.2807552814483645,0.0 -34600,3.0183144,2.745928,,,,,,,,,,,,,, -34700,3.1578217,2.8098185,,,,,,,,,,,,,, -34800,2.9423952,2.7810216,,,,,,,,,,,,,, -34900,2.8459804,2.7964697,,,,,,,,,,,,,, -35000,3.9080472,2.8085241,,,,,,,,,,,,,, -35100,4.754564,2.6060178,,,,,,,,,,,,,, -35200,3.1341417,2.7583523,,,,,,,,,,,,,, -35300,3.1955297,2.8031106,,,,,,,,,,,,,, -35400,3.4235191,2.7658062,,,,,,,,,,,,,, -35500,2.9749937,2.829652,,,,,,,,,,,,,, -35600,2.8791146,2.7600946,,,,,,,,,,,,,, -35700,3.0011823,2.7764683,,,,,,,,,,,,,, -35800,3.4330661,2.7901244,,,,,,,,,,,,,, -35900,3.1522539,2.7551723,,,,,,,,,,,,,, -36000,3.569598,2.7179031,,,,,,,,,,,,,, -36047,,,0.6511877775192261,1.5404627323150637,0.6087799668312073,1.7445299625396729,50000.0,0.4851000308990478,2.3932998180389404,10000.0,12296.627568244934,12825.880003213882,12296.627568244934,526.7157201766968,1.3211784362792969,0.0 -36100,2.9600766,2.8549047,,,,,,,,,,,,,, -36200,3.2378125,2.7352877,,,,,,,,,,,,,, -36300,2.986042,2.803756,,,,,,,,,,,,,, -36400,3.5902,2.8033867,,,,,,,,,,,,,, -36500,3.5558107,2.8130922,,,,,,,,,,,,,, -36600,3.9430726,2.7726495,,,,,,,,,,,,,, -36700,3.2345524,2.741488,,,,,,,,,,,,,, -36800,3.2150605,2.8179765,,,,,,,,,,,,,, -36900,3.0903404,2.7568343,,,,,,,,,,,,,, -37000,3.271841,2.88902,,,,,,,,,,,,,, -37100,3.802726,2.8329718,,,,,,,,,,,,,, -37200,3.4248657,2.791407,,,,,,,,,,,,,, -37300,3.0704477,2.8223116,,,,,,,,,,,,,, -37400,3.0390165,2.824061,,,,,,,,,,,,,, -37500,2.9949172,2.820707,,,,,,,,,,,,,, -37552,,,0.6801658272743225,1.3729093074798584,0.5968199968338013,1.7451149225234983,50000.0,0.4791000187397003,2.4134533405303955,10000.0,12806.58596277237,13357.748576164246,12806.58596277237,548.5404841899872,1.3562448024749756,0.0 -37600,2.6902578,2.8064027,,,,,,,,,,,,,, -37700,3.5796988,2.785419,,,,,,,,,,,,,, -37800,2.829611,2.8093004,,,,,,,,,,,,,, -37900,3.50845,2.7903821,,,,,,,,,,,,,, -38000,2.8055387,2.7904534,,,,,,,,,,,,,, -38100,3.0461545,2.7834535,,,,,,,,,,,,,, -38200,3.616104,2.770663,,,,,,,,,,,,,, -38300,2.8959146,2.6969259,,,,,,,,,,,,,, -38400,3.5657485,2.7667818,,,,,,,,,,,,,, -38500,3.1001794,2.7513175,,,,,,,,,,,,,, -38600,3.706677,2.7713592,,,,,,,,,,,,,, -38700,3.6184661,2.7687478,,,,,,,,,,,,,, -38800,3.0611134,2.8277345,,,,,,,,,,,,,, -38900,2.9526803,2.7208202,,,,,,,,,,,,,, -39000,3.09327,2.8772814,,,,,,,,,,,,,, -39057,,,0.6672711968421936,1.4364635944366455,0.604640007019043,1.7295138835906982,50000.0,0.4818000197410583,2.383208036422729,10000.0,13316.641247034073,13888.769673347471,13316.641247034073,569.4229211807251,1.388034105300903,0.0 -39100,3.7969193,2.822278,,,,,,,,,,,,,, -39200,2.823403,2.7256937,,,,,,,,,,,,,, -39300,3.0798354,2.8387887,,,,,,,,,,,,,, -39400,3.07839,2.701367,,,,,,,,,,,,,, -39500,3.5749247,2.8278322,,,,,,,,,,,,,, -39600,2.921279,2.720987,,,,,,,,,,,,,, -39700,3.475524,2.7859473,,,,,,,,,,,,,, -39800,2.845171,2.7471204,,,,,,,,,,,,,, -39900,3.2869928,2.7427018,,,,,,,,,,,,,, -40000,2.9858603,2.6943626,,,,,,,,,,,,,, -40100,3.5626924,2.7531152,,,,,,,,,,,,,, -40200,2.8724897,2.7669923,,,,,,,,,,,,,, -40300,2.8510869,2.7542522,,,,,,,,,,,,,, -40400,3.351783,2.8205016,,,,,,,,,,,,,, -40500,3.224463,2.6936436,,,,,,,,,,,,,, -40562,,,0.6653180718421936,1.4576128721237185,0.6128999590873718,1.690018653869629,50000.0,0.490200012922287,2.349595308303833,10000.0,13826.662028312683,14419.336757183077,13826.662028312683,589.8896474838257,1.4172828197479248,0.0 -40600,2.9057813,2.7494316,,,,,,,,,,,,,, -40700,3.2706406,2.8139262,,,,,,,,,,,,,, -40800,3.4896483,2.8289244,,,,,,,,,,,,,, -40900,3.2969809,2.8624315,,,,,,,,,,,,,, -41000,3.4153326,2.7160506,,,,,,,,,,,,,, -41100,3.0606034,2.8466687,,,,,,,,,,,,,, -41200,3.1203532,2.8016589,,,,,,,,,,,,,, -41300,3.2352228,2.8121862,,,,,,,,,,,,,, -41400,3.918148,2.8785517,,,,,,,,,,,,,, -41500,2.9979799,2.6898375,,,,,,,,,,,,,, -41600,2.814266,2.7918074,,,,,,,,,,,,,, -41700,2.882938,2.8073967,,,,,,,,,,,,,, -41800,3.294826,2.711154,,,,,,,,,,,,,, -41900,2.7506216,2.855495,,,,,,,,,,,,,, -42000,3.4303036,2.7707076,,,,,,,,,,,,,, -42067,,,0.6569873690605164,1.500503659248352,0.608460009098053,1.7169100046157837,50000.0,0.4836000204086303,2.381596803665161,10000.0,14336.611745595932,14949.84656405449,14336.611745595932,610.3627202510834,1.4521598815917969,0.0 -42100,3.0578058,2.81783,,,,,,,,,,,,,, -42200,3.780043,2.8753014,,,,,,,,,,,,,, -42300,2.8371043,2.729206,,,,,,,,,,,,,, -42400,3.0203717,2.8042617,,,,,,,,,,,,,, -42500,3.2604115,2.7630653,,,,,,,,,,,,,, -42600,4.0821605,2.778904,,,,,,,,,,,,,, -42700,3.5586357,2.8010998,,,,,,,,,,,,,, -42800,2.9240654,2.787487,,,,,,,,,,,,,, -42900,3.0570488,2.7780092,,,,,,,,,,,,,, -43000,3.1334426,2.6995537,,,,,,,,,,,,,, -43100,2.6922004,2.7363977,,,,,,,,,,,,,, -43200,3.057535,2.7088237,,,,,,,,,,,,,, -43300,3.1225965,2.6819682,,,,,,,,,,,,,, -43400,3.647727,2.8473039,,,,,,,,,,,,,, -43500,3.374798,2.8002083,,,,,,,,,,,,,, -43573,,,0.6497528553009033,1.5108873844146729,0.6085599660873413,1.701366662979126,50000.0,0.4910000264644623,2.3547487258911133,10000.0,14846.80584168434,15482.147728919985,14846.80584168434,632.3841454982758,1.4866185188293457,0.0 -43600,3.0432482,2.6721497,,,,,,,,,,,,,, -43700,3.1686876,2.746869,,,,,,,,,,,,,, -43800,3.8799665,2.8466191,,,,,,,,,,,,,, -43900,3.1845095,2.825036,,,,,,,,,,,,,, -44000,3.024974,2.7648816,,,,,,,,,,,,,, -44100,3.2057316,2.8151054,,,,,,,,,,,,,, -44200,3.4930038,2.7705195,,,,,,,,,,,,,, -44300,3.4616709,2.7751808,,,,,,,,,,,,,, -44400,2.7811525,2.774958,,,,,,,,,,,,,, -44500,3.31809,2.7424753,,,,,,,,,,,,,, -44600,3.6968052,2.7859178,,,,,,,,,,,,,, -44700,2.6519117,2.6236396,,,,,,,,,,,,,, -44800,3.8330684,2.755051,,,,,,,,,,,,,, -44900,2.9545338,2.7153437,,,,,,,,,,,,,, -45000,3.5698152,2.7833176,,,,,,,,,,,,,, -45079,,,0.6457469463348389,1.5189896821975708,0.6038999557495117,1.721349000930786,50000.0,0.4716000258922577,2.438140869140625,10000.0,15357.030562639236,16011.586278438568,15357.030562639236,651.5147202014923,1.5189547538757324,0.0 -45100,2.875826,2.818335,,,,,,,,,,,,,, -45200,3.7849333,2.8432126,,,,,,,,,,,,,, -45300,3.6620903,2.755006,,,,,,,,,,,,,, -45400,3.1803775,2.8335073,,,,,,,,,,,,,, -45500,3.198368,2.7842498,,,,,,,,,,,,,, -45600,3.3365638,2.7480235,,,,,,,,,,,,,, -45700,3.5273714,2.7197773,,,,,,,,,,,,,, -45800,3.0030305,2.7615957,,,,,,,,,,,,,, -45900,3.1207438,2.6469064,,,,,,,,,,,,,, -46000,3.654905,2.76218,,,,,,,,,,,,,, -46100,3.0336068,2.761457,,,,,,,,,,,,,, -46200,3.0509079,2.8799336,,,,,,,,,,,,,, -46300,4.233463,2.7998588,,,,,,,,,,,,,, -46400,3.114965,2.6970448,,,,,,,,,,,,,, -46500,3.4358823,2.7703397,,,,,,,,,,,,,, -46585,,,0.6690250039100647,1.4443752765655518,0.603119969367981,1.738900899887085,50000.0,0.4827000200748443,2.392597913742065,10000.0,15867.12254691124,16542.223130464554,15867.12254691124,671.9697597026825,1.5562007427215576,0.0 -46600,3.1504123,2.8092747,,,,,,,,,,,,,, -46700,2.9297473,2.7675602,,,,,,,,,,,,,, -46800,2.9922569,2.755706,,,,,,,,,,,,,, -46900,3.3460126,2.7387135,,,,,,,,,,,,,, -47000,2.951792,2.7628145,,,,,,,,,,,,,, -47100,4.154131,2.7946208,,,,,,,,,,,,,, -47200,3.0940707,2.7198946,,,,,,,,,,,,,, -47300,3.8568413,2.7700415,,,,,,,,,,,,,, -47400,2.8910332,2.7686005,,,,,,,,,,,,,, -47500,2.7734787,2.6734428,,,,,,,,,,,,,, -47600,3.4224474,2.7867827,,,,,,,,,,,,,, -47700,3.7067504,2.7960687,,,,,,,,,,,,,, -47800,3.237758,2.8630412,,,,,,,,,,,,,, -47900,2.9942253,2.774319,,,,,,,,,,,,,, -48000,3.1306689,2.7559614,,,,,,,,,,,,,, -48091,,,0.6783920526504517,1.3911480903625488,0.6177600026130676,1.6825976371765137,50000.0,0.4932000339031219,2.372044324874878,10000.0,16377.347776174543,17072.506830453873,16377.347776174543,691.9411878585815,1.5927386283874512,0.0 -48100,3.7944171,2.8641798,,,,,,,,,,,,,, -48200,3.1075456,2.614841,,,,,,,,,,,,,, -48300,3.3892474,2.7523732,,,,,,,,,,,,,, -48400,3.309797,2.7979271,,,,,,,,,,,,,, -48500,2.9260318,2.6241784,,,,,,,,,,,,,, -48600,3.8419085,2.8042495,,,,,,,,,,,,,, -48700,2.948865,2.630016,,,,,,,,,,,,,, -48800,3.1081693,2.7716942,,,,,,,,,,,,,, -48900,3.0565405,2.6577382,,,,,,,,,,,,,, -49000,3.387313,2.7870975,,,,,,,,,,,,,, -49100,2.9539845,2.8099592,,,,,,,,,,,,,, -49200,3.1656218,2.6622105,,,,,,,,,,,,,, -49300,3.8382778,2.739572,,,,,,,,,,,,,, -49400,3.137277,2.6416404,,,,,,,,,,,,,, -49500,4.658723,2.855185,,,,,,,,,,,,,, -49597,,,0.6721141338348389,1.4001306295394895,0.6206600069999695,1.6346431970596311,50000.0,0.5010000467300415,2.282940149307251,10000.0,16887.48326063156,17605.0581843853,16887.48326063156,714.2665731906891,1.6319713592529297,0.0 -49600,3.2451777,2.7198513,,,,,,,,,,,,,, -49700,3.4623322,2.8064566,,,,,,,,,,,,,, -49800,3.0727744,2.665304,,,,,,,,,,,,,, -49900,3.6960435,2.749094,,,,,,,,,,,,,, -50000,3.1937554,2.7230794,,,,,,,,,,,,,, -50100,2.9571743,2.7327547,,,,,,,,,,,,,, -50200,3.3370132,2.8275518,,,,,,,,,,,,,, -50300,3.3871183,2.7422013,,,,,,,,,,,,,, -50400,3.1873872,2.7527056,,,,,,,,,,,,,, -50500,3.2189646,2.709847,,,,,,,,,,,,,, -50600,3.3243208,2.6439483,,,,,,,,,,,,,, -50700,3.6520922,2.7383955,,,,,,,,,,,,,, -50800,2.9211,2.6929364,,,,,,,,,,,,,, -50900,3.967906,2.6215754,,,,,,,,,,,,,, -51000,2.904762,2.7006347,,,,,,,,,,,,,, -51100,3.0718226,2.6831203,,,,,,,,,,,,,, -51104,,,0.6506895422935486,1.5444890260696411,0.6065199971199036,1.7607401609420776,50000.0,0.4819000363349914,2.446437358856201,10000.0,17397.699914216995,18136.358857870106,17397.699914216995,735.2593734264374,1.6691064834594729,0.0 -51200,3.395954,2.7096958,,,,,,,,,,,,,, -51300,3.041735,2.8489208,,,,,,,,,,,,,, -51400,2.9059434,2.7694077,,,,,,,,,,,,,, -51500,3.5145013,2.9454532,,,,,,,,,,,,,, -51600,3.2899332,2.754456,,,,,,,,,,,,,, -51700,3.2821176,2.7523608,,,,,,,,,,,,,, -51800,3.4769278,2.733976,,,,,,,,,,,,,, -51900,3.1007018,2.7479675,,,,,,,,,,,,,, -52000,3.236247,2.7380173,,,,,,,,,,,,,, -52100,4.0109043,2.6980834,,,,,,,,,,,,,, -52200,3.4814003,2.6417685,,,,,,,,,,,,,, -52300,3.7840595,2.632037,,,,,,,,,,,,,, -52400,3.1392248,2.698513,,,,,,,,,,,,,, -52500,3.3598077,2.7259634,,,,,,,,,,,,,, -52600,2.7488515,2.6996682,,,,,,,,,,,,,, -52610,,,0.6633649468421936,1.458517074584961,0.6189799904823303,1.6728732585906982,50000.0,0.4906000196933746,2.3692007064819336,10000.0,17907.79274368286,18665.110434770584,17907.79274368286,753.8350718021393,1.702005386352539,0.0 -52700,3.0349789,2.7878814,,,,,,,,,,,,,, -52800,3.0588088,2.7316864,,,,,,,,,,,,,, -52900,3.5170836,2.6836305,,,,,,,,,,,,,, -53000,2.8480942,2.6898627,,,,,,,,,,,,,, -53100,3.4133675,2.6857882,,,,,,,,,,,,,, -53200,3.398709,2.738308,,,,,,,,,,,,,, -53300,3.2282584,2.7716508,,,,,,,,,,,,,, -53400,3.050051,2.6858273,,,,,,,,,,,,,, -53500,2.974879,2.7915158,,,,,,,,,,,,,, -53600,2.880388,2.7849476,,,,,,,,,,,,,, -53700,3.2667658,2.7893262,,,,,,,,,,,,,, -53800,3.870397,2.7993803,,,,,,,,,,,,,, -53900,3.6592627,2.7437418,,,,,,,,,,,,,, -54000,3.620733,2.67832,,,,,,,,,,,,,, -54100,3.5459175,2.7098813,,,,,,,,,,,,,, -54117,,,0.652363657951355,1.5211600065231323,0.6086999773979187,1.7170997858047483,50000.0,0.4872000217437744,2.394115447998047,10000.0,18417.99221158028,19192.78498077393,18417.99221158028,771.2206614017487,1.7383487224578855,0.0 -54200,2.8490896,2.7151954,,,,,,,,,,,,,, -54300,3.1812315,2.7147923,,,,,,,,,,,,,, -54400,3.2425096,2.7108567,,,,,,,,,,,,,, -54500,3.1241066,2.70926,,,,,,,,,,,,,, -54600,3.3490837,2.6897283,,,,,,,,,,,,,, -54700,3.4844728,2.7058632,,,,,,,,,,,,,, -54800,3.7068152,2.6310923,,,,,,,,,,,,,, -54900,3.2230926,2.6691127,,,,,,,,,,,,,, -55000,3.01194,2.6608548,,,,,,,,,,,,,, -55100,3.7350228,2.7172122,,,,,,,,,,,,,, -55200,3.9054902,2.6573226,,,,,,,,,,,,,, -55300,3.3614695,2.7525764,,,,,,,,,,,,,, -55400,3.3342028,2.6731126,,,,,,,,,,,,,, -55500,3.233761,2.648477,,,,,,,,,,,,,, -55600,3.6005962,2.6921635,,,,,,,,,,,,,, -55623,,,0.6726123690605164,1.4373608827590942,0.6170600056648254,1.682947874069214,50000.0,0.4896000325679779,2.345250606536865,10000.0,18927.954924583435,19720.38703656197,18927.954924583435,788.768620967865,1.7768800258636477,0.0 -55700,3.172902,2.6418438,,,,,,,,,,,,,, -55800,3.8547,2.7859993,,,,,,,,,,,,,, -55900,3.1061068,2.7366695,,,,,,,,,,,,,, -56000,3.6741676,2.7254663,,,,,,,,,,,,,, -56100,3.038886,2.7820485,,,,,,,,,,,,,, -56200,2.9837441,2.6404123,,,,,,,,,,,,,, -56300,4.0448728,2.6696534,,,,,,,,,,,,,, -56400,3.269438,2.6588385,,,,,,,,,,,,,, -56500,3.1728585,2.7431278,,,,,,,,,,,,,, -56600,3.2490387,2.6803436,,,,,,,,,,,,,, -56700,4.037985,2.7655587,,,,,,,,,,,,,, -56800,3.0764477,2.621006,,,,,,,,,,,,,, -56900,3.5856092,2.707679,,,,,,,,,,,,,, -57000,3.739978,2.718831,,,,,,,,,,,,,, -57100,3.4373763,2.7154386,,,,,,,,,,,,,, -57129,,,0.6888552308082581,1.3610378503799438,0.6221599578857422,1.6592100858688354,50000.0,0.4963000118732452,2.3328607082366943,10000.0,19438.130825281143,20247.76454782486,19438.130825281143,805.8777091503143,1.8174619674682613,0.0 -57200,3.5113773,2.692472,,,,,,,,,,,,,, -57300,3.05019,2.6549582,,,,,,,,,,,,,, -57400,3.1526918,2.7171736,,,,,,,,,,,,,, -57500,2.9426324,2.6846561,,,,,,,,,,,,,, -57600,3.464782,2.7048268,,,,,,,,,,,,,, -57700,3.453772,2.6951144,,,,,,,,,,,,,, -57800,3.0926738,2.7305937,,,,,,,,,,,,,, -57900,3.2168083,2.6435518,,,,,,,,,,,,,, -58000,3.3810487,2.7848387,,,,,,,,,,,,,, -58100,3.0433338,2.7666306,,,,,,,,,,,,,, -58200,3.1476128,2.738492,,,,,,,,,,,,,, -58300,3.8450541,2.6204338,,,,,,,,,,,,,, -58400,3.2744267,2.7192843,,,,,,,,,,,,,, -58500,2.899253,2.715478,,,,,,,,,,,,,, -58600,3.3675241,2.6035738,,,,,,,,,,,,,, -58636,,,0.6800262928009033,1.3760676383972168,0.620959997177124,1.6456332206726074,50000.0,0.4946000277996063,2.325051784515381,10000.0,19948.24530482292,20775.29114437104,19948.24530482292,823.1925563812256,1.8636653423309328,0.0 -58700,3.2294285,2.7028592,,,,,,,,,,,,,, -58800,3.3149905,2.6490202,,,,,,,,,,,,,, -58900,3.633787,2.7435179,,,,,,,,,,,,,, -59000,4.5237684,2.6478748,,,,,,,,,,,,,, -59100,2.9786773,2.6858907,,,,,,,,,,,,,, -59200,2.9615066,2.6887383,,,,,,,,,,,,,, -59300,2.9488368,2.6878383,,,,,,,,,,,,,, -59400,3.9499,2.7432487,,,,,,,,,,,,,, -59500,3.3376195,2.694567,,,,,,,,,,,,,, -59600,3.6886024,2.7924757,,,,,,,,,,,,,, -59700,3.4992611,2.6872108,,,,,,,,,,,,,, -59800,3.2225556,2.6473231,,,,,,,,,,,,,, -59900,3.3176937,2.6984167,,,,,,,,,,,,,, -60000,3.239935,2.6569924,,,,,,,,,,,,,, -60100,3.130508,2.6559165,,,,,,,,,,,,,, -60143,,,0.6727718114852905,1.408918380737305,0.6265400052070618,1.6337388753890991,50000.0,0.503600001335144,2.290309190750122,10000.0,20458.460973262787,21302.928270578384,20458.460973262787,840.5147912502289,1.910284280776977,0.0 -60200,2.9067893,2.694254,,,,,,,,,,,,,, -60300,3.4665902,2.757433,,,,,,,,,,,,,, -60400,3.044281,2.724019,,,,,,,,,,,,,, -60500,3.1165974,2.5975337,,,,,,,,,,,,,, -60600,3.2383554,2.774008,,,,,,,,,,,,,, -60700,3.509822,2.646195,,,,,,,,,,,,,, -60800,2.966564,2.6696246,,,,,,,,,,,,,, -60900,3.7344828,2.8398604,,,,,,,,,,,,,, -61000,3.2647946,2.8361633,,,,,,,,,,,,,, -61100,3.567203,2.7920866,,,,,,,,,,,,,, -61200,3.4945447,2.7051635,,,,,,,,,,,,,, -61300,3.634271,2.7520144,,,,,,,,,,,,,, -61400,3.0384586,2.6570323,,,,,,,,,,,,,, -61500,3.25031,2.675243,,,,,,,,,,,,,, -61600,3.224388,2.6537194,,,,,,,,,,,,,, -61651,,,0.6753029227256775,1.3835617303848269,0.6267799735069275,1.610917329788208,50000.0,0.5010000467300415,2.281240224838257,10000.0,20968.66392183304,21831.363361120224,20968.66392183304,858.6574947834015,1.9475953578948968,0.0 -61700,3.3318105,2.614158,,,,,,,,,,,,,, -61800,3.0319602,2.6719089,,,,,,,,,,,,,, -61900,3.1354854,2.7014217,,,,,,,,,,,,,, -62000,3.9000416,2.590985,,,,,,,,,,,,,, -62100,3.2864368,2.6681323,,,,,,,,,,,,,, -62200,3.1454294,2.6196632,,,,,,,,,,,,,, -62300,3.3621461,2.7025115,,,,,,,,,,,,,, -62400,3.1529472,2.6786768,,,,,,,,,,,,,, -62500,3.3622525,2.7216356,,,,,,,,,,,,,, -62600,2.8717935,2.6539998,,,,,,,,,,,,,, -62700,3.0270925,2.66241,,,,,,,,,,,,,, -62800,3.5200956,2.7024565,,,,,,,,,,,,,, -62900,3.3638566,2.7491746,,,,,,,,,,,,,, -63000,3.2560172,2.593718,,,,,,,,,,,,,, -63100,3.9865084,2.8032696,,,,,,,,,,,,,, -63158,,,0.6793287396430969,1.402336597442627,0.631060004234314,1.6216193437576294,50000.0,0.5006000399589539,2.285511016845703,10000.0,21478.79666543007,22359.277554273605,21478.79666543007,876.3508095741272,1.9852290153503416,0.0 -63200,3.2885761,2.684495,,,,,,,,,,,,,, -63300,2.8527865,2.6295238,,,,,,,,,,,,,, -63400,3.0366511,2.6222615,,,,,,,,,,,,,, -63500,3.0641475,2.7067409,,,,,,,,,,,,,, -63600,3.5929058,2.763246,,,,,,,,,,,,,, -63700,3.1991005,2.7398674,,,,,,,,,,,,,, -63800,3.6455715,2.6291525,,,,,,,,,,,,,, -63900,3.4684522,2.6774998,,,,,,,,,,,,,, -64000,3.3439212,2.598838,,,,,,,,,,,,,, -64100,3.4497466,2.6800365,,,,,,,,,,,,,, -64200,3.3651764,2.5334947,,,,,,,,,,,,,, -64300,3.3806696,2.642962,,,,,,,,,,,,,, -64400,3.3325763,2.6453261,,,,,,,,,,,,,, -64500,3.622828,2.6974087,,,,,,,,,,,,,, -64600,2.6816683,2.6658082,,,,,,,,,,,,,, -64665,,,0.6847297549247742,1.3605037927627563,0.6293999552726746,1.600814700126648,50000.0,0.5010000467300415,2.2789101600646973,10000.0,21988.747662067413,22887.225957870483,21988.747662067413,894.2573552131653,2.025351047515869,0.0 -64700,3.3385599,2.6673925,,,,,,,,,,,,,, -64800,3.5947106,2.628549,,,,,,,,,,,,,, -64900,3.5088532,2.6607347,,,,,,,,,,,,,, -65000,2.948618,2.6190305,,,,,,,,,,,,,, -65100,3.2147079,2.7530916,,,,,,,,,,,,,, -65200,3.4332664,2.6726508,,,,,,,,,,,,,, -65300,2.847737,2.8172534,,,,,,,,,,,,,, -65400,3.3386645,2.6787338,,,,,,,,,,,,,, -65500,3.63268,2.7208977,,,,,,,,,,,,,, -65600,3.1382508,2.6446946,,,,,,,,,,,,,, -65700,4.025921,2.633004,,,,,,,,,,,,,, -65800,3.251178,2.6344247,,,,,,,,,,,,,, -65900,3.1827846,2.6675735,,,,,,,,,,,,,, -66000,3.821051,2.6901627,,,,,,,,,,,,,, -66100,3.3779757,2.578622,,,,,,,,,,,,,, -66172,,,0.7006337642669678,1.267736554145813,0.6306799650192261,1.588563084602356,50000.0,0.5067000389099121,2.267800807952881,10000.0,22498.691545009613,23414.77932405472,22498.691545009613,911.7745883464812,2.06522798538208,0.0 -66200,3.6847134,2.6804454,,,,,,,,,,,,,, -66300,3.624288,2.7056692,,,,,,,,,,,,,, -66400,3.5550365,2.6272442,,,,,,,,,,,,,, -66500,3.532316,2.7115102,,,,,,,,,,,,,, -66600,3.3486464,2.638939,,,,,,,,,,,,,, -66700,3.6454954,2.7308552,,,,,,,,,,,,,, -66800,3.2764082,2.6573524,,,,,,,,,,,,,, -66900,3.6054528,2.710348,,,,,,,,,,,,,, -67000,3.389985,2.6790266,,,,,,,,,,,,,, -67100,3.2882364,2.6924832,,,,,,,,,,,,,, -67200,3.483987,2.6293879,,,,,,,,,,,,,, -67300,3.1757324,2.6319141,,,,,,,,,,,,,, -67400,4.0074744,2.5985613,,,,,,,,,,,,,, -67500,3.4790266,2.6260908,,,,,,,,,,,,,, -67600,3.3999572,2.7207737,,,,,,,,,,,,,, -67679,,,0.6891342401504517,1.355064868927002,0.6316999793052673,1.623766541481018,50000.0,0.5098000168800354,2.2622320652008057,10000.0,23008.63331103325,23942.103929281235,23008.63331103325,929.0643086433412,2.106299638748169,0.0 -67700,3.291906,2.6421096,,,,,,,,,,,,,, -67800,3.8240006,2.7138693,,,,,,,,,,,,,, -67900,3.0678387,2.631813,,,,,,,,,,,,,, -68000,3.3581905,2.6942956,,,,,,,,,,,,,, -68100,3.252508,2.6748857,,,,,,,,,,,,,, -68200,3.1091826,2.633706,,,,,,,,,,,,,, -68300,3.2150025,2.5913796,,,,,,,,,,,,,, -68400,3.681739,2.6269748,,,,,,,,,,,,,, -68500,3.5839765,2.6960216,,,,,,,,,,,,,, -68600,3.5572174,2.5864742,,,,,,,,,,,,,, -68700,3.4708908,2.647257,,,,,,,,,,,,,, -68800,3.686768,2.6592207,,,,,,,,,,,,,, -68900,3.0572677,2.721355,,,,,,,,,,,,,, -69000,3.8540215,2.5608215,,,,,,,,,,,,,, -69100,3.3808084,2.7345395,,,,,,,,,,,,,, -69186,,,0.6904296875,1.3263139724731443,0.6378600001335144,1.573560118675232,50000.0,0.5093000531196594,2.243839263916016,10000.0,23518.742659330368,24469.36714053154,23518.742659330368,946.126077890396,2.1450648307800293,0.0 -69200,3.1846995,2.614081,,,,,,,,,,,,,, -69300,3.602541,2.571263,,,,,,,,,,,,,, -69400,3.5224934,2.5485353,,,,,,,,,,,,,, -69500,3.5327258,2.703868,,,,,,,,,,,,,, -69600,3.0494454,2.5574698,,,,,,,,,,,,,, -69700,3.3558009,2.5950584,,,,,,,,,,,,,, -69800,3.3570259,2.611837,,,,,,,,,,,,,, -69900,4.1770387,2.6329348,,,,,,,,,,,,,, -70000,3.5880103,2.6902094,,,,,,,,,,,,,, -70100,3.780727,2.664992,,,,,,,,,,,,,, -70200,3.0656524,2.6016147,,,,,,,,,,,,,, -70300,3.45166,2.6127996,,,,,,,,,,,,,, -70400,3.5324771,2.6472611,,,,,,,,,,,,,, -70500,3.5514045,2.6497622,,,,,,,,,,,,,, -70600,3.4395094,2.676895,,,,,,,,,,,,,, -70693,,,0.6878587007522583,1.3525885343551636,0.6403399705886841,1.573731780052185,50000.0,0.5091000199317932,2.269750595092773,10000.0,24028.91832447052,24997.529990196228,24028.91832447052,964.0205476284028,2.1862633228302,0.0 -70700,3.5223567,2.5808518,,,,,,,,,,,,,, -70800,3.092473,2.625537,,,,,,,,,,,,,, -70900,3.1912172,2.7160883,,,,,,,,,,,,,, -71000,3.0152807,2.7043116,,,,,,,,,,,,,, -71100,3.1782854,2.7916272,,,,,,,,,,,,,, -71200,3.8802319,2.722777,,,,,,,,,,,,,, -71300,3.398731,2.6280506,,,,,,,,,,,,,, -71400,3.3075655,2.601132,,,,,,,,,,,,,, -71500,4.028394,2.664899,,,,,,,,,,,,,, -71600,3.20743,2.6819434,,,,,,,,,,,,,, -71700,3.0753174,2.6216824,,,,,,,,,,,,,, -71800,3.49423,2.6941803,,,,,,,,,,,,,, -71900,3.9030585,2.654374,,,,,,,,,,,,,, -72000,4.050318,2.698744,,,,,,,,,,,,,, -72100,3.1197772,2.6731584,,,,,,,,,,,,,, -72200,,,0.6857063174247742,1.3495975732803345,0.6386199593544006,1.5710564851760864,50000.0,0.5123000144958496,2.238888502120972,10000.0,24538.978356838223,25524.813975811005,24538.978356838223,981.154138803482,2.2252821922302246,0.0 -72200,3.8868647,2.5906332,,,,,,,,,,,,,, -72300,3.541343,2.6706238,,,,,,,,,,,,,, -72400,3.767845,2.684954,,,,,,,,,,,,,, -72500,3.3383129,2.7302341,,,,,,,,,,,,,, -72600,3.6461384,2.7602274,,,,,,,,,,,,,, -72700,3.6988368,2.6174862,,,,,,,,,,,,,, -72800,3.1682665,2.6479127,,,,,,,,,,,,,, -72900,3.758609,2.6089149,,,,,,,,,,,,,, -73000,3.4763696,2.6552796,,,,,,,,,,,,,, -73100,3.4646745,2.6229157,,,,,,,,,,,,,, -73200,4.380325,2.5959032,,,,,,,,,,,,,, -73300,3.9081447,2.5481477,,,,,,,,,,,,,, -73400,3.676307,2.6993194,,,,,,,,,,,,,, -73500,3.3990915,2.6303978,,,,,,,,,,,,,, -73600,5.1472235,2.647428,,,,,,,,,,,,,, -73700,3.3582206,2.6637747,,,,,,,,,,,,,, -73707,,,0.6907086968421936,1.3185800313949585,0.6348400115966797,1.5660574436187744,50000.0,0.5172000527381897,2.202910900115967,10000.0,25049.043762922287,26051.94483280182,25049.043762922287,998.1271076202391,2.264722347259521,0.0 -73800,3.6482227,2.719621,,,,,,,,,,,,,, -73900,3.2165182,2.6588664,,,,,,,,,,,,,, -74000,3.4732232,2.6038303,,,,,,,,,,,,,, -74100,3.6148248,2.618491,,,,,,,,,,,,,, -74200,3.3379982,2.5873804,,,,,,,,,,,,,, -74300,4.0904446,2.6738017,,,,,,,,,,,,,, -74400,3.5249817,2.694533,,,,,,,,,,,,,, -74500,4.28405,2.6141047,,,,,,,,,,,,,, -74600,3.4842203,2.64717,,,,,,,,,,,,,, -74700,2.9062016,2.6514301,,,,,,,,,,,,,, -74800,3.6596756,2.5112739,,,,,,,,,,,,,, -74900,3.5857265,2.6891599,,,,,,,,,,,,,, -75000,3.3170402,2.705982,,,,,,,,,,,,,, -75100,3.2258391,2.605612,,,,,,,,,,,,,, -75200,3.0127776,2.6133287,,,,,,,,,,,,,, -75215,,,0.7098612785339355,1.2410625219345093,0.6350199580192566,1.5747349262237549,50000.0,0.5069000124931335,2.2694790363311768,10000.0,25559.250710248947,26580.20965051651,25559.250710248947,1016.093424797058,2.3037054538726807,0.0 -75300,3.4417827,2.556425,,,,,,,,,,,,,, -75400,3.7157588,2.5177612,,,,,,,,,,,,,, -75500,4.0704875,2.6254838,,,,,,,,,,,,,, -75600,3.873163,2.656343,,,,,,,,,,,,,, -75700,3.9343507,2.6623774,,,,,,,,,,,,,, -75800,3.2048612,2.6365674,,,,,,,,,,,,,, -75900,3.380592,2.675864,,,,,,,,,,,,,, -76000,3.4913917,2.5453286,,,,,,,,,,,,,, -76100,3.4550674,2.6125276,,,,,,,,,,,,,, -76200,3.3082476,2.4830725,,,,,,,,,,,,,, -76300,3.706917,2.588419,,,,,,,,,,,,,, -76400,3.3544197,2.6664507,,,,,,,,,,,,,, -76500,3.515668,2.593944,,,,,,,,,,,,,, -76600,3.7062342,2.6350734,,,,,,,,,,,,,, -76700,3.48737,2.7049496,,,,,,,,,,,,,, -76722,,,0.7016701102256775,1.2833776473999023,0.6454399824142456,1.553180456161499,50000.0,0.5156000256538391,2.2022013664245605,10000.0,26069.155710458755,27107.455953359604,26069.155710458755,1033.3393914699554,2.347717046737671,0.0 -76800,3.2957015,2.620296,,,,,,,,,,,,,, -76900,3.9594371,2.5601218,,,,,,,,,,,,,, -77000,3.1007135,2.5580335,,,,,,,,,,,,,, -77100,3.4985037,2.6562696,,,,,,,,,,,,,, -77200,3.28029,2.5810199,,,,,,,,,,,,,, -77300,3.5010772,2.649449,,,,,,,,,,,,,, -77400,3.6299272,2.6559558,,,,,,,,,,,,,, -77500,3.2040894,2.6144414,,,,,,,,,,,,,, -77600,3.7006211,2.60777,,,,,,,,,,,,,, -77700,3.421905,2.5789142,,,,,,,,,,,,,, -77800,3.9065747,2.6404555,,,,,,,,,,,,,, -77900,3.0734637,2.5436914,,,,,,,,,,,,,, -78000,3.6588936,2.553893,,,,,,,,,,,,,, -78100,4.111753,2.5586648,,,,,,,,,,,,,, -78200,3.543447,2.6024487,,,,,,,,,,,,,, -78229,,,0.7018494606018066,1.2792024612426758,0.6469199657440186,1.5287810564041138,50000.0,0.5181000232696533,2.194598197937012,10000.0,26579.27547645569,27634.67905735969,26579.27547645569,1050.3499314785004,2.3884060382843018,0.0 -78300,3.2073092,2.5773392,,,,,,,,,,,,,, -78400,3.3469887,2.6201897,,,,,,,,,,,,,, -78500,3.6474612,2.4999638,,,,,,,,,,,,,, -78600,3.3820398,2.620543,,,,,,,,,,,,,, -78700,3.2992816,2.6247082,,,,,,,,,,,,,, -78800,3.1578999,2.465015,,,,,,,,,,,,,, -78900,3.199951,2.5891132,,,,,,,,,,,,,, -79000,3.537955,2.6333077,,,,,,,,,,,,,, -79100,3.6781228,2.627079,,,,,,,,,,,,,, -79200,3.4671996,2.7006683,,,,,,,,,,,,,, -79300,3.3358202,2.5422637,,,,,,,,,,,,,, -79400,3.5534635,2.6223001,,,,,,,,,,,,,, -79500,3.2630768,2.7587144,,,,,,,,,,,,,, -79600,3.8752224,2.6240778,,,,,,,,,,,,,, -79700,3.3130476,2.5578146,,,,,,,,,,,,,, -79736,,,0.702168345451355,1.2795171737670898,0.6496399641036987,1.508771896362305,50000.0,0.5218000411987305,2.1807165145874023,10000.0,27089.258211374283,28161.826422214508,27089.258211374283,1067.4200563430786,2.4313290119171143,0.0 -79800,3.4544203,2.5748222,,,,,,,,,,,,,, -79900,3.422591,2.6635945,,,,,,,,,,,,,, -80000,3.5759149,2.5628295,,,,,,,,,,,,,, -80100,3.5291674,2.5961533,,,,,,,,,,,,,, -80200,3.823906,2.6396945,,,,,,,,,,,,,, -80300,3.5231836,2.6483364,,,,,,,,,,,,,, -80400,3.8228204,2.6099234,,,,,,,,,,,,,, -80500,3.8530624,2.6236603,,,,,,,,,,,,,, -80600,3.4447556,2.5596406,,,,,,,,,,,,,, -80700,3.561265,2.6496868,,,,,,,,,,,,,, -80800,3.2954943,2.609534,,,,,,,,,,,,,, -80900,3.0571792,2.5157557,,,,,,,,,,,,,, -81000,3.5167282,2.5978382,,,,,,,,,,,,,, -81100,3.740578,2.6651077,,,,,,,,,,,,,, -81200,4.1968803,2.5837724,,,,,,,,,,,,,, -81243,,,0.6947743892669678,1.3432766199111938,0.6473199725151062,1.556182026863098,50000.0,0.5243000388145447,2.222846031188965,10000.0,27599.380972385406,28689.146060228348,27599.380972385406,1084.517686367035,2.4775948524475098,0.0 -81300,3.5275388,2.5721798,,,,,,,,,,,,,, -81400,3.3307738,2.5774226,,,,,,,,,,,,,, -81500,3.823169,2.697997,,,,,,,,,,,,,, -81600,3.9252467,2.6094527,,,,,,,,,,,,,, -81700,3.6636631,2.626408,,,,,,,,,,,,,, -81800,3.215357,2.5820181,,,,,,,,,,,,,, -81900,4.3689117,2.5412374,,,,,,,,,,,,,, -82000,3.6163418,2.5831323,,,,,,,,,,,,,, -82100,3.807524,2.5110855,,,,,,,,,,,,,, -82200,4.627461,2.5497794,,,,,,,,,,,,,, -82300,4.4320354,2.5788455,,,,,,,,,,,,,, -82400,3.2748518,2.5465555,,,,,,,,,,,,,, -82500,3.1058576,2.5613317,,,,,,,,,,,,,, -82600,3.61891,2.6256084,,,,,,,,,,,,,, -82700,3.6480184,2.6613047,,,,,,,,,,,,,, -82750,,,0.6988998651504517,1.2671300172805786,0.6509999632835388,1.4982835054397583,50000.0,0.5206000208854675,2.15671443939209,10000.0,28109.40303182602,29216.481212615967,28109.40303182602,1101.7360954284668,2.5183119773864746,0.0 -82800,4.3569207,2.5810838,,,,,,,,,,,,,, -82900,3.5373821,2.6213653,,,,,,,,,,,,,, -83000,3.3514736,2.539936,,,,,,,,,,,,,, -83100,3.5143995,2.6139052,,,,,,,,,,,,,, -83200,3.473104,2.4638345,,,,,,,,,,,,,, -83300,3.8050604,2.6418626,,,,,,,,,,,,,, -83400,3.6337328,2.5722322,,,,,,,,,,,,,, -83500,4.2592435,2.5707595,,,,,,,,,,,,,, -83600,3.4779348,2.614814,,,,,,,,,,,,,, -83700,4.0404935,2.6053329,,,,,,,,,,,,,, -83800,3.810152,2.6664267,,,,,,,,,,,,,, -83900,4.3991246,2.6575823,,,,,,,,,,,,,, -84000,3.4149103,2.6203368,,,,,,,,,,,,,, -84100,3.481344,2.6082175,,,,,,,,,,,,,, -84200,4.24729,2.6389084,,,,,,,,,,,,,, -84258,,,0.7250677347183228,1.181086778640747,0.653499960899353,1.501522421836853,50000.0,0.5231000185012817,2.192289113998413,10000.0,28619.604907035828,29744.09475159645,28619.604907035828,1119.0488619804382,2.564819812774658,0.0 -84300,3.8184857,2.5699656,,,,,,,,,,,,,, -84400,3.6057856,2.5889666,,,,,,,,,,,,,, -84500,3.484225,2.5197904,,,,,,,,,,,,,, -84600,3.8168335,2.5968497,,,,,,,,,,,,,, -84700,3.5380867,2.502923,,,,,,,,,,,,,, -84800,3.6239598,2.4996843,,,,,,,,,,,,,, -84900,4.0304017,2.7101104,,,,,,,,,,,,,, -85000,3.6386507,2.5995092,,,,,,,,,,,,,, -85100,3.6356695,2.6408796,,,,,,,,,,,,,, -85200,3.3430402,2.590887,,,,,,,,,,,,,, -85300,3.929299,2.531669,,,,,,,,,,,,,, -85400,3.8321295,2.5871687,,,,,,,,,,,,,, -85500,3.132037,2.552526,,,,,,,,,,,,,, -85600,3.2930646,2.5196192,,,,,,,,,,,,,, -85700,3.8217793,2.667921,,,,,,,,,,,,,, -85765,,,0.7072703838348389,1.2629165649414062,0.644540011882782,1.5403581857681274,50000.0,0.5260000228881836,2.19507384300232,10000.0,29129.81632399559,30271.58340406418,29129.81632399559,1136.2286508083344,2.6103806495666504,0.0 -85800,3.1756034,2.5860915,,,,,,,,,,,,,, -85900,3.6558027,2.6027493,,,,,,,,,,,,,, -86000,3.8566234,2.6236384,,,,,,,,,,,,,, -86100,3.511294,2.5535812,,,,,,,,,,,,,, -86200,3.4670455,2.503463,,,,,,,,,,,,,, -86300,3.8512495,2.6179204,,,,,,,,,,,,,, -86400,3.704464,2.515919,,,,,,,,,,,,,, -86500,3.7978706,2.5386345,,,,,,,,,,,,,, -86600,3.5695174,2.5934644,,,,,,,,,,,,,, -86700,3.6801867,2.6195483,,,,,,,,,,,,,, -86800,4.979945,2.64673,,,,,,,,,,,,,, -86900,3.5554574,2.5718906,,,,,,,,,,,,,, -87000,4.266795,2.5361004,,,,,,,,,,,,,, -87100,3.6189282,2.605863,,,,,,,,,,,,,, -87200,3.4143014,2.6135082,,,,,,,,,,,,,, -87272,,,0.7074697017669678,1.2444090843200684,0.6504600048065186,1.5148764848709106,50000.0,0.5247000455856323,2.167389154434204,10000.0,29639.82198214531,30799.211097955704,29639.82198214531,1153.7549715042114,2.653380870819092,0.0 -87300,3.5377111,2.5163422,,,,,,,,,,,,,, -87400,4.173608,2.575969,,,,,,,,,,,,,, -87500,3.5846138,2.6142147,,,,,,,,,,,,,, -87600,4.273701,2.5761733,,,,,,,,,,,,,, -87700,3.7799568,2.5639455,,,,,,,,,,,,,, -87800,4.120995,2.5577316,,,,,,,,,,,,,, -87900,4.251648,2.5560977,,,,,,,,,,,,,, -88000,3.5220935,2.5511496,,,,,,,,,,,,,, -88100,3.415333,2.6604064,,,,,,,,,,,,,, -88200,3.8202772,2.4863925,,,,,,,,,,,,,, -88300,4.0705605,2.571889,,,,,,,,,,,,,, -88400,4.1948423,2.5625339,,,,,,,,,,,,,, -88500,3.8000772,2.498811,,,,,,,,,,,,,, -88600,3.8147292,2.5491629,,,,,,,,,,,,,, -88700,4.0542164,2.5082912,,,,,,,,,,,,,, -88780,,,0.702566921710968,1.2763031721115112,0.6502799987792969,1.5065932273864746,50000.0,0.5261000394821167,2.174659013748169,10000.0,30149.93711233139,31326.53332161904,30149.93711233139,1170.8653423786163,2.6973180770874023,0.0 -88800,3.127415,2.5386064,,,,,,,,,,,,,, -88900,3.807929,2.551885,,,,,,,,,,,,,, -89000,3.640165,2.5782862,,,,,,,,,,,,,, -89100,3.7998455,2.5627916,,,,,,,,,,,,,, -89200,3.6049259,2.4467509,,,,,,,,,,,,,, -89300,4.0671043,2.5621853,,,,,,,,,,,,,, -89400,3.6294389,2.49207,,,,,,,,,,,,,, -89500,3.457746,2.5143197,,,,,,,,,,,,,, -89600,3.6582074,2.5816243,,,,,,,,,,,,,, -89700,3.6140544,2.4494357,,,,,,,,,,,,,, -89800,3.4966426,2.6445303,,,,,,,,,,,,,, -89900,3.8655317,2.6744196,,,,,,,,,,,,,, -90000,3.4882672,2.5680406,,,,,,,,,,,,,, -90100,3.593982,2.5350063,,,,,,,,,,,,,, -90200,3.9349334,2.5552235,,,,,,,,,,,,,, -90287,,,0.7116549611091614,1.246743083000183,0.6574999690055847,1.4754823446273804,50000.0,0.5333000421524048,2.1418609619140625,10000.0,30659.994156122208,31853.74282312393,30659.994156122208,1187.9223272800446,2.740720748901367,0.0 -90300,3.7239797,2.4964728,,,,,,,,,,,,,, -90400,3.3827884,2.6535845,,,,,,,,,,,,,, -90500,4.58574,2.5577593,,,,,,,,,,,,,, -90600,3.6176925,2.6151175,,,,,,,,,,,,,, -90700,4.4401455,2.4956431,,,,,,,,,,,,,, -90800,3.6717727,2.4958947,,,,,,,,,,,,,, -90900,3.7643511,2.4919767,,,,,,,,,,,,,, -91000,3.5957375,2.4919195,,,,,,,,,,,,,, -91100,3.4895344,2.495077,,,,,,,,,,,,,, -91200,3.8403041,2.539857,,,,,,,,,,,,,, -91300,3.5537,2.5019913,,,,,,,,,,,,,, -91400,4.8186064,2.4760153,,,,,,,,,,,,,, -91500,4.0116644,2.5367603,,,,,,,,,,,,,, -91600,4.102978,2.5782108,,,,,,,,,,,,,, -91700,3.7328093,2.6548629,,,,,,,,,,,,,, -91793,,,0.7194674611091614,1.1982098817825315,0.6654799580574036,1.4386276006698608,50000.0,0.5433000326156616,2.08927321434021,10000.0,31169.86101269722,32380.824385404587,31169.86101269722,1204.966181755066,2.859598398208618,0.0 -91800,4.610189,2.5852594,,,,,,,,,,,,,, -91900,3.713936,2.5164518,,,,,,,,,,,,,, -92000,4.0562305,2.6030207,,,,,,,,,,,,,, -92100,3.4235919,2.5082412,,,,,,,,,,,,,, -92200,3.304976,2.4640646,,,,,,,,,,,,,, -92300,3.9729757,2.4834554,,,,,,,,,,,,,, -92400,3.9980843,2.5889575,,,,,,,,,,,,,, -92500,3.395131,2.6064885,,,,,,,,,,,,,, -92600,3.6187057,2.5125875,,,,,,,,,,,,,, -92700,4.2748966,2.5080063,,,,,,,,,,,,,, -92800,3.8773787,2.674811,,,,,,,,,,,,,, -92900,3.856363,2.5437677,,,,,,,,,,,,,, -93000,4.3100014,2.5638874,,,,,,,,,,,,,, -93100,4.3460636,2.5211453,,,,,,,,,,,,,, -93200,3.7843702,2.567917,,,,,,,,,,,,,, -93300,,,0.722676157951355,1.201149821281433,0.6536999940872192,1.525767803192139,50000.0,0.5268000364303589,2.1804299354553223,10000.0,31679.799534082413,32907.94575691223,31679.799534082413,1222.0526728630066,2.903446912765503,0.0 -93300,4.307817,2.5793962,,,,,,,,,,,,,, -93400,4.2858567,2.5154886,,,,,,,,,,,,,, -93500,3.9109447,2.5093923,,,,,,,,,,,,,, -93600,3.7775815,2.4831057,,,,,,,,,,,,,, -93700,4.313739,2.6455116,,,,,,,,,,,,,, -93800,4.6403217,2.5603287,,,,,,,,,,,,,, -93900,3.6891425,2.55624,,,,,,,,,,,,,, -94000,4.134584,2.5823622,,,,,,,,,,,,,, -94100,3.947318,2.449223,,,,,,,,,,,,,, -94200,3.46941,2.507267,,,,,,,,,,,,,, -94300,4.236839,2.6835108,,,,,,,,,,,,,, -94400,4.1588116,2.4808867,,,,,,,,,,,,,, -94500,3.6678796,2.6628523,,,,,,,,,,,,,, -94600,3.6723547,2.4598641,,,,,,,,,,,,,, -94700,4.5223975,2.5619342,,,,,,,,,,,,,, -94800,3.7824118,2.582232,,,,,,,,,,,,,, -94807,,,0.7135283946990967,1.234078884124756,0.650439977645874,1.527204155921936,50000.0,0.5247000455856323,2.177957057952881,10000.0,32189.77259206772,33435.2316634655,32189.77259206772,1239.2692143917084,2.946500778198242,0.0 -94900,3.663495,2.6025932,,,,,,,,,,,,,, -95000,4.0240545,2.568561,,,,,,,,,,,,,, -95100,3.9076955,2.4657953,,,,,,,,,,,,,, -95200,3.804202,2.4642558,,,,,,,,,,,,,, -95300,3.9367867,2.5147142,,,,,,,,,,,,,, -95400,3.9005215,2.5484545,,,,,,,,,,,,,, -95500,3.4467905,2.501329,,,,,,,,,,,,,, -95600,4.1978383,2.481057,,,,,,,,,,,,,, -95700,4.1582475,2.5064583,,,,,,,,,,,,,, -95800,3.9367511,2.5588374,,,,,,,,,,,,,, -95900,4.358774,2.4686887,,,,,,,,,,,,,, -96000,3.8611279,2.5093415,,,,,,,,,,,,,, -96100,4.7183824,2.558148,,,,,,,,,,,,,, -96200,4.4106846,2.5333514,,,,,,,,,,,,,, -96300,3.659285,2.459774,,,,,,,,,,,,,, -96315,,,0.7221978306770325,1.189970850944519,0.6604399681091309,1.467153549194336,50000.0,0.532800018787384,2.1280345916748047,10000.0,32699.811596155167,33962.59027004242,32699.811596155167,1256.4928257465365,2.9893760681152344,0.0 -96400,4.093901,2.608292,,,,,,,,,,,,,, -96500,4.301923,2.5779738,,,,,,,,,,,,,, -96600,3.9993827,2.5989804,,,,,,,,,,,,,, -96700,3.8571754,2.4815097,,,,,,,,,,,,,, -96800,4.3915906,2.5705059,,,,,,,,,,,,,, -96900,3.8123865,2.5143085,,,,,,,,,,,,,, -97000,3.9973166,2.4192502,,,,,,,,,,,,,, -97100,3.6552017,2.5622392,,,,,,,,,,,,,, -97200,3.6577497,2.4935849,,,,,,,,,,,,,, -97300,4.8597527,2.538318,,,,,,,,,,,,,, -97400,3.885061,2.511459,,,,,,,,,,,,,, -97500,3.7596252,2.4905107,,,,,,,,,,,,,, -97600,3.9476202,2.5692973,,,,,,,,,,,,,, -97700,4.1247997,2.4726505,,,,,,,,,,,,,, -97800,4.5368476,2.6217957,,,,,,,,,,,,,, -97822,,,0.725027859210968,1.1888576745986938,0.668999969959259,1.4445313215255735,50000.0,0.5423000454902649,2.084092617034912,10000.0,33209.789820194244,34490.15003442764,33209.789820194244,1273.975423336029,3.037425756454468,0.0 -97900,3.612415,2.4731264,,,,,,,,,,,,,, -98000,4.057306,2.5276668,,,,,,,,,,,,,, -98100,4.113021,2.5510788,,,,,,,,,,,,,, -98200,3.9628158,2.4187856,,,,,,,,,,,,,, -98300,4.4519563,2.5056794,,,,,,,,,,,,,, -98400,3.519541,2.4663267,,,,,,,,,,,,,, -98500,4.046599,2.4770246,,,,,,,,,,,,,, -98600,3.8298385,2.50927,,,,,,,,,,,,,, -98700,3.9763432,2.497029,,,,,,,,,,,,,, -98800,4.469265,2.5752096,,,,,,,,,,,,,, -98900,4.694951,2.444672,,,,,,,,,,,,,, -99000,3.6160755,2.5742748,,,,,,,,,,,,,, -99100,4.615631,2.4928775,,,,,,,,,,,,,, -99200,4.10104,2.494493,,,,,,,,,,,,,, -99300,4.0270715,2.4992225,,,,,,,,,,,,,, -99329,,,0.7132493257522583,1.2412854433059692,0.6611199975013733,1.480249047279358,50000.0,0.5276000499725342,2.150489568710327,10000.0,33719.99088358879,35017.532859802246,33719.99088358879,1291.060376405716,3.081228733062744,0.0 -99400,3.4775302,2.4771974,,,,,,,,,,,,,, -99500,3.9668984,2.5184364,,,,,,,,,,,,,, -99600,3.9668238,2.5327306,,,,,,,,,,,,,, -99700,3.7898922,2.5154095,,,,,,,,,,,,,, -99800,3.717359,2.4210207,,,,,,,,,,,,,, -99900,4.2784767,2.5391016,,,,,,,,,,,,,, -100000,3.9383307,2.39476,,,,,,,,,,,,,, -100100,3.6978104,2.4799247,,,,,,,,,,,,,, -100200,4.064404,2.5050554,,,,,,,,,,,,,, -100300,3.667033,2.473528,,,,,,,,,,,,,, -100400,4.011183,2.4941175,,,,,,,,,,,,,, -100500,3.8430898,2.413337,,,,,,,,,,,,,, -100600,4.271947,2.4213345,,,,,,,,,,,,,, -100700,3.8636749,2.4451249,,,,,,,,,,,,,, -100800,3.696046,2.4257324,,,,,,,,,,,,,, -100836,,,0.7284757494926453,1.1622774600982666,0.6675999760627747,1.4319740533828735,50000.0,0.5448000431060791,2.0938448905944824,10000.0,34229.93264293671,35544.59036445618,34229.93264293671,1308.0780620574951,3.127570152282715,0.0 -100900,4.3087816,2.5770407,,,,,,,,,,,,,, -101000,4.270846,2.4948275,,,,,,,,,,,,,, -101100,3.6277328,2.5470688,,,,,,,,,,,,,, -101200,4.326176,2.4345043,,,,,,,,,,,,,, -101300,4.297858,2.4825997,,,,,,,,,,,,,, -101400,3.8601148,2.435741,,,,,,,,,,,,,, -101500,3.9465923,2.522184,,,,,,,,,,,,,, -101600,3.5638561,2.462099,,,,,,,,,,,,,, -101700,3.7324677,2.5600858,,,,,,,,,,,,,, -101800,4.3758636,2.3933382,,,,,,,,,,,,,, -101900,3.9711075,2.5020933,,,,,,,,,,,,,, -102000,4.0032234,2.4438353,,,,,,,,,,,,,, -102100,3.6485162,2.4969282,,,,,,,,,,,,,, -102200,3.6477363,2.4898398,,,,,,,,,,,,,, -102300,3.7305672,2.4655418,,,,,,,,,,,,,, -102343,,,0.7571149468421936,1.045891046524048,0.675879955291748,1.4058371782302856,50000.0,0.5517000555992126,2.0450947284698486,10000.0,34739.89816379547,36071.67170572281,34739.89816379547,1325.09671998024,3.1725218296051025,0.0 -102400,4.130425,2.4310026,,,,,,,,,,,,,, -102500,4.244378,2.4094088,,,,,,,,,,,,,, -102600,4.058507,2.4345827,,,,,,,,,,,,,, -102700,3.5979583,2.4578412,,,,,,,,,,,,,, -102800,4.291101,2.5214849,,,,,,,,,,,,,, -102900,3.6123478,2.4301805,,,,,,,,,,,,,, -103000,4.4983377,2.4799752,,,,,,,,,,,,,, -103100,3.5528438,2.420466,,,,,,,,,,,,,, -103200,4.1594086,2.46181,,,,,,,,,,,,,, -103300,4.56728,2.516208,,,,,,,,,,,,,, -103400,4.302914,2.47488,,,,,,,,,,,,,, -103500,3.927101,2.4379768,,,,,,,,,,,,,, -103600,4.216721,2.5087738,,,,,,,,,,,,,, -103700,4.1387963,2.4304466,,,,,,,,,,,,,, -103800,4.2055807,2.3983984,,,,,,,,,,,,,, -103850,,,0.7410514950752258,1.109355092048645,0.6710000038146973,1.4239152669906616,50000.0,0.5487000346183777,2.0953757762908936,10000.0,35250.05817198753,36599.079026699066,35250.05817198753,1342.2405395507812,3.224278211593628,0.0 -103900,3.8216166,2.4512305,,,,,,,,,,,,,, -104000,3.7989657,2.407841,,,,,,,,,,,,,, -104100,4.199044,2.4891913,,,,,,,,,,,,,, -104200,4.538221,2.4833302,,,,,,,,,,,,,, -104300,4.058106,2.4807525,,,,,,,,,,,,,, -104400,3.8757267,2.4844322,,,,,,,,,,,,,, -104500,3.9925578,2.4610236,,,,,,,,,,,,,, -104600,4.2853274,2.52176,,,,,,,,,,,,,, -104700,3.9023235,2.4792695,,,,,,,,,,,,,, -104800,4.455448,2.4681156,,,,,,,,,,,,,, -104900,6.432749,2.5409093,,,,,,,,,,,,,, -105000,3.6148217,2.3860962,,,,,,,,,,,,,, -105100,3.9457643,2.4586103,,,,,,,,,,,,,, -105200,4.176882,2.5318112,,,,,,,,,,,,,, -105300,4.490621,2.4180717,,,,,,,,,,,,,, -105356,,,0.7430245280265808,1.116188883781433,0.6763399839401245,1.4128817319869995,50000.0,0.5460000038146973,2.064647674560547,10000.0,35759.976840257645,37126.32803606987,35759.976840257645,1359.4712007045746,3.270002841949463,0.0 -105400,4.2146287,2.5404744,,,,,,,,,,,,,, -105500,4.6772904,2.5113926,,,,,,,,,,,,,, -105600,3.839474,2.4478195,,,,,,,,,,,,,, -105700,4.528245,2.4868853,,,,,,,,,,,,,, -105800,4.112986,2.432323,,,,,,,,,,,,,, -105900,3.9827116,2.3807068,,,,,,,,,,,,,, -106000,3.6533306,2.4935918,,,,,,,,,,,,,, -106100,4.119759,2.4475756,,,,,,,,,,,,,, -106200,4.023516,2.3651252,,,,,,,,,,,,,, -106300,3.6215212,2.3601398,,,,,,,,,,,,,, -106400,4.230387,2.50191,,,,,,,,,,,,,, -106500,4.1410317,2.538101,,,,,,,,,,,,,, -106600,4.098635,2.4815276,,,,,,,,,,,,,, -106700,4.2667747,2.5011804,,,,,,,,,,,,,, -106800,4.242419,2.417226,,,,,,,,,,,,,, -106863,,,0.7424864172935486,1.115410089492798,0.6801799535751343,1.388283610343933,50000.0,0.5570999979972839,2.026524066925049,10000.0,36269.989028692245,37653.3395678997,36269.989028692245,1376.36949300766,3.318311929702759,0.0 -106900,4.852457,2.469089,,,,,,,,,,,,,, -107000,3.6218214,2.3945632,,,,,,,,,,,,,, -107100,4.154459,2.4612007,,,,,,,,,,,,,, -107200,4.0862694,2.402587,,,,,,,,,,,,,, -107300,4.1156845,2.3681595,,,,,,,,,,,,,, -107400,3.9540386,2.4198027,,,,,,,,,,,,,, -107500,4.017131,2.3966613,,,,,,,,,,,,,, -107600,4.002015,2.415003,,,,,,,,,,,,,, -107700,4.6839504,2.5199614,,,,,,,,,,,,,, -107800,4.477203,2.4761338,,,,,,,,,,,,,, -107900,4.066377,2.4472356,,,,,,,,,,,,,, -108000,3.8870792,2.4944882,,,,,,,,,,,,,, -108100,3.9864483,2.4766245,,,,,,,,,,,,,, -108200,3.8827977,2.410283,,,,,,,,,,,,,, -108300,4.2258854,2.3952143,,,,,,,,,,,,,, -108370,,,0.7401745915412903,1.10150146484375,0.6816999912261963,1.3689839839935305,50000.0,0.5527999997138977,2.0306382179260254,10000.0,36780.031766176224,38180.82426738739,36780.031766176224,1393.7109084129331,3.367356061935425,0.0 -108400,4.56156,2.5073519,,,,,,,,,,,,,, -108500,4.429436,2.3753624,,,,,,,,,,,,,, -108600,4.2289796,2.4695706,,,,,,,,,,,,,, -108700,4.4385366,2.4456115,,,,,,,,,,,,,, -108800,4.4637165,2.4551523,,,,,,,,,,,,,, -108900,4.7107377,2.4105806,,,,,,,,,,,,,, -109000,4.215518,2.4662642,,,,,,,,,,,,,, -109100,4.3094563,2.5215306,,,,,,,,,,,,,, -109200,3.9427223,2.449462,,,,,,,,,,,,,, -109300,5.0736184,2.4485245,,,,,,,,,,,,,, -109400,4.1930895,2.4057431,,,,,,,,,,,,,, -109500,4.002015,2.420375,,,,,,,,,,,,,, -109600,4.519369,2.5322256,,,,,,,,,,,,,, -109700,4.1474776,2.446612,,,,,,,,,,,,,, -109800,3.9554877,2.4825954,,,,,,,,,,,,,, -109877,,,0.7429846525192261,1.0977813005447388,0.6827200055122375,1.3649893999099731,50000.0,0.5592000484466553,2.0068023204803467,10000.0,37290.11735486984,38708.2563290596,37290.11735486984,1410.958943605423,3.413281202316284,0.0 -109900,4.0420365,2.4157324,,,,,,,,,,,,,, -110000,4.2216363,2.4733603,,,,,,,,,,,,,, -110100,3.85207,2.4107218,,,,,,,,,,,,,, -110200,4.369595,2.433619,,,,,,,,,,,,,, -110300,4.189379,2.3327103,,,,,,,,,,,,,, -110400,3.8405497,2.4656842,,,,,,,,,,,,,, -110500,4.02547,2.442381,,,,,,,,,,,,,, -110600,3.9430623,2.3826537,,,,,,,,,,,,,, -110700,4.1018534,2.4563894,,,,,,,,,,,,,, -110800,4.1593995,2.3889298,,,,,,,,,,,,,, -110900,4.836023,2.4824617,,,,,,,,,,,,,, -111000,4.4835305,2.3469448,,,,,,,,,,,,,, -111100,3.8269246,2.3150997,,,,,,,,,,,,,, -111200,4.4810987,2.4637222,,,,,,,,,,,,,, -111300,5.6933475,2.3789334,,,,,,,,,,,,,, -111385,,,0.7664620280265808,1.0001466274261477,0.6792399883270264,1.3854286670684814,50000.0,0.551300048828125,2.0405306816101074,10000.0,37800.23961639404,39235.68991231918,37800.23961639404,1428.169054031372,3.4625115394592285,0.0 -111400,4.168498,2.3886104,,,,,,,,,,,,,, -111500,4.2533464,2.3795352,,,,,,,,,,,,,, -111600,4.546812,2.3859487,,,,,,,,,,,,,, -111700,4.3554773,2.3838792,,,,,,,,,,,,,, -111800,3.9746222,2.4893618,,,,,,,,,,,,,, -111900,4.3577685,2.3792176,,,,,,,,,,,,,, -112000,3.9235678,2.4506257,,,,,,,,,,,,,, -112100,4.182967,2.3817232,,,,,,,,,,,,,, -112200,4.5282135,2.4635036,,,,,,,,,,,,,, -112300,4.9227815,2.4121447,,,,,,,,,,,,,, -112400,4.4698715,2.349537,,,,,,,,,,,,,, -112500,4.178269,2.3557549,,,,,,,,,,,,,, -112600,3.996601,2.4576468,,,,,,,,,,,,,, -112700,5.108054,2.50825,,,,,,,,,,,,,, -112800,4.00102,2.3641534,,,,,,,,,,,,,, -112892,,,0.7581512928009033,1.031977891921997,0.6830799579620361,1.3659889698028564,50000.0,0.5605000257492065,2.010292291641236,10000.0,38310.37893438339,39763.20695757866,38310.37893438339,1445.4499547481537,3.506695508956909,0.0 -112900,5.0645185,2.41852,,,,,,,,,,,,,, -113000,4.7617564,2.4443638,,,,,,,,,,,,,, -113100,4.4970217,2.3365996,,,,,,,,,,,,,, -113200,4.2657757,2.4134228,,,,,,,,,,,,,, -113300,4.647787,2.4486647,,,,,,,,,,,,,, -113400,4.131575,2.3106687,,,,,,,,,,,,,, -113500,4.044937,2.3580537,,,,,,,,,,,,,, -113600,4.3572187,2.3928547,,,,,,,,,,,,,, -113700,4.1441517,2.3368404,,,,,,,,,,,,,, -113800,5.7321854,2.4274917,,,,,,,,,,,,,, -113900,4.334296,2.3789368,,,,,,,,,,,,,, -114000,4.520622,2.3857114,,,,,,,,,,,,,, -114100,4.891865,2.467017,,,,,,,,,,,,,, -114200,4.892727,2.4475875,,,,,,,,,,,,,, -114300,4.321483,2.4234853,,,,,,,,,,,,,, -114399,,,0.7623365521430969,1.0210669040679932,0.6910600066184998,1.3425350189208984,50000.0,0.5693000555038452,1.9787520170211792,10000.0,38820.30667281151,40290.31508851051,38820.30667281151,1462.527009487152,3.55802059173584,0.0 -114400,4.2578826,2.3989115,,,,,,,,,,,,,, -114500,4.465198,2.3136737,,,,,,,,,,,,,, -114600,4.893187,2.480427,,,,,,,,,,,,,, -114700,3.8084176,2.2927942,,,,,,,,,,,,,, -114800,4.1587896,2.3384295,,,,,,,,,,,,,, -114900,4.409585,2.3530097,,,,,,,,,,,,,, -115000,5.162622,2.38631,,,,,,,,,,,,,, -115100,4.397873,2.364143,,,,,,,,,,,,,, -115200,4.5773716,2.4127684,,,,,,,,,,,,,, -115300,4.473305,2.286802,,,,,,,,,,,,,, -115400,4.618923,2.460541,,,,,,,,,,,,,, -115500,4.462703,2.3598201,,,,,,,,,,,,,, -115600,4.0365605,2.3561015,,,,,,,,,,,,,, -115700,4.914153,2.4791446,,,,,,,,,,,,,, -115800,4.389757,2.3653584,,,,,,,,,,,,,, -115900,4.197872,2.446908,,,,,,,,,,,,,, -115906,,,0.7527702450752258,1.0518150329589844,0.6881200075149536,1.3431140184402466,50000.0,0.5593000054359436,2.0154571533203125,10000.0,39330.32674264908,40817.67096376419,39330.32674264908,1479.7441306114197,3.622938871383667,0.0 -116000,3.895187,2.3021684,,,,,,,,,,,,,, -116100,4.8129883,2.4512465,,,,,,,,,,,,,, -116200,4.524418,2.3940592,,,,,,,,,,,,,, -116300,4.4382453,2.3426514,,,,,,,,,,,,,, -116400,4.63963,2.3900814,,,,,,,,,,,,,, -116500,4.438001,2.333404,,,,,,,,,,,,,, -116600,4.2193193,2.3977191,,,,,,,,,,,,,, -116700,4.693877,2.4303114,,,,,,,,,,,,,, -116800,4.341075,2.3627515,,,,,,,,,,,,,, -116900,4.5606093,2.3975673,,,,,,,,,,,,,, -117000,4.3841476,2.2314904,,,,,,,,,,,,,, -117100,5.107866,2.3334217,,,,,,,,,,,,,, -117200,4.947049,2.3371568,,,,,,,,,,,,,, -117300,4.9197574,2.419439,,,,,,,,,,,,,, -117400,4.7807593,2.3969252,,,,,,,,,,,,,, -117413,,,0.7529894709587097,1.057737946510315,0.6910399794578552,1.3415244817733765,50000.0,0.5701000094413757,1.969575047492981,10000.0,39840.43624377251,41345.04013109207,39840.43624377251,1496.8971843719482,3.6775968074798575,0.0 -117500,4.893809,2.406622,,,,,,,,,,,,,, -117600,5.0591717,2.432578,,,,,,,,,,,,,, -117700,4.5427504,2.3625185,,,,,,,,,,,,,, -117800,4.0074825,2.4585042,,,,,,,,,,,,,, -117900,4.1811123,2.3871884,,,,,,,,,,,,,, -118000,4.457277,2.427627,,,,,,,,,,,,,, -118100,4.4260798,2.4519184,,,,,,,,,,,,,, -118200,4.6454816,2.394082,,,,,,,,,,,,,, -118300,4.4202914,2.3961658,,,,,,,,,,,,,, -118400,4.1982074,2.2526977,,,,,,,,,,,,,, -118500,4.823434,2.3672578,,,,,,,,,,,,,, -118600,4.8833213,2.3489037,,,,,,,,,,,,,, -118700,4.544959,2.3367474,,,,,,,,,,,,,, -118800,4.7118316,2.3309846,,,,,,,,,,,,,, -118900,5.267985,2.369894,,,,,,,,,,,,,, -118920,,,0.7530691623687744,1.0459004640579224,0.6911999583244324,1.3308132886886597,50000.0,0.5654000043869019,1.998255014419556,10000.0,40350.61412215233,41872.54306650162,40350.61412215233,1514.122484207153,3.725899457931519,0.0 -119000,5.4916778,2.3909392,,,,,,,,,,,,,, -119100,5.183334,2.3862598,,,,,,,,,,,,,, -119200,4.5210514,2.3073437,,,,,,,,,,,,,, -119300,4.192123,2.347826,,,,,,,,,,,,,, -119400,4.9550753,2.3637505,,,,,,,,,,,,,, -119500,4.907557,2.383763,,,,,,,,,,,,,, -119600,4.737233,2.3904767,,,,,,,,,,,,,, -119700,5.3482084,2.305001,,,,,,,,,,,,,, -119800,4.635854,2.4443088,,,,,,,,,,,,,, -119900,4.4505587,2.3745008,,,,,,,,,,,,,, -120000,4.4716816,2.3301413,,,,,,,,,,,,,, -120100,4.332887,2.3200917,,,,,,,,,,,,,, -120200,4.3208146,2.352587,,,,,,,,,,,,,, -120300,4.3985167,2.3331993,,,,,,,,,,,,,, -120400,4.3357825,2.2784371,,,,,,,,,,,,,, -120427,,,0.7836814522743225,0.931477665901184,0.693120002746582,1.3276666402816772,50000.0,0.5664000511169434,1.9804224967956543,10000.0,40860.52008938789,42399.89472198486,40860.52008938789,1531.4665586948397,3.774169683456421,0.0 -120500,5.604633,2.331043,,,,,,,,,,,,,, -120600,4.7164316,2.3478408,,,,,,,,,,,,,, -120700,4.8710136,2.3237696,,,,,,,,,,,,,, -120800,4.4601703,2.343319,,,,,,,,,,,,,, -120900,4.6217585,2.2591257,,,,,,,,,,,,,, -121000,4.9189725,2.4131753,,,,,,,,,,,,,, -121100,4.220987,2.2801495,,,,,,,,,,,,,, -121200,4.8667865,2.2373898,,,,,,,,,,,,,, -121300,4.4620695,2.3048458,,,,,,,,,,,,,, -121400,6.074312,2.3229685,,,,,,,,,,,,,, -121500,4.528856,2.3310945,,,,,,,,,,,,,, -121600,4.4927936,2.3622978,,,,,,,,,,,,,, -121700,4.603373,2.31585,,,,,,,,,,,,,, -121800,5.5339613,2.3882248,,,,,,,,,,,,,, -121900,4.401476,2.2838075,,,,,,,,,,,,,, -121934,,,0.7770846486091614,0.960997998714447,0.7005199790000916,1.2985070943832395,50000.0,0.5767000317573547,1.9412811994552608,10000.0,41370.45637798309,42927.18236398697,41370.45637798309,1548.710284948349,3.828591108322144,0.0 -122000,4.8239145,2.3190882,,,,,,,,,,,,,, -122100,4.9543686,2.3779213,,,,,,,,,,,,,, -122200,4.7616453,2.292246,,,,,,,,,,,,,, -122300,5.591048,2.3480356,,,,,,,,,,,,,, -122400,5.051527,2.2960966,,,,,,,,,,,,,, -122500,4.600073,2.3324828,,,,,,,,,,,,,, -122600,4.812874,2.2470586,,,,,,,,,,,,,, -122700,4.496955,2.3505104,,,,,,,,,,,,,, -122800,4.678561,2.2981608,,,,,,,,,,,,,, -122900,4.71578,2.358904,,,,,,,,,,,,,, -123000,4.977549,2.3286934,,,,,,,,,,,,,, -123100,4.6077986,2.2138655,,,,,,,,,,,,,, -123200,4.730993,2.311769,,,,,,,,,,,,,, -123300,5.216346,2.362445,,,,,,,,,,,,,, -123400,4.6964,2.3873699,,,,,,,,,,,,,, -123441,,,0.7704280614852905,0.9899856448173524,0.6974999904632568,1.3206807374954224,50000.0,0.5712000131607056,1.9497255086898804,10000.0,41880.4911146164,43454.39989852905,41880.4911146164,1565.789316654205,3.87968111038208,0.0 -123500,4.462153,2.3132737,,,,,,,,,,,,,, -123600,4.984945,2.2327533,,,,,,,,,,,,,, -123700,4.9439425,2.3376179,,,,,,,,,,,,,, -123800,5.4105268,2.235517,,,,,,,,,,,,,, -123900,5.4650736,2.3661048,,,,,,,,,,,,,, -124000,5.2003284,2.3452687,,,,,,,,,,,,,, -124100,4.652879,2.3417459,,,,,,,,,,,,,, -124200,4.9013186,2.3681612,,,,,,,,,,,,,, -124300,5.1490226,2.1800284,,,,,,,,,,,,,, -124400,4.975851,2.401412,,,,,,,,,,,,,, -124500,4.60688,2.3871648,,,,,,,,,,,,,, -124600,5.5366573,2.3847582,,,,,,,,,,,,,, -124700,5.1300616,2.2859423,,,,,,,,,,,,,, -124800,5.047518,2.310514,,,,,,,,,,,,,, -124900,4.9503574,2.3689647,,,,,,,,,,,,,, -124948,,,0.7798947691917419,0.9331660270690918,0.7085399627685547,1.2580126523971558,50000.0,0.5805000066757202,1.8990615606307983,10000.0,42390.486533641815,43981.515066862106,42390.486533641815,1582.8079690933228,3.928217649459839,0.0 -125000,4.9195504,2.3631876,,,,,,,,,,,,,, -125100,4.7184677,2.3075688,,,,,,,,,,,,,, -125200,4.9963083,2.2635732,,,,,,,,,,,,,, -125300,4.8438373,2.3487842,,,,,,,,,,,,,, -125400,4.628003,2.3041856,,,,,,,,,,,,,, -125500,5.4579706,2.3524919,,,,,,,,,,,,,, -125600,6.454068,2.3307974,,,,,,,,,,,,,, -125700,4.852807,2.298576,,,,,,,,,,,,,, -125800,5.3302183,2.373017,,,,,,,,,,,,,, -125900,4.5401573,2.183684,,,,,,,,,,,,,, -126000,5.34927,2.3265896,,,,,,,,,,,,,, -126100,5.2430716,2.1860785,,,,,,,,,,,,,, -126200,4.442413,2.279428,,,,,,,,,,,,,, -126300,6.1426334,2.3596551,,,,,,,,,,,,,, -126400,4.5876966,2.2364645,,,,,,,,,,,,,, -126455,,,0.7784597873687744,0.9620999693870544,0.7042999863624573,1.283419847488403,50000.0,0.5834000110626221,1.9199188947677608,10000.0,42900.42901682854,44508.50655436516,42900.42901682854,1599.7522914409635,3.979759454727173,0.0 -126500,4.7762628,2.305491,,,,,,,,,,,,,, -126600,5.0707126,2.265745,,,,,,,,,,,,,, -126700,4.81711,2.3003438,,,,,,,,,,,,,, -126800,4.7638726,2.3436944,,,,,,,,,,,,,, -126900,5.048692,2.2367043,,,,,,,,,,,,,, -127000,4.889726,2.2838871,,,,,,,,,,,,,, -127100,5.172703,2.3118992,,,,,,,,,,,,,, -127200,5.1452136,2.262961,,,,,,,,,,,,,, -127300,5.1383295,2.2866728,,,,,,,,,,,,,, -127400,4.80139,2.3799963,,,,,,,,,,,,,, -127500,6.332267,2.3441596,,,,,,,,,,,,,, -127600,4.81425,2.2919953,,,,,,,,,,,,,, -127700,5.49474,2.2760522,,,,,,,,,,,,,, -127800,4.5251775,2.2365458,,,,,,,,,,,,,, -127900,5.4869337,2.3445709,,,,,,,,,,,,,, -127961,,,0.7736966013908386,0.9492216110229492,0.7021399736404419,1.265101194381714,50000.0,0.5796000361442566,1.9083911180496216,10000.0,43410.41619229317,45035.59116268158,43410.41619229317,1616.7470650672913,4.031066417694092,0.0 -128000,5.4409738,2.310614,,,,,,,,,,,,,, -128100,5.1405606,2.3114746,,,,,,,,,,,,,, -128200,5.16418,2.246727,,,,,,,,,,,,,, -128300,4.7115746,2.235321,,,,,,,,,,,,,, -128400,5.7641177,2.28097,,,,,,,,,,,,,, -128500,5.527943,2.3046584,,,,,,,,,,,,,, -128600,4.818834,2.2530515,,,,,,,,,,,,,, -128700,4.835757,2.1978602,,,,,,,,,,,,,, -128800,5.252675,2.2789035,,,,,,,,,,,,,, -128900,4.9716268,2.2150018,,,,,,,,,,,,,, -129000,5.104177,2.202661,,,,,,,,,,,,,, -129100,5.0488105,2.2081394,,,,,,,,,,,,,, -129200,4.8050528,2.2876327,,,,,,,,,,,,,, -129300,5.215319,2.3052866,,,,,,,,,,,,,, -129400,5.35267,2.1547887,,,,,,,,,,,,,, -129469,,,0.805683970451355,0.8402982354164124,0.7069199681282043,1.2687928676605225,50000.0,0.5788000226020813,1.9307630062103271,10000.0,43920.58583164215,45563.149040699005,43920.58583164215,1634.0042452812197,4.109456300735474,0.0 -129500,5.0982614,2.2951484,,,,,,,,,,,,,, -129600,5.0068126,2.3286977,,,,,,,,,,,,,, -129700,5.0670753,2.257977,,,,,,,,,,,,,, -129800,4.830924,2.16096,,,,,,,,,,,,,, -129900,4.872844,2.25008,,,,,,,,,,,,,, -130000,5.285111,2.2560396,,,,,,,,,,,,,, -130100,4.7346063,2.2058945,,,,,,,,,,,,,, -130200,4.7620955,2.1493862,,,,,,,,,,,,,, -130300,5.058635,2.2259734,,,,,,,,,,,,,, -130400,4.893857,2.1911066,,,,,,,,,,,,,, -130500,5.06999,2.1778822,,,,,,,,,,,,,, -130600,5.717137,2.2636561,,,,,,,,,,,,,, -130700,4.9012027,2.2264838,,,,,,,,,,,,,, -130800,5.740614,2.3599415,,,,,,,,,,,,,, -130900,5.1519003,2.3146732,,,,,,,,,,,,,, -130977,,,0.7959781289100647,0.8902884125709534,0.7091799974441528,1.2596995830535889,50000.0,0.5836000442504883,1.9029459953308103,10000.0,44430.66488528252,46090.4052464962,44430.66488528252,1651.078492641449,4.160379648208618,0.0 -131000,5.3542485,2.2370534,,,,,,,,,,,,,, -131100,4.991638,2.2039025,,,,,,,,,,,,,, -131200,5.065652,2.1966908,,,,,,,,,,,,,, -131300,5.127347,2.2513003,,,,,,,,,,,,,, -131400,5.3700285,2.3200321,,,,,,,,,,,,,, -131500,4.9301233,2.1899934,,,,,,,,,,,,,, -131600,5.1581426,2.2488153,,,,,,,,,,,,,, -131700,4.868756,2.2387516,,,,,,,,,,,,,, -131800,5.0661616,2.1661363,,,,,,,,,,,,,, -131900,5.4300075,2.1451561,,,,,,,,,,,,,, -132000,4.8075676,2.2018828,,,,,,,,,,,,,, -132100,5.3006315,2.1785698,,,,,,,,,,,,,, -132200,5.2095804,2.2262762,,,,,,,,,,,,,, -132300,6.411537,2.2780209,,,,,,,,,,,,,, -132400,4.9756575,2.201717,,,,,,,,,,,,,, -132484,,,0.7964564561843872,0.8781556487083435,0.7150799632072449,1.2385895252227783,50000.0,0.5871000289916992,1.8773962259292605,10000.0,44940.60585641861,46617.5152451992,44940.60585641861,1668.14599943161,4.21102237701416,0.0 -132500,4.938317,2.2510662,,,,,,,,,,,,,, -132600,5.3357935,2.235773,,,,,,,,,,,,,, -132700,4.8458495,2.2611806,,,,,,,,,,,,,, -132800,5.398421,2.270404,,,,,,,,,,,,,, -132900,4.866294,2.2236218,,,,,,,,,,,,,, -133000,4.811748,2.2287304,,,,,,,,,,,,,, -133100,4.9124656,2.246306,,,,,,,,,,,,,, -133200,5.3505125,2.1361732,,,,,,,,,,,,,, -133300,5.5813136,2.2937794,,,,,,,,,,,,,, -133400,5.710444,2.198236,,,,,,,,,,,,,, -133500,5.2095523,2.2639875,,,,,,,,,,,,,, -133600,4.902765,2.1191013,,,,,,,,,,,,,, -133700,5.4554367,2.2410126,,,,,,,,,,,,,, -133800,4.903393,2.186776,,,,,,,,,,,,,, -133900,5.5571218,2.1819859,,,,,,,,,,,,,, -133991,,,0.8026745915412903,0.8502547144889832,0.7177799940109253,1.2106214761734009,50000.0,0.5932000279426575,1.840102195739746,10000.0,45450.668227911,47144.82175087929,45450.668227911,1685.284812450409,4.261689901351929,0.0 -134000,5.4296484,2.2066228,,,,,,,,,,,,,, -134100,5.123225,2.245906,,,,,,,,,,,,,, -134200,5.268555,2.1990392,,,,,,,,,,,,,, -134300,5.1016555,2.2073157,,,,,,,,,,,,,, -134400,5.7784452,2.2377098,,,,,,,,,,,,,, -134500,4.950174,2.1935477,,,,,,,,,,,,,, -134600,4.9126215,2.2362323,,,,,,,,,,,,,, -134700,5.015354,2.244369,,,,,,,,,,,,,, -134800,5.3861866,2.26081,,,,,,,,,,,,,, -134900,5.287631,2.2231107,,,,,,,,,,,,,, -135000,5.144796,2.2577844,,,,,,,,,,,,,, -135100,6.1201773,2.2392282,,,,,,,,,,,,,, -135200,5.7638226,2.2255905,,,,,,,,,,,,,, -135300,5.51778,2.1614485,,,,,,,,,,,,,, -135400,5.3585935,2.1826382,,,,,,,,,,,,,, -135498,,,0.7995057106018066,0.8691868782043457,0.7166799902915955,1.2189016342163086,50000.0,0.5916000008583069,1.859628438949585,10000.0,45960.74799871445,47672.127898454666,45960.74799871445,1702.4054865837095,4.315122365951538,0.0 -135500,4.971583,2.1874092,,,,,,,,,,,,,, -135600,5.387642,2.1846712,,,,,,,,,,,,,, -135700,5.651017,2.1827216,,,,,,,,,,,,,, -135800,5.2697954,2.193937,,,,,,,,,,,,,, -135900,6.7870073,2.260755,,,,,,,,,,,,,, -136000,5.6660666,2.2583187,,,,,,,,,,,,,, -136100,5.436568,2.2046986,,,,,,,,,,,,,, -136200,5.269227,2.1608844,,,,,,,,,,,,,, -136300,5.4578023,2.2189934,,,,,,,,,,,,,, -136400,5.5866985,2.225031,,,,,,,,,,,,,, -136500,5.6024404,2.2935627,,,,,,,,,,,,,, -136600,5.4282417,2.2788322,,,,,,,,,,,,,, -136700,5.94898,2.2080002,,,,,,,,,,,,,, -136800,5.456833,2.267587,,,,,,,,,,,,,, -136900,5.7346344,2.2016437,,,,,,,,,,,,,, -137000,5.9844513,2.1230493,,,,,,,,,,,,,, -137005,,,0.7967952489852905,0.862323522567749,0.7177000045776367,1.208281636238098,50000.0,0.5957000255584717,1.8467576503753664,10000.0,46470.69027876854,48199.111078739166,46470.69027876854,1719.3394558429718,4.369931697845459,0.0 -137100,5.791256,2.1673234,,,,,,,,,,,,,, -137200,5.0710535,2.1048524,,,,,,,,,,,,,, -137300,5.4826283,2.2055311,,,,,,,,,,,,,, -137400,5.57226,2.2142544,,,,,,,,,,,,,, -137500,6.693311,2.1754854,,,,,,,,,,,,,, -137600,6.0440803,2.2620957,,,,,,,,,,,,,, -137700,5.754511,2.1970077,,,,,,,,,,,,,, -137800,5.4394975,2.114153,,,,,,,,,,,,,, -137900,5.9070067,2.2169614,,,,,,,,,,,,,, -138000,5.9119515,2.228637,,,,,,,,,,,,,, -138100,5.5080237,2.1607413,,,,,,,,,,,,,, -138200,5.8827305,2.2728286,,,,,,,,,,,,,, -138300,4.8973536,2.1191688,,,,,,,,,,,,,, -138400,6.1587796,2.2584114,,,,,,,,,,,,,, -138500,5.483452,2.1861515,,,,,,,,,,,,,, -138512,,,0.8315330147743225,0.7364321947097778,0.7225199937820435,1.1886651515960691,50000.0,0.5956000089645386,1.8367879390716555,10000.0,46980.68811249733,48726.23480153084,46980.68811249733,1736.362550497055,4.420541524887085,0.0 -138600,6.1963983,2.1022007,,,,,,,,,,,,,, -138700,6.020135,2.2058468,,,,,,,,,,,,,, -138800,5.473957,2.2042966,,,,,,,,,,,,,, -138900,5.847062,2.1422002,,,,,,,,,,,,,, -139000,5.746412,2.2100492,,,,,,,,,,,,,, -139100,5.863413,2.1501365,,,,,,,,,,,,,, -139200,5.559093,2.1336815,,,,,,,,,,,,,, -139300,5.4288807,2.0929518,,,,,,,,,,,,,, -139400,5.5840926,2.1985316,,,,,,,,,,,,,, -139500,5.543734,2.1279542,,,,,,,,,,,,,, -139600,5.5555487,2.1788683,,,,,,,,,,,,,, -139700,5.4395475,2.1533394,,,,,,,,,,,,,, -139800,5.982488,2.2020879,,,,,,,,,,,,,, -139900,5.605173,2.1558425,,,,,,,,,,,,,, -140000,5.3608966,2.0836058,,,,,,,,,,,,,, -140019,,,0.8246771097183228,0.7462592720985413,0.729919970035553,1.1638081073760986,50000.0,0.6030000448226929,1.794895052909851,10000.0,47490.778540849686,49253.46666812897,47490.778540849686,1753.3998112678528,4.472010850906372,0.0 -140100,6.314786,2.1210515,,,,,,,,,,,,,, -140200,6.267622,2.1197994,,,,,,,,,,,,,, -140300,5.7739425,2.1494777,,,,,,,,,,,,,, -140400,5.842086,2.228551,,,,,,,,,,,,,, -140500,5.8778524,2.1764007,,,,,,,,,,,,,, -140600,6.366139,2.1865172,,,,,,,,,,,,,, -140700,5.49436,2.0850554,,,,,,,,,,,,,, -140800,5.897281,2.1946583,,,,,,,,,,,,,, -140900,5.473011,2.128488,,,,,,,,,,,,,, -141000,5.771557,2.1106322,,,,,,,,,,,,,, -141100,5.750139,2.2166693,,,,,,,,,,,,,, -141200,5.778118,2.113159,,,,,,,,,,,,,, -141300,5.6212907,2.14645,,,,,,,,,,,,,, -141400,5.6800623,2.103642,,,,,,,,,,,,,, -141500,6.0181417,2.199216,,,,,,,,,,,,,, -141526,,,0.8252949714660645,0.7593898177146912,0.7310400009155273,1.1624298095703125,50000.0,0.6044000387191772,1.7896476984024048,10000.0,48000.73310112953,49780.63418030739,48000.73310112953,1770.5072956085205,4.523514747619629,0.0 -141600,5.4490495,2.178797,,,,,,,,,,,,,, -141700,6.3503823,2.1429098,,,,,,,,,,,,,, -141800,5.6461253,2.173472,,,,,,,,,,,,,, -141900,6.2115164,2.1429377,,,,,,,,,,,,,, -142000,5.732097,2.169898,,,,,,,,,,,,,, -142100,5.665095,2.1009986,,,,,,,,,,,,,, -142200,6.0146213,2.1013572,,,,,,,,,,,,,, -142300,5.684213,2.1045256,,,,,,,,,,,,,, -142400,6.9456162,2.1435323,,,,,,,,,,,,,, -142500,5.7855005,2.0813382,,,,,,,,,,,,,, -142600,6.2145452,2.150844,,,,,,,,,,,,,, -142700,6.199267,2.1872146,,,,,,,,,,,,,, -142800,6.0112133,2.2100463,,,,,,,,,,,,,, -142900,5.797773,2.1729758,,,,,,,,,,,,,, -143000,6.5019917,2.1351972,,,,,,,,,,,,,, -143033,,,0.8251753449440002,0.7437108159065247,0.7314800024032593,1.141937017440796,50000.0,0.6105000376701355,1.773678183555603,10000.0,48510.89326953888,50308.06536364555,48510.89326953888,1787.6722025871277,4.576416730880737,0.0 -143100,5.293081,2.112245,,,,,,,,,,,,,, -143200,5.583184,2.1199548,,,,,,,,,,,,,, -143300,6.1032834,2.1380012,,,,,,,,,,,,,, -143400,5.799103,2.1612842,,,,,,,,,,,,,, -143500,6.1716104,2.1072435,,,,,,,,,,,,,, -143600,6.392242,2.153357,,,,,,,,,,,,,, -143700,5.9446473,2.1584268,,,,,,,,,,,,,, -143800,6.135903,2.1269782,,,,,,,,,,,,,, -143900,5.4318314,2.0594652,,,,,,,,,,,,,, -144000,6.403013,2.1145644,,,,,,,,,,,,,, -144100,5.8773527,2.1635432,,,,,,,,,,,,,, -144200,5.976665,2.0775578,,,,,,,,,,,,,, -144300,6.334276,2.0955667,,,,,,,,,,,,,, -144400,6.1236453,2.2260036,,,,,,,,,,,,,, -144500,6.4361587,2.1072419,,,,,,,,,,,,,, -144539,,,0.8236008882522583,0.7736658453941345,0.7317599654197693,1.1646684408187866,50000.0,0.6085000038146973,1.8027182817459104,10000.0,49020.87477731705,50835.20391178131,49020.87477731705,1804.7232937812803,4.630538702011108,0.0 -144600,6.2217803,2.112195,,,,,,,,,,,,,, -144700,6.567285,2.162012,,,,,,,,,,,,,, -144800,5.9905157,2.0747428,,,,,,,,,,,,,, -144900,5.704228,2.1703901,,,,,,,,,,,,,, -145000,6.726573,2.1266713,,,,,,,,,,,,,, -145100,5.9708195,2.1196585,,,,,,,,,,,,,, -145200,5.8218203,2.1439834,,,,,,,,,,,,,, -145300,6.23648,2.193316,,,,,,,,,,,,,, -145400,6.1024513,2.0939763,,,,,,,,,,,,,, -145500,5.6621,2.0183392,,,,,,,,,,,,,, -145600,5.9633455,2.0952597,,,,,,,,,,,,,, -145700,5.4954243,2.068313,,,,,,,,,,,,,, -145800,6.1030254,2.0688095,,,,,,,,,,,,,, -145900,6.1122212,2.1187334,,,,,,,,,,,,,, -146000,6.8410287,2.149062,,,,,,,,,,,,,, -146046,,,0.8278858065605164,0.7407735586166382,0.7351599931716919,1.1394431591033936,50000.0,0.615600049495697,1.7666029930114746,10000.0,49530.97762846947,51362.36286592484,49530.97762846947,1821.6753494739528,4.683336019515991,0.0 -146100,5.6981044,2.0310824,,,,,,,,,,,,,, -146200,5.9955077,2.1500413,,,,,,,,,,,,,, -146300,6.378954,2.1258874,,,,,,,,,,,,,, -146400,5.668297,2.1022248,,,,,,,,,,,,,, -146500,5.966694,2.0507817,,,,,,,,,,,,,, -146600,6.0855923,2.125142,,,,,,,,,,,,,, -146700,5.662941,2.0608082,,,,,,,,,,,,,, -146800,6.184531,1.9939342,,,,,,,,,,,,,, -146900,6.2527847,2.121588,,,,,,,,,,,,,, -147000,6.228484,2.0809903,,,,,,,,,,,,,, -147100,6.1779695,2.0752957,,,,,,,,,,,,,, -147200,6.118741,2.0618963,,,,,,,,,,,,,, -147300,6.116471,2.1566315,,,,,,,,,,,,,, -147400,5.7895236,2.0714893,,,,,,,,,,,,,, -147500,6.266503,2.086889,,,,,,,,,,,,,, -147552,,,0.85843825340271,0.6291110515594482,0.7379999756813049,1.130968689918518,50000.0,0.6152999997138977,1.7525743246078491,10000.0,50040.92667126656,51889.318643569946,50040.92667126656,1838.5789158344269,4.734616041183472,0.0 -147600,5.735433,1.9850533,,,,,,,,,,,,,, -147700,5.9122143,2.0499165,,,,,,,,,,,,,, -147800,6.1514244,2.0894122,,,,,,,,,,,,,, -147900,5.7946234,2.028346,,,,,,,,,,,,,, -148000,6.005404,2.0899649,,,,,,,,,,,,,, -148100,6.26198,2.042552,,,,,,,,,,,,,, -148200,6.357577,2.1612418,,,,,,,,,,,,,, -148300,5.745316,2.036817,,,,,,,,,,,,,, -148400,6.1308446,2.064022,,,,,,,,,,,,,, -148500,6.881778,2.1182172,,,,,,,,,,,,,, -148600,6.4316773,2.1515787,,,,,,,,,,,,,, -148700,7.051863,2.0289447,,,,,,,,,,,,,, -148800,7.2255116,2.1103892,,,,,,,,,,,,,, -148900,6.0535426,2.066267,,,,,,,,,,,,,, -149000,6.265662,2.0416842,,,,,,,,,,,,,, -149059,,,0.8477559089660645,0.667382538318634,0.7384999990463257,1.130212664604187,50000.0,0.6147000193595886,1.7563661336898804,10000.0,50551.148307561874,52416.78352403641,50551.148307561874,1855.714731216431,4.789153575897217,0.0 -149100,6.6071334,2.0826654,,,,,,,,,,,,,, -149200,6.119168,2.1199312,,,,,,,,,,,,,, -149300,6.440091,2.0691552,,,,,,,,,,,,,, -149400,6.5599165,2.1443086,,,,,,,,,,,,,, -149500,6.464209,2.0998027,,,,,,,,,,,,,, -149600,6.3494234,2.086997,,,,,,,,,,,,,, -149700,6.684592,2.040517,,,,,,,,,,,,,, -149800,6.476627,2.155502,,,,,,,,,,,,,, -149900,5.8745923,2.0524254,,,,,,,,,,,,,, -150000,6.226512,2.0411916,,,,,,,,,,,,,, -150100,6.502059,2.0561626,,,,,,,,,,,,,, -150200,6.046688,1.9741411,,,,,,,,,,,,,, -150300,6.3022485,2.1099954,,,,,,,,,,,,,, -150400,7.2133603,2.084544,,,,,,,,,,,,,, -150500,6.448851,2.0471811,,,,,,,,,,,,,, -150566,,,0.8537547588348389,0.6530286073684692,0.7407999634742737,1.108092188835144,50000.0,0.6205000281333923,1.735827088356018,10000.0,51061.33330345154,52944.4701063633,51061.33330345154,1873.110155582428,4.844248533248901,0.0 -150600,6.386923,1.9952351,,,,,,,,,,,,,, -150700,5.789716,1.9962251,,,,,,,,,,,,,, -150800,5.9993916,2.138988,,,,,,,,,,,,,, -150900,6.2240443,2.0359006,,,,,,,,,,,,,, -151000,6.313229,2.0494227,,,,,,,,,,,,,, -151100,6.097987,2.0066772,,,,,,,,,,,,,, -151200,5.8700604,1.9892045,,,,,,,,,,,,,, -151300,6.8569126,2.0528686,,,,,,,,,,,,,, -151400,5.594513,1.9484878,,,,,,,,,,,,,, -151500,6.723107,2.0058513,,,,,,,,,,,,,, -151600,6.1654124,2.0078363,,,,,,,,,,,,,, -151700,6.1174445,2.065864,,,,,,,,,,,,,, -151800,6.5227437,1.9865735,,,,,,,,,,,,,, -151900,6.4839425,2.0426161,,,,,,,,,,,,,, -152000,6.172814,1.9679583,,,,,,,,,,,,,, -152073,,,0.8474569320678711,0.6855957508087158,0.743939995765686,1.1150567531585691,50000.0,0.6213000416755676,1.74584698677063,10000.0,51571.34389591217,53472.46926546097,51571.34389591217,1890.9894676208496,4.900099754333496,0.0 -152100,6.3554215,1.9921018,,,,,,,,,,,,,, -152200,6.83929,1.9962554,,,,,,,,,,,,,, -152300,6.769699,2.0127256,,,,,,,,,,,,,, -152400,6.5573764,1.988289,,,,,,,,,,,,,, -152500,6.370758,2.0793722,,,,,,,,,,,,,, -152600,6.0987773,2.061552,,,,,,,,,,,,,, -152700,6.8477316,2.11207,,,,,,,,,,,,,, -152800,6.4508066,2.0766134,,,,,,,,,,,,,, -152900,6.841739,2.0187683,,,,,,,,,,,,,, -153000,6.7182956,2.0023282,,,,,,,,,,,,,, -153100,6.6903234,1.9583254,,,,,,,,,,,,,, -153200,6.252839,2.0030842,,,,,,,,,,,,,, -153300,5.855104,1.9404049,,,,,,,,,,,,,, -153400,6.8186884,2.0010774,,,,,,,,,,,,,, -153500,6.5824428,2.0825357,,,,,,,,,,,,,, -153580,,,0.8515625,0.6506571769714355,0.7438799738883972,1.100218653678894,50000.0,0.622700035572052,1.7106108665466309,10000.0,52081.42667245865,53999.78741884232,52081.42667245865,1908.1210148334503,4.9520440101623535,0.0 -153600,6.154376,2.0107698,,,,,,,,,,,,,, -153700,6.656628,1.9686432,,,,,,,,,,,,,, -153800,6.6621733,1.9726998,,,,,,,,,,,,,, -153900,6.208014,1.9836497,,,,,,,,,,,,,, -154000,6.8230495,2.042388,,,,,,,,,,,,,, -154100,6.639668,1.9425039,,,,,,,,,,,,,, -154200,6.8306413,2.0745938,,,,,,,,,,,,,, -154300,6.285861,1.9205722,,,,,,,,,,,,,, -154400,7.1149607,1.9607357,,,,,,,,,,,,,, -154500,6.613122,1.9899144,,,,,,,,,,,,,, -154600,7.0053296,2.0049288,,,,,,,,,,,,,, -154700,6.9968133,1.9704547,,,,,,,,,,,,,, -154800,7.3531895,1.9920728,,,,,,,,,,,,,, -154900,7.200627,2.0184462,,,,,,,,,,,,,, -155000,6.9090714,2.059928,,,,,,,,,,,,,, -155087,,,0.8556082248687744,0.6351298093795776,0.7490400075912476,1.0883169174194336,50000.0,0.6258000135421753,1.720607876777649,10000.0,52591.56684565544,54527.11508727074,52591.56684565544,1925.201922893524,5.0060014724731445,0.0 -155100,6.3092155,2.0113134,,,,,,,,,,,,,, -155200,7.0517955,2.0568647,,,,,,,,,,,,,, -155300,7.437629,2.0042207,,,,,,,,,,,,,, -155400,6.9545283,2.0504584,,,,,,,,,,,,,, -155500,6.3627167,1.9475436,,,,,,,,,,,,,, -155600,7.1636734,2.0173624,,,,,,,,,,,,,, -155700,7.296175,2.0339205,,,,,,,,,,,,,, -155800,6.8856516,2.0407789,,,,,,,,,,,,,, -155900,7.167302,2.0618694,,,,,,,,,,,,,, -156000,6.067742,1.9295397,,,,,,,,,,,,,, -156100,6.844066,1.9594549,,,,,,,,,,,,,, -156200,6.3982944,2.0870118,,,,,,,,,,,,,, -156300,6.6592402,1.9205172,,,,,,,,,,,,,, -156400,6.4798965,1.9554358,,,,,,,,,,,,,, -156500,6.331071,1.9882059,,,,,,,,,,,,,, -156594,,,0.8785076141357422,0.5570769309997559,0.7483800053596497,1.0872067213058472,50000.0,0.6260000467300415,1.7072957754135132,10000.0,53101.60917067528,55054.2172267437,53101.60917067528,1942.1512784957888,5.063514232635498,0.0 -156600,6.502505,1.8993268,,,,,,,,,,,,,, -156700,6.475976,1.9355483,,,,,,,,,,,,,, -156800,6.4574113,1.9583377,,,,,,,,,,,,,, -156900,6.3676844,1.8932284,,,,,,,,,,,,,, -157000,6.9397936,1.9134103,,,,,,,,,,,,,, -157100,7.0572224,1.9311073,,,,,,,,,,,,,, -157200,6.7309856,1.9997517,,,,,,,,,,,,,, -157300,7.0260763,1.9744297,,,,,,,,,,,,,, -157400,7.046474,2.02372,,,,,,,,,,,,,, -157500,7.0999856,1.9477154,,,,,,,,,,,,,, -157600,7.8317637,2.002726,,,,,,,,,,,,,, -157700,7.329295,1.9924382,,,,,,,,,,,,,, -157800,7.007186,2.0325181,,,,,,,,,,,,,, -157900,7.7575564,1.992578,,,,,,,,,,,,,, -158000,6.898694,1.9654768,,,,,,,,,,,,,, -158100,7.4101505,1.9523776,,,,,,,,,,,,,, -158101,,,0.8756377100944519,0.5695292949676514,0.7520399689674377,1.07098388671875,50000.0,0.628600001335144,1.6892168521881104,10000.0,53611.69621825218,55581.5731716156,53611.69621825218,1959.310719013214,5.121617794036865,0.0 -158200,7.3506184,1.9219894,,,,,,,,,,,,,, -158300,7.044776,1.9223697,,,,,,,,,,,,,, -158400,6.3752565,1.9006474,,,,,,,,,,,,,, -158500,6.7838387,2.0123892,,,,,,,,,,,,,, -158600,7.019328,1.9365189,,,,,,,,,,,,,, -158700,6.841166,1.9759078,,,,,,,,,,,,,, -158800,6.6711183,1.9390187,,,,,,,,,,,,,, -158900,6.688522,1.961014,,,,,,,,,,,,,, -159000,6.435457,1.9024894,,,,,,,,,,,,,, -159100,7.015312,1.9544897,,,,,,,,,,,,,, -159200,6.6855803,1.9309378,,,,,,,,,,,,,, -159300,7.699299,1.9274195,,,,,,,,,,,,,, -159400,6.7264376,1.9475738,,,,,,,,,,,,,, -159500,6.971914,1.8760443,,,,,,,,,,,,,, -159600,7.2844386,2.0040252,,,,,,,,,,,,,, -159608,,,0.8757772445678711,0.5535916686058044,0.754539966583252,1.05316162109375,50000.0,0.6337000131607056,1.6755914688110352,10000.0,54121.8637149334,56108.89868927002,54121.8637149334,1976.359569311142,5.178009748458862,0.0 -159700,6.785532,1.9088469,,,,,,,,,,,,,, -159800,7.235732,1.9254674,,,,,,,,,,,,,, -159900,6.9325914,1.9620829,,,,,,,,,,,,,, -160000,6.5504727,1.8474852,,,,,,,,,,,,,, -160100,6.7923403,1.8809001,,,,,,,,,,,,,, -160200,7.0622954,1.9379299,,,,,,,,,,,,,, -160300,7.411059,1.8940398,,,,,,,,,,,,,, -160400,7.636548,1.9376063,,,,,,,,,,,,,, -160500,7.1587243,1.9410145,,,,,,,,,,,,,, -160600,6.958854,1.852394,,,,,,,,,,,,,, -160700,7.259177,2.005073,,,,,,,,,,,,,, -160800,7.360055,1.9739213,,,,,,,,,,,,,, -160900,7.3284783,1.93723,,,,,,,,,,,,,, -161000,6.841175,1.9393432,,,,,,,,,,,,,, -161100,7.0125,2.005286,,,,,,,,,,,,,, -161115,,,0.8797831535339355,0.5523720979690552,0.7566799521446228,1.0592304468154907,50000.0,0.6381000280380249,1.6685220003128052,10000.0,54631.85998415947,56636.17068815231,54631.85998415947,1993.5300033092497,5.232148885726929,0.0 -161200,6.898536,1.9211578,,,,,,,,,,,,,, -161300,6.3364344,1.9455643,,,,,,,,,,,,,, -161400,7.1234136,1.9134647,,,,,,,,,,,,,, -161500,7.2791886,1.9136188,,,,,,,,,,,,,, -161600,7.3339067,1.9423912,,,,,,,,,,,,,, -161700,7.8670955,1.8613682,,,,,,,,,,,,,, -161800,7.162848,1.9146973,,,,,,,,,,,,,, -161900,7.0986075,1.9047089,,,,,,,,,,,,,, -162000,6.739705,1.9389015,,,,,,,,,,,,,, -162100,7.3869157,1.953538,,,,,,,,,,,,,, -162200,6.781286,1.9176321,,,,,,,,,,,,,, -162300,7.455319,2.01764,,,,,,,,,,,,,, -162400,7.1840606,1.9954451,,,,,,,,,,,,,, -162500,6.9813056,1.9154909,,,,,,,,,,,,,, -162600,7.5659,1.8838307,,,,,,,,,,,,,, -162622,,,0.8821747303009033,0.5450007915496826,0.758080005645752,1.0532749891281128,50000.0,0.6375000476837158,1.665583252906799,10000.0,55141.8918106556,57163.43918633461,55141.8918106556,2010.63884305954,5.307616949081421,0.0 -162700,7.379132,1.9924814,,,,,,,,,,,,,, -162800,7.4510036,1.86484,,,,,,,,,,,,,, -162900,7.4662805,1.9419075,,,,,,,,,,,,,, -163000,6.835957,1.9024283,,,,,,,,,,,,,, -163100,7.079223,1.8683524,,,,,,,,,,,,,, -163200,7.030488,1.9088093,,,,,,,,,,,,,, -163300,7.684036,1.87888,,,,,,,,,,,,,, -163400,7.0922503,1.918963,,,,,,,,,,,,,, -163500,7.466949,1.9528632,,,,,,,,,,,,,, -163600,7.4655747,1.917609,,,,,,,,,,,,,, -163700,7.5047655,1.8961624,,,,,,,,,,,,,, -163800,7.0790997,1.8812983,,,,,,,,,,,,,, -163900,7.314992,1.9228486,,,,,,,,,,,,,, -164000,7.364346,1.8813608,,,,,,,,,,,,,, -164100,7.782863,1.8811818,,,,,,,,,,,,,, -164129,,,0.8889508843421936,0.5121031999588013,0.761139988899231,1.0397335290908811,50000.0,0.6361000537872314,1.6553908586502075,10000.0,55651.94191503525,57690.5752518177,55651.94191503525,2027.6138999462128,5.36580753326416,0.0 -164200,6.9130025,1.8561913,,,,,,,,,,,,,, -164300,7.39548,1.8738283,,,,,,,,,,,,,, -164400,7.530449,1.9170454,,,,,,,,,,,,,, -164500,7.6800504,1.9187871,,,,,,,,,,,,,, -164600,7.7778835,1.8651477,,,,,,,,,,,,,, -164700,6.709955,1.8610876,,,,,,,,,,,,,, -164800,7.107539,1.8869572,,,,,,,,,,,,,, -164900,7.1012173,1.8672285,,,,,,,,,,,,,, -165000,6.8086066,1.878661,,,,,,,,,,,,,, -165100,7.3156824,1.9078461,,,,,,,,,,,,,, -165200,6.904261,1.8317049,,,,,,,,,,,,,, -165300,7.5481987,1.8949337,,,,,,,,,,,,,, -165400,7.8479204,1.907632,,,,,,,,,,,,,, -165500,7.772735,1.8659496,,,,,,,,,,,,,, -165600,6.809788,1.8568547,,,,,,,,,,,,,, -165636,,,0.9054129123687744,0.462770015001297,0.7623400092124939,1.0329078435897827,50000.0,0.6397000551223755,1.6543211936950684,10000.0,56162.06087732315,58218.07364296913,56162.06087732315,2044.88449382782,5.420394659042358,0.0 -165700,7.1761594,1.8927873,,,,,,,,,,,,,, -165800,7.172045,1.8958539,,,,,,,,,,,,,, -165900,6.4016137,1.8191053,,,,,,,,,,,,,, -166000,7.984753,1.8309114,,,,,,,,,,,,,, -166100,7.2767944,1.811308,,,,,,,,,,,,,, -166200,7.72024,1.8760546,,,,,,,,,,,,,, -166300,7.264915,1.8684129,,,,,,,,,,,,,, -166400,8.369588,1.8956043,,,,,,,,,,,,,, -166500,6.8629627,1.8286383,,,,,,,,,,,,,, -166600,7.8419566,1.8635671,,,,,,,,,,,,,, -166700,7.9419603,1.8688706,,,,,,,,,,,,,, -166800,7.9309816,1.8279223,,,,,,,,,,,,,, -166900,7.017561,1.8120892,,,,,,,,,,,,,, -167000,7.20773,1.8463664,,,,,,,,,,,,,, -167100,7.470432,1.8819015,,,,,,,,,,,,,, -167143,,,0.9016661047935486,0.473734438419342,0.7623800039291382,1.0327937602996826,50000.0,0.6407000422477722,1.6520830392837524,10000.0,56672.09117293358,58745.37574315071,56672.09117293358,2062.039719581604,5.4839723110198975,0.0 -167200,7.2495804,1.8568909,,,,,,,,,,,,,, -167300,7.3447113,1.8486588,,,,,,,,,,,,,, -167400,7.686576,1.8432026,,,,,,,,,,,,,, -167500,7.403619,1.8800738,,,,,,,,,,,,,, -167600,7.3420725,1.8153214,,,,,,,,,,,,,, -167700,8.21832,1.8838103,,,,,,,,,,,,,, -167800,7.517275,1.8378923,,,,,,,,,,,,,, -167900,7.4337378,1.8263174,,,,,,,,,,,,,, -168000,7.1478224,1.8268284,,,,,,,,,,,,,, -168100,8.119915,1.8275346,,,,,,,,,,,,,, -168200,7.589675,1.835708,,,,,,,,,,,,,, -168300,6.924834,1.7753384,,,,,,,,,,,,,, -168400,7.192931,1.8519402,,,,,,,,,,,,,, -168500,7.294121,1.8183417,,,,,,,,,,,,,, -168600,8.121896,1.8852197,,,,,,,,,,,,,, -168650,,,0.9009087681770324,0.4697478115558624,0.766319990158081,1.021145582199097,50000.0,0.6403000354766846,1.6334751844406128,10000.0,57182.03576087952,59272.69396233559,57182.03576087952,2079.302550792694,5.541658639907837,0.0 -168700,8.3133,1.8678834,,,,,,,,,,,,,, -168800,8.327079,1.9191967,,,,,,,,,,,,,, -168900,7.238955,1.835419,,,,,,,,,,,,,, -169000,7.6316714,1.8567681,,,,,,,,,,,,,, -169100,7.6269794,1.7660673,,,,,,,,,,,,,, -169200,7.673969,1.7971519,,,,,,,,,,,,,, -169300,7.4499803,1.8551862,,,,,,,,,,,,,, -169400,7.760926,1.9304023,,,,,,,,,,,,,, -169500,7.820918,1.8388492,,,,,,,,,,,,,, -169600,7.506037,1.8640292,,,,,,,,,,,,,, -169700,7.615835,1.8640045,,,,,,,,,,,,,, -169800,8.216329,1.7838595,,,,,,,,,,,,,, -169900,8.06643,1.9027874,,,,,,,,,,,,,, -170000,7.5756373,1.8803372,,,,,,,,,,,,,, -170100,7.289004,1.8102208,,,,,,,,,,,,,, -170157,,,0.9041772484779358,0.4596326351165771,0.7662799954414368,1.0119706392288208,50000.0,0.6470000147819519,1.6323281526565552,10000.0,57692.05851483345,59799.91110467911,57692.05851483345,2096.389819145202,5.597667694091797,0.0 -170200,7.6515136,1.8329335,,,,,,,,,,,,,, -170300,8.127008,1.865504,,,,,,,,,,,,,, -170400,7.675755,1.8029811,,,,,,,,,,,,,, -170500,7.9525785,1.8407567,,,,,,,,,,,,,, -170600,7.8918133,1.8938845,,,,,,,,,,,,,, -170700,7.5253997,1.8127965,,,,,,,,,,,,,, -170800,7.816531,1.8071544,,,,,,,,,,,,,, -170900,8.495542,1.808557,,,,,,,,,,,,,, -171000,6.8974724,1.7861407,,,,,,,,,,,,,, -171100,8.205576,1.8499092,,,,,,,,,,,,,, -171200,7.5070887,1.7771331,,,,,,,,,,,,,, -171300,7.614012,1.8472366,,,,,,,,,,,,,, -171400,7.827949,1.7375591,,,,,,,,,,,,,, -171500,7.5929756,1.812459,,,,,,,,,,,,,, -171600,8.146112,1.7965065,,,,,,,,,,,,,, -171664,,,0.9072065949440002,0.4597722291946411,0.7687399983406067,1.0186874866485596,50000.0,0.6431000232696533,1.6356762647628784,10000.0,58202.10117435455,60327.013768196106,58202.10117435455,2113.3370258808136,5.657930612564087,0.0 -171700,7.4247265,1.8466046,,,,,,,,,,,,,, -171800,7.1113014,1.75995,,,,,,,,,,,,,, -171900,7.632604,1.8368232,,,,,,,,,,,,,, -172000,6.80044,1.8060327,,,,,,,,,,,,,, -172100,8.080137,1.8269541,,,,,,,,,,,,,, -172200,8.439559,1.8634291,,,,,,,,,,,,,, -172300,7.5067663,1.7996125,,,,,,,,,,,,,, -172400,8.354585,1.7513752,,,,,,,,,,,,,, -172500,8.0512,1.8438536,,,,,,,,,,,,,, -172600,7.48041,1.7824206,,,,,,,,,,,,,, -172700,6.9944787,1.7950764,,,,,,,,,,,,,, -172800,8.283865,1.7833347,,,,,,,,,,,,,, -172900,7.5334573,1.8574405,,,,,,,,,,,,,, -173000,6.917834,1.7512515,,,,,,,,,,,,,, -173100,7.478061,1.8311397,,,,,,,,,,,,,, -173171,,,0.9073860049247742,0.4471368789672851,0.7688800096511841,1.0100610256195068,50000.0,0.6465000510215759,1.628620743751526,10000.0,58712.05089068413,60854.05264925957,58712.05089068413,2130.31436419487,5.717501878738403,0.0 -173200,8.101678,1.8130031,,,,,,,,,,,,,, -173300,7.2556562,1.8809682,,,,,,,,,,,,,, -173400,7.4366846,1.7258894,,,,,,,,,,,,,, -173500,7.6908545,1.8190682,,,,,,,,,,,,,, -173600,8.852355,1.8435743,,,,,,,,,,,,,, -173700,7.550541,1.8269796,,,,,,,,,,,,,, -173800,7.4427176,1.8708203,,,,,,,,,,,,,, -173900,7.8261456,1.8282282,,,,,,,,,,,,,, -174000,7.2568088,1.7817034,,,,,,,,,,,,,, -174100,7.885053,1.8150296,,,,,,,,,,,,,, -174200,7.9476867,1.8437974,,,,,,,,,,,,,, -174300,7.13291,1.8317893,,,,,,,,,,,,,, -174400,7.269232,1.8068602,,,,,,,,,,,,,, -174500,8.005044,1.8348453,,,,,,,,,,,,,, -174600,7.8334002,1.8460456,,,,,,,,,,,,,, -174678,,,0.917191445827484,0.4137333035469055,0.7696399688720703,1.0041489601135254,50000.0,0.6478000283241272,1.612573504447937,10000.0,59221.97033691406,61381.27100849152,59221.97033691406,2147.503813028336,5.7750890254974365,0.0 -174700,7.7916737,1.8581697,,,,,,,,,,,,,, -174800,7.5824566,1.781027,,,,,,,,,,,,,, -174900,8.112665,1.830004,,,,,,,,,,,,,, -175000,8.5348425,1.8424828,,,,,,,,,,,,,, -175100,7.324861,1.8028138,,,,,,,,,,,,,, -175200,8.666724,1.82975,,,,,,,,,,,,,, -175300,7.8286176,1.8457842,,,,,,,,,,,,,, -175400,7.8456044,1.7755715,,,,,,,,,,,,,, -175500,8.0929365,1.840983,,,,,,,,,,,,,, -175600,8.418507,1.779478,,,,,,,,,,,,,, -175700,7.8778806,1.819807,,,,,,,,,,,,,, -175800,8.052079,1.8083642,,,,,,,,,,,,,, -175900,8.402333,1.9237418,,,,,,,,,,,,,, -176000,7.895627,1.8361146,,,,,,,,,,,,,, -176100,7.3444347,1.7869799,,,,,,,,,,,,,, -176184,,,0.9174505472183228,0.415025532245636,0.7699399590492249,1.001592040061951,50000.0,0.6492000222206116,1.610582947731018,10000.0,59731.89277672768,61908.24905347824,59731.89277672768,2164.4502770900726,5.829644203186035,0.0 -176200,7.7825804,1.8174915,,,,,,,,,,,,,, -176300,7.6014204,1.829687,,,,,,,,,,,,,, -176400,7.7531404,1.8203728,,,,,,,,,,,,,, -176500,8.78813,1.8118114,,,,,,,,,,,,,, -176600,7.793695,1.7386514,,,,,,,,,,,,,, -176700,7.6722426,1.8434701,,,,,,,,,,,,,, -176800,8.376173,1.8209809,,,,,,,,,,,,,, -176900,8.341741,1.7891704,,,,,,,,,,,,,, -177000,8.478952,1.811474,,,,,,,,,,,,,, -177100,8.575404,1.8752966,,,,,,,,,,,,,, -177200,9.676201,1.7738132,,,,,,,,,,,,,, -177300,7.5537786,1.7851734,,,,,,,,,,,,,, -177400,7.6398683,1.7607099,,,,,,,,,,,,,, -177500,8.119231,1.8625273,,,,,,,,,,,,,, -177600,8.05742,1.7767512,,,,,,,,,,,,,, -177690,,,0.9185068607330322,0.4077030122280121,0.7705599665641785,0.9990577697753906,50000.0,0.6515000462532043,1.609252691268921,10000.0,60241.88081359863,62435.42911171913,60241.88081359863,2181.530555009842,5.88846755027771,0.0 -177700,8.778807,1.8202233,,,,,,,,,,,,,, -177800,8.2920685,1.830613,,,,,,,,,,,,,, -177900,8.071973,1.8551636,,,,,,,,,,,,,, -178000,7.378826,1.7870616,,,,,,,,,,,,,, -178100,7.8513865,1.7602571,,,,,,,,,,,,,, -178200,7.80471,1.7863355,,,,,,,,,,,,,, -178300,7.990285,1.758403,,,,,,,,,,,,,, -178400,6.9059095,1.740282,,,,,,,,,,,,,, -178500,7.8524623,1.7653401,,,,,,,,,,,,,, -178600,8.61601,1.8699492,,,,,,,,,,,,,, -178700,7.3050976,1.7650533,,,,,,,,,,,,,, -178800,8.027289,1.7719151,,,,,,,,,,,,,, -178900,8.857302,1.7904191,,,,,,,,,,,,,, -179000,7.315598,1.766966,,,,,,,,,,,,,, -179100,7.2007804,1.8125384,,,,,,,,,,,,,, -179196,,,0.9178292155265808,0.4130044281482696,0.7719599604606628,0.9963983297348022,50000.0,0.6491000056266785,1.6055240631103516,10000.0,60751.85917687416,62962.54237771034,60751.85917687416,2198.5527551174164,5.948404788970947,0.0 -179200,8.257444,1.7480112,,,,,,,,,,,,,, -179300,8.120436,1.7776461,,,,,,,,,,,,,, -179400,7.788121,1.7324378,,,,,,,,,,,,,, -179500,8.455129,1.7759178,,,,,,,,,,,,,, -179600,7.6867156,1.7645097,,,,,,,,,,,,,, -179700,7.7413588,1.7664721,,,,,,,,,,,,,, -179800,7.437921,1.7665777,,,,,,,,,,,,,, -179900,7.783893,1.7359067,,,,,,,,,,,,,, -180000,7.513466,1.7134731,,,,,,,,,,,,,, -180100,8.485822,1.8093674,,,,,,,,,,,,,, -180200,8.88779,1.7337962,,,,,,,,,,,,,, -180300,8.285094,1.7570083,,,,,,,,,,,,,, -180400,8.082336,1.7969278,,,,,,,,,,,,,, -180500,7.793805,1.7592578,,,,,,,,,,,,,, -180600,7.150359,1.7280375,,,,,,,,,,,,,, -180700,8.383072,1.8417013,,,,,,,,,,,,,, -180703,,,0.9183274507522584,0.4075784087181091,0.7716599702835083,0.9949711561203004,50000.0,0.6509000062942505,1.6056536436080933,10000.0,61261.88948059082,63490.04742026329,61261.88948059082,2215.9111063480377,6.012499570846558,0.0 -180800,8.681708,1.8037472,,,,,,,,,,,,,, -180900,7.5375743,1.7447848,,,,,,,,,,,,,, -181000,6.9070096,1.717975,,,,,,,,,,,,,, -181100,7.541904,1.7933351,,,,,,,,,,,,,, -181200,7.8913865,1.7492995,,,,,,,,,,,,,, -181300,7.3599267,1.7218057,,,,,,,,,,,,,, -181400,8.39444,1.7538954,,,,,,,,,,,,,, -181500,7.8906474,1.7111512,,,,,,,,,,,,,, -181600,8.594673,1.8136102,,,,,,,,,,,,,, -181700,7.420718,1.7932487,,,,,,,,,,,,,, -181800,8.051808,1.8391594,,,,,,,,,,,,,, -181900,7.8708158,1.8074979,,,,,,,,,,,,,, -182000,8.233057,1.7866788,,,,,,,,,,,,,, -182100,8.36535,1.7975368,,,,,,,,,,,,,, -182200,8.534714,1.825519,,,,,,,,,,,,,, -182210,,,0.919702649116516,0.4040387272834778,0.772159993648529,0.9936366081237792,50000.0,0.6494000554084778,1.6046305894851685,10000.0,61772.02049660683,64017.193064928055,61772.02049660683,2232.815001964569,6.070794105529785,0.0 -182300,7.480448,1.7724894,,,,,,,,,,,,,, -182400,7.9195247,1.8243495,,,,,,,,,,,,,, -182500,7.731333,1.7773509,,,,,,,,,,,,,, -182600,7.6433597,1.7948537,,,,,,,,,,,,,, -182700,7.625134,1.7997644,,,,,,,,,,,,,, -182800,8.564246,1.767006,,,,,,,,,,,,,, -182900,7.5774794,1.8074948,,,,,,,,,,,,,, -183000,8.065646,1.8212223,,,,,,,,,,,,,, -183100,7.3877707,1.7865044,,,,,,,,,,,,,, -183200,7.284476,1.7077891,,,,,,,,,,,,,, -183300,8.151302,1.8459454,,,,,,,,,,,,,, -183400,8.102354,1.7896106,,,,,,,,,,,,,, -183500,8.019737,1.8002726,,,,,,,,,,,,,, -183600,8.317075,1.8920836,,,,,,,,,,,,,, -183700,7.503269,1.713695,,,,,,,,,,,,,, -183717,,,0.9208585619926452,0.3966313898563385,0.7724399566650391,0.9899981021881104,50000.0,0.6511000394821167,1.6006096601486206,10000.0,62282.11263537407,64544.36849927902,62282.11263537407,2249.791279554367,6.127057075500488,0.0 -183800,8.2176075,1.755097,,,,,,,,,,,,,, -183900,7.7194076,1.8332044,,,,,,,,,,,,,, -184000,7.594435,1.8306395,,,,,,,,,,,,,, -184100,7.7985425,1.8137463,,,,,,,,,,,,,, -184200,8.577703,1.8193433,,,,,,,,,,,,,, -184300,7.5315204,1.7821221,,,,,,,,,,,,,, -184400,7.4951468,1.7679545,,,,,,,,,,,,,, -184500,7.863758,1.8295627,,,,,,,,,,,,,, -184600,7.4417596,1.7203369,,,,,,,,,,,,,, -184700,7.894877,1.787321,,,,,,,,,,,,,, -184800,7.1315303,1.7462938,,,,,,,,,,,,,, -184900,7.615345,1.7752093,,,,,,,,,,,,,, -185000,8.115127,1.780961,,,,,,,,,,,,,, -185100,8.081979,1.7970011,,,,,,,,,,,,,, -185200,8.113615,1.7922921,,,,,,,,,,,,,, -185224,,,0.921277105808258,0.3994551599025726,0.7727199792861938,0.9911830425262452,50000.0,0.6509000062942505,1.6002691984176636,10000.0,62792.30425167084,65071.81726360321,62792.30425167084,2266.940537214279,6.183432102203369,0.0 -185300,7.2839637,1.7864475,,,,,,,,,,,,,, -185400,7.767775,1.7679495,,,,,,,,,,,,,, -185500,8.151037,1.757922,,,,,,,,,,,,,, -185600,8.17657,1.8417695,,,,,,,,,,,,,, -185700,7.6393085,1.7608352,,,,,,,,,,,,,, -185800,7.541031,1.7681903,,,,,,,,,,,,,, -185862,,,,,,,,,,,63008.31415557861,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 02e83b9da..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.747056245803833,0.0,33.4391393661499,1,0,33.4391393661499,0.0013000001199543,6.9117279052734375,10000,51.186286211013794,0.0006576849264092,6.912638187408447,0.0006799999973736,6.912051200866699,50000 -35.347055435180664,0.0192117691040039,543.6612985134125,1500,0,543.6612985134125,0.0571000017225742,5.528391361236572,10000,579.0811305046082,0.0859375,5.226979732513428,0.0790399983525276,5.291569709777832,50000 -53.0393807888031,0.0525212287902832,1053.6502876281738,2998,0,1053.6502876281738,0.1293999999761581,4.724494457244873,10000,1106.8475799560547,0.1965680718421936,4.198776721954346,0.1792199909687042,4.319735050201416,50000 -71.19184017181396,0.079308271408081,1563.8795804977417,4496,0,1563.8795804977417,0.2076000124216079,4.096458911895752,10000,1635.308201789856,0.3117426633834839,3.379237413406372,0.2842199802398681,3.520533323287964,50000 -88.74252462387085,0.1061484813690185,2074.014079093933,5994,0,2074.014079093933,0.2656000256538391,3.724989652633667,10000,2163.0716738700867,0.3806600570678711,2.963285207748413,0.3558799922466278,3.0989913940429688,50000 -106.2141387462616,0.1336226463317871,2584.21363902092,7493,0,2584.21363902092,0.3315000236034393,3.3746931552886963,10000,2690.823953151703,0.4618144035339355,2.6370625495910645,0.4315599799156189,2.7958452701568604,50000 -123.49534726142883,0.1627497673034668,3094.3763308525085,8993,0,3094.3763308525085,0.3613000214099884,3.1709978580474854,10000,3218.3501625061035,0.5307716727256775,2.255420207977295,0.4728399813175201,2.547547578811645,50000 -141.10048961639404,0.1944019794464111,3604.335671663284,10493,0,3604.335671663284,0.3886000216007232,3.049933433532715,10000,3745.998475790024,0.5577766299247742,2.1650500297546387,0.5104199647903442,2.3915514945983887,50000 -158.77054691314697,0.2259001731872558,4114.402346134186,11993,0,4114.402346134186,0.4271000325679779,2.8042361736297607,10000,4273.819544792175,0.6017418503761292,1.94335687160492,0.5548200011253357,2.156901597976685,50000 -176.094083070755,0.2551212310791015,4624.474926710129,13494,0,4624.474926710129,0.4510000348091125,2.655189275741577,10000,4801.298516511917,0.6224888563156128,1.774627685546875,0.5730400085449219,2.0085930824279785,50000 -193.91254234313965,0.2859010696411133,5134.568528413773,14995,0,5134.568528413773,0.456900030374527,2.6749420166015625,10000,5329.296462774277,0.6300023794174194,1.7994438409805298,0.5817999839782715,2.027262687683105,50000 -211.41987991333008,0.31459641456604,5644.774896144867,16498,0,5644.774896144867,0.4735000133514404,2.5429165363311768,10000,5857.091723680496,0.6436144709587097,1.6748861074447632,0.5989199876785278,1.89004135131836,50000 -228.8724648952484,0.3461291790008545,6155.2353079319,18001,0,6155.2353079319,0.4797000288963318,2.4914474487304688,10000,6385.087836503983,0.7017298936843872,1.4178893566131592,0.6070199608802795,1.8295985460281368,50000 -246.26633167266849,0.3774294853210449,6665.291936635971,19503,0,6665.291936635971,0.496800035238266,2.404106855392456,10000,6912.620675325394,0.6970065236091614,1.4243632555007937,0.6238399744033813,1.7519325017929075,50000 -264.5082576274872,0.4087963104248047,7175.285215139389,21005,0,7175.285215139389,0.49590003490448,2.432694673538208,10000,7440.939846038818,0.6917251348495483,1.4819576740264893,0.6297399997711182,1.761543035507202,50000 -281.779283285141,0.4405534267425537,7685.427667379379,22507,0,7685.427667379379,0.5031999945640564,2.40031361579895,10000,7968.438979625702,0.6911072731018066,1.4709314107894895,0.629539966583252,1.7492947578430176,50000 -299.5811333656311,0.4718921184539795,8195.398663759232,24009,0,8195.398663759232,0.5029000043869019,2.397300720214844,10000,8496.294484138489,0.6873604655265808,1.4837753772735596,0.6243799924850464,1.7636795043945312,50000 -317.2381706237793,0.5033385753631592,8705.627409219742,25512,0,8705.627409219742,0.5123000144958496,2.378972291946411,10000,9024.263773679731,0.6941366195678711,1.4621155261993408,0.6352399587631226,1.7313721179962158,50000 -334.6807444095612,0.5375471115112305,9215.783554315569,27015,0,9215.783554315569,0.5028000473976135,2.38372540473938,10000,9551.9481112957,0.7335180044174194,1.286011815071106,0.6350799798965454,1.7067891359329224,50000 -352.87031412124634,0.5690395832061768,9725.861224412918,28518,0,9725.861224412918,0.5126000046730042,2.341330051422119,10000,10080.300015211104,0.7210817933082581,1.3177582025527954,0.6410999894142151,1.6783519983291626,50000 -370.6090452671051,0.5975942611694336,10235.83157801628,30021,0,10235.83157801628,0.5089000463485718,2.397118330001831,10000,10608.090680122375,0.7088049650192261,1.3988709449768066,0.6373199820518494,1.7182626724243164,50000 -388.2906057834625,0.6299974918365479,10745.978585481644,31525,0,10745.978585481644,0.511900007724762,2.325451374053955,10000,11136.00351715088,0.7114556431770325,1.361906886100769,0.646340012550354,1.6596877574920654,50000 -406.0712497234345,0.669827938079834,11256.05967617035,33027,0,11256.05967617035,0.5130000114440918,2.3377459049224854,10000,11663.95685338974,0.7103396058082581,1.374114990234375,0.6455199718475342,1.6690250635147097,50000 -423.49549770355225,0.7041542530059814,11766.244701385498,34532,0,11766.244701385498,0.5246000289916992,2.271498680114746,10000,12191.65503835678,0.7158003449440002,1.3178354501724243,0.6526600122451782,1.6064244508743286,50000 -441.2348201274872,0.7414686679840088,12276.337631940842,36036,0,12276.337631940842,0.5085000395774841,2.3779425621032715,10000,12719.575999498367,0.7166773080825806,1.3349188566207886,0.6400600075721741,1.6809086799621582,50000 -458.72729420661926,0.7722766399383545,12786.48780798912,37540,0,12786.48780798912,0.5275000333786011,2.2785587310791016,10000,13247.302038192747,0.7401347160339355,1.2507890462875366,0.6541999578475952,1.622602939605713,50000 -476.3352725505829,0.80930495262146,13296.705997228622,39044,0,13296.705997228622,0.5231000185012817,2.2626054286956787,10000,13775.218259096146,0.7385004758834839,1.2336959838867188,0.6571800112724304,1.5852231979370115,50000 -493.83841919898987,0.8440403938293457,13806.621383428574,40547,0,13806.621383428574,0.5362000465393066,2.2442195415496826,10000,14302.725590705872,0.7364476919174194,1.2696105241775513,0.6611799597740173,1.5964609384536743,50000 -512.0555701255798,0.8885290622711182,14316.647776842115,42051,0,14316.647776842115,0.5211000442504883,2.314053058624268,10000,14831.06675171852,0.7135283946990967,1.3360618352890017,0.6412599682807922,1.6634626388549805,50000 -529.3009285926819,0.9220795631408693,14826.77017712593,43555,0,14826.77017712593,0.5362000465393066,2.240588426589966,10000,15358.520558834076,0.7233737111091614,1.2983133792877195,0.6625799536705017,1.5901869535446167,50000 -547.1348049640656,0.9746749401092528,15336.741206169128,45058,0,15336.741206169128,0.5277000069618225,2.281455516815185,10000,15886.432034015656,0.7239118218421936,1.328299045562744,0.6556199789047241,1.6299936771392822,50000 -564.7328906059265,1.0138413906097412,15846.720610380173,46562,0,15846.720610380173,0.5383000373840332,2.2207345962524414,10000,16414.10246682167,0.7540656924247742,1.198965311050415,0.6650199890136719,1.589596390724182,50000 -582.2785995006561,1.0533804893493652,16356.965524196625,48066,0,16356.965524196625,0.5409000515937805,2.236502170562744,10000,16941.98569583893,0.7438217401504517,1.2626017332077026,0.6648600101470947,1.6124789714813232,50000 -599.9203262329102,1.0888376235961914,16867.10874891281,49571,0,16867.10874891281,0.5455999970436096,2.187873601913452,10000,17469.85976576805,0.7395169138908386,1.25126314163208,0.6665399670600891,1.5785037279129028,50000 -617.5973706245422,1.126636266708374,17377.16103053093,51075,0,17377.16103053093,0.5387000441551208,2.2007992267608643,10000,17997.680426597595,0.7398756146430969,1.2330973148345947,0.666700005531311,1.5564137697219849,50000 -635.0826358795166,1.1614611148834229,17887.110765457153,52578,0,17887.110765457153,0.5347000360488892,2.245861291885376,10000,18525.204036712646,0.7335976958274841,1.282623529434204,0.6645999550819397,1.590146541595459,50000 -652.7812848091125,1.1971261501312256,18397.201429128647,54082,0,18397.201429128647,0.5454000234603882,2.1653194427490234,10000,19053.081587553024,0.7420878410339355,1.2159703969955444,0.6678999662399292,1.5432895421981812,50000 -670.0841941833496,1.236152410507202,18907.42284011841,55586,0,18907.42284011841,0.5271000266075134,2.352210998535156,10000,19580.69886445999,0.7388392686843872,1.2988405227661133,0.6522200107574463,1.6849249601364136,50000 -687.3933174610138,1.2751078605651855,19417.531020641327,57090,0,19417.531020641327,0.5313000082969666,2.2467947006225586,10000,20108.20799922943,0.7461535334587097,1.222550630569458,0.6632999777793884,1.5833584070205688,50000 -704.799777507782,1.3132286071777344,19927.526131629944,58594,0,19927.526131629944,0.5463000535964966,2.2097489833831787,10000,20635.700806617737,0.750418484210968,1.2128936052322388,0.6708599925041199,1.5625097751617432,50000 -722.1753432750702,1.3547327518463137,20437.654178380966,60098,0,20437.654178380966,0.5517000555992126,2.1742067337036133,10000,21163.29753088951,0.7451171875,1.1998909711837769,0.6719799637794495,1.529384970664978,50000 -739.661140203476,1.3968374729156494,20947.57455515861,61601,0,20947.57455515861,0.5524000525474548,2.158252954483032,10000,21690.79941987992,0.7535673975944519,1.186854362487793,0.6801599860191345,1.5134427547454834,50000 -757.0620410442352,1.448808193206787,21457.5686917305,63105,0,21457.5686917305,0.541100025177002,2.2219254970550537,10000,22218.297739744183,0.7455755472183228,1.2469745874404907,0.6726199984550476,1.565129637718201,50000 -774.6855986118317,1.487745761871338,21967.481950998303,64609,0,21967.481950998303,0.5548000335693359,2.164292335510254,10000,22745.924226760864,0.7794363498687744,1.0946069955825806,0.6756399869918823,1.5386929512023926,50000 -792.5888900756836,1.5288665294647217,22477.70741415024,66114,0,22477.70741415024,0.5603000521659851,2.147666931152344,10000,23274.14600086212,0.7683154940605164,1.1131125688552856,0.6814999580383301,1.4990878105163574,50000 -810.3098471164703,1.5757479667663574,22987.609982013702,67618,0,22987.609982013702,0.5412000417709351,2.1961100101470947,10000,23801.867978334427,0.7588488459587097,1.1663302183151243,0.6769199967384338,1.5283010005950928,50000 -827.7932825088501,1.6187632083892822,23497.5886631012,69122,0,23497.5886631012,0.5520000457763672,2.1934776306152344,10000,24329.425166130062,0.7570351958274841,1.1902135610580444,0.6778199672698975,1.5349388122558594,50000 -845.2944579124451,1.6605701446533203,24007.770708084103,70627,0,24007.770708084103,0.5467000007629395,2.213312864303589,10000,24857.20196413994,0.7458944320678711,1.2468349933624268,0.668999969959259,1.5783573389053345,50000 -862.8972687721252,1.7031033039093018,24517.84316849709,72131,0,24517.84316849709,0.5525000095367432,2.200495719909668,10000,25384.97437238693,0.7466716766357422,1.2096285820007324,0.6700999736785889,1.548133373260498,50000 -880.508659362793,1.747278928756714,25028.014729499817,73636,0,25028.014729499817,0.556600034236908,2.1635515689849854,10000,25912.85485696793,0.7932278513908386,1.0378259420394895,0.6830999851226807,1.5034915208816528,50000 -898.0165367126465,1.791778802871704,25538.20123958588,75141,0,25538.20123958588,0.5615000128746033,2.1242897510528564,10000,26440.64587116241,0.7795758843421936,1.0734102725982666,0.6831799745559692,1.4843811988830566,50000 -915.6695799827576,1.832149028778076,26048.3084192276,76646,0,26048.3084192276,0.5569000244140625,2.119495391845703,10000,26968.49898505211,0.7704480290412903,1.1027264595031738,0.6850999593734741,1.478226661682129,50000 -932.8865323066713,1.8738563060760496,26558.43163084984,78151,0,26558.43163084984,0.5664000511169434,2.1145002841949463,10000,27495.934143304825,0.7723413705825806,1.126224398612976,0.6868199706077576,1.48687481880188,50000 -950.5808329582214,1.9172337055206297,27068.563962221146,79656,0,27068.563962221146,0.5633000135421753,2.1523826122283936,10000,28023.855808973312,0.7729591727256775,1.1393262147903442,0.6887999773025513,1.5029897689819336,50000 -968.9772000312804,1.9648942947387693,27578.76552867889,81161,0,27578.76552867889,0.5533000230789185,2.143951177597046,10000,28552.553351163864,0.7640505433082581,1.1171501874923706,0.682379961013794,1.4879083633422852,50000 -986.45987200737,2.0094380378723145,28088.932448387142,82665,0,28088.932448387142,0.5690000057220459,2.0855112075805664,10000,29080.30035591125,0.8116629123687744,0.9491603970527648,0.6913999915122986,1.452277898788452,50000 -1003.9821727275848,2.056580305099488,28599.081391334534,84170,0,28599.081391334534,0.563800036907196,2.106266975402832,10000,29608.071023464203,0.786152720451355,1.03113853931427,0.6920199990272522,1.438202738761902,50000 -1021.2940018177032,2.099645137786865,29109.275886297222,85675,0,29109.275886297222,0.5645000338554382,2.099759101867676,10000,30135.671944856644,0.7844387888908386,1.050876259803772,0.6937800049781799,1.4405685663223269,50000 -1038.789042711258,2.1402103900909424,29619.42868781089,87180,0,29619.42868781089,0.5717000365257263,2.0682175159454346,10000,30663.413135290142,0.7828842401504517,1.0602805614471436,0.6989799737930298,1.430122971534729,50000 -1056.377511024475,2.181400775909424,30129.41860818863,88684,0,30129.41860818863,0.565500020980835,2.0685789585113525,10000,31191.086156368256,0.7784797549247742,1.0503365993499756,0.6977999806404114,1.412257194519043,50000 -1073.8877835273745,2.227370262145996,30639.66088938713,90189,0,30639.66088938713,0.567300021648407,2.074946641921997,10000,31718.93631315232,0.7796755433082581,1.052192211151123,0.6933599710464478,1.4358712434768677,50000 -1091.269455909729,2.272991180419922,31149.84545230865,91694,0,31149.84545230865,0.5756000280380249,2.055137872695923,10000,32246.600895643234,0.8287826776504517,0.869025707244873,0.7048199772834778,1.3970078229904177,50000 -1108.6878995895386,2.321443557739258,31659.772423505783,93198,0,31659.772423505783,0.5735000371932983,2.089207172393799,10000,32774.04626703262,0.7990872263908386,0.9855494499206544,0.6990999579429626,1.429892659187317,50000 -1126.209413766861,2.367856502532959,32169.7107861042,94701,0,32169.7107861042,0.5597000122070312,2.086948871612549,10000,33301.60674357414,0.7864317297935486,1.0102328062057495,0.6943399906158447,1.414360761642456,50000 -1143.754875421524,2.413187265396118,32679.70833396912,96205,0,32679.70833396912,0.5768000483512878,2.0695486068725586,10000,33829.24775338173,0.8004623651504517,1.0195797681808472,0.7038599848747253,1.428849697113037,50000 -1161.3141412734983,2.4573397636413574,33189.67002224922,97709,0,33189.67002224922,0.5755000114440918,2.061427354812622,10000,34356.86605596542,0.7979910373687744,1.0076885223388672,0.7076399922370911,1.4024112224578855,50000 -1178.4537107944489,2.500805854797364,33699.635172605515,99213,0,33699.635172605515,0.5746999979019165,2.0553035736083984,10000,34884.06643486023,0.7958186864852905,1.0074490308761597,0.7049799561500549,1.3988691568374634,50000 -1196.1693103313446,2.546966552734375,34209.61276984215,100717,0,34209.61276984215,0.579800009727478,2.0282983779907227,10000,35411.85849046707,0.8254942297935486,0.8944934010505676,0.7102400064468384,1.3767321109771729,50000 -1213.4867506027222,2.591370344161988,34719.75500845909,102222,0,34719.75500845909,0.5827000141143799,2.037583112716675,10000,35939.416763305664,0.8199936151504517,0.9147710800170898,0.7091999650001526,1.3878830671310425,50000 -1230.9158039093018,2.652010679244995,35229.8086771965,103726,0,35229.8086771965,0.5809000134468079,2.029460430145264,10000,36467.01249408722,0.8092314600944519,0.9412323236465454,0.7082799673080444,1.3787775039672852,50000 -1248.351620197296,2.697594404220581,35739.95236110687,105231,0,35739.95236110687,0.5879000425338745,2.013390302658081,10000,36994.68914103508,0.810546875,0.9459798336029052,0.7136399745941162,1.372373104095459,50000 -1265.6972844600675,2.745483875274658,36250.164659023285,106736,0,36250.164659023285,0.5924000144004822,1.9973350763320925,10000,37522.34675478935,0.8152502775192261,0.9214245080947876,0.7125799655914307,1.3635106086730957,50000 -1283.1132173538208,2.795067071914673,36760.11304235458,108240,0,36760.11304235458,0.5960000157356262,2.006366014480591,10000,38049.81420207024,0.8124202489852905,0.9675102829933168,0.7174199819564819,1.3825063705444336,50000 -1300.696792602539,2.8408803939819336,37270.30475926399,109745,0,37270.30475926399,0.5906000137329102,1.9520862102508545,10000,38577.689265728,0.8219866156578064,0.869674026966095,0.7159799933433533,1.3264567852020264,50000 -1318.3463683128357,2.886791706085205,37780.53218817711,111249,0,37780.53218817711,0.589900016784668,2.0042407512664795,10000,39105.6652610302,0.8328284025192261,0.8718942999839783,0.7133600115776062,1.374558448791504,50000 -1335.973201751709,2.9371728897094727,38290.7551791668,112754,0,38290.7551791668,0.5905000567436218,2.003549337387085,10000,39633.61887669563,0.829121470451355,0.8686312437057495,0.7148799896240234,1.34848952293396,50000 -1353.783153772354,2.9890265464782715,38800.8717019558,114259,0,38800.8717019558,0.597100019454956,1.997839331626892,10000,40161.65034651756,0.8265106678009033,0.8988329768180847,0.7185800075531006,1.36379075050354,50000 -1371.0964777469635,3.0382204055786133,39310.83686709404,115764,0,39310.83686709404,0.5939000248908997,2.0048916339874268,10000,40689.02972865105,0.8270089030265808,0.9045054316520692,0.722819983959198,1.3632978200912476,50000 -1388.6389904022217,3.0894229412078857,39820.8876388073,117269,0,39820.8876388073,0.6014000177383423,1.9583295583724976,10000,41216.72868323326,0.8316724896430969,0.8688330054283142,0.7247799634933472,1.3182138204574585,50000 -1406.6010098457336,3.143878936767578,40331.05041480064,118774,0,40331.05041480064,0.6079000234603882,1.9448498487472528,10000,41744.96152448654,0.8371332883834839,0.8562281131744385,0.725659966468811,1.3277863264083862,50000 -1424.0867023468018,3.195374011993408,40841.25688076019,120279,0,40841.25688076019,0.5987000465393066,1.9820739030838013,10000,42272.75825166702,0.8495694994926453,0.8090940117835999,0.720579981803894,1.3468101024627686,50000 -1441.527908563614,3.246919870376587,41351.40269494057,121783,0,41351.40269494057,0.6050000190734863,1.9310460090637207,10000,42800.45030093193,0.8489118218421936,0.7738045454025269,0.7283599972724915,1.285322904586792,50000 -1459.0470685958862,3.294360637664795,41861.53848028183,123288,0,41861.53848028183,0.6044000387191772,1.924842476844788,10000,43328.20562505722,0.8492506146430969,0.7951707243919373,0.7295199632644653,1.2949326038360596,50000 -1476.5027883052826,3.3449392318725586,42371.742612838745,124793,0,42371.742612838745,0.6044000387191772,1.9222787618637085,10000,43855.968982219696,0.8462810516357422,0.8076683878898621,0.7290399670600891,1.3021578788757324,50000 -1494.1439380645752,3.396597146987915,42881.6758556366,126297,0,42881.6758556366,0.596500039100647,1.9440224170684808,10000,44383.64790058136,0.8422752022743225,0.7997431755065918,0.7246800065040588,1.2923554182052612,50000 -1511.4806642532349,3.4473884105682373,43391.6541454792,127801,0,43391.6541454792,0.613800048828125,1.8942350149154663,10000,44911.067249298096,0.8557876348495483,0.7706797122955322,0.7314199805259705,1.2801084518432615,50000 -1528.9485657215118,3.4982686042785645,43901.69043469429,129305,0,43901.69043469429,0.6118000149726868,1.882664918899536,10000,45438.67550635338,0.8703563213348389,0.682815670967102,0.7312600016593933,1.259967565536499,50000 -1546.2606403827667,3.5511486530303955,44411.83263254166,130810,0,44411.83263254166,0.6105000376701355,1.8829753398895264,10000,45966.23621058464,0.869559109210968,0.7021733522415161,0.7360599637031555,1.262072205543518,50000 -1563.5348060131073,3.6022284030914307,44921.73423457146,132313,0,44921.73423457146,0.61080002784729,1.890793800354004,10000,46493.51533031464,0.8615074753761292,0.731268048286438,0.7324999570846558,1.2648930549621582,50000 -1581.1292176246643,3.6521716117858887,45431.71617555618,133817,0,45431.71617555618,0.6159000396728516,1.9176183938980105,10000,47021.19475483894,0.8706154227256775,0.7510040998458862,0.7382999658584595,1.281070113182068,50000 -1598.651022195816,3.70396900177002,45941.8173930645,135322,0,45941.8173930645,0.6148000359535217,1.8706231117248533,10000,47548.92354273796,0.8668088316917419,0.6989973187446594,0.7389199733734131,1.2317496538162231,50000 -1616.0231895446775,3.759093761444092,46451.72684264183,136826,0,46451.72684264183,0.6172000169754028,1.886087417602539,10000,48076.31215286255,0.8732461333274841,0.7015025019645691,0.7404599785804749,1.2484326362609863,50000 -1633.4662518501282,3.812281608581543,46961.66208958626,138331,0,46961.66208958626,0.6165000200271606,1.88319993019104,10000,48603.79757499695,0.8901665806770325,0.6423742771148682,0.7400799989700317,1.2568217515945437,50000 -1650.7227218151093,3.866595983505249,47471.80485224724,139836,0,47471.80485224724,0.6229000091552734,1.850735783576965,10000,49131.303370952606,0.8897680044174194,0.627855658531189,0.74235999584198,1.2312299013137815,50000 -1668.2225155830383,3.922005653381348,47981.99515795708,141341,0,47981.99515795708,0.6237000226974487,1.8512567281723025,10000,49659.10092759133,0.886738657951355,0.6396247744560242,0.745639979839325,1.2290503978729248,50000 -1685.4138662815094,3.981988430023194,48492.10888576508,142846,0,48492.10888576508,0.6234000325202942,1.851295828819275,10000,50186.51897478104,0.8891502022743225,0.6449418067932129,0.7464199662208557,1.2233476638793943,50000 -1702.677888393402,4.033036470413208,49002.30559277535,144351,0,49002.30559277535,0.6279000043869019,1.849866271018982,10000,50714.08239650726,0.8878347873687744,0.636131763458252,0.7448999881744385,1.2222989797592163,50000 -1720.0401091575625,4.083997964859009,49512.39498496056,145856,0,49512.39498496056,0.6261000037193298,1.8473961353302,10000,51241.637093544006,0.8903858065605164,0.6295886039733887,0.7457000017166138,1.2230639457702637,50000 -1737.5025906562803,4.140517711639404,50022.29506134987,147360,0,50022.29506134987,0.6310000419616699,1.8346625566482544,10000,51769.10906100273,0.9123684167861938,0.5474082827568054,0.7480799555778503,1.2101439237594604,50000 -1754.9088788032532,4.21375036239624,50532.36496210098,148864,0,50532.36496210098,0.6290000081062317,1.848047018051148,10000,52296.71101593971,0.9058912396430968,0.5808451175689697,0.7503199577331543,1.2156970500946045,50000 -1772.8417789936066,4.267144680023193,51042.27467918396,150368,0,51042.27467918396,0.628600001335144,1.820862054824829,10000,52824.65858411789,0.9048748016357422,0.5630061030387878,0.7488600015640259,1.1980425119400024,50000 -1790.2585053443909,4.323015451431274,51552.41488814354,151873,0,51552.41488814354,0.6339000463485718,1.816255807876587,10000,53352.32368469238,0.9094387292861938,0.5636129379272461,0.7510600090026855,1.2021054029464722,50000 -1807.7640812397003,4.375983238220215,52062.60071802139,153378,0,52062.60071802139,0.6305000185966492,1.829545259475708,10000,53880.12092804909,0.9103156924247742,0.563623309135437,0.7524799704551697,1.2048771381378174,50000 -1825.2644836902616,4.426474571228027,52572.8312189579,154883,0,52572.8312189579,0.6341000199317932,1.829664707183838,10000,54407.95564079285,0.911730706691742,0.5565671324729919,0.752020001411438,1.2094497680664062,50000 -1843.028647899628,4.481486082077026,53082.78435873985,156387,0,53082.78435873985,0.6367000341415405,1.8163809776306152,10000,54935.78036189079,0.9270368218421936,0.5088555812835693,0.7543599605560303,1.2003954648971558,50000 -1861.3870635032647,4.539382696151733,53592.87838935852,157892,0,53592.87838935852,0.6326000094413757,1.8125512599945068,10000,55464.342358112335,0.922632336616516,0.5026780366897583,0.7541999816894531,1.184553146362305,50000 -1878.803019285202,4.597425699234009,54102.98790359497,159396,0,54102.98790359497,0.64000004529953,1.815568685531616,10000,55991.9786362648,0.9216158986091614,0.5223816633224487,0.7558599710464478,1.198951244354248,50000 -1896.3183450698853,4.652040481567383,54612.884423971176,160900,0,54612.884423971176,0.6383000016212463,1.8186490535736084,10000,56519.49779844284,0.9227519035339355,0.5238286256790161,0.7569599747657776,1.1966742277145386,50000 -1913.817587852478,4.720116376876831,55122.97531867027,162405,0,55122.97531867027,0.6376000046730042,1.815171360969544,10000,57047.208235025406,0.9255221486091614,0.5173733234405518,0.7567399740219116,1.1949071884155271,50000 -1931.449508666992,4.778955459594727,55633.16078686714,163910,0,55633.16078686714,0.6381000280380249,1.8144419193267824,10000,57575.13672566414,0.9263990521430968,0.5073674917221069,0.7562999725341797,1.1926721334457395,50000 -1948.845008611679,4.834153413772583,56143.32125544548,165415,0,56143.32125544548,0.6370000243186951,1.8039889335632324,10000,58102.80092835426,0.93558669090271,0.4720862805843353,0.7570399641990662,1.1844284534454346,50000 -1966.1763620376587,4.896468162536621,56653.3142747879,166919,0,56653.3142747879,0.640500009059906,1.7958463430404663,10000,58630.24149942398,0.9346898794174194,0.4757241606712341,0.759880006313324,1.1815327405929563,50000 -1983.467565536499,4.954056978225708,57163.46930789948,168424,0,57163.46930789948,0.6402000188827515,1.7989342212677002,10000,59157.79861664772,0.9356863498687744,0.4620772004127502,0.7594199776649475,1.1747729778289795,50000 -2000.991406917572,5.014198780059815,57673.66196155548,169929,0,57673.66196155548,0.6430000066757202,1.7972735166549685,10000,59685.62716174126,0.9359853267669678,0.4732051193714142,0.7616399526596069,1.178303360939026,50000 -2018.402989387512,5.076361656188965,58183.85387516022,171434,0,58183.85387516022,0.6459000110626221,1.798004150390625,10000,60213.34439706802,0.9338527917861938,0.4760339260101318,0.7612999677658081,1.17757248878479,50000 -2035.8027634620669,5.131362676620483,58693.92332792282,172938,0,58693.92332792282,0.6421000361442566,1.8015018701553345,10000,60740.91999530792,0.9358657598495485,0.4723958671092987,0.7612400054931641,1.1811754703521729,50000 -2053.356356859207,5.18861722946167,59204.02347588539,174442,0,59204.02347588539,0.6446000337600708,1.7955400943756104,10000,61268.68344020844,0.938875138759613,0.4575260877609253,0.7615000009536743,1.1745693683624268,50000 -2070.9494745731354,5.2475292682647705,59714.15299129486,175947,0,59714.15299129486,0.6451000571250916,1.7953126430511477,10000,61796.51736474037,0.9404296875,0.4535814821720123,0.7617799639701843,1.1750295162200928,50000 -2088.686345100403,5.303653001785278,60224.07753944397,177451,0,60224.07753944397,0.647100031375885,1.7913333177566528,10000,62324.28744530678,0.9409877061843872,0.457239419221878,0.762179970741272,1.175032138824463,50000 -2106.0333411693573,5.364094734191895,60734.13635802269,178955,0,60734.13635802269,0.6452000141143799,1.7911723852157593,10000,62851.80583786965,0.9386160373687744,0.4601021409034729,0.7621399760246277,1.1740968227386477,50000 -2123.3806269168854,5.421769380569458,61244.19420695305,180459,0,61244.19420695305,0.6443000435829163,1.7931833267211914,10000,63379.31936812401,0.9403499364852904,0.4558300673961639,0.7623599767684937,1.1769626140594482,50000 -2140.9366660118103,5.483500957489014,61754.27343964577,181963,0,61754.27343964577,0.6443000435829163,1.7920598983764648,10000,63907.06888461113,0.9390943646430968,0.4599470198154449,0.7626799941062927,1.1733717918395996,50000 -2158.5463812351227,6.609484672546387,62263.354476451874,183465,0,62263.354476451874,0.6443000435829163,1.7937980890274048,10000,64434.93941235542,0.9401307106018066,0.4537459015846252,0.7622999548912048,1.1755332946777344,50000 -2176.019249677658,6.671429634094238,62773.463541030884,184970,0,62773.463541030884,0.6452000141143799,1.7898361682891846,10000,64962.63630104065,0.9401506781578064,0.45149460434913635,0.7625799775123596,1.1717240810394287,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/measurements.csv deleted file mode 100644 index 5913317dd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1983 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.53007793,6.9183903,,,,,,,,,,,,,, -1,,,0.0006576849264092,6.912638187408447,0.0006799999973736,6.912051200866699,50000.0,0.0013000001199543,6.9117279052734375,10000.0,33.4391393661499,51.186286211013794,33.4391393661499,17.747056245803833,0.0,0.0 -100,0.5243191,6.9022217,,,,,,,,,,,,,, -200,0.5339566,6.85553,,,,,,,,,,,,,, -300,0.5776818,6.7704396,,,,,,,,,,,,,, -400,0.63670266,6.671179,,,,,,,,,,,,,, -500,0.66286606,6.595006,,,,,,,,,,,,,, -600,0.70742077,6.5122275,,,,,,,,,,,,,, -700,0.8528745,6.4691086,,,,,,,,,,,,,, -800,1.2188753,6.388007,,,,,,,,,,,,,, -900,1.846686,6.2864037,,,,,,,,,,,,,, -1000,2.8172822,6.2032948,,,,,,,,,,,,,, -1100,1.8448063,6.1830783,,,,,,,,,,,,,, -1200,1.5183877,6.071398,,,,,,,,,,,,,, -1300,1.4006057,6.038412,,,,,,,,,,,,,, -1400,1.6544429,5.952715,,,,,,,,,,,,,, -1500,,,0.0859375,5.226979732513428,0.0790399983525276,5.291569709777832,50000.0,0.0571000017225742,5.528391361236572,10000.0,543.6612985134125,579.0811305046082,543.6612985134125,35.347055435180664,0.0192117691040039,0.0 -1500,1.9596616,5.9481354,,,,,,,,,,,,,, -1600,3.331755,5.8162546,,,,,,,,,,,,,, -1700,1.8907006,5.765367,,,,,,,,,,,,,, -1800,2.8380408,5.765199,,,,,,,,,,,,,, -1900,3.6439667,5.7672343,,,,,,,,,,,,,, -2000,2.350937,5.6506815,,,,,,,,,,,,,, -2100,2.6122978,5.580435,,,,,,,,,,,,,, -2200,3.0132005,5.589504,,,,,,,,,,,,,, -2300,3.0934865,5.533934,,,,,,,,,,,,,, -2400,2.983362,5.4864287,,,,,,,,,,,,,, -2500,2.707905,5.440854,,,,,,,,,,,,,, -2600,3.6773832,5.3576007,,,,,,,,,,,,,, -2700,3.1353292,5.476649,,,,,,,,,,,,,, -2800,3.691067,5.3183465,,,,,,,,,,,,,, -2900,5.265792,5.285225,,,,,,,,,,,,,, -2998,,,0.1965680718421936,4.198776721954346,0.1792199909687042,4.319735050201416,50000.0,0.1293999999761581,4.724494457244873,10000.0,1053.6502876281738,1106.8475799560547,1053.6502876281738,53.0393807888031,0.0525212287902832,0.0 -3000,3.15797,5.2265916,,,,,,,,,,,,,, -3100,3.9605951,5.1831255,,,,,,,,,,,,,, -3200,3.979344,5.167163,,,,,,,,,,,,,, -3300,2.505346,5.210641,,,,,,,,,,,,,, -3400,3.9335866,5.1316915,,,,,,,,,,,,,, -3500,3.5064034,5.0963173,,,,,,,,,,,,,, -3600,8.030725,5.0855346,,,,,,,,,,,,,, -3700,3.6281767,4.940949,,,,,,,,,,,,,, -3800,3.0432794,4.9693666,,,,,,,,,,,,,, -3900,2.5225282,4.908684,,,,,,,,,,,,,, -4000,4.2952485,4.9061594,,,,,,,,,,,,,, -4100,2.8743186,4.830805,,,,,,,,,,,,,, -4200,3.828085,4.707897,,,,,,,,,,,,,, -4300,2.3892832,4.7825036,,,,,,,,,,,,,, -4400,3.4087408,4.8064375,,,,,,,,,,,,,, -4496,,,0.3117426633834839,3.379237413406372,0.2842199802398681,3.520533323287964,50000.0,0.2076000124216079,4.096458911895752,10000.0,1563.8795804977417,1635.308201789856,1563.8795804977417,71.19184017181396,0.079308271408081,0.0 -4500,3.2733796,4.7986283,,,,,,,,,,,,,, -4600,3.047998,4.673756,,,,,,,,,,,,,, -4700,2.9840202,4.7027316,,,,,,,,,,,,,, -4800,4.390091,4.7362742,,,,,,,,,,,,,, -4900,4.9035997,4.6036944,,,,,,,,,,,,,, -5000,3.7176137,4.7018375,,,,,,,,,,,,,, -5100,3.805436,4.548787,,,,,,,,,,,,,, -5200,4.031943,4.680272,,,,,,,,,,,,,, -5300,3.3736057,4.451186,,,,,,,,,,,,,, -5400,3.9522152,4.5507035,,,,,,,,,,,,,, -5500,3.5385787,4.3873587,,,,,,,,,,,,,, -5600,4.128977,4.3881702,,,,,,,,,,,,,, -5700,4.8831973,4.3631744,,,,,,,,,,,,,, -5800,2.977499,4.348094,,,,,,,,,,,,,, -5900,2.9817324,4.430764,,,,,,,,,,,,,, -5994,,,0.3806600570678711,2.963285207748413,0.3558799922466278,3.0989913940429688,50000.0,0.2656000256538391,3.724989652633667,10000.0,2074.014079093933,2163.0716738700867,2074.014079093933,88.74252462387085,0.1061484813690185,0.0 -6000,3.7833107,4.277339,,,,,,,,,,,,,, -6100,3.538883,4.314688,,,,,,,,,,,,,, -6200,3.090262,4.4028215,,,,,,,,,,,,,, -6300,2.751867,4.321579,,,,,,,,,,,,,, -6400,2.1915212,4.3054733,,,,,,,,,,,,,, -6500,3.410493,4.255901,,,,,,,,,,,,,, -6600,2.7659028,4.1653085,,,,,,,,,,,,,, -6700,2.812657,4.1997414,,,,,,,,,,,,,, -6800,3.1375966,4.189416,,,,,,,,,,,,,, -6900,2.1851115,4.1953425,,,,,,,,,,,,,, -7000,2.7217238,4.2361393,,,,,,,,,,,,,, -7100,2.6995785,4.1401052,,,,,,,,,,,,,, -7200,2.9650846,4.1034184,,,,,,,,,,,,,, -7300,2.5640435,4.1314664,,,,,,,,,,,,,, -7400,2.9652944,4.061359,,,,,,,,,,,,,, -7493,,,0.4618144035339355,2.6370625495910645,0.4315599799156189,2.7958452701568604,50000.0,0.3315000236034393,3.3746931552886963,10000.0,2584.21363902092,2690.823953151703,2584.21363902092,106.2141387462616,0.1336226463317871,0.0 -7500,1.8715445,4.1567726,,,,,,,,,,,,,, -7600,1.9733274,4.0305634,,,,,,,,,,,,,, -7700,2.2033336,4.0523815,,,,,,,,,,,,,, -7800,2.243729,4.0494733,,,,,,,,,,,,,, -7900,2.0605237,4.087781,,,,,,,,,,,,,, -8000,2.6862922,4.0315027,,,,,,,,,,,,,, -8100,1.9753596,4.0690546,,,,,,,,,,,,,, -8200,3.169049,3.9696956,,,,,,,,,,,,,, -8300,2.51048,3.9861314,,,,,,,,,,,,,, -8400,3.3460593,4.0094166,,,,,,,,,,,,,, -8500,3.099159,4.09377,,,,,,,,,,,,,, -8600,1.8206763,3.9934957,,,,,,,,,,,,,, -8700,2.623005,3.9717937,,,,,,,,,,,,,, -8800,1.7539487,3.995576,,,,,,,,,,,,,, -8900,1.9644407,3.860087,,,,,,,,,,,,,, -8993,,,0.5307716727256775,2.255420207977295,0.4728399813175201,2.547547578811645,50000.0,0.3613000214099884,3.1709978580474854,10000.0,3094.3763308525085,3218.3501625061035,3094.3763308525085,123.49534726142883,0.1627497673034668,0.0 -9000,2.251333,4.0932775,,,,,,,,,,,,,, -9100,1.9824795,3.8803468,,,,,,,,,,,,,, -9200,2.4044163,3.9778535,,,,,,,,,,,,,, -9300,2.3636405,3.9205542,,,,,,,,,,,,,, -9400,2.8150756,3.8734777,,,,,,,,,,,,,, -9500,2.4761908,3.903776,,,,,,,,,,,,,, -9600,2.3838408,3.8225424,,,,,,,,,,,,,, -9700,1.8742512,3.8792894,,,,,,,,,,,,,, -9800,1.751355,3.8089123,,,,,,,,,,,,,, -9900,2.1251032,3.844111,,,,,,,,,,,,,, -10000,1.886425,3.801008,,,,,,,,,,,,,, -10100,2.2774081,3.886293,,,,,,,,,,,,,, -10200,1.815929,3.8007088,,,,,,,,,,,,,, -10300,1.6595322,3.856038,,,,,,,,,,,,,, -10400,2.4554482,3.7768238,,,,,,,,,,,,,, -10493,,,0.5577766299247742,2.1650500297546387,0.5104199647903442,2.3915514945983887,50000.0,0.3886000216007232,3.049933433532715,10000.0,3604.335671663284,3745.998475790024,3604.335671663284,141.10048961639404,0.1944019794464111,0.0 -10500,2.8462524,3.8330512,,,,,,,,,,,,,, -10600,1.605692,3.7426069,,,,,,,,,,,,,, -10700,2.6848807,3.7464972,,,,,,,,,,,,,, -10800,3.013575,3.7322786,,,,,,,,,,,,,, -10900,1.730739,3.831715,,,,,,,,,,,,,, -11000,1.9209075,3.8175876,,,,,,,,,,,,,, -11100,1.7070206,3.6735928,,,,,,,,,,,,,, -11200,1.4434593,3.7231672,,,,,,,,,,,,,, -11300,1.7541217,3.6885295,,,,,,,,,,,,,, -11400,1.741867,3.7071733,,,,,,,,,,,,,, -11500,2.3299215,3.748128,,,,,,,,,,,,,, -11600,1.8025699,3.7732944,,,,,,,,,,,,,, -11700,1.7776257,3.7033806,,,,,,,,,,,,,, -11800,1.462137,3.7810757,,,,,,,,,,,,,, -11900,1.3874133,3.7407453,,,,,,,,,,,,,, -11993,,,0.6017418503761292,1.94335687160492,0.5548200011253357,2.156901597976685,50000.0,0.4271000325679779,2.8042361736297607,10000.0,4114.402346134186,4273.819544792175,4114.402346134186,158.77054691314697,0.2259001731872558,0.0 -12000,1.1892796,3.6565201,,,,,,,,,,,,,, -12100,1.8448852,3.6123018,,,,,,,,,,,,,, -12200,1.4926373,3.6459737,,,,,,,,,,,,,, -12300,1.6965644,3.6821895,,,,,,,,,,,,,, -12400,1.6607454,3.6275344,,,,,,,,,,,,,, -12500,2.4953234,3.6368222,,,,,,,,,,,,,, -12600,1.7321548,3.6873112,,,,,,,,,,,,,, -12700,1.7656084,3.6237068,,,,,,,,,,,,,, -12800,1.8037772,3.6265247,,,,,,,,,,,,,, -12900,1.4403155,3.5888903,,,,,,,,,,,,,, -13000,1.5297611,3.568368,,,,,,,,,,,,,, -13100,1.4795827,3.6072457,,,,,,,,,,,,,, -13200,1.6811414,3.640893,,,,,,,,,,,,,, -13300,2.1070995,3.6464906,,,,,,,,,,,,,, -13400,1.900721,3.631728,,,,,,,,,,,,,, -13494,,,0.6224888563156128,1.774627685546875,0.5730400085449219,2.0085930824279785,50000.0,0.4510000348091125,2.655189275741577,10000.0,4624.474926710129,4801.298516511917,4624.474926710129,176.094083070755,0.2551212310791015,0.0 -13500,1.9262061,3.6114485,,,,,,,,,,,,,, -13600,1.5360682,3.5622444,,,,,,,,,,,,,, -13700,1.5542122,3.513699,,,,,,,,,,,,,, -13800,1.8319225,3.5766444,,,,,,,,,,,,,, -13900,1.4362284,3.5468287,,,,,,,,,,,,,, -14000,1.6659569,3.5295348,,,,,,,,,,,,,, -14100,1.7475212,3.5691724,,,,,,,,,,,,,, -14200,2.0659184,3.5170252,,,,,,,,,,,,,, -14300,1.4880906,3.6058643,,,,,,,,,,,,,, -14400,1.4771082,3.5463548,,,,,,,,,,,,,, -14500,1.4894922,3.5204115,,,,,,,,,,,,,, -14600,1.5999901,3.5200224,,,,,,,,,,,,,, -14700,1.3843801,3.4639592,,,,,,,,,,,,,, -14800,1.9912605,3.5738895,,,,,,,,,,,,,, -14900,1.433263,3.5040553,,,,,,,,,,,,,, -14995,,,0.6300023794174194,1.7994438409805298,0.5817999839782715,2.027262687683105,50000.0,0.456900030374527,2.6749420166015625,10000.0,5134.568528413773,5329.296462774277,5134.568528413773,193.91254234313965,0.2859010696411133,0.0 -15000,1.4640278,3.488144,,,,,,,,,,,,,, -15100,1.7255365,3.6187973,,,,,,,,,,,,,, -15200,1.6509266,3.5663419,,,,,,,,,,,,,, -15300,1.1351008,3.5531693,,,,,,,,,,,,,, -15400,2.010826,3.5645375,,,,,,,,,,,,,, -15500,1.2679302,3.5615046,,,,,,,,,,,,,, -15600,1.6017749,3.530664,,,,,,,,,,,,,, -15700,1.4391023,3.4514635,,,,,,,,,,,,,, -15800,1.7346088,3.4651895,,,,,,,,,,,,,, -15900,1.2241194,3.4821324,,,,,,,,,,,,,, -16000,1.4430419,3.5015874,,,,,,,,,,,,,, -16100,1.449448,3.4880166,,,,,,,,,,,,,, -16200,1.7944751,3.5127988,,,,,,,,,,,,,, -16300,1.7955409,3.4748662,,,,,,,,,,,,,, -16400,1.5893719,3.5888276,,,,,,,,,,,,,, -16498,,,0.6436144709587097,1.6748861074447632,0.5989199876785278,1.89004135131836,50000.0,0.4735000133514404,2.5429165363311768,10000.0,5644.774896144867,5857.091723680496,5644.774896144867,211.41987991333008,0.31459641456604,0.0 -16500,1.3203968,3.431067,,,,,,,,,,,,,, -16600,2.1914167,3.5363302,,,,,,,,,,,,,, -16700,1.654979,3.460074,,,,,,,,,,,,,, -16800,2.031797,3.4044356,,,,,,,,,,,,,, -16900,1.4981098,3.507496,,,,,,,,,,,,,, -17000,1.6441832,3.4541652,,,,,,,,,,,,,, -17100,1.3473821,3.402613,,,,,,,,,,,,,, -17200,1.6363906,3.4999676,,,,,,,,,,,,,, -17300,1.6395018,3.428873,,,,,,,,,,,,,, -17400,1.4629667,3.44824,,,,,,,,,,,,,, -17500,1.706172,3.464917,,,,,,,,,,,,,, -17600,1.5185934,3.4982076,,,,,,,,,,,,,, -17700,1.4969205,3.4235995,,,,,,,,,,,,,, -17800,1.4762022,3.5413141,,,,,,,,,,,,,, -17900,1.2350925,3.3671336,,,,,,,,,,,,,, -18000,1.6328473,3.4794757,,,,,,,,,,,,,, -18001,,,0.7017298936843872,1.4178893566131592,0.6070199608802795,1.8295985460281368,50000.0,0.4797000288963318,2.4914474487304688,10000.0,6155.2353079319,6385.087836503983,6155.2353079319,228.8724648952484,0.3461291790008545,0.0 -18100,1.4078934,3.4484687,,,,,,,,,,,,,, -18200,1.8181356,3.437283,,,,,,,,,,,,,, -18300,1.8323483,3.480203,,,,,,,,,,,,,, -18400,1.7247887,3.3595111,,,,,,,,,,,,,, -18500,1.175362,3.4382098,,,,,,,,,,,,,, -18600,1.4335535,3.4115539,,,,,,,,,,,,,, -18700,1.0753274,3.4229686,,,,,,,,,,,,,, -18800,1.6997528,3.4389684,,,,,,,,,,,,,, -18900,1.5908545,3.3985238,,,,,,,,,,,,,, -19000,1.426914,3.4108381,,,,,,,,,,,,,, -19100,1.5875851,3.4054341,,,,,,,,,,,,,, -19200,1.7775002,3.4415202,,,,,,,,,,,,,, -19300,1.895856,3.4490998,,,,,,,,,,,,,, -19400,1.3244205,3.407294,,,,,,,,,,,,,, -19500,1.2091135,3.340525,,,,,,,,,,,,,, -19503,,,0.6970065236091614,1.4243632555007937,0.6238399744033813,1.7519325017929075,50000.0,0.496800035238266,2.404106855392456,10000.0,6665.291936635971,6912.620675325394,6665.291936635971,246.26633167266849,0.3774294853210449,0.0 -19600,1.4085213,3.3609827,,,,,,,,,,,,,, -19700,1.2671186,3.3835146,,,,,,,,,,,,,, -19800,1.7346365,3.340704,,,,,,,,,,,,,, -19900,1.1929768,3.4916182,,,,,,,,,,,,,, -20000,1.5206851,3.4005039,,,,,,,,,,,,,, -20100,1.4611347,3.3706427,,,,,,,,,,,,,, -20200,1.1820093,3.398648,,,,,,,,,,,,,, -20300,1.7781208,3.4212258,,,,,,,,,,,,,, -20400,1.4559796,3.486506,,,,,,,,,,,,,, -20500,1.3851607,3.444614,,,,,,,,,,,,,, -20600,1.8683459,3.455677,,,,,,,,,,,,,, -20700,1.5009791,3.3905625,,,,,,,,,,,,,, -20800,1.3624184,3.3966808,,,,,,,,,,,,,, -20900,1.5869062,3.366357,,,,,,,,,,,,,, -21000,1.7370398,3.3986683,,,,,,,,,,,,,, -21005,,,0.6917251348495483,1.4819576740264893,0.6297399997711182,1.761543035507202,50000.0,0.49590003490448,2.432694673538208,10000.0,7175.285215139389,7440.939846038818,7175.285215139389,264.5082576274872,0.4087963104248047,0.0 -21100,1.3548683,3.3756437,,,,,,,,,,,,,, -21200,1.4068127,3.3719432,,,,,,,,,,,,,, -21300,2.0568266,3.3534937,,,,,,,,,,,,,, -21400,1.6169343,3.3814788,,,,,,,,,,,,,, -21500,1.5043525,3.401853,,,,,,,,,,,,,, -21600,1.4274921,3.3700318,,,,,,,,,,,,,, -21700,1.3893316,3.2932951,,,,,,,,,,,,,, -21800,1.1395742,3.3227203,,,,,,,,,,,,,, -21900,1.3824087,3.3141432,,,,,,,,,,,,,, -22000,1.3744642,3.3616705,,,,,,,,,,,,,, -22100,1.3448173,3.4306822,,,,,,,,,,,,,, -22200,1.4297646,3.3580372,,,,,,,,,,,,,, -22300,1.4749954,3.2867665,,,,,,,,,,,,,, -22400,1.3365388,3.359713,,,,,,,,,,,,,, -22500,1.529518,3.3090627,,,,,,,,,,,,,, -22507,,,0.6911072731018066,1.4709314107894895,0.629539966583252,1.7492947578430176,50000.0,0.5031999945640564,2.40031361579895,10000.0,7685.427667379379,7968.438979625702,7685.427667379379,281.779283285141,0.4405534267425537,0.0 -22600,1.4319353,3.3461947,,,,,,,,,,,,,, -22700,1.1648679,3.3246248,,,,,,,,,,,,,, -22800,1.3235464,3.3462842,,,,,,,,,,,,,, -22900,1.0470531,3.3193877,,,,,,,,,,,,,, -23000,1.3472718,3.3014054,,,,,,,,,,,,,, -23100,1.2961215,3.3354487,,,,,,,,,,,,,, -23200,1.4982378,3.3354638,,,,,,,,,,,,,, -23300,1.1521528,3.4045594,,,,,,,,,,,,,, -23400,1.1275434,3.3459516,,,,,,,,,,,,,, -23500,1.3855536,3.4105358,,,,,,,,,,,,,, -23600,1.1903524,3.34189,,,,,,,,,,,,,, -23700,1.2109791,3.3499384,,,,,,,,,,,,,, -23800,1.8472505,3.3843448,,,,,,,,,,,,,, -23900,1.2917914,3.2984228,,,,,,,,,,,,,, -24000,1.2702787,3.3305953,,,,,,,,,,,,,, -24009,,,0.6873604655265808,1.4837753772735596,0.6243799924850464,1.7636795043945312,50000.0,0.5029000043869019,2.397300720214844,10000.0,8195.398663759232,8496.294484138489,8195.398663759232,299.5811333656311,0.4718921184539795,0.0 -24100,1.2397985,3.3142853,,,,,,,,,,,,,, -24200,1.1655052,3.276065,,,,,,,,,,,,,, -24300,1.4394324,3.2545893,,,,,,,,,,,,,, -24400,1.2630243,3.3107378,,,,,,,,,,,,,, -24500,1.3213918,3.3337219,,,,,,,,,,,,,, -24600,1.3998308,3.3724217,,,,,,,,,,,,,, -24700,1.6036626,3.2930255,,,,,,,,,,,,,, -24800,1.336687,3.300675,,,,,,,,,,,,,, -24900,1.7019025,3.3317013,,,,,,,,,,,,,, -25000,1.3341504,3.284715,,,,,,,,,,,,,, -25100,1.4722469,3.2425942,,,,,,,,,,,,,, -25200,1.4198943,3.3014526,,,,,,,,,,,,,, -25300,1.2722653,3.2953444,,,,,,,,,,,,,, -25400,1.5619912,3.356591,,,,,,,,,,,,,, -25500,1.1255035,3.280814,,,,,,,,,,,,,, -25512,,,0.6941366195678711,1.4621155261993408,0.6352399587631226,1.7313721179962158,50000.0,0.5123000144958496,2.378972291946411,10000.0,8705.627409219742,9024.263773679731,8705.627409219742,317.2381706237793,0.5033385753631592,0.0 -25600,1.5762684,3.239552,,,,,,,,,,,,,, -25700,1.2346163,3.3225706,,,,,,,,,,,,,, -25800,1.2743665,3.290781,,,,,,,,,,,,,, -25900,1.2896878,3.3799214,,,,,,,,,,,,,, -26000,1.4015872,3.2701938,,,,,,,,,,,,,, -26100,1.244658,3.298193,,,,,,,,,,,,,, -26200,1.5849842,3.3184144,,,,,,,,,,,,,, -26300,1.6097771,3.2681246,,,,,,,,,,,,,, -26400,1.3154229,3.2233372,,,,,,,,,,,,,, -26500,1.3382287,3.2646773,,,,,,,,,,,,,, -26600,1.414936,3.2426856,,,,,,,,,,,,,, -26700,1.4780828,3.3040934,,,,,,,,,,,,,, -26800,1.4593432,3.3555984,,,,,,,,,,,,,, -26900,1.3389374,3.2836418,,,,,,,,,,,,,, -27000,1.4291539,3.2741265,,,,,,,,,,,,,, -27015,,,0.7335180044174194,1.286011815071106,0.6350799798965454,1.7067891359329224,50000.0,0.5028000473976135,2.38372540473938,10000.0,9215.783554315569,9551.9481112957,9215.783554315569,334.6807444095612,0.5375471115112305,0.0 -27100,1.2697581,3.2243347,,,,,,,,,,,,,, -27200,1.3099526,3.2568333,,,,,,,,,,,,,, -27300,1.4614995,3.2151213,,,,,,,,,,,,,, -27400,1.5654151,3.3103065,,,,,,,,,,,,,, -27500,1.3119626,3.2527037,,,,,,,,,,,,,, -27600,1.3055055,3.278419,,,,,,,,,,,,,, -27700,1.5011414,3.283169,,,,,,,,,,,,,, -27800,1.350218,3.2642486,,,,,,,,,,,,,, -27900,1.362067,3.256581,,,,,,,,,,,,,, -28000,1.3280009,3.2863188,,,,,,,,,,,,,, -28100,1.3788705,3.294234,,,,,,,,,,,,,, -28200,1.6018621,3.2654755,,,,,,,,,,,,,, -28300,1.5069777,3.2601805,,,,,,,,,,,,,, -28400,1.4372634,3.2803984,,,,,,,,,,,,,, -28500,1.422018,3.2035859,,,,,,,,,,,,,, -28518,,,0.7210817933082581,1.3177582025527954,0.6410999894142151,1.6783519983291626,50000.0,0.5126000046730042,2.341330051422119,10000.0,9725.861224412918,10080.300015211104,9725.861224412918,352.87031412124634,0.5690395832061768,0.0 -28600,1.3775618,3.2707548,,,,,,,,,,,,,, -28700,1.334501,3.2917924,,,,,,,,,,,,,, -28800,1.4579092,3.2440288,,,,,,,,,,,,,, -28900,1.6591148,3.2017941,,,,,,,,,,,,,, -29000,1.4353667,3.2680206,,,,,,,,,,,,,, -29100,1.2308278,3.3041377,,,,,,,,,,,,,, -29200,1.7787689,3.3037887,,,,,,,,,,,,,, -29300,1.4959425,3.333031,,,,,,,,,,,,,, -29400,1.5580875,3.3059757,,,,,,,,,,,,,, -29500,1.4095824,3.271861,,,,,,,,,,,,,, -29600,1.4364283,3.3407266,,,,,,,,,,,,,, -29700,1.2185497,3.2600543,,,,,,,,,,,,,, -29800,1.599134,3.33781,,,,,,,,,,,,,, -29900,1.570122,3.2447217,,,,,,,,,,,,,, -30000,2.1381328,3.269824,,,,,,,,,,,,,, -30021,,,0.7088049650192261,1.3988709449768066,0.6373199820518494,1.7182626724243164,50000.0,0.5089000463485718,2.397118330001831,10000.0,10235.83157801628,10608.090680122375,10235.83157801628,370.6090452671051,0.5975942611694336,0.0 -30100,1.4208239,3.2664702,,,,,,,,,,,,,, -30200,1.6391689,3.2994046,,,,,,,,,,,,,, -30300,1.4961535,3.2933269,,,,,,,,,,,,,, -30400,1.4034054,3.2953103,,,,,,,,,,,,,, -30500,1.3829782,3.3422303,,,,,,,,,,,,,, -30600,1.3879594,3.233873,,,,,,,,,,,,,, -30700,1.4291579,3.2509198,,,,,,,,,,,,,, -30800,1.4822986,3.318923,,,,,,,,,,,,,, -30900,1.1925429,3.2779474,,,,,,,,,,,,,, -31000,1.3248423,3.243764,,,,,,,,,,,,,, -31100,1.4542747,3.24617,,,,,,,,,,,,,, -31200,1.2760524,3.254262,,,,,,,,,,,,,, -31300,1.3281842,3.2870343,,,,,,,,,,,,,, -31400,1.5319172,3.2848673,,,,,,,,,,,,,, -31500,1.2634355,3.2193763,,,,,,,,,,,,,, -31525,,,0.7114556431770325,1.361906886100769,0.646340012550354,1.6596877574920654,50000.0,0.511900007724762,2.325451374053955,10000.0,10745.978585481644,11136.00351715088,10745.978585481644,388.2906057834625,0.6299974918365479,0.0 -31600,1.3354604,3.2323472,,,,,,,,,,,,,, -31700,1.3086684,3.2192535,,,,,,,,,,,,,, -31800,1.3729014,3.2653918,,,,,,,,,,,,,, -31900,1.3312212,3.2084422,,,,,,,,,,,,,, -32000,1.4394549,3.292316,,,,,,,,,,,,,, -32100,1.3842533,3.195475,,,,,,,,,,,,,, -32200,1.3931243,3.2077844,,,,,,,,,,,,,, -32300,1.5431684,3.1610231,,,,,,,,,,,,,, -32400,1.2569374,3.2349925,,,,,,,,,,,,,, -32500,1.3981299,3.285206,,,,,,,,,,,,,, -32600,1.435171,3.2611108,,,,,,,,,,,,,, -32700,1.6148827,3.1894882,,,,,,,,,,,,,, -32800,1.7862707,3.2798455,,,,,,,,,,,,,, -32900,1.4285249,3.1979244,,,,,,,,,,,,,, -33000,1.4637061,3.2077017,,,,,,,,,,,,,, -33027,,,0.7103396058082581,1.374114990234375,0.6455199718475342,1.6690250635147097,50000.0,0.5130000114440918,2.3377459049224854,10000.0,11256.05967617035,11663.95685338974,11256.05967617035,406.0712497234345,0.669827938079834,0.0 -33100,1.5069801,3.2034674,,,,,,,,,,,,,, -33200,1.5368607,3.156771,,,,,,,,,,,,,, -33300,1.4415009,3.2535944,,,,,,,,,,,,,, -33400,1.3709481,3.247339,,,,,,,,,,,,,, -33500,1.2888385,3.22169,,,,,,,,,,,,,, -33600,1.2863564,3.2087777,,,,,,,,,,,,,, -33700,1.3264687,3.2195666,,,,,,,,,,,,,, -33800,1.4910179,3.3578207,,,,,,,,,,,,,, -33900,1.5182465,3.2355096,,,,,,,,,,,,,, -34000,1.3837982,3.2808955,,,,,,,,,,,,,, -34100,1.4963186,3.2541249,,,,,,,,,,,,,, -34200,1.522659,3.2733753,,,,,,,,,,,,,, -34300,1.4298301,3.2018077,,,,,,,,,,,,,, -34400,1.4915385,3.201373,,,,,,,,,,,,,, -34500,1.3084149,3.2005804,,,,,,,,,,,,,, -34532,,,0.7158003449440002,1.3178354501724243,0.6526600122451782,1.6064244508743286,50000.0,0.5246000289916992,2.271498680114746,10000.0,11766.244701385498,12191.65503835678,11766.244701385498,423.49549770355225,0.7041542530059814,0.0 -34600,1.4249605,3.2011514,,,,,,,,,,,,,, -34700,1.5314687,3.221914,,,,,,,,,,,,,, -34800,1.4341648,3.1642191,,,,,,,,,,,,,, -34900,1.3569456,3.1697664,,,,,,,,,,,,,, -35000,1.4893944,3.2198663,,,,,,,,,,,,,, -35100,1.3456097,3.042161,,,,,,,,,,,,,, -35200,1.2822633,3.1713197,,,,,,,,,,,,,, -35300,1.578492,3.2457626,,,,,,,,,,,,,, -35400,1.4307017,3.1882598,,,,,,,,,,,,,, -35500,1.5978312,3.242754,,,,,,,,,,,,,, -35600,1.2478129,3.1976175,,,,,,,,,,,,,, -35700,1.4438119,3.2109847,,,,,,,,,,,,,, -35800,1.4394194,3.2475576,,,,,,,,,,,,,, -35900,1.4600829,3.2066498,,,,,,,,,,,,,, -36000,1.6038299,3.204903,,,,,,,,,,,,,, -36036,,,0.7166773080825806,1.3349188566207886,0.6400600075721741,1.6809086799621582,50000.0,0.5085000395774841,2.3779425621032715,10000.0,12276.337631940842,12719.575999498367,12276.337631940842,441.2348201274872,0.7414686679840088,0.0 -36100,1.7120266,3.2763557,,,,,,,,,,,,,, -36200,1.4178475,3.1872017,,,,,,,,,,,,,, -36300,1.2975477,3.1882277,,,,,,,,,,,,,, -36400,1.5324515,3.2338934,,,,,,,,,,,,,, -36500,1.3539658,3.1903145,,,,,,,,,,,,,, -36600,1.5291836,3.2046978,,,,,,,,,,,,,, -36700,1.5024761,3.1835408,,,,,,,,,,,,,, -36800,1.507235,3.2299619,,,,,,,,,,,,,, -36900,1.4377294,3.1730628,,,,,,,,,,,,,, -37000,1.3972524,3.2970672,,,,,,,,,,,,,, -37100,1.5000968,3.2496402,,,,,,,,,,,,,, -37200,1.5373687,3.2072554,,,,,,,,,,,,,, -37300,1.7431806,3.2362142,,,,,,,,,,,,,, -37400,1.645465,3.2437778,,,,,,,,,,,,,, -37500,1.456754,3.227751,,,,,,,,,,,,,, -37540,,,0.7401347160339355,1.2507890462875366,0.6541999578475952,1.622602939605713,50000.0,0.5275000333786011,2.2785587310791016,10000.0,12786.48780798912,13247.302038192747,12786.48780798912,458.72729420661926,0.7722766399383545,0.0 -37600,1.4472487,3.201582,,,,,,,,,,,,,, -37700,1.6289276,3.2077854,,,,,,,,,,,,,, -37800,1.4374528,3.2374644,,,,,,,,,,,,,, -37900,1.5675256,3.2075124,,,,,,,,,,,,,, -38000,1.3738825,3.173668,,,,,,,,,,,,,, -38100,1.4356217,3.1826644,,,,,,,,,,,,,, -38200,1.6434376,3.1902246,,,,,,,,,,,,,, -38300,1.5172993,3.1234555,,,,,,,,,,,,,, -38400,1.4642242,3.1617055,,,,,,,,,,,,,, -38500,1.7671729,3.1693282,,,,,,,,,,,,,, -38600,1.4848088,3.1854534,,,,,,,,,,,,,, -38700,1.5390422,3.2401633,,,,,,,,,,,,,, -38800,1.7569461,3.2718704,,,,,,,,,,,,,, -38900,1.3761839,3.1431093,,,,,,,,,,,,,, -39000,1.502638,3.274508,,,,,,,,,,,,,, -39044,,,0.7385004758834839,1.2336959838867188,0.6571800112724304,1.5852231979370115,50000.0,0.5231000185012817,2.2626054286956787,10000.0,13296.705997228622,13775.218259096146,13296.705997228622,476.3352725505829,0.80930495262146,0.0 -39100,1.5130789,3.2307377,,,,,,,,,,,,,, -39200,1.3217574,3.1560895,,,,,,,,,,,,,, -39300,1.6580029,3.247662,,,,,,,,,,,,,, -39400,1.471699,3.122148,,,,,,,,,,,,,, -39500,1.5815115,3.2018678,,,,,,,,,,,,,, -39600,1.7200414,3.149337,,,,,,,,,,,,,, -39700,1.7136122,3.204105,,,,,,,,,,,,,, -39800,1.831527,3.185395,,,,,,,,,,,,,, -39900,1.5478166,3.1885033,,,,,,,,,,,,,, -40000,1.5023401,3.165664,,,,,,,,,,,,,, -40100,1.4656934,3.1573133,,,,,,,,,,,,,, -40200,1.5731422,3.1881518,,,,,,,,,,,,,, -40300,1.4625548,3.1651027,,,,,,,,,,,,,, -40400,1.655826,3.2464929,,,,,,,,,,,,,, -40500,1.6854959,3.1645038,,,,,,,,,,,,,, -40547,,,0.7364476919174194,1.2696105241775513,0.6611799597740173,1.5964609384536743,50000.0,0.5362000465393066,2.2442195415496826,10000.0,13806.621383428574,14302.725590705872,13806.621383428574,493.83841919898987,0.8440403938293457,0.0 -40600,1.6458942,3.1808176,,,,,,,,,,,,,, -40700,1.4265375,3.2227244,,,,,,,,,,,,,, -40800,1.6140938,3.223607,,,,,,,,,,,,,, -40900,1.7026433,3.2256176,,,,,,,,,,,,,, -41000,1.53128,3.129157,,,,,,,,,,,,,, -41100,1.5352358,3.2213035,,,,,,,,,,,,,, -41200,1.5081675,3.1650393,,,,,,,,,,,,,, -41300,1.4570181,3.183837,,,,,,,,,,,,,, -41400,1.734356,3.2877371,,,,,,,,,,,,,, -41500,1.5432684,3.143184,,,,,,,,,,,,,, -41600,1.6275268,3.2189283,,,,,,,,,,,,,, -41700,1.6779078,3.2281384,,,,,,,,,,,,,, -41800,1.6347115,3.131571,,,,,,,,,,,,,, -41900,1.6835626,3.2651222,,,,,,,,,,,,,, -42000,1.4661881,3.211886,,,,,,,,,,,,,, -42051,,,0.7135283946990967,1.3360618352890017,0.6412599682807922,1.6634626388549805,50000.0,0.5211000442504883,2.314053058624268,10000.0,14316.647776842115,14831.06675171852,14316.647776842115,512.0555701255798,0.8885290622711182,0.0 -42100,1.6753213,3.1790392,,,,,,,,,,,,,, -42200,1.7480378,3.2853618,,,,,,,,,,,,,, -42300,1.5817802,3.144611,,,,,,,,,,,,,, -42400,1.6197528,3.2000284,,,,,,,,,,,,,, -42500,1.5078207,3.1927137,,,,,,,,,,,,,, -42600,1.6526372,3.178208,,,,,,,,,,,,,, -42700,1.932925,3.1814198,,,,,,,,,,,,,, -42800,1.5970707,3.1768126,,,,,,,,,,,,,, -42900,1.6290325,3.2151408,,,,,,,,,,,,,, -43000,1.7762119,3.1400805,,,,,,,,,,,,,, -43100,1.5690162,3.1684074,,,,,,,,,,,,,, -43200,1.7204392,3.1170578,,,,,,,,,,,,,, -43300,1.5394099,3.1324973,,,,,,,,,,,,,, -43400,1.6927124,3.267759,,,,,,,,,,,,,, -43500,1.6081859,3.1912384,,,,,,,,,,,,,, -43555,,,0.7233737111091614,1.2983133792877195,0.6625799536705017,1.5901869535446167,50000.0,0.5362000465393066,2.240588426589966,10000.0,14826.77017712593,15358.520558834076,14826.77017712593,529.3009285926819,0.9220795631408693,0.0 -43600,1.465402,3.1149957,,,,,,,,,,,,,, -43700,1.7951503,3.1809978,,,,,,,,,,,,,, -43800,1.7655274,3.2597432,,,,,,,,,,,,,, -43900,1.7904752,3.247763,,,,,,,,,,,,,, -44000,1.5255514,3.2299874,,,,,,,,,,,,,, -44100,1.4661361,3.2439427,,,,,,,,,,,,,, -44200,1.5734178,3.1591926,,,,,,,,,,,,,, -44300,1.5790728,3.1735246,,,,,,,,,,,,,, -44400,1.6595863,3.1862917,,,,,,,,,,,,,, -44500,1.6722254,3.1917899,,,,,,,,,,,,,, -44600,1.8871393,3.1915157,,,,,,,,,,,,,, -44700,1.7613738,3.1322756,,,,,,,,,,,,,, -44800,1.5751312,3.169315,,,,,,,,,,,,,, -44900,1.5666215,3.1544013,,,,,,,,,,,,,, -45000,1.7521454,3.200597,,,,,,,,,,,,,, -45058,,,0.7239118218421936,1.328299045562744,0.6556199789047241,1.6299936771392822,50000.0,0.5277000069618225,2.281455516815185,10000.0,15336.741206169128,15886.432034015656,15336.741206169128,547.1348049640656,0.9746749401092528,0.0 -45100,1.7136242,3.2379305,,,,,,,,,,,,,, -45200,1.6864524,3.189129,,,,,,,,,,,,,, -45300,1.5897592,3.1970637,,,,,,,,,,,,,, -45400,1.7747215,3.2300935,,,,,,,,,,,,,, -45500,1.6106175,3.187168,,,,,,,,,,,,,, -45600,1.5646585,3.1526139,,,,,,,,,,,,,, -45700,1.6464369,3.1468847,,,,,,,,,,,,,, -45800,1.6850324,3.2230833,,,,,,,,,,,,,, -45900,1.6512716,3.096322,,,,,,,,,,,,,, -46000,1.7587023,3.2162309,,,,,,,,,,,,,, -46100,1.7140054,3.153827,,,,,,,,,,,,,, -46200,1.6989475,3.3116615,,,,,,,,,,,,,, -46300,1.7669035,3.1951888,,,,,,,,,,,,,, -46400,1.7012322,3.1252558,,,,,,,,,,,,,, -46500,1.6718934,3.224012,,,,,,,,,,,,,, -46562,,,0.7540656924247742,1.198965311050415,0.6650199890136719,1.589596390724182,50000.0,0.5383000373840332,2.2207345962524414,10000.0,15846.720610380173,16414.10246682167,15846.720610380173,564.7328906059265,1.0138413906097412,0.0 -46600,1.7492558,3.199145,,,,,,,,,,,,,, -46700,1.6257465,3.1811724,,,,,,,,,,,,,, -46800,1.6999733,3.1490362,,,,,,,,,,,,,, -46900,1.928061,3.171096,,,,,,,,,,,,,, -47000,1.7083813,3.190172,,,,,,,,,,,,,, -47100,1.7818835,3.2028058,,,,,,,,,,,,,, -47200,1.7755492,3.15505,,,,,,,,,,,,,, -47300,1.8898135,3.155879,,,,,,,,,,,,,, -47400,1.6490796,3.1719072,,,,,,,,,,,,,, -47500,1.6798505,3.1069503,,,,,,,,,,,,,, -47600,1.8410444,3.1709223,,,,,,,,,,,,,, -47700,1.9513547,3.2189178,,,,,,,,,,,,,, -47800,1.7837394,3.2544851,,,,,,,,,,,,,, -47900,1.715235,3.207501,,,,,,,,,,,,,, -48000,1.6439432,3.1620846,,,,,,,,,,,,,, -48066,,,0.7438217401504517,1.2626017332077026,0.6648600101470947,1.6124789714813232,50000.0,0.5409000515937805,2.236502170562744,10000.0,16356.965524196625,16941.98569583893,16356.965524196625,582.2785995006561,1.0533804893493652,0.0 -48100,1.6813242,3.2429152,,,,,,,,,,,,,, -48200,1.6412058,3.0921762,,,,,,,,,,,,,, -48300,1.7460654,3.1750958,,,,,,,,,,,,,, -48400,1.9319532,3.1762245,,,,,,,,,,,,,, -48500,1.7750548,3.068842,,,,,,,,,,,,,, -48600,1.6979823,3.1884418,,,,,,,,,,,,,, -48700,1.7065884,3.0564656,,,,,,,,,,,,,, -48800,1.645786,3.1418688,,,,,,,,,,,,,, -48900,1.7780573,3.1018302,,,,,,,,,,,,,, -49000,1.8146183,3.2235413,,,,,,,,,,,,,, -49100,1.743587,3.2185557,,,,,,,,,,,,,, -49200,1.9378768,3.0875843,,,,,,,,,,,,,, -49300,1.7991428,3.1760144,,,,,,,,,,,,,, -49400,1.7370248,3.0865495,,,,,,,,,,,,,, -49500,1.8367809,3.262498,,,,,,,,,,,,,, -49571,,,0.7395169138908386,1.25126314163208,0.6665399670600891,1.5785037279129028,50000.0,0.5455999970436096,2.187873601913452,10000.0,16867.10874891281,17469.85976576805,16867.10874891281,599.9203262329102,1.0888376235961914,0.0 -49600,1.7275095,3.1126528,,,,,,,,,,,,,, -49700,1.7628194,3.1819618,,,,,,,,,,,,,, -49800,1.9590877,3.1242096,,,,,,,,,,,,,, -49900,2.0289912,3.1892614,,,,,,,,,,,,,, -50000,1.748004,3.1168697,,,,,,,,,,,,,, -50100,1.5898159,3.1555886,,,,,,,,,,,,,, -50200,1.8379948,3.2392323,,,,,,,,,,,,,, -50300,1.7453613,3.1650164,,,,,,,,,,,,,, -50400,1.899602,3.1319752,,,,,,,,,,,,,, -50500,1.7576168,3.168765,,,,,,,,,,,,,, -50600,1.6922911,3.1220384,,,,,,,,,,,,,, -50700,1.8884671,3.1544907,,,,,,,,,,,,,, -50800,1.7787203,3.1631913,,,,,,,,,,,,,, -50900,1.8903214,3.0519342,,,,,,,,,,,,,, -51000,1.7991625,3.1316712,,,,,,,,,,,,,, -51075,,,0.7398756146430969,1.2330973148345947,0.666700005531311,1.5564137697219849,50000.0,0.5387000441551208,2.2007992267608643,10000.0,17377.16103053093,17997.680426597595,17377.16103053093,617.5973706245422,1.126636266708374,0.0 -51100,1.965547,3.1378474,,,,,,,,,,,,,, -51200,1.7107154,3.1010022,,,,,,,,,,,,,, -51300,1.88318,3.2452717,,,,,,,,,,,,,, -51400,1.7506795,3.1751273,,,,,,,,,,,,,, -51500,1.8584131,3.2558951,,,,,,,,,,,,,, -51600,1.9693428,3.1858194,,,,,,,,,,,,,, -51700,1.8702724,3.1223047,,,,,,,,,,,,,, -51800,1.8935623,3.1206248,,,,,,,,,,,,,, -51900,1.6538005,3.1530356,,,,,,,,,,,,,, -52000,1.7501881,3.1384423,,,,,,,,,,,,,, -52100,2.2234857,3.1401649,,,,,,,,,,,,,, -52200,1.7482282,3.0578494,,,,,,,,,,,,,, -52300,1.7650298,3.0800006,,,,,,,,,,,,,, -52400,1.7024335,3.1114187,,,,,,,,,,,,,, -52500,1.7700175,3.1313033,,,,,,,,,,,,,, -52578,,,0.7335976958274841,1.282623529434204,0.6645999550819397,1.590146541595459,50000.0,0.5347000360488892,2.245861291885376,10000.0,17887.110765457153,18525.204036712646,17887.110765457153,635.0826358795166,1.1614611148834229,0.0 -52600,1.7515205,3.1285694,,,,,,,,,,,,,, -52700,1.7463379,3.1560838,,,,,,,,,,,,,, -52800,1.871275,3.1448536,,,,,,,,,,,,,, -52900,1.9680822,3.113621,,,,,,,,,,,,,, -53000,1.8762901,3.1539638,,,,,,,,,,,,,, -53100,1.8397567,3.1038935,,,,,,,,,,,,,, -53200,1.7467872,3.1949754,,,,,,,,,,,,,, -53300,1.9530933,3.1843352,,,,,,,,,,,,,, -53400,1.8399214,3.1134896,,,,,,,,,,,,,, -53500,1.8663841,3.1937888,,,,,,,,,,,,,, -53600,1.7919617,3.1725383,,,,,,,,,,,,,, -53700,1.7981012,3.1615207,,,,,,,,,,,,,, -53800,1.9112482,3.1804013,,,,,,,,,,,,,, -53900,1.744596,3.1234224,,,,,,,,,,,,,, -54000,1.894135,3.088557,,,,,,,,,,,,,, -54082,,,0.7420878410339355,1.2159703969955444,0.6678999662399292,1.5432895421981812,50000.0,0.5454000234603882,2.1653194427490234,10000.0,18397.201429128647,19053.081587553024,18397.201429128647,652.7812848091125,1.1971261501312256,0.0 -54100,1.7435359,3.116668,,,,,,,,,,,,,, -54200,1.7013547,3.1364648,,,,,,,,,,,,,, -54300,1.6899413,3.1418986,,,,,,,,,,,,,, -54400,1.9753556,3.1144288,,,,,,,,,,,,,, -54500,1.9609958,3.0936322,,,,,,,,,,,,,, -54600,1.8253229,3.1138997,,,,,,,,,,,,,, -54700,1.8837843,3.0955303,,,,,,,,,,,,,, -54800,1.8560052,3.0625892,,,,,,,,,,,,,, -54900,1.9274067,3.1197095,,,,,,,,,,,,,, -55000,1.6460568,3.0687764,,,,,,,,,,,,,, -55100,1.804287,3.1135647,,,,,,,,,,,,,, -55200,2.014499,3.116086,,,,,,,,,,,,,, -55300,1.9341047,3.1766715,,,,,,,,,,,,,, -55400,1.779709,3.1017873,,,,,,,,,,,,,, -55500,1.8993453,3.062014,,,,,,,,,,,,,, -55586,,,0.7388392686843872,1.2988405227661133,0.6522200107574463,1.6849249601364136,50000.0,0.5271000266075134,2.352210998535156,10000.0,18907.42284011841,19580.69886445999,18907.42284011841,670.0841941833496,1.236152410507202,0.0 -55600,2.1471336,3.1386867,,,,,,,,,,,,,, -55700,1.7925489,3.0982447,,,,,,,,,,,,,, -55800,1.9649891,3.1346457,,,,,,,,,,,,,, -55900,1.8767128,3.1503797,,,,,,,,,,,,,, -56000,1.7620538,3.1110325,,,,,,,,,,,,,, -56100,1.8028753,3.1630716,,,,,,,,,,,,,, -56200,1.8332038,3.108925,,,,,,,,,,,,,, -56300,1.9530083,3.131515,,,,,,,,,,,,,, -56400,1.9623703,3.08124,,,,,,,,,,,,,, -56500,1.9214283,3.1086104,,,,,,,,,,,,,, -56600,1.8918403,3.1132627,,,,,,,,,,,,,, -56700,1.9939821,3.1594715,,,,,,,,,,,,,, -56800,1.7475258,3.090219,,,,,,,,,,,,,, -56900,1.8989062,3.0826788,,,,,,,,,,,,,, -57000,1.8550473,3.1237369,,,,,,,,,,,,,, -57090,,,0.7461535334587097,1.222550630569458,0.6632999777793884,1.5833584070205688,50000.0,0.5313000082969666,2.2467947006225586,10000.0,19417.531020641327,20108.20799922943,19417.531020641327,687.3933174610138,1.2751078605651855,0.0 -57100,1.9649022,3.1657915,,,,,,,,,,,,,, -57200,2.0794134,3.1454482,,,,,,,,,,,,,, -57300,1.8416722,3.0800571,,,,,,,,,,,,,, -57400,1.8856634,3.1309562,,,,,,,,,,,,,, -57500,1.9540888,3.0940862,,,,,,,,,,,,,, -57600,1.8799798,3.1406176,,,,,,,,,,,,,, -57700,1.8261038,3.0929663,,,,,,,,,,,,,, -57800,1.8307786,3.1017904,,,,,,,,,,,,,, -57900,1.9204654,3.0333927,,,,,,,,,,,,,, -58000,1.8902241,3.160543,,,,,,,,,,,,,, -58100,1.8615681,3.1829455,,,,,,,,,,,,,, -58200,1.9823505,3.158408,,,,,,,,,,,,,, -58300,1.9650855,3.1120975,,,,,,,,,,,,,, -58400,1.9013195,3.1234112,,,,,,,,,,,,,, -58500,1.9085454,3.1225913,,,,,,,,,,,,,, -58594,,,0.750418484210968,1.2128936052322388,0.6708599925041199,1.5625097751617432,50000.0,0.5463000535964966,2.2097489833831787,10000.0,19927.526131629944,20635.700806617737,19927.526131629944,704.799777507782,1.3132286071777344,0.0 -58600,1.9397743,3.013967,,,,,,,,,,,,,, -58700,2.0959709,3.1051528,,,,,,,,,,,,,, -58800,1.8953475,3.0746646,,,,,,,,,,,,,, -58900,2.0480835,3.1175377,,,,,,,,,,,,,, -59000,1.7852023,3.072875,,,,,,,,,,,,,, -59100,1.8928001,3.0967448,,,,,,,,,,,,,, -59200,1.8685179,3.0869505,,,,,,,,,,,,,, -59300,1.9076397,3.0847785,,,,,,,,,,,,,, -59400,1.9407052,3.1488874,,,,,,,,,,,,,, -59500,2.0113695,3.0823643,,,,,,,,,,,,,, -59600,2.0835118,3.168266,,,,,,,,,,,,,, -59700,2.1262958,3.0800314,,,,,,,,,,,,,, -59800,1.7949218,3.0765996,,,,,,,,,,,,,, -59900,1.8839464,3.0844774,,,,,,,,,,,,,, -60000,1.9391882,3.0741897,,,,,,,,,,,,,, -60098,,,0.7451171875,1.1998909711837769,0.6719799637794495,1.529384970664978,50000.0,0.5517000555992126,2.1742067337036133,10000.0,20437.654178380966,21163.29753088951,20437.654178380966,722.1753432750702,1.3547327518463137,0.0 -60100,2.0220819,3.0725598,,,,,,,,,,,,,, -60200,1.9027531,3.1162093,,,,,,,,,,,,,, -60300,2.0468554,3.1645641,,,,,,,,,,,,,, -60400,1.9054642,3.1548917,,,,,,,,,,,,,, -60500,1.8120677,3.0719452,,,,,,,,,,,,,, -60600,1.94832,3.1892488,,,,,,,,,,,,,, -60700,1.9829792,3.0583282,,,,,,,,,,,,,, -60800,1.9185995,3.0985487,,,,,,,,,,,,,, -60900,2.0303762,3.2624974,,,,,,,,,,,,,, -61000,1.9625535,3.2118654,,,,,,,,,,,,,, -61100,1.9375025,3.1825163,,,,,,,,,,,,,, -61200,2.0063734,3.1250477,,,,,,,,,,,,,, -61300,1.9132968,3.129272,,,,,,,,,,,,,, -61400,1.9837404,3.0856063,,,,,,,,,,,,,, -61500,1.9384618,3.0793717,,,,,,,,,,,,,, -61600,1.9196393,3.0820913,,,,,,,,,,,,,, -61601,,,0.7535673975944519,1.186854362487793,0.6801599860191345,1.5134427547454834,50000.0,0.5524000525474548,2.158252954483032,10000.0,20947.57455515861,21690.79941987992,20947.57455515861,739.661140203476,1.3968374729156494,0.0 -61700,1.83734,3.0442991,,,,,,,,,,,,,, -61800,1.9487894,3.124782,,,,,,,,,,,,,, -61900,1.9062911,3.1169822,,,,,,,,,,,,,, -62000,1.9651822,3.0235326,,,,,,,,,,,,,, -62100,1.9135768,3.0967507,,,,,,,,,,,,,, -62200,2.0420067,3.0573585,,,,,,,,,,,,,, -62300,1.8405244,3.1203048,,,,,,,,,,,,,, -62400,2.0654635,3.1093304,,,,,,,,,,,,,, -62500,1.985361,3.085602,,,,,,,,,,,,,, -62600,2.018806,3.1037848,,,,,,,,,,,,,, -62700,1.8731506,3.082269,,,,,,,,,,,,,, -62800,1.9943651,3.131451,,,,,,,,,,,,,, -62900,2.124494,3.158905,,,,,,,,,,,,,, -63000,1.8528535,3.0380788,,,,,,,,,,,,,, -63100,2.160086,3.2215366,,,,,,,,,,,,,, -63105,,,0.7455755472183228,1.2469745874404907,0.6726199984550476,1.565129637718201,50000.0,0.541100025177002,2.2219254970550537,10000.0,21457.5686917305,22218.297739744183,21457.5686917305,757.0620410442352,1.448808193206787,0.0 -63200,1.9435102,3.150928,,,,,,,,,,,,,, -63300,1.792353,3.0552566,,,,,,,,,,,,,, -63400,1.9950804,3.0792758,,,,,,,,,,,,,, -63500,2.0267217,3.1162403,,,,,,,,,,,,,, -63600,2.0847023,3.1470165,,,,,,,,,,,,,, -63700,2.0046318,3.1392212,,,,,,,,,,,,,, -63800,2.0908074,3.0657463,,,,,,,,,,,,,, -63900,1.9818636,3.1115742,,,,,,,,,,,,,, -64000,1.8907102,3.0193307,,,,,,,,,,,,,, -64100,2.041941,3.0880873,,,,,,,,,,,,,, -64200,1.8982966,3.0039806,,,,,,,,,,,,,, -64300,1.8322345,3.0540004,,,,,,,,,,,,,, -64400,1.8157994,3.0526764,,,,,,,,,,,,,, -64500,2.0006435,3.133564,,,,,,,,,,,,,, -64600,1.9820474,3.0693746,,,,,,,,,,,,,, -64609,,,0.7794363498687744,1.0946069955825806,0.6756399869918823,1.5386929512023926,50000.0,0.5548000335693359,2.164292335510254,10000.0,21967.481950998303,22745.924226760864,21967.481950998303,774.6855986118317,1.487745761871338,0.0 -64700,1.9853569,3.0917034,,,,,,,,,,,,,, -64800,2.1664457,3.0616856,,,,,,,,,,,,,, -64900,2.0676508,3.0864103,,,,,,,,,,,,,, -65000,2.0914876,3.0803208,,,,,,,,,,,,,, -65100,2.2016575,3.1494684,,,,,,,,,,,,,, -65200,2.0547564,3.0729656,,,,,,,,,,,,,, -65300,2.1815932,3.1856046,,,,,,,,,,,,,, -65400,2.0700889,3.1460536,,,,,,,,,,,,,, -65500,2.065904,3.1281073,,,,,,,,,,,,,, -65600,1.9188232,3.0735254,,,,,,,,,,,,,, -65700,1.9435838,3.089527,,,,,,,,,,,,,, -65800,2.0452108,3.0542672,,,,,,,,,,,,,, -65900,1.9961122,3.068892,,,,,,,,,,,,,, -66000,1.9596443,3.0880315,,,,,,,,,,,,,, -66100,1.9177425,3.016517,,,,,,,,,,,,,, -66114,,,0.7683154940605164,1.1131125688552856,0.6814999580383301,1.4990878105163574,50000.0,0.5603000521659851,2.147666931152344,10000.0,22477.70741415024,23274.14600086212,22477.70741415024,792.5888900756836,1.5288665294647217,0.0 -66200,1.9951521,3.1068583,,,,,,,,,,,,,, -66300,2.1583052,3.0996664,,,,,,,,,,,,,, -66400,1.9917324,3.0850465,,,,,,,,,,,,,, -66500,1.9407436,3.0920792,,,,,,,,,,,,,, -66600,2.015167,3.0600567,,,,,,,,,,,,,, -66700,1.9602686,3.147223,,,,,,,,,,,,,, -66800,2.2062316,3.0809069,,,,,,,,,,,,,, -66900,2.129836,3.1163943,,,,,,,,,,,,,, -67000,2.1467807,3.0620632,,,,,,,,,,,,,, -67100,2.0316463,3.1278648,,,,,,,,,,,,,, -67200,2.0320208,3.0575836,,,,,,,,,,,,,, -67300,2.1948717,3.0070581,,,,,,,,,,,,,, -67400,2.0493312,3.0374768,,,,,,,,,,,,,, -67500,2.040073,3.0400252,,,,,,,,,,,,,, -67600,2.0169675,3.1040251,,,,,,,,,,,,,, -67618,,,0.7588488459587097,1.1663302183151243,0.6769199967384338,1.5283010005950928,50000.0,0.5412000417709351,2.1961100101470947,10000.0,22987.609982013702,23801.867978334427,22987.609982013702,810.3098471164703,1.5757479667663574,0.0 -67700,2.1459851,3.0866191,,,,,,,,,,,,,, -67800,2.06643,3.0725527,,,,,,,,,,,,,, -67900,2.0803516,3.0573764,,,,,,,,,,,,,, -68000,2.1253207,3.1200185,,,,,,,,,,,,,, -68100,2.0726385,3.0805845,,,,,,,,,,,,,, -68200,1.9811503,3.0560422,,,,,,,,,,,,,, -68300,2.0575309,3.0340343,,,,,,,,,,,,,, -68400,2.221362,3.0734782,,,,,,,,,,,,,, -68500,2.292863,3.0760741,,,,,,,,,,,,,, -68600,2.290384,3.0079033,,,,,,,,,,,,,, -68700,2.1153793,3.0629044,,,,,,,,,,,,,, -68800,2.0611749,3.0313258,,,,,,,,,,,,,, -68900,1.9760334,3.1460786,,,,,,,,,,,,,, -69000,1.9001966,3.023491,,,,,,,,,,,,,, -69100,2.2492235,3.1586435,,,,,,,,,,,,,, -69122,,,0.7570351958274841,1.1902135610580444,0.6778199672698975,1.5349388122558594,50000.0,0.5520000457763672,2.1934776306152344,10000.0,23497.5886631012,24329.425166130062,23497.5886631012,827.7932825088501,1.6187632083892822,0.0 -69200,2.073429,3.0345922,,,,,,,,,,,,,, -69300,2.146159,3.0155504,,,,,,,,,,,,,, -69400,1.9788114,3.0111113,,,,,,,,,,,,,, -69500,2.1866589,3.0904176,,,,,,,,,,,,,, -69600,2.2005537,2.9804547,,,,,,,,,,,,,, -69700,2.0289204,3.053069,,,,,,,,,,,,,, -69800,2.4030979,3.0215278,,,,,,,,,,,,,, -69900,2.087346,3.0672748,,,,,,,,,,,,,, -70000,2.0411105,3.038857,,,,,,,,,,,,,, -70100,2.055395,3.066887,,,,,,,,,,,,,, -70200,1.9127394,3.0057242,,,,,,,,,,,,,, -70300,2.1699555,3.030295,,,,,,,,,,,,,, -70400,2.1681864,3.0508957,,,,,,,,,,,,,, -70500,2.0549455,3.025532,,,,,,,,,,,,,, -70600,2.0897691,3.0953362,,,,,,,,,,,,,, -70627,,,0.7458944320678711,1.2468349933624268,0.668999969959259,1.5783573389053345,50000.0,0.5467000007629395,2.213312864303589,10000.0,24007.770708084103,24857.20196413994,24007.770708084103,845.2944579124451,1.6605701446533203,0.0 -70700,1.944358,3.0275435,,,,,,,,,,,,,, -70800,2.0305269,3.0251029,,,,,,,,,,,,,, -70900,2.3061187,3.0854864,,,,,,,,,,,,,, -71000,2.0350597,3.071628,,,,,,,,,,,,,, -71100,2.2436955,3.1855102,,,,,,,,,,,,,, -71200,1.9981611,3.13461,,,,,,,,,,,,,, -71300,2.2338119,3.04303,,,,,,,,,,,,,, -71400,1.9356581,3.0328555,,,,,,,,,,,,,, -71500,2.1195903,3.0690422,,,,,,,,,,,,,, -71600,2.147386,3.0939116,,,,,,,,,,,,,, -71700,2.0178578,3.0455732,,,,,,,,,,,,,, -71800,2.1594944,3.117781,,,,,,,,,,,,,, -71900,2.1580896,3.0514822,,,,,,,,,,,,,, -72000,2.235657,3.1405444,,,,,,,,,,,,,, -72100,2.0511682,3.0632436,,,,,,,,,,,,,, -72131,,,0.7466716766357422,1.2096285820007324,0.6700999736785889,1.548133373260498,50000.0,0.5525000095367432,2.200495719909668,10000.0,24517.84316849709,25384.97437238693,24517.84316849709,862.8972687721252,1.7031033039093018,0.0 -72200,2.1271641,3.0024571,,,,,,,,,,,,,, -72300,2.0374753,3.085116,,,,,,,,,,,,,, -72400,2.156031,3.0812342,,,,,,,,,,,,,, -72500,2.2094355,3.1416535,,,,,,,,,,,,,, -72600,2.0231001,3.1288352,,,,,,,,,,,,,, -72700,2.281421,3.0288436,,,,,,,,,,,,,, -72800,2.137954,3.0721097,,,,,,,,,,,,,, -72900,2.6444066,3.0431979,,,,,,,,,,,,,, -73000,2.0727763,3.0634968,,,,,,,,,,,,,, -73100,2.001988,3.0798662,,,,,,,,,,,,,, -73200,2.0707645,3.053904,,,,,,,,,,,,,, -73300,2.1116939,2.9963396,,,,,,,,,,,,,, -73400,2.173549,3.09875,,,,,,,,,,,,,, -73500,2.1099842,3.0807247,,,,,,,,,,,,,, -73600,2.1756742,3.0075297,,,,,,,,,,,,,, -73636,,,0.7932278513908386,1.0378259420394895,0.6830999851226807,1.5034915208816528,50000.0,0.556600034236908,2.1635515689849854,10000.0,25028.014729499817,25912.85485696793,25028.014729499817,880.508659362793,1.747278928756714,0.0 -73700,2.0921526,3.054298,,,,,,,,,,,,,, -73800,2.1205406,3.1128368,,,,,,,,,,,,,, -73900,2.0719097,3.0662198,,,,,,,,,,,,,, -74000,1.9951067,2.986141,,,,,,,,,,,,,, -74100,2.046683,3.0528662,,,,,,,,,,,,,, -74200,2.1925786,3.0161185,,,,,,,,,,,,,, -74300,2.0902932,3.082315,,,,,,,,,,,,,, -74400,2.126319,3.1162014,,,,,,,,,,,,,, -74500,2.1680458,3.0593843,,,,,,,,,,,,,, -74600,2.0606632,3.0640335,,,,,,,,,,,,,, -74700,2.1881614,3.0774245,,,,,,,,,,,,,, -74800,1.9843775,2.9952679,,,,,,,,,,,,,, -74900,2.1158695,3.0706964,,,,,,,,,,,,,, -75000,2.3440614,3.1228063,,,,,,,,,,,,,, -75100,2.1984522,3.0597842,,,,,,,,,,,,,, -75141,,,0.7795758843421936,1.0734102725982666,0.6831799745559692,1.4843811988830566,50000.0,0.5615000128746033,2.1242897510528564,10000.0,25538.20123958588,26440.64587116241,25538.20123958588,898.0165367126465,1.791778802871704,0.0 -75200,2.11701,3.02176,,,,,,,,,,,,,, -75300,2.1439579,3.0019913,,,,,,,,,,,,,, -75400,2.088455,2.971347,,,,,,,,,,,,,, -75500,2.1957066,3.068781,,,,,,,,,,,,,, -75600,2.2132924,3.0432136,,,,,,,,,,,,,, -75700,2.1831677,3.0492094,,,,,,,,,,,,,, -75800,2.1646998,3.052165,,,,,,,,,,,,,, -75900,2.1329932,3.0941114,,,,,,,,,,,,,, -76000,2.3377964,3.014951,,,,,,,,,,,,,, -76100,2.2946908,3.0145442,,,,,,,,,,,,,, -76200,2.1600778,2.977688,,,,,,,,,,,,,, -76300,2.1421149,2.9689877,,,,,,,,,,,,,, -76400,2.1541765,3.084531,,,,,,,,,,,,,, -76500,2.290338,3.0241623,,,,,,,,,,,,,, -76600,2.3161678,3.0819051,,,,,,,,,,,,,, -76646,,,0.7704480290412903,1.1027264595031738,0.6850999593734741,1.478226661682129,50000.0,0.5569000244140625,2.119495391845703,10000.0,26048.3084192276,26968.49898505211,26048.3084192276,915.6695799827576,1.832149028778076,0.0 -76700,2.0641036,3.1406987,,,,,,,,,,,,,, -76800,2.2258987,3.0344615,,,,,,,,,,,,,, -76900,2.2176008,3.0313249,,,,,,,,,,,,,, -77000,2.1838553,3.0034606,,,,,,,,,,,,,, -77100,2.268557,3.045226,,,,,,,,,,,,,, -77200,2.2050164,3.0082722,,,,,,,,,,,,,, -77300,2.0481007,3.064725,,,,,,,,,,,,,, -77400,2.2010753,3.0754032,,,,,,,,,,,,,, -77500,2.2238104,3.060997,,,,,,,,,,,,,, -77600,2.0651352,3.0350971,,,,,,,,,,,,,, -77700,2.248104,3.0281706,,,,,,,,,,,,,, -77800,2.3637564,3.0505772,,,,,,,,,,,,,, -77900,2.0756657,2.9799495,,,,,,,,,,,,,, -78000,2.096482,2.9991202,,,,,,,,,,,,,, -78100,2.33225,2.9925168,,,,,,,,,,,,,, -78151,,,0.7723413705825806,1.126224398612976,0.6868199706077576,1.48687481880188,50000.0,0.5664000511169434,2.1145002841949463,10000.0,26558.43163084984,27495.934143304825,26558.43163084984,932.8865323066713,1.8738563060760496,0.0 -78200,2.1090093,3.0482092,,,,,,,,,,,,,, -78300,2.105405,3.0228102,,,,,,,,,,,,,, -78400,2.2516031,3.0363293,,,,,,,,,,,,,, -78500,2.1031828,2.9918308,,,,,,,,,,,,,, -78600,2.4150224,3.0549445,,,,,,,,,,,,,, -78700,2.1630692,3.054535,,,,,,,,,,,,,, -78800,2.19486,2.956991,,,,,,,,,,,,,, -78900,2.11741,2.9976015,,,,,,,,,,,,,, -79000,2.157772,3.0656137,,,,,,,,,,,,,, -79100,2.1751394,3.0394459,,,,,,,,,,,,,, -79200,2.2509358,3.0751112,,,,,,,,,,,,,, -79300,2.3074222,3.000128,,,,,,,,,,,,,, -79400,2.2336047,3.049631,,,,,,,,,,,,,, -79500,2.2103863,3.127474,,,,,,,,,,,,,, -79600,2.22659,3.031988,,,,,,,,,,,,,, -79656,,,0.7729591727256775,1.1393262147903442,0.6887999773025513,1.5029897689819336,50000.0,0.5633000135421753,2.1523826122283936,10000.0,27068.563962221146,28023.855808973312,27068.563962221146,950.5808329582214,1.9172337055206297,0.0 -79700,2.1982658,3.0000527,,,,,,,,,,,,,, -79800,2.1478825,2.980788,,,,,,,,,,,,,, -79900,2.3509057,3.0714014,,,,,,,,,,,,,, -80000,2.2831588,3.004742,,,,,,,,,,,,,, -80100,2.2403896,3.0570788,,,,,,,,,,,,,, -80200,2.1576943,3.0291262,,,,,,,,,,,,,, -80300,2.2759788,3.0850146,,,,,,,,,,,,,, -80400,2.3367536,3.0130172,,,,,,,,,,,,,, -80500,2.268102,3.0113325,,,,,,,,,,,,,, -80600,2.125296,3.0272284,,,,,,,,,,,,,, -80700,2.1676826,3.0289185,,,,,,,,,,,,,, -80800,2.3159122,3.034614,,,,,,,,,,,,,, -80900,2.1470952,2.941274,,,,,,,,,,,,,, -81000,2.3140683,3.0015416,,,,,,,,,,,,,, -81100,2.1428823,3.0099602,,,,,,,,,,,,,, -81161,,,0.7640505433082581,1.1171501874923706,0.682379961013794,1.4879083633422852,50000.0,0.5533000230789185,2.143951177597046,10000.0,27578.76552867889,28552.553351163864,27578.76552867889,968.9772000312804,1.9648942947387693,0.0 -81200,2.1034884,2.9589608,,,,,,,,,,,,,, -81300,2.3561575,3.0200312,,,,,,,,,,,,,, -81400,2.2617242,3.0040412,,,,,,,,,,,,,, -81500,2.4943984,3.1413906,,,,,,,,,,,,,, -81600,2.2348864,3.0119188,,,,,,,,,,,,,, -81700,2.1008518,3.0488677,,,,,,,,,,,,,, -81800,2.2695098,3.0275254,,,,,,,,,,,,,, -81900,2.225956,2.9791868,,,,,,,,,,,,,, -82000,2.2164059,3.0160031,,,,,,,,,,,,,, -82100,2.3042576,2.9834719,,,,,,,,,,,,,, -82200,2.3948972,2.9970746,,,,,,,,,,,,,, -82300,2.2157855,3.0229526,,,,,,,,,,,,,, -82400,2.102677,3.004221,,,,,,,,,,,,,, -82500,2.1330476,3.0245771,,,,,,,,,,,,,, -82600,2.3737397,3.0843098,,,,,,,,,,,,,, -82665,,,0.8116629123687744,0.9491603970527648,0.6913999915122986,1.452277898788452,50000.0,0.5690000057220459,2.0855112075805664,10000.0,28088.932448387142,29080.30035591125,28088.932448387142,986.45987200737,2.0094380378723145,0.0 -82700,2.4438891,3.0641143,,,,,,,,,,,,,, -82800,2.162452,2.9694655,,,,,,,,,,,,,, -82900,2.3127546,3.0524788,,,,,,,,,,,,,, -83000,2.2084897,2.9570198,,,,,,,,,,,,,, -83100,2.3725994,3.0444705,,,,,,,,,,,,,, -83200,2.089625,2.926397,,,,,,,,,,,,,, -83300,2.163891,3.025762,,,,,,,,,,,,,, -83400,2.2336318,3.02771,,,,,,,,,,,,,, -83500,2.2261877,3.0085268,,,,,,,,,,,,,, -83600,2.1962965,3.029448,,,,,,,,,,,,,, -83700,2.4637501,3.0599205,,,,,,,,,,,,,, -83800,2.154149,3.0775988,,,,,,,,,,,,,, -83900,2.4926176,3.0517342,,,,,,,,,,,,,, -84000,2.296805,3.0690105,,,,,,,,,,,,,, -84100,2.6137762,3.032179,,,,,,,,,,,,,, -84170,,,0.786152720451355,1.03113853931427,0.6920199990272522,1.438202738761902,50000.0,0.563800036907196,2.106266975402832,10000.0,28599.081391334534,29608.071023464203,28599.081391334534,1003.9821727275848,2.056580305099488,0.0 -84200,2.3318064,3.0699468,,,,,,,,,,,,,, -84300,2.1304533,2.981826,,,,,,,,,,,,,, -84400,2.3046753,3.0234747,,,,,,,,,,,,,, -84500,2.1643791,2.996109,,,,,,,,,,,,,, -84600,2.3069136,3.0322633,,,,,,,,,,,,,, -84700,2.231931,2.945047,,,,,,,,,,,,,, -84800,2.2751663,2.9327574,,,,,,,,,,,,,, -84900,2.505834,3.1076694,,,,,,,,,,,,,, -85000,2.4851866,3.022822,,,,,,,,,,,,,, -85100,2.3317993,3.067535,,,,,,,,,,,,,, -85200,2.2224038,3.056327,,,,,,,,,,,,,, -85300,2.2487605,2.9738455,,,,,,,,,,,,,, -85400,2.531868,3.0360177,,,,,,,,,,,,,, -85500,2.156621,2.9820952,,,,,,,,,,,,,, -85600,2.143571,3.012894,,,,,,,,,,,,,, -85675,,,0.7844387888908386,1.050876259803772,0.6937800049781799,1.4405685663223269,50000.0,0.5645000338554382,2.099759101867676,10000.0,29109.275886297222,30135.671944856644,29109.275886297222,1021.2940018177032,2.099645137786865,0.0 -85700,2.3199956,3.075603,,,,,,,,,,,,,, -85800,2.3921962,3.0360432,,,,,,,,,,,,,, -85900,2.451453,3.045447,,,,,,,,,,,,,, -86000,2.4174342,3.0686822,,,,,,,,,,,,,, -86100,2.3284245,2.9883227,,,,,,,,,,,,,, -86200,2.3059301,2.963283,,,,,,,,,,,,,, -86300,2.4877129,3.0806172,,,,,,,,,,,,,, -86400,2.2738285,2.9453676,,,,,,,,,,,,,, -86500,2.1942472,2.9339914,,,,,,,,,,,,,, -86600,2.2277215,3.0111969,,,,,,,,,,,,,, -86700,2.3681743,3.0571334,,,,,,,,,,,,,, -86800,2.4466631,3.0505874,,,,,,,,,,,,,, -86900,2.1794355,2.991807,,,,,,,,,,,,,, -87000,2.5054133,2.9658494,,,,,,,,,,,,,, -87100,2.3320389,3.0029662,,,,,,,,,,,,,, -87180,,,0.7828842401504517,1.0602805614471436,0.6989799737930298,1.430122971534729,50000.0,0.5717000365257263,2.0682175159454346,10000.0,29619.42868781089,30663.413135290142,29619.42868781089,1038.789042711258,2.1402103900909424,0.0 -87200,2.3925502,3.0323548,,,,,,,,,,,,,, -87300,2.232522,2.9303944,,,,,,,,,,,,,, -87400,2.2829728,2.9800136,,,,,,,,,,,,,, -87500,2.404741,3.0130126,,,,,,,,,,,,,, -87600,2.3697462,2.969211,,,,,,,,,,,,,, -87700,2.338422,2.9728396,,,,,,,,,,,,,, -87800,2.4753945,3.00343,,,,,,,,,,,,,, -87900,2.3431838,2.9744544,,,,,,,,,,,,,, -88000,2.3651826,2.9433637,,,,,,,,,,,,,, -88100,2.3904054,3.022264,,,,,,,,,,,,,, -88200,2.3058324,2.9552903,,,,,,,,,,,,,, -88300,2.3302422,2.9883823,,,,,,,,,,,,,, -88400,2.5542963,2.9989526,,,,,,,,,,,,,, -88500,2.5518882,2.9569879,,,,,,,,,,,,,, -88600,2.3958695,2.9986658,,,,,,,,,,,,,, -88684,,,0.7784797549247742,1.0503365993499756,0.6977999806404114,1.412257194519043,50000.0,0.565500020980835,2.0685789585113525,10000.0,30129.41860818863,31191.086156368256,30129.41860818863,1056.377511024475,2.181400775909424,0.0 -88700,2.3200574,2.9124057,,,,,,,,,,,,,, -88800,2.2630284,2.961094,,,,,,,,,,,,,, -88900,2.3173459,2.9980025,,,,,,,,,,,,,, -89000,2.4865174,2.9792788,,,,,,,,,,,,,, -89100,2.338781,2.9549973,,,,,,,,,,,,,, -89200,2.3794842,2.8825855,,,,,,,,,,,,,, -89300,2.246312,2.986405,,,,,,,,,,,,,, -89400,2.5249002,2.9657896,,,,,,,,,,,,,, -89500,2.2281938,2.9637861,,,,,,,,,,,,,, -89600,2.253194,3.0081275,,,,,,,,,,,,,, -89700,2.2565706,2.8922696,,,,,,,,,,,,,, -89800,2.568332,3.0533023,,,,,,,,,,,,,, -89900,2.42439,3.0641985,,,,,,,,,,,,,, -90000,2.5344567,2.9960842,,,,,,,,,,,,,, -90100,2.4473894,2.9734602,,,,,,,,,,,,,, -90189,,,0.7796755433082581,1.052192211151123,0.6933599710464478,1.4358712434768677,50000.0,0.567300021648407,2.074946641921997,10000.0,30639.66088938713,31718.93631315232,30639.66088938713,1073.8877835273745,2.227370262145996,0.0 -90200,2.4679189,2.9698176,,,,,,,,,,,,,, -90300,2.3276386,3.0116544,,,,,,,,,,,,,, -90400,2.3383987,3.0449035,,,,,,,,,,,,,, -90500,2.343195,2.9754333,,,,,,,,,,,,,, -90600,2.4786577,3.0431278,,,,,,,,,,,,,, -90700,2.3731282,2.9164188,,,,,,,,,,,,,, -90800,2.4877334,2.9552958,,,,,,,,,,,,,, -90900,2.2386422,2.9740155,,,,,,,,,,,,,, -91000,2.401211,2.9577854,,,,,,,,,,,,,, -91100,2.4156659,2.945679,,,,,,,,,,,,,, -91200,2.3361182,2.9577281,,,,,,,,,,,,,, -91300,2.3964748,2.9708571,,,,,,,,,,,,,, -91400,2.4211347,2.908489,,,,,,,,,,,,,, -91500,2.4534576,2.9660623,,,,,,,,,,,,,, -91600,2.5308495,3.0031466,,,,,,,,,,,,,, -91694,,,0.8287826776504517,0.869025707244873,0.7048199772834778,1.3970078229904177,50000.0,0.5756000280380249,2.055137872695923,10000.0,31149.84545230865,32246.600895643234,31149.84545230865,1091.269455909729,2.272991180419922,0.0 -91700,2.5038743,3.0061278,,,,,,,,,,,,,, -91800,2.5364182,3.0127354,,,,,,,,,,,,,, -91900,2.64808,2.9314182,,,,,,,,,,,,,, -92000,2.310822,2.987047,,,,,,,,,,,,,, -92100,2.3596559,2.964588,,,,,,,,,,,,,, -92200,2.4073637,2.9591784,,,,,,,,,,,,,, -92300,2.344278,2.9492345,,,,,,,,,,,,,, -92400,2.5113199,2.9642382,,,,,,,,,,,,,, -92500,2.420142,3.022369,,,,,,,,,,,,,, -92600,2.4856143,2.9427235,,,,,,,,,,,,,, -92700,2.6552064,2.9693637,,,,,,,,,,,,,, -92800,2.5731726,3.0790796,,,,,,,,,,,,,, -92900,2.4893372,2.9913306,,,,,,,,,,,,,, -93000,2.4734435,2.9706576,,,,,,,,,,,,,, -93100,2.477915,2.9617872,,,,,,,,,,,,,, -93198,,,0.7990872263908386,0.9855494499206544,0.6990999579429626,1.429892659187317,50000.0,0.5735000371932983,2.089207172393799,10000.0,31659.772423505783,32774.04626703262,31659.772423505783,1108.6878995895386,2.321443557739258,0.0 -93200,2.4862158,3.0098674,,,,,,,,,,,,,, -93300,2.5845258,2.9861832,,,,,,,,,,,,,, -93400,2.5079758,2.9430332,,,,,,,,,,,,,, -93500,2.4863267,2.9738345,,,,,,,,,,,,,, -93600,2.374909,2.948356,,,,,,,,,,,,,, -93700,2.605468,3.112435,,,,,,,,,,,,,, -93800,2.517194,2.9851022,,,,,,,,,,,,,, -93900,2.3984687,2.9518898,,,,,,,,,,,,,, -94000,2.5255823,3.001842,,,,,,,,,,,,,, -94100,2.4435809,2.9350934,,,,,,,,,,,,,, -94200,2.4169633,2.987322,,,,,,,,,,,,,, -94300,2.4722128,3.1053073,,,,,,,,,,,,,, -94400,2.383043,2.9270709,,,,,,,,,,,,,, -94500,2.2828524,3.0410688,,,,,,,,,,,,,, -94600,2.3366086,2.9292746,,,,,,,,,,,,,, -94700,2.7816515,3.0142407,,,,,,,,,,,,,, -94701,,,0.7864317297935486,1.0102328062057495,0.6943399906158447,1.414360761642456,50000.0,0.5597000122070312,2.086948871612549,10000.0,32169.7107861042,33301.60674357414,32169.7107861042,1126.209413766861,2.367856502532959,0.0 -94800,2.482651,3.0132694,,,,,,,,,,,,,, -94900,2.4417164,3.037182,,,,,,,,,,,,,, -95000,2.629804,2.9839156,,,,,,,,,,,,,, -95100,2.4975824,2.89446,,,,,,,,,,,,,, -95200,2.5317602,2.9290495,,,,,,,,,,,,,, -95300,2.4044862,2.940696,,,,,,,,,,,,,, -95400,2.4087331,2.9479556,,,,,,,,,,,,,, -95500,2.4161794,2.9601998,,,,,,,,,,,,,, -95600,2.4481604,2.9026406,,,,,,,,,,,,,, -95700,2.5620165,2.9668698,,,,,,,,,,,,,, -95800,2.5465856,3.0440063,,,,,,,,,,,,,, -95900,2.570163,2.933948,,,,,,,,,,,,,, -96000,2.4627407,2.9495287,,,,,,,,,,,,,, -96100,2.6123357,2.9388182,,,,,,,,,,,,,, -96200,2.8691278,2.9444296,,,,,,,,,,,,,, -96205,,,0.8004623651504517,1.0195797681808472,0.7038599848747253,1.428849697113037,50000.0,0.5768000483512878,2.0695486068725586,10000.0,32679.70833396912,33829.24775338173,32679.70833396912,1143.754875421524,2.413187265396118,0.0 -96300,2.4528377,2.9632757,,,,,,,,,,,,,, -96400,2.5026798,2.9933884,,,,,,,,,,,,,, -96500,2.407389,2.980528,,,,,,,,,,,,,, -96600,2.6607294,2.9895244,,,,,,,,,,,,,, -96700,2.4181125,2.9246905,,,,,,,,,,,,,, -96800,2.415588,2.9781628,,,,,,,,,,,,,, -96900,2.3395588,2.9252157,,,,,,,,,,,,,, -97000,2.5391111,2.9026165,,,,,,,,,,,,,, -97100,2.3675537,2.9855728,,,,,,,,,,,,,, -97200,2.5257835,2.9380307,,,,,,,,,,,,,, -97300,2.4703372,2.9718506,,,,,,,,,,,,,, -97400,2.6631386,2.954749,,,,,,,,,,,,,, -97500,2.5359206,2.9251904,,,,,,,,,,,,,, -97600,2.7084727,2.9980347,,,,,,,,,,,,,, -97700,2.5286446,2.905726,,,,,,,,,,,,,, -97709,,,0.7979910373687744,1.0076885223388672,0.7076399922370911,1.4024112224578855,50000.0,0.5755000114440918,2.061427354812622,10000.0,33189.67002224922,34356.86605596542,33189.67002224922,1161.3141412734983,2.4573397636413574,0.0 -97800,2.6512947,3.0081,,,,,,,,,,,,,, -97900,2.6104414,2.923501,,,,,,,,,,,,,, -98000,2.5080297,2.9310598,,,,,,,,,,,,,, -98100,2.6617978,3.0279005,,,,,,,,,,,,,, -98200,2.47814,2.8890707,,,,,,,,,,,,,, -98300,2.507719,2.9263394,,,,,,,,,,,,,, -98400,2.5176916,2.96375,,,,,,,,,,,,,, -98500,2.541802,2.9215107,,,,,,,,,,,,,, -98600,2.5125287,2.9818263,,,,,,,,,,,,,, -98700,2.493746,2.95081,,,,,,,,,,,,,, -98800,2.6432037,2.9850054,,,,,,,,,,,,,, -98900,2.7080941,2.9218023,,,,,,,,,,,,,, -99000,2.5553408,2.9982624,,,,,,,,,,,,,, -99100,2.6809018,2.942384,,,,,,,,,,,,,, -99200,2.477723,2.9510596,,,,,,,,,,,,,, -99213,,,0.7958186864852905,1.0074490308761597,0.7049799561500549,1.3988691568374634,50000.0,0.5746999979019165,2.0553035736083984,10000.0,33699.635172605515,34884.06643486023,33699.635172605515,1178.4537107944489,2.500805854797364,0.0 -99300,2.6289113,2.9352214,,,,,,,,,,,,,, -99400,2.4003997,2.955459,,,,,,,,,,,,,, -99500,2.4200091,2.949559,,,,,,,,,,,,,, -99600,2.6513462,2.9633813,,,,,,,,,,,,,, -99700,2.5751297,2.9438906,,,,,,,,,,,,,, -99800,2.351347,2.8660831,,,,,,,,,,,,,, -99900,2.4476306,2.952801,,,,,,,,,,,,,, -100000,2.4349988,2.8359694,,,,,,,,,,,,,, -100100,2.718693,2.935854,,,,,,,,,,,,,, -100200,2.660993,2.9618545,,,,,,,,,,,,,, -100300,2.4464505,2.9362273,,,,,,,,,,,,,, -100400,2.766764,2.9529736,,,,,,,,,,,,,, -100500,2.6632373,2.8862147,,,,,,,,,,,,,, -100600,2.3864124,2.8674552,,,,,,,,,,,,,, -100700,2.487502,2.9030871,,,,,,,,,,,,,, -100717,,,0.8254942297935486,0.8944934010505676,0.7102400064468384,1.3767321109771729,50000.0,0.579800009727478,2.0282983779907227,10000.0,34209.61276984215,35411.85849046707,34209.61276984215,1196.1693103313446,2.546966552734375,0.0 -100800,2.5752134,2.8946714,,,,,,,,,,,,,, -100900,2.5181642,2.9595993,,,,,,,,,,,,,, -101000,2.5567698,2.9168687,,,,,,,,,,,,,, -101100,2.7103755,2.95271,,,,,,,,,,,,,, -101200,2.632173,2.8789673,,,,,,,,,,,,,, -101300,2.6864192,2.9354134,,,,,,,,,,,,,, -101400,2.7839937,2.8921702,,,,,,,,,,,,,, -101500,2.5839713,2.952801,,,,,,,,,,,,,, -101600,2.489074,2.9162421,,,,,,,,,,,,,, -101700,2.7129261,3.0020418,,,,,,,,,,,,,, -101800,2.5241907,2.8492072,,,,,,,,,,,,,, -101900,2.7833574,2.9434474,,,,,,,,,,,,,, -102000,2.7331314,2.9258747,,,,,,,,,,,,,, -102100,2.5744188,2.94238,,,,,,,,,,,,,, -102200,2.7000585,2.9411786,,,,,,,,,,,,,, -102222,,,0.8199936151504517,0.9147710800170898,0.7091999650001526,1.3878830671310425,50000.0,0.5827000141143799,2.037583112716675,10000.0,34719.75500845909,35939.416763305664,34719.75500845909,1213.4867506027222,2.591370344161988,0.0 -102300,2.4959247,2.8971426,,,,,,,,,,,,,, -102400,2.5016925,2.8770852,,,,,,,,,,,,,, -102500,2.4630516,2.8743968,,,,,,,,,,,,,, -102600,2.5787683,2.8911943,,,,,,,,,,,,,, -102700,2.6358347,2.901556,,,,,,,,,,,,,, -102800,2.7420425,2.91427,,,,,,,,,,,,,, -102900,2.5269654,2.8794363,,,,,,,,,,,,,, -103000,2.7447162,2.9185781,,,,,,,,,,,,,, -103100,2.4499283,2.868849,,,,,,,,,,,,,, -103200,2.5762467,2.8918707,,,,,,,,,,,,,, -103300,2.9547255,2.9442139,,,,,,,,,,,,,, -103400,2.655016,2.945559,,,,,,,,,,,,,, -103500,2.5959213,2.9168372,,,,,,,,,,,,,, -103600,2.6006896,2.9483798,,,,,,,,,,,,,, -103700,2.5893269,2.8538804,,,,,,,,,,,,,, -103726,,,0.8092314600944519,0.9412323236465454,0.7082799673080444,1.3787775039672852,50000.0,0.5809000134468079,2.029460430145264,10000.0,35229.8086771965,36467.01249408722,35229.8086771965,1230.9158039093018,2.652010679244995,0.0 -103800,2.444424,2.8662415,,,,,,,,,,,,,, -103900,2.5947666,2.8802204,,,,,,,,,,,,,, -104000,2.669465,2.8666344,,,,,,,,,,,,,, -104100,2.5961814,2.9007013,,,,,,,,,,,,,, -104200,2.8099506,2.9343019,,,,,,,,,,,,,, -104300,2.5230856,2.928194,,,,,,,,,,,,,, -104400,2.7652931,2.9361446,,,,,,,,,,,,,, -104500,2.7788634,2.9133542,,,,,,,,,,,,,, -104600,2.7063744,2.9490364,,,,,,,,,,,,,, -104700,2.5582142,2.9040785,,,,,,,,,,,,,, -104800,2.594486,2.913768,,,,,,,,,,,,,, -104900,2.7303383,2.9512024,,,,,,,,,,,,,, -105000,2.573522,2.8366497,,,,,,,,,,,,,, -105100,2.544219,2.854413,,,,,,,,,,,,,, -105200,2.7914357,2.9792511,,,,,,,,,,,,,, -105231,,,0.810546875,0.9459798336029052,0.7136399745941162,1.372373104095459,50000.0,0.5879000425338745,2.013390302658081,10000.0,35739.95236110687,36994.68914103508,35739.95236110687,1248.351620197296,2.697594404220581,0.0 -105300,2.4769666,2.8861666,,,,,,,,,,,,,, -105400,2.6795034,2.9198298,,,,,,,,,,,,,, -105500,2.6598992,2.9251838,,,,,,,,,,,,,, -105600,2.8165934,2.91612,,,,,,,,,,,,,, -105700,2.590102,2.9215884,,,,,,,,,,,,,, -105800,2.5987508,2.9030166,,,,,,,,,,,,,, -105900,2.4557018,2.8549063,,,,,,,,,,,,,, -106000,2.6593304,2.9358678,,,,,,,,,,,,,, -106100,2.543328,2.8874798,,,,,,,,,,,,,, -106200,2.7993436,2.853033,,,,,,,,,,,,,, -106300,2.6592844,2.8582737,,,,,,,,,,,,,, -106400,2.8293407,2.9592357,,,,,,,,,,,,,, -106500,2.6065862,2.9480584,,,,,,,,,,,,,, -106600,2.819379,2.9203084,,,,,,,,,,,,,, -106700,2.7407503,2.9820206,,,,,,,,,,,,,, -106736,,,0.8152502775192261,0.9214245080947876,0.7125799655914307,1.3635106086730957,50000.0,0.5924000144004822,1.9973350763320925,10000.0,36250.164659023285,37522.34675478935,36250.164659023285,1265.6972844600675,2.745483875274658,0.0 -106800,2.6644495,2.8668494,,,,,,,,,,,,,, -106900,2.6686523,2.9022272,,,,,,,,,,,,,, -107000,2.695741,2.889771,,,,,,,,,,,,,, -107100,2.797487,2.895891,,,,,,,,,,,,,, -107200,2.6258247,2.8802142,,,,,,,,,,,,,, -107300,2.603643,2.829064,,,,,,,,,,,,,, -107400,2.717313,2.911846,,,,,,,,,,,,,, -107500,2.7072237,2.8256254,,,,,,,,,,,,,, -107600,2.6650674,2.8648524,,,,,,,,,,,,,, -107700,2.8207536,2.9575565,,,,,,,,,,,,,, -107800,2.703359,2.9150386,,,,,,,,,,,,,, -107900,2.8115592,2.927095,,,,,,,,,,,,,, -108000,2.7678175,2.9351645,,,,,,,,,,,,,, -108100,2.6782079,2.8821673,,,,,,,,,,,,,, -108200,2.7380393,2.900877,,,,,,,,,,,,,, -108240,,,0.8124202489852905,0.9675102829933168,0.7174199819564819,1.3825063705444336,50000.0,0.5960000157356262,2.006366014480591,10000.0,36760.11304235458,38049.81420207024,36760.11304235458,1283.1132173538208,2.795067071914673,0.0 -108300,2.8192537,2.8757586,,,,,,,,,,,,,, -108400,2.706983,2.9245489,,,,,,,,,,,,,, -108500,2.5778682,2.8839638,,,,,,,,,,,,,, -108600,2.697094,2.9174232,,,,,,,,,,,,,, -108700,2.6630623,2.8959196,,,,,,,,,,,,,, -108800,2.7300272,2.8914244,,,,,,,,,,,,,, -108900,2.6634939,2.8625581,,,,,,,,,,,,,, -109000,2.793648,2.909802,,,,,,,,,,,,,, -109100,2.8467784,2.9470325,,,,,,,,,,,,,, -109200,2.7966378,2.874468,,,,,,,,,,,,,, -109300,2.7484455,2.9219964,,,,,,,,,,,,,, -109400,2.7176414,2.8664794,,,,,,,,,,,,,, -109500,2.54407,2.8862414,,,,,,,,,,,,,, -109600,2.818293,2.9513762,,,,,,,,,,,,,, -109700,2.809513,2.9132543,,,,,,,,,,,,,, -109745,,,0.8219866156578064,0.869674026966095,0.7159799933433533,1.3264567852020264,50000.0,0.5906000137329102,1.9520862102508545,10000.0,37270.30475926399,38577.689265728,37270.30475926399,1300.696792602539,2.8408803939819336,0.0 -109800,2.8235595,2.899324,,,,,,,,,,,,,, -109900,2.8656142,2.8888154,,,,,,,,,,,,,, -110000,3.1827803,2.9023035,,,,,,,,,,,,,, -110100,2.8817015,2.888205,,,,,,,,,,,,,, -110200,2.778977,2.8864183,,,,,,,,,,,,,, -110300,2.6990545,2.7893105,,,,,,,,,,,,,, -110400,2.7860959,2.9009426,,,,,,,,,,,,,, -110500,3.0477748,2.8912678,,,,,,,,,,,,,, -110600,2.785498,2.839678,,,,,,,,,,,,,, -110700,2.7326465,2.9559433,,,,,,,,,,,,,, -110800,2.7710302,2.8776228,,,,,,,,,,,,,, -110900,2.8745246,2.936136,,,,,,,,,,,,,, -111000,2.6883316,2.8297136,,,,,,,,,,,,,, -111100,2.565682,2.8112378,,,,,,,,,,,,,, -111200,2.8052318,2.8885503,,,,,,,,,,,,,, -111249,,,0.8328284025192261,0.8718942999839783,0.7133600115776062,1.374558448791504,50000.0,0.589900016784668,2.0042407512664795,10000.0,37780.53218817711,39105.6652610302,37780.53218817711,1318.3463683128357,2.886791706085205,0.0 -111300,2.84169,2.8326578,,,,,,,,,,,,,, -111400,2.8346515,2.8858047,,,,,,,,,,,,,, -111500,2.9543927,2.8708398,,,,,,,,,,,,,, -111600,2.7624733,2.8652477,,,,,,,,,,,,,, -111700,2.835926,2.8350165,,,,,,,,,,,,,, -111800,2.8437643,2.9470403,,,,,,,,,,,,,, -111900,2.8589265,2.8640027,,,,,,,,,,,,,, -112000,2.9405668,2.8850086,,,,,,,,,,,,,, -112100,2.780267,2.846549,,,,,,,,,,,,,, -112200,2.792276,2.9118485,,,,,,,,,,,,,, -112300,2.7408113,2.882784,,,,,,,,,,,,,, -112400,2.8462405,2.850126,,,,,,,,,,,,,, -112500,2.589287,2.8279598,,,,,,,,,,,,,, -112600,2.8899703,2.888134,,,,,,,,,,,,,, -112700,2.8382668,2.878397,,,,,,,,,,,,,, -112754,,,0.829121470451355,0.8686312437057495,0.7148799896240234,1.34848952293396,50000.0,0.5905000567436218,2.003549337387085,10000.0,38290.7551791668,39633.61887669563,38290.7551791668,1335.973201751709,2.9371728897094727,0.0 -112800,2.7620354,2.8390276,,,,,,,,,,,,,, -112900,2.852819,2.8925495,,,,,,,,,,,,,, -113000,3.03281,2.9158645,,,,,,,,,,,,,, -113100,2.8306806,2.8148487,,,,,,,,,,,,,, -113200,3.0403678,2.8597465,,,,,,,,,,,,,, -113300,2.7875073,2.8919878,,,,,,,,,,,,,, -113400,2.704165,2.7915719,,,,,,,,,,,,,, -113500,2.830822,2.8626163,,,,,,,,,,,,,, -113600,2.800252,2.8600183,,,,,,,,,,,,,, -113700,2.8265822,2.8171487,,,,,,,,,,,,,, -113800,2.9343343,2.8599396,,,,,,,,,,,,,, -113900,2.815669,2.8343456,,,,,,,,,,,,,, -114000,2.734402,2.8252773,,,,,,,,,,,,,, -114100,2.9281764,2.8699362,,,,,,,,,,,,,, -114200,2.7944515,2.9245434,,,,,,,,,,,,,, -114259,,,0.8265106678009033,0.8988329768180847,0.7185800075531006,1.36379075050354,50000.0,0.597100019454956,1.997839331626892,10000.0,38800.8717019558,40161.65034651756,38800.8717019558,1353.783153772354,2.9890265464782715,0.0 -114300,2.7201722,2.8997116,,,,,,,,,,,,,, -114400,2.8990078,2.8321536,,,,,,,,,,,,,, -114500,2.7840817,2.8040652,,,,,,,,,,,,,, -114600,2.901351,2.9175205,,,,,,,,,,,,,, -114700,2.7383606,2.8418293,,,,,,,,,,,,,, -114800,2.5870345,2.8095593,,,,,,,,,,,,,, -114900,2.727174,2.7975433,,,,,,,,,,,,,, -115000,2.983519,2.8570068,,,,,,,,,,,,,, -115100,2.8254073,2.8393805,,,,,,,,,,,,,, -115200,2.8758628,2.8605556,,,,,,,,,,,,,, -115300,2.7029154,2.7546165,,,,,,,,,,,,,, -115400,3.0867875,2.9090724,,,,,,,,,,,,,, -115500,2.8969762,2.8403714,,,,,,,,,,,,,, -115600,2.7274446,2.8401086,,,,,,,,,,,,,, -115700,2.9698706,2.924065,,,,,,,,,,,,,, -115764,,,0.8270089030265808,0.9045054316520692,0.722819983959198,1.3632978200912476,50000.0,0.5939000248908997,2.0048916339874268,10000.0,39310.83686709404,40689.02972865105,39310.83686709404,1371.0964777469635,3.0382204055786133,0.0 -115800,2.8822591,2.8151712,,,,,,,,,,,,,, -115900,2.9169195,2.9046087,,,,,,,,,,,,,, -116000,2.8105102,2.810515,,,,,,,,,,,,,, -116100,2.9014862,2.921152,,,,,,,,,,,,,, -116200,2.817593,2.8754768,,,,,,,,,,,,,, -116300,2.9115958,2.8180459,,,,,,,,,,,,,, -116400,3.0520546,2.850071,,,,,,,,,,,,,, -116500,3.0388954,2.7968209,,,,,,,,,,,,,, -116600,2.7903047,2.8419127,,,,,,,,,,,,,, -116700,2.9419007,2.8914747,,,,,,,,,,,,,, -116800,2.773461,2.8197322,,,,,,,,,,,,,, -116900,2.820128,2.860991,,,,,,,,,,,,,, -117000,2.7211604,2.753239,,,,,,,,,,,,,, -117100,2.9917989,2.8325725,,,,,,,,,,,,,, -117200,2.7391615,2.8165185,,,,,,,,,,,,,, -117269,,,0.8316724896430969,0.8688330054283142,0.7247799634933472,1.3182138204574585,50000.0,0.6014000177383423,1.9583295583724976,10000.0,39820.8876388073,41216.72868323326,39820.8876388073,1388.6389904022217,3.0894229412078857,0.0 -117300,3.0398712,2.872742,,,,,,,,,,,,,, -117400,3.036429,2.8467627,,,,,,,,,,,,,, -117500,3.0106525,2.8427775,,,,,,,,,,,,,, -117600,2.8328795,2.8572402,,,,,,,,,,,,,, -117700,3.1707916,2.8233147,,,,,,,,,,,,,, -117800,2.9268687,2.9443712,,,,,,,,,,,,,, -117900,2.9215546,2.8775892,,,,,,,,,,,,,, -118000,3.0409565,2.8844461,,,,,,,,,,,,,, -118100,2.9034338,2.8831577,,,,,,,,,,,,,, -118200,2.8759491,2.8722486,,,,,,,,,,,,,, -118300,3.009857,2.863636,,,,,,,,,,,,,, -118400,2.9096644,2.7906878,,,,,,,,,,,,,, -118500,2.850352,2.8533206,,,,,,,,,,,,,, -118600,3.1624057,2.8185275,,,,,,,,,,,,,, -118700,2.9220123,2.8418076,,,,,,,,,,,,,, -118774,,,0.8371332883834839,0.8562281131744385,0.725659966468811,1.3277863264083862,50000.0,0.6079000234603882,1.9448498487472528,10000.0,40331.05041480064,41744.96152448654,40331.05041480064,1406.6010098457336,3.143878936767578,0.0 -118800,3.0198753,2.7980566,,,,,,,,,,,,,, -118900,2.9247296,2.8323197,,,,,,,,,,,,,, -119000,3.19019,2.8383405,,,,,,,,,,,,,, -119100,2.7870512,2.8537436,,,,,,,,,,,,,, -119200,2.9127965,2.7832177,,,,,,,,,,,,,, -119300,2.8037329,2.8452706,,,,,,,,,,,,,, -119400,2.8610349,2.8393385,,,,,,,,,,,,,, -119500,2.9753468,2.8663638,,,,,,,,,,,,,, -119600,3.038036,2.8491828,,,,,,,,,,,,,, -119700,3.07335,2.8216884,,,,,,,,,,,,,, -119800,2.8942444,2.9062161,,,,,,,,,,,,,, -119900,2.9892838,2.8197925,,,,,,,,,,,,,, -120000,2.8281987,2.8071551,,,,,,,,,,,,,, -120100,3.1979022,2.8064733,,,,,,,,,,,,,, -120200,2.9429166,2.8526747,,,,,,,,,,,,,, -120279,,,0.8495694994926453,0.8090940117835999,0.720579981803894,1.3468101024627686,50000.0,0.5987000465393066,1.9820739030838013,10000.0,40841.25688076019,42272.75825166702,40841.25688076019,1424.0867023468018,3.195374011993408,0.0 -120300,3.1398726,2.8358502,,,,,,,,,,,,,, -120400,2.9486432,2.7878494,,,,,,,,,,,,,, -120500,3.054146,2.8175755,,,,,,,,,,,,,, -120600,2.985218,2.8555279,,,,,,,,,,,,,, -120700,2.8644545,2.788704,,,,,,,,,,,,,, -120800,3.0061996,2.7947974,,,,,,,,,,,,,, -120900,3.0296817,2.7359216,,,,,,,,,,,,,, -121000,3.066421,2.8525834,,,,,,,,,,,,,, -121100,2.8446763,2.7926247,,,,,,,,,,,,,, -121200,2.901588,2.7206125,,,,,,,,,,,,,, -121300,2.8005645,2.8210711,,,,,,,,,,,,,, -121400,3.1880338,2.7811425,,,,,,,,,,,,,, -121500,2.89995,2.808314,,,,,,,,,,,,,, -121600,2.9262471,2.8142495,,,,,,,,,,,,,, -121700,2.9635892,2.7846599,,,,,,,,,,,,,, -121783,,,0.8489118218421936,0.7738045454025269,0.7283599972724915,1.285322904586792,50000.0,0.6050000190734863,1.9310460090637207,10000.0,41351.40269494057,42800.45030093193,41351.40269494057,1441.527908563614,3.246919870376587,0.0 -121800,3.0907836,2.8306122,,,,,,,,,,,,,, -121900,2.697613,2.7371125,,,,,,,,,,,,,, -122000,2.9563067,2.8146577,,,,,,,,,,,,,, -122100,3.0598576,2.8410487,,,,,,,,,,,,,, -122200,3.035853,2.8097239,,,,,,,,,,,,,, -122300,3.110721,2.8340712,,,,,,,,,,,,,, -122400,2.949498,2.7763455,,,,,,,,,,,,,, -122500,3.053337,2.8325958,,,,,,,,,,,,,, -122600,3.1261013,2.7964897,,,,,,,,,,,,,, -122700,3.0516098,2.8127131,,,,,,,,,,,,,, -122800,2.9968488,2.7985535,,,,,,,,,,,,,, -122900,2.919682,2.8259997,,,,,,,,,,,,,, -123000,2.979029,2.8020127,,,,,,,,,,,,,, -123100,2.9204676,2.7318735,,,,,,,,,,,,,, -123200,3.1585886,2.819181,,,,,,,,,,,,,, -123288,,,0.8492506146430969,0.7951707243919373,0.7295199632644653,1.2949326038360596,50000.0,0.6044000387191772,1.924842476844788,10000.0,41861.53848028183,43328.20562505722,41861.53848028183,1459.0470685958862,3.294360637664795,0.0 -123300,3.126494,2.7953,,,,,,,,,,,,,, -123400,3.2231724,2.8715854,,,,,,,,,,,,,, -123500,3.1600823,2.8003669,,,,,,,,,,,,,, -123600,3.1657207,2.7514062,,,,,,,,,,,,,, -123700,2.8931327,2.8167934,,,,,,,,,,,,,, -123800,3.0182517,2.7743068,,,,,,,,,,,,,, -123900,3.2741883,2.8480134,,,,,,,,,,,,,, -124000,2.9531803,2.7872343,,,,,,,,,,,,,, -124100,3.0829952,2.8152776,,,,,,,,,,,,,, -124200,3.0898075,2.7972245,,,,,,,,,,,,,, -124300,2.9126706,2.6804457,,,,,,,,,,,,,, -124400,3.1043339,2.840795,,,,,,,,,,,,,, -124500,3.108842,2.8475318,,,,,,,,,,,,,, -124600,3.1526127,2.8881903,,,,,,,,,,,,,, -124700,3.067476,2.7764807,,,,,,,,,,,,,, -124793,,,0.8462810516357422,0.8076683878898621,0.7290399670600891,1.3021578788757324,50000.0,0.6044000387191772,1.9222787618637085,10000.0,42371.742612838745,43855.968982219696,42371.742612838745,1476.5027883052826,3.3449392318725586,0.0 -124800,3.3465521,2.7728913,,,,,,,,,,,,,, -124900,3.166178,2.8394575,,,,,,,,,,,,,, -125000,3.0841522,2.8134663,,,,,,,,,,,,,, -125100,3.2554271,2.848654,,,,,,,,,,,,,, -125200,3.2209094,2.7864776,,,,,,,,,,,,,, -125300,3.465305,2.804283,,,,,,,,,,,,,, -125400,3.0330741,2.7598584,,,,,,,,,,,,,, -125500,3.4193084,2.8224196,,,,,,,,,,,,,, -125600,3.4121096,2.820875,,,,,,,,,,,,,, -125700,3.1798713,2.7818558,,,,,,,,,,,,,, -125800,3.2880647,2.853729,,,,,,,,,,,,,, -125900,2.9301867,2.7014215,,,,,,,,,,,,,, -126000,3.3770792,2.7877822,,,,,,,,,,,,,, -126100,2.8294237,2.6903927,,,,,,,,,,,,,, -126200,3.0839288,2.7803888,,,,,,,,,,,,,, -126297,,,0.8422752022743225,0.7997431755065918,0.7246800065040588,1.2923554182052612,50000.0,0.596500039100647,1.9440224170684808,10000.0,42881.6758556366,44383.64790058136,42881.6758556366,1494.1439380645752,3.396597146987915,0.0 -126300,3.1186533,2.8248408,,,,,,,,,,,,,, -126400,3.0422697,2.7328537,,,,,,,,,,,,,, -126500,3.2484193,2.8131728,,,,,,,,,,,,,, -126600,3.1302905,2.7807298,,,,,,,,,,,,,, -126700,3.1720047,2.7699637,,,,,,,,,,,,,, -126800,3.1722353,2.81322,,,,,,,,,,,,,, -126900,3.1426344,2.737259,,,,,,,,,,,,,, -127000,3.371689,2.7561035,,,,,,,,,,,,,, -127100,3.3679152,2.7860081,,,,,,,,,,,,,, -127200,3.1975288,2.773405,,,,,,,,,,,,,, -127300,3.3973773,2.755294,,,,,,,,,,,,,, -127400,3.2867105,2.838849,,,,,,,,,,,,,, -127500,3.2850504,2.836592,,,,,,,,,,,,,, -127600,3.2233422,2.7719374,,,,,,,,,,,,,, -127700,3.1041274,2.789545,,,,,,,,,,,,,, -127800,3.1286967,2.7526014,,,,,,,,,,,,,, -127801,,,0.8557876348495483,0.7706797122955322,0.7314199805259705,1.2801084518432615,50000.0,0.613800048828125,1.8942350149154663,10000.0,43391.6541454792,44911.067249298096,43391.6541454792,1511.4806642532349,3.4473884105682373,0.0 -127900,3.2801435,2.7860787,,,,,,,,,,,,,, -128000,3.4059513,2.7949529,,,,,,,,,,,,,, -128100,3.2487116,2.8304272,,,,,,,,,,,,,, -128200,3.1562061,2.7445436,,,,,,,,,,,,,, -128300,3.0346658,2.7509487,,,,,,,,,,,,,, -128400,3.2133226,2.7545047,,,,,,,,,,,,,, -128500,3.2273457,2.7865334,,,,,,,,,,,,,, -128600,3.2346008,2.7406347,,,,,,,,,,,,,, -128700,3.1138484,2.7462595,,,,,,,,,,,,,, -128800,3.2667515,2.7653756,,,,,,,,,,,,,, -128900,3.139632,2.7482371,,,,,,,,,,,,,, -129000,3.0821028,2.7136145,,,,,,,,,,,,,, -129100,3.2359848,2.7337515,,,,,,,,,,,,,, -129200,3.2757752,2.7800615,,,,,,,,,,,,,, -129300,3.3389933,2.8082268,,,,,,,,,,,,,, -129305,,,0.8703563213348389,0.682815670967102,0.7312600016593933,1.259967565536499,50000.0,0.6118000149726868,1.882664918899536,10000.0,43901.69043469429,45438.67550635338,43901.69043469429,1528.9485657215118,3.4982686042785645,0.0 -129400,3.2687595,2.6797864,,,,,,,,,,,,,, -129500,3.3626325,2.7874207,,,,,,,,,,,,,, -129600,3.1360395,2.8263984,,,,,,,,,,,,,, -129700,3.4243758,2.7830324,,,,,,,,,,,,,, -129800,3.518726,2.6952596,,,,,,,,,,,,,, -129900,3.146327,2.7438478,,,,,,,,,,,,,, -130000,3.3005188,2.781247,,,,,,,,,,,,,, -130100,3.4597466,2.7422273,,,,,,,,,,,,,, -130200,3.1987114,2.696333,,,,,,,,,,,,,, -130300,3.068093,2.7523103,,,,,,,,,,,,,, -130400,3.1298263,2.703416,,,,,,,,,,,,,, -130500,3.251922,2.6926036,,,,,,,,,,,,,, -130600,3.1409285,2.7538657,,,,,,,,,,,,,, -130700,3.2339232,2.7501674,,,,,,,,,,,,,, -130800,3.4241502,2.855167,,,,,,,,,,,,,, -130810,,,0.869559109210968,0.7021733522415161,0.7360599637031555,1.262072205543518,50000.0,0.6105000376701355,1.8829753398895264,10000.0,44411.83263254166,45966.23621058464,44411.83263254166,1546.2606403827667,3.5511486530303955,0.0 -130900,3.3347034,2.7737918,,,,,,,,,,,,,, -131000,3.2855105,2.7219028,,,,,,,,,,,,,, -131100,3.3459895,2.7427568,,,,,,,,,,,,,, -131200,3.0586972,2.7139215,,,,,,,,,,,,,, -131300,3.4121306,2.7570803,,,,,,,,,,,,,, -131400,3.3377895,2.7681441,,,,,,,,,,,,,, -131500,3.2936978,2.71996,,,,,,,,,,,,,, -131600,3.20392,2.7923045,,,,,,,,,,,,,, -131700,3.1907,2.747877,,,,,,,,,,,,,, -131800,3.0781426,2.6583807,,,,,,,,,,,,,, -131900,3.0631073,2.6715055,,,,,,,,,,,,,, -132000,3.342962,2.7106466,,,,,,,,,,,,,, -132100,3.2598534,2.6948237,,,,,,,,,,,,,, -132200,3.2501714,2.7640765,,,,,,,,,,,,,, -132300,3.2938848,2.7645812,,,,,,,,,,,,,, -132313,,,0.8615074753761292,0.731268048286438,0.7324999570846558,1.2648930549621582,50000.0,0.61080002784729,1.890793800354004,10000.0,44921.73423457146,46493.51533031464,44921.73423457146,1563.5348060131073,3.6022284030914307,0.0 -132400,3.2557344,2.737451,,,,,,,,,,,,,, -132500,3.5505583,2.7628183,,,,,,,,,,,,,, -132600,3.3740926,2.7490704,,,,,,,,,,,,,, -132700,3.3285642,2.7330923,,,,,,,,,,,,,, -132800,3.2321434,2.7445388,,,,,,,,,,,,,, -132900,3.3397706,2.7426524,,,,,,,,,,,,,, -133000,3.1135044,2.7364805,,,,,,,,,,,,,, -133100,3.202758,2.7527564,,,,,,,,,,,,,, -133200,3.1437485,2.692923,,,,,,,,,,,,,, -133300,3.5484776,2.798914,,,,,,,,,,,,,, -133400,3.4450438,2.7297533,,,,,,,,,,,,,, -133500,3.308249,2.7590117,,,,,,,,,,,,,, -133600,3.203771,2.6579218,,,,,,,,,,,,,, -133700,3.3119016,2.759354,,,,,,,,,,,,,, -133800,3.3108795,2.6709583,,,,,,,,,,,,,, -133817,,,0.8706154227256775,0.7510040998458862,0.7382999658584595,1.281070113182068,50000.0,0.6159000396728516,1.9176183938980105,10000.0,45431.71617555618,47021.19475483894,45431.71617555618,1581.1292176246643,3.6521716117858887,0.0 -133900,3.1678667,2.6973617,,,,,,,,,,,,,, -134000,3.4094877,2.7271833,,,,,,,,,,,,,, -134100,3.3829255,2.7562187,,,,,,,,,,,,,, -134200,3.5871441,2.7232583,,,,,,,,,,,,,, -134300,3.0651155,2.712531,,,,,,,,,,,,,, -134400,3.2389045,2.735529,,,,,,,,,,,,,, -134500,3.1030216,2.7112195,,,,,,,,,,,,,, -134600,3.386626,2.675036,,,,,,,,,,,,,, -134700,3.347795,2.7810972,,,,,,,,,,,,,, -134800,3.5738418,2.7668664,,,,,,,,,,,,,, -134900,3.465535,2.7405972,,,,,,,,,,,,,, -135000,3.35051,2.7655427,,,,,,,,,,,,,, -135100,3.3707645,2.758846,,,,,,,,,,,,,, -135200,3.3055446,2.7238255,,,,,,,,,,,,,, -135300,3.2936718,2.6702068,,,,,,,,,,,,,, -135322,,,0.8668088316917419,0.6989973187446594,0.7389199733734131,1.2317496538162231,50000.0,0.6148000359535217,1.8706231117248533,10000.0,45941.8173930645,47548.92354273796,45941.8173930645,1598.651022195816,3.70396900177002,0.0 -135400,3.4340312,2.6807935,,,,,,,,,,,,,, -135500,3.2964842,2.6992397,,,,,,,,,,,,,, -135600,3.0637624,2.6806517,,,,,,,,,,,,,, -135700,3.6933749,2.6865687,,,,,,,,,,,,,, -135800,3.2868063,2.7172024,,,,,,,,,,,,,, -135900,3.5793133,2.7672348,,,,,,,,,,,,,, -136000,3.4692795,2.7501397,,,,,,,,,,,,,, -136100,3.192411,2.7300916,,,,,,,,,,,,,, -136200,3.3922026,2.7125885,,,,,,,,,,,,,, -136300,3.8918052,2.7591934,,,,,,,,,,,,,, -136400,3.3055425,2.7254863,,,,,,,,,,,,,, -136500,3.672894,2.8093874,,,,,,,,,,,,,, -136600,3.304266,2.7835367,,,,,,,,,,,,,, -136700,3.3584504,2.7243402,,,,,,,,,,,,,, -136800,3.455774,2.7559755,,,,,,,,,,,,,, -136826,,,0.8732461333274841,0.7015025019645691,0.7404599785804749,1.2484326362609863,50000.0,0.6172000169754028,1.886087417602539,10000.0,46451.72684264183,48076.31215286255,46451.72684264183,1616.0231895446775,3.759093761444092,0.0 -136900,3.4286335,2.730536,,,,,,,,,,,,,, -137000,3.2809627,2.6348028,,,,,,,,,,,,,, -137100,3.2340062,2.672966,,,,,,,,,,,,,, -137200,3.2418487,2.6562045,,,,,,,,,,,,,, -137300,3.333226,2.6991963,,,,,,,,,,,,,, -137400,3.5999177,2.7350502,,,,,,,,,,,,,, -137500,3.4163501,2.6973093,,,,,,,,,,,,,, -137600,3.4796824,2.7789242,,,,,,,,,,,,,, -137700,3.4646282,2.713566,,,,,,,,,,,,,, -137800,3.5014968,2.674686,,,,,,,,,,,,,, -137900,3.5118225,2.765867,,,,,,,,,,,,,, -138000,3.5785,2.7731085,,,,,,,,,,,,,, -138100,3.525991,2.7041621,,,,,,,,,,,,,, -138200,3.4367816,2.7456152,,,,,,,,,,,,,, -138300,3.2982032,2.6564672,,,,,,,,,,,,,, -138331,,,0.8901665806770325,0.6423742771148682,0.7400799989700317,1.2568217515945437,50000.0,0.6165000200271606,1.88319993019104,10000.0,46961.66208958626,48603.79757499695,46961.66208958626,1633.4662518501282,3.812281608581543,0.0 -138400,3.622329,2.7476547,,,,,,,,,,,,,, -138500,3.6463525,2.6995034,,,,,,,,,,,,,, -138600,3.594189,2.6822839,,,,,,,,,,,,,, -138700,3.3587785,2.7350993,,,,,,,,,,,,,, -138800,3.5835533,2.7253954,,,,,,,,,,,,,, -138900,3.296243,2.6744437,,,,,,,,,,,,,, -139000,3.4930756,2.7135215,,,,,,,,,,,,,, -139100,3.580017,2.7032583,,,,,,,,,,,,,, -139200,3.268938,2.692288,,,,,,,,,,,,,, -139300,3.3996522,2.673001,,,,,,,,,,,,,, -139400,3.4864395,2.7624812,,,,,,,,,,,,,, -139500,3.4236438,2.6669364,,,,,,,,,,,,,, -139600,3.5038507,2.690896,,,,,,,,,,,,,, -139700,3.580824,2.7268634,,,,,,,,,,,,,, -139800,3.4218926,2.73456,,,,,,,,,,,,,, -139836,,,0.8897680044174194,0.627855658531189,0.74235999584198,1.2312299013137815,50000.0,0.6229000091552734,1.850735783576965,10000.0,47471.80485224724,49131.303370952606,47471.80485224724,1650.7227218151093,3.866595983505249,0.0 -139900,3.2961714,2.6822927,,,,,,,,,,,,,, -140000,3.5012612,2.631825,,,,,,,,,,,,,, -140100,3.5170248,2.6746836,,,,,,,,,,,,,, -140200,3.6257885,2.6824396,,,,,,,,,,,,,, -140300,3.3745039,2.6861305,,,,,,,,,,,,,, -140400,3.371823,2.7099493,,,,,,,,,,,,,, -140500,3.6466484,2.7104204,,,,,,,,,,,,,, -140600,3.5682054,2.7115765,,,,,,,,,,,,,, -140700,3.4824586,2.6348705,,,,,,,,,,,,,, -140800,3.4488342,2.7130597,,,,,,,,,,,,,, -140900,3.5574071,2.6703706,,,,,,,,,,,,,, -141000,3.80653,2.6554158,,,,,,,,,,,,,, -141100,3.714803,2.7412348,,,,,,,,,,,,,, -141200,3.2926123,2.6365564,,,,,,,,,,,,,, -141300,3.292797,2.674542,,,,,,,,,,,,,, -141341,,,0.886738657951355,0.6396247744560242,0.745639979839325,1.2290503978729248,50000.0,0.6237000226974487,1.8512567281723025,10000.0,47981.99515795708,49659.10092759133,47981.99515795708,1668.2225155830383,3.922005653381348,0.0 -141400,3.503982,2.6912515,,,,,,,,,,,,,, -141500,3.3752403,2.739533,,,,,,,,,,,,,, -141600,3.3369377,2.7390811,,,,,,,,,,,,,, -141700,3.6337428,2.6794286,,,,,,,,,,,,,, -141800,3.5840096,2.6923964,,,,,,,,,,,,,, -141900,3.5120897,2.711863,,,,,,,,,,,,,, -142000,3.719366,2.7083206,,,,,,,,,,,,,, -142100,3.4336965,2.6467066,,,,,,,,,,,,,, -142200,3.4985757,2.6689568,,,,,,,,,,,,,, -142300,3.3764749,2.6234667,,,,,,,,,,,,,, -142400,3.471684,2.6799946,,,,,,,,,,,,,, -142500,3.660801,2.6261935,,,,,,,,,,,,,, -142600,3.6135476,2.6799955,,,,,,,,,,,,,, -142700,3.7975118,2.7071724,,,,,,,,,,,,,, -142800,3.6851156,2.7445912,,,,,,,,,,,,,, -142846,,,0.8891502022743225,0.6449418067932129,0.7464199662208557,1.2233476638793943,50000.0,0.6234000325202942,1.851295828819275,10000.0,48492.10888576508,50186.51897478104,48492.10888576508,1685.4138662815094,3.981988430023194,0.0 -142900,3.5466442,2.7131264,,,,,,,,,,,,,, -143000,3.6450698,2.6879268,,,,,,,,,,,,,, -143100,3.5664372,2.6788926,,,,,,,,,,,,,, -143200,3.631175,2.6896937,,,,,,,,,,,,,, -143300,3.6289043,2.6936104,,,,,,,,,,,,,, -143400,3.5916426,2.7273445,,,,,,,,,,,,,, -143500,3.7269912,2.6611443,,,,,,,,,,,,,, -143600,3.9381983,2.7417257,,,,,,,,,,,,,, -143700,3.5287368,2.6968384,,,,,,,,,,,,,, -143800,3.5273156,2.6489525,,,,,,,,,,,,,, -143900,3.3764772,2.6101036,,,,,,,,,,,,,, -144000,3.7549386,2.673706,,,,,,,,,,,,,, -144100,3.7972727,2.7270923,,,,,,,,,,,,,, -144200,3.3461847,2.6136127,,,,,,,,,,,,,, -144300,3.6552193,2.6612341,,,,,,,,,,,,,, -144351,,,0.8878347873687744,0.636131763458252,0.7448999881744385,1.2222989797592163,50000.0,0.6279000043869019,1.849866271018982,10000.0,49002.30559277535,50714.08239650726,49002.30559277535,1702.677888393402,4.033036470413208,0.0 -144400,4.0233216,2.7689915,,,,,,,,,,,,,, -144500,3.7053363,2.6606543,,,,,,,,,,,,,, -144600,3.9288836,2.6623242,,,,,,,,,,,,,, -144700,3.6488981,2.6743314,,,,,,,,,,,,,, -144800,3.6300032,2.6598053,,,,,,,,,,,,,, -144900,3.868442,2.7058647,,,,,,,,,,,,,, -145000,3.6034443,2.6690936,,,,,,,,,,,,,, -145100,3.5612485,2.6474652,,,,,,,,,,,,,, -145200,3.6059546,2.6691072,,,,,,,,,,,,,, -145300,3.6514173,2.742557,,,,,,,,,,,,,, -145400,3.453324,2.6733482,,,,,,,,,,,,,, -145500,3.3722925,2.6228685,,,,,,,,,,,,,, -145600,3.622948,2.694575,,,,,,,,,,,,,, -145700,3.5734432,2.6580226,,,,,,,,,,,,,, -145800,3.6694353,2.6257172,,,,,,,,,,,,,, -145856,,,0.8903858065605164,0.6295886039733887,0.7457000017166138,1.2230639457702637,50000.0,0.6261000037193298,1.8473961353302,10000.0,49512.39498496056,51241.637093544006,49512.39498496056,1720.0401091575625,4.083997964859009,0.0 -145900,3.5482004,2.6627297,,,,,,,,,,,,,, -146000,3.8948627,2.674581,,,,,,,,,,,,,, -146100,3.8235674,2.6524463,,,,,,,,,,,,,, -146200,3.6521573,2.6651478,,,,,,,,,,,,,, -146300,3.5486352,2.6657448,,,,,,,,,,,,,, -146400,3.5724947,2.6622505,,,,,,,,,,,,,, -146500,3.6904466,2.634401,,,,,,,,,,,,,, -146600,3.5593457,2.6647608,,,,,,,,,,,,,, -146700,3.6684968,2.6407657,,,,,,,,,,,,,, -146800,3.422338,2.6227865,,,,,,,,,,,,,, -146900,3.8558874,2.6745744,,,,,,,,,,,,,, -147000,3.850894,2.6373851,,,,,,,,,,,,,, -147100,3.6932611,2.6304576,,,,,,,,,,,,,, -147200,3.7615955,2.6119523,,,,,,,,,,,,,, -147300,3.6763916,2.7210698,,,,,,,,,,,,,, -147360,,,0.9123684167861938,0.5474082827568054,0.7480799555778503,1.2101439237594604,50000.0,0.6310000419616699,1.8346625566482544,10000.0,50022.29506134987,51769.10906100273,50022.29506134987,1737.5025906562803,4.140517711639404,0.0 -147400,3.408171,2.6427796,,,,,,,,,,,,,, -147500,3.5476103,2.6568873,,,,,,,,,,,,,, -147600,3.594273,2.5566056,,,,,,,,,,,,,, -147700,3.8035011,2.6296465,,,,,,,,,,,,,, -147800,3.568759,2.6436167,,,,,,,,,,,,,, -147900,3.6583903,2.6044803,,,,,,,,,,,,,, -148000,3.666162,2.6344483,,,,,,,,,,,,,, -148100,3.467684,2.6036887,,,,,,,,,,,,,, -148200,3.815104,2.7164314,,,,,,,,,,,,,, -148300,3.7654793,2.629712,,,,,,,,,,,,,, -148400,3.8005142,2.6501918,,,,,,,,,,,,,, -148500,3.8703449,2.6642942,,,,,,,,,,,,,, -148600,3.9446354,2.7262037,,,,,,,,,,,,,, -148700,3.6731217,2.635556,,,,,,,,,,,,,, -148800,3.65839,2.6762977,,,,,,,,,,,,,, -148864,,,0.9058912396430968,0.5808451175689697,0.7503199577331543,1.2156970500946045,50000.0,0.6290000081062317,1.848047018051148,10000.0,50532.36496210098,52296.71101593971,50532.36496210098,1754.9088788032532,4.21375036239624,0.0 -148900,3.8655372,2.6340146,,,,,,,,,,,,,, -149000,3.5157683,2.6047359,,,,,,,,,,,,,, -149100,3.638381,2.6163993,,,,,,,,,,,,,, -149200,4.077223,2.699957,,,,,,,,,,,,,, -149300,3.6519384,2.6461897,,,,,,,,,,,,,, -149400,3.6064153,2.678061,,,,,,,,,,,,,, -149500,3.81895,2.65044,,,,,,,,,,,,,, -149600,3.8638747,2.6940932,,,,,,,,,,,,,, -149700,3.8692968,2.6331563,,,,,,,,,,,,,, -149800,3.8476772,2.7000558,,,,,,,,,,,,,, -149900,3.57054,2.6405888,,,,,,,,,,,,,, -150000,3.610127,2.6096377,,,,,,,,,,,,,, -150100,3.8934402,2.6532893,,,,,,,,,,,,,, -150200,3.6150386,2.5815513,,,,,,,,,,,,,, -150300,3.7588189,2.6501834,,,,,,,,,,,,,, -150368,,,0.9048748016357422,0.5630061030387878,0.7488600015640259,1.1980425119400024,50000.0,0.628600001335144,1.820862054824829,10000.0,51042.27467918396,52824.65858411789,51042.27467918396,1772.8417789936066,4.267144680023193,0.0 -150400,3.9073932,2.63863,,,,,,,,,,,,,, -150500,3.774983,2.6462889,,,,,,,,,,,,,, -150600,3.6969016,2.5829592,,,,,,,,,,,,,, -150700,3.7539485,2.6057136,,,,,,,,,,,,,, -150800,3.93554,2.7059033,,,,,,,,,,,,,, -150900,3.810222,2.6178155,,,,,,,,,,,,,, -151000,3.770671,2.6421566,,,,,,,,,,,,,, -151100,3.7955663,2.5909474,,,,,,,,,,,,,, -151200,3.740456,2.6152244,,,,,,,,,,,,,, -151300,3.9423864,2.6463559,,,,,,,,,,,,,, -151400,3.591011,2.5376344,,,,,,,,,,,,,, -151500,3.8681948,2.5911527,,,,,,,,,,,,,, -151600,3.684123,2.596656,,,,,,,,,,,,,, -151700,3.7350497,2.6728075,,,,,,,,,,,,,, -151800,3.7353258,2.6014628,,,,,,,,,,,,,, -151873,,,0.9094387292861938,0.5636129379272461,0.7510600090026855,1.2021054029464722,50000.0,0.6339000463485718,1.816255807876587,10000.0,51552.41488814354,53352.32368469238,51552.41488814354,1790.2585053443909,4.323015451431274,0.0 -151900,3.6354468,2.624681,,,,,,,,,,,,,, -152000,3.9214785,2.5795968,,,,,,,,,,,,,, -152100,3.569412,2.5721586,,,,,,,,,,,,,, -152200,3.5566413,2.58869,,,,,,,,,,,,,, -152300,4.009824,2.6174104,,,,,,,,,,,,,, -152400,3.8171072,2.5695622,,,,,,,,,,,,,, -152500,3.8826988,2.649636,,,,,,,,,,,,,, -152600,4.053456,2.652008,,,,,,,,,,,,,, -152700,3.8732715,2.685235,,,,,,,,,,,,,, -152800,3.7822025,2.644241,,,,,,,,,,,,,, -152900,3.9807901,2.6143308,,,,,,,,,,,,,, -153000,3.9770064,2.5752232,,,,,,,,,,,,,, -153100,3.5747743,2.5766115,,,,,,,,,,,,,, -153200,3.8045368,2.6060386,,,,,,,,,,,,,, -153300,3.7258854,2.5712783,,,,,,,,,,,,,, -153378,,,0.9103156924247742,0.563623309135437,0.7524799704551697,1.2048771381378174,50000.0,0.6305000185966492,1.829545259475708,10000.0,52062.60071802139,53880.12092804909,52062.60071802139,1807.7640812397003,4.375983238220215,0.0 -153400,4.050809,2.6427355,,,,,,,,,,,,,, -153500,3.9976926,2.642027,,,,,,,,,,,,,, -153600,4.012125,2.6324456,,,,,,,,,,,,,, -153700,3.784061,2.5867732,,,,,,,,,,,,,, -153800,3.726723,2.5752797,,,,,,,,,,,,,, -153900,3.808265,2.583963,,,,,,,,,,,,,, -154000,3.9203804,2.620985,,,,,,,,,,,,,, -154100,3.8712234,2.5686939,,,,,,,,,,,,,, -154200,3.9470375,2.6700692,,,,,,,,,,,,,, -154300,3.6641376,2.5572672,,,,,,,,,,,,,, -154400,3.7460818,2.5853,,,,,,,,,,,,,, -154500,3.9148104,2.5963225,,,,,,,,,,,,,, -154600,3.6911361,2.61648,,,,,,,,,,,,,, -154700,3.9261901,2.5853426,,,,,,,,,,,,,, -154800,3.8668458,2.5892575,,,,,,,,,,,,,, -154883,,,0.911730706691742,0.5565671324729919,0.752020001411438,1.2094497680664062,50000.0,0.6341000199317932,1.829664707183838,10000.0,52572.8312189579,54407.95564079285,52572.8312189579,1825.2644836902616,4.426474571228027,0.0 -154900,3.735096,2.5935004,,,,,,,,,,,,,, -155000,3.833467,2.644703,,,,,,,,,,,,,, -155100,3.7827482,2.6360273,,,,,,,,,,,,,, -155200,4.4088464,2.649295,,,,,,,,,,,,,, -155300,3.5963426,2.607907,,,,,,,,,,,,,, -155400,4.067263,2.6638355,,,,,,,,,,,,,, -155500,3.7019224,2.5924845,,,,,,,,,,,,,, -155600,3.9602423,2.6149669,,,,,,,,,,,,,, -155700,3.8217814,2.6363328,,,,,,,,,,,,,, -155800,3.7402363,2.6048005,,,,,,,,,,,,,, -155900,3.9223108,2.6516852,,,,,,,,,,,,,, -156000,3.8071325,2.555143,,,,,,,,,,,,,, -156100,3.592666,2.6013217,,,,,,,,,,,,,, -156200,3.7744958,2.6599483,,,,,,,,,,,,,, -156300,3.7549562,2.585853,,,,,,,,,,,,,, -156387,,,0.9270368218421936,0.5088555812835693,0.7543599605560303,1.2003954648971558,50000.0,0.6367000341415405,1.8163809776306152,10000.0,53082.78435873985,54935.78036189079,53082.78435873985,1843.028647899628,4.481486082077026,0.0 -156400,3.7139788,2.5695431,,,,,,,,,,,,,, -156500,3.7588246,2.6242287,,,,,,,,,,,,,, -156600,3.6865587,2.5362837,,,,,,,,,,,,,, -156700,3.963036,2.590965,,,,,,,,,,,,,, -156800,3.9153814,2.5971148,,,,,,,,,,,,,, -156900,3.7462912,2.543226,,,,,,,,,,,,,, -157000,3.8581848,2.5343394,,,,,,,,,,,,,, -157100,3.7880456,2.557486,,,,,,,,,,,,,, -157200,4.2148476,2.6498206,,,,,,,,,,,,,, -157300,3.7175488,2.5800624,,,,,,,,,,,,,, -157400,4.0481668,2.6204848,,,,,,,,,,,,,, -157500,4.0116434,2.5847535,,,,,,,,,,,,,, -157600,4.315316,2.6094108,,,,,,,,,,,,,, -157700,3.8042572,2.5709636,,,,,,,,,,,,,, -157800,3.9286585,2.630372,,,,,,,,,,,,,, -157892,,,0.922632336616516,0.5026780366897583,0.7541999816894531,1.184553146362305,50000.0,0.6326000094413757,1.8125512599945068,10000.0,53592.87838935852,55464.342358112335,53592.87838935852,1861.3870635032647,4.539382696151733,0.0 -157900,4.261435,2.5864744,,,,,,,,,,,,,, -158000,3.9821494,2.5980346,,,,,,,,,,,,,, -158100,4.1213455,2.5849507,,,,,,,,,,,,,, -158200,3.9579554,2.5744097,,,,,,,,,,,,,, -158300,3.683446,2.5639737,,,,,,,,,,,,,, -158400,3.8816957,2.5408087,,,,,,,,,,,,,, -158500,4.026844,2.6321235,,,,,,,,,,,,,, -158600,4.16465,2.576372,,,,,,,,,,,,,, -158700,3.954105,2.6245432,,,,,,,,,,,,,, -158800,3.8287957,2.5580187,,,,,,,,,,,,,, -158900,4.0502415,2.604249,,,,,,,,,,,,,, -159000,3.8910313,2.5605795,,,,,,,,,,,,,, -159100,3.924379,2.5811443,,,,,,,,,,,,,, -159200,3.987043,2.5557678,,,,,,,,,,,,,, -159300,4.0228577,2.544048,,,,,,,,,,,,,, -159396,,,0.9216158986091614,0.5223816633224487,0.7558599710464478,1.198951244354248,50000.0,0.64000004529953,1.815568685531616,10000.0,54102.98790359497,55991.9786362648,54102.98790359497,1878.803019285202,4.597425699234009,0.0 -159400,3.9183004,2.56438,,,,,,,,,,,,,, -159500,3.9738526,2.5227108,,,,,,,,,,,,,, -159600,3.7670338,2.605516,,,,,,,,,,,,,, -159700,4.447871,2.5701683,,,,,,,,,,,,,, -159800,3.7266827,2.5824625,,,,,,,,,,,,,, -159900,3.8507166,2.5646014,,,,,,,,,,,,,, -160000,4.0183935,2.5334525,,,,,,,,,,,,,, -160100,3.9337204,2.5435705,,,,,,,,,,,,,, -160200,3.6547685,2.5538208,,,,,,,,,,,,,, -160300,3.699889,2.5247924,,,,,,,,,,,,,, -160400,3.797158,2.5318782,,,,,,,,,,,,,, -160500,4.086463,2.5890696,,,,,,,,,,,,,, -160600,3.6749115,2.509079,,,,,,,,,,,,,, -160700,3.929345,2.6422,,,,,,,,,,,,,, -160800,4.04679,2.6363015,,,,,,,,,,,,,, -160900,,,0.9227519035339355,0.5238286256790161,0.7569599747657776,1.1966742277145386,50000.0,0.6383000016212463,1.8186490535736084,10000.0,54612.884423971176,56519.49779844284,54612.884423971176,1896.3183450698853,4.652040481567383,0.0 -160900,3.86068,2.5701287,,,,,,,,,,,,,, -161000,3.9465692,2.5814645,,,,,,,,,,,,,, -161100,4.310881,2.6615222,,,,,,,,,,,,,, -161200,3.9194176,2.5676146,,,,,,,,,,,,,, -161300,3.8448668,2.5761034,,,,,,,,,,,,,, -161400,3.8915253,2.556087,,,,,,,,,,,,,, -161500,3.987859,2.608315,,,,,,,,,,,,,, -161600,4.0802674,2.599439,,,,,,,,,,,,,, -161700,3.7898366,2.535933,,,,,,,,,,,,,, -161800,4.179438,2.5573032,,,,,,,,,,,,,, -161900,4.055577,2.567594,,,,,,,,,,,,,, -162000,3.8848763,2.5676675,,,,,,,,,,,,,, -162100,3.9237053,2.598948,,,,,,,,,,,,,, -162200,3.7251008,2.540055,,,,,,,,,,,,,, -162300,4.1734753,2.6632533,,,,,,,,,,,,,, -162400,3.9835925,2.6280403,,,,,,,,,,,,,, -162405,,,0.9255221486091614,0.5173733234405518,0.7567399740219116,1.1949071884155271,50000.0,0.6376000046730042,1.815171360969544,10000.0,55122.97531867027,57047.208235025406,55122.97531867027,1913.817587852478,4.720116376876831,0.0 -162500,3.9181519,2.5766375,,,,,,,,,,,,,, -162600,4.2013526,2.5765057,,,,,,,,,,,,,, -162700,4.0110564,2.6295543,,,,,,,,,,,,,, -162800,4.5785604,2.5300443,,,,,,,,,,,,,, -162900,4.1600842,2.5617595,,,,,,,,,,,,,, -163000,4.020743,2.568386,,,,,,,,,,,,,, -163100,3.9978461,2.536802,,,,,,,,,,,,,, -163200,4.090827,2.56711,,,,,,,,,,,,,, -163300,3.821525,2.5544634,,,,,,,,,,,,,, -163400,3.767614,2.572135,,,,,,,,,,,,,, -163500,4.025474,2.6181939,,,,,,,,,,,,,, -163600,4.190962,2.5965824,,,,,,,,,,,,,, -163700,3.995453,2.5492058,,,,,,,,,,,,,, -163800,3.9405413,2.5582864,,,,,,,,,,,,,, -163900,3.8144667,2.5696416,,,,,,,,,,,,,, -163910,,,0.9263990521430968,0.5073674917221069,0.7562999725341797,1.1926721334457395,50000.0,0.6381000280380249,1.8144419193267824,10000.0,55633.16078686714,57575.13672566414,55633.16078686714,1931.449508666992,4.778955459594727,0.0 -164000,3.8382342,2.5650544,,,,,,,,,,,,,, -164100,3.9829576,2.5623407,,,,,,,,,,,,,, -164200,3.8959758,2.5504937,,,,,,,,,,,,,, -164300,3.9618075,2.570464,,,,,,,,,,,,,, -164400,3.9454527,2.59596,,,,,,,,,,,,,, -164500,3.7263808,2.5621345,,,,,,,,,,,,,, -164600,4.0230093,2.549574,,,,,,,,,,,,,, -164700,4.0147786,2.53796,,,,,,,,,,,,,, -164800,4.304226,2.5684226,,,,,,,,,,,,,, -164900,3.8464918,2.5458217,,,,,,,,,,,,,, -165000,4.494838,2.5447254,,,,,,,,,,,,,, -165100,4.009727,2.5812478,,,,,,,,,,,,,, -165200,4.04685,2.5235221,,,,,,,,,,,,,, -165300,3.9827693,2.5528486,,,,,,,,,,,,,, -165400,4.017972,2.5653841,,,,,,,,,,,,,, -165415,,,0.93558669090271,0.4720862805843353,0.7570399641990662,1.1844284534454346,50000.0,0.6370000243186951,1.8039889335632324,10000.0,56143.32125544548,58102.80092835426,56143.32125544548,1948.845008611679,4.834153413772583,0.0 -165500,3.969575,2.5579867,,,,,,,,,,,,,, -165600,3.979847,2.5482738,,,,,,,,,,,,,, -165700,3.9073184,2.5524638,,,,,,,,,,,,,, -165800,3.7839313,2.5442693,,,,,,,,,,,,,, -165900,3.961074,2.5044765,,,,,,,,,,,,,, -166000,3.8898005,2.488764,,,,,,,,,,,,,, -166100,3.9535308,2.4856782,,,,,,,,,,,,,, -166200,3.929611,2.5566132,,,,,,,,,,,,,, -166300,4.1517262,2.5568643,,,,,,,,,,,,,, -166400,4.1562943,2.5484874,,,,,,,,,,,,,, -166500,3.7476013,2.5085285,,,,,,,,,,,,,, -166600,4.0619335,2.5416133,,,,,,,,,,,,,, -166700,4.0910707,2.5764444,,,,,,,,,,,,,, -166800,4.0423827,2.521751,,,,,,,,,,,,,, -166900,4.1247053,2.5148783,,,,,,,,,,,,,, -166919,,,0.9346898794174194,0.4757241606712341,0.759880006313324,1.1815327405929563,50000.0,0.640500009059906,1.7958463430404663,10000.0,56653.3142747879,58630.24149942398,56653.3142747879,1966.1763620376587,4.896468162536621,0.0 -167000,3.9634533,2.5110815,,,,,,,,,,,,,, -167100,4.0107164,2.5511537,,,,,,,,,,,,,, -167200,4.452412,2.5807028,,,,,,,,,,,,,, -167300,3.6763823,2.5484045,,,,,,,,,,,,,, -167400,3.970743,2.5278277,,,,,,,,,,,,,, -167500,3.799414,2.534194,,,,,,,,,,,,,, -167600,3.718434,2.5143843,,,,,,,,,,,,,, -167700,4.347668,2.5669825,,,,,,,,,,,,,, -167800,3.9396331,2.5269046,,,,,,,,,,,,,, -167900,4.1798573,2.5347326,,,,,,,,,,,,,, -168000,4.0942354,2.5105443,,,,,,,,,,,,,, -168100,4.0918765,2.5379767,,,,,,,,,,,,,, -168200,3.975433,2.519531,,,,,,,,,,,,,, -168300,3.6753147,2.4892893,,,,,,,,,,,,,, -168400,3.7046814,2.5696924,,,,,,,,,,,,,, -168424,,,0.9356863498687744,0.4620772004127502,0.7594199776649475,1.1747729778289795,50000.0,0.6402000188827515,1.7989342212677002,10000.0,57163.46930789948,59157.79861664772,57163.46930789948,1983.467565536499,4.954056978225708,0.0 -168500,3.8126144,2.5197158,,,,,,,,,,,,,, -168600,4.066557,2.5557783,,,,,,,,,,,,,, -168700,4.083719,2.5474324,,,,,,,,,,,,,, -168800,4.302935,2.5794997,,,,,,,,,,,,,, -168900,3.8376622,2.5190814,,,,,,,,,,,,,, -169000,3.7958865,2.5444193,,,,,,,,,,,,,, -169100,3.7959433,2.468561,,,,,,,,,,,,,, -169200,3.936374,2.5177202,,,,,,,,,,,,,, -169300,3.896953,2.5321012,,,,,,,,,,,,,, -169400,4.34288,2.6237326,,,,,,,,,,,,,, -169500,4.1503882,2.5228612,,,,,,,,,,,,,, -169600,3.987059,2.5499122,,,,,,,,,,,,,, -169700,4.121192,2.5611677,,,,,,,,,,,,,, -169800,3.7348263,2.503624,,,,,,,,,,,,,, -169900,4.2433195,2.6070726,,,,,,,,,,,,,, -169929,,,0.9359853267669678,0.4732051193714142,0.7616399526596069,1.178303360939026,50000.0,0.6430000066757202,1.7972735166549685,10000.0,57673.66196155548,59685.62716174126,57673.66196155548,2000.991406917572,5.014198780059815,0.0 -170000,4.024862,2.5617428,,,,,,,,,,,,,, -170100,4.383558,2.505525,,,,,,,,,,,,,, -170200,4.197388,2.5313048,,,,,,,,,,,,,, -170300,4.342896,2.568338,,,,,,,,,,,,,, -170400,3.9488564,2.5360394,,,,,,,,,,,,,, -170500,3.978308,2.5481923,,,,,,,,,,,,,, -170600,4.10341,2.566918,,,,,,,,,,,,,, -170700,4.043912,2.520566,,,,,,,,,,,,,, -170800,4.094449,2.5182815,,,,,,,,,,,,,, -170900,3.8987086,2.5195591,,,,,,,,,,,,,, -171000,4.190369,2.4992945,,,,,,,,,,,,,, -171100,4.1570525,2.5360506,,,,,,,,,,,,,, -171200,3.9915423,2.5024562,,,,,,,,,,,,,, -171300,4.1405244,2.564275,,,,,,,,,,,,,, -171400,4.070677,2.4707544,,,,,,,,,,,,,, -171434,,,0.9338527917861938,0.4760339260101318,0.7612999677658081,1.17757248878479,50000.0,0.6459000110626221,1.798004150390625,10000.0,58183.85387516022,60213.34439706802,58183.85387516022,2018.402989387512,5.076361656188965,0.0 -171500,4.0856366,2.5207734,,,,,,,,,,,,,, -171600,3.9778752,2.5014944,,,,,,,,,,,,,, -171700,4.110851,2.5414877,,,,,,,,,,,,,, -171800,3.763366,2.4627059,,,,,,,,,,,,,, -171900,4.3422723,2.5333047,,,,,,,,,,,,,, -172000,3.9356687,2.534456,,,,,,,,,,,,,, -172100,3.8288186,2.5339818,,,,,,,,,,,,,, -172200,4.388145,2.5747716,,,,,,,,,,,,,, -172300,3.8978045,2.52327,,,,,,,,,,,,,, -172400,4.365691,2.5141006,,,,,,,,,,,,,, -172500,4.1709566,2.5381348,,,,,,,,,,,,,, -172600,4.06979,2.524266,,,,,,,,,,,,,, -172700,4.1518216,2.5083315,,,,,,,,,,,,,, -172800,4.1568785,2.4949174,,,,,,,,,,,,,, -172900,3.9694571,2.5371227,,,,,,,,,,,,,, -172938,,,0.9358657598495485,0.4723958671092987,0.7612400054931641,1.1811754703521729,50000.0,0.6421000361442566,1.8015018701553345,10000.0,58693.92332792282,60740.91999530792,58693.92332792282,2035.8027634620669,5.131362676620483,0.0 -173000,3.9131644,2.4703526,,,,,,,,,,,,,, -173100,4.0641494,2.5295677,,,,,,,,,,,,,, -173200,3.9565167,2.5371227,,,,,,,,,,,,,, -173300,4.2106576,2.5728815,,,,,,,,,,,,,, -173400,3.954374,2.4504519,,,,,,,,,,,,,, -173500,3.9851475,2.5120208,,,,,,,,,,,,,, -173600,4.152791,2.543497,,,,,,,,,,,,,, -173700,4.0019975,2.538191,,,,,,,,,,,,,, -173800,4.071163,2.5711646,,,,,,,,,,,,,, -173900,4.1250615,2.5679436,,,,,,,,,,,,,, -174000,4.3639364,2.520391,,,,,,,,,,,,,, -174100,4.1214767,2.54894,,,,,,,,,,,,,, -174200,4.1606245,2.541076,,,,,,,,,,,,,, -174300,3.802029,2.5391436,,,,,,,,,,,,,, -174400,3.619999,2.5010035,,,,,,,,,,,,,, -174442,,,0.938875138759613,0.4575260877609253,0.7615000009536743,1.1745693683624268,50000.0,0.6446000337600708,1.7955400943756104,10000.0,59204.02347588539,61268.68344020844,59204.02347588539,2053.356356859207,5.18861722946167,0.0 -174500,4.058336,2.5205746,,,,,,,,,,,,,, -174600,4.142594,2.575835,,,,,,,,,,,,,, -174700,4.246927,2.5705395,,,,,,,,,,,,,, -174800,3.8717766,2.5085762,,,,,,,,,,,,,, -174900,4.2293735,2.563447,,,,,,,,,,,,,, -175000,4.06993,2.559253,,,,,,,,,,,,,, -175100,3.855874,2.545384,,,,,,,,,,,,,, -175200,4.1283865,2.5403984,,,,,,,,,,,,,, -175300,4.105181,2.5708241,,,,,,,,,,,,,, -175400,3.8965428,2.472833,,,,,,,,,,,,,, -175500,4.0140233,2.5555787,,,,,,,,,,,,,, -175600,3.987724,2.4977665,,,,,,,,,,,,,, -175700,3.9535425,2.5297382,,,,,,,,,,,,,, -175800,4.101997,2.5244005,,,,,,,,,,,,,, -175900,4.0313807,2.5831757,,,,,,,,,,,,,, -175947,,,0.9404296875,0.4535814821720123,0.7617799639701843,1.1750295162200928,50000.0,0.6451000571250916,1.7953126430511477,10000.0,59714.15299129486,61796.51736474037,59714.15299129486,2070.9494745731354,5.2475292682647705,0.0 -176000,4.122123,2.5382051,,,,,,,,,,,,,, -176100,4.2343473,2.533432,,,,,,,,,,,,,, -176200,4.106603,2.5501158,,,,,,,,,,,,,, -176300,4.0047274,2.541145,,,,,,,,,,,,,, -176400,4.1924515,2.5654387,,,,,,,,,,,,,, -176500,3.9667432,2.5187037,,,,,,,,,,,,,, -176600,4.2722855,2.511735,,,,,,,,,,,,,, -176700,3.998647,2.560698,,,,,,,,,,,,,, -176800,4.266446,2.5422537,,,,,,,,,,,,,, -176900,3.9298086,2.5463514,,,,,,,,,,,,,, -177000,4.183681,2.5308366,,,,,,,,,,,,,, -177100,4.409519,2.5587196,,,,,,,,,,,,,, -177200,4.0481753,2.4821086,,,,,,,,,,,,,, -177300,4.2492065,2.532807,,,,,,,,,,,,,, -177400,3.9740322,2.4939876,,,,,,,,,,,,,, -177451,,,0.9409877061843872,0.457239419221878,0.762179970741272,1.175032138824463,50000.0,0.647100031375885,1.7913333177566528,10000.0,60224.07753944397,62324.28744530678,60224.07753944397,2088.686345100403,5.303653001785278,0.0 -177500,4.312572,2.5590043,,,,,,,,,,,,,, -177600,4.555723,2.5119853,,,,,,,,,,,,,, -177700,4.4026933,2.5617151,,,,,,,,,,,,,, -177800,4.2766523,2.5557425,,,,,,,,,,,,,, -177900,4.1818223,2.5914838,,,,,,,,,,,,,, -178000,4.005548,2.5163414,,,,,,,,,,,,,, -178100,3.9769301,2.533205,,,,,,,,,,,,,, -178200,4.1849957,2.5281832,,,,,,,,,,,,,, -178300,4.0251365,2.5050347,,,,,,,,,,,,,, -178400,3.7961392,2.490074,,,,,,,,,,,,,, -178500,3.9696097,2.5162852,,,,,,,,,,,,,, -178600,4.0742373,2.5951028,,,,,,,,,,,,,, -178700,3.977338,2.5254953,,,,,,,,,,,,,, -178800,4.0331755,2.4989781,,,,,,,,,,,,,, -178900,4.3133874,2.517861,,,,,,,,,,,,,, -178955,,,0.9386160373687744,0.4601021409034729,0.7621399760246277,1.1740968227386477,50000.0,0.6452000141143799,1.7911723852157593,10000.0,60734.13635802269,62851.80583786965,60734.13635802269,2106.0333411693573,5.364094734191895,0.0 -179000,4.032121,2.4868922,,,,,,,,,,,,,, -179100,4.0008407,2.5605223,,,,,,,,,,,,,, -179200,4.083178,2.4838328,,,,,,,,,,,,,, -179300,4.030911,2.494258,,,,,,,,,,,,,, -179400,3.8014681,2.4818559,,,,,,,,,,,,,, -179500,3.929987,2.5200813,,,,,,,,,,,,,, -179600,4.2871995,2.5074801,,,,,,,,,,,,,, -179700,3.8387723,2.478395,,,,,,,,,,,,,, -179800,4.104311,2.513384,,,,,,,,,,,,,, -179900,4.2381306,2.4892075,,,,,,,,,,,,,, -180000,3.6660085,2.4804618,,,,,,,,,,,,,, -180100,4.4550533,2.5360098,,,,,,,,,,,,,, -180200,3.9984055,2.4747288,,,,,,,,,,,,,, -180300,4.249075,2.5019624,,,,,,,,,,,,,, -180400,3.8668268,2.5233774,,,,,,,,,,,,,, -180459,,,0.9403499364852904,0.4558300673961639,0.7623599767684937,1.1769626140594482,50000.0,0.6443000435829163,1.7931833267211914,10000.0,61244.19420695305,63379.31936812401,61244.19420695305,2123.3806269168854,5.421769380569458,0.0 -180500,4.208906,2.5008166,,,,,,,,,,,,,, -180600,3.9777875,2.4915261,,,,,,,,,,,,,, -180700,4.4048977,2.5566754,,,,,,,,,,,,,, -180800,4.079458,2.5348003,,,,,,,,,,,,,, -180900,3.8503344,2.466959,,,,,,,,,,,,,, -181000,4.0342145,2.500301,,,,,,,,,,,,,, -181100,3.8786979,2.5469801,,,,,,,,,,,,,, -181200,4.310166,2.4980736,,,,,,,,,,,,,, -181300,4.032121,2.504634,,,,,,,,,,,,,, -181400,4.385073,2.5168808,,,,,,,,,,,,,, -181500,3.79167,2.4528174,,,,,,,,,,,,,, -181600,4.3573475,2.5498912,,,,,,,,,,,,,, -181700,4.138976,2.5298653,,,,,,,,,,,,,, -181800,4.373277,2.560064,,,,,,,,,,,,,, -181900,4.0464354,2.5344274,,,,,,,,,,,,,, -181963,,,0.9390943646430968,0.4599470198154449,0.7626799941062927,1.1733717918395996,50000.0,0.6443000435829163,1.7920598983764648,10000.0,61754.27343964577,63907.06888461113,61754.27343964577,2140.9366660118103,5.483500957489014,0.0 -182000,4.234528,2.5159562,,,,,,,,,,,,,, -182100,4.1893425,2.5499315,,,,,,,,,,,,,, -182200,4.5352883,2.5506594,,,,,,,,,,,,,, -182300,4.0198936,2.5229702,,,,,,,,,,,,,, -182400,4.449576,2.524324,,,,,,,,,,,,,, -182500,4.184834,2.5618966,,,,,,,,,,,,,, -182600,4.092659,2.555935,,,,,,,,,,,,,, -182700,4.1695623,2.5391197,,,,,,,,,,,,,, -182800,4.4061613,2.5180032,,,,,,,,,,,,,, -182900,4.1080503,2.5437121,,,,,,,,,,,,,, -183000,4.2100058,2.5471523,,,,,,,,,,,,,, -183100,4.2720895,2.5596414,,,,,,,,,,,,,, -183200,4.049361,2.493175,,,,,,,,,,,,,, -183300,4.2716427,2.5626447,,,,,,,,,,,,,, -183400,4.052628,2.571228,,,,,,,,,,,,,, -183465,,,0.9401307106018066,0.4537459015846252,0.7622999548912048,1.1755332946777344,50000.0,0.6443000435829163,1.7937980890274048,10000.0,62263.354476451874,64434.93941235542,62263.354476451874,2158.5463812351227,6.609484672546387,0.0 -183500,4.1107116,2.5423274,,,,,,,,,,,,,, -183600,4.2506857,2.6052558,,,,,,,,,,,,,, -183700,3.9571373,2.4421964,,,,,,,,,,,,,, -183800,4.2124705,2.518375,,,,,,,,,,,,,, -183900,4.4010634,2.5518446,,,,,,,,,,,,,, -184000,4.204643,2.555203,,,,,,,,,,,,,, -184100,4.08562,2.5593975,,,,,,,,,,,,,, -184200,4.412939,2.5448315,,,,,,,,,,,,,, -184300,4.005271,2.5161586,,,,,,,,,,,,,, -184400,4.0290675,2.522457,,,,,,,,,,,,,, -184500,4.060093,2.553502,,,,,,,,,,,,,, -184600,3.9174025,2.4788268,,,,,,,,,,,,,, -184700,3.926142,2.5234034,,,,,,,,,,,,,, -184800,3.8175156,2.4797435,,,,,,,,,,,,,, -184900,4.1877413,2.512227,,,,,,,,,,,,,, -184970,,,0.9401506781578064,0.4514946043491363,0.7625799775123596,1.1717240810394287,50000.0,0.6452000141143799,1.7898361682891846,10000.0,62773.46354103088,64962.63630104065,62773.46354103088,2176.019249677658,6.671429634094238,0.0 -185000,3.9803984,2.5123444,,,,,,,,,,,,,, -185100,3.9931023,2.5715928,,,,,,,,,,,,,, -185200,4.3875933,2.5233178,,,,,,,,,,,,,, -185300,4.015665,2.5206876,,,,,,,,,,,,,, -185400,3.643542,2.4754303,,,,,,,,,,,,,, -185500,3.9669895,2.5144897,,,,,,,,,,,,,, -185600,4.158122,2.5735288,,,,,,,,,,,,,, -185663,,,,,,,,,,,63008.241268634796,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/eval_measurements.csv deleted file mode 100644 index f329235ff..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.785524606704712,0.0,36.148282289505005,1,0,36.148282289505005,0.0013000001199543,6.9117279052734375,10000,53.93392467498779,0.0006776147638447,6.912829399108887,0.0006799999973736,6.912051200866699,50000 -35.466519832611084,0.0201263427734375,546.2268629074097,1500,0,546.2268629074097,0.0485000014305114,5.6236677169799805,10000,581.7664752006531,0.0702925696969032,5.318418025970459,0.0672199949622154,5.3727641105651855,50000 -53.3863730430603,0.0508553981781005,1056.4533824920654,2999,0,1056.4533824920654,0.1115000024437904,4.811929702758789,10000,1109.996464729309,0.1759805381298065,4.236908912658691,0.1589599996805191,4.364312648773193,50000 -71.34168648719788,0.0803256034851074,1566.5691194534302,4427,0,1566.5691194534302,0.1784000098705291,4.235896110534668,10000,1638.1481268405914,0.2703284323215484,3.506873846054077,0.2503399848937988,3.643269062042236,50000 -89.21532106399536,0.1138412952423095,2076.4831607341766,5925,0,2076.4831607341766,0.2468000054359436,3.741469383239746,10000,2166.0223004817963,0.3623844087123871,2.9141533374786377,0.3374599814414978,3.0710597038269043,50000 -107.07217979431152,0.1446743011474609,2586.6451222896576,7424,0,2586.6451222896576,0.289000004529953,3.5003814697265625,10000,2694.126036405564,0.4332748651504516,2.47650408744812,0.3885799944400787,2.764934062957764,50000 -125.8609426021576,0.1740093231201172,3096.7352623939514,8923,0,3096.7352623939514,0.3427000045776367,3.144484043121338,10000,3223.0867116451263,0.5002790093421936,2.140385389328003,0.4495999813079834,2.4287495613098145,50000 -143.74170780181885,0.2023537158966064,3606.813559293747,10423,0,3606.813559293747,0.3750000298023224,2.9806602001190186,10000,3751.127242565155,0.5160634517669678,2.058394432067871,0.4732999801635742,2.3019864559173584,50000 -161.58436393737793,0.2334496974945068,4116.920515537262,11923,0,4116.920515537262,0.3921000063419342,2.822295665740967,10000,4279.158988952637,0.5541493892669678,1.873650074005127,0.510159969329834,2.0996718406677246,50000 -179.24028539657593,0.2617185115814209,4627.143044948578,13425,0,4627.143044948578,0.4071000218391418,2.80188512802124,10000,4807.1185483932495,0.5691764950752258,1.8064440488815308,0.527899980545044,2.027393579483032,50000 -196.7713873386383,0.2908220291137695,5137.219739675522,14927,0,5137.219739675522,0.418500006198883,2.6923470497131348,10000,5334.8082802295685,0.5782844424247742,1.7499090433120728,0.5433799624443054,1.9397135972976685,50000 -214.36780762672424,0.3212883472442627,5647.431823968887,16430,0,5647.431823968887,0.4326000213623047,2.6046390533447266,10000,5862.699553728104,0.5916174650192261,1.6916300058364868,0.5521799921989441,1.8945378065109253,50000 -232.15265464782715,0.3548550605773926,6157.574437856674,17934,0,6157.574437856674,0.4327000081539154,2.6345584392547607,10000,6390.7123556137085,0.6174266338348389,1.5416178703308103,0.5548799633979797,1.8838409185409544,50000 -250.0759198665619,0.387291669845581,6667.758923768997,19438,0,6667.758923768997,0.4364000260829925,2.5843722820281982,10000,6918.904389619827,0.6119060516357422,1.5986028909683228,0.5607999563217163,1.8535418510437007,50000 -267.61257815361023,0.4182553291320801,7177.744247436523,20942,0,7177.744247436523,0.4525000154972076,2.538072109222412,10000,7446.508990764618,0.6103515625,1.5754483938217163,0.5648599863052368,1.8282480239868164,50000 -285.21574664115906,0.4493522644042969,7687.6893701553345,22446,0,7687.6893701553345,0.4520000219345093,2.550936698913574,10000,7974.141838550568,0.6111288070678711,1.5875287055969238,0.572219967842102,1.8043186664581297,50000 -302.8937132358551,0.4838879108428955,8197.739970445633,23950,0,8197.739970445633,0.4563000202178955,2.495838165283203,10000,8501.958456516266,0.6212531924247742,1.5448447465896606,0.5806999802589417,1.7550562620162964,50000 -320.63299918174744,0.5171177387237549,8707.943544864655,25455,0,8707.943544864655,0.4638000130653381,2.4377589225769043,10000,9029.98591208458,0.6300023794174194,1.4953486919403076,0.5855000019073486,1.7113131284713743,50000 -338.6408226490021,0.5527477264404297,9218.090360164642,26960,0,9218.090360164642,0.4606000185012817,2.4801058769226074,10000,9558.22905421257,0.6515864133834839,1.3849291801452637,0.5827599763870239,1.7557555437088013,50000 -356.16178250312805,0.5854485034942627,9728.102895259855,28465,0,9728.102895259855,0.4606000185012817,2.44994854927063,10000,10085.848005533218,0.6382134556770325,1.4549899101257324,0.5848399996757507,1.734142184257507,50000 -373.8170976638794,0.6196649074554443,10238.321905136108,29971,0,10238.321905136108,0.4678000211715698,2.4534895420074463,10000,10613.808283567429,0.6387914419174194,1.470683455467224,0.5878399610519409,1.7269160747528076,50000 -391.4266884326935,0.6621429920196533,10748.38038611412,31476,0,10748.38038611412,0.4696000218391418,2.4204015731811523,10000,11141.573992013931,0.6308194994926453,1.480445384979248,0.5906999707221985,1.7181880474090576,50000 -409.0134968757629,0.6950950622558594,11258.525603055954,32982,0,11258.525603055954,0.4759000241756439,2.3977839946746826,10000,11669.392268419266,0.6412228941917419,1.4475443363189695,0.598639965057373,1.678407907485962,50000 -426.486044883728,0.7312412261962891,11768.464093208311,34487,0,11768.464093208311,0.4757000207901001,2.399554967880249,10000,12196.891924381256,0.6408242583274841,1.449402093887329,0.5999000072479248,1.667544960975647,50000 -444.545640707016,0.7679879665374756,12278.674576044084,35993,0,12278.674576044084,0.4733000099658966,2.408229112625122,10000,12725.251027584076,0.668965220451355,1.3015565872192385,0.5925999879837036,1.6919569969177246,50000 -462.1173481941223,0.802004337310791,12788.770513534546,37498,0,12788.770513534546,0.480400025844574,2.377382278442383,10000,13253.0048995018,0.6574656963348389,1.356908917427063,0.6019399762153625,1.649037480354309,50000 -479.8211030960083,0.8365299701690674,13298.82503771782,39004,0,13298.82503771782,0.4722000360488891,2.388904571533203,10000,13780.851271390917,0.6500119566917419,1.40574049949646,0.5988199710845947,1.6665188074111938,50000 -497.8423953056336,0.8696949481964111,13808.78684091568,40510,0,13808.78684091568,0.4780000150203705,2.3571691513061523,10000,14308.921383857729,0.6475805044174194,1.4063873291015625,0.6014999747276306,1.6467198133468628,50000 -515.2958283424377,0.9089441299438475,14318.928505182266,42017,0,14318.928505182266,0.4809000194072723,2.4058632850646973,10000,14836.609621286392,0.6520049571990967,1.406828999519348,0.6034199595451355,1.6507673263549805,50000 -532.904883146286,0.9475512504577636,14829.094542264938,43524,0,14829.094542264938,0.4820000231266022,2.3409223556518555,10000,15364.478039741516,0.6497727632522583,1.3987377882003784,0.6049000024795532,1.6419485807418823,50000 -550.5144395828247,0.9823529720306396,15339.251924276352,45031,0,15339.251924276352,0.4817000329494476,2.35010290145874,10000,15892.332380533218,0.6812220811843872,1.246896505355835,0.6029199957847595,1.6404823064804075,50000 -568.1608099937439,1.0166702270507812,15849.248585939407,46537,0,15849.248585939407,0.4793000221252441,2.3764281272888184,10000,16420.061936855316,0.6583425998687744,1.357541561126709,0.600600004196167,1.658661723136902,50000 -586.4911158084869,1.051117181777954,16359.366846084597,48044,0,16359.366846084597,0.4882000088691711,2.325078010559082,10000,16948.598692417145,0.6581233739852905,1.347723126411438,0.6038599610328674,1.6352239847183228,50000 -603.9879791736603,1.0860350131988523,16869.45350933075,49551,0,16869.45350933075,0.4693000316619873,2.4408493041992188,10000,17476.268760681152,0.6349250674247742,1.4702372550964355,0.5914999842643738,1.712287425994873,50000 -621.3797891139984,1.1208219528198242,17379.67597913742,51058,0,17379.67597913742,0.4838000237941742,2.329887628555298,10000,18003.970739126205,0.6541573405265808,1.3928905725479126,0.6054399609565735,1.634318232536316,50000 -639.0722250938416,1.1622974872589111,17889.827553749084,52566,0,17889.827553749084,0.4845000207424164,2.343546152114868,10000,18531.911148548126,0.6519849896430969,1.396129488945007,0.6064199805259705,1.6361829042434692,50000 -656.536458492279,1.201387643814087,18400.011061429977,54074,0,18400.011061429977,0.4937000274658203,2.290510654449463,10000,19059.650450468063,0.6962292790412903,1.2013616561889648,0.6180799603462219,1.5881798267364502,50000 -674.1192197799683,1.2447161674499512,18910.092315673828,55581,0,18910.092315673828,0.5017000436782837,2.253861904144287,10000,19587.41103410721,0.680086076259613,1.259409785270691,0.6174600124359131,1.571677565574646,50000 -692.030259847641,1.281590461730957,19420.2594435215,57089,0,19420.2594435215,0.4967000186443329,2.2936458587646484,10000,20115.578449726105,0.6758211255073547,1.280925989151001,0.6241399645805359,1.5517821311950684,50000 -709.8440093994141,1.3187220096588137,19930.20670747757,58596,0,19930.20670747757,0.49590003490448,2.289567470550537,10000,20643.429992198944,0.6629464030265808,1.3316421508789062,0.6131399869918823,1.595544934272766,50000 -727.8984732627869,1.35973858833313,20440.350209712986,60104,0,20440.350209712986,0.4999000132083893,2.261572360992432,10000,21171.72056317329,0.6681082248687744,1.3214025497436523,0.6195399761199951,1.5713294744491575,50000 -745.5649440288544,1.3957176208496094,20950.476552248,61612,0,20950.476552248,0.4922000169754028,2.2768988609313965,10000,21699.60069823265,0.6642019748687744,1.337678074836731,0.6198999881744385,1.577870488166809,50000 -763.08553647995,1.4341421127319336,21460.725561141968,63120,0,21460.725561141968,0.4999000132083893,2.285916805267334,10000,22227.46119689941,0.6990194320678711,1.1689326763153076,0.6167399883270264,1.5819714069366455,50000 -780.9508357048035,1.477266550064087,21970.70988535881,64628,0,21970.70988535881,0.4997000098228454,2.266094207763672,10000,22755.408446788788,0.6875,1.2308303117752075,0.6247199773788452,1.5516197681427002,50000 -798.4543952941895,1.5191307067871094,22480.74132156372,66136,0,22480.74132156372,0.5041000247001648,2.2609455585479736,10000,23283.03837132454,0.6802455186843872,1.267142415046692,0.6260600090026855,1.5435158014297483,50000 -815.9394180774689,1.5590641498565674,22990.79816222191,67644,0,22990.79816222191,0.4953000247478485,2.2861030101776123,10000,23810.67390537262,0.6720344424247742,1.2921645641326904,0.624779999256134,1.5412269830703735,50000 -833.3083035945892,1.5970518589019775,23500.92076444626,69152,0,23500.92076444626,0.4996000230312347,2.2561073303222656,10000,24338.25503468513,0.6754822731018066,1.2919354438781738,0.626800000667572,1.539273738861084,50000 -851.2584345340729,1.6391994953155518,24011.123149871823,70660,0,24011.123149871823,0.5056000351905823,2.218758344650269,10000,24866.50309228897,0.6835139989852905,1.258325219154358,0.6360599994659424,1.496769666671753,50000 -868.8834428787231,1.677199125289917,24521.15534877777,72168,0,24521.15534877777,0.5113000273704529,2.2167670726776123,10000,25394.25145077705,0.7145846486091614,1.0963265895843506,0.6308599710464478,1.5195053815841677,50000 -886.6075978279114,1.71752667427063,25031.063472747803,73676,0,25031.063472747803,0.5099000334739685,2.2187116146087646,10000,25921.97821545601,0.698660671710968,1.1710692644119265,0.6325399875640869,1.5120478868484497,50000 -904.4073250293732,1.7588274478912354,25541.049762249,75184,0,25541.049762249,0.5022000074386597,2.28255581855774,10000,26449.85903573036,0.6887555718421936,1.2242296934127808,0.6289199590682983,1.5229800939559937,50000 -922.2355952262878,1.813603639602661,26051.24068403244,76692,0,26051.24068403244,0.5051000118255615,2.2224087715148926,10000,26977.98685336113,0.6845503449440002,1.2397161722183228,0.6326799988746643,1.4972995519638062,50000 -939.9247233867644,1.8545491695404053,26561.368300914764,78200,0,26561.368300914764,0.5103999972343445,2.199228525161743,10000,27505.897212982178,0.6881178021430969,1.235331654548645,0.6344999670982361,1.5009933710098269,50000 -957.6483449935912,1.8976809978485107,27071.600769996643,79708,0,27071.600769996643,0.5179000496864319,2.2048068046569824,10000,28033.949457883835,0.6924824714660645,1.204252004623413,0.6411600112915039,1.4582104682922363,50000 -975.4938888549804,1.9428644180297847,27581.69038462639,81216,0,27581.69038462639,0.522100031375885,2.1594133377075195,10000,28561.98272919655,0.7299705147743225,1.0363904237747192,0.6387400031089783,1.4843218326568604,50000 -993.032985687256,1.9836251735687256,28091.66532254219,82723,0,28091.66532254219,0.513700008392334,2.1748580932617188,10000,29089.590743780136,0.7124919891357422,1.1045640707015991,0.6467999815940857,1.4450737237930298,50000 -1010.554886817932,2.03149938583374,28601.680331230164,84231,0,28601.680331230164,0.524399995803833,2.1410605907440186,10000,29617.22701382637,0.708426296710968,1.1359773874282837,0.6473000049591064,1.4457052946090698,50000 -1028.6400225162506,2.074836492538452,29111.69335460663,85739,0,29111.69335460663,0.5109000205993652,2.174423217773437,10000,30145.42070817948,0.697684109210968,1.191365122795105,0.64028000831604,1.475667953491211,50000 -1046.26478099823,2.1219582557678223,29621.69142627716,87245,0,29621.69142627716,0.5249000191688538,2.1454508304595947,10000,30673.143434762955,0.7034239172935486,1.1476686000823977,0.6473599672317505,1.432129144668579,50000 -1063.9508562088013,2.1666407585144043,30131.69265937805,88753,0,30131.69265937805,0.5250000357627869,2.169255256652832,10000,31200.92763876915,0.69921875,1.1784825325012207,0.6462599635124207,1.4546048641204834,50000 -1081.7278769016266,2.2114133834838867,30641.757912397385,90261,0,30641.757912397385,0.5344000458717346,2.106358766555786,10000,31728.86659193039,0.7489835619926453,0.97758811712265,0.6581000089645386,1.4023665189743042,50000 -1099.4261264801023,2.254272222518921,31151.930895090103,91769,0,31151.930895090103,0.5313000082969666,2.1127941608428955,10000,32256.83377289772,0.720703125,1.0760672092437744,0.6516799926757812,1.4240094423294067,50000 -1117.2352879047394,2.3011181354522705,31662.01350545883,93277,0,31662.01350545883,0.5184000134468079,2.1714203357696533,10000,32784.82475876808,0.7053571343421936,1.1407818794250488,0.6459000110626221,1.4545795917510986,50000 -1134.9675867557526,2.345402479171753,32172.043387413025,94785,0,32172.043387413025,0.5238000154495239,2.137367010116577,10000,33312.68545603752,0.7093630433082581,1.1220866441726685,0.6515799760818481,1.4239590167999268,50000 -1152.641860485077,2.3891897201538086,32682.25458574295,96294,0,32682.25458574295,0.5388000011444092,2.063443183898926,10000,33840.66708111763,0.7216597199440002,1.0726561546325684,0.666979968547821,1.358321189880371,50000 -1170.453558921814,2.43384313583374,33192.23504304886,97802,0,33192.23504304886,0.5335000157356262,2.085366487503052,10000,34368.55675005913,0.7195671200752258,1.0867538452148438,0.6625799536705017,1.372006058692932,50000 -1188.0544037818909,2.485387325286865,33702.44986701012,99311,0,33702.44986701012,0.525600016117096,2.124274492263794,10000,34896.47751951218,0.7441206574440002,0.9868006110191344,0.6504799723625183,1.4178804159164429,50000 -1205.8191084861755,2.532451629638672,34212.6481218338,100820,0,34212.6481218338,0.5522000193595886,2.032019853591919,10000,35424.54218220711,0.7390784025192261,0.991266429424286,0.6676999926567078,1.349393606185913,50000 -1223.5154864788055,2.5815184116363525,34722.8763525486,102329,0,34722.8763525486,0.5373000502586365,2.074155807495117,10000,35952.56840515137,0.7271803021430969,1.0407520532608032,0.6602399945259094,1.3748583793640137,50000 -1241.2360637187958,2.6300241947174072,35232.8334004879,103837,0,35232.8334004879,0.5454000234603882,2.035710573196411,10000,36480.34671187401,0.7323620915412903,1.032604098320007,0.6691199541091919,1.3446305990219116,50000 -1258.6924047470093,2.6776812076568604,35742.87636375427,105345,0,35742.87636375427,0.5454000234603882,2.032726526260376,10000,37007.94652104378,0.7277383208274841,1.057060718536377,0.6671000123023987,1.3545143604278564,50000 -1276.0966138839722,2.722939729690552,36252.81159281731,106853,0,36252.81159281731,0.5489000082015991,2.024749994277954,10000,37535.38425660133,0.7320830821990967,1.035598635673523,0.6726599931716919,1.330732822418213,50000 -1293.7430353164673,2.769303321838379,36762.72781395912,108361,0,36762.72781395912,0.5547000169754028,1.9950766563415527,10000,38063.045838832855,0.7787786722183228,0.8223951458930969,0.6761599779129028,1.301857590675354,50000 -1311.6061923503876,2.820514678955078,37272.70946359634,109869,0,37272.70946359634,0.539900004863739,2.03249454498291,10000,38590.995624780655,0.7466916441917419,0.96091890335083,0.6674599647521973,1.3489824533462524,50000 -1329.1077308654783,2.870955228805542,37782.68553614616,111377,0,37782.68553614616,0.5509000420570374,1.9932727813720703,10000,39118.579266786575,0.7472297549247742,0.9534629583358764,0.6775999665260315,1.3055094480514526,50000 -1346.5370290279388,2.9173662662506104,38292.7685611248,112885,0,38292.7685611248,0.5595000386238098,1.9724422693252563,10000,39646.19147348404,0.7496611475944519,0.9466625452041626,0.6813600063323975,1.2942183017730713,50000 -1364.0172460079193,2.962613582611084,38802.89866161346,114393,0,38802.89866161346,0.5559000372886658,1.9868924617767327,10000,40173.89938187599,0.7482461333274841,0.957245111465454,0.682699978351593,1.284542202949524,50000 -1381.7219486236572,3.011996269226074,39313.07150053978,115901,0,39313.07150053978,0.5552000403404236,1.9634768962860107,10000,40701.87993502617,0.7508171200752258,0.9416345953941344,0.6827799677848816,1.2685418128967283,50000 -1399.292273521423,3.0626745223999023,39823.0315990448,117409,0,39823.0315990448,0.5624000430107117,1.9720399379730225,10000,41229.51292562485,0.7893415093421936,0.7836189866065979,0.6830199956893921,1.2751917839050293,50000 -1417.0705609321594,3.109504461288452,40333.14848613739,118917,0,40333.14848613739,0.5601000189781189,1.9834524393081665,10000,41757.50674414635,0.7732780575752258,0.8545231819152832,0.6861199736595154,1.2665618658065796,50000 -1434.6702933311462,3.158001184463501,40843.37981677055,120426,0,40843.37981677055,0.5587000250816345,1.966181397438049,10000,42285.4398317337,0.7666015625,0.8653361201286316,0.6888200044631958,1.2589260339736938,50000 -1452.312379360199,3.207282543182373,41353.42318153381,121934,0,41353.42318153381,0.5608000159263611,1.9467543363571167,10000,42813.22805428505,0.76664137840271,0.8641605377197266,0.6898999810218811,1.2518686056137085,50000 -1469.9726836681366,3.264508008956909,41863.442735910416,123442,0,41863.442735910416,0.5678000450134277,1.9441684484481807,10000,43341.01841783524,0.7724210619926453,0.8580043315887451,0.6927599906921387,1.2386715412139893,50000 -1488.4095180034635,3.31595516204834,42373.48074388504,124950,0,42373.48074388504,0.5697000026702881,1.9078247547149656,10000,43869.59917449951,0.7725605964660645,0.8497375845909119,0.6963599920272827,1.2183761596679688,50000 -1506.1870160102844,3.367685079574585,42883.66209101677,126458,0,42883.66209101677,0.5722000002861023,1.912022590637207,10000,44397.662073135376,0.8109255433082581,0.6936679482460022,0.7005800008773804,1.2135677337646484,50000 -1523.7546339035034,3.41774845123291,43393.808292627335,127966,0,43393.808292627335,0.5768000483512878,1.9222203493118288,10000,44925.48057103157,0.7948222160339355,0.7559593319892883,0.7019599676132202,1.1997932195663452,50000 -1541.539691209793,3.468662738800049,43903.809019088745,129474,0,43903.809019088745,0.5685999989509583,1.900969624519348,10000,45453.370572805405,0.7926099896430969,0.7674739360809326,0.70305997133255,1.1998165845870972,50000 -1559.1616296768188,3.518587350845337,44413.95647978783,130983,0,44413.95647978783,0.5809000134468079,1.8419197797775269,10000,45981.24368643761,0.7925103306770325,0.7734469175338745,0.7093999981880188,1.1662369966506958,50000 -1576.9079988002777,3.573139905929565,44924.015117406845,132491,0,44924.015117406845,0.5781000256538391,1.8631477355957031,10000,46509.1582107544,0.7926897406578064,0.7695133090019226,0.7079199552536011,1.17943274974823,50000 -1594.682165145874,3.6277785301208496,45434.11597537994,133999,0,45434.11597537994,0.5849000215530396,1.8598921298980715,10000,47037.14133429527,0.7933274507522583,0.7550987601280212,0.711899995803833,1.1559367179870603,50000 -1612.187058925629,3.67779278755188,45944.17465925217,135507,0,45944.17465925217,0.591200053691864,1.8425389528274536,10000,47564.80839204788,0.8343032598495483,0.6008027791976929,0.7145000100135803,1.1467500925064087,50000 -1629.722366809845,3.7293410301208496,46454.21942257881,137013,0,46454.21942257881,0.5916000008583069,1.821941494941712,10000,48092.49197125435,0.8226243257522583,0.6424916386604309,0.7174199819564819,1.132103443145752,50000 -1647.4365646839142,3.780620813369751,46964.32555747032,138521,0,46964.32555747032,0.5888000130653381,1.8496946096420288,10000,48620.41685676575,0.8116230964660645,0.6831390261650085,0.7134799957275391,1.1457061767578125,50000 -1664.7287590503693,3.834901571273804,47474.33515691757,140029,0,47474.33515691757,0.5944000482559204,1.816175103187561,10000,49147.82697439194,0.8157086968421936,0.6734943389892578,0.7203199863433838,1.1174829006195068,50000 -1682.08682847023,3.888970851898194,47984.39821815491,141537,0,47984.39821815491,0.597100019454956,1.8183460235595703,10000,49675.356055021286,0.8167450428009033,0.6674581170082092,0.7184199690818787,1.126118779182434,50000 -1699.8958258628843,3.943979978561402,48494.36148881912,143045,0,48494.36148881912,0.5925000309944153,1.8193696737289429,10000,50203.234813690186,0.8122807741165161,0.6797454953193665,0.7173999547958374,1.137130618095398,50000 -1717.4889857769012,3.996210813522339,49004.41987133026,144553,0,49004.41987133026,0.6016000509262085,1.7970024347305298,10000,50730.99198412895,0.8585180044174194,0.5102043747901917,0.7265799641609192,1.0943148136138916,50000 -1735.1880152225494,4.048548221588135,49514.39406490326,146060,0,49514.39406490326,0.6035000085830688,1.7772631645202637,10000,51258.76909947395,0.8483139276504517,0.5484762787818909,0.729420006275177,1.085327386856079,50000 -1752.9253692626953,4.101878881454468,50024.444157123566,147568,0,50024.444157123566,0.603600025177002,1.786201238632202,10000,51786.662527799606,0.8409398794174194,0.5652621984481812,0.7280199527740479,1.085485339164734,50000 -1770.3562195301056,4.158680438995361,50534.34899163246,149075,0,50534.34899163246,0.610200047492981,1.7584848403930664,10000,52314.1056895256,0.8464604616165161,0.5500763654708862,0.7320799827575684,1.0699559450149536,50000 -1787.9360961914062,4.216421842575073,51044.57275414467,150584,0,51044.57275414467,0.612500011920929,1.7552403211593628,10000,52842.02047371864,0.8461614847183228,0.5413140058517456,0.7317999601364136,1.075062870979309,50000 -1805.3098711967468,4.273014545440674,51554.79049611092,152093,0,51554.79049611092,0.6089000105857849,1.747403621673584,10000,53369.72068047524,0.8462013602256775,0.5413247346878052,0.737339973449707,1.056695580482483,50000 -1822.9050323963163,4.3264100551605225,52065.01314878464,153601,0,52065.01314878464,0.612000048160553,1.7551151514053345,10000,53897.64444494248,0.8833306431770325,0.4207604527473449,0.7371399998664856,1.053296685218811,50000 -1840.645209312439,4.384929418563843,52575.09206676483,155109,0,52575.09206676483,0.6145000457763672,1.7276402711868286,10000,54425.57376766205,0.8740832209587097,0.4415238797664642,0.7404199838638306,1.0409177541732788,50000 -1858.6266074180603,4.439204216003418,53084.98903775215,156617,0,53084.98903775215,0.6178000569343567,1.732064127922058,10000,54953.56215620041,0.8761957883834839,0.4427817463874817,0.7436800003051758,1.0364357233047483,50000 -1876.2782986164093,4.496425151824951,53594.90204691887,158125,0,53594.90204691887,0.6204000115394592,1.7091346979141235,10000,55481.23839688301,0.8801418542861938,0.4310168325901031,0.7454400062561035,1.02325177192688,50000 -1893.973302364349,4.555343627929688,54104.91116786003,159633,0,54104.91116786003,0.6222000122070312,1.717750072479248,10000,56009.05380535126,0.8810586333274841,0.4157596826553345,0.7454599738121033,1.0245261192321775,50000 -1911.866572141648,4.608543395996094,54615.10005736351,161141,0,54615.10005736351,0.6239000558853149,1.703717827796936,10000,56537.24241113663,0.8859813213348389,0.4051311314105987,0.7473999857902527,1.0201231241226196,50000 -1929.71240234375,4.6652233600616455,55125.14937853813,162649,0,55125.14937853813,0.625700056552887,1.6878162622451782,10000,57065.24767613411,0.9079041481018066,0.3245623707771301,0.7503199577331543,1.0081355571746826,50000 -1947.10103392601,4.725534439086914,55635.22860836983,164157,0,55635.22860836983,0.6246000528335571,1.704638957977295,10000,57592.82874393463,0.904496133327484,0.3379042744636535,0.7519399523735046,1.0083187818527222,50000 -1964.745409488678,4.782070636749268,56145.23153233528,165665,0,56145.23153233528,0.628600001335144,1.6860158443450928,10000,58120.586990594864,0.9044164419174194,0.3358021378517151,0.7523199915885925,0.9994009137153624,50000 -1982.4217264652248,4.839404821395874,56655.2565972805,167173,0,56655.2565972805,0.6303000450134277,1.696708917617798,10000,58648.39890432358,0.9084821343421936,0.3253271281719208,0.7557199597358704,0.9877996444702148,50000 -2000.068875551224,4.895292043685913,57165.24439907074,168680,0,57165.24439907074,0.6302000284194946,1.6696397066116333,10000,59176.14205694199,0.9084821343421936,0.3189391791820526,0.7572000026702881,0.9814531207084656,50000 -2017.519765138626,4.954912900924683,57675.29413366318,170187,0,57675.29413366318,0.6338000297546387,1.6687065362930298,10000,59703.755058288574,0.9106544852256776,0.3087378144264221,0.7574599981307983,0.9799865484237672,50000 -2035.09224319458,5.01508355140686,58185.4462184906,171695,0,58185.4462184906,0.6383000016212463,1.6713703870773315,10000,60231.59297156334,0.924226701259613,0.2707524597644806,0.757099986076355,0.9797187447547911,50000 -2052.647294998169,5.063016414642334,58695.36663389206,173202,0,58695.36663389206,0.6396000385284424,1.6715961694717407,10000,60759.16719126701,0.9278340339660645,0.2625244557857513,0.758899986743927,0.9757165312767028,50000 -2070.1429677009583,5.118837833404541,59205.45822405815,174710,0,59205.45822405815,0.6392000317573547,1.6742520332336426,10000,61286.8631067276,0.9278539419174194,0.2555326819419861,0.7615399956703186,0.9703835844993592,50000 -2087.3609421253204,5.175340414047241,59715.56953692436,176218,0,59715.56953692436,0.6402000188827515,1.6667834520339966,10000,61814.30304598808,0.926418960094452,0.2620689570903778,0.7624199986457825,0.9670127034187316,50000 -2105.146994113922,5.229976654052734,60225.75075531006,177726,0,60225.75075531006,0.6399000287055969,1.6655884981155396,10000,62342.37773346901,0.927156388759613,0.2595357000827789,0.7625799775123596,0.9643649458885192,50000 -2122.712076902389,5.291700601577759,60735.92390227318,179234,0,60735.92390227318,0.6411000490188599,1.6631746292114258,10000,62870.22953939438,0.9292888641357422,0.2532636523246765,0.7625399827957153,0.962642788887024,50000 -2140.351333856582,5.360998392105103,61246.10619521141,180742,0,61246.10619521141,0.6431000232696533,1.6595665216445925,10000,63398.17158651352,0.9323381781578064,0.2413637936115265,0.763480007648468,0.9599735736846924,50000 -2157.9516232013702,5.419076681137085,61756.06179380417,182249,0,61756.06179380417,0.6432000398635864,1.664105772972107,10000,63925.83683180809,0.9336535334587096,0.2402791529893875,0.7635599970817566,0.9584994316101074,50000 -2175.6238107681274,5.481184005737305,62266.20957398415,183757,0,62266.20957398415,0.6440000534057617,1.6585973501205444,10000,64453.770429611206,0.935566782951355,0.2406303137540817,0.7646999955177307,0.9566633105278016,50000 -2192.9691956043243,5.543259382247925,62776.36577916145,185265,0,62776.36577916145,0.6434000134468079,1.6598241329193115,10000,64981.38829827309,0.9326968789100647,0.24263811111450195,0.7636399865150452,0.9573678374290466,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/measurements.csv deleted file mode 100644 index 33ba0c514..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1986 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.65951836,6.916249,,,,,,,,,,,,,, -1,,,0.0006776147638447,6.912829399108887,0.0006799999973736,6.912051200866699,50000.0,0.0013000001199543,6.9117279052734375,10000.0,36.148282289505005,53.93392467498779,36.148282289505005,17.785524606704712,0.0,0.0 -100,0.6548986,6.903148,,,,,,,,,,,,,, -200,0.6614154,6.8570094,,,,,,,,,,,,,, -300,0.6998633,6.7762957,,,,,,,,,,,,,, -400,0.7719193,6.658673,,,,,,,,,,,,,, -500,0.81616306,6.548357,,,,,,,,,,,,,, -600,0.86838275,6.4411182,,,,,,,,,,,,,, -700,0.883894,6.3774962,,,,,,,,,,,,,, -800,1.0319282,6.282172,,,,,,,,,,,,,, -900,2.6888688,6.1094255,,,,,,,,,,,,,, -1000,1.9430872,5.976159,,,,,,,,,,,,,, -1100,1.7313029,5.9314017,,,,,,,,,,,,,, -1200,1.9765059,5.7721643,,,,,,,,,,,,,, -1300,3.3373702,5.7086983,,,,,,,,,,,,,, -1400,2.5516007,5.609275,,,,,,,,,,,,,, -1500,,,0.0702925696969032,5.318418025970459,0.0672199949622154,5.3727641105651855,50000.0,0.0485000014305114,5.6236677169799805,10000.0,546.2268629074097,581.7664752006531,546.2268629074097,35.466519832611084,0.0201263427734375,0.0 -1500,2.5904968,5.578326,,,,,,,,,,,,,, -1600,3.7318857,5.376524,,,,,,,,,,,,,, -1700,3.2642329,5.2952647,,,,,,,,,,,,,, -1800,2.7201927,5.2965074,,,,,,,,,,,,,, -1900,4.859125,5.2836676,,,,,,,,,,,,,, -2000,3.75284,5.1207323,,,,,,,,,,,,,, -2100,4.181734,5.022377,,,,,,,,,,,,,, -2200,3.8821461,5.010536,,,,,,,,,,,,,, -2300,2.9569414,4.9684243,,,,,,,,,,,,,, -2400,4.4356623,4.9033422,,,,,,,,,,,,,, -2500,4.426328,4.8543367,,,,,,,,,,,,,, -2600,5.7448525,4.71231,,,,,,,,,,,,,, -2700,4.802875,4.8928633,,,,,,,,,,,,,, -2800,9.131086,4.715358,,,,,,,,,,,,,, -2900,5.388628,4.566572,,,,,,,,,,,,,, -2999,,,0.1759805381298065,4.236908912658691,0.1589599996805191,4.364312648773193,50000.0,0.1115000024437904,4.811929702758789,10000.0,1056.4533824920654,1109.996464729309,1056.4533824920654,53.3863730430603,0.0508553981781005,0.0 -3000,7.119295,4.587161,,,,,,,,,,,,,, -3100,4.33375,4.5143204,,,,,,,,,,,,,, -3200,4.750542,4.4868374,,,,,,,,,,,,,, -3300,5.151605,4.570735,,,,,,,,,,,,,, -3400,6.873904,4.4121447,,,,,,,,,,,,,, -3500,5.7714615,4.372432,,,,,,,,,,,,,, -3600,9.803895,4.3299932,,,,,,,,,,,,,, -3700,5.693832,4.1730533,,,,,,,,,,,,,, -3800,6.9544067,4.1887774,,,,,,,,,,,,,, -3900,5.0165176,4.117465,,,,,,,,,,,,,, -4000,5.111242,4.0888877,,,,,,,,,,,,,, -4100,7.1562076,3.9885607,,,,,,,,,,,,,, -4200,11.54062,3.8602612,,,,,,,,,,,,,, -4300,4.7630277,3.9670868,,,,,,,,,,,,,, -4400,6.537289,3.9209318,,,,,,,,,,,,,, -4427,,,0.2703284323215484,3.506873846054077,0.2503399848937988,3.643269062042236,50000.0,0.1784000098705291,4.235896110534668,10000.0,1566.5691194534302,1638.1481268405914,1566.5691194534302,71.34168648719788,0.0803256034851074,0.0 -4500,4.342236,3.9652271,,,,,,,,,,,,,, -4600,5.8825006,3.841396,,,,,,,,,,,,,, -4700,4.8642573,3.8230472,,,,,,,,,,,,,, -4800,6.4669075,3.917869,,,,,,,,,,,,,, -4900,6.329928,3.7305017,,,,,,,,,,,,,, -5000,7.271364,3.8422189,,,,,,,,,,,,,, -5100,7.0585103,3.6060243,,,,,,,,,,,,,, -5200,6.9806147,3.753502,,,,,,,,,,,,,, -5300,5.928885,3.4232886,,,,,,,,,,,,,, -5400,7.8149457,3.6172745,,,,,,,,,,,,,, -5500,9.5067625,3.389955,,,,,,,,,,,,,, -5600,11.508974,3.438755,,,,,,,,,,,,,, -5700,6.7628703,3.3516672,,,,,,,,,,,,,, -5800,7.7831244,3.3008428,,,,,,,,,,,,,, -5900,8.009226,3.4405007,,,,,,,,,,,,,, -5925,,,0.3623844087123871,2.9141533374786377,0.3374599814414978,3.0710597038269043,50000.0,0.2468000054359436,3.741469383239746,10000.0,2076.4831607341766,2166.0223004817963,2076.4831607341766,89.21532106399536,0.1138412952423095,0.0 -6000,7.125751,3.2389119,,,,,,,,,,,,,, -6100,10.360274,3.3307493,,,,,,,,,,,,,, -6200,7.9127192,3.3801708,,,,,,,,,,,,,, -6300,8.742774,3.315382,,,,,,,,,,,,,, -6400,8.393474,3.2659957,,,,,,,,,,,,,, -6500,5.6103764,3.207419,,,,,,,,,,,,,, -6600,6.415726,3.0972912,,,,,,,,,,,,,, -6700,5.999461,3.1670587,,,,,,,,,,,,,, -6800,4.317423,3.0758986,,,,,,,,,,,,,, -6900,5.4294014,3.0836406,,,,,,,,,,,,,, -7000,8.147172,3.19645,,,,,,,,,,,,,, -7100,8.302087,3.0017443,,,,,,,,,,,,,, -7200,10.35302,2.9579523,,,,,,,,,,,,,, -7300,6.53902,3.0261555,,,,,,,,,,,,,, -7400,7.5013328,2.906904,,,,,,,,,,,,,, -7424,,,0.4332748651504516,2.47650408744812,0.3885799944400787,2.764934062957764,50000.0,0.289000004529953,3.5003814697265625,10000.0,2586.6451222896576,2694.126036405564,2586.6451222896576,107.07217979431152,0.1446743011474609,0.0 -7500,5.5728354,3.0601234,,,,,,,,,,,,,, -7600,5.9529567,2.8641694,,,,,,,,,,,,,, -7700,6.782787,2.9065785,,,,,,,,,,,,,, -7800,4.274564,2.8754058,,,,,,,,,,,,,, -7900,5.622924,2.9032707,,,,,,,,,,,,,, -8000,5.9269056,2.891973,,,,,,,,,,,,,, -8100,7.5763655,2.847496,,,,,,,,,,,,,, -8200,5.2753716,2.7782137,,,,,,,,,,,,,, -8300,3.7982109,2.7682428,,,,,,,,,,,,,, -8400,5.1369567,2.7560909,,,,,,,,,,,,,, -8500,6.685697,2.9410036,,,,,,,,,,,,,, -8600,5.7399435,2.8254652,,,,,,,,,,,,,, -8700,4.6366835,2.7127516,,,,,,,,,,,,,, -8800,5.000506,2.8503413,,,,,,,,,,,,,, -8900,4.917245,2.6066375,,,,,,,,,,,,,, -8923,,,0.5002790093421936,2.140385389328003,0.4495999813079834,2.4287495613098145,50000.0,0.3427000045776367,3.144484043121338,10000.0,3096.7352623939514,3223.0867116451263,3096.7352623939514,125.8609426021576,0.1740093231201172,0.0 -9000,7.681655,2.868965,,,,,,,,,,,,,, -9100,6.5871277,2.6505003,,,,,,,,,,,,,, -9200,7.9008856,2.7635384,,,,,,,,,,,,,, -9300,8.045018,2.6970239,,,,,,,,,,,,,, -9400,7.388262,2.5832934,,,,,,,,,,,,,, -9500,8.541205,2.641162,,,,,,,,,,,,,, -9600,4.927603,2.5529938,,,,,,,,,,,,,, -9700,9.054901,2.5911534,,,,,,,,,,,,,, -9800,7.0572143,2.6020634,,,,,,,,,,,,,, -9900,6.0955486,2.588953,,,,,,,,,,,,,, -10000,5.648243,2.5878692,,,,,,,,,,,,,, -10100,8.821382,2.678521,,,,,,,,,,,,,, -10200,6.0498605,2.5654695,,,,,,,,,,,,,, -10300,5.618969,2.5715141,,,,,,,,,,,,,, -10400,7.180645,2.5365276,,,,,,,,,,,,,, -10423,,,0.5160634517669678,2.058394432067871,0.4732999801635742,2.3019864559173584,50000.0,0.3750000298023224,2.9806602001190186,10000.0,3606.813559293747,3751.127242565155,3606.813559293747,143.74170780181885,0.2023537158966064,0.0 -10500,8.183628,2.614114,,,,,,,,,,,,,, -10600,6.7946744,2.4828048,,,,,,,,,,,,,, -10700,7.165433,2.4468274,,,,,,,,,,,,,, -10800,9.237992,2.5008845,,,,,,,,,,,,,, -10900,5.930416,2.6714544,,,,,,,,,,,,,, -11000,5.531698,2.6232572,,,,,,,,,,,,,, -11100,4.760448,2.3223255,,,,,,,,,,,,,, -11200,4.7872796,2.4358227,,,,,,,,,,,,,, -11300,5.446193,2.4062226,,,,,,,,,,,,,, -11400,5.9794374,2.454314,,,,,,,,,,,,,, -11500,7.9794555,2.5618672,,,,,,,,,,,,,, -11600,6.183797,2.4727578,,,,,,,,,,,,,, -11700,7.243012,2.4017959,,,,,,,,,,,,,, -11800,4.7954774,2.5433655,,,,,,,,,,,,,, -11900,6.7866673,2.4872463,,,,,,,,,,,,,, -11923,,,0.5541493892669678,1.873650074005127,0.510159969329834,2.0996718406677246,50000.0,0.3921000063419342,2.822295665740967,10000.0,4116.920515537262,4279.158988952637,4116.920515537262,161.58436393737793,0.2334496974945068,0.0 -12000,4.1124034,2.3994784,,,,,,,,,,,,,, -12100,5.3885226,2.301789,,,,,,,,,,,,,, -12200,7.4762673,2.3997676,,,,,,,,,,,,,, -12300,5.173281,2.4215817,,,,,,,,,,,,,, -12400,5.5474644,2.3921688,,,,,,,,,,,,,, -12500,5.812075,2.4094326,,,,,,,,,,,,,, -12600,5.0257277,2.4804842,,,,,,,,,,,,,, -12700,7.898566,2.336603,,,,,,,,,,,,,, -12800,6.2458634,2.339815,,,,,,,,,,,,,, -12900,8.044458,2.3071418,,,,,,,,,,,,,, -13000,7.4524417,2.2945402,,,,,,,,,,,,,, -13100,5.638094,2.3203106,,,,,,,,,,,,,, -13200,5.9166694,2.3888416,,,,,,,,,,,,,, -13300,5.780845,2.4150133,,,,,,,,,,,,,, -13400,8.011519,2.3650997,,,,,,,,,,,,,, -13425,,,0.5691764950752258,1.8064440488815308,0.527899980545044,2.027393579483032,50000.0,0.4071000218391418,2.80188512802124,10000.0,4627.143044948578,4807.1185483932495,4627.143044948578,179.24028539657593,0.2617185115814209,0.0 -13500,8.489191,2.317394,,,,,,,,,,,,,, -13600,4.323392,2.2612011,,,,,,,,,,,,,, -13700,6.706024,2.1788912,,,,,,,,,,,,,, -13800,9.976341,2.3371181,,,,,,,,,,,,,, -13900,7.8259177,2.2776575,,,,,,,,,,,,,, -14000,8.099494,2.2583282,,,,,,,,,,,,,, -14100,6.3407836,2.2893565,,,,,,,,,,,,,, -14200,6.281924,2.2394793,,,,,,,,,,,,,, -14300,4.6461887,2.2931848,,,,,,,,,,,,,, -14400,9.67504,2.2588596,,,,,,,,,,,,,, -14500,6.9368954,2.2542715,,,,,,,,,,,,,, -14600,4.55069,2.1929226,,,,,,,,,,,,,, -14700,7.975143,2.1693258,,,,,,,,,,,,,, -14800,7.503401,2.326421,,,,,,,,,,,,,, -14900,5.4917583,2.1844137,,,,,,,,,,,,,, -14927,,,0.5782844424247742,1.7499090433120728,0.5433799624443054,1.9397135972976685,50000.0,0.418500006198883,2.6923470497131348,10000.0,5137.219739675522,5334.8082802295685,5137.219739675522,196.7713873386383,0.2908220291137695,0.0 -15000,5.1801376,2.1826122,,,,,,,,,,,,,, -15100,4.4912047,2.3190262,,,,,,,,,,,,,, -15200,4.360242,2.2971659,,,,,,,,,,,,,, -15300,7.5045877,2.2944188,,,,,,,,,,,,,, -15400,5.2173233,2.2937286,,,,,,,,,,,,,, -15500,5.380271,2.3002698,,,,,,,,,,,,,, -15600,4.4391127,2.2815027,,,,,,,,,,,,,, -15700,5.059688,2.1523232,,,,,,,,,,,,,, -15800,6.5801644,2.1815338,,,,,,,,,,,,,, -15900,7.5585885,2.239466,,,,,,,,,,,,,, -16000,3.9360287,2.2785769,,,,,,,,,,,,,, -16100,7.2075677,2.187315,,,,,,,,,,,,,, -16200,5.8723116,2.257892,,,,,,,,,,,,,, -16300,6.436123,2.1905527,,,,,,,,,,,,,, -16400,4.4448028,2.3078074,,,,,,,,,,,,,, -16430,,,0.5916174650192261,1.6916300058364868,0.5521799921989441,1.8945378065109253,50000.0,0.4326000213623047,2.6046390533447266,10000.0,5647.431823968887,5862.699553728104,5647.431823968887,214.36780762672424,0.3212883472442627,0.0 -16500,6.50328,2.2008374,,,,,,,,,,,,,, -16600,6.200751,2.2593572,,,,,,,,,,,,,, -16700,5.656588,2.158751,,,,,,,,,,,,,, -16800,4.1126313,2.1324112,,,,,,,,,,,,,, -16900,6.2925196,2.270992,,,,,,,,,,,,,, -17000,6.499133,2.160109,,,,,,,,,,,,,, -17100,7.526878,2.1112895,,,,,,,,,,,,,, -17200,6.0617576,2.1891327,,,,,,,,,,,,,, -17300,8.609404,2.1873326,,,,,,,,,,,,,, -17400,6.1900887,2.152923,,,,,,,,,,,,,, -17500,4.6248407,2.1846378,,,,,,,,,,,,,, -17600,6.73875,2.275035,,,,,,,,,,,,,, -17700,5.5718393,2.1881075,,,,,,,,,,,,,, -17800,6.0923,2.3574753,,,,,,,,,,,,,, -17900,4.389043,2.0857315,,,,,,,,,,,,,, -17934,,,0.6174266338348389,1.5416178703308103,0.5548799633979797,1.8838409185409544,50000.0,0.4327000081539154,2.6345584392547607,10000.0,6157.574437856674,6390.7123556137085,6157.574437856674,232.15265464782715,0.3548550605773926,0.0 -18000,5.292964,2.2065172,,,,,,,,,,,,,, -18100,5.386982,2.2091572,,,,,,,,,,,,,, -18200,5.592999,2.2625525,,,,,,,,,,,,,, -18300,5.5046372,2.163637,,,,,,,,,,,,,, -18400,5.898975,2.122286,,,,,,,,,,,,,, -18500,4.7474136,2.1999454,,,,,,,,,,,,,, -18600,4.2276263,2.2105043,,,,,,,,,,,,,, -18700,4.608778,2.183138,,,,,,,,,,,,,, -18800,3.419082,2.1652007,,,,,,,,,,,,,, -18900,3.1958628,2.1577818,,,,,,,,,,,,,, -19000,5.698378,2.1451304,,,,,,,,,,,,,, -19100,5.689672,2.1258998,,,,,,,,,,,,,, -19200,6.309278,2.2131062,,,,,,,,,,,,,, -19300,5.159098,2.1867511,,,,,,,,,,,,,, -19400,5.315073,2.1645253,,,,,,,,,,,,,, -19438,,,0.6119060516357422,1.5986028909683228,0.5607999563217163,1.8535418510437007,50000.0,0.4364000260829925,2.5843722820281982,10000.0,6667.758923768997,6918.904389619827,6667.758923768997,250.0759198665619,0.387291669845581,0.0 -19500,6.2870703,2.078244,,,,,,,,,,,,,, -19600,4.574747,2.1237497,,,,,,,,,,,,,, -19700,4.246491,2.108831,,,,,,,,,,,,,, -19800,4.730743,2.0857844,,,,,,,,,,,,,, -19900,4.7512674,2.3354807,,,,,,,,,,,,,, -20000,5.4266005,2.150703,,,,,,,,,,,,,, -20100,4.655987,2.1449077,,,,,,,,,,,,,, -20200,3.1415677,2.1388001,,,,,,,,,,,,,, -20300,3.9928591,2.1609612,,,,,,,,,,,,,, -20400,3.8356028,2.3013475,,,,,,,,,,,,,, -20500,4.8298078,2.192296,,,,,,,,,,,,,, -20600,4.6165624,2.2551425,,,,,,,,,,,,,, -20700,4.0617747,2.0866356,,,,,,,,,,,,,, -20800,3.5555391,2.149085,,,,,,,,,,,,,, -20900,4.715153,2.0427108,,,,,,,,,,,,,, -20942,,,0.6103515625,1.5754483938217163,0.5648599863052368,1.8282480239868164,50000.0,0.4525000154972076,2.538072109222412,10000.0,7177.744247436523,7446.508990764618,7177.744247436523,267.61257815361023,0.4182553291320801,0.0 -21000,4.7456517,2.1141994,,,,,,,,,,,,,, -21100,4.243678,2.1144066,,,,,,,,,,,,,, -21200,3.5396152,2.043796,,,,,,,,,,,,,, -21300,4.96468,2.0533183,,,,,,,,,,,,,, -21400,3.7903602,2.1250288,,,,,,,,,,,,,, -21500,3.2464867,2.1744401,,,,,,,,,,,,,, -21600,4.8575983,2.1489022,,,,,,,,,,,,,, -21700,4.224972,2.023692,,,,,,,,,,,,,, -21800,3.5314906,2.0520623,,,,,,,,,,,,,, -21900,4.2190437,2.0211363,,,,,,,,,,,,,, -22000,5.674565,2.1093857,,,,,,,,,,,,,, -22100,4.0838857,2.2248566,,,,,,,,,,,,,, -22200,4.0369887,2.0264173,,,,,,,,,,,,,, -22300,3.8193326,2.0176256,,,,,,,,,,,,,, -22400,3.0229464,2.1079667,,,,,,,,,,,,,, -22446,,,0.6111288070678711,1.5875287055969238,0.572219967842102,1.8043186664581297,50000.0,0.4520000219345093,2.550936698913574,10000.0,7687.6893701553345,7974.141838550568,7687.6893701553345,285.21574664115906,0.4493522644042969,0.0 -22500,4.7640476,2.027661,,,,,,,,,,,,,, -22600,4.901348,2.091713,,,,,,,,,,,,,, -22700,3.6786962,2.0692816,,,,,,,,,,,,,, -22800,3.0556374,2.0725398,,,,,,,,,,,,,, -22900,4.7092834,2.036696,,,,,,,,,,,,,, -23000,4.2319155,2.0071728,,,,,,,,,,,,,, -23100,3.5155287,2.0695448,,,,,,,,,,,,,, -23200,4.402894,2.0461411,,,,,,,,,,,,,, -23300,4.030592,2.1982672,,,,,,,,,,,,,, -23400,3.5431893,2.1394503,,,,,,,,,,,,,, -23500,4.2888894,2.1753273,,,,,,,,,,,,,, -23600,4.1767197,2.0962663,,,,,,,,,,,,,, -23700,4.0544467,2.0939462,,,,,,,,,,,,,, -23800,4.4796076,2.156255,,,,,,,,,,,,,, -23900,4.190848,2.033574,,,,,,,,,,,,,, -23950,,,0.6212531924247742,1.5448447465896606,0.5806999802589417,1.7550562620162964,50000.0,0.4563000202178955,2.495838165283203,10000.0,8197.739970445633,8501.958456516266,8197.739970445633,302.8937132358551,0.4838879108428955,0.0 -24000,3.781546,2.0759332,,,,,,,,,,,,,, -24100,4.0471706,2.0364041,,,,,,,,,,,,,, -24200,4.0731,2.0264168,,,,,,,,,,,,,, -24300,4.1848645,2.0146055,,,,,,,,,,,,,, -24400,3.526845,2.0394642,,,,,,,,,,,,,, -24500,3.4402914,2.0663428,,,,,,,,,,,,,, -24600,4.357879,2.092758,,,,,,,,,,,,,, -24700,4.2682614,1.991812,,,,,,,,,,,,,, -24800,3.6223388,1.9929996,,,,,,,,,,,,,, -24900,3.613389,2.0734265,,,,,,,,,,,,,, -25000,4.504308,2.031678,,,,,,,,,,,,,, -25100,3.8604982,1.9932401,,,,,,,,,,,,,, -25200,3.9006045,2.0366626,,,,,,,,,,,,,, -25300,4.547307,1.993502,,,,,,,,,,,,,, -25400,4.3587112,2.0540178,,,,,,,,,,,,,, -25455,,,0.6300023794174194,1.4953486919403076,0.5855000019073486,1.7113131284713743,50000.0,0.4638000130653381,2.4377589225769043,10000.0,8707.943544864655,9029.98591208458,8707.943544864655,320.63299918174744,0.5171177387237549,0.0 -25500,4.0521126,1.965327,,,,,,,,,,,,,, -25600,4.0571704,1.9464912,,,,,,,,,,,,,, -25700,3.7653298,2.0489993,,,,,,,,,,,,,, -25800,3.1621013,2.0094817,,,,,,,,,,,,,, -25900,3.6954777,2.1484466,,,,,,,,,,,,,, -26000,3.9476497,2.0103521,,,,,,,,,,,,,, -26100,3.4685776,1.9668487,,,,,,,,,,,,,, -26200,4.0942826,2.049332,,,,,,,,,,,,,, -26300,4.015271,1.9517951,,,,,,,,,,,,,, -26400,4.220454,1.936522,,,,,,,,,,,,,, -26500,3.804019,1.9595397,,,,,,,,,,,,,, -26600,3.0696976,1.9154084,,,,,,,,,,,,,, -26700,3.8088436,1.9977661,,,,,,,,,,,,,, -26800,4.3321543,2.095386,,,,,,,,,,,,,, -26900,3.6781368,2.0554223,,,,,,,,,,,,,, -26960,,,0.6515864133834839,1.3849291801452637,0.5827599763870239,1.7557555437088013,50000.0,0.4606000185012817,2.4801058769226074,10000.0,9218.090360164642,9558.22905421257,9218.090360164642,338.6408226490021,0.5527477264404297,0.0 -27000,4.573027,2.0251882,,,,,,,,,,,,,, -27100,3.2718143,1.8983016,,,,,,,,,,,,,, -27200,3.2337627,1.9148527,,,,,,,,,,,,,, -27300,5.2602353,1.8678337,,,,,,,,,,,,,, -27400,3.8791704,2.0973039,,,,,,,,,,,,,, -27500,3.739907,1.9694095,,,,,,,,,,,,,, -27600,4.0345078,2.0192096,,,,,,,,,,,,,, -27700,4.164549,1.9361947,,,,,,,,,,,,,, -27800,4.3030853,1.9778496,,,,,,,,,,,,,, -27900,3.734059,1.9433962,,,,,,,,,,,,,, -28000,3.5264838,1.9934313,,,,,,,,,,,,,, -28100,3.4488287,2.0437303,,,,,,,,,,,,,, -28200,3.8065753,2.0474913,,,,,,,,,,,,,, -28300,3.9881825,1.982349,,,,,,,,,,,,,, -28400,3.5263846,1.9634067,,,,,,,,,,,,,, -28465,,,0.6382134556770325,1.4549899101257324,0.5848399996757507,1.734142184257507,50000.0,0.4606000185012817,2.44994854927063,10000.0,9728.102895259855,10085.848005533218,9728.102895259855,356.16178250312805,0.5854485034942627,0.0 -28500,3.3604555,1.8371266,,,,,,,,,,,,,, -28600,3.7370238,2.0253599,,,,,,,,,,,,,, -28700,3.5889878,2.0153906,,,,,,,,,,,,,, -28800,3.776688,1.9835131,,,,,,,,,,,,,, -28900,3.3561618,1.893008,,,,,,,,,,,,,, -29000,4.8163366,1.975014,,,,,,,,,,,,,, -29100,3.39612,2.0613232,,,,,,,,,,,,,, -29200,4.7827034,2.0518713,,,,,,,,,,,,,, -29300,3.3844833,2.1497579,,,,,,,,,,,,,, -29400,3.7302246,2.0949104,,,,,,,,,,,,,, -29500,3.597992,1.9655366,,,,,,,,,,,,,, -29600,3.9359593,2.0995238,,,,,,,,,,,,,, -29700,3.5247605,1.9698393,,,,,,,,,,,,,, -29800,4.243985,2.0586677,,,,,,,,,,,,,, -29900,4.456963,1.9526223,,,,,,,,,,,,,, -29971,,,0.6387914419174194,1.470683455467224,0.5878399610519409,1.7269160747528076,50000.0,0.4678000211715698,2.4534895420074463,10000.0,10238.321905136108,10613.808283567429,10238.321905136108,373.8170976638794,0.6196649074554443,0.0 -30000,4.33733,1.9826854,,,,,,,,,,,,,, -30100,4.672706,1.9564342,,,,,,,,,,,,,, -30200,3.7717624,1.9712355,,,,,,,,,,,,,, -30300,3.289637,2.002493,,,,,,,,,,,,,, -30400,3.7317183,2.0432916,,,,,,,,,,,,,, -30500,4.4108644,2.1525736,,,,,,,,,,,,,, -30600,4.1451707,1.9469814,,,,,,,,,,,,,, -30700,3.902689,1.9282762,,,,,,,,,,,,,, -30800,4.261112,2.0559103,,,,,,,,,,,,,, -30900,3.8490758,2.0522304,,,,,,,,,,,,,, -31000,3.383793,2.0206447,,,,,,,,,,,,,, -31100,4.1307983,1.9711193,,,,,,,,,,,,,, -31200,4.173387,1.9863721,,,,,,,,,,,,,, -31300,3.8484526,2.049412,,,,,,,,,,,,,, -31400,3.4371817,1.9147666,,,,,,,,,,,,,, -31476,,,0.6308194994926453,1.480445384979248,0.5906999707221985,1.7181880474090576,50000.0,0.4696000218391418,2.4204015731811523,10000.0,10748.38038611412,11141.573992013931,10748.38038611412,391.4266884326935,0.6621429920196533,0.0 -31500,3.420785,1.9552139,,,,,,,,,,,,,, -31600,3.4911792,1.9251251,,,,,,,,,,,,,, -31700,3.3879488,1.9154049,,,,,,,,,,,,,, -31800,3.8097613,2.023099,,,,,,,,,,,,,, -31900,3.727613,1.9322665,,,,,,,,,,,,,, -32000,3.7287138,2.0151863,,,,,,,,,,,,,, -32100,3.274819,1.927003,,,,,,,,,,,,,, -32200,4.033281,1.8896058,,,,,,,,,,,,,, -32300,3.2601933,1.8812041,,,,,,,,,,,,,, -32400,3.6340332,1.9543606,,,,,,,,,,,,,, -32500,4.167731,2.0250416,,,,,,,,,,,,,, -32600,3.0759692,1.9770054,,,,,,,,,,,,,, -32700,4.084574,1.9274299,,,,,,,,,,,,,, -32800,3.9599128,2.076729,,,,,,,,,,,,,, -32900,3.538557,1.9489621,,,,,,,,,,,,,, -32982,,,0.6412228941917419,1.4475443363189695,0.598639965057373,1.678407907485962,50000.0,0.4759000241756439,2.3977839946746826,10000.0,11258.525603055954,11669.392268419266,11258.525603055954,409.0134968757629,0.6950950622558594,0.0 -33000,3.6875484,1.9356444,,,,,,,,,,,,,, -33100,3.7654936,1.89949,,,,,,,,,,,,,, -33200,3.4757533,1.8318464,,,,,,,,,,,,,, -33300,3.3422613,1.9272305,,,,,,,,,,,,,, -33400,3.3763392,1.9381558,,,,,,,,,,,,,, -33500,3.8663602,1.9690841,,,,,,,,,,,,,, -33600,3.5205846,1.9801959,,,,,,,,,,,,,, -33700,3.8390543,1.9030585,,,,,,,,,,,,,, -33800,3.9724703,2.1068583,,,,,,,,,,,,,, -33900,3.503715,1.9443197,,,,,,,,,,,,,, -34000,3.4737318,2.0112808,,,,,,,,,,,,,, -34100,4.5075316,1.9730545,,,,,,,,,,,,,, -34200,3.769554,1.9735109,,,,,,,,,,,,,, -34300,3.892922,1.9741397,,,,,,,,,,,,,, -34400,4.522958,1.9507601,,,,,,,,,,,,,, -34487,,,0.6408242583274841,1.449402093887329,0.5999000072479248,1.667544960975647,50000.0,0.4757000207901001,2.399554967880249,10000.0,11768.464093208311,12196.891924381256,11768.464093208311,426.486044883728,0.7312412261962891,0.0 -34500,3.2359438,1.9566075,,,,,,,,,,,,,, -34600,3.400955,1.8547888,,,,,,,,,,,,,, -34700,3.7180629,1.9631397,,,,,,,,,,,,,, -34800,3.369701,1.9033681,,,,,,,,,,,,,, -34900,3.3006766,1.9423223,,,,,,,,,,,,,, -35000,3.3985124,1.9786943,,,,,,,,,,,,,, -35100,3.4038432,1.6617475,,,,,,,,,,,,,, -35200,4.4728208,1.9010723,,,,,,,,,,,,,, -35300,3.595992,1.9371204,,,,,,,,,,,,,, -35400,3.6179762,1.8314457,,,,,,,,,,,,,, -35500,3.915561,1.9508595,,,,,,,,,,,,,, -35600,3.4745154,1.855283,,,,,,,,,,,,,, -35700,3.1627624,1.9133154,,,,,,,,,,,,,, -35800,3.3398743,1.9512734,,,,,,,,,,,,,, -35900,3.6318078,1.9317006,,,,,,,,,,,,,, -35993,,,0.668965220451355,1.3015565872192385,0.5925999879837036,1.6919569969177246,50000.0,0.4733000099658966,2.408229112625122,10000.0,12278.674576044084,12725.251027584076,12278.674576044084,444.545640707016,0.7679879665374756,0.0 -36000,5.103674,1.8966502,,,,,,,,,,,,,, -36100,3.8497536,2.0284154,,,,,,,,,,,,,, -36200,3.1838152,1.8974386,,,,,,,,,,,,,, -36300,3.8263314,1.9594088,,,,,,,,,,,,,, -36400,4.3586326,1.9341578,,,,,,,,,,,,,, -36500,4.365921,1.9384336,,,,,,,,,,,,,, -36600,4.0318694,1.9161526,,,,,,,,,,,,,, -36700,3.6971729,1.8553075,,,,,,,,,,,,,, -36800,3.4277232,1.9759631,,,,,,,,,,,,,, -36900,3.2859886,1.9153402,,,,,,,,,,,,,, -37000,3.5172946,2.0612671,,,,,,,,,,,,,, -37100,3.2024825,1.9963317,,,,,,,,,,,,,, -37200,3.6149704,1.946553,,,,,,,,,,,,,, -37300,3.6203554,2.0267184,,,,,,,,,,,,,, -37400,3.9505506,1.9796504,,,,,,,,,,,,,, -37498,,,0.6574656963348389,1.356908917427063,0.6019399762153625,1.649037480354309,50000.0,0.480400025844574,2.377382278442383,10000.0,12788.770513534546,13253.0048995018,12788.770513534546,462.1173481941223,0.802004337310791,0.0 -37500,4.0347395,1.9291161,,,,,,,,,,,,,, -37600,3.7154837,1.9535624,,,,,,,,,,,,,, -37700,3.9076676,1.9279389,,,,,,,,,,,,,, -37800,3.4673555,1.99596,,,,,,,,,,,,,, -37900,3.6776063,1.9167862,,,,,,,,,,,,,, -38000,3.605169,1.8740274,,,,,,,,,,,,,, -38100,4.86987,1.974938,,,,,,,,,,,,,, -38200,3.603687,1.8467927,,,,,,,,,,,,,, -38300,3.121402,1.8176706,,,,,,,,,,,,,, -38400,3.4097047,1.8510377,,,,,,,,,,,,,, -38500,3.842001,1.8729596,,,,,,,,,,,,,, -38600,3.5828273,1.8529702,,,,,,,,,,,,,, -38700,3.5097954,1.9100432,,,,,,,,,,,,,, -38800,3.646894,2.017982,,,,,,,,,,,,,, -38900,3.2651722,1.8505917,,,,,,,,,,,,,, -39000,3.378488,2.0098085,,,,,,,,,,,,,, -39004,,,0.6500119566917419,1.40574049949646,0.5988199710845947,1.6665188074111938,50000.0,0.4722000360488891,2.388904571533203,10000.0,13298.82503771782,13780.851271390917,13298.82503771782,479.8211030960083,0.8365299701690674,0.0 -39100,3.704543,1.9471531,,,,,,,,,,,,,, -39200,3.7630334,1.8698426,,,,,,,,,,,,,, -39300,3.6789355,1.9832858,,,,,,,,,,,,,, -39400,3.4628115,1.8228112,,,,,,,,,,,,,, -39500,3.375383,1.9294084,,,,,,,,,,,,,, -39600,3.4042335,1.8430895,,,,,,,,,,,,,, -39700,3.6880531,1.939328,,,,,,,,,,,,,, -39800,4.6047573,1.9180725,,,,,,,,,,,,,, -39900,4.6284194,1.8678334,,,,,,,,,,,,,, -40000,4.0872893,1.8531362,,,,,,,,,,,,,, -40100,3.7682202,1.880717,,,,,,,,,,,,,, -40200,3.355747,1.8806205,,,,,,,,,,,,,, -40300,3.3630722,1.8797337,,,,,,,,,,,,,, -40400,3.8646886,1.989495,,,,,,,,,,,,,, -40500,3.9764726,1.9030557,,,,,,,,,,,,,, -40510,,,0.6475805044174194,1.4063873291015625,0.6014999747276306,1.6467198133468628,50000.0,0.4780000150203705,2.3571691513061523,10000.0,13808.78684091568,14308.921383857729,13808.78684091568,497.8423953056336,0.8696949481964111,0.0 -40600,3.8262842,1.9143292,,,,,,,,,,,,,, -40700,3.873002,1.9412291,,,,,,,,,,,,,, -40800,3.5839577,1.9672307,,,,,,,,,,,,,, -40900,3.413704,2.0039783,,,,,,,,,,,,,, -41000,5.1481266,1.8257415,,,,,,,,,,,,,, -41100,3.9801297,2.030334,,,,,,,,,,,,,, -41200,3.5426164,1.9380199,,,,,,,,,,,,,, -41300,4.219268,1.9662043,,,,,,,,,,,,,, -41400,4.161106,2.0268683,,,,,,,,,,,,,, -41500,3.3666306,1.836204,,,,,,,,,,,,,, -41600,3.3731153,1.9498245,,,,,,,,,,,,,, -41700,3.680145,1.984288,,,,,,,,,,,,,, -41800,3.8112662,1.8984784,,,,,,,,,,,,,, -41900,3.363484,1.9935826,,,,,,,,,,,,,, -42000,3.5160584,1.9243727,,,,,,,,,,,,,, -42017,,,0.6520049571990967,1.406828999519348,0.6034199595451355,1.6507673263549805,50000.0,0.4809000194072723,2.4058632850646973,10000.0,14318.928505182266,14836.609621286392,14318.928505182266,515.2958283424377,0.9089441299438475,0.0 -42100,3.5262258,1.9847499,,,,,,,,,,,,,, -42200,4.0480003,2.019394,,,,,,,,,,,,,, -42300,3.5340443,1.853649,,,,,,,,,,,,,, -42400,4.853725,1.9893451,,,,,,,,,,,,,, -42500,3.8772047,1.8633076,,,,,,,,,,,,,, -42600,3.7732782,1.8798873,,,,,,,,,,,,,, -42700,4.358775,1.9677702,,,,,,,,,,,,,, -42800,3.8514411,1.9186814,,,,,,,,,,,,,, -42900,3.7851539,1.9292438,,,,,,,,,,,,,, -43000,3.7369957,1.8316785,,,,,,,,,,,,,, -43100,4.1382914,1.8749,,,,,,,,,,,,,, -43200,4.1716237,1.8102335,,,,,,,,,,,,,, -43300,4.07108,1.7966521,,,,,,,,,,,,,, -43400,3.6403248,2.0224319,,,,,,,,,,,,,, -43500,3.631296,1.9738672,,,,,,,,,,,,,, -43524,,,0.6497727632522583,1.3987377882003784,0.6049000024795532,1.6419485807418823,50000.0,0.4820000231266022,2.3409223556518555,10000.0,14829.094542264938,15364.478039741516,14829.094542264938,532.904883146286,0.9475512504577636,0.0 -43600,3.2011883,1.8097069,,,,,,,,,,,,,, -43700,3.1584573,1.9292089,,,,,,,,,,,,,, -43800,3.2669432,2.0169454,,,,,,,,,,,,,, -43900,3.611048,2.0333521,,,,,,,,,,,,,, -44000,3.2799675,1.9621667,,,,,,,,,,,,,, -44100,3.6883812,2.0020885,,,,,,,,,,,,,, -44200,4.357527,1.9521272,,,,,,,,,,,,,, -44300,3.6753073,1.88969,,,,,,,,,,,,,, -44400,3.8424418,1.981469,,,,,,,,,,,,,, -44500,3.6176162,1.8663926,,,,,,,,,,,,,, -44600,4.7133675,1.8938394,,,,,,,,,,,,,, -44700,3.6071932,1.8207333,,,,,,,,,,,,,, -44800,4.340813,1.8974878,,,,,,,,,,,,,, -44900,4.090591,1.8830391,,,,,,,,,,,,,, -45000,4.114465,1.9396517,,,,,,,,,,,,,, -45031,,,0.6812220811843872,1.246896505355835,0.6029199957847595,1.6404823064804075,50000.0,0.4817000329494476,2.35010290145874,10000.0,15339.251924276352,15892.332380533218,15339.251924276352,550.5144395828247,0.9823529720306396,0.0 -45100,4.2628326,1.948144,,,,,,,,,,,,,, -45200,4.534143,1.9901977,,,,,,,,,,,,,, -45300,4.209663,1.8883135,,,,,,,,,,,,,, -45400,5.2720356,1.9986577,,,,,,,,,,,,,, -45500,3.9591014,1.925394,,,,,,,,,,,,,, -45600,3.8939078,1.8889171,,,,,,,,,,,,,, -45700,3.338023,1.8685595,,,,,,,,,,,,,, -45800,4.1770782,1.9153782,,,,,,,,,,,,,, -45900,4.149512,1.7473068,,,,,,,,,,,,,, -46000,4.2704363,1.9286243,,,,,,,,,,,,,, -46100,3.912316,1.8969895,,,,,,,,,,,,,, -46200,3.9695456,2.1063097,,,,,,,,,,,,,, -46300,4.482665,1.9759026,,,,,,,,,,,,,, -46400,3.6261904,1.8604778,,,,,,,,,,,,,, -46500,3.871344,1.948381,,,,,,,,,,,,,, -46537,,,0.6583425998687744,1.357541561126709,0.600600004196167,1.658661723136902,50000.0,0.4793000221252441,2.3764281272888184,10000.0,15849.248585939407,16420.061936855316,15849.248585939407,568.1608099937439,1.0166702270507812,0.0 -46600,3.7186108,1.9206331,,,,,,,,,,,,,, -46700,4.0428863,1.9216177,,,,,,,,,,,,,, -46800,3.5582736,1.8092024,,,,,,,,,,,,,, -46900,4.0791755,1.9294422,,,,,,,,,,,,,, -47000,3.5580366,1.88474,,,,,,,,,,,,,, -47100,4.3057427,2.0210438,,,,,,,,,,,,,, -47200,3.957796,1.839609,,,,,,,,,,,,,, -47300,4.077636,1.9226439,,,,,,,,,,,,,, -47400,3.7719104,1.9447484,,,,,,,,,,,,,, -47500,3.423679,1.8609571,,,,,,,,,,,,,, -47600,3.482052,1.8913765,,,,,,,,,,,,,, -47700,4.5113487,1.9564047,,,,,,,,,,,,,, -47800,4.222157,2.0229096,,,,,,,,,,,,,, -47900,4.333299,1.9686867,,,,,,,,,,,,,, -48000,3.724629,1.9150733,,,,,,,,,,,,,, -48044,,,0.6581233739852905,1.347723126411438,0.6038599610328674,1.6352239847183228,50000.0,0.4882000088691711,2.325078010559082,10000.0,16359.366846084597,16948.598692417145,16359.366846084597,586.4911158084869,1.051117181777954,0.0 -48100,3.9804747,2.049993,,,,,,,,,,,,,, -48200,3.6523213,1.7266058,,,,,,,,,,,,,, -48300,3.6251302,1.9395362,,,,,,,,,,,,,, -48400,3.991004,1.9187541,,,,,,,,,,,,,, -48500,3.9168143,1.7304556,,,,,,,,,,,,,, -48600,4.0607786,1.9350543,,,,,,,,,,,,,, -48700,3.8449926,1.7432064,,,,,,,,,,,,,, -48800,3.8259947,1.9003503,,,,,,,,,,,,,, -48900,3.4710164,1.7985531,,,,,,,,,,,,,, -49000,4.0556993,1.9851327,,,,,,,,,,,,,, -49100,3.6333678,2.0098372,,,,,,,,,,,,,, -49200,4.6800303,1.7835777,,,,,,,,,,,,,, -49300,3.3229187,1.8731242,,,,,,,,,,,,,, -49400,3.8747702,1.7916169,,,,,,,,,,,,,, -49500,3.7281485,2.0374196,,,,,,,,,,,,,, -49551,,,0.6349250674247742,1.4702372550964355,0.5914999842643738,1.712287425994873,50000.0,0.4693000316619873,2.4408493041992188,10000.0,16869.45350933075,17476.268760681152,16869.45350933075,603.9879791736603,1.0860350131988523,0.0 -49600,5.078233,1.8984584,,,,,,,,,,,,,, -49700,4.140148,1.9064949,,,,,,,,,,,,,, -49800,3.5871108,1.7688228,,,,,,,,,,,,,, -49900,4.0479245,1.9402326,,,,,,,,,,,,,, -50000,3.341057,1.8547561,,,,,,,,,,,,,, -50100,3.242098,1.8653519,,,,,,,,,,,,,, -50200,3.7535698,1.9679734,,,,,,,,,,,,,, -50300,4.689464,1.8981049,,,,,,,,,,,,,, -50400,3.5488758,1.8778622,,,,,,,,,,,,,, -50500,4.1456614,1.8657382,,,,,,,,,,,,,, -50600,4.2999234,1.8008882,,,,,,,,,,,,,, -50700,4.0710807,1.8832377,,,,,,,,,,,,,, -50800,3.343669,1.8842219,,,,,,,,,,,,,, -50900,3.6036694,1.6847403,,,,,,,,,,,,,, -51000,3.9404473,1.8244942,,,,,,,,,,,,,, -51058,,,0.6541573405265808,1.3928905725479126,0.6054399609565735,1.634318232536316,50000.0,0.4838000237941742,2.329887628555298,10000.0,17379.67597913742,18003.970739126205,17379.67597913742,621.3797891139984,1.1208219528198242,0.0 -51100,4.5992537,1.8122648,,,,,,,,,,,,,, -51200,4.3970623,1.8849667,,,,,,,,,,,,,, -51300,3.482164,2.0097163,,,,,,,,,,,,,, -51400,3.0963278,1.9185841,,,,,,,,,,,,,, -51500,4.72976,2.0867853,,,,,,,,,,,,,, -51600,3.908988,1.8570337,,,,,,,,,,,,,, -51700,3.8122845,1.8708873,,,,,,,,,,,,,, -51800,3.807851,1.8510802,,,,,,,,,,,,,, -51900,3.478249,1.8838128,,,,,,,,,,,,,, -52000,3.9791558,1.847489,,,,,,,,,,,,,, -52100,4.198033,1.84461,,,,,,,,,,,,,, -52200,3.2269006,1.790124,,,,,,,,,,,,,, -52300,5.099823,1.8224125,,,,,,,,,,,,,, -52400,4.1520257,1.8228157,,,,,,,,,,,,,, -52500,3.9473805,1.8849285,,,,,,,,,,,,,, -52566,,,0.6519849896430969,1.396129488945007,0.6064199805259705,1.6361829042434692,50000.0,0.4845000207424164,2.343546152114868,10000.0,17889.827553749084,18531.911148548126,17889.827553749084,639.0722250938416,1.1622974872589111,0.0 -52600,4.100101,1.8236924,,,,,,,,,,,,,, -52700,3.7236798,1.9317925,,,,,,,,,,,,,, -52800,4.676104,1.8525763,,,,,,,,,,,,,, -52900,4.5332775,1.8569413,,,,,,,,,,,,,, -53000,3.6532583,1.860101,,,,,,,,,,,,,, -53100,4.2403173,1.8310931,,,,,,,,,,,,,, -53200,3.3876717,1.9030443,,,,,,,,,,,,,, -53300,4.167019,1.9376378,,,,,,,,,,,,,, -53400,3.5086544,1.871673,,,,,,,,,,,,,, -53500,4.763943,1.9518213,,,,,,,,,,,,,, -53600,3.2888978,1.9552066,,,,,,,,,,,,,, -53700,3.5896013,1.9049454,,,,,,,,,,,,,, -53800,3.3882887,1.9456526,,,,,,,,,,,,,, -53900,3.576158,1.8850416,,,,,,,,,,,,,, -54000,3.9606104,1.8145359,,,,,,,,,,,,,, -54074,,,0.6962292790412903,1.2013616561889648,0.6180799603462219,1.5881798267364502,50000.0,0.4937000274658203,2.290510654449463,10000.0,18400.011061429977,19059.650450468063,18400.011061429977,656.536458492279,1.201387643814087,0.0 -54100,3.4159963,1.8383527,,,,,,,,,,,,,, -54200,4.2583313,1.8710855,,,,,,,,,,,,,, -54300,3.5325418,1.9093535,,,,,,,,,,,,,, -54400,4.1631613,1.7939106,,,,,,,,,,,,,, -54500,3.7576244,1.8584942,,,,,,,,,,,,,, -54600,4.0719843,1.8307201,,,,,,,,,,,,,, -54700,3.7197704,1.8518447,,,,,,,,,,,,,, -54800,3.6201158,1.6997719,,,,,,,,,,,,,, -54900,4.090229,1.8512009,,,,,,,,,,,,,, -55000,4.183418,1.8003546,,,,,,,,,,,,,, -55100,3.35198,1.810655,,,,,,,,,,,,,, -55200,3.779303,1.8482698,,,,,,,,,,,,,, -55300,3.6184676,1.9181385,,,,,,,,,,,,,, -55400,4.2451086,1.8065,,,,,,,,,,,,,, -55500,3.3654318,1.7471915,,,,,,,,,,,,,, -55581,,,0.680086076259613,1.259409785270691,0.6174600124359131,1.571677565574646,50000.0,0.5017000436782837,2.253861904144287,10000.0,18910.092315673828,19587.41103410721,18910.092315673828,674.1192197799683,1.2447161674499512,0.0 -55600,4.039318,1.8075455,,,,,,,,,,,,,, -55700,3.823744,1.7656617,,,,,,,,,,,,,, -55800,3.6353002,1.8911221,,,,,,,,,,,,,, -55900,3.7399259,1.9111063,,,,,,,,,,,,,, -56000,3.8697908,1.8566241,,,,,,,,,,,,,, -56100,3.9397345,1.9096655,,,,,,,,,,,,,, -56200,3.83084,1.8184199,,,,,,,,,,,,,, -56300,4.357963,1.8301402,,,,,,,,,,,,,, -56400,4.6726885,1.7718401,,,,,,,,,,,,,, -56500,3.7714448,1.8857633,,,,,,,,,,,,,, -56600,3.6396317,1.8308554,,,,,,,,,,,,,, -56700,3.4489672,1.8794664,,,,,,,,,,,,,, -56800,4.5014567,1.7715724,,,,,,,,,,,,,, -56900,3.9361532,1.8188599,,,,,,,,,,,,,, -57000,4.0524035,1.8758833,,,,,,,,,,,,,, -57089,,,0.6758211255073547,1.280925989151001,0.6241399645805359,1.5517821311950684,50000.0,0.4967000186443329,2.2936458587646484,10000.0,19420.2594435215,20115.578449726105,19420.2594435215,692.030259847641,1.281590461730957,0.0 -57100,3.7800546,1.8971844,,,,,,,,,,,,,, -57200,4.037833,1.9041,,,,,,,,,,,,,, -57300,3.4779809,1.7788101,,,,,,,,,,,,,, -57400,3.892965,1.9057614,,,,,,,,,,,,,, -57500,4.1386056,1.8458889,,,,,,,,,,,,,, -57600,4.099463,1.851531,,,,,,,,,,,,,, -57700,3.7457373,1.8307168,,,,,,,,,,,,,, -57800,3.6991725,1.8707818,,,,,,,,,,,,,, -57900,4.012762,1.6870741,,,,,,,,,,,,,, -58000,3.5965695,1.8971416,,,,,,,,,,,,,, -58100,3.6064382,1.8958248,,,,,,,,,,,,,, -58200,3.8533187,1.8325372,,,,,,,,,,,,,, -58300,3.9721625,1.8034477,,,,,,,,,,,,,, -58400,4.1132374,1.8589545,,,,,,,,,,,,,, -58500,4.079478,1.8619653,,,,,,,,,,,,,, -58596,,,0.6629464030265808,1.3316421508789062,0.6131399869918823,1.595544934272766,50000.0,0.49590003490448,2.289567470550537,10000.0,19930.20670747757,20643.429992198944,19930.20670747757,709.8440093994141,1.3187220096588137,0.0 -58600,3.9711473,1.7569642,,,,,,,,,,,,,, -58700,4.7646523,1.8331771,,,,,,,,,,,,,, -58800,3.5244255,1.8329862,,,,,,,,,,,,,, -58900,4.210716,1.8507333,,,,,,,,,,,,,, -59000,4.2570615,1.8004644,,,,,,,,,,,,,, -59100,4.3104863,1.8216587,,,,,,,,,,,,,, -59200,4.5856285,1.8356957,,,,,,,,,,,,,, -59300,4.2198753,1.7883577,,,,,,,,,,,,,, -59400,4.0593657,1.941771,,,,,,,,,,,,,, -59500,4.406976,1.7907118,,,,,,,,,,,,,, -59600,4.178593,1.9230957,,,,,,,,,,,,,, -59700,3.7103398,1.7940608,,,,,,,,,,,,,, -59800,3.7948608,1.7519615,,,,,,,,,,,,,, -59900,4.0073776,1.8272254,,,,,,,,,,,,,, -60000,4.0840044,1.8391849,,,,,,,,,,,,,, -60100,3.7662015,1.7233318,,,,,,,,,,,,,, -60104,,,0.6681082248687744,1.3214025497436523,0.6195399761199951,1.5713294744491575,50000.0,0.4999000132083893,2.261572360992432,10000.0,20440.350209712986,21171.72056317329,20440.350209712986,727.8984732627869,1.35973858833313,0.0 -60200,4.165618,1.8726376,,,,,,,,,,,,,, -60300,4.500872,1.9195852,,,,,,,,,,,,,, -60400,3.67619,1.9287126,,,,,,,,,,,,,, -60500,3.8481584,1.7535356,,,,,,,,,,,,,, -60600,3.614973,1.9379938,,,,,,,,,,,,,, -60700,4.179482,1.7906853,,,,,,,,,,,,,, -60800,3.8103385,1.772399,,,,,,,,,,,,,, -60900,4.2054086,2.0451002,,,,,,,,,,,,,, -61000,3.8893328,1.9387674,,,,,,,,,,,,,, -61100,3.4539702,1.9837818,,,,,,,,,,,,,, -61200,3.8700147,1.8554476,,,,,,,,,,,,,, -61300,4.125886,1.8738459,,,,,,,,,,,,,, -61400,4.811831,1.8195964,,,,,,,,,,,,,, -61500,3.4082794,1.8581703,,,,,,,,,,,,,, -61600,3.712785,1.7951447,,,,,,,,,,,,,, -61612,,,0.6642019748687744,1.337678074836731,0.6198999881744385,1.577870488166809,50000.0,0.4922000169754028,2.2768988609313965,10000.0,20950.476552248,21699.60069823265,20950.476552248,745.5649440288544,1.3957176208496094,0.0 -61700,3.5025032,1.7560368,,,,,,,,,,,,,, -61800,3.4680712,1.815761,,,,,,,,,,,,,, -61900,3.5457764,1.8117268,,,,,,,,,,,,,, -62000,3.6263063,1.7334534,,,,,,,,,,,,,, -62100,3.9077551,1.823132,,,,,,,,,,,,,, -62200,3.6977339,1.7538688,,,,,,,,,,,,,, -62300,3.4872313,1.8192823,,,,,,,,,,,,,, -62400,3.9440463,1.9341726,,,,,,,,,,,,,, -62500,4.597421,1.8704631,,,,,,,,,,,,,, -62600,4.363708,1.8394251,,,,,,,,,,,,,, -62700,3.8108912,1.8439287,,,,,,,,,,,,,, -62800,4.0131865,1.883743,,,,,,,,,,,,,, -62900,3.5775256,1.8921465,,,,,,,,,,,,,, -63000,3.7160268,1.681649,,,,,,,,,,,,,, -63100,5.139548,1.9736557,,,,,,,,,,,,,, -63120,,,0.6990194320678711,1.1689326763153076,0.6167399883270264,1.5819714069366455,50000.0,0.4999000132083893,2.285916805267334,10000.0,21460.725561141968,22227.46119689941,21460.725561141968,763.08553647995,1.4341421127319336,0.0 -63200,3.775115,1.8431333,,,,,,,,,,,,,, -63300,4.77426,1.7357802,,,,,,,,,,,,,, -63400,3.5437615,1.7766457,,,,,,,,,,,,,, -63500,4.317065,1.8404081,,,,,,,,,,,,,, -63600,3.7657788,1.8920768,,,,,,,,,,,,,, -63700,3.964555,1.8634483,,,,,,,,,,,,,, -63800,3.9358478,1.7761108,,,,,,,,,,,,,, -63900,3.6468332,1.7979658,,,,,,,,,,,,,, -64000,3.6330893,1.7248658,,,,,,,,,,,,,, -64100,4.522793,1.7763464,,,,,,,,,,,,,, -64200,3.903381,1.6451797,,,,,,,,,,,,,, -64300,4.0315695,1.7544253,,,,,,,,,,,,,, -64400,3.5300868,1.7391512,,,,,,,,,,,,,, -64500,3.5828602,1.8204279,,,,,,,,,,,,,, -64600,3.4577444,1.7717335,,,,,,,,,,,,,, -64628,,,0.6875,1.2308303117752075,0.6247199773788452,1.5516197681427002,50000.0,0.4997000098228454,2.266094207763672,10000.0,21970.70988535881,22755.408446788788,21970.70988535881,780.9508357048035,1.477266550064087,0.0 -64700,4.026586,1.7837776,,,,,,,,,,,,,, -64800,3.9943664,1.7759898,,,,,,,,,,,,,, -64900,4.3224673,1.7941158,,,,,,,,,,,,,, -65000,4.9145784,1.8036321,,,,,,,,,,,,,, -65100,3.5600684,1.8735968,,,,,,,,,,,,,, -65200,4.4335036,1.8138922,,,,,,,,,,,,,, -65300,3.9659488,1.9094703,,,,,,,,,,,,,, -65400,3.5504596,1.8435032,,,,,,,,,,,,,, -65500,4.2003303,1.8397052,,,,,,,,,,,,,, -65600,4.0413823,1.7715291,,,,,,,,,,,,,, -65700,3.6492538,1.7246616,,,,,,,,,,,,,, -65800,3.4848006,1.7696283,,,,,,,,,,,,,, -65900,3.6390486,1.7781588,,,,,,,,,,,,,, -66000,4.058332,1.8813393,,,,,,,,,,,,,, -66100,3.9919238,1.7435771,,,,,,,,,,,,,, -66136,,,0.6802455186843872,1.267142415046692,0.6260600090026855,1.5435158014297483,50000.0,0.5041000247001648,2.2609455585479736,10000.0,22480.74132156372,23283.03837132454,22480.74132156372,798.4543952941895,1.5191307067871094,0.0 -66200,4.3972445,1.8984635,,,,,,,,,,,,,, -66300,3.7475197,1.8593307,,,,,,,,,,,,,, -66400,3.8412902,1.8062181,,,,,,,,,,,,,, -66500,3.9826126,1.7486826,,,,,,,,,,,,,, -66600,3.5983071,1.8072474,,,,,,,,,,,,,, -66700,4.513105,1.8800124,,,,,,,,,,,,,, -66800,4.581377,1.8086336,,,,,,,,,,,,,, -66900,3.9777734,1.8555248,,,,,,,,,,,,,, -67000,4.167978,1.8224646,,,,,,,,,,,,,, -67100,3.6290138,1.8076648,,,,,,,,,,,,,, -67200,4.4729958,1.7783828,,,,,,,,,,,,,, -67300,4.319245,1.7418956,,,,,,,,,,,,,, -67400,3.6694167,1.7289093,,,,,,,,,,,,,, -67500,4.7683144,1.7749465,,,,,,,,,,,,,, -67600,3.5928042,1.8273938,,,,,,,,,,,,,, -67644,,,0.6720344424247742,1.2921645641326904,0.624779999256134,1.5412269830703735,50000.0,0.4953000247478485,2.2861030101776123,10000.0,22990.79816222191,23810.67390537262,22990.79816222191,815.9394180774689,1.5590641498565674,0.0 -67700,3.7506707,1.7134442,,,,,,,,,,,,,, -67800,4.434529,1.8260705,,,,,,,,,,,,,, -67900,4.229591,1.769142,,,,,,,,,,,,,, -68000,3.7799807,1.8301582,,,,,,,,,,,,,, -68100,4.228023,1.7698435,,,,,,,,,,,,,, -68200,3.880984,1.7818632,,,,,,,,,,,,,, -68300,4.171462,1.7450911,,,,,,,,,,,,,, -68400,4.493407,1.7896631,,,,,,,,,,,,,, -68500,4.588711,1.8740704,,,,,,,,,,,,,, -68600,4.716364,1.725522,,,,,,,,,,,,,, -68700,4.1535745,1.7948635,,,,,,,,,,,,,, -68800,4.106507,1.7846553,,,,,,,,,,,,,, -68900,3.6293755,1.8366683,,,,,,,,,,,,,, -69000,3.6648092,1.7190156,,,,,,,,,,,,,, -69100,4.458842,1.9136685,,,,,,,,,,,,,, -69152,,,0.6754822731018066,1.2919354438781738,0.626800000667572,1.539273738861084,50000.0,0.4996000230312347,2.2561073303222656,10000.0,23500.92076444626,24338.25503468513,23500.92076444626,833.3083035945892,1.5970518589019775,0.0 -69200,3.6135473,1.7077749,,,,,,,,,,,,,, -69300,4.625569,1.7251794,,,,,,,,,,,,,, -69400,5.3598537,1.6783098,,,,,,,,,,,,,, -69500,4.0360446,1.8295066,,,,,,,,,,,,,, -69600,3.5053759,1.6478133,,,,,,,,,,,,,, -69700,3.8355742,1.7400637,,,,,,,,,,,,,, -69800,4.719228,1.8207408,,,,,,,,,,,,,, -69900,4.7756586,1.8084278,,,,,,,,,,,,,, -70000,4.2926183,1.7976996,,,,,,,,,,,,,, -70100,4.164242,1.7257984,,,,,,,,,,,,,, -70200,3.5309763,1.7065375,,,,,,,,,,,,,, -70300,4.4216323,1.7150406,,,,,,,,,,,,,, -70400,4.108333,1.792548,,,,,,,,,,,,,, -70500,4.004658,1.743326,,,,,,,,,,,,,, -70600,3.8744295,1.8027425,,,,,,,,,,,,,, -70660,,,0.6835139989852905,1.258325219154358,0.6360599994659424,1.496769666671753,50000.0,0.5056000351905823,2.218758344650269,10000.0,24011.123149871823,24866.50309228897,24011.123149871823,851.2584345340729,1.6391994953155518,0.0 -70700,3.620082,1.7048607,,,,,,,,,,,,,, -70800,3.7314615,1.7457944,,,,,,,,,,,,,, -70900,4.451253,1.8354621,,,,,,,,,,,,,, -71000,4.044697,1.849945,,,,,,,,,,,,,, -71100,4.1405253,1.9149451,,,,,,,,,,,,,, -71200,3.6995075,1.8659583,,,,,,,,,,,,,, -71300,3.839236,1.6889275,,,,,,,,,,,,,, -71400,3.7376094,1.7601181,,,,,,,,,,,,,, -71500,3.8492832,1.8329977,,,,,,,,,,,,,, -71600,3.6808577,1.889401,,,,,,,,,,,,,, -71700,4.056039,1.7318723,,,,,,,,,,,,,, -71800,4.273508,1.7946438,,,,,,,,,,,,,, -71900,4.651276,1.786656,,,,,,,,,,,,,, -72000,4.30345,1.8748734,,,,,,,,,,,,,, -72100,3.4624193,1.7891569,,,,,,,,,,,,,, -72168,,,0.7145846486091614,1.0963265895843506,0.6308599710464478,1.5195053815841677,50000.0,0.5113000273704529,2.2167670726776123,10000.0,24521.15534877777,25394.25145077705,24521.15534877777,868.8834428787231,1.677199125289917,0.0 -72200,3.518998,1.7032962,,,,,,,,,,,,,, -72300,4.4538445,1.8177392,,,,,,,,,,,,,, -72400,4.2271037,1.8264163,,,,,,,,,,,,,, -72500,4.075352,1.9037656,,,,,,,,,,,,,, -72600,4.071859,1.9247054,,,,,,,,,,,,,, -72700,4.150609,1.7675949,,,,,,,,,,,,,, -72800,4.093104,1.7539643,,,,,,,,,,,,,, -72900,3.9929292,1.7344661,,,,,,,,,,,,,, -73000,3.6992772,1.7481921,,,,,,,,,,,,,, -73100,3.9850128,1.7822753,,,,,,,,,,,,,, -73200,4.038803,1.7414314,,,,,,,,,,,,,, -73300,4.2232966,1.7050319,,,,,,,,,,,,,, -73400,4.151448,1.8086662,,,,,,,,,,,,,, -73500,4.4891453,1.8008316,,,,,,,,,,,,,, -73600,5.0922174,1.768318,,,,,,,,,,,,,, -73676,,,0.698660671710968,1.1710692644119265,0.6325399875640869,1.5120478868484497,50000.0,0.5099000334739685,2.2187116146087646,10000.0,25031.063472747803,25921.97821545601,25031.063472747803,886.6075978279114,1.71752667427063,0.0 -73700,3.8175123,1.7996519,,,,,,,,,,,,,, -73800,3.7102206,1.8668296,,,,,,,,,,,,,, -73900,4.278139,1.7688135,,,,,,,,,,,,,, -74000,4.233358,1.7157843,,,,,,,,,,,,,, -74100,3.8374743,1.738197,,,,,,,,,,,,,, -74200,3.6067393,1.7441028,,,,,,,,,,,,,, -74300,4.8460703,1.8107188,,,,,,,,,,,,,, -74400,4.065711,1.8827153,,,,,,,,,,,,,, -74500,3.9591238,1.7825824,,,,,,,,,,,,,, -74600,3.8996127,1.8351936,,,,,,,,,,,,,, -74700,4.2121835,1.8211817,,,,,,,,,,,,,, -74800,4.366143,1.6565015,,,,,,,,,,,,,, -74900,4.0304284,1.8323331,,,,,,,,,,,,,, -75000,4.5750103,1.8620507,,,,,,,,,,,,,, -75100,3.7370002,1.7304736,,,,,,,,,,,,,, -75184,,,0.6887555718421936,1.2242296934127808,0.6289199590682983,1.5229800939559937,50000.0,0.5022000074386597,2.28255581855774,10000.0,25541.049762249,26449.85903573036,25541.049762249,904.4073250293732,1.7588274478912354,0.0 -75200,3.6969583,1.7411438,,,,,,,,,,,,,, -75300,4.063019,1.7027574,,,,,,,,,,,,,, -75400,3.4957347,1.6258868,,,,,,,,,,,,,, -75500,4.289809,1.7518226,,,,,,,,,,,,,, -75600,4.540641,1.7905822,,,,,,,,,,,,,, -75700,4.326928,1.8008766,,,,,,,,,,,,,, -75800,4.639347,1.7918545,,,,,,,,,,,,,, -75900,4.5268545,1.806392,,,,,,,,,,,,,, -76000,3.9176862,1.6943197,,,,,,,,,,,,,, -76100,4.0417156,1.7191675,,,,,,,,,,,,,, -76200,4.476136,1.5900033,,,,,,,,,,,,,, -76300,4.194317,1.6626012,,,,,,,,,,,,,, -76400,4.357486,1.7815101,,,,,,,,,,,,,, -76500,3.9042463,1.695591,,,,,,,,,,,,,, -76600,4.2776346,1.8146858,,,,,,,,,,,,,, -76692,,,0.6845503449440002,1.2397161722183228,0.6326799988746643,1.4972995519638062,50000.0,0.5051000118255615,2.2224087715148926,10000.0,26051.24068403244,26977.98685336113,26051.24068403244,922.2355952262878,1.813603639602661,0.0 -76700,3.920456,1.8675185,,,,,,,,,,,,,, -76800,4.032026,1.739934,,,,,,,,,,,,,, -76900,3.992158,1.6965442,,,,,,,,,,,,,, -77000,4.3512635,1.6549609,,,,,,,,,,,,,, -77100,4.0698943,1.7722805,,,,,,,,,,,,,, -77200,3.7642694,1.6557436,,,,,,,,,,,,,, -77300,3.917099,1.7591796,,,,,,,,,,,,,, -77400,5.532931,1.78704,,,,,,,,,,,,,, -77500,5.105005,1.7764275,,,,,,,,,,,,,, -77600,4.1306295,1.7226657,,,,,,,,,,,,,, -77700,4.2963896,1.7170236,,,,,,,,,,,,,, -77800,4.051342,1.7934885,,,,,,,,,,,,,, -77900,3.9599662,1.6703693,,,,,,,,,,,,,, -78000,4.4526815,1.6908292,,,,,,,,,,,,,, -78100,4.34778,1.6189586,,,,,,,,,,,,,, -78200,,,0.6881178021430969,1.235331654548645,0.6344999670982361,1.5009933710098269,50000.0,0.5103999972343445,2.199228525161743,10000.0,26561.368300914764,27505.897212982178,26561.368300914764,939.9247233867644,1.8545491695404053,0.0 -78200,3.827973,1.7198726,,,,,,,,,,,,,, -78300,4.6858406,1.7531731,,,,,,,,,,,,,, -78400,3.9912236,1.6718788,,,,,,,,,,,,,, -78500,4.3235893,1.619461,,,,,,,,,,,,,, -78600,4.3511877,1.7838672,,,,,,,,,,,,,, -78700,4.205593,1.7454135,,,,,,,,,,,,,, -78800,3.926227,1.6025794,,,,,,,,,,,,,, -78900,3.7575758,1.6720326,,,,,,,,,,,,,, -79000,3.8882985,1.7285665,,,,,,,,,,,,,, -79100,4.203562,1.7273777,,,,,,,,,,,,,, -79200,4.2988,1.9022999,,,,,,,,,,,,,, -79300,4.310811,1.700346,,,,,,,,,,,,,, -79400,4.4171023,1.7742538,,,,,,,,,,,,,, -79500,4.290045,1.9037071,,,,,,,,,,,,,, -79600,4.7863717,1.7258389,,,,,,,,,,,,,, -79700,4.3070726,1.676297,,,,,,,,,,,,,, -79708,,,0.6924824714660645,1.204252004623413,0.6411600112915039,1.4582104682922363,50000.0,0.5179000496864319,2.2048068046569824,10000.0,27071.600769996643,28033.949457883835,27071.600769996643,957.6483449935912,1.8976809978485107,0.0 -79800,4.1181626,1.6951474,,,,,,,,,,,,,, -79900,5.1166067,1.8747128,,,,,,,,,,,,,, -80000,4.358524,1.6795309,,,,,,,,,,,,,, -80100,4.5361543,1.7433963,,,,,,,,,,,,,, -80200,4.1384587,1.7252462,,,,,,,,,,,,,, -80300,4.4150043,1.7850034,,,,,,,,,,,,,, -80400,3.6470354,1.7671657,,,,,,,,,,,,,, -80500,4.32927,1.7370441,,,,,,,,,,,,,, -80600,3.8505573,1.6925921,,,,,,,,,,,,,, -80700,3.7242796,1.7839855,,,,,,,,,,,,,, -80800,3.8986762,1.7597395,,,,,,,,,,,,,, -80900,4.0836005,1.5981271,,,,,,,,,,,,,, -81000,3.921169,1.705497,,,,,,,,,,,,,, -81100,3.9904253,1.7994336,,,,,,,,,,,,,, -81200,4.394548,1.6666361,,,,,,,,,,,,,, -81216,,,0.7299705147743225,1.0363904237747192,0.6387400031089783,1.4843218326568604,50000.0,0.522100031375885,2.1594133377075195,10000.0,27581.69038462639,28561.98272919655,27581.69038462639,975.4938888549804,1.9428644180297847,0.0 -81300,4.1891294,1.6936016,,,,,,,,,,,,,, -81400,4.065236,1.7064512,,,,,,,,,,,,,, -81500,4.0755086,1.8439492,,,,,,,,,,,,,, -81600,3.966009,1.7051191,,,,,,,,,,,,,, -81700,4.1712775,1.7659229,,,,,,,,,,,,,, -81800,4.2456093,1.7247713,,,,,,,,,,,,,, -81900,4.4293385,1.649469,,,,,,,,,,,,,, -82000,4.912668,1.7192607,,,,,,,,,,,,,, -82100,4.300866,1.6441127,,,,,,,,,,,,,, -82200,4.7335176,1.6398063,,,,,,,,,,,,,, -82300,4.984018,1.6649203,,,,,,,,,,,,,, -82400,3.9404697,1.6756232,,,,,,,,,,,,,, -82500,4.2205057,1.664832,,,,,,,,,,,,,, -82600,3.79195,1.7726148,,,,,,,,,,,,,, -82700,4.428921,1.7790847,,,,,,,,,,,,,, -82723,,,0.7124919891357422,1.1045640707015991,0.6467999815940857,1.4450737237930298,50000.0,0.513700008392334,2.1748580932617188,10000.0,28091.66532254219,29089.590743780136,28091.66532254219,993.032985687256,1.9836251735687256,0.0 -82800,4.087693,1.6492643,,,,,,,,,,,,,, -82900,4.099389,1.747632,,,,,,,,,,,,,, -83000,4.6755166,1.6891239,,,,,,,,,,,,,, -83100,4.3177767,1.7760247,,,,,,,,,,,,,, -83200,3.4927442,1.5276716,,,,,,,,,,,,,, -83300,4.495047,1.7494159,,,,,,,,,,,,,, -83400,3.705497,1.7204244,,,,,,,,,,,,,, -83500,4.3945966,1.7028959,,,,,,,,,,,,,, -83600,4.09145,1.6990309,,,,,,,,,,,,,, -83700,4.6769657,1.7085843,,,,,,,,,,,,,, -83800,4.895309,1.8057092,,,,,,,,,,,,,, -83900,5.1551156,1.7902749,,,,,,,,,,,,,, -84000,4.7395115,1.7290456,,,,,,,,,,,,,, -84100,4.194326,1.734834,,,,,,,,,,,,,, -84200,4.3423915,1.7825885,,,,,,,,,,,,,, -84231,,,0.708426296710968,1.1359773874282837,0.6473000049591064,1.4457052946090698,50000.0,0.524399995803833,2.1410605907440186,10000.0,28601.680331230164,29617.22701382637,28601.680331230164,1010.554886817932,2.03149938583374,0.0 -84300,4.2024493,1.6767933,,,,,,,,,,,,,, -84400,3.7719622,1.6785668,,,,,,,,,,,,,, -84500,4.0115395,1.6488329,,,,,,,,,,,,,, -84600,4.6846333,1.6944228,,,,,,,,,,,,,, -84700,4.401801,1.5967524,,,,,,,,,,,,,, -84800,4.0109124,1.5990847,,,,,,,,,,,,,, -84900,4.7911997,1.8638812,,,,,,,,,,,,,, -85000,4.299272,1.7110498,,,,,,,,,,,,,, -85100,3.8453069,1.726542,,,,,,,,,,,,,, -85200,3.8771968,1.734845,,,,,,,,,,,,,, -85300,4.0334425,1.6278828,,,,,,,,,,,,,, -85400,4.5480385,1.7092663,,,,,,,,,,,,,, -85500,4.8478436,1.7026699,,,,,,,,,,,,,, -85600,3.9277914,1.6647213,,,,,,,,,,,,,, -85700,4.8441753,1.8430371,,,,,,,,,,,,,, -85739,,,0.697684109210968,1.191365122795105,0.64028000831604,1.475667953491211,50000.0,0.5109000205993652,2.174423217773437,10000.0,29111.69335460663,30145.42070817948,29111.69335460663,1028.6400225162506,2.074836492538452,0.0 -85800,3.7147198,1.6847527,,,,,,,,,,,,,, -85900,4.4262233,1.725267,,,,,,,,,,,,,, -86000,4.2281346,1.7895634,,,,,,,,,,,,,, -86100,4.8228316,1.6638589,,,,,,,,,,,,,, -86200,4.24356,1.5993655,,,,,,,,,,,,,, -86300,4.1328616,1.7539632,,,,,,,,,,,,,, -86400,4.099311,1.6109881,,,,,,,,,,,,,, -86500,4.723155,1.665096,,,,,,,,,,,,,, -86600,4.130449,1.6734978,,,,,,,,,,,,,, -86700,3.994406,1.759144,,,,,,,,,,,,,, -86800,5.006904,1.7957158,,,,,,,,,,,,,, -86900,4.358703,1.7250051,,,,,,,,,,,,,, -87000,4.1209736,1.6256533,,,,,,,,,,,,,, -87100,4.541003,1.7073346,,,,,,,,,,,,,, -87200,4.543406,1.7333829,,,,,,,,,,,,,, -87245,,,0.7034239172935486,1.1476686000823977,0.6473599672317505,1.432129144668579,50000.0,0.5249000191688538,2.1454508304595947,10000.0,29621.69142627716,30673.143434762955,29621.69142627716,1046.26478099823,2.1219582557678223,0.0 -87300,4.0462174,1.6236173,,,,,,,,,,,,,, -87400,4.444753,1.6817541,,,,,,,,,,,,,, -87500,4.441548,1.7381082,,,,,,,,,,,,,, -87600,4.5837407,1.7016412,,,,,,,,,,,,,, -87700,4.517736,1.6468296,,,,,,,,,,,,,, -87800,4.3391705,1.6828899,,,,,,,,,,,,,, -87900,4.974826,1.6253918,,,,,,,,,,,,,, -88000,4.857052,1.6512034,,,,,,,,,,,,,, -88100,4.7486777,1.8034939,,,,,,,,,,,,,, -88200,4.1931553,1.6208016,,,,,,,,,,,,,, -88300,4.2234187,1.7291131,,,,,,,,,,,,,, -88400,4.3222446,1.670763,,,,,,,,,,,,,, -88500,4.003148,1.6226978,,,,,,,,,,,,,, -88600,4.885564,1.7170649,,,,,,,,,,,,,, -88700,4.394938,1.6263974,,,,,,,,,,,,,, -88753,,,0.69921875,1.1784825325012207,0.6462599635124207,1.4546048641204834,50000.0,0.5250000357627869,2.169255256652832,10000.0,30131.69265937805,31200.92763876915,30131.69265937805,1063.9508562088013,2.1666407585144043,0.0 -88800,4.067823,1.6567308,,,,,,,,,,,,,, -88900,4.4020395,1.6689317,,,,,,,,,,,,,, -89000,4.0105076,1.6928194,,,,,,,,,,,,,, -89100,4.0216284,1.6716037,,,,,,,,,,,,,, -89200,4.2097235,1.4925842,,,,,,,,,,,,,, -89300,4.4590726,1.6841428,,,,,,,,,,,,,, -89400,4.368488,1.6286792,,,,,,,,,,,,,, -89500,3.9196935,1.6307118,,,,,,,,,,,,,, -89600,4.629742,1.7082915,,,,,,,,,,,,,, -89700,4.4855785,1.5544258,,,,,,,,,,,,,, -89800,4.3106284,1.8031228,,,,,,,,,,,,,, -89900,4.66264,1.8061287,,,,,,,,,,,,,, -90000,3.8977327,1.6444062,,,,,,,,,,,,,, -90100,4.66726,1.6264144,,,,,,,,,,,,,, -90200,4.1370053,1.6445476,,,,,,,,,,,,,, -90261,,,0.7489835619926453,0.97758811712265,0.6581000089645386,1.4023665189743042,50000.0,0.5344000458717346,2.106358766555786,10000.0,30641.757912397385,31728.86659193039,30641.757912397385,1081.7278769016266,2.2114133834838867,0.0 -90300,4.7075596,1.6167698,,,,,,,,,,,,,, -90400,4.83625,1.7679236,,,,,,,,,,,,,, -90500,4.5051064,1.6688757,,,,,,,,,,,,,, -90600,4.3251677,1.7512957,,,,,,,,,,,,,, -90700,4.201393,1.5718809,,,,,,,,,,,,,, -90800,4.5976877,1.6134994,,,,,,,,,,,,,, -90900,4.050402,1.60633,,,,,,,,,,,,,, -91000,4.669969,1.6310513,,,,,,,,,,,,,, -91100,3.7295425,1.6032631,,,,,,,,,,,,,, -91200,4.16385,1.6441066,,,,,,,,,,,,,, -91300,4.1033373,1.6020046,,,,,,,,,,,,,, -91400,3.9140153,1.5215044,,,,,,,,,,,,,, -91500,4.23732,1.6705105,,,,,,,,,,,,,, -91600,4.3449583,1.6933637,,,,,,,,,,,,,, -91700,4.7942586,1.7739276,,,,,,,,,,,,,, -91769,,,0.720703125,1.0760672092437744,0.6516799926757812,1.4240094423294067,50000.0,0.5313000082969666,2.1127941608428955,10000.0,31151.930895090103,32256.83377289772,31151.930895090103,1099.4261264801023,2.254272222518921,0.0 -91800,5.2195883,1.7004693,,,,,,,,,,,,,, -91900,4.4095592,1.5813553,,,,,,,,,,,,,, -92000,4.485383,1.7354491,,,,,,,,,,,,,, -92100,4.2462006,1.6712284,,,,,,,,,,,,,, -92200,4.5095854,1.5755785,,,,,,,,,,,,,, -92300,4.2625937,1.572634,,,,,,,,,,,,,, -92400,4.9766273,1.659039,,,,,,,,,,,,,, -92500,4.1445312,1.7117827,,,,,,,,,,,,,, -92600,4.2069116,1.6777073,,,,,,,,,,,,,, -92700,4.3015466,1.6310632,,,,,,,,,,,,,, -92800,4.5047846,1.8163896,,,,,,,,,,,,,, -92900,4.3314714,1.6551467,,,,,,,,,,,,,, -93000,4.7504487,1.6731595,,,,,,,,,,,,,, -93100,4.8195243,1.65209,,,,,,,,,,,,,, -93200,4.502681,1.6943893,,,,,,,,,,,,,, -93277,,,0.7053571343421936,1.1407818794250488,0.6459000110626221,1.4545795917510986,50000.0,0.5184000134468079,2.1714203357696533,10000.0,31662.01350545883,32784.82475876808,31662.01350545883,1117.2352879047394,2.3011181354522705,0.0 -93300,5.717562,1.7090292,,,,,,,,,,,,,, -93400,4.247355,1.5708567,,,,,,,,,,,,,, -93500,4.299988,1.6276591,,,,,,,,,,,,,, -93600,5.426577,1.5870534,,,,,,,,,,,,,, -93700,4.9112725,1.8131964,,,,,,,,,,,,,, -93800,5.016949,1.6255033,,,,,,,,,,,,,, -93900,4.580008,1.6749202,,,,,,,,,,,,,, -94000,4.348436,1.6407372,,,,,,,,,,,,,, -94100,4.389022,1.5609004,,,,,,,,,,,,,, -94200,4.09868,1.6757016,,,,,,,,,,,,,, -94300,4.4447346,1.8135054,,,,,,,,,,,,,, -94400,4.2432466,1.5098292,,,,,,,,,,,,,, -94500,4.6275687,1.8245767,,,,,,,,,,,,,, -94600,4.744018,1.5710776,,,,,,,,,,,,,, -94700,4.728121,1.7176479,,,,,,,,,,,,,, -94785,,,0.7093630433082581,1.1220866441726685,0.6515799760818481,1.4239590167999268,50000.0,0.5238000154495239,2.137367010116577,10000.0,32172.043387413025,33312.68545603752,32172.043387413025,1134.9675867557526,2.345402479171753,0.0 -94800,4.98833,1.7408979,,,,,,,,,,,,,, -94900,6.214643,1.7156171,,,,,,,,,,,,,, -95000,4.8760986,1.6369331,,,,,,,,,,,,,, -95100,4.247954,1.5324651,,,,,,,,,,,,,, -95200,4.3808784,1.567647,,,,,,,,,,,,,, -95300,4.039812,1.5982902,,,,,,,,,,,,,, -95400,4.5698595,1.6593649,,,,,,,,,,,,,, -95500,4.901078,1.650307,,,,,,,,,,,,,, -95600,5.476459,1.6118908,,,,,,,,,,,,,, -95700,5.0540905,1.6334672,,,,,,,,,,,,,, -95800,4.4762416,1.7074263,,,,,,,,,,,,,, -95900,4.4737825,1.5454671,,,,,,,,,,,,,, -96000,4.244854,1.67191,,,,,,,,,,,,,, -96100,4.188175,1.6091944,,,,,,,,,,,,,, -96200,5.3227267,1.6161798,,,,,,,,,,,,,, -96294,,,0.7216597199440002,1.0726561546325684,0.666979968547821,1.358321189880371,50000.0,0.5388000011444092,2.063443183898926,10000.0,32682.25458574295,33840.66708111763,32682.25458574295,1152.641860485077,2.3891897201538086,0.0 -96300,5.220495,1.5740398,,,,,,,,,,,,,, -96400,4.9274592,1.6861371,,,,,,,,,,,,,, -96500,4.288635,1.64966,,,,,,,,,,,,,, -96600,4.6298223,1.6974001,,,,,,,,,,,,,, -96700,4.8980002,1.5783219,,,,,,,,,,,,,, -96800,5.3483105,1.6628631,,,,,,,,,,,,,, -96900,4.5488114,1.6224262,,,,,,,,,,,,,, -97000,4.2514486,1.5060687,,,,,,,,,,,,,, -97100,4.566001,1.6200409,,,,,,,,,,,,,, -97200,4.710562,1.6474773,,,,,,,,,,,,,, -97300,4.6109776,1.6655957,,,,,,,,,,,,,, -97400,4.3574777,1.5550237,,,,,,,,,,,,,, -97500,3.9730873,1.5526426,,,,,,,,,,,,,, -97600,5.053328,1.6690522,,,,,,,,,,,,,, -97700,4.316213,1.5756983,,,,,,,,,,,,,, -97800,5.090575,1.7385358,,,,,,,,,,,,,, -97802,,,0.7195671200752258,1.0867538452148438,0.6625799536705017,1.372006058692932,50000.0,0.5335000157356262,2.085366487503052,10000.0,33192.23504304886,34368.55675005913,33192.23504304886,1170.453558921814,2.43384313583374,0.0 -97900,4.6676745,1.511306,,,,,,,,,,,,,, -98000,4.7412453,1.6193603,,,,,,,,,,,,,, -98100,4.649354,1.686005,,,,,,,,,,,,,, -98200,4.797965,1.5587039,,,,,,,,,,,,,, -98300,5.087027,1.6197231,,,,,,,,,,,,,, -98400,4.2059474,1.5743282,,,,,,,,,,,,,, -98500,4.344348,1.5958477,,,,,,,,,,,,,, -98600,4.1293178,1.6333091,,,,,,,,,,,,,, -98700,5.00225,1.5966839,,,,,,,,,,,,,, -98800,4.944378,1.7081254,,,,,,,,,,,,,, -98900,5.981155,1.5886736,,,,,,,,,,,,,, -99000,4.0750723,1.6974633,,,,,,,,,,,,,, -99100,4.28595,1.574734,,,,,,,,,,,,,, -99200,4.8516526,1.6076102,,,,,,,,,,,,,, -99300,5.1042614,1.6086168,,,,,,,,,,,,,, -99311,,,0.7441206574440002,0.9868006110191344,0.6504799723625183,1.4178804159164429,50000.0,0.525600016117096,2.124274492263794,10000.0,33702.44986701012,34896.47751951218,33702.44986701012,1188.0544037818909,2.485387325286865,0.0 -99400,4.713755,1.6280086,,,,,,,,,,,,,, -99500,4.2332716,1.6125515,,,,,,,,,,,,,, -99600,4.497134,1.6377958,,,,,,,,,,,,,, -99700,4.1180115,1.5939411,,,,,,,,,,,,,, -99800,3.8266242,1.4693929,,,,,,,,,,,,,, -99900,4.662248,1.6749961,,,,,,,,,,,,,, -100000,4.5978427,1.4820025,,,,,,,,,,,,,, -100100,4.1769886,1.5814799,,,,,,,,,,,,,, -100200,4.932039,1.6002984,,,,,,,,,,,,,, -100300,4.84017,1.5256283,,,,,,,,,,,,,, -100400,5.2114477,1.6245973,,,,,,,,,,,,,, -100500,5.4473906,1.545998,,,,,,,,,,,,,, -100600,5.1360526,1.4872746,,,,,,,,,,,,,, -100700,4.2489204,1.551677,,,,,,,,,,,,,, -100800,4.1927533,1.5095075,,,,,,,,,,,,,, -100820,,,0.7390784025192261,0.991266429424286,0.6676999926567078,1.349393606185913,50000.0,0.5522000193595886,2.032019853591919,10000.0,34212.6481218338,35424.54218220711,34212.6481218338,1205.8191084861755,2.532451629638672,0.0 -100900,4.5375,1.6585947,,,,,,,,,,,,,, -101000,5.169953,1.6224227,,,,,,,,,,,,,, -101100,4.5810394,1.6779815,,,,,,,,,,,,,, -101200,5.337602,1.6110301,,,,,,,,,,,,,, -101300,4.5753994,1.5711627,,,,,,,,,,,,,, -101400,4.5659385,1.5303084,,,,,,,,,,,,,, -101500,4.8325434,1.6104636,,,,,,,,,,,,,, -101600,4.6270313,1.5628536,,,,,,,,,,,,,, -101700,5.1130366,1.6740162,,,,,,,,,,,,,, -101800,4.3636703,1.5396478,,,,,,,,,,,,,, -101900,5.1200085,1.5687523,,,,,,,,,,,,,, -102000,4.454467,1.5761542,,,,,,,,,,,,,, -102100,4.63741,1.5903171,,,,,,,,,,,,,, -102200,5.2135296,1.6377362,,,,,,,,,,,,,, -102300,6.2464457,1.6139864,,,,,,,,,,,,,, -102329,,,0.7271803021430969,1.0407520532608032,0.6602399945259094,1.3748583793640137,50000.0,0.5373000502586365,2.074155807495117,10000.0,34722.8763525486,35952.56840515137,34722.8763525486,1223.5154864788055,2.5815184116363525,0.0 -102400,4.31352,1.5401419,,,,,,,,,,,,,, -102500,5.0019035,1.5035487,,,,,,,,,,,,,, -102600,4.9398737,1.5335842,,,,,,,,,,,,,, -102700,4.7371545,1.5118474,,,,,,,,,,,,,, -102800,5.0578933,1.6182538,,,,,,,,,,,,,, -102900,4.6829934,1.5257525,,,,,,,,,,,,,, -103000,5.217938,1.546003,,,,,,,,,,,,,, -103100,4.5591865,1.502969,,,,,,,,,,,,,, -103200,4.461494,1.5314915,,,,,,,,,,,,,, -103300,4.541472,1.6160114,,,,,,,,,,,,,, -103400,5.1781607,1.5875096,,,,,,,,,,,,,, -103500,4.6674304,1.55746,,,,,,,,,,,,,, -103600,4.858612,1.5905852,,,,,,,,,,,,,, -103700,4.756847,1.5263966,,,,,,,,,,,,,, -103800,4.601532,1.4822729,,,,,,,,,,,,,, -103837,,,0.7323620915412903,1.032604098320007,0.6691199541091919,1.3446305990219116,50000.0,0.5454000234603882,2.035710573196411,10000.0,35232.8334004879,36480.34671187401,35232.8334004879,1241.2360637187958,2.6300241947174072,0.0 -103900,4.578151,1.5091193,,,,,,,,,,,,,, -104000,4.5745773,1.4791019,,,,,,,,,,,,,, -104100,4.5699267,1.5855322,,,,,,,,,,,,,, -104200,5.749112,1.5758839,,,,,,,,,,,,,, -104300,4.5936093,1.5250636,,,,,,,,,,,,,, -104400,4.461791,1.6008546,,,,,,,,,,,,,, -104500,4.9582005,1.6245116,,,,,,,,,,,,,, -104600,4.8807354,1.6242236,,,,,,,,,,,,,, -104700,5.6290207,1.611186,,,,,,,,,,,,,, -104800,4.6761694,1.5915917,,,,,,,,,,,,,, -104900,4.9649777,1.6237075,,,,,,,,,,,,,, -105000,4.6234508,1.4926906,,,,,,,,,,,,,, -105100,4.939338,1.5203075,,,,,,,,,,,,,, -105200,5.289798,1.6989179,,,,,,,,,,,,,, -105300,5.4546666,1.4996572,,,,,,,,,,,,,, -105345,,,0.7277383208274841,1.057060718536377,0.6671000123023987,1.3545143604278564,50000.0,0.5454000234603882,2.032726526260376,10000.0,35742.87636375427,37007.94652104378,35742.87636375427,1258.6924047470093,2.6776812076568604,0.0 -105400,5.433167,1.6739726,,,,,,,,,,,,,, -105500,5.461731,1.619292,,,,,,,,,,,,,, -105600,4.4951587,1.5831392,,,,,,,,,,,,,, -105700,5.341715,1.5952024,,,,,,,,,,,,,, -105800,4.7838063,1.5522894,,,,,,,,,,,,,, -105900,4.698089,1.5006772,,,,,,,,,,,,,, -106000,5.1912785,1.5976186,,,,,,,,,,,,,, -106100,5.0497932,1.5129629,,,,,,,,,,,,,, -106200,5.3289437,1.4792758,,,,,,,,,,,,,, -106300,4.650686,1.4536102,,,,,,,,,,,,,, -106400,5.116538,1.6289002,,,,,,,,,,,,,, -106500,4.4601645,1.6402562,,,,,,,,,,,,,, -106600,4.983405,1.5627234,,,,,,,,,,,,,, -106700,4.759819,1.6096468,,,,,,,,,,,,,, -106800,4.622663,1.4994502,,,,,,,,,,,,,, -106853,,,0.7320830821990967,1.035598635673523,0.6726599931716919,1.330732822418213,50000.0,0.5489000082015991,2.024749994277954,10000.0,36252.81159281731,37535.38425660133,36252.81159281731,1276.0966138839722,2.722939729690552,0.0 -106900,4.538478,1.5129049,,,,,,,,,,,,,, -107000,4.226217,1.4751952,,,,,,,,,,,,,, -107100,5.222317,1.5623404,,,,,,,,,,,,,, -107200,4.982366,1.5303638,,,,,,,,,,,,,, -107300,5.0171714,1.4714327,,,,,,,,,,,,,, -107400,4.713055,1.5065631,,,,,,,,,,,,,, -107500,5.386957,1.4581474,,,,,,,,,,,,,, -107600,4.8855057,1.5256977,,,,,,,,,,,,,, -107700,5.0727997,1.637803,,,,,,,,,,,,,, -107800,4.773296,1.5397182,,,,,,,,,,,,,, -107900,5.232627,1.6099902,,,,,,,,,,,,,, -108000,5.250172,1.5812223,,,,,,,,,,,,,, -108100,5.093143,1.5229155,,,,,,,,,,,,,, -108200,5.4131603,1.5060978,,,,,,,,,,,,,, -108300,4.7578535,1.5215955,,,,,,,,,,,,,, -108361,,,0.7787786722183228,0.8223951458930969,0.6761599779129028,1.301857590675354,50000.0,0.5547000169754028,1.9950766563415527,10000.0,36762.72781395912,38063.045838832855,36762.72781395912,1293.7430353164673,2.769303321838379,0.0 -108400,5.6555085,1.5744226,,,,,,,,,,,,,, -108500,4.6355615,1.4080997,,,,,,,,,,,,,, -108600,5.49059,1.5981896,,,,,,,,,,,,,, -108700,4.761775,1.5767882,,,,,,,,,,,,,, -108800,4.79682,1.5294824,,,,,,,,,,,,,, -108900,4.6796436,1.462482,,,,,,,,,,,,,, -109000,5.3640575,1.5891647,,,,,,,,,,,,,, -109100,5.1919856,1.6161186,,,,,,,,,,,,,, -109200,4.5967345,1.5110676,,,,,,,,,,,,,, -109300,5.033236,1.5617537,,,,,,,,,,,,,, -109400,5.368706,1.481468,,,,,,,,,,,,,, -109500,4.552589,1.5102595,,,,,,,,,,,,,, -109600,4.848721,1.6432596,,,,,,,,,,,,,, -109700,4.4871087,1.5709006,,,,,,,,,,,,,, -109800,4.9707875,1.5235445,,,,,,,,,,,,,, -109869,,,0.7466916441917419,0.96091890335083,0.6674599647521973,1.3489824533462524,50000.0,0.539900004863739,2.03249454498291,10000.0,37272.70946359634,38590.995624780655,37272.70946359634,1311.6061923503876,2.820514678955078,0.0 -109900,4.699529,1.5172217,,,,,,,,,,,,,, -110000,6.423997,1.5657978,,,,,,,,,,,,,, -110100,4.788745,1.5252255,,,,,,,,,,,,,, -110200,5.108454,1.5280646,,,,,,,,,,,,,, -110300,4.990038,1.365809,,,,,,,,,,,,,, -110400,4.6438675,1.5565085,,,,,,,,,,,,,, -110500,5.171328,1.5312275,,,,,,,,,,,,,, -110600,5.5647964,1.4890625,,,,,,,,,,,,,, -110700,4.89485,1.5736135,,,,,,,,,,,,,, -110800,4.964599,1.486619,,,,,,,,,,,,,, -110900,5.066575,1.5987788,,,,,,,,,,,,,, -111000,5.2488904,1.4389017,,,,,,,,,,,,,, -111100,5.407216,1.3941623,,,,,,,,,,,,,, -111200,4.715988,1.5535889,,,,,,,,,,,,,, -111300,6.022503,1.4898823,,,,,,,,,,,,,, -111377,,,0.7472297549247742,0.9534629583358764,0.6775999665260315,1.3055094480514526,50000.0,0.5509000420570374,1.9932727813720703,10000.0,37782.68553614616,39118.579266786575,37782.68553614616,1329.1077308654783,2.870955228805542,0.0 -111400,5.219486,1.5234708,,,,,,,,,,,,,, -111500,5.804259,1.4874959,,,,,,,,,,,,,, -111600,4.9243593,1.530648,,,,,,,,,,,,,, -111700,5.312122,1.4592248,,,,,,,,,,,,,, -111800,5.5408926,1.5902687,,,,,,,,,,,,,, -111900,5.2765517,1.4939299,,,,,,,,,,,,,, -112000,4.7104897,1.5457519,,,,,,,,,,,,,, -112100,5.7696157,1.4628625,,,,,,,,,,,,,, -112200,4.621835,1.5307491,,,,,,,,,,,,,, -112300,4.950535,1.5351158,,,,,,,,,,,,,, -112400,5.551716,1.4569451,,,,,,,,,,,,,, -112500,5.6879377,1.4995924,,,,,,,,,,,,,, -112600,5.192165,1.5684996,,,,,,,,,,,,,, -112700,6.094157,1.5477669,,,,,,,,,,,,,, -112800,4.7277336,1.4745944,,,,,,,,,,,,,, -112885,,,0.7496611475944519,0.9466625452041626,0.6813600063323975,1.2942183017730713,50000.0,0.5595000386238098,1.9724422693252563,10000.0,38292.7685611248,39646.19147348404,38292.7685611248,1346.5370290279388,2.9173662662506104,0.0 -112900,5.3964853,1.5264412,,,,,,,,,,,,,, -113000,4.877382,1.5805871,,,,,,,,,,,,,, -113100,4.9360366,1.3702208,,,,,,,,,,,,,, -113200,5.999122,1.4695683,,,,,,,,,,,,,, -113300,4.771484,1.5379877,,,,,,,,,,,,,, -113400,4.772059,1.3335769,,,,,,,,,,,,,, -113500,5.4427876,1.3991588,,,,,,,,,,,,,, -113600,4.7971907,1.4836164,,,,,,,,,,,,,, -113700,5.484535,1.4597232,,,,,,,,,,,,,, -113800,5.689035,1.5249779,,,,,,,,,,,,,, -113900,5.2400594,1.4553365,,,,,,,,,,,,,, -114000,5.1912894,1.435692,,,,,,,,,,,,,, -114100,5.3820853,1.5096278,,,,,,,,,,,,,, -114200,5.0696673,1.5683131,,,,,,,,,,,,,, -114300,4.915011,1.5170274,,,,,,,,,,,,,, -114393,,,0.7482461333274841,0.957245111465454,0.682699978351593,1.284542202949524,50000.0,0.5559000372886658,1.9868924617767327,10000.0,38802.89866161346,40173.89938187599,38802.89866161346,1364.0172460079193,2.962613582611084,0.0 -114400,4.840806,1.4686887,,,,,,,,,,,,,, -114500,5.176027,1.3791444,,,,,,,,,,,,,, -114600,6.245968,1.623748,,,,,,,,,,,,,, -114700,5.29239,1.4559021,,,,,,,,,,,,,, -114800,5.134018,1.4165776,,,,,,,,,,,,,, -114900,5.4844847,1.4334838,,,,,,,,,,,,,, -115000,6.6211123,1.3868824,,,,,,,,,,,,,, -115100,4.8159294,1.4379504,,,,,,,,,,,,,, -115200,4.517477,1.4753401,,,,,,,,,,,,,, -115300,4.922257,1.319308,,,,,,,,,,,,,, -115400,5.432362,1.5506159,,,,,,,,,,,,,, -115500,5.655003,1.4581745,,,,,,,,,,,,,, -115600,4.5596814,1.4558208,,,,,,,,,,,,,, -115700,5.6754174,1.609835,,,,,,,,,,,,,, -115800,5.4908733,1.4431322,,,,,,,,,,,,,, -115900,5.4565525,1.581604,,,,,,,,,,,,,, -115901,,,0.7508171200752258,0.9416345953941344,0.6827799677848816,1.2685418128967283,50000.0,0.5552000403404236,1.9634768962860107,10000.0,39313.07150053978,40701.87993502617,39313.07150053978,1381.7219486236572,3.011996269226074,0.0 -116000,4.719518,1.3531549,,,,,,,,,,,,,, -116100,5.213813,1.5959268,,,,,,,,,,,,,, -116200,5.0731153,1.4334862,,,,,,,,,,,,,, -116300,5.2632704,1.4359814,,,,,,,,,,,,,, -116400,5.4823527,1.4778554,,,,,,,,,,,,,, -116500,4.8438396,1.3847402,,,,,,,,,,,,,, -116600,5.516899,1.5142069,,,,,,,,,,,,,, -116700,5.3958697,1.5192808,,,,,,,,,,,,,, -116800,5.8623686,1.4686855,,,,,,,,,,,,,, -116900,5.327192,1.5005963,,,,,,,,,,,,,, -117000,5.0103927,1.3378872,,,,,,,,,,,,,, -117100,5.19044,1.3662645,,,,,,,,,,,,,, -117200,5.037278,1.4001372,,,,,,,,,,,,,, -117300,5.428599,1.5042918,,,,,,,,,,,,,, -117400,5.5779295,1.4785709,,,,,,,,,,,,,, -117409,,,0.7893415093421936,0.7836189866065979,0.6830199956893921,1.2751917839050293,50000.0,0.5624000430107117,1.9720399379730225,10000.0,39823.0315990448,41229.51292562485,39823.0315990448,1399.292273521423,3.0626745223999023,0.0 -117500,5.3632717,1.4336754,,,,,,,,,,,,,, -117600,5.505517,1.4887198,,,,,,,,,,,,,, -117700,6.141665,1.4686198,,,,,,,,,,,,,, -117800,5.325035,1.5599861,,,,,,,,,,,,,, -117900,4.759353,1.435733,,,,,,,,,,,,,, -118000,5.3424506,1.5323287,,,,,,,,,,,,,, -118100,5.5831995,1.5577791,,,,,,,,,,,,,, -118200,5.3342314,1.5086865,,,,,,,,,,,,,, -118300,5.6258116,1.5014344,,,,,,,,,,,,,, -118400,5.208555,1.3545069,,,,,,,,,,,,,, -118500,5.207556,1.4449906,,,,,,,,,,,,,, -118600,5.671587,1.4207721,,,,,,,,,,,,,, -118700,4.9110227,1.395318,,,,,,,,,,,,,, -118800,5.7702775,1.3967559,,,,,,,,,,,,,, -118900,5.814363,1.445947,,,,,,,,,,,,,, -118917,,,0.7732780575752258,0.8545231819152832,0.6861199736595154,1.2665618658065796,50000.0,0.5601000189781189,1.9834524393081665,10000.0,40333.14848613739,41757.50674414635,40333.14848613739,1417.0705609321594,3.109504461288452,0.0 -119000,6.211192,1.5239853,,,,,,,,,,,,,, -119100,5.85496,1.4850633,,,,,,,,,,,,,, -119200,5.665322,1.3715401,,,,,,,,,,,,,, -119300,5.0379252,1.3905312,,,,,,,,,,,,,, -119400,5.3437266,1.4389081,,,,,,,,,,,,,, -119500,5.227824,1.4444195,,,,,,,,,,,,,, -119600,5.3204308,1.4284637,,,,,,,,,,,,,, -119700,5.4442644,1.4310033,,,,,,,,,,,,,, -119800,5.5466847,1.5523145,,,,,,,,,,,,,, -119900,5.329414,1.4229486,,,,,,,,,,,,,, -120000,5.418873,1.4120481,,,,,,,,,,,,,, -120100,5.635773,1.4394948,,,,,,,,,,,,,, -120200,5.3810296,1.408449,,,,,,,,,,,,,, -120300,5.192854,1.4700481,,,,,,,,,,,,,, -120400,5.9760633,1.3578053,,,,,,,,,,,,,, -120426,,,0.7666015625,0.8653361201286316,0.6888200044631958,1.2589260339736938,50000.0,0.5587000250816345,1.966181397438049,10000.0,40843.37981677055,42285.4398317337,40843.37981677055,1434.6702933311462,3.158001184463501,0.0 -120500,5.5604086,1.386399,,,,,,,,,,,,,, -120600,5.484844,1.4383118,,,,,,,,,,,,,, -120700,5.3627224,1.4147192,,,,,,,,,,,,,, -120800,4.928829,1.3789934,,,,,,,,,,,,,, -120900,5.996537,1.304023,,,,,,,,,,,,,, -121000,5.0729394,1.5093325,,,,,,,,,,,,,, -121100,5.0259075,1.3741148,,,,,,,,,,,,,, -121200,5.568916,1.2474624,,,,,,,,,,,,,, -121300,5.3898225,1.4047214,,,,,,,,,,,,,, -121400,5.249091,1.3871219,,,,,,,,,,,,,, -121500,5.484819,1.4052196,,,,,,,,,,,,,, -121600,5.8765097,1.4692566,,,,,,,,,,,,,, -121700,5.557895,1.3755246,,,,,,,,,,,,,, -121800,5.2690516,1.4496555,,,,,,,,,,,,,, -121900,5.478846,1.3920172,,,,,,,,,,,,,, -121934,,,0.76664137840271,0.8641605377197266,0.6898999810218811,1.2518686056137085,50000.0,0.5608000159263611,1.9467543363571167,10000.0,41353.42318153381,42813.22805428505,41353.42318153381,1452.312379360199,3.207282543182373,0.0 -122000,5.0461655,1.3754821,,,,,,,,,,,,,, -122100,4.934055,1.4450054,,,,,,,,,,,,,, -122200,5.1428723,1.3934742,,,,,,,,,,,,,, -122300,6.0510545,1.3666253,,,,,,,,,,,,,, -122400,5.3946,1.3967032,,,,,,,,,,,,,, -122500,4.904729,1.3620191,,,,,,,,,,,,,, -122600,6.525439,1.2853171,,,,,,,,,,,,,, -122700,6.2322645,1.4270742,,,,,,,,,,,,,, -122800,5.6790347,1.3718631,,,,,,,,,,,,,, -122900,5.8179436,1.4203535,,,,,,,,,,,,,, -123000,6.0130434,1.4068123,,,,,,,,,,,,,, -123100,5.0044966,1.2388083,,,,,,,,,,,,,, -123200,5.3108954,1.3954587,,,,,,,,,,,,,, -123300,5.5900993,1.4417113,,,,,,,,,,,,,, -123400,6.299903,1.4650425,,,,,,,,,,,,,, -123442,,,0.7724210619926453,0.8580043315887451,0.6927599906921387,1.2386715412139893,50000.0,0.5678000450134277,1.9441684484481807,10000.0,41863.442735910416,43341.01841783524,41863.442735910416,1469.9726836681366,3.264508008956909,0.0 -123500,5.2398777,1.3783612,,,,,,,,,,,,,, -123600,5.495849,1.2939537,,,,,,,,,,,,,, -123700,5.572708,1.407774,,,,,,,,,,,,,, -123800,6.1531916,1.3172402,,,,,,,,,,,,,, -123900,6.9355407,1.3965749,,,,,,,,,,,,,, -124000,5.189268,1.3737924,,,,,,,,,,,,,, -124100,5.4833183,1.435761,,,,,,,,,,,,,, -124200,5.6511197,1.4034547,,,,,,,,,,,,,, -124300,5.478804,1.2587049,,,,,,,,,,,,,, -124400,6.1244226,1.4830824,,,,,,,,,,,,,, -124500,6.751875,1.4697766,,,,,,,,,,,,,, -124600,5.50632,1.4796022,,,,,,,,,,,,,, -124700,5.1362796,1.3482931,,,,,,,,,,,,,, -124800,5.523687,1.3147557,,,,,,,,,,,,,, -124900,6.474144,1.480881,,,,,,,,,,,,,, -124950,,,0.7725605964660645,0.8497375845909119,0.6963599920272827,1.2183761596679688,50000.0,0.5697000026702881,1.9078247547149656,10000.0,42373.48074388504,43869.59917449951,42373.48074388504,1488.4095180034635,3.31595516204834,0.0 -125000,5.872887,1.4425893,,,,,,,,,,,,,, -125100,5.5271764,1.3864969,,,,,,,,,,,,,, -125200,5.8155413,1.3126996,,,,,,,,,,,,,, -125300,5.8927197,1.4471716,,,,,,,,,,,,,, -125400,5.416307,1.3121754,,,,,,,,,,,,,, -125500,5.493122,1.4047577,,,,,,,,,,,,,, -125600,5.632506,1.3597145,,,,,,,,,,,,,, -125700,5.828019,1.3312254,,,,,,,,,,,,,, -125800,6.21041,1.4279177,,,,,,,,,,,,,, -125900,5.9549465,1.242264,,,,,,,,,,,,,, -126000,5.8653026,1.40951,,,,,,,,,,,,,, -126100,7.309144,1.2922227,,,,,,,,,,,,,, -126200,5.1668434,1.3149128,,,,,,,,,,,,,, -126300,5.798913,1.4148377,,,,,,,,,,,,,, -126400,6.185355,1.3234301,,,,,,,,,,,,,, -126458,,,0.8109255433082581,0.6936679482460022,0.7005800008773804,1.2135677337646484,50000.0,0.5722000002861023,1.912022590637207,10000.0,42883.66209101677,44397.662073135376,42883.66209101677,1506.1870160102844,3.367685079574585,0.0 -126500,5.6862845,1.3857554,,,,,,,,,,,,,, -126600,5.4817224,1.348254,,,,,,,,,,,,,, -126700,5.8168616,1.3394537,,,,,,,,,,,,,, -126800,6.058051,1.39,,,,,,,,,,,,,, -126900,5.447992,1.2760816,,,,,,,,,,,,,, -127000,5.4161015,1.324898,,,,,,,,,,,,,, -127100,6.480232,1.3928789,,,,,,,,,,,,,, -127200,5.7947655,1.3385777,,,,,,,,,,,,,, -127300,5.5889096,1.3434026,,,,,,,,,,,,,, -127400,5.904726,1.4307053,,,,,,,,,,,,,, -127500,5.727079,1.4387997,,,,,,,,,,,,,, -127600,5.953841,1.3289005,,,,,,,,,,,,,, -127700,5.262668,1.3884624,,,,,,,,,,,,,, -127800,5.3393054,1.2782391,,,,,,,,,,,,,, -127900,6.6980433,1.3892717,,,,,,,,,,,,,, -127966,,,0.7948222160339355,0.7559593319892883,0.7019599676132202,1.1997932195663452,50000.0,0.5768000483512878,1.9222203493118288,10000.0,43393.808292627335,44925.48057103157,43393.808292627335,1523.7546339035034,3.41774845123291,0.0 -128000,5.961412,1.4623152,,,,,,,,,,,,,, -128100,5.4580965,1.373515,,,,,,,,,,,,,, -128200,6.754474,1.3386068,,,,,,,,,,,,,, -128300,6.276997,1.3063695,,,,,,,,,,,,,, -128400,6.2094955,1.3187921,,,,,,,,,,,,,, -128500,5.9468303,1.3250378,,,,,,,,,,,,,, -128600,5.671252,1.3095526,,,,,,,,,,,,,, -128700,6.0194516,1.2426769,,,,,,,,,,,,,, -128800,6.06209,1.3414845,,,,,,,,,,,,,, -128900,5.201402,1.2825307,,,,,,,,,,,,,, -129000,5.745407,1.3240634,,,,,,,,,,,,,, -129100,5.653391,1.2629118,,,,,,,,,,,,,, -129200,6.0665426,1.3490667,,,,,,,,,,,,,, -129300,6.4553246,1.3837761,,,,,,,,,,,,,, -129400,5.5974684,1.1928719,,,,,,,,,,,,,, -129474,,,0.7926099896430969,0.7674739360809326,0.70305997133255,1.1998165845870972,50000.0,0.5685999989509583,1.900969624519348,10000.0,43903.809019088745,45453.370572805405,43903.809019088745,1541.539691209793,3.468662738800049,0.0 -129500,5.8616085,1.369401,,,,,,,,,,,,,, -129600,6.2004313,1.4598535,,,,,,,,,,,,,, -129700,6.113078,1.3253174,,,,,,,,,,,,,, -129800,5.447851,1.232458,,,,,,,,,,,,,, -129900,6.0698214,1.2805393,,,,,,,,,,,,,, -130000,5.867323,1.3254128,,,,,,,,,,,,,, -130100,5.7699447,1.271193,,,,,,,,,,,,,, -130200,6.189677,1.2180243,,,,,,,,,,,,,, -130300,8.337022,1.3091193,,,,,,,,,,,,,, -130400,5.565905,1.2572182,,,,,,,,,,,,,, -130500,6.4584165,1.2437109,,,,,,,,,,,,,, -130600,5.8619723,1.3177959,,,,,,,,,,,,,, -130700,5.910877,1.2925267,,,,,,,,,,,,,, -130800,6.436888,1.4372395,,,,,,,,,,,,,, -130900,6.085354,1.4282296,,,,,,,,,,,,,, -130983,,,0.7925103306770325,0.7734469175338745,0.7093999981880188,1.1662369966506958,50000.0,0.5809000134468079,1.8419197797775269,10000.0,44413.95647978783,45981.24368643761,44413.95647978783,1559.1616296768188,3.518587350845337,0.0 -131000,5.642071,1.2523646,,,,,,,,,,,,,, -131100,5.8160095,1.3061464,,,,,,,,,,,,,, -131200,6.5269403,1.2631949,,,,,,,,,,,,,, -131300,6.236053,1.2788188,,,,,,,,,,,,,, -131400,6.799521,1.4515808,,,,,,,,,,,,,, -131500,7.158256,1.3151872,,,,,,,,,,,,,, -131600,6.464533,1.3621037,,,,,,,,,,,,,, -131700,5.8027205,1.3135507,,,,,,,,,,,,,, -131800,5.671999,1.1826365,,,,,,,,,,,,,, -131900,5.8558087,1.1614045,,,,,,,,,,,,,, -132000,5.476552,1.2598021,,,,,,,,,,,,,, -132100,6.254725,1.2335229,,,,,,,,,,,,,, -132200,6.01963,1.3112373,,,,,,,,,,,,,, -132300,6.50314,1.3723797,,,,,,,,,,,,,, -132400,6.3512583,1.2699287,,,,,,,,,,,,,, -132491,,,0.7926897406578064,0.7695133090019226,0.7079199552536011,1.17943274974823,50000.0,0.5781000256538391,1.8631477355957031,10000.0,44924.015117406845,46509.1582107544,44924.015117406845,1576.9079988002777,3.573139905929565,0.0 -132500,6.0958343,1.2866294,,,,,,,,,,,,,, -132600,5.3881564,1.2870514,,,,,,,,,,,,,, -132700,6.8433967,1.2832725,,,,,,,,,,,,,, -132800,6.384542,1.3225893,,,,,,,,,,,,,, -132900,6.4587116,1.3141208,,,,,,,,,,,,,, -133000,5.7795825,1.2610663,,,,,,,,,,,,,, -133100,6.1811347,1.3110178,,,,,,,,,,,,,, -133200,5.5246553,1.1707554,,,,,,,,,,,,,, -133300,7.097017,1.3445542,,,,,,,,,,,,,, -133400,6.2002535,1.2845277,,,,,,,,,,,,,, -133500,5.715569,1.3168665,,,,,,,,,,,,,, -133600,5.94286,1.154869,,,,,,,,,,,,,, -133700,5.748987,1.2685988,,,,,,,,,,,,,, -133800,5.639806,1.1692338,,,,,,,,,,,,,, -133900,5.9898796,1.230584,,,,,,,,,,,,,, -133999,,,0.7933274507522583,0.7550987601280212,0.711899995803833,1.1559367179870603,50000.0,0.5849000215530396,1.8598921298980715,10000.0,45434.11597537994,47037.14133429527,45434.11597537994,1594.682165145874,3.6277785301208496,0.0 -134000,6.336476,1.27499,,,,,,,,,,,,,, -134100,6.2355886,1.3194883,,,,,,,,,,,,,, -134200,6.341895,1.2483875,,,,,,,,,,,,,, -134300,6.2857957,1.2535472,,,,,,,,,,,,,, -134400,5.715663,1.2653427,,,,,,,,,,,,,, -134500,6.7826357,1.272443,,,,,,,,,,,,,, -134600,6.5676346,1.2711351,,,,,,,,,,,,,, -134700,6.448416,1.2576036,,,,,,,,,,,,,, -134800,6.486983,1.2847463,,,,,,,,,,,,,, -134900,6.1931453,1.2792617,,,,,,,,,,,,,, -135000,5.7975507,1.3139266,,,,,,,,,,,,,, -135100,6.717225,1.2687347,,,,,,,,,,,,,, -135200,6.6209846,1.2764788,,,,,,,,,,,,,, -135300,6.22075,1.1817855,,,,,,,,,,,,,, -135400,6.1038804,1.1971133,,,,,,,,,,,,,, -135500,5.942823,1.232592,,,,,,,,,,,,,, -135507,,,0.8343032598495483,0.6008027791976929,0.7145000100135803,1.1467500925064087,50000.0,0.591200053691864,1.8425389528274536,10000.0,45944.17465925217,47564.80839204788,45944.17465925217,1612.187058925629,3.67779278755188,0.0 -135600,6.4374623,1.233988,,,,,,,,,,,,,, -135700,6.529048,1.2289491,,,,,,,,,,,,,, -135800,5.6525025,1.2733133,,,,,,,,,,,,,, -135900,6.829076,1.2892946,,,,,,,,,,,,,, -136000,6.6437106,1.3140947,,,,,,,,,,,,,, -136100,5.9193683,1.242914,,,,,,,,,,,,,, -136200,5.9689093,1.2016814,,,,,,,,,,,,,, -136300,6.9256043,1.3215902,,,,,,,,,,,,,, -136400,6.445261,1.269339,,,,,,,,,,,,,, -136500,6.558603,1.3428373,,,,,,,,,,,,,, -136600,6.5366793,1.344963,,,,,,,,,,,,,, -136700,6.600472,1.2842062,,,,,,,,,,,,,, -136800,7.012643,1.3235995,,,,,,,,,,,,,, -136900,6.455179,1.2064149,,,,,,,,,,,,,, -137000,5.8318453,1.1525574,,,,,,,,,,,,,, -137013,,,0.8226243257522583,0.6424916386604309,0.7174199819564819,1.132103443145752,50000.0,0.5916000008583069,1.821941494941712,10000.0,46454.21942257881,48092.49197125435,46454.21942257881,1629.722366809845,3.7293410301208496,0.0 -137100,6.7947283,1.2576952,,,,,,,,,,,,,, -137200,5.880582,1.1532952,,,,,,,,,,,,,, -137300,6.2847404,1.2415231,,,,,,,,,,,,,, -137400,6.319782,1.2371402,,,,,,,,,,,,,, -137500,6.0126853,1.2057283,,,,,,,,,,,,,, -137600,7.405254,1.3069403,,,,,,,,,,,,,, -137700,6.7840977,1.2647898,,,,,,,,,,,,,, -137800,6.165741,1.1365422,,,,,,,,,,,,,, -137900,6.769141,1.3064271,,,,,,,,,,,,,, -138000,7.039729,1.282558,,,,,,,,,,,,,, -138100,6.796441,1.2362498,,,,,,,,,,,,,, -138200,7.3363695,1.3606392,,,,,,,,,,,,,, -138300,6.301425,1.1881977,,,,,,,,,,,,,, -138400,6.4872317,1.3350928,,,,,,,,,,,,,, -138500,7.1537495,1.1769599,,,,,,,,,,,,,, -138521,,,0.8116230964660645,0.6831390261650085,0.7134799957275391,1.1457061767578125,50000.0,0.5888000130653381,1.8496946096420288,10000.0,46964.32555747032,48620.41685676575,46964.32555747032,1647.4365646839142,3.780620813369751,0.0 -138600,6.3435507,1.1370721,,,,,,,,,,,,,, -138700,6.084204,1.2580073,,,,,,,,,,,,,, -138800,7.5528083,1.2461718,,,,,,,,,,,,,, -138900,7.259512,1.2131473,,,,,,,,,,,,,, -139000,7.051343,1.3127465,,,,,,,,,,,,,, -139100,6.564562,1.1989579,,,,,,,,,,,,,, -139200,6.0761213,1.1596098,,,,,,,,,,,,,, -139300,6.3965487,1.1390244,,,,,,,,,,,,,, -139400,7.0550265,1.2773076,,,,,,,,,,,,,, -139500,5.960484,1.132317,,,,,,,,,,,,,, -139600,6.4687185,1.1949561,,,,,,,,,,,,,, -139700,6.490782,1.2096747,,,,,,,,,,,,,, -139800,7.4892054,1.2669915,,,,,,,,,,,,,, -139900,6.15229,1.1581743,,,,,,,,,,,,,, -140000,6.8438296,1.1059737,,,,,,,,,,,,,, -140029,,,0.8157086968421936,0.6734943389892578,0.7203199863433838,1.1174829006195068,50000.0,0.5944000482559204,1.816175103187561,10000.0,47474.33515691757,49147.82697439194,47474.33515691757,1664.7287590503693,3.834901571273804,0.0 -140100,6.765304,1.178851,,,,,,,,,,,,,, -140200,6.9281135,1.1872119,,,,,,,,,,,,,, -140300,6.370327,1.1921577,,,,,,,,,,,,,, -140400,6.7205453,1.2481073,,,,,,,,,,,,,, -140500,6.4419985,1.1974168,,,,,,,,,,,,,, -140600,7.26025,1.2322961,,,,,,,,,,,,,, -140700,6.3667483,1.1399988,,,,,,,,,,,,,, -140800,6.4384828,1.2000864,,,,,,,,,,,,,, -140900,6.8364,1.1294698,,,,,,,,,,,,,, -141000,7.029639,1.1787935,,,,,,,,,,,,,, -141100,6.7860756,1.2346245,,,,,,,,,,,,,, -141200,6.6795077,1.1438259,,,,,,,,,,,,,, -141300,6.508142,1.1806476,,,,,,,,,,,,,, -141400,7.9570875,1.1535981,,,,,,,,,,,,,, -141500,6.6602564,1.2340022,,,,,,,,,,,,,, -141537,,,0.8167450428009033,0.6674581170082092,0.7184199690818787,1.126118779182434,50000.0,0.597100019454956,1.8183460235595703,10000.0,47984.39821815491,49675.356055021286,47984.39821815491,1682.08682847023,3.888970851898194,0.0 -141600,6.9043856,1.2185943,,,,,,,,,,,,,, -141700,6.6265874,1.150207,,,,,,,,,,,,,, -141800,7.0973983,1.1961697,,,,,,,,,,,,,, -141900,6.439253,1.1911641,,,,,,,,,,,,,, -142000,6.362183,1.1830764,,,,,,,,,,,,,, -142100,7.633022,1.1307472,,,,,,,,,,,,,, -142200,6.3550677,1.1514678,,,,,,,,,,,,,, -142300,6.5214486,1.1071919,,,,,,,,,,,,,, -142400,6.887403,1.1922398,,,,,,,,,,,,,, -142500,6.3676887,1.1252508,,,,,,,,,,,,,, -142600,7.062135,1.2160553,,,,,,,,,,,,,, -142700,7.3447986,1.225691,,,,,,,,,,,,,, -142800,7.2070436,1.2274294,,,,,,,,,,,,,, -142900,6.843334,1.2503556,,,,,,,,,,,,,, -143000,7.179265,1.1569041,,,,,,,,,,,,,, -143045,,,0.8122807741165161,0.6797454953193665,0.7173999547958374,1.137130618095398,50000.0,0.5925000309944153,1.8193696737289429,10000.0,48494.36148881912,50203.234813690186,48494.36148881912,1699.8958258628843,3.943979978561402,0.0 -143100,7.2940173,1.1518118,,,,,,,,,,,,,, -143200,6.943269,1.1066473,,,,,,,,,,,,,, -143300,7.1201277,1.1786962,,,,,,,,,,,,,, -143400,7.98312,1.2446206,,,,,,,,,,,,,, -143500,8.99149,1.1804571,,,,,,,,,,,,,, -143600,6.963779,1.2198474,,,,,,,,,,,,,, -143700,7.1547327,1.1962425,,,,,,,,,,,,,, -143800,6.914625,1.145501,,,,,,,,,,,,,, -143900,6.8428655,1.0665972,,,,,,,,,,,,,, -144000,7.3364396,1.1632257,,,,,,,,,,,,,, -144100,7.861058,1.176566,,,,,,,,,,,,,, -144200,6.4863935,1.0564029,,,,,,,,,,,,,, -144300,7.2398076,1.0977951,,,,,,,,,,,,,, -144400,7.506397,1.2557223,,,,,,,,,,,,,, -144500,7.046339,1.1297883,,,,,,,,,,,,,, -144553,,,0.8585180044174194,0.5102043747901917,0.7265799641609192,1.0943148136138916,50000.0,0.6016000509262085,1.7970024347305298,10000.0,49004.41987133026,50730.99198412895,49004.41987133026,1717.4889857769012,3.996210813522339,0.0 -144600,7.633311,1.1401478,,,,,,,,,,,,,, -144700,7.5196743,1.2197261,,,,,,,,,,,,,, -144800,7.3024387,1.1068077,,,,,,,,,,,,,, -144900,7.288255,1.2379968,,,,,,,,,,,,,, -145000,7.291068,1.1575009,,,,,,,,,,,,,, -145100,6.7697587,1.1260178,,,,,,,,,,,,,, -145200,7.1251316,1.1362476,,,,,,,,,,,,,, -145300,6.8331966,1.2422335,,,,,,,,,,,,,, -145400,6.463786,1.1504447,,,,,,,,,,,,,, -145500,7.038015,1.021421,,,,,,,,,,,,,, -145600,7.5960264,1.1598022,,,,,,,,,,,,,, -145700,7.0969095,1.113445,,,,,,,,,,,,,, -145800,7.32755,1.0969548,,,,,,,,,,,,,, -145900,7.565307,1.1405098,,,,,,,,,,,,,, -146000,7.2725835,1.1777062,,,,,,,,,,,,,, -146060,,,0.8483139276504517,0.5484762787818909,0.729420006275177,1.085327386856079,50000.0,0.6035000085830688,1.7772631645202637,10000.0,49514.39406490326,51258.76909947395,49514.39406490326,1735.1880152225494,4.048548221588135,0.0 -146100,6.9537005,1.0809265,,,,,,,,,,,,,, -146200,6.7585816,1.1196258,,,,,,,,,,,,,, -146300,6.7589626,1.126696,,,,,,,,,,,,,, -146400,7.4357805,1.1207006,,,,,,,,,,,,,, -146500,6.9478846,1.0994457,,,,,,,,,,,,,, -146600,7.6009417,1.176624,,,,,,,,,,,,,, -146700,7.296793,1.0855758,,,,,,,,,,,,,, -146800,8.053946,1.045233,,,,,,,,,,,,,, -146900,7.1204686,1.1527528,,,,,,,,,,,,,, -147000,6.8712277,1.1096761,,,,,,,,,,,,,, -147100,6.7532153,1.1097816,,,,,,,,,,,,,, -147200,8.361964,1.1049886,,,,,,,,,,,,,, -147300,7.053365,1.2074263,,,,,,,,,,,,,, -147400,8.481319,1.0951233,,,,,,,,,,,,,, -147500,7.1523895,1.1139233,,,,,,,,,,,,,, -147568,,,0.8409398794174194,0.5652621984481812,0.7280199527740479,1.085485339164734,50000.0,0.603600025177002,1.786201238632202,10000.0,50024.444157123566,51786.662527799606,50024.444157123566,1752.9253692626953,4.101878881454468,0.0 -147600,7.1521683,1.0010135,,,,,,,,,,,,,, -147700,7.742622,1.0778829,,,,,,,,,,,,,, -147800,7.1842833,1.0990285,,,,,,,,,,,,,, -147900,7.02501,1.0424248,,,,,,,,,,,,,, -148000,7.684538,1.093638,,,,,,,,,,,,,, -148100,7.6121073,1.0409188,,,,,,,,,,,,,, -148200,7.6669817,1.2155317,,,,,,,,,,,,,, -148300,7.531575,1.1013836,,,,,,,,,,,,,, -148400,6.580113,1.1271703,,,,,,,,,,,,,, -148500,7.917426,1.1052727,,,,,,,,,,,,,, -148600,7.4026723,1.2295117,,,,,,,,,,,,,, -148700,8.061416,1.0605496,,,,,,,,,,,,,, -148800,8.897592,1.1380177,,,,,,,,,,,,,, -148900,7.3109035,1.0944005,,,,,,,,,,,,,, -149000,7.1880794,1.0579604,,,,,,,,,,,,,, -149075,,,0.8464604616165161,0.5500763654708862,0.7320799827575684,1.0699559450149536,50000.0,0.610200047492981,1.7584848403930664,10000.0,50534.34899163246,52314.1056895256,50534.34899163246,1770.3562195301056,4.158680438995361,0.0 -149100,7.726055,1.106259,,,,,,,,,,,,,, -149200,7.465136,1.2025067,,,,,,,,,,,,,, -149300,7.7654567,1.1267451,,,,,,,,,,,,,, -149400,7.4731116,1.1320217,,,,,,,,,,,,,, -149500,8.372238,1.1124343,,,,,,,,,,,,,, -149600,7.4600353,1.136914,,,,,,,,,,,,,, -149700,7.749041,1.0802453,,,,,,,,,,,,,, -149800,7.3328753,1.1853195,,,,,,,,,,,,,, -149900,7.2798615,1.1113935,,,,,,,,,,,,,, -150000,7.6099205,1.0182106,,,,,,,,,,,,,, -150100,7.3559937,1.0814769,,,,,,,,,,,,,, -150200,6.9421883,0.98553896,,,,,,,,,,,,,, -150300,7.9013085,1.1021919,,,,,,,,,,,,,, -150400,7.2728753,1.0946038,,,,,,,,,,,,,, -150500,8.044041,1.0564814,,,,,,,,,,,,,, -150584,,,0.8461614847183228,0.5413140058517456,0.7317999601364136,1.075062870979309,50000.0,0.612500011920929,1.7552403211593628,10000.0,51044.57275414467,52842.02047371864,51044.57275414467,1787.9360961914062,4.216421842575073,0.0 -150600,7.2758474,0.9999207,,,,,,,,,,,,,, -150700,7.90305,1.0193442,,,,,,,,,,,,,, -150800,7.447611,1.1685525,,,,,,,,,,,,,, -150900,7.3496675,1.0464604,,,,,,,,,,,,,, -151000,8.014072,1.058711,,,,,,,,,,,,,, -151100,7.2486267,1.0144755,,,,,,,,,,,,,, -151200,7.6618876,1.0487143,,,,,,,,,,,,,, -151300,7.257572,1.0419902,,,,,,,,,,,,,, -151400,6.975829,0.93494993,,,,,,,,,,,,,, -151500,7.9258842,0.99519104,,,,,,,,,,,,,, -151600,7.374549,1.0162851,,,,,,,,,,,,,, -151700,7.439957,1.110991,,,,,,,,,,,,,, -151800,7.7928843,0.99659896,,,,,,,,,,,,,, -151900,6.9374375,1.0500077,,,,,,,,,,,,,, -152000,8.082174,0.99496144,,,,,,,,,,,,,, -152093,,,0.8462013602256775,0.5413247346878052,0.737339973449707,1.056695580482483,50000.0,0.6089000105857849,1.747403621673584,10000.0,51554.79049611092,53369.72068047524,51554.79049611092,1805.3098711967468,4.273014545440674,0.0 -152100,7.730215,0.99937224,,,,,,,,,,,,,, -152200,7.5775275,0.98824835,,,,,,,,,,,,,, -152300,7.7585096,1.0070716,,,,,,,,,,,,,, -152400,7.7372184,0.9678758,,,,,,,,,,,,,, -152500,8.187795,1.0731289,,,,,,,,,,,,,, -152600,8.147376,1.0528729,,,,,,,,,,,,,, -152700,7.962455,1.1722544,,,,,,,,,,,,,, -152800,8.157742,1.0837559,,,,,,,,,,,,,, -152900,7.7121205,1.0341591,,,,,,,,,,,,,, -153000,8.398786,1.0107616,,,,,,,,,,,,,, -153100,7.5030985,0.9838613,,,,,,,,,,,,,, -153200,7.5157075,1.0026901,,,,,,,,,,,,,, -153300,7.077675,0.9863365,,,,,,,,,,,,,, -153400,8.046303,1.0077548,,,,,,,,,,,,,, -153500,7.8590603,1.0578053,,,,,,,,,,,,,, -153600,7.8194857,1.064039,,,,,,,,,,,,,, -153601,,,0.8833306431770325,0.4207604527473449,0.7371399998664856,1.053296685218811,50000.0,0.612000048160553,1.7551151514053345,10000.0,52065.01314878464,53897.64444494248,52065.01314878464,1822.9050323963163,4.3264100551605225,0.0 -153700,7.4557147,0.9551532,,,,,,,,,,,,,, -153800,8.10491,1.0044221,,,,,,,,,,,,,, -153900,8.357875,1.0131855,,,,,,,,,,,,,, -154000,8.448686,1.0655047,,,,,,,,,,,,,, -154100,7.738092,0.9873272,,,,,,,,,,,,,, -154200,8.116047,1.0828667,,,,,,,,,,,,,, -154300,7.3111873,0.94825965,,,,,,,,,,,,,, -154400,7.932472,0.96639884,,,,,,,,,,,,,, -154500,7.422262,1.04627,,,,,,,,,,,,,, -154600,7.451042,0.9956161,,,,,,,,,,,,,, -154700,7.5868244,0.97265136,,,,,,,,,,,,,, -154800,7.7069263,0.9898337,,,,,,,,,,,,,, -154900,7.812803,1.0021656,,,,,,,,,,,,,, -155000,7.7344923,1.0676185,,,,,,,,,,,,,, -155100,7.3204913,1.024667,,,,,,,,,,,,,, -155109,,,0.8740832209587097,0.4415238797664642,0.7404199838638306,1.0409177541732788,50000.0,0.6145000457763672,1.7276402711868286,10000.0,52575.09206676483,54425.57376766205,52575.09206676483,1840.645209312439,4.384929418563843,0.0 -155200,8.193318,1.0705608,,,,,,,,,,,,,, -155300,7.6333356,1.0390016,,,,,,,,,,,,,, -155400,8.647014,1.0524893,,,,,,,,,,,,,, -155500,8.123226,0.97583616,,,,,,,,,,,,,, -155600,7.599722,1.0406319,,,,,,,,,,,,,, -155700,8.817473,1.0571178,,,,,,,,,,,,,, -155800,7.2297444,0.9839936,,,,,,,,,,,,,, -155900,8.492619,1.0550914,,,,,,,,,,,,,, -156000,8.191564,0.9461438,,,,,,,,,,,,,, -156100,8.0589285,0.9873298,,,,,,,,,,,,,, -156200,7.8125,1.0941927,,,,,,,,,,,,,, -156300,7.986753,0.92324984,,,,,,,,,,,,,, -156400,7.502875,0.9620941,,,,,,,,,,,,,, -156500,7.3202252,1.0076197,,,,,,,,,,,,,, -156600,7.474123,0.8678957,,,,,,,,,,,,,, -156617,,,0.8761957883834839,0.4427817463874817,0.7436800003051758,1.0364357233047483,50000.0,0.6178000569343567,1.732064127922058,10000.0,53084.98903775215,54953.56215620041,53084.98903775215,1858.6266074180603,4.439204216003418,0.0 -156700,8.140549,0.9808198,,,,,,,,,,,,,, -156800,7.9542675,0.99569035,,,,,,,,,,,,,, -156900,7.539258,0.90215224,,,,,,,,,,,,,, -157000,8.781932,0.8766464,,,,,,,,,,,,,, -157100,7.481469,0.93624306,,,,,,,,,,,,,, -157200,7.767143,1.0110254,,,,,,,,,,,,,, -157300,7.575869,0.94899356,,,,,,,,,,,,,, -157400,8.544243,1.0366392,,,,,,,,,,,,,, -157500,8.257792,0.92440844,,,,,,,,,,,,,, -157600,7.9148684,0.98106116,,,,,,,,,,,,,, -157700,7.7361197,0.959285,,,,,,,,,,,,,, -157800,8.316285,1.0521762,,,,,,,,,,,,,, -157900,8.378285,0.9608177,,,,,,,,,,,,,, -158000,9.718365,0.9186446,,,,,,,,,,,,,, -158100,7.6491623,0.95292884,,,,,,,,,,,,,, -158125,,,0.8801418542861938,0.4310168325901031,0.7454400062561035,1.02325177192688,50000.0,0.6204000115394592,1.7091346979141235,10000.0,53594.90204691887,55481.23839688301,53594.90204691887,1876.2782986164093,4.496425151824951,0.0 -158200,8.30259,0.8982784,,,,,,,,,,,,,, -158300,7.5835943,0.91186666,,,,,,,,,,,,,, -158400,7.7859716,0.86449903,,,,,,,,,,,,,, -158500,8.029497,1.0279607,,,,,,,,,,,,,, -158600,9.159482,1.0235782,,,,,,,,,,,,,, -158700,7.8489046,0.9787717,,,,,,,,,,,,,, -158800,7.2684455,0.8795154,,,,,,,,,,,,,, -158900,7.491494,0.94689286,,,,,,,,,,,,,, -159000,8.002368,0.91301835,,,,,,,,,,,,,, -159100,7.801846,0.9129372,,,,,,,,,,,,,, -159200,7.881913,0.9514958,,,,,,,,,,,,,, -159300,8.672527,0.9556307,,,,,,,,,,,,,, -159400,8.865687,0.97535455,,,,,,,,,,,,,, -159500,7.6203775,0.85354245,,,,,,,,,,,,,, -159600,8.979526,1.0186379,,,,,,,,,,,,,, -159633,,,0.8810586333274841,0.4157596826553345,0.7454599738121033,1.0245261192321775,50000.0,0.6222000122070312,1.717750072479248,10000.0,54104.91116786003,56009.05380535126,54104.91116786003,1893.973302364349,4.555343627929688,0.0 -159700,9.270966,0.92901164,,,,,,,,,,,,,, -159800,8.26221,0.89824426,,,,,,,,,,,,,, -159900,8.174073,0.9467077,,,,,,,,,,,,,, -160000,8.495528,0.8437643,,,,,,,,,,,,,, -160100,8.589655,0.89595294,,,,,,,,,,,,,, -160200,8.804291,0.90174234,,,,,,,,,,,,,, -160300,7.9943814,0.9139918,,,,,,,,,,,,,, -160400,9.422521,0.9308269,,,,,,,,,,,,,, -160500,8.934356,0.9466989,,,,,,,,,,,,,, -160600,8.662055,0.85432816,,,,,,,,,,,,,, -160700,8.103726,0.97026014,,,,,,,,,,,,,, -160800,8.406007,0.96764755,,,,,,,,,,,,,, -160900,8.990685,0.8900076,,,,,,,,,,,,,, -161000,8.308975,0.93557584,,,,,,,,,,,,,, -161100,9.337422,1.0440726,,,,,,,,,,,,,, -161141,,,0.8859813213348389,0.4051311314105987,0.7473999857902527,1.0201231241226196,50000.0,0.6239000558853149,1.703717827796936,10000.0,54615.10005736351,56537.24241113663,54615.10005736351,1911.866572141648,4.608543395996094,0.0 -161200,7.72065,0.89254016,,,,,,,,,,,,,, -161300,8.312663,0.90204704,,,,,,,,,,,,,, -161400,8.07931,0.908902,,,,,,,,,,,,,, -161500,8.206174,0.88531286,,,,,,,,,,,,,, -161600,8.534459,0.9705941,,,,,,,,,,,,,, -161700,7.7171206,0.86436355,,,,,,,,,,,,,, -161800,8.934224,0.8659434,,,,,,,,,,,,,, -161900,9.432147,0.91349185,,,,,,,,,,,,,, -162000,7.7148314,0.9090629,,,,,,,,,,,,,, -162100,8.491842,0.9392345,,,,,,,,,,,,,, -162200,8.665949,0.8895229,,,,,,,,,,,,,, -162300,9.186754,1.0195785,,,,,,,,,,,,,, -162400,9.832728,0.9849765,,,,,,,,,,,,,, -162500,8.653519,0.89956194,,,,,,,,,,,,,, -162600,8.573844,0.90364206,,,,,,,,,,,,,, -162649,,,0.9079041481018066,0.3245623707771301,0.7503199577331543,1.0081355571746826,50000.0,0.625700056552887,1.6878162622451782,10000.0,55125.14937853813,57065.24767613411,55125.14937853813,1929.71240234375,4.6652233600616455,0.0 -162700,8.878698,1.0118823,,,,,,,,,,,,,, -162800,8.381388,0.82707196,,,,,,,,,,,,,, -162900,8.49692,0.9379967,,,,,,,,,,,,,, -163000,8.715215,0.89789367,,,,,,,,,,,,,, -163100,7.9816628,0.82075596,,,,,,,,,,,,,, -163200,8.946406,0.9173849,,,,,,,,,,,,,, -163300,8.670243,0.8824347,,,,,,,,,,,,,, -163400,9.08826,0.8704233,,,,,,,,,,,,,, -163500,9.279474,0.94179976,,,,,,,,,,,,,, -163600,9.01095,0.9367347,,,,,,,,,,,,,, -163700,8.546757,0.8630162,,,,,,,,,,,,,, -163800,8.122144,0.8267287,,,,,,,,,,,,,, -163900,7.9895372,0.94235814,,,,,,,,,,,,,, -164000,9.509092,0.8219323,,,,,,,,,,,,,, -164100,8.821361,0.86759853,,,,,,,,,,,,,, -164157,,,0.904496133327484,0.3379042744636535,0.7519399523735046,1.0083187818527222,50000.0,0.6246000528335571,1.704638957977295,10000.0,55635.22860836983,57592.82874393463,55635.22860836983,1947.10103392601,4.725534439086914,0.0 -164200,9.649421,0.84992075,,,,,,,,,,,,,, -164300,9.352214,0.89643914,,,,,,,,,,,,,, -164400,8.8216305,0.9150908,,,,,,,,,,,,,, -164500,8.347014,0.88012147,,,,,,,,,,,,,, -164600,8.214928,0.9059078,,,,,,,,,,,,,, -164700,7.761392,0.8440553,,,,,,,,,,,,,, -164800,8.42916,0.8821788,,,,,,,,,,,,,, -164900,8.523624,0.8637755,,,,,,,,,,,,,, -165000,8.3353615,0.8474949,,,,,,,,,,,,,, -165100,8.601655,0.90844214,,,,,,,,,,,,,, -165200,8.270277,0.8472856,,,,,,,,,,,,,, -165300,8.159276,0.89843804,,,,,,,,,,,,,, -165400,8.991665,0.8486089,,,,,,,,,,,,,, -165500,8.868814,0.84064317,,,,,,,,,,,,,, -165600,8.555185,0.8238571,,,,,,,,,,,,,, -165665,,,0.9044164419174194,0.3358021378517151,0.7523199915885925,0.9994009137153624,50000.0,0.628600001335144,1.6860158443450928,10000.0,56145.23153233528,58120.586990594864,56145.23153233528,1964.745409488678,4.782070636749268,0.0 -165700,8.855747,0.8293226,,,,,,,,,,,,,, -165800,8.480935,0.8590057,,,,,,,,,,,,,, -165900,8.457761,0.80055875,,,,,,,,,,,,,, -166000,10.226591,0.79155344,,,,,,,,,,,,,, -166100,7.734864,0.76642966,,,,,,,,,,,,,, -166200,8.253902,0.88281125,,,,,,,,,,,,,, -166300,10.187464,0.8598737,,,,,,,,,,,,,, -166400,9.514905,0.85337687,,,,,,,,,,,,,, -166500,8.240789,0.7843699,,,,,,,,,,,,,, -166600,8.455405,0.83665633,,,,,,,,,,,,,, -166700,9.057286,0.88494384,,,,,,,,,,,,,, -166800,8.515297,0.8310803,,,,,,,,,,,,,, -166900,9.878424,0.764642,,,,,,,,,,,,,, -167000,9.359017,0.8356634,,,,,,,,,,,,,, -167100,8.289096,0.85374874,,,,,,,,,,,,,, -167173,,,0.9084821343421936,0.3253271281719208,0.7557199597358704,0.9877996444702148,50000.0,0.6303000450134277,1.696708917617798,10000.0,56655.2565972805,58648.39890432358,56655.2565972805,1982.4217264652248,4.839404821395874,0.0 -167200,8.843774,0.82105935,,,,,,,,,,,,,, -167300,8.811021,0.82828474,,,,,,,,,,,,,, -167400,8.659586,0.8236524,,,,,,,,,,,,,, -167500,8.144775,0.8139154,,,,,,,,,,,,,, -167600,8.864811,0.77789485,,,,,,,,,,,,,, -167700,10.238482,0.8699614,,,,,,,,,,,,,, -167800,9.138619,0.8355366,,,,,,,,,,,,,, -167900,9.4013195,0.82004905,,,,,,,,,,,,,, -168000,9.215381,0.8160635,,,,,,,,,,,,,, -168100,10.007153,0.8217164,,,,,,,,,,,,,, -168200,8.477957,0.7775792,,,,,,,,,,,,,, -168300,8.685454,0.784358,,,,,,,,,,,,,, -168400,8.793345,0.86034936,,,,,,,,,,,,,, -168500,7.9889483,0.8185907,,,,,,,,,,,,,, -168600,9.471287,0.87239814,,,,,,,,,,,,,, -168680,,,0.9084821343421936,0.3189391791820526,0.7572000026702881,0.9814531207084656,50000.0,0.6302000284194946,1.6696397066116333,10000.0,57165.24439907074,59176.14205694199,57165.24439907074,2000.068875551224,4.895292043685913,0.0 -168700,8.508657,0.85864866,,,,,,,,,,,,,, -168800,10.331773,0.8683077,,,,,,,,,,,,,, -168900,8.442363,0.7950059,,,,,,,,,,,,,, -169000,8.598603,0.8256201,,,,,,,,,,,,,, -169100,7.9134207,0.7180321,,,,,,,,,,,,,, -169200,8.654848,0.78982115,,,,,,,,,,,,,, -169300,8.475311,0.8046667,,,,,,,,,,,,,, -169400,10.066752,0.90738344,,,,,,,,,,,,,, -169500,8.195788,0.7965511,,,,,,,,,,,,,, -169600,8.952694,0.846738,,,,,,,,,,,,,, -169700,10.093283,0.83736885,,,,,,,,,,,,,, -169800,7.701286,0.73530096,,,,,,,,,,,,,, -169900,9.479924,0.89405817,,,,,,,,,,,,,, -170000,8.493184,0.8415154,,,,,,,,,,,,,, -170100,9.761829,0.77224493,,,,,,,,,,,,,, -170187,,,0.9106544852256776,0.3087378144264221,0.7574599981307983,0.9799865484237672,50000.0,0.6338000297546387,1.6687065362930298,10000.0,57675.29413366318,59703.755058288574,57675.29413366318,2017.519765138626,4.954912900924683,0.0 -170200,8.371581,0.7962718,,,,,,,,,,,,,, -170300,9.856999,0.8392453,,,,,,,,,,,,,, -170400,9.112534,0.77196157,,,,,,,,,,,,,, -170500,9.184378,0.7771572,,,,,,,,,,,,,, -170600,9.725001,0.89025885,,,,,,,,,,,,,, -170700,8.953537,0.74108136,,,,,,,,,,,,,, -170800,9.315863,0.7674526,,,,,,,,,,,,,, -170900,9.185582,0.7591305,,,,,,,,,,,,,, -171000,9.197354,0.71728534,,,,,,,,,,,,,, -171100,9.521381,0.8279508,,,,,,,,,,,,,, -171200,9.272113,0.730859,,,,,,,,,,,,,, -171300,8.657754,0.8223077,,,,,,,,,,,,,, -171400,8.787059,0.7009287,,,,,,,,,,,,,, -171500,8.470106,0.76486427,,,,,,,,,,,,,, -171600,8.377724,0.74643123,,,,,,,,,,,,,, -171695,,,0.924226701259613,0.2707524597644806,0.757099986076355,0.9797187447547911,50000.0,0.6383000016212463,1.6713703870773315,10000.0,58185.4462184906,60231.59297156334,58185.4462184906,2035.09224319458,5.01508355140686,0.0 -171700,9.544834,0.86919504,,,,,,,,,,,,,, -171800,8.847729,0.74465877,,,,,,,,,,,,,, -171900,9.054957,0.77650464,,,,,,,,,,,,,, -172000,9.095485,0.764777,,,,,,,,,,,,,, -172100,9.215189,0.82682735,,,,,,,,,,,,,, -172200,9.731484,0.82544297,,,,,,,,,,,,,, -172300,8.558599,0.7445369,,,,,,,,,,,,,, -172400,8.399072,0.7462869,,,,,,,,,,,,,, -172500,9.8952875,0.8028975,,,,,,,,,,,,,, -172600,9.327245,0.74708724,,,,,,,,,,,,,, -172700,9.282231,0.7740226,,,,,,,,,,,,,, -172800,10.086449,0.7353619,,,,,,,,,,,,,, -172900,10.435882,0.81860316,,,,,,,,,,,,,, -173000,8.729208,0.7310967,,,,,,,,,,,,,, -173100,9.080298,0.82967144,,,,,,,,,,,,,, -173200,9.98905,0.7846443,,,,,,,,,,,,,, -173202,,,0.9278340339660645,0.2625244557857513,0.758899986743927,0.9757165312767028,50000.0,0.6396000385284424,1.6715961694717407,10000.0,58695.36663389206,60759.16719126701,58695.36663389206,2052.647294998169,5.063016414642334,0.0 -173300,8.854166,0.8203432,,,,,,,,,,,,,, -173400,9.611991,0.68728673,,,,,,,,,,,,,, -173500,8.691309,0.76725435,,,,,,,,,,,,,, -173600,9.39477,0.80470806,,,,,,,,,,,,,, -173700,10.401324,0.8028625,,,,,,,,,,,,,, -173800,9.92518,0.8395438,,,,,,,,,,,,,, -173900,9.381167,0.8204537,,,,,,,,,,,,,, -174000,10.064153,0.78955257,,,,,,,,,,,,,, -174100,9.071093,0.7926024,,,,,,,,,,,,,, -174200,9.859471,0.7862271,,,,,,,,,,,,,, -174300,8.91721,0.78238034,,,,,,,,,,,,,, -174400,10.126909,0.76956266,,,,,,,,,,,,,, -174500,8.615598,0.80781996,,,,,,,,,,,,,, -174600,9.049289,0.8506655,,,,,,,,,,,,,, -174700,9.075789,0.8227005,,,,,,,,,,,,,, -174710,,,0.9278539419174194,0.2555326819419861,0.7615399956703186,0.9703835844993592,50000.0,0.6392000317573547,1.6742520332336426,10000.0,59205.45822405815,61286.8631067276,59205.45822405815,2070.1429677009583,5.118837833404541,0.0 -174800,9.366888,0.7452716,,,,,,,,,,,,,, -174900,9.0643835,0.7776153,,,,,,,,,,,,,, -175000,8.548725,0.79035556,,,,,,,,,,,,,, -175100,9.889936,0.78352046,,,,,,,,,,,,,, -175200,9.544049,0.7853558,,,,,,,,,,,,,, -175300,9.813369,0.83786285,,,,,,,,,,,,,, -175400,9.621998,0.703187,,,,,,,,,,,,,, -175500,9.349079,0.84231365,,,,,,,,,,,,,, -175600,8.680082,0.73962617,,,,,,,,,,,,,, -175700,8.955928,0.7982542,,,,,,,,,,,,,, -175800,9.1938925,0.7644135,,,,,,,,,,,,,, -175900,9.617907,0.8624522,,,,,,,,,,,,,, -176000,9.666374,0.7938887,,,,,,,,,,,,,, -176100,8.290899,0.72400004,,,,,,,,,,,,,, -176200,9.638471,0.80555856,,,,,,,,,,,,,, -176218,,,0.926418960094452,0.2620689570903778,0.7624199986457825,0.9670127034187316,50000.0,0.6402000188827515,1.6667834520339966,10000.0,59715.56953692436,61814.30304598808,59715.56953692436,2087.3609421253204,5.175340414047241,0.0 -176300,9.706517,0.78875923,,,,,,,,,,,,,, -176400,9.371303,0.8009281,,,,,,,,,,,,,, -176500,9.376794,0.7876346,,,,,,,,,,,,,, -176600,10.129209,0.7197939,,,,,,,,,,,,,, -176700,9.140056,0.82052606,,,,,,,,,,,,,, -176800,10.281442,0.78127694,,,,,,,,,,,,,, -176900,10.253676,0.78658694,,,,,,,,,,,,,, -177000,10.034853,0.7964683,,,,,,,,,,,,,, -177100,9.255141,0.8384554,,,,,,,,,,,,,, -177200,8.99334,0.71582353,,,,,,,,,,,,,, -177300,10.013929,0.7475413,,,,,,,,,,,,,, -177400,8.3378315,0.70633906,,,,,,,,,,,,,, -177500,10.473986,0.81090015,,,,,,,,,,,,,, -177600,9.311041,0.7104384,,,,,,,,,,,,,, -177700,10.064758,0.76673526,,,,,,,,,,,,,, -177726,,,0.927156388759613,0.2595357000827789,0.7625799775123596,0.9643649458885192,50000.0,0.6399000287055969,1.6655884981155396,10000.0,60225.75075531006,62342.37773346901,60225.75075531006,2105.146994113922,5.229976654052734,0.0 -177800,8.890577,0.7921909,,,,,,,,,,,,,, -177900,9.470848,0.83069164,,,,,,,,,,,,,, -178000,9.75574,0.7569425,,,,,,,,,,,,,, -178100,9.075534,0.7561671,,,,,,,,,,,,,, -178200,9.26366,0.742368,,,,,,,,,,,,,, -178300,9.373938,0.760157,,,,,,,,,,,,,, -178400,9.003464,0.7115295,,,,,,,,,,,,,, -178500,9.169102,0.73539793,,,,,,,,,,,,,, -178600,9.310589,0.8582133,,,,,,,,,,,,,, -178700,9.106975,0.7552422,,,,,,,,,,,,,, -178800,10.328778,0.7558171,,,,,,,,,,,,,, -178900,9.725902,0.7696839,,,,,,,,,,,,,, -179000,8.26182,0.74324065,,,,,,,,,,,,,, -179100,9.593958,0.8052176,,,,,,,,,,,,,, -179200,9.086088,0.71143544,,,,,,,,,,,,,, -179234,,,0.9292888641357422,0.2532636523246765,0.7625399827957153,0.962642788887024,50000.0,0.6411000490188599,1.6631746292114258,10000.0,60735.92390227318,62870.22953939438,60735.92390227318,2122.712076902389,5.291700601577759,0.0 -179300,9.29757,0.7393747,,,,,,,,,,,,,, -179400,8.921832,0.7113048,,,,,,,,,,,,,, -179500,9.126083,0.75660837,,,,,,,,,,,,,, -179600,8.977776,0.7159683,,,,,,,,,,,,,, -179700,8.12385,0.7000072,,,,,,,,,,,,,, -179800,9.175866,0.7199036,,,,,,,,,,,,,, -179900,8.641133,0.6904889,,,,,,,,,,,,,, -180000,8.731987,0.6965074,,,,,,,,,,,,,, -180100,10.522349,0.7849077,,,,,,,,,,,,,, -180200,9.596565,0.6941109,,,,,,,,,,,,,, -180300,9.186325,0.72768974,,,,,,,,,,,,,, -180400,9.392252,0.7778475,,,,,,,,,,,,,, -180500,9.686813,0.73684925,,,,,,,,,,,,,, -180600,10.348262,0.7079736,,,,,,,,,,,,,, -180700,10.419376,0.7901297,,,,,,,,,,,,,, -180742,,,0.9323381781578064,0.2413637936115265,0.763480007648468,0.9599735736846924,50000.0,0.6431000232696533,1.6595665216445925,10000.0,61246.10619521141,63398.17158651352,61246.10619521141,2140.351333856582,5.360998392105103,0.0 -180800,9.593069,0.757032,,,,,,,,,,,,,, -180900,9.664256,0.75639385,,,,,,,,,,,,,, -181000,9.246564,0.7217957,,,,,,,,,,,,,, -181100,9.737207,0.8004011,,,,,,,,,,,,,, -181200,8.600393,0.7344938,,,,,,,,,,,,,, -181300,8.744046,0.6636405,,,,,,,,,,,,,, -181400,8.919175,0.7773916,,,,,,,,,,,,,, -181500,8.330708,0.6632563,,,,,,,,,,,,,, -181600,10.105836,0.7571882,,,,,,,,,,,,,, -181700,8.90893,0.7457719,,,,,,,,,,,,,, -181800,9.566295,0.80583775,,,,,,,,,,,,,, -181900,8.955478,0.7698382,,,,,,,,,,,,,, -182000,9.090498,0.74251974,,,,,,,,,,,,,, -182100,9.354804,0.74568856,,,,,,,,,,,,,, -182200,9.361649,0.7770352,,,,,,,,,,,,,, -182249,,,0.9336535334587096,0.2402791529893875,0.7635599970817566,0.9584994316101074,50000.0,0.6432000398635864,1.664105772972107,10000.0,61756.06179380417,63925.83683180809,61756.06179380417,2157.9516232013702,5.419076681137085,0.0 -182300,8.87693,0.7389826,,,,,,,,,,,,,, -182400,10.202564,0.7616934,,,,,,,,,,,,,, -182500,8.573059,0.7132356,,,,,,,,,,,,,, -182600,8.452941,0.74010944,,,,,,,,,,,,,, -182700,10.234291,0.76664066,,,,,,,,,,,,,, -182800,10.032936,0.74743307,,,,,,,,,,,,,, -182900,10.340072,0.7729756,,,,,,,,,,,,,, -183000,9.866049,0.75956273,,,,,,,,,,,,,, -183100,9.8069105,0.76343477,,,,,,,,,,,,,, -183200,8.674625,0.70990777,,,,,,,,,,,,,, -183300,8.93265,0.8252755,,,,,,,,,,,,,, -183400,9.247288,0.758861,,,,,,,,,,,,,, -183500,9.279394,0.759166,,,,,,,,,,,,,, -183600,9.930377,0.8403345,,,,,,,,,,,,,, -183700,8.618325,0.65445393,,,,,,,,,,,,,, -183757,,,0.935566782951355,0.2406303137540817,0.7646999955177307,0.9566633105278016,50000.0,0.6440000534057617,1.6585973501205444,10000.0,62266.20957398415,64453.770429611206,62266.20957398415,2175.6238107681274,5.481184005737305,0.0 -183800,9.41692,0.73467064,,,,,,,,,,,,,, -183900,9.040458,0.7858412,,,,,,,,,,,,,, -184000,9.886035,0.79053575,,,,,,,,,,,,,, -184100,8.774415,0.7772209,,,,,,,,,,,,,, -184200,9.38826,0.7800963,,,,,,,,,,,,,, -184300,8.993067,0.71016604,,,,,,,,,,,,,, -184400,8.690764,0.7560521,,,,,,,,,,,,,, -184500,9.309555,0.8104429,,,,,,,,,,,,,, -184600,8.490863,0.708276,,,,,,,,,,,,,, -184700,9.124922,0.74100995,,,,,,,,,,,,,, -184800,10.091351,0.69761056,,,,,,,,,,,,,, -184900,9.107313,0.74487543,,,,,,,,,,,,,, -185000,9.335832,0.7238143,,,,,,,,,,,,,, -185100,9.469693,0.7665613,,,,,,,,,,,,,, -185200,10.078446,0.7091474,,,,,,,,,,,,,, -185265,,,0.9326968789100648,0.2426381111145019,0.7636399865150452,0.9573678374290466,50000.0,0.6434000134468079,1.6598241329193115,10000.0,62776.36577916145,64981.38829827309,62776.36577916145,2192.9691956043243,5.543259382247925,0.0 -185300,9.857883,0.6966859,,,,,,,,,,,,,, -185400,9.264001,0.7182425,,,,,,,,,,,,,, -185500,9.430655,0.71097195,,,,,,,,,,,,,, -185600,9.443064,0.79917115,,,,,,,,,,,,,, -185700,8.747757,0.7139369,,,,,,,,,,,,,, -185800,9.12613,0.7208414,,,,,,,,,,,,,, -185900,9.11973,0.7582873,,,,,,,,,,,,,, -185951,,,,,,,,,,,63008.286588430405,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/eval_measurements.csv deleted file mode 100644 index c52443632..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.589736700057983,0.0,33.78456735610962,1,0,33.78456735610962,0.0013000001199543,6.9117279052734375,10000,51.374412059783936,0.0007772640092298,6.911434173583984,0.0006799999973736,6.912051200866699,50000 -35.278390645980835,0.01981782913208,543.8763875961304,1500,0,543.8763875961304,0.1568000018596649,4.447728633880615,10000,579.2266387939453,0.2221181392669677,3.8356406688690186,0.2069999873638153,3.960334539413452,50000 -53.212013721466064,0.0511689186096191,1054.0921666622162,3001,0,1054.0921666622162,0.2457000166177749,3.8936009407043457,10000,1107.459528684616,0.3391063511371612,3.0905704498291016,0.3193399906158447,3.2351529598236084,50000 -70.99523687362671,0.0811347961425781,1564.1626377105713,4504,0,1564.1626377105713,0.2432000041007995,4.029863357543945,10000,1635.3940970897677,0.3458625674247741,3.1005280017852783,0.3310000002384186,3.232429265975952,50000 -88.70134115219116,0.1119053363800048,2074.3404870033264,6008,0,2074.3404870033264,0.2597000002861023,3.796529531478882,10000,2163.361767053604,0.3782485425472259,2.895301818847656,0.3356399834156036,3.18208122253418,50000 -106.30011343955994,0.1416637897491455,2584.437113523484,7513,0,2584.437113523484,0.1861000061035156,4.576222896575928,10000,2691.138946533203,0.2871890962123871,3.638346910476685,0.2617000043392181,3.8255012035369873,50000 -124.0553548336029,0.1691560745239257,3094.679023504257,9020,0,3094.679023504257,0.2747000157833099,3.7215497493743896,10000,3219.2161860466003,0.3981983363628387,2.802988052368164,0.3658399879932403,3.020524501800537,50000 -142.4556610584259,0.1979579925537109,3604.599910259247,10527,0,3604.599910259247,0.2690000236034393,3.798298120498657,10000,3747.6188309192657,0.3869379758834839,2.846888542175293,0.3642399907112121,3.010152101516724,50000 -160.00274682044983,0.2236948013305664,4114.688268184662,12035,0,4114.688268184662,0.2058000117540359,4.551807880401611,10000,4275.333572149277,0.302156388759613,3.548335313796997,0.2874999940395355,3.705745458602905,50000 -177.63978910446167,0.2553791999816894,4624.764115333557,13543,0,4624.764115333557,0.1577000021934509,5.386991500854492,10000,4803.130608320236,0.2102798074483871,4.674249649047852,0.1959799975156784,4.838823795318604,50000 -195.9572970867157,0.2865011692047119,5134.7330057621,15052,0,5134.7330057621,0.1561000049114227,5.416417598724365,10000,5331.500044822693,0.2354113459587097,4.341064929962158,0.2106999903917312,4.6160736083984375,50000 -213.5535204410553,0.3168478012084961,5644.699779033661,16562,0,5644.699779033661,0.1513000130653381,4.930394649505615,10000,5859.146992444992,0.2307876199483871,4.050477504730225,0.2169799953699112,4.200684070587158,50000 -231.6983737945557,0.3491621017456054,6154.821240901947,18073,0,6154.821240901947,0.2099000066518783,4.276243209838867,10000,6387.498101949692,0.3052654564380646,3.361721992492676,0.2855399847030639,3.5417702198028564,50000 -249.51444363594047,0.3746833801269531,6664.954342365265,19584,0,6664.954342365265,0.1702000051736831,4.693854331970215,10000,6915.529188632965,0.2442004084587097,3.913287878036499,0.230119988322258,4.02189826965332,50000 -267.2611298561096,0.4081254005432129,7175.1495196819305,21096,0,7175.1495196819305,0.1699000149965286,4.684741973876953,10000,7443.557811498642,0.242267221212387,3.892278432846069,0.2260199934244156,4.00605058670044,50000 -284.9658041000366,0.439424991607666,7685.252676725388,22609,0,7685.252676725388,0.1548000127077102,5.273361682891846,10000,7971.450614213943,0.2122528702020645,4.563592433929443,0.2013600021600723,4.659169673919678,50000 -302.4214491844177,0.4703423976898193,8195.383947610855,24121,0,8195.383947610855,0.1574000120162964,5.048336029052734,10000,8499.119955062866,0.2415298074483871,4.111358165740967,0.2184399962425232,4.350162506103516,50000 -320.2576837539673,0.502802848815918,8705.445008277893,25634,0,8705.445008277893,0.1791000068187713,4.636255741119385,10000,9027.102459907532,0.2726801633834839,3.768687009811402,0.2542800009250641,3.919223308563232,50000 -337.85802817344666,0.5298073291778564,9215.368250131609,27146,0,9215.368250131609,0.1522000133991241,5.444266319274902,10000,9554.7049472332,0.2140664756298065,4.525050640106201,0.2014999985694885,4.717775344848633,50000 -355.6845571994781,0.5628311634063721,9725.415110111237,28659,0,9725.415110111237,0.1374000012874603,5.371987819671631,10000,10082.664724826813,0.1899513602256775,4.6598968505859375,0.1812800019979477,4.777307510375977,50000 -373.0587408542633,0.594775915145874,10235.613961458206,30172,0,10235.613961458206,0.1344000101089477,5.338951110839844,10000,10610.321287631989,0.1981624662876129,4.559650897979736,0.1924600005149841,4.613406181335449,50000 -390.7978858947754,0.6274514198303223,10745.81271648407,31686,0,10745.81271648407,0.228300005197525,4.332971572875977,10000,11138.345761299131,0.3178212642669678,3.394036054611206,0.3035599887371063,3.541114330291748,50000 -408.22711849212646,0.6634583473205566,11255.900620937347,33200,0,11255.900620937347,0.1111000031232833,5.935586452484131,10000,11665.95344877243,0.1719148606061935,5.012552261352539,0.1557399928569793,5.170931339263916,50000 -425.9146020412445,0.7024168968200684,11766.076352834702,34714,0,11766.076352834702,0.2261000126600265,4.192788124084473,10000,12193.907732248306,0.3223851919174194,3.268257141113281,0.3014200031757355,3.499361515045166,50000 -443.42841243743896,0.7372820377349854,12276.239780902864,36228,0,12276.239780902864,0.1833000034093856,4.704871654510498,10000,12721.67400622368,0.2633928656578064,3.819126605987549,0.2487399876117706,3.92941689491272,50000 -461.1026515960693,0.7711968421936035,12786.358525514604,37742,0,12786.358525514604,0.2308000177145004,4.0718674659729,10000,13249.55362534523,0.3377511203289032,3.210239887237549,0.3193999826908111,3.303650140762329,50000 -478.5720095634461,0.8064799308776855,13296.550779342651,39256,0,13296.550779342651,0.1162000074982643,5.706114768981934,10000,13777.302340745926,0.1889150142669677,4.7008209228515625,0.1738400012254715,4.816714763641357,50000 -495.9221394062042,0.8445472717285156,13806.753144979475,40770,0,13806.753144979475,0.064300000667572,7.548539638519287,10000,14304.94440293312,0.0982541441917419,6.791675090789795,0.092299997806549,6.967357158660889,50000 -513.688027381897,0.8791120052337646,14316.756876945496,42282,0,14316.756876945496,0.2132000029087066,4.345427513122559,10000,14832.802520275116,0.3134167790412903,3.368189573287964,0.2895599901676178,3.5763092041015625,50000 -531.8394250869751,0.916710615158081,14826.869350671768,43796,0,14826.869350671768,0.1721000075340271,4.735334873199463,10000,15361.157025814056,0.2523716390132904,3.879667043685913,0.237879991531372,4.018867015838623,50000 -549.6033554077148,0.9554266929626464,15336.784964323044,45310,0,15336.784964323044,0.0870000049471855,6.223611831665039,10000,15888.927158594131,0.1259765625,5.498228073120117,0.1167399957776069,5.669344425201416,50000 -567.3659672737122,0.9953644275665284,15846.803076267242,46824,0,15846.803076267242,0.2154000103473663,4.327939033508301,10000,16416.799332618713,0.3132772445678711,3.4213531017303467,0.2969000041484833,3.581291675567627,50000 -584.9171187877655,1.0334186553955078,16356.954234361649,48338,0,16356.954234361649,0.2399000078439712,4.064082145690918,10000,16944.592866659164,0.3341836631298065,3.1928305625915527,0.3179799914360046,3.339296817779541,50000 -602.6156764030457,1.071347713470459,16867.11060857773,49852,0,16867.11060857773,0.1909000128507614,4.771968841552734,10000,17472.53868341446,0.2771245241165161,3.76228141784668,0.2566399872303009,3.951295852661133,50000 -620.5363309383392,1.1086018085479736,17377.086690425873,51366,0,17377.086690425873,0.0683000013232231,6.9438652992248535,10000,18000.52445960045,0.1058673411607742,6.146651268005371,0.093079999089241,6.373527526855469,50000 -638.2453672885895,1.144514799118042,17887.21930384636,52880,0,17887.21930384636,0.3041000068187713,3.449614286422729,10000,18528.455255270004,0.4368024468421936,2.536063432693481,0.4030399918556213,2.7369699478149414,50000 -655.88427901268,1.1814467906951904,18397.130579948425,54394,0,18397.130579948425,0.19930000603199,4.530226230621338,10000,19056.094877958298,0.2947225570678711,3.631112098693848,0.2807799875736236,3.7400660514831534,50000 -673.4222629070282,1.218961477279663,18907.1184194088,55908,0,18907.1184194088,0.1299000084400177,5.389272212982178,10000,19583.71437358856,0.1893734037876129,4.625549793243408,0.1741199940443039,4.754756927490234,50000 -690.9619073867798,1.2572824954986572,19417.15093255043,57422,0,19417.15093255043,0.1934000104665756,4.597098350524902,10000,20111.37770724297,0.2808115482330322,3.709074258804321,0.2637200057506561,3.829373598098755,50000 -708.6484682559967,1.293562412261963,19927.386483430862,58937,0,19927.386483430862,0.1551000028848648,5.24745512008667,10000,20639.389159202576,0.2262236922979354,4.323338985443115,0.2068399935960769,4.579726219177246,50000 -726.3159120082855,1.3322150707244873,20437.37043976784,60451,0,20437.37043976784,0.2665000259876251,3.855259656906128,10000,21167.13238477707,0.388093888759613,2.8707337379455566,0.3535799980163574,3.126149177551269,50000 -743.8167865276337,1.3707172870635986,20947.30847287178,61965,0,20947.30847287178,0.1634000092744827,4.972453594207764,10000,21694.66338658333,0.2292729616165161,4.253842830657959,0.220100000500679,4.375511646270752,50000 -761.1923484802246,1.4095752239227295,21457.2831799984,63479,0,21457.2831799984,0.1897000074386596,4.5615234375,10000,22222.106975317,0.2836615145206451,3.637618541717529,0.2704199850559234,3.731812715530396,50000 -778.6932566165924,1.4470484256744385,21967.26203918457,64993,0,21967.26203918457,0.1979000121355056,4.845719337463379,10000,22749.678109169006,0.2693319320678711,3.968483686447144,0.2618399858474731,4.07616662979126,50000 -796.4663238525391,1.4883835315704346,22477.223001003265,66507,0,22477.223001003265,0.2486000061035156,4.025635242462158,10000,23277.50735592842,0.3372528553009033,3.1919519901275635,0.3183600008487701,3.329773664474488,50000 -814.1657972335815,1.5314984321594238,22987.152262687683,68021,0,22987.152262687683,0.2199000120162964,4.267356395721436,10000,23805.23312044144,0.3380899131298065,3.256957054138184,0.293720006942749,3.590768575668335,50000 -831.8838548660278,1.574218988418579,23497.329399108887,69535,0,23497.329399108887,0.2547000050544739,4.050992488861084,10000,24333.222053050995,0.3651546537876129,3.035815715789795,0.3411799967288971,3.226549863815308,50000 -849.4705073833466,1.6176283359527588,24007.33136749268,71049,0,24007.33136749268,0.265500009059906,4.025924205780029,10000,24860.906602859497,0.3894889950752258,2.9351134300231934,0.3605599999427795,3.162463903427124,50000 -867.0384802818298,1.6555404663085938,24517.32631087303,72563,0,24517.32631087303,0.1791000068187713,4.710382461547852,10000,25388.561143636703,0.2634526491165161,3.807114124298096,0.2535600066184997,3.922708034515381,50000 -884.5662899017334,1.6954331398010254,25027.415630340576,74077,0,25027.415630340576,0.2621999979019165,4.026857376098633,10000,25916.27106308937,0.3816764950752258,2.9451732635498047,0.3625199794769287,3.1048009395599365,50000 -901.976181268692,1.7384414672851562,25537.43885302544,75591,0,25537.43885302544,0.0870000049471855,6.647572040557861,10000,26443.80044603348,0.1342075914144516,5.668425559997559,0.1243399977684021,5.764382839202881,50000 -919.41881108284,1.7832458019256592,26047.63839364052,77106,0,26047.63839364052,0.2205000072717666,4.244337558746338,10000,26971.54065322876,0.3316725194454193,3.21846866607666,0.2920999825000763,3.503586769104004,50000 -936.843403339386,1.825603008270264,26557.77646183968,78620,0,26557.77646183968,0.2669000029563904,4.146973609924316,10000,27499.20015025139,0.3699776828289032,3.103986978530884,0.3526600003242492,3.277158737182617,50000 -954.2851715087892,1.868180513381958,27067.766325950623,80134,0,27067.766325950623,0.2421000152826309,4.21008825302124,10000,28026.72852373123,0.3420161008834839,3.218864917755127,0.3225799798965454,3.411599636077881,50000 -971.8416512012482,1.9092822074890137,27577.878240585327,81648,0,27577.878240585327,0.281900018453598,3.800844669342041,10000,28554.49188780785,0.3948501050472259,2.805603265762329,0.3705599904060364,2.987741470336914,50000 -989.644821882248,1.954078197479248,28087.81955480576,83162,0,28087.81955480576,0.1921000033617019,5.06205940246582,10000,29082.3350391388,0.2846579849720001,3.92525863647461,0.2743200063705444,4.041182041168213,50000 -1007.0220937728882,1.9960856437683103,28598.049216508865,84677,0,28598.049216508865,0.331900030374527,3.2385549545288086,10000,29610.035831689835,0.4646643698215484,2.382607698440552,0.4369199872016907,2.5394155979156494,50000 -1024.5876967906952,2.03897762298584,29108.04040503502,86191,0,29108.04040503502,0.3320000171661377,3.385972738265991,10000,30137.687334775925,0.482122927904129,2.2993392944335938,0.4251599907875061,2.6702864170074463,50000 -1042.0650265216827,2.081550359725952,29617.99115371704,87705,0,29617.99115371704,0.2369000166654586,4.14555025100708,10000,30665.211062908173,0.3338049948215484,3.208051443099976,0.304419994354248,3.446050882339477,50000 -1059.4081492424011,2.1293301582336426,30127.92679142952,89219,0,30127.92679142952,0.2253000140190124,4.554732322692871,10000,31192.59133434296,0.3150709569454193,3.5599539279937744,0.2946600019931793,3.770634651184082,50000 -1076.9277966022491,2.174423694610596,30637.95221996308,90733,0,30637.95221996308,0.2732000052928924,3.744641542434693,10000,31720.2343685627,0.3970822691917419,2.781815528869629,0.3685599863529205,2.9695539474487305,50000 -1094.9929592609406,2.2205803394317627,31148.09440755844,92246,0,31148.09440755844,0.320000022649765,3.2951889038085938,10000,32248.54153752327,0.4543407261371612,2.392561912536621,0.4245999753475189,2.587425708770752,50000 -1112.41717171669,2.264895439147949,31658.02108550072,93760,0,31658.02108550072,0.3368000090122223,3.445615291595459,10000,32775.989436626434,0.458685427904129,2.434083938598633,0.4264200031757355,2.6672022342681885,50000 -1129.7756674289703,2.3125619888305664,32168.247616052628,95275,0,32168.247616052628,0.2595000267028808,4.024302959442139,10000,33303.67395091057,0.3795639276504516,2.966967821121216,0.340179979801178,3.3062493801116943,50000 -1147.1902074813845,2.362378597259521,32678.433045387268,96790,0,32678.433045387268,0.2849000096321106,3.783292055130005,10000,33831.3767824173,0.4096579849720001,2.806935787200928,0.3768999874591827,3.023160934448242,50000 -1164.826674938202,2.409458637237549,33188.55919909477,98304,0,33188.55919909477,0.2519000172615051,4.132585525512695,10000,34359.24005818367,0.351283460855484,3.222519874572754,0.3321599960327148,3.3702187538146973,50000 -1182.862447977066,2.4572508335113525,33698.464007377625,99818,0,33698.464007377625,0.2345000058412552,4.5583319664001465,10000,34887.28036427498,0.328125,3.5659189224243164,0.3105199933052063,3.713387966156006,50000 -1200.3929872512815,2.495596408843994,34208.510954380035,101332,0,34208.510954380035,0.3333000242710113,3.3893141746521,10000,35414.95019340515,0.4641461968421936,2.4268994331359863,0.4397799968719482,2.5999388694763184,50000 -1218.1181762218475,2.540222406387329,34718.491564273834,102846,0,34718.491564273834,0.3058000206947326,3.579463481903076,10000,35942.75317645073,0.4216557741165161,2.711665391921997,0.3911399841308594,2.9025368690490723,50000 -1235.8320398330688,2.584373712539673,35228.44272494316,104360,0,35228.44272494316,0.3788000047206878,3.055749654769897,10000,36470.514008522034,0.5540298223495483,1.874036312103272,0.4943999946117401,2.240631580352783,50000 -1253.241216421127,2.6315665245056152,35738.398156404495,105874,0,35738.398156404495,0.3613000214099884,3.1027300357818604,10000,36997.97931480408,0.5051219463348389,2.1285979747772217,0.4580999910831451,2.40424919128418,50000 -1271.1150813102722,2.683964490890503,36248.40759110451,107388,0,36248.40759110451,0.3841000199317932,3.025560140609741,10000,37525.970309495926,0.5316087007522583,2.0331740379333496,0.4949599802494049,2.2442522048950195,50000 -1288.5332021713257,2.757606744766236,36758.39783191681,108902,0,36758.39783191681,0.3831000328063965,3.019572496414185,10000,38053.5044836998,0.5332629084587097,2.0092248916625977,0.4962999820709228,2.2380595207214355,50000 -1306.213036775589,2.8046953678131104,37268.32622623444,110416,0,37268.32622623444,0.3677000105381012,3.080395698547364,10000,38581.21153998375,0.5140505433082581,2.1145224571228027,0.4885999858379364,2.261136531829834,50000 -1323.868717432022,2.851999282836914,37778.27315187454,111930,0,37778.27315187454,0.369700014591217,3.0804600715637207,10000,39108.91320705414,0.5302335619926453,2.0475571155548096,0.4895599782466888,2.2830638885498047,50000 -1341.6378679275513,2.903039693832397,38288.5008263588,113445,0,38288.5008263588,0.3817000091075897,2.9632489681243896,10000,39637.01444840431,0.5375478267669678,1.958598017692566,0.4891199767589569,2.234560489654541,50000 -1358.971853017807,2.952772617340088,38798.68224453926,114960,0,38798.68224453926,0.379800021648407,3.0490152835845947,10000,40164.63128519058,0.5321866869926453,1.9915037155151367,0.4983199834823608,2.20292329788208,50000 -1376.4977297782898,2.999505519866944,39308.85549354553,116475,0,39308.85549354553,0.4178000092506408,2.817715644836426,10000,40692.430067777634,0.5823102593421936,1.767098069190979,0.5386199951171875,2.0045840740203857,50000 -1394.1560413837433,3.048726320266724,39818.84736561775,117989,0,39818.84736561775,0.2914000153541565,3.691535711288452,10000,41220.18163561821,0.4283721148967743,2.6851370334625244,0.3919200003147125,2.899744987487793,50000 -1411.35599732399,3.1032016277313232,40328.843329668045,119503,0,40328.843329668045,0.3220000267028808,3.63352370262146,10000,41747.48498153687,0.447644293308258,2.602922201156616,0.4219799935817718,2.8036701679229736,50000 -1428.919724702835,3.1509153842926025,40838.74562501907,121017,0,40838.74562501907,0.3559000194072723,3.172493934631348,10000,42275.05197453499,0.5178372263908386,2.0767385959625244,0.4694999754428863,2.3726882934570312,50000 -1446.3481650352478,3.2026822566986084,41348.64927601814,122531,0,41348.64927601814,0.4456000328063965,2.6093904972076416,10000,42802.48861408234,0.6288264989852905,1.5047695636749268,0.5694999694824219,1.8335968255996704,50000 -1464.131145477295,3.255311965942383,41858.61288642883,124045,0,41858.61288642883,0.4452000260353088,2.5815134048461914,10000,43330.34150052071,0.6246811151504517,1.5214416980743408,0.5735200047492981,1.811548113822937,50000 -1481.754063129425,3.308755159378052,42368.84647965431,125560,0,42368.84647965431,0.4448000192642212,2.698650360107422,10000,43858.30543160439,0.6104312539100647,1.618659257888794,0.5592799782752991,1.915094375610352,50000 -1499.2517862319946,3.358278512954712,42878.92653274536,127074,0,42878.92653274536,0.4290000200271606,2.7601938247680664,10000,44385.98586678505,0.5913584232330322,1.713505506515503,0.5462999939918518,2.002792835235596,50000 -1516.6094889640808,3.411818504333496,43388.88260865212,128588,0,43388.88260865212,0.4603000283241272,2.5546529293060303,10000,44913.40585780144,0.6278898119926453,1.5220273733139038,0.5804799795150757,1.792826771736145,50000 -1534.450347661972,3.4636495113372803,43898.80344581604,130102,0,43898.80344581604,0.4164000153541565,2.8192737102508545,10000,45441.27255749703,0.6101123690605164,1.616065502166748,0.5356799960136414,2.0169262886047363,50000 -1552.0316081047058,3.529684543609619,44408.81854104996,131616,0,44408.81854104996,0.4419000148773193,2.6380434036254883,10000,45968.98812127113,0.6253786683082581,1.5129410028457642,0.5648399591445923,1.8565038442611688,50000 -1569.5413410663605,3.5810797214508057,44918.97057008743,133131,0,44918.97057008743,0.4800000190734863,2.4350662231445312,10000,46496.75433373451,0.6495934128761292,1.4116238355636597,0.5912799835205078,1.7342907190322876,50000 -1587.159719467163,3.6312103271484375,45428.92800879479,134645,0,45428.92800879479,0.4717000126838684,2.4368271827697754,10000,47024.43297600746,0.6475008130073547,1.4163345098495483,0.5905199646949768,1.7114334106445312,50000 -1604.752497673035,3.681864023208618,45938.94582152367,136159,0,45938.94582152367,0.4859000146389007,2.411069869995117,10000,47552.14718770981,0.6627072691917419,1.3503409624099731,0.6053799986839294,1.6583715677261353,50000 -1622.6106128692627,3.7330148220062256,46449.14306783676,137674,0,46449.14306783676,0.5019000172615051,2.331106662750244,10000,48080.309012174606,0.6785116195678711,1.28064227104187,0.6271799802780151,1.5513432025909424,50000 -1640.0028715133667,3.786928653717041,46959.10416579247,139188,0,46959.10416579247,0.467600017786026,2.5262668132781982,10000,48607.7670943737,0.667410671710968,1.3150213956832886,0.5842999815940857,1.7889710664749146,50000 -1657.6408331394196,3.842206716537476,47469.02995443344,140701,0,47469.02995443344,0.5172000527381897,2.2077748775482178,10000,49135.43949460983,0.7164978981018066,1.0897971391677856,0.6477800011634827,1.4555057287216189,50000 -1675.0142815113068,3.91502857208252,47979.00236058235,142214,0,47979.00236058235,0.5021000504493713,2.242482900619507,10000,49662.91038489342,0.7066326141357422,1.145608901977539,0.6356599926948547,1.4984140396118164,50000 -1692.9932827949524,3.9672343730926514,48489.20419001579,143729,0,48489.20419001579,0.508400022983551,2.2351131439208984,10000,50191.19554066658,0.6968669891357422,1.1911671161651611,0.6352399587631226,1.5196411609649658,50000 -1710.424084186554,4.019224882125855,48999.42097496986,145244,0,48999.42097496986,0.5085000395774841,2.3389248847961426,10000,50718.9482088089,0.6825175285339355,1.2460888624191284,0.616379976272583,1.6117098331451416,50000 -1728.0896308422089,4.0746119022369385,49509.32262468338,146758,0,49509.32262468338,0.536300003528595,2.1201376914978027,10000,51246.62455034256,0.7297114133834839,1.0454022884368896,0.6575199961662292,1.3876160383224487,50000 -1745.3887765407562,4.133601903915405,50019.25556206703,148272,0,50019.25556206703,0.5314000248908997,2.1541640758514404,10000,51773.96777963638,0.7495615482330322,0.9471167325973512,0.6625799536705017,1.3929927349090576,50000 -1762.67853140831,4.18830156326294,50529.4074280262,149787,0,50529.4074280262,0.534500002861023,2.082439422607422,10000,52301.51612615585,0.7493821382522583,0.9529452323913574,0.6685400009155273,1.3474432229995728,50000 -1780.5153470039368,4.245532751083374,51039.44907045365,151301,0,51039.44907045365,0.5250000357627869,2.173976182937622,10000,52829.5035905838,0.7281169891357422,1.0500671863555908,0.6520599722862244,1.4359551668167114,50000 -1797.9433093070984,4.3016557693481445,51549.42072844505,152815,0,51549.42072844505,0.5455000400543213,2.086134195327759,10000,53357.01241946221,0.7424864172935486,0.9894612431526184,0.6647999882698059,1.3785579204559326,50000 -1815.2350759506223,4.359313249588013,52059.3456428051,154329,0,52059.3456428051,0.5574000477790833,2.01480484008789,10000,53884.340841293335,0.7708266973495483,0.8738231658935547,0.6909799575805664,1.2597182989120483,50000 -1832.526507616043,4.413227081298828,52569.56672739983,155844,0,52569.56672739983,0.5478000044822693,2.047738790512085,10000,54411.96320319176,0.7562380433082581,0.9150132536888124,0.6816799640655518,1.297837734222412,50000 -1850.3109276294708,4.46701979637146,53079.48914647102,157358,0,53079.48914647102,0.5590000152587891,2.0179154872894287,10000,54939.77777934074,0.7893216013908386,0.7879802584648132,0.6877599954605103,1.278788447380066,50000 -1867.7333896160128,4.52330470085144,53589.66965150833,158873,0,53589.66965150833,0.5772000551223755,1.9140307903289795,10000,55467.48884010315,0.7989277839660645,0.7345190644264221,0.7048599720001221,1.2018253803253174,50000 -1885.144741296768,4.571201324462891,54099.629590034485,160387,0,54099.629590034485,0.5842000246047974,1.8834655284881592,10000,55994.96142053604,0.7996053695678711,0.7400349974632263,0.7064200043678284,1.1957767009735107,50000 -1902.832640647888,4.625036954879761,54609.69239163399,161901,0,54609.69239163399,0.5866000056266785,1.8886688947677608,10000,56522.81906700134,0.8086734414100647,0.7023515105247498,0.7143399715423584,1.1650846004486084,50000 -1920.478876829148,4.685438394546509,55119.59383225441,163415,0,55119.59383225441,0.5954000353813171,1.8337507247924805,10000,57050.48101377487,0.8216477632522583,0.6528841257095337,0.7217999696731567,1.1238961219787598,50000 -1937.9243867397308,4.741180896759033,55629.537470817566,164929,0,55629.537470817566,0.5958000421524048,1.849045395851136,10000,57577.97980308533,0.8226243257522583,0.6398840546607971,0.7210400104522705,1.1339006423950195,50000 -1955.0879509449005,4.795497179031372,56139.56280255318,166443,0,56139.56280255318,0.5837000012397766,1.897260069847107,10000,58105.27678442001,0.8295599222183228,0.6190734505653381,0.7127199769020081,1.169293999671936,50000 -1972.470759391785,4.853919506072998,56649.58001804352,167957,0,56649.58001804352,0.6022000312805176,1.799255132675171,10000,58632.78707766533,0.8482740521430969,0.5581567883491516,0.7278199791908264,1.1075727939605713,50000 -1990.327763795853,4.91126561164856,57159.64442586899,169471,0,57159.64442586899,0.6065000295639038,1.802993655204773,10000,59160.81916928291,0.8517617583274841,0.5366014242172241,0.7337799668312073,1.0894720554351809,50000 -2008.0955998897552,4.968322038650513,57669.60288262367,170985,0,57669.60288262367,0.6118000149726868,1.7798755168914795,10000,59688.65509557724,0.8520607352256775,0.5264204144477844,0.7373600006103516,1.0713787078857422,50000 -2025.8408544063568,5.029789209365845,58179.54089784622,172499,0,58179.54089784622,0.6099000573158264,1.7793102264404297,10000,60216.452474832535,0.8572624325752258,0.5081093311309814,0.7387199997901917,1.0654412508010864,50000 -2043.2813086509705,5.085204124450684,58689.59416794777,174013,0,58689.59416794777,0.6139000058174133,1.7508876323699951,10000,60744.05333399773,0.8667888641357422,0.4688793122768402,0.7406599521636963,1.0524687767028809,50000 -2061.011162519455,5.145601987838745,59199.493911504745,175527,0,59199.493911504745,0.6184000372886658,1.7471065521240234,10000,61271.79503440857,0.8767737150192261,0.4423072636127472,0.7445200085639954,1.0392271280288696,50000 -2078.988668680191,5.2041075229644775,59709.4871468544,177041,0,59709.4871468544,0.616100013256073,1.7452729940414429,10000,61799.87656879425,0.8763552308082581,0.4342447221279144,0.7455599904060364,1.0426863431930542,50000 -2096.6750016212463,5.264262676239014,60219.56933808327,178556,0,60219.56933808327,0.6222000122070312,1.731208324432373,10000,62327.75828695297,0.8790258169174194,0.4302779138088226,0.7473999857902527,1.0301430225372314,50000 -2114.281795501709,5.3266565799713135,60729.72563242912,180071,0,60729.72563242912,0.6249000430107117,1.7289284467697144,10000,62855.63610816002,0.878348171710968,0.4257912039756775,0.7497199773788452,1.0273454189300537,50000 -2131.816429138184,5.389420032501221,61239.78214740753,181585,0,61239.78214740753,0.6249000430107117,1.7295113801956177,10000,63383.342509269714,0.8835299611091614,0.412623792886734,0.750220000743866,1.02522611618042,50000 -2149.390313625336,5.450626850128174,61749.71807217598,183099,0,61749.71807217598,0.6252000331878662,1.7207680940628052,10000,63910.96601963043,0.8856026530265808,0.4069898426532745,0.7508999705314636,1.0227185487747192,50000 -2167.0347170829773,5.507994651794434,62259.6433134079,184613,0,62259.6433134079,0.625,1.724463701248169,10000,64438.64605140686,0.8846459984779358,0.4113345444202423,0.7509999871253967,1.0213669538497925,50000 -2184.528416633606,5.566797494888306,62769.76277685165,186127,0,62769.76277685165,0.6243000030517578,1.7228504419326782,10000,64966.37164711952,0.8859614133834839,0.4078052341938019,0.7506799697875977,1.0217421054840088,50000 -2202.380264520645,5.6263251304626465,62951.208920001984,186666,0,62951.208920001984,0.6237000226974487,1.722780704498291,10000,65165.74819970131,0.8869180083274841,0.40388861298561096,0.7511199712753296,1.0213629007339478,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/measurements.csv deleted file mode 100644 index 0f505a19f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.65951836,6.916249,,,,,,,,,,,,,, -1,,,0.0007772640092298,6.911434173583984,0.0006799999973736,6.912051200866699,50000.0,0.0013000001199543,6.9117279052734375,10000.0,33.78456735610962,51.374412059783936,33.78456735610962,17.589736700057983,0.0,0.0 -100,0.7728392,6.6574306,,,,,,,,,,,,,, -200,0.9639778,6.270466,,,,,,,,,,,,,, -300,4.097598,5.9491305,,,,,,,,,,,,,, -400,3.0694587,5.611302,,,,,,,,,,,,,, -500,5.1520123,5.5068665,,,,,,,,,,,,,, -600,3.9026535,5.286533,,,,,,,,,,,,,, -700,3.4540303,5.2254705,,,,,,,,,,,,,, -800,5.173514,5.1939564,,,,,,,,,,,,,, -900,4.6619444,4.8933096,,,,,,,,,,,,,, -1000,4.6439457,4.715914,,,,,,,,,,,,,, -1100,4.9989495,4.6600795,,,,,,,,,,,,,, -1200,4.2266774,4.3852096,,,,,,,,,,,,,, -1300,3.266809,4.387208,,,,,,,,,,,,,, -1400,3.483506,4.1373997,,,,,,,,,,,,,, -1500,,,0.2221181392669677,3.8356406688690186,0.2069999873638153,3.960334539413452,50000.0,0.1568000018596649,4.447728633880615,10000.0,543.8763875961304,579.2266387939453,543.8763875961304,35.278390645980835,0.01981782913208,0.0 -1500,2.456178,4.1563473,,,,,,,,,,,,,, -1600,2.6248307,3.9620183,,,,,,,,,,,,,, -1700,1.8491353,3.9108884,,,,,,,,,,,,,, -1800,2.7104266,3.8140626,,,,,,,,,,,,,, -1900,2.365437,3.865006,,,,,,,,,,,,,, -2000,1.6789199,3.613669,,,,,,,,,,,,,, -2100,2.5820014,3.3880267,,,,,,,,,,,,,, -2200,2.9339688,3.4941092,,,,,,,,,,,,,, -2300,1.60665,3.417815,,,,,,,,,,,,,, -2400,1.7239844,3.364913,,,,,,,,,,,,,, -2500,2.1490011,3.2485292,,,,,,,,,,,,,, -2600,2.3161435,3.2078762,,,,,,,,,,,,,, -2700,1.2875129,3.3777041,,,,,,,,,,,,,, -2800,1.2037095,3.140713,,,,,,,,,,,,,, -2900,1.1408294,3.08986,,,,,,,,,,,,,, -3000,1.3762243,3.138618,,,,,,,,,,,,,, -3001,,,0.3391063511371612,3.0905704498291016,0.3193399906158447,3.2351529598236084,50000.0,0.2457000166177749,3.8936009407043457,10000.0,1054.0921666622162,1107.459528684616,1054.0921666622162,53.212013721466064,0.0511689186096191,0.0 -3100,1.1789217,3.0518715,,,,,,,,,,,,,, -3200,1.1696014,3.0287242,,,,,,,,,,,,,, -3300,1.1884649,3.2471995,,,,,,,,,,,,,, -3400,1.1070592,3.1234522,,,,,,,,,,,,,, -3500,0.96103746,3.122366,,,,,,,,,,,,,, -3600,1.2151154,3.0240934,,,,,,,,,,,,,, -3700,0.91467285,2.8584194,,,,,,,,,,,,,, -3800,0.93897307,2.957863,,,,,,,,,,,,,, -3900,0.95030224,2.883951,,,,,,,,,,,,,, -4000,0.9553036,2.950785,,,,,,,,,,,,,, -4100,0.99095416,2.8193822,,,,,,,,,,,,,, -4200,1.0772164,2.740959,,,,,,,,,,,,,, -4300,1.0360683,2.879907,,,,,,,,,,,,,, -4400,0.9693872,2.893528,,,,,,,,,,,,,, -4500,1.0536494,2.9502661,,,,,,,,,,,,,, -4504,,,0.3458625674247741,3.1005280017852783,0.3310000002384186,3.232429265975952,50000.0,0.2432000041007995,4.029863357543945,10000.0,1564.1626377105713,1635.3940970897677,1564.1626377105713,70.99523687362671,0.0811347961425781,0.0 -4600,0.85145074,2.7478068,,,,,,,,,,,,,, -4700,1.1343073,2.7844887,,,,,,,,,,,,,, -4800,0.89515275,2.9111905,,,,,,,,,,,,,, -4900,1.2163771,2.750597,,,,,,,,,,,,,, -5000,0.83723587,2.8880014,,,,,,,,,,,,,, -5100,1.0370644,2.7484503,,,,,,,,,,,,,, -5200,0.95404965,2.831106,,,,,,,,,,,,,, -5300,0.8218291,2.6214666,,,,,,,,,,,,,, -5400,0.87717736,2.6413548,,,,,,,,,,,,,, -5500,0.9930184,2.610421,,,,,,,,,,,,,, -5600,1.0381688,2.6160302,,,,,,,,,,,,,, -5700,0.9840512,2.5745316,,,,,,,,,,,,,, -5800,0.93285733,2.545204,,,,,,,,,,,,,, -5900,0.88262576,2.7341294,,,,,,,,,,,,,, -6000,1.232755,2.545416,,,,,,,,,,,,,, -6008,,,0.3782485425472259,2.895301818847656,0.3356399834156036,3.18208122253418,50000.0,0.2597000002861023,3.796529531478882,10000.0,2074.3404870033264,2163.361767053604,2074.3404870033264,88.70134115219116,0.1119053363800048,0.0 -6100,1.0347154,2.6073363,,,,,,,,,,,,,, -6200,0.88097227,2.610224,,,,,,,,,,,,,, -6300,1.0261714,2.65306,,,,,,,,,,,,,, -6400,1.0912153,2.5933535,,,,,,,,,,,,,, -6500,0.9150018,2.6049557,,,,,,,,,,,,,, -6600,1.0242797,2.452241,,,,,,,,,,,,,, -6700,1.1925135,2.5709844,,,,,,,,,,,,,, -6800,0.90135926,2.5328531,,,,,,,,,,,,,, -6900,0.97081983,2.5434117,,,,,,,,,,,,,, -7000,1.0646713,2.6335156,,,,,,,,,,,,,, -7100,0.9274951,2.5366309,,,,,,,,,,,,,, -7200,0.90614694,2.4653764,,,,,,,,,,,,,, -7300,0.92034024,2.5138743,,,,,,,,,,,,,, -7400,1.0210297,2.3879704,,,,,,,,,,,,,, -7500,0.9652711,2.6446805,,,,,,,,,,,,,, -7513,,,0.2871890962123871,3.638346910476685,0.2617000043392181,3.8255012035369873,50000.0,0.1861000061035156,4.576222896575928,10000.0,2584.437113523484,2691.138946533203,2584.437113523484,106.30011343955994,0.1416637897491455,0.0 -7600,0.9698968,2.4095082,,,,,,,,,,,,,, -7700,0.9485243,2.5173812,,,,,,,,,,,,,, -7800,1.1211898,2.4439209,,,,,,,,,,,,,, -7900,0.9868668,2.5762496,,,,,,,,,,,,,, -8000,0.96296394,2.4660692,,,,,,,,,,,,,, -8100,0.94543236,2.523344,,,,,,,,,,,,,, -8200,1.1646193,2.4181979,,,,,,,,,,,,,, -8300,0.90070105,2.4328485,,,,,,,,,,,,,, -8400,0.9961277,2.551776,,,,,,,,,,,,,, -8500,0.9186828,2.630436,,,,,,,,,,,,,, -8600,1.0590107,2.509601,,,,,,,,,,,,,, -8700,0.95449984,2.461715,,,,,,,,,,,,,, -8800,1.0856212,2.5577257,,,,,,,,,,,,,, -8900,1.080554,2.3320653,,,,,,,,,,,,,, -9000,0.956435,2.6404936,,,,,,,,,,,,,, -9020,,,0.3981983363628387,2.802988052368164,0.3658399879932403,3.020524501800537,50000.0,0.2747000157833099,3.7215497493743896,10000.0,3094.679023504257,3219.2161860466003,3094.679023504257,124.0553548336029,0.1691560745239257,0.0 -9100,0.92093253,2.4154701,,,,,,,,,,,,,, -9200,0.9360784,2.5615482,,,,,,,,,,,,,, -9300,1.0616139,2.4612176,,,,,,,,,,,,,, -9400,0.99554116,2.3766286,,,,,,,,,,,,,, -9500,0.98467124,2.4627323,,,,,,,,,,,,,, -9600,1.0158783,2.3818226,,,,,,,,,,,,,, -9700,1.0798844,2.4643273,,,,,,,,,,,,,, -9800,0.9742587,2.4170089,,,,,,,,,,,,,, -9900,0.995228,2.4237204,,,,,,,,,,,,,, -10000,0.91524035,2.447634,,,,,,,,,,,,,, -10100,1.111251,2.6036735,,,,,,,,,,,,,, -10200,0.91011417,2.3792887,,,,,,,,,,,,,, -10300,1.057201,2.4976118,,,,,,,,,,,,,, -10400,0.9937259,2.4734676,,,,,,,,,,,,,, -10500,1.0990841,2.48347,,,,,,,,,,,,,, -10527,,,0.3869379758834839,2.846888542175293,0.3642399907112121,3.010152101516724,50000.0,0.2690000236034393,3.798298120498657,10000.0,3604.599910259247,3747.6188309192657,3604.599910259247,142.4556610584259,0.1979579925537109,0.0 -10600,1.0046823,2.327423,,,,,,,,,,,,,, -10700,0.9867146,2.386248,,,,,,,,,,,,,, -10800,1.0308454,2.3938031,,,,,,,,,,,,,, -10900,0.98191535,2.595851,,,,,,,,,,,,,, -11000,1.0549929,2.5412133,,,,,,,,,,,,,, -11100,0.9619333,2.2966821,,,,,,,,,,,,,, -11200,0.92965585,2.3443363,,,,,,,,,,,,,, -11300,0.98814094,2.367242,,,,,,,,,,,,,, -11400,1.0451972,2.3891807,,,,,,,,,,,,,, -11500,1.0638901,2.493826,,,,,,,,,,,,,, -11600,0.98820275,2.5291204,,,,,,,,,,,,,, -11700,1.0123264,2.377146,,,,,,,,,,,,,, -11800,1.0110762,2.5462604,,,,,,,,,,,,,, -11900,0.8832754,2.4609938,,,,,,,,,,,,,, -12000,1.0685382,2.4670115,,,,,,,,,,,,,, -12035,,,0.302156388759613,3.548335313796997,0.2874999940395355,3.705745458602905,50000.0,0.2058000117540359,4.551807880401611,10000.0,4114.688268184662,4275.333572149277,4114.688268184662,160.00274682044983,0.2236948013305664,0.0 -12100,1.0601903,2.3486238,,,,,,,,,,,,,, -12200,0.9191566,2.4315875,,,,,,,,,,,,,, -12300,1.1257428,2.4686842,,,,,,,,,,,,,, -12400,0.9727146,2.371749,,,,,,,,,,,,,, -12500,1.1626514,2.4539096,,,,,,,,,,,,,, -12600,1.060438,2.6079755,,,,,,,,,,,,,, -12700,0.9175157,2.359859,,,,,,,,,,,,,, -12800,1.0163589,2.3872738,,,,,,,,,,,,,, -12900,0.86893827,2.3563924,,,,,,,,,,,,,, -13000,1.0163597,2.4150767,,,,,,,,,,,,,, -13100,0.9633173,2.4211836,,,,,,,,,,,,,, -13200,0.92741865,2.3631256,,,,,,,,,,,,,, -13300,1.021104,2.4865792,,,,,,,,,,,,,, -13400,0.9371632,2.4133408,,,,,,,,,,,,,, -13500,1.099271,2.4570992,,,,,,,,,,,,,, -13543,,,0.2102798074483871,4.674249649047852,0.1959799975156784,4.838823795318604,50000.0,0.1577000021934509,5.386991500854492,10000.0,4624.764115333557,4803.130608320236,4624.764115333557,177.63978910446167,0.2553791999816894,0.0 -13600,0.98047704,2.3034189,,,,,,,,,,,,,, -13700,1.024843,2.2748456,,,,,,,,,,,,,, -13800,0.92884475,2.3915138,,,,,,,,,,,,,, -13900,1.0239531,2.3539329,,,,,,,,,,,,,, -14000,0.9920064,2.3166294,,,,,,,,,,,,,, -14100,1.020118,2.43301,,,,,,,,,,,,,, -14200,1.345252,2.377965,,,,,,,,,,,,,, -14300,1.0255854,2.4723537,,,,,,,,,,,,,, -14400,0.9959925,2.3681695,,,,,,,,,,,,,, -14500,1.0465633,2.385435,,,,,,,,,,,,,, -14600,1.1592416,2.3371005,,,,,,,,,,,,,, -14700,1.0450859,2.2923975,,,,,,,,,,,,,, -14800,1.0563132,2.4596043,,,,,,,,,,,,,, -14900,0.99441946,2.3691437,,,,,,,,,,,,,, -15000,1.0725455,2.3685668,,,,,,,,,,,,,, -15052,,,0.2354113459587097,4.341064929962158,0.2106999903917312,4.6160736083984375,50000.0,0.1561000049114227,5.416417598724365,10000.0,5134.7330057621,5331.500044822693,5134.7330057621,195.9572970867157,0.2865011692047119,0.0 -15100,0.9424485,2.519311,,,,,,,,,,,,,, -15200,0.95751655,2.469267,,,,,,,,,,,,,, -15300,1.0443485,2.4427302,,,,,,,,,,,,,, -15400,0.9739824,2.4537077,,,,,,,,,,,,,, -15500,1.0596116,2.4566483,,,,,,,,,,,,,, -15600,1.118674,2.435189,,,,,,,,,,,,,, -15700,0.9680992,2.2629879,,,,,,,,,,,,,, -15800,1.0673224,2.3609493,,,,,,,,,,,,,, -15900,1.0071281,2.4001021,,,,,,,,,,,,,, -16000,0.9518868,2.3897943,,,,,,,,,,,,,, -16100,1.0450004,2.3518612,,,,,,,,,,,,,, -16200,0.9687661,2.4279633,,,,,,,,,,,,,, -16300,1.2408341,2.3334296,,,,,,,,,,,,,, -16400,1.0638642,2.5513368,,,,,,,,,,,,,, -16500,0.9485368,2.3650901,,,,,,,,,,,,,, -16562,,,0.2307876199483871,4.050477504730225,0.2169799953699112,4.200684070587158,50000.0,0.1513000130653381,4.930394649505615,10000.0,5644.699779033661,5859.146992444992,5644.699779033661,213.5535204410553,0.3168478012084961,0.0 -16600,1.1325016,2.5022523,,,,,,,,,,,,,, -16700,0.95569336,2.255217,,,,,,,,,,,,,, -16800,1.0292634,2.2312095,,,,,,,,,,,,,, -16900,1.0298088,2.420283,,,,,,,,,,,,,, -17000,1.2443007,2.3539054,,,,,,,,,,,,,, -17100,1.0414429,2.2813218,,,,,,,,,,,,,, -17200,0.9844167,2.3918772,,,,,,,,,,,,,, -17300,1.1215005,2.388493,,,,,,,,,,,,,, -17400,1.1286757,2.3645086,,,,,,,,,,,,,, -17500,0.9816696,2.402774,,,,,,,,,,,,,, -17600,1.0260639,2.4159207,,,,,,,,,,,,,, -17700,1.2580549,2.3592546,,,,,,,,,,,,,, -17800,1.1112701,2.5416849,,,,,,,,,,,,,, -17900,1.0067971,2.286498,,,,,,,,,,,,,, -18000,1.0394522,2.446114,,,,,,,,,,,,,, -18073,,,0.3052654564380646,3.361721992492676,0.2855399847030639,3.5417702198028564,50000.0,0.2099000066518783,4.276243209838867,10000.0,6154.821240901947,6387.498101949692,6154.821240901947,231.6983737945557,0.3491621017456054,0.0 -18100,0.9813001,2.3139024,,,,,,,,,,,,,, -18200,1.0975819,2.4411147,,,,,,,,,,,,,, -18300,1.1598215,2.463871,,,,,,,,,,,,,, -18400,1.0226197,2.2822993,,,,,,,,,,,,,, -18500,1.100963,2.380516,,,,,,,,,,,,,, -18600,0.9893774,2.319114,,,,,,,,,,,,,, -18700,0.9938384,2.3512025,,,,,,,,,,,,,, -18800,1.0817897,2.367224,,,,,,,,,,,,,, -18900,1.1257659,2.3290274,,,,,,,,,,,,,, -19000,1.0315208,2.3072495,,,,,,,,,,,,,, -19100,1.2046396,2.329269,,,,,,,,,,,,,, -19200,1.2508307,2.4608278,,,,,,,,,,,,,, -19300,1.2050682,2.401599,,,,,,,,,,,,,, -19400,1.074214,2.3350868,,,,,,,,,,,,,, -19500,0.9962083,2.2141073,,,,,,,,,,,,,, -19584,,,0.2442004084587097,3.913287878036499,0.230119988322258,4.02189826965332,50000.0,0.1702000051736831,4.693854331970215,10000.0,6664.954342365265,6915.529188632965,6664.954342365265,249.51444363594047,0.3746833801269531,0.0 -19600,1.0650318,2.3213742,,,,,,,,,,,,,, -19700,1.003315,2.3039184,,,,,,,,,,,,,, -19800,1.1572368,2.36124,,,,,,,,,,,,,, -19900,1.093947,2.5463018,,,,,,,,,,,,,, -20000,1.1674484,2.2774997,,,,,,,,,,,,,, -20100,0.9990647,2.2979815,,,,,,,,,,,,,, -20200,0.9949296,2.3841667,,,,,,,,,,,,,, -20300,0.96187663,2.3688922,,,,,,,,,,,,,, -20400,1.0154812,2.4733825,,,,,,,,,,,,,, -20500,1.104438,2.4285567,,,,,,,,,,,,,, -20600,1.1097035,2.4683995,,,,,,,,,,,,,, -20700,0.9605069,2.3030102,,,,,,,,,,,,,, -20800,1.0420744,2.3550677,,,,,,,,,,,,,, -20900,1.0471733,2.359397,,,,,,,,,,,,,, -21000,0.96367806,2.364928,,,,,,,,,,,,,, -21096,,,0.242267221212387,3.892278432846069,0.2260199934244156,4.00605058670044,50000.0,0.1699000149965286,4.684741973876953,10000.0,7175.1495196819305,7443.557811498642,7175.1495196819305,267.2611298561096,0.4081254005432129,0.0 -21100,1.0275878,2.3693004,,,,,,,,,,,,,, -21200,1.1381854,2.3532808,,,,,,,,,,,,,, -21300,1.1550744,2.2992814,,,,,,,,,,,,,, -21400,0.98160994,2.3911138,,,,,,,,,,,,,, -21500,1.1599206,2.4002142,,,,,,,,,,,,,, -21600,1.093686,2.3795204,,,,,,,,,,,,,, -21700,1.0780344,2.2363846,,,,,,,,,,,,,, -21800,1.1150916,2.2880492,,,,,,,,,,,,,, -21900,1.1764863,2.2751296,,,,,,,,,,,,,, -22000,1.1890153,2.3116567,,,,,,,,,,,,,, -22100,1.0813745,2.426185,,,,,,,,,,,,,, -22200,1.2870286,2.3371577,,,,,,,,,,,,,, -22300,0.9883559,2.252191,,,,,,,,,,,,,, -22400,1.136135,2.380939,,,,,,,,,,,,,, -22500,1.1454005,2.3209069,,,,,,,,,,,,,, -22600,1.0572033,2.3408372,,,,,,,,,,,,,, -22609,,,0.2122528702020645,4.563592433929443,0.2013600021600723,4.659169673919678,50000.0,0.1548000127077102,5.273361682891846,10000.0,7685.252676725388,7971.450614213943,7685.252676725388,284.9658041000366,0.439424991607666,0.0 -22700,1.1044136,2.336973,,,,,,,,,,,,,, -22800,1.0085125,2.360465,,,,,,,,,,,,,, -22900,1.0676048,2.2959042,,,,,,,,,,,,,, -23000,1.2011665,2.2809677,,,,,,,,,,,,,, -23100,1.0213382,2.3286242,,,,,,,,,,,,,, -23200,1.0473623,2.3202596,,,,,,,,,,,,,, -23300,1.0102903,2.4605346,,,,,,,,,,,,,, -23400,1.1466417,2.3705242,,,,,,,,,,,,,, -23500,1.1979718,2.466289,,,,,,,,,,,,,, -23600,1.0486604,2.3328419,,,,,,,,,,,,,, -23700,1.1867604,2.374302,,,,,,,,,,,,,, -23800,1.1610421,2.4321494,,,,,,,,,,,,,, -23900,1.1056864,2.3509982,,,,,,,,,,,,,, -24000,1.1574485,2.375829,,,,,,,,,,,,,, -24100,1.0772243,2.3076124,,,,,,,,,,,,,, -24121,,,0.2415298074483871,4.111358165740967,0.2184399962425232,4.350162506103516,50000.0,0.1574000120162964,5.048336029052734,10000.0,8195.383947610855,8499.119955062866,8195.383947610855,302.4214491844177,0.4703423976898193,0.0 -24200,1.1648579,2.3460877,,,,,,,,,,,,,, -24300,1.1346929,2.2914114,,,,,,,,,,,,,, -24400,1.0062653,2.3796709,,,,,,,,,,,,,, -24500,1.072738,2.3258114,,,,,,,,,,,,,, -24600,1.1247334,2.34713,,,,,,,,,,,,,, -24700,1.0818347,2.2335958,,,,,,,,,,,,,, -24800,1.0516248,2.3196986,,,,,,,,,,,,,, -24900,1.0873038,2.3462963,,,,,,,,,,,,,, -25000,1.0645747,2.2624295,,,,,,,,,,,,,, -25100,1.0707465,2.2560782,,,,,,,,,,,,,, -25200,1.130396,2.303745,,,,,,,,,,,,,, -25300,1.3076025,2.3278284,,,,,,,,,,,,,, -25400,1.0048507,2.2841256,,,,,,,,,,,,,, -25500,1.1630758,2.2659388,,,,,,,,,,,,,, -25600,1.1152242,2.267364,,,,,,,,,,,,,, -25634,,,0.2726801633834839,3.768687009811402,0.2542800009250641,3.919223308563232,50000.0,0.1791000068187713,4.636255741119385,10000.0,8705.445008277893,9027.102459907532,8705.445008277893,320.2576837539673,0.502802848815918,0.0 -25700,1.1057074,2.3675656,,,,,,,,,,,,,, -25800,1.2199332,2.3916192,,,,,,,,,,,,,, -25900,1.0857099,2.4346464,,,,,,,,,,,,,, -26000,1.0738765,2.316956,,,,,,,,,,,,,, -26100,0.9705473,2.2840436,,,,,,,,,,,,,, -26200,1.0278786,2.3584647,,,,,,,,,,,,,, -26300,1.1217387,2.267205,,,,,,,,,,,,,, -26400,1.0352938,2.2177544,,,,,,,,,,,,,, -26500,1.1161565,2.258375,,,,,,,,,,,,,, -26600,0.99057513,2.2204523,,,,,,,,,,,,,, -26700,1.1803569,2.3396366,,,,,,,,,,,,,, -26800,1.067646,2.4289248,,,,,,,,,,,,,, -26900,1.1539186,2.353163,,,,,,,,,,,,,, -27000,1.1132377,2.294073,,,,,,,,,,,,,, -27100,1.1007648,2.2347815,,,,,,,,,,,,,, -27146,,,0.2140664756298065,4.525050640106201,0.2014999985694885,4.717775344848633,50000.0,0.1522000133991241,5.444266319274902,10000.0,9215.368250131609,9554.7049472332,9215.368250131609,337.85802817344666,0.5298073291778564,0.0 -27200,0.97782326,2.1920104,,,,,,,,,,,,,, -27300,0.937758,2.1197147,,,,,,,,,,,,,, -27400,1.1264703,2.4024339,,,,,,,,,,,,,, -27500,0.9992931,2.2197447,,,,,,,,,,,,,, -27600,1.0022532,2.2675889,,,,,,,,,,,,,, -27700,1.10997,2.2758746,,,,,,,,,,,,,, -27800,1.25115,2.3807793,,,,,,,,,,,,,, -27900,1.0868726,2.248118,,,,,,,,,,,,,, -28000,1.1235538,2.3222895,,,,,,,,,,,,,, -28100,1.1279935,2.340491,,,,,,,,,,,,,, -28200,1.0815952,2.2741845,,,,,,,,,,,,,, -28300,1.0378317,2.3529413,,,,,,,,,,,,,, -28400,1.068742,2.3173366,,,,,,,,,,,,,, -28500,1.2002747,2.251476,,,,,,,,,,,,,, -28600,1.2022028,2.302367,,,,,,,,,,,,,, -28659,,,0.1899513602256775,4.6598968505859375,0.1812800019979477,4.777307510375977,50000.0,0.1374000012874603,5.371987819671631,10000.0,9725.415110111237,10082.664724826813,9725.415110111237,355.6845571994781,0.5628311634063721,0.0 -28700,1.103845,2.3209867,,,,,,,,,,,,,, -28800,1.0811981,2.362648,,,,,,,,,,,,,, -28900,1.0064145,2.2161384,,,,,,,,,,,,,, -29000,1.125113,2.288208,,,,,,,,,,,,,, -29100,1.1232724,2.3475223,,,,,,,,,,,,,, -29200,1.2064928,2.3796816,,,,,,,,,,,,,, -29300,1.1565448,2.5029805,,,,,,,,,,,,,, -29400,1.1416476,2.3676,,,,,,,,,,,,,, -29500,1.0588696,2.2720265,,,,,,,,,,,,,, -29600,1.1245964,2.4292512,,,,,,,,,,,,,, -29700,1.0778719,2.2442322,,,,,,,,,,,,,, -29800,1.3042147,2.3650675,,,,,,,,,,,,,, -29900,1.0985601,2.2958002,,,,,,,,,,,,,, -30000,1.2009592,2.3765738,,,,,,,,,,,,,, -30100,0.9835156,2.274851,,,,,,,,,,,,,, -30172,,,0.1981624662876129,4.559650897979736,0.1924600005149841,4.613406181335449,50000.0,0.1344000101089477,5.338951110839844,10000.0,10235.613961458206,10610.321287631989,10235.613961458206,373.0587408542633,0.594775915145874,0.0 -30200,1.1469787,2.342173,,,,,,,,,,,,,, -30300,1.0435753,2.3082297,,,,,,,,,,,,,, -30400,1.1230443,2.3496737,,,,,,,,,,,,,, -30500,1.0962055,2.3891613,,,,,,,,,,,,,, -30600,1.200143,2.3011205,,,,,,,,,,,,,, -30700,1.070116,2.2338443,,,,,,,,,,,,,, -30800,1.0583935,2.3693702,,,,,,,,,,,,,, -30900,1.0646985,2.4257436,,,,,,,,,,,,,, -31000,1.1190671,2.2275038,,,,,,,,,,,,,, -31100,1.0646604,2.2944324,,,,,,,,,,,,,, -31200,1.0816989,2.2566216,,,,,,,,,,,,,, -31300,1.0837171,2.3457584,,,,,,,,,,,,,, -31400,1.1014111,2.2920876,,,,,,,,,,,,,, -31500,1.1353257,2.2585046,,,,,,,,,,,,,, -31600,1.0707762,2.25113,,,,,,,,,,,,,, -31686,,,0.3178212642669678,3.394036054611206,0.3035599887371063,3.541114330291748,50000.0,0.228300005197525,4.332971572875977,10000.0,10745.81271648407,11138.345761299131,10745.81271648407,390.7978858947754,0.6274514198303223,0.0 -31700,0.9856458,2.207655,,,,,,,,,,,,,, -31800,1.1490744,2.3211493,,,,,,,,,,,,,, -31900,1.0941694,2.2272074,,,,,,,,,,,,,, -32000,1.0730531,2.3591385,,,,,,,,,,,,,, -32100,1.067879,2.2136304,,,,,,,,,,,,,, -32200,1.154081,2.2282627,,,,,,,,,,,,,, -32300,1.1535565,2.1524224,,,,,,,,,,,,,, -32400,1.1655209,2.302691,,,,,,,,,,,,,, -32500,1.1667265,2.3641925,,,,,,,,,,,,,, -32600,1.0147038,2.3402784,,,,,,,,,,,,,, -32700,1.0747877,2.23188,,,,,,,,,,,,,, -32800,1.2167555,2.4032617,,,,,,,,,,,,,, -32900,1.0516013,2.3023705,,,,,,,,,,,,,, -33000,1.0781804,2.2372892,,,,,,,,,,,,,, -33100,1.1735224,2.165281,,,,,,,,,,,,,, -33200,,,0.1719148606061935,5.012552261352539,0.1557399928569793,5.170931339263916,50000.0,0.1111000031232833,5.935586452484131,10000.0,11255.900620937347,11665.95344877243,11255.900620937347,408.22711849212646,0.6634583473205566,0.0 -33200,1.1876693,2.168255,,,,,,,,,,,,,, -33300,1.1295205,2.2728388,,,,,,,,,,,,,, -33400,1.095036,2.3398554,,,,,,,,,,,,,, -33500,1.1524401,2.3025928,,,,,,,,,,,,,, -33600,1.1130202,2.2785134,,,,,,,,,,,,,, -33700,1.0399889,2.2511525,,,,,,,,,,,,,, -33800,1.1488023,2.4767358,,,,,,,,,,,,,, -33900,1.2149128,2.2282913,,,,,,,,,,,,,, -34000,1.2285622,2.3598268,,,,,,,,,,,,,, -34100,1.3289814,2.31724,,,,,,,,,,,,,, -34200,1.0703323,2.2515974,,,,,,,,,,,,,, -34300,1.0922608,2.334629,,,,,,,,,,,,,, -34400,1.123366,2.3045797,,,,,,,,,,,,,, -34500,1.1299134,2.2737637,,,,,,,,,,,,,, -34600,1.0361683,2.2032037,,,,,,,,,,,,,, -34700,1.1715169,2.2941704,,,,,,,,,,,,,, -34714,,,0.3223851919174194,3.268257141113281,0.3014200031757355,3.499361515045166,50000.0,0.2261000126600265,4.192788124084473,10000.0,11766.076352834702,12193.907732248306,11766.076352834702,425.9146020412445,0.7024168968200684,0.0 -34800,1.2389377,2.2657547,,,,,,,,,,,,,, -34900,1.2075444,2.2925565,,,,,,,,,,,,,, -35000,1.2078097,2.3442528,,,,,,,,,,,,,, -35100,1.133942,2.012296,,,,,,,,,,,,,, -35200,1.2791274,2.2806246,,,,,,,,,,,,,, -35300,1.0558753,2.2716742,,,,,,,,,,,,,, -35400,1.0901389,2.1635911,,,,,,,,,,,,,, -35500,1.0537763,2.3075724,,,,,,,,,,,,,, -35600,1.072615,2.2482576,,,,,,,,,,,,,, -35700,1.0993388,2.236747,,,,,,,,,,,,,, -35800,1.1721569,2.2823143,,,,,,,,,,,,,, -35900,1.2511071,2.2992902,,,,,,,,,,,,,, -36000,1.069123,2.219465,,,,,,,,,,,,,, -36100,1.1254148,2.3224938,,,,,,,,,,,,,, -36200,1.1003782,2.2112043,,,,,,,,,,,,,, -36228,,,0.2633928656578064,3.819126605987549,0.2487399876117706,3.92941689491272,50000.0,0.1833000034093856,4.704871654510498,10000.0,12276.239780902864,12721.67400622368,12276.239780902864,443.42841243743896,0.7372820377349854,0.0 -36300,1.1179578,2.3035333,,,,,,,,,,,,,, -36400,1.0889158,2.237514,,,,,,,,,,,,,, -36500,1.1472334,2.190631,,,,,,,,,,,,,, -36600,1.3216649,2.2589254,,,,,,,,,,,,,, -36700,1.019277,2.1993477,,,,,,,,,,,,,, -36800,1.115187,2.3500547,,,,,,,,,,,,,, -36900,1.1088648,2.285032,,,,,,,,,,,,,, -37000,1.1323612,2.4171572,,,,,,,,,,,,,, -37100,1.0964054,2.3962047,,,,,,,,,,,,,, -37200,1.1225587,2.2972205,,,,,,,,,,,,,, -37300,1.1145657,2.3453705,,,,,,,,,,,,,, -37400,1.3450546,2.3675745,,,,,,,,,,,,,, -37500,1.111445,2.3409522,,,,,,,,,,,,,, -37600,1.0182239,2.2230206,,,,,,,,,,,,,, -37700,1.2579571,2.2387304,,,,,,,,,,,,,, -37742,,,0.3377511203289032,3.210239887237549,0.3193999826908111,3.303650140762329,50000.0,0.2308000177145004,4.0718674659729,10000.0,12786.358525514604,13249.55362534523,12786.358525514604,461.1026515960693,0.7711968421936035,0.0 -37800,1.2145718,2.2149653,,,,,,,,,,,,,, -37900,1.2598721,2.2120242,,,,,,,,,,,,,, -38000,1.0841508,2.2313087,,,,,,,,,,,,,, -38100,1.1255763,2.2340477,,,,,,,,,,,,,, -38200,1.1469951,2.1936116,,,,,,,,,,,,,, -38300,1.1524552,2.1124268,,,,,,,,,,,,,, -38400,1.1640941,2.2935088,,,,,,,,,,,,,, -38500,1.116126,2.2049901,,,,,,,,,,,,,, -38600,1.1309494,2.26349,,,,,,,,,,,,,, -38700,1.1209184,2.2122574,,,,,,,,,,,,,, -38800,1.1414548,2.3589582,,,,,,,,,,,,,, -38900,1.03896,2.1958785,,,,,,,,,,,,,, -39000,1.1865501,2.3442333,,,,,,,,,,,,,, -39100,1.2474709,2.3157945,,,,,,,,,,,,,, -39200,1.1165444,2.2412121,,,,,,,,,,,,,, -39256,,,0.1889150142669677,4.7008209228515625,0.1738400012254715,4.816714763641357,50000.0,0.1162000074982643,5.706114768981934,10000.0,13296.550779342651,13777.302340745926,13296.550779342651,478.5720095634461,0.8064799308776855,0.0 -39300,1.1246595,2.3137712,,,,,,,,,,,,,, -39400,1.0563282,2.1303532,,,,,,,,,,,,,, -39500,1.4132928,2.3432996,,,,,,,,,,,,,, -39600,1.236006,2.1776721,,,,,,,,,,,,,, -39700,1.1966473,2.246893,,,,,,,,,,,,,, -39800,1.064759,2.2178946,,,,,,,,,,,,,, -39900,1.1261389,2.1743083,,,,,,,,,,,,,, -40000,1.0600312,2.1437507,,,,,,,,,,,,,, -40100,1.267958,2.2205386,,,,,,,,,,,,,, -40200,1.1347489,2.2452579,,,,,,,,,,,,,, -40300,1.1091435,2.2553515,,,,,,,,,,,,,, -40400,1.3288182,2.3660102,,,,,,,,,,,,,, -40500,1.1567248,2.2071888,,,,,,,,,,,,,, -40600,1.167929,2.366427,,,,,,,,,,,,,, -40700,1.0862759,2.3121095,,,,,,,,,,,,,, -40770,,,0.0982541441917419,6.791675090789795,0.092299997806549,6.967357158660889,50000.0,0.064300000667572,7.548539638519287,10000.0,13806.753144979475,14304.94440293312,13806.753144979475,495.9221394062042,0.8445472717285156,0.0 -40800,1.1750488,2.3764548,,,,,,,,,,,,,, -40900,1.2447118,2.352669,,,,,,,,,,,,,, -41000,1.1413194,2.1635702,,,,,,,,,,,,,, -41100,1.128414,2.3548179,,,,,,,,,,,,,, -41200,1.2543864,2.362692,,,,,,,,,,,,,, -41300,1.0767994,2.2757356,,,,,,,,,,,,,, -41400,1.316634,2.4402316,,,,,,,,,,,,,, -41500,1.209431,2.1877685,,,,,,,,,,,,,, -41600,1.1015856,2.3149564,,,,,,,,,,,,,, -41700,1.1102978,2.2882252,,,,,,,,,,,,,, -41800,1.0744822,2.2506115,,,,,,,,,,,,,, -41900,1.1966687,2.415341,,,,,,,,,,,,,, -42000,1.0879009,2.3095105,,,,,,,,,,,,,, -42100,1.2798162,2.32451,,,,,,,,,,,,,, -42200,1.22953,2.4258907,,,,,,,,,,,,,, -42282,,,0.3134167790412903,3.368189573287964,0.2895599901676178,3.5763092041015625,50000.0,0.2132000029087066,4.345427513122559,10000.0,14316.756876945496,14832.802520275116,14316.756876945496,513.688027381897,0.8791120052337646,0.0 -42300,1.1667558,2.2457314,,,,,,,,,,,,,, -42400,1.1800143,2.359816,,,,,,,,,,,,,, -42500,1.1508856,2.2113261,,,,,,,,,,,,,, -42600,1.1422846,2.2461982,,,,,,,,,,,,,, -42700,1.2128242,2.2760742,,,,,,,,,,,,,, -42800,1.1439296,2.2586632,,,,,,,,,,,,,, -42900,1.1835282,2.2585917,,,,,,,,,,,,,, -43000,1.0976604,2.1814625,,,,,,,,,,,,,, -43100,1.2257136,2.2438323,,,,,,,,,,,,,, -43200,1.1432779,2.1170382,,,,,,,,,,,,,, -43300,1.0743879,2.1658814,,,,,,,,,,,,,, -43400,1.1766887,2.3744693,,,,,,,,,,,,,, -43500,1.1236318,2.2979975,,,,,,,,,,,,,, -43600,1.162423,2.1512492,,,,,,,,,,,,,, -43700,1.1464022,2.2648947,,,,,,,,,,,,,, -43796,,,0.2523716390132904,3.879667043685913,0.237879991531372,4.018867015838623,50000.0,0.1721000075340271,4.735334873199463,10000.0,14826.869350671768,15361.157025814056,14826.869350671768,531.8394250869751,0.916710615158081,0.0 -43800,1.1569287,2.3215983,,,,,,,,,,,,,, -43900,1.264621,2.3960786,,,,,,,,,,,,,, -44000,1.1775596,2.3643749,,,,,,,,,,,,,, -44100,1.1541388,2.3387115,,,,,,,,,,,,,, -44200,1.1181439,2.2902029,,,,,,,,,,,,,, -44300,1.1535043,2.2840106,,,,,,,,,,,,,, -44400,1.3620772,2.3565197,,,,,,,,,,,,,, -44500,1.0955849,2.213028,,,,,,,,,,,,,, -44600,1.3635706,2.2687805,,,,,,,,,,,,,, -44700,1.2990338,2.1863787,,,,,,,,,,,,,, -44800,1.5911804,2.315155,,,,,,,,,,,,,, -44900,1.1512913,2.2528791,,,,,,,,,,,,,, -45000,1.1927491,2.2241635,,,,,,,,,,,,,, -45100,1.210661,2.3368416,,,,,,,,,,,,,, -45200,1.3630208,2.306807,,,,,,,,,,,,,, -45300,1.3552077,2.262048,,,,,,,,,,,,,, -45310,,,0.1259765625,5.498228073120117,0.1167399957776069,5.669344425201416,50000.0,0.0870000049471855,6.223611831665039,10000.0,15336.784964323044,15888.927158594131,15336.784964323044,549.6033554077148,0.9554266929626464,0.0 -45400,1.1614666,2.3676858,,,,,,,,,,,,,, -45500,1.200435,2.2727013,,,,,,,,,,,,,, -45600,1.1415652,2.2589192,,,,,,,,,,,,,, -45700,1.1008214,2.2122865,,,,,,,,,,,,,, -45800,1.124559,2.248756,,,,,,,,,,,,,, -45900,1.1507068,2.1731088,,,,,,,,,,,,,, -46000,1.2574033,2.2757013,,,,,,,,,,,,,, -46100,1.1091896,2.2398632,,,,,,,,,,,,,, -46200,1.1439874,2.5118332,,,,,,,,,,,,,, -46300,1.2979918,2.306268,,,,,,,,,,,,,, -46400,1.2292764,2.2964582,,,,,,,,,,,,,, -46500,1.3665025,2.3303277,,,,,,,,,,,,,, -46600,1.216346,2.3121393,,,,,,,,,,,,,, -46700,1.2500788,2.3242772,,,,,,,,,,,,,, -46800,1.2665656,2.2045772,,,,,,,,,,,,,, -46824,,,0.3132772445678711,3.4213531017303467,0.2969000041484833,3.581291675567627,50000.0,0.2154000103473663,4.327939033508301,10000.0,15846.803076267242,16416.799332618713,15846.803076267242,567.3659672737122,0.9953644275665284,0.0 -46900,1.140518,2.206488,,,,,,,,,,,,,, -47000,1.1721532,2.285774,,,,,,,,,,,,,, -47100,1.3717616,2.315093,,,,,,,,,,,,,, -47200,1.2511265,2.2370858,,,,,,,,,,,,,, -47300,1.2443974,2.2576606,,,,,,,,,,,,,, -47400,1.1905137,2.2996132,,,,,,,,,,,,,, -47500,1.1580131,2.1735618,,,,,,,,,,,,,, -47600,1.3122914,2.299661,,,,,,,,,,,,,, -47700,1.2280288,2.3087373,,,,,,,,,,,,,, -47800,1.2440302,2.3534787,,,,,,,,,,,,,, -47900,1.250937,2.290491,,,,,,,,,,,,,, -48000,1.0969503,2.233786,,,,,,,,,,,,,, -48100,1.2745465,2.3813534,,,,,,,,,,,,,, -48200,1.2325646,2.1080475,,,,,,,,,,,,,, -48300,1.257329,2.271262,,,,,,,,,,,,,, -48338,,,0.3341836631298065,3.1928305625915527,0.3179799914360046,3.339296817779541,50000.0,0.2399000078439712,4.064082145690918,10000.0,16356.954234361649,16944.592866659164,16356.954234361649,584.9171187877655,1.0334186553955078,0.0 -48400,1.1799651,2.2607968,,,,,,,,,,,,,, -48500,1.1597292,2.0130553,,,,,,,,,,,,,, -48600,1.2249328,2.2842572,,,,,,,,,,,,,, -48700,1.1603365,2.1112173,,,,,,,,,,,,,, -48800,1.1977514,2.3013554,,,,,,,,,,,,,, -48900,1.316633,2.2335567,,,,,,,,,,,,,, -49000,1.2218243,2.3335187,,,,,,,,,,,,,, -49100,1.206085,2.4105988,,,,,,,,,,,,,, -49200,1.3179072,2.1110425,,,,,,,,,,,,,, -49300,1.2050921,2.2604706,,,,,,,,,,,,,, -49400,1.1982555,2.1615362,,,,,,,,,,,,,, -49500,1.1068805,2.3974142,,,,,,,,,,,,,, -49600,1.190547,2.2244577,,,,,,,,,,,,,, -49700,1.1738436,2.270943,,,,,,,,,,,,,, -49800,1.1866692,2.1438801,,,,,,,,,,,,,, -49852,,,0.2771245241165161,3.76228141784668,0.2566399872303009,3.951295852661133,50000.0,0.1909000128507614,4.771968841552734,10000.0,16867.11060857773,17472.53868341446,16867.11060857773,602.6156764030457,1.071347713470459,0.0 -49900,1.3797152,2.2996469,,,,,,,,,,,,,, -50000,1.2377634,2.20423,,,,,,,,,,,,,, -50100,1.1856276,2.2385893,,,,,,,,,,,,,, -50200,1.2125096,2.3258376,,,,,,,,,,,,,, -50300,1.08017,2.2226007,,,,,,,,,,,,,, -50400,1.296157,2.2097075,,,,,,,,,,,,,, -50500,1.2280899,2.2685192,,,,,,,,,,,,,, -50600,1.1728249,2.1025984,,,,,,,,,,,,,, -50700,1.3237396,2.239637,,,,,,,,,,,,,, -50800,1.1739078,2.216575,,,,,,,,,,,,,, -50900,1.1921116,2.088761,,,,,,,,,,,,,, -51000,1.2109782,2.244447,,,,,,,,,,,,,, -51100,1.2201576,2.1826172,,,,,,,,,,,,,, -51200,1.2632955,2.2358222,,,,,,,,,,,,,, -51300,1.1330638,2.3550518,,,,,,,,,,,,,, -51366,,,0.1058673411607742,6.146651268005371,0.093079999089241,6.373527526855469,50000.0,0.0683000013232231,6.9438652992248535,10000.0,17377.086690425873,18000.52445960045,17377.086690425873,620.5363309383392,1.1086018085479736,0.0 -51400,1.1399019,2.2559273,,,,,,,,,,,,,, -51500,1.2297796,2.388003,,,,,,,,,,,,,, -51600,1.2829788,2.2304184,,,,,,,,,,,,,, -51700,1.1781777,2.205522,,,,,,,,,,,,,, -51800,1.3737543,2.2059486,,,,,,,,,,,,,, -51900,1.2030586,2.2166178,,,,,,,,,,,,,, -52000,1.3766677,2.2054133,,,,,,,,,,,,,, -52100,1.248078,2.255341,,,,,,,,,,,,,, -52200,1.1888124,2.118781,,,,,,,,,,,,,, -52300,1.1683899,2.1176236,,,,,,,,,,,,,, -52400,1.3316635,2.21713,,,,,,,,,,,,,, -52500,1.3386775,2.25419,,,,,,,,,,,,,, -52600,1.5019848,2.2228065,,,,,,,,,,,,,, -52700,1.321657,2.323178,,,,,,,,,,,,,, -52800,1.0608077,2.1624732,,,,,,,,,,,,,, -52880,,,0.4368024468421936,2.536063432693481,0.4030399918556213,2.7369699478149414,50000.0,0.3041000068187713,3.449614286422729,10000.0,17887.21930384636,18528.455255270004,17887.21930384636,638.2453672885895,1.144514799118042,0.0 -52900,1.1767313,2.1006196,,,,,,,,,,,,,, -53000,1.2135054,2.1939595,,,,,,,,,,,,,, -53100,1.2071842,2.1960008,,,,,,,,,,,,,, -53200,1.3547229,2.2917624,,,,,,,,,,,,,, -53300,1.2631602,2.3364851,,,,,,,,,,,,,, -53400,1.2879051,2.1985288,,,,,,,,,,,,,, -53500,1.1767204,2.2536495,,,,,,,,,,,,,, -53600,1.3166856,2.30677,,,,,,,,,,,,,, -53700,1.2340901,2.3343673,,,,,,,,,,,,,, -53800,1.1520158,2.3226795,,,,,,,,,,,,,, -53900,1.1568131,2.2087553,,,,,,,,,,,,,, -54000,1.26177,2.2120852,,,,,,,,,,,,,, -54100,1.1588694,2.213154,,,,,,,,,,,,,, -54200,1.103595,2.188462,,,,,,,,,,,,,, -54300,1.1783333,2.221387,,,,,,,,,,,,,, -54394,,,0.2947225570678711,3.631112098693848,0.2807799875736236,3.7400660514831534,50000.0,0.19930000603199,4.530226230621338,10000.0,18397.130579948425,19056.094877958298,18397.130579948425,655.88427901268,1.1814467906951904,0.0 -54400,1.3057133,2.1920214,,,,,,,,,,,,,, -54500,1.381769,2.2148821,,,,,,,,,,,,,, -54600,1.2427475,2.1353073,,,,,,,,,,,,,, -54700,1.3482722,2.2143264,,,,,,,,,,,,,, -54800,1.341963,2.023216,,,,,,,,,,,,,, -54900,1.2806349,2.219344,,,,,,,,,,,,,, -55000,1.2959654,2.1004622,,,,,,,,,,,,,, -55100,1.1622199,2.196525,,,,,,,,,,,,,, -55200,1.4199102,2.176832,,,,,,,,,,,,,, -55300,1.2137122,2.284781,,,,,,,,,,,,,, -55400,1.2340248,2.166021,,,,,,,,,,,,,, -55500,1.2686402,2.0458987,,,,,,,,,,,,,, -55600,1.238629,2.2106917,,,,,,,,,,,,,, -55700,1.1930895,2.1060672,,,,,,,,,,,,,, -55800,1.2502046,2.193635,,,,,,,,,,,,,, -55900,1.4165231,2.2295904,,,,,,,,,,,,,, -55908,,,0.1893734037876129,4.625549793243408,0.1741199940443039,4.754756927490234,50000.0,0.1299000084400177,5.389272212982178,10000.0,18907.1184194088,19583.71437358856,18907.1184194088,673.4222629070282,1.218961477279663,0.0 -56000,1.0892092,2.183472,,,,,,,,,,,,,, -56100,1.4644251,2.3215108,,,,,,,,,,,,,, -56200,1.2051244,2.2306077,,,,,,,,,,,,,, -56300,1.232005,2.2186387,,,,,,,,,,,,,, -56400,1.2194176,2.1134202,,,,,,,,,,,,,, -56500,1.2030079,2.2718403,,,,,,,,,,,,,, -56600,1.2854512,2.1619277,,,,,,,,,,,,,, -56700,1.3931315,2.3314264,,,,,,,,,,,,,, -56800,1.3683523,2.1263335,,,,,,,,,,,,,, -56900,1.2692988,2.1753685,,,,,,,,,,,,,, -57000,1.2803444,2.2539349,,,,,,,,,,,,,, -57100,1.2977084,2.2252355,,,,,,,,,,,,,, -57200,1.5251548,2.2399452,,,,,,,,,,,,,, -57300,1.3310452,2.1151056,,,,,,,,,,,,,, -57400,1.2670612,2.245761,,,,,,,,,,,,,, -57422,,,0.2808115482330322,3.709074258804321,0.2637200057506561,3.829373598098755,50000.0,0.1934000104665756,4.597098350524902,10000.0,19417.15093255043,20111.37770724297,19417.15093255043,690.9619073867798,1.2572824954986572,0.0 -57500,1.3065569,2.1896796,,,,,,,,,,,,,, -57600,1.456563,2.2303777,,,,,,,,,,,,,, -57700,1.4430957,2.1604745,,,,,,,,,,,,,, -57800,1.1782796,2.224524,,,,,,,,,,,,,, -57900,1.1924481,2.058941,,,,,,,,,,,,,, -58000,1.1454432,2.259175,,,,,,,,,,,,,, -58100,1.2858337,2.28892,,,,,,,,,,,,,, -58200,1.2529517,2.3028917,,,,,,,,,,,,,, -58300,1.3129194,2.1256633,,,,,,,,,,,,,, -58400,1.1179829,2.2151694,,,,,,,,,,,,,, -58500,1.213912,2.232587,,,,,,,,,,,,,, -58600,1.2841822,2.1269279,,,,,,,,,,,,,, -58700,1.1173663,2.1955118,,,,,,,,,,,,,, -58800,1.257676,2.154149,,,,,,,,,,,,,, -58900,1.2066857,2.2733154,,,,,,,,,,,,,, -58937,,,0.2262236922979354,4.323338985443115,0.2068399935960769,4.579726219177246,50000.0,0.1551000028848648,5.24745512008667,10000.0,19927.386483430862,20639.389159202576,19927.386483430862,708.6484682559967,1.293562412261963,0.0 -59000,1.2292678,2.1812196,,,,,,,,,,,,,, -59100,1.4686711,2.2357306,,,,,,,,,,,,,, -59200,1.2986819,2.1709688,,,,,,,,,,,,,, -59300,1.2197171,2.109929,,,,,,,,,,,,,, -59400,1.2807152,2.1976938,,,,,,,,,,,,,, -59500,1.3305799,2.1726136,,,,,,,,,,,,,, -59600,1.3238409,2.312096,,,,,,,,,,,,,, -59700,1.2209407,2.121656,,,,,,,,,,,,,, -59800,1.3241882,2.1426206,,,,,,,,,,,,,, -59900,1.1670988,2.14026,,,,,,,,,,,,,, -60000,1.4105798,2.1720433,,,,,,,,,,,,,, -60100,1.3086932,2.1597285,,,,,,,,,,,,,, -60200,1.2140248,2.225431,,,,,,,,,,,,,, -60300,1.1846304,2.2485924,,,,,,,,,,,,,, -60400,1.2299348,2.2491236,,,,,,,,,,,,,, -60451,,,0.388093888759613,2.8707337379455566,0.3535799980163574,3.126149177551269,50000.0,0.2665000259876251,3.855259656906128,10000.0,20437.37043976784,21167.13238477707,20437.37043976784,726.3159120082855,1.3322150707244873,0.0 -60500,1.3066353,2.0782688,,,,,,,,,,,,,, -60600,1.2887052,2.3158605,,,,,,,,,,,,,, -60700,1.3384222,2.1659975,,,,,,,,,,,,,, -60800,1.3244321,2.172333,,,,,,,,,,,,,, -60900,1.3505087,2.3919592,,,,,,,,,,,,,, -61000,1.2528713,2.3570282,,,,,,,,,,,,,, -61100,1.2787162,2.2950346,,,,,,,,,,,,,, -61200,1.2932931,2.2117877,,,,,,,,,,,,,, -61300,1.407428,2.3045835,,,,,,,,,,,,,, -61400,1.2040551,2.109173,,,,,,,,,,,,,, -61500,1.2820108,2.168743,,,,,,,,,,,,,, -61600,1.3397135,2.1493692,,,,,,,,,,,,,, -61700,1.3805358,2.1063223,,,,,,,,,,,,,, -61800,1.4688022,2.1816657,,,,,,,,,,,,,, -61900,1.3740543,2.1412945,,,,,,,,,,,,,, -61965,,,0.2292729616165161,4.253842830657959,0.220100000500679,4.375511646270752,50000.0,0.1634000092744827,4.972453594207764,10000.0,20947.30847287178,21694.66338658333,20947.30847287178,743.8167865276337,1.3707172870635986,0.0 -62000,1.271578,2.1927433,,,,,,,,,,,,,, -62100,1.4393641,2.1755772,,,,,,,,,,,,,, -62200,1.4473135,2.1670976,,,,,,,,,,,,,, -62300,1.280692,2.2080095,,,,,,,,,,,,,, -62400,1.2905883,2.1979003,,,,,,,,,,,,,, -62500,1.3033446,2.2455819,,,,,,,,,,,,,, -62600,1.4037167,2.1348116,,,,,,,,,,,,,, -62700,1.3366277,2.216934,,,,,,,,,,,,,, -62800,1.3281568,2.222657,,,,,,,,,,,,,, -62900,1.3985583,2.2615392,,,,,,,,,,,,,, -63000,1.3224692,2.0885527,,,,,,,,,,,,,, -63100,1.352049,2.3120632,,,,,,,,,,,,,, -63200,1.2594845,2.2103138,,,,,,,,,,,,,, -63300,1.2696345,2.100699,,,,,,,,,,,,,, -63400,1.3700321,2.1025553,,,,,,,,,,,,,, -63479,,,0.2836615145206451,3.637618541717529,0.2704199850559234,3.731812715530396,50000.0,0.1897000074386596,4.5615234375,10000.0,21457.2831799984,22222.106975317,21457.2831799984,761.1923484802246,1.4095752239227295,0.0 -63500,1.3273386,2.2740424,,,,,,,,,,,,,, -63600,1.3338205,2.278207,,,,,,,,,,,,,, -63700,1.2575959,2.2589266,,,,,,,,,,,,,, -63800,1.3864138,2.150831,,,,,,,,,,,,,, -63900,1.3409345,2.231776,,,,,,,,,,,,,, -64000,1.5281879,2.0924788,,,,,,,,,,,,,, -64100,1.3145461,2.170275,,,,,,,,,,,,,, -64200,1.2543502,1.9529468,,,,,,,,,,,,,, -64300,1.3907372,2.1854382,,,,,,,,,,,,,, -64400,1.166299,2.0732934,,,,,,,,,,,,,, -64500,1.4150505,2.1955936,,,,,,,,,,,,,, -64600,1.2904804,2.1455622,,,,,,,,,,,,,, -64700,1.3722006,2.1475575,,,,,,,,,,,,,, -64800,1.3483088,2.135486,,,,,,,,,,,,,, -64900,1.2734365,2.142445,,,,,,,,,,,,,, -64993,,,0.2693319320678711,3.968483686447144,0.2618399858474731,4.07616662979126,50000.0,0.1979000121355056,4.845719337463379,10000.0,21967.26203918457,22749.678109169006,21967.26203918457,778.6932566165924,1.4470484256744385,0.0 -65000,1.315363,2.1270103,,,,,,,,,,,,,, -65100,1.3535955,2.226813,,,,,,,,,,,,,, -65200,1.2679131,2.1046934,,,,,,,,,,,,,, -65300,1.3672836,2.3707836,,,,,,,,,,,,,, -65400,1.3390151,2.185392,,,,,,,,,,,,,, -65500,1.3826288,2.2549202,,,,,,,,,,,,,, -65600,1.2827617,2.1081793,,,,,,,,,,,,,, -65700,1.4395405,2.1157079,,,,,,,,,,,,,, -65800,1.5145491,2.0730367,,,,,,,,,,,,,, -65900,1.3838043,2.2013772,,,,,,,,,,,,,, -66000,1.2130936,2.182009,,,,,,,,,,,,,, -66100,1.3304449,2.1184916,,,,,,,,,,,,,, -66200,1.3076501,2.2713463,,,,,,,,,,,,,, -66300,1.3602952,2.2213478,,,,,,,,,,,,,, -66400,1.3376756,2.1553712,,,,,,,,,,,,,, -66500,1.3292062,2.1679502,,,,,,,,,,,,,, -66507,,,0.3372528553009033,3.1919519901275635,0.3183600008487701,3.329773664474488,50000.0,0.2486000061035156,4.025635242462158,10000.0,22477.223001003265,23277.50735592842,22477.223001003265,796.4663238525391,1.4883835315704346,0.0 -66600,1.5561128,2.1860106,,,,,,,,,,,,,, -66700,1.4204084,2.234151,,,,,,,,,,,,,, -66800,1.2149112,2.1381454,,,,,,,,,,,,,, -66900,1.3428466,2.2171254,,,,,,,,,,,,,, -67000,1.4502985,2.1986823,,,,,,,,,,,,,, -67100,1.3313731,2.2324483,,,,,,,,,,,,,, -67200,1.3026209,2.1060636,,,,,,,,,,,,,, -67300,1.4329386,2.0657372,,,,,,,,,,,,,, -67400,1.2925352,2.132687,,,,,,,,,,,,,, -67500,1.5646538,2.1570754,,,,,,,,,,,,,, -67600,1.3783661,2.1871302,,,,,,,,,,,,,, -67700,1.2128394,2.0851183,,,,,,,,,,,,,, -67800,1.3845247,2.186733,,,,,,,,,,,,,, -67900,1.2958393,2.1349502,,,,,,,,,,,,,, -68000,1.3961139,2.2472022,,,,,,,,,,,,,, -68021,,,0.3380899131298065,3.256957054138184,0.293720006942749,3.590768575668335,50000.0,0.2199000120162964,4.267356395721436,10000.0,22987.152262687683,23805.23312044144,22987.152262687683,814.1657972335815,1.5314984321594238,0.0 -68100,1.3835759,2.1959724,,,,,,,,,,,,,, -68200,1.3570093,2.1557245,,,,,,,,,,,,,, -68300,1.2646989,2.092242,,,,,,,,,,,,,, -68400,1.4473761,2.124281,,,,,,,,,,,,,, -68500,1.390787,2.1866915,,,,,,,,,,,,,, -68600,1.4783969,2.05311,,,,,,,,,,,,,, -68700,1.4343283,2.1604643,,,,,,,,,,,,,, -68800,1.5111128,2.1839445,,,,,,,,,,,,,, -68900,1.2751092,2.2638662,,,,,,,,,,,,,, -69000,1.2884033,2.059105,,,,,,,,,,,,,, -69100,1.3872423,2.2320719,,,,,,,,,,,,,, -69200,1.3536619,2.0592518,,,,,,,,,,,,,, -69300,1.338085,2.0256438,,,,,,,,,,,,,, -69400,1.4054544,2.0841372,,,,,,,,,,,,,, -69500,1.2677902,2.1967363,,,,,,,,,,,,,, -69535,,,0.3651546537876129,3.035815715789795,0.3411799967288971,3.226549863815308,50000.0,0.2547000050544739,4.050992488861084,10000.0,23497.329399108887,24333.222053050995,23497.329399108887,831.8838548660278,1.574218988418579,0.0 -69600,1.5121027,2.035021,,,,,,,,,,,,,, -69700,1.349375,2.0721965,,,,,,,,,,,,,, -69800,1.5020525,2.124176,,,,,,,,,,,,,, -69900,1.485408,2.1287785,,,,,,,,,,,,,, -70000,1.3251692,2.1602848,,,,,,,,,,,,,, -70100,1.3106736,2.0915537,,,,,,,,,,,,,, -70200,1.2087969,2.0493293,,,,,,,,,,,,,, -70300,1.3500859,2.072785,,,,,,,,,,,,,, -70400,1.4417714,2.1421702,,,,,,,,,,,,,, -70500,1.3808004,2.1836786,,,,,,,,,,,,,, -70600,1.3789306,2.257018,,,,,,,,,,,,,, -70700,1.2736351,1.967896,,,,,,,,,,,,,, -70800,1.3646331,2.1299505,,,,,,,,,,,,,, -70900,1.5002242,2.2059216,,,,,,,,,,,,,, -71000,1.3098744,2.2269604,,,,,,,,,,,,,, -71049,,,0.3894889950752258,2.9351134300231934,0.3605599999427795,3.162463903427124,50000.0,0.265500009059906,4.025924205780029,10000.0,24007.33136749268,24860.906602859497,24007.33136749268,849.4705073833466,1.6176283359527588,0.0 -71100,1.4892246,2.340547,,,,,,,,,,,,,, -71200,1.3280964,2.22955,,,,,,,,,,,,,, -71300,1.3742777,2.076394,,,,,,,,,,,,,, -71400,1.4350814,2.041234,,,,,,,,,,,,,, -71500,1.2474782,2.1134233,,,,,,,,,,,,,, -71600,1.2529641,2.1391954,,,,,,,,,,,,,, -71700,1.3760586,2.0745046,,,,,,,,,,,,,, -71800,1.8515534,2.253567,,,,,,,,,,,,,, -71900,1.2455019,2.0895517,,,,,,,,,,,,,, -72000,1.346197,2.1824768,,,,,,,,,,,,,, -72100,1.5863371,2.1754787,,,,,,,,,,,,,, -72200,1.4058928,2.110669,,,,,,,,,,,,,, -72300,1.3625801,2.2076259,,,,,,,,,,,,,, -72400,1.4461162,2.1306353,,,,,,,,,,,,,, -72500,1.332,2.2487442,,,,,,,,,,,,,, -72563,,,0.2634526491165161,3.807114124298096,0.2535600066184997,3.922708034515381,50000.0,0.1791000068187713,4.710382461547852,10000.0,24517.32631087303,25388.561143636703,24517.32631087303,867.0384802818298,1.6555404663085938,0.0 -72600,1.3396928,2.2503936,,,,,,,,,,,,,, -72700,1.4934974,2.1207047,,,,,,,,,,,,,, -72800,1.3952783,2.153139,,,,,,,,,,,,,, -72900,1.3806905,2.0323403,,,,,,,,,,,,,, -73000,1.3041407,2.1279695,,,,,,,,,,,,,, -73100,1.3847873,2.1080065,,,,,,,,,,,,,, -73200,1.4217414,2.1074371,,,,,,,,,,,,,, -73300,1.4805183,2.0634995,,,,,,,,,,,,,, -73400,1.4552382,2.1509614,,,,,,,,,,,,,, -73500,1.4815242,2.140298,,,,,,,,,,,,,, -73600,1.5550234,2.1289887,,,,,,,,,,,,,, -73700,1.4387397,2.2087655,,,,,,,,,,,,,, -73800,1.2898628,2.185346,,,,,,,,,,,,,, -73900,1.4855943,2.1659813,,,,,,,,,,,,,, -74000,1.4794576,2.101286,,,,,,,,,,,,,, -74077,,,0.3816764950752258,2.9451732635498047,0.3625199794769287,3.1048009395599365,50000.0,0.2621999979019165,4.026857376098633,10000.0,25027.415630340576,25916.27106308937,25027.415630340576,884.5662899017334,1.6954331398010254,0.0 -74100,1.3817053,2.095843,,,,,,,,,,,,,, -74200,1.7166796,2.0576305,,,,,,,,,,,,,, -74300,1.4947195,2.217872,,,,,,,,,,,,,, -74400,1.4456735,2.185248,,,,,,,,,,,,,, -74500,1.2945436,2.1186762,,,,,,,,,,,,,, -74600,1.4618437,2.2201982,,,,,,,,,,,,,, -74700,1.3151677,2.1366816,,,,,,,,,,,,,, -74800,1.4183832,2.0588732,,,,,,,,,,,,,, -74900,1.3464695,2.1601014,,,,,,,,,,,,,, -75000,1.4151853,2.1689787,,,,,,,,,,,,,, -75100,1.4889903,2.0732055,,,,,,,,,,,,,, -75200,1.3762292,2.1176898,,,,,,,,,,,,,, -75300,1.3971877,2.0451007,,,,,,,,,,,,,, -75400,1.5171899,2.0381837,,,,,,,,,,,,,, -75500,1.5412154,2.1244001,,,,,,,,,,,,,, -75591,,,0.1342075914144516,5.668425559997559,0.1243399977684021,5.764382839202881,50000.0,0.0870000049471855,6.647572040557861,10000.0,25537.43885302544,26443.80044603348,25537.43885302544,901.976181268692,1.7384414672851562,0.0 -75600,1.6530756,2.1841054,,,,,,,,,,,,,, -75700,1.4736409,2.1874526,,,,,,,,,,,,,, -75800,1.3695099,2.1491039,,,,,,,,,,,,,, -75900,1.5362549,2.1535656,,,,,,,,,,,,,, -76000,1.4764858,1.9920366,,,,,,,,,,,,,, -76100,1.5597483,2.1057098,,,,,,,,,,,,,, -76200,1.4593612,2.0076447,,,,,,,,,,,,,, -76300,1.4175496,1.9973917,,,,,,,,,,,,,, -76400,1.3339307,2.1720746,,,,,,,,,,,,,, -76500,1.3245765,2.090363,,,,,,,,,,,,,, -76600,1.3672637,2.0566654,,,,,,,,,,,,,, -76700,1.4739704,2.18186,,,,,,,,,,,,,, -76800,1.3067203,2.109484,,,,,,,,,,,,,, -76900,1.5691506,2.054926,,,,,,,,,,,,,, -77000,1.4121772,2.0358737,,,,,,,,,,,,,, -77100,1.4276667,2.121986,,,,,,,,,,,,,, -77106,,,0.3316725194454193,3.21846866607666,0.2920999825000763,3.503586769104004,50000.0,0.2205000072717666,4.244337558746338,10000.0,26047.63839364052,26971.54065322876,26047.63839364052,919.41881108284,1.7832458019256592,0.0 -77200,1.5426793,2.0209,,,,,,,,,,,,,, -77300,1.4834458,2.1278605,,,,,,,,,,,,,, -77400,1.4957973,2.1123946,,,,,,,,,,,,,, -77500,1.5303128,2.154904,,,,,,,,,,,,,, -77600,1.4112477,2.0525181,,,,,,,,,,,,,, -77700,1.5155092,2.039018,,,,,,,,,,,,,, -77800,1.3639334,2.1400883,,,,,,,,,,,,,, -77900,1.3668094,2.063006,,,,,,,,,,,,,, -78000,1.5424706,2.0682023,,,,,,,,,,,,,, -78100,1.5013599,2.0154617,,,,,,,,,,,,,, -78200,1.4146316,2.0186234,,,,,,,,,,,,,, -78300,1.4844543,2.0849586,,,,,,,,,,,,,, -78400,1.4440688,2.0484116,,,,,,,,,,,,,, -78500,1.4150211,1.976275,,,,,,,,,,,,,, -78600,1.5400218,2.1814966,,,,,,,,,,,,,, -78620,,,0.3699776828289032,3.103986978530884,0.3526600003242492,3.277158737182617,50000.0,0.2669000029563904,4.146973609924316,10000.0,26557.77646183968,27499.20015025139,26557.77646183968,936.843403339386,1.825603008270264,0.0 -78700,1.4547111,2.1427984,,,,,,,,,,,,,, -78800,1.4258378,1.9397334,,,,,,,,,,,,,, -78900,1.5289216,2.078699,,,,,,,,,,,,,, -79000,1.3785979,2.1047184,,,,,,,,,,,,,, -79100,1.3157191,2.119385,,,,,,,,,,,,,, -79200,1.3711535,2.2048616,,,,,,,,,,,,,, -79300,1.472938,2.015612,,,,,,,,,,,,,, -79400,1.449197,2.1279242,,,,,,,,,,,,,, -79500,1.4279215,2.2520726,,,,,,,,,,,,,, -79600,1.6103883,2.1027822,,,,,,,,,,,,,, -79700,1.4095423,2.0366693,,,,,,,,,,,,,, -79800,1.6316181,2.0184388,,,,,,,,,,,,,, -79900,1.5430316,2.2174103,,,,,,,,,,,,,, -80000,1.5236809,1.9665927,,,,,,,,,,,,,, -80100,1.4832177,2.1116717,,,,,,,,,,,,,, -80134,,,0.3420161008834839,3.218864917755127,0.3225799798965454,3.411599636077881,50000.0,0.2421000152826309,4.21008825302124,10000.0,27067.766325950623,28026.72852373123,27067.766325950623,954.2851715087892,1.868180513381958,0.0 -80200,1.650006,2.1311996,,,,,,,,,,,,,, -80300,1.4356791,2.148314,,,,,,,,,,,,,, -80400,1.471603,2.082625,,,,,,,,,,,,,, -80500,1.5251275,2.1434069,,,,,,,,,,,,,, -80600,1.4050612,2.019274,,,,,,,,,,,,,, -80700,1.3621432,2.0884197,,,,,,,,,,,,,, -80800,1.5663004,2.13849,,,,,,,,,,,,,, -80900,1.395597,1.9314164,,,,,,,,,,,,,, -81000,1.4695588,2.0376122,,,,,,,,,,,,,, -81100,1.412953,2.1803663,,,,,,,,,,,,,, -81200,1.5203575,2.056208,,,,,,,,,,,,,, -81300,1.5808947,2.051959,,,,,,,,,,,,,, -81400,1.8334686,2.0879362,,,,,,,,,,,,,, -81500,1.364107,2.184924,,,,,,,,,,,,,, -81600,1.4404653,2.038428,,,,,,,,,,,,,, -81648,,,0.3948501050472259,2.805603265762329,0.3705599904060364,2.987741470336914,50000.0,0.281900018453598,3.800844669342041,10000.0,27577.878240585327,28554.49188780785,27577.878240585327,971.8416512012482,1.9092822074890137,0.0 -81700,1.4725528,2.1285017,,,,,,,,,,,,,, -81800,1.5356951,2.0156932,,,,,,,,,,,,,, -81900,1.5042601,2.0141928,,,,,,,,,,,,,, -82000,1.3610313,2.117133,,,,,,,,,,,,,, -82100,1.5008698,1.995779,,,,,,,,,,,,,, -82200,1.5359075,2.0161774,,,,,,,,,,,,,, -82300,1.3295529,2.025325,,,,,,,,,,,,,, -82400,1.4320565,2.0467331,,,,,,,,,,,,,, -82500,1.5975549,2.0676243,,,,,,,,,,,,,, -82600,1.5348154,2.1586149,,,,,,,,,,,,,, -82700,1.6525123,2.1032572,,,,,,,,,,,,,, -82800,1.5904186,2.0852423,,,,,,,,,,,,,, -82900,1.4178798,2.0526903,,,,,,,,,,,,,, -83000,1.4594427,2.0075417,,,,,,,,,,,,,, -83100,1.5544426,2.099029,,,,,,,,,,,,,, -83162,,,0.2846579849720001,3.92525863647461,0.2743200063705444,4.041182041168213,50000.0,0.1921000033617019,5.06205940246582,10000.0,28087.81955480576,29082.3350391388,28087.81955480576,989.644821882248,1.954078197479248,0.0 -83200,1.4573303,1.8910725,,,,,,,,,,,,,, -83300,1.5405034,2.1136622,,,,,,,,,,,,,, -83400,1.5407013,2.0406754,,,,,,,,,,,,,, -83500,1.5680034,2.022573,,,,,,,,,,,,,, -83600,1.5531363,2.1237736,,,,,,,,,,,,,, -83700,1.7165927,2.1310313,,,,,,,,,,,,,, -83800,1.4027476,2.078023,,,,,,,,,,,,,, -83900,1.5907708,2.1306024,,,,,,,,,,,,,, -84000,1.5550964,2.1463637,,,,,,,,,,,,,, -84100,1.4867644,2.0726464,,,,,,,,,,,,,, -84200,1.8196535,2.1650052,,,,,,,,,,,,,, -84300,1.547042,2.054792,,,,,,,,,,,,,, -84400,1.4353579,1.9912087,,,,,,,,,,,,,, -84500,1.5221416,2.015474,,,,,,,,,,,,,, -84600,1.431887,2.05005,,,,,,,,,,,,,, -84677,,,0.4646643698215484,2.382607698440552,0.4369199872016907,2.5394155979156494,50000.0,0.331900030374527,3.2385549545288086,10000.0,28598.049216508865,29610.035831689835,28598.049216508865,1007.0220937728882,1.9960856437683103,0.0 -84700,1.5071361,1.9517245,,,,,,,,,,,,,, -84800,1.6252289,2.0059884,,,,,,,,,,,,,, -84900,1.5035976,2.267658,,,,,,,,,,,,,, -85000,1.5227554,2.056899,,,,,,,,,,,,,, -85100,1.5759015,2.1598141,,,,,,,,,,,,,, -85200,1.5232805,2.1566703,,,,,,,,,,,,,, -85300,1.658039,2.0626125,,,,,,,,,,,,,, -85400,1.4629257,2.06553,,,,,,,,,,,,,, -85500,1.4048542,2.0385625,,,,,,,,,,,,,, -85600,1.4683146,2.0120392,,,,,,,,,,,,,, -85700,1.551682,2.1697798,,,,,,,,,,,,,, -85800,1.4583822,2.0351932,,,,,,,,,,,,,, -85900,1.6105623,2.0699015,,,,,,,,,,,,,, -86000,1.4958231,2.1413944,,,,,,,,,,,,,, -86100,1.5359058,1.9941761,,,,,,,,,,,,,, -86191,,,0.482122927904129,2.2993392944335938,0.4251599907875061,2.6702864170074463,50000.0,0.3320000171661377,3.385972738265991,10000.0,29108.04040503502,30137.687334775925,29108.04040503502,1024.5876967906952,2.03897762298584,0.0 -86200,1.5454732,1.9897501,,,,,,,,,,,,,, -86300,1.5685985,2.1382885,,,,,,,,,,,,,, -86400,1.5269428,1.9719031,,,,,,,,,,,,,, -86500,1.7542021,2.0536644,,,,,,,,,,,,,, -86600,1.4755309,2.014972,,,,,,,,,,,,,, -86700,1.4203241,2.0721564,,,,,,,,,,,,,, -86800,1.5218066,2.15529,,,,,,,,,,,,,, -86900,1.7076604,2.1072512,,,,,,,,,,,,,, -87000,1.4457239,1.9319177,,,,,,,,,,,,,, -87100,1.5537944,2.0597417,,,,,,,,,,,,,, -87200,1.5520301,2.0783372,,,,,,,,,,,,,, -87300,1.512037,1.9951267,,,,,,,,,,,,,, -87400,1.7274201,2.0788176,,,,,,,,,,,,,, -87500,1.4092827,2.0469458,,,,,,,,,,,,,, -87600,1.5042918,2.0087051,,,,,,,,,,,,,, -87700,1.4533844,1.9648759,,,,,,,,,,,,,, -87705,,,0.3338049948215484,3.208051443099976,0.304419994354248,3.446050882339477,50000.0,0.2369000166654586,4.14555025100708,10000.0,29617.99115371704,30665.211062908173,29617.99115371704,1042.0650265216827,2.081550359725952,0.0 -87800,1.5813948,2.051182,,,,,,,,,,,,,, -87900,1.6634957,2.033969,,,,,,,,,,,,,, -88000,1.51431,1.9658499,,,,,,,,,,,,,, -88100,1.5896428,2.101273,,,,,,,,,,,,,, -88200,1.51245,1.9823455,,,,,,,,,,,,,, -88300,1.5376709,2.091913,,,,,,,,,,,,,, -88400,1.4791863,2.040886,,,,,,,,,,,,,, -88500,1.5237311,1.9582175,,,,,,,,,,,,,, -88600,1.5816717,2.022798,,,,,,,,,,,,,, -88700,1.6402177,2.00848,,,,,,,,,,,,,, -88800,1.5750605,2.0676541,,,,,,,,,,,,,, -88900,1.5732212,2.1054387,,,,,,,,,,,,,, -89000,1.5931852,2.0858867,,,,,,,,,,,,,, -89100,1.4675602,2.0676074,,,,,,,,,,,,,, -89200,1.5781131,1.9041196,,,,,,,,,,,,,, -89219,,,0.3150709569454193,3.5599539279937744,0.2946600019931793,3.770634651184082,50000.0,0.2253000140190124,4.554732322692871,10000.0,30127.92679142952,31192.59133434296,30127.92679142952,1059.4081492424011,2.1293301582336426,0.0 -89300,1.6820763,2.0745664,,,,,,,,,,,,,, -89400,1.4342294,1.97621,,,,,,,,,,,,,, -89500,1.3840318,1.9578446,,,,,,,,,,,,,, -89600,1.7615412,2.1467342,,,,,,,,,,,,,, -89700,1.5562967,1.9323976,,,,,,,,,,,,,, -89800,1.7465199,2.184883,,,,,,,,,,,,,, -89900,1.5379857,2.1225212,,,,,,,,,,,,,, -90000,1.4868149,2.0569692,,,,,,,,,,,,,, -90100,1.6140597,2.003934,,,,,,,,,,,,,, -90200,1.4452655,1.9454781,,,,,,,,,,,,,, -90300,1.5586634,2.0501437,,,,,,,,,,,,,, -90400,1.747987,2.135839,,,,,,,,,,,,,, -90500,1.7278706,2.0316117,,,,,,,,,,,,,, -90600,1.4338516,2.0820725,,,,,,,,,,,,,, -90700,1.6565249,1.95195,,,,,,,,,,,,,, -90733,,,0.3970822691917419,2.781815528869629,0.3685599863529205,2.9695539474487305,50000.0,0.2732000052928924,3.744641542434693,10000.0,30637.95221996308,31720.2343685627,30637.95221996308,1076.9277966022491,2.174423694610596,0.0 -90800,1.5946565,1.9522612,,,,,,,,,,,,,, -90900,1.6237386,1.9928515,,,,,,,,,,,,,, -91000,1.5352658,1.9630518,,,,,,,,,,,,,, -91100,1.6225282,1.9897754,,,,,,,,,,,,,, -91200,1.4233629,1.9889662,,,,,,,,,,,,,, -91300,1.478759,1.9739448,,,,,,,,,,,,,, -91400,1.6861113,1.9503021,,,,,,,,,,,,,, -91500,1.5156922,2.035068,,,,,,,,,,,,,, -91600,1.5665629,2.074631,,,,,,,,,,,,,, -91700,1.5986041,2.1069322,,,,,,,,,,,,,, -91800,1.7284474,2.1247082,,,,,,,,,,,,,, -91900,1.9159877,2.0019326,,,,,,,,,,,,,, -92000,1.5740376,2.0979304,,,,,,,,,,,,,, -92100,1.5979503,2.0132773,,,,,,,,,,,,,, -92200,1.4850283,1.9172014,,,,,,,,,,,,,, -92246,,,0.4543407261371612,2.392561912536621,0.4245999753475189,2.587425708770752,50000.0,0.320000022649765,3.2951889038085938,10000.0,31148.09440755844,32248.54153752327,31148.09440755844,1094.9929592609406,2.2205803394317627,0.0 -92300,1.660739,1.9316522,,,,,,,,,,,,,, -92400,1.8380635,2.0752535,,,,,,,,,,,,,, -92500,1.6152991,2.121569,,,,,,,,,,,,,, -92600,1.5830187,1.9966929,,,,,,,,,,,,,, -92700,1.7167494,1.9807612,,,,,,,,,,,,,, -92800,1.8406265,2.1359499,,,,,,,,,,,,,, -92900,1.5734485,2.0050023,,,,,,,,,,,,,, -93000,1.5362444,2.0395243,,,,,,,,,,,,,, -93100,1.5753245,2.02858,,,,,,,,,,,,,, -93200,1.5942568,2.0123003,,,,,,,,,,,,,, -93300,1.7911276,2.0540743,,,,,,,,,,,,,, -93400,1.6612387,1.961155,,,,,,,,,,,,,, -93500,1.581851,2.0078619,,,,,,,,,,,,,, -93600,1.9702184,1.9552058,,,,,,,,,,,,,, -93700,1.6911948,2.1994038,,,,,,,,,,,,,, -93760,,,0.458685427904129,2.434083938598633,0.4264200031757355,2.6672022342681885,50000.0,0.3368000090122223,3.445615291595459,10000.0,31658.02108550072,32775.989436626434,31658.02108550072,1112.41717171669,2.264895439147949,0.0 -93800,1.6263968,2.0013947,,,,,,,,,,,,,, -93900,1.6749305,2.0103035,,,,,,,,,,,,,, -94000,1.4873682,2.0604599,,,,,,,,,,,,,, -94100,1.529476,1.9022477,,,,,,,,,,,,,, -94200,1.4750212,1.9660889,,,,,,,,,,,,,, -94300,1.625511,2.127382,,,,,,,,,,,,,, -94400,1.5966157,1.8690617,,,,,,,,,,,,,, -94500,1.5798714,2.1183188,,,,,,,,,,,,,, -94600,1.449013,1.9413623,,,,,,,,,,,,,, -94700,1.5826523,2.0807931,,,,,,,,,,,,,, -94800,1.7215909,2.090455,,,,,,,,,,,,,, -94900,1.5804263,2.1301074,,,,,,,,,,,,,, -95000,1.6973963,2.0293138,,,,,,,,,,,,,, -95100,1.6866469,1.937815,,,,,,,,,,,,,, -95200,1.7747443,1.9530277,,,,,,,,,,,,,, -95275,,,0.3795639276504516,2.966967821121216,0.340179979801178,3.3062493801116943,50000.0,0.2595000267028808,4.024302959442139,10000.0,32168.247616052628,33303.67395091057,32168.247616052628,1129.7756674289703,2.3125619888305664,0.0 -95300,1.6200527,1.8991587,,,,,,,,,,,,,, -95400,1.5104252,1.9875108,,,,,,,,,,,,,, -95500,1.6292288,2.045725,,,,,,,,,,,,,, -95600,1.5681847,1.8601918,,,,,,,,,,,,,, -95700,1.6288469,1.9356747,,,,,,,,,,,,,, -95800,1.5635699,2.0725896,,,,,,,,,,,,,, -95900,1.7085338,1.9192348,,,,,,,,,,,,,, -96000,1.9701449,2.025556,,,,,,,,,,,,,, -96100,1.6754658,2.02701,,,,,,,,,,,,,, -96200,1.9097333,1.9602973,,,,,,,,,,,,,, -96300,1.6615785,1.9253345,,,,,,,,,,,,,, -96400,1.6558204,1.9968381,,,,,,,,,,,,,, -96500,1.6682945,2.0715387,,,,,,,,,,,,,, -96600,1.7127819,2.0840437,,,,,,,,,,,,,, -96700,1.6588855,2.0031857,,,,,,,,,,,,,, -96790,,,0.4096579849720001,2.806935787200928,0.3768999874591827,3.023160934448242,50000.0,0.2849000096321106,3.783292055130005,10000.0,32678.433045387268,33831.3767824173,32678.433045387268,1147.1902074813845,2.362378597259521,0.0 -96800,1.7054253,2.066031,,,,,,,,,,,,,, -96900,1.7242146,2.0196304,,,,,,,,,,,,,, -97000,1.619599,1.8162923,,,,,,,,,,,,,, -97100,1.554034,1.9848127,,,,,,,,,,,,,, -97200,1.8185382,2.053211,,,,,,,,,,,,,, -97300,1.6050514,2.004184,,,,,,,,,,,,,, -97400,1.6857927,1.9377627,,,,,,,,,,,,,, -97500,1.6911863,1.8731619,,,,,,,,,,,,,, -97600,1.6742694,2.105166,,,,,,,,,,,,,, -97700,1.5633343,1.9057401,,,,,,,,,,,,,, -97800,2.0821784,2.15282,,,,,,,,,,,,,, -97900,1.6402315,1.9320025,,,,,,,,,,,,,, -98000,1.785865,1.9973724,,,,,,,,,,,,,, -98100,1.8486413,2.0711727,,,,,,,,,,,,,, -98200,1.820671,1.897626,,,,,,,,,,,,,, -98300,1.6966155,1.9957564,,,,,,,,,,,,,, -98304,,,0.351283460855484,3.222519874572754,0.3321599960327148,3.3702187538146973,50000.0,0.2519000172615051,4.132585525512695,10000.0,33188.55919909477,34359.24005818367,33188.55919909477,1164.826674938202,2.409458637237549,0.0 -98400,1.7523139,1.9386835,,,,,,,,,,,,,, -98500,1.8463644,1.9795539,,,,,,,,,,,,,, -98600,1.7820965,2.0372214,,,,,,,,,,,,,, -98700,1.7318848,1.985385,,,,,,,,,,,,,, -98800,1.9002569,2.0795822,,,,,,,,,,,,,, -98900,1.6627842,1.867942,,,,,,,,,,,,,, -99000,1.6833626,2.067589,,,,,,,,,,,,,, -99100,1.6218454,1.9233907,,,,,,,,,,,,,, -99200,2.1027825,2.0081887,,,,,,,,,,,,,, -99300,1.7073921,1.9092523,,,,,,,,,,,,,, -99400,1.9799725,1.9904017,,,,,,,,,,,,,, -99500,1.6247189,1.9935907,,,,,,,,,,,,,, -99600,2.0574193,2.03587,,,,,,,,,,,,,, -99700,1.7869225,2.0091345,,,,,,,,,,,,,, -99800,1.654911,1.84298,,,,,,,,,,,,,, -99818,,,0.328125,3.5659189224243164,0.3105199933052063,3.713387966156006,50000.0,0.2345000058412552,4.5583319664001465,10000.0,33698.464007377625,34887.28036427498,33698.464007377625,1182.862447977066,2.4572508335113525,0.0 -99900,1.8273351,2.0313795,,,,,,,,,,,,,, -100000,1.6472013,1.8589442,,,,,,,,,,,,,, -100100,1.6986525,1.9401548,,,,,,,,,,,,,, -100200,1.780778,1.9615371,,,,,,,,,,,,,, -100300,1.57667,1.9032483,,,,,,,,,,,,,, -100400,1.8006585,1.9782102,,,,,,,,,,,,,, -100500,1.905745,1.8966349,,,,,,,,,,,,,, -100600,1.7272315,1.8651023,,,,,,,,,,,,,, -100700,1.678036,1.9285634,,,,,,,,,,,,,, -100800,1.6432967,1.8670871,,,,,,,,,,,,,, -100900,1.6972932,1.9413887,,,,,,,,,,,,,, -101000,1.64129,1.9437733,,,,,,,,,,,,,, -101100,2.0873652,2.061039,,,,,,,,,,,,,, -101200,1.5973924,1.8881185,,,,,,,,,,,,,, -101300,1.653631,1.9161699,,,,,,,,,,,,,, -101332,,,0.4641461968421936,2.4268994331359863,0.4397799968719482,2.5999388694763184,50000.0,0.3333000242710113,3.3893141746521,10000.0,34208.510954380035,35414.95019340515,34208.510954380035,1200.3929872512815,2.495596408843994,0.0 -101400,1.883077,1.8952031,,,,,,,,,,,,,, -101500,1.8257217,2.0044436,,,,,,,,,,,,,, -101600,1.6943182,1.9180893,,,,,,,,,,,,,, -101700,2.0141017,2.1384583,,,,,,,,,,,,,, -101800,1.866861,1.9074649,,,,,,,,,,,,,, -101900,1.880945,2.002943,,,,,,,,,,,,,, -102000,1.8120478,1.9756894,,,,,,,,,,,,,, -102100,1.7223055,1.9705925,,,,,,,,,,,,,, -102200,1.856046,1.9943173,,,,,,,,,,,,,, -102300,2.012922,1.9619672,,,,,,,,,,,,,, -102400,1.7310222,1.8907506,,,,,,,,,,,,,, -102500,1.660817,1.8579063,,,,,,,,,,,,,, -102600,1.9675868,1.9481974,,,,,,,,,,,,,, -102700,1.8118513,1.9198028,,,,,,,,,,,,,, -102800,1.7582691,1.9973989,,,,,,,,,,,,,, -102846,,,0.4216557741165161,2.711665391921997,0.3911399841308594,2.9025368690490723,50000.0,0.3058000206947326,3.579463481903076,10000.0,34718.491564273834,35942.75317645073,34718.491564273834,1218.1181762218475,2.540222406387329,0.0 -102900,1.8429732,1.9393363,,,,,,,,,,,,,, -103000,2.0578291,1.9210618,,,,,,,,,,,,,, -103100,1.7493879,1.8715011,,,,,,,,,,,,,, -103200,1.7810184,1.9022009,,,,,,,,,,,,,, -103300,1.7070366,1.9606622,,,,,,,,,,,,,, -103400,1.812251,1.9742775,,,,,,,,,,,,,, -103500,1.8169891,1.9113307,,,,,,,,,,,,,, -103600,1.8327672,2.0108259,,,,,,,,,,,,,, -103700,2.0002973,1.8796475,,,,,,,,,,,,,, -103800,1.7636806,1.8289561,,,,,,,,,,,,,, -103900,1.9603919,1.9236504,,,,,,,,,,,,,, -104000,1.582305,1.7935667,,,,,,,,,,,,,, -104100,1.8157727,1.969063,,,,,,,,,,,,,, -104200,1.8117801,1.9395926,,,,,,,,,,,,,, -104300,1.749317,1.9435754,,,,,,,,,,,,,, -104360,,,0.5540298223495483,1.874036312103272,0.4943999946117401,2.240631580352783,50000.0,0.3788000047206878,3.055749654769897,10000.0,35228.44272494316,36470.514008522034,35228.44272494316,1235.8320398330688,2.584373712539673,0.0 -104400,1.7753608,1.9278538,,,,,,,,,,,,,, -104500,2.0520654,2.0256846,,,,,,,,,,,,,, -104600,1.8360134,1.9999192,,,,,,,,,,,,,, -104700,1.7233183,1.9313108,,,,,,,,,,,,,, -104800,1.877952,1.9623682,,,,,,,,,,,,,, -104900,1.8856132,1.9871027,,,,,,,,,,,,,, -105000,1.8395249,1.8287108,,,,,,,,,,,,,, -105100,1.7342716,1.8849713,,,,,,,,,,,,,, -105200,1.953989,2.062417,,,,,,,,,,,,,, -105300,1.6385783,1.8313553,,,,,,,,,,,,,, -105400,1.8765888,1.9523895,,,,,,,,,,,,,, -105500,1.8521664,1.9786999,,,,,,,,,,,,,, -105600,1.8038187,1.9103765,,,,,,,,,,,,,, -105700,1.7927929,1.9183774,,,,,,,,,,,,,, -105800,1.9041269,1.8997077,,,,,,,,,,,,,, -105874,,,0.5051219463348389,2.1285979747772217,0.4580999910831451,2.40424919128418,50000.0,0.3613000214099884,3.1027300357818604,10000.0,35738.398156404495,36997.97931480408,35738.398156404495,1253.241216421127,2.6315665245056152,0.0 -105900,1.8259158,1.8257599,,,,,,,,,,,,,, -106000,1.896245,1.980924,,,,,,,,,,,,,, -106100,1.9151876,1.8906236,,,,,,,,,,,,,, -106200,1.8518711,1.8383647,,,,,,,,,,,,,, -106300,1.7763907,1.8882922,,,,,,,,,,,,,, -106400,1.894825,2.0275862,,,,,,,,,,,,,, -106500,1.7054845,1.9575839,,,,,,,,,,,,,, -106600,1.8916391,1.936226,,,,,,,,,,,,,, -106700,1.8850102,1.921811,,,,,,,,,,,,,, -106800,1.8151398,1.8033366,,,,,,,,,,,,,, -106900,1.90145,1.9021193,,,,,,,,,,,,,, -107000,2.0079417,1.8643918,,,,,,,,,,,,,, -107100,2.0503154,1.8784978,,,,,,,,,,,,,, -107200,1.8505341,1.890169,,,,,,,,,,,,,, -107300,1.7506636,1.8251164,,,,,,,,,,,,,, -107388,,,0.5316087007522583,2.0331740379333496,0.4949599802494049,2.2442522048950195,50000.0,0.3841000199317932,3.025560140609741,10000.0,36248.40759110451,37525.970309495926,36248.40759110451,1271.1150813102722,2.683964490890503,0.0 -107400,1.9336141,1.8859563,,,,,,,,,,,,,, -107500,2.0060334,1.7889919,,,,,,,,,,,,,, -107600,1.7901183,1.9009845,,,,,,,,,,,,,, -107700,1.7166563,1.9704775,,,,,,,,,,,,,, -107800,1.8819244,1.9293411,,,,,,,,,,,,,, -107900,1.8690901,1.941567,,,,,,,,,,,,,, -108000,2.1666203,1.983636,,,,,,,,,,,,,, -108100,1.8291417,1.8892407,,,,,,,,,,,,,, -108200,1.9194613,1.8054374,,,,,,,,,,,,,, -108300,1.8313187,1.8742859,,,,,,,,,,,,,, -108400,1.7789001,1.8914701,,,,,,,,,,,,,, -108500,1.8925805,1.791445,,,,,,,,,,,,,, -108600,1.978135,1.9436754,,,,,,,,,,,,,, -108700,1.9933037,1.9141101,,,,,,,,,,,,,, -108800,1.8068935,1.8681227,,,,,,,,,,,,,, -108900,1.8525268,1.8197889,,,,,,,,,,,,,, -108902,,,0.5332629084587097,2.0092248916625977,0.4962999820709228,2.2380595207214355,50000.0,0.3831000328063965,3.019572496414185,10000.0,36758.39783191681,38053.5044836998,36758.39783191681,1288.5332021713257,2.757606744766236,0.0 -109000,1.9479356,1.9476538,,,,,,,,,,,,,, -109100,1.8936844,1.9650868,,,,,,,,,,,,,, -109200,1.765746,1.8362579,,,,,,,,,,,,,, -109300,1.8654871,1.9492152,,,,,,,,,,,,,, -109400,1.8831242,1.7928692,,,,,,,,,,,,,, -109500,1.937904,1.8536464,,,,,,,,,,,,,, -109600,1.8680987,1.9956586,,,,,,,,,,,,,, -109700,2.2613535,1.9442408,,,,,,,,,,,,,, -109800,2.0047529,1.9247407,,,,,,,,,,,,,, -109900,1.9168322,1.9246857,,,,,,,,,,,,,, -110000,2.0535007,1.9396211,,,,,,,,,,,,,, -110100,1.933815,1.8564466,,,,,,,,,,,,,, -110200,1.8357518,1.8568094,,,,,,,,,,,,,, -110300,1.9420536,1.8070478,,,,,,,,,,,,,, -110400,1.8796456,1.9282284,,,,,,,,,,,,,, -110416,,,0.5140505433082581,2.1145224571228027,0.4885999858379364,2.261136531829834,50000.0,0.3677000105381012,3.080395698547364,10000.0,37268.32622623444,38581.21153998375,37268.32622623444,1306.213036775589,2.8046953678131104,0.0 -110500,1.8254263,1.867159,,,,,,,,,,,,,, -110600,1.9236957,1.827282,,,,,,,,,,,,,, -110700,1.8143574,1.9402899,,,,,,,,,,,,,, -110800,1.8391291,1.785016,,,,,,,,,,,,,, -110900,1.9724457,1.9595296,,,,,,,,,,,,,, -111000,1.9695432,1.7577466,,,,,,,,,,,,,, -111100,1.9229726,1.730381,,,,,,,,,,,,,, -111200,2.021011,1.9402525,,,,,,,,,,,,,, -111300,2.0102005,1.8123068,,,,,,,,,,,,,, -111400,2.1024382,1.8437883,,,,,,,,,,,,,, -111500,2.2046635,1.8579359,,,,,,,,,,,,,, -111600,1.8857328,1.8363011,,,,,,,,,,,,,, -111700,2.0533757,1.824923,,,,,,,,,,,,,, -111800,2.1446447,1.9682857,,,,,,,,,,,,,, -111900,2.0769513,1.864678,,,,,,,,,,,,,, -111930,,,0.5302335619926453,2.0475571155548096,0.4895599782466888,2.2830638885498047,50000.0,0.369700014591217,3.0804600715637207,10000.0,37778.27315187454,39108.91320705414,37778.27315187454,1323.868717432022,2.851999282836914,0.0 -112000,2.0728533,1.9072714,,,,,,,,,,,,,, -112100,2.1461415,1.8460345,,,,,,,,,,,,,, -112200,2.065427,1.8930068,,,,,,,,,,,,,, -112300,1.9264748,1.8565404,,,,,,,,,,,,,, -112400,1.9096757,1.8303239,,,,,,,,,,,,,, -112500,2.0778224,1.867128,,,,,,,,,,,,,, -112600,2.0790284,1.8909682,,,,,,,,,,,,,, -112700,2.127338,1.9076763,,,,,,,,,,,,,, -112800,1.927895,1.7399884,,,,,,,,,,,,,, -112900,2.066287,1.8787737,,,,,,,,,,,,,, -113000,1.9527959,1.9211481,,,,,,,,,,,,,, -113100,2.166307,1.7943226,,,,,,,,,,,,,, -113200,2.2635624,1.8704628,,,,,,,,,,,,,, -113300,2.0634403,1.9269241,,,,,,,,,,,,,, -113400,1.9363514,1.7255225,,,,,,,,,,,,,, -113445,,,0.5375478267669678,1.958598017692566,0.4891199767589569,2.234560489654541,50000.0,0.3817000091075897,2.9632489681243896,10000.0,38288.5008263588,39637.01444840431,38288.5008263588,1341.6378679275513,2.903039693832397,0.0 -113500,1.8321763,1.7756057,,,,,,,,,,,,,, -113600,1.824359,1.8329496,,,,,,,,,,,,,, -113700,2.1316664,1.845776,,,,,,,,,,,,,, -113800,2.2706463,1.8733478,,,,,,,,,,,,,, -113900,2.0208826,1.8027296,,,,,,,,,,,,,, -114000,2.0395117,1.7768083,,,,,,,,,,,,,, -114100,2.2242577,1.933158,,,,,,,,,,,,,, -114200,2.0808513,1.8777261,,,,,,,,,,,,,, -114300,1.9468573,1.8447748,,,,,,,,,,,,,, -114400,2.107306,1.8694779,,,,,,,,,,,,,, -114500,2.2063293,1.7631218,,,,,,,,,,,,,, -114600,2.3965855,1.9549854,,,,,,,,,,,,,, -114700,2.1292863,1.7962284,,,,,,,,,,,,,, -114800,1.899261,1.7986327,,,,,,,,,,,,,, -114900,2.175935,1.8607181,,,,,,,,,,,,,, -114960,,,0.5321866869926453,1.9915037155151367,0.4983199834823608,2.20292329788208,50000.0,0.379800021648407,3.0490152835845947,10000.0,38798.68224453926,40164.63128519058,38798.68224453926,1358.971853017807,2.952772617340088,0.0 -115000,2.236798,1.8196027,,,,,,,,,,,,,, -115100,2.133185,1.8455523,,,,,,,,,,,,,, -115200,2.1836689,1.8704123,,,,,,,,,,,,,, -115300,2.0790813,1.7062447,,,,,,,,,,,,,, -115400,2.0311124,1.960398,,,,,,,,,,,,,, -115500,2.0512981,1.7905477,,,,,,,,,,,,,, -115600,1.9199059,1.7927768,,,,,,,,,,,,,, -115700,2.2000437,1.9858294,,,,,,,,,,,,,, -115800,2.0823774,1.7701025,,,,,,,,,,,,,, -115900,1.984121,1.9733787,,,,,,,,,,,,,, -116000,2.0003757,1.7557342,,,,,,,,,,,,,, -116100,2.2257776,1.9666708,,,,,,,,,,,,,, -116200,1.9768054,1.8114285,,,,,,,,,,,,,, -116300,2.0422492,1.8222834,,,,,,,,,,,,,, -116400,2.0471194,1.8470931,,,,,,,,,,,,,, -116475,,,0.5823102593421936,1.767098069190979,0.5386199951171875,2.0045840740203857,50000.0,0.4178000092506408,2.817715644836426,10000.0,39308.85549354553,40692.430067777634,39308.85549354553,1376.4977297782898,2.999505519866944,0.0 -116500,2.197484,1.7914996,,,,,,,,,,,,,, -116600,1.9409703,1.9221482,,,,,,,,,,,,,, -116700,2.0758464,1.841471,,,,,,,,,,,,,, -116800,2.1512036,1.790264,,,,,,,,,,,,,, -116900,2.0979822,1.7812424,,,,,,,,,,,,,, -117000,2.022211,1.6141633,,,,,,,,,,,,,, -117100,2.1129475,1.7457404,,,,,,,,,,,,,, -117200,2.0989869,1.8543442,,,,,,,,,,,,,, -117300,2.2282717,1.9122108,,,,,,,,,,,,,, -117400,2.2506838,1.8575153,,,,,,,,,,,,,, -117500,2.1162872,1.8115405,,,,,,,,,,,,,, -117600,2.1865597,1.8801618,,,,,,,,,,,,,, -117700,2.2977078,1.7884501,,,,,,,,,,,,,, -117800,2.08365,1.9728947,,,,,,,,,,,,,, -117900,2.2331126,1.875514,,,,,,,,,,,,,, -117989,,,0.4283721148967743,2.6851370334625244,0.3919200003147125,2.899744987487793,50000.0,0.2914000153541565,3.691535711288452,10000.0,39818.84736561775,41220.18163561821,39818.84736561775,1394.1560413837433,3.048726320266724,0.0 -118000,2.1227586,1.8914225,,,,,,,,,,,,,, -118100,2.1407232,1.9800322,,,,,,,,,,,,,, -118200,2.0603561,1.8777266,,,,,,,,,,,,,, -118300,2.148448,1.8644054,,,,,,,,,,,,,, -118400,2.15693,1.7464697,,,,,,,,,,,,,, -118500,2.1545775,1.7996187,,,,,,,,,,,,,, -118600,2.1238859,1.7386607,,,,,,,,,,,,,, -118700,2.200435,1.7989445,,,,,,,,,,,,,, -118800,2.070801,1.7941818,,,,,,,,,,,,,, -118900,2.176191,1.7344447,,,,,,,,,,,,,, -119000,2.0751314,1.8770062,,,,,,,,,,,,,, -119100,2.3002665,1.8439964,,,,,,,,,,,,,, -119200,2.1871204,1.7225361,,,,,,,,,,,,,, -119300,2.1730626,1.8063142,,,,,,,,,,,,,, -119400,2.2530687,1.8506699,,,,,,,,,,,,,, -119500,2.169609,1.783404,,,,,,,,,,,,,, -119503,,,0.447644293308258,2.602922201156616,0.4219799935817718,2.8036701679229736,50000.0,0.3220000267028808,3.63352370262146,10000.0,40328.843329668045,41747.48498153687,40328.843329668045,1411.35599732399,3.1032016277313232,0.0 -119600,2.2548504,1.8645326,,,,,,,,,,,,,, -119700,2.1891012,1.7814966,,,,,,,,,,,,,, -119800,2.3140287,1.9197994,,,,,,,,,,,,,, -119900,2.099543,1.7683605,,,,,,,,,,,,,, -120000,2.157247,1.7764031,,,,,,,,,,,,,, -120100,2.1330175,1.7676295,,,,,,,,,,,,,, -120200,2.2946987,1.7917107,,,,,,,,,,,,,, -120300,2.370034,1.8240047,,,,,,,,,,,,,, -120400,2.1159625,1.673296,,,,,,,,,,,,,, -120500,2.1029434,1.7372818,,,,,,,,,,,,,, -120600,2.184234,1.8201542,,,,,,,,,,,,,, -120700,2.1886728,1.7953202,,,,,,,,,,,,,, -120800,2.074805,1.7494675,,,,,,,,,,,,,, -120900,2.1901932,1.6887224,,,,,,,,,,,,,, -121000,2.2763376,1.8741693,,,,,,,,,,,,,, -121017,,,0.5178372263908386,2.0767385959625244,0.4694999754428863,2.3726882934570312,50000.0,0.3559000194072723,3.172493934631348,10000.0,40838.74562501907,42275.05197453499,40838.74562501907,1428.919724702835,3.1509153842926025,0.0 -121100,2.1402562,1.7472155,,,,,,,,,,,,,, -121200,2.1440544,1.6351316,,,,,,,,,,,,,, -121300,2.2539291,1.7443459,,,,,,,,,,,,,, -121400,2.2525802,1.782263,,,,,,,,,,,,,, -121500,2.1885536,1.8047256,,,,,,,,,,,,,, -121600,2.2422113,1.7459824,,,,,,,,,,,,,, -121700,2.4015102,1.7291986,,,,,,,,,,,,,, -121800,2.3032885,1.7914097,,,,,,,,,,,,,, -121900,2.3761802,1.7368053,,,,,,,,,,,,,, -122000,2.1590648,1.6902797,,,,,,,,,,,,,, -122100,2.1525958,1.8297298,,,,,,,,,,,,,, -122200,2.172686,1.7197967,,,,,,,,,,,,,, -122300,2.2682717,1.802555,,,,,,,,,,,,,, -122400,2.2154286,1.7316808,,,,,,,,,,,,,, -122500,2.4405122,1.7559209,,,,,,,,,,,,,, -122531,,,0.6288264989852905,1.5047695636749268,0.5694999694824219,1.8335968255996704,50000.0,0.4456000328063965,2.6093904972076416,10000.0,41348.64927601814,42802.48861408234,41348.64927601814,1446.3481650352478,3.2026822566986084,0.0 -122600,2.0745165,1.6942041,,,,,,,,,,,,,, -122700,2.3094006,1.8205662,,,,,,,,,,,,,, -122800,2.2872553,1.7798512,,,,,,,,,,,,,, -122900,2.1931844,1.7948365,,,,,,,,,,,,,, -123000,2.235243,1.7636627,,,,,,,,,,,,,, -123100,2.147047,1.6009102,,,,,,,,,,,,,, -123200,2.3035674,1.7323692,,,,,,,,,,,,,, -123300,2.453145,1.7667526,,,,,,,,,,,,,, -123400,2.3543355,1.8157824,,,,,,,,,,,,,, -123500,2.3560534,1.7693422,,,,,,,,,,,,,, -123600,2.6913836,1.6608589,,,,,,,,,,,,,, -123700,2.396377,1.7781746,,,,,,,,,,,,,, -123800,2.176401,1.6685954,,,,,,,,,,,,,, -123900,2.2760975,1.7848077,,,,,,,,,,,,,, -124000,2.247975,1.7462677,,,,,,,,,,,,,, -124045,,,0.6246811151504517,1.5214416980743408,0.5735200047492981,1.811548113822937,50000.0,0.4452000260353088,2.5815134048461914,10000.0,41858.61288642883,43330.34150052071,41858.61288642883,1464.131145477295,3.255311965942383,0.0 -124100,2.187102,1.8010348,,,,,,,,,,,,,, -124200,2.8661551,1.803344,,,,,,,,,,,,,, -124300,2.2042356,1.5895629,,,,,,,,,,,,,, -124400,2.336294,1.8927271,,,,,,,,,,,,,, -124500,2.2301123,1.8513321,,,,,,,,,,,,,, -124600,2.3512857,1.8551735,,,,,,,,,,,,,, -124700,2.3338096,1.7214663,,,,,,,,,,,,,, -124800,2.214262,1.7757305,,,,,,,,,,,,,, -124900,2.410974,1.850269,,,,,,,,,,,,,, -125000,2.1823335,1.8219352,,,,,,,,,,,,,, -125100,2.1539886,1.7837143,,,,,,,,,,,,,, -125200,2.420809,1.7131166,,,,,,,,,,,,,, -125300,2.3533564,1.8415936,,,,,,,,,,,,,, -125400,2.1809576,1.6702082,,,,,,,,,,,,,, -125500,2.5554225,1.7412722,,,,,,,,,,,,,, -125560,,,0.6104312539100647,1.618659257888794,0.5592799782752991,1.915094375610352,50000.0,0.4448000192642212,2.698650360107422,10000.0,42368.84647965431,43858.30543160439,42368.84647965431,1481.754063129425,3.308755159378052,0.0 -125600,2.5384192,1.7482411,,,,,,,,,,,,,, -125700,2.6594756,1.7848516,,,,,,,,,,,,,, -125800,2.4571695,1.8419952,,,,,,,,,,,,,, -125900,2.4112062,1.5690341,,,,,,,,,,,,,, -126000,2.6453145,1.761393,,,,,,,,,,,,,, -126100,2.0867975,1.5754077,,,,,,,,,,,,,, -126200,2.252095,1.6174715,,,,,,,,,,,,,, -126300,2.6461816,1.7732334,,,,,,,,,,,,,, -126400,2.389339,1.7257956,,,,,,,,,,,,,, -126500,2.5156465,1.8098514,,,,,,,,,,,,,, -126600,2.4881663,1.7391417,,,,,,,,,,,,,, -126700,2.5289636,1.8026279,,,,,,,,,,,,,, -126800,2.345719,1.7856847,,,,,,,,,,,,,, -126900,2.4466336,1.7292963,,,,,,,,,,,,,, -127000,2.7035344,1.7338613,,,,,,,,,,,,,, -127074,,,0.5913584232330322,1.713505506515503,0.5462999939918518,2.002792835235596,50000.0,0.4290000200271606,2.7601938247680664,10000.0,42878.92653274536,44385.98586678505,42878.92653274536,1499.2517862319946,3.358278512954712,0.0 -127100,2.6392035,1.734936,,,,,,,,,,,,,, -127200,2.3678765,1.6604936,,,,,,,,,,,,,, -127300,2.2426157,1.7265198,,,,,,,,,,,,,, -127400,2.5891378,1.8160582,,,,,,,,,,,,,, -127500,2.7460158,1.8182256,,,,,,,,,,,,,, -127600,2.466618,1.6784897,,,,,,,,,,,,,, -127700,2.315575,1.7317334,,,,,,,,,,,,,, -127800,2.5379488,1.6245006,,,,,,,,,,,,,, -127900,2.6378746,1.7698762,,,,,,,,,,,,,, -128000,2.5213006,1.8032147,,,,,,,,,,,,,, -128100,2.4910984,1.746613,,,,,,,,,,,,,, -128200,2.2770932,1.6689247,,,,,,,,,,,,,, -128300,2.322212,1.6535454,,,,,,,,,,,,,, -128400,2.5954401,1.6888841,,,,,,,,,,,,,, -128500,2.685609,1.7562796,,,,,,,,,,,,,, -128588,,,0.6278898119926453,1.5220273733139038,0.5804799795150757,1.792826771736145,50000.0,0.4603000283241272,2.5546529293060303,10000.0,43388.88260865212,44913.40585780144,43388.88260865212,1516.6094889640808,3.411818504333496,0.0 -128600,2.425514,1.6780815,,,,,,,,,,,,,, -128700,2.3091452,1.5998129,,,,,,,,,,,,,, -128800,2.4764059,1.738973,,,,,,,,,,,,,, -128900,2.4883099,1.6066074,,,,,,,,,,,,,, -129000,2.41999,1.6136436,,,,,,,,,,,,,, -129100,2.34666,1.5418758,,,,,,,,,,,,,, -129200,2.6848407,1.7270834,,,,,,,,,,,,,, -129300,2.5130534,1.7546996,,,,,,,,,,,,,, -129400,2.6574562,1.6095612,,,,,,,,,,,,,, -129500,2.4299946,1.7040021,,,,,,,,,,,,,, -129600,2.4898846,1.8310806,,,,,,,,,,,,,, -129700,2.5918784,1.6539816,,,,,,,,,,,,,, -129800,2.4170005,1.5598859,,,,,,,,,,,,,, -129900,2.549449,1.685346,,,,,,,,,,,,,, -130000,2.3926308,1.6179862,,,,,,,,,,,,,, -130100,2.427509,1.6489755,,,,,,,,,,,,,, -130102,,,0.6101123690605164,1.616065502166748,0.5356799960136414,2.0169262886047363,50000.0,0.4164000153541565,2.8192737102508545,10000.0,43898.80344581604,45441.27255749703,43898.80344581604,1534.450347661972,3.4636495113372803,0.0 -130200,2.4921305,1.5721067,,,,,,,,,,,,,, -130300,2.5256863,1.7326325,,,,,,,,,,,,,, -130400,2.528752,1.5685201,,,,,,,,,,,,,, -130500,2.5596035,1.6103735,,,,,,,,,,,,,, -130600,2.5448632,1.7610002,,,,,,,,,,,,,, -130700,2.5379136,1.6114541,,,,,,,,,,,,,, -130800,2.623268,1.7585812,,,,,,,,,,,,,, -130900,2.711277,1.7577934,,,,,,,,,,,,,, -131000,2.355732,1.5902172,,,,,,,,,,,,,, -131100,2.6900637,1.6423229,,,,,,,,,,,,,, -131200,2.646236,1.5971009,,,,,,,,,,,,,, -131300,2.624355,1.7132077,,,,,,,,,,,,,, -131400,2.6944206,1.7635175,,,,,,,,,,,,,, -131500,2.637745,1.6425217,,,,,,,,,,,,,, -131600,2.9692314,1.726504,,,,,,,,,,,,,, -131616,,,0.6253786683082581,1.5129410028457642,0.5648399591445923,1.8565038442611688,50000.0,0.4419000148773193,2.6380434036254883,10000.0,44408.81854104996,45968.98812127113,44408.81854104996,1552.0316081047058,3.529684543609619,0.0 -131700,2.5318334,1.6606302,,,,,,,,,,,,,, -131800,2.3910232,1.5690356,,,,,,,,,,,,,, -131900,2.5393164,1.5280856,,,,,,,,,,,,,, -132000,2.64493,1.6087981,,,,,,,,,,,,,, -132100,2.7696347,1.5934651,,,,,,,,,,,,,, -132200,2.6612642,1.6359022,,,,,,,,,,,,,, -132300,2.6399343,1.678906,,,,,,,,,,,,,, -132400,2.8580709,1.6473352,,,,,,,,,,,,,, -132500,2.6047056,1.6938367,,,,,,,,,,,,,, -132600,2.3705475,1.6211021,,,,,,,,,,,,,, -132700,2.6420393,1.6516172,,,,,,,,,,,,,, -132800,2.7594895,1.700669,,,,,,,,,,,,,, -132900,2.5571392,1.6706722,,,,,,,,,,,,,, -133000,2.9000566,1.6805904,,,,,,,,,,,,,, -133100,2.5251076,1.6380432,,,,,,,,,,,,,, -133131,,,0.6495934128761292,1.4116238355636597,0.5912799835205078,1.7342907190322876,50000.0,0.4800000190734863,2.4350662231445312,10000.0,44918.97057008743,46496.75433373451,44918.97057008743,1569.5413410663605,3.5810797214508057,0.0 -133200,2.7455568,1.5707574,,,,,,,,,,,,,, -133300,2.8649404,1.705166,,,,,,,,,,,,,, -133400,3.07465,1.6591074,,,,,,,,,,,,,, -133500,2.7166421,1.6868421,,,,,,,,,,,,,, -133600,2.6824982,1.5415473,,,,,,,,,,,,,, -133700,2.7747355,1.6829505,,,,,,,,,,,,,, -133800,2.6890042,1.594513,,,,,,,,,,,,,, -133900,2.5803926,1.5836301,,,,,,,,,,,,,, -134000,2.7137327,1.6269635,,,,,,,,,,,,,, -134100,2.581505,1.594369,,,,,,,,,,,,,, -134200,2.7557943,1.652221,,,,,,,,,,,,,, -134300,2.4062114,1.5785686,,,,,,,,,,,,,, -134400,2.5672228,1.6539404,,,,,,,,,,,,,, -134500,2.4641464,1.5594603,,,,,,,,,,,,,, -134600,3.2157702,1.6936668,,,,,,,,,,,,,, -134645,,,0.6475008130073547,1.4163345098495483,0.5905199646949768,1.7114334106445312,50000.0,0.4717000126838684,2.4368271827697754,10000.0,45428.92800879479,47024.43297600746,45428.92800879479,1587.159719467163,3.6312103271484375,0.0 -134700,2.6987636,1.6640062,,,,,,,,,,,,,, -134800,2.7353702,1.6418848,,,,,,,,,,,,,, -134900,2.8231506,1.6312003,,,,,,,,,,,,,, -135000,2.7275941,1.722561,,,,,,,,,,,,,, -135100,2.6728005,1.6210728,,,,,,,,,,,,,, -135200,3.0639932,1.671892,,,,,,,,,,,,,, -135300,2.6150987,1.516813,,,,,,,,,,,,,, -135400,2.6614647,1.566245,,,,,,,,,,,,,, -135500,2.6775675,1.5904955,,,,,,,,,,,,,, -135600,2.6266751,1.6139638,,,,,,,,,,,,,, -135700,2.7419355,1.6567022,,,,,,,,,,,,,, -135800,2.7119877,1.6158347,,,,,,,,,,,,,, -135900,3.0915794,1.7015164,,,,,,,,,,,,,, -136000,2.6875486,1.68052,,,,,,,,,,,,,, -136100,2.9322479,1.5814102,,,,,,,,,,,,,, -136159,,,0.6627072691917419,1.3503409624099731,0.6053799986839294,1.6583715677261353,50000.0,0.4859000146389007,2.411069869995117,10000.0,45938.94582152367,47552.14718770981,45938.94582152367,1604.752497673035,3.681864023208618,0.0 -136200,2.6824608,1.5431756,,,,,,,,,,,,,, -136300,2.6748457,1.6697245,,,,,,,,,,,,,, -136400,2.8577943,1.6540716,,,,,,,,,,,,,, -136500,2.7745733,1.6965314,,,,,,,,,,,,,, -136600,2.894098,1.6876111,,,,,,,,,,,,,, -136700,2.7675185,1.6524254,,,,,,,,,,,,,, -136800,3.0931127,1.7254603,,,,,,,,,,,,,, -136900,2.986115,1.609183,,,,,,,,,,,,,, -137000,2.8696613,1.5080154,,,,,,,,,,,,,, -137100,2.8000388,1.5974451,,,,,,,,,,,,,, -137200,2.73808,1.4726785,,,,,,,,,,,,,, -137300,3.0591185,1.6019889,,,,,,,,,,,,,, -137400,3.0831563,1.6024169,,,,,,,,,,,,,, -137500,2.834345,1.5642648,,,,,,,,,,,,,, -137600,2.901763,1.6749618,,,,,,,,,,,,,, -137674,,,0.6785116195678711,1.28064227104187,0.6271799802780151,1.5513432025909424,50000.0,0.5019000172615051,2.331106662750244,10000.0,46449.14306783676,48080.309012174606,46449.14306783676,1622.6106128692627,3.7330148220062256,0.0 -137700,3.1059284,1.6504596,,,,,,,,,,,,,, -137800,2.9453292,1.4898727,,,,,,,,,,,,,, -137900,2.9866645,1.6247308,,,,,,,,,,,,,, -138000,3.007614,1.6673942,,,,,,,,,,,,,, -138100,2.7215261,1.5556911,,,,,,,,,,,,,, -138200,2.890239,1.727089,,,,,,,,,,,,,, -138300,2.8282146,1.5212168,,,,,,,,,,,,,, -138400,2.9234552,1.6332843,,,,,,,,,,,,,, -138500,3.0280704,1.5805147,,,,,,,,,,,,,, -138600,2.6956277,1.4840274,,,,,,,,,,,,,, -138700,3.041647,1.6020839,,,,,,,,,,,,,, -138800,3.1246464,1.6415244,,,,,,,,,,,,,, -138900,2.7314332,1.5662998,,,,,,,,,,,,,, -139000,3.1170292,1.7194262,,,,,,,,,,,,,, -139100,3.16369,1.5814623,,,,,,,,,,,,,, -139188,,,0.667410671710968,1.3150213956832886,0.5842999815940857,1.7889710664749146,50000.0,0.467600017786026,2.5262668132781982,10000.0,46959.10416579247,48607.7670943737,46959.10416579247,1640.0028715133667,3.786928653717041,0.0 -139200,2.9259377,1.5500308,,,,,,,,,,,,,, -139300,2.939247,1.5294396,,,,,,,,,,,,,, -139400,3.057502,1.6519034,,,,,,,,,,,,,, -139500,2.8958387,1.479403,,,,,,,,,,,,,, -139600,2.8530118,1.5459936,,,,,,,,,,,,,, -139700,2.9623754,1.534429,,,,,,,,,,,,,, -139800,3.222123,1.6200317,,,,,,,,,,,,,, -139900,2.6321077,1.4520323,,,,,,,,,,,,,, -140000,2.8166668,1.4707069,,,,,,,,,,,,,, -140100,3.1168373,1.5458038,,,,,,,,,,,,,, -140200,3.0954928,1.5632058,,,,,,,,,,,,,, -140300,3.164726,1.586653,,,,,,,,,,,,,, -140400,2.8836665,1.6191564,,,,,,,,,,,,,, -140500,2.8684804,1.5625119,,,,,,,,,,,,,, -140600,3.2489269,1.6376531,,,,,,,,,,,,,, -140700,2.8646724,1.4452904,,,,,,,,,,,,,, -140701,,,0.7164978981018066,1.0897971391677856,0.6477800011634827,1.4555057287216189,50000.0,0.5172000527381897,2.2077748775482178,10000.0,47469.02995443344,49135.43949460983,47469.02995443344,1657.6408331394196,3.842206716537476,0.0 -140800,3.3081405,1.6070734,,,,,,,,,,,,,, -140900,2.8631527,1.4850608,,,,,,,,,,,,,, -141000,3.0592406,1.5592333,,,,,,,,,,,,,, -141100,3.041829,1.6560614,,,,,,,,,,,,,, -141200,2.977128,1.4649713,,,,,,,,,,,,,, -141300,3.1707127,1.5035157,,,,,,,,,,,,,, -141400,3.0929976,1.4609715,,,,,,,,,,,,,, -141500,2.8639612,1.5836694,,,,,,,,,,,,,, -141600,2.9262242,1.6060157,,,,,,,,,,,,,, -141700,3.1280363,1.5100558,,,,,,,,,,,,,, -141800,3.0257938,1.5872507,,,,,,,,,,,,,, -141900,2.9533167,1.6028153,,,,,,,,,,,,,, -142000,3.4604473,1.5316274,,,,,,,,,,,,,, -142100,3.1235814,1.4962977,,,,,,,,,,,,,, -142200,2.9973,1.5332098,,,,,,,,,,,,,, -142214,,,0.7066326141357422,1.145608901977539,0.6356599926948547,1.4984140396118164,50000.0,0.5021000504493713,2.242482900619507,10000.0,47979.00236058235,49662.91038489342,47979.00236058235,1675.0142815113068,3.91502857208252,0.0 -142300,3.160401,1.4854856,,,,,,,,,,,,,, -142400,3.1393878,1.5110002,,,,,,,,,,,,,, -142500,2.9702444,1.4469256,,,,,,,,,,,,,, -142600,3.1457455,1.5926684,,,,,,,,,,,,,, -142700,3.1732366,1.5729212,,,,,,,,,,,,,, -142800,3.2566466,1.5820348,,,,,,,,,,,,,, -142900,3.0804563,1.6108496,,,,,,,,,,,,,, -143000,3.3159158,1.5065004,,,,,,,,,,,,,, -143100,2.9443264,1.5074135,,,,,,,,,,,,,, -143200,3.0925624,1.5053141,,,,,,,,,,,,,, -143300,3.107665,1.5379198,,,,,,,,,,,,,, -143400,3.23115,1.5993452,,,,,,,,,,,,,, -143500,3.0067873,1.4605136,,,,,,,,,,,,,, -143600,3.4104106,1.5681536,,,,,,,,,,,,,, -143700,3.7671185,1.5747099,,,,,,,,,,,,,, -143729,,,0.6968669891357422,1.1911671161651611,0.6352399587631226,1.5196411609649658,50000.0,0.508400022983551,2.2351131439208984,10000.0,48489.20419001579,50191.19554066658,48489.20419001579,1692.9932827949524,3.9672343730926514,0.0 -143800,3.0649624,1.5024749,,,,,,,,,,,,,, -143900,3.2426481,1.4439645,,,,,,,,,,,,,, -144000,3.4541237,1.565439,,,,,,,,,,,,,, -144100,3.2700667,1.5653093,,,,,,,,,,,,,, -144200,3.5169144,1.4191909,,,,,,,,,,,,,, -144300,3.0903962,1.4816914,,,,,,,,,,,,,, -144400,3.3509636,1.6141899,,,,,,,,,,,,,, -144500,3.2593935,1.488965,,,,,,,,,,,,,, -144600,3.3501182,1.5486656,,,,,,,,,,,,,, -144700,3.3148408,1.5498164,,,,,,,,,,,,,, -144800,3.3271632,1.511207,,,,,,,,,,,,,, -144900,3.2057047,1.5068845,,,,,,,,,,,,,, -145000,3.4611197,1.5231025,,,,,,,,,,,,,, -145100,3.4123082,1.5295861,,,,,,,,,,,,,, -145200,3.3219883,1.4876168,,,,,,,,,,,,,, -145244,,,0.6825175285339355,1.2460888624191284,0.616379976272583,1.6117098331451416,50000.0,0.5085000395774841,2.3389248847961426,10000.0,48999.42097496986,50718.9482088089,48999.42097496986,1710.424084186554,4.019224882125855,0.0 -145300,3.493224,1.6579578,,,,,,,,,,,,,, -145400,3.097268,1.4608719,,,,,,,,,,,,,, -145500,2.90711,1.3712764,,,,,,,,,,,,,, -145600,3.2163055,1.5562003,,,,,,,,,,,,,, -145700,3.2440064,1.4535544,,,,,,,,,,,,,, -145800,3.6246715,1.4855413,,,,,,,,,,,,,, -145900,3.4043374,1.495423,,,,,,,,,,,,,, -146000,3.551348,1.578887,,,,,,,,,,,,,, -146100,3.0790935,1.4117733,,,,,,,,,,,,,, -146200,3.5445843,1.5202888,,,,,,,,,,,,,, -146300,3.352525,1.5523306,,,,,,,,,,,,,, -146400,3.4023144,1.5191177,,,,,,,,,,,,,, -146500,3.3709376,1.4445571,,,,,,,,,,,,,, -146600,3.4691172,1.5411743,,,,,,,,,,,,,, -146700,3.1979933,1.4398086,,,,,,,,,,,,,, -146758,,,0.7297114133834839,1.0454022884368896,0.6575199961662292,1.3876160383224487,50000.0,0.536300003528595,2.1201376914978027,10000.0,49509.32262468338,51246.62455034256,49509.32262468338,1728.0896308422089,4.0746119022369385,0.0 -146800,3.3916116,1.4045445,,,,,,,,,,,,,, -146900,3.382327,1.514017,,,,,,,,,,,,,, -147000,3.4222023,1.4878157,,,,,,,,,,,,,, -147100,3.2735624,1.438897,,,,,,,,,,,,,, -147200,3.342679,1.4309428,,,,,,,,,,,,,, -147300,3.4270988,1.6212333,,,,,,,,,,,,,, -147400,3.5409455,1.4463645,,,,,,,,,,,,,, -147500,3.1981866,1.4584373,,,,,,,,,,,,,, -147600,3.3212602,1.3557147,,,,,,,,,,,,,, -147700,3.720948,1.39696,,,,,,,,,,,,,, -147800,3.2020192,1.3827429,,,,,,,,,,,,,, -147900,3.2440503,1.3442053,,,,,,,,,,,,,, -148000,3.427322,1.4387243,,,,,,,,,,,,,, -148100,3.7140157,1.4157227,,,,,,,,,,,,,, -148200,3.6349936,1.5889082,,,,,,,,,,,,,, -148272,,,0.7495615482330322,0.9471167325973512,0.6625799536705017,1.3929927349090576,50000.0,0.5314000248908997,2.1541640758514404,10000.0,50019.25556206703,51773.96777963638,50019.25556206703,1745.3887765407562,4.133601903915405,0.0 -148300,3.5071185,1.4536659,,,,,,,,,,,,,, -148400,3.2273242,1.3964937,,,,,,,,,,,,,, -148500,3.3704174,1.4729491,,,,,,,,,,,,,, -148600,3.609921,1.575829,,,,,,,,,,,,,, -148700,3.4936776,1.4766916,,,,,,,,,,,,,, -148800,3.5096648,1.4875002,,,,,,,,,,,,,, -148900,3.8152084,1.4322565,,,,,,,,,,,,,, -149000,3.4655645,1.4059,,,,,,,,,,,,,, -149100,3.6218076,1.4426173,,,,,,,,,,,,,, -149200,3.554204,1.5439627,,,,,,,,,,,,,, -149300,3.4618804,1.491471,,,,,,,,,,,,,, -149400,3.705743,1.5187972,,,,,,,,,,,,,, -149500,3.581638,1.4317226,,,,,,,,,,,,,, -149600,3.7548318,1.4692214,,,,,,,,,,,,,, -149700,3.4253612,1.4327284,,,,,,,,,,,,,, -149787,,,0.7493821382522583,0.9529452323913574,0.6685400009155273,1.3474432229995728,50000.0,0.534500002861023,2.082439422607422,10000.0,50529.4074280262,52301.51612615585,50529.4074280262,1762.67853140831,4.18830156326294,0.0 -149800,3.5828137,1.5474102,,,,,,,,,,,,,, -149900,3.3628993,1.4230219,,,,,,,,,,,,,, -150000,3.715127,1.3961445,,,,,,,,,,,,,, -150100,3.5095563,1.4391271,,,,,,,,,,,,,, -150200,3.5462298,1.3491408,,,,,,,,,,,,,, -150300,3.7366576,1.3861271,,,,,,,,,,,,,, -150400,3.740476,1.4664742,,,,,,,,,,,,,, -150500,3.4328687,1.4167917,,,,,,,,,,,,,, -150600,3.398177,1.3569268,,,,,,,,,,,,,, -150700,3.4500535,1.3780252,,,,,,,,,,,,,, -150800,3.4887836,1.5312814,,,,,,,,,,,,,, -150900,3.4474497,1.4406345,,,,,,,,,,,,,, -151000,3.76574,1.3887382,,,,,,,,,,,,,, -151100,3.4286592,1.350294,,,,,,,,,,,,,, -151200,3.8677394,1.3840395,,,,,,,,,,,,,, -151300,3.7650063,1.4099001,,,,,,,,,,,,,, -151301,,,0.7281169891357422,1.0500671863555908,0.6520599722862244,1.4359551668167114,50000.0,0.5250000357627869,2.173976182937622,10000.0,51039.44907045365,52829.5035905838,51039.44907045365,1780.5153470039368,4.245532751083374,0.0 -151400,3.4058115,1.2369688,,,,,,,,,,,,,, -151500,3.4332287,1.32455,,,,,,,,,,,,,, -151600,3.452442,1.3165247,,,,,,,,,,,,,, -151700,3.6271,1.478864,,,,,,,,,,,,,, -151800,3.632082,1.3484473,,,,,,,,,,,,,, -151900,3.5041366,1.3615048,,,,,,,,,,,,,, -152000,3.5197465,1.3039389,,,,,,,,,,,,,, -152100,3.7813208,1.313896,,,,,,,,,,,,,, -152200,3.5240874,1.3748558,,,,,,,,,,,,,, -152300,3.7774112,1.3419436,,,,,,,,,,,,,, -152400,3.7379663,1.3474877,,,,,,,,,,,,,, -152500,3.7851045,1.4187441,,,,,,,,,,,,,, -152600,3.7951622,1.4351137,,,,,,,,,,,,,, -152700,3.9272995,1.5284743,,,,,,,,,,,,,, -152800,3.6464427,1.4495394,,,,,,,,,,,,,, -152815,,,0.7424864172935486,0.9894612431526184,0.6647999882698059,1.3785579204559326,50000.0,0.5455000400543213,2.086134195327759,10000.0,51549.42072844505,53357.01241946221,51549.42072844505,1797.9433093070984,4.3016557693481445,0.0 -152900,4.0149426,1.4387509,,,,,,,,,,,,,, -153000,3.9553616,1.332129,,,,,,,,,,,,,, -153100,3.5211904,1.3239119,,,,,,,,,,,,,, -153200,3.7446003,1.3656831,,,,,,,,,,,,,, -153300,3.972835,1.2993625,,,,,,,,,,,,,, -153400,3.6973107,1.3742708,,,,,,,,,,,,,, -153500,3.6461613,1.4520569,,,,,,,,,,,,,, -153600,3.7834113,1.3269231,,,,,,,,,,,,,, -153700,3.717253,1.2766742,,,,,,,,,,,,,, -153800,3.906441,1.2791858,,,,,,,,,,,,,, -153900,3.6571338,1.3163816,,,,,,,,,,,,,, -154000,3.7852933,1.3684412,,,,,,,,,,,,,, -154100,3.5801022,1.2697552,,,,,,,,,,,,,, -154200,3.9302378,1.4300666,,,,,,,,,,,,,, -154300,3.7927792,1.3239388,,,,,,,,,,,,,, -154329,,,0.7708266973495483,0.8738231658935547,0.6909799575805664,1.2597182989120483,50000.0,0.5574000477790833,2.01480484008789,10000.0,52059.3456428051,53884.340841293335,52059.3456428051,1815.2350759506223,4.359313249588013,0.0 -154400,4.090256,1.3356462,,,,,,,,,,,,,, -154500,3.8983068,1.3934239,,,,,,,,,,,,,, -154600,3.8581996,1.3605998,,,,,,,,,,,,,, -154700,3.6642818,1.2762867,,,,,,,,,,,,,, -154800,4.033273,1.3856156,,,,,,,,,,,,,, -154900,3.6906056,1.3353655,,,,,,,,,,,,,, -155000,4.08186,1.4354955,,,,,,,,,,,,,, -155100,4.0857577,1.4086938,,,,,,,,,,,,,, -155200,3.9467356,1.4201927,,,,,,,,,,,,,, -155300,4.039769,1.3164024,,,,,,,,,,,,,, -155400,3.7779713,1.3699502,,,,,,,,,,,,,, -155500,3.8514206,1.3120105,,,,,,,,,,,,,, -155600,3.7910883,1.3556671,,,,,,,,,,,,,, -155700,4.12981,1.4224952,,,,,,,,,,,,,, -155800,3.8222895,1.3411243,,,,,,,,,,,,,, -155844,,,0.7562380433082581,0.9150132536888124,0.6816799640655518,1.297837734222412,50000.0,0.5478000044822693,2.047738790512085,10000.0,52569.56672739983,54411.96320319176,52569.56672739983,1832.526507616043,4.413227081298828,0.0 -155900,3.9784951,1.3988712,,,,,,,,,,,,,, -156000,3.8427832,1.2979352,,,,,,,,,,,,,, -156100,3.8792655,1.271313,,,,,,,,,,,,,, -156200,4.062735,1.469133,,,,,,,,,,,,,, -156300,3.9973593,1.2620813,,,,,,,,,,,,,, -156400,3.94427,1.3080132,,,,,,,,,,,,,, -156500,3.8627865,1.3311216,,,,,,,,,,,,,, -156600,3.761317,1.1735022,,,,,,,,,,,,,, -156700,3.9786646,1.2447169,,,,,,,,,,,,,, -156800,3.857272,1.3178968,,,,,,,,,,,,,, -156900,3.683483,1.2205014,,,,,,,,,,,,,, -157000,3.8850255,1.2171509,,,,,,,,,,,,,, -157100,3.9094872,1.2317134,,,,,,,,,,,,,, -157200,4.260562,1.3551116,,,,,,,,,,,,,, -157300,4.07127,1.2643672,,,,,,,,,,,,,, -157358,,,0.7893216013908386,0.7879802584648132,0.6877599954605103,1.278788447380066,50000.0,0.5590000152587891,2.0179154872894287,10000.0,53079.48914647102,54939.77777934074,53079.48914647102,1850.3109276294708,4.46701979637146,0.0 -157400,4.1508517,1.3489755,,,,,,,,,,,,,, -157500,4.023927,1.2351108,,,,,,,,,,,,,, -157600,4.2918983,1.3461932,,,,,,,,,,,,,, -157700,4.337308,1.3063562,,,,,,,,,,,,,, -157800,4.2808666,1.4002662,,,,,,,,,,,,,, -157900,4.2504635,1.253674,,,,,,,,,,,,,, -158000,4.1814766,1.2711312,,,,,,,,,,,,,, -158100,4.3859544,1.3222115,,,,,,,,,,,,,, -158200,4.7939262,1.2255397,,,,,,,,,,,,,, -158300,3.8828132,1.2292503,,,,,,,,,,,,,, -158400,4.3548064,1.208513,,,,,,,,,,,,,, -158500,4.4170165,1.3928943,,,,,,,,,,,,,, -158600,4.375092,1.2684109,,,,,,,,,,,,,, -158700,4.136725,1.3081361,,,,,,,,,,,,,, -158800,3.9189112,1.2344515,,,,,,,,,,,,,, -158873,,,0.7989277839660645,0.7345190644264221,0.7048599720001221,1.2018253803253174,50000.0,0.5772000551223755,1.9140307903289795,10000.0,53589.66965150833,55467.48884010315,53589.66965150833,1867.7333896160128,4.52330470085144,0.0 -158900,4.327172,1.2765707,,,,,,,,,,,,,, -159000,4.0566125,1.2066535,,,,,,,,,,,,,, -159100,4.431079,1.2802665,,,,,,,,,,,,,, -159200,4.28545,1.2693402,,,,,,,,,,,,,, -159300,4.501535,1.2501734,,,,,,,,,,,,,, -159400,4.311121,1.2565185,,,,,,,,,,,,,, -159500,4.123731,1.2196774,,,,,,,,,,,,,, -159600,4.618519,1.3685616,,,,,,,,,,,,,, -159700,4.2116246,1.2595577,,,,,,,,,,,,,, -159800,4.167567,1.2140611,,,,,,,,,,,,,, -159900,4.20665,1.2727265,,,,,,,,,,,,,, -160000,3.9531925,1.1290122,,,,,,,,,,,,,, -160100,3.9585042,1.1607136,,,,,,,,,,,,,, -160200,4.893995,1.2025158,,,,,,,,,,,,,, -160300,4.2783203,1.2272671,,,,,,,,,,,,,, -160387,,,0.7996053695678711,0.7400349974632263,0.7064200043678284,1.1957767009735107,50000.0,0.5842000246047974,1.8834655284881592,10000.0,54099.629590034485,55994.96142053604,54099.629590034485,1885.144741296768,4.571201324462891,0.0 -160400,4.0023236,1.2474773,,,,,,,,,,,,,, -160500,4.6080165,1.2372228,,,,,,,,,,,,,, -160600,4.39798,1.1724164,,,,,,,,,,,,,, -160700,4.1733227,1.2866617,,,,,,,,,,,,,, -160800,4.2345743,1.3141227,,,,,,,,,,,,,, -160900,4.36018,1.2463156,,,,,,,,,,,,,, -161000,4.3248205,1.2615502,,,,,,,,,,,,,, -161100,4.6644135,1.3300755,,,,,,,,,,,,,, -161200,4.3615294,1.1936969,,,,,,,,,,,,,, -161300,4.1775637,1.2150285,,,,,,,,,,,,,, -161400,4.515951,1.2160419,,,,,,,,,,,,,, -161500,4.673601,1.2420354,,,,,,,,,,,,,, -161600,4.4867687,1.2699105,,,,,,,,,,,,,, -161700,4.137982,1.1300559,,,,,,,,,,,,,, -161800,4.670782,1.1934316,,,,,,,,,,,,,, -161900,4.4929214,1.1968868,,,,,,,,,,,,,, -161901,,,0.8086734414100647,0.7023515105247498,0.7143399715423584,1.1650846004486084,50000.0,0.5866000056266785,1.8886688947677608,10000.0,54609.69239163399,56522.81906700134,54609.69239163399,1902.832640647888,4.625036954879761,0.0 -162000,4.845638,1.2435789,,,,,,,,,,,,,, -162100,4.4230843,1.2593929,,,,,,,,,,,,,, -162200,3.8923113,1.1205916,,,,,,,,,,,,,, -162300,4.6002812,1.3474859,,,,,,,,,,,,,, -162400,4.4688745,1.2985375,,,,,,,,,,,,,, -162500,4.43406,1.1997832,,,,,,,,,,,,,, -162600,4.3708224,1.1823512,,,,,,,,,,,,,, -162700,4.4822803,1.2646185,,,,,,,,,,,,,, -162800,4.390908,1.1125787,,,,,,,,,,,,,, -162900,4.9296784,1.2906874,,,,,,,,,,,,,, -163000,4.774992,1.2133867,,,,,,,,,,,,,, -163100,4.558675,1.1489238,,,,,,,,,,,,,, -163200,4.5071683,1.2027211,,,,,,,,,,,,,, -163300,4.3786855,1.1323771,,,,,,,,,,,,,, -163400,4.715576,1.1723063,,,,,,,,,,,,,, -163415,,,0.8216477632522583,0.6528841257095337,0.7217999696731567,1.1238961219787598,50000.0,0.5954000353813171,1.8337507247924805,10000.0,55119.59383225441,57050.48101377487,55119.59383225441,1920.478876829148,4.685438394546509,0.0 -163500,4.3816023,1.2458223,,,,,,,,,,,,,, -163600,4.6366253,1.1923995,,,,,,,,,,,,,, -163700,4.3771753,1.1688813,,,,,,,,,,,,,, -163800,4.5328097,1.1442009,,,,,,,,,,,,,, -163900,4.342332,1.2096634,,,,,,,,,,,,,, -164000,4.279121,1.1144831,,,,,,,,,,,,,, -164100,4.302929,1.1340716,,,,,,,,,,,,,, -164200,4.391298,1.1322217,,,,,,,,,,,,,, -164300,4.610348,1.1825596,,,,,,,,,,,,,, -164400,4.5261607,1.2081234,,,,,,,,,,,,,, -164500,4.43783,1.2197866,,,,,,,,,,,,,, -164600,4.2692776,1.1527929,,,,,,,,,,,,,, -164700,4.8541074,1.13285,,,,,,,,,,,,,, -164800,4.8977437,1.2031834,,,,,,,,,,,,,, -164900,4.6173706,1.1213334,,,,,,,,,,,,,, -164929,,,0.8226243257522583,0.6398840546607971,0.7210400104522705,1.1339006423950195,50000.0,0.5958000421524048,1.849045395851136,10000.0,55629.537470817566,57577.97980308533,55629.537470817566,1937.9243867397308,4.741180896759033,0.0 -165000,4.898855,1.1204422,,,,,,,,,,,,,, -165100,4.9632792,1.179707,,,,,,,,,,,,,, -165200,4.6077194,1.1191925,,,,,,,,,,,,,, -165300,4.747089,1.1559634,,,,,,,,,,,,,, -165400,4.646215,1.1580637,,,,,,,,,,,,,, -165500,5.15024,1.1231985,,,,,,,,,,,,,, -165600,4.776306,1.1583176,,,,,,,,,,,,,, -165700,4.925024,1.1348364,,,,,,,,,,,,,, -165800,4.7401605,1.170722,,,,,,,,,,,,,, -165900,4.0792074,1.0195608,,,,,,,,,,,,,, -166000,4.7867937,1.0563565,,,,,,,,,,,,,, -166100,4.3971133,1.0277615,,,,,,,,,,,,,, -166200,4.7757773,1.1281403,,,,,,,,,,,,,, -166300,4.8181777,1.1091542,,,,,,,,,,,,,, -166400,4.5923886,1.1434455,,,,,,,,,,,,,, -166443,,,0.8295599222183228,0.6190734505653381,0.7127199769020081,1.169293999671936,50000.0,0.5837000012397766,1.897260069847107,10000.0,56139.56280255318,58105.27678442001,56139.56280255318,1955.0879509449005,4.795497179031372,0.0 -166500,4.33907,1.0684978,,,,,,,,,,,,,, -166600,4.5874305,1.1131911,,,,,,,,,,,,,, -166700,4.464481,1.0935954,,,,,,,,,,,,,, -166800,4.801839,1.0753049,,,,,,,,,,,,,, -166900,4.438693,1.025687,,,,,,,,,,,,,, -167000,4.711978,1.1315857,,,,,,,,,,,,,, -167100,4.7988195,1.1485724,,,,,,,,,,,,,, -167200,4.7685375,1.1192722,,,,,,,,,,,,,, -167300,4.789621,1.1180922,,,,,,,,,,,,,, -167400,4.920841,1.0420623,,,,,,,,,,,,,, -167500,5.134485,1.1115686,,,,,,,,,,,,,, -167600,4.8369246,1.0752184,,,,,,,,,,,,,, -167700,5.2593374,1.142516,,,,,,,,,,,,,, -167800,4.674764,1.1418939,,,,,,,,,,,,,, -167900,5.139483,1.0601139,,,,,,,,,,,,,, -167957,,,0.8482740521430969,0.5581567883491516,0.7278199791908264,1.1075727939605713,50000.0,0.6022000312805176,1.799255132675171,10000.0,56649.58001804352,58632.78707766533,56649.58001804352,1972.470759391785,4.853919506072998,0.0 -168000,5.152622,1.0900981,,,,,,,,,,,,,, -168100,4.832556,1.0726805,,,,,,,,,,,,,, -168200,4.9363112,1.0413029,,,,,,,,,,,,,, -168300,4.8655124,1.0254523,,,,,,,,,,,,,, -168400,4.884053,1.1647248,,,,,,,,,,,,,, -168500,4.6955934,1.0701658,,,,,,,,,,,,,, -168600,4.844963,1.0736809,,,,,,,,,,,,,, -168700,5.072441,1.1556515,,,,,,,,,,,,,, -168800,5.314481,1.156363,,,,,,,,,,,,,, -168900,4.9184647,1.0820811,,,,,,,,,,,,,, -169000,4.606121,1.0899146,,,,,,,,,,,,,, -169100,4.444344,0.955394,,,,,,,,,,,,,, -169200,5.3325953,1.0491197,,,,,,,,,,,,,, -169300,4.934654,1.1344882,,,,,,,,,,,,,, -169400,4.874605,1.1975598,,,,,,,,,,,,,, -169471,,,0.8517617583274841,0.5366014242172241,0.7337799668312073,1.0894720554351809,50000.0,0.6065000295639038,1.802993655204773,10000.0,57159.64442586899,59160.81916928291,57159.64442586899,1990.327763795853,4.91126561164856,0.0 -169500,4.8757825,1.0881666,,,,,,,,,,,,,, -169600,5.221752,1.1329215,,,,,,,,,,,,,, -169700,5.256373,1.0971118,,,,,,,,,,,,,, -169800,5.017776,0.98831284,,,,,,,,,,,,,, -169900,5.4334984,1.1653174,,,,,,,,,,,,,, -170000,4.898004,1.1054114,,,,,,,,,,,,,, -170100,4.9951315,1.039165,,,,,,,,,,,,,, -170200,4.9502077,1.0432152,,,,,,,,,,,,,, -170300,5.1731396,1.1040508,,,,,,,,,,,,,, -170400,5.181788,1.0897573,,,,,,,,,,,,,, -170500,5.1036005,1.0714278,,,,,,,,,,,,,, -170600,5.1302013,1.1120933,,,,,,,,,,,,,, -170700,4.8921247,1.0230186,,,,,,,,,,,,,, -170800,5.0705433,1.0562823,,,,,,,,,,,,,, -170900,5.1012115,1.0003226,,,,,,,,,,,,,, -170985,,,0.8520607352256775,0.5264204144477844,0.7373600006103516,1.0713787078857422,50000.0,0.6118000149726868,1.7798755168914795,10000.0,57669.60288262367,59688.65509557724,57669.60288262367,2008.0955998897552,4.968322038650513,0.0 -171000,4.946556,1.0024378,,,,,,,,,,,,,, -171100,4.933498,1.0749143,,,,,,,,,,,,,, -171200,5.150162,1.007196,,,,,,,,,,,,,, -171300,5.0532575,1.1232554,,,,,,,,,,,,,, -171400,4.8830795,0.928297,,,,,,,,,,,,,, -171500,5.3743577,1.0335323,,,,,,,,,,,,,, -171600,5.3146257,0.9686057,,,,,,,,,,,,,, -171700,5.533942,1.1165675,,,,,,,,,,,,,, -171800,4.8292365,0.99952894,,,,,,,,,,,,,, -171900,5.092463,1.0427905,,,,,,,,,,,,,, -172000,4.9695168,0.97207284,,,,,,,,,,,,,, -172100,5.125384,1.028945,,,,,,,,,,,,,, -172200,5.383829,1.0886873,,,,,,,,,,,,,, -172300,5.170103,0.9963415,,,,,,,,,,,,,, -172400,4.8789053,0.9519058,,,,,,,,,,,,,, -172499,,,0.8572624325752258,0.5081093311309814,0.7387199997901917,1.0654412508010864,50000.0,0.6099000573158264,1.7793102264404297,10000.0,58179.54089784622,60216.452474832535,58179.54089784622,2025.8408544063568,5.029789209365845,0.0 -172500,5.29864,1.0444715,,,,,,,,,,,,,, -172600,5.387392,0.98678184,,,,,,,,,,,,,, -172700,4.904493,1.0002466,,,,,,,,,,,,,, -172800,5.417091,0.9908376,,,,,,,,,,,,,, -172900,4.7334576,1.0327163,,,,,,,,,,,,,, -173000,4.7139754,0.9648284,,,,,,,,,,,,,, -173100,5.5683513,1.0777264,,,,,,,,,,,,,, -173200,4.9858727,1.0305321,,,,,,,,,,,,,, -173300,5.49471,1.1047944,,,,,,,,,,,,,, -173400,5.0605307,0.92254204,,,,,,,,,,,,,, -173500,5.2163405,1.00899,,,,,,,,,,,,,, -173600,5.213239,1.0425106,,,,,,,,,,,,,, -173700,5.4720845,0.9926208,,,,,,,,,,,,,, -173800,5.3264604,1.0852054,,,,,,,,,,,,,, -173900,4.8887873,0.99235296,,,,,,,,,,,,,, -174000,5.4600077,1.0146196,,,,,,,,,,,,,, -174013,,,0.8667888641357422,0.4688793122768402,0.7406599521636963,1.0524687767028809,50000.0,0.6139000058174133,1.7508876323699951,10000.0,58689.59416794777,60744.05333399773,58689.59416794777,2043.2813086509705,5.085204124450684,0.0 -174100,5.560106,1.0282851,,,,,,,,,,,,,, -174200,5.4040995,1.0276374,,,,,,,,,,,,,, -174300,5.3594294,1.0234404,,,,,,,,,,,,,, -174400,5.2279778,1.0346773,,,,,,,,,,,,,, -174500,5.466712,1.0543201,,,,,,,,,,,,,, -174600,5.075388,1.0663689,,,,,,,,,,,,,, -174700,5.4429717,1.058388,,,,,,,,,,,,,, -174800,5.454043,0.96745396,,,,,,,,,,,,,, -174900,5.6168056,1.0260899,,,,,,,,,,,,,, -175000,5.5650716,1.0468422,,,,,,,,,,,,,, -175100,5.189997,1.0241473,,,,,,,,,,,,,, -175200,5.271299,1.0336345,,,,,,,,,,,,,, -175300,5.0493493,1.0937027,,,,,,,,,,,,,, -175400,5.210881,0.9753548,,,,,,,,,,,,,, -175500,5.703443,1.0644842,,,,,,,,,,,,,, -175527,,,0.8767737150192261,0.4423072636127472,0.7445200085639954,1.0392271280288696,50000.0,0.6184000372886658,1.7471065521240234,10000.0,59199.493911504745,61271.79503440857,59199.493911504745,2061.011162519455,5.145601987838745,0.0 -175600,4.825403,0.98892325,,,,,,,,,,,,,, -175700,5.3612185,1.0514346,,,,,,,,,,,,,, -175800,5.168697,1.0010754,,,,,,,,,,,,,, -175900,5.61401,1.1260207,,,,,,,,,,,,,, -176000,5.384941,1.0254,,,,,,,,,,,,,, -176100,5.0281973,0.9700856,,,,,,,,,,,,,, -176200,5.2431803,1.0290416,,,,,,,,,,,,,, -176300,5.74776,1.0193052,,,,,,,,,,,,,, -176400,5.06914,1.0402359,,,,,,,,,,,,,, -176500,5.3893485,1.0415219,,,,,,,,,,,,,, -176600,5.2266145,0.96166575,,,,,,,,,,,,,, -176700,4.980715,1.014527,,,,,,,,,,,,,, -176800,5.378235,1.0258096,,,,,,,,,,,,,, -176900,5.4152412,1.0259693,,,,,,,,,,,,,, -177000,5.315622,0.9959934,,,,,,,,,,,,,, -177041,,,0.8763552308082581,0.4342447221279144,0.7455599904060364,1.0426863431930542,50000.0,0.616100013256073,1.7452729940414429,10000.0,59709.4871468544,61799.87656879425,59709.4871468544,2078.988668680191,5.2041075229644775,0.0 -177100,5.338844,1.0449015,,,,,,,,,,,,,, -177200,5.595066,0.9206166,,,,,,,,,,,,,, -177300,5.270874,0.9848468,,,,,,,,,,,,,, -177400,5.1467414,0.92121524,,,,,,,,,,,,,, -177500,5.8348255,1.08182,,,,,,,,,,,,,, -177600,5.1614747,0.9247537,,,,,,,,,,,,,, -177700,5.310036,1.0305419,,,,,,,,,,,,,, -177800,5.5639596,0.97521496,,,,,,,,,,,,,, -177900,5.210644,1.0720837,,,,,,,,,,,,,, -178000,5.8826346,0.99442273,,,,,,,,,,,,,, -178100,5.406684,1.0156133,,,,,,,,,,,,,, -178200,5.422265,0.9950417,,,,,,,,,,,,,, -178300,5.3035493,0.9800297,,,,,,,,,,,,,, -178400,4.991267,0.8635881,,,,,,,,,,,,,, -178500,5.354985,0.9476961,,,,,,,,,,,,,, -178556,,,0.8790258169174194,0.4302779138088226,0.7473999857902527,1.0301430225372314,50000.0,0.6222000122070312,1.731208324432373,10000.0,60219.56933808327,62327.75828695297,60219.56933808327,2096.6750016212463,5.264262676239014,0.0 -178600,5.5884624,1.063661,,,,,,,,,,,,,, -178700,5.4473686,0.9627908,,,,,,,,,,,,,, -178800,5.60883,0.9053508,,,,,,,,,,,,,, -178900,5.494836,0.97351557,,,,,,,,,,,,,, -179000,5.1977587,0.92283,,,,,,,,,,,,,, -179100,5.3949075,1.0269213,,,,,,,,,,,,,, -179200,5.447402,0.9288369,,,,,,,,,,,,,, -179300,5.4495335,0.9383844,,,,,,,,,,,,,, -179400,5.335552,0.91307217,,,,,,,,,,,,,, -179500,4.961536,0.9404491,,,,,,,,,,,,,, -179600,5.302427,0.9484787,,,,,,,,,,,,,, -179700,5.069697,0.9400732,,,,,,,,,,,,,, -179800,5.0779014,0.93025625,,,,,,,,,,,,,, -179900,5.4120646,0.9046379,,,,,,,,,,,,,, -180000,5.3879495,0.88771635,,,,,,,,,,,,,, -180071,,,0.878348171710968,0.4257912039756775,0.7497199773788452,1.0273454189300537,50000.0,0.6249000430107117,1.7289284467697144,10000.0,60729.72563242912,62855.63610816002,60729.72563242912,2114.281795501709,5.3266565799713135,0.0 -180100,5.333601,0.97994524,,,,,,,,,,,,,, -180200,5.3965087,0.88727754,,,,,,,,,,,,,, -180300,5.6458144,0.9983356,,,,,,,,,,,,,, -180400,5.724287,1.0445402,,,,,,,,,,,,,, -180500,5.1114244,0.93830824,,,,,,,,,,,,,, -180600,5.375145,0.9066442,,,,,,,,,,,,,, -180700,5.5321994,1.0002086,,,,,,,,,,,,,, -180800,5.270418,0.99436104,,,,,,,,,,,,,, -180900,4.7756557,0.87307656,,,,,,,,,,,,,, -181000,5.1266165,0.9339792,,,,,,,,,,,,,, -181100,5.6750517,0.9796642,,,,,,,,,,,,,, -181200,5.347969,0.9662565,,,,,,,,,,,,,, -181300,5.450307,0.8740126,,,,,,,,,,,,,, -181400,5.3587003,0.9759199,,,,,,,,,,,,,, -181500,5.2967005,0.8602308,,,,,,,,,,,,,, -181585,,,0.8835299611091614,0.412623792886734,0.750220000743866,1.02522611618042,50000.0,0.6249000430107117,1.7295113801956177,10000.0,61239.78214740753,63383.342509269714,61239.78214740753,2131.816429138184,5.389420032501221,0.0 -181600,5.8498077,0.99883765,,,,,,,,,,,,,, -181700,5.460898,0.9652775,,,,,,,,,,,,,, -181800,6.0243177,1.0562912,,,,,,,,,,,,,, -181900,5.39174,1.0118762,,,,,,,,,,,,,, -182000,6.057145,0.96858394,,,,,,,,,,,,,, -182100,5.442118,0.9819824,,,,,,,,,,,,,, -182200,5.756352,1.018435,,,,,,,,,,,,,, -182300,5.11284,0.95014656,,,,,,,,,,,,,, -182400,5.7849264,1.009482,,,,,,,,,,,,,, -182500,5.045401,0.9403887,,,,,,,,,,,,,, -182600,5.0932717,1.0014548,,,,,,,,,,,,,, -182700,5.5004177,0.999756,,,,,,,,,,,,,, -182800,5.776628,0.9417396,,,,,,,,,,,,,, -182900,4.86836,0.95298934,,,,,,,,,,,,,, -183000,5.454142,1.0314943,,,,,,,,,,,,,, -183099,,,0.8856026530265808,0.4069898426532745,0.7508999705314636,1.0227185487747192,50000.0,0.6252000331878662,1.7207680940628052,10000.0,61749.71807217598,63910.96601963043,61749.71807217598,2149.390313625336,5.450626850128174,0.0 -183100,5.2819524,0.95893973,,,,,,,,,,,,,, -183200,5.1571426,0.89826214,,,,,,,,,,,,,, -183300,5.7149916,1.0061417,,,,,,,,,,,,,, -183400,5.6369267,0.97940993,,,,,,,,,,,,,, -183500,5.0687633,0.9276409,,,,,,,,,,,,,, -183600,5.70812,1.0797492,,,,,,,,,,,,,, -183700,5.187946,0.8609237,,,,,,,,,,,,,, -183800,5.5135927,0.9456265,,,,,,,,,,,,,, -183900,5.310844,0.9668104,,,,,,,,,,,,,, -184000,5.2014537,1.0052385,,,,,,,,,,,,,, -184100,5.6780167,1.0068674,,,,,,,,,,,,,, -184200,5.3239965,0.9841039,,,,,,,,,,,,,, -184300,5.383182,0.94821775,,,,,,,,,,,,,, -184400,5.5261474,0.98411846,,,,,,,,,,,,,, -184500,5.479126,1.0090163,,,,,,,,,,,,,, -184600,4.9770827,0.8799009,,,,,,,,,,,,,, -184613,,,0.8846459984779358,0.4113345444202423,0.7509999871253967,1.0213669538497925,50000.0,0.625,1.724463701248169,10000.0,62259.6433134079,64438.64605140686,62259.6433134079,2167.0347170829773,5.507994651794434,0.0 -184700,5.671865,0.95465815,,,,,,,,,,,,,, -184800,5.389598,0.8899372,,,,,,,,,,,,,, -184900,5.336171,0.9419026,,,,,,,,,,,,,, -185000,5.2364535,0.9687122,,,,,,,,,,,,,, -185100,5.164866,0.94267416,,,,,,,,,,,,,, -185200,5.6868634,0.9609963,,,,,,,,,,,,,, -185300,5.2692,0.93026423,,,,,,,,,,,,,, -185400,5.1386633,0.92248076,,,,,,,,,,,,,, -185500,5.3001046,0.92679226,,,,,,,,,,,,,, -185600,5.5587378,1.0242113,,,,,,,,,,,,,, -185700,5.1536627,0.91346717,,,,,,,,,,,,,, -185800,5.320826,0.94959176,,,,,,,,,,,,,, -185900,5.48901,1.0021389,,,,,,,,,,,,,, -186000,5.1676636,0.9677553,,,,,,,,,,,,,, -186100,5.6619916,0.98623586,,,,,,,,,,,,,, -186127,,,0.8859614133834839,0.4078052341938019,0.7506799697875977,1.0217421054840088,50000.0,0.6243000030517578,1.7228504419326782,10000.0,62769.76277685165,64966.37164711952,62769.76277685165,2184.528416633606,5.566797494888306,0.0 -186200,5.0970182,0.84342766,,,,,,,,,,,,,, -186300,5.4333286,0.97810674,,,,,,,,,,,,,, -186400,5.5240254,0.9645752,,,,,,,,,,,,,, -186500,5.3122435,0.9238648,,,,,,,,,,,,,, -186600,5.3962636,0.95740867,,,,,,,,,,,,,, -186666,,,0.8869180083274841,0.4038886129856109,0.7511199712753296,1.0213629007339478,50000.0,0.6237000226974487,1.722780704498291,10000.0,62951.208920001984,65165.74819970131,62951.208920001984,2202.380264520645,5.626325130462647,0.0 -186666,,,,,,,,,,,62951.208920001984,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/eval_measurements.csv deleted file mode 100644 index c532c5110..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.53779411315918,0.0,33.085120677948,1,0,33.085120677948,0.0013000001199543,6.9117279052734375,10000,50.62336611747742,0.0005978954141028,6.911577224731445,0.0006799999973736,6.912051200866699,50000 -35.47820210456848,0.021540880203247,543.0786073207855,1499,0,543.0786073207855,0.1145000085234642,4.87145471572876,10000,578.6318206787109,0.1660355478525161,4.276123523712158,0.1514399945735931,4.407378196716309,50000 -53.395957469940186,0.0544428825378418,1053.254381418228,2998,0,1053.254381418228,0.2322000116109848,3.897366523742676,10000,1106.8110687732697,0.3397241532802582,3.0705549716949463,0.3152599930763244,3.223424196243286,50000 -71.222341299057,0.0895986557006836,1563.3771555423737,4498,0,1563.3771555423737,0.3113000094890594,3.398090362548828,10000,1634.8487372398376,0.4709024131298065,2.321305751800537,0.4215799868106842,2.6061043739318848,50000 -88.89337491989136,0.1208815574645996,2073.5986137390137,6000,0,2073.5986137390137,0.3790000081062317,2.932342052459717,10000,2162.8251535892487,0.5570989847183228,1.8586288690567017,0.4995799958705902,2.171689748764038,50000 -106.47591543197632,0.1529977321624755,2583.842520713806,7503,0,2583.842520713806,0.4199000298976898,2.727141857147217,10000,2690.7365441322327,0.5871731638908386,1.7035075426101685,0.5399199724197388,1.982594013214112,50000 -124.2677731513977,0.1858928203582763,3093.868242740631,9006,0,3093.868242740631,0.4235000312328338,2.7236905097961426,10000,3218.64040517807,0.6002869606018066,1.646314024925232,0.550819993019104,1.9125394821166992,50000 -142.03222179412842,0.2162094116210937,3603.961983203888,10510,0,3603.961983203888,0.4556000232696533,2.51633882522583,10000,3746.58158493042,0.629324734210968,1.4985932111740112,0.5819999575614929,1.7651735544204712,50000 -159.66547989845276,0.2459611892700195,4113.984125375748,12014,0,4113.984125375748,0.462300032377243,2.5003979206085205,10000,4274.320744037628,0.629304826259613,1.5169682502746582,0.5853399634361267,1.7463421821594238,50000 -177.51647090911865,0.2762629985809326,4623.965626001358,13517,0,4623.965626001358,0.461400032043457,2.5226640701293945,10000,4802.236160516739,0.6426379084587097,1.466149091720581,0.589959979057312,1.7425481081008911,50000 -195.36043906211853,0.3092031478881836,5134.007486343384,15021,0,5134.007486343384,0.4645000100135803,2.4317922592163086,10000,5330.208964586258,0.6707589030265808,1.304468035697937,0.5949400067329407,1.69130539894104,50000 -213.22563099861145,0.3367078304290771,5644.199923276901,16526,0,5644.199923276901,0.4793000221252441,2.4198639392852783,10000,5858.348840475082,0.6666334271430969,1.3141627311706543,0.6050199866294861,1.657135009765625,50000 -231.12113857269287,0.3682653903961181,6154.205993413925,18030,0,6154.205993413925,0.4785000085830688,2.3898165225982666,10000,6386.336737394333,0.6596978306770325,1.363717555999756,0.6011999845504761,1.6728837490081787,50000 -249.5828001499176,0.398468017578125,6664.177336454392,19534,0,6664.177336454392,0.4834000170230865,2.365743398666382,10000,6914.852823257446,0.6613320708274841,1.3531110286712646,0.6044999957084656,1.6473594903945925,50000 -267.3385720252991,0.4360337257385254,7174.233859062195,21038,0,7174.233859062195,0.4713000357151031,2.450623750686645,10000,7442.757306337357,0.6508888602256775,1.4064011573791504,0.5990999937057495,1.6903232336044312,50000 -285.31112122535706,0.467952013015747,7684.207465171814,22542,0,7684.207465171814,0.4860000312328338,2.391890287399292,10000,7970.789026260376,0.6646006107330322,1.3387823104858398,0.6141799688339233,1.617658257484436,50000 -303.01013803482056,0.5014698505401611,8194.420420408249,24047,0,8194.420420408249,0.4808000326156616,2.405961513519287,10000,8498.789740085602,0.6842713356018066,1.2391905784606934,0.6031399965286255,1.66865074634552,50000 -320.98294949531555,0.5339152812957764,8704.557604551315,25552,0,8704.557604551315,0.4945000112056732,2.319498300552368,10000,9026.98484826088,0.6896723508834839,1.2271901369094849,0.6210399866104126,1.5833314657211304,50000 -338.75679183006287,0.567237377166748,9214.546523332596,27056,0,9214.546523332596,0.4948000311851501,2.30348801612854,10000,9554.83399772644,0.683015763759613,1.2480367422103882,0.6202600002288818,1.574621319770813,50000 -356.4754137992859,0.6004829406738281,9724.683393001556,28561,0,9724.683393001556,0.4812000095844269,2.4035277366638184,10000,10082.776176929474,0.6730906963348389,1.2945865392684937,0.6128399968147278,1.6189769506454468,50000 -374.3439898490906,0.6367506980895996,10234.78894495964,30066,0,10234.78894495964,0.4984000325202942,2.274780035018921,10000,10610.840461969376,0.6851283311843872,1.2363115549087524,0.6257599592208862,1.53518807888031,50000 -391.9957549571991,0.6751341819763184,10744.854972839355,31570,0,10744.854972839355,0.5035000443458557,2.246699333190918,10000,11138.650857448578,0.686922013759613,1.2258983850479126,0.6343799829483032,1.5134379863739014,50000 -409.8847246170044,0.7104458808898926,11254.89800453186,33074,0,11254.89800453186,0.4880000352859497,2.316349744796753,10000,11666.672231435776,0.7069116830825806,1.1511811017990112,0.6170399785041809,1.5918534994125366,50000 -427.47814416885376,0.7436776161193848,11765.143156051636,34579,0,11765.143156051636,0.4945000112056732,2.303445816040039,10000,12194.597846269608,0.6890744566917419,1.2211211919784546,0.6181399822235107,1.5859718322753906,50000 -445.4265134334564,0.7787868976593018,12275.155267238615,36083,0,12275.155267238615,0.5162000060081482,2.258314847946167,10000,12722.646920681,0.6932995915412903,1.190823316574097,0.6310399770736694,1.5338072776794434,50000 -463.7156083583832,0.8150899410247803,12785.271156549454,37588,0,12785.271156549454,0.4964000284671783,2.281170606613159,10000,13251.140579938889,0.6861447691917419,1.233417272567749,0.6261999607086182,1.5477724075317385,50000 -481.4382588863373,0.8444697856903076,13295.242151021956,39093,0,13295.242151021956,0.5124000310897827,2.243726968765259,10000,13778.91663146019,0.7010124325752258,1.1792408227920532,0.64028000831604,1.4864907264709473,50000 -499.3242871761322,0.8821501731872559,13805.37430691719,40598,0,13805.37430691719,0.4976000189781189,2.290236473083496,10000,14307.025272130966,0.6853276491165161,1.2317752838134766,0.6255800127983093,1.5425355434417725,50000 -517.1469714641571,0.9237699508666992,14315.556084632874,42104,0,14315.556084632874,0.5109000205993652,2.22822904586792,10000,14835.124782562256,0.7300103306770325,1.0265074968338013,0.6344799995422363,1.5053465366363523,50000 -534.8262553215027,0.9605207443237304,14825.481731653214,43609,0,14825.481731653214,0.5078000426292419,2.239952087402344,10000,15362.819847106934,0.7144451141357422,1.1129549741744995,0.638700008392334,1.48816180229187,50000 -552.4992291927338,0.9993839263916016,15335.629780054092,45096,0,15335.629780054092,0.5080000162124634,2.263049840927124,10000,15890.73356628418,0.7016900181770325,1.154717206954956,0.6298800110816956,1.5299546718597412,50000 -570.4190890789032,2.1719322204589844,15844.633338212969,46598,0,15844.633338212969,0.5160000324249268,2.164910793304444,10000,16418.882928848267,0.7108976244926453,1.125377893447876,0.6459800004959106,1.4500062465667725,50000 -588.1098670959473,2.20971941947937,16354.717643976212,48103,0,16354.717643976212,0.511400043964386,2.2049496173858643,10000,16946.749007940292,0.6979033946990967,1.1757428646087646,0.6360799670219421,1.504656195640564,50000 -605.7434694766998,2.248587131500244,16864.65782546997,49608,0,16864.65782546997,0.5209000110626221,2.1950442790985107,10000,17474.416180849075,0.7112563848495483,1.128156304359436,0.6507200002670288,1.4424986839294434,50000 -623.6246781349182,2.2868897914886475,17374.636246919632,51114,0,17374.636246919632,0.5172000527381897,2.2189431190490723,10000,18002.36827158928,0.7567960619926453,0.9291717410087584,0.64301997423172,1.4702551364898682,50000 -642.0388751029968,2.3303985595703125,17884.815304279327,52620,0,17884.815304279327,0.4982000291347503,2.276724100112915,10000,18531.05840587616,0.7090441584587097,1.130107045173645,0.6304599642753601,1.5274958610534668,50000 -659.6608362197876,2.370500087738037,18394.996651649475,54126,0,18394.996651649475,0.513700008392334,2.202472686767578,10000,19058.955248355865,0.7151626348495483,1.1085048913955688,0.640720009803772,1.4724503755569458,50000 -677.4676706790924,2.408476591110229,18905.095246076584,55632,0,18905.095246076584,0.5157999992370605,2.200547933578491,10000,19586.952585458755,0.7169164419174194,1.0998347997665403,0.6484999656677246,1.4559392929077148,50000 -695.7189378738403,2.4459259510040283,19415.004539966583,57138,0,19415.004539966583,0.5157999992370605,2.2013065814971924,10000,20115.204418182373,0.71000075340271,1.1082898378372192,0.6432799696922302,1.463404655456543,50000 -713.5251975059509,2.4837570190429688,19925.25050020218,58644,0,19925.25050020218,0.5159000158309937,2.179934501647949,10000,20643.348040819168,0.7024473547935486,1.1483705043792725,0.6444000005722046,1.4588139057159424,50000 -731.4343252182007,2.5222463607788086,20435.16510415077,60150,0,20435.16510415077,0.5308000445365906,2.1435821056365967,10000,21171.26289153099,0.7673987150192261,0.8855634927749634,0.6535999774932861,1.4235299825668335,50000 -749.1801223754883,2.560462474822998,20945.074239969254,61655,0,20945.074239969254,0.5003000497817993,2.302015781402588,10000,21699.009701251984,0.7128706574440002,1.1097934246063232,0.6341599822044373,1.517980456352234,50000 -766.840833902359,2.6059324741363525,21454.97639608383,63160,0,21454.97639608383,0.5281000137329102,2.1373679637908936,10000,22226.67244052887,0.7302096486091614,1.038726806640625,0.6582399606704712,1.4035637378692627,50000 -784.704512834549,2.6474947929382324,21965.061421632767,64666,0,21965.061421632767,0.5207000374794006,2.1825969219207764,10000,22754.716769218445,0.7167769074440002,1.0958020687103271,0.6484799981117249,1.438090443611145,50000 -802.5071756839752,2.695244073867798,22475.0136077404,66172,0,22475.0136077404,0.5215000510215759,2.1981096267700195,10000,23282.5726454258,0.7169762253761292,1.0895507335662842,0.6572399735450745,1.4136441946029663,50000 -820.9895300865173,2.7346694469451904,22985.089703559875,67678,0,22985.089703559875,0.5365000367164612,2.1053407192230225,10000,23811.22298622132,0.727937638759613,1.04031240940094,0.6620799899101257,1.3774571418762207,50000 -838.5256464481354,2.7713065147399902,23495.21492218972,69184,0,23495.21492218972,0.5349000096321106,2.0840163230896,10000,24338.97371149063,0.7716637253761292,0.8626237511634827,0.661579966545105,1.3844743967056274,50000 -856.152738571167,2.8141791820526123,24005.384941101074,70691,0,24005.384941101074,0.5488000512123108,2.067430257797241,10000,24866.868101119995,0.7581313848495483,0.9096571207046508,0.6702799797058105,1.3560611009597778,50000 -873.9265124797821,2.855938196182251,24515.50654459,72197,0,24515.50654459,0.534500002861023,2.125728368759156,10000,25394.858213186264,0.7457548975944519,0.974716067314148,0.6640200018882751,1.391406536102295,50000 -891.5010304450989,2.8975298404693604,25025.546664714813,73703,0,25025.546664714813,0.5339000225067139,2.1294972896575928,10000,25922.5676074028,0.7359893321990967,1.0156553983688354,0.6577799916267395,1.3957440853118896,50000 -909.7378346920012,2.939307928085327,25535.52632427216,75209,0,25535.52632427216,0.532200038433075,2.108184576034546,10000,26450.878532886505,0.7361288070678711,1.0147454738616943,0.6642199754714966,1.3678507804870603,50000 -927.2890992164612,2.9757421016693115,26045.5726313591,76715,0,26045.5726313591,0.5281000137329102,2.151803970336914,10000,26978.564618349075,0.725027859210968,1.0601333379745483,0.6577199697494507,1.3992598056793213,50000 -945.0738813877106,3.020139217376709,26555.53935289383,78221,0,26555.53935289383,0.523300051689148,2.1556520462036133,10000,27506.41523051262,0.7487444281578064,0.9544544219970704,0.6531999707221985,1.4107705354690552,50000 -962.5798609256744,3.0614492893218994,27065.5693500042,79727,0,27065.5693500042,0.5210000276565552,2.229564905166626,10000,28034.04595661164,0.7374043464660645,0.9866397380828856,0.6531199812889099,1.437100887298584,50000 -980.4816539287568,3.104302167892456,27575.630492925644,81233,0,27575.630492925644,0.5451000332832336,2.054157257080078,10000,28562.10580611229,0.7581512928009033,0.9172185659408568,0.6740999817848206,1.3255287408828735,50000 -998.2937302589417,3.1469409465789795,28085.67872595787,82739,0,28085.67872595787,0.547700047492981,2.067952871322632,10000,29090.06341791153,0.7515744566917419,0.9364935159683228,0.6704399585723877,1.3452812433242798,50000 -1016.1643199920654,3.186960935592652,28595.77198863029,84245,0,28595.77198863029,0.544700026512146,2.050497055053711,10000,29618.119492292404,0.749043345451355,0.9529641270637512,0.6751599907875061,1.3240749835968018,50000 -1033.8766658306122,3.23142409324646,29105.80688929557,85751,0,29105.80688929557,0.5537000298500061,1.9823843240737915,10000,30145.964017629623,0.7525310516357422,0.9264184236526488,0.6801599860191345,1.2922544479370115,50000 -1051.9092426300049,3.2774195671081543,29615.99306058884,87257,0,29615.99306058884,0.5484000444412231,2.032165765762329,10000,30674.282242536545,0.7596858739852905,0.9148987531661988,0.6706599593162537,1.3406606912612915,50000 -1069.4649865627289,3.327728748321533,30125.917417526245,88763,0,30125.917417526245,0.541100025177002,2.063876152038574,10000,31201.865349769592,0.7681162357330322,0.8765470385551453,0.6717999577522278,1.342298150062561,50000 -1087.206655502319,3.3750007152557373,30635.965329885483,90269,0,30635.965329885483,0.5516000390052795,2.005854368209839,10000,31729.75548362732,0.7680763602256775,0.8671398162841797,0.6808800101280212,1.2917416095733645,50000 -1105.1372406482697,3.414245128631592,31146.04019165039,91776,0,31146.04019165039,0.5533000230789185,2.02983021736145,10000,32257.85404109955,0.7579320669174194,0.906115710735321,0.6772199869155884,1.3106396198272705,50000 -1123.2598929405212,3.460043430328369,31656.18217587471,93282,0,31656.18217587471,0.558899998664856,1.9982883930206297,10000,32786.21779823303,0.7669204473495483,0.8753052949905396,0.6854999661445618,1.2829898595809937,50000 -1140.9816081523895,3.503222703933716,32166.163598299023,94788,0,32166.163598299023,0.5600000023841858,2.0052812099456787,10000,33314.017624139786,0.7631935477256775,0.8865559697151184,0.6861400008201599,1.2788174152374268,50000 -1159.3427624702454,3.55094575881958,32676.12209534645,96292,0,32676.12209534645,0.5563000440597534,2.0307154655456543,10000,33842.43831539154,0.7673788070678711,0.8581953048706055,0.6811400055885315,1.2850172519683838,50000 -1177.4779727458954,3.59602165222168,33186.13390159607,97798,0,33186.13390159607,0.5625,1.977084040641785,10000,34370.68486762047,0.7949816584587097,0.7670559883117676,0.6891199946403503,1.256996750831604,50000 -1195.3270378112793,3.64474105834961,33696.05191516876,99303,0,33696.05191516876,0.5552000403404236,1.99160385131836,10000,34898.55397820473,0.7777224183082581,0.8399938344955444,0.6829800009727478,1.2816671133041382,50000 -1213.2627115249634,3.693167209625244,34206.17853355408,100809,0,34206.17853355408,0.55840003490448,2.0301966667175293,10000,35426.718936920166,0.7700493931770325,0.8601894378662109,0.6826599836349487,1.296310305595398,50000 -1231.1673793792725,3.744328260421753,34716.23765563965,102315,0,34716.23765563965,0.5603000521659851,1.9734432697296145,10000,35954.78728866577,0.7760483026504517,0.8299661874771118,0.6898399591445923,1.2503280639648438,50000 -1249.0791852474213,3.794133901596069,35226.270773649216,103821,0,35226.270773649216,0.5681000351905823,1.9737340211868288,10000,36482.83643579483,0.76761794090271,0.8790847659111023,0.685699999332428,1.2791662216186523,50000 -1267.4798917770386,3.841939449310303,35736.3866379261,105327,0,35736.3866379261,0.5646000504493713,1.997875094413757,10000,37011.45383000374,0.7800143361091614,0.810962438583374,0.6937800049781799,1.2509777545928955,50000 -1285.1757638454435,3.8856961727142334,36246.66552352905,106833,0,36246.66552352905,0.5701000094413757,1.9629000425338743,10000,37539.52548789978,0.8098094463348389,0.6974734663963318,0.6970799565315247,1.228915095329285,50000 -1303.0352365970612,3.935134649276733,36756.7243309021,108339,0,36756.7243309021,0.572700023651123,1.9371442794799805,10000,38067.54865336418,0.8022361397743225,0.7281384468078613,0.6987000107765198,1.2137079238891602,50000 -1320.5831859111786,3.982501983642578,37266.74565792084,109845,0,37266.74565792084,0.5742000341415405,1.9284789562225344,10000,38595.21840763092,0.795320451259613,0.7437325119972229,0.6987400054931641,1.229027509689331,50000 -1338.08602809906,4.032949924468994,37776.84398150444,111350,0,37776.84398150444,0.5878000259399414,1.8825327157974243,10000,39122.9255297184,0.8005022406578064,0.7239665985107422,0.7046999931335449,1.1938434839248655,50000 -1356.5964069366455,4.082559108734131,38287.077071905136,112857,0,38287.077071905136,0.5761000514030457,1.942015290260315,10000,39651.77258491516,0.7922711968421936,0.7695667147636414,0.7007399797439575,1.2252343893051147,50000 -1374.487015724182,4.132639169692993,38796.97853899002,114363,0,38796.97853899002,0.5761000514030457,1.938789129257202,10000,40179.66720294952,0.797293484210968,0.7358038425445557,0.7057999968528748,1.2084553241729736,50000 -1392.427879571915,4.181373596191406,39306.96019792557,115869,0,39306.96019792557,0.5750000476837158,1.9542590379714968,10000,40707.69045972824,0.8302773833274841,0.6240458488464355,0.703719973564148,1.2063405513763428,50000 -1410.2165460586548,4.232245206832886,39817.018862485886,117375,0,39817.018862485886,0.5832000374794006,1.9158483743667605,10000,41235.641859054565,0.8184191584587097,0.661994993686676,0.702019989490509,1.206328511238098,50000 -1427.716940164566,4.2863757610321045,40327.03925204277,118881,0,40327.03925204277,0.5807000398635864,1.93263578414917,10000,41763.26990413666,0.8126594424247742,0.6735084652900696,0.7055599689483643,1.2065515518188477,50000 -1446.2023582458496,4.338989973068237,40837.051486730576,120387,0,40837.051486730576,0.5871000289916992,1.8709394931793213,10000,42291.873514175415,0.8160474896430969,0.6632678508758545,0.7124399542808533,1.1698178052902222,50000 -1463.6995224952698,4.385130882263184,41347.19075655937,121894,0,41347.19075655937,0.5918000340461731,1.8801075220108032,10000,42819.61029410362,0.8153699040412903,0.6547273993492126,0.7078199982643127,1.1856532096862793,50000 -1481.5602872371674,4.439347743988037,41857.21672534943,123400,0,41857.21672534943,0.5821000337600708,1.942033529281616,10000,43347.604243040085,0.8108059167861938,0.6760615706443787,0.704539954662323,1.212332248687744,50000 -1499.112335205078,4.489536046981812,42367.14495229721,124906,0,42367.14495229721,0.591200053691864,1.8790547847747805,10000,43875.18793177605,0.8531568646430969,0.5148501992225647,0.7152799963951111,1.1633607149124146,50000 -1516.701696395874,4.539835691452026,42877.105335474014,126412,0,42877.105335474014,0.5944000482559204,1.8772876262664795,10000,44402.84032511711,0.8422951102256775,0.5681675672531128,0.7170599699020386,1.14886736869812,50000 -1535.1918017864227,4.589509010314941,43387.303060531616,127918,0,43387.303060531616,0.5874000191688538,1.876875281333924,10000,44931.63086462021,0.8375916481018066,0.573846697807312,0.7154799699783325,1.1557798385620115,50000 -1552.941791057587,4.631393909454346,43897.30212020874,129424,0,43897.30212020874,0.6003000140190125,1.853930711746216,10000,45459.47427010536,0.8389867544174194,0.5688159465789795,0.7219199538230896,1.1433396339416504,50000 -1570.5166273117063,4.68116021156311,44407.29587292671,130930,0,44407.29587292671,0.5950000286102295,1.887262225151062,10000,45987.14557003975,0.8404615521430969,0.56621253490448,0.7183199524879456,1.1442562341690063,50000 -1588.0953686237335,4.735187530517578,44917.199598789215,132436,0,44917.199598789215,0.5893000364303589,1.8838881254196167,10000,46514.73631954193,0.8365752100944519,0.582433819770813,0.7178199887275696,1.1573113203048706,50000 -1606.6192646026611,4.7905237674713135,45427.28427886963,133942,0,45427.28427886963,0.596500039100647,1.8780877590179443,10000,47043.45353341103,0.8755381107330322,0.4374611675739288,0.7226799726486206,1.1405651569366455,50000 -1624.4392714500427,4.849410057067871,45937.33114624024,135448,0,45937.33114624024,0.5966000556945801,1.880648732185364,10000,47571.43300700188,0.8627630472183228,0.4806753098964691,0.7238199710845947,1.131712555885315,50000 -1642.457855463028,4.900918006896973,46447.43666052818,136954,0,46447.43666052818,0.597000002861023,1.857327580451965,10000,48099.66242289543,0.8583984375,0.4998326003551483,0.7249599695205688,1.131963849067688,50000 -1660.0084781646729,4.961538314819336,46957.42192101479,138459,0,46957.42192101479,0.5968000292778015,1.881799340248108,10000,48627.31255126,0.859773576259613,0.4849222302436828,0.7234799861907959,1.1349554061889648,50000 -1677.843991279602,5.014599561691284,47467.60343265533,139966,0,47467.60343265533,0.6035000085830688,1.8506639003753664,10000,49155.43559336662,0.8603116869926453,0.486088365316391,0.7274799942970276,1.1127536296844482,50000 -1695.720237493515,5.068366050720215,47977.59747934341,141472,0,47977.59747934341,0.601900041103363,1.8711603879928589,10000,49683.41357302666,0.8621651530265808,0.4798415601253509,0.7297399640083313,1.1351284980773926,50000 -1713.8870961666107,5.12295126914978,48487.5690472126,142978,0,48487.5690472126,0.6035000085830688,1.8868118524551392,10000,50211.66066980362,0.8984972834587097,0.3527859449386596,0.7320599555969238,1.1209580898284912,50000 -1731.585001707077,5.165935516357422,48997.5791516304,144484,0,48997.5791516304,0.6080000400543213,1.855978012084961,10000,50739.464485645294,0.8881736397743225,0.3926560580730438,0.7305600047111511,1.1178728342056274,50000 -1749.398912668228,5.220576286315918,49507.801383018494,145989,0,49507.801383018494,0.6055999994277954,1.8426967859268188,10000,51267.6076464653,0.8858019709587097,0.3927642703056335,0.7334799766540527,1.106014609336853,50000 -1767.210108757019,5.27645468711853,50017.77842974663,147495,0,50017.77842974663,0.6131000518798828,1.8579843044281008,10000,51795.50549149513,0.8859414458274841,0.3913815915584564,0.7366200089454651,1.1027542352676392,50000 -1785.0233714580536,5.3356263637542725,50527.86081528664,149001,0,50527.86081528664,0.6097000241279602,1.835785865783692,10000,52323.51455950737,0.8880141973495483,0.386345237493515,0.735319972038269,1.107372164726257,50000 -1802.831404209137,5.39626669883728,51037.921318769455,150507,0,51037.921318769455,0.6142000555992126,1.8414394855499268,10000,52851.49748301506,0.8951889276504517,0.3608154356479645,0.7378399968147278,1.096303939819336,50000 -1820.3468084335327,5.445066690444946,51547.98945236206,152013,0,51547.98945236206,0.6211000084877014,1.8491344451904297,10000,53379.18301439285,0.9246850609779358,0.2611102759838104,0.7400599718093872,1.1027374267578125,50000 -1838.1192009449005,5.501033782958984,52058.14172434807,153519,0,52058.14172434807,0.6139000058174133,1.8494164943695068,10000,53907.217074632645,0.913683831691742,0.2980218827724457,0.7379800081253052,1.1061500310897827,50000 -1855.824599981308,5.557091951370239,52568.189259290695,155025,0,52568.189259290695,0.617400050163269,1.840550422668457,10000,54435.08118414879,0.9156568646430968,0.2907183766365051,0.7414599657058716,1.100637435913086,50000 -1873.38943362236,5.6150617599487305,53078.3854739666,156531,0,53078.3854739666,0.6130000352859497,1.8647652864456177,10000,54962.95487737656,0.9124282598495485,0.294901579618454,0.7405799627304077,1.1055474281311035,50000 -1891.319774627685,5.684006452560425,53588.38181400299,158037,0,53588.38181400299,0.6182000041007996,1.8408807516098025,10000,55491.00404858589,0.91796875,0.2793106138706207,0.7449199557304382,1.0848394632339478,50000 -1909.000571727753,5.731633424758911,54098.50211858749,159543,0,54098.50211858749,0.6185000538825989,1.831809043884277,10000,56018.90651440621,0.9270168542861938,0.254503846168518,0.7486400008201599,1.0716850757598877,50000 -1926.8864195346832,5.78792929649353,54608.54653549194,161048,0,54608.54653549194,0.6205000281333923,1.843653440475464,10000,56546.94776725769,0.9447743892669678,0.2013452053070068,0.7463200092315674,1.0841114521026611,50000 -1944.548994064331,5.8435704708099365,55118.47389388085,162553,0,55118.47389388085,0.6206000447273254,1.835035800933838,10000,57074.64824795723,0.942163586616516,0.2034819424152374,0.7483599781990051,1.0738941431045532,50000 -1962.2046110630035,5.903212070465088,55628.56163620949,164059,0,55628.56163620949,0.6212000250816345,1.8518248796463013,10000,57602.503945589066,0.9382772445678712,0.2111508697271347,0.7485199570655823,1.0775885581970217,50000 -1979.8018288612368,5.963002681732178,56138.76058793068,165565,0,56138.76058793068,0.6224000453948975,1.850699543952942,10000,58130.41319346428,0.9431201815605164,0.2024757117033004,0.7490999698638916,1.07507061958313,50000 -1997.844202041626,6.011024236679077,56648.74323391914,167070,0,56648.74323391914,0.6238000392913818,1.8532116413116453,10000,58658.53840446472,0.9422233700752258,0.2018176466226577,0.7506600022315979,1.0700610876083374,50000 -2015.589801311493,6.071972846984863,57158.81772947312,168576,0,57158.81772947312,0.625700056552887,1.847376108169556,10000,59186.4742205143,0.9480428695678712,0.1823372691869735,0.7523199915885925,1.0641475915908811,50000 -2033.127522945404,6.13216757774353,57668.98382782936,170082,0,57668.98382782936,0.6253000497817993,1.8547792434692385,10000,59714.29183888435,0.9587252736091614,0.1586360782384872,0.7521199584007263,1.072721004486084,50000 -2050.924390077591,6.191601753234863,58179.04970908165,171588,0,58179.04970908165,0.6230000257492065,1.851597547531128,10000,60242.26770663261,0.9563336968421936,0.1618997901678085,0.751039981842041,1.0672721862792969,50000 -2070.0848445892334,6.254953384399414,58689.24033522606,173094,0,58689.24033522606,0.6283000111579895,1.836564540863037,10000,60771.73478055,0.9559350609779358,0.161485806107521,0.7530199885368347,1.0638294219970703,50000 -2087.834163427353,6.306419849395752,59199.163927316666,174599,0,59199.163927316666,0.6271000504493713,1.835718870162964,10000,61299.51197075844,0.9558952450752258,0.1617143154144287,0.7538999915122986,1.0610108375549316,50000 -2105.543501853943,6.366549730300903,59709.09991669655,176104,0,59709.09991669655,0.6297000050544739,1.8355635404586792,10000,61827.27109313011,0.9573700428009032,0.1562838703393936,0.7552399635314941,1.0589247941970823,50000 -2123.2411789894104,6.4263880252838135,60219.06253504753,177609,0,60219.06253504753,0.6300000548362732,1.838411450386048,10000,62355.04538965225,0.9576889276504515,0.1566994935274124,0.7546600103378296,1.05818510055542,50000 -2141.059831142425,6.489727973937988,60729.021720170975,179114,0,60729.021720170975,0.6331000328063965,1.8273247480392456,10000,62882.94174575806,0.9618940949440002,0.1457709223031997,0.7546399831771851,1.054285764694214,50000 -2158.9144999980927,6.554764747619629,61239.021084070206,180620,0,61239.021084070206,0.6304000020027161,1.829865574836731,10000,63410.9137442112,0.9607979655265808,0.1452068239450454,0.7542200088500977,1.0534926652908323,50000 -2176.718202829361,6.606281042098999,61748.97645926476,182125,0,61748.97645926476,0.6331000328063965,1.830521583557129,10000,63938.778237342834,0.9617745280265808,0.1454741656780243,0.7547399997711182,1.053675413131714,50000 -2194.670075416565,6.670706748962402,62258.98048186302,183630,0,62258.98048186302,0.6306000351905823,1.831824541091919,10000,64466.852714538574,0.9620137214660645,0.1412625908851623,0.754859983921051,1.0554081201553345,50000 -2212.4462909698486,6.730313539505005,62769.02976322174,185136,0,62769.02976322174,0.6320000290870667,1.8308967351913452,10000,64994.79127025604,0.9618940949440002,0.14330562949180603,0.754859983921051,1.0536677837371826,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/measurements.csv deleted file mode 100644 index 12725bec1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1985 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.65951836,6.916249,,,,,,,,,,,,,, -1,,,0.0005978954141028,6.911577224731445,0.0006799999973736,6.912051200866699,50000.0,0.0013000001199543,6.9117279052734375,10000.0,33.085120677948,50.62336611747742,33.085120677948,17.53779411315918,0.0,0.0 -100,0.6801074,6.824913,,,,,,,,,,,,,, -200,0.80092597,6.541847,,,,,,,,,,,,,, -300,1.1292177,6.2467513,,,,,,,,,,,,,, -400,4.776688,5.975026,,,,,,,,,,,,,, -500,2.5209382,5.770973,,,,,,,,,,,,,, -600,4.250824,5.532158,,,,,,,,,,,,,, -700,3.5484412,5.4972453,,,,,,,,,,,,,, -800,4.2511196,5.4037676,,,,,,,,,,,,,, -900,3.6300282,5.1845484,,,,,,,,,,,,,, -1000,4.0812607,5.063146,,,,,,,,,,,,,, -1100,5.1402154,5.0276184,,,,,,,,,,,,,, -1200,5.587027,4.8329554,,,,,,,,,,,,,, -1300,3.3602285,4.712384,,,,,,,,,,,,,, -1400,7.227912,4.578289,,,,,,,,,,,,,, -1499,,,0.1660355478525161,4.276123523712158,0.1514399945735931,4.407378196716309,50000.0,0.1145000085234642,4.87145471572876,10000.0,543.0786073207855,578.6318206787109,543.0786073207855,35.47820210456848,0.021540880203247,0.0 -1500,6.9094734,4.6174436,,,,,,,,,,,,,, -1600,4.4961166,4.416278,,,,,,,,,,,,,, -1700,4.171754,4.3264956,,,,,,,,,,,,,, -1800,6.585314,4.249954,,,,,,,,,,,,,, -1900,5.240136,4.2635627,,,,,,,,,,,,,, -2000,5.6986904,4.063851,,,,,,,,,,,,,, -2100,5.5995297,3.87217,,,,,,,,,,,,,, -2200,4.581665,3.9115682,,,,,,,,,,,,,, -2300,5.8460655,3.8097124,,,,,,,,,,,,,, -2400,4.4309034,3.7363355,,,,,,,,,,,,,, -2500,4.046387,3.608087,,,,,,,,,,,,,, -2600,5.0090437,3.5313766,,,,,,,,,,,,,, -2700,5.0263443,3.7378578,,,,,,,,,,,,,, -2800,5.3718185,3.4534519,,,,,,,,,,,,,, -2900,4.609112,3.4156046,,,,,,,,,,,,,, -2998,,,0.3397241532802582,3.0705549716949463,0.3152599930763244,3.223424196243286,50000.0,0.2322000116109848,3.897366523742676,10000.0,1053.254381418228,1106.8110687732697,1053.254381418228,53.395957469940186,0.0544428825378418,0.0 -3000,3.0612223,3.368557,,,,,,,,,,,,,, -3100,3.6212444,3.271155,,,,,,,,,,,,,, -3200,2.5598214,3.258109,,,,,,,,,,,,,, -3300,2.605233,3.4497592,,,,,,,,,,,,,, -3400,2.9336827,3.2978482,,,,,,,,,,,,,, -3500,2.965545,3.2773585,,,,,,,,,,,,,, -3600,5.2903085,3.2145872,,,,,,,,,,,,,, -3700,2.6310997,2.9515755,,,,,,,,,,,,,, -3800,2.1948898,3.0331748,,,,,,,,,,,,,, -3900,3.2714171,3.0115302,,,,,,,,,,,,,, -4000,2.686244,2.9983442,,,,,,,,,,,,,, -4100,3.065295,2.875576,,,,,,,,,,,,,, -4200,3.239785,2.7643309,,,,,,,,,,,,,, -4300,2.9219368,2.8727388,,,,,,,,,,,,,, -4400,3.4247346,2.9280405,,,,,,,,,,,,,, -4498,,,0.4709024131298065,2.321305751800537,0.4215799868106842,2.6061043739318848,50000.0,0.3113000094890594,3.398090362548828,10000.0,1563.3771555423737,1634.8487372398376,1563.3771555423737,71.222341299057,0.0895986557006836,0.0 -4500,2.195379,2.9226048,,,,,,,,,,,,,, -4600,2.610151,2.7819514,,,,,,,,,,,,,, -4700,2.1002684,2.7282138,,,,,,,,,,,,,, -4800,2.6132724,2.882826,,,,,,,,,,,,,, -4900,3.1340663,2.6745749,,,,,,,,,,,,,, -5000,2.486464,2.851747,,,,,,,,,,,,,, -5100,2.4701893,2.6333508,,,,,,,,,,,,,, -5200,2.4577193,2.7222214,,,,,,,,,,,,,, -5300,2.2647586,2.5306845,,,,,,,,,,,,,, -5400,2.3443465,2.6186867,,,,,,,,,,,,,, -5500,2.2991219,2.4165754,,,,,,,,,,,,,, -5600,1.951356,2.4230123,,,,,,,,,,,,,, -5700,2.6775346,2.4069304,,,,,,,,,,,,,, -5800,1.7750196,2.3983228,,,,,,,,,,,,,, -5900,2.6285126,2.5553768,,,,,,,,,,,,,, -6000,,,0.5570989847183228,1.8586288690567017,0.4995799958705902,2.171689748764038,50000.0,0.3790000081062317,2.932342052459717,10000.0,2073.5986137390137,2162.8251535892487,2073.5986137390137,88.89337491989136,0.1208815574645996,0.0 -6000,2.5613983,2.3522058,,,,,,,,,,,,,, -6100,1.9124403,2.429397,,,,,,,,,,,,,, -6200,2.240947,2.4461703,,,,,,,,,,,,,, -6300,1.8580009,2.399385,,,,,,,,,,,,,, -6400,2.427305,2.418312,,,,,,,,,,,,,, -6500,1.7489716,2.3227746,,,,,,,,,,,,,, -6600,1.9032338,2.2699606,,,,,,,,,,,,,, -6700,3.2530067,2.3519793,,,,,,,,,,,,,, -6800,1.7869232,2.3025138,,,,,,,,,,,,,, -6900,2.3496711,2.278365,,,,,,,,,,,,,, -7000,1.5748342,2.4137516,,,,,,,,,,,,,, -7100,1.9433676,2.2872982,,,,,,,,,,,,,, -7200,1.3478229,2.1923704,,,,,,,,,,,,,, -7300,1.8181314,2.2464647,,,,,,,,,,,,,, -7400,1.7239101,2.1191678,,,,,,,,,,,,,, -7500,1.8980765,2.3309662,,,,,,,,,,,,,, -7503,,,0.5871731638908386,1.7035075426101685,0.5399199724197388,1.982594013214112,50000.0,0.4199000298976898,2.727141857147217,10000.0,2583.842520713806,2690.7365441322327,2583.842520713806,106.47591543197632,0.1529977321624755,0.0 -7600,2.1654959,2.1527472,,,,,,,,,,,,,, -7700,2.1889496,2.1814075,,,,,,,,,,,,,, -7800,1.9347785,2.1747904,,,,,,,,,,,,,, -7900,1.9474357,2.2777147,,,,,,,,,,,,,, -8000,1.9943792,2.1952364,,,,,,,,,,,,,, -8100,1.7601043,2.2057333,,,,,,,,,,,,,, -8200,1.8209075,2.1261814,,,,,,,,,,,,,, -8300,1.9400759,2.165527,,,,,,,,,,,,,, -8400,1.9436607,2.2049673,,,,,,,,,,,,,, -8500,2.0566068,2.3074687,,,,,,,,,,,,,, -8600,1.4396471,2.1542857,,,,,,,,,,,,,, -8700,2.0126061,2.1652513,,,,,,,,,,,,,, -8800,1.896215,2.1922479,,,,,,,,,,,,,, -8900,1.9535444,1.9519197,,,,,,,,,,,,,, -9000,1.8396206,2.289687,,,,,,,,,,,,,, -9006,,,0.6002869606018066,1.646314024925232,0.550819993019104,1.9125394821166992,50000.0,0.4235000312328338,2.7236905097961426,10000.0,3093.868242740631,3218.64040517807,3093.868242740631,124.2677731513977,0.1858928203582763,0.0 -9100,1.4652468,2.031365,,,,,,,,,,,,,, -9200,1.9295785,2.2101402,,,,,,,,,,,,,, -9300,1.9060596,2.0953078,,,,,,,,,,,,,, -9400,2.0309098,2.0083497,,,,,,,,,,,,,, -9500,1.619025,1.9918277,,,,,,,,,,,,,, -9600,2.1136134,1.9837949,,,,,,,,,,,,,, -9700,1.4820362,2.0862937,,,,,,,,,,,,,, -9800,2.075265,2.0369837,,,,,,,,,,,,,, -9900,2.1397557,2.0335033,,,,,,,,,,,,,, -10000,1.5160531,2.0524042,,,,,,,,,,,,,, -10100,1.7271024,2.17523,,,,,,,,,,,,,, -10200,1.7425908,2.0310462,,,,,,,,,,,,,, -10300,1.7776552,2.0866792,,,,,,,,,,,,,, -10400,1.4496489,2.0438676,,,,,,,,,,,,,, -10500,1.4961951,2.048204,,,,,,,,,,,,,, -10510,,,0.629324734210968,1.4985932111740112,0.5819999575614929,1.7651735544204712,50000.0,0.4556000232696533,2.51633882522583,10000.0,3603.961983203888,3746.58158493042,3603.961983203888,142.03222179412842,0.2162094116210937,0.0 -10600,1.8262875,1.9633639,,,,,,,,,,,,,, -10700,2.0524802,2.0243487,,,,,,,,,,,,,, -10800,1.7016307,1.9748688,,,,,,,,,,,,,, -10900,1.804152,2.1673694,,,,,,,,,,,,,, -11000,1.7850707,2.1204362,,,,,,,,,,,,,, -11100,1.6689775,1.8890637,,,,,,,,,,,,,, -11200,1.3088424,1.994362,,,,,,,,,,,,,, -11300,1.5127791,1.9101273,,,,,,,,,,,,,, -11400,2.0184174,2.0072248,,,,,,,,,,,,,, -11500,1.9469382,2.078639,,,,,,,,,,,,,, -11600,1.4404993,2.096778,,,,,,,,,,,,,, -11700,1.4216337,1.960299,,,,,,,,,,,,,, -11800,1.8818891,2.0931377,,,,,,,,,,,,,, -11900,1.4477136,2.054998,,,,,,,,,,,,,, -12000,1.5282193,1.9388955,,,,,,,,,,,,,, -12014,,,0.629304826259613,1.5169682502746582,0.5853399634361267,1.7463421821594238,50000.0,0.462300032377243,2.5003979206085205,10000.0,4113.984125375748,4274.320744037628,4113.984125375748,159.66547989845276,0.2459611892700195,0.0 -12100,1.4382838,1.8902469,,,,,,,,,,,,,, -12200,1.6054494,1.896131,,,,,,,,,,,,,, -12300,1.6974492,2.0193546,,,,,,,,,,,,,, -12400,1.4918569,1.9158574,,,,,,,,,,,,,, -12500,1.7549683,1.9473882,,,,,,,,,,,,,, -12600,2.0907161,2.0871282,,,,,,,,,,,,,, -12700,1.6860493,1.9473364,,,,,,,,,,,,,, -12800,1.6332822,1.9642525,,,,,,,,,,,,,, -12900,1.6643994,1.8928678,,,,,,,,,,,,,, -13000,1.4290293,1.9091012,,,,,,,,,,,,,, -13100,1.2983818,1.9078809,,,,,,,,,,,,,, -13200,1.5427691,2.017841,,,,,,,,,,,,,, -13300,1.5605373,2.0507224,,,,,,,,,,,,,, -13400,2.3225205,1.9519832,,,,,,,,,,,,,, -13500,1.641517,1.9767829,,,,,,,,,,,,,, -13517,,,0.6426379084587097,1.466149091720581,0.589959979057312,1.7425481081008911,50000.0,0.461400032043457,2.5226640701293945,10000.0,4623.965626001358,4802.236160516739,4623.965626001358,177.51647090911865,0.2762629985809326,0.0 -13600,1.5570457,1.8922135,,,,,,,,,,,,,, -13700,1.5281484,1.8111831,,,,,,,,,,,,,, -13800,1.2870568,1.9499245,,,,,,,,,,,,,, -13900,1.4719255,1.9381057,,,,,,,,,,,,,, -14000,1.6903636,1.8688005,,,,,,,,,,,,,, -14100,1.4478383,1.9222643,,,,,,,,,,,,,, -14200,1.5045458,1.8571056,,,,,,,,,,,,,, -14300,1.6686702,1.9602396,,,,,,,,,,,,,, -14400,1.6681569,1.9099398,,,,,,,,,,,,,, -14500,1.4047927,1.9002843,,,,,,,,,,,,,, -14600,1.6895922,1.8649497,,,,,,,,,,,,,, -14700,1.5465375,1.8108068,,,,,,,,,,,,,, -14800,1.5290561,1.9324758,,,,,,,,,,,,,, -14900,1.3993473,1.8822954,,,,,,,,,,,,,, -15000,1.5670874,1.8422544,,,,,,,,,,,,,, -15021,,,0.6707589030265808,1.304468035697937,0.5949400067329407,1.69130539894104,50000.0,0.4645000100135803,2.4317922592163086,10000.0,5134.007486343384,5330.208964586258,5134.007486343384,195.36043906211853,0.3092031478881836,0.0 -15100,1.58666,2.0007238,,,,,,,,,,,,,, -15200,1.6561447,1.9527259,,,,,,,,,,,,,, -15300,1.5489208,2.0003736,,,,,,,,,,,,,, -15400,1.8874168,1.9447067,,,,,,,,,,,,,, -15500,1.6542476,1.9564772,,,,,,,,,,,,,, -15600,1.7374036,1.9647624,,,,,,,,,,,,,, -15700,1.5060613,1.7936928,,,,,,,,,,,,,, -15800,1.4657303,1.8457355,,,,,,,,,,,,,, -15900,1.7395409,1.9065213,,,,,,,,,,,,,, -16000,1.4203601,1.940892,,,,,,,,,,,,,, -16100,2.3366423,1.8864363,,,,,,,,,,,,,, -16200,1.6018001,1.8961078,,,,,,,,,,,,,, -16300,1.913952,1.8510005,,,,,,,,,,,,,, -16400,1.4338888,2.050356,,,,,,,,,,,,,, -16500,1.7012851,1.8689356,,,,,,,,,,,,,, -16526,,,0.6666334271430969,1.3141627311706543,0.6050199866294861,1.657135009765625,50000.0,0.4793000221252441,2.4198639392852783,10000.0,5644.199923276901,5858.348840475082,5644.199923276901,213.22563099861145,0.3367078304290771,0.0 -16600,1.510866,1.9499259,,,,,,,,,,,,,, -16700,1.5618409,1.8416284,,,,,,,,,,,,,, -16800,1.2865746,1.7532935,,,,,,,,,,,,,, -16900,1.6084167,1.9813564,,,,,,,,,,,,,, -17000,1.6651453,1.793491,,,,,,,,,,,,,, -17100,1.5569721,1.8088595,,,,,,,,,,,,,, -17200,1.4252516,1.895108,,,,,,,,,,,,,, -17300,1.6473426,1.8029294,,,,,,,,,,,,,, -17400,1.5715815,1.858242,,,,,,,,,,,,,, -17500,1.5956881,1.87083,,,,,,,,,,,,,, -17600,1.5780468,1.9321129,,,,,,,,,,,,,, -17700,1.7271371,1.7936006,,,,,,,,,,,,,, -17800,1.7455671,2.046404,,,,,,,,,,,,,, -17900,1.6306978,1.7718908,,,,,,,,,,,,,, -18000,1.6618335,1.8516438,,,,,,,,,,,,,, -18030,,,0.6596978306770325,1.363717555999756,0.6011999845504761,1.6728837490081787,50000.0,0.4785000085830688,2.3898165225982666,10000.0,6154.205993413925,6386.336737394333,6154.205993413925,231.12113857269287,0.3682653903961181,0.0 -18100,1.6218284,1.8430207,,,,,,,,,,,,,, -18200,2.1147728,1.9249282,,,,,,,,,,,,,, -18300,2.0309865,1.8779802,,,,,,,,,,,,,, -18400,1.5209794,1.8057282,,,,,,,,,,,,,, -18500,1.7782217,1.8705088,,,,,,,,,,,,,, -18600,1.5844799,1.8682226,,,,,,,,,,,,,, -18700,1.9082854,1.8644946,,,,,,,,,,,,,, -18800,1.5038053,1.8458893,,,,,,,,,,,,,, -18900,1.4208075,1.8022275,,,,,,,,,,,,,, -19000,1.4579314,1.7735909,,,,,,,,,,,,,, -19100,1.4099528,1.7956834,,,,,,,,,,,,,, -19200,2.1383982,1.904696,,,,,,,,,,,,,, -19300,1.5671841,1.8741041,,,,,,,,,,,,,, -19400,1.8795221,1.8311453,,,,,,,,,,,,,, -19500,1.8662019,1.7875406,,,,,,,,,,,,,, -19534,,,0.6613320708274841,1.3531110286712646,0.6044999957084656,1.6473594903945925,50000.0,0.4834000170230865,2.365743398666382,10000.0,6664.177336454392,6914.852823257446,6664.177336454392,249.5828001499176,0.398468017578125,0.0 -19600,1.6186876,1.825918,,,,,,,,,,,,,, -19700,1.5582658,1.7900602,,,,,,,,,,,,,, -19800,1.811101,1.7713886,,,,,,,,,,,,,, -19900,1.6502336,1.9958587,,,,,,,,,,,,,, -20000,1.8240585,1.8493695,,,,,,,,,,,,,, -20100,1.6935227,1.7862772,,,,,,,,,,,,,, -20200,1.6406816,1.8266219,,,,,,,,,,,,,, -20300,1.6565113,1.8605309,,,,,,,,,,,,,, -20400,1.8138326,1.9841609,,,,,,,,,,,,,, -20500,1.5865383,1.9197441,,,,,,,,,,,,,, -20600,1.688375,1.9191002,,,,,,,,,,,,,, -20700,1.8876044,1.7620537,,,,,,,,,,,,,, -20800,1.7308532,1.8742727,,,,,,,,,,,,,, -20900,1.7042339,1.8181542,,,,,,,,,,,,,, -21000,1.7582747,1.8569579,,,,,,,,,,,,,, -21038,,,0.6508888602256775,1.4064011573791504,0.5990999937057495,1.6903232336044312,50000.0,0.4713000357151031,2.450623750686645,10000.0,7174.233859062195,7442.757306337357,7174.233859062195,267.3385720252991,0.4360337257385254,0.0 -21100,1.6674912,1.7874599,,,,,,,,,,,,,, -21200,1.7999963,1.7883523,,,,,,,,,,,,,, -21300,1.8085215,1.7820891,,,,,,,,,,,,,, -21400,1.7213031,1.8124833,,,,,,,,,,,,,, -21500,1.6894274,1.8479462,,,,,,,,,,,,,, -21600,1.6978588,1.7729203,,,,,,,,,,,,,, -21700,2.190452,1.70646,,,,,,,,,,,,,, -21800,2.1415427,1.7772756,,,,,,,,,,,,,, -21900,2.1268682,1.7578851,,,,,,,,,,,,,, -22000,1.5689949,1.7497193,,,,,,,,,,,,,, -22100,1.8857905,1.9173505,,,,,,,,,,,,,, -22200,1.8797818,1.7536148,,,,,,,,,,,,,, -22300,1.722426,1.7294835,,,,,,,,,,,,,, -22400,1.5711006,1.7912054,,,,,,,,,,,,,, -22500,1.8726124,1.7379167,,,,,,,,,,,,,, -22542,,,0.6646006107330322,1.3387823104858398,0.6141799688339233,1.617658257484436,50000.0,0.4860000312328338,2.391890287399292,10000.0,7684.207465171814,7970.789026260376,7684.207465171814,285.31112122535706,0.467952013015747,0.0 -22600,1.8753287,1.8015599,,,,,,,,,,,,,, -22700,1.5186721,1.7437942,,,,,,,,,,,,,, -22800,1.5715423,1.7801521,,,,,,,,,,,,,, -22900,1.5551589,1.7811563,,,,,,,,,,,,,, -23000,1.9416584,1.7304505,,,,,,,,,,,,,, -23100,1.7437401,1.7732356,,,,,,,,,,,,,, -23200,1.8880239,1.7672194,,,,,,,,,,,,,, -23300,1.7641551,1.8867543,,,,,,,,,,,,,, -23400,1.5932959,1.8279119,,,,,,,,,,,,,, -23500,1.8077586,1.9251597,,,,,,,,,,,,,, -23600,1.6442419,1.8085495,,,,,,,,,,,,,, -23700,1.5718555,1.7861246,,,,,,,,,,,,,, -23800,2.147084,1.8986014,,,,,,,,,,,,,, -23900,1.7716031,1.7416774,,,,,,,,,,,,,, -24000,1.8395919,1.7927039,,,,,,,,,,,,,, -24047,,,0.6842713356018066,1.2391905784606934,0.6031399965286255,1.66865074634552,50000.0,0.4808000326156616,2.405961513519287,10000.0,8194.420420408249,8498.789740085602,8194.420420408249,303.01013803482056,0.5014698505401611,0.0 -24100,1.6233354,1.7340671,,,,,,,,,,,,,, -24200,1.8193859,1.7521331,,,,,,,,,,,,,, -24300,1.5314604,1.6862398,,,,,,,,,,,,,, -24400,1.7012138,1.782281,,,,,,,,,,,,,, -24500,1.6873993,1.8029251,,,,,,,,,,,,,, -24600,1.8600512,1.8003759,,,,,,,,,,,,,, -24700,1.6834071,1.7649791,,,,,,,,,,,,,, -24800,1.6388735,1.7308779,,,,,,,,,,,,,, -24900,1.8838631,1.8152285,,,,,,,,,,,,,, -25000,1.7216719,1.6648177,,,,,,,,,,,,,, -25100,1.9126034,1.6746176,,,,,,,,,,,,,, -25200,1.5748625,1.7528826,,,,,,,,,,,,,, -25300,2.0776677,1.7260818,,,,,,,,,,,,,, -25400,1.8468809,1.8034831,,,,,,,,,,,,,, -25500,1.6740115,1.7074739,,,,,,,,,,,,,, -25552,,,0.6896723508834839,1.2271901369094849,0.6210399866104126,1.5833314657211304,50000.0,0.4945000112056732,2.319498300552368,10000.0,8704.557604551315,9026.98484826088,8704.557604551315,320.98294949531555,0.5339152812957764,0.0 -25600,1.7297927,1.6715585,,,,,,,,,,,,,, -25700,1.8589133,1.8165315,,,,,,,,,,,,,, -25800,1.7524123,1.7906392,,,,,,,,,,,,,, -25900,1.8894138,1.8897228,,,,,,,,,,,,,, -26000,1.8534954,1.760169,,,,,,,,,,,,,, -26100,1.6197371,1.72723,,,,,,,,,,,,,, -26200,1.7732806,1.8053542,,,,,,,,,,,,,, -26300,1.6698796,1.6944995,,,,,,,,,,,,,, -26400,1.6876096,1.647611,,,,,,,,,,,,,, -26500,2.0223677,1.6956087,,,,,,,,,,,,,, -26600,1.6952701,1.6414807,,,,,,,,,,,,,, -26700,1.8160487,1.7638569,,,,,,,,,,,,,, -26800,1.7347594,1.8434701,,,,,,,,,,,,,, -26900,1.865025,1.7845354,,,,,,,,,,,,,, -27000,1.7438458,1.7382742,,,,,,,,,,,,,, -27056,,,0.683015763759613,1.2480367422103882,0.6202600002288818,1.574621319770813,50000.0,0.4948000311851501,2.30348801612854,10000.0,9214.546523332596,9554.83399772644,9214.546523332596,338.75679183006287,0.567237377166748,0.0 -27100,1.6552968,1.6741133,,,,,,,,,,,,,, -27200,1.8158801,1.6656572,,,,,,,,,,,,,, -27300,1.6903781,1.6453416,,,,,,,,,,,,,, -27400,1.9467394,1.7818835,,,,,,,,,,,,,, -27500,1.6999977,1.6860282,,,,,,,,,,,,,, -27600,1.7694596,1.7481068,,,,,,,,,,,,,, -27700,1.7906499,1.7245324,,,,,,,,,,,,,, -27800,1.9033854,1.7721078,,,,,,,,,,,,,, -27900,1.6390811,1.6748025,,,,,,,,,,,,,, -28000,1.7487787,1.7084095,,,,,,,,,,,,,, -28100,1.8146478,1.756586,,,,,,,,,,,,,, -28200,2.0630198,1.7842809,,,,,,,,,,,,,, -28300,1.7204573,1.7102286,,,,,,,,,,,,,, -28400,1.9774976,1.7010148,,,,,,,,,,,,,, -28500,1.7513007,1.635623,,,,,,,,,,,,,, -28561,,,0.6730906963348389,1.2945865392684937,0.6128399968147278,1.6189769506454468,50000.0,0.4812000095844269,2.4035277366638184,10000.0,9724.683393001556,10082.776176929474,9724.683393001556,356.4754137992859,0.6004829406738281,0.0 -28600,1.878181,1.750647,,,,,,,,,,,,,, -28700,1.7855455,1.7288196,,,,,,,,,,,,,, -28800,2.392274,1.7458619,,,,,,,,,,,,,, -28900,2.0694866,1.6251439,,,,,,,,,,,,,, -29000,1.8586428,1.7019421,,,,,,,,,,,,,, -29100,1.8365239,1.75287,,,,,,,,,,,,,, -29200,1.6896806,1.7915934,,,,,,,,,,,,,, -29300,1.855081,1.8432207,,,,,,,,,,,,,, -29400,2.0108428,1.7853239,,,,,,,,,,,,,, -29500,1.7289288,1.6817583,,,,,,,,,,,,,, -29600,1.9274614,1.8067384,,,,,,,,,,,,,, -29700,1.5886983,1.7368714,,,,,,,,,,,,,, -29800,1.7972792,1.7800735,,,,,,,,,,,,,, -29900,1.6822666,1.6828864,,,,,,,,,,,,,, -30000,1.6705943,1.7078837,,,,,,,,,,,,,, -30066,,,0.6851283311843872,1.2363115549087524,0.6257599592208862,1.53518807888031,50000.0,0.4984000325202942,2.274780035018921,10000.0,10234.78894495964,10610.840461969376,10234.78894495964,374.3439898490906,0.6367506980895996,0.0 -30100,1.867991,1.706202,,,,,,,,,,,,,, -30200,1.7401878,1.7583255,,,,,,,,,,,,,, -30300,1.7184744,1.750212,,,,,,,,,,,,,, -30400,1.5852128,1.7527019,,,,,,,,,,,,,, -30500,1.6898265,1.8542802,,,,,,,,,,,,,, -30600,1.7651017,1.6858339,,,,,,,,,,,,,, -30700,1.6233429,1.6491809,,,,,,,,,,,,,, -30800,1.908377,1.7855583,,,,,,,,,,,,,, -30900,1.773697,1.7772367,,,,,,,,,,,,,, -31000,1.6483022,1.6976433,,,,,,,,,,,,,, -31100,1.8011847,1.7730929,,,,,,,,,,,,,, -31200,2.0005841,1.7225181,,,,,,,,,,,,,, -31300,1.7648627,1.7523545,,,,,,,,,,,,,, -31400,1.6782975,1.7091423,,,,,,,,,,,,,, -31500,1.580819,1.6783819,,,,,,,,,,,,,, -31570,,,0.686922013759613,1.2258983850479126,0.6343799829483032,1.5134379863739014,50000.0,0.5035000443458557,2.246699333190918,10000.0,10744.854972839355,11138.650857448578,10744.854972839355,391.9957549571991,0.6751341819763184,0.0 -31600,1.8766545,1.7116542,,,,,,,,,,,,,, -31700,1.6689154,1.669543,,,,,,,,,,,,,, -31800,1.9787041,1.7776002,,,,,,,,,,,,,, -31900,1.871806,1.6765784,,,,,,,,,,,,,, -32000,1.7385315,1.7483591,,,,,,,,,,,,,, -32100,1.6967547,1.6927043,,,,,,,,,,,,,, -32200,1.563024,1.7006946,,,,,,,,,,,,,, -32300,1.811637,1.5404829,,,,,,,,,,,,,, -32400,1.6155648,1.6886575,,,,,,,,,,,,,, -32500,1.7176429,1.8354051,,,,,,,,,,,,,, -32600,1.59099,1.6825252,,,,,,,,,,,,,, -32700,1.9369941,1.6889702,,,,,,,,,,,,,, -32800,1.9103534,1.8044518,,,,,,,,,,,,,, -32900,1.7726189,1.6551726,,,,,,,,,,,,,, -33000,1.8615838,1.6833179,,,,,,,,,,,,,, -33074,,,0.7069116830825806,1.1511811017990112,0.6170399785041809,1.5918534994125366,50000.0,0.4880000352859497,2.316349744796753,10000.0,11254.89800453186,11666.672231435776,11254.89800453186,409.8847246170044,0.7104458808898926,0.0 -33100,2.071825,1.6227014,,,,,,,,,,,,,, -33200,1.7643852,1.5650253,,,,,,,,,,,,,, -33300,1.6489943,1.7115086,,,,,,,,,,,,,, -33400,1.8043419,1.6922785,,,,,,,,,,,,,, -33500,1.6466217,1.6517287,,,,,,,,,,,,,, -33600,1.9081115,1.7151428,,,,,,,,,,,,,, -33700,1.7092535,1.6932259,,,,,,,,,,,,,, -33800,2.029863,1.8707441,,,,,,,,,,,,,, -33900,1.7419225,1.7028352,,,,,,,,,,,,,, -34000,1.7552046,1.7387413,,,,,,,,,,,,,, -34100,1.728471,1.7178359,,,,,,,,,,,,,, -34200,1.7403098,1.7449573,,,,,,,,,,,,,, -34300,1.6955873,1.6823916,,,,,,,,,,,,,, -34400,1.9982464,1.7127217,,,,,,,,,,,,,, -34500,1.7229168,1.6943678,,,,,,,,,,,,,, -34579,,,0.6890744566917419,1.2211211919784546,0.6181399822235107,1.5859718322753906,50000.0,0.4945000112056732,2.303445816040039,10000.0,11765.143156051636,12194.597846269608,11765.143156051636,427.47814416885376,0.7436776161193848,0.0 -34600,1.6923596,1.5745939,,,,,,,,,,,,,, -34700,2.0348349,1.6623776,,,,,,,,,,,,,, -34800,1.8208818,1.6202525,,,,,,,,,,,,,, -34900,1.6534077,1.6289628,,,,,,,,,,,,,, -35000,1.8425848,1.6694717,,,,,,,,,,,,,, -35100,1.5798287,1.4535689,,,,,,,,,,,,,, -35200,1.799442,1.6191475,,,,,,,,,,,,,, -35300,1.7955716,1.7010398,,,,,,,,,,,,,, -35400,1.5874724,1.6109942,,,,,,,,,,,,,, -35500,1.7678593,1.6843404,,,,,,,,,,,,,, -35600,1.6626009,1.6373085,,,,,,,,,,,,,, -35700,1.7687047,1.676316,,,,,,,,,,,,,, -35800,1.8405764,1.7002187,,,,,,,,,,,,,, -35900,1.9012046,1.7301295,,,,,,,,,,,,,, -36000,1.7906685,1.632307,,,,,,,,,,,,,, -36083,,,0.6932995915412903,1.190823316574097,0.6310399770736694,1.5338072776794434,50000.0,0.5162000060081482,2.258314847946167,10000.0,12275.155267238615,12722.646920681,12275.155267238615,445.4265134334564,0.7787868976593018,0.0 -36100,2.0514276,1.7233045,,,,,,,,,,,,,, -36200,1.61211,1.6264936,,,,,,,,,,,,,, -36300,1.9582349,1.6492631,,,,,,,,,,,,,, -36400,1.8025149,1.7117945,,,,,,,,,,,,,, -36500,1.9775399,1.6676846,,,,,,,,,,,,,, -36600,1.8417516,1.6700749,,,,,,,,,,,,,, -36700,1.8189528,1.6148387,,,,,,,,,,,,,, -36800,1.9233402,1.737422,,,,,,,,,,,,,, -36900,1.8952194,1.5855877,,,,,,,,,,,,,, -37000,1.9790735,1.839503,,,,,,,,,,,,,, -37100,1.6558297,1.7461467,,,,,,,,,,,,,, -37200,1.6991009,1.7071085,,,,,,,,,,,,,, -37300,1.7766292,1.7295935,,,,,,,,,,,,,, -37400,1.9605842,1.7420045,,,,,,,,,,,,,, -37500,1.8259319,1.6944569,,,,,,,,,,,,,, -37588,,,0.6861447691917419,1.233417272567749,0.6261999607086182,1.5477724075317385,50000.0,0.4964000284671783,2.281170606613159,10000.0,12785.271156549454,13251.140579938889,12785.271156549454,463.7156083583832,0.8150899410247803,0.0 -37600,1.7408673,1.6881175,,,,,,,,,,,,,, -37700,1.6409415,1.675383,,,,,,,,,,,,,, -37800,1.5676634,1.6746744,,,,,,,,,,,,,, -37900,1.7807555,1.6793683,,,,,,,,,,,,,, -38000,1.70646,1.6294109,,,,,,,,,,,,,, -38100,1.6341972,1.692044,,,,,,,,,,,,,, -38200,1.6207328,1.6418841,,,,,,,,,,,,,, -38300,1.7378072,1.6032406,,,,,,,,,,,,,, -38400,1.8972458,1.6272043,,,,,,,,,,,,,, -38500,2.0026019,1.6337229,,,,,,,,,,,,,, -38600,1.8567337,1.6346385,,,,,,,,,,,,,, -38700,1.8502,1.6877108,,,,,,,,,,,,,, -38800,1.7444293,1.7505269,,,,,,,,,,,,,, -38900,1.7426406,1.5624444,,,,,,,,,,,,,, -39000,1.7964219,1.761781,,,,,,,,,,,,,, -39093,,,0.7010124325752258,1.1792408227920532,0.64028000831604,1.4864907264709473,50000.0,0.5124000310897827,2.243726968765259,10000.0,13295.242151021956,13778.91663146019,13295.242151021956,481.4382588863373,0.8444697856903076,0.0 -39100,1.8100599,1.6643441,,,,,,,,,,,,,, -39200,1.5725249,1.6015363,,,,,,,,,,,,,, -39300,2.0574703,1.7369529,,,,,,,,,,,,,, -39400,2.0021927,1.5795355,,,,,,,,,,,,,, -39500,1.9049473,1.650343,,,,,,,,,,,,,, -39600,1.7686517,1.5483937,,,,,,,,,,,,,, -39700,1.7842565,1.6543558,,,,,,,,,,,,,, -39800,1.8078381,1.6225668,,,,,,,,,,,,,, -39900,1.7554122,1.6313287,,,,,,,,,,,,,, -40000,1.7854903,1.659977,,,,,,,,,,,,,, -40100,1.818351,1.6107421,,,,,,,,,,,,,, -40200,1.726761,1.6640472,,,,,,,,,,,,,, -40300,1.7957897,1.6225047,,,,,,,,,,,,,, -40400,1.9614317,1.7297403,,,,,,,,,,,,,, -40500,1.7342356,1.6337779,,,,,,,,,,,,,, -40598,,,0.6853276491165161,1.2317752838134766,0.6255800127983093,1.5425355434417725,50000.0,0.4976000189781189,2.290236473083496,10000.0,13805.37430691719,14307.025272130966,13805.37430691719,499.3242871761322,0.8821501731872559,0.0 -40600,1.9652053,1.7156854,,,,,,,,,,,,,, -40700,1.7319577,1.6712799,,,,,,,,,,,,,, -40800,2.0646918,1.7433913,,,,,,,,,,,,,, -40900,1.8553755,1.6901557,,,,,,,,,,,,,, -41000,1.9452896,1.5624733,,,,,,,,,,,,,, -41100,1.6047703,1.6999191,,,,,,,,,,,,,, -41200,2.0576255,1.6827463,,,,,,,,,,,,,, -41300,1.823447,1.7142383,,,,,,,,,,,,,, -41400,2.200202,1.792998,,,,,,,,,,,,,, -41500,1.5944242,1.613289,,,,,,,,,,,,,, -41600,1.7958217,1.7191014,,,,,,,,,,,,,, -41700,1.9240744,1.7138507,,,,,,,,,,,,,, -41800,1.6800349,1.5910357,,,,,,,,,,,,,, -41900,1.8745259,1.7691418,,,,,,,,,,,,,, -42000,1.9405402,1.6833781,,,,,,,,,,,,,, -42100,1.7884967,1.6606398,,,,,,,,,,,,,, -42104,,,0.7300103306770325,1.0265074968338013,0.6344799995422363,1.5053465366363523,50000.0,0.5109000205993652,2.22822904586792,10000.0,14315.556084632874,14835.124782562256,14315.556084632874,517.1469714641571,0.9237699508666992,0.0 -42200,1.9377139,1.7822828,,,,,,,,,,,,,, -42300,1.6800506,1.6181263,,,,,,,,,,,,,, -42400,2.2571774,1.7078569,,,,,,,,,,,,,, -42500,1.7576724,1.6000365,,,,,,,,,,,,,, -42600,1.9111502,1.5800873,,,,,,,,,,,,,, -42700,2.1789644,1.6630261,,,,,,,,,,,,,, -42800,1.6386658,1.6657491,,,,,,,,,,,,,, -42900,1.844062,1.6796001,,,,,,,,,,,,,, -43000,1.7860879,1.5958312,,,,,,,,,,,,,, -43100,1.7043468,1.6383218,,,,,,,,,,,,,, -43200,1.7224094,1.5320451,,,,,,,,,,,,,, -43300,1.6582524,1.6080579,,,,,,,,,,,,,, -43400,1.7459859,1.7616277,,,,,,,,,,,,,, -43500,1.7049195,1.6548314,,,,,,,,,,,,,, -43600,1.8839041,1.4940605,,,,,,,,,,,,,, -43609,,,0.7144451141357422,1.1129549741744995,0.638700008392334,1.48816180229187,50000.0,0.5078000426292419,2.239952087402344,10000.0,14825.481731653214,15362.819847106934,14825.481731653214,534.8262553215027,0.9605207443237304,0.0 -43700,1.6995025,1.6829085,,,,,,,,,,,,,, -43800,1.9443467,1.8136672,,,,,,,,,,,,,, -43900,1.8069569,1.6996342,,,,,,,,,,,,,, -44000,1.7221551,1.6822628,,,,,,,,,,,,,, -44100,1.8295584,1.6947283,,,,,,,,,,,,,, -44200,2.1416311,1.6766973,,,,,,,,,,,,,, -44300,1.7864538,1.6139622,,,,,,,,,,,,,, -44400,1.7747477,1.7078873,,,,,,,,,,,,,, -44500,1.7016854,1.5948895,,,,,,,,,,,,,, -44600,1.8920484,1.6646774,,,,,,,,,,,,,, -44700,1.8396213,1.5601743,,,,,,,,,,,,,, -44800,1.8965155,1.613008,,,,,,,,,,,,,, -44900,1.7114025,1.6129414,,,,,,,,,,,,,, -45000,1.8072617,1.6617558,,,,,,,,,,,,,, -45096,,,0.7016900181770325,1.154717206954956,0.6298800110816956,1.5299546718597412,50000.0,0.5080000162124634,2.263049840927124,10000.0,15335.629780054092,15890.73356628418,15335.629780054092,552.4992291927338,0.9993839263916016,0.0 -45100,1.8633059,1.6557748,,,,,,,,,,,,,, -45200,2.0193555,1.6953547,,,,,,,,,,,,,, -45300,1.8564959,1.6378099,,,,,,,,,,,,,, -45400,2.467938,1.7751708,,,,,,,,,,,,,, -45500,1.7802049,1.6475122,,,,,,,,,,,,,, -45600,1.9362932,1.6030364,,,,,,,,,,,,,, -45700,1.7462169,1.5608327,,,,,,,,,,,,,, -45800,1.8147984,1.6592793,,,,,,,,,,,,,, -45900,1.7043061,1.4849013,,,,,,,,,,,,,, -46000,2.2485862,1.6540484,,,,,,,,,,,,,, -46100,1.8524095,1.5790106,,,,,,,,,,,,,, -46200,1.926396,1.8265934,,,,,,,,,,,,,, -46300,1.8742373,1.7093413,,,,,,,,,,,,,, -46400,1.7067795,1.5510254,,,,,,,,,,,,,, -46500,1.8498683,1.6602684,,,,,,,,,,,,,, -46598,,,0.7108976244926453,1.125377893447876,0.6459800004959106,1.4500062465667725,50000.0,0.5160000324249268,2.164910793304444,10000.0,15844.633338212969,16418.882928848267,15844.633338212969,570.4190890789032,2.1719322204589844,0.0 -46600,1.8542678,1.6918879,,,,,,,,,,,,,, -46700,1.6710488,1.639707,,,,,,,,,,,,,, -46800,1.8166949,1.6100143,,,,,,,,,,,,,, -46900,1.9415135,1.6529989,,,,,,,,,,,,,, -47000,1.9580597,1.662352,,,,,,,,,,,,,, -47100,2.0833359,1.668268,,,,,,,,,,,,,, -47200,1.8709406,1.5853559,,,,,,,,,,,,,, -47300,1.8163253,1.6213788,,,,,,,,,,,,,, -47400,1.7515837,1.6623358,,,,,,,,,,,,,, -47500,1.8035198,1.5532888,,,,,,,,,,,,,, -47600,1.888812,1.6569191,,,,,,,,,,,,,, -47700,1.8308965,1.6883522,,,,,,,,,,,,,, -47800,1.9734036,1.7807453,,,,,,,,,,,,,, -47900,1.9824952,1.6765774,,,,,,,,,,,,,, -48000,1.7613598,1.6480285,,,,,,,,,,,,,, -48100,1.9634473,1.7252908,,,,,,,,,,,,,, -48103,,,0.6979033946990967,1.1757428646087646,0.6360799670219421,1.504656195640564,50000.0,0.511400043964386,2.2049496173858643,10000.0,16354.717643976212,16946.749007940292,16354.717643976212,588.1098670959473,2.20971941947937,0.0 -48200,1.7839718,1.4758805,,,,,,,,,,,,,, -48300,1.9457606,1.6206632,,,,,,,,,,,,,, -48400,1.9050815,1.6661031,,,,,,,,,,,,,, -48500,2.250255,1.4908922,,,,,,,,,,,,,, -48600,2.0776505,1.7035043,,,,,,,,,,,,,, -48700,1.7935716,1.5093577,,,,,,,,,,,,,, -48800,1.8416483,1.669647,,,,,,,,,,,,,, -48900,1.7300516,1.5099137,,,,,,,,,,,,,, -49000,1.9032841,1.7070938,,,,,,,,,,,,,, -49100,1.9908541,1.7380619,,,,,,,,,,,,,, -49200,2.153075,1.5277174,,,,,,,,,,,,,, -49300,2.256664,1.6514629,,,,,,,,,,,,,, -49400,1.8322675,1.5387261,,,,,,,,,,,,,, -49500,1.9403286,1.7606552,,,,,,,,,,,,,, -49600,1.9957045,1.5708407,,,,,,,,,,,,,, -49608,,,0.7112563848495483,1.128156304359436,0.6507200002670288,1.4424986839294434,50000.0,0.5209000110626221,2.1950442790985107,10000.0,16864.65782546997,17474.416180849075,16864.65782546997,605.7434694766998,2.248587131500244,0.0 -49700,1.8654419,1.6659015,,,,,,,,,,,,,, -49800,2.112255,1.5638217,,,,,,,,,,,,,, -49900,2.0724354,1.6103804,,,,,,,,,,,,,, -50000,1.8529246,1.6129119,,,,,,,,,,,,,, -50100,1.7538509,1.6299678,,,,,,,,,,,,,, -50200,1.9867476,1.735983,,,,,,,,,,,,,, -50300,1.7230984,1.6019831,,,,,,,,,,,,,, -50400,1.8105023,1.5862439,,,,,,,,,,,,,, -50500,1.962062,1.6322367,,,,,,,,,,,,,, -50600,2.0159595,1.5126593,,,,,,,,,,,,,, -50700,2.0221028,1.5628953,,,,,,,,,,,,,, -50800,1.9062357,1.6567607,,,,,,,,,,,,,, -50900,1.7435424,1.4476855,,,,,,,,,,,,,, -51000,1.7673987,1.5565107,,,,,,,,,,,,,, -51100,1.6291053,1.5653069,,,,,,,,,,,,,, -51114,,,0.7567960619926453,0.9291717410087584,0.64301997423172,1.4702551364898682,50000.0,0.5172000527381897,2.2189431190490723,10000.0,17374.636246919632,18002.36827158928,17374.636246919632,623.6246781349182,2.2868897914886475,0.0 -51200,2.0067616,1.6066489,,,,,,,,,,,,,, -51300,1.9551138,1.7312925,,,,,,,,,,,,,, -51400,1.6939487,1.6539528,,,,,,,,,,,,,, -51500,2.4513636,1.7541801,,,,,,,,,,,,,, -51600,1.8477618,1.6414094,,,,,,,,,,,,,, -51700,1.8927511,1.6142495,,,,,,,,,,,,,, -51800,1.8264531,1.5864203,,,,,,,,,,,,,, -51900,1.722345,1.6165924,,,,,,,,,,,,,, -52000,1.8276458,1.6024891,,,,,,,,,,,,,, -52100,1.868865,1.5818768,,,,,,,,,,,,,, -52200,1.8303876,1.5184333,,,,,,,,,,,,,, -52300,2.343784,1.5128926,,,,,,,,,,,,,, -52400,1.7982914,1.6150092,,,,,,,,,,,,,, -52500,2.3162463,1.611284,,,,,,,,,,,,,, -52600,1.7329667,1.5874277,,,,,,,,,,,,,, -52620,,,0.7090441584587097,1.130107045173645,0.6304599642753601,1.5274958610534668,50000.0,0.4982000291347503,2.276724100112915,10000.0,17884.815304279327,18531.05840587616,17884.815304279327,642.0388751029968,2.3303985595703125,0.0 -52700,1.7196931,1.620173,,,,,,,,,,,,,, -52800,1.8030959,1.5743322,,,,,,,,,,,,,, -52900,2.1683004,1.5549906,,,,,,,,,,,,,, -53000,1.9137987,1.5717652,,,,,,,,,,,,,, -53100,1.841564,1.5419167,,,,,,,,,,,,,, -53200,2.040341,1.6770254,,,,,,,,,,,,,, -53300,1.9114453,1.6489573,,,,,,,,,,,,,, -53400,1.9235424,1.5909929,,,,,,,,,,,,,, -53500,2.2270887,1.6500432,,,,,,,,,,,,,, -53600,1.949393,1.640965,,,,,,,,,,,,,, -53700,1.7384222,1.639452,,,,,,,,,,,,,, -53800,1.7319759,1.6603303,,,,,,,,,,,,,, -53900,1.8912686,1.5655003,,,,,,,,,,,,,, -54000,1.8169413,1.5767318,,,,,,,,,,,,,, -54100,2.184684,1.6018283,,,,,,,,,,,,,, -54126,,,0.7151626348495483,1.1085048913955688,0.640720009803772,1.4724503755569458,50000.0,0.513700008392334,2.202472686767578,10000.0,18394.996651649475,19058.955248355865,18394.996651649475,659.6608362197876,2.370500087738037,0.0 -54200,1.8036587,1.6135211,,,,,,,,,,,,,, -54300,1.7740297,1.5993006,,,,,,,,,,,,,, -54400,1.8014216,1.5279028,,,,,,,,,,,,,, -54500,1.9268711,1.5179126,,,,,,,,,,,,,, -54600,1.7738781,1.5246966,,,,,,,,,,,,,, -54700,1.9847225,1.5344448,,,,,,,,,,,,,, -54800,2.0345218,1.4497898,,,,,,,,,,,,,, -54900,1.9527113,1.529418,,,,,,,,,,,,,, -55000,1.7459697,1.5368274,,,,,,,,,,,,,, -55100,1.8319168,1.5367997,,,,,,,,,,,,,, -55200,2.1299226,1.5620041,,,,,,,,,,,,,, -55300,1.9634938,1.6747804,,,,,,,,,,,,,, -55400,1.8540375,1.5725566,,,,,,,,,,,,,, -55500,1.8937695,1.4857317,,,,,,,,,,,,,, -55600,1.8982077,1.5712302,,,,,,,,,,,,,, -55632,,,0.7169164419174194,1.0998347997665403,0.6484999656677246,1.4559392929077148,50000.0,0.5157999992370605,2.200547933578491,10000.0,18905.095246076584,19586.952585458755,18905.095246076584,677.4676706790924,2.408476591110229,0.0 -55700,1.8798475,1.5336361,,,,,,,,,,,,,, -55800,1.828735,1.6115408,,,,,,,,,,,,,, -55900,1.8023484,1.64132,,,,,,,,,,,,,, -56000,1.7656777,1.598042,,,,,,,,,,,,,, -56100,1.844919,1.6424872,,,,,,,,,,,,,, -56200,1.8582509,1.5590435,,,,,,,,,,,,,, -56300,2.1081223,1.6023698,,,,,,,,,,,,,, -56400,1.8227059,1.5437367,,,,,,,,,,,,,, -56500,2.0719936,1.581928,,,,,,,,,,,,,, -56600,1.7305651,1.5492477,,,,,,,,,,,,,, -56700,2.0382564,1.6459641,,,,,,,,,,,,,, -56800,1.8792025,1.5335747,,,,,,,,,,,,,, -56900,1.7489061,1.5632796,,,,,,,,,,,,,, -57000,1.9275655,1.6037936,,,,,,,,,,,,,, -57100,1.8790815,1.5842271,,,,,,,,,,,,,, -57138,,,0.71000075340271,1.1082898378372192,0.6432799696922302,1.463404655456543,50000.0,0.5157999992370605,2.2013065814971924,10000.0,19415.004539966583,20115.204418182373,19415.004539966583,695.7189378738403,2.4459259510040283,0.0 -57200,2.067906,1.6006404,,,,,,,,,,,,,, -57300,1.8091063,1.4912072,,,,,,,,,,,,,, -57400,1.9357151,1.5807142,,,,,,,,,,,,,, -57500,1.9087061,1.5527058,,,,,,,,,,,,,, -57600,1.7608345,1.5606062,,,,,,,,,,,,,, -57700,1.9508214,1.5642252,,,,,,,,,,,,,, -57800,1.8589277,1.5561842,,,,,,,,,,,,,, -57900,2.0812063,1.4623029,,,,,,,,,,,,,, -58000,1.9344426,1.6497898,,,,,,,,,,,,,, -58100,1.8026836,1.6426523,,,,,,,,,,,,,, -58200,1.8030969,1.6283234,,,,,,,,,,,,,, -58300,2.1448195,1.5478706,,,,,,,,,,,,,, -58400,1.9411113,1.6297204,,,,,,,,,,,,,, -58500,1.9623678,1.570333,,,,,,,,,,,,,, -58600,2.055989,1.474722,,,,,,,,,,,,,, -58644,,,0.7024473547935486,1.1483705043792725,0.6444000005722046,1.4588139057159424,50000.0,0.5159000158309937,2.179934501647949,10000.0,19925.25050020218,20643.348040819168,19925.25050020218,713.5251975059509,2.4837570190429688,0.0 -58700,1.8134265,1.5492573,,,,,,,,,,,,,, -58800,1.7354327,1.4896618,,,,,,,,,,,,,, -58900,1.8746704,1.5792414,,,,,,,,,,,,,, -59000,1.963271,1.5338318,,,,,,,,,,,,,, -59100,1.8172413,1.5460882,,,,,,,,,,,,,, -59200,1.7583053,1.5700849,,,,,,,,,,,,,, -59300,1.810322,1.5239162,,,,,,,,,,,,,, -59400,1.9213467,1.6464169,,,,,,,,,,,,,, -59500,2.1618106,1.5459682,,,,,,,,,,,,,, -59600,2.28451,1.664904,,,,,,,,,,,,,, -59700,1.8677591,1.4989672,,,,,,,,,,,,,, -59800,1.7865245,1.5310618,,,,,,,,,,,,,, -59900,1.8487772,1.5531447,,,,,,,,,,,,,, -60000,1.8841983,1.5320425,,,,,,,,,,,,,, -60100,1.9523933,1.5530515,,,,,,,,,,,,,, -60150,,,0.7673987150192261,0.8855634927749634,0.6535999774932861,1.4235299825668335,50000.0,0.5308000445365906,2.1435821056365967,10000.0,20435.16510415077,21171.26289153099,20435.16510415077,731.4343252182007,2.5222463607788086,0.0 -60200,1.8512924,1.6110502,,,,,,,,,,,,,, -60300,1.9171003,1.5866985,,,,,,,,,,,,,, -60400,1.9388233,1.6081352,,,,,,,,,,,,,, -60500,1.7613395,1.4465522,,,,,,,,,,,,,, -60600,1.9376347,1.6522987,,,,,,,,,,,,,, -60700,2.0273972,1.4959214,,,,,,,,,,,,,, -60800,1.8026974,1.565502,,,,,,,,,,,,,, -60900,2.1639671,1.7410021,,,,,,,,,,,,,, -61000,1.8506935,1.6929444,,,,,,,,,,,,,, -61100,1.8047038,1.6691687,,,,,,,,,,,,,, -61200,1.895984,1.6085755,,,,,,,,,,,,,, -61300,2.0498874,1.6601708,,,,,,,,,,,,,, -61400,1.8958838,1.5090013,,,,,,,,,,,,,, -61500,1.8698686,1.5442718,,,,,,,,,,,,,, -61600,1.7943599,1.5274992,,,,,,,,,,,,,, -61655,,,0.7128706574440002,1.1097934246063232,0.6341599822044373,1.517980456352234,50000.0,0.5003000497817993,2.302015781402588,10000.0,20945.074239969254,21699.009701251984,20945.074239969254,749.1801223754883,2.560462474822998,0.0 -61700,1.7337645,1.4898541,,,,,,,,,,,,,, -61800,1.8782241,1.5241125,,,,,,,,,,,,,, -61900,2.0889041,1.574733,,,,,,,,,,,,,, -62000,1.9860456,1.4835573,,,,,,,,,,,,,, -62100,1.9381945,1.5295691,,,,,,,,,,,,,, -62200,2.0475051,1.5109506,,,,,,,,,,,,,, -62300,1.7787641,1.5901675,,,,,,,,,,,,,, -62400,2.1206765,1.6059577,,,,,,,,,,,,,, -62500,2.110593,1.5797045,,,,,,,,,,,,,, -62600,1.8573959,1.5330087,,,,,,,,,,,,,, -62700,1.8074081,1.5468553,,,,,,,,,,,,,, -62800,2.1688619,1.593123,,,,,,,,,,,,,, -62900,2.1063614,1.6696548,,,,,,,,,,,,,, -63000,2.049463,1.4495085,,,,,,,,,,,,,, -63100,2.2283304,1.7127054,,,,,,,,,,,,,, -63160,,,0.7302096486091614,1.038726806640625,0.6582399606704712,1.4035637378692627,50000.0,0.5281000137329102,2.1373679637908936,10000.0,21454.97639608383,22226.67244052887,21454.97639608383,766.840833902359,2.6059324741363525,0.0 -63200,1.8698148,1.5757322,,,,,,,,,,,,,, -63300,1.9348751,1.5453824,,,,,,,,,,,,,, -63400,2.0003018,1.4659592,,,,,,,,,,,,,, -63500,1.8774773,1.5791507,,,,,,,,,,,,,, -63600,1.9361247,1.6035782,,,,,,,,,,,,,, -63700,2.0378802,1.6395802,,,,,,,,,,,,,, -63800,1.897014,1.4266019,,,,,,,,,,,,,, -63900,2.0018635,1.5410855,,,,,,,,,,,,,, -64000,1.8644222,1.468987,,,,,,,,,,,,,, -64100,2.0801105,1.5435282,,,,,,,,,,,,,, -64200,1.8887953,1.3853419,,,,,,,,,,,,,, -64300,1.8905518,1.4914796,,,,,,,,,,,,,, -64400,2.0080786,1.5077274,,,,,,,,,,,,,, -64500,1.8377242,1.5592664,,,,,,,,,,,,,, -64600,1.8916135,1.4816318,,,,,,,,,,,,,, -64666,,,0.7167769074440002,1.0958020687103271,0.6484799981117249,1.438090443611145,50000.0,0.5207000374794006,2.1825969219207764,10000.0,21965.061421632767,22754.716769218445,21965.061421632767,784.704512834549,2.6474947929382324,0.0 -64700,1.9721618,1.5201387,,,,,,,,,,,,,, -64800,2.0504224,1.4699959,,,,,,,,,,,,,, -64900,2.0682666,1.4843023,,,,,,,,,,,,,, -65000,2.0225275,1.5208166,,,,,,,,,,,,,, -65100,1.9915438,1.6525257,,,,,,,,,,,,,, -65200,1.9487197,1.551697,,,,,,,,,,,,,, -65300,1.8719555,1.6058068,,,,,,,,,,,,,, -65400,1.9383872,1.6159823,,,,,,,,,,,,,, -65500,2.020166,1.5887693,,,,,,,,,,,,,, -65600,2.064503,1.537889,,,,,,,,,,,,,, -65700,2.0796747,1.530335,,,,,,,,,,,,,, -65800,1.9592676,1.508249,,,,,,,,,,,,,, -65900,2.0864863,1.5455859,,,,,,,,,,,,,, -66000,1.8350803,1.5772953,,,,,,,,,,,,,, -66100,2.031091,1.4098835,,,,,,,,,,,,,, -66172,,,0.7169762253761292,1.0895507335662842,0.6572399735450745,1.4136441946029663,50000.0,0.5215000510215759,2.1981096267700195,10000.0,22475.0136077404,23282.5726454258,22475.0136077404,802.5071756839752,2.695244073867798,0.0 -66200,1.8480234,1.5611175,,,,,,,,,,,,,, -66300,1.7584083,1.5362034,,,,,,,,,,,,,, -66400,2.0479057,1.5357955,,,,,,,,,,,,,, -66500,1.8832899,1.5179683,,,,,,,,,,,,,, -66600,2.0660622,1.5138619,,,,,,,,,,,,,, -66700,1.9605575,1.6060123,,,,,,,,,,,,,, -66800,2.1399713,1.4923007,,,,,,,,,,,,,, -66900,1.9864104,1.5517794,,,,,,,,,,,,,, -67000,1.953503,1.5171143,,,,,,,,,,,,,, -67100,2.0408447,1.5957643,,,,,,,,,,,,,, -67200,2.1158369,1.5116236,,,,,,,,,,,,,, -67300,2.0236156,1.4552541,,,,,,,,,,,,,, -67400,1.9549222,1.4782354,,,,,,,,,,,,,, -67500,1.8956017,1.4543713,,,,,,,,,,,,,, -67600,2.1228158,1.5518234,,,,,,,,,,,,,, -67678,,,0.727937638759613,1.04031240940094,0.6620799899101257,1.3774571418762207,50000.0,0.5365000367164612,2.1053407192230225,10000.0,22985.089703559875,23811.22298622132,22985.089703559875,820.9895300865173,2.7346694469451904,0.0 -67700,2.0262818,1.4481027,,,,,,,,,,,,,, -67800,2.155344,1.5128847,,,,,,,,,,,,,, -67900,2.0402088,1.495344,,,,,,,,,,,,,, -68000,1.9512142,1.5698255,,,,,,,,,,,,,, -68100,2.0866292,1.5117669,,,,,,,,,,,,,, -68200,1.9658655,1.4807682,,,,,,,,,,,,,, -68300,1.875839,1.4508693,,,,,,,,,,,,,, -68400,2.0392344,1.5184029,,,,,,,,,,,,,, -68500,2.1661184,1.5274593,,,,,,,,,,,,,, -68600,2.0775294,1.4090273,,,,,,,,,,,,,, -68700,1.9424772,1.5330884,,,,,,,,,,,,,, -68800,2.1331263,1.4789945,,,,,,,,,,,,,, -68900,1.8681583,1.6126394,,,,,,,,,,,,,, -69000,1.8584942,1.4402046,,,,,,,,,,,,,, -69100,1.9082958,1.6030833,,,,,,,,,,,,,, -69184,,,0.7716637253761292,0.8626237511634827,0.661579966545105,1.3844743967056274,50000.0,0.5349000096321106,2.0840163230896,10000.0,23495.21492218972,24338.97371149063,23495.21492218972,838.5256464481354,2.7713065147399902,0.0 -69200,1.8855596,1.4209807,,,,,,,,,,,,,, -69300,2.078239,1.4564801,,,,,,,,,,,,,, -69400,2.0075576,1.4521812,,,,,,,,,,,,,, -69500,2.0454752,1.5654459,,,,,,,,,,,,,, -69600,1.9698389,1.362079,,,,,,,,,,,,,, -69700,2.083184,1.4971142,,,,,,,,,,,,,, -69800,2.0043838,1.4813148,,,,,,,,,,,,,, -69900,1.9711992,1.4784743,,,,,,,,,,,,,, -70000,2.2600682,1.4917461,,,,,,,,,,,,,, -70100,2.1858563,1.5329044,,,,,,,,,,,,,, -70200,1.8747113,1.4703652,,,,,,,,,,,,,, -70300,1.8826045,1.4377444,,,,,,,,,,,,,, -70400,2.1187227,1.5856953,,,,,,,,,,,,,, -70500,2.2294972,1.4970382,,,,,,,,,,,,,, -70600,2.0038605,1.5197304,,,,,,,,,,,,,, -70691,,,0.7581313848495483,0.9096571207046508,0.6702799797058105,1.3560611009597778,50000.0,0.5488000512123108,2.067430257797241,10000.0,24005.384941101074,24866.868101119995,24005.384941101074,856.152738571167,2.8141791820526123,0.0 -70700,1.8577834,1.4099127,,,,,,,,,,,,,, -70800,1.8009814,1.4645439,,,,,,,,,,,,,, -70900,2.0677521,1.5261531,,,,,,,,,,,,,, -71000,2.0296187,1.5786393,,,,,,,,,,,,,, -71100,2.28879,1.6676717,,,,,,,,,,,,,, -71200,1.939132,1.5559212,,,,,,,,,,,,,, -71300,2.3342457,1.4774363,,,,,,,,,,,,,, -71400,2.117495,1.4622332,,,,,,,,,,,,,, -71500,2.1311677,1.4965265,,,,,,,,,,,,,, -71600,2.1600096,1.5734472,,,,,,,,,,,,,, -71700,1.8831731,1.4479764,,,,,,,,,,,,,, -71800,2.3075888,1.5542476,,,,,,,,,,,,,, -71900,1.8929734,1.5019394,,,,,,,,,,,,,, -72000,2.2592309,1.5744233,,,,,,,,,,,,,, -72100,1.9180231,1.5677285,,,,,,,,,,,,,, -72197,,,0.7457548975944519,0.974716067314148,0.6640200018882751,1.391406536102295,50000.0,0.534500002861023,2.125728368759156,10000.0,24515.50654459,25394.858213186264,24515.50654459,873.9265124797821,2.855938196182251,0.0 -72200,2.2894297,1.4381248,,,,,,,,,,,,,, -72300,1.9988754,1.5126624,,,,,,,,,,,,,, -72400,2.0027528,1.5326982,,,,,,,,,,,,,, -72500,2.2136652,1.642685,,,,,,,,,,,,,, -72600,2.1327467,1.6340296,,,,,,,,,,,,,, -72700,2.0515552,1.4473972,,,,,,,,,,,,,, -72800,2.2606537,1.5378277,,,,,,,,,,,,,, -72900,1.9725035,1.4586443,,,,,,,,,,,,,, -73000,2.2604759,1.5028472,,,,,,,,,,,,,, -73100,1.9492694,1.4390658,,,,,,,,,,,,,, -73200,1.9940507,1.507782,,,,,,,,,,,,,, -73300,1.960637,1.4164734,,,,,,,,,,,,,, -73400,2.197106,1.513655,,,,,,,,,,,,,, -73500,2.1435134,1.5046587,,,,,,,,,,,,,, -73600,2.1349044,1.4738846,,,,,,,,,,,,,, -73700,2.0730214,1.559846,,,,,,,,,,,,,, -73703,,,0.7359893321990967,1.0156553983688354,0.6577799916267395,1.3957440853118896,50000.0,0.5339000225067139,2.1294972896575928,10000.0,25025.546664714813,25922.5676074028,25025.546664714813,891.5010304450989,2.8975298404693604,0.0 -73800,2.1167338,1.6046349,,,,,,,,,,,,,, -73900,1.9637045,1.5074896,,,,,,,,,,,,,, -74000,1.9125143,1.3978208,,,,,,,,,,,,,, -74100,2.1624842,1.4840981,,,,,,,,,,,,,, -74200,2.028356,1.5106703,,,,,,,,,,,,,, -74300,2.334714,1.551913,,,,,,,,,,,,,, -74400,2.0661464,1.6051302,,,,,,,,,,,,,, -74500,2.0109448,1.4843805,,,,,,,,,,,,,, -74600,1.9187106,1.5489776,,,,,,,,,,,,,, -74700,1.862712,1.5018337,,,,,,,,,,,,,, -74800,2.0006533,1.4177897,,,,,,,,,,,,,, -74900,2.0642908,1.5825466,,,,,,,,,,,,,, -75000,2.170594,1.5493603,,,,,,,,,,,,,, -75100,2.1894717,1.4968503,,,,,,,,,,,,,, -75200,1.983293,1.4480053,,,,,,,,,,,,,, -75209,,,0.7361288070678711,1.0147454738616943,0.6642199754714966,1.3678507804870603,50000.0,0.532200038433075,2.108184576034546,10000.0,25535.52632427216,26450.878532886505,25535.52632427216,909.7378346920012,2.939307928085327,0.0 -75300,2.0385954,1.3790531,,,,,,,,,,,,,, -75400,2.1032462,1.3785043,,,,,,,,,,,,,, -75500,2.1816292,1.5240252,,,,,,,,,,,,,, -75600,2.24111,1.5339174,,,,,,,,,,,,,, -75700,2.3020275,1.4972595,,,,,,,,,,,,,, -75800,2.006071,1.505139,,,,,,,,,,,,,, -75900,2.2309923,1.5276992,,,,,,,,,,,,,, -76000,2.0948868,1.4408681,,,,,,,,,,,,,, -76100,1.9620734,1.476874,,,,,,,,,,,,,, -76200,2.0863855,1.3453374,,,,,,,,,,,,,, -76300,2.093731,1.3712914,,,,,,,,,,,,,, -76400,1.9284118,1.4897988,,,,,,,,,,,,,, -76500,2.1653197,1.435611,,,,,,,,,,,,,, -76600,2.0093799,1.4954534,,,,,,,,,,,,,, -76700,2.238721,1.6069899,,,,,,,,,,,,,, -76715,,,0.725027859210968,1.0601333379745483,0.6577199697494507,1.3992598056793213,50000.0,0.5281000137329102,2.151803970336914,10000.0,26045.5726313591,26978.564618349075,26045.5726313591,927.2890992164612,2.9757421016693115,0.0 -76800,2.2582316,1.4752464,,,,,,,,,,,,,, -76900,1.9416248,1.3819453,,,,,,,,,,,,,, -77000,2.1292841,1.3863242,,,,,,,,,,,,,, -77100,1.9431846,1.4830563,,,,,,,,,,,,,, -77200,2.0371904,1.4409661,,,,,,,,,,,,,, -77300,2.2651742,1.5090886,,,,,,,,,,,,,, -77400,2.1917071,1.4927888,,,,,,,,,,,,,, -77500,2.1112301,1.5190305,,,,,,,,,,,,,, -77600,2.0658207,1.4680591,,,,,,,,,,,,,, -77700,2.3322597,1.4992864,,,,,,,,,,,,,, -77800,2.1401536,1.4497633,,,,,,,,,,,,,, -77900,1.9515711,1.4036149,,,,,,,,,,,,,, -78000,1.9839634,1.3851943,,,,,,,,,,,,,, -78100,2.0910952,1.3723018,,,,,,,,,,,,,, -78200,1.9745016,1.411586,,,,,,,,,,,,,, -78221,,,0.7487444281578064,0.9544544219970704,0.6531999707221985,1.4107705354690552,50000.0,0.523300051689148,2.1556520462036133,10000.0,26555.53935289383,27506.41523051262,26555.53935289383,945.0738813877106,3.020139217376709,0.0 -78300,2.1040149,1.4962322,,,,,,,,,,,,,, -78400,2.006037,1.4506954,,,,,,,,,,,,,, -78500,2.1032243,1.3594412,,,,,,,,,,,,,, -78600,2.5035236,1.49426,,,,,,,,,,,,,, -78700,2.2650828,1.4442797,,,,,,,,,,,,,, -78800,2.1180646,1.338014,,,,,,,,,,,,,, -78900,2.0223296,1.4245526,,,,,,,,,,,,,, -79000,2.1757898,1.4929221,,,,,,,,,,,,,, -79100,2.291646,1.5058672,,,,,,,,,,,,,, -79200,2.4781206,1.5838896,,,,,,,,,,,,,, -79300,2.1442943,1.4076382,,,,,,,,,,,,,, -79400,2.3991299,1.4828421,,,,,,,,,,,,,, -79500,2.22845,1.6191477,,,,,,,,,,,,,, -79600,2.1862721,1.4861498,,,,,,,,,,,,,, -79700,2.2209308,1.426306,,,,,,,,,,,,,, -79727,,,0.7374043464660645,0.9866397380828856,0.6531199812889099,1.437100887298584,50000.0,0.5210000276565552,2.229564905166626,10000.0,27065.5693500042,28034.04595661164,27065.5693500042,962.5798609256744,3.0614492893218994,0.0 -79800,2.148368,1.4400215,,,,,,,,,,,,,, -79900,2.087013,1.5236533,,,,,,,,,,,,,, -80000,2.1171408,1.3976119,,,,,,,,,,,,,, -80100,2.1863465,1.4914972,,,,,,,,,,,,,, -80200,2.2253983,1.4384466,,,,,,,,,,,,,, -80300,2.2985375,1.5129374,,,,,,,,,,,,,, -80400,2.2293332,1.429847,,,,,,,,,,,,,, -80500,2.2981586,1.4669609,,,,,,,,,,,,,, -80600,2.041189,1.42208,,,,,,,,,,,,,, -80700,2.0914245,1.5209899,,,,,,,,,,,,,, -80800,2.018185,1.4655235,,,,,,,,,,,,,, -80900,1.923942,1.3535477,,,,,,,,,,,,,, -81000,2.14516,1.4546466,,,,,,,,,,,,,, -81100,2.2151465,1.5243049,,,,,,,,,,,,,, -81200,2.1887598,1.4023926,,,,,,,,,,,,,, -81233,,,0.7581512928009033,0.9172185659408568,0.6740999817848206,1.3255287408828735,50000.0,0.5451000332832336,2.054157257080078,10000.0,27575.630492925644,28562.10580611229,27575.630492925644,980.4816539287568,3.104302167892456,0.0 -81300,1.9630013,1.4568189,,,,,,,,,,,,,, -81400,2.210301,1.4316889,,,,,,,,,,,,,, -81500,2.3673868,1.6268892,,,,,,,,,,,,,, -81600,2.1083605,1.4542017,,,,,,,,,,,,,, -81700,2.1469443,1.4660221,,,,,,,,,,,,,, -81800,2.092406,1.4145595,,,,,,,,,,,,,, -81900,2.3182933,1.3690395,,,,,,,,,,,,,, -82000,2.0042539,1.4726424,,,,,,,,,,,,,, -82100,2.1677501,1.3983586,,,,,,,,,,,,,, -82200,2.298163,1.3807203,,,,,,,,,,,,,, -82300,2.361496,1.4608706,,,,,,,,,,,,,, -82400,2.1363497,1.4311378,,,,,,,,,,,,,, -82500,2.1731217,1.4090434,,,,,,,,,,,,,, -82600,2.0911887,1.5640249,,,,,,,,,,,,,, -82700,2.045123,1.4695247,,,,,,,,,,,,,, -82739,,,0.7515744566917419,0.9364935159683228,0.6704399585723877,1.3452812433242798,50000.0,0.547700047492981,2.067952871322632,10000.0,28085.67872595787,29090.06341791153,28085.67872595787,998.2937302589417,3.1469409465789795,0.0 -82800,2.0332294,1.3955569,,,,,,,,,,,,,, -82900,2.1396015,1.4285492,,,,,,,,,,,,,, -83000,2.027527,1.4059441,,,,,,,,,,,,,, -83100,2.3275213,1.490822,,,,,,,,,,,,,, -83200,2.0322719,1.300411,,,,,,,,,,,,,, -83300,2.1356976,1.4616001,,,,,,,,,,,,,, -83400,2.3581593,1.448452,,,,,,,,,,,,,, -83500,2.2140436,1.4039159,,,,,,,,,,,,,, -83600,2.1517968,1.465409,,,,,,,,,,,,,, -83700,2.2195187,1.4547818,,,,,,,,,,,,,, -83800,2.2424412,1.5148792,,,,,,,,,,,,,, -83900,2.3881598,1.4864907,,,,,,,,,,,,,, -84000,2.1027243,1.5023456,,,,,,,,,,,,,, -84100,2.3844943,1.4635445,,,,,,,,,,,,,, -84200,2.3213644,1.5440068,,,,,,,,,,,,,, -84245,,,0.749043345451355,0.9529641270637512,0.6751599907875061,1.3240749835968018,50000.0,0.544700026512146,2.050497055053711,10000.0,28595.77198863029,29618.119492292404,28595.77198863029,1016.1643199920654,3.186960935592652,0.0 -84300,2.1981044,1.4021499,,,,,,,,,,,,,, -84400,2.1799724,1.4445729,,,,,,,,,,,,,, -84500,2.1310344,1.4142507,,,,,,,,,,,,,, -84600,2.2006564,1.4293466,,,,,,,,,,,,,, -84700,2.1107018,1.325405,,,,,,,,,,,,,, -84800,2.4483585,1.3311489,,,,,,,,,,,,,, -84900,2.8133404,1.5732338,,,,,,,,,,,,,, -85000,2.1693163,1.4795414,,,,,,,,,,,,,, -85100,2.2125852,1.4729563,,,,,,,,,,,,,, -85200,2.1507456,1.5187061,,,,,,,,,,,,,, -85300,2.297574,1.3497597,,,,,,,,,,,,,, -85400,2.4979591,1.4799271,,,,,,,,,,,,,, -85500,2.3279812,1.408926,,,,,,,,,,,,,, -85600,2.2133195,1.3850746,,,,,,,,,,,,,, -85700,2.1726155,1.5301653,,,,,,,,,,,,,, -85751,,,0.7525310516357422,0.9264184236526488,0.6801599860191345,1.2922544479370115,50000.0,0.5537000298500061,1.9823843240737915,10000.0,29105.80688929557,30145.964017629623,29105.80688929557,1033.8766658306122,3.23142409324646,0.0 -85800,2.1526053,1.4593315,,,,,,,,,,,,,, -85900,2.236124,1.4576634,,,,,,,,,,,,,, -86000,2.2207677,1.491648,,,,,,,,,,,,,, -86100,2.5159724,1.444026,,,,,,,,,,,,,, -86200,2.2761817,1.3635585,,,,,,,,,,,,,, -86300,2.2536726,1.510874,,,,,,,,,,,,,, -86400,2.1679566,1.3395065,,,,,,,,,,,,,, -86500,2.3628616,1.3851237,,,,,,,,,,,,,, -86600,2.4367287,1.4773405,,,,,,,,,,,,,, -86700,2.2015958,1.4989282,,,,,,,,,,,,,, -86800,2.4322698,1.535885,,,,,,,,,,,,,, -86900,2.4484653,1.4656857,,,,,,,,,,,,,, -87000,2.2481203,1.4233737,,,,,,,,,,,,,, -87100,2.3441525,1.4267162,,,,,,,,,,,,,, -87200,2.410519,1.4669942,,,,,,,,,,,,,, -87257,,,0.7596858739852905,0.9148987531661988,0.6706599593162537,1.3406606912612915,50000.0,0.5484000444412231,2.032165765762329,10000.0,29615.99306058884,30674.282242536545,29615.99306058884,1051.9092426300049,3.2774195671081543,0.0 -87300,1.946158,1.3365448,,,,,,,,,,,,,, -87400,2.0602694,1.3833798,,,,,,,,,,,,,, -87500,2.19538,1.450001,,,,,,,,,,,,,, -87600,2.1903505,1.453964,,,,,,,,,,,,,, -87700,2.2011518,1.3748014,,,,,,,,,,,,,, -87800,2.2580793,1.4281964,,,,,,,,,,,,,, -87900,2.3261771,1.3827919,,,,,,,,,,,,,, -88000,2.2116723,1.3733644,,,,,,,,,,,,,, -88100,2.251507,1.4917331,,,,,,,,,,,,,, -88200,2.1128051,1.3323663,,,,,,,,,,,,,, -88300,2.345359,1.4162097,,,,,,,,,,,,,, -88400,2.1952717,1.4087819,,,,,,,,,,,,,, -88500,2.1680984,1.3718655,,,,,,,,,,,,,, -88600,2.3038943,1.4134765,,,,,,,,,,,,,, -88700,2.474085,1.3506411,,,,,,,,,,,,,, -88763,,,0.7681162357330322,0.8765470385551453,0.6717999577522278,1.342298150062561,50000.0,0.541100025177002,2.063876152038574,10000.0,30125.917417526245,31201.865349769592,30125.917417526245,1069.4649865627289,3.327728748321533,0.0 -88800,2.0983536,1.4227259,,,,,,,,,,,,,, -88900,2.223403,1.4419101,,,,,,,,,,,,,, -89000,2.4204652,1.4760648,,,,,,,,,,,,,, -89100,2.378594,1.4266132,,,,,,,,,,,,,, -89200,2.5176322,1.3043761,,,,,,,,,,,,,, -89300,2.2092428,1.3813138,,,,,,,,,,,,,, -89400,2.2786021,1.3782867,,,,,,,,,,,,,, -89500,2.1876187,1.3555825,,,,,,,,,,,,,, -89600,2.250286,1.4087245,,,,,,,,,,,,,, -89700,2.5318074,1.3093141,,,,,,,,,,,,,, -89800,2.4995666,1.5278184,,,,,,,,,,,,,, -89900,2.5642657,1.4986387,,,,,,,,,,,,,, -90000,2.1455739,1.4365277,,,,,,,,,,,,,, -90100,2.1895342,1.4214793,,,,,,,,,,,,,, -90200,2.23743,1.3753971,,,,,,,,,,,,,, -90269,,,0.7680763602256775,0.8671398162841797,0.6808800101280212,1.2917416095733645,50000.0,0.5516000390052795,2.005854368209839,10000.0,30635.965329885483,31729.75548362732,30635.965329885483,1087.206655502319,3.3750007152557373,0.0 -90300,2.5008266,1.3580481,,,,,,,,,,,,,, -90400,2.261527,1.4932635,,,,,,,,,,,,,, -90500,2.303205,1.4039706,,,,,,,,,,,,,, -90600,2.1213982,1.4201453,,,,,,,,,,,,,, -90700,2.150909,1.3105502,,,,,,,,,,,,,, -90800,2.4892375,1.3952986,,,,,,,,,,,,,, -90900,2.2051141,1.354514,,,,,,,,,,,,,, -91000,2.1692069,1.3562195,,,,,,,,,,,,,, -91100,2.1664143,1.3322392,,,,,,,,,,,,,, -91200,2.2464983,1.3636765,,,,,,,,,,,,,, -91300,2.1067598,1.3954604,,,,,,,,,,,,,, -91400,2.40537,1.3172159,,,,,,,,,,,,,, -91500,2.2858741,1.4198507,,,,,,,,,,,,,, -91600,2.3982894,1.414763,,,,,,,,,,,,,, -91700,2.4080555,1.4620388,,,,,,,,,,,,,, -91776,,,0.7579320669174194,0.906115710735321,0.6772199869155884,1.3106396198272705,50000.0,0.5533000230789185,2.02983021736145,10000.0,31146.04019165039,32257.85404109955,31146.04019165039,1105.1372406482697,3.414245128631592,0.0 -91800,2.326274,1.4398065,,,,,,,,,,,,,, -91900,2.3252833,1.3287553,,,,,,,,,,,,,, -92000,2.1023529,1.4411123,,,,,,,,,,,,,, -92100,2.4596932,1.3704244,,,,,,,,,,,,,, -92200,2.0272453,1.2970078,,,,,,,,,,,,,, -92300,2.1806748,1.3573445,,,,,,,,,,,,,, -92400,2.3160124,1.4426482,,,,,,,,,,,,,, -92500,2.4130538,1.4836309,,,,,,,,,,,,,, -92600,2.5697238,1.3413416,,,,,,,,,,,,,, -92700,2.222082,1.3631811,,,,,,,,,,,,,, -92800,2.5063982,1.5344148,,,,,,,,,,,,,, -92900,2.4735382,1.3736047,,,,,,,,,,,,,, -93000,2.507868,1.4560596,,,,,,,,,,,,,, -93100,2.9196358,1.3917687,,,,,,,,,,,,,, -93200,2.3607454,1.4648517,,,,,,,,,,,,,, -93282,,,0.7669204473495483,0.8753052949905396,0.6854999661445618,1.2829898595809937,50000.0,0.558899998664856,1.9982883930206297,10000.0,31656.18217587471,32786.21779823303,31656.18217587471,1123.2598929405212,3.460043430328369,0.0 -93300,2.4275713,1.4551919,,,,,,,,,,,,,, -93400,2.3700044,1.336104,,,,,,,,,,,,,, -93500,2.3146753,1.3968694,,,,,,,,,,,,,, -93600,2.4786677,1.3943876,,,,,,,,,,,,,, -93700,2.525721,1.5622789,,,,,,,,,,,,,, -93800,2.3507395,1.370812,,,,,,,,,,,,,, -93900,2.5735366,1.3727019,,,,,,,,,,,,,, -94000,2.191628,1.3804368,,,,,,,,,,,,,, -94100,2.555019,1.3219051,,,,,,,,,,,,,, -94200,2.3311007,1.3949751,,,,,,,,,,,,,, -94300,2.6618958,1.5373936,,,,,,,,,,,,,, -94400,2.3269975,1.3118421,,,,,,,,,,,,,, -94500,2.1694238,1.5460113,,,,,,,,,,,,,, -94600,2.4309633,1.3545322,,,,,,,,,,,,,, -94700,2.5117536,1.4990759,,,,,,,,,,,,,, -94788,,,0.7631935477256775,0.8865559697151184,0.6861400008201599,1.2788174152374268,50000.0,0.5600000023841858,2.0052812099456787,10000.0,32166.163598299023,33314.017624139786,32166.163598299023,1140.9816081523895,3.503222703933716,0.0 -94800,2.594197,1.5080702,,,,,,,,,,,,,, -94900,2.4837961,1.4514413,,,,,,,,,,,,,, -95000,2.456327,1.4034334,,,,,,,,,,,,,, -95100,2.1603212,1.288902,,,,,,,,,,,,,, -95200,2.2774572,1.3542132,,,,,,,,,,,,,, -95300,2.3516507,1.3140467,,,,,,,,,,,,,, -95400,2.3761702,1.4168108,,,,,,,,,,,,,, -95500,2.3307319,1.407539,,,,,,,,,,,,,, -95600,2.2947445,1.3380551,,,,,,,,,,,,,, -95700,2.5961635,1.358078,,,,,,,,,,,,,, -95800,2.2463338,1.4417238,,,,,,,,,,,,,, -95900,2.3803656,1.3249626,,,,,,,,,,,,,, -96000,2.492601,1.3720498,,,,,,,,,,,,,, -96100,2.382735,1.3519806,,,,,,,,,,,,,, -96200,2.6103487,1.3999239,,,,,,,,,,,,,, -96292,,,0.7673788070678711,0.8581953048706055,0.6811400055885315,1.2850172519683838,50000.0,0.5563000440597534,2.0307154655456543,10000.0,32676.12209534645,33842.43831539154,32676.12209534645,1159.3427624702454,3.55094575881958,0.0 -96300,2.4396589,1.3533425,,,,,,,,,,,,,, -96400,2.541401,1.4284462,,,,,,,,,,,,,, -96500,2.5051708,1.3896651,,,,,,,,,,,,,, -96600,2.5063496,1.4507171,,,,,,,,,,,,,, -96700,2.6097355,1.3347876,,,,,,,,,,,,,, -96800,2.4132416,1.4104729,,,,,,,,,,,,,, -96900,2.330965,1.3443534,,,,,,,,,,,,,, -97000,2.4179635,1.2847723,,,,,,,,,,,,,, -97100,2.4486928,1.4467064,,,,,,,,,,,,,, -97200,2.4191892,1.370796,,,,,,,,,,,,,, -97300,2.455288,1.3852712,,,,,,,,,,,,,, -97400,2.4595964,1.3195827,,,,,,,,,,,,,, -97500,2.3906517,1.330078,,,,,,,,,,,,,, -97600,2.516482,1.4253417,,,,,,,,,,,,,, -97700,2.3356028,1.3150128,,,,,,,,,,,,,, -97798,,,0.7949816584587097,0.7670559883117676,0.6891199946403503,1.256996750831604,50000.0,0.5625,1.977084040641785,10000.0,33186.13390159607,34370.68486762047,33186.13390159607,1177.4779727458954,3.59602165222168,0.0 -97800,2.7094803,1.4531925,,,,,,,,,,,,,, -97900,2.5047264,1.2951188,,,,,,,,,,,,,, -98000,2.2835963,1.302565,,,,,,,,,,,,,, -98100,2.42183,1.4408382,,,,,,,,,,,,,, -98200,2.2929506,1.2793146,,,,,,,,,,,,,, -98300,2.3124669,1.3553747,,,,,,,,,,,,,, -98400,2.276359,1.3033208,,,,,,,,,,,,,, -98500,2.5397005,1.3334565,,,,,,,,,,,,,, -98600,2.5062742,1.4285865,,,,,,,,,,,,,, -98700,2.5991712,1.3658324,,,,,,,,,,,,,, -98800,2.5911577,1.416348,,,,,,,,,,,,,, -98900,2.3736436,1.3040509,,,,,,,,,,,,,, -99000,2.5866017,1.4350302,,,,,,,,,,,,,, -99100,2.6145074,1.387741,,,,,,,,,,,,,, -99200,2.5449355,1.3344476,,,,,,,,,,,,,, -99300,2.511642,1.3627928,,,,,,,,,,,,,, -99303,,,0.7777224183082581,0.8399938344955444,0.6829800009727478,1.2816671133041382,50000.0,0.5552000403404236,1.99160385131836,10000.0,33696.05191516876,34898.55397820473,33696.05191516876,1195.3270378112793,3.64474105834961,0.0 -99400,2.3697054,1.3515322,,,,,,,,,,,,,, -99500,2.3388674,1.3395226,,,,,,,,,,,,,, -99600,2.512536,1.3955929,,,,,,,,,,,,,, -99700,2.6464853,1.3598347,,,,,,,,,,,,,, -99800,2.3999689,1.2636659,,,,,,,,,,,,,, -99900,2.4708903,1.3924191,,,,,,,,,,,,,, -100000,2.3705287,1.2399534,,,,,,,,,,,,,, -100100,2.6187813,1.3573436,,,,,,,,,,,,,, -100200,2.5700831,1.3334786,,,,,,,,,,,,,, -100300,2.3235714,1.3145175,,,,,,,,,,,,,, -100400,2.6720326,1.340734,,,,,,,,,,,,,, -100500,2.4467034,1.284675,,,,,,,,,,,,,, -100600,2.582102,1.2378067,,,,,,,,,,,,,, -100700,2.290291,1.2953553,,,,,,,,,,,,,, -100800,2.3435144,1.29597,,,,,,,,,,,,,, -100809,,,0.7700493931770325,0.8601894378662109,0.6826599836349487,1.296310305595398,50000.0,0.55840003490448,2.0301966667175293,10000.0,34206.17853355408,35426.718936920166,34206.17853355408,1213.2627115249634,3.693167209625244,0.0 -100900,2.630988,1.3829432,,,,,,,,,,,,,, -101000,2.546227,1.3598294,,,,,,,,,,,,,, -101100,2.4772873,1.3621206,,,,,,,,,,,,,, -101200,2.4326406,1.29143,,,,,,,,,,,,,, -101300,2.6240625,1.2607125,,,,,,,,,,,,,, -101400,2.5219417,1.2881355,,,,,,,,,,,,,, -101500,2.5058088,1.3353746,,,,,,,,,,,,,, -101600,2.4870605,1.3524994,,,,,,,,,,,,,, -101700,2.3608449,1.4310871,,,,,,,,,,,,,, -101800,2.5031824,1.250985,,,,,,,,,,,,,, -101900,2.2573264,1.3163376,,,,,,,,,,,,,, -102000,2.5258563,1.3345009,,,,,,,,,,,,,, -102100,2.334783,1.3367586,,,,,,,,,,,,,, -102200,2.4469376,1.3631495,,,,,,,,,,,,,, -102300,2.732203,1.3588343,,,,,,,,,,,,,, -102315,,,0.7760483026504517,0.8299661874771118,0.6898399591445923,1.2503280639648438,50000.0,0.5603000521659851,1.9734432697296145,10000.0,34716.23765563965,35954.78728866577,34716.23765563965,1231.1673793792725,3.744328260421753,0.0 -102400,2.4364367,1.3092735,,,,,,,,,,,,,, -102500,2.2856214,1.2570139,,,,,,,,,,,,,, -102600,2.5363853,1.2834545,,,,,,,,,,,,,, -102700,2.3159564,1.287002,,,,,,,,,,,,,, -102800,2.6533163,1.3335882,,,,,,,,,,,,,, -102900,2.5060115,1.3172569,,,,,,,,,,,,,, -103000,2.5280814,1.2910179,,,,,,,,,,,,,, -103100,2.2988877,1.2195778,,,,,,,,,,,,,, -103200,2.2353585,1.3110642,,,,,,,,,,,,,, -103300,2.607772,1.3929551,,,,,,,,,,,,,, -103400,2.5885706,1.3165585,,,,,,,,,,,,,, -103500,2.4592967,1.3293537,,,,,,,,,,,,,, -103600,2.635284,1.3786427,,,,,,,,,,,,,, -103700,2.4357252,1.2304106,,,,,,,,,,,,,, -103800,2.3118744,1.2281874,,,,,,,,,,,,,, -103821,,,0.76761794090271,0.8790847659111023,0.685699999332428,1.2791662216186523,50000.0,0.5681000351905823,1.9737340211868288,10000.0,35226.270773649216,36482.83643579483,35226.270773649216,1249.0791852474213,3.794133901596069,0.0 -103900,2.4997044,1.2596463,,,,,,,,,,,,,, -104000,2.6101506,1.2067254,,,,,,,,,,,,,, -104100,2.6241293,1.2873694,,,,,,,,,,,,,, -104200,2.6064475,1.3449254,,,,,,,,,,,,,, -104300,2.3465974,1.3157223,,,,,,,,,,,,,, -104400,2.3805566,1.3199384,,,,,,,,,,,,,, -104500,3.1695137,1.3502983,,,,,,,,,,,,,, -104600,2.7188492,1.3936826,,,,,,,,,,,,,, -104700,2.390958,1.291368,,,,,,,,,,,,,, -104800,2.5702674,1.3338506,,,,,,,,,,,,,, -104900,2.6750727,1.324675,,,,,,,,,,,,,, -105000,2.3852494,1.1737351,,,,,,,,,,,,,, -105100,2.4588225,1.2465372,,,,,,,,,,,,,, -105200,2.6521385,1.4275244,,,,,,,,,,,,,, -105300,2.3354785,1.269702,,,,,,,,,,,,,, -105327,,,0.7800143361091614,0.810962438583374,0.6937800049781799,1.2509777545928955,50000.0,0.5646000504493713,1.997875094413757,10000.0,35736.3866379261,37011.45383000374,35736.3866379261,1267.4798917770386,3.841939449310303,0.0 -105400,2.6197078,1.3679029,,,,,,,,,,,,,, -105500,2.8859115,1.3660539,,,,,,,,,,,,,, -105600,2.5886984,1.289272,,,,,,,,,,,,,, -105700,2.5035856,1.301286,,,,,,,,,,,,,, -105800,2.461191,1.2927465,,,,,,,,,,,,,, -105900,2.3186688,1.2523313,,,,,,,,,,,,,, -106000,2.610374,1.3488957,,,,,,,,,,,,,, -106100,2.509365,1.2738822,,,,,,,,,,,,,, -106200,2.787187,1.2303779,,,,,,,,,,,,,, -106300,2.375981,1.2297574,,,,,,,,,,,,,, -106400,2.5687492,1.378655,,,,,,,,,,,,,, -106500,2.6857395,1.37227,,,,,,,,,,,,,, -106600,2.7733483,1.3602052,,,,,,,,,,,,,, -106700,2.6547346,1.3874006,,,,,,,,,,,,,, -106800,2.7206194,1.2638316,,,,,,,,,,,,,, -106833,,,0.8098094463348389,0.6974734663963318,0.6970799565315247,1.228915095329285,50000.0,0.5701000094413757,1.9629000425338743,10000.0,36246.66552352905,37539.52548789978,36246.66552352905,1285.1757638454435,3.8856961727142334,0.0 -106900,2.7610176,1.2750162,,,,,,,,,,,,,, -107000,2.4083633,1.2460176,,,,,,,,,,,,,, -107100,2.6866848,1.2817844,,,,,,,,,,,,,, -107200,2.816987,1.2433269,,,,,,,,,,,,,, -107300,2.4855943,1.2433122,,,,,,,,,,,,,, -107400,2.4865463,1.2669696,,,,,,,,,,,,,, -107500,2.554039,1.2074963,,,,,,,,,,,,,, -107600,2.4114652,1.2300681,,,,,,,,,,,,,, -107700,2.7706048,1.4150825,,,,,,,,,,,,,, -107800,2.4608314,1.3040321,,,,,,,,,,,,,, -107900,2.9401052,1.3397709,,,,,,,,,,,,,, -108000,2.6132104,1.355533,,,,,,,,,,,,,, -108100,2.5459414,1.3086948,,,,,,,,,,,,,, -108200,2.7674944,1.246991,,,,,,,,,,,,,, -108300,2.3982146,1.2365716,,,,,,,,,,,,,, -108339,,,0.8022361397743225,0.7281384468078613,0.6987000107765198,1.2137079238891602,50000.0,0.572700023651123,1.9371442794799805,10000.0,36756.7243309021,38067.54865336418,36756.7243309021,1303.0352365970612,3.935134649276733,0.0 -108400,2.7946854,1.3354119,,,,,,,,,,,,,, -108500,2.5097573,1.2262157,,,,,,,,,,,,,, -108600,2.678915,1.3332341,,,,,,,,,,,,,, -108700,2.7747447,1.3079405,,,,,,,,,,,,,, -108800,2.6727011,1.2939495,,,,,,,,,,,,,, -108900,2.716053,1.2377368,,,,,,,,,,,,,, -109000,2.8944445,1.350558,,,,,,,,,,,,,, -109100,2.7367923,1.3807924,,,,,,,,,,,,,, -109200,2.488265,1.2833561,,,,,,,,,,,,,, -109300,2.651826,1.3409536,,,,,,,,,,,,,, -109400,2.5264244,1.2507455,,,,,,,,,,,,,, -109500,2.5455158,1.2471858,,,,,,,,,,,,,, -109600,2.8835204,1.4147755,,,,,,,,,,,,,, -109700,2.6213217,1.2842287,,,,,,,,,,,,,, -109800,2.6006036,1.3022959,,,,,,,,,,,,,, -109845,,,0.795320451259613,0.7437325119972229,0.6987400054931641,1.229027509689331,50000.0,0.5742000341415405,1.9284789562225344,10000.0,37266.74565792084,38595.21840763092,37266.74565792084,1320.5831859111786,3.982501983642578,0.0 -109900,2.5456078,1.2794597,,,,,,,,,,,,,, -110000,2.7574682,1.3204255,,,,,,,,,,,,,, -110100,2.9253688,1.286449,,,,,,,,,,,,,, -110200,2.6573784,1.2795391,,,,,,,,,,,,,, -110300,2.7364345,1.1812873,,,,,,,,,,,,,, -110400,2.8931227,1.2994436,,,,,,,,,,,,,, -110500,2.6634998,1.2286268,,,,,,,,,,,,,, -110600,2.7143197,1.2431896,,,,,,,,,,,,,, -110700,2.4099681,1.3274869,,,,,,,,,,,,,, -110800,2.6773942,1.2339025,,,,,,,,,,,,,, -110900,2.8227098,1.3265158,,,,,,,,,,,,,, -111000,2.6942976,1.1731806,,,,,,,,,,,,,, -111100,2.7078729,1.1359663,,,,,,,,,,,,,, -111200,2.5997746,1.2875295,,,,,,,,,,,,,, -111300,2.7452369,1.2188342,,,,,,,,,,,,,, -111350,,,0.8005022406578064,0.7239665985107422,0.7046999931335449,1.1938434839248655,50000.0,0.5878000259399414,1.8825327157974243,10000.0,37776.84398150444,39122.9255297184,37776.84398150444,1338.08602809906,4.032949924468994,0.0 -111400,2.5219483,1.2777208,,,,,,,,,,,,,, -111500,2.6195047,1.2094352,,,,,,,,,,,,,, -111600,2.8403082,1.2614369,,,,,,,,,,,,,, -111700,2.8020566,1.2582384,,,,,,,,,,,,,, -111800,2.8124068,1.3407162,,,,,,,,,,,,,, -111900,2.7188423,1.2296151,,,,,,,,,,,,,, -112000,2.6553977,1.2590243,,,,,,,,,,,,,, -112100,2.5925996,1.2006619,,,,,,,,,,,,,, -112200,2.618013,1.317886,,,,,,,,,,,,,, -112300,2.5892708,1.2420743,,,,,,,,,,,,,, -112400,2.881387,1.2251205,,,,,,,,,,,,,, -112500,2.7644274,1.2880336,,,,,,,,,,,,,, -112600,2.6861727,1.3125119,,,,,,,,,,,,,, -112700,2.6127236,1.2514971,,,,,,,,,,,,,, -112800,2.53115,1.1932821,,,,,,,,,,,,,, -112857,,,0.7922711968421936,0.7695667147636414,0.7007399797439575,1.2252343893051147,50000.0,0.5761000514030457,1.942015290260315,10000.0,38287.077071905136,39651.77258491516,38287.077071905136,1356.5964069366455,4.082559108734131,0.0 -112900,2.7869778,1.2430022,,,,,,,,,,,,,, -113000,2.988069,1.3363509,,,,,,,,,,,,,, -113100,2.763745,1.1874828,,,,,,,,,,,,,, -113200,2.956376,1.2499229,,,,,,,,,,,,,, -113300,2.8666143,1.3432271,,,,,,,,,,,,,, -113400,2.6410599,1.1317545,,,,,,,,,,,,,, -113500,2.7906775,1.2189473,,,,,,,,,,,,,, -113600,2.907996,1.251494,,,,,,,,,,,,,, -113700,2.7563508,1.1742889,,,,,,,,,,,,,, -113800,2.9957447,1.2867472,,,,,,,,,,,,,, -113900,2.654952,1.1981359,,,,,,,,,,,,,, -114000,2.8041,1.1970298,,,,,,,,,,,,,, -114100,2.731928,1.2304873,,,,,,,,,,,,,, -114200,2.9069686,1.339539,,,,,,,,,,,,,, -114300,2.8351512,1.297698,,,,,,,,,,,,,, -114363,,,0.797293484210968,0.7358038425445557,0.7057999968528748,1.2084553241729736,50000.0,0.5761000514030457,1.938789129257202,10000.0,38796.97853899002,40179.66720294952,38796.97853899002,1374.487015724182,4.132639169692993,0.0 -114400,2.6763377,1.2502499,,,,,,,,,,,,,, -114500,2.8490574,1.1348222,,,,,,,,,,,,,, -114600,3.0177147,1.3336768,,,,,,,,,,,,,, -114700,2.8353896,1.1591021,,,,,,,,,,,,,, -114800,2.8199306,1.2103478,,,,,,,,,,,,,, -114900,2.8388958,1.1897724,,,,,,,,,,,,,, -115000,2.9384298,1.2262222,,,,,,,,,,,,,, -115100,2.735473,1.1947322,,,,,,,,,,,,,, -115200,2.597752,1.2650522,,,,,,,,,,,,,, -115300,2.591539,1.0938193,,,,,,,,,,,,,, -115400,3.1239865,1.3447722,,,,,,,,,,,,,, -115500,2.7524302,1.2034625,,,,,,,,,,,,,, -115600,2.6258924,1.2473464,,,,,,,,,,,,,, -115700,2.806462,1.3647419,,,,,,,,,,,,,, -115800,2.7545474,1.1845355,,,,,,,,,,,,,, -115869,,,0.8302773833274841,0.6240458488464355,0.703719973564148,1.2063405513763428,50000.0,0.5750000476837158,1.9542590379714968,10000.0,39306.96019792557,40707.69045972824,39306.96019792557,1392.427879571915,4.181373596191406,0.0 -115900,3.0425746,1.315463,,,,,,,,,,,,,, -116000,2.7597363,1.2003909,,,,,,,,,,,,,, -116100,2.836798,1.3407261,,,,,,,,,,,,,, -116200,2.7798717,1.2252446,,,,,,,,,,,,,, -116300,2.7918825,1.181803,,,,,,,,,,,,,, -116400,2.8058908,1.2169884,,,,,,,,,,,,,, -116500,2.6709049,1.1498799,,,,,,,,,,,,,, -116600,2.6430826,1.2350053,,,,,,,,,,,,,, -116700,2.830915,1.2520244,,,,,,,,,,,,,, -116800,2.9747145,1.201724,,,,,,,,,,,,,, -116900,2.974572,1.2415667,,,,,,,,,,,,,, -117000,2.605983,1.1062281,,,,,,,,,,,,,, -117100,2.9107451,1.1699023,,,,,,,,,,,,,, -117200,3.0009494,1.1833718,,,,,,,,,,,,,, -117300,2.968248,1.2619287,,,,,,,,,,,,,, -117375,,,0.8184191584587097,0.661994993686676,0.702019989490509,1.206328511238098,50000.0,0.5832000374794006,1.9158483743667605,10000.0,39817.018862485886,41235.641859054565,39817.018862485886,1410.2165460586548,4.232245206832886,0.0 -117400,2.9026356,1.2510735,,,,,,,,,,,,,, -117500,2.7433755,1.2017288,,,,,,,,,,,,,, -117600,2.8429213,1.2010553,,,,,,,,,,,,,, -117700,3.3214924,1.1917331,,,,,,,,,,,,,, -117800,2.940022,1.3140684,,,,,,,,,,,,,, -117900,2.950053,1.2397661,,,,,,,,,,,,,, -118000,2.977709,1.2698386,,,,,,,,,,,,,, -118100,2.8315818,1.2840773,,,,,,,,,,,,,, -118200,2.877135,1.2654033,,,,,,,,,,,,,, -118300,3.0923433,1.2586107,,,,,,,,,,,,,, -118400,2.957252,1.1034824,,,,,,,,,,,,,, -118500,2.7423887,1.2170471,,,,,,,,,,,,,, -118600,2.8491006,1.1579468,,,,,,,,,,,,,, -118700,2.854774,1.1746591,,,,,,,,,,,,,, -118800,3.2829435,1.16546,,,,,,,,,,,,,, -118881,,,0.8126594424247742,0.6735084652900696,0.7055599689483643,1.2065515518188477,50000.0,0.5807000398635864,1.93263578414917,10000.0,40327.03925204277,41763.26990413666,40327.03925204277,1427.716940164566,4.2863757610321045,0.0 -118900,2.9063525,1.2066964,,,,,,,,,,,,,, -119000,3.1129022,1.2269821,,,,,,,,,,,,,, -119100,2.9767942,1.2502648,,,,,,,,,,,,,, -119200,2.983767,1.1697298,,,,,,,,,,,,,, -119300,2.7237897,1.1956053,,,,,,,,,,,,,, -119400,3.142015,1.1758146,,,,,,,,,,,,,, -119500,2.8113303,1.2683448,,,,,,,,,,,,,, -119600,3.112259,1.2687565,,,,,,,,,,,,,, -119700,2.789264,1.1642109,,,,,,,,,,,,,, -119800,3.1446102,1.3200907,,,,,,,,,,,,,, -119900,3.012961,1.177908,,,,,,,,,,,,,, -120000,2.6804342,1.1614554,,,,,,,,,,,,,, -120100,3.023748,1.1997962,,,,,,,,,,,,,, -120200,2.8016818,1.198295,,,,,,,,,,,,,, -120300,3.269065,1.2045394,,,,,,,,,,,,,, -120387,,,0.8160474896430969,0.6632678508758545,0.7124399542808533,1.1698178052902222,50000.0,0.5871000289916992,1.8709394931793213,10000.0,40837.051486730576,42291.873514175415,40837.051486730576,1446.2023582458496,4.338989973068237,0.0 -120400,2.8637683,1.1396351,,,,,,,,,,,,,, -120500,2.8451111,1.1356736,,,,,,,,,,,,,, -120600,2.9075308,1.2258546,,,,,,,,,,,,,, -120700,3.1254368,1.205854,,,,,,,,,,,,,, -120800,2.900872,1.1407248,,,,,,,,,,,,,, -120900,2.759652,1.0693562,,,,,,,,,,,,,, -121000,2.9703526,1.2285151,,,,,,,,,,,,,, -121100,2.9624398,1.1385117,,,,,,,,,,,,,, -121200,2.9086034,1.0476122,,,,,,,,,,,,,, -121300,2.8436372,1.140287,,,,,,,,,,,,,, -121400,2.9626043,1.1760243,,,,,,,,,,,,,, -121500,2.7705827,1.1203489,,,,,,,,,,,,,, -121600,3.1341293,1.223269,,,,,,,,,,,,,, -121700,3.0301564,1.1638343,,,,,,,,,,,,,, -121800,3.2159624,1.2136503,,,,,,,,,,,,,, -121894,,,0.8153699040412903,0.6547273993492126,0.7078199982643127,1.1856532096862793,50000.0,0.5918000340461731,1.8801075220108032,10000.0,41347.19075655937,42819.61029410362,41347.19075655937,1463.6995224952698,4.385130882263184,0.0 -121900,2.6874914,1.0869479,,,,,,,,,,,,,, -122000,2.8772342,1.159986,,,,,,,,,,,,,, -122100,2.8938622,1.2137866,,,,,,,,,,,,,, -122200,2.8396802,1.1303924,,,,,,,,,,,,,, -122300,3.0849087,1.1655173,,,,,,,,,,,,,, -122400,3.1286845,1.1409049,,,,,,,,,,,,,, -122500,2.8672311,1.1395036,,,,,,,,,,,,,, -122600,3.1291416,1.1171045,,,,,,,,,,,,,, -122700,3.4238129,1.1924646,,,,,,,,,,,,,, -122800,3.1269686,1.1811082,,,,,,,,,,,,,, -122900,3.1789694,1.2086024,,,,,,,,,,,,,, -123000,3.0792508,1.2048607,,,,,,,,,,,,,, -123100,2.8740785,1.0761154,,,,,,,,,,,,,, -123200,3.203702,1.1405854,,,,,,,,,,,,,, -123300,2.9930978,1.1848321,,,,,,,,,,,,,, -123400,,,0.8108059167861938,0.6760615706443787,0.704539954662323,1.212332248687744,50000.0,0.5821000337600708,1.942033529281616,10000.0,41857.21672534943,43347.604243040085,41857.21672534943,1481.5602872371674,4.439347743988037,0.0 -123400,3.231952,1.2667814,,,,,,,,,,,,,, -123500,3.0437417,1.1440912,,,,,,,,,,,,,, -123600,3.1357603,1.0980574,,,,,,,,,,,,,, -123700,2.9268773,1.200327,,,,,,,,,,,,,, -123800,2.9461486,1.0888454,,,,,,,,,,,,,, -123900,3.731921,1.1841854,,,,,,,,,,,,,, -124000,2.9798179,1.1457896,,,,,,,,,,,,,, -124100,2.9408596,1.1837163,,,,,,,,,,,,,, -124200,3.0960977,1.119299,,,,,,,,,,,,,, -124300,3.0011837,1.0172309,,,,,,,,,,,,,, -124400,3.1043131,1.1991928,,,,,,,,,,,,,, -124500,3.1578119,1.1912954,,,,,,,,,,,,,, -124600,3.359602,1.2551858,,,,,,,,,,,,,, -124700,3.0316172,1.1276559,,,,,,,,,,,,,, -124800,3.0057833,1.1087494,,,,,,,,,,,,,, -124900,3.0290208,1.1986018,,,,,,,,,,,,,, -124906,,,0.8531568646430969,0.5148501992225647,0.7152799963951111,1.1633607149124146,50000.0,0.591200053691864,1.8790547847747805,10000.0,42367.14495229721,43875.18793177605,42367.14495229721,1499.112335205078,4.489536046981812,0.0 -125000,3.009891,1.2049227,,,,,,,,,,,,,, -125100,3.1958158,1.1778309,,,,,,,,,,,,,, -125200,2.8860283,1.0996366,,,,,,,,,,,,,, -125300,3.0765195,1.137986,,,,,,,,,,,,,, -125400,3.3104894,1.1175933,,,,,,,,,,,,,, -125500,3.1747684,1.227578,,,,,,,,,,,,,, -125600,3.1462243,1.137566,,,,,,,,,,,,,, -125700,3.2960455,1.1305606,,,,,,,,,,,,,, -125800,3.1881905,1.1994317,,,,,,,,,,,,,, -125900,2.8632402,0.99180406,,,,,,,,,,,,,, -126000,2.9778218,1.1161876,,,,,,,,,,,,,, -126100,3.0826843,0.98749495,,,,,,,,,,,,,, -126200,3.1237552,1.1176025,,,,,,,,,,,,,, -126300,3.1170561,1.1720845,,,,,,,,,,,,,, -126400,2.9326825,1.0858524,,,,,,,,,,,,,, -126412,,,0.8422951102256775,0.5681675672531128,0.7170599699020386,1.14886736869812,50000.0,0.5944000482559204,1.8772876262664795,10000.0,42877.105335474014,44402.84032511711,42877.105335474014,1516.701696395874,4.539835691452026,0.0 -126500,3.1416562,1.1752112,,,,,,,,,,,,,, -126600,2.9756048,1.0578909,,,,,,,,,,,,,, -126700,3.2931693,1.1281756,,,,,,,,,,,,,, -126800,3.2497485,1.1806505,,,,,,,,,,,,,, -126900,3.0296023,1.0526259,,,,,,,,,,,,,, -127000,3.2861917,1.0903046,,,,,,,,,,,,,, -127100,3.2073083,1.163051,,,,,,,,,,,,,, -127200,3.2073157,1.0965494,,,,,,,,,,,,,, -127300,3.2565022,1.1040213,,,,,,,,,,,,,, -127400,3.207407,1.1688174,,,,,,,,,,,,,, -127500,3.287132,1.1768548,,,,,,,,,,,,,, -127600,3.0745656,1.086448,,,,,,,,,,,,,, -127700,3.1064067,1.1070297,,,,,,,,,,,,,, -127800,2.891222,1.0525048,,,,,,,,,,,,,, -127900,3.343512,1.1548089,,,,,,,,,,,,,, -127918,,,0.8375916481018066,0.573846697807312,0.7154799699783325,1.1557798385620115,50000.0,0.5874000191688538,1.876875281333924,10000.0,43387.303060531616,44931.63086462021,43387.303060531616,1535.1918017864227,4.589509010314941,0.0 -128000,3.1459131,1.15138,,,,,,,,,,,,,, -128100,3.2998993,1.1682442,,,,,,,,,,,,,, -128200,3.1081033,1.0939732,,,,,,,,,,,,,, -128300,3.323738,1.1290683,,,,,,,,,,,,,, -128400,3.1167457,1.1040057,,,,,,,,,,,,,, -128500,3.101352,1.114776,,,,,,,,,,,,,, -128600,3.4306786,1.1045157,,,,,,,,,,,,,, -128700,3.0467343,1.030208,,,,,,,,,,,,,, -128800,3.1942081,1.1308439,,,,,,,,,,,,,, -128900,3.1282363,1.0580224,,,,,,,,,,,,,, -129000,3.109051,1.0602026,,,,,,,,,,,,,, -129100,3.0739355,1.02791,,,,,,,,,,,,,, -129200,3.2704413,1.137197,,,,,,,,,,,,,, -129300,3.3908398,1.1530871,,,,,,,,,,,,,, -129400,2.9556556,0.9923637,,,,,,,,,,,,,, -129424,,,0.8389867544174194,0.5688159465789795,0.7219199538230896,1.1433396339416504,50000.0,0.6003000140190125,1.853930711746216,10000.0,43897.30212020874,45459.47427010536,43897.30212020874,1552.941791057587,4.631393909454346,0.0 -129500,3.2055655,1.1553317,,,,,,,,,,,,,, -129600,3.1080954,1.173109,,,,,,,,,,,,,, -129700,3.3612657,1.1060307,,,,,,,,,,,,,, -129800,3.118898,0.9529387,,,,,,,,,,,,,, -129900,3.02477,1.0533905,,,,,,,,,,,,,, -130000,3.3490307,1.0885072,,,,,,,,,,,,,, -130100,2.8572054,1.0308218,,,,,,,,,,,,,, -130200,3.1198423,0.9836019,,,,,,,,,,,,,, -130300,3.2087014,1.0587003,,,,,,,,,,,,,, -130400,3.1049519,1.0139655,,,,,,,,,,,,,, -130500,3.0748916,1.0030581,,,,,,,,,,,,,, -130600,3.3265514,1.1155642,,,,,,,,,,,,,, -130700,3.0371907,1.0633783,,,,,,,,,,,,,, -130800,3.5205767,1.2009475,,,,,,,,,,,,,, -130900,3.4654078,1.1195217,,,,,,,,,,,,,, -130930,,,0.8404615521430969,0.56621253490448,0.7183199524879456,1.1442562341690063,50000.0,0.5950000286102295,1.887262225151062,10000.0,44407.29587292671,45987.14557003975,44407.29587292671,1570.5166273117063,4.68116021156311,0.0 -131000,3.2412486,1.0628388,,,,,,,,,,,,,, -131100,3.2249606,1.0523607,,,,,,,,,,,,,, -131200,3.069619,1.0354221,,,,,,,,,,,,,, -131300,3.1695468,1.1057069,,,,,,,,,,,,,, -131400,3.340523,1.140927,,,,,,,,,,,,,, -131500,3.257215,1.0543122,,,,,,,,,,,,,, -131600,3.4154696,1.1282281,,,,,,,,,,,,,, -131700,3.322503,1.0443977,,,,,,,,,,,,,, -131800,3.126458,0.9901531,,,,,,,,,,,,,, -131900,3.1000764,0.95649624,,,,,,,,,,,,,, -132000,3.0342689,1.0521402,,,,,,,,,,,,,, -132100,3.2829056,1.0060414,,,,,,,,,,,,,, -132200,3.3481913,1.0848489,,,,,,,,,,,,,, -132300,3.4584508,1.0716522,,,,,,,,,,,,,, -132400,3.097076,1.0291219,,,,,,,,,,,,,, -132436,,,0.8365752100944519,0.582433819770813,0.7178199887275696,1.1573113203048706,50000.0,0.5893000364303589,1.8838881254196167,10000.0,44917.199598789215,46514.73631954193,44917.199598789215,1588.0953686237335,4.735187530517578,0.0 -132500,3.314865,1.0738307,,,,,,,,,,,,,, -132600,3.196086,1.0461962,,,,,,,,,,,,,, -132700,3.0986373,1.0755563,,,,,,,,,,,,,, -132800,3.224455,1.0958661,,,,,,,,,,,,,, -132900,3.405965,1.0690398,,,,,,,,,,,,,, -133000,3.3926613,1.012504,,,,,,,,,,,,,, -133100,3.273484,1.0934552,,,,,,,,,,,,,, -133200,3.4258666,0.99207,,,,,,,,,,,,,, -133300,4.0402517,1.0848107,,,,,,,,,,,,,, -133400,3.414812,1.0544919,,,,,,,,,,,,,, -133500,3.3758397,1.0659096,,,,,,,,,,,,,, -133600,3.2691567,0.93902177,,,,,,,,,,,,,, -133700,3.2291393,1.0689778,,,,,,,,,,,,,, -133800,3.26289,0.96757394,,,,,,,,,,,,,, -133900,3.141155,0.9910269,,,,,,,,,,,,,, -133942,,,0.8755381107330322,0.4374611675739288,0.7226799726486206,1.1405651569366455,50000.0,0.596500039100647,1.8780877590179443,10000.0,45427.28427886963,47043.45353341103,45427.28427886963,1606.6192646026611,4.7905237674713135,0.0 -134000,3.365006,0.9991516,,,,,,,,,,,,,, -134100,3.4240549,1.0753335,,,,,,,,,,,,,, -134200,3.577089,0.9963289,,,,,,,,,,,,,, -134300,3.2789848,1.0133209,,,,,,,,,,,,,, -134400,3.265319,1.0459814,,,,,,,,,,,,,, -134500,3.302134,1.0365067,,,,,,,,,,,,,, -134600,3.5842993,1.0530398,,,,,,,,,,,,,, -134700,3.2535148,1.0188129,,,,,,,,,,,,,, -134800,3.5736806,1.0503044,,,,,,,,,,,,,, -134900,3.1442714,1.0629262,,,,,,,,,,,,,, -135000,3.4867597,1.119632,,,,,,,,,,,,,, -135100,3.5700152,1.0465057,,,,,,,,,,,,,, -135200,3.725942,1.0614905,,,,,,,,,,,,,, -135300,3.3751616,0.95905787,,,,,,,,,,,,,, -135400,3.179765,0.97602224,,,,,,,,,,,,,, -135448,,,0.8627630472183228,0.4806753098964691,0.7238199710845947,1.131712555885315,50000.0,0.5966000556945801,1.880648732185364,10000.0,45937.33114624024,47571.43300700188,45937.33114624024,1624.4392714500427,4.849410057067871,0.0 -135500,3.1732335,1.0414761,,,,,,,,,,,,,, -135600,3.2748797,1.0049995,,,,,,,,,,,,,, -135700,3.4827173,1.0588865,,,,,,,,,,,,,, -135800,3.53086,1.0603845,,,,,,,,,,,,,, -135900,3.5934784,1.0869865,,,,,,,,,,,,,, -136000,3.5262082,1.0916287,,,,,,,,,,,,,, -136100,3.1800091,0.9928757,,,,,,,,,,,,,, -136200,3.234956,1.0026428,,,,,,,,,,,,,, -136300,3.6529562,1.0857155,,,,,,,,,,,,,, -136400,3.593563,1.0556487,,,,,,,,,,,,,, -136500,3.6139145,1.0905145,,,,,,,,,,,,,, -136600,3.5272102,1.1434199,,,,,,,,,,,,,, -136700,3.4239616,1.0412287,,,,,,,,,,,,,, -136800,3.4033508,1.0893245,,,,,,,,,,,,,, -136900,3.3158336,0.9966346,,,,,,,,,,,,,, -136954,,,0.8583984375,0.4998326003551483,0.7249599695205688,1.131963849067688,50000.0,0.597000002861023,1.857327580451965,10000.0,46447.43666052818,48099.66242289543,46447.43666052818,1642.457855463028,4.900918006896973,0.0 -137000,3.1732516,0.9153453,,,,,,,,,,,,,, -137100,3.0761886,0.9683629,,,,,,,,,,,,,, -137200,3.172568,0.9436656,,,,,,,,,,,,,, -137300,3.8442767,0.98953664,,,,,,,,,,,,,, -137400,3.4797666,1.0364501,,,,,,,,,,,,,, -137500,3.4067578,1.0209353,,,,,,,,,,,,,, -137600,3.8026917,1.0861416,,,,,,,,,,,,,, -137700,3.4082246,0.9894577,,,,,,,,,,,,,, -137800,3.9012303,0.93308485,,,,,,,,,,,,,, -137900,3.5388017,1.0642354,,,,,,,,,,,,,, -138000,3.722795,1.0438683,,,,,,,,,,,,,, -138100,3.6384463,1.0108334,,,,,,,,,,,,,, -138200,3.586888,1.1370754,,,,,,,,,,,,,, -138300,3.620431,0.9081721,,,,,,,,,,,,,, -138400,3.5140035,1.0997665,,,,,,,,,,,,,, -138459,,,0.859773576259613,0.4849222302436828,0.7234799861907959,1.1349554061889648,50000.0,0.5968000292778015,1.881799340248108,10000.0,46957.42192101479,48627.31255126,46957.42192101479,1660.0084781646729,4.961538314819336,0.0 -138500,3.5323892,0.9432061,,,,,,,,,,,,,, -138600,3.2394178,0.95329076,,,,,,,,,,,,,, -138700,3.414103,1.0095495,,,,,,,,,,,,,, -138800,3.5664086,1.0051657,,,,,,,,,,,,,, -138900,3.6125154,0.9720714,,,,,,,,,,,,,, -139000,3.8387423,1.0809278,,,,,,,,,,,,,, -139100,3.628653,0.99365044,,,,,,,,,,,,,, -139200,3.2853231,0.95816827,,,,,,,,,,,,,, -139300,3.469027,0.92388517,,,,,,,,,,,,,, -139400,3.5325189,1.0773373,,,,,,,,,,,,,, -139500,3.7242615,0.9467163,,,,,,,,,,,,,, -139600,3.6301258,0.99612266,,,,,,,,,,,,,, -139700,3.5205922,1.0038645,,,,,,,,,,,,,, -139800,3.858015,1.0377425,,,,,,,,,,,,,, -139900,3.5749385,0.9944605,,,,,,,,,,,,,, -139966,,,0.8603116869926453,0.486088365316391,0.7274799942970276,1.1127536296844482,50000.0,0.6035000085830688,1.8506639003753664,10000.0,47467.60343265533,49155.43559336662,47467.60343265533,1677.843991279602,5.014599561691284,0.0 -140000,3.503114,0.85892403,,,,,,,,,,,,,, -140100,3.5024164,0.9684154,,,,,,,,,,,,,, -140200,3.4572117,0.9299828,,,,,,,,,,,,,, -140300,3.642031,0.97088873,,,,,,,,,,,,,, -140400,3.5504436,0.98617303,,,,,,,,,,,,,, -140500,3.5740159,0.9852618,,,,,,,,,,,,,, -140600,3.9508445,1.0509164,,,,,,,,,,,,,, -140700,3.5487332,0.87050676,,,,,,,,,,,,,, -140800,3.5042076,1.0123603,,,,,,,,,,,,,, -140900,3.6175282,0.9233029,,,,,,,,,,,,,, -141000,3.3928564,0.9349111,,,,,,,,,,,,,, -141100,3.3925235,0.98587304,,,,,,,,,,,,,, -141200,3.5677364,0.91106373,,,,,,,,,,,,,, -141300,3.3386228,0.9297342,,,,,,,,,,,,,, -141400,3.6391766,0.9340919,,,,,,,,,,,,,, -141472,,,0.8621651530265808,0.4798415601253509,0.7297399640083313,1.1351284980773926,50000.0,0.601900041103363,1.8711603879928589,10000.0,47977.59747934341,49683.41357302666,47977.59747934341,1695.720237493515,5.068366050720215,0.0 -141500,3.6050243,1.029397,,,,,,,,,,,,,, -141600,3.4090178,1.0011771,,,,,,,,,,,,,, -141700,3.7513664,0.97079325,,,,,,,,,,,,,, -141800,3.9169395,0.9987396,,,,,,,,,,,,,, -141900,3.6365616,0.99519396,,,,,,,,,,,,,, -142000,3.7324636,1.0243443,,,,,,,,,,,,,, -142100,3.5777671,0.92626643,,,,,,,,,,,,,, -142200,3.4017513,0.92465955,,,,,,,,,,,,,, -142300,3.7224352,0.9133482,,,,,,,,,,,,,, -142400,3.6693473,0.98813415,,,,,,,,,,,,,, -142500,3.3573823,0.89215624,,,,,,,,,,,,,, -142600,3.7104504,0.9612592,,,,,,,,,,,,,, -142700,3.6385152,0.97784066,,,,,,,,,,,,,, -142800,3.8607905,0.98334074,,,,,,,,,,,,,, -142900,3.7079577,1.0313892,,,,,,,,,,,,,, -142978,,,0.8984972834587097,0.3527859449386596,0.7320599555969238,1.1209580898284912,50000.0,0.6035000085830688,1.8868118524551392,10000.0,48487.5690472126,50211.66066980362,48487.5690472126,1713.8870961666107,5.12295126914978,0.0 -143000,4.247408,0.9419458,,,,,,,,,,,,,, -143100,3.5012047,0.91462237,,,,,,,,,,,,,, -143200,3.5858436,0.9421499,,,,,,,,,,,,,, -143300,3.8877769,0.9969387,,,,,,,,,,,,,, -143400,3.8724914,0.9611496,,,,,,,,,,,,,, -143500,3.8687773,0.93584543,,,,,,,,,,,,,, -143600,4.028435,1.0161809,,,,,,,,,,,,,, -143700,3.9032047,1.0054295,,,,,,,,,,,,,, -143800,3.6404967,0.927847,,,,,,,,,,,,,, -143900,3.6424985,0.8658445,,,,,,,,,,,,,, -144000,3.7148814,0.96183777,,,,,,,,,,,,,, -144100,3.5103445,0.972175,,,,,,,,,,,,,, -144200,3.6147184,0.87404406,,,,,,,,,,,,,, -144300,3.6182783,0.9055232,,,,,,,,,,,,,, -144400,3.9396632,1.0475174,,,,,,,,,,,,,, -144484,,,0.8881736397743225,0.3926560580730438,0.7305600047111511,1.1178728342056274,50000.0,0.6080000400543213,1.855978012084961,10000.0,48997.5791516304,50739.464485645294,48997.5791516304,1731.585001707077,5.165935516357422,0.0 -144500,3.4779415,0.8934869,,,,,,,,,,,,,, -144600,3.6346405,0.94678247,,,,,,,,,,,,,, -144700,4.101165,0.9798648,,,,,,,,,,,,,, -144800,3.8899019,0.89828235,,,,,,,,,,,,,, -144900,3.8836353,1.022845,,,,,,,,,,,,,, -145000,4.091866,0.95406485,,,,,,,,,,,,,, -145100,3.854465,0.8856487,,,,,,,,,,,,,, -145200,3.8466463,0.941064,,,,,,,,,,,,,, -145300,4.0646324,1.0590608,,,,,,,,,,,,,, -145400,3.9453323,0.9649959,,,,,,,,,,,,,, -145500,3.6147466,0.8655964,,,,,,,,,,,,,, -145600,3.8962886,0.9449609,,,,,,,,,,,,,, -145700,3.8636324,0.9263992,,,,,,,,,,,,,, -145800,3.5468745,0.8845706,,,,,,,,,,,,,, -145900,3.9224482,0.9521331,,,,,,,,,,,,,, -145989,,,0.8858019709587097,0.3927642703056335,0.7334799766540527,1.106014609336853,50000.0,0.6055999994277954,1.8426967859268188,10000.0,49507.801383018494,51267.6076464653,49507.801383018494,1749.398912668228,5.220576286315918,0.0 -146000,4.258169,0.95772684,,,,,,,,,,,,,, -146100,3.5900419,0.8359331,,,,,,,,,,,,,, -146200,3.7056599,0.9702286,,,,,,,,,,,,,, -146300,4.0869756,0.88408625,,,,,,,,,,,,,, -146400,3.8354824,0.9810562,,,,,,,,,,,,,, -146500,3.774503,0.9017654,,,,,,,,,,,,,, -146600,3.8545399,0.9122367,,,,,,,,,,,,,, -146700,3.7425177,0.8883434,,,,,,,,,,,,,, -146800,3.8539023,0.9126118,,,,,,,,,,,,,, -146900,3.7971687,0.922004,,,,,,,,,,,,,, -147000,3.689243,0.86750704,,,,,,,,,,,,,, -147100,3.5573251,0.8774481,,,,,,,,,,,,,, -147200,4.0021615,0.90258855,,,,,,,,,,,,,, -147300,4.078206,0.9879017,,,,,,,,,,,,,, -147400,3.9264648,0.8740085,,,,,,,,,,,,,, -147495,,,0.8859414458274841,0.3913815915584564,0.7366200089454651,1.1027542352676392,50000.0,0.6131000518798828,1.8579843044281008,10000.0,50017.77842974663,51795.50549149513,50017.77842974663,1767.210108757019,5.27645468711853,0.0 -147500,3.737522,0.8808248,,,,,,,,,,,,,, -147600,3.9673972,0.79135495,,,,,,,,,,,,,, -147700,3.8366106,0.8947325,,,,,,,,,,,,,, -147800,3.6522436,0.84814286,,,,,,,,,,,,,, -147900,4.196546,0.8942234,,,,,,,,,,,,,, -148000,4.0206833,0.9060215,,,,,,,,,,,,,, -148100,3.8380063,0.84995145,,,,,,,,,,,,,, -148200,3.8550858,0.9826666,,,,,,,,,,,,,, -148300,3.5876386,0.87199515,,,,,,,,,,,,,, -148400,3.990086,0.9362303,,,,,,,,,,,,,, -148500,3.9637253,0.88301367,,,,,,,,,,,,,, -148600,4.642593,1.0045732,,,,,,,,,,,,,, -148700,3.8980467,0.8572177,,,,,,,,,,,,,, -148800,3.8776639,0.8813505,,,,,,,,,,,,,, -148900,4.2883925,0.9044811,,,,,,,,,,,,,, -149000,3.632313,0.80873317,,,,,,,,,,,,,, -149001,,,0.8880141973495483,0.386345237493515,0.735319972038269,1.107372164726257,50000.0,0.6097000241279602,1.835785865783692,10000.0,50527.86081528664,52323.51455950737,50527.86081528664,1785.0233714580536,5.3356263637542725,0.0 -149100,3.9843426,0.899948,,,,,,,,,,,,,, -149200,4.310929,0.95478463,,,,,,,,,,,,,, -149300,3.7963665,0.9050162,,,,,,,,,,,,,, -149400,3.9844894,0.9285204,,,,,,,,,,,,,, -149500,4.004958,0.8865826,,,,,,,,,,,,,, -149600,4.0297346,0.9037757,,,,,,,,,,,,,, -149700,3.7945075,0.87525034,,,,,,,,,,,,,, -149800,3.8819554,0.9742726,,,,,,,,,,,,,, -149900,3.865741,0.9045127,,,,,,,,,,,,,, -150000,3.9246335,0.8346295,,,,,,,,,,,,,, -150100,3.971066,0.8375265,,,,,,,,,,,,,, -150200,3.5131772,0.78298163,,,,,,,,,,,,,, -150300,4.165841,0.92852813,,,,,,,,,,,,,, -150400,3.9890385,0.8616191,,,,,,,,,,,,,, -150500,3.9981112,0.85712016,,,,,,,,,,,,,, -150507,,,0.8951889276504517,0.3608154356479645,0.7378399968147278,1.096303939819336,50000.0,0.6142000555992126,1.8414394855499268,10000.0,51037.921318769455,52851.49748301506,51037.921318769455,1802.831404209137,5.39626669883728,0.0 -150600,3.9732084,0.8566242,,,,,,,,,,,,,, -150700,3.8572135,0.81272626,,,,,,,,,,,,,, -150800,4.057274,0.9683213,,,,,,,,,,,,,, -150900,4.072985,0.86117864,,,,,,,,,,,,,, -151000,3.864227,0.8496556,,,,,,,,,,,,,, -151100,3.857935,0.8271589,,,,,,,,,,,,,, -151200,4.0151505,0.8975184,,,,,,,,,,,,,, -151300,4.0602183,0.83861536,,,,,,,,,,,,,, -151400,3.9558907,0.74993044,,,,,,,,,,,,,, -151500,4.091678,0.86129,,,,,,,,,,,,,, -151600,4.0726247,0.8091113,,,,,,,,,,,,,, -151700,4.3615756,0.9073189,,,,,,,,,,,,,, -151800,3.8393447,0.7873346,,,,,,,,,,,,,, -151900,4.0421305,0.8751114,,,,,,,,,,,,,, -152000,4.1443505,0.8110999,,,,,,,,,,,,,, -152013,,,0.9246850609779358,0.2611102759838104,0.7400599718093872,1.1027374267578125,50000.0,0.6211000084877014,1.8491344451904297,10000.0,51547.98945236206,53379.18301439285,51547.98945236206,1820.3468084335327,5.445066690444946,0.0 -152100,3.8293016,0.77990884,,,,,,,,,,,,,, -152200,3.8139513,0.78895324,,,,,,,,,,,,,, -152300,4.0279617,0.8216721,,,,,,,,,,,,,, -152400,3.9901752,0.7596591,,,,,,,,,,,,,, -152500,3.9626603,0.86294025,,,,,,,,,,,,,, -152600,4.218234,0.8765845,,,,,,,,,,,,,, -152700,4.495047,0.9405791,,,,,,,,,,,,,, -152800,4.1170416,0.91510165,,,,,,,,,,,,,, -152900,4.4812255,0.8754558,,,,,,,,,,,,,, -153000,4.0069885,0.8120467,,,,,,,,,,,,,, -153100,3.9063282,0.75938565,,,,,,,,,,,,,, -153200,4.102982,0.84217614,,,,,,,,,,,,,, -153300,4.1972275,0.8058156,,,,,,,,,,,,,, -153400,4.1188946,0.83278704,,,,,,,,,,,,,, -153500,4.0602846,0.8738629,,,,,,,,,,,,,, -153519,,,0.913683831691742,0.2980218827724457,0.7379800081253052,1.1061500310897827,50000.0,0.6139000058174133,1.8494164943695068,10000.0,52058.14172434807,53907.217074632645,52058.14172434807,1838.1192009449005,5.501033782958984,0.0 -153600,4.079966,0.8627415,,,,,,,,,,,,,, -153700,3.9830992,0.7875022,,,,,,,,,,,,,, -153800,4.0503287,0.828897,,,,,,,,,,,,,, -153900,3.8040726,0.79049635,,,,,,,,,,,,,, -154000,4.385111,0.82620066,,,,,,,,,,,,,, -154100,3.935329,0.75483793,,,,,,,,,,,,,, -154200,4.285557,0.8732061,,,,,,,,,,,,,, -154300,3.977769,0.7306315,,,,,,,,,,,,,, -154400,3.9394746,0.77412754,,,,,,,,,,,,,, -154500,4.128496,0.8106868,,,,,,,,,,,,,, -154600,4.123727,0.8614019,,,,,,,,,,,,,, -154700,4.066242,0.81398666,,,,,,,,,,,,,, -154800,3.8049889,0.79660887,,,,,,,,,,,,,, -154900,4.421864,0.80827117,,,,,,,,,,,,,, -155000,4.3795075,0.8684125,,,,,,,,,,,,,, -155025,,,0.9156568646430968,0.2907183766365051,0.7414599657058716,1.100637435913086,50000.0,0.617400050163269,1.840550422668457,10000.0,52568.189259290695,54435.08118414879,52568.189259290695,1855.824599981308,5.557091951370239,0.0 -155100,4.187324,0.85682935,,,,,,,,,,,,,, -155200,4.3093886,0.8789415,,,,,,,,,,,,,, -155300,3.9844787,0.79777396,,,,,,,,,,,,,, -155400,3.9992745,0.88476694,,,,,,,,,,,,,, -155500,3.8962874,0.7817892,,,,,,,,,,,,,, -155600,4.3070073,0.8310146,,,,,,,,,,,,,, -155700,4.2296567,0.8685244,,,,,,,,,,,,,, -155800,4.0580683,0.80334765,,,,,,,,,,,,,, -155900,4.557595,0.93313366,,,,,,,,,,,,,, -156000,4.2397914,0.8013913,,,,,,,,,,,,,, -156100,4.1443143,0.8103236,,,,,,,,,,,,,, -156200,4.3261075,0.890663,,,,,,,,,,,,,, -156300,4.1347404,0.7022103,,,,,,,,,,,,,, -156400,4.101625,0.776937,,,,,,,,,,,,,, -156500,3.9542325,0.8335361,,,,,,,,,,,,,, -156531,,,0.9124282598495485,0.294901579618454,0.7405799627304077,1.1055474281311035,50000.0,0.6130000352859497,1.8647652864456177,10000.0,53078.3854739666,54962.95487737656,53078.3854739666,1873.38943362236,5.6150617599487305,0.0 -156600,3.9263446,0.6860087,,,,,,,,,,,,,, -156700,4.572378,0.7849472,,,,,,,,,,,,,, -156800,4.009138,0.78369844,,,,,,,,,,,,,, -156900,3.9696019,0.70426816,,,,,,,,,,,,,, -157000,3.6688278,0.6806907,,,,,,,,,,,,,, -157100,4.0285335,0.78142864,,,,,,,,,,,,,, -157200,3.8126807,0.7778506,,,,,,,,,,,,,, -157300,4.031209,0.7717186,,,,,,,,,,,,,, -157400,4.38916,0.8575461,,,,,,,,,,,,,, -157500,4.005147,0.76476324,,,,,,,,,,,,,, -157600,4.306368,0.81445014,,,,,,,,,,,,,, -157700,4.337109,0.76824826,,,,,,,,,,,,,, -157800,4.123522,0.86308205,,,,,,,,,,,,,, -157900,4.3345747,0.82703125,,,,,,,,,,,,,, -158000,4.27046,0.74141365,,,,,,,,,,,,,, -158037,,,0.91796875,0.2793106138706207,0.7449199557304382,1.0848394632339478,50000.0,0.6182000041007996,1.8408807516098025,10000.0,53588.38181400299,55491.00404858589,53588.38181400299,1891.319774627685,5.684006452560425,0.0 -158100,4.377762,0.77032024,,,,,,,,,,,,,, -158200,4.092824,0.7394189,,,,,,,,,,,,,, -158300,4.4362054,0.7509594,,,,,,,,,,,,,, -158400,4.1064105,0.71390224,,,,,,,,,,,,,, -158500,4.262589,0.82715744,,,,,,,,,,,,,, -158600,4.3896065,0.80167514,,,,,,,,,,,,,, -158700,4.1270223,0.8549156,,,,,,,,,,,,,, -158800,3.9909337,0.7668411,,,,,,,,,,,,,, -158900,4.2682734,0.7877976,,,,,,,,,,,,,, -159000,4.1132164,0.7111068,,,,,,,,,,,,,, -159100,4.744786,0.77223676,,,,,,,,,,,,,, -159200,4.098203,0.7639824,,,,,,,,,,,,,, -159300,3.890032,0.7072542,,,,,,,,,,,,,, -159400,3.9451365,0.770999,,,,,,,,,,,,,, -159500,4.505579,0.7132619,,,,,,,,,,,,,, -159543,,,0.9270168542861938,0.254503846168518,0.7486400008201599,1.0716850757598877,50000.0,0.6185000538825989,1.831809043884277,10000.0,54098.50211858749,56018.90651440621,54098.50211858749,1909.000571727753,5.731633424758911,0.0 -159600,4.445966,0.811118,,,,,,,,,,,,,, -159700,4.357022,0.77889585,,,,,,,,,,,,,, -159800,4.314091,0.7758924,,,,,,,,,,,,,, -159900,4.363389,0.74519813,,,,,,,,,,,,,, -160000,3.8989048,0.65580285,,,,,,,,,,,,,, -160100,4.280546,0.70177025,,,,,,,,,,,,,, -160200,4.285046,0.7349766,,,,,,,,,,,,,, -160300,4.0880327,0.7275105,,,,,,,,,,,,,, -160400,4.51521,0.75594765,,,,,,,,,,,,,, -160500,4.3937473,0.75188315,,,,,,,,,,,,,, -160600,4.002724,0.6575159,,,,,,,,,,,,,, -160700,4.2527785,0.8226657,,,,,,,,,,,,,, -160800,4.598519,0.7780782,,,,,,,,,,,,,, -160900,4.181983,0.6944938,,,,,,,,,,,,,, -161000,4.223044,0.75806916,,,,,,,,,,,,,, -161048,,,0.9447743892669678,0.2013452053070068,0.7463200092315674,1.0841114521026611,50000.0,0.6205000281333923,1.843653440475464,10000.0,54608.54653549194,56546.94776725769,54608.54653549194,1926.8864195346832,5.78792929649353,0.0 -161100,4.7462077,0.8164015,,,,,,,,,,,,,, -161200,4.2111154,0.76044464,,,,,,,,,,,,,, -161300,3.9947007,0.76083595,,,,,,,,,,,,,, -161400,4.198601,0.67080104,,,,,,,,,,,,,, -161500,4.496291,0.7094698,,,,,,,,,,,,,, -161600,4.4164987,0.76365554,,,,,,,,,,,,,, -161700,4.310314,0.6919793,,,,,,,,,,,,,, -161800,4.2351084,0.6966426,,,,,,,,,,,,,, -161900,4.650704,0.70914644,,,,,,,,,,,,,, -162000,4.1753335,0.759001,,,,,,,,,,,,,, -162100,4.345852,0.7682604,,,,,,,,,,,,,, -162200,3.914234,0.72272265,,,,,,,,,,,,,, -162300,4.4103956,0.82770073,,,,,,,,,,,,,, -162400,4.396818,0.7888949,,,,,,,,,,,,,, -162500,4.1774592,0.7521914,,,,,,,,,,,,,, -162553,,,0.942163586616516,0.2034819424152374,0.7483599781990051,1.0738941431045532,50000.0,0.6206000447273254,1.835035800933838,10000.0,55118.47389388085,57074.64824795723,55118.47389388085,1944.548994064331,5.8435704708099365,0.0 -162600,4.446496,0.7275636,,,,,,,,,,,,,, -162700,4.5863414,0.8585326,,,,,,,,,,,,,, -162800,4.3925333,0.6643535,,,,,,,,,,,,,, -162900,4.8899817,0.7696624,,,,,,,,,,,,,, -163000,4.239535,0.7380306,,,,,,,,,,,,,, -163100,4.359884,0.71476763,,,,,,,,,,,,,, -163200,4.5349717,0.7509688,,,,,,,,,,,,,, -163300,4.32556,0.7167301,,,,,,,,,,,,,, -163400,4.192103,0.74727553,,,,,,,,,,,,,, -163500,4.236754,0.7991405,,,,,,,,,,,,,, -163600,4.555582,0.7578461,,,,,,,,,,,,,, -163700,4.129241,0.7484554,,,,,,,,,,,,,, -163800,4.5929985,0.69362754,,,,,,,,,,,,,, -163900,4.2665334,0.7395704,,,,,,,,,,,,,, -164000,4.337465,0.69617075,,,,,,,,,,,,,, -164059,,,0.9382772445678712,0.2111508697271347,0.7485199570655823,1.0775885581970217,50000.0,0.6212000250816345,1.8518248796463013,10000.0,55628.56163620949,57602.503945589066,55628.56163620949,1962.2046110630035,5.903212070465088,0.0 -164100,4.146561,0.7283324,,,,,,,,,,,,,, -164200,4.5201206,0.71942127,,,,,,,,,,,,,, -164300,4.406873,0.7250798,,,,,,,,,,,,,, -164400,4.6735272,0.7619637,,,,,,,,,,,,,, -164500,4.330656,0.7546522,,,,,,,,,,,,,, -164600,4.372152,0.7295996,,,,,,,,,,,,,, -164700,4.3459797,0.68056977,,,,,,,,,,,,,, -164800,4.6351867,0.7234824,,,,,,,,,,,,,, -164900,4.095115,0.6620742,,,,,,,,,,,,,, -165000,4.514458,0.7144309,,,,,,,,,,,,,, -165100,4.361974,0.7672417,,,,,,,,,,,,,, -165200,4.0280538,0.67114794,,,,,,,,,,,,,, -165300,4.4158425,0.7220127,,,,,,,,,,,,,, -165400,4.507141,0.724824,,,,,,,,,,,,,, -165500,4.4868245,0.69546264,,,,,,,,,,,,,, -165565,,,0.9431201815605164,0.2024757117033004,0.7490999698638916,1.07507061958313,50000.0,0.6224000453948975,1.850699543952942,10000.0,56138.76058793068,58130.41319346428,56138.76058793068,1979.8018288612368,5.963002681732178,0.0 -165600,4.6005826,0.7218827,,,,,,,,,,,,,, -165700,4.550269,0.70244056,,,,,,,,,,,,,, -165800,4.297403,0.72935236,,,,,,,,,,,,,, -165900,4.3114395,0.65488404,,,,,,,,,,,,,, -166000,4.458774,0.6480154,,,,,,,,,,,,,, -166100,4.1761107,0.63891065,,,,,,,,,,,,,, -166200,4.590182,0.69330215,,,,,,,,,,,,,, -166300,4.7030945,0.7005473,,,,,,,,,,,,,, -166400,4.4594283,0.7073255,,,,,,,,,,,,,, -166500,3.9871626,0.64880246,,,,,,,,,,,,,, -166600,4.54482,0.67059445,,,,,,,,,,,,,, -166700,4.7253294,0.71380174,,,,,,,,,,,,,, -166800,4.293394,0.666553,,,,,,,,,,,,,, -166900,4.5908623,0.62041706,,,,,,,,,,,,,, -167000,4.042728,0.6398359,,,,,,,,,,,,,, -167070,,,0.9422233700752258,0.2018176466226577,0.7506600022315979,1.0700610876083374,50000.0,0.6238000392913818,1.8532116413116453,10000.0,56648.74323391914,58658.53840446472,56648.74323391914,1997.844202041626,6.011024236679077,0.0 -167100,4.241358,0.71674407,,,,,,,,,,,,,, -167200,4.376945,0.73801804,,,,,,,,,,,,,, -167300,4.251634,0.65952104,,,,,,,,,,,,,, -167400,4.1096687,0.68210316,,,,,,,,,,,,,, -167500,4.2272925,0.66605747,,,,,,,,,,,,,, -167600,4.2295866,0.6672428,,,,,,,,,,,,,, -167700,4.838688,0.74312687,,,,,,,,,,,,,, -167800,4.3287044,0.70922804,,,,,,,,,,,,,, -167900,4.2778454,0.67375624,,,,,,,,,,,,,, -168000,4.245273,0.64091057,,,,,,,,,,,,,, -168100,4.3727417,0.6848504,,,,,,,,,,,,,, -168200,4.6103826,0.643297,,,,,,,,,,,,,, -168300,3.8992972,0.6102895,,,,,,,,,,,,,, -168400,4.542838,0.7518118,,,,,,,,,,,,,, -168500,4.4762883,0.6514375,,,,,,,,,,,,,, -168576,,,0.9480428695678712,0.1823372691869735,0.7523199915885925,1.0641475915908811,50000.0,0.625700056552887,1.847376108169556,10000.0,57158.81772947312,59186.4742205143,57158.81772947312,2015.589801311493,6.071972846984863,0.0 -168600,4.573614,0.74581337,,,,,,,,,,,,,, -168700,4.8047447,0.7299622,,,,,,,,,,,,,, -168800,4.492614,0.6963788,,,,,,,,,,,,,, -168900,4.2053885,0.6552119,,,,,,,,,,,,,, -169000,4.160291,0.64843065,,,,,,,,,,,,,, -169100,3.884494,0.58296776,,,,,,,,,,,,,, -169200,4.4789004,0.64262176,,,,,,,,,,,,,, -169300,3.9195356,0.682288,,,,,,,,,,,,,, -169400,4.451035,0.77415085,,,,,,,,,,,,,, -169500,4.6400924,0.6388733,,,,,,,,,,,,,, -169600,4.0794063,0.6986841,,,,,,,,,,,,,, -169700,4.7894073,0.7212014,,,,,,,,,,,,,, -169800,4.2728586,0.63175094,,,,,,,,,,,,,, -169900,4.651647,0.75248396,,,,,,,,,,,,,, -170000,4.5175214,0.71499187,,,,,,,,,,,,,, -170082,,,0.9587252736091614,0.1586360782384872,0.7521199584007263,1.072721004486084,50000.0,0.6253000497817993,1.8547792434692385,10000.0,57668.98382782936,59714.29183888435,57668.98382782936,2033.127522945404,6.13216757774353,0.0 -170100,4.358362,0.63390636,,,,,,,,,,,,,, -170200,4.0230722,0.6488799,,,,,,,,,,,,,, -170300,4.6661,0.69784904,,,,,,,,,,,,,, -170400,4.4140406,0.6489035,,,,,,,,,,,,,, -170500,4.1606865,0.65107024,,,,,,,,,,,,,, -170600,4.2902007,0.6992063,,,,,,,,,,,,,, -170700,4.204056,0.62907493,,,,,,,,,,,,,, -170800,4.631465,0.6394198,,,,,,,,,,,,,, -170900,4.5624146,0.63320166,,,,,,,,,,,,,, -171000,4.278925,0.61412644,,,,,,,,,,,,,, -171100,4.377981,0.6648495,,,,,,,,,,,,,, -171200,4.0964375,0.60247886,,,,,,,,,,,,,, -171300,4.5383925,0.68620497,,,,,,,,,,,,,, -171400,4.194836,0.5473083,,,,,,,,,,,,,, -171500,4.389491,0.6507511,,,,,,,,,,,,,, -171588,,,0.9563336968421936,0.1618997901678085,0.751039981842041,1.0672721862792969,50000.0,0.6230000257492065,1.851597547531128,10000.0,58179.04970908165,60242.26770663261,58179.04970908165,2050.924390077591,6.191601753234863,0.0 -171600,4.265093,0.6393514,,,,,,,,,,,,,, -171700,4.37125,0.7135152,,,,,,,,,,,,,, -171800,4.562885,0.6028528,,,,,,,,,,,,,, -171900,4.639116,0.61960685,,,,,,,,,,,,,, -172000,4.396407,0.64679,,,,,,,,,,,,,, -172100,4.7995677,0.652212,,,,,,,,,,,,,, -172200,4.906142,0.68603677,,,,,,,,,,,,,, -172300,4.4557486,0.6501588,,,,,,,,,,,,,, -172400,4.4967566,0.6156101,,,,,,,,,,,,,, -172500,4.802207,0.6846458,,,,,,,,,,,,,, -172600,4.324065,0.62630856,,,,,,,,,,,,,, -172700,4.2636604,0.6111213,,,,,,,,,,,,,, -172800,4.13008,0.60515463,,,,,,,,,,,,,, -172900,4.485752,0.6708732,,,,,,,,,,,,,, -173000,4.134788,0.58817464,,,,,,,,,,,,,, -173094,,,0.9559350609779358,0.161485806107521,0.7530199885368347,1.0638294219970703,50000.0,0.6283000111579895,1.836564540863037,10000.0,58689.24033522606,60771.73478055,58689.24033522606,2070.0848445892334,6.254953384399414,0.0 -173100,4.160212,0.6415992,,,,,,,,,,,,,, -173200,4.2885466,0.6645707,,,,,,,,,,,,,, -173300,4.699148,0.7096751,,,,,,,,,,,,,, -173400,4.254508,0.56055677,,,,,,,,,,,,,, -173500,4.3700976,0.6303349,,,,,,,,,,,,,, -173600,4.6341586,0.65847397,,,,,,,,,,,,,, -173700,4.61932,0.69750816,,,,,,,,,,,,,, -173800,5.0065694,0.70278305,,,,,,,,,,,,,, -173900,4.670669,0.66771555,,,,,,,,,,,,,, -174000,4.573702,0.63372815,,,,,,,,,,,,,, -174100,4.659018,0.66495633,,,,,,,,,,,,,, -174200,4.3829665,0.6634216,,,,,,,,,,,,,, -174300,4.5127196,0.68398273,,,,,,,,,,,,,, -174400,4.1906853,0.64085287,,,,,,,,,,,,,, -174500,4.91076,0.6856111,,,,,,,,,,,,,, -174599,,,0.9558952450752258,0.1617143154144287,0.7538999915122986,1.0610108375549316,50000.0,0.6271000504493713,1.835718870162964,10000.0,59199.163927316666,61299.51197075844,59199.163927316666,2087.834163427353,6.306419849395752,0.0 -174600,4.380221,0.69199795,,,,,,,,,,,,,, -174700,4.47241,0.67756754,,,,,,,,,,,,,, -174800,4.4528146,0.65057784,,,,,,,,,,,,,, -174900,4.488529,0.64507186,,,,,,,,,,,,,, -175000,4.576517,0.68997234,,,,,,,,,,,,,, -175100,4.600994,0.6628512,,,,,,,,,,,,,, -175200,4.6470065,0.6435485,,,,,,,,,,,,,, -175300,4.7661347,0.7373832,,,,,,,,,,,,,, -175400,4.4601545,0.613933,,,,,,,,,,,,,, -175500,4.799649,0.680072,,,,,,,,,,,,,, -175600,4.508797,0.6112422,,,,,,,,,,,,,, -175700,4.075836,0.6471265,,,,,,,,,,,,,, -175800,4.385288,0.6202572,,,,,,,,,,,,,, -175900,4.964283,0.7550105,,,,,,,,,,,,,, -176000,4.9464474,0.72908485,,,,,,,,,,,,,, -176100,4.437423,0.60784274,,,,,,,,,,,,,, -176104,,,0.9573700428009032,0.1562838703393936,0.7552399635314941,1.0589247941970823,50000.0,0.6297000050544739,1.8355635404586792,10000.0,59709.09991669655,61827.27109313011,59709.09991669655,2105.543501853943,6.366549730300903,0.0 -176200,4.630244,0.6555191,,,,,,,,,,,,,, -176300,4.6056046,0.6437645,,,,,,,,,,,,,, -176400,5.0404973,0.6595234,,,,,,,,,,,,,, -176500,4.5698714,0.66868055,,,,,,,,,,,,,, -176600,4.341032,0.5942122,,,,,,,,,,,,,, -176700,4.935252,0.71878093,,,,,,,,,,,,,, -176800,4.343473,0.62222517,,,,,,,,,,,,,, -176900,4.516055,0.66965187,,,,,,,,,,,,,, -177000,4.5543013,0.6595895,,,,,,,,,,,,,, -177100,4.917612,0.7591307,,,,,,,,,,,,,, -177200,4.091444,0.60044885,,,,,,,,,,,,,, -177300,4.6584077,0.69835335,,,,,,,,,,,,,, -177400,4.1362476,0.574371,,,,,,,,,,,,,, -177500,4.742643,0.7032992,,,,,,,,,,,,,, -177600,4.52243,0.61724144,,,,,,,,,,,,,, -177609,,,0.9576889276504515,0.1566994935274124,0.7546600103378296,1.05818510055542,50000.0,0.6300000548362732,1.838411450386048,10000.0,60219.06253504753,62355.04538965225,60219.06253504753,2123.2411789894104,6.4263880252838135,0.0 -177700,4.2418165,0.6319198,,,,,,,,,,,,,, -177800,4.626449,0.7026026,,,,,,,,,,,,,, -177900,4.7764907,0.70783544,,,,,,,,,,,,,, -178000,4.8087764,0.67040676,,,,,,,,,,,,,, -178100,4.814863,0.6458997,,,,,,,,,,,,,, -178200,4.348995,0.6356518,,,,,,,,,,,,,, -178300,4.2119412,0.614013,,,,,,,,,,,,,, -178400,4.220364,0.6022758,,,,,,,,,,,,,, -178500,4.562172,0.6366268,,,,,,,,,,,,,, -178600,5.062126,0.7319252,,,,,,,,,,,,,, -178700,4.43213,0.634988,,,,,,,,,,,,,, -178800,4.474416,0.6326275,,,,,,,,,,,,,, -178900,4.317527,0.62261593,,,,,,,,,,,,,, -179000,4.816504,0.597915,,,,,,,,,,,,,, -179100,4.942334,0.6899636,,,,,,,,,,,,,, -179114,,,0.9618940949440002,0.1457709223031997,0.7546399831771851,1.054285764694214,50000.0,0.6331000328063965,1.8273247480392456,10000.0,60729.021720170975,62882.94174575806,60729.021720170975,2141.059831142425,6.489727973937988,0.0 -179200,4.740948,0.5705929,,,,,,,,,,,,,, -179300,4.573967,0.66618377,,,,,,,,,,,,,, -179400,4.33987,0.5680778,,,,,,,,,,,,,, -179500,4.463926,0.636537,,,,,,,,,,,,,, -179600,5.264066,0.6447028,,,,,,,,,,,,,, -179700,4.031716,0.5894791,,,,,,,,,,,,,, -179800,4.2047257,0.62128735,,,,,,,,,,,,,, -179900,4.7728243,0.5961811,,,,,,,,,,,,,, -180000,4.189294,0.5605238,,,,,,,,,,,,,, -180100,4.413083,0.6410082,,,,,,,,,,,,,, -180200,5.365277,0.58082336,,,,,,,,,,,,,, -180300,4.562582,0.6169249,,,,,,,,,,,,,, -180400,4.248384,0.643865,,,,,,,,,,,,,, -180500,4.2705693,0.6356393,,,,,,,,,,,,,, -180600,4.340384,0.55101764,,,,,,,,,,,,,, -180620,,,0.9607979655265808,0.1452068239450454,0.7542200088500977,1.0534926652908323,50000.0,0.6304000020027161,1.829865574836731,10000.0,61239.021084070206,63410.9137442112,61239.021084070206,2158.9144999980927,6.554764747619629,0.0 -180700,4.7632866,0.6533339,,,,,,,,,,,,,, -180800,4.4765234,0.65886986,,,,,,,,,,,,,, -180900,4.1637535,0.57011473,,,,,,,,,,,,,, -181000,4.5280714,0.5982453,,,,,,,,,,,,,, -181100,4.5108056,0.67535114,,,,,,,,,,,,,, -181200,4.227744,0.61470807,,,,,,,,,,,,,, -181300,4.406019,0.61120766,,,,,,,,,,,,,, -181400,4.713855,0.63746786,,,,,,,,,,,,,, -181500,4.0288463,0.58664227,,,,,,,,,,,,,, -181600,4.5873785,0.62471807,,,,,,,,,,,,,, -181700,4.3661118,0.6231191,,,,,,,,,,,,,, -181800,4.908829,0.7017721,,,,,,,,,,,,,, -181900,4.8000727,0.6904024,,,,,,,,,,,,,, -182000,4.5711737,0.6245921,,,,,,,,,,,,,, -182100,4.365741,0.6296063,,,,,,,,,,,,,, -182125,,,0.9617745280265808,0.1454741656780243,0.7547399997711182,1.053675413131714,50000.0,0.6331000328063965,1.830521583557129,10000.0,61748.97645926476,63938.778237342834,61748.97645926476,2176.718202829361,6.606281042098999,0.0 -182200,4.39265,0.6675839,,,,,,,,,,,,,, -182300,4.680056,0.6429658,,,,,,,,,,,,,, -182400,5.6823773,0.69837177,,,,,,,,,,,,,, -182500,4.920974,0.6424817,,,,,,,,,,,,,, -182600,4.1253514,0.639927,,,,,,,,,,,,,, -182700,4.2928357,0.60896075,,,,,,,,,,,,,, -182800,4.9842134,0.6306653,,,,,,,,,,,,,, -182900,4.323523,0.6214145,,,,,,,,,,,,,, -183000,4.951533,0.6526,,,,,,,,,,,,,, -183100,4.228035,0.66411525,,,,,,,,,,,,,, -183200,4.4659357,0.592696,,,,,,,,,,,,,, -183300,5.08712,0.70375174,,,,,,,,,,,,,, -183400,4.531077,0.6512959,,,,,,,,,,,,,, -183500,4.4413204,0.635223,,,,,,,,,,,,,, -183600,4.634992,0.71264696,,,,,,,,,,,,,, -183630,,,0.9620137214660645,0.1412625908851623,0.754859983921051,1.0554081201553345,50000.0,0.6306000351905823,1.831824541091919,10000.0,62258.98048186302,64466.852714538574,62258.98048186302,2194.670075416565,6.670706748962402,0.0 -183700,4.236908,0.5400254,,,,,,,,,,,,,, -183800,4.5575266,0.6100406,,,,,,,,,,,,,, -183900,4.250189,0.6687763,,,,,,,,,,,,,, -184000,4.1901507,0.6625433,,,,,,,,,,,,,, -184100,4.9006257,0.6531587,,,,,,,,,,,,,, -184200,4.20775,0.6479423,,,,,,,,,,,,,, -184300,4.254473,0.6222254,,,,,,,,,,,,,, -184400,4.243407,0.6510787,,,,,,,,,,,,,, -184500,4.692767,0.67887414,,,,,,,,,,,,,, -184600,4.211574,0.5984837,,,,,,,,,,,,,, -184700,4.6095424,0.6228283,,,,,,,,,,,,,, -184800,4.40742,0.58509696,,,,,,,,,,,,,, -184900,4.6245394,0.64027953,,,,,,,,,,,,,, -185000,4.3230853,0.62838906,,,,,,,,,,,,,, -185100,4.6624827,0.6481297,,,,,,,,,,,,,, -185136,,,0.9618940949440002,0.143305629491806,0.754859983921051,1.0536677837371826,50000.0,0.6320000290870667,1.8308967351913448,10000.0,62769.02976322174,64994.79127025604,62769.02976322174,2212.4462909698486,6.730313539505005,0.0 -185200,4.6287947,0.628445,,,,,,,,,,,,,, -185300,4.400025,0.61362624,,,,,,,,,,,,,, -185400,4.216524,0.5957145,,,,,,,,,,,,,, -185500,4.317547,0.60175604,,,,,,,,,,,,,, -185600,4.3823647,0.67474055,,,,,,,,,,,,,, -185700,4.2319217,0.63736457,,,,,,,,,,,,,, -185800,4.41481,0.60470265,,,,,,,,,,,,,, -185842,,,,,,,,,,,63008.0367565155,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index a974ad124..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.30432653427124,0.0,42.85896849632263,1,0,42.85896849632263,0.0010000000474974,6.907756805419922,10000,83.16339540481567,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -62.0397412776947,0.0269179344177246,462.88109707832336,898,0,462.88109707832336,0.0105000007897615,6.4579176902771,10000,524.9953355789185,0.0148632805794477,6.403157711029053,0.0144999995827674,6.414474010467529,50000 -83.5831470489502,0.0547959804534912,883.0495610237122,1848,0,883.0495610237122,0.0328000001609325,5.980993270874023,10000,966.7851572036744,0.0405468754470348,5.837361812591553,0.0391599982976913,5.86493444442749,50000 -105.20509386062622,0.0855200290679931,1303.4351394176483,2799,0,1303.4351394176483,0.0517000034451484,5.6387457847595215,10000,1408.873723268509,0.0707421824336052,5.413792610168457,0.0650599971413612,5.469770431518555,50000 -127.0231866836548,0.1160211563110351,1723.6793518066406,3749,0,1723.6793518066406,0.0730000063776969,5.3636393547058105,10000,1851.0162580013275,0.1000781208276748,5.093623638153076,0.0906400009989738,5.134443283081055,50000 -148.92200636863708,0.1431868076324463,2143.966569185257,4697,0,2143.966569185257,0.1000000014901161,5.027024269104004,10000,2293.278913974762,0.143593743443489,4.689797401428223,0.1317999958992004,4.742933750152588,50000 -170.99932527542114,0.1723501682281494,2564.190915584564,5641,0,2564.190915584564,0.127700001001358,4.731691837310791,10000,2735.65858912468,0.18896484375,4.296420097351074,0.1714600026607513,4.386421203613281,50000 -195.720308303833,0.2005727291107177,2984.3381164073944,6584,0,2984.3381164073944,0.1648000031709671,4.412674903869629,10000,3180.604308128357,0.2314257770776748,3.941406726837158,0.2192399948835373,4.0269927978515625,50000 -220.51652264595032,0.2327091693878173,3404.479038000107,7526,0,3404.479038000107,0.1961000114679336,4.192534446716309,10000,3625.622143268585,0.279296875,3.652488231658936,0.2576200067996979,3.760484218597412,50000 -252.37112522125244,0.2743103504180908,3824.4939839839935,8473,0,3824.4939839839935,0.2160000056028366,4.029370307922363,10000,4077.582981586456,0.3104882836341858,3.4264824390411377,0.2886799871921539,3.565650224685669,50000 -278.07969093322754,0.3103628158569336,4244.734249830246,9418,0,4244.734249830246,0.2433000057935714,3.854766607284546,10000,4523.618048667908,0.3550195097923279,3.1795032024383545,0.3194199800491333,3.3503851890563965,50000 -306.845290184021,0.3585376739501953,4664.890008926392,10359,0,4664.890008926392,0.267300009727478,3.692898988723755,10000,4972.636468410492,0.3695117235183716,3.0522427558898926,0.3451800048351288,3.1714067459106445,50000 -333.15025997161865,0.3876104354858398,5085.055969715118,11303,0,5085.055969715118,0.2752000093460083,3.60302472114563,10000,5419.18580698967,0.3982031047344208,2.9216232299804688,0.3621799945831299,3.0729660987854004,50000 -361.84075355529785,0.4170222282409668,5505.355977058411,12245,0,5505.355977058411,0.3027999997138977,3.470056533813477,10000,5868.255347728729,0.4265234172344208,2.743405342102051,0.3881599903106689,2.927395820617676,50000 -393.0726861953736,0.4512176513671875,5925.322516679764,13186,0,5925.322516679764,0.3097000122070312,3.434152841567993,10000,6319.537341594696,0.4320507645606994,2.727667093276977,0.4016000032424927,2.8690223693847656,50000 -423.9458937644959,0.4859218597412109,6345.321208238602,14124,0,6345.321208238602,0.3222000300884247,3.3154313564300537,10000,6770.492334604263,0.4466406106948852,2.5795199871063232,0.4178600013256073,2.742652177810669,50000 -454.60587215423584,0.5157270431518555,6765.518880844116,15061,0,6765.518880844116,0.3351000249385834,3.2048323154449463,10000,7221.430178642273,0.46728515625,2.460857391357422,0.432559996843338,2.633084297180176,50000 -485.0385265350342,0.5460660457611084,7185.779655456543,15996,0,7185.779655456543,0.3368000090122223,3.2335305213928223,10000,7672.20264005661,0.4797070324420929,2.4331743717193604,0.428739994764328,2.6840686798095703,50000 -515.7722768783569,0.5766818523406982,7606.167441606522,16932,0,7606.167441606522,0.3441000282764435,3.238809108734131,10000,8123.402855873108,0.4716406166553497,2.5049636363983154,0.4402399957180023,2.659813642501831,50000 -548.9488704204559,0.6240944862365723,8026.408129453659,17870,0,8026.408129453659,0.3556000292301178,3.085073232650757,10000,8576.917443037033,0.4911913871765136,2.32535719871521,0.4554999768733978,2.496633529663086,50000 -582.8137938976288,0.6556878089904785,8446.362316608429,18809,0,8446.362316608429,0.3546000123023987,3.1322598457336426,10000,9030.816624879835,0.5016992092132568,2.336990594863892,0.4569000005722046,2.542475938796997,50000 -619.9755432605743,0.6908133029937744,8866.605522155762,19749,0,8866.605522155762,0.3619000315666199,3.080425500869751,10000,9488.30536866188,0.4941796660423279,2.3316903114318848,0.4655799865722656,2.4985291957855225,50000 -658.7014982700348,0.7256288528442383,9286.701465845108,20689,0,9286.701465845108,0.366100013256073,3.067256450653076,10000,9947.210430145264,0.5057812333106995,2.29396653175354,0.4740999937057495,2.452284574508667,50000 -696.9998137950897,0.7613976001739502,9706.624783277512,21618,0,9706.624783277512,0.3717000186443329,2.9770636558532715,10000,10405.515013217926,0.5250585675239563,2.16995906829834,0.4835799932479858,2.365188837051392,50000 -732.1488399505615,0.7885289192199707,10126.87127161026,22550,0,10126.87127161026,0.3823000192642212,2.93916916847229,10000,10860.985714435576,0.5533007979393005,2.0510404109954834,0.4919999837875366,2.332852840423584,50000 -768.0584897994995,0.816164493560791,10547.10615158081,23484,0,10547.10615158081,0.3897000253200531,2.925745725631714,10000,11317.205980062485,0.5318750143051147,2.154400587081909,0.4979399740695953,2.32008957862854,50000 -802.4123823642731,0.8501076698303223,10967.08280968666,24417,0,10967.08280968666,0.3899000287055969,2.9328761100769043,10000,11771.618568897247,0.5354101657867432,2.1295690536499023,0.4962599873542785,2.312972068786621,50000 -836.2594237327576,0.8776521682739258,11387.141655921936,25349,0,11387.141655921936,0.3992000222206116,2.854811668395996,10000,12225.600093364716,0.5593945384025574,1.996328830718994,0.5059399604797363,2.242471933364868,50000 -867.9045441150665,1.2171745300292969,11806.869129419329,26274,0,11806.869129419329,0.4005000293254852,2.8530092239379883,10000,12677.361342906952,0.5405468344688416,2.104295015335083,0.5084199905395508,2.259407997131348,50000 -898.9583828449249,1.245424747467041,12227.276673793793,27207,0,12227.276673793793,0.4026000201702118,2.8585050106048584,10000,13128.899226903915,0.5548046827316284,2.044733762741089,0.5141599774360657,2.2269556522369385,50000 -932.4390978813173,1.2829210758209229,12647.221132278442,28137,0,12647.221132278442,0.4109000265598297,2.7780418395996094,10000,13582.409242868423,0.5671679377555847,1.9500863552093504,0.5212999582290649,2.152677536010742,50000 -966.9934012889862,1.3104724884033203,13067.440270900726,29071,0,13067.440270900726,0.4243000149726867,2.731639862060547,10000,14037.257809400558,0.570507824420929,1.9200985431671145,0.5317800045013428,2.108691453933716,50000 -999.187772512436,1.343907117843628,13487.520725488665,30004,0,13487.520725488665,0.419400006532669,2.765523672103882,10000,14489.614025115969,0.5688085556030273,1.989999532699585,0.5308399796485901,2.16821813583374,50000 -1031.2494950294497,1.3736467361450195,13907.610777139664,30934,0,13907.610777139664,0.4137000143527984,2.810401201248169,10000,14941.843090057371,0.5673046708106995,1.9936236143112185,0.526419997215271,2.189337968826294,50000 -1065.047378540039,1.4030015468597412,14327.60186433792,31864,0,14327.60186433792,0.4237000346183777,2.7296688556671143,10000,15395.709174633026,0.5966015458106995,1.8194531202316284,0.5347200036048889,2.106844902038574,50000 -1098.5008039474487,1.434175968170166,14747.6784825325,32796,0,14747.6784825325,0.4262000322341919,2.733640193939209,10000,15849.31853699684,0.5743749737739563,1.9367910623550413,0.5367400050163269,2.105886220932007,50000 -1132.3038840293884,1.4670917987823486,15167.680012226105,33728,0,15167.680012226105,0.4223000109195709,2.7446610927581787,10000,16303.204138755798,0.5781444907188416,1.909422755241394,0.536899983882904,2.1106417179107666,50000 -1164.2943496704102,1.4970180988311768,15587.909015655518,34663,0,15587.909015655518,0.427700012922287,2.7481539249420166,10000,16755.502063274384,0.5896288752555847,1.907817840576172,0.5367599725723267,2.140103340148926,50000 -1197.9523015022278,1.5312883853912354,16008.223159313202,35595,0,16008.223159313202,0.4282000064849853,2.7046916484832764,10000,17209.556839942932,0.5802148580551147,1.8932175636291504,0.5425800085067749,2.0681850910186768,50000 -1231.2077586650848,1.5664069652557373,16428.311499357224,36527,0,16428.311499357224,0.4364000260829925,2.7020034790039062,10000,17662.983982801437,0.5877343416213989,1.890945553779602,0.545740008354187,2.086134433746338,50000 -1265.4408011436462,1.5989644527435305,16848.334990262985,37459,0,16848.334990262985,0.4308000206947326,2.711873769760132,10000,18117.32121515274,0.5893945097923279,1.879970908164978,0.5448200106620789,2.087385654449463,50000 -1297.7227976322174,1.631274700164795,17268.433282613754,38392,0,17268.433282613754,0.4370000064373016,2.6826961040496826,10000,18569.78191447258,0.6200585961341858,1.7613177299499512,0.5546599626541138,2.0597145557403564,50000 -1330.3051433563232,1.663404941558838,17688.36795592308,39325,0,17688.36795592308,0.4411000311374664,2.634138584136963,10000,19022.378975868225,0.5931640267372131,1.8362483978271484,0.5575799942016602,2.013503313064575,50000 -1363.0829148292542,1.6961984634399414,18108.57060527801,40258,0,18108.57060527801,0.4405000209808349,2.6993274688720703,10000,19475.44114756584,0.5926952958106995,1.8907710313797,0.5488799810409546,2.0913960933685303,50000 -1394.9611177444458,1.7287757396697998,18528.672277212143,41188,0,18528.672277212143,0.4478000104427337,2.6145589351654053,10000,19927.50222015381,0.6147655844688416,1.7229764461517334,0.560259997844696,1.9867902994155884,50000 -1426.9955496788025,1.7593953609466553,18948.740561246872,42119,0,18948.740561246872,0.4463000297546386,2.636042594909668,10000,20379.68393588066,0.5964453220367432,1.8440462350845337,0.558899998664856,2.013521671295166,50000 -1458.5802025794983,1.7902934551239014,19368.863209962845,43052,0,19368.863209962845,0.4493000209331512,2.613379716873169,10000,20831.470037698746,0.6012109518051147,1.8032618761062624,0.5563399791717529,1.999727249145508,50000 -1492.062311410904,1.8297011852264404,19788.84431171417,43982,0,19788.84431171417,0.4533000290393829,2.577320337295532,10000,21285.02067923546,0.6162695288658142,1.7325092554092407,0.5649799704551697,1.955632925033569,50000 -1524.7891371250153,1.8625941276550293,20208.771056890488,44914,0,20208.771056890488,0.4526000320911407,2.549083709716797,10000,21737.75564527512,0.6068750023841858,1.7359304428100586,0.5651000142097473,1.932178974151612,50000 -1556.786475896835,1.8994367122650144,20628.775985956192,45847,0,20628.775985956192,0.4513000249862671,2.5971016883850098,10000,22189.844033002853,0.6036913990974426,1.786840319633484,0.5648800134658813,1.972473382949829,50000 -1589.9524147510529,1.930569648742676,21048.85280585289,46780,0,21048.85280585289,0.4513000249862671,2.5617563724517822,10000,22643.165594816208,0.6071484088897705,1.7338517904281616,0.564520001411438,1.9355089664459229,50000 -1623.8386886119845,1.963141202926636,21468.807317256927,47715,0,21468.807317256927,0.4541000127792358,2.577409744262696,10000,23097.087222337723,0.6330664157867432,1.6488145589828491,0.5689399838447571,1.946183681488037,50000 -1655.5556573867798,1.9970552921295168,21889.076536417007,48651,0,21889.076536417007,0.4523000121116638,2.5711467266082764,10000,23549.15618038177,0.6096093654632568,1.764771580696106,0.5680999755859375,1.94655179977417,50000 -1687.9413216114044,2.028259038925171,22309.38073515892,49585,0,22309.38073515892,0.4537000358104706,2.5811941623687744,10000,24001.9259724617,0.6127734184265137,1.7522122859954834,0.5688999891281128,1.9498313665390008,50000 -1720.9364750385284,2.060332059860229,22729.444053173065,50516,0,22729.444053173065,0.4613000154495239,2.5196902751922607,10000,24455.06451320648,0.6240820288658142,1.6619887351989746,0.5757799744606018,1.885627508163452,50000 -1753.9146332740784,2.098830223083496,23149.631596565247,51449,0,23149.631596565247,0.4574000239372253,2.5609421730041504,10000,24908.31686663628,0.6103906035423279,1.7757805585861206,0.5678399801254272,1.9511852264404297,50000 -1787.3572108745575,2.134948968887329,23569.55105662346,52382,0,23569.55105662346,0.459600031375885,2.524887084960937,10000,25361.764173030853,0.6201757788658142,1.6868958473205566,0.5767799615859985,1.888272166252136,50000 -1820.436414003372,2.1685073375701904,23989.92184877396,53312,0,23989.92184877396,0.4633000195026397,2.536857843399048,10000,25815.29575109481,0.6251562237739563,1.6901180744171145,0.5781999826431274,1.9079535007476809,50000 -1852.839579820633,2.207714319229126,24410.181674718857,54243,0,24410.181674718857,0.4629000127315521,2.5215423107147217,10000,26268.046141147614,0.6433789134025574,1.6116437911987305,0.5797600150108337,1.8849976062774656,50000 -1886.1577606201167,2.2408063411712646,24830.14268946648,55173,0,24830.14268946648,0.4659000337123871,2.519357681274414,10000,26721.406907081604,0.6248632669448853,1.6888477802276611,0.5834400057792664,1.886579155921936,50000 -1918.165089845657,2.2733821868896484,25250.628484487534,56105,0,25250.628484487534,0.4644000232219696,2.5098447799682617,10000,27173.98124217987,0.6288476586341858,1.6663347482681274,0.5836799740791321,1.8777925968170168,50000 -1951.056599855423,2.3057284355163574,25670.61145210266,57036,0,25670.61145210266,0.4666000306606293,2.5017263889312744,10000,27626.93688774109,0.644726574420929,1.6122007369995115,0.5859599709510803,1.872389435768128,50000 -1984.3666186332705,2.3380661010742188,26091.04658770561,57970,0,26091.04658770561,0.4657000303268432,2.4734883308410645,10000,28080.76435756684,0.6248242259025574,1.6605851650238037,0.5842999815940857,1.8347177505493164,50000 -2017.8947920799253,2.3739945888519287,26511.307859420776,58900,0,26511.307859420776,0.4716000258922577,2.476768016815185,10000,28534.63808321953,0.6384179592132568,1.634387731552124,0.5914799571037292,1.8405027389526367,50000 -2051.0959992408752,2.4093093872070312,26931.3852558136,59829,0,26931.3852558136,0.4670000076293945,2.510704517364502,10000,28988.00011849404,0.634570300579071,1.642100811004639,0.58406001329422,1.868959784507752,50000 -2081.7578916549683,2.442704200744629,27351.664115190502,60758,0,27351.664115190502,0.4716000258922577,2.478414297103882,10000,29439.022877693176,0.6297656297683716,1.6264452934265137,0.5870400071144104,1.827787399291992,50000 -2113.970644235611,2.47971773147583,27771.800651311874,61688,0,27771.800651311874,0.4758000373840332,2.4539406299591064,10000,29891.457787036896,0.6357226371765137,1.6139147281646729,0.5892199873924255,1.826127290725708,50000 -2147.4553532600403,2.5132803916931152,28192.091106176376,62618,0,28192.091106176376,0.4781000316143036,2.4324634075164795,10000,30345.31472611428,0.6410741806030273,1.5769309997558594,0.594760000705719,1.7927261590957642,50000 -2181.003592252732,2.5525503158569336,28612.40795993805,63551,0,28612.40795993805,0.4824000298976898,2.3906712532043457,10000,30799.267755031586,0.6690039038658142,1.4578577280044556,0.5971800088882446,1.771941065788269,50000 -2215.162506103516,2.5897204875946045,29032.35076379776,64482,0,29032.35076379776,0.4761000275611877,2.455456018447876,10000,31253.455913305283,0.6347460746765137,1.6374387741088867,0.5944199562072754,1.825912952423096,50000 -2248.791541337967,2.6287543773651123,29452.668427705765,65414,0,29452.668427705765,0.4762000143527984,2.461789846420288,10000,31707.489936828613,0.6394921541213989,1.6012791395187378,0.5922200083732605,1.8196548223495483,50000 -2282.513402938843,2.665854454040528,29872.83366370201,66344,0,29872.83366370201,0.4827000200748443,2.3863396644592285,10000,32161.46227788925,0.659472644329071,1.5096582174301147,0.6010400056838989,1.7638288736343384,50000 -2316.063986301422,2.7025163173675537,30292.992176771164,67275,0,30292.992176771164,0.4785000085830688,2.4507222175598145,10000,32615.25714874268,0.6406835913658142,1.6189045906066897,0.6001200079917908,1.808770775794983,50000 -2348.187881231308,2.735978603363037,30712.97263765335,68204,0,30712.97263765335,0.4745000302791595,2.447077989578247,10000,33067.442917346954,0.6403124928474426,1.6007063388824463,0.5951399803161621,1.8148831129074097,50000 -2381.223728656769,2.768976926803589,31133.27097249031,69131,0,31133.27097249031,0.4790000319480896,2.451775312423706,10000,33520.85844230652,0.6518359184265137,1.5784968137741089,0.599399983882904,1.818591475486756,50000 -2413.0035569667816,2.803950309753418,31553.39439797401,70062,0,31553.39439797401,0.4794000089168548,2.46195387840271,10000,33972.844121456146,0.6419140696525574,1.6183698177337646,0.5989399552345276,1.82237446308136,50000 -2445.9967498779297,2.849007129669189,31973.70577263832,70992,0,31973.70577263832,0.4846000373363495,2.424272298812866,10000,34426.241938352585,0.64990234375,1.5852410793304443,0.6050199866294861,1.7876743078231812,50000 -2479.398825407028,2.8843657970428467,32393.78838253021,71924,0,32393.78838253021,0.4853000342845917,2.3803160190582275,10000,34879.81064558029,0.6512500047683716,1.5644874572753906,0.6061800122261047,1.7602328062057495,50000 -2513.4549918174744,2.9265990257263184,32814.01478791237,72858,0,32814.01478791237,0.4829000234603882,2.3781180381774902,10000,35334.18482041359,0.6725976467132568,1.4549552202224731,0.6046000123023987,1.746518850326538,50000 -2546.6604709625244,2.96217679977417,33234.2273080349,73792,0,33234.2273080349,0.4903000295162201,2.389852523803711,10000,35787.68767333031,0.6545507907867432,1.5467392206192017,0.6106799840927124,1.7467910051345823,50000 -2579.5513093471527,3.009035110473633,33654.26421165466,74726,0,33654.26421165466,0.4914000332355499,2.389246225357056,10000,36240.71160840988,0.6555468440055847,1.561237096786499,0.6126199960708618,1.7528570890426636,50000 -2612.520934343338,3.050046682357788,34074.34876227379,75657,0,34074.34876227379,0.4906000196933746,2.3438291549682617,10000,36693.855113983154,0.6675195097923279,1.4419922828674316,0.6165199875831604,1.693387746810913,50000 -2643.3242888450623,3.088836431503296,34494.56536388397,76588,0,34494.56536388397,0.4869000315666199,2.4197306632995605,10000,37144.96238017082,0.6521288752555847,1.591919183731079,0.6082599759101868,1.792444348335266,50000 -2676.3921184539795,3.130977869033813,34915.00139904022,77520,0,34915.00139904022,0.4862000346183777,2.3950002193450928,10000,37598.55631041527,0.6513671875,1.5582828521728516,0.6090599894523621,1.7538148164749146,50000 -2708.896152973175,3.167029619216919,35335.17891597748,78452,0,35335.17891597748,0.4884000122547149,2.357494354248047,10000,38051.32184123993,0.6617382764816284,1.5026801824569702,0.6145200133323669,1.7247921228408811,50000 -2742.400264978409,3.214477300643921,35755.31320667267,79383,0,35755.31320667267,0.4864000082015991,2.4213500022888184,10000,38505.05619072914,0.6783593893051147,1.4718852043151855,0.612280011177063,1.779030442237854,50000 -2775.382992506027,3.256436347961426,36175.47290205956,80314,0,36175.47290205956,0.4903000295162201,2.387882232666016,10000,38958.28948068619,0.655468761920929,1.5589125156402588,0.6153199672698975,1.7471002340316772,50000 -2809.1208050251007,3.3032021522521973,36595.45416688919,81246,0,36595.45416688919,0.4992000162601471,2.3286523818969727,10000,39412.10371303558,0.672070324420929,1.45991849899292,0.6218799948692322,1.682550311088562,50000 -2842.265993595124,3.340360879898072,37015.47077083588,82178,0,37015.47077083588,0.4916000366210937,2.343575954437256,10000,39865.35120391846,0.677539050579071,1.4339114427566528,0.6177999973297119,1.7054755687713623,50000 -2871.618143796921,3.3766629695892334,37435.68727660179,83109,0,37435.68727660179,0.5033000111579895,2.3191699981689453,10000,40315.00443935394,0.6681249737739563,1.490073561668396,0.619879961013794,1.6944104433059692,50000 -2903.392266750336,3.419344186782837,37855.84566473961,84038,0,37855.84566473961,0.4980000257492065,2.30544662475586,10000,40767.02870512009,0.6651171445846558,1.465654730796814,0.6218000054359436,1.6730984449386597,50000 -2935.9429478645325,3.4596714973449707,38275.783217191696,84971,0,38275.783217191696,0.503000020980835,2.307889223098755,10000,41219.60556221008,0.6802929639816284,1.428873062133789,0.6230999827384949,1.6761740446090698,50000 -2968.3241169452667,3.498311281204224,38695.95108413696,85903,0,38695.95108413696,0.499500036239624,2.3466873168945312,10000,41672.24136352539,0.6666210889816284,1.5058388710021973,0.623259961605072,1.7001662254333496,50000 -2999.692320823669,3.542412757873535,39116.04857087135,86834,0,39116.04857087135,0.5,2.360891819000244,10000,42123.79933476448,0.6658398509025574,1.5209107398986816,0.6223999857902527,1.7161909341812134,50000 -3033.672209739685,3.5862622261047363,39536.07976317406,87764,0,39536.07976317406,0.4973000288009643,2.3251760005950928,10000,42577.90255379677,0.6809179782867432,1.4363369941711426,0.6262999773025513,1.676653504371643,50000 -3066.316313743592,3.6272923946380615,39956.33709144592,88696,0,39956.33709144592,0.5035000443458557,2.3024938106536865,10000,43030.892899513245,0.6968359351158142,1.3666648864746094,0.6245999932289124,1.67709481716156,50000 -3097.2331142425537,3.6694083213806152,40376.373056173325,89628,0,40376.373056173325,0.5107000470161438,2.262535572052002,10000,43481.935858011246,0.6750390529632568,1.4493118524551392,0.6331999897956848,1.6430368423461914,50000 -3127.205216407776,3.715543746948242,40796.29656982422,90554,0,40796.29656982422,0.5046000480651855,2.265634775161743,10000,43931.92548966408,0.6821093559265137,1.3882906436920166,0.6303799748420715,1.624939203262329,50000 -3159.993196964264,3.765239715576172,41216.31773328781,91481,0,41216.31773328781,0.5101000070571899,2.2855710983276367,10000,44384.83210873604,0.6850780844688416,1.3804093599319458,0.6272000074386597,1.6394050121307373,50000 -3192.737065553665,3.803889751434326,41636.72108960152,92412,0,41636.72108960152,0.5117000341415405,2.257552146911621,10000,44838.06574630737,0.6785351634025574,1.4298800230026243,0.6362599730491638,1.630461573600769,50000 -3223.965080499649,3.8404388427734375,42057.02818584442,93343,0,42057.02818584442,0.5101000070571899,2.2780954837799072,10000,45289.685396671295,0.68212890625,1.4268741607666016,0.6340399980545044,1.639679789543152,50000 -3255.977776527405,3.900721549987793,42476.98532438278,94272,0,42476.98532438278,0.5105000138282776,2.2469406127929688,10000,45741.764285326,0.6900390386581421,1.358396291732788,0.632860004901886,1.6014665365219116,50000 -3286.2305703163147,3.9465410709381104,42896.978048563,95203,0,42896.978048563,0.5120000243186951,2.2654266357421875,10000,46192.10404133797,0.69740229845047,1.3583762645721436,0.6399799585342407,1.6108720302581787,50000 -3316.7220873832703,3.99554705619812,43316.982568740845,96132,0,43316.982568740845,0.5175000429153442,2.22903060913086,10000,46642.69738292694,0.6863671541213989,1.4062321186065674,0.6400399804115295,1.6143362522125244,50000 -3346.5461995601654,4.039233446121216,43736.97590112686,97062,0,43736.97590112686,0.5182999968528748,2.2834768295288086,10000,47092.60734796524,0.685546875,1.4155725240707395,0.6375600099563599,1.633133053779602,50000 -3377.816866159439,4.090280055999756,44156.90197920799,97992,0,44156.90197920799,0.5205000042915344,2.2080962657928467,10000,47543.903435230255,0.7066406011581421,1.2971094846725464,0.6403200030326843,1.582248568534851,50000 -3411.555104732513,4.134812593460083,44576.86071538925,98923,0,44576.86071538925,0.5091000199317932,2.254444599151612,10000,47997.69262051582,0.6883788704872131,1.393426775932312,0.641319990158081,1.6070661544799805,50000 -3443.7383959293365,4.1757354736328125,44996.839626550674,99856,0,44996.839626550674,0.5134000182151794,2.239524602890014,10000,48449.945254564285,0.6871289014816284,1.3720650672912598,0.6365999579429626,1.5965046882629397,50000 -3477.3515434265137,4.214470624923706,45416.97100830078,100785,0,45416.97100830078,0.522599995136261,2.181353807449341,10000,48903.77703499794,0.7017577886581421,1.312267780303955,0.6451999545097351,1.5605461597442627,50000 -3510.946757078171,4.263065338134766,45837.05639505386,101716,0,45837.05639505386,0.5249000191688538,2.2173092365264893,10000,49357.55435633659,0.69349604845047,1.36771821975708,0.6436399817466736,1.5866146087646484,50000 -3541.7239258289337,4.30495023727417,46257.20961642265,102645,0,46257.20961642265,0.5223000049591064,2.250572919845581,10000,49808.57501959801,0.6882616877555847,1.39771568775177,0.6396999955177307,1.6158915758132937,50000 -3571.6743774414062,4.362541913986206,46677.277096033096,103573,0,46677.277096033096,0.5329000353813171,2.161540508270264,10000,50258.69889426232,0.7084375023841858,1.2893983125686646,0.6509000062942505,1.5381120443344116,50000 -3602.106454372406,4.4127349853515625,47097.76482272148,104501,0,47097.76482272148,0.526199996471405,2.236384868621826,10000,50709.71799230576,0.7164257764816284,1.3079707622528076,0.6469599604606628,1.601804494857788,50000 -3635.9119005203247,4.460398435592651,47517.686259269714,105427,0,47517.686259269714,0.5055000185966492,2.2996857166290283,10000,51163.54053092003,0.6813281178474426,1.432881474494934,0.6276599764823914,1.6766186952590942,50000 -3667.913044929504,4.501140356063843,47937.69239234924,106359,0,47937.69239234924,0.5250000357627869,2.1776108741760254,10000,51615.63697743416,0.7081249952316284,1.3071223497390747,0.653439998626709,1.5344473123550415,50000 -3699.998573303223,4.5414369106292725,48357.65385222435,107289,0,48357.65385222435,0.535800039768219,2.1337265968322754,10000,52067.7736222744,0.7215625047683716,1.2222944498062134,0.6552599668502808,1.523241400718689,50000 -3730.917119503021,4.586863279342651,48777.81280827522,108222,0,48777.81280827522,0.5349000096321106,2.1375465393066406,10000,52518.94608283043,0.7074999809265137,1.2904741764068604,0.6606400012969971,1.5038014650344849,50000 -3764.027851104736,4.637696266174316,49197.85523939133,109153,0,49197.85523939133,0.5330000519752502,2.1510488986968994,10000,52972.1984333992,0.7104882597923279,1.2939804792404177,0.6558799743652344,1.531550407409668,50000 -3798.13270521164,4.687514781951904,49617.85936307907,110084,0,49617.85936307907,0.5314000248908997,2.220554113388061,10000,53426.40553641319,0.7136523127555847,1.3198362588882446,0.6524199843406677,1.5822229385375977,50000 -3832.661499977112,4.737124443054199,50038.1953496933,111016,0,50038.1953496933,0.537600040435791,2.108808279037476,10000,53881.36761689186,0.7125585675239563,1.2730872631072998,0.6649599671363831,1.4868433475494385,50000 -3860.9939935207367,4.784902572631836,50458.33082890511,111947,0,50458.33082890511,0.5469000339508057,2.1059060096740723,10000,54329.931025743484,0.7182812094688416,1.2427871227264404,0.6640799641609192,1.481744647026062,50000 -3894.177620410919,4.832216739654541,50878.57825565338,112868,0,50878.57825565338,0.5367000102996826,2.121800184249878,10000,54783.457184791565,0.7190234065055847,1.2374742031097412,0.6615200042724609,1.4908949136734009,50000 -3927.492713212967,4.883016586303711,51298.48287606239,113797,0,51298.48287606239,0.5442000031471252,2.132467031478882,10000,55236.77574682236,0.7408984303474426,1.1787163019180298,0.666979968547821,1.5026799440383911,50000 -3961.1463055610657,4.928993463516235,51718.38898730278,114730,0,51718.38898730278,0.5469000339508057,2.0868093967437744,10000,55690.430280447006,0.7202734351158142,1.223569393157959,0.6654999852180481,1.4597338438034058,50000 -3993.5162620544434,4.973785161972046,52138.31260251999,115661,0,52138.31260251999,0.5440000295639038,2.075801372528076,10000,56142.81825685501,0.7249413728713989,1.2017278671264648,0.6695599555969238,1.4541553258895874,50000 -4026.839727401733,5.018024444580078,52558.60073399544,116592,0,52558.60073399544,0.5515000224113464,2.100456476211548,10000,56596.52179288864,0.7337695360183716,1.1900928020477295,0.675059974193573,1.4655356407165527,50000 -4057.929902076721,5.0658793449401855,52978.923567056656,117526,0,52978.923567056656,0.5491000413894653,2.076021194458008,10000,57048.03144288063,0.730664074420929,1.1986579895019531,0.6748799681663513,1.436489820480347,50000 -4089.267287492752,5.115900993347168,53399.125903368,118457,0,53399.125903368,0.5466000437736511,2.084545135498047,10000,57499.6692841053,0.7255663871765137,1.2088725566864014,0.6714800000190735,1.4509501457214355,50000 -4121.7964906692505,5.169644355773926,53819.215997457504,119385,0,53819.215997457504,0.5509999990463257,2.038119316101074,10000,57952.38986158371,0.7369726300239563,1.1480770111083984,0.6777399778366089,1.4084073305130005,50000 -4155.675201416016,5.214017868041992,54239.2432115078,120316,0,54239.2432115078,0.5499000549316406,2.078464984893799,10000,58406.38825464249,0.7320312261581421,1.1968371868133545,0.6744199991226196,1.4489213228225708,50000 -4186.622891664505,5.255097150802612,54659.52649736405,121247,0,54659.52649736405,0.5569000244140625,2.0281074047088623,10000,58857.70808959007,0.7334765195846558,1.1736871004104614,0.6791200041770935,1.4089908599853516,50000 -4218.173988342285,5.308438777923584,55079.78484630585,122176,0,55079.78484630585,0.5547000169754028,2.018074989318848,10000,59309.61920070648,0.7381835579872131,1.1373696327209473,0.6790800094604492,1.3966187238693235,50000 -4248.474764108658,5.354437828063965,55500.03019499779,123107,0,55500.03019499779,0.5599000453948975,2.0738844871521,10000,59760.25983929634,0.7488671541213989,1.1549774408340454,0.681439995765686,1.463075876235962,50000 -4278.991844892502,5.403183937072754,55920.1936044693,124036,0,55920.1936044693,0.5573000311851501,2.0066962242126465,10000,60211.03713226318,0.7421875,1.1317152976989746,0.6843999624252319,1.3768173456192017,50000 -4311.943066358566,5.446610689163208,56340.23328781128,124965,0,56340.23328781128,0.5586000084877014,2.050241708755493,10000,60664.11954545975,0.742968738079071,1.1672155857086182,0.6843799948692322,1.420559048652649,50000 -4345.138010501862,5.488474130630493,56760.46367549896,125892,0,56760.46367549896,0.563800036907196,1.99488365650177,10000,61117.63439536095,0.7481640577316284,1.098036289215088,0.6840199828147888,1.375110149383545,50000 -4379.761988639832,5.534748315811157,57180.5708527565,126822,0,57180.5708527565,0.5568000078201294,1.996273279190064,10000,61572.45949554443,0.7443749904632568,1.1106796264648438,0.6901800036430359,1.3542922735214231,50000 -4413.811166524887,5.576773881912232,57600.62677979469,127753,0,57600.62677979469,0.5613000392913818,1.97961175441742,10000,62026.65469145775,0.7480273246765137,1.1173925399780271,0.688759982585907,1.3697985410690308,50000 -4447.039078474045,5.6262922286987305,58020.76798009872,128684,0,58020.76798009872,0.566100001335144,1.9980629682540887,10000,62480.12114715576,0.75990229845047,1.0859085321426392,0.6924999952316284,1.37213134765625,50000 -4481.064235448837,5.68550181388855,58440.78542947769,129613,0,58440.78542947769,0.570900022983551,1.9598753452301023,10000,62934.27048492432,0.7639062404632568,1.050950288772583,0.6924999952316284,1.3538199663162231,50000 -4513.996830224991,5.740103721618652,58860.95683288574,130545,0,58860.95683288574,0.5597000122070312,2.01959228515625,10000,63387.47819709778,0.7498437166213989,1.1301133632659912,0.693120002746582,1.3789575099945068,50000 -4548.038062334061,5.784178018569946,59280.94118070602,131476,0,59280.94118070602,0.5682000517845154,1.9795842170715328,10000,63841.59644985199,0.756054699420929,1.0679931640625,0.6955400109291077,1.333950757980347,50000 -4581.237900733948,5.829130172729492,59701.01261425018,132405,0,59701.01261425018,0.5670000314712524,1.9746955633163448,10000,64294.9606654644,0.7659375071525574,1.050911784172058,0.6958400011062622,1.3542274236679075,50000 -4615.49994468689,5.87558388710022,60121.277137994766,133333,0,60121.277137994766,0.5729000568389893,1.9410356283187864,10000,64749.58311963081,0.7583202719688416,1.063936471939087,0.7030799984931946,1.3143028020858765,50000 -4648.055237054825,5.918776750564575,60541.402416706085,134263,0,60541.402416706085,0.5733000040054321,1.947724461555481,10000,65202.35506153107,0.7602148056030273,1.056712985038757,0.700939953327179,1.3244282007217407,50000 -4682.473633766174,5.967818260192871,60961.52054858208,135186,0,60961.52054858208,0.5781000256538391,1.9366137981414795,10000,65656.98781871796,0.7699609398841858,1.0183870792388916,0.7020399570465088,1.3113582134246826,50000 -4716.022467851639,6.017577886581421,61381.7981479168,136116,0,61381.7981479168,0.5782000422477722,1.9222419261932373,10000,66110.9119849205,0.768359363079071,1.0378899574279783,0.705839991569519,1.3026723861694336,50000 -4746.358287096024,6.064196825027466,61801.93469142914,137045,0,61801.93469142914,0.5843000411987305,1.8952151536941528,10000,66561.47880458832,0.7707226276397705,0.9996986985206604,0.7090199589729309,1.2752629518508911,50000 -4777.243156194687,6.118428707122803,62222.269728422165,137975,0,62222.269728422165,0.5821000337600708,1.892650485038757,10000,67012.80084323883,0.7759569883346558,0.9817464351654052,0.708899974822998,1.267372965812683,50000 -4810.671858549118,6.181352853775024,62642.559386491776,138906,0,62642.559386491776,0.5886000394821167,1.8847719430923464,10000,67466.630849123,0.7888085842132568,0.9418398141860962,0.7098599672317505,1.2711541652679443,50000 -4844.179829597473,6.227135896682739,63062.58865451813,139840,0,63062.58865451813,0.5869000554084778,1.8614424467086792,10000,67920.26310777664,0.770312488079071,0.9991475343704224,0.7108599543571472,1.2587929964065552,50000 -4875.891560316086,6.273644685745239,63482.7752828598,140771,0,63482.7752828598,0.5897000432014465,1.8567043542861936,10000,68372.2566754818,0.7780663967132568,0.9655563831329346,0.7139399647712708,1.2472407817840576,50000 -4904.032790899277,6.321924686431885,63903.23413252831,141701,0,63903.23413252831,0.588200032711029,1.8638124465942385,10000,68820.95351719856,0.7883398532867432,0.9381603598594666,0.7156800031661987,1.24841046333313,50000 -4935.975115537643,6.379689931869507,64323.499345541,142629,0,64323.499345541,0.5924000144004822,1.8806276321411133,10000,69273.26719760895,0.7808007597923279,0.9934661388397216,0.7153599858283997,1.273717164993286,50000 -4969.630095720291,6.429242849349976,64743.52144980431,143560,0,64743.52144980431,0.5910000205039978,1.84019148349762,10000,69727.04166722298,0.7847851514816284,0.949887216091156,0.7196199893951416,1.221478819847107,50000 -5003.322789907455,6.475466012954712,65163.633311748505,144490,0,65163.633311748505,0.5952000021934509,1.845787763595581,10000,70180.94100570679,0.7881640195846558,0.9314032793045044,0.718239963054657,1.2326812744140625,50000 -5036.984225511551,6.524138689041138,65583.94645094872,145420,0,65583.94645094872,0.5932000279426575,1.8493428230285645,10000,70635.012373209,0.7870116829872131,0.9516395926475524,0.7205599546432495,1.229424238204956,50000 -5070.531363964081,6.570559024810791,66003.91966462135,146350,0,66003.91966462135,0.6014000177383423,1.814241647720337,10000,71088.62664437294,0.7876952886581421,0.9195204973220824,0.726419985294342,1.2012972831726074,50000 -5103.775855779648,6.61857533454895,66423.87331795692,147280,0,66423.87331795692,0.6032000184059143,1.803131937980652,10000,71541.92117094994,0.7960156202316284,0.888444185256958,0.7275800108909607,1.1892650127410889,50000 -5136.919964790344,6.665120840072632,66843.96156454086,148210,0,66843.96156454086,0.5993000268936157,1.8347848653793333,10000,71995.24836182594,0.7996679544448853,0.9016132950782776,0.7257199883460999,1.2199366092681885,50000 -5169.321353435516,6.71377420425415,67264.05994081497,149141,0,67264.05994081497,0.6061000227928162,1.7923401594161987,10000,72447.8445456028,0.7919726371765137,0.9040732979774476,0.7277799844741821,1.1872663497924805,50000 -5201.246497869492,6.762799263000488,67684.25943183899,150071,0,67684.25943183899,0.6043000221252441,1.791891098022461,10000,72900.0660943985,0.7992578148841858,0.882004976272583,0.7291399836540222,1.1780247688293457,50000 -5232.74213886261,6.808778524398804,68104.4731631279,151000,0,68104.4731631279,0.609000027179718,1.7773709297180176,10000,73351.86920380592,0.8058202862739563,0.8614989519119263,0.7326399683952332,1.1725473403930664,50000 -5265.935954332352,6.8564043045043945,68524.8423511982,151932,0,68524.8423511982,0.6045000553131104,1.7815557718276978,10000,73805.52795624733,0.8000390529632568,0.8734333515167236,0.731220006942749,1.1657071113586426,50000 -5299.06134557724,6.904045104980469,68945.09971499443,152865,0,68945.09971499443,0.6141000390052795,1.7590707540512085,10000,74259.00728917122,0.804492175579071,0.8603943586349487,0.7348399758338928,1.1572157144546509,50000 -5333.035125255585,6.953671216964722,69365.369992733,153797,0,69365.369992733,0.6112000346183777,1.773018479347229,10000,74713.34918832779,0.8055663704872131,0.8574345707893372,0.7340199947357178,1.1686347723007202,50000 -5367.154443502426,7.006609916687012,69785.46291160583,154728,0,69785.46291160583,0.6133000254631042,1.7443084716796875,10000,75167.66317725182,0.8135741949081421,0.8229332566261292,0.737779974937439,1.1446532011032104,50000 -5400.772227048874,7.056041240692139,70205.56688523293,155658,0,70205.56688523293,0.614300012588501,1.7738709449768066,10000,75621.48214673996,0.8089257478713989,0.8678907752037048,0.7383999824523926,1.1707357168197632,50000 -5434.675608158112,7.35836124420166,70625.50346064568,156587,0,70625.50346064568,0.6217000484466553,1.7479978799819946,10000,76075.67284274101,0.8149804472923279,0.8367795944213867,0.7415599822998047,1.1527200937271118,50000 -5468.575320243835,7.409008026123047,71045.69268107414,157511,0,71045.69268107414,0.6192000508308411,1.723163604736328,10000,76529.8608827591,0.8213085532188416,0.7878121137619019,0.745199978351593,1.118953824043274,50000 -5502.602823019028,7.457210540771484,71465.86521029472,158441,0,71465.86521029472,0.6200000047683716,1.7198753356933594,10000,76984.15627121925,0.8138671517372131,0.8201526403427124,0.7439999580383301,1.1131176948547363,50000 -5536.2020580768585,7.512519836425781,71885.79600262642,159370,0,71885.79600262642,0.6217000484466553,1.7308553457260132,10000,77437.7894179821,0.82093745470047,0.8007131218910217,0.7448599934577942,1.126524806022644,50000 -5570.279438018799,7.564241647720337,72306.14024019241,160301,0,72306.14024019241,0.6235000491142273,1.7202221155166626,10000,77892.31031370163,0.8273046612739563,0.7696381211280823,0.7482199668884277,1.115172028541565,50000 -5600.44885635376,7.616266965866089,72726.45551586151,161232,0,72726.45551586151,0.6300000548362732,1.7063981294631958,10000,78342.89540290833,0.8215234279632568,0.79811030626297,0.7492199540138245,1.1067612171173096,50000 -5630.4672927856445,7.676425695419311,73146.63851761818,162159,0,73146.63851761818,0.6338000297546387,1.673032522201538,10000,78793.20580601692,0.8233007788658142,0.7672076225280762,0.7484999895095825,1.0859493017196655,50000 -5663.513606309891,7.735929727554321,73566.73581504822,163088,0,73566.73581504822,0.6310000419616699,1.7050611972808838,10000,79246.45765209198,0.8291601538658142,0.7714617848396301,0.7514199614524841,1.10415518283844,50000 -5697.595949888229,7.793646812438965,73986.66831469536,164021,0,73986.66831469536,0.6335000395774841,1.6779024600982666,10000,79700.57873511314,0.83509761095047,0.7409867644309998,0.7541199922561646,1.079445481300354,50000 -5732.24786067009,7.844249725341797,74407.01210308075,164957,0,74407.01210308075,0.6362000107765198,1.6556838750839231,10000,80155.67283654213,0.8322851657867432,0.7366339564323425,0.7554399967193604,1.0619897842407229,50000 -5763.104225158691,7.89516282081604,74827.03679513931,165887,0,74827.03679513931,0.6294000148773193,1.6705749034881592,10000,80606.65344071388,0.8315820097923279,0.7409655451774597,0.7555999755859375,1.0678353309631348,50000 -5794.742982387543,7.947040319442749,75247.1132349968,166816,0,75247.1132349968,0.6351000070571899,1.669582486152649,10000,81058.46815085411,0.8335937261581421,0.7405239939689636,0.7565000057220459,1.076341986656189,50000 -5827.715883970261,7.998236179351807,75667.09297275543,167747,0,75667.09297275543,0.6355000138282776,1.6743074655532837,10000,81511.52088880539,0.8338671922683716,0.7559254765510559,0.75791996717453,1.0783684253692627,50000 -5861.583888530731,8.05066990852356,76087.34431004524,168678,0,76087.34431004524,0.636400043964386,1.665740728378296,10000,81965.7411146164,0.8360546827316284,0.7370734214782715,0.7594599723815918,1.0704350471496582,50000 -5895.755210876465,8.10885739326477,76507.4817943573,169609,0,76507.4817943573,0.6401000022888184,1.649182677268982,10000,82420.15692543983,0.83984375,0.7221209406852722,0.760159969329834,1.0561344623565674,50000 -5929.933554887772,8.157514572143555,76927.61034536362,170540,0,76927.61034536362,0.6373000144958496,1.6481930017471311,10000,82874.5614593029,0.8365820050239563,0.7234076261520386,0.7600599527359009,1.0516064167022705,50000 -5964.20698928833,8.21371054649353,77347.73178625107,171472,0,77347.73178625107,0.6414000391960144,1.6435645818710327,10000,83329.06152820587,0.8395116925239563,0.7189873456954956,0.7607799768447876,1.0504368543624878,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 01c8a2122..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1906 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.33511066,6.907757,,,,,,,,,,,,,, -1,,,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.85896849632263,83.16339540481567,42.85896849632263,40.30432653427124,0.0,0.0 -100,0.3815669,6.9060698,,,,,,,,,,,,,, -200,0.44061092,6.893231,,,,,,,,,,,,,, -300,0.64990646,6.855405,,,,,,,,,,,,,, -400,0.56644624,6.8186793,,,,,,,,,,,,,, -500,0.6247999,6.8514566,,,,,,,,,,,,,, -600,0.7370425,6.7467074,,,,,,,,,,,,,, -700,0.9734796,6.6812277,,,,,,,,,,,,,, -800,1.7228559,6.669016,,,,,,,,,,,,,, -898,,,0.0148632805794477,6.403157711029053,0.0144999995827674,6.414474010467529,50000.0,0.0105000007897615,6.4579176902771,10000.0,462.88109707832336,524.9953355789185,462.88109707832336,62.0397412776947,0.0269179344177246,0.0 -900,1.0779326,6.56636,,,,,,,,,,,,,, -1000,2.130615,6.537986,,,,,,,,,,,,,, -1100,1.1413043,6.6337376,,,,,,,,,,,,,, -1200,1.2914578,6.5083346,,,,,,,,,,,,,, -1300,3.494979,6.4810996,,,,,,,,,,,,,, -1400,1.8626428,6.705836,,,,,,,,,,,,,, -1500,1.3868357,6.66536,,,,,,,,,,,,,, -1600,1.8369179,6.2595816,,,,,,,,,,,,,, -1700,1.5072191,6.4254236,,,,,,,,,,,,,, -1800,1.9273188,6.564934,,,,,,,,,,,,,, -1848,,,0.0405468754470348,5.837361812591553,0.0391599982976913,5.86493444442749,50000.0,0.0328000001609325,5.980993270874023,10000.0,883.0495610237122,966.7851572036744,883.0495610237122,83.5831470489502,0.0547959804534912,0.0 -1900,1.7330856,6.208303,,,,,,,,,,,,,, -2000,2.1973748,6.1680903,,,,,,,,,,,,,, -2100,1.6239414,6.2010293,,,,,,,,,,,,,, -2200,1.3561915,6.6893725,,,,,,,,,,,,,, -2300,1.8753055,6.1070137,,,,,,,,,,,,,, -2400,2.0352476,6.050651,,,,,,,,,,,,,, -2500,1.4523338,6.424206,,,,,,,,,,,,,, -2600,1.4288659,6.054582,,,,,,,,,,,,,, -2700,1.4354227,6.537649,,,,,,,,,,,,,, -2799,,,0.0707421824336052,5.413792610168457,0.0650599971413612,5.469770431518555,50000.0,0.0517000034451484,5.6387457847595215,10000.0,1303.4351394176483,1408.873723268509,1303.4351394176483,105.20509386062622,0.0855200290679931,0.0 -2800,1.4890145,6.3477697,,,,,,,,,,,,,, -2900,2.0624974,5.953514,,,,,,,,,,,,,, -3000,1.4170961,5.967253,,,,,,,,,,,,,, -3100,1.3996145,5.9511614,,,,,,,,,,,,,, -3200,1.440032,6.441163,,,,,,,,,,,,,, -3300,1.5844904,5.819017,,,,,,,,,,,,,, -3400,1.4682246,5.8761573,,,,,,,,,,,,,, -3500,1.4416735,6.000227,,,,,,,,,,,,,, -3600,1.4744029,5.8086023,,,,,,,,,,,,,, -3700,1.4679388,5.783492,,,,,,,,,,,,,, -3749,,,0.1000781208276748,5.093623638153076,0.0906400009989738,5.134443283081055,50000.0,0.0730000063776969,5.3636393547058105,10000.0,1723.6793518066406,1851.0162580013275,1723.6793518066406,127.0231866836548,0.1160211563110351,0.0 -3800,1.4031733,5.768343,,,,,,,,,,,,,, -3900,1.745681,5.7527113,,,,,,,,,,,,,, -4000,1.4726896,6.183769,,,,,,,,,,,,,, -4100,1.3826708,6.081013,,,,,,,,,,,,,, -4200,1.1008958,6.1935673,,,,,,,,,,,,,, -4300,1.7291987,5.9429183,,,,,,,,,,,,,, -4400,1.3995796,6.513689,,,,,,,,,,,,,, -4500,1.5897392,5.5020227,,,,,,,,,,,,,, -4600,1.4754591,6.409756,,,,,,,,,,,,,, -4697,,,0.143593743443489,4.689797401428223,0.1317999958992004,4.742933750152588,50000.0,0.1000000014901161,5.027024269104004,10000.0,2143.966569185257,2293.278913974762,2143.966569185257,148.92200636863708,0.1431868076324463,0.0 -4700,1.1445404,6.280095,,,,,,,,,,,,,, -4800,1.5736758,5.6276994,,,,,,,,,,,,,, -4900,1.821114,5.4674635,,,,,,,,,,,,,, -5000,1.5287786,5.382552,,,,,,,,,,,,,, -5100,1.7939113,5.3086796,,,,,,,,,,,,,, -5200,1.6615365,5.2498636,,,,,,,,,,,,,, -5300,1.5389667,5.516762,,,,,,,,,,,,,, -5400,2.2357903,5.260895,,,,,,,,,,,,,, -5500,1.6084963,5.398707,,,,,,,,,,,,,, -5600,1.7018138,5.331379,,,,,,,,,,,,,, -5641,,,0.18896484375,4.296420097351074,0.1714600026607513,4.386421203613281,50000.0,0.127700001001358,4.731691837310791,10000.0,2564.190915584564,2735.65858912468,2564.190915584564,170.99932527542114,0.1723501682281494,0.0 -5700,1.3003325,6.1475983,,,,,,,,,,,,,, -5800,2.9912212,5.211799,,,,,,,,,,,,,, -5900,1.4847474,5.6116204,,,,,,,,,,,,,, -6000,1.600532,5.229281,,,,,,,,,,,,,, -6100,1.763965,5.1074324,,,,,,,,,,,,,, -6200,1.4868872,6.296896,,,,,,,,,,,,,, -6300,1.7170597,5.089658,,,,,,,,,,,,,, -6400,1.8728169,4.9879107,,,,,,,,,,,,,, -6500,1.4662707,5.25876,,,,,,,,,,,,,, -6584,,,0.2314257770776748,3.941406726837158,0.2192399948835373,4.0269927978515625,50000.0,0.1648000031709671,4.412674903869629,10000.0,2984.3381164073944,3180.604308128357,2984.3381164073944,195.720308303833,0.2005727291107177,0.0 -6600,1.6138202,5.3673763,,,,,,,,,,,,,, -6700,1.2130203,6.2265177,,,,,,,,,,,,,, -6800,1.4563178,4.903162,,,,,,,,,,,,,, -6900,1.4329655,5.1028643,,,,,,,,,,,,,, -7000,1.8351854,4.870098,,,,,,,,,,,,,, -7100,1.5069767,6.2596617,,,,,,,,,,,,,, -7200,1.3332598,5.658273,,,,,,,,,,,,,, -7300,0.9862061,6.315052,,,,,,,,,,,,,, -7400,1.2331787,6.121927,,,,,,,,,,,,,, -7500,1.3900236,6.048945,,,,,,,,,,,,,, -7526,,,0.279296875,3.652488231658936,0.2576200067996979,3.760484218597412,50000.0,0.1961000114679336,4.192534446716309,10000.0,3404.479038000107,3625.622143268585,3404.479038000107,220.51652264595032,0.2327091693878173,0.0 -7600,1.6720549,5.0273714,,,,,,,,,,,,,, -7700,1.685283,5.031924,,,,,,,,,,,,,, -7800,1.2917787,6.287484,,,,,,,,,,,,,, -7900,1.527833,5.0872307,,,,,,,,,,,,,, -8000,1.6568123,4.7737126,,,,,,,,,,,,,, -8100,1.6311893,4.7786064,,,,,,,,,,,,,, -8200,1.6648618,4.8198166,,,,,,,,,,,,,, -8300,1.3989843,5.6932726,,,,,,,,,,,,,, -8400,1.6590629,4.731575,,,,,,,,,,,,,, -8473,,,0.3104882836341858,3.4264824390411377,0.2886799871921539,3.565650224685669,50000.0,0.2160000056028366,4.029370307922363,10000.0,3824.4939839839935,4077.582981586456,3824.4939839839935,252.37112522125244,0.2743103504180908,0.0 -8500,1.7302619,4.9787617,,,,,,,,,,,,,, -8600,1.7230579,4.587225,,,,,,,,,,,,,, -8700,1.7223849,4.652999,,,,,,,,,,,,,, -8800,1.4480826,4.627886,,,,,,,,,,,,,, -8900,1.8442845,4.467264,,,,,,,,,,,,,, -9000,1.5580941,4.570261,,,,,,,,,,,,,, -9100,1.6297923,4.527457,,,,,,,,,,,,,, -9200,1.4918045,4.539021,,,,,,,,,,,,,, -9300,2.0004668,4.478071,,,,,,,,,,,,,, -9400,1.5490764,4.528556,,,,,,,,,,,,,, -9418,,,0.3550195097923279,3.1795032024383545,0.3194199800491333,3.3503851890563965,50000.0,0.2433000057935714,3.854766607284546,10000.0,4244.734249830246,4523.618048667908,4244.734249830246,278.07969093322754,0.3103628158569336,0.0 -9500,1.6082051,5.10377,,,,,,,,,,,,,, -9600,1.5588588,5.339176,,,,,,,,,,,,,, -9700,1.8700348,4.592454,,,,,,,,,,,,,, -9800,1.1333839,5.6403227,,,,,,,,,,,,,, -9900,1.621632,4.43828,,,,,,,,,,,,,, -10000,1.0281854,6.1483583,,,,,,,,,,,,,, -10100,1.1675003,5.6788893,,,,,,,,,,,,,, -10200,1.596367,4.6737704,,,,,,,,,,,,,, -10300,1.0141855,5.5884185,,,,,,,,,,,,,, -10359,,,0.3695117235183716,3.0522427558898926,0.3451800048351288,3.1714067459106445,50000.0,0.267300009727478,3.692898988723755,10000.0,4664.890008926392,4972.636468410492,4664.890008926392,306.845290184021,0.3585376739501953,0.0 -10400,1.2333952,5.082205,,,,,,,,,,,,,, -10500,1.6849738,4.5605087,,,,,,,,,,,,,, -10600,1.4413723,4.357038,,,,,,,,,,,,,, -10700,1.4421433,4.2818027,,,,,,,,,,,,,, -10800,1.1796523,6.009423,,,,,,,,,,,,,, -10900,1.6436565,4.322519,,,,,,,,,,,,,, -11000,1.6271378,4.3749723,,,,,,,,,,,,,, -11100,1.26115,5.3676457,,,,,,,,,,,,,, -11200,1.7842518,4.268562,,,,,,,,,,,,,, -11300,1.0653958,5.687186,,,,,,,,,,,,,, -11303,,,0.3982031047344208,2.9216232299804688,0.3621799945831299,3.0729660987854004,50000.0,0.2752000093460083,3.60302472114563,10000.0,5085.055969715118,5419.18580698967,5085.055969715118,333.15025997161865,0.3876104354858398,0.0 -11400,1.5982114,4.3944693,,,,,,,,,,,,,, -11500,1.0982196,5.7278476,,,,,,,,,,,,,, -11600,1.4085455,4.9501715,,,,,,,,,,,,,, -11700,1.3931773,4.6683216,,,,,,,,,,,,,, -11800,1.5678655,4.2089896,,,,,,,,,,,,,, -11900,1.4130958,4.1140327,,,,,,,,,,,,,, -12000,1.3415955,4.7128935,,,,,,,,,,,,,, -12100,1.6770313,4.1064568,,,,,,,,,,,,,, -12200,1.3016984,4.358662,,,,,,,,,,,,,, -12245,,,0.4265234172344208,2.743405342102051,0.3881599903106689,2.927395820617676,50000.0,0.3027999997138977,3.470056533813477,10000.0,5505.355977058411,5868.255347728729,5505.355977058411,361.84075355529785,0.4170222282409668,0.0 -12300,2.0987494,4.222327,,,,,,,,,,,,,, -12400,1.3892562,4.2827487,,,,,,,,,,,,,, -12500,1.2281731,4.534379,,,,,,,,,,,,,, -12600,1.6488825,4.150011,,,,,,,,,,,,,, -12700,1.5129715,4.1676483,,,,,,,,,,,,,, -12800,1.9234936,4.127056,,,,,,,,,,,,,, -12900,1.4822131,4.171365,,,,,,,,,,,,,, -13000,1.5467075,4.5148854,,,,,,,,,,,,,, -13100,1.618201,4.4776373,,,,,,,,,,,,,, -13186,,,0.4320507645606994,2.727667093276977,0.4016000032424927,2.8690223693847656,50000.0,0.3097000122070312,3.434152841567993,10000.0,5925.322516679764,6319.537341594696,5925.322516679764,393.0726861953736,0.4512176513671875,0.0 -13200,1.4562464,4.0836587,,,,,,,,,,,,,, -13300,1.7578392,4.2395964,,,,,,,,,,,,,, -13400,1.1516366,5.47408,,,,,,,,,,,,,, -13500,1.3521547,5.3183517,,,,,,,,,,,,,, -13600,1.8971142,4.233157,,,,,,,,,,,,,, -13700,1.6443093,4.137264,,,,,,,,,,,,,, -13800,0.89312196,5.916962,,,,,,,,,,,,,, -13900,1.341496,4.844331,,,,,,,,,,,,,, -14000,1.4921114,4.192223,,,,,,,,,,,,,, -14100,1.5991299,4.0268936,,,,,,,,,,,,,, -14124,,,0.4466406106948852,2.5795199871063232,0.4178600013256073,2.742652177810669,50000.0,0.3222000300884247,3.3154313564300537,10000.0,6345.321208238602,6770.492334604263,6345.321208238602,423.9458937644959,0.4859218597412109,0.0 -14200,1.2678801,4.403783,,,,,,,,,,,,,, -14300,1.4231181,4.058322,,,,,,,,,,,,,, -14400,1.0956874,4.831523,,,,,,,,,,,,,, -14500,2.2234247,4.117581,,,,,,,,,,,,,, -14600,1.3573707,4.8607693,,,,,,,,,,,,,, -14700,0.86979085,5.9932575,,,,,,,,,,,,,, -14800,1.4520903,4.0185785,,,,,,,,,,,,,, -14900,1.4007621,4.323617,,,,,,,,,,,,,, -15000,0.88487756,5.9461513,,,,,,,,,,,,,, -15061,,,0.46728515625,2.460857391357422,0.432559996843338,2.633084297180176,50000.0,0.3351000249385834,3.2048323154449463,10000.0,6765.518880844116,7221.430178642273,6765.518880844116,454.60587215423584,0.5157270431518555,0.0 -15100,1.5005656,4.1541414,,,,,,,,,,,,,, -15200,1.4336051,4.0038104,,,,,,,,,,,,,, -15300,1.4660448,4.0039716,,,,,,,,,,,,,, -15400,1.7844678,4.1167355,,,,,,,,,,,,,, -15500,0.9940483,5.793133,,,,,,,,,,,,,, -15600,1.3018769,4.369115,,,,,,,,,,,,,, -15700,1.3683828,4.050959,,,,,,,,,,,,,, -15800,1.4611548,4.026494,,,,,,,,,,,,,, -15900,1.1221875,4.470697,,,,,,,,,,,,,, -15996,,,0.4797070324420929,2.4331743717193604,0.428739994764328,2.6840686798095703,50000.0,0.3368000090122223,3.2335305213928223,10000.0,7185.779655456543,7672.20264005661,7185.779655456543,485.0385265350342,0.5460660457611084,0.0 -16000,1.3964449,4.3423567,,,,,,,,,,,,,, -16100,1.174214,4.479966,,,,,,,,,,,,,, -16200,1.3070045,5.103544,,,,,,,,,,,,,, -16300,1.3784715,3.9415398,,,,,,,,,,,,,, -16400,1.1195374,5.925735,,,,,,,,,,,,,, -16500,1.3133401,3.9883952,,,,,,,,,,,,,, -16600,1.1036018,5.774896,,,,,,,,,,,,,, -16700,0.9973497,5.6191864,,,,,,,,,,,,,, -16800,1.3751979,4.240373,,,,,,,,,,,,,, -16900,0.9301552,5.7785525,,,,,,,,,,,,,, -16932,,,0.4716406166553497,2.5049636363983154,0.4402399957180023,2.659813642501831,50000.0,0.3441000282764435,3.238809108734131,10000.0,7606.167441606522,8123.402855873108,7606.167441606522,515.7722768783569,0.5766818523406982,0.0 -17000,1.3743173,3.9968054,,,,,,,,,,,,,, -17100,1.2474897,3.8962862,,,,,,,,,,,,,, -17200,1.3119881,4.013825,,,,,,,,,,,,,, -17300,1.3514309,3.966508,,,,,,,,,,,,,, -17400,1.4741132,3.8010192,,,,,,,,,,,,,, -17500,1.3861331,3.9607718,,,,,,,,,,,,,, -17600,1.4394033,4.372599,,,,,,,,,,,,,, -17700,1.2768012,4.152191,,,,,,,,,,,,,, -17800,1.4756603,3.8624704,,,,,,,,,,,,,, -17870,,,0.4911913871765136,2.32535719871521,0.4554999768733978,2.496633529663086,50000.0,0.3556000292301178,3.085073232650757,10000.0,8026.408129453659,8576.917443037033,8026.408129453659,548.9488704204559,0.6240944862365723,0.0 -17900,1.2331799,3.9476347,,,,,,,,,,,,,, -18000,1.1773759,4.0666595,,,,,,,,,,,,,, -18100,1.482135,3.7925568,,,,,,,,,,,,,, -18200,1.0356406,4.8650794,,,,,,,,,,,,,, -18300,1.482793,3.9151683,,,,,,,,,,,,,, -18400,1.2501955,4.1205697,,,,,,,,,,,,,, -18500,1.3039732,3.9076471,,,,,,,,,,,,,, -18600,1.2750295,3.935873,,,,,,,,,,,,,, -18700,1.3920591,3.9350123,,,,,,,,,,,,,, -18800,1.1691134,4.3952107,,,,,,,,,,,,,, -18809,,,0.5016992092132568,2.336990594863892,0.4569000005722046,2.542475938796997,50000.0,0.3546000123023987,3.1322598457336426,10000.0,8446.362316608429,9030.816624879835,8446.362316608429,582.8137938976288,0.6556878089904785,0.0 -18900,1.2835234,3.7938685,,,,,,,,,,,,,, -19000,0.94008386,5.787056,,,,,,,,,,,,,, -19100,1.3388802,3.8271396,,,,,,,,,,,,,, -19200,1.4248797,3.8943138,,,,,,,,,,,,,, -19300,1.0835857,4.4303274,,,,,,,,,,,,,, -19400,1.3362153,3.8729112,,,,,,,,,,,,,, -19500,1.1267589,5.152501,,,,,,,,,,,,,, -19600,0.9924961,5.2142315,,,,,,,,,,,,,, -19700,1.4800142,3.8554962,,,,,,,,,,,,,, -19749,,,0.4941796660423279,2.3316903114318848,0.4655799865722656,2.4985291957855225,50000.0,0.3619000315666199,3.080425500869751,10000.0,8866.605522155762,9488.30536866188,8866.605522155762,619.9755432605743,0.6908133029937744,0.0 -19800,1.4247411,3.820827,,,,,,,,,,,,,, -19900,1.3956721,3.8974745,,,,,,,,,,,,,, -20000,1.5582091,3.8065426,,,,,,,,,,,,,, -20100,1.2567279,3.8107097,,,,,,,,,,,,,, -20200,1.0545884,4.8967032,,,,,,,,,,,,,, -20300,1.0082862,4.7957764,,,,,,,,,,,,,, -20400,1.2911224,3.7263577,,,,,,,,,,,,,, -20500,1.325903,4.2401247,,,,,,,,,,,,,, -20600,1.2292219,3.962033,,,,,,,,,,,,,, -20689,,,0.5057812333106995,2.29396653175354,0.4740999937057495,2.452284574508667,50000.0,0.366100013256073,3.067256450653076,10000.0,9286.701465845108,9947.210430145264,9286.701465845108,658.7014982700348,0.7256288528442383,0.0 -20700,1.3314855,3.7416553,,,,,,,,,,,,,, -20800,1.1923307,4.389763,,,,,,,,,,,,,, -20900,1.4499873,3.7645652,,,,,,,,,,,,,, -21000,0.93893623,5.3330736,,,,,,,,,,,,,, -21100,1.3476652,3.855993,,,,,,,,,,,,,, -21200,0.97093785,5.6260824,,,,,,,,,,,,,, -21300,1.2932109,3.80097,,,,,,,,,,,,,, -21400,1.3970604,3.8947482,,,,,,,,,,,,,, -21500,1.2724465,4.1425953,,,,,,,,,,,,,, -21600,1.3971261,4.35052,,,,,,,,,,,,,, -21618,,,0.5250585675239563,2.16995906829834,0.4835799932479858,2.365188837051392,50000.0,0.3717000186443329,2.9770636558532715,10000.0,9706.624783277512,10405.515013217926,9706.624783277512,696.9998137950897,0.7613976001739502,0.0 -21700,0.89889127,5.7630835,,,,,,,,,,,,,, -21800,0.8876219,5.843638,,,,,,,,,,,,,, -21900,1.5349743,3.823572,,,,,,,,,,,,,, -22000,1.3235927,3.8353882,,,,,,,,,,,,,, -22100,1.3384405,3.7541726,,,,,,,,,,,,,, -22200,0.9477341,5.50204,,,,,,,,,,,,,, -22300,1.3806987,3.8184845,,,,,,,,,,,,,, -22400,1.3473928,3.7079253,,,,,,,,,,,,,, -22500,0.9235801,5.4772224,,,,,,,,,,,,,, -22550,,,0.5533007979393005,2.0510404109954834,0.4919999837875366,2.332852840423584,50000.0,0.3823000192642212,2.93916916847229,10000.0,10126.87127161026,10860.985714435576,10126.87127161026,732.1488399505615,0.7885289192199707,0.0 -22600,1.393467,3.7652788,,,,,,,,,,,,,, -22700,1.2854029,4.3745613,,,,,,,,,,,,,, -22800,1.3701744,3.9223258,,,,,,,,,,,,,, -22900,1.4658985,3.6705947,,,,,,,,,,,,,, -23000,1.4635853,3.7202694,,,,,,,,,,,,,, -23100,1.2132663,3.7535994,,,,,,,,,,,,,, -23200,1.2080135,3.8287385,,,,,,,,,,,,,, -23300,1.3303988,3.8285747,,,,,,,,,,,,,, -23400,1.2508812,3.758007,,,,,,,,,,,,,, -23484,,,0.5318750143051147,2.154400587081909,0.4979399740695953,2.32008957862854,50000.0,0.3897000253200531,2.925745725631714,10000.0,10547.10615158081,11317.205980062485,10547.10615158081,768.0584897994995,0.816164493560791,0.0 -23500,1.3169329,3.7376258,,,,,,,,,,,,,, -23600,1.073695,4.6763663,,,,,,,,,,,,,, -23700,1.3364513,3.7115097,,,,,,,,,,,,,, -23800,1.0740314,5.5907025,,,,,,,,,,,,,, -23900,1.2193197,3.9064713,,,,,,,,,,,,,, -24000,1.1734403,4.121295,,,,,,,,,,,,,, -24100,1.3955309,3.7726388,,,,,,,,,,,,,, -24200,1.0853322,4.982077,,,,,,,,,,,,,, -24300,1.152095,3.8862586,,,,,,,,,,,,,, -24400,1.0484064,5.416642,,,,,,,,,,,,,, -24417,,,0.5354101657867432,2.1295690536499023,0.4962599873542785,2.312972068786621,50000.0,0.3899000287055969,2.9328761100769043,10000.0,10967.08280968666,11771.618568897247,10967.08280968666,802.4123823642731,0.8501076698303223,0.0 -24500,1.3225493,3.6833627,,,,,,,,,,,,,, -24600,1.1222271,4.1514874,,,,,,,,,,,,,, -24700,1.0488237,4.8657694,,,,,,,,,,,,,, -24800,1.299754,3.702248,,,,,,,,,,,,,, -24900,0.9001194,5.695058,,,,,,,,,,,,,, -25000,1.4135826,3.7366583,,,,,,,,,,,,,, -25100,1.1996399,4.0059876,,,,,,,,,,,,,, -25200,1.2952086,3.6356745,,,,,,,,,,,,,, -25300,1.3153713,3.63731,,,,,,,,,,,,,, -25349,,,0.5593945384025574,1.996328830718994,0.5059399604797363,2.242471933364868,50000.0,0.3992000222206116,2.854811668395996,10000.0,11387.141655921936,12225.600093364716,11387.141655921936,836.2594237327576,0.8776521682739258,0.0 -25400,1.0004456,5.3859262,,,,,,,,,,,,,, -25500,0.9691387,4.8275366,,,,,,,,,,,,,, -25600,1.1299334,4.6370697,,,,,,,,,,,,,, -25700,1.0513451,4.448268,,,,,,,,,,,,,, -25800,1.2966197,3.7396746,,,,,,,,,,,,,, -25900,1.139399,4.8146906,,,,,,,,,,,,,, -26000,1.2011029,4.3253827,,,,,,,,,,,,,, -26100,1.3207772,3.7164047,,,,,,,,,,,,,, -26200,1.5131508,3.6388159,,,,,,,,,,,,,, -26274,,,0.5405468344688416,2.104295015335083,0.5084199905395508,2.259407997131348,50000.0,0.4005000293254852,2.8530092239379883,10000.0,11806.869129419329,12677.361342906952,11806.869129419329,867.9045441150665,1.2171745300292969,0.0 -26300,1.236659,3.6013227,,,,,,,,,,,,,, -26400,1.2758584,3.7748222,,,,,,,,,,,,,, -26500,0.9930302,5.771501,,,,,,,,,,,,,, -26600,1.3482964,3.6852207,,,,,,,,,,,,,, -26700,1.2242693,3.944267,,,,,,,,,,,,,, -26800,1.325148,3.4957716,,,,,,,,,,,,,, -26900,1.2960014,4.000495,,,,,,,,,,,,,, -27000,1.3907186,3.5189102,,,,,,,,,,,,,, -27100,1.2889752,3.6035717,,,,,,,,,,,,,, -27200,1.4509848,3.6809182,,,,,,,,,,,,,, -27207,,,0.5548046827316284,2.044733762741089,0.5141599774360657,2.2269556522369385,50000.0,0.4026000201702118,2.8585050106048584,10000.0,12227.276673793793,13128.899226903915,12227.276673793793,898.9583828449249,1.245424747467041,0.0 -27300,1.4596937,3.6173835,,,,,,,,,,,,,, -27400,0.99531674,5.520566,,,,,,,,,,,,,, -27500,1.1412749,4.3210964,,,,,,,,,,,,,, -27600,1.464345,3.5985029,,,,,,,,,,,,,, -27700,1.2525563,4.238814,,,,,,,,,,,,,, -27800,0.99672234,4.561734,,,,,,,,,,,,,, -27900,1.4943119,3.6273415,,,,,,,,,,,,,, -28000,1.3453009,3.8646235,,,,,,,,,,,,,, -28100,1.3566033,3.8012807,,,,,,,,,,,,,, -28137,,,0.5671679377555847,1.9500863552093504,0.5212999582290649,2.152677536010742,50000.0,0.4109000265598297,2.7780418395996094,10000.0,12647.221132278442,13582.409242868423,12647.221132278442,932.4390978813173,1.2829210758209229,0.0 -28200,1.319401,3.6095393,,,,,,,,,,,,,, -28300,1.320996,3.64919,,,,,,,,,,,,,, -28400,1.1737674,4.0732827,,,,,,,,,,,,,, -28500,0.8915265,5.650047,,,,,,,,,,,,,, -28600,0.9648376,5.693259,,,,,,,,,,,,,, -28700,1.2650144,3.853105,,,,,,,,,,,,,, -28800,1.3006015,3.540911,,,,,,,,,,,,,, -28900,1.3114675,3.5323257,,,,,,,,,,,,,, -29000,1.1787704,3.9160905,,,,,,,,,,,,,, -29071,,,0.570507824420929,1.9200985431671145,0.5317800045013428,2.108691453933716,50000.0,0.4243000149726867,2.731639862060547,10000.0,13067.440270900726,14037.257809400558,13067.440270900726,966.9934012889862,1.3104724884033203,0.0 -29100,1.0995171,4.1851053,,,,,,,,,,,,,, -29200,1.144771,4.4684124,,,,,,,,,,,,,, -29300,1.2565298,3.757028,,,,,,,,,,,,,, -29400,1.2989573,3.5344431,,,,,,,,,,,,,, -29500,1.502272,3.6614168,,,,,,,,,,,,,, -29600,1.3354721,3.5554998,,,,,,,,,,,,,, -29700,1.2361548,4.2373796,,,,,,,,,,,,,, -29800,1.1020232,4.711989,,,,,,,,,,,,,, -29900,1.3782694,3.5663352,,,,,,,,,,,,,, -30000,1.3061192,3.5648022,,,,,,,,,,,,,, -30004,,,0.5688085556030273,1.989999532699585,0.5308399796485901,2.16821813583374,50000.0,0.419400006532669,2.765523672103882,10000.0,13487.520725488665,14489.614025115969,13487.520725488665,999.187772512436,1.343907117843628,0.0 -30100,1.4093174,3.54497,,,,,,,,,,,,,, -30200,1.0200958,5.689646,,,,,,,,,,,,,, -30300,1.373394,3.6839721,,,,,,,,,,,,,, -30400,1.1354239,4.7385974,,,,,,,,,,,,,, -30500,1.1983232,3.8040025,,,,,,,,,,,,,, -30600,1.2569438,5.639386,,,,,,,,,,,,,, -30700,1.5716507,3.5502336,,,,,,,,,,,,,, -30800,1.414761,3.7792075,,,,,,,,,,,,,, -30900,1.2171549,4.77932,,,,,,,,,,,,,, -30934,,,0.5673046708106995,1.9936236143112185,0.526419997215271,2.189337968826294,50000.0,0.4137000143527984,2.810401201248169,10000.0,13907.610777139664,14941.843090057371,13907.610777139664,1031.2494950294497,1.3736467361450195,0.0 -31000,1.2032804,4.5745854,,,,,,,,,,,,,, -31100,1.4166961,3.8254814,,,,,,,,,,,,,, -31200,1.5186838,3.637803,,,,,,,,,,,,,, -31300,1.3496021,3.6479805,,,,,,,,,,,,,, -31400,1.3251269,3.4376352,,,,,,,,,,,,,, -31500,1.3474244,3.447157,,,,,,,,,,,,,, -31600,1.3374492,3.538054,,,,,,,,,,,,,, -31700,1.2682081,3.7350106,,,,,,,,,,,,,, -31800,1.2740808,4.1210566,,,,,,,,,,,,,, -31864,,,0.5966015458106995,1.8194531202316284,0.5347200036048889,2.106844902038574,50000.0,0.4237000346183777,2.7296688556671143,10000.0,14327.60186433792,15395.709174633026,14327.60186433792,1065.047378540039,1.4030015468597412,0.0 -31900,1.0716732,5.5842113,,,,,,,,,,,,,, -32000,1.3744962,3.5225027,,,,,,,,,,,,,, -32100,1.374052,3.4524465,,,,,,,,,,,,,, -32200,1.3136814,3.4740915,,,,,,,,,,,,,, -32300,1.0326201,4.500204,,,,,,,,,,,,,, -32400,1.3859912,3.5152793,,,,,,,,,,,,,, -32500,1.0913632,5.1158433,,,,,,,,,,,,,, -32600,1.0836843,5.692502,,,,,,,,,,,,,, -32700,1.303348,3.728844,,,,,,,,,,,,,, -32796,,,0.5743749737739563,1.9367910623550413,0.5367400050163269,2.105886220932007,50000.0,0.4262000322341919,2.733640193939209,10000.0,14747.6784825325,15849.31853699684,14747.6784825325,1098.5008039474487,1.434175968170166,0.0 -32800,1.4368559,3.4703596,,,,,,,,,,,,,, -32900,1.4433358,3.4777403,,,,,,,,,,,,,, -33000,0.96457094,5.707094,,,,,,,,,,,,,, -33100,1.3958989,3.4814,,,,,,,,,,,,,, -33200,1.526689,3.4932728,,,,,,,,,,,,,, -33300,1.3943461,3.4477718,,,,,,,,,,,,,, -33400,1.1416107,5.3672376,,,,,,,,,,,,,, -33500,1.3873292,3.630977,,,,,,,,,,,,,, -33600,1.365072,3.6067908,,,,,,,,,,,,,, -33700,1.0981115,4.6329594,,,,,,,,,,,,,, -33728,,,0.5781444907188416,1.909422755241394,0.536899983882904,2.1106417179107666,50000.0,0.4223000109195709,2.7446610927581787,10000.0,15167.680012226105,16303.204138755798,15167.680012226105,1132.3038840293884,1.4670917987823486,0.0 -33800,1.3690603,3.4045951,,,,,,,,,,,,,, -33900,1.0610689,5.732446,,,,,,,,,,,,,, -34000,1.1821792,4.2757773,,,,,,,,,,,,,, -34100,1.4446803,3.589923,,,,,,,,,,,,,, -34200,1.3848808,3.534983,,,,,,,,,,,,,, -34300,1.2844048,3.7334328,,,,,,,,,,,,,, -34400,1.4448678,3.4556148,,,,,,,,,,,,,, -34500,1.604189,3.4804873,,,,,,,,,,,,,, -34600,1.4795079,3.4933343,,,,,,,,,,,,,, -34663,,,0.5896288752555847,1.907817840576172,0.5367599725723267,2.140103340148926,50000.0,0.427700012922287,2.7481539249420166,10000.0,15587.909015655518,16755.502063274384,15587.909015655518,1164.2943496704102,1.4970180988311768,0.0 -34700,1.3006158,3.4195971,,,,,,,,,,,,,, -34800,1.1695399,4.5144234,,,,,,,,,,,,,, -34900,1.3348869,3.595904,,,,,,,,,,,,,, -35000,1.4246092,3.5451357,,,,,,,,,,,,,, -35100,1.4025989,3.6096714,,,,,,,,,,,,,, -35200,1.3635875,3.5134006,,,,,,,,,,,,,, -35300,1.4384782,3.396733,,,,,,,,,,,,,, -35400,1.2870638,4.3141775,,,,,,,,,,,,,, -35500,0.97306496,5.3059683,,,,,,,,,,,,,, -35595,,,0.5802148580551147,1.8932175636291504,0.5425800085067749,2.0681850910186768,50000.0,0.4282000064849853,2.7046916484832764,10000.0,16008.223159313202,17209.556839942932,16008.223159313202,1197.9523015022278,1.5312883853912354,0.0 -35600,1.2534096,3.9210227,,,,,,,,,,,,,, -35700,1.2064584,4.5768614,,,,,,,,,,,,,, -35800,1.4917996,3.4846685,,,,,,,,,,,,,, -35900,1.3538533,3.627059,,,,,,,,,,,,,, -36000,1.5849044,3.6375546,,,,,,,,,,,,,, -36100,1.3303933,3.4871664,,,,,,,,,,,,,, -36200,1.1702737,4.4923453,,,,,,,,,,,,,, -36300,1.5824432,3.4200737,,,,,,,,,,,,,, -36400,1.3646702,3.4905114,,,,,,,,,,,,,, -36500,1.3481846,4.0751667,,,,,,,,,,,,,, -36527,,,0.5877343416213989,1.890945553779602,0.545740008354187,2.086134433746338,50000.0,0.4364000260829925,2.7020034790039062,10000.0,16428.311499357224,17662.983982801437,16428.311499357224,1231.2077586650848,1.5664069652557373,0.0 -36600,1.0673214,4.682581,,,,,,,,,,,,,, -36700,1.298418,3.7468555,,,,,,,,,,,,,, -36800,1.4904572,3.6735563,,,,,,,,,,,,,, -36900,1.1211259,5.4394135,,,,,,,,,,,,,, -37000,1.3779327,3.5756607,,,,,,,,,,,,,, -37100,1.3826714,3.4365838,,,,,,,,,,,,,, -37200,1.1721017,4.961701,,,,,,,,,,,,,, -37300,1.2533466,4.1519265,,,,,,,,,,,,,, -37400,1.3387086,3.7604237,,,,,,,,,,,,,, -37459,,,0.5893945097923279,1.879970908164978,0.5448200106620789,2.087385654449463,50000.0,0.4308000206947326,2.711873769760132,10000.0,16848.334990262985,18117.32121515274,16848.334990262985,1265.4408011436462,1.5989644527435305,0.0 -37500,1.5315073,3.4886591,,,,,,,,,,,,,, -37600,1.2208064,5.5174866,,,,,,,,,,,,,, -37700,1.3248539,3.7992263,,,,,,,,,,,,,, -37800,1.3353391,3.6325855,,,,,,,,,,,,,, -37900,1.1458236,5.610993,,,,,,,,,,,,,, -38000,1.38431,3.5191455,,,,,,,,,,,,,, -38100,1.3204317,3.5950198,,,,,,,,,,,,,, -38200,1.416633,3.458136,,,,,,,,,,,,,, -38300,1.2516953,4.2443447,,,,,,,,,,,,,, -38392,,,0.6200585961341858,1.7613177299499512,0.5546599626541138,2.0597145557403564,50000.0,0.4370000064373016,2.6826961040496826,10000.0,17268.433282613754,18569.78191447258,17268.433282613754,1297.7227976322174,1.631274700164795,0.0 -38400,1.3914067,3.5103288,,,,,,,,,,,,,, -38500,1.4535043,3.6027012,,,,,,,,,,,,,, -38600,1.2600785,4.3771687,,,,,,,,,,,,,, -38700,1.3893019,3.5045602,,,,,,,,,,,,,, -38800,1.5824436,3.46649,,,,,,,,,,,,,, -38900,1.6109343,3.468131,,,,,,,,,,,,,, -39000,1.3789985,3.4914699,,,,,,,,,,,,,, -39100,1.468334,3.4945977,,,,,,,,,,,,,, -39200,1.4691011,3.4668384,,,,,,,,,,,,,, -39300,1.5307899,3.8680408,,,,,,,,,,,,,, -39325,,,0.5931640267372131,1.8362483978271484,0.5575799942016602,2.013503313064575,50000.0,0.4411000311374664,2.634138584136963,10000.0,17688.36795592308,19022.378975868225,17688.36795592308,1330.3051433563232,1.663404941558838,0.0 -39400,1.0961003,4.963597,,,,,,,,,,,,,, -39500,1.2700716,3.6694715,,,,,,,,,,,,,, -39600,1.388262,4.1161475,,,,,,,,,,,,,, -39700,1.540875,3.4024765,,,,,,,,,,,,,, -39800,1.352289,3.4674928,,,,,,,,,,,,,, -39900,1.2790658,3.7384877,,,,,,,,,,,,,, -40000,1.3410401,3.5736077,,,,,,,,,,,,,, -40100,1.3743919,3.7515028,,,,,,,,,,,,,, -40200,1.4809273,3.3137853,,,,,,,,,,,,,, -40258,,,0.5926952958106995,1.8907710313797,0.5488799810409546,2.0913960933685303,50000.0,0.4405000209808349,2.6993274688720703,10000.0,18108.57060527801,19475.44114756584,18108.57060527801,1363.0829148292542,1.6961984634399414,0.0 -40300,1.2306819,3.993864,,,,,,,,,,,,,, -40400,1.4275529,3.3922753,,,,,,,,,,,,,, -40500,1.3591943,3.8730185,,,,,,,,,,,,,, -40600,1.1394331,5.1908937,,,,,,,,,,,,,, -40700,1.1035147,5.6442447,,,,,,,,,,,,,, -40800,1.3462493,3.4904702,,,,,,,,,,,,,, -40900,1.0266299,5.347706,,,,,,,,,,,,,, -41000,1.453216,3.6548772,,,,,,,,,,,,,, -41100,1.2768362,3.4919007,,,,,,,,,,,,,, -41188,,,0.6147655844688416,1.7229764461517334,0.560259997844696,1.9867902994155884,50000.0,0.4478000104427337,2.6145589351654053,10000.0,18528.672277212143,19927.50222015381,18528.672277212143,1394.9611177444458,1.7287757396697998,0.0 -41200,1.382388,3.5556955,,,,,,,,,,,,,, -41300,1.1529125,4.781762,,,,,,,,,,,,,, -41400,1.3358638,3.3778644,,,,,,,,,,,,,, -41500,1.0635438,5.458711,,,,,,,,,,,,,, -41600,0.98715377,5.613872,,,,,,,,,,,,,, -41700,1.0638647,5.524852,,,,,,,,,,,,,, -41800,1.1384404,4.5284853,,,,,,,,,,,,,, -41900,1.01482,5.4813366,,,,,,,,,,,,,, -42000,1.2124336,4.225887,,,,,,,,,,,,,, -42100,1.3816546,3.533864,,,,,,,,,,,,,, -42119,,,0.5964453220367432,1.8440462350845337,0.558899998664856,2.013521671295166,50000.0,0.4463000297546386,2.636042594909668,10000.0,18948.740561246872,20379.68393588066,18948.740561246872,1426.9955496788025,1.7593953609466553,0.0 -42200,1.5113583,3.400303,,,,,,,,,,,,,, -42300,1.8624357,3.3736012,,,,,,,,,,,,,, -42400,1.4242359,3.5621395,,,,,,,,,,,,,, -42500,1.371036,3.4284315,,,,,,,,,,,,,, -42600,1.3015555,3.870051,,,,,,,,,,,,,, -42700,1.2457107,4.2985454,,,,,,,,,,,,,, -42800,1.4660532,3.5397341,,,,,,,,,,,,,, -42900,1.303128,5.6142373,,,,,,,,,,,,,, -43000,1.5027285,3.3938675,,,,,,,,,,,,,, -43052,,,0.6012109518051147,1.8032618761062624,0.5563399791717529,1.999727249145508,50000.0,0.4493000209331512,2.613379716873169,10000.0,19368.863209962845,20831.470037698746,19368.863209962845,1458.5802025794983,1.7902934551239014,0.0 -43100,1.7120324,3.3581278,,,,,,,,,,,,,, -43200,1.1623122,4.707112,,,,,,,,,,,,,, -43300,1.2149812,5.543624,,,,,,,,,,,,,, -43400,1.4362799,3.4339538,,,,,,,,,,,,,, -43500,1.1842442,5.450888,,,,,,,,,,,,,, -43600,1.4393032,3.4080722,,,,,,,,,,,,,, -43700,1.6350367,3.4603286,,,,,,,,,,,,,, -43800,0.98215586,5.25334,,,,,,,,,,,,,, -43900,1.3308792,3.460457,,,,,,,,,,,,,, -43982,,,0.6162695288658142,1.7325092554092407,0.5649799704551697,1.955632925033569,50000.0,0.4533000290393829,2.577320337295532,10000.0,19788.84431171417,21285.02067923546,19788.84431171417,1492.062311410904,1.8297011852264404,0.0 -44000,1.3304374,3.4551754,,,,,,,,,,,,,, -44100,1.3824569,3.6892736,,,,,,,,,,,,,, -44200,1.3315808,3.7349555,,,,,,,,,,,,,, -44300,1.5116298,3.8838742,,,,,,,,,,,,,, -44400,1.5149051,3.5164855,,,,,,,,,,,,,, -44500,1.4637948,3.461868,,,,,,,,,,,,,, -44600,1.2925482,3.4257624,,,,,,,,,,,,,, -44700,1.1010494,5.486014,,,,,,,,,,,,,, -44800,1.1815528,5.536148,,,,,,,,,,,,,, -44900,1.164428,4.851369,,,,,,,,,,,,,, -44914,,,0.6068750023841858,1.7359304428100586,0.5651000142097473,1.932178974151612,50000.0,0.4526000320911407,2.549083709716797,10000.0,20208.771056890488,21737.75564527512,20208.771056890488,1524.7891371250153,1.8625941276550293,0.0 -45000,1.103652,5.0320306,,,,,,,,,,,,,, -45100,1.3503634,3.722867,,,,,,,,,,,,,, -45200,1.5358176,3.3297286,,,,,,,,,,,,,, -45300,1.1190456,4.8204527,,,,,,,,,,,,,, -45400,1.4405301,3.4783044,,,,,,,,,,,,,, -45500,1.1287029,4.933419,,,,,,,,,,,,,, -45600,1.4600513,3.8888073,,,,,,,,,,,,,, -45700,1.569131,3.381788,,,,,,,,,,,,,, -45800,1.4141794,3.417244,,,,,,,,,,,,,, -45847,,,0.6036913990974426,1.786840319633484,0.5648800134658813,1.972473382949829,50000.0,0.4513000249862671,2.5971016883850098,10000.0,20628.775985956192,22189.844033002853,20628.775985956192,1556.786475896835,1.8994367122650144,0.0 -45900,1.1794645,3.982215,,,,,,,,,,,,,, -46000,1.6764139,3.4199083,,,,,,,,,,,,,, -46100,1.5339693,3.3594813,,,,,,,,,,,,,, -46200,1.2652092,4.7738876,,,,,,,,,,,,,, -46300,1.2592766,3.7552469,,,,,,,,,,,,,, -46400,1.3543419,3.8895884,,,,,,,,,,,,,, -46500,1.1831217,5.6667385,,,,,,,,,,,,,, -46600,1.1883332,5.4256387,,,,,,,,,,,,,, -46700,0.99206036,5.2652273,,,,,,,,,,,,,, -46780,,,0.6071484088897705,1.7338517904281616,0.564520001411438,1.9355089664459229,50000.0,0.4513000249862671,2.5617563724517822,10000.0,21048.85280585289,22643.165594816208,21048.85280585289,1589.9524147510529,1.930569648742676,0.0 -46800,1.495532,3.3998697,,,,,,,,,,,,,, -46900,1.4493083,3.4364529,,,,,,,,,,,,,, -47000,1.4953852,3.3920708,,,,,,,,,,,,,, -47100,1.4485276,3.4369867,,,,,,,,,,,,,, -47200,1.4989452,3.323799,,,,,,,,,,,,,, -47300,1.3447555,3.542851,,,,,,,,,,,,,, -47400,1.4754025,3.2168643,,,,,,,,,,,,,, -47500,1.5365175,3.2920697,,,,,,,,,,,,,, -47600,1.532702,3.2723062,,,,,,,,,,,,,, -47700,1.497766,3.3441942,,,,,,,,,,,,,, -47715,,,0.6330664157867432,1.6488145589828491,0.5689399838447571,1.946183681488037,50000.0,0.4541000127792358,2.577409744262696,10000.0,21468.807317256927,23097.087222337723,21468.807317256927,1623.8386886119845,1.963141202926636,0.0 -47800,1.198439,5.5958157,,,,,,,,,,,,,, -47900,1.3190868,3.9731917,,,,,,,,,,,,,, -48000,1.4046742,5.4464765,,,,,,,,,,,,,, -48100,1.4787468,3.2987735,,,,,,,,,,,,,, -48200,1.3166398,4.223577,,,,,,,,,,,,,, -48300,1.5529503,3.311066,,,,,,,,,,,,,, -48400,1.5371679,3.5310934,,,,,,,,,,,,,, -48500,1.3742191,3.4922597,,,,,,,,,,,,,, -48600,1.663527,3.5275965,,,,,,,,,,,,,, -48651,,,0.6096093654632568,1.764771580696106,0.5680999755859375,1.94655179977417,50000.0,0.4523000121116638,2.5711467266082764,10000.0,21889.076536417007,23549.15618038177,21889.076536417007,1655.5556573867798,1.9970552921295168,0.0 -48700,1.4425895,3.3398368,,,,,,,,,,,,,, -48800,1.3904294,3.7466292,,,,,,,,,,,,,, -48900,1.2120572,4.185484,,,,,,,,,,,,,, -49000,1.2472523,4.8465343,,,,,,,,,,,,,, -49100,1.3098737,5.3789563,,,,,,,,,,,,,, -49200,1.5358516,3.4339867,,,,,,,,,,,,,, -49300,1.5633279,3.5217454,,,,,,,,,,,,,, -49400,1.4997625,3.3904636,,,,,,,,,,,,,, -49500,1.4845558,3.3338215,,,,,,,,,,,,,, -49585,,,0.6127734184265137,1.7522122859954834,0.5688999891281128,1.9498313665390008,50000.0,0.4537000358104706,2.5811941623687744,10000.0,22309.38073515892,24001.9259724617,22309.38073515892,1687.9413216114044,2.028259038925171,0.0 -49600,1.7445515,3.4832706,,,,,,,,,,,,,, -49700,1.187468,5.4441524,,,,,,,,,,,,,, -49800,1.4021412,3.5190473,,,,,,,,,,,,,, -49900,1.3947322,5.006082,,,,,,,,,,,,,, -50000,1.4508666,5.560526,,,,,,,,,,,,,, -50100,1.5835266,3.2408144,,,,,,,,,,,,,, -50200,1.2804539,4.426152,,,,,,,,,,,,,, -50300,1.7427387,3.359579,,,,,,,,,,,,,, -50400,1.2922136,5.572807,,,,,,,,,,,,,, -50500,1.6609607,3.2902603,,,,,,,,,,,,,, -50516,,,0.6240820288658142,1.6619887351989746,0.5757799744606018,1.885627508163452,50000.0,0.4613000154495239,2.5196902751922607,10000.0,22729.444053173065,24455.06451320648,22729.444053173065,1720.9364750385284,2.060332059860229,0.0 -50600,1.5622365,3.4197428,,,,,,,,,,,,,, -50700,1.5382979,3.6509573,,,,,,,,,,,,,, -50800,1.5551323,3.3529382,,,,,,,,,,,,,, -50900,1.3820513,5.4866905,,,,,,,,,,,,,, -51000,1.5229508,3.6232078,,,,,,,,,,,,,, -51100,1.4358735,3.595418,,,,,,,,,,,,,, -51200,1.4227506,3.2887132,,,,,,,,,,,,,, -51300,1.2572428,3.8551407,,,,,,,,,,,,,, -51400,1.5420604,3.397852,,,,,,,,,,,,,, -51449,,,0.6103906035423279,1.7757805585861206,0.5678399801254272,1.9511852264404297,50000.0,0.4574000239372253,2.5609421730041504,10000.0,23149.631596565247,24908.31686663628,23149.631596565247,1753.9146332740784,2.098830223083496,0.0 -51500,1.5691158,3.5189748,,,,,,,,,,,,,, -51600,1.5023681,3.422902,,,,,,,,,,,,,, -51700,1.3913031,3.7440562,,,,,,,,,,,,,, -51800,1.213551,5.4934797,,,,,,,,,,,,,, -51900,1.2854856,5.0527954,,,,,,,,,,,,,, -52000,1.2499423,5.5585594,,,,,,,,,,,,,, -52100,1.5798022,3.3642197,,,,,,,,,,,,,, -52200,1.730485,3.3060627,,,,,,,,,,,,,, -52300,1.7093455,3.3271592,,,,,,,,,,,,,, -52382,,,0.6201757788658142,1.6868958473205566,0.5767799615859985,1.888272166252136,50000.0,0.459600031375885,2.524887084960937,10000.0,23569.55105662346,25361.764173030853,23569.55105662346,1787.3572108745575,2.134948968887329,0.0 -52400,1.4053161,3.8115857,,,,,,,,,,,,,, -52500,1.4845648,3.2544317,,,,,,,,,,,,,, -52600,1.3050668,4.3410873,,,,,,,,,,,,,, -52700,1.2310189,5.4868364,,,,,,,,,,,,,, -52800,1.5975516,3.3835783,,,,,,,,,,,,,, -52900,1.381345,3.1716962,,,,,,,,,,,,,, -53000,1.5849785,3.406253,,,,,,,,,,,,,, -53100,2.0646918,3.381522,,,,,,,,,,,,,, -53200,1.5317427,3.3985085,,,,,,,,,,,,,, -53300,1.4925982,3.3990793,,,,,,,,,,,,,, -53312,,,0.6251562237739563,1.6901180744171145,0.5781999826431274,1.9079535007476809,50000.0,0.4633000195026397,2.536857843399048,10000.0,23989.92184877396,25815.29575109481,23989.92184877396,1820.436414003372,2.1685073375701904,0.0 -53400,1.6565161,3.3534086,,,,,,,,,,,,,, -53500,1.316027,5.454611,,,,,,,,,,,,,, -53600,1.2756563,5.276978,,,,,,,,,,,,,, -53700,1.5934567,3.3362887,,,,,,,,,,,,,, -53800,1.4097395,3.8220925,,,,,,,,,,,,,, -53900,1.383804,4.0213494,,,,,,,,,,,,,, -54000,1.4598277,3.342877,,,,,,,,,,,,,, -54100,1.5928878,3.3709104,,,,,,,,,,,,,, -54200,1.5449114,3.276125,,,,,,,,,,,,,, -54243,,,0.6433789134025574,1.6116437911987305,0.5797600150108337,1.8849976062774656,50000.0,0.4629000127315521,2.5215423107147217,10000.0,24410.181674718857,26268.046141147614,24410.181674718857,1852.839579820633,2.207714319229126,0.0 -54300,1.1430954,5.2945266,,,,,,,,,,,,,, -54400,1.4048015,4.6699796,,,,,,,,,,,,,, -54500,1.032048,5.1591,,,,,,,,,,,,,, -54600,1.4831686,3.4219453,,,,,,,,,,,,,, -54700,1.2499709,4.8525047,,,,,,,,,,,,,, -54800,1.4873922,3.5876036,,,,,,,,,,,,,, -54900,1.5301136,3.6232638,,,,,,,,,,,,,, -55000,1.611704,3.3500931,,,,,,,,,,,,,, -55100,1.5571814,3.3238306,,,,,,,,,,,,,, -55173,,,0.6248632669448853,1.6888477802276611,0.5834400057792664,1.886579155921936,50000.0,0.4659000337123871,2.519357681274414,10000.0,24830.14268946648,26721.406907081604,24830.14268946648,1886.1577606201167,2.2408063411712646,0.0 -55200,1.3616976,5.478109,,,,,,,,,,,,,, -55300,1.5652077,3.3905635,,,,,,,,,,,,,, -55400,1.3335819,3.9123478,,,,,,,,,,,,,, -55500,1.1968995,5.097232,,,,,,,,,,,,,, -55600,1.6478884,3.5232546,,,,,,,,,,,,,, -55700,1.6131194,3.382419,,,,,,,,,,,,,, -55800,1.5305479,3.8176017,,,,,,,,,,,,,, -55900,1.6855494,3.3546543,,,,,,,,,,,,,, -56000,1.671507,3.5467856,,,,,,,,,,,,,, -56100,1.662098,3.50743,,,,,,,,,,,,,, -56105,,,0.6288476586341858,1.6663347482681274,0.5836799740791321,1.8777925968170168,50000.0,0.4644000232219696,2.5098447799682617,10000.0,25250.628484487534,27173.98124217987,25250.628484487534,1918.165089845657,2.2733821868896484,0.0 -56200,1.3850583,3.311985,,,,,,,,,,,,,, -56300,1.4542466,3.200741,,,,,,,,,,,,,, -56400,1.5023487,3.2415314,,,,,,,,,,,,,, -56500,1.681023,3.3487022,,,,,,,,,,,,,, -56600,1.7185037,3.3155158,,,,,,,,,,,,,, -56700,1.2333915,5.033487,,,,,,,,,,,,,, -56800,1.5126139,3.3336241,,,,,,,,,,,,,, -56900,1.6385096,3.3124113,,,,,,,,,,,,,, -57000,1.4490612,3.2440798,,,,,,,,,,,,,, -57036,,,0.644726574420929,1.6122007369995115,0.5859599709510803,1.872389435768128,50000.0,0.4666000306606293,2.5017263889312744,10000.0,25670.61145210266,27626.93688774109,25670.61145210266,1951.056599855423,2.3057284355163574,0.0 -57100,1.4983749,3.302356,,,,,,,,,,,,,, -57200,1.7702936,3.2405646,,,,,,,,,,,,,, -57300,1.4506923,3.221621,,,,,,,,,,,,,, -57400,1.4190139,3.6207619,,,,,,,,,,,,,, -57500,1.2731638,4.6087832,,,,,,,,,,,,,, -57600,1.603738,3.351427,,,,,,,,,,,,,, -57700,1.3879901,3.587517,,,,,,,,,,,,,, -57800,1.5696008,3.2983763,,,,,,,,,,,,,, -57900,1.6433077,3.415452,,,,,,,,,,,,,, -57970,,,0.6248242259025574,1.6605851650238037,0.5842999815940857,1.8347177505493164,50000.0,0.4657000303268432,2.4734883308410645,10000.0,26091.04658770561,28080.76435756684,26091.04658770561,1984.3666186332705,2.3380661010742188,0.0 -58000,1.4365274,5.420682,,,,,,,,,,,,,, -58100,1.348461,5.5463905,,,,,,,,,,,,,, -58200,1.5401797,3.27129,,,,,,,,,,,,,, -58300,1.3013893,4.303942,,,,,,,,,,,,,, -58400,1.6327224,3.301077,,,,,,,,,,,,,, -58500,1.6226492,3.416287,,,,,,,,,,,,,, -58600,1.5347732,3.4295933,,,,,,,,,,,,,, -58700,1.1510726,5.479263,,,,,,,,,,,,,, -58800,1.5864342,3.266893,,,,,,,,,,,,,, -58900,,,0.6384179592132568,1.634387731552124,0.5914799571037292,1.8405027389526367,50000.0,0.4716000258922577,2.476768016815185,10000.0,26511.307859420776,28534.63808321953,26511.307859420776,2017.8947920799253,2.3739945888519287,0.0 -58900,1.1123562,4.7996864,,,,,,,,,,,,,, -59000,1.3227025,4.6990676,,,,,,,,,,,,,, -59100,1.2947797,5.0420976,,,,,,,,,,,,,, -59200,1.5172956,3.4949427,,,,,,,,,,,,,, -59300,1.4510733,3.3084412,,,,,,,,,,,,,, -59400,1.6615175,3.3263547,,,,,,,,,,,,,, -59500,1.4043825,4.970336,,,,,,,,,,,,,, -59600,1.4735208,3.2947788,,,,,,,,,,,,,, -59700,1.4610261,3.6422825,,,,,,,,,,,,,, -59800,1.2371893,5.0307426,,,,,,,,,,,,,, -59829,,,0.634570300579071,1.642100811004639,0.58406001329422,1.868959784507752,50000.0,0.4670000076293945,2.510704517364502,10000.0,26931.3852558136,28988.00011849404,26931.3852558136,2051.0959992408752,2.4093093872070312,0.0 -59900,1.504489,3.3101132,,,,,,,,,,,,,, -60000,1.4496084,3.6386626,,,,,,,,,,,,,, -60100,1.1746718,4.083628,,,,,,,,,,,,,, -60200,1.4413377,3.804346,,,,,,,,,,,,,, -60300,1.6229409,3.2238,,,,,,,,,,,,,, -60400,1.5181961,4.030171,,,,,,,,,,,,,, -60500,1.5803572,3.3420424,,,,,,,,,,,,,, -60600,1.4707938,5.1217947,,,,,,,,,,,,,, -60700,1.6403952,3.3382347,,,,,,,,,,,,,, -60758,,,0.6297656297683716,1.6264452934265137,0.5870400071144104,1.827787399291992,50000.0,0.4716000258922577,2.478414297103882,10000.0,27351.664115190502,29439.022877693176,27351.664115190502,2081.7578916549683,2.442704200744629,0.0 -60800,1.2105286,4.549324,,,,,,,,,,,,,, -60900,1.6213351,3.6136134,,,,,,,,,,,,,, -61000,1.5450014,3.2949924,,,,,,,,,,,,,, -61100,1.6457572,3.3736484,,,,,,,,,,,,,, -61200,1.6005667,3.2452834,,,,,,,,,,,,,, -61300,1.6584898,3.7061136,,,,,,,,,,,,,, -61400,1.4449581,3.3995912,,,,,,,,,,,,,, -61500,1.1472764,4.7328672,,,,,,,,,,,,,, -61600,1.6069117,3.3078268,,,,,,,,,,,,,, -61688,,,0.6357226371765137,1.6139147281646729,0.5892199873924255,1.826127290725708,50000.0,0.4758000373840332,2.4539406299591064,10000.0,27771.800651311874,29891.457787036896,27771.800651311874,2113.970644235611,2.47971773147583,0.0 -61700,1.6142535,3.304206,,,,,,,,,,,,,, -61800,1.6326506,3.4239793,,,,,,,,,,,,,, -61900,1.3495431,4.1056113,,,,,,,,,,,,,, -62000,1.5625657,3.2838912,,,,,,,,,,,,,, -62100,1.416327,3.8220675,,,,,,,,,,,,,, -62200,1.517544,3.3031352,,,,,,,,,,,,,, -62300,1.8436264,3.261653,,,,,,,,,,,,,, -62400,1.6917707,3.353998,,,,,,,,,,,,,, -62500,1.2521776,5.28837,,,,,,,,,,,,,, -62600,1.3972828,3.5366817,,,,,,,,,,,,,, -62618,,,0.6410741806030273,1.5769309997558594,0.594760000705719,1.7927261590957642,50000.0,0.4781000316143036,2.4324634075164795,10000.0,28192.091106176376,30345.31472611428,28192.091106176376,2147.4553532600403,2.5132803916931152,0.0 -62700,1.3280048,3.4527192,,,,,,,,,,,,,, -62800,1.3042964,4.4378185,,,,,,,,,,,,,, -62900,1.5541275,3.3629136,,,,,,,,,,,,,, -63000,1.27828,4.2299337,,,,,,,,,,,,,, -63100,1.604484,3.0840092,,,,,,,,,,,,,, -63200,1.3042941,5.525991,,,,,,,,,,,,,, -63300,1.3978498,5.3254175,,,,,,,,,,,,,, -63400,1.4516248,3.467319,,,,,,,,,,,,,, -63500,1.6442565,3.401492,,,,,,,,,,,,,, -63551,,,0.6690039038658142,1.4578577280044556,0.5971800088882446,1.771941065788269,50000.0,0.4824000298976898,2.3906712532043457,10000.0,28612.40795993805,30799.267755031586,28612.40795993805,2181.003592252732,2.5525503158569336,0.0 -63600,1.5976826,3.6927617,,,,,,,,,,,,,, -63700,1.5826575,3.4397902,,,,,,,,,,,,,, -63800,1.5948763,3.2289546,,,,,,,,,,,,,, -63900,1.3314673,3.7652311,,,,,,,,,,,,,, -64000,1.1883694,5.2371826,,,,,,,,,,,,,, -64100,1.317939,5.408944,,,,,,,,,,,,,, -64200,1.3073605,4.252235,,,,,,,,,,,,,, -64300,1.3676516,5.1584353,,,,,,,,,,,,,, -64400,1.8171817,3.331075,,,,,,,,,,,,,, -64482,,,0.6347460746765137,1.6374387741088867,0.5944199562072754,1.825912952423096,50000.0,0.4761000275611877,2.455456018447876,10000.0,29032.35076379776,31253.455913305283,29032.35076379776,2215.162506103516,2.5897204875946045,0.0 -64500,1.5919131,3.205287,,,,,,,,,,,,,, -64600,1.5820024,3.3634048,,,,,,,,,,,,,, -64700,1.7824969,3.354428,,,,,,,,,,,,,, -64800,1.5807029,3.2723794,,,,,,,,,,,,,, -64900,1.5084149,5.3567877,,,,,,,,,,,,,, -65000,1.3622155,5.4133368,,,,,,,,,,,,,, -65100,1.4742433,5.3583097,,,,,,,,,,,,,, -65200,1.5228542,5.48298,,,,,,,,,,,,,, -65300,1.7951577,3.3748505,,,,,,,,,,,,,, -65400,1.7109349,3.2218585,,,,,,,,,,,,,, -65414,,,0.6394921541213989,1.6012791395187378,0.5922200083732605,1.8196548223495483,50000.0,0.4762000143527984,2.461789846420288,10000.0,29452.668427705765,31707.489936828613,29452.668427705765,2248.791541337967,2.6287543773651123,0.0 -65500,1.6231822,3.190433,,,,,,,,,,,,,, -65600,1.2689961,5.3869267,,,,,,,,,,,,,, -65700,1.645395,3.204382,,,,,,,,,,,,,, -65800,1.3242984,3.9823925,,,,,,,,,,,,,, -65900,1.2433321,5.398849,,,,,,,,,,,,,, -66000,1.5493176,3.1742935,,,,,,,,,,,,,, -66100,1.7453438,3.2919366,,,,,,,,,,,,,, -66200,1.4459465,3.543358,,,,,,,,,,,,,, -66300,1.4462259,3.57048,,,,,,,,,,,,,, -66344,,,0.659472644329071,1.5096582174301147,0.6010400056838989,1.7638288736343384,50000.0,0.4827000200748443,2.3863396644592285,10000.0,29872.83366370201,32161.46227788925,29872.83366370201,2282.513402938843,2.665854454040528,0.0 -66400,1.5569477,3.8874402,,,,,,,,,,,,,, -66500,1.3689376,5.1765695,,,,,,,,,,,,,, -66600,1.8528705,3.2030723,,,,,,,,,,,,,, -66700,1.5702305,3.2959838,,,,,,,,,,,,,, -66800,1.6801336,3.1948576,,,,,,,,,,,,,, -66900,1.2991452,5.0659723,,,,,,,,,,,,,, -67000,1.2941878,5.3893967,,,,,,,,,,,,,, -67100,1.3140473,4.431617,,,,,,,,,,,,,, -67200,1.6734222,5.47831,,,,,,,,,,,,,, -67275,,,0.6406835913658142,1.6189045906066897,0.6001200079917908,1.808770775794983,50000.0,0.4785000085830688,2.4507222175598145,10000.0,30292.992176771164,32615.25714874268,30292.992176771164,2316.063986301422,2.7025163173675537,0.0 -67300,1.6060834,3.2637033,,,,,,,,,,,,,, -67400,1.4524039,3.6126168,,,,,,,,,,,,,, -67500,1.4119623,4.0893345,,,,,,,,,,,,,, -67600,1.2839415,4.1482573,,,,,,,,,,,,,, -67700,1.5011653,4.2728868,,,,,,,,,,,,,, -67800,1.5233786,3.1648126,,,,,,,,,,,,,, -67900,1.5824115,3.3536975,,,,,,,,,,,,,, -68000,1.3800095,3.9385726,,,,,,,,,,,,,, -68100,1.5496217,5.2401876,,,,,,,,,,,,,, -68200,1.5103719,3.1654432,,,,,,,,,,,,,, -68204,,,0.6403124928474426,1.6007063388824463,0.5951399803161621,1.8148831129074097,50000.0,0.4745000302791595,2.447077989578247,10000.0,30712.97263765335,33067.442917346954,30712.97263765335,2348.187881231308,2.735978603363037,0.0 -68300,1.5910008,3.4289813,,,,,,,,,,,,,, -68400,1.3642025,4.14033,,,,,,,,,,,,,, -68500,1.2113974,5.0794077,,,,,,,,,,,,,, -68600,1.7832807,3.2718034,,,,,,,,,,,,,, -68700,1.387542,5.361459,,,,,,,,,,,,,, -68800,1.7055851,3.1315105,,,,,,,,,,,,,, -68900,1.5934743,3.1979256,,,,,,,,,,,,,, -69000,1.6572347,3.1688397,,,,,,,,,,,,,, -69100,1.3114356,3.8530657,,,,,,,,,,,,,, -69131,,,0.6518359184265137,1.5784968137741089,0.599399983882904,1.818591475486756,50000.0,0.4790000319480896,2.451775312423706,10000.0,31133.27097249031,33520.85844230652,31133.27097249031,2381.223728656769,2.768976926803589,0.0 -69200,1.676012,3.21815,,,,,,,,,,,,,, -69300,1.2830222,4.9195065,,,,,,,,,,,,,, -69400,1.3812168,5.423938,,,,,,,,,,,,,, -69500,1.4620105,3.1344388,,,,,,,,,,,,,, -69600,1.8391864,3.1056275,,,,,,,,,,,,,, -69700,1.7789077,3.0398178,,,,,,,,,,,,,, -69800,1.2286471,4.827286,,,,,,,,,,,,,, -69900,1.1619811,4.969931,,,,,,,,,,,,,, -70000,1.8499442,3.269785,,,,,,,,,,,,,, -70062,,,0.6419140696525574,1.6183698177337646,0.5989399552345276,1.82237446308136,50000.0,0.4794000089168548,2.46195387840271,10000.0,31553.39439797401,33972.844121456146,31553.39439797401,2413.0035569667816,2.803950309753418,0.0 -70100,1.3830862,3.8668883,,,,,,,,,,,,,, -70200,1.4180045,3.535344,,,,,,,,,,,,,, -70300,1.6801807,3.1599011,,,,,,,,,,,,,, -70400,2.464447,3.1937683,,,,,,,,,,,,,, -70500,1.5304937,3.1068358,,,,,,,,,,,,,, -70600,1.5004061,3.1751688,,,,,,,,,,,,,, -70700,1.50858,3.662951,,,,,,,,,,,,,, -70800,1.4841799,4.180625,,,,,,,,,,,,,, -70900,1.2304865,4.5978055,,,,,,,,,,,,,, -70992,,,0.64990234375,1.5852410793304443,0.6050199866294861,1.7876743078231812,50000.0,0.4846000373363495,2.424272298812866,10000.0,31973.70577263832,34426.241938352585,31973.70577263832,2445.9967498779297,2.849007129669189,0.0 -71000,1.3851706,5.1688275,,,,,,,,,,,,,, -71100,1.4256396,4.491454,,,,,,,,,,,,,, -71200,1.6008431,3.2249808,,,,,,,,,,,,,, -71300,1.5081888,3.5644007,,,,,,,,,,,,,, -71400,1.4169166,4.859352,,,,,,,,,,,,,, -71500,1.6341839,3.3146908,,,,,,,,,,,,,, -71600,1.6908731,3.2095265,,,,,,,,,,,,,, -71700,1.5263324,3.5930586,,,,,,,,,,,,,, -71800,1.38823,5.3910456,,,,,,,,,,,,,, -71900,1.7885149,3.1246076,,,,,,,,,,,,,, -71924,,,0.6512500047683716,1.5644874572753906,0.6061800122261047,1.7602328062057495,50000.0,0.4853000342845917,2.3803160190582275,10000.0,32393.78838253021,34879.81064558029,32393.78838253021,2479.398825407028,2.8843657970428467,0.0 -72000,1.4639395,5.0911484,,,,,,,,,,,,,, -72100,1.64386,3.4575775,,,,,,,,,,,,,, -72200,1.9091108,3.4282875,,,,,,,,,,,,,, -72300,1.2651471,5.1256104,,,,,,,,,,,,,, -72400,1.7653238,3.2374635,,,,,,,,,,,,,, -72500,1.5397265,4.0329704,,,,,,,,,,,,,, -72600,1.4216036,5.2358766,,,,,,,,,,,,,, -72700,1.4537553,5.011012,,,,,,,,,,,,,, -72800,1.5874435,5.1456556,,,,,,,,,,,,,, -72858,,,0.6725976467132568,1.4549552202224731,0.6046000123023987,1.746518850326538,50000.0,0.4829000234603882,2.3781180381774902,10000.0,32814.01478791237,35334.18482041359,32814.01478791237,2513.4549918174744,2.9265990257263184,0.0 -72900,1.7424577,3.2598145,,,,,,,,,,,,,, -73000,1.2469481,4.9926805,,,,,,,,,,,,,, -73100,1.6323265,3.178835,,,,,,,,,,,,,, -73200,1.6361439,3.2502997,,,,,,,,,,,,,, -73300,1.70106,3.2053695,,,,,,,,,,,,,, -73400,1.2892039,4.6564083,,,,,,,,,,,,,, -73500,1.6862514,3.194808,,,,,,,,,,,,,, -73600,1.4929807,4.8417716,,,,,,,,,,,,,, -73700,1.3677318,4.1235576,,,,,,,,,,,,,, -73792,,,0.6545507907867432,1.5467392206192017,0.6106799840927124,1.7467910051345823,50000.0,0.4903000295162201,2.389852523803711,10000.0,33234.2273080349,35787.68767333031,33234.2273080349,2546.6604709625244,2.96217679977417,0.0 -73800,1.6199818,3.2511642,,,,,,,,,,,,,, -73900,1.8154607,3.1006312,,,,,,,,,,,,,, -74000,1.5161791,3.5249636,,,,,,,,,,,,,, -74100,1.7020983,3.1513484,,,,,,,,,,,,,, -74200,1.5938978,3.1405513,,,,,,,,,,,,,, -74300,1.5009265,3.845849,,,,,,,,,,,,,, -74400,1.7347406,3.187452,,,,,,,,,,,,,, -74500,1.519469,5.447557,,,,,,,,,,,,,, -74600,1.6439437,3.46754,,,,,,,,,,,,,, -74700,1.2344842,5.415341,,,,,,,,,,,,,, -74726,,,0.6555468440055847,1.561237096786499,0.6126199960708618,1.7528570890426636,50000.0,0.4914000332355499,2.389246225357056,10000.0,33654.26421165466,36240.71160840988,33654.26421165466,2579.5513093471527,3.009035110473633,0.0 -74800,1.7530323,3.2482436,,,,,,,,,,,,,, -74900,1.4866145,4.42741,,,,,,,,,,,,,, -75000,1.555438,3.7919102,,,,,,,,,,,,,, -75100,1.6382617,3.1720023,,,,,,,,,,,,,, -75200,1.8266723,3.1990407,,,,,,,,,,,,,, -75300,1.8411355,3.1535063,,,,,,,,,,,,,, -75400,1.2860484,5.2221932,,,,,,,,,,,,,, -75500,1.757268,3.227028,,,,,,,,,,,,,, -75600,1.6201311,3.4825296,,,,,,,,,,,,,, -75657,,,0.6675195097923279,1.4419922828674316,0.6165199875831604,1.693387746810913,50000.0,0.4906000196933746,2.3438291549682617,10000.0,34074.34876227379,36693.855113983154,34074.34876227379,2612.520934343338,3.050046682357788,0.0 -75700,1.7152388,3.1691809,,,,,,,,,,,,,, -75800,1.5392256,3.5779862,,,,,,,,,,,,,, -75900,1.6803912,3.1717877,,,,,,,,,,,,,, -76000,1.8367581,3.1447237,,,,,,,,,,,,,, -76100,1.6023439,3.2524877,,,,,,,,,,,,,, -76200,1.3110597,5.156523,,,,,,,,,,,,,, -76300,1.6476371,3.2548842,,,,,,,,,,,,,, -76400,1.6233957,3.15841,,,,,,,,,,,,,, -76500,1.5546654,3.5132065,,,,,,,,,,,,,, -76588,,,0.6521288752555847,1.591919183731079,0.6082599759101868,1.792444348335266,50000.0,0.4869000315666199,2.4197306632995605,10000.0,34494.56536388397,37144.96238017082,34494.56536388397,2643.3242888450623,3.088836431503296,0.0 -76600,1.4943091,5.38359,,,,,,,,,,,,,, -76700,1.6731535,3.267724,,,,,,,,,,,,,, -76800,1.5076786,3.7059789,,,,,,,,,,,,,, -76900,1.792654,3.1587274,,,,,,,,,,,,,, -77000,1.5583674,3.186892,,,,,,,,,,,,,, -77100,1.7037394,3.4138227,,,,,,,,,,,,,, -77200,1.7498084,5.5038824,,,,,,,,,,,,,, -77300,1.5244073,4.252227,,,,,,,,,,,,,, -77400,1.6849269,3.2023327,,,,,,,,,,,,,, -77500,1.4651424,4.0736136,,,,,,,,,,,,,, -77520,,,0.6513671875,1.5582828521728516,0.6090599894523621,1.7538148164749146,50000.0,0.4862000346183777,2.3950002193450928,10000.0,34915.00139904022,37598.55631041527,34915.00139904022,2676.3921184539795,3.130977869033813,0.0 -77600,1.6896968,3.1948507,,,,,,,,,,,,,, -77700,1.8347946,3.3470047,,,,,,,,,,,,,, -77800,1.701295,3.1536548,,,,,,,,,,,,,, -77900,1.7761006,3.152521,,,,,,,,,,,,,, -78000,1.4446794,4.4007063,,,,,,,,,,,,,, -78100,1.9434739,3.2624068,,,,,,,,,,,,,, -78200,1.7338097,3.1680439,,,,,,,,,,,,,, -78300,1.538898,3.199873,,,,,,,,,,,,,, -78400,1.5385795,4.7864485,,,,,,,,,,,,,, -78452,,,0.6617382764816284,1.5026801824569702,0.6145200133323669,1.7247921228408811,50000.0,0.4884000122547149,2.357494354248047,10000.0,35335.17891597748,38051.32184123993,35335.17891597748,2708.896152973175,3.167029619216919,0.0 -78500,1.7817847,3.2638469,,,,,,,,,,,,,, -78600,1.7055565,3.0831904,,,,,,,,,,,,,, -78700,1.7360799,3.1078234,,,,,,,,,,,,,, -78800,1.7495968,3.477221,,,,,,,,,,,,,, -78900,1.3026168,4.1739616,,,,,,,,,,,,,, -79000,1.8036052,3.3244271,,,,,,,,,,,,,, -79100,1.5582215,3.1376805,,,,,,,,,,,,,, -79200,1.5529711,4.4547048,,,,,,,,,,,,,, -79300,1.8500981,3.2159789,,,,,,,,,,,,,, -79383,,,0.6783593893051147,1.4718852043151855,0.612280011177063,1.779030442237854,50000.0,0.4864000082015991,2.4213500022888184,10000.0,35755.31320667267,38505.05619072914,35755.31320667267,2742.400264978409,3.214477300643921,0.0 -79400,1.7448863,3.0636406,,,,,,,,,,,,,, -79500,1.6989119,3.193265,,,,,,,,,,,,,, -79600,1.6689899,3.388154,,,,,,,,,,,,,, -79700,1.5151275,3.8643496,,,,,,,,,,,,,, -79800,1.4172211,4.8651943,,,,,,,,,,,,,, -79900,1.4466548,4.2876787,,,,,,,,,,,,,, -80000,1.5422649,3.454085,,,,,,,,,,,,,, -80100,1.8048248,3.2310467,,,,,,,,,,,,,, -80200,1.7065902,3.04839,,,,,,,,,,,,,, -80300,1.5021846,3.5116308,,,,,,,,,,,,,, -80314,,,0.655468761920929,1.5589125156402588,0.6153199672698975,1.7471002340316772,50000.0,0.4903000295162201,2.387882232666016,10000.0,36175.47290205956,38958.28948068619,36175.47290205956,2775.382992506027,3.256436347961426,0.0 -80400,1.4560024,3.9963086,,,,,,,,,,,,,, -80500,1.8626243,3.1161313,,,,,,,,,,,,,, -80600,1.5583191,5.327501,,,,,,,,,,,,,, -80700,1.9317765,3.2310488,,,,,,,,,,,,,, -80800,1.6433932,3.2843544,,,,,,,,,,,,,, -80900,1.6919289,3.429954,,,,,,,,,,,,,, -81000,1.7233974,3.5562384,,,,,,,,,,,,,, -81100,1.8441538,3.1214414,,,,,,,,,,,,,, -81200,1.3130851,4.636489,,,,,,,,,,,,,, -81246,,,0.672070324420929,1.45991849899292,0.6218799948692322,1.682550311088562,50000.0,0.4992000162601471,2.3286523818969727,10000.0,36595.45416688919,39412.10371303558,36595.45416688919,2809.1208050251007,3.3032021522521973,0.0 -81300,1.4812577,4.7941747,,,,,,,,,,,,,, -81400,1.515204,5.319288,,,,,,,,,,,,,, -81500,1.6023967,3.9899173,,,,,,,,,,,,,, -81600,1.4577634,3.80508,,,,,,,,,,,,,, -81700,1.3857633,4.3722305,,,,,,,,,,,,,, -81800,1.7459079,3.268541,,,,,,,,,,,,,, -81900,2.0327501,3.1193988,,,,,,,,,,,,,, -82000,1.8124292,3.1660886,,,,,,,,,,,,,, -82100,1.8316209,3.296138,,,,,,,,,,,,,, -82178,,,0.677539050579071,1.4339114427566528,0.6177999973297119,1.7054755687713623,50000.0,0.4916000366210937,2.343575954437256,10000.0,37015.47077083588,39865.35120391846,37015.47077083588,2842.265993595124,3.340360879898072,0.0 -82200,1.441675,4.070748,,,,,,,,,,,,,, -82300,1.5160213,3.7636817,,,,,,,,,,,,,, -82400,1.8166907,3.2188442,,,,,,,,,,,,,, -82500,1.8651149,3.1109395,,,,,,,,,,,,,, -82600,1.6943499,3.53268,,,,,,,,,,,,,, -82700,1.4111267,4.2070074,,,,,,,,,,,,,, -82800,1.3290955,4.16824,,,,,,,,,,,,,, -82900,1.4007545,4.9196515,,,,,,,,,,,,,, -83000,1.8313886,3.2063918,,,,,,,,,,,,,, -83100,1.7395041,5.226446,,,,,,,,,,,,,, -83109,,,0.6681249737739563,1.490073561668396,0.619879961013794,1.6944104433059692,50000.0,0.5033000111579895,2.3191699981689453,10000.0,37435.68727660179,40315.00443935394,37435.68727660179,2871.618143796921,3.3766629695892334,0.0 -83200,2.2015042,3.1451616,,,,,,,,,,,,,, -83300,1.8128233,3.134149,,,,,,,,,,,,,, -83400,1.6397768,3.235239,,,,,,,,,,,,,, -83500,1.9910445,3.1556942,,,,,,,,,,,,,, -83600,1.6131837,3.4268827,,,,,,,,,,,,,, -83700,1.7568089,3.4821498,,,,,,,,,,,,,, -83800,1.7249395,3.2406335,,,,,,,,,,,,,, -83900,1.3534255,4.843944,,,,,,,,,,,,,, -84000,1.6825964,3.1443543,,,,,,,,,,,,,, -84038,,,0.6651171445846558,1.465654730796814,0.6218000054359436,1.6730984449386597,50000.0,0.4980000257492065,2.30544662475586,10000.0,37855.84566473961,40767.02870512009,37855.84566473961,2903.392266750336,3.419344186782837,0.0 -84100,1.7878562,3.1371121,,,,,,,,,,,,,, -84200,1.6967214,3.3398697,,,,,,,,,,,,,, -84300,1.8086962,2.9943767,,,,,,,,,,,,,, -84400,1.8768419,3.0561814,,,,,,,,,,,,,, -84500,1.7875309,3.0413833,,,,,,,,,,,,,, -84600,1.8870096,3.225758,,,,,,,,,,,,,, -84700,1.3737004,4.3165703,,,,,,,,,,,,,, -84800,1.7145312,3.2649963,,,,,,,,,,,,,, -84900,1.7098373,2.9569237,,,,,,,,,,,,,, -84971,,,0.6802929639816284,1.428873062133789,0.6230999827384949,1.6761740446090698,50000.0,0.503000020980835,2.307889223098755,10000.0,38275.783217191696,41219.60556221008,38275.783217191696,2935.9429478645325,3.4596714973449707,0.0 -85000,1.7097057,3.1790738,,,,,,,,,,,,,, -85100,1.4648111,3.920629,,,,,,,,,,,,,, -85200,1.7157692,3.1746404,,,,,,,,,,,,,, -85300,1.536196,4.734055,,,,,,,,,,,,,, -85400,1.614269,4.075532,,,,,,,,,,,,,, -85500,1.7302507,3.3477728,,,,,,,,,,,,,, -85600,1.5912011,3.5358212,,,,,,,,,,,,,, -85700,1.6699785,3.2833393,,,,,,,,,,,,,, -85800,1.7517014,3.221642,,,,,,,,,,,,,, -85900,1.495642,5.245803,,,,,,,,,,,,,, -85903,,,0.6666210889816284,1.5058388710021973,0.623259961605072,1.7001662254333496,50000.0,0.499500036239624,2.3466873168945312,10000.0,38695.95108413696,41672.24136352539,38695.95108413696,2968.3241169452667,3.498311281204224,0.0 -86000,1.3879658,5.273839,,,,,,,,,,,,,, -86100,1.5409633,4.781111,,,,,,,,,,,,,, -86200,1.5257936,4.785886,,,,,,,,,,,,,, -86300,1.7556303,3.1663418,,,,,,,,,,,,,, -86400,1.750875,3.0582414,,,,,,,,,,,,,, -86500,1.6846994,3.3167577,,,,,,,,,,,,,, -86600,1.7589017,3.3178363,,,,,,,,,,,,,, -86700,1.8665415,3.2277565,,,,,,,,,,,,,, -86800,1.9159005,3.2278218,,,,,,,,,,,,,, -86834,,,0.6658398509025574,1.5209107398986816,0.6223999857902527,1.7161909341812134,50000.0,0.5,2.360891819000244,10000.0,39116.04857087135,42123.79933476448,39116.04857087135,2999.692320823669,3.542412757873535,0.0 -86900,1.92056,3.052807,,,,,,,,,,,,,, -87000,1.5684205,5.4573107,,,,,,,,,,,,,, -87100,1.5296482,4.208612,,,,,,,,,,,,,, -87200,1.743186,3.1699493,,,,,,,,,,,,,, -87300,1.7267534,3.0350718,,,,,,,,,,,,,, -87400,1.6847261,3.062407,,,,,,,,,,,,,, -87500,1.4852238,4.559154,,,,,,,,,,,,,, -87600,1.5204388,4.0012074,,,,,,,,,,,,,, -87700,1.839572,3.1160774,,,,,,,,,,,,,, -87764,,,0.6809179782867432,1.4363369941711426,0.6262999773025513,1.676653504371643,50000.0,0.4973000288009643,2.3251760005950928,10000.0,39536.07976317406,42577.90255379677,39536.07976317406,3033.672209739685,3.5862622261047363,0.0 -87800,1.9788023,3.1161983,,,,,,,,,,,,,, -87900,1.503041,4.8900847,,,,,,,,,,,,,, -88000,1.4997703,4.8757944,,,,,,,,,,,,,, -88100,2.0054023,3.302462,,,,,,,,,,,,,, -88200,1.8173721,3.0456004,,,,,,,,,,,,,, -88300,1.9860737,3.0576007,,,,,,,,,,,,,, -88400,1.6790563,3.1735249,,,,,,,,,,,,,, -88500,1.6328298,3.4066086,,,,,,,,,,,,,, -88600,1.7291322,3.051236,,,,,,,,,,,,,, -88696,,,0.6968359351158142,1.3666648864746094,0.6245999932289124,1.67709481716156,50000.0,0.5035000443458557,2.3024938106536865,10000.0,39956.33709144592,43030.892899513245,39956.33709144592,3066.316313743592,3.6272923946380615,0.0 -88700,1.4944265,4.47392,,,,,,,,,,,,,, -88800,1.5801105,5.1469593,,,,,,,,,,,,,, -88900,1.6155825,3.239984,,,,,,,,,,,,,, -89000,1.5910504,3.9050152,,,,,,,,,,,,,, -89100,1.5880237,4.1047893,,,,,,,,,,,,,, -89200,1.7452924,3.0426018,,,,,,,,,,,,,, -89300,1.9108725,3.0852857,,,,,,,,,,,,,, -89400,1.8101596,3.1863558,,,,,,,,,,,,,, -89500,1.9127085,3.1593316,,,,,,,,,,,,,, -89600,1.7778869,3.4329386,,,,,,,,,,,,,, -89628,,,0.6750390529632568,1.4493118524551392,0.6331999897956848,1.6430368423461914,50000.0,0.5107000470161438,2.262535572052002,10000.0,40376.373056173325,43481.935858011246,40376.373056173325,3097.2331142425537,3.6694083213806152,0.0 -89700,1.9486002,3.0279822,,,,,,,,,,,,,, -89800,1.7565926,3.3730574,,,,,,,,,,,,,, -89900,1.6716523,3.0000699,,,,,,,,,,,,,, -90000,2.0645638,3.0023026,,,,,,,,,,,,,, -90100,1.7401576,3.0557868,,,,,,,,,,,,,, -90200,1.7771263,2.9079597,,,,,,,,,,,,,, -90300,1.8766629,3.0717854,,,,,,,,,,,,,, -90400,1.7025872,3.3529737,,,,,,,,,,,,,, -90500,2.1087883,3.4099708,,,,,,,,,,,,,, -90554,,,0.6821093559265137,1.3882906436920166,0.6303799748420715,1.624939203262329,50000.0,0.5046000480651855,2.265634775161743,10000.0,40796.29656982422,43931.92548966408,40796.29656982422,3127.205216407776,3.715543746948242,0.0 -90600,1.6722347,3.5848544,,,,,,,,,,,,,, -90700,2.0000708,3.1451952,,,,,,,,,,,,,, -90800,1.6410571,3.1989088,,,,,,,,,,,,,, -90900,1.6340289,5.287427,,,,,,,,,,,,,, -91000,1.7817467,3.07293,,,,,,,,,,,,,, -91100,1.7265866,3.1570396,,,,,,,,,,,,,, -91200,1.7160882,3.4495733,,,,,,,,,,,,,, -91300,1.759134,3.1324909,,,,,,,,,,,,,, -91400,1.8364289,3.084399,,,,,,,,,,,,,, -91481,,,0.6850780844688416,1.3804093599319458,0.6272000074386597,1.6394050121307373,50000.0,0.5101000070571899,2.2855710983276367,10000.0,41216.31773328781,44384.83210873604,41216.31773328781,3159.993196964264,3.765239715576172,0.0 -91500,1.546068,4.245015,,,,,,,,,,,,,, -91600,1.6174295,3.2293878,,,,,,,,,,,,,, -91700,1.7107309,3.5823948,,,,,,,,,,,,,, -91800,1.6548613,5.3269463,,,,,,,,,,,,,, -91900,1.5515804,4.9534845,,,,,,,,,,,,,, -92000,1.7485572,2.998014,,,,,,,,,,,,,, -92100,1.7835784,3.159623,,,,,,,,,,,,,, -92200,1.9795914,3.3771403,,,,,,,,,,,,,, -92300,1.6394272,4.9987297,,,,,,,,,,,,,, -92400,1.7393336,3.1351285,,,,,,,,,,,,,, -92412,,,0.6785351634025574,1.4298800230026243,0.6362599730491638,1.630461573600769,50000.0,0.5117000341415405,2.257552146911621,10000.0,41636.72108960152,44838.06574630737,41636.72108960152,3192.737065553665,3.803889751434326,0.0 -92500,1.6945547,3.152821,,,,,,,,,,,,,, -92600,1.6596633,3.5989738,,,,,,,,,,,,,, -92700,1.8751303,3.0137353,,,,,,,,,,,,,, -92800,1.7847402,3.0492654,,,,,,,,,,,,,, -92900,1.745689,4.415225,,,,,,,,,,,,,, -93000,1.7712257,3.2036192,,,,,,,,,,,,,, -93100,1.6122906,3.7322197,,,,,,,,,,,,,, -93200,1.8506285,3.0657947,,,,,,,,,,,,,, -93300,1.6242509,5.2677765,,,,,,,,,,,,,, -93343,,,0.68212890625,1.4268741607666016,0.6340399980545044,1.639679789543152,50000.0,0.5101000070571899,2.2780954837799072,10000.0,42057.02818584442,45289.685396671295,42057.02818584442,3223.965080499649,3.8404388427734375,0.0 -93400,1.6247156,3.6773903,,,,,,,,,,,,,, -93500,1.9475758,3.1067626,,,,,,,,,,,,,, -93600,2.0056448,3.131492,,,,,,,,,,,,,, -93700,1.9655939,3.1566513,,,,,,,,,,,,,, -93800,1.7468055,4.258772,,,,,,,,,,,,,, -93900,1.8840066,3.6766877,,,,,,,,,,,,,, -94000,1.9565688,3.0311584,,,,,,,,,,,,,, -94100,1.939974,2.9431007,,,,,,,,,,,,,, -94200,1.7874789,4.9807253,,,,,,,,,,,,,, -94272,,,0.6900390386581421,1.358396291732788,0.632860004901886,1.6014665365219116,50000.0,0.5105000138282776,2.2469406127929688,10000.0,42476.98532438278,45741.764285326,42476.98532438278,3255.977776527405,3.900721549987793,0.0 -94300,1.5454192,4.2912216,,,,,,,,,,,,,, -94400,1.9817595,2.9555101,,,,,,,,,,,,,, -94500,1.8679677,2.9884934,,,,,,,,,,,,,, -94600,1.5703464,5.1123495,,,,,,,,,,,,,, -94700,1.8512373,3.0422468,,,,,,,,,,,,,, -94800,1.893924,3.830222,,,,,,,,,,,,,, -94900,1.7627873,4.9484286,,,,,,,,,,,,,, -95000,1.814544,5.025979,,,,,,,,,,,,,, -95100,1.8922234,3.1325445,,,,,,,,,,,,,, -95200,1.8555514,2.911878,,,,,,,,,,,,,, -95203,,,0.69740229845047,1.3583762645721436,0.6399799585342407,1.6108720302581787,50000.0,0.5120000243186951,2.2654266357421875,10000.0,42896.978048563,46192.10404133797,42896.978048563,3286.2305703163147,3.9465410709381104,0.0 -95300,1.8620523,2.9861815,,,,,,,,,,,,,, -95400,1.7864208,3.7003262,,,,,,,,,,,,,, -95500,1.7794455,3.0702434,,,,,,,,,,,,,, -95600,2.0232317,3.236124,,,,,,,,,,,,,, -95700,2.0508556,3.1704135,,,,,,,,,,,,,, -95800,2.012931,3.0270066,,,,,,,,,,,,,, -95900,1.9605142,5.213504,,,,,,,,,,,,,, -96000,1.8760085,2.9953153,,,,,,,,,,,,,, -96100,1.8833288,5.226363,,,,,,,,,,,,,, -96132,,,0.6863671541213989,1.4062321186065674,0.6400399804115295,1.6143362522125244,50000.0,0.5175000429153442,2.22903060913086,10000.0,43316.982568740845,46642.69738292694,43316.982568740845,3316.7220873832703,3.99554705619812,0.0 -96200,1.791521,4.3760037,,,,,,,,,,,,,, -96300,1.9197905,3.029633,,,,,,,,,,,,,, -96400,1.7882764,3.470837,,,,,,,,,,,,,, -96500,1.8221564,4.933984,,,,,,,,,,,,,, -96600,2.1358407,2.9754379,,,,,,,,,,,,,, -96700,1.7950002,5.239003,,,,,,,,,,,,,, -96800,1.962416,3.1716142,,,,,,,,,,,,,, -96900,1.9658033,3.0201347,,,,,,,,,,,,,, -97000,1.95357,3.0163753,,,,,,,,,,,,,, -97062,,,0.685546875,1.4155725240707395,0.6375600099563599,1.633133053779602,50000.0,0.5182999968528748,2.2834768295288086,10000.0,43736.97590112686,47092.60734796524,43736.97590112686,3346.5461995601654,4.039233446121216,0.0 -97100,1.5871648,3.916845,,,,,,,,,,,,,, -97200,1.8251423,3.1336763,,,,,,,,,,,,,, -97300,1.7691671,2.8742385,,,,,,,,,,,,,, -97400,1.8336169,2.9543633,,,,,,,,,,,,,, -97500,1.6466919,3.4195738,,,,,,,,,,,,,, -97600,1.9610822,3.106442,,,,,,,,,,,,,, -97700,1.6063147,4.328992,,,,,,,,,,,,,, -97800,1.9543244,4.372752,,,,,,,,,,,,,, -97900,1.5655904,4.2713695,,,,,,,,,,,,,, -97992,,,0.7066406011581421,1.2971094846725464,0.6403200030326843,1.582248568534851,50000.0,0.5205000042915344,2.2080962657928467,10000.0,44156.90197920799,47543.903435230255,44156.90197920799,3377.816866159439,4.090280055999756,0.0 -98000,1.9552553,5.17334,,,,,,,,,,,,,, -98100,1.8868597,3.0505896,,,,,,,,,,,,,, -98200,1.89117,3.451421,,,,,,,,,,,,,, -98300,2.0230281,3.7594278,,,,,,,,,,,,,, -98400,2.0189476,3.0284314,,,,,,,,,,,,,, -98500,1.7766505,4.3522053,,,,,,,,,,,,,, -98600,1.9619982,2.9424343,,,,,,,,,,,,,, -98700,2.0354655,3.0867329,,,,,,,,,,,,,, -98800,1.5387607,4.843735,,,,,,,,,,,,,, -98900,2.0217276,3.068943,,,,,,,,,,,,,, -98923,,,0.6883788704872131,1.393426775932312,0.641319990158081,1.6070661544799805,50000.0,0.5091000199317932,2.254444599151612,10000.0,44576.86071538925,47997.69262051582,44576.86071538925,3411.555104732513,4.134812593460083,0.0 -99000,1.8268806,3.623696,,,,,,,,,,,,,, -99100,1.8191845,3.6691117,,,,,,,,,,,,,, -99200,1.631457,3.8121269,,,,,,,,,,,,,, -99300,1.9995348,3.9211924,,,,,,,,,,,,,, -99400,1.956489,4.6877394,,,,,,,,,,,,,, -99500,1.8932769,3.0744007,,,,,,,,,,,,,, -99600,2.0336666,3.0609431,,,,,,,,,,,,,, -99700,2.1911316,3.897142,,,,,,,,,,,,,, -99800,1.8445191,5.229978,,,,,,,,,,,,,, -99856,,,0.6871289014816284,1.3720650672912598,0.6365999579429626,1.5965046882629397,50000.0,0.5134000182151794,2.239524602890014,10000.0,44996.839626550674,48449.945254564285,44996.839626550674,3443.7383959293365,4.1757354736328125,0.0 -99900,2.0011456,3.0545955,,,,,,,,,,,,,, -100000,2.0746636,3.23832,,,,,,,,,,,,,, -100100,2.1032262,3.0752676,,,,,,,,,,,,,, -100200,1.6125511,4.2630486,,,,,,,,,,,,,, -100300,2.1235578,4.926929,,,,,,,,,,,,,, -100400,1.9810011,3.113409,,,,,,,,,,,,,, -100500,1.9262766,3.1063697,,,,,,,,,,,,,, -100600,1.49796,4.6653857,,,,,,,,,,,,,, -100700,1.9591252,3.19549,,,,,,,,,,,,,, -100785,,,0.7017577886581421,1.312267780303955,0.6451999545097351,1.5605461597442627,50000.0,0.522599995136261,2.181353807449341,10000.0,45416.97100830078,48903.77703499794,45416.97100830078,3477.3515434265137,4.214470624923706,0.0 -100800,2.1074336,3.0727587,,,,,,,,,,,,,, -100900,1.9393499,2.9679186,,,,,,,,,,,,,, -101000,2.0335407,5.2009187,,,,,,,,,,,,,, -101100,1.8067209,3.802067,,,,,,,,,,,,,, -101200,1.9690892,3.081893,,,,,,,,,,,,,, -101300,1.7764571,3.2001772,,,,,,,,,,,,,, -101400,1.8508295,3.0361314,,,,,,,,,,,,,, -101500,1.9839643,3.2253377,,,,,,,,,,,,,, -101600,1.9056427,2.9792035,,,,,,,,,,,,,, -101700,2.193856,2.9375336,,,,,,,,,,,,,, -101716,,,0.69349604845047,1.36771821975708,0.6436399817466736,1.5866146087646484,50000.0,0.5249000191688538,2.2173092365264893,10000.0,45837.05639505386,49357.55435633659,45837.05639505386,3510.946757078171,4.263065338134766,0.0 -101800,1.615869,4.5551205,,,,,,,,,,,,,, -101900,2.04749,3.0458405,,,,,,,,,,,,,, -102000,1.8201424,5.0590706,,,,,,,,,,,,,, -102100,1.596886,5.1250997,,,,,,,,,,,,,, -102200,2.0326467,3.014753,,,,,,,,,,,,,, -102300,2.032871,3.1267815,,,,,,,,,,,,,, -102400,1.803225,5.129404,,,,,,,,,,,,,, -102500,1.6945813,4.778289,,,,,,,,,,,,,, -102600,2.0082152,2.986356,,,,,,,,,,,,,, -102645,,,0.6882616877555847,1.39771568775177,0.6396999955177307,1.6158915758132937,50000.0,0.5223000049591064,2.250572919845581,10000.0,46257.20961642265,49808.57501959801,46257.20961642265,3541.7239258289337,4.30495023727417,0.0 -102700,1.6667897,4.8017287,,,,,,,,,,,,,, -102800,2.0546732,2.9978943,,,,,,,,,,,,,, -102900,1.7349681,4.739115,,,,,,,,,,,,,, -103000,2.2153463,3.0894158,,,,,,,,,,,,,, -103100,1.9509207,3.2268379,,,,,,,,,,,,,, -103200,2.0666232,5.1022153,,,,,,,,,,,,,, -103300,1.7745492,3.6098852,,,,,,,,,,,,,, -103400,1.8647311,2.9397564,,,,,,,,,,,,,, -103500,1.9840163,3.1314678,,,,,,,,,,,,,, -103573,,,0.7084375023841858,1.2893983125686646,0.6509000062942505,1.5381120443344116,50000.0,0.5329000353813171,2.161540508270264,10000.0,46677.277096033096,50258.69889426232,46677.277096033096,3571.6743774414062,4.362541913986206,0.0 -103600,1.615406,3.7384725,,,,,,,,,,,,,, -103700,1.7752054,3.7866914,,,,,,,,,,,,,, -103800,2.179744,2.9276946,,,,,,,,,,,,,, -103900,1.9423285,3.008544,,,,,,,,,,,,,, -104000,2.4483032,3.1608896,,,,,,,,,,,,,, -104100,1.8437287,5.195168,,,,,,,,,,,,,, -104200,2.136684,2.9131641,,,,,,,,,,,,,, -104300,1.9705641,3.0234344,,,,,,,,,,,,,, -104400,1.8788702,3.4561994,,,,,,,,,,,,,, -104500,1.917423,3.5623338,,,,,,,,,,,,,, -104501,,,0.7164257764816284,1.3079707622528076,0.6469599604606628,1.601804494857788,50000.0,0.526199996471405,2.236384868621826,10000.0,47097.76482272148,50709.71799230576,47097.76482272148,3602.106454372406,4.4127349853515625,0.0 -104600,1.8922273,2.9151795,,,,,,,,,,,,,, -104700,2.03283,2.9561367,,,,,,,,,,,,,, -104800,1.6692876,4.4191046,,,,,,,,,,,,,, -104900,2.1184592,2.97861,,,,,,,,,,,,,, -105000,2.008763,2.9029472,,,,,,,,,,,,,, -105100,1.916662,3.9830637,,,,,,,,,,,,,, -105200,1.6772853,3.7507029,,,,,,,,,,,,,, -105300,2.0353842,2.9823976,,,,,,,,,,,,,, -105400,2.3247156,3.3130488,,,,,,,,,,,,,, -105427,,,0.6813281178474426,1.432881474494934,0.6276599764823914,1.6766186952590942,50000.0,0.5055000185966492,2.2996857166290283,10000.0,47517.686259269714,51163.54053092003,47517.686259269714,3635.9119005203247,4.460398435592651,0.0 -105500,2.1114821,2.8735168,,,,,,,,,,,,,, -105600,1.7923309,4.845955,,,,,,,,,,,,,, -105700,1.6727122,4.989375,,,,,,,,,,,,,, -105800,1.8012875,3.730073,,,,,,,,,,,,,, -105900,1.7453521,3.9457598,,,,,,,,,,,,,, -106000,2.0563886,3.0474653,,,,,,,,,,,,,, -106100,1.8656638,3.6493394,,,,,,,,,,,,,, -106200,2.1256046,2.8973813,,,,,,,,,,,,,, -106300,2.196552,2.8397899,,,,,,,,,,,,,, -106359,,,0.7081249952316284,1.3071223497390747,0.653439998626709,1.5344473123550415,50000.0,0.5250000357627869,2.1776108741760254,10000.0,47937.69239234924,51615.63697743416,47937.69239234924,3667.913044929504,4.501140356063843,0.0 -106400,2.0915356,2.984655,,,,,,,,,,,,,, -106500,2.1223586,3.0399985,,,,,,,,,,,,,, -106600,2.1256146,2.8739574,,,,,,,,,,,,,, -106700,1.7226006,4.870388,,,,,,,,,,,,,, -106800,2.0160563,3.0902581,,,,,,,,,,,,,, -106900,2.005765,2.906763,,,,,,,,,,,,,, -107000,2.0539296,3.0262573,,,,,,,,,,,,,, -107100,2.1536083,3.0016181,,,,,,,,,,,,,, -107200,1.7430625,4.787407,,,,,,,,,,,,,, -107289,,,0.7215625047683716,1.2222944498062134,0.6552599668502808,1.523241400718689,50000.0,0.535800039768219,2.1337265968322754,10000.0,48357.65385222435,52067.7736222744,48357.65385222435,3699.998573303223,4.5414369106292725,0.0 -107300,2.1087513,3.0578094,,,,,,,,,,,,,, -107400,2.8975646,3.9596386,,,,,,,,,,,,,, -107500,2.0806324,5.1903253,,,,,,,,,,,,,, -107600,2.0710049,3.9079375,,,,,,,,,,,,,, -107700,1.9054886,3.387282,,,,,,,,,,,,,, -107800,2.0975826,2.812378,,,,,,,,,,,,,, -107900,2.0201092,3.0023856,,,,,,,,,,,,,, -108000,2.0457273,3.12168,,,,,,,,,,,,,, -108100,2.315782,2.8926327,,,,,,,,,,,,,, -108200,1.8984523,3.96583,,,,,,,,,,,,,, -108222,,,0.7074999809265137,1.2904741764068604,0.6606400012969971,1.5038014650344849,50000.0,0.5349000096321106,2.1375465393066406,10000.0,48777.81280827522,52518.94608283043,48777.81280827522,3730.917119503021,4.586863279342651,0.0 -108300,2.158533,3.105456,,,,,,,,,,,,,, -108400,2.0592675,5.0736732,,,,,,,,,,,,,, -108500,2.2349467,5.114026,,,,,,,,,,,,,, -108600,1.9485115,4.6091356,,,,,,,,,,,,,, -108700,1.7756062,4.903991,,,,,,,,,,,,,, -108800,1.9322706,4.611401,,,,,,,,,,,,,, -108900,2.5403464,4.2963905,,,,,,,,,,,,,, -109000,1.685874,4.3503036,,,,,,,,,,,,,, -109100,2.0083206,2.9715543,,,,,,,,,,,,,, -109153,,,0.7104882597923279,1.2939804792404177,0.6558799743652344,1.531550407409668,50000.0,0.5330000519752502,2.1510488986968994,10000.0,49197.85523939133,52972.1984333992,49197.85523939133,3764.027851104736,4.637696266174316,0.0 -109200,2.058856,2.943552,,,,,,,,,,,,,, -109300,2.275843,2.8153944,,,,,,,,,,,,,, -109400,4.195169,3.4802303,,,,,,,,,,,,,, -109500,2.0166261,3.0172758,,,,,,,,,,,,,, -109600,2.006994,4.9596834,,,,,,,,,,,,,, -109700,1.7104145,5.0073466,,,,,,,,,,,,,, -109800,1.8245225,4.9858036,,,,,,,,,,,,,, -109900,1.7225659,3.9698484,,,,,,,,,,,,,, -110000,1.770186,5.0018616,,,,,,,,,,,,,, -110084,,,0.7136523127555847,1.3198362588882446,0.6524199843406677,1.5822229385375977,50000.0,0.5314000248908997,2.220554113388061,10000.0,49617.85936307907,53426.40553641319,49617.85936307907,3798.13270521164,4.687514781951904,0.0 -110100,2.1007524,3.113289,,,,,,,,,,,,,, -110200,2.0086555,2.9263442,,,,,,,,,,,,,, -110300,1.6276445,4.5429196,,,,,,,,,,,,,, -110400,2.0212495,3.6455312,,,,,,,,,,,,,, -110500,2.3969321,3.0025816,,,,,,,,,,,,,, -110600,1.8482584,3.1982324,,,,,,,,,,,,,, -110700,1.9024134,3.2768118,,,,,,,,,,,,,, -110800,2.053822,3.0116856,,,,,,,,,,,,,, -110900,2.1175191,2.9881845,,,,,,,,,,,,,, -111000,2.2878,2.951422,,,,,,,,,,,,,, -111016,,,0.7125585675239563,1.2730872631072998,0.6649599671363831,1.4868433475494385,50000.0,0.537600040435791,2.108808279037476,10000.0,50038.1953496933,53881.36761689186,50038.1953496933,3832.661499977112,4.737124443054199,0.0 -111100,1.8857347,2.9148307,,,,,,,,,,,,,, -111200,2.1236944,3.1729875,,,,,,,,,,,,,, -111300,2.0691397,5.149493,,,,,,,,,,,,,, -111400,2.103289,2.8664987,,,,,,,,,,,,,, -111500,2.1314662,3.005967,,,,,,,,,,,,,, -111600,1.9088062,4.0206485,,,,,,,,,,,,,, -111700,1.9460422,4.3511367,,,,,,,,,,,,,, -111800,2.2202604,4.105709,,,,,,,,,,,,,, -111900,1.9335775,3.8387861,,,,,,,,,,,,,, -111947,,,0.7182812094688416,1.2427871227264404,0.6640799641609192,1.481744647026062,50000.0,0.5469000339508057,2.1059060096740723,10000.0,50458.33082890511,54329.931025743484,50458.33082890511,3860.9939935207367,4.784902572631836,0.0 -112000,2.1439588,3.165837,,,,,,,,,,,,,, -112100,2.035024,2.8472514,,,,,,,,,,,,,, -112200,1.9508859,4.3618555,,,,,,,,,,,,,, -112300,2.3289797,2.9070873,,,,,,,,,,,,,, -112400,2.0152357,3.200427,,,,,,,,,,,,,, -112500,1.8834081,3.645537,,,,,,,,,,,,,, -112600,2.2020068,2.9079149,,,,,,,,,,,,,, -112700,1.9148893,3.54027,,,,,,,,,,,,,, -112800,2.1572776,2.8738523,,,,,,,,,,,,,, -112868,,,0.7190234065055847,1.2374742031097412,0.6615200042724609,1.4908949136734009,50000.0,0.5367000102996826,2.121800184249878,10000.0,50878.57825565338,54783.457184791565,50878.57825565338,3894.177620410919,4.832216739654541,0.0 -112900,2.0281563,3.3012347,,,,,,,,,,,,,, -113000,1.9224898,4.7150764,,,,,,,,,,,,,, -113100,2.1894267,2.963845,,,,,,,,,,,,,, -113200,2.3719923,3.1622941,,,,,,,,,,,,,, -113300,2.3165765,2.8920782,,,,,,,,,,,,,, -113400,2.164202,3.0409913,,,,,,,,,,,,,, -113500,2.4555988,2.9018085,,,,,,,,,,,,,, -113600,2.1514227,2.8697891,,,,,,,,,,,,,, -113700,1.986786,3.065067,,,,,,,,,,,,,, -113797,,,0.7408984303474426,1.1787163019180298,0.666979968547821,1.5026799440383911,50000.0,0.5442000031471252,2.132467031478882,10000.0,51298.48287606239,55236.77574682236,51298.48287606239,3927.492713212967,4.883016586303711,0.0 -113800,1.9229275,3.0308213,,,,,,,,,,,,,, -113900,2.1849449,2.8931272,,,,,,,,,,,,,, -114000,2.1497602,4.7124667,,,,,,,,,,,,,, -114100,2.2027805,2.8991716,,,,,,,,,,,,,, -114200,1.8996942,4.3269434,,,,,,,,,,,,,, -114300,2.0633366,2.9760897,,,,,,,,,,,,,, -114400,1.9826747,4.875321,,,,,,,,,,,,,, -114500,2.1601927,3.3857007,,,,,,,,,,,,,, -114600,2.2499099,2.9442155,,,,,,,,,,,,,, -114700,2.0416973,3.5508268,,,,,,,,,,,,,, -114730,,,0.7202734351158142,1.223569393157959,0.6654999852180481,1.4597338438034058,50000.0,0.5469000339508057,2.0868093967437744,10000.0,51718.38898730278,55690.430280447006,51718.38898730278,3961.1463055610657,4.928993463516235,0.0 -114800,1.9544641,4.45623,,,,,,,,,,,,,, -114900,2.1433163,3.7282214,,,,,,,,,,,,,, -115000,2.0811577,3.0136325,,,,,,,,,,,,,, -115100,2.1489358,2.8037734,,,,,,,,,,,,,, -115200,2.2145617,2.888301,,,,,,,,,,,,,, -115300,2.2311463,2.8720684,,,,,,,,,,,,,, -115400,2.0975468,2.9580266,,,,,,,,,,,,,, -115500,2.0859482,3.1629589,,,,,,,,,,,,,, -115600,2.0576005,2.8914242,,,,,,,,,,,,,, -115661,,,0.7249413728713989,1.2017278671264648,0.6695599555969238,1.4541553258895874,50000.0,0.5440000295639038,2.075801372528076,10000.0,52138.31260251999,56142.81825685501,52138.31260251999,3993.5162620544434,4.973785161972046,0.0 -115700,2.3337855,2.8379807,,,,,,,,,,,,,, -115800,1.9788507,4.406853,,,,,,,,,,,,,, -115900,2.2014694,2.7806387,,,,,,,,,,,,,, -116000,2.1050537,3.7301466,,,,,,,,,,,,,, -116100,2.259188,2.8803356,,,,,,,,,,,,,, -116200,2.0067558,3.8618627,,,,,,,,,,,,,, -116300,2.2316701,2.9091973,,,,,,,,,,,,,, -116400,2.036876,3.3116643,,,,,,,,,,,,,, -116500,2.0926993,2.7991498,,,,,,,,,,,,,, -116592,,,0.7337695360183716,1.1900928020477295,0.675059974193573,1.4655356407165527,50000.0,0.5515000224113464,2.100456476211548,10000.0,52558.60073399544,56596.52179288864,52558.60073399544,4026.839727401733,5.018024444580078,0.0 -116600,2.1021183,2.9751859,,,,,,,,,,,,,, -116700,2.1753492,3.0453608,,,,,,,,,,,,,, -116800,2.1508577,4.1624126,,,,,,,,,,,,,, -116900,2.4092798,3.0040345,,,,,,,,,,,,,, -117000,2.486259,2.8310475,,,,,,,,,,,,,, -117100,1.9557247,4.220297,,,,,,,,,,,,,, -117200,2.3441398,2.8693266,,,,,,,,,,,,,, -117300,2.1760595,2.980228,,,,,,,,,,,,,, -117400,2.3777409,2.8599794,,,,,,,,,,,,,, -117500,2.198602,2.827673,,,,,,,,,,,,,, -117526,,,0.730664074420929,1.1986579895019531,0.6748799681663513,1.436489820480347,50000.0,0.5491000413894653,2.076021194458008,10000.0,52978.923567056656,57048.03144288063,52978.923567056656,4057.929902076721,5.0658793449401855,0.0 -117600,2.2243807,3.3749247,,,,,,,,,,,,,, -117700,2.3876853,2.7356913,,,,,,,,,,,,,, -117800,2.2225323,2.7930458,,,,,,,,,,,,,, -117900,2.2175918,2.820287,,,,,,,,,,,,,, -118000,2.448853,2.8573236,,,,,,,,,,,,,, -118100,2.2918596,2.8423567,,,,,,,,,,,,,, -118200,2.2348301,2.8091526,,,,,,,,,,,,,, -118300,2.2667236,2.843639,,,,,,,,,,,,,, -118400,2.220694,4.9360685,,,,,,,,,,,,,, -118457,,,0.7255663871765137,1.2088725566864014,0.6714800000190735,1.4509501457214355,50000.0,0.5466000437736511,2.084545135498047,10000.0,53399.125903368,57499.6692841053,53399.125903368,4089.267287492752,5.115900993347168,0.0 -118500,1.9001539,4.076723,,,,,,,,,,,,,, -118600,2.4066343,2.842811,,,,,,,,,,,,,, -118700,2.1038952,4.381259,,,,,,,,,,,,,, -118800,2.1805303,2.8367307,,,,,,,,,,,,,, -118900,2.1019616,3.346269,,,,,,,,,,,,,, -119000,2.5333142,2.7713535,,,,,,,,,,,,,, -119100,2.220549,2.8719645,,,,,,,,,,,,,, -119200,2.3749552,2.8290687,,,,,,,,,,,,,, -119300,2.4369135,2.8419213,,,,,,,,,,,,,, -119385,,,0.7369726300239563,1.1480770111083984,0.6777399778366089,1.4084073305130005,50000.0,0.5509999990463257,2.038119316101074,10000.0,53819.215997457504,57952.38986158371,53819.215997457504,4121.7964906692505,5.169644355773926,0.0 -119400,2.3001432,3.096837,,,,,,,,,,,,,, -119500,2.0803044,3.135548,,,,,,,,,,,,,, -119600,2.247351,3.2057037,,,,,,,,,,,,,, -119700,2.131621,2.8828464,,,,,,,,,,,,,, -119800,2.1976202,3.9115736,,,,,,,,,,,,,, -119900,2.182934,4.714372,,,,,,,,,,,,,, -120000,2.3773086,2.7683723,,,,,,,,,,,,,, -120100,2.0029533,3.9225318,,,,,,,,,,,,,, -120200,2.3381243,4.4183683,,,,,,,,,,,,,, -120300,2.4608822,4.011881,,,,,,,,,,,,,, -120316,,,0.7320312261581421,1.1968371868133545,0.6744199991226196,1.4489213228225708,50000.0,0.5499000549316406,2.078464984893799,10000.0,54239.2432115078,58406.38825464249,54239.2432115078,4155.675201416016,5.214017868041992,0.0 -120400,2.5341983,2.8551943,,,,,,,,,,,,,, -120500,2.2309325,2.7928724,,,,,,,,,,,,,, -120600,2.7905693,3.8681679,,,,,,,,,,,,,, -120700,2.320518,2.9286911,,,,,,,,,,,,,, -120800,2.3170512,2.8112302,,,,,,,,,,,,,, -120900,2.3715055,2.8940878,,,,,,,,,,,,,, -121000,2.4036138,2.7090678,,,,,,,,,,,,,, -121100,2.2534294,2.839868,,,,,,,,,,,,,, -121200,2.337228,2.689302,,,,,,,,,,,,,, -121247,,,0.7334765195846558,1.1736871004104614,0.6791200041770935,1.4089908599853516,50000.0,0.5569000244140625,2.0281074047088623,10000.0,54659.52649736405,58857.70808959007,54659.52649736405,4186.622891664505,5.255097150802612,0.0 -121300,3.1410944,3.9986322,,,,,,,,,,,,,, -121400,2.1756117,4.219657,,,,,,,,,,,,,, -121500,2.2693017,2.9098263,,,,,,,,,,,,,, -121600,2.2214227,4.2672496,,,,,,,,,,,,,, -121700,2.3367152,2.9353342,,,,,,,,,,,,,, -121800,2.6921594,2.7361975,,,,,,,,,,,,,, -121900,2.303961,4.2161546,,,,,,,,,,,,,, -122000,2.4356694,2.7247689,,,,,,,,,,,,,, -122100,2.404952,2.863844,,,,,,,,,,,,,, -122176,,,0.7381835579872131,1.1373696327209473,0.6790800094604492,1.3966187238693235,50000.0,0.5547000169754028,2.018074989318848,10000.0,55079.78484630585,59309.61920070648,55079.78484630585,4218.173988342285,5.308438777923584,0.0 -122200,2.1848724,4.8862696,,,,,,,,,,,,,, -122300,2.3316267,3.2612605,,,,,,,,,,,,,, -122400,2.2226043,4.9149714,,,,,,,,,,,,,, -122500,2.3054552,2.6932092,,,,,,,,,,,,,, -122600,2.3726974,4.979402,,,,,,,,,,,,,, -122700,2.530676,3.5069008,,,,,,,,,,,,,, -122800,2.4242055,4.998727,,,,,,,,,,,,,, -122900,2.3192055,3.0309744,,,,,,,,,,,,,, -123000,2.6895149,2.7861252,,,,,,,,,,,,,, -123100,2.485748,2.7334816,,,,,,,,,,,,,, -123107,,,0.7488671541213989,1.1549774408340454,0.681439995765686,1.463075876235962,50000.0,0.5599000453948975,2.0738844871521,10000.0,55500.03019499779,59760.25983929634,55500.03019499779,4248.474764108658,5.354437828063965,0.0 -123200,2.4942517,4.726226,,,,,,,,,,,,,, -123300,2.4663157,2.8879802,,,,,,,,,,,,,, -123400,2.172104,3.5919597,,,,,,,,,,,,,, -123500,2.303821,4.0571127,,,,,,,,,,,,,, -123600,2.4505754,2.793887,,,,,,,,,,,,,, -123700,2.3028238,3.2384057,,,,,,,,,,,,,, -123800,2.2232442,4.932774,,,,,,,,,,,,,, -123900,2.2265186,2.8557086,,,,,,,,,,,,,, -124000,2.6948462,4.4436345,,,,,,,,,,,,,, -124036,,,0.7421875,1.1317152976989746,0.6843999624252319,1.3768173456192017,50000.0,0.5573000311851501,2.0066962242126465,10000.0,55920.1936044693,60211.03713226318,55920.1936044693,4278.991844892502,5.403183937072754,0.0 -124100,3.4608126,4.256061,,,,,,,,,,,,,, -124200,2.4612787,3.0139127,,,,,,,,,,,,,, -124300,2.4133675,4.8620567,,,,,,,,,,,,,, -124400,2.5518293,2.8002183,,,,,,,,,,,,,, -124500,2.4472163,2.7276523,,,,,,,,,,,,,, -124600,2.0762076,3.326152,,,,,,,,,,,,,, -124700,2.3517072,2.8009195,,,,,,,,,,,,,, -124800,2.372029,2.7423208,,,,,,,,,,,,,, -124900,2.4803088,2.7257676,,,,,,,,,,,,,, -124965,,,0.742968738079071,1.1672155857086182,0.6843799948692322,1.420559048652649,50000.0,0.5586000084877014,2.050241708755493,10000.0,56340.23328781128,60664.11954545975,56340.23328781128,4311.943066358566,5.446610689163208,0.0 -125000,2.2862344,3.5281057,,,,,,,,,,,,,, -125100,2.2994869,2.716918,,,,,,,,,,,,,, -125200,2.515708,2.732944,,,,,,,,,,,,,, -125300,2.5256462,2.7017126,,,,,,,,,,,,,, -125400,2.615875,3.148155,,,,,,,,,,,,,, -125500,2.4677224,4.5924187,,,,,,,,,,,,,, -125600,2.4295242,2.708416,,,,,,,,,,,,,, -125700,2.4205296,3.9676175,,,,,,,,,,,,,, -125800,2.6081405,2.8627954,,,,,,,,,,,,,, -125892,,,0.7481640577316284,1.098036289215088,0.6840199828147888,1.375110149383545,50000.0,0.563800036907196,1.99488365650177,10000.0,56760.46367549896,61117.63439536095,56760.46367549896,4345.138010501862,5.488474130630493,0.0 -125900,2.2063415,3.490728,,,,,,,,,,,,,, -126000,2.2185433,4.0551257,,,,,,,,,,,,,, -126100,2.5075471,2.767253,,,,,,,,,,,,,, -126200,2.479672,2.7850754,,,,,,,,,,,,,, -126300,2.324943,2.7920833,,,,,,,,,,,,,, -126400,2.3706517,4.507047,,,,,,,,,,,,,, -126500,2.3003135,3.2320309,,,,,,,,,,,,,, -126600,2.642663,2.6800685,,,,,,,,,,,,,, -126700,2.3210208,3.5682864,,,,,,,,,,,,,, -126800,2.5411546,2.752006,,,,,,,,,,,,,, -126822,,,0.7443749904632568,1.1106796264648438,0.6901800036430359,1.3542922735214231,50000.0,0.5568000078201294,1.996273279190064,10000.0,57180.5708527565,61572.45949554443,57180.5708527565,4379.761988639832,5.534748315811157,0.0 -126900,2.5151613,2.8092976,,,,,,,,,,,,,, -127000,2.6844823,2.7533512,,,,,,,,,,,,,, -127100,2.5178556,2.9680107,,,,,,,,,,,,,, -127200,2.6215522,2.76517,,,,,,,,,,,,,, -127300,2.3542135,3.8426332,,,,,,,,,,,,,, -127400,2.246045,2.7453668,,,,,,,,,,,,,, -127500,2.3396723,3.1214814,,,,,,,,,,,,,, -127600,2.7164094,2.8195705,,,,,,,,,,,,,, -127700,2.8018663,2.8419333,,,,,,,,,,,,,, -127753,,,0.7480273246765137,1.1173925399780271,0.688759982585907,1.3697985410690308,50000.0,0.5613000392913818,1.97961175441742,10000.0,57600.62677979469,62026.65469145775,57600.62677979469,4413.811166524887,5.576773881912232,0.0 -127800,2.2908247,4.2199135,,,,,,,,,,,,,, -127900,2.3431654,2.6586466,,,,,,,,,,,,,, -128000,2.4023201,2.870623,,,,,,,,,,,,,, -128100,2.4581773,2.9138749,,,,,,,,,,,,,, -128200,2.4190784,2.952682,,,,,,,,,,,,,, -128300,2.641077,2.7931435,,,,,,,,,,,,,, -128400,2.5866005,2.6464965,,,,,,,,,,,,,, -128500,2.5214746,2.9161818,,,,,,,,,,,,,, -128600,2.6690006,2.8433194,,,,,,,,,,,,,, -128684,,,0.75990229845047,1.0859085321426392,0.6924999952316284,1.37213134765625,50000.0,0.566100001335144,1.9980629682540887,10000.0,58020.76798009872,62480.12114715576,58020.76798009872,4447.039078474045,5.6262922286987305,0.0 -128700,2.6382945,2.9530528,,,,,,,,,,,,,, -128800,2.5821521,2.760471,,,,,,,,,,,,,, -128900,2.5402443,2.7356353,,,,,,,,,,,,,, -129000,2.3718684,4.8603735,,,,,,,,,,,,,, -129100,2.398315,3.49261,,,,,,,,,,,,,, -129200,2.4531894,4.423971,,,,,,,,,,,,,, -129300,2.1111472,2.977996,,,,,,,,,,,,,, -129400,2.130526,3.3149495,,,,,,,,,,,,,, -129500,2.495605,2.718237,,,,,,,,,,,,,, -129600,2.6110914,2.6549904,,,,,,,,,,,,,, -129613,,,0.7639062404632568,1.050950288772583,0.6924999952316284,1.3538199663162231,50000.0,0.570900022983551,1.9598753452301023,10000.0,58440.78542947769,62934.27048492432,58440.78542947769,4481.064235448837,5.68550181388855,0.0 -129700,2.5849125,2.9902627,,,,,,,,,,,,,, -129800,2.6609058,4.6397142,,,,,,,,,,,,,, -129900,2.4096234,4.5643826,,,,,,,,,,,,,, -130000,2.4008594,4.00626,,,,,,,,,,,,,, -130100,2.710632,4.693765,,,,,,,,,,,,,, -130200,2.571098,2.773695,,,,,,,,,,,,,, -130300,2.6011944,3.0181575,,,,,,,,,,,,,, -130400,2.5115635,3.193895,,,,,,,,,,,,,, -130500,2.519387,2.7076073,,,,,,,,,,,,,, -130545,,,0.7498437166213989,1.1301133632659912,0.693120002746582,1.3789575099945068,50000.0,0.5597000122070312,2.01959228515625,10000.0,58860.95683288574,63387.47819709778,58860.95683288574,4513.996830224991,5.740103721618652,0.0 -130600,2.545705,3.518062,,,,,,,,,,,,,, -130700,2.605357,2.756136,,,,,,,,,,,,,, -130800,2.3535256,3.2793396,,,,,,,,,,,,,, -130900,2.6828508,2.6670327,,,,,,,,,,,,,, -131000,2.4836638,3.7987254,,,,,,,,,,,,,, -131100,2.706571,2.6983237,,,,,,,,,,,,,, -131200,2.5822945,3.0802014,,,,,,,,,,,,,, -131300,2.3786116,3.5800848,,,,,,,,,,,,,, -131400,2.835336,2.7204633,,,,,,,,,,,,,, -131476,,,0.756054699420929,1.0679931640625,0.6955400109291077,1.333950757980347,50000.0,0.5682000517845154,1.9795842170715328,10000.0,59280.94118070602,63841.59644985199,59280.94118070602,4548.038062334061,5.784178018569946,0.0 -131500,2.4363608,2.9300747,,,,,,,,,,,,,, -131600,2.5419407,2.9235916,,,,,,,,,,,,,, -131700,2.4548619,2.702619,,,,,,,,,,,,,, -131800,2.8893652,2.7639575,,,,,,,,,,,,,, -131900,2.3937006,3.7441838,,,,,,,,,,,,,, -132000,2.600603,2.6047227,,,,,,,,,,,,,, -132100,2.6711795,2.5328481,,,,,,,,,,,,,, -132200,2.9202166,2.99855,,,,,,,,,,,,,, -132300,2.4719722,3.7571115,,,,,,,,,,,,,, -132400,3.0113597,2.7008681,,,,,,,,,,,,,, -132405,,,0.7659375071525574,1.050911784172058,0.6958400011062622,1.3542274236679075,50000.0,0.5670000314712524,1.9746955633163448,10000.0,59701.01261425018,64294.9606654644,59701.01261425018,4581.237900733948,5.829130172729492,0.0 -132500,2.589831,2.854181,,,,,,,,,,,,,, -132600,2.787167,2.730276,,,,,,,,,,,,,, -132700,2.883353,4.7953234,,,,,,,,,,,,,, -132800,2.4616394,2.6804018,,,,,,,,,,,,,, -132900,2.676008,2.7028518,,,,,,,,,,,,,, -133000,2.7698874,2.6305642,,,,,,,,,,,,,, -133100,2.283636,3.2467027,,,,,,,,,,,,,, -133200,2.5504804,3.423644,,,,,,,,,,,,,, -133300,2.9750917,4.661392,,,,,,,,,,,,,, -133333,,,0.7583202719688416,1.063936471939087,0.7030799984931946,1.3143028020858765,50000.0,0.5729000568389893,1.9410356283187864,10000.0,60121.277137994766,64749.58311963081,60121.277137994766,4615.49994468689,5.87558388710022,0.0 -133400,2.8110685,2.639998,,,,,,,,,,,,,, -133500,2.661981,2.653549,,,,,,,,,,,,,, -133600,3.733183,2.6221755,,,,,,,,,,,,,, -133700,2.6831331,3.2622125,,,,,,,,,,,,,, -133800,3.182508,2.6782997,,,,,,,,,,,,,, -133900,2.734917,3.0383332,,,,,,,,,,,,,, -134000,2.8121147,2.691082,,,,,,,,,,,,,, -134100,2.4440587,2.9855714,,,,,,,,,,,,,, -134200,2.6928508,3.297235,,,,,,,,,,,,,, -134263,,,0.7602148056030273,1.056712985038757,0.700939953327179,1.3244282007217407,50000.0,0.5733000040054321,1.947724461555481,10000.0,60541.402416706085,65202.35506153107,60541.402416706085,4648.055237054825,5.918776750564575,0.0 -134300,2.6748693,2.585936,,,,,,,,,,,,,, -134400,2.632414,4.6074734,,,,,,,,,,,,,, -134500,3.1029465,2.939845,,,,,,,,,,,,,, -134600,2.7328622,2.7367926,,,,,,,,,,,,,, -134700,2.4751627,3.5358765,,,,,,,,,,,,,, -134800,2.9278634,2.7468762,,,,,,,,,,,,,, -134900,3.0530598,2.659624,,,,,,,,,,,,,, -135000,2.3979309,3.7216883,,,,,,,,,,,,,, -135100,2.8806784,2.7036734,,,,,,,,,,,,,, -135186,,,0.7699609398841858,1.0183870792388916,0.7020399570465088,1.3113582134246826,50000.0,0.5781000256538391,1.9366137981414795,10000.0,60961.52054858208,65656.98781871796,60961.52054858208,4682.473633766174,5.967818260192871,0.0 -135200,2.78868,4.46656,,,,,,,,,,,,,, -135300,2.7288635,2.61397,,,,,,,,,,,,,, -135400,2.4903526,3.7046733,,,,,,,,,,,,,, -135500,2.8657458,4.5386333,,,,,,,,,,,,,, -135600,3.2210956,4.4750385,,,,,,,,,,,,,, -135700,3.1357043,2.6680307,,,,,,,,,,,,,, -135800,3.106826,2.658574,,,,,,,,,,,,,, -135900,2.6608589,2.564981,,,,,,,,,,,,,, -136000,3.0897236,2.8251262,,,,,,,,,,,,,, -136100,2.9095912,2.7554524,,,,,,,,,,,,,, -136116,,,0.768359363079071,1.0378899574279783,0.705839991569519,1.3026723861694336,50000.0,0.5782000422477722,1.9222419261932373,10000.0,61381.7981479168,66110.9119849205,61381.7981479168,4716.022467851639,6.017577886581421,0.0 -136200,2.8594084,2.6443758,,,,,,,,,,,,,, -136300,2.6366673,3.0503738,,,,,,,,,,,,,, -136400,2.6483436,3.1490157,,,,,,,,,,,,,, -136500,2.9103026,2.5978,,,,,,,,,,,,,, -136600,2.7994921,2.5860968,,,,,,,,,,,,,, -136700,2.7261243,2.6214685,,,,,,,,,,,,,, -136800,2.6491604,3.0721004,,,,,,,,,,,,,, -136900,2.9348984,2.6333096,,,,,,,,,,,,,, -137000,3.00826,2.6104116,,,,,,,,,,,,,, -137045,,,0.7707226276397705,0.9996986985206604,0.7090199589729309,1.2752629518508911,50000.0,0.5843000411987305,1.8952151536941528,10000.0,61801.93469142914,66561.47880458832,61801.93469142914,4746.358287096024,6.064196825027466,0.0 -137100,2.873725,2.9187145,,,,,,,,,,,,,, -137200,2.9349782,2.765163,,,,,,,,,,,,,, -137300,3.1736,4.648104,,,,,,,,,,,,,, -137400,2.615784,3.8960614,,,,,,,,,,,,,, -137500,3.1114087,2.572638,,,,,,,,,,,,,, -137600,3.0787156,2.8414352,,,,,,,,,,,,,, -137700,2.7229207,3.7081037,,,,,,,,,,,,,, -137800,2.735444,2.9642363,,,,,,,,,,,,,, -137900,2.798462,2.9344838,,,,,,,,,,,,,, -137975,,,0.7759569883346558,0.9817464351654052,0.708899974822998,1.267372965812683,50000.0,0.5821000337600708,1.892650485038757,10000.0,62222.269728422165,67012.80084323883,62222.269728422165,4777.243156194687,6.118428707122803,0.0 -138000,2.7353017,2.7837486,,,,,,,,,,,,,, -138100,2.825523,3.473713,,,,,,,,,,,,,, -138200,2.861181,2.6028779,,,,,,,,,,,,,, -138300,3.1112366,2.632752,,,,,,,,,,,,,, -138400,2.8321774,3.06382,,,,,,,,,,,,,, -138500,2.7533958,3.44058,,,,,,,,,,,,,, -138600,3.075178,2.5569673,,,,,,,,,,,,,, -138700,3.246479,4.4707375,,,,,,,,,,,,,, -138800,2.9656303,2.6536722,,,,,,,,,,,,,, -138900,3.2881808,4.7744765,,,,,,,,,,,,,, -138906,,,0.7888085842132568,0.9418398141860962,0.7098599672317505,1.2711541652679443,50000.0,0.5886000394821167,1.8847719430923464,10000.0,62642.559386491776,67466.630849123,62642.559386491776,4810.671858549118,6.181352853775024,0.0 -139000,3.2091494,4.7145147,,,,,,,,,,,,,, -139100,2.987662,2.5781274,,,,,,,,,,,,,, -139200,2.9386332,2.4991577,,,,,,,,,,,,,, -139300,3.1340387,2.6467733,,,,,,,,,,,,,, -139400,3.1873355,2.6710289,,,,,,,,,,,,,, -139500,2.8003979,2.51243,,,,,,,,,,,,,, -139600,2.9532554,2.531321,,,,,,,,,,,,,, -139700,3.0878036,2.6414146,,,,,,,,,,,,,, -139800,3.1091623,4.4225416,,,,,,,,,,,,,, -139840,,,0.770312488079071,0.9991475343704224,0.7108599543571472,1.2587929964065552,50000.0,0.5869000554084778,1.8614424467086792,10000.0,63062.58865451813,67920.26310777664,63062.58865451813,4844.179829597473,6.227135896682739,0.0 -139900,2.8723474,3.7833843,,,,,,,,,,,,,, -140000,3.0586224,2.6385667,,,,,,,,,,,,,, -140100,2.8668604,3.819766,,,,,,,,,,,,,, -140200,3.1091018,2.7131345,,,,,,,,,,,,,, -140300,2.788995,2.708268,,,,,,,,,,,,,, -140400,3.0668247,4.685711,,,,,,,,,,,,,, -140500,3.1740506,2.7876525,,,,,,,,,,,,,, -140600,2.9363103,3.9065983,,,,,,,,,,,,,, -140700,3.2742743,4.6914725,,,,,,,,,,,,,, -140771,,,0.7780663967132568,0.9655563831329346,0.7139399647712708,1.2472407817840576,50000.0,0.5897000432014465,1.8567043542861936,10000.0,63482.7752828598,68372.2566754818,63482.7752828598,4875.891560316086,6.273644685745239,0.0 -140800,2.979481,2.7770653,,,,,,,,,,,,,, -140900,3.056304,2.63625,,,,,,,,,,,,,, -141000,3.3308961,2.6818688,,,,,,,,,,,,,, -141100,2.9195168,2.660997,,,,,,,,,,,,,, -141200,2.6832397,3.0419474,,,,,,,,,,,,,, -141300,2.9603717,2.582191,,,,,,,,,,,,,, -141400,3.4942238,4.650073,,,,,,,,,,,,,, -141500,3.1969826,2.6319015,,,,,,,,,,,,,, -141600,2.9089189,2.4954705,,,,,,,,,,,,,, -141700,3.221661,2.7096777,,,,,,,,,,,,,, -141701,,,0.7883398532867432,0.9381603598594666,0.7156800031661987,1.24841046333313,50000.0,0.588200032711029,1.8638124465942385,10000.0,63903.23413252831,68820.95351719856,63903.23413252831,4904.032790899277,6.321924686431885,0.0 -141800,3.0179994,2.5975444,,,,,,,,,,,,,, -141900,3.1691158,2.595195,,,,,,,,,,,,,, -142000,3.1739686,2.6287546,,,,,,,,,,,,,, -142100,2.7839618,3.4052029,,,,,,,,,,,,,, -142200,3.5657735,4.6724753,,,,,,,,,,,,,, -142300,2.8616362,2.854314,,,,,,,,,,,,,, -142400,2.9848187,2.583108,,,,,,,,,,,,,, -142500,3.3421881,4.1841717,,,,,,,,,,,,,, -142600,3.7594244,4.46834,,,,,,,,,,,,,, -142629,,,0.7808007597923279,0.9934661388397216,0.7153599858283997,1.273717164993286,50000.0,0.5924000144004822,1.8806276321411133,10000.0,64323.499345541,69273.26719760895,64323.499345541,4935.975115537643,6.379689931869507,0.0 -142700,2.9086144,2.959979,,,,,,,,,,,,,, -142800,3.594709,4.6252375,,,,,,,,,,,,,, -142900,2.7613623,3.152205,,,,,,,,,,,,,, -143000,3.049967,2.925615,,,,,,,,,,,,,, -143100,3.1021042,4.465156,,,,,,,,,,,,,, -143200,3.0064251,3.4625,,,,,,,,,,,,,, -143300,2.8861446,4.0624323,,,,,,,,,,,,,, -143400,2.9207416,2.6586072,,,,,,,,,,,,,, -143500,3.3620043,2.5417504,,,,,,,,,,,,,, -143560,,,0.7847851514816284,0.949887216091156,0.7196199893951416,1.221478819847107,50000.0,0.5910000205039978,1.84019148349762,10000.0,64743.52144980431,69727.04166722298,64743.52144980431,4969.630095720291,6.429242849349976,0.0 -143600,3.1162157,2.4875286,,,,,,,,,,,,,, -143700,3.205206,2.565304,,,,,,,,,,,,,, -143800,3.5220873,4.601513,,,,,,,,,,,,,, -143900,3.379428,3.5752861,,,,,,,,,,,,,, -144000,3.1643128,4.647688,,,,,,,,,,,,,, -144100,3.3272562,2.5382104,,,,,,,,,,,,,, -144200,3.6076484,3.3635292,,,,,,,,,,,,,, -144300,3.739889,4.6176405,,,,,,,,,,,,,, -144400,2.9927962,3.079854,,,,,,,,,,,,,, -144490,,,0.7881640195846558,0.9314032793045044,0.718239963054657,1.2326812744140625,50000.0,0.5952000021934509,1.845787763595581,10000.0,65163.633311748505,70180.94100570679,65163.633311748505,5003.322789907455,6.475466012954712,0.0 -144500,3.1159055,3.8622134,,,,,,,,,,,,,, -144600,2.966437,3.0410676,,,,,,,,,,,,,, -144700,3.1377242,3.3757186,,,,,,,,,,,,,, -144800,3.1902409,2.8563778,,,,,,,,,,,,,, -144900,3.351428,2.6686637,,,,,,,,,,,,,, -145000,3.020085,3.1171746,,,,,,,,,,,,,, -145100,3.345762,2.6840053,,,,,,,,,,,,,, -145200,3.322857,2.5248644,,,,,,,,,,,,,, -145300,3.8032398,4.5703764,,,,,,,,,,,,,, -145400,3.425786,2.595605,,,,,,,,,,,,,, -145420,,,0.7870116829872131,0.9516395926475524,0.7205599546432495,1.229424238204956,50000.0,0.5932000279426575,1.8493428230285645,10000.0,65583.94645094872,70635.012373209,65583.94645094872,5036.984225511551,6.524138689041138,0.0 -145500,3.111999,2.8520064,,,,,,,,,,,,,, -145600,3.1915662,3.3755407,,,,,,,,,,,,,, -145700,3.225176,2.4609141,,,,,,,,,,,,,, -145800,3.138875,2.6289325,,,,,,,,,,,,,, -145900,3.3205729,4.6159325,,,,,,,,,,,,,, -146000,3.2259426,2.5112095,,,,,,,,,,,,,, -146100,3.5441394,2.503284,,,,,,,,,,,,,, -146200,3.1942062,2.5280874,,,,,,,,,,,,,, -146300,3.3964753,2.4433646,,,,,,,,,,,,,, -146350,,,0.7876952886581421,0.9195204973220824,0.726419985294342,1.2012972831726074,50000.0,0.6014000177383423,1.814241647720337,10000.0,66003.91966462135,71088.62664437294,66003.91966462135,5070.531363964081,6.570559024810791,0.0 -146400,3.2076383,3.393395,,,,,,,,,,,,,, -146500,4.143378,4.60875,,,,,,,,,,,,,, -146600,3.331534,2.5419514,,,,,,,,,,,,,, -146700,3.7744017,2.5876188,,,,,,,,,,,,,, -146800,3.855474,4.570594,,,,,,,,,,,,,, -146900,3.816336,4.1153765,,,,,,,,,,,,,, -147000,3.1518595,3.2362335,,,,,,,,,,,,,, -147100,3.2028096,3.2256148,,,,,,,,,,,,,, -147200,3.630041,2.482442,,,,,,,,,,,,,, -147280,,,0.7960156202316284,0.888444185256958,0.7275800108909607,1.1892650127410889,50000.0,0.6032000184059143,1.803131937980652,10000.0,66423.87331795692,71541.92117094994,66423.87331795692,5103.775855779648,6.61857533454895,0.0 -147300,3.3791611,4.130595,,,,,,,,,,,,,, -147400,3.297713,2.803748,,,,,,,,,,,,,, -147500,5.695345,4.1253076,,,,,,,,,,,,,, -147600,3.765473,4.3552103,,,,,,,,,,,,,, -147700,2.9187217,3.8961713,,,,,,,,,,,,,, -147800,4.0802093,4.513192,,,,,,,,,,,,,, -147900,3.5250888,2.5285118,,,,,,,,,,,,,, -148000,3.3439584,2.486485,,,,,,,,,,,,,, -148100,3.455817,4.3289995,,,,,,,,,,,,,, -148200,3.8736918,2.615316,,,,,,,,,,,,,, -148210,,,0.7996679544448853,0.9016132950782776,0.7257199883460999,1.2199366092681885,50000.0,0.5993000268936157,1.8347848653793333,10000.0,66843.96156454086,71995.24836182594,66843.96156454086,5136.919964790344,6.665120840072632,0.0 -148300,3.7232716,2.4692097,,,,,,,,,,,,,, -148400,3.3350744,2.556497,,,,,,,,,,,,,, -148500,3.515229,2.4252644,,,,,,,,,,,,,, -148600,3.6168187,4.029157,,,,,,,,,,,,,, -148700,3.4838424,3.8233843,,,,,,,,,,,,,, -148800,3.6525831,3.6666708,,,,,,,,,,,,,, -148900,3.888196,3.4287176,,,,,,,,,,,,,, -149000,3.9795988,4.247979,,,,,,,,,,,,,, -149100,3.5086231,2.5193524,,,,,,,,,,,,,, -149141,,,0.7919726371765137,0.9040732979774476,0.7277799844741821,1.1872663497924805,50000.0,0.6061000227928162,1.7923401594161987,10000.0,67264.05994081497,72447.8445456028,67264.05994081497,5169.321353435516,6.71377420425415,0.0 -149200,3.311954,2.8642707,,,,,,,,,,,,,, -149300,3.7632084,2.5096917,,,,,,,,,,,,,, -149400,3.812753,2.5768144,,,,,,,,,,,,,, -149500,3.7213519,2.8627768,,,,,,,,,,,,,, -149600,3.1662724,2.3995214,,,,,,,,,,,,,, -149700,4.0272665,2.583787,,,,,,,,,,,,,, -149800,3.500879,2.470336,,,,,,,,,,,,,, -149900,3.4577904,2.5406451,,,,,,,,,,,,,, -150000,4.081754,4.278159,,,,,,,,,,,,,, -150071,,,0.7992578148841858,0.882004976272583,0.7291399836540222,1.1780247688293457,50000.0,0.6043000221252441,1.791891098022461,10000.0,67684.25943183899,72900.0660943985,67684.25943183899,5201.246497869492,6.762799263000488,0.0 -150100,3.8143048,2.4949327,,,,,,,,,,,,,, -150200,3.698367,4.2191825,,,,,,,,,,,,,, -150300,4.187956,4.417368,,,,,,,,,,,,,, -150400,3.7201228,3.4256399,,,,,,,,,,,,,, -150500,4.040747,2.4883087,,,,,,,,,,,,,, -150600,3.420985,3.0469944,,,,,,,,,,,,,, -150700,3.6052904,2.722366,,,,,,,,,,,,,, -150800,3.7061672,2.46071,,,,,,,,,,,,,, -150900,3.7493594,2.5993266,,,,,,,,,,,,,, -151000,,,0.8058202862739563,0.8614989519119263,0.7326399683952332,1.1725473403930664,50000.0,0.609000027179718,1.7773709297180176,10000.0,68104.4731631279,73351.86920380592,68104.4731631279,5232.74213886261,6.808778524398804,0.0 -151000,3.4731605,3.0939622,,,,,,,,,,,,,, -151100,3.398694,2.9821038,,,,,,,,,,,,,, -151200,3.5878265,2.41999,,,,,,,,,,,,,, -151300,3.7680714,2.458597,,,,,,,,,,,,,, -151400,3.4578946,2.616446,,,,,,,,,,,,,, -151500,3.9380767,4.2125273,,,,,,,,,,,,,, -151600,3.836182,2.5003614,,,,,,,,,,,,,, -151700,3.9855552,2.6139908,,,,,,,,,,,,,, -151800,3.8411882,3.7444656,,,,,,,,,,,,,, -151900,4.0308433,4.2450852,,,,,,,,,,,,,, -151932,,,0.8000390529632568,0.8734333515167236,0.731220006942749,1.1657071113586426,50000.0,0.6045000553131104,1.7815557718276978,10000.0,68524.8423511982,73805.52795624733,68524.8423511982,5265.935954332352,6.8564043045043945,0.0 -152000,3.776339,3.2881556,,,,,,,,,,,,,, -152100,3.8408518,2.4456353,,,,,,,,,,,,,, -152200,3.5944262,3.387082,,,,,,,,,,,,,, -152300,3.7331502,4.188884,,,,,,,,,,,,,, -152400,4.717687,4.285228,,,,,,,,,,,,,, -152500,3.6594737,2.4914684,,,,,,,,,,,,,, -152600,3.8453918,2.5708082,,,,,,,,,,,,,, -152700,3.6214666,2.653666,,,,,,,,,,,,,, -152800,3.4600108,2.7868106,,,,,,,,,,,,,, -152865,,,0.804492175579071,0.8603943586349487,0.7348399758338928,1.1572157144546509,50000.0,0.6141000390052795,1.7590707540512085,10000.0,68945.09971499443,74259.00728917122,68945.09971499443,5299.06134557724,6.904045104980469,0.0 -152900,3.83056,2.508308,,,,,,,,,,,,,, -153000,3.5637064,2.4486527,,,,,,,,,,,,,, -153100,3.6825106,3.2969012,,,,,,,,,,,,,, -153200,4.088924,3.979264,,,,,,,,,,,,,, -153300,3.5894082,3.1353984,,,,,,,,,,,,,, -153400,3.5649178,3.376824,,,,,,,,,,,,,, -153500,3.6978374,3.301685,,,,,,,,,,,,,, -153600,3.9864836,2.512234,,,,,,,,,,,,,, -153700,3.9937122,2.3948183,,,,,,,,,,,,,, -153797,,,0.8055663704872131,0.8574345707893372,0.7340199947357178,1.1686347723007202,50000.0,0.6112000346183777,1.773018479347229,10000.0,69365.369992733,74713.34918832779,69365.369992733,5333.035125255585,6.953671216964722,0.0 -153800,3.8593934,2.3869195,,,,,,,,,,,,,, -153900,3.7575858,2.6728258,,,,,,,,,,,,,, -154000,3.8641148,2.43163,,,,,,,,,,,,,, -154100,3.7936678,2.4143426,,,,,,,,,,,,,, -154200,4.041633,2.5575573,,,,,,,,,,,,,, -154300,3.7868528,2.5430741,,,,,,,,,,,,,, -154400,3.472295,2.8331883,,,,,,,,,,,,,, -154500,4.347123,4.215972,,,,,,,,,,,,,, -154600,3.7534282,2.4874887,,,,,,,,,,,,,, -154700,3.4959116,2.3610244,,,,,,,,,,,,,, -154728,,,0.8135741949081421,0.8229332566261292,0.737779974937439,1.1446532011032104,50000.0,0.6133000254631042,1.7443084716796875,10000.0,69785.46291160583,75167.66317725182,69785.46291160583,5367.154443502426,7.006609916687012,0.0 -154800,4.430503,2.8542938,,,,,,,,,,,,,, -154900,3.9765198,2.7640243,,,,,,,,,,,,,, -155000,4.5471134,4.3677893,,,,,,,,,,,,,, -155100,4.34429,2.375637,,,,,,,,,,,,,, -155200,3.7950563,2.6944213,,,,,,,,,,,,,, -155300,4.01537,2.4325204,,,,,,,,,,,,,, -155400,3.6728363,3.2724102,,,,,,,,,,,,,, -155500,4.0853195,2.416975,,,,,,,,,,,,,, -155600,3.7700694,3.3719764,,,,,,,,,,,,,, -155658,,,0.8089257478713989,0.8678907752037048,0.7383999824523926,1.1707357168197632,50000.0,0.614300012588501,1.7738709449768066,10000.0,70205.56688523293,75621.48214673996,70205.56688523293,5400.772227048874,7.056041240692139,0.0 -155700,4.013618,2.42336,,,,,,,,,,,,,, -155800,4.1506557,3.7762928,,,,,,,,,,,,,, -155900,3.6163318,3.1997623,,,,,,,,,,,,,, -156000,4.0192475,2.464807,,,,,,,,,,,,,, -156100,4.205903,3.8090234,,,,,,,,,,,,,, -156200,3.7017293,3.3814452,,,,,,,,,,,,,, -156300,4.2193093,3.5342076,,,,,,,,,,,,,, -156400,4.0750246,3.462687,,,,,,,,,,,,,, -156500,3.9756618,2.5229628,,,,,,,,,,,,,, -156587,,,0.8149804472923279,0.8367795944213867,0.7415599822998047,1.1527200937271118,50000.0,0.6217000484466553,1.7479978799819946,10000.0,70625.50346064568,76075.67284274101,70625.50346064568,5434.675608158112,7.35836124420166,0.0 -156600,4.1764693,2.4627657,,,,,,,,,,,,,, -156700,4.192957,2.5284562,,,,,,,,,,,,,, -156800,4.2077117,2.4711335,,,,,,,,,,,,,, -156900,4.5124507,4.116451,,,,,,,,,,,,,, -157000,3.7983198,3.1360798,,,,,,,,,,,,,, -157100,3.7613523,2.405957,,,,,,,,,,,,,, -157200,4.17808,2.5447187,,,,,,,,,,,,,, -157300,4.038158,2.4397976,,,,,,,,,,,,,, -157400,3.8584452,2.7276306,,,,,,,,,,,,,, -157500,4.467758,2.3977103,,,,,,,,,,,,,, -157511,,,0.8213085532188416,0.7878121137619019,0.745199978351593,1.118953824043274,50000.0,0.6192000508308411,1.723163604736328,10000.0,71045.69268107414,76529.8608827591,71045.69268107414,5468.575320243835,7.409008026123047,0.0 -157600,4.0950646,2.726379,,,,,,,,,,,,,, -157700,4.2141104,2.3340728,,,,,,,,,,,,,, -157800,4.060463,2.426597,,,,,,,,,,,,,, -157900,3.9575322,2.5614972,,,,,,,,,,,,,, -158000,3.9915333,2.5499182,,,,,,,,,,,,,, -158100,4.253127,2.450983,,,,,,,,,,,,,, -158200,3.8239217,2.5733495,,,,,,,,,,,,,, -158300,4.9933276,2.613123,,,,,,,,,,,,,, -158400,4.069451,2.6885598,,,,,,,,,,,,,, -158441,,,0.8138671517372131,0.8201526403427124,0.7439999580383301,1.1131176948547363,50000.0,0.6200000047683716,1.7198753356933594,10000.0,71465.86521029472,76984.15627121925,71465.86521029472,5502.602823019028,7.457210540771484,0.0 -158500,4.3476996,2.3684402,,,,,,,,,,,,,, -158600,4.29762,2.4531329,,,,,,,,,,,,,, -158700,4.001362,2.8243432,,,,,,,,,,,,,, -158800,4.03714,2.9847734,,,,,,,,,,,,,, -158900,3.9425657,2.3395324,,,,,,,,,,,,,, -159000,3.8648136,3.4517903,,,,,,,,,,,,,, -159100,3.9788182,3.7297134,,,,,,,,,,,,,, -159200,4.57658,3.7024639,,,,,,,,,,,,,, -159300,4.0907497,2.4084744,,,,,,,,,,,,,, -159370,,,0.82093745470047,0.8007131218910217,0.7448599934577942,1.126524806022644,50000.0,0.6217000484466553,1.7308553457260132,10000.0,71885.79600262642,77437.7894179821,71885.79600262642,5536.2020580768585,7.512519836425781,0.0 -159400,4.9155307,3.863739,,,,,,,,,,,,,, -159500,4.040636,2.8396294,,,,,,,,,,,,,, -159600,5.0711713,2.3447115,,,,,,,,,,,,,, -159700,4.37224,2.571474,,,,,,,,,,,,,, -159800,4.737862,2.4123864,,,,,,,,,,,,,, -159900,3.9721596,2.3865566,,,,,,,,,,,,,, -160000,4.504187,2.5240152,,,,,,,,,,,,,, -160100,4.016418,2.3450816,,,,,,,,,,,,,, -160200,4.1694613,2.6287694,,,,,,,,,,,,,, -160300,4.5039096,2.4616873,,,,,,,,,,,,,, -160301,,,0.8273046612739563,0.7696381211280823,0.7482199668884277,1.115172028541565,50000.0,0.6235000491142273,1.7202221155166626,10000.0,72306.14024019241,77892.31031370163,72306.14024019241,5570.279438018799,7.564241647720337,0.0 -160400,4.5847445,2.431643,,,,,,,,,,,,,, -160500,4.754954,3.7556658,,,,,,,,,,,,,, -160600,4.9015646,3.7382567,,,,,,,,,,,,,, -160700,4.4347343,3.3823206,,,,,,,,,,,,,, -160800,4.811438,3.916216,,,,,,,,,,,,,, -160900,4.5943856,4.0474463,,,,,,,,,,,,,, -161000,4.125694,2.6156316,,,,,,,,,,,,,, -161100,4.392488,2.3558378,,,,,,,,,,,,,, -161200,4.208253,3.4110188,,,,,,,,,,,,,, -161232,,,0.8215234279632568,0.79811030626297,0.7492199540138245,1.1067612171173096,50000.0,0.6300000548362732,1.7063981294631958,10000.0,72726.45551586151,78342.89540290833,72726.45551586151,5600.44885635376,7.616266965866089,0.0 -161300,4.4420033,2.4036348,,,,,,,,,,,,,, -161400,4.7654424,2.3652458,,,,,,,,,,,,,, -161500,4.388219,2.74258,,,,,,,,,,,,,, -161600,4.4735813,2.3215685,,,,,,,,,,,,,, -161700,4.1979465,2.8148482,,,,,,,,,,,,,, -161800,5.3699713,4.3319964,,,,,,,,,,,,,, -161900,4.7221346,2.5305765,,,,,,,,,,,,,, -162000,4.7364345,2.982617,,,,,,,,,,,,,, -162100,4.796807,3.8067973,,,,,,,,,,,,,, -162159,,,0.8233007788658142,0.7672076225280762,0.7484999895095825,1.0859493017196655,50000.0,0.6338000297546387,1.673032522201538,10000.0,73146.63851761818,78793.20580601692,73146.63851761818,5630.4672927856445,7.676425695419311,0.0 -162200,4.448053,2.436304,,,,,,,,,,,,,, -162300,4.830475,2.4258885,,,,,,,,,,,,,, -162400,4.169002,2.2544656,,,,,,,,,,,,,, -162500,4.3698955,2.2784815,,,,,,,,,,,,,, -162600,4.602059,3.2284992,,,,,,,,,,,,,, -162700,4.712513,2.5114946,,,,,,,,,,,,,, -162800,4.5596795,2.3749187,,,,,,,,,,,,,, -162900,5.4171977,4.3620224,,,,,,,,,,,,,, -163000,4.553608,2.8455043,,,,,,,,,,,,,, -163088,,,0.8291601538658142,0.7714617848396301,0.7514199614524841,1.10415518283844,50000.0,0.6310000419616699,1.7050611972808838,10000.0,73566.73581504822,79246.45765209198,73566.73581504822,5663.513606309891,7.735929727554321,0.0 -163100,4.390659,3.3378122,,,,,,,,,,,,,, -163200,4.612509,2.4524522,,,,,,,,,,,,,, -163300,4.4970245,2.3415444,,,,,,,,,,,,,, -163400,4.4900103,2.3670447,,,,,,,,,,,,,, -163500,4.5347433,2.3333797,,,,,,,,,,,,,, -163600,4.3909287,2.3097014,,,,,,,,,,,,,, -163700,5.0609255,4.190096,,,,,,,,,,,,,, -163800,4.6314406,3.1513975,,,,,,,,,,,,,, -163900,4.8161254,2.3447962,,,,,,,,,,,,,, -164000,4.6342707,2.2481294,,,,,,,,,,,,,, -164021,,,0.83509761095047,0.7409867644309998,0.7541199922561646,1.079445481300354,50000.0,0.6335000395774841,1.6779024600982666,10000.0,73986.66831469536,79700.57873511314,73986.66831469536,5697.595949888229,7.793646812438965,0.0 -164100,4.586197,2.417203,,,,,,,,,,,,,, -164200,4.5371804,2.3577127,,,,,,,,,,,,,, -164300,4.64329,2.359459,,,,,,,,,,,,,, -164400,5.28393,4.06935,,,,,,,,,,,,,, -164500,4.716491,2.3673635,,,,,,,,,,,,,, -164600,5.767026,4.1549444,,,,,,,,,,,,,, -164700,4.4974017,2.2941387,,,,,,,,,,,,,, -164800,5.2290797,3.7233634,,,,,,,,,,,,,, -164900,4.473127,2.3549168,,,,,,,,,,,,,, -164957,,,0.8322851657867432,0.7366339564323425,0.7554399967193604,1.0619897842407229,50000.0,0.6362000107765198,1.6556838750839231,10000.0,74407.01210308075,80155.67283654213,74407.01210308075,5732.24786067009,7.844249725341797,0.0 -165000,4.4380813,3.365471,,,,,,,,,,,,,, -165100,5.031345,2.3856888,,,,,,,,,,,,,, -165200,4.8003173,2.3260021,,,,,,,,,,,,,, -165300,4.7824554,3.6451383,,,,,,,,,,,,,, -165400,4.839137,3.0186682,,,,,,,,,,,,,, -165500,4.658033,2.5134342,,,,,,,,,,,,,, -165600,4.6654305,2.7268014,,,,,,,,,,,,,, -165700,5.8999953,4.258935,,,,,,,,,,,,,, -165800,4.7644415,3.1637805,,,,,,,,,,,,,, -165887,,,0.8315820097923279,0.7409655451774597,0.7555999755859375,1.0678353309631348,50000.0,0.6294000148773193,1.6705749034881592,10000.0,74827.03679513931,80606.65344071388,74827.03679513931,5763.104225158691,7.89516282081604,0.0 -165900,5.066099,2.3586164,,,,,,,,,,,,,, -166000,5.313916,3.9719331,,,,,,,,,,,,,, -166100,4.513039,2.508701,,,,,,,,,,,,,, -166200,5.17104,3.6411583,,,,,,,,,,,,,, -166300,5.054175,2.3451989,,,,,,,,,,,,,, -166400,4.676841,3.259601,,,,,,,,,,,,,, -166500,4.6474686,3.6536102,,,,,,,,,,,,,, -166600,5.882763,3.96272,,,,,,,,,,,,,, -166700,4.4128976,2.306427,,,,,,,,,,,,,, -166800,5.804311,4.357635,,,,,,,,,,,,,, -166816,,,0.8335937261581421,0.7405239939689636,0.7565000057220459,1.076341986656189,50000.0,0.6351000070571899,1.669582486152649,10000.0,75247.1132349968,81058.46815085411,75247.1132349968,5794.742982387543,7.947040319442749,0.0 -166900,4.978141,2.347518,,,,,,,,,,,,,, -167000,4.6620946,2.2234929,,,,,,,,,,,,,, -167100,5.376382,3.5038018,,,,,,,,,,,,,, -167200,4.6557274,2.2763348,,,,,,,,,,,,,, -167300,5.154142,2.3582976,,,,,,,,,,,,,, -167400,4.797354,2.2279348,,,,,,,,,,,,,, -167500,4.8728094,2.6292133,,,,,,,,,,,,,, -167600,4.754238,2.375493,,,,,,,,,,,,,, -167700,7.2383547,4.2300754,,,,,,,,,,,,,, -167747,,,0.8338671922683716,0.7559254765510559,0.75791996717453,1.0783684253692627,50000.0,0.6355000138282776,1.6743074655532837,10000.0,75667.09297275543,81511.52088880539,75667.09297275543,5827.715883970261,7.998236179351807,0.0 -167800,4.7638597,2.3220792,,,,,,,,,,,,,, -167900,4.984024,3.1131568,,,,,,,,,,,,,, -168000,5.043437,2.2508056,,,,,,,,,,,,,, -168100,4.820967,2.2872376,,,,,,,,,,,,,, -168200,4.8943005,2.2905014,,,,,,,,,,,,,, -168300,4.8291264,2.53182,,,,,,,,,,,,,, -168400,5.5999427,2.709173,,,,,,,,,,,,,, -168500,4.8037076,2.298138,,,,,,,,,,,,,, -168600,5.062232,2.5096462,,,,,,,,,,,,,, -168678,,,0.8360546827316284,0.7370734214782715,0.7594599723815918,1.0704350471496582,50000.0,0.636400043964386,1.665740728378296,10000.0,76087.34431004524,81965.7411146164,76087.34431004524,5861.583888530731,8.05066990852356,0.0 -168700,4.8307314,2.2732406,,,,,,,,,,,,,, -168800,4.651676,3.3711803,,,,,,,,,,,,,, -168900,5.777817,4.161727,,,,,,,,,,,,,, -169000,5.190045,3.5970407,,,,,,,,,,,,,, -169100,4.868297,2.2602527,,,,,,,,,,,,,, -169200,4.4252152,2.2074509,,,,,,,,,,,,,, -169300,4.71776,3.0676565,,,,,,,,,,,,,, -169400,5.119131,2.1912055,,,,,,,,,,,,,, -169500,5.3092227,3.5966,,,,,,,,,,,,,, -169600,5.0829077,2.6842651,,,,,,,,,,,,,, -169609,,,0.83984375,0.7221209406852722,0.760159969329834,1.0561344623565674,50000.0,0.6401000022888184,1.649182677268982,10000.0,76507.4817943573,82420.15692543983,76507.4817943573,5895.755210876465,8.10885739326477,0.0 -169700,4.7275596,2.488705,,,,,,,,,,,,,, -169800,4.9616065,2.3425927,,,,,,,,,,,,,, -169900,5.528734,2.2612529,,,,,,,,,,,,,, -170000,4.726496,2.5155344,,,,,,,,,,,,,, -170100,5.2213526,2.2698588,,,,,,,,,,,,,, -170200,5.0272646,2.248812,,,,,,,,,,,,,, -170300,4.9006443,2.2581732,,,,,,,,,,,,,, -170400,4.9877195,2.222002,,,,,,,,,,,,,, -170500,6.3174663,4.0583854,,,,,,,,,,,,,, -170540,,,0.8365820050239563,0.7234076261520386,0.7600599527359009,1.0516064167022705,50000.0,0.6373000144958496,1.6481930017471311,10000.0,76927.61034536362,82874.5614593029,76927.61034536362,5929.933554887772,8.157514572143555,0.0 -170600,5.5563073,2.2874708,,,,,,,,,,,,,, -170700,4.7758327,2.1741686,,,,,,,,,,,,,, -170800,4.885949,2.635352,,,,,,,,,,,,,, -170900,5.123112,3.4189095,,,,,,,,,,,,,, -171000,6.00969,4.125309,,,,,,,,,,,,,, -171100,4.9815354,3.212852,,,,,,,,,,,,,, -171200,4.7817454,2.2703576,,,,,,,,,,,,,, -171300,5.4717646,2.4464834,,,,,,,,,,,,,, -171400,5.243305,2.9556282,,,,,,,,,,,,,, -171472,,,0.8395116925239563,0.7189873456954956,0.7607799768447876,1.0504368543624878,50000.0,0.6414000391960144,1.6435645818710327,10000.0,77347.73178625107,83329.06152820587,77347.73178625107,5964.20698928833,8.21371054649353,0.0 -171500,4.869399,2.2797868,,,,,,,,,,,,,, -171600,5.531389,2.4178586,,,,,,,,,,,,,, -171700,4.80992,2.702098,,,,,,,,,,,,,, -171800,5.2580066,2.2529647,,,,,,,,,,,,,, -171861,,,,,,,,,,,77520.0893175602,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 50dd2f961..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -27.850645065307617,0.0,35.015655517578125,1,0,35.015655517578125,0.0010000000474974,6.907756805419922,10000,62.8664071559906,0.0008593749953433,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -62.341947078704834,0.0177381038665771,455.200243473053,875,0,455.200243473053,0.0130000002682209,6.417356014251709,10000,517.6062908172607,0.0174804683774709,6.352764129638672,0.016339998692274,6.365495681762695,50000 -98.02724885940552,0.0483002662658691,875.5398058891296,1799,0,875.5398058891296,0.0340000018477439,5.986781597137451,10000,973.7104048728944,0.0464843735098838,5.830441951751709,0.0422799997031688,5.864605903625488,50000 -133.39315724372864,0.0725252628326416,1295.7417635917664,2691,0,1295.7417635917664,0.051400002092123,5.651369571685791,10000,1429.3486967086792,0.0708398446440696,5.448312282562256,0.0682799965143203,5.485576152801514,50000 -167.62544560432434,0.0957322120666503,1715.829880952835,3620,0,1715.829880952835,0.0813000053167343,5.356257915496826,10000,1883.7408828735352,0.1144726499915123,5.076627731323242,0.1036200001835823,5.128720760345459,50000 -201.9062662124633,0.1222002506256103,2136.0431559085846,4546,0,2136.0431559085846,0.108300007879734,5.006032943725586,10000,2338.310334444046,0.1610351502895355,4.597641944885254,0.1422600001096725,4.700864315032959,50000 -235.9856154918671,0.146214485168457,2556.31773352623,5473,0,2556.31773352623,0.1407999992370605,4.678713321685791,10000,2792.7371475696564,0.2023632824420929,4.254808902740479,0.1880799978971481,4.334715366363525,50000 -268.4839344024658,0.1717164516448974,2976.2778012752533,6402,0,2976.2778012752533,0.1830000132322311,4.3917012214660645,10000,3245.2699744701385,0.25146484375,3.917474031448364,0.2351599931716919,4.009504795074463,50000 -301.3875472545624,0.1950387954711914,3396.418829917908,7330,0,3396.418829917908,0.2092000097036361,4.218708038330078,10000,3698.38631939888,0.2987304627895355,3.65021538734436,0.2711600065231323,3.7866060733795166,50000 -334.6520891189575,0.2184698581695556,3816.663746833801,8257,0,3816.663746833801,0.2370000183582306,4.004726886749268,10000,4151.96754193306,0.3236913979053497,3.4396770000457764,0.3033199906349182,3.5456032752990723,50000 -368.2774066925049,0.2466275691986084,4236.656970024109,9184,0,4236.656970024109,0.259300023317337,3.76635479927063,10000,4605.663713693619,0.3685546815395355,3.152804613113404,0.3393599987030029,3.286856412887573,50000 -402.5110983848572,0.2724997997283935,4656.829707622528,10112,0,4656.829707622528,0.2835000157356262,3.687035322189331,10000,5060.144756317139,0.4025000035762787,3.0011579990386963,0.3671199977397918,3.165762662887573,50000 -435.4448158740997,0.2978024482727051,5076.838560819626,11041,0,5076.838560819626,0.3059000074863434,3.496697902679444,10000,5513.161741495132,0.4302734136581421,2.806786060333252,0.3969399929046631,2.9504761695861816,50000 -467.0167169570923,0.3260564804077148,5497.176267147064,11969,0,5497.176267147064,0.3264000117778778,3.385175943374634,10000,5965.148057460785,0.4562304615974426,2.67525053024292,0.4208199977874756,2.830482482910156,50000 -500.62565183639526,0.3518826961517334,5917.206485748291,12898,0,5917.206485748291,0.3397000133991241,3.322092294692993,10000,6418.862198352814,0.4791015386581421,2.5726089477539062,0.4387999773025512,2.7511792182922363,50000 -534.2692155838013,0.3798291683197021,6337.3051698207855,13826,0,6337.3051698207855,0.356900006532669,3.201007843017578,10000,6872.681324005127,0.5181445479393005,2.3554434776306152,0.4603399932384491,2.6169192790985107,50000 -562.43803191185,0.4050462245941162,6757.488060712814,14756,0,6757.488060712814,0.3725000321865082,3.11665415763855,10000,7321.107630968094,0.5152734518051147,2.358519554138184,0.4766999781131744,2.536402463912964,50000 -595.4816019535065,0.4494235515594482,7177.518812179565,15681,0,7177.518812179565,0.385200023651123,3.0521280765533447,10000,7774.275252819061,0.5367968678474426,2.26556658744812,0.4979399740695953,2.443150281906128,50000 -626.7855026721954,0.4745340347290039,7597.600474834442,16611,0,7597.600474834442,0.3908000290393829,3.024139881134033,10000,8225.735122203827,0.5549218654632568,2.1838366985321045,0.501039981842041,2.4183454513549805,50000 -658.9955780506134,0.5048494338989258,8017.97752737999,17542,0,8017.97752737999,0.4057000279426574,2.9616432189941406,10000,8678.402475357056,0.5504687428474426,2.1938395500183105,0.5118199586868286,2.3658089637756348,50000 -693.2478134632111,0.5354297161102295,8437.904272794724,18464,0,8437.904272794724,0.4108000099658966,2.897355794906616,10000,9132.661324977877,0.5669335722923279,2.096486091613769,0.5268599987030029,2.2790541648864746,50000 -728.4040546417236,0.5617325305938721,8858.220192193985,19391,0,8858.220192193985,0.4244000315666199,2.8181262016296387,10000,9588.208375692368,0.5886914134025574,1.9833548069000244,0.5397199988365173,2.214489698410034,50000 -763.6083111763,0.5928773880004883,9278.46674156189,20321,0,9278.46674156189,0.4308000206947326,2.7886321544647217,10000,10043.739041805267,0.5872656106948853,1.982076644897461,0.5415999889373779,2.179627656936645,50000 -799.4674112796783,0.6267457008361816,9698.78861618042,21248,0,9698.78861618042,0.4305000305175781,2.7999267578125,10000,10500.002382278442,0.5874413847923279,1.997110605239868,0.5496799945831299,2.174673795700073,50000 -835.1677443981171,0.6544830799102783,10118.962321043016,22177,0,10118.962321043016,0.4414000213146209,2.7625842094421387,10000,10955.953563451769,0.6080663800239563,1.9223328828811648,0.5569199919700623,2.1479055881500244,50000 -869.9834413528442,0.6817433834075928,10539.390924930573,23105,0,10539.390924930573,0.4481000304222107,2.6906943321228027,10000,11411.274099826813,0.6371484398841858,1.7699031829833984,0.5706999897956848,2.06937313079834,50000 -904.6943933963776,0.7084939479827881,10959.36401629448,24032,0,10959.36401629448,0.4580000340938568,2.681931734085083,10000,11866.034049749374,0.6188671588897705,1.8651227951049805,0.5744799971580505,2.0639655590057373,50000 -940.5957970619202,0.7409365177154541,11379.409708976746,24957,0,11379.409708976746,0.4633000195026397,2.642620325088501,10000,12322.063168525696,0.6286327838897705,1.810321569442749,0.5790799856185913,2.0259885787963867,50000 -974.3425529003144,0.7686212062835693,11799.423764944077,25885,0,11799.423764944077,0.4640000164508819,2.601236343383789,10000,12775.900420188904,0.6464648246765137,1.718097686767578,0.5830000042915344,1.9904863834381104,50000 -1008.6455118656158,0.7995619773864746,12219.484512329102,26811,0,12219.484512329102,0.4707000255584717,2.5924384593963623,10000,13230.34390258789,0.6296288967132568,1.7748783826828003,0.5895999670028687,1.974582314491272,50000 -1043.1403737068176,0.8287265300750732,12639.566604852676,27739,0,12639.566604852676,0.4793000221252441,2.5261025428771973,10000,13684.9992685318,0.6431835889816284,1.694810390472412,0.6007599830627441,1.8964165449142456,50000 -1077.1589756011963,0.8670501708984375,13059.846086263657,28669,0,13059.846086263657,0.4789000153541565,2.543222665786743,10000,14139.38765001297,0.6536718606948853,1.666849136352539,0.5978400111198425,1.9145691394805908,50000 -1111.5325186252594,0.903029203414917,13479.923082113266,29597,0,13479.923082113266,0.4761000275611877,2.5695018768310547,10000,14593.923149585724,0.6449218392372131,1.720602035522461,0.5954799652099609,1.935060739517212,50000 -1144.7327094078064,0.9328234195709229,13900.250252962112,30525,0,13900.250252962112,0.4842000305652618,2.5386219024658203,10000,15047.529522657394,0.6502929329872131,1.707275390625,0.602620005607605,1.911301851272583,50000 -1179.5270998477936,0.9612855911254884,14320.317656040192,31451,0,14320.317656040192,0.4948000311851501,2.4632036685943604,10000,15502.468721866608,0.6668164134025574,1.6053072214126587,0.6132000088691711,1.8501198291778564,50000 -1214.3044934272766,0.9927427768707277,14740.255218982697,32379,0,14740.255218982697,0.4894000291824341,2.4860403537750244,10000,15957.264911651611,0.6805077791213989,1.5708024501800537,0.6123799681663513,1.8718345165252688,50000 -1247.9506666660309,1.0217430591583252,15160.616579771042,33309,0,15160.616579771042,0.4953000247478485,2.487512350082397,10000,16411.350699186325,0.6620507836341858,1.677104949951172,0.6145600080490112,1.8774863481521609,50000 -1281.597553730011,1.057400465011597,15580.754612445831,34238,0,15580.754612445831,0.496800035238266,2.435833215713501,10000,16865.220044851303,0.6710546612739563,1.578461527824402,0.6174600124359131,1.8176082372665403,50000 -1315.423936367035,1.0867114067077637,16000.741730213163,35166,0,16000.741730213163,0.4941000342369079,2.4827773571014404,10000,17319.111181020737,0.6815429329872131,1.5850794315338137,0.621239960193634,1.84546422958374,50000 -1350.0655298233032,1.1218464374542236,16421.03150343895,36093,0,16421.03150343895,0.5010000467300415,2.4259908199310303,10000,17774.126088142395,0.6749609112739563,1.5821787118911743,0.6269599795341492,1.7968742847442627,50000 -1385.2672073841095,1.1547789573669434,16841.30184864998,37022,0,16841.30184864998,0.5004000067710876,2.398932933807373,10000,18229.68051123619,0.6784374713897705,1.5355225801467896,0.6299600005149841,1.7512410879135132,50000 -1421.4622313976288,1.187713623046875,17261.51871085167,37949,0,17261.51871085167,0.5121000409126282,2.372727155685425,10000,18686.173897981644,0.6911327838897705,1.483366847038269,0.6319000124931335,1.7460869550704956,50000 -1455.6237666606903,1.2181427478790283,17681.805111408234,38877,0,17681.805111408234,0.5072000026702881,2.3739521503448486,10000,19140.70230269432,0.6896093487739563,1.500171780586243,0.6326599717140198,1.7544867992401123,50000 -1489.073492050171,1.2485857009887695,18101.77188515663,39807,0,18101.77188515663,0.5121999979019165,2.383885622024536,10000,19594.19838809967,0.6860156059265137,1.5207045078277588,0.6359599828720093,1.745492458343506,50000 -1520.9033830165863,1.277255296707153,18522.085191726685,40735,0,18522.085191726685,0.5095000267028809,2.366427183151245,10000,20046.41907286644,0.6874804496765137,1.4901330471038818,0.6361799836158752,1.7246910333633425,50000 -1553.8083319664,1.3109490871429443,18942.435485124588,41661,0,18942.435485124588,0.517300009727478,2.3592562675476074,10000,20499.75671958924,0.7146679759025574,1.4313410520553589,0.6459000110626221,1.7241450548171997,50000 -1588.933512687683,1.3452599048614502,19362.675671339035,42589,0,19362.675671339035,0.5134000182151794,2.363996982574463,10000,20955.205878019333,0.6867382526397705,1.5198066234588623,0.6368399858474731,1.7426916360855105,50000 -1623.3567078113556,1.3867592811584473,19782.83354473114,43520,0,19782.83354473114,0.5212000012397766,2.3035788536071777,10000,21409.87931132317,0.7005859017372131,1.4200528860092163,0.6437000036239624,1.678924322128296,50000 -1656.5718541145325,1.4205570220947266,20203.08823728561,44447,0,20203.08823728561,0.5196000337600708,2.387691020965576,10000,21863.43239045143,0.7023046612739563,1.501721739768982,0.6417799592018127,1.76237154006958,50000 -1691.4364099502563,1.452678680419922,20623.02325534821,45375,0,20623.02325534821,0.5195000171661377,2.3472135066986084,10000,22318.31251120568,0.6984961032867432,1.481454610824585,0.6467799544334412,1.7090678215026855,50000 -1727.9724340438845,1.4842946529388428,21043.11750602722,46301,0,21043.11750602722,0.5258000493049622,2.281585216522217,10000,22775.022524118423,0.7051757574081421,1.4099836349487305,0.6536799669265747,1.652795433998108,50000 -1760.8348679542542,1.5186994075775146,21463.19794869423,47228,0,21463.19794869423,0.5272000432014465,2.330073356628418,10000,23228.048787355423,0.7090820074081421,1.4518290758132937,0.6502000093460083,1.708951473236084,50000 -1794.4216213226318,1.5487060546875,21883.307546377186,48156,0,21883.307546377186,0.5291000008583069,2.2683568000793457,10000,23681.82369303704,0.7205273509025574,1.3598943948745728,0.6526199579238892,1.6453591585159302,50000 -1828.082728624344,1.5821101665496826,22303.582375764847,49084,0,22303.582375764847,0.5312000513076782,2.238567352294922,10000,24135.841794013977,0.7106054425239563,1.384338617324829,0.6567999720573425,1.621684432029724,50000 -1861.705460071564,1.618788242340088,22723.54018163681,50009,0,22723.54018163681,0.5333000421524048,2.27439832687378,10000,24589.507689237595,0.7145702838897705,1.409250259399414,0.6565799713134766,1.664278507232666,50000 -1896.688829660416,1.6524369716644287,23143.67327833176,50934,0,23143.67327833176,0.5382000207901001,2.2213504314422607,10000,25044.706683397293,0.727734386920929,1.318716526031494,0.6599999666213989,1.6194846630096436,50000 -1931.7127187252045,1.6885082721710205,23563.626941919327,51860,0,23563.626941919327,0.5361000299453735,2.21683406829834,10000,25499.76881289482,0.7147851586341858,1.3640855550765991,0.6587399840354919,1.61241352558136,50000 -1965.6869568824768,1.7237651348114014,23983.90749502182,52789,0,23983.90749502182,0.526900053024292,2.2857484817504883,10000,25954.107449531555,0.7134960889816284,1.4122158288955688,0.6552799940109253,1.6602015495300293,50000 -1998.2732956409448,1.7659268379211426,24404.197466611862,53718,0,24404.197466611862,0.5385000109672546,2.231038808822632,10000,26407.075337409973,0.7314453125,1.304121494293213,0.6644399762153625,1.601436972618103,50000 -2032.3752472400663,1.7985684871673584,24824.45699334145,54647,0,24824.45699334145,0.5401000380516052,2.210639476776123,10000,26861.518835544583,0.71533203125,1.3832980394363403,0.6642999649047852,1.607182502746582,50000 -2065.166063785553,1.8309228420257568,25244.383882761,55575,0,25244.383882761,0.5408000349998474,2.2293291091918945,10000,27314.31769967079,0.7239453196525574,1.3593758344650269,0.6665999889373779,1.6094951629638672,50000 -2098.33789563179,1.864485502243042,25665.15626358986,56501,0,25665.15626358986,0.5393000245094299,2.2219364643096924,10000,27768.345024824142,0.7246484160423279,1.3352673053741455,0.6661999821662903,1.6007659435272217,50000 -2132.7292597293854,1.8969478607177728,26085.455446720123,57431,0,26085.455446720123,0.5423000454902649,2.192864894866944,10000,28223.11777973175,0.7429101467132568,1.2352076768875122,0.6694999933242798,1.5599645376205444,50000 -2166.513976097107,1.9372212886810305,26505.45462369919,58358,0,26505.45462369919,0.5409000515937805,2.201468706130981,10000,28676.99065828324,0.7205859422683716,1.346676468849182,0.6641799807548523,1.5929832458496094,50000 -2200.711050987244,1.973011493682861,26925.466319322582,59285,0,26925.466319322582,0.5421000123023987,2.206836462020874,10000,29131.28421139717,0.7235937118530273,1.3309944868087769,0.6640599966049194,1.594617247581482,50000 -2234.163626194,2.0108447074890137,27345.63130736351,60211,0,27345.63130736351,0.547700047492981,2.189788341522217,10000,29584.98865461349,0.7406054735183716,1.2563652992248535,0.6726999878883362,1.5577574968338013,50000 -2266.2198588848114,2.044236660003662,27765.606746673584,61138,0,27765.606746673584,0.5457000136375427,2.157984495162964,10000,30037.10212635994,0.729296863079071,1.288324236869812,0.6774599552154541,1.5246641635894775,50000 -2302.1697578430176,2.4654579162597656,28185.229808330536,62061,0,28185.229808330536,0.5463000535964966,2.169198751449585,10000,30493.144993782043,0.7343164086341858,1.2893426418304443,0.6753199696540833,1.5548889636993408,50000 -2339.451201438904,2.5024900436401367,28605.168552160263,62990,0,28605.168552160263,0.5506000518798828,2.156816005706787,10000,30950.45087170601,0.7407812476158142,1.2475316524505615,0.6754399538040161,1.5307263135910034,50000 -2376.555834293365,2.5417940616607666,29025.421385526657,63916,0,29025.421385526657,0.5541000366210938,2.133321046829224,10000,31407.895943164825,0.7317968606948853,1.2713884115219116,0.6771399974822998,1.5070552825927734,50000 -2410.2451598644257,2.581367254257202,29445.614727020264,64842,0,29445.614727020264,0.5561000108718872,2.1238458156585693,10000,31861.867134332657,0.7364453077316284,1.2684588432312012,0.6774799823760986,1.5218348503112793,50000 -2443.933711528778,2.621212720870972,29865.653613567352,65767,0,29865.653613567352,0.5523000359535217,2.1705105304718018,10000,32315.68285059929,0.7422069907188416,1.2749184370040894,0.6781600117683411,1.548775553703308,50000 -2477.4188113212585,2.660482168197632,30285.574808120728,66693,0,30285.574808120728,0.5541000366210938,2.1089136600494385,10000,32769.177268743515,0.758984386920929,1.163934350013733,0.6821199655532837,1.4920015335083008,50000 -2514.413983345032,2.6944541931152344,30705.76295566559,67619,0,30705.76295566559,0.551300048828125,2.1411116123199463,10000,33226.44407486916,0.7391406297683716,1.2571861743927002,0.6822999715805054,1.5116750001907349,50000 -2548.3767414093018,2.7300000190734863,31125.82836127281,68545,0,31125.82836127281,0.558899998664856,2.133676052093506,10000,33680.55595970154,0.74609375,1.2384456396102903,0.683459997177124,1.5109264850616455,50000 -2581.5695893764496,2.764475584030152,31545.97432255745,69471,0,31545.97432255745,0.5611000061035156,2.103188991546631,10000,34133.97758722305,0.755175769329071,1.170236349105835,0.684939980506897,1.482764482498169,50000 -2616.4249787330627,2.803164482116699,31966.08284807205,70398,0,31966.08284807205,0.5548000335693359,2.153522253036499,10000,34589.028621673584,0.7429296970367432,1.2822554111480713,0.6846599578857422,1.5301166772842407,50000 -2649.695028066635,2.83899450302124,32385.999056339264,71325,0,32385.999056339264,0.5677000284194946,2.128825664520264,10000,35042.29918694496,0.7483007907867432,1.252873182296753,0.6881600022315979,1.5203845500946045,50000 -2683.389586210251,2.873255014419556,32805.99316358566,72251,0,32805.99316358566,0.5610000491142273,2.117690086364746,10000,35496.07172703743,0.7540820240974426,1.2140352725982666,0.6817799806594849,1.519381046295166,50000 -2715.83868432045,2.911292314529419,33226.05568480492,73176,0,33226.05568480492,0.5640000104904175,2.119715929031372,10000,35948.66935944557,0.7476171851158142,1.2500646114349363,0.6918999552726746,1.4931375980377195,50000 -2749.675267457962,2.9467294216156006,33646.12110567093,74100,0,33646.12110567093,0.5685000419616699,2.076725721359253,10000,36402.65527367592,0.7524218559265137,1.1826977729797363,0.6895999908447266,1.460847020149231,50000 -2784.946478128433,2.986032724380493,34066.15981054306,75027,0,34066.15981054306,0.5649999976158142,2.099517107009888,10000,36858.05403661728,0.7589452862739563,1.1854172945022583,0.6945599913597107,1.469211220741272,50000 -2818.763783454895,3.024766683578491,34486.180584430695,75953,0,34486.180584430695,0.5619000196456909,2.1018519401550293,10000,37311.97987627983,0.7703906297683716,1.1538687944412231,0.6911799907684326,1.488873839378357,50000 -2850.4957184791565,3.0609076023101807,34906.22839021683,76878,0,34906.22839021683,0.5715000033378601,2.0509159564971924,10000,37763.84492731094,0.7526757717132568,1.1900469064712524,0.6944400072097778,1.4452449083328247,50000 -2885.9357640743256,3.1037163734436035,35326.41598343849,77804,0,35326.41598343849,0.5711000561714172,2.069314956665039,10000,38219.564494133,0.7564452886581421,1.186036467552185,0.6922799944877625,1.4668896198272705,50000 -2919.8212456703186,3.144604682922364,35746.742753982544,78733,0,35746.742753982544,0.5674000382423401,2.0553789138793945,10000,38673.86698675156,0.7732617259025574,1.1194065809249878,0.6947199702262878,1.444237470626831,50000 -2951.7412304878235,3.184016704559326,36166.86061668396,79660,0,36166.86061668396,0.5734000205993652,2.0305778980255127,10000,39125.99325990677,0.7591992020606995,1.1473793983459473,0.6970199942588806,1.4147883653640747,50000 -2982.663145303726,3.2296504974365234,36586.801466465,80584,0,36586.801466465,0.5719000101089478,2.041655302047729,10000,39576.94997668266,0.7634375095367432,1.142561674118042,0.7002399563789368,1.4214868545532229,50000 -3017.626652240753,3.2731289863586426,37006.7823369503,81508,0,37006.7823369503,0.5731000304222107,2.0733675956726074,10000,40031.986365795135,0.7697460651397705,1.1436564922332764,0.6994400024414062,1.451212763786316,50000 -3053.159845352173,3.3112025260925293,37426.94573545456,82435,0,37426.94573545456,0.5703999996185303,2.048683881759644,10000,40487.77037620544,0.7611523270606995,1.1672053337097168,0.699400007724762,1.4308583736419678,50000 -3090.146003007889,3.34938383102417,37846.88027453423,83361,0,37846.88027453423,0.5679000020027161,2.0731194019317627,10000,40944.77842760086,0.7637695074081421,1.173699975013733,0.6990399956703186,1.4479047060012815,50000 -3126.185777664185,3.386541843414306,38267.15163445473,84287,0,38267.15163445473,0.5777000188827515,2.0119788646698,10000,41401.175797224045,0.7743945121765137,1.1085599660873413,0.70305997133255,1.4031741619110107,50000 -3159.238482236862,3.423867702484131,38687.10668325424,85213,0,38687.10668325424,0.5763000249862671,2.0371885299682617,10000,41854.26957678795,0.7837694883346558,1.086231708526611,0.7024999856948853,1.429186224937439,50000 -3194.0549857616425,3.4639732837677,39107.087277412415,86141,0,39107.087277412415,0.5753000378608704,2.027710199356079,10000,42309.15513443947,0.7665038704872131,1.1527868509292605,0.7004599571228027,1.423740267753601,50000 -3227.807467699051,3.500558376312256,39527.25010895729,87068,0,39527.25010895729,0.581000030040741,2.0015461444854736,10000,42763.15571928024,0.7691406011581421,1.0984034538269043,0.7033999562263489,1.393964767456055,50000 -3264.1564960479736,3.542058229446411,39947.20313882828,87991,0,39947.20313882828,0.5868000388145447,1.9906413555145264,10000,43219.547652721405,0.7862499952316284,1.0665332078933716,0.7097600102424622,1.3905929327011108,50000 -3297.296382665634,3.579483985900879,40367.46312975884,88918,0,40367.46312975884,0.5790000557899475,2.0272250175476074,10000,43673.03454852104,0.7696874737739563,1.142638087272644,0.7088800072669983,1.4129151105880735,50000 -3332.3801724910736,3.621098756790161,40787.6904566288,89848,0,40787.6904566288,0.5851000547409058,2.0033745765686035,10000,44128.43647813797,0.7753515243530273,1.1038709878921509,0.7093200087547302,1.3977713584899902,50000 -3365.9759092330933,3.658256053924561,41207.61959028244,90774,0,41207.61959028244,0.5878000259399414,1.9843614101409912,10000,44582.04789733887,0.781054675579071,1.074851155281067,0.7090799808502197,1.384374976158142,50000 -3399.644933462143,3.699693918228149,41628.0981926918,91701,0,41628.0981926918,0.5909000039100647,1.9665168523788448,10000,45036.28593277931,0.77406245470047,1.1004438400268557,0.7123000025749207,1.3684351444244385,50000 -3428.649636030197,3.740588903427124,42048.34095311165,92629,0,42048.34095311165,0.5853000283241272,1.97959566116333,10000,45485.623708724976,0.7764062285423279,1.1032389402389526,0.7084800004959106,1.3871572017669678,50000 -3467.593858480453,3.784477949142456,42468.40495491028,93553,0,42468.40495491028,0.5873000025749207,1.9714946746826167,10000,45944.72555589676,0.7830468416213989,1.04946768283844,0.7099199891090393,1.3628854751586914,50000 -3502.902225732804,3.828879117965698,42888.62809586525,94479,0,42888.62809586525,0.5933000445365906,1.9169158935546875,10000,46400.34952402115,0.8056640625,0.9454649686813354,0.718459963798523,1.3121576309204102,50000 -3536.771305322647,3.8754472732543945,43308.81820011139,95407,0,43308.81820011139,0.5925000309944153,1.9520272016525269,10000,46854.504509449005,0.7821093797683716,1.0739256143569946,0.7144799828529358,1.3560128211975098,50000 -3572.430454492569,3.914361476898194,43729.01921200752,96335,0,43729.01921200752,0.5966000556945801,1.926663517951965,10000,47310.45342755318,0.78955078125,1.03350567817688,0.7170000076293945,1.3376338481903076,50000 -3606.683182001114,3.956416368484497,44149.1900267601,97262,0,44149.1900267601,0.5884000062942505,1.9976340532302856,10000,47764.96751379967,0.7940233945846558,1.0544887781143188,0.7132399678230286,1.3924885988235474,50000 -3639.32297372818,3.995898723602295,44569.44393587112,98188,0,44569.44393587112,0.5919000506401062,1.9524030685424805,10000,48217.94920134544,0.783886730670929,1.0592858791351318,0.717519998550415,1.3481186628341677,50000 -3674.262171983719,4.035337209701538,44989.47699189186,99116,0,44989.47699189186,0.5948000550270081,1.938390731811524,10000,48673.00974225998,0.7895702719688416,1.0405603647232056,0.7199599742889404,1.3382598161697388,50000 -3707.622751235962,4.07489013671875,45409.69598340988,100045,0,45409.69598340988,0.5946000218391418,1.9406942129135127,10000,49126.67878437042,0.7934765219688416,1.03852379322052,0.719980001449585,1.356076717376709,50000 -3742.466834306717,4.11404824256897,45829.853684186935,100973,0,45829.853684186935,0.5918000340461731,1.9154317378997805,10000,49581.7681043148,0.7863867282867432,1.0272761583328247,0.7179799675941467,1.3234541416168213,50000 -3775.237138032913,4.155972242355347,46250.57767724991,101901,0,46250.57767724991,0.598300039768219,1.912156343460083,10000,50035.35246872902,0.7957617044448853,1.0064743757247925,0.7209399938583374,1.316093921661377,50000 -3810.6633427143097,4.199147939682007,46670.89849615097,102828,0,46670.89849615097,0.5942000150680542,1.957307934761048,10000,50491.19104528427,0.7928906083106995,1.0309524536132812,0.7185800075531006,1.3496639728546145,50000 -3843.5726778507233,4.241317510604858,47091.06334114075,103754,0,47091.06334114075,0.5997000336647034,1.9341633319854736,10000,50944.35725951195,0.8130663633346558,0.960721492767334,0.7226600050926208,1.3447332382202148,50000 -3878.780025720596,4.284324884414673,47511.09459543228,104681,0,47511.09459543228,0.6028000116348267,1.9325157403945925,10000,51399.68806958199,0.79408198595047,1.0491050481796265,0.7261199951171875,1.340632438659668,50000 -3912.86477136612,4.712871551513672,47930.99238157272,105604,0,47930.99238157272,0.6016000509262085,1.9280401468276973,10000,51854.14779257774,0.7981054782867432,1.0043699741363523,0.7231199741363525,1.320081353187561,50000 -3948.129752397537,4.755612373352051,48351.03900790215,106529,0,48351.03900790215,0.603600025177002,1.904357671737671,10000,52309.55121731758,0.8086132407188416,0.9574166536331176,0.7298600077629089,1.299571871757507,50000 -3984.69742679596,4.803227424621582,48771.34631705284,107457,0,48771.34631705284,0.5999000072479248,1.917056679725647,10000,52766.52301621437,0.796582043170929,1.0255680084228516,0.7306199669837952,1.309115648269653,50000 -4018.960418462753,4.847645044326782,49191.41876125336,108382,0,49191.41876125336,0.6070000529289246,1.8809870481491089,10000,53220.95117616653,0.8032421469688416,0.9735230803489684,0.7277199625968933,1.296148419380188,50000 -4052.371610879898,4.887386798858643,49611.47404384613,109308,0,49611.47404384613,0.6025000214576721,1.9155027866363523,10000,53674.50714588165,0.80712890625,0.9809885025024414,0.7274799942970276,1.309841513633728,50000 -4084.464169979096,4.937520742416382,50031.75916481018,110233,0,50031.75916481018,0.6146000027656555,1.8612117767333984,10000,54126.98281383514,0.8057421445846558,0.956943690776825,0.7321999669075012,1.2614250183105469,50000 -4119.084473133087,4.998009443283081,50451.79173922539,111158,0,50451.79173922539,0.6043000221252441,1.9332977533340447,10000,54581.74570274353,0.8017968535423279,1.0286946296691897,0.7317799925804138,1.3278456926345823,50000 -4153.366170406342,5.039664268493652,50871.74392461777,112081,0,50871.74392461777,0.6106000542640686,1.8744934797286987,10000,55036.06898331642,0.8128515481948853,0.9389355778694152,0.7346799969673157,1.2732287645339966,50000 -4188.366625308991,5.085594415664673,51291.82757949829,113004,0,51291.82757949829,0.6077000498771667,1.887868046760559,10000,55491.24841022492,0.8268554210662842,0.8962365388870239,0.7337200045585632,1.2884944677352903,50000 -4222.597744464874,5.1359288692474365,51711.96106266976,113930,0,51711.96106266976,0.6146000027656555,1.8564660549163816,10000,55945.711918354034,0.8086913824081421,0.9540855288505554,0.7350999712944031,1.2699496746063232,50000 -4256.532593727112,5.181121587753296,52131.87544989586,114857,0,52131.87544989586,0.6121000051498413,1.887677788734436,10000,56399.65555882454,0.8112109303474426,0.9619636535644532,0.7340999841690063,1.2924312353134155,50000 -4292.1596002578735,5.225587368011475,52551.95956778526,115782,0,52551.95956778526,0.6084000468254089,1.8864821195602417,10000,56855.459886312485,0.8207421898841858,0.9196900129318236,0.7369199991226196,1.2814921140670776,50000 -4323.595373153687,5.270857572555542,52972.2819852829,116708,0,52972.2819852829,0.6146000027656555,1.8671879768371584,10000,57307.31153726578,0.8111132383346558,0.9671093821525574,0.7394199967384338,1.2774097919464111,50000 -4355.863595724106,5.312883377075195,53392.3686144352,117635,0,53392.3686144352,0.613800048828125,1.872160077095032,10000,57759.757767915726,0.8186327815055847,0.9464924931526184,0.73881995677948,1.275384545326233,50000 -4391.868913650513,5.357113838195801,53812.669481277466,118557,0,53812.669481277466,0.6128000020980835,1.859431505203247,10000,58216.1560792923,0.8215234279632568,0.9040040969848632,0.7385199666023254,1.2626641988754272,50000 -4428.260704755783,5.399370908737183,54232.62202167511,119482,0,54232.62202167511,0.6167000532150269,1.868729591369629,10000,58672.590933561325,0.8153515458106995,0.9577239751815796,0.7410399913787842,1.269242763519287,50000 -4460.699743509293,5.445420742034912,54652.57331991196,120407,0,54652.57331991196,0.6178000569343567,1.8522409200668333,10000,59125.07707071304,0.8184374570846558,0.9347028136253356,0.7399199604988098,1.2617909908294678,50000 -4494.443091392517,5.487497329711914,55072.54042840004,121333,0,55072.54042840004,0.6199000477790833,1.8362445831298828,10000,59578.87798953056,0.8251171708106995,0.8875871300697327,0.743619978427887,1.240481734275818,50000 -4528.681435108185,5.531659364700317,55492.709612846375,122261,0,55492.709612846375,0.6212000250816345,1.827873468399048,10000,60033.37808465958,0.8381249904632568,0.8315138220787048,0.7446199655532837,1.2221808433532717,50000 -4560.943968057632,5.580912590026856,55912.80021524429,123187,0,55912.80021524429,0.6188000440597534,1.8316484689712524,10000,60485.82897758484,0.8201562166213989,0.9035536050796508,0.7415199875831604,1.2345235347747805,50000 -4594.97258067131,5.626370191574097,56332.88788199425,124111,0,56332.88788199425,0.6260000467300415,1.830779194831848,10000,60940.03962993622,0.8274609446525574,0.8888863325119019,0.7472400069236755,1.228013038635254,50000 -4629.3885724544525,5.673748731613159,56752.807072639465,125037,0,56752.807072639465,0.6228000521659851,1.8111176490783687,10000,61394.470771074295,0.8400976657867432,0.8476256132125854,0.7465400099754333,1.231023907661438,50000 -4663.57472038269,5.72042441368103,57172.92175483704,125962,0,57172.92175483704,0.6270000338554382,1.8236942291259768,10000,61848.86698770523,0.8267577886581421,0.9042437076568604,0.744219958782196,1.2379591464996338,50000 -4697.19917845726,5.7643516063690186,57593.29928016663,126889,0,57593.29928016663,0.6248000264167786,1.817178010940552,10000,62302.96178174019,0.829394519329071,0.8763020634651184,0.7468599677085876,1.225918889045715,50000 -4730.561057806015,5.811323881149292,58013.23331975937,127815,0,58013.23331975937,0.6276000142097473,1.7829385995864868,10000,62756.352815151215,0.8350195288658142,0.8415217399597168,0.7514199614524841,1.19737708568573,50000 -4765.972766160965,5.859126567840576,58433.33541345596,128742,0,58433.33541345596,0.629300057888031,1.8096715211868288,10000,63211.96445417404,0.8329101204872131,0.8737267255783081,0.7518799901008606,1.2153342962265017,50000 -4794.663794994354,5.904484272003174,58853.82092785835,129669,0,58853.82092785835,0.628600001335144,1.7966262102127075,10000,63661.234209775925,0.8311132788658142,0.8590825200080872,0.7507799863815308,1.207733988761902,50000 -4828.563596725464,5.966588497161865,59274.19008851051,130594,0,59274.19008851051,0.628000020980835,1.803253412246704,10000,64115.61355304718,0.84033203125,0.8434707522392273,0.7518799901008606,1.2105284929275513,50000 -4860.874860286713,6.009950637817383,59694.75778841972,131523,0,59694.75778841972,0.6252000331878662,1.8083374500274656,10000,64568.58602309227,0.8484765291213989,0.8026385307312012,0.7519800066947937,1.2015353441238403,50000 -4896.653185606003,6.053093671798706,60114.998239040375,132448,0,60114.998239040375,0.6305000185966492,1.7780799865722656,10000,65024.695997715,0.8370116949081421,0.834997832775116,0.7545199990272522,1.1796581745147705,50000 -4929.048456430435,6.100675344467163,60535.22475409508,133375,0,60535.22475409508,0.6372000575065613,1.7991794347763062,10000,65477.413791656494,0.8409179449081421,0.8413550853729248,0.7555800080299377,1.2021714448928833,50000 -4964.578231573105,6.146247386932373,60955.59069728851,134299,0,60955.59069728851,0.6312000155448914,1.78182053565979,10000,65933.40303897858,0.8481249809265137,0.8018088936805725,0.7539399862289429,1.1947689056396484,50000 -5000.347151756287,6.192385196685791,61375.55031371117,135226,0,61375.55031371117,0.638200044631958,1.7809736728668213,10000,66389.22588658333,0.8419921398162842,0.8387060165405273,0.7549200057983398,1.199419617652893,50000 -5035.061300992966,6.240199089050293,61795.65664720535,136151,0,61795.65664720535,0.6385000348091125,1.7745895385742188,10000,66844.14261889458,0.8457421660423279,0.8141063451766968,0.7572399973869324,1.1856240034103394,50000 -5072.519873142242,6.286186456680298,62215.70627474785,137076,0,62215.70627474785,0.6363000273704529,1.7618496417999268,10000,67301.74554777145,0.8503710627555847,0.803084671497345,0.7572799921035767,1.1784809827804563,50000 -5105.9577214717865,6.330903053283691,62635.84506726265,138004,0,62635.84506726265,0.6401000022888184,1.7613399028778076,10000,67755.41611027718,0.8441210985183716,0.8067802786827087,0.7592200040817261,1.1712822914123535,50000 -5141.126784801483,6.374598264694214,63055.99439263344,138929,0,63055.99439263344,0.6415000557899475,1.7597088813781738,10000,68210.82635331154,0.8473241925239563,0.8092468976974487,0.7607199549674988,1.178433895111084,50000 -5172.483269929886,6.423179388046265,63476.15185451508,139855,0,63476.15185451508,0.6410000324249268,1.7686790227890017,10000,68662.43821144104,0.8530859351158142,0.7908298969268799,0.7603999972343445,1.179398775100708,50000 -5208.587812423706,6.4807868003845215,63896.06881427765,140779,0,63896.06881427765,0.6406000256538391,1.7598754167556765,10000,69118.56551671028,0.8641406297683716,0.7573795318603516,0.7615399956703186,1.174419403076172,50000 -5243.268939495087,6.525313138961792,64316.15193653107,141701,0,64316.15193653107,0.6450000405311584,1.7349143028259275,10000,69573.42215561867,0.8532226085662842,0.7790336608886719,0.7630999684333801,1.150395750999451,50000 -5276.812092781067,6.580393314361572,64736.44136500359,142628,0,64736.44136500359,0.6477000117301941,1.7499827146530151,10000,70027.35831856728,0.8541601300239563,0.8001324534416199,0.7625600099563599,1.1798006296157837,50000 -5310.8514902591705,6.647460222244263,65156.379346847534,143555,0,65156.379346847534,0.6407000422477722,1.7233561277389526,10000,70481.45216488838,0.8614062070846558,0.7344046831130981,0.7638199925422668,1.1369210481643677,50000 -5346.491506576538,6.697022199630737,65576.38734292984,144480,0,65576.38734292984,0.6437000036239624,1.7280446290969849,10000,70937.19880628586,0.852832019329071,0.7703531980514526,0.7658999562263489,1.1399189233779907,50000 -5383.4498608112335,6.750396251678467,65996.66863155365,145406,0,65996.66863155365,0.6492000222206116,1.734040141105652,10000,71394.53972387314,0.8582421541213989,0.7630525827407837,0.7664799690246582,1.1516687870025637,50000 -5419.915347337723,6.796186208724976,66416.9311683178,146333,0,66416.9311683178,0.6490000486373901,1.7267446517944336,10000,71851.3624317646,0.8623632788658142,0.7401703000068665,0.7676799893379211,1.1452099084854126,50000 -5454.302681922913,6.847928524017334,66837.13448381424,147259,0,66837.13448381424,0.6520000100135803,1.7251871824264526,10000,72306.05270385742,0.8596093654632568,0.7663670778274536,0.7688199877738953,1.1503161191940308,50000 -5492.399060964584,6.8949620723724365,67257.20768213272,148183,0,67257.20768213272,0.6517000198364258,1.709851622581482,10000,72764.31754755974,0.8612304329872131,0.7369097471237183,0.7663999795913696,1.1357945203781128,50000 -5527.658671617508,6.944510221481323,67677.46754312515,149109,0,67677.46754312515,0.6541000604629517,1.7144485712051392,10000,73219.93533587456,0.8688085675239563,0.7321873903274536,0.7706199884414673,1.136370062828064,50000 -5563.043598890305,6.990314960479736,68097.54691076279,150035,0,68097.54691076279,0.653700053691864,1.7364274263381958,10000,73675.49411344528,0.8769726157188416,0.7250710725784302,0.7684800028800964,1.159796953201294,50000 -5601.793773651123,7.040832042694092,68517.9379067421,150961,0,68517.9379067421,0.6527000069618225,1.708383321762085,10000,74134.7340862751,0.8651366829872131,0.7272911667823792,0.7701199650764465,1.1263177394866943,50000 -5633.890254974365,7.088500022888184,68938.07565045357,151889,0,68938.07565045357,0.6552000045776367,1.6844028234481812,10000,74587.0642721653,0.8691796660423279,0.7026453018188477,0.770039975643158,1.108427882194519,50000 -5667.825475215912,7.13709831237793,69358.09606146812,152814,0,69358.09606146812,0.6562000513076782,1.7009633779525757,10000,75041.11749482155,0.8727929592132568,0.7038527727127075,0.7734599709510803,1.1162408590316772,50000 -5701.598674058914,7.188451766967773,69778.18906927109,153742,0,69778.18906927109,0.6551000475883484,1.710476279258728,10000,75495.08348155022,0.8701562285423279,0.7306351065635681,0.772379994392395,1.131606936454773,50000 -5737.48500084877,7.237675905227661,70198.11680269241,154667,0,70198.11680269241,0.6605000495910645,1.6771475076675415,10000,75950.9940161705,0.8733202815055847,0.6871562004089355,0.7735199928283691,1.1001402139663696,50000 -5771.582266569138,7.296308755874634,70618.42568159103,155588,0,70618.42568159103,0.6607000231742859,1.696984887123108,10000,76405.50797510147,0.874804675579071,0.6974979043006897,0.7736999988555908,1.112828254699707,50000 -5805.957935094833,7.350852727890015,71038.46853804588,156512,0,71038.46853804588,0.6598000526428223,1.6827988624572754,10000,76860.0290813446,0.8740820288658142,0.7038549780845642,0.7731800079345703,1.112505555152893,50000 -5841.424285888672,7.40127420425415,71458.70658421516,157436,0,71458.70658421516,0.6623000502586365,1.684056282043457,10000,77315.83238148689,0.87367182970047,0.6964795589447021,0.7759000062942505,1.108446717262268,50000 -5875.608007192612,7.447989463806152,71878.72174358368,158361,0,71878.72174358368,0.659000039100647,1.6712520122528076,10000,77770.12656998634,0.8794335722923279,0.675494372844696,0.7756199836730957,1.0992703437805176,50000 -5913.313067674637,7.503507614135742,72298.6612625122,159284,0,72298.6612625122,0.6622000336647034,1.6857061386108398,10000,78227.87518262863,0.8847265243530273,0.6685106754302979,0.7770799994468689,1.111339449882507,50000 -5950.919641017914,7.556835651397705,72718.89306354523,160211,0,72718.89306354523,0.6640000343322754,1.670935869216919,10000,78685.8155503273,0.8788671493530273,0.681140124797821,0.7768399715423584,1.102988362312317,50000 -5986.23615694046,7.605699300765991,73138.84251356125,161137,0,73138.84251356125,0.663800060749054,1.6655369997024536,10000,79141.17946362495,0.8810937404632568,0.6635440587997437,0.7783399820327759,1.0918289422988892,50000 -6023.772705554962,7.658042430877685,73558.78927731514,162061,0,73558.78927731514,0.6646000146865845,1.662644624710083,10000,79598.76367640495,0.885546863079071,0.6447293162345886,0.7795199751853943,1.0828590393066406,50000 -6062.477011442184,7.70618200302124,73978.82945775986,162987,0,73978.82945775986,0.666700005531311,1.6550078392028809,10000,80057.60511755943,0.8820117115974426,0.656156063079834,0.7808600068092346,1.076492428779602,50000 -6099.97262597084,7.761868000030518,74399.09590768814,163914,0,74399.09590768814,0.6621000170707703,1.6676348447799685,10000,80515.47159552574,0.8850781321525574,0.6661824584007263,0.780519962310791,1.0905977487564087,50000 -6137.260036468506,7.812313795089722,74819.25520539284,164840,0,74819.25520539284,0.6657000184059143,1.671615719795227,10000,80973.01725029945,0.8875976204872131,0.6562318205833435,0.7809399962425232,1.0894443988800049,50000 -6172.728312015533,7.865967750549316,75239.2183611393,165766,0,75239.2183611393,0.663800060749054,1.661457896232605,10000,81428.55065131187,0.8836132884025574,0.6717751026153564,0.7789999842643738,1.0890315771102903,50000 -6207.636021375656,7.91825795173645,75659.28298974037,166692,0,75659.28298974037,0.6668000221252441,1.6559531688690186,10000,81883.62391614914,0.8857030868530273,0.6525242328643799,0.7821799516677856,1.081518054008484,50000 -6242.713820695877,7.97198224067688,76079.30256128311,167619,0,76079.30256128311,0.6665000319480896,1.6662715673446655,10000,82338.82419419289,0.8876562118530273,0.6507769823074341,0.7819199562072754,1.0878478288650513,50000 -6276.506313562393,8.024755239486694,76499.28124332428,168545,0,76499.28124332428,0.6678000092506409,1.6539418697357178,10000,82792.69729566574,0.8912890553474426,0.6282624006271362,0.7827799916267395,1.078118920326233,50000 -6310.818851947784,8.084598779678345,76919.63348913193,169472,0,76919.63348913193,0.6668000221252441,1.6524428129196167,10000,83247.47095322609,0.8900976181030273,0.6365549564361572,0.7829799652099609,1.0776472091674805,50000 -6348.446104764938,8.135056495666504,77339.8290321827,170400,0,77339.8290321827,0.6687000393867493,1.6623047590255737,10000,83705.39258909225,0.8862695097923279,0.6478437781333923,0.7828199863433838,1.084557294845581,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/measurements.csv deleted file mode 100644 index f94e4770c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1896 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.29836407,6.9077535,,,,,,,,,,,,,, -1,,,0.0008593749953433,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,35.015655517578125,62.8664071559906,35.015655517578125,27.850645065307617,0.0,0.0 -100,0.38148215,6.904769,,,,,,,,,,,,,, -200,0.40988588,6.881121,,,,,,,,,,,,,, -300,0.49031135,6.8398085,,,,,,,,,,,,,, -400,0.52575904,6.8111625,,,,,,,,,,,,,, -500,0.53513014,6.8478746,,,,,,,,,,,,,, -600,0.7661168,6.731335,,,,,,,,,,,,,, -700,0.94130147,6.670276,,,,,,,,,,,,,, -800,1.1368748,6.678386,,,,,,,,,,,,,, -875,,,0.0174804683774709,6.352764129638672,0.016339998692274,6.365495681762695,50000.0,0.0130000002682209,6.417356014251709,10000.0,455.200243473053,517.6062908172607,455.200243473053,62.341947078704834,0.0177381038665771,0.0 -900,0.84834,6.5857697,,,,,,,,,,,,,, -1000,1.2833058,6.5758142,,,,,,,,,,,,,, -1100,1.0452181,6.649188,,,,,,,,,,,,,, -1200,0.86409396,6.534938,,,,,,,,,,,,,, -1300,1.3927343,6.5045214,,,,,,,,,,,,,, -1400,1.0125237,6.7419357,,,,,,,,,,,,,, -1500,0.91151637,6.709257,,,,,,,,,,,,,, -1600,1.0513586,6.324368,,,,,,,,,,,,,, -1700,0.97866243,6.4839964,,,,,,,,,,,,,, -1799,,,0.0464843735098838,5.830441951751709,0.0422799997031688,5.864605903625488,50000.0,0.0340000018477439,5.986781597137451,10000.0,875.5398058891296,973.7104048728944,875.5398058891296,98.02724885940552,0.0483002662658691,0.0 -1800,1.1960188,6.6082,,,,,,,,,,,,,, -1900,1.2162327,6.2866983,,,,,,,,,,,,,, -2000,1.4334325,6.291986,,,,,,,,,,,,,, -2100,2.3739831,6.3280478,,,,,,,,,,,,,, -2200,0.74769735,6.7309704,,,,,,,,,,,,,, -2300,0.91727227,6.2177663,,,,,,,,,,,,,, -2400,1.3038214,6.1891537,,,,,,,,,,,,,, -2500,1.279581,6.5237823,,,,,,,,,,,,,, -2600,1.3122458,6.1936865,,,,,,,,,,,,,, -2691,,,0.0708398446440696,5.448312282562256,0.0682799965143203,5.485576152801514,50000.0,0.051400002092123,5.651369571685791,10000.0,1295.7417635917664,1429.3486967086792,1295.7417635917664,133.39315724372864,0.0725252628326416,0.0 -2700,0.79680914,6.5986686,,,,,,,,,,,,,, -2800,1.0613863,6.411978,,,,,,,,,,,,,, -2900,1.0973257,6.0927873,,,,,,,,,,,,,, -3000,0.979519,6.074169,,,,,,,,,,,,,, -3100,0.812836,6.078933,,,,,,,,,,,,,, -3200,1.2623873,6.5368,,,,,,,,,,,,,, -3300,1.0118568,6.0083327,,,,,,,,,,,,,, -3400,0.9980577,6.0065527,,,,,,,,,,,,,, -3500,1.2186067,6.159245,,,,,,,,,,,,,, -3600,1.1369549,5.956229,,,,,,,,,,,,,, -3620,,,0.1144726499915123,5.076627731323242,0.1036200001835823,5.128720760345459,50000.0,0.0813000053167343,5.356257915496826,10000.0,1715.829880952835,1883.7408828735352,1715.829880952835,167.62544560432434,0.0957322120666503,0.0 -3700,1.0589858,5.957033,,,,,,,,,,,,,, -3800,1.0987488,5.9421196,,,,,,,,,,,,,, -3900,1.040821,5.9112935,,,,,,,,,,,,,, -4000,1.1829375,6.2801313,,,,,,,,,,,,,, -4100,0.8841159,6.184741,,,,,,,,,,,,,, -4200,0.91160107,6.316868,,,,,,,,,,,,,, -4300,0.98572874,6.0703135,,,,,,,,,,,,,, -4400,0.92129606,6.5851736,,,,,,,,,,,,,, -4500,1.016556,5.6813946,,,,,,,,,,,,,, -4546,,,0.1610351502895355,4.597641944885254,0.1422600001096725,4.700864315032959,50000.0,0.108300007879734,5.006032943725586,10000.0,2136.0431559085846,2338.310334444046,2136.0431559085846,201.9062662124633,0.1222002506256103,0.0 -4600,0.8941744,6.5085373,,,,,,,,,,,,,, -4700,1.4205955,6.427918,,,,,,,,,,,,,, -4800,1.079785,5.7929454,,,,,,,,,,,,,, -4900,1.1062212,5.66559,,,,,,,,,,,,,, -5000,1.099382,5.5922203,,,,,,,,,,,,,, -5100,1.1175648,5.5164857,,,,,,,,,,,,,, -5200,1.1710017,5.495788,,,,,,,,,,,,,, -5300,0.9510877,5.6870985,,,,,,,,,,,,,, -5400,1.1963737,5.4767776,,,,,,,,,,,,,, -5473,,,0.2023632824420929,4.254808902740479,0.1880799978971481,4.334715366363525,50000.0,0.1407999992370605,4.678713321685791,10000.0,2556.31773352623,2792.7371475696564,2556.31773352623,235.9856154918671,0.146214485168457,0.0 -5500,0.97264105,5.6016493,,,,,,,,,,,,,, -5600,1.4548857,5.655055,,,,,,,,,,,,,, -5700,0.8706061,6.273406,,,,,,,,,,,,,, -5800,1.2908964,5.4569616,,,,,,,,,,,,,, -5900,1.1163734,5.7970824,,,,,,,,,,,,,, -6000,1.1513643,5.4868784,,,,,,,,,,,,,, -6100,1.3243062,5.370141,,,,,,,,,,,,,, -6200,0.86446905,6.397096,,,,,,,,,,,,,, -6300,1.0161717,5.3222895,,,,,,,,,,,,,, -6400,1.2228597,5.272567,,,,,,,,,,,,,, -6402,,,0.25146484375,3.917474031448364,0.2351599931716919,4.009504795074463,50000.0,0.1830000132322311,4.3917012214660645,10000.0,2976.2778012752533,3245.2699744701385,2976.2778012752533,268.4839344024658,0.1717164516448974,0.0 -6500,0.96046245,5.4453473,,,,,,,,,,,,,, -6600,0.93130136,5.539044,,,,,,,,,,,,,, -6700,0.8254276,6.2840586,,,,,,,,,,,,,, -6800,1.0896927,5.167512,,,,,,,,,,,,,, -6900,0.8849586,5.354709,,,,,,,,,,,,,, -7000,1.0132285,5.16558,,,,,,,,,,,,,, -7100,0.9017557,6.3534417,,,,,,,,,,,,,, -7200,0.8606855,5.840517,,,,,,,,,,,,,, -7300,0.7210417,6.4205866,,,,,,,,,,,,,, -7330,,,0.2987304627895355,3.65021538734436,0.2711600065231323,3.7866060733795166,50000.0,0.2092000097036361,4.218708038330078,10000.0,3396.418829917908,3698.38631939888,3396.418829917908,301.3875472545624,0.1950387954711914,0.0 -7400,0.80200166,6.2340975,,,,,,,,,,,,,, -7500,0.8061865,6.155999,,,,,,,,,,,,,, -7600,1.2433851,5.320231,,,,,,,,,,,,,, -7700,1.1158612,5.286006,,,,,,,,,,,,,, -7800,0.7551457,6.366006,,,,,,,,,,,,,, -7900,0.86731416,5.341492,,,,,,,,,,,,,, -8000,0.9931931,5.078314,,,,,,,,,,,,,, -8100,1.0437162,5.083346,,,,,,,,,,,,,, -8200,1.0201274,5.1057673,,,,,,,,,,,,,, -8257,,,0.3236913979053497,3.4396770000457764,0.3033199906349182,3.5456032752990723,50000.0,0.2370000183582306,4.004726886749268,10000.0,3816.663746833801,4151.96754193306,3816.663746833801,334.6520891189575,0.2184698581695556,0.0 -8300,0.9045565,5.8755717,,,,,,,,,,,,,, -8400,1.3077722,5.046062,,,,,,,,,,,,,, -8500,1.2095388,5.2909307,,,,,,,,,,,,,, -8600,1.2954017,4.9359674,,,,,,,,,,,,,, -8700,1.0259647,4.942609,,,,,,,,,,,,,, -8800,0.98262346,4.940645,,,,,,,,,,,,,, -8900,1.0816014,4.7827053,,,,,,,,,,,,,, -9000,0.96140677,4.889351,,,,,,,,,,,,,, -9100,0.91328955,4.8339252,,,,,,,,,,,,,, -9184,,,0.3685546815395355,3.152804613113404,0.3393599987030029,3.286856412887573,50000.0,0.259300023317337,3.76635479927063,10000.0,4236.656970024109,4605.663713693619,4236.656970024109,368.2774066925049,0.2466275691986084,0.0 -9200,1.0341064,4.8510633,,,,,,,,,,,,,, -9300,0.99503475,4.8144717,,,,,,,,,,,,,, -9400,1.0556763,4.805708,,,,,,,,,,,,,, -9500,0.98937106,5.3819222,,,,,,,,,,,,,, -9600,0.80905473,5.5229063,,,,,,,,,,,,,, -9700,1.020638,4.9089746,,,,,,,,,,,,,, -9800,0.78252953,5.769402,,,,,,,,,,,,,, -9900,0.98305565,4.726537,,,,,,,,,,,,,, -10000,0.68217283,6.24755,,,,,,,,,,,,,, -10100,0.7930018,5.833638,,,,,,,,,,,,,, -10112,,,0.4025000035762787,3.0011579990386963,0.3671199977397918,3.165762662887573,50000.0,0.2835000157356262,3.687035322189331,10000.0,4656.829707622528,5060.144756317139,4656.829707622528,402.5110983848572,0.2724997997283935,0.0 -10200,0.9358989,4.9897957,,,,,,,,,,,,,, -10300,0.6697975,5.741494,,,,,,,,,,,,,, -10400,0.8285263,5.3232746,,,,,,,,,,,,,, -10500,0.9168576,4.7747707,,,,,,,,,,,,,, -10600,0.92697495,4.6675777,,,,,,,,,,,,,, -10700,0.9929579,4.610406,,,,,,,,,,,,,, -10800,0.69027734,6.1010084,,,,,,,,,,,,,, -10900,0.92295223,4.639388,,,,,,,,,,,,,, -11000,1.0308872,4.650672,,,,,,,,,,,,,, -11041,,,0.4302734136581421,2.806786060333252,0.3969399929046631,2.9504761695861816,50000.0,0.3059000074863434,3.496697902679444,10000.0,5076.838560819626,5513.161741495132,5076.838560819626,435.4448158740997,0.2978024482727051,0.0 -11100,0.7808152,5.48993,,,,,,,,,,,,,, -11200,0.94955266,4.5482483,,,,,,,,,,,,,, -11300,0.65337336,5.7790375,,,,,,,,,,,,,, -11400,0.9003775,4.7327247,,,,,,,,,,,,,, -11500,0.7421802,5.853131,,,,,,,,,,,,,, -11600,0.8258037,5.133235,,,,,,,,,,,,,, -11700,0.9395791,4.9526043,,,,,,,,,,,,,, -11800,0.9862305,4.529726,,,,,,,,,,,,,, -11900,0.9371976,4.437964,,,,,,,,,,,,,, -11969,,,0.4562304615974426,2.67525053024292,0.4208199977874756,2.830482482910156,50000.0,0.3264000117778778,3.385175943374634,10000.0,5497.176267147064,5965.148057460785,5497.176267147064,467.0167169570923,0.3260564804077148,0.0 -12000,0.9033206,4.935407,,,,,,,,,,,,,, -12100,0.93642604,4.459603,,,,,,,,,,,,,, -12200,0.855786,4.6101203,,,,,,,,,,,,,, -12300,0.9173169,4.509783,,,,,,,,,,,,,, -12400,1.1260102,4.607253,,,,,,,,,,,,,, -12500,0.77168614,4.723975,,,,,,,,,,,,,, -12600,0.99437594,4.453143,,,,,,,,,,,,,, -12700,0.95712364,4.511554,,,,,,,,,,,,,, -12800,0.9332914,4.4488635,,,,,,,,,,,,,, -12898,,,0.4791015386581421,2.5726089477539062,0.4387999773025512,2.7511792182922363,50000.0,0.3397000133991241,3.322092294692993,10000.0,5917.206485748291,6418.862198352814,5917.206485748291,500.62565183639526,0.3518826961517334,0.0 -12900,0.9639437,4.522727,,,,,,,,,,,,,, -13000,0.7506333,4.7158995,,,,,,,,,,,,,, -13100,0.9259093,4.705798,,,,,,,,,,,,,, -13200,1.0212216,4.3246384,,,,,,,,,,,,,, -13300,1.0264257,4.499808,,,,,,,,,,,,,, -13400,0.66467196,5.544939,,,,,,,,,,,,,, -13500,0.7041699,5.4540606,,,,,,,,,,,,,, -13600,0.92970985,4.431267,,,,,,,,,,,,,, -13700,0.88351506,4.4309177,,,,,,,,,,,,,, -13800,0.6484888,5.9961815,,,,,,,,,,,,,, -13826,,,0.5181445479393005,2.3554434776306152,0.4603399932384491,2.6169192790985107,50000.0,0.356900006532669,3.201007843017578,10000.0,6337.3051698207855,6872.681324005127,6337.3051698207855,534.2692155838013,0.3798291683197021,0.0 -13900,0.81312764,5.013983,,,,,,,,,,,,,, -14000,0.907851,4.4206886,,,,,,,,,,,,,, -14100,1.0236382,4.3317666,,,,,,,,,,,,,, -14200,0.9088968,4.6565914,,,,,,,,,,,,,, -14300,0.9600951,4.3005004,,,,,,,,,,,,,, -14400,0.7697558,4.9876575,,,,,,,,,,,,,, -14500,0.941428,4.3488207,,,,,,,,,,,,,, -14600,0.8777529,5.003766,,,,,,,,,,,,,, -14700,0.6117788,6.02252,,,,,,,,,,,,,, -14756,,,0.5152734518051147,2.358519554138184,0.4766999781131744,2.536402463912964,50000.0,0.3725000321865082,3.11665415763855,10000.0,6757.488060712814,7321.107630968094,6757.488060712814,562.43803191185,0.4050462245941162,0.0 -14800,1.1480031,4.2488976,,,,,,,,,,,,,, -14900,0.8835943,4.5342197,,,,,,,,,,,,,, -15000,0.7915987,6.038218,,,,,,,,,,,,,, -15100,0.8850468,4.3996725,,,,,,,,,,,,,, -15200,0.89314187,4.213795,,,,,,,,,,,,,, -15300,0.8866534,4.26196,,,,,,,,,,,,,, -15400,1.0516242,4.3650866,,,,,,,,,,,,,, -15500,0.6798861,5.795732,,,,,,,,,,,,,, -15600,0.915906,4.5875235,,,,,,,,,,,,,, -15681,,,0.5367968678474426,2.26556658744812,0.4979399740695953,2.443150281906128,50000.0,0.385200023651123,3.0521280765533447,10000.0,7177.518812179565,7774.275252819061,7177.518812179565,595.4816019535065,0.4494235515594482,0.0 -15700,0.9542517,4.272072,,,,,,,,,,,,,, -15800,1.0297389,4.2453113,,,,,,,,,,,,,, -15900,0.7229665,4.657067,,,,,,,,,,,,,, -16000,0.98379695,4.5846677,,,,,,,,,,,,,, -16100,0.78496283,4.674223,,,,,,,,,,,,,, -16200,0.79344213,5.1885805,,,,,,,,,,,,,, -16300,0.94323355,4.178815,,,,,,,,,,,,,, -16400,0.74025434,5.9646664,,,,,,,,,,,,,, -16500,0.9513789,4.2643995,,,,,,,,,,,,,, -16600,0.6507199,5.767072,,,,,,,,,,,,,, -16611,,,0.5549218654632568,2.1838366985321045,0.501039981842041,2.4183454513549805,50000.0,0.3908000290393829,3.024139881134033,10000.0,7597.600474834442,8225.735122203827,7597.600474834442,626.7855026721954,0.4745340347290039,0.0 -16700,0.66679627,5.6646695,,,,,,,,,,,,,, -16800,0.83641285,4.45131,,,,,,,,,,,,,, -16900,0.70873857,5.8100147,,,,,,,,,,,,,, -17000,0.93349046,4.19585,,,,,,,,,,,,,, -17100,0.90398383,4.1641254,,,,,,,,,,,,,, -17200,0.92297727,4.284212,,,,,,,,,,,,,, -17300,0.90347505,4.2899914,,,,,,,,,,,,,, -17400,0.98574686,4.0640664,,,,,,,,,,,,,, -17500,0.8390214,4.168951,,,,,,,,,,,,,, -17542,,,0.5504687428474426,2.1938395500183105,0.5118199586868286,2.3658089637756348,50000.0,0.4057000279426574,2.9616432189941406,10000.0,8017.97752737999,8678.402475357056,8017.97752737999,658.9955780506134,0.5048494338989258,0.0 -17600,0.9881797,4.545062,,,,,,,,,,,,,, -17700,0.8428461,4.356602,,,,,,,,,,,,,, -17800,0.9846513,4.1384068,,,,,,,,,,,,,, -17900,0.92862546,4.1942844,,,,,,,,,,,,,, -18000,0.9699769,4.3126774,,,,,,,,,,,,,, -18100,0.9137299,4.094517,,,,,,,,,,,,,, -18200,0.7960023,4.9656763,,,,,,,,,,,,,, -18300,1.0059828,4.083748,,,,,,,,,,,,,, -18400,0.94854647,4.315483,,,,,,,,,,,,,, -18464,,,0.5669335722923279,2.096486091613769,0.5268599987030029,2.2790541648864746,50000.0,0.4108000099658966,2.897355794906616,10000.0,8437.904272794724,9132.661324977877,8437.904272794724,693.2478134632111,0.5354297161102295,0.0 -18500,0.9415356,4.124601,,,,,,,,,,,,,, -18600,0.909493,4.2215805,,,,,,,,,,,,,, -18700,0.9780783,4.1040225,,,,,,,,,,,,,, -18800,0.7729892,4.5116463,,,,,,,,,,,,,, -18900,0.9124184,4.0492907,,,,,,,,,,,,,, -19000,0.7036077,5.7907305,,,,,,,,,,,,,, -19100,1.0688138,4.0314946,,,,,,,,,,,,,, -19200,0.90366364,4.0967245,,,,,,,,,,,,,, -19300,0.7816852,4.598703,,,,,,,,,,,,,, -19391,,,0.5886914134025574,1.9833548069000244,0.5397199988365173,2.214489698410034,50000.0,0.4244000315666199,2.8181262016296387,10000.0,8858.220192193985,9588.208375692368,8858.220192193985,728.4040546417236,0.5617325305938721,0.0 -19400,0.96556187,4.028498,,,,,,,,,,,,,, -19500,0.7275237,5.19949,,,,,,,,,,,,,, -19600,0.7997617,5.2149754,,,,,,,,,,,,,, -19700,0.97180504,4.114164,,,,,,,,,,,,,, -19800,0.895876,4.0051956,,,,,,,,,,,,,, -19900,0.9353751,4.0629916,,,,,,,,,,,,,, -20000,0.9324842,3.9918022,,,,,,,,,,,,,, -20100,0.85536766,4.0015354,,,,,,,,,,,,,, -20200,0.70504004,4.959666,,,,,,,,,,,,,, -20300,0.74779135,4.9320426,,,,,,,,,,,,,, -20321,,,0.5872656106948853,1.982076644897461,0.5415999889373779,2.179627656936645,50000.0,0.4308000206947326,2.7886321544647217,10000.0,9278.46674156189,10043.739041805267,9278.46674156189,763.6083111763,0.5928773880004883,0.0 -20400,0.89300394,3.9626722,,,,,,,,,,,,,, -20500,0.90068865,4.3778996,,,,,,,,,,,,,, -20600,0.86043787,4.179897,,,,,,,,,,,,,, -20700,0.8858026,3.9894633,,,,,,,,,,,,,, -20800,0.8460804,4.545527,,,,,,,,,,,,,, -20900,0.9439105,3.9746604,,,,,,,,,,,,,, -21000,0.7668844,5.334712,,,,,,,,,,,,,, -21100,0.95639974,4.1042666,,,,,,,,,,,,,, -21200,0.82724106,5.629221,,,,,,,,,,,,,, -21248,,,0.5874413847923279,1.997110605239868,0.5496799945831299,2.174673795700073,50000.0,0.4305000305175781,2.7999267578125,10000.0,9698.78861618042,10500.002382278442,9698.78861618042,799.4674112796783,0.6267457008361816,0.0 -21300,0.92662966,4.0486255,,,,,,,,,,,,,, -21400,0.9429606,4.0440283,,,,,,,,,,,,,, -21500,0.80796444,4.2447443,,,,,,,,,,,,,, -21600,0.8121661,4.5003843,,,,,,,,,,,,,, -21700,0.6872727,5.7509465,,,,,,,,,,,,,, -21800,0.6400201,5.788799,,,,,,,,,,,,,, -21900,0.96547675,3.9743032,,,,,,,,,,,,,, -22000,0.9222591,4.033592,,,,,,,,,,,,,, -22100,0.9570929,3.9933631,,,,,,,,,,,,,, -22177,,,0.6080663800239563,1.9223328828811648,0.5569199919700623,2.1479055881500244,50000.0,0.4414000213146209,2.7625842094421387,10000.0,10118.962321043016,10955.953563451769,10118.962321043016,835.1677443981171,0.6544830799102783,0.0 -22200,0.79231817,5.4963827,,,,,,,,,,,,,, -22300,0.91776663,3.9256046,,,,,,,,,,,,,, -22400,0.9839239,3.9353113,,,,,,,,,,,,,, -22500,0.6859986,5.442223,,,,,,,,,,,,,, -22600,0.9163095,3.9720209,,,,,,,,,,,,,, -22700,0.86750805,4.5183263,,,,,,,,,,,,,, -22800,0.91370094,4.091424,,,,,,,,,,,,,, -22900,0.9273629,3.904544,,,,,,,,,,,,,, -23000,0.9492237,3.9803452,,,,,,,,,,,,,, -23100,0.9446921,3.9597673,,,,,,,,,,,,,, -23105,,,0.6371484398841858,1.7699031829833984,0.5706999897956848,2.06937313079834,50000.0,0.4481000304222107,2.6906943321228027,10000.0,10539.390924930573,11411.274099826813,10539.390924930573,869.9834413528442,0.6817433834075928,0.0 -23200,0.898244,4.01262,,,,,,,,,,,,,, -23300,0.98602474,3.9844902,,,,,,,,,,,,,, -23400,0.9902838,3.952336,,,,,,,,,,,,,, -23500,0.92592776,3.9815993,,,,,,,,,,,,,, -23600,0.9062233,4.8317146,,,,,,,,,,,,,, -23700,0.9534617,3.9436293,,,,,,,,,,,,,, -23800,0.75648725,5.6325474,,,,,,,,,,,,,, -23900,0.9487839,4.085683,,,,,,,,,,,,,, -24000,0.84139717,4.263676,,,,,,,,,,,,,, -24032,,,0.6188671588897705,1.8651227951049805,0.5744799971580505,2.0639655590057373,50000.0,0.4580000340938568,2.681931734085083,10000.0,10959.36401629448,11866.034049749374,10959.36401629448,904.6943933963776,0.7084939479827881,0.0 -24100,0.969813,3.995106,,,,,,,,,,,,,, -24200,0.84329396,5.0438404,,,,,,,,,,,,,, -24300,0.89692646,4.132306,,,,,,,,,,,,,, -24400,0.7542329,5.3932176,,,,,,,,,,,,,, -24500,0.8845948,3.9460921,,,,,,,,,,,,,, -24600,0.8204698,4.329314,,,,,,,,,,,,,, -24700,0.77995664,4.9491243,,,,,,,,,,,,,, -24800,0.9414093,3.8776379,,,,,,,,,,,,,, -24900,0.71687883,5.662775,,,,,,,,,,,,,, -24957,,,0.6286327838897705,1.810321569442749,0.5790799856185913,2.0259885787963867,50000.0,0.4633000195026397,2.642620325088501,10000.0,11379.409708976746,12322.063168525696,11379.409708976746,940.5957970619202,0.7409365177154541,0.0 -25000,0.8320698,3.923002,,,,,,,,,,,,,, -25100,0.82584316,4.1861115,,,,,,,,,,,,,, -25200,0.895615,3.823931,,,,,,,,,,,,,, -25300,1.0315995,3.8870108,,,,,,,,,,,,,, -25400,0.7797279,5.311533,,,,,,,,,,,,,, -25500,0.8314296,4.9046845,,,,,,,,,,,,,, -25600,0.7886216,4.686959,,,,,,,,,,,,,, -25700,0.8669398,4.5496073,,,,,,,,,,,,,, -25800,0.867164,3.9431548,,,,,,,,,,,,,, -25885,,,0.6464648246765137,1.718097686767578,0.5830000042915344,1.9904863834381104,50000.0,0.4640000164508819,2.601236343383789,10000.0,11799.423764944077,12775.900420188904,11799.423764944077,974.3425529003144,0.7686212062835693,0.0 -25900,0.8035726,4.8674164,,,,,,,,,,,,,, -26000,0.8775691,4.4781127,,,,,,,,,,,,,, -26100,0.88389915,3.9338045,,,,,,,,,,,,,, -26200,0.9940375,3.876895,,,,,,,,,,,,,, -26300,1.001064,3.808306,,,,,,,,,,,,,, -26400,1.019315,3.9913054,,,,,,,,,,,,,, -26500,0.7522622,5.702442,,,,,,,,,,,,,, -26600,0.95149773,3.8858123,,,,,,,,,,,,,, -26700,1.0099823,4.0825024,,,,,,,,,,,,,, -26800,0.9589921,3.79167,,,,,,,,,,,,,, -26811,,,0.6296288967132568,1.7748783826828003,0.5895999670028687,1.974582314491272,50000.0,0.4707000255584717,2.5924384593963623,10000.0,12219.484512329102,13230.34390258789,12219.484512329102,1008.6455118656158,0.7995619773864746,0.0 -26900,0.8531818,4.141441,,,,,,,,,,,,,, -27000,1.1355752,3.7334356,,,,,,,,,,,,,, -27100,0.94642913,3.833827,,,,,,,,,,,,,, -27200,0.9404428,3.893908,,,,,,,,,,,,,, -27300,1.0053344,3.8051538,,,,,,,,,,,,,, -27400,0.71736294,5.423447,,,,,,,,,,,,,, -27500,0.8186835,4.469817,,,,,,,,,,,,,, -27600,1.0083638,3.8865924,,,,,,,,,,,,,, -27700,0.8837995,4.350534,,,,,,,,,,,,,, -27739,,,0.6431835889816284,1.694810390472412,0.6007599830627441,1.8964165449142456,50000.0,0.4793000221252441,2.5261025428771973,10000.0,12639.566604852676,13684.9992685318,12639.566604852676,1043.1403737068176,0.8287265300750732,0.0 -27800,0.76086676,4.6047688,,,,,,,,,,,,,, -27900,0.9589759,3.8729484,,,,,,,,,,,,,, -28000,0.9117641,4.0677533,,,,,,,,,,,,,, -28100,0.92649186,3.9423459,,,,,,,,,,,,,, -28200,0.953635,3.7754183,,,,,,,,,,,,,, -28300,1.0625355,3.9302804,,,,,,,,,,,,,, -28400,0.91554147,4.243682,,,,,,,,,,,,,, -28500,0.7473897,5.5766234,,,,,,,,,,,,,, -28600,0.78865737,5.6040115,,,,,,,,,,,,,, -28669,,,0.6536718606948853,1.666849136352539,0.5978400111198425,1.9145691394805908,50000.0,0.4789000153541565,2.543222665786743,10000.0,13059.846086263657,14139.38765001297,13059.846086263657,1077.1589756011963,0.8670501708984375,0.0 -28700,0.93264437,4.026315,,,,,,,,,,,,,, -28800,0.95418525,3.829796,,,,,,,,,,,,,, -28900,0.9731747,3.740002,,,,,,,,,,,,,, -29000,0.87800705,4.112607,,,,,,,,,,,,,, -29100,0.9314516,4.318855,,,,,,,,,,,,,, -29200,0.8167659,4.498102,,,,,,,,,,,,,, -29300,0.91184396,3.9131017,,,,,,,,,,,,,, -29400,0.9941251,3.8018968,,,,,,,,,,,,,, -29500,0.97163886,3.8665433,,,,,,,,,,,,,, -29597,,,0.6449218392372131,1.720602035522461,0.5954799652099609,1.935060739517212,50000.0,0.4761000275611877,2.5695018768310547,10000.0,13479.923082113266,14593.923149585724,13479.923082113266,1111.5325186252594,0.903029203414917,0.0 -29600,1.046413,3.7785046,,,,,,,,,,,,,, -29700,0.80565315,4.372155,,,,,,,,,,,,,, -29800,0.810877,4.780892,,,,,,,,,,,,,, -29900,0.9362742,3.8056064,,,,,,,,,,,,,, -30000,0.94103545,3.8197253,,,,,,,,,,,,,, -30100,0.98000735,3.762395,,,,,,,,,,,,,, -30200,0.79692453,5.6529913,,,,,,,,,,,,,, -30300,0.9456164,3.8752723,,,,,,,,,,,,,, -30400,0.79466385,4.770418,,,,,,,,,,,,,, -30500,0.99940926,3.999677,,,,,,,,,,,,,, -30525,,,0.6502929329872131,1.707275390625,0.602620005607605,1.911301851272583,50000.0,0.4842000305652618,2.5386219024658203,10000.0,13900.250252962112,15047.529522657394,13900.250252962112,1144.7327094078064,0.9328234195709229,0.0 -30600,0.83117795,5.5153494,,,,,,,,,,,,,, -30700,0.9162575,3.7826672,,,,,,,,,,,,,, -30800,0.9677174,3.9558356,,,,,,,,,,,,,, -30900,0.9455687,4.7998567,,,,,,,,,,,,,, -31000,0.8235127,4.606522,,,,,,,,,,,,,, -31100,1.0000398,4.0440392,,,,,,,,,,,,,, -31200,0.989596,3.7696931,,,,,,,,,,,,,, -31300,0.9829964,3.832428,,,,,,,,,,,,,, -31400,0.9658351,3.6860132,,,,,,,,,,,,,, -31451,,,0.6668164134025574,1.6053072214126587,0.6132000088691711,1.8501198291778564,50000.0,0.4948000311851501,2.4632036685943604,10000.0,14320.317656040192,15502.468721866608,14320.317656040192,1179.5270998477936,0.9612855911254884,0.0 -31500,1.0325974,3.7431295,,,,,,,,,,,,,, -31600,0.9923551,3.770452,,,,,,,,,,,,,, -31700,0.95177406,3.9612756,,,,,,,,,,,,,, -31800,0.88827753,4.275523,,,,,,,,,,,,,, -31900,0.82026136,5.514175,,,,,,,,,,,,,, -32000,0.9284022,3.7887266,,,,,,,,,,,,,, -32100,0.8760897,3.6800518,,,,,,,,,,,,,, -32200,0.964291,3.6761632,,,,,,,,,,,,,, -32300,0.8188047,4.5695496,,,,,,,,,,,,,, -32379,,,0.6805077791213989,1.5708024501800537,0.6123799681663513,1.8718345165252688,50000.0,0.4894000291824341,2.4860403537750244,10000.0,14740.255218982697,15957.264911651611,14740.255218982697,1214.3044934272766,0.9927427768707277,0.0 -32400,1.0106074,3.7487743,,,,,,,,,,,,,, -32500,0.76687133,5.0955153,,,,,,,,,,,,,, -32600,0.93592185,5.5461416,,,,,,,,,,,,,, -32700,0.9233208,3.9134233,,,,,,,,,,,,,, -32800,1.0076334,3.7402272,,,,,,,,,,,,,, -32900,0.9588465,3.6856182,,,,,,,,,,,,,, -33000,0.79238594,5.594827,,,,,,,,,,,,,, -33100,1.0200384,3.7114403,,,,,,,,,,,,,, -33200,0.96215975,3.6939652,,,,,,,,,,,,,, -33300,1.0189531,3.7026472,,,,,,,,,,,,,, -33309,,,0.6620507836341858,1.677104949951172,0.6145600080490112,1.8774863481521609,50000.0,0.4953000247478485,2.487512350082397,10000.0,15160.616579771042,16411.350699186325,15160.616579771042,1247.9506666660309,1.0217430591583252,0.0 -33400,0.906142,5.2287607,,,,,,,,,,,,,, -33500,0.95875883,3.796195,,,,,,,,,,,,,, -33600,0.9734967,3.791043,,,,,,,,,,,,,, -33700,0.80106944,4.6642675,,,,,,,,,,,,,, -33800,0.97493595,3.691932,,,,,,,,,,,,,, -33900,0.8892142,5.6111283,,,,,,,,,,,,,, -34000,0.8618441,4.39291,,,,,,,,,,,,,, -34100,1.0800573,3.7548606,,,,,,,,,,,,,, -34200,1.0651654,3.7165916,,,,,,,,,,,,,, -34238,,,0.6710546612739563,1.578461527824402,0.6174600124359131,1.8176082372665403,50000.0,0.496800035238266,2.435833215713501,10000.0,15580.754612445831,16865.220044851303,15580.754612445831,1281.597553730011,1.057400465011597,0.0 -34300,0.95490324,3.898426,,,,,,,,,,,,,, -34400,0.9498165,3.7048142,,,,,,,,,,,,,, -34500,0.9934098,3.6789422,,,,,,,,,,,,,, -34600,0.99203396,3.6716115,,,,,,,,,,,,,, -34700,0.9869701,3.6670277,,,,,,,,,,,,,, -34800,0.8104736,4.55207,,,,,,,,,,,,,, -34900,1.0019908,3.7979577,,,,,,,,,,,,,, -35000,1.003916,3.7756274,,,,,,,,,,,,,, -35100,1.0788318,3.8290644,,,,,,,,,,,,,, -35166,,,0.6815429329872131,1.5850794315338137,0.621239960193634,1.84546422958374,50000.0,0.4941000342369079,2.4827773571014404,10000.0,16000.741730213163,17319.111181020737,16000.741730213163,1315.423936367035,1.0867114067077637,0.0 -35200,1.0809532,3.7541676,,,,,,,,,,,,,, -35300,0.93300587,3.6390266,,,,,,,,,,,,,, -35400,1.0139673,4.432934,,,,,,,,,,,,,, -35500,0.8266467,5.1814275,,,,,,,,,,,,,, -35600,0.91169596,4.115728,,,,,,,,,,,,,, -35700,0.9403917,4.663226,,,,,,,,,,,,,, -35800,1.1506076,3.7000403,,,,,,,,,,,,,, -35900,0.9499887,3.8510427,,,,,,,,,,,,,, -36000,0.99094117,3.8686213,,,,,,,,,,,,,, -36093,,,0.6749609112739563,1.5821787118911743,0.6269599795341492,1.7968742847442627,50000.0,0.5010000467300415,2.4259908199310303,10000.0,16421.03150343895,17774.126088142395,16421.03150343895,1350.0655298233032,1.1218464374542236,0.0 -36100,1.0404501,3.6883245,,,,,,,,,,,,,, -36200,0.90463525,4.553444,,,,,,,,,,,,,, -36300,1.0522516,3.6216056,,,,,,,,,,,,,, -36400,1.1204756,3.7099905,,,,,,,,,,,,,, -36500,0.8710237,4.191329,,,,,,,,,,,,,, -36600,0.8412132,4.7579484,,,,,,,,,,,,,, -36700,1.0582051,3.989148,,,,,,,,,,,,,, -36800,1.0394659,3.7970917,,,,,,,,,,,,,, -36900,0.8516885,5.332053,,,,,,,,,,,,,, -37000,1.0153329,3.8099637,,,,,,,,,,,,,, -37022,,,0.6784374713897705,1.5355225801467896,0.6299600005149841,1.7512410879135132,50000.0,0.5004000067710876,2.398932933807373,10000.0,16841.30184864998,18229.68051123619,16841.30184864998,1385.2672073841095,1.1547789573669434,0.0 -37100,0.9744273,3.675113,,,,,,,,,,,,,, -37200,0.8881933,4.931869,,,,,,,,,,,,,, -37300,0.93328947,4.269449,,,,,,,,,,,,,, -37400,0.9227616,3.890552,,,,,,,,,,,,,, -37500,1.1502314,3.7204745,,,,,,,,,,,,,, -37600,0.86959505,5.442815,,,,,,,,,,,,,, -37700,0.9060835,3.9868731,,,,,,,,,,,,,, -37800,0.9612434,3.891843,,,,,,,,,,,,,, -37900,0.8324446,5.4114456,,,,,,,,,,,,,, -37949,,,0.6911327838897705,1.483366847038269,0.6319000124931335,1.7460869550704956,50000.0,0.5121000409126282,2.372727155685425,10000.0,17261.51871085167,18686.173897981644,17261.51871085167,1421.4622313976288,1.187713623046875,0.0 -38000,1.0588166,3.7438865,,,,,,,,,,,,,, -38100,1.0052075,3.742297,,,,,,,,,,,,,, -38200,1.0729543,3.7200172,,,,,,,,,,,,,, -38300,0.9545739,4.327615,,,,,,,,,,,,,, -38400,0.97153944,3.689851,,,,,,,,,,,,,, -38500,1.0002767,3.730122,,,,,,,,,,,,,, -38600,0.90622085,4.4573574,,,,,,,,,,,,,, -38700,0.97049534,3.7308996,,,,,,,,,,,,,, -38800,0.9869735,3.6736813,,,,,,,,,,,,,, -38877,,,0.6896093487739563,1.500171780586243,0.6326599717140198,1.7544867992401123,50000.0,0.5072000026702881,2.3739521503448486,10000.0,17681.805111408234,19140.70230269432,17681.805111408234,1455.6237666606903,1.2181427478790283,0.0 -38900,1.041456,3.646099,,,,,,,,,,,,,, -39000,1.0563474,3.7328892,,,,,,,,,,,,,, -39100,1.0553325,3.689026,,,,,,,,,,,,,, -39200,0.9582618,3.6220553,,,,,,,,,,,,,, -39300,0.9075746,3.9546824,,,,,,,,,,,,,, -39400,0.88719773,4.923106,,,,,,,,,,,,,, -39500,0.8987111,3.8578155,,,,,,,,,,,,,, -39600,1.0019872,4.2529387,,,,,,,,,,,,,, -39700,1.0156088,3.6270816,,,,,,,,,,,,,, -39800,0.97302747,3.672422,,,,,,,,,,,,,, -39807,,,0.6860156059265137,1.5207045078277588,0.6359599828720093,1.745492458343506,50000.0,0.5121999979019165,2.383885622024536,10000.0,18101.77188515663,19594.19838809967,18101.77188515663,1489.073492050171,1.2485857009887695,0.0 -39900,0.9750376,3.9444065,,,,,,,,,,,,,, -40000,0.98294634,3.7967129,,,,,,,,,,,,,, -40100,0.98692364,3.925381,,,,,,,,,,,,,, -40200,1.0577453,3.6700072,,,,,,,,,,,,,, -40300,0.95620835,4.161645,,,,,,,,,,,,,, -40400,1.042369,3.677665,,,,,,,,,,,,,, -40500,0.9587469,4.0790424,,,,,,,,,,,,,, -40600,0.859397,5.1002517,,,,,,,,,,,,,, -40700,0.8578478,5.4725437,,,,,,,,,,,,,, -40735,,,0.6874804496765137,1.4901330471038818,0.6361799836158752,1.7246910333633425,50000.0,0.5095000267028809,2.366427183151245,10000.0,18522.085191726685,20046.41907286644,18522.085191726685,1520.9033830165863,1.277255296707153,0.0 -40800,0.9846954,3.6915584,,,,,,,,,,,,,, -40900,0.949101,5.215635,,,,,,,,,,,,,, -41000,0.9889896,3.8681242,,,,,,,,,,,,,, -41100,1.0004774,3.717482,,,,,,,,,,,,,, -41200,0.9610138,3.7469878,,,,,,,,,,,,,, -41300,0.90826654,4.7304225,,,,,,,,,,,,,, -41400,1.0634443,3.6097918,,,,,,,,,,,,,, -41500,0.8431336,5.308586,,,,,,,,,,,,,, -41600,0.87552845,5.4565463,,,,,,,,,,,,,, -41661,,,0.7146679759025574,1.4313410520553589,0.6459000110626221,1.7241450548171997,50000.0,0.517300009727478,2.3592562675476074,10000.0,18942.435485124588,20499.75671958924,18942.435485124588,1553.8083319664,1.3109490871429443,0.0 -41700,0.9234254,5.395972,,,,,,,,,,,,,, -41800,0.8711837,4.544059,,,,,,,,,,,,,, -41900,0.8675782,5.3666058,,,,,,,,,,,,,, -42000,0.9382671,4.3473177,,,,,,,,,,,,,, -42100,0.9977382,3.7131183,,,,,,,,,,,,,, -42200,1.0415379,3.6039422,,,,,,,,,,,,,, -42300,1.1188347,3.594501,,,,,,,,,,,,,, -42400,0.97257066,3.7309635,,,,,,,,,,,,,, -42500,0.99864,3.6694593,,,,,,,,,,,,,, -42589,,,0.6867382526397705,1.5198066234588623,0.6368399858474731,1.7426916360855105,50000.0,0.5134000182151794,2.363996982574463,10000.0,19362.675671339035,20955.205878019333,19362.675671339035,1588.933512687683,1.3452599048614502,0.0 -42600,0.8842365,4.029301,,,,,,,,,,,,,, -42700,0.8857321,4.386162,,,,,,,,,,,,,, -42800,0.97974414,3.7353792,,,,,,,,,,,,,, -42900,0.9997124,5.4020753,,,,,,,,,,,,,, -43000,1.0119361,3.588066,,,,,,,,,,,,,, -43100,0.9941302,3.5793583,,,,,,,,,,,,,, -43200,0.87878674,4.7041807,,,,,,,,,,,,,, -43300,0.92174816,5.356333,,,,,,,,,,,,,, -43400,1.0113894,3.6733274,,,,,,,,,,,,,, -43500,0.9538141,5.3353395,,,,,,,,,,,,,, -43520,,,0.7005859017372131,1.4200528860092163,0.6437000036239624,1.678924322128296,50000.0,0.5212000012397766,2.3035788536071777,10000.0,19782.83354473114,21409.87931132317,19782.83354473114,1623.3567078113556,1.3867592811584473,0.0 -43600,0.9580472,3.6235666,,,,,,,,,,,,,, -43700,1.1108235,3.6152215,,,,,,,,,,,,,, -43800,0.8497071,5.1198773,,,,,,,,,,,,,, -43900,1.1253871,3.7366743,,,,,,,,,,,,,, -44000,1.0453336,3.6783001,,,,,,,,,,,,,, -44100,1.011539,3.850791,,,,,,,,,,,,,, -44200,0.9333136,3.9314632,,,,,,,,,,,,,, -44300,0.92740995,4.033474,,,,,,,,,,,,,, -44400,1.1001616,3.7278075,,,,,,,,,,,,,, -44447,,,0.7023046612739563,1.501721739768982,0.6417799592018127,1.76237154006958,50000.0,0.5196000337600708,2.387691020965576,10000.0,20203.08823728561,21863.43239045143,20203.08823728561,1656.5718541145325,1.4205570220947266,0.0 -44500,1.154919,3.6625237,,,,,,,,,,,,,, -44600,0.98180956,3.661238,,,,,,,,,,,,,, -44700,0.96956253,5.3447857,,,,,,,,,,,,,, -44800,1.0185927,5.346094,,,,,,,,,,,,,, -44900,0.8797144,4.79482,,,,,,,,,,,,,, -45000,0.9369408,4.928205,,,,,,,,,,,,,, -45100,0.9552905,3.8494627,,,,,,,,,,,,,, -45200,1.0524403,3.5464668,,,,,,,,,,,,,, -45300,1.012413,4.8003683,,,,,,,,,,,,,, -45375,,,0.6984961032867432,1.481454610824585,0.6467799544334412,1.7090678215026855,50000.0,0.5195000171661377,2.3472135066986084,10000.0,20623.02325534821,22318.31251120568,20623.02325534821,1691.4364099502563,1.452678680419922,0.0 -45400,1.2432559,3.6962004,,,,,,,,,,,,,, -45500,0.8893619,4.8916216,,,,,,,,,,,,,, -45600,0.9761006,3.9923687,,,,,,,,,,,,,, -45700,1.0304142,3.6075866,,,,,,,,,,,,,, -45800,1.0133126,3.604486,,,,,,,,,,,,,, -45900,0.95086765,4.1480603,,,,,,,,,,,,,, -46000,0.98561466,3.5364625,,,,,,,,,,,,,, -46100,1.1061438,3.5718763,,,,,,,,,,,,,, -46200,0.85060316,4.7237077,,,,,,,,,,,,,, -46300,0.9943372,3.905677,,,,,,,,,,,,,, -46301,,,0.7051757574081421,1.4099836349487305,0.6536799669265747,1.652795433998108,50000.0,0.5258000493049622,2.281585216522217,10000.0,21043.11750602722,22775.022524118423,21043.11750602722,1727.9724340438845,1.4842946529388428,0.0 -46400,0.9879013,3.9999352,,,,,,,,,,,,,, -46500,0.9270246,5.4544535,,,,,,,,,,,,,, -46600,1.0138023,5.2653747,,,,,,,,,,,,,, -46700,0.9368387,5.1546555,,,,,,,,,,,,,, -46800,1.0479752,3.5240486,,,,,,,,,,,,,, -46900,1.0978614,3.6052542,,,,,,,,,,,,,, -47000,1.1088651,3.5479188,,,,,,,,,,,,,, -47100,1.0294663,3.5663004,,,,,,,,,,,,,, -47200,0.99417484,3.591897,,,,,,,,,,,,,, -47228,,,0.7090820074081421,1.4518290758132937,0.6502000093460083,1.708951473236084,50000.0,0.5272000432014465,2.330073356628418,10000.0,21463.19794869423,23228.048787355423,21463.19794869423,1760.8348679542542,1.5186994075775146,0.0 -47300,1.0661429,3.721801,,,,,,,,,,,,,, -47400,0.99379396,3.4508,,,,,,,,,,,,,, -47500,1.0739033,3.5124025,,,,,,,,,,,,,, -47600,1.1365958,3.546765,,,,,,,,,,,,,, -47700,1.0554893,3.534822,,,,,,,,,,,,,, -47800,0.98984694,5.401817,,,,,,,,,,,,,, -47900,0.90009147,4.032447,,,,,,,,,,,,,, -48000,0.9176784,5.293222,,,,,,,,,,,,,, -48100,1.0979381,3.4882042,,,,,,,,,,,,,, -48156,,,0.7205273509025574,1.3598943948745728,0.6526199579238892,1.6453591585159302,50000.0,0.5291000008583069,2.2683568000793457,10000.0,21883.307546377186,23681.82369303704,21883.307546377186,1794.4216213226318,1.5487060546875,0.0 -48200,0.93450856,4.303299,,,,,,,,,,,,,, -48300,1.0337292,3.500953,,,,,,,,,,,,,, -48400,0.981686,3.7687352,,,,,,,,,,,,,, -48500,1.0422388,3.7231565,,,,,,,,,,,,,, -48600,1.0431288,3.6779723,,,,,,,,,,,,,, -48700,1.0026882,3.5600142,,,,,,,,,,,,,, -48800,0.99268836,3.9488325,,,,,,,,,,,,,, -48900,1.0041373,4.2676206,,,,,,,,,,,,,, -49000,0.90373915,4.76451,,,,,,,,,,,,,, -49084,,,0.7106054425239563,1.384338617324829,0.6567999720573425,1.621684432029724,50000.0,0.5312000513076782,2.238567352294922,10000.0,22303.582375764847,24135.841794013977,22303.582375764847,1828.082728624344,1.5821101665496826,0.0 -49100,1.0782002,5.205789,,,,,,,,,,,,,, -49200,1.0655675,3.6304946,,,,,,,,,,,,,, -49300,1.0782304,3.6765518,,,,,,,,,,,,,, -49400,1.0337373,3.5548637,,,,,,,,,,,,,, -49500,1.0150543,3.5367174,,,,,,,,,,,,,, -49600,1.0314721,3.6036115,,,,,,,,,,,,,, -49700,0.95470923,5.247381,,,,,,,,,,,,,, -49800,1.0236449,3.6899376,,,,,,,,,,,,,, -49900,1.0258664,4.8956695,,,,,,,,,,,,,, -50000,1.0515388,5.3637767,,,,,,,,,,,,,, -50009,,,0.7145702838897705,1.409250259399414,0.6565799713134766,1.664278507232666,50000.0,0.5333000421524048,2.27439832687378,10000.0,22723.54018163681,24589.507689237595,22723.54018163681,1861.705460071564,1.618788242340088,0.0 -50100,1.018249,3.4924574,,,,,,,,,,,,,, -50200,0.98787177,4.471263,,,,,,,,,,,,,, -50300,1.0322553,3.5940962,,,,,,,,,,,,,, -50400,1.0563252,5.392688,,,,,,,,,,,,,, -50500,1.0395948,3.4916365,,,,,,,,,,,,,, -50600,1.1070052,3.5597737,,,,,,,,,,,,,, -50700,1.0194038,3.7764368,,,,,,,,,,,,,, -50800,1.087234,3.5386415,,,,,,,,,,,,,, -50900,0.96325034,5.3085093,,,,,,,,,,,,,, -50934,,,0.727734386920929,1.318716526031494,0.6599999666213989,1.6194846630096436,50000.0,0.5382000207901001,2.2213504314422607,10000.0,23143.67327833176,25044.706683397293,23143.67327833176,1896.688829660416,1.6524369716644287,0.0 -51000,1.1450317,3.8268926,,,,,,,,,,,,,, -51100,1.0300528,3.7710955,,,,,,,,,,,,,, -51200,1.0574899,3.5498056,,,,,,,,,,,,,, -51300,1.0186421,3.9898443,,,,,,,,,,,,,, -51400,1.0824804,3.526258,,,,,,,,,,,,,, -51500,1.1173637,3.6441472,,,,,,,,,,,,,, -51600,1.0979084,3.6059914,,,,,,,,,,,,,, -51700,1.0185035,3.924263,,,,,,,,,,,,,, -51800,0.98149866,5.3002653,,,,,,,,,,,,,, -51860,,,0.7147851586341858,1.3640855550765991,0.6587399840354919,1.61241352558136,50000.0,0.5361000299453735,2.21683406829834,10000.0,23563.626941919327,25499.76881289482,23563.626941919327,1931.7127187252045,1.6885082721710205,0.0 -51900,0.9943883,4.9539814,,,,,,,,,,,,,, -52000,1.0078298,5.3410997,,,,,,,,,,,,,, -52100,1.1268497,3.522505,,,,,,,,,,,,,, -52200,1.0928944,3.5221097,,,,,,,,,,,,,, -52300,1.105151,3.4890182,,,,,,,,,,,,,, -52400,0.94669205,3.946547,,,,,,,,,,,,,, -52500,1.1447355,3.5096562,,,,,,,,,,,,,, -52600,0.9436217,4.361619,,,,,,,,,,,,,, -52700,1.0348465,5.273991,,,,,,,,,,,,,, -52789,,,0.7134960889816284,1.4122158288955688,0.6552799940109253,1.6602015495300293,50000.0,0.526900053024292,2.2857484817504883,10000.0,23983.90749502182,25954.107449531555,23983.90749502182,1965.6869568824768,1.7237651348114014,0.0 -52800,1.2010124,3.5940094,,,,,,,,,,,,,, -52900,1.0753905,3.3981445,,,,,,,,,,,,,, -53000,1.0354296,3.5405076,,,,,,,,,,,,,, -53100,1.1059718,3.5680285,,,,,,,,,,,,,, -53200,1.0697544,3.5933108,,,,,,,,,,,,,, -53300,1.1805476,3.5799484,,,,,,,,,,,,,, -53400,1.1387501,3.5506806,,,,,,,,,,,,,, -53500,1.0826386,5.2997904,,,,,,,,,,,,,, -53600,0.9872415,5.115101,,,,,,,,,,,,,, -53700,1.051745,3.511382,,,,,,,,,,,,,, -53718,,,0.7314453125,1.304121494293213,0.6644399762153625,1.601436972618103,50000.0,0.5385000109672546,2.231038808822632,10000.0,24404.197466611862,26407.075337409973,24404.197466611862,1998.2732956409448,1.7659268379211426,0.0 -53800,0.9758873,3.9427981,,,,,,,,,,,,,, -53900,1.0012672,4.1399593,,,,,,,,,,,,,, -54000,1.0635815,3.4702063,,,,,,,,,,,,,, -54100,1.1972566,3.5190911,,,,,,,,,,,,,, -54200,1.0970463,3.483344,,,,,,,,,,,,,, -54300,1.0484079,5.1398354,,,,,,,,,,,,,, -54400,1.0145525,4.597485,,,,,,,,,,,,,, -54500,0.94900835,4.9705267,,,,,,,,,,,,,, -54600,1.0748854,3.6402383,,,,,,,,,,,,,, -54647,,,0.71533203125,1.3832980394363403,0.6642999649047852,1.607182502746582,50000.0,0.5401000380516052,2.210639476776123,10000.0,24824.45699334145,26861.518835544583,24824.45699334145,2032.3752472400663,1.7985684871673584,0.0 -54700,0.95383376,4.7562895,,,,,,,,,,,,,, -54800,1.0183586,3.8033068,,,,,,,,,,,,,, -54900,1.029369,3.754346,,,,,,,,,,,,,, -55000,1.1689074,3.5356398,,,,,,,,,,,,,, -55100,1.0598962,3.5118122,,,,,,,,,,,,,, -55200,1.0504007,5.2936125,,,,,,,,,,,,,, -55300,1.1526527,3.5736876,,,,,,,,,,,,,, -55400,0.9236196,4.0017023,,,,,,,,,,,,,, -55500,0.9950278,4.922753,,,,,,,,,,,,,, -55575,,,0.7239453196525574,1.3593758344650269,0.6665999889373779,1.6094951629638672,50000.0,0.5408000349998474,2.2293291091918945,10000.0,25244.383882761,27314.31769967079,25244.383882761,2065.166063785553,1.8309228420257568,0.0 -55600,1.2416337,3.6398876,,,,,,,,,,,,,, -55700,1.0989915,3.5347655,,,,,,,,,,,,,, -55800,1.0386926,3.8960574,,,,,,,,,,,,,, -55900,1.1492496,3.5593872,,,,,,,,,,,,,, -56000,1.1008664,3.7724252,,,,,,,,,,,,,, -56100,1.0074998,3.6315584,,,,,,,,,,,,,, -56200,1.0889908,3.4948745,,,,,,,,,,,,,, -56300,1.0730454,3.4357698,,,,,,,,,,,,,, -56400,1.1249497,3.483361,,,,,,,,,,,,,, -56500,1.14982,3.5289128,,,,,,,,,,,,,, -56501,,,0.7246484160423279,1.3352673053741455,0.6661999821662903,1.6007659435272217,50000.0,0.5393000245094299,2.2219364643096924,10000.0,25665.15626358986,27768.345024824142,25665.15626358986,2098.33789563179,1.864485502243042,0.0 -56600,1.0377907,3.5533233,,,,,,,,,,,,,, -56700,1.0918438,4.9294195,,,,,,,,,,,,,, -56800,1.1905879,3.5844254,,,,,,,,,,,,,, -56900,1.0308824,3.4505534,,,,,,,,,,,,,, -57000,1.1364301,3.486038,,,,,,,,,,,,,, -57100,1.2283133,3.5166552,,,,,,,,,,,,,, -57200,1.10928,3.4190507,,,,,,,,,,,,,, -57300,1.0543127,3.4181046,,,,,,,,,,,,,, -57400,1.0084635,3.7802124,,,,,,,,,,,,,, -57431,,,0.7429101467132568,1.2352076768875122,0.6694999933242798,1.5599645376205444,50000.0,0.5423000454902649,2.192864894866944,10000.0,26085.455446720123,28223.11777973175,26085.455446720123,2132.7292597293854,1.8969478607177728,0.0 -57500,0.9579452,4.5636473,,,,,,,,,,,,,, -57600,1.0653943,3.5170724,,,,,,,,,,,,,, -57700,0.9985331,3.7482107,,,,,,,,,,,,,, -57800,1.0365402,3.4977121,,,,,,,,,,,,,, -57900,1.0194001,3.571296,,,,,,,,,,,,,, -58000,1.1026682,5.2570276,,,,,,,,,,,,,, -58100,1.0327713,5.246602,,,,,,,,,,,,,, -58200,1.086733,3.4738376,,,,,,,,,,,,,, -58300,0.9758465,4.3607183,,,,,,,,,,,,,, -58358,,,0.7205859422683716,1.346676468849182,0.6641799807548523,1.5929832458496094,50000.0,0.5409000515937805,2.201468706130981,10000.0,26505.45462369919,28676.99065828324,26505.45462369919,2166.513976097107,1.9372212886810305,0.0 -58400,1.1414294,3.4893634,,,,,,,,,,,,,, -58500,1.0876399,3.547262,,,,,,,,,,,,,, -58600,1.0874192,3.6080875,,,,,,,,,,,,,, -58700,1.0272484,5.256359,,,,,,,,,,,,,, -58800,1.1357695,3.4540803,,,,,,,,,,,,,, -58900,0.9844383,4.668709,,,,,,,,,,,,,, -59000,0.95068866,4.5883284,,,,,,,,,,,,,, -59100,1.0716968,4.8976,,,,,,,,,,,,,, -59200,1.078924,3.643356,,,,,,,,,,,,,, -59285,,,0.7235937118530273,1.3309944868087769,0.6640599966049194,1.594617247581482,50000.0,0.5421000123023987,2.206836462020874,10000.0,26925.466319322582,29131.28421139717,26925.466319322582,2200.711050987244,1.973011493682861,0.0 -59300,1.1410234,3.5247307,,,,,,,,,,,,,, -59400,1.1206341,3.4817762,,,,,,,,,,,,,, -59500,1.0546998,4.798649,,,,,,,,,,,,,, -59600,1.1879773,3.4916625,,,,,,,,,,,,,, -59700,1.0271826,3.8221097,,,,,,,,,,,,,, -59800,1.038221,4.889551,,,,,,,,,,,,,, -59900,1.1091557,3.5049856,,,,,,,,,,,,,, -60000,1.0324973,3.8048167,,,,,,,,,,,,,, -60100,0.9748196,4.164326,,,,,,,,,,,,,, -60200,1.0050197,3.9001288,,,,,,,,,,,,,, -60211,,,0.7406054735183716,1.2563652992248535,0.6726999878883362,1.5577574968338013,50000.0,0.547700047492981,2.189788341522217,10000.0,27345.63130736351,29584.98865461349,27345.63130736351,2234.163626194,2.0108447074890137,0.0 -60300,1.0745484,3.4522305,,,,,,,,,,,,,, -60400,1.0248286,4.076855,,,,,,,,,,,,,, -60500,1.1231616,3.5741994,,,,,,,,,,,,,, -60600,1.0411211,4.9091806,,,,,,,,,,,,,, -60700,1.1480013,3.4707603,,,,,,,,,,,,,, -60800,1.0252826,4.5564833,,,,,,,,,,,,,, -60900,1.0997852,3.7363253,,,,,,,,,,,,,, -61000,1.0683802,3.523514,,,,,,,,,,,,,, -61100,1.1119194,3.6151474,,,,,,,,,,,,,, -61138,,,0.729296863079071,1.288324236869812,0.6774599552154541,1.5246641635894775,50000.0,0.5457000136375427,2.157984495162964,10000.0,27765.606746673584,30037.10212635994,27765.606746673584,2266.2198588848114,2.044236660003662,0.0 -61200,1.1118387,3.4020147,,,,,,,,,,,,,, -61300,1.0375504,3.794917,,,,,,,,,,,,,, -61400,1.0524288,3.5944579,,,,,,,,,,,,,, -61500,1.0394163,4.683045,,,,,,,,,,,,,, -61600,1.1344637,3.4792016,,,,,,,,,,,,,, -61700,1.1819493,3.5096464,,,,,,,,,,,,,, -61800,1.042159,3.5662162,,,,,,,,,,,,,, -61900,0.9413339,4.1449504,,,,,,,,,,,,,, -62000,1.067962,3.4969244,,,,,,,,,,,,,, -62061,,,0.7343164086341858,1.2893426418304443,0.6753199696540833,1.5548889636993408,50000.0,0.5463000535964966,2.169198751449585,10000.0,28185.229808330536,30493.144993782043,28185.229808330536,2302.1697578430176,2.4654579162597656,0.0 -62100,0.99523336,3.906529,,,,,,,,,,,,,, -62200,1.0452473,3.5156205,,,,,,,,,,,,,, -62300,1.0777647,3.464972,,,,,,,,,,,,,, -62400,1.116918,3.5042958,,,,,,,,,,,,,, -62500,1.0874486,5.066327,,,,,,,,,,,,,, -62600,1.0108792,3.6592019,,,,,,,,,,,,,, -62700,1.0029023,3.6713576,,,,,,,,,,,,,, -62800,0.9660513,4.4137306,,,,,,,,,,,,,, -62900,1.1210475,3.5381382,,,,,,,,,,,,,, -62990,,,0.7407812476158142,1.2475316524505615,0.6754399538040161,1.5307263135910034,50000.0,0.5506000518798828,2.156816005706787,10000.0,28605.168552160263,30950.45087170601,28605.168552160263,2339.451201438904,2.5024900436401367,0.0 -63000,0.9900257,4.2504363,,,,,,,,,,,,,, -63100,1.1057516,3.3586912,,,,,,,,,,,,,, -63200,1.1020396,5.1969776,,,,,,,,,,,,,, -63300,1.1416028,5.143928,,,,,,,,,,,,,, -63400,1.0622137,3.6425796,,,,,,,,,,,,,, -63500,1.0399057,3.6107206,,,,,,,,,,,,,, -63600,1.0251414,3.7607932,,,,,,,,,,,,,, -63700,1.1398848,3.5983708,,,,,,,,,,,,,, -63800,1.1380628,3.4538934,,,,,,,,,,,,,, -63900,1.047993,3.8628025,,,,,,,,,,,,,, -63916,,,0.7317968606948853,1.2713884115219116,0.6771399974822998,1.5070552825927734,50000.0,0.5541000366210938,2.133321046829224,10000.0,29025.421385526657,31407.895943164825,29025.421385526657,2376.555834293365,2.5417940616607666,0.0 -64000,1.1150599,5.0658827,,,,,,,,,,,,,, -64100,1.0500005,5.1453238,,,,,,,,,,,,,, -64200,1.0579096,4.2770095,,,,,,,,,,,,,, -64300,1.1628412,4.98546,,,,,,,,,,,,,, -64400,1.0498406,3.5123553,,,,,,,,,,,,,, -64500,1.088701,3.411773,,,,,,,,,,,,,, -64600,1.1876713,3.5261092,,,,,,,,,,,,,, -64700,1.0870609,3.4807987,,,,,,,,,,,,,, -64800,1.3559253,3.5241072,,,,,,,,,,,,,, -64842,,,0.7364453077316284,1.2684588432312012,0.6774799823760986,1.5218348503112793,50000.0,0.5561000108718872,2.1238458156585693,10000.0,29445.614727020264,31861.867134332657,29445.614727020264,2410.2451598644257,2.581367254257202,0.0 -64900,1.1542126,5.138047,,,,,,,,,,,,,, -65000,1.0539153,5.169981,,,,,,,,,,,,,, -65100,1.0476049,5.152339,,,,,,,,,,,,,, -65200,1.0862622,5.2301946,,,,,,,,,,,,,, -65300,1.1337347,3.4797752,,,,,,,,,,,,,, -65400,1.190487,3.350044,,,,,,,,,,,,,, -65500,1.1116751,3.4292502,,,,,,,,,,,,,, -65600,1.1014649,5.151114,,,,,,,,,,,,,, -65700,1.1950734,3.4168434,,,,,,,,,,,,,, -65767,,,0.7422069907188416,1.2749184370040894,0.6781600117683411,1.548775553703308,50000.0,0.5523000359535217,2.1705105304718018,10000.0,29865.653613567352,32315.68285059929,29865.653613567352,2443.933711528778,2.621212720870972,0.0 -65800,1.1057682,4.0739036,,,,,,,,,,,,,, -65900,1.0833194,5.1324625,,,,,,,,,,,,,, -66000,1.1273183,3.364628,,,,,,,,,,,,,, -66100,1.1207488,3.4298036,,,,,,,,,,,,,, -66200,1.1339073,3.7068684,,,,,,,,,,,,,, -66300,1.0180583,3.6717606,,,,,,,,,,,,,, -66400,1.0425538,3.9610395,,,,,,,,,,,,,, -66500,1.1084374,4.9647193,,,,,,,,,,,,,, -66600,1.1401093,3.344108,,,,,,,,,,,,,, -66693,,,0.758984386920929,1.163934350013733,0.6821199655532837,1.4920015335083008,50000.0,0.5541000366210938,2.1089136600494385,10000.0,30285.574808120728,32769.177268743515,30285.574808120728,2477.4188113212585,2.660482168197632,0.0 -66700,1.2070986,3.4574592,,,,,,,,,,,,,, -66800,1.1459582,3.3977957,,,,,,,,,,,,,, -66900,1.0818938,4.9023705,,,,,,,,,,,,,, -67000,1.1631763,5.162732,,,,,,,,,,,,,, -67100,1.0859178,4.4597263,,,,,,,,,,,,,, -67200,1.2844901,5.216612,,,,,,,,,,,,,, -67300,1.1831902,3.4169207,,,,,,,,,,,,,, -67400,1.0652229,3.7915277,,,,,,,,,,,,,, -67500,1.0541942,4.1279135,,,,,,,,,,,,,, -67600,1.0482485,4.217339,,,,,,,,,,,,,, -67619,,,0.7391406297683716,1.2571861743927002,0.6822999715805054,1.5116750001907349,50000.0,0.551300048828125,2.1411116123199463,10000.0,30705.76295566559,33226.44407486916,30705.76295566559,2514.413983345032,2.6944541931152344,0.0 -67700,1.1116642,4.291999,,,,,,,,,,,,,, -67800,1.1128038,3.404828,,,,,,,,,,,,,, -67900,1.069212,3.5247421,,,,,,,,,,,,,, -68000,1.0659198,4.0376487,,,,,,,,,,,,,, -68100,1.1033432,4.9973993,,,,,,,,,,,,,, -68200,1.1404327,3.4276347,,,,,,,,,,,,,, -68300,1.1645436,3.6474738,,,,,,,,,,,,,, -68400,1.0527626,4.2292323,,,,,,,,,,,,,, -68500,1.0076768,4.9380035,,,,,,,,,,,,,, -68545,,,0.74609375,1.2384456396102903,0.683459997177124,1.5109264850616455,50000.0,0.558899998664856,2.133676052093506,10000.0,31125.82836127281,33680.55595970154,31125.82836127281,2548.3767414093018,2.7300000190734863,0.0 -68600,1.2328901,3.4672484,,,,,,,,,,,,,, -68700,1.2577311,5.124217,,,,,,,,,,,,,, -68800,1.0659647,3.3139548,,,,,,,,,,,,,, -68900,1.1558741,3.3517275,,,,,,,,,,,,,, -69000,1.1497622,3.3752303,,,,,,,,,,,,,, -69100,1.0318723,3.926727,,,,,,,,,,,,,, -69200,1.0698397,3.387788,,,,,,,,,,,,,, -69300,1.0531185,4.7758765,,,,,,,,,,,,,, -69400,1.1515716,5.1388884,,,,,,,,,,,,,, -69471,,,0.755175769329071,1.170236349105835,0.684939980506897,1.482764482498169,50000.0,0.5611000061035156,2.103188991546631,10000.0,31545.97432255745,34133.97758722305,31545.97432255745,2581.5695893764496,2.764475584030152,0.0 -69500,1.0408221,3.3490567,,,,,,,,,,,,,, -69600,1.1000386,3.347468,,,,,,,,,,,,,, -69700,1.1394329,3.278398,,,,,,,,,,,,,, -69800,1.1127573,4.7275352,,,,,,,,,,,,,, -69900,1.0479617,4.850443,,,,,,,,,,,,,, -70000,1.329598,3.4917336,,,,,,,,,,,,,, -70100,1.1382691,3.927823,,,,,,,,,,,,,, -70200,1.0920507,3.664833,,,,,,,,,,,,,, -70300,1.1283389,3.3492217,,,,,,,,,,,,,, -70398,,,0.7429296970367432,1.2822554111480713,0.6846599578857422,1.5301166772842407,50000.0,0.5548000335693359,2.153522253036499,10000.0,31966.08284807205,34589.028621673584,31966.08284807205,2616.4249787330627,2.803164482116699,0.0 -70400,1.2650551,3.3864865,,,,,,,,,,,,,, -70500,1.1314994,3.4240897,,,,,,,,,,,,,, -70600,1.0955927,3.3794992,,,,,,,,,,,,,, -70700,1.1400608,3.814345,,,,,,,,,,,,,, -70800,1.0949762,4.246401,,,,,,,,,,,,,, -70900,1.1179966,4.5529447,,,,,,,,,,,,,, -71000,1.1615449,4.9882927,,,,,,,,,,,,,, -71100,1.101839,4.418654,,,,,,,,,,,,,, -71200,1.1422356,3.434606,,,,,,,,,,,,,, -71300,1.075457,3.7018404,,,,,,,,,,,,,, -71325,,,0.7483007907867432,1.252873182296753,0.6881600022315979,1.5203845500946045,50000.0,0.5677000284194946,2.128825664520264,10000.0,32385.999056339264,35042.29918694496,32385.999056339264,2649.695028066635,2.83899450302124,0.0 -71400,1.0629461,4.715215,,,,,,,,,,,,,, -71500,1.1080257,3.4598644,,,,,,,,,,,,,, -71600,1.2256999,3.356328,,,,,,,,,,,,,, -71700,1.1237262,3.6995785,,,,,,,,,,,,,, -71800,1.2624632,5.1768517,,,,,,,,,,,,,, -71900,1.1602366,3.4006503,,,,,,,,,,,,,, -72000,1.1614747,4.906856,,,,,,,,,,,,,, -72100,1.0690526,3.6050975,,,,,,,,,,,,,, -72200,1.1566824,3.6092155,,,,,,,,,,,,,, -72251,,,0.7540820240974426,1.2140352725982666,0.6817799806594849,1.519381046295166,50000.0,0.5610000491142273,2.117690086364746,10000.0,32805.99316358566,35496.07172703743,32805.99316358566,2683.389586210251,2.873255014419556,0.0 -72300,1.0701144,4.8896513,,,,,,,,,,,,,, -72400,1.2623407,3.4259157,,,,,,,,,,,,,, -72500,1.0267907,4.0852823,,,,,,,,,,,,,, -72600,1.1533014,5.007082,,,,,,,,,,,,,, -72700,1.2316352,4.8159084,,,,,,,,,,,,,, -72800,1.1828825,4.9638834,,,,,,,,,,,,,, -72900,1.2422872,3.442867,,,,,,,,,,,,,, -73000,1.1072938,4.807455,,,,,,,,,,,,,, -73100,1.2361462,3.3575068,,,,,,,,,,,,,, -73176,,,0.7476171851158142,1.2500646114349363,0.6918999552726746,1.4931375980377195,50000.0,0.5640000104904175,2.119715929031372,10000.0,33226.05568480492,35948.66935944557,33226.05568480492,2715.83868432045,2.911292314529419,0.0 -73200,1.197566,3.4247139,,,,,,,,,,,,,, -73300,1.1868743,3.4495473,,,,,,,,,,,,,, -73400,1.0270312,4.5595903,,,,,,,,,,,,,, -73500,1.1690664,3.345459,,,,,,,,,,,,,, -73600,1.1249424,4.709548,,,,,,,,,,,,,, -73700,1.1271336,4.1211905,,,,,,,,,,,,,, -73800,1.0915843,3.4636035,,,,,,,,,,,,,, -73900,1.2230302,3.3301115,,,,,,,,,,,,,, -74000,1.1280613,3.6782336,,,,,,,,,,,,,, -74100,,,0.7524218559265137,1.1826977729797363,0.6895999908447266,1.460847020149231,50000.0,0.5685000419616699,2.076725721359253,10000.0,33646.12110567093,36402.65527367592,33646.12110567093,2749.675267457962,2.9467294216156006,0.0 -74100,1.2444911,3.3540468,,,,,,,,,,,,,, -74200,1.1496826,3.4277744,,,,,,,,,,,,,, -74300,1.1199327,3.931526,,,,,,,,,,,,,, -74400,1.1767042,3.3740664,,,,,,,,,,,,,, -74500,1.2199918,5.13645,,,,,,,,,,,,,, -74600,1.178741,3.5924776,,,,,,,,,,,,,, -74700,1.0959584,5.1860313,,,,,,,,,,,,,, -74800,1.2743156,3.3954568,,,,,,,,,,,,,, -74900,1.1230807,4.410947,,,,,,,,,,,,,, -75000,1.2031977,3.8807938,,,,,,,,,,,,,, -75027,,,0.7589452862739563,1.1854172945022583,0.6945599913597107,1.469211220741272,50000.0,0.5649999976158142,2.099517107009888,10000.0,34066.15981054306,36858.05403661728,34066.15981054306,2784.946478128433,2.986032724380493,0.0 -75100,1.1788818,3.424697,,,,,,,,,,,,,, -75200,1.1988142,3.3942738,,,,,,,,,,,,,, -75300,1.162168,3.3587573,,,,,,,,,,,,,, -75400,1.1114973,5.001313,,,,,,,,,,,,,, -75500,1.1482975,3.3914723,,,,,,,,,,,,,, -75600,1.2168504,3.618608,,,,,,,,,,,,,, -75700,1.2030972,3.3612366,,,,,,,,,,,,,, -75800,1.164466,3.741089,,,,,,,,,,,,,, -75900,1.2390882,3.4049284,,,,,,,,,,,,,, -75953,,,0.7703906297683716,1.1538687944412231,0.6911799907684326,1.488873839378357,50000.0,0.5619000196456909,2.1018519401550293,10000.0,34486.180584430695,37311.97987627983,34486.180584430695,2818.763783454895,3.024766683578491,0.0 -76000,1.1925908,3.3292754,,,,,,,,,,,,,, -76100,1.1204582,3.45435,,,,,,,,,,,,,, -76200,1.1621077,4.970871,,,,,,,,,,,,,, -76300,1.1063405,3.4684389,,,,,,,,,,,,,, -76400,1.2224723,3.362515,,,,,,,,,,,,,, -76500,1.1634284,3.6056044,,,,,,,,,,,,,, -76600,1.2539862,5.129879,,,,,,,,,,,,,, -76700,1.2050104,3.411755,,,,,,,,,,,,,, -76800,1.0453699,3.8251746,,,,,,,,,,,,,, -76878,,,0.7526757717132568,1.1900469064712524,0.6944400072097778,1.4452449083328247,50000.0,0.5715000033378601,2.0509159564971924,10000.0,34906.22839021683,37763.84492731094,34906.22839021683,2850.4957184791565,3.0609076023101807,0.0 -76900,1.241409,3.3729908,,,,,,,,,,,,,, -77000,1.1840847,3.402648,,,,,,,,,,,,,, -77100,1.1261566,3.6011193,,,,,,,,,,,,,, -77200,1.2383188,5.1350584,,,,,,,,,,,,,, -77300,1.0682136,4.2659554,,,,,,,,,,,,,, -77400,1.2183024,3.3881047,,,,,,,,,,,,,, -77500,1.0780584,4.141063,,,,,,,,,,,,,, -77600,1.2136208,3.3944557,,,,,,,,,,,,,, -77700,1.2477816,3.5608227,,,,,,,,,,,,,, -77800,1.3963857,3.3713987,,,,,,,,,,,,,, -77804,,,0.7564452886581421,1.186036467552185,0.6922799944877625,1.4668896198272705,50000.0,0.5711000561714172,2.069314956665039,10000.0,35326.41598343849,38219.564494133,35326.41598343849,2885.9357640743256,3.1037163734436035,0.0 -77900,1.2281327,3.4069679,,,,,,,,,,,,,, -78000,1.2066442,4.3708367,,,,,,,,,,,,,, -78100,1.1686071,3.4317033,,,,,,,,,,,,,, -78200,1.2555901,3.3886693,,,,,,,,,,,,,, -78300,1.1130813,3.4456694,,,,,,,,,,,,,, -78400,1.2133743,4.659067,,,,,,,,,,,,,, -78500,1.0739595,3.3893023,,,,,,,,,,,,,, -78600,1.2430242,3.3246748,,,,,,,,,,,,,, -78700,1.1815639,3.3031294,,,,,,,,,,,,,, -78733,,,0.7732617259025574,1.1194065809249878,0.6947199702262878,1.444237470626831,50000.0,0.5674000382423401,2.0553789138793945,10000.0,35746.742753982544,38673.86698675156,35746.742753982544,2919.8212456703186,3.144604682922364,0.0 -78800,1.1250407,3.6409895,,,,,,,,,,,,,, -78900,1.2065512,4.1746716,,,,,,,,,,,,,, -79000,1.2242891,3.4731069,,,,,,,,,,,,,, -79100,1.2070227,3.306381,,,,,,,,,,,,,, -79200,1.2243927,4.405903,,,,,,,,,,,,,, -79300,1.2908597,3.4003134,,,,,,,,,,,,,, -79400,1.1843969,3.2777753,,,,,,,,,,,,,, -79500,1.2656652,3.412695,,,,,,,,,,,,,, -79600,1.1654843,3.5547528,,,,,,,,,,,,,, -79660,,,0.7591992020606995,1.1473793983459473,0.6970199942588806,1.4147883653640747,50000.0,0.5734000205993652,2.0305778980255127,10000.0,36166.86061668396,39125.99325990677,36166.86061668396,2951.7412304878235,3.184016704559326,0.0 -79700,1.1064891,3.998098,,,,,,,,,,,,,, -79800,1.4438422,4.7367163,,,,,,,,,,,,,, -79900,1.1179688,4.229555,,,,,,,,,,,,,, -80000,1.1889832,3.609089,,,,,,,,,,,,,, -80100,1.2794776,3.385085,,,,,,,,,,,,,, -80200,1.2451339,3.3148527,,,,,,,,,,,,,, -80300,1.1723268,3.6645954,,,,,,,,,,,,,, -80400,1.0676177,4.078266,,,,,,,,,,,,,, -80500,1.4076631,3.3156135,,,,,,,,,,,,,, -80584,,,0.7634375095367432,1.142561674118042,0.7002399563789368,1.4214868545532229,50000.0,0.5719000101089478,2.041655302047729,10000.0,36586.801466465,39576.94997668266,36586.801466465,2982.663145303726,3.2296504974365234,0.0 -80600,1.3267655,5.058346,,,,,,,,,,,,,, -80700,1.2187431,3.3667004,,,,,,,,,,,,,, -80800,1.1621542,3.503415,,,,,,,,,,,,,, -80900,1.1593015,3.618633,,,,,,,,,,,,,, -81000,1.2486224,3.6843138,,,,,,,,,,,,,, -81100,1.3646619,3.356144,,,,,,,,,,,,,, -81200,1.181549,4.575342,,,,,,,,,,,,,, -81300,1.1536244,4.6632524,,,,,,,,,,,,,, -81400,1.4927459,5.075599,,,,,,,,,,,,,, -81500,1.0386066,4.0437536,,,,,,,,,,,,,, -81508,,,0.7697460651397705,1.1436564922332764,0.6994400024414062,1.451212763786316,50000.0,0.5731000304222107,2.0733675956726074,10000.0,37006.7823369503,40031.986365795135,37006.7823369503,3017.626652240753,3.2731289863586426,0.0 -81600,1.1377655,3.9361544,,,,,,,,,,,,,, -81700,1.1118207,4.285017,,,,,,,,,,,,,, -81800,1.1962278,3.38468,,,,,,,,,,,,,, -81900,1.2387936,3.3137932,,,,,,,,,,,,,, -82000,1.1775037,3.279214,,,,,,,,,,,,,, -82100,1.2190979,3.3957257,,,,,,,,,,,,,, -82200,1.1811289,4.0895452,,,,,,,,,,,,,, -82300,1.1747272,3.8907447,,,,,,,,,,,,,, -82400,1.2120111,3.406418,,,,,,,,,,,,,, -82435,,,0.7611523270606995,1.1672053337097168,0.699400007724762,1.4308583736419678,50000.0,0.5703999996185303,2.048683881759644,10000.0,37426.94573545456,40487.77037620544,37426.94573545456,3053.159845352173,3.3112025260925293,0.0 -82500,1.2050428,3.2940466,,,,,,,,,,,,,, -82600,1.2061996,3.7045174,,,,,,,,,,,,,, -82700,1.1199445,4.243548,,,,,,,,,,,,,, -82800,1.0867972,4.2201867,,,,,,,,,,,,,, -82900,1.2172192,4.772333,,,,,,,,,,,,,, -83000,1.2811041,3.3726194,,,,,,,,,,,,,, -83100,1.2472378,4.929949,,,,,,,,,,,,,, -83200,1.264631,3.3862107,,,,,,,,,,,,,, -83300,1.609016,3.3581553,,,,,,,,,,,,,, -83361,,,0.7637695074081421,1.173699975013733,0.6990399956703186,1.4479047060012815,50000.0,0.5679000020027161,2.0731194019317627,10000.0,37846.88027453423,40944.77842760086,37846.88027453423,3090.146003007889,3.34938383102417,0.0 -83400,1.2218302,3.4492736,,,,,,,,,,,,,, -83500,1.2520403,3.2762077,,,,,,,,,,,,,, -83600,1.13909,3.5552897,,,,,,,,,,,,,, -83700,1.225356,3.662704,,,,,,,,,,,,,, -83800,1.331033,3.4681754,,,,,,,,,,,,,, -83900,1.243126,4.708927,,,,,,,,,,,,,, -84000,1.2752436,3.3075817,,,,,,,,,,,,,, -84100,1.1492019,3.3007975,,,,,,,,,,,,,, -84200,1.3516474,3.512801,,,,,,,,,,,,,, -84287,,,0.7743945121765137,1.1085599660873413,0.70305997133255,1.4031741619110107,50000.0,0.5777000188827515,2.0119788646698,10000.0,38267.15163445473,41401.175797224045,38267.15163445473,3126.185777664185,3.386541843414306,0.0 -84300,1.4248416,3.275389,,,,,,,,,,,,,, -84400,1.2671402,3.2751603,,,,,,,,,,,,,, -84500,1.2569754,3.2941206,,,,,,,,,,,,,, -84600,1.364356,3.3710687,,,,,,,,,,,,,, -84700,1.2057482,4.303631,,,,,,,,,,,,,, -84800,1.4428409,3.433719,,,,,,,,,,,,,, -84900,1.2777896,3.242126,,,,,,,,,,,,,, -85000,1.2861208,3.373056,,,,,,,,,,,,,, -85100,1.2341216,3.9989238,,,,,,,,,,,,,, -85200,1.3012421,3.368824,,,,,,,,,,,,,, -85213,,,0.7837694883346558,1.086231708526611,0.7024999856948853,1.429186224937439,50000.0,0.5763000249862671,2.0371885299682617,10000.0,38687.10668325424,41854.26957678795,38687.10668325424,3159.238482236862,3.423867702484131,0.0 -85300,1.2059411,4.649508,,,,,,,,,,,,,, -85400,1.2237476,4.1106405,,,,,,,,,,,,,, -85500,1.2923914,3.5537925,,,,,,,,,,,,,, -85600,1.1765127,3.6792486,,,,,,,,,,,,,, -85700,1.2429318,3.5222163,,,,,,,,,,,,,, -85800,1.2580742,3.4345193,,,,,,,,,,,,,, -85900,1.4127108,4.955584,,,,,,,,,,,,,, -86000,1.3510218,4.986854,,,,,,,,,,,,,, -86100,1.2041199,4.649994,,,,,,,,,,,,,, -86141,,,0.7665038704872131,1.1527868509292605,0.7004599571228027,1.423740267753601,50000.0,0.5753000378608704,2.027710199356079,10000.0,39107.087277412415,42309.15513443947,39107.087277412415,3194.0549857616425,3.4639732837677,0.0 -86200,1.2073534,4.663836,,,,,,,,,,,,,, -86300,1.2993176,3.3431098,,,,,,,,,,,,,, -86400,1.1973853,3.3041523,,,,,,,,,,,,,, -86500,1.3244965,3.4052577,,,,,,,,,,,,,, -86600,1.0933355,3.5342612,,,,,,,,,,,,,, -86700,1.3169551,3.412157,,,,,,,,,,,,,, -86800,1.3090374,3.4355588,,,,,,,,,,,,,, -86900,1.2719989,3.2760022,,,,,,,,,,,,,, -87000,1.5247848,5.1621675,,,,,,,,,,,,,, -87068,,,0.7691406011581421,1.0984034538269043,0.7033999562263489,1.393964767456055,50000.0,0.581000030040741,2.0015461444854736,10000.0,39527.25010895729,42763.15571928024,39527.25010895729,3227.807467699051,3.500558376312256,0.0 -87100,1.1314697,4.195196,,,,,,,,,,,,,, -87200,1.2995502,3.342986,,,,,,,,,,,,,, -87300,1.1637613,3.321199,,,,,,,,,,,,,, -87400,1.1968561,3.2912283,,,,,,,,,,,,,, -87500,1.1414876,4.481322,,,,,,,,,,,,,, -87600,1.1459941,4.0479517,,,,,,,,,,,,,, -87700,1.2910643,3.296591,,,,,,,,,,,,,, -87800,1.3275977,3.2754989,,,,,,,,,,,,,, -87900,1.3575894,4.734151,,,,,,,,,,,,,, -87991,,,0.7862499952316284,1.0665332078933716,0.7097600102424622,1.3905929327011108,50000.0,0.5868000388145447,1.9906413555145264,10000.0,39947.20313882828,43219.547652721405,39947.20313882828,3264.1564960479736,3.542058229446411,0.0 -88000,1.3770727,4.7052813,,,,,,,,,,,,,, -88100,1.3230224,3.5053577,,,,,,,,,,,,,, -88200,1.3469468,3.3495526,,,,,,,,,,,,,, -88300,1.2196841,3.2650151,,,,,,,,,,,,,, -88400,1.2414181,3.4153907,,,,,,,,,,,,,, -88500,1.1952039,3.6217713,,,,,,,,,,,,,, -88600,1.2578728,3.3027027,,,,,,,,,,,,,, -88700,1.3180044,4.389063,,,,,,,,,,,,,, -88800,1.3493662,4.903988,,,,,,,,,,,,,, -88900,1.2509489,3.498323,,,,,,,,,,,,,, -88918,,,0.7696874737739563,1.142638087272644,0.7088800072669983,1.4129151105880735,50000.0,0.5790000557899475,2.0272250175476074,10000.0,40367.46312975884,43673.03454852104,40367.46312975884,3297.296382665634,3.579483985900879,0.0 -89000,1.2676181,3.9422007,,,,,,,,,,,,,, -89100,1.1731839,4.1191163,,,,,,,,,,,,,, -89200,1.2494731,3.2897987,,,,,,,,,,,,,, -89300,1.2889351,3.3122602,,,,,,,,,,,,,, -89400,1.2639765,3.383845,,,,,,,,,,,,,, -89500,1.2672268,3.3605285,,,,,,,,,,,,,, -89600,1.1339724,3.6044145,,,,,,,,,,,,,, -89700,1.2637532,3.2831116,,,,,,,,,,,,,, -89800,1.2056773,3.5887845,,,,,,,,,,,,,, -89848,,,0.7753515243530273,1.1038709878921509,0.7093200087547302,1.3977713584899902,50000.0,0.5851000547409058,2.0033745765686035,10000.0,40787.6904566288,44128.43647813797,40787.6904566288,3332.3801724910736,3.621098756790161,0.0 -89900,1.2461298,3.267349,,,,,,,,,,,,,, -90000,1.3082154,3.2456305,,,,,,,,,,,,,, -90100,1.2786531,3.294753,,,,,,,,,,,,,, -90200,1.3226075,3.2182672,,,,,,,,,,,,,, -90300,1.2708505,3.3248777,,,,,,,,,,,,,, -90400,1.232645,3.513211,,,,,,,,,,,,,, -90500,1.2970088,3.4786546,,,,,,,,,,,,,, -90600,1.2475735,3.7215536,,,,,,,,,,,,,, -90700,1.3634753,3.3699293,,,,,,,,,,,,,, -90774,,,0.781054675579071,1.074851155281067,0.7090799808502197,1.384374976158142,50000.0,0.5878000259399414,1.9843614101409912,10000.0,41207.61959028244,44582.04789733887,41207.61959028244,3365.9759092330933,3.658256053924561,0.0 -90800,1.3122495,3.4036655,,,,,,,,,,,,,, -90900,1.4056269,5.023668,,,,,,,,,,,,,, -91000,1.2689456,3.2975976,,,,,,,,,,,,,, -91100,1.3167237,3.3553157,,,,,,,,,,,,,, -91200,1.3005431,3.6629558,,,,,,,,,,,,,, -91300,1.21021,3.4193075,,,,,,,,,,,,,, -91400,1.4559151,3.3562083,,,,,,,,,,,,,, -91500,1.2118095,4.2457123,,,,,,,,,,,,,, -91600,1.3227351,3.4753218,,,,,,,,,,,,,, -91700,1.2156503,3.737316,,,,,,,,,,,,,, -91701,,,0.77406245470047,1.1004438400268557,0.7123000025749207,1.3684351444244385,50000.0,0.5909000039100647,1.9665168523788448,10000.0,41628.0981926918,45036.28593277931,41628.0981926918,3399.644933462143,3.699693918228149,0.0 -91800,1.5140129,5.054346,,,,,,,,,,,,,, -91900,1.3768382,4.7582436,,,,,,,,,,,,,, -92000,1.2553986,3.19092,,,,,,,,,,,,,, -92100,1.5284317,3.328326,,,,,,,,,,,,,, -92200,1.3365148,3.553738,,,,,,,,,,,,,, -92300,1.3632722,4.8271265,,,,,,,,,,,,,, -92400,1.2935619,3.3395336,,,,,,,,,,,,,, -92500,1.1987853,3.3184001,,,,,,,,,,,,,, -92600,1.2916749,3.7708814,,,,,,,,,,,,,, -92629,,,0.7764062285423279,1.1032389402389526,0.7084800004959106,1.3871572017669678,50000.0,0.5853000283241272,1.97959566116333,10000.0,42048.34095311165,45485.623708724976,42048.34095311165,3428.649636030197,3.740588903427124,0.0 -92700,1.4248734,3.283136,,,,,,,,,,,,,, -92800,1.3089627,3.3204422,,,,,,,,,,,,,, -92900,1.2913562,4.377938,,,,,,,,,,,,,, -93000,1.4604739,3.4024127,,,,,,,,,,,,,, -93100,1.314624,3.7659557,,,,,,,,,,,,,, -93200,1.4365609,3.2761557,,,,,,,,,,,,,, -93300,1.421076,4.981224,,,,,,,,,,,,,, -93400,1.1892413,3.78756,,,,,,,,,,,,,, -93500,1.3214668,3.328636,,,,,,,,,,,,,, -93553,,,0.7830468416213989,1.04946768283844,0.7099199891090393,1.3628854751586914,50000.0,0.5873000025749207,1.9714946746826167,10000.0,42468.40495491028,45944.72555589676,42468.40495491028,3467.593858480453,3.784477949142456,0.0 -93600,1.3324652,3.2959743,,,,,,,,,,,,,, -93700,1.4799339,3.2993567,,,,,,,,,,,,,, -93800,1.21972,4.228961,,,,,,,,,,,,,, -93900,1.2442832,3.767174,,,,,,,,,,,,,, -94000,1.2839711,3.2485247,,,,,,,,,,,,,, -94100,1.2021879,3.1220756,,,,,,,,,,,,,, -94200,1.4415572,4.7326674,,,,,,,,,,,,,, -94300,1.2080225,4.2900763,,,,,,,,,,,,,, -94400,1.3293542,3.208221,,,,,,,,,,,,,, -94479,,,0.8056640625,0.9454649686813354,0.718459963798523,1.3121576309204102,50000.0,0.5933000445365906,1.9169158935546875,10000.0,42888.62809586525,46400.34952402115,42888.62809586525,3502.902225732804,3.828879117965698,0.0 -94500,1.3876246,3.2336957,,,,,,,,,,,,,, -94600,1.340662,4.8513675,,,,,,,,,,,,,, -94700,1.3165945,3.2172542,,,,,,,,,,,,,, -94800,1.3184383,3.9242375,,,,,,,,,,,,,, -94900,1.3868555,4.7821217,,,,,,,,,,,,,, -95000,1.5527773,4.8241963,,,,,,,,,,,,,, -95100,1.3526398,3.3481824,,,,,,,,,,,,,, -95200,1.3353785,3.2095013,,,,,,,,,,,,,, -95300,1.4543988,3.3000987,,,,,,,,,,,,,, -95400,1.2114583,3.8150651,,,,,,,,,,,,,, -95407,,,0.7821093797683716,1.0739256143569946,0.7144799828529358,1.3560128211975098,50000.0,0.5925000309944153,1.9520272016525269,10000.0,43308.81820011139,46854.504509449005,43308.81820011139,3536.771305322647,3.8754472732543945,0.0 -95500,1.3667067,3.285116,,,,,,,,,,,,,, -95600,1.4924654,3.4348507,,,,,,,,,,,,,, -95700,1.2814224,3.40825,,,,,,,,,,,,,, -95800,1.3551003,3.2770238,,,,,,,,,,,,,, -95900,1.3653607,4.898947,,,,,,,,,,,,,, -96000,1.350911,3.1967325,,,,,,,,,,,,,, -96100,1.5090619,4.9278803,,,,,,,,,,,,,, -96200,1.2620386,4.310495,,,,,,,,,,,,,, -96300,1.3053093,3.264132,,,,,,,,,,,,,, -96335,,,0.78955078125,1.03350567817688,0.7170000076293945,1.3376338481903076,50000.0,0.5966000556945801,1.926663517951965,10000.0,43729.01921200752,47310.45342755318,43729.01921200752,3572.430454492569,3.914361476898194,0.0 -96400,1.2938535,3.641836,,,,,,,,,,,,,, -96500,1.4886725,4.754251,,,,,,,,,,,,,, -96600,1.344043,3.2574959,,,,,,,,,,,,,, -96700,1.4620024,4.920784,,,,,,,,,,,,,, -96800,1.3712564,3.351046,,,,,,,,,,,,,, -96900,1.3659154,3.2439828,,,,,,,,,,,,,, -97000,1.2716961,3.2238212,,,,,,,,,,,,,, -97100,1.1945043,3.9366124,,,,,,,,,,,,,, -97200,1.3332417,3.3364675,,,,,,,,,,,,,, -97262,,,0.7940233945846558,1.0544887781143188,0.7132399678230286,1.3924885988235474,50000.0,0.5884000062942505,1.9976340532302856,10000.0,44149.1900267601,47764.96751379967,44149.1900267601,3606.683182001114,3.956416368484497,0.0 -97300,1.3045241,3.1354916,,,,,,,,,,,,,, -97400,1.4297996,3.190577,,,,,,,,,,,,,, -97500,1.225606,3.601527,,,,,,,,,,,,,, -97600,1.2668056,3.29986,,,,,,,,,,,,,, -97700,1.2627684,4.2845984,,,,,,,,,,,,,, -97800,1.2528342,4.2819276,,,,,,,,,,,,,, -97900,1.2980736,4.232853,,,,,,,,,,,,,, -98000,1.5073427,4.8625546,,,,,,,,,,,,,, -98100,1.2937497,3.2488854,,,,,,,,,,,,,, -98188,,,0.783886730670929,1.0592858791351318,0.717519998550415,1.3481186628341677,50000.0,0.5919000506401062,1.9524030685424805,10000.0,44569.44393587112,48217.94920134544,44569.44393587112,3639.32297372818,3.995898723602295,0.0 -98200,1.31236,3.6307018,,,,,,,,,,,,,, -98300,1.2202214,3.8038373,,,,,,,,,,,,,, -98400,1.429753,3.23684,,,,,,,,,,,,,, -98500,1.3568771,4.2439456,,,,,,,,,,,,,, -98600,1.3258575,3.1805987,,,,,,,,,,,,,, -98700,1.3658339,3.2861645,,,,,,,,,,,,,, -98800,1.532829,4.7072473,,,,,,,,,,,,,, -98900,1.3391396,3.2349658,,,,,,,,,,,,,, -99000,1.2780975,3.7296484,,,,,,,,,,,,,, -99100,1.3388457,3.8081956,,,,,,,,,,,,,, -99116,,,0.7895702719688416,1.0405603647232056,0.7199599742889404,1.3382598161697388,50000.0,0.5948000550270081,1.938390731811524,10000.0,44989.47699189186,48673.00974225998,44989.47699189186,3674.262171983719,4.035337209701538,0.0 -99200,1.2476819,3.964488,,,,,,,,,,,,,, -99300,1.2608985,3.979793,,,,,,,,,,,,,, -99400,1.2783185,4.4420733,,,,,,,,,,,,,, -99500,1.3970994,3.2614427,,,,,,,,,,,,,, -99600,1.3094746,3.2018392,,,,,,,,,,,,,, -99700,1.3509461,3.948916,,,,,,,,,,,,,, -99800,1.5813959,4.916288,,,,,,,,,,,,,, -99900,1.4380246,3.258077,,,,,,,,,,,,,, -100000,1.4308829,3.4468436,,,,,,,,,,,,,, -100045,,,0.7934765219688416,1.03852379322052,0.719980001449585,1.356076717376709,50000.0,0.5946000218391418,1.9406942129135127,10000.0,45409.69598340988,49126.67878437042,45409.69598340988,3707.622751235962,4.07489013671875,0.0 -100100,1.3377614,3.232912,,,,,,,,,,,,,, -100200,1.3750896,4.261447,,,,,,,,,,,,,, -100300,1.4843398,4.7093315,,,,,,,,,,,,,, -100400,1.394846,3.361741,,,,,,,,,,,,,, -100500,1.2684381,3.2923698,,,,,,,,,,,,,, -100600,1.4286982,4.562293,,,,,,,,,,,,,, -100700,1.2724552,3.438543,,,,,,,,,,,,,, -100800,1.3181973,3.2752612,,,,,,,,,,,,,, -100900,1.3077866,3.2165468,,,,,,,,,,,,,, -100973,,,0.7863867282867432,1.0272761583328247,0.7179799675941467,1.3234541416168213,50000.0,0.5918000340461731,1.9154317378997805,10000.0,45829.853684186935,49581.7681043148,45829.853684186935,3742.466834306717,4.11404824256897,0.0 -101000,1.5851033,4.8900995,,,,,,,,,,,,,, -101100,1.4398545,3.8662364,,,,,,,,,,,,,, -101200,1.4123148,3.2400625,,,,,,,,,,,,,, -101300,1.2595448,3.3930483,,,,,,,,,,,,,, -101400,1.4215201,3.2530098,,,,,,,,,,,,,, -101500,1.289914,3.4083083,,,,,,,,,,,,,, -101600,1.385369,3.259016,,,,,,,,,,,,,, -101700,1.5709311,3.19951,,,,,,,,,,,,,, -101800,1.3608876,4.4326916,,,,,,,,,,,,,, -101900,1.3428645,3.2344995,,,,,,,,,,,,,, -101901,,,0.7957617044448853,1.0064743757247925,0.7209399938583374,1.316093921661377,50000.0,0.598300039768219,1.912156343460083,10000.0,46250.57767724991,50035.35246872902,46250.57767724991,3775.237138032913,4.155972242355347,0.0 -102000,1.3750952,4.8175297,,,,,,,,,,,,,, -102100,1.750612,4.842959,,,,,,,,,,,,,, -102200,1.4207758,3.2189684,,,,,,,,,,,,,, -102300,1.3927166,3.3355336,,,,,,,,,,,,,, -102400,1.635035,4.873836,,,,,,,,,,,,,, -102500,1.5271916,4.619391,,,,,,,,,,,,,, -102600,1.3658473,3.2390287,,,,,,,,,,,,,, -102700,1.5289519,4.610094,,,,,,,,,,,,,, -102800,1.5252007,3.2086112,,,,,,,,,,,,,, -102828,,,0.7928906083106995,1.0309524536132812,0.7185800075531006,1.3496639728546145,50000.0,0.5942000150680542,1.957307934761048,10000.0,46670.89849615097,50491.19104528427,46670.89849615097,3810.6633427143097,4.199147939682007,0.0 -102900,1.4442269,4.5811577,,,,,,,,,,,,,, -103000,1.3769926,3.2986493,,,,,,,,,,,,,, -103100,1.2824684,3.3953366,,,,,,,,,,,,,, -103200,1.4899179,4.867714,,,,,,,,,,,,,, -103300,1.2937251,3.7506704,,,,,,,,,,,,,, -103400,1.4260441,3.2070587,,,,,,,,,,,,,, -103500,1.4532595,3.297171,,,,,,,,,,,,,, -103600,1.249471,3.8662338,,,,,,,,,,,,,, -103700,1.3687387,3.8679419,,,,,,,,,,,,,, -103754,,,0.8130663633346558,0.960721492767334,0.7226600050926208,1.3447332382202148,50000.0,0.5997000336647034,1.9341633319854736,10000.0,47091.06334114075,50944.35725951195,47091.06334114075,3843.5726778507233,4.241317510604858,0.0 -103800,1.3379389,3.1950996,,,,,,,,,,,,,, -103900,1.472203,3.2489765,,,,,,,,,,,,,, -104000,1.4538255,3.2103426,,,,,,,,,,,,,, -104100,1.5559785,4.8353,,,,,,,,,,,,,, -104200,1.4905028,3.1299853,,,,,,,,,,,,,, -104300,1.4186587,3.254834,,,,,,,,,,,,,, -104400,1.2312906,3.6011517,,,,,,,,,,,,,, -104500,1.3468719,3.7166436,,,,,,,,,,,,,, -104600,1.3250356,3.1884074,,,,,,,,,,,,,, -104681,,,0.79408198595047,1.0491050481796265,0.7261199951171875,1.340632438659668,50000.0,0.6028000116348267,1.9325157403945925,10000.0,47511.09459543228,51399.68806958199,47511.09459543228,3878.780025720596,4.284324884414673,0.0 -104700,1.3911262,3.1907716,,,,,,,,,,,,,, -104800,1.4219282,4.3535614,,,,,,,,,,,,,, -104900,1.436053,3.2227068,,,,,,,,,,,,,, -105000,1.4687716,3.219203,,,,,,,,,,,,,, -105100,1.4453914,4.0466213,,,,,,,,,,,,,, -105200,1.294601,3.8184798,,,,,,,,,,,,,, -105300,1.3872849,3.251849,,,,,,,,,,,,,, -105400,1.3856249,3.3356378,,,,,,,,,,,,,, -105500,1.4899205,3.1615524,,,,,,,,,,,,,, -105600,1.5479838,4.601452,,,,,,,,,,,,,, -105604,,,0.7981054782867432,1.0043699741363523,0.7231199741363525,1.320081353187561,50000.0,0.6016000509262085,1.9280401468276973,10000.0,47930.99238157272,51854.14779257774,47930.99238157272,3912.86477136612,4.712871551513672,0.0 -105700,1.5184159,4.742228,,,,,,,,,,,,,, -105800,1.3029246,3.799643,,,,,,,,,,,,,, -105900,1.3428408,3.9291265,,,,,,,,,,,,,, -106000,1.4341555,3.217196,,,,,,,,,,,,,, -106100,1.3614411,3.757739,,,,,,,,,,,,,, -106200,1.3835784,3.1599567,,,,,,,,,,,,,, -106300,1.5017684,3.180571,,,,,,,,,,,,,, -106400,1.4737356,3.219762,,,,,,,,,,,,,, -106500,1.4747617,3.2442353,,,,,,,,,,,,,, -106529,,,0.8086132407188416,0.9574166536331176,0.7298600077629089,1.299571871757507,50000.0,0.603600025177002,1.904357671737671,10000.0,48351.03900790215,52309.55121731758,48351.03900790215,3948.129752397537,4.755612373352051,0.0 -106600,1.338263,3.086057,,,,,,,,,,,,,, -106700,1.7517134,4.6182165,,,,,,,,,,,,,, -106800,1.4250807,3.3275685,,,,,,,,,,,,,, -106900,1.3546369,3.162219,,,,,,,,,,,,,, -107000,1.5482984,3.2480907,,,,,,,,,,,,,, -107100,1.5217141,3.210548,,,,,,,,,,,,,, -107200,1.5934529,4.605357,,,,,,,,,,,,,, -107300,1.3930868,3.276422,,,,,,,,,,,,,, -107400,1.3442905,4.0104594,,,,,,,,,,,,,, -107457,,,0.796582043170929,1.0255680084228516,0.7306199669837952,1.309115648269653,50000.0,0.5999000072479248,1.917056679725647,10000.0,48771.34631705284,52766.52301621437,48771.34631705284,3984.69742679596,4.803227424621582,0.0 -107500,1.856537,4.888,,,,,,,,,,,,,, -107600,1.3626347,3.9614234,,,,,,,,,,,,,, -107700,1.3807882,3.5589876,,,,,,,,,,,,,, -107800,1.3841038,3.0958014,,,,,,,,,,,,,, -107900,1.3945049,3.1860218,,,,,,,,,,,,,, -108000,1.3651785,3.3279495,,,,,,,,,,,,,, -108100,1.3575557,3.1278393,,,,,,,,,,,,,, -108200,1.3879027,3.9769244,,,,,,,,,,,,,, -108300,1.338858,3.2972398,,,,,,,,,,,,,, -108382,,,0.8032421469688416,0.9735230803489684,0.7277199625968933,1.296148419380188,50000.0,0.6070000529289246,1.8809870481491089,10000.0,49191.41876125336,53220.95117616653,49191.41876125336,4018.960418462753,4.847645044326782,0.0 -108400,1.8168223,4.804697,,,,,,,,,,,,,, -108500,1.6458622,4.827344,,,,,,,,,,,,,, -108600,1.6170114,4.453157,,,,,,,,,,,,,, -108700,1.5402951,4.6960206,,,,,,,,,,,,,, -108800,1.6340078,4.450343,,,,,,,,,,,,,, -108900,1.4517705,4.235058,,,,,,,,,,,,,, -109000,1.3801552,4.294735,,,,,,,,,,,,,, -109100,1.5205041,3.2103903,,,,,,,,,,,,,, -109200,1.4234107,3.150646,,,,,,,,,,,,,, -109300,1.365221,3.0632439,,,,,,,,,,,,,, -109308,,,0.80712890625,0.9809885025024414,0.7274799942970276,1.309841513633728,50000.0,0.6025000214576721,1.9155027866363523,10000.0,49611.47404384613,53674.50714588165,49611.47404384613,4052.371610879898,4.887386798858643,0.0 -109400,1.397136,3.5911794,,,,,,,,,,,,,, -109500,1.4943115,3.2160616,,,,,,,,,,,,,, -109600,1.7148784,4.7210336,,,,,,,,,,,,,, -109700,1.5675995,4.74603,,,,,,,,,,,,,, -109800,1.5449818,4.749156,,,,,,,,,,,,,, -109900,1.3945873,3.9790378,,,,,,,,,,,,,, -110000,1.6944768,4.7471433,,,,,,,,,,,,,, -110100,1.459107,3.3470275,,,,,,,,,,,,,, -110200,1.4887874,3.186556,,,,,,,,,,,,,, -110233,,,0.8057421445846558,0.956943690776825,0.7321999669075012,1.2614250183105469,50000.0,0.6146000027656555,1.8612117767333984,10000.0,50031.75916481018,54126.98281383514,50031.75916481018,4084.464169979096,4.937520742416382,0.0 -110300,1.5169635,4.4192014,,,,,,,,,,,,,, -110400,1.4470735,3.8091464,,,,,,,,,,,,,, -110500,1.538464,3.2057216,,,,,,,,,,,,,, -110600,1.3868467,3.4127452,,,,,,,,,,,,,, -110700,1.415317,3.4869082,,,,,,,,,,,,,, -110800,1.5293772,3.23105,,,,,,,,,,,,,, -110900,1.4663374,3.181163,,,,,,,,,,,,,, -111000,1.6314983,3.1757832,,,,,,,,,,,,,, -111100,1.4411718,3.2304926,,,,,,,,,,,,,, -111158,,,0.8017968535423279,1.0286946296691897,0.7317799925804138,1.3278456926345823,50000.0,0.6043000221252441,1.9332977533340447,10000.0,50451.79173922539,54581.74570274353,50451.79173922539,4119.084473133087,4.998009443283081,0.0 -111200,1.4860328,3.3734612,,,,,,,,,,,,,, -111300,1.7824496,4.816523,,,,,,,,,,,,,, -111400,1.4800408,3.1539755,,,,,,,,,,,,,, -111500,1.5601943,3.227821,,,,,,,,,,,,,, -111600,1.5100474,4.1188307,,,,,,,,,,,,,, -111700,1.678803,4.2950916,,,,,,,,,,,,,, -111800,1.3862698,4.1461124,,,,,,,,,,,,,, -111900,1.4659123,3.962924,,,,,,,,,,,,,, -112000,1.581623,3.35137,,,,,,,,,,,,,, -112081,,,0.8128515481948853,0.9389355778694152,0.7346799969673157,1.2732287645339966,50000.0,0.6106000542640686,1.8744934797286987,10000.0,50871.74392461777,55036.06898331642,50871.74392461777,4153.366170406342,5.039664268493652,0.0 -112100,1.4882878,3.1380346,,,,,,,,,,,,,, -112200,1.5795395,4.282777,,,,,,,,,,,,,, -112300,1.4569542,3.116831,,,,,,,,,,,,,, -112400,1.512467,3.4260855,,,,,,,,,,,,,, -112500,1.4334182,3.730118,,,,,,,,,,,,,, -112600,1.5486642,3.1956344,,,,,,,,,,,,,, -112700,1.3064361,3.6820474,,,,,,,,,,,,,, -112800,1.4893976,3.1301448,,,,,,,,,,,,,, -112900,1.4385886,3.4947875,,,,,,,,,,,,,, -113000,1.7193724,4.575553,,,,,,,,,,,,,, -113004,,,0.8268554210662842,0.8962365388870239,0.7337200045585632,1.2884944677352903,50000.0,0.6077000498771667,1.887868046760559,10000.0,51291.82757949829,55491.24841022492,51291.82757949829,4188.366625308991,5.085594415664673,0.0 -113100,1.4824545,3.185563,,,,,,,,,,,,,, -113200,1.4304426,3.3427641,,,,,,,,,,,,,, -113300,1.5110537,3.142288,,,,,,,,,,,,,, -113400,1.5522983,3.2501073,,,,,,,,,,,,,, -113500,1.5882218,3.125303,,,,,,,,,,,,,, -113600,1.5814065,3.138595,,,,,,,,,,,,,, -113700,1.5715594,3.3132243,,,,,,,,,,,,,, -113800,1.4267404,3.3083777,,,,,,,,,,,,,, -113900,1.4980267,3.208774,,,,,,,,,,,,,, -113930,,,0.8086913824081421,0.9540855288505554,0.7350999712944031,1.2699496746063232,50000.0,0.6146000027656555,1.8564660549163816,10000.0,51711.96106266976,55945.711918354034,51711.96106266976,4222.597744464874,5.1359288692474365,0.0 -114000,1.593671,4.555649,,,,,,,,,,,,,, -114100,1.5525569,3.1436548,,,,,,,,,,,,,, -114200,1.5791826,4.281375,,,,,,,,,,,,,, -114300,1.4865328,3.2193968,,,,,,,,,,,,,, -114400,2.0634196,4.6871395,,,,,,,,,,,,,, -114500,1.3309036,3.6056325,,,,,,,,,,,,,, -114600,1.538159,3.2070873,,,,,,,,,,,,,, -114700,1.4345766,3.6900272,,,,,,,,,,,,,, -114800,1.6717978,4.351899,,,,,,,,,,,,,, -114857,,,0.8112109303474426,0.9619636535644532,0.7340999841690063,1.2924312353134155,50000.0,0.6121000051498413,1.887677788734436,10000.0,52131.87544989586,56399.65555882454,52131.87544989586,4256.532593727112,5.181121587753296,0.0 -114900,1.4355038,3.7691476,,,,,,,,,,,,,, -115000,1.5197622,3.2734091,,,,,,,,,,,,,, -115100,1.5620784,3.0583851,,,,,,,,,,,,,, -115200,1.6936121,3.1490192,,,,,,,,,,,,,, -115300,1.5278943,3.1860323,,,,,,,,,,,,,, -115400,1.5078125,3.2090697,,,,,,,,,,,,,, -115500,1.4865336,3.3913784,,,,,,,,,,,,,, -115600,1.5583394,3.1508799,,,,,,,,,,,,,, -115700,1.4316943,3.1054797,,,,,,,,,,,,,, -115782,,,0.8207421898841858,0.9196900129318236,0.7369199991226196,1.2814921140670776,50000.0,0.6084000468254089,1.8864821195602417,10000.0,52551.95956778526,56855.459886312485,52551.95956778526,4292.1596002578735,5.225587368011475,0.0 -115800,1.6542684,4.3224697,,,,,,,,,,,,,, -115900,1.5641371,3.0976794,,,,,,,,,,,,,, -116000,1.5699329,3.8550649,,,,,,,,,,,,,, -116100,1.5564889,3.1816878,,,,,,,,,,,,,, -116200,1.5287414,3.9658787,,,,,,,,,,,,,, -116300,1.6120083,3.1895032,,,,,,,,,,,,,, -116400,1.5277004,3.5319057,,,,,,,,,,,,,, -116500,1.5079097,3.1285608,,,,,,,,,,,,,, -116600,1.4266828,3.229687,,,,,,,,,,,,,, -116700,1.4523009,3.2556884,,,,,,,,,,,,,, -116708,,,0.8111132383346558,0.9671093821525574,0.7394199967384338,1.2774097919464111,50000.0,0.6146000027656555,1.8671879768371584,10000.0,52972.2819852829,57307.31153726578,52972.2819852829,4323.595373153687,5.270857572555542,0.0 -116800,1.5204356,4.129372,,,,,,,,,,,,,, -116900,1.4675301,3.2481132,,,,,,,,,,,,,, -117000,1.582668,3.1303837,,,,,,,,,,,,,, -117100,1.5705278,4.206607,,,,,,,,,,,,,, -117200,1.3969793,3.1667392,,,,,,,,,,,,,, -117300,1.4711586,3.247898,,,,,,,,,,,,,, -117400,1.6564214,3.1085136,,,,,,,,,,,,,, -117500,1.4887525,3.09293,,,,,,,,,,,,,, -117600,1.5143498,3.5347736,,,,,,,,,,,,,, -117635,,,0.8186327815055847,0.9464924931526184,0.73881995677948,1.275384545326233,50000.0,0.613800048828125,1.872160077095032,10000.0,53392.3686144352,57759.757767915726,53392.3686144352,4355.863595724106,5.312883377075195,0.0 -117700,1.6526979,3.086449,,,,,,,,,,,,,, -117800,1.5259168,3.102128,,,,,,,,,,,,,, -117900,1.7227961,3.0800662,,,,,,,,,,,,,, -118000,1.510534,3.0854542,,,,,,,,,,,,,, -118100,1.6306082,3.1756067,,,,,,,,,,,,,, -118200,1.6830113,3.10563,,,,,,,,,,,,,, -118300,1.7314152,3.1471837,,,,,,,,,,,,,, -118400,2.0025258,4.6955056,,,,,,,,,,,,,, -118500,1.5746828,4.102023,,,,,,,,,,,,,, -118557,,,0.8215234279632568,0.9040040969848632,0.7385199666023254,1.2626641988754272,50000.0,0.6128000020980835,1.859431505203247,10000.0,53812.669481277466,58216.1560792923,53812.669481277466,4391.868913650513,5.357113838195801,0.0 -118600,1.5966281,3.1346438,,,,,,,,,,,,,, -118700,1.6350133,4.32011,,,,,,,,,,,,,, -118800,1.6275034,3.1268928,,,,,,,,,,,,,, -118900,1.5404036,3.5490358,,,,,,,,,,,,,, -119000,1.5478036,3.081617,,,,,,,,,,,,,, -119100,1.6738677,3.1516554,,,,,,,,,,,,,, -119200,1.6372256,3.0922446,,,,,,,,,,,,,, -119300,1.614258,3.090712,,,,,,,,,,,,,, -119400,1.5196815,3.3289587,,,,,,,,,,,,,, -119482,,,0.8153515458106995,0.9577239751815796,0.7410399913787842,1.269242763519287,50000.0,0.6167000532150269,1.868729591369629,10000.0,54232.62202167511,58672.590933561325,54232.62202167511,4428.260704755783,5.399370908737183,0.0 -119500,1.4908459,3.396409,,,,,,,,,,,,,, -119600,1.6705917,3.4196172,,,,,,,,,,,,,, -119700,1.628027,3.1819873,,,,,,,,,,,,,, -119800,1.4870381,3.9459405,,,,,,,,,,,,,, -119900,1.6447452,4.5299306,,,,,,,,,,,,,, -120000,1.6860373,3.0475204,,,,,,,,,,,,,, -120100,1.508081,3.976236,,,,,,,,,,,,,, -120200,1.6332232,4.290867,,,,,,,,,,,,,, -120300,1.5550958,4.068596,,,,,,,,,,,,,, -120400,1.578499,3.112558,,,,,,,,,,,,,, -120407,,,0.8184374570846558,0.9347028136253356,0.7399199604988098,1.2617909908294678,50000.0,0.6178000569343567,1.8522409200668333,10000.0,54652.57331991196,59125.07707071304,54652.57331991196,4460.699743509293,5.445420742034912,0.0 -120500,1.5541048,3.0701752,,,,,,,,,,,,,, -120600,1.5054804,3.908874,,,,,,,,,,,,,, -120700,1.7055947,3.1435025,,,,,,,,,,,,,, -120800,1.977938,3.0555978,,,,,,,,,,,,,, -120900,1.699347,3.1586194,,,,,,,,,,,,,, -121000,1.6818405,3.0419533,,,,,,,,,,,,,, -121100,1.5382844,3.108199,,,,,,,,,,,,,, -121200,1.5759721,2.9716284,,,,,,,,,,,,,, -121300,1.595469,4.0399404,,,,,,,,,,,,,, -121333,,,0.8251171708106995,0.8875871300697327,0.743619978427887,1.240481734275818,50000.0,0.6199000477790833,1.8362445831298828,10000.0,55072.54042840004,59578.87798953056,55072.54042840004,4494.443091392517,5.487497329711914,0.0 -121400,1.6489722,4.164724,,,,,,,,,,,,,, -121500,1.6293253,3.186637,,,,,,,,,,,,,, -121600,1.8328265,4.183148,,,,,,,,,,,,,, -121700,1.6219435,3.196635,,,,,,,,,,,,,, -121800,1.5141476,3.0289419,,,,,,,,,,,,,, -121900,1.6722562,4.1867127,,,,,,,,,,,,,, -122000,1.5154241,3.0078032,,,,,,,,,,,,,, -122100,1.6451135,3.1168737,,,,,,,,,,,,,, -122200,2.3426483,4.6998835,,,,,,,,,,,,,, -122261,,,0.8381249904632568,0.8315138220787048,0.7446199655532837,1.2221808433532717,50000.0,0.6212000250816345,1.827873468399048,10000.0,55492.709612846375,60033.37808465958,55492.709612846375,4528.681435108185,5.531659364700317,0.0 -122300,1.6141217,3.489138,,,,,,,,,,,,,, -122400,2.1669476,4.649017,,,,,,,,,,,,,, -122500,1.5898794,3.0340474,,,,,,,,,,,,,, -122600,2.0541914,4.713959,,,,,,,,,,,,,, -122700,1.6197073,3.662869,,,,,,,,,,,,,, -122800,1.8819234,4.71142,,,,,,,,,,,,,, -122900,1.4912318,3.2686718,,,,,,,,,,,,,, -123000,1.7400179,3.0595355,,,,,,,,,,,,,, -123100,1.692204,3.0566828,,,,,,,,,,,,,, -123187,,,0.8201562166213989,0.9035536050796508,0.7415199875831604,1.2345235347747805,50000.0,0.6188000440597534,1.8316484689712524,10000.0,55912.80021524429,60485.82897758484,55912.80021524429,4560.943968057632,5.580912590026856,0.0 -123200,1.9136004,4.6057143,,,,,,,,,,,,,, -123300,1.7028468,3.1352315,,,,,,,,,,,,,, -123400,1.6757531,3.7357764,,,,,,,,,,,,,, -123500,1.640288,4.0999827,,,,,,,,,,,,,, -123600,1.615371,3.0939884,,,,,,,,,,,,,, -123700,1.5976015,3.449666,,,,,,,,,,,,,, -123800,2.639866,4.6804957,,,,,,,,,,,,,, -123900,1.6803606,3.178341,,,,,,,,,,,,,, -124000,1.778392,4.350226,,,,,,,,,,,,,, -124100,1.6978375,4.2735157,,,,,,,,,,,,,, -124111,,,0.8274609446525574,0.8888863325119019,0.7472400069236755,1.228013038635254,50000.0,0.6260000467300415,1.830779194831848,10000.0,56332.88788199425,60940.03962993622,56332.88788199425,4594.97258067131,5.626370191574097,0.0 -124200,1.7470757,3.2540293,,,,,,,,,,,,,, -124300,1.9273002,4.6329775,,,,,,,,,,,,,, -124400,1.6921232,3.1240897,,,,,,,,,,,,,, -124500,1.5464936,3.0379584,,,,,,,,,,,,,, -124600,1.5627123,3.5623915,,,,,,,,,,,,,, -124700,1.6625893,3.0926297,,,,,,,,,,,,,, -124800,1.7962054,3.0672522,,,,,,,,,,,,,, -124900,2.0893886,3.0715196,,,,,,,,,,,,,, -125000,1.6184832,3.6666145,,,,,,,,,,,,,, -125037,,,0.8400976657867432,0.8476256132125854,0.7465400099754333,1.231023907661438,50000.0,0.6228000521659851,1.8111176490783687,10000.0,56752.807072639465,61394.470771074295,56752.807072639465,4629.3885724544525,5.673748731613159,0.0 -125100,1.7328031,3.0466628,,,,,,,,,,,,,, -125200,1.7156096,3.0742235,,,,,,,,,,,,,, -125300,1.7711694,3.0270982,,,,,,,,,,,,,, -125400,1.6725748,3.365865,,,,,,,,,,,,,, -125500,2.08431,4.4209356,,,,,,,,,,,,,, -125600,1.6088381,2.9798844,,,,,,,,,,,,,, -125700,1.8017427,3.9567661,,,,,,,,,,,,,, -125800,1.6843685,3.106294,,,,,,,,,,,,,, -125900,1.6773677,3.6571565,,,,,,,,,,,,,, -125962,,,0.8267577886581421,0.9042437076568604,0.744219958782196,1.2379591464996338,50000.0,0.6270000338554382,1.8236942291259768,10000.0,57172.92175483704,61848.86698770523,57172.92175483704,4663.57472038269,5.72042441368103,0.0 -126000,1.7018087,4.080203,,,,,,,,,,,,,, -126100,1.658387,3.1033647,,,,,,,,,,,,,, -126200,1.6466902,3.078141,,,,,,,,,,,,,, -126300,1.8318273,3.0977044,,,,,,,,,,,,,, -126400,1.840557,4.364113,,,,,,,,,,,,,, -126500,1.643877,3.4573271,,,,,,,,,,,,,, -126600,1.754954,3.0056312,,,,,,,,,,,,,, -126700,1.5888886,3.7296882,,,,,,,,,,,,,, -126800,1.7360177,3.0890675,,,,,,,,,,,,,, -126889,,,0.829394519329071,0.8763020634651184,0.7468599677085876,1.225918889045715,50000.0,0.6248000264167786,1.817178010940552,10000.0,57593.29928016663,62302.96178174019,57593.29928016663,4697.19917845726,5.7643516063690186,0.0 -126900,1.7687706,3.1032948,,,,,,,,,,,,,, -127000,1.7589113,3.080042,,,,,,,,,,,,,, -127100,1.7350754,3.2323983,,,,,,,,,,,,,, -127200,1.6760905,3.036996,,,,,,,,,,,,,, -127300,1.6072628,3.9240704,,,,,,,,,,,,,, -127400,1.6662143,3.0452318,,,,,,,,,,,,,, -127500,1.6416949,3.3477402,,,,,,,,,,,,,, -127600,1.6357155,3.1975434,,,,,,,,,,,,,, -127700,1.9885268,3.0618277,,,,,,,,,,,,,, -127800,1.7121342,4.154373,,,,,,,,,,,,,, -127815,,,0.8350195288658142,0.8415217399597168,0.7514199614524841,1.19737708568573,50000.0,0.6276000142097473,1.7829385995864868,10000.0,58013.23331975937,62756.352815151215,58013.23331975937,4730.561057806015,5.811323881149292,0.0 -127900,1.7095587,3.028169,,,,,,,,,,,,,, -128000,1.7727076,3.1485295,,,,,,,,,,,,,, -128100,1.6598431,3.1537497,,,,,,,,,,,,,, -128200,1.7547956,3.2309597,,,,,,,,,,,,,, -128300,1.8135018,3.090962,,,,,,,,,,,,,, -128400,1.7548343,2.9536142,,,,,,,,,,,,,, -128500,1.7960458,3.2242954,,,,,,,,,,,,,, -128600,1.6776514,3.091679,,,,,,,,,,,,,, -128700,1.7078915,3.230235,,,,,,,,,,,,,, -128742,,,0.8329101204872131,0.8737267255783081,0.7518799901008606,1.2153342962265017,50000.0,0.629300057888031,1.8096715211868288,10000.0,58433.33541345596,63211.96445417404,58433.33541345596,4765.972766160965,5.859126567840576,0.0 -128800,1.7588387,3.0280516,,,,,,,,,,,,,, -128900,1.8633581,3.0293584,,,,,,,,,,,,,, -129000,2.2037752,4.619896,,,,,,,,,,,,,, -129100,1.737663,3.6504989,,,,,,,,,,,,,, -129200,1.9506118,4.2861166,,,,,,,,,,,,,, -129300,1.6201439,3.287191,,,,,,,,,,,,,, -129400,1.7454579,3.5873022,,,,,,,,,,,,,, -129500,1.9314085,3.021702,,,,,,,,,,,,,, -129600,1.7304498,2.998198,,,,,,,,,,,,,, -129669,,,0.8311132788658142,0.8590825200080872,0.7507799863815308,1.207733988761902,50000.0,0.628600001335144,1.7966262102127075,10000.0,58853.82092785835,63661.234209775925,58853.82092785835,4794.663794994354,5.904484272003174,0.0 -129700,1.9032401,3.2705498,,,,,,,,,,,,,, -129800,2.0116055,4.4627,,,,,,,,,,,,,, -129900,2.057666,4.3895664,,,,,,,,,,,,,, -130000,1.7613894,4.0378036,,,,,,,,,,,,,, -130100,2.1154916,4.5142384,,,,,,,,,,,,,, -130200,1.6929641,3.0740564,,,,,,,,,,,,,, -130300,1.676923,3.2565074,,,,,,,,,,,,,, -130400,1.6149807,3.4265501,,,,,,,,,,,,,, -130500,1.7652029,3.041922,,,,,,,,,,,,,, -130594,,,0.84033203125,0.8434707522392273,0.7518799901008606,1.2105284929275513,50000.0,0.628000020980835,1.803253412246704,10000.0,59274.19008851051,64115.61355304718,59274.19008851051,4828.563596725464,5.966588497161865,0.0 -130600,1.7344449,3.6489727,,,,,,,,,,,,,, -130700,1.742761,3.0288558,,,,,,,,,,,,,, -130800,1.7248276,3.5210507,,,,,,,,,,,,,, -130900,1.7602386,3.012101,,,,,,,,,,,,,, -131000,1.8306192,3.8823736,,,,,,,,,,,,,, -131100,1.7298847,3.0084906,,,,,,,,,,,,,, -131200,1.7557116,3.3625548,,,,,,,,,,,,,, -131300,1.6027316,3.7043977,,,,,,,,,,,,,, -131400,1.8539708,3.0443227,,,,,,,,,,,,,, -131500,1.8818942,3.2209353,,,,,,,,,,,,,, -131523,,,0.8484765291213989,0.8026385307312012,0.7519800066947937,1.2015353441238403,50000.0,0.6252000331878662,1.8083374500274656,10000.0,59694.75778841972,64568.58602309227,59694.75778841972,4860.874860286713,6.009950637817383,0.0 -131600,1.774332,3.2589025,,,,,,,,,,,,,, -131700,1.7990774,3.0233028,,,,,,,,,,,,,, -131800,1.8605614,3.0415766,,,,,,,,,,,,,, -131900,1.792469,3.8678572,,,,,,,,,,,,,, -132000,1.7584476,2.9430375,,,,,,,,,,,,,, -132100,1.8721128,2.8675683,,,,,,,,,,,,,, -132200,1.7456979,3.2929885,,,,,,,,,,,,,, -132300,1.8221946,3.82964,,,,,,,,,,,,,, -132400,1.7966653,3.0040545,,,,,,,,,,,,,, -132448,,,0.8370116949081421,0.834997832775116,0.7545199990272522,1.1796581745147705,50000.0,0.6305000185966492,1.7780799865722656,10000.0,60114.998239040375,65024.695997715,60114.998239040375,4896.653185606003,6.053093671798706,0.0 -132500,1.8051058,3.148125,,,,,,,,,,,,,, -132600,1.7388844,2.9796684,,,,,,,,,,,,,, -132700,1.973491,4.566074,,,,,,,,,,,,,, -132800,2.0108168,3.0251849,,,,,,,,,,,,,, -132900,2.0368226,2.9903147,,,,,,,,,,,,,, -133000,1.7726635,3.0127854,,,,,,,,,,,,,, -133100,1.7036009,3.5222929,,,,,,,,,,,,,, -133200,1.8832644,3.6430638,,,,,,,,,,,,,, -133300,2.049558,4.529075,,,,,,,,,,,,,, -133375,,,0.8409179449081421,0.8413550853729248,0.7555800080299377,1.2021714448928833,50000.0,0.6372000575065613,1.7991794347763062,10000.0,60535.22475409508,65477.413791656494,60535.22475409508,4929.048456430435,6.100675344467163,0.0 -133400,2.0197423,2.9905689,,,,,,,,,,,,,, -133500,1.9635812,2.9838417,,,,,,,,,,,,,, -133600,1.9401634,2.9751132,,,,,,,,,,,,,, -133700,2.0478988,3.4978209,,,,,,,,,,,,,, -133800,1.8131671,2.9949596,,,,,,,,,,,,,, -133900,1.8155355,3.351301,,,,,,,,,,,,,, -134000,1.8906357,3.0546327,,,,,,,,,,,,,, -134100,1.7586399,3.2790158,,,,,,,,,,,,,, -134200,1.8660184,3.470446,,,,,,,,,,,,,, -134299,,,0.8481249809265137,0.8018088936805725,0.7539399862289429,1.1947689056396484,50000.0,0.6312000155448914,1.78182053565979,10000.0,60955.59069728851,65933.40303897858,60955.59069728851,4964.578231573105,6.146247386932373,0.0 -134300,1.7856568,2.9325943,,,,,,,,,,,,,, -134400,2.0565803,4.411559,,,,,,,,,,,,,, -134500,1.7747527,3.2548294,,,,,,,,,,,,,, -134600,1.904423,3.066797,,,,,,,,,,,,,, -134700,1.8974882,3.7300496,,,,,,,,,,,,,, -134800,1.7741885,3.114212,,,,,,,,,,,,,, -134900,1.7928312,2.9519897,,,,,,,,,,,,,, -135000,1.6832774,3.823331,,,,,,,,,,,,,, -135100,1.9012761,3.0791414,,,,,,,,,,,,,, -135200,2.0072336,4.303597,,,,,,,,,,,,,, -135226,,,0.8419921398162842,0.8387060165405273,0.7549200057983398,1.199419617652893,50000.0,0.638200044631958,1.7809736728668213,10000.0,61375.55031371117,66389.22588658333,61375.55031371117,5000.347151756287,6.192385196685791,0.0 -135300,2.0231442,3.048646,,,,,,,,,,,,,, -135400,1.8564711,3.8614182,,,,,,,,,,,,,, -135500,2.1777897,4.4009376,,,,,,,,,,,,,, -135600,2.1960158,4.395927,,,,,,,,,,,,,, -135700,1.859319,2.9476926,,,,,,,,,,,,,, -135800,1.870645,2.9937701,,,,,,,,,,,,,, -135900,1.8184906,2.910099,,,,,,,,,,,,,, -136000,1.7940952,3.096471,,,,,,,,,,,,,, -136100,1.8634299,3.068233,,,,,,,,,,,,,, -136151,,,0.8457421660423279,0.8141063451766968,0.7572399973869324,1.1856240034103394,50000.0,0.6385000348091125,1.7745895385742188,10000.0,61795.65664720535,66844.14261889458,61795.65664720535,5035.061300992966,6.240199089050293,0.0 -136200,1.8386717,3.00207,,,,,,,,,,,,,, -136300,1.8046728,3.351678,,,,,,,,,,,,,, -136400,1.9985375,3.4326317,,,,,,,,,,,,,, -136500,2.015644,3.024043,,,,,,,,,,,,,, -136600,1.7320651,2.8865252,,,,,,,,,,,,,, -136700,2.0811431,3.0148492,,,,,,,,,,,,,, -136800,1.818903,3.3639576,,,,,,,,,,,,,, -136900,1.901812,2.9595935,,,,,,,,,,,,,, -137000,1.923339,2.9462862,,,,,,,,,,,,,, -137076,,,0.8503710627555847,0.803084671497345,0.7572799921035767,1.1784809827804563,50000.0,0.6363000273704529,1.7618496417999268,10000.0,62215.70627474785,67301.74554777145,62215.70627474785,5072.519873142242,6.286186456680298,0.0 -137100,1.8278319,3.226677,,,,,,,,,,,,,, -137200,1.8516448,3.0630112,,,,,,,,,,,,,, -137300,2.2840598,4.476594,,,,,,,,,,,,,, -137400,1.8472867,3.9491131,,,,,,,,,,,,,, -137500,1.8564184,2.8985538,,,,,,,,,,,,,, -137600,1.8240654,3.1866195,,,,,,,,,,,,,, -137700,1.9213202,3.8347733,,,,,,,,,,,,,, -137800,1.8193202,3.2475657,,,,,,,,,,,,,, -137900,2.061348,3.1902094,,,,,,,,,,,,,, -138000,1.9465466,3.0728083,,,,,,,,,,,,,, -138004,,,0.8441210985183716,0.8067802786827087,0.7592200040817261,1.1712822914123535,50000.0,0.6401000022888184,1.7613399028778076,10000.0,62635.84506726265,67755.41611027718,62635.84506726265,5105.9577214717865,6.330903053283691,0.0 -138100,1.8096528,3.612882,,,,,,,,,,,,,, -138200,1.9567798,2.9736178,,,,,,,,,,,,,, -138300,1.9846817,3.0071023,,,,,,,,,,,,,, -138400,1.7138464,3.3443677,,,,,,,,,,,,,, -138500,1.9773312,3.601668,,,,,,,,,,,,,, -138600,1.9222978,2.898231,,,,,,,,,,,,,, -138700,2.2799509,4.3511214,,,,,,,,,,,,,, -138800,1.9090945,2.9732609,,,,,,,,,,,,,, -138900,2.7773638,4.5692077,,,,,,,,,,,,,, -138929,,,0.8473241925239563,0.8092468976974487,0.7607199549674988,1.178433895111084,50000.0,0.6415000557899475,1.7597088813781738,10000.0,63055.99439263344,68210.82635331154,63055.99439263344,5141.126784801483,6.374598264694214,0.0 -139000,2.4095056,4.5352006,,,,,,,,,,,,,, -139100,2.0326738,2.942073,,,,,,,,,,,,,, -139200,1.8485163,2.8529701,,,,,,,,,,,,,, -139300,1.9186469,2.9302094,,,,,,,,,,,,,, -139400,2.0019138,3.0146065,,,,,,,,,,,,,, -139500,1.803104,2.8745914,,,,,,,,,,,,,, -139600,1.9339297,2.9459343,,,,,,,,,,,,,, -139700,1.9911922,2.960847,,,,,,,,,,,,,, -139800,2.1826708,4.290168,,,,,,,,,,,,,, -139855,,,0.8530859351158142,0.7908298969268799,0.7603999972343445,1.179398775100708,50000.0,0.6410000324249268,1.7686790227890017,10000.0,63476.15185451508,68662.43821144104,63476.15185451508,5172.483269929886,6.423179388046265,0.0 -139900,2.5256212,3.8593488,,,,,,,,,,,,,, -140000,2.0293982,2.9426296,,,,,,,,,,,,,, -140100,1.9679372,3.9160523,,,,,,,,,,,,,, -140200,1.8020571,3.0151353,,,,,,,,,,,,,, -140300,1.8717718,3.028074,,,,,,,,,,,,,, -140400,2.4458601,4.437424,,,,,,,,,,,,,, -140500,2.0392501,3.0968795,,,,,,,,,,,,,, -140600,2.1749942,3.9403925,,,,,,,,,,,,,, -140700,2.5122328,4.451141,,,,,,,,,,,,,, -140779,,,0.8641406297683716,0.7573795318603516,0.7615399956703186,1.174419403076172,50000.0,0.6406000256538391,1.7598754167556765,10000.0,63896.06881427765,69118.56551671028,63896.06881427765,5208.587812423706,6.4807868003845215,0.0 -140800,2.1091647,3.062458,,,,,,,,,,,,,, -140900,1.9575961,2.9911394,,,,,,,,,,,,,, -141000,2.0441942,2.9543774,,,,,,,,,,,,,, -141100,1.8056327,2.9773352,,,,,,,,,,,,,, -141200,1.9546201,3.303632,,,,,,,,,,,,,, -141300,2.1395621,2.9303927,,,,,,,,,,,,,, -141400,2.4213972,4.4757566,,,,,,,,,,,,,, -141500,2.1249864,2.9621549,,,,,,,,,,,,,, -141600,1.9377038,2.8918717,,,,,,,,,,,,,, -141700,2.043459,3.0105214,,,,,,,,,,,,,, -141701,,,0.8532226085662842,0.7790336608886719,0.7630999684333801,1.150395750999451,50000.0,0.6450000405311584,1.7349143028259275,10000.0,64316.15193653107,69573.42215561867,64316.15193653107,5243.268939495087,6.525313138961792,0.0 -141800,1.9619166,2.9781027,,,,,,,,,,,,,, -141900,2.1348095,2.927381,,,,,,,,,,,,,, -142000,1.9859146,2.9682748,,,,,,,,,,,,,, -142100,1.9229333,3.6532054,,,,,,,,,,,,,, -142200,2.622701,4.5076027,,,,,,,,,,,,,, -142300,1.9719613,3.1681468,,,,,,,,,,,,,, -142400,1.982521,2.9057593,,,,,,,,,,,,,, -142500,2.1567826,4.161335,,,,,,,,,,,,,, -142600,2.3644683,4.285591,,,,,,,,,,,,,, -142628,,,0.8541601300239563,0.8001324534416199,0.7625600099563599,1.1798006296157837,50000.0,0.6477000117301941,1.7499827146530151,10000.0,64736.44136500359,70027.35831856728,64736.44136500359,5276.812092781067,6.580393314361572,0.0 -142700,1.8604099,3.2406886,,,,,,,,,,,,,, -142800,2.5068095,4.435894,,,,,,,,,,,,,, -142900,1.9230639,3.4255004,,,,,,,,,,,,,, -143000,1.8938985,3.2300272,,,,,,,,,,,,,, -143100,2.261135,4.363378,,,,,,,,,,,,,, -143200,1.9115651,3.6501546,,,,,,,,,,,,,, -143300,2.2684667,4.104866,,,,,,,,,,,,,, -143400,1.9454132,3.0224416,,,,,,,,,,,,,, -143500,2.035957,2.9085658,,,,,,,,,,,,,, -143555,,,0.8614062070846558,0.7344046831130981,0.7638199925422668,1.1369210481643677,50000.0,0.6407000422477722,1.7233561277389526,10000.0,65156.379346847534,70481.45216488838,65156.379346847534,5310.8514902591705,6.647460222244263,0.0 -143600,1.9056628,2.9315906,,,,,,,,,,,,,, -143700,1.9325607,2.9255219,,,,,,,,,,,,,, -143800,2.5373676,4.4238477,,,,,,,,,,,,,, -143900,2.252722,3.7633874,,,,,,,,,,,,,, -144000,2.5037696,4.4566884,,,,,,,,,,,,,, -144100,1.7990234,2.893804,,,,,,,,,,,,,, -144200,2.1192186,3.5885203,,,,,,,,,,,,,, -144300,2.3850484,4.4855046,,,,,,,,,,,,,, -144400,2.1876428,3.3764,,,,,,,,,,,,,, -144480,,,0.852832019329071,0.7703531980514526,0.7658999562263489,1.1399189233779907,50000.0,0.6437000036239624,1.7280446290969849,10000.0,65576.38734292984,70937.19880628586,65576.38734292984,5346.491506576538,6.697022199630737,0.0 -144500,2.24133,3.93181,,,,,,,,,,,,,, -144600,1.9352282,3.3613539,,,,,,,,,,,,,, -144700,1.9796005,3.5625572,,,,,,,,,,,,,, -144800,2.0876992,3.2055306,,,,,,,,,,,,,, -144900,2.0186272,2.99948,,,,,,,,,,,,,, -145000,1.8597388,3.3759315,,,,,,,,,,,,,, -145100,2.1964831,3.1319065,,,,,,,,,,,,,, -145200,2.2158425,2.947545,,,,,,,,,,,,,, -145300,2.8150828,4.417467,,,,,,,,,,,,,, -145400,2.1279802,2.9200568,,,,,,,,,,,,,, -145406,,,0.8582421541213989,0.7630525827407837,0.7664799690246582,1.1516687870025637,50000.0,0.6492000222206116,1.734040141105652,10000.0,65996.66863155365,71394.53972387314,65996.66863155365,5383.4498608112335,6.750396251678467,0.0 -145500,2.0009654,3.214011,,,,,,,,,,,,,, -145600,1.8562875,3.564445,,,,,,,,,,,,,, -145700,2.0368967,2.8972583,,,,,,,,,,,,,, -145800,2.0694232,3.0094354,,,,,,,,,,,,,, -145900,2.5089839,4.4607506,,,,,,,,,,,,,, -146000,2.1071086,2.909724,,,,,,,,,,,,,, -146100,2.0174234,2.9084554,,,,,,,,,,,,,, -146200,1.9830493,2.8695858,,,,,,,,,,,,,, -146300,1.8571547,2.8409524,,,,,,,,,,,,,, -146333,,,0.8623632788658142,0.7401703000068665,0.7676799893379211,1.1452099084854126,50000.0,0.6490000486373901,1.7267446517944336,10000.0,66416.9311683178,71851.3624317646,66416.9311683178,5419.915347337723,6.796186208724976,0.0 -146400,2.0489047,3.5389953,,,,,,,,,,,,,, -146500,2.595894,4.44678,,,,,,,,,,,,,, -146600,2.0776813,2.9237914,,,,,,,,,,,,,, -146700,2.129312,2.9421277,,,,,,,,,,,,,, -146800,3.1576746,4.428259,,,,,,,,,,,,,, -146900,2.0570211,4.1462383,,,,,,,,,,,,,, -147000,2.537463,3.4832258,,,,,,,,,,,,,, -147100,1.9420533,3.4722316,,,,,,,,,,,,,, -147200,2.096438,2.9140086,,,,,,,,,,,,,, -147259,,,0.8596093654632568,0.7663670778274536,0.7688199877738953,1.1503161191940308,50000.0,0.6520000100135803,1.7251871824264526,10000.0,66837.13448381424,72306.05270385742,66837.13448381424,5454.302681922913,6.847928524017334,0.0 -147300,2.3016596,4.0876365,,,,,,,,,,,,,, -147400,1.9544008,3.1584964,,,,,,,,,,,,,, -147500,2.2891674,4.1582007,,,,,,,,,,,,,, -147600,2.3534205,4.2554235,,,,,,,,,,,,,, -147700,2.0292513,3.9537377,,,,,,,,,,,,,, -147800,2.6526759,4.331373,,,,,,,,,,,,,, -147900,2.0678694,2.9124582,,,,,,,,,,,,,, -148000,2.1893637,2.9444356,,,,,,,,,,,,,, -148100,3.1957805,4.2901297,,,,,,,,,,,,,, -148183,,,0.8612304329872131,0.7369097471237183,0.7663999795913696,1.1357945203781128,50000.0,0.6517000198364258,1.709851622581482,10000.0,67257.20768213272,72764.31754755974,67257.20768213272,5492.399060964584,6.8949620723724365,0.0 -148200,2.192278,2.9400175,,,,,,,,,,,,,, -148300,2.090666,2.8632398,,,,,,,,,,,,,, -148400,1.9878772,2.9349632,,,,,,,,,,,,,, -148500,2.1107244,2.8328433,,,,,,,,,,,,,, -148600,2.3731222,4.042704,,,,,,,,,,,,,, -148700,2.3249433,3.923339,,,,,,,,,,,,,, -148800,2.2135642,3.7376754,,,,,,,,,,,,,, -148900,2.1142347,3.6415849,,,,,,,,,,,,,, -149000,2.546147,4.211537,,,,,,,,,,,,,, -149100,2.0537636,2.9394183,,,,,,,,,,,,,, -149109,,,0.8688085675239563,0.7321873903274536,0.7706199884414673,1.136370062828064,50000.0,0.6541000604629517,1.7144485712051392,10000.0,67677.46754312515,73219.93533587456,67677.46754312515,5527.658671617508,6.944510221481323,0.0 -149200,2.1827605,3.2409742,,,,,,,,,,,,,, -149300,2.1281915,2.8555052,,,,,,,,,,,,,, -149400,2.2263138,2.986229,,,,,,,,,,,,,, -149500,2.128229,3.17859,,,,,,,,,,,,,, -149600,2.0916371,2.8468475,,,,,,,,,,,,,, -149700,2.202701,2.9444702,,,,,,,,,,,,,, -149800,2.33771,2.8685777,,,,,,,,,,,,,, -149900,2.2942595,2.9135962,,,,,,,,,,,,,, -150000,2.4397967,4.241616,,,,,,,,,,,,,, -150035,,,0.8769726157188416,0.7250710725784302,0.7684800028800964,1.159796953201294,50000.0,0.653700053691864,1.7364274263381958,10000.0,68097.54691076279,73675.49411344528,68097.54691076279,5563.043598890305,6.990314960479736,0.0 -150100,2.1207206,2.8905838,,,,,,,,,,,,,, -150200,2.589519,4.1047463,,,,,,,,,,,,,, -150300,2.6179538,4.334231,,,,,,,,,,,,,, -150400,2.1714294,3.6497335,,,,,,,,,,,,,, -150500,2.1456838,2.8478396,,,,,,,,,,,,,, -150600,2.0833774,3.3518004,,,,,,,,,,,,,, -150700,2.0236187,3.0951395,,,,,,,,,,,,,, -150800,2.1513417,2.8886142,,,,,,,,,,,,,, -150900,2.2858598,2.9510472,,,,,,,,,,,,,, -150961,,,0.8651366829872131,0.7272911667823792,0.7701199650764465,1.1263177394866943,50000.0,0.6527000069618225,1.708383321762085,10000.0,68517.9379067421,74134.7340862751,68517.9379067421,5601.793773651123,7.040832042694092,0.0 -151000,2.1351073,3.3879964,,,,,,,,,,,,,, -151100,1.9808552,3.2683196,,,,,,,,,,,,,, -151200,2.0856411,2.8419888,,,,,,,,,,,,,, -151300,2.2307189,2.8221028,,,,,,,,,,,,,, -151400,2.1760406,2.9745934,,,,,,,,,,,,,, -151500,2.9750054,4.2082024,,,,,,,,,,,,,, -151600,2.020949,2.8652768,,,,,,,,,,,,,, -151700,2.1643846,3.0217247,,,,,,,,,,,,,, -151800,2.2538466,3.8552234,,,,,,,,,,,,,, -151889,,,0.8691796660423279,0.7026453018188477,0.770039975643158,1.108427882194519,50000.0,0.6552000045776367,1.6844028234481812,10000.0,68938.07565045357,74587.0642721653,68938.07565045357,5633.890254974365,7.088500022888184,0.0 -151900,2.7942848,4.254463,,,,,,,,,,,,,, -152000,2.1481578,3.5161476,,,,,,,,,,,,,, -152100,2.2049942,2.8630767,,,,,,,,,,,,,, -152200,2.0781176,3.5582414,,,,,,,,,,,,,, -152300,2.9974754,4.1873193,,,,,,,,,,,,,, -152400,2.703864,4.184177,,,,,,,,,,,,,, -152500,2.3657458,2.890067,,,,,,,,,,,,,, -152600,2.2139866,2.9786272,,,,,,,,,,,,,, -152700,1.996444,3.0082872,,,,,,,,,,,,,, -152800,2.1183717,3.1319115,,,,,,,,,,,,,, -152814,,,0.8727929592132568,0.7038527727127075,0.7734599709510803,1.1162408590316772,50000.0,0.6562000513076782,1.7009633779525757,10000.0,69358.09606146812,75041.11749482155,69358.09606146812,5667.825475215912,7.13709831237793,0.0 -152900,2.130423,2.9249206,,,,,,,,,,,,,, -153000,2.261627,2.8474498,,,,,,,,,,,,,, -153100,2.1571758,3.5034828,,,,,,,,,,,,,, -153200,2.6130264,4.060331,,,,,,,,,,,,,, -153300,2.2904007,3.4226043,,,,,,,,,,,,,, -153400,2.1661932,3.6073973,,,,,,,,,,,,,, -153500,2.448317,3.5426602,,,,,,,,,,,,,, -153600,2.1420455,2.8847141,,,,,,,,,,,,,, -153700,2.1429372,2.8077445,,,,,,,,,,,,,, -153742,,,0.8701562285423279,0.7306351065635681,0.772379994392395,1.131606936454773,50000.0,0.6551000475883484,1.710476279258728,10000.0,69778.18906927109,75495.08348155022,69778.18906927109,5701.598674058914,7.188451766967773,0.0 -153800,2.143946,2.8455222,,,,,,,,,,,,,, -153900,2.073292,3.0790648,,,,,,,,,,,,,, -154000,2.1823127,2.8508599,,,,,,,,,,,,,, -154100,2.29696,2.891225,,,,,,,,,,,,,, -154200,2.4287028,2.9312987,,,,,,,,,,,,,, -154300,2.1884356,2.9345486,,,,,,,,,,,,,, -154400,2.269541,3.223451,,,,,,,,,,,,,, -154500,2.8783596,4.194627,,,,,,,,,,,,,, -154600,2.259464,2.9298582,,,,,,,,,,,,,, -154667,,,0.8733202815055847,0.6871562004089355,0.7735199928283691,1.1001402139663696,50000.0,0.6605000495910645,1.6771475076675415,10000.0,70198.11680269241,75950.9940161705,70198.11680269241,5737.48500084877,7.237675905227661,0.0 -154700,2.176936,2.8005743,,,,,,,,,,,,,, -154800,2.2366269,3.2196207,,,,,,,,,,,,,, -154900,2.28805,3.1256905,,,,,,,,,,,,,, -155000,3.2405417,4.282035,,,,,,,,,,,,,, -155100,2.3245375,2.8346696,,,,,,,,,,,,,, -155200,2.1421707,3.058859,,,,,,,,,,,,,, -155300,2.1831226,2.8676233,,,,,,,,,,,,,, -155400,2.214321,3.5878882,,,,,,,,,,,,,, -155500,2.2266345,2.845347,,,,,,,,,,,,,, -155588,,,0.874804675579071,0.6974979043006897,0.7736999988555908,1.112828254699707,50000.0,0.6607000231742859,1.696984887123108,10000.0,70618.42568159103,76405.50797510147,70618.42568159103,5771.582266569138,7.296308755874634,0.0 -155600,2.45283,3.6177278,,,,,,,,,,,,,, -155700,2.3188288,2.8584387,,,,,,,,,,,,,, -155800,2.5780225,3.933394,,,,,,,,,,,,,, -155900,2.2975273,3.5174544,,,,,,,,,,,,,, -156000,2.115506,2.899665,,,,,,,,,,,,,, -156100,2.3676755,3.9496033,,,,,,,,,,,,,, -156200,2.3792892,3.6051235,,,,,,,,,,,,,, -156300,2.3215115,3.702243,,,,,,,,,,,,,, -156400,2.3334994,3.6623287,,,,,,,,,,,,,, -156500,2.3055809,2.9588907,,,,,,,,,,,,,, -156512,,,0.8740820288658142,0.7038549780845642,0.7731800079345703,1.112505555152893,50000.0,0.6598000526428223,1.6827988624572754,10000.0,71038.46853804588,76860.0290813446,71038.46853804588,5805.957935094833,7.350852727890015,0.0 -156600,2.3528154,2.8831072,,,,,,,,,,,,,, -156700,2.4394872,3.0070047,,,,,,,,,,,,,, -156800,2.3623958,2.8928301,,,,,,,,,,,,,, -156900,2.498553,4.1399317,,,,,,,,,,,,,, -157000,2.17138,3.396838,,,,,,,,,,,,,, -157100,2.3285058,2.888861,,,,,,,,,,,,,, -157200,2.115873,2.9076009,,,,,,,,,,,,,, -157300,2.3539267,2.849902,,,,,,,,,,,,,, -157400,2.0827062,3.1014698,,,,,,,,,,,,,, -157436,,,0.87367182970047,0.6964795589447021,0.7759000062942505,1.108446717262268,50000.0,0.6623000502586365,1.684056282043457,10000.0,71458.70658421516,77315.83238148689,71458.70658421516,5841.424285888672,7.40127420425415,0.0 -157500,2.2545135,2.7813575,,,,,,,,,,,,,, -157600,2.1737387,3.1145253,,,,,,,,,,,,,, -157700,2.3222997,2.81071,,,,,,,,,,,,,, -157800,2.1419215,2.8537972,,,,,,,,,,,,,, -157900,2.2930064,2.9929967,,,,,,,,,,,,,, -158000,2.3996584,2.955888,,,,,,,,,,,,,, -158100,2.2774985,2.9109457,,,,,,,,,,,,,, -158200,2.2249014,2.9731393,,,,,,,,,,,,,, -158300,2.262205,3.0269594,,,,,,,,,,,,,, -158361,,,0.8794335722923279,0.675494372844696,0.7756199836730957,1.0992703437805176,50000.0,0.659000039100647,1.6712520122528076,10000.0,71878.72174358368,77770.12656998634,71878.72174358368,5875.608007192612,7.447989463806152,0.0 -158400,2.2857442,3.0950174,,,,,,,,,,,,,, -158500,2.278481,2.8002338,,,,,,,,,,,,,, -158600,2.3086019,2.8665586,,,,,,,,,,,,,, -158700,2.1507947,3.1818314,,,,,,,,,,,,,, -158800,2.3602824,3.3586297,,,,,,,,,,,,,, -158900,2.2653725,2.794571,,,,,,,,,,,,,, -159000,2.3623044,3.6825635,,,,,,,,,,,,,, -159100,2.349705,3.8366146,,,,,,,,,,,,,, -159200,2.4667323,3.8849714,,,,,,,,,,,,,, -159284,,,0.8847265243530273,0.6685106754302979,0.7770799994468689,1.111339449882507,50000.0,0.6622000336647034,1.6857061386108398,10000.0,72298.6612625122,78227.87518262863,72298.6612625122,5913.313067674637,7.503507614135742,0.0 -159300,2.2664034,2.7994843,,,,,,,,,,,,,, -159400,2.4647062,3.9610164,,,,,,,,,,,,,, -159500,2.216067,3.1479487,,,,,,,,,,,,,, -159600,2.2435665,2.7500286,,,,,,,,,,,,,, -159700,2.357574,2.9765375,,,,,,,,,,,,,, -159800,2.2509105,2.7945325,,,,,,,,,,,,,, -159900,2.3095026,2.8633235,,,,,,,,,,,,,, -160000,2.508447,2.943634,,,,,,,,,,,,,, -160100,2.3521156,2.8146112,,,,,,,,,,,,,, -160200,2.3001606,3.0891256,,,,,,,,,,,,,, -160211,,,0.8788671493530273,0.681140124797821,0.7768399715423584,1.102988362312317,50000.0,0.6640000343322754,1.670935869216919,10000.0,72718.89306354523,78685.8155503273,72718.89306354523,5950.919641017914,7.556835651397705,0.0 -160300,2.3496335,2.8460941,,,,,,,,,,,,,, -160400,2.7159426,2.8607655,,,,,,,,,,,,,, -160500,2.5340526,3.8423383,,,,,,,,,,,,,, -160600,2.27536,3.8426263,,,,,,,,,,,,,, -160700,3.7269213,3.6039646,,,,,,,,,,,,,, -160800,2.775683,4.0304093,,,,,,,,,,,,,, -160900,2.8118365,4.0863547,,,,,,,,,,,,,, -161000,2.4290087,3.0864372,,,,,,,,,,,,,, -161100,2.3390222,2.8332527,,,,,,,,,,,,,, -161137,,,0.8810937404632568,0.6635440587997437,0.7783399820327759,1.0918289422988892,50000.0,0.663800060749054,1.6655369997024536,10000.0,73138.84251356125,79141.17946362495,73138.84251356125,5986.23615694046,7.605699300765991,0.0 -161200,2.424767,3.6501095,,,,,,,,,,,,,, -161300,2.3497922,2.8615007,,,,,,,,,,,,,, -161400,2.3191957,2.8553295,,,,,,,,,,,,,, -161500,2.3119593,3.077581,,,,,,,,,,,,,, -161600,2.399961,2.7924612,,,,,,,,,,,,,, -161700,2.2608519,3.1739488,,,,,,,,,,,,,, -161800,3.430387,4.3010163,,,,,,,,,,,,,, -161900,2.4341066,2.9896703,,,,,,,,,,,,,, -162000,2.4156277,3.3124793,,,,,,,,,,,,,, -162061,,,0.885546863079071,0.6447293162345886,0.7795199751853943,1.0828590393066406,50000.0,0.6646000146865845,1.662644624710083,10000.0,73558.78927731514,79598.76367640495,73558.78927731514,6023.772705554962,7.658042430877685,0.0 -162100,2.525978,3.904363,,,,,,,,,,,,,, -162200,2.3504205,2.8964787,,,,,,,,,,,,,, -162300,2.3956501,2.853299,,,,,,,,,,,,,, -162400,2.3925962,2.7635584,,,,,,,,,,,,,, -162500,2.3293664,2.7964506,,,,,,,,,,,,,, -162600,2.892248,3.5332031,,,,,,,,,,,,,, -162700,2.431116,2.9657617,,,,,,,,,,,,,, -162800,2.2886484,2.7980337,,,,,,,,,,,,,, -162900,3.3365407,4.2664146,,,,,,,,,,,,,, -162987,,,0.8820117115974426,0.656156063079834,0.7808600068092346,1.076492428779602,50000.0,0.666700005531311,1.6550078392028809,10000.0,73978.82945775986,80057.60511755943,73978.82945775986,6062.477011442184,7.70618200302124,0.0 -163000,2.6083858,3.1971245,,,,,,,,,,,,,, -163100,2.274512,3.5848699,,,,,,,,,,,,,, -163200,2.5546072,2.8830457,,,,,,,,,,,,,, -163300,2.482888,2.8106344,,,,,,,,,,,,,, -163400,2.5388825,2.8357644,,,,,,,,,,,,,, -163500,2.3114815,2.770949,,,,,,,,,,,,,, -163600,2.2819524,2.759707,,,,,,,,,,,,,, -163700,3.0977895,4.1644773,,,,,,,,,,,,,, -163800,2.5192833,3.4484873,,,,,,,,,,,,,, -163900,2.3411078,2.832741,,,,,,,,,,,,,, -163914,,,0.8850781321525574,0.6661824584007263,0.780519962310791,1.0905977487564087,50000.0,0.6621000170707703,1.6676348447799685,10000.0,74399.09590768814,80515.47159552574,74399.09590768814,6099.97262597084,7.761868000030518,0.0 -164000,2.3112605,2.7199879,,,,,,,,,,,,,, -164100,2.2791018,2.8552556,,,,,,,,,,,,,, -164200,2.4448552,2.888016,,,,,,,,,,,,,, -164300,2.305649,2.7967353,,,,,,,,,,,,,, -164400,2.9127536,4.141715,,,,,,,,,,,,,, -164500,2.4244006,2.8663108,,,,,,,,,,,,,, -164600,3.0495827,4.1877427,,,,,,,,,,,,,, -164700,2.2564676,2.8177521,,,,,,,,,,,,,, -164800,2.8844087,3.8490262,,,,,,,,,,,,,, -164840,,,0.8875976204872131,0.6562318205833435,0.7809399962425232,1.0894443988800049,50000.0,0.6657000184059143,1.671615719795227,10000.0,74819.25520539284,80973.01725029945,74819.25520539284,6137.260036468506,7.812313795089722,0.0 -164900,2.3695,2.829602,,,,,,,,,,,,,, -165000,2.691974,3.6554925,,,,,,,,,,,,,, -165100,2.4180613,2.8339748,,,,,,,,,,,,,, -165200,2.3836708,2.8025553,,,,,,,,,,,,,, -165300,2.499592,3.840695,,,,,,,,,,,,,, -165400,2.5171752,3.3637183,,,,,,,,,,,,,, -165500,2.4436402,3.022719,,,,,,,,,,,,,, -165600,2.8164077,3.1120665,,,,,,,,,,,,,, -165700,4.0819325,4.238885,,,,,,,,,,,,,, -165766,,,0.8836132884025574,0.6717751026153564,0.7789999842643738,1.0890315771102903,50000.0,0.663800060749054,1.661457896232605,10000.0,75239.2183611393,81428.55065131187,75239.2183611393,6172.728312015533,7.865967750549316,0.0 -165800,2.3904324,3.4376748,,,,,,,,,,,,,, -165900,2.4427984,2.8462584,,,,,,,,,,,,,, -166000,2.8777921,4.082359,,,,,,,,,,,,,, -166100,2.2808008,2.941553,,,,,,,,,,,,,, -166200,3.1271918,3.854386,,,,,,,,,,,,,, -166300,2.5079412,2.8127778,,,,,,,,,,,,,, -166400,2.5956733,3.5753198,,,,,,,,,,,,,, -166500,2.677073,3.8681948,,,,,,,,,,,,,, -166600,3.0536275,3.9935088,,,,,,,,,,,,,, -166692,,,0.8857030868530273,0.6525242328643799,0.7821799516677856,1.081518054008484,50000.0,0.6668000221252441,1.6559531688690186,10000.0,75659.28298974037,81883.62391614914,75659.28298974037,6207.636021375656,7.91825795173645,0.0 -166700,2.3086364,2.8119454,,,,,,,,,,,,,, -166800,3.180475,4.2799306,,,,,,,,,,,,,, -166900,2.4855387,2.8439152,,,,,,,,,,,,,, -167000,2.4503088,2.7440066,,,,,,,,,,,,,, -167100,2.6113863,3.6981814,,,,,,,,,,,,,, -167200,2.3835835,2.7546387,,,,,,,,,,,,,, -167300,2.5545204,2.826694,,,,,,,,,,,,,, -167400,2.5349307,2.7381022,,,,,,,,,,,,,, -167500,2.596671,3.1105826,,,,,,,,,,,,,, -167600,2.4061837,2.884894,,,,,,,,,,,,,, -167619,,,0.8876562118530273,0.6507769823074341,0.7819199562072754,1.0878478288650513,50000.0,0.6665000319480896,1.6662715673446655,10000.0,76079.30256128311,82338.82419419289,76079.30256128311,6242.713820695877,7.97198224067688,0.0 -167700,3.3069446,4.164506,,,,,,,,,,,,,, -167800,2.633469,2.80235,,,,,,,,,,,,,, -167900,2.4844906,3.442528,,,,,,,,,,,,,, -168000,2.5953405,2.7847779,,,,,,,,,,,,,, -168100,2.3907104,2.8012533,,,,,,,,,,,,,, -168200,2.4185655,2.7965448,,,,,,,,,,,,,, -168300,2.3514373,2.9899569,,,,,,,,,,,,,, -168400,2.7113655,3.1207428,,,,,,,,,,,,,, -168500,2.1984138,2.7565513,,,,,,,,,,,,,, -168545,,,0.8912890553474426,0.6282624006271362,0.7827799916267395,1.078118920326233,50000.0,0.6678000092506409,1.6539418697357178,10000.0,76499.28124332428,82792.69729566574,76499.28124332428,6276.506313562393,8.024755239486694,0.0 -168600,2.4339666,2.9543507,,,,,,,,,,,,,, -168700,2.3851287,2.7454243,,,,,,,,,,,,,, -168800,2.828363,3.65323,,,,,,,,,,,,,, -168900,3.5638013,4.2109175,,,,,,,,,,,,,, -169000,2.642706,3.794094,,,,,,,,,,,,,, -169100,2.6014874,2.749109,,,,,,,,,,,,,, -169200,2.378232,2.7746015,,,,,,,,,,,,,, -169300,2.5666687,3.4229333,,,,,,,,,,,,,, -169400,2.4511914,2.7264438,,,,,,,,,,,,,, -169472,,,0.8900976181030273,0.6365549564361572,0.7829799652099609,1.0776472091674805,50000.0,0.6668000221252441,1.6524428129196167,10000.0,76919.63348913193,83247.47095322609,76919.63348913193,6310.818851947784,8.084598779678345,0.0 -169500,2.6802397,3.822583,,,,,,,,,,,,,, -169600,2.3128355,3.0922208,,,,,,,,,,,,,, -169700,2.3418863,2.940598,,,,,,,,,,,,,, -169800,2.4431717,2.8503428,,,,,,,,,,,,,, -169900,2.5532622,2.7310028,,,,,,,,,,,,,, -170000,2.5540836,3.0569158,,,,,,,,,,,,,, -170100,2.6491263,2.7975879,,,,,,,,,,,,,, -170200,2.4909263,2.7323432,,,,,,,,,,,,,, -170300,2.3509438,2.7666476,,,,,,,,,,,,,, -170400,,,0.8862695097923279,0.6478437781333923,0.7828199863433838,1.084557294845581,50000.0,0.6687000393867493,1.6623047590255735,10000.0,77339.8290321827,83705.39258909225,77339.8290321827,6348.446104764938,8.135056495666504,0.0 -170400,2.450406,2.7484744,,,,,,,,,,,,,, -170500,2.9578195,4.0966754,,,,,,,,,,,,,, -170600,2.5858939,2.786906,,,,,,,,,,,,,, -170700,2.5022779,2.733825,,,,,,,,,,,,,, -170800,2.579488,3.0744135,,,,,,,,,,,,,, -170804,,,,,,,,,,,77520.33187580109,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 87cbf402c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -27.33039879798889,0.0,36.198220014572144,1,0,36.198220014572144,0.0010000000474974,6.907756805419922,10000,63.52872085571289,0.0011328124674037,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -64.24186062812805,0.019777774810791,456.5025894641876,861,0,456.5025894641876,0.0112000005319714,6.476788520812988,10000,520.8086996078491,0.0126953125,6.42455530166626,0.0141799999400973,6.434906482696533,50000 -99.12552428245544,0.0505821704864501,876.7815659046173,1784,0,876.7815659046173,0.0274000018835067,5.98668384552002,10000,976.0506365299224,0.0383203104138374,5.833907127380371,0.0350199975073337,5.863661289215088,50000 -130.37237691879272,0.0803611278533935,1296.93590593338,2708,0,1296.93590593338,0.0492000021040439,5.628585338592529,10000,1427.5301036834717,0.0665039047598838,5.404107093811035,0.0612799972295761,5.447628498077393,50000 -167.9686098098755,0.107553482055664,1716.8910706043243,3629,0,1716.8910706043243,0.0685000047087669,5.34055233001709,10000,1885.157268285752,0.0964843705296516,5.060458660125732,0.089819997549057,5.1000800132751465,50000 -204.1009650230408,0.1321187019348144,2136.898129463196,4551,0,2136.898129463196,0.0927000045776367,5.04966402053833,10000,2341.369452238083,0.1328320354223251,4.688483715057373,0.1245999932289123,4.738650321960449,50000 -241.13143801689148,0.1552557945251464,2557.2378079891205,5471,0,2557.2378079891205,0.1232000067830085,4.732566833496094,10000,2798.8110077381134,0.1771679669618606,4.302088737487793,0.161540001630783,4.381013870239258,50000 -280.170841217041,0.1808259487152099,2977.4658873081207,6394,0,2977.4658873081207,0.1492000073194503,4.503682613372803,10000,3258.1526939868927,0.2212109267711639,3.972068309783936,0.1992399990558624,4.102482318878174,50000 -316.2496666908264,0.205230712890625,3397.741530895233,7317,0,3397.741530895233,0.1853000074625015,4.179815769195557,10000,3714.579946756363,0.2605859339237213,3.638662576675415,0.2458799928426742,3.736096858978272,50000 -352.8280653953552,0.2301778793334961,3817.889068841934,8240,0,3817.889068841934,0.2107000052928924,4.018807411193848,10000,4171.378915309906,0.2956250011920929,3.4536590576171875,0.2745999991893768,3.559567213058472,50000 -390.5762298107147,0.2563011646270752,4238.194683074951,9165,0,4238.194683074951,0.2387000173330307,3.821212768554688,10000,4629.507674217224,0.3382031321525574,3.15322208404541,0.3057200014591217,3.3277082443237305,50000 -425.8897521495819,0.2844047546386719,4658.40030002594,10090,0,4658.40030002594,0.2547000050544739,3.6767055988311768,10000,5085.103245258331,0.3600195348262787,3.029609203338623,0.3360999822616577,3.154364824295044,50000 -457.8882927894592,0.3195030689239502,5078.3580095767975,11012,0,5078.3580095767975,0.2780000269412994,3.559499740600586,10000,5537.1427166461945,0.3813281059265136,2.868154764175415,0.3559399843215942,3.013925552368164,50000 -495.4014315605164,0.8100032806396484,5498.060791969299,11933,0,5498.060791969299,0.2905000150203705,3.4545247554779053,10000,5994.897334814072,0.4097460806369781,2.723881721496582,0.3713999986648559,2.902348041534424,50000 -533.4520351886749,0.8349728584289551,5918.129729747772,12855,0,5918.129729747772,0.3008000254631042,3.418243408203125,10000,6453.089605808258,0.4191406071186065,2.710547924041748,0.3916800022125244,2.8580996990203857,50000 -570.6516087055206,0.8618414402008057,6338.371341228485,13780,0,6338.371341228485,0.3104000091552734,3.3097615242004395,10000,6910.605922698975,0.4387499988079071,2.571505308151245,0.4053399860858917,2.736690282821656,50000 -605.3328831195831,0.8873722553253174,6758.6344385147095,14706,0,6758.6344385147095,0.3216000199317932,3.267530918121338,10000,7365.626133203506,0.4510742127895355,2.485666036605835,0.4161399900913238,2.670121192932129,50000 -641.0547397136688,0.9128148555755616,7178.670622348785,15629,0,7178.670622348785,0.3317000269889831,3.190614938735962,10000,7821.458172559738,0.4708984196186065,2.401356935501098,0.42739999294281,2.6038594245910645,50000 -679.3933956623077,0.9427704811096193,7598.717087745666,16554,0,7598.717087745666,0.3414000272750854,3.112022399902344,10000,8279.92192864418,0.4760351479053497,2.324101448059082,0.4444599747657776,2.4943294525146484,50000 -716.8324489593506,0.9697284698486328,8019.005577802658,17481,0,8019.005577802658,0.3461000025272369,3.0988028049468994,10000,8737.72558760643,0.48291015625,2.2987401485443115,0.4461599886417389,2.4863901138305664,50000 -754.4993402957916,0.9993364810943604,8439.072633743286,18405,0,8439.072633743286,0.3519000113010406,3.06425142288208,10000,9195.537045240402,0.5003905892372131,2.1953961849212646,0.449099987745285,2.455754041671753,50000 -790.1128647327423,1.0250554084777832,8859.13781619072,19326,0,8859.13781619072,0.3622000217437744,2.999075174331665,10000,9651.289949893951,0.4957421720027923,2.2166545391082764,0.4653599858283996,2.381878137588501,50000 -826.458952665329,1.0508620738983154,9279.222533941267,20251,0,9279.222533941267,0.3603000044822693,3.0382838249206543,10000,10107.794107198715,0.494433581829071,2.2306392192840576,0.4623000025749206,2.414182186126709,50000 -865.185097694397,1.0772264003753662,9699.259632349014,21178,0,9699.259632349014,0.3705000281333923,2.947197437286377,10000,10566.632721424105,0.5224609375,2.102747201919556,0.4746599793434143,2.3242554664611816,50000 -901.9166934490204,1.106586217880249,10119.803411722183,22101,0,10119.803411722183,0.3824000060558319,2.909083604812622,10000,11023.986797571182,0.5175585746765137,2.128702402114868,0.4822399914264679,2.302961826324463,50000 -935.33984375,1.1371748447418213,10539.75125527382,23023,0,10539.75125527382,0.3818000257015228,2.911873340606689,10000,11477.437187194824,0.5186718702316284,2.1031439304351807,0.4840799868106842,2.2893989086151123,50000 -972.4291639328004,1.1649210453033447,10959.968374729156,23947,0,10959.968374729156,0.3876000046730041,2.855060577392578,10000,11934.81978225708,0.5406249761581421,2.0112240314483643,0.4944999814033508,2.2325828075408936,50000 -1007.3695442676544,1.192338228225708,11380.2715177536,24875,0,11380.2715177536,0.3956000208854675,2.823624849319458,10000,12390.1386282444,0.5497460961341858,1.9716825485229488,0.5027799606323242,2.185875654220581,50000 -1045.554309129715,1.227823257446289,11800.526648044586,25798,0,11800.526648044586,0.3968000113964081,2.8089516162872314,10000,12848.661672592165,0.5396679639816284,2.003579616546631,0.5015400052070618,2.187553644180298,50000 -1081.7528929710388,1.2576560974121094,12220.824908733368,26720,0,12220.824908733368,0.398900032043457,2.789376735687256,10000,13305.236095428469,0.5524218678474426,1.9628337621688845,0.5076599717140198,2.161849021911621,50000 -1114.786839723587,1.2878186702728271,12641.06434583664,27644,0,12641.06434583664,0.4028000235557556,2.7451486587524414,10000,13758.587542057036,0.5774218440055847,1.7861711978912354,0.5199399590492249,2.085148334503174,50000 -1149.4626867771149,1.321131706237793,13061.35329389572,28568,0,13061.35329389572,0.4127000272274017,2.69579815864563,10000,14213.63385272026,0.5564843416213989,1.902564525604248,0.5218999981880188,2.0633559226989746,50000 -1185.187756061554,1.3507835865020752,13481.723677873611,29493,0,13481.723677873611,0.4073000252246856,2.698023557662964,10000,14669.807126045229,0.5679101347923279,1.855337381362915,0.5210599899291992,2.056286096572876,50000 -1219.5118567943573,1.3829002380371094,13902.074913978577,30418,0,13902.074913978577,0.4095000326633453,2.7296950817108154,10000,15124.563690185549,0.5766796469688416,1.8442224264144893,0.5261200070381165,2.0872247219085693,50000 -1259.265320301056,1.4137554168701172,14322.042706489565,31342,0,14322.042706489565,0.4130000174045563,2.724041700363159,10000,15584.364404439926,0.5602929592132568,1.9060845375061035,0.5257999897003174,2.082095146179199,50000 -1291.5704834461212,1.4477450847625732,14742.00146341324,32267,0,14742.00146341324,0.4201000332832336,2.6634790897369385,10000,16036.710843086244,0.5718163847923279,1.8387211561203003,0.5347399711608887,2.0247344970703125,50000 -1326.4956283569336,1.4855966567993164,15162.095947265623,33192,0,15162.095947265623,0.4217000305652618,2.6554317474365234,10000,16491.816769123077,0.5762890577316284,1.802590131759644,0.5324400067329407,2.0144195556640625,50000 -1362.09840965271,1.5149762630462646,15582.160798072817,34115,0,15582.160798072817,0.424200028181076,2.650291204452514,10000,16947.562863588333,0.5796874761581421,1.8108338117599487,0.533840000629425,2.0201752185821533,50000 -1398.3604464530945,1.5435802936553955,16002.27912735939,35038,0,16002.27912735939,0.4234000146389007,2.645028591156006,10000,17404.02609181404,0.5776171684265137,1.8064240217208865,0.5378199815750122,1.9878385066986084,50000 -1437.056425333023,1.5760712623596191,16422.558165311813,35963,0,16422.558165311813,0.4192000329494476,2.647194623947144,10000,17863.082113027573,0.5839062333106995,1.7839018106460571,0.5367000102996826,1.999006986618042,50000 -1474.281730890274,1.6072852611541748,16842.82310938835,36889,0,16842.82310938835,0.4297000169754028,2.5912978649139404,10000,18320.652250528336,0.6082812547683716,1.6362735033035278,0.5448200106620789,1.936899781227112,50000 -1512.0637485980988,1.636359453201294,17262.928169488907,37814,0,17262.928169488907,0.4208000302314758,2.656367540359497,10000,18778.617341041565,0.5752734541893005,1.8207134008407595,0.5349000096321106,2.0128426551818848,50000 -1548.39102768898,1.6686515808105469,17683.129102230072,38739,0,17683.129102230072,0.4314000308513641,2.607285737991333,10000,19235.227719545364,0.5896679759025574,1.7389613389968872,0.5483999848365784,1.930012464523316,50000 -1585.63733625412,1.6978001594543457,18103.296102762222,39665,0,18103.296102762222,0.4336000084877014,2.5970046520233154,10000,19692.71918320656,0.6051952838897705,1.6787258386611938,0.5502399802207947,1.9336005449295044,50000 -1621.870540380478,1.7322437763214111,18523.21506094933,40591,0,18523.21506094933,0.4307000339031219,2.6018872261047363,10000,20148.956298351288,0.5888671875,1.7536296844482422,0.5473399758338928,1.943955898284912,50000 -1659.0322132110596,1.7727165222167969,18943.480389356613,41515,0,18943.480389356613,0.4418000280857086,2.548300504684448,10000,20606.473484039307,0.5938476324081421,1.7254453897476196,0.5546000003814697,1.912533044815064,50000 -1694.7046456336975,1.80218505859375,19363.40782856941,42440,0,19363.40782856941,0.4280000329017639,2.6177022457122803,10000,21062.1515891552,0.5986718535423279,1.7187060117721558,0.5518999695777893,1.947245478630066,50000 -1731.7015480995178,1.8368933200836184,19783.71187520027,43366,0,19783.71187520027,0.4452000260353088,2.529025793075561,10000,21519.535794973373,0.6058202981948853,1.6716299057006836,0.5638200044631958,1.8678628206253047,50000 -1766.042251586914,1.8706142902374268,20204.07502055168,44292,0,20204.07502055168,0.4427000284194946,2.543543100357056,10000,21974.32190656662,0.6031640768051147,1.6963953971862793,0.562279999256134,1.889854431152344,50000 -1802.985370874405,1.9021704196929927,20624.01596546173,45216,0,20624.01596546173,0.4473000168800354,2.5134053230285645,10000,22431.28592157364,0.6105663776397705,1.6492289304733276,0.564579963684082,1.8600382804870603,50000 -1841.3757576942444,1.9337167739868164,21044.385004997253,46141,0,21044.385004997253,0.4474000334739685,2.504635572433472,10000,22890.125202655792,0.6305859088897705,1.5546183586120603,0.5636999607086182,1.8670766353607176,50000 -1876.8982713222504,1.9686899185180664,21464.45929455757,47066,0,21464.45929455757,0.4488000273704529,2.477980613708496,10000,23345.80527472496,0.60595703125,1.6388249397277832,0.5649600028991699,1.8310381174087524,50000 -1912.115308046341,2.0009567737579346,21884.46242260933,47990,0,21884.46242260933,0.4502000212669372,2.5234055519104004,10000,23801.1062412262,0.6087695360183716,1.669023036956787,0.5663599967956543,1.8735897541046145,50000 -1946.673007249832,2.032533407211304,22304.51851439476,48913,0,22304.51851439476,0.4472000300884247,2.5053343772888184,10000,24255.80009675026,0.6250976324081421,1.5893003940582275,0.5663999915122986,1.850501537322998,50000 -1986.1257948875427,2.0658884048461914,22724.775852918625,49839,0,22724.775852918625,0.4528000354766845,2.4803237915039062,10000,24715.59235072136,0.6113671660423279,1.6454654932022097,0.5694599747657776,1.8334907293319704,50000 -2020.339570283889,2.098045825958252,23144.970401525497,50765,0,23144.970401525497,0.4515000283718109,2.4958276748657227,10000,25170.08160185814,0.6122851371765137,1.6380223035812378,0.5708400011062622,1.8340604305267327,50000 -2057.0334384441376,2.137653350830078,23564.94721007347,51688,0,23564.94721007347,0.4626000225543976,2.4470291137695312,10000,25626.840389966965,0.6258788704872131,1.5507782697677612,0.5779600143432617,1.7888613939285278,50000 -2094.8600878715515,2.1691770553588867,23985.300182819366,52613,0,23985.300182819366,0.4631000161170959,2.449575424194336,10000,26085.101460695267,0.6236132383346558,1.5819681882858276,0.5786799788475037,1.7916535139083862,50000 -2127.224277496338,2.20285701751709,24405.514179944992,53538,0,24405.514179944992,0.4600000083446502,2.4227559566497803,10000,26537.762898921967,0.6176171898841858,1.5891666412353516,0.5787599682807922,1.77160382270813,50000 -2163.457891225815,2.23595929145813,24825.61731696129,54464,0,24825.61731696129,0.4535000324249267,2.493711471557617,10000,26994.1825401783,0.6196874976158142,1.635029435157776,0.5768199563026428,1.8407946825027464,50000 -2198.990670442581,2.656230926513672,25245.34061932564,55384,0,25245.34061932564,0.4604000151157379,2.4601993560791016,10000,27449.90674352646,0.6449999809265137,1.504424214363098,0.5765599608421326,1.8060835599899288,50000 -2239.8955862522125,2.6976802349090576,25665.65577197075,56309,0,25665.65577197075,0.4602000117301941,2.413660049438477,10000,27911.216849565502,0.621777355670929,1.5656689405441284,0.5817599892616272,1.751799702644348,50000 -2277.402089357376,2.7310519218444824,26086.07477974892,57234,0,26086.07477974892,0.4628000259399414,2.429422378540039,10000,28369.223430871964,0.6290820240974426,1.5716264247894287,0.5814799666404724,1.7926048040390017,50000 -2314.5219326019287,2.77089786529541,26506.14146876335,58159,0,26506.14146876335,0.4617000222206116,2.406198501586914,10000,28826.498693466187,0.6390234231948853,1.509754657745361,0.5824999809265137,1.7665115594863892,50000 -2348.078119754792,2.8062548637390137,26926.27724289894,59083,0,26926.27724289894,0.4697000086307525,2.399991273880005,10000,29280.27455854416,0.6219140291213989,1.569996356964111,0.5807999968528748,1.7605829238891602,50000 -2383.9525430202484,2.843522310256958,27346.266345262527,60007,0,27346.266345262527,0.4677000343799591,2.3942172527313232,10000,29736.22474217415,0.6282616853713989,1.5407322645187378,0.5867399573326111,1.7369784116744995,50000 -2422.922830820084,2.881842613220215,27766.477121591568,60933,0,27766.477121591568,0.4727000296115875,2.403726816177368,10000,30195.49317908287,0.6407226324081421,1.5198386907577517,0.5875200033187866,1.75584077835083,50000 -2457.7355239391327,2.923898696899414,28186.81658530236,61858,0,28186.81658530236,0.4716000258922577,2.3759047985076904,10000,30650.73554301262,0.6363281011581421,1.5116547346115112,0.5927199721336365,1.7103793621063232,50000 -2492.127090215683,2.957786798477173,28607.062667131424,62782,0,28607.062667131424,0.4659000337123871,2.4069180488586426,10000,31105.45560526848,0.6325976252555847,1.5511209964752195,0.5876799821853638,1.7578150033950806,50000 -2530.325961828232,3.003383159637451,29027.318502426147,63709,0,29027.318502426147,0.4752000272274017,2.3501551151275635,10000,31564.00478863716,0.6412500143051147,1.4811878204345703,0.594819962978363,1.6979410648345947,50000 -2568.0701701641083,3.045332431793213,29447.26177978516,64633,0,29447.26177978516,0.4707000255584717,2.37775993347168,10000,32021.78300929069,0.6591210961341858,1.411446452140808,0.5887599587440491,1.7354134321212769,50000 -2604.7873711586,3.084872245788574,29867.51460146904,65560,0,29867.51460146904,0.4760000109672546,2.3619251251220703,10000,32478.842272996902,0.6402148008346558,1.5128191709518433,0.5958399772644043,1.723629355430603,50000 -2640.4960482120514,3.11910343170166,30287.78778719902,66486,0,30287.78778719902,0.4793000221252441,2.334697723388672,10000,32934.90845179558,0.6462304592132568,1.4637900590896606,0.6012399792671204,1.669540286064148,50000 -2676.85613656044,3.155380249023437,30708.07630681992,67413,0,30708.07630681992,0.4799000322818756,2.361106634140014,10000,33391.64254951477,0.6551952958106995,1.444200873374939,0.6001799702644348,1.7015769481658936,50000 -2713.703197956085,3.194664239883423,31128.145755052567,68335,0,31128.145755052567,0.4817000329494476,2.323792695999145,10000,33848.646995306015,0.6418554782867432,1.4940626621246338,0.6005799770355225,1.6846686601638794,50000 -2749.776581287384,3.231003999710083,31548.4422082901,69259,0,31548.4422082901,0.4778000116348266,2.381270170211792,10000,34305.10170960426,0.6425195336341858,1.519197940826416,0.5944199562072754,1.7298539876937866,50000 -2783.305285215378,3.269001007080078,31968.580335378647,70184,0,31968.580335378647,0.4826000332832336,2.329708337783813,10000,34758.855362176895,0.6552343368530273,1.4339993000030518,0.6019399762153625,1.6860952377319336,50000 -2818.3009536266327,3.307988405227661,32388.90991282463,71111,0,32388.90991282463,0.4850000143051147,2.3340394496917725,10000,35214.268662929535,0.6490820050239563,1.4559168815612793,0.6016600131988525,1.6674160957336426,50000 -2853.14493060112,3.345659017562866,32809.21870470047,72037,0,32809.21870470047,0.4869000315666199,2.315235376358032,10000,35669.50822305679,0.6522070169448853,1.439338207244873,0.6070799827575684,1.6486047506332395,50000 -2890.5099818706512,3.38305139541626,33229.53257513046,72961,0,33229.53257513046,0.4834000170230865,2.3369715213775635,10000,36127.273602962494,0.6541991829872131,1.4503023624420166,0.6034799814224243,1.6737196445465088,50000 -2926.0756623744965,3.4222118854522705,33649.60353899002,73887,0,33649.60353899002,0.4837000370025635,2.336214303970337,10000,36582.99809360504,0.6709960699081421,1.3740272521972656,0.6000399589538574,1.691603183746338,50000 -2960.50363445282,3.460890054702759,34069.536211013794,74811,0,34069.536211013794,0.488500028848648,2.28943419456482,10000,37037.44539427757,0.6478906273841858,1.458599090576172,0.6050599813461304,1.654366970062256,50000 -2997.275089740753,3.4971187114715576,34489.581547021866,75734,0,34489.581547021866,0.4951000213623047,2.2635269165039062,10000,37494.34624528885,0.6602148413658142,1.4000751972198486,0.6098600029945374,1.6153631210327148,50000 -3035.50843667984,3.5318963527679443,34909.62964272499,76659,0,34909.62964272499,0.4915000200271606,2.289534568786621,10000,37952.71143436432,0.6640429496765137,1.389413833618164,0.6104399561882019,1.645629644393921,50000 -3071.462685823441,3.570381164550781,35329.89449548721,77586,0,35329.89449548721,0.4926000237464905,2.264949083328247,10000,38409.01891493797,0.6577734351158142,1.4027554988861084,0.6133599877357483,1.6025428771972656,50000 -3106.437283039093,3.609963178634644,35749.847000837326,78511,0,35749.847000837326,0.4919000267982483,2.274023532867432,10000,38864.033963918686,0.6576171517372131,1.4034545421600342,0.6142399907112122,1.60440194606781,50000 -3139.2335624694824,3.6466240882873535,36170.04280281067,79436,0,36170.04280281067,0.4887000322341919,2.2839529514312744,10000,39317.1111471653,0.6655468344688416,1.3878076076507568,0.6120399832725525,1.6240185499191284,50000 -3175.607802391052,3.685581922531128,36590.08802413941,80360,0,36590.08802413941,0.4982000291347503,2.274289131164551,10000,39773.61768245697,0.6602929830551147,1.4149186611175537,0.6128799915313721,1.6237813234329224,50000 -3212.750938415528,3.723136425018311,37010.07461738586,81284,0,37010.07461738586,0.4914000332355499,2.2799715995788574,10000,40230.83346366882,0.6623827815055847,1.4213883876800537,0.6172999739646912,1.631671667098999,50000 -3248.683898210525,3.760983467102051,37430.19033193588,82208,0,37430.19033193588,0.4945000112056732,2.2693800926208496,10000,40686.96852493286,0.6692187190055847,1.3704071044921875,0.6172400116920471,1.5985249280929563,50000 -3286.08683013916,3.796276569366455,37850.23502016068,83133,0,37850.23502016068,0.4952000379562378,2.2838222980499268,10000,41144.50082588196,0.6871874928474426,1.3113399744033811,0.6131600141525269,1.6342535018920898,50000 -3322.1507127285004,3.83987021446228,38270.56211447716,84057,0,38270.56211447716,0.5052000284194946,2.2316551208496094,10000,41600.98387527466,0.6665624976158142,1.3889672756195068,0.6204999685287476,1.596768140792847,50000 -3357.9229278564453,3.876704692840576,38690.58512282372,84982,0,38690.58512282372,0.5042999982833862,2.21973204612732,10000,42056.864844083786,0.6727148294448853,1.3522733449935913,0.6233800053596497,1.5781606435775757,50000 -3393.9766433238983,3.9189043045043945,39110.57365632057,85907,0,39110.57365632057,0.4939000308513641,2.280888557434082,10000,42512.99793553352,0.6804882884025574,1.3314138650894165,0.6174600124359131,1.6178598403930664,50000 -3430.34137749672,3.9617397785186768,39530.89902305603,86832,0,39530.89902305603,0.5002000331878662,2.2396240234375,10000,42969.77918601036,0.669238269329071,1.371837854385376,0.6244199872016907,1.578890323638916,50000 -3465.7896132469177,4.005863666534424,39951.18798828125,87751,0,39951.18798828125,0.5024000406265259,2.2063982486724854,10000,43425.609325408936,0.674023449420929,1.3504189252853394,0.6277799606323242,1.565969944000244,50000 -3500.5497431755066,4.045783042907715,40371.13100576401,88675,0,40371.13100576401,0.5085000395774841,2.1928582191467285,10000,43880.40012168884,0.6822851300239563,1.2980283498764038,0.6274799704551697,1.5562463998794556,50000 -3535.699191570282,4.085117340087891,40791.40249633789,89599,0,40791.40249633789,0.5076000094413757,2.21124529838562,10000,44335.90924882889,0.6733984351158142,1.3592572212219238,0.6269800066947937,1.5658754110336304,50000 -3569.6862609386444,4.122882843017578,41211.55885767937,90524,0,41211.55885767937,0.5056000351905823,2.1989150047302246,10000,44790.13888645172,0.6755468845367432,1.3368297815322876,0.6282599568367004,1.5485241413116455,50000 -3606.986034631729,4.161207437515259,41631.89220118523,91448,0,41631.89220118523,0.513700008392334,2.1726202964782715,10000,45247.85872173309,0.6873828172683716,1.2694848775863647,0.6334599852561951,1.521888256072998,50000 -3645.071723461151,4.2016355991363525,42052.14500403404,92373,0,42052.14500403404,0.5078999996185303,2.2144458293914795,10000,45706.28652644157,0.699414074420929,1.262794017791748,0.6269800066947937,1.5701531171798706,50000 -3677.9450783729553,4.251505374908447,42472.23790502548,93298,0,42472.23790502548,0.5161000490188599,2.159391164779663,10000,46159.35181570053,0.6812304258346558,1.297718167304993,0.63646000623703,1.5049082040786743,50000 -3716.609961032867,4.289197206497192,42892.54639649391,94221,0,42892.54639649391,0.5134000182151794,2.162588119506836,10000,46618.41143655777,0.6886913776397705,1.2740516662597656,0.634880006313324,1.5176643133163452,50000 -3754.973846912384,4.333668947219849,43312.84077787399,95146,0,43312.84077787399,0.5163000226020813,2.128796100616455,10000,47077.16275238991,0.7041796445846558,1.1996102333068848,0.6363599896430969,1.4963476657867432,50000 -3792.851813316345,4.377025365829468,43733.06075167656,96071,0,43733.06075167656,0.5179000496864319,2.110060930252075,10000,47535.35248112679,0.6895898580551147,1.2599599361419678,0.6400399804115295,1.4780635833740234,50000 -3830.168391227722,4.4223480224609375,44153.27914762497,96996,0,44153.27914762497,0.5164999961853027,2.145113468170166,10000,47992.982228040695,0.690234363079071,1.2644189596176147,0.642300009727478,1.498186111450195,50000 -3866.2021346092224,4.46885085105896,44573.4611530304,97914,0,44573.4611530304,0.5138000249862671,2.157207727432251,10000,48449.29330062866,0.6972265243530273,1.2512993812561035,0.6369799971580505,1.5186965465545654,50000 -3904.7241473197937,4.510102987289429,44993.504257678986,98837,0,44993.504257678986,0.5199000239372253,2.1083343029022217,10000,48907.94924497605,0.6911327838897705,1.2608174085617063,0.6458399891853333,1.466307282447815,50000 -3944.386614084244,4.552144289016724,45413.90217757225,99760,0,45413.90217757225,0.5199000239372253,2.098806858062744,10000,49368.09958767891,0.6932421922683716,1.244040608406067,0.6432799696922302,1.4662216901779177,50000 -3982.7625353336334,4.600781679153442,45834.00345420837,100680,0,45834.00345420837,0.5200000405311584,2.113632917404175,10000,49826.67389035225,0.7011132836341858,1.218852400779724,0.6454199552536011,1.4721068143844604,50000 -4021.927656650543,4.6470115184783936,46253.99683356285,101600,0,46253.99683356285,0.522100031375885,2.124752998352051,10000,50285.926607847214,0.7016406059265137,1.2149200439453125,0.6446999907493591,1.46499764919281,50000 -4060.9536135196686,4.69239354133606,46674.343448877335,102523,0,46674.343448877335,0.5236999988555908,2.1005005836486816,10000,50745.39287304878,0.6954296827316284,1.231560230255127,0.6503599882125854,1.4419031143188477,50000 -4096.788547039032,4.73948335647583,47094.48765182495,103444,0,47094.48765182495,0.5247000455856323,2.1184866428375244,10000,51201.4682199955,0.7022265195846558,1.2212992906570437,0.6457799673080444,1.4728409051895142,50000 -4134.106735706329,4.77846884727478,47514.57798838616,104365,0,47514.57798838616,0.528700053691864,2.0928614139556885,10000,51658.96434521675,0.7180859446525574,1.164534091949463,0.6500999927520752,1.4714568853378296,50000 -4172.708475351334,4.825288534164429,47934.9159283638,105288,0,47934.9159283638,0.5332000255584717,2.0801806449890137,10000,52117.99890470505,0.6989648342132568,1.2240859270095823,0.6522799730300903,1.4324557781219482,50000 -4214.411760091782,4.873553991317749,48354.99125123024,106209,0,48354.99125123024,0.5314000248908997,2.071938753128052,10000,52579.87425208092,0.7043554782867432,1.194347858428955,0.650879979133606,1.437902331352234,50000 -4255.533785581589,4.913725852966309,48775.022804260254,107131,0,48775.022804260254,0.534000039100647,2.086843490600586,10000,53041.11679935455,0.717578113079071,1.1480196714401243,0.6521399617195129,1.4363003969192505,50000 -4297.330185413361,4.968867778778076,49194.9777302742,108050,0,49194.9777302742,0.5369000434875488,2.042796850204468,10000,53502.97209262848,0.7084179520606995,1.183990240097046,0.661359965801239,1.3978753089904783,50000 -4334.789614200592,5.015544414520264,49614.986085653305,108970,0,49614.986085653305,0.5343000292778015,2.046653509140014,10000,53960.53556919098,0.7154492139816284,1.1541640758514404,0.6593199968338013,1.3946455717086792,50000 -4371.936519861221,5.060314655303955,50035.14333152771,109892,0,50035.14333152771,0.5376999974250793,2.04797911643982,10000,54417.93356990814,0.7190625071525574,1.153494119644165,0.6587399840354919,1.4109584093093872,50000 -4411.016355514526,5.104965448379517,50455.33958125114,110814,0,50455.33958125114,0.5372000336647034,2.042734384536743,10000,54877.30311059952,0.707324206829071,1.1942484378814695,0.6566799879074097,1.4237462282180786,50000 -4449.826142311096,5.151665925979614,50875.68436551094,111735,0,50875.68436551094,0.5392000079154968,2.033914566040039,10000,55336.55248832703,0.7171288728713989,1.144237995147705,0.6643199920654297,1.3885068893432615,50000 -4489.448220968247,5.197967767715454,51295.67676925659,112656,0,51295.67676925659,0.5385000109672546,2.0181593894958496,10000,55796.26206231117,0.7238867282867432,1.1074589490890503,0.6677199602127075,1.3568204641342163,50000 -4526.354462623596,5.241725444793701,51715.61089348793,113578,0,51715.61089348793,0.5476000308990479,1.9904879331588743,10000,56253.195014476776,0.7433788776397705,1.0337588787078855,0.6680999994277954,1.3554725646972656,50000 -4565.85283613205,5.286535739898682,52135.60233712196,114501,0,52135.60233712196,0.5457000136375427,2.03826904296875,10000,56712.77788186073,0.7174609303474426,1.1560657024383545,0.6663399934768677,1.3868917226791382,50000 -4604.891656398773,5.327480792999268,52555.83881497383,115423,0,52555.83881497383,0.5384000539779663,1.9989176988601685,10000,57172.142484903336,0.7275195121765137,1.0962249040603638,0.6660999655723572,1.3563467264175415,50000 -4641.241088867188,5.377520561218262,52975.88430976868,116343,0,52975.88430976868,0.5468000173568726,1.994463562965393,10000,57628.63620185852,0.7395312190055847,1.0397402048110962,0.6725199818611145,1.3434137105941772,50000 -4680.875846385956,5.427154302597046,53396.02493786812,117267,0,53396.02493786812,0.5462000370025635,1.9877381324768064,10000,58088.51064157486,0.7267968654632568,1.0954101085662842,0.6725599765777588,1.3405753374099731,50000 -4720.6627950668335,5.471050262451172,53816.141922950745,118189,0,53816.141922950745,0.551300048828125,1.9637449979782104,10000,58548.5072760582,0.7324609160423279,1.0725865364074707,0.6753399968147278,1.3302712440490725,50000 -4757.115355014801,5.5186426639556885,54236.35823750496,119112,0,54236.35823750496,0.5527000427246094,1.9856129884719849,10000,59005.27227306366,0.7394335865974426,1.0667121410369873,0.6752600073814392,1.3475139141082764,50000 -4792.37073636055,5.564018726348877,54656.45013904572,120033,0,54656.45013904572,0.5491000413894653,1.9687131643295288,10000,59460.71386647224,0.729199230670929,1.0772337913513184,0.675279974937439,1.3255867958068848,50000 -4830.57146859169,5.6144208908081055,55076.50438141823,120955,0,55076.50438141823,0.5529000163078308,1.95231294631958,10000,59919.06721377373,0.7299999594688416,1.0698095560073853,0.6780999898910522,1.307468056678772,50000 -4867.829748630524,5.6567864418029785,55496.58257865906,121875,0,55496.58257865906,0.5550000071525574,1.928056001663208,10000,60376.49431824684,0.7435156106948853,1.0205103158950806,0.6786999702453613,1.2984049320220947,50000 -4906.302020072937,5.70208215713501,55916.64206838608,122794,0,55916.64206838608,0.5601000189781189,1.9321649074554443,10000,60835.11968517304,0.7494726181030273,1.0063884258270264,0.6796199679374695,1.3093905448913574,50000 -4942.790773868561,5.748456001281738,56336.915759801865,123715,0,56336.915759801865,0.5570999979972839,1.92024028301239,10000,61291.97744345665,0.7424414157867432,1.035762071609497,0.6828199625015259,1.2859761714935305,50000 -4983.189473390579,5.796940326690674,56757.10899662972,124636,0,56757.10899662972,0.5631000399589539,1.9321929216384888,10000,61752.66572880745,0.7409374713897705,1.054355263710022,0.6823399662971497,1.299355149269104,50000 -5024.603069782257,5.8421630859375,57177.09319901466,125557,0,57177.09319901466,0.5631999969482422,1.9148192405700684,10000,62214.15758442879,0.7582616806030273,0.9785751104354858,0.6868399977684021,1.288220763206482,50000 -5065.656978368759,5.892143726348877,57597.37649774552,126481,0,57597.37649774552,0.5613000392913818,1.9243862628936768,10000,62675.59431099892,0.744921863079071,1.0215235948562622,0.6889199614524841,1.2752596139907837,50000 -5104.507493257523,5.944280862808228,58017.32027029991,127404,0,58017.32027029991,0.572100043296814,1.8867379426956177,10000,63134.48843693733,0.7497265338897705,1.0092304944992063,0.6896399855613708,1.2700551748275757,50000 -5141.601754665375,5.991713047027588,58437.31800484657,128326,0,58437.31800484657,0.5698000192642212,1.8772536516189573,10000,63591.67713737488,0.7568749785423279,0.9682868123054504,0.6923800110816956,1.2558635473251345,50000 -5177.357530832291,6.036379098892212,58857.65893149376,129247,0,58857.65893149376,0.5692000389099121,1.8783817291259768,10000,64047.86702609062,0.7479491829872131,0.9964887499809264,0.6901599764823914,1.2488561868667605,50000 -5218.325902700424,6.086957216262817,59277.79971027374,130169,0,59277.79971027374,0.5705000162124634,1.873192310333252,10000,64509.075212717056,0.75501948595047,0.979481041431427,0.6937199831008911,1.2402454614639282,50000 -5259.966984272003,6.138098955154419,59698.05014848709,131090,0,59698.05014848709,0.5722000002861023,1.867382287979126,10000,64971.06675004959,0.7644726634025574,0.9445186257362366,0.6960799694061279,1.2283663749694824,50000 -5298.666358947754,6.187035799026489,60117.96479392052,132012,0,60117.96479392052,0.5773000121116638,1.8444945812225344,10000,65429.777752399445,0.763867199420929,0.9349223375320436,0.6986199617385864,1.2158799171447754,50000 -5336.413270950317,6.234623432159424,60538.21879982948,132935,0,60538.21879982948,0.5715000033378601,1.8723223209381104,10000,65887.87410736084,0.7568749785423279,0.9747950434684752,0.6997199654579163,1.2330219745635986,50000 -5372.358961343765,6.28567361831665,60958.368864774704,133860,0,60958.368864774704,0.5763000249862671,1.8516438007354736,10000,66344.07096099854,0.7635351419448853,0.9517484307289124,0.7021600008010864,1.2254695892333984,50000 -5407.867599010468,6.32945704460144,61378.45201802254,134783,0,61378.45201802254,0.5750000476837158,1.8422235250473025,10000,66799.75529813766,0.7740820050239563,0.883020281791687,0.6988799571990967,1.2179479598999023,50000 -5444.643156290054,6.38027548789978,61798.35517120361,135707,0,61798.35517120361,0.5833000540733337,1.828897714614868,10000,67256.5339550972,0.761914074420929,0.9550774097442628,0.6998999714851379,1.2158119678497314,50000 -5480.798250198364,6.4242448806762695,62218.55541443825,136632,0,62218.55541443825,0.5819000005722046,1.79646897315979,10000,67712.98147702217,0.7692968845367432,0.9049091935157776,0.7047799825668335,1.1906603574752808,50000 -5524.29457783699,6.470351934432983,62638.88868141174,137558,0,62638.88868141174,0.5800000429153442,1.7963061332702637,10000,68176.90643644333,0.7782226204872131,0.8680867552757263,0.7066400051116943,1.1834466457366943,50000 -5560.188486814499,6.537533760070801,63058.98669815064,138483,0,63058.98669815064,0.5848000049591064,1.7935667037963867,10000,68633.01377010345,0.768750011920929,0.9112576842308044,0.7040199637413025,1.1906100511550903,50000 -5604.57666349411,6.587302923202515,63478.89449167252,139405,0,63478.89449167252,0.585800051689148,1.7905820608139038,10000,69097.40825605392,0.775195300579071,0.8927517533302307,0.7079600095748901,1.1707615852355957,50000 -5642.757848501205,6.641981601715088,63898.79946422577,140328,0,63898.79946422577,0.5877000093460083,1.781422138214111,10000,69555.59872603416,0.7830859422683716,0.8475367426872253,0.7130999565124512,1.1560105085372925,50000 -5687.743291378021,6.696745872497559,64318.720710515976,141248,0,64318.720710515976,0.5913000106811523,1.7659611701965332,10000,70020.60980701447,0.7801562547683716,0.8689774870872498,0.7124599814414978,1.151098370552063,50000 -5725.94082069397,6.748511791229248,64738.64544630051,142170,0,64738.64544630051,0.58760005235672,1.802656173706055,10000,70478.83239912987,0.7764452695846558,0.8885748982429504,0.7087399959564209,1.1719297170639038,50000 -5764.303265571594,6.794252634048462,65158.87682819367,143093,0,65158.87682819367,0.5895000100135803,1.753939509391785,10000,70937.52128458023,0.7829882502555847,0.8484798073768616,0.7154799699783325,1.1398857831954956,50000 -5808.492747783661,6.839449644088745,65579.17691516876,144016,0,65579.17691516876,0.5962000489234924,1.7356946468353271,10000,71402.10477161407,0.7986913919448853,0.7790917158126831,0.7196399569511414,1.120820164680481,50000 -5846.021278142929,6.884845018386841,65999.29121685028,144940,0,65999.29121685028,0.5929000377655029,1.7483426332473757,10000,71859.84184217453,0.7861132621765137,0.835200846195221,0.718999981880188,1.122183084487915,50000 -5883.058493375778,6.93181300163269,66419.42792582512,145865,0,66419.42792582512,0.600600004196167,1.7311910390853882,10000,72317.11073088646,0.789746105670929,0.8230050206184387,0.7236599922180176,1.11073899269104,50000 -5922.079788208008,6.981026649475098,66839.75595474243,146791,0,66839.75595474243,0.5986000299453735,1.7203985452651978,10000,72776.55774569511,0.7936913967132568,0.7987895011901855,0.7238999605178833,1.1063193082809448,50000 -5959.401660680771,7.031947135925293,67259.88791394234,147715,0,67259.88791394234,0.6008000373840332,1.724432110786438,10000,73234.11113333702,0.7894726395606995,0.8196678757667542,0.7218999862670898,1.1062513589859009,50000 -5996.285425662994,7.0817482471466064,67680.01170182228,148639,0,67680.01170182228,0.6016000509262085,1.7002700567245483,10000,73691.21638917923,0.7953320145606995,0.7963201999664307,0.724079966545105,1.092710256576538,50000 -6035.759362697601,7.138274669647217,68099.94502019882,149563,0,68099.94502019882,0.6046000123023987,1.7144430875778198,10000,74150.72949790955,0.7990038990974426,0.7832265496253967,0.7269600033760071,1.0977861881256104,50000 -6073.976391792297,7.188690185546875,68519.99004364014,150488,0,68519.99004364014,0.6063000559806824,1.6938530206680298,10000,74609.09123158455,0.7956249713897705,0.7945720553398132,0.7290799617767334,1.082510232925415,50000 -6111.824824333191,7.236317157745361,68940.01628899574,151411,0,68940.01628899574,0.610200047492981,1.6871274709701538,10000,75067.06227970123,0.7971875071525574,0.7923197746276855,0.7300999760627747,1.0846587419509888,50000 -6153.824164628983,7.284022569656372,69360.21492862701,152336,0,69360.21492862701,0.6121000051498413,1.6767473220825195,10000,75529.35670304298,0.8050390481948853,0.7471283078193665,0.7305999994277954,1.0720800161361694,50000 -6189.108179092407,7.329882383346558,69780.14233207703,153221,0,69780.14233207703,0.6136000156402588,1.6623154878616333,10000,75984.6604912281,0.8126757740974426,0.722992479801178,0.7340399622917175,1.0540411472320557,50000 -6223.7734811306,7.37775182723999,70200.330078125,154146,0,70200.330078125,0.6131000518798828,1.6655516624450684,10000,76439.60998177528,0.8069726228713989,0.7545697093009949,0.7346999645233154,1.059226393699646,50000 -6261.632649898529,7.428417921066284,70620.54906439781,155071,0,70620.54906439781,0.6117000579833984,1.6622600555419922,10000,76897.78743886948,0.8126562237739563,0.7402761578559875,0.7357199788093567,1.0593024492263794,50000 -6297.965461015701,7.48820424079895,71040.45113134384,155973,0,71040.45113134384,0.6123000383377075,1.667306900024414,10000,77354.13040804863,0.81556636095047,0.7248516082763672,0.7346000075340271,1.0620793104171753,50000 -6335.5825090408325,7.540948152542114,71460.83744478226,156898,0,71460.83744478226,0.6170000433921814,1.6425464153289795,10000,77812.2350218296,0.8093163967132568,0.7406788468360901,0.7397199869155884,1.0419944524765017,50000 -6372.919394493103,7.591905832290649,71880.89972639084,157824,0,71880.89972639084,0.6149000525474548,1.6336321830749512,10000,78269.73344492912,0.8129101395606995,0.7125070095062256,0.7402399778366089,1.0308599472045898,50000 -6411.03583574295,7.643307447433472,72301.06321072578,158748,0,72301.06321072578,0.6207000017166138,1.616907835006714,10000,78728.11376452446,0.8221093416213989,0.6898041367530823,0.7427399754524231,1.029032826423645,50000 -6451.647526979446,7.691318511962891,72721.30015707016,159671,0,72721.30015707016,0.6194000244140625,1.619384765625,10000,79189.0581202507,0.81800776720047,0.7054724097251892,0.7444599866867065,1.0100157260894775,50000 -6487.851769685745,7.754884719848633,73141.47015810013,160594,0,73141.47015810013,0.6256000399589539,1.6140958070755005,10000,79645.54460167885,0.8192187547683716,0.6979327201843262,0.7441799640655518,1.0171470642089844,50000 -6523.876390695572,7.80646276473999,73561.80247449875,161519,0,73561.80247449875,0.626300036907196,1.5992350578308103,10000,80102.00310349464,0.8247851133346558,0.6756976842880249,0.7467600107192993,1.0093703269958496,50000 -6564.9981191158295,7.855017900466919,73982.11289167404,162446,0,73982.11289167404,0.6248000264167786,1.616760015487671,10000,80563.53231620789,0.8268163800239563,0.6841356754302979,0.7471999526023865,1.016317367553711,50000 -6602.321990013123,7.912995100021362,74402.1906940937,163371,0,74402.1906940937,0.6270000338554382,1.6025879383087158,10000,81021.04125189781,0.8298437595367432,0.6641115546226501,0.7482799887657166,0.9978412389755248,50000 -6637.561418771744,7.962857246398926,74822.27236771584,164296,0,74822.27236771584,0.6296000480651855,1.5875972509384155,10000,81476.4604113102,0.8291991949081421,0.6516792178153992,0.7512399554252625,0.986169457435608,50000 -6675.802627325058,8.023729801177979,75242.47238945961,165222,0,75242.47238945961,0.6282000541687012,1.5926567316055298,10000,81935.01135158539,0.8324413895606995,0.6385616660118103,0.7501399517059326,0.9882864356040956,50000 -6713.923867702484,8.082376480102539,75662.8382794857,166148,0,75662.8382794857,0.6331000328063965,1.5844221115112305,10000,82393.60587334633,0.8308203220367432,0.660889208316803,0.7534799575805664,0.9846858382225036,50000 -6753.028325080872,8.135641813278198,76082.94535136223,167072,0,76082.94535136223,0.6313000321388245,1.5846593379974363,10000,82852.91863751411,0.8302929401397705,0.645725667476654,0.7534199953079224,0.9803726077079772,50000 -6790.670788764954,8.188108682632446,76502.94843864441,167997,0,76502.94843864441,0.6322000026702881,1.5767680406570437,10000,83310.66582012177,0.8384179472923279,0.6208813190460205,0.7554399967193604,0.9750881791114808,50000 -6825.619375705719,8.23816442489624,76923.12551164627,168922,0,76923.12551164627,0.6335000395774841,1.5818995237350464,10000,83765.89028906822,0.8319921493530273,0.6606417894363403,0.756060004234314,0.979549527168274,50000 -6860.926365375519,8.290139198303223,77343.40371894836,169849,0,77343.40371894836,0.6391000151634216,1.561057209968567,10000,84221.57668566704,0.8330858945846558,0.6320017576217651,0.7576000094413757,0.9574471712112427,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/measurements.csv deleted file mode 100644 index 698474817..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1890 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.37300017,6.9077563,,,,,,,,,,,,,, -1,,,0.0011328124674037,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,36.198220014572144,63.52872085571289,36.198220014572144,27.33039879798889,0.0,0.0 -100,0.42597833,6.9058266,,,,,,,,,,,,,, -200,0.48652443,6.891159,,,,,,,,,,,,,, -300,0.58670694,6.8476305,,,,,,,,,,,,,, -400,0.71134,6.8066683,,,,,,,,,,,,,, -500,0.7483395,6.841069,,,,,,,,,,,,,, -600,0.88303244,6.7206297,,,,,,,,,,,,,, -700,1.2797852,6.653428,,,,,,,,,,,,,, -800,1.8315693,6.628335,,,,,,,,,,,,,, -861,,,0.0126953125,6.42455530166626,0.0141799999400973,6.434906482696533,50000.0,0.0112000005319714,6.476788520812988,10000.0,456.5025894641876,520.8086996078491,456.5025894641876,64.24186062812805,0.019777774810791,0.0 -900,1.1432458,6.50592,,,,,,,,,,,,,, -1000,2.3264136,6.4774947,,,,,,,,,,,,,, -1100,1.401816,6.5728126,,,,,,,,,,,,,, -1200,1.700348,6.4333305,,,,,,,,,,,,,, -1300,3.5189166,6.3566217,,,,,,,,,,,,,, -1400,2.0752306,6.666839,,,,,,,,,,,,,, -1500,1.5719353,6.6144657,,,,,,,,,,,,,, -1600,1.7858338,6.1199837,,,,,,,,,,,,,, -1700,1.5524312,6.3396482,,,,,,,,,,,,,, -1784,,,0.0383203104138374,5.833907127380371,0.0350199975073337,5.863661289215088,50000.0,0.0274000018835067,5.98668384552002,10000.0,876.7815659046173,976.0506365299224,876.7815659046173,99.12552428245544,0.0505821704864501,0.0 -1800,2.3561606,6.4788866,,,,,,,,,,,,,, -1900,1.9219406,6.0431094,,,,,,,,,,,,,, -2000,2.6393173,6.0248666,,,,,,,,,,,,,, -2100,1.9970801,6.056867,,,,,,,,,,,,,, -2200,1.4138345,6.662715,,,,,,,,,,,,,, -2300,1.9507396,5.932197,,,,,,,,,,,,,, -2400,2.436392,5.8687153,,,,,,,,,,,,,, -2500,1.7913704,6.3180923,,,,,,,,,,,,,, -2600,1.9546446,5.866639,,,,,,,,,,,,,, -2700,1.9884267,6.478634,,,,,,,,,,,,,, -2708,,,0.0665039047598838,5.404107093811035,0.0612799972295761,5.447628498077393,50000.0,0.0492000021040439,5.628585338592529,10000.0,1296.93590593338,1427.5301036834717,1296.93590593338,130.37237691879272,0.0803611278533935,0.0 -2800,1.9222412,6.2002573,,,,,,,,,,,,,, -2900,2.3847506,5.7565646,,,,,,,,,,,,,, -3000,1.7999232,5.7363815,,,,,,,,,,,,,, -3100,1.6850019,5.7869296,,,,,,,,,,,,,, -3200,1.6641718,6.315088,,,,,,,,,,,,,, -3300,2.0082068,5.6021786,,,,,,,,,,,,,, -3400,1.939098,5.7002716,,,,,,,,,,,,,, -3500,2.323231,5.8219953,,,,,,,,,,,,,, -3600,1.9972314,5.5628233,,,,,,,,,,,,,, -3629,,,0.0964843705296516,5.060458660125732,0.089819997549057,5.1000800132751465,50000.0,0.0685000047087669,5.34055233001709,10000.0,1716.8910706043243,1885.157268285752,1716.8910706043243,167.9686098098755,0.107553482055664,0.0 -3700,1.8179845,5.5539594,,,,,,,,,,,,,, -3800,1.8011016,5.5042276,,,,,,,,,,,,,, -3900,1.8146698,5.5123367,,,,,,,,,,,,,, -4000,2.060421,6.0464797,,,,,,,,,,,,,, -4100,1.9223084,5.917954,,,,,,,,,,,,,, -4200,1.2829739,6.081312,,,,,,,,,,,,,, -4300,1.7545686,5.7360806,,,,,,,,,,,,,, -4400,1.7143345,6.459114,,,,,,,,,,,,,, -4500,1.7890946,5.214184,,,,,,,,,,,,,, -4551,,,0.1328320354223251,4.688483715057373,0.1245999932289123,4.738650321960449,50000.0,0.0927000045776367,5.04966402053833,10000.0,2136.898129463196,2341.369452238083,2136.898129463196,204.1009650230408,0.1321187019348144,0.0 -4600,1.5500369,6.3074284,,,,,,,,,,,,,, -4700,1.4556094,6.161683,,,,,,,,,,,,,, -4800,1.801822,5.380245,,,,,,,,,,,,,, -4900,2.3461208,5.212431,,,,,,,,,,,,,, -5000,1.7890658,5.059899,,,,,,,,,,,,,, -5100,2.0211236,4.9536853,,,,,,,,,,,,,, -5200,1.955784,4.9704423,,,,,,,,,,,,,, -5300,1.9382652,5.2398696,,,,,,,,,,,,,, -5400,2.6802046,4.9539366,,,,,,,,,,,,,, -5471,,,0.1771679669618606,4.302088737487793,0.161540001630783,4.381013870239258,50000.0,0.1232000067830085,4.732566833496094,10000.0,2557.2378079891205,2798.8110077381134,2557.2378079891205,241.13143801689148,0.1552557945251464,0.0 -5500,1.7888178,5.094569,,,,,,,,,,,,,, -5600,1.8045131,4.9888196,,,,,,,,,,,,,, -5700,1.6044209,6.0550976,,,,,,,,,,,,,, -5800,2.882344,4.8557234,,,,,,,,,,,,,, -5900,1.7885385,5.391362,,,,,,,,,,,,,, -6000,1.8310385,4.882785,,,,,,,,,,,,,, -6100,1.8634506,4.730268,,,,,,,,,,,,,, -6200,1.593213,6.21535,,,,,,,,,,,,,, -6300,2.056371,4.684146,,,,,,,,,,,,,, -6394,,,0.2212109267711639,3.972068309783936,0.1992399990558624,4.102482318878174,50000.0,0.1492000073194503,4.503682613372803,10000.0,2977.4658873081207,3258.1526939868927,2977.4658873081207,280.170841217041,0.1808259487152099,0.0 -6400,2.3033197,4.641814,,,,,,,,,,,,,, -6500,1.6357594,4.9867954,,,,,,,,,,,,,, -6600,1.8984728,5.070292,,,,,,,,,,,,,, -6700,1.7262775,6.1139474,,,,,,,,,,,,,, -6800,1.7782893,4.509053,,,,,,,,,,,,,, -6900,1.8531541,4.791422,,,,,,,,,,,,,, -7000,1.7937051,4.475444,,,,,,,,,,,,,, -7100,1.9081416,6.1494217,,,,,,,,,,,,,, -7200,1.584979,5.4391303,,,,,,,,,,,,,, -7300,1.1540654,6.191502,,,,,,,,,,,,,, -7317,,,0.2605859339237213,3.638662576675415,0.2458799928426742,3.736096858978272,50000.0,0.1853000074625015,4.179815769195557,10000.0,3397.741530895233,3714.579946756363,3397.741530895233,316.2496666908264,0.205230712890625,0.0 -7400,1.4600132,6.002787,,,,,,,,,,,,,, -7500,1.673348,5.916168,,,,,,,,,,,,,, -7600,1.7499689,4.6484647,,,,,,,,,,,,,, -7700,1.7727379,4.646594,,,,,,,,,,,,,, -7800,1.4548022,6.192534,,,,,,,,,,,,,, -7900,1.8030996,4.804638,,,,,,,,,,,,,, -8000,1.6264025,4.33706,,,,,,,,,,,,,, -8100,1.9302546,4.344308,,,,,,,,,,,,,, -8200,1.9549801,4.437389,,,,,,,,,,,,,, -8240,,,0.2956250011920929,3.4536590576171875,0.2745999991893768,3.559567213058472,50000.0,0.2107000052928924,4.018807411193848,10000.0,3817.889068841934,4171.378915309906,3817.889068841934,352.8280653953552,0.2301778793334961,0.0 -8300,1.7364444,5.525644,,,,,,,,,,,,,, -8400,2.0980084,4.3559895,,,,,,,,,,,,,, -8500,1.8909012,4.6608214,,,,,,,,,,,,,, -8600,1.9261348,4.107976,,,,,,,,,,,,,, -8700,2.0196378,4.245428,,,,,,,,,,,,,, -8800,1.7549663,4.1580977,,,,,,,,,,,,,, -8900,2.868178,4.040077,,,,,,,,,,,,,, -9000,1.89685,4.1056952,,,,,,,,,,,,,, -9100,1.6331667,4.016892,,,,,,,,,,,,,, -9165,,,0.3382031321525574,3.15322208404541,0.3057200014591217,3.3277082443237305,50000.0,0.2387000173330307,3.821212768554688,10000.0,4238.194683074951,4629.507674217224,4238.194683074951,390.5762298107147,0.2563011646270752,0.0 -9200,1.6459957,4.102604,,,,,,,,,,,,,, -9300,2.0934994,4.0115824,,,,,,,,,,,,,, -9400,2.0213246,4.001642,,,,,,,,,,,,,, -9500,1.6250216,4.8347983,,,,,,,,,,,,,, -9600,1.4457774,5.0695286,,,,,,,,,,,,,, -9700,1.8549728,4.1886764,,,,,,,,,,,,,, -9800,1.4225621,5.4451475,,,,,,,,,,,,,, -9900,1.9178866,3.9132235,,,,,,,,,,,,,, -10000,1.351109,6.0125403,,,,,,,,,,,,,, -10090,,,0.3600195348262787,3.029609203338623,0.3360999822616577,3.154364824295044,50000.0,0.2547000050544739,3.6767055988311768,10000.0,4658.40030002594,5085.103245258331,4658.40030002594,425.8897521495819,0.2844047546386719,0.0 -10100,1.2737653,5.4541106,,,,,,,,,,,,,, -10200,1.7679503,4.187911,,,,,,,,,,,,,, -10300,1.2473125,5.397935,,,,,,,,,,,,,, -10400,1.3790884,4.7922125,,,,,,,,,,,,,, -10500,1.7248993,4.089899,,,,,,,,,,,,,, -10600,1.8133392,3.8666573,,,,,,,,,,,,,, -10700,1.8492883,3.7769725,,,,,,,,,,,,,, -10800,1.4210619,5.8528824,,,,,,,,,,,,,, -10900,1.5805113,3.7789664,,,,,,,,,,,,,, -11000,1.9450583,3.8324013,,,,,,,,,,,,,, -11012,,,0.3813281059265136,2.868154764175415,0.3559399843215942,3.013925552368164,50000.0,0.2780000269412994,3.559499740600586,10000.0,5078.3580095767975,5537.1427166461945,5078.3580095767975,457.8882927894592,0.3195030689239502,0.0 -11100,1.5880663,5.067219,,,,,,,,,,,,,, -11200,1.912439,3.6734433,,,,,,,,,,,,,, -11300,1.463501,5.462715,,,,,,,,,,,,,, -11400,1.7538006,3.9291508,,,,,,,,,,,,,, -11500,1.1797587,5.5322537,,,,,,,,,,,,,, -11600,1.6807449,4.5916333,,,,,,,,,,,,,, -11700,1.5401216,4.324398,,,,,,,,,,,,,, -11800,1.8063177,3.6727524,,,,,,,,,,,,,, -11900,1.748687,3.5721314,,,,,,,,,,,,,, -11933,,,0.4097460806369781,2.723881721496582,0.3713999986648559,2.902348041534424,50000.0,0.2905000150203705,3.4545247554779053,10000.0,5498.060791969299,5994.897334814072,5498.060791969299,495.4014315605164,0.8100032806396484,0.0 -12000,1.4882482,4.357831,,,,,,,,,,,,,, -12100,1.9399594,3.5737407,,,,,,,,,,,,,, -12200,1.6102602,3.8778896,,,,,,,,,,,,,, -12300,1.9849244,3.6955218,,,,,,,,,,,,,, -12400,1.7724942,3.7998493,,,,,,,,,,,,,, -12500,1.5651346,4.0863976,,,,,,,,,,,,,, -12600,1.8412005,3.6058052,,,,,,,,,,,,,, -12700,1.9544392,3.6113594,,,,,,,,,,,,,, -12800,2.0443542,3.5639548,,,,,,,,,,,,,, -12855,,,0.4191406071186065,2.710547924041748,0.3916800022125244,2.8580996990203857,50000.0,0.3008000254631042,3.418243408203125,10000.0,5918.129729747772,6453.089605808258,5918.129729747772,533.4520351886749,0.8349728584289551,0.0 -12900,1.7509378,3.6736445,,,,,,,,,,,,,, -13000,1.8469598,4.1114497,,,,,,,,,,,,,, -13100,1.5370184,3.991217,,,,,,,,,,,,,, -13200,1.6356151,3.5036783,,,,,,,,,,,,,, -13300,1.7451177,3.7054837,,,,,,,,,,,,,, -13400,1.2050917,5.2362947,,,,,,,,,,,,,, -13500,1.3602917,5.158098,,,,,,,,,,,,,, -13600,1.853244,3.6172905,,,,,,,,,,,,,, -13700,1.9070017,3.6452413,,,,,,,,,,,,,, -13780,,,0.4387499988079071,2.571505308151245,0.4053399860858917,2.736690282821656,50000.0,0.3104000091552734,3.3097615242004395,10000.0,6338.371341228485,6910.605922698975,6338.371341228485,570.6516087055206,0.8618414402008057,0.0 -13800,1.058157,5.75858,,,,,,,,,,,,,, -13900,1.2992405,4.4645157,,,,,,,,,,,,,, -14000,1.6355872,3.652089,,,,,,,,,,,,,, -14100,1.8251475,3.429179,,,,,,,,,,,,,, -14200,1.6334975,3.9936137,,,,,,,,,,,,,, -14300,1.7437679,3.5373,,,,,,,,,,,,,, -14400,1.6586443,4.5080676,,,,,,,,,,,,,, -14500,1.6560903,3.5126364,,,,,,,,,,,,,, -14600,1.7032757,4.486407,,,,,,,,,,,,,, -14700,1.1175468,5.849221,,,,,,,,,,,,,, -14706,,,0.4510742127895355,2.485666036605835,0.4161399900913238,2.670121192932129,50000.0,0.3216000199317932,3.267530918121338,10000.0,6758.6344385147095,7365.626133203506,6758.6344385147095,605.3328831195831,0.8873722553253174,0.0 -14800,1.5607185,3.434941,,,,,,,,,,,,,, -14900,1.7718534,3.8572607,,,,,,,,,,,,,, -15000,1.0701704,5.751545,,,,,,,,,,,,,, -15100,1.6974045,3.5949583,,,,,,,,,,,,,, -15200,1.7478367,3.3972282,,,,,,,,,,,,,, -15300,1.7372928,3.4354553,,,,,,,,,,,,,, -15400,1.853926,3.532875,,,,,,,,,,,,,, -15500,1.1503448,5.5463104,,,,,,,,,,,,,, -15600,1.6262782,3.9342992,,,,,,,,,,,,,, -15629,,,0.4708984196186065,2.401356935501098,0.42739999294281,2.6038594245910645,50000.0,0.3317000269889831,3.190614938735962,10000.0,7178.670622348785,7821.458172559738,7178.670622348785,641.0547397136688,0.9128148555755616,0.0 -15700,1.6661576,3.4970126,,,,,,,,,,,,,, -15800,1.8249384,3.4343433,,,,,,,,,,,,,, -15900,1.3368934,4.0375423,,,,,,,,,,,,,, -16000,1.5740063,3.880939,,,,,,,,,,,,,, -16100,1.3790416,4.0451617,,,,,,,,,,,,,, -16200,1.4113088,4.8157916,,,,,,,,,,,,,, -16300,1.7253252,3.3391507,,,,,,,,,,,,,, -16400,1.2220346,5.7276397,,,,,,,,,,,,,, -16500,1.5827582,3.3700457,,,,,,,,,,,,,, -16554,,,0.4760351479053497,2.324101448059082,0.4444599747657776,2.4943294525146484,50000.0,0.3414000272750854,3.112022399902344,10000.0,7598.717087745666,8279.92192864418,7598.717087745666,679.3933956623077,0.9427704811096193,0.0 -16600,1.2253634,5.558523,,,,,,,,,,,,,, -16700,1.1666864,5.4599133,,,,,,,,,,,,,, -16800,1.4167883,3.741876,,,,,,,,,,,,,, -16900,1.1088748,5.6331673,,,,,,,,,,,,,, -17000,1.6232889,3.3493097,,,,,,,,,,,,,, -17100,1.4469479,3.278279,,,,,,,,,,,,,, -17200,1.5125697,3.4061096,,,,,,,,,,,,,, -17300,1.5634573,3.459866,,,,,,,,,,,,,, -17400,1.4864755,3.1960344,,,,,,,,,,,,,, -17481,,,0.48291015625,2.2987401485443115,0.4461599886417389,2.4863901138305664,50000.0,0.3461000025272369,3.0988028049468994,10000.0,8019.005577802658,8737.72558760643,8019.005577802658,716.8324489593506,0.9697284698486328,0.0 -17500,1.6736596,3.3203554,,,,,,,,,,,,,, -17600,1.6618156,3.921867,,,,,,,,,,,,,, -17700,1.412342,3.6757994,,,,,,,,,,,,,, -17800,2.11816,3.2874296,,,,,,,,,,,,,, -17900,1.5834728,3.4522097,,,,,,,,,,,,,, -18000,1.3812642,3.570708,,,,,,,,,,,,,, -18100,1.8236918,3.2161026,,,,,,,,,,,,,, -18200,1.1036028,4.538188,,,,,,,,,,,,,, -18300,1.5874361,3.3792603,,,,,,,,,,,,,, -18400,2.63737,3.6604223,,,,,,,,,,,,,, -18405,,,0.5003905892372131,2.1953961849212646,0.449099987745285,2.455754041671753,50000.0,0.3519000113010406,3.06425142288208,10000.0,8439.072633743286,9195.537045240402,8439.072633743286,754.4993402957916,0.9993364810943604,0.0 -18500,1.6952806,3.3395507,,,,,,,,,,,,,, -18600,1.3986392,3.4159877,,,,,,,,,,,,,, -18700,1.7411808,3.4004622,,,,,,,,,,,,,, -18800,1.3078219,3.9337525,,,,,,,,,,,,,, -18900,2.0986183,3.2437174,,,,,,,,,,,,,, -19000,1.1228839,5.5967174,,,,,,,,,,,,,, -19100,1.6515486,3.2002654,,,,,,,,,,,,,, -19200,1.6679966,3.2835126,,,,,,,,,,,,,, -19300,1.4252049,4.031251,,,,,,,,,,,,,, -19326,,,0.4957421720027923,2.2166545391082764,0.4653599858283996,2.381878137588501,50000.0,0.3622000217437744,2.999075174331665,10000.0,8859.13781619072,9651.289949893951,8859.13781619072,790.1128647327423,1.0250554084777832,0.0 -19400,1.5512142,3.2416737,,,,,,,,,,,,,, -19500,1.3895754,4.952037,,,,,,,,,,,,,, -19600,1.2813475,4.94186,,,,,,,,,,,,,, -19700,1.5515683,3.3352804,,,,,,,,,,,,,, -19800,1.6451185,3.171969,,,,,,,,,,,,,, -19900,2.1034746,3.317772,,,,,,,,,,,,,, -20000,1.5842105,3.1587107,,,,,,,,,,,,,, -20100,1.5102069,3.2269993,,,,,,,,,,,,,, -20200,1.2108798,4.619786,,,,,,,,,,,,,, -20251,,,0.494433581829071,2.2306392192840576,0.4623000025749206,2.414182186126709,50000.0,0.3603000044822693,3.0382838249206543,10000.0,9279.222533941267,10107.794107198715,9279.222533941267,826.458952665329,1.0508620738983154,0.0 -20300,1.1246924,4.465823,,,,,,,,,,,,,, -20400,1.4320343,3.0857444,,,,,,,,,,,,,, -20500,1.6290083,3.7645607,,,,,,,,,,,,,, -20600,1.5140907,3.378902,,,,,,,,,,,,,, -20700,1.5600271,3.1240714,,,,,,,,,,,,,, -20800,1.3111343,3.9312906,,,,,,,,,,,,,, -20900,1.6189282,3.1319005,,,,,,,,,,,,,, -21000,1.1020641,5.0902634,,,,,,,,,,,,,, -21100,1.7230084,3.2245162,,,,,,,,,,,,,, -21178,,,0.5224609375,2.102747201919556,0.4746599793434143,2.3242554664611816,50000.0,0.3705000281333923,2.947197437286377,10000.0,9699.259632349014,10566.632721424105,9699.259632349014,865.185097694397,1.0772264003753662,0.0 -21200,1.308347,5.4624434,,,,,,,,,,,,,, -21300,1.5062495,3.1609516,,,,,,,,,,,,,, -21400,1.5610936,3.1981208,,,,,,,,,,,,,, -21500,1.4312605,3.6534672,,,,,,,,,,,,,, -21600,1.5089612,3.894883,,,,,,,,,,,,,, -21700,0.9926527,5.5591025,,,,,,,,,,,,,, -21800,0.9435192,5.6815243,,,,,,,,,,,,,, -21900,1.6657966,3.2126372,,,,,,,,,,,,,, -22000,1.5684496,3.2540538,,,,,,,,,,,,,, -22100,1.5758462,3.0956376,,,,,,,,,,,,,, -22101,,,0.5175585746765137,2.128702402114868,0.4822399914264679,2.302961826324463,50000.0,0.3824000060558319,2.909083604812622,10000.0,10119.803411722183,11023.986797571182,10119.803411722183,901.9166934490204,1.106586217880249,0.0 -22200,1.0832388,5.261098,,,,,,,,,,,,,, -22300,1.5389814,3.160875,,,,,,,,,,,,,, -22400,1.6327549,3.127139,,,,,,,,,,,,,, -22500,1.1408198,5.2268343,,,,,,,,,,,,,, -22600,1.548507,3.154634,,,,,,,,,,,,,, -22700,1.3842317,3.9369597,,,,,,,,,,,,,, -22800,1.5960348,3.2732937,,,,,,,,,,,,,, -22900,1.6611642,3.0357685,,,,,,,,,,,,,, -23000,1.4791068,3.1256187,,,,,,,,,,,,,, -23023,,,0.5186718702316284,2.1031439304351807,0.4840799868106842,2.2893989086151123,50000.0,0.3818000257015228,2.911873340606689,10000.0,10539.75125527382,11477.437187194824,10539.75125527382,935.33984375,1.1371748447418213,0.0 -23100,1.5635062,3.1830118,,,,,,,,,,,,,, -23200,1.4361187,3.2408354,,,,,,,,,,,,,, -23300,1.5760106,3.1291625,,,,,,,,,,,,,, -23400,1.5737684,3.07983,,,,,,,,,,,,,, -23500,1.5348277,3.1781795,,,,,,,,,,,,,, -23600,1.3360467,4.4179835,,,,,,,,,,,,,, -23700,1.6081805,3.137929,,,,,,,,,,,,,, -23800,1.3067956,5.458396,,,,,,,,,,,,,, -23900,1.49166,3.3237162,,,,,,,,,,,,,, -23947,,,0.5406249761581421,2.0112240314483643,0.4944999814033508,2.2325828075408936,50000.0,0.3876000046730041,2.855060577392578,10000.0,10959.968374729156,11934.81978225708,10959.968374729156,972.4291639328004,1.1649210453033447,0.0 -24000,1.3760645,3.6009717,,,,,,,,,,,,,, -24100,1.7777566,3.1578054,,,,,,,,,,,,,, -24200,1.3838092,4.7665634,,,,,,,,,,,,,, -24300,1.6015176,3.3749456,,,,,,,,,,,,,, -24400,1.1493696,5.1798377,,,,,,,,,,,,,, -24500,1.6055325,3.065099,,,,,,,,,,,,,, -24600,1.2580132,3.6851604,,,,,,,,,,,,,, -24700,1.1523952,4.600312,,,,,,,,,,,,,, -24800,1.6022054,3.0549285,,,,,,,,,,,,,, -24875,,,0.5497460961341858,1.9716825485229488,0.5027799606323242,2.185875654220581,50000.0,0.3956000208854675,2.823624849319458,10000.0,11380.2715177536,12390.1386282444,11380.2715177536,1007.3695442676544,1.192338228225708,0.0 -24900,1.0162607,5.47103,,,,,,,,,,,,,, -25000,1.4912624,3.0962596,,,,,,,,,,,,,, -25100,1.3963573,3.4695518,,,,,,,,,,,,,, -25200,1.4804103,2.9448366,,,,,,,,,,,,,, -25300,1.6618015,2.9080424,,,,,,,,,,,,,, -25400,1.257211,5.1442456,,,,,,,,,,,,,, -25500,1.1032321,4.525561,,,,,,,,,,,,,, -25600,1.2811235,4.248979,,,,,,,,,,,,,, -25700,1.2580194,4.0511303,,,,,,,,,,,,,, -25798,,,0.5396679639816284,2.003579616546631,0.5015400052070618,2.187553644180298,50000.0,0.3968000113964081,2.8089516162872314,10000.0,11800.526648044586,12848.661672592165,11800.526648044586,1045.554309129715,1.227823257446289,0.0 -25800,1.6090914,3.095817,,,,,,,,,,,,,, -25900,1.3044322,4.490026,,,,,,,,,,,,,, -26000,1.3400958,3.901606,,,,,,,,,,,,,, -26100,1.6067307,3.0997565,,,,,,,,,,,,,, -26200,1.5012703,2.99315,,,,,,,,,,,,,, -26300,1.4530071,2.8960376,,,,,,,,,,,,,, -26400,1.7408222,3.2289393,,,,,,,,,,,,,, -26500,1.1838748,5.5907125,,,,,,,,,,,,,, -26600,1.6018428,3.0133743,,,,,,,,,,,,,, -26700,1.6067649,3.4163656,,,,,,,,,,,,,, -26720,,,0.5524218678474426,1.9628337621688845,0.5076599717140198,2.161849021911621,50000.0,0.398900032043457,2.789376735687256,10000.0,12220.824908733368,13305.236095428469,12220.824908733368,1081.7528929710388,1.2576560974121094,0.0 -26800,1.6737268,2.853364,,,,,,,,,,,,,, -26900,1.478857,3.4767275,,,,,,,,,,,,,, -27000,1.6278234,2.8178532,,,,,,,,,,,,,, -27100,1.6053476,2.9304547,,,,,,,,,,,,,, -27200,1.6850529,3.0231724,,,,,,,,,,,,,, -27300,1.6419683,2.891236,,,,,,,,,,,,,, -27400,1.172171,5.287425,,,,,,,,,,,,,, -27500,1.4320843,3.850221,,,,,,,,,,,,,, -27600,1.6666467,2.9059834,,,,,,,,,,,,,, -27644,,,0.5774218440055847,1.7861711978912354,0.5199399590492249,2.085148334503174,50000.0,0.4028000235557556,2.7451486587524414,10000.0,12641.06434583664,13758.587542057036,12641.06434583664,1114.786839723587,1.2878186702728271,0.0 -27700,1.6023941,3.7390347,,,,,,,,,,,,,, -27800,1.3451861,4.15167,,,,,,,,,,,,,, -27900,1.783004,3.012341,,,,,,,,,,,,,, -28000,1.6275117,3.3407164,,,,,,,,,,,,,, -28100,1.4788394,3.159636,,,,,,,,,,,,,, -28200,1.5817634,2.9303594,,,,,,,,,,,,,, -28300,1.6720424,2.9991307,,,,,,,,,,,,,, -28400,1.388335,3.529584,,,,,,,,,,,,,, -28500,1.1218783,5.440316,,,,,,,,,,,,,, -28568,,,0.5564843416213989,1.902564525604248,0.5218999981880188,2.0633559226989746,50000.0,0.4127000272274017,2.69579815864563,10000.0,13061.35329389572,14213.63385272026,13061.35329389572,1149.4626867771149,1.321131706237793,0.0 -28600,1.1644249,5.440513,,,,,,,,,,,,,, -28700,1.3831779,3.3031924,,,,,,,,,,,,,, -28800,1.6168388,2.964358,,,,,,,,,,,,,, -28900,1.6519238,2.911161,,,,,,,,,,,,,, -29000,1.364166,3.381165,,,,,,,,,,,,,, -29100,1.3857639,3.7004044,,,,,,,,,,,,,, -29200,1.35266,4.0577054,,,,,,,,,,,,,, -29300,1.5280288,3.2134442,,,,,,,,,,,,,, -29400,1.5549839,2.9464643,,,,,,,,,,,,,, -29493,,,0.5679101347923279,1.855337381362915,0.5210599899291992,2.056286096572876,50000.0,0.4073000252246856,2.698023557662964,10000.0,13481.723677873611,14669.807126045229,13481.723677873611,1185.187756061554,1.3507835865020752,0.0 -29500,1.5854609,3.0468616,,,,,,,,,,,,,, -29600,1.6988492,2.8645813,,,,,,,,,,,,,, -29700,1.3574148,3.8144295,,,,,,,,,,,,,, -29800,1.1807375,4.330884,,,,,,,,,,,,,, -29900,1.5590962,2.930382,,,,,,,,,,,,,, -30000,1.67818,2.8821087,,,,,,,,,,,,,, -30100,1.7035092,2.8007662,,,,,,,,,,,,,, -30200,1.2661836,5.4529324,,,,,,,,,,,,,, -30300,1.6304663,3.044998,,,,,,,,,,,,,, -30400,1.2974606,4.3759484,,,,,,,,,,,,,, -30418,,,0.5766796469688416,1.8442224264144893,0.5261200070381165,2.0872247219085693,50000.0,0.4095000326633453,2.7296950817108154,10000.0,13902.074913978577,15124.563690185549,13902.074913978577,1219.5118567943573,1.3829002380371094,0.0 -30500,1.3636162,3.2345178,,,,,,,,,,,,,, -30600,1.3405038,5.346722,,,,,,,,,,,,,, -30700,1.682161,2.892911,,,,,,,,,,,,,, -30800,1.7137369,3.2119863,,,,,,,,,,,,,, -30900,1.350501,4.4616885,,,,,,,,,,,,,, -31000,1.4548453,4.1697702,,,,,,,,,,,,,, -31100,1.643667,3.2762032,,,,,,,,,,,,,, -31200,1.7483176,2.9793565,,,,,,,,,,,,,, -31300,1.5641514,2.9736269,,,,,,,,,,,,,, -31342,,,0.5602929592132568,1.9060845375061035,0.5257999897003174,2.082095146179199,50000.0,0.4130000174045563,2.724041700363159,10000.0,14322.042706489565,15584.364404439926,14322.042706489565,1259.265320301056,1.4137554168701172,0.0 -31400,1.7815005,2.7897904,,,,,,,,,,,,,, -31500,1.671011,2.787352,,,,,,,,,,,,,, -31600,1.5658042,2.9071436,,,,,,,,,,,,,, -31700,1.5848938,3.1637602,,,,,,,,,,,,,, -31800,1.254641,3.715336,,,,,,,,,,,,,, -31900,1.1569042,5.3840885,,,,,,,,,,,,,, -32000,1.6660972,2.9295292,,,,,,,,,,,,,, -32100,1.5742198,2.7773244,,,,,,,,,,,,,, -32200,1.6157466,2.7741728,,,,,,,,,,,,,, -32267,,,0.5718163847923279,1.8387211561203003,0.5347399711608887,2.0247344970703125,50000.0,0.4201000332832336,2.6634790897369385,10000.0,14742.00146341324,16036.710843086244,14742.00146341324,1291.5704834461212,1.4477450847625732,0.0 -32300,1.1862491,4.0996437,,,,,,,,,,,,,, -32400,1.6990504,2.9036949,,,,,,,,,,,,,, -32500,1.3419925,4.827058,,,,,,,,,,,,,, -32600,1.450073,5.424258,,,,,,,,,,,,,, -32700,1.6380489,3.10659,,,,,,,,,,,,,, -32800,1.6988572,2.8026912,,,,,,,,,,,,,, -32900,1.4743272,2.7855883,,,,,,,,,,,,,, -33000,1.1361022,5.5026326,,,,,,,,,,,,,, -33100,1.7280283,2.764421,,,,,,,,,,,,,, -33192,,,0.5762890577316284,1.802590131759644,0.5324400067329407,2.0144195556640625,50000.0,0.4217000305652618,2.6554317474365234,10000.0,15162.095947265623,16491.816769123077,15162.095947265623,1326.4956283569336,1.4855966567993164,0.0 -33200,1.8164574,2.8647978,,,,,,,,,,,,,, -33300,1.6595231,2.8001277,,,,,,,,,,,,,, -33400,1.3892373,5.0609994,,,,,,,,,,,,,, -33500,1.654238,2.9987204,,,,,,,,,,,,,, -33600,1.6030756,2.9578044,,,,,,,,,,,,,, -33700,1.2852015,4.3116307,,,,,,,,,,,,,, -33800,1.6578285,2.7657824,,,,,,,,,,,,,, -33900,1.3231609,5.4869175,,,,,,,,,,,,,, -34000,1.3203821,3.856331,,,,,,,,,,,,,, -34100,1.6929438,2.870143,,,,,,,,,,,,,, -34115,,,0.5796874761581421,1.8108338117599487,0.533840000629425,2.0201752185821533,50000.0,0.424200028181076,2.650291204452514,10000.0,15582.160798072817,16947.562863588333,15582.160798072817,1362.09840965271,1.5149762630462646,0.0 -34200,1.6625174,2.9347425,,,,,,,,,,,,,, -34300,1.7419711,3.1070886,,,,,,,,,,,,,, -34400,1.6050385,2.751775,,,,,,,,,,,,,, -34500,1.8286694,2.86703,,,,,,,,,,,,,, -34600,1.6118354,2.7986522,,,,,,,,,,,,,, -34700,1.7007697,2.7792985,,,,,,,,,,,,,, -34800,1.2990096,4.1044135,,,,,,,,,,,,,, -34900,1.7893267,2.9449,,,,,,,,,,,,,, -35000,1.7185395,2.9239755,,,,,,,,,,,,,, -35038,,,0.5776171684265137,1.8064240217208865,0.5378199815750122,1.9878385066986084,50000.0,0.4234000146389007,2.645028591156006,10000.0,16002.27912735939,17404.02609181404,16002.27912735939,1398.3604464530945,1.5435802936553955,0.0 -35100,1.6775104,3.014629,,,,,,,,,,,,,, -35200,1.7342504,2.8464472,,,,,,,,,,,,,, -35300,1.7162131,2.7529604,,,,,,,,,,,,,, -35400,1.5041804,3.9766898,,,,,,,,,,,,,, -35500,1.1755595,4.9938846,,,,,,,,,,,,,, -35600,1.5091228,3.4033382,,,,,,,,,,,,,, -35700,1.2962493,4.2105308,,,,,,,,,,,,,, -35800,1.8465116,2.8243232,,,,,,,,,,,,,, -35900,1.5072592,2.9503555,,,,,,,,,,,,,, -35963,,,0.5839062333106995,1.7839018106460571,0.5367000102996826,1.999006986618042,50000.0,0.4192000329494476,2.647194623947144,10000.0,16422.558165311813,17863.082113027573,16422.558165311813,1437.056425333023,1.5760712623596191,0.0 -36000,1.4766531,2.9360347,,,,,,,,,,,,,, -36100,1.6420459,2.8128223,,,,,,,,,,,,,, -36200,1.6943916,4.1681705,,,,,,,,,,,,,, -36300,1.8740718,2.6990438,,,,,,,,,,,,,, -36400,1.8511367,2.8367686,,,,,,,,,,,,,, -36500,1.5093114,3.6277056,,,,,,,,,,,,,, -36600,1.3161163,4.3469477,,,,,,,,,,,,,, -36700,1.6273096,3.152906,,,,,,,,,,,,,, -36800,1.8194957,2.9925838,,,,,,,,,,,,,, -36889,,,0.6082812547683716,1.6362735033035278,0.5448200106620789,1.936899781227112,50000.0,0.4297000169754028,2.5912978649139404,10000.0,16842.82310938835,18320.652250528336,16842.82310938835,1474.281730890274,1.6072852611541748,0.0 -36900,1.3000292,5.213604,,,,,,,,,,,,,, -37000,1.6043881,2.9531562,,,,,,,,,,,,,, -37100,1.6679748,2.8034143,,,,,,,,,,,,,, -37200,1.4611108,4.6608157,,,,,,,,,,,,,, -37300,1.5274178,3.7051253,,,,,,,,,,,,,, -37400,1.5346107,3.1661239,,,,,,,,,,,,,, -37500,1.6096884,2.8524642,,,,,,,,,,,,,, -37600,1.5611882,5.3367124,,,,,,,,,,,,,, -37700,1.5383921,3.2581828,,,,,,,,,,,,,, -37800,1.6700472,3.039171,,,,,,,,,,,,,, -37814,,,0.5752734541893005,1.8207134008407595,0.5349000096321106,2.0128426551818848,50000.0,0.4208000302314758,2.656367540359497,10000.0,17262.928169488907,18778.617341041565,17262.928169488907,1512.0637485980988,1.636359453201294,0.0 -37900,1.2280855,5.3374104,,,,,,,,,,,,,, -38000,1.7034382,2.8532226,,,,,,,,,,,,,, -38100,1.7731766,2.9761906,,,,,,,,,,,,,, -38200,1.7328243,2.8420708,,,,,,,,,,,,,, -38300,1.4241459,3.800828,,,,,,,,,,,,,, -38400,1.8922381,2.8766348,,,,,,,,,,,,,, -38500,1.7461287,2.8563883,,,,,,,,,,,,,, -38600,1.5116045,3.9136305,,,,,,,,,,,,,, -38700,1.637153,2.8579779,,,,,,,,,,,,,, -38739,,,0.5896679759025574,1.7389613389968872,0.5483999848365784,1.930012464523316,50000.0,0.4314000308513641,2.607285737991333,10000.0,17683.129102230072,19235.227719545364,17683.129102230072,1548.39102768898,1.6686515808105469,0.0 -38800,1.769731,2.7972715,,,,,,,,,,,,,, -38900,1.6927491,2.7894704,,,,,,,,,,,,,, -39000,1.742252,2.7665992,,,,,,,,,,,,,, -39100,1.7668966,2.7852411,,,,,,,,,,,,,, -39200,1.8503375,2.7612307,,,,,,,,,,,,,, -39300,1.5899825,3.3263097,,,,,,,,,,,,,, -39400,1.3577576,4.6820445,,,,,,,,,,,,,, -39500,1.5894073,3.074456,,,,,,,,,,,,,, -39600,1.6481194,3.6474986,,,,,,,,,,,,,, -39665,,,0.6051952838897705,1.6787258386611938,0.5502399802207947,1.9336005449295044,50000.0,0.4336000084877014,2.5970046520233154,10000.0,18103.296102762222,19692.71918320656,18103.296102762222,1585.63733625412,1.6978001594543457,0.0 -39700,1.5611911,2.661648,,,,,,,,,,,,,, -39800,1.5463932,2.7153027,,,,,,,,,,,,,, -39900,1.6932119,3.1992285,,,,,,,,,,,,,, -40000,1.6980314,2.96357,,,,,,,,,,,,,, -40100,1.6048465,3.1798635,,,,,,,,,,,,,, -40200,1.6439555,2.6820185,,,,,,,,,,,,,, -40300,1.5786579,3.5122852,,,,,,,,,,,,,, -40400,1.8307443,2.798053,,,,,,,,,,,,,, -40500,1.8086036,3.351059,,,,,,,,,,,,,, -40591,,,0.5888671875,1.7536296844482422,0.5473399758338928,1.943955898284912,50000.0,0.4307000339031219,2.6018872261047363,10000.0,18523.21506094933,20148.956298351288,18523.21506094933,1621.870540380478,1.7322437763214111,0.0 -40600,1.2885412,4.936473,,,,,,,,,,,,,, -40700,1.2003343,5.4402914,,,,,,,,,,,,,, -40800,1.7014769,2.8034642,,,,,,,,,,,,,, -40900,1.1781977,5.074484,,,,,,,,,,,,,, -41000,1.6513112,3.0942173,,,,,,,,,,,,,, -41100,1.6461152,2.8419185,,,,,,,,,,,,,, -41200,1.7024134,2.9136374,,,,,,,,,,,,,, -41300,1.2750541,4.455202,,,,,,,,,,,,,, -41400,1.6992089,2.6742048,,,,,,,,,,,,,, -41500,1.1895036,5.232976,,,,,,,,,,,,,, -41515,,,0.5938476324081421,1.7254453897476196,0.5546000003814697,1.912533044815064,50000.0,0.4418000280857086,2.548300504684448,10000.0,18943.480389356613,20606.473484039307,18943.480389356613,1659.0322132110596,1.7727165222167969,0.0 -41600,1.1892271,5.3732595,,,,,,,,,,,,,, -41700,1.3206964,5.312147,,,,,,,,,,,,,, -41800,1.3112215,4.179912,,,,,,,,,,,,,, -41900,1.206262,5.2302256,,,,,,,,,,,,,, -42000,1.3748783,3.8599277,,,,,,,,,,,,,, -42100,1.7057569,2.8496099,,,,,,,,,,,,,, -42200,1.7588516,2.6766405,,,,,,,,,,,,,, -42300,1.8967683,2.6890385,,,,,,,,,,,,,, -42400,1.5410005,2.934111,,,,,,,,,,,,,, -42440,,,0.5986718535423279,1.7187060117721558,0.5518999695777893,1.947245478630066,50000.0,0.4280000329017639,2.6177022457122803,10000.0,19363.40782856941,21062.1515891552,19363.40782856941,1694.7046456336975,1.80218505859375,0.0 -42500,1.6558673,2.7959142,,,,,,,,,,,,,, -42600,1.4447519,3.3372014,,,,,,,,,,,,,, -42700,1.4550909,3.8372033,,,,,,,,,,,,,, -42800,1.681265,2.8915014,,,,,,,,,,,,,, -42900,1.3323346,5.3312917,,,,,,,,,,,,,, -43000,1.8133358,2.7125823,,,,,,,,,,,,,, -43100,1.7024717,2.6939936,,,,,,,,,,,,,, -43200,1.443584,4.357647,,,,,,,,,,,,,, -43300,1.4530462,5.303346,,,,,,,,,,,,,, -43366,,,0.6058202981948853,1.6716299057006836,0.5638200044631958,1.8678628206253047,50000.0,0.4452000260353088,2.529025793075561,10000.0,19783.71187520027,21519.535794973373,19783.71187520027,1731.7015480995178,1.8368933200836184,0.0 -43400,1.7013353,2.7859626,,,,,,,,,,,,,, -43500,1.3288811,5.162587,,,,,,,,,,,,,, -43600,1.7133684,2.7125213,,,,,,,,,,,,,, -43700,1.8900753,2.791152,,,,,,,,,,,,,, -43800,1.1535802,4.9806857,,,,,,,,,,,,,, -43900,1.6461555,2.8825996,,,,,,,,,,,,,, -44000,1.6258726,2.798766,,,,,,,,,,,,,, -44100,1.6237692,3.0962787,,,,,,,,,,,,,, -44200,1.6429461,3.1958122,,,,,,,,,,,,,, -44292,,,0.6031640768051147,1.6963953971862793,0.562279999256134,1.889854431152344,50000.0,0.4427000284194946,2.543543100357056,10000.0,20204.07502055168,21974.32190656662,20204.07502055168,1766.042251586914,1.8706142902374268,0.0 -44300,1.5852637,3.3370361,,,,,,,,,,,,,, -44400,1.8276932,2.9207222,,,,,,,,,,,,,, -44500,1.6680237,2.7889888,,,,,,,,,,,,,, -44600,1.8905245,2.814564,,,,,,,,,,,,,, -44700,1.5235058,5.250491,,,,,,,,,,,,,, -44800,1.3404512,5.3034782,,,,,,,,,,,,,, -44900,1.2775,4.4862175,,,,,,,,,,,,,, -45000,1.4380205,4.7432637,,,,,,,,,,,,,, -45100,1.5971119,3.144681,,,,,,,,,,,,,, -45200,1.6971787,2.659863,,,,,,,,,,,,,, -45216,,,0.6105663776397705,1.6492289304733276,0.564579963684082,1.8600382804870603,50000.0,0.4473000168800354,2.5134053230285645,10000.0,20624.01596546173,22431.28592157364,20624.01596546173,1802.985370874405,1.9021704196929927,0.0 -45300,1.3562677,4.5150094,,,,,,,,,,,,,, -45400,1.9046504,2.7842257,,,,,,,,,,,,,, -45500,1.3579247,4.618365,,,,,,,,,,,,,, -45600,1.5482012,3.2894151,,,,,,,,,,,,,, -45700,1.5639957,2.639912,,,,,,,,,,,,,, -45800,1.653368,2.7511902,,,,,,,,,,,,,, -45900,1.5931704,3.5478804,,,,,,,,,,,,,, -46000,2.03828,2.676154,,,,,,,,,,,,,, -46100,1.8633982,2.6740487,,,,,,,,,,,,,, -46141,,,0.6305859088897705,1.5546183586120603,0.5636999607086182,1.8670766353607176,50000.0,0.4474000334739685,2.504635572433472,10000.0,21044.385004997253,22890.125202655792,21044.385004997253,1841.3757576942444,1.9337167739868164,0.0 -46200,1.3947538,4.4647346,,,,,,,,,,,,,, -46300,1.6337618,3.209984,,,,,,,,,,,,,, -46400,1.499475,3.354684,,,,,,,,,,,,,, -46500,1.4344488,5.418329,,,,,,,,,,,,,, -46600,1.6711646,5.196038,,,,,,,,,,,,,, -46700,1.3343142,5.036292,,,,,,,,,,,,,, -46800,1.6352338,2.6762557,,,,,,,,,,,,,, -46900,1.6637764,2.656109,,,,,,,,,,,,,, -47000,1.815074,2.67217,,,,,,,,,,,,,, -47066,,,0.60595703125,1.6388249397277832,0.5649600028991699,1.8310381174087524,50000.0,0.4488000273704529,2.477980613708496,10000.0,21464.45929455757,23345.80527472496,21464.45929455757,1876.8982713222504,1.9686899185180664,0.0 -47100,1.7010157,2.7003534,,,,,,,,,,,,,, -47200,1.655738,2.6954641,,,,,,,,,,,,,, -47300,1.868801,2.8758752,,,,,,,,,,,,,, -47400,1.7149496,2.5044312,,,,,,,,,,,,,, -47500,1.7428986,2.529687,,,,,,,,,,,,,, -47600,1.7441639,2.5668333,,,,,,,,,,,,,, -47700,1.8087193,2.6049976,,,,,,,,,,,,,, -47800,1.2078879,5.3520436,,,,,,,,,,,,,, -47900,1.354286,3.4088783,,,,,,,,,,,,,, -47990,,,0.6087695360183716,1.669023036956787,0.5663599967956543,1.8735897541046145,50000.0,0.4502000212669372,2.5234055519104004,10000.0,21884.46242260933,23801.1062412262,21884.46242260933,1912.115308046341,2.0009567737579346,0.0 -48000,1.4071574,5.2312727,,,,,,,,,,,,,, -48100,1.8399972,2.633653,,,,,,,,,,,,,, -48200,1.4465659,3.7422316,,,,,,,,,,,,,, -48300,1.8134918,2.6363306,,,,,,,,,,,,,, -48400,1.5929892,2.9708335,,,,,,,,,,,,,, -48500,1.6423103,2.9097812,,,,,,,,,,,,,, -48600,1.7771682,2.8554807,,,,,,,,,,,,,, -48700,1.6752936,2.637176,,,,,,,,,,,,,, -48800,1.4833995,3.1767397,,,,,,,,,,,,,, -48900,1.4238981,3.742129,,,,,,,,,,,,,, -48913,,,0.6250976324081421,1.5893003940582275,0.5663999915122986,1.850501537322998,50000.0,0.4472000300884247,2.5053343772888184,10000.0,22304.51851439476,24255.80009675026,22304.51851439476,1946.673007249832,2.032533407211304,0.0 -49000,1.3991693,4.468645,,,,,,,,,,,,,, -49100,1.5110896,5.1600404,,,,,,,,,,,,,, -49200,1.8421386,2.851973,,,,,,,,,,,,,, -49300,1.8436347,2.8776376,,,,,,,,,,,,,, -49400,1.7536244,2.6763024,,,,,,,,,,,,,, -49500,1.8453714,2.6640916,,,,,,,,,,,,,, -49600,1.7941831,2.8726296,,,,,,,,,,,,,, -49700,1.2343717,5.1549625,,,,,,,,,,,,,, -49800,1.7918568,2.863432,,,,,,,,,,,,,, -49839,,,0.6113671660423279,1.6454654932022097,0.5694599747657776,1.8334907293319704,50000.0,0.4528000354766845,2.4803237915039062,10000.0,22724.775852918625,24715.59235072136,22724.775852918625,1986.1257948875427,2.0658884048461914,0.0 -49900,1.5331292,4.668319,,,,,,,,,,,,,, -50000,1.5168471,5.3089957,,,,,,,,,,,,,, -50100,1.8623258,2.5725896,,,,,,,,,,,,,, -50200,1.5386893,4.0409107,,,,,,,,,,,,,, -50300,1.6426343,2.7170262,,,,,,,,,,,,,, -50400,1.5401794,5.324452,,,,,,,,,,,,,, -50500,1.9517945,2.6152225,,,,,,,,,,,,,, -50600,1.6243131,2.6959312,,,,,,,,,,,,,, -50700,1.5559016,3.0390663,,,,,,,,,,,,,, -50765,,,0.6122851371765137,1.6380223035812378,0.5708400011062622,1.8340604305267327,50000.0,0.4515000283718109,2.4958276748657227,10000.0,23144.970401525497,25170.08160185814,23144.970401525497,2020.339570283889,2.098045825958252,0.0 -50800,1.7086135,2.611475,,,,,,,,,,,,,, -50900,1.2573162,5.242277,,,,,,,,,,,,,, -51000,1.6117175,3.0036552,,,,,,,,,,,,,, -51100,1.6563929,2.9304142,,,,,,,,,,,,,, -51200,1.7317266,2.661984,,,,,,,,,,,,,, -51300,1.670887,3.2756286,,,,,,,,,,,,,, -51400,1.6308194,2.6525788,,,,,,,,,,,,,, -51500,1.8254825,2.8077,,,,,,,,,,,,,, -51600,1.8311436,2.764017,,,,,,,,,,,,,, -51688,,,0.6258788704872131,1.5507782697677612,0.5779600143432617,1.7888613939285278,50000.0,0.4626000225543976,2.4470291137695312,10000.0,23564.94721007347,25626.840389966965,23564.94721007347,2057.0334384441376,2.137653350830078,0.0 -51700,1.4757166,3.1988993,,,,,,,,,,,,,, -51800,1.3755474,5.207524,,,,,,,,,,,,,, -51900,1.3933579,4.71828,,,,,,,,,,,,,, -52000,1.5034969,5.3344674,,,,,,,,,,,,,, -52100,1.8326118,2.6410017,,,,,,,,,,,,,, -52200,1.8293729,2.638525,,,,,,,,,,,,,, -52300,1.8450122,2.6198862,,,,,,,,,,,,,, -52400,1.5466734,3.3582363,,,,,,,,,,,,,, -52500,1.7726543,2.5592184,,,,,,,,,,,,,, -52600,1.3822682,3.866578,,,,,,,,,,,,,, -52613,,,0.6236132383346558,1.5819681882858276,0.5786799788475037,1.7916535139083862,50000.0,0.4631000161170959,2.449575424194336,10000.0,23985.300182819366,26085.101460695267,23985.300182819366,2094.8600878715515,2.1691770553588867,0.0 -52700,1.4087508,5.2087975,,,,,,,,,,,,,, -52800,1.8973842,2.6955817,,,,,,,,,,,,,, -52900,1.7552439,2.4352815,,,,,,,,,,,,,, -53000,1.9058326,2.6521063,,,,,,,,,,,,,, -53100,1.7847925,2.7027743,,,,,,,,,,,,,, -53200,1.7311532,2.7329862,,,,,,,,,,,,,, -53300,1.6206868,2.7031956,,,,,,,,,,,,,, -53400,1.8553543,2.634798,,,,,,,,,,,,,, -53500,1.5248749,5.2731733,,,,,,,,,,,,,, -53538,,,0.6176171898841858,1.5891666412353516,0.5787599682807922,1.77160382270813,50000.0,0.4600000083446502,2.4227559566497803,10000.0,24405.514179944992,26537.762898921967,24405.514179944992,2127.224277496338,2.20285701751709,0.0 -53600,1.3061516,5.0220423,,,,,,,,,,,,,, -53700,1.834394,2.6826456,,,,,,,,,,,,,, -53800,1.58393,3.2867227,,,,,,,,,,,,,, -53900,1.5840565,3.5995677,,,,,,,,,,,,,, -54000,1.8290675,2.685316,,,,,,,,,,,,,, -54100,1.8350399,2.648011,,,,,,,,,,,,,, -54200,1.7444959,2.5346966,,,,,,,,,,,,,, -54300,1.30499,5.075966,,,,,,,,,,,,,, -54400,1.775013,4.3192654,,,,,,,,,,,,,, -54464,,,0.6196874976158142,1.635029435157776,0.5768199563026428,1.8407946825027464,50000.0,0.4535000324249267,2.493711471557617,10000.0,24825.61731696129,26994.1825401783,24825.61731696129,2163.457891225815,2.23595929145813,0.0 -54500,1.3346176,4.84982,,,,,,,,,,,,,, -54600,2.1565642,2.852388,,,,,,,,,,,,,, -54700,1.340215,4.484654,,,,,,,,,,,,,, -54800,1.788591,2.9749615,,,,,,,,,,,,,, -54900,1.6210922,2.9241483,,,,,,,,,,,,,, -55000,1.8360444,2.5656664,,,,,,,,,,,,,, -55100,1.9468335,2.6974018,,,,,,,,,,,,,, -55200,1.4271812,5.265208,,,,,,,,,,,,,, -55300,1.8714848,2.6747508,,,,,,,,,,,,,, -55384,,,0.6449999809265137,1.504424214363098,0.5765599608421326,1.8060835599899288,50000.0,0.4604000151157379,2.4601993560791016,10000.0,25245.34061932564,27449.90674352646,25245.34061932564,2198.990670442581,2.656230926513672,0.0 -55400,1.701008,3.4161162,,,,,,,,,,,,,, -55500,1.4560128,4.763506,,,,,,,,,,,,,, -55600,1.9195746,2.8994145,,,,,,,,,,,,,, -55700,1.9588834,2.6719413,,,,,,,,,,,,,, -55800,1.5824037,3.231483,,,,,,,,,,,,,, -55900,1.9401937,2.7199383,,,,,,,,,,,,,, -56000,1.8762082,2.88928,,,,,,,,,,,,,, -56100,1.7814361,2.8634634,,,,,,,,,,,,,, -56200,1.593654,2.5971303,,,,,,,,,,,,,, -56300,1.7482984,2.5193136,,,,,,,,,,,,,, -56309,,,0.621777355670929,1.5656689405441284,0.5817599892616272,1.751799702644348,50000.0,0.4602000117301941,2.413660049438477,10000.0,25665.65577197075,27911.216849565502,25665.65577197075,2239.8955862522125,2.6976802349090576,0.0 -56400,1.8573277,2.5138366,,,,,,,,,,,,,, -56500,1.812493,2.59054,,,,,,,,,,,,,, -56600,1.8841243,2.6108117,,,,,,,,,,,,,, -56700,1.415225,4.7161794,,,,,,,,,,,,,, -56800,1.9485031,2.632871,,,,,,,,,,,,,, -56900,1.9770434,2.6206791,,,,,,,,,,,,,, -57000,1.7703556,2.6449366,,,,,,,,,,,,,, -57100,1.8620182,2.6304588,,,,,,,,,,,,,, -57200,1.8429142,2.4842992,,,,,,,,,,,,,, -57234,,,0.6290820240974426,1.5716264247894287,0.5814799666404724,1.7926048040390017,50000.0,0.4628000259399414,2.429422378540039,10000.0,26086.07477974892,28369.223430871964,26086.07477974892,2277.402089357376,2.7310519218444824,0.0 -57300,1.7164456,2.426474,,,,,,,,,,,,,, -57400,1.611927,3.0264852,,,,,,,,,,,,,, -57500,1.4980942,4.232062,,,,,,,,,,,,,, -57600,1.7724403,2.617171,,,,,,,,,,,,,, -57700,1.6687634,2.987369,,,,,,,,,,,,,, -57800,1.8712789,2.528617,,,,,,,,,,,,,, -57900,1.6588035,2.7374368,,,,,,,,,,,,,, -58000,1.6527678,5.1662207,,,,,,,,,,,,,, -58100,1.4117678,5.2987857,,,,,,,,,,,,,, -58159,,,0.6390234231948853,1.509754657745361,0.5824999809265137,1.7665115594863892,50000.0,0.4617000222206116,2.406198501586914,10000.0,26506.14146876335,28826.498693466187,26506.14146876335,2314.5219326019287,2.77089786529541,0.0 -58200,1.8006989,2.5623677,,,,,,,,,,,,,, -58300,1.422024,3.9178627,,,,,,,,,,,,,, -58400,1.998741,2.6087482,,,,,,,,,,,,,, -58500,2.0569263,2.7128856,,,,,,,,,,,,,, -58600,2.1529753,2.8178058,,,,,,,,,,,,,, -58700,1.3200434,5.1938114,,,,,,,,,,,,,, -58800,1.8935249,2.5492935,,,,,,,,,,,,,, -58900,1.4456884,4.449765,,,,,,,,,,,,,, -59000,1.4579605,4.2566,,,,,,,,,,,,,, -59083,,,0.6219140291213989,1.569996356964111,0.5807999968528748,1.7605829238891602,50000.0,0.4697000086307525,2.399991273880005,10000.0,26926.27724289894,29280.27455854416,26926.27724289894,2348.078119754792,2.8062548637390137,0.0 -59100,1.6755581,4.7070017,,,,,,,,,,,,,, -59200,1.7319807,2.7892556,,,,,,,,,,,,,, -59300,1.8212686,2.6301177,,,,,,,,,,,,,, -59400,1.8236878,2.6074514,,,,,,,,,,,,,, -59500,1.5675632,4.6229706,,,,,,,,,,,,,, -59600,1.9373765,2.6139216,,,,,,,,,,,,,, -59700,1.682399,3.0603445,,,,,,,,,,,,,, -59800,1.4739211,4.7297473,,,,,,,,,,,,,, -59900,1.9228665,2.5402145,,,,,,,,,,,,,, -60000,1.7237867,3.0709121,,,,,,,,,,,,,, -60007,,,0.6282616853713989,1.5407322645187378,0.5867399573326111,1.7369784116744995,50000.0,0.4677000343799591,2.3942172527313232,10000.0,27346.266345262527,29736.22474217415,27346.266345262527,2383.9525430202484,2.843522310256958,0.0 -60100,1.6695901,3.5817385,,,,,,,,,,,,,, -60200,1.6532723,3.3225384,,,,,,,,,,,,,, -60300,1.7154202,2.5150177,,,,,,,,,,,,,, -60400,1.7381582,3.5749528,,,,,,,,,,,,,, -60500,1.7002269,2.6893399,,,,,,,,,,,,,, -60600,1.50115,4.7792606,,,,,,,,,,,,,, -60700,1.9555886,2.6509967,,,,,,,,,,,,,, -60800,1.5475605,4.225151,,,,,,,,,,,,,, -60900,1.7022285,2.9358308,,,,,,,,,,,,,, -60933,,,0.6407226324081421,1.5198386907577517,0.5875200033187866,1.75584077835083,50000.0,0.4727000296115875,2.403726816177368,10000.0,27766.477121591568,30195.49317908287,27766.477121591568,2422.922830820084,2.881842613220215,0.0 -61000,1.8383454,2.6734915,,,,,,,,,,,,,, -61100,1.9147409,2.747397,,,,,,,,,,,,,, -61200,1.8241686,2.5300357,,,,,,,,,,,,,, -61300,1.8345573,3.0514503,,,,,,,,,,,,,, -61400,1.6164205,2.7471504,,,,,,,,,,,,,, -61500,1.6312335,4.3875604,,,,,,,,,,,,,, -61600,2.0217102,2.586555,,,,,,,,,,,,,, -61700,1.9692131,2.539685,,,,,,,,,,,,,, -61800,1.7896278,2.7888165,,,,,,,,,,,,,, -61858,,,0.6363281011581421,1.5116547346115112,0.5927199721336365,1.7103793621063232,50000.0,0.4716000258922577,2.3759047985076904,10000.0,28186.81658530236,30650.73554301262,28186.81658530236,2457.7355239391327,2.923898696899414,0.0 -61900,1.4379796,3.647668,,,,,,,,,,,,,, -62000,1.9921027,2.626394,,,,,,,,,,,,,, -62100,1.5188631,3.222078,,,,,,,,,,,,,, -62200,1.9414092,2.6203012,,,,,,,,,,,,,, -62300,2.0281084,2.5301106,,,,,,,,,,,,,, -62400,1.9405025,2.617451,,,,,,,,,,,,,, -62500,1.646721,5.0383396,,,,,,,,,,,,,, -62600,1.6422204,2.9486141,,,,,,,,,,,,,, -62700,1.6865834,2.8317251,,,,,,,,,,,,,, -62782,,,0.6325976252555847,1.5511209964752195,0.5876799821853638,1.7578150033950806,50000.0,0.4659000337123871,2.4069180488586426,10000.0,28607.062667131424,31105.45560526848,28607.062667131424,2492.127090215683,2.957786798477173,0.0 -62800,1.3495598,3.9968011,,,,,,,,,,,,,, -62900,2.0218105,2.7107813,,,,,,,,,,,,,, -63000,1.8972619,3.790136,,,,,,,,,,,,,, -63100,1.8368883,2.3508508,,,,,,,,,,,,,, -63200,1.5698223,5.2537,,,,,,,,,,,,,, -63300,1.53408,5.0564756,,,,,,,,,,,,,, -63400,1.6692215,2.9885025,,,,,,,,,,,,,, -63500,1.7394073,2.7315567,,,,,,,,,,,,,, -63600,1.8234787,3.0512362,,,,,,,,,,,,,, -63700,1.77458,2.7614894,,,,,,,,,,,,,, -63709,,,0.6412500143051147,1.4811878204345703,0.594819962978363,1.6979410648345947,50000.0,0.4752000272274017,2.3501551151275635,10000.0,29027.318502426147,31564.00478863716,29027.318502426147,2530.325961828232,3.003383159637451,0.0 -63800,1.7086287,2.543862,,,,,,,,,,,,,, -63900,1.6004783,3.1642776,,,,,,,,,,,,,, -64000,1.4354694,4.9444723,,,,,,,,,,,,,, -64100,1.4205966,5.18249,,,,,,,,,,,,,, -64200,1.5015397,3.813425,,,,,,,,,,,,,, -64300,1.671387,4.88787,,,,,,,,,,,,,, -64400,1.8991596,2.601875,,,,,,,,,,,,,, -64500,2.1094854,2.5347435,,,,,,,,,,,,,, -64600,1.8539164,2.733582,,,,,,,,,,,,,, -64633,,,0.6591210961341858,1.411446452140808,0.5887599587440491,1.7354134321212769,50000.0,0.4707000255584717,2.37775993347168,10000.0,29447.26177978516,32021.78300929069,29447.26177978516,2568.0701701641083,3.045332431793213,0.0 -64700,2.1116183,2.542954,,,,,,,,,,,,,, -64800,1.7766927,2.5996764,,,,,,,,,,,,,, -64900,1.6509649,5.1219835,,,,,,,,,,,,,, -65000,1.6147387,5.169769,,,,,,,,,,,,,, -65100,1.601854,5.098976,,,,,,,,,,,,,, -65200,1.7171973,5.23759,,,,,,,,,,,,,, -65300,1.9977808,2.676724,,,,,,,,,,,,,, -65400,1.7448114,2.4377975,,,,,,,,,,,,,, -65500,1.8444533,2.4915688,,,,,,,,,,,,,, -65560,,,0.6402148008346558,1.5128191709518433,0.5958399772644043,1.723629355430603,50000.0,0.4760000109672546,2.3619251251220703,10000.0,29867.51460146904,32478.842272996902,29867.51460146904,2604.7873711586,3.084872245788574,0.0 -65600,1.5061095,5.1432595,,,,,,,,,,,,,, -65700,1.7866008,2.403999,,,,,,,,,,,,,, -65800,1.7443463,3.524503,,,,,,,,,,,,,, -65900,1.4165598,5.1158614,,,,,,,,,,,,,, -66000,1.8802202,2.4448893,,,,,,,,,,,,,, -66100,2.0837924,2.5383677,,,,,,,,,,,,,, -66200,1.7813336,3.0136929,,,,,,,,,,,,,, -66300,1.6940273,2.9512596,,,,,,,,,,,,,, -66400,1.7791399,3.3100584,,,,,,,,,,,,,, -66486,,,0.6462304592132568,1.4637900590896606,0.6012399792671204,1.669540286064148,50000.0,0.4793000221252441,2.334697723388672,10000.0,30287.78778719902,32934.90845179558,30287.78778719902,2640.4960482120514,3.11910343170166,0.0 -66500,1.4732611,4.887159,,,,,,,,,,,,,, -66600,1.895872,2.493001,,,,,,,,,,,,,, -66700,1.9775025,2.5523577,,,,,,,,,,,,,, -66800,1.8720498,2.4662104,,,,,,,,,,,,,, -66900,1.4312798,4.7541885,,,,,,,,,,,,,, -67000,1.7387229,5.046111,,,,,,,,,,,,,, -67100,1.3554116,4.0435104,,,,,,,,,,,,,, -67200,1.8627379,5.287076,,,,,,,,,,,,,, -67300,1.9403273,2.5739849,,,,,,,,,,,,,, -67400,1.8091803,3.0804172,,,,,,,,,,,,,, -67413,,,0.6551952958106995,1.444200873374939,0.6001799702644348,1.7015769481658936,50000.0,0.4799000322818756,2.361106634140014,10000.0,30708.07630681992,33391.64254951477,30708.07630681992,2676.85613656044,3.155380249023437,0.0 -67500,1.6033944,3.5884633,,,,,,,,,,,,,, -67600,1.4999031,3.7175505,,,,,,,,,,,,,, -67700,1.5120479,3.8245754,,,,,,,,,,,,,, -67800,1.8949149,2.4389386,,,,,,,,,,,,,, -67900,1.9951879,2.658758,,,,,,,,,,,,,, -68000,1.7091537,3.4585943,,,,,,,,,,,,,, -68100,1.520302,4.949183,,,,,,,,,,,,,, -68200,1.8400993,2.4455552,,,,,,,,,,,,,, -68300,1.6747528,2.8022313,,,,,,,,,,,,,, -68335,,,0.6418554782867432,1.4940626621246338,0.6005799770355225,1.6846686601638794,50000.0,0.4817000329494476,2.323792695999145,10000.0,31128.145755052567,33848.646995306015,31128.145755052567,2713.703197956085,3.194664239883423,0.0 -68400,1.5476502,3.6803155,,,,,,,,,,,,,, -68500,1.3668845,4.8206687,,,,,,,,,,,,,, -68600,2.024341,2.571824,,,,,,,,,,,,,, -68700,1.5484811,5.017375,,,,,,,,,,,,,, -68800,2.1564367,2.440725,,,,,,,,,,,,,, -68900,1.8757964,2.4286788,,,,,,,,,,,,,, -69000,1.8500644,2.482967,,,,,,,,,,,,,, -69100,1.5367233,3.296997,,,,,,,,,,,,,, -69200,1.791865,2.4768448,,,,,,,,,,,,,, -69259,,,0.6425195336341858,1.519197940826416,0.5944199562072754,1.7298539876937866,50000.0,0.4778000116348266,2.381270170211792,10000.0,31548.4422082901,34305.10170960426,31548.4422082901,2749.776581287384,3.231003999710083,0.0 -69300,1.5142462,4.591647,,,,,,,,,,,,,, -69400,1.5050514,5.0782547,,,,,,,,,,,,,, -69500,1.7616997,2.4677575,,,,,,,,,,,,,, -69600,2.265809,2.3931375,,,,,,,,,,,,,, -69700,1.7544423,2.345945,,,,,,,,,,,,,, -69800,1.4357879,4.4545116,,,,,,,,,,,,,, -69900,1.3603299,4.732403,,,,,,,,,,,,,, -70000,1.998222,2.540416,,,,,,,,,,,,,, -70100,1.5953103,3.2879972,,,,,,,,,,,,,, -70184,,,0.6552343368530273,1.4339993000030518,0.6019399762153625,1.6860952377319336,50000.0,0.4826000332832336,2.329708337783813,10000.0,31968.580335378647,34758.855362176895,31968.580335378647,2783.305285215378,3.269001007080078,0.0 -70200,1.7164377,2.9285738,,,,,,,,,,,,,, -70300,1.8153329,2.3990903,,,,,,,,,,,,,, -70400,1.8817484,2.5085928,,,,,,,,,,,,,, -70500,1.818385,2.445788,,,,,,,,,,,,,, -70600,1.889512,2.527133,,,,,,,,,,,,,, -70700,1.7725137,3.1105604,,,,,,,,,,,,,, -70800,1.4679728,3.6792445,,,,,,,,,,,,,, -70900,1.4330662,4.274659,,,,,,,,,,,,,, -71000,1.645874,4.9025936,,,,,,,,,,,,,, -71100,1.4733726,4.068574,,,,,,,,,,,,,, -71111,,,0.6490820050239563,1.4559168815612793,0.6016600131988525,1.6674160957336426,50000.0,0.4850000143051147,2.3340394496917725,10000.0,32388.90991282463,35214.268662929535,32388.90991282463,2818.3009536266327,3.307988405227661,0.0 -71200,2.0401437,2.535132,,,,,,,,,,,,,, -71300,1.713425,2.9185586,,,,,,,,,,,,,, -71400,1.5220402,4.477433,,,,,,,,,,,,,, -71500,1.7463676,2.617052,,,,,,,,,,,,,, -71600,2.1447625,2.4282088,,,,,,,,,,,,,, -71700,1.7083106,2.972118,,,,,,,,,,,,,, -71800,1.7961851,5.110629,,,,,,,,,,,,,, -71900,1.8533236,2.4153478,,,,,,,,,,,,,, -72000,1.5219353,4.817188,,,,,,,,,,,,,, -72037,,,0.6522070169448853,1.439338207244873,0.6070799827575684,1.6486047506332395,50000.0,0.4869000315666199,2.315235376358032,10000.0,32809.21870470047,35669.50822305679,32809.21870470047,2853.14493060112,3.345659017562866,0.0 -72100,1.7336125,2.7826488,,,,,,,,,,,,,, -72200,1.9457363,2.8020651,,,,,,,,,,,,,, -72300,1.4345623,4.8654776,,,,,,,,,,,,,, -72400,1.9173559,2.5091984,,,,,,,,,,,,,, -72500,2.1704957,3.550307,,,,,,,,,,,,,, -72600,1.595586,4.95952,,,,,,,,,,,,,, -72700,1.7387221,4.677106,,,,,,,,,,,,,, -72800,1.8592494,4.8852305,,,,,,,,,,,,,, -72900,2.1982403,2.5421653,,,,,,,,,,,,,, -72961,,,0.6541991829872131,1.4503023624420166,0.6034799814224243,1.6737196445465088,50000.0,0.4834000170230865,2.3369715213775635,10000.0,33229.53257513046,36127.273602962494,33229.53257513046,2890.5099818706512,3.38305139541626,0.0 -73000,1.4179122,4.654792,,,,,,,,,,,,,, -73100,1.9175744,2.420017,,,,,,,,,,,,,, -73200,1.8490965,2.5270867,,,,,,,,,,,,,, -73300,1.9427776,2.5176907,,,,,,,,,,,,,, -73400,1.4772805,4.266408,,,,,,,,,,,,,, -73500,1.982846,2.4601119,,,,,,,,,,,,,, -73600,1.6674553,4.483461,,,,,,,,,,,,,, -73700,1.5884016,3.6248565,,,,,,,,,,,,,, -73800,1.7741554,2.5563588,,,,,,,,,,,,,, -73887,,,0.6709960699081421,1.3740272521972656,0.6000399589538574,1.691603183746338,50000.0,0.4837000370025635,2.336214303970337,10000.0,33649.60353899002,36582.99809360504,33649.60353899002,2926.0756623744965,3.4222118854522705,0.0 -73900,1.839991,2.321743,,,,,,,,,,,,,, -74000,1.7261285,2.9303308,,,,,,,,,,,,,, -74100,2.0021822,2.4354014,,,,,,,,,,,,,, -74200,1.8389685,2.4963572,,,,,,,,,,,,,, -74300,1.646725,3.3327188,,,,,,,,,,,,,, -74400,1.9440773,2.4988143,,,,,,,,,,,,,, -74500,1.6959355,5.1366153,,,,,,,,,,,,,, -74600,1.8062394,2.8553061,,,,,,,,,,,,,, -74700,1.585364,5.1563497,,,,,,,,,,,,,, -74800,2.0083628,2.5226521,,,,,,,,,,,,,, -74811,,,0.6478906273841858,1.458599090576172,0.6050599813461304,1.654366970062256,50000.0,0.488500028848648,2.28943419456482,10000.0,34069.536211013794,37037.44539427757,34069.536211013794,2960.50363445282,3.460890054702759,0.0 -74900,1.578372,3.9957087,,,,,,,,,,,,,, -75000,1.756043,3.2510257,,,,,,,,,,,,,, -75100,1.9048537,2.4953246,,,,,,,,,,,,,, -75200,2.0072293,2.4564133,,,,,,,,,,,,,, -75300,1.8219575,2.408299,,,,,,,,,,,,,, -75400,1.4011438,4.882656,,,,,,,,,,,,,, -75500,2.0194952,2.4801857,,,,,,,,,,,,,, -75600,1.8903114,2.934157,,,,,,,,,,,,,, -75700,1.9468098,2.4546435,,,,,,,,,,,,,, -75734,,,0.6602148413658142,1.4000751972198486,0.6098600029945374,1.6153631210327148,50000.0,0.4951000213623047,2.2635269165039062,10000.0,34489.581547021866,37494.34624528885,34489.581547021866,2997.275089740753,3.4971187114715576,0.0 -75800,1.7166309,2.9973197,,,,,,,,,,,,,, -75900,1.830271,2.4305692,,,,,,,,,,,,,, -76000,1.9336432,2.3875449,,,,,,,,,,,,,, -76100,1.9091182,2.6505573,,,,,,,,,,,,,, -76200,1.5168608,4.8747234,,,,,,,,,,,,,, -76300,1.8335966,2.5809996,,,,,,,,,,,,,, -76400,2.0350559,2.475803,,,,,,,,,,,,,, -76500,1.946517,2.8936481,,,,,,,,,,,,,, -76600,1.8208874,5.084793,,,,,,,,,,,,,, -76659,,,0.6640429496765137,1.389413833618164,0.6104399561882019,1.645629644393921,50000.0,0.4915000200271606,2.289534568786621,10000.0,34909.62964272499,37952.71143436432,34909.62964272499,3035.50843667984,3.5318963527679443,0.0 -76700,1.9372343,2.519752,,,,,,,,,,,,,, -76800,1.7433604,3.1237228,,,,,,,,,,,,,, -76900,2.110989,2.4552,,,,,,,,,,,,,, -77000,1.9002705,2.4344523,,,,,,,,,,,,,, -77100,1.8906205,2.8310978,,,,,,,,,,,,,, -77200,1.8374726,5.170589,,,,,,,,,,,,,, -77300,1.7115568,3.7333462,,,,,,,,,,,,,, -77400,2.2481277,2.5202522,,,,,,,,,,,,,, -77500,1.6743873,3.6372151,,,,,,,,,,,,,, -77586,,,0.6577734351158142,1.4027554988861084,0.6133599877357483,1.6025428771972656,50000.0,0.4926000237464905,2.264949083328247,10000.0,35329.89449548721,38409.01891493797,35329.89449548721,3071.462685823441,3.570381164550781,0.0 -77600,1.8640174,2.4411728,,,,,,,,,,,,,, -77700,2.0025227,2.640816,,,,,,,,,,,,,, -77800,2.0944788,2.4227178,,,,,,,,,,,,,, -77900,2.0283396,2.4271975,,,,,,,,,,,,,, -78000,1.829429,4.0578613,,,,,,,,,,,,,, -78100,2.082491,2.5204604,,,,,,,,,,,,,, -78200,2.019188,2.3803544,,,,,,,,,,,,,, -78300,1.7879624,2.5444114,,,,,,,,,,,,,, -78400,1.6977426,4.4855995,,,,,,,,,,,,,, -78500,1.9176874,2.5509713,,,,,,,,,,,,,, -78511,,,0.6576171517372131,1.4034545421600342,0.6142399907112122,1.60440194606781,50000.0,0.4919000267982483,2.274023532867432,10000.0,35749.847000837326,38864.033963918686,35749.847000837326,3106.437283039093,3.609963178634644,0.0 -78600,1.9541191,2.4187126,,,,,,,,,,,,,, -78700,2.1135902,2.392499,,,,,,,,,,,,,, -78800,1.7468259,2.8722029,,,,,,,,,,,,,, -78900,1.5390453,3.7352915,,,,,,,,,,,,,, -79000,2.0449355,2.6119545,,,,,,,,,,,,,, -79100,2.0330389,2.3641453,,,,,,,,,,,,,, -79200,1.7304685,4.0083737,,,,,,,,,,,,,, -79300,1.8837395,2.4567893,,,,,,,,,,,,,, -79400,1.9222056,2.3439522,,,,,,,,,,,,,, -79436,,,0.6655468344688416,1.3878076076507568,0.6120399832725525,1.6240185499191284,50000.0,0.4887000322341919,2.2839529514312744,10000.0,36170.04280281067,39317.1111471653,36170.04280281067,3139.2335624694824,3.6466240882873535,0.0 -79500,1.7562793,2.4071267,,,,,,,,,,,,,, -79600,1.8631303,2.6807256,,,,,,,,,,,,,, -79700,1.6427069,3.3876922,,,,,,,,,,,,,, -79800,1.532889,4.536592,,,,,,,,,,,,,, -79900,1.6471877,3.7774289,,,,,,,,,,,,,, -80000,1.8292673,2.8256633,,,,,,,,,,,,,, -80100,1.9966989,2.4288754,,,,,,,,,,,,,, -80200,2.0129201,2.313651,,,,,,,,,,,,,, -80300,1.8137516,2.9599357,,,,,,,,,,,,,, -80360,,,0.6602929830551147,1.4149186611175537,0.6128799915313721,1.6237813234329224,50000.0,0.4982000291347503,2.274289131164551,10000.0,36590.08802413941,39773.61768245697,36590.08802413941,3175.607802391052,3.685581922531128,0.0 -80400,1.687401,3.5204337,,,,,,,,,,,,,, -80500,2.0355692,2.3505526,,,,,,,,,,,,,, -80600,1.6914518,5.029158,,,,,,,,,,,,,, -80700,1.9007697,2.4399142,,,,,,,,,,,,,, -80800,1.8203086,2.6330137,,,,,,,,,,,,,, -80900,1.8536081,2.8258321,,,,,,,,,,,,,, -81000,1.6971166,2.8783598,,,,,,,,,,,,,, -81100,2.213278,2.3221047,,,,,,,,,,,,,, -81200,1.65083,4.315887,,,,,,,,,,,,,, -81284,,,0.6623827815055847,1.4213883876800537,0.6172999739646912,1.631671667098999,50000.0,0.4914000332355499,2.2799715995788574,10000.0,37010.07461738586,40230.83346366882,37010.07461738586,3212.750938415528,3.723136425018311,0.0 -81300,1.5127534,4.50208,,,,,,,,,,,,,, -81400,1.7192092,4.978182,,,,,,,,,,,,,, -81500,1.6557944,3.4983604,,,,,,,,,,,,,, -81600,1.5858883,3.2980156,,,,,,,,,,,,,, -81700,1.7909079,3.9909382,,,,,,,,,,,,,, -81800,2.0208194,2.5399246,,,,,,,,,,,,,, -81900,2.1718824,2.3737745,,,,,,,,,,,,,, -82000,1.9677632,2.386674,,,,,,,,,,,,,, -82100,1.9464653,2.5366774,,,,,,,,,,,,,, -82200,1.6186798,3.616919,,,,,,,,,,,,,, -82208,,,0.6692187190055847,1.3704071044921875,0.6172400116920471,1.5985249280929563,50000.0,0.4945000112056732,2.2693800926208496,10000.0,37430.19033193588,40686.96852493286,37430.19033193588,3248.683898210525,3.760983467102051,0.0 -82300,1.9115113,3.2561378,,,,,,,,,,,,,, -82400,2.0937922,2.4911144,,,,,,,,,,,,,, -82500,2.0224462,2.3746839,,,,,,,,,,,,,, -82600,1.9204481,2.96813,,,,,,,,,,,,,, -82700,1.5687317,3.7525074,,,,,,,,,,,,,, -82800,1.5970072,3.7505004,,,,,,,,,,,,,, -82900,1.5431473,4.6161275,,,,,,,,,,,,,, -83000,2.0789955,2.4177592,,,,,,,,,,,,,, -83100,1.8071272,4.8805156,,,,,,,,,,,,,, -83133,,,0.6871874928474426,1.3113399744033811,0.6131600141525269,1.6342535018920898,50000.0,0.4952000379562378,2.2838222980499268,10000.0,37850.23502016068,41144.50082588196,37850.23502016068,3286.08683013916,3.796276569366455,0.0 -83200,2.1047199,2.4271193,,,,,,,,,,,,,, -83300,2.2073045,2.4410906,,,,,,,,,,,,,, -83400,1.8915052,2.5136304,,,,,,,,,,,,,, -83500,2.1705987,2.3764443,,,,,,,,,,,,,, -83600,1.9796947,2.777493,,,,,,,,,,,,,, -83700,1.8223776,2.887035,,,,,,,,,,,,,, -83800,1.8736705,2.4821527,,,,,,,,,,,,,, -83900,1.7341138,4.498501,,,,,,,,,,,,,, -84000,2.0051544,2.3810127,,,,,,,,,,,,,, -84057,,,0.6665624976158142,1.3889672756195068,0.6204999685287476,1.596768140792847,50000.0,0.5052000284194946,2.2316551208496094,10000.0,38270.56211447716,41600.98387527466,38270.56211447716,3322.1507127285004,3.83987021446228,0.0 -84100,2.130461,2.309981,,,,,,,,,,,,,, -84200,2.0813994,2.642513,,,,,,,,,,,,,, -84300,1.9149473,2.276351,,,,,,,,,,,,,, -84400,2.1140962,2.4238725,,,,,,,,,,,,,, -84500,2.0992417,2.3177896,,,,,,,,,,,,,, -84600,2.0824468,2.5092623,,,,,,,,,,,,,, -84700,1.641766,3.9077568,,,,,,,,,,,,,, -84800,2.6178844,2.5195847,,,,,,,,,,,,,, -84900,1.9829279,2.267094,,,,,,,,,,,,,, -84982,,,0.6727148294448853,1.3522733449935913,0.6233800053596497,1.5781606435775757,50000.0,0.5042999982833862,2.21973204612732,10000.0,38690.58512282372,42056.864844083786,38690.58512282372,3357.9229278564453,3.876704692840576,0.0 -85000,2.2694573,2.5026054,,,,,,,,,,,,,, -85100,1.8657036,3.4031615,,,,,,,,,,,,,, -85200,2.016574,2.343867,,,,,,,,,,,,,, -85300,1.6481142,4.3877993,,,,,,,,,,,,,, -85400,1.6944101,3.5844393,,,,,,,,,,,,,, -85500,2.0691657,2.7075882,,,,,,,,,,,,,, -85600,1.9183147,2.904053,,,,,,,,,,,,,, -85700,1.982631,2.5971572,,,,,,,,,,,,,, -85800,1.8728801,2.5258718,,,,,,,,,,,,,, -85900,1.8629578,4.965603,,,,,,,,,,,,,, -85907,,,0.6804882884025574,1.3314138650894165,0.6174600124359131,1.6178598403930664,50000.0,0.4939000308513641,2.280888557434082,10000.0,39110.57365632057,42512.99793553352,39110.57365632057,3393.9766433238983,3.9189043045043945,0.0 -86000,1.5819038,4.964382,,,,,,,,,,,,,, -86100,1.893463,4.4703608,,,,,,,,,,,,,, -86200,1.800272,4.4180875,,,,,,,,,,,,,, -86300,1.9942681,2.3869076,,,,,,,,,,,,,, -86400,1.9021347,2.3930948,,,,,,,,,,,,,, -86500,2.041293,2.5337996,,,,,,,,,,,,,, -86600,1.760905,2.6132507,,,,,,,,,,,,,, -86700,2.040985,2.434185,,,,,,,,,,,,,, -86800,2.0565145,2.5868309,,,,,,,,,,,,,, -86832,,,0.669238269329071,1.371837854385376,0.6244199872016907,1.578890323638916,50000.0,0.5002000331878662,2.2396240234375,10000.0,39530.89902305603,42969.77918601036,39530.89902305603,3430.34137749672,3.9617397785186768,0.0 -86900,2.0931053,2.3088036,,,,,,,,,,,,,, -87000,1.6406027,5.1522417,,,,,,,,,,,,,, -87100,1.8579865,3.7953963,,,,,,,,,,,,,, -87200,2.0067503,2.3958588,,,,,,,,,,,,,, -87300,1.9496585,2.2733665,,,,,,,,,,,,,, -87400,1.9810866,2.2844489,,,,,,,,,,,,,, -87500,1.5591958,4.1638207,,,,,,,,,,,,,, -87600,1.7290593,3.5076454,,,,,,,,,,,,,, -87700,2.207049,2.3685458,,,,,,,,,,,,,, -87751,,,0.674023449420929,1.3504189252853394,0.6277799606323242,1.565969944000244,50000.0,0.5024000406265259,2.2063982486724854,10000.0,39951.18798828125,43425.609325408936,39951.18798828125,3465.7896132469177,4.005863666534424,0.0 -87800,1.9500477,2.2781346,,,,,,,,,,,,,, -87900,1.8312081,4.5989027,,,,,,,,,,,,,, -88000,1.7905369,4.515996,,,,,,,,,,,,,, -88100,2.161512,2.660937,,,,,,,,,,,,,, -88200,2.0020604,2.2761474,,,,,,,,,,,,,, -88300,2.2483947,2.330315,,,,,,,,,,,,,, -88400,1.9877156,2.4230127,,,,,,,,,,,,,, -88500,1.8957741,2.7575932,,,,,,,,,,,,,, -88600,2.0332222,2.290637,,,,,,,,,,,,,, -88675,,,0.6822851300239563,1.2980283498764038,0.6274799704551697,1.5562463998794556,50000.0,0.5085000395774841,2.1928582191467285,10000.0,40371.13100576401,43880.40012168884,40371.13100576401,3500.5497431755066,4.045783042907715,0.0 -88700,1.7198936,4.027759,,,,,,,,,,,,,, -88800,1.7376337,4.8148704,,,,,,,,,,,,,, -88900,1.7478477,2.5722,,,,,,,,,,,,,, -89000,1.8064313,3.3955288,,,,,,,,,,,,,, -89100,1.7436432,3.594494,,,,,,,,,,,,,, -89200,2.25338,2.365682,,,,,,,,,,,,,, -89300,2.1032393,2.3594382,,,,,,,,,,,,,, -89400,2.223981,2.4313047,,,,,,,,,,,,,, -89500,2.1692562,2.34758,,,,,,,,,,,,,, -89599,,,0.6733984351158142,1.3592572212219238,0.6269800066947937,1.5658754110336304,50000.0,0.5076000094413757,2.21124529838562,10000.0,40791.40249633789,44335.90924882889,40791.40249633789,3535.699191570282,4.085117340087891,0.0 -89600,2.0850072,2.825571,,,,,,,,,,,,,, -89700,2.0236366,2.3130467,,,,,,,,,,,,,, -89800,1.888883,2.7243755,,,,,,,,,,,,,, -89900,2.0289423,2.3085718,,,,,,,,,,,,,, -90000,2.2069468,2.239375,,,,,,,,,,,,,, -90100,2.0424578,2.309905,,,,,,,,,,,,,, -90200,2.031005,2.2451663,,,,,,,,,,,,,, -90300,2.3117857,2.3667963,,,,,,,,,,,,,, -90400,1.9377823,2.6909158,,,,,,,,,,,,,, -90500,2.0372827,2.606418,,,,,,,,,,,,,, -90524,,,0.6755468845367432,1.3368297815322876,0.6282599568367004,1.5485241413116455,50000.0,0.5056000351905823,2.1989150047302246,10000.0,41211.55885767937,44790.13888645172,41211.55885767937,3569.6862609386444,4.122882843017578,0.0 -90600,1.8570291,2.9286814,,,,,,,,,,,,,, -90700,2.1463435,2.484473,,,,,,,,,,,,,, -90800,2.130941,2.5373673,,,,,,,,,,,,,, -90900,1.712022,5.009117,,,,,,,,,,,,,, -91000,1.9788038,2.3250074,,,,,,,,,,,,,, -91100,2.0526192,2.355002,,,,,,,,,,,,,, -91200,1.8876996,2.8696678,,,,,,,,,,,,,, -91300,2.014439,2.432015,,,,,,,,,,,,,, -91400,2.0695715,2.3542528,,,,,,,,,,,,,, -91448,,,0.6873828172683716,1.2694848775863647,0.6334599852561951,1.521888256072998,50000.0,0.513700008392334,2.1726202964782715,10000.0,41631.89220118523,45247.85872173309,41631.89220118523,3606.986034631729,4.161207437515259,0.0 -91500,1.7663615,3.8408496,,,,,,,,,,,,,, -91600,2.1031199,2.602621,,,,,,,,,,,,,, -91700,2.196042,3.0254056,,,,,,,,,,,,,, -91800,1.9565754,4.9592543,,,,,,,,,,,,,, -91900,1.8147033,4.6261573,,,,,,,,,,,,,, -92000,1.9953758,2.146196,,,,,,,,,,,,,, -92100,2.2130256,2.4076686,,,,,,,,,,,,,, -92200,2.0926676,2.6579268,,,,,,,,,,,,,, -92300,2.0385032,4.748847,,,,,,,,,,,,,, -92373,,,0.699414074420929,1.262794017791748,0.6269800066947937,1.5701531171798706,50000.0,0.5078999996185303,2.2144458293914795,10000.0,42052.14500403404,45706.28652644157,42052.14500403404,3645.071723461151,4.2016355991363525,0.0 -92400,2.0711107,2.4082713,,,,,,,,,,,,,, -92500,2.1452904,2.3636963,,,,,,,,,,,,,, -92600,1.811637,3.044616,,,,,,,,,,,,,, -92700,2.0905337,2.2834787,,,,,,,,,,,,,, -92800,2.0646107,2.3520687,,,,,,,,,,,,,, -92900,1.7599219,3.9563677,,,,,,,,,,,,,, -93000,2.1960156,2.5511093,,,,,,,,,,,,,, -93100,1.8550268,3.0714843,,,,,,,,,,,,,, -93200,2.3482337,2.396177,,,,,,,,,,,,,, -93298,,,0.6812304258346558,1.297718167304993,0.63646000623703,1.5049082040786743,50000.0,0.5161000490188599,2.159391164779663,10000.0,42472.23790502548,46159.35181570053,42472.23790502548,3677.9450783729553,4.251505374908447,0.0 -93300,1.7996644,4.968138,,,,,,,,,,,,,, -93400,1.8127049,3.1568317,,,,,,,,,,,,,, -93500,2.4245727,2.370476,,,,,,,,,,,,,, -93600,2.2813947,2.3143141,,,,,,,,,,,,,, -93700,2.2811522,2.4350164,,,,,,,,,,,,,, -93800,1.6798536,3.794177,,,,,,,,,,,,,, -93900,1.8604643,3.0919693,,,,,,,,,,,,,, -94000,2.18218,2.2157073,,,,,,,,,,,,,, -94100,2.020839,2.1015978,,,,,,,,,,,,,, -94200,1.9539652,4.5718474,,,,,,,,,,,,,, -94221,,,0.6886913776397705,1.2740516662597656,0.634880006313324,1.5176643133163452,50000.0,0.5134000182151794,2.162588119506836,10000.0,42892.54639649391,46618.41143655777,42892.54639649391,3716.609961032867,4.289197206497192,0.0 -94300,1.708486,3.863974,,,,,,,,,,,,,, -94400,2.036537,2.2260056,,,,,,,,,,,,,, -94500,2.2643483,2.2680526,,,,,,,,,,,,,, -94600,1.7893602,4.7472,,,,,,,,,,,,,, -94700,2.196281,2.264698,,,,,,,,,,,,,, -94800,1.9175758,3.2522612,,,,,,,,,,,,,, -94900,1.8915989,4.6057963,,,,,,,,,,,,,, -95000,2.2651262,4.7023687,,,,,,,,,,,,,, -95100,2.2510777,2.444527,,,,,,,,,,,,,, -95146,,,0.7041796445846558,1.1996102333068848,0.6363599896430969,1.4963476657867432,50000.0,0.5163000226020813,2.128796100616455,10000.0,43312.84077787399,47077.16275238991,43312.84077787399,3754.973846912384,4.333668947219849,0.0 -95200,2.0943065,2.1293192,,,,,,,,,,,,,, -95300,2.2162488,2.2563083,,,,,,,,,,,,,, -95400,1.946216,3.1383731,,,,,,,,,,,,,, -95500,2.0392265,2.3295808,,,,,,,,,,,,,, -95600,2.3793674,2.5384011,,,,,,,,,,,,,, -95700,2.105615,2.5464616,,,,,,,,,,,,,, -95800,2.4009175,2.3077545,,,,,,,,,,,,,, -95900,1.8892254,4.864751,,,,,,,,,,,,,, -96000,2.0960913,2.2249272,,,,,,,,,,,,,, -96071,,,0.6895898580551147,1.2599599361419678,0.6400399804115295,1.4780635833740234,50000.0,0.5179000496864319,2.110060930252075,10000.0,43733.06075167656,47535.35248112679,43733.06075167656,3792.851813316345,4.377025365829468,0.0 -96100,1.7089694,4.8403625,,,,,,,,,,,,,, -96200,1.778577,3.8905401,,,,,,,,,,,,,, -96300,2.213282,2.3070173,,,,,,,,,,,,,, -96400,2.0136151,2.8659174,,,,,,,,,,,,,, -96500,2.0013568,4.5970135,,,,,,,,,,,,,, -96600,2.0940015,2.2314036,,,,,,,,,,,,,, -96700,1.8185476,4.852652,,,,,,,,,,,,,, -96800,2.3104618,2.3918228,,,,,,,,,,,,,, -96900,2.2036324,2.213354,,,,,,,,,,,,,, -96996,,,0.690234363079071,1.2644189596176147,0.642300009727478,1.498186111450195,50000.0,0.5164999961853027,2.145113468170166,10000.0,44153.27914762497,47992.982228040695,44153.27914762497,3830.168391227722,4.4223480224609375,0.0 -97000,2.0923884,2.1856349,,,,,,,,,,,,,, -97100,1.6818993,3.3838696,,,,,,,,,,,,,, -97200,2.1714396,2.3740156,,,,,,,,,,,,,, -97300,2.1706326,2.0869021,,,,,,,,,,,,,, -97400,2.113436,2.1303174,,,,,,,,,,,,,, -97500,2.1691096,2.8166184,,,,,,,,,,,,,, -97600,2.1675842,2.3220794,,,,,,,,,,,,,, -97700,1.765228,3.929395,,,,,,,,,,,,,, -97800,2.1059146,3.9214396,,,,,,,,,,,,,, -97900,1.81245,3.8198442,,,,,,,,,,,,,, -97914,,,0.6972265243530273,1.2512993812561035,0.6369799971580505,1.5186965465545654,50000.0,0.5138000249862671,2.157207727432251,10000.0,44573.4611530304,48449.29330062866,44573.4611530304,3866.2021346092224,4.46885085105896,0.0 -98000,2.496704,4.851774,,,,,,,,,,,,,, -98100,2.1498754,2.300816,,,,,,,,,,,,,, -98200,2.038095,2.855242,,,,,,,,,,,,,, -98300,1.8602216,3.1915421,,,,,,,,,,,,,, -98400,2.1215851,2.2051473,,,,,,,,,,,,,, -98500,1.8699536,3.782223,,,,,,,,,,,,,, -98600,2.0277085,2.0842247,,,,,,,,,,,,,, -98700,2.2309096,2.3527367,,,,,,,,,,,,,, -98800,1.8165379,4.5657067,,,,,,,,,,,,,, -98837,,,0.6911327838897705,1.2608174085617063,0.6458399891853333,1.466307282447815,50000.0,0.5199000239372253,2.1083343029022217,10000.0,44993.504257678986,48907.94924497605,44993.504257678986,3904.7241473197937,4.510102987289429,0.0 -98900,2.3210497,2.290267,,,,,,,,,,,,,, -99000,1.9428736,3.0126228,,,,,,,,,,,,,, -99100,1.9932364,3.1739597,,,,,,,,,,,,,, -99200,1.8754681,3.2937148,,,,,,,,,,,,,, -99300,1.8180102,3.391234,,,,,,,,,,,,,, -99400,2.0251138,4.2455406,,,,,,,,,,,,,, -99500,2.0411983,2.3778448,,,,,,,,,,,,,, -99600,2.1976767,2.283049,,,,,,,,,,,,,, -99700,1.8878868,3.3006642,,,,,,,,,,,,,, -99760,,,0.6932421922683716,1.244040608406067,0.6432799696922302,1.4662216901779177,50000.0,0.5199000239372253,2.098806858062744,10000.0,45413.90217757225,49368.09958767891,45413.90217757225,3944.386614084244,4.552144289016724,0.0 -99800,1.8417679,4.824582,,,,,,,,,,,,,, -99900,2.0447466,2.3132575,,,,,,,,,,,,,, -100000,2.3043938,2.5080552,,,,,,,,,,,,,, -100100,2.4123576,2.2449577,,,,,,,,,,,,,, -100200,2.061159,3.8566728,,,,,,,,,,,,,, -100300,1.9995092,4.532547,,,,,,,,,,,,,, -100400,2.1600785,2.4131966,,,,,,,,,,,,,, -100500,2.198488,2.353929,,,,,,,,,,,,,, -100600,1.7965008,4.233755,,,,,,,,,,,,,, -100680,,,0.7011132836341858,1.218852400779724,0.6454199552536011,1.4721068143844604,50000.0,0.5200000405311584,2.113632917404175,10000.0,45834.00345420837,49826.67389035225,45834.00345420837,3982.7625353336334,4.600781679153442,0.0 -100700,2.2428033,2.5700579,,,,,,,,,,,,,, -100800,2.1687434,2.3113668,,,,,,,,,,,,,, -100900,2.5313218,2.1933565,,,,,,,,,,,,,, -101000,2.5582829,4.9121137,,,,,,,,,,,,,, -101100,2.1623642,3.3354478,,,,,,,,,,,,,, -101200,2.252008,2.2760491,,,,,,,,,,,,,, -101300,2.0936801,2.5342493,,,,,,,,,,,,,, -101400,2.2094615,2.2869349,,,,,,,,,,,,,, -101500,2.3363981,2.5285664,,,,,,,,,,,,,, -101600,,,0.7016406059265137,1.2149200439453125,0.6446999907493591,1.46499764919281,50000.0,0.522100031375885,2.124752998352051,10000.0,46253.99683356285,50285.926607847214,46253.99683356285,4021.927656650543,4.6470115184783936,0.0 -101600,2.2441814,2.2480679,,,,,,,,,,,,,, -101700,2.4879968,2.1965005,,,,,,,,,,,,,, -101800,1.9454134,4.136139,,,,,,,,,,,,,, -101900,2.3101315,2.2831347,,,,,,,,,,,,,, -102000,1.8717817,4.6808133,,,,,,,,,,,,,, -102100,1.7327234,4.7926407,,,,,,,,,,,,,, -102200,2.35024,2.3014126,,,,,,,,,,,,,, -102300,2.1350672,2.3225253,,,,,,,,,,,,,, -102400,2.3714218,4.830178,,,,,,,,,,,,,, -102500,1.9530979,4.371133,,,,,,,,,,,,,, -102523,,,0.6954296827316284,1.231560230255127,0.6503599882125854,1.4419031143188477,50000.0,0.5236999988555908,2.1005005836486816,10000.0,46674.343448877335,50745.39287304878,46674.343448877335,4060.9536135196686,4.69239354133606,0.0 -102600,2.1464581,2.162324,,,,,,,,,,,,,, -102700,1.8386191,4.390252,,,,,,,,,,,,,, -102800,2.2546666,2.2507977,,,,,,,,,,,,,, -102900,1.8902057,4.3452716,,,,,,,,,,,,,, -103000,2.127714,2.3019,,,,,,,,,,,,,, -103100,2.100966,2.466372,,,,,,,,,,,,,, -103200,2.07462,4.7618055,,,,,,,,,,,,,, -103300,2.1308844,3.0446718,,,,,,,,,,,,,, -103400,2.1788206,2.1672301,,,,,,,,,,,,,, -103444,,,0.7022265195846558,1.2212992906570437,0.6457799673080444,1.4728409051895142,50000.0,0.5247000455856323,2.1184866428375244,10000.0,47094.48765182495,51201.4682199955,47094.48765182495,4096.788547039032,4.73948335647583,0.0 -103500,2.309433,2.4132445,,,,,,,,,,,,,, -103600,1.9117931,3.258737,,,,,,,,,,,,,, -103700,2.1115117,3.2288795,,,,,,,,,,,,,, -103800,2.289006,2.2076707,,,,,,,,,,,,,, -103900,2.3909214,2.262187,,,,,,,,,,,,,, -104000,2.543698,2.2189353,,,,,,,,,,,,,, -104100,2.009141,4.747068,,,,,,,,,,,,,, -104200,2.2382402,2.120685,,,,,,,,,,,,,, -104300,2.3289568,2.2146869,,,,,,,,,,,,,, -104365,,,0.7180859446525574,1.164534091949463,0.6500999927520752,1.4714568853378296,50000.0,0.528700053691864,2.0928614139556885,10000.0,47514.57798838616,51658.96434521675,47514.57798838616,4134.106735706329,4.77846884727478,0.0 -104400,2.1006145,2.8069324,,,,,,,,,,,,,, -104500,2.538785,3.0108929,,,,,,,,,,,,,, -104600,2.4274592,2.1656897,,,,,,,,,,,,,, -104700,2.3306246,2.2016487,,,,,,,,,,,,,, -104800,1.8306719,4.018528,,,,,,,,,,,,,, -104900,2.286635,2.2227275,,,,,,,,,,,,,, -105000,2.2766328,2.143734,,,,,,,,,,,,,, -105100,2.2768729,3.500635,,,,,,,,,,,,,, -105200,1.9239242,3.277268,,,,,,,,,,,,,, -105288,,,0.6989648342132568,1.2240859270095823,0.6522799730300903,1.4324557781219482,50000.0,0.5332000255584717,2.0801806449890137,10000.0,47934.9159283638,52117.99890470505,47934.9159283638,4172.708475351334,4.825288534164429,0.0 -105300,2.3457515,2.2210858,,,,,,,,,,,,,, -105400,2.3761096,2.4638267,,,,,,,,,,,,,, -105500,2.3239217,2.0339913,,,,,,,,,,,,,, -105600,2.2185426,4.436831,,,,,,,,,,,,,, -105700,1.9244536,4.611205,,,,,,,,,,,,,, -105800,1.9737076,3.1765954,,,,,,,,,,,,,, -105900,2.002959,3.3580291,,,,,,,,,,,,,, -106000,2.617327,2.320363,,,,,,,,,,,,,, -106100,1.9903355,3.1243582,,,,,,,,,,,,,, -106200,2.205506,2.151547,,,,,,,,,,,,,, -106209,,,0.7043554782867432,1.194347858428955,0.650879979133606,1.437902331352234,50000.0,0.5314000248908997,2.071938753128052,10000.0,48354.99125123024,52579.87425208092,48354.99125123024,4214.411760091782,4.873553991317749,0.0 -106300,2.2669797,2.0576262,,,,,,,,,,,,,, -106400,2.617219,2.2127628,,,,,,,,,,,,,, -106500,2.237221,2.2821352,,,,,,,,,,,,,, -106600,2.389659,2.033925,,,,,,,,,,,,,, -106700,2.0845413,4.494281,,,,,,,,,,,,,, -106800,2.2187655,2.398715,,,,,,,,,,,,,, -106900,2.5590403,2.061169,,,,,,,,,,,,,, -107000,2.5051353,2.2479787,,,,,,,,,,,,,, -107100,2.4889073,2.292463,,,,,,,,,,,,,, -107131,,,0.717578113079071,1.1480196714401243,0.6521399617195129,1.4363003969192505,50000.0,0.534000039100647,2.086843490600586,10000.0,48775.022804260254,53041.11679935455,48775.022804260254,4255.533785581589,4.913725852966309,0.0 -107200,2.0354283,4.40293,,,,,,,,,,,,,, -107300,2.2369208,2.2680922,,,,,,,,,,,,,, -107400,2.0227416,3.4145024,,,,,,,,,,,,,, -107500,2.1209493,4.738518,,,,,,,,,,,,,, -107600,2.032954,3.3516772,,,,,,,,,,,,,, -107700,2.125485,2.7691677,,,,,,,,,,,,,, -107800,2.287195,2.071297,,,,,,,,,,,,,, -107900,2.2865875,2.2585187,,,,,,,,,,,,,, -108000,2.2459505,2.3841777,,,,,,,,,,,,,, -108050,,,0.7084179520606995,1.183990240097046,0.661359965801239,1.3978753089904783,50000.0,0.5369000434875488,2.042796850204468,10000.0,49194.9777302742,53502.97209262848,49194.9777302742,4297.330185413361,4.968867778778076,0.0 -108100,2.5881207,2.1318738,,,,,,,,,,,,,, -108200,2.0343149,3.4488525,,,,,,,,,,,,,, -108300,2.267424,2.3716633,,,,,,,,,,,,,, -108400,2.318045,4.701435,,,,,,,,,,,,,, -108500,2.0574977,4.771083,,,,,,,,,,,,,, -108600,2.1152055,4.10467,,,,,,,,,,,,,, -108700,1.968782,4.5266075,,,,,,,,,,,,,, -108800,2.0706334,4.1964765,,,,,,,,,,,,,, -108900,2.11267,3.8706222,,,,,,,,,,,,,, -108970,,,0.7154492139816284,1.1541640758514404,0.6593199968338013,1.3946455717086792,50000.0,0.5343000292778015,2.046653509140014,10000.0,49614.986085653305,53960.53556919098,49614.986085653305,4334.789614200592,5.015544414520264,0.0 -109000,1.8581278,3.9327338,,,,,,,,,,,,,, -109100,2.3893154,2.1747959,,,,,,,,,,,,,, -109200,2.4610798,2.121983,,,,,,,,,,,,,, -109300,2.4167523,2.0073895,,,,,,,,,,,,,, -109400,2.1356626,2.8460746,,,,,,,,,,,,,, -109500,2.410475,2.2309446,,,,,,,,,,,,,, -109600,2.4607573,4.5934577,,,,,,,,,,,,,, -109700,2.1543963,4.7048836,,,,,,,,,,,,,, -109800,2.0714056,4.6304526,,,,,,,,,,,,,, -109892,,,0.7190625071525574,1.153494119644165,0.6587399840354919,1.4109584093093872,50000.0,0.5376999974250793,2.04797911643982,10000.0,50035.14333152771,54417.93356990814,50035.14333152771,4371.936519861221,5.060314655303955,0.0 -109900,1.9383601,3.4466488,,,,,,,,,,,,,, -110000,2.221527,4.634578,,,,,,,,,,,,,, -110100,2.6305132,2.492147,,,,,,,,,,,,,, -110200,2.3232415,2.130474,,,,,,,,,,,,,, -110300,1.9103655,4.1299977,,,,,,,,,,,,,, -110400,2.0115194,3.0588665,,,,,,,,,,,,,, -110500,2.4055493,2.228787,,,,,,,,,,,,,, -110600,2.1612754,2.520629,,,,,,,,,,,,,, -110700,2.1777215,2.6435964,,,,,,,,,,,,,, -110800,2.2972045,2.2421951,,,,,,,,,,,,,, -110814,,,0.707324206829071,1.1942484378814695,0.6566799879074097,1.4237462282180786,50000.0,0.5372000336647034,2.042734384536743,10000.0,50455.33958125114,54877.30311059952,50455.33958125114,4411.016355514526,5.104965448379517,0.0 -110900,2.5444078,2.2486885,,,,,,,,,,,,,, -111000,2.4087224,2.1608143,,,,,,,,,,,,,, -111100,2.5688736,2.171113,,,,,,,,,,,,,, -111200,2.6552281,2.5552835,,,,,,,,,,,,,, -111300,2.1496809,4.742508,,,,,,,,,,,,,, -111400,2.3709986,2.0850785,,,,,,,,,,,,,, -111500,2.3906238,2.2311523,,,,,,,,,,,,,, -111600,1.9914683,3.5721388,,,,,,,,,,,,,, -111700,2.1011915,3.9552176,,,,,,,,,,,,,, -111735,,,0.7171288728713989,1.144237995147705,0.6643199920654297,1.3885068893432615,50000.0,0.5392000079154968,2.033914566040039,10000.0,50875.68436551094,55336.55248832703,50875.68436551094,4449.826142311096,5.151665925979614,0.0 -111800,2.084995,3.6318617,,,,,,,,,,,,,, -111900,2.1083763,3.2832065,,,,,,,,,,,,,, -112000,2.3409994,2.4044406,,,,,,,,,,,,,, -112100,2.2688162,2.095293,,,,,,,,,,,,,, -112200,2.1657884,3.9292178,,,,,,,,,,,,,, -112300,2.6910696,2.1426105,,,,,,,,,,,,,, -112400,2.393326,2.5754342,,,,,,,,,,,,,, -112500,1.979413,3.0782824,,,,,,,,,,,,,, -112600,2.437251,2.079402,,,,,,,,,,,,,, -112656,,,0.7238867282867432,1.1074589490890503,0.6677199602127075,1.3568204641342163,50000.0,0.5385000109672546,2.0181593894958496,10000.0,51295.67676925659,55796.26206231117,51295.67676925659,4489.448220968247,5.197967767715454,0.0 -112700,2.0735993,2.945571,,,,,,,,,,,,,, -112800,2.371988,2.0854692,,,,,,,,,,,,,, -112900,2.328408,2.608185,,,,,,,,,,,,,, -113000,2.1197503,4.3058405,,,,,,,,,,,,,, -113100,2.5174947,2.1836896,,,,,,,,,,,,,, -113200,2.6745698,2.498498,,,,,,,,,,,,,, -113300,2.5064566,2.075929,,,,,,,,,,,,,, -113400,2.592104,2.297665,,,,,,,,,,,,,, -113500,2.642384,2.0761967,,,,,,,,,,,,,, -113578,,,0.7433788776397705,1.0337588787078855,0.6680999994277954,1.3554725646972656,50000.0,0.5476000308990479,1.9904879331588743,10000.0,51715.61089348793,56253.195014476776,51715.61089348793,4526.354462623596,5.241725444793701,0.0 -113600,2.7619216,2.0905395,,,,,,,,,,,,,, -113700,2.4101005,2.324842,,,,,,,,,,,,,, -113800,2.3268604,2.401925,,,,,,,,,,,,,, -113900,2.5815086,2.1837513,,,,,,,,,,,,,, -114000,2.2831068,4.322057,,,,,,,,,,,,,, -114100,2.4417868,2.0716925,,,,,,,,,,,,,, -114200,2.3277853,3.880866,,,,,,,,,,,,,, -114300,2.3757794,2.1400046,,,,,,,,,,,,,, -114400,2.4251559,4.49689,,,,,,,,,,,,,, -114500,2.1153977,2.8407125,,,,,,,,,,,,,, -114501,,,0.7174609303474426,1.1560657024383545,0.6663399934768677,1.3868917226791382,50000.0,0.5457000136375427,2.03826904296875,10000.0,52135.60233712196,56712.77788186073,52135.60233712196,4565.85283613205,5.286535739898682,0.0 -114600,2.4774532,2.1393132,,,,,,,,,,,,,, -114700,2.1919343,2.934181,,,,,,,,,,,,,, -114800,2.151253,4.0001764,,,,,,,,,,,,,, -114900,2.186737,3.102269,,,,,,,,,,,,,, -115000,2.513326,2.2892509,,,,,,,,,,,,,, -115100,2.6425986,1.9523948,,,,,,,,,,,,,, -115200,2.4367094,2.0524118,,,,,,,,,,,,,, -115300,2.4486058,2.2285056,,,,,,,,,,,,,, -115400,2.4873552,2.1402729,,,,,,,,,,,,,, -115423,,,0.7275195121765137,1.0962249040603638,0.6660999655723572,1.3563467264175415,50000.0,0.5384000539779663,1.9989176988601685,10000.0,52555.83881497383,57172.142484903336,52555.83881497383,4604.891656398773,5.327480792999268,0.0 -115500,2.4503696,2.5160487,,,,,,,,,,,,,, -115600,2.6749105,2.1746264,,,,,,,,,,,,,, -115700,2.334808,2.0058463,,,,,,,,,,,,,, -115800,2.3065898,4.021101,,,,,,,,,,,,,, -115900,2.5809464,2.007594,,,,,,,,,,,,,, -116000,2.2073817,3.2260947,,,,,,,,,,,,,, -116100,2.5694566,2.1422017,,,,,,,,,,,,,, -116200,2.3240507,3.4040234,,,,,,,,,,,,,, -116300,2.5474205,2.1376402,,,,,,,,,,,,,, -116343,,,0.7395312190055847,1.0397402048110962,0.6725199818611145,1.3434137105941772,50000.0,0.5468000173568726,1.994463562965393,10000.0,52975.88430976868,57628.63620185852,52975.88430976868,4641.241088867188,5.377520561218262,0.0 -116400,2.2519438,2.655924,,,,,,,,,,,,,, -116500,2.4183004,2.02516,,,,,,,,,,,,,, -116600,2.337585,2.2334409,,,,,,,,,,,,,, -116700,2.543452,2.3473918,,,,,,,,,,,,,, -116800,2.2753227,3.703163,,,,,,,,,,,,,, -116900,2.3500125,2.2952883,,,,,,,,,,,,,, -117000,2.710053,2.0127492,,,,,,,,,,,,,, -117100,2.2666945,3.8042412,,,,,,,,,,,,,, -117200,2.4696903,2.086147,,,,,,,,,,,,,, -117267,,,0.7267968654632568,1.0954101085662842,0.6725599765777588,1.3405753374099731,50000.0,0.5462000370025635,1.9877381324768064,10000.0,53396.02493786812,58088.51064157486,53396.02493786812,4680.875846385956,5.427154302597046,0.0 -117300,2.4959002,2.2844481,,,,,,,,,,,,,, -117400,2.3930604,2.0208008,,,,,,,,,,,,,, -117500,2.6312954,2.07971,,,,,,,,,,,,,, -117600,2.3341503,2.7184658,,,,,,,,,,,,,, -117700,2.369432,1.9630387,,,,,,,,,,,,,, -117800,2.5618277,2.0493567,,,,,,,,,,,,,, -117900,2.4059327,1.9975123,,,,,,,,,,,,,, -118000,2.5897193,2.0405476,,,,,,,,,,,,,, -118100,2.6240568,2.0955846,,,,,,,,,,,,,, -118189,,,0.7324609160423279,1.0725865364074707,0.6753399968147278,1.3302712440490725,50000.0,0.551300048828125,1.9637449979782104,10000.0,53816.141922950745,58548.5072760582,53816.141922950745,4720.6627950668335,5.471050262451172,0.0 -118200,2.536024,2.0234687,,,,,,,,,,,,,, -118300,2.8056583,2.0886266,,,,,,,,,,,,,, -118400,2.453296,4.604456,,,,,,,,,,,,,, -118500,2.2045386,3.633,,,,,,,,,,,,,, -118600,2.9988778,2.117253,,,,,,,,,,,,,, -118700,2.4907417,3.9099133,,,,,,,,,,,,,, -118800,2.8597288,2.0869126,,,,,,,,,,,,,, -118900,2.4653635,2.7332575,,,,,,,,,,,,,, -119000,2.930037,1.9759195,,,,,,,,,,,,,, -119100,2.5682614,2.1295493,,,,,,,,,,,,,, -119112,,,0.7394335865974426,1.0667121410369873,0.6752600073814392,1.3475139141082764,50000.0,0.5527000427246094,1.9856129884719849,10000.0,54236.35823750496,59005.27227306366,54236.35823750496,4757.115355014801,5.5186426639556885,0.0 -119200,2.5278707,2.027934,,,,,,,,,,,,,, -119300,2.8288386,2.0549917,,,,,,,,,,,,,, -119400,2.4585416,2.4228103,,,,,,,,,,,,,, -119500,2.2364323,2.4563932,,,,,,,,,,,,,, -119600,2.557365,2.5204601,,,,,,,,,,,,,, -119700,2.4730117,2.1485019,,,,,,,,,,,,,, -119800,2.4392369,3.4006,,,,,,,,,,,,,, -119900,2.5684216,4.3038783,,,,,,,,,,,,,, -120000,2.6494682,2.012068,,,,,,,,,,,,,, -120033,,,0.729199230670929,1.0772337913513184,0.675279974937439,1.3255867958068848,50000.0,0.5491000413894653,1.9687131643295288,10000.0,54656.45013904572,59460.71386647224,54656.45013904572,4792.37073636055,5.564018726348877,0.0 -120100,2.2357519,3.4756863,,,,,,,,,,,,,, -120200,2.623279,3.9771078,,,,,,,,,,,,,, -120300,2.2957273,3.533763,,,,,,,,,,,,,, -120400,2.569287,2.129435,,,,,,,,,,,,,, -120500,2.4135478,2.033609,,,,,,,,,,,,,, -120600,2.256791,3.3320558,,,,,,,,,,,,,, -120700,2.4634478,2.1091392,,,,,,,,,,,,,, -120800,2.5048661,1.9832859,,,,,,,,,,,,,, -120900,2.8638773,2.0995514,,,,,,,,,,,,,, -120955,,,0.7299999594688416,1.0698095560073853,0.6780999898910522,1.307468056678772,50000.0,0.5529000163078308,1.95231294631958,10000.0,55076.50438141823,59919.06721377373,55076.50438141823,4830.57146859169,5.6144208908081055,0.0 -121000,2.7614868,1.9293876,,,,,,,,,,,,,, -121100,2.4705014,2.008113,,,,,,,,,,,,,, -121200,2.7320373,1.9370265,,,,,,,,,,,,,, -121300,2.4589176,3.5018246,,,,,,,,,,,,,, -121400,2.3877392,3.776824,,,,,,,,,,,,,, -121500,2.8146393,2.1968954,,,,,,,,,,,,,, -121600,2.3097215,3.797229,,,,,,,,,,,,,, -121700,2.49772,2.225124,,,,,,,,,,,,,, -121800,2.9395912,1.9526706,,,,,,,,,,,,,, -121875,,,0.7435156106948853,1.0205103158950806,0.6786999702453613,1.2984049320220947,50000.0,0.5550000071525574,1.928056001663208,10000.0,55496.58257865906,60376.49431824684,55496.58257865906,4867.829748630524,5.6567864418029785,0.0 -121900,2.3300514,3.7414625,,,,,,,,,,,,,, -122000,2.5170124,1.8821465,,,,,,,,,,,,,, -122100,2.6079044,2.0440059,,,,,,,,,,,,,, -122200,2.5430849,4.5032706,,,,,,,,,,,,,, -122300,3.024065,2.5867937,,,,,,,,,,,,,, -122400,2.760115,4.5359654,,,,,,,,,,,,,, -122500,2.8350518,1.9481199,,,,,,,,,,,,,, -122600,2.6333723,4.488834,,,,,,,,,,,,,, -122700,2.4337258,2.9468446,,,,,,,,,,,,,, -122794,,,0.7494726181030273,1.0063884258270264,0.6796199679374695,1.3093905448913574,50000.0,0.5601000189781189,1.9321649074554443,10000.0,55916.64206838608,60835.11968517304,55916.64206838608,4906.302020072937,5.70208215713501,0.0 -122800,2.5887287,4.5569525,,,,,,,,,,,,,, -122900,2.6134343,2.3738155,,,,,,,,,,,,,, -123000,2.8639848,1.9404898,,,,,,,,,,,,,, -123100,2.7626028,1.9677141,,,,,,,,,,,,,, -123200,2.9966197,4.373678,,,,,,,,,,,,,, -123300,2.6462722,2.1603308,,,,,,,,,,,,,, -123400,2.3386083,3.0614383,,,,,,,,,,,,,, -123500,2.3353324,3.596727,,,,,,,,,,,,,, -123600,2.906376,1.9915233,,,,,,,,,,,,,, -123700,2.6892862,2.6453958,,,,,,,,,,,,,, -123715,,,0.7424414157867432,1.035762071609497,0.6828199625015259,1.2859761714935305,50000.0,0.5570999979972839,1.92024028301239,10000.0,56336.915759801865,61291.97744345665,56336.915759801865,4942.790773868561,5.748456001281738,0.0 -123800,2.8583694,4.53552,,,,,,,,,,,,,, -123900,2.9004877,2.0710359,,,,,,,,,,,,,, -124000,2.710842,3.9968572,,,,,,,,,,,,,, -124100,2.4262567,3.8902178,,,,,,,,,,,,,, -124200,2.6753132,2.280857,,,,,,,,,,,,,, -124300,2.828898,4.4741864,,,,,,,,,,,,,, -124400,2.8676808,2.0182438,,,,,,,,,,,,,, -124500,2.8428285,1.9282475,,,,,,,,,,,,,, -124600,2.519962,2.709589,,,,,,,,,,,,,, -124636,,,0.7409374713897705,1.054355263710022,0.6823399662971497,1.299355149269104,50000.0,0.5631000399589539,1.9321929216384888,10000.0,56757.10899662972,61752.66572880745,56757.10899662972,4983.189473390579,5.796940326690674,0.0 -124700,2.7272055,2.054721,,,,,,,,,,,,,, -124800,2.8905866,1.966248,,,,,,,,,,,,,, -124900,2.7673748,1.9483953,,,,,,,,,,,,,, -125000,2.6282358,2.9766622,,,,,,,,,,,,,, -125100,2.7419474,1.941865,,,,,,,,,,,,,, -125200,2.8528228,2.0091732,,,,,,,,,,,,,, -125300,3.00266,1.9085169,,,,,,,,,,,,,, -125400,2.5774746,2.4343483,,,,,,,,,,,,,, -125500,2.6800666,4.114108,,,,,,,,,,,,,, -125557,,,0.7582616806030273,0.9785751104354858,0.6868399977684021,1.288220763206482,50000.0,0.5631999969482422,1.9148192405700684,10000.0,57177.09319901466,62214.15758442879,57177.09319901466,5024.603069782257,5.8421630859375,0.0 -125600,2.6255705,1.8843081,,,,,,,,,,,,,, -125700,2.75154,3.4887679,,,,,,,,,,,,,, -125800,3.0180573,2.1137104,,,,,,,,,,,,,, -125900,2.5512085,2.9490511,,,,,,,,,,,,,, -126000,2.6419916,3.575731,,,,,,,,,,,,,, -126100,3.3869524,2.0645216,,,,,,,,,,,,,, -126200,2.7895045,1.9562877,,,,,,,,,,,,,, -126300,2.956941,2.0593286,,,,,,,,,,,,,, -126400,2.597974,4.0443096,,,,,,,,,,,,,, -126481,,,0.744921863079071,1.0215235948562622,0.6889199614524841,1.2752596139907837,50000.0,0.5613000392913818,1.9243862628936768,10000.0,57597.37649774552,62675.59431099892,57597.37649774552,5065.656978368759,5.892143726348877,0.0 -126500,2.6971476,2.5254858,,,,,,,,,,,,,, -126600,3.029172,1.8771944,,,,,,,,,,,,,, -126700,2.4474237,3.0208948,,,,,,,,,,,,,, -126800,2.8496232,1.9516565,,,,,,,,,,,,,, -126900,2.9003801,2.015558,,,,,,,,,,,,,, -127000,2.8939354,2.0179737,,,,,,,,,,,,,, -127100,3.1076813,2.250421,,,,,,,,,,,,,, -127200,3.1233325,1.9961689,,,,,,,,,,,,,, -127300,2.6863096,3.3467667,,,,,,,,,,,,,, -127400,2.7225404,1.9006706,,,,,,,,,,,,,, -127404,,,0.7497265338897705,1.0092304944992063,0.6896399855613708,1.2700551748275757,50000.0,0.572100043296814,1.8867379426956177,10000.0,58017.32027029991,63134.48843693733,58017.32027029991,5104.507493257523,5.944280862808228,0.0 -127500,2.449443,2.4855995,,,,,,,,,,,,,, -127600,3.081644,2.0864418,,,,,,,,,,,,,, -127700,2.9264476,1.9879935,,,,,,,,,,,,,, -127800,2.8255937,3.7229915,,,,,,,,,,,,,, -127900,2.8809779,1.9000459,,,,,,,,,,,,,, -128000,2.7835696,2.0804286,,,,,,,,,,,,,, -128100,2.954255,2.140007,,,,,,,,,,,,,, -128200,2.7372794,2.3197308,,,,,,,,,,,,,, -128300,3.0633411,1.9888949,,,,,,,,,,,,,, -128326,,,0.7568749785423279,0.9682868123054504,0.6923800110816956,1.2558635473251345,50000.0,0.5698000192642212,1.8772536516189573,10000.0,58437.31800484657,63591.67713737488,58437.31800484657,5141.601754665375,5.991713047027588,0.0 -128400,2.9403694,1.827213,,,,,,,,,,,,,, -128500,2.7773883,2.1624334,,,,,,,,,,,,,, -128600,3.1561077,2.046884,,,,,,,,,,,,,, -128700,2.815964,2.1814072,,,,,,,,,,,,,, -128800,3.1591084,1.901117,,,,,,,,,,,,,, -128900,2.9793782,1.9441564,,,,,,,,,,,,,, -129000,2.8679326,4.5051517,,,,,,,,,,,,,, -129100,2.5826323,2.8856082,,,,,,,,,,,,,, -129200,2.787979,4.0243073,,,,,,,,,,,,,, -129247,,,0.7479491829872131,0.9964887499809264,0.6901599764823914,1.2488561868667605,50000.0,0.5692000389099121,1.8783817291259768,10000.0,58857.65893149376,64047.86702609062,58857.65893149376,5177.357530832291,6.036379098892212,0.0 -129300,2.4502113,2.3303227,,,,,,,,,,,,,, -129400,2.4983416,2.7416134,,,,,,,,,,,,,, -129500,2.9465735,1.9345812,,,,,,,,,,,,,, -129600,3.2272205,1.880481,,,,,,,,,,,,,, -129700,2.9127252,2.2889347,,,,,,,,,,,,,, -129800,3.248922,4.2532425,,,,,,,,,,,,,, -129900,3.0496376,4.1698637,,,,,,,,,,,,,, -130000,2.6134024,3.4683,,,,,,,,,,,,,, -130100,2.7051682,4.2372303,,,,,,,,,,,,,, -130169,,,0.75501948595047,0.979481041431427,0.6937199831008911,1.2402454614639282,50000.0,0.5705000162124634,1.873192310333252,10000.0,59277.79971027374,64509.075212717056,59277.79971027374,5218.325902700424,6.086957216262817,0.0 -130200,2.9494283,2.003627,,,,,,,,,,,,,, -130300,2.8616316,2.355172,,,,,,,,,,,,,, -130400,3.078238,2.515328,,,,,,,,,,,,,, -130500,3.114644,1.9873273,,,,,,,,,,,,,, -130600,2.4647422,2.9337406,,,,,,,,,,,,,, -130700,3.0738726,1.963497,,,,,,,,,,,,,, -130800,3.0933456,2.6806324,,,,,,,,,,,,,, -130900,3.4602213,1.8648899,,,,,,,,,,,,,, -131000,3.1171446,3.291775,,,,,,,,,,,,,, -131090,,,0.7644726634025574,0.9445186257362366,0.6960799694061279,1.2283663749694824,50000.0,0.5722000002861023,1.867382287979126,10000.0,59698.05014848709,64971.06675004959,59698.05014848709,5259.966984272003,6.138098955154419,0.0 -131100,3.2234883,1.913284,,,,,,,,,,,,,, -131200,2.8562863,2.427536,,,,,,,,,,,,,, -131300,2.7394693,3.05035,,,,,,,,,,,,,, -131400,3.0164437,1.9340143,,,,,,,,,,,,,, -131500,2.7869432,2.19079,,,,,,,,,,,,,, -131600,2.979138,2.2199483,,,,,,,,,,,,,, -131700,2.8609843,1.8495553,,,,,,,,,,,,,, -131800,3.0870638,1.9172673,,,,,,,,,,,,,, -131900,2.9316356,3.227059,,,,,,,,,,,,,, -132000,2.956385,1.7859991,,,,,,,,,,,,,, -132012,,,0.763867199420929,0.9349223375320436,0.6986199617385864,1.2158799171447754,50000.0,0.5773000121116638,1.8444945812225344,10000.0,60117.96479392052,65429.777752399445,60117.96479392052,5298.666358947754,6.187035799026489,0.0 -132100,3.1191607,1.7195902,,,,,,,,,,,,,, -132200,2.946316,2.3147135,,,,,,,,,,,,,, -132300,2.9455316,3.2163014,,,,,,,,,,,,,, -132400,3.2829378,1.878391,,,,,,,,,,,,,, -132500,2.8720403,2.0936205,,,,,,,,,,,,,, -132600,3.13123,1.9098448,,,,,,,,,,,,,, -132700,3.1805382,4.4195437,,,,,,,,,,,,,, -132800,2.9445422,1.9026481,,,,,,,,,,,,,, -132900,2.973351,1.8992742,,,,,,,,,,,,,, -132935,,,0.7568749785423279,0.9747950434684752,0.6997199654579163,1.2330219745635986,50000.0,0.5715000033378601,1.8723223209381104,10000.0,60538.21879982948,65887.87410736084,60538.21879982948,5336.413270950317,6.234623432159424,0.0 -133000,3.1445634,1.900179,,,,,,,,,,,,,, -133100,3.1140819,2.713271,,,,,,,,,,,,,, -133200,2.891459,2.8727565,,,,,,,,,,,,,, -133300,3.052402,4.23892,,,,,,,,,,,,,, -133400,3.1078842,1.8158029,,,,,,,,,,,,,, -133500,3.327575,1.8632022,,,,,,,,,,,,,, -133600,3.2083786,1.8115686,,,,,,,,,,,,,, -133700,2.923737,2.6581175,,,,,,,,,,,,,, -133800,3.823679,1.8558778,,,,,,,,,,,,,, -133860,,,0.7635351419448853,0.9517484307289124,0.7021600008010864,1.2254695892333984,50000.0,0.5763000249862671,1.8516438007354736,10000.0,60958.368864774704,66344.07096099854,60958.368864774704,5372.358961343765,6.28567361831665,0.0 -133900,3.0510826,2.4046493,,,,,,,,,,,,,, -134000,3.3510625,1.9630069,,,,,,,,,,,,,, -134100,2.7427037,2.3421075,,,,,,,,,,,,,, -134200,3.2244334,2.6362011,,,,,,,,,,,,,, -134300,2.9091756,1.7196137,,,,,,,,,,,,,, -134400,3.184522,4.1811466,,,,,,,,,,,,,, -134500,2.8617465,2.2385192,,,,,,,,,,,,,, -134600,3.199396,1.9743977,,,,,,,,,,,,,, -134700,2.8705485,2.963106,,,,,,,,,,,,,, -134783,,,0.7740820050239563,0.883020281791687,0.6988799571990967,1.2179479598999023,50000.0,0.5750000476837158,1.8422235250473025,10000.0,61378.45201802254,66799.75529813766,61378.45201802254,5407.867599010468,6.32945704460144,0.0 -134800,3.1718254,2.061259,,,,,,,,,,,,,, -134900,3.1830509,1.7844965,,,,,,,,,,,,,, -135000,3.0950904,3.1687596,,,,,,,,,,,,,, -135100,3.4047246,1.951361,,,,,,,,,,,,,, -135200,2.92735,3.9634583,,,,,,,,,,,,,, -135300,3.3421288,1.8945729,,,,,,,,,,,,,, -135400,2.715973,3.1936693,,,,,,,,,,,,,, -135500,3.0234647,4.1221204,,,,,,,,,,,,,, -135600,3.1121428,4.101854,,,,,,,,,,,,,, -135700,3.4862523,1.8431666,,,,,,,,,,,,,, -135707,,,0.761914074420929,0.9550774097442628,0.6998999714851379,1.2158119678497314,50000.0,0.5833000540733337,1.828897714614868,10000.0,61798.35517120361,67256.5339550972,61798.35517120361,5444.643156290054,6.38027548789978,0.0 -135800,3.2738934,1.776808,,,,,,,,,,,,,, -135900,3.1568472,1.8235399,,,,,,,,,,,,,, -136000,3.2040393,2.0728793,,,,,,,,,,,,,, -136100,3.5709636,1.923731,,,,,,,,,,,,,, -136200,3.1221778,1.877029,,,,,,,,,,,,,, -136300,3.1038954,2.370121,,,,,,,,,,,,,, -136400,2.937459,2.556182,,,,,,,,,,,,,, -136500,3.5026674,1.8849714,,,,,,,,,,,,,, -136600,3.282268,1.7520446,,,,,,,,,,,,,, -136632,,,0.7692968845367432,0.9049091935157776,0.7047799825668335,1.1906603574752808,50000.0,0.5819000005722046,1.79646897315979,10000.0,62218.55541443825,67712.98147702217,62218.55541443825,5480.798250198364,6.4242448806762695,0.0 -136700,3.5093482,1.8815341,,,,,,,,,,,,,, -136800,3.1046245,2.4225936,,,,,,,,,,,,,, -136900,3.2160814,1.8156397,,,,,,,,,,,,,, -137000,3.307202,1.7698307,,,,,,,,,,,,,, -137100,3.3047965,2.1875792,,,,,,,,,,,,,, -137200,3.535988,1.9845624,,,,,,,,,,,,,, -137300,3.2110248,4.229994,,,,,,,,,,,,,, -137400,3.1444218,3.3577607,,,,,,,,,,,,,, -137500,3.2092109,1.7478265,,,,,,,,,,,,,, -137558,,,0.7782226204872131,0.8680867552757263,0.7066400051116943,1.1834466457366943,50000.0,0.5800000429153442,1.7963061332702637,10000.0,62638.88868141174,68176.90643644333,62638.88868141174,5524.29457783699,6.470351934432983,0.0 -137600,3.316852,2.1059804,,,,,,,,,,,,,, -137700,3.251811,3.2077456,,,,,,,,,,,,,, -137800,3.1659272,2.288481,,,,,,,,,,,,,, -137900,3.281427,2.2681985,,,,,,,,,,,,,, -138000,3.1319826,1.9683911,,,,,,,,,,,,,, -138100,3.0911574,2.8640745,,,,,,,,,,,,,, -138200,3.4053285,1.7607133,,,,,,,,,,,,,, -138300,3.4040215,1.8526671,,,,,,,,,,,,,, -138400,3.0321784,2.3809078,,,,,,,,,,,,,, -138483,,,0.768750011920929,0.9112576842308044,0.7040199637413025,1.1906100511550903,50000.0,0.5848000049591064,1.7935667037963867,10000.0,63058.98669815064,68633.01377010345,63058.98669815064,5560.188486814499,6.537533760070801,0.0 -138500,3.410994,2.8407614,,,,,,,,,,,,,, -138600,3.6860564,1.7153707,,,,,,,,,,,,,, -138700,3.704127,4.0737715,,,,,,,,,,,,,, -138800,3.5668461,1.7964357,,,,,,,,,,,,,, -138900,3.3387637,4.329763,,,,,,,,,,,,,, -139000,4.2670712,4.3247657,,,,,,,,,,,,,, -139100,3.5087144,1.7691185,,,,,,,,,,,,,, -139200,3.1181037,1.6322618,,,,,,,,,,,,,, -139300,3.5114598,1.8437636,,,,,,,,,,,,,, -139400,3.6529374,1.8436929,,,,,,,,,,,,,, -139405,,,0.775195300579071,0.8927517533302307,0.7079600095748901,1.1707615852355957,50000.0,0.585800051689148,1.7905820608139038,10000.0,63478.89449167252,69097.40825605392,63478.89449167252,5604.57666349411,6.587302923202515,0.0 -139500,3.3728433,1.6287972,,,,,,,,,,,,,, -139600,3.5363681,1.7684115,,,,,,,,,,,,,, -139700,3.8261225,1.7941316,,,,,,,,,,,,,, -139800,3.4745266,3.9756105,,,,,,,,,,,,,, -139900,3.0277548,3.2392125,,,,,,,,,,,,,, -140000,3.6086593,1.8366061,,,,,,,,,,,,,, -140100,3.0976772,3.2701397,,,,,,,,,,,,,, -140200,3.545101,1.9545057,,,,,,,,,,,,,, -140300,3.4946434,1.9408214,,,,,,,,,,,,,, -140328,,,0.7830859422683716,0.8475367426872253,0.7130999565124512,1.1560105085372925,50000.0,0.5877000093460083,1.781422138214111,10000.0,63898.79946422577,69555.59872603416,63898.79946422577,5642.757848501205,6.641981601715088,0.0 -140400,3.2978067,4.220694,,,,,,,,,,,,,, -140500,3.888579,2.0398216,,,,,,,,,,,,,, -140600,3.2378523,3.4082422,,,,,,,,,,,,,, -140700,4.3042192,4.232524,,,,,,,,,,,,,, -140800,4.17108,1.9686404,,,,,,,,,,,,,, -140900,3.422002,1.8449831,,,,,,,,,,,,,, -141000,3.6244516,1.8340125,,,,,,,,,,,,,, -141100,3.4226327,1.8502107,,,,,,,,,,,,,, -141200,3.2646918,2.4010692,,,,,,,,,,,,,, -141248,,,0.7801562547683716,0.8689774870872498,0.7124599814414978,1.151098370552063,50000.0,0.5913000106811523,1.7659611701965332,10000.0,64318.720710515976,70020.60980701447,64318.720710515976,5687.743291378021,6.696745872497559,0.0 -141300,3.3724241,1.6942029,,,,,,,,,,,,,, -141400,3.5548093,4.262327,,,,,,,,,,,,,, -141500,3.1741004,1.7522273,,,,,,,,,,,,,, -141600,3.6868925,1.6787851,,,,,,,,,,,,,, -141700,3.8783445,1.91708,,,,,,,,,,,,,, -141800,3.4344945,1.7419398,,,,,,,,,,,,,, -141900,3.6992686,1.8104169,,,,,,,,,,,,,, -142000,3.4898992,1.7672007,,,,,,,,,,,,,, -142100,3.1901934,2.8893676,,,,,,,,,,,,,, -142170,,,0.7764452695846558,0.8885748982429504,0.7087399959564209,1.1719297170639038,50000.0,0.58760005235672,1.802656173706055,10000.0,64738.64544630051,70478.83239912987,64738.64544630051,5725.94082069397,6.748511791229248,0.0 -142200,4.242071,4.3381357,,,,,,,,,,,,,, -142300,3.4374268,2.175965,,,,,,,,,,,,,, -142400,3.540012,1.6935447,,,,,,,,,,,,,, -142500,3.3946202,3.725673,,,,,,,,,,,,,, -142600,4.5726805,4.03217,,,,,,,,,,,,,, -142700,3.5361202,2.285632,,,,,,,,,,,,,, -142800,4.212102,4.2150145,,,,,,,,,,,,,, -142900,3.5925803,2.530478,,,,,,,,,,,,,, -143000,3.6898742,2.2191308,,,,,,,,,,,,,, -143093,,,0.7829882502555847,0.8484798073768616,0.7154799699783325,1.1398857831954956,50000.0,0.5895000100135803,1.753939509391785,10000.0,65158.87682819367,70937.52128458023,65158.87682819367,5764.303265571594,6.794252634048462,0.0 -143100,3.795315,4.05919,,,,,,,,,,,,,, -143200,3.3043268,2.8798769,,,,,,,,,,,,,, -143300,3.7492473,3.526691,,,,,,,,,,,,,, -143400,3.5181677,1.8603787,,,,,,,,,,,,,, -143500,3.7537365,1.7126486,,,,,,,,,,,,,, -143600,3.906632,1.7453496,,,,,,,,,,,,,, -143700,3.6971083,1.7333473,,,,,,,,,,,,,, -143800,4.1704855,4.1656895,,,,,,,,,,,,,, -143900,3.318805,3.0340347,,,,,,,,,,,,,, -144000,3.9180744,4.1515465,,,,,,,,,,,,,, -144016,,,0.7986913919448853,0.7790917158126831,0.7196399569511414,1.120820164680481,50000.0,0.5962000489234924,1.7356946468353271,10000.0,65579.17691516876,71402.10477161407,65579.17691516876,5808.492747783661,6.839449644088745,0.0 -144100,3.608887,1.7167962,,,,,,,,,,,,,, -144200,3.5880442,2.7628007,,,,,,,,,,,,,, -144300,4.265957,4.211528,,,,,,,,,,,,,, -144400,3.5174813,2.3787823,,,,,,,,,,,,,, -144500,3.5324361,3.3219633,,,,,,,,,,,,,, -144600,3.2014825,2.4372127,,,,,,,,,,,,,, -144700,3.622613,2.7736223,,,,,,,,,,,,,, -144800,3.4970844,2.1803713,,,,,,,,,,,,,, -144900,3.9247992,1.8272429,,,,,,,,,,,,,, -144940,,,0.7861132621765137,0.835200846195221,0.718999981880188,1.122183084487915,50000.0,0.5929000377655029,1.7483426332473757,10000.0,65999.29121685028,71859.84184217453,65999.29121685028,5846.021278142929,6.884845018386841,0.0 -145000,3.5595639,2.5512652,,,,,,,,,,,,,, -145100,4.1547055,1.9991962,,,,,,,,,,,,,, -145200,3.7791758,1.6891117,,,,,,,,,,,,,, -145300,3.8049827,4.107191,,,,,,,,,,,,,, -145400,4.035031,1.7561641,,,,,,,,,,,,,, -145500,3.594986,2.1399302,,,,,,,,,,,,,, -145600,3.4534962,2.7822905,,,,,,,,,,,,,, -145700,3.6322112,1.6030089,,,,,,,,,,,,,, -145800,3.8830664,1.8804338,,,,,,,,,,,,,, -145865,,,0.789746105670929,0.8230050206184387,0.7236599922180176,1.11073899269104,50000.0,0.600600004196167,1.7311910390853882,10000.0,66419.42792582512,72317.11073088646,66419.42792582512,5883.058493375778,6.93181300163269,0.0 -145900,4.029062,4.165863,,,,,,,,,,,,,, -146000,3.7942636,1.7263763,,,,,,,,,,,,,, -146100,3.8649936,1.7251912,,,,,,,,,,,,,, -146200,3.6646473,1.703821,,,,,,,,,,,,,, -146300,3.7937717,1.6449955,,,,,,,,,,,,,, -146400,3.6329167,2.7538111,,,,,,,,,,,,,, -146500,4.603946,4.184206,,,,,,,,,,,,,, -146600,3.966431,1.7410121,,,,,,,,,,,,,, -146700,3.5704322,1.7662629,,,,,,,,,,,,,, -146791,,,0.7936913967132568,0.7987895011901855,0.7238999605178833,1.1063193082809448,50000.0,0.5986000299453735,1.7203985452651978,10000.0,66839.75595474243,72776.55774569511,66839.75595474243,5922.079788208008,6.981026649475098,0.0 -146800,4.4389997,4.135357,,,,,,,,,,,,,, -146900,3.281002,3.5972652,,,,,,,,,,,,,, -147000,3.3604515,2.6550088,,,,,,,,,,,,,, -147100,3.5211644,2.6262803,,,,,,,,,,,,,, -147200,4.1490355,1.6613685,,,,,,,,,,,,,, -147300,3.7280445,3.6189566,,,,,,,,,,,,,, -147400,3.7799551,2.1550043,,,,,,,,,,,,,, -147500,4.556594,3.6860173,,,,,,,,,,,,,, -147600,3.8012373,3.8555539,,,,,,,,,,,,,, -147700,3.3948588,3.3621976,,,,,,,,,,,,,, -147715,,,0.7894726395606995,0.8196678757667542,0.7218999862670898,1.1062513589859009,50000.0,0.6008000373840332,1.724432110786438,10000.0,67259.88791394234,73234.11113333702,67259.88791394234,5959.401660680771,7.031947135925293,0.0 -147800,4.2797985,3.9919014,,,,,,,,,,,,,, -147900,4.046797,1.7025193,,,,,,,,,,,,,, -148000,4.0623035,1.7168726,,,,,,,,,,,,,, -148100,4.205604,3.841926,,,,,,,,,,,,,, -148200,4.206076,1.7565036,,,,,,,,,,,,,, -148300,4.095108,1.6255158,,,,,,,,,,,,,, -148400,3.8249993,1.7588247,,,,,,,,,,,,,, -148500,3.8540354,1.5674219,,,,,,,,,,,,,, -148600,4.0782957,3.5167596,,,,,,,,,,,,,, -148639,,,0.7953320145606995,0.7963201999664307,0.724079966545105,1.092710256576538,50000.0,0.6016000509262085,1.7002700567245483,10000.0,67680.01170182228,73691.21638917923,67680.01170182228,5996.285425662994,7.0817482471466064,0.0 -148700,3.5611303,3.3002167,,,,,,,,,,,,,, -148800,3.8894627,3.1398351,,,,,,,,,,,,,, -148900,4.01231,2.829893,,,,,,,,,,,,,, -149000,4.0276327,3.8014069,,,,,,,,,,,,,, -149100,4.3596206,1.7480555,,,,,,,,,,,,,, -149200,4.379641,2.2314858,,,,,,,,,,,,,, -149300,4.185598,1.6593021,,,,,,,,,,,,,, -149400,4.055706,1.7940093,,,,,,,,,,,,,, -149500,3.9893377,2.2033637,,,,,,,,,,,,,, -149563,,,0.7990038990974426,0.7832265496253967,0.7269600033760071,1.0977861881256104,50000.0,0.6046000123023987,1.7144430875778198,10000.0,68099.94502019882,74150.72949790955,68099.94502019882,6035.759362697601,7.138274669647217,0.0 -149600,3.9866903,1.5562757,,,,,,,,,,,,,, -149700,4.213221,1.747929,,,,,,,,,,,,,, -149800,4.2934036,1.6248229,,,,,,,,,,,,,, -149900,4.166405,1.7031218,,,,,,,,,,,,,, -150000,4.318555,3.812175,,,,,,,,,,,,,, -150100,3.974452,1.6385247,,,,,,,,,,,,,, -150200,3.8984954,3.6684868,,,,,,,,,,,,,, -150300,4.614601,3.9984012,,,,,,,,,,,,,, -150400,3.72269,2.859815,,,,,,,,,,,,,, -150488,,,0.7956249713897705,0.7945720553398132,0.7290799617767334,1.082510232925415,50000.0,0.6063000559806824,1.6938530206680298,10000.0,68519.99004364014,74609.09123158455,68519.99004364014,6073.976391792297,7.188690185546875,0.0 -150500,4.181776,1.6265122,,,,,,,,,,,,,, -150600,3.8894458,2.4319353,,,,,,,,,,,,,, -150700,3.8374712,1.9637929,,,,,,,,,,,,,, -150800,4.2172766,1.6203051,,,,,,,,,,,,,, -150900,4.295711,1.7634019,,,,,,,,,,,,,, -151000,4.101585,2.4614453,,,,,,,,,,,,,, -151100,3.7625046,2.3054662,,,,,,,,,,,,,, -151200,4.202943,1.5344905,,,,,,,,,,,,,, -151300,4.5938067,1.5876813,,,,,,,,,,,,,, -151400,4.1006584,1.7904023,,,,,,,,,,,,,, -151411,,,0.7971875071525574,0.7923197746276855,0.7300999760627747,1.0846587419509888,50000.0,0.610200047492981,1.6871274709701538,10000.0,68940.01628899574,75067.06227970123,68940.01628899574,6111.824824333191,7.236317157745361,0.0 -151500,4.693741,3.7248182,,,,,,,,,,,,,, -151600,4.1542287,1.6384057,,,,,,,,,,,,,, -151700,4.274308,1.8584958,,,,,,,,,,,,,, -151800,4.1671004,3.1930728,,,,,,,,,,,,,, -151900,4.788556,3.8348255,,,,,,,,,,,,,, -152000,4.160566,2.6693938,,,,,,,,,,,,,, -152100,4.547795,1.6258726,,,,,,,,,,,,,, -152200,3.9243345,2.8052282,,,,,,,,,,,,,, -152300,4.5419135,3.7480915,,,,,,,,,,,,,, -152336,,,0.8050390481948853,0.7471283078193665,0.7305999994277954,1.0720800161361694,50000.0,0.6121000051498413,1.6767473220825195,10000.0,69360.21492862701,75529.35670304298,69360.21492862701,6153.824164628983,7.284022569656372,0.0 -152400,5.1132216,3.8023148,,,,,,,,,,,,,, -152500,4.8716288,1.645042,,,,,,,,,,,,,, -152600,4.3113165,1.7577667,,,,,,,,,,,,,, -152700,4.621731,1.8984938,,,,,,,,,,,,,, -152800,4.0807214,2.0415938,,,,,,,,,,,,,, -152900,4.672558,1.7230871,,,,,,,,,,,,,, -153000,4.278436,1.6021423,,,,,,,,,,,,,, -153100,4.1087065,2.7033887,,,,,,,,,,,,,, -153200,4.871598,3.5134525,,,,,,,,,,,,,, -153221,,,0.8126757740974426,0.722992479801178,0.7340399622917175,1.0540411472320557,50000.0,0.6136000156402588,1.6623154878616333,10000.0,69780.14233207703,75984.6604912281,69780.14233207703,6189.108179092407,7.329882383346558,0.0 -153300,4.2406063,2.5132594,,,,,,,,,,,,,, -153400,3.9255404,2.8344984,,,,,,,,,,,,,, -153500,4.142918,2.648895,,,,,,,,,,,,,, -153600,4.340967,1.6523981,,,,,,,,,,,,,, -153700,4.2832456,1.5809001,,,,,,,,,,,,,, -153800,4.519532,1.5603948,,,,,,,,,,,,,, -153900,3.9500244,1.9745289,,,,,,,,,,,,,, -154000,4.4580717,1.6005886,,,,,,,,,,,,,, -154100,4.421088,1.6253713,,,,,,,,,,,,,, -154146,,,0.8069726228713989,0.7545697093009949,0.7346999645233154,1.059226393699646,50000.0,0.6131000518798828,1.6655516624450684,10000.0,70200.330078125,76439.60998177528,70200.330078125,6223.7734811306,7.37775182723999,0.0 -154200,4.794111,1.683695,,,,,,,,,,,,,, -154300,4.4656734,1.789176,,,,,,,,,,,,,, -154400,4.4891353,2.1611738,,,,,,,,,,,,,, -154500,5.0135098,3.7820776,,,,,,,,,,,,,, -154600,4.4466267,1.6835287,,,,,,,,,,,,,, -154700,4.453322,1.5233811,,,,,,,,,,,,,, -154800,4.355202,2.2304661,,,,,,,,,,,,,, -154900,4.559506,2.0147376,,,,,,,,,,,,,, -155000,5.3674226,3.8965137,,,,,,,,,,,,,, -155071,,,0.8126562237739563,0.7402761578559875,0.7357199788093567,1.0593024492263794,50000.0,0.6117000579833984,1.6622600555419922,10000.0,70620.54906439781,76897.78743886948,70620.54906439781,6261.632649898529,7.428417921066284,0.0 -155100,4.249114,1.5055903,,,,,,,,,,,,,, -155200,4.1490273,1.9774919,,,,,,,,,,,,,, -155300,4.4577465,1.6131845,,,,,,,,,,,,,, -155400,4.398376,2.773466,,,,,,,,,,,,,, -155500,4.804682,1.57678,,,,,,,,,,,,,, -155600,4.2852626,2.8639083,,,,,,,,,,,,,, -155700,4.5645146,1.6383785,,,,,,,,,,,,,, -155800,4.632609,3.2987647,,,,,,,,,,,,,, -155900,4.3101077,2.62774,,,,,,,,,,,,,, -155973,,,0.81556636095047,0.7248516082763672,0.7346000075340271,1.0620793104171753,50000.0,0.6123000383377075,1.667306900024414,10000.0,71040.45113134384,77354.13040804863,71040.45113134384,6297.965461015701,7.48820424079895,0.0 -156000,4.3946285,1.659982,,,,,,,,,,,,,, -156100,4.490282,3.3314228,,,,,,,,,,,,,, -156200,4.566746,2.8309534,,,,,,,,,,,,,, -156300,4.414945,2.994616,,,,,,,,,,,,,, -156400,4.451059,2.900298,,,,,,,,,,,,,, -156500,4.691236,1.7779094,,,,,,,,,,,,,, -156600,4.8514667,1.6290011,,,,,,,,,,,,,, -156700,4.6198473,1.7901471,,,,,,,,,,,,,, -156800,4.7077456,1.6458243,,,,,,,,,,,,,, -156898,,,0.8093163967132568,0.7406788468360901,0.7397199869155884,1.0419944524765017,50000.0,0.6170000433921814,1.6425464153289795,10000.0,71460.83744478226,77812.2350218296,71460.83744478226,6335.5825090408325,7.540948152542114,0.0 -156900,4.941903,3.5634916,,,,,,,,,,,,,, -157000,4.7885027,2.504806,,,,,,,,,,,,,, -157100,4.8083477,1.5779753,,,,,,,,,,,,,, -157200,4.278741,1.7196852,,,,,,,,,,,,,, -157300,4.923038,1.5647014,,,,,,,,,,,,,, -157400,4.341804,1.9938533,,,,,,,,,,,,,, -157500,4.826235,1.5215325,,,,,,,,,,,,,, -157600,4.7126584,1.9411197,,,,,,,,,,,,,, -157700,5.2023873,1.5267311,,,,,,,,,,,,,, -157800,4.828542,1.6155643,,,,,,,,,,,,,, -157824,,,0.8129101395606995,0.7125070095062256,0.7402399778366089,1.0308599472045898,50000.0,0.6149000525474548,1.6336321830749512,10000.0,71880.89972639084,78269.73344492912,71880.89972639084,6372.919394493103,7.591905832290649,0.0 -157900,5.0596657,1.837933,,,,,,,,,,,,,, -158000,4.776202,1.7739675,,,,,,,,,,,,,, -158100,4.9091682,1.6279366,,,,,,,,,,,,,, -158200,5.232339,1.8089641,,,,,,,,,,,,,, -158300,4.7696457,1.8015767,,,,,,,,,,,,,, -158400,4.3433247,1.9675449,,,,,,,,,,,,,, -158500,4.6433015,1.5389774,,,,,,,,,,,,,, -158600,5.1291313,1.6138041,,,,,,,,,,,,,, -158700,4.6712985,2.1847513,,,,,,,,,,,,,, -158748,,,0.8221093416213989,0.6898041367530823,0.7427399754524231,1.029032826423645,50000.0,0.6207000017166138,1.616907835006714,10000.0,72301.06321072578,78728.11376452446,72301.06321072578,6411.03583574295,7.643307447433472,0.0 -158800,4.4699397,2.3500216,,,,,,,,,,,,,, -158900,4.740914,1.4531915,,,,,,,,,,,,,, -159000,4.8643093,2.9247499,,,,,,,,,,,,,, -159100,4.4389114,3.169278,,,,,,,,,,,,,, -159200,5.2521315,3.2024136,,,,,,,,,,,,,, -159300,4.763033,1.5261805,,,,,,,,,,,,,, -159400,4.7284555,3.3667336,,,,,,,,,,,,,, -159500,4.727487,2.118129,,,,,,,,,,,,,, -159600,4.8343487,1.4629421,,,,,,,,,,,,,, -159671,,,0.81800776720047,0.7054724097251892,0.7444599866867065,1.0100157260894775,50000.0,0.6194000244140625,1.619384765625,10000.0,72721.30015707016,79189.0581202507,72721.30015707016,6451.647526979446,7.691318511962891,0.0 -159700,5.0650163,1.7875435,,,,,,,,,,,,,, -159800,5.40844,1.5027204,,,,,,,,,,,,,, -159900,4.737208,1.551955,,,,,,,,,,,,,, -160000,5.349137,1.7754927,,,,,,,,,,,,,, -160100,5.172685,1.5145924,,,,,,,,,,,,,, -160200,4.566337,1.9341512,,,,,,,,,,,,,, -160300,5.0153213,1.582783,,,,,,,,,,,,,, -160400,5.028288,1.6517085,,,,,,,,,,,,,, -160500,5.094339,3.191927,,,,,,,,,,,,,, -160594,,,0.8192187547683716,0.6979327201843262,0.7441799640655518,1.0171470642089844,50000.0,0.6256000399589539,1.6140958070755005,10000.0,73141.47015810013,79645.54460167885,73141.47015810013,6487.851769685745,7.754884719848633,0.0 -160600,5.259407,3.2088826,,,,,,,,,,,,,, -160700,5.1299157,2.8001735,,,,,,,,,,,,,, -160800,6.2974916,3.4651225,,,,,,,,,,,,,, -160900,5.2266197,3.4966106,,,,,,,,,,,,,, -161000,4.661122,1.9104419,,,,,,,,,,,,,, -161100,5.313792,1.471397,,,,,,,,,,,,,, -161200,4.7584467,2.8617988,,,,,,,,,,,,,, -161300,5.4306693,1.5511134,,,,,,,,,,,,,, -161400,4.7933865,1.5368477,,,,,,,,,,,,,, -161500,4.9637556,2.0067563,,,,,,,,,,,,,, -161519,,,0.8247851133346558,0.6756976842880249,0.7467600107192993,1.0093703269958496,50000.0,0.626300036907196,1.5992350578308103,10000.0,73561.80247449875,80102.00310349464,73561.80247449875,6523.876390695572,7.80646276473999,0.0 -161600,4.984381,1.5213493,,,,,,,,,,,,,, -161700,5.029701,2.1139674,,,,,,,,,,,,,, -161800,6.0444293,3.8418052,,,,,,,,,,,,,, -161900,5.11365,1.7631917,,,,,,,,,,,,,, -162000,4.4505615,2.2759266,,,,,,,,,,,,,, -162100,5.787649,3.3063598,,,,,,,,,,,,,, -162200,5.5023293,1.57779,,,,,,,,,,,,,, -162300,5.6248474,1.5318663,,,,,,,,,,,,,, -162400,5.0480595,1.3887233,,,,,,,,,,,,,, -162446,,,0.8268163800239563,0.6841356754302979,0.7471999526023865,1.016317367553711,50000.0,0.6248000264167786,1.616760015487671,10000.0,73982.11289167404,80563.53231620789,73982.11289167404,6564.9981191158295,7.855017900466919,0.0 -162500,4.9026012,1.4626601,,,,,,,,,,,,,, -162600,5.2039213,2.6434958,,,,,,,,,,,,,, -162700,4.9696436,1.7791514,,,,,,,,,,,,,, -162800,5.1644745,1.5276304,,,,,,,,,,,,,, -162900,6.57222,3.9308794,,,,,,,,,,,,,, -163000,5.2916436,2.1268666,,,,,,,,,,,,,, -163100,4.885157,2.761108,,,,,,,,,,,,,, -163200,5.954459,1.6101515,,,,,,,,,,,,,, -163300,5.424368,1.4967059,,,,,,,,,,,,,, -163371,,,0.8298437595367432,0.6641115546226501,0.7482799887657166,0.9978412389755248,50000.0,0.6270000338554382,1.6025879383087158,10000.0,74402.1906940937,81021.04125189781,74402.1906940937,6602.321990013123,7.912995100021362,0.0 -163400,5.5267434,1.48181,,,,,,,,,,,,,, -163500,5.4971685,1.5016077,,,,,,,,,,,,,, -163600,5.3433666,1.3952314,,,,,,,,,,,,,, -163700,5.872387,3.7018843,,,,,,,,,,,,,, -163800,4.9726996,2.4833238,,,,,,,,,,,,,, -163900,5.233634,1.4710844,,,,,,,,,,,,,, -164000,4.93768,1.3619984,,,,,,,,,,,,,, -164100,5.2603254,1.5762203,,,,,,,,,,,,,, -164200,5.5374584,1.5783898,,,,,,,,,,,,,, -164296,,,0.8291991949081421,0.6516792178153992,0.7512399554252625,0.986169457435608,50000.0,0.6296000480651855,1.5875972509384155,10000.0,74822.27236771584,81476.4604113102,74822.27236771584,6637.561418771744,7.962857246398926,0.0 -164300,5.3086386,1.5130337,,,,,,,,,,,,,, -164400,6.045804,3.6009226,,,,,,,,,,,,,, -164500,5.591949,1.5462494,,,,,,,,,,,,,, -164600,6.1483083,3.6499815,,,,,,,,,,,,,, -164700,5.6376724,1.4964327,,,,,,,,,,,,,, -164800,6.07838,3.1959653,,,,,,,,,,,,,, -164900,5.5892596,1.5023198,,,,,,,,,,,,,, -165000,5.0863924,2.8439016,,,,,,,,,,,,,, -165100,5.4215574,1.5247053,,,,,,,,,,,,,, -165200,5.511482,1.4742349,,,,,,,,,,,,,, -165222,,,0.8324413895606995,0.6385616660118103,0.7501399517059326,0.9882864356040956,50000.0,0.6282000541687012,1.5926567316055298,10000.0,75242.47238945961,81935.01135158539,75242.47238945961,6675.802627325058,8.023729801177979,0.0 -165300,5.3496785,3.1277683,,,,,,,,,,,,,, -165400,5.6394205,2.408053,,,,,,,,,,,,,, -165500,5.735985,1.8023314,,,,,,,,,,,,,, -165600,5.7921124,2.048784,,,,,,,,,,,,,, -165700,6.300249,3.8039443,,,,,,,,,,,,,, -165800,5.302912,2.5506122,,,,,,,,,,,,,, -165900,5.7910147,1.5076374,,,,,,,,,,,,,, -166000,5.9719586,3.4916635,,,,,,,,,,,,,, -166100,5.2013345,1.7528249,,,,,,,,,,,,,, -166148,,,0.8308203220367432,0.660889208316803,0.7534799575805664,0.9846858382225036,50000.0,0.6331000328063965,1.5844221115112305,10000.0,75662.8382794857,82393.60587334633,75662.8382794857,6713.923867702484,8.082376480102539,0.0 -166200,5.9934163,3.1252704,,,,,,,,,,,,,, -166300,5.477733,1.4961331,,,,,,,,,,,,,, -166400,5.897803,2.6726265,,,,,,,,,,,,,, -166500,5.633115,3.150876,,,,,,,,,,,,,, -166600,6.1202474,3.444935,,,,,,,,,,,,,, -166700,5.4076285,1.5500861,,,,,,,,,,,,,, -166800,6.025914,3.8567433,,,,,,,,,,,,,, -166900,5.78902,1.4690502,,,,,,,,,,,,,, -167000,5.5283,1.4400272,,,,,,,,,,,,,, -167072,,,0.8302929401397705,0.645725667476654,0.7534199953079224,0.9803726077079772,50000.0,0.6313000321388245,1.5846593379974363,10000.0,76082.94535136223,82852.91863751411,76082.94535136223,6753.028325080872,8.135641813278198,0.0 -167100,5.874711,2.9459333,,,,,,,,,,,,,, -167200,6.049611,1.4265174,,,,,,,,,,,,,, -167300,5.98208,1.4982803,,,,,,,,,,,,,, -167400,5.6598105,1.384367,,,,,,,,,,,,,, -167500,5.4158998,1.9069251,,,,,,,,,,,,,, -167600,5.621063,1.5502809,,,,,,,,,,,,,, -167700,7.959579,3.7179267,,,,,,,,,,,,,, -167800,6.126358,1.533714,,,,,,,,,,,,,, -167900,5.9005685,2.496551,,,,,,,,,,,,,, -167997,,,0.8384179472923279,0.6208813190460205,0.7554399967193604,0.9750881791114808,50000.0,0.6322000026702881,1.5767680406570437,10000.0,76502.94843864441,83310.66582012177,76502.94843864441,6790.670788764954,8.188108682632446,0.0 -168000,5.5051227,1.4127979,,,,,,,,,,,,,, -168100,5.8951325,1.4406197,,,,,,,,,,,,,, -168200,5.595158,1.515266,,,,,,,,,,,,,, -168300,5.867398,1.7653993,,,,,,,,,,,,,, -168400,6.482553,2.0161905,,,,,,,,,,,,,, -168500,5.422512,1.4650745,,,,,,,,,,,,,, -168600,5.8706536,1.714192,,,,,,,,,,,,,, -168700,5.9624767,1.3894646,,,,,,,,,,,,,, -168800,5.4186444,2.8098269,,,,,,,,,,,,,, -168900,6.1843033,3.6741729,,,,,,,,,,,,,, -168922,,,0.8319921493530273,0.6606417894363403,0.756060004234314,0.979549527168274,50000.0,0.6335000395774841,1.5818995237350464,10000.0,76923.12551164627,83765.89028906822,76923.12551164627,6825.619375705719,8.23816442489624,0.0 -169000,5.830075,3.1399343,,,,,,,,,,,,,, -169100,5.8365173,1.343837,,,,,,,,,,,,,, -169200,5.558229,1.3744226,,,,,,,,,,,,,, -169300,5.7786665,2.4061615,,,,,,,,,,,,,, -169400,6.0645657,1.3805387,,,,,,,,,,,,,, -169500,5.8311806,3.0603755,,,,,,,,,,,,,, -169600,5.8069015,1.9780349,,,,,,,,,,,,,, -169700,5.503387,1.7333037,,,,,,,,,,,,,, -169800,6.374658,1.5111549,,,,,,,,,,,,,, -169849,,,0.8330858945846558,0.6320017576217651,0.7576000094413757,0.9574471712112428,50000.0,0.6391000151634216,1.561057209968567,10000.0,77343.40371894836,84221.57668566704,77343.40371894836,6860.926365375519,8.290139198303223,0.0 -169900,6.30094,1.3667868,,,,,,,,,,,,,, -170000,5.428663,1.761017,,,,,,,,,,,,,, -170100,6.147312,1.4606872,,,,,,,,,,,,,, -170200,5.9543157,1.4155104,,,,,,,,,,,,,, -170245,,,,,,,,,,,77520.00591540337,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/eval_measurements.csv deleted file mode 100644 index a3295ca12..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.21615719795227,0.0,34.864447355270386,1,0,34.864447355270386,0.0010000000474974,6.907756805419922,10000,63.08093595504761,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -67.2409439086914,0.0178401470184326,455.0325014591217,860,0,455.0325014591217,0.0260000005364418,6.065794467926025,10000,522.3378174304962,0.0338671877980232,5.902928829193115,0.030639998614788,5.9316911697387695,50000 -112.5071575641632,0.0444440841674804,875.0685901641846,1778,0,875.0685901641846,0.0545000024139881,5.616189956665039,10000,987.7160251140594,0.0747656226158142,5.325953960418701,0.0690400004386901,5.398590087890625,50000 -150.98039412498474,0.0706682205200195,1295.2597908973694,2697,0,1295.2597908973694,0.0668000057339668,5.455690860748291,10000,1446.455206632614,0.0935742184519767,5.147493362426758,0.0866400003433227,5.2080559730529785,50000 -194.3020474910736,0.0982718467712402,1715.3475415706637,3618,0,1715.3475415706637,0.0971000045537948,5.038281440734863,10000,1909.941129684448,0.138710930943489,4.64686393737793,0.1281400024890899,4.720630645751953,50000 -236.8234026432037,0.1237492561340332,2135.6945128440857,4541,0,2135.6945128440857,0.1171000078320503,4.889176845550537,10000,2372.883920669556,0.1616406142711639,4.394140720367432,0.1517399996519088,4.4949469566345215,50000 -277.5315291881561,0.1509375572204589,2556.0419986248016,5457,0,2556.0419986248016,0.1381999999284744,4.630496978759766,10000,2834.015328407288,0.1947851479053497,4.121174335479736,0.1805399954319,4.207534790039063,50000 -317.28538703918457,0.1799032688140869,2976.1662895679474,6379,0,2976.1662895679474,0.1568000018596649,4.494888305664063,10000,3293.9713699817657,0.2248437404632568,3.996538639068604,0.2078799903392791,4.081354141235352,50000 -357.51818895339966,0.21305513381958,3396.437614440918,7299,0,3396.437614440918,0.1620000004768371,4.486764430999756,10000,3754.5577280521393,0.2309374958276748,3.953554630279541,0.2123799920082092,4.063230991363525,50000 -396.7017076015472,0.2488820552825927,3816.479942083359,8218,0,3816.479942083359,0.1733000129461288,4.334794521331787,10000,4213.867478132248,0.2625390589237213,3.628416776657105,0.2255199998617172,3.8713552951812735,50000 -439.81812286376953,0.2767214775085449,4236.704192399979,9142,0,4236.704192399979,0.1899000108242035,4.1776041984558105,10000,4677.285391569138,0.2725390493869781,3.545860767364502,0.2532599866390228,3.675501585006714,50000 -477.5397388935089,0.3062405586242676,4656.660678386688,10064,0,4656.660678386688,0.1834000051021576,4.324698448181152,10000,5135.04213309288,0.2577343583106994,3.733825206756592,0.2406199872493744,3.855492115020752,50000 -517.0797474384308,0.3357918262481689,5076.703187704086,10987,0,5076.703187704086,0.1914000064134597,4.227960586547852,10000,5594.703353404999,0.2759570181369781,3.5516586303710938,0.2458799928426742,3.746058225631714,50000 -555.4356439113617,0.3640127182006836,5496.717879772186,11908,0,5496.717879772186,0.2025000154972076,4.11765718460083,10000,6053.150849819183,0.2885546684265136,3.454838752746582,0.2699199914932251,3.566235542297364,50000 -592.7188277244568,0.3962678909301758,5916.797674417496,12828,0,5916.797674417496,0.215700015425682,4.02402925491333,10000,6510.595469713211,0.3011523485183716,3.359318256378174,0.2775000035762787,3.493115901947021,50000 -633.5849332809448,0.4264430999755859,6336.788841247559,13749,0,6336.788841247559,0.2234000116586685,3.946427583694458,10000,6971.531793117523,0.3214453160762787,3.2423040866851807,0.2875199913978576,3.422166109085083,50000 -670.5265691280365,0.4606430530548095,6756.758174657822,14674,0,6756.758174657822,0.2208000123500824,3.968396186828613,10000,7428.526203870773,0.3084179759025574,3.319817066192627,0.2889599800109863,3.429847478866577,50000 -712.6638605594635,0.490314245223999,7176.900419473648,15597,0,7176.900419473648,0.2218000143766403,3.960664987564087,10000,7890.88374376297,0.3149218559265136,3.302459239959717,0.2924000024795532,3.4336822032928467,50000 -749.902556180954,0.5184402465820312,7597.258638143539,16518,0,7597.258638143539,0.2164000123739242,4.005506992340088,10000,8348.557153463364,0.3135937452316284,3.3022851943969727,0.2899599969387054,3.4567174911499023,50000 -790.755383014679,0.5460658073425293,8017.4315502643585,17442,0,8017.4315502643585,0.2416000068187713,3.832712650299072,10000,8809.659813642502,0.3493554592132568,3.045586585998535,0.311519980430603,3.2577993869781494,50000 -830.8956499099731,0.5777482986450195,8437.51029419899,18365,0,8437.51029419899,0.2379000186920166,3.853330612182617,10000,9269.958486557009,0.334765613079071,3.173603057861328,0.3100999891757965,3.299381732940674,50000 -867.5589497089386,0.6066219806671143,8857.471648454666,19286,0,8857.471648454666,0.234700009226799,3.879996299743652,10000,9726.6599547863,0.328125,3.206876516342163,0.3041799962520599,3.352670431137085,50000 -910.8073680400848,0.6355025768280029,9277.629351854324,20209,0,9277.629351854324,0.2412000149488449,3.833801746368408,10000,10190.143986225128,0.3524218797683716,3.068400621414185,0.311379998922348,3.3076298236846924,50000 -955.7451527118684,0.6662991046905518,9697.635033369064,21131,0,9697.635033369064,0.2560000121593475,3.727675914764404,10000,10655.167280435562,0.3587304651737213,3.019661903381348,0.3378599882125854,3.147328615188598,50000 -995.0893120765686,0.700777530670166,10117.811880588531,22053,0,10117.811880588531,0.2412000149488449,3.8103978633880615,10000,11114.772018909454,0.3477343618869781,3.0713601112365723,0.3216799795627594,3.2210161685943604,50000 -1035.9921073913574,0.7372596263885498,10538.091529846191,22976,0,10538.091529846191,0.255400002002716,3.708144664764404,10000,11576.040191173552,0.3622460961341858,2.9715335369110107,0.3338399827480316,3.14302396774292,50000 -1075.7954070568085,0.7705366611480713,10958.286712884905,23900,0,10958.286712884905,0.2514000236988067,3.753178596496582,10000,12036.120604991913,0.3491992056369781,3.0883233547210693,0.3263799846172333,3.207261085510254,50000 -1118.106077671051,0.8033137321472168,11378.547968149183,24825,0,11378.547968149183,0.2499000132083892,3.739815473556519,10000,12498.776052236555,0.3492382764816284,3.050524473190308,0.3290999829769134,3.18064284324646,50000 -1154.6575469970703,0.8338415622711182,11798.491114139557,25746,0,11798.491114139557,0.255700021982193,3.705024480819702,10000,12955.351569652556,0.3688476383686065,2.933948516845703,0.3354199826717376,3.105673789978028,50000 -1192.788717508316,0.8701872825622559,12218.91095137596,26668,0,12218.91095137596,0.2708000242710113,3.657903671264648,10000,13413.987750291824,0.3754101395606994,2.905060052871704,0.3468399941921234,3.06069564819336,50000 -1233.934255361557,0.9021317958831788,12638.921879768372,27590,0,12638.921879768372,0.2639000117778778,3.661264181137085,10000,13875.224354982376,0.3759374916553497,2.8957693576812744,0.3487599790096283,3.039360284805298,50000 -1270.6803524494171,0.9330759048461914,13059.063136339188,28514,0,13059.063136339188,0.2562000155448913,3.7381186485290527,10000,14332.19144487381,0.3650195300579071,2.984952926635742,0.3395199775695801,3.1396758556365967,50000 -1311.5622823238373,0.9640073776245116,13479.126105070114,29437,0,13479.126105070114,0.2629000246524811,3.685798168182373,10000,14793.215996265411,0.3907031118869781,2.8178062438964844,0.3384000062942505,3.126830577850342,50000 -1350.0582914352417,0.9966273307800292,13899.084905862808,30356,0,13899.084905862808,0.2647000253200531,3.649627447128296,10000,15251.75150704384,0.3720507621765136,2.9231514930725098,0.346699982881546,3.082463026046753,50000 -1389.423010110855,1.0317187309265137,14319.371235609056,31278,0,14319.371235609056,0.2718999981880188,3.6703150272369385,10000,15711.486956119536,0.3739062547683716,2.9314498901367188,0.3473199903964996,3.0704736709594727,50000 -1430.7917094230652,1.0686018466949463,14739.30179309845,32202,0,14739.30179309845,0.2857000231742859,3.5437114238739014,10000,16172.872477769852,0.4043163955211639,2.71505069732666,0.3649599850177765,2.9368364810943604,50000 -1470.1178567409515,1.098562479019165,15159.38614320755,33126,0,15159.38614320755,0.2803000211715698,3.540454387664795,10000,16632.36226463318,0.3913476467132568,2.812678575515747,0.3640399873256683,2.9665307998657227,50000 -1509.067530155182,1.1289377212524414,15579.40675020218,34048,0,15579.40675020218,0.2736000120639801,3.5886569023132324,10000,17091.41108250618,0.382148414850235,2.845215082168579,0.3581399917602539,2.999690532684326,50000 -1550.8924548625946,1.168405055999756,15999.715313196182,34973,0,15999.715313196182,0.2762000262737274,3.589385271072388,10000,17553.633061647415,0.3923242092132568,2.8107075691223145,0.3644599914550781,2.975851058959961,50000 -1595.15216255188,1.2027950286865234,16419.95436644554,35899,0,16419.95436644554,0.2859000265598297,3.4763166904449463,10000,18018.21585845948,0.399726539850235,2.7130661010742188,0.3743399977684021,2.8684468269348145,50000 -1635.3084263801577,1.2354793548583984,16840.155430793762,36822,0,16840.155430793762,0.2893000245094299,3.473834991455078,10000,18478.654767751694,0.4059179723262787,2.700996398925781,0.3773199915885925,2.852568387985229,50000 -1677.465342760086,1.268315076828003,17260.433208703995,37742,0,17260.433208703995,0.3015000224113464,3.418804407119751,10000,18941.17035579681,0.4121874868869781,2.6718852519989014,0.3816199898719787,2.845273971557617,50000 -1715.8985350131989,1.3021540641784668,17680.66015148163,38663,0,17680.66015148163,0.2939000129699707,3.432493209838867,10000,19399.91322350502,0.4493359327316284,2.501319169998169,0.3834999799728393,2.8455910682678223,50000 -1756.2016875743866,1.339077711105347,18100.73324918747,39587,0,18100.73324918747,0.2675000131130218,3.623452425003052,10000,19860.374609947205,0.3864062428474426,2.9066293239593506,0.3610599935054779,3.036476850509644,50000 -1794.0471782684326,1.3741655349731443,18521.10993671417,40510,0,18521.10993671417,0.289900004863739,3.5260701179504395,10000,20318.680195093155,0.401660144329071,2.776913642883301,0.3719799816608429,2.9364328384399414,50000 -1832.2530777454376,1.404280662536621,18941.358196496964,41434,0,18941.358196496964,0.2957000136375427,3.4568240642547607,10000,20777.212375164032,0.423164039850235,2.5967025756835938,0.3805999755859375,2.832966089248657,50000 -1870.96342420578,1.441824436187744,19361.352199077606,42358,0,19361.352199077606,0.299200028181076,3.452272891998291,10000,21236.00361180305,0.4092773199081421,2.675566434860229,0.3887199759483337,2.8146812915802,50000 -1908.5935270786283,1.474895715713501,19781.59530377388,43280,0,19781.59530377388,0.2969000041484833,3.4342548847198486,10000,21693.95869255066,0.4135546684265136,2.665776252746582,0.3831200003623962,2.8275225162506104,50000 -1947.7522237300875,1.50919771194458,20201.73545074463,44204,0,20201.73545074463,0.305400013923645,3.4004056453704834,10000,22153.340349674225,0.4298437535762787,2.616574048995972,0.3955000042915344,2.792757034301758,50000 -1987.8704631328585,1.5406432151794434,20621.955298423767,45130,0,20621.955298423767,0.3004000186920166,3.436731338500977,10000,22613.75858616829,0.4178906083106994,2.693882465362549,0.3900199830532074,2.8285815715789795,50000 -2028.9680247306824,1.5772721767425537,21041.939405441284,46053,0,21041.939405441284,0.3076000213623047,3.3951525688171387,10000,23074.92538762093,0.4268945157527923,2.6220920085906982,0.3996999859809875,2.773547410964966,50000 -2068.98393702507,1.6132659912109375,21462.19115138054,46976,0,21462.19115138054,0.3146000206470489,3.3456552028656006,10000,23535.27784347534,0.4323046803474426,2.536307573318481,0.405239999294281,2.6976065635681152,50000 -2108.731509923935,1.6486258506774902,21882.690141916275,47901,0,21882.690141916275,0.2980000078678131,3.4122321605682373,10000,23995.60861515999,0.4513085782527923,2.50014066696167,0.3981199860572815,2.7976644039154053,50000 -2151.260172843933,2.1510133743286133,22302.164969682693,48825,0,22302.164969682693,0.3126000165939331,3.3366458415985107,10000,24458.163280963898,0.4263085722923279,2.605025053024292,0.4018400013446808,2.7520787715911865,50000 -2195.242655277252,2.196490526199341,22722.07783985138,49748,0,22722.07783985138,0.31700000166893,3.326667070388794,10000,24922.15366792679,0.4362890422344208,2.5398664474487305,0.4076199829578399,2.714684247970581,50000 -2232.8479537963867,2.2305619716644287,23142.06842494011,50670,0,23142.06842494011,0.3070000112056732,3.3801825046539307,10000,25379.83161520958,0.4491601586341858,2.490548610687256,0.402319997549057,2.752784013748169,50000 -2276.712729215622,2.2754602432250977,23562.160203695297,51593,0,23562.160203695297,0.3153000175952911,3.344174385070801,10000,25843.881860017776,0.4297265410423279,2.5958352088928223,0.4063799977302551,2.737758159637451,50000 -2319.089858531952,2.310560941696167,23982.16872549057,52517,0,23982.16872549057,0.3210000097751617,3.283968210220337,10000,26306.350472450256,0.4399023354053497,2.52112889289856,0.4089199900627136,2.6849489212036133,50000 -2358.8874881267548,2.348323106765747,24402.40174674988,53440,0,24402.40174674988,0.3276000022888183,3.283280611038208,10000,26766.47193908692,0.4496484398841858,2.463557720184326,0.4154399931430816,2.6775717735290527,50000 -2398.257269382477,2.3831875324249268,24822.649163007736,54363,0,24822.649163007736,0.3160000145435333,3.333500862121582,10000,27226.17301273346,0.4367187321186065,2.556265592575073,0.4078799784183502,2.71284294128418,50000 -2440.921109676361,2.425750970840454,25242.692935943604,55287,0,25242.692935943604,0.3257000148296356,3.252009153366089,10000,27688.975242853165,0.4478320181369781,2.4865450859069824,0.4236199855804443,2.6358258724212646,50000 -2481.1670627593994,2.4647953510284424,25662.820051670074,56210,0,25662.820051670074,0.326200008392334,3.243481397628784,10000,28149.4397380352,0.4586718678474426,2.4052109718322754,0.421559989452362,2.599883794784546,50000 -2522.4550380706787,2.499359130859375,26083.14991784096,57131,0,26083.14991784096,0.329800009727478,3.234283447265625,10000,28611.14157915116,0.4611132740974426,2.4022233486175537,0.4213999807834625,2.619320869445801,50000 -2565.287081718445,2.5334596633911133,26503.37421274185,58055,0,26503.37421274185,0.3215000033378601,3.2995548248291016,10000,29074.28104519844,0.4460351467132568,2.520827054977417,0.4186599850654602,2.6796395778656006,50000 -2605.8996703624725,2.574249744415283,26923.577298164368,58979,0,26923.577298164368,0.3387000262737274,3.2149226665496826,10000,29535.1870803833,0.4637890458106994,2.4174797534942627,0.4327999949455261,2.582634925842285,50000 -2646.5336713790894,2.614187479019165,27343.765065908432,59903,0,27343.765065908432,0.3347000181674957,3.231863021850586,10000,29996.09737586975,0.478808581829071,2.342374563217163,0.4243399798870086,2.625638961791992,50000 -2688.3805034160614,2.658954620361328,27763.803510665894,60824,0,27763.803510665894,0.3337000012397766,3.2204599380493164,10000,30458.07638692856,0.4606054723262787,2.44146728515625,0.428659975528717,2.598587989807129,50000 -2727.574634075165,2.6977574825286865,28183.8014895916,61748,0,28183.8014895916,0.3223000168800354,3.279371738433838,10000,30917.355994701385,0.4458398222923279,2.5256237983703613,0.4161399900913238,2.6907958984375,50000 -2768.981840610504,2.7327027320861816,28603.97628927231,62670,0,28603.97628927231,0.3391000032424927,3.1987009048461914,10000,31379.02168869972,0.4750781059265136,2.330697059631348,0.4386599957942962,2.538495779037476,50000 -2807.4011034965515,2.7702078819274902,29023.923259735107,63592,0,29023.923259735107,0.3415000140666961,3.151350259780884,10000,31837.474231004715,0.4749414026737213,2.3414981365203857,0.44200000166893,2.4970359802246094,50000 -2847.411164045334,2.807270765304565,29443.90414404869,64514,0,29443.90414404869,0.3314000070095062,3.231776237487793,10000,32297.55117797852,0.4628320336341858,2.441080331802368,0.4318199753761291,2.605339527130127,50000 -2887.384506225586,2.84656310081482,29864.123959302902,65437,0,29864.123959302902,0.3297000229358673,3.290002822875977,10000,32757.831778764725,0.4589453041553497,2.488585948944092,0.4240599870681762,2.6650540828704834,50000 -2929.774935245514,2.8840224742889404,30284.387431383133,66360,0,30284.387431383133,0.3379000127315521,3.1919198036193848,10000,33220.57291150093,0.467089831829071,2.380722045898437,0.4341999888420105,2.5594401359558105,50000 -2972.3694083690643,2.92598557472229,30704.6074051857,67284,0,30704.6074051857,0.3418000042438507,3.1739964485168457,10000,33683.47833299637,0.4660546779632568,2.3730216026306152,0.4387199878692627,2.5333032608032227,50000 -3009.159858226776,2.972452402114868,31124.684980630875,68207,0,31124.684980630875,0.3398000299930572,3.1727473735809326,10000,34140.442061424255,0.4681445062160492,2.362294435501098,0.4355399906635284,2.5348927974700928,50000 -3048.2066645622253,3.0094172954559326,31545.37599992752,69130,0,31545.37599992752,0.3315000236034393,3.226326704025269,10000,34600.265894174576,0.4894726574420929,2.2886223793029785,0.4311199784278869,2.5940520763397217,50000 -3090.044335842133,3.047203779220581,31965.58531689644,70053,0,31965.58531689644,0.356900006532669,3.05293607711792,10000,35062.400102853775,0.485644519329071,2.2598440647125244,0.4579599797725677,2.3936898708343506,50000 -3127.865446805954,3.08792495727539,32385.602464437485,70977,0,32385.602464437485,0.349700003862381,3.116337776184082,10000,35520.32847523689,0.482421875,2.2963407039642334,0.4460999965667724,2.483259916305542,50000 -3170.2405710220337,3.131728172302246,32805.70732951164,71899,0,32805.70732951164,0.3488000035285949,3.096506118774414,10000,35982.90106034279,0.4937304556369781,2.219071626663208,0.4499399960041046,2.446985006332397,50000 -3212.5319378376007,3.1703007221221924,33225.77615022659,72822,0,33225.77615022659,0.3605000078678131,3.04154372215271,10000,36445.34771442413,0.4931640625,2.2360095977783203,0.4639399945735931,2.3808131217956543,50000 -3248.851069688797,3.207165241241455,33646.14836072922,73744,0,33646.14836072922,0.3540000021457672,3.107847213745117,10000,36902.12452173233,0.4816796779632568,2.3041117191314697,0.452019989490509,2.4658422470092773,50000 -3288.8602344989777,3.24563980102539,34066.290695905685,74668,0,34066.290695905685,0.3610000312328338,3.050954341888428,10000,37362.36349225044,0.5001562237739563,2.175226926803589,0.4656799733638763,2.3801429271698,50000 -3331.9556045532227,3.292352199554444,34486.66932654381,75590,0,34486.66932654381,0.3633000254631042,3.0194509029388428,10000,37825.93365931511,0.4973046779632568,2.2028791904449463,0.4642199873924255,2.3795573711395264,50000 -3369.877052307129,3.3313422203063965,34906.83700180054,76511,0,34906.83700180054,0.3605000078678131,3.0538387298583984,10000,38284.10998988152,0.4972656071186065,2.211764097213745,0.464739978313446,2.384872436523437,50000 -3412.295135498047,3.371867179870605,35326.80542087555,77432,0,35326.80542087555,0.3682000041007995,3.008074760437012,10000,38746.58549690247,0.5082421898841858,2.163210153579712,0.4681800007820129,2.3664684295654297,50000 -3456.013298511505,3.413897752761841,35746.876959085464,78356,0,35746.876959085464,0.3723000288009643,3.007336378097534,10000,39210.46626496315,0.5389648079872131,2.0390491485595703,0.4726599752902984,2.385406494140625,50000 -3497.8065342903137,3.455860614776612,36166.997678518295,79280,0,36166.997678518295,0.3680000305175781,2.9997689723968506,10000,39672.47076559067,0.5073632597923279,2.175447940826416,0.4732399880886078,2.356532573699951,50000 -3537.0906777381897,3.501105070114136,36587.06204032898,80204,0,36587.06204032898,0.3631000220775604,3.042214870452881,10000,40131.91340637207,0.5040234327316284,2.234501361846924,0.469760000705719,2.405181884765625,50000 -3578.367530584336,3.5425891876220703,37007.36421537399,81127,0,37007.36421537399,0.3806000053882599,2.9327199459075928,10000,40593.58241772652,0.5310351252555847,2.0206193923950195,0.484059989452362,2.27791166305542,50000 -3619.866011381149,3.581484079360962,37427.59137535095,82046,0,37427.59137535095,0.3708000183105469,2.9745867252349854,10000,41055.39542388916,0.5068749785423279,2.1524007320404053,0.4747999906539917,2.324914693832397,50000 -3662.954482316971,3.627816438674927,37847.52794003487,82966,0,37847.52794003487,0.3758000135421753,2.963873147964477,10000,41518.515270233154,0.5151953101158142,2.132296800613404,0.4791599810123443,2.317836284637451,50000 -3703.213614225388,3.6741631031036377,38267.72216629982,83888,0,38267.72216629982,0.3775000274181366,2.935068368911743,10000,41979.064101696014,0.5267773270606995,2.065271615982056,0.4811999797821045,2.2873899936676025,50000 -3743.798425197601,3.718817710876465,38687.88592624664,84811,0,38687.88592624664,0.3895000219345093,2.879626035690308,10000,42439.90636992455,0.52587890625,2.056821823120117,0.4928999841213226,2.2359001636505127,50000 -3784.9390771389008,3.758352756500244,39108.16635656357,85733,0,39108.16635656357,0.3757000267505646,2.9468016624450684,10000,42901.415877103806,0.517382800579071,2.116588830947876,0.4849999845027923,2.2845561504364014,50000 -3825.567506790161,3.802644729614258,39528.129534721375,86655,0,39528.129534721375,0.3890000283718109,2.89555287361145,10000,43362.101089954376,0.5392773151397705,2.0084869861602783,0.5009399652481079,2.2077176570892334,50000 -3867.839832305908,3.842219352722168,39948.1226670742,87576,0,39948.1226670742,0.3804000318050384,2.9136619567871094,10000,43824.45436620712,0.5444140434265137,1.9728312492370603,0.492499977350235,2.241783618927002,50000 -3905.0960161685935,3.8879823684692383,40368.45597243309,88499,0,40368.45597243309,0.3840000033378601,2.8978774547576904,10000,44282.1383099556,0.5226757526397705,2.081761598587036,0.491599977016449,2.257352590560913,50000 -3948.579337596893,3.9286277294158936,40788.48864960671,89420,0,40788.48864960671,0.3827000260353088,2.921517372131348,10000,44745.744921684265,0.5298827886581421,2.077981472015381,0.4909399747848511,2.27143931388855,50000 -3993.872991085053,3.9703309535980233,41208.83841848373,90343,0,41208.83841848373,0.3936000168323517,2.8484578132629395,10000,45211.47863817215,0.5615038871765137,1.8954758644104004,0.5057199597358704,2.180562734603882,50000 -4037.57505941391,4.016332626342773,41629.06167125702,91267,0,41629.06167125702,0.3892000317573547,2.885896921157837,10000,45675.498945236206,0.5329882502555847,2.074381828308105,0.5001400113105774,2.2381608486175537,50000 -4078.279564142227,4.529749393463135,42048.947972774506,92188,0,42048.947972774506,0.3930000066757202,2.836609601974488,10000,46136.65203642845,0.5424413681030273,1.9888652563095093,0.5024399757385254,2.1791367530822754,50000 -4120.957757472992,4.573975324630737,42469.17872309685,93110,0,42469.17872309685,0.4075000286102295,2.773594617843628,10000,46599.65407395363,0.5608007907867432,1.895853400230408,0.5133000016212463,2.126741886138916,50000 -4160.462041378021,4.615373611450195,42889.36215043068,94031,0,42889.36215043068,0.4058000147342682,2.794235944747925,10000,47059.432107925415,0.5450390577316284,1.9546549320220947,0.5126599669456482,2.1251654624938965,50000 -4202.813487291336,4.658671140670776,43309.5159702301,94953,0,43309.5159702301,0.4006000161170959,2.837176322937012,10000,47522.02918076515,0.5402538776397705,2.001154661178589,0.5038999915122986,2.17992639541626,50000 -4242.95054268837,4.701233148574829,43729.85532140732,95875,0,43729.85532140732,0.3922000229358673,2.864680528640747,10000,47982.59704566002,0.5454687476158142,2.008453845977783,0.5034399628639221,2.21916127204895,50000 -4283.885046005249,4.746007680892944,44149.91594219208,96795,0,44149.91594219208,0.409600019454956,2.749626398086548,10000,48443.68581080437,0.5572656393051147,1.9084703922271729,0.5183199644088745,2.112875461578369,50000 -4323.22603559494,4.789544582366943,44570.11754012108,97717,0,44570.11754012108,0.4025000333786011,2.8153152465820312,10000,48903.320981264114,0.5448241829872131,1.973633050918579,0.5102199912071228,2.158968210220337,50000 -4363.465132236481,4.829688310623169,44990.19146609306,98639,0,44990.19146609306,0.4062000215053558,2.79468321800232,10000,49363.72312116623,0.5551952719688416,1.929055452346801,0.5148800015449524,2.135251760482788,50000 -4405.708698511124,4.869691371917725,45410.33908557892,99560,0,45410.33908557892,0.4070000052452087,2.797633647918701,10000,49826.20372629166,0.5842382907867432,1.811358332633972,0.5146600008010864,2.145191669464112,50000 -4444.26401591301,4.912209749221802,45830.48991537094,100482,0,45830.48991537094,0.4086000323295593,2.762568473815918,10000,50285.001859903336,0.5579687356948853,1.9476128816604608,0.5199399590492249,2.1243672370910645,50000 -4480.708422660828,4.960192918777466,46250.49274253845,101404,0,46250.49274253845,0.4086000323295593,2.7876458168029785,10000,50741.545766592026,0.5579296946525574,1.942072868347168,0.5202599763870239,2.1204233169555664,50000 -4521.99057674408,5.001747369766235,46670.66937637329,102326,0,46670.66937637329,0.4100000262260437,2.769395112991333,10000,51203.09488844872,0.5746093392372131,1.825226068496704,0.5278399586677551,2.0796656608581543,50000 -4560.5888912677765,5.0430192947387695,47090.8395152092,103247,0,47090.8395152092,0.4067000150680542,2.7682337760925293,10000,51661.95341157913,0.5560937523841858,1.95291531085968,0.5209599733352661,2.1211466789245605,50000 -4603.852442026138,5.094490051269531,47510.9509665966,104172,0,47510.9509665966,0.4191000163555145,2.6920504570007324,10000,52125.42877531052,0.5689452886581421,1.8281947374343872,0.5344399809837341,2.012808561325073,50000 -4643.782491207123,5.150495529174805,47930.90433549881,105094,0,47930.90433549881,0.4271000325679779,2.6605894565582275,10000,52585.41693139076,0.5889062285423279,1.7638208866119385,0.5421199798583984,1.9910471439361568,50000 -4681.881660699844,5.203859329223633,48350.86739444733,106015,0,48350.86739444733,0.4189000129699707,2.725624322891236,10000,53043.58160710335,0.5689062476158142,1.88288688659668,0.5313199758529663,2.059181928634644,50000 -4720.600524902344,5.249224662780762,48770.8959608078,106937,0,48770.8959608078,0.431300014257431,2.653942584991455,10000,53502.4229733944,0.5787695050239563,1.8417892456054688,0.5412399768829346,2.022822380065918,50000 -4761.017573356628,5.296191930770874,49191.23649168015,107860,0,49191.23649168015,0.4325000345706939,2.625818014144897,10000,53963.276510477066,0.5873827934265137,1.7708899974822998,0.5433200001716614,1.9900877475738523,50000 -4798.291470050812,5.3400962352752686,49611.455362319946,108784,0,49611.455362319946,0.4370000064373016,2.632002353668213,10000,54420.8622033596,0.6213085651397705,1.629205584526062,0.5485599637031555,1.96748161315918,50000 -4840.195640087128,5.386731386184692,50031.65995430946,109705,0,50031.65995430946,0.4294000267982483,2.6271207332611084,10000,54883.06630086899,0.588671863079071,1.755498290061951,0.5467999577522278,1.954467177391052,50000 -4879.87920832634,5.428632497787476,50451.89798164368,110628,0,50451.89798164368,0.4384000301361084,2.5746302604675293,10000,55343.07885932922,0.595898449420929,1.7019425630569458,0.5554800033569336,1.911659836769104,50000 -4918.517581939697,5.472025394439697,50871.97425460816,111551,0,50871.97425460816,0.4394000172615051,2.577054500579834,10000,55801.88543534279,0.6108593344688416,1.6443220376968384,0.5592799782752991,1.905354380607605,50000 -4958.763142347336,5.519868135452271,51292.22484111786,112474,0,51292.22484111786,0.4414000213146209,2.582552194595337,10000,56262.47783136368,0.5939062237739563,1.7227932214736938,0.5597599744796753,1.904951810836792,50000 -5000.821955919266,5.5690460205078125,51712.5141146183,113395,0,51712.5141146183,0.4389000236988067,2.563261270523072,10000,56724.9230401516,0.6010937094688416,1.6848231554031372,0.5621399879455566,1.882049918174744,50000 -5034.966580152512,5.616888523101807,52132.70584130287,114316,0,52132.70584130287,0.442300021648407,2.572660207748413,10000,57179.35532331467,0.6037890315055847,1.70114004611969,0.55485999584198,1.942366003990173,50000 -5076.205003976822,5.660736322402954,52552.76272273064,115237,0,52552.76272273064,0.448600023984909,2.519437313079834,10000,57640.74333691597,0.6056835651397705,1.67186439037323,0.5682199597358704,1.854199171066284,50000 -5117.389803171158,5.706765413284302,52972.98017334938,116160,0,52972.98017334938,0.4535000324249267,2.516984462738037,10000,58102.24056863785,0.6073437333106995,1.675774097442627,0.5725799798965454,1.846383810043335,50000 -5158.30740070343,5.750935316085815,53393.17074465752,117083,0,53393.17074465752,0.4607000350952148,2.4718594551086426,10000,58563.44180321693,0.6232812404632568,1.5855211019515991,0.5778599977493286,1.8170597553253167,50000 -5198.218531131744,5.809661865234375,53813.120841264725,118005,0,53813.120841264725,0.462300032377243,2.502239465713501,10000,59023.410198926926,0.6318945288658142,1.554277420043945,0.5718399882316589,1.8370327949523928,50000 -5240.525000333786,5.854615688323975,54233.34093928337,118929,0,54233.34093928337,0.4620000123977661,2.4387903213500977,10000,59486.030648469925,0.6213085651397705,1.6146199703216553,0.5793200135231018,1.8063288927078247,50000 -5275.923225164413,5.900138854980469,54653.69129776955,119852,0,54653.69129776955,0.4583000242710113,2.520430088043213,10000,59941.873532533646,0.6187695264816284,1.6314858198165894,0.5747199654579163,1.853618025779724,50000 -5316.057541370392,5.947498083114624,55073.72949528694,120773,0,55073.72949528694,0.4687000215053558,2.431938886642456,10000,60402.14140725136,0.6486914157867432,1.4689342975616455,0.5861999988555908,1.7645444869995115,50000 -5357.858735084534,5.998109817504883,55493.78798747063,121692,0,55493.78798747063,0.4652000367641449,2.4484989643096924,10000,60864.09947562218,0.6215234398841858,1.6078003644943235,0.5834999680519104,1.7998595237731934,50000 -5397.7881071567535,6.045310974121094,55914.06940293312,122615,0,55914.06940293312,0.4740000367164612,2.3974080085754395,10000,61324.40569233894,0.6367577910423279,1.5228012800216677,0.5913800001144409,1.7486042976379397,50000 -5440.216062545776,6.09236216545105,56334.1478741169,123540,0,56334.1478741169,0.4682000279426574,2.441822290420532,10000,61787.00834584236,0.644335925579071,1.5328458547592163,0.5850600004196167,1.7931092977523804,50000 -5480.426479578018,6.140239953994751,56754.07245540619,124462,0,56754.07245540619,0.4686000347137451,2.41546893119812,10000,62247.24000930786,0.6310741901397705,1.5595864057540894,0.5909199714660645,1.7579379081726074,50000 -5523.234518289566,6.193271636962891,57174.0851354599,125382,0,57174.0851354599,0.4755000174045563,2.390855073928833,10000,62710.162212610245,0.6387109160423279,1.5228734016418457,0.5959399938583374,1.7375160455703735,50000 -5563.646774530411,6.24152421951294,57594.21831464768,126304,0,57594.21831464768,0.4809000194072723,2.3549671173095703,10000,63170.8043320179,0.6492577791213989,1.4756730794906616,0.5982199907302856,1.706802487373352,50000 -5607.780184268951,6.286353588104248,58014.53499889374,127228,0,58014.53499889374,0.472100019454956,2.3971104621887207,10000,63635.348907232285,0.6450781226158142,1.5104620456695557,0.5978999733924866,1.7296024560928345,50000 -5649.124925851822,6.335825204849243,58434.6179728508,128150,0,58434.6179728508,0.4826000332832336,2.3594110012054443,10000,64096.87507081032,0.6434765458106995,1.4898433685302734,0.6008999943733215,1.686245083808899,50000 -5689.88328742981,6.38477087020874,58854.83988237381,129073,0,58854.83988237381,0.4907000362873077,2.321425199508667,10000,64557.95269560814,0.6571874618530273,1.4405206441879272,0.6071999669075012,1.6774921417236328,50000 -5733.991091012955,6.439935207366943,59274.73733615875,129996,0,59274.73733615875,0.4869000315666199,2.335404634475708,10000,65022.062368392944,0.6743749976158142,1.362080216407776,0.6096000075340271,1.675639510154724,50000 -5772.714293003082,6.497478723526001,59694.95303297043,130801,0,59694.95303297043,0.4888000190258026,2.307673454284668,10000,65481.10044121742,0.6576952934265137,1.4295475482940674,0.6107199788093567,1.6498571634292605,50000 -5811.607728004456,6.554529428482056,60115.202788591385,131725,0,60115.202788591385,0.4941000342369079,2.286726474761963,10000,65940.35055208206,0.6663476228713989,1.3924933671951294,0.6150000095367432,1.636078119277954,50000 -5851.074378013611,6.603094100952148,60535.19282770157,132646,0,60535.19282770157,0.4978000223636627,2.2578117847442627,10000,66399.90402460098,0.6916796565055847,1.280269742012024,0.6193599700927734,1.6119118928909302,50000 -5891.558381795883,6.649211645126343,60955.23886442184,133570,0,60955.23886442184,0.4993000328540802,2.2752673625946045,10000,66860.528911829,0.6681445240974426,1.408696532249451,0.624239981174469,1.609699010848999,50000 -5933.31524014473,6.696808576583862,61375.3524119854,134493,0,61375.3524119854,0.5034000277519226,2.2548952102661133,10000,67322.49597454071,0.6716992259025574,1.369166374206543,0.6208400130271912,1.6039738655090332,50000 -5971.316042661667,7.2534918785095215,61794.85926914215,135413,0,61794.85926914215,0.49590003490448,2.288580894470215,10000,67780.60915780067,0.6820898056030273,1.3533567190170288,0.6220999956130981,1.6352964639663696,50000 -6013.095870256424,7.303630352020264,62215.11662912369,136335,0,62215.11662912369,0.5031000375747681,2.237124443054199,10000,68242.74508333206,0.67431640625,1.3767963647842407,0.6293999552726746,1.5818134546279907,50000 -6051.310604095459,7.360164165496826,62635.25749588013,137257,0,62635.25749588013,0.509600043296814,2.2336137294769287,10000,68701.20639562607,0.6786132454872131,1.3735076189041138,0.6330199837684631,1.5894250869750977,50000 -6095.903052806854,7.405869483947754,63055.25727057457,138175,0,63055.25727057457,0.5169000029563904,2.163562297821045,10000,69165.89258766174,0.6922070384025574,1.2771286964416504,0.6366400122642517,1.529810905456543,50000 -6135.814694881439,7.4559032917022705,63475.19632101059,139094,0,63475.19632101059,0.5109000205993652,2.2069404125213623,10000,69625.84164237976,0.6813281178474426,1.333216428756714,0.6382799744606018,1.5517253875732422,50000 -6172.251842260361,7.511298418045044,63895.5168299675,140016,0,63895.5168299675,0.5188000202178955,2.150461196899414,10000,70082.7032134533,0.6907812356948853,1.2872713804244995,0.6406999826431274,1.5119783878326416,50000 -6214.205435991287,7.5685484409332275,64315.69512438774,140938,0,64315.69512438774,0.5188000202178955,2.1491525173187256,10000,70544.94150233269,0.6983007788658142,1.247533082962036,0.6419399976730347,1.5060182809829712,50000 -6255.516871213913,7.614865064620972,64735.88982725144,141862,0,64735.88982725144,0.5161000490188599,2.164371967315674,10000,71006.54262113571,0.7147851586341858,1.200818657875061,0.6396399736404419,1.5271742343902588,50000 -6293.555589437485,7.669762372970581,65156.19044685364,142784,0,65156.19044685364,0.5197000503540039,2.190412759780884,10000,71464.98599791527,0.6896093487739563,1.3375908136367798,0.6431199908256531,1.5491735935211182,50000 -6338.054198503494,7.726184368133545,65576.20358800888,143705,0,65576.20358800888,0.525600016117096,2.1392199993133545,10000,71929.60329174995,0.7058203220367432,1.2415093183517456,0.6496599912643433,1.489941954612732,50000 -6381.687472581863,7.781303644180298,65996.54626846313,144624,0,65996.54626846313,0.5315000414848328,2.090836763381958,10000,72393.68397283554,0.7205273509025574,1.1597256660461426,0.6532999873161316,1.449857473373413,50000 -6425.086839675903,7.835384845733643,66416.76606321335,145543,0,66416.76606321335,0.5302000045776367,2.0874135494232178,10000,72857.4061486721,0.7084179520606995,1.2068352699279783,0.6576600074768066,1.4416378736495972,50000 -6465.893684387207,7.883531093597412,66836.82033586502,146464,0,66836.82033586502,0.5349000096321106,2.0838398933410645,10000,73318.36433267593,0.7099999785423279,1.185759425163269,0.6562199592590332,1.433472990989685,50000 -6505.38139629364,7.932077884674072,67257.22652721405,147388,0,67257.22652721405,0.5333000421524048,2.0673604011535645,10000,73778.35560202599,0.71875,1.1505794525146484,0.6584999561309814,1.419907569885254,50000 -6547.186131954193,7.988610029220581,67677.38717126846,148311,0,67677.38717126846,0.5368000268936157,2.0515666007995605,10000,74240.42641305923,0.7202538847923279,1.1635055541992188,0.6650800108909607,1.4090675115585327,50000 -6585.98409485817,8.03798794746399,68097.53750610352,149233,0,68097.53750610352,0.5412999987602234,2.040438652038574,10000,74699.47302079201,0.7199999690055847,1.1488616466522217,0.6662200093269348,1.3948432207107544,50000 -6628.50556063652,8.088744640350342,68517.46277451515,150153,0,68517.46277451515,0.547700047492981,2.0186476707458496,10000,75162.01907277107,0.7326562404632568,1.1124690771102903,0.6711199879646301,1.3810205459594729,50000 -6667.386779785156,8.145340204238892,68937.43747639656,151075,0,68937.43747639656,0.5461000204086304,2.029315233230591,10000,75620.98027873039,0.7326562404632568,1.1057353019714355,0.673039972782135,1.3807213306427002,50000 -6709.3519904613495,8.198824405670166,69357.67158484459,151999,0,69357.67158484459,0.5471000075340271,1.9906830787658687,10000,76083.28164362907,0.7305663824081421,1.1032029390335083,0.6752399802207947,1.345995306968689,50000 -6747.357416629791,8.25408148765564,69777.8404636383,152922,0,69777.8404636383,0.5550000071525574,1.967795968055725,10000,76541.56018471718,0.7342382669448853,1.077256202697754,0.6779199838638306,1.3336390256881714,50000 -6788.15131855011,8.30648159980774,70197.90728449821,153844,0,70197.90728449821,0.5613000392913818,1.933016657829285,10000,77002.52214980125,0.755175769329071,0.9857242107391356,0.6827999949455261,1.3085020780563354,50000 -6829.035629749298,8.359469890594482,70618.0102212429,154766,0,70618.0102212429,0.5594000220298767,1.937770247459412,10000,77463.61097240448,0.7420898079872131,1.0481141805648804,0.6864399909973145,1.2994166612625122,50000 -6870.133940458298,8.413815975189209,71038.28627920151,155688,0,71038.28627920151,0.563800036907196,1.915860891342163,10000,77925.08789467812,0.7477148175239563,1.0200145244598389,0.6866199970245361,1.2940551042556765,50000 -6912.691474199295,8.466750383377075,71458.23702192307,156608,0,71458.23702192307,0.5662000179290771,1.920639157295227,10000,78387.6969614029,0.7582226395606995,0.9849535822868348,0.6900399923324585,1.2790842056274414,50000 -6956.333927631378,8.52064037322998,71878.5546181202,157528,0,71878.5546181202,0.5696000456809998,1.915724515914917,10000,78851.75995445251,0.7495703101158142,1.0092881917953491,0.6941199898719788,1.263446569442749,50000 -7001.015208482742,8.571056127548218,72298.76634907722,158452,0,72298.76634907722,0.5752000212669373,1.885601282119751,10000,79316.75279378891,0.7608984112739563,0.9779459238052368,0.6959199905395508,1.251377820968628,50000 -7042.348134994507,8.622626304626465,72719.03267765045,159372,0,72719.03267765045,0.5725000500679016,1.8790431022644043,10000,79778.45223927498,0.7640234231948853,0.9492297768592834,0.6974599957466125,1.241568922996521,50000 -7082.851930856705,8.679443597793579,73139.28754115105,160296,0,73139.28754115105,0.5748000144958496,1.8601552248001096,10000,80239.31615614891,0.7643359303474426,0.9486043453216552,0.703000009059906,1.21571147441864,50000 -7125.049999952316,8.735102891921997,73559.51320672035,161219,0,73559.51320672035,0.5754000544548035,1.8825575113296509,10000,80701.84395289421,0.763378918170929,0.9769259691238404,0.7021999955177307,1.2391090393066406,50000 -7167.17483496666,8.798385381698608,73979.57974791527,162142,0,73979.57974791527,0.579300045967102,1.8460499048233032,10000,81164.14709663391,0.7700781226158142,0.9373509287834167,0.7052599787712097,1.2130261659622192,50000 -7212.307286977768,8.858510732650757,74399.50149416924,163064,0,74399.50149416924,0.5833000540733337,1.832371592521668,10000,81629.3094651699,0.7858203053474426,0.8678929209709167,0.7084400057792664,1.2065235376358032,50000 -7251.045498132706,8.912059307098389,74819.62465143204,163987,0,74819.62465143204,0.5903000235557556,1.815895795822144,10000,82088.27244186401,0.7753320336341858,0.9123504161834716,0.7140399813652039,1.184422492980957,50000 -7291.754102945328,8.975491523742676,75239.51298332214,164909,0,75239.51298332214,0.5915000438690186,1.7959223985671997,10000,82548.98171710968,0.7803320288658142,0.8811646103858948,0.7166399955749512,1.1691982746124268,50000 -7335.747500419617,9.027804374694824,75659.53583550453,165831,0,75659.53583550453,0.5896000266075134,1.802293419837952,10000,83013.09928798676,0.788378894329071,0.8617101907730103,0.7175999879837036,1.1666576862335205,50000 -7375.801503419876,9.081603765487673,76079.45230317116,166751,0,76079.45230317116,0.5955000519752502,1.7816680669784546,10000,83473.17249202728,0.7841405868530273,0.869185209274292,0.7195799946784973,1.1540203094482422,50000 -7417.107574224472,9.143556594848633,76499.65189146996,167674,0,76499.65189146996,0.5976999998092651,1.7546520233154297,10000,83934.78953242302,0.7900976538658142,0.8440901041030884,0.7210599780082703,1.1322126388549805,50000 -7460.537466287613,9.19641399383545,76919.8663828373,168597,0,76919.8663828373,0.6009000539779663,1.7406065464019775,10000,84398.5364639759,0.7947070002555847,0.8163579106330872,0.7242000102996826,1.1301054954528809,50000 -7503.553560256958,9.254878997802734,77339.84844470024,169519,0,77339.84844470024,0.602400004863739,1.7409383058547974,10000,84861.64176750183,0.7934960722923279,0.8324291706085205,0.725820004940033,1.1191824674606323,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/measurements.csv deleted file mode 100644 index 14f8f57c9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1887 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.37333363,6.9077563,,,,,,,,,,,,,, -1,,,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,34.864447355270386,63.08093595504761,34.864447355270386,28.21615719795227,0.0,0.0 -100,0.5345217,6.830349,,,,,,,,,,,,,, -200,0.9364716,6.7053638,,,,,,,,,,,,,, -300,0.98852146,6.5289936,,,,,,,,,,,,,, -400,0.8801767,6.5275545,,,,,,,,,,,,,, -500,0.861733,6.7618575,,,,,,,,,,,,,, -600,0.86091596,6.379199,,,,,,,,,,,,,, -700,0.9657758,6.234465,,,,,,,,,,,,,, -800,1.1662923,6.4302483,,,,,,,,,,,,,, -860,,,0.0338671877980232,5.902928829193115,0.030639998614788,5.9316911697387695,50000.0,0.0260000005364418,6.065794467926025,10000.0,455.0325014591217,522.3378174304962,455.0325014591217,67.2409439086914,0.0178401470184326,0.0 -900,0.7381636,6.107238,,,,,,,,,,,,,, -1000,0.723859,6.171989,,,,,,,,,,,,,, -1100,0.67814565,6.4193106,,,,,,,,,,,,,, -1200,0.5571498,6.0707545,,,,,,,,,,,,,, -1300,0.63382703,6.0564666,,,,,,,,,,,,,, -1400,0.5477051,6.6601357,,,,,,,,,,,,,, -1500,0.6671792,6.6208715,,,,,,,,,,,,,, -1600,0.6723983,5.8444195,,,,,,,,,,,,,, -1700,0.54010713,6.196493,,,,,,,,,,,,,, -1778,,,0.0747656226158142,5.325953960418701,0.0690400004386901,5.398590087890625,50000.0,0.0545000024139881,5.616189956665039,10000.0,875.0685901641846,987.7160251140594,875.0685901641846,112.5071575641632,0.0444440841674804,0.0 -1800,0.56270444,6.429225,,,,,,,,,,,,,, -1900,0.6312672,5.831458,,,,,,,,,,,,,, -2000,0.56851804,5.820255,,,,,,,,,,,,,, -2100,0.54557943,5.819797,,,,,,,,,,,,,, -2200,0.4190826,6.6528745,,,,,,,,,,,,,, -2300,0.4741094,5.753438,,,,,,,,,,,,,, -2400,0.5341279,5.5847235,,,,,,,,,,,,,, -2500,0.5426925,6.2303395,,,,,,,,,,,,,, -2600,0.5800073,5.729277,,,,,,,,,,,,,, -2697,,,0.0935742184519767,5.147493362426758,0.0866400003433227,5.2080559730529785,50000.0,0.0668000057339668,5.455690860748291,10000.0,1295.2597908973694,1446.455206632614,1295.2597908973694,150.98039412498474,0.0706682205200195,0.0 -2700,0.5016589,6.4109893,,,,,,,,,,,,,, -2800,0.46235412,6.036454,,,,,,,,,,,,,, -2900,0.5771683,5.3611937,,,,,,,,,,,,,, -3000,0.5537741,5.495505,,,,,,,,,,,,,, -3100,0.5665852,5.512565,,,,,,,,,,,,,, -3200,0.44448748,6.210102,,,,,,,,,,,,,, -3300,0.54148966,5.328244,,,,,,,,,,,,,, -3400,0.5846938,5.3761177,,,,,,,,,,,,,, -3500,0.49323693,5.5508947,,,,,,,,,,,,,, -3600,0.49192053,5.147485,,,,,,,,,,,,,, -3618,,,0.138710930943489,4.64686393737793,0.1281400024890899,4.720630645751953,50000.0,0.0971000045537948,5.038281440734863,10000.0,1715.3475415706637,1909.941129684448,1715.3475415706637,194.3020474910736,0.0982718467712402,0.0 -3700,0.79903245,5.4338326,,,,,,,,,,,,,, -3800,0.50508946,5.2219353,,,,,,,,,,,,,, -3900,0.9387077,5.4544945,,,,,,,,,,,,,, -4000,0.6403164,5.906692,,,,,,,,,,,,,, -4100,0.46122968,5.771821,,,,,,,,,,,,,, -4200,0.5156849,5.922494,,,,,,,,,,,,,, -4300,0.4699634,5.5534325,,,,,,,,,,,,,, -4400,0.702306,6.426707,,,,,,,,,,,,,, -4500,0.7102047,4.936615,,,,,,,,,,,,,, -4541,,,0.1616406142711639,4.394140720367432,0.1517399996519088,4.4949469566345215,50000.0,0.1171000078320503,4.889176845550537,10000.0,2135.6945128440857,2372.883920669556,2135.6945128440857,236.8234026432037,0.1237492561340332,0.0 -4600,0.7220094,6.293412,,,,,,,,,,,,,, -4700,0.97017777,6.199792,,,,,,,,,,,,,, -4800,0.628732,5.071438,,,,,,,,,,,,,, -4900,0.81743056,5.0476336,,,,,,,,,,,,,, -5000,0.6670151,4.8491817,,,,,,,,,,,,,, -5100,0.8336706,4.9192047,,,,,,,,,,,,,, -5200,0.7229728,4.8399897,,,,,,,,,,,,,, -5300,0.83726746,5.153849,,,,,,,,,,,,,, -5400,0.845339,4.8716607,,,,,,,,,,,,,, -5457,,,0.1947851479053497,4.121174335479736,0.1805399954319,4.207534790039063,50000.0,0.1381999999284744,4.630496978759766,10000.0,2556.0419986248016,2834.015328407288,2556.0419986248016,277.5315291881561,0.1509375572204589,0.0 -5500,0.8130439,5.106575,,,,,,,,,,,,,, -5600,0.69235235,4.9899826,,,,,,,,,,,,,, -5700,0.7521791,6.0215054,,,,,,,,,,,,,, -5800,0.8681135,4.8728247,,,,,,,,,,,,,, -5900,0.7952735,5.381453,,,,,,,,,,,,,, -6000,0.65075225,4.9122086,,,,,,,,,,,,,, -6100,0.86832887,4.705046,,,,,,,,,,,,,, -6200,0.7761517,6.2213216,,,,,,,,,,,,,, -6300,0.6938967,4.6706457,,,,,,,,,,,,,, -6379,,,0.2248437404632568,3.996538639068604,0.2078799903392791,4.081354141235352,50000.0,0.1568000018596649,4.494888305664063,10000.0,2976.1662895679474,3293.9713699817657,2976.1662895679474,317.28538703918457,0.1799032688140869,0.0 -6400,0.7649852,4.637851,,,,,,,,,,,,,, -6500,0.74889725,4.9760127,,,,,,,,,,,,,, -6600,0.6443669,5.005858,,,,,,,,,,,,,, -6700,0.64572346,6.115892,,,,,,,,,,,,,, -6800,0.782837,4.607653,,,,,,,,,,,,,, -6900,0.77175915,4.821822,,,,,,,,,,,,,, -7000,0.7787801,4.5532656,,,,,,,,,,,,,, -7100,0.81063735,6.210451,,,,,,,,,,,,,, -7200,0.7205615,5.562838,,,,,,,,,,,,,, -7299,,,0.2309374958276748,3.953554630279541,0.2123799920082092,4.063230991363525,50000.0,0.1620000004768371,4.486764430999756,10000.0,3396.437614440918,3754.5577280521393,3396.437614440918,357.51818895339966,0.21305513381958,0.0 -7300,0.9402281,6.3526516,,,,,,,,,,,,,, -7400,0.67207456,6.0364795,,,,,,,,,,,,,, -7500,0.66310495,5.935543,,,,,,,,,,,,,, -7600,0.8095516,4.8077145,,,,,,,,,,,,,, -7700,0.84992737,4.845103,,,,,,,,,,,,,, -7800,0.89595366,6.3169913,,,,,,,,,,,,,, -7900,0.86425614,4.955548,,,,,,,,,,,,,, -8000,0.7494382,4.569913,,,,,,,,,,,,,, -8100,0.8374708,4.583261,,,,,,,,,,,,,, -8200,0.78581756,4.586575,,,,,,,,,,,,,, -8218,,,0.2625390589237213,3.628416776657105,0.2255199998617172,3.8713552951812735,50000.0,0.1733000129461288,4.334794521331787,10000.0,3816.479942083359,4213.867478132248,3816.479942083359,396.7017076015472,0.2488820552825927,0.0 -8300,0.89553565,5.664416,,,,,,,,,,,,,, -8400,0.9657134,4.5964446,,,,,,,,,,,,,, -8500,0.7581696,4.853735,,,,,,,,,,,,,, -8600,0.83575386,4.5036526,,,,,,,,,,,,,, -8700,1.1039188,4.645851,,,,,,,,,,,,,, -8800,1.0359488,4.609297,,,,,,,,,,,,,, -8900,0.85944927,4.335394,,,,,,,,,,,,,, -9000,0.78284806,4.5002017,,,,,,,,,,,,,, -9100,0.7520169,4.472762,,,,,,,,,,,,,, -9142,,,0.2725390493869781,3.545860767364502,0.2532599866390228,3.675501585006714,50000.0,0.1899000108242035,4.1776041984558105,10000.0,4236.704192399979,4677.285391569138,4236.704192399979,439.81812286376953,0.2767214775085449,0.0 -9200,0.94739234,4.5199957,,,,,,,,,,,,,, -9300,1.0630181,4.434876,,,,,,,,,,,,,, -9400,1.0938321,4.556095,,,,,,,,,,,,,, -9500,0.86849105,5.1286917,,,,,,,,,,,,,, -9600,0.71082324,5.316351,,,,,,,,,,,,,, -9700,0.94749725,4.6676664,,,,,,,,,,,,,, -9800,0.9119557,5.7402315,,,,,,,,,,,,,, -9900,0.8370245,4.4339094,,,,,,,,,,,,,, -10000,0.7940009,6.2470617,,,,,,,,,,,,,, -10064,,,0.2577343583106994,3.733825206756592,0.2406199872493744,3.855492115020752,50000.0,0.1834000051021576,4.324698448181152,10000.0,4656.660678386688,5135.04213309288,4656.660678386688,477.5397388935089,0.3062405586242676,0.0 -10100,0.8595941,5.763239,,,,,,,,,,,,,, -10200,0.8305899,4.7467546,,,,,,,,,,,,,, -10300,0.7892189,5.6632714,,,,,,,,,,,,,, -10400,0.7124142,5.13175,,,,,,,,,,,,,, -10500,0.8139036,4.543969,,,,,,,,,,,,,, -10600,0.9051397,4.485305,,,,,,,,,,,,,, -10700,1.0894763,4.310498,,,,,,,,,,,,,, -10800,0.85420436,6.048172,,,,,,,,,,,,,, -10900,0.9581427,4.392799,,,,,,,,,,,,,, -10987,,,0.2759570181369781,3.5516586303710938,0.2458799928426742,3.746058225631714,50000.0,0.1914000064134597,4.227960586547852,10000.0,5076.703187704086,5594.703353404999,5076.703187704086,517.0797474384308,0.3357918262481689,0.0 -11000,0.8771809,4.4703026,,,,,,,,,,,,,, -11100,1.0196195,5.4963036,,,,,,,,,,,,,, -11200,0.9724316,4.352787,,,,,,,,,,,,,, -11300,0.5866623,5.7537446,,,,,,,,,,,,,, -11400,1.0181857,4.5743113,,,,,,,,,,,,,, -11500,0.69712543,5.8792605,,,,,,,,,,,,,, -11600,0.76017636,5.061146,,,,,,,,,,,,,, -11700,0.935608,4.7559037,,,,,,,,,,,,,, -11800,0.91860217,4.273209,,,,,,,,,,,,,, -11900,0.7908355,4.158858,,,,,,,,,,,,,, -11908,,,0.2885546684265136,3.454838752746582,0.2699199914932251,3.566235542297364,50000.0,0.2025000154972076,4.11765718460083,10000.0,5496.717879772186,6053.150849819183,5496.717879772186,555.4356439113617,0.3640127182006836,0.0 -12000,0.8410576,4.8575177,,,,,,,,,,,,,, -12100,0.8294626,4.1488147,,,,,,,,,,,,,, -12200,0.82694125,4.4688654,,,,,,,,,,,,,, -12300,1.030882,4.3184457,,,,,,,,,,,,,, -12400,1.1441875,4.4100313,,,,,,,,,,,,,, -12500,0.71952355,4.594874,,,,,,,,,,,,,, -12600,0.8435116,4.3214474,,,,,,,,,,,,,, -12700,1.008525,4.272443,,,,,,,,,,,,,, -12800,0.8457687,4.2258425,,,,,,,,,,,,,, -12828,,,0.3011523485183716,3.359318256378174,0.2775000035762787,3.493115901947021,50000.0,0.215700015425682,4.02402925491333,10000.0,5916.797674417496,6510.595469713211,5916.797674417496,592.7188277244568,0.3962678909301758,0.0 -12900,0.9495214,4.3235397,,,,,,,,,,,,,, -13000,0.70594805,4.634172,,,,,,,,,,,,,, -13100,1.0048368,4.618791,,,,,,,,,,,,,, -13200,0.873649,4.1434755,,,,,,,,,,,,,, -13300,0.9573612,4.3750277,,,,,,,,,,,,,, -13400,1.0946894,5.692624,,,,,,,,,,,,,, -13500,0.91139233,5.5217857,,,,,,,,,,,,,, -13600,0.83885616,4.3374805,,,,,,,,,,,,,, -13700,0.80198866,4.238555,,,,,,,,,,,,,, -13749,,,0.3214453160762787,3.2423040866851807,0.2875199913978576,3.422166109085083,50000.0,0.2234000116586685,3.946427583694458,10000.0,6336.788841247559,6971.531793117523,6336.788841247559,633.5849332809448,0.4264430999755859,0.0 -13800,0.5854787,6.06793,,,,,,,,,,,,,, -13900,0.9342197,4.947026,,,,,,,,,,,,,, -14000,0.82699674,4.3618803,,,,,,,,,,,,,, -14100,1.1241596,4.2236695,,,,,,,,,,,,,, -14200,0.8229772,4.545107,,,,,,,,,,,,,, -14300,1.127177,4.306148,,,,,,,,,,,,,, -14400,0.7551563,4.9804683,,,,,,,,,,,,,, -14500,0.94870716,4.214695,,,,,,,,,,,,,, -14600,1.086679,5.1614003,,,,,,,,,,,,,, -14674,,,0.3084179759025574,3.319817066192627,0.2889599800109863,3.429847478866577,50000.0,0.2208000123500824,3.968396186828613,10000.0,6756.758174657822,7428.526203870773,6756.758174657822,670.5265691280365,0.4606430530548095,0.0 -14700,0.9265988,6.2114506,,,,,,,,,,,,,, -14800,0.84498715,4.150548,,,,,,,,,,,,,, -14900,0.99207616,4.5686255,,,,,,,,,,,,,, -15000,0.7288443,6.096277,,,,,,,,,,,,,, -15100,0.9211667,4.3985624,,,,,,,,,,,,,, -15200,1.2224329,4.2257085,,,,,,,,,,,,,, -15300,0.97203946,4.222899,,,,,,,,,,,,,, -15400,0.9732195,4.213538,,,,,,,,,,,,,, -15500,0.80561894,5.9479275,,,,,,,,,,,,,, -15597,,,0.3149218559265136,3.302459239959717,0.2924000024795532,3.4336822032928467,50000.0,0.2218000143766403,3.960664987564087,10000.0,7176.900419473648,7890.88374376297,7176.900419473648,712.6638605594635,0.490314245223999,0.0 -15600,1.1591256,4.674486,,,,,,,,,,,,,, -15700,0.9667464,4.297616,,,,,,,,,,,,,, -15800,0.8667739,4.260016,,,,,,,,,,,,,, -15900,0.7916176,4.671741,,,,,,,,,,,,,, -16000,0.7968277,4.520251,,,,,,,,,,,,,, -16100,0.7467047,4.5769825,,,,,,,,,,,,,, -16200,0.80803925,5.3015075,,,,,,,,,,,,,, -16300,0.925234,4.201272,,,,,,,,,,,,,, -16400,0.8053812,6.13556,,,,,,,,,,,,,, -16500,0.9210095,4.2355413,,,,,,,,,,,,,, -16518,,,0.3135937452316284,3.3022851943969727,0.2899599969387054,3.4567174911499023,50000.0,0.2164000123739242,4.005506992340088,10000.0,7597.258638143539,8348.557153463364,7597.258638143539,749.902556180954,0.5184402465820312,0.0 -16600,0.78641796,5.9600544,,,,,,,,,,,,,, -16700,0.7931481,5.866458,,,,,,,,,,,,,, -16800,0.95144814,4.4847775,,,,,,,,,,,,,, -16900,0.92076457,6.0285826,,,,,,,,,,,,,, -17000,0.8137289,4.1605997,,,,,,,,,,,,,, -17100,0.89423645,4.1791253,,,,,,,,,,,,,, -17200,0.86809486,4.229105,,,,,,,,,,,,,, -17300,1.0974802,4.27354,,,,,,,,,,,,,, -17400,0.82118934,4.0340524,,,,,,,,,,,,,, -17442,,,0.3493554592132568,3.045586585998535,0.311519980430603,3.2577993869781494,50000.0,0.2416000068187713,3.832712650299072,10000.0,8017.4315502643585,8809.659813642502,8017.4315502643585,790.755383014679,0.5460658073425293,0.0 -17500,1.005329,4.205557,,,,,,,,,,,,,, -17600,0.96709317,4.557013,,,,,,,,,,,,,, -17700,0.92830175,4.4166183,,,,,,,,,,,,,, -17800,0.94677466,4.20364,,,,,,,,,,,,,, -17900,0.9623034,4.1682863,,,,,,,,,,,,,, -18000,0.8622224,4.260439,,,,,,,,,,,,,, -18100,0.9062119,3.9904249,,,,,,,,,,,,,, -18200,0.82833546,5.085025,,,,,,,,,,,,,, -18300,0.9255688,4.1305656,,,,,,,,,,,,,, -18365,,,0.334765613079071,3.173603057861328,0.3100999891757965,3.299381732940674,50000.0,0.2379000186920166,3.853330612182617,10000.0,8437.51029419899,9269.958486557009,8437.51029419899,830.8956499099731,0.5777482986450195,0.0 -18400,0.9865733,4.405195,,,,,,,,,,,,,, -18500,0.8120871,4.0509543,,,,,,,,,,,,,, -18600,0.8563953,4.238135,,,,,,,,,,,,,, -18700,0.832796,4.182831,,,,,,,,,,,,,, -18800,0.8790368,4.600828,,,,,,,,,,,,,, -18900,1.0515175,4.010655,,,,,,,,,,,,,, -19000,0.79271966,5.9884157,,,,,,,,,,,,,, -19100,0.9412007,3.9469004,,,,,,,,,,,,,, -19200,0.98486006,4.07668,,,,,,,,,,,,,, -19286,,,0.328125,3.206876516342163,0.3041799962520599,3.352670431137085,50000.0,0.234700009226799,3.879996299743652,10000.0,8857.471648454666,9726.6599547863,8857.471648454666,867.5589497089386,0.6066219806671143,0.0 -19300,0.7625678,4.68409,,,,,,,,,,,,,, -19400,0.983722,4.1365185,,,,,,,,,,,,,, -19500,0.8834165,5.3734717,,,,,,,,,,,,,, -19600,1.0220696,5.516304,,,,,,,,,,,,,, -19700,0.8592908,4.1194534,,,,,,,,,,,,,, -19800,0.89548105,4.019303,,,,,,,,,,,,,, -19900,0.82331854,4.0890975,,,,,,,,,,,,,, -20000,0.94610053,4.116624,,,,,,,,,,,,,, -20100,0.9503924,4.099891,,,,,,,,,,,,,, -20200,0.8989408,5.2086873,,,,,,,,,,,,,, -20209,,,0.3524218797683716,3.068400621414185,0.311379998922348,3.3076298236846924,50000.0,0.2412000149488449,3.833801746368408,10000.0,9277.629351854324,10190.143986225128,9277.629351854324,910.8073680400848,0.6355025768280029,0.0 -20300,0.81926364,5.094794,,,,,,,,,,,,,, -20400,0.7977714,3.8723893,,,,,,,,,,,,,, -20500,0.79876196,4.4864964,,,,,,,,,,,,,, -20600,0.75189084,4.1343145,,,,,,,,,,,,,, -20700,1.0603358,4.017812,,,,,,,,,,,,,, -20800,0.95392334,4.708803,,,,,,,,,,,,,, -20900,0.7929478,4.0438275,,,,,,,,,,,,,, -21000,0.6902105,5.6009116,,,,,,,,,,,,,, -21100,0.997385,4.175238,,,,,,,,,,,,,, -21131,,,0.3587304651737213,3.019661903381348,0.3378599882125854,3.147328615188598,50000.0,0.2560000121593475,3.727675914764404,10000.0,9697.635033369064,10655.167280435562,9697.635033369064,955.7451527118684,0.6662991046905518,0.0 -21200,0.68219244,5.851604,,,,,,,,,,,,,, -21300,0.94494087,4.0271907,,,,,,,,,,,,,, -21400,0.8249885,4.244356,,,,,,,,,,,,,, -21500,0.8369322,4.383905,,,,,,,,,,,,,, -21600,0.7783769,4.531444,,,,,,,,,,,,,, -21700,0.65223306,6.0448847,,,,,,,,,,,,,, -21800,0.6784818,6.080043,,,,,,,,,,,,,, -21900,0.8285996,4.038012,,,,,,,,,,,,,, -22000,0.78612155,4.068596,,,,,,,,,,,,,, -22053,,,0.3477343618869781,3.0713601112365723,0.3216799795627594,3.2210161685943604,50000.0,0.2412000149488449,3.8103978633880615,10000.0,10117.811880588531,11114.772018909454,10117.811880588531,995.0893120765686,0.700777530670166,0.0 -22100,0.9868493,4.025068,,,,,,,,,,,,,, -22200,0.71670604,5.798018,,,,,,,,,,,,,, -22300,0.8837645,4.1079087,,,,,,,,,,,,,, -22400,0.8997026,3.9535823,,,,,,,,,,,,,, -22500,0.8809871,5.759507,,,,,,,,,,,,,, -22600,0.85851085,4.0307326,,,,,,,,,,,,,, -22700,0.9193845,4.7153444,,,,,,,,,,,,,, -22800,1.0225773,4.200569,,,,,,,,,,,,,, -22900,1.0498298,3.9357755,,,,,,,,,,,,,, -22976,,,0.3622460961341858,2.9715335369110107,0.3338399827480316,3.14302396774292,50000.0,0.255400002002716,3.708144664764404,10000.0,10538.091529846191,11576.040191173552,10538.091529846191,1035.9921073913574,0.7372596263885498,0.0 -23000,0.7916837,4.050954,,,,,,,,,,,,,, -23100,1.0379369,4.029197,,,,,,,,,,,,,, -23200,0.9063908,4.099342,,,,,,,,,,,,,, -23300,1.0695556,4.0888796,,,,,,,,,,,,,, -23400,1.025324,4.100216,,,,,,,,,,,,,, -23500,0.81110924,4.0598593,,,,,,,,,,,,,, -23600,1.0522133,5.0386505,,,,,,,,,,,,,, -23700,0.92706454,4.0190854,,,,,,,,,,,,,, -23800,0.9634167,6.018855,,,,,,,,,,,,,, -23900,,,0.3491992056369781,3.0883233547210693,0.3263799846172333,3.207261085510254,50000.0,0.2514000236988067,3.753178596496582,10000.0,10958.286712884905,12036.120604991913,10958.286712884905,1075.7954070568085,0.7705366611480713,0.0 -23900,0.88463306,4.1638374,,,,,,,,,,,,,, -24000,0.961491,4.497767,,,,,,,,,,,,,, -24100,0.9810447,4.1569967,,,,,,,,,,,,,, -24200,0.8607526,5.3055134,,,,,,,,,,,,,, -24300,0.83680755,4.2240486,,,,,,,,,,,,,, -24400,0.76955366,5.710074,,,,,,,,,,,,,, -24500,0.91706115,3.9563017,,,,,,,,,,,,,, -24600,1.0406744,4.519926,,,,,,,,,,,,,, -24700,0.8033859,5.2314014,,,,,,,,,,,,,, -24800,0.9548317,3.9856706,,,,,,,,,,,,,, -24825,,,0.3492382764816284,3.050524473190308,0.3290999829769134,3.18064284324646,50000.0,0.2499000132083892,3.739815473556519,10000.0,11378.547968149183,12498.776052236555,11378.547968149183,1118.106077671051,0.8033137321472168,0.0 -24900,0.63956684,5.988489,,,,,,,,,,,,,, -25000,0.97820795,3.9995074,,,,,,,,,,,,,, -25100,0.923824,4.3415413,,,,,,,,,,,,,, -25200,0.9642212,4.0083113,,,,,,,,,,,,,, -25300,0.9190958,3.881112,,,,,,,,,,,,,, -25400,0.83621776,5.645939,,,,,,,,,,,,,, -25500,0.74275297,5.176154,,,,,,,,,,,,,, -25600,0.7836739,4.935535,,,,,,,,,,,,,, -25700,0.82065845,4.795205,,,,,,,,,,,,,, -25746,,,0.3688476383686065,2.933948516845703,0.3354199826717376,3.105673789978028,50000.0,0.255700021982193,3.705024480819702,10000.0,11798.491114139557,12955.351569652556,11798.491114139557,1154.6575469970703,0.8338415622711182,0.0 -25800,0.9414855,4.057428,,,,,,,,,,,,,, -25900,0.8532907,5.1688223,,,,,,,,,,,,,, -26000,0.99551487,4.71432,,,,,,,,,,,,,, -26100,0.9915348,4.0441136,,,,,,,,,,,,,, -26200,0.9229811,3.9255488,,,,,,,,,,,,,, -26300,1.0652256,3.972159,,,,,,,,,,,,,, -26400,0.987719,4.106896,,,,,,,,,,,,,, -26500,0.76679623,6.075879,,,,,,,,,,,,,, -26600,0.9422828,3.902112,,,,,,,,,,,,,, -26668,,,0.3754101395606994,2.905060052871704,0.3468399941921234,3.06069564819336,50000.0,0.2708000242710113,3.657903671264648,10000.0,12218.91095137596,13413.987750291824,12218.91095137596,1192.788717508316,0.8701872825622559,0.0 -26700,0.78818774,4.1990714,,,,,,,,,,,,,, -26800,0.9733687,3.8891253,,,,,,,,,,,,,, -26900,0.8979288,4.2440453,,,,,,,,,,,,,, -27000,0.95427746,3.8084004,,,,,,,,,,,,,, -27100,1.3595551,4.0651393,,,,,,,,,,,,,, -27200,0.9192895,3.980703,,,,,,,,,,,,,, -27300,0.9160927,3.9242241,,,,,,,,,,,,,, -27400,0.78548163,5.8594093,,,,,,,,,,,,,, -27500,0.7563912,4.661805,,,,,,,,,,,,,, -27590,,,0.3759374916553497,2.8957693576812744,0.3487599790096283,3.039360284805298,50000.0,0.2639000117778778,3.661264181137085,10000.0,12638.921879768372,13875.224354982376,12638.921879768372,1233.934255361557,0.9021317958831788,0.0 -27600,1.0125085,3.9261856,,,,,,,,,,,,,, -27700,0.9636456,4.566968,,,,,,,,,,,,,, -27800,0.8537059,4.8970566,,,,,,,,,,,,,, -27900,1.0828197,3.9554281,,,,,,,,,,,,,, -28000,1.0871116,4.2539067,,,,,,,,,,,,,, -28100,0.8467363,4.0195293,,,,,,,,,,,,,, -28200,1.0963868,3.9379067,,,,,,,,,,,,,, -28300,1.1053274,3.9981687,,,,,,,,,,,,,, -28400,0.8690098,4.3571477,,,,,,,,,,,,,, -28500,0.80142295,6.024173,,,,,,,,,,,,,, -28514,,,0.3650195300579071,2.984952926635742,0.3395199775695801,3.1396758556365967,50000.0,0.2562000155448913,3.7381186485290527,10000.0,13059.063136339188,14332.19144487381,13059.063136339188,1270.6803524494171,0.9330759048461914,0.0 -28600,0.59818137,5.9696097,,,,,,,,,,,,,, -28700,1.2983594,4.331139,,,,,,,,,,,,,, -28800,0.94927245,3.8846526,,,,,,,,,,,,,, -28900,0.9853845,3.8224075,,,,,,,,,,,,,, -29000,0.9180677,4.2209306,,,,,,,,,,,,,, -29100,0.70439947,4.4721856,,,,,,,,,,,,,, -29200,0.7864968,4.79575,,,,,,,,,,,,,, -29300,1.1699421,4.1319017,,,,,,,,,,,,,, -29400,0.9191146,3.9572105,,,,,,,,,,,,,, -29437,,,0.3907031118869781,2.8178062438964844,0.3384000062942505,3.126830577850342,50000.0,0.2629000246524811,3.685798168182373,10000.0,13479.126105070114,14793.215996265411,13479.126105070114,1311.5622823238373,0.9640073776245116,0.0 -29500,0.94802785,3.9752026,,,,,,,,,,,,,, -29600,1.0161892,3.7942367,,,,,,,,,,,,,, -29700,0.7578343,4.5926385,,,,,,,,,,,,,, -29800,0.7468435,5.0662036,,,,,,,,,,,,,, -29900,0.86336225,3.8505309,,,,,,,,,,,,,, -30000,0.96751857,3.8520508,,,,,,,,,,,,,, -30100,0.98311883,3.9602852,,,,,,,,,,,,,, -30200,0.82237375,6.0703254,,,,,,,,,,,,,, -30300,1.1755269,4.0874257,,,,,,,,,,,,,, -30356,,,0.3720507621765136,2.9231514930725098,0.346699982881546,3.082463026046753,50000.0,0.2647000253200531,3.649627447128296,10000.0,13899.084905862808,15251.75150704384,13899.084905862808,1350.0582914352417,0.9966273307800292,0.0 -30400,0.78837687,5.038559,,,,,,,,,,,,,, -30500,0.93651146,4.1518626,,,,,,,,,,,,,, -30600,1.059604,6.038608,,,,,,,,,,,,,, -30700,0.9972076,3.903611,,,,,,,,,,,,,, -30800,1.0120974,4.2165785,,,,,,,,,,,,,, -30900,0.83541155,5.105652,,,,,,,,,,,,,, -31000,0.8080617,4.953506,,,,,,,,,,,,,, -31100,1.1168658,4.2842755,,,,,,,,,,,,,, -31200,0.9860507,4.081298,,,,,,,,,,,,,, -31278,,,0.3739062547683716,2.9314498901367188,0.3473199903964996,3.0704736709594727,50000.0,0.2718999981880188,3.6703150272369385,10000.0,14319.371235609056,15711.486956119536,14319.371235609056,1389.423010110855,1.0317187309265137,0.0 -31300,0.88974327,3.9364462,,,,,,,,,,,,,, -31400,1.0294636,3.8456619,,,,,,,,,,,,,, -31500,1.172203,3.7901058,,,,,,,,,,,,,, -31600,1.0212882,3.978037,,,,,,,,,,,,,, -31700,0.9540274,4.000467,,,,,,,,,,,,,, -31800,0.8760758,4.473259,,,,,,,,,,,,,, -31900,0.6924268,5.9597316,,,,,,,,,,,,,, -32000,0.98860604,3.856291,,,,,,,,,,,,,, -32100,1.2245817,3.8459387,,,,,,,,,,,,,, -32200,1.089177,3.8362823,,,,,,,,,,,,,, -32202,,,0.4043163955211639,2.71505069732666,0.3649599850177765,2.9368364810943604,50000.0,0.2857000231742859,3.5437114238739014,10000.0,14739.30179309845,16172.872477769852,14739.30179309845,1430.7917094230652,1.0686018466949463,0.0 -32300,0.7924706,4.8608346,,,,,,,,,,,,,, -32400,1.1519039,3.9558702,,,,,,,,,,,,,, -32500,0.73501736,5.4263725,,,,,,,,,,,,,, -32600,0.74762243,6.00259,,,,,,,,,,,,,, -32700,0.970569,4.039507,,,,,,,,,,,,,, -32800,1.033137,3.7395773,,,,,,,,,,,,,, -32900,0.9834595,3.8279734,,,,,,,,,,,,,, -33000,0.68532634,6.001791,,,,,,,,,,,,,, -33100,0.9731256,3.7714105,,,,,,,,,,,,,, -33126,,,0.3913476467132568,2.812678575515747,0.3640399873256683,2.9665307998657227,50000.0,0.2803000211715698,3.540454387664795,10000.0,15159.38614320755,16632.36226463318,15159.38614320755,1470.1178567409515,1.098562479019165,0.0 -33200,0.95251137,3.816175,,,,,,,,,,,,,, -33300,1.1817489,3.7151828,,,,,,,,,,,,,, -33400,0.8680091,5.688826,,,,,,,,,,,,,, -33500,0.7766522,3.948949,,,,,,,,,,,,,, -33600,0.88335735,3.8901045,,,,,,,,,,,,,, -33700,0.7829539,4.9917707,,,,,,,,,,,,,, -33800,1.0164624,3.6964514,,,,,,,,,,,,,, -33900,0.7043676,6.07036,,,,,,,,,,,,,, -34000,0.9731448,4.650126,,,,,,,,,,,,,, -34048,,,0.382148414850235,2.845215082168579,0.3581399917602539,2.999690532684326,50000.0,0.2736000120639801,3.5886569023132324,10000.0,15579.40675020218,17091.41108250618,15579.40675020218,1509.067530155182,1.1289377212524414,0.0 -34100,0.92334896,3.824052,,,,,,,,,,,,,, -34200,1.0627848,3.9378142,,,,,,,,,,,,,, -34300,0.96516496,3.9462562,,,,,,,,,,,,,, -34400,1.0065918,3.7674968,,,,,,,,,,,,,, -34500,0.93386066,3.780937,,,,,,,,,,,,,, -34600,1.0435681,3.7521112,,,,,,,,,,,,,, -34700,1.2482569,3.7570674,,,,,,,,,,,,,, -34800,0.8959937,4.8656993,,,,,,,,,,,,,, -34900,1.3199806,3.871949,,,,,,,,,,,,,, -34973,,,0.3923242092132568,2.8107075691223145,0.3644599914550781,2.975851058959961,50000.0,0.2762000262737274,3.589385271072388,10000.0,15999.715313196182,17553.633061647415,15999.715313196182,1550.8924548625946,1.168405055999756,0.0 -35000,1.2130052,3.9476082,,,,,,,,,,,,,, -35100,1.0791879,3.939434,,,,,,,,,,,,,, -35200,1.0051278,3.8512158,,,,,,,,,,,,,, -35300,1.0476835,3.7510123,,,,,,,,,,,,,, -35400,0.99646205,4.6145625,,,,,,,,,,,,,, -35500,0.78013104,5.6343036,,,,,,,,,,,,,, -35600,0.9351166,4.223998,,,,,,,,,,,,,, -35700,0.9313392,4.9605103,,,,,,,,,,,,,, -35800,1.0893091,3.9184952,,,,,,,,,,,,,, -35899,,,0.399726539850235,2.7130661010742188,0.3743399977684021,2.8684468269348145,50000.0,0.2859000265598297,3.4763166904449463,10000.0,16419.95436644554,18018.21585845948,16419.95436644554,1595.15216255188,1.2027950286865234,0.0 -35900,1.0718638,3.8748112,,,,,,,,,,,,,, -36000,1.0317712,3.9307365,,,,,,,,,,,,,, -36100,1.2944659,3.8631265,,,,,,,,,,,,,, -36200,0.9584057,4.8852625,,,,,,,,,,,,,, -36300,0.93946064,3.7413216,,,,,,,,,,,,,, -36400,0.90706223,3.737418,,,,,,,,,,,,,, -36500,1.1819615,4.3890905,,,,,,,,,,,,,, -36600,0.87316096,5.095472,,,,,,,,,,,,,, -36700,1.0326457,4.043505,,,,,,,,,,,,,, -36800,1.0490568,4.0195045,,,,,,,,,,,,,, -36822,,,0.4059179723262787,2.700996398925781,0.3773199915885925,2.852568387985229,50000.0,0.2893000245094299,3.473834991455078,10000.0,16840.155430793762,18478.654767751694,16840.155430793762,1635.3084263801577,1.2354793548583984,0.0 -36900,0.84607476,5.8203154,,,,,,,,,,,,,, -37000,0.9690029,4.0159974,,,,,,,,,,,,,, -37100,0.99627036,3.7704968,,,,,,,,,,,,,, -37200,0.8963906,5.3614693,,,,,,,,,,,,,, -37300,1.0028548,4.4688706,,,,,,,,,,,,,, -37400,0.97022724,4.0328507,,,,,,,,,,,,,, -37500,0.99260986,3.844193,,,,,,,,,,,,,, -37600,0.90293133,5.9593077,,,,,,,,,,,,,, -37700,0.8114886,3.987616,,,,,,,,,,,,,, -37742,,,0.4121874868869781,2.6718852519989014,0.3816199898719787,2.845273971557617,50000.0,0.3015000224113464,3.418804407119751,10000.0,17260.433208703995,18941.17035579681,17260.433208703995,1677.465342760086,1.268315076828003,0.0 -37800,0.95453435,3.8851337,,,,,,,,,,,,,, -37900,0.7378444,5.9010925,,,,,,,,,,,,,, -38000,1.220489,3.9542909,,,,,,,,,,,,,, -38100,1.0224855,3.9343438,,,,,,,,,,,,,, -38200,1.0586592,3.7928488,,,,,,,,,,,,,, -38300,0.7973225,4.5383406,,,,,,,,,,,,,, -38400,0.9847094,3.834006,,,,,,,,,,,,,, -38500,1.0213363,3.8631365,,,,,,,,,,,,,, -38600,0.91747177,4.7114725,,,,,,,,,,,,,, -38663,,,0.4493359327316284,2.501319169998169,0.3834999799728393,2.8455910682678223,50000.0,0.2939000129699707,3.432493209838867,10000.0,17680.66015148163,19399.91322350502,17680.66015148163,1715.8985350131989,1.3021540641784668,0.0 -38700,0.95493126,3.76758,,,,,,,,,,,,,, -38800,1.0331391,3.89819,,,,,,,,,,,,,, -38900,1.2082547,3.8709292,,,,,,,,,,,,,, -39000,1.0380274,3.739202,,,,,,,,,,,,,, -39100,1.062902,3.8507097,,,,,,,,,,,,,, -39200,0.9068427,3.679257,,,,,,,,,,,,,, -39300,0.9379365,4.1782885,,,,,,,,,,,,,, -39400,0.83738285,5.3153787,,,,,,,,,,,,,, -39500,0.98785335,3.9173353,,,,,,,,,,,,,, -39587,,,0.3864062428474426,2.9066293239593506,0.3610599935054779,3.036476850509644,50000.0,0.2675000131130218,3.623452425003052,10000.0,18100.73324918747,19860.374609947205,18100.73324918747,1756.2016875743866,1.339077711105347,0.0 -39600,0.91187036,4.4203715,,,,,,,,,,,,,, -39700,0.95200944,3.7079413,,,,,,,,,,,,,, -39800,1.2077103,3.6951797,,,,,,,,,,,,,, -39900,1.0638609,4.096092,,,,,,,,,,,,,, -40000,0.9457968,3.792462,,,,,,,,,,,,,, -40100,1.0945491,4.0599513,,,,,,,,,,,,,, -40200,1.1457373,3.657606,,,,,,,,,,,,,, -40300,0.9694944,4.279361,,,,,,,,,,,,,, -40400,1.1775495,3.7131233,,,,,,,,,,,,,, -40500,0.90441805,4.0987673,,,,,,,,,,,,,, -40510,,,0.401660144329071,2.776913642883301,0.3719799816608429,2.9364328384399414,50000.0,0.289900004863739,3.5260701179504395,10000.0,18521.10993671417,20318.680195093155,18521.10993671417,1794.0471782684326,1.3741655349731443,0.0 -40600,0.75331074,5.566181,,,,,,,,,,,,,, -40700,0.8890809,5.9730816,,,,,,,,,,,,,, -40800,1.0019432,3.7281046,,,,,,,,,,,,,, -40900,0.8873243,5.720485,,,,,,,,,,,,,, -41000,1.0168384,3.9580731,,,,,,,,,,,,,, -41100,0.88303524,3.7356703,,,,,,,,,,,,,, -41200,1.4427589,3.9905872,,,,,,,,,,,,,, -41300,0.9214504,5.146849,,,,,,,,,,,,,, -41400,0.9896025,3.6666367,,,,,,,,,,,,,, -41434,,,0.423164039850235,2.5967025756835938,0.3805999755859375,2.832966089248657,50000.0,0.2957000136375427,3.4568240642547607,10000.0,18941.358196496964,20777.212375164032,18941.358196496964,1832.2530777454376,1.404280662536621,0.0 -41500,0.73457927,5.8298106,,,,,,,,,,,,,, -41600,0.6631973,5.9463687,,,,,,,,,,,,,, -41700,0.7495197,5.866023,,,,,,,,,,,,,, -41800,0.8459196,4.928897,,,,,,,,,,,,,, -41900,0.68337023,5.8157573,,,,,,,,,,,,,, -42000,0.8832632,4.6094227,,,,,,,,,,,,,, -42100,0.8809438,3.7426808,,,,,,,,,,,,,, -42200,1.1057673,3.5751295,,,,,,,,,,,,,, -42300,1.1825838,3.7046165,,,,,,,,,,,,,, -42358,,,0.4092773199081421,2.675566434860229,0.3887199759483337,2.8146812915802,50000.0,0.299200028181076,3.452272891998291,10000.0,19361.352199077606,21236.00361180305,19361.352199077606,1870.96342420578,1.441824436187744,0.0 -42400,0.9285173,3.8533635,,,,,,,,,,,,,, -42500,1.1236713,3.7681253,,,,,,,,,,,,,, -42600,0.863732,4.115941,,,,,,,,,,,,,, -42700,0.8710488,4.643573,,,,,,,,,,,,,, -42800,1.0177982,3.8302178,,,,,,,,,,,,,, -42900,0.861486,5.91977,,,,,,,,,,,,,, -43000,0.99237996,3.6501837,,,,,,,,,,,,,, -43100,0.9357534,3.6266484,,,,,,,,,,,,,, -43200,0.86284107,5.077509,,,,,,,,,,,,,, -43280,,,0.4135546684265136,2.665776252746582,0.3831200003623962,2.8275225162506104,50000.0,0.2969000041484833,3.4342548847198486,10000.0,19781.59530377388,21693.95869255066,19781.59530377388,1908.5935270786283,1.474895715713501,0.0 -43300,0.7402072,5.88854,,,,,,,,,,,,,, -43400,1.0293809,3.632482,,,,,,,,,,,,,, -43500,0.81487054,5.842734,,,,,,,,,,,,,, -43600,1.0724455,3.7023191,,,,,,,,,,,,,, -43700,1.1894219,3.68954,,,,,,,,,,,,,, -43800,0.798612,5.68165,,,,,,,,,,,,,, -43900,1.028026,3.8367908,,,,,,,,,,,,,, -44000,1.0122848,3.6787245,,,,,,,,,,,,,, -44100,1.0675461,3.936364,,,,,,,,,,,,,, -44200,0.94227415,4.0529814,,,,,,,,,,,,,, -44204,,,0.4298437535762787,2.616574048995972,0.3955000042915344,2.792757034301758,50000.0,0.305400013923645,3.4004056453704834,10000.0,20201.73545074463,22153.340349674225,20201.73545074463,1947.7522237300875,1.50919771194458,0.0 -44300,1.0002013,4.1451435,,,,,,,,,,,,,, -44400,1.1726916,3.8106833,,,,,,,,,,,,,, -44500,0.942711,3.6295745,,,,,,,,,,,,,, -44600,0.93779075,3.7143831,,,,,,,,,,,,,, -44700,0.76365936,5.82193,,,,,,,,,,,,,, -44800,0.8054341,5.861146,,,,,,,,,,,,,, -44900,0.67543465,5.180569,,,,,,,,,,,,,, -45000,0.8556224,5.392503,,,,,,,,,,,,,, -45100,0.97400635,3.9185011,,,,,,,,,,,,,, -45130,,,0.4178906083106994,2.693882465362549,0.3900199830532074,2.8285815715789795,50000.0,0.3004000186920166,3.436731338500977,10000.0,20621.955298423767,22613.75858616829,20621.955298423767,1987.8704631328585,1.5406432151794434,0.0 -45200,0.84915966,3.5680366,,,,,,,,,,,,,, -45300,0.8185925,5.1446514,,,,,,,,,,,,,, -45400,1.0754498,3.8232493,,,,,,,,,,,,,, -45500,0.73882926,5.311696,,,,,,,,,,,,,, -45600,0.9507821,4.127742,,,,,,,,,,,,,, -45700,1.1106892,3.6825557,,,,,,,,,,,,,, -45800,1.2377228,3.7128997,,,,,,,,,,,,,, -45900,0.92703235,4.286577,,,,,,,,,,,,,, -46000,1.2868134,3.7459538,,,,,,,,,,,,,, -46053,,,0.4268945157527923,2.6220920085906982,0.3996999859809875,2.773547410964966,50000.0,0.3076000213623047,3.3951525688171387,10000.0,21041.939405441284,23074.92538762093,21041.939405441284,2028.9680247306824,1.5772721767425537,0.0 -46100,0.9923983,3.627321,,,,,,,,,,,,,, -46200,0.74650985,5.0991178,,,,,,,,,,,,,, -46300,0.95917946,3.971021,,,,,,,,,,,,,, -46400,0.9900485,4.123592,,,,,,,,,,,,,, -46500,0.9281977,6.000811,,,,,,,,,,,,,, -46600,1.0081555,5.7803,,,,,,,,,,,,,, -46700,0.7300355,5.616311,,,,,,,,,,,,,, -46800,0.916952,3.6349607,,,,,,,,,,,,,, -46900,0.96119785,3.50517,,,,,,,,,,,,,, -46976,,,0.4323046803474426,2.536307573318481,0.405239999294281,2.6976065635681152,50000.0,0.3146000206470489,3.3456552028656006,10000.0,21462.19115138054,23535.27784347534,21462.19115138054,2068.98393702507,1.6132659912109375,0.0 -47000,1.0621858,3.5783653,,,,,,,,,,,,,, -47100,0.88833535,3.5639114,,,,,,,,,,,,,, -47200,1.0959018,3.6212142,,,,,,,,,,,,,, -47300,1.0596956,3.7610354,,,,,,,,,,,,,, -47400,0.99200535,3.4275463,,,,,,,,,,,,,, -47500,1.0585154,3.529731,,,,,,,,,,,,,, -47600,1.1032639,3.5934088,,,,,,,,,,,,,, -47700,1.2581664,3.6705818,,,,,,,,,,,,,, -47800,0.7665672,5.906576,,,,,,,,,,,,,, -47900,0.7352021,4.148534,,,,,,,,,,,,,, -47901,,,0.4513085782527923,2.50014066696167,0.3981199860572815,2.7976644039154053,50000.0,0.2980000078678131,3.4122321605682373,10000.0,21882.690141916275,23995.60861515999,21882.690141916275,2108.731509923935,1.6486258506774902,0.0 -48000,0.98574847,5.8506002,,,,,,,,,,,,,, -48100,0.98846525,3.5049386,,,,,,,,,,,,,, -48200,1.0131544,4.4947133,,,,,,,,,,,,,, -48300,1.1782454,3.66011,,,,,,,,,,,,,, -48400,1.145505,3.8830376,,,,,,,,,,,,,, -48500,1.0161175,3.7994642,,,,,,,,,,,,,, -48600,1.0149015,3.6945581,,,,,,,,,,,,,, -48700,1.1097399,3.6022024,,,,,,,,,,,,,, -48800,1.0025939,4.0237446,,,,,,,,,,,,,, -48825,,,0.4263085722923279,2.605025053024292,0.4018400013446808,2.7520787715911865,50000.0,0.3126000165939331,3.3366458415985107,10000.0,22302.164969682693,24458.163280963898,22302.164969682693,2151.260172843933,2.1510133743286133,0.0 -48900,0.88895375,4.4418364,,,,,,,,,,,,,, -49000,0.75512445,5.1334414,,,,,,,,,,,,,, -49100,0.83829075,5.750117,,,,,,,,,,,,,, -49200,1.0877352,3.7767549,,,,,,,,,,,,,, -49300,1.005882,3.844298,,,,,,,,,,,,,, -49400,1.0359585,3.7129717,,,,,,,,,,,,,, -49500,0.98803353,3.5132246,,,,,,,,,,,,,, -49600,0.9983302,3.7552512,,,,,,,,,,,,,, -49700,0.95793015,5.8622656,,,,,,,,,,,,,, -49748,,,0.4362890422344208,2.5398664474487305,0.4076199829578399,2.714684247970581,50000.0,0.31700000166893,3.326667070388794,10000.0,22722.07783985138,24922.15366792679,22722.07783985138,2195.242655277252,2.196490526199341,0.0 -49800,1.0309058,3.7117875,,,,,,,,,,,,,, -49900,0.8985425,5.3298664,,,,,,,,,,,,,, -50000,1.0667986,6.0195494,,,,,,,,,,,,,, -50100,0.9958615,3.564746,,,,,,,,,,,,,, -50200,0.9185058,4.706111,,,,,,,,,,,,,, -50300,1.0829109,3.5152729,,,,,,,,,,,,,, -50400,0.96275836,5.9268804,,,,,,,,,,,,,, -50500,1.0829011,3.6076012,,,,,,,,,,,,,, -50600,1.1033812,3.6576564,,,,,,,,,,,,,, -50670,,,0.4491601586341858,2.490548610687256,0.402319997549057,2.752784013748169,50000.0,0.3070000112056732,3.3801825046539307,10000.0,23142.06842494011,25379.83161520958,23142.06842494011,2232.8479537963867,2.2305619716644287,0.0 -50700,1.0347688,3.8535275,,,,,,,,,,,,,, -50800,1.4309311,3.6359832,,,,,,,,,,,,,, -50900,1.0398384,5.8885555,,,,,,,,,,,,,, -51000,0.9657848,3.8384929,,,,,,,,,,,,,, -51100,0.9177467,3.7108374,,,,,,,,,,,,,, -51200,0.9020129,3.4774172,,,,,,,,,,,,,, -51300,0.88545823,4.0349216,,,,,,,,,,,,,, -51400,1.036375,3.5817673,,,,,,,,,,,,,, -51500,1.179483,3.734589,,,,,,,,,,,,,, -51593,,,0.4297265410423279,2.5958352088928223,0.4063799977302551,2.737758159637451,50000.0,0.3153000175952911,3.344174385070801,10000.0,23562.160203695297,25843.881860017776,23562.160203695297,2276.712729215622,2.2754602432250977,0.0 -51600,0.93390876,3.626811,,,,,,,,,,,,,, -51700,0.9096067,3.9600334,,,,,,,,,,,,,, -51800,0.7403327,5.7626595,,,,,,,,,,,,,, -51900,0.927046,5.401924,,,,,,,,,,,,,, -52000,1.0909066,5.977151,,,,,,,,,,,,,, -52100,1.3127271,3.6989546,,,,,,,,,,,,,, -52200,1.3831378,3.5708437,,,,,,,,,,,,,, -52300,1.0792074,3.589654,,,,,,,,,,,,,, -52400,0.81764895,3.9840899,,,,,,,,,,,,,, -52500,1.1774577,3.5569012,,,,,,,,,,,,,, -52517,,,0.4399023354053497,2.52112889289856,0.4089199900627136,2.6849489212036133,50000.0,0.3210000097751617,3.283968210220337,10000.0,23982.16872549057,26306.350472450256,23982.16872549057,2319.089858531952,2.310560941696167,0.0 -52600,0.82054245,4.6310797,,,,,,,,,,,,,, -52700,0.8972247,5.829413,,,,,,,,,,,,,, -52800,0.98621464,3.6593606,,,,,,,,,,,,,, -52900,1.0728554,3.3688173,,,,,,,,,,,,,, -53000,1.0567961,3.6242945,,,,,,,,,,,,,, -53100,1.0485088,3.5955596,,,,,,,,,,,,,, -53200,0.9286081,3.6091506,,,,,,,,,,,,,, -53300,1.0754659,3.6992645,,,,,,,,,,,,,, -53400,1.0184685,3.471205,,,,,,,,,,,,,, -53440,,,0.4496484398841858,2.463557720184326,0.4154399931430816,2.6775717735290527,50000.0,0.3276000022888183,3.283280611038208,10000.0,24402.40174674988,26766.47193908692,24402.40174674988,2358.8874881267548,2.348323106765747,0.0 -53500,0.9777777,5.8297234,,,,,,,,,,,,,, -53600,0.7422197,5.6373725,,,,,,,,,,,,,, -53700,1.0570472,3.5458844,,,,,,,,,,,,,, -53800,0.92389435,4.0335817,,,,,,,,,,,,,, -53900,0.8960961,4.3097525,,,,,,,,,,,,,, -54000,1.0083438,3.5245135,,,,,,,,,,,,,, -54100,1.3950592,3.6587956,,,,,,,,,,,,,, -54200,1.0751499,3.5300438,,,,,,,,,,,,,, -54300,0.91534287,5.7937336,,,,,,,,,,,,,, -54363,,,0.4367187321186065,2.556265592575073,0.4078799784183502,2.71284294128418,50000.0,0.3160000145435333,3.333500862121582,10000.0,24822.649163007736,27226.17301273346,24822.649163007736,2398.257269382477,2.3831875324249268,0.0 -54400,0.9849993,5.021352,,,,,,,,,,,,,, -54500,0.7604299,5.547454,,,,,,,,,,,,,, -54600,1.0193809,3.6981664,,,,,,,,,,,,,, -54700,0.73417675,5.1952963,,,,,,,,,,,,,, -54800,0.89851874,3.7715187,,,,,,,,,,,,,, -54900,0.92235607,3.8696835,,,,,,,,,,,,,, -55000,0.9820533,3.5693054,,,,,,,,,,,,,, -55100,1.0035207,3.5880747,,,,,,,,,,,,,, -55200,0.8686074,5.887261,,,,,,,,,,,,,, -55287,,,0.4478320181369781,2.4865450859069824,0.4236199855804443,2.6358258724212646,50000.0,0.3257000148296356,3.252009153366089,10000.0,25242.692935943604,27688.975242853165,25242.692935943604,2440.921109676361,2.425750970840454,0.0 -55300,1.0421109,3.554968,,,,,,,,,,,,,, -55400,0.89555115,4.0921617,,,,,,,,,,,,,, -55500,0.957026,5.439595,,,,,,,,,,,,,, -55600,1.0145212,3.7523332,,,,,,,,,,,,,, -55700,1.1045284,3.718773,,,,,,,,,,,,,, -55800,1.0691164,4.040621,,,,,,,,,,,,,, -55900,1.2548308,3.6849897,,,,,,,,,,,,,, -56000,0.95639116,3.7414389,,,,,,,,,,,,,, -56100,1.074103,3.7386875,,,,,,,,,,,,,, -56200,0.9062052,3.4778388,,,,,,,,,,,,,, -56210,,,0.4586718678474426,2.4052109718322754,0.421559989452362,2.599883794784546,50000.0,0.326200008392334,3.243481397628784,10000.0,25662.820051670074,28149.4397380352,25662.820051670074,2481.1670627593994,2.4647953510284424,0.0 -56300,1.0277143,3.4348345,,,,,,,,,,,,,, -56400,1.0327018,3.4685223,,,,,,,,,,,,,, -56500,1.0678996,3.6208053,,,,,,,,,,,,,, -56600,1.1230822,3.652601,,,,,,,,,,,,,, -56700,0.84058356,5.368209,,,,,,,,,,,,,, -56800,1.0145246,3.5157294,,,,,,,,,,,,,, -56900,1.1670253,3.635631,,,,,,,,,,,,,, -57000,0.963997,3.4198575,,,,,,,,,,,,,, -57100,1.1459386,3.481813,,,,,,,,,,,,,, -57131,,,0.4611132740974426,2.4022233486175537,0.4213999807834625,2.619320869445801,50000.0,0.329800009727478,3.234283447265625,10000.0,26083.14991784096,28611.14157915116,26083.14991784096,2522.4550380706787,2.499359130859375,0.0 -57200,0.96657014,3.4879713,,,,,,,,,,,,,, -57300,1.1081977,3.4280753,,,,,,,,,,,,,, -57400,0.98523796,3.8375726,,,,,,,,,,,,,, -57500,0.7535532,4.926066,,,,,,,,,,,,,, -57600,1.0385897,3.4718823,,,,,,,,,,,,,, -57700,0.98924875,3.7385352,,,,,,,,,,,,,, -57800,1.106545,3.4933903,,,,,,,,,,,,,, -57900,1.1342014,3.6619186,,,,,,,,,,,,,, -58000,1.0078453,5.824627,,,,,,,,,,,,,, -58055,,,0.4460351467132568,2.520827054977417,0.4186599850654602,2.6796395778656006,50000.0,0.3215000033378601,3.2995548248291016,10000.0,26503.37421274185,29074.28104519844,26503.37421274185,2565.287081718445,2.5334596633911133,0.0 -58100,0.8330503,5.8787146,,,,,,,,,,,,,, -58200,1.0429332,3.5217206,,,,,,,,,,,,,, -58300,0.88063276,4.627369,,,,,,,,,,,,,, -58400,1.2151852,3.60236,,,,,,,,,,,,,, -58500,1.1443682,3.7295918,,,,,,,,,,,,,, -58600,1.1348802,3.5892315,,,,,,,,,,,,,, -58700,0.8163235,5.8257184,,,,,,,,,,,,,, -58800,1.17471,3.5199676,,,,,,,,,,,,,, -58900,0.76131976,5.145465,,,,,,,,,,,,,, -58979,,,0.4637890458106994,2.4174797534942627,0.4327999949455261,2.582634925842285,50000.0,0.3387000262737274,3.2149226665496826,10000.0,26923.577298164368,29535.1870803833,26923.577298164368,2605.8996703624725,2.574249744415283,0.0 -59000,0.8188513,5.0246935,,,,,,,,,,,,,, -59100,1.0448385,5.405388,,,,,,,,,,,,,, -59200,0.94672596,3.6369073,,,,,,,,,,,,,, -59300,1.1382786,3.5516284,,,,,,,,,,,,,, -59400,1.1291555,3.4776762,,,,,,,,,,,,,, -59500,0.92255974,5.3039384,,,,,,,,,,,,,, -59600,1.055338,3.4810624,,,,,,,,,,,,,, -59700,0.95643336,3.8407562,,,,,,,,,,,,,, -59800,0.98383504,5.4129663,,,,,,,,,,,,,, -59900,0.99437106,3.409141,,,,,,,,,,,,,, -59903,,,0.478808581829071,2.342374563217163,0.4243399798870086,2.625638961791992,50000.0,0.3347000181674957,3.231863021850586,10000.0,27343.765065908432,29996.09737586975,27343.765065908432,2646.5336713790894,2.614187479019165,0.0 -60000,1.0258608,3.8891404,,,,,,,,,,,,,, -60100,0.9499396,4.370033,,,,,,,,,,,,,, -60200,0.9840912,4.0497417,,,,,,,,,,,,,, -60300,1.0918481,3.2995749,,,,,,,,,,,,,, -60400,0.99270207,4.246805,,,,,,,,,,,,,, -60500,1.1550485,3.5937274,,,,,,,,,,,,,, -60600,0.83966625,5.426981,,,,,,,,,,,,,, -60700,1.0890616,3.5803447,,,,,,,,,,,,,, -60800,0.7687486,4.8676195,,,,,,,,,,,,,, -60824,,,0.4606054723262787,2.44146728515625,0.428659975528717,2.598587989807129,50000.0,0.3337000012397766,3.2204599380493164,10000.0,27763.803510665894,30458.07638692856,27763.803510665894,2688.3805034160614,2.658954620361328,0.0 -60900,1.2585032,3.8663032,,,,,,,,,,,,,, -61000,1.1531408,3.5905197,,,,,,,,,,,,,, -61100,1.00122,3.5706227,,,,,,,,,,,,,, -61200,1.32664,3.471226,,,,,,,,,,,,,, -61300,0.95984316,3.8890862,,,,,,,,,,,,,, -61400,0.9107776,3.5332935,,,,,,,,,,,,,, -61500,0.9428238,5.127219,,,,,,,,,,,,,, -61600,1.1953437,3.529088,,,,,,,,,,,,,, -61700,1.2358949,3.4054596,,,,,,,,,,,,,, -61748,,,0.4458398222923279,2.5256237983703613,0.4161399900913238,2.6907958984375,50000.0,0.3223000168800354,3.279371738433838,10000.0,28183.8014895916,30917.355994701385,28183.8014895916,2727.574634075165,2.6977574825286865,0.0 -61800,1.0307678,3.6635942,,,,,,,,,,,,,, -61900,0.94089216,4.318039,,,,,,,,,,,,,, -62000,1.1796731,3.5643663,,,,,,,,,,,,,, -62100,0.9140764,3.9663055,,,,,,,,,,,,,, -62200,1.0639538,3.4915104,,,,,,,,,,,,,, -62300,1.055549,3.4227042,,,,,,,,,,,,,, -62400,1.3659239,3.5423899,,,,,,,,,,,,,, -62500,0.761949,5.6344247,,,,,,,,,,,,,, -62600,1.0214536,3.7135773,,,,,,,,,,,,,, -62670,,,0.4750781059265136,2.330697059631348,0.4386599957942962,2.538495779037476,50000.0,0.3391000032424927,3.1987009048461914,10000.0,28603.97628927231,31379.02168869972,28603.97628927231,2768.981840610504,2.7327027320861816,0.0 -62700,1.0337404,3.6923091,,,,,,,,,,,,,, -62800,0.90074605,4.742394,,,,,,,,,,,,,, -62900,1.0176909,3.4656148,,,,,,,,,,,,,, -63000,0.88380885,4.407897,,,,,,,,,,,,,, -63100,1.0368327,3.2272892,,,,,,,,,,,,,, -63200,0.90628535,5.866467,,,,,,,,,,,,,, -63300,0.81638145,5.655496,,,,,,,,,,,,,, -63400,1.1378341,3.7530687,,,,,,,,,,,,,, -63500,0.9918745,3.612386,,,,,,,,,,,,,, -63592,,,0.4749414026737213,2.3414981365203857,0.44200000166893,2.4970359802246094,50000.0,0.3415000140666961,3.151350259780884,10000.0,29023.923259735107,31837.474231004715,29023.923259735107,2807.4011034965515,2.7702078819274902,0.0 -63600,1.3174912,3.9687114,,,,,,,,,,,,,, -63700,1.1623954,3.651843,,,,,,,,,,,,,, -63800,0.9819686,3.344381,,,,,,,,,,,,,, -63900,0.89946175,3.8575616,,,,,,,,,,,,,, -64000,0.79485464,5.5428686,,,,,,,,,,,,,, -64100,0.87277627,5.801633,,,,,,,,,,,,,, -64200,0.8237398,4.551134,,,,,,,,,,,,,, -64300,0.92116374,5.558177,,,,,,,,,,,,,, -64400,1.1218487,3.502835,,,,,,,,,,,,,, -64500,1.1381861,3.374325,,,,,,,,,,,,,, -64514,,,0.4628320336341858,2.441080331802368,0.4318199753761291,2.605339527130127,50000.0,0.3314000070095062,3.231776237487793,10000.0,29443.90414404869,32297.55117797852,29443.90414404869,2847.411164045334,2.807270765304565,0.0 -64600,0.9979322,3.5642557,,,,,,,,,,,,,, -64700,1.2852002,3.511495,,,,,,,,,,,,,, -64800,1.0835171,3.438896,,,,,,,,,,,,,, -64900,0.9298335,5.7178345,,,,,,,,,,,,,, -65000,0.92145985,5.757779,,,,,,,,,,,,,, -65100,0.8642601,5.7449903,,,,,,,,,,,,,, -65200,0.8080079,5.771218,,,,,,,,,,,,,, -65300,1.0003382,3.4745214,,,,,,,,,,,,,, -65400,1.0519857,3.3322332,,,,,,,,,,,,,, -65437,,,0.4589453041553497,2.488585948944092,0.4240599870681762,2.6650540828704834,50000.0,0.3297000229358673,3.290002822875977,10000.0,29864.123959302902,32757.831778764725,29864.123959302902,2887.384506225586,2.84656310081482,0.0 -65500,1.1402619,3.5285268,,,,,,,,,,,,,, -65600,0.91107875,5.791683,,,,,,,,,,,,,, -65700,1.09298,3.402423,,,,,,,,,,,,,, -65800,1.014268,4.2488856,,,,,,,,,,,,,, -65900,0.66527927,5.75437,,,,,,,,,,,,,, -66000,1.0174521,3.4203084,,,,,,,,,,,,,, -66100,1.1383051,3.529345,,,,,,,,,,,,,, -66200,1.0342625,3.7708845,,,,,,,,,,,,,, -66300,1.0891455,3.7558522,,,,,,,,,,,,,, -66360,,,0.467089831829071,2.380722045898437,0.4341999888420105,2.5594401359558105,50000.0,0.3379000127315521,3.1919198036193848,10000.0,30284.387431383133,33220.57291150093,30284.387431383133,2929.774935245514,2.8840224742889404,0.0 -66400,0.96183014,4.048172,,,,,,,,,,,,,, -66500,1.002251,5.5629787,,,,,,,,,,,,,, -66600,1.2281672,3.3337705,,,,,,,,,,,,,, -66700,1.4769586,3.5339973,,,,,,,,,,,,,, -66800,1.0891507,3.3136544,,,,,,,,,,,,,, -66900,0.89098334,5.450243,,,,,,,,,,,,,, -67000,0.83598787,5.715145,,,,,,,,,,,,,, -67100,0.8605083,4.716804,,,,,,,,,,,,,, -67200,0.96659786,5.845385,,,,,,,,,,,,,, -67284,,,0.4660546779632568,2.3730216026306152,0.4387199878692627,2.5333032608032227,50000.0,0.3418000042438507,3.1739964485168457,10000.0,30704.6074051857,33683.47833299637,30704.6074051857,2972.3694083690643,2.92598557472229,0.0 -67300,1.3667587,3.4690506,,,,,,,,,,,,,, -67400,1.146221,3.8146272,,,,,,,,,,,,,, -67500,0.91746813,4.347039,,,,,,,,,,,,,, -67600,0.84537566,4.36427,,,,,,,,,,,,,, -67700,0.8772973,4.505682,,,,,,,,,,,,,, -67800,1.166393,3.4254968,,,,,,,,,,,,,, -67900,0.9791271,3.3952887,,,,,,,,,,,,,, -68000,1.0388776,4.1710434,,,,,,,,,,,,,, -68100,1.0936519,5.7067018,,,,,,,,,,,,,, -68200,1.1015197,3.345111,,,,,,,,,,,,,, -68207,,,0.4681445062160492,2.362294435501098,0.4355399906635284,2.5348927974700928,50000.0,0.3398000299930572,3.1727473735809326,10000.0,31124.684980630875,34140.442061424255,31124.684980630875,3009.159858226776,2.972452402114868,0.0 -68300,1.0930728,3.590509,,,,,,,,,,,,,, -68400,0.9483489,4.403062,,,,,,,,,,,,,, -68500,0.84741503,5.4503913,,,,,,,,,,,,,, -68600,1.070538,3.408006,,,,,,,,,,,,,, -68700,0.933756,5.7583523,,,,,,,,,,,,,, -68800,1.0716169,3.3328407,,,,,,,,,,,,,, -68900,1.2364125,3.399455,,,,,,,,,,,,,, -69000,1.142698,3.2993472,,,,,,,,,,,,,, -69100,0.8353824,4.025131,,,,,,,,,,,,,, -69130,,,0.4894726574420929,2.2886223793029785,0.4311199784278869,2.5940520763397217,50000.0,0.3315000236034393,3.226326704025269,10000.0,31545.37599992752,34600.265894174576,31545.37599992752,3048.2066645622253,3.0094172954559326,0.0 -69200,1.0350384,3.346495,,,,,,,,,,,,,, -69300,0.77900493,5.2445784,,,,,,,,,,,,,, -69400,0.9367516,5.763915,,,,,,,,,,,,,, -69500,1.0245386,3.343081,,,,,,,,,,,,,, -69600,1.0930868,3.3614347,,,,,,,,,,,,,, -69700,1.1639526,3.2485695,,,,,,,,,,,,,, -69800,0.8537641,5.156531,,,,,,,,,,,,,, -69900,0.94096285,5.3836365,,,,,,,,,,,,,, -70000,1.1734362,3.5497756,,,,,,,,,,,,,, -70053,,,0.485644519329071,2.2598440647125244,0.4579599797725677,2.3936898708343506,50000.0,0.356900006532669,3.05293607711792,10000.0,31965.58531689644,35062.400102853775,31965.58531689644,3090.044335842133,3.047203779220581,0.0 -70100,0.9344618,4.0780687,,,,,,,,,,,,,, -70200,1.2072477,3.7409205,,,,,,,,,,,,,, -70300,1.2745802,3.283952,,,,,,,,,,,,,, -70400,1.0088844,3.298495,,,,,,,,,,,,,, -70500,1.0487922,3.207415,,,,,,,,,,,,,, -70600,1.0082227,3.3367116,,,,,,,,,,,,,, -70700,1.0283837,3.8303182,,,,,,,,,,,,,, -70800,0.89445066,4.409361,,,,,,,,,,,,,, -70900,0.7839945,4.8901086,,,,,,,,,,,,,, -70977,,,0.482421875,2.2963407039642334,0.4460999965667724,2.483259916305542,50000.0,0.349700003862381,3.116337776184082,10000.0,32385.602464437485,35520.32847523689,32385.602464437485,3127.865446805954,3.08792495727539,0.0 -71000,0.78616244,5.507391,,,,,,,,,,,,,, -71100,0.908653,4.779861,,,,,,,,,,,,,, -71200,1.0126517,3.3490517,,,,,,,,,,,,,, -71300,1.0810882,3.6968558,,,,,,,,,,,,,, -71400,0.8771388,5.1759915,,,,,,,,,,,,,, -71500,1.0805435,3.5135825,,,,,,,,,,,,,, -71600,1.2808338,3.4012043,,,,,,,,,,,,,, -71700,1.1217557,3.7989428,,,,,,,,,,,,,, -71800,0.8790468,5.7603345,,,,,,,,,,,,,, -71899,,,0.4937304556369781,2.219071626663208,0.4499399960041046,2.446985006332397,50000.0,0.3488000035285949,3.096506118774414,10000.0,32805.70732951164,35982.90106034279,32805.70732951164,3170.2405710220337,3.131728172302246,0.0 -71900,1.1282934,3.3566086,,,,,,,,,,,,,, -72000,0.8805299,5.4946437,,,,,,,,,,,,,, -72100,1.0432407,3.4989154,,,,,,,,,,,,,, -72200,1.1257793,3.642952,,,,,,,,,,,,,, -72300,0.9035068,5.5165668,,,,,,,,,,,,,, -72400,1.1089963,3.305917,,,,,,,,,,,,,, -72500,1.0061572,4.2563367,,,,,,,,,,,,,, -72600,0.9041531,5.556142,,,,,,,,,,,,,, -72700,0.9111258,5.3213105,,,,,,,,,,,,,, -72800,0.916454,5.523387,,,,,,,,,,,,,, -72822,,,0.4931640625,2.2360095977783203,0.4639399945735931,2.3808131217956543,50000.0,0.3605000078678131,3.04154372215271,10000.0,33225.77615022659,36445.34771442413,33225.77615022659,3212.5319378376007,3.1703007221221924,0.0 -72900,1.0789064,3.3987744,,,,,,,,,,,,,, -73000,0.70944875,5.2755766,,,,,,,,,,,,,, -73100,1.2695943,3.3101718,,,,,,,,,,,,,, -73200,1.3273574,3.4713786,,,,,,,,,,,,,, -73300,1.4082006,3.4001412,,,,,,,,,,,,,, -73400,0.8368727,4.963084,,,,,,,,,,,,,, -73500,1.1079677,3.3066416,,,,,,,,,,,,,, -73600,0.9223207,5.187183,,,,,,,,,,,,,, -73700,0.9241051,4.27193,,,,,,,,,,,,,, -73744,,,0.4816796779632568,2.3041117191314697,0.452019989490509,2.4658422470092773,50000.0,0.3540000021457672,3.107847213745117,10000.0,33646.14836072922,36902.12452173233,33646.14836072922,3248.851069688797,3.207165241241455,0.0 -73800,1.1253672,3.3997514,,,,,,,,,,,,,, -73900,1.2735846,3.3008873,,,,,,,,,,,,,, -74000,0.9890273,3.683381,,,,,,,,,,,,,, -74100,1.028779,3.343028,,,,,,,,,,,,,, -74200,1.6388369,3.3149471,,,,,,,,,,,,,, -74300,1.1835489,3.9997072,,,,,,,,,,,,,, -74400,1.1719126,3.3501582,,,,,,,,,,,,,, -74500,0.993616,5.7750463,,,,,,,,,,,,,, -74600,1.0136169,3.56454,,,,,,,,,,,,,, -74668,,,0.5001562237739563,2.175226926803589,0.4656799733638763,2.3801429271698,50000.0,0.3610000312328338,3.050954341888428,10000.0,34066.290695905685,37362.36349225044,34066.290695905685,3288.8602344989777,3.24563980102539,0.0 -74700,0.84578806,5.770762,,,,,,,,,,,,,, -74800,1.1822624,3.4461975,,,,,,,,,,,,,, -74900,0.8888867,4.6765356,,,,,,,,,,,,,, -75000,1.1076388,3.9907446,,,,,,,,,,,,,, -75100,1.1097894,3.3456259,,,,,,,,,,,,,, -75200,1.0120466,3.328835,,,,,,,,,,,,,, -75300,1.0568292,3.2574906,,,,,,,,,,,,,, -75400,0.89559764,5.6017942,,,,,,,,,,,,,, -75500,1.0811903,3.344496,,,,,,,,,,,,,, -75590,,,0.4973046779632568,2.2028791904449463,0.4642199873924255,2.3795573711395264,50000.0,0.3633000254631042,3.0194509029388428,10000.0,34486.66932654381,37825.93365931511,34486.66932654381,3331.9556045532227,3.292352199554444,0.0 -75600,1.0581739,3.67198,,,,,,,,,,,,,, -75700,1.053945,3.2756495,,,,,,,,,,,,,, -75800,1.0847735,3.7621145,,,,,,,,,,,,,, -75900,1.3406662,3.272534,,,,,,,,,,,,,, -76000,1.201217,3.2014468,,,,,,,,,,,,,, -76100,1.3463705,3.4495058,,,,,,,,,,,,,, -76200,0.80339825,5.505494,,,,,,,,,,,,,, -76300,1.0246463,3.3840904,,,,,,,,,,,,,, -76400,1.1481179,3.2597923,,,,,,,,,,,,,, -76500,0.9127926,3.6276276,,,,,,,,,,,,,, -76511,,,0.4972656071186065,2.211764097213745,0.464739978313446,2.384872436523437,50000.0,0.3605000078678131,3.0538387298583984,10000.0,34906.83700180054,38284.10998988152,34906.83700180054,3369.877052307129,3.3313422203063965,0.0 -76600,0.9676514,5.697114,,,,,,,,,,,,,, -76700,1.1840652,3.4033296,,,,,,,,,,,,,, -76800,0.8966729,3.8475704,,,,,,,,,,,,,, -76900,1.0818524,3.2807682,,,,,,,,,,,,,, -77000,1.0509171,3.2692456,,,,,,,,,,,,,, -77100,1.2475208,3.6127594,,,,,,,,,,,,,, -77200,0.9004393,5.8198614,,,,,,,,,,,,,, -77300,0.9796191,4.4414134,,,,,,,,,,,,,, -77400,1.0451201,3.2871015,,,,,,,,,,,,,, -77432,,,0.5082421898841858,2.163210153579712,0.4681800007820129,2.3664684295654297,50000.0,0.3682000041007995,3.008074760437012,10000.0,35326.80542087555,38746.58549690247,35326.80542087555,3412.295135498047,3.371867179870605,0.0 -77500,0.9170125,4.260623,,,,,,,,,,,,,, -77600,1.1257257,3.369975,,,,,,,,,,,,,, -77700,1.0659351,3.5731745,,,,,,,,,,,,,, -77800,1.0872568,3.3477373,,,,,,,,,,,,,, -77900,1.1894965,3.2664106,,,,,,,,,,,,,, -78000,0.91029954,4.6999216,,,,,,,,,,,,,, -78100,1.0118781,3.3108726,,,,,,,,,,,,,, -78200,1.2135918,3.2916129,,,,,,,,,,,,,, -78300,0.9749296,3.2557878,,,,,,,,,,,,,, -78356,,,0.5389648079872131,2.0390491485595703,0.4726599752902984,2.385406494140625,50000.0,0.3723000288009643,3.007336378097534,10000.0,35746.876959085464,39210.46626496315,35746.876959085464,3456.013298511505,3.413897752761841,0.0 -78400,0.98360026,5.1335506,,,,,,,,,,,,,, -78500,1.0757447,3.4140759,,,,,,,,,,,,,, -78600,1.1065032,3.2436187,,,,,,,,,,,,,, -78700,1.2029066,3.2637205,,,,,,,,,,,,,, -78800,1.0324984,3.5051131,,,,,,,,,,,,,, -78900,0.8555227,4.34658,,,,,,,,,,,,,, -79000,1.0524173,3.4635997,,,,,,,,,,,,,, -79100,1.1327064,3.2980058,,,,,,,,,,,,,, -79200,0.86229944,4.770926,,,,,,,,,,,,,, -79280,,,0.5073632597923279,2.175447940826416,0.4732399880886078,2.356532573699951,50000.0,0.3680000305175781,2.9997689723968506,10000.0,36166.997678518295,39672.47076559067,36166.997678518295,3497.8065342903137,3.455860614776612,0.0 -79300,1.2644209,3.3385503,,,,,,,,,,,,,, -79400,1.0524645,3.0581892,,,,,,,,,,,,,, -79500,1.2155924,3.3149633,,,,,,,,,,,,,, -79600,0.95990795,3.392847,,,,,,,,,,,,,, -79700,0.9270567,4.0704103,,,,,,,,,,,,,, -79800,0.8661671,5.1390276,,,,,,,,,,,,,, -79900,0.801765,4.4422255,,,,,,,,,,,,,, -80000,1.0110688,3.467477,,,,,,,,,,,,,, -80100,1.138302,3.3196137,,,,,,,,,,,,,, -80200,1.3058892,3.1973362,,,,,,,,,,,,,, -80204,,,0.5040234327316284,2.234501361846924,0.469760000705719,2.405181884765625,50000.0,0.3631000220775604,3.042214870452881,10000.0,36587.06204032898,40131.91340637207,36587.06204032898,3537.0906777381897,3.501105070114136,0.0 -80300,0.97667336,3.6375296,,,,,,,,,,,,,, -80400,0.9142707,4.189972,,,,,,,,,,,,,, -80500,1.1397941,3.2590215,,,,,,,,,,,,,, -80600,1.053978,5.6939926,,,,,,,,,,,,,, -80700,1.0695432,3.2354434,,,,,,,,,,,,,, -80800,0.97056645,3.31277,,,,,,,,,,,,,, -80900,1.079496,3.5394378,,,,,,,,,,,,,, -81000,1.0002891,3.709034,,,,,,,,,,,,,, -81100,1.0493412,3.173347,,,,,,,,,,,,,, -81127,,,0.5310351252555847,2.0206193923950195,0.484059989452362,2.27791166305542,50000.0,0.3806000053882599,2.9327199459075928,10000.0,37007.36421537399,40593.58241772652,37007.36421537399,3578.367530584336,3.5425891876220703,0.0 -81200,0.91517335,4.9617167,,,,,,,,,,,,,, -81300,0.983263,5.178178,,,,,,,,,,,,,, -81400,0.96324027,5.6844215,,,,,,,,,,,,,, -81500,0.8840748,4.172225,,,,,,,,,,,,,, -81600,1.0145392,3.9904885,,,,,,,,,,,,,, -81700,0.96768355,4.5638695,,,,,,,,,,,,,, -81800,1.274526,3.383729,,,,,,,,,,,,,, -81900,0.99326795,3.1874819,,,,,,,,,,,,,, -82000,1.0411226,3.1060803,,,,,,,,,,,,,, -82046,,,0.5068749785423279,2.1524007320404053,0.4747999906539917,2.324914693832397,50000.0,0.3708000183105469,2.9745867252349854,10000.0,37427.59137535095,41055.39542388916,37427.59137535095,3619.866011381149,3.581484079360962,0.0 -82100,1.3161789,3.377957,,,,,,,,,,,,,, -82200,0.9419672,4.2802706,,,,,,,,,,,,,, -82300,1.0154483,3.9080937,,,,,,,,,,,,,, -82400,1.1707498,3.27709,,,,,,,,,,,,,, -82500,1.0782313,3.2027504,,,,,,,,,,,,,, -82600,1.0246334,3.6886566,,,,,,,,,,,,,, -82700,0.8663761,4.390079,,,,,,,,,,,,,, -82800,0.8703515,4.3729525,,,,,,,,,,,,,, -82900,0.77953446,5.196924,,,,,,,,,,,,,, -82966,,,0.5151953101158142,2.132296800613404,0.4791599810123443,2.317836284637451,50000.0,0.3758000135421753,2.963873147964477,10000.0,37847.52794003487,41518.515270233154,37847.52794003487,3662.954482316971,3.627816438674927,0.0 -83000,1.2145611,3.2875147,,,,,,,,,,,,,, -83100,0.98695606,5.5090895,,,,,,,,,,,,,, -83200,1.1663836,3.2283857,,,,,,,,,,,,,, -83300,1.1149127,3.2870893,,,,,,,,,,,,,, -83400,1.0522245,3.273655,,,,,,,,,,,,,, -83500,1.2675802,3.1749954,,,,,,,,,,,,,, -83600,1.2788235,3.4561307,,,,,,,,,,,,,, -83700,1.1071053,3.5688186,,,,,,,,,,,,,, -83800,1.0709335,3.2472656,,,,,,,,,,,,,, -83888,,,0.5267773270606995,2.065271615982056,0.4811999797821045,2.2873899936676025,50000.0,0.3775000274181366,2.935068368911743,10000.0,38267.72216629982,41979.064101696014,38267.72216629982,3703.213614225388,3.6741631031036377,0.0 -83900,0.84480166,5.185192,,,,,,,,,,,,,, -84000,1.2240161,3.2070246,,,,,,,,,,,,,, -84100,1.1210269,3.1355636,,,,,,,,,,,,,, -84200,1.218932,3.4535828,,,,,,,,,,,,,, -84300,1.0588877,3.0734384,,,,,,,,,,,,,, -84400,1.0420303,3.1467252,,,,,,,,,,,,,, -84500,1.0508991,3.017173,,,,,,,,,,,,,, -84600,1.4050102,3.4335341,,,,,,,,,,,,,, -84700,0.88152105,4.463702,,,,,,,,,,,,,, -84800,1.1890137,3.3617897,,,,,,,,,,,,,, -84811,,,0.52587890625,2.056821823120117,0.4928999841213226,2.2359001636505127,50000.0,0.3895000219345093,2.879626035690308,10000.0,38687.88592624664,42439.90636992455,38687.88592624664,3743.798425197601,3.718817710876465,0.0 -84900,1.2507305,3.064208,,,,,,,,,,,,,, -85000,1.1825438,3.3392236,,,,,,,,,,,,,, -85100,0.91180366,4.050479,,,,,,,,,,,,,, -85200,1.1612775,3.2293289,,,,,,,,,,,,,, -85300,0.8891635,5.0345964,,,,,,,,,,,,,, -85400,1.0652126,4.3420444,,,,,,,,,,,,,, -85500,1.1982805,3.4902568,,,,,,,,,,,,,, -85600,1.0974768,3.6400518,,,,,,,,,,,,,, -85700,0.9950403,3.386143,,,,,,,,,,,,,, -85733,,,0.517382800579071,2.116588830947876,0.4849999845027923,2.2845561504364014,50000.0,0.3757000267505646,2.9468016624450684,10000.0,39108.16635656357,42901.415877103806,39108.16635656357,3784.9390771389008,3.758352756500244,0.0 -85800,1.2173973,3.2888558,,,,,,,,,,,,,, -85900,1.005275,5.6302614,,,,,,,,,,,,,, -86000,0.75775635,5.572264,,,,,,,,,,,,,, -86100,0.97087485,5.098194,,,,,,,,,,,,,, -86200,1.1059104,5.1483984,,,,,,,,,,,,,, -86300,1.087235,3.128291,,,,,,,,,,,,,, -86400,1.19707,3.1872344,,,,,,,,,,,,,, -86500,1.0800065,3.2810874,,,,,,,,,,,,,, -86600,1.0486845,3.3968236,,,,,,,,,,,,,, -86655,,,0.5392773151397705,2.0084869861602783,0.5009399652481079,2.2077176570892334,50000.0,0.3890000283718109,2.89555287361145,10000.0,39528.129534721375,43362.101089954376,39528.129534721375,3825.567506790161,3.802644729614258,0.0 -86700,1.2469023,3.2963057,,,,,,,,,,,,,, -86800,1.1310214,3.3820114,,,,,,,,,,,,,, -86900,1.0413381,2.990498,,,,,,,,,,,,,, -87000,0.91239655,5.7911034,,,,,,,,,,,,,, -87100,1.0329871,4.431375,,,,,,,,,,,,,, -87200,1.092376,3.2375603,,,,,,,,,,,,,, -87300,1.1529917,3.0856164,,,,,,,,,,,,,, -87400,1.129322,3.0636897,,,,,,,,,,,,,, -87500,0.8670651,4.8185344,,,,,,,,,,,,,, -87576,,,0.5444140434265137,1.9728312492370603,0.492499977350235,2.241783618927002,50000.0,0.3804000318050384,2.9136619567871094,10000.0,39948.1226670742,43824.45436620712,39948.1226670742,3867.839832305908,3.842219352722168,0.0 -87600,1.0547113,4.186554,,,,,,,,,,,,,, -87700,1.2334481,3.1257005,,,,,,,,,,,,,, -87800,1.0865828,3.1333344,,,,,,,,,,,,,, -87900,0.93951875,5.237,,,,,,,,,,,,,, -88000,1.0019993,5.218179,,,,,,,,,,,,,, -88100,1.3980509,3.3611722,,,,,,,,,,,,,, -88200,1.2400856,3.0582948,,,,,,,,,,,,,, -88300,1.0900509,3.0733829,,,,,,,,,,,,,, -88400,1.0498924,3.1724622,,,,,,,,,,,,,, -88499,,,0.5226757526397705,2.081761598587036,0.491599977016449,2.257352590560913,50000.0,0.3840000033378601,2.8978774547576904,10000.0,40368.45597243309,44282.1383099556,40368.45597243309,3905.0960161685935,3.8879823684692383,0.0 -88500,1.0518757,3.4418168,,,,,,,,,,,,,, -88600,1.0752511,3.0974333,,,,,,,,,,,,,, -88700,0.97719496,4.693027,,,,,,,,,,,,,, -88800,0.92655617,5.4202046,,,,,,,,,,,,,, -88900,1.0420306,3.2731595,,,,,,,,,,,,,, -89000,0.9839283,4.0391045,,,,,,,,,,,,,, -89100,0.9714381,4.285013,,,,,,,,,,,,,, -89200,1.3258878,3.209947,,,,,,,,,,,,,, -89300,1.1392237,3.122204,,,,,,,,,,,,,, -89400,1.1677923,3.0824947,,,,,,,,,,,,,, -89420,,,0.5298827886581421,2.077981472015381,0.4909399747848511,2.27143931388855,50000.0,0.3827000260353088,2.921517372131348,10000.0,40788.48864960671,44745.744921684265,40788.48864960671,3948.579337596893,3.9286277294158936,0.0 -89500,1.1225007,3.1830664,,,,,,,,,,,,,, -89600,1.0379709,3.498728,,,,,,,,,,,,,, -89700,1.2670292,3.0596852,,,,,,,,,,,,,, -89800,1.0096668,3.4372065,,,,,,,,,,,,,, -89900,1.1107553,3.0446386,,,,,,,,,,,,,, -90000,1.2531469,3.045445,,,,,,,,,,,,,, -90100,1.1571568,3.0562506,,,,,,,,,,,,,, -90200,1.2905792,3.0637465,,,,,,,,,,,,,, -90300,1.0552337,3.1037579,,,,,,,,,,,,,, -90343,,,0.5615038871765137,1.8954758644104004,0.5057199597358704,2.180562734603882,50000.0,0.3936000168323517,2.8484578132629395,10000.0,41208.83841848373,45211.47863817215,41208.83841848373,3993.872991085053,3.9703309535980233,0.0 -90400,1.0547091,3.4769518,,,,,,,,,,,,,, -90500,1.0720096,3.3050582,,,,,,,,,,,,,, -90600,0.95223606,3.5414395,,,,,,,,,,,,,, -90700,1.1081065,3.1790516,,,,,,,,,,,,,, -90800,1.0436236,3.2012644,,,,,,,,,,,,,, -90900,0.95616245,5.6478243,,,,,,,,,,,,,, -91000,1.4005305,3.1069477,,,,,,,,,,,,,, -91100,1.0671387,3.0765648,,,,,,,,,,,,,, -91200,1.0237428,3.5072796,,,,,,,,,,,,,, -91267,,,0.5329882502555847,2.074381828308105,0.5001400113105774,2.2381608486175537,50000.0,0.3892000317573547,2.885896921157837,10000.0,41629.06167125702,45675.498945236206,41629.06167125702,4037.57505941391,4.016332626342773,0.0 -91300,1.1239802,3.2367587,,,,,,,,,,,,,, -91400,1.0632541,3.1205885,,,,,,,,,,,,,, -91500,0.9776434,4.428287,,,,,,,,,,,,,, -91600,1.1887535,3.370875,,,,,,,,,,,,,, -91700,1.1254324,3.7655425,,,,,,,,,,,,,, -91800,1.11152,5.6567726,,,,,,,,,,,,,, -91900,1.1583797,5.3443556,,,,,,,,,,,,,, -92000,1.2871755,2.9495437,,,,,,,,,,,,,, -92100,1.2393198,3.2577815,,,,,,,,,,,,,, -92188,,,0.5424413681030273,1.9888652563095093,0.5024399757385254,2.1791367530822754,50000.0,0.3930000066757202,2.836609601974488,10000.0,42048.947972774506,46136.65203642845,42048.947972774506,4078.279564142227,4.529749393463135,0.0 -92200,1.1539278,3.3711262,,,,,,,,,,,,,, -92300,0.8599294,5.408015,,,,,,,,,,,,,, -92400,1.1023394,3.1700754,,,,,,,,,,,,,, -92500,1.1640044,3.1742275,,,,,,,,,,,,,, -92600,1.0413458,3.6882868,,,,,,,,,,,,,, -92700,1.1408695,3.0196502,,,,,,,,,,,,,, -92800,1.2706988,3.0838394,,,,,,,,,,,,,, -92900,1.1491567,4.677867,,,,,,,,,,,,,, -93000,1.274426,3.3593295,,,,,,,,,,,,,, -93100,0.93745816,3.7581153,,,,,,,,,,,,,, -93110,,,0.5608007907867432,1.895853400230408,0.5133000016212463,2.126741886138916,50000.0,0.4075000286102295,2.773594617843628,10000.0,42469.17872309685,46599.65407395363,42469.17872309685,4120.957757472992,4.573975324630737,0.0 -93200,1.1788723,3.0376139,,,,,,,,,,,,,, -93300,0.8186726,5.552652,,,,,,,,,,,,,, -93400,1.0281216,3.8064559,,,,,,,,,,,,,, -93500,1.2012339,3.051149,,,,,,,,,,,,,, -93600,1.3107555,3.120069,,,,,,,,,,,,,, -93700,1.7772276,3.1907237,,,,,,,,,,,,,, -93800,0.96591765,4.452961,,,,,,,,,,,,,, -93900,1.0283431,3.788022,,,,,,,,,,,,,, -94000,1.0783116,2.933456,,,,,,,,,,,,,, -94031,,,0.5450390577316284,1.9546549320220947,0.5126599669456482,2.1251654624938965,50000.0,0.4058000147342682,2.794235944747925,10000.0,42889.36215043068,47059.432107925415,42889.36215043068,4160.462041378021,4.615373611450195,0.0 -94100,1.1862866,2.9304674,,,,,,,,,,,,,, -94200,1.298515,5.418929,,,,,,,,,,,,,, -94300,0.90639615,4.472189,,,,,,,,,,,,,, -94400,1.2198589,2.9958901,,,,,,,,,,,,,, -94500,1.4406445,3.047522,,,,,,,,,,,,,, -94600,0.88490325,5.427791,,,,,,,,,,,,,, -94700,1.1394858,2.9983072,,,,,,,,,,,,,, -94800,0.9451423,3.9109297,,,,,,,,,,,,,, -94900,1.0166923,5.3002152,,,,,,,,,,,,,, -94953,,,0.5402538776397705,2.001154661178589,0.5038999915122986,2.17992639541626,50000.0,0.4006000161170959,2.837176322937012,10000.0,43309.5159702301,47522.02918076515,43309.5159702301,4202.813487291336,4.658671140670776,0.0 -95000,1.0845355,5.3118935,,,,,,,,,,,,,, -95100,1.0890507,3.134642,,,,,,,,,,,,,, -95200,1.1755681,2.949277,,,,,,,,,,,,,, -95300,1.1668637,3.0496356,,,,,,,,,,,,,, -95400,1.0767723,3.8872561,,,,,,,,,,,,,, -95500,1.1124301,3.0786698,,,,,,,,,,,,,, -95600,1.1866924,3.2916248,,,,,,,,,,,,,, -95700,1.3616964,3.1954215,,,,,,,,,,,,,, -95800,1.2651007,3.0424018,,,,,,,,,,,,,, -95875,,,0.5454687476158142,2.008453845977783,0.5034399628639221,2.21916127204895,50000.0,0.3922000229358673,2.864680528640747,10000.0,43729.85532140732,47982.59704566002,43729.85532140732,4242.95054268837,4.701233148574829,0.0 -95900,0.905168,5.492555,,,,,,,,,,,,,, -96000,1.1858149,2.9746864,,,,,,,,,,,,,, -96100,0.88475686,5.488603,,,,,,,,,,,,,, -96200,0.9684521,4.4849515,,,,,,,,,,,,,, -96300,1.2774893,3.0549254,,,,,,,,,,,,,, -96400,1.1031259,3.52891,,,,,,,,,,,,,, -96500,0.984979,5.2330136,,,,,,,,,,,,,, -96600,1.1385281,3.0014205,,,,,,,,,,,,,, -96700,0.85976267,5.4685707,,,,,,,,,,,,,, -96795,,,0.5572656393051147,1.9084703922271729,0.5183199644088745,2.112875461578369,50000.0,0.409600019454956,2.749626398086548,10000.0,44149.91594219208,48443.68581080437,44149.91594219208,4283.885046005249,4.746007680892944,0.0 -96800,1.2312616,3.1798923,,,,,,,,,,,,,, -96900,1.3562132,2.9685547,,,,,,,,,,,,,, -97000,1.2086891,2.8391542,,,,,,,,,,,,,, -97100,1.0792309,4.0934563,,,,,,,,,,,,,, -97200,1.3019273,3.212426,,,,,,,,,,,,,, -97300,1.2065749,2.880091,,,,,,,,,,,,,, -97400,1.254471,2.8984916,,,,,,,,,,,,,, -97500,1.0084012,3.454867,,,,,,,,,,,,,, -97600,1.1906613,3.0721776,,,,,,,,,,,,,, -97700,0.92762613,4.514675,,,,,,,,,,,,,, -97717,,,0.5448241829872131,1.973633050918579,0.5102199912071228,2.158968210220337,50000.0,0.4025000333786011,2.8153152465820312,10000.0,44570.11754012108,48903.320981264114,44570.11754012108,4323.22603559494,4.789544582366943,0.0 -97800,1.0358433,4.506735,,,,,,,,,,,,,, -97900,0.9745361,4.4801264,,,,,,,,,,,,,, -98000,1.1332641,5.543242,,,,,,,,,,,,,, -98100,1.1913652,3.05265,,,,,,,,,,,,,, -98200,1.0108131,3.5023248,,,,,,,,,,,,,, -98300,0.95768195,3.8366587,,,,,,,,,,,,,, -98400,1.2391254,2.9240565,,,,,,,,,,,,,, -98500,1.0393109,4.4038987,,,,,,,,,,,,,, -98600,1.2074115,2.9029243,,,,,,,,,,,,,, -98639,,,0.5551952719688416,1.929055452346801,0.5148800015449524,2.135251760482788,50000.0,0.4062000215053558,2.79468321800232,10000.0,44990.19146609306,49363.72312116623,44990.19146609306,4363.465132236481,4.829688310623169,0.0 -98700,1.14941,3.0577147,,,,,,,,,,,,,, -98800,0.8526192,5.153429,,,,,,,,,,,,,, -98900,1.3172148,3.0199528,,,,,,,,,,,,,, -99000,1.018626,3.6790128,,,,,,,,,,,,,, -99100,1.0364788,3.7809083,,,,,,,,,,,,,, -99200,0.88390684,3.8958106,,,,,,,,,,,,,, -99300,1.044169,4.016456,,,,,,,,,,,,,, -99400,1.0015658,4.9341345,,,,,,,,,,,,,, -99500,1.0811671,3.0070565,,,,,,,,,,,,,, -99560,,,0.5842382907867432,1.811358332633972,0.5146600008010864,2.145191669464112,50000.0,0.4070000052452087,2.797633647918701,10000.0,45410.33908557892,49826.20372629166,45410.33908557892,4405.708698511124,4.869691371917725,0.0 -99600,1.396038,3.0556011,,,,,,,,,,,,,, -99700,0.90872204,3.9070323,,,,,,,,,,,,,, -99800,0.9214943,5.4514236,,,,,,,,,,,,,, -99900,1.1886265,2.9839964,,,,,,,,,,,,,, -100000,1.1184065,3.2307372,,,,,,,,,,,,,, -100100,1.3128589,3.0168035,,,,,,,,,,,,,, -100200,1.1371635,4.510786,,,,,,,,,,,,,, -100300,1.0182351,5.2411113,,,,,,,,,,,,,, -100400,1.3349597,3.169248,,,,,,,,,,,,,, -100482,,,0.5579687356948853,1.9476128816604608,0.5199399590492249,2.1243672370910645,50000.0,0.4086000323295593,2.762568473815918,10000.0,45830.48991537094,50285.001859903336,45830.48991537094,4444.26401591301,4.912209749221802,0.0 -100500,1.1454761,3.0962186,,,,,,,,,,,,,, -100600,0.9751876,4.9305096,,,,,,,,,,,,,, -100700,1.1067722,3.2313673,,,,,,,,,,,,,, -100800,1.1332518,3.0082579,,,,,,,,,,,,,, -100900,1.1758977,2.9943066,,,,,,,,,,,,,, -101000,1.1362547,5.5494127,,,,,,,,,,,,,, -101100,1.0683715,3.8274956,,,,,,,,,,,,,, -101200,1.3591164,2.955882,,,,,,,,,,,,,, -101300,1.0985383,3.1308796,,,,,,,,,,,,,, -101400,1.1259259,2.9538016,,,,,,,,,,,,,, -101404,,,0.5579296946525574,1.942072868347168,0.5202599763870239,2.1204233169555664,50000.0,0.4086000323295593,2.7876458168029785,10000.0,46250.49274253845,50741.545766592026,46250.49274253845,4480.708422660828,4.960192918777466,0.0 -101500,1.1532238,3.2284744,,,,,,,,,,,,,, -101600,1.2202655,3.005118,,,,,,,,,,,,,, -101700,1.2608756,2.8800383,,,,,,,,,,,,,, -101800,1.0726889,4.7940154,,,,,,,,,,,,,, -101900,1.238675,2.9675949,,,,,,,,,,,,,, -102000,0.84742755,5.3717237,,,,,,,,,,,,,, -102100,1.0926847,5.423752,,,,,,,,,,,,,, -102200,1.3050134,2.9882145,,,,,,,,,,,,,, -102300,1.3551911,3.075016,,,,,,,,,,,,,, -102326,,,0.5746093392372131,1.825226068496704,0.5278399586677551,2.0796656608581543,50000.0,0.4100000262260437,2.769395112991333,10000.0,46670.66937637329,51203.09488844872,46670.66937637329,4521.99057674408,5.001747369766235,0.0 -102400,1.1603523,5.48189,,,,,,,,,,,,,, -102500,1.0075636,5.020833,,,,,,,,,,,,,, -102600,1.2561133,2.9538493,,,,,,,,,,,,,, -102700,0.9341094,5.0460463,,,,,,,,,,,,,, -102800,1.2831329,2.8955445,,,,,,,,,,,,,, -102900,0.9208587,5.0214224,,,,,,,,,,,,,, -103000,1.4172117,3.0724375,,,,,,,,,,,,,, -103100,1.2033123,3.1906543,,,,,,,,,,,,,, -103200,0.953171,5.4147916,,,,,,,,,,,,,, -103247,,,0.5560937523841858,1.95291531085968,0.5209599733352661,2.1211466789245605,50000.0,0.4067000150680542,2.7682337760925293,10000.0,47090.8395152092,51661.95341157913,47090.8395152092,4560.5888912677765,5.0430192947387695,0.0 -103300,1.1089547,3.6394033,,,,,,,,,,,,,, -103400,1.2040528,2.887946,,,,,,,,,,,,,, -103500,1.1362638,3.0702305,,,,,,,,,,,,,, -103600,1.012965,3.8802326,,,,,,,,,,,,,, -103700,1.0609914,3.8645184,,,,,,,,,,,,,, -103800,1.259074,2.8721855,,,,,,,,,,,,,, -103900,1.2124506,2.9542022,,,,,,,,,,,,,, -104000,1.3020141,2.9875293,,,,,,,,,,,,,, -104100,0.9683976,5.437676,,,,,,,,,,,,,, -104172,,,0.5689452886581421,1.8281947374343872,0.5344399809837341,2.012808561325073,50000.0,0.4191000163555145,2.6920504570007324,10000.0,47510.9509665966,52125.42877531052,47510.9509665966,4603.852442026138,5.094490051269531,0.0 -104200,1.217754,2.7939186,,,,,,,,,,,,,, -104300,1.2756892,2.9926815,,,,,,,,,,,,,, -104400,1.1031123,3.5112047,,,,,,,,,,,,,, -104500,1.2731059,3.6760416,,,,,,,,,,,,,, -104600,1.2425783,2.8844018,,,,,,,,,,,,,, -104700,1.1782582,2.8793385,,,,,,,,,,,,,, -104800,1.0731413,4.6317763,,,,,,,,,,,,,, -104900,1.3796209,2.9374669,,,,,,,,,,,,,, -105000,1.3080037,2.8581913,,,,,,,,,,,,,, -105094,,,0.5889062285423279,1.7638208866119385,0.5421199798583984,1.9910471439361568,50000.0,0.4271000325679779,2.6605894565582275,10000.0,47930.90433549881,52585.41693139076,47930.90433549881,4643.782491207123,5.150495529174805,0.0 -105100,1.38548,4.1385145,,,,,,,,,,,,,, -105200,1.0326165,3.792016,,,,,,,,,,,,,, -105300,1.3240554,2.967534,,,,,,,,,,,,,, -105400,1.3654325,3.1817586,,,,,,,,,,,,,, -105500,1.2496648,2.744947,,,,,,,,,,,,,, -105600,0.9959645,5.1323743,,,,,,,,,,,,,, -105700,0.8942642,5.1658216,,,,,,,,,,,,,, -105800,1.0569816,3.759093,,,,,,,,,,,,,, -105900,1.0898334,4.0147495,,,,,,,,,,,,,, -106000,1.2917786,3.046843,,,,,,,,,,,,,, -106015,,,0.5689062476158142,1.88288688659668,0.5313199758529663,2.059181928634644,50000.0,0.4189000129699707,2.725624322891236,10000.0,48350.86739444733,53043.58160710335,48350.86739444733,4681.881660699844,5.203859329223633,0.0 -106100,1.0842558,3.7976186,,,,,,,,,,,,,, -106200,1.2088317,2.7490072,,,,,,,,,,,,,, -106300,1.3087004,2.739695,,,,,,,,,,,,,, -106400,1.1963692,2.8274515,,,,,,,,,,,,,, -106500,1.2444249,2.9248435,,,,,,,,,,,,,, -106600,1.1848866,2.7161543,,,,,,,,,,,,,, -106700,1.0201145,5.111496,,,,,,,,,,,,,, -106800,1.2750673,3.0281343,,,,,,,,,,,,,, -106900,1.2539458,2.76861,,,,,,,,,,,,,, -106937,,,0.5787695050239563,1.8417892456054688,0.5412399768829346,2.022822380065918,50000.0,0.431300014257431,2.653942584991455,10000.0,48770.8959608078,53502.4229733944,48770.8959608078,4720.600524902344,5.249224662780762,0.0 -107000,1.2471874,2.9288826,,,,,,,,,,,,,, -107100,1.3374608,2.9165428,,,,,,,,,,,,,, -107200,0.97215194,5.0085025,,,,,,,,,,,,,, -107300,1.2113658,2.9666283,,,,,,,,,,,,,, -107400,1.010794,3.9435096,,,,,,,,,,,,,, -107500,1.0386807,5.441752,,,,,,,,,,,,,, -107600,1.0648211,3.9200115,,,,,,,,,,,,,, -107700,1.1463444,3.4332762,,,,,,,,,,,,,, -107800,1.3515307,2.7966113,,,,,,,,,,,,,, -107860,,,0.5873827934265137,1.7708899974822998,0.5433200001716614,1.9900877475738523,50000.0,0.4325000345706939,2.625818014144897,10000.0,49191.23649168015,53963.276510477066,49191.23649168015,4761.017573356628,5.296191930770874,0.0 -107900,1.3245759,2.9538426,,,,,,,,,,,,,, -108000,1.1206177,3.0450783,,,,,,,,,,,,,, -108100,1.2785944,2.850055,,,,,,,,,,,,,, -108200,1.0677718,4.0076313,,,,,,,,,,,,,, -108300,1.2457978,2.9804068,,,,,,,,,,,,,, -108400,1.1087404,5.3809237,,,,,,,,,,,,,, -108500,1.0019408,5.4036283,,,,,,,,,,,,,, -108600,1.2347603,4.8203325,,,,,,,,,,,,,, -108700,0.9177917,5.1427445,,,,,,,,,,,,,, -108784,,,0.6213085651397705,1.629205584526062,0.5485599637031555,1.96748161315918,50000.0,0.4370000064373016,2.632002353668213,10000.0,49611.455362319946,54420.8622033596,49611.455362319946,4798.291470050812,5.3400962352752686,0.0 -108800,1.0169934,4.822703,,,,,,,,,,,,,, -108900,1.1450014,4.4566298,,,,,,,,,,,,,, -109000,1.0879271,4.510077,,,,,,,,,,,,,, -109100,1.1746653,2.8526053,,,,,,,,,,,,,, -109200,1.1983087,2.871746,,,,,,,,,,,,,, -109300,1.329485,2.6388335,,,,,,,,,,,,,, -109400,1.1783215,3.3883963,,,,,,,,,,,,,, -109500,1.2490938,2.875857,,,,,,,,,,,,,, -109600,1.3075297,5.2675695,,,,,,,,,,,,,, -109700,0.99340695,5.2667856,,,,,,,,,,,,,, -109705,,,0.588671863079071,1.755498290061951,0.5467999577522278,1.954467177391052,50000.0,0.4294000267982483,2.6271207332611084,10000.0,50031.65995430946,54883.06630086899,50031.65995430946,4840.195640087128,5.386731386184692,0.0 -109800,1.0596347,5.272332,,,,,,,,,,,,,, -109900,1.0677488,4.0565786,,,,,,,,,,,,,, -110000,0.9092621,5.2936544,,,,,,,,,,,,,, -110100,1.2517631,3.1063805,,,,,,,,,,,,,, -110200,1.2624835,2.794733,,,,,,,,,,,,,, -110300,1.0180252,4.7201395,,,,,,,,,,,,,, -110400,1.1198797,3.7465692,,,,,,,,,,,,,, -110500,1.3138129,2.8835957,,,,,,,,,,,,,, -110600,1.380882,3.1871696,,,,,,,,,,,,,, -110628,,,0.595898449420929,1.7019425630569458,0.5554800033569336,1.911659836769104,50000.0,0.4384000301361084,2.5746302604675293,10000.0,50451.89798164368,55343.07885932922,50451.89798164368,4879.87920832634,5.428632497787476,0.0 -110700,1.2829238,3.2469077,,,,,,,,,,,,,, -110800,1.2319261,2.9037533,,,,,,,,,,,,,, -110900,1.2404172,2.8504393,,,,,,,,,,,,,, -111000,1.249018,2.799963,,,,,,,,,,,,,, -111100,1.276903,2.8556478,,,,,,,,,,,,,, -111200,1.2306353,3.142137,,,,,,,,,,,,,, -111300,0.9380551,5.403353,,,,,,,,,,,,,, -111400,1.2496779,2.7730224,,,,,,,,,,,,,, -111500,1.36212,2.9857764,,,,,,,,,,,,,, -111551,,,0.6108593344688416,1.6443220376968384,0.5592799782752991,1.905354380607605,50000.0,0.4394000172615051,2.577054500579834,10000.0,50871.97425460816,55801.88543534279,50871.97425460816,4918.517581939697,5.472025394439697,0.0 -111600,1.0025146,4.205665,,,,,,,,,,,,,, -111700,1.0644587,4.49348,,,,,,,,,,,,,, -111800,1.1430357,4.264676,,,,,,,,,,,,,, -111900,1.1676028,3.8581011,,,,,,,,,,,,,, -112000,1.1540557,2.9859881,,,,,,,,,,,,,, -112100,1.285392,2.765592,,,,,,,,,,,,,, -112200,1.1239283,4.5455413,,,,,,,,,,,,,, -112300,1.4515182,2.8313038,,,,,,,,,,,,,, -112400,1.3061882,3.1275496,,,,,,,,,,,,,, -112474,,,0.5939062237739563,1.7227932214736938,0.5597599744796753,1.904951810836792,50000.0,0.4414000213146209,2.582552194595337,10000.0,51292.22484111786,56262.47783136368,51292.22484111786,4958.763142347336,5.519868135452271,0.0 -112500,1.0617968,3.610017,,,,,,,,,,,,,, -112600,1.3876034,2.8224375,,,,,,,,,,,,,, -112700,1.149304,3.5302496,,,,,,,,,,,,,, -112800,1.3719712,2.749453,,,,,,,,,,,,,, -112900,1.3013052,3.2691827,,,,,,,,,,,,,, -113000,1.0193396,4.9118757,,,,,,,,,,,,,, -113100,1.4332042,2.8857193,,,,,,,,,,,,,, -113200,1.5565902,3.1203346,,,,,,,,,,,,,, -113300,1.3026191,2.7686934,,,,,,,,,,,,,, -113395,,,0.6010937094688416,1.6848231554031372,0.5621399879455566,1.882049918174744,50000.0,0.4389000236988067,2.563261270523072,10000.0,51712.5141146183,56724.9230401516,51712.5141146183,5000.821955919266,5.5690460205078125,0.0 -113400,1.3318703,2.9003298,,,,,,,,,,,,,, -113500,1.2874005,2.7221513,,,,,,,,,,,,,, -113600,1.2707291,2.6990957,,,,,,,,,,,,,, -113700,1.3320864,2.9897807,,,,,,,,,,,,,, -113800,1.220192,2.9838674,,,,,,,,,,,,,, -113900,1.685937,2.8540142,,,,,,,,,,,,,, -114000,1.0360836,4.9067464,,,,,,,,,,,,,, -114100,1.5239906,2.6420588,,,,,,,,,,,,,, -114200,1.0334514,4.4849424,,,,,,,,,,,,,, -114300,1.2827637,2.812601,,,,,,,,,,,,,, -114316,,,0.6037890315055847,1.70114004611969,0.55485999584198,1.942366003990173,50000.0,0.442300021648407,2.572660207748413,10000.0,52132.70584130287,57179.35532331467,52132.70584130287,5034.966580152512,5.616888523101807,0.0 -114400,1.1139613,5.1446004,,,,,,,,,,,,,, -114500,1.1159934,3.2870524,,,,,,,,,,,,,, -114600,1.347975,2.7218437,,,,,,,,,,,,,, -114700,1.1773313,3.4700518,,,,,,,,,,,,,, -114800,1.0235685,4.5788774,,,,,,,,,,,,,, -114900,1.0575267,3.6609747,,,,,,,,,,,,,, -115000,1.2910618,2.987799,,,,,,,,,,,,,, -115100,1.2499127,2.646266,,,,,,,,,,,,,, -115200,1.3830434,2.7032933,,,,,,,,,,,,,, -115237,,,0.6056835651397705,1.67186439037323,0.5682199597358704,1.854199171066284,50000.0,0.448600023984909,2.519437313079834,10000.0,52552.76272273064,57640.74333691597,52552.76272273064,5076.205003976822,5.660736322402954,0.0 -115300,1.3227513,2.7593174,,,,,,,,,,,,,, -115400,1.2726586,2.7873733,,,,,,,,,,,,,, -115500,1.1657839,3.0503178,,,,,,,,,,,,,, -115600,1.2635459,2.7708,,,,,,,,,,,,,, -115700,1.582906,2.7122972,,,,,,,,,,,,,, -115800,1.0745709,4.6009974,,,,,,,,,,,,,, -115900,1.3804291,2.6402013,,,,,,,,,,,,,, -116000,1.0939856,3.795377,,,,,,,,,,,,,, -116100,1.3635439,2.8095531,,,,,,,,,,,,,, -116160,,,0.6073437333106995,1.675774097442627,0.5725799798965454,1.846383810043335,50000.0,0.4535000324249267,2.516984462738037,10000.0,52972.98017334938,58102.24056863785,52972.98017334938,5117.389803171158,5.706765413284302,0.0 -116200,1.1228926,3.9697685,,,,,,,,,,,,,, -116300,1.3375305,2.7553318,,,,,,,,,,,,,, -116400,1.2038525,3.240815,,,,,,,,,,,,,, -116500,1.3978615,2.6299024,,,,,,,,,,,,,, -116600,1.2481586,2.8532968,,,,,,,,,,,,,, -116700,1.2637392,2.898556,,,,,,,,,,,,,, -116800,1.0533721,4.234681,,,,,,,,,,,,,, -116900,1.3700331,2.815363,,,,,,,,,,,,,, -117000,1.4816687,2.7292616,,,,,,,,,,,,,, -117083,,,0.6232812404632568,1.5855211019515991,0.5778599977493286,1.8170597553253167,50000.0,0.4607000350952148,2.4718594551086426,10000.0,53393.17074465752,58563.44180321693,53393.17074465752,5158.30740070343,5.750935316085815,0.0 -117100,1.2063676,4.4410486,,,,,,,,,,,,,, -117200,1.2660105,2.685029,,,,,,,,,,,,,, -117300,1.2530653,2.8669972,,,,,,,,,,,,,, -117400,1.4126006,2.724762,,,,,,,,,,,,,, -117500,1.4215845,2.6984704,,,,,,,,,,,,,, -117600,1.2643732,3.25087,,,,,,,,,,,,,, -117700,1.4075125,2.6089027,,,,,,,,,,,,,, -117800,1.3625057,2.6113853,,,,,,,,,,,,,, -117900,1.3378353,2.7168093,,,,,,,,,,,,,, -118000,1.3398978,2.7114704,,,,,,,,,,,,,, -118005,,,0.6318945288658142,1.554277420043945,0.5718399882316589,1.8370327949523928,50000.0,0.462300032377243,2.502239465713501,10000.0,53813.120841264725,59023.410198926926,53813.120841264725,5198.218531131744,5.809661865234375,0.0 -118100,1.2685273,2.6886597,,,,,,,,,,,,,, -118200,1.3297892,2.7279568,,,,,,,,,,,,,, -118300,1.4505614,2.7909002,,,,,,,,,,,,,, -118400,1.1289093,5.201138,,,,,,,,,,,,,, -118500,1.097318,4.153406,,,,,,,,,,,,,, -118600,1.3793652,2.6771417,,,,,,,,,,,,,, -118700,1.1794361,4.466561,,,,,,,,,,,,,, -118800,1.4572887,2.727128,,,,,,,,,,,,,, -118900,1.1604259,3.2559025,,,,,,,,,,,,,, -118929,,,0.6213085651397705,1.6146199703216553,0.5793200135231018,1.8063288927078247,50000.0,0.4620000123977661,2.4387903213500977,10000.0,54233.34093928337,59486.030648469925,54233.34093928337,5240.525000333786,5.854615688323975,0.0 -119000,1.3346689,2.6187086,,,,,,,,,,,,,, -119100,1.505075,2.7336583,,,,,,,,,,,,,, -119200,1.3172988,2.5978224,,,,,,,,,,,,,, -119300,1.3152982,2.7513022,,,,,,,,,,,,,, -119400,1.4771769,3.0846355,,,,,,,,,,,,,, -119500,1.3459655,3.031255,,,,,,,,,,,,,, -119600,1.2704116,3.0490208,,,,,,,,,,,,,, -119700,1.3473859,2.7701073,,,,,,,,,,,,,, -119800,1.17982,3.9333613,,,,,,,,,,,,,, -119852,,,0.6187695264816284,1.6314858198165894,0.5747199654579163,1.853618025779724,50000.0,0.4583000242710113,2.520430088043213,10000.0,54653.69129776955,59941.873532533646,54653.69129776955,5275.923225164413,5.900138854980469,0.0 -119900,1.2856003,4.942613,,,,,,,,,,,,,, -120000,1.3950762,2.6015127,,,,,,,,,,,,,, -120100,1.1221424,3.975175,,,,,,,,,,,,,, -120200,1.339151,4.50573,,,,,,,,,,,,,, -120300,1.221971,4.087133,,,,,,,,,,,,,, -120400,1.4070836,2.7361066,,,,,,,,,,,,,, -120500,1.4428021,2.6590164,,,,,,,,,,,,,, -120600,1.3033047,3.9285314,,,,,,,,,,,,,, -120700,1.4281945,2.7380767,,,,,,,,,,,,,, -120773,,,0.6486914157867432,1.4689342975616455,0.5861999988555908,1.7645444869995115,50000.0,0.4687000215053558,2.431938886642456,10000.0,55073.72949528694,60402.14140725136,55073.72949528694,5316.057541370392,5.947498083114624,0.0 -120800,1.2855188,2.5763655,,,,,,,,,,,,,, -120900,1.2868221,2.6868377,,,,,,,,,,,,,, -121000,1.3226533,2.562758,,,,,,,,,,,,,, -121100,1.3757575,2.6249824,,,,,,,,,,,,,, -121200,1.3128846,2.4901526,,,,,,,,,,,,,, -121300,1.1081345,4.0205503,,,,,,,,,,,,,, -121400,1.1443772,4.347476,,,,,,,,,,,,,, -121500,1.587358,2.825269,,,,,,,,,,,,,, -121600,1.0139179,4.360526,,,,,,,,,,,,,, -121692,,,0.6215234398841858,1.6078003644943235,0.5834999680519104,1.7998595237731934,50000.0,0.4652000367641449,2.4484989643096924,10000.0,55493.78798747063,60864.09947562218,55493.78798747063,5357.858735084534,5.998109817504883,0.0 -121700,1.4002774,2.781529,,,,,,,,,,,,,, -121800,1.4710732,2.5746036,,,,,,,,,,,,,, -121900,1.0559927,4.314037,,,,,,,,,,,,,, -122000,1.3770341,2.5240638,,,,,,,,,,,,,, -122100,1.4039725,2.6314874,,,,,,,,,,,,,, -122200,1.1944395,5.1471877,,,,,,,,,,,,,, -122300,1.346212,3.1402266,,,,,,,,,,,,,, -122400,1.1061419,5.082319,,,,,,,,,,,,,, -122500,1.4657197,2.5577972,,,,,,,,,,,,,, -122600,1.2266734,5.1731834,,,,,,,,,,,,,, -122615,,,0.6367577910423279,1.5228012800216677,0.5913800001144409,1.7486042976379397,50000.0,0.4740000367164612,2.3974080085754395,10000.0,55914.06940293312,61324.40569233894,55914.06940293312,5397.7881071567535,6.045310974121094,0.0 -122700,1.1788479,3.4268475,,,,,,,,,,,,,, -122800,1.175152,5.265839,,,,,,,,,,,,,, -122900,1.324504,2.9418106,,,,,,,,,,,,,, -123000,1.4083717,2.5650668,,,,,,,,,,,,,, -123100,1.5430653,2.625712,,,,,,,,,,,,,, -123200,1.2343717,5.0197015,,,,,,,,,,,,,, -123300,1.569428,2.7505608,,,,,,,,,,,,,, -123400,1.2400175,3.5580978,,,,,,,,,,,,,, -123500,1.1120007,4.049123,,,,,,,,,,,,,, -123540,,,0.644335925579071,1.5328458547592163,0.5850600004196167,1.7931092977523804,50000.0,0.4682000279426574,2.441822290420532,10000.0,56334.1478741169,61787.00834584236,56334.1478741169,5440.216062545776,6.09236216545105,0.0 -123600,1.4440559,2.644501,,,,,,,,,,,,,, -123700,1.3292899,3.150086,,,,,,,,,,,,,, -123800,1.1289408,5.0771303,,,,,,,,,,,,,, -123900,1.4149162,2.7098181,,,,,,,,,,,,,, -124000,1.2446332,4.5905347,,,,,,,,,,,,,, -124100,1.0753871,4.4007287,,,,,,,,,,,,,, -124200,1.4132831,2.8102665,,,,,,,,,,,,,, -124300,1.230438,5.0868673,,,,,,,,,,,,,, -124400,1.4360301,2.6166747,,,,,,,,,,,,,, -124462,,,0.6310741901397705,1.5595864057540894,0.5909199714660645,1.7579379081726074,50000.0,0.4686000347137451,2.41546893119812,10000.0,56754.07245540619,62247.24000930786,56754.07245540619,5480.426479578018,6.140239953994751,0.0 -124500,1.4335614,2.5380692,,,,,,,,,,,,,, -124600,1.3417119,3.3240025,,,,,,,,,,,,,, -124700,1.3494054,2.686839,,,,,,,,,,,,,, -124800,1.4477278,2.5333886,,,,,,,,,,,,,, -124900,1.4947413,2.5369313,,,,,,,,,,,,,, -125000,1.2572136,3.4996774,,,,,,,,,,,,,, -125100,1.4932504,2.557165,,,,,,,,,,,,,, -125200,1.4561561,2.4948232,,,,,,,,,,,,,, -125300,1.5740992,2.5258913,,,,,,,,,,,,,, -125382,,,0.6387109160423279,1.5228734016418457,0.5959399938583374,1.7375160455703735,50000.0,0.4755000174045563,2.390855073928833,10000.0,57174.0851354599,62710.162212610245,57174.0851354599,5523.234518289566,6.193271636962891,0.0 -125400,1.4642397,3.003955,,,,,,,,,,,,,, -125500,1.1429048,4.7163644,,,,,,,,,,,,,, -125600,1.5272853,2.4192605,,,,,,,,,,,,,, -125700,1.2233282,3.919183,,,,,,,,,,,,,, -125800,1.5006562,2.7267091,,,,,,,,,,,,,, -125900,1.2267734,3.424847,,,,,,,,,,,,,, -126000,1.1945604,4.178192,,,,,,,,,,,,,, -126100,1.4223049,2.6522682,,,,,,,,,,,,,, -126200,1.5205188,2.5940228,,,,,,,,,,,,,, -126300,1.4588906,2.5906281,,,,,,,,,,,,,, -126304,,,0.6492577791213989,1.4756730794906616,0.5982199907302856,1.706802487373352,50000.0,0.4809000194072723,2.3549671173095703,10000.0,57594.21831464768,63170.8043320179,57594.21831464768,5563.646774530411,6.24152421951294,0.0 -126400,1.1782869,4.611234,,,,,,,,,,,,,, -126500,1.2884454,3.0403903,,,,,,,,,,,,,, -126600,1.452651,2.471651,,,,,,,,,,,,,, -126700,1.2585607,3.503167,,,,,,,,,,,,,, -126800,1.4797986,2.4908776,,,,,,,,,,,,,, -126900,1.4726664,2.5094423,,,,,,,,,,,,,, -127000,1.403573,2.4892294,,,,,,,,,,,,,, -127100,1.3499129,2.7468796,,,,,,,,,,,,,, -127200,1.5763458,2.5575356,,,,,,,,,,,,,, -127228,,,0.6450781226158142,1.5104620456695557,0.5978999733924866,1.7296024560928345,50000.0,0.472100019454956,2.3971104621887207,10000.0,58014.53499889374,63635.348907232285,58014.53499889374,5607.780184268951,6.286353588104248,0.0 -127300,1.1828653,3.843798,,,,,,,,,,,,,, -127400,1.4743806,2.446717,,,,,,,,,,,,,, -127500,1.2712648,2.9423268,,,,,,,,,,,,,, -127600,1.6913784,2.6592855,,,,,,,,,,,,,, -127700,1.5549353,2.6139112,,,,,,,,,,,,,, -127800,1.2272255,4.2601776,,,,,,,,,,,,,, -127900,1.6900474,2.5183136,,,,,,,,,,,,,, -128000,1.3640919,2.6314597,,,,,,,,,,,,,, -128100,1.4154323,2.685546,,,,,,,,,,,,,, -128150,,,0.6434765458106995,1.4898433685302734,0.6008999943733215,1.686245083808899,50000.0,0.4826000332832336,2.3594110012054443,10000.0,58434.6179728508,64096.87507081032,58434.6179728508,5649.124925851822,6.335825204849243,0.0 -128200,1.3083423,2.801561,,,,,,,,,,,,,, -128300,1.7419736,2.6457067,,,,,,,,,,,,,, -128400,1.3896295,2.3295858,,,,,,,,,,,,,, -128500,1.4751264,2.7496467,,,,,,,,,,,,,, -128600,1.5764744,2.5996077,,,,,,,,,,,,,, -128700,1.5613843,2.7499905,,,,,,,,,,,,,, -128800,1.4497344,2.4537003,,,,,,,,,,,,,, -128900,1.5472218,2.6155958,,,,,,,,,,,,,, -129000,1.2221397,5.0250816,,,,,,,,,,,,,, -129073,,,0.6571874618530273,1.4405206441879272,0.6071999669075012,1.6774921417236328,50000.0,0.4907000362873077,2.321425199508667,10000.0,58854.83988237381,64557.95269560814,58854.83988237381,5689.88328742981,6.38477087020874,0.0 -129100,1.4106946,3.4089117,,,,,,,,,,,,,, -129200,1.2831051,4.614445,,,,,,,,,,,,,, -129300,1.4211308,2.858245,,,,,,,,,,,,,, -129400,1.39281,3.2352366,,,,,,,,,,,,,, -129500,1.3834224,2.4970825,,,,,,,,,,,,,, -129600,1.7133704,2.4808705,,,,,,,,,,,,,, -129700,1.312935,2.7740765,,,,,,,,,,,,,, -129800,1.3885331,4.9062195,,,,,,,,,,,,,, -129900,1.1914557,4.752328,,,,,,,,,,,,,, -129996,,,0.6743749976158142,1.362080216407776,0.6096000075340271,1.675639510154724,50000.0,0.4869000315666199,2.335404634475708,10000.0,59274.73733615875,65022.062368392944,59274.73733615875,5733.991091012955,6.439935207366943,0.0 -130000,1.2742138,4.028843,,,,,,,,,,,,,, -130100,1.2935038,4.866732,,,,,,,,,,,,,, -130200,1.5659105,2.5872548,,,,,,,,,,,,,, -130300,1.320179,2.8038616,,,,,,,,,,,,,, -130400,1.4423541,3.068245,,,,,,,,,,,,,, -130500,1.5543182,2.6252887,,,,,,,,,,,,,, -130600,1.2792146,3.3785849,,,,,,,,,,,,,, -130700,1.6226479,2.525203,,,,,,,,,,,,,, -130800,1.3410989,3.086322,,,,,,,,,,,,,, -130801,,,0.6576952934265137,1.4295475482940674,0.6107199788093567,1.6498571634292605,50000.0,0.4888000190258026,2.307673454284668,10000.0,59694.95303297043,65481.10044121742,59694.95303297043,5772.714293003082,6.497478723526001,0.0 -130900,1.7573724,2.4470608,,,,,,,,,,,,,, -131000,1.3546815,3.8137343,,,,,,,,,,,,,, -131100,1.5380081,2.4072905,,,,,,,,,,,,,, -131200,1.3628291,2.9096816,,,,,,,,,,,,,, -131300,1.32266,3.4585536,,,,,,,,,,,,,, -131400,1.3728172,2.4661825,,,,,,,,,,,,,, -131500,1.5561501,2.733561,,,,,,,,,,,,,, -131600,1.5472324,2.679109,,,,,,,,,,,,,, -131700,1.5702255,2.4653962,,,,,,,,,,,,,, -131725,,,0.6663476228713989,1.3924933671951294,0.6150000095367432,1.636078119277954,50000.0,0.4941000342369079,2.286726474761963,10000.0,60115.202788591385,65940.35055208206,60115.202788591385,5811.607728004456,6.554529428482056,0.0 -131800,1.5713962,2.4811654,,,,,,,,,,,,,, -131900,1.3049562,3.7379832,,,,,,,,,,,,,, -132000,1.4910252,2.3516178,,,,,,,,,,,,,, -132100,1.4830455,2.276513,,,,,,,,,,,,,, -132200,1.5310397,2.8133307,,,,,,,,,,,,,, -132300,1.3834037,3.7926233,,,,,,,,,,,,,, -132400,1.6632564,2.386518,,,,,,,,,,,,,, -132500,1.6525393,2.5659144,,,,,,,,,,,,,, -132600,1.5748619,2.4831908,,,,,,,,,,,,,, -132646,,,0.6916796565055847,1.280269742012024,0.6193599700927734,1.6119118928909302,50000.0,0.4978000223636627,2.2578117847442627,10000.0,60535.19282770157,66399.90402460098,60535.19282770157,5851.074378013611,6.603094100952148,0.0 -132700,1.3567876,4.9572835,,,,,,,,,,,,,, -132800,1.5360968,2.453431,,,,,,,,,,,,,, -132900,1.4831829,2.3729498,,,,,,,,,,,,,, -133000,1.5720046,2.396145,,,,,,,,,,,,,, -133100,1.4324316,3.1562,,,,,,,,,,,,,, -133200,1.3869247,3.3033433,,,,,,,,,,,,,, -133300,1.2513666,4.885457,,,,,,,,,,,,,, -133400,1.5188854,2.428272,,,,,,,,,,,,,, -133500,1.643646,2.4357843,,,,,,,,,,,,,, -133570,,,0.6681445240974426,1.408696532249451,0.624239981174469,1.609699010848999,50000.0,0.4993000328540802,2.2752673625946045,10000.0,60955.23886442184,66860.528911829,60955.23886442184,5891.558381795883,6.649211645126343,0.0 -133600,1.7406875,2.3844526,,,,,,,,,,,,,, -133700,1.5546725,3.1219027,,,,,,,,,,,,,, -133800,1.7443286,2.4060657,,,,,,,,,,,,,, -133900,1.5337154,2.8730316,,,,,,,,,,,,,, -134000,1.5672787,2.4715106,,,,,,,,,,,,,, -134100,1.6365923,2.769959,,,,,,,,,,,,,, -134200,1.4729066,3.0997195,,,,,,,,,,,,,, -134300,1.6094096,2.2994804,,,,,,,,,,,,,, -134400,1.2953905,4.680738,,,,,,,,,,,,,, -134493,,,0.6716992259025574,1.369166374206543,0.6208400130271912,1.6039738655090332,50000.0,0.5034000277519226,2.2548952102661133,10000.0,61375.3524119854,67322.49597454071,61375.3524119854,5933.31524014473,6.696808576583862,0.0 -134500,1.5275447,2.6660762,,,,,,,,,,,,,, -134600,1.662837,2.4401212,,,,,,,,,,,,,, -134700,1.3456967,3.4040344,,,,,,,,,,,,,, -134800,1.5983899,2.5492287,,,,,,,,,,,,,, -134900,1.6981511,2.329463,,,,,,,,,,,,,, -135000,1.3907812,3.6984072,,,,,,,,,,,,,, -135100,1.6427523,2.4409895,,,,,,,,,,,,,, -135200,1.3323109,4.554095,,,,,,,,,,,,,, -135300,1.6330955,2.3773057,,,,,,,,,,,,,, -135400,1.341154,3.56315,,,,,,,,,,,,,, -135413,,,0.6820898056030273,1.3533567190170288,0.6220999956130981,1.6352964639663696,50000.0,0.49590003490448,2.288580894470215,10000.0,61794.85926914215,67780.60915780067,61794.85926914215,5971.316042661667,7.2534918785095215,0.0 -135500,1.346983,4.7533875,,,,,,,,,,,,,, -135600,1.3832138,4.6928954,,,,,,,,,,,,,, -135700,1.6288676,2.368847,,,,,,,,,,,,,, -135800,1.6683736,2.386862,,,,,,,,,,,,,, -135900,1.6922164,2.2820847,,,,,,,,,,,,,, -136000,1.656181,2.5892668,,,,,,,,,,,,,, -136100,1.6926063,2.4383054,,,,,,,,,,,,,, -136200,1.5310415,2.4060218,,,,,,,,,,,,,, -136300,1.4844972,2.8622885,,,,,,,,,,,,,, -136335,,,0.67431640625,1.3767963647842407,0.6293999552726746,1.5818134546279907,50000.0,0.5031000375747681,2.237124443054199,10000.0,62215.11662912369,68242.74508333206,62215.11662912369,6013.095870256424,7.303630352020264,0.0 -136400,1.3451737,2.9839492,,,,,,,,,,,,,, -136500,1.6451614,2.3602138,,,,,,,,,,,,,, -136600,1.7046381,2.2040715,,,,,,,,,,,,,, -136700,1.6113266,2.3778212,,,,,,,,,,,,,, -136800,1.6499504,2.9069939,,,,,,,,,,,,,, -136900,1.6651026,2.3379812,,,,,,,,,,,,,, -137000,1.8905079,2.2949011,,,,,,,,,,,,,, -137100,1.5540848,2.635954,,,,,,,,,,,,,, -137200,1.669626,2.4628317,,,,,,,,,,,,,, -137257,,,0.6786132454872131,1.3735076189041138,0.6330199837684631,1.5894250869750977,50000.0,0.509600043296814,2.2336137294769287,10000.0,62635.25749588013,68701.20639562607,62635.25749588013,6051.310604095459,7.360164165496826,0.0 -137300,1.4618105,4.79229,,,,,,,,,,,,,, -137400,1.3759753,3.7932215,,,,,,,,,,,,,, -137500,1.7964499,2.2764242,,,,,,,,,,,,,, -137600,1.6126548,2.5866091,,,,,,,,,,,,,, -137700,1.4396712,3.6672468,,,,,,,,,,,,,, -137800,1.4970856,2.691538,,,,,,,,,,,,,, -137900,1.5993216,2.6997602,,,,,,,,,,,,,, -138000,1.7785943,2.4774837,,,,,,,,,,,,,, -138100,1.5473989,3.3327193,,,,,,,,,,,,,, -138175,,,0.6922070384025574,1.2771286964416504,0.6366400122642517,1.529810905456543,50000.0,0.5169000029563904,2.163562297821045,10000.0,63055.25727057457,69165.89258766174,63055.25727057457,6095.903052806854,7.405869483947754,0.0 -138200,1.6748564,2.2820969,,,,,,,,,,,,,, -138300,1.670158,2.3341296,,,,,,,,,,,,,, -138400,1.5098673,2.8256454,,,,,,,,,,,,,, -138500,1.3662283,3.2596807,,,,,,,,,,,,,, -138600,1.8163021,2.241005,,,,,,,,,,,,,, -138700,1.4830214,4.5715523,,,,,,,,,,,,,, -138800,1.7018299,2.3367224,,,,,,,,,,,,,, -138900,1.4314294,4.9696674,,,,,,,,,,,,,, -139000,1.5059756,4.8951387,,,,,,,,,,,,,, -139094,,,0.6813281178474426,1.333216428756714,0.6382799744606018,1.5517253875732422,50000.0,0.5109000205993652,2.2069404125213623,10000.0,63475.19632101059,69625.84164237976,63475.19632101059,6135.814694881439,7.4559032917022705,0.0 -139100,1.7778455,2.2130792,,,,,,,,,,,,,, -139200,1.7332146,2.2243738,,,,,,,,,,,,,, -139300,1.8270589,2.311189,,,,,,,,,,,,,, -139400,1.9100978,2.3068082,,,,,,,,,,,,,, -139500,1.7703044,2.1357853,,,,,,,,,,,,,, -139600,1.7784874,2.2381546,,,,,,,,,,,,,, -139700,1.696775,2.2631276,,,,,,,,,,,,,, -139800,1.4141536,4.4674034,,,,,,,,,,,,,, -139900,1.3569177,3.6722448,,,,,,,,,,,,,, -140000,1.7793899,2.3209808,,,,,,,,,,,,,, -140016,,,0.6907812356948853,1.2872713804244995,0.6406999826431274,1.5119783878326416,50000.0,0.5188000202178955,2.150461196899414,10000.0,63895.5168299675,70082.7032134533,63895.5168299675,6172.251842260361,7.511298418045044,0.0 -140100,1.4350128,3.7637687,,,,,,,,,,,,,, -140200,1.7383301,2.4449162,,,,,,,,,,,,,, -140300,1.6671784,2.360004,,,,,,,,,,,,,, -140400,1.475175,4.7317047,,,,,,,,,,,,,, -140500,1.7780735,2.512598,,,,,,,,,,,,,, -140600,1.4350361,3.8683922,,,,,,,,,,,,,, -140700,1.4312571,4.775051,,,,,,,,,,,,,, -140800,1.6831449,2.4271977,,,,,,,,,,,,,, -140900,1.6642067,2.2840009,,,,,,,,,,,,,, -140938,,,0.6983007788658142,1.247533082962036,0.6419399976730347,1.5060182809829712,50000.0,0.5188000202178955,2.1491525173187256,10000.0,64315.69512438774,70544.94150233269,64315.69512438774,6214.205435991287,7.5685484409332275,0.0 -141000,1.7132865,2.3246064,,,,,,,,,,,,,, -141100,1.8060263,2.4160922,,,,,,,,,,,,,, -141200,1.5308938,2.8516784,,,,,,,,,,,,,, -141300,1.6688052,2.2036448,,,,,,,,,,,,,, -141400,1.505849,4.8212986,,,,,,,,,,,,,, -141500,1.7759145,2.2569284,,,,,,,,,,,,,, -141600,1.6187223,2.189209,,,,,,,,,,,,,, -141700,1.8471477,2.4067404,,,,,,,,,,,,,, -141800,1.69833,2.2532768,,,,,,,,,,,,,, -141862,,,0.7147851586341858,1.200818657875061,0.6396399736404419,1.5271742343902588,50000.0,0.5161000490188599,2.164371967315674,10000.0,64735.88982725144,71006.54262113571,64735.88982725144,6255.516871213913,7.614865064620972,0.0 -141900,1.7040802,2.2363002,,,,,,,,,,,,,, -142000,1.8424273,2.2530067,,,,,,,,,,,,,, -142100,1.4685855,3.2814198,,,,,,,,,,,,,, -142200,1.6208391,4.8225303,,,,,,,,,,,,,, -142300,1.636983,2.603828,,,,,,,,,,,,,, -142400,1.8341151,2.177742,,,,,,,,,,,,,, -142500,1.538622,4.204163,,,,,,,,,,,,,, -142600,2.2618232,4.6417694,,,,,,,,,,,,,, -142700,1.7181919,2.6886158,,,,,,,,,,,,,, -142784,,,0.6896093487739563,1.3375908136367798,0.6431199908256531,1.5491735935211182,50000.0,0.5197000503540039,2.190412759780884,10000.0,65156.19044685364,71464.98599791527,65156.19044685364,6293.555589437485,7.669762372970581,0.0 -142800,1.4600065,4.6924505,,,,,,,,,,,,,, -142900,1.6297035,2.9310927,,,,,,,,,,,,,, -143000,1.7308042,2.6612964,,,,,,,,,,,,,, -143100,1.5378196,4.603426,,,,,,,,,,,,,, -143200,1.6068734,3.3430965,,,,,,,,,,,,,, -143300,1.6712959,4.018328,,,,,,,,,,,,,, -143400,1.720754,2.3373199,,,,,,,,,,,,,, -143500,1.8910481,2.2096746,,,,,,,,,,,,,, -143600,1.9694765,2.1639407,,,,,,,,,,,,,, -143700,1.8469676,2.2649374,,,,,,,,,,,,,, -143705,,,0.7058203220367432,1.2415093183517456,0.6496599912643433,1.489941954612732,50000.0,0.525600016117096,2.1392199993133545,10000.0,65576.20358800888,71929.60329174995,65576.20358800888,6338.054198503494,7.726184368133545,0.0 -143800,1.63041,4.751976,,,,,,,,,,,,,, -143900,1.5683908,3.4752722,,,,,,,,,,,,,, -144000,1.7342826,4.7404814,,,,,,,,,,,,,, -144100,1.8147172,2.1127567,,,,,,,,,,,,,, -144200,1.6562754,3.183939,,,,,,,,,,,,,, -144300,1.799108,4.8164306,,,,,,,,,,,,,, -144400,1.778465,2.8045511,,,,,,,,,,,,,, -144500,1.5499873,3.823105,,,,,,,,,,,,,, -144600,1.65298,2.7900171,,,,,,,,,,,,,, -144624,,,0.7205273509025574,1.1597256660461426,0.6532999873161316,1.449857473373413,50000.0,0.5315000414848328,2.090836763381958,10000.0,65996.54626846313,72393.68397283554,65996.54626846313,6381.687472581863,7.781303644180298,0.0 -144700,1.4933091,3.2147477,,,,,,,,,,,,,, -144800,1.8551207,2.6298456,,,,,,,,,,,,,, -144900,1.896525,2.3147137,,,,,,,,,,,,,, -145000,1.869127,2.9388752,,,,,,,,,,,,,, -145100,1.8188727,2.4040413,,,,,,,,,,,,,, -145200,1.935987,2.156636,,,,,,,,,,,,,, -145300,1.5840279,4.6718216,,,,,,,,,,,,,, -145400,1.8850242,2.2533512,,,,,,,,,,,,,, -145500,1.805083,2.5204751,,,,,,,,,,,,,, -145543,,,0.7084179520606995,1.2068352699279783,0.6576600074768066,1.4416378736495972,50000.0,0.5302000045776367,2.0874135494232178,10000.0,66416.76606321335,72857.4061486721,66416.76606321335,6425.086839675903,7.835384845733643,0.0 -145600,1.7245139,3.2215302,,,,,,,,,,,,,, -145700,1.8601495,2.1345408,,,,,,,,,,,,,, -145800,1.8716315,2.3099854,,,,,,,,,,,,,, -145900,1.8006102,4.7032046,,,,,,,,,,,,,, -146000,1.937443,2.1430695,,,,,,,,,,,,,, -146100,2.0620615,2.1731164,,,,,,,,,,,,,, -146200,1.9416399,2.1698885,,,,,,,,,,,,,, -146300,1.9157803,2.0337205,,,,,,,,,,,,,, -146400,1.5628132,3.1362393,,,,,,,,,,,,,, -146464,,,0.7099999785423279,1.185759425163269,0.6562199592590332,1.433472990989685,50000.0,0.5349000096321106,2.0838398933410645,10000.0,66836.82033586502,73318.36433267593,66836.82033586502,6465.893684387207,7.883531093597412,0.0 -146500,2.1212134,4.7777014,,,,,,,,,,,,,, -146600,1.8368001,2.0929825,,,,,,,,,,,,,, -146700,1.999408,2.2151957,,,,,,,,,,,,,, -146800,1.858932,4.6971955,,,,,,,,,,,,,, -146900,1.7928946,4.111972,,,,,,,,,,,,,, -147000,1.6425611,2.9643376,,,,,,,,,,,,,, -147100,1.6368599,2.9491653,,,,,,,,,,,,,, -147200,1.8876032,2.1168807,,,,,,,,,,,,,, -147300,1.593946,4.113139,,,,,,,,,,,,,, -147388,,,0.71875,1.1505794525146484,0.6584999561309814,1.419907569885254,50000.0,0.5333000421524048,2.0673604011535645,10000.0,67257.22652721405,73778.35560202599,67257.22652721405,6505.38139629364,7.932077884674072,0.0 -147400,1.7380205,2.5104039,,,,,,,,,,,,,, -147500,1.7819862,4.1528673,,,,,,,,,,,,,, -147600,1.709091,4.3999486,,,,,,,,,,,,,, -147700,1.6827183,3.875571,,,,,,,,,,,,,, -147800,1.7341665,4.5248585,,,,,,,,,,,,,, -147900,1.9575807,2.0959036,,,,,,,,,,,,,, -148000,1.8772923,2.1745934,,,,,,,,,,,,,, -148100,1.729854,4.3352413,,,,,,,,,,,,,, -148200,2.0423005,2.2150996,,,,,,,,,,,,,, -148300,1.866824,2.0758927,,,,,,,,,,,,,, -148311,,,0.7202538847923279,1.1635055541992188,0.6650800108909607,1.4090675115585327,50000.0,0.5368000268936157,2.0515666007995605,10000.0,67677.38717126846,74240.42641305923,67677.38717126846,6547.186131954193,7.988610029220581,0.0 -148400,1.8402829,2.186917,,,,,,,,,,,,,, -148500,1.7906021,1.9578217,,,,,,,,,,,,,, -148600,1.6858908,3.9679368,,,,,,,,,,,,,, -148700,1.7338436,3.7901282,,,,,,,,,,,,,, -148800,1.7473398,3.5043519,,,,,,,,,,,,,, -148900,1.8261942,3.2376428,,,,,,,,,,,,,, -149000,1.7536938,4.266993,,,,,,,,,,,,,, -149100,1.9888453,2.1354296,,,,,,,,,,,,,, -149200,1.8986752,2.5644343,,,,,,,,,,,,,, -149233,,,0.7199999690055847,1.1488616466522217,0.6662200093269348,1.3948432207107544,50000.0,0.5412999987602234,2.040438652038574,10000.0,68097.53750610352,74699.47302079201,68097.53750610352,6585.98409485817,8.03798794746399,0.0 -149300,2.0284362,2.0912993,,,,,,,,,,,,,, -149400,2.1228256,2.2655668,,,,,,,,,,,,,, -149500,1.9934237,2.5009522,,,,,,,,,,,,,, -149600,1.8755033,1.9988601,,,,,,,,,,,,,, -149700,1.9558355,2.154824,,,,,,,,,,,,,, -149800,2.150663,2.0649989,,,,,,,,,,,,,, -149900,2.0467691,2.1986365,,,,,,,,,,,,,, -150000,1.8530463,4.33597,,,,,,,,,,,,,, -150100,2.0904477,2.0831614,,,,,,,,,,,,,, -150153,,,0.7326562404632568,1.1124690771102903,0.6711199879646301,1.3810205459594729,50000.0,0.547700047492981,2.0186476707458496,10000.0,68517.46277451515,75162.01907277107,68517.46277451515,6628.50556063652,8.088744640350342,0.0 -150200,1.9302018,4.196253,,,,,,,,,,,,,, -150300,1.8287065,4.4241405,,,,,,,,,,,,,, -150400,1.7164394,3.2288857,,,,,,,,,,,,,, -150500,2.032816,2.036312,,,,,,,,,,,,,, -150600,2.0117457,2.75988,,,,,,,,,,,,,, -150700,2.0239513,2.3799396,,,,,,,,,,,,,, -150800,2.0501313,2.0770216,,,,,,,,,,,,,, -150900,2.1149309,2.1608877,,,,,,,,,,,,,, -151000,1.7976682,2.783251,,,,,,,,,,,,,, -151075,,,0.7326562404632568,1.1057353019714355,0.673039972782135,1.3807213306427002,50000.0,0.5461000204086304,2.029315233230591,10000.0,68937.43747639656,75620.98027873039,68937.43747639656,6667.386779785156,8.145340204238892,0.0 -151100,1.8377657,2.6251116,,,,,,,,,,,,,, -151200,2.26397,2.0084126,,,,,,,,,,,,,, -151300,2.1391857,2.0372505,,,,,,,,,,,,,, -151400,2.1553233,2.230249,,,,,,,,,,,,,, -151500,1.9139545,4.226973,,,,,,,,,,,,,, -151600,2.0902352,2.09134,,,,,,,,,,,,,, -151700,2.1829584,2.210122,,,,,,,,,,,,,, -151800,1.8686774,3.6289635,,,,,,,,,,,,,, -151900,2.1405618,4.2956405,,,,,,,,,,,,,, -151999,,,0.7305663824081421,1.1032029390335083,0.6752399802207947,1.345995306968689,50000.0,0.5471000075340271,1.9906830787658687,10000.0,69357.67158484459,76083.28164362907,69357.67158484459,6709.3519904613495,8.198824405670166,0.0 -152000,1.8529009,3.0649624,,,,,,,,,,,,,, -152100,1.9740015,2.0050507,,,,,,,,,,,,,, -152200,1.7646848,3.1488478,,,,,,,,,,,,,, -152300,2.0409606,4.201487,,,,,,,,,,,,,, -152400,2.1423407,4.3034935,,,,,,,,,,,,,, -152500,2.290858,2.1147268,,,,,,,,,,,,,, -152600,2.2204163,2.171442,,,,,,,,,,,,,, -152700,1.9115965,2.1789098,,,,,,,,,,,,,, -152800,2.1026983,2.4708335,,,,,,,,,,,,,, -152900,2.077436,2.1438372,,,,,,,,,,,,,, -152922,,,0.7342382669448853,1.077256202697754,0.6779199838638306,1.3336390256881714,50000.0,0.5550000071525574,1.967795968055725,10000.0,69777.8404636383,76541.56018471718,69777.8404636383,6747.357416629791,8.25408148765564,0.0 -153000,2.0504067,2.0332656,,,,,,,,,,,,,, -153100,1.9627819,3.0432563,,,,,,,,,,,,,, -153200,2.1643555,3.9427462,,,,,,,,,,,,,, -153300,1.8316325,2.848432,,,,,,,,,,,,,, -153400,1.8643327,3.1836514,,,,,,,,,,,,,, -153500,1.9282559,3.05667,,,,,,,,,,,,,, -153600,2.1968002,2.01004,,,,,,,,,,,,,, -153700,2.2000208,1.9826564,,,,,,,,,,,,,, -153800,2.3293493,1.9337833,,,,,,,,,,,,,, -153844,,,0.755175769329071,0.9857242107391356,0.6827999949455261,1.3085020780563354,50000.0,0.5613000392913818,1.933016657829285,10000.0,70197.90728449821,77002.52214980125,70197.90728449821,6788.15131855011,8.30648159980774,0.0 -153900,2.0076423,2.2755346,,,,,,,,,,,,,, -154000,2.0544813,1.9732826,,,,,,,,,,,,,, -154100,2.000024,1.988837,,,,,,,,,,,,,, -154200,2.4561894,2.0683937,,,,,,,,,,,,,, -154300,2.3609266,2.1258538,,,,,,,,,,,,,, -154400,2.0084357,2.4754617,,,,,,,,,,,,,, -154500,1.9584078,4.225456,,,,,,,,,,,,,, -154600,2.2332416,2.0996726,,,,,,,,,,,,,, -154700,2.0617392,1.8561833,,,,,,,,,,,,,, -154766,,,0.7420898079872131,1.0481141805648804,0.6864399909973145,1.2994166612625122,50000.0,0.5594000220298767,1.937770247459412,10000.0,70618.0102212429,77463.61097240448,70618.0102212429,6829.035629749298,8.359469890594482,0.0 -154800,2.0399334,2.5457015,,,,,,,,,,,,,, -154900,2.179543,2.4144146,,,,,,,,,,,,,, -155000,2.1658142,4.444418,,,,,,,,,,,,,, -155100,2.252373,1.8992714,,,,,,,,,,,,,, -155200,2.07128,2.2399669,,,,,,,,,,,,,, -155300,2.2764037,1.8701594,,,,,,,,,,,,,, -155400,1.9125541,3.0757742,,,,,,,,,,,,,, -155500,2.4047332,1.9949473,,,,,,,,,,,,,, -155600,1.9410366,3.1704187,,,,,,,,,,,,,, -155688,,,0.7477148175239563,1.0200145244598389,0.6866199970245361,1.2940551042556765,50000.0,0.563800036907196,1.915860891342163,10000.0,71038.28627920151,77925.08789467812,71038.28627920151,6870.133940458298,8.413815975189209,0.0 -155700,2.4574628,2.0471058,,,,,,,,,,,,,, -155800,2.0327697,3.6993227,,,,,,,,,,,,,, -155900,1.8886883,2.922522,,,,,,,,,,,,,, -156000,2.2263815,2.0318055,,,,,,,,,,,,,, -156100,2.1678596,3.736506,,,,,,,,,,,,,, -156200,2.0669005,3.1817436,,,,,,,,,,,,,, -156300,1.9592674,3.2719607,,,,,,,,,,,,,, -156400,1.9681119,3.27948,,,,,,,,,,,,,, -156500,2.1955981,2.1818151,,,,,,,,,,,,,, -156600,2.2792006,1.9672649,,,,,,,,,,,,,, -156608,,,0.7582226395606995,0.9849535822868348,0.6900399923324585,1.2790842056274414,50000.0,0.5662000179290771,1.920639157295227,10000.0,71458.23702192307,78387.6969614029,71458.23702192307,6912.691474199295,8.466750383377075,0.0 -156700,2.1199226,2.0985317,,,,,,,,,,,,,, -156800,2.369357,1.9913006,,,,,,,,,,,,,, -156900,1.9328455,3.981201,,,,,,,,,,,,,, -157000,1.9611527,2.8238032,,,,,,,,,,,,,, -157100,2.4364862,1.9731228,,,,,,,,,,,,,, -157200,2.6197221,2.0782633,,,,,,,,,,,,,, -157300,2.2719765,2.0062575,,,,,,,,,,,,,, -157400,2.1152494,2.3105943,,,,,,,,,,,,,, -157500,2.4355946,1.9183463,,,,,,,,,,,,,, -157528,,,0.7495703101158142,1.0092881917953491,0.6941199898719788,1.263446569442749,50000.0,0.5696000456809998,1.915724515914917,10000.0,71878.5546181202,78851.75995445251,71878.5546181202,6956.333927631378,8.52064037322998,0.0 -157600,2.1128063,2.3046525,,,,,,,,,,,,,, -157700,2.30365,1.8930168,,,,,,,,,,,,,, -157800,2.3969824,1.9744182,,,,,,,,,,,,,, -157900,2.2710624,2.1279092,,,,,,,,,,,,,, -158000,2.3451824,2.1106274,,,,,,,,,,,,,, -158100,2.4473305,1.8991199,,,,,,,,,,,,,, -158200,2.3060775,2.1346948,,,,,,,,,,,,,, -158300,2.2576537,2.1850593,,,,,,,,,,,,,, -158400,2.22631,2.32343,,,,,,,,,,,,,, -158452,,,0.7608984112739563,0.9779459238052368,0.6959199905395508,1.251377820968628,50000.0,0.5752000212669373,1.885601282119751,10000.0,72298.76634907722,79316.75279378891,72298.76634907722,7001.015208482742,8.571056127548218,0.0 -158500,2.4484527,1.9137886,,,,,,,,,,,,,, -158600,2.4539292,1.9774761,,,,,,,,,,,,,, -158700,2.2886202,2.4843547,,,,,,,,,,,,,, -158800,2.2996044,2.695665,,,,,,,,,,,,,, -158900,2.3915753,1.8581135,,,,,,,,,,,,,, -159000,2.1754594,3.2204123,,,,,,,,,,,,,, -159100,2.2260857,3.580957,,,,,,,,,,,,,, -159200,2.3224876,3.6326156,,,,,,,,,,,,,, -159300,2.2967677,1.8611163,,,,,,,,,,,,,, -159372,,,0.7640234231948853,0.9492297768592834,0.6974599957466125,1.241568922996521,50000.0,0.5725000500679016,1.8790431022644043,10000.0,72719.03267765045,79778.45223927498,72719.03267765045,7042.348134994507,8.622626304626465,0.0 -159400,2.170144,3.8037891,,,,,,,,,,,,,, -159500,2.2427273,2.4669607,,,,,,,,,,,,,, -159600,2.5198123,1.830071,,,,,,,,,,,,,, -159700,2.3746097,2.110302,,,,,,,,,,,,,, -159800,2.67451,1.8684814,,,,,,,,,,,,,, -159900,2.5148063,1.9520067,,,,,,,,,,,,,, -160000,2.57744,2.1069307,,,,,,,,,,,,,, -160100,2.5211859,1.9209718,,,,,,,,,,,,,, -160200,2.281677,2.1867766,,,,,,,,,,,,,, -160296,,,0.7643359303474426,0.9486043453216552,0.703000009059906,1.21571147441864,50000.0,0.5748000144958496,1.8601552248001096,10000.0,73139.28754115105,80239.31615614891,73139.28754115105,7082.851930856705,8.679443597793579,0.0 -160300,2.469796,1.9297806,,,,,,,,,,,,,, -160400,2.5879653,1.9844508,,,,,,,,,,,,,, -160500,2.3893132,3.5918717,,,,,,,,,,,,,, -160600,2.271126,3.5229726,,,,,,,,,,,,,, -160700,2.2956305,3.132323,,,,,,,,,,,,,, -160800,2.6831474,3.8562505,,,,,,,,,,,,,, -160900,2.288559,3.9122765,,,,,,,,,,,,,, -161000,2.5461354,2.235147,,,,,,,,,,,,,, -161100,2.6284933,1.8931504,,,,,,,,,,,,,, -161200,2.3799605,3.2393515,,,,,,,,,,,,,, -161219,,,0.763378918170929,0.9769259691238404,0.7021999955177307,1.2391090393066406,50000.0,0.5754000544548035,1.8825575113296509,10000.0,73559.51320672035,80701.84395289421,73559.51320672035,7125.049999952316,8.735102891921997,0.0 -161300,2.467727,1.9025127,,,,,,,,,,,,,, -161400,2.4767513,1.8406355,,,,,,,,,,,,,, -161500,2.4155064,2.3539202,,,,,,,,,,,,,, -161600,2.4313507,1.8270754,,,,,,,,,,,,,, -161700,2.3724387,2.4050612,,,,,,,,,,,,,, -161800,2.4327142,4.298862,,,,,,,,,,,,,, -161900,2.5575516,2.1194243,,,,,,,,,,,,,, -162000,2.2884965,2.6207454,,,,,,,,,,,,,, -162100,2.3930345,3.618421,,,,,,,,,,,,,, -162142,,,0.7700781226158142,0.9373509287834167,0.7052599787712097,1.2130261659622192,50000.0,0.579300045967102,1.8460499048233032,10000.0,73979.57974791527,81164.14709663391,73979.57974791527,7167.17483496666,8.798385381698608,0.0 -162200,2.6666884,1.9772091,,,,,,,,,,,,,, -162300,2.5901008,1.8452121,,,,,,,,,,,,,, -162400,2.4838955,1.7218003,,,,,,,,,,,,,, -162500,2.451184,1.7200596,,,,,,,,,,,,,, -162600,2.3893838,2.9606586,,,,,,,,,,,,,, -162700,2.4350371,2.057278,,,,,,,,,,,,,, -162800,2.575477,1.7961801,,,,,,,,,,,,,, -162900,2.481702,4.289983,,,,,,,,,,,,,, -163000,2.3243732,2.4241254,,,,,,,,,,,,,, -163064,,,0.7858203053474426,0.8678929209709167,0.7084400057792664,1.2065235376358032,50000.0,0.5833000540733337,1.832371592521668,10000.0,74399.50149416924,81629.3094651699,74399.50149416924,7212.307286977768,8.858510732650757,0.0 -163100,2.3399167,3.0712807,,,,,,,,,,,,,, -163200,2.650847,1.9771104,,,,,,,,,,,,,, -163300,2.8241715,1.84609,,,,,,,,,,,,,, -163400,2.8404555,1.8667828,,,,,,,,,,,,,, -163500,2.7838395,1.7908455,,,,,,,,,,,,,, -163600,2.5920348,1.7472687,,,,,,,,,,,,,, -163700,2.7656221,4.16846,,,,,,,,,,,,,, -163800,2.3806295,2.7877169,,,,,,,,,,,,,, -163900,2.6718051,1.8152364,,,,,,,,,,,,,, -163987,,,0.7753320336341858,0.9123504161834716,0.7140399813652039,1.184422492980957,50000.0,0.5903000235557556,1.815895795822144,10000.0,74819.62465143204,82088.27244186401,74819.62465143204,7251.045498132706,8.912059307098389,0.0 -164000,2.7510376,1.7157527,,,,,,,,,,,,,, -164100,2.7633219,1.877255,,,,,,,,,,,,,, -164200,2.6927917,1.9005914,,,,,,,,,,,,,, -164300,2.7842338,1.8728364,,,,,,,,,,,,,, -164400,2.6565027,3.9928102,,,,,,,,,,,,,, -164500,2.8087354,1.8960568,,,,,,,,,,,,,, -164600,2.99345,4.0498214,,,,,,,,,,,,,, -164700,2.637167,1.7273798,,,,,,,,,,,,,, -164800,2.8289745,3.5318942,,,,,,,,,,,,,, -164900,2.7311897,1.7646828,,,,,,,,,,,,,, -164909,,,0.7803320288658142,0.8811646103858948,0.7166399955749512,1.1691982746124268,50000.0,0.5915000438690186,1.7959223985671997,10000.0,75239.51298332214,82548.98171710968,75239.51298332214,7291.754102945328,8.975491523742676,0.0 -165000,2.4434068,3.1486473,,,,,,,,,,,,,, -165100,2.614224,1.7915722,,,,,,,,,,,,,, -165200,2.7282045,1.7896429,,,,,,,,,,,,,, -165300,2.417757,3.4840553,,,,,,,,,,,,,, -165400,2.5112495,2.6643846,,,,,,,,,,,,,, -165500,2.592555,2.0085866,,,,,,,,,,,,,, -165600,2.7936165,2.337101,,,,,,,,,,,,,, -165700,2.7280684,4.2864637,,,,,,,,,,,,,, -165800,2.4560382,2.8052263,,,,,,,,,,,,,, -165831,,,0.788378894329071,0.8617101907730103,0.7175999879837036,1.1666576862335205,50000.0,0.5896000266075134,1.802293419837952,10000.0,75659.53583550453,83013.09928798676,75659.53583550453,7335.747500419617,9.027804374694824,0.0 -165900,3.063157,1.8809786,,,,,,,,,,,,,, -166000,2.7284386,3.9088497,,,,,,,,,,,,,, -166100,2.8412418,2.0038571,,,,,,,,,,,,,, -166200,2.6567369,3.4484713,,,,,,,,,,,,,, -166300,2.6721218,1.758626,,,,,,,,,,,,,, -166400,2.8046942,2.934549,,,,,,,,,,,,,, -166500,2.6813624,3.5073307,,,,,,,,,,,,,, -166600,2.7552829,3.8153796,,,,,,,,,,,,,, -166700,2.782726,1.7955706,,,,,,,,,,,,,, -166751,,,0.7841405868530273,0.869185209274292,0.7195799946784973,1.1540203094482422,50000.0,0.5955000519752502,1.7816680669784546,10000.0,76079.45230317116,83473.17249202728,76079.45230317116,7375.801503419876,9.081603765487673,0.0 -166800,2.8702846,4.2771387,,,,,,,,,,,,,, -166900,2.6582146,1.7969371,,,,,,,,,,,,,, -167000,3.0300167,1.7211517,,,,,,,,,,,,,, -167100,2.6829262,3.310674,,,,,,,,,,,,,, -167200,2.830964,1.6764305,,,,,,,,,,,,,, -167300,3.278455,1.8131756,,,,,,,,,,,,,, -167400,2.7016196,1.5950531,,,,,,,,,,,,,, -167500,2.8571954,2.1989474,,,,,,,,,,,,,, -167600,3.0061202,1.808112,,,,,,,,,,,,,, -167674,,,0.7900976538658142,0.8440901041030884,0.7210599780082703,1.1322126388549805,50000.0,0.5976999998092651,1.7546520233154297,10000.0,76499.65189146996,83934.78953242302,76499.65189146996,7417.107574224472,9.143556594848633,0.0 -167700,3.1996996,4.132632,,,,,,,,,,,,,, -167800,3.0004606,1.8159108,,,,,,,,,,,,,, -167900,2.4569366,2.7791855,,,,,,,,,,,,,, -168000,2.9150968,1.6877317,,,,,,,,,,,,,, -168100,2.848478,1.656286,,,,,,,,,,,,,, -168200,2.9382436,1.7630076,,,,,,,,,,,,,, -168300,2.8632512,1.9954437,,,,,,,,,,,,,, -168400,2.8460631,2.179005,,,,,,,,,,,,,, -168500,2.9048052,1.6593455,,,,,,,,,,,,,, -168597,,,0.7947070002555847,0.8163579106330872,0.7242000102996826,1.1301054954528809,50000.0,0.6009000539779663,1.7406065464019775,10000.0,76919.8663828373,84398.5364639759,76919.8663828373,7460.537466287613,9.19641399383545,0.0 -168600,2.8333306,1.9848583,,,,,,,,,,,,,, -168700,3.080619,1.7005336,,,,,,,,,,,,,, -168800,2.8428473,3.115286,,,,,,,,,,,,,, -168900,3.1141064,4.134634,,,,,,,,,,,,,, -169000,2.887319,3.3786478,,,,,,,,,,,,,, -169100,2.9178813,1.7180877,,,,,,,,,,,,,, -169200,2.7806334,1.6324717,,,,,,,,,,,,,, -169300,2.6753488,2.700599,,,,,,,,,,,,,, -169400,2.8799162,1.6166155,,,,,,,,,,,,,, -169500,3.0294776,3.3469112,,,,,,,,,,,,,, -169519,,,0.7934960722923279,0.8324291706085205,0.725820004940033,1.1191824674606323,50000.0,0.602400004863739,1.7409383058547974,10000.0,77339.84844470024,84861.64176750183,77339.84844470024,7503.553560256958,9.254878997802734,0.0 -169600,2.7647853,2.1878436,,,,,,,,,,,,,, -169700,2.6832922,1.9035865,,,,,,,,,,,,,, -169800,2.9250576,1.7831861,,,,,,,,,,,,,, -169900,3.0676959,1.6354208,,,,,,,,,,,,,, -169921,,,,,,,,,,,77520.14989256859,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/eval_measurements.csv deleted file mode 100644 index c7355adcd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.80561327934265,0.0,39.67635035514832,1,0,39.67635035514832,0.0010000000474974,6.907756805419922,10000,68.48205900192261,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -69.23552632331848,0.0189437866210937,460.0015518665314,851,0,460.0015518665314,0.0230000019073486,6.076109886169434,10000,529.3016893863678,0.034257810562849,5.893809795379639,0.0277399998158216,5.958045482635498,50000 -113.24756622314452,0.0433447360992431,879.9417262077332,1762,0,879.9417262077332,0.0534000024199485,5.5414252281188965,10000,993.3262569904329,0.0724414065480232,5.277771472930908,0.0684999972581863,5.329521656036377,50000 -156.88109421730042,0.0699703693389892,1300.3221225738523,2680,0,1300.3221225738523,0.0815000012516975,5.237037658691406,10000,1457.4158778190613,0.1116406247019767,4.909753799438477,0.1025599986314773,4.962595939636231,50000 -197.1125226020813,0.0976169109344482,1720.324553012848,3597,0,1720.324553012848,0.1203000023961067,4.7625908851623535,10000,1917.725405454636,0.1716992110013961,4.300407409667969,0.1581199914216995,4.405792236328125,50000 -242.0457983016968,0.1311566829681396,2140.311147928238,4512,0,2140.311147928238,0.1652000099420547,4.379038333892822,10000,2382.7273893356323,0.2356249988079071,3.8305723667144775,0.2140599936246872,3.9363467693328857,50000 -285.71435475349426,0.1571612358093261,2560.448935985565,5427,0,2560.448935985565,0.1862000077962875,4.206192493438721,10000,2846.608163833618,0.264453113079071,3.639159202575684,0.2471199929714203,3.742400884628296,50000 -331.25124192237854,0.1834251880645752,2980.792924880981,6344,0,2980.792924880981,0.2149000167846679,4.034944534301758,10000,3312.563648700714,0.3071679472923279,3.325169086456299,0.2760599851608276,3.493992328643799,50000 -370.6109387874603,0.2109639644622802,3401.106611251831,7263,0,3401.106611251831,0.2467000186443328,3.774560689926148,10000,3772.312874078751,0.34326171875,3.0648884773254395,0.3183799982070923,3.202389717102051,50000 -416.5219187736511,0.2373898029327392,3821.290620088577,8179,0,3821.290620088577,0.2595000267028808,3.666433095932007,10000,4238.481454372406,0.3716796934604645,2.930692672729492,0.3402999937534332,3.099444627761841,50000 -456.7812805175781,0.2679200172424316,4241.387323856354,9097,0,4241.387323856354,0.285500019788742,3.554089307785034,10000,4698.91641163826,0.3936523497104645,2.7883851528167725,0.3587999939918518,2.984739303588867,50000 -495.365522146225,0.2954919338226318,4661.656677007675,10015,0,4661.656677007675,0.2871000170707702,3.4907071590423584,10000,5157.845973968506,0.4039062261581421,2.737269163131714,0.3750399947166443,2.905771493911743,50000 -537.7116749286652,0.320845365524292,5082.045745134354,10934,0,5082.045745134354,0.2964000105857849,3.415822029113769,10000,5620.655026435852,0.423164039850235,2.606679677963257,0.3898399770259857,2.784226894378662,50000 -580.7291917800903,0.3462169170379638,5502.399238586426,11853,0,5502.399238586426,0.3145000040531158,3.317158460617065,10000,6084.099897861481,0.4403710961341858,2.523897886276245,0.4095799922943115,2.7042064666748047,50000 -624.3629055023193,0.3769981861114502,5922.57227897644,12770,0,5922.57227897644,0.3165000081062317,3.3088340759277344,10000,6547.986525058746,0.4703515470027923,2.3473031520843506,0.4133599996566772,2.656049966812134,50000 -667.0851843357086,0.4046244621276855,6342.573100805283,13689,0,6342.573100805283,0.3374000191688537,3.1996219158172607,10000,7010.78583407402,0.4599999785423279,2.3901679515838623,0.4278199970722198,2.5512709617614746,50000 -711.7846443653107,0.4368345737457275,6762.732401609421,14606,0,6762.732401609421,0.3416000306606293,3.1861348152160645,10000,7475.724648714066,0.4755273461341858,2.356476306915283,0.4366599917411804,2.5436606407165527,50000 -755.9537699222565,0.4638805389404297,7182.700782775879,15523,0,7182.700782775879,0.3447000086307525,3.096624851226806,10000,7939.937408447266,0.4975781142711639,2.191378593444824,0.4444399774074554,2.4651525020599365,50000 -800.8106706142426,0.4935455322265625,7602.853166103363,16441,0,7602.853166103363,0.3492000102996826,3.093480348587036,10000,8405.024823188782,0.486152321100235,2.2722318172454834,0.4551399946212768,2.4447381496429443,50000 -841.0064516067505,0.5205333232879639,8023.229565858841,17359,0,8023.229565858841,0.3612000048160553,3.026905298233032,10000,8865.672659635544,0.5013867020606995,2.1820547580718994,0.4605799913406372,2.391518831253052,50000 -885.2275066375732,0.5578651428222656,8443.607513904572,18273,0,8443.607513904572,0.3623000085353851,3.014296293258667,10000,9330.35705280304,0.5225195288658142,2.098334789276123,0.4728599786758423,2.336740732192993,50000 -929.2668771743774,0.5929629802703857,8863.76538324356,19190,0,8863.76538324356,0.3705000281333923,2.9391238689422607,10000,9794.63821029663,0.5135741829872131,2.104053258895874,0.4769399762153625,2.300676822662353,50000 -969.891283750534,0.623236894607544,9283.831866025925,20105,0,9283.831866025925,0.3801000118255615,2.9380943775177,10000,10255.40760755539,0.5260156393051147,2.0405845642089844,0.4841799736022949,2.2629568576812744,50000 -1013.2559487819672,0.6527237892150879,9703.953919887545,21022,0,9703.953919887545,0.3781000077724457,2.9239919185638428,10000,10718.971504211426,0.53369140625,2.0299012660980225,0.4887999892234802,2.2722764015197754,50000 -1054.5985553264618,0.6823267936706543,10124.150616884232,21939,0,10124.150616884232,0.3904000222682953,2.829202175140381,10000,11180.589293718338,0.5484570264816284,1.911619544029236,0.5009399652481079,2.1616222858428955,50000 -1098.3198697566986,0.7099740505218506,10544.085697889328,22852,0,10544.085697889328,0.3939000070095062,2.801138162612915,10000,11644.321509361269,0.5477929711341858,1.9292596578598025,0.5063599944114685,2.1407365798950195,50000 -1138.9947321414948,0.7432305812835693,10964.31832265854,23768,0,10964.31832265854,0.3993000090122223,2.8055648803710938,10000,12105.310588121414,0.5549609065055847,1.911234259605408,0.5130199790000916,2.142449378967285,50000 -1181.7955298423767,0.773090124130249,11384.320001840591,24687,0,11384.320001840591,0.3990000188350677,2.804219961166382,10000,12568.191683530807,0.5798242092132568,1.821606397628784,0.5144199728965759,2.138012886047364,50000 -1224.8035056591034,0.8047606945037842,11804.5945789814,25606,0,11804.5945789814,0.4036000072956085,2.8055355548858643,10000,13031.554756641388,0.5553905963897705,1.961132287979126,0.5149999856948853,2.1607208251953125,50000 -1268.5251622200012,0.8329160213470459,12224.88507938385,26524,0,12224.88507938385,0.4075000286102295,2.782655477523804,10000,13495.64342713356,0.5643359422683716,1.889811038970948,0.5171599984169006,2.1188228130340576,50000 -1311.8416454792025,0.8607838153839111,12644.999958276749,27444,0,12644.999958276749,0.4106000065803528,2.772930145263672,10000,13959.151177167892,0.5781835913658142,1.8442374467849727,0.5189599990844727,2.119422197341919,50000 -1357.0431625843048,0.895582914352417,13065.262185811996,28361,0,13065.262185811996,0.4104000329971313,2.7521159648895264,10000,14424.698271989822,0.5642968416213989,1.883542776107788,0.5241999626159668,2.089167356491089,50000 -1400.45316529274,0.9288933277130128,13485.303869009018,29278,0,13485.303869009018,0.424200028181076,2.681459426879883,10000,14888.231466531754,0.5771288871765137,1.784113883972168,0.5302799940109253,2.028592348098755,50000 -1445.575161933899,0.9585323333740234,13905.338715076448,30193,0,13905.338715076448,0.4198000133037567,2.670804500579834,10000,15353.465184688568,0.5874413847923279,1.7344344854354858,0.5366799831390381,1.992333173751831,50000 -1490.0121433734894,0.9918410778045654,14325.543489933014,31110,0,14325.543489933014,0.4220000207424164,2.683004379272461,10000,15818.188180923462,0.5776171684265137,1.811983942985535,0.5379199981689453,2.004590034484864,50000 -1533.8274431228638,1.0221738815307615,14745.545434951782,32027,0,14745.545434951782,0.4261000156402588,2.6556942462921143,10000,16282.084006547928,0.5791406035423279,1.7816259860992432,0.5400199890136719,1.9844132661819456,50000 -1578.6870546340942,1.0637776851654053,15165.510427713394,32942,0,15165.510427713394,0.428600013256073,2.61394476890564,10000,16746.997824668884,0.5919530987739563,1.684620022773743,0.5460000038146973,1.936972737312317,50000 -1622.3610565662384,1.1009063720703125,15585.75400686264,33857,0,15585.75400686264,0.4273000061511993,2.6944146156311035,10000,17211.00025987625,0.5914843678474426,1.794201374053955,0.5385199785232544,2.034351110458374,50000 -1666.5135188102722,1.1385252475738523,16005.86516070366,34774,0,16005.86516070366,0.4339000284671783,2.622077703475952,10000,17675.351983308792,0.5895116925239563,1.737478733062744,0.5465199947357178,1.960172414779663,50000 -1712.2519705295565,1.1743512153625488,16425.89613389969,35690,0,16425.89613389969,0.434000015258789,2.628279209136963,10000,18141.204838514328,0.5945702791213989,1.753453493118286,0.5510799884796143,1.965092182159424,50000 -1758.4259514808657,1.2110042572021484,16846.11967420578,36606,0,16846.11967420578,0.4409000277519226,2.563768148422241,10000,18607.68721485138,0.6235546469688416,1.566106200218201,0.553059995174408,1.8976249694824217,50000 -1803.445976257324,1.2457661628723145,17266.167387723923,37521,0,17266.167387723923,0.4422000348567962,2.561940670013428,10000,19072.837619304657,0.5977148413658142,1.7096517086029053,0.5522400140762329,1.920454382896424,50000 -1849.2419004440308,1.2778377532958984,17686.52992963791,38438,0,17686.52992963791,0.4454000294208526,2.5424721240997314,10000,19539.07708930969,0.602343738079071,1.6719639301300049,0.5600799918174744,1.897687554359436,50000 -1894.6081516742704,1.3122048377990725,18106.487027406693,39354,0,18106.487027406693,0.4482000172138214,2.514027833938598,10000,20004.48260617256,0.6241015195846558,1.555986762046814,0.5655399560928345,1.844379186630249,50000 -1939.2537033557887,1.3499047756195068,18526.620589256287,40270,0,18526.620589256287,0.4452000260353088,2.520315647125244,10000,20469.3478808403,0.6108984351158142,1.6063635349273682,0.56277996301651,1.844841718673706,50000 -1983.9678659439087,1.3874144554138184,18946.871083021164,41184,0,18946.871083021164,0.4467000067234039,2.5229926109313965,10000,20934.40208363533,0.6096289157867432,1.6241741180419922,0.5638999938964844,1.85176420211792,50000 -2029.7209577560425,1.4249577522277832,19367.147775888443,42094,0,19367.147775888443,0.4445000290870666,2.5553810596466064,10000,21400.516935825348,0.6166796684265137,1.6326512098312378,0.5676999688148499,1.8969939947128296,50000 -2073.206840276718,1.4591336250305176,19787.37788057328,43004,0,19787.37788057328,0.4534000158309936,2.4737699031829834,10000,21864.3144903183,0.6131054759025574,1.5964422225952148,0.5736199617385864,1.7933223247528076,50000 -2119.591502904892,1.496941328048706,20207.759017944336,43923,0,20207.759017944336,0.4507000148296356,2.493374824523926,10000,22331.16711997986,0.6158202886581421,1.6060982942581177,0.5713199973106384,1.8300766944885247,50000 -2164.257128477097,1.5289440155029297,20628.109059095383,44841,0,20628.109059095383,0.4593000113964081,2.492283821105957,10000,22796.262968063354,0.6277539134025574,1.5462582111358645,0.5756999850273132,1.8153027296066284,50000 -2209.939630508423,1.5598695278167725,21048.3162214756,45758,0,21048.3162214756,0.4622000157833099,2.4568777084350586,10000,23262.23146867752,0.6309570074081421,1.5428788661956787,0.5822799801826477,1.779966950416565,50000 -2252.3399090766907,1.5934512615203855,21468.28194952011,46675,0,21468.28194952011,0.4599000215530395,2.479991912841797,10000,23724.68097090721,0.6208202838897705,1.5916742086410522,0.574679970741272,1.8139941692352293,50000 -2296.0664880275726,1.631007432937622,21888.7586414814,47590,0,21888.7586414814,0.4625000357627868,2.438570737838745,10000,24188.97009658813,0.6354101300239563,1.507873773574829,0.5806800127029419,1.768570065498352,50000 -2340.759823322296,1.6684105396270752,22309.00814270973,48509,0,22309.00814270973,0.4689000248908996,2.4309449195861816,10000,24653.99893283844,0.6536523103713989,1.431609869003296,0.5843600034713745,1.7661563158035278,50000 -2387.934749364853,1.710059642791748,22729.02753067017,49425,0,22729.02753067017,0.4668000340461731,2.44736123085022,10000,25121.2830452919,0.6244921684265137,1.5902440547943115,0.5818600058555603,1.7915534973144531,50000 -2432.709528207779,1.7498483657836914,23149.305313825607,50342,0,23149.305313825607,0.4702000319957733,2.4052207469940186,10000,25586.42380213737,0.6416601538658142,1.4782509803771973,0.5866000056266785,1.7392513751983645,50000 -2477.1309113502502,1.788116216659546,23569.449481725693,51259,0,23569.449481725693,0.4697000086307525,2.399967670440674,10000,26051.07665014267,0.6513866782188416,1.4301273822784424,0.5900200009346008,1.729769229888916,50000 -2522.665075063705,1.8231792449951167,23989.585379123688,52176,0,23989.585379123688,0.460500031709671,2.463486433029175,10000,26516.830800056458,0.6253125071525574,1.5959267616271973,0.5811799764633179,1.8141032457351685,50000 -2567.257641553879,1.857563018798828,24409.85411667824,53091,0,24409.85411667824,0.4796000123023987,2.35530948638916,10000,26981.77447628975,0.6471874713897705,1.444928526878357,0.5928999781608582,1.692473292350769,50000 -2610.7226634025574,1.9003996849060056,24830.11319756508,54008,0,24830.11319756508,0.4711000323295593,2.407909393310547,10000,27445.589923620224,0.6502734422683716,1.464459776878357,0.5907399654388428,1.7406047582626345,50000 -2656.829050302505,1.9344263076782229,25250.349819660187,54927,0,25250.349819660187,0.4720000326633453,2.425620555877685,10000,27912.01579046249,0.6294726133346558,1.5730618238449097,0.5880399942398071,1.7772213220596311,50000 -2700.1080725193024,1.9745028018951416,25670.50795698166,55845,0,25670.50795698166,0.4745000302791595,2.4103524684906006,10000,28375.542023181915,0.6401562094688416,1.5094945430755615,0.5906800031661987,1.7373265027999878,50000 -2742.359834194184,2.0240936279296875,26090.748301029205,56763,0,26090.748301029205,0.4695000350475311,2.4129831790924072,10000,28838.13270163536,0.646484375,1.4634003639221191,0.5952399969100952,1.7165286540985107,50000 -2787.934552669525,2.0629003047943115,26510.98332166672,57682,0,26510.98332166672,0.4797000288963318,2.3530988693237305,10000,29304.03017401696,0.6556054353713989,1.399447321891785,0.599299967288971,1.666598558425903,50000 -2834.242516517639,2.105926275253296,26931.549178361893,58601,0,26931.549178361893,0.4820000231266022,2.3680901527404785,10000,29770.996066331863,0.6381640434265137,1.4684827327728271,0.5956400036811829,1.700318694114685,50000 -2877.9978761672974,2.141942977905273,27351.7313952446,59519,0,27351.7313952446,0.4812000095844269,2.367074728012085,10000,30235.01827669144,0.6510156393051147,1.446966528892517,0.5954599976539612,1.7019143104553225,50000 -2924.465174674988,2.1778712272644043,27772.0975227356,60437,0,27772.0975227356,0.4795000255107879,2.367875814437866,10000,30701.9356341362,0.6662304401397705,1.356255054473877,0.5960400104522705,1.6892080307006836,50000 -2963.984852075577,2.220782995223999,28192.18399953842,61357,0,28192.18399953842,0.4777000248432159,2.391491651535034,10000,31161.63326120377,0.6473046541213989,1.4790903329849243,0.5974400043487549,1.7262574434280396,50000 -3009.6463346481323,2.254441261291504,28612.51757788658,62274,0,28612.51757788658,0.4845000207424164,2.359397649765014,10000,31627.710637569427,0.6483398079872131,1.4805115461349487,0.5988399982452393,1.7147594690322876,50000 -3052.322532892227,2.297640323638916,29032.7456882,63191,0,29032.7456882,0.4908000230789184,2.2709903717041016,10000,32090.70645928383,0.6765429377555847,1.3175029754638672,0.6112399697303772,1.6248478889465332,50000 -3096.3077614307404,2.3357009887695312,29452.76300573349,64108,0,29452.76300573349,0.4914000332355499,2.2900989055633545,10000,32554.79523062706,0.6566015481948853,1.395091533660889,0.6128999590873718,1.6186541318893433,50000 -3141.72727894783,2.376920223236084,29873.10894012451,65024,0,29873.10894012451,0.487600028514862,2.3212289810180664,10000,33020.64948439598,0.6565819978713989,1.4120858907699585,0.6054199934005737,1.6623183488845823,50000 -3187.0775051116943,2.411798477172852,30293.38259911537,65942,0,30293.38259911537,0.4879000186920166,2.3212332725524902,10000,33486.357377290726,0.6722265481948853,1.4081395864486694,0.6116799712181091,1.6795889139175415,50000 -3230.855504989624,2.4521572589874268,30713.42664384842,66860,0,30713.42664384842,0.4897000193595886,2.283788681030273,10000,33950.26797604561,0.6582812070846558,1.39947247505188,0.6097399592399597,1.6316499710083008,50000 -3275.497382879257,2.491576910018921,31133.74489402771,67778,0,31133.74489402771,0.4891000092029571,2.292510986328125,10000,34415.31527900696,0.6619336009025574,1.376253962516785,0.6098399758338928,1.6209572553634644,50000 -3319.9401018619537,2.5363433361053467,31553.782564401627,68697,0,31553.782564401627,0.4903000295162201,2.307218551635742,10000,34879.88833665848,0.6661718487739563,1.374730348587036,0.6109600067138672,1.637009620666504,50000 -3359.4197540283203,2.5785627365112305,31974.15192389488,69615,0,31974.15192389488,0.4954000115394592,2.2856035232543945,10000,35339.82807254791,0.6862890720367432,1.303639531135559,0.6137999892234802,1.6325457096099854,50000 -3404.967422485352,2.6233327388763428,32394.14501833916,70532,0,32394.14501833916,0.5014000535011292,2.240478515625,10000,35805.46223473549,0.6647265553474426,1.3673946857452393,0.618619978427887,1.584282636642456,50000 -3451.6360535621643,2.662638902664185,32814.301644325256,71449,0,32814.301644325256,0.4984000325202942,2.2308449745178223,10000,36272.37492394447,0.6814843416213989,1.3013520240783691,0.6236599683761597,1.567617654800415,50000 -3498.489343166352,2.706322431564331,33234.60793232918,72368,0,33234.60793232918,0.4963000118732452,2.253334045410156,10000,36739.626620054245,0.6871093511581421,1.261122465133667,0.6153599619865417,1.602860927581787,50000 -3538.5254430770874,2.7534587383270264,33654.99967765808,73287,0,33654.99967765808,0.5010000467300415,2.2361388206481934,10000,37200.15159320831,0.6668750047683716,1.3529757261276243,0.6220799684524536,1.573674201965332,50000 -3582.1325442790985,2.7938594818115234,34075.14587640762,74204,0,34075.14587640762,0.4923000335693359,2.295822858810425,10000,37663.99420571327,0.6692773103713989,1.3801852464675903,0.6154199838638306,1.623300552368164,50000 -3625.7986521720886,2.837161064147949,34495.48349118233,75122,0,34495.48349118233,0.4949000179767608,2.308547258377075,10000,38128.09023118019,0.6792187094688416,1.330292582511902,0.6165399551391602,1.6309008598327637,50000 -3670.284652233124,2.877194404602051,34915.78348207474,76040,0,34915.78348207474,0.4988000094890594,2.2662360668182373,10000,38592.96436548233,0.668652355670929,1.3684196472167969,0.6186800003051758,1.6112457513809204,50000 -3714.829068660736,2.9171230792999268,35335.89672112465,76959,0,35335.89672112465,0.509600043296814,2.2034332752227783,10000,39057.71058821678,0.6804882884025574,1.2870023250579834,0.6269800066947937,1.5367377996444702,50000 -3757.898426532746,2.958980321884156,35755.94821023941,77879,0,35755.94821023941,0.4988000094890594,2.237107753753662,10000,39520.92169165611,0.6826757788658142,1.2964460849761963,0.6247599720954895,1.580952763557434,50000 -3803.0521445274353,3.0000874996185303,36176.19422531128,78797,0,36176.19422531128,0.5067000389099121,2.2033281326293945,10000,39986.41137290001,0.6805663704872131,1.3006861209869385,0.6326799988746643,1.5357210636138916,50000 -3849.35542011261,3.0385146141052246,36596.54655098915,79714,0,36596.54655098915,0.5046000480651855,2.24481201171875,10000,40453.15292882919,0.6774023175239563,1.3380963802337646,0.6273800134658813,1.583625555038452,50000 -3892.457942724228,3.081490993499756,37016.87239527702,80632,0,37016.87239527702,0.509600043296814,2.193935871124268,10000,40916.673253536224,0.6942577958106995,1.2454754114151,0.6340599656105042,1.517375946044922,50000 -3936.6828587055206,3.125739812850952,37437.10391163826,81549,0,37437.10391163826,0.5088000297546387,2.1980860233306885,10000,41381.222581624985,0.7085741758346558,1.1876144409179688,0.6307399868965149,1.532135248184204,50000 -3979.00262093544,3.170383214950561,37857.19369125366,82465,0,37857.19369125366,0.5051000118255615,2.229015350341797,10000,41843.72469615936,0.6794531345367432,1.320271611213684,0.6279999613761902,1.5583600997924805,50000 -4024.61803984642,3.2090845108032227,38277.31278705597,83383,0,38277.31278705597,0.5099000334739685,2.218010187149048,10000,42309.54677009583,0.6886913776397705,1.2884644269943235,0.6291199922561646,1.5523051023483276,50000 -4069.077126741409,3.251221179962158,38697.50476980209,84301,0,38697.50476980209,0.5146000385284424,2.178715229034424,10000,42774.288845300674,0.7104882597923279,1.1823593378067017,0.6382799744606018,1.511709451675415,50000 -4113.385211467743,3.2891671657562256,39117.63734054565,85212,0,39117.63734054565,0.5046000480651855,2.224882364273072,10000,43238.81547117233,0.6768164038658142,1.3248145580291748,0.6277599930763245,1.5532524585723877,50000 -4159.022690296173,3.333111047744751,39537.63686776161,86128,0,39537.63686776161,0.5139999985694885,2.168210744857788,10000,43704.54471921921,0.6905273199081421,1.2369657754898071,0.6373599767684937,1.5001189708709717,50000 -4204.034997463226,3.376688241958618,39957.94800043106,87043,0,39957.94800043106,0.5199000239372253,2.167684316635132,10000,44169.95958900452,0.7060351371765137,1.2162786722183228,0.64301997423172,1.5114175081253052,50000 -4250.808509349823,3.416672706604004,40378.29867053032,87961,0,40378.29867053032,0.5170000195503235,2.170346975326538,10000,44637.17258501053,0.6932812333106995,1.2655775547027588,0.6387799978256226,1.5086603164672852,50000 -4295.005742549896,3.462848424911499,40798.624522686005,88880,0,40798.624522686005,0.5182000398635864,2.1785120964050293,10000,45101.79077982903,0.7005273103713989,1.252591609954834,0.6423400044441223,1.520622730255127,50000 -4340.285962820053,3.503141403198242,41218.92868351936,89796,0,41218.92868351936,0.5162000060081482,2.1909217834472656,10000,45567.463312625885,0.7016210556030273,1.2490087747573853,0.6393600106239319,1.5409966707229614,50000 -4381.028985977173,3.543948173522949,41639.11161088944,90713,0,41639.11161088944,0.5216000080108643,2.1296322345733643,10000,46028.47916865349,0.69837886095047,1.2028743028640747,0.6417199969291687,1.4627127647399902,50000 -4427.328744649887,3.5883829593658447,42059.214393138885,91631,0,42059.214393138885,0.5245000123977661,2.149744272232056,10000,46494.97500014305,0.6951367259025574,1.2514410018920898,0.6416599750518799,1.5035076141357422,50000 -4472.958468675613,3.6285886764526367,42479.155719041824,92549,0,42479.155719041824,0.5261000394821167,2.147681713104248,10000,46960.63577723503,0.703320324420929,1.215058445930481,0.6431199908256531,1.4973742961883545,50000 -4519.443992853165,3.677181482315064,42899.42155408859,93468,0,42899.42155408859,0.5208000540733337,2.1210744380950928,10000,47427.48384642601,0.7299999594688416,1.073001265525818,0.6488400101661682,1.4530651569366455,50000 -4565.002544879913,3.720834970474243,43319.45416808128,94385,0,43319.45416808128,0.5323000550270081,2.078791618347168,10000,47893.16714167595,0.7069921493530273,1.1722054481506348,0.6530399918556213,1.4273556470870972,50000 -4611.628629922867,3.7676095962524414,43739.83526778221,95303,0,43739.83526778221,0.5348000526428223,2.1036524772644043,10000,48360.269610881805,0.712109386920929,1.1649584770202637,0.6550599932670593,1.4420679807662964,50000 -4656.578102588654,3.80673885345459,44159.95757508278,96220,0,44159.95757508278,0.5303000211715698,2.125929832458496,10000,48825.42972564697,0.7240039110183716,1.1309791803359983,0.6502199769020081,1.4678243398666382,50000 -4702.385105133057,3.8552112579345694,44579.869277477264,97138,0,44579.869277477264,0.5326000452041626,2.099363327026367,10000,49291.24484491348,0.7091991901397705,1.208737015724182,0.6511200070381165,1.454700946807861,50000 -4748.463057041168,3.8989665508270255,45000.0894985199,98053,0,45000.0894985199,0.5384000539779663,2.0602378845214844,10000,49757.63449978829,0.714550793170929,1.1435582637786863,0.658519983291626,1.4108588695526123,50000 -4790.124536275864,3.9481842517852783,45420.03511428833,98971,0,45420.03511428833,0.5272000432014465,2.114023447036743,10000,50219.339812517166,0.7219140529632568,1.1263810396194458,0.6539799571037292,1.4435505867004397,50000 -4836.862557888031,3.9894981384277344,45840.2887210846,99888,0,45840.2887210846,0.5340999960899353,2.083680391311645,10000,50686.42065501213,0.7099218368530273,1.1682158708572388,0.6560999751091003,1.4241340160369873,50000 -4883.661042928696,4.037500381469727,46260.74950814247,100807,0,46260.74950814247,0.534000039100647,2.062699556350708,10000,51153.77632880211,0.7173827886581421,1.1348272562026978,0.6578800082206726,1.4096122980117798,50000 -4932.40603351593,4.089587211608887,46681.04964232445,101725,0,46681.04964232445,0.5418000221252441,2.032041311264038,10000,51622.92252993584,0.7226757407188416,1.1128582954406738,0.6645799875259399,1.3956934213638306,50000 -4975.847240924835,4.131916999816895,47101.4250562191,102642,0,47101.4250562191,0.5415000319480896,2.086282730102539,10000,52086.83007669449,0.7225781083106995,1.1422998905181885,0.6594600081443787,1.4308340549468994,50000 -5022.208994150162,4.173872470855713,47521.5824341774,103560,0,47521.5824341774,0.5383000373840332,2.04221248626709,10000,52553.43996691704,0.7223241925239563,1.1243679523468018,0.6632999777793884,1.3928260803222656,50000 -5067.726749658585,4.222776651382446,47941.7282936573,104475,0,47941.7282936573,0.5473000407218933,2.005379199981689,10000,53019.200771570206,0.73095703125,1.0725715160369873,0.6665999889373779,1.363909125328064,50000 -5112.407633304596,4.270694017410278,48361.7201230526,105393,0,48361.7201230526,0.5384000539779663,2.039724588394165,10000,53483.97005271912,0.7487890720367432,1.0079954862594604,0.6653000116348267,1.3744615316390991,50000 -5157.802113294601,4.31640887260437,48782.00963258743,106309,0,48782.00963258743,0.5458000302314758,2.01989221572876,10000,53949.74782347679,0.7193945050239563,1.12610924243927,0.6675999760627747,1.3651163578033447,50000 -5202.594695091248,4.361671686172485,49202.203140735626,107226,0,49202.203140735626,0.5460000038146973,2.00260591506958,10000,54414.82877445221,0.7369921803474426,1.0469598770141602,0.6708799600601196,1.35093891620636,50000 -5247.3822016716,4.404340744018555,49622.17845416069,108140,0,49622.17845416069,0.5439000129699707,2.017230749130249,10000,54879.68249583244,0.7404687404632568,1.0386266708374023,0.6669600009918213,1.367197036743164,50000 -5296.62574672699,4.448246955871582,50042.48392677307,109057,0,50042.48392677307,0.5503000020980835,1.989863038063049,10000,55349.32324099541,0.7318944931030273,1.066878318786621,0.674839973449707,1.330930471420288,50000 -5343.052444219589,4.493457078933716,50462.67532229424,109975,0,50462.67532229424,0.5470000505447388,2.032017946243286,10000,55816.034700632095,0.7314843535423279,1.0798393487930298,0.6667199730873108,1.3757458925247192,50000 -5386.50735449791,4.539769411087036,50882.97862505913,110892,0,50882.97862505913,0.5520000457763672,1.9919664859771729,10000,56279.88754367829,0.7429296970367432,1.0271873474121094,0.6739000082015991,1.34721040725708,50000 -5434.5012810230255,4.5913405418396,51303.28646111488,111809,0,51303.28646111488,0.5468000173568726,1.9874333143234253,10000,56748.28958415985,0.7333788871765137,1.0607457160949707,0.6744999885559082,1.334320902824402,50000 -5480.232470989227,4.634597301483154,51723.2543463707,112725,0,51723.2543463707,0.5455000400543213,2.0178678035736084,10000,57214.07962059975,0.732714831829071,1.0809401273727417,0.6717199683189392,1.3633190393447876,50000 -5526.704137802124,4.677294731140137,52143.389543771744,113642,0,52143.389543771744,0.5550000071525574,1.975521326065064,10000,57680.77817702293,0.7489062547683716,0.9951205253601074,0.67603999376297,1.3217229843139648,50000 -5572.480983495712,4.731476306915283,52563.71753501892,114559,0,52563.71753501892,0.5550000071525574,1.9714163541793823,10000,58146.98582029343,0.745898425579071,1.0103362798690796,0.6803799867630005,1.3071465492248535,50000 -5618.007117033005,4.778652191162109,52983.67902421951,115477,0,52983.67902421951,0.5524000525474548,1.993807315826416,10000,58612.56900882721,0.7442968487739563,1.0438237190246582,0.6782799959182739,1.3390940427780151,50000 -5664.597292423248,4.826295614242554,53403.994445085526,116394,0,53403.994445085526,0.5547000169754028,1.96835458278656,10000,59079.57103824616,0.7509765625,1.0075005292892456,0.6782000064849854,1.3190701007843018,50000 -5710.4811136722565,4.872780084609985,53824.31056809425,117310,0,53824.31056809425,0.5551000237464905,1.9470292329788208,10000,59545.86528062821,0.7648828029632568,0.9348188638687134,0.6816399693489075,1.3038502931594849,50000 -5757.854747772217,4.916918992996216,54244.33873510361,118227,0,54244.33873510361,0.5621000528335571,1.951754570007324,10000,60013.35964846611,0.7464257478713989,1.029775619506836,0.6835199594497681,1.3185490369796753,50000 -5801.55264043808,4.960393190383911,54664.30675196648,119144,0,54664.30675196648,0.5582000017166138,1.942766547203064,10000,60477.11735010147,0.7531836032867432,0.9818204045295716,0.6846799850463867,1.296289563179016,50000 -5845.588560342789,5.010600090026856,55084.343988895416,120060,0,55084.343988895416,0.5561000108718872,2.008143424987793,10000,60941.289311409,0.7615820169448853,1.0147477388381958,0.6805799603462219,1.354517936706543,50000 -5891.7737855911255,5.061231851577759,55504.44954395294,120978,0,55504.44954395294,0.5636000037193298,1.9418689012527464,10000,61407.6800467968,0.7485546469688416,1.011296033859253,0.6860799789428711,1.2949484586715698,50000 -5936.102333545685,5.1081626415252686,55924.67421007157,121895,0,55924.67421007157,0.5637000203132629,1.916074275970459,10000,61872.32814979553,0.7560937404632568,0.9640070796012878,0.685259997844696,1.2753068208694458,50000 -5980.293684959412,5.162423133850098,56344.74541926384,122812,0,56344.74541926384,0.5688000321388245,1.9093610048294067,10000,62336.69419193268,0.7655858993530273,0.926956832408905,0.6911399960517883,1.2588539123535156,50000 -6024.918617486954,5.211941957473755,56765.02931785584,123732,0,56765.02931785584,0.5617000460624695,1.9466270208358765,10000,62801.70062446594,0.7517968416213989,1.0083987712860107,0.6892200112342834,1.288521409034729,50000 -6071.47722363472,5.257215738296509,57185.08319354057,124648,0,57185.08319354057,0.5719000101089478,1.9025684595108032,10000,63268.40723109245,0.7625976204872131,0.9537436366081238,0.6960600018501282,1.2512370347976685,50000 -6115.479986190796,5.305594205856323,57605.48809599877,125566,0,57605.48809599877,0.5688000321388245,1.9237370491027832,10000,63732.91115617752,0.7717187404632568,0.9368151426315308,0.6967399716377258,1.2695817947387695,50000 -6159.436917304993,5.356076002120972,58025.8363161087,126484,0,58025.8363161087,0.5696000456809998,1.9149582386016848,10000,64197.31534385681,0.7656054496765137,0.9447650909423828,0.6959199905395508,1.2564334869384766,50000 -6205.239112854004,5.409669399261475,58445.861562252045,127403,0,58445.861562252045,0.5748000144958496,1.8830746412277224,10000,64663.24557733536,0.7616015672683716,0.9457040429115297,0.696399986743927,1.245165467262268,50000 -6248.568606853485,5.454848766326904,58865.82812094688,128313,0,58865.82812094688,0.5778000354766846,1.8673149347305296,10000,65126.63528132439,0.7695116996765137,0.9001871347427368,0.6969000101089478,1.226536512374878,50000 -6289.945489883423,5.501060485839844,59285.90663409233,129230,0,59285.90663409233,0.5782999992370605,1.8833563327789309,10000,65588.18467664719,0.7834765315055847,0.878718376159668,0.7016800045967102,1.2358816862106323,50000 -6335.249958515167,5.549408435821533,59706.12528705597,130148,0,59706.12528705597,0.576200008392334,1.8715360164642327,10000,66053.80458760262,0.7674999833106995,0.9318538308143616,0.7002800107002258,1.227251648902893,50000 -6381.077441215515,5.593918085098267,60126.05570912361,131067,0,60126.05570912361,0.576200008392334,1.8704854249954224,10000,66519.6564245224,0.7759179472923279,0.9067143797874452,0.7016400098800659,1.23651921749115,50000 -6426.768949747086,5.6505677700042725,60546.168536663055,131984,0,60546.168536663055,0.584600031375885,1.861109495162964,10000,66985.56582260132,0.7800976634025574,0.899075448513031,0.7061600089073181,1.231971263885498,50000 -6469.7969336509705,5.7013936042785645,60966.1720366478,132902,0,60966.1720366478,0.5800999999046326,1.8424324989318848,10000,67448.69650387764,0.7770116925239563,0.8806419372558594,0.7057600021362305,1.1938964128494265,50000 -6513.367963075638,5.74661111831665,61386.14681506157,133822,0,61386.14681506157,0.5730000138282776,1.9204665422439573,10000,67912.33575248718,0.7735351324081421,0.9415649771690368,0.7021200060844421,1.2566030025482178,50000 -6558.35079741478,5.793826341629028,61806.27965426445,134742,0,61806.27965426445,0.5799000263214111,1.858611226081848,10000,68377.54789853096,0.7837499976158142,0.8793671727180481,0.7064200043678284,1.2217317819595337,50000 -6602.141601085663,5.841196537017822,62226.222628593445,135659,0,62226.222628593445,0.5830000042915344,1.8402303457260127,10000,68841.37681627274,0.7777929306030273,0.8910472393035889,0.708579957485199,1.2002291679382324,50000 -6647.391499996185,5.891319513320923,62646.28657245636,136576,0,62646.28657245636,0.5878000259399414,1.8325586318969729,10000,69306.7888610363,0.7836328148841858,0.8509073257446289,0.7106199860572815,1.1824514865875244,50000 -6688.541975975037,5.941382884979248,63066.573899030685,137494,0,63066.573899030685,0.5861999988555908,1.8159235715866089,10000,69768.32604432106,0.7882031202316284,0.8359939455986023,0.7127199769020081,1.1701191663742063,50000 -6735.836848020554,5.990756034851074,63486.502836704254,138412,0,63486.502836704254,0.5888000130653381,1.816422939300537,10000,70235.64741063118,0.7888476252555847,0.8295444250106812,0.7119199633598328,1.178058624267578,50000 -6781.783026695252,6.037896156311035,63906.77509832382,139332,0,63906.77509832382,0.5918000340461731,1.8050135374069207,10000,70701.96142339706,0.7895312309265137,0.8278987407684326,0.7135799527168274,1.1600850820541382,50000 -6825.600820064545,6.085220098495483,64326.802830696106,140251,0,64326.802830696106,0.5911000370979309,1.8107106685638428,10000,71165.90361714363,0.7918164134025574,0.8284945487976074,0.7123599648475647,1.180811047554016,50000 -6873.079336643219,6.138157367706299,64746.913219451904,141169,0,64746.913219451904,0.5927000045776367,1.7914958000183103,10000,71633.59379124641,0.8020898103713989,0.7641717195510864,0.714199960231781,1.146799087524414,50000 -6917.972964763641,6.189510822296143,65167.130407333374,142084,0,65167.130407333374,0.5960000157356262,1.807189583778381,10000,72098.80448961258,0.7903515696525574,0.8418607115745544,0.718459963798523,1.1597744226455688,50000 -6960.956964015961,6.237833261489868,65587.38845300674,143000,0,65587.38845300674,0.5969000458717346,1.820393681526184,10000,72562.14328551292,0.7939453125,0.8365713953971863,0.7152599692344666,1.183610916137695,50000 -7008.814731359482,6.289177417755127,66007.30506968498,143917,0,66007.30506968498,0.5915000438690186,1.8317683935165403,10000,73030.0181658268,0.8037304282188416,0.8105840682983398,0.7167999744415283,1.1841614246368408,50000 -7051.8511662483215,6.337826013565064,66427.55252289772,144834,0,66427.55252289772,0.5910000205039978,1.8147989511489868,10000,73493.39942002296,0.7900585532188416,0.8303824663162231,0.7147600054740906,1.1643662452697754,50000 -7096.628978729248,6.390943288803101,66847.67327642441,145753,0,66847.67327642441,0.600600004196167,1.7750864028930664,10000,73958.39906048775,0.8025780916213989,0.7942774295806885,0.7219399809837341,1.1464852094650269,50000 -7141.1040625572205,6.439411401748657,67267.95164108276,146669,0,67267.95164108276,0.6003000140190125,1.7748686075210571,10000,74423.24955821037,0.8016601204872131,0.7725005149841309,0.7207199931144714,1.135236740112305,50000 -7188.041525125504,6.497358798980713,67687.97954654694,147587,0,67687.97954654694,0.5991000533103943,1.7693490982055664,10000,74890.32166147232,0.7973241806030273,0.8021615147590637,0.7207799553871155,1.1381536722183228,50000 -7233.060445070267,6.546364784240723,68107.91986322403,148506,0,68107.91986322403,0.6051000356674194,1.7475770711898804,10000,75355.3785700798,0.8033788800239563,0.7781947255134583,0.7247799634933472,1.1242923736572266,50000 -7277.828405618668,6.59432315826416,68527.9065463543,149422,0,68527.9065463543,0.6030000448226929,1.751966118812561,10000,75820.22951364517,0.8059765696525574,0.7608180046081543,0.7228999733924866,1.124770998954773,50000 -7320.086533069611,6.65229082107544,68947.92167925835,150341,0,68947.92167925835,0.6095000505447388,1.739927053451538,10000,76282.6093711853,0.818359375,0.733532190322876,0.7276999950408936,1.10873281955719,50000 -7365.671882867813,6.706914186477661,69367.8230752945,151259,0,69367.8230752945,0.609000027179718,1.7397416830062866,10000,76748.19864630699,0.8030859231948853,0.7721931338310242,0.7271599769592285,1.108335256576538,50000 -7411.831959962845,6.759262323379517,69787.99724078178,152177,0,69787.99724078178,0.6035000085830688,1.7696969509124756,10000,77214.63326454163,0.8080468773841858,0.78282630443573,0.7276399731636047,1.1330397129058838,50000 -7456.8257756233215,6.808479070663452,70208.0367231369,153094,0,70208.0367231369,0.6070000529289246,1.7434462308883667,10000,77679.76422834396,0.8205859065055847,0.7144888639450073,0.7292799949645996,1.1079001426696775,50000 -7503.045190811157,6.8642542362213135,70628.20611071587,154011,0,70628.20611071587,0.6048000454902649,1.7228014469146729,10000,78146.25729894638,0.8125,0.7371184229850769,0.7307400107383728,1.0911295413970947,50000 -7549.361694574356,6.916259765625,71048.3641808033,154931,0,71048.3641808033,0.6079000234603882,1.7812516689300537,10000,78612.83320403099,0.8117187023162842,0.7915002107620239,0.7309199571609497,1.1422063112258911,50000 -7592.454738378525,6.974083662033081,71468.62477397919,155850,0,71468.62477397919,0.6098000407218933,1.736724853515625,10000,79076.29350209236,0.8199804425239563,0.7226928472518921,0.7334399819374084,1.1000254154205322,50000 -7636.801145553589,7.02562427520752,71888.80693101883,156768,0,71888.80693101883,0.6101000308990479,1.7127060890197754,10000,79540.92209506035,0.8161718845367432,0.7212274074554443,0.7339999675750732,1.0783400535583496,50000 -7682.525693893433,7.078494548797607,72308.92112541199,157685,0,72308.92112541199,0.6080000400543213,1.7125824689865112,10000,80006.86278057098,0.8190429210662842,0.7114618420600891,0.7349599599838257,1.075613021850586,50000 -7727.382160902023,7.130783319473267,72728.9054980278,158600,0,72728.9054980278,0.6096000075340271,1.695249319076538,10000,80471.80396866798,0.8244531154632568,0.6801031231880188,0.7355999946594238,1.058942794799805,50000 -7773.549040794373,7.186861276626587,73148.94817018509,159516,0,73148.94817018509,0.6140000224113464,1.6969444751739502,10000,80938.11805319786,0.8199218511581421,0.7040935158729553,0.737559974193573,1.067233324050903,50000 -7819.126227378845,7.240983247756958,73568.97808933258,160432,0,73568.97808933258,0.6107000112533569,1.701716661453247,10000,81403.82748365402,0.8233007788658142,0.6964321732521057,0.7393199801445007,1.0613622665405271,50000 -7861.79163479805,7.296997785568237,73989.23976898193,161350,0,73989.23976898193,0.6128000020980835,1.720652461051941,10000,81866.85880875587,0.8231640458106995,0.701531708240509,0.7382000088691711,1.0777002573013306,50000 -7907.331022977829,7.35301947593689,74409.15633821487,162268,0,74409.15633821487,0.6132000088691711,1.7002379894256592,10000,82332.41932559013,0.8293749690055847,0.6697791218757629,0.7391200065612793,1.058064103126526,50000 -7953.901052236557,7.404864072799683,74829.30160307884,163188,0,74829.30160307884,0.6170000433921814,1.7050611972808838,10000,82799.23604655266,0.8237695097923279,0.6979393362998962,0.739579975605011,1.063696026802063,50000 -7999.356866836548,7.457672595977783,75249.34499502182,164106,0,75249.34499502182,0.617400050163269,1.6836295127868652,10000,83264.83662986755,0.8280468583106995,0.6734815835952759,0.7416799664497375,1.0467333793640137,50000 -8047.415571212769,7.507667779922485,75669.39481902122,165024,0,75669.39481902122,0.619100034236908,1.6772472858428955,10000,83733.0430316925,0.8347070217132568,0.6491798162460327,0.7430399656295776,1.0401471853256226,50000 -8094.137541294098,7.558278083801269,76089.31130671501,165943,0,76089.31130671501,0.6195000410079956,1.6991186141967771,10000,84199.780534029,0.8305078148841858,0.6913408637046814,0.7436800003051758,1.0642290115356443,50000 -8138.949564218521,7.613423585891724,76509.37047219276,166860,0,76509.37047219276,0.6202000379562378,1.6646264791488647,10000,84664.75477600098,0.83509761095047,0.6374510526657104,0.7443199753761292,1.0309221744537354,50000 -8184.667482852936,7.671286582946777,76929.66366410255,167778,0,76929.66366410255,0.6232000589370728,1.6723334789276123,10000,85130.87297224998,0.8340038657188416,0.6508774161338806,0.7443599700927734,1.0415518283843994,50000 -8231.38061952591,7.726383686065674,77349.58613371849,168694,0,77349.58613371849,0.6210000514984131,1.6536414623260498,10000,85597.61217308044,0.8335155844688416,0.6408066153526306,0.74617999792099,1.0215263366699219,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/measurements.csv deleted file mode 100644 index ed516796c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1878 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.35481423,6.9077563,,,,,,,,,,,,,, -1,,,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,39.67635035514832,68.48205900192261,39.67635035514832,28.80561327934265,0.0,0.0 -100,0.51505494,6.882768,,,,,,,,,,,,,, -200,0.55649036,6.7787814,,,,,,,,,,,,,, -300,0.6884644,6.6875052,,,,,,,,,,,,,, -400,0.72869354,6.601669,,,,,,,,,,,,,, -500,1.1129255,6.7335253,,,,,,,,,,,,,, -600,1.0685781,6.4183674,,,,,,,,,,,,,, -700,1.2517183,6.278539,,,,,,,,,,,,,, -800,1.5896072,6.4174848,,,,,,,,,,,,,, -851,,,0.034257810562849,5.893809795379639,0.0277399998158216,5.958045482635498,50000.0,0.0230000019073486,6.076109886169434,10000.0,460.0015518665314,529.3016893863678,460.0015518665314,69.23552632331848,0.0189437866210937,0.0 -900,1.2343698,6.1546264,,,,,,,,,,,,,, -1000,1.4698266,6.1267533,,,,,,,,,,,,,, -1100,1.0656388,6.372529,,,,,,,,,,,,,, -1200,0.91435415,6.0777807,,,,,,,,,,,,,, -1300,1.0949916,6.010666,,,,,,,,,,,,,, -1400,0.9325736,6.620752,,,,,,,,,,,,,, -1500,0.9543835,6.5680895,,,,,,,,,,,,,, -1600,1.1754575,5.7222548,,,,,,,,,,,,,, -1700,0.8718817,6.088678,,,,,,,,,,,,,, -1762,,,0.0724414065480232,5.277771472930908,0.0684999972581863,5.329521656036377,50000.0,0.0534000024199485,5.5414252281188965,10000.0,879.9417262077332,993.3262569904329,879.9417262077332,113.24756622314452,0.0433447360992431,0.0 -1800,0.8259153,6.373461,,,,,,,,,,,,,, -1900,1.0979453,5.707238,,,,,,,,,,,,,, -2000,1.196276,5.683844,,,,,,,,,,,,,, -2100,0.88663256,5.748126,,,,,,,,,,,,,, -2200,0.7592352,6.621974,,,,,,,,,,,,,, -2300,0.9219961,5.5969067,,,,,,,,,,,,,, -2400,1.0318683,5.4772177,,,,,,,,,,,,,, -2500,0.85778296,6.1755385,,,,,,,,,,,,,, -2600,1.0317562,5.519588,,,,,,,,,,,,,, -2680,,,0.1116406247019767,4.909753799438477,0.1025599986314773,4.962595939636231,50000.0,0.0815000012516975,5.237037658691406,10000.0,1300.3221225738523,1457.4158778190613,1300.3221225738523,156.88109421730042,0.0699703693389892,0.0 -2700,0.8455367,6.3614297,,,,,,,,,,,,,, -2800,0.7103569,5.953796,,,,,,,,,,,,,, -2900,1.1340603,5.2270365,,,,,,,,,,,,,, -3000,0.91503763,5.2999496,,,,,,,,,,,,,, -3100,0.8386076,5.28157,,,,,,,,,,,,,, -3200,0.8508658,6.1753464,,,,,,,,,,,,,, -3300,1.0910822,5.163763,,,,,,,,,,,,,, -3400,1.0141369,5.1852326,,,,,,,,,,,,,, -3500,0.96470857,5.3317103,,,,,,,,,,,,,, -3597,,,0.1716992110013961,4.300407409667969,0.1581199914216995,4.405792236328125,50000.0,0.1203000023961067,4.7625908851623535,10000.0,1720.324553012848,1917.725405454636,1720.324553012848,197.1125226020813,0.0976169109344482,0.0 -3600,0.9438086,5.040763,,,,,,,,,,,,,, -3700,0.8958748,5.0127378,,,,,,,,,,,,,, -3800,0.82254964,4.995363,,,,,,,,,,,,,, -3900,1.0672226,5.0834603,,,,,,,,,,,,,, -4000,1.1159508,5.752025,,,,,,,,,,,,,, -4100,0.8212864,5.6170697,,,,,,,,,,,,,, -4200,0.64426154,5.813441,,,,,,,,,,,,,, -4300,0.75314116,5.4178605,,,,,,,,,,,,,, -4400,0.6589047,6.303141,,,,,,,,,,,,,, -4500,0.8780346,4.5980043,,,,,,,,,,,,,, -4512,,,0.2356249988079071,3.8305723667144775,0.2140599936246872,3.9363467693328857,50000.0,0.1652000099420547,4.379038333892822,10000.0,2140.311147928238,2382.7273893356323,2140.311147928238,242.0457983016968,0.1311566829681396,0.0 -4600,0.7161315,6.130403,,,,,,,,,,,,,, -4700,0.8602618,5.998777,,,,,,,,,,,,,, -4800,0.7200593,4.7774525,,,,,,,,,,,,,, -4900,0.9234136,4.633643,,,,,,,,,,,,,, -5000,0.8017137,4.4832697,,,,,,,,,,,,,, -5100,0.86022496,4.416117,,,,,,,,,,,,,, -5200,0.75946504,4.3858147,,,,,,,,,,,,,, -5300,0.782014,4.749908,,,,,,,,,,,,,, -5400,0.9250894,4.3976707,,,,,,,,,,,,,, -5427,,,0.264453113079071,3.639159202575684,0.2471199929714203,3.742400884628296,50000.0,0.1862000077962875,4.206192493438721,10000.0,2560.448935985565,2846.608163833618,2560.448935985565,285.71435475349426,0.1571612358093261,0.0 -5500,0.8648523,4.6010294,,,,,,,,,,,,,, -5600,0.8425251,4.574733,,,,,,,,,,,,,, -5700,0.6376389,5.813009,,,,,,,,,,,,,, -5800,1.0765839,4.3152685,,,,,,,,,,,,,, -5900,0.71047354,5.1189504,,,,,,,,,,,,,, -6000,0.83316964,4.5020223,,,,,,,,,,,,,, -6100,0.80566394,4.218577,,,,,,,,,,,,,, -6200,0.7267476,6.009871,,,,,,,,,,,,,, -6300,0.97280115,4.201082,,,,,,,,,,,,,, -6344,,,0.3071679472923279,3.325169086456299,0.2760599851608276,3.493992328643799,50000.0,0.2149000167846679,4.034944534301758,10000.0,2980.792924880981,3312.563648700714,2980.792924880981,331.25124192237854,0.1834251880645752,0.0 -6400,0.8778372,4.064634,,,,,,,,,,,,,, -6500,0.8881416,4.52019,,,,,,,,,,,,,, -6600,0.7468045,4.710378,,,,,,,,,,,,,, -6700,0.65811497,5.941773,,,,,,,,,,,,,, -6800,0.8722521,4.0880313,,,,,,,,,,,,,, -6900,0.7725175,4.382661,,,,,,,,,,,,,, -7000,0.82224476,4.0209284,,,,,,,,,,,,,, -7100,0.7717102,5.994628,,,,,,,,,,,,,, -7200,0.6093837,5.1901965,,,,,,,,,,,,,, -7263,,,0.34326171875,3.0648884773254395,0.3183799982070923,3.202389717102051,50000.0,0.2467000186443328,3.774560689926148,10000.0,3401.106611251831,3772.312874078751,3401.106611251831,370.6109387874603,0.2109639644622802,0.0 -7300,0.5321136,6.0421753,,,,,,,,,,,,,, -7400,0.5450035,5.7774625,,,,,,,,,,,,,, -7500,0.5934555,5.699511,,,,,,,,,,,,,, -7600,0.80528593,4.3082075,,,,,,,,,,,,,, -7700,0.89507663,4.3009763,,,,,,,,,,,,,, -7800,0.67356247,6.0356064,,,,,,,,,,,,,, -7900,0.7334469,4.4818335,,,,,,,,,,,,,, -8000,0.7766802,4.0131655,,,,,,,,,,,,,, -8100,0.8682313,4.0169973,,,,,,,,,,,,,, -8179,,,0.3716796934604645,2.930692672729492,0.3402999937534332,3.099444627761841,50000.0,0.2595000267028808,3.666433095932007,10000.0,3821.290620088577,4238.481454372406,3821.290620088577,416.5219187736511,0.2373898029327392,0.0 -8200,0.8063001,4.0742245,,,,,,,,,,,,,, -8300,0.6418281,5.3268714,,,,,,,,,,,,,, -8400,1.0762689,3.964694,,,,,,,,,,,,,, -8500,0.7035727,4.410484,,,,,,,,,,,,,, -8600,0.8292494,3.7657363,,,,,,,,,,,,,, -8700,0.84321433,3.9468906,,,,,,,,,,,,,, -8800,0.87330556,3.8672388,,,,,,,,,,,,,, -8900,0.90457594,3.725369,,,,,,,,,,,,,, -9000,0.93731785,3.817317,,,,,,,,,,,,,, -9097,,,0.3936523497104645,2.7883851528167725,0.3587999939918518,2.984739303588867,50000.0,0.285500019788742,3.554089307785034,10000.0,4241.387323856354,4698.91641163826,4241.387323856354,456.7812805175781,0.2679200172424316,0.0 -9100,0.77233994,3.814423,,,,,,,,,,,,,, -9200,0.861786,3.8581383,,,,,,,,,,,,,, -9300,0.91435426,3.6837704,,,,,,,,,,,,,, -9400,0.9919195,3.7997549,,,,,,,,,,,,,, -9500,0.74125093,4.68095,,,,,,,,,,,,,, -9600,0.7095078,4.894421,,,,,,,,,,,,,, -9700,1.0590891,3.900954,,,,,,,,,,,,,, -9800,0.72678053,5.3235264,,,,,,,,,,,,,, -9900,0.90156025,3.7023056,,,,,,,,,,,,,, -10000,0.68206203,5.9574084,,,,,,,,,,,,,, -10015,,,0.4039062261581421,2.737269163131714,0.3750399947166443,2.905771493911743,50000.0,0.2871000170707702,3.4907071590423584,10000.0,4661.656677007675,5157.845973968506,4661.656677007675,495.365522146225,0.2954919338226318,0.0 -10100,0.7868653,5.4189696,,,,,,,,,,,,,, -10200,0.8705279,4.0537953,,,,,,,,,,,,,, -10300,0.6088798,5.274397,,,,,,,,,,,,,, -10400,0.70666647,4.647433,,,,,,,,,,,,,, -10500,0.8820708,3.8936768,,,,,,,,,,,,,, -10600,0.8397086,3.680163,,,,,,,,,,,,,, -10700,0.89773417,3.540579,,,,,,,,,,,,,, -10800,0.7183798,5.7510967,,,,,,,,,,,,,, -10900,0.98866844,3.6122236,,,,,,,,,,,,,, -10934,,,0.423164039850235,2.606679677963257,0.3898399770259857,2.784226894378662,50000.0,0.2964000105857849,3.415822029113769,10000.0,5082.045745134354,5620.655026435852,5082.045745134354,537.7116749286652,0.320845365524292,0.0 -11000,0.9111131,3.6677427,,,,,,,,,,,,,, -11100,0.7124607,4.9292727,,,,,,,,,,,,,, -11200,1.2438997,3.630439,,,,,,,,,,,,,, -11300,0.7112497,5.3676867,,,,,,,,,,,,,, -11400,0.8928308,3.8701892,,,,,,,,,,,,,, -11500,0.6855998,5.468418,,,,,,,,,,,,,, -11600,0.79274714,4.498906,,,,,,,,,,,,,, -11700,0.87885714,4.150953,,,,,,,,,,,,,, -11800,1.0802404,3.6095393,,,,,,,,,,,,,, -11853,,,0.4403710961341858,2.523897886276245,0.4095799922943115,2.7042064666748047,50000.0,0.3145000040531158,3.317158460617065,10000.0,5502.399238586426,6084.099897861481,5502.399238586426,580.7291917800903,0.3462169170379638,0.0 -11900,0.8722808,3.404137,,,,,,,,,,,,,, -12000,0.8100908,4.231626,,,,,,,,,,,,,, -12100,0.98219216,3.4446428,,,,,,,,,,,,,, -12200,0.9009311,3.7541046,,,,,,,,,,,,,, -12300,1.014623,3.540711,,,,,,,,,,,,,, -12400,0.94889635,3.6707792,,,,,,,,,,,,,, -12500,0.7770348,3.9271924,,,,,,,,,,,,,, -12600,1.0115117,3.5445385,,,,,,,,,,,,,, -12700,0.9518884,3.448646,,,,,,,,,,,,,, -12770,,,0.4703515470027923,2.3473031520843506,0.4133599996566772,2.656049966812134,50000.0,0.3165000081062317,3.3088340759277344,10000.0,5922.57227897644,6547.986525058746,5922.57227897644,624.3629055023193,0.3769981861114502,0.0 -12800,0.89718837,3.504731,,,,,,,,,,,,,, -12900,1.0331,3.5809224,,,,,,,,,,,,,, -13000,0.8028285,3.919968,,,,,,,,,,,,,, -13100,0.91384715,3.9164183,,,,,,,,,,,,,, -13200,0.9558134,3.4103236,,,,,,,,,,,,,, -13300,0.99761415,3.581141,,,,,,,,,,,,,, -13400,0.7256346,5.1782203,,,,,,,,,,,,,, -13500,0.9336431,5.0577703,,,,,,,,,,,,,, -13600,1.020357,3.5382156,,,,,,,,,,,,,, -13689,,,0.4599999785423279,2.3901679515838623,0.4278199970722198,2.5512709617614746,50000.0,0.3374000191688537,3.1996219158172607,10000.0,6342.573100805283,7010.78583407402,6342.573100805283,667.0851843357086,0.4046244621276855,0.0 -13700,0.9553179,3.5226204,,,,,,,,,,,,,, -13800,0.68498087,5.7097034,,,,,,,,,,,,,, -13900,0.80066997,4.3953633,,,,,,,,,,,,,, -14000,0.9104122,3.5579052,,,,,,,,,,,,,, -14100,0.93924826,3.2933006,,,,,,,,,,,,,, -14200,1.0017058,3.892248,,,,,,,,,,,,,, -14300,1.3531758,3.3948116,,,,,,,,,,,,,, -14400,0.86327356,4.415147,,,,,,,,,,,,,, -14500,1.162553,3.4772327,,,,,,,,,,,,,, -14600,0.89088744,4.458549,,,,,,,,,,,,,, -14606,,,0.4755273461341858,2.356476306915283,0.4366599917411804,2.5436606407165527,50000.0,0.3416000306606293,3.1861348152160645,10000.0,6762.732401609421,7475.724648714066,6762.732401609421,711.7846443653107,0.4368345737457275,0.0 -14700,0.6803331,5.7928367,,,,,,,,,,,,,, -14800,1.0118022,3.357528,,,,,,,,,,,,,, -14900,1.0236912,3.7971706,,,,,,,,,,,,,, -15000,0.80229384,5.7465887,,,,,,,,,,,,,, -15100,1.0097079,3.5222135,,,,,,,,,,,,,, -15200,1.022681,3.336365,,,,,,,,,,,,,, -15300,1.028671,3.3387377,,,,,,,,,,,,,, -15400,0.9973735,3.4259062,,,,,,,,,,,,,, -15500,0.78709894,5.5294194,,,,,,,,,,,,,, -15523,,,0.4975781142711639,2.191378593444824,0.4444399774074554,2.4651525020599365,50000.0,0.3447000086307525,3.096624851226806,10000.0,7182.700782775879,7939.937408447266,7182.700782775879,755.9537699222565,0.4638805389404297,0.0 -15600,0.9444227,3.84737,,,,,,,,,,,,,, -15700,1.0363942,3.3318443,,,,,,,,,,,,,, -15800,1.0221397,3.4243684,,,,,,,,,,,,,, -15900,0.7800649,4.02727,,,,,,,,,,,,,, -16000,0.98512673,3.8445058,,,,,,,,,,,,,, -16100,0.84341663,4.030142,,,,,,,,,,,,,, -16200,0.9017305,4.760366,,,,,,,,,,,,,, -16300,0.9872183,3.3073115,,,,,,,,,,,,,, -16400,0.8154846,5.7031713,,,,,,,,,,,,,, -16441,,,0.486152321100235,2.2722318172454834,0.4551399946212768,2.4447381496429443,50000.0,0.3492000102996826,3.093480348587036,10000.0,7602.853166103363,8405.024823188782,7602.853166103363,800.8106706142426,0.4935455322265625,0.0 -16500,1.0503374,3.3532867,,,,,,,,,,,,,, -16600,0.8069811,5.538958,,,,,,,,,,,,,, -16700,0.80832297,5.361334,,,,,,,,,,,,,, -16800,0.9736477,3.7087245,,,,,,,,,,,,,, -16900,0.72807735,5.5325146,,,,,,,,,,,,,, -17000,1.1163956,3.3320115,,,,,,,,,,,,,, -17100,0.9981054,3.2797773,,,,,,,,,,,,,, -17200,0.9778502,3.3606296,,,,,,,,,,,,,, -17300,1.0737516,3.389006,,,,,,,,,,,,,, -17359,,,0.5013867020606995,2.1820547580718994,0.4605799913406372,2.391518831253052,50000.0,0.3612000048160553,3.026905298233032,10000.0,8023.229565858841,8865.672659635544,8023.229565858841,841.0064516067505,0.5205333232879639,0.0 -17400,1.0157855,3.0561566,,,,,,,,,,,,,, -17500,1.0955107,3.2894068,,,,,,,,,,,,,, -17600,1.2057781,3.8583324,,,,,,,,,,,,,, -17700,1.0472726,3.6168258,,,,,,,,,,,,,, -17800,1.2725986,3.2378037,,,,,,,,,,,,,, -17900,1.0510173,3.338227,,,,,,,,,,,,,, -18000,0.9202512,3.4858992,,,,,,,,,,,,,, -18100,1.0852911,3.13479,,,,,,,,,,,,,, -18200,0.76033497,4.5054455,,,,,,,,,,,,,, -18273,,,0.5225195288658142,2.098334789276123,0.4728599786758423,2.336740732192993,50000.0,0.3623000085353851,3.014296293258667,10000.0,8443.607513904572,9330.35705280304,8443.607513904572,885.2275066375732,0.5578651428222656,0.0 -18300,1.057422,3.2500718,,,,,,,,,,,,,, -18400,1.2299212,3.513308,,,,,,,,,,,,,, -18500,0.9900044,3.217335,,,,,,,,,,,,,, -18600,0.9577952,3.3536305,,,,,,,,,,,,,, -18700,1.0136153,3.2324421,,,,,,,,,,,,,, -18800,0.9224457,3.9209518,,,,,,,,,,,,,, -18900,1.2150041,3.1417499,,,,,,,,,,,,,, -19000,0.70868385,5.523512,,,,,,,,,,,,,, -19100,1.1758717,3.1625242,,,,,,,,,,,,,, -19190,,,0.5135741829872131,2.104053258895874,0.4769399762153625,2.300676822662353,50000.0,0.3705000281333923,2.9391238689422607,10000.0,8863.76538324356,9794.63821029663,8863.76538324356,929.2668771743774,0.5929629802703857,0.0 -19200,1.0680891,3.221698,,,,,,,,,,,,,, -19300,0.87258387,4.0011144,,,,,,,,,,,,,, -19400,1.047076,3.2029872,,,,,,,,,,,,,, -19500,0.89204866,4.8371925,,,,,,,,,,,,,, -19600,0.81578934,4.813334,,,,,,,,,,,,,, -19700,1.0347329,3.2214832,,,,,,,,,,,,,, -19800,1.0340889,3.1577554,,,,,,,,,,,,,, -19900,1.4669547,3.27287,,,,,,,,,,,,,, -20000,1.2346728,3.1536183,,,,,,,,,,,,,, -20100,0.92957443,3.151758,,,,,,,,,,,,,, -20105,,,0.5260156393051147,2.0405845642089844,0.4841799736022949,2.2629568576812744,50000.0,0.3801000118255615,2.9380943775177,10000.0,9283.831866025925,10255.40760755539,9283.831866025925,969.891283750534,0.623236894607544,0.0 -20200,0.8132613,4.552374,,,,,,,,,,,,,, -20300,0.80569756,4.4515324,,,,,,,,,,,,,, -20400,1.0702442,3.0841887,,,,,,,,,,,,,, -20500,1.0046846,3.7629476,,,,,,,,,,,,,, -20600,1.0349013,3.3507907,,,,,,,,,,,,,, -20700,1.1048704,3.096387,,,,,,,,,,,,,, -20800,0.83151174,3.9464536,,,,,,,,,,,,,, -20900,1.1428037,3.044766,,,,,,,,,,,,,, -21000,0.8220563,5.046962,,,,,,,,,,,,,, -21022,,,0.53369140625,2.0299012660980225,0.4887999892234802,2.2722764015197754,50000.0,0.3781000077724457,2.9239919185638428,10000.0,9703.953919887545,10718.971504211426,9703.953919887545,1013.2559487819672,0.6527237892150879,0.0 -21100,1.0406046,3.1603677,,,,,,,,,,,,,, -21200,0.87333,5.404953,,,,,,,,,,,,,, -21300,1.092828,3.2137353,,,,,,,,,,,,,, -21400,1.2225847,3.1988316,,,,,,,,,,,,,, -21500,0.97234845,3.5542233,,,,,,,,,,,,,, -21600,0.96798414,3.8779604,,,,,,,,,,,,,, -21700,0.7468983,5.560317,,,,,,,,,,,,,, -21800,0.8076804,5.6320233,,,,,,,,,,,,,, -21900,1.0486765,3.1970181,,,,,,,,,,,,,, -21939,,,0.5484570264816284,1.911619544029236,0.5009399652481079,2.1616222858428955,50000.0,0.3904000222682953,2.829202175140381,10000.0,10124.150616884232,11180.589293718338,10124.150616884232,1054.5985553264618,0.6823267936706543,0.0 -22000,1.0081575,3.2249517,,,,,,,,,,,,,, -22100,1.0653982,3.0831664,,,,,,,,,,,,,, -22200,0.9029059,5.282081,,,,,,,,,,,,,, -22300,1.1557736,3.0916324,,,,,,,,,,,,,, -22400,1.078937,3.0142202,,,,,,,,,,,,,, -22500,0.79969674,5.189562,,,,,,,,,,,,,, -22600,1.1232446,3.102105,,,,,,,,,,,,,, -22700,0.9899743,3.9334545,,,,,,,,,,,,,, -22800,1.1520973,3.3094034,,,,,,,,,,,,,, -22852,,,0.5477929711341858,1.9292596578598025,0.5063599944114685,2.1407365798950195,50000.0,0.3939000070095062,2.801138162612915,10000.0,10544.085697889328,11644.321509361269,10544.085697889328,1098.3198697566986,0.7099740505218506,0.0 -22900,1.102242,2.9894648,,,,,,,,,,,,,, -23000,1.0835642,3.034747,,,,,,,,,,,,,, -23100,1.1012148,3.0956156,,,,,,,,,,,,,, -23200,1.2055157,3.2438703,,,,,,,,,,,,,, -23300,1.0934104,3.1433911,,,,,,,,,,,,,, -23400,1.2414125,3.09875,,,,,,,,,,,,,, -23500,1.034873,3.0956385,,,,,,,,,,,,,, -23600,1.0448636,4.4348297,,,,,,,,,,,,,, -23700,1.0958825,3.079375,,,,,,,,,,,,,, -23768,,,0.5549609065055847,1.911234259605408,0.5130199790000916,2.142449378967285,50000.0,0.3993000090122223,2.8055648803710938,10000.0,10964.31832265854,12105.310588121414,10964.31832265854,1138.9947321414948,0.7432305812835693,0.0 -23800,1.0011728,5.4309697,,,,,,,,,,,,,, -23900,1.0375304,3.326466,,,,,,,,,,,,,, -24000,1.0636017,3.6530907,,,,,,,,,,,,,, -24100,1.0890653,3.0995715,,,,,,,,,,,,,, -24200,0.9954303,4.698985,,,,,,,,,,,,,, -24300,1.0577302,3.3405757,,,,,,,,,,,,,, -24400,0.8115222,5.1519866,,,,,,,,,,,,,, -24500,1.0951815,3.0603092,,,,,,,,,,,,,, -24600,0.97941905,3.7164469,,,,,,,,,,,,,, -24687,,,0.5798242092132568,1.821606397628784,0.5144199728965759,2.138012886047364,50000.0,0.3990000188350677,2.804219961166382,10000.0,11384.320001840591,12568.191683530807,11384.320001840591,1181.7955298423767,0.773090124130249,0.0 -24700,0.8875263,4.5769544,,,,,,,,,,,,,, -24800,1.0736783,2.987112,,,,,,,,,,,,,, -24900,0.74992585,5.4745536,,,,,,,,,,,,,, -25000,1.2410457,3.0787272,,,,,,,,,,,,,, -25100,0.9628372,3.5189943,,,,,,,,,,,,,, -25200,1.0905647,3.0008068,,,,,,,,,,,,,, -25300,1.1135186,2.9381938,,,,,,,,,,,,,, -25400,0.82389504,5.0566278,,,,,,,,,,,,,, -25500,0.87329733,4.513094,,,,,,,,,,,,,, -25600,0.9000217,4.231817,,,,,,,,,,,,,, -25606,,,0.5553905963897705,1.961132287979126,0.5149999856948853,2.1607208251953125,50000.0,0.4036000072956085,2.8055355548858643,10000.0,11804.5945789814,13031.554756641388,11804.5945789814,1224.8035056591034,0.8047606945037842,0.0 -25700,0.8652629,4.0440307,,,,,,,,,,,,,, -25800,1.168766,3.120882,,,,,,,,,,,,,, -25900,1.0071372,4.5421443,,,,,,,,,,,,,, -26000,1.0186714,3.8739393,,,,,,,,,,,,,, -26100,1.1485273,3.0721729,,,,,,,,,,,,,, -26200,1.1220727,3.031762,,,,,,,,,,,,,, -26300,1.1320058,2.8677654,,,,,,,,,,,,,, -26400,1.030917,3.1512175,,,,,,,,,,,,,, -26500,0.8700101,5.5351424,,,,,,,,,,,,,, -26524,,,0.5643359422683716,1.889811038970948,0.5171599984169006,2.1188228130340576,50000.0,0.4075000286102295,2.782655477523804,10000.0,12224.88507938385,13495.64342713356,12224.88507938385,1268.5251622200012,0.8329160213470459,0.0 -26600,1.080142,3.0051007,,,,,,,,,,,,,, -26700,1.001572,3.3567424,,,,,,,,,,,,,, -26800,1.1991733,2.9624054,,,,,,,,,,,,,, -26900,1.0650506,3.4441195,,,,,,,,,,,,,, -27000,1.1202304,2.842979,,,,,,,,,,,,,, -27100,1.0175279,2.926558,,,,,,,,,,,,,, -27200,1.086007,3.0812335,,,,,,,,,,,,,, -27300,1.1430466,2.9527557,,,,,,,,,,,,,, -27400,0.90426534,5.2335606,,,,,,,,,,,,,, -27444,,,0.5781835913658142,1.8442374467849727,0.5189599990844727,2.119422197341919,50000.0,0.4106000065803528,2.772930145263672,10000.0,12644.999958276749,13959.151177167892,12644.999958276749,1311.8416454792025,0.8607838153839111,0.0 -27500,0.84970385,3.874135,,,,,,,,,,,,,, -27600,1.0682046,2.9213939,,,,,,,,,,,,,, -27700,0.91448843,3.7954211,,,,,,,,,,,,,, -27800,0.86929685,4.1403008,,,,,,,,,,,,,, -27900,1.1066101,2.9856167,,,,,,,,,,,,,, -28000,1.1303891,3.3082047,,,,,,,,,,,,,, -28100,1.0700067,3.138032,,,,,,,,,,,,,, -28200,1.0820552,2.9567766,,,,,,,,,,,,,, -28300,1.2362471,3.0205023,,,,,,,,,,,,,, -28361,,,0.5642968416213989,1.883542776107788,0.5241999626159668,2.089167356491089,50000.0,0.4104000329971313,2.7521159648895264,10000.0,13065.262185811996,14424.698271989822,13065.262185811996,1357.0431625843048,0.895582914352417,0.0 -28400,1.0245297,3.544591,,,,,,,,,,,,,, -28500,0.765691,5.4096236,,,,,,,,,,,,,, -28600,0.78894323,5.4525313,,,,,,,,,,,,,, -28700,1.1378531,3.2752252,,,,,,,,,,,,,, -28800,1.2165359,2.9229841,,,,,,,,,,,,,, -28900,1.1274505,2.8605366,,,,,,,,,,,,,, -29000,0.9792619,3.3856986,,,,,,,,,,,,,, -29100,0.9769533,3.7307253,,,,,,,,,,,,,, -29200,0.99126714,4.1327643,,,,,,,,,,,,,, -29278,,,0.5771288871765137,1.784113883972168,0.5302799940109253,2.028592348098755,50000.0,0.424200028181076,2.681459426879883,10000.0,13485.303869009018,14888.231466531754,13485.303869009018,1400.45316529274,0.9288933277130128,0.0 -29300,1.1263239,3.1497438,,,,,,,,,,,,,, -29400,1.3901259,2.9423008,,,,,,,,,,,,,, -29500,1.0788217,3.0044448,,,,,,,,,,,,,, -29600,1.1711923,2.856721,,,,,,,,,,,,,, -29700,1.0048532,3.8225493,,,,,,,,,,,,,, -29800,0.96218574,4.3642445,,,,,,,,,,,,,, -29900,1.0478753,2.8891506,,,,,,,,,,,,,, -30000,1.1324292,2.8986726,,,,,,,,,,,,,, -30100,1.0986985,2.8250208,,,,,,,,,,,,,, -30193,,,0.5874413847923279,1.7344344854354858,0.5366799831390381,1.992333173751831,50000.0,0.4198000133037567,2.670804500579834,10000.0,13905.338715076448,15353.465184688568,13905.338715076448,1445.575161933899,0.9585323333740234,0.0 -30200,0.8949322,5.488049,,,,,,,,,,,,,, -30300,1.1654447,2.9609315,,,,,,,,,,,,,, -30400,0.8940252,4.3539042,,,,,,,,,,,,,, -30500,1.0226799,3.2081714,,,,,,,,,,,,,, -30600,1.0543294,5.374751,,,,,,,,,,,,,, -30700,1.0887829,2.868462,,,,,,,,,,,,,, -30800,1.1501865,3.1589117,,,,,,,,,,,,,, -30900,1.0509322,4.3999014,,,,,,,,,,,,,, -31000,0.8709227,4.1755137,,,,,,,,,,,,,, -31100,1.159287,3.3298006,,,,,,,,,,,,,, -31110,,,0.5776171684265137,1.811983942985535,0.5379199981689453,2.004590034484864,50000.0,0.4220000207424164,2.683004379272461,10000.0,14325.543489933014,15818.188180923462,14325.543489933014,1490.0121433734894,0.9918410778045654,0.0 -31200,1.1481229,2.9869678,,,,,,,,,,,,,, -31300,1.0130218,2.9712677,,,,,,,,,,,,,, -31400,1.1120598,2.768235,,,,,,,,,,,,,, -31500,1.0728589,2.7522988,,,,,,,,,,,,,, -31600,1.25494,2.9049392,,,,,,,,,,,,,, -31700,1.1318986,3.11735,,,,,,,,,,,,,, -31800,1.3311677,3.7640438,,,,,,,,,,,,,, -31900,0.95208526,5.356084,,,,,,,,,,,,,, -32000,1.0607697,2.9310904,,,,,,,,,,,,,, -32027,,,0.5791406035423279,1.7816259860992432,0.5400199890136719,1.9844132661819456,50000.0,0.4261000156402588,2.6556942462921143,10000.0,14745.545434951782,16282.084006547928,14745.545434951782,1533.8274431228638,1.0221738815307615,0.0 -32100,1.1267945,2.7540119,,,,,,,,,,,,,, -32200,1.2715669,2.8492522,,,,,,,,,,,,,, -32300,0.88156766,4.085593,,,,,,,,,,,,,, -32400,1.1482954,2.8321676,,,,,,,,,,,,,, -32500,0.95032215,4.7763395,,,,,,,,,,,,,, -32600,0.91537815,5.3867846,,,,,,,,,,,,,, -32700,1.2118088,3.1490068,,,,,,,,,,,,,, -32800,1.178396,2.8647223,,,,,,,,,,,,,, -32900,0.9779637,2.752336,,,,,,,,,,,,,, -32942,,,0.5919530987739563,1.684620022773743,0.5460000038146973,1.936972737312317,50000.0,0.428600013256073,2.61394476890564,10000.0,15165.510427713394,16746.997824668884,15165.510427713394,1578.6870546340942,1.0637776851654053,0.0 -33000,0.7972855,5.4751434,,,,,,,,,,,,,, -33100,1.1537099,2.687965,,,,,,,,,,,,,, -33200,1.1499362,2.8622766,,,,,,,,,,,,,, -33300,1.1056796,2.730651,,,,,,,,,,,,,, -33400,1.0424929,5.0383277,,,,,,,,,,,,,, -33500,1.0875577,3.0221653,,,,,,,,,,,,,, -33600,1.129516,2.9303527,,,,,,,,,,,,,, -33700,0.89419067,4.2794228,,,,,,,,,,,,,, -33800,1.1443007,2.725261,,,,,,,,,,,,,, -33857,,,0.5914843678474426,1.794201374053955,0.5385199785232544,2.034351110458374,50000.0,0.4273000061511993,2.6944146156311035,10000.0,15585.75400686264,17211.00025987625,15585.75400686264,1622.3610565662384,1.1009063720703125,0.0 -33900,0.8675617,5.489604,,,,,,,,,,,,,, -34000,0.9347475,3.7877057,,,,,,,,,,,,,, -34100,1.2281443,2.8291485,,,,,,,,,,,,,, -34200,1.1128825,2.906988,,,,,,,,,,,,,, -34300,1.0652373,3.1413126,,,,,,,,,,,,,, -34400,1.7382997,2.7940717,,,,,,,,,,,,,, -34500,1.0944768,2.86275,,,,,,,,,,,,,, -34600,1.1030655,2.7422178,,,,,,,,,,,,,, -34700,1.2660909,2.755452,,,,,,,,,,,,,, -34774,,,0.5895116925239563,1.737478733062744,0.5465199947357178,1.960172414779663,50000.0,0.4339000284671783,2.622077703475952,10000.0,16005.86516070366,17675.351983308792,16005.86516070366,1666.5135188102722,1.1385252475738523,0.0 -34800,0.89065623,4.1184816,,,,,,,,,,,,,, -34900,1.1494788,2.962237,,,,,,,,,,,,,, -35000,1.1987002,2.8899198,,,,,,,,,,,,,, -35100,1.249363,2.9304357,,,,,,,,,,,,,, -35200,1.1640472,2.8643212,,,,,,,,,,,,,, -35300,1.1169944,2.6490293,,,,,,,,,,,,,, -35400,1.0283295,3.9610949,,,,,,,,,,,,,, -35500,0.78813356,4.933935,,,,,,,,,,,,,, -35600,1.0472397,3.4529426,,,,,,,,,,,,,, -35690,,,0.5945702791213989,1.753453493118286,0.5510799884796143,1.965092182159424,50000.0,0.434000015258789,2.628279209136963,10000.0,16425.89613389969,18141.204838514328,16425.89613389969,1712.2519705295565,1.1743512153625488,0.0 -35700,0.89937,4.224055,,,,,,,,,,,,,, -35800,1.4741876,2.8435688,,,,,,,,,,,,,, -35900,1.192591,2.980597,,,,,,,,,,,,,, -36000,1.0723871,3.0369189,,,,,,,,,,,,,, -36100,1.1940097,2.842441,,,,,,,,,,,,,, -36200,1.0688637,4.1795135,,,,,,,,,,,,,, -36300,1.25868,2.67755,,,,,,,,,,,,,, -36400,1.2501478,2.8237906,,,,,,,,,,,,,, -36500,1.0204791,3.6225214,,,,,,,,,,,,,, -36600,0.9075231,4.330707,,,,,,,,,,,,,, -36606,,,0.6235546469688416,1.566106200218201,0.553059995174408,1.8976249694824217,50000.0,0.4409000277519226,2.563768148422241,10000.0,16846.11967420578,18607.68721485138,16846.11967420578,1758.4259514808657,1.2110042572021484,0.0 -36700,1.0444994,3.1640437,,,,,,,,,,,,,, -36800,1.0451578,2.9446259,,,,,,,,,,,,,, -36900,0.9931995,5.1859837,,,,,,,,,,,,,, -37000,1.0793421,2.9625144,,,,,,,,,,,,,, -37100,1.145749,2.7390888,,,,,,,,,,,,,, -37200,1.028417,4.659342,,,,,,,,,,,,,, -37300,1.148298,3.7105656,,,,,,,,,,,,,, -37400,1.0265647,3.1813765,,,,,,,,,,,,,, -37500,1.2217416,2.8559742,,,,,,,,,,,,,, -37521,,,0.5977148413658142,1.7096517086029053,0.5522400140762329,1.920454382896424,50000.0,0.4422000348567962,2.561940670013428,10000.0,17266.167387723923,19072.837619304657,17266.167387723923,1803.445976257324,1.2457661628723145,0.0 -37600,1.1271724,5.3316913,,,,,,,,,,,,,, -37700,1.0153769,3.2318988,,,,,,,,,,,,,, -37800,1.1531718,3.0787804,,,,,,,,,,,,,, -37900,0.8527018,5.284351,,,,,,,,,,,,,, -38000,1.1108755,2.8168242,,,,,,,,,,,,,, -38100,1.2738123,2.9039876,,,,,,,,,,,,,, -38200,1.1435667,2.8062336,,,,,,,,,,,,,, -38300,0.9092074,3.8068752,,,,,,,,,,,,,, -38400,1.1776242,2.835163,,,,,,,,,,,,,, -38438,,,0.602343738079071,1.6719639301300049,0.5600799918174744,1.897687554359436,50000.0,0.4454000294208526,2.5424721240997314,10000.0,17686.52992963791,19539.07708930969,17686.52992963791,1849.2419004440308,1.2778377532958984,0.0 -38500,1.1087081,2.9186382,,,,,,,,,,,,,, -38600,0.994157,3.9763734,,,,,,,,,,,,,, -38700,1.1197758,2.847795,,,,,,,,,,,,,, -38800,1.2520218,2.7847252,,,,,,,,,,,,,, -38900,1.1355171,2.8106582,,,,,,,,,,,,,, -39000,1.2194585,2.7736723,,,,,,,,,,,,,, -39100,1.2298194,2.790537,,,,,,,,,,,,,, -39200,1.2002231,2.7422783,,,,,,,,,,,,,, -39300,1.1325471,3.2518766,,,,,,,,,,,,,, -39354,,,0.6241015195846558,1.555986762046814,0.5655399560928345,1.844379186630249,50000.0,0.4482000172138214,2.514027833938598,10000.0,18106.487027406693,20004.48260617256,18106.487027406693,1894.6081516742704,1.3122048377990725,0.0 -39400,0.968573,4.63597,,,,,,,,,,,,,, -39500,1.0988799,3.0170267,,,,,,,,,,,,,, -39600,0.9785965,3.630353,,,,,,,,,,,,,, -39700,1.2387606,2.6724207,,,,,,,,,,,,,, -39800,1.123091,2.7416756,,,,,,,,,,,,,, -39900,1.0514168,3.203973,,,,,,,,,,,,,, -40000,1.0974431,2.9572077,,,,,,,,,,,,,, -40100,1.0355589,3.1230054,,,,,,,,,,,,,, -40200,1.1663672,2.6796298,,,,,,,,,,,,,, -40270,,,0.6108984351158142,1.6063635349273682,0.56277996301651,1.844841718673706,50000.0,0.4452000260353088,2.520315647125244,10000.0,18526.620589256287,20469.3478808403,18526.620589256287,1939.2537033557887,1.3499047756195068,0.0 -40300,0.9585463,3.4596515,,,,,,,,,,,,,, -40400,1.3680557,2.7054603,,,,,,,,,,,,,, -40500,1.0003769,3.3450012,,,,,,,,,,,,,, -40600,0.87141913,4.885213,,,,,,,,,,,,,, -40700,0.9350667,5.377904,,,,,,,,,,,,,, -40800,1.1083658,2.8092122,,,,,,,,,,,,,, -40900,0.9624717,5.045082,,,,,,,,,,,,,, -41000,1.062567,3.1009088,,,,,,,,,,,,,, -41100,1.2336066,2.8395631,,,,,,,,,,,,,, -41184,,,0.6096289157867432,1.6241741180419922,0.5638999938964844,1.85176420211792,50000.0,0.4467000067234039,2.5229926109313965,10000.0,18946.871083021164,20934.40208363533,18946.871083021164,1983.9678659439087,1.3874144554138184,0.0 -41200,1.217174,2.9401727,,,,,,,,,,,,,, -41300,0.8987172,4.4794817,,,,,,,,,,,,,, -41400,1.1156458,2.6598644,,,,,,,,,,,,,, -41500,0.8894194,5.148108,,,,,,,,,,,,,, -41600,0.92119694,5.3885465,,,,,,,,,,,,,, -41700,0.9606105,5.271581,,,,,,,,,,,,,, -41800,0.9931987,4.188057,,,,,,,,,,,,,, -41900,0.91166884,5.2217026,,,,,,,,,,,,,, -42000,1.023418,3.8215437,,,,,,,,,,,,,, -42094,,,0.6166796684265137,1.6326512098312378,0.5676999688148499,1.8969939947128296,50000.0,0.4445000290870666,2.5553810596466064,10000.0,19367.147775888443,21400.516935825348,19367.147775888443,2029.7209577560425,1.4249577522277832,0.0 -42100,1.0856603,2.7922843,,,,,,,,,,,,,, -42200,1.149588,2.6757479,,,,,,,,,,,,,, -42300,1.5212513,2.6923676,,,,,,,,,,,,,, -42400,1.0716587,2.9104667,,,,,,,,,,,,,, -42500,1.1678175,2.7883215,,,,,,,,,,,,,, -42600,1.0539274,3.3222713,,,,,,,,,,,,,, -42700,1.0896683,3.8870869,,,,,,,,,,,,,, -42800,1.2577852,2.8897152,,,,,,,,,,,,,, -42900,1.0452514,5.3476863,,,,,,,,,,,,,, -43000,1.2771574,2.7041938,,,,,,,,,,,,,, -43004,,,0.6131054759025574,1.5964422225952148,0.5736199617385864,1.7933223247528076,50000.0,0.4534000158309936,2.4737699031829834,10000.0,19787.37788057328,21864.3144903183,19787.37788057328,2073.206840276718,1.4591336250305176,0.0 -43100,1.2384778,2.6524775,,,,,,,,,,,,,, -43200,1.025245,4.335183,,,,,,,,,,,,,, -43300,1.0093188,5.2708206,,,,,,,,,,,,,, -43400,1.1240351,2.8020253,,,,,,,,,,,,,, -43500,1.045128,5.2112813,,,,,,,,,,,,,, -43600,1.169842,2.736357,,,,,,,,,,,,,, -43700,1.1645403,2.7342653,,,,,,,,,,,,,, -43800,0.83326125,4.9778643,,,,,,,,,,,,,, -43900,1.2659682,2.9360392,,,,,,,,,,,,,, -43923,,,0.6158202886581421,1.6060982942581177,0.5713199973106384,1.8300766944885247,50000.0,0.4507000148296356,2.493374824523926,10000.0,20207.759017944336,22331.16711997986,20207.759017944336,2119.591502904892,1.496941328048706,0.0 -44000,1.0362344,2.8262901,,,,,,,,,,,,,, -44100,1.1140321,3.0011303,,,,,,,,,,,,,, -44200,1.2097591,3.2730277,,,,,,,,,,,,,, -44300,1.1416429,3.3351724,,,,,,,,,,,,,, -44400,1.4259362,2.9176145,,,,,,,,,,,,,, -44500,1.1741102,2.7727306,,,,,,,,,,,,,, -44600,1.1990905,2.796431,,,,,,,,,,,,,, -44700,0.94704866,5.2428284,,,,,,,,,,,,,, -44800,0.98795164,5.241023,,,,,,,,,,,,,, -44841,,,0.6277539134025574,1.5462582111358645,0.5756999850273132,1.8153027296066284,50000.0,0.4593000113964081,2.492283821105957,10000.0,20628.109059095383,22796.262968063354,20628.109059095383,2164.257128477097,1.5289440155029297,0.0 -44900,0.8605609,4.4795547,,,,,,,,,,,,,, -45000,0.9427876,4.6950855,,,,,,,,,,,,,, -45100,1.0816582,3.173271,,,,,,,,,,,,,, -45200,1.2150046,2.6272924,,,,,,,,,,,,,, -45300,0.99524194,4.446806,,,,,,,,,,,,,, -45400,1.2517524,2.740994,,,,,,,,,,,,,, -45500,0.93954307,4.6021814,,,,,,,,,,,,,, -45600,1.162228,3.2833245,,,,,,,,,,,,,, -45700,1.1894757,2.6268554,,,,,,,,,,,,,, -45758,,,0.6309570074081421,1.5428788661956787,0.5822799801826477,1.779966950416565,50000.0,0.4622000157833099,2.4568777084350586,10000.0,21048.3162214756,23262.23146867752,21048.3162214756,2209.939630508423,1.5598695278167725,0.0 -45800,1.1698214,2.701476,,,,,,,,,,,,,, -45900,1.0300163,3.5449371,,,,,,,,,,,,,, -46000,1.1571803,2.6052566,,,,,,,,,,,,,, -46100,1.1833547,2.651425,,,,,,,,,,,,,, -46200,0.9232885,4.461251,,,,,,,,,,,,,, -46300,1.084335,3.226396,,,,,,,,,,,,,, -46400,1.0227628,3.2902253,,,,,,,,,,,,,, -46500,0.9329874,5.3264236,,,,,,,,,,,,,, -46600,1.1067768,5.1298122,,,,,,,,,,,,,, -46675,,,0.6208202838897705,1.5916742086410522,0.574679970741272,1.8139941692352293,50000.0,0.4599000215530395,2.479991912841797,10000.0,21468.28194952011,23724.68097090721,21468.28194952011,2252.3399090766907,1.5934512615203855,0.0 -46700,0.91093343,4.9753704,,,,,,,,,,,,,, -46800,1.1370927,2.6601007,,,,,,,,,,,,,, -46900,1.1115704,2.6139073,,,,,,,,,,,,,, -47000,1.1834141,2.6021152,,,,,,,,,,,,,, -47100,1.1596739,2.666082,,,,,,,,,,,,,, -47200,1.072834,2.6783826,,,,,,,,,,,,,, -47300,1.1397586,2.84404,,,,,,,,,,,,,, -47400,1.1966345,2.5051122,,,,,,,,,,,,,, -47500,1.2268753,2.5915658,,,,,,,,,,,,,, -47590,,,0.6354101300239563,1.507873773574829,0.5806800127029419,1.768570065498352,50000.0,0.4625000357627868,2.438570737838745,10000.0,21888.7586414814,24188.97009658813,21888.7586414814,2296.0664880275726,1.631007432937622,0.0 -47600,1.2517623,2.5442457,,,,,,,,,,,,,, -47700,1.239344,2.6139338,,,,,,,,,,,,,, -47800,0.9164586,5.280665,,,,,,,,,,,,,, -47900,0.9862864,3.3891075,,,,,,,,,,,,,, -48000,1.276154,5.23933,,,,,,,,,,,,,, -48100,1.2032893,2.609734,,,,,,,,,,,,,, -48200,0.99216974,3.785577,,,,,,,,,,,,,, -48300,1.2014396,2.5973706,,,,,,,,,,,,,, -48400,1.1774228,2.885491,,,,,,,,,,,,,, -48500,1.0811334,2.9036055,,,,,,,,,,,,,, -48509,,,0.6536523103713989,1.431609869003296,0.5843600034713745,1.7661563158035278,50000.0,0.4689000248908996,2.4309449195861816,10000.0,22309.00814270973,24653.99893283844,22309.00814270973,2340.759823322296,1.6684105396270752,0.0 -48600,1.3794404,2.8142467,,,,,,,,,,,,,, -48700,1.3575408,2.6165118,,,,,,,,,,,,,, -48800,1.082524,3.2296715,,,,,,,,,,,,,, -48900,1.0999957,3.7251768,,,,,,,,,,,,,, -49000,1.0566013,4.4686136,,,,,,,,,,,,,, -49100,1.0685208,5.0552955,,,,,,,,,,,,,, -49200,1.2618113,2.7346158,,,,,,,,,,,,,, -49300,1.1124958,2.8301027,,,,,,,,,,,,,, -49400,1.2168425,2.7077928,,,,,,,,,,,,,, -49425,,,0.6244921684265137,1.5902440547943115,0.5818600058555603,1.7915534973144531,50000.0,0.4668000340461731,2.44736123085022,10000.0,22729.02753067017,25121.2830452919,22729.02753067017,2387.934749364853,1.710059642791748,0.0 -49500,1.1903566,2.6319659,,,,,,,,,,,,,, -49600,1.1768422,2.794569,,,,,,,,,,,,,, -49700,1.020001,5.1750517,,,,,,,,,,,,,, -49800,1.1165252,2.8219395,,,,,,,,,,,,,, -49900,1.069981,4.6096025,,,,,,,,,,,,,, -50000,0.9793189,5.242741,,,,,,,,,,,,,, -50100,1.1570427,2.530229,,,,,,,,,,,,,, -50200,1.0642401,3.9616334,,,,,,,,,,,,,, -50300,1.1954097,2.6962428,,,,,,,,,,,,,, -50342,,,0.6416601538658142,1.4782509803771973,0.5866000056266785,1.7392513751983645,50000.0,0.4702000319957733,2.4052207469940186,10000.0,23149.305313825607,25586.42380213737,23149.305313825607,2432.709528207779,1.7498483657836914,0.0 -50400,1.0729936,5.302037,,,,,,,,,,,,,, -50500,1.1885638,2.5208435,,,,,,,,,,,,,, -50600,1.3188404,2.690339,,,,,,,,,,,,,, -50700,1.1528497,3.0340383,,,,,,,,,,,,,, -50800,1.2758181,2.6543252,,,,,,,,,,,,,, -50900,0.95018816,5.1904116,,,,,,,,,,,,,, -51000,1.1628971,3.0256407,,,,,,,,,,,,,, -51100,1.0785923,2.8907468,,,,,,,,,,,,,, -51200,1.141289,2.6073375,,,,,,,,,,,,,, -51259,,,0.6513866782188416,1.4301273822784424,0.5900200009346008,1.729769229888916,50000.0,0.4697000086307525,2.399967670440674,10000.0,23569.449481725693,26051.07665014267,23569.449481725693,2477.1309113502502,1.788116216659546,0.0 -51300,1.3212168,3.2870896,,,,,,,,,,,,,, -51400,1.2199658,2.631341,,,,,,,,,,,,,, -51500,1.1450262,2.772272,,,,,,,,,,,,,, -51600,1.2247962,2.758501,,,,,,,,,,,,,, -51700,1.114654,3.1615427,,,,,,,,,,,,,, -51800,0.98976,5.1445956,,,,,,,,,,,,,, -51900,1.1443732,4.7375183,,,,,,,,,,,,,, -52000,1.0536016,5.2926064,,,,,,,,,,,,,, -52100,1.1987177,2.6245575,,,,,,,,,,,,,, -52176,,,0.6253125071525574,1.5959267616271973,0.5811799764633179,1.8141032457351685,50000.0,0.460500031709671,2.463486433029175,10000.0,23989.585379123688,26516.830800056458,23989.585379123688,2522.665075063705,1.8231792449951167,0.0 -52200,1.1607908,2.5382428,,,,,,,,,,,,,, -52300,1.3278757,2.5969949,,,,,,,,,,,,,, -52400,1.0845888,3.2555294,,,,,,,,,,,,,, -52500,1.2892513,2.5257514,,,,,,,,,,,,,, -52600,1.0256116,3.8711722,,,,,,,,,,,,,, -52700,1.1052928,5.202765,,,,,,,,,,,,,, -52800,1.2403504,2.700744,,,,,,,,,,,,,, -52900,1.1637591,2.4447615,,,,,,,,,,,,,, -53000,1.188359,2.5801578,,,,,,,,,,,,,, -53091,,,0.6471874713897705,1.444928526878357,0.5928999781608582,1.692473292350769,50000.0,0.4796000123023987,2.35530948638916,10000.0,24409.85411667824,26981.77447628975,24409.85411667824,2567.257641553879,1.857563018798828,0.0 -53100,1.1564286,2.6427433,,,,,,,,,,,,,, -53200,1.2177092,2.738417,,,,,,,,,,,,,, -53300,1.2397305,2.70089,,,,,,,,,,,,,, -53400,1.1859185,2.6263413,,,,,,,,,,,,,, -53500,1.0894756,5.2279897,,,,,,,,,,,,,, -53600,0.9876654,4.939114,,,,,,,,,,,,,, -53700,1.2418388,2.5861619,,,,,,,,,,,,,, -53800,1.0746886,3.2832003,,,,,,,,,,,,,, -53900,1.1262549,3.5916758,,,,,,,,,,,,,, -54000,1.3319834,2.5863993,,,,,,,,,,,,,, -54008,,,0.6502734422683716,1.464459776878357,0.5907399654388428,1.7406047582626345,50000.0,0.4711000323295593,2.407909393310547,10000.0,24830.11319756508,27445.589923620224,24830.11319756508,2610.7226634025574,1.9003996849060056,0.0 -54100,1.3829792,2.633584,,,,,,,,,,,,,, -54200,1.2741287,2.5242424,,,,,,,,,,,,,, -54300,0.99810076,5.084698,,,,,,,,,,,,,, -54400,1.0668882,4.320258,,,,,,,,,,,,,, -54500,0.91539276,4.8071594,,,,,,,,,,,,,, -54600,1.3095745,2.8445017,,,,,,,,,,,,,, -54700,0.9660889,4.5083137,,,,,,,,,,,,,, -54800,1.1404657,2.930766,,,,,,,,,,,,,, -54900,1.1337427,2.9943254,,,,,,,,,,,,,, -54927,,,0.6294726133346558,1.5730618238449097,0.5880399942398071,1.7772213220596311,50000.0,0.4720000326633453,2.425620555877685,10000.0,25250.349819660187,27912.01579046249,25250.349819660187,2656.829050302505,1.9344263076782229,0.0 -55000,1.1421297,2.5399508,,,,,,,,,,,,,, -55100,1.2780129,2.6591692,,,,,,,,,,,,,, -55200,1.0556475,5.23892,,,,,,,,,,,,,, -55300,1.3221055,2.615472,,,,,,,,,,,,,, -55400,1.1612399,3.355899,,,,,,,,,,,,,, -55500,0.9982592,4.736016,,,,,,,,,,,,,, -55600,1.2708758,2.841709,,,,,,,,,,,,,, -55700,1.2862699,2.6167877,,,,,,,,,,,,,, -55800,1.0816875,3.2007926,,,,,,,,,,,,,, -55845,,,0.6401562094688416,1.5094945430755615,0.5906800031661987,1.7373265027999878,50000.0,0.4745000302791595,2.4103524684906006,10000.0,25670.50795698166,28375.542023181915,25670.50795698166,2700.1080725193024,1.9745028018951416,0.0 -55900,1.3551508,2.6833344,,,,,,,,,,,,,, -56000,1.2392638,2.878447,,,,,,,,,,,,,, -56100,1.3041738,2.7686455,,,,,,,,,,,,,, -56200,1.2122896,2.6461945,,,,,,,,,,,,,, -56300,1.2866545,2.484857,,,,,,,,,,,,,, -56400,1.3397176,2.5233364,,,,,,,,,,,,,, -56500,1.2290103,2.566418,,,,,,,,,,,,,, -56600,1.4728596,2.6163328,,,,,,,,,,,,,, -56700,1.0244023,4.7122087,,,,,,,,,,,,,, -56763,,,0.646484375,1.4634003639221191,0.5952399969100952,1.7165286540985107,50000.0,0.4695000350475311,2.4129831790924072,10000.0,26090.748301029205,28838.13270163536,26090.748301029205,2742.359834194184,2.0240936279296875,0.0 -56800,1.2907206,2.6481645,,,,,,,,,,,,,, -56900,1.2833647,2.6197448,,,,,,,,,,,,,, -57000,1.2586259,2.585296,,,,,,,,,,,,,, -57100,1.3130827,2.620536,,,,,,,,,,,,,, -57200,1.3015753,2.4813335,,,,,,,,,,,,,, -57300,1.2547482,2.4768803,,,,,,,,,,,,,, -57400,1.2133838,2.9967744,,,,,,,,,,,,,, -57500,1.058346,4.1869254,,,,,,,,,,,,,, -57600,1.4503973,2.6336608,,,,,,,,,,,,,, -57682,,,0.6556054353713989,1.399447321891785,0.599299967288971,1.666598558425903,50000.0,0.4797000288963318,2.3530988693237305,10000.0,26510.98332166672,29304.03017401696,26510.98332166672,2787.934552669525,2.0629003047943115,0.0 -57700,1.0991777,2.9621084,,,,,,,,,,,,,, -57800,1.5685402,2.538513,,,,,,,,,,,,,, -57900,1.2196852,2.7303836,,,,,,,,,,,,,, -58000,1.3073187,5.1632214,,,,,,,,,,,,,, -58100,0.9942118,5.1524115,,,,,,,,,,,,,, -58200,1.2862031,2.554792,,,,,,,,,,,,,, -58300,1.0380663,3.914201,,,,,,,,,,,,,, -58400,1.2571373,2.5476825,,,,,,,,,,,,,, -58500,1.5346105,2.6984806,,,,,,,,,,,,,, -58600,1.4763879,2.7655284,,,,,,,,,,,,,, -58601,,,0.6381640434265137,1.4684827327728271,0.5956400036811829,1.700318694114685,50000.0,0.4820000231266022,2.3680901527404785,10000.0,26931.549178361893,29770.996066331863,26931.549178361893,2834.242516517639,2.105926275253296,0.0 -58700,1.0330402,5.174204,,,,,,,,,,,,,, -58800,1.4211528,2.5533195,,,,,,,,,,,,,, -58900,0.9354575,4.4604025,,,,,,,,,,,,,, -59000,1.0653316,4.2543716,,,,,,,,,,,,,, -59100,1.0292596,4.6337667,,,,,,,,,,,,,, -59200,1.0867571,2.787784,,,,,,,,,,,,,, -59300,1.2935851,2.618028,,,,,,,,,,,,,, -59400,1.3438281,2.5810778,,,,,,,,,,,,,, -59500,1.0940706,4.584988,,,,,,,,,,,,,, -59519,,,0.6510156393051147,1.446966528892517,0.5954599976539612,1.7019143104553225,50000.0,0.4812000095844269,2.367074728012085,10000.0,27351.7313952446,30235.01827669144,27351.7313952446,2877.9978761672974,2.141942977905273,0.0 -59600,1.2230736,2.5723214,,,,,,,,,,,,,, -59700,1.2904071,3.074059,,,,,,,,,,,,,, -59800,1.1068608,4.686555,,,,,,,,,,,,,, -59900,1.3115159,2.5144734,,,,,,,,,,,,,, -60000,1.337763,3.033666,,,,,,,,,,,,,, -60100,1.0233095,3.5454602,,,,,,,,,,,,,, -60200,1.0801543,3.2323737,,,,,,,,,,,,,, -60300,1.3042146,2.4490857,,,,,,,,,,,,,, -60400,1.1399932,3.5298557,,,,,,,,,,,,,, -60437,,,0.6662304401397705,1.356255054473877,0.5960400104522705,1.6892080307006836,50000.0,0.4795000255107879,2.367875814437866,10000.0,27772.0975227356,30701.9356341362,27772.0975227356,2924.465174674988,2.1778712272644043,0.0 -60500,1.35453,2.638059,,,,,,,,,,,,,, -60600,1.0493308,4.7431865,,,,,,,,,,,,,, -60700,1.3238207,2.555734,,,,,,,,,,,,,, -60800,1.0706704,4.1487174,,,,,,,,,,,,,, -60900,1.2242316,2.94293,,,,,,,,,,,,,, -61000,1.3688934,2.5896444,,,,,,,,,,,,,, -61100,1.299383,2.686348,,,,,,,,,,,,,, -61200,1.1588033,2.4952838,,,,,,,,,,,,,, -61300,1.1301409,3.0129924,,,,,,,,,,,,,, -61357,,,0.6473046541213989,1.4790903329849243,0.5974400043487549,1.7262574434280396,50000.0,0.4777000248432159,2.391491651535034,10000.0,28192.18399953842,31161.63326120377,28192.18399953842,2963.984852075577,2.220782995223999,0.0 -61400,1.1821448,2.69427,,,,,,,,,,,,,, -61500,1.0238419,4.35985,,,,,,,,,,,,,, -61600,1.2451583,2.5561452,,,,,,,,,,,,,, -61700,1.3638536,2.5657547,,,,,,,,,,,,,, -61800,1.1600333,2.7141361,,,,,,,,,,,,,, -61900,1.0953579,3.5980043,,,,,,,,,,,,,, -62000,1.4055269,2.5929365,,,,,,,,,,,,,, -62100,1.2060589,3.1844883,,,,,,,,,,,,,, -62200,1.2563723,2.5991747,,,,,,,,,,,,,, -62274,,,0.6483398079872131,1.4805115461349487,0.5988399982452393,1.7147594690322876,50000.0,0.4845000207424164,2.359397649765014,10000.0,28612.51757788658,31627.710637569427,28612.51757788658,3009.6463346481323,2.254441261291504,0.0 -62300,1.364522,2.4724,,,,,,,,,,,,,, -62400,1.1858959,2.6281872,,,,,,,,,,,,,, -62500,1.0631526,4.9590464,,,,,,,,,,,,,, -62600,1.1277235,2.9221923,,,,,,,,,,,,,, -62700,1.162946,2.8210893,,,,,,,,,,,,,, -62800,0.9917362,4.005134,,,,,,,,,,,,,, -62900,1.25739,2.595766,,,,,,,,,,,,,, -63000,1.1681585,3.7002215,,,,,,,,,,,,,, -63100,1.3023982,2.3546677,,,,,,,,,,,,,, -63191,,,0.6765429377555847,1.3175029754638672,0.6112399697303772,1.6248478889465332,50000.0,0.4908000230789184,2.2709903717041016,10000.0,29032.7456882,32090.70645928383,29032.7456882,3052.322532892227,2.297640323638916,0.0 -63200,1.0355015,5.192359,,,,,,,,,,,,,, -63300,1.1286433,5.0073247,,,,,,,,,,,,,, -63400,1.1697651,2.8264341,,,,,,,,,,,,,, -63500,1.2161297,2.7554407,,,,,,,,,,,,,, -63600,1.2695334,3.0213678,,,,,,,,,,,,,, -63700,1.1630785,2.7302287,,,,,,,,,,,,,, -63800,1.2476243,2.5079134,,,,,,,,,,,,,, -63900,1.1068078,3.1517541,,,,,,,,,,,,,, -64000,0.94348323,4.891645,,,,,,,,,,,,,, -64100,1.0800915,5.1055865,,,,,,,,,,,,,, -64108,,,0.6566015481948853,1.395091533660889,0.6128999590873718,1.6186541318893433,50000.0,0.4914000332355499,2.2900989055633545,10000.0,29452.76300573349,32554.79523062706,29452.76300573349,3096.3077614307404,2.3357009887695312,0.0 -64200,0.98299503,3.7885907,,,,,,,,,,,,,, -64300,1.205042,4.8053575,,,,,,,,,,,,,, -64400,1.2883877,2.5892458,,,,,,,,,,,,,, -64500,1.3505349,2.5047731,,,,,,,,,,,,,, -64600,1.3699397,2.731217,,,,,,,,,,,,,, -64700,1.3176619,2.55661,,,,,,,,,,,,,, -64800,1.2032825,2.5104077,,,,,,,,,,,,,, -64900,1.1947523,5.091879,,,,,,,,,,,,,, -65000,1.230805,5.102867,,,,,,,,,,,,,, -65024,,,0.6565819978713989,1.4120858907699585,0.6054199934005737,1.6623183488845823,50000.0,0.487600028514862,2.3212289810180664,10000.0,29873.10894012451,33020.64948439598,29873.10894012451,3141.72727894783,2.376920223236084,0.0 -65100,1.1658746,5.0613384,,,,,,,,,,,,,, -65200,1.1425402,5.1854506,,,,,,,,,,,,,, -65300,1.3021826,2.6163332,,,,,,,,,,,,,, -65400,1.251143,2.4252393,,,,,,,,,,,,,, -65500,1.3090991,2.5228992,,,,,,,,,,,,,, -65600,0.9750711,5.065139,,,,,,,,,,,,,, -65700,1.2036594,2.3548155,,,,,,,,,,,,,, -65800,1.0427773,3.4748673,,,,,,,,,,,,,, -65900,0.94310445,5.027237,,,,,,,,,,,,,, -65942,,,0.6722265481948853,1.4081395864486694,0.6116799712181091,1.6795889139175415,50000.0,0.4879000186920166,2.3212332725524902,10000.0,30293.38259911537,33486.357377290726,30293.38259911537,3187.0775051116943,2.411798477172852,0.0 -66000,1.3279402,2.4194334,,,,,,,,,,,,,, -66100,1.3919185,2.5251837,,,,,,,,,,,,,, -66200,1.2031093,2.9571404,,,,,,,,,,,,,, -66300,1.3390083,2.9525144,,,,,,,,,,,,,, -66400,1.2680038,3.3372028,,,,,,,,,,,,,, -66500,1.045368,4.8750544,,,,,,,,,,,,,, -66600,1.4275682,2.4592416,,,,,,,,,,,,,, -66700,1.6481175,2.55866,,,,,,,,,,,,,, -66800,1.3870407,2.4198463,,,,,,,,,,,,,, -66860,,,0.6582812070846558,1.39947247505188,0.6097399592399597,1.6316499710083008,50000.0,0.4897000193595886,2.283788681030273,10000.0,30713.42664384842,33950.26797604561,30713.42664384842,3230.855504989624,2.4521572589874268,0.0 -66900,1.0797129,4.7417855,,,,,,,,,,,,,, -67000,1.1265024,5.0935426,,,,,,,,,,,,,, -67100,0.98307055,3.98997,,,,,,,,,,,,,, -67200,1.2388995,5.1975517,,,,,,,,,,,,,, -67300,1.3090516,2.4890459,,,,,,,,,,,,,, -67400,1.1996214,3.0358982,,,,,,,,,,,,,, -67500,1.1647309,3.5854032,,,,,,,,,,,,,, -67600,1.1138983,3.6636562,,,,,,,,,,,,,, -67700,1.0806692,3.8398683,,,,,,,,,,,,,, -67778,,,0.6619336009025574,1.376253962516785,0.6098399758338928,1.6209572553634644,50000.0,0.4891000092029571,2.292510986328125,10000.0,31133.74489402771,34415.31527900696,31133.74489402771,3275.497382879257,2.491576910018921,0.0 -67800,1.3394375,2.5192955,,,,,,,,,,,,,, -67900,1.2230103,2.5629652,,,,,,,,,,,,,, -68000,1.2209282,3.4225612,,,,,,,,,,,,,, -68100,1.2573186,4.939247,,,,,,,,,,,,,, -68200,1.3750991,2.4093287,,,,,,,,,,,,,, -68300,1.3059973,2.7660165,,,,,,,,,,,,,, -68400,1.0525603,3.6352181,,,,,,,,,,,,,, -68500,0.9752643,4.8025494,,,,,,,,,,,,,, -68600,1.3152531,2.5517673,,,,,,,,,,,,,, -68697,,,0.6661718487739563,1.374730348587036,0.6109600067138672,1.637009620666504,50000.0,0.4903000295162201,2.307218551635742,10000.0,31553.782564401627,34879.88833665848,31553.782564401627,3319.9401018619537,2.5363433361053467,0.0 -68700,1.1891992,5.011734,,,,,,,,,,,,,, -68800,1.2550759,2.344183,,,,,,,,,,,,,, -68900,1.2921792,2.4291825,,,,,,,,,,,,,, -69000,1.4876969,2.4442103,,,,,,,,,,,,,, -69100,1.0404207,3.2772405,,,,,,,,,,,,,, -69200,1.2471448,2.476325,,,,,,,,,,,,,, -69300,1.0133288,4.5465975,,,,,,,,,,,,,, -69400,1.0357667,5.0388465,,,,,,,,,,,,,, -69500,1.2455791,2.4311035,,,,,,,,,,,,,, -69600,1.276182,2.3259335,,,,,,,,,,,,,, -69615,,,0.6862890720367432,1.303639531135559,0.6137999892234802,1.6325457096099854,50000.0,0.4954000115394592,2.2856035232543945,10000.0,31974.15192389488,35339.82807254791,31974.15192389488,3359.4197540283203,2.5785627365112305,0.0 -69700,1.2922637,2.2822807,,,,,,,,,,,,,, -69800,1.0201955,4.39592,,,,,,,,,,,,,, -69900,1.0590835,4.6610174,,,,,,,,,,,,,, -70000,1.4113407,2.5923164,,,,,,,,,,,,,, -70100,1.0427616,3.2872443,,,,,,,,,,,,,, -70200,1.3723836,2.9068775,,,,,,,,,,,,,, -70300,1.3663173,2.4210813,,,,,,,,,,,,,, -70400,1.3225403,2.4960918,,,,,,,,,,,,,, -70500,1.2492517,2.3596125,,,,,,,,,,,,,, -70532,,,0.6647265553474426,1.3673946857452393,0.618619978427887,1.584282636642456,50000.0,0.5014000535011292,2.240478515625,10000.0,32394.14501833916,35805.46223473549,32394.14501833916,3404.967422485352,2.6233327388763428,0.0 -70600,1.3903877,2.4831855,,,,,,,,,,,,,, -70700,1.24099,3.0895665,,,,,,,,,,,,,, -70800,1.1766592,3.687613,,,,,,,,,,,,,, -70900,0.99334043,4.218916,,,,,,,,,,,,,, -71000,1.1080385,4.8360653,,,,,,,,,,,,,, -71100,1.0700388,4.0574493,,,,,,,,,,,,,, -71200,1.3029041,2.5117383,,,,,,,,,,,,,, -71300,1.3088995,2.9165804,,,,,,,,,,,,,, -71400,1.151494,4.497271,,,,,,,,,,,,,, -71449,,,0.6814843416213989,1.3013520240783691,0.6236599683761597,1.567617654800415,50000.0,0.4984000325202942,2.2308449745178223,10000.0,32814.301644325256,36272.37492394447,32814.301644325256,3451.6360535621643,2.662638902664185,0.0 -71500,1.3290651,2.5514958,,,,,,,,,,,,,, -71600,1.4543121,2.4476924,,,,,,,,,,,,,, -71700,1.1954697,2.9890547,,,,,,,,,,,,,, -71800,1.2115532,5.074474,,,,,,,,,,,,,, -71900,1.3901677,2.394882,,,,,,,,,,,,,, -72000,1.185549,4.7589993,,,,,,,,,,,,,, -72100,1.2311102,2.779051,,,,,,,,,,,,,, -72200,1.297073,2.7484806,,,,,,,,,,,,,, -72300,1.0159822,4.7373915,,,,,,,,,,,,,, -72368,,,0.6871093511581421,1.261122465133667,0.6153599619865417,1.602860927581787,50000.0,0.4963000118732452,2.253334045410156,10000.0,33234.60793232918,36739.626620054245,33234.60793232918,3498.489343166352,2.706322431564331,0.0 -72400,1.325364,2.4057736,,,,,,,,,,,,,, -72500,1.2697786,3.512175,,,,,,,,,,,,,, -72600,1.151325,4.8954935,,,,,,,,,,,,,, -72700,1.3213079,4.6126184,,,,,,,,,,,,,, -72800,1.274351,4.846816,,,,,,,,,,,,,, -72900,1.382611,2.4704704,,,,,,,,,,,,,, -73000,1.1183108,4.6097903,,,,,,,,,,,,,, -73100,1.2721858,2.4401119,,,,,,,,,,,,,, -73200,1.3763272,2.5658607,,,,,,,,,,,,,, -73287,,,0.6668750047683716,1.3529757261276243,0.6220799684524536,1.573674201965332,50000.0,0.5010000467300415,2.2361388206481934,10000.0,33654.99967765808,37200.15159320831,33654.99967765808,3538.5254430770874,2.7534587383270264,0.0 -73300,1.3537174,2.572242,,,,,,,,,,,,,, -73400,1.1617647,4.234084,,,,,,,,,,,,,, -73500,1.387378,2.4556732,,,,,,,,,,,,,, -73600,1.133214,4.487252,,,,,,,,,,,,,, -73700,1.190981,3.6092203,,,,,,,,,,,,,, -73800,1.361348,2.495752,,,,,,,,,,,,,, -73900,1.3810099,2.371092,,,,,,,,,,,,,, -74000,1.2460108,2.9352307,,,,,,,,,,,,,, -74100,1.3592631,2.4501495,,,,,,,,,,,,,, -74200,1.2961797,2.4328644,,,,,,,,,,,,,, -74204,,,0.6692773103713989,1.3801852464675903,0.6154199838638306,1.623300552368164,50000.0,0.4923000335693359,2.295822858810425,10000.0,34075.14587640762,37663.99420571327,34075.14587640762,3582.1325442790985,2.7938594818115234,0.0 -74300,1.1393898,3.2872128,,,,,,,,,,,,,, -74400,1.4220842,2.3978894,,,,,,,,,,,,,, -74500,1.1171049,5.066608,,,,,,,,,,,,,, -74600,1.2157453,2.7590108,,,,,,,,,,,,,, -74700,0.93277556,5.123307,,,,,,,,,,,,,, -74800,1.3661698,2.4360666,,,,,,,,,,,,,, -74900,1.124453,3.953754,,,,,,,,,,,,,, -75000,1.3202349,3.2304344,,,,,,,,,,,,,, -75100,1.3759695,2.4593923,,,,,,,,,,,,,, -75122,,,0.6792187094688416,1.330292582511902,0.6165399551391602,1.6309008598327637,50000.0,0.4949000179767608,2.308547258377075,10000.0,34495.48349118233,38128.09023118019,34495.48349118233,3625.7986521720886,2.837161064147949,0.0 -75200,1.33972,2.4079425,,,,,,,,,,,,,, -75300,1.4454342,2.34256,,,,,,,,,,,,,, -75400,1.1548465,4.8433676,,,,,,,,,,,,,, -75500,1.7916392,2.414957,,,,,,,,,,,,,, -75600,1.3657489,2.8088663,,,,,,,,,,,,,, -75700,1.3590212,2.3578944,,,,,,,,,,,,,, -75800,1.1506548,2.9502475,,,,,,,,,,,,,, -75900,1.5239954,2.384844,,,,,,,,,,,,,, -76000,1.3076032,2.329563,,,,,,,,,,,,,, -76040,,,0.668652355670929,1.3684196472167969,0.6186800003051758,1.6112457513809204,50000.0,0.4988000094890594,2.2662360668182373,10000.0,34915.78348207474,38592.96436548233,34915.78348207474,3670.284652233124,2.877194404602051,0.0 -76100,1.2309695,2.5706244,,,,,,,,,,,,,, -76200,1.1444278,4.8419366,,,,,,,,,,,,,, -76300,1.3878618,2.5390215,,,,,,,,,,,,,, -76400,1.477032,2.4385881,,,,,,,,,,,,,, -76500,1.3277533,2.8779316,,,,,,,,,,,,,, -76600,1.4082596,5.065243,,,,,,,,,,,,,, -76700,1.3112009,2.5062265,,,,,,,,,,,,,, -76800,1.2630296,3.1332746,,,,,,,,,,,,,, -76900,1.3783238,2.4185236,,,,,,,,,,,,,, -76959,,,0.6804882884025574,1.2870023250579834,0.6269800066947937,1.5367377996444702,50000.0,0.509600043296814,2.2034332752227783,10000.0,35335.89672112465,39057.71058821678,35335.89672112465,3714.829068660736,2.9171230792999268,0.0 -77000,1.2594969,2.3788762,,,,,,,,,,,,,, -77100,1.2252616,2.766881,,,,,,,,,,,,,, -77200,1.1713694,5.124359,,,,,,,,,,,,,, -77300,1.085566,3.788675,,,,,,,,,,,,,, -77400,1.3503635,2.4462001,,,,,,,,,,,,,, -77500,1.1910559,3.6117182,,,,,,,,,,,,,, -77600,1.2557819,2.4892898,,,,,,,,,,,,,, -77700,1.3932797,2.6479232,,,,,,,,,,,,,, -77800,1.5541307,2.372135,,,,,,,,,,,,,, -77879,,,0.6826757788658142,1.2964460849761963,0.6247599720954895,1.580952763557434,50000.0,0.4988000094890594,2.237107753753662,10000.0,35755.94821023941,39520.92169165611,35755.94821023941,3757.898426532746,2.958980321884156,0.0 -77900,1.8064166,2.4325943,,,,,,,,,,,,,, -78000,1.3915926,4.0568523,,,,,,,,,,,,,, -78100,1.5281276,2.5185688,,,,,,,,,,,,,, -78200,1.3504089,2.4308841,,,,,,,,,,,,,, -78300,1.2107248,2.4409413,,,,,,,,,,,,,, -78400,1.2999701,4.428824,,,,,,,,,,,,,, -78500,1.2350428,2.4310749,,,,,,,,,,,,,, -78600,1.2710454,2.429087,,,,,,,,,,,,,, -78700,1.3427851,2.3411944,,,,,,,,,,,,,, -78797,,,0.6805663704872131,1.3006861209869385,0.6326799988746643,1.5357210636138916,50000.0,0.5067000389099121,2.2033281326293945,10000.0,36176.19422531128,39986.41137290001,36176.19422531128,3803.0521445274353,3.0000874996185303,0.0 -78800,1.2813487,2.8188376,,,,,,,,,,,,,, -78900,1.136766,3.699154,,,,,,,,,,,,,, -79000,1.3836503,2.5401716,,,,,,,,,,,,,, -79100,1.3602924,2.439101,,,,,,,,,,,,,, -79200,1.2302761,4.0239863,,,,,,,,,,,,,, -79300,1.3227652,2.4453132,,,,,,,,,,,,,, -79400,1.2371738,2.2570248,,,,,,,,,,,,,, -79500,1.2988944,2.4431992,,,,,,,,,,,,,, -79600,1.3072405,2.6912827,,,,,,,,,,,,,, -79700,1.2036035,3.344731,,,,,,,,,,,,,, -79714,,,0.6774023175239563,1.3380963802337646,0.6273800134658813,1.583625555038452,50000.0,0.5046000480651855,2.24481201171875,10000.0,36596.54655098915,40453.15292882919,36596.54655098915,3849.35542011261,3.0385146141052246,0.0 -79800,1.1379592,4.4867105,,,,,,,,,,,,,, -79900,1.1617469,3.7996469,,,,,,,,,,,,,, -80000,1.4109496,2.7936502,,,,,,,,,,,,,, -80100,1.3241073,2.4404461,,,,,,,,,,,,,, -80200,1.2896314,2.3194695,,,,,,,,,,,,,, -80300,1.2165309,2.9016252,,,,,,,,,,,,,, -80400,1.332957,3.512255,,,,,,,,,,,,,, -80500,1.3979805,2.3712034,,,,,,,,,,,,,, -80600,1.2316569,4.937988,,,,,,,,,,,,,, -80632,,,0.6942577958106995,1.2454754114151,0.6340599656105042,1.517375946044922,50000.0,0.509600043296814,2.193935871124268,10000.0,37016.87239527702,40916.673253536224,37016.87239527702,3892.457942724228,3.081490993499756,0.0 -80700,1.3313954,2.4501367,,,,,,,,,,,,,, -80800,1.5903836,2.6176214,,,,,,,,,,,,,, -80900,1.2061557,2.802205,,,,,,,,,,,,,, -81000,1.1702754,2.84983,,,,,,,,,,,,,, -81100,1.4331579,2.3384228,,,,,,,,,,,,,, -81200,1.256537,4.2849255,,,,,,,,,,,,,, -81300,1.1749942,4.433774,,,,,,,,,,,,,, -81400,1.1353239,4.908223,,,,,,,,,,,,,, -81500,1.2400943,3.4931622,,,,,,,,,,,,,, -81549,,,0.7085741758346558,1.1876144409179688,0.6307399868965149,1.532135248184204,50000.0,0.5088000297546387,2.1980860233306885,10000.0,37437.10391163826,41381.222581624985,37437.10391163826,3936.6828587055206,3.125739812850952,0.0 -81600,1.1891233,3.3217518,,,,,,,,,,,,,, -81700,1.263191,3.9678571,,,,,,,,,,,,,, -81800,1.5888275,2.4824243,,,,,,,,,,,,,, -81900,1.4309298,2.361088,,,,,,,,,,,,,, -82000,1.2525716,2.329119,,,,,,,,,,,,,, -82100,1.4292675,2.4849646,,,,,,,,,,,,,, -82200,1.2385044,3.5795865,,,,,,,,,,,,,, -82300,1.2622631,3.1686988,,,,,,,,,,,,,, -82400,1.3887001,2.4212532,,,,,,,,,,,,,, -82465,,,0.6794531345367432,1.320271611213684,0.6279999613761902,1.5583600997924805,50000.0,0.5051000118255615,2.229015350341797,10000.0,37857.19369125366,41843.72469615936,37857.19369125366,3979.00262093544,3.170383214950561,0.0 -82500,1.2949547,2.3554382,,,,,,,,,,,,,, -82600,1.2657579,2.9328344,,,,,,,,,,,,,, -82700,1.1177787,3.7519753,,,,,,,,,,,,,, -82800,1.1142008,3.7322721,,,,,,,,,,,,,, -82900,1.1370226,4.5467277,,,,,,,,,,,,,, -83000,1.4921976,2.4581268,,,,,,,,,,,,,, -83100,1.2167659,4.843635,,,,,,,,,,,,,, -83200,1.3256731,2.372278,,,,,,,,,,,,,, -83300,1.3262674,2.431157,,,,,,,,,,,,,, -83383,,,0.6886913776397705,1.2884644269943235,0.6291199922561646,1.5523051023483276,50000.0,0.5099000334739685,2.218010187149048,10000.0,38277.31278705597,42309.54677009583,38277.31278705597,4024.61803984642,3.2090845108032227,0.0 -83400,1.3908442,2.5519063,,,,,,,,,,,,,, -83500,1.6418694,2.3503242,,,,,,,,,,,,,, -83600,1.3190141,2.730023,,,,,,,,,,,,,, -83700,1.3465513,2.8766937,,,,,,,,,,,,,, -83800,1.3415366,2.4581654,,,,,,,,,,,,,, -83900,1.1904254,4.44119,,,,,,,,,,,,,, -84000,1.326596,2.3421454,,,,,,,,,,,,,, -84100,1.297607,2.2651522,,,,,,,,,,,,,, -84200,1.3207645,2.6459029,,,,,,,,,,,,,, -84300,1.481578,2.2587476,,,,,,,,,,,,,, -84301,,,0.7104882597923279,1.1823593378067017,0.6382799744606018,1.511709451675415,50000.0,0.5146000385284424,2.178715229034424,10000.0,38697.50476980209,42774.288845300674,38697.50476980209,4069.077126741409,3.251221179962158,0.0 -84400,1.4928676,2.353109,,,,,,,,,,,,,, -84500,1.3516474,2.2942114,,,,,,,,,,,,,, -84600,1.3812706,2.4589462,,,,,,,,,,,,,, -84700,1.2803326,3.8808892,,,,,,,,,,,,,, -84800,1.7382765,2.4949133,,,,,,,,,,,,,, -84900,1.3163383,2.2521114,,,,,,,,,,,,,, -85000,1.5261978,2.3957293,,,,,,,,,,,,,, -85100,1.4292382,3.3923888,,,,,,,,,,,,,, -85200,1.4529843,2.4155405,,,,,,,,,,,,,, -85212,,,0.6768164038658142,1.3248145580291748,0.6277599930763245,1.5532524585723877,50000.0,0.5046000480651855,2.224882364273072,10000.0,39117.63734054565,43238.81547117233,39117.63734054565,4113.385211467743,3.2891671657562256,0.0 -85300,1.0974582,4.3563724,,,,,,,,,,,,,, -85400,1.2014031,3.5774953,,,,,,,,,,,,,, -85500,1.2868252,2.628014,,,,,,,,,,,,,, -85600,1.2294269,2.876846,,,,,,,,,,,,,, -85700,1.212785,2.5878377,,,,,,,,,,,,,, -85800,1.4221066,2.5283442,,,,,,,,,,,,,, -85900,1.3548498,4.8930445,,,,,,,,,,,,,, -86000,1.0842174,4.8859253,,,,,,,,,,,,,, -86100,1.2593884,4.408622,,,,,,,,,,,,,, -86128,,,0.6905273199081421,1.2369657754898071,0.6373599767684937,1.5001189708709717,50000.0,0.5139999985694885,2.168210744857788,10000.0,39537.63686776161,43704.54471921921,39537.63686776161,4159.022690296173,3.333111047744751,0.0 -86200,1.3646071,4.3689284,,,,,,,,,,,,,, -86300,1.4218872,2.3910298,,,,,,,,,,,,,, -86400,1.381205,2.2998881,,,,,,,,,,,,,, -86500,1.4112605,2.549025,,,,,,,,,,,,,, -86600,1.4522607,2.6386068,,,,,,,,,,,,,, -86700,1.7370929,2.4679465,,,,,,,,,,,,,, -86800,1.5244203,2.5831468,,,,,,,,,,,,,, -86900,1.3753852,2.30269,,,,,,,,,,,,,, -87000,1.2047037,5.0308037,,,,,,,,,,,,,, -87043,,,0.7060351371765137,1.2162786722183228,0.64301997423172,1.5114175081253052,50000.0,0.5199000239372253,2.167684316635132,10000.0,39957.94800043106,44169.95958900452,39957.94800043106,4204.034997463226,3.376688241958618,0.0 -87100,1.3292934,3.706297,,,,,,,,,,,,,, -87200,1.5618757,2.3465426,,,,,,,,,,,,,, -87300,1.3447163,2.216455,,,,,,,,,,,,,, -87400,1.3528492,2.2214804,,,,,,,,,,,,,, -87500,1.1612991,4.1477976,,,,,,,,,,,,,, -87600,1.204299,3.4922197,,,,,,,,,,,,,, -87700,1.3605795,2.2894804,,,,,,,,,,,,,, -87800,1.3830857,2.336557,,,,,,,,,,,,,, -87900,1.2666467,4.5596437,,,,,,,,,,,,,, -87961,,,0.6932812333106995,1.2655775547027588,0.6387799978256226,1.5086603164672852,50000.0,0.5170000195503235,2.170346975326538,10000.0,40378.29867053032,44637.17258501053,40378.29867053032,4250.808509349823,3.416672706604004,0.0 -88000,1.18894,4.456615,,,,,,,,,,,,,, -88100,1.3672175,2.4971757,,,,,,,,,,,,,, -88200,1.4749073,2.3144593,,,,,,,,,,,,,, -88300,1.6321062,2.2005908,,,,,,,,,,,,,, -88400,1.4241892,2.384794,,,,,,,,,,,,,, -88500,1.3867928,2.720461,,,,,,,,,,,,,, -88600,1.3512158,2.2750454,,,,,,,,,,,,,, -88700,1.3779522,4.0230646,,,,,,,,,,,,,, -88800,1.2284023,4.78724,,,,,,,,,,,,,, -88880,,,0.7005273103713989,1.252591609954834,0.6423400044441223,1.520622730255127,50000.0,0.5182000398635864,2.1785120964050293,10000.0,40798.624522686005,45101.79077982903,40798.624522686005,4295.005742549896,3.462848424911499,0.0 -88900,1.2426717,2.5694878,,,,,,,,,,,,,, -89000,1.2831653,3.3654723,,,,,,,,,,,,,, -89100,1.1663144,3.5972452,,,,,,,,,,,,,, -89200,1.508686,2.2690349,,,,,,,,,,,,,, -89300,1.5272661,2.3221705,,,,,,,,,,,,,, -89400,1.3947078,2.4021418,,,,,,,,,,,,,, -89500,1.6004338,2.412394,,,,,,,,,,,,,, -89600,1.2588868,2.8488686,,,,,,,,,,,,,, -89700,1.5247009,2.350657,,,,,,,,,,,,,, -89796,,,0.7016210556030273,1.2490087747573853,0.6393600106239319,1.5409966707229614,50000.0,0.5162000060081482,2.1909217834472656,10000.0,41218.92868351936,45567.463312625885,41218.92868351936,4340.285962820053,3.503141403198242,0.0 -89800,1.3975742,2.692666,,,,,,,,,,,,,, -89900,1.6500998,2.2518435,,,,,,,,,,,,,, -90000,1.5463151,2.2041593,,,,,,,,,,,,,, -90100,1.3752956,2.247694,,,,,,,,,,,,,, -90200,1.4867878,2.1948342,,,,,,,,,,,,,, -90300,1.4526536,2.3295326,,,,,,,,,,,,,, -90400,1.3744729,2.7259634,,,,,,,,,,,,,, -90500,1.3179505,2.5833585,,,,,,,,,,,,,, -90600,1.207351,2.9141836,,,,,,,,,,,,,, -90700,1.6901963,2.3599906,,,,,,,,,,,,,, -90713,,,0.69837886095047,1.2028743028640747,0.6417199969291687,1.4627127647399902,50000.0,0.5216000080108643,2.1296322345733643,10000.0,41639.11161088944,46028.47916865349,41639.11161088944,4381.028985977173,3.543948173522949,0.0 -90800,1.2960062,2.4969018,,,,,,,,,,,,,, -90900,1.358201,4.9766655,,,,,,,,,,,,,, -91000,1.4569072,2.3192215,,,,,,,,,,,,,, -91100,1.3570642,2.3215597,,,,,,,,,,,,,, -91200,1.4218144,2.838863,,,,,,,,,,,,,, -91300,1.3615896,2.4658031,,,,,,,,,,,,,, -91400,1.4128678,2.3031223,,,,,,,,,,,,,, -91500,1.3035086,3.798931,,,,,,,,,,,,,, -91600,1.4079295,2.587213,,,,,,,,,,,,,, -91631,,,0.6951367259025574,1.2514410018920898,0.6416599750518799,1.5035076141357422,50000.0,0.5245000123977661,2.149744272232056,10000.0,42059.214393138885,46494.97500014305,42059.214393138885,4427.328744649887,3.5883829593658447,0.0 -91700,1.2560538,2.993421,,,,,,,,,,,,,, -91800,1.2591901,4.9017525,,,,,,,,,,,,,, -91900,1.3747226,4.5962415,,,,,,,,,,,,,, -92000,1.3803871,2.1275911,,,,,,,,,,,,,, -92100,1.6746818,2.3915062,,,,,,,,,,,,,, -92200,1.409501,2.6017056,,,,,,,,,,,,,, -92300,1.3463358,4.6558332,,,,,,,,,,,,,, -92400,1.4154392,2.3655634,,,,,,,,,,,,,, -92500,1.4645661,2.378336,,,,,,,,,,,,,, -92549,,,0.703320324420929,1.215058445930481,0.6431199908256531,1.4973742961883545,50000.0,0.5261000394821167,2.147681713104248,10000.0,42479.155719041824,46960.63577723503,42479.155719041824,4472.958468675613,3.6285886764526367,0.0 -92600,1.2982702,3.0011444,,,,,,,,,,,,,, -92700,1.6439623,2.294889,,,,,,,,,,,,,, -92800,1.5504853,2.337013,,,,,,,,,,,,,, -92900,1.2309734,3.9597747,,,,,,,,,,,,,, -93000,1.5815609,2.5262737,,,,,,,,,,,,,, -93100,1.3234113,3.085944,,,,,,,,,,,,,, -93200,1.7069907,2.285988,,,,,,,,,,,,,, -93300,1.408431,4.8752007,,,,,,,,,,,,,, -93400,1.2909206,3.1707165,,,,,,,,,,,,,, -93468,,,0.7299999594688416,1.073001265525818,0.6488400101661682,1.4530651569366455,50000.0,0.5208000540733337,2.1210744380950928,10000.0,42899.42155408859,47427.48384642601,42899.42155408859,4519.443992853165,3.677181482315064,0.0 -93500,1.722117,2.3415878,,,,,,,,,,,,,, -93600,1.5745794,2.3350883,,,,,,,,,,,,,, -93700,1.409449,2.3019335,,,,,,,,,,,,,, -93800,1.1961204,3.7629297,,,,,,,,,,,,,, -93900,1.3339202,3.0419447,,,,,,,,,,,,,, -94000,1.4679002,2.1860027,,,,,,,,,,,,,, -94100,1.3614823,2.0791168,,,,,,,,,,,,,, -94200,1.669143,4.616843,,,,,,,,,,,,,, -94300,1.2065301,3.883254,,,,,,,,,,,,,, -94385,,,0.7069921493530273,1.1722054481506348,0.6530399918556213,1.4273556470870972,50000.0,0.5323000550270081,2.078791618347168,10000.0,43319.45416808128,47893.16714167595,43319.45416808128,4565.002544879913,3.720834970474243,0.0 -94400,1.4933511,2.1492064,,,,,,,,,,,,,, -94500,1.5341913,2.1839335,,,,,,,,,,,,,, -94600,1.1944185,4.691222,,,,,,,,,,,,,, -94700,1.427351,2.2807214,,,,,,,,,,,,,, -94800,1.4023787,3.2606568,,,,,,,,,,,,,, -94900,1.3542473,4.5953617,,,,,,,,,,,,,, -95000,1.5406929,4.6835427,,,,,,,,,,,,,, -95100,1.4101744,2.386497,,,,,,,,,,,,,, -95200,1.4649634,2.119925,,,,,,,,,,,,,, -95300,1.3798239,2.1999745,,,,,,,,,,,,,, -95303,,,0.712109386920929,1.1649584770202637,0.6550599932670593,1.4420679807662964,50000.0,0.5348000526428223,2.1036524772644043,10000.0,43739.83526778221,48360.269610881805,43739.83526778221,4611.628629922867,3.7676095962524414,0.0 -95400,1.3544726,3.1661782,,,,,,,,,,,,,, -95500,1.5162967,2.32759,,,,,,,,,,,,,, -95600,1.4716103,2.5420017,,,,,,,,,,,,,, -95700,1.3253068,2.4607038,,,,,,,,,,,,,, -95800,1.5557926,2.3163013,,,,,,,,,,,,,, -95900,1.2256184,4.775695,,,,,,,,,,,,,, -96000,1.511374,2.1887054,,,,,,,,,,,,,, -96100,1.3129745,4.7683263,,,,,,,,,,,,,, -96200,1.4768713,3.8397443,,,,,,,,,,,,,, -96220,,,0.7240039110183716,1.1309791803359983,0.6502199769020081,1.4678243398666382,50000.0,0.5303000211715698,2.125929832458496,10000.0,44159.95757508278,48825.42972564697,44159.95757508278,4656.578102588654,3.80673885345459,0.0 -96300,1.4196594,2.243288,,,,,,,,,,,,,, -96400,1.3549093,2.7855806,,,,,,,,,,,,,, -96500,1.2913387,4.539915,,,,,,,,,,,,,, -96600,1.5635601,2.2927713,,,,,,,,,,,,,, -96700,1.2045444,4.798977,,,,,,,,,,,,,, -96800,1.6109453,2.3740292,,,,,,,,,,,,,, -96900,1.4890227,2.1472988,,,,,,,,,,,,,, -97000,1.8595061,2.1064835,,,,,,,,,,,,,, -97100,1.2076521,3.331306,,,,,,,,,,,,,, -97138,,,0.7091991901397705,1.208737015724182,0.6511200070381165,1.454700946807861,50000.0,0.5326000452041626,2.099363327026367,10000.0,44579.869277477264,49291.24484491348,44579.869277477264,4702.385105133057,3.8552112579345694,0.0 -97200,1.3773527,2.4032383,,,,,,,,,,,,,, -97300,1.6207649,2.0350044,,,,,,,,,,,,,, -97400,1.5230402,2.1063352,,,,,,,,,,,,,, -97500,1.3006507,2.7706635,,,,,,,,,,,,,, -97600,1.4821639,2.2642312,,,,,,,,,,,,,, -97700,1.1927791,3.859953,,,,,,,,,,,,,, -97800,1.3870001,3.7912967,,,,,,,,,,,,,, -97900,1.4002703,3.7601287,,,,,,,,,,,,,, -98000,1.6477492,4.7822833,,,,,,,,,,,,,, -98053,,,0.714550793170929,1.1435582637786863,0.658519983291626,1.4108588695526123,50000.0,0.5384000539779663,2.0602378845214844,10000.0,45000.0894985199,49757.63449978829,45000.0894985199,4748.463057041168,3.8989665508270255,0.0 -98100,1.4686301,2.251827,,,,,,,,,,,,,, -98200,1.3107722,2.8003042,,,,,,,,,,,,,, -98300,1.2571921,3.1588755,,,,,,,,,,,,,, -98400,1.4068321,2.1713135,,,,,,,,,,,,,, -98500,1.3235167,3.7655241,,,,,,,,,,,,,, -98600,1.4942304,2.0648034,,,,,,,,,,,,,, -98700,1.5932333,2.2435443,,,,,,,,,,,,,, -98800,1.2388484,4.4241214,,,,,,,,,,,,,, -98900,1.4981966,2.260273,,,,,,,,,,,,,, -98971,,,0.7219140529632568,1.1263810396194458,0.6539799571037292,1.4435505867004397,50000.0,0.5272000432014465,2.114023447036743,10000.0,45420.03511428833,50219.339812517166,45420.03511428833,4790.124536275864,3.9481842517852783,0.0 -99000,1.3917556,2.970546,,,,,,,,,,,,,, -99100,1.2907141,3.1135292,,,,,,,,,,,,,, -99200,1.3154317,3.296824,,,,,,,,,,,,,, -99300,1.1923199,3.377852,,,,,,,,,,,,,, -99400,1.3245168,4.197578,,,,,,,,,,,,,, -99500,1.8107206,2.3408694,,,,,,,,,,,,,, -99600,1.5208777,2.2174883,,,,,,,,,,,,,, -99700,1.387312,3.292375,,,,,,,,,,,,,, -99800,1.3140209,4.764408,,,,,,,,,,,,,, -99888,,,0.7099218368530273,1.1682158708572388,0.6560999751091003,1.4241340160369873,50000.0,0.5340999960899353,2.083680391311645,10000.0,45840.2887210846,50686.42065501213,45840.2887210846,4836.862557888031,3.9894981384277344,0.0 -99900,1.3309408,2.1532671,,,,,,,,,,,,,, -100000,1.677938,2.5097938,,,,,,,,,,,,,, -100100,1.7253994,2.2770565,,,,,,,,,,,,,, -100200,1.3884352,3.8066754,,,,,,,,,,,,,, -100300,1.2627742,4.494467,,,,,,,,,,,,,, -100400,1.4920415,2.422891,,,,,,,,,,,,,, -100500,1.6448687,2.308283,,,,,,,,,,,,,, -100600,1.2733612,4.2371397,,,,,,,,,,,,,, -100700,1.3232526,2.5206223,,,,,,,,,,,,,, -100800,1.5759815,2.3081176,,,,,,,,,,,,,, -100807,,,0.7173827886581421,1.1348272562026978,0.6578800082206726,1.4096122980117798,50000.0,0.534000039100647,2.062699556350708,10000.0,46260.74950814247,51153.77632880211,46260.74950814247,4883.661042928696,4.037500381469727,0.0 -100900,1.4368259,2.1799715,,,,,,,,,,,,,, -101000,1.4414649,4.788266,,,,,,,,,,,,,, -101100,1.2715412,3.2433398,,,,,,,,,,,,,, -101200,1.6310971,2.2290835,,,,,,,,,,,,,, -101300,1.451393,2.5382113,,,,,,,,,,,,,, -101400,1.7566973,2.2673833,,,,,,,,,,,,,, -101500,1.4402354,2.5077548,,,,,,,,,,,,,, -101600,1.4286805,2.1791952,,,,,,,,,,,,,, -101700,1.4672954,2.117537,,,,,,,,,,,,,, -101725,,,0.7226757407188416,1.1128582954406738,0.6645799875259399,1.3956934213638306,50000.0,0.5418000221252441,2.032041311264038,10000.0,46681.04964232445,51622.92252993584,46681.04964232445,4932.40603351593,4.089587211608887,0.0 -101800,1.289584,4.1205955,,,,,,,,,,,,,, -101900,1.6610192,2.2201173,,,,,,,,,,,,,, -102000,1.340571,4.6563287,,,,,,,,,,,,,, -102100,1.3312036,4.7140694,,,,,,,,,,,,,, -102200,1.5399519,2.219591,,,,,,,,,,,,,, -102300,1.7575786,2.348929,,,,,,,,,,,,,, -102400,1.4746884,4.7307014,,,,,,,,,,,,,, -102500,1.3396451,4.2669244,,,,,,,,,,,,,, -102600,1.7403752,2.1554163,,,,,,,,,,,,,, -102642,,,0.7225781083106995,1.1422998905181885,0.6594600081443787,1.4308340549468994,50000.0,0.5415000319480896,2.086282730102539,10000.0,47101.4250562191,52086.83007669449,47101.4250562191,4975.847240924835,4.131916999816895,0.0 -102700,1.2660649,4.394008,,,,,,,,,,,,,, -102800,1.6032126,2.1409872,,,,,,,,,,,,,, -102900,1.2098948,4.3807096,,,,,,,,,,,,,, -103000,1.4831454,2.3554251,,,,,,,,,,,,,, -103100,1.3465291,2.460049,,,,,,,,,,,,,, -103200,1.3913796,4.696104,,,,,,,,,,,,,, -103300,1.3306249,2.995739,,,,,,,,,,,,,, -103400,1.3455245,2.087598,,,,,,,,,,,,,, -103500,1.4740307,2.3895278,,,,,,,,,,,,,, -103560,,,0.7223241925239563,1.1243679523468018,0.6632999777793884,1.3928260803222656,50000.0,0.5383000373840332,2.04221248626709,10000.0,47521.5824341774,52553.43996691704,47521.5824341774,5022.208994150162,4.173872470855713,0.0 -103600,1.2384917,3.2693186,,,,,,,,,,,,,, -103700,1.2995746,3.211738,,,,,,,,,,,,,, -103800,1.746349,2.1621523,,,,,,,,,,,,,, -103900,1.5081072,2.2066212,,,,,,,,,,,,,, -104000,1.5506325,2.2004402,,,,,,,,,,,,,, -104100,1.3320383,4.7313747,,,,,,,,,,,,,, -104200,1.5115651,2.068396,,,,,,,,,,,,,, -104300,1.6845808,2.2171643,,,,,,,,,,,,,, -104400,1.3150645,2.7820504,,,,,,,,,,,,,, -104475,,,0.73095703125,1.0725715160369873,0.6665999889373779,1.363909125328064,50000.0,0.5473000407218933,2.005379199981689,10000.0,47941.7282936573,53019.200771570206,47941.7282936573,5067.726749658585,4.222776651382446,0.0 -104500,1.4604149,2.976816,,,,,,,,,,,,,, -104600,1.5847567,2.1173134,,,,,,,,,,,,,, -104700,1.4385934,2.154326,,,,,,,,,,,,,, -104800,1.3225769,3.9771566,,,,,,,,,,,,,, -104900,1.5498257,2.1876307,,,,,,,,,,,,,, -105000,1.6931112,2.1930532,,,,,,,,,,,,,, -105100,1.4286379,3.4992697,,,,,,,,,,,,,, -105200,1.2075775,3.1652744,,,,,,,,,,,,,, -105300,1.5490994,2.2720203,,,,,,,,,,,,,, -105393,,,0.7487890720367432,1.0079954862594604,0.6653000116348267,1.3744615316390991,50000.0,0.5384000539779663,2.039724588394165,10000.0,48361.7201230526,53483.97005271912,48361.7201230526,5112.407633304596,4.270694017410278,0.0 -105400,1.533706,2.370441,,,,,,,,,,,,,, -105500,1.4435685,2.0495486,,,,,,,,,,,,,, -105600,1.3779638,4.338624,,,,,,,,,,,,,, -105700,1.3041004,4.529716,,,,,,,,,,,,,, -105800,1.4196596,3.1711302,,,,,,,,,,,,,, -105900,1.4185083,3.3833776,,,,,,,,,,,,,, -106000,1.6302145,2.2354686,,,,,,,,,,,,,, -106100,1.3988084,3.0732672,,,,,,,,,,,,,, -106200,1.5410932,2.1053915,,,,,,,,,,,,,, -106300,1.7305579,2.0494463,,,,,,,,,,,,,, -106309,,,0.7193945050239563,1.12610924243927,0.6675999760627747,1.3651163578033447,50000.0,0.5458000302314758,2.01989221572876,10000.0,48782.00963258743,53949.74782347679,48782.00963258743,5157.802113294601,4.31640887260437,0.0 -106400,1.5254114,2.165327,,,,,,,,,,,,,, -106500,1.6870493,2.2448258,,,,,,,,,,,,,, -106600,1.6518674,2.0672135,,,,,,,,,,,,,, -106700,1.3425953,4.425268,,,,,,,,,,,,,, -106800,1.4630075,2.313169,,,,,,,,,,,,,, -106900,1.6289074,2.0606368,,,,,,,,,,,,,, -107000,1.6233371,2.2017612,,,,,,,,,,,,,, -107100,1.6610045,2.2395635,,,,,,,,,,,,,, -107200,1.4398547,4.394596,,,,,,,,,,,,,, -107226,,,0.7369921803474426,1.0469598770141602,0.6708799600601196,1.35093891620636,50000.0,0.5460000038146973,2.00260591506958,10000.0,49202.203140735626,54414.82877445221,49202.203140735626,5202.594695091248,4.361671686172485,0.0 -107300,1.4658968,2.2187865,,,,,,,,,,,,,, -107400,1.3420005,3.4178643,,,,,,,,,,,,,, -107500,1.4218833,4.7553897,,,,,,,,,,,,,, -107600,1.4336598,3.3577654,,,,,,,,,,,,,, -107700,1.4251082,2.7195728,,,,,,,,,,,,,, -107800,1.8049479,2.0500765,,,,,,,,,,,,,, -107900,1.5724833,2.2591662,,,,,,,,,,,,,, -108000,1.4889096,2.3489664,,,,,,,,,,,,,, -108100,1.6952082,2.0972767,,,,,,,,,,,,,, -108140,,,0.7404687404632568,1.0386266708374023,0.6669600009918213,1.367197036743164,50000.0,0.5439000129699707,2.017230749130249,10000.0,49622.17845416069,54879.68249583244,49622.17845416069,5247.3822016716,4.404340744018555,0.0 -108200,1.5263925,3.4337294,,,,,,,,,,,,,, -108300,1.6449958,2.3012068,,,,,,,,,,,,,, -108400,1.5652543,4.597021,,,,,,,,,,,,,, -108500,1.3739102,4.659663,,,,,,,,,,,,,, -108600,1.4566169,4.1223063,,,,,,,,,,,,,, -108700,1.3270011,4.4431386,,,,,,,,,,,,,, -108800,1.446904,4.127452,,,,,,,,,,,,,, -108900,1.4094216,3.8520393,,,,,,,,,,,,,, -109000,1.3202175,3.8676612,,,,,,,,,,,,,, -109057,,,0.7318944931030273,1.066878318786621,0.674839973449707,1.330930471420288,50000.0,0.5503000020980835,1.989863038063049,10000.0,50042.48392677307,55349.32324099541,50042.48392677307,5296.62574672699,4.448246955871582,0.0 -109100,1.741214,2.1297245,,,,,,,,,,,,,, -109200,1.9570891,2.103009,,,,,,,,,,,,,, -109300,1.6797842,1.9361285,,,,,,,,,,,,,, -109400,1.6138023,2.799802,,,,,,,,,,,,,, -109500,1.6232079,2.1221828,,,,,,,,,,,,,, -109600,1.5642128,4.5982876,,,,,,,,,,,,,, -109700,1.4599519,4.64149,,,,,,,,,,,,,, -109800,1.4144756,4.5618963,,,,,,,,,,,,,, -109900,1.3094164,3.4698389,,,,,,,,,,,,,, -109975,,,0.7314843535423279,1.0798393487930298,0.6667199730873108,1.3757458925247192,50000.0,0.5470000505447388,2.032017946243286,10000.0,50462.67532229424,55816.034700632095,50462.67532229424,5343.052444219589,4.493457078933716,0.0 -110000,1.3062214,4.606292,,,,,,,,,,,,,, -110100,1.5629324,2.4333467,,,,,,,,,,,,,, -110200,1.8937806,2.110199,,,,,,,,,,,,,, -110300,1.4029679,4.1340055,,,,,,,,,,,,,, -110400,1.6076833,3.1093209,,,,,,,,,,,,,, -110500,1.6269437,2.1589775,,,,,,,,,,,,,, -110600,1.3800062,2.480411,,,,,,,,,,,,,, -110700,1.6696917,2.6608627,,,,,,,,,,,,,, -110800,1.6172532,2.2064402,,,,,,,,,,,,,, -110892,,,0.7429296970367432,1.0271873474121094,0.6739000082015991,1.34721040725708,50000.0,0.5520000457763672,1.9919664859771729,10000.0,50882.97862505913,56279.88754367829,50882.97862505913,5386.50735449791,4.539769411087036,0.0 -110900,1.6683894,2.1610172,,,,,,,,,,,,,, -111000,1.6679326,2.0837724,,,,,,,,,,,,,, -111100,1.7072842,2.11256,,,,,,,,,,,,,, -111200,2.106134,2.508706,,,,,,,,,,,,,, -111300,1.4824661,4.710743,,,,,,,,,,,,,, -111400,1.6995896,2.0383854,,,,,,,,,,,,,, -111500,1.8964199,2.2088094,,,,,,,,,,,,,, -111600,1.3994086,3.5452397,,,,,,,,,,,,,, -111700,1.4996831,3.8622913,,,,,,,,,,,,,, -111800,1.5700053,3.6431878,,,,,,,,,,,,,, -111809,,,0.7333788871765137,1.0607457160949707,0.6744999885559082,1.334320902824402,50000.0,0.5468000173568726,1.9874333143234253,10000.0,51303.28646111488,56748.28958415985,51303.28646111488,5434.5012810230255,4.5913405418396,0.0 -111900,1.3936007,3.298483,,,,,,,,,,,,,, -112000,1.5755514,2.4454346,,,,,,,,,,,,,, -112100,1.7576777,2.0402315,,,,,,,,,,,,,, -112200,1.5592946,3.9042978,,,,,,,,,,,,,, -112300,1.7362756,2.1031947,,,,,,,,,,,,,, -112400,1.6827788,2.5172296,,,,,,,,,,,,,, -112500,1.4383789,3.0567133,,,,,,,,,,,,,, -112600,1.542577,2.0845108,,,,,,,,,,,,,, -112700,1.5562962,2.931026,,,,,,,,,,,,,, -112725,,,0.732714831829071,1.0809401273727417,0.6717199683189392,1.3633190393447876,50000.0,0.5455000400543213,2.0178678035736084,10000.0,51723.2543463707,57214.07962059975,51723.2543463707,5480.232470989227,4.634597301483154,0.0 -112800,1.5915639,2.0915024,,,,,,,,,,,,,, -112900,1.6727498,2.652281,,,,,,,,,,,,,, -113000,1.548704,4.302472,,,,,,,,,,,,,, -113100,1.6243488,2.1468205,,,,,,,,,,,,,, -113200,1.697935,2.469475,,,,,,,,,,,,,, -113300,1.6624372,2.0420694,,,,,,,,,,,,,, -113400,1.6979272,2.2869208,,,,,,,,,,,,,, -113500,1.9473131,2.046499,,,,,,,,,,,,,, -113600,1.6072645,2.0374055,,,,,,,,,,,,,, -113642,,,0.7489062547683716,0.9951205253601074,0.67603999376297,1.3217229843139648,50000.0,0.5550000071525574,1.975521326065064,10000.0,52143.389543771744,57680.77817702293,52143.389543771744,5526.704137802124,4.677294731140137,0.0 -113700,1.4309113,2.277311,,,,,,,,,,,,,, -113800,1.5758021,2.3942695,,,,,,,,,,,,,, -113900,1.817001,2.1707053,,,,,,,,,,,,,, -114000,1.5123004,4.2903733,,,,,,,,,,,,,, -114100,1.6386054,2.0663035,,,,,,,,,,,,,, -114200,1.5825998,3.8337443,,,,,,,,,,,,,, -114300,1.6303351,2.118526,,,,,,,,,,,,,, -114400,1.672858,4.402605,,,,,,,,,,,,,, -114500,1.3854545,2.7820883,,,,,,,,,,,,,, -114559,,,0.745898425579071,1.0103362798690796,0.6803799867630005,1.3071465492248535,50000.0,0.5550000071525574,1.9714163541793823,10000.0,52563.71753501892,58146.98582029343,52563.71753501892,5572.480983495712,4.731476306915283,0.0 -114600,1.6318614,2.049687,,,,,,,,,,,,,, -114700,1.5866708,2.9375966,,,,,,,,,,,,,, -114800,1.5821579,3.9704947,,,,,,,,,,,,,, -114900,1.5991071,3.0703351,,,,,,,,,,,,,, -115000,1.6301112,2.2584295,,,,,,,,,,,,,, -115100,1.5670974,1.9602954,,,,,,,,,,,,,, -115200,1.8605736,2.0613456,,,,,,,,,,,,,, -115300,1.591705,2.1478922,,,,,,,,,,,,,, -115400,1.5843168,2.136228,,,,,,,,,,,,,, -115477,,,0.7442968487739563,1.0438237190246582,0.6782799959182739,1.3390940427780151,50000.0,0.5524000525474548,1.993807315826416,10000.0,52983.67902421951,58612.56900882721,52983.67902421951,5618.007117033005,4.778652191162109,0.0 -115500,1.5222895,2.4876976,,,,,,,,,,,,,, -115600,1.6619468,2.0237935,,,,,,,,,,,,,, -115700,1.6069676,1.9922383,,,,,,,,,,,,,, -115800,1.4675273,3.9946043,,,,,,,,,,,,,, -115900,1.595443,1.9932959,,,,,,,,,,,,,, -116000,1.5109105,3.229848,,,,,,,,,,,,,, -116100,1.8545341,2.1700613,,,,,,,,,,,,,, -116200,1.4903977,3.41977,,,,,,,,,,,,,, -116300,1.7320554,2.0698857,,,,,,,,,,,,,, -116394,,,0.7509765625,1.0075005292892456,0.6782000064849854,1.3190701007843018,50000.0,0.5547000169754028,1.96835458278656,10000.0,53403.994445085526,59079.57103824616,53403.994445085526,5664.597292423248,4.826295614242554,0.0 -116400,1.6041138,2.617467,,,,,,,,,,,,,, -116500,1.6209977,2.0075624,,,,,,,,,,,,,, -116600,1.4373957,2.1845765,,,,,,,,,,,,,, -116700,1.7065431,2.3279746,,,,,,,,,,,,,, -116800,1.5141745,3.6794038,,,,,,,,,,,,,, -116900,1.7399585,2.2390687,,,,,,,,,,,,,, -117000,1.747234,2.0159113,,,,,,,,,,,,,, -117100,1.4841708,3.7959137,,,,,,,,,,,,,, -117200,1.7039559,2.0668378,,,,,,,,,,,,,, -117300,1.8186485,2.2367816,,,,,,,,,,,,,, -117310,,,0.7648828029632568,0.9348188638687134,0.6816399693489075,1.3038502931594849,50000.0,0.5551000237464905,1.9470292329788208,10000.0,53824.31056809425,59545.86528062821,53824.31056809425,5710.4811136722565,4.872780084609985,0.0 -117400,1.7551033,2.059718,,,,,,,,,,,,,, -117500,1.9282532,1.9917167,,,,,,,,,,,,,, -117600,1.5774728,2.6268425,,,,,,,,,,,,,, -117700,1.6260068,1.952743,,,,,,,,,,,,,, -117800,1.6498829,2.0066502,,,,,,,,,,,,,, -117900,1.8120983,1.9649299,,,,,,,,,,,,,, -118000,1.7082926,2.0350366,,,,,,,,,,,,,, -118100,1.9054116,2.076577,,,,,,,,,,,,,, -118200,1.822577,2.0493927,,,,,,,,,,,,,, -118227,,,0.7464257478713989,1.029775619506836,0.6835199594497681,1.3185490369796753,50000.0,0.5621000528335571,1.951754570007324,10000.0,54244.33873510361,60013.35964846611,54244.33873510361,5757.854747772217,4.916918992996216,0.0 -118300,1.5848004,2.0221498,,,,,,,,,,,,,, -118400,1.6395421,4.528708,,,,,,,,,,,,,, -118500,1.6145221,3.6132188,,,,,,,,,,,,,, -118600,1.7874819,2.071597,,,,,,,,,,,,,, -118700,1.5789583,3.9264758,,,,,,,,,,,,,, -118800,1.7079452,1.9831417,,,,,,,,,,,,,, -118900,1.6940736,2.687819,,,,,,,,,,,,,, -119000,1.7778238,1.8998724,,,,,,,,,,,,,, -119100,1.9504229,2.0552871,,,,,,,,,,,,,, -119144,,,0.7531836032867432,0.9818204045295716,0.6846799850463867,1.296289563179016,50000.0,0.5582000017166138,1.942766547203064,10000.0,54664.30675196648,60477.11735010147,54664.30675196648,5801.55264043808,4.960393190383911,0.0 -119200,1.6920694,1.9475915,,,,,,,,,,,,,, -119300,1.7554567,2.0535336,,,,,,,,,,,,,, -119400,1.6382781,2.4368672,,,,,,,,,,,,,, -119500,1.5633543,2.4436502,,,,,,,,,,,,,, -119600,1.726533,2.5284557,,,,,,,,,,,,,, -119700,1.708142,2.1202135,,,,,,,,,,,,,, -119800,1.527415,3.3851292,,,,,,,,,,,,,, -119900,1.6040471,4.220422,,,,,,,,,,,,,, -120000,1.8977568,1.9404941,,,,,,,,,,,,,, -120060,,,0.7615820169448853,1.0147477388381958,0.6805799603462219,1.354517936706543,50000.0,0.5561000108718872,2.008143424987793,10000.0,55084.343988895416,60941.289311409,55084.343988895416,5845.588560342789,5.010600090026856,0.0 -120100,1.5413464,3.4076118,,,,,,,,,,,,,, -120200,1.618692,3.9393573,,,,,,,,,,,,,, -120300,1.4787815,3.5487943,,,,,,,,,,,,,, -120400,1.6497865,2.0552974,,,,,,,,,,,,,, -120500,1.6727964,2.0076735,,,,,,,,,,,,,, -120600,1.462339,3.288091,,,,,,,,,,,,,, -120700,1.7031149,2.089211,,,,,,,,,,,,,, -120800,1.8872913,1.9172399,,,,,,,,,,,,,, -120900,1.7801785,2.0196981,,,,,,,,,,,,,, -120978,,,0.7485546469688416,1.011296033859253,0.6860799789428711,1.2949484586715698,50000.0,0.5636000037193298,1.9418689012527464,10000.0,55504.44954395294,61407.6800467968,55504.44954395294,5891.7737855911255,5.061231851577759,0.0 -121000,1.7340136,1.8921528,,,,,,,,,,,,,, -121100,1.786664,2.0244217,,,,,,,,,,,,,, -121200,1.7433358,1.9003171,,,,,,,,,,,,,, -121300,1.5621829,3.5200558,,,,,,,,,,,,,, -121400,1.6271237,3.6666896,,,,,,,,,,,,,, -121500,1.8347199,2.1035051,,,,,,,,,,,,,, -121600,1.6312126,3.773524,,,,,,,,,,,,,, -121700,1.5626087,2.1516135,,,,,,,,,,,,,, -121800,1.8594273,1.8679882,,,,,,,,,,,,,, -121895,,,0.7560937404632568,0.9640070796012878,0.685259997844696,1.2753068208694458,50000.0,0.5637000203132629,1.916074275970459,10000.0,55924.67421007157,61872.32814979553,55924.67421007157,5936.102333545685,5.1081626415252686,0.0 -121900,1.5535706,3.6961856,,,,,,,,,,,,,, -122000,1.6919374,1.8803065,,,,,,,,,,,,,, -122100,1.6663686,2.0632308,,,,,,,,,,,,,, -122200,1.7976708,4.5034676,,,,,,,,,,,,,, -122300,1.8187364,2.6225927,,,,,,,,,,,,,, -122400,1.7027305,4.5059624,,,,,,,,,,,,,, -122500,1.79291,1.9005048,,,,,,,,,,,,,, -122600,1.596266,4.514431,,,,,,,,,,,,,, -122700,1.5971566,2.9158046,,,,,,,,,,,,,, -122800,1.558317,4.5489244,,,,,,,,,,,,,, -122812,,,0.7655858993530273,0.926956832408905,0.6911399960517883,1.2588539123535156,50000.0,0.5688000321388245,1.9093610048294067,10000.0,56344.74541926384,62336.69419193268,56344.74541926384,5980.293684959412,5.162423133850098,0.0 -122900,1.7960874,2.3836777,,,,,,,,,,,,,, -123000,1.9977285,1.9149742,,,,,,,,,,,,,, -123100,2.1360786,1.9220383,,,,,,,,,,,,,, -123200,1.9789628,4.3565054,,,,,,,,,,,,,, -123300,1.9372238,2.0882971,,,,,,,,,,,,,, -123400,1.594541,3.0128522,,,,,,,,,,,,,, -123500,1.6877302,3.5703685,,,,,,,,,,,,,, -123600,1.7240748,1.937362,,,,,,,,,,,,,, -123700,1.5081259,2.5834112,,,,,,,,,,,,,, -123732,,,0.7517968416213989,1.0083987712860107,0.6892200112342834,1.288521409034729,50000.0,0.5617000460624695,1.9466270208358765,10000.0,56765.02931785584,62801.70062446594,56765.02931785584,6024.918617486954,5.211941957473755,0.0 -123800,1.7578878,4.5162354,,,,,,,,,,,,,, -123900,1.836155,2.051063,,,,,,,,,,,,,, -124000,1.6592332,3.989605,,,,,,,,,,,,,, -124100,1.762499,3.8527174,,,,,,,,,,,,,, -124200,1.7971447,2.1713455,,,,,,,,,,,,,, -124300,1.7212353,4.4109297,,,,,,,,,,,,,, -124400,1.8185353,2.022836,,,,,,,,,,,,,, -124500,1.8435723,1.9014411,,,,,,,,,,,,,, -124600,1.6629243,2.7315676,,,,,,,,,,,,,, -124648,,,0.7625976204872131,0.9537436366081238,0.6960600018501282,1.2512370347976685,50000.0,0.5719000101089478,1.9025684595108032,10000.0,57185.08319354057,63268.40723109245,57185.08319354057,6071.47722363472,5.257215738296509,0.0 -124700,1.766927,2.0459673,,,,,,,,,,,,,, -124800,1.9320203,1.8838493,,,,,,,,,,,,,, -124900,2.0691428,1.8794018,,,,,,,,,,,,,, -125000,1.6328677,2.9865556,,,,,,,,,,,,,, -125100,1.9160057,1.9009823,,,,,,,,,,,,,, -125200,1.8166999,1.9449253,,,,,,,,,,,,,, -125300,2.0914881,1.9156922,,,,,,,,,,,,,, -125400,1.8337678,2.3792374,,,,,,,,,,,,,, -125500,1.6640192,4.154638,,,,,,,,,,,,,, -125566,,,0.7717187404632568,0.9368151426315308,0.6967399716377258,1.2695817947387695,50000.0,0.5688000321388245,1.9237370491027832,10000.0,57605.48809599877,63732.91115617752,57605.48809599877,6115.479986190796,5.305594205856323,0.0 -125600,1.7914798,1.8355727,,,,,,,,,,,,,, -125700,1.806911,3.4402518,,,,,,,,,,,,,, -125800,1.8063564,2.1388893,,,,,,,,,,,,,, -125900,1.784458,2.9641187,,,,,,,,,,,,,, -126000,1.6800838,3.602517,,,,,,,,,,,,,, -126100,1.9779018,1.9870574,,,,,,,,,,,,,, -126200,1.7950609,1.9649363,,,,,,,,,,,,,, -126300,2.2256522,2.020818,,,,,,,,,,,,,, -126400,1.664104,3.9785752,,,,,,,,,,,,,, -126484,,,0.7656054496765137,0.9447650909423828,0.6959199905395508,1.2564334869384766,50000.0,0.5696000456809998,1.9149582386016848,10000.0,58025.8363161087,64197.31534385681,58025.8363161087,6159.436917304993,5.356076002120972,0.0 -126500,1.6911987,2.5279791,,,,,,,,,,,,,, -126600,1.708844,1.8666835,,,,,,,,,,,,,, -126700,1.8637648,3.0123453,,,,,,,,,,,,,, -126800,1.801968,1.9317613,,,,,,,,,,,,,, -126900,1.885333,2.0158107,,,,,,,,,,,,,, -127000,1.7793472,1.9008791,,,,,,,,,,,,,, -127100,1.7861879,2.2737215,,,,,,,,,,,,,, -127200,2.1467962,1.9355335,,,,,,,,,,,,,, -127300,1.7227715,3.3156629,,,,,,,,,,,,,, -127400,1.9175937,1.8999338,,,,,,,,,,,,,, -127403,,,0.7616015672683716,0.9457040429115297,0.696399986743927,1.245165467262268,50000.0,0.5748000144958496,1.8830746412277224,10000.0,58445.861562252045,64663.24557733536,58445.861562252045,6205.239112854004,5.409669399261475,0.0 -127500,1.7171274,2.4539056,,,,,,,,,,,,,, -127600,1.9078038,2.0728602,,,,,,,,,,,,,, -127700,2.0848386,1.9454405,,,,,,,,,,,,,, -127800,1.5785801,3.7313485,,,,,,,,,,,,,, -127900,1.7923033,1.9211835,,,,,,,,,,,,,, -128000,1.7063391,2.0984433,,,,,,,,,,,,,, -128100,1.7199174,2.1205828,,,,,,,,,,,,,, -128200,1.7932452,2.2464638,,,,,,,,,,,,,, -128300,1.9827387,1.9602815,,,,,,,,,,,,,, -128313,,,0.7695116996765137,0.9001871347427368,0.6969000101089478,1.226536512374878,50000.0,0.5778000354766846,1.8673149347305296,10000.0,58865.82812094688,65126.63528132439,58865.82812094688,6248.568606853485,5.454848766326904,0.0 -128400,1.8155885,1.7893236,,,,,,,,,,,,,, -128500,1.7940046,2.1228008,,,,,,,,,,,,,, -128600,1.8359951,1.9773577,,,,,,,,,,,,,, -128700,2.0110893,2.1406474,,,,,,,,,,,,,, -128800,2.2385123,1.9769571,,,,,,,,,,,,,, -128900,1.8436297,1.9379169,,,,,,,,,,,,,, -129000,1.8076401,4.426574,,,,,,,,,,,,,, -129100,1.8237464,2.9035568,,,,,,,,,,,,,, -129200,1.7202351,4.023366,,,,,,,,,,,,,, -129230,,,0.7834765315055847,0.878718376159668,0.7016800045967102,1.2358816862106323,50000.0,0.5782999992370605,1.8833563327789309,10000.0,59285.90663409233,65588.18467664719,59285.90663409233,6289.945489883423,5.501060485839844,0.0 -129300,1.736627,2.2281046,,,,,,,,,,,,,, -129400,1.8639473,2.7698617,,,,,,,,,,,,,, -129500,1.9860454,1.9199216,,,,,,,,,,,,,, -129600,2.0968242,1.8437825,,,,,,,,,,,,,, -129700,1.7026125,2.2541819,,,,,,,,,,,,,, -129800,2.0774403,4.214578,,,,,,,,,,,,,, -129900,1.946953,4.1040044,,,,,,,,,,,,,, -130000,1.9541023,3.479782,,,,,,,,,,,,,, -130100,1.8090886,4.278258,,,,,,,,,,,,,, -130148,,,0.7674999833106995,0.9318538308143616,0.7002800107002258,1.227251648902893,50000.0,0.576200008392334,1.8715360164642327,10000.0,59706.12528705597,66053.80458760262,59706.12528705597,6335.249958515167,5.549408435821533,0.0 -130200,1.8610848,1.9298178,,,,,,,,,,,,,, -130300,1.8308194,2.324808,,,,,,,,,,,,,, -130400,1.7282407,2.514221,,,,,,,,,,,,,, -130500,1.9953773,1.9169891,,,,,,,,,,,,,, -130600,1.7846699,2.954384,,,,,,,,,,,,,, -130700,2.0344791,1.9417193,,,,,,,,,,,,,, -130800,1.7878828,2.7083955,,,,,,,,,,,,,, -130900,1.8815789,1.8463918,,,,,,,,,,,,,, -131000,1.9030905,3.2810457,,,,,,,,,,,,,, -131067,,,0.7759179472923279,0.9067143797874452,0.7016400098800659,1.23651921749115,50000.0,0.576200008392334,1.8704854249954224,10000.0,60126.05570912361,66519.6564245224,60126.05570912361,6381.077441215515,5.593918085098267,0.0 -131100,1.9540573,1.8226367,,,,,,,,,,,,,, -131200,1.779493,2.3823223,,,,,,,,,,,,,, -131300,1.6092638,3.020935,,,,,,,,,,,,,, -131400,2.074635,1.960342,,,,,,,,,,,,,, -131500,1.9668555,2.233986,,,,,,,,,,,,,, -131600,2.0466251,2.1844833,,,,,,,,,,,,,, -131700,2.2153416,1.9331938,,,,,,,,,,,,,, -131800,1.9336697,1.9650443,,,,,,,,,,,,,, -131900,1.7592586,3.2289455,,,,,,,,,,,,,, -131984,,,0.7800976634025574,0.899075448513031,0.7061600089073181,1.231971263885498,50000.0,0.584600031375885,1.861109495162964,10000.0,60546.168536663055,66985.56582260132,60546.168536663055,6426.768949747086,5.6505677700042725,0.0 -132000,2.0381496,1.7260435,,,,,,,,,,,,,, -132100,1.9096642,1.666004,,,,,,,,,,,,,, -132200,1.8359059,2.3526933,,,,,,,,,,,,,, -132300,1.7356323,3.219598,,,,,,,,,,,,,, -132400,2.1923423,1.8717724,,,,,,,,,,,,,, -132500,1.9600155,2.0921307,,,,,,,,,,,,,, -132600,2.179864,1.834744,,,,,,,,,,,,,, -132700,1.9599601,4.345562,,,,,,,,,,,,,, -132800,1.8500973,1.8251415,,,,,,,,,,,,,, -132900,2.0915341,1.8571554,,,,,,,,,,,,,, -132902,,,0.7770116925239563,0.8806419372558594,0.7057600021362305,1.1938964128494265,50000.0,0.5800999999046326,1.8424324989318848,10000.0,60966.1720366478,67448.69650387764,60966.1720366478,6469.7969336509705,5.7013936042785645,0.0 -133000,1.9386799,1.8173949,,,,,,,,,,,,,, -133100,1.650016,2.6628447,,,,,,,,,,,,,, -133200,2.0972853,2.8445442,,,,,,,,,,,,,, -133300,2.591588,4.2722573,,,,,,,,,,,,,, -133400,2.1460948,1.8714061,,,,,,,,,,,,,, -133500,2.1908798,1.839875,,,,,,,,,,,,,, -133600,2.4809892,1.7901833,,,,,,,,,,,,,, -133700,1.7865332,2.5883522,,,,,,,,,,,,,, -133800,2.2655897,1.8299949,,,,,,,,,,,,,, -133822,,,0.7735351324081421,0.9415649771690368,0.7021200060844421,1.2566030025482178,50000.0,0.5730000138282776,1.9204665422439573,10000.0,61386.14681506157,67912.33575248718,61386.14681506157,6513.367963075638,5.74661111831665,0.0 -133900,2.1413596,2.36382,,,,,,,,,,,,,, -134000,2.1162431,1.9854345,,,,,,,,,,,,,, -134100,1.7794001,2.3574047,,,,,,,,,,,,,, -134200,1.9036418,2.6495779,,,,,,,,,,,,,, -134300,2.234661,1.6911825,,,,,,,,,,,,,, -134400,1.8131402,4.1888275,,,,,,,,,,,,,, -134500,1.9413418,2.2851555,,,,,,,,,,,,,, -134600,2.0942192,1.939968,,,,,,,,,,,,,, -134700,1.849656,3.0067806,,,,,,,,,,,,,, -134742,,,0.7837499976158142,0.8793671727180481,0.7064200043678284,1.2217317819595337,50000.0,0.5799000263214111,1.858611226081848,10000.0,61806.27965426445,68377.54789853096,61806.27965426445,6558.35079741478,5.793826341629028,0.0 -134800,1.9425957,1.9764466,,,,,,,,,,,,,, -134900,2.1016936,1.7847167,,,,,,,,,,,,,, -135000,2.020882,3.208064,,,,,,,,,,,,,, -135100,1.9177374,1.9412144,,,,,,,,,,,,,, -135200,1.8396252,3.956991,,,,,,,,,,,,,, -135300,2.0486786,1.8784517,,,,,,,,,,,,,, -135400,1.7660737,3.1635652,,,,,,,,,,,,,, -135500,1.9266826,4.1659045,,,,,,,,,,,,,, -135600,2.067854,4.1247416,,,,,,,,,,,,,, -135659,,,0.7777929306030273,0.8910472393035889,0.708579957485199,1.2002291679382324,50000.0,0.5830000042915344,1.8402303457260127,10000.0,62226.222628593445,68841.37681627274,62226.222628593445,6602.141601085663,5.841196537017822,0.0 -135700,2.0771096,1.7551267,,,,,,,,,,,,,, -135800,2.2739167,1.8294904,,,,,,,,,,,,,, -135900,2.270748,1.7458165,,,,,,,,,,,,,, -136000,1.9349475,2.042045,,,,,,,,,,,,,, -136100,2.0319026,1.9256277,,,,,,,,,,,,,, -136200,2.0086691,1.8278642,,,,,,,,,,,,,, -136300,1.9264144,2.4078314,,,,,,,,,,,,,, -136400,1.7594656,2.5707533,,,,,,,,,,,,,, -136500,2.1897855,1.8239337,,,,,,,,,,,,,, -136576,,,0.7836328148841858,0.8509073257446289,0.7106199860572815,1.1824514865875244,50000.0,0.5878000259399414,1.8325586318969729,10000.0,62646.28657245636,69306.7888610363,62646.28657245636,6647.391499996185,5.891319513320923,0.0 -136600,2.159587,1.7055846,,,,,,,,,,,,,, -136700,2.0871234,1.8198608,,,,,,,,,,,,,, -136800,1.8818924,2.4132857,,,,,,,,,,,,,, -136900,2.0755105,1.7721331,,,,,,,,,,,,,, -137000,2.1786358,1.7647119,,,,,,,,,,,,,, -137100,2.0045924,2.1095128,,,,,,,,,,,,,, -137200,2.1083808,1.9951947,,,,,,,,,,,,,, -137300,2.163423,4.207585,,,,,,,,,,,,,, -137400,1.9534222,3.3462477,,,,,,,,,,,,,, -137494,,,0.7882031202316284,0.8359939455986023,0.7127199769020081,1.1701191663742063,50000.0,0.5861999988555908,1.8159235715866089,10000.0,63066.573899030685,69768.32604432106,63066.573899030685,6688.541975975037,5.941382884979248,0.0 -137500,2.0430803,1.7129236,,,,,,,,,,,,,, -137600,1.7934223,2.1265228,,,,,,,,,,,,,, -137700,2.0373166,3.2407274,,,,,,,,,,,,,, -137800,1.8839242,2.2738943,,,,,,,,,,,,,, -137900,1.9826391,2.2579124,,,,,,,,,,,,,, -138000,2.0985434,1.9657804,,,,,,,,,,,,,, -138100,2.3601513,2.8733115,,,,,,,,,,,,,, -138200,2.1402805,1.764719,,,,,,,,,,,,,, -138300,2.546577,1.7899704,,,,,,,,,,,,,, -138400,2.0471168,2.3417985,,,,,,,,,,,,,, -138412,,,0.7888476252555847,0.8295444250106812,0.7119199633598328,1.178058624267578,50000.0,0.5888000130653381,1.816422939300537,10000.0,63486.502836704254,70235.64741063118,63486.502836704254,6735.836848020554,5.990756034851074,0.0 -138500,1.8496189,2.8136811,,,,,,,,,,,,,, -138600,2.0298202,1.6740696,,,,,,,,,,,,,, -138700,2.1573193,4.0740747,,,,,,,,,,,,,, -138800,2.1948454,1.8339237,,,,,,,,,,,,,, -138900,2.009768,4.371686,,,,,,,,,,,,,, -139000,2.1372113,4.311814,,,,,,,,,,,,,, -139100,2.117327,1.7112272,,,,,,,,,,,,,, -139200,2.0846593,1.7249795,,,,,,,,,,,,,, -139300,2.4984467,1.8209556,,,,,,,,,,,,,, -139332,,,0.7895312309265137,0.8278987407684326,0.7135799527168274,1.1600850820541382,50000.0,0.5918000340461731,1.8050135374069207,10000.0,63906.77509832382,70701.96142339706,63906.77509832382,6781.783026695252,6.037896156311035,0.0 -139400,2.0915928,1.8406007,,,,,,,,,,,,,, -139500,2.6093087,1.6469789,,,,,,,,,,,,,, -139600,2.3283315,1.7610549,,,,,,,,,,,,,, -139700,3.1348612,1.7521437,,,,,,,,,,,,,, -139800,1.9688271,3.9273922,,,,,,,,,,,,,, -139900,2.2178473,3.2822192,,,,,,,,,,,,,, -140000,2.119878,1.8341316,,,,,,,,,,,,,, -140100,1.8569901,3.3299003,,,,,,,,,,,,,, -140200,2.1801891,1.9141864,,,,,,,,,,,,,, -140251,,,0.7918164134025574,0.8284945487976074,0.7123599648475647,1.180811047554016,50000.0,0.5911000370979309,1.8107106685638428,10000.0,64326.802830696106,71165.90361714363,64326.802830696106,6825.600820064545,6.085220098495483,0.0 -140300,2.1002362,1.9380715,,,,,,,,,,,,,, -140400,2.0004992,4.207368,,,,,,,,,,,,,, -140500,2.2788937,2.0234346,,,,,,,,,,,,,, -140600,1.9845626,3.383141,,,,,,,,,,,,,, -140700,1.9584278,4.1924877,,,,,,,,,,,,,, -140800,2.1821651,1.8872702,,,,,,,,,,,,,, -140900,2.142945,1.7994442,,,,,,,,,,,,,, -141000,2.0781424,1.7914917,,,,,,,,,,,,,, -141100,2.1005065,1.8314989,,,,,,,,,,,,,, -141169,,,0.8020898103713989,0.7641717195510864,0.714199960231781,1.146799087524414,50000.0,0.5927000045776367,1.7914958000183103,10000.0,64746.913219451904,71633.59379124641,64746.913219451904,6873.079336643219,6.138157367706299,0.0 -141200,1.8282261,2.4078937,,,,,,,,,,,,,, -141300,2.0743544,1.7383736,,,,,,,,,,,,,, -141400,4.198553,4.2643147,,,,,,,,,,,,,, -141500,2.3095942,1.8032494,,,,,,,,,,,,,, -141600,2.0797865,1.6591356,,,,,,,,,,,,,, -141700,2.000022,1.8357275,,,,,,,,,,,,,, -141800,2.2655923,1.7682521,,,,,,,,,,,,,, -141900,2.2711713,1.7657485,,,,,,,,,,,,,, -142000,2.2052467,1.7421576,,,,,,,,,,,,,, -142084,,,0.7903515696525574,0.8418607115745544,0.718459963798523,1.1597744226455688,50000.0,0.5960000157356262,1.807189583778381,10000.0,65167.130407333374,72098.80448961258,65167.130407333374,6917.972964763641,6.189510822296143,0.0 -142100,1.9392322,2.8725572,,,,,,,,,,,,,, -142200,2.0410829,4.2661,,,,,,,,,,,,,, -142300,2.1967046,2.1387248,,,,,,,,,,,,,, -142400,2.1394124,1.7509679,,,,,,,,,,,,,, -142500,2.3120518,3.7362683,,,,,,,,,,,,,, -142600,2.3677661,4.001456,,,,,,,,,,,,,, -142700,1.9825238,2.2678308,,,,,,,,,,,,,, -142800,2.206221,4.2037067,,,,,,,,,,,,,, -142900,2.164972,2.584123,,,,,,,,,,,,,, -143000,,,0.7939453125,0.8365713953971863,0.7152599692344666,1.183610916137695,50000.0,0.5969000458717346,1.820393681526184,10000.0,65587.38845300674,72562.14328551292,65587.38845300674,6960.956964015961,6.237833261489868,0.0 -143000,2.2476885,2.2278712,,,,,,,,,,,,,, -143100,2.1927516,4.1134806,,,,,,,,,,,,,, -143200,2.123268,2.9706202,,,,,,,,,,,,,, -143300,2.1621773,3.6016397,,,,,,,,,,,,,, -143400,2.0403538,1.7911506,,,,,,,,,,,,,, -143500,2.460069,1.7162583,,,,,,,,,,,,,, -143600,2.0959997,1.6666443,,,,,,,,,,,,,, -143700,2.0011277,1.6857064,,,,,,,,,,,,,, -143800,2.4267678,4.160328,,,,,,,,,,,,,, -143900,2.0261858,3.083529,,,,,,,,,,,,,, -143917,,,0.8037304282188416,0.8105840682983398,0.7167999744415283,1.1841614246368408,50000.0,0.5915000438690186,1.8317683935165403,10000.0,66007.30506968498,73030.0181658268,66007.30506968498,7008.814731359482,6.289177417755127,0.0 -144000,1.9976627,4.191102,,,,,,,,,,,,,, -144100,2.0617704,1.6896344,,,,,,,,,,,,,, -144200,1.9658415,2.7989345,,,,,,,,,,,,,, -144300,3.6384835,4.2034674,,,,,,,,,,,,,, -144400,2.0279684,2.4501595,,,,,,,,,,,,,, -144500,1.9428643,3.3280702,,,,,,,,,,,,,, -144600,2.9052703,2.4403267,,,,,,,,,,,,,, -144700,2.117588,2.8104334,,,,,,,,,,,,,, -144800,1.9985206,2.2215636,,,,,,,,,,,,,, -144834,,,0.7900585532188416,0.8303824663162231,0.7147600054740906,1.1643662452697754,50000.0,0.5910000205039978,1.8147989511489868,10000.0,66427.55252289772,73493.39942002296,66427.55252289772,7051.8511662483215,6.337826013565064,0.0 -144900,2.5043652,1.8354856,,,,,,,,,,,,,, -145000,2.1497831,2.5535462,,,,,,,,,,,,,, -145100,2.3753278,1.9603647,,,,,,,,,,,,,, -145200,2.1145122,1.7194074,,,,,,,,,,,,,, -145300,2.4902644,4.163534,,,,,,,,,,,,,, -145400,2.4614375,1.7729828,,,,,,,,,,,,,, -145500,2.182297,2.1999545,,,,,,,,,,,,,, -145600,2.504586,2.8379283,,,,,,,,,,,,,, -145700,2.3402112,1.6801379,,,,,,,,,,,,,, -145753,,,0.8025780916213989,0.7942774295806885,0.7219399809837341,1.1464852094650269,50000.0,0.600600004196167,1.7750864028930664,10000.0,66847.67327642441,73958.39906048775,66847.67327642441,7096.628978729248,6.390943288803101,0.0 -145800,2.3007915,1.8727276,,,,,,,,,,,,,, -145900,2.307043,4.156337,,,,,,,,,,,,,, -146000,2.3435736,1.7154142,,,,,,,,,,,,,, -146100,2.2639265,1.6879708,,,,,,,,,,,,,, -146200,2.185513,1.711448,,,,,,,,,,,,,, -146300,2.2734423,1.6747504,,,,,,,,,,,,,, -146400,2.064461,2.8096468,,,,,,,,,,,,,, -146500,2.5506895,4.1696134,,,,,,,,,,,,,, -146600,2.1680055,1.6939343,,,,,,,,,,,,,, -146669,,,0.8016601204872131,0.7725005149841309,0.7207199931144714,1.135236740112305,50000.0,0.6003000140190125,1.7748686075210571,10000.0,67267.95164108276,74423.24955821037,67267.95164108276,7141.1040625572205,6.439411401748657,0.0 -146700,2.6335607,1.7660941,,,,,,,,,,,,,, -146800,2.306194,4.155113,,,,,,,,,,,,,, -146900,2.0557728,3.6130886,,,,,,,,,,,,,, -147000,2.2739854,2.6639605,,,,,,,,,,,,,, -147100,2.0749724,2.5797431,,,,,,,,,,,,,, -147200,2.2200627,1.6913499,,,,,,,,,,,,,, -147300,2.1502283,3.7064793,,,,,,,,,,,,,, -147400,1.9861562,2.1203861,,,,,,,,,,,,,, -147500,3.534478,3.6960819,,,,,,,,,,,,,, -147587,,,0.7973241806030273,0.8021615147590637,0.7207799553871155,1.1381536722183228,50000.0,0.5991000533103943,1.7693490982055664,10000.0,67687.97954654694,74890.32166147232,67687.97954654694,7188.041525125504,6.497358798980713,0.0 -147600,2.2063227,3.9120593,,,,,,,,,,,,,, -147700,1.9731598,3.421628,,,,,,,,,,,,,, -147800,2.4260507,4.0682178,,,,,,,,,,,,,, -147900,2.2508287,1.6813767,,,,,,,,,,,,,, -148000,2.381452,1.7163234,,,,,,,,,,,,,, -148100,2.2293017,3.9021757,,,,,,,,,,,,,, -148200,2.37797,1.7816837,,,,,,,,,,,,,, -148300,2.4863892,1.5875816,,,,,,,,,,,,,, -148400,2.2244768,1.7389425,,,,,,,,,,,,,, -148500,2.282054,1.5701722,,,,,,,,,,,,,, -148506,,,0.8033788800239563,0.7781947255134583,0.7247799634933472,1.1242923736572266,50000.0,0.6051000356674194,1.7475770711898804,10000.0,68107.91986322403,75355.3785700798,68107.91986322403,7233.060445070267,6.546364784240723,0.0 -148600,2.1796243,3.5171785,,,,,,,,,,,,,, -148700,2.2429175,3.348194,,,,,,,,,,,,,, -148800,2.5211878,3.1776187,,,,,,,,,,,,,, -148900,2.2144694,2.8733845,,,,,,,,,,,,,, -149000,2.034016,3.8388162,,,,,,,,,,,,,, -149100,2.2455611,1.748673,,,,,,,,,,,,,, -149200,2.2214646,2.1854508,,,,,,,,,,,,,, -149300,2.5009863,1.6522679,,,,,,,,,,,,,, -149400,2.3950026,1.7351505,,,,,,,,,,,,,, -149422,,,0.8059765696525574,0.7608180046081543,0.7228999733924866,1.124770998954773,50000.0,0.6030000448226929,1.751966118812561,10000.0,68527.9065463543,75820.22951364517,68527.9065463543,7277.828405618668,6.59432315826416,0.0 -149500,2.125857,2.1590154,,,,,,,,,,,,,, -149600,2.363439,1.5430514,,,,,,,,,,,,,, -149700,2.2894526,1.7291489,,,,,,,,,,,,,, -149800,2.7017183,1.6943314,,,,,,,,,,,,,, -149900,2.6605356,1.7599568,,,,,,,,,,,,,, -150000,2.314825,3.9160604,,,,,,,,,,,,,, -150100,2.294265,1.6778874,,,,,,,,,,,,,, -150200,2.2554758,3.7811155,,,,,,,,,,,,,, -150300,2.3954036,4.023459,,,,,,,,,,,,,, -150341,,,0.818359375,0.733532190322876,0.7276999950408936,1.10873281955719,50000.0,0.6095000505447388,1.739927053451538,10000.0,68947.92167925835,76282.6093711853,68947.92167925835,7320.086533069611,6.65229082107544,0.0 -150400,2.073535,2.9611092,,,,,,,,,,,,,, -150500,2.5341363,1.6113727,,,,,,,,,,,,,, -150600,2.292881,2.396637,,,,,,,,,,,,,, -150700,2.7130952,2.0254433,,,,,,,,,,,,,, -150800,2.4346075,1.6305029,,,,,,,,,,,,,, -150900,2.461174,1.7706844,,,,,,,,,,,,,, -151000,2.114151,2.4473538,,,,,,,,,,,,,, -151100,1.9856858,2.3112469,,,,,,,,,,,,,, -151200,2.290554,1.5147141,,,,,,,,,,,,,, -151259,,,0.8030859231948853,0.7721931338310242,0.7271599769592285,1.108335256576538,50000.0,0.609000027179718,1.7397416830062866,10000.0,69367.8230752945,76748.19864630699,69367.8230752945,7365.671882867813,6.706914186477661,0.0 -151300,2.316984,1.5721678,,,,,,,,,,,,,, -151400,2.2861645,1.8011496,,,,,,,,,,,,,, -151500,2.3943608,3.7793236,,,,,,,,,,,,,, -151600,2.3320224,1.6689887,,,,,,,,,,,,,, -151700,2.295302,1.8686659,,,,,,,,,,,,,, -151800,2.4202864,3.263423,,,,,,,,,,,,,, -151900,2.7424796,3.8369498,,,,,,,,,,,,,, -152000,2.2427697,2.716628,,,,,,,,,,,,,, -152100,2.4392173,1.6462748,,,,,,,,,,,,,, -152177,,,0.8080468773841858,0.78282630443573,0.7276399731636047,1.1330397129058838,50000.0,0.6035000085830688,1.7696969509124756,10000.0,69787.99724078178,77214.63326454163,69787.99724078178,7411.831959962845,6.759262323379517,0.0 -152200,2.9774008,2.8026464,,,,,,,,,,,,,, -152300,2.4414806,3.8144372,,,,,,,,,,,,,, -152400,2.9117396,3.8220277,,,,,,,,,,,,,, -152500,2.644621,1.6799741,,,,,,,,,,,,,, -152600,2.6102197,1.8182042,,,,,,,,,,,,,, -152700,2.3423321,1.8588536,,,,,,,,,,,,,, -152800,2.238285,2.0614486,,,,,,,,,,,,,, -152900,3.609674,1.7723973,,,,,,,,,,,,,, -153000,2.3619163,1.6051759,,,,,,,,,,,,,, -153094,,,0.8205859065055847,0.7144888639450073,0.7292799949645996,1.1079001426696775,50000.0,0.6070000529289246,1.7434462308883667,10000.0,70208.0367231369,77679.76422834396,70208.0367231369,7456.8257756233215,6.808479070663452,0.0 -153100,2.5932572,2.7234066,,,,,,,,,,,,,, -153200,2.4047697,3.5549762,,,,,,,,,,,,,, -153300,2.227585,2.5331962,,,,,,,,,,,,,, -153400,2.0989408,2.9147983,,,,,,,,,,,,,, -153500,2.466275,2.743722,,,,,,,,,,,,,, -153600,2.4169693,1.7100312,,,,,,,,,,,,,, -153700,2.9167216,1.5976925,,,,,,,,,,,,,, -153800,2.689271,1.5873315,,,,,,,,,,,,,, -153900,2.2834837,1.9586464,,,,,,,,,,,,,, -154000,2.5174415,1.5893413,,,,,,,,,,,,,, -154011,,,0.8125,0.7371184229850769,0.7307400107383728,1.0911295413970947,50000.0,0.6048000454902649,1.7228014469146729,10000.0,70628.20611071587,78146.25729894638,70628.20611071587,7503.045190811157,6.8642542362213135,0.0 -154100,2.2807107,1.6463661,,,,,,,,,,,,,, -154200,3.1398885,1.690717,,,,,,,,,,,,,, -154300,2.5340972,1.7777779,,,,,,,,,,,,,, -154400,2.308639,2.1889055,,,,,,,,,,,,,, -154500,2.7468066,3.8032365,,,,,,,,,,,,,, -154600,2.4022002,1.6686282,,,,,,,,,,,,,, -154700,2.7282503,1.4892428,,,,,,,,,,,,,, -154800,2.382576,2.2497656,,,,,,,,,,,,,, -154900,2.4770837,2.019167,,,,,,,,,,,,,, -154931,,,0.8117187023162842,0.7915002107620239,0.7309199571609497,1.1422063112258911,50000.0,0.6079000234603882,1.7812516689300537,10000.0,71048.3641808033,78612.83320403099,71048.3641808033,7549.361694574356,6.916259765625,0.0 -155000,2.5976274,3.9522645,,,,,,,,,,,,,, -155100,2.4710507,1.5458817,,,,,,,,,,,,,, -155200,2.1620343,1.9555037,,,,,,,,,,,,,, -155300,2.641018,1.5871203,,,,,,,,,,,,,, -155400,2.489533,2.7925982,,,,,,,,,,,,,, -155500,3.0666046,1.6281099,,,,,,,,,,,,,, -155600,2.4565957,2.9121215,,,,,,,,,,,,,, -155700,2.6845355,1.6518115,,,,,,,,,,,,,, -155800,2.347258,3.3703363,,,,,,,,,,,,,, -155850,,,0.8199804425239563,0.7226928472518921,0.7334399819374084,1.1000254154205322,50000.0,0.6098000407218933,1.736724853515625,10000.0,71468.62477397919,79076.29350209236,71468.62477397919,7592.454738378525,6.974083662033081,0.0 -155900,2.437902,2.6957479,,,,,,,,,,,,,, -156000,2.4825222,1.6658368,,,,,,,,,,,,,, -156100,2.5043163,3.39325,,,,,,,,,,,,,, -156200,2.4079545,2.844847,,,,,,,,,,,,,, -156300,3.0053532,3.0698,,,,,,,,,,,,,, -156400,2.714908,3.0100574,,,,,,,,,,,,,, -156500,2.6143513,1.8314226,,,,,,,,,,,,,, -156600,2.5897057,1.6653157,,,,,,,,,,,,,, -156700,2.3359597,1.7982978,,,,,,,,,,,,,, -156768,,,0.8161718845367432,0.7212274074554443,0.7339999675750732,1.0783400535583496,50000.0,0.6101000308990479,1.7127060890197754,10000.0,71888.80693101883,79540.92209506035,71888.80693101883,7636.801145553589,7.02562427520752,0.0 -156800,2.5928733,1.6621327,,,,,,,,,,,,,, -156900,2.6127026,3.6623616,,,,,,,,,,,,,, -157000,2.6694098,2.56596,,,,,,,,,,,,,, -157100,2.7451138,1.6216692,,,,,,,,,,,,,, -157200,3.406022,1.7418394,,,,,,,,,,,,,, -157300,2.611841,1.6389369,,,,,,,,,,,,,, -157400,2.4478633,2.0393853,,,,,,,,,,,,,, -157500,2.5437467,1.537382,,,,,,,,,,,,,, -157600,2.3480678,2.0173376,,,,,,,,,,,,,, -157685,,,0.8190429210662842,0.7114618420600891,0.7349599599838257,1.075613021850586,50000.0,0.6080000400543213,1.7125824689865112,10000.0,72308.92112541199,80006.86278057098,72308.92112541199,7682.525693893433,7.078494548797607,0.0 -157700,2.6537454,1.5526322,,,,,,,,,,,,,, -157800,2.622187,1.6802983,,,,,,,,,,,,,, -157900,2.8148563,1.8598179,,,,,,,,,,,,,, -158000,2.8609328,1.7593513,,,,,,,,,,,,,, -158100,2.2310772,1.539893,,,,,,,,,,,,,, -158200,2.4345758,1.8897438,,,,,,,,,,,,,, -158300,3.6314185,1.8384285,,,,,,,,,,,,,, -158400,2.4051366,1.9590989,,,,,,,,,,,,,, -158500,2.4849162,1.5421655,,,,,,,,,,,,,, -158600,,,0.8244531154632568,0.6801031231880188,0.7355999946594238,1.058942794799805,50000.0,0.6096000075340271,1.695249319076538,10000.0,72728.9054980278,80471.80396866798,72728.9054980278,7727.382160902023,7.130783319473267,0.0 -158600,2.5981631,1.631819,,,,,,,,,,,,,, -158700,2.4617734,2.2360363,,,,,,,,,,,,,, -158800,2.2837894,2.399142,,,,,,,,,,,,,, -158900,2.5223608,1.5126678,,,,,,,,,,,,,, -159000,2.612571,2.976578,,,,,,,,,,,,,, -159100,2.5642757,3.2835166,,,,,,,,,,,,,, -159200,2.5001924,3.3344316,,,,,,,,,,,,,, -159300,2.4792447,1.584709,,,,,,,,,,,,,, -159400,2.8443787,3.5198193,,,,,,,,,,,,,, -159500,2.485581,2.133187,,,,,,,,,,,,,, -159516,,,0.8199218511581421,0.7040935158729553,0.737559974193573,1.067233324050903,50000.0,0.6140000224113464,1.6969444751739502,10000.0,73148.94817018509,80938.11805319786,73148.94817018509,7773.549040794373,7.186861276626587,0.0 -159600,2.6888254,1.5446907,,,,,,,,,,,,,, -159700,2.5629718,1.7756238,,,,,,,,,,,,,, -159800,2.9632206,1.5167506,,,,,,,,,,,,,, -159900,2.6000085,1.567032,,,,,,,,,,,,,, -160000,2.5029926,1.7693565,,,,,,,,,,,,,, -160100,2.6445339,1.5854596,,,,,,,,,,,,,, -160200,2.5499148,1.9680482,,,,,,,,,,,,,, -160300,2.5934968,1.5422348,,,,,,,,,,,,,, -160400,3.1774037,1.6054146,,,,,,,,,,,,,, -160432,,,0.8233007788658142,0.6964321732521057,0.7393199801445007,1.0613622665405271,50000.0,0.6107000112533569,1.701716661453247,10000.0,73568.97808933258,81403.82748365402,73568.97808933258,7819.126227378845,7.240983247756958,0.0 -160500,2.5921001,3.3226857,,,,,,,,,,,,,, -160600,2.793936,3.2799969,,,,,,,,,,,,,, -160700,2.878528,2.8893223,,,,,,,,,,,,,, -160800,2.862955,3.5129282,,,,,,,,,,,,,, -160900,2.5197287,3.6189957,,,,,,,,,,,,,, -161000,3.047865,1.9781649,,,,,,,,,,,,,, -161100,2.7215464,1.5223542,,,,,,,,,,,,,, -161200,2.7441375,2.9694228,,,,,,,,,,,,,, -161300,3.8027415,1.6422358,,,,,,,,,,,,,, -161350,,,0.8231640458106995,0.701531708240509,0.7382000088691711,1.0777002573013306,50000.0,0.6128000020980835,1.720652461051941,10000.0,73989.23976898193,81866.85880875587,73989.23976898193,7861.79163479805,7.296997785568237,0.0 -161400,2.6415179,1.5606503,,,,,,,,,,,,,, -161500,2.8834343,2.0831616,,,,,,,,,,,,,, -161600,2.549697,1.4853649,,,,,,,,,,,,,, -161700,2.6029754,2.1592195,,,,,,,,,,,,,, -161800,2.9898925,4.0172944,,,,,,,,,,,,,, -161900,2.8623145,1.7975525,,,,,,,,,,,,,, -162000,2.6439033,2.3818526,,,,,,,,,,,,,, -162100,2.6014743,3.3691118,,,,,,,,,,,,,, -162200,3.2056189,1.5435057,,,,,,,,,,,,,, -162268,,,0.8293749690055847,0.6697791218757629,0.7391200065612793,1.058064103126526,50000.0,0.6132000088691711,1.7002379894256592,10000.0,74409.15633821487,82332.41932559013,74409.15633821487,7907.331022977829,7.35301947593689,0.0 -162300,2.823541,1.5949782,,,,,,,,,,,,,, -162400,2.7731347,1.4447138,,,,,,,,,,,,,, -162500,2.8602715,1.5431811,,,,,,,,,,,,,, -162600,2.809358,2.7598407,,,,,,,,,,,,,, -162700,2.4403615,1.7503616,,,,,,,,,,,,,, -162800,3.820758,1.520612,,,,,,,,,,,,,, -162900,2.8950088,4.085947,,,,,,,,,,,,,, -163000,2.628925,2.1916838,,,,,,,,,,,,,, -163100,2.7579176,2.8448014,,,,,,,,,,,,,, -163188,,,0.8237695097923279,0.6979393362998962,0.739579975605011,1.063696026802063,50000.0,0.6170000433921814,1.7050611972808838,10000.0,74829.30160307884,82799.23604655266,74829.30160307884,7953.901052236557,7.404864072799683,0.0 -163200,2.531879,1.7025225,,,,,,,,,,,,,, -163300,5.863553,1.5635386,,,,,,,,,,,,,, -163400,2.7578673,1.5148399,,,,,,,,,,,,,, -163500,3.3666832,1.4878474,,,,,,,,,,,,,, -163600,3.0694792,1.469166,,,,,,,,,,,,,, -163700,3.1626365,3.7860932,,,,,,,,,,,,,, -163800,2.518465,2.6085937,,,,,,,,,,,,,, -163900,2.630817,1.5599599,,,,,,,,,,,,,, -164000,2.5322528,1.4160056,,,,,,,,,,,,,, -164100,2.6659656,1.6570814,,,,,,,,,,,,,, -164106,,,0.8280468583106995,0.6734815835952759,0.7416799664497375,1.0467333793640137,50000.0,0.617400050163269,1.6836295127868652,10000.0,75249.34499502182,83264.83662986755,75249.34499502182,7999.356866836548,7.457672595977783,0.0 -164200,3.0343142,1.5763156,,,,,,,,,,,,,, -164300,3.2110012,1.5698705,,,,,,,,,,,,,, -164400,2.7646496,3.6994135,,,,,,,,,,,,,, -164500,2.5060017,1.581934,,,,,,,,,,,,,, -164600,2.8217778,3.7149005,,,,,,,,,,,,,, -164700,2.6989465,1.4857975,,,,,,,,,,,,,, -164800,2.9729137,3.253534,,,,,,,,,,,,,, -164900,3.0295353,1.5018438,,,,,,,,,,,,,, -165000,2.990224,2.9306872,,,,,,,,,,,,,, -165024,,,0.8347070217132568,0.6491798162460327,0.7430399656295776,1.0401471853256226,50000.0,0.619100034236908,1.6772472858428955,10000.0,75669.39481902122,83733.0430316925,75669.39481902122,8047.415571212769,7.507667779922485,0.0 -165100,2.3218706,1.5452926,,,,,,,,,,,,,, -165200,3.4298744,1.5451398,,,,,,,,,,,,,, -165300,2.5738006,3.2839785,,,,,,,,,,,,,, -165400,3.2031877,2.485166,,,,,,,,,,,,,, -165500,2.6295888,1.8341095,,,,,,,,,,,,,, -165600,2.8622413,2.1223757,,,,,,,,,,,,,, -165700,3.0256264,3.9094048,,,,,,,,,,,,,, -165800,2.701936,2.6198394,,,,,,,,,,,,,, -165900,2.9322655,1.591031,,,,,,,,,,,,,, -165943,,,0.8305078148841858,0.6913408637046814,0.7436800003051758,1.0642290115356443,50000.0,0.6195000410079956,1.6991186141967771,10000.0,76089.31130671501,84199.780534029,76089.31130671501,8094.137541294098,7.558278083801269,0.0 -166000,3.1453269,3.6489418,,,,,,,,,,,,,, -166100,3.0558627,1.8479748,,,,,,,,,,,,,, -166200,3.2144356,3.2746596,,,,,,,,,,,,,, -166300,2.9475753,1.5450206,,,,,,,,,,,,,, -166400,2.8195329,2.766524,,,,,,,,,,,,,, -166500,3.022612,3.2949207,,,,,,,,,,,,,, -166600,2.8776326,3.5494885,,,,,,,,,,,,,, -166700,2.755557,1.5511286,,,,,,,,,,,,,, -166800,2.8237543,4.000556,,,,,,,,,,,,,, -166860,,,0.83509761095047,0.6374510526657104,0.7443199753761292,1.0309221744537354,50000.0,0.6202000379562378,1.6646264791488647,10000.0,76509.37047219276,84664.75477600098,76509.37047219276,8138.949564218521,7.613423585891724,0.0 -166900,2.7180002,1.5809206,,,,,,,,,,,,,, -167000,2.813129,1.4211155,,,,,,,,,,,,,, -167100,2.939491,3.0998368,,,,,,,,,,,,,, -167200,2.6656954,1.5107787,,,,,,,,,,,,,, -167300,3.5286286,1.5232769,,,,,,,,,,,,,, -167400,3.4668546,1.4158306,,,,,,,,,,,,,, -167500,2.5441651,2.004249,,,,,,,,,,,,,, -167600,3.0250995,1.5588095,,,,,,,,,,,,,, -167700,3.0772069,3.8913603,,,,,,,,,,,,,, -167778,,,0.8340038657188416,0.6508774161338806,0.7443599700927734,1.0415518283843994,50000.0,0.6232000589370728,1.6723334789276123,10000.0,76929.66366410255,85130.87297224998,76929.66366410255,8184.667482852936,7.671286582946777,0.0 -167800,2.8456633,1.5350362,,,,,,,,,,,,,, -167900,2.783136,2.6222744,,,,,,,,,,,,,, -168000,2.7647913,1.4790905,,,,,,,,,,,,,, -168100,2.5818193,1.4683019,,,,,,,,,,,,,, -168200,2.747293,1.5237672,,,,,,,,,,,,,, -168300,2.612568,1.7605696,,,,,,,,,,,,,, -168400,2.7081625,2.0694585,,,,,,,,,,,,,, -168500,2.7285712,1.4686364,,,,,,,,,,,,,, -168600,3.4924018,1.8467288,,,,,,,,,,,,,, -168694,,,0.8335155844688416,0.6408066153526306,0.74617999792099,1.021526336669922,50000.0,0.6210000514984131,1.6536414623260498,10000.0,77349.58613371849,85597.61217308044,77349.58613371849,8231.38061952591,7.726383686065674,0.0 -168700,2.7072027,1.4381235,,,,,,,,,,,,,, -168800,3.103116,2.915212,,,,,,,,,,,,,, -168900,3.369955,3.9116251,,,,,,,,,,,,,, -169000,2.9597657,3.2364635,,,,,,,,,,,,,, -169071,,,,,,,,,,,77520.06314373016,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index eb7a65b04..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -158.21125888824463,0.0,61.96003007888794,1,0,61.96003007888794,30.208836,2472,0.9085978916580344,220.17135214805603,30.99089,0.94417778980905,30.141817,5348,0.9043706614402812 -265.20961356163025,0.0419168472290039,1502.605571269989,1729,0,1502.605571269989,6.5954204,2472,0.899579550301627,1767.9254868030548,6.527625,0.944635537887994,6.5964723,5348,0.8966179750330672 -387.1395664215088,0.0902779102325439,2942.91507768631,3495,0,2942.91507768631,3.3570948,2472,0.6746694290414965,3330.286596775055,3.2871444,0.6894920491170125,3.6583788,5348,0.7243403458296727 -523.3349103927612,0.1417605876922607,4383.12352013588,5219,0,4383.12352013588,0.8465067,2472,0.2771921272317348,4906.813807249069,0.78254795,0.2698947346226629,1.1469847,5348,0.3340799598366433 -659.7774543762207,0.1961519718170166,5823.346606016159,6943,0,5823.346606016159,0.58340687,2472,0.1945646212905977,6483.608228683472,0.5334083,0.1861156782456618,0.85424995,5348,0.2556938316421599 -792.5681960582733,0.2492585182189941,7263.875959396362,8684,0,7263.875959396362,0.50006676,2472,0.1700688562549509,8057.054099321365,0.45842528,0.1636695731587988,0.7623629,5348,0.2306593162574703 -928.0997140407562,0.3020815849304199,8704.147989988327,10391,0,8704.147989988327,0.45460317,2472,0.1544086283590275,9632.983120441437,0.4250593,0.1507914068209825,0.6983394,5348,0.2119582532801683 -1065.1332168579102,0.350053071975708,10144.166356801989,12089,0,10144.166356801989,0.41433823,2472,0.1409623626429427,11210.15414071083,0.36514083,0.1322825601429391,0.65220815,5348,0.1972349073636038 -1201.91868019104,0.4008769989013672,11584.574691057203,13828,0,11584.574691057203,0.39117342,2472,0.1323908760384295,12787.471153974531,0.31292966,0.1176421643828492,0.6289795,5348,0.190650433976655 -1337.3448441028595,0.4517827033996582,13024.453626871107,15518,0,13024.453626871107,0.3671398,2472,0.125627120021124,14362.89952802658,0.28706837,0.106113242077151,0.59261024,5348,0.1804261563860702 -1471.637094259262,0.5028841495513916,14464.984570026398,17235,0,14464.984570026398,0.3576985,2472,0.1215038693559198,15937.846490383148,0.28261438,0.1072414958142727,0.58001304,5348,0.1747685296928854 -1606.7800455093384,0.55181884765625,15905.405641078947,18968,0,15905.405641078947,0.33700582,2472,0.1146588670200881,17513.532631635666,0.27884072,0.1019963510404823,0.5530038,5348,0.1681068190814563 -1741.54425239563,0.6050012111663818,17345.66361284256,20669,0,17345.66361284256,0.32689303,2472,0.1120183616679869,19088.68205499649,0.2795932,0.102994169942227,0.5383996,5348,0.1635112042248761 -1877.2550451755524,0.659116268157959,18786.160464525223,22379,0,18786.160464525223,0.32418817,2472,0.1106168626734101,20665.01668071747,0.26681316,0.0988424735055046,0.53043646,5348,0.1613002886741265 -2013.0226137638087,0.7118988037109375,20226.1611392498,24105,0,20226.1611392498,0.31326857,2472,0.1062092498933642,22240.911847114563,0.24471778,0.0913865546218487,0.52384573,5348,0.1589252440213561 -2149.603051900864,0.7754313945770264,21666.239350557327,25782,0,21666.239350557327,0.30197123,2472,0.1013547823614242,23817.70459675789,0.23044355,0.0857560186755001,0.5116114,5348,0.1542330826341755 -2285.3351743221283,0.8276486396789551,23106.38236856461,27484,0,23106.38236856461,0.30364478,2472,0.1011110434058456,25393.70451760292,0.22562861,0.082642092235939,0.4992585,5348,0.1501974376550778 -2419.71209859848,0.8837933540344238,24546.66773033142,29193,0,24546.66773033142,0.29264167,2472,0.0988564580667438,26968.495937109,0.23682073,0.0862042815045935,0.48714328,5348,0.1461907566351603 -2553.862573862076,0.9391980171203612,25986.74062180519,30862,0,25986.74062180519,0.28424188,2472,0.0946519610830134,28542.845601081848,0.22417934,0.0807763471315633,0.48772222,5348,0.1471755312472846 -2689.2480068206787,0.999396562576294,27426.87792801857,32543,0,27426.87792801857,0.27928787,2472,0.0940426136940669,30118.500452518463,0.22964644,0.0826993512064067,0.47304323,5348,0.1424254419417438 -2827.5657880306244,1.057802677154541,28866.8210606575,34241,0,28866.8210606575,0.27463698,2472,0.0918489630938598,31696.89235520363,0.22303934,0.0796327674584652,0.4615741,5348,0.1392587157380499 -2963.170545816421,1.1170177459716797,30306.81025290489,35914,0,30306.81025290489,0.26810637,2472,0.0904271525196514,33272.61613845825,0.17427287,0.0660188801082377,0.45529997,5348,0.1377525898606833 -3099.013829946518,1.1702604293823242,31746.826417922974,37620,0,31746.826417922974,0.2616381,2472,0.0897771819714419,34848.603251457214,0.19530179,0.0711178222212836,0.45110688,5348,0.1353485812487328 -3234.586481332779,1.2274115085601809,33187.17974424362,39313,0,33187.17974424362,0.25463516,2472,0.0848820912802388,36424.66038656235,0.2397092,0.0867527649257172,0.4422507,5348,0.1319984166368981 -3367.3648619651794,1.2825305461883545,34627.428320646286,41000,0,34627.428320646286,0.25042498,2472,0.0831352954319257,37997.81482315064,0.24868596,0.0896817422295253,0.43414676,5348,0.1291309846780655 -3499.707387447357,1.3385326862335205,36067.43534851074,42715,0,36067.43534851074,0.23770487,2472,0.0797635732130887,39570.2938041687,0.28319895,0.1032108615919181,0.42378053,5348,0.1277986425557797 -3632.162237882614,1.3978519439697266,37507.68049240112,44418,0,37507.68049240112,0.24019764,2472,0.0809010216724554,41143.12516140938,0.2481551,0.0890143598137168,0.41698194,5348,0.1253367060254689 -3765.296792268753,1.455432415008545,38948.10423493385,46117,0,38948.10423493385,0.23128738,2472,0.0772449373387768,42716.81401634216,0.22114661,0.0816975270840524,0.40540525,5348,0.1197176979445243 -3901.612589597702,1.514012098312378,40389.05780625343,47828,0,40389.05780625343,0.22929518,2472,0.0757825036053053,44294.21702575684,0.18673953,0.0712133900314304,0.3940457,5348,0.1176612568427353 -4037.291281223297,1.5761985778808594,41829.449006557465,49511,0,41829.449006557465,0.2227393,2472,0.0738529035403083,45870.42122077942,0.20579855,0.0762191313604197,0.39688087,5348,0.1164158065979899 -4173.671866893768,1.6306617259979248,43269.72302722931,51202,0,43269.72302722931,0.21779981,2472,0.0727154550809416,47447.20176792145,0.18011282,0.0690227194640096,0.38041675,5348,0.112331888353592 -4309.752601146698,1.781308889389038,44709.526916742325,52918,0,44709.526916742325,0.21114749,2472,0.0694452907602624,49023.30984520912,0.17563458,0.0666475010577246,0.37358853,5348,0.1102464832926228 -4444.613988637924,1.8514833450317385,46149.56543898583,54594,0,46149.56543898583,0.20493606,2472,0.0703390002640505,50598.354439258575,0.16827415,0.0645134575569358,0.36485976,5348,0.1085762283132355 -4578.047208547592,1.909917116165161,47589.52333641052,56314,0,47589.52333641052,0.19563222,2472,0.0660532569617939,52171.8772919178,0.16584644,0.0628298020901296,0.35868764,5348,0.1056122498238025 -4711.303059339523,1.9719979763031008,49029.40422821045,58024,0,49029.40422821045,0.19396327,2472,0.0637174253041659,53745.149267435074,0.16346063,0.0612073753798233,0.35139942,5348,0.1030730760690114 -4847.475305557251,2.020612955093384,50469.57779741287,59706,0,50469.57779741287,0.18521993,2472,0.0618893831373265,55321.613298892975,0.13584337,0.0522171408074965,0.3399327,5348,0.099790494028597 -4982.617788314819,2.086501359939575,51910.06654167175,61398,0,51910.06654167175,0.18194032,2472,0.059980094651961,56897.38406395912,0.13234085,0.0504312902195244,0.33066174,5348,0.0966816957432634 -5119.124755144119,2.150652885437012,53350.15077781677,63107,0,53350.15077781677,0.18186228,2472,0.0592691893648569,58474.11357855797,0.11835707,0.0454961389099674,0.3270045,5348,0.094084594070112 -5254.06840133667,2.2103593349456787,54790.40292072296,64784,0,54790.40292072296,0.17583282,2472,0.0571161619239128,60049.44139790535,0.114238106,0.0437322432061581,0.31845886,5348,0.0914874923969607 -5388.650343894959,2.275513172149658,56230.28919124603,66491,0,56230.28919124603,0.17403618,2472,0.0551662502792842,61624.050307273865,0.105487265,0.0400259385970796,0.31323752,5348,0.0898655106828736 -5525.277727842331,2.334338903427124,57670.17176914215,68195,0,57670.17176914215,0.16796601,2472,0.0546381492088639,63200.69050574303,0.099455096,0.0379664168778986,0.30718902,5348,0.0880118172953454 -5661.63925409317,2.3968987464904785,59110.58864855766,69881,0,59110.58864855766,0.16488022,2472,0.0529116649401824,64777.60359311104,0.08041966,0.030858299457803,0.30663946,5348,0.0876352858260038 -5795.302344799042,2.4600815773010254,60550.75520849228,71588,0,60550.75520849228,0.16116796,2472,0.051896085958605,66351.57162237167,0.07414443,0.028042570786109555,0.2998368,5348,0.08489336435695183 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 2a782efcf..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,768 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,54.19962,30.951794,,,,,,,,,,,,,, -1,,,30.99089,0.94417778980905,30.141817,0.9043706614402812,5348.0,30.208836,0.9085978916580344,2472.0,61.96003007888794,220.17135214805603,61.96003007888794,158.21125888824463,0.0,0.0 -100,34.3021,10.599904,,,,,,,,,,,,,, -200,2.932757,6.2059727,,,,,,,,,,,,,, -300,1.0351027,5.854243,,,,,,,,,,,,,, -400,0.26643464,5.832925,,,,,,,,,,,,,, -500,0.25737756,5.8068495,,,,,,,,,,,,,, -600,0.4364609,5.813186,,,,,,,,,,,,,, -700,0.38151062,5.8107295,,,,,,,,,,,,,, -800,0.30340022,5.7847657,,,,,,,,,,,,,, -900,0.54796505,5.784335,,,,,,,,,,,,,, -1000,0.3331883,5.7871943,,,,,,,,,,,,,, -1100,0.27945787,5.7693367,,,,,,,,,,,,,, -1200,0.23603041,5.7959504,,,,,,,,,,,,,, -1300,0.37656757,5.7955317,,,,,,,,,,,,,, -1400,0.3712578,5.7845955,,,,,,,,,,,,,, -1500,0.2653813,5.7728243,,,,,,,,,,,,,, -1600,1.9574066,5.7288647,,,,,,,,,,,,,, -1700,0.7584992,5.5853615,,,,,,,,,,,,,, -1729,,,6.527625,0.944635537887994,6.5964723,0.8966179750330672,5348.0,6.5954204,0.899579550301627,2472.0,1502.605571269989,1767.9254868030548,1502.605571269989,265.20961356163025,0.0419168472290039,0.0 -1800,0.76931566,5.4998364,,,,,,,,,,,,,, -1900,0.88576365,5.2894044,,,,,,,,,,,,,, -2000,0.9054967,4.7271476,,,,,,,,,,,,,, -2100,1.165263,4.0673876,,,,,,,,,,,,,, -2200,0.86118484,3.8086166,,,,,,,,,,,,,, -2300,1.1050234,3.4833252,,,,,,,,,,,,,, -2400,1.133331,3.4050841,,,,,,,,,,,,,, -2500,0.94116265,3.2035744,,,,,,,,,,,,,, -2600,0.977584,3.0996299,,,,,,,,,,,,,, -2700,0.9635643,3.015257,,,,,,,,,,,,,, -2800,1.0476876,2.8622227,,,,,,,,,,,,,, -2900,1.0897795,2.8149512,,,,,,,,,,,,,, -3000,0.91584104,2.739287,,,,,,,,,,,,,, -3100,0.9810251,2.7542353,,,,,,,,,,,,,, -3200,1.0603007,2.6537328,,,,,,,,,,,,,, -3300,1.0164123,2.6215734,,,,,,,,,,,,,, -3400,1.1874396,2.573163,,,,,,,,,,,,,, -3495,,,3.2871444,0.6894920491170125,3.6583788,0.7243403458296727,5348.0,3.3570948,0.6746694290414965,2472.0,2942.91507768631,3330.286596775055,2942.91507768631,387.1395664215088,0.0902779102325439,0.0 -3500,0.99307406,2.5291927,,,,,,,,,,,,,, -3600,0.9062632,2.436075,,,,,,,,,,,,,, -3700,1.1126887,2.459356,,,,,,,,,,,,,, -3800,1.1225502,2.3955345,,,,,,,,,,,,,, -3900,0.9811544,2.3517172,,,,,,,,,,,,,, -4000,1.026482,2.3335595,,,,,,,,,,,,,, -4100,0.9310614,2.2759314,,,,,,,,,,,,,, -4200,0.85149664,2.3366482,,,,,,,,,,,,,, -4300,0.97763157,2.173264,,,,,,,,,,,,,, -4400,0.82654,2.1664336,,,,,,,,,,,,,, -4500,0.89529544,2.1306124,,,,,,,,,,,,,, -4600,1.0655924,2.2055917,,,,,,,,,,,,,, -4700,0.8460405,2.103369,,,,,,,,,,,,,, -4800,0.8048951,2.0763338,,,,,,,,,,,,,, -4900,1.0713862,2.069076,,,,,,,,,,,,,, -5000,0.81408936,2.0463946,,,,,,,,,,,,,, -5100,0.9638177,2.0049632,,,,,,,,,,,,,, -5200,0.9072069,1.93249,,,,,,,,,,,,,, -5219,,,0.78254795,0.2698947346226629,1.1469847,0.3340799598366433,5348.0,0.8465067,0.2771921272317348,2472.0,4383.12352013588,4906.813807249069,4383.12352013588,523.3349103927612,0.1417605876922607,0.0 -5300,0.885901,2.0362496,,,,,,,,,,,,,, -5400,0.93635994,1.9745896,,,,,,,,,,,,,, -5500,0.9933123,1.9635887,,,,,,,,,,,,,, -5600,0.94881,1.9933805,,,,,,,,,,,,,, -5700,0.85642284,1.9163088,,,,,,,,,,,,,, -5800,0.87650746,1.9295073,,,,,,,,,,,,,, -5900,0.7520006,1.9783832,,,,,,,,,,,,,, -6000,0.7492279,1.8653291,,,,,,,,,,,,,, -6100,0.96672344,1.8802377,,,,,,,,,,,,,, -6200,0.9033152,1.8135442,,,,,,,,,,,,,, -6300,0.8610135,1.923154,,,,,,,,,,,,,, -6400,0.77161056,1.8468862,,,,,,,,,,,,,, -6500,0.70371115,1.8076043,,,,,,,,,,,,,, -6600,0.79324186,1.8491433,,,,,,,,,,,,,, -6700,0.72238743,1.8442955,,,,,,,,,,,,,, -6800,0.79457486,1.7954131,,,,,,,,,,,,,, -6900,0.9237172,1.7999542,,,,,,,,,,,,,, -6943,,,0.5334083,0.1861156782456618,0.85424995,0.2556938316421599,5348.0,0.58340687,0.1945646212905977,2472.0,5823.346606016159,6483.608228683472,5823.346606016159,659.7774543762207,0.1961519718170166,0.0 -7000,0.84360516,1.7982931,,,,,,,,,,,,,, -7100,0.94879776,1.761334,,,,,,,,,,,,,, -7200,1.0132645,1.7884037,,,,,,,,,,,,,, -7300,0.91865504,1.7756293,,,,,,,,,,,,,, -7400,0.8963354,1.7994933,,,,,,,,,,,,,, -7500,0.8783976,1.7611295,,,,,,,,,,,,,, -7600,0.8378237,1.7493545,,,,,,,,,,,,,, -7700,0.83828247,1.7787338,,,,,,,,,,,,,, -7800,0.7954198,1.7249634,,,,,,,,,,,,,, -7900,0.6731003,1.752616,,,,,,,,,,,,,, -8000,0.72224355,1.6940055,,,,,,,,,,,,,, -8100,0.678538,1.6621445,,,,,,,,,,,,,, -8200,0.80891496,1.7266089,,,,,,,,,,,,,, -8300,0.7888917,1.6995832,,,,,,,,,,,,,, -8400,0.618063,1.6272997,,,,,,,,,,,,,, -8500,0.73544353,1.7077968,,,,,,,,,,,,,, -8600,0.7043598,1.6593627,,,,,,,,,,,,,, -8684,,,0.45842528,0.1636695731587988,0.7623629,0.2306593162574703,5348.0,0.50006676,0.1700688562549509,2472.0,7263.875959396362,8057.054099321365,7263.875959396362,792.5681960582733,0.2492585182189941,0.0 -8700,0.7132828,1.7144253,,,,,,,,,,,,,, -8800,0.8286831,1.6514418,,,,,,,,,,,,,, -8900,0.86885834,1.6613334,,,,,,,,,,,,,, -9000,1.0283356,1.6524261,,,,,,,,,,,,,, -9100,0.7574715,1.6190292,,,,,,,,,,,,,, -9200,1.0999166,1.6842065,,,,,,,,,,,,,, -9300,0.7510121,1.6147251,,,,,,,,,,,,,, -9400,0.64723045,1.6079016,,,,,,,,,,,,,, -9500,0.78648025,1.6154841,,,,,,,,,,,,,, -9600,0.81519455,1.6353118,,,,,,,,,,,,,, -9700,0.59346604,1.6524637,,,,,,,,,,,,,, -9800,0.8145823,1.5795889,,,,,,,,,,,,,, -9900,0.84147084,1.6026464,,,,,,,,,,,,,, -10000,0.8008255,1.6278892,,,,,,,,,,,,,, -10100,0.7235706,1.5888907,,,,,,,,,,,,,, -10200,0.6946253,1.6734022,,,,,,,,,,,,,, -10300,0.7844911,1.6349138,,,,,,,,,,,,,, -10391,,,0.4250593,0.1507914068209825,0.6983394,0.2119582532801683,5348.0,0.45460317,0.1544086283590275,2472.0,8704.147989988327,9632.983120441437,8704.147989988327,928.0997140407562,0.3020815849304199,0.0 -10400,0.68954897,1.5433452,,,,,,,,,,,,,, -10500,0.72935593,1.635221,,,,,,,,,,,,,, -10600,0.82716864,1.564945,,,,,,,,,,,,,, -10700,0.71387655,1.6075269,,,,,,,,,,,,,, -10800,0.6669136,1.6567501,,,,,,,,,,,,,, -10900,0.642738,1.6195475,,,,,,,,,,,,,, -11000,0.59384334,1.5989413,,,,,,,,,,,,,, -11100,0.64718044,1.569447,,,,,,,,,,,,,, -11200,0.8876569,1.5884732,,,,,,,,,,,,,, -11300,0.773195,1.5560219,,,,,,,,,,,,,, -11400,0.69197446,1.5132705,,,,,,,,,,,,,, -11500,0.7079817,1.5022359,,,,,,,,,,,,,, -11600,0.830248,1.554652,,,,,,,,,,,,,, -11700,0.7292553,1.5692152,,,,,,,,,,,,,, -11800,0.6161142,1.6153779,,,,,,,,,,,,,, -11900,0.83043814,1.595378,,,,,,,,,,,,,, -12000,0.74174696,1.5515031,,,,,,,,,,,,,, -12089,,,0.36514083,0.1322825601429391,0.65220815,0.1972349073636038,5348.0,0.41433823,0.1409623626429427,2472.0,10144.166356801989,11210.15414071083,10144.166356801989,1065.1332168579102,0.350053071975708,0.0 -12100,0.6541654,1.5337014,,,,,,,,,,,,,, -12200,0.70557845,1.5172164,,,,,,,,,,,,,, -12300,0.82574594,1.4982258,,,,,,,,,,,,,, -12400,0.605324,1.531617,,,,,,,,,,,,,, -12500,0.63840497,1.5081764,,,,,,,,,,,,,, -12600,0.7425679,1.5215,,,,,,,,,,,,,, -12700,0.6425596,1.5207582,,,,,,,,,,,,,, -12800,0.89181036,1.5042467,,,,,,,,,,,,,, -12900,0.6399073,1.5390197,,,,,,,,,,,,,, -13000,0.84956914,1.5479641,,,,,,,,,,,,,, -13100,0.62125546,1.5516434,,,,,,,,,,,,,, -13200,0.88787556,1.5306338,,,,,,,,,,,,,, -13300,0.62696844,1.5585104,,,,,,,,,,,,,, -13400,0.65088516,1.5239317,,,,,,,,,,,,,, -13500,0.7342667,1.4964592,,,,,,,,,,,,,, -13600,0.70898944,1.4726741,,,,,,,,,,,,,, -13700,0.60427517,1.5401143,,,,,,,,,,,,,, -13800,0.83830905,1.5231441,,,,,,,,,,,,,, -13828,,,0.31292966,0.1176421643828492,0.6289795,0.190650433976655,5348.0,0.39117342,0.1323908760384295,2472.0,11584.574691057203,12787.471153974531,11584.574691057203,1201.91868019104,0.4008769989013672,0.0 -13900,0.6124131,1.4976956,,,,,,,,,,,,,, -14000,0.70219594,1.5138173,,,,,,,,,,,,,, -14100,0.61959124,1.5546327,,,,,,,,,,,,,, -14200,0.6180617,1.5443621,,,,,,,,,,,,,, -14300,0.72760934,1.4988126,,,,,,,,,,,,,, -14400,0.7525038,1.5218047,,,,,,,,,,,,,, -14500,0.67047125,1.5210286,,,,,,,,,,,,,, -14600,0.5637617,1.4354826,,,,,,,,,,,,,, -14700,0.63182545,1.4850918,,,,,,,,,,,,,, -14800,0.68405896,1.4809942,,,,,,,,,,,,,, -14900,0.7406014,1.4694549,,,,,,,,,,,,,, -15000,0.648963,1.4739832,,,,,,,,,,,,,, -15100,0.80734533,1.482332,,,,,,,,,,,,,, -15200,0.7629685,1.4409605,,,,,,,,,,,,,, -15300,0.5990668,1.4735967,,,,,,,,,,,,,, -15400,0.7374908,1.4701905,,,,,,,,,,,,,, -15500,0.78151774,1.4332346,,,,,,,,,,,,,, -15518,,,0.28706837,0.106113242077151,0.59261024,0.1804261563860702,5348.0,0.3671398,0.125627120021124,2472.0,13024.453626871107,14362.89952802658,13024.453626871107,1337.3448441028595,0.4517827033996582,0.0 -15600,0.66586167,1.4446175,,,,,,,,,,,,,, -15700,0.6399541,1.4836506,,,,,,,,,,,,,, -15800,0.67337805,1.423807,,,,,,,,,,,,,, -15900,0.6222766,1.4999422,,,,,,,,,,,,,, -16000,0.58642405,1.445693,,,,,,,,,,,,,, -16100,0.6717128,1.4408715,,,,,,,,,,,,,, -16200,0.5491122,1.490891,,,,,,,,,,,,,, -16300,0.5849523,1.4726977,,,,,,,,,,,,,, -16400,0.59725046,1.4196707,,,,,,,,,,,,,, -16500,0.7457869,1.4356616,,,,,,,,,,,,,, -16600,0.7360927,1.4296829,,,,,,,,,,,,,, -16700,0.62130886,1.3868109,,,,,,,,,,,,,, -16800,0.6956814,1.4486121,,,,,,,,,,,,,, -16900,0.69578695,1.4492956,,,,,,,,,,,,,, -17000,0.7539934,1.4004295,,,,,,,,,,,,,, -17100,0.7154955,1.423577,,,,,,,,,,,,,, -17200,0.8388117,1.4424572,,,,,,,,,,,,,, -17235,,,0.28261438,0.1072414958142727,0.58001304,0.1747685296928854,5348.0,0.3576985,0.1215038693559198,2472.0,14464.984570026398,15937.846490383148,14464.984570026398,1471.637094259262,0.5028841495513916,0.0 -17300,0.6468824,1.4018247,,,,,,,,,,,,,, -17400,0.635477,1.4573559,,,,,,,,,,,,,, -17500,0.63590825,1.4727068,,,,,,,,,,,,,, -17600,0.58479285,1.4492059,,,,,,,,,,,,,, -17700,0.6853016,1.4526631,,,,,,,,,,,,,, -17800,0.65234935,1.3665922,,,,,,,,,,,,,, -17900,0.8325668,1.4413937,,,,,,,,,,,,,, -18000,0.72421503,1.4095196,,,,,,,,,,,,,, -18100,0.6264392,1.4062287,,,,,,,,,,,,,, -18200,0.7360878,1.4125181,,,,,,,,,,,,,, -18300,0.7553614,1.4351375,,,,,,,,,,,,,, -18400,0.77040994,1.4464325,,,,,,,,,,,,,, -18500,0.53379726,1.4389483,,,,,,,,,,,,,, -18600,0.72517544,1.3029721,,,,,,,,,,,,,, -18700,0.6643532,1.4491894,,,,,,,,,,,,,, -18800,0.6616309,1.4222324,,,,,,,,,,,,,, -18900,0.6606286,1.368264,,,,,,,,,,,,,, -18968,,,0.27884072,0.1019963510404823,0.5530038,0.1681068190814563,5348.0,0.33700582,0.1146588670200881,2472.0,15905.405641078947,17513.532631635666,15905.405641078947,1606.7800455093384,0.55181884765625,0.0 -19000,0.95685196,1.4068286,,,,,,,,,,,,,, -19100,0.61110157,1.3701239,,,,,,,,,,,,,, -19200,0.8749874,1.4029294,,,,,,,,,,,,,, -19300,0.6796596,1.4405766,,,,,,,,,,,,,, -19400,0.68873614,1.4450439,,,,,,,,,,,,,, -19500,0.6145224,1.4039792,,,,,,,,,,,,,, -19600,0.57758284,1.3905513,,,,,,,,,,,,,, -19700,1.14372,1.433258,,,,,,,,,,,,,, -19800,0.89853674,1.417602,,,,,,,,,,,,,, -19900,0.58956265,1.3487293,,,,,,,,,,,,,, -20000,0.5360638,1.330615,,,,,,,,,,,,,, -20100,0.8196081,1.4083061,,,,,,,,,,,,,, -20200,0.6783298,1.3775331,,,,,,,,,,,,,, -20300,0.68296134,1.3847952,,,,,,,,,,,,,, -20400,0.6499227,1.4443938,,,,,,,,,,,,,, -20500,0.599993,1.4006051,,,,,,,,,,,,,, -20600,0.72245693,1.3842404,,,,,,,,,,,,,, -20669,,,0.2795932,0.102994169942227,0.5383996,0.1635112042248761,5348.0,0.32689303,0.1120183616679869,2472.0,17345.66361284256,19088.68205499649,17345.66361284256,1741.54425239563,0.6050012111663818,0.0 -20700,0.7361153,1.3575628,,,,,,,,,,,,,, -20800,0.7315613,1.4323316,,,,,,,,,,,,,, -20900,0.8576879,1.340481,,,,,,,,,,,,,, -21000,0.6485606,1.3814075,,,,,,,,,,,,,, -21100,0.7916432,1.3826418,,,,,,,,,,,,,, -21200,0.74419403,1.3817157,,,,,,,,,,,,,, -21300,0.76621366,1.351154,,,,,,,,,,,,,, -21400,0.68989056,1.4057693,,,,,,,,,,,,,, -21500,0.6561492,1.4225934,,,,,,,,,,,,,, -21600,0.61712945,1.422159,,,,,,,,,,,,,, -21700,0.65835243,1.3561989,,,,,,,,,,,,,, -21800,0.67817533,1.3577447,,,,,,,,,,,,,, -21900,0.71838063,1.3971575,,,,,,,,,,,,,, -22000,0.6424991,1.3553933,,,,,,,,,,,,,, -22100,0.6364011,1.3574029,,,,,,,,,,,,,, -22200,0.76701874,1.3673556,,,,,,,,,,,,,, -22300,0.5974121,1.318927,,,,,,,,,,,,,, -22379,,,0.26681316,0.0988424735055046,0.53043646,0.1613002886741265,5348.0,0.32418817,0.1106168626734101,2472.0,18786.160464525223,20665.01668071747,18786.160464525223,1877.2550451755524,0.659116268157959,0.0 -22400,0.7307391,1.3273461,,,,,,,,,,,,,, -22500,0.76861984,1.4243032,,,,,,,,,,,,,, -22600,0.7978624,1.3290378,,,,,,,,,,,,,, -22700,0.6935139,1.3780848,,,,,,,,,,,,,, -22800,0.7315191,1.3440946,,,,,,,,,,,,,, -22900,0.69788337,1.3466097,,,,,,,,,,,,,, -23000,0.82566,1.3829454,,,,,,,,,,,,,, -23100,0.77417946,1.3432437,,,,,,,,,,,,,, -23200,0.6949909,1.3341293,,,,,,,,,,,,,, -23300,0.63966626,1.3793069,,,,,,,,,,,,,, -23400,0.7203176,1.3693883,,,,,,,,,,,,,, -23500,0.7195913,1.2847493,,,,,,,,,,,,,, -23600,0.7629206,1.3480747,,,,,,,,,,,,,, -23700,0.65988165,1.3029736,,,,,,,,,,,,,, -23800,0.69051033,1.3853482,,,,,,,,,,,,,, -23900,0.6140464,1.3635609,,,,,,,,,,,,,, -24000,0.8318002,1.3242277,,,,,,,,,,,,,, -24100,1.174067,1.3984586,,,,,,,,,,,,,, -24105,,,0.24471778,0.0913865546218487,0.52384573,0.1589252440213561,5348.0,0.31326857,0.1062092498933642,2472.0,20226.1611392498,22240.911847114563,20226.1611392498,2013.0226137638087,0.7118988037109375,0.0 -24200,0.8279611,1.385788,,,,,,,,,,,,,, -24300,0.6531667,1.4046034,,,,,,,,,,,,,, -24400,0.6298403,1.3569056,,,,,,,,,,,,,, -24500,0.67756003,1.3353378,,,,,,,,,,,,,, -24600,0.62336046,1.3473357,,,,,,,,,,,,,, -24700,0.73938566,1.3144137,,,,,,,,,,,,,, -24800,0.60253537,1.3346663,,,,,,,,,,,,,, -24900,0.7131653,1.3376514,,,,,,,,,,,,,, -25000,1.0137072,1.3181905,,,,,,,,,,,,,, -25100,0.6856062,1.3634067,,,,,,,,,,,,,, -25200,0.6396416,1.3156645,,,,,,,,,,,,,, -25300,0.72371155,1.3862548,,,,,,,,,,,,,, -25400,0.73257816,1.3541552,,,,,,,,,,,,,, -25500,0.8374844,1.390343,,,,,,,,,,,,,, -25600,0.71984714,1.3213828,,,,,,,,,,,,,, -25700,0.6267279,1.3921206,,,,,,,,,,,,,, -25782,,,0.23044355,0.0857560186755001,0.5116114,0.1542330826341755,5348.0,0.30197123,0.1013547823614242,2472.0,21666.239350557327,23817.70459675789,21666.239350557327,2149.603051900864,0.7754313945770264,0.0 -25800,0.7462265,1.2918179,,,,,,,,,,,,,, -25900,0.68715703,1.2825235,,,,,,,,,,,,,, -26000,0.7107881,1.3916801,,,,,,,,,,,,,, -26100,0.7622514,1.2835432,,,,,,,,,,,,,, -26200,0.66056484,1.3166258,,,,,,,,,,,,,, -26300,0.6918486,1.336669,,,,,,,,,,,,,, -26400,1.1057209,1.3154395,,,,,,,,,,,,,, -26500,0.7938357,1.3230135,,,,,,,,,,,,,, -26600,0.67899114,1.3991287,,,,,,,,,,,,,, -26700,0.7093399,1.301342,,,,,,,,,,,,,, -26800,0.7304602,1.3396785,,,,,,,,,,,,,, -26900,0.7278955,1.3058634,,,,,,,,,,,,,, -27000,0.7363955,1.331932,,,,,,,,,,,,,, -27100,0.6317831,1.2889081,,,,,,,,,,,,,, -27200,0.8010329,1.302989,,,,,,,,,,,,,, -27300,0.7376796,1.3056476,,,,,,,,,,,,,, -27400,0.82651156,1.3356062,,,,,,,,,,,,,, -27484,,,0.22562861,0.082642092235939,0.4992585,0.1501974376550778,5348.0,0.30364478,0.1011110434058456,2472.0,23106.38236856461,25393.70451760292,23106.38236856461,2285.3351743221283,0.8276486396789551,0.0 -27500,0.62668186,1.323312,,,,,,,,,,,,,, -27600,0.7117739,1.3609407,,,,,,,,,,,,,, -27700,0.70537865,1.2301209,,,,,,,,,,,,,, -27800,0.66998667,1.3203673,,,,,,,,,,,,,, -27900,0.7969268,1.280186,,,,,,,,,,,,,, -28000,0.7587093,1.3640412,,,,,,,,,,,,,, -28100,0.6398658,1.2632756,,,,,,,,,,,,,, -28200,0.72199225,1.3475226,,,,,,,,,,,,,, -28300,1.1093037,1.3644617,,,,,,,,,,,,,, -28400,0.6971881,1.3174942,,,,,,,,,,,,,, -28500,0.65430033,1.3398255,,,,,,,,,,,,,, -28600,0.70669156,1.3344057,,,,,,,,,,,,,, -28700,0.6733177,1.2707053,,,,,,,,,,,,,, -28800,0.84104776,1.3401209,,,,,,,,,,,,,, -28900,0.6443919,1.315418,,,,,,,,,,,,,, -29000,0.61910087,1.29879,,,,,,,,,,,,,, -29100,0.6885299,1.2649475,,,,,,,,,,,,,, -29193,,,0.23682073,0.0862042815045935,0.48714328,0.1461907566351603,5348.0,0.29264167,0.0988564580667438,2472.0,24546.66773033142,26968.495937109,24546.66773033142,2419.71209859848,0.8837933540344238,0.0 -29200,0.7712755,1.3133379,,,,,,,,,,,,,, -29300,0.7264091,1.2623869,,,,,,,,,,,,,, -29400,0.767476,1.321614,,,,,,,,,,,,,, -29500,0.6380586,1.2838441,,,,,,,,,,,,,, -29600,0.762157,1.2562425,,,,,,,,,,,,,, -29700,0.6821499,1.3447663,,,,,,,,,,,,,, -29800,0.6364841,1.2736888,,,,,,,,,,,,,, -29900,0.9088386,1.2624395,,,,,,,,,,,,,, -30000,0.7726925,1.3000442,,,,,,,,,,,,,, -30100,0.69727373,1.3091214,,,,,,,,,,,,,, -30200,0.65750784,1.2884374,,,,,,,,,,,,,, -30300,0.77528614,1.3198541,,,,,,,,,,,,,, -30400,0.6707005,1.3591197,,,,,,,,,,,,,, -30500,0.72705346,1.3368665,,,,,,,,,,,,,, -30600,0.69529146,1.2927227,,,,,,,,,,,,,, -30700,0.85161316,1.3293957,,,,,,,,,,,,,, -30800,0.81555915,1.2491447,,,,,,,,,,,,,, -30862,,,0.22417934,0.0807763471315633,0.48772222,0.1471755312472846,5348.0,0.28424188,0.0946519610830134,2472.0,25986.74062180519,28542.845601081848,25986.74062180519,2553.862573862076,0.9391980171203612,0.0 -30900,0.77923006,1.2575136,,,,,,,,,,,,,, -31000,0.78551954,1.2884511,,,,,,,,,,,,,, -31100,0.6973568,1.2385246,,,,,,,,,,,,,, -31200,0.6317925,1.2530293,,,,,,,,,,,,,, -31300,0.7068032,1.3183745,,,,,,,,,,,,,, -31400,0.6654973,1.2942669,,,,,,,,,,,,,, -31500,0.68027914,1.2581484,,,,,,,,,,,,,, -31600,0.6799808,1.2770244,,,,,,,,,,,,,, -31700,0.6461001,1.2633444,,,,,,,,,,,,,, -31800,0.77279204,1.2590209,,,,,,,,,,,,,, -31900,0.8162062,1.3208084,,,,,,,,,,,,,, -32000,0.83528566,1.2728342,,,,,,,,,,,,,, -32100,0.69050664,1.3056977,,,,,,,,,,,,,, -32200,0.8139385,1.2495466,,,,,,,,,,,,,, -32300,0.7457748,1.2164435,,,,,,,,,,,,,, -32400,0.77542174,1.2787852,,,,,,,,,,,,,, -32500,0.81301886,1.2910819,,,,,,,,,,,,,, -32543,,,0.22964644,0.0826993512064067,0.47304323,0.1424254419417438,5348.0,0.27928787,0.0940426136940669,2472.0,27426.87792801857,30118.500452518463,27426.87792801857,2689.2480068206787,0.999396562576294,0.0 -32600,0.67421174,1.2985922,,,,,,,,,,,,,, -32700,0.71038103,1.2580689,,,,,,,,,,,,,, -32800,0.7492985,1.282615,,,,,,,,,,,,,, -32900,0.74353284,1.3103178,,,,,,,,,,,,,, -33000,0.8188803,1.2311205,,,,,,,,,,,,,, -33100,0.704264,1.232032,,,,,,,,,,,,,, -33200,0.6173429,1.209061,,,,,,,,,,,,,, -33300,0.7165188,1.255232,,,,,,,,,,,,,, -33400,0.7293122,1.2547756,,,,,,,,,,,,,, -33500,0.84969497,1.2270072,,,,,,,,,,,,,, -33600,0.6657056,1.3066754,,,,,,,,,,,,,, -33700,0.7357403,1.2497176,,,,,,,,,,,,,, -33800,0.76052433,1.2806417,,,,,,,,,,,,,, -33900,1.0802714,1.260197,,,,,,,,,,,,,, -34000,0.7505127,1.2682112,,,,,,,,,,,,,, -34100,0.6804598,1.1924919,,,,,,,,,,,,,, -34200,0.6848476,1.2995006,,,,,,,,,,,,,, -34241,,,0.22303934,0.0796327674584652,0.4615741,0.1392587157380499,5348.0,0.27463698,0.0918489630938598,2472.0,28866.8210606575,31696.89235520363,28866.8210606575,2827.5657880306244,1.057802677154541,0.0 -34300,0.7193836,1.2524736,,,,,,,,,,,,,, -34400,0.70989895,1.2860979,,,,,,,,,,,,,, -34500,0.7963191,1.2259371,,,,,,,,,,,,,, -34600,0.7568306,1.2185181,,,,,,,,,,,,,, -34700,0.8723297,1.2389096,,,,,,,,,,,,,, -34800,0.68993515,1.3073533,,,,,,,,,,,,,, -34900,0.68644166,1.321035,,,,,,,,,,,,,, -35000,0.7643071,1.2739513,,,,,,,,,,,,,, -35100,0.6991389,1.2125659,,,,,,,,,,,,,, -35200,0.7355853,1.2609068,,,,,,,,,,,,,, -35300,0.82586634,1.2806354,,,,,,,,,,,,,, -35400,0.7333495,1.1772132,,,,,,,,,,,,,, -35500,0.70796126,1.2822484,,,,,,,,,,,,,, -35600,0.76340675,1.2745552,,,,,,,,,,,,,, -35700,1.0478466,1.289695,,,,,,,,,,,,,, -35800,0.8640685,1.2451388,,,,,,,,,,,,,, -35900,0.72967756,1.235068,,,,,,,,,,,,,, -35914,,,0.17427287,0.0660188801082377,0.45529997,0.1377525898606833,5348.0,0.26810637,0.0904271525196514,2472.0,30306.81025290489,33272.61613845825,30306.81025290489,2963.170545816421,1.1170177459716797,0.0 -36000,0.7718379,1.2316959,,,,,,,,,,,,,, -36100,0.64672595,1.2470331,,,,,,,,,,,,,, -36200,0.7158055,1.1668773,,,,,,,,,,,,,, -36300,0.7007555,1.223153,,,,,,,,,,,,,, -36400,0.6627181,1.1832198,,,,,,,,,,,,,, -36500,0.8505196,1.298041,,,,,,,,,,,,,, -36600,0.65608966,1.2028424,,,,,,,,,,,,,, -36700,0.6846243,1.217978,,,,,,,,,,,,,, -36800,0.69155383,1.1963992,,,,,,,,,,,,,, -36900,0.73977566,1.263731,,,,,,,,,,,,,, -37000,0.7525657,1.2089695,,,,,,,,,,,,,, -37100,0.81824857,1.2027909,,,,,,,,,,,,,, -37200,0.70824176,1.2052431,,,,,,,,,,,,,, -37300,0.7564023,1.2164983,,,,,,,,,,,,,, -37400,0.676105,1.2420219,,,,,,,,,,,,,, -37500,0.77040094,1.2463224,,,,,,,,,,,,,, -37600,0.7968494,1.2200058,,,,,,,,,,,,,, -37620,,,0.19530179,0.0711178222212836,0.45110688,0.1353485812487328,5348.0,0.2616381,0.0897771819714419,2472.0,31746.826417922974,34848.603251457214,31746.826417922974,3099.013829946518,1.1702604293823242,0.0 -37700,0.76301765,1.2081985,,,,,,,,,,,,,, -37800,0.690854,1.264179,,,,,,,,,,,,,, -37900,0.6664843,1.2127084,,,,,,,,,,,,,, -38000,0.84593356,1.2111108,,,,,,,,,,,,,, -38100,0.7030793,1.2490914,,,,,,,,,,,,,, -38200,0.68868226,1.2069899,,,,,,,,,,,,,, -38300,0.90510553,1.2290845,,,,,,,,,,,,,, -38400,0.68822175,1.2242967,,,,,,,,,,,,,, -38500,0.7036268,1.2569116,,,,,,,,,,,,,, -38600,0.7869445,1.2557346,,,,,,,,,,,,,, -38700,0.84972364,1.2328299,,,,,,,,,,,,,, -38800,0.7132479,1.2141709,,,,,,,,,,,,,, -38900,0.8058344,1.2916062,,,,,,,,,,,,,, -39000,0.6514719,1.2441269,,,,,,,,,,,,,, -39100,0.6776603,1.214105,,,,,,,,,,,,,, -39200,0.68089795,1.2402093,,,,,,,,,,,,,, -39300,0.7127027,1.2242062,,,,,,,,,,,,,, -39313,,,0.2397092,0.0867527649257172,0.4422507,0.1319984166368981,5348.0,0.25463516,0.0848820912802388,2472.0,33187.17974424362,36424.66038656235,33187.17974424362,3234.586481332779,1.2274115085601809,0.0 -39400,0.8229604,1.2152512,,,,,,,,,,,,,, -39500,0.694861,1.2268084,,,,,,,,,,,,,, -39600,0.6636973,1.2052897,,,,,,,,,,,,,, -39700,0.7830393,1.2071326,,,,,,,,,,,,,, -39800,0.80184865,1.2406646,,,,,,,,,,,,,, -39900,0.71599054,1.1925156,,,,,,,,,,,,,, -40000,0.8456893,1.2246156,,,,,,,,,,,,,, -40100,0.74566066,1.20089,,,,,,,,,,,,,, -40200,0.7495653,1.1714991,,,,,,,,,,,,,, -40300,0.829436,1.2019656,,,,,,,,,,,,,, -40400,0.8155006,1.2191092,,,,,,,,,,,,,, -40500,0.75581,1.2132462,,,,,,,,,,,,,, -40600,0.69174,1.1307098,,,,,,,,,,,,,, -40700,0.93372935,1.2429394,,,,,,,,,,,,,, -40800,0.7607271,1.2234579,,,,,,,,,,,,,, -40900,0.8814649,1.1678036,,,,,,,,,,,,,, -41000,,,0.24868596,0.0896817422295253,0.43414676,0.1291309846780655,5348.0,0.25042498,0.0831352954319257,2472.0,34627.428320646286,37997.81482315064,34627.428320646286,3367.3648619651794,1.2825305461883545,0.0 -41000,0.88051784,1.2034899,,,,,,,,,,,,,, -41100,0.69917667,1.1744772,,,,,,,,,,,,,, -41200,1.046176,1.1732249,,,,,,,,,,,,,, -41300,0.70398843,1.221615,,,,,,,,,,,,,, -41400,0.76194316,1.1120641,,,,,,,,,,,,,, -41500,0.67939216,1.2423693,,,,,,,,,,,,,, -41600,0.7281695,1.1637343,,,,,,,,,,,,,, -41700,0.8889751,1.2371058,,,,,,,,,,,,,, -41800,0.70550144,1.2370659,,,,,,,,,,,,,, -41900,0.92563444,1.207524,,,,,,,,,,,,,, -42000,0.99357533,1.1909618,,,,,,,,,,,,,, -42100,0.7196744,1.1729013,,,,,,,,,,,,,, -42200,0.700216,1.1519617,,,,,,,,,,,,,, -42300,0.79538155,1.118507,,,,,,,,,,,,,, -42400,0.8634991,1.1674684,,,,,,,,,,,,,, -42500,0.8085238,1.2429076,,,,,,,,,,,,,, -42600,0.7330839,1.152032,,,,,,,,,,,,,, -42700,0.68043506,1.1939379,,,,,,,,,,,,,, -42715,,,0.28319895,0.1032108615919181,0.42378053,0.1277986425557797,5348.0,0.23770487,0.0797635732130887,2472.0,36067.43534851074,39570.2938041687,36067.43534851074,3499.707387447357,1.3385326862335205,0.0 -42800,0.69923407,1.1969633,,,,,,,,,,,,,, -42900,0.68754953,1.183289,,,,,,,,,,,,,, -43000,0.7436788,1.1885818,,,,,,,,,,,,,, -43100,0.69360566,1.2125796,,,,,,,,,,,,,, -43200,0.91347814,1.2116395,,,,,,,,,,,,,, -43300,0.85701,1.1582437,,,,,,,,,,,,,, -43400,0.7731314,1.1716973,,,,,,,,,,,,,, -43500,0.7343838,1.1850482,,,,,,,,,,,,,, -43600,0.80246204,1.1497043,,,,,,,,,,,,,, -43700,0.74339426,1.1374003,,,,,,,,,,,,,, -43800,0.61043704,1.1322328,,,,,,,,,,,,,, -43900,0.8110908,1.112261,,,,,,,,,,,,,, -44000,0.83025694,1.1695745,,,,,,,,,,,,,, -44100,0.7561555,1.1560968,,,,,,,,,,,,,, -44200,0.8799511,1.163604,,,,,,,,,,,,,, -44300,0.75914913,1.1209819,,,,,,,,,,,,,, -44400,0.8129188,1.1593486,,,,,,,,,,,,,, -44418,,,0.2481551,0.0890143598137168,0.41698194,0.1253367060254689,5348.0,0.24019764,0.0809010216724554,2472.0,37507.68049240112,41143.12516140938,37507.68049240112,3632.162237882614,1.3978519439697266,0.0 -44500,0.7261423,1.1397482,,,,,,,,,,,,,, -44600,0.7268536,1.1669261,,,,,,,,,,,,,, -44700,0.80778617,1.2076569,,,,,,,,,,,,,, -44800,0.9180664,1.2015492,,,,,,,,,,,,,, -44900,0.6759551,1.1802491,,,,,,,,,,,,,, -45000,0.7420583,1.1857127,,,,,,,,,,,,,, -45100,0.7923753,1.1504896,,,,,,,,,,,,,, -45200,0.6510319,1.1502715,,,,,,,,,,,,,, -45300,0.7834931,1.2086028,,,,,,,,,,,,,, -45400,0.8149974,1.1117504,,,,,,,,,,,,,, -45500,0.91990596,1.0834713,,,,,,,,,,,,,, -45600,0.79879695,1.1625065,,,,,,,,,,,,,, -45700,0.75968385,1.1045635,,,,,,,,,,,,,, -45800,0.6935701,1.1350763,,,,,,,,,,,,,, -45900,0.6733658,1.0996171,,,,,,,,,,,,,, -46000,0.7559227,1.1623884,,,,,,,,,,,,,, -46100,0.70829517,1.1232989,,,,,,,,,,,,,, -46117,,,0.22114661,0.0816975270840524,0.40540525,0.1197176979445243,5348.0,0.23128738,0.0772449373387768,2472.0,38948.10423493385,42716.81401634216,38948.10423493385,3765.296792268753,1.455432415008545,0.0 -46200,0.83314216,1.0930526,,,,,,,,,,,,,, -46300,1.1529976,1.1365176,,,,,,,,,,,,,, -46400,0.81885785,1.1964978,,,,,,,,,,,,,, -46500,0.7245957,1.0846953,,,,,,,,,,,,,, -46600,0.70814544,1.1025565,,,,,,,,,,,,,, -46700,0.74765795,1.1031014,,,,,,,,,,,,,, -46800,0.8225291,1.1338079,,,,,,,,,,,,,, -46900,0.74757564,1.1499224,,,,,,,,,,,,,, -47000,0.82568914,1.1514275,,,,,,,,,,,,,, -47100,0.76748854,1.0840901,,,,,,,,,,,,,, -47200,0.75513303,1.1403675,,,,,,,,,,,,,, -47300,0.7804188,1.1308985,,,,,,,,,,,,,, -47400,0.8539159,1.1029098,,,,,,,,,,,,,, -47500,0.94586074,1.1429071,,,,,,,,,,,,,, -47600,0.7762182,1.1502001,,,,,,,,,,,,,, -47700,0.78801453,1.1166859,,,,,,,,,,,,,, -47800,0.82344335,1.0912608,,,,,,,,,,,,,, -47828,,,0.18673953,0.0712133900314304,0.3940457,0.1176612568427353,5348.0,0.22929518,0.0757825036053053,2472.0,40389.05780625343,44294.21702575684,40389.05780625343,3901.612589597702,1.514012098312378,0.0 -47900,0.74394006,1.1191773,,,,,,,,,,,,,, -48000,0.7705349,1.132191,,,,,,,,,,,,,, -48100,0.7482879,1.1420674,,,,,,,,,,,,,, -48200,0.7859356,1.142155,,,,,,,,,,,,,, -48300,1.0932049,1.1383792,,,,,,,,,,,,,, -48400,0.79393035,1.114203,,,,,,,,,,,,,, -48500,0.80591744,1.130164,,,,,,,,,,,,,, -48600,0.81955385,1.0522693,,,,,,,,,,,,,, -48700,0.89032793,1.1000531,,,,,,,,,,,,,, -48800,0.8279148,1.0475398,,,,,,,,,,,,,, -48900,0.8171042,1.0929841,,,,,,,,,,,,,, -49000,0.9503799,1.1566513,,,,,,,,,,,,,, -49100,0.80221105,1.1663291,,,,,,,,,,,,,, -49200,0.92831546,1.1209877,,,,,,,,,,,,,, -49300,0.8293706,1.0965772,,,,,,,,,,,,,, -49400,1.1092343,1.1528864,,,,,,,,,,,,,, -49500,0.81967145,1.0648024,,,,,,,,,,,,,, -49511,,,0.20579855,0.0762191313604197,0.39688087,0.1164158065979899,5348.0,0.2227393,0.0738529035403083,2472.0,41829.449006557465,45870.42122077942,41829.449006557465,4037.291281223297,1.5761985778808594,0.0 -49600,1.2680261,1.1346153,,,,,,,,,,,,,, -49700,0.7722906,1.0615184,,,,,,,,,,,,,, -49800,0.7330689,1.0660189,,,,,,,,,,,,,, -49900,0.7900686,1.0717349,,,,,,,,,,,,,, -50000,0.88749576,1.1824068,,,,,,,,,,,,,, -50100,0.89371485,1.1198164,,,,,,,,,,,,,, -50200,0.7853887,1.0883644,,,,,,,,,,,,,, -50300,0.92039686,1.107888,,,,,,,,,,,,,, -50400,0.73784703,1.1020855,,,,,,,,,,,,,, -50500,0.84234273,1.0848491,,,,,,,,,,,,,, -50600,0.87646157,1.1070607,,,,,,,,,,,,,, -50700,0.85896045,1.0694294,,,,,,,,,,,,,, -50800,0.9384646,1.116012,,,,,,,,,,,,,, -50900,0.77382016,1.0466684,,,,,,,,,,,,,, -51000,0.78903574,1.0892721,,,,,,,,,,,,,, -51100,0.81751335,1.0951504,,,,,,,,,,,,,, -51200,0.8883863,1.0998214,,,,,,,,,,,,,, -51202,,,0.18011282,0.0690227194640096,0.38041675,0.112331888353592,5348.0,0.21779981,0.0727154550809416,2472.0,43269.72302722931,47447.20176792145,43269.72302722931,4173.671866893768,1.6306617259979248,0.0 -51300,0.7540945,1.104456,,,,,,,,,,,,,, -51400,0.8012298,1.0829426,,,,,,,,,,,,,, -51500,0.7639917,1.0808932,,,,,,,,,,,,,, -51600,0.82691836,1.0715263,,,,,,,,,,,,,, -51700,0.83136374,1.051812,,,,,,,,,,,,,, -51800,0.7844783,1.0871623,,,,,,,,,,,,,, -51900,0.77613866,1.0552096,,,,,,,,,,,,,, -52000,0.785979,1.0461708,,,,,,,,,,,,,, -52100,0.92193204,1.0464784,,,,,,,,,,,,,, -52200,0.817265,1.0784032,,,,,,,,,,,,,, -52300,0.9143305,1.043628,,,,,,,,,,,,,, -52400,0.88974446,1.0712341,,,,,,,,,,,,,, -52500,1.0156363,1.1333698,,,,,,,,,,,,,, -52600,0.97346807,1.0475366,,,,,,,,,,,,,, -52700,0.78789914,1.0507243,,,,,,,,,,,,,, -52800,0.9278012,1.0318755,,,,,,,,,,,,,, -52900,0.922252,1.0574629,,,,,,,,,,,,,, -52918,,,0.17563458,0.0666475010577246,0.37358853,0.1102464832926228,5348.0,0.21114749,0.0694452907602624,2472.0,44709.526916742325,49023.30984520912,44709.526916742325,4309.752601146698,1.781308889389038,0.0 -53000,1.15516,1.0460466,,,,,,,,,,,,,, -53100,0.85639507,1.0777148,,,,,,,,,,,,,, -53200,0.83470505,1.0517826,,,,,,,,,,,,,, -53300,0.84053123,1.040362,,,,,,,,,,,,,, -53400,0.8300606,1.0920022,,,,,,,,,,,,,, -53500,0.77566797,1.0719192,,,,,,,,,,,,,, -53600,0.8050565,1.0748365,,,,,,,,,,,,,, -53700,0.970073,1.019551,,,,,,,,,,,,,, -53800,0.82197,1.0382081,,,,,,,,,,,,,, -53900,0.9450502,1.034149,,,,,,,,,,,,,, -54000,0.882698,1.080683,,,,,,,,,,,,,, -54100,1.0346545,1.1129961,,,,,,,,,,,,,, -54200,0.96684337,1.042392,,,,,,,,,,,,,, -54300,0.94162583,1.0670698,,,,,,,,,,,,,, -54400,0.8185338,1.0292674,,,,,,,,,,,,,, -54500,0.79954094,1.0500447,,,,,,,,,,,,,, -54594,,,0.16827415,0.0645134575569358,0.36485976,0.1085762283132355,5348.0,0.20493606,0.0703390002640505,2472.0,46149.56543898583,50598.354439258575,46149.56543898583,4444.613988637924,1.8514833450317385,0.0 -54600,0.8806687,1.0303435,,,,,,,,,,,,,, -54700,0.8568609,1.0844454,,,,,,,,,,,,,, -54800,0.91929233,1.1118929,,,,,,,,,,,,,, -54900,1.137407,1.0514233,,,,,,,,,,,,,, -55000,0.85496074,1.1156358,,,,,,,,,,,,,, -55100,0.80496794,1.0056815,,,,,,,,,,,,,, -55200,0.9293919,1.0598135,,,,,,,,,,,,,, -55300,0.84525514,0.99651736,,,,,,,,,,,,,, -55400,0.9679677,1.0313896,,,,,,,,,,,,,, -55500,0.8408825,1.0870007,,,,,,,,,,,,,, -55600,1.2487694,1.0425344,,,,,,,,,,,,,, -55700,0.78270507,1.0378656,,,,,,,,,,,,,, -55800,0.99919665,1.0458853,,,,,,,,,,,,,, -55900,0.8772426,1.0340008,,,,,,,,,,,,,, -56000,1.1948583,1.0515958,,,,,,,,,,,,,, -56100,1.2254827,1.0038968,,,,,,,,,,,,,, -56200,1.0643924,0.981818,,,,,,,,,,,,,, -56300,0.93388957,1.0251743,,,,,,,,,,,,,, -56314,,,0.16584644,0.0628298020901296,0.35868764,0.1056122498238025,5348.0,0.19563222,0.0660532569617939,2472.0,47589.52333641052,52171.8772919178,47589.52333641052,4578.047208547592,1.909917116165161,0.0 -56400,1.0187218,1.0550755,,,,,,,,,,,,,, -56500,0.94634897,1.0181668,,,,,,,,,,,,,, -56600,2.0519702,1.0554354,,,,,,,,,,,,,, -56700,0.9985618,0.9951867,,,,,,,,,,,,,, -56800,0.98475355,0.98505574,,,,,,,,,,,,,, -56900,0.8607492,0.9896292,,,,,,,,,,,,,, -57000,0.92394626,1.0387223,,,,,,,,,,,,,, -57100,0.85579795,1.0361719,,,,,,,,,,,,,, -57200,0.9023914,0.98675466,,,,,,,,,,,,,, -57300,1.0187452,1.0502414,,,,,,,,,,,,,, -57400,0.9674642,0.9807892,,,,,,,,,,,,,, -57500,0.864701,1.0132235,,,,,,,,,,,,,, -57600,0.8267489,1.0465716,,,,,,,,,,,,,, -57700,1.0306867,1.0249121,,,,,,,,,,,,,, -57800,0.8547436,1.0157942,,,,,,,,,,,,,, -57900,1.1393551,0.9923132,,,,,,,,,,,,,, -58000,0.9166301,0.9505,,,,,,,,,,,,,, -58024,,,0.16346063,0.0612073753798233,0.35139942,0.1030730760690114,5348.0,0.19396327,0.0637174253041659,2472.0,49029.40422821045,53745.149267435074,49029.40422821045,4711.303059339523,1.9719979763031008,0.0 -58100,0.8829891,1.0218631,,,,,,,,,,,,,, -58200,0.9979568,1.0252826,,,,,,,,,,,,,, -58300,0.86171067,1.0100434,,,,,,,,,,,,,, -58400,1.1204964,0.9706481,,,,,,,,,,,,,, -58500,1.0554582,0.9692442,,,,,,,,,,,,,, -58600,0.86386675,0.9929895,,,,,,,,,,,,,, -58700,0.994331,1.006627,,,,,,,,,,,,,, -58800,1.6085485,0.9765048,,,,,,,,,,,,,, -58900,1.0972594,0.9994896,,,,,,,,,,,,,, -59000,0.8496107,1.0192397,,,,,,,,,,,,,, -59100,0.972081,1.0048836,,,,,,,,,,,,,, -59200,0.89739174,1.0053201,,,,,,,,,,,,,, -59300,0.8986013,0.98579675,,,,,,,,,,,,,, -59400,0.9480394,1.0156268,,,,,,,,,,,,,, -59500,0.8993761,1.01418,,,,,,,,,,,,,, -59600,1.1568174,0.9818138,,,,,,,,,,,,,, -59700,0.9418277,0.992097,,,,,,,,,,,,,, -59706,,,0.13584337,0.0522171408074965,0.3399327,0.099790494028597,5348.0,0.18521993,0.0618893831373265,2472.0,50469.57779741287,55321.613298892975,50469.57779741287,4847.475305557251,2.020612955093384,0.0 -59800,0.9205809,0.9933954,,,,,,,,,,,,,, -59900,0.96702623,0.9832436,,,,,,,,,,,,,, -60000,0.9334758,1.016871,,,,,,,,,,,,,, -60100,1.2135285,1.0457457,,,,,,,,,,,,,, -60200,0.94700855,1.0090808,,,,,,,,,,,,,, -60300,0.88058394,1.0068913,,,,,,,,,,,,,, -60400,1.1726155,1.0403686,,,,,,,,,,,,,, -60500,0.99250853,0.9881513,,,,,,,,,,,,,, -60600,0.90840846,1.0231782,,,,,,,,,,,,,, -60700,0.9902139,1.020046,,,,,,,,,,,,,, -60800,0.9011922,0.96403366,,,,,,,,,,,,,, -60900,1.0810778,0.95134526,,,,,,,,,,,,,, -61000,0.91631687,0.98403287,,,,,,,,,,,,,, -61100,0.9235507,1.01353,,,,,,,,,,,,,, -61200,0.92186856,0.99774885,,,,,,,,,,,,,, -61300,1.0194548,0.97973424,,,,,,,,,,,,,, -61398,,,0.13234085,0.0504312902195244,0.33066174,0.0966816957432634,5348.0,0.18194032,0.059980094651961,2472.0,51910.06654167175,56897.38406395912,51910.06654167175,4982.617788314819,2.086501359939575,0.0 -61400,0.9876745,0.9441144,,,,,,,,,,,,,, -61500,0.87979126,1.0091808,,,,,,,,,,,,,, -61600,1.1191708,0.9481761,,,,,,,,,,,,,, -61700,0.9210057,0.9848827,,,,,,,,,,,,,, -61800,1.0266136,0.95678586,,,,,,,,,,,,,, -61900,0.8762882,0.9453976,,,,,,,,,,,,,, -62000,0.9394555,0.9532498,,,,,,,,,,,,,, -62100,1.0280273,1.0180652,,,,,,,,,,,,,, -62200,0.8556576,0.938468,,,,,,,,,,,,,, -62300,1.0219814,0.9565694,,,,,,,,,,,,,, -62400,1.0389256,0.9999363,,,,,,,,,,,,,, -62500,0.98654056,0.957512,,,,,,,,,,,,,, -62600,0.94200706,0.9754311,,,,,,,,,,,,,, -62700,0.9554139,0.94278955,,,,,,,,,,,,,, -62800,1.3400074,0.98919815,,,,,,,,,,,,,, -62900,1.3567216,0.98369837,,,,,,,,,,,,,, -63000,1.0296079,0.9442222,,,,,,,,,,,,,, -63100,1.1623846,0.95858604,,,,,,,,,,,,,, -63107,,,0.11835707,0.0454961389099674,0.3270045,0.094084594070112,5348.0,0.18186228,0.0592691893648569,2472.0,53350.15077781677,58474.11357855797,53350.15077781677,5119.124755144119,2.150652885437012,0.0 -63200,1.0519112,0.9728085,,,,,,,,,,,,,, -63300,1.0447224,0.9512702,,,,,,,,,,,,,, -63400,1.1640955,0.97138745,,,,,,,,,,,,,, -63500,1.1004025,0.9980609,,,,,,,,,,,,,, -63600,1.209348,0.9710101,,,,,,,,,,,,,, -63700,1.0153426,0.9585148,,,,,,,,,,,,,, -63800,1.1797775,0.92341727,,,,,,,,,,,,,, -63900,0.96608055,0.967705,,,,,,,,,,,,,, -64000,1.1030753,0.9484977,,,,,,,,,,,,,, -64100,1.6534485,0.93684554,,,,,,,,,,,,,, -64200,1.0232399,0.9517622,,,,,,,,,,,,,, -64300,1.1322695,0.9529215,,,,,,,,,,,,,, -64400,1.1374695,0.9118763,,,,,,,,,,,,,, -64500,1.062128,0.9640387,,,,,,,,,,,,,, -64600,1.0112821,1.0014738,,,,,,,,,,,,,, -64700,1.3663876,0.9393968,,,,,,,,,,,,,, -64784,,,0.114238106,0.0437322432061581,0.31845886,0.0914874923969607,5348.0,0.17583282,0.0571161619239128,2472.0,54790.40292072296,60049.44139790535,54790.40292072296,5254.06840133667,2.2103593349456787,0.0 -64800,1.0245441,0.9235665,,,,,,,,,,,,,, -64900,1.1118112,0.92376524,,,,,,,,,,,,,, -65000,1.1368512,0.9544832,,,,,,,,,,,,,, -65100,1.0779444,0.9760084,,,,,,,,,,,,,, -65200,1.0539286,0.93747836,,,,,,,,,,,,,, -65300,1.0909477,0.9306168,,,,,,,,,,,,,, -65400,1.1889269,0.9308849,,,,,,,,,,,,,, -65500,1.0177242,0.9165942,,,,,,,,,,,,,, -65600,1.0500052,0.92954415,,,,,,,,,,,,,, -65700,1.176867,0.9649693,,,,,,,,,,,,,, -65800,1.2129158,0.9191377,,,,,,,,,,,,,, -65900,0.99053025,0.90398866,,,,,,,,,,,,,, -66000,1.2005105,0.94492793,,,,,,,,,,,,,, -66100,1.0984038,0.92304945,,,,,,,,,,,,,, -66200,1.0415006,0.940561,,,,,,,,,,,,,, -66300,0.9789381,0.90619206,,,,,,,,,,,,,, -66400,1.1280928,0.94511193,,,,,,,,,,,,,, -66491,,,0.105487265,0.0400259385970796,0.31323752,0.0898655106828736,5348.0,0.17403618,0.0551662502792842,2472.0,56230.28919124603,61624.050307273865,56230.28919124603,5388.650343894959,2.275513172149658,0.0 -66500,1.3149676,0.90123695,,,,,,,,,,,,,, -66600,0.9834498,0.8966993,,,,,,,,,,,,,, -66700,1.338875,0.9563768,,,,,,,,,,,,,, -66800,1.0740384,0.92571664,,,,,,,,,,,,,, -66900,1.2324227,0.89920425,,,,,,,,,,,,,, -67000,1.0464182,0.91067064,,,,,,,,,,,,,, -67100,1.0153555,0.91869533,,,,,,,,,,,,,, -67200,1.5843872,0.90499234,,,,,,,,,,,,,, -67300,1.5207292,0.8798313,,,,,,,,,,,,,, -67400,1.0134885,0.91800845,,,,,,,,,,,,,, -67500,1.0154684,0.89368844,,,,,,,,,,,,,, -67600,1.0024942,0.9029597,,,,,,,,,,,,,, -67700,1.1415308,0.9186275,,,,,,,,,,,,,, -67800,1.2543343,0.9495928,,,,,,,,,,,,,, -67900,1.4294491,0.91935927,,,,,,,,,,,,,, -68000,1.2825718,0.8990597,,,,,,,,,,,,,, -68100,1.1589744,0.9163807,,,,,,,,,,,,,, -68195,,,0.099455096,0.0379664168778986,0.30718902,0.0880118172953454,5348.0,0.16796601,0.0546381492088639,2472.0,57670.17176914215,63200.69050574303,57670.17176914215,5525.277727842331,2.334338903427124,0.0 -68200,1.0492101,0.9015856,,,,,,,,,,,,,, -68300,1.1756614,0.8894092,,,,,,,,,,,,,, -68400,1.2473031,0.8801529,,,,,,,,,,,,,, -68500,1.2229769,0.89472514,,,,,,,,,,,,,, -68600,1.2659159,0.9172504,,,,,,,,,,,,,, -68700,1.3643398,0.86355406,,,,,,,,,,,,,, -68800,1.2711647,0.8674519,,,,,,,,,,,,,, -68900,1.0698347,0.86499655,,,,,,,,,,,,,, -69000,1.0329494,0.87258554,,,,,,,,,,,,,, -69100,1.2048652,0.9462891,,,,,,,,,,,,,, -69200,1.4137146,0.9068899,,,,,,,,,,,,,, -69300,1.080373,0.9224217,,,,,,,,,,,,,, -69400,1.108445,0.8568283,,,,,,,,,,,,,, -69500,1.1072903,0.933364,,,,,,,,,,,,,, -69600,1.051838,0.8658438,,,,,,,,,,,,,, -69700,1.1904719,0.887897,,,,,,,,,,,,,, -69800,1.0855438,0.87445647,,,,,,,,,,,,,, -69881,,,0.08041966,0.030858299457803,0.30663946,0.0876352858260038,5348.0,0.16488022,0.0529116649401824,2472.0,59110.58864855766,64777.60359311104,59110.58864855766,5661.63925409317,2.3968987464904785,0.0 -69900,1.1053042,0.85888135,,,,,,,,,,,,,, -70000,1.0591165,0.89513385,,,,,,,,,,,,,, -70100,1.124354,0.89444435,,,,,,,,,,,,,, -70200,1.0789748,0.88597155,,,,,,,,,,,,,, -70300,1.2905236,0.89825904,,,,,,,,,,,,,, -70400,1.2438085,0.85583204,,,,,,,,,,,,,, -70500,1.085541,0.8491242,,,,,,,,,,,,,, -70600,1.2141032,0.8973155,,,,,,,,,,,,,, -70700,1.0486642,0.91287255,,,,,,,,,,,,,, -70800,1.2321107,0.92138296,,,,,,,,,,,,,, -70900,1.1550182,0.88299674,,,,,,,,,,,,,, -71000,1.1921747,0.87508076,,,,,,,,,,,,,, -71100,1.2584609,0.8710226,,,,,,,,,,,,,, -71200,1.311889,0.8580136,,,,,,,,,,,,,, -71300,1.3109095,0.9061311,,,,,,,,,,,,,, -71400,1.1748669,0.85541224,,,,,,,,,,,,,, -71500,1.2039218,0.8924937,,,,,,,,,,,,,, -71588,,,0.07414443,0.0280425707861095,0.2998368,0.0848933643569518,5348.0,0.16116796,0.051896085958605,2472.0,60550.75520849228,66351.57162237167,60550.75520849228,5795.302344799042,2.4600815773010254,0.0 -71600,1.0715926,0.82243073,,,,,,,,,,,,,, -71700,1.3208086,0.8949372,,,,,,,,,,,,,, -71800,1.1175965,0.80199224,,,,,,,,,,,,,, -71900,1.0719097,0.88192695,,,,,,,,,,,,,, -72000,1.2724043,0.9083863,,,,,,,,,,,,,, -72100,1.2252376,0.86369556,,,,,,,,,,,,,, -72200,1.0747806,0.85956997,,,,,,,,,,,,,, -72206,,,,,,,,,,,61068.42524552345,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 92902e256..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -117.74000239372252,0.0,35.15062928199768,1,0,35.15062928199768,30.208836,2472,0.9085978916580344,152.89069986343384,31.28297,0.9402553710331124,30.141817,5348,0.9043706614402812 -232.7037012577057,0.030081033706665,1475.9704265594482,1672,0,1475.9704265594482,6.379735,2472,0.8667357260374139,1708.7733316421509,6.5428987,0.8976848691695108,6.410177,5348,0.8668623343020169 -365.6510236263275,0.0797476768493652,2916.499334096909,3369,0,2916.499334096909,2.5972764,2472,0.532935226372555,3282.37225317955,2.9755747,0.6129117945215133,2.9443042,5348,0.5842320206223389 -499.9926204681397,0.1351144313812255,4356.664489269257,5038,0,4356.664489269257,0.6751439,2472,0.2215790222005565,4857.005820035934,0.84934884,0.2720468094502097,0.9519117,5348,0.2816745030267337 -636.2041306495667,0.1858932971954345,5796.641132116318,6718,0,5796.641132116318,0.47832748,2472,0.1601161822354924,6433.316517829895,0.58703554,0.2014097413937189,0.7285059,5348,0.2218735819728318 -770.8168325424194,0.2430033683776855,7236.699079275131,8411,0,7236.699079275131,0.42482418,2472,0.1440903459062011,8008.11719751358,0.49299544,0.1700506915739706,0.6641883,5348,0.2000347567510161 -930.1696033477784,0.2932603359222412,8677.11043047905,10118,0,8677.11043047905,0.3799917,2472,0.1304003412345378,9608.002731323242,0.29567704,0.1097389899676925,0.60095006,5348,0.1838632128754453 -1066.4482853412628,0.3423874378204345,10117.716738939283,11832,0,10117.716738939283,0.3490801,2472,0.1201836166798692,11185.009489297869,0.2633589,0.0978113489841189,0.5699334,5348,0.1738223736930013 -1201.1347353458405,0.3927774429321289,11558.25578045845,13501,0,11558.25578045845,0.33060178,2472,0.1119371153494607,12760.356534957886,0.25779513,0.0963556119347762,0.5493172,5348,0.1670737712040317 -1337.5061967372894,0.443948745727539,12998.907264947891,15177,0,12998.907264947891,0.3165861,2472,0.1080169804805719,14337.500767707825,0.22716044,0.0864846815260038,0.5236656,5348,0.1590024812458364 -1474.2307941913605,0.4952020645141601,14439.008511543274,16862,0,14439.008511543274,0.3099842,2472,0.1052342940710499,15914.451095581057,0.22985189,0.0873535887743841,0.5217149,5348,0.1570329320215878 -1610.9025394916534,0.5452065467834473,15879.444328546524,18526,0,15879.444328546524,0.29305145,2472,0.0995470517742164,17491.679450511932,0.2079086,0.0806949683182493,0.49227136,5348,0.1503229481448584 -1746.945957183838,0.5986299514770508,17320.42205142975,20203,0,17320.42205142975,0.28144392,2472,0.0955253590071699,19068.82412481308,0.22827052,0.08320993218302,0.48092973,5348,0.1457949158596986 -1883.9370305538173,0.6513259410858154,18760.71983551979,21887,0,18760.71983551979,0.2813647,2472,0.095748786383117,20646.238900899887,0.2024693,0.0782813825275657,0.4742581,5348,0.1438736399007501 -2023.754251241684,0.7084658145904541,20201.21184659004,23551,0,20201.21184659004,0.2655093,2472,0.091300550443808,22226.676994800568,0.16744016,0.0667004249365727,0.45430538,5348,0.1394228448400706 -2161.89958691597,0.764528751373291,21641.502986192703,25245,0,21641.502986192703,0.2632382,2472,0.0904271525196514,23805.24548983574,0.16750503,0.0649266793314691,0.45686898,5348,0.1384573795340664 -2295.1566026210785,0.8176376819610596,23081.486287117004,26921,0,23081.486287117004,0.25737813,2472,0.088781914569496,25378.6113049984,0.15699868,0.0647888841016803,0.44483903,5348,0.1347886113712503 -2430.638184070587,0.8726804256439209,24521.77783679962,28601,0,24521.77783679962,0.24619368,2472,0.0839883817764507,26954.51037287712,0.16138557,0.062552889186055,0.43052354,5348,0.1311584618206744 -2567.4729528427124,0.9269707202911376,25962.238719701767,30300,0,25962.238719701767,0.24171406,2472,0.0827696869985578,28531.932410240173,0.15786675,0.0623297506176489,0.42404762,5348,0.1281558647190013 -2703.1411702632904,1.0707290172576904,27402.67057681084,31982,0,27402.67057681084,0.24244688,2472,0.0822009627688745,30108.24808359146,0.1505317,0.0602347318274219,0.42554152,5348,0.1288896183515645 -2840.8364021778107,1.121816635131836,28843.26565551757,33698,0,28843.26565551757,0.23038326,2472,0.0779355310462494,31686.66285228729,0.14140327,0.0557607767315864,0.4030115,5348,0.1216003552912326 -2976.497559785843,1.1860380172729492,30283.51073741913,35403,0,30283.51073741913,0.22733217,2472,0.0772246257591453,33262.70490074158,0.12408025,0.0507194035988632,0.41467035,5348,0.1220734332911746 -3113.3532407283783,1.2495365142822266,31723.40802645684,37064,0,31723.40802645684,0.22285151,2472,0.0752950256941482,34839.59212565422,0.12518206,0.0486296714584302,0.3949641,5348,0.1198238991281848 -3248.6328341960907,1.309312343597412,33163.48185658455,38724,0,33163.48185658455,0.21985793,2472,0.0746247435663071,36415.07633471489,0.1323574,0.0523750270504219,0.40408602,5348,0.1195149502302634 -3383.161405324936,1.3682467937469482,34603.7865755558,40411,0,34603.7865755558,0.21655618,2472,0.0717608108382588,37990.041987895966,0.120666884,0.0473931067543348,0.39027813,5348,0.1158558367205074 -3521.0500218868256,1.421659231185913,36043.68631386757,42087,0,36043.68631386757,0.21440364,2472,0.0719436150549428,39567.956351041794,0.11153517,0.0455627372524879,0.3821042,5348,0.1134518281085569 -3658.994877815247,1.4797418117523191,37483.879881858826,43777,0,37483.879881858826,0.20935404,2472,0.0694249791806308,41146.22732448578,0.1301486,0.0486029122392758,0.38293865,5348,0.1129111675371945 -3798.062462329865,1.5324275493621826,38923.82216382027,45457,0,38923.82216382027,0.2068121,2472,0.0691203054861576,42725.361545324326,0.08889699,0.035825625038456,0.37530112,5348,0.1112312579047472 -3936.06809425354,1.5900306701660156,40364.56965065002,47141,0,40364.56965065002,0.19967549,2472,0.0666219811914772,44304.246056079865,0.094712876,0.0381311850779195,0.3694132,5348,0.1085665736601755 -4071.692668676376,1.653782606124878,41804.55584001541,48834,0,41804.55584001541,0.1959807,2472,0.0662563727581094,45879.99431490898,0.11606108,0.0464308483742607,0.36678696,5348,0.108219006150014 -4207.30445432663,1.7124838829040527,43244.480036735535,50487,0,43244.480036735535,0.19380894,2472,0.0641033453171653,47455.66095995903,0.11821188,0.0462240558060097,0.3563446,5348,0.1046081659055581 -4341.327328681946,1.7806909084320068,44685.18374609947,52173,0,44685.18374609947,0.19001366,2472,0.0631080779152194,49030.528040885925,0.13081229,0.0529921570447373,0.36061352,5348,0.1045019647218977 -4476.6968767642975,1.838486671447754,46125.39481854439,53875,0,46125.39481854439,0.19238628,2472,0.0637783600430605,50606.23990535736,0.110869765,0.0435027585397347,0.35066342,5348,0.1025517248037691 -4612.31334400177,1.8981482982635496,47565.87292122841,55536,0,47565.87292122841,0.18675168,2472,0.0619706294558527,52182.46559405327,0.09881861,0.0399136879843109,0.3454917,5348,0.0991919055388744 -4747.195596456528,1.9564628601074217,49005.80888533592,57214,0,49005.80888533592,0.18420044,2472,0.0593910588426461,53757.41523528099,0.081867896,0.0328155161319974,0.34621996,5348,0.0992884520694748 -4883.638375282288,2.0142009258270264,50446.31366991997,58908,0,50446.31366991997,0.18286462,2472,0.0597972904352771,55334.49369978905,0.09207707,0.0386608058288431,0.34705952,5348,0.099771184722477 -5019.637094259262,2.0751595497131348,51886.75729846954,60577,0,51886.75729846954,0.17879401,2472,0.0577051977332277,56911.06728410721,0.07713567,0.030381777350016,0.34140152,5348,0.0971451190901454 -5157.233544111252,2.1355743408203125,53327.07553648949,62265,0,53327.07553648949,0.17833206,2472,0.0573192777202283,58489.11463499069,0.07505095,0.0310540394767666,0.33734754,5348,0.0952624617434372 -5294.676733493805,2.2014431953430176,54766.94156455994,63929,0,54766.94156455994,0.17739047,2472,0.0574614587776491,60066.56161522865,0.06894256,0.0279005010031024,0.33377695,5348,0.093379804396729 -5428.8484699726105,2.258314847946167,56207.575248003006,65615,0,56207.575248003006,0.17490257,2472,0.0573395892998598,61641.49627780914,0.07285346,0.0289643687731488,0.33438018,5348,0.0934763509273294 -5564.061456441879,2.3223886489868164,57647.53215622902,67319,0,57647.53215622902,0.1737761,2472,0.0559787134645461,63216.80353283882,0.07366558,0.0289878081494336,0.33114108,5348,0.0924239937437848 -5699.908368587494,2.387371301651001,59088.42539978027,68991,0,59088.42539978027,0.17034464,2472,0.0554506123941258,64793.68052625656,0.0607663,0.0242742533464907,0.32713333,5348,0.0918640238663023 -5834.580038309097,2.4519057273864746,60528.64407491684,70651,0,60528.64407491684,0.17073686,2472,0.0549021997440741,66368.70748090744,0.063274354,0.02436748237245956,0.3291911,5348,0.09122681676433958 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/measurements.csv deleted file mode 100644 index e7d6504d7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/measurements.csv +++ /dev/null @@ -1,759 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,55.65534,31.496857,,,,,,,,,,,,,, -1,,,31.28297,0.9402553710331124,30.141817,0.9043706614402812,5348.0,30.208836,0.9085978916580344,2472.0,35.15062928199768,152.89069986343384,35.15062928199768,117.74000239372252,0.0,0.0 -100,3.5460367,6.660453,,,,,,,,,,,,,, -200,0.63467115,5.9230337,,,,,,,,,,,,,, -300,0.5973869,5.840604,,,,,,,,,,,,,, -400,0.7159481,5.8158545,,,,,,,,,,,,,, -500,3.586852,5.823735,,,,,,,,,,,,,, -600,2.58799,5.8174133,,,,,,,,,,,,,, -700,4.7324595,5.807006,,,,,,,,,,,,,, -800,2.8516448,5.7919765,,,,,,,,,,,,,, -900,0.9616946,5.786746,,,,,,,,,,,,,, -1000,2.6366029,5.780922,,,,,,,,,,,,,, -1100,0.41061834,5.738054,,,,,,,,,,,,,, -1200,1.6531086,5.6274524,,,,,,,,,,,,,, -1300,0.9649554,5.495645,,,,,,,,,,,,,, -1400,2.478767,5.2962184,,,,,,,,,,,,,, -1500,2.087988,4.5219917,,,,,,,,,,,,,, -1600,1.2401243,3.985641,,,,,,,,,,,,,, -1672,,,6.5428987,0.8976848691695108,6.410177,0.8668623343020169,5348.0,6.379735,0.8667357260374139,2472.0,1475.9704265594482,1708.7733316421509,1475.9704265594482,232.7037012577057,0.030081033706665,0.0 -1700,1.2578862,3.6984205,,,,,,,,,,,,,, -1800,1.0317254,3.4495447,,,,,,,,,,,,,, -1900,1.126558,3.294468,,,,,,,,,,,,,, -2000,1.1344782,3.174609,,,,,,,,,,,,,, -2100,1.1077406,3.0779524,,,,,,,,,,,,,, -2200,1.2198116,2.9980273,,,,,,,,,,,,,, -2300,1.0519574,2.8807623,,,,,,,,,,,,,, -2400,1.2090063,2.8306582,,,,,,,,,,,,,, -2500,0.8764725,2.751259,,,,,,,,,,,,,, -2600,1.3351263,2.7005806,,,,,,,,,,,,,, -2700,1.0923162,2.6340835,,,,,,,,,,,,,, -2800,0.94355845,2.5550418,,,,,,,,,,,,,, -2900,1.0261002,2.5016544,,,,,,,,,,,,,, -3000,0.81180024,2.4879365,,,,,,,,,,,,,, -3100,1.2587835,2.4356225,,,,,,,,,,,,,, -3200,1.5813141,2.3579366,,,,,,,,,,,,,, -3300,0.850811,2.3146813,,,,,,,,,,,,,, -3369,,,2.9755747,0.6129117945215133,2.9443042,0.5842320206223389,5348.0,2.5972764,0.532935226372555,2472.0,2916.499334096909,3282.37225317955,2916.499334096909,365.6510236263275,0.0797476768493652,0.0 -3400,1.1862497,2.2642512,,,,,,,,,,,,,, -3500,0.8934999,2.2892594,,,,,,,,,,,,,, -3600,0.76360446,2.1820655,,,,,,,,,,,,,, -3700,1.0168115,2.1660695,,,,,,,,,,,,,, -3800,0.9896475,2.0566769,,,,,,,,,,,,,, -3900,0.9699648,2.1182768,,,,,,,,,,,,,, -4000,0.92488927,2.0648625,,,,,,,,,,,,,, -4100,0.70064807,1.9836016,,,,,,,,,,,,,, -4200,0.83219784,2.004149,,,,,,,,,,,,,, -4300,0.8046751,2.0217352,,,,,,,,,,,,,, -4400,0.77761936,2.010346,,,,,,,,,,,,,, -4500,0.9127341,1.9149112,,,,,,,,,,,,,, -4600,0.91729873,1.871301,,,,,,,,,,,,,, -4700,0.788374,1.9078677,,,,,,,,,,,,,, -4800,0.62050563,1.8635247,,,,,,,,,,,,,, -4900,0.6936256,1.8640975,,,,,,,,,,,,,, -5000,0.8446352,1.8538053,,,,,,,,,,,,,, -5038,,,0.84934884,0.2720468094502097,0.9519117,0.2816745030267337,5348.0,0.6751439,0.2215790222005565,2472.0,4356.664489269257,4857.005820035934,4356.664489269257,499.9926204681397,0.1351144313812255,0.0 -5100,0.6766945,1.8245921,,,,,,,,,,,,,, -5200,0.9053873,1.7853582,,,,,,,,,,,,,, -5300,0.64330745,1.8111768,,,,,,,,,,,,,, -5400,0.8223378,1.7890978,,,,,,,,,,,,,, -5500,0.66640025,1.7177796,,,,,,,,,,,,,, -5600,0.97442174,1.7998974,,,,,,,,,,,,,, -5700,0.8645282,1.7352675,,,,,,,,,,,,,, -5800,0.6106753,1.691118,,,,,,,,,,,,,, -5900,0.650081,1.7509849,,,,,,,,,,,,,, -6000,0.7340776,1.7113643,,,,,,,,,,,,,, -6100,0.8798186,1.7189333,,,,,,,,,,,,,, -6200,0.7470706,1.6369526,,,,,,,,,,,,,, -6300,0.7169226,1.7050018,,,,,,,,,,,,,, -6400,0.58552444,1.6444604,,,,,,,,,,,,,, -6500,0.69274306,1.6729848,,,,,,,,,,,,,, -6600,0.6177419,1.6821475,,,,,,,,,,,,,, -6700,0.6226477,1.6610866,,,,,,,,,,,,,, -6718,,,0.58703554,0.2014097413937189,0.7285059,0.2218735819728318,5348.0,0.47832748,0.1601161822354924,2472.0,5796.641132116318,6433.316517829895,5796.641132116318,636.2041306495667,0.1858932971954345,0.0 -6800,0.6310816,1.6315587,,,,,,,,,,,,,, -6900,0.6191093,1.5968645,,,,,,,,,,,,,, -7000,0.5995303,1.6341879,,,,,,,,,,,,,, -7100,0.7727753,1.6035329,,,,,,,,,,,,,, -7200,0.65745145,1.591284,,,,,,,,,,,,,, -7300,0.59274393,1.6147314,,,,,,,,,,,,,, -7400,0.70083994,1.6230382,,,,,,,,,,,,,, -7500,0.58585584,1.5960802,,,,,,,,,,,,,, -7600,0.63220745,1.6541637,,,,,,,,,,,,,, -7700,0.6276093,1.6252707,,,,,,,,,,,,,, -7800,0.83877903,1.5889275,,,,,,,,,,,,,, -7900,0.7855565,1.5887717,,,,,,,,,,,,,, -8000,0.9163538,1.6198282,,,,,,,,,,,,,, -8100,0.6222086,1.4925594,,,,,,,,,,,,,, -8200,0.7476115,1.5883794,,,,,,,,,,,,,, -8300,0.8765327,1.523773,,,,,,,,,,,,,, -8400,0.79717916,1.4737582,,,,,,,,,,,,,, -8411,,,0.49299544,0.1700506915739706,0.6641883,0.2000347567510161,5348.0,0.42482418,0.1440903459062011,2472.0,7236.699079275131,8008.11719751358,7236.699079275131,770.8168325424194,0.2430033683776855,0.0 -8500,0.5783225,1.5099348,,,,,,,,,,,,,, -8600,0.7642483,1.5061133,,,,,,,,,,,,,, -8700,0.7261298,1.5285932,,,,,,,,,,,,,, -8800,0.57980406,1.4956536,,,,,,,,,,,,,, -8900,0.5987759,1.5359657,,,,,,,,,,,,,, -9000,0.7515243,1.5476224,,,,,,,,,,,,,, -9100,0.57447916,1.5203238,,,,,,,,,,,,,, -9200,0.5871458,1.4990442,,,,,,,,,,,,,, -9300,0.6473755,1.4592754,,,,,,,,,,,,,, -9400,0.68317235,1.5484189,,,,,,,,,,,,,, -9500,0.6591369,1.5072801,,,,,,,,,,,,,, -9600,0.63970995,1.5009631,,,,,,,,,,,,,, -9700,0.67291284,1.4383215,,,,,,,,,,,,,, -9800,0.6710104,1.4904565,,,,,,,,,,,,,, -9900,0.5853493,1.5213066,,,,,,,,,,,,,, -10000,0.5668193,1.476167,,,,,,,,,,,,,, -10100,0.66655785,1.5072322,,,,,,,,,,,,,, -10118,,,0.29567704,0.1097389899676925,0.60095006,0.1838632128754453,5348.0,0.3799917,0.1304003412345378,2472.0,8677.11043047905,9608.002731323242,8677.11043047905,930.1696033477784,0.2932603359222412,0.0 -10200,0.7626423,1.4945351,,,,,,,,,,,,,, -10300,0.6241157,1.5081612,,,,,,,,,,,,,, -10400,0.5652436,1.4617248,,,,,,,,,,,,,, -10500,0.5805408,1.4469161,,,,,,,,,,,,,, -10600,0.68573785,1.4274435,,,,,,,,,,,,,, -10700,0.632974,1.4684049,,,,,,,,,,,,,, -10800,0.7082151,1.5304327,,,,,,,,,,,,,, -10900,0.65631336,1.4649652,,,,,,,,,,,,,, -11000,0.58690256,1.4498552,,,,,,,,,,,,,, -11100,0.52002716,1.4346122,,,,,,,,,,,,,, -11200,0.78269184,1.4216102,,,,,,,,,,,,,, -11300,0.5799982,1.4270382,,,,,,,,,,,,,, -11400,0.6461435,1.3841985,,,,,,,,,,,,,, -11500,0.5983232,1.4785842,,,,,,,,,,,,,, -11600,0.613635,1.4069406,,,,,,,,,,,,,, -11700,0.56952095,1.410047,,,,,,,,,,,,,, -11800,0.64298433,1.3812139,,,,,,,,,,,,,, -11832,,,0.2633589,0.0978113489841189,0.5699334,0.1738223736930013,5348.0,0.3490801,0.1201836166798692,2472.0,10117.716738939283,11185.009489297869,10117.716738939283,1066.4482853412628,0.3423874378204345,0.0 -11900,0.5965605,1.417431,,,,,,,,,,,,,, -12000,0.7976388,1.4541192,,,,,,,,,,,,,, -12100,0.52933776,1.3847505,,,,,,,,,,,,,, -12200,0.63091916,1.4206148,,,,,,,,,,,,,, -12300,0.7263533,1.4338994,,,,,,,,,,,,,, -12400,0.5266436,1.3813672,,,,,,,,,,,,,, -12500,0.59208137,1.3681934,,,,,,,,,,,,,, -12600,0.5407434,1.3478855,,,,,,,,,,,,,, -12700,0.583774,1.3958652,,,,,,,,,,,,,, -12800,0.53519225,1.4125887,,,,,,,,,,,,,, -12900,0.5632229,1.3772717,,,,,,,,,,,,,, -13000,0.62389946,1.3699812,,,,,,,,,,,,,, -13100,0.6310672,1.3887848,,,,,,,,,,,,,, -13200,0.56621855,1.3531082,,,,,,,,,,,,,, -13300,0.6055219,1.3716155,,,,,,,,,,,,,, -13400,0.6516232,1.3819195,,,,,,,,,,,,,, -13500,0.5454213,1.3831558,,,,,,,,,,,,,, -13501,,,0.25779513,0.0963556119347762,0.5493172,0.1670737712040317,5348.0,0.33060178,0.1119371153494607,2472.0,11558.25578045845,12760.356534957886,11558.25578045845,1201.1347353458405,0.3927774429321289,0.0 -13600,0.5476699,1.361814,,,,,,,,,,,,,, -13700,0.59967893,1.3776288,,,,,,,,,,,,,, -13800,0.671678,1.3749843,,,,,,,,,,,,,, -13900,0.60145813,1.3611002,,,,,,,,,,,,,, -14000,0.69167495,1.3614767,,,,,,,,,,,,,, -14100,0.85319644,1.3754181,,,,,,,,,,,,,, -14200,0.6539933,1.4022086,,,,,,,,,,,,,, -14300,0.6304278,1.3889016,,,,,,,,,,,,,, -14400,0.55435526,1.4247397,,,,,,,,,,,,,, -14500,0.6040876,1.3939384,,,,,,,,,,,,,, -14600,0.62343353,1.3689415,,,,,,,,,,,,,, -14700,0.77185595,1.3961359,,,,,,,,,,,,,, -14800,0.53939193,1.3036963,,,,,,,,,,,,,, -14900,0.58409476,1.3038367,,,,,,,,,,,,,, -15000,0.6277574,1.3452849,,,,,,,,,,,,,, -15100,0.5607586,1.3642845,,,,,,,,,,,,,, -15177,,,0.22716044,0.0864846815260038,0.5236656,0.1590024812458364,5348.0,0.3165861,0.1080169804805719,2472.0,12998.907264947891,14337.500767707825,12998.907264947891,1337.5061967372894,0.443948745727539,0.0 -15200,0.66547996,1.3843321,,,,,,,,,,,,,, -15300,0.62239516,1.3226368,,,,,,,,,,,,,, -15400,0.553594,1.3740004,,,,,,,,,,,,,, -15500,0.63493246,1.2912935,,,,,,,,,,,,,, -15600,0.6602118,1.3312826,,,,,,,,,,,,,, -15700,0.6988399,1.3589365,,,,,,,,,,,,,, -15800,0.7114914,1.2730293,,,,,,,,,,,,,, -15900,0.61077017,1.3341434,,,,,,,,,,,,,, -16000,0.5833273,1.2887193,,,,,,,,,,,,,, -16100,0.51466095,1.2835833,,,,,,,,,,,,,, -16200,0.70362246,1.3924646,,,,,,,,,,,,,, -16300,0.6555161,1.2915282,,,,,,,,,,,,,, -16400,0.55911905,1.3811188,,,,,,,,,,,,,, -16500,0.6834128,1.3415177,,,,,,,,,,,,,, -16600,0.70366216,1.306151,,,,,,,,,,,,,, -16700,0.6114095,1.3114053,,,,,,,,,,,,,, -16800,0.6325811,1.364156,,,,,,,,,,,,,, -16862,,,0.22985189,0.0873535887743841,0.5217149,0.1570329320215878,5348.0,0.3099842,0.1052342940710499,2472.0,14439.008511543274,15914.451095581057,14439.008511543274,1474.2307941913605,0.4952020645141601,0.0 -16900,0.72356486,1.3401258,,,,,,,,,,,,,, -17000,0.62826025,1.311433,,,,,,,,,,,,,, -17100,0.5859078,1.2782542,,,,,,,,,,,,,, -17200,0.55147463,1.3154132,,,,,,,,,,,,,, -17300,0.60992855,1.3103769,,,,,,,,,,,,,, -17400,0.6892987,1.3408211,,,,,,,,,,,,,, -17500,0.6129293,1.3197808,,,,,,,,,,,,,, -17600,0.5260965,1.2716851,,,,,,,,,,,,,, -17700,0.65497756,1.2987686,,,,,,,,,,,,,, -17800,0.69053656,1.2898481,,,,,,,,,,,,,, -17900,0.7192791,1.3512701,,,,,,,,,,,,,, -18000,0.6456443,1.2806324,,,,,,,,,,,,,, -18100,0.5259511,1.2226527,,,,,,,,,,,,,, -18200,0.5983411,1.270805,,,,,,,,,,,,,, -18300,0.63547146,1.2713313,,,,,,,,,,,,,, -18400,0.64628345,1.3221656,,,,,,,,,,,,,, -18500,0.5194967,1.2988937,,,,,,,,,,,,,, -18526,,,0.2079086,0.0806949683182493,0.49227136,0.1503229481448584,5348.0,0.29305145,0.0995470517742164,2472.0,15879.444328546524,17491.679450511932,15879.444328546524,1610.9025394916534,0.5452065467834473,0.0 -18600,0.727663,1.2752972,,,,,,,,,,,,,, -18700,0.519134,1.3227943,,,,,,,,,,,,,, -18800,0.5464832,1.3079715,,,,,,,,,,,,,, -18900,0.6717008,1.2708592,,,,,,,,,,,,,, -19000,0.55043143,1.2699635,,,,,,,,,,,,,, -19100,0.5493623,1.2970016,,,,,,,,,,,,,, -19200,0.62293595,1.3008294,,,,,,,,,,,,,, -19300,0.6077231,1.3126283,,,,,,,,,,,,,, -19400,0.5589188,1.2270902,,,,,,,,,,,,,, -19500,0.6224177,1.3301955,,,,,,,,,,,,,, -19600,0.6416088,1.2811036,,,,,,,,,,,,,, -19700,0.6916502,1.270365,,,,,,,,,,,,,, -19800,0.691069,1.3125712,,,,,,,,,,,,,, -19900,0.8167288,1.2717026,,,,,,,,,,,,,, -20000,0.5890329,1.2356873,,,,,,,,,,,,,, -20100,0.6133539,1.2064264,,,,,,,,,,,,,, -20200,0.6072641,1.2745733,,,,,,,,,,,,,, -20203,,,0.22827052,0.08320993218302,0.48092973,0.1457949158596986,5348.0,0.28144392,0.0955253590071699,2472.0,17320.42205142975,19068.82412481308,17320.42205142975,1746.945957183838,0.5986299514770508,0.0 -20300,0.67610556,1.2796859,,,,,,,,,,,,,, -20400,0.63172746,1.3027414,,,,,,,,,,,,,, -20500,0.56401163,1.3059698,,,,,,,,,,,,,, -20600,0.5880883,1.2570208,,,,,,,,,,,,,, -20700,0.59958756,1.2472085,,,,,,,,,,,,,, -20800,0.61117715,1.2368243,,,,,,,,,,,,,, -20900,0.6909385,1.1970546,,,,,,,,,,,,,, -21000,0.6315257,1.2442341,,,,,,,,,,,,,, -21100,0.54606,1.2333494,,,,,,,,,,,,,, -21200,0.60353744,1.2418553,,,,,,,,,,,,,, -21300,0.63157904,1.2393068,,,,,,,,,,,,,, -21400,0.76347834,1.2556525,,,,,,,,,,,,,, -21500,0.57219213,1.2596315,,,,,,,,,,,,,, -21600,0.57237643,1.3241484,,,,,,,,,,,,,, -21700,0.5619634,1.2601299,,,,,,,,,,,,,, -21800,0.67741233,1.2256962,,,,,,,,,,,,,, -21887,,,0.2024693,0.0782813825275657,0.4742581,0.1438736399007501,5348.0,0.2813647,0.095748786383117,2472.0,18760.71983551979,20646.238900899887,18760.71983551979,1883.9370305538173,0.6513259410858154,0.0 -21900,0.7047042,1.2476199,,,,,,,,,,,,,, -22000,0.6349834,1.2275587,,,,,,,,,,,,,, -22100,0.5707013,1.2209709,,,,,,,,,,,,,, -22200,0.57621074,1.2487543,,,,,,,,,,,,,, -22300,0.652597,1.1780349,,,,,,,,,,,,,, -22400,0.7395086,1.2336779,,,,,,,,,,,,,, -22500,0.6729703,1.2618923,,,,,,,,,,,,,, -22600,0.68707097,1.2357078,,,,,,,,,,,,,, -22700,0.6029684,1.2145468,,,,,,,,,,,,,, -22800,0.6117777,1.202252,,,,,,,,,,,,,, -22900,0.5933043,1.2180344,,,,,,,,,,,,,, -23000,0.54326,1.1940966,,,,,,,,,,,,,, -23100,0.68923414,1.2722851,,,,,,,,,,,,,, -23200,0.61532634,1.2356071,,,,,,,,,,,,,, -23300,0.64701325,1.1877352,,,,,,,,,,,,,, -23400,0.6824105,1.189693,,,,,,,,,,,,,, -23500,0.612169,1.246711,,,,,,,,,,,,,, -23551,,,0.16744016,0.0667004249365727,0.45430538,0.1394228448400706,5348.0,0.2655093,0.091300550443808,2472.0,20201.21184659004,22226.676994800568,20201.21184659004,2023.754251241684,0.7084658145904541,0.0 -23600,0.6821936,1.2193506,,,,,,,,,,,,,, -23700,0.6265399,1.1760387,,,,,,,,,,,,,, -23800,0.5894744,1.2954286,,,,,,,,,,,,,, -23900,0.76480824,1.2168516,,,,,,,,,,,,,, -24000,0.57429945,1.2498785,,,,,,,,,,,,,, -24100,0.64995867,1.1879025,,,,,,,,,,,,,, -24200,0.70021755,1.2234936,,,,,,,,,,,,,, -24300,0.5653169,1.2050344,,,,,,,,,,,,,, -24400,0.6051813,1.2410825,,,,,,,,,,,,,, -24500,0.5196281,1.2205833,,,,,,,,,,,,,, -24600,0.53914404,1.1875564,,,,,,,,,,,,,, -24700,0.57947946,1.2293165,,,,,,,,,,,,,, -24800,0.63284695,1.1450942,,,,,,,,,,,,,, -24900,0.59715444,1.2354392,,,,,,,,,,,,,, -25000,0.674155,1.1721573,,,,,,,,,,,,,, -25100,0.613595,1.1738671,,,,,,,,,,,,,, -25200,0.6240115,1.2194282,,,,,,,,,,,,,, -25245,,,0.16750503,0.0649266793314691,0.45686898,0.1384573795340664,5348.0,0.2632382,0.0904271525196514,2472.0,21641.502986192703,23805.24548983574,21641.502986192703,2161.89958691597,0.764528751373291,0.0 -25300,0.69046175,1.2777548,,,,,,,,,,,,,, -25400,0.8533684,1.2057186,,,,,,,,,,,,,, -25500,0.57769734,1.2455978,,,,,,,,,,,,,, -25600,0.5114425,1.233677,,,,,,,,,,,,,, -25700,0.60653675,1.20785,,,,,,,,,,,,,, -25800,0.64378303,1.223438,,,,,,,,,,,,,, -25900,0.61116284,1.1530671,,,,,,,,,,,,,, -26000,0.5223776,1.1863966,,,,,,,,,,,,,, -26100,0.5335659,1.1869334,,,,,,,,,,,,,, -26200,0.6686947,1.1987833,,,,,,,,,,,,,, -26300,0.5674954,1.2102916,,,,,,,,,,,,,, -26400,0.6634069,1.1582681,,,,,,,,,,,,,, -26500,0.5796547,1.192786,,,,,,,,,,,,,, -26600,0.6443898,1.258915,,,,,,,,,,,,,, -26700,0.61534464,1.2487152,,,,,,,,,,,,,, -26800,0.64293385,1.1964236,,,,,,,,,,,,,, -26900,0.6353956,1.2108504,,,,,,,,,,,,,, -26921,,,0.15699868,0.0647888841016803,0.44483903,0.1347886113712503,5348.0,0.25737813,0.088781914569496,2472.0,23081.486287117004,25378.6113049984,23081.486287117004,2295.1566026210785,0.8176376819610596,0.0 -27000,0.5121652,1.1427993,,,,,,,,,,,,,, -27100,0.5711171,1.1443841,,,,,,,,,,,,,, -27200,0.65071493,1.1955785,,,,,,,,,,,,,, -27300,0.55668247,1.1721199,,,,,,,,,,,,,, -27400,0.7332946,1.1466377,,,,,,,,,,,,,, -27500,0.5684212,1.2160251,,,,,,,,,,,,,, -27600,0.66306,1.1960537,,,,,,,,,,,,,, -27700,0.57005423,1.2371094,,,,,,,,,,,,,, -27800,0.6298046,1.191144,,,,,,,,,,,,,, -27900,0.65659755,1.173385,,,,,,,,,,,,,, -28000,0.63538766,1.2161567,,,,,,,,,,,,,, -28100,0.5919751,1.1967314,,,,,,,,,,,,,, -28200,0.9976849,1.2031549,,,,,,,,,,,,,, -28300,0.56818616,1.1706452,,,,,,,,,,,,,, -28400,0.64973253,1.1943781,,,,,,,,,,,,,, -28500,0.54009575,1.1641797,,,,,,,,,,,,,, -28600,0.66949695,1.1580898,,,,,,,,,,,,,, -28601,,,0.16138557,0.062552889186055,0.43052354,0.1311584618206744,5348.0,0.24619368,0.0839883817764507,2472.0,24521.77783679962,26954.51037287712,24521.77783679962,2430.638184070587,0.8726804256439209,0.0 -28700,0.58230895,1.1237613,,,,,,,,,,,,,, -28800,0.7020556,1.2235447,,,,,,,,,,,,,, -28900,0.5938143,1.1968373,,,,,,,,,,,,,, -29000,0.58055985,1.1715062,,,,,,,,,,,,,, -29100,0.5810806,1.1824194,,,,,,,,,,,,,, -29200,0.553701,1.1720859,,,,,,,,,,,,,, -29300,0.602626,1.1636757,,,,,,,,,,,,,, -29400,0.5607995,1.1414249,,,,,,,,,,,,,, -29500,0.6209973,1.1588705,,,,,,,,,,,,,, -29600,0.582861,1.1392083,,,,,,,,,,,,,, -29700,0.64018905,1.2177309,,,,,,,,,,,,,, -29800,0.71705586,1.1498282,,,,,,,,,,,,,, -29900,0.6195194,1.2186105,,,,,,,,,,,,,, -30000,0.71962523,1.1655798,,,,,,,,,,,,,, -30100,0.554867,1.1652359,,,,,,,,,,,,,, -30200,0.7075187,1.1820238,,,,,,,,,,,,,, -30300,,,0.15786675,0.0623297506176489,0.42404762,0.1281558647190013,5348.0,0.24171406,0.0827696869985578,2472.0,25962.238719701767,28531.932410240173,25962.238719701767,2567.4729528427124,0.9269707202911376,0.0 -30300,0.6430073,1.1711458,,,,,,,,,,,,,, -30400,0.7698446,1.2142531,,,,,,,,,,,,,, -30500,0.6738613,1.1530533,,,,,,,,,,,,,, -30600,0.8400936,1.1858144,,,,,,,,,,,,,, -30700,0.6192139,1.1738961,,,,,,,,,,,,,, -30800,0.76909965,1.1417258,,,,,,,,,,,,,, -30900,0.5951845,1.1624061,,,,,,,,,,,,,, -31000,0.5968161,1.1317344,,,,,,,,,,,,,, -31100,0.6881333,1.1602162,,,,,,,,,,,,,, -31200,0.53339434,1.1520002,,,,,,,,,,,,,, -31300,0.6338372,1.2054751,,,,,,,,,,,,,, -31400,0.66626,1.198939,,,,,,,,,,,,,, -31500,0.5312983,1.2045865,,,,,,,,,,,,,, -31600,0.6009537,1.1691183,,,,,,,,,,,,,, -31700,0.68535274,1.113796,,,,,,,,,,,,,, -31800,0.6444939,1.1459911,,,,,,,,,,,,,, -31900,0.5050573,1.1744349,,,,,,,,,,,,,, -31982,,,0.1505317,0.0602347318274219,0.42554152,0.1288896183515645,5348.0,0.24244688,0.0822009627688745,2472.0,27402.67057681084,30108.24808359146,27402.67057681084,2703.1411702632904,1.0707290172576904,0.0 -32000,0.67888886,1.1101221,,,,,,,,,,,,,, -32100,0.5807746,1.1370897,,,,,,,,,,,,,, -32200,0.6307081,1.1609776,,,,,,,,,,,,,, -32300,0.58973193,1.1875145,,,,,,,,,,,,,, -32400,0.5943667,1.167517,,,,,,,,,,,,,, -32500,0.5510779,1.1206177,,,,,,,,,,,,,, -32600,0.7136931,1.1664524,,,,,,,,,,,,,, -32700,0.60633624,1.1385721,,,,,,,,,,,,,, -32800,0.57245123,1.1510823,,,,,,,,,,,,,, -32900,0.7431682,1.1283448,,,,,,,,,,,,,, -33000,0.700413,1.112183,,,,,,,,,,,,,, -33100,0.5931646,1.1101427,,,,,,,,,,,,,, -33200,0.5937588,1.0786146,,,,,,,,,,,,,, -33300,0.64139,1.1692722,,,,,,,,,,,,,, -33400,0.63654333,1.1253122,,,,,,,,,,,,,, -33500,0.55604845,1.0788187,,,,,,,,,,,,,, -33600,0.56398386,1.1636965,,,,,,,,,,,,,, -33698,,,0.14140327,0.0557607767315864,0.4030115,0.1216003552912326,5348.0,0.23038326,0.0779355310462494,2472.0,28843.26565551757,31686.66285228729,28843.26565551757,2840.8364021778107,1.121816635131836,0.0 -33700,0.6602458,1.1573125,,,,,,,,,,,,,, -33800,0.7124999,1.160516,,,,,,,,,,,,,, -33900,0.5971715,1.1247259,,,,,,,,,,,,,, -34000,0.59912807,1.0999489,,,,,,,,,,,,,, -34100,0.5751119,1.1446509,,,,,,,,,,,,,, -34200,0.66666216,1.1362208,,,,,,,,,,,,,, -34300,0.6802127,1.0965284,,,,,,,,,,,,,, -34400,0.60137326,1.1532263,,,,,,,,,,,,,, -34500,0.6545143,1.1562283,,,,,,,,,,,,,, -34600,0.70080274,1.0820811,,,,,,,,,,,,,, -34700,0.66100186,1.0818045,,,,,,,,,,,,,, -34800,0.56504744,1.1330957,,,,,,,,,,,,,, -34900,0.6328332,1.1284496,,,,,,,,,,,,,, -35000,0.79151034,1.157974,,,,,,,,,,,,,, -35100,0.73652965,1.1488777,,,,,,,,,,,,,, -35200,0.63654035,1.1418388,,,,,,,,,,,,,, -35300,0.64782006,1.1083376,,,,,,,,,,,,,, -35400,0.5746977,1.0796843,,,,,,,,,,,,,, -35403,,,0.12408025,0.0507194035988632,0.41467035,0.1220734332911746,5348.0,0.22733217,0.0772246257591453,2472.0,30283.51073741913,33262.70490074158,30283.51073741913,2976.497559785843,1.1860380172729492,0.0 -35500,0.78594625,1.1527704,,,,,,,,,,,,,, -35600,0.63200796,1.1338226,,,,,,,,,,,,,, -35700,0.60396564,1.0990841,,,,,,,,,,,,,, -35800,0.58125824,1.1636388,,,,,,,,,,,,,, -35900,0.6065663,1.1125481,,,,,,,,,,,,,, -36000,0.72710323,1.1463366,,,,,,,,,,,,,, -36100,0.7364337,1.1048203,,,,,,,,,,,,,, -36200,0.6857127,1.1111766,,,,,,,,,,,,,, -36300,0.51842535,1.0984397,,,,,,,,,,,,,, -36400,0.5946273,1.1083606,,,,,,,,,,,,,, -36500,1.1351986,1.1033658,,,,,,,,,,,,,, -36600,0.78683263,1.1017826,,,,,,,,,,,,,, -36700,0.5824834,1.1197245,,,,,,,,,,,,,, -36800,0.694488,1.0978552,,,,,,,,,,,,,, -36900,0.6331714,1.1716914,,,,,,,,,,,,,, -37000,0.6137942,1.0499216,,,,,,,,,,,,,, -37064,,,0.12518206,0.0486296714584302,0.3949641,0.1198238991281848,5348.0,0.22285151,0.0752950256941482,2472.0,31723.40802645684,34839.59212565422,31723.40802645684,3113.3532407283783,1.2495365142822266,0.0 -37100,0.5813903,1.0629605,,,,,,,,,,,,,, -37200,0.5651438,1.062028,,,,,,,,,,,,,, -37300,0.6918192,1.1353043,,,,,,,,,,,,,, -37400,0.65518576,1.0570716,,,,,,,,,,,,,, -37500,0.5710714,1.0861841,,,,,,,,,,,,,, -37600,0.6165958,1.1118153,,,,,,,,,,,,,, -37700,0.6293094,1.0603526,,,,,,,,,,,,,, -37800,0.64886516,1.13051,,,,,,,,,,,,,, -37900,0.514008,1.0410856,,,,,,,,,,,,,, -38000,0.5370148,1.0404932,,,,,,,,,,,,,, -38100,0.6826493,1.0836085,,,,,,,,,,,,,, -38200,0.66906726,1.0912914,,,,,,,,,,,,,, -38300,0.7821987,1.1216819,,,,,,,,,,,,,, -38400,0.7175939,1.152831,,,,,,,,,,,,,, -38500,0.6246448,1.1469604,,,,,,,,,,,,,, -38600,0.5282299,1.0888969,,,,,,,,,,,,,, -38700,0.6356921,1.0895025,,,,,,,,,,,,,, -38724,,,0.1323574,0.0523750270504219,0.40408602,0.1195149502302634,5348.0,0.21985793,0.0746247435663071,2472.0,33163.48185658455,36415.07633471489,33163.48185658455,3248.6328341960907,1.309312343597412,0.0 -38800,0.6682677,1.0437866,,,,,,,,,,,,,, -38900,0.71031785,1.1472162,,,,,,,,,,,,,, -39000,0.6160686,1.127379,,,,,,,,,,,,,, -39100,0.67663324,1.0788835,,,,,,,,,,,,,, -39200,0.64907724,1.0584813,,,,,,,,,,,,,, -39300,0.5791553,1.0802447,,,,,,,,,,,,,, -39400,0.5917748,1.1185303,,,,,,,,,,,,,, -39500,0.69264084,1.114262,,,,,,,,,,,,,, -39600,0.6412035,1.0736493,,,,,,,,,,,,,, -39700,0.6509288,1.0307262,,,,,,,,,,,,,, -39800,0.6750567,1.1009269,,,,,,,,,,,,,, -39900,0.6591741,1.0355853,,,,,,,,,,,,,, -40000,0.6334091,1.0884315,,,,,,,,,,,,,, -40100,0.5577739,1.0684859,,,,,,,,,,,,,, -40200,0.7259226,1.087579,,,,,,,,,,,,,, -40300,0.63830024,1.0874871,,,,,,,,,,,,,, -40400,0.742903,1.0925173,,,,,,,,,,,,,, -40411,,,0.120666884,0.0473931067543348,0.39027813,0.1158558367205074,5348.0,0.21655618,0.0717608108382588,2472.0,34603.7865755558,37990.041987895966,34603.7865755558,3383.161405324936,1.3682467937469482,0.0 -40500,0.79504836,1.0446494,,,,,,,,,,,,,, -40600,0.72980064,1.0433719,,,,,,,,,,,,,, -40700,0.6805598,1.0492609,,,,,,,,,,,,,, -40800,0.77261525,1.1323152,,,,,,,,,,,,,, -40900,0.66675323,1.1445794,,,,,,,,,,,,,, -41000,0.741924,1.0532101,,,,,,,,,,,,,, -41100,0.60705394,1.0915562,,,,,,,,,,,,,, -41200,0.7134365,1.0436121,,,,,,,,,,,,,, -41300,0.6669555,1.0457267,,,,,,,,,,,,,, -41400,0.67173404,1.044759,,,,,,,,,,,,,, -41500,0.627519,1.0379146,,,,,,,,,,,,,, -41600,0.7561915,1.062876,,,,,,,,,,,,,, -41700,0.7673497,1.070692,,,,,,,,,,,,,, -41800,0.7217858,1.0482637,,,,,,,,,,,,,, -41900,0.62906516,1.0660211,,,,,,,,,,,,,, -42000,0.5998377,1.0146646,,,,,,,,,,,,,, -42087,,,0.11153517,0.0455627372524879,0.3821042,0.1134518281085569,5348.0,0.21440364,0.0719436150549428,2472.0,36043.68631386757,39567.956351041794,36043.68631386757,3521.0500218868256,1.421659231185913,0.0 -42100,0.6407231,1.0456845,,,,,,,,,,,,,, -42200,0.73490584,1.0153065,,,,,,,,,,,,,, -42300,0.72896457,1.0785433,,,,,,,,,,,,,, -42400,0.6272899,1.0632204,,,,,,,,,,,,,, -42500,0.68709517,1.0735795,,,,,,,,,,,,,, -42600,0.62849474,1.0438868,,,,,,,,,,,,,, -42700,0.61006457,1.0503614,,,,,,,,,,,,,, -42800,0.6533434,1.0612897,,,,,,,,,,,,,, -42900,0.5437226,1.0486268,,,,,,,,,,,,,, -43000,0.6646526,1.0403712,,,,,,,,,,,,,, -43100,0.62061715,1.0464758,,,,,,,,,,,,,, -43200,0.7575076,1.0793698,,,,,,,,,,,,,, -43300,0.721459,1.1070011,,,,,,,,,,,,,, -43400,0.6002256,0.9966881,,,,,,,,,,,,,, -43500,0.5959831,1.0555972,,,,,,,,,,,,,, -43600,0.6490993,1.0541646,,,,,,,,,,,,,, -43700,0.75451505,1.0445689,,,,,,,,,,,,,, -43777,,,0.1301486,0.0486029122392758,0.38293865,0.1129111675371945,5348.0,0.20935404,0.0694249791806308,2472.0,37483.879881858826,41146.22732448578,37483.879881858826,3658.994877815247,1.4797418117523191,0.0 -43800,0.6376418,0.99576974,,,,,,,,,,,,,, -43900,0.72244555,1.0587119,,,,,,,,,,,,,, -44000,0.5662587,0.99918294,,,,,,,,,,,,,, -44100,0.68487567,1.0792265,,,,,,,,,,,,,, -44200,0.64388204,1.0661469,,,,,,,,,,,,,, -44300,0.64356244,1.0512807,,,,,,,,,,,,,, -44400,0.6506553,1.0098282,,,,,,,,,,,,,, -44500,0.68751997,1.0740958,,,,,,,,,,,,,, -44600,0.8154671,1.0596783,,,,,,,,,,,,,, -44700,0.63552064,1.0032914,,,,,,,,,,,,,, -44800,0.79831225,0.9927091,,,,,,,,,,,,,, -44900,0.6209727,1.0599642,,,,,,,,,,,,,, -45000,0.6676062,1.069678,,,,,,,,,,,,,, -45100,0.7098215,1.0821711,,,,,,,,,,,,,, -45200,0.8966735,1.0390625,,,,,,,,,,,,,, -45300,0.7100824,0.99325705,,,,,,,,,,,,,, -45400,0.6022477,1.0062445,,,,,,,,,,,,,, -45457,,,0.08889699,0.035825625038456,0.37530112,0.1112312579047472,5348.0,0.2068121,0.0691203054861576,2472.0,38923.82216382027,42725.361545324326,38923.82216382027,3798.062462329865,1.5324275493621826,0.0 -45500,0.54472923,0.9966872,,,,,,,,,,,,,, -45600,0.79924035,1.0320405,,,,,,,,,,,,,, -45700,0.8170862,1.077717,,,,,,,,,,,,,, -45800,0.60322726,0.9993349,,,,,,,,,,,,,, -45900,0.6015059,1.0350862,,,,,,,,,,,,,, -46000,0.6867329,1.0385059,,,,,,,,,,,,,, -46100,0.90816635,1.0662874,,,,,,,,,,,,,, -46200,0.8658246,1.0166101,,,,,,,,,,,,,, -46300,0.6278811,1.0378773,,,,,,,,,,,,,, -46400,0.74830383,1.0267698,,,,,,,,,,,,,, -46500,0.63663954,1.0091351,,,,,,,,,,,,,, -46600,0.6090489,1.0198326,,,,,,,,,,,,,, -46700,0.7174384,0.97556585,,,,,,,,,,,,,, -46800,0.8158925,1.0501857,,,,,,,,,,,,,, -46900,0.6732949,1.0500121,,,,,,,,,,,,,, -47000,0.6730135,1.0427852,,,,,,,,,,,,,, -47100,0.6178435,0.99331886,,,,,,,,,,,,,, -47141,,,0.094712876,0.0381311850779195,0.3694132,0.1085665736601755,5348.0,0.19967549,0.0666219811914772,2472.0,40364.56965065002,44304.246056079865,40364.56965065002,3936.06809425354,1.5900306701660156,0.0 -47200,0.6054558,1.0538986,,,,,,,,,,,,,, -47300,0.70494246,1.0466503,,,,,,,,,,,,,, -47400,0.6733564,1.0389715,,,,,,,,,,,,,, -47500,0.73549765,1.0425923,,,,,,,,,,,,,, -47600,0.80645657,1.0164824,,,,,,,,,,,,,, -47700,0.56596005,0.9725338,,,,,,,,,,,,,, -47800,0.6558816,1.0036409,,,,,,,,,,,,,, -47900,0.541457,1.0416555,,,,,,,,,,,,,, -48000,0.719761,1.0386666,,,,,,,,,,,,,, -48100,0.7605032,1.0033146,,,,,,,,,,,,,, -48200,0.8950533,1.034025,,,,,,,,,,,,,, -48300,0.778925,1.0472794,,,,,,,,,,,,,, -48400,0.86208624,1.0820838,,,,,,,,,,,,,, -48500,0.6935178,1.0237439,,,,,,,,,,,,,, -48600,0.76198024,1.0209875,,,,,,,,,,,,,, -48700,0.62150687,0.9827876,,,,,,,,,,,,,, -48800,0.5977343,0.9599541,,,,,,,,,,,,,, -48834,,,0.11606108,0.0464308483742607,0.36678696,0.108219006150014,5348.0,0.1959807,0.0662563727581094,2472.0,41804.55584001541,45879.99431490898,41804.55584001541,4071.692668676376,1.653782606124878,0.0 -48900,0.8183634,0.99329746,,,,,,,,,,,,,, -49000,0.7214743,1.0304363,,,,,,,,,,,,,, -49100,0.6416349,1.01687,,,,,,,,,,,,,, -49200,0.78633225,0.99519295,,,,,,,,,,,,,, -49300,0.68114805,1.0220649,,,,,,,,,,,,,, -49400,0.79807293,0.9648691,,,,,,,,,,,,,, -49500,0.68819386,0.9630741,,,,,,,,,,,,,, -49600,0.6325177,1.0183318,,,,,,,,,,,,,, -49700,0.6538073,0.9676338,,,,,,,,,,,,,, -49800,0.72967756,0.9373606,,,,,,,,,,,,,, -49900,0.7183679,0.97795546,,,,,,,,,,,,,, -50000,0.6930508,0.98946637,,,,,,,,,,,,,, -50100,0.60403335,0.97056973,,,,,,,,,,,,,, -50200,0.729398,1.0012865,,,,,,,,,,,,,, -50300,0.71221316,1.0356457,,,,,,,,,,,,,, -50400,0.5893537,1.0397357,,,,,,,,,,,,,, -50487,,,0.11821188,0.0462240558060097,0.3563446,0.1046081659055581,5348.0,0.19380894,0.0641033453171653,2472.0,43244.480036735535,47455.66095995903,43244.480036735535,4207.30445432663,1.7124838829040527,0.0 -50500,0.7476564,0.9363188,,,,,,,,,,,,,, -50600,0.74591625,0.9781539,,,,,,,,,,,,,, -50700,0.7544404,0.9209946,,,,,,,,,,,,,, -50800,0.7352888,0.9777348,,,,,,,,,,,,,, -50900,0.6379102,0.9909736,,,,,,,,,,,,,, -51000,0.6661021,0.9881593,,,,,,,,,,,,,, -51100,0.838324,0.9606061,,,,,,,,,,,,,, -51200,0.71415865,1.0373802,,,,,,,,,,,,,, -51300,0.6354701,1.0076815,,,,,,,,,,,,,, -51400,0.65454984,0.96764565,,,,,,,,,,,,,, -51500,0.7416014,0.985522,,,,,,,,,,,,,, -51600,0.6811653,0.97952056,,,,,,,,,,,,,, -51700,0.83763486,0.9632429,,,,,,,,,,,,,, -51800,0.7728974,0.9724922,,,,,,,,,,,,,, -51900,0.6956566,0.99966425,,,,,,,,,,,,,, -52000,1.279499,0.98000455,,,,,,,,,,,,,, -52100,0.65643615,0.980685,,,,,,,,,,,,,, -52173,,,0.13081229,0.0529921570447373,0.36061352,0.1045019647218977,5348.0,0.19001366,0.0631080779152194,2472.0,44685.18374609947,49030.528040885925,44685.18374609947,4341.327328681946,1.7806909084320068,0.0 -52200,0.66823685,0.9493951,,,,,,,,,,,,,, -52300,0.6425284,0.95367676,,,,,,,,,,,,,, -52400,0.7617361,1.0203711,,,,,,,,,,,,,, -52500,0.8295063,0.9896198,,,,,,,,,,,,,, -52600,0.7517373,0.987512,,,,,,,,,,,,,, -52700,0.61159426,0.9807003,,,,,,,,,,,,,, -52800,0.6037705,0.9599428,,,,,,,,,,,,,, -52900,0.7397675,0.96988034,,,,,,,,,,,,,, -53000,0.9216055,0.98372006,,,,,,,,,,,,,, -53100,0.755603,0.9800207,,,,,,,,,,,,,, -53200,0.86935216,0.9935356,,,,,,,,,,,,,, -53300,0.8133298,0.996307,,,,,,,,,,,,,, -53400,0.88056153,0.99994016,,,,,,,,,,,,,, -53500,0.773195,0.94574296,,,,,,,,,,,,,, -53600,0.6497891,0.9740089,,,,,,,,,,,,,, -53700,0.7303997,0.9655664,,,,,,,,,,,,,, -53800,0.69932693,0.95117575,,,,,,,,,,,,,, -53875,,,0.110869765,0.0435027585397347,0.35066342,0.1025517248037691,5348.0,0.19238628,0.0637783600430605,2472.0,46125.39481854439,50606.23990535736,46125.39481854439,4476.6968767642975,1.838486671447754,0.0 -53900,0.72171706,0.9474175,,,,,,,,,,,,,, -54000,0.633324,0.98248965,,,,,,,,,,,,,, -54100,0.7775148,0.97927,,,,,,,,,,,,,, -54200,0.7132553,0.9584814,,,,,,,,,,,,,, -54300,0.9007946,1.018337,,,,,,,,,,,,,, -54400,0.71296567,0.97109574,,,,,,,,,,,,,, -54500,0.6068695,0.9384914,,,,,,,,,,,,,, -54600,0.79141957,0.9454896,,,,,,,,,,,,,, -54700,0.94004065,0.9902899,,,,,,,,,,,,,, -54800,0.62087035,0.9899211,,,,,,,,,,,,,, -54900,0.6816768,0.93335795,,,,,,,,,,,,,, -55000,0.72229433,0.9929821,,,,,,,,,,,,,, -55100,0.6658572,0.8860259,,,,,,,,,,,,,, -55200,0.76396453,0.952767,,,,,,,,,,,,,, -55300,0.83851177,0.9933725,,,,,,,,,,,,,, -55400,0.699611,0.970124,,,,,,,,,,,,,, -55500,0.63447267,0.9627306,,,,,,,,,,,,,, -55536,,,0.09881861,0.0399136879843109,0.3454917,0.0991919055388744,5348.0,0.18675168,0.0619706294558527,2472.0,47565.87292122841,52182.46559405327,47565.87292122841,4612.31334400177,1.8981482982635496,0.0 -55600,1.0046269,0.97795606,,,,,,,,,,,,,, -55700,0.6909929,0.9775885,,,,,,,,,,,,,, -55800,0.6670385,0.92096424,,,,,,,,,,,,,, -55900,0.704608,0.93638027,,,,,,,,,,,,,, -56000,0.7149687,0.9486242,,,,,,,,,,,,,, -56100,0.65993214,0.94440156,,,,,,,,,,,,,, -56200,0.6900584,0.93055093,,,,,,,,,,,,,, -56300,0.7952602,0.9700926,,,,,,,,,,,,,, -56400,0.76675665,0.9422614,,,,,,,,,,,,,, -56500,0.9090456,0.95684856,,,,,,,,,,,,,, -56600,0.70175415,0.94530183,,,,,,,,,,,,,, -56700,0.6682965,0.9573566,,,,,,,,,,,,,, -56800,0.7681956,0.964304,,,,,,,,,,,,,, -56900,0.7348693,0.9248137,,,,,,,,,,,,,, -57000,0.66710097,0.9757102,,,,,,,,,,,,,, -57100,0.7943798,0.98588425,,,,,,,,,,,,,, -57200,0.8176937,0.8877381,,,,,,,,,,,,,, -57214,,,0.081867896,0.0328155161319974,0.34621996,0.0992884520694748,5348.0,0.18420044,0.0593910588426461,2472.0,49005.80888533592,53757.41523528099,49005.80888533592,4747.195596456528,1.9564628601074217,0.0 -57300,0.7861485,0.87386495,,,,,,,,,,,,,, -57400,0.72720057,0.93318915,,,,,,,,,,,,,, -57500,0.6503864,0.9574715,,,,,,,,,,,,,, -57600,0.75411355,0.90829426,,,,,,,,,,,,,, -57700,0.7070206,0.9010051,,,,,,,,,,,,,, -57800,0.6999941,0.921748,,,,,,,,,,,,,, -57900,0.6189797,0.8925543,,,,,,,,,,,,,, -58000,0.7257383,0.94496787,,,,,,,,,,,,,, -58100,0.68106115,0.9407749,,,,,,,,,,,,,, -58200,0.73323786,0.9235056,,,,,,,,,,,,,, -58300,0.89552605,0.9241793,,,,,,,,,,,,,, -58400,0.69786507,0.8699805,,,,,,,,,,,,,, -58500,0.9152265,0.93686664,,,,,,,,,,,,,, -58600,0.66713125,0.8903672,,,,,,,,,,,,,, -58700,0.6808034,0.911648,,,,,,,,,,,,,, -58800,0.7736684,0.94577324,,,,,,,,,,,,,, -58900,0.83538926,0.9485173,,,,,,,,,,,,,, -58908,,,0.09207707,0.0386608058288431,0.34705952,0.099771184722477,5348.0,0.18286462,0.0597972904352771,2472.0,50446.31366991997,55334.49369978905,50446.31366991997,4883.638375282288,2.0142009258270264,0.0 -59000,0.8205705,0.8727582,,,,,,,,,,,,,, -59100,0.8644072,0.92789495,,,,,,,,,,,,,, -59200,0.9292822,0.9780957,,,,,,,,,,,,,, -59300,0.8340237,0.89598197,,,,,,,,,,,,,, -59400,0.9042469,0.9070006,,,,,,,,,,,,,, -59500,0.79715765,0.9440254,,,,,,,,,,,,,, -59600,0.6741279,0.9144948,,,,,,,,,,,,,, -59700,0.76332724,0.94810313,,,,,,,,,,,,,, -59800,1.0503604,0.8812831,,,,,,,,,,,,,, -59900,0.82874686,0.92361176,,,,,,,,,,,,,, -60000,0.854742,0.92395264,,,,,,,,,,,,,, -60100,0.73163867,0.90700555,,,,,,,,,,,,,, -60200,0.7804597,0.917339,,,,,,,,,,,,,, -60300,0.98536026,0.9466818,,,,,,,,,,,,,, -60400,0.6419417,0.90505105,,,,,,,,,,,,,, -60500,0.7328878,0.90339077,,,,,,,,,,,,,, -60577,,,0.07713567,0.030381777350016,0.34140152,0.0971451190901454,5348.0,0.17879401,0.0577051977332277,2472.0,51886.75729846954,56911.06728410721,51886.75729846954,5019.637094259262,2.0751595497131348,0.0 -60600,0.96224153,0.926414,,,,,,,,,,,,,, -60700,0.6940764,0.9082225,,,,,,,,,,,,,, -60800,0.7816132,0.8923618,,,,,,,,,,,,,, -60900,0.78892714,0.892098,,,,,,,,,,,,,, -61000,1.0107106,0.87589055,,,,,,,,,,,,,, -61100,0.7677353,0.9473364,,,,,,,,,,,,,, -61200,0.7643004,0.88913244,,,,,,,,,,,,,, -61300,0.7833746,0.90849024,,,,,,,,,,,,,, -61400,0.79672307,0.90925044,,,,,,,,,,,,,, -61500,0.92638516,0.9415647,,,,,,,,,,,,,, -61600,0.77792525,0.91049707,,,,,,,,,,,,,, -61700,0.9481865,0.864285,,,,,,,,,,,,,, -61800,0.7687734,0.9159809,,,,,,,,,,,,,, -61900,0.7188935,0.8875986,,,,,,,,,,,,,, -62000,0.7375127,0.9163103,,,,,,,,,,,,,, -62100,0.7607649,1.0056973,,,,,,,,,,,,,, -62200,0.91665894,0.8782424,,,,,,,,,,,,,, -62265,,,0.07505095,0.0310540394767666,0.33734754,0.0952624617434372,5348.0,0.17833206,0.0573192777202283,2472.0,53327.07553648949,58489.11463499069,53327.07553648949,5157.233544111252,2.1355743408203125,0.0 -62300,0.7695425,0.8610561,,,,,,,,,,,,,, -62400,0.92930174,0.8997421,,,,,,,,,,,,,, -62500,0.7070499,0.88422793,,,,,,,,,,,,,, -62600,0.7158658,0.8810182,,,,,,,,,,,,,, -62700,0.8731003,0.8808446,,,,,,,,,,,,,, -62800,0.85138744,0.86452824,,,,,,,,,,,,,, -62900,0.7253467,0.88785625,,,,,,,,,,,,,, -63000,1.0160185,0.8558051,,,,,,,,,,,,,, -63100,0.8717173,0.94123864,,,,,,,,,,,,,, -63200,1.2285242,0.90812737,,,,,,,,,,,,,, -63300,0.76193875,0.9192571,,,,,,,,,,,,,, -63400,0.70305157,0.8724472,,,,,,,,,,,,,, -63500,0.95337164,0.85017145,,,,,,,,,,,,,, -63600,0.83588743,0.9148372,,,,,,,,,,,,,, -63700,0.70504946,0.91831356,,,,,,,,,,,,,, -63800,0.7536099,0.84409183,,,,,,,,,,,,,, -63900,0.7845666,0.8960478,,,,,,,,,,,,,, -63929,,,0.06894256,0.0279005010031024,0.33377695,0.093379804396729,5348.0,0.17739047,0.0574614587776491,2472.0,54766.94156455994,60066.56161522865,54766.94156455994,5294.676733493805,2.2014431953430176,0.0 -64000,0.77436244,0.89029,,,,,,,,,,,,,, -64100,0.8067924,0.8767291,,,,,,,,,,,,,, -64200,0.831571,0.88589483,,,,,,,,,,,,,, -64300,0.93187827,0.90901613,,,,,,,,,,,,,, -64400,0.7288361,0.87161297,,,,,,,,,,,,,, -64500,0.8228785,0.8858122,,,,,,,,,,,,,, -64600,0.6314381,0.90398127,,,,,,,,,,,,,, -64700,0.71428716,0.8725052,,,,,,,,,,,,,, -64800,0.9736614,0.88380986,,,,,,,,,,,,,, -64900,0.7741239,0.8837203,,,,,,,,,,,,,, -65000,0.7720092,0.8872192,,,,,,,,,,,,,, -65100,0.90105885,0.886175,,,,,,,,,,,,,, -65200,0.88917625,0.8609511,,,,,,,,,,,,,, -65300,0.8580708,0.8388089,,,,,,,,,,,,,, -65400,1.0189201,0.871996,,,,,,,,,,,,,, -65500,0.73192865,0.8753495,,,,,,,,,,,,,, -65600,0.7435934,0.88696194,,,,,,,,,,,,,, -65615,,,0.07285346,0.0289643687731488,0.33438018,0.0934763509273294,5348.0,0.17490257,0.0573395892998598,2472.0,56207.575248003006,61641.49627780914,56207.575248003006,5428.8484699726105,2.258314847946167,0.0 -65700,0.91364384,0.87668234,,,,,,,,,,,,,, -65800,0.71835345,0.88269633,,,,,,,,,,,,,, -65900,0.6646364,0.85710216,,,,,,,,,,,,,, -66000,0.836541,0.86927074,,,,,,,,,,,,,, -66100,0.840013,0.8943811,,,,,,,,,,,,,, -66200,0.8219048,0.8698887,,,,,,,,,,,,,, -66300,0.845209,0.8544883,,,,,,,,,,,,,, -66400,0.7552241,0.85177124,,,,,,,,,,,,,, -66500,0.9929545,0.82093674,,,,,,,,,,,,,, -66600,0.82370055,0.87840796,,,,,,,,,,,,,, -66700,0.9196965,0.87274164,,,,,,,,,,,,,, -66800,0.957534,0.8842864,,,,,,,,,,,,,, -66900,0.73304737,0.86683357,,,,,,,,,,,,,, -67000,0.798157,0.89956766,,,,,,,,,,,,,, -67100,1.1072494,0.8822131,,,,,,,,,,,,,, -67200,0.79881495,0.8651874,,,,,,,,,,,,,, -67300,0.7569697,0.8758191,,,,,,,,,,,,,, -67319,,,0.07366558,0.0289878081494336,0.33114108,0.0924239937437848,5348.0,0.1737761,0.0559787134645461,2472.0,57647.53215622902,63216.80353283882,57647.53215622902,5564.061456441879,2.3223886489868164,0.0 -67400,1.0936184,0.89192325,,,,,,,,,,,,,, -67500,0.79229695,0.9492217,,,,,,,,,,,,,, -67600,0.94801223,0.88810223,,,,,,,,,,,,,, -67700,0.7026678,0.87777126,,,,,,,,,,,,,, -67800,0.7594342,0.89991593,,,,,,,,,,,,,, -67900,0.7945936,0.87831765,,,,,,,,,,,,,, -68000,0.7840088,0.8401508,,,,,,,,,,,,,, -68100,0.8332259,0.8220984,,,,,,,,,,,,,, -68200,0.77797294,0.85876596,,,,,,,,,,,,,, -68300,0.6820751,0.82639307,,,,,,,,,,,,,, -68400,1.3583248,0.8528919,,,,,,,,,,,,,, -68500,1.0030878,0.86600715,,,,,,,,,,,,,, -68600,0.8767847,0.8758581,,,,,,,,,,,,,, -68700,0.7254902,0.852304,,,,,,,,,,,,,, -68800,0.78419924,0.8749315,,,,,,,,,,,,,, -68900,0.7492663,0.8278364,,,,,,,,,,,,,, -68991,,,0.0607663,0.0242742533464907,0.32713333,0.0918640238663023,5348.0,0.17034464,0.0554506123941258,2472.0,59088.42539978027,64793.68052625656,59088.42539978027,5699.908368587494,2.387371301651001,0.0 -69000,1.3596876,0.8045478,,,,,,,,,,,,,, -69100,0.8972515,0.88192517,,,,,,,,,,,,,, -69200,0.9452104,0.8999501,,,,,,,,,,,,,, -69300,1.1802593,0.8566544,,,,,,,,,,,,,, -69400,0.8787824,0.87097865,,,,,,,,,,,,,, -69500,0.7518301,0.89432836,,,,,,,,,,,,,, -69600,0.8009477,0.87064195,,,,,,,,,,,,,, -69700,0.81016445,0.8645948,,,,,,,,,,,,,, -69800,0.81286293,0.8515697,,,,,,,,,,,,,, -69900,0.7799689,0.8407082,,,,,,,,,,,,,, -70000,0.718971,0.8809638,,,,,,,,,,,,,, -70100,0.7218017,0.84159017,,,,,,,,,,,,,, -70200,0.85101634,0.85791594,,,,,,,,,,,,,, -70300,1.4207835,0.86796594,,,,,,,,,,,,,, -70400,1.1794899,0.8456957,,,,,,,,,,,,,, -70500,0.7160158,0.8388278,,,,,,,,,,,,,, -70600,1.098553,0.8551095,,,,,,,,,,,,,, -70651,,,0.063274354,0.0243674823724595,0.3291911,0.0912268167643395,5348.0,0.17073686,0.0549021997440741,2472.0,60528.64407491684,66368.70748090744,60528.64407491684,5834.580038309097,2.4519057273864746,0.0 -70700,0.71836054,0.8266632,,,,,,,,,,,,,, -70800,0.83761144,0.9332447,,,,,,,,,,,,,, -70900,0.8406012,0.8712852,,,,,,,,,,,,,, -71000,0.8222241,0.87002075,,,,,,,,,,,,,, -71100,0.9757726,0.8782835,,,,,,,,,,,,,, -71200,0.81325233,0.8953381,,,,,,,,,,,,,, -71300,0.8463123,0.8563871,,,,,,,,,,,,,, -71314,,,,,,,,,,,61068.68136191368,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 388189452..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -117.5308644771576,0.0,34.96774101257324,1,0,34.96774101257324,30.208836,2472,0.9085978916580344,152.49866819381714,31.732792,0.9418907725440387,30.141817,5348,0.9043706614402812 -229.59972071647644,0.0315151214599609,1475.8832149505615,1680,0,1475.8832149505615,6.5108743,2472,0.899579550301627,1705.5838894844055,6.382445,0.9391896477614642,6.4972925,5348,0.8966179750330672 -350.76149916648865,0.0831305980682373,2916.303174495697,3371,0,2916.303174495697,3.5268924,2472,0.6986777161659862,3267.289516210556,3.8691866,0.7855008984725966,3.8361292,5348,0.7482549214593973 -487.6277189254761,0.1381878852844238,4356.832535982132,5028,0,4356.832535982132,0.95981365,2472,0.3003473280116994,4844.810940742493,1.2378957,0.3726033340564353,1.2712711,5348,0.359597207874335 -625.5019948482513,0.1880762577056884,5798.092703580856,6698,0,5798.092703580856,0.6022405,2472,0.1995206467206954,6424.067903518677,0.73739135,0.2415693151924872,0.87928903,5348,0.260743215192562 -760.9348003864288,0.2422053813934326,7238.464210033417,8371,0,7238.464210033417,0.5196976,2472,0.1704750878475819,7999.998436927795,0.6227412,0.2060978816320464,0.7817897,5348,0.2322523340123772 -896.4183924198151,0.2991371154785156,8678.964599847794,10043,0,8678.964599847794,0.4573212,2472,0.1549570410090792,9576.111248016356,0.6011116,0.2039046405496584,0.70295954,5348,0.21234443940257 -1029.9016468524933,0.3551392555236816,10118.942187547684,11739,0,10118.942187547684,0.4183106,2472,0.1423638616375195,11149.701045513151,0.51052195,0.1751270459160206,0.6549365,5348,0.1973604178533844 -1166.7034318447113,0.4065909385681152,11559.051738500595,13408,0,11559.051738500595,0.40097797,2472,0.1355391708813194,12726.73373579979,0.51378894,0.1716898590868122,0.63236403,5348,0.1921758691601417 -1302.8229024410248,0.4605817794799804,12999.760911226273,15097,0,12999.760911226273,0.37557527,2472,0.128369183271383,14303.689206123352,0.42208117,0.1474570592080328,0.6005494,5348,0.1824729428347992 -1437.2882940769196,0.5167298316955566,14440.03445315361,16782,0,14440.03445315361,0.35843527,2472,0.1197976966668697,15878.557571172714,0.42158687,0.1475592811097697,0.5759482,5348,0.1734265329175396 -1565.2552840709686,0.5599899291992188,15880.001097917557,18502,0,15880.001097917557,0.346281,2472,0.1170150102573477,17446.604358911514,0.40602422,0.1438895161019885,0.56187373,5348,0.1689081552854398 -1702.2899096012115,0.6148602962493896,17320.446226596832,20363,0,17320.446226596832,0.32706234,2472,0.1114699490179351,19024.20640349388,0.26609316,0.0986750889965472,0.5398281,5348,0.1643125404288596 -1833.1509010791776,0.6625058650970459,18761.039699316025,22227,0,18761.039699316025,0.32297328,2472,0.1081997846972559,20595.77546405792,0.2531626,0.0921339491439612,0.5354374,5348,0.1603927512864825 -1961.5512776374817,0.7109498977661133,20201.548386335373,24083,0,20201.548386335373,0.30685455,2472,0.1041780919302094,22164.801256656647,0.24372207,0.0909669311543335,0.516554,5348,0.1562122865114842 -2089.888088941574,0.7663748264312744,21641.7920794487,25925,0,21641.7920794487,0.29840583,2472,0.1015782097373712,23733.50555229187,0.2315767,0.0863608729376742,0.5072757,5348,0.1532483080220512 -2218.8633258342743,0.8157093524932861,23081.903094768524,27770,0,23081.903094768524,0.29091525,2472,0.0979627485629557,25302.70827627182,0.2280684,0.0847369992111042,0.49400762,5348,0.1494154107572144 -2349.9573764801025,0.8724699020385742,24522.0075917244,29607,0,24522.0075917244,0.28334573,2472,0.0948550768793289,26874.034334897995,0.21079952,0.078806826269645,0.48705956,5348,0.1460749007984398 -2479.606421470642,0.9324443340301514,25962.538994789124,31457,0,25962.538994789124,0.28156066,2472,0.095748786383117,28444.344913959503,0.24844179,0.0889036306393992,0.48052293,5348,0.1438253666354499 -2609.657151460648,0.989710807800293,27403.022592306137,33310,0,27403.022592306137,0.27785712,2472,0.0933926431458574,30015.006851911545,0.20790268,0.0784953216281683,0.46711764,5348,0.1401276345134537 -2740.5739209651947,1.044804573059082,28843.551401615143,35145,0,28843.551401615143,0.26721686,2472,0.0905083988381776,31586.579282283783,0.18601915,0.0706503207801541,0.45886686,5348,0.1364009384322774 -2871.2297463417053,1.0952837467193604,30283.677778959274,36979,0,30283.677778959274,0.25798786,2472,0.0860804744785001,33157.48137187958,0.178505,0.0666595795572328,0.4410436,5348,0.1324618399837802 -3001.384289264679,1.155475616455078,31723.92041254044,38817,0,31723.92041254044,0.25531837,2472,0.0869335608230252,34728.01073908806,0.17209044,0.0670704134366925,0.43647054,5348,0.1301929965146702 -3132.688977241516,1.2110445499420166,33163.837792634964,40660,0,33163.837792634964,0.24571325,2472,0.0827696869985578,36299.36015820503,0.18338138,0.0674733290877416,0.4277794,5348,0.1277310599843594 -3264.1424934864044,1.2691049575805664,34604.3055870533,42494,0,34604.3055870533,0.2410632,2472,0.0803119858631405,37871.41095805168,0.1785813,0.0671149421050691,0.42371103,5348,0.126253898066173 -3396.860850095749,1.3300681114196775,36044.50428843498,44317,0,36044.50428843498,0.2344677,2472,0.0794182763593524,39444.46064519882,0.15933198,0.06103466511816,0.4057534,5348,0.1218996495360939 -3527.749038219452,1.389116287231445,37484.47999191284,46154,0,37484.47999191284,0.22770649,2472,0.0764324741535149,41015.45380330086,0.15425502,0.0592063748676683,0.39746863,5348,0.1196597700261641 -3659.4434146881104,1.4481637477874756,38924.55057263374,47997,0,38924.55057263374,0.22139996,2472,0.0748481709422541,42587.35040640831,0.13490936,0.0518983999577546,0.39513963,5348,0.1172750707203336 -3791.970413923264,1.503833293914795,40364.931837558746,49845,0,40364.931837558746,0.21825966,2472,0.0737513456421505,44160.38792181015,0.13607302,0.0515339115510741,0.38606805,5348,0.1145717678635218 -3923.273575782776,1.5587153434753418,41805.41350221634,51671,0,41805.41350221634,0.21176472,2472,0.0718217455771535,45732.29967999458,0.13753137,0.0521109640474325,0.3752853,5348,0.1104395763538237 -4053.979643106461,1.6170992851257324,43246.1699757576,53500,0,43246.1699757576,0.20651865,2472,0.0683890886194219,47303.89137673378,0.120480634,0.0451126197781576,0.37096584,5348,0.1088175946397366 -4185.1421592235565,1.6726500988006592,44686.96074914932,55323,0,44686.96074914932,0.19546556,2472,0.0659720106432677,48875.97346353531,0.119767986,0.0465471313071518,0.35839888,5348,0.1048978054973594 -4317.38533616066,1.7271640300750732,46126.98453450203,57163,0,46126.98453450203,0.19099317,2472,0.0644689537505331,50448.36845588684,0.13565873,0.0476978145593217,0.35523295,5348,0.1047529857014588 -4449.386815786362,1.7888824939727783,47567.41899180412,59000,0,47567.41899180412,0.18731277,2472,0.0619706294558527,52020.93994116783,0.0892466,0.0350036959467783,0.3397383,5348,0.0990181217837937 -4580.630652904511,1.845522403717041,49007.59355258942,60810,0,49007.59355258942,0.18456773,2472,0.059980094651961,53592.48714256287,0.088539205,0.0335192043610336,0.3364041,5348,0.0972127016615657 -4710.4846367836,1.9030003547668457,50448.088452100754,62641,0,50448.088452100754,0.17749391,2472,0.0579692482684378,55162.96579504013,0.10741339,0.0409584474735139,0.32483196,5348,0.093843227743611 -4840.52444434166,1.964537143707276,51888.67530345917,64473,0,51888.67530345917,0.17470016,2472,0.0573192777202283,56733.72640347481,0.10360827,0.039054374563204,0.3186228,5348,0.091564729621441 -4969.87771320343,2.021014928817749,53328.80698037148,66305,0,53328.80698037148,0.16910523,2472,0.0560599597830723,58303.341222286224,0.12077924,0.0472280929930092,0.31349313,5348,0.0898268920706334 -5100.211203813553,2.079517364501953,54768.67781591416,68111,0,54768.67781591416,0.16567686,2472,0.0543131639347592,59873.67728209496,0.09411724,0.0350127446624418,0.30283234,5348,0.0869304961526207 -5231.957772254944,2.138704776763916,56209.37894654274,69919,0,56209.37894654274,0.16325577,2472,0.0523023175512359,61446.25600028038,0.09114316,0.0354409071780348,0.30461708,5348,0.0860229587649767 -5365.79065990448,2.1979687213897705,57650.042269945145,71727,0,57650.042269945145,0.15865384,2472,0.0511445575122377,63020.883692502975,0.06677296,0.0253651451166043,0.29442066,5348,0.0829720883980034 -5498.555616617203,2.2566795349121094,59090.54529213905,73563,0,59090.54529213905,0.15709297,2472,0.0503930290658704,64594.28295874596,0.07948388,0.0303998949085375,0.29232764,5348,0.0821417882348397 -5631.136070251465,2.3183770179748535,60530.85288310051,75365,0,60530.85288310051,0.15438108,2472,0.05033209432697581,66167.30356693268,0.06929343,0.025832999331997328,0.29072267,5348,0.08138872529615648 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/measurements.csv deleted file mode 100644 index ff8877c6f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/measurements.csv +++ /dev/null @@ -1,806 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,50.381653,30.986305,,,,,,,,,,,,,, -1,,,31.732792,0.9418907725440387,30.141817,0.9043706614402812,5348.0,30.208836,0.9085978916580344,2472.0,34.96774101257324,152.49866819381714,34.96774101257324,117.5308644771576,0.0,0.0 -100,26.344826,10.763748,,,,,,,,,,,,,, -200,1.8537667,6.207562,,,,,,,,,,,,,, -300,0.63974255,5.85892,,,,,,,,,,,,,, -400,0.496367,5.825487,,,,,,,,,,,,,, -500,0.6286103,5.800247,,,,,,,,,,,,,, -600,0.49764857,5.805335,,,,,,,,,,,,,, -700,0.4258598,5.8000426,,,,,,,,,,,,,, -800,0.5793303,5.780876,,,,,,,,,,,,,, -900,0.71394044,5.7713943,,,,,,,,,,,,,, -1000,0.27521506,5.7827373,,,,,,,,,,,,,, -1100,0.34027898,5.7782235,,,,,,,,,,,,,, -1200,0.95107204,5.793951,,,,,,,,,,,,,, -1300,0.39018092,5.799106,,,,,,,,,,,,,, -1400,0.91668314,5.786225,,,,,,,,,,,,,, -1500,0.7473002,5.7764115,,,,,,,,,,,,,, -1600,0.35061288,5.7073026,,,,,,,,,,,,,, -1680,,,6.382445,0.9391896477614642,6.4972925,0.8966179750330672,5348.0,6.5108743,0.899579550301627,2472.0,1475.8832149505615,1705.5838894844055,1475.8832149505615,229.59972071647644,0.0315151214599609,0.0 -1700,0.87677455,5.5840025,,,,,,,,,,,,,, -1800,1.6480999,5.455385,,,,,,,,,,,,,, -1900,1.0412316,5.2704687,,,,,,,,,,,,,, -2000,1.1176548,4.661881,,,,,,,,,,,,,, -2100,1.1740141,4.115972,,,,,,,,,,,,,, -2200,0.93788356,3.7633157,,,,,,,,,,,,,, -2300,1.3249353,3.4827948,,,,,,,,,,,,,, -2400,1.1400748,3.351273,,,,,,,,,,,,,, -2500,0.9284382,3.1851647,,,,,,,,,,,,,, -2600,0.93070143,3.0626948,,,,,,,,,,,,,, -2700,1.0016719,3.0111053,,,,,,,,,,,,,, -2800,0.9766339,2.9133017,,,,,,,,,,,,,, -2900,1.1352164,2.8247528,,,,,,,,,,,,,, -3000,1.0535344,2.7791357,,,,,,,,,,,,,, -3100,0.9934089,2.6860921,,,,,,,,,,,,,, -3200,1.0024524,2.618062,,,,,,,,,,,,,, -3300,1.189138,2.6096728,,,,,,,,,,,,,, -3371,,,3.8691866,0.7855008984725966,3.8361292,0.7482549214593973,5348.0,3.5268924,0.6986777161659862,2472.0,2916.303174495697,3267.289516210556,2916.303174495697,350.76149916648865,0.0831305980682373,0.0 -3400,1.1262169,2.544555,,,,,,,,,,,,,, -3500,0.9838374,2.5458674,,,,,,,,,,,,,, -3600,1.1595916,2.4622824,,,,,,,,,,,,,, -3700,1.2013037,2.4296043,,,,,,,,,,,,,, -3800,1.0702555,2.3588188,,,,,,,,,,,,,, -3900,0.9803295,2.3795662,,,,,,,,,,,,,, -4000,1.1276942,2.3517787,,,,,,,,,,,,,, -4100,1.1809188,2.2894723,,,,,,,,,,,,,, -4200,0.9808955,2.2968621,,,,,,,,,,,,,, -4300,1.2973837,2.2327342,,,,,,,,,,,,,, -4400,0.89279896,2.2673526,,,,,,,,,,,,,, -4500,0.97877055,2.1041682,,,,,,,,,,,,,, -4600,0.94982004,2.119099,,,,,,,,,,,,,, -4700,0.89580345,2.1208286,,,,,,,,,,,,,, -4800,0.87509793,2.0706055,,,,,,,,,,,,,, -4900,0.8643679,2.0398872,,,,,,,,,,,,,, -5000,0.9055061,2.1210847,,,,,,,,,,,,,, -5028,,,1.2378957,0.3726033340564353,1.2712711,0.359597207874335,5348.0,0.95981365,0.3003473280116994,2472.0,4356.832535982132,4844.810940742493,4356.832535982132,487.6277189254761,0.1381878852844238,0.0 -5100,0.8724866,1.9885439,,,,,,,,,,,,,, -5200,1.0762197,1.9506179,,,,,,,,,,,,,, -5300,0.9896116,2.0321195,,,,,,,,,,,,,, -5400,0.8707035,1.9917476,,,,,,,,,,,,,, -5500,0.95185363,1.9531325,,,,,,,,,,,,,, -5600,0.78335416,1.9499493,,,,,,,,,,,,,, -5700,0.8617293,1.9003617,,,,,,,,,,,,,, -5800,0.91902524,1.8860849,,,,,,,,,,,,,, -5900,0.86791015,1.888651,,,,,,,,,,,,,, -6000,0.7266469,1.8028966,,,,,,,,,,,,,, -6100,0.74144685,1.8699172,,,,,,,,,,,,,, -6200,0.80445474,1.8214875,,,,,,,,,,,,,, -6300,0.7928609,1.8546042,,,,,,,,,,,,,, -6400,1.0374897,1.8477199,,,,,,,,,,,,,, -6500,0.9111255,1.9092429,,,,,,,,,,,,,, -6600,0.7376215,1.8434881,,,,,,,,,,,,,, -6698,,,0.73739135,0.2415693151924872,0.87928903,0.260743215192562,5348.0,0.6022405,0.1995206467206954,2472.0,5798.092703580856,6424.067903518677,5798.092703580856,625.5019948482513,0.1880762577056884,0.0 -6700,0.7972412,1.798822,,,,,,,,,,,,,, -6800,0.7492715,1.8344951,,,,,,,,,,,,,, -6900,0.9215514,1.7796534,,,,,,,,,,,,,, -7000,0.68512625,1.7866005,,,,,,,,,,,,,, -7100,0.74260736,1.7662839,,,,,,,,,,,,,, -7200,0.8636535,1.7585702,,,,,,,,,,,,,, -7300,0.7814596,1.7827992,,,,,,,,,,,,,, -7400,0.9234576,1.7418053,,,,,,,,,,,,,, -7500,0.92928165,1.7525059,,,,,,,,,,,,,, -7600,0.8349206,1.728157,,,,,,,,,,,,,, -7700,0.89779043,1.7376484,,,,,,,,,,,,,, -7800,0.8648127,1.7270182,,,,,,,,,,,,,, -7900,0.79186684,1.737461,,,,,,,,,,,,,, -8000,0.83320355,1.7319624,,,,,,,,,,,,,, -8100,0.76335275,1.6889318,,,,,,,,,,,,,, -8200,0.8392153,1.7588339,,,,,,,,,,,,,, -8300,0.7862716,1.7298915,,,,,,,,,,,,,, -8371,,,0.6227412,0.2060978816320464,0.7817897,0.2322523340123772,5348.0,0.5196976,0.1704750878475819,2472.0,7238.464210033417,7999.998436927795,7238.464210033417,760.9348003864288,0.2422053813934326,0.0 -8400,0.78451586,1.597633,,,,,,,,,,,,,, -8500,0.7422435,1.7017267,,,,,,,,,,,,,, -8600,0.76795083,1.6866158,,,,,,,,,,,,,, -8700,0.7511821,1.6683137,,,,,,,,,,,,,, -8800,0.8295415,1.6031662,,,,,,,,,,,,,, -8900,0.6316371,1.7129529,,,,,,,,,,,,,, -9000,0.6521023,1.6038316,,,,,,,,,,,,,, -9100,0.65918994,1.6543056,,,,,,,,,,,,,, -9200,0.75245285,1.6855624,,,,,,,,,,,,,, -9300,0.696784,1.6251101,,,,,,,,,,,,,, -9400,0.6188909,1.5731145,,,,,,,,,,,,,, -9500,0.88179654,1.5749009,,,,,,,,,,,,,, -9600,0.7071505,1.6535579,,,,,,,,,,,,,, -9700,0.91930294,1.6524466,,,,,,,,,,,,,, -9800,0.74752516,1.5802193,,,,,,,,,,,,,, -9900,0.9162703,1.6197731,,,,,,,,,,,,,, -10000,0.6873138,1.569904,,,,,,,,,,,,,, -10043,,,0.6011116,0.2039046405496584,0.70295954,0.21234443940257,5348.0,0.4573212,0.1549570410090792,2472.0,8678.964599847794,9576.111248016356,8678.964599847794,896.4183924198151,0.2991371154785156,0.0 -10100,0.7272463,1.5829484,,,,,,,,,,,,,, -10200,0.8089655,1.600578,,,,,,,,,,,,,, -10300,0.6763894,1.6463946,,,,,,,,,,,,,, -10400,0.64891726,1.5460777,,,,,,,,,,,,,, -10500,0.70659345,1.5830033,,,,,,,,,,,,,, -10600,0.64019614,1.5681865,,,,,,,,,,,,,, -10700,0.706675,1.6124648,,,,,,,,,,,,,, -10800,0.8029325,1.6182206,,,,,,,,,,,,,, -10900,0.6301495,1.5524968,,,,,,,,,,,,,, -11000,0.6171067,1.5723903,,,,,,,,,,,,,, -11100,0.56647706,1.5439826,,,,,,,,,,,,,, -11200,0.59742117,1.5541248,,,,,,,,,,,,,, -11300,0.6624522,1.6055243,,,,,,,,,,,,,, -11400,0.7106156,1.544136,,,,,,,,,,,,,, -11500,0.8303741,1.5332142,,,,,,,,,,,,,, -11600,0.5537313,1.5370493,,,,,,,,,,,,,, -11700,0.754779,1.542112,,,,,,,,,,,,,, -11739,,,0.51052195,0.1751270459160206,0.6549365,0.1973604178533844,5348.0,0.4183106,0.1423638616375195,2472.0,10118.942187547684,11149.701045513151,10118.942187547684,1029.9016468524933,0.3551392555236816,0.0 -11800,0.8053734,1.5729247,,,,,,,,,,,,,, -11900,0.738388,1.5768516,,,,,,,,,,,,,, -12000,0.66726744,1.5847591,,,,,,,,,,,,,, -12100,1.0894946,1.4579504,,,,,,,,,,,,,, -12200,0.785641,1.5519537,,,,,,,,,,,,,, -12300,0.66883993,1.5340488,,,,,,,,,,,,,, -12400,0.7263732,1.5197339,,,,,,,,,,,,,, -12500,0.845203,1.5732981,,,,,,,,,,,,,, -12600,0.7930397,1.4799485,,,,,,,,,,,,,, -12700,0.88026,1.5730958,,,,,,,,,,,,,, -12800,0.6837469,1.5700186,,,,,,,,,,,,,, -12900,0.5900645,1.5555879,,,,,,,,,,,,,, -13000,0.69211024,1.4742086,,,,,,,,,,,,,, -13100,0.9073921,1.567606,,,,,,,,,,,,,, -13200,0.71975243,1.5233167,,,,,,,,,,,,,, -13300,0.83976305,1.6018567,,,,,,,,,,,,,, -13400,0.6836006,1.5215393,,,,,,,,,,,,,, -13408,,,0.51378894,0.1716898590868122,0.63236403,0.1921758691601417,5348.0,0.40097797,0.1355391708813194,2472.0,11559.051738500595,12726.73373579979,11559.051738500595,1166.7034318447113,0.4065909385681152,0.0 -13500,0.8130952,1.5277876,,,,,,,,,,,,,, -13600,1.2436051,1.5691402,,,,,,,,,,,,,, -13700,0.86698306,1.5207206,,,,,,,,,,,,,, -13800,0.7256714,1.4824393,,,,,,,,,,,,,, -13900,0.5971005,1.4598302,,,,,,,,,,,,,, -14000,0.640405,1.5112702,,,,,,,,,,,,,, -14100,0.73333174,1.5031371,,,,,,,,,,,,,, -14200,0.7511292,1.5317253,,,,,,,,,,,,,, -14300,0.6296382,1.5144224,,,,,,,,,,,,,, -14400,0.66592366,1.532094,,,,,,,,,,,,,, -14500,0.7224228,1.5565637,,,,,,,,,,,,,, -14600,0.58175725,1.4352169,,,,,,,,,,,,,, -14700,0.5486085,1.4586558,,,,,,,,,,,,,, -14800,0.6856817,1.4289824,,,,,,,,,,,,,, -14900,0.60500133,1.5217252,,,,,,,,,,,,,, -15000,0.6653488,1.4455023,,,,,,,,,,,,,, -15097,,,0.42208117,0.1474570592080328,0.6005494,0.1824729428347992,5348.0,0.37557527,0.128369183271383,2472.0,12999.760911226273,14303.689206123352,12999.760911226273,1302.8229024410248,0.4605817794799804,0.0 -15100,0.7362665,1.4814746,,,,,,,,,,,,,, -15200,0.73305553,1.4980661,,,,,,,,,,,,,, -15300,0.7000144,1.4852077,,,,,,,,,,,,,, -15400,0.7503088,1.498764,,,,,,,,,,,,,, -15500,0.64076144,1.4646859,,,,,,,,,,,,,, -15600,0.5854158,1.3901371,,,,,,,,,,,,,, -15700,0.6820771,1.4745154,,,,,,,,,,,,,, -15800,0.7252482,1.4409906,,,,,,,,,,,,,, -15900,0.6517745,1.4551417,,,,,,,,,,,,,, -16000,0.7605533,1.4813541,,,,,,,,,,,,,, -16100,0.72754985,1.4250629,,,,,,,,,,,,,, -16200,0.68586236,1.4832573,,,,,,,,,,,,,, -16300,0.79714274,1.4332744,,,,,,,,,,,,,, -16400,0.85406727,1.4708285,,,,,,,,,,,,,, -16500,0.67965454,1.4891944,,,,,,,,,,,,,, -16600,0.678899,1.4096518,,,,,,,,,,,,,, -16700,0.79985476,1.4295099,,,,,,,,,,,,,, -16782,,,0.42158687,0.1475592811097697,0.5759482,0.1734265329175396,5348.0,0.35843527,0.1197976966668697,2472.0,14440.03445315361,15878.557571172714,14440.03445315361,1437.2882940769196,0.5167298316955566,0.0 -16800,0.63190925,1.4361362,,,,,,,,,,,,,, -16900,0.93359095,1.432533,,,,,,,,,,,,,, -17000,0.598071,1.4317609,,,,,,,,,,,,,, -17100,0.71053594,1.4368149,,,,,,,,,,,,,, -17200,0.66938114,1.4680126,,,,,,,,,,,,,, -17300,0.5927224,1.4173362,,,,,,,,,,,,,, -17400,0.68278325,1.4298621,,,,,,,,,,,,,, -17500,0.69404566,1.4448724,,,,,,,,,,,,,, -17600,0.7341216,1.4214562,,,,,,,,,,,,,, -17700,0.7668637,1.464415,,,,,,,,,,,,,, -17800,0.6302365,1.4139099,,,,,,,,,,,,,, -17900,0.76909477,1.4141672,,,,,,,,,,,,,, -18000,0.72649074,1.419686,,,,,,,,,,,,,, -18100,0.7182737,1.4436724,,,,,,,,,,,,,, -18200,0.6307541,1.4624716,,,,,,,,,,,,,, -18300,0.73378235,1.4127821,,,,,,,,,,,,,, -18400,0.74648964,1.4511914,,,,,,,,,,,,,, -18500,0.65589947,1.4443429,,,,,,,,,,,,,, -18502,,,0.40602422,0.1438895161019885,0.56187373,0.1689081552854398,5348.0,0.346281,0.1170150102573477,2472.0,15880.001097917557,17446.604358911514,15880.001097917557,1565.2552840709686,0.5599899291992188,0.0 -18600,0.6569165,1.3599238,,,,,,,,,,,,,, -18700,0.64520997,1.4904858,,,,,,,,,,,,,, -18800,0.62061733,1.3913378,,,,,,,,,,,,,, -18900,0.7340102,1.4159728,,,,,,,,,,,,,, -19000,0.6957357,1.3939824,,,,,,,,,,,,,, -19100,0.7944274,1.4051019,,,,,,,,,,,,,, -19200,0.6546836,1.3848972,,,,,,,,,,,,,, -19300,0.65241927,1.3893809,,,,,,,,,,,,,, -19400,0.6056626,1.4275075,,,,,,,,,,,,,, -19500,0.65392524,1.467661,,,,,,,,,,,,,, -19600,0.6715254,1.3972274,,,,,,,,,,,,,, -19700,0.6268025,1.3875068,,,,,,,,,,,,,, -19800,0.6821443,1.4101552,,,,,,,,,,,,,, -19900,0.7131004,1.4598509,,,,,,,,,,,,,, -20000,0.6273115,1.3502146,,,,,,,,,,,,,, -20100,0.5584491,1.3922302,,,,,,,,,,,,,, -20200,0.59007037,1.3952104,,,,,,,,,,,,,, -20300,0.65658295,1.4687693,,,,,,,,,,,,,, -20363,,,0.26609316,0.0986750889965472,0.5398281,0.1643125404288596,5348.0,0.32706234,0.1114699490179351,2472.0,17320.446226596832,19024.20640349388,17320.446226596832,1702.2899096012115,0.6148602962493896,0.0 -20400,1.0124904,1.4429891,,,,,,,,,,,,,, -20500,0.62697107,1.4539,,,,,,,,,,,,,, -20600,0.7354467,1.3997494,,,,,,,,,,,,,, -20700,0.6741571,1.3658515,,,,,,,,,,,,,, -20800,0.6355004,1.468507,,,,,,,,,,,,,, -20900,0.74176955,1.3984035,,,,,,,,,,,,,, -21000,0.8257611,1.4275995,,,,,,,,,,,,,, -21100,0.8374326,1.4230223,,,,,,,,,,,,,, -21200,0.684965,1.3591318,,,,,,,,,,,,,, -21300,0.8131148,1.3678966,,,,,,,,,,,,,, -21400,0.74155504,1.418034,,,,,,,,,,,,,, -21500,0.6947854,1.3930963,,,,,,,,,,,,,, -21600,0.6490522,1.4036366,,,,,,,,,,,,,, -21700,0.6997731,1.4089363,,,,,,,,,,,,,, -21800,0.5960054,1.3308018,,,,,,,,,,,,,, -21900,0.63825405,1.3624989,,,,,,,,,,,,,, -22000,0.66221535,1.3876997,,,,,,,,,,,,,, -22100,0.5682937,1.3696485,,,,,,,,,,,,,, -22200,0.7071851,1.3656852,,,,,,,,,,,,,, -22227,,,0.2531626,0.0921339491439612,0.5354374,0.1603927512864825,5348.0,0.32297328,0.1081997846972559,2472.0,18761.039699316025,20595.77546405792,18761.039699316025,1833.1509010791776,0.6625058650970459,0.0 -22300,0.71235335,1.4165641,,,,,,,,,,,,,, -22400,0.6253919,1.3620901,,,,,,,,,,,,,, -22500,0.6084062,1.3855091,,,,,,,,,,,,,, -22600,0.6345813,1.3690075,,,,,,,,,,,,,, -22700,0.67850846,1.3492879,,,,,,,,,,,,,, -22800,0.7678607,1.3449124,,,,,,,,,,,,,, -22900,0.7382919,1.3563544,,,,,,,,,,,,,, -23000,0.69488716,1.3815415,,,,,,,,,,,,,, -23100,0.7432118,1.3874013,,,,,,,,,,,,,, -23200,0.74835443,1.399479,,,,,,,,,,,,,, -23300,0.86485934,1.3339753,,,,,,,,,,,,,, -23400,0.6249033,1.3767457,,,,,,,,,,,,,, -23500,0.6897648,1.2639055,,,,,,,,,,,,,, -23600,0.6492406,1.2981322,,,,,,,,,,,,,, -23700,0.59619594,1.3178717,,,,,,,,,,,,,, -23800,0.85458845,1.3614407,,,,,,,,,,,,,, -23900,0.78476906,1.3639456,,,,,,,,,,,,,, -24000,0.6917179,1.3390366,,,,,,,,,,,,,, -24083,,,0.24372207,0.0909669311543335,0.516554,0.1562122865114842,5348.0,0.30685455,0.1041780919302094,2472.0,20201.548386335373,22164.801256656647,20201.548386335373,1961.5512776374817,0.7109498977661133,0.0 -24100,0.76642054,1.3440158,,,,,,,,,,,,,, -24200,0.7285386,1.3958579,,,,,,,,,,,,,, -24300,0.6760391,1.4007838,,,,,,,,,,,,,, -24400,0.82463497,1.3217815,,,,,,,,,,,,,, -24500,0.84654146,1.4340222,,,,,,,,,,,,,, -24600,0.62461483,1.3987525,,,,,,,,,,,,,, -24700,0.6632109,1.2991345,,,,,,,,,,,,,, -24800,0.5982821,1.2929165,,,,,,,,,,,,,, -24900,0.6090117,1.3710504,,,,,,,,,,,,,, -25000,0.65364987,1.3474636,,,,,,,,,,,,,, -25100,0.61020094,1.3111718,,,,,,,,,,,,,, -25200,0.75307626,1.3383707,,,,,,,,,,,,,, -25300,0.7909155,1.3834751,,,,,,,,,,,,,, -25400,0.72082996,1.3470027,,,,,,,,,,,,,, -25500,0.76443493,1.3637676,,,,,,,,,,,,,, -25600,0.6936721,1.3681962,,,,,,,,,,,,,, -25700,0.79787064,1.3672059,,,,,,,,,,,,,, -25800,0.7526753,1.3045083,,,,,,,,,,,,,, -25900,0.7308089,1.2807242,,,,,,,,,,,,,, -25925,,,0.2315767,0.0863608729376742,0.5072757,0.1532483080220512,5348.0,0.29840583,0.1015782097373712,2472.0,21641.7920794487,23733.50555229187,21641.7920794487,2089.888088941574,0.7663748264312744,0.0 -26000,0.61246306,1.3376995,,,,,,,,,,,,,, -26100,0.72716826,1.2751573,,,,,,,,,,,,,, -26200,0.6206992,1.3069955,,,,,,,,,,,,,, -26300,0.6972925,1.3392091,,,,,,,,,,,,,, -26400,0.7409414,1.277683,,,,,,,,,,,,,, -26500,0.6787324,1.2487152,,,,,,,,,,,,,, -26600,0.82966614,1.337474,,,,,,,,,,,,,, -26700,0.69630927,1.365617,,,,,,,,,,,,,, -26800,0.6736173,1.2979754,,,,,,,,,,,,,, -26900,0.65752035,1.2786783,,,,,,,,,,,,,, -27000,0.74240065,1.2815498,,,,,,,,,,,,,, -27100,0.6472189,1.2960957,,,,,,,,,,,,,, -27200,0.7234193,1.328696,,,,,,,,,,,,,, -27300,0.6938025,1.2880623,,,,,,,,,,,,,, -27400,0.8277514,1.3194655,,,,,,,,,,,,,, -27500,0.79387224,1.3313525,,,,,,,,,,,,,, -27600,0.9604713,1.310361,,,,,,,,,,,,,, -27700,0.7371702,1.3087524,,,,,,,,,,,,,, -27770,,,0.2280684,0.0847369992111042,0.49400762,0.1494154107572144,5348.0,0.29091525,0.0979627485629557,2472.0,23081.903094768524,25302.70827627182,23081.903094768524,2218.8633258342743,0.8157093524932861,0.0 -27800,0.6671195,1.3442708,,,,,,,,,,,,,, -27900,0.62363726,1.2834488,,,,,,,,,,,,,, -28000,0.6906648,1.3448665,,,,,,,,,,,,,, -28100,0.81566447,1.2626945,,,,,,,,,,,,,, -28200,0.77283025,1.3516918,,,,,,,,,,,,,, -28300,0.76323265,1.3451886,,,,,,,,,,,,,, -28400,0.6986555,1.3112609,,,,,,,,,,,,,, -28500,0.6929332,1.2935529,,,,,,,,,,,,,, -28600,0.68705016,1.3136082,,,,,,,,,,,,,, -28700,0.6370867,1.2781737,,,,,,,,,,,,,, -28800,0.8345162,1.3093754,,,,,,,,,,,,,, -28900,0.6295987,1.2881376,,,,,,,,,,,,,, -29000,0.61466986,1.2673063,,,,,,,,,,,,,, -29100,0.73081976,1.271464,,,,,,,,,,,,,, -29200,0.7089423,1.3086017,,,,,,,,,,,,,, -29300,0.7142604,1.3133601,,,,,,,,,,,,,, -29400,0.6376814,1.345294,,,,,,,,,,,,,, -29500,0.71945196,1.2670351,,,,,,,,,,,,,, -29600,0.75407994,1.2914541,,,,,,,,,,,,,, -29607,,,0.21079952,0.078806826269645,0.48705956,0.1460749007984398,5348.0,0.28334573,0.0948550768793289,2472.0,24522.0075917244,26874.034334897995,24522.0075917244,2349.9573764801025,0.8724699020385742,0.0 -29700,0.789501,1.3536595,,,,,,,,,,,,,, -29800,0.69561964,1.2640644,,,,,,,,,,,,,, -29900,0.9410612,1.2489705,,,,,,,,,,,,,, -30000,0.60923135,1.3004618,,,,,,,,,,,,,, -30100,0.6725915,1.2863597,,,,,,,,,,,,,, -30200,0.81989336,1.3041385,,,,,,,,,,,,,, -30300,0.84489226,1.308595,,,,,,,,,,,,,, -30400,0.7333481,1.3152938,,,,,,,,,,,,,, -30500,0.7157795,1.3336383,,,,,,,,,,,,,, -30600,0.708658,1.2800931,,,,,,,,,,,,,, -30700,0.73991823,1.2671349,,,,,,,,,,,,,, -30800,0.74078506,1.292337,,,,,,,,,,,,,, -30900,0.9744725,1.2412834,,,,,,,,,,,,,, -31000,0.65235007,1.2579182,,,,,,,,,,,,,, -31100,0.6941907,1.2822548,,,,,,,,,,,,,, -31200,0.6548274,1.2441427,,,,,,,,,,,,,, -31300,0.6613033,1.2875437,,,,,,,,,,,,,, -31400,0.7265932,1.3239418,,,,,,,,,,,,,, -31457,,,0.24844179,0.0889036306393992,0.48052293,0.1438253666354499,5348.0,0.28156066,0.095748786383117,2472.0,25962.538994789124,28444.344913959503,25962.538994789124,2479.606421470642,0.9324443340301514,0.0 -31500,0.6768072,1.2657819,,,,,,,,,,,,,, -31600,0.62225634,1.2566506,,,,,,,,,,,,,, -31700,0.6766326,1.2734014,,,,,,,,,,,,,, -31800,0.78865373,1.2523243,,,,,,,,,,,,,, -31900,0.7201941,1.3182029,,,,,,,,,,,,,, -32000,0.7099429,1.2069205,,,,,,,,,,,,,, -32100,0.69987315,1.2750357,,,,,,,,,,,,,, -32200,0.6425916,1.3352606,,,,,,,,,,,,,, -32300,0.70116013,1.2415603,,,,,,,,,,,,,, -32400,0.72873455,1.2634262,,,,,,,,,,,,,, -32500,0.6573688,1.2476873,,,,,,,,,,,,,, -32600,0.6439556,1.2174164,,,,,,,,,,,,,, -32700,0.69297296,1.2724916,,,,,,,,,,,,,, -32800,0.66412675,1.3027626,,,,,,,,,,,,,, -32900,0.72290033,1.2969844,,,,,,,,,,,,,, -33000,1.1029927,1.3065761,,,,,,,,,,,,,, -33100,0.68538064,1.2119725,,,,,,,,,,,,,, -33200,0.7301633,1.2214485,,,,,,,,,,,,,, -33300,0.63571817,1.2353379,,,,,,,,,,,,,, -33310,,,0.20790268,0.0784953216281683,0.46711764,0.1401276345134537,5348.0,0.27785712,0.0933926431458574,2472.0,27403.022592306137,30015.006851911545,27403.022592306137,2609.657151460648,0.989710807800293,0.0 -33400,0.7760545,1.3073459,,,,,,,,,,,,,, -33500,0.69702184,1.2519072,,,,,,,,,,,,,, -33600,0.84276336,1.2775085,,,,,,,,,,,,,, -33700,0.65379304,1.2726842,,,,,,,,,,,,,, -33800,0.7200337,1.2664969,,,,,,,,,,,,,, -33900,0.73547983,1.2866825,,,,,,,,,,,,,, -34000,0.7580713,1.2720306,,,,,,,,,,,,,, -34100,0.7451405,1.2643945,,,,,,,,,,,,,, -34200,0.82299256,1.2940605,,,,,,,,,,,,,, -34300,0.65758425,1.2422737,,,,,,,,,,,,,, -34400,0.59396344,1.2618226,,,,,,,,,,,,,, -34500,0.70137495,1.2786373,,,,,,,,,,,,,, -34600,0.6460822,1.2918166,,,,,,,,,,,,,, -34700,0.76401025,1.246556,,,,,,,,,,,,,, -34800,0.80281097,1.267346,,,,,,,,,,,,,, -34900,0.8203,1.2922775,,,,,,,,,,,,,, -35000,0.7738933,1.2695508,,,,,,,,,,,,,, -35100,0.6987484,1.2435187,,,,,,,,,,,,,, -35145,,,0.18601915,0.0706503207801541,0.45886686,0.1364009384322774,5348.0,0.26721686,0.0905083988381776,2472.0,28843.551401615143,31586.579282283783,28843.551401615143,2740.5739209651947,1.044804573059082,0.0 -35200,0.68581605,1.2634807,,,,,,,,,,,,,, -35300,0.74095464,1.2363974,,,,,,,,,,,,,, -35400,0.96102214,1.2632886,,,,,,,,,,,,,, -35500,0.6701982,1.2777828,,,,,,,,,,,,,, -35600,0.83149433,1.2532722,,,,,,,,,,,,,, -35700,0.71560687,1.1892265,,,,,,,,,,,,,, -35800,0.71078426,1.2755492,,,,,,,,,,,,,, -35900,1.010033,1.299587,,,,,,,,,,,,,, -36000,0.63218343,1.2180768,,,,,,,,,,,,,, -36100,0.7614344,1.2440765,,,,,,,,,,,,,, -36200,0.9422396,1.1649885,,,,,,,,,,,,,, -36300,0.6351352,1.2157235,,,,,,,,,,,,,, -36400,0.57955754,1.1805738,,,,,,,,,,,,,, -36500,0.72621727,1.2099986,,,,,,,,,,,,,, -36600,0.64923155,1.1944325,,,,,,,,,,,,,, -36700,0.64670384,1.2167771,,,,,,,,,,,,,, -36800,0.7087601,1.2380384,,,,,,,,,,,,,, -36900,0.8738222,1.2659013,,,,,,,,,,,,,, -36979,,,0.178505,0.0666595795572328,0.4410436,0.1324618399837802,5348.0,0.25798786,0.0860804744785001,2472.0,30283.677778959274,33157.48137187958,30283.677778959274,2871.2297463417053,1.0952837467193604,0.0 -37000,0.725418,1.2205021,,,,,,,,,,,,,, -37100,0.6575056,1.1771917,,,,,,,,,,,,,, -37200,0.7716887,1.1886548,,,,,,,,,,,,,, -37300,0.81383586,1.2514994,,,,,,,,,,,,,, -37400,1.0985878,1.2402829,,,,,,,,,,,,,, -37500,0.74121773,1.214153,,,,,,,,,,,,,, -37600,0.6770891,1.1881533,,,,,,,,,,,,,, -37700,0.6995174,1.2506131,,,,,,,,,,,,,, -37800,0.79016244,1.2668821,,,,,,,,,,,,,, -37900,0.7199695,1.2084153,,,,,,,,,,,,,, -38000,0.9205315,1.1978189,,,,,,,,,,,,,, -38100,0.6837857,1.2429245,,,,,,,,,,,,,, -38200,0.7436861,1.2462893,,,,,,,,,,,,,, -38300,0.74041796,1.208882,,,,,,,,,,,,,, -38400,0.85181975,1.2595131,,,,,,,,,,,,,, -38500,0.65319157,1.1996772,,,,,,,,,,,,,, -38600,0.65960366,1.1852864,,,,,,,,,,,,,, -38700,0.69046605,1.2508552,,,,,,,,,,,,,, -38800,0.947972,1.2114887,,,,,,,,,,,,,, -38817,,,0.17209044,0.0670704134366925,0.43647054,0.1301929965146702,5348.0,0.25531837,0.0869335608230252,2472.0,31723.92041254044,34728.01073908806,31723.92041254044,3001.384289264679,1.155475616455078,0.0 -38900,0.8401746,1.2492949,,,,,,,,,,,,,, -39000,0.7305332,1.2737061,,,,,,,,,,,,,, -39100,0.739383,1.2175251,,,,,,,,,,,,,, -39200,0.73632324,1.207275,,,,,,,,,,,,,, -39300,0.6304924,1.2012016,,,,,,,,,,,,,, -39400,0.66034406,1.2234669,,,,,,,,,,,,,, -39500,0.69193274,1.2281474,,,,,,,,,,,,,, -39600,0.7376429,1.2530626,,,,,,,,,,,,,, -39700,0.7312758,1.1741728,,,,,,,,,,,,,, -39800,0.74995303,1.2115588,,,,,,,,,,,,,, -39900,0.70996314,1.1949518,,,,,,,,,,,,,, -40000,0.83961934,1.2395463,,,,,,,,,,,,,, -40100,0.8320636,1.2383952,,,,,,,,,,,,,, -40200,0.7617471,1.2396591,,,,,,,,,,,,,, -40300,0.7667733,1.239835,,,,,,,,,,,,,, -40400,0.7998902,1.1901581,,,,,,,,,,,,,, -40500,0.70223755,1.207617,,,,,,,,,,,,,, -40600,0.7881948,1.1643025,,,,,,,,,,,,,, -40660,,,0.18338138,0.0674733290877416,0.4277794,0.1277310599843594,5348.0,0.24571325,0.0827696869985578,2472.0,33163.837792634964,36299.36015820503,33163.837792634964,3132.688977241516,1.2110445499420166,0.0 -40700,0.8033058,1.2487048,,,,,,,,,,,,,, -40800,0.7224223,1.2275112,,,,,,,,,,,,,, -40900,0.65127623,1.1953022,,,,,,,,,,,,,, -41000,0.7529644,1.1740016,,,,,,,,,,,,,, -41100,0.7012825,1.212821,,,,,,,,,,,,,, -41200,0.70105946,1.2135932,,,,,,,,,,,,,, -41300,0.9833216,1.177597,,,,,,,,,,,,,, -41400,0.717149,1.177627,,,,,,,,,,,,,, -41500,0.6995945,1.1565094,,,,,,,,,,,,,, -41600,0.81979436,1.2072822,,,,,,,,,,,,,, -41700,0.8388303,1.2106825,,,,,,,,,,,,,, -41800,0.8032606,1.1757269,,,,,,,,,,,,,, -41900,0.842051,1.1868353,,,,,,,,,,,,,, -42000,0.80579704,1.1480439,,,,,,,,,,,,,, -42100,0.83876437,1.1514516,,,,,,,,,,,,,, -42200,0.82787764,1.1768823,,,,,,,,,,,,,, -42300,0.70777166,1.1502416,,,,,,,,,,,,,, -42400,0.77337104,1.2004328,,,,,,,,,,,,,, -42494,,,0.1785813,0.0671149421050691,0.42371103,0.126253898066173,5348.0,0.2410632,0.0803119858631405,2472.0,34604.3055870533,37871.41095805168,34604.3055870533,3264.1424934864044,1.2691049575805664,0.0 -42500,0.81180245,1.2154238,,,,,,,,,,,,,, -42600,0.7076889,1.1694931,,,,,,,,,,,,,, -42700,0.77373236,1.1924279,,,,,,,,,,,,,, -42800,0.8375398,1.1514891,,,,,,,,,,,,,, -42900,0.83267516,1.1762766,,,,,,,,,,,,,, -43000,0.78902113,1.1697991,,,,,,,,,,,,,, -43100,0.76492774,1.1812124,,,,,,,,,,,,,, -43200,0.734727,1.1527066,,,,,,,,,,,,,, -43300,0.7492245,1.159779,,,,,,,,,,,,,, -43400,0.8192841,1.16976,,,,,,,,,,,,,, -43500,0.76318765,1.1747012,,,,,,,,,,,,,, -43600,0.83128387,1.1790411,,,,,,,,,,,,,, -43700,0.88661575,1.1342114,,,,,,,,,,,,,, -43800,0.74237984,1.1486726,,,,,,,,,,,,,, -43900,0.7552385,1.1794457,,,,,,,,,,,,,, -44000,0.690698,1.103161,,,,,,,,,,,,,, -44100,0.64381546,1.1388427,,,,,,,,,,,,,, -44200,0.68027645,1.1806687,,,,,,,,,,,,,, -44300,0.7444509,1.1656711,,,,,,,,,,,,,, -44317,,,0.15933198,0.06103466511816,0.4057534,0.1218996495360939,5348.0,0.2344677,0.0794182763593524,2472.0,36044.50428843498,39444.46064519882,36044.50428843498,3396.860850095749,1.3300681114196775,0.0 -44400,0.68094933,1.1881871,,,,,,,,,,,,,, -44500,0.7427185,1.1644608,,,,,,,,,,,,,, -44600,0.77268726,1.1653855,,,,,,,,,,,,,, -44700,1.010356,1.1173141,,,,,,,,,,,,,, -44800,0.82051957,1.1684538,,,,,,,,,,,,,, -44900,0.79804033,1.1746596,,,,,,,,,,,,,, -45000,0.7350594,1.1616833,,,,,,,,,,,,,, -45100,0.7245099,1.1224749,,,,,,,,,,,,,, -45200,0.7522,1.1278741,,,,,,,,,,,,,, -45300,0.719603,1.1723237,,,,,,,,,,,,,, -45400,0.968688,1.173511,,,,,,,,,,,,,, -45500,0.8712712,1.121443,,,,,,,,,,,,,, -45600,0.8221362,1.119145,,,,,,,,,,,,,, -45700,0.75291204,1.093239,,,,,,,,,,,,,, -45800,0.75595105,1.1944002,,,,,,,,,,,,,, -45900,0.7443716,1.1185352,,,,,,,,,,,,,, -46000,0.7908984,1.1814252,,,,,,,,,,,,,, -46100,0.80807585,1.116398,,,,,,,,,,,,,, -46154,,,0.15425502,0.0592063748676683,0.39746863,0.1196597700261641,5348.0,0.22770649,0.0764324741535149,2472.0,37484.47999191284,41015.45380330086,37484.47999191284,3527.749038219452,1.389116287231445,0.0 -46200,0.8587368,1.1426219,,,,,,,,,,,,,, -46300,0.8418962,1.1117197,,,,,,,,,,,,,, -46400,0.6929338,1.1651449,,,,,,,,,,,,,, -46500,0.7248685,1.1237905,,,,,,,,,,,,,, -46600,0.6855604,1.1282712,,,,,,,,,,,,,, -46700,0.7745891,1.135457,,,,,,,,,,,,,, -46800,0.7674938,1.1570389,,,,,,,,,,,,,, -46900,0.7461658,1.1204166,,,,,,,,,,,,,, -47000,0.849226,1.1284379,,,,,,,,,,,,,, -47100,1.1192985,1.0594904,,,,,,,,,,,,,, -47200,0.8163098,1.1416225,,,,,,,,,,,,,, -47300,0.74251336,1.176703,,,,,,,,,,,,,, -47400,0.7798837,1.1197524,,,,,,,,,,,,,, -47500,0.930316,1.1330119,,,,,,,,,,,,,, -47600,0.8852346,1.1179426,,,,,,,,,,,,,, -47700,0.9639835,1.1049058,,,,,,,,,,,,,, -47800,0.7446115,1.1533589,,,,,,,,,,,,,, -47900,1.0162448,1.1240308,,,,,,,,,,,,,, -47997,,,0.13490936,0.0518983999577546,0.39513963,0.1172750707203336,5348.0,0.22139996,0.0748481709422541,2472.0,38924.55057263374,42587.35040640831,38924.55057263374,3659.4434146881104,1.4481637477874756,0.0 -48000,0.744648,1.1215062,,,,,,,,,,,,,, -48100,0.88898385,1.1186523,,,,,,,,,,,,,, -48200,0.94224465,1.1837686,,,,,,,,,,,,,, -48300,0.90398395,1.1040998,,,,,,,,,,,,,, -48400,0.7700344,1.1259761,,,,,,,,,,,,,, -48500,0.972609,1.1162591,,,,,,,,,,,,,, -48600,1.0387485,1.1186389,,,,,,,,,,,,,, -48700,0.84731597,1.0942436,,,,,,,,,,,,,, -48800,0.8735725,1.0809972,,,,,,,,,,,,,, -48900,0.92750764,1.0933734,,,,,,,,,,,,,, -49000,0.736801,1.1659393,,,,,,,,,,,,,, -49100,0.7143545,1.1226454,,,,,,,,,,,,,, -49200,0.73160535,1.0940992,,,,,,,,,,,,,, -49300,0.6962025,1.0866094,,,,,,,,,,,,,, -49400,1.0102556,1.1144834,,,,,,,,,,,,,, -49500,0.7286829,1.0616608,,,,,,,,,,,,,, -49600,0.7152963,1.1117129,,,,,,,,,,,,,, -49700,0.7455189,1.0846047,,,,,,,,,,,,,, -49800,0.8141828,1.0039954,,,,,,,,,,,,,, -49845,,,0.13607302,0.0515339115510741,0.38606805,0.1145717678635218,5348.0,0.21825966,0.0737513456421505,2472.0,40364.931837558746,44160.38792181015,40364.931837558746,3791.970413923264,1.503833293914795,0.0 -49900,0.83162594,1.1263117,,,,,,,,,,,,,, -50000,0.8862502,1.0911416,,,,,,,,,,,,,, -50100,1.0139849,1.0763785,,,,,,,,,,,,,, -50200,0.76793534,1.1223379,,,,,,,,,,,,,, -50300,0.73991436,1.1127708,,,,,,,,,,,,,, -50400,0.858866,1.0803857,,,,,,,,,,,,,, -50500,0.9010245,1.0275809,,,,,,,,,,,,,, -50600,0.935896,1.1329614,,,,,,,,,,,,,, -50700,0.88343287,1.0459944,,,,,,,,,,,,,, -50800,0.90250987,1.0566112,,,,,,,,,,,,,, -50900,0.8351449,1.0496141,,,,,,,,,,,,,, -51000,1.0526135,1.1287842,,,,,,,,,,,,,, -51100,0.843772,1.095754,,,,,,,,,,,,,, -51200,0.7652732,1.074385,,,,,,,,,,,,,, -51300,0.9242325,1.140304,,,,,,,,,,,,,, -51400,0.8813339,1.0702902,,,,,,,,,,,,,, -51500,0.83954376,1.1173356,,,,,,,,,,,,,, -51600,0.8277881,1.118638,,,,,,,,,,,,,, -51671,,,0.13753137,0.0521109640474325,0.3752853,0.1104395763538237,5348.0,0.21176472,0.0718217455771535,2472.0,41805.41350221634,45732.29967999458,41805.41350221634,3923.273575782776,1.5587153434753418,0.0 -51700,0.9585651,1.093283,,,,,,,,,,,,,, -51800,0.8357396,1.0560299,,,,,,,,,,,,,, -51900,0.76720273,1.0966074,,,,,,,,,,,,,, -52000,0.8346626,1.0782471,,,,,,,,,,,,,, -52100,0.7678623,1.0447799,,,,,,,,,,,,,, -52200,0.7885657,1.0895077,,,,,,,,,,,,,, -52300,0.89191794,1.0792372,,,,,,,,,,,,,, -52400,0.7751941,1.0937215,,,,,,,,,,,,,, -52500,0.9303035,1.0774288,,,,,,,,,,,,,, -52600,0.8481858,1.0605286,,,,,,,,,,,,,, -52700,0.8260827,1.0497428,,,,,,,,,,,,,, -52800,0.7893319,1.0281987,,,,,,,,,,,,,, -52900,0.7661029,1.0387635,,,,,,,,,,,,,, -53000,1.0852197,1.0599102,,,,,,,,,,,,,, -53100,0.8261377,1.0925245,,,,,,,,,,,,,, -53200,0.8061592,1.0559111,,,,,,,,,,,,,, -53300,0.9712237,1.0661004,,,,,,,,,,,,,, -53400,0.9186549,1.0895435,,,,,,,,,,,,,, -53500,,,0.120480634,0.0451126197781576,0.37096584,0.1088175946397366,5348.0,0.20651865,0.0683890886194219,2472.0,43246.1699757576,47303.89137673378,43246.1699757576,4053.979643106461,1.6170992851257324,0.0 -53500,0.87243974,1.0606108,,,,,,,,,,,,,, -53600,1.1720756,1.0623055,,,,,,,,,,,,,, -53700,0.81481224,0.9982871,,,,,,,,,,,,,, -53800,0.7670862,1.0381204,,,,,,,,,,,,,, -53900,0.88542825,1.0646031,,,,,,,,,,,,,, -54000,0.88267976,1.049318,,,,,,,,,,,,,, -54100,1.0772245,0.99616045,,,,,,,,,,,,,, -54200,0.79760677,1.0512825,,,,,,,,,,,,,, -54300,0.9385167,1.0773432,,,,,,,,,,,,,, -54400,0.8046858,1.0271605,,,,,,,,,,,,,, -54500,0.86920446,1.1226237,,,,,,,,,,,,,, -54600,0.9483814,1.0324589,,,,,,,,,,,,,, -54700,0.8760016,1.0915401,,,,,,,,,,,,,, -54800,0.964738,1.0763612,,,,,,,,,,,,,, -54900,0.9760585,0.99961615,,,,,,,,,,,,,, -55000,0.8369362,1.0441206,,,,,,,,,,,,,, -55100,0.8167888,1.0300945,,,,,,,,,,,,,, -55200,0.91146076,1.0898676,,,,,,,,,,,,,, -55300,0.8805431,1.0446507,,,,,,,,,,,,,, -55323,,,0.119767986,0.0465471313071518,0.35839888,0.1048978054973594,5348.0,0.19546556,0.0659720106432677,2472.0,44686.96074914932,48875.97346353531,44686.96074914932,4185.1421592235565,1.6726500988006592,0.0 -55400,1.0654386,1.0327955,,,,,,,,,,,,,, -55500,0.94246244,1.072404,,,,,,,,,,,,,, -55600,0.9791033,1.049745,,,,,,,,,,,,,, -55700,0.85240436,1.0142015,,,,,,,,,,,,,, -55800,0.95482445,1.0390759,,,,,,,,,,,,,, -55900,0.90079474,1.0493,,,,,,,,,,,,,, -56000,0.97477275,1.028423,,,,,,,,,,,,,, -56100,0.8479377,0.987484,,,,,,,,,,,,,, -56200,1.0113432,1.0641907,,,,,,,,,,,,,, -56300,1.0856696,1.0201887,,,,,,,,,,,,,, -56400,0.9505877,1.0637444,,,,,,,,,,,,,, -56500,0.9434516,1.0109253,,,,,,,,,,,,,, -56600,0.97308826,1.0411991,,,,,,,,,,,,,, -56700,0.90817434,1.0245061,,,,,,,,,,,,,, -56800,0.837648,0.9985648,,,,,,,,,,,,,, -56900,0.8398548,0.9971792,,,,,,,,,,,,,, -57000,0.85176563,1.059987,,,,,,,,,,,,,, -57100,0.8974101,1.0317086,,,,,,,,,,,,,, -57163,,,0.13565873,0.0476978145593217,0.35523295,0.1047529857014588,5348.0,0.19099317,0.0644689537505331,2472.0,46126.98453450203,50448.36845588684,46126.98453450203,4317.38533616066,1.7271640300750732,0.0 -57200,1.008187,0.94925576,,,,,,,,,,,,,, -57300,1.0076562,1.0194374,,,,,,,,,,,,,, -57400,0.96837467,1.0372664,,,,,,,,,,,,,, -57500,1.0470097,1.076487,,,,,,,,,,,,,, -57600,1.010519,1.0354178,,,,,,,,,,,,,, -57700,0.88010937,0.9745457,,,,,,,,,,,,,, -57800,0.9749968,0.9955672,,,,,,,,,,,,,, -57900,0.9585498,1.0473675,,,,,,,,,,,,,, -58000,0.91071665,1.001078,,,,,,,,,,,,,, -58100,0.87327313,1.0392188,,,,,,,,,,,,,, -58200,0.90451443,1.0321473,,,,,,,,,,,,,, -58300,0.9756543,0.9829784,,,,,,,,,,,,,, -58400,0.93280196,0.99483,,,,,,,,,,,,,, -58500,1.0023057,1.0295594,,,,,,,,,,,,,, -58600,0.8563677,0.9859212,,,,,,,,,,,,,, -58700,1.0331166,1.0197868,,,,,,,,,,,,,, -58800,0.9116826,0.99106866,,,,,,,,,,,,,, -58900,1.0245736,0.9967543,,,,,,,,,,,,,, -59000,,,0.0892466,0.0350036959467783,0.3397383,0.0990181217837937,5348.0,0.18731277,0.0619706294558527,2472.0,47567.41899180412,52020.93994116783,47567.41899180412,4449.386815786362,1.7888824939727783,0.0 -59000,0.9482307,1.0056573,,,,,,,,,,,,,, -59100,0.8243064,0.9657101,,,,,,,,,,,,,, -59200,0.87190986,1.0406426,,,,,,,,,,,,,, -59300,1.0940601,0.9958963,,,,,,,,,,,,,, -59400,1.4685034,0.9936838,,,,,,,,,,,,,, -59500,0.9187382,1.0395647,,,,,,,,,,,,,, -59600,1.0002456,0.9521964,,,,,,,,,,,,,, -59700,0.9442919,1.006419,,,,,,,,,,,,,, -59800,0.89049447,0.96681696,,,,,,,,,,,,,, -59900,0.9080448,0.9698381,,,,,,,,,,,,,, -60000,0.9659065,1.0006648,,,,,,,,,,,,,, -60100,0.92337584,0.9847823,,,,,,,,,,,,,, -60200,1.0049245,0.98024106,,,,,,,,,,,,,, -60300,0.9209971,0.98317325,,,,,,,,,,,,,, -60400,1.0035881,1.0247998,,,,,,,,,,,,,, -60500,0.9611712,0.9695714,,,,,,,,,,,,,, -60600,1.0530396,0.9975486,,,,,,,,,,,,,, -60700,1.0077336,0.98024607,,,,,,,,,,,,,, -60800,1.00129,0.95742404,,,,,,,,,,,,,, -60810,,,0.088539205,0.0335192043610336,0.3364041,0.0972127016615657,5348.0,0.18456773,0.059980094651961,2472.0,49007.59355258942,53592.48714256287,49007.59355258942,4580.630652904511,1.845522403717041,0.0 -60900,0.996053,0.98551214,,,,,,,,,,,,,, -61000,1.045146,0.97664475,,,,,,,,,,,,,, -61100,1.0053194,1.0215218,,,,,,,,,,,,,, -61200,1.167322,0.9890876,,,,,,,,,,,,,, -61300,1.0333304,1.0264504,,,,,,,,,,,,,, -61400,0.96363014,1.0248634,,,,,,,,,,,,,, -61500,1.1978118,0.99101216,,,,,,,,,,,,,, -61600,0.96989757,0.9864584,,,,,,,,,,,,,, -61700,1.2796031,0.9443514,,,,,,,,,,,,,, -61800,1.1614032,0.9881383,,,,,,,,,,,,,, -61900,1.1110067,0.9171547,,,,,,,,,,,,,, -62000,0.89865875,0.88560444,,,,,,,,,,,,,, -62100,1.3598614,1.0472819,,,,,,,,,,,,,, -62200,0.9928717,0.973057,,,,,,,,,,,,,, -62300,0.96486664,0.99409586,,,,,,,,,,,,,, -62400,0.91642296,0.91370183,,,,,,,,,,,,,, -62500,0.8920706,0.93685555,,,,,,,,,,,,,, -62600,1.1923854,0.96987754,,,,,,,,,,,,,, -62641,,,0.10741339,0.0409584474735139,0.32483196,0.093843227743611,5348.0,0.17749391,0.0579692482684378,2472.0,50448.088452100754,55162.96579504013,50448.088452100754,4710.4846367836,1.9030003547668457,0.0 -62700,1.0203325,0.9524763,,,,,,,,,,,,,, -62800,1.0633141,0.95181566,,,,,,,,,,,,,, -62900,0.9999192,0.9634365,,,,,,,,,,,,,, -63000,1.0456501,0.978724,,,,,,,,,,,,,, -63100,0.9706815,0.943223,,,,,,,,,,,,,, -63200,1.0637422,0.9514858,,,,,,,,,,,,,, -63300,1.1686275,0.9193041,,,,,,,,,,,,,, -63400,1.0142634,0.95156014,,,,,,,,,,,,,, -63500,0.99253005,0.9386111,,,,,,,,,,,,,, -63600,1.3124723,0.9966669,,,,,,,,,,,,,, -63700,1.0809451,0.91246337,,,,,,,,,,,,,, -63800,1.040159,0.9480788,,,,,,,,,,,,,, -63900,1.0229021,0.95536786,,,,,,,,,,,,,, -64000,1.0718316,0.95658684,,,,,,,,,,,,,, -64100,1.0003899,0.9087108,,,,,,,,,,,,,, -64200,1.0551542,0.8929804,,,,,,,,,,,,,, -64300,0.9856225,0.9217429,,,,,,,,,,,,,, -64400,1.0600796,0.94080937,,,,,,,,,,,,,, -64473,,,0.10360827,0.039054374563204,0.3186228,0.091564729621441,5348.0,0.17470016,0.0573192777202283,2472.0,51888.67530345917,56733.72640347481,51888.67530345917,4840.52444434166,1.964537143707276,0.0 -64500,1.2478008,0.9938206,,,,,,,,,,,,,, -64600,1.1024079,0.9490681,,,,,,,,,,,,,, -64700,1.2971485,0.9558991,,,,,,,,,,,,,, -64800,1.1260825,0.98019254,,,,,,,,,,,,,, -64900,1.0751168,0.93953145,,,,,,,,,,,,,, -65000,1.1863599,0.9860204,,,,,,,,,,,,,, -65100,0.9719559,0.89856833,,,,,,,,,,,,,, -65200,1.1699739,0.8907753,,,,,,,,,,,,,, -65300,0.9903922,0.9438327,,,,,,,,,,,,,, -65400,1.072526,0.914657,,,,,,,,,,,,,, -65500,1.0906029,0.9345518,,,,,,,,,,,,,, -65600,0.9401968,0.9422413,,,,,,,,,,,,,, -65700,1.215109,0.92840546,,,,,,,,,,,,,, -65800,0.97596204,0.9184518,,,,,,,,,,,,,, -65900,1.0943907,0.9246528,,,,,,,,,,,,,, -66000,1.0254571,0.8858833,,,,,,,,,,,,,, -66100,1.3922551,0.92498076,,,,,,,,,,,,,, -66200,1.1591532,0.94658345,,,,,,,,,,,,,, -66300,1.2445158,0.965106,,,,,,,,,,,,,, -66305,,,0.12077924,0.0472280929930092,0.31349313,0.0898268920706334,5348.0,0.16910523,0.0560599597830723,2472.0,53328.80698037148,58303.341222286224,53328.80698037148,4969.87771320343,2.021014928817749,0.0 -66400,1.2606586,0.96594054,,,,,,,,,,,,,, -66500,0.9944219,0.8596481,,,,,,,,,,,,,, -66600,1.0773646,0.9422514,,,,,,,,,,,,,, -66700,1.0428483,0.9173579,,,,,,,,,,,,,, -66800,1.0924313,0.9140485,,,,,,,,,,,,,, -66900,1.2916965,0.94000125,,,,,,,,,,,,,, -67000,1.5589423,0.9378728,,,,,,,,,,,,,, -67100,1.0853179,0.9191372,,,,,,,,,,,,,, -67200,0.99952984,0.8916313,,,,,,,,,,,,,, -67300,1.0281308,0.867456,,,,,,,,,,,,,, -67400,1.5036283,0.9272045,,,,,,,,,,,,,, -67500,1.1575646,0.9387014,,,,,,,,,,,,,, -67600,1.3162476,0.89753604,,,,,,,,,,,,,, -67700,1.0991164,0.89486015,,,,,,,,,,,,,, -67800,1.2018493,0.9006579,,,,,,,,,,,,,, -67900,1.0977341,0.8466623,,,,,,,,,,,,,, -68000,1.0620854,0.88974833,,,,,,,,,,,,,, -68100,1.1141634,0.8534434,,,,,,,,,,,,,, -68111,,,0.09411724,0.0350127446624418,0.30283234,0.0869304961526207,5348.0,0.16567686,0.0543131639347592,2472.0,54768.67781591416,59873.67728209496,54768.67781591416,5100.211203813553,2.079517364501953,0.0 -68200,1.0966836,0.8840797,,,,,,,,,,,,,, -68300,1.3014146,0.9172913,,,,,,,,,,,,,, -68400,1.1718965,0.9311903,,,,,,,,,,,,,, -68500,1.0679768,0.87595993,,,,,,,,,,,,,, -68600,1.0639845,0.89495444,,,,,,,,,,,,,, -68700,1.0964429,0.91236,,,,,,,,,,,,,, -68800,1.097775,0.89917064,,,,,,,,,,,,,, -68900,1.2048591,0.86328715,,,,,,,,,,,,,, -69000,1.0782105,0.88361245,,,,,,,,,,,,,, -69100,1.3484403,0.9274926,,,,,,,,,,,,,, -69200,1.125383,0.90195215,,,,,,,,,,,,,, -69300,1.119098,0.89542687,,,,,,,,,,,,,, -69400,1.0791504,0.8689683,,,,,,,,,,,,,, -69500,1.0699021,0.88111526,,,,,,,,,,,,,, -69600,1.1659647,0.8962378,,,,,,,,,,,,,, -69700,1.3013803,0.8895924,,,,,,,,,,,,,, -69800,1.2226585,0.8482044,,,,,,,,,,,,,, -69900,1.4274813,0.87105316,,,,,,,,,,,,,, -69919,,,0.09114316,0.0354409071780348,0.30461708,0.0860229587649767,5348.0,0.16325577,0.0523023175512359,2472.0,56209.37894654274,61446.25600028038,56209.37894654274,5231.957772254944,2.138704776763916,0.0 -70000,1.199239,0.9012478,,,,,,,,,,,,,, -70100,1.2597286,0.9022399,,,,,,,,,,,,,, -70200,1.0786433,0.8856536,,,,,,,,,,,,,, -70300,1.093302,0.90384877,,,,,,,,,,,,,, -70400,1.376065,0.88341653,,,,,,,,,,,,,, -70500,1.1810757,0.87975746,,,,,,,,,,,,,, -70600,1.0371507,0.8775099,,,,,,,,,,,,,, -70700,1.1599951,0.8862041,,,,,,,,,,,,,, -70800,1.132655,0.9581704,,,,,,,,,,,,,, -70900,1.1954173,0.9075747,,,,,,,,,,,,,, -71000,1.1864449,0.90660226,,,,,,,,,,,,,, -71100,1.080587,0.8743978,,,,,,,,,,,,,, -71200,1.1870294,0.91508275,,,,,,,,,,,,,, -71300,1.2756621,0.86030525,,,,,,,,,,,,,, -71400,1.5441039,0.8738105,,,,,,,,,,,,,, -71500,1.0444884,0.8790699,,,,,,,,,,,,,, -71600,1.1510484,0.8301602,,,,,,,,,,,,,, -71700,1.1711398,0.88986886,,,,,,,,,,,,,, -71727,,,0.06677296,0.0253651451166043,0.29442066,0.0829720883980034,5348.0,0.15865384,0.0511445575122377,2472.0,57650.042269945145,63020.883692502975,57650.042269945145,5365.79065990448,2.1979687213897705,0.0 -71800,1.1544787,0.8530566,,,,,,,,,,,,,, -71900,1.50341,0.82722867,,,,,,,,,,,,,, -72000,1.3416057,0.92506707,,,,,,,,,,,,,, -72100,1.1381657,0.8528194,,,,,,,,,,,,,, -72200,1.0497339,0.85382843,,,,,,,,,,,,,, -72300,1.1976839,0.85380995,,,,,,,,,,,,,, -72400,1.2073328,0.8720283,,,,,,,,,,,,,, -72500,1.1036478,0.8752708,,,,,,,,,,,,,, -72600,1.3592578,0.86959755,,,,,,,,,,,,,, -72700,1.331028,0.86338615,,,,,,,,,,,,,, -72800,1.1334833,0.8777264,,,,,,,,,,,,,, -72900,1.1134368,0.8921795,,,,,,,,,,,,,, -73000,1.2033229,0.89935863,,,,,,,,,,,,,, -73100,1.1079537,0.8373892,,,,,,,,,,,,,, -73200,1.4886532,0.90071946,,,,,,,,,,,,,, -73300,1.3842769,0.8494511,,,,,,,,,,,,,, -73400,1.2012358,0.8466366,,,,,,,,,,,,,, -73500,1.210445,0.83558303,,,,,,,,,,,,,, -73563,,,0.07948388,0.0303998949085375,0.29232764,0.0821417882348397,5348.0,0.15709297,0.0503930290658704,2472.0,59090.54529213905,64594.28295874596,59090.54529213905,5498.555616617203,2.2566795349121094,0.0 -73600,1.1568732,0.8698323,,,,,,,,,,,,,, -73700,1.1941267,0.83972013,,,,,,,,,,,,,, -73800,1.286941,0.85088724,,,,,,,,,,,,,, -73900,1.4836189,0.8688408,,,,,,,,,,,,,, -74000,1.2340769,0.87208074,,,,,,,,,,,,,, -74100,1.6921941,0.8622453,,,,,,,,,,,,,, -74200,1.2632291,0.87821734,,,,,,,,,,,,,, -74300,1.1345729,0.8044813,,,,,,,,,,,,,, -74400,1.1805092,0.8444003,,,,,,,,,,,,,, -74500,1.2603452,0.8454713,,,,,,,,,,,,,, -74600,1.2446022,0.8661894,,,,,,,,,,,,,, -74700,1.2507929,0.86038303,,,,,,,,,,,,,, -74800,1.2459906,0.8874345,,,,,,,,,,,,,, -74900,1.2488753,0.84376717,,,,,,,,,,,,,, -75000,1.1480608,0.81425995,,,,,,,,,,,,,, -75100,1.2382565,0.81966007,,,,,,,,,,,,,, -75200,1.3494326,0.8208674,,,,,,,,,,,,,, -75300,1.6075435,0.8679661,,,,,,,,,,,,,, -75365,,,0.06929343,0.0258329993319973,0.29072267,0.0813887252961564,5348.0,0.15438108,0.0503320943269758,2472.0,60530.85288310051,66167.30356693268,60530.85288310051,5631.136070251465,2.318377017974853,0.0 -75400,1.4259524,0.87921697,,,,,,,,,,,,,, -75500,1.3252176,0.8898231,,,,,,,,,,,,,, -75600,1.0909312,0.851233,,,,,,,,,,,,,, -75700,1.319133,0.8421127,,,,,,,,,,,,,, -75800,1.232341,0.84340125,,,,,,,,,,,,,, -75900,1.4615594,0.87916964,,,,,,,,,,,,,, -76000,1.2416338,0.8203891,,,,,,,,,,,,,, -76049,,,,,,,,,,,61068.082364320755,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/eval_measurements.csv deleted file mode 100644 index c9593e039..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -115.58325505256651,0.0,35.08036541938782,1,0,35.08036541938782,30.208836,2472,0.9085978916580344,150.66368436813354,30.807152,0.9418356167355484,30.141817,5348,0.9043706614402812 -239.9583158493042,0.0279610157012939,1474.9875185489657,1774,0,1474.9875185489657,2.5539382,2472,0.5610870757418804,1715.042043685913,3.0434783,0.6561460888525483,2.9196923,5348,0.6208521196790794 -370.7067046165466,0.0781688690185546,2915.676682949066,3603,0,2915.676682949066,0.9011528,2472,0.2833262242804623,3286.603778839112,1.1530465,0.3532306010182116,1.2137563,5348,0.34704615889628 -501.77516174316406,0.1309382915496826,4356.196928501129,5426,0,4356.196928501129,0.7714357,2472,0.2490605894420409,4858.320121049881,1.018678,0.3176850277429747,1.0831897,5348,0.3157071550633828 -633.5774214267731,0.1800930500030517,5796.309594631195,7221,0,5796.309594631195,0.6883703,2472,0.2286271403327036,6430.357679367065,0.879369,0.284893824790215,0.98021287,5348,0.2919180899234386 -763.5807249546051,0.2328414916992187,7236.248072147369,9054,0,7236.248072147369,0.6508851,2472,0.2148558893425141,8000.426544189453,0.9016759,0.2929333952193084,0.95180404,5348,0.282427565965417 -896.5786190032959,0.2828640937805176,8676.789829492569,10891,0,8676.789829492569,0.65417844,2472,0.2198119147726118,9574.090694189072,0.8089695,0.2691599672493282,0.95127535,5348,0.2845902082508665 -1027.3806738853457,0.3326706886291504,10117.304208755491,12730,0,10117.304208755491,0.6159002,2472,0.2037860784433205,11145.531847715378,0.8335272,0.2669985010888317,0.91189045,5348,0.2689013970282978 -1158.6492466926577,0.3859102725982666,11557.22633099556,14542,0,11557.22633099556,0.6003142,2472,0.1986472487965389,12716.850360155106,0.81453586,0.2609976846979583,0.8834435,5348,0.2635334099269142 -1291.026449918747,0.4386816024780273,12997.759685993196,16361,0,12997.759685993196,0.5641267,2472,0.188532082140028,14289.888085842133,0.7476235,0.2428817993931665,0.8414512,5348,0.250461009683617 -1424.5227282047272,0.4945368766784668,14437.954718589785,18186,0,14437.954718589785,0.5645912,2472,0.1878211768529238,15863.708154201508,0.6964414,0.2297980367821279,0.83934134,5348,0.2496403641735134 -1554.9639530181885,0.5451233386993408,15878.121164798737,20022,0,15878.121164798737,0.54033,2472,0.1794324944650945,17434.440548181534,0.6533358,0.2152976453253372,0.8104945,5348,0.2420614615213802 -1686.3452100753784,0.5957059860229492,17318.338774442673,21846,0,17318.338774442673,0.52948534,2472,0.1805090081855666,19006.164157390594,0.7111166,0.2384679209741439,0.8012005,5348,0.2396188342971895 -1819.2048013210297,0.6472318172454834,18758.5142223835,23656,0,18758.5142223835,0.5160569,2472,0.1741108606016289,20579.324452877045,0.6636916,0.2185747650370637,0.7818518,5348,0.2352549311140504 -1951.8814685344696,0.7016665935516357,20199.02674293518,25469,0,20199.02674293518,0.50244045,2472,0.1679564519732699,22152.642586946487,0.6760904,0.2182878715546995,0.7688105,5348,0.2299931451963273 -2085.1744623184204,0.7583460807800293,21639.47647380829,27307,0,21639.47647380829,0.49013776,2472,0.1643613023784859,23726.517678260803,0.5880113,0.1996346712314533,0.7525327,5348,0.2239589870338009 -2215.55966258049,0.8103816509246826,23079.74537658692,29139,0,23079.74537658692,0.47743583,2472,0.1630613612820669,25297.29909467697,0.5951339,0.2034078153011018,0.74104714,5348,0.2230514496461569 -2345.2983088493347,0.8707478046417236,24520.142600536343,30948,0,24520.142600536343,0.46406707,2472,0.1560944894684459,26867.569379091263,0.5952117,0.2003793207538572,0.70316494,5348,0.2136574722187358 -2487.0760929584503,0.9291160106658936,25960.58037161827,32794,0,25960.58037161827,0.45321754,2472,0.1542461357219751,28449.918375730515,0.42188603,0.1499192788185452,0.6963656,5348,0.2094673527906774 -2621.817850112915,0.9812886714935304,27400.81459593773,34626,0,27400.81459593773,0.42954934,2472,0.145755895435988,30025.02007961273,0.3637514,0.1303431418243586,0.67736506,5348,0.2064261370767641 -2754.01162815094,1.035853385925293,28840.77037382126,36470,0,28840.77037382126,0.41625735,2472,0.1412467247577844,31597.29881477356,0.3567816,0.1272822637748813,0.66085625,5348,0.2003823242611776 -2886.7089653015137,1.089693307876587,30281.36931490898,38286,0,30281.36931490898,0.40459725,2472,0.1378546909593159,33170.72305393219,0.3339406,0.1217827672728213,0.637265,5348,0.1931896077314461 -3021.955439567566,1.148374080657959,31721.717066049576,40096,0,31721.717066049576,0.39331612,2472,0.1330814697459021,34746.4487221241,0.3399412,0.124161671315805,0.62865573,5348,0.1924461994458229 -3157.081080198288,1.2041044235229492,33162.12484502792,41912,0,33162.12484502792,0.3796361,2472,0.1285316759084354,36322.11321473122,0.3111691,0.1121283985854748,0.60674435,5348,0.1846162758141286 -3290.895377635956,1.264575481414795,34602.23737287521,43754,0,34602.23737287521,0.36802,2472,0.1236162736376007,37896.17702317238,0.3599174,0.1250047824399735,0.59408325,5348,0.1794317271208859 -3423.787575244904,1.323909044265747,36042.2100276947,45578,0,36042.2100276947,0.35863748,2472,0.1223975788597079,39469.1763818264,0.29631695,0.1095491215770804,0.5799549,5348,0.1743340703051835 -3557.1513023376465,1.3801522254943848,37485.10423183441,47381,0,37485.10423183441,0.3452232,2472,0.1170962565758739,41045.56417584419,0.28037357,0.1038601149552454,0.56814826,5348,0.1708294312443882 -3692.1489946842194,1.4441523551940918,38925.50800943375,49209,0,38925.50800943375,0.33619353,2472,0.1127495785347226,42621.10450196266,0.2589628,0.0947917933854136,0.55032986,5348,0.168048891163096 -3825.167426109314,1.4977757930755615,40365.552735090256,51041,0,40365.552735090256,0.32483763,2472,0.1091950520992017,44194.29576420784,0.24425602,0.0925137498728103,0.53289413,5348,0.1615223456945074 -3959.577538490296,1.5582191944122314,41805.85911726952,52875,0,41805.85911726952,0.31024233,2472,0.1048483740580504,45769.149518728256,0.25342596,0.0934737874636159,0.51060355,5348,0.1543682477770161 -4092.6102225780487,1.617872714996338,43246.37415957451,54678,0,43246.37415957451,0.29198828,2472,0.0984299148944813,47342.83117246628,0.23427062,0.085687837923127,0.4877075,5348,0.148855440879732 -4226.572660923004,1.6772680282592771,44686.48836612701,56498,0,44686.48836612701,0.28381628,2472,0.0955862937460646,48917.04101586342,0.22292952,0.0828723003094216,0.46993592,5348,0.143149540921247 -4359.694350004196,1.7328250408172607,46126.97168445587,58324,0,46126.97168445587,0.27153856,2472,0.091117746227124,50490.77713441849,0.2015512,0.0759762241959345,0.45896965,5348,0.1386118539830271 -4491.721256494522,1.7927451133728027,47566.9754588604,60165,0,47566.9754588604,0.26202387,2472,0.0866288871285519,52062.943118810654,0.18237127,0.0681878160264327,0.4402338,5348,0.132664587698041 -4625.361068725586,1.8530685901641848,49006.85757493973,61990,0,49006.85757493973,0.2465062,2472,0.082850933317084,53636.60034799576,0.16777992,0.0616180573167638,0.42649555,5348,0.1305309093717717 -4757.340369462967,1.9127659797668457,50447.22711467743,63797,0,50447.22711467743,0.23887035,2472,0.0810838258891394,55209.0825612545,0.16459136,0.0626281978427584,0.41153353,5348,0.1246319163520858 -4889.8466360569,1.9750051498413088,51887.10120534897,65617,0,51887.10120534897,0.22475353,2472,0.0755184530700952,56781.60054159165,0.14450066,0.0529541422504166,0.3941793,5348,0.118501211658959 -5022.408258914948,2.0354931354522705,53327.44103908539,67460,0,53327.44103908539,0.21985038,2472,0.0726748319216785,58354.63788485527,0.14300667,0.0543315273329536,0.38965482,5348,0.1168116473734516 -5153.785512447357,2.1095762252807617,54767.72738194466,69292,0,54767.72738194466,0.20976654,2472,0.0695468486584201,59926.45112419128,0.15330945,0.0531289329793872,0.3706188,5348,0.1106809426803247 -5287.613696575165,2.169984579086304,56208.11917066574,71111,0,56208.11917066574,0.2021894,2472,0.0670079012044766,61500.805604457855,0.104285374,0.0391234694455748,0.36181018,5348,0.1078038560684321 -5422.824957847595,2.234661340713501,57648.46272945404,72948,0,57648.46272945404,0.19403726,2472,0.0655657790506367,63076.50105428696,0.102518566,0.0376658462535337,0.34724754,5348,0.1040095774158355 -5556.046092987061,2.296834707260132,59088.35154867172,74785,0,59088.35154867172,0.18991898,2472,0.063575244246745,64649.74883103371,0.12474066,0.0474371254207602,0.34358847,5348,0.1028124004363903 -5688.373752355576,2.361799716949463,60528.67248296738,76633,0,60528.67248296738,0.18748328,2472,0.062112810513273616,66222.53850531578,0.12209161,0.04441446214856531,0.33922857,5348,0.10118076406924317 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/measurements.csv deleted file mode 100644 index 37bd27a3b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/measurements.csv +++ /dev/null @@ -1,819 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,51.327003,31.740528,,,,,,,,,,,,,, -1,,,30.807152,0.9418356167355484,30.141817,0.9043706614402812,5348.0,30.208836,0.9085978916580344,2472.0,35.08036541938782,150.66368436813354,35.08036541938782,115.58325505256651,0.0,0.0 -100,2.913668,5.805973,,,,,,,,,,,,,, -200,3.2766166,5.8001566,,,,,,,,,,,,,, -300,2.977524,5.621324,,,,,,,,,,,,,, -400,2.0815504,5.553656,,,,,,,,,,,,,, -500,1.3978415,5.501116,,,,,,,,,,,,,, -600,2.8577,5.4544997,,,,,,,,,,,,,, -700,1.5390961,5.2930565,,,,,,,,,,,,,, -800,0.9051763,4.6282926,,,,,,,,,,,,,, -900,1.5698236,4.0717797,,,,,,,,,,,,,, -1000,2.9744227,3.8075185,,,,,,,,,,,,,, -1100,2.7680771,3.523741,,,,,,,,,,,,,, -1200,0.95023733,3.3035824,,,,,,,,,,,,,, -1300,1.0069748,3.1543443,,,,,,,,,,,,,, -1400,1.3719654,3.0891993,,,,,,,,,,,,,, -1500,1.383478,2.9872732,,,,,,,,,,,,,, -1600,1.6706014,2.9225397,,,,,,,,,,,,,, -1700,0.8865236,2.8426466,,,,,,,,,,,,,, -1774,,,3.0434783,0.6561460888525483,2.9196923,0.6208521196790794,5348.0,2.5539382,0.5610870757418804,2472.0,1474.9875185489657,1715.042043685913,1474.9875185489657,239.9583158493042,0.0279610157012939,0.0 -1800,1.703599,2.72208,,,,,,,,,,,,,, -1900,1.9691,2.6466093,,,,,,,,,,,,,, -2000,1.6619976,2.607233,,,,,,,,,,,,,, -2100,0.817355,2.5399368,,,,,,,,,,,,,, -2200,0.54391444,2.5041473,,,,,,,,,,,,,, -2300,1.1969945,2.411631,,,,,,,,,,,,,, -2400,0.54971176,2.3861024,,,,,,,,,,,,,, -2500,0.9204755,2.4227216,,,,,,,,,,,,,, -2600,0.76504195,2.3566017,,,,,,,,,,,,,, -2700,0.834093,2.3417137,,,,,,,,,,,,,, -2800,1.2266383,2.3304985,,,,,,,,,,,,,, -2900,0.55304825,2.278685,,,,,,,,,,,,,, -3000,1.5233542,2.3568156,,,,,,,,,,,,,, -3100,0.89315695,2.2034905,,,,,,,,,,,,,, -3200,0.70686543,2.1746578,,,,,,,,,,,,,, -3300,0.91939986,2.1629481,,,,,,,,,,,,,, -3400,1.3110164,2.281736,,,,,,,,,,,,,, -3500,0.9752169,2.1983469,,,,,,,,,,,,,, -3600,1.1243263,2.1857727,,,,,,,,,,,,,, -3603,,,1.1530465,0.3532306010182116,1.2137563,0.34704615889628,5348.0,0.9011528,0.2833262242804623,2472.0,2915.676682949066,3286.603778839112,2915.676682949066,370.7067046165466,0.0781688690185546,0.0 -3700,1.202691,2.1782768,,,,,,,,,,,,,, -3800,1.0794083,2.0966969,,,,,,,,,,,,,, -3900,0.83553517,2.1477234,,,,,,,,,,,,,, -4000,0.60595125,2.0410094,,,,,,,,,,,,,, -4100,0.56699646,2.1134748,,,,,,,,,,,,,, -4200,0.9313837,2.127647,,,,,,,,,,,,,, -4300,0.53222746,2.1254845,,,,,,,,,,,,,, -4400,0.4619082,2.081369,,,,,,,,,,,,,, -4500,0.81363046,2.0556529,,,,,,,,,,,,,, -4600,0.6906248,2.1118674,,,,,,,,,,,,,, -4700,1.113687,2.0058906,,,,,,,,,,,,,, -4800,0.79699165,2.112159,,,,,,,,,,,,,, -4900,0.6151718,2.0543532,,,,,,,,,,,,,, -5000,0.4673009,2.091731,,,,,,,,,,,,,, -5100,0.46728483,2.054836,,,,,,,,,,,,,, -5200,0.45539615,2.0208457,,,,,,,,,,,,,, -5300,1.0065519,2.0555391,,,,,,,,,,,,,, -5400,0.5852523,2.0470254,,,,,,,,,,,,,, -5426,,,1.018678,0.3176850277429747,1.0831897,0.3157071550633828,5348.0,0.7714357,0.2490605894420409,2472.0,4356.196928501129,4858.320121049881,4356.196928501129,501.77516174316406,0.1309382915496826,0.0 -5500,0.6827183,2.006588,,,,,,,,,,,,,, -5600,0.42355144,2.0151355,,,,,,,,,,,,,, -5700,0.4408083,1.985532,,,,,,,,,,,,,, -5800,0.947965,2.0499578,,,,,,,,,,,,,, -5900,1.2669332,2.076466,,,,,,,,,,,,,, -6000,0.6864027,1.9806731,,,,,,,,,,,,,, -6100,0.5851519,2.0084774,,,,,,,,,,,,,, -6200,1.1451197,2.0131803,,,,,,,,,,,,,, -6300,0.97971493,2.0475862,,,,,,,,,,,,,, -6400,1.1484877,2.0134847,,,,,,,,,,,,,, -6500,0.6929764,1.9671475,,,,,,,,,,,,,, -6600,0.5756703,2.1039984,,,,,,,,,,,,,, -6700,0.8218309,2.0075676,,,,,,,,,,,,,, -6800,1.0014747,1.9893292,,,,,,,,,,,,,, -6900,0.44774848,1.9478574,,,,,,,,,,,,,, -7000,0.4524958,2.004385,,,,,,,,,,,,,, -7100,0.8279339,1.9793674,,,,,,,,,,,,,, -7200,0.49063668,1.9985273,,,,,,,,,,,,,, -7221,,,0.879369,0.284893824790215,0.98021287,0.2919180899234386,5348.0,0.6883703,0.2286271403327036,2472.0,5796.309594631195,6430.357679367065,5796.309594631195,633.5774214267731,0.1800930500030517,0.0 -7300,0.82989126,1.9313495,,,,,,,,,,,,,, -7400,0.4967496,1.9620832,,,,,,,,,,,,,, -7500,1.3355808,1.943529,,,,,,,,,,,,,, -7600,0.5431243,1.9878259,,,,,,,,,,,,,, -7700,0.7110138,2.0015595,,,,,,,,,,,,,, -7800,0.78698754,1.9891019,,,,,,,,,,,,,, -7900,0.8989867,1.9495841,,,,,,,,,,,,,, -8000,1.1112984,1.9871505,,,,,,,,,,,,,, -8100,0.74919343,1.8665923,,,,,,,,,,,,,, -8200,0.7766685,1.9862654,,,,,,,,,,,,,, -8300,0.55153227,1.9133166,,,,,,,,,,,,,, -8400,0.6367904,1.9220923,,,,,,,,,,,,,, -8500,1.0959033,1.9328719,,,,,,,,,,,,,, -8600,1.1801794,1.9536079,,,,,,,,,,,,,, -8700,0.97409284,1.9791505,,,,,,,,,,,,,, -8800,1.13105,1.9538475,,,,,,,,,,,,,, -8900,0.52997446,1.8336642,,,,,,,,,,,,,, -9000,1.3637114,1.9023397,,,,,,,,,,,,,, -9054,,,0.9016759,0.2929333952193084,0.95180404,0.282427565965417,5348.0,0.6508851,0.2148558893425141,2472.0,7236.248072147369,8000.426544189453,7236.248072147369,763.5807249546051,0.2328414916992187,0.0 -9100,0.9469202,1.9818598,,,,,,,,,,,,,, -9200,0.5431623,1.8582536,,,,,,,,,,,,,, -9300,0.9771902,1.911231,,,,,,,,,,,,,, -9400,1.2743657,1.9404975,,,,,,,,,,,,,, -9500,0.7260515,1.849168,,,,,,,,,,,,,, -9600,0.5719195,1.9367403,,,,,,,,,,,,,, -9700,0.84212327,1.9187403,,,,,,,,,,,,,, -9800,0.6883667,1.8944147,,,,,,,,,,,,,, -9900,0.6059759,1.9315456,,,,,,,,,,,,,, -10000,0.44589344,1.9013149,,,,,,,,,,,,,, -10100,0.8065694,1.85485,,,,,,,,,,,,,, -10200,0.6716825,1.8854276,,,,,,,,,,,,,, -10300,0.5932027,1.8760154,,,,,,,,,,,,,, -10400,0.53399986,1.8676547,,,,,,,,,,,,,, -10500,0.8689937,1.8917665,,,,,,,,,,,,,, -10600,0.7043835,1.9007016,,,,,,,,,,,,,, -10700,0.5891786,1.8882465,,,,,,,,,,,,,, -10800,0.91637945,1.9216013,,,,,,,,,,,,,, -10891,,,0.8089695,0.2691599672493282,0.95127535,0.2845902082508665,5348.0,0.65417844,0.2198119147726118,2472.0,8676.789829492569,9574.090694189072,8676.789829492569,896.5786190032959,0.2828640937805176,0.0 -10900,0.61794,1.9202918,,,,,,,,,,,,,, -11000,0.95334214,1.931894,,,,,,,,,,,,,, -11100,0.8095213,1.8987658,,,,,,,,,,,,,, -11200,0.84295315,1.8680053,,,,,,,,,,,,,, -11300,1.1629004,1.916806,,,,,,,,,,,,,, -11400,0.7171881,1.803393,,,,,,,,,,,,,, -11500,0.95748836,1.9272224,,,,,,,,,,,,,, -11600,1.0685259,1.8662354,,,,,,,,,,,,,, -11700,0.687069,1.8842753,,,,,,,,,,,,,, -11800,0.7949505,1.9663959,,,,,,,,,,,,,, -11900,0.5806151,1.8539473,,,,,,,,,,,,,, -12000,0.6835972,1.8577424,,,,,,,,,,,,,, -12100,1.1626582,1.84497,,,,,,,,,,,,,, -12200,1.2230796,1.883193,,,,,,,,,,,,,, -12300,0.521605,1.8973385,,,,,,,,,,,,,, -12400,0.9298602,1.8611755,,,,,,,,,,,,,, -12500,0.4599402,1.8634907,,,,,,,,,,,,,, -12600,1.0508263,1.8889992,,,,,,,,,,,,,, -12700,0.7980777,1.8587403,,,,,,,,,,,,,, -12730,,,0.8335272,0.2669985010888317,0.91189045,0.2689013970282978,5348.0,0.6159002,0.2037860784433205,2472.0,10117.304208755491,11145.531847715378,10117.304208755491,1027.3806738853457,0.3326706886291504,0.0 -12800,0.6119388,1.8811579,,,,,,,,,,,,,, -12900,0.8831468,1.8577437,,,,,,,,,,,,,, -13000,0.5288061,1.8428662,,,,,,,,,,,,,, -13100,0.846746,1.8468062,,,,,,,,,,,,,, -13200,1.143792,1.9157108,,,,,,,,,,,,,, -13300,0.6577989,1.899914,,,,,,,,,,,,,, -13400,0.5561591,1.8478537,,,,,,,,,,,,,, -13500,0.59122694,1.8148593,,,,,,,,,,,,,, -13600,0.5655393,1.8592705,,,,,,,,,,,,,, -13700,0.5679241,1.8313689,,,,,,,,,,,,,, -13800,1.1557288,1.8841279,,,,,,,,,,,,,, -13900,0.8711048,1.8063773,,,,,,,,,,,,,, -14000,0.7251461,1.8507825,,,,,,,,,,,,,, -14100,0.55307156,1.8684943,,,,,,,,,,,,,, -14200,0.82419455,1.8540386,,,,,,,,,,,,,, -14300,0.611316,1.8185737,,,,,,,,,,,,,, -14400,0.94481647,1.8818074,,,,,,,,,,,,,, -14500,0.65584004,1.8890027,,,,,,,,,,,,,, -14542,,,0.81453586,0.2609976846979583,0.8834435,0.2635334099269142,5348.0,0.6003142,0.1986472487965389,2472.0,11557.22633099556,12716.850360155106,11557.22633099556,1158.6492466926577,0.3859102725982666,0.0 -14600,0.72067493,1.8115175,,,,,,,,,,,,,, -14700,0.64380914,1.8773477,,,,,,,,,,,,,, -14800,0.5608746,1.8015373,,,,,,,,,,,,,, -14900,0.5312583,1.7533324,,,,,,,,,,,,,, -15000,0.46424142,1.7624584,,,,,,,,,,,,,, -15100,0.69547933,1.8117535,,,,,,,,,,,,,, -15200,0.619486,1.846811,,,,,,,,,,,,,, -15300,0.5706891,1.8575708,,,,,,,,,,,,,, -15400,0.7759898,1.8649248,,,,,,,,,,,,,, -15500,0.61407393,1.7756233,,,,,,,,,,,,,, -15600,0.7135708,1.8291885,,,,,,,,,,,,,, -15700,0.63143426,1.7923969,,,,,,,,,,,,,, -15800,0.9729713,1.8088325,,,,,,,,,,,,,, -15900,0.5549796,1.8339323,,,,,,,,,,,,,, -16000,0.8298087,1.8705714,,,,,,,,,,,,,, -16100,0.88915825,1.8019555,,,,,,,,,,,,,, -16200,0.57445794,1.8624668,,,,,,,,,,,,,, -16300,0.50259984,1.7999908,,,,,,,,,,,,,, -16361,,,0.7476235,0.2428817993931665,0.8414512,0.250461009683617,5348.0,0.5641267,0.188532082140028,2472.0,12997.759685993196,14289.888085842133,12997.759685993196,1291.026449918747,0.4386816024780273,0.0 -16400,0.53067756,1.7877907,,,,,,,,,,,,,, -16500,0.9323449,1.8039126,,,,,,,,,,,,,, -16600,0.88735926,1.8124669,,,,,,,,,,,,,, -16700,0.6991669,1.8032122,,,,,,,,,,,,,, -16800,0.5076196,1.7652291,,,,,,,,,,,,,, -16900,0.7745582,1.8502293,,,,,,,,,,,,,, -17000,0.86640847,1.8144792,,,,,,,,,,,,,, -17100,0.8483526,1.8532598,,,,,,,,,,,,,, -17200,0.5303563,1.845219,,,,,,,,,,,,,, -17300,0.54011357,1.7757604,,,,,,,,,,,,,, -17400,0.5674009,1.7435912,,,,,,,,,,,,,, -17500,0.80416274,1.7791476,,,,,,,,,,,,,, -17600,1.1582975,1.7644836,,,,,,,,,,,,,, -17700,0.6721361,1.7973208,,,,,,,,,,,,,, -17800,0.74542046,1.7852972,,,,,,,,,,,,,, -17900,0.7759851,1.833181,,,,,,,,,,,,,, -18000,0.7393186,1.7765172,,,,,,,,,,,,,, -18100,0.52688336,1.7445853,,,,,,,,,,,,,, -18186,,,0.6964414,0.2297980367821279,0.83934134,0.2496403641735134,5348.0,0.5645912,0.1878211768529238,2472.0,14437.954718589785,15863.708154201508,14437.954718589785,1424.5227282047272,0.4945368766784668,0.0 -18200,0.55605423,1.8075576,,,,,,,,,,,,,, -18300,0.8362481,1.7453485,,,,,,,,,,,,,, -18400,0.8159907,1.8243685,,,,,,,,,,,,,, -18500,0.58316994,1.7655607,,,,,,,,,,,,,, -18600,0.6536486,1.7343827,,,,,,,,,,,,,, -18700,0.92747945,1.7704242,,,,,,,,,,,,,, -18800,0.65024096,1.7860705,,,,,,,,,,,,,, -18900,0.8204114,1.7705508,,,,,,,,,,,,,, -19000,0.5769952,1.692316,,,,,,,,,,,,,, -19100,0.6069065,1.7279881,,,,,,,,,,,,,, -19200,0.5096139,1.7438036,,,,,,,,,,,,,, -19300,0.46548524,1.7760133,,,,,,,,,,,,,, -19400,0.4795043,1.7705082,,,,,,,,,,,,,, -19500,0.6292451,1.7826132,,,,,,,,,,,,,, -19600,0.7016655,1.7311311,,,,,,,,,,,,,, -19700,0.39650995,1.7629564,,,,,,,,,,,,,, -19800,0.6566356,1.7037264,,,,,,,,,,,,,, -19900,0.552849,1.859302,,,,,,,,,,,,,, -20000,0.48245773,1.7343941,,,,,,,,,,,,,, -20022,,,0.6533358,0.2152976453253372,0.8104945,0.2420614615213802,5348.0,0.54033,0.1794324944650945,2472.0,15878.121164798737,17434.440548181534,15878.121164798737,1554.9639530181885,0.5451233386993408,0.0 -20100,0.4749735,1.7532417,,,,,,,,,,,,,, -20200,0.490057,1.7308912,,,,,,,,,,,,,, -20300,0.6168751,1.7639229,,,,,,,,,,,,,, -20400,0.5464795,1.793289,,,,,,,,,,,,,, -20500,0.92285764,1.7773228,,,,,,,,,,,,,, -20600,0.49655595,1.7414757,,,,,,,,,,,,,, -20700,0.8339886,1.7892241,,,,,,,,,,,,,, -20800,0.6293452,1.7828859,,,,,,,,,,,,,, -20900,0.7279699,1.6988559,,,,,,,,,,,,,, -21000,0.42501193,1.7383898,,,,,,,,,,,,,, -21100,0.51816696,1.714195,,,,,,,,,,,,,, -21200,0.62416255,1.6829342,,,,,,,,,,,,,, -21300,0.6988433,1.780429,,,,,,,,,,,,,, -21400,0.80381066,1.769983,,,,,,,,,,,,,, -21500,0.7712105,1.7451708,,,,,,,,,,,,,, -21600,0.49676996,1.722162,,,,,,,,,,,,,, -21700,0.54543245,1.7234735,,,,,,,,,,,,,, -21800,0.586447,1.6757189,,,,,,,,,,,,,, -21846,,,0.7111166,0.2384679209741439,0.8012005,0.2396188342971895,5348.0,0.52948534,0.1805090081855666,2472.0,17318.338774442673,19006.164157390594,17318.338774442673,1686.3452100753784,0.5957059860229492,0.0 -21900,0.6685477,1.7284203,,,,,,,,,,,,,, -22000,0.7685714,1.6815062,,,,,,,,,,,,,, -22100,1.1376926,1.7033525,,,,,,,,,,,,,, -22200,0.5286547,1.6743426,,,,,,,,,,,,,, -22300,0.5871335,1.694069,,,,,,,,,,,,,, -22400,0.7379253,1.7646329,,,,,,,,,,,,,, -22500,0.8721402,1.7285645,,,,,,,,,,,,,, -22600,0.59021246,1.7177376,,,,,,,,,,,,,, -22700,0.7492581,1.7138901,,,,,,,,,,,,,, -22800,1.2233378,1.7410043,,,,,,,,,,,,,, -22900,0.70081294,1.7081022,,,,,,,,,,,,,, -23000,0.6739941,1.6432426,,,,,,,,,,,,,, -23100,0.6528473,1.7262324,,,,,,,,,,,,,, -23200,0.5960381,1.6539807,,,,,,,,,,,,,, -23300,0.5625639,1.6984009,,,,,,,,,,,,,, -23400,0.66782236,1.7219467,,,,,,,,,,,,,, -23500,0.67952216,1.6353394,,,,,,,,,,,,,, -23600,0.90375,1.6629369,,,,,,,,,,,,,, -23656,,,0.6636916,0.2185747650370637,0.7818518,0.2352549311140504,5348.0,0.5160569,0.1741108606016289,2472.0,18758.5142223835,20579.324452877045,18758.5142223835,1819.2048013210297,0.6472318172454834,0.0 -23700,0.46445453,1.6833566,,,,,,,,,,,,,, -23800,0.5998236,1.6629833,,,,,,,,,,,,,, -23900,0.59297615,1.6704756,,,,,,,,,,,,,, -24000,0.592544,1.6667918,,,,,,,,,,,,,, -24100,0.61294997,1.7317476,,,,,,,,,,,,,, -24200,0.62836343,1.7154711,,,,,,,,,,,,,, -24300,0.6776429,1.7019973,,,,,,,,,,,,,, -24400,0.5113494,1.713785,,,,,,,,,,,,,, -24500,1.0191822,1.6887292,,,,,,,,,,,,,, -24600,0.62875426,1.6771,,,,,,,,,,,,,, -24700,0.58573836,1.6144296,,,,,,,,,,,,,, -24800,0.70666426,1.666807,,,,,,,,,,,,,, -24900,0.6277035,1.6781363,,,,,,,,,,,,,, -25000,0.6511078,1.6801989,,,,,,,,,,,,,, -25100,0.70205164,1.6494359,,,,,,,,,,,,,, -25200,0.64736766,1.6685765,,,,,,,,,,,,,, -25300,0.5905246,1.6595529,,,,,,,,,,,,,, -25400,0.72406244,1.7155771,,,,,,,,,,,,,, -25469,,,0.6760904,0.2182878715546995,0.7688105,0.2299931451963273,5348.0,0.50244045,0.1679564519732699,2472.0,20199.02674293518,22152.642586946487,20199.02674293518,1951.8814685344696,0.7016665935516357,0.0 -25500,0.85330164,1.6926093,,,,,,,,,,,,,, -25600,0.69440657,1.6690494,,,,,,,,,,,,,, -25700,0.4973908,1.6387848,,,,,,,,,,,,,, -25800,0.5415485,1.6509285,,,,,,,,,,,,,, -25900,0.8984592,1.6503458,,,,,,,,,,,,,, -26000,0.50268185,1.6559131,,,,,,,,,,,,,, -26100,0.6409499,1.6827762,,,,,,,,,,,,,, -26200,0.73344505,1.6353124,,,,,,,,,,,,,, -26300,0.4889843,1.6309925,,,,,,,,,,,,,, -26400,0.7632051,1.6831001,,,,,,,,,,,,,, -26500,0.66133213,1.6674851,,,,,,,,,,,,,, -26600,0.8392673,1.7085662,,,,,,,,,,,,,, -26700,0.88871473,1.672987,,,,,,,,,,,,,, -26800,0.5382022,1.6648728,,,,,,,,,,,,,, -26900,0.8677697,1.5947778,,,,,,,,,,,,,, -27000,0.60707796,1.6632917,,,,,,,,,,,,,, -27100,0.5140589,1.6003625,,,,,,,,,,,,,, -27200,0.6902258,1.6901382,,,,,,,,,,,,,, -27300,0.6499044,1.608729,,,,,,,,,,,,,, -27307,,,0.5880113,0.1996346712314533,0.7525327,0.2239589870338009,5348.0,0.49013776,0.1643613023784859,2472.0,21639.47647380829,23726.517678260803,21639.47647380829,2085.1744623184204,0.7583460807800293,0.0 -27400,0.56264955,1.6641963,,,,,,,,,,,,,, -27500,0.572481,1.7125641,,,,,,,,,,,,,, -27600,0.8044392,1.7095886,,,,,,,,,,,,,, -27700,0.5559614,1.6915101,,,,,,,,,,,,,, -27800,0.48723826,1.6694726,,,,,,,,,,,,,, -27900,0.5673756,1.674396,,,,,,,,,,,,,, -28000,0.6383665,1.6206815,,,,,,,,,,,,,, -28100,0.7518555,1.6679006,,,,,,,,,,,,,, -28200,0.58260685,1.7088901,,,,,,,,,,,,,, -28300,0.51730233,1.6044501,,,,,,,,,,,,,, -28400,0.842471,1.6692071,,,,,,,,,,,,,, -28500,0.537613,1.6369145,,,,,,,,,,,,,, -28600,0.8815822,1.6391906,,,,,,,,,,,,,, -28700,0.7764774,1.6078032,,,,,,,,,,,,,, -28800,0.70501196,1.6583275,,,,,,,,,,,,,, -28900,0.5888702,1.6344721,,,,,,,,,,,,,, -29000,0.6539966,1.6137102,,,,,,,,,,,,,, -29100,0.72623473,1.6326274,,,,,,,,,,,,,, -29139,,,0.5951339,0.2034078153011018,0.74104714,0.2230514496461569,5348.0,0.47743583,0.1630613612820669,2472.0,23079.74537658692,25297.29909467697,23079.74537658692,2215.55966258049,0.8103816509246826,0.0 -29200,0.77388364,1.6568322,,,,,,,,,,,,,, -29300,0.7050679,1.6100172,,,,,,,,,,,,,, -29400,0.7263696,1.7039609,,,,,,,,,,,,,, -29500,0.62415427,1.5712515,,,,,,,,,,,,,, -29600,0.62073076,1.6859756,,,,,,,,,,,,,, -29700,0.77889717,1.6655651,,,,,,,,,,,,,, -29800,0.8542127,1.651523,,,,,,,,,,,,,, -29900,0.6615421,1.6319263,,,,,,,,,,,,,, -30000,0.6330873,1.700094,,,,,,,,,,,,,, -30100,0.69541514,1.655886,,,,,,,,,,,,,, -30200,0.52361935,1.6496606,,,,,,,,,,,,,, -30300,0.5234537,1.6573215,,,,,,,,,,,,,, -30400,0.5742333,1.6910859,,,,,,,,,,,,,, -30500,0.524227,1.598502,,,,,,,,,,,,,, -30600,0.7536979,1.6417145,,,,,,,,,,,,,, -30700,0.79151046,1.639414,,,,,,,,,,,,,, -30800,0.53911114,1.6017203,,,,,,,,,,,,,, -30900,0.64194083,1.5820123,,,,,,,,,,,,,, -30948,,,0.5952117,0.2003793207538572,0.70316494,0.2136574722187358,5348.0,0.46406707,0.1560944894684459,2472.0,24520.142600536343,26867.569379091263,24520.142600536343,2345.2983088493347,0.8707478046417236,0.0 -31000,0.56543756,1.5520201,,,,,,,,,,,,,, -31100,0.5778311,1.6270963,,,,,,,,,,,,,, -31200,0.49403065,1.5815468,,,,,,,,,,,,,, -31300,0.700703,1.5699613,,,,,,,,,,,,,, -31400,0.6080138,1.590259,,,,,,,,,,,,,, -31500,0.5652554,1.5774058,,,,,,,,,,,,,, -31600,0.55582255,1.585872,,,,,,,,,,,,,, -31700,0.53576916,1.5859749,,,,,,,,,,,,,, -31800,0.5398196,1.6021472,,,,,,,,,,,,,, -31900,0.7132159,1.5913842,,,,,,,,,,,,,, -32000,0.5018901,1.6226218,,,,,,,,,,,,,, -32100,0.85972726,1.6451851,,,,,,,,,,,,,, -32200,0.64131904,1.63665,,,,,,,,,,,,,, -32300,0.51106143,1.6322812,,,,,,,,,,,,,, -32400,0.7760323,1.6763325,,,,,,,,,,,,,, -32500,0.5544656,1.6303899,,,,,,,,,,,,,, -32600,0.56395817,1.623508,,,,,,,,,,,,,, -32700,0.5741079,1.618805,,,,,,,,,,,,,, -32794,,,0.42188603,0.1499192788185452,0.6963656,0.2094673527906774,5348.0,0.45321754,0.1542461357219751,2472.0,25960.58037161827,28449.918375730515,25960.58037161827,2487.0760929584503,0.9291160106658936,0.0 -32800,0.6085827,1.5959072,,,,,,,,,,,,,, -32900,0.53558373,1.5909418,,,,,,,,,,,,,, -33000,0.54013175,1.5348554,,,,,,,,,,,,,, -33100,0.69327295,1.5526385,,,,,,,,,,,,,, -33200,0.68255985,1.6007674,,,,,,,,,,,,,, -33300,0.7771495,1.5802835,,,,,,,,,,,,,, -33400,0.6026241,1.5589583,,,,,,,,,,,,,, -33500,0.81019646,1.6081332,,,,,,,,,,,,,, -33600,0.7693665,1.6089083,,,,,,,,,,,,,, -33700,0.8018136,1.665018,,,,,,,,,,,,,, -33800,0.693323,1.6312289,,,,,,,,,,,,,, -33900,0.60108477,1.6465071,,,,,,,,,,,,,, -34000,0.5158423,1.6293296,,,,,,,,,,,,,, -34100,0.5893047,1.5986843,,,,,,,,,,,,,, -34200,0.52929825,1.6060623,,,,,,,,,,,,,, -34300,0.84561425,1.5440214,,,,,,,,,,,,,, -34400,0.5388568,1.5128303,,,,,,,,,,,,,, -34500,0.61213523,1.5857455,,,,,,,,,,,,,, -34600,0.5229687,1.5277838,,,,,,,,,,,,,, -34626,,,0.3637514,0.1303431418243586,0.67736506,0.2064261370767641,5348.0,0.42954934,0.145755895435988,2472.0,27400.81459593773,30025.02007961273,27400.81459593773,2621.817850112915,0.9812886714935304,0.0 -34700,0.54556376,1.5500728,,,,,,,,,,,,,, -34800,0.6863294,1.605403,,,,,,,,,,,,,, -34900,0.74005586,1.648523,,,,,,,,,,,,,, -35000,0.89914227,1.5751978,,,,,,,,,,,,,, -35100,0.5457148,1.5378454,,,,,,,,,,,,,, -35200,0.67998695,1.6305326,,,,,,,,,,,,,, -35300,0.7202959,1.5174433,,,,,,,,,,,,,, -35400,0.75203,1.5525279,,,,,,,,,,,,,, -35500,0.7903938,1.5393288,,,,,,,,,,,,,, -35600,0.5526961,1.5025932,,,,,,,,,,,,,, -35700,0.5919677,1.5807761,,,,,,,,,,,,,, -35800,0.54790086,1.5532157,,,,,,,,,,,,,, -35900,0.6585145,1.4791616,,,,,,,,,,,,,, -36000,0.65117085,1.5703042,,,,,,,,,,,,,, -36100,0.6462351,1.5356885,,,,,,,,,,,,,, -36200,0.6310339,1.5152655,,,,,,,,,,,,,, -36300,0.54163355,1.5497395,,,,,,,,,,,,,, -36400,0.8583407,1.5444729,,,,,,,,,,,,,, -36470,,,0.3567816,0.1272822637748813,0.66085625,0.2003823242611776,5348.0,0.41625735,0.1412467247577844,2472.0,28840.77037382126,31597.29881477356,28840.77037382126,2754.01162815094,1.035853385925293,0.0 -36500,0.69363123,1.5381091,,,,,,,,,,,,,, -36600,0.62297994,1.5756178,,,,,,,,,,,,,, -36700,0.7251446,1.5568233,,,,,,,,,,,,,, -36800,0.65332997,1.564483,,,,,,,,,,,,,, -36900,0.56551504,1.6306052,,,,,,,,,,,,,, -37000,0.6566304,1.4950521,,,,,,,,,,,,,, -37100,0.61614716,1.5200486,,,,,,,,,,,,,, -37200,0.6302459,1.4397792,,,,,,,,,,,,,, -37300,0.6573528,1.5916916,,,,,,,,,,,,,, -37400,0.65352035,1.5270865,,,,,,,,,,,,,, -37500,0.8318183,1.5685625,,,,,,,,,,,,,, -37600,0.5504072,1.5517159,,,,,,,,,,,,,, -37700,0.6737422,1.53652,,,,,,,,,,,,,, -37800,0.5987488,1.619592,,,,,,,,,,,,,, -37900,0.6002708,1.4837403,,,,,,,,,,,,,, -38000,0.6413483,1.512328,,,,,,,,,,,,,, -38100,0.6004978,1.5836982,,,,,,,,,,,,,, -38200,0.54188126,1.5130466,,,,,,,,,,,,,, -38286,,,0.3339406,0.1217827672728213,0.637265,0.1931896077314461,5348.0,0.40459725,0.1378546909593159,2472.0,30281.36931490898,33170.72305393219,30281.36931490898,2886.7089653015137,1.089693307876587,0.0 -38300,0.5084936,1.4743023,,,,,,,,,,,,,, -38400,0.6617227,1.5506628,,,,,,,,,,,,,, -38500,0.52845013,1.5347056,,,,,,,,,,,,,, -38600,0.796306,1.5687679,,,,,,,,,,,,,, -38700,0.55620676,1.5359303,,,,,,,,,,,,,, -38800,0.5749717,1.4988546,,,,,,,,,,,,,, -38900,0.82209986,1.5466498,,,,,,,,,,,,,, -39000,0.67861414,1.4919834,,,,,,,,,,,,,, -39100,0.5827904,1.4819719,,,,,,,,,,,,,, -39200,0.6072674,1.5470148,,,,,,,,,,,,,, -39300,0.5516229,1.5509852,,,,,,,,,,,,,, -39400,0.5892535,1.549576,,,,,,,,,,,,,, -39500,0.8001724,1.529076,,,,,,,,,,,,,, -39600,0.51915807,1.509276,,,,,,,,,,,,,, -39700,0.6303524,1.488005,,,,,,,,,,,,,, -39800,0.62814844,1.6134666,,,,,,,,,,,,,, -39900,0.63285476,1.5080453,,,,,,,,,,,,,, -40000,0.6635528,1.517246,,,,,,,,,,,,,, -40096,,,0.3399412,0.124161671315805,0.62865573,0.1924461994458229,5348.0,0.39331612,0.1330814697459021,2472.0,31721.717066049576,34746.4487221241,31721.717066049576,3021.955439567566,1.148374080657959,0.0 -40100,0.6320931,1.4990722,,,,,,,,,,,,,, -40200,0.6092756,1.531443,,,,,,,,,,,,,, -40300,0.6002133,1.5348799,,,,,,,,,,,,,, -40400,0.60053545,1.4855245,,,,,,,,,,,,,, -40500,0.6322985,1.4768386,,,,,,,,,,,,,, -40600,0.53012127,1.4861016,,,,,,,,,,,,,, -40700,0.52923214,1.4854908,,,,,,,,,,,,,, -40800,0.5832207,1.5682174,,,,,,,,,,,,,, -40900,0.77702165,1.5152016,,,,,,,,,,,,,, -41000,0.5774953,1.5022589,,,,,,,,,,,,,, -41100,0.5473748,1.5311173,,,,,,,,,,,,,, -41200,0.5970183,1.462345,,,,,,,,,,,,,, -41300,0.586713,1.4798713,,,,,,,,,,,,,, -41400,0.52176046,1.4659872,,,,,,,,,,,,,, -41500,0.6966562,1.4733217,,,,,,,,,,,,,, -41600,0.57920605,1.5100381,,,,,,,,,,,,,, -41700,0.6659059,1.5063338,,,,,,,,,,,,,, -41800,0.5228604,1.4972005,,,,,,,,,,,,,, -41900,0.60627466,1.5153185,,,,,,,,,,,,,, -41912,,,0.3111691,0.1121283985854748,0.60674435,0.1846162758141286,5348.0,0.3796361,0.1285316759084354,2472.0,33162.12484502792,36322.11321473122,33162.12484502792,3157.081080198288,1.2041044235229492,0.0 -42000,0.65668267,1.4915215,,,,,,,,,,,,,, -42100,0.6661582,1.4758549,,,,,,,,,,,,,, -42200,0.5494689,1.4586205,,,,,,,,,,,,,, -42300,0.57896054,1.3973415,,,,,,,,,,,,,, -42400,0.5783878,1.4763744,,,,,,,,,,,,,, -42500,0.58203584,1.5136688,,,,,,,,,,,,,, -42600,0.6794426,1.4313034,,,,,,,,,,,,,, -42700,0.5787635,1.440918,,,,,,,,,,,,,, -42800,0.63071173,1.5548732,,,,,,,,,,,,,, -42900,0.6783899,1.4672838,,,,,,,,,,,,,, -43000,0.6246107,1.4953843,,,,,,,,,,,,,, -43100,0.5904269,1.4864334,,,,,,,,,,,,,, -43200,0.5517198,1.4526262,,,,,,,,,,,,,, -43300,0.64088446,1.418514,,,,,,,,,,,,,, -43400,0.56171596,1.4478427,,,,,,,,,,,,,, -43500,0.5762009,1.4230007,,,,,,,,,,,,,, -43600,0.5636622,1.4616323,,,,,,,,,,,,,, -43700,0.7592205,1.4340973,,,,,,,,,,,,,, -43754,,,0.3599174,0.1250047824399735,0.59408325,0.1794317271208859,5348.0,0.36802,0.1236162736376007,2472.0,34602.23737287521,37896.17702317238,34602.23737287521,3290.895377635956,1.264575481414795,0.0 -43800,0.5632923,1.4125475,,,,,,,,,,,,,, -43900,0.73674065,1.4897991,,,,,,,,,,,,,, -44000,0.59586835,1.4640781,,,,,,,,,,,,,, -44100,0.68186736,1.4753861,,,,,,,,,,,,,, -44200,0.4669968,1.3877573,,,,,,,,,,,,,, -44300,0.7073276,1.4262375,,,,,,,,,,,,,, -44400,0.7170594,1.4263173,,,,,,,,,,,,,, -44500,0.53696287,1.4212651,,,,,,,,,,,,,, -44600,0.61003745,1.4578779,,,,,,,,,,,,,, -44700,0.7965466,1.4380767,,,,,,,,,,,,,, -44800,0.58822477,1.441013,,,,,,,,,,,,,, -44900,0.722539,1.444533,,,,,,,,,,,,,, -45000,0.7238267,1.4449998,,,,,,,,,,,,,, -45100,0.5367155,1.4422916,,,,,,,,,,,,,, -45200,0.56610364,1.4828632,,,,,,,,,,,,,, -45300,0.6993363,1.4871379,,,,,,,,,,,,,, -45400,0.5689739,1.3850113,,,,,,,,,,,,,, -45500,0.6716793,1.4138645,,,,,,,,,,,,,, -45578,,,0.29631695,0.1095491215770804,0.5799549,0.1743340703051835,5348.0,0.35863748,0.1223975788597079,2472.0,36042.2100276947,39469.1763818264,36042.2100276947,3423.787575244904,1.323909044265747,0.0 -45600,0.5439074,1.4029955,,,,,,,,,,,,,, -45700,0.6067779,1.3961271,,,,,,,,,,,,,, -45800,0.5217631,1.3838228,,,,,,,,,,,,,, -45900,0.6758232,1.3935859,,,,,,,,,,,,,, -46000,0.66037184,1.4566948,,,,,,,,,,,,,, -46100,0.6414954,1.4314363,,,,,,,,,,,,,, -46200,0.5641779,1.4360601,,,,,,,,,,,,,, -46300,0.51350063,1.433366,,,,,,,,,,,,,, -46400,0.67092437,1.432351,,,,,,,,,,,,,, -46500,0.6747395,1.4055735,,,,,,,,,,,,,, -46600,0.60638,1.3976365,,,,,,,,,,,,,, -46700,0.6099412,1.3997356,,,,,,,,,,,,,, -46800,0.73953617,1.4323114,,,,,,,,,,,,,, -46900,0.70863384,1.4037446,,,,,,,,,,,,,, -47000,0.6139029,1.4728261,,,,,,,,,,,,,, -47100,0.5293919,1.416657,,,,,,,,,,,,,, -47200,0.5225181,1.3890966,,,,,,,,,,,,,, -47300,0.5556242,1.4469036,,,,,,,,,,,,,, -47381,,,0.28037357,0.1038601149552454,0.56814826,0.1708294312443882,5348.0,0.3452232,0.1170962565758739,2472.0,37485.10423183441,41045.56417584419,37485.10423183441,3557.1513023376465,1.3801522254943848,0.0 -47400,0.73777217,1.4493718,,,,,,,,,,,,,, -47500,0.6995498,1.3903809,,,,,,,,,,,,,, -47600,0.7349516,1.4197286,,,,,,,,,,,,,, -47700,0.5802082,1.338615,,,,,,,,,,,,,, -47800,0.6105017,1.409671,,,,,,,,,,,,,, -47900,0.6030326,1.3696599,,,,,,,,,,,,,, -48000,0.56905884,1.4510547,,,,,,,,,,,,,, -48100,0.6732094,1.4101095,,,,,,,,,,,,,, -48200,0.75385594,1.4250755,,,,,,,,,,,,,, -48300,0.6505101,1.4082067,,,,,,,,,,,,,, -48400,0.6811464,1.4088674,,,,,,,,,,,,,, -48500,0.64118123,1.4671876,,,,,,,,,,,,,, -48600,0.5875039,1.331689,,,,,,,,,,,,,, -48700,0.55052805,1.3879247,,,,,,,,,,,,,, -48800,0.61133146,1.418819,,,,,,,,,,,,,, -48900,0.56857306,1.327094,,,,,,,,,,,,,, -49000,0.6085536,1.4310832,,,,,,,,,,,,,, -49100,0.68897474,1.3930004,,,,,,,,,,,,,, -49200,0.71188754,1.3398856,,,,,,,,,,,,,, -49209,,,0.2589628,0.0947917933854136,0.55032986,0.168048891163096,5348.0,0.33619353,0.1127495785347226,2472.0,38925.50800943375,42621.10450196266,38925.50800943375,3692.1489946842194,1.4441523551940918,0.0 -49300,0.56328213,1.377641,,,,,,,,,,,,,, -49400,0.5936495,1.395055,,,,,,,,,,,,,, -49500,0.5111541,1.3204637,,,,,,,,,,,,,, -49600,0.597788,1.3953347,,,,,,,,,,,,,, -49700,0.73254246,1.3640622,,,,,,,,,,,,,, -49800,0.49661633,1.3599956,,,,,,,,,,,,,, -49900,0.58088547,1.3669306,,,,,,,,,,,,,, -50000,0.6494016,1.3550643,,,,,,,,,,,,,, -50100,0.61704,1.3504424,,,,,,,,,,,,,, -50200,0.7181895,1.3679338,,,,,,,,,,,,,, -50300,0.51491505,1.4288867,,,,,,,,,,,,,, -50400,0.67130065,1.3531685,,,,,,,,,,,,,, -50500,0.59644043,1.304628,,,,,,,,,,,,,, -50600,0.65420115,1.3192599,,,,,,,,,,,,,, -50700,0.5817225,1.3717226,,,,,,,,,,,,,, -50800,0.63702655,1.3605392,,,,,,,,,,,,,, -50900,0.54441494,1.3219731,,,,,,,,,,,,,, -51000,0.70496494,1.333759,,,,,,,,,,,,,, -51041,,,0.24425602,0.0925137498728103,0.53289413,0.1615223456945074,5348.0,0.32483763,0.1091950520992017,2472.0,40365.552735090256,44194.29576420784,40365.552735090256,3825.167426109314,1.4977757930755615,0.0 -51100,0.59743476,1.3375663,,,,,,,,,,,,,, -51200,0.6030359,1.3248862,,,,,,,,,,,,,, -51300,0.7362929,1.3639101,,,,,,,,,,,,,, -51400,0.63906616,1.3176054,,,,,,,,,,,,,, -51500,0.7521683,1.3986435,,,,,,,,,,,,,, -51600,0.55919045,1.3954648,,,,,,,,,,,,,, -51700,0.60288876,1.3539258,,,,,,,,,,,,,, -51800,0.6935627,1.3542186,,,,,,,,,,,,,, -51900,0.58238995,1.3276953,,,,,,,,,,,,,, -52000,0.640564,1.3254468,,,,,,,,,,,,,, -52100,0.54543376,1.3182033,,,,,,,,,,,,,, -52200,0.7519768,1.3212022,,,,,,,,,,,,,, -52300,0.58560526,1.3558621,,,,,,,,,,,,,, -52400,0.67560345,1.3235716,,,,,,,,,,,,,, -52500,0.606716,1.332109,,,,,,,,,,,,,, -52600,0.670301,1.3293281,,,,,,,,,,,,,, -52700,0.5864488,1.2406703,,,,,,,,,,,,,, -52800,0.6222474,1.3142364,,,,,,,,,,,,,, -52875,,,0.25342596,0.0934737874636159,0.51060355,0.1543682477770161,5348.0,0.31024233,0.1048483740580504,2472.0,41805.85911726952,45769.149518728256,41805.85911726952,3959.577538490296,1.5582191944122314,0.0 -52900,0.67467993,1.3565847,,,,,,,,,,,,,, -53000,0.60591567,1.3228161,,,,,,,,,,,,,, -53100,0.544691,1.2957393,,,,,,,,,,,,,, -53200,0.6078536,1.3339669,,,,,,,,,,,,,, -53300,0.69970274,1.3221562,,,,,,,,,,,,,, -53400,0.6431786,1.3429627,,,,,,,,,,,,,, -53500,0.7471424,1.3595021,,,,,,,,,,,,,, -53600,0.64088726,1.3426609,,,,,,,,,,,,,, -53700,0.6155944,1.3127445,,,,,,,,,,,,,, -53800,0.641812,1.2637413,,,,,,,,,,,,,, -53900,0.5992783,1.2727455,,,,,,,,,,,,,, -54000,0.6092431,1.333279,,,,,,,,,,,,,, -54100,0.69384235,1.3733543,,,,,,,,,,,,,, -54200,0.70812714,1.301016,,,,,,,,,,,,,, -54300,0.6491245,1.3475729,,,,,,,,,,,,,, -54400,0.71510553,1.2687907,,,,,,,,,,,,,, -54500,0.5446024,1.3417094,,,,,,,,,,,,,, -54600,0.59157014,1.3027961,,,,,,,,,,,,,, -54678,,,0.23427062,0.085687837923127,0.4877075,0.148855440879732,5348.0,0.29198828,0.0984299148944813,2472.0,43246.37415957451,47342.83117246628,43246.37415957451,4092.6102225780487,1.617872714996338,0.0 -54700,0.7195359,1.2883898,,,,,,,,,,,,,, -54800,0.74695385,1.307672,,,,,,,,,,,,,, -54900,0.7035978,1.3165784,,,,,,,,,,,,,, -55000,0.5608879,1.2928109,,,,,,,,,,,,,, -55100,0.5938328,1.2483628,,,,,,,,,,,,,, -55200,0.6433116,1.343066,,,,,,,,,,,,,, -55300,0.791208,1.2938536,,,,,,,,,,,,,, -55400,0.6170634,1.2721597,,,,,,,,,,,,,, -55500,0.5540912,1.3042065,,,,,,,,,,,,,, -55600,0.72811234,1.2948841,,,,,,,,,,,,,, -55700,0.66845304,1.2950158,,,,,,,,,,,,,, -55800,0.65897787,1.2670617,,,,,,,,,,,,,, -55900,0.6148419,1.2667934,,,,,,,,,,,,,, -56000,0.60642594,1.2861813,,,,,,,,,,,,,, -56100,0.5864798,1.2523928,,,,,,,,,,,,,, -56200,0.6858789,1.2991472,,,,,,,,,,,,,, -56300,0.6816345,1.2629466,,,,,,,,,,,,,, -56400,0.72152025,1.2871435,,,,,,,,,,,,,, -56498,,,0.22292952,0.0828723003094216,0.46993592,0.143149540921247,5348.0,0.28381628,0.0955862937460646,2472.0,44686.48836612701,48917.04101586342,44686.48836612701,4226.572660923004,1.6772680282592771,0.0 -56500,0.6131751,1.2427813,,,,,,,,,,,,,, -56600,0.66203517,1.2901803,,,,,,,,,,,,,, -56700,0.72320765,1.2893987,,,,,,,,,,,,,, -56800,0.7688947,1.2499127,,,,,,,,,,,,,, -56900,0.6438947,1.2934632,,,,,,,,,,,,,, -57000,0.68677026,1.2992966,,,,,,,,,,,,,, -57100,0.694051,1.2925084,,,,,,,,,,,,,, -57200,0.66772074,1.232877,,,,,,,,,,,,,, -57300,0.6891138,1.2547469,,,,,,,,,,,,,, -57400,0.6490779,1.2466675,,,,,,,,,,,,,, -57500,0.60407376,1.291769,,,,,,,,,,,,,, -57600,0.71455455,1.298284,,,,,,,,,,,,,, -57700,0.81679875,1.2190745,,,,,,,,,,,,,, -57800,0.5828423,1.1989758,,,,,,,,,,,,,, -57900,0.6651895,1.1833586,,,,,,,,,,,,,, -58000,0.58603567,1.2635372,,,,,,,,,,,,,, -58100,0.65635157,1.2807457,,,,,,,,,,,,,, -58200,0.65142643,1.2272575,,,,,,,,,,,,,, -58300,0.7267477,1.1912472,,,,,,,,,,,,,, -58324,,,0.2015512,0.0759762241959345,0.45896965,0.1386118539830271,5348.0,0.27153856,0.091117746227124,2472.0,46126.97168445587,50490.77713441849,46126.97168445587,4359.694350004196,1.7328250408172607,0.0 -58400,0.6116981,1.1562301,,,,,,,,,,,,,, -58500,0.6804753,1.2676667,,,,,,,,,,,,,, -58600,0.6839953,1.2138867,,,,,,,,,,,,,, -58700,0.6604306,1.2617921,,,,,,,,,,,,,, -58800,0.6997333,1.245908,,,,,,,,,,,,,, -58900,0.6318478,1.2358847,,,,,,,,,,,,,, -59000,0.64357656,1.1976029,,,,,,,,,,,,,, -59100,0.73189193,1.2348123,,,,,,,,,,,,,, -59200,0.7422421,1.2588351,,,,,,,,,,,,,, -59300,0.64752233,1.2396246,,,,,,,,,,,,,, -59400,0.579466,1.2575034,,,,,,,,,,,,,, -59500,0.6530127,1.2298617,,,,,,,,,,,,,, -59600,0.67387295,1.1799213,,,,,,,,,,,,,, -59700,0.60832626,1.1933068,,,,,,,,,,,,,, -59800,0.77345645,1.1768535,,,,,,,,,,,,,, -59900,0.6502624,1.1806228,,,,,,,,,,,,,, -60000,0.8718834,1.1896955,,,,,,,,,,,,,, -60100,0.6726531,1.2300817,,,,,,,,,,,,,, -60165,,,0.18237127,0.0681878160264327,0.4402338,0.132664587698041,5348.0,0.26202387,0.0866288871285519,2472.0,47566.9754588604,52062.943118810654,47566.9754588604,4491.721256494522,1.7927451133728027,0.0 -60200,0.6582402,1.2753795,,,,,,,,,,,,,, -60300,0.5806591,1.2005786,,,,,,,,,,,,,, -60400,0.7284846,1.2862487,,,,,,,,,,,,,, -60500,0.709738,1.2011731,,,,,,,,,,,,,, -60600,0.7005088,1.2147253,,,,,,,,,,,,,, -60700,0.60131365,1.2005471,,,,,,,,,,,,,, -60800,0.6838778,1.2040174,,,,,,,,,,,,,, -60900,0.5960218,1.1814259,,,,,,,,,,,,,, -61000,0.536582,1.1566672,,,,,,,,,,,,,, -61100,0.6872671,1.225426,,,,,,,,,,,,,, -61200,0.6327037,1.205772,,,,,,,,,,,,,, -61300,0.7936497,1.1739637,,,,,,,,,,,,,, -61400,0.60486686,1.2041211,,,,,,,,,,,,,, -61500,0.78612065,1.200063,,,,,,,,,,,,,, -61600,0.69126827,1.2273495,,,,,,,,,,,,,, -61700,0.77302647,1.2110595,,,,,,,,,,,,,, -61800,0.6232179,1.1785948,,,,,,,,,,,,,, -61900,0.6964097,1.1512002,,,,,,,,,,,,,, -61990,,,0.16777992,0.0616180573167638,0.42649555,0.1305309093717717,5348.0,0.2465062,0.082850933317084,2472.0,49006.85757493973,53636.60034799576,49006.85757493973,4625.361068725586,1.8530685901641848,0.0 -62000,0.6090993,1.1497018,,,,,,,,,,,,,, -62100,0.6358119,1.1722574,,,,,,,,,,,,,, -62200,0.607289,1.1471202,,,,,,,,,,,,,, -62300,0.6681402,1.1422541,,,,,,,,,,,,,, -62400,0.7000225,1.1818067,,,,,,,,,,,,,, -62500,0.64917314,1.1443447,,,,,,,,,,,,,, -62600,0.7613389,1.15214,,,,,,,,,,,,,, -62700,0.83174574,1.152633,,,,,,,,,,,,,, -62800,0.70676935,1.1570948,,,,,,,,,,,,,, -62900,0.60457116,1.1626838,,,,,,,,,,,,,, -63000,0.66204125,1.1639979,,,,,,,,,,,,,, -63100,0.68672293,1.1602216,,,,,,,,,,,,,, -63200,0.68991554,1.0976713,,,,,,,,,,,,,, -63300,0.74610275,1.1478624,,,,,,,,,,,,,, -63400,0.9591877,1.108963,,,,,,,,,,,,,, -63500,0.77767575,1.1982839,,,,,,,,,,,,,, -63600,0.7248619,1.1473352,,,,,,,,,,,,,, -63700,0.6165833,1.1884965,,,,,,,,,,,,,, -63797,,,0.16459136,0.0626281978427584,0.41153353,0.1246319163520858,5348.0,0.23887035,0.0810838258891394,2472.0,50447.22711467743,55209.0825612545,50447.22711467743,4757.340369462967,1.9127659797668457,0.0 -63800,0.61731356,1.1280618,,,,,,,,,,,,,, -63900,0.70925844,1.1318284,,,,,,,,,,,,,, -64000,0.7568588,1.1611062,,,,,,,,,,,,,, -64100,0.6342375,1.1270039,,,,,,,,,,,,,, -64200,0.74421144,1.0870446,,,,,,,,,,,,,, -64300,0.8172056,1.1345189,,,,,,,,,,,,,, -64400,0.7056772,1.1339618,,,,,,,,,,,,,, -64500,0.6023361,1.1672606,,,,,,,,,,,,,, -64600,0.71206635,1.1763303,,,,,,,,,,,,,, -64700,0.7543056,1.1765484,,,,,,,,,,,,,, -64800,0.68500584,1.1586502,,,,,,,,,,,,,, -64900,0.65382767,1.1367716,,,,,,,,,,,,,, -65000,0.7676573,1.1260567,,,,,,,,,,,,,, -65100,0.6982565,1.125841,,,,,,,,,,,,,, -65200,0.66017234,1.0707169,,,,,,,,,,,,,, -65300,0.69958997,1.165775,,,,,,,,,,,,,, -65400,0.75423783,1.1142029,,,,,,,,,,,,,, -65500,0.73023736,1.1260608,,,,,,,,,,,,,, -65600,0.69882023,1.1135348,,,,,,,,,,,,,, -65617,,,0.14450066,0.0529541422504166,0.3941793,0.118501211658959,5348.0,0.22475353,0.0755184530700952,2472.0,51887.10120534897,56781.60054159165,51887.10120534897,4889.8466360569,1.9750051498413088,0.0 -65700,0.6393573,1.1273513,,,,,,,,,,,,,, -65800,0.709927,1.0965521,,,,,,,,,,,,,, -65900,0.6905479,1.1074939,,,,,,,,,,,,,, -66000,0.6673729,1.1073254,,,,,,,,,,,,,, -66100,0.7314462,1.1133634,,,,,,,,,,,,,, -66200,0.7616949,1.1520414,,,,,,,,,,,,,, -66300,0.8181208,1.1285167,,,,,,,,,,,,,, -66400,0.81365097,1.1013573,,,,,,,,,,,,,, -66500,0.86371094,1.0752962,,,,,,,,,,,,,, -66600,0.74403405,1.0901927,,,,,,,,,,,,,, -66700,0.77267724,1.0975875,,,,,,,,,,,,,, -66800,0.75732744,1.1388834,,,,,,,,,,,,,, -66900,0.7995586,1.1124759,,,,,,,,,,,,,, -67000,0.80111545,1.1084447,,,,,,,,,,,,,, -67100,0.95766205,1.0945783,,,,,,,,,,,,,, -67200,0.70285267,1.1101931,,,,,,,,,,,,,, -67300,0.7104508,1.0656794,,,,,,,,,,,,,, -67400,0.82650125,1.1270844,,,,,,,,,,,,,, -67460,,,0.14300667,0.0543315273329536,0.38965482,0.1168116473734516,5348.0,0.21985038,0.0726748319216785,2472.0,53327.44103908539,58354.63788485527,53327.44103908539,5022.408258914948,2.0354931354522705,0.0 -67500,0.7315819,1.1174433,,,,,,,,,,,,,, -67600,0.7402416,1.0900646,,,,,,,,,,,,,, -67700,0.8044973,1.0917314,,,,,,,,,,,,,, -67800,0.7442979,1.1427422,,,,,,,,,,,,,, -67900,0.73482037,1.0833863,,,,,,,,,,,,,, -68000,0.8029623,1.0307181,,,,,,,,,,,,,, -68100,0.77808416,1.0986763,,,,,,,,,,,,,, -68200,0.9871705,1.0699506,,,,,,,,,,,,,, -68300,0.72399646,1.0653716,,,,,,,,,,,,,, -68400,0.75947607,1.0401003,,,,,,,,,,,,,, -68500,0.73082,1.0209628,,,,,,,,,,,,,, -68600,0.6597005,1.0420498,,,,,,,,,,,,,, -68700,0.742617,1.0517086,,,,,,,,,,,,,, -68800,0.7862022,1.0544578,,,,,,,,,,,,,, -68900,0.7393884,1.0542759,,,,,,,,,,,,,, -69000,0.8115779,1.0749578,,,,,,,,,,,,,, -69100,0.7956187,1.0836157,,,,,,,,,,,,,, -69200,0.77349585,1.0528122,,,,,,,,,,,,,, -69292,,,0.15330945,0.0531289329793872,0.3706188,0.1106809426803247,5348.0,0.20976654,0.0695468486584201,2472.0,54767.72738194466,59926.45112419128,54767.72738194466,5153.785512447357,2.1095762252807617,0.0 -69300,0.683197,1.0957847,,,,,,,,,,,,,, -69400,0.86631536,1.0599173,,,,,,,,,,,,,, -69500,0.77128685,1.0298897,,,,,,,,,,,,,, -69600,0.7882759,1.0611669,,,,,,,,,,,,,, -69700,0.8977627,1.0374769,,,,,,,,,,,,,, -69800,0.7480784,1.0142518,,,,,,,,,,,,,, -69900,0.9024469,1.0310861,,,,,,,,,,,,,, -70000,0.7268849,1.0653293,,,,,,,,,,,,,, -70100,0.94106865,1.038154,,,,,,,,,,,,,, -70200,0.6719468,1.0047674,,,,,,,,,,,,,, -70300,0.7080009,0.9951775,,,,,,,,,,,,,, -70400,0.98249876,1.0044976,,,,,,,,,,,,,, -70500,0.75803494,1.0022928,,,,,,,,,,,,,, -70600,1.0036314,1.0354042,,,,,,,,,,,,,, -70700,0.7693166,1.0782042,,,,,,,,,,,,,, -70800,0.7609993,1.079936,,,,,,,,,,,,,, -70900,0.71963155,1.0629996,,,,,,,,,,,,,, -71000,0.93555737,1.0694674,,,,,,,,,,,,,, -71100,1.3394235,1.0497632,,,,,,,,,,,,,, -71111,,,0.104285374,0.0391234694455748,0.36181018,0.1078038560684321,5348.0,0.2021894,0.0670079012044766,2472.0,56208.11917066574,61500.805604457855,56208.11917066574,5287.613696575165,2.169984579086304,0.0 -71200,0.74110407,1.0880355,,,,,,,,,,,,,, -71300,0.8416505,0.99938756,,,,,,,,,,,,,, -71400,1.0781791,1.0363173,,,,,,,,,,,,,, -71500,1.20892,1.0274804,,,,,,,,,,,,,, -71600,0.77216244,0.97109133,,,,,,,,,,,,,, -71700,0.7282132,1.0128669,,,,,,,,,,,,,, -71800,0.88187534,0.9794724,,,,,,,,,,,,,, -71900,0.727707,0.9846788,,,,,,,,,,,,,, -72000,0.7170913,1.0888572,,,,,,,,,,,,,, -72100,0.8427117,1.0180436,,,,,,,,,,,,,, -72200,0.7403946,1.0419414,,,,,,,,,,,,,, -72300,0.9113353,1.0236164,,,,,,,,,,,,,, -72400,0.7569747,0.98098844,,,,,,,,,,,,,, -72500,0.77490723,1.0042626,,,,,,,,,,,,,, -72600,0.715964,1.0214233,,,,,,,,,,,,,, -72700,0.79462993,1.0044783,,,,,,,,,,,,,, -72800,0.6904888,1.0314177,,,,,,,,,,,,,, -72900,0.8648653,1.0520935,,,,,,,,,,,,,, -72948,,,0.102518566,0.0376658462535337,0.34724754,0.1040095774158355,5348.0,0.19403726,0.0655657790506367,2472.0,57648.46272945404,63076.50105428696,57648.46272945404,5422.824957847595,2.234661340713501,0.0 -73000,0.76418865,0.98166955,,,,,,,,,,,,,, -73100,0.8349472,0.96228313,,,,,,,,,,,,,, -73200,0.78237253,1.0327675,,,,,,,,,,,,,, -73300,0.81229645,0.99838424,,,,,,,,,,,,,, -73400,0.71702814,0.95221,,,,,,,,,,,,,, -73500,0.8063864,1.0095738,,,,,,,,,,,,,, -73600,0.7999925,0.9728885,,,,,,,,,,,,,, -73700,0.8141773,0.9742018,,,,,,,,,,,,,, -73800,0.7447915,0.9595936,,,,,,,,,,,,,, -73900,0.8326202,0.9963786,,,,,,,,,,,,,, -74000,0.84781075,1.0139066,,,,,,,,,,,,,, -74100,0.8844852,1.0267122,,,,,,,,,,,,,, -74200,0.7113835,0.9951886,,,,,,,,,,,,,, -74300,0.7175497,0.94086266,,,,,,,,,,,,,, -74400,0.8188503,0.93280697,,,,,,,,,,,,,, -74500,0.8681963,1.0181842,,,,,,,,,,,,,, -74600,0.7116164,0.97358406,,,,,,,,,,,,,, -74700,0.7670402,1.0458051,,,,,,,,,,,,,, -74785,,,0.12474066,0.0474371254207602,0.34358847,0.1028124004363903,5348.0,0.18991898,0.063575244246745,2472.0,59088.35154867172,64649.74883103371,59088.35154867172,5556.046092987061,2.296834707260132,0.0 -74800,0.8700774,0.99701494,,,,,,,,,,,,,, -74900,0.88602763,0.99948704,,,,,,,,,,,,,, -75000,0.857287,0.99745435,,,,,,,,,,,,,, -75100,0.7553262,1.0071303,,,,,,,,,,,,,, -75200,0.78304076,1.0523881,,,,,,,,,,,,,, -75300,0.8524129,0.99943304,,,,,,,,,,,,,, -75400,0.9494779,1.0117552,,,,,,,,,,,,,, -75500,1.0912913,0.99869007,,,,,,,,,,,,,, -75600,1.166823,1.0198913,,,,,,,,,,,,,, -75700,1.1780938,0.9636844,,,,,,,,,,,,,, -75800,0.77826136,1.0064341,,,,,,,,,,,,,, -75900,0.8065996,1.0393928,,,,,,,,,,,,,, -76000,0.80782527,0.96397275,,,,,,,,,,,,,, -76100,0.86238563,0.94817066,,,,,,,,,,,,,, -76200,0.8972235,0.99760354,,,,,,,,,,,,,, -76300,0.80734783,0.9760605,,,,,,,,,,,,,, -76400,0.92505234,0.9392375,,,,,,,,,,,,,, -76500,1.140403,0.97532326,,,,,,,,,,,,,, -76600,0.754435,1.0031235,,,,,,,,,,,,,, -76633,,,0.12209161,0.0444144621485653,0.33922857,0.1011807640692431,5348.0,0.18748328,0.0621128105132736,2472.0,60528.67248296738,66222.53850531578,60528.67248296738,5688.373752355576,2.361799716949463,0.0 -76700,0.8481122,0.9653985,,,,,,,,,,,,,, -76800,0.94129324,0.966837,,,,,,,,,,,,,, -76900,0.7866487,0.9757045,,,,,,,,,,,,,, -77000,0.79008657,0.94153947,,,,,,,,,,,,,, -77100,0.78811246,1.0015397,,,,,,,,,,,,,, -77200,0.9092731,0.9738318,,,,,,,,,,,,,, -77300,0.8357037,0.97625184,,,,,,,,,,,,,, -77324,,,,,,,,,,,61068.41437482834,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/eval_measurements.csv deleted file mode 100644 index d86fb1352..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -116.5929560661316,0.0,36.17317390441895,1,0,36.17317390441895,30.208836,2472,0.9085978916580344,152.76622414588928,31.451271,0.9408900814105828,30.14182,5348,0.9043706614402812 -242.3737189769745,0.0271456241607666,1476.5014803409576,1796,0,1476.5014803409576,3.0440004,2472,0.5801799605955356,1718.9732348918917,3.44108,0.6439037536543498,3.3643177,5348,0.6250132751479576 -373.27892112731934,0.0788459777832031,2916.602714776993,3624,0,2916.602714776993,0.6575261,2472,0.2147137082850933,3290.107299566269,0.90955204,0.288583466746811,0.95395464,5348,0.2790677467005223 -504.3946657180786,0.1390526294708252,4356.862602472305,5454,0,4356.862602472305,0.48497862,2472,0.1616598622874901,4861.620371341705,0.55690676,0.1896396760044132,0.74646324,5348,0.2249920349112254 -636.3247225284576,0.191321849822998,5796.924175024033,7256,0,5796.924175024033,0.42860916,2472,0.1437044258932017,6433.738090276718,0.55555606,0.1888104982499259,0.6697497,5348,0.2014443360977823 -768.995701789856,0.251901626586914,7237.313529729843,9080,0,7237.313529729843,0.39310047,2472,0.1300956675400646,8006.933594942093,0.45142564,0.1560763418045296,0.624928,5348,0.1884298637728453 -901.823336839676,0.3069281578063965,8677.727895021439,10907,0,8677.727895021439,0.37033734,2472,0.1250787073710722,9580.304859876633,0.43169728,0.1528120328344837,0.59886086,5348,0.1800978981820288 -1034.6947605609894,0.3596227169036865,10117.946796894072,12736,0,10117.946796894072,0.35084808,2472,0.117807161862978,11153.522399902344,0.4072509,0.1403770491803278,0.5681042,5348,0.1724707222645954 -1166.9670646190643,0.4112875461578369,11558.43038058281,14545,0,11558.43038058281,0.33382007,2472,0.1137042227774054,12726.40248823166,0.39576054,0.1385089016958059,0.55048543,5348,0.1663013989592284 -1299.6520047187803,0.4664990901947021,12998.585268974304,16358,0,12998.585268974304,0.33373216,2472,0.1095200373733065,14299.370248556135,0.40191302,0.1366346600588508,0.5395458,5348,0.1615706189598076 -1432.9095618724823,0.5199224948883057,14438.571347236631,18169,0,14438.571347236631,0.31530783,2472,0.1063311193711535,15872.740940332413,0.37821412,0.1335672436824805,0.52589005,5348,0.1580177066337121 -1562.391996383667,0.5783975124359131,15878.452617168428,20001,0,15878.452617168428,0.30063546,2472,0.1019235065911075,17442.23735189438,0.37082225,0.1336668088612551,0.50519156,5348,0.150931191287641 -1695.3766107559204,0.6333692073822021,17318.92253255844,21815,0,17318.92253255844,0.29525468,2472,0.0974955822314301,19015.82189750672,0.31139514,0.1098601009739526,0.49194026,5348,0.1469824381860837 -1826.374429941177,0.6851708889007568,18759.53474497795,23615,0,18759.53474497795,0.28810266,2472,0.0963175106128003,20587.556434631348,0.32840356,0.1159002532103454,0.48064968,5348,0.144983925002655 -1958.3797991275787,0.7363095283508301,20199.681203603745,25420,0,20199.681203603745,0.27983314,2472,0.0948347652996973,22159.8330514431,0.3194967,0.1130051041507338,0.47430897,5348,0.1427247361866051 -2091.949378967285,0.7944025993347168,21639.995112657547,27253,0,21639.995112657547,0.2673006,2472,0.088781914569496,23733.84988975525,0.30807894,0.1090600525061238,0.4601984,5348,0.1377139712484432 -2224.193256378174,0.8488309383392334,23080.38262534141,29061,0,23080.38262534141,0.26807737,2472,0.0894725082769687,25306.609971761703,0.2758949,0.1011349817592217,0.45608607,5348,0.135165142840592 -2356.3475427627563,0.904953956604004,24520.94918370247,30857,0,24520.94918370247,0.25917184,2472,0.0856742428858692,26879.460943460464,0.24288544,0.0886891000941861,0.44085503,5348,0.1318922154532376 -2488.2610454559326,0.9572756290435792,25961.47345638275,32659,0,25961.47345638275,0.25159425,2472,0.0839274470375561,28452.025376319885,0.27391443,0.0989485234644992,0.43026587,5348,0.1293723510045666 -2620.2989869117737,1.022853136062622,27401.712621688843,34494,0,27401.712621688843,0.24414307,2472,0.0828306217374525,30024.442529678345,0.25215727,0.0928667264214493,0.4266692,5348,0.1278179518618998 -2751.0915093421936,1.0802249908447266,28842.23300004005,36309,0,28842.23300004005,0.24109034,2472,0.0801901163853512,31595.88714289665,0.26872647,0.0920601350633074,0.41945672,5348,0.124429168637825 -2881.2546286582947,1.1364960670471191,30282.535950660706,38107,0,30282.535950660706,0.23581171,2472,0.0775699226128816,33166.483761549,0.21433543,0.0794151752505141,0.40897462,5348,0.1207700551280689 -3013.71555685997,1.1902563571929932,31722.782329320908,39910,0,31722.782329320908,0.2262993,2472,0.0743200698718339,34739.318687200546,0.19636467,0.0730434975637368,0.40372702,5348,0.1181826081079776 -3144.373094320297,1.2504262924194336,33163.08148312569,41747,0,33163.08148312569,0.21935742,2472,0.0716592529401011,36310.41373419762,0.21572936,0.0796678366578203,0.395749,5348,0.1162806414551493 -3284.943300962448,1.312373399734497,34603.13644838333,43579,0,34603.13644838333,0.21774586,2472,0.0726342087624154,37891.17783498764,0.1478013,0.0551468587783352,0.38760498,5348,0.113480792067737 -3419.207637071609,1.3747718334197998,36043.40504741669,45391,0,36043.40504741669,0.21235178,2472,0.0704202465825767,39465.84740304947,0.12686214,0.0475608169637049,0.37629935,5348,0.1112891858231074 -3554.784821510315,1.4327740669250488,37484.02665233612,47211,0,37484.02665233612,0.20668186,2472,0.0686937623138951,41042.178755521774,0.1242431,0.0480329439796949,0.3734636,5348,0.1081803875377738 -3689.043043851853,1.4914908409118652,38924.49155926704,49025,0,38924.49155926704,0.20458083,2472,0.0670891475230028,42617.03544545174,0.12451019,0.0470675119749111,0.36440134,5348,0.1062977301910655 -3820.602771282196,1.5500528812408447,40364.53780937195,50856,0,40364.53780937195,0.1984191,2472,0.0658501411654784,44188.77558088303,0.11252448,0.0436147960437567,0.35983658,5348,0.1053129555789412 -3954.864999771118,1.6051020622253418,41804.75467848778,52665,0,41804.75467848778,0.19205548,2472,0.0630065200170617,45763.38461208344,0.10154488,0.0393358401807406,0.34599328,5348,0.100929743089682 -4087.795857667923,1.659766435623169,43245.01002693176,54472,0,43245.01002693176,0.18795243,2472,0.0621940568317998,47336.69872045517,0.116006896,0.0441252374833898,0.3398731,5348,0.1001863348040588 -4221.798095703125,1.7223410606384275,44685.05935645104,56285,0,44685.05935645104,0.1826138,2472,0.0587613998740682,48910.888063669205,0.10221567,0.0367382905171829,0.3412774,5348,0.097502341253367 -4356.356876373291,1.7787425518035889,46125.22264838219,58115,0,46125.22264838219,0.18172288,2472,0.0592082546259622,50485.74221920967,0.095418066,0.0367302365224802,0.3292542,5348,0.0944707801925137 -4489.549370765686,1.8408918380737305,47565.73729014397,59930,0,47565.73729014397,0.17646027,2472,0.0567505534905449,52059.585463523865,0.0784812,0.0307128037937166,0.32586622,5348,0.0934763509273294 -4623.5429792404175,1.8968467712402344,49006.57438802719,61736,0,49006.57438802719,0.17256908,2472,0.0555521702922836,53634.54512619972,0.07178165,0.0287519646911267,0.31728697,5348,0.0908309759888778 -4757.461458683014,1.9577386379241943,50446.679966926575,63544,0,50446.679966926575,0.16620629,2472,0.0538866207624967,55208.704122543335,0.078863874,0.0309223185523101,0.3153456,5348,0.0890062465605298 -4891.862742900848,2.0222373008728027,51887.27757143974,65377,0,51887.27757143974,0.16686207,2472,0.0535413239087603,56783.84284877777,0.072210774,0.0280750267393629,0.3099662,5348,0.0889579732952296 -5025.178674221039,2.083531856536865,53327.70025777817,67194,0,53327.70025777817,0.16483718,2472,0.0523835638697621,58357.71854352951,0.06876019,0.0267672908902852,0.30606192,5348,0.0876642497851839 -5157.87672996521,2.142979860305786,54768.31900382042,68998,0,54768.31900382042,0.1602182,2472,0.0511851806715008,59931.1682472229,0.060892846,0.0234735113973606,0.30153775,5348,0.0862160518261776 -5292.408869028091,2.206705570220948,56208.43280529976,70808,0,56208.43280529976,0.15855892,2472,0.0502711595880811,61505.95262885094,0.05510921,0.0211103238444546,0.29850617,5348,0.0841113374590884 -5428.279653549194,2.27409291267395,57648.75893139839,72636,0,57648.75893139839,0.158432,2472,0.0497633700972924,63082.29228138924,0.059947684,0.0222034592160474,0.29521114,5348,0.0831651814592042 -5559.786453962326,2.330268144607544,59089.160449266434,74455,0,59089.160449266434,0.15610263,2472,0.0496618121991347,64654.3326253891,0.057411306,0.0215887926147866,0.29396224,5348,0.082431427826641 -5693.6730353832245,2.3865671157836914,60529.28284430504,76255,0,60529.28284430504,0.15516719,2472,0.0494383848231877,66228.47186899185,0.05751698,0.021302694306068883,0.29176736,5348,0.08214178823483978 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/measurements.csv deleted file mode 100644 index 57fb8fdc3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/measurements.csv +++ /dev/null @@ -1,815 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,53.16104,31.063494,,,,,,,,,,,,,, -1,,,31.451271,0.9408900814105828,30.14182,0.9043706614402812,5348.0,30.208836,0.9085978916580344,2472.0,36.17317390441895,152.76622414588928,36.17317390441895,116.5929560661316,0.0,0.0 -100,0.7267284,5.980682,,,,,,,,,,,,,, -200,0.5947843,5.8547816,,,,,,,,,,,,,, -300,0.32863864,5.825823,,,,,,,,,,,,,, -400,2.3709545,5.803504,,,,,,,,,,,,,, -500,0.54797375,5.8018937,,,,,,,,,,,,,, -600,6.995308,5.848922,,,,,,,,,,,,,, -700,2.026809,5.695526,,,,,,,,,,,,,, -800,1.1992017,5.5224485,,,,,,,,,,,,,, -900,2.6210892,5.398271,,,,,,,,,,,,,, -1000,1.1969976,4.412748,,,,,,,,,,,,,, -1100,1.8614469,3.832489,,,,,,,,,,,,,, -1200,1.5104018,3.415266,,,,,,,,,,,,,, -1300,0.88736546,3.1881993,,,,,,,,,,,,,, -1400,0.8103376,2.9331892,,,,,,,,,,,,,, -1500,0.8884041,2.9069536,,,,,,,,,,,,,, -1600,0.65637594,2.7261972,,,,,,,,,,,,,, -1700,0.868662,2.6306865,,,,,,,,,,,,,, -1796,,,3.44108,0.6439037536543498,3.3643177,0.6250132751479576,5348.0,3.0440004,0.5801799605955356,2472.0,1476.5014803409576,1718.9732348918917,1476.5014803409576,242.3737189769745,0.0271456241607666,0.0 -1800,0.87631077,2.4582996,,,,,,,,,,,,,, -1900,0.5697287,2.3741596,,,,,,,,,,,,,, -2000,0.6316353,2.360782,,,,,,,,,,,,,, -2100,0.53699493,2.242546,,,,,,,,,,,,,, -2200,0.7059006,2.2834864,,,,,,,,,,,,,, -2300,0.5904657,2.1566458,,,,,,,,,,,,,, -2400,0.65745264,2.11414,,,,,,,,,,,,,, -2500,1.1165895,2.1067886,,,,,,,,,,,,,, -2600,0.45424277,2.0227993,,,,,,,,,,,,,, -2700,0.74259686,2.0832815,,,,,,,,,,,,,, -2800,0.5178594,1.9939625,,,,,,,,,,,,,, -2900,0.5020457,1.9657445,,,,,,,,,,,,,, -3000,0.6524366,1.9758954,,,,,,,,,,,,,, -3100,0.77329427,1.9718874,,,,,,,,,,,,,, -3200,0.65459627,1.9018377,,,,,,,,,,,,,, -3300,0.5960454,1.8535764,,,,,,,,,,,,,, -3400,0.60469294,1.8912444,,,,,,,,,,,,,, -3500,0.62503326,1.9112631,,,,,,,,,,,,,, -3600,0.8037757,1.873149,,,,,,,,,,,,,, -3624,,,0.90955204,0.288583466746811,0.95395464,0.2790677467005223,5348.0,0.6575261,0.2147137082850933,2472.0,2916.602714776993,3290.107299566269,2916.602714776993,373.27892112731934,0.0788459777832031,0.0 -3700,0.75286835,1.8467127,,,,,,,,,,,,,, -3800,0.7594918,1.8513565,,,,,,,,,,,,,, -3900,0.5434944,1.8163298,,,,,,,,,,,,,, -4000,0.5396772,1.8474374,,,,,,,,,,,,,, -4100,0.5988913,1.789353,,,,,,,,,,,,,, -4200,0.45726702,1.768996,,,,,,,,,,,,,, -4300,0.8676368,1.7654425,,,,,,,,,,,,,, -4400,0.6828073,1.7492541,,,,,,,,,,,,,, -4500,0.504184,1.7744958,,,,,,,,,,,,,, -4600,0.48563215,1.712307,,,,,,,,,,,,,, -4700,0.49355516,1.7359241,,,,,,,,,,,,,, -4800,0.5569483,1.7618011,,,,,,,,,,,,,, -4900,0.56509036,1.7225181,,,,,,,,,,,,,, -5000,0.6115635,1.706735,,,,,,,,,,,,,, -5100,0.4457644,1.6920089,,,,,,,,,,,,,, -5200,0.4910211,1.69176,,,,,,,,,,,,,, -5300,0.45712662,1.6780072,,,,,,,,,,,,,, -5400,0.5341835,1.6540524,,,,,,,,,,,,,, -5454,,,0.55690676,0.1896396760044132,0.74646324,0.2249920349112254,5348.0,0.48497862,0.1616598622874901,2472.0,4356.862602472305,4861.620371341705,4356.862602472305,504.3946657180786,0.1390526294708252,0.0 -5500,0.48504776,1.6685467,,,,,,,,,,,,,, -5600,0.55607814,1.6654364,,,,,,,,,,,,,, -5700,0.54860216,1.7172865,,,,,,,,,,,,,, -5800,0.41057107,1.6465893,,,,,,,,,,,,,, -5900,0.8532159,1.6792506,,,,,,,,,,,,,, -6000,0.5102273,1.6658115,,,,,,,,,,,,,, -6100,0.5469113,1.6739902,,,,,,,,,,,,,, -6200,0.4506288,1.654396,,,,,,,,,,,,,, -6300,0.52881736,1.6551086,,,,,,,,,,,,,, -6400,0.42295358,1.6399757,,,,,,,,,,,,,, -6500,0.5030524,1.656832,,,,,,,,,,,,,, -6600,0.6516423,1.6621304,,,,,,,,,,,,,, -6700,0.5269378,1.6389759,,,,,,,,,,,,,, -6800,0.5140815,1.58493,,,,,,,,,,,,,, -6900,0.46389234,1.5504928,,,,,,,,,,,,,, -7000,0.479042,1.5976387,,,,,,,,,,,,,, -7100,0.46973753,1.6063148,,,,,,,,,,,,,, -7200,0.57627475,1.5879129,,,,,,,,,,,,,, -7256,,,0.55555606,0.1888104982499259,0.6697497,0.2014443360977823,5348.0,0.42860916,0.1437044258932017,2472.0,5796.924175024033,6433.738090276718,5796.924175024033,636.3247225284576,0.191321849822998,0.0 -7300,0.74590415,1.6233064,,,,,,,,,,,,,, -7400,0.57369363,1.5599918,,,,,,,,,,,,,, -7500,0.49053225,1.5823313,,,,,,,,,,,,,, -7600,0.5307257,1.5956463,,,,,,,,,,,,,, -7700,0.47913966,1.5685358,,,,,,,,,,,,,, -7800,0.49431148,1.6351156,,,,,,,,,,,,,, -7900,0.43763155,1.6022071,,,,,,,,,,,,,, -8000,0.5118086,1.6038314,,,,,,,,,,,,,, -8100,0.5128286,1.527932,,,,,,,,,,,,,, -8200,0.4936815,1.5703567,,,,,,,,,,,,,, -8300,0.50440276,1.5418022,,,,,,,,,,,,,, -8400,0.59962463,1.5070589,,,,,,,,,,,,,, -8500,0.5601612,1.5914644,,,,,,,,,,,,,, -8600,0.49699536,1.560203,,,,,,,,,,,,,, -8700,0.5009486,1.5165219,,,,,,,,,,,,,, -8800,0.54899377,1.4833175,,,,,,,,,,,,,, -8900,0.56820494,1.524118,,,,,,,,,,,,,, -9000,0.37851426,1.536685,,,,,,,,,,,,,, -9080,,,0.45142564,0.1560763418045296,0.624928,0.1884298637728453,5348.0,0.39310047,0.1300956675400646,2472.0,7237.313529729843,8006.933594942093,7237.313529729843,768.995701789856,0.251901626586914,0.0 -9100,0.57588005,1.542958,,,,,,,,,,,,,, -9200,0.5154503,1.5644947,,,,,,,,,,,,,, -9300,0.5244034,1.4812039,,,,,,,,,,,,,, -9400,0.5544119,1.539081,,,,,,,,,,,,,, -9500,0.49015263,1.4472643,,,,,,,,,,,,,, -9600,0.42607793,1.5248604,,,,,,,,,,,,,, -9700,0.45456672,1.5492365,,,,,,,,,,,,,, -9800,0.51696014,1.497823,,,,,,,,,,,,,, -9900,0.45594054,1.5413321,,,,,,,,,,,,,, -10000,0.50595677,1.4900416,,,,,,,,,,,,,, -10100,0.5678579,1.4841266,,,,,,,,,,,,,, -10200,0.4926864,1.5045747,,,,,,,,,,,,,, -10300,0.5211786,1.5488625,,,,,,,,,,,,,, -10400,0.42660517,1.4657955,,,,,,,,,,,,,, -10500,0.52244705,1.5566348,,,,,,,,,,,,,, -10600,0.5028611,1.4605273,,,,,,,,,,,,,, -10700,0.52403855,1.5149639,,,,,,,,,,,,,, -10800,0.5248078,1.5174875,,,,,,,,,,,,,, -10900,0.40590665,1.4864832,,,,,,,,,,,,,, -10907,,,0.43169728,0.1528120328344837,0.59886086,0.1800978981820288,5348.0,0.37033734,0.1250787073710722,2472.0,8677.727895021439,9580.304859876633,8677.727895021439,901.823336839676,0.3069281578063965,0.0 -11000,0.4201826,1.5070109,,,,,,,,,,,,,, -11100,0.42977566,1.4535222,,,,,,,,,,,,,, -11200,0.49669918,1.4630345,,,,,,,,,,,,,, -11300,0.47401,1.4821254,,,,,,,,,,,,,, -11400,0.50011903,1.4501326,,,,,,,,,,,,,, -11500,0.50273305,1.5007806,,,,,,,,,,,,,, -11600,0.453244,1.4001598,,,,,,,,,,,,,, -11700,0.45722792,1.4952353,,,,,,,,,,,,,, -11800,0.5333573,1.4878454,,,,,,,,,,,,,, -11900,0.43728718,1.5233269,,,,,,,,,,,,,, -12000,0.45472628,1.4364896,,,,,,,,,,,,,, -12100,0.38573802,1.440662,,,,,,,,,,,,,, -12200,0.4530883,1.446875,,,,,,,,,,,,,, -12300,0.58242625,1.4171354,,,,,,,,,,,,,, -12400,0.41514394,1.4472104,,,,,,,,,,,,,, -12500,0.47483394,1.4425573,,,,,,,,,,,,,, -12600,0.42710295,1.4944336,,,,,,,,,,,,,, -12700,0.50317526,1.4658703,,,,,,,,,,,,,, -12736,,,0.4072509,0.1403770491803278,0.5681042,0.1724707222645954,5348.0,0.35084808,0.117807161862978,2472.0,10117.946796894072,11153.522399902344,10117.946796894072,1034.6947605609894,0.3596227169036865,0.0 -12800,0.5080637,1.392449,,,,,,,,,,,,,, -12900,0.57677674,1.4803725,,,,,,,,,,,,,, -13000,0.7577733,1.4869016,,,,,,,,,,,,,, -13100,0.41555935,1.4900781,,,,,,,,,,,,,, -13200,0.43647683,1.4518883,,,,,,,,,,,,,, -13300,0.5356385,1.4827999,,,,,,,,,,,,,, -13400,0.5924099,1.3929671,,,,,,,,,,,,,, -13500,0.42593414,1.4721817,,,,,,,,,,,,,, -13600,0.56581885,1.3809651,,,,,,,,,,,,,, -13700,0.56305295,1.4714439,,,,,,,,,,,,,, -13800,0.5624695,1.4468347,,,,,,,,,,,,,, -13900,0.40565726,1.4666597,,,,,,,,,,,,,, -14000,0.6045008,1.4959991,,,,,,,,,,,,,, -14100,0.53908217,1.4621892,,,,,,,,,,,,,, -14200,0.41222754,1.3905746,,,,,,,,,,,,,, -14300,0.4589015,1.3872803,,,,,,,,,,,,,, -14400,0.5374596,1.5051494,,,,,,,,,,,,,, -14500,0.52248746,1.4401003,,,,,,,,,,,,,, -14545,,,0.39576054,0.1385089016958059,0.55048543,0.1663013989592284,5348.0,0.33382007,0.1137042227774054,2472.0,11558.43038058281,12726.40248823166,11558.43038058281,1166.9670646190643,0.4112875461578369,0.0 -14600,0.5983822,1.3812685,,,,,,,,,,,,,, -14700,0.4659985,1.405246,,,,,,,,,,,,,, -14800,0.45997304,1.4104505,,,,,,,,,,,,,, -14900,0.5511453,1.3967527,,,,,,,,,,,,,, -15000,0.46455938,1.4378785,,,,,,,,,,,,,, -15100,0.54920167,1.4010184,,,,,,,,,,,,,, -15200,0.44457775,1.4003386,,,,,,,,,,,,,, -15300,0.520205,1.3621697,,,,,,,,,,,,,, -15400,0.4790203,1.453191,,,,,,,,,,,,,, -15500,0.48618042,1.4098866,,,,,,,,,,,,,, -15600,0.4997873,1.4271617,,,,,,,,,,,,,, -15700,0.5588614,1.446998,,,,,,,,,,,,,, -15800,0.45171425,1.3873512,,,,,,,,,,,,,, -15900,0.44001007,1.4569194,,,,,,,,,,,,,, -16000,0.4441902,1.396489,,,,,,,,,,,,,, -16100,0.48312682,1.3996619,,,,,,,,,,,,,, -16200,0.4575005,1.4341329,,,,,,,,,,,,,, -16300,0.48673776,1.3477442,,,,,,,,,,,,,, -16358,,,0.40191302,0.1366346600588508,0.5395458,0.1615706189598076,5348.0,0.33373216,0.1095200373733065,2472.0,12998.585268974304,14299.370248556135,12998.585268974304,1299.6520047187803,0.4664990901947021,0.0 -16400,0.54629767,1.4928473,,,,,,,,,,,,,, -16500,0.51734036,1.4285553,,,,,,,,,,,,,, -16600,0.4883779,1.3856938,,,,,,,,,,,,,, -16700,0.4307826,1.3605208,,,,,,,,,,,,,, -16800,0.5142241,1.4443096,,,,,,,,,,,,,, -16900,0.4385312,1.3815017,,,,,,,,,,,,,, -17000,0.4610101,1.3703339,,,,,,,,,,,,,, -17100,0.60911244,1.415019,,,,,,,,,,,,,, -17200,0.4307535,1.421454,,,,,,,,,,,,,, -17300,0.47234356,1.3624622,,,,,,,,,,,,,, -17400,0.60849535,1.3433132,,,,,,,,,,,,,, -17500,0.5422925,1.4106618,,,,,,,,,,,,,, -17600,0.4607119,1.3886125,,,,,,,,,,,,,, -17700,0.48332238,1.327424,,,,,,,,,,,,,, -17800,0.5336804,1.3384433,,,,,,,,,,,,,, -17900,0.5880802,1.373095,,,,,,,,,,,,,, -18000,0.45493045,1.3622291,,,,,,,,,,,,,, -18100,0.5397216,1.3234122,,,,,,,,,,,,,, -18169,,,0.37821412,0.1335672436824805,0.52589005,0.1580177066337121,5348.0,0.31530783,0.1063311193711535,2472.0,14438.571347236631,15872.740940332413,14438.571347236631,1432.9095618724823,0.5199224948883057,0.0 -18200,0.55467594,1.4197016,,,,,,,,,,,,,, -18300,0.48997712,1.3223692,,,,,,,,,,,,,, -18400,0.5670439,1.4421939,,,,,,,,,,,,,, -18500,0.46485117,1.3628023,,,,,,,,,,,,,, -18600,0.501566,1.3554913,,,,,,,,,,,,,, -18700,0.45921504,1.4085063,,,,,,,,,,,,,, -18800,0.4758113,1.401499,,,,,,,,,,,,,, -18900,0.5580688,1.3991315,,,,,,,,,,,,,, -19000,0.5206257,1.3595895,,,,,,,,,,,,,, -19100,0.4418727,1.3306054,,,,,,,,,,,,,, -19200,0.4968073,1.3567421,,,,,,,,,,,,,, -19300,0.5035639,1.3903534,,,,,,,,,,,,,, -19400,0.50565654,1.385112,,,,,,,,,,,,,, -19500,0.513464,1.3748027,,,,,,,,,,,,,, -19600,0.46884003,1.3701023,,,,,,,,,,,,,, -19700,0.5405344,1.3264875,,,,,,,,,,,,,, -19800,0.6265248,1.3544161,,,,,,,,,,,,,, -19900,0.4416883,1.4219539,,,,,,,,,,,,,, -20000,0.46808887,1.3473831,,,,,,,,,,,,,, -20001,,,0.37082225,0.1336668088612551,0.50519156,0.150931191287641,5348.0,0.30063546,0.1019235065911075,2472.0,15878.452617168428,17442.23735189438,15878.452617168428,1562.391996383667,0.5783975124359131,0.0 -20100,0.57240313,1.3618282,,,,,,,,,,,,,, -20200,0.52476656,1.420017,,,,,,,,,,,,,, -20300,0.4898919,1.405143,,,,,,,,,,,,,, -20400,0.5408933,1.339015,,,,,,,,,,,,,, -20500,0.53149146,1.3413712,,,,,,,,,,,,,, -20600,0.534497,1.3832763,,,,,,,,,,,,,, -20700,0.49581736,1.341498,,,,,,,,,,,,,, -20800,0.45986155,1.3600101,,,,,,,,,,,,,, -20900,0.41257468,1.2190447,,,,,,,,,,,,,, -21000,0.48395148,1.3263991,,,,,,,,,,,,,, -21100,0.48209134,1.314953,,,,,,,,,,,,,, -21200,0.54256314,1.3723394,,,,,,,,,,,,,, -21300,0.6818297,1.4153068,,,,,,,,,,,,,, -21400,0.4678745,1.336183,,,,,,,,,,,,,, -21500,0.5087437,1.3055806,,,,,,,,,,,,,, -21600,0.58559734,1.3563834,,,,,,,,,,,,,, -21700,0.45297438,1.3567947,,,,,,,,,,,,,, -21800,0.55500054,1.3130599,,,,,,,,,,,,,, -21815,,,0.31139514,0.1098601009739526,0.49194026,0.1469824381860837,5348.0,0.29525468,0.0974955822314301,2472.0,17318.92253255844,19015.82189750672,17318.92253255844,1695.3766107559204,0.6333692073822021,0.0 -21900,0.5047571,1.3243806,,,,,,,,,,,,,, -22000,0.5525561,1.3033292,,,,,,,,,,,,,, -22100,0.49654248,1.3397107,,,,,,,,,,,,,, -22200,0.4371696,1.3360277,,,,,,,,,,,,,, -22300,0.5082929,1.34736,,,,,,,,,,,,,, -22400,0.54076135,1.3116225,,,,,,,,,,,,,, -22500,0.5595855,1.3938613,,,,,,,,,,,,,, -22600,0.4791194,1.3188373,,,,,,,,,,,,,, -22700,0.51678944,1.318927,,,,,,,,,,,,,, -22800,0.6295565,1.2833626,,,,,,,,,,,,,, -22900,0.47846833,1.319618,,,,,,,,,,,,,, -23000,0.61547685,1.2891531,,,,,,,,,,,,,, -23100,0.5155072,1.2915875,,,,,,,,,,,,,, -23200,0.51655114,1.3122332,,,,,,,,,,,,,, -23300,0.6896333,1.284764,,,,,,,,,,,,,, -23400,0.5620434,1.3242594,,,,,,,,,,,,,, -23500,0.49136278,1.3001171,,,,,,,,,,,,,, -23600,0.5214307,1.302725,,,,,,,,,,,,,, -23615,,,0.32840356,0.1159002532103454,0.48064968,0.144983925002655,5348.0,0.28810266,0.0963175106128003,2472.0,18759.53474497795,20587.556434631348,18759.53474497795,1826.374429941177,0.6851708889007568,0.0 -23700,0.4752838,1.3370388,,,,,,,,,,,,,, -23800,0.5272462,1.3056164,,,,,,,,,,,,,, -23900,0.4243006,1.2818153,,,,,,,,,,,,,, -24000,0.46776262,1.2943233,,,,,,,,,,,,,, -24100,0.5025303,1.288668,,,,,,,,,,,,,, -24200,0.497914,1.2855529,,,,,,,,,,,,,, -24300,0.42085385,1.3008513,,,,,,,,,,,,,, -24400,0.43664598,1.2983143,,,,,,,,,,,,,, -24500,0.63474977,1.300461,,,,,,,,,,,,,, -24600,0.54859936,1.3356177,,,,,,,,,,,,,, -24700,0.55084825,1.3038565,,,,,,,,,,,,,, -24800,0.43461096,1.2377956,,,,,,,,,,,,,, -24900,0.45436883,1.2913004,,,,,,,,,,,,,, -25000,0.5664312,1.30999,,,,,,,,,,,,,, -25100,0.5269295,1.2972571,,,,,,,,,,,,,, -25200,0.5672119,1.2277207,,,,,,,,,,,,,, -25300,0.5204917,1.2929811,,,,,,,,,,,,,, -25400,0.56865203,1.3015722,,,,,,,,,,,,,, -25420,,,0.3194967,0.1130051041507338,0.47430897,0.1427247361866051,5348.0,0.27983314,0.0948347652996973,2472.0,20199.681203603745,22159.8330514431,20199.681203603745,1958.3797991275787,0.7363095283508301,0.0 -25500,0.59754884,1.3081249,,,,,,,,,,,,,, -25600,0.43123573,1.3030334,,,,,,,,,,,,,, -25700,0.45100293,1.3149188,,,,,,,,,,,,,, -25800,0.49443302,1.2987642,,,,,,,,,,,,,, -25900,0.59501624,1.3252159,,,,,,,,,,,,,, -26000,0.5755105,1.3548262,,,,,,,,,,,,,, -26100,0.4365909,1.2997167,,,,,,,,,,,,,, -26200,0.57205266,1.291085,,,,,,,,,,,,,, -26300,0.56217134,1.2608756,,,,,,,,,,,,,, -26400,0.53383094,1.2406095,,,,,,,,,,,,,, -26500,0.48797885,1.303907,,,,,,,,,,,,,, -26600,0.4981662,1.3212107,,,,,,,,,,,,,, -26700,0.56194234,1.3182065,,,,,,,,,,,,,, -26800,0.5060959,1.2826488,,,,,,,,,,,,,, -26900,0.5406141,1.3193395,,,,,,,,,,,,,, -27000,0.5204593,1.3272811,,,,,,,,,,,,,, -27100,0.4925984,1.2551553,,,,,,,,,,,,,, -27200,0.49590927,1.2591605,,,,,,,,,,,,,, -27253,,,0.30807894,0.1090600525061238,0.4601984,0.1377139712484432,5348.0,0.2673006,0.088781914569496,2472.0,21639.995112657547,23733.84988975525,21639.995112657547,2091.949378967285,0.7944025993347168,0.0 -27300,0.54493827,1.3086461,,,,,,,,,,,,,, -27400,0.43802685,1.2865742,,,,,,,,,,,,,, -27500,0.5624945,1.3236473,,,,,,,,,,,,,, -27600,0.55052215,1.2924103,,,,,,,,,,,,,, -27700,0.6926688,1.2820183,,,,,,,,,,,,,, -27800,0.6231818,1.2571527,,,,,,,,,,,,,, -27900,0.5248144,1.2977024,,,,,,,,,,,,,, -28000,0.5063938,1.2470163,,,,,,,,,,,,,, -28100,0.5015371,1.2811406,,,,,,,,,,,,,, -28200,0.5479876,1.3387668,,,,,,,,,,,,,, -28300,0.49611804,1.3075999,,,,,,,,,,,,,, -28400,0.4544332,1.3058275,,,,,,,,,,,,,, -28500,0.64236003,1.3618793,,,,,,,,,,,,,, -28600,0.5696599,1.3020805,,,,,,,,,,,,,, -28700,0.47475547,1.2487571,,,,,,,,,,,,,, -28800,0.47723514,1.2526718,,,,,,,,,,,,,, -28900,0.52175033,1.2657553,,,,,,,,,,,,,, -29000,0.63395274,1.3040286,,,,,,,,,,,,,, -29061,,,0.2758949,0.1011349817592217,0.45608607,0.135165142840592,5348.0,0.26807737,0.0894725082769687,2472.0,23080.38262534141,25306.609971761703,23080.38262534141,2224.193256378174,0.8488309383392334,0.0 -29100,0.5156371,1.2302334,,,,,,,,,,,,,, -29200,0.43440253,1.303362,,,,,,,,,,,,,, -29300,0.46752292,1.2485036,,,,,,,,,,,,,, -29400,0.49965084,1.2633578,,,,,,,,,,,,,, -29500,0.51217157,1.243679,,,,,,,,,,,,,, -29600,0.5777662,1.2607735,,,,,,,,,,,,,, -29700,0.52274036,1.2933158,,,,,,,,,,,,,, -29800,0.50138915,1.2760323,,,,,,,,,,,,,, -29900,0.61623406,1.2240899,,,,,,,,,,,,,, -30000,0.47912768,1.2689788,,,,,,,,,,,,,, -30100,0.44002488,1.3122036,,,,,,,,,,,,,, -30200,0.75868475,1.2698392,,,,,,,,,,,,,, -30300,0.47629276,1.2620581,,,,,,,,,,,,,, -30400,0.48392093,1.3167139,,,,,,,,,,,,,, -30500,0.5786136,1.2618213,,,,,,,,,,,,,, -30600,0.5477713,1.2395598,,,,,,,,,,,,,, -30700,0.4897036,1.3870372,,,,,,,,,,,,,, -30800,0.57677156,1.2650219,,,,,,,,,,,,,, -30857,,,0.24288544,0.0886891000941861,0.44085503,0.1318922154532376,5348.0,0.25917184,0.0856742428858692,2472.0,24520.94918370247,26879.460943460464,24520.94918370247,2356.3475427627563,0.904953956604004,0.0 -30900,0.5959841,1.2817492,,,,,,,,,,,,,, -31000,0.43443364,1.1752626,,,,,,,,,,,,,, -31100,0.73441,1.2791388,,,,,,,,,,,,,, -31200,0.40739602,1.218336,,,,,,,,,,,,,, -31300,0.52844095,1.2932849,,,,,,,,,,,,,, -31400,0.48487815,1.2311877,,,,,,,,,,,,,, -31500,0.535629,1.2349712,,,,,,,,,,,,,, -31600,0.4836718,1.2577733,,,,,,,,,,,,,, -31700,0.46982607,1.2419764,,,,,,,,,,,,,, -31800,0.46714497,1.2593179,,,,,,,,,,,,,, -31900,0.65103924,1.319023,,,,,,,,,,,,,, -32000,0.5835743,1.3128306,,,,,,,,,,,,,, -32100,0.614922,1.2671584,,,,,,,,,,,,,, -32200,0.523446,1.1840497,,,,,,,,,,,,,, -32300,0.52963793,1.2366625,,,,,,,,,,,,,, -32400,0.46102557,1.275656,,,,,,,,,,,,,, -32500,0.49551257,1.2114975,,,,,,,,,,,,,, -32600,0.525708,1.204487,,,,,,,,,,,,,, -32659,,,0.27391443,0.0989485234644992,0.43026587,0.1293723510045666,5348.0,0.25159425,0.0839274470375561,2472.0,25961.47345638275,28452.025376319885,25961.47345638275,2488.2610454559326,0.9572756290435792,0.0 -32700,0.58005303,1.2470771,,,,,,,,,,,,,, -32800,0.4763343,1.2405001,,,,,,,,,,,,,, -32900,0.496107,1.2443143,,,,,,,,,,,,,, -33000,0.5314007,1.2446772,,,,,,,,,,,,,, -33100,0.53293276,1.2341431,,,,,,,,,,,,,, -33200,0.53182185,1.2350605,,,,,,,,,,,,,, -33300,0.48474124,1.2224878,,,,,,,,,,,,,, -33400,0.5816342,1.2669746,,,,,,,,,,,,,, -33500,0.5664296,1.1815536,,,,,,,,,,,,,, -33600,0.54269034,1.2218752,,,,,,,,,,,,,, -33700,0.5119837,1.2073785,,,,,,,,,,,,,, -33800,0.51986134,1.201887,,,,,,,,,,,,,, -33900,0.5519293,1.2682062,,,,,,,,,,,,,, -34000,0.50568473,1.25914,,,,,,,,,,,,,, -34100,0.5712617,1.1928945,,,,,,,,,,,,,, -34200,0.6562689,1.2323035,,,,,,,,,,,,,, -34300,0.48228866,1.2069306,,,,,,,,,,,,,, -34400,0.5377682,1.2045994,,,,,,,,,,,,,, -34494,,,0.25215727,0.0928667264214493,0.4266692,0.1278179518618998,5348.0,0.24414307,0.0828306217374525,2472.0,27401.712621688843,30024.442529678345,27401.712621688843,2620.2989869117737,1.022853136062622,0.0 -34500,0.5983088,1.2583797,,,,,,,,,,,,,, -34600,0.5461884,1.1697519,,,,,,,,,,,,,, -34700,0.5320554,1.1841191,,,,,,,,,,,,,, -34800,0.5998251,1.254803,,,,,,,,,,,,,, -34900,0.5162685,1.2431693,,,,,,,,,,,,,, -35000,0.52443516,1.2367578,,,,,,,,,,,,,, -35100,0.49037367,1.1654378,,,,,,,,,,,,,, -35200,0.47976974,1.2286044,,,,,,,,,,,,,, -35300,0.47680357,1.2380192,,,,,,,,,,,,,, -35400,0.50299275,1.180023,,,,,,,,,,,,,, -35500,0.49529594,1.216995,,,,,,,,,,,,,, -35600,0.5924393,1.2414565,,,,,,,,,,,,,, -35700,0.4828561,1.2489457,,,,,,,,,,,,,, -35800,0.5514246,1.2137094,,,,,,,,,,,,,, -35900,0.5847468,1.2430604,,,,,,,,,,,,,, -36000,0.63680863,1.2169564,,,,,,,,,,,,,, -36100,0.5420551,1.1855252,,,,,,,,,,,,,, -36200,0.47793084,1.1649063,,,,,,,,,,,,,, -36300,0.48645774,1.2232226,,,,,,,,,,,,,, -36309,,,0.26872647,0.0920601350633074,0.41945672,0.124429168637825,5348.0,0.24109034,0.0801901163853512,2472.0,28842.23300004005,31595.88714289665,28842.23300004005,2751.0915093421936,1.0802249908447266,0.0 -36400,0.5135631,1.1846213,,,,,,,,,,,,,, -36500,0.5946386,1.1920968,,,,,,,,,,,,,, -36600,0.5402479,1.1788353,,,,,,,,,,,,,, -36700,0.46695516,1.2514313,,,,,,,,,,,,,, -36800,0.63681155,1.2376045,,,,,,,,,,,,,, -36900,0.51671314,1.1998333,,,,,,,,,,,,,, -37000,0.8652708,1.191466,,,,,,,,,,,,,, -37100,0.4897429,1.1978283,,,,,,,,,,,,,, -37200,0.49080557,1.149541,,,,,,,,,,,,,, -37300,0.51816136,1.20962,,,,,,,,,,,,,, -37400,0.6148675,1.1758534,,,,,,,,,,,,,, -37500,0.5060882,1.2265869,,,,,,,,,,,,,, -37600,0.5124387,1.1864452,,,,,,,,,,,,,, -37700,0.4930115,1.1985315,,,,,,,,,,,,,, -37800,0.62762827,1.2554096,,,,,,,,,,,,,, -37900,0.60788375,1.1520041,,,,,,,,,,,,,, -38000,0.48850986,1.1591256,,,,,,,,,,,,,, -38100,0.5446634,1.2514908,,,,,,,,,,,,,, -38107,,,0.21433543,0.0794151752505141,0.40897462,0.1207700551280689,5348.0,0.23581171,0.0775699226128816,2472.0,30282.535950660706,33166.483761549,30282.535950660706,2881.2546286582947,1.1364960670471191,0.0 -38200,0.52350515,1.1457428,,,,,,,,,,,,,, -38300,0.5418876,1.258423,,,,,,,,,,,,,, -38400,0.6165533,1.1940053,,,,,,,,,,,,,, -38500,0.7172097,1.1788616,,,,,,,,,,,,,, -38600,0.46806636,1.1783717,,,,,,,,,,,,,, -38700,0.5495634,1.1717688,,,,,,,,,,,,,, -38800,0.6153435,1.1358237,,,,,,,,,,,,,, -38900,0.6203085,1.1986375,,,,,,,,,,,,,, -39000,0.5348927,1.2143464,,,,,,,,,,,,,, -39100,0.6418597,1.2353501,,,,,,,,,,,,,, -39200,0.58457446,1.1722558,,,,,,,,,,,,,, -39300,0.5315562,1.1663185,,,,,,,,,,,,,, -39400,0.5281387,1.1930091,,,,,,,,,,,,,, -39500,0.5375327,1.178618,,,,,,,,,,,,,, -39600,0.6137164,1.1571311,,,,,,,,,,,,,, -39700,0.6925608,1.1765996,,,,,,,,,,,,,, -39800,0.484942,1.1646338,,,,,,,,,,,,,, -39900,0.54294705,1.1511267,,,,,,,,,,,,,, -39910,,,0.19636467,0.0730434975637368,0.40372702,0.1181826081079776,5348.0,0.2262993,0.0743200698718339,2472.0,31722.782329320908,34739.318687200546,31722.782329320908,3013.71555685997,1.1902563571929932,0.0 -40000,0.58736736,1.1433738,,,,,,,,,,,,,, -40100,0.64114046,1.2097569,,,,,,,,,,,,,, -40200,0.5980906,1.1740774,,,,,,,,,,,,,, -40300,0.5801392,1.1996496,,,,,,,,,,,,,, -40400,0.5644043,1.1160747,,,,,,,,,,,,,, -40500,0.7118627,1.1486806,,,,,,,,,,,,,, -40600,0.66552633,1.1327199,,,,,,,,,,,,,, -40700,0.4857072,1.1571203,,,,,,,,,,,,,, -40800,0.5853765,1.2140822,,,,,,,,,,,,,, -40900,0.5520669,1.1618084,,,,,,,,,,,,,, -41000,0.47925714,1.1691695,,,,,,,,,,,,,, -41100,0.5349485,1.1604538,,,,,,,,,,,,,, -41200,0.53398997,1.1715307,,,,,,,,,,,,,, -41300,0.5737051,1.1650852,,,,,,,,,,,,,, -41400,0.58835685,1.2279519,,,,,,,,,,,,,, -41500,0.6389907,1.1613426,,,,,,,,,,,,,, -41600,0.53028756,1.1422814,,,,,,,,,,,,,, -41700,0.7418815,1.1864806,,,,,,,,,,,,,, -41747,,,0.21572936,0.0796678366578203,0.395749,0.1162806414551493,5348.0,0.21935742,0.0716592529401011,2472.0,33163.08148312569,36310.41373419762,33163.08148312569,3144.373094320297,1.2504262924194336,0.0 -41800,0.6535466,1.1892338,,,,,,,,,,,,,, -41900,0.5718912,1.1273808,,,,,,,,,,,,,, -42000,0.5576783,1.1551675,,,,,,,,,,,,,, -42100,0.7601202,1.1479319,,,,,,,,,,,,,, -42200,0.70166075,1.2055099,,,,,,,,,,,,,, -42300,0.5050915,1.1403025,,,,,,,,,,,,,, -42400,0.5984332,1.166134,,,,,,,,,,,,,, -42500,0.5924786,1.1378887,,,,,,,,,,,,,, -42600,0.49878204,1.1244735,,,,,,,,,,,,,, -42700,0.48591334,1.1870253,,,,,,,,,,,,,, -42800,0.58732724,1.144101,,,,,,,,,,,,,, -42900,0.5322867,1.1482805,,,,,,,,,,,,,, -43000,0.52989334,1.1820359,,,,,,,,,,,,,, -43100,0.69214,1.1536206,,,,,,,,,,,,,, -43200,0.5654476,1.1155361,,,,,,,,,,,,,, -43300,0.5862685,1.1226224,,,,,,,,,,,,,, -43400,0.49361795,1.1084344,,,,,,,,,,,,,, -43500,0.5564397,1.1177748,,,,,,,,,,,,,, -43579,,,0.1478013,0.0551468587783352,0.38760498,0.113480792067737,5348.0,0.21774586,0.0726342087624154,2472.0,34603.13644838333,37891.17783498764,34603.13644838333,3284.943300962448,1.312373399734497,0.0 -43600,0.5722182,1.0621886,,,,,,,,,,,,,, -43700,0.49926743,1.0827326,,,,,,,,,,,,,, -43800,0.45652264,1.0852449,,,,,,,,,,,,,, -43900,0.5469091,1.169942,,,,,,,,,,,,,, -44000,0.557609,1.1205802,,,,,,,,,,,,,, -44100,0.57621765,1.1488755,,,,,,,,,,,,,, -44200,0.5320827,1.1449491,,,,,,,,,,,,,, -44300,0.59658897,1.1006731,,,,,,,,,,,,,, -44400,0.492967,1.125722,,,,,,,,,,,,,, -44500,0.6209067,1.1386768,,,,,,,,,,,,,, -44600,0.63824993,1.1505287,,,,,,,,,,,,,, -44700,0.5188898,1.1630803,,,,,,,,,,,,,, -44800,0.49271023,1.0514708,,,,,,,,,,,,,, -44900,0.527032,1.181929,,,,,,,,,,,,,, -45000,0.5463258,1.1096343,,,,,,,,,,,,,, -45100,0.5554528,1.1646248,,,,,,,,,,,,,, -45200,0.59117544,1.1126196,,,,,,,,,,,,,, -45300,0.5555738,1.1260403,,,,,,,,,,,,,, -45391,,,0.12686214,0.0475608169637049,0.37629935,0.1112891858231074,5348.0,0.21235178,0.0704202465825767,2472.0,36043.40504741669,39465.84740304947,36043.40504741669,3419.207637071609,1.3747718334197998,0.0 -45400,0.6510251,1.1200159,,,,,,,,,,,,,, -45500,0.4822121,1.1194935,,,,,,,,,,,,,, -45600,0.5569794,1.1090654,,,,,,,,,,,,,, -45700,0.5881864,1.1548388,,,,,,,,,,,,,, -45800,0.5857825,1.1234677,,,,,,,,,,,,,, -45900,0.552177,1.0826076,,,,,,,,,,,,,, -46000,0.6248966,1.1601964,,,,,,,,,,,,,, -46100,0.5779099,1.0983752,,,,,,,,,,,,,, -46200,0.5967643,1.0960106,,,,,,,,,,,,,, -46300,0.55500466,1.1346936,,,,,,,,,,,,,, -46400,0.5829537,1.1426951,,,,,,,,,,,,,, -46500,0.6794894,1.0931267,,,,,,,,,,,,,, -46600,0.6833388,1.0888844,,,,,,,,,,,,,, -46700,0.5708667,1.0782033,,,,,,,,,,,,,, -46800,0.6906957,1.1702207,,,,,,,,,,,,,, -46900,0.6224761,1.0900447,,,,,,,,,,,,,, -47000,0.6115112,1.1415588,,,,,,,,,,,,,, -47100,0.7146157,1.0800208,,,,,,,,,,,,,, -47200,0.56703967,1.1008193,,,,,,,,,,,,,, -47211,,,0.1242431,0.0480329439796949,0.3734636,0.1081803875377738,5348.0,0.20668186,0.0686937623138951,2472.0,37484.02665233612,41042.178755521774,37484.02665233612,3554.784821510315,1.4327740669250488,0.0 -47300,0.5640988,1.0971789,,,,,,,,,,,,,, -47400,0.50106853,1.1175755,,,,,,,,,,,,,, -47500,0.5342957,1.1619567,,,,,,,,,,,,,, -47600,0.55066437,1.1028386,,,,,,,,,,,,,, -47700,0.5901452,1.0832541,,,,,,,,,,,,,, -47800,0.5293155,1.0835698,,,,,,,,,,,,,, -47900,0.61198974,1.1407524,,,,,,,,,,,,,, -48000,0.6218817,1.0968969,,,,,,,,,,,,,, -48100,0.5847245,1.0942518,,,,,,,,,,,,,, -48200,0.6515476,1.0888004,,,,,,,,,,,,,, -48300,0.55761576,1.1254411,,,,,,,,,,,,,, -48400,0.7100646,1.1044865,,,,,,,,,,,,,, -48500,0.5610421,1.1180657,,,,,,,,,,,,,, -48600,0.53101856,1.0982352,,,,,,,,,,,,,, -48700,0.5807228,1.0446904,,,,,,,,,,,,,, -48800,0.74384725,1.0752686,,,,,,,,,,,,,, -48900,0.52884513,1.095788,,,,,,,,,,,,,, -49000,0.55168617,1.0942546,,,,,,,,,,,,,, -49025,,,0.12451019,0.0470675119749111,0.36440134,0.1062977301910655,5348.0,0.20458083,0.0670891475230028,2472.0,38924.49155926704,42617.03544545174,38924.49155926704,3689.043043851853,1.4914908409118652,0.0 -49100,0.6253623,1.0974008,,,,,,,,,,,,,, -49200,0.5452111,1.0924627,,,,,,,,,,,,,, -49300,0.52043945,1.0866458,,,,,,,,,,,,,, -49400,0.64896494,1.0818315,,,,,,,,,,,,,, -49500,0.60521245,1.0356354,,,,,,,,,,,,,, -49600,0.6343791,1.0970153,,,,,,,,,,,,,, -49700,0.58402807,1.0848061,,,,,,,,,,,,,, -49800,0.57361335,1.0432496,,,,,,,,,,,,,, -49900,0.5341043,1.0739467,,,,,,,,,,,,,, -50000,0.6589286,1.0909706,,,,,,,,,,,,,, -50100,0.59561485,1.0870527,,,,,,,,,,,,,, -50200,0.5545997,1.0810633,,,,,,,,,,,,,, -50300,0.5905967,1.0532745,,,,,,,,,,,,,, -50400,0.5663696,1.0953679,,,,,,,,,,,,,, -50500,0.5285265,1.0529507,,,,,,,,,,,,,, -50600,0.60401535,1.0766939,,,,,,,,,,,,,, -50700,0.88864046,1.093146,,,,,,,,,,,,,, -50800,0.52543396,1.0318941,,,,,,,,,,,,,, -50856,,,0.11252448,0.0436147960437567,0.35983658,0.1053129555789412,5348.0,0.1984191,0.0658501411654784,2472.0,40364.53780937195,44188.77558088303,40364.53780937195,3820.602771282196,1.5500528812408447,0.0 -50900,0.53551626,1.048874,,,,,,,,,,,,,, -51000,0.5642274,1.0690471,,,,,,,,,,,,,, -51100,0.53239244,1.1027614,,,,,,,,,,,,,, -51200,0.602214,1.1235921,,,,,,,,,,,,,, -51300,0.7102248,1.0828562,,,,,,,,,,,,,, -51400,0.6020523,1.0925128,,,,,,,,,,,,,, -51500,0.6770331,1.0588156,,,,,,,,,,,,,, -51600,0.70885384,1.0554403,,,,,,,,,,,,,, -51700,0.6218951,1.0996441,,,,,,,,,,,,,, -51800,0.7350124,1.0509636,,,,,,,,,,,,,, -51900,0.6445479,1.0727209,,,,,,,,,,,,,, -52000,0.604234,1.0454357,,,,,,,,,,,,,, -52100,0.6064952,1.0463215,,,,,,,,,,,,,, -52200,0.605426,0.98934543,,,,,,,,,,,,,, -52300,0.5220032,1.0623444,,,,,,,,,,,,,, -52400,0.6528823,1.0708183,,,,,,,,,,,,,, -52500,0.5417214,1.0762156,,,,,,,,,,,,,, -52600,0.64396703,1.0402889,,,,,,,,,,,,,, -52665,,,0.10154488,0.0393358401807406,0.34599328,0.100929743089682,5348.0,0.19205548,0.0630065200170617,2472.0,41804.75467848778,45763.38461208344,41804.75467848778,3954.864999771118,1.6051020622253418,0.0 -52700,0.539107,1.0832356,,,,,,,,,,,,,, -52800,0.56240165,1.036948,,,,,,,,,,,,,, -52900,0.594309,1.0607655,,,,,,,,,,,,,, -53000,0.5860029,1.072617,,,,,,,,,,,,,, -53100,0.66367316,1.0851897,,,,,,,,,,,,,, -53200,0.7217513,1.0869938,,,,,,,,,,,,,, -53300,0.5446538,1.0850405,,,,,,,,,,,,,, -53400,0.6127608,1.0482633,,,,,,,,,,,,,, -53500,0.52350295,1.0023925,,,,,,,,,,,,,, -53600,0.5779395,1.0330594,,,,,,,,,,,,,, -53700,0.5503523,1.0364476,,,,,,,,,,,,,, -53800,0.5755447,1.0078967,,,,,,,,,,,,,, -53900,0.70843506,1.0151523,,,,,,,,,,,,,, -54000,0.6388469,1.0086601,,,,,,,,,,,,,, -54100,0.71186113,1.022549,,,,,,,,,,,,,, -54200,0.7204798,1.0319391,,,,,,,,,,,,,, -54300,0.49802017,1.0429907,,,,,,,,,,,,,, -54400,0.5955081,0.99405175,,,,,,,,,,,,,, -54472,,,0.116006896,0.0441252374833898,0.3398731,0.1001863348040588,5348.0,0.18795243,0.0621940568317998,2472.0,43245.01002693176,47336.69872045517,43245.01002693176,4087.795857667923,1.659766435623169,0.0 -54500,0.6099982,1.0834296,,,,,,,,,,,,,, -54600,0.61576134,0.9882208,,,,,,,,,,,,,, -54700,0.65729606,1.0506855,,,,,,,,,,,,,, -54800,0.60678333,1.0348358,,,,,,,,,,,,,, -54900,0.62545544,1.0131563,,,,,,,,,,,,,, -55000,0.62810117,1.033695,,,,,,,,,,,,,, -55100,0.65697813,0.9950699,,,,,,,,,,,,,, -55200,0.63667464,1.0215664,,,,,,,,,,,,,, -55300,0.63606477,1.0389124,,,,,,,,,,,,,, -55400,0.63937676,1.0222226,,,,,,,,,,,,,, -55500,0.7352382,1.0721191,,,,,,,,,,,,,, -55600,0.6924561,1.0514297,,,,,,,,,,,,,, -55700,0.7585145,1.0345923,,,,,,,,,,,,,, -55800,0.6188052,1.0471812,,,,,,,,,,,,,, -55900,0.7108915,1.0598644,,,,,,,,,,,,,, -56000,0.629931,1.0359929,,,,,,,,,,,,,, -56100,0.6049313,0.96815073,,,,,,,,,,,,,, -56200,0.6073003,1.0388821,,,,,,,,,,,,,, -56285,,,0.10221567,0.0367382905171829,0.3412774,0.097502341253367,5348.0,0.1826138,0.0587613998740682,2472.0,44685.05935645104,48910.888063669205,44685.05935645104,4221.798095703125,1.7223410606384275,0.0 -56300,0.74586725,1.0478555,,,,,,,,,,,,,, -56400,0.56781435,1.0360057,,,,,,,,,,,,,, -56500,0.6805601,1.0464277,,,,,,,,,,,,,, -56600,0.667235,1.0420628,,,,,,,,,,,,,, -56700,0.6796938,1.0029229,,,,,,,,,,,,,, -56800,0.66389024,0.9799044,,,,,,,,,,,,,, -56900,0.7079634,1.014863,,,,,,,,,,,,,, -57000,0.55412424,1.0220456,,,,,,,,,,,,,, -57100,0.6492611,1.0111699,,,,,,,,,,,,,, -57200,0.65284425,0.9790265,,,,,,,,,,,,,, -57300,0.7460411,1.0235007,,,,,,,,,,,,,, -57400,0.53887445,0.93920237,,,,,,,,,,,,,, -57500,0.63363373,1.0233638,,,,,,,,,,,,,, -57600,0.7363586,1.037775,,,,,,,,,,,,,, -57700,0.6058364,0.97984934,,,,,,,,,,,,,, -57800,0.61876434,0.99924,,,,,,,,,,,,,, -57900,0.58845806,0.9833587,,,,,,,,,,,,,, -58000,0.7279387,1.0001624,,,,,,,,,,,,,, -58100,0.63871044,1.0098305,,,,,,,,,,,,,, -58115,,,0.095418066,0.0367302365224802,0.3292542,0.0944707801925137,5348.0,0.18172288,0.0592082546259622,2472.0,46125.22264838219,50485.74221920967,46125.22264838219,4356.356876373291,1.7787425518035889,0.0 -58200,0.58403176,0.9928555,,,,,,,,,,,,,, -58300,0.55817753,1.0593606,,,,,,,,,,,,,, -58400,0.79117185,0.97516483,,,,,,,,,,,,,, -58500,0.7140864,0.9969865,,,,,,,,,,,,,, -58600,0.59535474,0.953222,,,,,,,,,,,,,, -58700,0.7142656,1.0136157,,,,,,,,,,,,,, -58800,0.87538445,1.0178,,,,,,,,,,,,,, -58900,0.65122384,0.96846455,,,,,,,,,,,,,, -59000,0.8412113,0.9758583,,,,,,,,,,,,,, -59100,0.63342154,0.9693029,,,,,,,,,,,,,, -59200,0.56701607,0.99800867,,,,,,,,,,,,,, -59300,0.5527941,0.9681371,,,,,,,,,,,,,, -59400,0.6845722,0.98046225,,,,,,,,,,,,,, -59500,0.79457086,1.0160624,,,,,,,,,,,,,, -59600,0.7922184,1.0037456,,,,,,,,,,,,,, -59700,0.7858404,0.9593889,,,,,,,,,,,,,, -59800,0.7174199,0.9862157,,,,,,,,,,,,,, -59900,0.5838883,0.9789092,,,,,,,,,,,,,, -59930,,,0.0784812,0.0307128037937166,0.32586622,0.0934763509273294,5348.0,0.17646027,0.0567505534905449,2472.0,47565.73729014397,52059.585463523865,47565.73729014397,4489.549370765686,1.8408918380737305,0.0 -60000,0.72158647,0.98732483,,,,,,,,,,,,,, -60100,0.6684737,1.0181397,,,,,,,,,,,,,, -60200,0.59915936,0.9871982,,,,,,,,,,,,,, -60300,0.5635998,1.002029,,,,,,,,,,,,,, -60400,0.6888448,1.025129,,,,,,,,,,,,,, -60500,0.9213573,0.9949475,,,,,,,,,,,,,, -60600,0.6024477,0.9907301,,,,,,,,,,,,,, -60700,0.77333784,0.9867024,,,,,,,,,,,,,, -60800,0.78620255,1.0152649,,,,,,,,,,,,,, -60900,0.59693515,1.0001965,,,,,,,,,,,,,, -61000,0.6993312,1.002143,,,,,,,,,,,,,, -61100,0.7597167,1.0323211,,,,,,,,,,,,,, -61200,0.6079025,1.0067139,,,,,,,,,,,,,, -61300,0.6451427,0.9629003,,,,,,,,,,,,,, -61400,0.65889704,1.0139507,,,,,,,,,,,,,, -61500,0.6749081,0.96368295,,,,,,,,,,,,,, -61600,0.6007143,0.9738884,,,,,,,,,,,,,, -61700,0.5588691,0.9459401,,,,,,,,,,,,,, -61736,,,0.07178165,0.0287519646911267,0.31728697,0.0908309759888778,5348.0,0.17256908,0.0555521702922836,2472.0,49006.57438802719,53634.54512619972,49006.57438802719,4623.5429792404175,1.8968467712402344,0.0 -61800,0.7557779,0.9805225,,,,,,,,,,,,,, -61900,0.6748026,0.9414806,,,,,,,,,,,,,, -62000,0.646684,0.9461662,,,,,,,,,,,,,, -62100,0.7742704,0.98044276,,,,,,,,,,,,,, -62200,0.62300676,0.98101145,,,,,,,,,,,,,, -62300,0.6630245,0.9920972,,,,,,,,,,,,,, -62400,0.87902045,0.96938956,,,,,,,,,,,,,, -62500,0.6452883,0.97090465,,,,,,,,,,,,,, -62600,0.64793056,0.94547915,,,,,,,,,,,,,, -62700,0.63541853,0.96257424,,,,,,,,,,,,,, -62800,0.5798585,0.9572721,,,,,,,,,,,,,, -62900,0.5759014,0.97468567,,,,,,,,,,,,,, -63000,0.69249785,0.95877486,,,,,,,,,,,,,, -63100,0.67659074,0.9564641,,,,,,,,,,,,,, -63200,0.5957149,0.97063607,,,,,,,,,,,,,, -63300,0.9628462,0.97460693,,,,,,,,,,,,,, -63400,0.6224033,0.9263619,,,,,,,,,,,,,, -63500,0.614802,0.96052974,,,,,,,,,,,,,, -63544,,,0.078863874,0.0309223185523101,0.3153456,0.0890062465605298,5348.0,0.16620629,0.0538866207624967,2472.0,50446.679966926575,55208.704122543335,50446.679966926575,4757.461458683014,1.9577386379241943,0.0 -63600,0.54373366,0.984943,,,,,,,,,,,,,, -63700,0.7003053,0.98248917,,,,,,,,,,,,,, -63800,0.894167,0.921088,,,,,,,,,,,,,, -63900,0.67794573,0.95957863,,,,,,,,,,,,,, -64000,0.6349755,0.9775409,,,,,,,,,,,,,, -64100,0.627843,0.9372321,,,,,,,,,,,,,, -64200,0.70902604,0.98697287,,,,,,,,,,,,,, -64300,0.7507208,0.93020475,,,,,,,,,,,,,, -64400,0.91054887,0.92237216,,,,,,,,,,,,,, -64500,1.0847008,0.97315097,,,,,,,,,,,,,, -64600,0.6216388,0.97257507,,,,,,,,,,,,,, -64700,0.7630994,0.9819627,,,,,,,,,,,,,, -64800,0.5343177,0.9697809,,,,,,,,,,,,,, -64900,0.7172954,0.92766356,,,,,,,,,,,,,, -65000,0.6734307,0.94079334,,,,,,,,,,,,,, -65100,0.85484326,0.9497004,,,,,,,,,,,,,, -65200,0.6356684,0.93734133,,,,,,,,,,,,,, -65300,0.7652901,0.9376279,,,,,,,,,,,,,, -65377,,,0.072210774,0.0280750267393629,0.3099662,0.0889579732952296,5348.0,0.16686207,0.0535413239087603,2472.0,51887.27757143974,56783.84284877777,51887.27757143974,4891.862742900848,2.0222373008728027,0.0 -65400,0.64340764,0.9563217,,,,,,,,,,,,,, -65500,0.675321,0.9181692,,,,,,,,,,,,,, -65600,0.65847874,0.92373955,,,,,,,,,,,,,, -65700,0.69638216,0.9416376,,,,,,,,,,,,,, -65800,0.7043892,0.9439946,,,,,,,,,,,,,, -65900,0.66277784,0.9588669,,,,,,,,,,,,,, -66000,0.6778772,0.9486509,,,,,,,,,,,,,, -66100,0.69417936,0.97777873,,,,,,,,,,,,,, -66200,0.82397085,0.91567415,,,,,,,,,,,,,, -66300,0.619235,0.9415149,,,,,,,,,,,,,, -66400,0.93888646,0.95549655,,,,,,,,,,,,,, -66500,0.73774946,0.8727042,,,,,,,,,,,,,, -66600,0.68889624,0.8998037,,,,,,,,,,,,,, -66700,0.80153966,0.92138124,,,,,,,,,,,,,, -66800,0.76117116,0.93239915,,,,,,,,,,,,,, -66900,0.66517574,0.9732887,,,,,,,,,,,,,, -67000,0.72346455,0.9565093,,,,,,,,,,,,,, -67100,1.274399,0.9511656,,,,,,,,,,,,,, -67194,,,0.06876019,0.0267672908902852,0.30606192,0.0876642497851839,5348.0,0.16483718,0.0523835638697621,2472.0,53327.70025777817,58357.71854352951,53327.70025777817,5025.178674221039,2.083531856536865,0.0 -67200,0.72012055,0.9304152,,,,,,,,,,,,,, -67300,0.7081412,0.8822393,,,,,,,,,,,,,, -67400,0.6681836,0.94236845,,,,,,,,,,,,,, -67500,0.6836705,0.98422855,,,,,,,,,,,,,, -67600,0.75192034,0.9360547,,,,,,,,,,,,,, -67700,0.62759095,0.9166041,,,,,,,,,,,,,, -67800,0.6295666,0.93738025,,,,,,,,,,,,,, -67900,0.7116404,0.90877384,,,,,,,,,,,,,, -68000,0.6417971,0.9372863,,,,,,,,,,,,,, -68100,0.6707054,0.9117065,,,,,,,,,,,,,, -68200,0.6396331,0.9133608,,,,,,,,,,,,,, -68300,0.6406322,0.8838813,,,,,,,,,,,,,, -68400,0.7036108,0.9467315,,,,,,,,,,,,,, -68500,0.66161907,0.86037546,,,,,,,,,,,,,, -68600,0.76588476,0.91706157,,,,,,,,,,,,,, -68700,0.70718956,0.89349526,,,,,,,,,,,,,, -68800,0.66499156,0.89747846,,,,,,,,,,,,,, -68900,0.65995103,0.8981164,,,,,,,,,,,,,, -68998,,,0.060892846,0.0234735113973606,0.30153775,0.0862160518261776,5348.0,0.1602182,0.0511851806715008,2472.0,54768.31900382042,59931.1682472229,54768.31900382042,5157.87672996521,2.142979860305786,0.0 -69000,0.68282557,0.93148243,,,,,,,,,,,,,, -69100,0.6611111,0.9117061,,,,,,,,,,,,,, -69200,0.60238963,0.95932156,,,,,,,,,,,,,, -69300,0.6947465,0.94759864,,,,,,,,,,,,,, -69400,0.81820804,0.9111263,,,,,,,,,,,,,, -69500,0.649548,0.9117233,,,,,,,,,,,,,, -69600,0.5904985,0.8954685,,,,,,,,,,,,,, -69700,0.7578811,0.89632094,,,,,,,,,,,,,, -69800,0.6242142,0.8678452,,,,,,,,,,,,,, -69900,0.80958396,0.8924598,,,,,,,,,,,,,, -70000,0.68338484,0.91482747,,,,,,,,,,,,,, -70100,0.7313811,0.8948128,,,,,,,,,,,,,, -70200,0.73693097,0.8715966,,,,,,,,,,,,,, -70300,0.6676286,0.9041207,,,,,,,,,,,,,, -70400,0.6118239,0.91193676,,,,,,,,,,,,,, -70500,0.7336369,0.88316756,,,,,,,,,,,,,, -70600,0.57578033,0.90498877,,,,,,,,,,,,,, -70700,0.57623833,0.90480787,,,,,,,,,,,,,, -70800,0.67058057,0.9233258,,,,,,,,,,,,,, -70808,,,0.05510921,0.0211103238444546,0.29850617,0.0841113374590884,5348.0,0.15855892,0.0502711595880811,2472.0,56208.43280529976,61505.95262885094,56208.43280529976,5292.408869028091,2.206705570220948,0.0 -70900,0.80451965,0.9136451,,,,,,,,,,,,,, -71000,0.6258044,0.91171134,,,,,,,,,,,,,, -71100,0.778014,0.8866034,,,,,,,,,,,,,, -71200,0.88831645,0.9192585,,,,,,,,,,,,,, -71300,0.80810964,0.8851986,,,,,,,,,,,,,, -71400,0.62770665,0.8897743,,,,,,,,,,,,,, -71500,0.60834247,0.89494234,,,,,,,,,,,,,, -71600,0.649678,0.8622108,,,,,,,,,,,,,, -71700,0.7422218,0.9072295,,,,,,,,,,,,,, -71800,0.7940734,0.8860293,,,,,,,,,,,,,, -71900,0.6261519,0.8739318,,,,,,,,,,,,,, -72000,0.66222745,0.93907773,,,,,,,,,,,,,, -72100,0.6547651,0.88974416,,,,,,,,,,,,,, -72200,0.66106904,0.8861896,,,,,,,,,,,,,, -72300,0.6709451,0.8242285,,,,,,,,,,,,,, -72400,0.63804984,0.8779848,,,,,,,,,,,,,, -72500,1.2074183,0.90666336,,,,,,,,,,,,,, -72600,0.6578562,0.8687562,,,,,,,,,,,,,, -72636,,,0.059947684,0.0222034592160474,0.29521114,0.0831651814592042,5348.0,0.158432,0.0497633700972924,2472.0,57648.75893139839,63082.29228138924,57648.75893139839,5428.279653549194,2.27409291267395,0.0 -72700,0.76037717,0.8846775,,,,,,,,,,,,,, -72800,0.8810499,0.921081,,,,,,,,,,,,,, -72900,0.8333668,0.87748665,,,,,,,,,,,,,, -73000,0.7654715,0.91658634,,,,,,,,,,,,,, -73100,0.6791277,0.86501026,,,,,,,,,,,,,, -73200,0.7196177,0.90913785,,,,,,,,,,,,,, -73300,0.70046234,0.8722578,,,,,,,,,,,,,, -73400,0.81533223,0.8701471,,,,,,,,,,,,,, -73500,0.795059,0.8822004,,,,,,,,,,,,,, -73600,0.75762856,0.88170207,,,,,,,,,,,,,, -73700,0.8183144,0.8593964,,,,,,,,,,,,,, -73800,0.64964443,0.84653264,,,,,,,,,,,,,, -73900,0.74670535,0.90059876,,,,,,,,,,,,,, -74000,0.70849365,0.89199287,,,,,,,,,,,,,, -74100,0.5690579,0.8632935,,,,,,,,,,,,,, -74200,0.74838084,0.8842021,,,,,,,,,,,,,, -74300,0.666281,0.86171305,,,,,,,,,,,,,, -74400,1.0608127,0.86398005,,,,,,,,,,,,,, -74455,,,0.057411306,0.0215887926147866,0.29396224,0.082431427826641,5348.0,0.15610263,0.0496618121991347,2472.0,59089.160449266434,64654.3326253891,59089.160449266434,5559.786453962326,2.330268144607544,0.0 -74500,0.733427,0.9241276,,,,,,,,,,,,,, -74600,0.6187129,0.81199086,,,,,,,,,,,,,, -74700,0.66265684,0.87532884,,,,,,,,,,,,,, -74800,0.65869355,0.8444329,,,,,,,,,,,,,, -74900,0.98342067,0.8446436,,,,,,,,,,,,,, -75000,0.8067601,0.8704201,,,,,,,,,,,,,, -75100,0.72763914,0.8548028,,,,,,,,,,,,,, -75200,0.88964355,0.88371235,,,,,,,,,,,,,, -75300,0.9935409,0.9117662,,,,,,,,,,,,,, -75400,0.64074004,0.85713834,,,,,,,,,,,,,, -75500,0.69233364,0.9189581,,,,,,,,,,,,,, -75600,1.1392998,0.9043458,,,,,,,,,,,,,, -75700,0.782333,0.88033324,,,,,,,,,,,,,, -75800,0.659488,0.8810511,,,,,,,,,,,,,, -75900,0.71593124,0.8865827,,,,,,,,,,,,,, -76000,0.7035078,0.86864954,,,,,,,,,,,,,, -76100,0.93992317,0.8905173,,,,,,,,,,,,,, -76200,0.6931815,0.8945333,,,,,,,,,,,,,, -76255,,,0.05751698,0.0213026943060688,0.29176736,0.0821417882348397,5348.0,0.15516719,0.0494383848231877,2472.0,60529.28284430504,66228.47186899185,60529.28284430504,5693.6730353832245,2.3865671157836914,0.0 -76300,0.71171755,0.8375066,,,,,,,,,,,,,, -76400,0.77068275,0.86174583,,,,,,,,,,,,,, -76500,0.7418491,0.9233072,,,,,,,,,,,,,, -76600,0.6703217,0.8862985,,,,,,,,,,,,,, -76700,0.7129224,0.87402004,,,,,,,,,,,,,, -76800,0.7411842,0.8943901,,,,,,,,,,,,,, -76900,0.69574153,0.8746201,,,,,,,,,,,,,, -76955,,,,,,,,,,,61068.16817140579,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index a9451e9d0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,30 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -204.89513993263245,0.0,43.85169339179993,1,0,43.85169339179993,28.255945,2472,2.343875043162107,248.746901512146,29.405813,2.4046810903108926,28.149015,5348,2.185263137569152 -315.39243960380554,0.0418043136596679,1483.9563069343567,1737,0,1483.9563069343567,6.369111,2472,0.899579550301627,1799.4650785923004,6.6251674,0.9446252394389404,6.4828987,5348,0.8966179750330672 -436.4544777870178,0.0805685520172119,2924.2540802955627,3504,0,2924.2540802955627,3.4329967,2472,0.6812706924217496,3360.9410836696625,3.5062532,0.7003595876728018,3.887965,5348,0.7383395927667339 -577.4759426116943,0.1174452304840087,4364.524667263031,5218,0,4364.524667263031,0.794886,2472,0.2441248755915747,4942.343545675278,0.7218365,0.2371396492559735,1.1679587,5348,0.3178311787365921 -715.8265788555145,0.1547858715057373,5804.555928230286,6930,0,5804.555928230286,0.5532459,2472,0.178376292324254,6520.837166547775,0.5146038,0.1715433191257631,0.8652903,5348,0.2476997789084449 -859.2827861309052,0.1961219310760498,7245.780788183212,8665,0,7245.780788183212,0.5063116,2472,0.16480815713038,8105.636641263962,0.46659666,0.1569886277153232,0.8137787,5348,0.2347432345018681 -991.9161324501038,0.239854097366333,8685.838481426239,10391,0,8685.838481426239,0.45595506,2472,0.1479089228769321,9678.446252822876,0.43561387,0.1443470560306426,0.7429956,5348,0.2146808654431003 -1125.863081932068,0.2782533168792724,10125.72972393036,12124,0,10125.72972393036,0.44069472,2472,0.1428310279690451,11252.399189710615,0.39994377,0.1340533469465892,0.71666336,5348,0.2062620079747434 -1262.8164167404177,0.3170113563537597,11567.119065999985,13863,0,11567.119065999985,0.41449788,2472,0.1320862023439563,12830.857915639876,0.34301874,0.1186406302815985,0.683678,5348,0.1980072796084072 -1401.224974155426,0.3565559387207031,13007.51091980934,15594,0,13007.51091980934,0.39703676,2472,0.1269270611175431,14409.774055480955,0.31959072,0.108528533960387,0.67307526,5348,0.1947343522210529 -1544.6018645763395,0.3955864906311035,14447.46181845665,17329,0,14447.46181845665,0.39617884,2472,0.1280238864176467,15993.21812081337,0.32896024,0.1162404071497905,0.6574222,5348,0.1909786921806965 -1678.731039762497,0.4372310638427734,15887.366396427156,19065,0,15887.366396427156,0.3724943,2472,0.1213616882984989,17567.369089365005,0.32634243,0.1093841719770485,0.6247462,5348,0.1807544145901117 -1814.3625190258024,0.4785201549530029,17327.25873041153,20805,0,17327.25873041153,0.35500604,2472,0.1163853512887697,19143.011053800583,0.31835115,0.1072308544595984,0.6072395,5348,0.1752319530397675 -1949.846531391144,0.5190324783325195,18767.39107894897,22510,0,18767.39107894897,0.34486303,2472,0.1103731237178315,20718.74327898025,0.29313225,0.0993517851630826,0.5852561,5348,0.1685316238160981 -2084.754161596298,0.5621540546417236,20207.345369815823,24245,0,20207.345369815823,0.33417776,2472,0.1059045761988909,22293.72416448593,0.27176955,0.0930756724800472,0.57057077,5348,0.164611834673721 -2217.822919368744,0.6007964611053467,21647.66396617889,25965,0,21647.66396617889,0.31846216,2472,0.1046249466821034,23867.22459101677,0.2505488,0.0890581392315918,0.5587426,5348,0.1626229761433522 -2352.263339281082,0.6415128707885742,23087.62013030052,27671,0,23087.62013030052,0.31094378,2472,0.1007048118132147,25441.73743915558,0.2422054,0.0829080197604606,0.5420699,5348,0.1576701391235506 -2490.2884378433228,0.685178279876709,24528.12004709244,29371,0,24528.12004709244,0.30119932,2472,0.0959519021794324,27020.380325078964,0.25566694,0.0881056075576356,0.52995616,5348,0.1546771966749374 -2622.43061876297,0.7326092720031738,25968.7009601593,31104,0,25968.7009601593,0.29052296,2472,0.0922755062661223,28593.22857093811,0.23302117,0.0778354738766895,0.51413625,5348,0.1478996302267878 -2758.963734149933,0.7755627632141113,27408.875204086304,32821,0,27408.875204086304,0.28367037,2472,0.0900615440862835,30170.055107831955,0.2403235,0.0805778012934776,0.50835735,5348,0.1464900508800216 -2893.2134778499603,0.820319414138794,28848.791372060776,34549,0,28848.791372060776,0.26772746,2472,0.0863242134340787,31744.34097886085,0.22173432,0.0732813282607073,0.48096427,5348,0.13927802504417 -3029.243176460266,0.8601865768432617,30288.866734981537,36283,0,30288.866734981537,0.2593685,2472,0.0838665122986614,33320.562363147736,0.16659415,0.0579881717454362,0.46956345,5348,0.1348368846365505 -3168.526449203491,0.9034137725830078,31729.2007689476,37984,0,31729.2007689476,0.24856593,2472,0.0794792110982471,34900.29788470268,0.1848998,0.0630979282380519,0.4517904,5348,0.1306564198615522 -3301.976773738861,0.9474256038665771,33169.299060583115,39697,0,33169.299060583115,0.2419337,2472,0.0779761542055125,36473.96536016464,0.23087613,0.0794465266891946,0.44235966,5348,0.1274221110864381 -3436.159049510956,0.9893152713775636,34609.19066643715,41414,0,34609.19066643715,0.23373564,2472,0.075620010968253,38048.15858435631,0.2321512,0.0798170281507963,0.4297312,5348,0.1231837183930795 -3569.5378217697144,1.0312442779541016,36049.38084578514,43118,0,36049.38084578514,0.22829898,2472,0.073101375093941,39621.84571695328,0.26301807,0.0910260834976104,0.42406932,5348,0.1224016914952161 -3701.4083173274994,1.073310613632202,37489.69993138313,44835,0,37489.69993138313,0.22612567,2472,0.0724514045457315,41194.15355873108,0.23237117,0.0770743425042243,0.42018595,5348,0.1203452503934271 -3834.3490607738495,1.1359052658081057,38929.57273769379,46551,0,38929.57273769379,0.22410408,2472,0.0715983182012065,42767.10988521576,0.21373685,0.0742041797957979,0.41774118,5348,0.1196887339853442 -3969.007215976715,1.1929059028625488,40149.33516192436,48000,0,40149.33516192436,0.22410974,2472,0.07137489082525948,44121.655792713165,0.19231977,0.06766794636157,0.41810927,5348,0.11975631655676454 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index ea786281d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,511 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.398708,32.61559,,,,,,,,,,,,,, -1,,,29.405813,2.4046810903108926,28.149015,2.185263137569152,5348.0,28.255945,2.343875043162107,2472.0,43.85169339179993,248.746901512146,43.85169339179993,204.89513993263245,0.0,0.0 -100,5.1085215,9.070861,,,,,,,,,,,,,, -200,1.777264,6.510295,,,,,,,,,,,,,, -300,0.52845174,5.9052043,,,,,,,,,,,,,, -400,0.3313454,5.8404036,,,,,,,,,,,,,, -500,0.37015814,5.830102,,,,,,,,,,,,,, -600,0.6390093,5.8130426,,,,,,,,,,,,,, -700,0.38070086,5.742335,,,,,,,,,,,,,, -800,0.59922594,5.6289177,,,,,,,,,,,,,, -900,0.4995264,5.5087457,,,,,,,,,,,,,, -1000,0.52683777,5.389214,,,,,,,,,,,,,, -1100,0.7466925,5.221553,,,,,,,,,,,,,, -1200,0.8765487,4.796853,,,,,,,,,,,,,, -1300,1.0552169,4.3396077,,,,,,,,,,,,,, -1400,1.2957213,4.01505,,,,,,,,,,,,,, -1500,1.8548398,3.7182877,,,,,,,,,,,,,, -1600,3.1208067,3.5173218,,,,,,,,,,,,,, -1700,1.8177955,3.302434,,,,,,,,,,,,,, -1737,,,6.6251674,0.9446252394389404,6.4828987,0.8966179750330672,5348.0,6.369111,0.899579550301627,2472.0,1483.9563069343567,1799.4650785923004,1483.9563069343567,315.39243960380554,0.0418043136596679,0.0 -1800,2.4579785,3.1480901,,,,,,,,,,,,,, -1900,2.012009,3.0571241,,,,,,,,,,,,,, -2000,2.4520738,2.965278,,,,,,,,,,,,,, -2100,2.6451144,2.858619,,,,,,,,,,,,,, -2200,2.8728986,2.751056,,,,,,,,,,,,,, -2300,3.566946,2.7148726,,,,,,,,,,,,,, -2400,2.4731784,2.6518383,,,,,,,,,,,,,, -2500,3.9669971,2.6543612,,,,,,,,,,,,,, -2600,2.6564171,2.5630388,,,,,,,,,,,,,, -2700,3.0697412,2.444043,,,,,,,,,,,,,, -2800,2.5152693,2.4802027,,,,,,,,,,,,,, -2900,2.7918656,2.4740226,,,,,,,,,,,,,, -3000,2.7497365,2.368186,,,,,,,,,,,,,, -3100,4.8078036,2.3284032,,,,,,,,,,,,,, -3200,3.9053826,2.2233317,,,,,,,,,,,,,, -3300,4.913415,2.3054109,,,,,,,,,,,,,, -3400,3.5811586,2.151973,,,,,,,,,,,,,, -3500,4.179702,2.1721992,,,,,,,,,,,,,, -3504,,,3.5062532,0.7003595876728018,3.887965,0.7383395927667339,5348.0,3.4329967,0.6812706924217496,2472.0,2924.2540802955627,3360.9410836696625,2924.2540802955627,436.4544777870178,0.0805685520172119,0.0 -3600,3.1829526,2.154426,,,,,,,,,,,,,, -3700,3.4416227,2.1347575,,,,,,,,,,,,,, -3800,3.5311472,2.1965299,,,,,,,,,,,,,, -3900,3.907694,2.177648,,,,,,,,,,,,,, -4000,4.3847604,2.0996594,,,,,,,,,,,,,, -4100,3.4973483,2.1148484,,,,,,,,,,,,,, -4200,4.187183,2.0517867,,,,,,,,,,,,,, -4300,3.5326555,2.0810215,,,,,,,,,,,,,, -4400,4.5431185,2.0592303,,,,,,,,,,,,,, -4500,4.4239244,2.0807314,,,,,,,,,,,,,, -4600,3.0550838,2.0665774,,,,,,,,,,,,,, -4700,3.5137184,2.037611,,,,,,,,,,,,,, -4800,6.6926565,2.0264046,,,,,,,,,,,,,, -4900,3.816354,1.9780138,,,,,,,,,,,,,, -5000,4.329701,1.9567621,,,,,,,,,,,,,, -5100,4.1264153,1.9815439,,,,,,,,,,,,,, -5200,5.1036057,1.9750175,,,,,,,,,,,,,, -5218,,,0.7218365,0.2371396492559735,1.1679587,0.3178311787365921,5348.0,0.794886,0.2441248755915747,2472.0,4364.524667263031,4942.343545675278,4364.524667263031,577.4759426116943,0.1174452304840087,0.0 -5300,4.715253,1.9312832,,,,,,,,,,,,,, -5400,4.6010056,1.9054682,,,,,,,,,,,,,, -5500,3.131726,1.9735377,,,,,,,,,,,,,, -5600,4.436184,1.8797752,,,,,,,,,,,,,, -5700,4.90608,1.9372238,,,,,,,,,,,,,, -5800,3.384652,1.8795785,,,,,,,,,,,,,, -5900,3.6930077,1.8717954,,,,,,,,,,,,,, -6000,4.4883885,1.8274652,,,,,,,,,,,,,, -6100,3.2653232,1.8936096,,,,,,,,,,,,,, -6200,2.902279,1.7421424,,,,,,,,,,,,,, -6300,3.0558023,1.8415514,,,,,,,,,,,,,, -6400,3.3506541,1.8629076,,,,,,,,,,,,,, -6500,4.939678,1.8705146,,,,,,,,,,,,,, -6600,3.5608222,1.772387,,,,,,,,,,,,,, -6700,5.2793293,1.8865469,,,,,,,,,,,,,, -6800,2.4084136,1.7424691,,,,,,,,,,,,,, -6900,8.998476,1.8027107,,,,,,,,,,,,,, -6930,,,0.5146038,0.1715433191257631,0.8652903,0.2476997789084449,5348.0,0.5532459,0.178376292324254,2472.0,5804.555928230286,6520.837166547775,5804.555928230286,715.8265788555145,0.1547858715057373,0.0 -7000,3.1535628,1.7718748,,,,,,,,,,,,,, -7100,2.4368565,1.7199066,,,,,,,,,,,,,, -7200,2.9826777,1.853961,,,,,,,,,,,,,, -7300,3.9120917,1.7712891,,,,,,,,,,,,,, -7400,3.8803568,1.8002744,,,,,,,,,,,,,, -7500,3.4297466,1.7567983,,,,,,,,,,,,,, -7600,2.3862135,1.7660505,,,,,,,,,,,,,, -7700,3.0300927,1.7446017,,,,,,,,,,,,,, -7800,3.553201,1.7625946,,,,,,,,,,,,,, -7900,4.003544,1.7037511,,,,,,,,,,,,,, -8000,3.7526674,1.7017454,,,,,,,,,,,,,, -8100,3.1776042,1.6999923,,,,,,,,,,,,,, -8200,3.286068,1.7199624,,,,,,,,,,,,,, -8300,2.9268498,1.7126951,,,,,,,,,,,,,, -8400,3.4740915,1.7072426,,,,,,,,,,,,,, -8500,2.3861845,1.6728392,,,,,,,,,,,,,, -8600,2.7464814,1.6659585,,,,,,,,,,,,,, -8665,,,0.46659666,0.1569886277153232,0.8137787,0.2347432345018681,5348.0,0.5063116,0.16480815713038,2472.0,7245.780788183212,8105.636641263962,7245.780788183212,859.2827861309052,0.1961219310760498,0.0 -8700,2.9009328,1.691641,,,,,,,,,,,,,, -8800,4.0621657,1.6662779,,,,,,,,,,,,,, -8900,2.2540052,1.6753441,,,,,,,,,,,,,, -9000,4.540156,1.6747118,,,,,,,,,,,,,, -9100,3.8015285,1.7016534,,,,,,,,,,,,,, -9200,3.176475,1.7374355,,,,,,,,,,,,,, -9300,4.6500416,1.683031,,,,,,,,,,,,,, -9400,2.845159,1.6812923,,,,,,,,,,,,,, -9500,3.7906134,1.6621112,,,,,,,,,,,,,, -9600,4.7979655,1.6681948,,,,,,,,,,,,,, -9700,3.6013248,1.6980381,,,,,,,,,,,,,, -9800,3.2631228,1.6938215,,,,,,,,,,,,,, -9900,2.224533,1.6844302,,,,,,,,,,,,,, -10000,4.2638297,1.6988908,,,,,,,,,,,,,, -10100,2.679822,1.7072344,,,,,,,,,,,,,, -10200,3.3613842,1.65708,,,,,,,,,,,,,, -10300,2.9026303,1.6788588,,,,,,,,,,,,,, -10391,,,0.43561387,0.1443470560306426,0.7429956,0.2146808654431003,5348.0,0.45595506,0.1479089228769321,2472.0,8685.838481426239,9678.446252822876,8685.838481426239,991.9161324501038,0.239854097366333,0.0 -10400,3.3928342,1.7078441,,,,,,,,,,,,,, -10500,3.74102,1.7292085,,,,,,,,,,,,,, -10600,3.2053485,1.677165,,,,,,,,,,,,,, -10700,2.6231897,1.7037622,,,,,,,,,,,,,, -10800,2.8935196,1.6150247,,,,,,,,,,,,,, -10900,2.8665645,1.6930839,,,,,,,,,,,,,, -11000,2.8037918,1.722218,,,,,,,,,,,,,, -11100,2.8584664,1.6613529,,,,,,,,,,,,,, -11200,3.2998636,1.6368965,,,,,,,,,,,,,, -11300,3.684375,1.6390733,,,,,,,,,,,,,, -11400,2.1991744,1.6058906,,,,,,,,,,,,,, -11500,2.6797268,1.6797432,,,,,,,,,,,,,, -11600,4.9568524,1.7143075,,,,,,,,,,,,,, -11700,2.454987,1.6018586,,,,,,,,,,,,,, -11800,4.627812,1.6111356,,,,,,,,,,,,,, -11900,2.6002045,1.5767281,,,,,,,,,,,,,, -12000,3.6064668,1.611024,,,,,,,,,,,,,, -12100,2.2886732,1.6403111,,,,,,,,,,,,,, -12124,,,0.39994377,0.1340533469465892,0.71666336,0.2062620079747434,5348.0,0.44069472,0.1428310279690451,2472.0,10125.72972393036,11252.399189710615,10125.72972393036,1125.863081932068,0.2782533168792724,0.0 -12200,3.0434697,1.587339,,,,,,,,,,,,,, -12300,3.0273068,1.6103655,,,,,,,,,,,,,, -12400,1.6783377,1.5467458,,,,,,,,,,,,,, -12500,2.2396467,1.6260211,,,,,,,,,,,,,, -12600,3.2343464,1.6614057,,,,,,,,,,,,,, -12700,2.0881593,1.5546253,,,,,,,,,,,,,, -12800,2.7867875,1.5838304,,,,,,,,,,,,,, -12900,3.3464208,1.6482874,,,,,,,,,,,,,, -13000,4.3424587,1.668267,,,,,,,,,,,,,, -13100,3.3395538,1.6328875,,,,,,,,,,,,,, -13200,3.249789,1.6316302,,,,,,,,,,,,,, -13300,2.7837815,1.6338847,,,,,,,,,,,,,, -13400,2.8443582,1.5960108,,,,,,,,,,,,,, -13500,3.5556462,1.5586448,,,,,,,,,,,,,, -13600,3.767112,1.6512512,,,,,,,,,,,,,, -13700,2.776813,1.576949,,,,,,,,,,,,,, -13800,3.1870463,1.551065,,,,,,,,,,,,,, -13863,,,0.34301874,0.1186406302815985,0.683678,0.1980072796084072,5348.0,0.41449788,0.1320862023439563,2472.0,11567.119065999985,12830.857915639876,11567.119065999985,1262.8164167404177,0.3170113563537597,0.0 -13900,3.356286,1.6200022,,,,,,,,,,,,,, -14000,3.7446523,1.6168766,,,,,,,,,,,,,, -14100,3.4834316,1.6116185,,,,,,,,,,,,,, -14200,2.3770218,1.5610515,,,,,,,,,,,,,, -14300,1.8962398,1.5166023,,,,,,,,,,,,,, -14400,2.454485,1.6098466,,,,,,,,,,,,,, -14500,2.09971,1.6050199,,,,,,,,,,,,,, -14600,2.909656,1.5646151,,,,,,,,,,,,,, -14700,2.6909344,1.5690645,,,,,,,,,,,,,, -14800,4.17412,1.5956329,,,,,,,,,,,,,, -14900,2.2499175,1.463492,,,,,,,,,,,,,, -15000,2.999102,1.5675049,,,,,,,,,,,,,, -15100,2.047156,1.5501621,,,,,,,,,,,,,, -15200,2.6940985,1.6158174,,,,,,,,,,,,,, -15300,2.3427699,1.6080428,,,,,,,,,,,,,, -15400,2.5872269,1.6694969,,,,,,,,,,,,,, -15500,2.9253333,1.5057847,,,,,,,,,,,,,, -15594,,,0.31959072,0.108528533960387,0.67307526,0.1947343522210529,5348.0,0.39703676,0.1269270611175431,2472.0,13007.51091980934,14409.774055480955,13007.51091980934,1401.224974155426,0.3565559387207031,0.0 -15600,2.4480212,1.5533508,,,,,,,,,,,,,, -15700,2.817168,1.5075995,,,,,,,,,,,,,, -15800,4.509166,1.4829713,,,,,,,,,,,,,, -15900,2.9899197,1.5507681,,,,,,,,,,,,,, -16000,2.0272572,1.4831529,,,,,,,,,,,,,, -16100,2.7678916,1.5923656,,,,,,,,,,,,,, -16200,3.4699795,1.6144701,,,,,,,,,,,,,, -16300,4.4199023,1.5870776,,,,,,,,,,,,,, -16400,4.3517876,1.5339222,,,,,,,,,,,,,, -16500,2.1518106,1.5295863,,,,,,,,,,,,,, -16600,2.6055639,1.5316223,,,,,,,,,,,,,, -16700,3.455813,1.521837,,,,,,,,,,,,,, -16800,2.0366285,1.5015993,,,,,,,,,,,,,, -16900,3.1007535,1.5981091,,,,,,,,,,,,,, -17000,2.083427,1.504032,,,,,,,,,,,,,, -17100,3.7302308,1.5367081,,,,,,,,,,,,,, -17200,2.4289582,1.5356686,,,,,,,,,,,,,, -17300,4.165705,1.5666935,,,,,,,,,,,,,, -17329,,,0.32896024,0.1162404071497905,0.6574222,0.1909786921806965,5348.0,0.39617884,0.1280238864176467,2472.0,14447.46181845665,15993.21812081337,14447.46181845665,1544.6018645763395,0.3955864906311035,0.0 -17400,3.3722851,1.5402972,,,,,,,,,,,,,, -17500,3.2387235,1.5775318,,,,,,,,,,,,,, -17600,4.0160108,1.4904186,,,,,,,,,,,,,, -17700,3.000713,1.551325,,,,,,,,,,,,,, -17800,3.1084695,1.5676576,,,,,,,,,,,,,, -17900,5.395583,1.5127381,,,,,,,,,,,,,, -18000,3.2203696,1.5345638,,,,,,,,,,,,,, -18100,2.490646,1.5610688,,,,,,,,,,,,,, -18200,4.7960486,1.5106941,,,,,,,,,,,,,, -18300,4.1584506,1.4789149,,,,,,,,,,,,,, -18400,1.8493059,1.5668654,,,,,,,,,,,,,, -18500,4.4045315,1.5202417,,,,,,,,,,,,,, -18600,3.016946,1.520046,,,,,,,,,,,,,, -18700,3.4783938,1.4543767,,,,,,,,,,,,,, -18800,2.5046988,1.4733509,,,,,,,,,,,,,, -18900,2.5557191,1.4932349,,,,,,,,,,,,,, -19000,4.850944,1.4936566,,,,,,,,,,,,,, -19065,,,0.32634243,0.1093841719770485,0.6247462,0.1807544145901117,5348.0,0.3724943,0.1213616882984989,2472.0,15887.366396427156,17567.369089365005,15887.366396427156,1678.731039762497,0.4372310638427734,0.0 -19100,3.3866398,1.525444,,,,,,,,,,,,,, -19200,3.0151734,1.5373561,,,,,,,,,,,,,, -19300,2.6043038,1.5412257,,,,,,,,,,,,,, -19400,1.9343466,1.5011717,,,,,,,,,,,,,, -19500,3.3572316,1.4948598,,,,,,,,,,,,,, -19600,2.9761326,1.4904805,,,,,,,,,,,,,, -19700,3.3888628,1.4913148,,,,,,,,,,,,,, -19800,2.146643,1.4905722,,,,,,,,,,,,,, -19900,2.4987009,1.5068676,,,,,,,,,,,,,, -20000,3.931678,1.4820955,,,,,,,,,,,,,, -20100,3.0600765,1.5328562,,,,,,,,,,,,,, -20200,1.8773036,1.45835,,,,,,,,,,,,,, -20300,2.6371732,1.5153986,,,,,,,,,,,,,, -20400,2.6605568,1.529596,,,,,,,,,,,,,, -20500,3.0022116,1.5085684,,,,,,,,,,,,,, -20600,2.1589832,1.4581497,,,,,,,,,,,,,, -20700,2.8310611,1.5445117,,,,,,,,,,,,,, -20800,2.5914047,1.4632986,,,,,,,,,,,,,, -20805,,,0.31835115,0.1072308544595984,0.6072395,0.1752319530397675,5348.0,0.35500604,0.1163853512887697,2472.0,17327.25873041153,19143.011053800583,17327.25873041153,1814.3625190258024,0.4785201549530029,0.0 -20900,1.9798741,1.4959025,,,,,,,,,,,,,, -21000,3.0615115,1.4814564,,,,,,,,,,,,,, -21100,2.5997448,1.5149312,,,,,,,,,,,,,, -21200,2.8481498,1.4696213,,,,,,,,,,,,,, -21300,2.7388365,1.4543864,,,,,,,,,,,,,, -21400,2.2783241,1.4707798,,,,,,,,,,,,,, -21500,4.116015,1.5024179,,,,,,,,,,,,,, -21600,2.8501174,1.4634813,,,,,,,,,,,,,, -21700,2.8999784,1.4472675,,,,,,,,,,,,,, -21800,2.584408,1.4181728,,,,,,,,,,,,,, -21900,2.3703074,1.4468038,,,,,,,,,,,,,, -22000,2.438949,1.4485054,,,,,,,,,,,,,, -22100,2.4233475,1.4544568,,,,,,,,,,,,,, -22200,2.377291,1.4741534,,,,,,,,,,,,,, -22300,3.2739952,1.4207674,,,,,,,,,,,,,, -22400,4.932386,1.4364787,,,,,,,,,,,,,, -22500,1.8989036,1.4437104,,,,,,,,,,,,,, -22510,,,0.29313225,0.0993517851630826,0.5852561,0.1685316238160981,5348.0,0.34486303,0.1103731237178315,2472.0,18767.39107894897,20718.74327898025,18767.39107894897,1949.846531391144,0.5190324783325195,0.0 -22600,2.6663775,1.4293433,,,,,,,,,,,,,, -22700,2.796776,1.3946108,,,,,,,,,,,,,, -22800,2.340767,1.4560599,,,,,,,,,,,,,, -22900,2.6178608,1.4425273,,,,,,,,,,,,,, -23000,2.2824085,1.4243901,,,,,,,,,,,,,, -23100,2.9237957,1.5125483,,,,,,,,,,,,,, -23200,2.931218,1.4210213,,,,,,,,,,,,,, -23300,2.50591,1.45215,,,,,,,,,,,,,, -23400,3.1478088,1.448272,,,,,,,,,,,,,, -23500,2.342274,1.3644447,,,,,,,,,,,,,, -23600,3.9857192,1.4481276,,,,,,,,,,,,,, -23700,3.441248,1.4435256,,,,,,,,,,,,,, -23800,2.678869,1.4002593,,,,,,,,,,,,,, -23900,2.5533419,1.414851,,,,,,,,,,,,,, -24000,2.2779503,1.4536966,,,,,,,,,,,,,, -24100,2.9530747,1.4561458,,,,,,,,,,,,,, -24200,2.6462805,1.3638576,,,,,,,,,,,,,, -24245,,,0.27176955,0.0930756724800472,0.57057077,0.164611834673721,5348.0,0.33417776,0.1059045761988909,2472.0,20207.345369815823,22293.72416448593,20207.345369815823,2084.754161596298,0.5621540546417236,0.0 -24300,3.298874,1.4741187,,,,,,,,,,,,,, -24400,2.2325995,1.4275393,,,,,,,,,,,,,, -24500,2.4199495,1.5006964,,,,,,,,,,,,,, -24600,2.977685,1.4117272,,,,,,,,,,,,,, -24700,1.934321,1.3775134,,,,,,,,,,,,,, -24800,3.513759,1.3802975,,,,,,,,,,,,,, -24900,2.0539205,1.3817971,,,,,,,,,,,,,, -25000,3.1064672,1.4115872,,,,,,,,,,,,,, -25100,2.3091624,1.3573048,,,,,,,,,,,,,, -25200,1.9236559,1.4122704,,,,,,,,,,,,,, -25300,2.6786783,1.4242384,,,,,,,,,,,,,, -25400,1.9606974,1.4076523,,,,,,,,,,,,,, -25500,2.274435,1.416517,,,,,,,,,,,,,, -25600,3.194168,1.4288863,,,,,,,,,,,,,, -25700,3.372857,1.3894101,,,,,,,,,,,,,, -25800,2.5623198,1.4160967,,,,,,,,,,,,,, -25900,3.0505373,1.4421637,,,,,,,,,,,,,, -25965,,,0.2505488,0.0890581392315918,0.5587426,0.1626229761433522,5348.0,0.31846216,0.1046249466821034,2472.0,21647.66396617889,23867.22459101677,21647.66396617889,2217.822919368744,0.6007964611053467,0.0 -26000,2.7741418,1.4211246,,,,,,,,,,,,,, -26100,3.7328007,1.464163,,,,,,,,,,,,,, -26200,2.7773454,1.3668406,,,,,,,,,,,,,, -26300,3.6046748,1.4416758,,,,,,,,,,,,,, -26400,3.4904246,1.3970839,,,,,,,,,,,,,, -26500,2.693756,1.4055924,,,,,,,,,,,,,, -26600,2.7541332,1.4088197,,,,,,,,,,,,,, -26700,3.2093725,1.3912085,,,,,,,,,,,,,, -26800,2.8518643,1.4265014,,,,,,,,,,,,,, -26900,2.3673983,1.4050558,,,,,,,,,,,,,, -27000,3.591821,1.4442421,,,,,,,,,,,,,, -27100,2.466089,1.4324738,,,,,,,,,,,,,, -27200,2.3637054,1.332602,,,,,,,,,,,,,, -27300,3.9746976,1.4259362,,,,,,,,,,,,,, -27400,2.6590922,1.4062718,,,,,,,,,,,,,, -27500,2.1749516,1.3608582,,,,,,,,,,,,,, -27600,3.3297234,1.399826,,,,,,,,,,,,,, -27671,,,0.2422054,0.0829080197604606,0.5420699,0.1576701391235506,5348.0,0.31094378,0.1007048118132147,2472.0,23087.62013030052,25441.73743915558,23087.62013030052,2352.263339281082,0.6415128707885742,0.0 -27700,2.6428688,1.3796078,,,,,,,,,,,,,, -27800,2.1601503,1.3441703,,,,,,,,,,,,,, -27900,3.923819,1.3533974,,,,,,,,,,,,,, -28000,3.3620503,1.3759923,,,,,,,,,,,,,, -28100,2.4533172,1.3341255,,,,,,,,,,,,,, -28200,3.343889,1.4218106,,,,,,,,,,,,,, -28300,2.1623452,1.3509059,,,,,,,,,,,,,, -28400,2.680612,1.4176755,,,,,,,,,,,,,, -28500,2.8557124,1.4059393,,,,,,,,,,,,,, -28600,2.343812,1.3666201,,,,,,,,,,,,,, -28700,3.352298,1.3813787,,,,,,,,,,,,,, -28800,3.7501216,1.3902661,,,,,,,,,,,,,, -28900,2.4256673,1.3614725,,,,,,,,,,,,,, -29000,3.6112907,1.3923995,,,,,,,,,,,,,, -29100,3.097819,1.3056467,,,,,,,,,,,,,, -29200,3.0545917,1.3303158,,,,,,,,,,,,,, -29300,2.79122,1.3499415,,,,,,,,,,,,,, -29371,,,0.25566694,0.0881056075576356,0.52995616,0.1546771966749374,5348.0,0.30119932,0.0959519021794324,2472.0,24528.12004709244,27020.380325078964,24528.12004709244,2490.2884378433228,0.685178279876709,0.0 -29400,2.5743725,1.3723707,,,,,,,,,,,,,, -29500,2.4502578,1.3681598,,,,,,,,,,,,,, -29600,2.330742,1.3280388,,,,,,,,,,,,,, -29700,2.5098462,1.3044442,,,,,,,,,,,,,, -29800,3.8867037,1.3463246,,,,,,,,,,,,,, -29900,2.8142502,1.3691523,,,,,,,,,,,,,, -30000,2.2532518,1.3129417,,,,,,,,,,,,,, -30100,2.8390882,1.3172894,,,,,,,,,,,,,, -30200,2.4108057,1.3324722,,,,,,,,,,,,,, -30300,2.9218652,1.3306943,,,,,,,,,,,,,, -30400,2.7147,1.3607571,,,,,,,,,,,,,, -30500,2.541322,1.3107916,,,,,,,,,,,,,, -30600,2.6319664,1.3306614,,,,,,,,,,,,,, -30700,2.5277565,1.3471785,,,,,,,,,,,,,, -30800,2.442511,1.2804868,,,,,,,,,,,,,, -30900,2.335741,1.2674917,,,,,,,,,,,,,, -31000,2.7574518,1.334347,,,,,,,,,,,,,, -31100,2.2037482,1.2879992,,,,,,,,,,,,,, -31104,,,0.23302117,0.0778354738766895,0.51413625,0.1478996302267878,5348.0,0.29052296,0.0922755062661223,2472.0,25968.7009601593,28593.22857093811,25968.7009601593,2622.43061876297,0.7326092720031738,0.0 -31200,2.5360198,1.2957902,,,,,,,,,,,,,, -31300,2.770438,1.3732498,,,,,,,,,,,,,, -31400,2.3713934,1.2955005,,,,,,,,,,,,,, -31500,1.9913324,1.3228464,,,,,,,,,,,,,, -31600,4.0089846,1.3536623,,,,,,,,,,,,,, -31700,2.3535159,1.3300353,,,,,,,,,,,,,, -31800,3.0582156,1.3726403,,,,,,,,,,,,,, -31900,2.1200864,1.3070556,,,,,,,,,,,,,, -32000,2.838867,1.3375986,,,,,,,,,,,,,, -32100,2.6018796,1.3096122,,,,,,,,,,,,,, -32200,2.714751,1.275454,,,,,,,,,,,,,, -32300,2.760114,1.2719471,,,,,,,,,,,,,, -32400,3.2757738,1.2815362,,,,,,,,,,,,,, -32500,2.8597655,1.2566406,,,,,,,,,,,,,, -32600,2.7721298,1.2746118,,,,,,,,,,,,,, -32700,3.7028534,1.3171161,,,,,,,,,,,,,, -32800,3.4423003,1.3466848,,,,,,,,,,,,,, -32821,,,0.2403235,0.0805778012934776,0.50835735,0.1464900508800216,5348.0,0.28367037,0.0900615440862835,2472.0,27408.875204086304,30170.055107831955,27408.875204086304,2758.963734149933,0.7755627632141113,0.0 -32900,3.5951462,1.3460969,,,,,,,,,,,,,, -33000,2.3925135,1.2577784,,,,,,,,,,,,,, -33100,3.4521646,1.2685013,,,,,,,,,,,,,, -33200,1.8254269,1.2711492,,,,,,,,,,,,,, -33300,1.8367789,1.2470931,,,,,,,,,,,,,, -33400,3.1882758,1.3350348,,,,,,,,,,,,,, -33500,3.16906,1.379792,,,,,,,,,,,,,, -33600,3.9901338,1.268804,,,,,,,,,,,,,, -33700,3.983926,1.3296295,,,,,,,,,,,,,, -33800,2.7188137,1.2821845,,,,,,,,,,,,,, -33900,3.024435,1.2797796,,,,,,,,,,,,,, -34000,6.0855713,1.2563846,,,,,,,,,,,,,, -34100,3.2856221,1.2832072,,,,,,,,,,,,,, -34200,1.925548,1.2568822,,,,,,,,,,,,,, -34300,2.0272124,1.3071307,,,,,,,,,,,,,, -34400,6.5896807,1.3098743,,,,,,,,,,,,,, -34500,2.1006327,1.2441913,,,,,,,,,,,,,, -34549,,,0.22173432,0.0732813282607073,0.48096427,0.13927802504417,5348.0,0.26772746,0.0863242134340787,2472.0,28848.791372060776,31744.34097886085,28848.791372060776,2893.2134778499603,0.820319414138794,0.0 -34600,2.6695192,1.2409929,,,,,,,,,,,,,, -34700,2.102579,1.2399482,,,,,,,,,,,,,, -34800,2.7520704,1.2620498,,,,,,,,,,,,,, -34900,2.8897007,1.2709544,,,,,,,,,,,,,, -35000,2.6336985,1.344961,,,,,,,,,,,,,, -35100,2.7945027,1.2552366,,,,,,,,,,,,,, -35200,3.4576738,1.2801378,,,,,,,,,,,,,, -35300,3.0344439,1.2404473,,,,,,,,,,,,,, -35400,2.6334584,1.2476748,,,,,,,,,,,,,, -35500,2.6059062,1.236118,,,,,,,,,,,,,, -35600,2.6435993,1.2339915,,,,,,,,,,,,,, -35700,2.5542247,1.2972373,,,,,,,,,,,,,, -35800,2.9109852,1.2593325,,,,,,,,,,,,,, -35900,3.4160118,1.2899487,,,,,,,,,,,,,, -36000,3.280583,1.2705353,,,,,,,,,,,,,, -36100,4.423746,1.1953074,,,,,,,,,,,,,, -36200,2.6439505,1.3013911,,,,,,,,,,,,,, -36283,,,0.16659415,0.0579881717454362,0.46956345,0.1348368846365505,5348.0,0.2593685,0.0838665122986614,2472.0,30288.866734981537,33320.562363147736,30288.866734981537,3029.243176460266,0.8601865768432617,0.0 -36300,2.5848088,1.2470998,,,,,,,,,,,,,, -36400,3.2501013,1.2565297,,,,,,,,,,,,,, -36500,3.5774484,1.2708672,,,,,,,,,,,,,, -36600,5.1392717,1.2696502,,,,,,,,,,,,,, -36700,2.4002683,1.214413,,,,,,,,,,,,,, -36800,3.5355613,1.2068527,,,,,,,,,,,,,, -36900,2.8025336,1.2525303,,,,,,,,,,,,,, -37000,2.6337774,1.2542927,,,,,,,,,,,,,, -37100,3.0206563,1.1743705,,,,,,,,,,,,,, -37200,2.2756178,1.2205284,,,,,,,,,,,,,, -37300,3.696907,1.2525947,,,,,,,,,,,,,, -37400,3.4663513,1.2307186,,,,,,,,,,,,,, -37500,3.2925582,1.2320942,,,,,,,,,,,,,, -37600,3.0758793,1.2307698,,,,,,,,,,,,,, -37700,4.7007647,1.2314947,,,,,,,,,,,,,, -37800,2.83132,1.212129,,,,,,,,,,,,,, -37900,2.3505628,1.2478162,,,,,,,,,,,,,, -37984,,,0.1848998,0.0630979282380519,0.4517904,0.1306564198615522,5348.0,0.24856593,0.0794792110982471,2472.0,31729.2007689476,34900.29788470268,31729.2007689476,3168.526449203491,0.9034137725830078,0.0 -38000,2.4340703,1.2153543,,,,,,,,,,,,,, -38100,3.0229995,1.1985983,,,,,,,,,,,,,, -38200,3.0085216,1.1958055,,,,,,,,,,,,,, -38300,2.2934682,1.2012639,,,,,,,,,,,,,, -38400,3.5496504,1.1514125,,,,,,,,,,,,,, -38500,2.9505813,1.1621324,,,,,,,,,,,,,, -38600,3.004905,1.2061982,,,,,,,,,,,,,, -38700,3.9139276,1.2027833,,,,,,,,,,,,,, -38800,2.4150596,1.2461687,,,,,,,,,,,,,, -38900,3.773776,1.2069247,,,,,,,,,,,,,, -39000,3.8856127,1.2196939,,,,,,,,,,,,,, -39100,2.7515442,1.2430546,,,,,,,,,,,,,, -39200,8.125044,1.2060655,,,,,,,,,,,,,, -39300,4.946021,1.1616833,,,,,,,,,,,,,, -39400,3.6632938,1.1873014,,,,,,,,,,,,,, -39500,2.4359355,1.139596,,,,,,,,,,,,,, -39600,2.0897539,1.222404,,,,,,,,,,,,,, -39697,,,0.23087613,0.0794465266891946,0.44235966,0.1274221110864381,5348.0,0.2419337,0.0779761542055125,2472.0,33169.299060583115,36473.96536016464,33169.299060583115,3301.976773738861,0.9474256038665771,0.0 -39700,2.897393,1.2224524,,,,,,,,,,,,,, -39800,5.041082,1.1990694,,,,,,,,,,,,,, -39900,3.5586975,1.193958,,,,,,,,,,,,,, -40000,2.526235,1.1976626,,,,,,,,,,,,,, -40100,3.0177352,1.2020949,,,,,,,,,,,,,, -40200,2.3218067,1.1877761,,,,,,,,,,,,,, -40300,2.8042188,1.1287814,,,,,,,,,,,,,, -40400,4.3236556,1.2132485,,,,,,,,,,,,,, -40500,3.7032008,1.1324279,,,,,,,,,,,,,, -40600,3.8644931,1.154324,,,,,,,,,,,,,, -40700,3.2654026,1.1610534,,,,,,,,,,,,,, -40800,2.554062,1.1614232,,,,,,,,,,,,,, -40900,2.9614396,1.1865816,,,,,,,,,,,,,, -41000,3.418402,1.2015053,,,,,,,,,,,,,, -41100,3.2245133,1.1861706,,,,,,,,,,,,,, -41200,2.109629,1.1167799,,,,,,,,,,,,,, -41300,3.2641187,1.196206,,,,,,,,,,,,,, -41400,3.0412376,1.2578672,,,,,,,,,,,,,, -41414,,,0.2321512,0.0798170281507963,0.4297312,0.1231837183930795,5348.0,0.23373564,0.075620010968253,2472.0,34609.19066643715,38048.15858435631,34609.19066643715,3436.159049510956,0.9893152713775636,0.0 -41500,2.2958927,1.2034014,,,,,,,,,,,,,, -41600,2.261411,1.1598289,,,,,,,,,,,,,, -41700,3.2263734,1.1285043,,,,,,,,,,,,,, -41800,2.8748908,1.2124578,,,,,,,,,,,,,, -41900,2.4221652,1.1373851,,,,,,,,,,,,,, -42000,2.3089523,1.1402026,,,,,,,,,,,,,, -42100,3.15977,1.1750817,,,,,,,,,,,,,, -42200,4.058865,1.1676943,,,,,,,,,,,,,, -42300,2.8129413,1.1436245,,,,,,,,,,,,,, -42400,3.1653445,1.1927518,,,,,,,,,,,,,, -42500,3.245519,1.1474745,,,,,,,,,,,,,, -42600,1.7622523,1.1586378,,,,,,,,,,,,,, -42700,4.568808,1.1824749,,,,,,,,,,,,,, -42800,2.2501338,1.1574192,,,,,,,,,,,,,, -42900,3.219719,1.118246,,,,,,,,,,,,,, -43000,4.188193,1.1121429,,,,,,,,,,,,,, -43100,4.9621477,1.1957631,,,,,,,,,,,,,, -43118,,,0.26301807,0.0910260834976104,0.42406932,0.1224016914952161,5348.0,0.22829898,0.073101375093941,2472.0,36049.38084578514,39621.84571695328,36049.38084578514,3569.5378217697144,1.0312442779541016,0.0 -43200,5.130556,1.1316788,,,,,,,,,,,,,, -43300,4.4968596,1.1526048,,,,,,,,,,,,,, -43400,3.111443,1.1440542,,,,,,,,,,,,,, -43500,6.1589766,1.2021731,,,,,,,,,,,,,, -43600,9.22407,1.1994746,,,,,,,,,,,,,, -43700,3.504544,1.1076084,,,,,,,,,,,,,, -43800,2.7202759,1.1205769,,,,,,,,,,,,,, -43900,3.207946,1.1173007,,,,,,,,,,,,,, -44000,3.3367016,1.0941443,,,,,,,,,,,,,, -44100,3.6641998,1.1568327,,,,,,,,,,,,,, -44200,3.9537554,1.1634455,,,,,,,,,,,,,, -44300,5.8085527,1.131967,,,,,,,,,,,,,, -44400,2.5748367,1.1326025,,,,,,,,,,,,,, -44500,2.3931293,1.0738696,,,,,,,,,,,,,, -44600,3.334624,1.1189551,,,,,,,,,,,,,, -44700,2.3993907,1.1624117,,,,,,,,,,,,,, -44800,3.6202695,1.1648202,,,,,,,,,,,,,, -44835,,,0.23237117,0.0770743425042243,0.42018595,0.1203452503934271,5348.0,0.22612567,0.0724514045457315,2472.0,37489.69993138313,41194.15355873108,37489.69993138313,3701.4083173274994,1.073310613632202,0.0 -44900,3.228583,1.155534,,,,,,,,,,,,,, -45000,2.777526,1.1681049,,,,,,,,,,,,,, -45100,3.0768182,1.1514113,,,,,,,,,,,,,, -45200,3.7147665,1.146391,,,,,,,,,,,,,, -45300,2.4335675,1.1295159,,,,,,,,,,,,,, -45400,2.997644,1.170912,,,,,,,,,,,,,, -45500,3.787315,1.1663052,,,,,,,,,,,,,, -45600,4.415588,1.1503332,,,,,,,,,,,,,, -45700,2.050951,1.0464627,,,,,,,,,,,,,, -45800,3.5125923,1.0591418,,,,,,,,,,,,,, -45900,4.4526877,1.1613237,,,,,,,,,,,,,, -46000,2.8363862,1.1509557,,,,,,,,,,,,,, -46100,2.5655813,1.120207,,,,,,,,,,,,,, -46200,3.6499107,1.106671,,,,,,,,,,,,,, -46300,2.5941997,1.1597105,,,,,,,,,,,,,, -46400,2.443227,1.1792766,,,,,,,,,,,,,, -46500,4.3684773,1.1525736,,,,,,,,,,,,,, -46551,,,0.21373685,0.0742041797957979,0.41774118,0.1196887339853442,5348.0,0.22410408,0.0715983182012065,2472.0,38929.57273769379,42767.10988521576,38929.57273769379,3834.3490607738495,1.1359052658081057,0.0 -46600,2.9772,1.1293186,,,,,,,,,,,,,, -46700,2.7484894,1.2045398,,,,,,,,,,,,,, -46800,3.3085196,1.1463293,,,,,,,,,,,,,, -46900,4.8487606,1.1683153,,,,,,,,,,,,,, -47000,3.132278,1.1115932,,,,,,,,,,,,,, -47100,3.3363118,1.1528901,,,,,,,,,,,,,, -47200,10.579948,1.1361341,,,,,,,,,,,,,, -47300,2.6855063,1.1498421,,,,,,,,,,,,,, -47400,2.5297744,1.1221875,,,,,,,,,,,,,, -47500,3.4582846,1.0979085,,,,,,,,,,,,,, -47600,5.2559605,1.1370599,,,,,,,,,,,,,, -47700,3.0348928,1.1438888,,,,,,,,,,,,,, -47800,3.9331484,1.0789859,,,,,,,,,,,,,, -47900,4.222039,1.1145976,,,,,,,,,,,,,, -48000,,,0.19231977,0.06766794636157,0.41810927,0.1197563165567645,5348.0,0.22410974,0.0713748908252594,2472.0,40149.33516192436,44121.655792713165,40149.33516192436,3969.007215976715,1.1929059028625488,0.0 -48000,,,,,,,,,,,40149.33516192436,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 16367d4e2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -159.01518487930298,0.0,15.45589280128479,1,0,15.45589280128479,28.256151,2472,2.344362521073264,174.47113847732544,29.524904,2.289365492296078,28.14923,5348,2.1855334678548326 -273.5006091594696,0.0295906066894531,1456.418749332428,1699,0,1456.418749332428,6.3231583,2472,0.8925923669083745,1730.0256130695343,6.5281034,0.9363028259749184,6.363745,5348,0.8913754984214642 -390.15674901008606,0.0819911956787109,2896.712926864624,3407,0,2896.712926864624,3.7111945,2472,0.786301870696484,3287.10879445076,4.3004704,0.8745647056270951,4.174814,5348,0.8286105988781293 -524.7541451454163,0.1371448040008545,4337.139240503311,5093,0,4337.139240503311,0.6458774,2472,0.2056141206101598,4862.266368150711,0.91354764,0.2718951000690131,0.99934816,5348,0.279000164129102 -657.7007877826691,0.1902203559875488,5777.388965606689,6795,0,5777.388965606689,0.49709377,2472,0.1617817317652793,6435.596847295761,0.68832296,0.2178153222760207,0.79865134,5348,0.2307944814003109 -791.83607172966,0.2452392578125,7217.569427251816,8514,0,7217.569427251816,0.44906363,2472,0.1469136554749862,8010.0493178367615,0.5999228,0.1935764928827366,0.72859186,5348,0.2121706556474893 -928.1383287906648,0.2975411415100097,8658.101991176605,10205,0,8658.101991176605,0.40845668,2472,0.1322080718217455,9587.016305446625,0.53137237,0.1741598852198085,0.68165743,5348,0.1969935410371028 -1061.1018662452698,0.3483004570007324,10098.519153118134,11908,0,10098.519153118134,0.3866501,2472,0.1245302947210204,11160.529643058777,0.50215024,0.166799780931952,0.64968437,5348,0.1875030170790812 -1194.6305973529816,0.4078514575958252,11538.94403219223,13621,0,11538.94403219223,0.3623559,2472,0.1185790018889769,12734.627663373947,0.44562355,0.1491876088023925,0.6191698,5348,0.1785145350801819 -1328.6569051742554,0.4639625549316406,12979.073741436005,15293,0,12979.073741436005,0.34927937,2472,0.1144963743830357,14308.919946432114,0.44809407,0.1485734782534432,0.6029571,5348,0.1749519681010262 -1464.3784432411194,0.5136878490447998,14419.393539190292,17003,0,14419.393539190292,0.33758086,2472,0.1090731826214124,15885.092690944672,0.42328086,0.1401969991726482,0.58921814,5348,0.1701149869179451 -1600.6581590175629,0.5700757503509521,15859.808149814606,18687,0,15859.808149814606,0.32256627,2472,0.1058436414599963,17461.924512147903,0.40703046,0.1380073719048515,0.56702036,5348,0.1645442521023007 -1733.951742887497,0.6257908344268799,17300.169142961502,20376,0,17300.169142961502,0.31212524,2472,0.1036499908597891,19035.71454906464,0.35868832,0.1200126644541563,0.5564612,5348,0.160730664143584 -1869.3748636245728,0.6759476661682129,18740.1226541996,22081,0,18740.1226541996,0.30617353,2472,0.1011313549854772,20611.221473693848,0.32458857,0.1098593623444518,0.5423642,5348,0.1574770460623497 -2003.326298236847,0.7289724349975586,20180.643664598465,23760,0,20180.643664598465,0.29169166,2472,0.0963987569313265,22185.8266851902,0.36119395,0.1249985887984465,0.5224131,5348,0.1516263263079641 -2136.82527923584,0.7829153537750244,21621.02370738983,25457,0,21621.02370738983,0.28612068,2472,0.092600491540227,23759.84077358246,0.3251517,0.10952787258248,0.51066995,5348,0.1475134441043861 -2272.730340003968,0.8456952571868896,23060.985421419144,27173,0,23060.985421419144,0.28231353,2472,0.0910364999085979,25335.851276874542,0.3157251,0.1056232280028406,0.50370485,5348,0.1455438948801374 -2406.2770829200745,0.8960163593292236,24501.122561454773,28848,0,24501.122561454773,0.27055633,2472,0.0877460240082871,26909.666396856308,0.28317812,0.0987635239567233,0.4912923,5348,0.1415372138602199 -2543.348155975342,0.951606273651123,25941.73451423645,30561,0,25941.73451423645,0.26829788,2472,0.0875022850527085,28487.4870493412,0.2763036,0.0966358705988807,0.48683694,5348,0.1402627996562943 -2677.9681293964386,1.001542568206787,27382.11405301094,32288,0,27382.11405301094,0.26102412,2472,0.084333678630187,30062.6204662323,0.25569287,0.0897257371193316,0.46944138,5348,0.1357830406364347 -2832.421259641648,1.0486581325531006,28822.875962257385,33983,0,28822.875962257385,0.2518262,2472,0.0820790932910852,31657.9614071846,0.15601371,0.0559534943036898,0.46224105,5348,0.1343155333713083 -2968.835390806198,1.0996172428131104,30262.88162112236,35663,0,30262.88162112236,0.24787179,2472,0.0796213921556679,33234.50987505913,0.15231837,0.0532382995745299,0.45944563,5348,0.1315349932900161 -3104.18021607399,1.1535747051239014,31703.97305607796,37356,0,31703.97305607796,0.23940639,2472,0.0773464952369345,34811.08276438713,0.1572897,0.055789000766405,0.44850588,5348,0.1289282369638047 -3242.2381801605225,1.2139358520507812,33144.23564553261,39037,0,33144.23564553261,0.2376608,2472,0.076290293096094,36389.54399847984,0.13856104,0.049722430920778,0.44071487,5348,0.1269683423926161 -3379.543396711349,1.2748100757598877,34584.34844779968,40732,0,34584.34844779968,0.23402792,2472,0.0753356488534113,37967.1045422554,0.1460249,0.0518470339231544,0.4341819,5348,0.1248636280255269 -3513.8202295303345,1.3349525928497314,36024.96976852417,42425,0,36024.96976852417,0.23212434,2472,0.0747872362033595,39542.14425396919,0.13446575,0.0488422537318217,0.43170702,5348,0.1238691987603425 -3649.597271680832,1.391124963760376,37465.89845681191,44094,0,37465.89845681191,0.23024347,2472,0.0743810046107285,41118.98565912247,0.15975782,0.0529069018799897,0.4287825,5348,0.1227975322706778 -3784.859577178955,1.4510438442230225,38906.24567842484,45785,0,38906.24567842484,0.22979745,2472,0.07413726565515,42694.73621749878,0.14198789,0.050932994062765,0.42813993,5348,0.1223051449646157 -3920.495556592941,1.513451337814331,40346.5943479538,47459,0,40346.5943479538,0.22964966,2472,0.0739138382792029,44270.86405515671,0.1208012,0.0438509485377796,0.4275235,5348,0.1221120519034148 -4057.6739218235016,1.570603847503662,40782.78246617317,48000,0,40782.78246617317,0.22965063,2472,0.07399508459772916,44844.3178961277,0.12780735,0.0460682191722342,0.42751795,5348,0.12202516002587448 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/measurements.csv deleted file mode 100644 index 8f3218165..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,22.45417,33.534977,,,,,,,,,,,,,, -1,,,29.524904,2.289365492296078,28.14923,2.1855334678548326,5348.0,28.256151,2.344362521073264,2472.0,15.45589280128479,174.47113847732544,15.45589280128479,159.01518487930298,0.0,0.0 -100,2.7577605,7.383607,,,,,,,,,,,,,, -200,0.794999,5.9575896,,,,,,,,,,,,,, -300,0.8565822,5.866565,,,,,,,,,,,,,, -400,1.0740551,5.791782,,,,,,,,,,,,,, -500,1.6073877,5.743721,,,,,,,,,,,,,, -600,0.58529663,5.606693,,,,,,,,,,,,,, -700,0.55366355,5.474323,,,,,,,,,,,,,, -800,1.5316026,5.2313952,,,,,,,,,,,,,, -900,1.6449729,4.77405,,,,,,,,,,,,,, -1000,1.2343005,4.3318624,,,,,,,,,,,,,, -1100,1.5018227,3.9412153,,,,,,,,,,,,,, -1200,2.2914524,3.6726847,,,,,,,,,,,,,, -1300,2.1964688,3.4880342,,,,,,,,,,,,,, -1400,1.7445439,3.3404646,,,,,,,,,,,,,, -1500,2.1173775,3.1492357,,,,,,,,,,,,,, -1600,2.0232575,3.098775,,,,,,,,,,,,,, -1699,,,6.5281034,0.9363028259749184,6.363745,0.8913754984214642,5348.0,6.3231583,0.8925923669083745,2472.0,1456.418749332428,1730.0256130695343,1456.418749332428,273.5006091594696,0.0295906066894531,0.0 -1700,2.5522985,3.0240345,,,,,,,,,,,,,, -1800,2.319901,2.8394785,,,,,,,,,,,,,, -1900,2.47641,2.7488756,,,,,,,,,,,,,, -2000,2.6001263,2.7049668,,,,,,,,,,,,,, -2100,3.2794647,2.6523237,,,,,,,,,,,,,, -2200,2.5165567,2.5089145,,,,,,,,,,,,,, -2300,2.5007968,2.4934216,,,,,,,,,,,,,, -2400,2.579081,2.420997,,,,,,,,,,,,,, -2500,1.9827996,2.43219,,,,,,,,,,,,,, -2600,2.5036695,2.4037826,,,,,,,,,,,,,, -2700,2.3979466,2.3775764,,,,,,,,,,,,,, -2800,3.2230926,2.3018415,,,,,,,,,,,,,, -2900,4.276971,2.2834964,,,,,,,,,,,,,, -3000,2.9150684,2.2473373,,,,,,,,,,,,,, -3100,2.751388,2.1778405,,,,,,,,,,,,,, -3200,2.3442483,2.1317375,,,,,,,,,,,,,, -3300,2.9760478,2.2008345,,,,,,,,,,,,,, -3400,3.987654,2.1628125,,,,,,,,,,,,,, -3407,,,4.3004704,0.8745647056270951,4.174814,0.8286105988781293,5348.0,3.7111945,0.786301870696484,2472.0,2896.712926864624,3287.10879445076,2896.712926864624,390.15674901008606,0.0819911956787109,0.0 -3500,2.4645875,2.1274061,,,,,,,,,,,,,, -3600,2.174849,2.043341,,,,,,,,,,,,,, -3700,4.302718,2.004914,,,,,,,,,,,,,, -3800,3.3246224,1.9901123,,,,,,,,,,,,,, -3900,3.6569505,2.0074296,,,,,,,,,,,,,, -4000,3.3690352,1.9686253,,,,,,,,,,,,,, -4100,3.16188,1.9609493,,,,,,,,,,,,,, -4200,3.3671074,1.9355191,,,,,,,,,,,,,, -4300,4.29941,1.9714103,,,,,,,,,,,,,, -4400,3.0667648,1.9241294,,,,,,,,,,,,,, -4500,4.8902864,1.946899,,,,,,,,,,,,,, -4600,2.7454607,1.951438,,,,,,,,,,,,,, -4700,3.3571286,1.9656603,,,,,,,,,,,,,, -4800,2.8243055,1.8186811,,,,,,,,,,,,,, -4900,2.7409189,1.8269409,,,,,,,,,,,,,, -5000,5.9441547,1.8190298,,,,,,,,,,,,,, -5093,,,0.91354764,0.2718951000690131,0.99934816,0.279000164129102,5348.0,0.6458774,0.2056141206101598,2472.0,4337.139240503311,4862.266368150711,4337.139240503311,524.7541451454163,0.1371448040008545,0.0 -5100,4.238278,1.8587016,,,,,,,,,,,,,, -5200,4.444533,1.8484938,,,,,,,,,,,,,, -5300,2.799898,1.8108144,,,,,,,,,,,,,, -5400,3.5972757,1.8022892,,,,,,,,,,,,,, -5500,3.31158,1.8498075,,,,,,,,,,,,,, -5600,2.7603693,1.78074,,,,,,,,,,,,,, -5700,3.226093,1.7535646,,,,,,,,,,,,,, -5800,2.7562091,1.7158266,,,,,,,,,,,,,, -5900,3.6725166,1.8499258,,,,,,,,,,,,,, -6000,4.6752844,1.7853042,,,,,,,,,,,,,, -6100,2.3951423,1.7717408,,,,,,,,,,,,,, -6200,2.8123925,1.7649422,,,,,,,,,,,,,, -6300,3.042001,1.654778,,,,,,,,,,,,,, -6400,2.7915027,1.7644885,,,,,,,,,,,,,, -6500,2.848739,1.712228,,,,,,,,,,,,,, -6600,3.5913289,1.7146999,,,,,,,,,,,,,, -6700,3.233406,1.7278234,,,,,,,,,,,,,, -6795,,,0.68832296,0.2178153222760207,0.79865134,0.2307944814003109,5348.0,0.49709377,0.1617817317652793,2472.0,5777.388965606689,6435.596847295761,5777.388965606689,657.7007877826691,0.1902203559875488,0.0 -6800,2.9013047,1.7020476,,,,,,,,,,,,,, -6900,3.0410728,1.74812,,,,,,,,,,,,,, -7000,2.9666445,1.6608808,,,,,,,,,,,,,, -7100,3.5791287,1.6923163,,,,,,,,,,,,,, -7200,3.008774,1.7076343,,,,,,,,,,,,,, -7300,2.6824608,1.6473278,,,,,,,,,,,,,, -7400,2.4565501,1.6424577,,,,,,,,,,,,,, -7500,3.2798972,1.6936029,,,,,,,,,,,,,, -7600,2.651932,1.6879994,,,,,,,,,,,,,, -7700,2.1780806,1.7322674,,,,,,,,,,,,,, -7800,5.051147,1.7224908,,,,,,,,,,,,,, -7900,2.2696967,1.6122396,,,,,,,,,,,,,, -8000,2.6406512,1.6723979,,,,,,,,,,,,,, -8100,1.9984496,1.6053252,,,,,,,,,,,,,, -8200,4.042363,1.6516312,,,,,,,,,,,,,, -8300,4.0505047,1.6422731,,,,,,,,,,,,,, -8400,2.3319,1.6069044,,,,,,,,,,,,,, -8500,1.9234849,1.541212,,,,,,,,,,,,,, -8514,,,0.5999228,0.1935764928827366,0.72859186,0.2121706556474893,5348.0,0.44906363,0.1469136554749862,2472.0,7217.569427251816,8010.0493178367615,7217.569427251816,791.83607172966,0.2452392578125,0.0 -8600,1.9305114,1.6194788,,,,,,,,,,,,,, -8700,2.217231,1.6414981,,,,,,,,,,,,,, -8800,3.6010957,1.5970796,,,,,,,,,,,,,, -8900,3.3241568,1.6316901,,,,,,,,,,,,,, -9000,2.0985944,1.6137509,,,,,,,,,,,,,, -9100,4.413601,1.6143436,,,,,,,,,,,,,, -9200,2.0426521,1.6324726,,,,,,,,,,,,,, -9300,2.7825077,1.5600138,,,,,,,,,,,,,, -9400,2.5395052,1.6186572,,,,,,,,,,,,,, -9500,3.0280344,1.5837529,,,,,,,,,,,,,, -9600,2.0274284,1.5853449,,,,,,,,,,,,,, -9700,2.572979,1.5859741,,,,,,,,,,,,,, -9800,3.350795,1.5622137,,,,,,,,,,,,,, -9900,4.0037637,1.5955095,,,,,,,,,,,,,, -10000,2.884682,1.6103027,,,,,,,,,,,,,, -10100,2.0537102,1.556551,,,,,,,,,,,,,, -10200,3.2380092,1.5751926,,,,,,,,,,,,,, -10205,,,0.53137237,0.1741598852198085,0.68165743,0.1969935410371028,5348.0,0.40845668,0.1322080718217455,2472.0,8658.101991176605,9587.016305446625,8658.101991176605,928.1383287906648,0.2975411415100097,0.0 -10300,2.065848,1.5844026,,,,,,,,,,,,,, -10400,2.3048697,1.5398074,,,,,,,,,,,,,, -10500,3.2718568,1.5984011,,,,,,,,,,,,,, -10600,3.2003956,1.6338894,,,,,,,,,,,,,, -10700,2.624602,1.52384,,,,,,,,,,,,,, -10800,2.675922,1.50957,,,,,,,,,,,,,, -10900,2.3145368,1.5444045,,,,,,,,,,,,,, -11000,2.697383,1.5769546,,,,,,,,,,,,,, -11100,2.0986304,1.5652857,,,,,,,,,,,,,, -11200,2.8985221,1.5455711,,,,,,,,,,,,,, -11300,3.7146623,1.5690013,,,,,,,,,,,,,, -11400,2.2569451,1.5049328,,,,,,,,,,,,,, -11500,2.9228847,1.4801358,,,,,,,,,,,,,, -11600,3.081018,1.593466,,,,,,,,,,,,,, -11700,2.4448473,1.4718295,,,,,,,,,,,,,, -11800,2.9331806,1.5411038,,,,,,,,,,,,,, -11900,3.7859938,1.5603735,,,,,,,,,,,,,, -11908,,,0.50215024,0.166799780931952,0.64968437,0.1875030170790812,5348.0,0.3866501,0.1245302947210204,2472.0,10098.519153118134,11160.529643058777,10098.519153118134,1061.1018662452698,0.3483004570007324,0.0 -12000,2.9104598,1.4951053,,,,,,,,,,,,,, -12100,1.7787594,1.4962015,,,,,,,,,,,,,, -12200,2.2195675,1.5094528,,,,,,,,,,,,,, -12300,4.695114,1.5441542,,,,,,,,,,,,,, -12400,2.1947927,1.4059951,,,,,,,,,,,,,, -12500,2.3998477,1.5457917,,,,,,,,,,,,,, -12600,1.9071679,1.5271622,,,,,,,,,,,,,, -12700,2.4901686,1.4754092,,,,,,,,,,,,,, -12800,3.4298182,1.5513904,,,,,,,,,,,,,, -12900,2.1308544,1.5084305,,,,,,,,,,,,,, -13000,2.0411153,1.5296777,,,,,,,,,,,,,, -13100,2.0487828,1.5060806,,,,,,,,,,,,,, -13200,2.9486835,1.5448289,,,,,,,,,,,,,, -13300,2.2872953,1.4463981,,,,,,,,,,,,,, -13400,3.653674,1.4628052,,,,,,,,,,,,,, -13500,2.1230438,1.4961748,,,,,,,,,,,,,, -13600,2.6147068,1.4544537,,,,,,,,,,,,,, -13621,,,0.44562355,0.1491876088023925,0.6191698,0.1785145350801819,5348.0,0.3623559,0.1185790018889769,2472.0,11538.94403219223,12734.627663373947,11538.94403219223,1194.6305973529816,0.4078514575958252,0.0 -13700,2.2119315,1.4774821,,,,,,,,,,,,,, -13800,2.440574,1.422646,,,,,,,,,,,,,, -13900,2.9034903,1.477201,,,,,,,,,,,,,, -14000,2.5563145,1.5087487,,,,,,,,,,,,,, -14100,3.553871,1.4616325,,,,,,,,,,,,,, -14200,2.050807,1.4954317,,,,,,,,,,,,,, -14300,3.1706023,1.5232965,,,,,,,,,,,,,, -14400,2.7036295,1.5033851,,,,,,,,,,,,,, -14500,2.2335396,1.4025002,,,,,,,,,,,,,, -14600,2.7620034,1.4493026,,,,,,,,,,,,,, -14700,1.847509,1.4180734,,,,,,,,,,,,,, -14800,4.677427,1.449445,,,,,,,,,,,,,, -14900,2.2018864,1.3827034,,,,,,,,,,,,,, -15000,1.7649194,1.4272872,,,,,,,,,,,,,, -15100,5.8095775,1.4528112,,,,,,,,,,,,,, -15200,1.9849666,1.448923,,,,,,,,,,,,,, -15293,,,0.44809407,0.1485734782534432,0.6029571,0.1749519681010262,5348.0,0.34927937,0.1144963743830357,2472.0,12979.073741436005,14308.919946432114,12979.073741436005,1328.6569051742554,0.4639625549316406,0.0 -15300,2.7128205,1.4473062,,,,,,,,,,,,,, -15400,2.693484,1.5351484,,,,,,,,,,,,,, -15500,2.722153,1.4454879,,,,,,,,,,,,,, -15600,1.9938968,1.3882453,,,,,,,,,,,,,, -15700,2.556783,1.4166148,,,,,,,,,,,,,, -15800,2.7027051,1.4032739,,,,,,,,,,,,,, -15900,2.5691626,1.4988725,,,,,,,,,,,,,, -16000,2.3134031,1.3897526,,,,,,,,,,,,,, -16100,3.0811713,1.4401904,,,,,,,,,,,,,, -16200,4.224279,1.52863,,,,,,,,,,,,,, -16300,2.501877,1.4730825,,,,,,,,,,,,,, -16400,2.1378944,1.4551955,,,,,,,,,,,,,, -16500,2.3031597,1.4559993,,,,,,,,,,,,,, -16600,2.794286,1.3766141,,,,,,,,,,,,,, -16700,2.5720239,1.4660897,,,,,,,,,,,,,, -16800,1.8925855,1.3609555,,,,,,,,,,,,,, -16900,2.7511613,1.3900168,,,,,,,,,,,,,, -17000,3.6850722,1.3833616,,,,,,,,,,,,,, -17003,,,0.42328086,0.1401969991726482,0.58921814,0.1701149869179451,5348.0,0.33758086,0.1090731826214124,2472.0,14419.393539190292,15885.092690944672,14419.393539190292,1464.3784432411194,0.5136878490447998,0.0 -17100,2.7478838,1.4003068,,,,,,,,,,,,,, -17200,3.2943614,1.4216986,,,,,,,,,,,,,, -17300,2.3240452,1.4833547,,,,,,,,,,,,,, -17400,2.6103647,1.3919727,,,,,,,,,,,,,, -17500,2.171864,1.3977293,,,,,,,,,,,,,, -17600,4.107323,1.4607731,,,,,,,,,,,,,, -17700,2.229076,1.4353317,,,,,,,,,,,,,, -17800,2.03995,1.4000032,,,,,,,,,,,,,, -17900,2.1766498,1.4563252,,,,,,,,,,,,,, -18000,2.6093733,1.4202201,,,,,,,,,,,,,, -18100,2.9846258,1.4426782,,,,,,,,,,,,,, -18200,3.2995954,1.4268736,,,,,,,,,,,,,, -18300,2.3974147,1.3496823,,,,,,,,,,,,,, -18400,2.4260778,1.4452741,,,,,,,,,,,,,, -18500,3.060063,1.4072179,,,,,,,,,,,,,, -18600,2.7673814,1.3939558,,,,,,,,,,,,,, -18687,,,0.40703046,0.1380073719048515,0.56702036,0.1645442521023007,5348.0,0.32256627,0.1058436414599963,2472.0,15859.808149814606,17461.924512147903,15859.808149814606,1600.6581590175629,0.5700757503509521,0.0 -18700,2.2927775,1.3910831,,,,,,,,,,,,,, -18800,2.2173102,1.4415159,,,,,,,,,,,,,, -18900,3.6352391,1.4234489,,,,,,,,,,,,,, -19000,2.1981342,1.4592144,,,,,,,,,,,,,, -19100,2.147264,1.4524865,,,,,,,,,,,,,, -19200,2.7228708,1.4376441,,,,,,,,,,,,,, -19300,3.1337707,1.4789021,,,,,,,,,,,,,, -19400,2.2853334,1.3690346,,,,,,,,,,,,,, -19500,3.2946036,1.348211,,,,,,,,,,,,,, -19600,2.9710345,1.4195527,,,,,,,,,,,,,, -19700,2.923521,1.4126965,,,,,,,,,,,,,, -19800,2.167028,1.4001513,,,,,,,,,,,,,, -19900,2.6489983,1.4540664,,,,,,,,,,,,,, -20000,6.0127907,1.3998905,,,,,,,,,,,,,, -20100,2.6935568,1.4159956,,,,,,,,,,,,,, -20200,3.9487035,1.3654398,,,,,,,,,,,,,, -20300,2.3422067,1.4507736,,,,,,,,,,,,,, -20376,,,0.35868832,0.1200126644541563,0.5564612,0.160730664143584,5348.0,0.31212524,0.1036499908597891,2472.0,17300.169142961502,19035.71454906464,17300.169142961502,1733.951742887497,0.6257908344268799,0.0 -20400,3.3608072,1.3868083,,,,,,,,,,,,,, -20500,2.2255752,1.3669007,,,,,,,,,,,,,, -20600,4.304128,1.3559843,,,,,,,,,,,,,, -20700,3.7910066,1.3347677,,,,,,,,,,,,,, -20800,2.0850103,1.4111774,,,,,,,,,,,,,, -20900,3.7211623,1.4280249,,,,,,,,,,,,,, -21000,4.4950285,1.4113302,,,,,,,,,,,,,, -21100,2.8138938,1.3803694,,,,,,,,,,,,,, -21200,3.07351,1.3572284,,,,,,,,,,,,,, -21300,1.814109,1.4029162,,,,,,,,,,,,,, -21400,2.0724473,1.3681065,,,,,,,,,,,,,, -21500,2.7664006,1.4075242,,,,,,,,,,,,,, -21600,2.5343425,1.3590016,,,,,,,,,,,,,, -21700,2.0710328,1.3280817,,,,,,,,,,,,,, -21800,2.6662514,1.3938575,,,,,,,,,,,,,, -21900,3.1199687,1.3811134,,,,,,,,,,,,,, -22000,3.8795917,1.3609333,,,,,,,,,,,,,, -22081,,,0.32458857,0.1098593623444518,0.5423642,0.1574770460623497,5348.0,0.30617353,0.1011313549854772,2472.0,18740.1226541996,20611.221473693848,18740.1226541996,1869.3748636245728,0.6759476661682129,0.0 -22100,2.334998,1.2993933,,,,,,,,,,,,,, -22200,2.626137,1.4394116,,,,,,,,,,,,,, -22300,2.237539,1.366594,,,,,,,,,,,,,, -22400,2.133205,1.281978,,,,,,,,,,,,,, -22500,2.3066623,1.3293544,,,,,,,,,,,,,, -22600,2.7513115,1.3421901,,,,,,,,,,,,,, -22700,2.422966,1.3814036,,,,,,,,,,,,,, -22800,2.5102055,1.3048439,,,,,,,,,,,,,, -22900,2.1844757,1.3086963,,,,,,,,,,,,,, -23000,2.3684218,1.3257354,,,,,,,,,,,,,, -23100,2.2554603,1.3513745,,,,,,,,,,,,,, -23200,2.786255,1.3592247,,,,,,,,,,,,,, -23300,2.1913836,1.3166938,,,,,,,,,,,,,, -23400,2.531389,1.368354,,,,,,,,,,,,,, -23500,2.7352698,1.3437395,,,,,,,,,,,,,, -23600,2.586785,1.3532861,,,,,,,,,,,,,, -23700,1.9161073,1.3403811,,,,,,,,,,,,,, -23760,,,0.36119395,0.1249985887984465,0.5224131,0.1516263263079641,5348.0,0.29169166,0.0963987569313265,2472.0,20180.643664598465,22185.8266851902,20180.643664598465,2003.326298236847,0.7289724349975586,0.0 -23800,2.3329077,1.2517338,,,,,,,,,,,,,, -23900,2.8151414,1.293344,,,,,,,,,,,,,, -24000,2.292568,1.3674011,,,,,,,,,,,,,, -24100,2.3842964,1.3326095,,,,,,,,,,,,,, -24200,2.0881882,1.317312,,,,,,,,,,,,,, -24300,2.2858183,1.3291613,,,,,,,,,,,,,, -24400,1.8473048,1.338059,,,,,,,,,,,,,, -24500,3.2297535,1.3462312,,,,,,,,,,,,,, -24600,2.1642838,1.309293,,,,,,,,,,,,,, -24700,1.9022164,1.3152571,,,,,,,,,,,,,, -24800,3.0498605,1.3271366,,,,,,,,,,,,,, -24900,3.4307737,1.3793215,,,,,,,,,,,,,, -25000,2.2316628,1.3514926,,,,,,,,,,,,,, -25100,3.832429,1.2769517,,,,,,,,,,,,,, -25200,2.5110533,1.267152,,,,,,,,,,,,,, -25300,2.2900279,1.3199756,,,,,,,,,,,,,, -25400,1.9389684,1.265323,,,,,,,,,,,,,, -25457,,,0.3251517,0.10952787258248,0.51066995,0.1475134441043861,5348.0,0.28612068,0.092600491540227,2472.0,21621.02370738983,23759.84077358246,21621.02370738983,2136.82527923584,0.7829153537750244,0.0 -25500,3.1010797,1.3071569,,,,,,,,,,,,,, -25600,2.850557,1.277966,,,,,,,,,,,,,, -25700,2.819539,1.2961394,,,,,,,,,,,,,, -25800,2.9409158,1.2523743,,,,,,,,,,,,,, -25900,2.9019914,1.3313197,,,,,,,,,,,,,, -26000,1.806782,1.299783,,,,,,,,,,,,,, -26100,2.4425724,1.3291868,,,,,,,,,,,,,, -26200,2.3894925,1.259916,,,,,,,,,,,,,, -26300,2.2826195,1.283863,,,,,,,,,,,,,, -26400,2.0937648,1.3435346,,,,,,,,,,,,,, -26500,2.153138,1.2618589,,,,,,,,,,,,,, -26600,2.4468822,1.3551308,,,,,,,,,,,,,, -26700,2.929278,1.2854079,,,,,,,,,,,,,, -26800,3.3246617,1.3839191,,,,,,,,,,,,,, -26900,2.4292934,1.2812148,,,,,,,,,,,,,, -27000,2.7295978,1.3105949,,,,,,,,,,,,,, -27100,3.17731,1.3014948,,,,,,,,,,,,,, -27173,,,0.3157251,0.1056232280028406,0.50370485,0.1455438948801374,5348.0,0.28231353,0.0910364999085979,2472.0,23060.985421419144,25335.851276874542,23060.985421419144,2272.730340003968,0.8456952571868896,0.0 -27200,2.1472585,1.2625436,,,,,,,,,,,,,, -27300,2.1604476,1.3163944,,,,,,,,,,,,,, -27400,1.8254498,1.3250923,,,,,,,,,,,,,, -27500,1.5185518,1.2447708,,,,,,,,,,,,,, -27600,3.0033479,1.269023,,,,,,,,,,,,,, -27700,2.4502351,1.3046317,,,,,,,,,,,,,, -27800,2.2012942,1.2726824,,,,,,,,,,,,,, -27900,2.5361683,1.2503005,,,,,,,,,,,,,, -28000,2.5379677,1.2251126,,,,,,,,,,,,,, -28100,2.0897617,1.2717494,,,,,,,,,,,,,, -28200,2.5179846,1.3331302,,,,,,,,,,,,,, -28300,2.3402894,1.2642382,,,,,,,,,,,,,, -28400,3.1108732,1.2475137,,,,,,,,,,,,,, -28500,2.4350405,1.3381418,,,,,,,,,,,,,, -28600,2.038597,1.31125,,,,,,,,,,,,,, -28700,2.1707804,1.2593932,,,,,,,,,,,,,, -28800,3.0660684,1.2589824,,,,,,,,,,,,,, -28848,,,0.28317812,0.0987635239567233,0.4912923,0.1415372138602199,5348.0,0.27055633,0.0877460240082871,2472.0,24501.122561454773,26909.666396856308,24501.122561454773,2406.2770829200745,0.8960163593292236,0.0 -28900,2.3467093,1.2517239,,,,,,,,,,,,,, -29000,3.8634837,1.2264702,,,,,,,,,,,,,, -29100,2.1475878,1.2567061,,,,,,,,,,,,,, -29200,2.1114993,1.235269,,,,,,,,,,,,,, -29300,1.7375482,1.2741864,,,,,,,,,,,,,, -29400,2.799837,1.2866943,,,,,,,,,,,,,, -29500,2.6439857,1.3232758,,,,,,,,,,,,,, -29600,2.2299833,1.304304,,,,,,,,,,,,,, -29700,2.5181177,1.2301458,,,,,,,,,,,,,, -29800,1.794587,1.3088988,,,,,,,,,,,,,, -29900,2.6132905,1.2801452,,,,,,,,,,,,,, -30000,2.5938141,1.213421,,,,,,,,,,,,,, -30100,2.6098907,1.2405372,,,,,,,,,,,,,, -30200,2.1083522,1.2292856,,,,,,,,,,,,,, -30300,1.8567731,1.2024279,,,,,,,,,,,,,, -30400,2.9527202,1.2724961,,,,,,,,,,,,,, -30500,2.5644834,1.2399515,,,,,,,,,,,,,, -30561,,,0.2763036,0.0966358705988807,0.48683694,0.1402627996562943,5348.0,0.26829788,0.0875022850527085,2472.0,25941.73451423645,28487.4870493412,25941.73451423645,2543.348155975342,0.951606273651123,0.0 -30600,2.1194656,1.3066154,,,,,,,,,,,,,, -30700,2.3235917,1.2156001,,,,,,,,,,,,,, -30800,2.883419,1.2625792,,,,,,,,,,,,,, -30900,1.5351869,1.2376242,,,,,,,,,,,,,, -31000,2.689855,1.3192006,,,,,,,,,,,,,, -31100,3.6598802,1.2636741,,,,,,,,,,,,,, -31200,2.5771105,1.2324755,,,,,,,,,,,,,, -31300,2.3758953,1.2396109,,,,,,,,,,,,,, -31400,3.4749541,1.2182943,,,,,,,,,,,,,, -31500,2.1788664,1.2267928,,,,,,,,,,,,,, -31600,2.6230915,1.2221628,,,,,,,,,,,,,, -31700,1.7602772,1.2181913,,,,,,,,,,,,,, -31800,2.2508183,1.2742906,,,,,,,,,,,,,, -31900,3.8397043,1.2131228,,,,,,,,,,,,,, -32000,2.1322193,1.2372962,,,,,,,,,,,,,, -32100,3.1764503,1.3106322,,,,,,,,,,,,,, -32200,2.5214028,1.2106959,,,,,,,,,,,,,, -32288,,,0.25569287,0.0897257371193316,0.46944138,0.1357830406364347,5348.0,0.26102412,0.084333678630187,2472.0,27382.11405301094,30062.6204662323,27382.11405301094,2677.9681293964386,1.001542568206787,0.0 -32300,2.604627,1.2324195,,,,,,,,,,,,,, -32400,2.5574954,1.2138143,,,,,,,,,,,,,, -32500,2.3175182,1.2113049,,,,,,,,,,,,,, -32600,3.5129395,1.167862,,,,,,,,,,,,,, -32700,2.5519166,1.2160972,,,,,,,,,,,,,, -32800,2.800166,1.2720044,,,,,,,,,,,,,, -32900,1.8832369,1.219888,,,,,,,,,,,,,, -33000,3.4398289,1.2305173,,,,,,,,,,,,,, -33100,3.315761,1.278595,,,,,,,,,,,,,, -33200,1.6933392,1.193135,,,,,,,,,,,,,, -33300,2.654122,1.217884,,,,,,,,,,,,,, -33400,2.6624715,1.2100012,,,,,,,,,,,,,, -33500,2.4991462,1.2473748,,,,,,,,,,,,,, -33600,2.5027325,1.2047251,,,,,,,,,,,,,, -33700,2.4319067,1.2122428,,,,,,,,,,,,,, -33800,2.360877,1.2269825,,,,,,,,,,,,,, -33900,2.2440662,1.1790314,,,,,,,,,,,,,, -33983,,,0.15601371,0.0559534943036898,0.46224105,0.1343155333713083,5348.0,0.2518262,0.0820790932910852,2472.0,28822.875962257385,31657.9614071846,28822.875962257385,2832.421259641648,1.0486581325531006,0.0 -34000,3.6126103,1.1992625,,,,,,,,,,,,,, -34100,2.6460943,1.2157248,,,,,,,,,,,,,, -34200,2.5089235,1.2385871,,,,,,,,,,,,,, -34300,1.6186767,1.1493902,,,,,,,,,,,,,, -34400,1.7183483,1.157907,,,,,,,,,,,,,, -34500,2.0848362,1.2209046,,,,,,,,,,,,,, -34600,3.3171477,1.2384126,,,,,,,,,,,,,, -34700,3.2087915,1.1570477,,,,,,,,,,,,,, -34800,2.0471413,1.1831905,,,,,,,,,,,,,, -34900,2.2340717,1.1724634,,,,,,,,,,,,,, -35000,2.2872427,1.2573991,,,,,,,,,,,,,, -35100,1.8330052,1.1898842,,,,,,,,,,,,,, -35200,2.735189,1.1964756,,,,,,,,,,,,,, -35300,2.0972931,1.1891212,,,,,,,,,,,,,, -35400,1.99595,1.1948217,,,,,,,,,,,,,, -35500,3.3247693,1.2058008,,,,,,,,,,,,,, -35600,1.7641449,1.1600955,,,,,,,,,,,,,, -35663,,,0.15231837,0.0532382995745299,0.45944563,0.1315349932900161,5348.0,0.24787179,0.0796213921556679,2472.0,30262.88162112236,33234.50987505913,30262.88162112236,2968.835390806198,1.0996172428131104,0.0 -35700,2.5304432,1.205143,,,,,,,,,,,,,, -35800,2.9383175,1.186779,,,,,,,,,,,,,, -35900,1.4899191,1.2163962,,,,,,,,,,,,,, -36000,2.903448,1.259198,,,,,,,,,,,,,, -36100,2.012894,1.1991922,,,,,,,,,,,,,, -36200,2.4603434,1.168832,,,,,,,,,,,,,, -36300,1.711495,1.1465969,,,,,,,,,,,,,, -36400,2.3421326,1.2173963,,,,,,,,,,,,,, -36500,2.2340143,1.2169905,,,,,,,,,,,,,, -36600,3.6316593,1.1782181,,,,,,,,,,,,,, -36700,2.629095,1.1752905,,,,,,,,,,,,,, -36800,2.3738656,1.1857111,,,,,,,,,,,,,, -36900,1.6499445,1.1814693,,,,,,,,,,,,,, -37000,2.4025764,1.1829443,,,,,,,,,,,,,, -37100,2.0077262,1.1447453,,,,,,,,,,,,,, -37200,1.6652176,1.1157322,,,,,,,,,,,,,, -37300,2.9178865,1.1376489,,,,,,,,,,,,,, -37356,,,0.1572897,0.055789000766405,0.44850588,0.1289282369638047,5348.0,0.23940639,0.0773464952369345,2472.0,31703.97305607796,34811.08276438713,31703.97305607796,3104.18021607399,1.1535747051239014,0.0 -37400,2.676315,1.2372941,,,,,,,,,,,,,, -37500,1.8620678,1.1738721,,,,,,,,,,,,,, -37600,2.5114236,1.1401552,,,,,,,,,,,,,, -37700,2.477773,1.1633278,,,,,,,,,,,,,, -37800,2.4062467,1.1630393,,,,,,,,,,,,,, -37900,2.3565116,1.1812105,,,,,,,,,,,,,, -38000,3.3986301,1.1551846,,,,,,,,,,,,,, -38100,2.3816214,1.1752523,,,,,,,,,,,,,, -38200,1.9646856,1.1746274,,,,,,,,,,,,,, -38300,2.9213505,1.1821052,,,,,,,,,,,,,, -38400,1.7314458,1.1309034,,,,,,,,,,,,,, -38500,2.244329,1.1373677,,,,,,,,,,,,,, -38600,2.7695692,1.1728058,,,,,,,,,,,,,, -38700,3.2754128,1.1743423,,,,,,,,,,,,,, -38800,2.1639936,1.1462781,,,,,,,,,,,,,, -38900,2.2480736,1.1778253,,,,,,,,,,,,,, -39000,2.1564062,1.1338598,,,,,,,,,,,,,, -39037,,,0.13856104,0.049722430920778,0.44071487,0.1269683423926161,5348.0,0.2376608,0.076290293096094,2472.0,33144.23564553261,36389.54399847984,33144.23564553261,3242.2381801605225,1.2139358520507812,0.0 -39100,2.041168,1.1757214,,,,,,,,,,,,,, -39200,1.8815985,1.1725736,,,,,,,,,,,,,, -39300,2.959932,1.1666852,,,,,,,,,,,,,, -39400,2.950188,1.1742318,,,,,,,,,,,,,, -39500,1.6981493,1.1254752,,,,,,,,,,,,,, -39600,3.7869377,1.1572554,,,,,,,,,,,,,, -39700,2.2860537,1.1556233,,,,,,,,,,,,,, -39800,2.4612284,1.14781,,,,,,,,,,,,,, -39900,2.3080978,1.1281798,,,,,,,,,,,,,, -40000,3.670863,1.2182508,,,,,,,,,,,,,, -40100,2.080334,1.1484783,,,,,,,,,,,,,, -40200,2.474618,1.137895,,,,,,,,,,,,,, -40300,1.8026341,1.1436504,,,,,,,,,,,,,, -40400,3.3051786,1.0834426,,,,,,,,,,,,,, -40500,2.9368098,1.1455244,,,,,,,,,,,,,, -40600,3.7368114,1.15075,,,,,,,,,,,,,, -40700,2.7818282,1.1729909,,,,,,,,,,,,,, -40732,,,0.1460249,0.0518470339231544,0.4341819,0.1248636280255269,5348.0,0.23402792,0.0753356488534113,2472.0,34584.34844779968,37967.1045422554,34584.34844779968,3379.543396711349,1.2748100757598877,0.0 -40800,1.9111979,1.13995,,,,,,,,,,,,,, -40900,4.237692,1.096611,,,,,,,,,,,,,, -41000,1.5761949,1.1211214,,,,,,,,,,,,,, -41100,2.1860664,1.0701826,,,,,,,,,,,,,, -41200,3.0330334,1.13939,,,,,,,,,,,,,, -41300,2.8109732,1.1448601,,,,,,,,,,,,,, -41400,2.180318,1.1889591,,,,,,,,,,,,,, -41500,4.2819047,1.1898868,,,,,,,,,,,,,, -41600,3.0931442,1.1591525,,,,,,,,,,,,,, -41700,2.9670835,1.1020126,,,,,,,,,,,,,, -41800,2.494124,1.1396234,,,,,,,,,,,,,, -41900,2.996877,1.1561104,,,,,,,,,,,,,, -42000,2.4081407,1.1326793,,,,,,,,,,,,,, -42100,2.3788602,1.2031076,,,,,,,,,,,,,, -42200,5.0484,1.0972817,,,,,,,,,,,,,, -42300,4.815595,1.1112851,,,,,,,,,,,,,, -42400,3.9656942,1.0919783,,,,,,,,,,,,,, -42425,,,0.13446575,0.0488422537318217,0.43170702,0.1238691987603425,5348.0,0.23212434,0.0747872362033595,2472.0,36024.96976852417,39542.14425396919,36024.96976852417,3513.8202295303345,1.3349525928497314,0.0 -42500,2.1716163,1.1534225,,,,,,,,,,,,,, -42600,3.7813647,1.1089145,,,,,,,,,,,,,, -42700,2.2349534,1.1411449,,,,,,,,,,,,,, -42800,2.9633865,1.1302575,,,,,,,,,,,,,, -42900,2.8774958,1.141377,,,,,,,,,,,,,, -43000,2.4459481,1.1461397,,,,,,,,,,,,,, -43100,3.368505,1.1303315,,,,,,,,,,,,,, -43200,3.5800068,1.1067053,,,,,,,,,,,,,, -43300,2.5901299,1.1519672,,,,,,,,,,,,,, -43400,3.6004148,1.1084558,,,,,,,,,,,,,, -43500,3.2760477,1.2074472,,,,,,,,,,,,,, -43600,2.3299775,1.1684014,,,,,,,,,,,,,, -43700,2.3858774,1.1504151,,,,,,,,,,,,,, -43800,2.7542093,1.1197332,,,,,,,,,,,,,, -43900,2.8330112,1.1235667,,,,,,,,,,,,,, -44000,2.9917152,1.1722994,,,,,,,,,,,,,, -44094,,,0.15975782,0.0529069018799897,0.4287825,0.1227975322706778,5348.0,0.23024347,0.0743810046107285,2472.0,37465.89845681191,41118.98565912247,37465.89845681191,3649.597271680832,1.391124963760376,0.0 -44100,3.107829,1.1759793,,,,,,,,,,,,,, -44200,2.7199562,1.1419022,,,,,,,,,,,,,, -44300,2.2730336,1.1246662,,,,,,,,,,,,,, -44400,2.2785597,1.1063209,,,,,,,,,,,,,, -44500,3.91636,1.1116368,,,,,,,,,,,,,, -44600,3.120739,1.1144602,,,,,,,,,,,,,, -44700,2.7142382,1.1346257,,,,,,,,,,,,,, -44800,2.141653,1.1320456,,,,,,,,,,,,,, -44900,2.8225548,1.0912747,,,,,,,,,,,,,, -45000,2.7442718,1.1488992,,,,,,,,,,,,,, -45100,4.5226197,1.1169779,,,,,,,,,,,,,, -45200,4.7304215,1.1143503,,,,,,,,,,,,,, -45300,3.3685234,1.1301513,,,,,,,,,,,,,, -45400,3.3416476,1.1129168,,,,,,,,,,,,,, -45500,2.718508,1.1710308,,,,,,,,,,,,,, -45600,4.2373867,1.1393584,,,,,,,,,,,,,, -45700,2.4158704,1.1353337,,,,,,,,,,,,,, -45785,,,0.14198789,0.050932994062765,0.42813993,0.1223051449646157,5348.0,0.22979745,0.07413726565515,2472.0,38906.24567842484,42694.73621749878,38906.24567842484,3784.859577178955,1.4510438442230225,0.0 -45800,1.8461953,1.0937111,,,,,,,,,,,,,, -45900,2.4331412,1.1193871,,,,,,,,,,,,,, -46000,2.379735,1.1584779,,,,,,,,,,,,,, -46100,2.6317282,1.1289828,,,,,,,,,,,,,, -46200,2.114127,1.1120638,,,,,,,,,,,,,, -46300,4.0224304,1.0989608,,,,,,,,,,,,,, -46400,2.366819,1.1539319,,,,,,,,,,,,,, -46500,2.925727,1.1649294,,,,,,,,,,,,,, -46600,1.9326982,1.0988271,,,,,,,,,,,,,, -46700,3.0938568,1.1354978,,,,,,,,,,,,,, -46800,2.1974075,1.0843254,,,,,,,,,,,,,, -46900,1.9336641,1.159223,,,,,,,,,,,,,, -47000,3.8459136,1.1474905,,,,,,,,,,,,,, -47100,1.8317205,1.110192,,,,,,,,,,,,,, -47200,3.5093348,1.1300517,,,,,,,,,,,,,, -47300,2.339557,1.1364893,,,,,,,,,,,,,, -47400,3.4949327,1.1349746,,,,,,,,,,,,,, -47459,,,0.1208012,0.0438509485377796,0.4275235,0.1221120519034148,5348.0,0.22964966,0.0739138382792029,2472.0,40346.5943479538,44270.86405515671,40346.5943479538,3920.495556592941,1.513451337814331,0.0 -47500,2.5071452,1.1223722,,,,,,,,,,,,,, -47600,2.3632638,1.1283282,,,,,,,,,,,,,, -47700,3.2150195,1.1287947,,,,,,,,,,,,,, -47800,2.472579,1.1213895,,,,,,,,,,,,,, -47900,3.4415748,1.0962477,,,,,,,,,,,,,, -48000,,,0.12780735,0.0460682191722342,0.42751795,0.1220251600258744,5348.0,0.22965063,0.0739950845977291,2472.0,40782.78246617317,44844.3178961277,40782.78246617317,4057.6739218235016,1.570603847503662,0.0 -48000,,,,,,,,,,,40782.78246617317,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 1b7818486..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -164.4026050567627,0.0,15.222284317016602,1,0,15.222284317016602,28.256304,2472,2.3443422094936324,179.6249499320984,29.058031,2.360243429556226,28.14937,5348,2.1855045038956527 -275.77101039886475,0.0288538932800292,1455.277383327484,1708,0,1455.277383327484,6.1304574,2472,0.899579550301627,1731.153431415558,6.197022,0.943809484010489,6.2215743,5348,0.8966179750330672 -392.0126388072968,0.0841245651245117,2895.461868286133,3430,0,2895.461868286133,3.945421,2472,0.8374667397883533,3287.7166328430176,4.192091,0.8861044829697827,4.297596,5348,0.8586076059356808 -527.4732053279877,0.1376943588256836,4335.978147506714,5111,0,4335.978147506714,0.74433875,2472,0.2344159405276948,4863.82788681984,0.75358486,0.2408147853119326,1.097035,5348,0.3027988839221062 -665.1103904247284,0.1945838928222656,5776.205506324768,6798,0,5776.205506324768,0.57972324,2472,0.1880039810696077,6441.831294298172,0.56135374,0.1861137746199915,0.89679164,5348,0.2572771947440069 -800.5987780094147,0.255300760269165,7216.254794597626,8513,0,7216.254794597626,0.49849275,2472,0.1604817906688603,8017.511333227158,0.43911177,0.1493955322637619,0.7938834,5348,0.2276953377680373 -938.0026862621309,0.3100192546844482,8656.777244329453,10207,0,8656.777244329453,0.46262303,2472,0.147421444965775,9595.573912382126,0.39790037,0.134293730165625,0.73437,5348,0.2113307008312656 -1072.4567618370056,0.3640625476837158,10096.870054721832,11906,0,10096.870054721832,0.43567663,2472,0.1397436678650498,11170.25594997406,0.40488553,0.136518069681887,0.7124341,5348,0.2063199358931036 -1206.3002724647522,0.4186415672302246,11537.339000701904,13599,0,11537.339000701904,0.41858706,2472,0.135376678244267,12744.704414606094,0.37492138,0.1258041847568515,0.6971934,5348,0.201154696505981 -1345.0807433128357,0.4681658744812011,12977.885040283203,15296,0,12977.885040283203,0.40154976,2472,0.131781528649483,14324.161016225817,0.36053848,0.1228070175438596,0.6717796,5348,0.1930834065477857 -1481.8800811767578,0.5250308513641357,14418.47855091095,17012,0,14418.47855091095,0.38370186,2472,0.1228038104523388,15901.694470643995,0.38336515,0.122435304253486,0.6500841,5348,0.1883815905075451 -1621.8828246593475,0.5793395042419434,15858.525933265686,18700,0,15858.525933265686,0.37486923,2472,0.1191477261186602,17481.879634141922,0.2896522,0.0982423036692167,0.62595093,5348,0.1799820423453083 -1756.674932718277,0.6285579204559326,17298.88808941841,20401,0,17298.88808941841,0.36329266,2472,0.1185993134686084,19057.16420722008,0.3050606,0.10250043128911,0.6120852,5348,0.1774428685905172 -1891.644542694092,0.6785871982574463,18739.611248254776,22128,0,18739.611248254776,0.34999764,2472,0.1121199195661446,20632.98897361756,0.4048885,0.13275771099435,0.60599583,5348,0.1750871332438668 -2025.3478963375087,0.7357115745544434,20180.07414293289,23829,0,20180.07414293289,0.33509344,2472,0.108057603639835,22207.29289650917,0.4211561,0.1378736712524745,0.5772926,5348,0.1673344468366529 -2158.018583536148,0.7869384288787842,21620.75711417198,25527,0,21620.75711417198,0.32599014,2472,0.1042187150894725,23780.77961206436,0.4618884,0.1521776963059216,0.565899,5348,0.1642932311227396 -2288.8654062747955,0.8404562473297119,23061.056136369705,27250,0,23061.056136369705,0.31674242,2472,0.1024922308207909,25352.06184864044,0.39754912,0.1286418593731658,0.54623437,5348,0.1589445533274761 -2423.746869325638,0.89430832862854,24501.34368991852,28939,0,24501.34368991852,0.30347437,2472,0.0970284158999045,26927.3649828434,0.35738322,0.118875696087301,0.5297782,5348,0.1531807254506309 -2559.73819565773,0.9447612762451172,25941.561491966248,30640,0,25941.561491966248,0.29775858,2472,0.0966831190461682,28503.707077503204,0.3072533,0.1059137139551105,0.51826257,5348,0.1500526178591772 -2694.1472787857056,0.9967937469482422,27381.578418970108,32345,0,27381.578418970108,0.2845581,2472,0.0911380578067556,30078.26686859131,0.34496155,0.1164896053751008,0.49741358,5348,0.1455921681454376 -2830.545342445373,1.0511739253997805,28822.132719278336,34019,0,28822.132719278336,0.26872075,2472,0.0860195397396055,31655.35311293602,0.2937716,0.0989305956582183,0.48239866,5348,0.1393938808808905 -2964.973935842514,1.1049518585205078,30262.328066825867,35722,0,30262.328066825867,0.25975537,2472,0.0832977880689781,33230.11437892914,0.279953,0.096801433933893,0.4692077,5348,0.136082334881296 -3101.6122257709503,1.1635775566101074,31702.204449653625,37434,0,31702.204449653625,0.2533827,2472,0.0805760363983507,34806.76887631416,0.27683035,0.0935626034620739,0.4581124,5348,0.1331666296571632 -3235.8583705425262,1.2178065776824951,33142.71718811989,39117,0,33142.71718811989,0.24226296,2472,0.0769808868035667,36381.66338968277,0.27329835,0.0938528237962617,0.44634154,5348,0.1302316151269104 -3369.8367607593536,1.2777178287506104,34583.25195026398,40807,0,34583.25195026398,0.23748754,2472,0.0763106046757256,37956.318457603455,0.26125973,0.0869634244483009,0.43017662,5348,0.1248539733724668 -3504.3984801769257,1.3354296684265137,36023.209208488464,42493,0,36023.209208488464,0.23134708,2472,0.0739747730180976,39530.97588849068,0.23121312,0.0805833849035775,0.42062643,5348,0.1218706855769137 -3639.836142063141,1.39624285697937,37463.483968019485,44175,0,37463.483968019485,0.22691439,2472,0.0724107813864684,41106.83026814461,0.2266846,0.0783503848103599,0.4167648,5348,0.1213589889647315 -3772.66555261612,1.4497272968292236,38903.80582237244,45875,0,38903.80582237244,0.2248237,2472,0.0717811224178904,42680.11693024635,0.21549411,0.07470870837705,0.4130225,5348,0.1205962713729882 -3906.2811391353607,1.5107195377349854,40344.11418509483,47567,0,40344.11418509483,0.22458875,2472,0.0714764487234172,44254.18384337425,0.22340734,0.0765954799448001,0.413083,5348,0.1203452503934271 -4039.852513551712,1.5739483833312988,40700.30122280121,48000,0,40700.30122280121,0.22453754,2472,0.07129364450673328,44744.03060722351,0.2246082,0.0775943396226415,0.41315985,5348,0.12034525039342711 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/measurements.csv deleted file mode 100644 index d0326bafb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.116364,32.883373,,,,,,,,,,,,,, -1,,,29.058031,2.360243429556226,28.14937,2.1855045038956527,5348.0,28.256304,2.3443422094936324,2472.0,15.222284317016602,179.6249499320984,15.222284317016602,164.4026050567627,0.0,0.0 -100,5.7899575,9.10455,,,,,,,,,,,,,, -200,1.6940492,6.4983997,,,,,,,,,,,,,, -300,0.5279651,5.913055,,,,,,,,,,,,,, -400,0.4028226,5.855239,,,,,,,,,,,,,, -500,0.3627856,5.818173,,,,,,,,,,,,,, -600,0.42951548,5.7981715,,,,,,,,,,,,,, -700,0.7810807,5.7550545,,,,,,,,,,,,,, -800,0.5919283,5.6368623,,,,,,,,,,,,,, -900,0.47743398,5.523805,,,,,,,,,,,,,, -1000,0.52411705,5.406255,,,,,,,,,,,,,, -1100,1.3652543,5.2018414,,,,,,,,,,,,,, -1200,0.8588487,4.808166,,,,,,,,,,,,,, -1300,1.7869838,4.3451285,,,,,,,,,,,,,, -1400,2.503166,4.0274677,,,,,,,,,,,,,, -1500,2.2064807,3.7243543,,,,,,,,,,,,,, -1600,1.9470832,3.495843,,,,,,,,,,,,,, -1700,1.6716436,3.312048,,,,,,,,,,,,,, -1708,,,6.197022,0.943809484010489,6.2215743,0.8966179750330672,5348.0,6.1304574,0.899579550301627,2472.0,1455.277383327484,1731.153431415558,1455.277383327484,275.77101039886475,0.0288538932800292,0.0 -1800,2.1714623,3.1308105,,,,,,,,,,,,,, -1900,2.3536522,3.102767,,,,,,,,,,,,,, -2000,2.840626,3.020813,,,,,,,,,,,,,, -2100,2.4071875,2.8902752,,,,,,,,,,,,,, -2200,2.4005697,2.724623,,,,,,,,,,,,,, -2300,2.457866,2.6779993,,,,,,,,,,,,,, -2400,3.139471,2.6181061,,,,,,,,,,,,,, -2500,2.4891636,2.618306,,,,,,,,,,,,,, -2600,2.8208697,2.5581954,,,,,,,,,,,,,, -2700,3.932886,2.491493,,,,,,,,,,,,,, -2800,5.93129,2.4877279,,,,,,,,,,,,,, -2900,5.386731,2.4208114,,,,,,,,,,,,,, -3000,3.1416233,2.4097373,,,,,,,,,,,,,, -3100,3.3935792,2.2507014,,,,,,,,,,,,,, -3200,3.1351645,2.276831,,,,,,,,,,,,,, -3300,5.7485466,2.3250253,,,,,,,,,,,,,, -3400,3.3606148,2.2279885,,,,,,,,,,,,,, -3430,,,4.192091,0.8861044829697827,4.297596,0.8586076059356808,5348.0,3.945421,0.8374667397883533,2472.0,2895.461868286133,3287.7166328430176,2895.461868286133,392.0126388072968,0.0841245651245117,0.0 -3500,3.6239212,2.2301142,,,,,,,,,,,,,, -3600,3.3505409,2.1503263,,,,,,,,,,,,,, -3700,3.224593,2.121372,,,,,,,,,,,,,, -3800,6.2731075,2.174557,,,,,,,,,,,,,, -3900,3.429079,2.1909437,,,,,,,,,,,,,, -4000,4.972707,2.1380029,,,,,,,,,,,,,, -4100,5.537512,2.0599287,,,,,,,,,,,,,, -4200,3.5581663,2.056336,,,,,,,,,,,,,, -4300,2.7944763,2.1102557,,,,,,,,,,,,,, -4400,2.7652867,2.0816562,,,,,,,,,,,,,, -4500,4.2205,2.0683367,,,,,,,,,,,,,, -4600,3.9352176,2.0000472,,,,,,,,,,,,,, -4700,3.0568285,1.9950639,,,,,,,,,,,,,, -4800,3.0310023,1.9985951,,,,,,,,,,,,,, -4900,3.406628,1.9679483,,,,,,,,,,,,,, -5000,3.1053865,1.9506752,,,,,,,,,,,,,, -5100,3.7991016,2.030519,,,,,,,,,,,,,, -5111,,,0.75358486,0.2408147853119326,1.097035,0.3027988839221062,5348.0,0.74433875,0.2344159405276948,2472.0,4335.978147506714,4863.82788681984,4335.978147506714,527.4732053279877,0.1376943588256836,0.0 -5200,4.175105,1.9066278,,,,,,,,,,,,,, -5300,4.720848,1.9532671,,,,,,,,,,,,,, -5400,2.7750213,1.9378343,,,,,,,,,,,,,, -5500,3.9382627,1.9417855,,,,,,,,,,,,,, -5600,4.1722574,1.890569,,,,,,,,,,,,,, -5700,4.4486046,1.8976651,,,,,,,,,,,,,, -5800,4.39754,1.8827868,,,,,,,,,,,,,, -5900,3.645513,1.9292806,,,,,,,,,,,,,, -6000,3.1700675,1.8583639,,,,,,,,,,,,,, -6100,2.7713325,1.8812337,,,,,,,,,,,,,, -6200,3.0822353,1.8313541,,,,,,,,,,,,,, -6300,3.0285635,1.8536593,,,,,,,,,,,,,, -6400,5.1032524,1.8493828,,,,,,,,,,,,,, -6500,4.272853,1.7927507,,,,,,,,,,,,,, -6600,4.7452106,1.8010336,,,,,,,,,,,,,, -6700,4.2796993,1.8346484,,,,,,,,,,,,,, -6798,,,0.56135374,0.1861137746199915,0.89679164,0.2572771947440069,5348.0,0.57972324,0.1880039810696077,2472.0,5776.205506324768,6441.831294298172,5776.205506324768,665.1103904247284,0.1945838928222656,0.0 -6800,2.4808457,1.7959282,,,,,,,,,,,,,, -6900,2.1928933,1.7964357,,,,,,,,,,,,,, -7000,1.8763452,1.7138916,,,,,,,,,,,,,, -7100,2.2030602,1.7063336,,,,,,,,,,,,,, -7200,4.2791867,1.8392378,,,,,,,,,,,,,, -7300,2.5626054,1.7009364,,,,,,,,,,,,,, -7400,3.7711563,1.7559216,,,,,,,,,,,,,, -7500,3.101652,1.8026377,,,,,,,,,,,,,, -7600,2.8949182,1.860974,,,,,,,,,,,,,, -7700,3.1233864,1.7908356,,,,,,,,,,,,,, -7800,2.0046806,1.8395942,,,,,,,,,,,,,, -7900,3.5519202,1.7347556,,,,,,,,,,,,,, -8000,2.7388494,1.7075418,,,,,,,,,,,,,, -8100,2.9635615,1.7207884,,,,,,,,,,,,,, -8200,4.07396,1.7597922,,,,,,,,,,,,,, -8300,3.075035,1.6528456,,,,,,,,,,,,,, -8400,3.0499344,1.7434822,,,,,,,,,,,,,, -8500,2.788544,1.7268461,,,,,,,,,,,,,, -8513,,,0.43911177,0.1493955322637619,0.7938834,0.2276953377680373,5348.0,0.49849275,0.1604817906688603,2472.0,7216.254794597626,8017.511333227158,7216.254794597626,800.5987780094147,0.255300760269165,0.0 -8600,3.9598296,1.7459836,,,,,,,,,,,,,, -8700,3.5424201,1.775305,,,,,,,,,,,,,, -8800,2.9670234,1.6923417,,,,,,,,,,,,,, -8900,2.0752966,1.6654073,,,,,,,,,,,,,, -9000,2.2399714,1.6832379,,,,,,,,,,,,,, -9100,2.5955865,1.7399364,,,,,,,,,,,,,, -9200,2.5920846,1.6417301,,,,,,,,,,,,,, -9300,4.168919,1.7379489,,,,,,,,,,,,,, -9400,2.7783413,1.7048513,,,,,,,,,,,,,, -9500,4.4829426,1.6867803,,,,,,,,,,,,,, -9600,2.8203702,1.639288,,,,,,,,,,,,,, -9700,1.8278894,1.6264637,,,,,,,,,,,,,, -9800,3.313238,1.5835115,,,,,,,,,,,,,, -9900,2.2768981,1.7779745,,,,,,,,,,,,,, -10000,2.5167744,1.7427665,,,,,,,,,,,,,, -10100,2.2713459,1.6671922,,,,,,,,,,,,,, -10200,2.4770362,1.6871561,,,,,,,,,,,,,, -10207,,,0.39790037,0.134293730165625,0.73437,0.2113307008312656,5348.0,0.46262303,0.147421444965775,2472.0,8656.777244329453,9595.573912382126,8656.777244329453,938.0026862621309,0.3100192546844482,0.0 -10300,1.9463332,1.655946,,,,,,,,,,,,,, -10400,3.301753,1.699263,,,,,,,,,,,,,, -10500,3.6499438,1.6415735,,,,,,,,,,,,,, -10600,3.303678,1.6774098,,,,,,,,,,,,,, -10700,2.9302452,1.6620864,,,,,,,,,,,,,, -10800,2.8093472,1.6645495,,,,,,,,,,,,,, -10900,3.141614,1.6915654,,,,,,,,,,,,,, -11000,4.458789,1.6553327,,,,,,,,,,,,,, -11100,3.0642695,1.6627883,,,,,,,,,,,,,, -11200,2.5963194,1.6578761,,,,,,,,,,,,,, -11300,3.702911,1.64634,,,,,,,,,,,,,, -11400,4.426281,1.6078613,,,,,,,,,,,,,, -11500,2.806781,1.6494846,,,,,,,,,,,,,, -11600,3.9110131,1.614607,,,,,,,,,,,,,, -11700,2.5035765,1.5949067,,,,,,,,,,,,,, -11800,2.9803472,1.6480671,,,,,,,,,,,,,, -11900,4.4930625,1.5764287,,,,,,,,,,,,,, -11906,,,0.40488553,0.136518069681887,0.7124341,0.2063199358931036,5348.0,0.43567663,0.1397436678650498,2472.0,10096.870054721832,11170.25594997406,10096.870054721832,1072.4567618370056,0.3640625476837158,0.0 -12000,2.73928,1.6832291,,,,,,,,,,,,,, -12100,2.494835,1.650966,,,,,,,,,,,,,, -12200,2.2923958,1.6383733,,,,,,,,,,,,,, -12300,2.50641,1.5921246,,,,,,,,,,,,,, -12400,3.519066,1.5959284,,,,,,,,,,,,,, -12500,3.860552,1.6386479,,,,,,,,,,,,,, -12600,4.518633,1.6366655,,,,,,,,,,,,,, -12700,2.8296633,1.6594361,,,,,,,,,,,,,, -12800,2.5375545,1.593838,,,,,,,,,,,,,, -12900,3.1227314,1.6103414,,,,,,,,,,,,,, -13000,3.010142,1.6624705,,,,,,,,,,,,,, -13100,2.1009066,1.631535,,,,,,,,,,,,,, -13200,2.429522,1.6083891,,,,,,,,,,,,,, -13300,3.0392537,1.6318868,,,,,,,,,,,,,, -13400,2.765835,1.511598,,,,,,,,,,,,,, -13500,2.4522202,1.5702326,,,,,,,,,,,,,, -13599,,,0.37492138,0.1258041847568515,0.6971934,0.201154696505981,5348.0,0.41858706,0.135376678244267,2472.0,11537.339000701904,12744.704414606094,11537.339000701904,1206.3002724647522,0.4186415672302246,0.0 -13600,3.664741,1.5480185,,,,,,,,,,,,,, -13700,2.152839,1.6088296,,,,,,,,,,,,,, -13800,2.3888843,1.6151681,,,,,,,,,,,,,, -13900,3.1382496,1.595799,,,,,,,,,,,,,, -14000,2.6993544,1.5836291,,,,,,,,,,,,,, -14100,3.2437093,1.5986035,,,,,,,,,,,,,, -14200,2.3876998,1.6223152,,,,,,,,,,,,,, -14300,2.82371,1.5778893,,,,,,,,,,,,,, -14400,3.6508288,1.5864832,,,,,,,,,,,,,, -14500,3.3466165,1.5666265,,,,,,,,,,,,,, -14600,3.468916,1.5665449,,,,,,,,,,,,,, -14700,2.2468443,1.5226244,,,,,,,,,,,,,, -14800,2.9167147,1.5260861,,,,,,,,,,,,,, -14900,2.6470413,1.5398456,,,,,,,,,,,,,, -15000,3.2786102,1.5212469,,,,,,,,,,,,,, -15100,2.4634438,1.6219181,,,,,,,,,,,,,, -15200,3.1215754,1.5527905,,,,,,,,,,,,,, -15296,,,0.36053848,0.1228070175438596,0.6717796,0.1930834065477857,5348.0,0.40154976,0.131781528649483,2472.0,12977.885040283203,14324.161016225817,12977.885040283203,1345.0807433128357,0.4681658744812011,0.0 -15300,2.9710436,1.6879253,,,,,,,,,,,,,, -15400,2.4895377,1.5852952,,,,,,,,,,,,,, -15500,2.8079557,1.5001847,,,,,,,,,,,,,, -15600,2.3839052,1.4941614,,,,,,,,,,,,,, -15700,2.126052,1.5599728,,,,,,,,,,,,,, -15800,2.5181727,1.5604299,,,,,,,,,,,,,, -15900,2.8993,1.6185904,,,,,,,,,,,,,, -16000,3.8049018,1.5660443,,,,,,,,,,,,,, -16100,3.3118649,1.5702677,,,,,,,,,,,,,, -16200,2.8066547,1.5681134,,,,,,,,,,,,,, -16300,2.7096183,1.5824554,,,,,,,,,,,,,, -16400,2.5668995,1.5611279,,,,,,,,,,,,,, -16500,2.9341285,1.5717324,,,,,,,,,,,,,, -16600,3.0075939,1.5483744,,,,,,,,,,,,,, -16700,2.6589692,1.5146313,,,,,,,,,,,,,, -16800,2.388733,1.5593431,,,,,,,,,,,,,, -16900,1.9144622,1.500097,,,,,,,,,,,,,, -17000,2.7854555,1.4885246,,,,,,,,,,,,,, -17012,,,0.38336515,0.122435304253486,0.6500841,0.1883815905075451,5348.0,0.38370186,0.1228038104523388,2472.0,14418.47855091095,15901.694470643995,14418.47855091095,1481.8800811767578,0.5250308513641357,0.0 -17100,2.8957422,1.5487267,,,,,,,,,,,,,, -17200,4.182768,1.5617752,,,,,,,,,,,,,, -17300,2.2651458,1.5456504,,,,,,,,,,,,,, -17400,2.867671,1.5834591,,,,,,,,,,,,,, -17500,4.66451,1.6200359,,,,,,,,,,,,,, -17600,2.6686695,1.596133,,,,,,,,,,,,,, -17700,2.8410318,1.5567306,,,,,,,,,,,,,, -17800,3.614726,1.468256,,,,,,,,,,,,,, -17900,4.0350657,1.5577794,,,,,,,,,,,,,, -18000,2.9739769,1.6136103,,,,,,,,,,,,,, -18100,2.7989297,1.5657923,,,,,,,,,,,,,, -18200,3.6061199,1.5517024,,,,,,,,,,,,,, -18300,1.9988503,1.4513059,,,,,,,,,,,,,, -18400,2.4001935,1.5417656,,,,,,,,,,,,,, -18500,2.962035,1.5033197,,,,,,,,,,,,,, -18600,2.1756074,1.5619773,,,,,,,,,,,,,, -18700,,,0.2896522,0.0982423036692167,0.62595093,0.1799820423453083,5348.0,0.37486923,0.1191477261186602,2472.0,15858.525933265686,17481.879634141922,15858.525933265686,1621.8828246593475,0.5793395042419434,0.0 -18700,3.5677135,1.4924085,,,,,,,,,,,,,, -18800,3.2278683,1.4753143,,,,,,,,,,,,,, -18900,3.5787072,1.5458461,,,,,,,,,,,,,, -19000,1.977475,1.4853152,,,,,,,,,,,,,, -19100,3.4148664,1.4912486,,,,,,,,,,,,,, -19200,2.1704147,1.5587387,,,,,,,,,,,,,, -19300,3.697057,1.4843822,,,,,,,,,,,,,, -19400,2.7033172,1.4991494,,,,,,,,,,,,,, -19500,3.0167706,1.5326363,,,,,,,,,,,,,, -19600,3.13887,1.4952708,,,,,,,,,,,,,, -19700,2.2046494,1.514614,,,,,,,,,,,,,, -19800,4.274617,1.5327259,,,,,,,,,,,,,, -19900,2.3961875,1.4582968,,,,,,,,,,,,,, -20000,2.531922,1.4986917,,,,,,,,,,,,,, -20100,3.0864315,1.5398493,,,,,,,,,,,,,, -20200,2.9011505,1.4966437,,,,,,,,,,,,,, -20300,1.6463615,1.4764717,,,,,,,,,,,,,, -20400,2.7095363,1.4814436,,,,,,,,,,,,,, -20401,,,0.3050606,0.10250043128911,0.6120852,0.1774428685905172,5348.0,0.36329266,0.1185993134686084,2472.0,17298.88808941841,19057.16420722008,17298.88808941841,1756.674932718277,0.6285579204559326,0.0 -20500,4.101335,1.5208206,,,,,,,,,,,,,, -20600,2.630385,1.4750924,,,,,,,,,,,,,, -20700,3.0326033,1.4809265,,,,,,,,,,,,,, -20800,2.7347057,1.4383633,,,,,,,,,,,,,, -20900,2.7722604,1.471301,,,,,,,,,,,,,, -21000,2.4489043,1.493403,,,,,,,,,,,,,, -21100,2.4379237,1.4800649,,,,,,,,,,,,,, -21200,3.3666825,1.4729306,,,,,,,,,,,,,, -21300,2.4455557,1.4258721,,,,,,,,,,,,,, -21400,3.6802979,1.4520713,,,,,,,,,,,,,, -21500,2.5842896,1.4067769,,,,,,,,,,,,,, -21600,2.0005949,1.5119013,,,,,,,,,,,,,, -21700,4.3797026,1.4255196,,,,,,,,,,,,,, -21800,2.9456322,1.4986323,,,,,,,,,,,,,, -21900,2.870328,1.4857143,,,,,,,,,,,,,, -22000,3.2334774,1.5287373,,,,,,,,,,,,,, -22100,2.9898863,1.4531872,,,,,,,,,,,,,, -22128,,,0.4048885,0.13275771099435,0.60599583,0.1750871332438668,5348.0,0.34999764,0.1121199195661446,2472.0,18739.611248254776,20632.98897361756,18739.611248254776,1891.644542694092,0.6785871982574463,0.0 -22200,3.2560384,1.5522524,,,,,,,,,,,,,, -22300,2.4141176,1.4490393,,,,,,,,,,,,,, -22400,5.111145,1.4428092,,,,,,,,,,,,,, -22500,5.005746,1.4978529,,,,,,,,,,,,,, -22600,2.5679362,1.4363613,,,,,,,,,,,,,, -22700,3.9366658,1.4421954,,,,,,,,,,,,,, -22800,3.1118817,1.4329294,,,,,,,,,,,,,, -22900,2.3411984,1.4475522,,,,,,,,,,,,,, -23000,2.7785954,1.4259155,,,,,,,,,,,,,, -23100,3.5756402,1.5236455,,,,,,,,,,,,,, -23200,2.8709404,1.4185122,,,,,,,,,,,,,, -23300,2.2823193,1.4562128,,,,,,,,,,,,,, -23400,3.1969547,1.4483848,,,,,,,,,,,,,, -23500,2.8826437,1.4659449,,,,,,,,,,,,,, -23600,3.9376132,1.4521991,,,,,,,,,,,,,, -23700,3.5413816,1.4087608,,,,,,,,,,,,,, -23800,3.0242124,1.3925627,,,,,,,,,,,,,, -23829,,,0.4211561,0.1378736712524745,0.5772926,0.1673344468366529,5348.0,0.33509344,0.108057603639835,2472.0,20180.07414293289,22207.29289650917,20180.07414293289,2025.3478963375087,0.7357115745544434,0.0 -23900,3.754793,1.3870438,,,,,,,,,,,,,, -24000,3.2698746,1.4512682,,,,,,,,,,,,,, -24100,2.48718,1.4470279,,,,,,,,,,,,,, -24200,2.5714905,1.4583924,,,,,,,,,,,,,, -24300,2.903913,1.4906241,,,,,,,,,,,,,, -24400,2.6613584,1.4892329,,,,,,,,,,,,,, -24500,2.7999434,1.4461552,,,,,,,,,,,,,, -24600,4.2911334,1.4583703,,,,,,,,,,,,,, -24700,3.9415226,1.4128215,,,,,,,,,,,,,, -24800,2.5461218,1.3968471,,,,,,,,,,,,,, -24900,2.5859835,1.4063207,,,,,,,,,,,,,, -25000,2.6048844,1.4489836,,,,,,,,,,,,,, -25100,5.7260566,1.4362607,,,,,,,,,,,,,, -25200,3.8288953,1.4345626,,,,,,,,,,,,,, -25300,2.7222629,1.4330252,,,,,,,,,,,,,, -25400,3.2287257,1.3572973,,,,,,,,,,,,,, -25500,2.5993211,1.4892944,,,,,,,,,,,,,, -25527,,,0.4618884,0.1521776963059216,0.565899,0.1642932311227396,5348.0,0.32599014,0.1042187150894725,2472.0,21620.75711417198,23780.77961206436,21620.75711417198,2158.018583536148,0.7869384288787842,0.0 -25600,2.6100874,1.4375243,,,,,,,,,,,,,, -25700,3.2312822,1.4398279,,,,,,,,,,,,,, -25800,2.5644047,1.431405,,,,,,,,,,,,,, -25900,4.015628,1.3953468,,,,,,,,,,,,,, -26000,2.7190623,1.4139873,,,,,,,,,,,,,, -26100,2.2399137,1.4133027,,,,,,,,,,,,,, -26200,3.3510876,1.361223,,,,,,,,,,,,,, -26300,2.5186348,1.3923649,,,,,,,,,,,,,, -26400,2.0536394,1.3900402,,,,,,,,,,,,,, -26500,2.6105702,1.4068995,,,,,,,,,,,,,, -26600,2.933124,1.4202402,,,,,,,,,,,,,, -26700,2.9593585,1.3602535,,,,,,,,,,,,,, -26800,1.8752804,1.361675,,,,,,,,,,,,,, -26900,3.0810623,1.4378518,,,,,,,,,,,,,, -27000,2.566835,1.4110482,,,,,,,,,,,,,, -27100,2.496329,1.4220201,,,,,,,,,,,,,, -27200,3.8402843,1.4042451,,,,,,,,,,,,,, -27250,,,0.39754912,0.1286418593731658,0.54623437,0.1589445533274761,5348.0,0.31674242,0.1024922308207909,2472.0,23061.056136369705,25352.06184864044,23061.056136369705,2288.8654062747955,0.8404562473297119,0.0 -27300,2.937834,1.4572355,,,,,,,,,,,,,, -27400,2.2620273,1.4196985,,,,,,,,,,,,,, -27500,3.1237307,1.3708243,,,,,,,,,,,,,, -27600,3.3323011,1.3685917,,,,,,,,,,,,,, -27700,2.5461977,1.3769262,,,,,,,,,,,,,, -27800,3.2582662,1.3428555,,,,,,,,,,,,,, -27900,2.697706,1.3964276,,,,,,,,,,,,,, -28000,2.120923,1.3936343,,,,,,,,,,,,,, -28100,4.5454006,1.3598071,,,,,,,,,,,,,, -28200,3.7280488,1.3871226,,,,,,,,,,,,,, -28300,1.9739282,1.3071308,,,,,,,,,,,,,, -28400,3.7392251,1.3628234,,,,,,,,,,,,,, -28500,3.7392092,1.3897358,,,,,,,,,,,,,, -28600,3.917198,1.3890234,,,,,,,,,,,,,, -28700,2.249738,1.3670448,,,,,,,,,,,,,, -28800,2.8687603,1.3417641,,,,,,,,,,,,,, -28900,2.358896,1.3637599,,,,,,,,,,,,,, -28939,,,0.35738322,0.118875696087301,0.5297782,0.1531807254506309,5348.0,0.30347437,0.0970284158999045,2472.0,24501.34368991852,26927.3649828434,24501.34368991852,2423.746869325638,0.89430832862854,0.0 -29000,1.8676637,1.3661119,,,,,,,,,,,,,, -29100,3.983415,1.3793117,,,,,,,,,,,,,, -29200,3.6533,1.2854764,,,,,,,,,,,,,, -29300,2.5849962,1.34087,,,,,,,,,,,,,, -29400,1.8566076,1.3722452,,,,,,,,,,,,,, -29500,2.8514307,1.3697015,,,,,,,,,,,,,, -29600,2.7276182,1.3644092,,,,,,,,,,,,,, -29700,2.4500356,1.3359473,,,,,,,,,,,,,, -29800,2.9788196,1.3218669,,,,,,,,,,,,,, -29900,2.7728136,1.332149,,,,,,,,,,,,,, -30000,3.1608193,1.2727237,,,,,,,,,,,,,, -30100,2.7489796,1.3001032,,,,,,,,,,,,,, -30200,2.103997,1.388659,,,,,,,,,,,,,, -30300,2.1564074,1.3099694,,,,,,,,,,,,,, -30400,2.9006083,1.3497727,,,,,,,,,,,,,, -30500,2.8699663,1.2960324,,,,,,,,,,,,,, -30600,2.4399881,1.2950621,,,,,,,,,,,,,, -30640,,,0.3072533,0.1059137139551105,0.51826257,0.1500526178591772,5348.0,0.29775858,0.0966831190461682,2472.0,25941.561491966248,28503.707077503204,25941.561491966248,2559.73819565773,0.9447612762451172,0.0 -30700,2.314336,1.3608391,,,,,,,,,,,,,, -30800,2.842539,1.3447365,,,,,,,,,,,,,, -30900,2.247893,1.3124605,,,,,,,,,,,,,, -31000,2.2638152,1.3225131,,,,,,,,,,,,,, -31100,2.3769522,1.2922933,,,,,,,,,,,,,, -31200,2.0915263,1.3170756,,,,,,,,,,,,,, -31300,1.9388384,1.2987667,,,,,,,,,,,,,, -31400,2.582011,1.3147187,,,,,,,,,,,,,, -31500,2.4222877,1.3066286,,,,,,,,,,,,,, -31600,2.3449764,1.3406321,,,,,,,,,,,,,, -31700,3.56109,1.3675774,,,,,,,,,,,,,, -31800,2.3698008,1.3229058,,,,,,,,,,,,,, -31900,3.0970805,1.3231156,,,,,,,,,,,,,, -32000,2.4856644,1.2808065,,,,,,,,,,,,,, -32100,3.0509193,1.3712579,,,,,,,,,,,,,, -32200,5.8060527,1.3084564,,,,,,,,,,,,,, -32300,2.5281909,1.3070939,,,,,,,,,,,,,, -32345,,,0.34496155,0.1164896053751008,0.49741358,0.1455921681454376,5348.0,0.2845581,0.0911380578067556,2472.0,27381.578418970108,30078.26686859131,27381.578418970108,2694.1472787857056,0.9967937469482422,0.0 -32400,2.5207632,1.3471272,,,,,,,,,,,,,, -32500,2.1532156,1.3225911,,,,,,,,,,,,,, -32600,3.4017918,1.3275708,,,,,,,,,,,,,, -32700,2.7316275,1.2832481,,,,,,,,,,,,,, -32800,2.848236,1.3409303,,,,,,,,,,,,,, -32900,3.702418,1.2701001,,,,,,,,,,,,,, -33000,2.3021405,1.2981951,,,,,,,,,,,,,, -33100,2.0874834,1.3281054,,,,,,,,,,,,,, -33200,2.3649082,1.3057482,,,,,,,,,,,,,, -33300,1.9696074,1.2062552,,,,,,,,,,,,,, -33400,2.5407135,1.2911351,,,,,,,,,,,,,, -33500,3.2930427,1.2938709,,,,,,,,,,,,,, -33600,2.6658816,1.2982596,,,,,,,,,,,,,, -33700,2.9702592,1.3010864,,,,,,,,,,,,,, -33800,3.030579,1.2860488,,,,,,,,,,,,,, -33900,3.0203273,1.3381898,,,,,,,,,,,,,, -34000,3.0168388,1.2854934,,,,,,,,,,,,,, -34019,,,0.2937716,0.0989305956582183,0.48239866,0.1393938808808905,5348.0,0.26872075,0.0860195397396055,2472.0,28822.132719278336,31655.35311293602,28822.132719278336,2830.545342445373,1.0511739253997805,0.0 -34100,2.2032485,1.2444124,,,,,,,,,,,,,, -34200,2.6393793,1.2330501,,,,,,,,,,,,,, -34300,1.8955432,1.2679648,,,,,,,,,,,,,, -34400,2.8831177,1.2543284,,,,,,,,,,,,,, -34500,3.65722,1.3013673,,,,,,,,,,,,,, -34600,3.3937953,1.2945179,,,,,,,,,,,,,, -34700,2.4194088,1.2272928,,,,,,,,,,,,,, -34800,3.768267,1.2559441,,,,,,,,,,,,,, -34900,2.318958,1.2795663,,,,,,,,,,,,,, -35000,2.9618537,1.3061415,,,,,,,,,,,,,, -35100,3.1833115,1.29024,,,,,,,,,,,,,, -35200,2.7391474,1.2214637,,,,,,,,,,,,,, -35300,3.1880586,1.2424357,,,,,,,,,,,,,, -35400,3.047967,1.2820332,,,,,,,,,,,,,, -35500,2.470822,1.2104884,,,,,,,,,,,,,, -35600,2.894937,1.2569494,,,,,,,,,,,,,, -35700,3.7578042,1.211461,,,,,,,,,,,,,, -35722,,,0.279953,0.096801433933893,0.4692077,0.136082334881296,5348.0,0.25975537,0.0832977880689781,2472.0,30262.328066825867,33230.11437892914,30262.328066825867,2964.973935842514,1.1049518585205078,0.0 -35800,3.1759045,1.2493587,,,,,,,,,,,,,, -35900,2.8517613,1.2892147,,,,,,,,,,,,,, -36000,1.9770558,1.2125905,,,,,,,,,,,,,, -36100,4.5915203,1.2524551,,,,,,,,,,,,,, -36200,3.636261,1.2121125,,,,,,,,,,,,,, -36300,2.5889754,1.2380694,,,,,,,,,,,,,, -36400,2.4797432,1.2899793,,,,,,,,,,,,,, -36500,2.9296799,1.2567618,,,,,,,,,,,,,, -36600,2.1254299,1.2302877,,,,,,,,,,,,,, -36700,3.1383114,1.2073481,,,,,,,,,,,,,, -36800,2.362559,1.2108973,,,,,,,,,,,,,, -36900,2.5411634,1.2601825,,,,,,,,,,,,,, -37000,2.622844,1.2765024,,,,,,,,,,,,,, -37100,2.635074,1.1778418,,,,,,,,,,,,,, -37200,2.2466412,1.2423297,,,,,,,,,,,,,, -37300,2.801066,1.2743433,,,,,,,,,,,,,, -37400,3.959563,1.1685213,,,,,,,,,,,,,, -37434,,,0.27683035,0.0935626034620739,0.4581124,0.1331666296571632,5348.0,0.2533827,0.0805760363983507,2472.0,31702.204449653625,34806.76887631416,31702.204449653625,3101.6122257709503,1.1635775566101074,0.0 -37500,2.6878188,1.2181299,,,,,,,,,,,,,, -37600,3.5668614,1.2357734,,,,,,,,,,,,,, -37700,8.556174,1.2332109,,,,,,,,,,,,,, -37800,2.7993314,1.1861838,,,,,,,,,,,,,, -37900,4.5939493,1.2823012,,,,,,,,,,,,,, -38000,3.4227433,1.180999,,,,,,,,,,,,,, -38100,3.1512191,1.1909723,,,,,,,,,,,,,, -38200,2.9847567,1.1990187,,,,,,,,,,,,,, -38300,5.477371,1.2163244,,,,,,,,,,,,,, -38400,7.9594026,1.2365594,,,,,,,,,,,,,, -38500,2.7255363,1.2110314,,,,,,,,,,,,,, -38600,2.824238,1.2249017,,,,,,,,,,,,,, -38700,2.4260502,1.2307922,,,,,,,,,,,,,, -38800,5.136106,1.2033522,,,,,,,,,,,,,, -38900,2.989393,1.2105091,,,,,,,,,,,,,, -39000,2.1404002,1.1740538,,,,,,,,,,,,,, -39100,3.4038367,1.168159,,,,,,,,,,,,,, -39117,,,0.27329835,0.0938528237962617,0.44634154,0.1302316151269104,5348.0,0.24226296,0.0769808868035667,2472.0,33142.71718811989,36381.66338968277,33142.71718811989,3235.8583705425262,1.2178065776824951,0.0 -39200,3.2937217,1.2232159,,,,,,,,,,,,,, -39300,3.0128603,1.1935414,,,,,,,,,,,,,, -39400,6.380681,1.2172344,,,,,,,,,,,,,, -39500,3.0744069,1.1628574,,,,,,,,,,,,,, -39600,3.4957867,1.1978984,,,,,,,,,,,,,, -39700,2.5959404,1.1714724,,,,,,,,,,,,,, -39800,2.9715135,1.2135277,,,,,,,,,,,,,, -39900,2.40512,1.1905068,,,,,,,,,,,,,, -40000,3.3474402,1.1835886,,,,,,,,,,,,,, -40100,3.048717,1.2453648,,,,,,,,,,,,,, -40200,4.2973633,1.1713176,,,,,,,,,,,,,, -40300,4.7604065,1.1803141,,,,,,,,,,,,,, -40400,3.6350887,1.1720703,,,,,,,,,,,,,, -40500,3.7147717,1.1520188,,,,,,,,,,,,,, -40600,2.961699,1.1507516,,,,,,,,,,,,,, -40700,3.0578434,1.205605,,,,,,,,,,,,,, -40800,3.682955,1.2274446,,,,,,,,,,,,,, -40807,,,0.26125973,0.0869634244483009,0.43017662,0.1248539733724668,5348.0,0.23748754,0.0763106046757256,2472.0,34583.25195026398,37956.318457603455,34583.25195026398,3369.8367607593536,1.2777178287506104,0.0 -40900,2.431764,1.1633621,,,,,,,,,,,,,, -41000,2.637277,1.1768701,,,,,,,,,,,,,, -41100,3.4670694,1.1150914,,,,,,,,,,,,,, -41200,3.4691405,1.1138252,,,,,,,,,,,,,, -41300,3.2590442,1.1745563,,,,,,,,,,,,,, -41400,3.1143248,1.1556953,,,,,,,,,,,,,, -41500,3.0967407,1.1516658,,,,,,,,,,,,,, -41600,3.850338,1.2085427,,,,,,,,,,,,,, -41700,3.557001,1.1566926,,,,,,,,,,,,,, -41800,2.442137,1.1662892,,,,,,,,,,,,,, -41900,2.7947695,1.1545721,,,,,,,,,,,,,, -42000,5.352449,1.1943983,,,,,,,,,,,,,, -42100,3.4136624,1.2176065,,,,,,,,,,,,,, -42200,2.976207,1.1185426,,,,,,,,,,,,,, -42300,2.4292557,1.1731488,,,,,,,,,,,,,, -42400,3.6425836,1.141552,,,,,,,,,,,,,, -42493,,,0.23121312,0.0805833849035775,0.42062643,0.1218706855769137,5348.0,0.23134708,0.0739747730180976,2472.0,36023.209208488464,39530.97588849068,36023.209208488464,3504.3984801769257,1.3354296684265137,0.0 -42500,2.6491275,1.1920593,,,,,,,,,,,,,, -42600,4.7814975,1.1586523,,,,,,,,,,,,,, -42700,2.7747948,1.1695793,,,,,,,,,,,,,, -42800,2.3171237,1.168946,,,,,,,,,,,,,, -42900,3.6794543,1.155412,,,,,,,,,,,,,, -43000,3.2302172,1.1432419,,,,,,,,,,,,,, -43100,2.4710715,1.159032,,,,,,,,,,,,,, -43200,2.7840326,1.159867,,,,,,,,,,,,,, -43300,3.2584362,1.1357551,,,,,,,,,,,,,, -43400,4.0827975,1.1221217,,,,,,,,,,,,,, -43500,2.5922272,1.2080395,,,,,,,,,,,,,, -43600,2.9375982,1.1584022,,,,,,,,,,,,,, -43700,5.708089,1.148655,,,,,,,,,,,,,, -43800,2.1318371,1.0845739,,,,,,,,,,,,,, -43900,3.7480416,1.116667,,,,,,,,,,,,,, -44000,3.0576224,1.1701664,,,,,,,,,,,,,, -44100,2.9245527,1.1172861,,,,,,,,,,,,,, -44175,,,0.2266846,0.0783503848103599,0.4167648,0.1213589889647315,5348.0,0.22691439,0.0724107813864684,2472.0,37463.483968019485,41106.83026814461,37463.483968019485,3639.836142063141,1.39624285697937,0.0 -44200,2.5622544,1.0905747,,,,,,,,,,,,,, -44300,4.626583,1.1768677,,,,,,,,,,,,,, -44400,3.254254,1.152445,,,,,,,,,,,,,, -44500,3.5006914,1.1734096,,,,,,,,,,,,,, -44600,3.57862,1.1513116,,,,,,,,,,,,,, -44700,2.6641636,1.1291035,,,,,,,,,,,,,, -44800,2.4231443,1.1372465,,,,,,,,,,,,,, -44900,3.4367845,1.161101,,,,,,,,,,,,,, -45000,4.8015075,1.1569929,,,,,,,,,,,,,, -45100,3.3575182,1.1273707,,,,,,,,,,,,,, -45200,2.209318,1.1423604,,,,,,,,,,,,,, -45300,3.4082947,1.1311226,,,,,,,,,,,,,, -45400,3.8962123,1.1321799,,,,,,,,,,,,,, -45500,2.320548,1.1579467,,,,,,,,,,,,,, -45600,2.7969189,1.1734473,,,,,,,,,,,,,, -45700,2.9439492,1.1002165,,,,,,,,,,,,,, -45800,2.2257,1.1605223,,,,,,,,,,,,,, -45875,,,0.21549411,0.07470870837705,0.4130225,0.1205962713729882,5348.0,0.2248237,0.0717811224178904,2472.0,38903.80582237244,42680.11693024635,38903.80582237244,3772.66555261612,1.4497272968292236,0.0 -45900,4.532062,1.1292852,,,,,,,,,,,,,, -46000,2.6489,1.1104892,,,,,,,,,,,,,, -46100,3.292837,1.1771842,,,,,,,,,,,,,, -46200,3.9841323,1.1027213,,,,,,,,,,,,,, -46300,3.1835048,1.0938745,,,,,,,,,,,,,, -46400,5.083072,1.1468396,,,,,,,,,,,,,, -46500,2.9898539,1.1967065,,,,,,,,,,,,,, -46600,3.0111947,1.1047055,,,,,,,,,,,,,, -46700,4.018654,1.1693646,,,,,,,,,,,,,, -46800,3.8248873,1.1298561,,,,,,,,,,,,,, -46900,2.9096513,1.1734431,,,,,,,,,,,,,, -47000,2.8059583,1.1628114,,,,,,,,,,,,,, -47100,3.6109707,1.1617988,,,,,,,,,,,,,, -47200,5.6517057,1.1211638,,,,,,,,,,,,,, -47300,2.2601626,1.1427479,,,,,,,,,,,,,, -47400,3.9663393,1.1512464,,,,,,,,,,,,,, -47500,2.3106203,1.145401,,,,,,,,,,,,,, -47567,,,0.22340734,0.0765954799448001,0.413083,0.1203452503934271,5348.0,0.22458875,0.0714764487234172,2472.0,40344.11418509483,44254.18384337425,40344.11418509483,3906.2811391353607,1.5107195377349854,0.0 -47600,2.8014867,1.1014278,,,,,,,,,,,,,, -47700,3.8437896,1.130714,,,,,,,,,,,,,, -47800,3.237839,1.129389,,,,,,,,,,,,,, -47900,2.816023,1.126314,,,,,,,,,,,,,, -48000,,,0.2246082,0.0775943396226415,0.41315985,0.1203452503934271,5348.0,0.22453754,0.0712936445067332,2472.0,40700.30122280121,44744.03060722351,40700.30122280121,4039.852513551712,1.5739483833312988,0.0 -48000,,,,,,,,,,,40700.30122280121,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 6aaab688a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -160.12119221687317,0.0,15.801752090454102,1,0,15.801752090454102,28.256042,2472,2.3442000284362114,175.9230306148529,28.983004,2.33386175402295,28.14912,5348,2.185417612018112 -290.99935603141785,0.0297057628631591,1456.0460293293,1688,0,1456.0460293293,1.3364433,2472,0.374322101029797,1747.1521837711334,1.6165359,0.4345034367663282,1.767953,5348,0.4371240719464746 -425.304594039917,0.0848915576934814,2896.040989637375,3393,0,2896.040989637375,0.71331537,2472,0.2247273170434464,3321.5905945301056,0.9240469,0.2783773278611686,1.0730573,5348,0.2966971431881595 -559.1721725463867,0.1350078582763672,4336.36754155159,5060,0,4336.36754155159,0.6102613,2472,0.1905226169439197,4895.912990808487,0.891339,0.2699624467434197,0.94086546,5348,0.2633403168657134 -696.0269293785095,0.1918809413909912,5776.8087022304535,6747,0,5776.8087022304535,0.57886106,2472,0.1817683261227225,6473.349189996719,0.77726054,0.2364288038052074,0.8973155,5348,0.2494376164592525 -831.1587114334106,0.2507824897766113,7217.110549926758,8445,0,7217.110549926758,0.59052604,2472,0.1910100948550769,8048.926936626434,0.8788927,0.2620743476444855,0.9430648,5348,0.2627127644168107 -966.1767551898956,0.3870253562927246,8657.196389913559,10120,0,8657.196389913559,0.52912104,2472,0.1654784392582211,9624.246983766556,0.711383,0.216222748320544,0.83408403,5348,0.2332853818898018 -1102.631802558899,0.4416139125823974,10097.694558858871,11820,0,10097.694558858871,0.51529604,2472,0.1637113318302764,11201.33668255806,0.7196112,0.2243399922553017,0.83782446,5348,0.2366355465016364 -1238.2635581493378,0.4966261386871338,11537.564745664597,13509,0,11537.564745664597,0.5031106,2472,0.1594459001076513,12776.975717544556,0.68961746,0.2111366964375537,0.8070212,5348,0.2289504426658428 -1384.0197823047638,0.5527534484863281,12977.657584190369,15220,0,12977.657584190369,0.4793862,2472,0.1495541608270875,14362.963431596756,0.44087306,0.1439790155509756,0.77578115,5348,0.218031030054935 -1526.2666816711426,0.6094729900360107,14418.166572093964,16948,0,14418.166572093964,0.46206802,2472,0.148721386062194,15945.859293460846,0.42074928,0.1369093745802465,0.75233626,5348,0.2137443640962762 -1664.0769486427307,0.6706218719482422,15858.704028606417,18622,0,15858.704028606417,0.458045,2472,0.1455730912193041,17524.35138106346,0.41944084,0.1386746170712854,0.7515029,5348,0.213251976790214 -1802.022501707077,0.7339036464691162,17298.691420793533,20312,0,17298.691420793533,0.4342024,2472,0.1382609225519468,19102.427579641346,0.37821147,0.126417685163833,0.7153456,5348,0.2044276238933353 -1937.7433910369875,0.7955629825592041,18738.98100972176,22015,0,18738.98100972176,0.41897473,2472,0.1338533097719009,20678.58448934555,0.38218117,0.1272201019168034,0.69446677,5348,0.1973700725064444 -2073.9834084510803,0.8556327819824219,20178.99764108658,23682,0,20178.99764108658,0.40583488,2472,0.1318221518087461,22254.981443166733,0.35579237,0.1210676985855049,0.69894844,5348,0.1971093968738233 -2209.2528836727142,0.906546115875244,21619.240039110184,25354,0,21619.240039110184,0.38012487,2472,0.1212804419799727,23830.625638246536,0.3752873,0.1209741880406944,0.635908,5348,0.1815654054471552 -2346.485763788224,0.9582936763763428,23059.45344424248,27057,0,23059.45344424248,0.36712766,2472,0.1185993134686084,25408.206853866577,0.32275435,0.1098269556544179,0.6262908,5348,0.1797986039371675 -2484.727519273758,1.015852451324463,24500.440947532654,28739,0,24500.440947532654,0.3531975,2472,0.1150854101923506,26987.57475304604,0.2909292,0.0987520261869914,0.6012382,5348,0.1724996862237755 -2623.338259458542,1.0719337463378906,25941.052482128143,30429,0,25941.052482128143,0.33602273,2472,0.1067982857026791,28566.935750246048,0.27494448,0.0924123634624073,0.58237994,5348,0.1684640412446778 -2760.744163513184,1.128051519393921,27381.33598518372,32123,0,27381.33598518372,0.31771612,2472,0.1033046940060528,30144.7624783516,0.2601245,0.0908591731266149,0.5635271,5348,0.160508607123203 -2897.3103017807007,1.1814250946044922,28821.37881207466,33800,0,28821.37881207466,0.306206,2472,0.0981252412000081,31721.50653815269,0.26487103,0.0886311124678291,0.5352882,5348,0.153209689409811 -3033.224186897278,1.2355103492736816,30261.592493772507,35496,0,30261.592493772507,0.28559756,2472,0.0918286515142282,33297.76939201355,0.24313995,0.0822007972071422,0.5057727,5348,0.1451480541046757 -3171.0469262599945,1.2923777103424072,31701.920878887177,37194,0,31701.920878887177,0.2720657,2472,0.0870554303008144,34876.05977129936,0.21898215,0.0749193774173812,0.48544854,5348,0.1398573042277725 -3309.1948528289795,1.345574140548706,33141.95910906792,38888,0,33141.95910906792,0.2549981,2472,0.0828306217374525,36454.38322663307,0.2012175,0.0686038810344557,0.46407133,5348,0.1329349179837222 -3446.0249009132385,1.402059555053711,34582.12906932831,40597,0,34582.12906932831,0.24118751,2472,0.0770215099628298,38031.52192878723,0.16660897,0.0576332048370914,0.44376504,5348,0.1274510750456182 -3584.731765270233,1.4607226848602295,36022.32275509834,42267,0,36022.32275509834,0.23010056,2472,0.0739341498588345,39610.5627257824,0.16422431,0.0553903991336949,0.4237094,5348,0.1222279077401353 -3721.962750196457,1.517216682434082,37463.19466948509,43960,0,37463.19466948509,0.22181489,2472,0.0717811224178904,41188.80369210243,0.16956198,0.0570378974095783,0.41393122,5348,0.1190225629242013 -3856.989844560623,1.5770831108093262,38903.46937012672,45652,0,38903.46937012672,0.21725857,2472,0.0700140149899457,42764.24932599068,0.14950223,0.0494265207849142,0.4087675,5348,0.1171012869652529 -3994.8145458698273,1.6236886978149414,40343.7261993885,47317,0,40343.7261993885,0.21572423,2472,0.0690999939065261,44342.45700645447,0.16066147,0.0549748414936208,0.4061986,5348,0.1159137646388677 -4131.348959207535,1.6855216026306152,40915.90963792801,48000,0,40915.90963792801,0.21571341,2472,0.06903905916763146,45051.27542638779,0.18988594,0.05882446800023402,0.40646556,5348,0.11600065651640808 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/measurements.csv deleted file mode 100644 index d9c07ca4b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.402025,33.182224,,,,,,,,,,,,,, -1,,,28.983004,2.33386175402295,28.14912,2.185417612018112,5348.0,28.256042,2.3442000284362114,2472.0,15.801752090454102,175.9230306148529,15.801752090454102,160.12119221687317,0.0,0.0 -100,2.679085,5.772139,,,,,,,,,,,,,, -200,2.7632885,4.759522,,,,,,,,,,,,,, -300,2.3266256,3.6058547,,,,,,,,,,,,,, -400,2.1590006,3.2106543,,,,,,,,,,,,,, -500,2.180483,3.0468225,,,,,,,,,,,,,, -600,2.348636,2.8213403,,,,,,,,,,,,,, -700,2.1228402,2.6121035,,,,,,,,,,,,,, -800,2.450852,2.5945804,,,,,,,,,,,,,, -900,3.2810428,2.5915444,,,,,,,,,,,,,, -1000,3.0836682,2.5805156,,,,,,,,,,,,,, -1100,2.1638103,2.372737,,,,,,,,,,,,,, -1200,2.3843753,2.3358305,,,,,,,,,,,,,, -1300,3.7789755,2.3662305,,,,,,,,,,,,,, -1400,2.680866,2.3000336,,,,,,,,,,,,,, -1500,2.172665,2.2394242,,,,,,,,,,,,,, -1600,3.2575545,2.2793574,,,,,,,,,,,,,, -1688,,,1.6165359,0.4345034367663282,1.767953,0.4371240719464746,5348.0,1.3364433,0.374322101029797,2472.0,1456.0460293293,1747.1521837711334,1456.0460293293,290.99935603141785,0.0297057628631591,0.0 -1700,2.4285355,2.215107,,,,,,,,,,,,,, -1800,3.038354,2.1840332,,,,,,,,,,,,,, -1900,3.056319,2.180257,,,,,,,,,,,,,, -2000,5.7863064,2.1576147,,,,,,,,,,,,,, -2100,3.2626357,2.144678,,,,,,,,,,,,,, -2200,5.669501,2.095285,,,,,,,,,,,,,, -2300,4.141226,2.1571527,,,,,,,,,,,,,, -2400,2.0999765,2.1142812,,,,,,,,,,,,,, -2500,3.6034653,2.1122825,,,,,,,,,,,,,, -2600,4.1103745,2.0720584,,,,,,,,,,,,,, -2700,3.40626,2.050753,,,,,,,,,,,,,, -2800,3.6083953,2.0730243,,,,,,,,,,,,,, -2900,2.79783,2.046029,,,,,,,,,,,,,, -3000,1.8778179,2.0617385,,,,,,,,,,,,,, -3100,3.4412627,2.0132868,,,,,,,,,,,,,, -3200,3.4494417,2.0442274,,,,,,,,,,,,,, -3300,3.3756604,2.0936143,,,,,,,,,,,,,, -3393,,,0.9240469,0.2783773278611686,1.0730573,0.2966971431881595,5348.0,0.71331537,0.2247273170434464,2472.0,2896.040989637375,3321.5905945301056,2896.040989637375,425.304594039917,0.0848915576934814,0.0 -3400,2.4839432,2.0671134,,,,,,,,,,,,,, -3500,2.7382452,2.024447,,,,,,,,,,,,,, -3600,4.122017,2.0354633,,,,,,,,,,,,,, -3700,2.8532073,2.0609684,,,,,,,,,,,,,, -3800,4.2972326,1.9819506,,,,,,,,,,,,,, -3900,4.3768244,1.9752278,,,,,,,,,,,,,, -4000,3.6324136,2.0129268,,,,,,,,,,,,,, -4100,2.7815678,1.9922222,,,,,,,,,,,,,, -4200,5.2241344,1.9282914,,,,,,,,,,,,,, -4300,2.3758097,1.9902561,,,,,,,,,,,,,, -4400,2.2320104,1.9671952,,,,,,,,,,,,,, -4500,3.057143,2.0341947,,,,,,,,,,,,,, -4600,2.7976542,2.0119948,,,,,,,,,,,,,, -4700,7.24488,2.0050962,,,,,,,,,,,,,, -4800,3.255035,1.9337038,,,,,,,,,,,,,, -4900,3.2504687,1.8924412,,,,,,,,,,,,,, -5000,2.7521684,2.0069475,,,,,,,,,,,,,, -5060,,,0.891339,0.2699624467434197,0.94086546,0.2633403168657134,5348.0,0.6102613,0.1905226169439197,2472.0,4336.36754155159,4895.912990808487,4336.36754155159,559.1721725463867,0.1350078582763672,0.0 -5100,3.226483,2.0870478,,,,,,,,,,,,,, -5200,2.9891715,1.9791205,,,,,,,,,,,,,, -5300,1.9661472,1.9332005,,,,,,,,,,,,,, -5400,4.7280006,1.9262583,,,,,,,,,,,,,, -5500,4.087809,1.9495469,,,,,,,,,,,,,, -5600,2.793913,1.9593377,,,,,,,,,,,,,, -5700,2.5189416,1.983565,,,,,,,,,,,,,, -5800,2.6883297,1.9103695,,,,,,,,,,,,,, -5900,2.7731576,1.9524264,,,,,,,,,,,,,, -6000,3.896665,1.9514835,,,,,,,,,,,,,, -6100,2.0712268,1.982647,,,,,,,,,,,,,, -6200,3.549993,1.9048715,,,,,,,,,,,,,, -6300,2.9417589,1.8999379,,,,,,,,,,,,,, -6400,2.866662,1.9200985,,,,,,,,,,,,,, -6500,6.46656,2.0359862,,,,,,,,,,,,,, -6600,3.5624003,1.9022435,,,,,,,,,,,,,, -6700,4.9184575,1.9310429,,,,,,,,,,,,,, -6747,,,0.77726054,0.2364288038052074,0.8973155,0.2494376164592525,5348.0,0.57886106,0.1817683261227225,2472.0,5776.8087022304535,6473.349189996719,5776.8087022304535,696.0269293785095,0.1918809413909912,0.0 -6800,2.7176163,1.9436809,,,,,,,,,,,,,, -6900,5.1440053,1.9615953,,,,,,,,,,,,,, -7000,3.9231067,1.8999319,,,,,,,,,,,,,, -7100,2.6946511,1.8099413,,,,,,,,,,,,,, -7200,2.1283677,1.9594963,,,,,,,,,,,,,, -7300,2.7589526,1.8024377,,,,,,,,,,,,,, -7400,3.3087645,1.9079921,,,,,,,,,,,,,, -7500,1.8875259,1.818358,,,,,,,,,,,,,, -7600,3.431,1.9141575,,,,,,,,,,,,,, -7700,3.089834,1.9173094,,,,,,,,,,,,,, -7800,2.7078211,1.9230124,,,,,,,,,,,,,, -7900,3.5810082,1.8298855,,,,,,,,,,,,,, -8000,2.6977072,1.7644106,,,,,,,,,,,,,, -8100,2.499477,1.9230084,,,,,,,,,,,,,, -8200,2.9069834,1.9074574,,,,,,,,,,,,,, -8300,1.6274717,1.8440002,,,,,,,,,,,,,, -8400,3.4830139,2.2143133,,,,,,,,,,,,,, -8445,,,0.8788927,0.2620743476444855,0.9430648,0.2627127644168107,5348.0,0.59052604,0.1910100948550769,2472.0,7217.110549926758,8048.926936626434,7217.110549926758,831.1587114334106,0.2507824897766113,0.0 -8500,1.9773184,1.8565063,,,,,,,,,,,,,, -8600,2.8339155,1.8514671,,,,,,,,,,,,,, -8700,2.1689944,1.84947,,,,,,,,,,,,,, -8800,3.7861598,1.9215717,,,,,,,,,,,,,, -8900,1.8640302,1.7972558,,,,,,,,,,,,,, -9000,2.3174043,1.8224236,,,,,,,,,,,,,, -9100,2.0008965,1.8580443,,,,,,,,,,,,,, -9200,2.6861746,1.7866447,,,,,,,,,,,,,, -9300,3.3915203,1.7835224,,,,,,,,,,,,,, -9400,3.6509902,1.7912948,,,,,,,,,,,,,, -9500,2.4303596,1.9066842,,,,,,,,,,,,,, -9600,2.980489,1.746624,,,,,,,,,,,,,, -9700,2.3174376,1.8691227,,,,,,,,,,,,,, -9800,3.0511239,1.8000435,,,,,,,,,,,,,, -9900,3.8317661,1.8144952,,,,,,,,,,,,,, -10000,3.5085502,1.8967698,,,,,,,,,,,,,, -10100,2.2281725,1.8144652,,,,,,,,,,,,,, -10120,,,0.711383,0.216222748320544,0.83408403,0.2332853818898018,5348.0,0.52912104,0.1654784392582211,2472.0,8657.196389913559,9624.246983766556,8657.196389913559,966.1767551898956,0.3870253562927246,0.0 -10200,2.7257411,1.8571098,,,,,,,,,,,,,, -10300,2.6226394,1.8040175,,,,,,,,,,,,,, -10400,2.8272834,1.8098073,,,,,,,,,,,,,, -10500,5.191779,1.8444442,,,,,,,,,,,,,, -10600,3.6067157,1.8593233,,,,,,,,,,,,,, -10700,2.5331256,1.861956,,,,,,,,,,,,,, -10800,3.031512,1.8340915,,,,,,,,,,,,,, -10900,14.24997,2.2990036,,,,,,,,,,,,,, -11000,2.5948877,1.9076046,,,,,,,,,,,,,, -11100,2.7550325,1.863475,,,,,,,,,,,,,, -11200,2.9268878,1.7585795,,,,,,,,,,,,,, -11300,2.4003637,1.7750382,,,,,,,,,,,,,, -11400,3.8205843,1.8111521,,,,,,,,,,,,,, -11500,3.8867106,1.8445376,,,,,,,,,,,,,, -11600,2.0911806,1.8408862,,,,,,,,,,,,,, -11700,2.5667474,1.7542974,,,,,,,,,,,,,, -11800,2.1948133,1.788257,,,,,,,,,,,,,, -11820,,,0.7196112,0.2243399922553017,0.83782446,0.2366355465016364,5348.0,0.51529604,0.1637113318302764,2472.0,10097.694558858871,11201.33668255806,10097.694558858871,1102.631802558899,0.4416139125823974,0.0 -11900,2.4445379,1.7899476,,,,,,,,,,,,,, -12000,3.103115,1.7575577,,,,,,,,,,,,,, -12100,2.9072938,1.8348047,,,,,,,,,,,,,, -12200,2.3855426,1.7550339,,,,,,,,,,,,,, -12300,1.9582512,1.8392292,,,,,,,,,,,,,, -12400,3.7374265,1.8164203,,,,,,,,,,,,,, -12500,5.171058,1.8344153,,,,,,,,,,,,,, -12600,3.5440097,1.787369,,,,,,,,,,,,,, -12700,2.8996625,1.7143142,,,,,,,,,,,,,, -12800,3.3319342,1.8052906,,,,,,,,,,,,,, -12900,2.0133145,1.7891711,,,,,,,,,,,,,, -13000,2.8044505,1.7630744,,,,,,,,,,,,,, -13100,4.067434,1.8030671,,,,,,,,,,,,,, -13200,2.2785368,1.8735725,,,,,,,,,,,,,, -13300,3.831596,1.7398248,,,,,,,,,,,,,, -13400,4.767216,1.8314523,,,,,,,,,,,,,, -13500,4.2835093,1.7373546,,,,,,,,,,,,,, -13509,,,0.68961746,0.2111366964375537,0.8070212,0.2289504426658428,5348.0,0.5031106,0.1594459001076513,2472.0,11537.564745664597,12776.975717544556,11537.564745664597,1238.2635581493378,0.4966261386871338,0.0 -13600,1.9557632,1.7704852,,,,,,,,,,,,,, -13700,2.9038727,1.786328,,,,,,,,,,,,,, -13800,2.9076817,1.7231454,,,,,,,,,,,,,, -13900,2.7732215,1.8083725,,,,,,,,,,,,,, -14000,2.6018064,1.796545,,,,,,,,,,,,,, -14100,6.036249,1.7502131,,,,,,,,,,,,,, -14200,3.5084174,1.7179226,,,,,,,,,,,,,, -14300,2.3091414,1.778173,,,,,,,,,,,,,, -14400,1.94116,1.7332479,,,,,,,,,,,,,, -14500,2.8619084,1.7795618,,,,,,,,,,,,,, -14600,2.3991668,1.73222,,,,,,,,,,,,,, -14700,2.927577,1.696835,,,,,,,,,,,,,, -14800,2.1981397,1.7597377,,,,,,,,,,,,,, -14900,3.7178814,1.7605896,,,,,,,,,,,,,, -15000,2.269092,1.7499211,,,,,,,,,,,,,, -15100,2.7136462,1.7131618,,,,,,,,,,,,,, -15200,2.1235545,1.7205948,,,,,,,,,,,,,, -15220,,,0.44087306,0.1439790155509756,0.77578115,0.218031030054935,5348.0,0.4793862,0.1495541608270875,2472.0,12977.657584190369,14362.963431596756,12977.657584190369,1384.0197823047638,0.5527534484863281,0.0 -15300,2.1244311,1.7531749,,,,,,,,,,,,,, -15400,2.6752336,1.766789,,,,,,,,,,,,,, -15500,2.6575081,1.7531359,,,,,,,,,,,,,, -15600,1.6609335,1.6920135,,,,,,,,,,,,,, -15700,2.8268652,1.721535,,,,,,,,,,,,,, -15800,3.0482552,1.7460228,,,,,,,,,,,,,, -15900,2.6414835,1.7851447,,,,,,,,,,,,,, -16000,1.6753511,1.6529676,,,,,,,,,,,,,, -16100,2.885196,1.7176877,,,,,,,,,,,,,, -16200,3.7562754,1.7016506,,,,,,,,,,,,,, -16300,3.2309036,1.7337633,,,,,,,,,,,,,, -16400,1.5874014,1.73501,,,,,,,,,,,,,, -16500,2.3049748,1.8194934,,,,,,,,,,,,,, -16600,1.6829145,1.6794695,,,,,,,,,,,,,, -16700,2.8906102,1.6699677,,,,,,,,,,,,,, -16800,2.6293771,1.7170092,,,,,,,,,,,,,, -16900,2.6821408,1.7458239,,,,,,,,,,,,,, -16948,,,0.42074928,0.1369093745802465,0.75233626,0.2137443640962762,5348.0,0.46206802,0.148721386062194,2472.0,14418.166572093964,15945.859293460846,14418.166572093964,1526.2666816711426,0.6094729900360107,0.0 -17000,1.8597531,1.6409502,,,,,,,,,,,,,, -17100,1.7847077,1.6500641,,,,,,,,,,,,,, -17200,1.9909204,1.6845772,,,,,,,,,,,,,, -17300,2.2917953,1.7222109,,,,,,,,,,,,,, -17400,1.902858,1.7222935,,,,,,,,,,,,,, -17500,2.3319652,1.7102234,,,,,,,,,,,,,, -17600,4.2645206,1.6788515,,,,,,,,,,,,,, -17700,2.315838,1.6927559,,,,,,,,,,,,,, -17800,3.7232327,1.6831633,,,,,,,,,,,,,, -17900,2.1863854,1.7206755,,,,,,,,,,,,,, -18000,3.0093179,1.7550792,,,,,,,,,,,,,, -18100,2.9750955,1.8448002,,,,,,,,,,,,,, -18200,1.6647451,1.7788349,,,,,,,,,,,,,, -18300,2.161308,1.7111443,,,,,,,,,,,,,, -18400,2.7273772,1.7701867,,,,,,,,,,,,,, -18500,2.2467082,1.7429008,,,,,,,,,,,,,, -18600,2.0474627,1.6983312,,,,,,,,,,,,,, -18622,,,0.41944084,0.1386746170712854,0.7515029,0.213251976790214,5348.0,0.458045,0.1455730912193041,2472.0,15858.704028606417,17524.35138106346,15858.704028606417,1664.0769486427307,0.6706218719482422,0.0 -18700,3.6217983,1.7134703,,,,,,,,,,,,,, -18800,2.7596323,1.6711714,,,,,,,,,,,,,, -18900,2.3670728,1.6936857,,,,,,,,,,,,,, -19000,2.0139146,1.6889627,,,,,,,,,,,,,, -19100,2.5081844,1.7013142,,,,,,,,,,,,,, -19200,2.7627916,1.7470373,,,,,,,,,,,,,, -19300,2.808157,1.6995893,,,,,,,,,,,,,, -19400,2.9473925,1.6901183,,,,,,,,,,,,,, -19500,3.1359706,1.6885021,,,,,,,,,,,,,, -19600,4.0867267,1.703784,,,,,,,,,,,,,, -19700,2.797865,1.7224311,,,,,,,,,,,,,, -19800,1.6558168,1.6515304,,,,,,,,,,,,,, -19900,1.8963374,1.6313272,,,,,,,,,,,,,, -20000,2.0872047,1.6741861,,,,,,,,,,,,,, -20100,2.7535408,1.703382,,,,,,,,,,,,,, -20200,2.6909451,1.6338544,,,,,,,,,,,,,, -20300,5.098877,1.6403631,,,,,,,,,,,,,, -20312,,,0.37821147,0.126417685163833,0.7153456,0.2044276238933353,5348.0,0.4342024,0.1382609225519468,2472.0,17298.691420793533,19102.427579641346,17298.691420793533,1802.022501707077,0.7339036464691162,0.0 -20400,7.290337,1.7055498,,,,,,,,,,,,,, -20500,3.6134255,1.6812786,,,,,,,,,,,,,, -20600,3.6259065,1.6608372,,,,,,,,,,,,,, -20700,2.2075737,1.702934,,,,,,,,,,,,,, -20800,2.5082467,1.6683165,,,,,,,,,,,,,, -20900,3.1041892,1.7460183,,,,,,,,,,,,,, -21000,2.1619456,1.6237508,,,,,,,,,,,,,, -21100,3.6147165,1.635853,,,,,,,,,,,,,, -21200,2.3735025,1.6452264,,,,,,,,,,,,,, -21300,2.9886012,1.6913419,,,,,,,,,,,,,, -21400,2.0977998,1.6304076,,,,,,,,,,,,,, -21500,3.5858335,1.6914097,,,,,,,,,,,,,, -21600,2.4487755,1.653933,,,,,,,,,,,,,, -21700,2.2446814,1.6288488,,,,,,,,,,,,,, -21800,5.0543327,1.7141923,,,,,,,,,,,,,, -21900,2.3363483,1.6000116,,,,,,,,,,,,,, -22000,2.081191,1.6411935,,,,,,,,,,,,,, -22015,,,0.38218117,0.1272201019168034,0.69446677,0.1973700725064444,5348.0,0.41897473,0.1338533097719009,2472.0,18738.98100972176,20678.58448934555,18738.98100972176,1937.7433910369875,0.7955629825592041,0.0 -22100,3.4874516,1.6533824,,,,,,,,,,,,,, -22200,2.1513636,1.7041137,,,,,,,,,,,,,, -22300,2.1246245,1.6204642,,,,,,,,,,,,,, -22400,3.0804942,1.601451,,,,,,,,,,,,,, -22500,3.3619714,1.6760914,,,,,,,,,,,,,, -22600,2.0688043,1.6343708,,,,,,,,,,,,,, -22700,2.1243107,1.610533,,,,,,,,,,,,,, -22800,2.199071,1.6383865,,,,,,,,,,,,,, -22900,3.3798106,1.603978,,,,,,,,,,,,,, -23000,1.8751509,1.5740948,,,,,,,,,,,,,, -23100,2.998097,1.6920183,,,,,,,,,,,,,, -23200,2.1930213,1.5966421,,,,,,,,,,,,,, -23300,2.4366825,1.5630035,,,,,,,,,,,,,, -23400,2.4127874,1.5405922,,,,,,,,,,,,,, -23500,1.840619,1.5778962,,,,,,,,,,,,,, -23600,2.1010058,1.5712845,,,,,,,,,,,,,, -23682,,,0.35579237,0.1210676985855049,0.69894844,0.1971093968738233,5348.0,0.40583488,0.1318221518087461,2472.0,20178.99764108658,22254.981443166733,20178.99764108658,2073.9834084510803,0.8556327819824219,0.0 -23700,2.0779932,1.5268558,,,,,,,,,,,,,, -23800,1.5299577,1.5389427,,,,,,,,,,,,,, -23900,2.4159918,1.536626,,,,,,,,,,,,,, -24000,2.1967633,1.67132,,,,,,,,,,,,,, -24100,1.4086194,1.6443528,,,,,,,,,,,,,, -24200,1.8917248,1.569824,,,,,,,,,,,,,, -24300,3.2544427,1.6301624,,,,,,,,,,,,,, -24400,3.0433269,1.5305064,,,,,,,,,,,,,, -24500,2.2272599,1.6086618,,,,,,,,,,,,,, -24600,2.1617758,1.5644957,,,,,,,,,,,,,, -24700,2.3001528,1.5969658,,,,,,,,,,,,,, -24800,2.854851,1.57562,,,,,,,,,,,,,, -24900,3.188765,1.5607198,,,,,,,,,,,,,, -25000,3.2209854,1.6003052,,,,,,,,,,,,,, -25100,2.4602046,1.5800676,,,,,,,,,,,,,, -25200,2.250078,1.5271399,,,,,,,,,,,,,, -25300,2.4497054,1.5877502,,,,,,,,,,,,,, -25354,,,0.3752873,0.1209741880406944,0.635908,0.1815654054471552,5348.0,0.38012487,0.1212804419799727,2472.0,21619.240039110184,23830.625638246536,21619.240039110184,2209.2528836727142,0.906546115875244,0.0 -25400,2.2493784,1.586888,,,,,,,,,,,,,, -25500,2.6280737,1.543523,,,,,,,,,,,,,, -25600,2.3791952,1.489314,,,,,,,,,,,,,, -25700,3.5259876,1.5655743,,,,,,,,,,,,,, -25800,1.7513149,1.5416929,,,,,,,,,,,,,, -25900,2.4915295,1.5562725,,,,,,,,,,,,,, -26000,3.539939,1.551396,,,,,,,,,,,,,, -26100,3.1977348,1.5344574,,,,,,,,,,,,,, -26200,2.8235133,1.512681,,,,,,,,,,,,,, -26300,2.8021786,1.5530623,,,,,,,,,,,,,, -26400,1.9628384,1.5185298,,,,,,,,,,,,,, -26500,1.8101431,1.5755364,,,,,,,,,,,,,, -26600,2.2407837,1.593983,,,,,,,,,,,,,, -26700,5.2705846,1.5891882,,,,,,,,,,,,,, -26800,1.8981128,1.5452536,,,,,,,,,,,,,, -26900,2.0853176,1.5620961,,,,,,,,,,,,,, -27000,2.9911623,1.5389681,,,,,,,,,,,,,, -27057,,,0.32275435,0.1098269556544179,0.6262908,0.1797986039371675,5348.0,0.36712766,0.1185993134686084,2472.0,23059.45344424248,25408.206853866577,23059.45344424248,2346.485763788224,0.9582936763763428,0.0 -27100,3.0404348,1.538788,,,,,,,,,,,,,, -27200,2.3972895,1.5682088,,,,,,,,,,,,,, -27300,2.7970881,1.5857118,,,,,,,,,,,,,, -27400,3.426501,1.5734181,,,,,,,,,,,,,, -27500,1.9576288,1.5157492,,,,,,,,,,,,,, -27600,2.0692675,1.573466,,,,,,,,,,,,,, -27700,2.2570777,1.5680972,,,,,,,,,,,,,, -27800,2.393945,1.5420688,,,,,,,,,,,,,, -27900,2.4204602,1.5523721,,,,,,,,,,,,,, -28000,2.6195972,1.5290691,,,,,,,,,,,,,, -28100,1.8332328,1.5047457,,,,,,,,,,,,,, -28200,1.7429999,1.5870053,,,,,,,,,,,,,, -28300,2.8257203,1.5479617,,,,,,,,,,,,,, -28400,2.5547836,1.540502,,,,,,,,,,,,,, -28500,2.9956334,1.5815994,,,,,,,,,,,,,, -28600,2.1099079,1.5111977,,,,,,,,,,,,,, -28700,2.5100908,1.5244898,,,,,,,,,,,,,, -28739,,,0.2909292,0.0987520261869914,0.6012382,0.1724996862237755,5348.0,0.3531975,0.1150854101923506,2472.0,24500.440947532654,26987.57475304604,24500.440947532654,2484.727519273758,1.015852451324463,0.0 -28800,2.4801257,1.5066308,,,,,,,,,,,,,, -28900,2.3812652,1.4591476,,,,,,,,,,,,,, -29000,2.677497,1.447249,,,,,,,,,,,,,, -29100,3.9448705,1.4414628,,,,,,,,,,,,,, -29200,2.3891904,1.4949087,,,,,,,,,,,,,, -29300,2.2785666,1.4879693,,,,,,,,,,,,,, -29400,2.2448938,1.4289092,,,,,,,,,,,,,, -29500,2.9058754,1.5871772,,,,,,,,,,,,,, -29600,2.6398127,1.497914,,,,,,,,,,,,,, -29700,2.2880173,1.4825245,,,,,,,,,,,,,, -29800,2.5556254,1.5485088,,,,,,,,,,,,,, -29900,2.228269,1.5199765,,,,,,,,,,,,,, -30000,2.5770693,1.4517375,,,,,,,,,,,,,, -30100,2.6594903,1.460034,,,,,,,,,,,,,, -30200,1.9548686,1.4818919,,,,,,,,,,,,,, -30300,2.2980855,1.4912974,,,,,,,,,,,,,, -30400,3.2442787,1.4784914,,,,,,,,,,,,,, -30429,,,0.27494448,0.0924123634624073,0.58237994,0.1684640412446778,5348.0,0.33602273,0.1067982857026791,2472.0,25941.052482128143,28566.935750246048,25941.052482128143,2623.338259458542,1.0719337463378906,0.0 -30500,2.2365673,1.4523531,,,,,,,,,,,,,, -30600,2.1447601,1.456797,,,,,,,,,,,,,, -30700,1.6983795,1.4088892,,,,,,,,,,,,,, -30800,1.9785447,1.4500053,,,,,,,,,,,,,, -30900,2.6248233,1.4468365,,,,,,,,,,,,,, -31000,4.072221,1.3973033,,,,,,,,,,,,,, -31100,2.1316264,1.3950686,,,,,,,,,,,,,, -31200,2.8320255,1.4916406,,,,,,,,,,,,,, -31300,1.9277601,1.4992411,,,,,,,,,,,,,, -31400,3.5428734,1.441722,,,,,,,,,,,,,, -31500,2.7442431,1.4617574,,,,,,,,,,,,,, -31600,1.9712093,1.4991621,,,,,,,,,,,,,, -31700,2.931349,1.4691665,,,,,,,,,,,,,, -31800,1.782046,1.5105525,,,,,,,,,,,,,, -31900,2.379107,1.4009844,,,,,,,,,,,,,, -32000,2.3232253,1.4504502,,,,,,,,,,,,,, -32100,3.7316923,1.4429258,,,,,,,,,,,,,, -32123,,,0.2601245,0.0908591731266149,0.5635271,0.160508607123203,5348.0,0.31771612,0.1033046940060528,2472.0,27381.33598518372,30144.7624783516,27381.33598518372,2760.744163513184,1.128051519393921,0.0 -32200,2.342091,1.4297441,,,,,,,,,,,,,, -32300,2.5517504,1.4344054,,,,,,,,,,,,,, -32400,2.8314574,1.4636794,,,,,,,,,,,,,, -32500,2.6648884,1.404216,,,,,,,,,,,,,, -32600,4.4626975,1.3904401,,,,,,,,,,,,,, -32700,2.5264573,1.4938525,,,,,,,,,,,,,, -32800,2.3558104,1.4575362,,,,,,,,,,,,,, -32900,2.70173,1.4111435,,,,,,,,,,,,,, -33000,1.9615967,1.3394873,,,,,,,,,,,,,, -33100,2.9912934,1.379305,,,,,,,,,,,,,, -33200,2.0992863,1.4362907,,,,,,,,,,,,,, -33300,3.5142536,1.3768836,,,,,,,,,,,,,, -33400,2.505246,1.4238667,,,,,,,,,,,,,, -33500,2.7156932,1.3854321,,,,,,,,,,,,,, -33600,2.435697,1.406179,,,,,,,,,,,,,, -33700,2.4975386,1.4235722,,,,,,,,,,,,,, -33800,,,0.26487103,0.0886311124678291,0.5352882,0.153209689409811,5348.0,0.306206,0.0981252412000081,2472.0,28821.37881207466,31721.50653815269,28821.37881207466,2897.3103017807007,1.1814250946044922,0.0 -33800,2.4305685,1.4425074,,,,,,,,,,,,,, -33900,2.3204021,1.3690442,,,,,,,,,,,,,, -34000,1.769089,1.380257,,,,,,,,,,,,,, -34100,2.0287433,1.3733879,,,,,,,,,,,,,, -34200,2.7895546,1.3629147,,,,,,,,,,,,,, -34300,2.5930367,1.3825613,,,,,,,,,,,,,, -34400,1.994054,1.4276241,,,,,,,,,,,,,, -34500,2.2332928,1.3698657,,,,,,,,,,,,,, -34600,1.9137244,1.4100655,,,,,,,,,,,,,, -34700,1.9542568,1.3591446,,,,,,,,,,,,,, -34800,2.7813108,1.3661773,,,,,,,,,,,,,, -34900,2.518691,1.4032869,,,,,,,,,,,,,, -35000,5.139729,1.3805482,,,,,,,,,,,,,, -35100,3.2664988,1.3502672,,,,,,,,,,,,,, -35200,3.2580142,1.4447485,,,,,,,,,,,,,, -35300,1.4775172,1.3492992,,,,,,,,,,,,,, -35400,2.0873322,1.3535788,,,,,,,,,,,,,, -35496,,,0.24313995,0.0822007972071422,0.5057727,0.1451480541046757,5348.0,0.28559756,0.0918286515142282,2472.0,30261.592493772507,33297.76939201355,30261.592493772507,3033.224186897278,1.2355103492736816,0.0 -35500,1.5918313,1.3787267,,,,,,,,,,,,,, -35600,1.4868482,1.323712,,,,,,,,,,,,,, -35700,2.5286868,1.3167784,,,,,,,,,,,,,, -35800,2.6101272,1.3200282,,,,,,,,,,,,,, -35900,1.9475386,1.3688763,,,,,,,,,,,,,, -36000,2.6002836,1.368429,,,,,,,,,,,,,, -36100,1.640196,1.3293006,,,,,,,,,,,,,, -36200,2.6255605,1.2838227,,,,,,,,,,,,,, -36300,2.9467957,1.2848277,,,,,,,,,,,,,, -36400,1.9000856,1.336713,,,,,,,,,,,,,, -36500,1.8314599,1.2524546,,,,,,,,,,,,,, -36600,2.261323,1.3111064,,,,,,,,,,,,,, -36700,1.6385969,1.2845261,,,,,,,,,,,,,, -36800,3.2858486,1.2863553,,,,,,,,,,,,,, -36900,3.2811985,1.3067323,,,,,,,,,,,,,, -37000,2.4535012,1.2880384,,,,,,,,,,,,,, -37100,2.9771273,1.2848873,,,,,,,,,,,,,, -37194,,,0.21898215,0.0749193774173812,0.48544854,0.1398573042277725,5348.0,0.2720657,0.0870554303008144,2472.0,31701.920878887177,34876.05977129936,31701.920878887177,3171.0469262599945,1.2923777103424072,0.0 -37200,1.5474418,1.2641402,,,,,,,,,,,,,, -37300,1.7610567,1.3028084,,,,,,,,,,,,,, -37400,2.156486,1.2768223,,,,,,,,,,,,,, -37500,1.6623255,1.268355,,,,,,,,,,,,,, -37600,1.8470974,1.2821685,,,,,,,,,,,,,, -37700,2.093409,1.2350676,,,,,,,,,,,,,, -37800,2.8156712,1.2808913,,,,,,,,,,,,,, -37900,2.2079444,1.3093799,,,,,,,,,,,,,, -38000,2.529984,1.300514,,,,,,,,,,,,,, -38100,3.4685469,1.3076705,,,,,,,,,,,,,, -38200,3.712459,1.3504345,,,,,,,,,,,,,, -38300,2.1216092,1.3738016,,,,,,,,,,,,,, -38400,1.4373497,1.2986168,,,,,,,,,,,,,, -38500,2.1068883,1.2422315,,,,,,,,,,,,,, -38600,2.3449621,1.2505593,,,,,,,,,,,,,, -38700,2.4221842,1.2485483,,,,,,,,,,,,,, -38800,2.2067018,1.2706034,,,,,,,,,,,,,, -38888,,,0.2012175,0.0686038810344557,0.46407133,0.1329349179837222,5348.0,0.2549981,0.0828306217374525,2472.0,33141.95910906792,36454.38322663307,33141.95910906792,3309.1948528289795,1.345574140548706,0.0 -38900,2.2299798,1.2512783,,,,,,,,,,,,,, -39000,1.9500977,1.2254575,,,,,,,,,,,,,, -39100,1.5072635,1.2848145,,,,,,,,,,,,,, -39200,2.4226441,1.2558731,,,,,,,,,,,,,, -39300,2.343526,1.2596251,,,,,,,,,,,,,, -39400,4.247792,1.3209633,,,,,,,,,,,,,, -39500,2.4742777,1.2690239,,,,,,,,,,,,,, -39600,2.8347337,1.2111028,,,,,,,,,,,,,, -39700,4.2523985,1.2593325,,,,,,,,,,,,,, -39800,2.5614145,1.2597848,,,,,,,,,,,,,, -39900,2.2689207,1.2708827,,,,,,,,,,,,,, -40000,3.0356805,1.3291664,,,,,,,,,,,,,, -40100,2.3224692,1.2574492,,,,,,,,,,,,,, -40200,1.9208509,1.2680588,,,,,,,,,,,,,, -40300,2.3660805,1.2048564,,,,,,,,,,,,,, -40400,3.316835,1.2386068,,,,,,,,,,,,,, -40500,4.5080276,1.1732733,,,,,,,,,,,,,, -40597,,,0.16660897,0.0576332048370914,0.44376504,0.1274510750456182,5348.0,0.24118751,0.0770215099628298,2472.0,34582.12906932831,38031.52192878723,34582.12906932831,3446.0249009132385,1.402059555053711,0.0 -40600,4.9670444,1.2369435,,,,,,,,,,,,,, -40700,2.083215,1.2213452,,,,,,,,,,,,,, -40800,1.7926558,1.2593261,,,,,,,,,,,,,, -40900,1.735864,1.2366799,,,,,,,,,,,,,, -41000,2.6829185,1.2604822,,,,,,,,,,,,,, -41100,2.8313072,1.1852596,,,,,,,,,,,,,, -41200,2.2261994,1.1769763,,,,,,,,,,,,,, -41300,2.1215785,1.1714965,,,,,,,,,,,,,, -41400,2.2134948,1.2056998,,,,,,,,,,,,,, -41500,2.4395113,1.221748,,,,,,,,,,,,,, -41600,2.7673552,1.177962,,,,,,,,,,,,,, -41700,2.5303142,1.224452,,,,,,,,,,,,,, -41800,1.410359,1.185755,,,,,,,,,,,,,, -41900,3.2150488,1.1696073,,,,,,,,,,,,,, -42000,1.3873019,1.1741194,,,,,,,,,,,,,, -42100,2.0776105,1.2406626,,,,,,,,,,,,,, -42200,2.1869822,1.1702296,,,,,,,,,,,,,, -42267,,,0.16422431,0.0553903991336949,0.4237094,0.1222279077401353,5348.0,0.23010056,0.0739341498588345,2472.0,36022.32275509834,39610.5627257824,36022.32275509834,3584.731765270233,1.4607226848602295,0.0 -42300,1.6719017,1.1699139,,,,,,,,,,,,,, -42400,3.7699223,1.1805105,,,,,,,,,,,,,, -42500,3.9244566,1.1513819,,,,,,,,,,,,,, -42600,2.4406133,1.2050757,,,,,,,,,,,,,, -42700,1.6554675,1.2025272,,,,,,,,,,,,,, -42800,2.6056938,1.2566981,,,,,,,,,,,,,, -42900,3.198595,1.1528031,,,,,,,,,,,,,, -43000,2.3868134,1.213785,,,,,,,,,,,,,, -43100,4.295318,1.1392734,,,,,,,,,,,,,, -43200,2.4747417,1.17445,,,,,,,,,,,,,, -43300,2.6536722,1.1751488,,,,,,,,,,,,,, -43400,2.7247875,1.1712074,,,,,,,,,,,,,, -43500,6.5716834,1.2211405,,,,,,,,,,,,,, -43600,2.2557378,1.2268993,,,,,,,,,,,,,, -43700,2.6360269,1.2042385,,,,,,,,,,,,,, -43800,1.9502686,1.1207912,,,,,,,,,,,,,, -43900,2.2976818,1.181202,,,,,,,,,,,,,, -43960,,,0.16956198,0.0570378974095783,0.41393122,0.1190225629242013,5348.0,0.22181489,0.0717811224178904,2472.0,37463.19466948509,41188.80369210243,37463.19466948509,3721.962750196457,1.517216682434082,0.0 -44000,2.384778,1.1597308,,,,,,,,,,,,,, -44100,2.887459,1.1163864,,,,,,,,,,,,,, -44200,2.2114065,1.2097418,,,,,,,,,,,,,, -44300,1.6158651,1.1732537,,,,,,,,,,,,,, -44400,2.6200128,1.1659396,,,,,,,,,,,,,, -44500,2.0131006,1.1596704,,,,,,,,,,,,,, -44600,2.4385366,1.2156897,,,,,,,,,,,,,, -44700,2.9738634,1.2048491,,,,,,,,,,,,,, -44800,1.6563052,1.1376069,,,,,,,,,,,,,, -44900,2.4125693,1.1613483,,,,,,,,,,,,,, -45000,2.9115748,1.185433,,,,,,,,,,,,,, -45100,3.0843217,1.1718326,,,,,,,,,,,,,, -45200,2.6739662,1.2487214,,,,,,,,,,,,,, -45300,2.295576,1.113597,,,,,,,,,,,,,, -45400,1.6729254,1.1013728,,,,,,,,,,,,,, -45500,2.256197,1.174069,,,,,,,,,,,,,, -45600,1.6389074,1.1587495,,,,,,,,,,,,,, -45652,,,0.14950223,0.0494265207849142,0.4087675,0.1171012869652529,5348.0,0.21725857,0.0700140149899457,2472.0,38903.46937012672,42764.24932599068,38903.46937012672,3856.989844560623,1.5770831108093262,0.0 -45700,2.0311992,1.0757637,,,,,,,,,,,,,, -45800,2.4514196,1.1362596,,,,,,,,,,,,,, -45900,4.2978544,1.1113275,,,,,,,,,,,,,, -46000,3.048061,1.1754123,,,,,,,,,,,,,, -46100,1.9653622,1.1163757,,,,,,,,,,,,,, -46200,2.2975583,1.1130465,,,,,,,,,,,,,, -46300,1.9023966,1.1799093,,,,,,,,,,,,,, -46400,3.0053506,1.12178,,,,,,,,,,,,,, -46500,2.6575947,1.2047203,,,,,,,,,,,,,, -46600,4.1253815,1.1193739,,,,,,,,,,,,,, -46700,2.52269,1.1822609,,,,,,,,,,,,,, -46800,2.3322732,1.1607736,,,,,,,,,,,,,, -46900,1.9389585,1.1228415,,,,,,,,,,,,,, -47000,2.0968342,1.142294,,,,,,,,,,,,,, -47100,5.2010593,1.1320963,,,,,,,,,,,,,, -47200,2.6955955,1.1121788,,,,,,,,,,,,,, -47300,2.5975242,1.131176,,,,,,,,,,,,,, -47317,,,0.16066147,0.0549748414936208,0.4061986,0.1159137646388677,5348.0,0.21572423,0.0690999939065261,2472.0,40343.7261993885,44342.45700645447,40343.7261993885,3994.8145458698273,1.6236886978149414,0.0 -47400,3.6297467,1.1454483,,,,,,,,,,,,,, -47500,3.6005273,1.091243,,,,,,,,,,,,,, -47600,2.2095008,1.1006105,,,,,,,,,,,,,, -47700,4.949297,1.1267474,,,,,,,,,,,,,, -47800,3.6232226,1.1916837,,,,,,,,,,,,,, -47900,2.2475789,1.1449461,,,,,,,,,,,,,, -48000,,,0.18988594,0.058824468000234,0.40646556,0.116000656516408,5348.0,0.21571341,0.0690390591676314,2472.0,40915.90963792801,45051.27542638779,40915.90963792801,4131.348959207535,1.6855216026306152,0.0 -48000,,,,,,,,,,,40915.90963792801,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/eval_measurements.csv deleted file mode 100644 index cc21d2677..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -163.03839111328125,0.0,16.60268473625183,1,0,16.60268473625183,28.256365,2472,2.3440578473787905,179.64115977287292,28.652134,2.3640764239661616,28.14943,5348,2.1854755399364723 -293.5546774864197,0.0314514636993408,1457.0688242912292,1692,0,1457.0688242912292,2.6875026,2472,0.5876749334795767,1750.7331321239471,2.7661793,0.616729432815946,3.1783566,5348,0.6568639755930371 -428.7920391559601,0.0820395946502685,2897.5816576480865,3396,0,2897.5816576480865,0.61222464,2472,0.1921475433144435,3326.616904258728,0.77670133,0.2395568610625213,0.9372019,5348,0.2630506772739122 -564.2852621078491,0.1306955814361572,4338.747329473496,5071,0,4338.747329473496,0.5147943,2472,0.165884670850852,4903.404318571091,0.67181903,0.2107846644031234,0.81510663,5348,0.2328991957674001 -698.5102126598358,0.1890890598297119,5778.721389293671,6744,0,5778.721389293671,0.45724174,2472,0.1476245607620904,6477.741781234741,0.7183804,0.2258982279304178,0.7535185,5348,0.2154146190756635 -832.8487877845764,0.2436397075653076,7219.063415050507,8435,0,7219.063415050507,0.4333417,2472,0.1384640383482623,8052.558868169785,0.58081234,0.181722517129294,0.719909,5348,0.2052289600973189 -967.5219428539276,0.2929205894470215,8659.439930438995,10111,0,8659.439930438995,0.4085987,2472,0.1312331159994312,9627.739500045776,0.57047266,0.1831359970882951,0.67904717,5348,0.1958832559351979 -1104.7788660526276,0.3447701930999756,10100.39093017578,11807,0,10100.39093017578,0.39502275,2472,0.1270489305953324,11206.082193851473,0.43884623,0.1453479301718804,0.6524572,5348,0.1888932871197273 -1239.711059808731,0.4017910957336426,11540.737623214722,13483,0,11540.737623214722,0.37856,2472,0.1242865557654418,12781.498124361038,0.5149647,0.1668655376632475,0.643958,5348,0.1846355851202487 -1376.2743470668793,0.4509057998657226,12981.104486703873,15183,0,12981.104486703873,0.36153093,2472,0.1167915828814007,14358.558901309969,0.4569584,0.1493814295257181,0.6158217,5348,0.1773849406721569 -1510.1858768463137,0.5034916400909424,14421.737709760666,16899,0,14421.737709760666,0.35192952,2472,0.1134604838218268,15933.23994922638,0.43455628,0.1444184104049378,0.59773993,5348,0.1722390105911544 -1645.4841482639313,0.5546281337738037,15861.863068819046,18554,0,15861.863068819046,0.34360522,2472,0.1104543700363577,17508.79342675209,0.42747203,0.1402834559825565,0.5999306,5348,0.1729244909584174 -1779.5542323589325,0.6197030544281006,17301.77611398697,20255,0,17301.77611398697,0.32855484,2472,0.1051327361728921,19082.924453496933,0.4165472,0.1373503718636462,0.57209927,5348,0.1654228255307645 -1913.847899436951,0.6720757484436035,18741.892809152603,21960,0,18741.892809152603,0.32678494,2472,0.1036703024394207,20657.470319747925,0.45802358,0.1451975568283515,0.5748465,5348,0.1641484113268389 -2049.448930501938,0.7180941104888916,20182.922751426697,23637,0,20182.922751426697,0.31266233,2472,0.1005016960168992,22234.225139141083,0.36998138,0.1244089883425191,0.5380809,5348,0.1556233526748216 -2184.9646167755127,0.7669081687927246,21623.12057375908,25321,0,21623.12057375908,0.30099636,2472,0.096825300103589,23810.06982064247,0.3866852,0.1330297052680436,0.53082854,5348,0.1544261756953764 -2321.4288444519043,0.825782060623169,23063.543320417404,27004,0,23063.543320417404,0.29252425,2472,0.0944691568663294,25387.0976998806,0.3206794,0.1080222660607426,0.5141854,5348,0.1482665070430694 -2456.543736219406,0.8892910480499268,24503.62016057968,28672,0,24503.62016057968,0.28394303,2472,0.0919911441512806,26962.43369913101,0.34442776,0.1134760598433213,0.5025977,5348,0.1444915376965928 -2592.422735452652,0.9475498199462892,25944.00483107567,30368,0,25944.00483107567,0.27818447,2472,0.088558487193549,28538.838061094284,0.3366657,0.1129947038176884,0.48924,5348,0.1405234752889154 -2728.975687980652,1.0051610469818115,27383.97795963288,32050,0,27383.97795963288,0.26197803,2472,0.0844758596876079,30115.50089669228,0.31884333,0.1084934699441537,0.4744674,5348,0.1363912837792174 -2863.3213727474213,1.0570778846740725,28823.93050980568,33728,0,28823.93050980568,0.25633433,2472,0.082221274348506,31689.932183027267,0.26275682,0.0906634322464177,0.46272892,5348,0.1334079959836643 -2997.9378378391266,1.1216213703155518,30263.79677867889,35434,0,30263.79677867889,0.24987762,2472,0.0803729206020352,33264.56129407883,0.23715778,0.0822559468865204,0.45252073,5348,0.1306467652084922 -3131.8751130104065,1.180323600769043,31703.80060362816,37095,0,31703.80060362816,0.2403281,2472,0.0776917920906709,34838.64279150963,0.2662122,0.0921220819411118,0.4400807,5348,0.1263697539028935 -3266.877792835236,1.2335264682769775,33144.33669090271,38792,0,33144.33669090271,0.23511161,2472,0.0753356488534113,36414.3170132637,0.23776919,0.0816136566212724,0.43106434,5348,0.1244967512092452 -3403.8705430030823,1.2884671688079834,34584.92214655876,40493,0,34584.92214655876,0.2303546,2472,0.073954461438466,37992.03388214111,0.245853,0.0821549486668751,0.42453703,5348,0.1226913310870174 -3539.95130443573,1.3435046672821045,36025.603747844696,42166,0,36025.603747844696,0.22683771,2472,0.072248288749416,39568.92971038818,0.19963317,0.0690241674824582,0.41599953,5348,0.119881827046545 -3673.904772281647,1.4003708362579346,37465.7214448452,43858,0,37465.7214448452,0.22416562,2472,0.0712123981882071,41143.13983345032,0.19884929,0.0696530193958415,0.41424644,5348,0.1192156559854021 -3808.381865978241,1.463677167892456,38905.63683629036,45559,0,38905.63683629036,0.22256012,2472,0.0708671013344707,42717.67965936661,0.21639155,0.0758470642510123,0.40952334,5348,0.118057097618197 -3951.201397418976,1.5179204940795898,40346.24051046372,47255,0,40346.24051046372,0.22253822,2472,0.070765543436313,44301.23940682411,0.14392106,0.0514502725017065,0.40911275,5348,0.1177288394141556 -4087.9201798439026,1.5780909061431885,40963.57310843468,48000,0,40963.57310843468,0.22244853,2472,0.07060305079926066,45055.39094734192,0.13062766,0.046137194585132935,0.40922388,5348,0.11775780337333577 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/measurements.csv deleted file mode 100644 index 205458f30..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,19.97156,33.3628,,,,,,,,,,,,,, -1,,,28.652134,2.3640764239661616,28.14943,2.1854755399364723,5348.0,28.256365,2.3440578473787905,2472.0,16.60268473625183,179.64115977287292,16.60268473625183,163.03839111328125,0.0,0.0 -100,0.6764528,5.974327,,,,,,,,,,,,,, -200,0.81621385,5.8461623,,,,,,,,,,,,,, -300,0.35137692,5.6963544,,,,,,,,,,,,,, -400,0.4978052,5.4592085,,,,,,,,,,,,,, -500,3.12195,4.8659573,,,,,,,,,,,,,, -600,2.518604,3.9639368,,,,,,,,,,,,,, -700,2.133242,3.519647,,,,,,,,,,,,,, -800,2.7007182,3.1820345,,,,,,,,,,,,,, -900,1.9007668,2.9566088,,,,,,,,,,,,,, -1000,1.9064113,2.8953915,,,,,,,,,,,,,, -1100,1.916658,2.6843143,,,,,,,,,,,,,, -1200,2.3349862,2.6010392,,,,,,,,,,,,,, -1300,2.7594743,2.5207837,,,,,,,,,,,,,, -1400,2.137585,2.4555142,,,,,,,,,,,,,, -1500,2.06146,2.3973935,,,,,,,,,,,,,, -1600,2.5934317,2.3687,,,,,,,,,,,,,, -1692,,,2.7661793,0.616729432815946,3.1783566,0.6568639755930371,5348.0,2.6875026,0.5876749334795767,2472.0,1457.0688242912292,1750.7331321239471,1457.0688242912292,293.5546774864197,0.0314514636993408,0.0 -1700,2.065997,2.2679536,,,,,,,,,,,,,, -1800,2.0645704,2.298472,,,,,,,,,,,,,, -1900,2.1103375,2.238489,,,,,,,,,,,,,, -2000,3.1600158,2.1203363,,,,,,,,,,,,,, -2100,2.3110313,2.1612856,,,,,,,,,,,,,, -2200,1.8614471,2.1284633,,,,,,,,,,,,,, -2300,3.1144905,2.1432502,,,,,,,,,,,,,, -2400,2.0346124,2.03621,,,,,,,,,,,,,, -2500,2.0530136,2.0510998,,,,,,,,,,,,,, -2600,2.4034193,2.0400167,,,,,,,,,,,,,, -2700,1.8855661,2.017541,,,,,,,,,,,,,, -2800,2.068966,2.0046606,,,,,,,,,,,,,, -2900,2.2186134,1.9746678,,,,,,,,,,,,,, -3000,1.9830955,1.9679114,,,,,,,,,,,,,, -3100,2.1880627,1.9469041,,,,,,,,,,,,,, -3200,2.4529264,1.9226111,,,,,,,,,,,,,, -3300,2.2975678,1.9544233,,,,,,,,,,,,,, -3396,,,0.77670133,0.2395568610625213,0.9372019,0.2630506772739122,5348.0,0.61222464,0.1921475433144435,2472.0,2897.5816576480865,3326.616904258728,2897.5816576480865,428.7920391559601,0.0820395946502685,0.0 -3400,3.0261195,1.9084653,,,,,,,,,,,,,, -3500,1.9404007,1.813303,,,,,,,,,,,,,, -3600,3.1117308,1.9079797,,,,,,,,,,,,,, -3700,2.1625483,1.8625969,,,,,,,,,,,,,, -3800,2.0456288,1.8474367,,,,,,,,,,,,,, -3900,1.9961656,1.8193164,,,,,,,,,,,,,, -4000,2.538927,1.875909,,,,,,,,,,,,,, -4100,2.4532301,1.7647384,,,,,,,,,,,,,, -4200,2.2374988,1.7748616,,,,,,,,,,,,,, -4300,2.7782674,1.7877648,,,,,,,,,,,,,, -4400,2.7902932,1.857867,,,,,,,,,,,,,, -4500,2.5017405,1.807187,,,,,,,,,,,,,, -4600,2.434175,1.8035121,,,,,,,,,,,,,, -4700,1.7115916,1.8133581,,,,,,,,,,,,,, -4800,2.689726,1.8069055,,,,,,,,,,,,,, -4900,2.025163,1.8486514,,,,,,,,,,,,,, -5000,2.531794,1.745615,,,,,,,,,,,,,, -5071,,,0.67181903,0.2107846644031234,0.81510663,0.2328991957674001,5348.0,0.5147943,0.165884670850852,2472.0,4338.747329473496,4903.404318571091,4338.747329473496,564.2852621078491,0.1306955814361572,0.0 -5100,2.9572678,1.7375916,,,,,,,,,,,,,, -5200,3.0506253,1.8035394,,,,,,,,,,,,,, -5300,2.505675,1.7978303,,,,,,,,,,,,,, -5400,2.5878925,1.6870785,,,,,,,,,,,,,, -5500,3.1214387,1.7670187,,,,,,,,,,,,,, -5600,2.9207842,1.7119411,,,,,,,,,,,,,, -5700,2.9379878,1.7436165,,,,,,,,,,,,,, -5800,2.9911861,1.6827505,,,,,,,,,,,,,, -5900,1.8512429,1.729753,,,,,,,,,,,,,, -6000,3.3212252,1.8009892,,,,,,,,,,,,,, -6100,2.0908062,1.6731743,,,,,,,,,,,,,, -6200,2.2451422,1.7037318,,,,,,,,,,,,,, -6300,3.5402348,1.7545251,,,,,,,,,,,,,, -6400,2.8774297,1.7557352,,,,,,,,,,,,,, -6500,3.5243418,1.7123917,,,,,,,,,,,,,, -6600,2.8174994,1.720168,,,,,,,,,,,,,, -6700,1.82747,1.7176809,,,,,,,,,,,,,, -6744,,,0.7183804,0.2258982279304178,0.7535185,0.2154146190756635,5348.0,0.45724174,0.1476245607620904,2472.0,5778.721389293671,6477.741781234741,5778.721389293671,698.5102126598358,0.1890890598297119,0.0 -6800,1.9208033,1.6455631,,,,,,,,,,,,,, -6900,3.039464,1.6850964,,,,,,,,,,,,,, -7000,2.6599514,1.6460751,,,,,,,,,,,,,, -7100,2.1253874,1.5906415,,,,,,,,,,,,,, -7200,2.179585,1.6991526,,,,,,,,,,,,,, -7300,2.5701811,1.6107625,,,,,,,,,,,,,, -7400,2.943845,1.6780832,,,,,,,,,,,,,, -7500,2.5407972,1.6164218,,,,,,,,,,,,,, -7600,3.0505388,1.7718903,,,,,,,,,,,,,, -7700,3.1217997,1.7207695,,,,,,,,,,,,,, -7800,4.2191663,1.6296585,,,,,,,,,,,,,, -7900,3.1005912,1.6941937,,,,,,,,,,,,,, -8000,3.066215,1.6266915,,,,,,,,,,,,,, -8100,1.860813,1.6627873,,,,,,,,,,,,,, -8200,3.0694904,1.6469055,,,,,,,,,,,,,, -8300,5.713393,1.6515063,,,,,,,,,,,,,, -8400,3.3981578,1.6346661,,,,,,,,,,,,,, -8435,,,0.58081234,0.181722517129294,0.719909,0.2052289600973189,5348.0,0.4333417,0.1384640383482623,2472.0,7219.063415050507,8052.558868169785,7219.063415050507,832.8487877845764,0.2436397075653076,0.0 -8500,3.0783455,1.6561189,,,,,,,,,,,,,, -8600,3.0826705,1.6548958,,,,,,,,,,,,,, -8700,2.2723877,1.6590768,,,,,,,,,,,,,, -8800,2.4421759,1.6477842,,,,,,,,,,,,,, -8900,2.3486257,1.6430461,,,,,,,,,,,,,, -9000,2.984601,1.6193453,,,,,,,,,,,,,, -9100,2.4364421,1.5763372,,,,,,,,,,,,,, -9200,1.7280567,1.622752,,,,,,,,,,,,,, -9300,2.6167972,1.6187203,,,,,,,,,,,,,, -9400,2.658753,1.6046913,,,,,,,,,,,,,, -9500,2.6071632,1.6001781,,,,,,,,,,,,,, -9600,2.058335,1.5965515,,,,,,,,,,,,,, -9700,2.5237553,1.5369959,,,,,,,,,,,,,, -9800,2.9531844,1.6394887,,,,,,,,,,,,,, -9900,2.56956,1.605667,,,,,,,,,,,,,, -10000,7.6572957,1.6757822,,,,,,,,,,,,,, -10100,2.2904112,1.5753306,,,,,,,,,,,,,, -10111,,,0.57047266,0.1831359970882951,0.67904717,0.1958832559351979,5348.0,0.4085987,0.1312331159994312,2472.0,8659.439930438995,9627.739500045776,8659.439930438995,967.5219428539276,0.2929205894470215,0.0 -10200,2.9424567,1.5551798,,,,,,,,,,,,,, -10300,4.3107,1.587351,,,,,,,,,,,,,, -10400,3.5606663,1.5882354,,,,,,,,,,,,,, -10500,3.1025631,1.6061687,,,,,,,,,,,,,, -10600,3.7586322,1.6235286,,,,,,,,,,,,,, -10700,3.7841792,1.634376,,,,,,,,,,,,,, -10800,2.4558346,1.5740241,,,,,,,,,,,,,, -10900,2.8614821,1.5658292,,,,,,,,,,,,,, -11000,2.3821023,1.6171768,,,,,,,,,,,,,, -11100,4.4142475,1.6251762,,,,,,,,,,,,,, -11200,1.8930466,1.6164032,,,,,,,,,,,,,, -11300,3.6598659,1.5379887,,,,,,,,,,,,,, -11400,1.9426116,1.50417,,,,,,,,,,,,,, -11500,4.280859,1.5847509,,,,,,,,,,,,,, -11600,4.0955706,1.6311859,,,,,,,,,,,,,, -11700,2.4320147,1.5604827,,,,,,,,,,,,,, -11800,2.335886,1.6059858,,,,,,,,,,,,,, -11807,,,0.43884623,0.1453479301718804,0.6524572,0.1888932871197273,5348.0,0.39502275,0.1270489305953324,2472.0,10100.39093017578,11206.082193851473,10100.39093017578,1104.7788660526276,0.3447701930999756,0.0 -11900,2.0920854,1.5988739,,,,,,,,,,,,,, -12000,3.288314,1.530616,,,,,,,,,,,,,, -12100,2.3767529,1.5885011,,,,,,,,,,,,,, -12200,3.0722303,1.5417876,,,,,,,,,,,,,, -12300,2.8821795,1.5957863,,,,,,,,,,,,,, -12400,3.2884536,1.5034353,,,,,,,,,,,,,, -12500,3.1059172,1.5422149,,,,,,,,,,,,,, -12600,3.7647521,1.6530242,,,,,,,,,,,,,, -12700,1.7508106,1.5468508,,,,,,,,,,,,,, -12800,2.6007583,1.5990493,,,,,,,,,,,,,, -12900,1.7361333,1.5737156,,,,,,,,,,,,,, -13000,3.0956593,1.5803629,,,,,,,,,,,,,, -13100,4.8551025,1.5477027,,,,,,,,,,,,,, -13200,2.5799775,1.6153066,,,,,,,,,,,,,, -13300,2.3266444,1.5489657,,,,,,,,,,,,,, -13400,2.04146,1.5632435,,,,,,,,,,,,,, -13483,,,0.5149647,0.1668655376632475,0.643958,0.1846355851202487,5348.0,0.37856,0.1242865557654418,2472.0,11540.737623214722,12781.498124361038,11540.737623214722,1239.711059808731,0.4017910957336426,0.0 -13500,2.162764,1.5210471,,,,,,,,,,,,,, -13600,1.9939514,1.5242784,,,,,,,,,,,,,, -13700,3.1878276,1.6034981,,,,,,,,,,,,,, -13800,2.4351857,1.5490184,,,,,,,,,,,,,, -13900,2.0774755,1.5147369,,,,,,,,,,,,,, -14000,2.8798335,1.5144216,,,,,,,,,,,,,, -14100,3.3405185,1.5274396,,,,,,,,,,,,,, -14200,2.5253196,1.5351319,,,,,,,,,,,,,, -14300,4.424787,1.5619253,,,,,,,,,,,,,, -14400,2.8027253,1.503172,,,,,,,,,,,,,, -14500,2.8636096,1.4872937,,,,,,,,,,,,,, -14600,3.041764,1.4752587,,,,,,,,,,,,,, -14700,2.367052,1.5227889,,,,,,,,,,,,,, -14800,2.9100397,1.5645382,,,,,,,,,,,,,, -14900,2.6435401,1.4285389,,,,,,,,,,,,,, -15000,3.1778553,1.485853,,,,,,,,,,,,,, -15100,2.8184242,1.4721686,,,,,,,,,,,,,, -15183,,,0.4569584,0.1493814295257181,0.6158217,0.1773849406721569,5348.0,0.36153093,0.1167915828814007,2472.0,12981.104486703873,14358.558901309969,12981.104486703873,1376.2743470668793,0.4509057998657226,0.0 -15200,2.3516014,1.515737,,,,,,,,,,,,,, -15300,3.4710355,1.5498551,,,,,,,,,,,,,, -15400,3.26612,1.5646219,,,,,,,,,,,,,, -15500,2.431029,1.5436053,,,,,,,,,,,,,, -15600,3.5724792,1.4485717,,,,,,,,,,,,,, -15700,1.612815,1.4945827,,,,,,,,,,,,,, -15800,2.0328624,1.468972,,,,,,,,,,,,,, -15900,3.0514312,1.5275633,,,,,,,,,,,,,, -16000,3.994744,1.5492465,,,,,,,,,,,,,, -16100,3.327753,1.5301118,,,,,,,,,,,,,, -16200,2.8974674,1.5678988,,,,,,,,,,,,,, -16300,2.8818226,1.547135,,,,,,,,,,,,,, -16400,3.126888,1.5089366,,,,,,,,,,,,,, -16500,2.5421078,1.4991934,,,,,,,,,,,,,, -16600,3.9994895,1.4687809,,,,,,,,,,,,,, -16700,3.7236958,1.4805137,,,,,,,,,,,,,, -16800,2.1941838,1.4863323,,,,,,,,,,,,,, -16899,,,0.43455628,0.1444184104049378,0.59773993,0.1722390105911544,5348.0,0.35192952,0.1134604838218268,2472.0,14421.737709760666,15933.23994922638,14421.737709760666,1510.1858768463137,0.5034916400909424,0.0 -16900,1.759346,1.4688437,,,,,,,,,,,,,, -17000,2.7816327,1.4913048,,,,,,,,,,,,,, -17100,2.6237202,1.4472687,,,,,,,,,,,,,, -17200,4.0193276,1.4849254,,,,,,,,,,,,,, -17300,2.9134064,1.4420879,,,,,,,,,,,,,, -17400,2.9635744,1.4924188,,,,,,,,,,,,,, -17500,1.9016325,1.4651527,,,,,,,,,,,,,, -17600,3.5401726,1.483245,,,,,,,,,,,,,, -17700,1.5135566,1.4824507,,,,,,,,,,,,,, -17800,2.722364,1.4407765,,,,,,,,,,,,,, -17900,3.8236098,1.4374101,,,,,,,,,,,,,, -18000,2.6224723,1.5083044,,,,,,,,,,,,,, -18100,2.898105,1.5207963,,,,,,,,,,,,,, -18200,5.4103937,1.444075,,,,,,,,,,,,,, -18300,1.9141707,1.5172772,,,,,,,,,,,,,, -18400,3.1483908,1.5137312,,,,,,,,,,,,,, -18500,2.0446641,1.5541264,,,,,,,,,,,,,, -18554,,,0.42747203,0.1402834559825565,0.5999306,0.1729244909584174,5348.0,0.34360522,0.1104543700363577,2472.0,15861.863068819046,17508.79342675209,15861.863068819046,1645.4841482639313,0.5546281337738037,0.0 -18600,1.6353147,1.4548701,,,,,,,,,,,,,, -18700,2.3644543,1.5016155,,,,,,,,,,,,,, -18800,2.853226,1.483192,,,,,,,,,,,,,, -18900,2.331504,1.503237,,,,,,,,,,,,,, -19000,2.0118675,1.4997408,,,,,,,,,,,,,, -19100,2.428699,1.4414822,,,,,,,,,,,,,, -19200,2.2585578,1.482087,,,,,,,,,,,,,, -19300,1.996089,1.4813308,,,,,,,,,,,,,, -19400,2.543516,1.4760721,,,,,,,,,,,,,, -19500,1.7930732,1.4302771,,,,,,,,,,,,,, -19600,2.2695491,1.4222782,,,,,,,,,,,,,, -19700,2.1586,1.4975789,,,,,,,,,,,,,, -19800,2.3164117,1.4679714,,,,,,,,,,,,,, -19900,2.9781752,1.460714,,,,,,,,,,,,,, -20000,2.5170724,1.4280893,,,,,,,,,,,,,, -20100,1.957732,1.4958907,,,,,,,,,,,,,, -20200,1.9155838,1.4758831,,,,,,,,,,,,,, -20255,,,0.4165472,0.1373503718636462,0.57209927,0.1654228255307645,5348.0,0.32855484,0.1051327361728921,2472.0,17301.77611398697,19082.924453496933,17301.77611398697,1779.5542323589325,0.6197030544281006,0.0 -20300,2.3074145,1.477262,,,,,,,,,,,,,, -20400,2.3151672,1.4706446,,,,,,,,,,,,,, -20500,2.5834146,1.5012177,,,,,,,,,,,,,, -20600,4.0693097,1.4838138,,,,,,,,,,,,,, -20700,2.292322,1.4202914,,,,,,,,,,,,,, -20800,2.506834,1.4653115,,,,,,,,,,,,,, -20900,2.2439919,1.4329299,,,,,,,,,,,,,, -21000,2.6762767,1.4712871,,,,,,,,,,,,,, -21100,2.5089145,1.4161052,,,,,,,,,,,,,, -21200,2.4922557,1.4232882,,,,,,,,,,,,,, -21300,2.8619525,1.4123768,,,,,,,,,,,,,, -21400,2.8819122,1.4506842,,,,,,,,,,,,,, -21500,3.9128077,1.4930241,,,,,,,,,,,,,, -21600,3.6293302,1.4272252,,,,,,,,,,,,,, -21700,3.469381,1.4338573,,,,,,,,,,,,,, -21800,3.4413147,1.4372472,,,,,,,,,,,,,, -21900,2.3072555,1.4330913,,,,,,,,,,,,,, -21960,,,0.45802358,0.1451975568283515,0.5748465,0.1641484113268389,5348.0,0.32678494,0.1036703024394207,2472.0,18741.892809152603,20657.470319747925,18741.892809152603,1913.847899436951,0.6720757484436035,0.0 -22000,2.898135,1.4575381,,,,,,,,,,,,,, -22100,2.9212596,1.4267774,,,,,,,,,,,,,, -22200,2.7161114,1.4492222,,,,,,,,,,,,,, -22300,2.1634464,1.3427547,,,,,,,,,,,,,, -22400,2.311857,1.3846431,,,,,,,,,,,,,, -22500,3.411644,1.4380256,,,,,,,,,,,,,, -22600,3.2996507,1.465291,,,,,,,,,,,,,, -22700,1.7114676,1.3867863,,,,,,,,,,,,,, -22800,2.098394,1.4066918,,,,,,,,,,,,,, -22900,2.3095276,1.4655353,,,,,,,,,,,,,, -23000,2.199413,1.4106869,,,,,,,,,,,,,, -23100,4.222838,1.4760401,,,,,,,,,,,,,, -23200,1.9220623,1.4370857,,,,,,,,,,,,,, -23300,2.6514747,1.4129522,,,,,,,,,,,,,, -23400,1.6184065,1.3948865,,,,,,,,,,,,,, -23500,2.3842137,1.3896552,,,,,,,,,,,,,, -23600,2.41078,1.4409575,,,,,,,,,,,,,, -23637,,,0.36998138,0.1244089883425191,0.5380809,0.1556233526748216,5348.0,0.31266233,0.1005016960168992,2472.0,20182.922751426697,22234.225139141083,20182.922751426697,2049.448930501938,0.7180941104888916,0.0 -23700,3.187863,1.3674142,,,,,,,,,,,,,, -23800,2.3118968,1.3814262,,,,,,,,,,,,,, -23900,2.1974523,1.3189355,,,,,,,,,,,,,, -24000,1.7703434,1.3988175,,,,,,,,,,,,,, -24100,2.2602065,1.4365028,,,,,,,,,,,,,, -24200,2.5811818,1.3903257,,,,,,,,,,,,,, -24300,2.9723966,1.4166691,,,,,,,,,,,,,, -24400,2.1786895,1.4245336,,,,,,,,,,,,,, -24500,2.1629696,1.4025447,,,,,,,,,,,,,, -24600,1.7200326,1.3388965,,,,,,,,,,,,,, -24700,2.2898414,1.4164318,,,,,,,,,,,,,, -24800,1.8870592,1.2925851,,,,,,,,,,,,,, -24900,2.5251102,1.3898442,,,,,,,,,,,,,, -25000,2.4677162,1.3834606,,,,,,,,,,,,,, -25100,2.8690214,1.4489591,,,,,,,,,,,,,, -25200,1.9675428,1.3866022,,,,,,,,,,,,,, -25300,2.9925702,1.3786176,,,,,,,,,,,,,, -25321,,,0.3866852,0.1330297052680436,0.53082854,0.1544261756953764,5348.0,0.30099636,0.096825300103589,2472.0,21623.12057375908,23810.06982064247,21623.12057375908,2184.9646167755127,0.7669081687927246,0.0 -25400,2.5687938,1.3764095,,,,,,,,,,,,,, -25500,2.8788052,1.4112507,,,,,,,,,,,,,, -25600,2.5245326,1.3773977,,,,,,,,,,,,,, -25700,2.6513891,1.3754516,,,,,,,,,,,,,, -25800,3.8421538,1.4303131,,,,,,,,,,,,,, -25900,2.8596497,1.4088126,,,,,,,,,,,,,, -26000,2.151019,1.4167914,,,,,,,,,,,,,, -26100,2.1179175,1.3771207,,,,,,,,,,,,,, -26200,2.264943,1.3555641,,,,,,,,,,,,,, -26300,1.7839309,1.3134071,,,,,,,,,,,,,, -26400,1.8003156,1.3608347,,,,,,,,,,,,,, -26500,3.490991,1.432617,,,,,,,,,,,,,, -26600,2.2538505,1.3841629,,,,,,,,,,,,,, -26700,1.94915,1.3718549,,,,,,,,,,,,,, -26800,1.8289999,1.3198541,,,,,,,,,,,,,, -26900,2.7807791,1.3787062,,,,,,,,,,,,,, -27000,2.3978186,1.3454019,,,,,,,,,,,,,, -27004,,,0.3206794,0.1080222660607426,0.5141854,0.1482665070430694,5348.0,0.29252425,0.0944691568663294,2472.0,23063.543320417404,25387.0976998806,23063.543320417404,2321.4288444519043,0.825782060623169,0.0 -27100,2.5013154,1.2921778,,,,,,,,,,,,,, -27200,2.8965118,1.3138239,,,,,,,,,,,,,, -27300,2.3256788,1.3469687,,,,,,,,,,,,,, -27400,3.1509712,1.383617,,,,,,,,,,,,,, -27500,2.1266708,1.3430351,,,,,,,,,,,,,, -27600,2.9277961,1.3352257,,,,,,,,,,,,,, -27700,3.5144775,1.4021744,,,,,,,,,,,,,, -27800,1.9312797,1.3330288,,,,,,,,,,,,,, -27900,2.6522474,1.3617932,,,,,,,,,,,,,, -28000,1.9888768,1.3196906,,,,,,,,,,,,,, -28100,1.6055635,1.3290745,,,,,,,,,,,,,, -28200,1.805569,1.3328816,,,,,,,,,,,,,, -28300,2.01336,1.2710623,,,,,,,,,,,,,, -28400,3.1484766,1.3542395,,,,,,,,,,,,,, -28500,2.1555662,1.3079164,,,,,,,,,,,,,, -28600,2.6379623,1.3304023,,,,,,,,,,,,,, -28672,,,0.34442776,0.1134760598433213,0.5025977,0.1444915376965928,5348.0,0.28394303,0.0919911441512806,2472.0,24503.62016057968,26962.43369913101,24503.62016057968,2456.543736219406,0.8892910480499268,0.0 -28700,3.2416236,1.3769727,,,,,,,,,,,,,, -28800,2.8128452,1.338463,,,,,,,,,,,,,, -28900,2.3161743,1.3209437,,,,,,,,,,,,,, -29000,2.5300992,1.3161993,,,,,,,,,,,,,, -29100,2.4014134,1.2913765,,,,,,,,,,,,,, -29200,2.1555505,1.322758,,,,,,,,,,,,,, -29300,2.5504127,1.3522025,,,,,,,,,,,,,, -29400,2.1874726,1.3513954,,,,,,,,,,,,,, -29500,2.1509476,1.3190032,,,,,,,,,,,,,, -29600,2.5553646,1.3587126,,,,,,,,,,,,,, -29700,2.1583292,1.2819895,,,,,,,,,,,,,, -29800,2.281729,1.3173391,,,,,,,,,,,,,, -29900,3.528208,1.3389325,,,,,,,,,,,,,, -30000,2.4505708,1.307382,,,,,,,,,,,,,, -30100,1.9556847,1.2707566,,,,,,,,,,,,,, -30200,1.9909201,1.296238,,,,,,,,,,,,,, -30300,2.612081,1.3146256,,,,,,,,,,,,,, -30368,,,0.3366657,0.1129947038176884,0.48924,0.1405234752889154,5348.0,0.27818447,0.088558487193549,2472.0,25944.00483107567,28538.838061094284,25944.00483107567,2592.422735452652,0.9475498199462892,0.0 -30400,3.479038,1.3012847,,,,,,,,,,,,,, -30500,1.6023792,1.2843016,,,,,,,,,,,,,, -30600,1.8652593,1.2877138,,,,,,,,,,,,,, -30700,2.677276,1.2956189,,,,,,,,,,,,,, -30800,2.6460564,1.3011808,,,,,,,,,,,,,, -30900,1.6564211,1.2739158,,,,,,,,,,,,,, -31000,2.9718747,1.2577789,,,,,,,,,,,,,, -31100,2.1543748,1.2551665,,,,,,,,,,,,,, -31200,2.3368173,1.2957801,,,,,,,,,,,,,, -31300,2.4090254,1.3305156,,,,,,,,,,,,,, -31400,3.2875192,1.3046817,,,,,,,,,,,,,, -31500,1.7559481,1.2692312,,,,,,,,,,,,,, -31600,2.609461,1.26584,,,,,,,,,,,,,, -31700,3.1758168,1.2553089,,,,,,,,,,,,,, -31800,2.601479,1.3116568,,,,,,,,,,,,,, -31900,2.553893,1.2743013,,,,,,,,,,,,,, -32000,2.2674685,1.3202583,,,,,,,,,,,,,, -32050,,,0.31884333,0.1084934699441537,0.4744674,0.1363912837792174,5348.0,0.26197803,0.0844758596876079,2472.0,27383.97795963288,30115.50089669228,27383.97795963288,2728.975687980652,1.0051610469818115,0.0 -32100,2.0636013,1.3492357,,,,,,,,,,,,,, -32200,5.330937,1.2863352,,,,,,,,,,,,,, -32300,2.437819,1.2655836,,,,,,,,,,,,,, -32400,2.5743353,1.3249645,,,,,,,,,,,,,, -32500,2.1291277,1.2580895,,,,,,,,,,,,,, -32600,2.9510603,1.2595255,,,,,,,,,,,,,, -32700,3.074325,1.2308615,,,,,,,,,,,,,, -32800,1.9651377,1.3403277,,,,,,,,,,,,,, -32900,3.4627333,1.2356793,,,,,,,,,,,,,, -33000,2.220267,1.2507086,,,,,,,,,,,,,, -33100,2.986897,1.2253662,,,,,,,,,,,,,, -33200,2.1621,1.2110131,,,,,,,,,,,,,, -33300,2.0568066,1.2485391,,,,,,,,,,,,,, -33400,2.7019227,1.2700175,,,,,,,,,,,,,, -33500,3.204135,1.2990346,,,,,,,,,,,,,, -33600,1.9835413,1.2593486,,,,,,,,,,,,,, -33700,2.7411568,1.2534302,,,,,,,,,,,,,, -33728,,,0.26275682,0.0906634322464177,0.46272892,0.1334079959836643,5348.0,0.25633433,0.082221274348506,2472.0,28823.93050980568,31689.932183027267,28823.93050980568,2863.3213727474213,1.0570778846740725,0.0 -33800,1.7618475,1.3101288,,,,,,,,,,,,,, -33900,3.6236334,1.2204791,,,,,,,,,,,,,, -34000,2.157185,1.2646946,,,,,,,,,,,,,, -34100,2.4555068,1.2342257,,,,,,,,,,,,,, -34200,1.7468916,1.2831389,,,,,,,,,,,,,, -34300,4.1285787,1.2325579,,,,,,,,,,,,,, -34400,2.651117,1.2895695,,,,,,,,,,,,,, -34500,2.0690746,1.2626393,,,,,,,,,,,,,, -34600,3.8736744,1.23668,,,,,,,,,,,,,, -34700,1.7970433,1.2400714,,,,,,,,,,,,,, -34800,1.9718077,1.2559913,,,,,,,,,,,,,, -34900,1.9346123,1.2940993,,,,,,,,,,,,,, -35000,2.9434812,1.2994279,,,,,,,,,,,,,, -35100,2.371198,1.2221718,,,,,,,,,,,,,, -35200,3.7940996,1.2416713,,,,,,,,,,,,,, -35300,2.4908128,1.1987936,,,,,,,,,,,,,, -35400,3.95674,1.2406597,,,,,,,,,,,,,, -35434,,,0.23715778,0.0822559468865204,0.45252073,0.1306467652084922,5348.0,0.24987762,0.0803729206020352,2472.0,30263.79677867889,33264.56129407883,30263.79677867889,2997.9378378391266,1.1216213703155518,0.0 -35500,3.3996663,1.2035706,,,,,,,,,,,,,, -35600,2.0465941,1.2209945,,,,,,,,,,,,,, -35700,1.9488319,1.2148578,,,,,,,,,,,,,, -35800,2.909103,1.2398611,,,,,,,,,,,,,, -35900,3.2966223,1.2486209,,,,,,,,,,,,,, -36000,3.8471344,1.2008266,,,,,,,,,,,,,, -36100,2.1874433,1.1926237,,,,,,,,,,,,,, -36200,2.181227,1.2716595,,,,,,,,,,,,,, -36300,1.8928149,1.2054563,,,,,,,,,,,,,, -36400,5.675743,1.1901733,,,,,,,,,,,,,, -36500,2.556581,1.1767017,,,,,,,,,,,,,, -36600,2.469177,1.1867843,,,,,,,,,,,,,, -36700,1.8420767,1.1729608,,,,,,,,,,,,,, -36800,2.193925,1.1472309,,,,,,,,,,,,,, -36900,2.3896139,1.1890672,,,,,,,,,,,,,, -37000,1.6170136,1.2025706,,,,,,,,,,,,,, -37095,,,0.2662122,0.0921220819411118,0.4400807,0.1263697539028935,5348.0,0.2403281,0.0776917920906709,2472.0,31703.80060362816,34838.64279150963,31703.80060362816,3131.8751130104065,1.180323600769043,0.0 -37100,2.3361835,1.1661474,,,,,,,,,,,,,, -37200,1.9316827,1.1321003,,,,,,,,,,,,,, -37300,2.3031058,1.253222,,,,,,,,,,,,,, -37400,2.5597246,1.2742499,,,,,,,,,,,,,, -37500,3.6544251,1.1502382,,,,,,,,,,,,,, -37600,2.073495,1.225052,,,,,,,,,,,,,, -37700,2.58946,1.1588856,,,,,,,,,,,,,, -37800,1.8835222,1.2097727,,,,,,,,,,,,,, -37900,3.189687,1.2108743,,,,,,,,,,,,,, -38000,3.1082604,1.1968851,,,,,,,,,,,,,, -38100,2.1835465,1.1932542,,,,,,,,,,,,,, -38200,2.3485954,1.1775479,,,,,,,,,,,,,, -38300,2.24339,1.139957,,,,,,,,,,,,,, -38400,2.2114627,1.1909893,,,,,,,,,,,,,, -38500,2.5765257,1.1411531,,,,,,,,,,,,,, -38600,1.9615417,1.1922364,,,,,,,,,,,,,, -38700,2.955509,1.2159698,,,,,,,,,,,,,, -38792,,,0.23776919,0.0816136566212724,0.43106434,0.1244967512092452,5348.0,0.23511161,0.0753356488534113,2472.0,33144.33669090271,36414.3170132637,33144.33669090271,3266.877792835236,1.2335264682769775,0.0 -38800,2.2055597,1.2253611,,,,,,,,,,,,,, -38900,4.872925,1.2262586,,,,,,,,,,,,,, -39000,2.3310597,1.1751475,,,,,,,,,,,,,, -39100,1.8039068,1.1534587,,,,,,,,,,,,,, -39200,2.0933418,1.2302811,,,,,,,,,,,,,, -39300,3.6893318,1.2079655,,,,,,,,,,,,,, -39400,3.5538301,1.1890799,,,,,,,,,,,,,, -39500,2.481145,1.1726943,,,,,,,,,,,,,, -39600,1.8677827,1.1234739,,,,,,,,,,,,,, -39700,1.7503325,1.199142,,,,,,,,,,,,,, -39800,2.8347433,1.2074256,,,,,,,,,,,,,, -39900,3.7255669,1.2384366,,,,,,,,,,,,,, -40000,3.2800393,1.1819841,,,,,,,,,,,,,, -40100,3.3125074,1.2054069,,,,,,,,,,,,,, -40200,2.940457,1.1983771,,,,,,,,,,,,,, -40300,1.5420908,1.1824244,,,,,,,,,,,,,, -40400,3.2996287,1.2383344,,,,,,,,,,,,,, -40493,,,0.245853,0.0821549486668751,0.42453703,0.1226913310870174,5348.0,0.2303546,0.073954461438466,2472.0,34584.92214655876,37992.03388214111,34584.92214655876,3403.8705430030823,1.2884671688079834,0.0 -40500,2.6214957,1.1687092,,,,,,,,,,,,,, -40600,1.7395962,1.1584071,,,,,,,,,,,,,, -40700,2.1507826,1.1264765,,,,,,,,,,,,,, -40800,2.5817363,1.1800663,,,,,,,,,,,,,, -40900,2.4159327,1.1689152,,,,,,,,,,,,,, -41000,3.0038097,1.0973221,,,,,,,,,,,,,, -41100,2.2418542,1.1612183,,,,,,,,,,,,,, -41200,2.7508328,1.1357701,,,,,,,,,,,,,, -41300,2.6633008,1.1593689,,,,,,,,,,,,,, -41400,4.583047,1.1285005,,,,,,,,,,,,,, -41500,2.7472997,1.1380895,,,,,,,,,,,,,, -41600,2.697177,1.1695117,,,,,,,,,,,,,, -41700,3.3037686,1.1819857,,,,,,,,,,,,,, -41800,1.3967395,1.1236378,,,,,,,,,,,,,, -41900,2.8711681,1.116848,,,,,,,,,,,,,, -42000,1.9421113,1.199407,,,,,,,,,,,,,, -42100,2.779494,1.1871037,,,,,,,,,,,,,, -42166,,,0.19963317,0.0690241674824582,0.41599953,0.119881827046545,5348.0,0.22683771,0.072248288749416,2472.0,36025.603747844696,39568.92971038818,36025.603747844696,3539.95130443573,1.3435046672821045,0.0 -42200,1.7416813,1.1055902,,,,,,,,,,,,,, -42300,2.5265498,1.1254131,,,,,,,,,,,,,, -42400,2.4563327,1.1540731,,,,,,,,,,,,,, -42500,1.9267952,1.1814868,,,,,,,,,,,,,, -42600,2.2605352,1.0751992,,,,,,,,,,,,,, -42700,2.316642,1.1804173,,,,,,,,,,,,,, -42800,3.8810258,1.1641468,,,,,,,,,,,,,, -42900,4.877115,1.1467357,,,,,,,,,,,,,, -43000,1.8070699,1.2093275,,,,,,,,,,,,,, -43100,2.1196299,1.148168,,,,,,,,,,,,,, -43200,2.7274418,1.1212568,,,,,,,,,,,,,, -43300,2.0236487,1.1443036,,,,,,,,,,,,,, -43400,2.0204673,1.13092,,,,,,,,,,,,,, -43500,6.681984,1.1953605,,,,,,,,,,,,,, -43600,2.5208614,1.1541524,,,,,,,,,,,,,, -43700,1.8666756,1.1300437,,,,,,,,,,,,,, -43800,9.776022,1.1217419,,,,,,,,,,,,,, -43858,,,0.19884929,0.0696530193958415,0.41424644,0.1192156559854021,5348.0,0.22416562,0.0712123981882071,2472.0,37465.7214448452,41143.13983345032,37465.7214448452,3673.904772281647,1.4003708362579346,0.0 -43900,2.5643516,1.1262488,,,,,,,,,,,,,, -44000,1.8983699,1.1342235,,,,,,,,,,,,,, -44100,2.7198722,1.1547194,,,,,,,,,,,,,, -44200,2.5709734,1.1706852,,,,,,,,,,,,,, -44300,2.6151094,1.1201501,,,,,,,,,,,,,, -44400,2.325447,1.1064252,,,,,,,,,,,,,, -44500,2.553275,1.1273353,,,,,,,,,,,,,, -44600,3.654486,1.1209638,,,,,,,,,,,,,, -44700,6.7093387,1.179887,,,,,,,,,,,,,, -44800,1.8456391,1.1615473,,,,,,,,,,,,,, -44900,3.557736,1.1152289,,,,,,,,,,,,,, -45000,4.398363,1.158078,,,,,,,,,,,,,, -45100,2.038878,1.1476153,,,,,,,,,,,,,, -45200,5.875288,1.1552075,,,,,,,,,,,,,, -45300,2.3726935,1.1632367,,,,,,,,,,,,,, -45400,3.2580013,1.0750841,,,,,,,,,,,,,, -45500,2.6250541,1.1901085,,,,,,,,,,,,,, -45559,,,0.21639155,0.0758470642510123,0.40952334,0.118057097618197,5348.0,0.22256012,0.0708671013344707,2472.0,38905.63683629036,42717.67965936661,38905.63683629036,3808.381865978241,1.463677167892456,0.0 -45600,2.889078,1.1844239,,,,,,,,,,,,,, -45700,2.0147092,1.1109675,,,,,,,,,,,,,, -45800,3.3497872,1.1420019,,,,,,,,,,,,,, -45900,2.3332572,1.1309856,,,,,,,,,,,,,, -46000,2.5551434,1.1782155,,,,,,,,,,,,,, -46100,3.0336447,1.1227216,,,,,,,,,,,,,, -46200,2.572378,1.1155179,,,,,,,,,,,,,, -46300,3.1547039,1.1168884,,,,,,,,,,,,,, -46400,2.343409,1.1626161,,,,,,,,,,,,,, -46500,3.60673,1.1908226,,,,,,,,,,,,,, -46600,2.5894709,1.158323,,,,,,,,,,,,,, -46700,2.2103112,1.179596,,,,,,,,,,,,,, -46800,2.0047736,1.1057904,,,,,,,,,,,,,, -46900,1.9321479,1.1593839,,,,,,,,,,,,,, -47000,3.3765492,1.1077394,,,,,,,,,,,,,, -47100,1.6833522,1.1074891,,,,,,,,,,,,,, -47200,2.1362329,1.152596,,,,,,,,,,,,,, -47255,,,0.14392106,0.0514502725017065,0.40911275,0.1177288394141556,5348.0,0.22253822,0.070765543436313,2472.0,40346.24051046372,44301.23940682411,40346.24051046372,3951.201397418976,1.5179204940795898,0.0 -47300,2.4503574,1.1394595,,,,,,,,,,,,,, -47400,3.367509,1.1401918,,,,,,,,,,,,,, -47500,3.0219357,1.104136,,,,,,,,,,,,,, -47600,2.7378664,1.0976082,,,,,,,,,,,,,, -47700,2.444543,1.1436592,,,,,,,,,,,,,, -47800,2.8784187,1.1606178,,,,,,,,,,,,,, -47900,2.505443,1.1507068,,,,,,,,,,,,,, -48000,,,0.13062766,0.0461371945851329,0.40922388,0.1177578033733357,5348.0,0.22244853,0.0706030507992606,2472.0,40963.57310843468,45055.39094734192,40963.57310843468,4087.920179843903,1.5780909061431885,0.0 -48000,,,,,,,,,,,40963.57310843468,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index fbb1ceb50..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -304.7409582138061,0.0,18.134385585784912,1,0,18.134385585784912,0.5047287344932556,0.7438743114471436,0.0269396591122225,43793,322.87538480758667,0.5092470049858093,0.7391904592514038,0.0221835560731189,0.5064646005630493,0.7422076463699341,0.0253144192803209,43793 -421.0076684951782,0.029463768005371,258.18942499160767,747,0,258.18942499160767,0.983142077922821,0.0816605240106582,0.0369262199382961,43793,679.2467248439789,0.9867319464683532,0.0704648867249488,0.0333204718944733,0.9841179251670836,0.0788265094161033,0.0337082857474339,43793 -542.6084208488464,0.0602407455444335,498.358202457428,1500,0,498.358202457428,0.9835611581802368,0.0604051724076271,0.0883908861125532,43793,1041.0658912658691,0.9873274564743042,0.0489182732999324,0.0835694140901646,0.9845482110977172,0.0574894733726978,0.0848161563026899,43793 -665.5626842975616,0.0868616104125976,738.4302983283997,2245,0,738.4302983283997,0.9840468168258668,0.0566889606416225,0.1269769745888034,43793,1404.1376762390137,0.9878538846969604,0.0445154048502445,0.1271006322314452,0.9850134253501892,0.0538621060550212,0.1286356176045245,43793 -790.7114028930664,0.1153392791748046,978.480020046234,2983,0,978.480020046234,0.9842873215675354,0.0547768212854862,0.1546334280093363,43793,1769.383313894272,0.987955927848816,0.0427691042423248,0.1545525683655089,0.9852378964424132,0.0521178096532821,0.1480877293086972,43793 -913.8746719360352,0.1453778743743896,1218.683366060257,3727,0,1218.683366060257,0.9844751358032228,0.0530382059514522,0.1746518429197931,43793,2132.798628091812,0.9880806803703308,0.0414005331695079,0.186444753574127,0.9853852987289428,0.0503360740840435,0.1722589120405069,43793 -1039.562193632126,0.1721065044403076,1458.725549697876,4481,0,1458.725549697876,0.9846848845481871,0.0517690926790237,0.1815007698173546,43793,2498.5742666721344,0.9885159134864808,0.0399757809937,0.2038349252000051,0.9855237007141112,0.0492455773055553,0.1806705486653925,43793 -1166.363187789917,0.2028937339782714,1698.9222049713137,5228,0,1698.9222049713137,0.984925389289856,0.051138449460268,0.1908380836181807,43793,2865.626400947571,0.9886970520019532,0.0385160632431507,0.2270452773609716,0.9857315421104432,0.0484458357095718,0.192269829851199,43793 -1292.633964061737,0.2294843196868896,1939.038080215454,5996,0,1939.038080215454,0.9850820899009703,0.0504555068910121,0.2009072419200461,43793,3232.058907032013,0.9888204336166382,0.0382240451872348,0.2375763856274924,0.9858850240707396,0.0478058569133281,0.2007561521562516,43793 -1414.3643636703491,0.256913423538208,2179.076942920685,6757,0,2179.076942920685,0.9850327968597412,0.0497962236404418,0.2034223256774653,43793,3593.875339746475,0.9890872836112976,0.0374721959233284,0.2541572959008027,0.9859747290611268,0.0472970940172672,0.2038481776836391,43793 -1540.1427791118622,0.2850782871246338,2419.183886051178,7522,0,2419.183886051178,0.9851840138435364,0.0496272966265678,0.212045926660142,43793,3959.8080909252167,0.989159345626831,0.0369904600083828,0.2583130676710853,0.9860392808914183,0.0471784397959709,0.2126993248703816,43793 -1665.422149181366,0.3125410079956054,2659.2870647907257,8281,0,2659.2870647907257,0.9853053092956544,0.0490796826779842,0.2183831775864873,43793,4325.237766027451,0.9891362190246582,0.037117350846529,0.2612576523422221,0.9862495064735411,0.0464592687785625,0.2196623737679801,43793 -1792.234255552292,0.341059923171997,2899.4726645946503,9050,0,2899.4726645946503,0.985432505607605,0.0485740005970001,0.2298582522393411,43793,4692.283871173859,0.9894993901252748,0.0358033999800682,0.2808998759655631,0.9862414002418518,0.0460271462798118,0.2250249947677923,43793 -1920.9666464328768,0.368741512298584,3139.7406141757965,9810,0,3139.7406141757965,0.9851722121238708,0.0490936003625392,0.2129779767804349,43793,5061.331539392471,0.9892947673797609,0.0360191650688648,0.2829000146369129,0.9861679077148438,0.04629398137331,0.2249200090916605,43793 -2055.779053211212,0.3968634605407715,3379.834892272949,10564,0,3379.834892272949,0.9855567812919616,0.048363097012043,0.2423461242788807,43793,5436.286778688431,0.9896652698516846,0.0348614193499088,0.2978428145748989,0.9864857792854308,0.0457414388656616,0.2394883731962805,43793 -2186.676379680633,0.4272780418395996,3619.9956657886505,11314,0,3619.9956657886505,0.9855571985244752,0.0484842956066131,0.2412491077403872,43793,5807.399445056915,0.989776849746704,0.0341957435011863,0.3297334600062838,0.9864062070846558,0.0457753986120224,0.2374500720500937,43793 -2317.691652059555,0.4562311172485351,3860.237807989121,12064,0,3860.237807989121,0.9856915473937988,0.0479581244289875,0.2449247904619788,43793,6178.707715749741,0.9900395274162292,0.0332000404596328,0.3556511829583499,0.986559271812439,0.0451599508523941,0.2459968183186453,43793 -2446.9008860588074,0.4860799312591553,4100.330991983414,12819,0,4100.330991983414,0.985736608505249,0.047790914773941,0.2415669797862868,43793,6548.0606780052185,0.9901712536811828,0.0327010117471218,0.3520209612511431,0.9864898324012756,0.0451404564082622,0.2395079632055661,43793 -2573.6196570396423,0.514338493347168,4340.384344100952,13572,0,4340.384344100952,0.98572438955307,0.0475961826741695,0.2534813540282405,43793,6914.880608558655,0.9905293583869934,0.0316002890467643,0.368745897355923,0.9866258502006532,0.0448617041110992,0.2575587049856996,43793 -2705.533165216446,0.5428283214569092,4580.391752004623,14326,0,4580.391752004623,0.9858802556991576,0.0469530262053012,0.2496725852030368,43793,7286.849791765213,0.9907206296920776,0.0310354493558406,0.3946692529617591,0.986632764339447,0.0443393066525459,0.251086335599621,43793 -2835.9652168750763,0.5726001262664795,4820.388184309006,15080,0,4820.388184309006,0.9857193827629088,0.0474602282047271,0.2438626242540415,43793,7657.327804088592,0.9904849529266356,0.0318308845162391,0.376709458590151,0.9865093231201172,0.0448063239455223,0.2508813830749123,43793 -2965.335473537445,0.6022679805755615,5060.620755910873,15827,0,5060.620755910873,0.985835611820221,0.0477200932800769,0.2473249334432067,43793,8026.981867313385,0.9903941750526428,0.0317538008093833,0.3847419663950548,0.98667573928833,0.0448875240981578,0.2546998476400984,43793 -3095.985473155976,0.6344373226165771,5300.623321056366,16575,0,5300.623321056366,0.985753893852234,0.0474314987659454,0.2495560120245449,43793,8397.686516284943,0.9903529286384584,0.0319628044962883,0.3737279558476397,0.9867126941680908,0.0445161461830139,0.2609148835963837,43793 -3227.5997779369354,0.6642558574676514,5540.597485303879,17330,0,5540.597485303879,0.9857543110847472,0.047792412340641,0.2385126380118646,43793,8769.324578762054,0.990510880947113,0.0314424484968185,0.3802214268304177,0.9867005348205566,0.044713731855154,0.261765784988841,43793 -3361.631055355072,0.6942794322967529,5780.7910261154175,18077,0,5780.7910261154175,0.9859607219696044,0.0470982193946838,0.2513989695053806,43793,9143.603893518448,0.990651547908783,0.0309187453240156,0.3985708604516049,0.9868072867393494,0.0443336330354213,0.266102079928784,43793 -3493.475923061371,0.7263171672821045,6020.8155081272125,18829,0,6020.8155081272125,0.9857859015464784,0.0480262599885463,0.2439402673153815,43793,9515.525021791458,0.9905492067337036,0.0310190469026565,0.3952029213450361,0.9866875410079956,0.0448851659893989,0.2635644689347876,43793 -3622.156748771668,0.7575254440307617,6260.985202074051,19584,0,6260.985202074051,0.9859375357627868,0.0473642535507679,0.2539158233362041,43793,9884.426526784897,0.9906980395317078,0.0305677652359008,0.4242602740789738,0.9866716861724854,0.0448466949164867,0.2587407145618029,43793 -3750.527698040009,0.7873868942260742,6501.081083774567,20334,0,6501.081083774567,0.9859299659729004,0.0473265685141086,0.2548134925721332,43793,10252.942810297012,0.9910008907318116,0.0294814500957727,0.4198984412224145,0.9867601990699768,0.0444596223533153,0.2566778579806512,43793 -3876.483735561371,0.8186118602752686,6741.23437333107,21083,0,6741.23437333107,0.9858819246292114,0.047335647046566,0.2524632223612293,43793,10619.10322880745,0.9911450147628784,0.0289554018527269,0.438380159959826,0.9866806268692015,0.0445852689445018,0.2575463262048126,43793 -4001.835470676422,0.8482942581176758,6981.344316244125,21830,0,6981.344316244125,0.9859269857406616,0.0477520860731601,0.2534212336427425,43793,10984.614884376526,0.9911766052246094,0.0291900876909494,0.4514105935616376,0.9867467880249025,0.0448013246059417,0.2649502321140494,43793 -4129.834286689758,0.8798651695251465,7221.317653656006,22585,0,7221.317653656006,0.9858440160751344,0.04732321575284,0.246868166257547,43793,11352.638570547104,0.9911478757858276,0.0292267147451639,0.422738707506834,0.9867796897888184,0.044557336717844,0.2584770706554278,43793 -4261.829008817673,0.9099385738372804,7461.457053661346,23334,0,7461.457053661346,0.9857808351516724,0.0471952743828296,0.2496342910386312,43793,11724.822311639786,0.9909042119979858,0.0301217660307884,0.4214129838863844,0.9866956472396852,0.0443504191935062,0.2666798122420432,43793 -4390.558969974518,1.252744436264038,7701.260676383972,24086,0,7701.260676383972,0.9860095381736756,0.046928908675909,0.2543803166805494,43793,12093.718660354614,0.99093496799469,0.0298298224806785,0.4233810275004899,0.9868283867836,0.0444389320909976,0.2710360070079665,43793 -4518.463991880417,1.2852094173431396,7941.512230634689,24831,0,7941.512230634689,0.9857420921325684,0.0475777871906757,0.2484946569880513,43793,12461.927044153214,0.9908654093742372,0.0299554914236068,0.4146248075562608,0.9866400361061096,0.0446502827107906,0.2619743777067673,43793 -4643.603357791901,1.3162312507629397,8181.642039299011,25587,0,8181.642039299011,0.985908031463623,0.047651320695877,0.2526763650455053,43793,12827.246604681017,0.9910185933113098,0.0293855872005224,0.432371964851221,0.9868754744529724,0.0446766056120395,0.2670693539503197,43793 -4772.377794981003,1.3500714302062988,8421.819328069687,26330,0,8421.819328069687,0.9859522581100464,0.0470491051673889,0.2559921718547284,43793,13196.254113912582,0.991351306438446,0.0284819547086954,0.4572311420610337,0.986766278743744,0.0443845465779304,0.2718749601531016,43793 -4902.130871295929,1.381518840789795,8661.81754231453,27090,0,8661.81754231453,0.985910177230835,0.0474141016602516,0.2518513987897968,43793,13566.056858778,0.9912184476852416,0.0286246985197067,0.4483122074480636,0.9867537021636964,0.0445280410349369,0.2673226801916273,43793 -5027.87125992775,1.414813995361328,8902.068484067917,27850,0,8902.068484067917,0.9858503341674804,0.0474634394049644,0.2518078150224259,43793,13932.1010928154,0.9915094375610352,0.0278093516826629,0.4788816882412824,0.9867618083953856,0.0447036251425743,0.2695098683814103,43793 -5155.106744766235,1.451047420501709,9142.265007972715,28604,0,9142.265007972715,0.9858406782150269,0.0476547665894031,0.2527613423048759,43793,14299.589767217636,0.9915740489959716,0.0273893792182207,0.4774020733554986,0.9868312478065492,0.0447375103831291,0.263646713481197,43793 -5276.8527200222015,1.4818134307861328,9382.218895435331,29360,0,9382.218895435331,0.9858920574188232,0.0477716363966465,0.2477014584042329,43793,14661.340048074722,0.9915950894355774,0.0272927861660718,0.4815457537935733,0.9867784976959229,0.0447870045900344,0.2628334545799982,43793 -5401.89927482605,1.514939308166504,9622.306058883669,30109,0,9622.306058883669,0.9860133528709412,0.0474919974803924,0.254933399067772,43793,15026.52691435814,0.9915949702262878,0.0276272725313901,0.4659208937616314,0.9868267774581908,0.0446105264127254,0.2654210941093273,43793 -5528.822056770325,1.5469331741333008,9862.304508924484,30861,0,9862.304508924484,0.9858651161193848,0.0474503748118877,0.2607415900214805,43793,15393.499891757963,0.9914368391036988,0.0282070115208625,0.4585877945100001,0.9867272973060608,0.0447770431637764,0.2648167534202912,43793 -5652.796845912933,1.580089807510376,10102.278964281082,31613,0,10102.278964281082,0.9859716296195984,0.0472341999411582,0.2509376027609071,43793,15757.501960992811,0.9915117025375366,0.0278496518731117,0.4719675642555955,0.9867539405822754,0.0446218624711036,0.2626403302936321,43793 -5779.776381969452,1.6129043102264404,10342.485783815384,32375,0,10342.485783815384,0.9860483407974244,0.0476890802383422,0.2538761775615324,43793,16124.741267204285,0.9914616346359252,0.0277968235313892,0.4661418373102133,0.986866533756256,0.0448258966207504,0.264903256071144,43793 -5905.846033334732,1.644451379776001,10582.662477493286,33127,0,10582.662477493286,0.9858246445655824,0.0473682954907417,0.2573437097629211,43793,16491.039191007614,0.9914880990982056,0.027717201039195,0.476126487256679,0.9866778254508972,0.044719535857439,0.2619186380504609,43793 -6031.545439004898,1.678142547607422,10822.833097219467,33880,0,10822.833097219467,0.986042022705078,0.0473588556051254,0.2607110022505314,43793,16856.96304345131,0.9916728138923644,0.0270904246717691,0.4776111416574321,0.9869416356086732,0.0447549112141132,0.2737131680794807,43793 -6158.037148237228,1.71140456199646,11062.853231668472,34630,0,11062.853231668472,0.9858322143554688,0.0475003346800804,0.256020744104348,43793,17223.528245449066,0.9918762445449828,0.0265015661716461,0.4959326874687609,0.9868482947349548,0.0447657331824302,0.2684190313777011,43793 -6286.22608089447,1.7442903518676758,11302.85993719101,35374,0,11302.85993719101,0.9859994649887084,0.047745082527399,0.2512933375163518,43793,17591.77684688568,0.9919270873069764,0.0261572357267141,0.5077010842706066,0.9868158102035522,0.0450122393667697,0.2640386329595336,43793 -6413.524546384811,1.7778651714324951,11542.92440032959,36121,0,11542.92440032959,0.9859430193901062,0.0478298552334308,0.2534297712322391,43793,17959.193194389343,0.9923211336135864,0.0249617137014865,0.5338029568777007,0.986887276172638,0.0448975339531898,0.2611574561332976,43793 -6544.811107635498,1.813506364822388,11782.954986095428,36868,0,11782.954986095428,0.986023485660553,0.0480411872267723,0.2562980574986526,43793,18330.56738352776,0.9920997619628906,0.0254573356360197,0.5130001361147138,0.9868791699409484,0.0451982542872428,0.2632707293366008,43793 -6670.475694179535,1.847543478012085,12023.092227220535,37621,0,12023.092227220535,0.9859792590141296,0.0478082187473773,0.2564197364137381,43793,18696.423310041428,0.991813063621521,0.0264731142669916,0.5066786252637335,0.9869157075881958,0.0449858866631984,0.270892764058419,43793 -6795.817242145538,1.8814423084259035,12263.215424776075,38381,0,12263.215424776075,0.985987663269043,0.0477124713361263,0.2526040420471432,43793,19061.941546201702,0.9919701814651488,0.0260833278298378,0.4938227566410164,0.9869075417518616,0.0448095016181468,0.2723583684879261,43793 -6917.844294548035,1.9166576862335205,12503.383272647858,39140,0,12503.383272647858,0.9859459400177002,0.0479544922709465,0.25299035868724,43793,19424.19128537178,0.9919205904006958,0.0262812860310077,0.5019598940760397,0.9867547154426576,0.0452037937939167,0.2664720998561808,43793 -7043.957558870316,1.951773881912232,12743.645746707916,39896,0,12743.645746707916,0.9860289096832277,0.0478225275874137,0.2570222310540754,43793,19790.62196087837,0.9918465614318848,0.0261474233120679,0.5186325621587466,0.9868633151054382,0.0452602319419384,0.2675940791495391,43793 -7164.391726016998,1.987480878829956,12983.84726190567,40660,0,12983.84726190567,0.9858823418617249,0.0478564389050006,0.2583532536156636,43793,20151.31301140785,0.9921106100082396,0.0257124546915292,0.50393287827044,0.9866834878921508,0.0454183742403984,0.2655199069094247,43793 -7288.951065063477,2.023068428039551,13223.94919514656,41411,0,13223.94919514656,0.985870122909546,0.0485646724700927,0.2541695344242801,43793,20516.02917122841,0.9921404719352722,0.0253216736018657,0.5251935954997733,0.9867870211601256,0.0456630028784275,0.265891050572245,43793 -7416.767649650574,2.0580596923828125,13464.0667886734,42171,0,13464.0667886734,0.9859126806259156,0.0486242473125457,0.2552605737327201,43793,20884.01805138588,0.992421567440033,0.0244387220591306,0.5435246945200095,0.9867402911186218,0.0457667410373687,0.2670472319105248,43793 -7546.203330516815,2.094203472137451,13704.141545772552,42920,0,13704.141545772552,0.9860095381736756,0.0487369447946548,0.2568638627028144,43793,21253.58636689186,0.9927170872688292,0.0233779214322567,0.5733930541995389,0.9868364930152892,0.0459333881735801,0.2704952414247853,43793 -7669.8636927604675,2.1300570964813232,13944.12907576561,43675,0,13944.12907576561,0.9858528971672058,0.0487676002085208,0.2534971891102581,43793,21617.28958058357,0.9927061200141908,0.023393128067255,0.5666941320422774,0.9868308305740356,0.0457053445279598,0.2666772814062301,43793 -7793.7537133693695,2.167064905166626,14184.10730624199,44425,0,14184.10730624199,0.9857774972915648,0.0494213365018367,0.2502955025357254,43793,21981.214041233063,0.9929761290550232,0.0225698053836822,0.5881898919000075,0.9867346286773682,0.0464401394128799,0.2643463754657173,43793 -7919.652441978455,2.202032327651977,14424.139218568802,45172,0,14424.139218568802,0.9857913851737976,0.0486513152718544,0.2532316248694131,43793,22347.200195789337,0.9928836226463318,0.0231638737022876,0.5603069323792206,0.9867216348648072,0.0458410233259201,0.2698805055719386,43793 -8042.348666667938,2.241867780685425,14664.235578775406,45928,0,14664.235578775406,0.9859084486961364,0.0492966175079345,0.2490588585001329,43793,22710.05312728882,0.9925437569618224,0.0238797627389431,0.5548479854623927,0.986866533756256,0.0462991744279861,0.2675998202951851,43793 -8164.049585580826,2.278223991394043,14904.3414375782,46685,0,14904.3414375782,0.985787570476532,0.0489269718527793,0.255223206791752,43793,23071.91637802124,0.9926679134368896,0.023723516613245,0.5522737730499209,0.986710250377655,0.0460764281451702,0.2673692998549654,43793 -8290.571984052658,2.3142735958099365,15144.60926938057,47437,0,15144.60926938057,0.9859118461608888,0.0494793020188808,0.2525168562310033,43793,23438.763329267505,0.9925719499588012,0.0237914603203535,0.5576811836200534,0.986846685409546,0.0464228764176368,0.2657642561931333,43793 -8415.009822368622,2.3507189750671387,15384.76617383957,48189,0,15384.76617383957,0.9858225584030152,0.0492965169250965,0.2521494930765173,43793,23803.41541624069,0.9928362369537354,0.0230195168405771,0.558678018020372,0.9867460131645204,0.0464305877685546,0.2732741377262889,43793 -8537.169777154922,2.38702654838562,15624.714596033096,48939,0,15624.714596033096,0.9858726859092712,0.0493974350392818,0.2561368281989875,43793,24165.57988381385,0.9928126335144044,0.0228299181908369,0.5758678220084759,0.986866533756256,0.0463282391428947,0.2763794516470148,43793 -8657.844844341278,2.4283759593963623,15864.714323282242,49683,0,15864.714323282242,0.9858827590942384,0.0499323159456253,0.2530611175182334,43793,24526.31983613968,0.9930967688560486,0.0219309777021408,0.5948772431380066,0.986885666847229,0.0467839241027832,0.2694161972938923,43793 -8778.289328336716,2.465296030044556,16104.945301055908,50443,0,16104.945301055908,0.9857075810432434,0.0498334541916847,0.2552904516568788,43793,24887.052001953125,0.9935494065284728,0.0208131670951843,0.6246684589240922,0.9867971539497375,0.0465557985007762,0.2757770272471453,43793 -8900.391793251038,2.501399278640747,16345.135685920715,51201,0,16345.135685920715,0.9858195781707764,0.0507305338978767,0.2476450730734251,43793,25249.40115904808,0.993834674358368,0.019873609766364,0.6402711633729733,0.986703395843506,0.0474903881549835,0.2653409913394465,43793 -9021.43328166008,2.5409858226776123,16585.19568347931,51945,0,16585.19568347931,0.9858056902885436,0.0504932701587677,0.253286375824104,43793,25610.56532263756,0.9937007427215576,0.020169697701931,0.6355450158745422,0.9866883754730223,0.0474072881042957,0.2672529701239556,43793 -9140.794448375702,2.5773093700408936,16825.145767211914,52698,0,16825.145767211914,0.9858128428459167,0.0505554601550102,0.2467603426558873,43793,25969.932988643646,0.9938126802444458,0.019981313496828,0.6342326759145145,0.9867126941680908,0.0471597947180271,0.2665871337103819,43793 -9264.3684155941,2.6135025024414062,17065.106918811798,53458,0,17065.106918811798,0.9857555627822876,0.0508050434291362,0.2525310318683968,43793,26333.523721456528,0.993604838848114,0.0205391012132167,0.6163999862960505,0.9866956472396852,0.047807291150093,0.2672979128963146,43793 -9388.612121343613,2.655043125152588,17305.169719457626,54198,0,17305.169719457626,0.9857187271118164,0.0511037185788154,0.2504868933628825,43793,26697.89265561104,0.993502676486969,0.0207406654953956,0.6146754178331644,0.986743152141571,0.047809213399887,0.2679407266434734,43793 -9512.002286434174,2.691478490829468,17545.279770612717,54955,0,17545.279770612717,0.985710084438324,0.0512884557247161,0.2556887617367147,43793,27061.449331760406,0.9935206770896912,0.0204508285969495,0.6152469816588861,0.9867601990699768,0.0481349416077137,0.2696156510939548,43793 -9630.57388138771,2.7295162677764893,17785.269641399384,55711,0,17785.269641399384,0.9857816696166992,0.0518056154251098,0.2560246636285639,43793,27420.06883215904,0.9935681819915771,0.0202482622116804,0.6322992889381533,0.9867435693740844,0.0486103147268295,0.2674156022578643,43793 -9754.26741719246,2.77383804321289,18025.47614264488,56465,0,18025.47614264488,0.98571515083313,0.0521701723337173,0.2486022880193414,43793,27784.033193588257,0.9937283992767334,0.019526956602931,0.6458244588409702,0.986622989177704,0.0490015819668769,0.2654806746801502,43793 -9873.88046002388,2.812410593032837,18265.4832341671,57218,0,18265.4832341671,0.9856405854225159,0.05209792032837868,0.25158631821213995,43793,28143.712094783783,0.994199812412262,0.01858672685921192,0.6527285246331709,0.9865426421165466,0.04894685745239258,0.2711905727317665,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index f086c8e5c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,658 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3602805,0.73957366,,,,,,,,,,,,,,,,, -1,,,0.5092470049858093,0.7391904592514038,0.0221835560731189,0.5064646005630493,0.7422076463699341,0.0253144192803209,43793.0,0.5047287344932556,0.7438743114471436,0.0269396591122225,43793.0,18.134385585784912,322.87538480758667,18.134385585784912,304.7409582138061,0.0,0.0 -100,0.6410757,0.49086633,,,,,,,,,,,,,,,,, -200,0.40028146,0.36105412,,,,,,,,,,,,,,,,, -300,0.2900596,0.25938684,,,,,,,,,,,,,,,,, -400,0.19244447,0.17325296,,,,,,,,,,,,,,,,, -500,0.11756218,0.12124865,,,,,,,,,,,,,,,,, -600,0.074283645,0.08897363,,,,,,,,,,,,,,,,, -700,0.04829801,0.07373614,,,,,,,,,,,,,,,,, -747,,,0.9867319464683532,0.0704648867249488,0.0333204718944733,0.9841179251670836,0.0788265094161033,0.0337082857474339,43793.0,0.983142077922821,0.0816605240106582,0.0369262199382961,43793.0,258.18942499160767,679.2467248439789,258.18942499160767,421.0076684951782,0.029463768005371,0.0 -800,0.08523462,0.06405493,,,,,,,,,,,,,,,,, -900,0.0935932,0.058690917,,,,,,,,,,,,,,,,, -1000,0.10805061,0.05908417,,,,,,,,,,,,,,,,, -1100,0.038734186,0.053865183,,,,,,,,,,,,,,,,, -1200,0.18394944,0.050147902,,,,,,,,,,,,,,,,, -1300,0.180009,0.04961869,,,,,,,,,,,,,,,,, -1400,0.09363408,0.049552456,,,,,,,,,,,,,,,,, -1500,,,0.9873274564743042,0.0489182732999324,0.0835694140901646,0.9845482110977172,0.0574894733726978,0.0848161563026899,43793.0,0.9835611581802368,0.0604051724076271,0.0883908861125532,43793.0,498.358202457428,1041.0658912658691,498.358202457428,542.6084208488464,0.0602407455444335,0.0 -1500,0.362386,0.050941344,,,,,,,,,,,,,,,,, -1600,0.3347069,0.04505878,,,,,,,,,,,,,,,,, -1700,0.16034609,0.050992392,,,,,,,,,,,,,,,,, -1800,0.1698825,0.054396052,,,,,,,,,,,,,,,,, -1900,0.11209466,0.049737602,,,,,,,,,,,,,,,,, -2000,0.2999754,0.054708734,,,,,,,,,,,,,,,,, -2100,0.16854769,0.05156015,,,,,,,,,,,,,,,,, -2200,0.26465592,0.042395093,,,,,,,,,,,,,,,,, -2245,,,0.9878538846969604,0.0445154048502445,0.1271006322314452,0.9850134253501892,0.0538621060550212,0.1286356176045245,43793.0,0.9840468168258668,0.0566889606416225,0.1269769745888034,43793.0,738.4302983283997,1404.1376762390137,738.4302983283997,665.5626842975616,0.0868616104125976,0.0 -2300,0.06781817,0.04115476,,,,,,,,,,,,,,,,, -2400,0.228929,0.047571715,,,,,,,,,,,,,,,,, -2500,0.13504602,0.047458943,,,,,,,,,,,,,,,,, -2600,0.1338194,0.04486001,,,,,,,,,,,,,,,,, -2700,0.09543726,0.04015502,,,,,,,,,,,,,,,,, -2800,0.21646881,0.044227254,,,,,,,,,,,,,,,,, -2900,0.054344196,0.04391369,,,,,,,,,,,,,,,,, -2983,,,0.987955927848816,0.0427691042423248,0.1545525683655089,0.9852378964424132,0.0521178096532821,0.1480877293086972,43793.0,0.9842873215675354,0.0547768212854862,0.1546334280093363,43793.0,978.480020046234,1769.383313894272,978.480020046234,790.7114028930664,0.1153392791748046,0.0 -3000,0.15997611,0.041751623,,,,,,,,,,,,,,,,, -3100,0.11646138,0.043886688,,,,,,,,,,,,,,,,, -3200,0.09258105,0.046789393,,,,,,,,,,,,,,,,, -3300,0.047158554,0.04106905,,,,,,,,,,,,,,,,, -3400,0.12519452,0.047374662,,,,,,,,,,,,,,,,, -3500,0.0955483,0.047442537,,,,,,,,,,,,,,,,, -3600,0.103953026,0.044931874,,,,,,,,,,,,,,,,, -3700,0.09307338,0.041704185,,,,,,,,,,,,,,,,, -3727,,,0.9880806803703308,0.0414005331695079,0.186444753574127,0.9853852987289428,0.0503360740840435,0.1722589120405069,43793.0,0.9844751358032228,0.0530382059514522,0.1746518429197931,43793.0,1218.683366060257,2132.798628091812,1218.683366060257,913.8746719360352,0.1453778743743896,0.0 -3800,0.105174705,0.04406657,,,,,,,,,,,,,,,,, -3900,0.06783569,0.043092206,,,,,,,,,,,,,,,,, -4000,0.09944194,0.037757892,,,,,,,,,,,,,,,,, -4100,0.07669856,0.04155025,,,,,,,,,,,,,,,,, -4200,0.046947055,0.041408163,,,,,,,,,,,,,,,,, -4300,0.041696716,0.04141068,,,,,,,,,,,,,,,,, -4400,0.27917707,0.047565147,,,,,,,,,,,,,,,,, -4481,,,0.9885159134864808,0.0399757809937,0.2038349252000051,0.9855237007141112,0.0492455773055553,0.1806705486653925,43793.0,0.9846848845481871,0.0517690926790237,0.1815007698173546,43793.0,1458.725549697876,2498.5742666721344,1458.725549697876,1039.562193632126,0.1721065044403076,0.0 -4500,0.057529137,0.040093046,,,,,,,,,,,,,,,,, -4600,0.071352996,0.043816496,,,,,,,,,,,,,,,,, -4700,0.06564144,0.040230956,,,,,,,,,,,,,,,,, -4800,0.11592184,0.045219753,,,,,,,,,,,,,,,,, -4900,0.0646675,0.047737207,,,,,,,,,,,,,,,,, -5000,0.056036565,0.04218325,,,,,,,,,,,,,,,,, -5100,0.030606197,0.03924094,,,,,,,,,,,,,,,,, -5200,0.029410431,0.037804175,,,,,,,,,,,,,,,,, -5228,,,0.9886970520019532,0.0385160632431507,0.2270452773609716,0.9857315421104432,0.0484458357095718,0.192269829851199,43793.0,0.984925389289856,0.051138449460268,0.1908380836181807,43793.0,1698.9222049713137,2865.626400947571,1698.9222049713137,1166.363187789917,0.2028937339782714,0.0 -5300,0.094819695,0.039752588,,,,,,,,,,,,,,,,, -5400,0.09328389,0.04043171,,,,,,,,,,,,,,,,, -5500,0.05417134,0.044599786,,,,,,,,,,,,,,,,, -5600,0.07710872,0.039249934,,,,,,,,,,,,,,,,, -5700,0.05342732,0.039993286,,,,,,,,,,,,,,,,, -5800,0.064322524,0.04305338,,,,,,,,,,,,,,,,, -5900,0.045568716,0.043371364,,,,,,,,,,,,,,,,, -5996,,,0.9888204336166382,0.0382240451872348,0.2375763856274924,0.9858850240707396,0.0478058569133281,0.2007561521562516,43793.0,0.9850820899009703,0.0504555068910121,0.2009072419200461,43793.0,1939.038080215454,3232.058907032013,1939.038080215454,1292.633964061737,0.2294843196868896,0.0 -6000,0.035342142,0.036545284,,,,,,,,,,,,,,,,, -6100,0.0571316,0.044168465,,,,,,,,,,,,,,,,, -6200,0.042004544,0.03832966,,,,,,,,,,,,,,,,, -6300,0.048051234,0.041910063,,,,,,,,,,,,,,,,, -6400,0.028696073,0.03939201,,,,,,,,,,,,,,,,, -6500,0.033606753,0.03765546,,,,,,,,,,,,,,,,, -6600,0.034239184,0.041569863,,,,,,,,,,,,,,,,, -6700,0.06109573,0.043239266,,,,,,,,,,,,,,,,, -6757,,,0.9890872836112976,0.0374721959233284,0.2541572959008027,0.9859747290611268,0.0472970940172672,0.2038481776836391,43793.0,0.9850327968597412,0.0497962236404418,0.2034223256774653,43793.0,2179.076942920685,3593.875339746475,2179.076942920685,1414.3643636703491,0.256913423538208,0.0 -6800,0.02708919,0.03774938,,,,,,,,,,,,,,,,, -6900,0.025139013,0.038105752,,,,,,,,,,,,,,,,, -7000,0.056832034,0.03938669,,,,,,,,,,,,,,,,, -7100,0.021533193,0.03499056,,,,,,,,,,,,,,,,, -7200,0.025728455,0.03847789,,,,,,,,,,,,,,,,, -7300,0.022348845,0.03825816,,,,,,,,,,,,,,,,, -7400,0.022181625,0.04051082,,,,,,,,,,,,,,,,, -7500,0.036350187,0.04126757,,,,,,,,,,,,,,,,, -7522,,,0.989159345626831,0.0369904600083828,0.2583130676710853,0.9860392808914183,0.0471784397959709,0.2126993248703816,43793.0,0.9851840138435364,0.0496272966265678,0.212045926660142,43793.0,2419.183886051178,3959.8080909252167,2419.183886051178,1540.1427791118622,0.2850782871246338,0.0 -7600,0.039540313,0.041927297,,,,,,,,,,,,,,,,, -7700,0.032564223,0.045076065,,,,,,,,,,,,,,,,, -7800,0.02884492,0.03816714,,,,,,,,,,,,,,,,, -7900,0.04162379,0.040194273,,,,,,,,,,,,,,,,, -8000,0.024647813,0.039291244,,,,,,,,,,,,,,,,, -8100,0.026102776,0.042551838,,,,,,,,,,,,,,,,, -8200,0.064008564,0.03899405,,,,,,,,,,,,,,,,, -8281,,,0.9891362190246582,0.037117350846529,0.2612576523422221,0.9862495064735411,0.0464592687785625,0.2196623737679801,43793.0,0.9853053092956544,0.0490796826779842,0.2183831775864873,43793.0,2659.2870647907257,4325.237766027451,2659.2870647907257,1665.422149181366,0.3125410079956054,0.0 -8300,0.02961861,0.043979947,,,,,,,,,,,,,,,,, -8400,0.027514221,0.035701405,,,,,,,,,,,,,,,,, -8500,0.029894616,0.043382376,,,,,,,,,,,,,,,,, -8600,0.031674214,0.039775398,,,,,,,,,,,,,,,,, -8700,0.02594498,0.03977437,,,,,,,,,,,,,,,,, -8800,0.020978658,0.038048323,,,,,,,,,,,,,,,,, -8900,0.037303902,0.04098896,,,,,,,,,,,,,,,,, -9000,0.030689517,0.03786298,,,,,,,,,,,,,,,,, -9050,,,0.9894993901252748,0.0358033999800682,0.2808998759655631,0.9862414002418518,0.0460271462798118,0.2250249947677923,43793.0,0.985432505607605,0.0485740005970001,0.2298582522393411,43793.0,2899.4726645946503,4692.283871173859,2899.4726645946503,1792.234255552292,0.341059923171997,0.0 -9100,0.02571278,0.04350158,,,,,,,,,,,,,,,,, -9200,0.037767723,0.039480343,,,,,,,,,,,,,,,,, -9300,0.0306333,0.04090809,,,,,,,,,,,,,,,,, -9400,0.026739923,0.03736642,,,,,,,,,,,,,,,,, -9500,0.030806724,0.038235776,,,,,,,,,,,,,,,,, -9600,0.026355274,0.036228724,,,,,,,,,,,,,,,,, -9700,0.04192772,0.037462264,,,,,,,,,,,,,,,,, -9800,0.03551106,0.040617593,,,,,,,,,,,,,,,,, -9810,,,0.9892947673797609,0.0360191650688648,0.2829000146369129,0.9861679077148438,0.04629398137331,0.2249200090916605,43793.0,0.9851722121238708,0.0490936003625392,0.2129779767804349,43793.0,3139.7406141757965,5061.331539392471,3139.7406141757965,1920.9666464328768,0.368741512298584,0.0 -9900,0.023182759,0.035527572,,,,,,,,,,,,,,,,, -10000,0.021782426,0.04081412,,,,,,,,,,,,,,,,, -10100,0.026067315,0.039315537,,,,,,,,,,,,,,,,, -10200,0.034388985,0.040411044,,,,,,,,,,,,,,,,, -10300,0.028125701,0.03847059,,,,,,,,,,,,,,,,, -10400,0.026414214,0.034318194,,,,,,,,,,,,,,,,, -10500,0.025659502,0.03758497,,,,,,,,,,,,,,,,, -10564,,,0.9896652698516846,0.0348614193499088,0.2978428145748989,0.9864857792854308,0.0457414388656616,0.2394883731962805,43793.0,0.9855567812919616,0.048363097012043,0.2423461242788807,43793.0,3379.834892272949,5436.286778688431,3379.834892272949,2055.779053211212,0.3968634605407715,0.0 -10600,0.0227283,0.0403812,,,,,,,,,,,,,,,,, -10700,0.02266281,0.037981108,,,,,,,,,,,,,,,,, -10800,0.022377133,0.035878696,,,,,,,,,,,,,,,,, -10900,0.03385325,0.037715863,,,,,,,,,,,,,,,,, -11000,0.027745906,0.038171362,,,,,,,,,,,,,,,,, -11100,0.03355294,0.036798596,,,,,,,,,,,,,,,,, -11200,0.027100453,0.03636525,,,,,,,,,,,,,,,,, -11300,0.038944736,0.039678127,,,,,,,,,,,,,,,,, -11314,,,0.989776849746704,0.0341957435011863,0.3297334600062838,0.9864062070846558,0.0457753986120224,0.2374500720500937,43793.0,0.9855571985244752,0.0484842956066131,0.2412491077403872,43793.0,3619.9956657886505,5807.399445056915,3619.9956657886505,2186.676379680633,0.4272780418395996,0.0 -11400,0.024053967,0.037765797,,,,,,,,,,,,,,,,, -11500,0.037099537,0.03680772,,,,,,,,,,,,,,,,, -11600,0.027803317,0.033524476,,,,,,,,,,,,,,,,, -11700,0.026705442,0.036362357,,,,,,,,,,,,,,,,, -11800,0.023866443,0.037601765,,,,,,,,,,,,,,,,, -11900,0.049218915,0.040222,,,,,,,,,,,,,,,,, -12000,0.03375869,0.037933365,,,,,,,,,,,,,,,,, -12064,,,0.9900395274162292,0.0332000404596328,0.3556511829583499,0.986559271812439,0.0451599508523941,0.2459968183186453,43793.0,0.9856915473937988,0.0479581244289875,0.2449247904619788,43793.0,3860.237807989121,6178.707715749741,3860.237807989121,2317.691652059555,0.4562311172485351,0.0 -12100,0.03490409,0.03439087,,,,,,,,,,,,,,,,, -12200,0.031256076,0.036570977,,,,,,,,,,,,,,,,, -12300,0.029429207,0.03536356,,,,,,,,,,,,,,,,, -12400,0.03437275,0.034023,,,,,,,,,,,,,,,,, -12500,0.039733887,0.034723803,,,,,,,,,,,,,,,,, -12600,0.03186765,0.035108414,,,,,,,,,,,,,,,,, -12700,0.047650702,0.034485392,,,,,,,,,,,,,,,,, -12800,0.053911515,0.03665238,,,,,,,,,,,,,,,,, -12819,,,0.9901712536811828,0.0327010117471218,0.3520209612511431,0.9864898324012756,0.0451404564082622,0.2395079632055661,43793.0,0.985736608505249,0.047790914773941,0.2415669797862868,43793.0,4100.330991983414,6548.0606780052185,4100.330991983414,2446.9008860588074,0.4860799312591553,0.0 -12900,0.05520495,0.037591618,,,,,,,,,,,,,,,,, -13000,0.04236428,0.041252382,,,,,,,,,,,,,,,,, -13100,0.029373467,0.040202197,,,,,,,,,,,,,,,,, -13200,0.04504938,0.032586902,,,,,,,,,,,,,,,,, -13300,0.051285956,0.037438225,,,,,,,,,,,,,,,,, -13400,0.043406863,0.037832525,,,,,,,,,,,,,,,,, -13500,0.029990634,0.034249436,,,,,,,,,,,,,,,,, -13572,,,0.9905293583869934,0.0316002890467643,0.368745897355923,0.9866258502006532,0.0448617041110992,0.2575587049856996,43793.0,0.98572438955307,0.0475961826741695,0.2534813540282405,43793.0,4340.384344100952,6914.880608558655,4340.384344100952,2573.6196570396423,0.514338493347168,0.0 -13600,0.056186646,0.0348758,,,,,,,,,,,,,,,,, -13700,0.045794837,0.036282655,,,,,,,,,,,,,,,,, -13800,0.03582824,0.033489726,,,,,,,,,,,,,,,,, -13900,0.034713335,0.03335566,,,,,,,,,,,,,,,,, -14000,0.04481325,0.035865255,,,,,,,,,,,,,,,,, -14100,0.032526143,0.0319582,,,,,,,,,,,,,,,,, -14200,0.036200162,0.031515665,,,,,,,,,,,,,,,,, -14300,0.043707155,0.03733552,,,,,,,,,,,,,,,,, -14326,,,0.9907206296920776,0.0310354493558406,0.3946692529617591,0.986632764339447,0.0443393066525459,0.251086335599621,43793.0,0.9858802556991576,0.0469530262053012,0.2496725852030368,43793.0,4580.391752004623,7286.849791765213,4580.391752004623,2705.533165216446,0.5428283214569092,0.0 -14400,0.059641674,0.035896078,,,,,,,,,,,,,,,,, -14500,0.037956532,0.033069484,,,,,,,,,,,,,,,,, -14600,0.048538778,0.037203707,,,,,,,,,,,,,,,,, -14700,0.059704475,0.03961403,,,,,,,,,,,,,,,,, -14800,0.043252777,0.03544521,,,,,,,,,,,,,,,,, -14900,0.054938003,0.03657567,,,,,,,,,,,,,,,,, -15000,0.052627582,0.03510514,,,,,,,,,,,,,,,,, -15080,,,0.9904849529266356,0.0318308845162391,0.376709458590151,0.9865093231201172,0.0448063239455223,0.2508813830749123,43793.0,0.9857193827629088,0.0474602282047271,0.2438626242540415,43793.0,4820.388184309006,7657.327804088592,4820.388184309006,2835.9652168750763,0.5726001262664795,0.0 -15100,0.06575453,0.03920303,,,,,,,,,,,,,,,,, -15200,0.050126124,0.039847206,,,,,,,,,,,,,,,,, -15300,0.04362441,0.035034757,,,,,,,,,,,,,,,,, -15400,0.0475436,0.035545174,,,,,,,,,,,,,,,,, -15500,0.06997517,0.03585543,,,,,,,,,,,,,,,,, -15600,0.056813605,0.035162337,,,,,,,,,,,,,,,,, -15700,0.051938582,0.032584604,,,,,,,,,,,,,,,,, -15800,0.050246187,0.033882126,,,,,,,,,,,,,,,,, -15827,,,0.9903941750526428,0.0317538008093833,0.3847419663950548,0.98667573928833,0.0448875240981578,0.2546998476400984,43793.0,0.985835611820221,0.0477200932800769,0.2473249334432067,43793.0,5060.620755910873,8026.981867313385,5060.620755910873,2965.335473537445,0.6022679805755615,0.0 -15900,0.050258797,0.03298019,,,,,,,,,,,,,,,,, -16000,0.076754265,0.034893453,,,,,,,,,,,,,,,,, -16100,0.04570125,0.034467544,,,,,,,,,,,,,,,,, -16200,0.056084126,0.036321808,,,,,,,,,,,,,,,,, -16300,0.057382822,0.035269808,,,,,,,,,,,,,,,,, -16400,0.072563775,0.033610985,,,,,,,,,,,,,,,,, -16500,0.06885053,0.03350681,,,,,,,,,,,,,,,,, -16575,,,0.9903529286384584,0.0319628044962883,0.3737279558476397,0.9867126941680908,0.0445161461830139,0.2609148835963837,43793.0,0.985753893852234,0.0474314987659454,0.2495560120245449,43793.0,5300.623321056366,8397.686516284943,5300.623321056366,3095.985473155976,0.6344373226165771,0.0 -16600,0.04894679,0.03576402,,,,,,,,,,,,,,,,, -16700,0.06630433,0.03442499,,,,,,,,,,,,,,,,, -16800,0.05997898,0.03614558,,,,,,,,,,,,,,,,, -16900,0.066836536,0.03718469,,,,,,,,,,,,,,,,, -17000,0.055159736,0.03444828,,,,,,,,,,,,,,,,, -17100,0.047963776,0.031152435,,,,,,,,,,,,,,,,, -17200,0.061428618,0.03457578,,,,,,,,,,,,,,,,, -17300,0.06846207,0.035686143,,,,,,,,,,,,,,,,, -17330,,,0.990510880947113,0.0314424484968185,0.3802214268304177,0.9867005348205566,0.044713731855154,0.261765784988841,43793.0,0.9857543110847472,0.047792412340641,0.2385126380118646,43793.0,5540.597485303879,8769.324578762054,5540.597485303879,3227.5997779369354,0.6642558574676514,0.0 -17400,0.055850603,0.03991692,,,,,,,,,,,,,,,,, -17500,0.094905905,0.037295625,,,,,,,,,,,,,,,,, -17600,0.056742202,0.03435528,,,,,,,,,,,,,,,,, -17700,0.05554853,0.03555731,,,,,,,,,,,,,,,,, -17800,0.054125424,0.03469606,,,,,,,,,,,,,,,,, -17900,0.099334046,0.03548863,,,,,,,,,,,,,,,,, -18000,0.06857005,0.038686134,,,,,,,,,,,,,,,,, -18077,,,0.990651547908783,0.0309187453240156,0.3985708604516049,0.9868072867393494,0.0443336330354213,0.266102079928784,43793.0,0.9859607219696044,0.0470982193946838,0.2513989695053806,43793.0,5780.7910261154175,9143.603893518448,5780.7910261154175,3361.631055355072,0.6942794322967529,0.0 -18100,0.08736213,0.0346118,,,,,,,,,,,,,,,,, -18200,0.065201834,0.03207556,,,,,,,,,,,,,,,,, -18300,0.06644592,0.037735272,,,,,,,,,,,,,,,,, -18400,0.08274746,0.032430075,,,,,,,,,,,,,,,,, -18500,0.075410806,0.03533591,,,,,,,,,,,,,,,,, -18600,0.06871564,0.0355669,,,,,,,,,,,,,,,,, -18700,0.06286932,0.03806667,,,,,,,,,,,,,,,,, -18800,0.06731543,0.035607833,,,,,,,,,,,,,,,,, -18829,,,0.9905492067337036,0.0310190469026565,0.3952029213450361,0.9866875410079956,0.0448851659893989,0.2635644689347876,43793.0,0.9857859015464784,0.0480262599885463,0.2439402673153815,43793.0,6020.8155081272125,9515.525021791458,6020.8155081272125,3493.475923061371,0.7263171672821045,0.0 -18900,0.063734256,0.034502264,,,,,,,,,,,,,,,,, -19000,0.058442447,0.034111183,,,,,,,,,,,,,,,,, -19100,0.07229619,0.03329375,,,,,,,,,,,,,,,,, -19200,0.05629614,0.03240436,,,,,,,,,,,,,,,,, -19300,0.065515496,0.035482954,,,,,,,,,,,,,,,,, -19400,0.057515875,0.034512203,,,,,,,,,,,,,,,,, -19500,0.09478555,0.036695056,,,,,,,,,,,,,,,,, -19584,,,0.9906980395317078,0.0305677652359008,0.4242602740789738,0.9866716861724854,0.0448466949164867,0.2587407145618029,43793.0,0.9859375357627868,0.0473642535507679,0.2539158233362041,43793.0,6260.985202074051,9884.426526784897,6260.985202074051,3622.156748771668,0.7575254440307617,0.0 -19600,0.059921686,0.03224696,,,,,,,,,,,,,,,,, -19700,0.05747511,0.032989174,,,,,,,,,,,,,,,,, -19800,0.06217112,0.034768455,,,,,,,,,,,,,,,,, -19900,0.058276672,0.033648178,,,,,,,,,,,,,,,,, -20000,0.060597505,0.035208493,,,,,,,,,,,,,,,,, -20100,0.05525888,0.032657564,,,,,,,,,,,,,,,,, -20200,0.07295973,0.034010395,,,,,,,,,,,,,,,,, -20300,0.07379199,0.034065295,,,,,,,,,,,,,,,,, -20334,,,0.9910008907318116,0.0294814500957727,0.4198984412224145,0.9867601990699768,0.0444596223533153,0.2566778579806512,43793.0,0.9859299659729004,0.0473265685141086,0.2548134925721332,43793.0,6501.081083774567,10252.942810297012,6501.081083774567,3750.527698040009,0.7873868942260742,0.0 -20400,0.09064492,0.038345512,,,,,,,,,,,,,,,,, -20500,0.087185405,0.033503205,,,,,,,,,,,,,,,,, -20600,0.065949604,0.03295248,,,,,,,,,,,,,,,,, -20700,0.063214496,0.03421887,,,,,,,,,,,,,,,,, -20800,0.104525566,0.034046117,,,,,,,,,,,,,,,,, -20900,0.05792636,0.030833514,,,,,,,,,,,,,,,,, -21000,0.08357871,0.03558199,,,,,,,,,,,,,,,,, -21083,,,0.9911450147628784,0.0289554018527269,0.438380159959826,0.9866806268692015,0.0445852689445018,0.2575463262048126,43793.0,0.9858819246292114,0.047335647046566,0.2524632223612293,43793.0,6741.23437333107,10619.10322880745,6741.23437333107,3876.483735561371,0.8186118602752686,0.0 -21100,0.07244632,0.03662355,,,,,,,,,,,,,,,,, -21200,0.07727023,0.034483746,,,,,,,,,,,,,,,,, -21300,0.11401589,0.03355091,,,,,,,,,,,,,,,,, -21400,0.10763151,0.034871716,,,,,,,,,,,,,,,,, -21500,0.071703725,0.03336892,,,,,,,,,,,,,,,,, -21600,0.06942757,0.03493149,,,,,,,,,,,,,,,,, -21700,0.0806185,0.031608857,,,,,,,,,,,,,,,,, -21800,0.07225669,0.033244096,,,,,,,,,,,,,,,,, -21830,,,0.9911766052246094,0.0291900876909494,0.4514105935616376,0.9867467880249025,0.0448013246059417,0.2649502321140494,43793.0,0.9859269857406616,0.0477520860731601,0.2534212336427425,43793.0,6981.344316244125,10984.614884376526,6981.344316244125,4001.835470676422,0.8482942581176758,0.0 -21900,0.07564987,0.03283607,,,,,,,,,,,,,,,,, -22000,0.10018704,0.03137104,,,,,,,,,,,,,,,,, -22100,0.07210544,0.035550024,,,,,,,,,,,,,,,,, -22200,0.081202276,0.037259877,,,,,,,,,,,,,,,,, -22300,0.066614024,0.035723224,,,,,,,,,,,,,,,,, -22400,0.06697526,0.03453742,,,,,,,,,,,,,,,,, -22500,0.072255045,0.037493277,,,,,,,,,,,,,,,,, -22585,,,0.9911478757858276,0.0292267147451639,0.422738707506834,0.9867796897888184,0.044557336717844,0.2584770706554278,43793.0,0.9858440160751344,0.04732321575284,0.246868166257547,43793.0,7221.317653656006,11352.638570547104,7221.317653656006,4129.834286689758,0.8798651695251465,0.0 -22600,0.071454704,0.033650376,,,,,,,,,,,,,,,,, -22700,0.08660182,0.03253293,,,,,,,,,,,,,,,,, -22800,0.07068395,0.03552698,,,,,,,,,,,,,,,,, -22900,0.06852462,0.031142492,,,,,,,,,,,,,,,,, -23000,0.06559799,0.036324926,,,,,,,,,,,,,,,,, -23100,0.079037264,0.031787574,,,,,,,,,,,,,,,,, -23200,0.07390123,0.036067344,,,,,,,,,,,,,,,,, -23300,0.068186015,0.03410095,,,,,,,,,,,,,,,,, -23334,,,0.9909042119979858,0.0301217660307884,0.4214129838863844,0.9866956472396852,0.0443504191935062,0.2666798122420432,43793.0,0.9857808351516724,0.0471952743828296,0.2496342910386312,43793.0,7461.457053661346,11724.822311639786,7461.457053661346,4261.829008817673,0.9099385738372804,0.0 -23400,0.08297999,0.030672565,,,,,,,,,,,,,,,,, -23500,0.07723312,0.03202799,,,,,,,,,,,,,,,,, -23600,0.06943443,0.031345338,,,,,,,,,,,,,,,,, -23700,0.08063764,0.033775248,,,,,,,,,,,,,,,,, -23800,0.072253354,0.035071425,,,,,,,,,,,,,,,,, -23900,0.0850408,0.033192106,,,,,,,,,,,,,,,,, -24000,0.09179044,0.032375373,,,,,,,,,,,,,,,,, -24086,,,0.99093496799469,0.0298298224806785,0.4233810275004899,0.9868283867836,0.0444389320909976,0.2710360070079665,43793.0,0.9860095381736756,0.046928908675909,0.2543803166805494,43793.0,7701.260676383972,12093.718660354614,7701.260676383972,4390.558969974518,1.252744436264038,0.0 -24100,0.10247901,0.039097827,,,,,,,,,,,,,,,,, -24200,0.061169375,0.033228565,,,,,,,,,,,,,,,,, -24300,0.065721735,0.031635817,,,,,,,,,,,,,,,,, -24400,0.06659834,0.0358654,,,,,,,,,,,,,,,,, -24500,0.07789589,0.034781378,,,,,,,,,,,,,,,,, -24600,0.07714717,0.033012748,,,,,,,,,,,,,,,,, -24700,0.076893106,0.03375523,,,,,,,,,,,,,,,,, -24800,0.06626461,0.0354171,,,,,,,,,,,,,,,,, -24831,,,0.9908654093742372,0.0299554914236068,0.4146248075562608,0.9866400361061096,0.0446502827107906,0.2619743777067673,43793.0,0.9857420921325684,0.0475777871906757,0.2484946569880513,43793.0,7941.512230634689,12461.927044153214,7941.512230634689,4518.463991880417,1.2852094173431396,0.0 -24900,0.07609089,0.03380963,,,,,,,,,,,,,,,,, -25000,0.074536696,0.03377588,,,,,,,,,,,,,,,,, -25100,0.119022176,0.03513629,,,,,,,,,,,,,,,,, -25200,0.09516807,0.031163253,,,,,,,,,,,,,,,,, -25300,0.078748316,0.03273316,,,,,,,,,,,,,,,,, -25400,0.08064929,0.034906473,,,,,,,,,,,,,,,,, -25500,0.06552965,0.033103686,,,,,,,,,,,,,,,,, -25587,,,0.9910185933113098,0.0293855872005224,0.432371964851221,0.9868754744529724,0.0446766056120395,0.2670693539503197,43793.0,0.985908031463623,0.047651320695877,0.2526763650455053,43793.0,8181.642039299011,12827.246604681017,8181.642039299011,4643.603357791901,1.3162312507629397,0.0 -25600,0.0771851,0.032972787,,,,,,,,,,,,,,,,, -25700,0.07949363,0.032602683,,,,,,,,,,,,,,,,, -25800,0.07175028,0.03351815,,,,,,,,,,,,,,,,, -25900,0.10210972,0.035199363,,,,,,,,,,,,,,,,, -26000,0.098027326,0.034548774,,,,,,,,,,,,,,,,, -26100,0.09510643,0.03551696,,,,,,,,,,,,,,,,, -26200,0.10014035,0.0363641,,,,,,,,,,,,,,,,, -26300,0.09299055,0.032475803,,,,,,,,,,,,,,,,, -26330,,,0.991351306438446,0.0284819547086954,0.4572311420610337,0.986766278743744,0.0443845465779304,0.2718749601531016,43793.0,0.9859522581100464,0.0470491051673889,0.2559921718547284,43793.0,8421.819328069687,13196.254113912582,8421.819328069687,4772.377794981003,1.3500714302062988,0.0 -26400,0.089031,0.035318162,,,,,,,,,,,,,,,,, -26500,0.07957814,0.03125726,,,,,,,,,,,,,,,,, -26600,0.09937268,0.030631477,,,,,,,,,,,,,,,,, -26700,0.07255517,0.033187103,,,,,,,,,,,,,,,,, -26800,0.06734344,0.031155376,,,,,,,,,,,,,,,,, -26900,0.066124655,0.033906683,,,,,,,,,,,,,,,,, -27000,0.0760836,0.034196757,,,,,,,,,,,,,,,,, -27090,,,0.9912184476852416,0.0286246985197067,0.4483122074480636,0.9867537021636964,0.0445280410349369,0.2673226801916273,43793.0,0.985910177230835,0.0474141016602516,0.2518513987897968,43793.0,8661.81754231453,13566.056858778,8661.81754231453,4902.130871295929,1.381518840789795,0.0 -27100,0.06693612,0.031714447,,,,,,,,,,,,,,,,, -27200,0.09138554,0.03523301,,,,,,,,,,,,,,,,, -27300,0.10352392,0.03127005,,,,,,,,,,,,,,,,, -27400,0.067418516,0.03181215,,,,,,,,,,,,,,,,, -27500,0.07522644,0.030873751,,,,,,,,,,,,,,,,, -27600,0.08805649,0.034251142,,,,,,,,,,,,,,,,, -27700,0.07663164,0.034474272,,,,,,,,,,,,,,,,, -27800,0.09871478,0.03425004,,,,,,,,,,,,,,,,, -27850,,,0.9915094375610352,0.0278093516826629,0.4788816882412824,0.9867618083953856,0.0447036251425743,0.2695098683814103,43793.0,0.9858503341674804,0.0474634394049644,0.2518078150224259,43793.0,8902.068484067917,13932.1010928154,8902.068484067917,5027.87125992775,1.414813995361328,0.0 -27900,0.08566208,0.03425524,,,,,,,,,,,,,,,,, -28000,0.07253477,0.032033637,,,,,,,,,,,,,,,,, -28100,0.07478689,0.032839507,,,,,,,,,,,,,,,,, -28200,0.078246206,0.03522161,,,,,,,,,,,,,,,,, -28300,0.07390142,0.032774303,,,,,,,,,,,,,,,,, -28400,0.078386985,0.032029193,,,,,,,,,,,,,,,,, -28500,0.07456879,0.03264539,,,,,,,,,,,,,,,,, -28600,0.06891185,0.030853681,,,,,,,,,,,,,,,,, -28604,,,0.9915740489959716,0.0273893792182207,0.4774020733554986,0.9868312478065492,0.0447375103831291,0.263646713481197,43793.0,0.9858406782150269,0.0476547665894031,0.2527613423048759,43793.0,9142.265007972715,14299.589767217636,9142.265007972715,5155.106744766235,1.451047420501709,0.0 -28700,0.085413314,0.032154094,,,,,,,,,,,,,,,,, -28800,0.08105415,0.033068217,,,,,,,,,,,,,,,,, -28900,0.08955051,0.03314308,,,,,,,,,,,,,,,,, -29000,0.09350074,0.03427792,,,,,,,,,,,,,,,,, -29100,0.09274854,0.03286193,,,,,,,,,,,,,,,,, -29200,0.08805036,0.029895185,,,,,,,,,,,,,,,,, -29300,0.085406885,0.032848574,,,,,,,,,,,,,,,,, -29360,,,0.9915950894355774,0.0272927861660718,0.4815457537935733,0.9867784976959229,0.0447870045900344,0.2628334545799982,43793.0,0.9858920574188232,0.0477716363966465,0.2477014584042329,43793.0,9382.218895435331,14661.340048074722,9382.218895435331,5276.8527200222015,1.4818134307861328,0.0 -29400,0.12099649,0.032413993,,,,,,,,,,,,,,,,, -29500,0.101535626,0.035626568,,,,,,,,,,,,,,,,, -29600,0.07141458,0.032094393,,,,,,,,,,,,,,,,, -29700,0.07400352,0.032262724,,,,,,,,,,,,,,,,, -29800,0.069992214,0.030780487,,,,,,,,,,,,,,,,, -29900,0.08121522,0.03290953,,,,,,,,,,,,,,,,, -30000,0.07785508,0.03480123,,,,,,,,,,,,,,,,, -30100,0.09729279,0.03330896,,,,,,,,,,,,,,,,, -30109,,,0.9915949702262878,0.0276272725313901,0.4659208937616314,0.9868267774581908,0.0446105264127254,0.2654210941093273,43793.0,0.9860133528709412,0.0474919974803924,0.254933399067772,43793.0,9622.306058883669,15026.52691435814,9622.306058883669,5401.89927482605,1.514939308166504,0.0 -30200,0.122332744,0.03289756,,,,,,,,,,,,,,,,, -30300,0.070175275,0.03485915,,,,,,,,,,,,,,,,, -30400,0.0743547,0.030382184,,,,,,,,,,,,,,,,, -30500,0.08349605,0.030830005,,,,,,,,,,,,,,,,, -30600,0.082084976,0.03369145,,,,,,,,,,,,,,,,, -30700,0.082840875,0.03424286,,,,,,,,,,,,,,,,, -30800,0.097457744,0.031644575,,,,,,,,,,,,,,,,, -30861,,,0.9914368391036988,0.0282070115208625,0.4585877945100001,0.9867272973060608,0.0447770431637764,0.2648167534202912,43793.0,0.9858651161193848,0.0474503748118877,0.2607415900214805,43793.0,9862.304508924484,15393.499891757963,9862.304508924484,5528.822056770325,1.5469331741333008,0.0 -30900,0.06385794,0.029235909,,,,,,,,,,,,,,,,, -31000,0.075250454,0.03177599,,,,,,,,,,,,,,,,, -31100,0.07259093,0.029811261,,,,,,,,,,,,,,,,, -31200,0.06865238,0.030976346,,,,,,,,,,,,,,,,, -31300,0.118771054,0.03460996,,,,,,,,,,,,,,,,, -31400,0.079032525,0.032047525,,,,,,,,,,,,,,,,, -31500,0.10157565,0.03317355,,,,,,,,,,,,,,,,, -31600,0.13614634,0.031769067,,,,,,,,,,,,,,,,, -31613,,,0.9915117025375366,0.0278496518731117,0.4719675642555955,0.9867539405822754,0.0446218624711036,0.2626403302936321,43793.0,0.9859716296195984,0.0472341999411582,0.2509376027609071,43793.0,10102.278964281082,15757.501960992811,10102.278964281082,5652.796845912933,1.580089807510376,0.0 -31700,0.08871866,0.03165118,,,,,,,,,,,,,,,,, -31800,0.11074813,0.032927778,,,,,,,,,,,,,,,,, -31900,0.08417187,0.034338366,,,,,,,,,,,,,,,,, -32000,0.09175835,0.032034945,,,,,,,,,,,,,,,,, -32100,0.069292754,0.031454112,,,,,,,,,,,,,,,,, -32200,0.07427833,0.032627393,,,,,,,,,,,,,,,,, -32300,0.07914648,0.03154033,,,,,,,,,,,,,,,,, -32375,,,0.9914616346359252,0.0277968235313892,0.4661418373102133,0.986866533756256,0.0448258966207504,0.264903256071144,43793.0,0.9860483407974244,0.0476890802383422,0.2538761775615324,43793.0,10342.485783815384,16124.741267204285,10342.485783815384,5779.776381969452,1.6129043102264404,0.0 -32400,0.07155393,0.03168514,,,,,,,,,,,,,,,,, -32500,0.09991196,0.02926192,,,,,,,,,,,,,,,,, -32600,0.08051558,0.03429834,,,,,,,,,,,,,,,,, -32700,0.13185638,0.028859392,,,,,,,,,,,,,,,,, -32800,0.08265793,0.031098796,,,,,,,,,,,,,,,,, -32900,0.08720893,0.033622507,,,,,,,,,,,,,,,,, -33000,0.09240262,0.028826209,,,,,,,,,,,,,,,,, -33100,0.09247661,0.032638535,,,,,,,,,,,,,,,,, -33127,,,0.9914880990982056,0.027717201039195,0.476126487256679,0.9866778254508972,0.044719535857439,0.2619186380504609,43793.0,0.9858246445655824,0.0473682954907417,0.2573437097629211,43793.0,10582.662477493286,16491.039191007614,10582.662477493286,5905.846033334732,1.644451379776001,0.0 -33200,0.079444975,0.03201948,,,,,,,,,,,,,,,,, -33300,0.11465143,0.033851385,,,,,,,,,,,,,,,,, -33400,0.07217504,0.03113123,,,,,,,,,,,,,,,,, -33500,0.10516276,0.02899668,,,,,,,,,,,,,,,,, -33600,0.08736199,0.033782415,,,,,,,,,,,,,,,,, -33700,0.09407841,0.03130115,,,,,,,,,,,,,,,,, -33800,0.14850251,0.032890756,,,,,,,,,,,,,,,,, -33880,,,0.9916728138923644,0.0270904246717691,0.4776111416574321,0.9869416356086732,0.0447549112141132,0.2737131680794807,43793.0,0.986042022705078,0.0473588556051254,0.2607110022505314,43793.0,10822.833097219467,16856.96304345131,10822.833097219467,6031.545439004898,1.678142547607422,0.0 -33900,0.084675096,0.03131265,,,,,,,,,,,,,,,,, -34000,0.08220972,0.03395298,,,,,,,,,,,,,,,,, -34100,0.09581819,0.033009585,,,,,,,,,,,,,,,,, -34200,0.120490775,0.031323493,,,,,,,,,,,,,,,,, -34300,0.10498638,0.033503797,,,,,,,,,,,,,,,,, -34400,0.088241,0.033824734,,,,,,,,,,,,,,,,, -34500,0.08633454,0.031420507,,,,,,,,,,,,,,,,, -34600,0.093915336,0.034299158,,,,,,,,,,,,,,,,, -34630,,,0.9918762445449828,0.0265015661716461,0.4959326874687609,0.9868482947349548,0.0447657331824302,0.2684190313777011,43793.0,0.9858322143554688,0.0475003346800804,0.256020744104348,43793.0,11062.853231668472,17223.528245449066,11062.853231668472,6158.037148237228,1.71140456199646,0.0 -34700,0.10465358,0.030336248,,,,,,,,,,,,,,,,, -34800,0.1243878,0.03416718,,,,,,,,,,,,,,,,, -34900,0.09549887,0.030401146,,,,,,,,,,,,,,,,, -35000,0.09163072,0.031001879,,,,,,,,,,,,,,,,, -35100,0.08523836,0.027905833,,,,,,,,,,,,,,,,, -35200,0.09595831,0.03214887,,,,,,,,,,,,,,,,, -35300,0.09577866,0.029634653,,,,,,,,,,,,,,,,, -35374,,,0.9919270873069764,0.0261572357267141,0.5077010842706066,0.9868158102035522,0.0450122393667697,0.2640386329595336,43793.0,0.9859994649887084,0.047745082527399,0.2512933375163518,43793.0,11302.85993719101,17591.77684688568,11302.85993719101,6286.22608089447,1.7442903518676758,0.0 -35400,0.12317393,0.031734426,,,,,,,,,,,,,,,,, -35500,0.0789233,0.030523844,,,,,,,,,,,,,,,,, -35600,0.09900943,0.031501107,,,,,,,,,,,,,,,,, -35700,0.078065395,0.03031578,,,,,,,,,,,,,,,,, -35800,0.09080776,0.03238954,,,,,,,,,,,,,,,,, -35900,0.09051785,0.031058963,,,,,,,,,,,,,,,,, -36000,0.08336974,0.031302433,,,,,,,,,,,,,,,,, -36100,0.1000009,0.030553093,,,,,,,,,,,,,,,,, -36121,,,0.9923211336135864,0.0249617137014865,0.5338029568777007,0.986887276172638,0.0448975339531898,0.2611574561332976,43793.0,0.9859430193901062,0.0478298552334308,0.2534297712322391,43793.0,11542.92440032959,17959.193194389343,11542.92440032959,6413.524546384811,1.7778651714324951,0.0 -36200,0.09654907,0.032330614,,,,,,,,,,,,,,,,, -36300,0.09547742,0.029303985,,,,,,,,,,,,,,,,, -36400,0.09847539,0.030909715,,,,,,,,,,,,,,,,, -36500,0.09907828,0.028650073,,,,,,,,,,,,,,,,, -36600,0.08474321,0.0305897,,,,,,,,,,,,,,,,, -36700,0.08731146,0.029164044,,,,,,,,,,,,,,,,, -36800,0.113451056,0.029809216,,,,,,,,,,,,,,,,, -36868,,,0.9920997619628906,0.0254573356360197,0.5130001361147138,0.9868791699409484,0.0451982542872428,0.2632707293366008,43793.0,0.986023485660553,0.0480411872267723,0.2562980574986526,43793.0,11782.954986095428,18330.56738352776,11782.954986095428,6544.811107635498,1.813506364822388,0.0 -36900,0.075002536,0.028720016,,,,,,,,,,,,,,,,, -37000,0.082655035,0.02917953,,,,,,,,,,,,,,,,, -37100,0.077563785,0.029725764,,,,,,,,,,,,,,,,, -37200,0.08239626,0.03329687,,,,,,,,,,,,,,,,, -37300,0.081298344,0.030886978,,,,,,,,,,,,,,,,, -37400,0.09332178,0.029492242,,,,,,,,,,,,,,,,, -37500,0.09350729,0.032327335,,,,,,,,,,,,,,,,, -37600,0.091107816,0.03343917,,,,,,,,,,,,,,,,, -37621,,,0.991813063621521,0.0264731142669916,0.5066786252637335,0.9869157075881958,0.0449858866631984,0.270892764058419,43793.0,0.9859792590141296,0.0478082187473773,0.2564197364137381,43793.0,12023.092227220535,18696.423310041428,12023.092227220535,6670.475694179535,1.847543478012085,0.0 -37700,0.08473992,0.029189076,,,,,,,,,,,,,,,,, -37800,0.09774509,0.029133013,,,,,,,,,,,,,,,,, -37900,0.08133784,0.029189792,,,,,,,,,,,,,,,,, -38000,0.12955995,0.02811334,,,,,,,,,,,,,,,,, -38100,0.09339433,0.028940257,,,,,,,,,,,,,,,,, -38200,0.092202075,0.032855704,,,,,,,,,,,,,,,,, -38300,0.08299625,0.029107854,,,,,,,,,,,,,,,,, -38381,,,0.9919701814651488,0.0260833278298378,0.4938227566410164,0.9869075417518616,0.0448095016181468,0.2723583684879261,43793.0,0.985987663269043,0.0477124713361263,0.2526040420471432,43793.0,12263.215424776075,19061.941546201702,12263.215424776075,6795.817242145538,1.8814423084259035,0.0 -38400,0.10309056,0.030142646,,,,,,,,,,,,,,,,, -38500,0.13126487,0.032078143,,,,,,,,,,,,,,,,, -38600,0.08410689,0.02934161,,,,,,,,,,,,,,,,, -38700,0.11842255,0.030607054,,,,,,,,,,,,,,,,, -38800,0.100945964,0.02928161,,,,,,,,,,,,,,,,, -38900,0.09755617,0.030611724,,,,,,,,,,,,,,,,, -39000,0.11256044,0.03346533,,,,,,,,,,,,,,,,, -39100,0.09144977,0.032318123,,,,,,,,,,,,,,,,, -39140,,,0.9919205904006958,0.0262812860310077,0.5019598940760397,0.9867547154426576,0.0452037937939167,0.2664720998561808,43793.0,0.9859459400177002,0.0479544922709465,0.25299035868724,43793.0,12503.383272647858,19424.19128537178,12503.383272647858,6917.844294548035,1.9166576862335205,0.0 -39200,0.08459207,0.0280421,,,,,,,,,,,,,,,,, -39300,0.08988781,0.028558908,,,,,,,,,,,,,,,,, -39400,0.09102586,0.031083291,,,,,,,,,,,,,,,,, -39500,0.099483,0.031241074,,,,,,,,,,,,,,,,, -39600,0.099496305,0.03042587,,,,,,,,,,,,,,,,, -39700,0.10085796,0.030343823,,,,,,,,,,,,,,,,, -39800,0.09419078,0.02847528,,,,,,,,,,,,,,,,, -39896,,,0.9918465614318848,0.0261474233120679,0.5186325621587466,0.9868633151054382,0.0452602319419384,0.2675940791495391,43793.0,0.9860289096832277,0.0478225275874137,0.2570222310540754,43793.0,12743.645746707916,19790.62196087837,12743.645746707916,7043.957558870316,1.951773881912232,0.0 -39900,0.10946118,0.03120162,,,,,,,,,,,,,,,,, -40000,0.09422329,0.030460214,,,,,,,,,,,,,,,,, -40100,0.11687991,0.029992247,,,,,,,,,,,,,,,,, -40200,0.14038914,0.029765617,,,,,,,,,,,,,,,,, -40300,0.13443296,0.033670746,,,,,,,,,,,,,,,,, -40400,0.10674529,0.02908355,,,,,,,,,,,,,,,,, -40500,0.10110569,0.029778302,,,,,,,,,,,,,,,,, -40600,0.10300961,0.031163491,,,,,,,,,,,,,,,,, -40660,,,0.9921106100082396,0.0257124546915292,0.50393287827044,0.9866834878921508,0.0454183742403984,0.2655199069094247,43793.0,0.9858823418617249,0.0478564389050006,0.2583532536156636,43793.0,12983.84726190567,20151.31301140785,12983.84726190567,7164.391726016998,1.987480878829956,0.0 -40700,0.09614224,0.033162948,,,,,,,,,,,,,,,,, -40800,0.11325433,0.028728815,,,,,,,,,,,,,,,,, -40900,0.0879477,0.028128486,,,,,,,,,,,,,,,,, -41000,0.109822236,0.033545017,,,,,,,,,,,,,,,,, -41100,0.103882805,0.031509645,,,,,,,,,,,,,,,,, -41200,0.08117849,0.030104186,,,,,,,,,,,,,,,,, -41300,0.11258986,0.030211572,,,,,,,,,,,,,,,,, -41400,0.098710336,0.0315502,,,,,,,,,,,,,,,,, -41411,,,0.9921404719352722,0.0253216736018657,0.5251935954997733,0.9867870211601256,0.0456630028784275,0.265891050572245,43793.0,0.985870122909546,0.0485646724700927,0.2541695344242801,43793.0,13223.94919514656,20516.02917122841,13223.94919514656,7288.951065063477,2.023068428039551,0.0 -41500,0.1026507,0.031350516,,,,,,,,,,,,,,,,, -41600,0.0943045,0.028541664,,,,,,,,,,,,,,,,, -41700,0.09564296,0.029877119,,,,,,,,,,,,,,,,, -41800,0.11382861,0.032029375,,,,,,,,,,,,,,,,, -41900,0.100676015,0.02951493,,,,,,,,,,,,,,,,, -42000,0.102176294,0.029567625,,,,,,,,,,,,,,,,, -42100,0.12886602,0.031125242,,,,,,,,,,,,,,,,, -42171,,,0.992421567440033,0.0244387220591306,0.5435246945200095,0.9867402911186218,0.0457667410373687,0.2670472319105248,43793.0,0.9859126806259156,0.0486242473125457,0.2552605737327201,43793.0,13464.0667886734,20884.01805138588,13464.0667886734,7416.767649650574,2.0580596923828125,0.0 -42200,0.10083389,0.030989569,,,,,,,,,,,,,,,,, -42300,0.11779652,0.029799415,,,,,,,,,,,,,,,,, -42400,0.13851365,0.03053871,,,,,,,,,,,,,,,,, -42500,0.091181,0.03100554,,,,,,,,,,,,,,,,, -42600,0.09046419,0.029130366,,,,,,,,,,,,,,,,, -42700,0.15048783,0.030149838,,,,,,,,,,,,,,,,, -42800,0.106629886,0.025736568,,,,,,,,,,,,,,,,, -42900,0.106706895,0.029984605,,,,,,,,,,,,,,,,, -42920,,,0.9927170872688292,0.0233779214322567,0.5733930541995389,0.9868364930152892,0.0459333881735801,0.2704952414247853,43793.0,0.9860095381736756,0.0487369447946548,0.2568638627028144,43793.0,13704.141545772552,21253.58636689186,13704.141545772552,7546.203330516815,2.094203472137451,0.0 -43000,0.09634066,0.02842475,,,,,,,,,,,,,,,,, -43100,0.123148836,0.032314993,,,,,,,,,,,,,,,,, -43200,0.10557683,0.030237993,,,,,,,,,,,,,,,,, -43300,0.09789816,0.027300255,,,,,,,,,,,,,,,,, -43400,0.10323496,0.031067012,,,,,,,,,,,,,,,,, -43500,0.09160436,0.028279606,,,,,,,,,,,,,,,,, -43600,0.12320145,0.031313017,,,,,,,,,,,,,,,,, -43675,,,0.9927061200141908,0.023393128067255,0.5666941320422774,0.9868308305740356,0.0457053445279598,0.2666772814062301,43793.0,0.9858528971672058,0.0487676002085208,0.2534971891102581,43793.0,13944.12907576561,21617.28958058357,13944.12907576561,7669.8636927604675,2.1300570964813232,0.0 -43700,0.18584849,0.030324804,,,,,,,,,,,,,,,,, -43800,0.13633803,0.029554574,,,,,,,,,,,,,,,,, -43900,0.1057481,0.030684393,,,,,,,,,,,,,,,,, -44000,0.09894284,0.028519938,,,,,,,,,,,,,,,,, -44100,0.112350464,0.02900283,,,,,,,,,,,,,,,,, -44200,0.11165144,0.030086812,,,,,,,,,,,,,,,,, -44300,0.12719183,0.031663198,,,,,,,,,,,,,,,,, -44400,0.11147098,0.027482776,,,,,,,,,,,,,,,,, -44425,,,0.9929761290550232,0.0225698053836822,0.5881898919000075,0.9867346286773682,0.0464401394128799,0.2643463754657173,43793.0,0.9857774972915648,0.0494213365018367,0.2502955025357254,43793.0,14184.10730624199,21981.214041233063,14184.10730624199,7793.7537133693695,2.167064905166626,0.0 -44500,0.12303709,0.032252967,,,,,,,,,,,,,,,,, -44600,0.09658993,0.028638255,,,,,,,,,,,,,,,,, -44700,0.0976461,0.028377276,,,,,,,,,,,,,,,,, -44800,0.12747811,0.031957664,,,,,,,,,,,,,,,,, -44900,0.13626996,0.029390017,,,,,,,,,,,,,,,,, -45000,0.12869409,0.028078172,,,,,,,,,,,,,,,,, -45100,0.11234666,0.028193763,,,,,,,,,,,,,,,,, -45172,,,0.9928836226463318,0.0231638737022876,0.5603069323792206,0.9867216348648072,0.0458410233259201,0.2698805055719386,43793.0,0.9857913851737976,0.0486513152718544,0.2532316248694131,43793.0,14424.139218568802,22347.200195789337,14424.139218568802,7919.652441978455,2.202032327651977,0.0 -45200,0.11912407,0.02980999,,,,,,,,,,,,,,,,, -45300,0.111569926,0.027450116,,,,,,,,,,,,,,,,, -45400,0.11207073,0.028355561,,,,,,,,,,,,,,,,, -45500,0.13543092,0.02891205,,,,,,,,,,,,,,,,, -45600,0.13385531,0.033504516,,,,,,,,,,,,,,,,, -45700,0.14966632,0.027788637,,,,,,,,,,,,,,,,, -45800,0.1402558,0.029732293,,,,,,,,,,,,,,,,, -45900,0.14380267,0.028985197,,,,,,,,,,,,,,,,, -45928,,,0.9925437569618224,0.0238797627389431,0.5548479854623927,0.986866533756256,0.0462991744279861,0.2675998202951851,43793.0,0.9859084486961364,0.0492966175079345,0.2490588585001329,43793.0,14664.235578775406,22710.05312728882,14664.235578775406,8042.348666667938,2.241867780685425,0.0 -46000,0.11160788,0.030720163,,,,,,,,,,,,,,,,, -46100,0.10663037,0.028013853,,,,,,,,,,,,,,,,, -46200,0.09953581,0.028126163,,,,,,,,,,,,,,,,, -46300,0.12112611,0.028141731,,,,,,,,,,,,,,,,, -46400,0.10580791,0.029992467,,,,,,,,,,,,,,,,, -46500,0.11887659,0.029787395,,,,,,,,,,,,,,,,, -46600,0.10600076,0.029313115,,,,,,,,,,,,,,,,, -46685,,,0.9926679134368896,0.023723516613245,0.5522737730499209,0.986710250377655,0.0460764281451702,0.2673692998549654,43793.0,0.985787570476532,0.0489269718527793,0.255223206791752,43793.0,14904.3414375782,23071.91637802124,14904.3414375782,8164.049585580826,2.278223991394043,0.0 -46700,0.121254936,0.02668661,,,,,,,,,,,,,,,,, -46800,0.109939836,0.027908094,,,,,,,,,,,,,,,,, -46900,0.122681335,0.029526116,,,,,,,,,,,,,,,,, -47000,0.11371989,0.029255858,,,,,,,,,,,,,,,,, -47100,0.12545817,0.02794874,,,,,,,,,,,,,,,,, -47200,0.11900955,0.028241925,,,,,,,,,,,,,,,,, -47300,0.13553767,0.027796378,,,,,,,,,,,,,,,,, -47400,0.10341247,0.024767926,,,,,,,,,,,,,,,,, -47437,,,0.9925719499588012,0.0237914603203535,0.5576811836200534,0.986846685409546,0.0464228764176368,0.2657642561931333,43793.0,0.9859118461608888,0.0494793020188808,0.2525168562310033,43793.0,15144.60926938057,23438.763329267505,15144.60926938057,8290.571984052658,2.3142735958099365,0.0 -47500,0.116123535,0.027559381,,,,,,,,,,,,,,,,, -47600,0.12530562,0.026231913,,,,,,,,,,,,,,,,, -47700,0.11960884,0.030340906,,,,,,,,,,,,,,,,, -47800,0.13236351,0.028656336,,,,,,,,,,,,,,,,, -47900,0.1405656,0.027687455,,,,,,,,,,,,,,,,, -48000,0.12087709,0.02767678,,,,,,,,,,,,,,,,, -48100,0.15608495,0.027065037,,,,,,,,,,,,,,,,, -48189,,,0.9928362369537354,0.0230195168405771,0.558678018020372,0.9867460131645204,0.0464305877685546,0.2732741377262889,43793.0,0.9858225584030152,0.0492965169250965,0.2521494930765173,43793.0,15384.76617383957,23803.41541624069,15384.76617383957,8415.009822368622,2.3507189750671387,0.0 -48200,0.10247039,0.026713206,,,,,,,,,,,,,,,,, -48300,0.13768758,0.030963361,,,,,,,,,,,,,,,,, -48400,0.13131577,0.027316479,,,,,,,,,,,,,,,,, -48500,0.12434023,0.02805689,,,,,,,,,,,,,,,,, -48600,0.15080535,0.02770796,,,,,,,,,,,,,,,,, -48700,0.11758622,0.028850602,,,,,,,,,,,,,,,,, -48800,0.12569556,0.027902216,,,,,,,,,,,,,,,,, -48900,0.15313539,0.027619323,,,,,,,,,,,,,,,,, -48939,,,0.9928126335144044,0.0228299181908369,0.5758678220084759,0.986866533756256,0.0463282391428947,0.2763794516470148,43793.0,0.9858726859092712,0.0493974350392818,0.2561368281989875,43793.0,15624.714596033096,24165.57988381385,15624.714596033096,8537.169777154922,2.38702654838562,0.0 -49000,0.10945153,0.026204173,,,,,,,,,,,,,,,,, -49100,0.11296707,0.025812687,,,,,,,,,,,,,,,,, -49200,0.10637259,0.02485261,,,,,,,,,,,,,,,,, -49300,0.12154729,0.026225762,,,,,,,,,,,,,,,,, -49400,0.11342826,0.026865767,,,,,,,,,,,,,,,,, -49500,0.14311875,0.029773401,,,,,,,,,,,,,,,,, -49600,0.12070797,0.026701283,,,,,,,,,,,,,,,,, -49683,,,0.9930967688560486,0.0219309777021408,0.5948772431380066,0.986885666847229,0.0467839241027832,0.2694161972938923,43793.0,0.9858827590942384,0.0499323159456253,0.2530611175182334,43793.0,15864.714323282242,24526.31983613968,15864.714323282242,8657.844844341278,2.4283759593963623,0.0 -49700,0.14288963,0.027786115,,,,,,,,,,,,,,,,, -49800,0.1189359,0.025632473,,,,,,,,,,,,,,,,, -49900,0.116658874,0.029058179,,,,,,,,,,,,,,,,, -50000,0.12186116,0.02650856,,,,,,,,,,,,,,,,, -50100,0.14685273,0.03139087,,,,,,,,,,,,,,,,, -50200,0.1477174,0.027037371,,,,,,,,,,,,,,,,, -50300,0.120193794,0.02606146,,,,,,,,,,,,,,,,, -50400,0.14041235,0.028305842,,,,,,,,,,,,,,,,, -50443,,,0.9935494065284728,0.0208131670951843,0.6246684589240922,0.9867971539497375,0.0465557985007762,0.2757770272471453,43793.0,0.9857075810432434,0.0498334541916847,0.2552904516568788,43793.0,16104.945301055908,24887.052001953125,16104.945301055908,8778.289328336716,2.465296030044556,0.0 -50500,0.14144042,0.026985144,,,,,,,,,,,,,,,,, -50600,0.13666297,0.026994094,,,,,,,,,,,,,,,,, -50700,0.12190612,0.026517482,,,,,,,,,,,,,,,,, -50800,0.13272811,0.0302684,,,,,,,,,,,,,,,,, -50900,0.121208206,0.02790879,,,,,,,,,,,,,,,,, -51000,0.12459074,0.02765715,,,,,,,,,,,,,,,,, -51100,0.12485046,0.027473388,,,,,,,,,,,,,,,,, -51200,0.13407284,0.026236333,,,,,,,,,,,,,,,,, -51201,,,0.993834674358368,0.019873609766364,0.6402711633729733,0.986703395843506,0.0474903881549835,0.2653409913394465,43793.0,0.9858195781707764,0.0507305338978767,0.2476450730734251,43793.0,16345.135685920715,25249.40115904808,16345.135685920715,8900.391793251038,2.501399278640747,0.0 -51300,0.14748788,0.024695499,,,,,,,,,,,,,,,,, -51400,0.186285,0.027112328,,,,,,,,,,,,,,,,, -51500,0.12724794,0.026551666,,,,,,,,,,,,,,,,, -51600,0.13637798,0.026551154,,,,,,,,,,,,,,,,, -51700,0.16020547,0.025932014,,,,,,,,,,,,,,,,, -51800,0.12467053,0.02751093,,,,,,,,,,,,,,,,, -51900,0.14814104,0.025864411,,,,,,,,,,,,,,,,, -51945,,,0.9937007427215576,0.020169697701931,0.6355450158745422,0.9866883754730223,0.0474072881042957,0.2672529701239556,43793.0,0.9858056902885436,0.0504932701587677,0.253286375824104,43793.0,16585.19568347931,25610.56532263756,16585.19568347931,9021.43328166008,2.5409858226776123,0.0 -52000,0.1338143,0.026742855,,,,,,,,,,,,,,,,, -52100,0.15437244,0.025650345,,,,,,,,,,,,,,,,, -52200,0.14728026,0.027066827,,,,,,,,,,,,,,,,, -52300,0.15814185,0.028158635,,,,,,,,,,,,,,,,, -52400,0.13693771,0.027430484,,,,,,,,,,,,,,,,, -52500,0.12523927,0.024152333,,,,,,,,,,,,,,,,, -52600,0.14336374,0.02445104,,,,,,,,,,,,,,,,, -52698,,,0.9938126802444458,0.019981313496828,0.6342326759145145,0.9867126941680908,0.0471597947180271,0.2665871337103819,43793.0,0.9858128428459167,0.0505554601550102,0.2467603426558873,43793.0,16825.145767211914,25969.932988643646,16825.145767211914,9140.794448375702,2.5773093700408936,0.0 -52700,0.13215475,0.025300724,,,,,,,,,,,,,,,,, -52800,0.16819331,0.028909873,,,,,,,,,,,,,,,,, -52900,0.13633448,0.025994988,,,,,,,,,,,,,,,,, -53000,0.15294276,0.028738478,,,,,,,,,,,,,,,,, -53100,0.15231092,0.025568044,,,,,,,,,,,,,,,,, -53200,0.1258914,0.023572223,,,,,,,,,,,,,,,,, -53300,0.1309281,0.023201853,,,,,,,,,,,,,,,,, -53400,0.14134035,0.027003475,,,,,,,,,,,,,,,,, -53458,,,0.993604838848114,0.0205391012132167,0.6163999862960505,0.9866956472396852,0.047807291150093,0.2672979128963146,43793.0,0.9857555627822876,0.0508050434291362,0.2525310318683968,43793.0,17065.106918811798,26333.523721456528,17065.106918811798,9264.3684155941,2.6135025024414062,0.0 -53500,0.15301184,0.026839707,,,,,,,,,,,,,,,,, -53600,0.13289745,0.026081802,,,,,,,,,,,,,,,,, -53700,0.16322215,0.025941372,,,,,,,,,,,,,,,,, -53800,0.1353199,0.024604293,,,,,,,,,,,,,,,,, -53900,0.141248,0.025333362,,,,,,,,,,,,,,,,, -54000,0.1647293,0.027000066,,,,,,,,,,,,,,,,, -54100,0.14098798,0.024550062,,,,,,,,,,,,,,,,, -54198,,,0.993502676486969,0.0207406654953956,0.6146754178331644,0.986743152141571,0.047809213399887,0.2679407266434734,43793.0,0.9857187271118164,0.0511037185788154,0.2504868933628825,43793.0,17305.169719457626,26697.89265561104,17305.169719457626,9388.612121343613,2.655043125152588,0.0 -54200,0.18276449,0.026890218,,,,,,,,,,,,,,,,, -54300,0.15644808,0.027757209,,,,,,,,,,,,,,,,, -54400,0.19540092,0.027367722,,,,,,,,,,,,,,,,, -54500,0.16674596,0.024409227,,,,,,,,,,,,,,,,, -54600,0.16326492,0.022934766,,,,,,,,,,,,,,,,, -54700,0.17526871,0.027438201,,,,,,,,,,,,,,,,, -54800,0.20420763,0.025075149,,,,,,,,,,,,,,,,, -54900,0.15952973,0.027435342,,,,,,,,,,,,,,,,, -54955,,,0.9935206770896912,0.0204508285969495,0.6152469816588861,0.9867601990699768,0.0481349416077137,0.2696156510939548,43793.0,0.985710084438324,0.0512884557247161,0.2556887617367147,43793.0,17545.279770612717,27061.449331760406,17545.279770612717,9512.002286434174,2.691478490829468,0.0 -55000,0.17649905,0.026161905,,,,,,,,,,,,,,,,, -55100,0.16608389,0.026170202,,,,,,,,,,,,,,,,, -55200,0.16647059,0.024523553,,,,,,,,,,,,,,,,, -55300,0.14674988,0.024106668,,,,,,,,,,,,,,,,, -55400,0.1762251,0.028502196,,,,,,,,,,,,,,,,, -55500,0.15628678,0.024456102,,,,,,,,,,,,,,,,, -55600,0.14629792,0.024521708,,,,,,,,,,,,,,,,, -55700,0.15066215,0.025590846,,,,,,,,,,,,,,,,, -55711,,,0.9935681819915771,0.0202482622116804,0.6322992889381533,0.9867435693740844,0.0486103147268295,0.2674156022578643,43793.0,0.9857816696166992,0.0518056154251098,0.2560246636285639,43793.0,17785.269641399384,27420.06883215904,17785.269641399384,9630.57388138771,2.7295162677764893,0.0 -55800,0.1552559,0.023367595,,,,,,,,,,,,,,,,, -55900,0.17202291,0.02207502,,,,,,,,,,,,,,,,, -56000,0.13403477,0.02276422,,,,,,,,,,,,,,,,, -56100,0.2081695,0.02586664,,,,,,,,,,,,,,,,, -56200,0.17601387,0.025911247,,,,,,,,,,,,,,,,, -56300,0.13757338,0.022839544,,,,,,,,,,,,,,,,, -56400,0.15317892,0.024690077,,,,,,,,,,,,,,,,, -56465,,,0.9937283992767334,0.019526956602931,0.6458244588409702,0.986622989177704,0.0490015819668769,0.2654806746801502,43793.0,0.98571515083313,0.0521701723337173,0.2486022880193414,43793.0,18025.47614264488,27784.033193588257,18025.47614264488,9754.26741719246,2.77383804321289,0.0 -56500,0.15343091,0.0248316,,,,,,,,,,,,,,,,, -56600,0.17465052,0.024376208,,,,,,,,,,,,,,,,, -56700,0.15849686,0.024578394,,,,,,,,,,,,,,,,, -56800,0.16182147,0.023654316,,,,,,,,,,,,,,,,, -56900,0.17421497,0.027338421,,,,,,,,,,,,,,,,, -57000,0.17624074,0.026108006,,,,,,,,,,,,,,,,, -57100,0.14070188,0.02199973,,,,,,,,,,,,,,,,, -57200,0.15354358,0.024049701,,,,,,,,,,,,,,,,, -57218,,,0.994199812412262,0.0185867268592119,0.6527285246331709,0.9865426421165466,0.0489468574523925,0.2711905727317665,43793.0,0.985640585422516,0.0520979203283786,0.2515863182121399,43793.0,18265.4832341671,28143.712094783783,18265.4832341671,9873.88046002388,2.812410593032837,0.0 -57300,0.15314943,0.023253394,,,,,,,,,,,,,,,,, -57400,0.16960827,0.02218262,,,,,,,,,,,,,,,,, -57500,0.1506865,0.025428077,,,,,,,,,,,,,,,,, -57600,0.15765081,0.022802848,,,,,,,,,,,,,,,,, -57700,0.1813331,0.0244967,,,,,,,,,,,,,,,,, -57800,0.15899116,0.02367617,,,,,,,,,,,,,,,,, -57878,,,,,,,,,,,,,,18477.15989422798,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/eval_measurements.csv deleted file mode 100644 index b8a97a27a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -120.44142413139345,0.0,12.25728178024292,1,0,12.25728178024292,0.5047284960746765,0.7438743114471436,0.0269645186056544,43793,132.69875383377075,0.5092054009437561,0.7392170429229736,0.0224031943319573,0.5064647793769836,0.7422076463699341,0.0253016651271869,43793 -241.9868643283844,0.0215451717376709,252.3354289531708,740,0,252.3354289531708,0.983142077922821,0.0788627117872238,0.0411033540555226,43793,494.3655803203583,0.9868019819259644,0.0672526881098747,0.0377212887011602,0.9841179251670836,0.0759080797433853,0.0384024595839474,43793 -364.3091459274292,0.0499124526977539,492.3489384651184,1486,0,492.3489384651184,0.9831559658050536,0.0658572912216186,0.0769709794151032,43793,856.7501437664032,0.9867600202560424,0.0523919649422168,0.0754291005588222,0.984138250350952,0.0624441169202327,0.0756716920682181,43793 -486.101655960083,0.0771832466125488,732.587010383606,2240,0,732.587010383606,0.98364919424057,0.058329090476036,0.1170758496200459,43793,1218.8280143737793,0.9872581958770752,0.0459452904760837,0.1198747155496232,0.9846627116203308,0.0551317036151886,0.1179890674866964,43793 -608.3868906497955,0.1059887409210205,972.7253069877625,2992,0,972.7253069877625,0.9839326739311218,0.0568287819623947,0.1407996392358726,43793,1581.3003692626953,0.987499177455902,0.0444191396236419,0.1499075671709747,0.984927773475647,0.0537693053483963,0.1389333456550538,43793 -732.8152990341187,0.1326751708984375,1212.699630022049,3750,0,1212.699630022049,0.984113335609436,0.0551243908703327,0.1527203936176814,43793,1945.7496001720428,0.9879691004753112,0.0425723828375339,0.1707366148038464,0.985083281993866,0.0522946529090404,0.1516738952771516,43793 -852.7306742668152,0.1593849658966064,1452.8758084774015,4505,0,1452.8758084774015,0.9843845963478088,0.0536156743764877,0.1683795304310171,43793,2305.887991666794,0.988291561603546,0.0410981923341751,0.1936304617762851,0.9853597283363342,0.0507598668336868,0.1696635347704308,43793 -970.4168586730956,0.186838150024414,1693.1669998168943,5253,0,1693.1669998168943,0.9847253561019896,0.0517366044223308,0.1818915693414284,43793,2663.9131495952606,0.9885271787643432,0.0402484983205795,0.2153660692381714,0.985521674156189,0.0495162829756736,0.1760585594420274,43793 -1089.2654836177826,0.2155897617340088,1933.394582033157,6001,0,1933.394582033157,0.9848209619522096,0.0510535649955272,0.2013832889900116,43793,3023.038010835648,0.9886388182640076,0.0387742333114147,0.2430892577206733,0.9857465624809264,0.0483935810625553,0.2052181951988453,43793 -1205.7322103977203,0.2425727844238281,2173.5919332504272,6758,0,2173.5919332504272,0.985062301158905,0.0504746623337268,0.214576990825933,43793,3379.749230146408,0.9886777997016908,0.0383498258888721,0.2587568848428609,0.9859300255775452,0.0477911457419395,0.2149441046105157,43793 -1325.9005455970764,0.2703602313995361,2413.561702489853,7517,0,2413.561702489853,0.9852644801139832,0.0494988821446895,0.22878059163496,43793,3739.935199022293,0.9893007874488832,0.0369826629757881,0.2733339480948629,0.9862178564071656,0.0469463691115379,0.2218923235563141,43793 -1448.121309518814,0.2982769012451172,2653.61355137825,8271,0,2653.61355137825,0.9852551817893982,0.0497630648314952,0.2241731692471199,43793,4102.255522012711,0.9892852902412416,0.0360622964799404,0.3039532837698702,0.9862227439880372,0.0468678250908851,0.2347299525440665,43793 -1563.973773241043,0.3276290893554687,2893.6486990451813,9022,0,2893.6486990451813,0.985448122024536,0.0490517504513263,0.2379081424435728,43793,4458.192250013351,0.989479660987854,0.0352666191756725,0.3198014137231609,0.986312448978424,0.0462530441582202,0.2304908775967825,43793 -1681.5201325416565,0.3571395874023437,3133.746441602707,9774,0,3133.746441602707,0.9855083227157592,0.0486259721219539,0.2380824970141881,43793,4815.885853528976,0.9896937608718872,0.0343454629182815,0.3432716140602922,0.9864041805267334,0.0458738133311271,0.2375307320474216,43793 -1797.7493011951449,0.3862593173980713,3373.892038345337,10525,0,3373.892038345337,0.9856384992599488,0.0480903573334217,0.2470925305778434,43793,5172.309501647949,0.9900869131088256,0.0333064012229442,0.3712344743367008,0.9865828156471252,0.0453640073537826,0.2488915138371127,43793 -1917.783516407013,0.4158000946044922,3614.049623250961,11281,0,3614.049623250961,0.985637664794922,0.0483180433511734,0.2483078179636221,43793,5532.550810575485,0.9902627468109132,0.0328679829835891,0.3630038486881168,0.986607551574707,0.0455364808440208,0.2463440279685129,43793 -2037.843641042709,0.4441356658935547,3854.0442173480974,12033,0,3854.0442173480974,0.9856178760528564,0.0485445819795131,0.2410220960823785,43793,5892.653676509857,0.9904043674468994,0.0323328152298927,0.3812291308726584,0.9865823984146118,0.0455449484288692,0.2456691085848675,43793 -2155.8782205581665,0.473520278930664,4094.288563728333,12788,0,4094.288563728333,0.9856410026550292,0.0485776364803314,0.2459781030991462,43793,6250.9817860126495,0.9902794361114502,0.0322381928563118,0.3840976298833426,0.986572265625,0.0456129871308803,0.2533034074588097,43793 -2276.909010410309,0.5021927356719971,4334.27161693573,13530,0,4334.27161693573,0.9857252836227416,0.0479781739413738,0.247405006703782,43793,6612.0441880226135,0.9906366467475892,0.031362771987915,0.4005199168896542,0.9865929484367372,0.0450634248554706,0.2480390896979932,43793 -2393.3545455932617,0.5375001430511475,4574.2215077877045,14272,0,4574.2215077877045,0.9857500791549684,0.0482579730451107,0.2457903663583589,43793,6968.494510889053,0.9906654953956604,0.03089433722198,0.4312031362505946,0.9866920113563538,0.0452400967478752,0.2555988970728234,43793 -2510.4126312732697,0.5691132545471191,4814.304394006729,15008,0,4814.304394006729,0.9857884645462036,0.0488702729344368,0.2505672932539682,43793,7325.689950466156,0.990509271621704,0.0307738613337278,0.433661814830482,0.986723244190216,0.0458978489041328,0.2576038652916376,43793 -2631.9517204761505,0.5978469848632812,5054.479521274567,15758,0,5054.479521274567,0.985659956932068,0.0484933592379093,0.2410234352551857,43793,7687.453608751297,0.9911080598831176,0.0294422395527362,0.4476567930154197,0.9865406155586244,0.0454853437840938,0.2522223362011558,43793 -2751.486648082733,0.627873420715332,5294.512858390808,16507,0,5294.512858390808,0.98580402135849,0.0486540645360946,0.2458925109038763,43793,8047.071723461151,0.9912184476852416,0.0284808743745088,0.4712633126096623,0.9866733551025392,0.0455878898501396,0.2604579027919871,43793 -2867.3289663791656,0.6588001251220703,5534.532721757889,17256,0,5534.532721757889,0.9859118461608888,0.0483159162104129,0.25409906685934,43793,8402.984964370728,0.9917227029800416,0.0272119287401437,0.5129358647554417,0.9868016242980956,0.0451838597655296,0.2661670100367674,43793 -2982.892689228058,0.689293384552002,5774.725650072098,18010,0,5774.725650072098,0.9857736825942992,0.0492622181773185,0.2460779115552562,43793,8758.792022228241,0.9915904998779296,0.0272315237671136,0.520161026037464,0.9867346286773682,0.0460444428026676,0.261773279133514,43793 -3101.454934358597,0.7188436985015869,6014.775934457779,18764,0,6014.775934457779,0.9858528971672058,0.0488565750420093,0.2505815096182588,43793,9117.453943014145,0.9917588233947754,0.0268174763768911,0.5093664431630106,0.9866693019866944,0.0456399135291576,0.2636575558792305,43793 -3214.5177624225616,0.7481474876403809,6254.735512256622,19513,0,6254.735512256622,0.9858368635177612,0.0486247912049293,0.2516486277602932,43793,9470.525463342668,0.9917542934417723,0.0272373333573341,0.492260297195812,0.9866904020309448,0.0456603653728961,0.2597286529450691,43793 -3329.5672364234924,0.7790920734405518,6494.936870098114,20277,0,6494.936870098114,0.9858145713806152,0.04971195012331,0.2459385786965277,43793,9825.827875614166,0.9916300177574158,0.0272173155099153,0.5153293861606729,0.9866891503334044,0.0464046634733676,0.2622391754031578,43793 -3444.192747116089,0.8117268085479736,6735.167747020721,21037,0,6735.167747020721,0.9858802556991576,0.0489316806197166,0.2508172272557681,43793,10180.737467765808,0.991977035999298,0.0264404602348804,0.5115002049030158,0.9867467880249025,0.0457068383693695,0.2609522036510439,43793 -3558.997751712799,0.842423677444458,6975.293369054794,21796,0,6975.293369054794,0.9857686161994934,0.0502894409000873,0.2469956287644825,43793,10535.719145298004,0.9917345643043518,0.0266399029642343,0.5187101817819291,0.986717164516449,0.0467933230102062,0.2639815040254276,43793 -3673.886606693268,0.8734378814697266,7215.500328779221,22547,0,7215.500328779221,0.9858317971229552,0.0495360493659973,0.2483276572328031,43793,10890.866336345673,0.992116153240204,0.0254687368869781,0.5342317305319013,0.986757755279541,0.0462933629751205,0.2647903676022655,43793 -3789.942592382431,0.9038915634155272,7455.528498411179,23294,0,7455.528498411179,0.9859139323234558,0.0498078241944313,0.2459262801428786,43793,11247.001574993134,0.9923154711723328,0.0248916465789079,0.5460445997356278,0.9866899847984314,0.0465769246220588,0.254345836560073,43793 -3904.790680646896,0.9359946250915528,7695.701915502548,24044,0,7695.701915502548,0.9857686161994934,0.0509364232420921,0.2463427566125356,43793,11602.075031757357,0.9923921823501588,0.0241668839007616,0.5712409322426214,0.9866307377815248,0.0474363043904304,0.2526439866836966,43793 -4021.984845399857,0.9708609580993652,7935.791239023209,24798,0,7935.791239023209,0.9857513904571532,0.0502862893044948,0.2472731325056807,43793,11959.413158655168,0.9930984377861024,0.0224706102162599,0.5961115454126245,0.9866713285446168,0.0468453131616115,0.2598697709526262,43793 -4147.293561458588,1.0030477046966553,8176.01961183548,25538,0,8176.01961183548,0.9857223033905028,0.0505571216344833,0.2435503236909003,43793,12325.004558086395,0.9933809638023376,0.0218420680612325,0.6266066179462839,0.986694872379303,0.0469269268214702,0.2572042716750418,43793 -4262.936295509338,1.0376760959625244,8416.271092653275,26291,0,8416.271092653275,0.9857041835784912,0.0503971055150032,0.2480874887484809,43793,12680.953973293304,0.9928025603294371,0.0234451312571764,0.5761827578002454,0.9866213798522948,0.0469994507730007,0.2511379755020969,43793 -4382.595624923706,1.070669651031494,8656.47072982788,27040,0,8656.47072982788,0.9857812523841858,0.0514933206140995,0.2459544175952203,43793,13040.865669727324,0.992655873298645,0.0236091390252113,0.5897485965947075,0.9866039156913756,0.048078216612339,0.2507343853926724,43793 -4496.867026567459,1.102473258972168,8896.424266338348,27791,0,8896.424266338348,0.9857842326164246,0.0510902516543865,0.2458477064293119,43793,13395.142271757126,0.9928238987922668,0.0232392903417348,0.5776398059058571,0.9865588545799256,0.0476624406874179,0.2508783567306901,43793 -4612.7197296619415,1.1351377964019775,9136.532777786257,28536,0,9136.532777786257,0.9857046008110046,0.0521145649254322,0.2429519902380678,43793,13751.155616044998,0.9926972389221193,0.023307790979743,0.5889419244248999,0.9865353107452391,0.0485096462070941,0.2496172886362365,43793 -4732.798652887344,1.166337966918945,9376.589265823364,29288,0,9376.589265823364,0.9856772422790528,0.0517030544579029,0.2464958534102047,43793,14111.342035531998,0.9929196834564208,0.0228444803506135,0.5824909793466204,0.9865726828575134,0.0482287295162677,0.2497479171584527,43793 -4844.209650993347,1.199488401412964,9616.7351603508,30037,0,9616.7351603508,0.9855828881263732,0.0522092543542385,0.239892505514401,43793,14462.951917409897,0.9931166768074036,0.0221618507057428,0.6110780190593907,0.9865093231201172,0.0487466380000114,0.2475532409059372,43793 -4960.09548664093,1.2316246032714844,9856.721096038818,30790,0,9856.721096038818,0.9857505559921264,0.0530433468520641,0.2417635372968837,43793,14818.875288009644,0.9931876063346864,0.0214119311422109,0.6239691444803845,0.9865888953208924,0.0495238825678825,0.2529329067237749,43793 -5071.77366065979,1.265533208847046,10096.81805896759,31529,0,10096.81805896759,0.9855732321739196,0.0524737164378166,0.2433974481567167,43793,15170.70418548584,0.9937662482261658,0.020005514845252,0.6605400162041181,0.9865211248397828,0.0490268692374229,0.2478563465140137,43793 -5190.624104499817,1.299309253692627,10336.932120800018,32270,0,10336.932120800018,0.9855959415435792,0.0532693080604076,0.2317083155339858,43793,15529.72211575508,0.9942206740379332,0.0189791806042194,0.6804974564516768,0.9864314198493958,0.0497116670012474,0.2431531213486542,43793 -5309.052172183991,1.3314814567565918,10577.190371990204,33024,0,10577.190371990204,0.985584557056427,0.053100511431694,0.2379802507963694,43793,15888.460463523865,0.9941815733909608,0.0190620310604572,0.6768720292775423,0.9864902496337892,0.0494199618697166,0.2456189152667728,43793 -5426.022752046585,1.362900733947754,10817.216018676758,33781,0,10817.216018676758,0.9855904579162598,0.0534843131899833,0.2413656636156215,43793,16245.507838010788,0.9938918352127076,0.0196193568408489,0.6653318432091149,0.9864837527275084,0.0499318204820156,0.2432102650466545,43793 -5542.016725063324,1.3957459926605225,11057.21807861328,34531,0,11057.21807861328,0.9854485392570496,0.053956814110279,0.2340651284593916,43793,16601.556545495987,0.9931476712226868,0.0218456238508224,0.6065809626161704,0.9863153100013732,0.0502861440181732,0.2356881656238881,43793 -5656.931999921799,1.4285671710968018,11297.394924879074,35289,0,11297.394924879074,0.9854581952095032,0.0542607828974723,0.2370528405868589,43793,16956.7015914917,0.9932732582092284,0.0213515106588602,0.6102445285731928,0.9863283038139344,0.0507453233003616,0.2422555799690798,43793 -5772.85106253624,1.4614310264587402,11537.618818044662,36047,0,11537.618818044662,0.9854379892349244,0.0543590039014816,0.2314366596590499,43793,17312.897706270218,0.9936595559120178,0.0201679468154907,0.6438834465263861,0.9863104224205016,0.0507057793438434,0.2338064859736999,43793 -5890.137528896332,1.8974733352661133,11777.46779179573,36797,0,11777.46779179573,0.985491931438446,0.0551511570811271,0.2346267707140708,43793,17670.489493846893,0.9936636686325072,0.0199106354266405,0.6476884772247021,0.9864094853401184,0.0514586195349693,0.2377665407309361,43793 -6002.826015472412,1.9313278198242188,12017.64917397499,37549,0,12017.64917397499,0.9854986667633056,0.0555910542607307,0.2343527281268959,43793,18023.41338586808,0.9935628175735474,0.0201002769172191,0.6463171003589824,0.9863964915275574,0.0519293472170829,0.2382274855362635,43793 -6117.448066949844,1.9647307395935056,12257.801238059998,38299,0,12257.801238059998,0.9853967428207396,0.0553966537117958,0.2333304576637424,43793,18378.24086880684,0.9944259524345398,0.0178183428943157,0.7018950843743552,0.9862747192382812,0.0519014485180377,0.2361996599409608,43793 -6234.392640352249,1.999298334121704,12497.982424497604,39053,0,12497.982424497604,0.985359251499176,0.055601317435503,0.2355170950865257,43793,18735.421145915985,0.9949982762336732,0.0165216345340013,0.7224225479696811,0.9863018989562988,0.0520194247364997,0.2346073985838009,43793 -6351.638965606689,2.03450345993042,12738.079112768171,39792,0,12738.079112768171,0.9853891134262084,0.0565869100391864,0.2302643002277187,43793,19092.81990480423,0.9952055215835572,0.0157686341553926,0.7366664288057573,0.9863213896751404,0.0527669936418533,0.2330148621436922,43793 -6470.0264637470245,2.069436550140381,12978.10954594612,40539,0,12978.10954594612,0.9852867722511292,0.0564857460558414,0.2309985263228809,43793,19451.29268836975,0.9952002763748168,0.0160813517868518,0.7377185027103813,0.9861488342285156,0.0530131682753562,0.2330946343841542,43793 -6589.574536561966,2.107372522354126,13218.214102506638,41274,0,13218.214102506638,0.985367238521576,0.0568977668881416,0.228333392060031,43793,19811.006506443024,0.9948239326477052,0.0167916342616081,0.7085789565068721,0.9862288236618042,0.0531344339251518,0.2288084561811299,43793 -6704.920714616776,2.143451690673828,13458.43649482727,42022,0,13458.43649482727,0.9854017496109008,0.05723188072443,0.2287596771324323,43793,20166.63150715828,0.9945333003997804,0.0173073932528495,0.7135553817663034,0.9861322045326232,0.0535982139408588,0.2313976091081589,43793 -6821.973500728607,2.1787869930267334,13698.556680202484,42774,0,13698.556680202484,0.9853150248527528,0.0580363795161247,0.2279296138834921,43793,20523.85961985588,0.9939397573471068,0.0189362727105617,0.6697946583231249,0.9862284064292908,0.0541604533791542,0.2336562058187641,43793 -6937.488364696503,2.2132139205932617,13938.790473937988,43522,0,13938.790473937988,0.9853845238685608,0.0585099793970584,0.2311539778899343,43793,20879.66306090355,0.994067132472992,0.0184782650321722,0.6751643354404507,0.9861873984336852,0.0549007356166839,0.2255450868821931,43793 -7050.623664855957,2.248157262802124,14178.823575496674,44274,0,14178.823575496674,0.9851073622703552,0.0578879825770854,0.2267297435673207,43793,21232.88632273674,0.9948219656944276,0.0166021138429641,0.7204996206656971,0.9860494136810304,0.0540257133543491,0.2289064375174395,43793 -7161.983935594559,2.284372329711914,14418.789083957672,45021,0,14418.789083957672,0.9852181673049928,0.058937769383192,0.2193470926250697,43793,21584.268087387085,0.99458110332489,0.0170487016439437,0.7150672404338416,0.9861192107200624,0.054877046495676,0.221464246573604,43793 -7272.9426436424255,2.3197762966156006,14658.8561296463,45776,0,14658.8561296463,0.985298991203308,0.0596415922045707,0.2246596163663179,43793,21935.34939599037,0.994462788105011,0.0171351470053195,0.7201120524918182,0.9861781001091005,0.0556235834956169,0.2299968563395169,43793 -7386.477149009705,2.3562400341033936,14898.866291761398,46528,0,14898.866291761398,0.9852783679962158,0.0603801384568214,0.2273424703023917,43793,22288.94997239113,0.9965863823890686,0.0121391955763101,0.8180281852483683,0.986148476600647,0.0563682243227958,0.231505998139778,43793 -7496.979228496551,2.3927478790283203,15139.000858545303,47279,0,15139.000858545303,0.9851911664009094,0.0600710287690162,0.2206319943723637,43793,22639.64312577248,0.996531307697296,0.0124534703791141,0.8025178704840774,0.9861151576042176,0.0560289360582828,0.2313723651048697,43793 -7612.681452512741,2.4296257495880127,15379.0965590477,48028,0,15379.0965590477,0.985040843486786,0.0601809658110141,0.2190393447107278,43793,22995.49823999405,0.9962067008018494,0.0131070259958505,0.7931497384251549,0.986046552658081,0.0562718920409679,0.2260776234022609,43793 -7727.149502515793,2.468273878097534,15619.057337284088,48773,0,15619.057337284088,0.9852885007858276,0.061463788151741,0.2218434503684044,43793,23349.98605751992,0.9955082535743712,0.0142448795959353,0.774797287980069,0.986183762550354,0.0573287643492221,0.228134299896944,43793 -7836.429103851318,2.5045437812805176,15859.288888454435,49526,0,15859.288888454435,0.9851848483085632,0.0611521825194358,0.2181409027100457,43793,23699.553965330124,0.995168685913086,0.0149586545303463,0.7747991965325125,0.9860441088676452,0.0570524372160434,0.2230743490184299,43793 -7947.837135314941,2.5418262481689453,16099.445001363754,50272,0,16099.445001363754,0.9852353930473328,0.0616249740123748,0.221362975017089,43793,24051.175461292267,0.99516099691391,0.0150730079039931,0.7670223106612885,0.986040472984314,0.0577150806784629,0.2250570354018474,43793 -8055.256649494171,2.5816922187805176,16339.63740158081,51016,0,16339.63740158081,0.9851894974708556,0.0623769909143447,0.2168952545827823,43793,24398.847739219666,0.994962215423584,0.0152889490127563,0.7472203813139899,0.9861119389533995,0.0582580640912056,0.220357172150594,43793 -8166.622575998306,2.619372129440308,16579.84620285034,51765,0,16579.84620285034,0.9851473569869996,0.0621534325182437,0.2212461640478424,43793,24750.48007750511,0.9956825971603394,0.0138268237933516,0.788454380210666,0.9860782027244568,0.0579963140189647,0.2223265183184763,43793 -8280.05424952507,2.6561763286590576,16819.83504796028,52515,0,16819.83504796028,0.9851554036140442,0.0629678145051002,0.2186763948138787,43793,25103.957667827606,0.9962998032569884,0.0124083021655678,0.8206180664414136,0.9861164093017578,0.0587662868201732,0.2223903278764082,43793 -8392.8616604805,2.697401523590088,17059.884666204453,53261,0,17059.884666204453,0.985207200050354,0.0634720176458358,0.2194677612263874,43793,25456.877187013622,0.997327983379364,0.0103095285594463,0.8675313433747215,0.9860234260559082,0.0592837184667587,0.2185852231293817,43793 -8505.566602230072,2.735142230987549,17299.912273406982,54013,0,17299.912273406982,0.9851477742195128,0.063467264175415,0.2142464943903636,43793,25809.669014453888,0.996036410331726,0.0128533076494932,0.8124292085607783,0.9860871434211732,0.0590945109724998,0.2162775368447429,43793 -8611.9836332798,2.773247480392456,17539.928025960922,54762,0,17539.928025960922,0.9850918054580688,0.0639261305332183,0.2151637094748986,43793,26156.15963792801,0.9976819157600404,0.0096607785671949,0.8719512232661151,0.9859962463378906,0.059577465057373,0.2174690053207167,43793 -8719.667748212814,2.8100857734680176,17780.066515922546,55525,0,17780.066515922546,0.9851404428482056,0.0652488842606544,0.214197982853505,43793,26504.03938817978,0.9970630407333374,0.010564861819148,0.8574635624911153,0.9860855340957642,0.0606580302119255,0.2188187087485164,43793 -8833.25545835495,2.8469173908233643,18020.076694488525,56277,0,18020.076694488525,0.9851536750793456,0.0651892498135566,0.2122587202567049,43793,26857.69357442856,0.9965944290161132,0.0112590510398149,0.8476336181469304,0.9860563278198242,0.0609218515455722,0.2172369927346399,43793 -8949.157200813293,2.883453845977783,18260.329787254333,57029,0,18260.329787254333,0.985084593296051,0.06547699123620987,0.2110058729860776,43793,27213.905391216278,0.9959918260574341,0.012544776313006878,0.8311303477558817,0.9860153198242188,0.06112876161932945,0.21309542368130246,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/measurements.csv deleted file mode 100644 index 6520a52a3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/measurements.csv +++ /dev/null @@ -1,656 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3629005,0.73958975,,,,,,,,,,,,,,,,, -1,,,0.5092054009437561,0.7392170429229736,0.0224031943319573,0.5064647793769836,0.7422076463699341,0.0253016651271869,43793.0,0.5047284960746765,0.7438743114471436,0.0269645186056544,43793.0,12.25728178024292,132.69875383377075,12.25728178024292,120.44142413139345,0.0,0.0 -100,0.4808574,0.44099918,,,,,,,,,,,,,,,,, -200,0.38297063,0.33363512,,,,,,,,,,,,,,,,, -300,0.2753433,0.2292048,,,,,,,,,,,,,,,,, -400,0.18350214,0.15001622,,,,,,,,,,,,,,,,, -500,0.108127244,0.10663183,,,,,,,,,,,,,,,,, -600,0.06689893,0.08002707,,,,,,,,,,,,,,,,, -700,0.050253335,0.06799725,,,,,,,,,,,,,,,,, -740,,,0.9868019819259644,0.0672526881098747,0.0377212887011602,0.9841179251670836,0.0759080797433853,0.0384024595839474,43793.0,0.983142077922821,0.0788627117872238,0.0411033540555226,43793.0,252.3354289531708,494.3655803203583,252.3354289531708,241.9868643283844,0.0215451717376709,0.0 -800,0.041874092,0.06057031,,,,,,,,,,,,,,,,, -900,0.02691705,0.056484792,,,,,,,,,,,,,,,,, -1000,0.0420182,0.058245055,,,,,,,,,,,,,,,,, -1100,0.023898397,0.053572636,,,,,,,,,,,,,,,,, -1200,0.033425447,0.049806528,,,,,,,,,,,,,,,,, -1300,0.04495272,0.049098205,,,,,,,,,,,,,,,,, -1400,0.028590066,0.05068437,,,,,,,,,,,,,,,,, -1486,,,0.9867600202560424,0.0523919649422168,0.0754291005588222,0.984138250350952,0.0624441169202327,0.0756716920682181,43793.0,0.9831559658050536,0.0658572912216186,0.0769709794151032,43793.0,492.3489384651184,856.7501437664032,492.3489384651184,364.3091459274292,0.0499124526977539,0.0 -1500,0.1369007,0.052592576,,,,,,,,,,,,,,,,, -1600,0.027008569,0.047057815,,,,,,,,,,,,,,,,, -1700,0.05964297,0.05211306,,,,,,,,,,,,,,,,, -1800,0.052559126,0.055155344,,,,,,,,,,,,,,,,, -1900,0.03718618,0.051482916,,,,,,,,,,,,,,,,, -2000,0.06785948,0.054633517,,,,,,,,,,,,,,,,, -2100,0.07821021,0.05304722,,,,,,,,,,,,,,,,, -2200,0.06236004,0.043811034,,,,,,,,,,,,,,,,, -2240,,,0.9872581958770752,0.0459452904760837,0.1198747155496232,0.9846627116203308,0.0551317036151886,0.1179890674866964,43793.0,0.98364919424057,0.058329090476036,0.1170758496200459,43793.0,732.587010383606,1218.8280143737793,732.587010383606,486.101655960083,0.0771832466125488,0.0 -2300,0.07093475,0.044074655,,,,,,,,,,,,,,,,, -2400,0.069040954,0.04923712,,,,,,,,,,,,,,,,, -2500,0.030072901,0.048720043,,,,,,,,,,,,,,,,, -2600,0.03576217,0.04711239,,,,,,,,,,,,,,,,, -2700,0.035910588,0.043158274,,,,,,,,,,,,,,,,, -2800,0.13001873,0.04835405,,,,,,,,,,,,,,,,, -2900,0.019548818,0.046035226,,,,,,,,,,,,,,,,, -2992,,,0.987499177455902,0.0444191396236419,0.1499075671709747,0.984927773475647,0.0537693053483963,0.1389333456550538,43793.0,0.9839326739311218,0.0568287819623947,0.1407996392358726,43793.0,972.7253069877625,1581.3003692626953,972.7253069877625,608.3868906497955,0.1059887409210205,0.0 -3000,0.040884104,0.043720752,,,,,,,,,,,,,,,,, -3100,0.020725852,0.046005532,,,,,,,,,,,,,,,,, -3200,0.026774332,0.049195126,,,,,,,,,,,,,,,,, -3300,0.02314051,0.04349512,,,,,,,,,,,,,,,,, -3400,0.035759907,0.048977356,,,,,,,,,,,,,,,,, -3500,0.028464416,0.0510147,,,,,,,,,,,,,,,,, -3600,0.034922346,0.04663686,,,,,,,,,,,,,,,,, -3700,0.0374221,0.045592934,,,,,,,,,,,,,,,,, -3750,,,0.9879691004753112,0.0425723828375339,0.1707366148038464,0.985083281993866,0.0522946529090404,0.1516738952771516,43793.0,0.984113335609436,0.0551243908703327,0.1527203936176814,43793.0,1212.699630022049,1945.7496001720428,1212.699630022049,732.8152990341187,0.1326751708984375,0.0 -3800,0.040262595,0.046298288,,,,,,,,,,,,,,,,, -3900,0.022441143,0.046430703,,,,,,,,,,,,,,,,, -4000,0.03292108,0.04137175,,,,,,,,,,,,,,,,, -4100,0.034134917,0.04366538,,,,,,,,,,,,,,,,, -4200,0.034424704,0.044111017,,,,,,,,,,,,,,,,, -4300,0.017791802,0.044145554,,,,,,,,,,,,,,,,, -4400,0.063005544,0.047960818,,,,,,,,,,,,,,,,, -4500,0.021414608,0.042866945,,,,,,,,,,,,,,,,, -4505,,,0.988291561603546,0.0410981923341751,0.1936304617762851,0.9853597283363342,0.0507598668336868,0.1696635347704308,43793.0,0.9843845963478088,0.0536156743764877,0.1683795304310171,43793.0,1452.8758084774015,2305.887991666794,1452.8758084774015,852.7306742668152,0.1593849658966064,0.0 -4600,0.018829973,0.04617383,,,,,,,,,,,,,,,,, -4700,0.039226152,0.042966034,,,,,,,,,,,,,,,,, -4800,0.029697899,0.04674733,,,,,,,,,,,,,,,,, -4900,0.01949074,0.04942446,,,,,,,,,,,,,,,,, -5000,0.017913913,0.044358175,,,,,,,,,,,,,,,,, -5100,0.01929789,0.04199,,,,,,,,,,,,,,,,, -5200,0.011929977,0.04051881,,,,,,,,,,,,,,,,, -5253,,,0.9885271787643432,0.0402484983205795,0.2153660692381714,0.985521674156189,0.0495162829756736,0.1760585594420274,43793.0,0.9847253561019896,0.0517366044223308,0.1818915693414284,43793.0,1693.1669998168943,2663.9131495952606,1693.1669998168943,970.4168586730956,0.186838150024414,0.0 -5300,0.04814299,0.043045223,,,,,,,,,,,,,,,,, -5400,0.023561046,0.042695694,,,,,,,,,,,,,,,,, -5500,0.02545079,0.046029825,,,,,,,,,,,,,,,,, -5600,0.021275079,0.042160146,,,,,,,,,,,,,,,,, -5700,0.014285329,0.04205296,,,,,,,,,,,,,,,,, -5800,0.03383029,0.04543103,,,,,,,,,,,,,,,,, -5900,0.020481884,0.044695307,,,,,,,,,,,,,,,,, -6000,0.014645288,0.039068308,,,,,,,,,,,,,,,,, -6001,,,0.9886388182640076,0.0387742333114147,0.2430892577206733,0.9857465624809264,0.0483935810625553,0.2052181951988453,43793.0,0.9848209619522096,0.0510535649955272,0.2013832889900116,43793.0,1933.394582033157,3023.038010835648,1933.394582033157,1089.2654836177826,0.2155897617340088,0.0 -6100,0.022228919,0.0460318,,,,,,,,,,,,,,,,, -6200,0.022617348,0.040433876,,,,,,,,,,,,,,,,, -6300,0.014728956,0.043067344,,,,,,,,,,,,,,,,, -6400,0.01720358,0.041452356,,,,,,,,,,,,,,,,, -6500,0.013957953,0.04020835,,,,,,,,,,,,,,,,, -6600,0.028844729,0.043277252,,,,,,,,,,,,,,,,, -6700,0.027856365,0.044947747,,,,,,,,,,,,,,,,, -6758,,,0.9886777997016908,0.0383498258888721,0.2587568848428609,0.9859300255775452,0.0477911457419395,0.2149441046105157,43793.0,0.985062301158905,0.0504746623337268,0.214576990825933,43793.0,2173.5919332504272,3379.749230146408,2173.5919332504272,1205.7322103977203,0.2425727844238281,0.0 -6800,0.015329155,0.039557602,,,,,,,,,,,,,,,,, -6900,0.015649186,0.040290937,,,,,,,,,,,,,,,,, -7000,0.02259306,0.04151819,,,,,,,,,,,,,,,,, -7100,0.012336362,0.03723069,,,,,,,,,,,,,,,,, -7200,0.015385642,0.039897792,,,,,,,,,,,,,,,,, -7300,0.014128588,0.04039552,,,,,,,,,,,,,,,,, -7400,0.025054535,0.042574435,,,,,,,,,,,,,,,,, -7500,0.021992216,0.042870738,,,,,,,,,,,,,,,,, -7517,,,0.9893007874488832,0.0369826629757881,0.2733339480948629,0.9862178564071656,0.0469463691115379,0.2218923235563141,43793.0,0.9852644801139832,0.0494988821446895,0.22878059163496,43793.0,2413.561702489853,3739.935199022293,2413.561702489853,1325.9005455970764,0.2703602313995361,0.0 -7600,0.021109497,0.04365589,,,,,,,,,,,,,,,,, -7700,0.01784635,0.045671813,,,,,,,,,,,,,,,,, -7800,0.014111357,0.039380793,,,,,,,,,,,,,,,,, -7900,0.01538893,0.041189294,,,,,,,,,,,,,,,,, -8000,0.013363793,0.040979233,,,,,,,,,,,,,,,,, -8100,0.021156361,0.043772474,,,,,,,,,,,,,,,,, -8200,0.039661907,0.041262213,,,,,,,,,,,,,,,,, -8271,,,0.9892852902412416,0.0360622964799404,0.3039532837698702,0.9862227439880372,0.0468678250908851,0.2347299525440665,43793.0,0.9852551817893982,0.0497630648314952,0.2241731692471199,43793.0,2653.61355137825,4102.255522012711,2653.61355137825,1448.121309518814,0.2982769012451172,0.0 -8300,0.016662382,0.044823356,,,,,,,,,,,,,,,,, -8400,0.030056266,0.039238434,,,,,,,,,,,,,,,,, -8500,0.013854158,0.0438358,,,,,,,,,,,,,,,,, -8600,0.01407483,0.040460136,,,,,,,,,,,,,,,,, -8700,0.016129376,0.041144118,,,,,,,,,,,,,,,,, -8800,0.016592879,0.04012928,,,,,,,,,,,,,,,,, -8900,0.023191432,0.04157338,,,,,,,,,,,,,,,,, -9000,0.0148254465,0.039406486,,,,,,,,,,,,,,,,, -9022,,,0.989479660987854,0.0352666191756725,0.3198014137231609,0.986312448978424,0.0462530441582202,0.2304908775967825,43793.0,0.985448122024536,0.0490517504513263,0.2379081424435728,43793.0,2893.6486990451813,4458.192250013351,2893.6486990451813,1563.973773241043,0.3276290893554687,0.0 -9100,0.014882458,0.04489697,,,,,,,,,,,,,,,,, -9200,0.015806183,0.040871732,,,,,,,,,,,,,,,,, -9300,0.019761294,0.04255205,,,,,,,,,,,,,,,,, -9400,0.013758177,0.038938303,,,,,,,,,,,,,,,,, -9500,0.017349454,0.039340846,,,,,,,,,,,,,,,,, -9600,0.013504926,0.03766906,,,,,,,,,,,,,,,,, -9700,0.023705237,0.03942273,,,,,,,,,,,,,,,,, -9774,,,0.9896937608718872,0.0343454629182815,0.3432716140602922,0.9864041805267334,0.0458738133311271,0.2375307320474216,43793.0,0.9855083227157592,0.0486259721219539,0.2380824970141881,43793.0,3133.746441602707,4815.885853528976,3133.746441602707,1681.5201325416565,0.3571395874023437,0.0 -9800,0.014001304,0.04122029,,,,,,,,,,,,,,,,, -9900,0.013141775,0.037359834,,,,,,,,,,,,,,,,, -10000,0.018196782,0.04204687,,,,,,,,,,,,,,,,, -10100,0.017647414,0.040999055,,,,,,,,,,,,,,,,, -10200,0.029839104,0.0421023,,,,,,,,,,,,,,,,, -10300,0.019653834,0.039820757,,,,,,,,,,,,,,,,, -10400,0.015386586,0.036832083,,,,,,,,,,,,,,,,, -10500,0.02033401,0.039774444,,,,,,,,,,,,,,,,, -10525,,,0.9900869131088256,0.0333064012229442,0.3712344743367008,0.9865828156471252,0.0453640073537826,0.2488915138371127,43793.0,0.9856384992599488,0.0480903573334217,0.2470925305778434,43793.0,3373.892038345337,5172.309501647949,3373.892038345337,1797.7493011951449,0.3862593173980713,0.0 -10600,0.0182419,0.04155783,,,,,,,,,,,,,,,,, -10700,0.016355716,0.040162157,,,,,,,,,,,,,,,,, -10800,0.014443124,0.03703387,,,,,,,,,,,,,,,,, -10900,0.017948378,0.039734505,,,,,,,,,,,,,,,,, -11000,0.016880987,0.039959937,,,,,,,,,,,,,,,,, -11100,0.0151933925,0.038701836,,,,,,,,,,,,,,,,, -11200,0.013843781,0.037903275,,,,,,,,,,,,,,,,, -11281,,,0.9902627468109132,0.0328679829835891,0.3630038486881168,0.986607551574707,0.0455364808440208,0.2463440279685129,43793.0,0.985637664794922,0.0483180433511734,0.2483078179636221,43793.0,3614.049623250961,5532.550810575485,3614.049623250961,1917.783516407013,0.4158000946044922,0.0 -11300,0.017288702,0.040832605,,,,,,,,,,,,,,,,, -11400,0.014612814,0.039495483,,,,,,,,,,,,,,,,, -11500,0.019290706,0.03856466,,,,,,,,,,,,,,,,, -11600,0.013782036,0.03550627,,,,,,,,,,,,,,,,, -11700,0.014726845,0.038166467,,,,,,,,,,,,,,,,, -11800,0.016501972,0.039752986,,,,,,,,,,,,,,,,, -11900,0.020284092,0.041540626,,,,,,,,,,,,,,,,, -12000,0.016237378,0.039757967,,,,,,,,,,,,,,,,, -12033,,,0.9904043674468994,0.0323328152298927,0.3812291308726584,0.9865823984146118,0.0455449484288692,0.2456691085848675,43793.0,0.9856178760528564,0.0485445819795131,0.2410220960823785,43793.0,3854.0442173480974,5892.653676509857,3854.0442173480974,2037.843641042709,0.4441356658935547,0.0 -12100,0.015644522,0.03607673,,,,,,,,,,,,,,,,, -12200,0.01818508,0.03799874,,,,,,,,,,,,,,,,, -12300,0.015715258,0.03706095,,,,,,,,,,,,,,,,, -12400,0.013793069,0.03685797,,,,,,,,,,,,,,,,, -12500,0.019519104,0.036685698,,,,,,,,,,,,,,,,, -12600,0.016446555,0.037085235,,,,,,,,,,,,,,,,, -12700,0.018140946,0.036497205,,,,,,,,,,,,,,,,, -12788,,,0.9902794361114502,0.0322381928563118,0.3840976298833426,0.986572265625,0.0456129871308803,0.2533034074588097,43793.0,0.9856410026550292,0.0485776364803314,0.2459781030991462,43793.0,4094.288563728333,6250.9817860126495,4094.288563728333,2155.8782205581665,0.473520278930664,0.0 -12800,0.024640564,0.038245827,,,,,,,,,,,,,,,,, -12900,0.019956226,0.038809504,,,,,,,,,,,,,,,,, -13000,0.01912824,0.041984424,,,,,,,,,,,,,,,,, -13100,0.017040052,0.04142825,,,,,,,,,,,,,,,,, -13200,0.02410711,0.03577059,,,,,,,,,,,,,,,,, -13300,0.014968814,0.037994523,,,,,,,,,,,,,,,,, -13400,0.021449363,0.040062875,,,,,,,,,,,,,,,,, -13500,0.013948858,0.0353881,,,,,,,,,,,,,,,,, -13530,,,0.9906366467475892,0.031362771987915,0.4005199168896542,0.9865929484367372,0.0450634248554706,0.2480390896979932,43793.0,0.9857252836227416,0.0479781739413738,0.247405006703782,43793.0,4334.27161693573,6612.0441880226135,4334.27161693573,2276.909010410309,0.5021927356719971,0.0 -13600,0.022429595,0.036683775,,,,,,,,,,,,,,,,, -13700,0.019656673,0.03742568,,,,,,,,,,,,,,,,, -13800,0.02073924,0.035295647,,,,,,,,,,,,,,,,, -13900,0.015251582,0.034541633,,,,,,,,,,,,,,,,, -14000,0.017532714,0.037320636,,,,,,,,,,,,,,,,, -14100,0.016400497,0.03463339,,,,,,,,,,,,,,,,, -14200,0.017659724,0.034310266,,,,,,,,,,,,,,,,, -14272,,,0.9906654953956604,0.03089433722198,0.4312031362505946,0.9866920113563538,0.0452400967478752,0.2555988970728234,43793.0,0.9857500791549684,0.0482579730451107,0.2457903663583589,43793.0,4574.2215077877045,6968.494510889053,4574.2215077877045,2393.3545455932617,0.5375001430511475,0.0 -14300,0.020652667,0.038906265,,,,,,,,,,,,,,,,, -14400,0.026631543,0.036799654,,,,,,,,,,,,,,,,, -14500,0.015681192,0.033929355,,,,,,,,,,,,,,,,, -14600,0.017360082,0.038111117,,,,,,,,,,,,,,,,, -14700,0.01821726,0.04035226,,,,,,,,,,,,,,,,, -14800,0.024493728,0.03608419,,,,,,,,,,,,,,,,, -14900,0.021524148,0.037821136,,,,,,,,,,,,,,,,, -15000,0.02125665,0.03684442,,,,,,,,,,,,,,,,, -15008,,,0.990509271621704,0.0307738613337278,0.433661814830482,0.986723244190216,0.0458978489041328,0.2576038652916376,43793.0,0.9857884645462036,0.0488702729344368,0.2505672932539682,43793.0,4814.304394006729,7325.689950466156,4814.304394006729,2510.4126312732697,0.5691132545471191,0.0 -15100,0.025555471,0.040677756,,,,,,,,,,,,,,,,, -15200,0.025924558,0.040409286,,,,,,,,,,,,,,,,, -15300,0.01946985,0.0364627,,,,,,,,,,,,,,,,, -15400,0.022016324,0.037302252,,,,,,,,,,,,,,,,, -15500,0.02066824,0.03680235,,,,,,,,,,,,,,,,, -15600,0.018983485,0.036465086,,,,,,,,,,,,,,,,, -15700,0.020107111,0.03377728,,,,,,,,,,,,,,,,, -15758,,,0.9911080598831176,0.0294422395527362,0.4476567930154197,0.9865406155586244,0.0454853437840938,0.2522223362011558,43793.0,0.985659956932068,0.0484933592379093,0.2410234352551857,43793.0,5054.479521274567,7687.453608751297,5054.479521274567,2631.9517204761505,0.5978469848632812,0.0 -15800,0.021129083,0.0354642,,,,,,,,,,,,,,,,, -15900,0.02488704,0.035101987,,,,,,,,,,,,,,,,, -16000,0.021430496,0.035657085,,,,,,,,,,,,,,,,, -16100,0.021225788,0.035394683,,,,,,,,,,,,,,,,, -16200,0.021711687,0.037141822,,,,,,,,,,,,,,,,, -16300,0.022484323,0.036503524,,,,,,,,,,,,,,,,, -16400,0.022814186,0.0354808,,,,,,,,,,,,,,,,, -16500,0.019834155,0.035186402,,,,,,,,,,,,,,,,, -16507,,,0.9912184476852416,0.0284808743745088,0.4712633126096623,0.9866733551025392,0.0455878898501396,0.2604579027919871,43793.0,0.98580402135849,0.0486540645360946,0.2458925109038763,43793.0,5294.512858390808,8047.071723461151,5294.512858390808,2751.486648082733,0.627873420715332,0.0 -16600,0.0214916,0.037187863,,,,,,,,,,,,,,,,, -16700,0.02362353,0.035839774,,,,,,,,,,,,,,,,, -16800,0.021303305,0.036762882,,,,,,,,,,,,,,,,, -16900,0.024827436,0.0378402,,,,,,,,,,,,,,,,, -17000,0.019856669,0.035669617,,,,,,,,,,,,,,,,, -17100,0.019670252,0.03302417,,,,,,,,,,,,,,,,, -17200,0.024907833,0.03525471,,,,,,,,,,,,,,,,, -17256,,,0.9917227029800416,0.0272119287401437,0.5129358647554417,0.9868016242980956,0.0451838597655296,0.2661670100367674,43793.0,0.9859118461608888,0.0483159162104129,0.25409906685934,43793.0,5534.532721757889,8402.984964370728,5534.532721757889,2867.3289663791656,0.6588001251220703,0.0 -17300,0.023018494,0.035994783,,,,,,,,,,,,,,,,, -17400,0.028858304,0.040685322,,,,,,,,,,,,,,,,, -17500,0.034746926,0.03792398,,,,,,,,,,,,,,,,, -17600,0.02329925,0.035487197,,,,,,,,,,,,,,,,, -17700,0.021745184,0.034990173,,,,,,,,,,,,,,,,, -17800,0.02011822,0.0348544,,,,,,,,,,,,,,,,, -17900,0.023493728,0.035359293,,,,,,,,,,,,,,,,, -18000,0.03360892,0.038955636,,,,,,,,,,,,,,,,, -18010,,,0.9915904998779296,0.0272315237671136,0.520161026037464,0.9867346286773682,0.0460444428026676,0.261773279133514,43793.0,0.9857736825942992,0.0492622181773185,0.2460779115552562,43793.0,5774.725650072098,8758.792022228241,5774.725650072098,2982.892689228058,0.689293384552002,0.0 -18100,0.024592457,0.034449007,,,,,,,,,,,,,,,,, -18200,0.024097474,0.03312947,,,,,,,,,,,,,,,,, -18300,0.02271278,0.0363624,,,,,,,,,,,,,,,,, -18400,0.032393575,0.033884734,,,,,,,,,,,,,,,,, -18500,0.02964362,0.035853196,,,,,,,,,,,,,,,,, -18600,0.024907626,0.035534367,,,,,,,,,,,,,,,,, -18700,0.026614144,0.03753515,,,,,,,,,,,,,,,,, -18764,,,0.9917588233947754,0.0268174763768911,0.5093664431630106,0.9866693019866944,0.0456399135291576,0.2636575558792305,43793.0,0.9858528971672058,0.0488565750420093,0.2505815096182588,43793.0,6014.775934457779,9117.453943014145,6014.775934457779,3101.454934358597,0.7188436985015869,0.0 -18800,0.023547096,0.035588063,,,,,,,,,,,,,,,,, -18900,0.02531788,0.035255887,,,,,,,,,,,,,,,,, -19000,0.023038257,0.03418201,,,,,,,,,,,,,,,,, -19100,0.028275557,0.03367347,,,,,,,,,,,,,,,,, -19200,0.026395312,0.033526223,,,,,,,,,,,,,,,,, -19300,0.026207613,0.035658218,,,,,,,,,,,,,,,,, -19400,0.028072596,0.035253603,,,,,,,,,,,,,,,,, -19500,0.03223864,0.036216754,,,,,,,,,,,,,,,,, -19513,,,0.9917542934417723,0.0272373333573341,0.492260297195812,0.9866904020309448,0.0456603653728961,0.2597286529450691,43793.0,0.9858368635177612,0.0486247912049293,0.2516486277602932,43793.0,6254.735512256622,9470.525463342668,6254.735512256622,3214.5177624225616,0.7481474876403809,0.0 -19600,0.023477858,0.032553297,,,,,,,,,,,,,,,,, -19700,0.023564765,0.033286113,,,,,,,,,,,,,,,,, -19800,0.023611756,0.03497944,,,,,,,,,,,,,,,,, -19900,0.02661714,0.03429543,,,,,,,,,,,,,,,,, -20000,0.02374444,0.034881208,,,,,,,,,,,,,,,,, -20100,0.025329096,0.032502357,,,,,,,,,,,,,,,,, -20200,0.028291643,0.03445099,,,,,,,,,,,,,,,,, -20277,,,0.9916300177574158,0.0272173155099153,0.5153293861606729,0.9866891503334044,0.0464046634733676,0.2622391754031578,43793.0,0.9858145713806152,0.04971195012331,0.2459385786965277,43793.0,6494.936870098114,9825.827875614166,6494.936870098114,3329.5672364234924,0.7790920734405518,0.0 -20300,0.028227527,0.034231976,,,,,,,,,,,,,,,,, -20400,0.029156437,0.037379086,,,,,,,,,,,,,,,,, -20500,0.03133032,0.033674873,,,,,,,,,,,,,,,,, -20600,0.025532763,0.033124395,,,,,,,,,,,,,,,,, -20700,0.025075192,0.033952713,,,,,,,,,,,,,,,,, -20800,0.025497451,0.03290558,,,,,,,,,,,,,,,,, -20900,0.022336895,0.03104013,,,,,,,,,,,,,,,,, -21000,0.029052604,0.03494999,,,,,,,,,,,,,,,,, -21037,,,0.991977035999298,0.0264404602348804,0.5115002049030158,0.9867467880249025,0.0457068383693695,0.2609522036510439,43793.0,0.9858802556991576,0.0489316806197166,0.2508172272557681,43793.0,6735.167747020721,10180.737467765808,6735.167747020721,3444.192747116089,0.8117268085479736,0.0 -21100,0.03224792,0.03541613,,,,,,,,,,,,,,,,, -21200,0.024970915,0.033567898,,,,,,,,,,,,,,,,, -21300,0.029874563,0.03236063,,,,,,,,,,,,,,,,, -21400,0.040867813,0.035220478,,,,,,,,,,,,,,,,, -21500,0.026158266,0.032118928,,,,,,,,,,,,,,,,, -21600,0.031353343,0.034489367,,,,,,,,,,,,,,,,, -21700,0.029056186,0.032652073,,,,,,,,,,,,,,,,, -21796,,,0.9917345643043518,0.0266399029642343,0.5187101817819291,0.986717164516449,0.0467933230102062,0.2639815040254276,43793.0,0.9857686161994934,0.0502894409000873,0.2469956287644825,43793.0,6975.293369054794,10535.719145298004,6975.293369054794,3558.997751712799,0.842423677444458,0.0 -21800,0.0310071,0.03260361,,,,,,,,,,,,,,,,, -21900,0.031793647,0.032357328,,,,,,,,,,,,,,,,, -22000,0.025328778,0.031253282,,,,,,,,,,,,,,,,, -22100,0.027240282,0.035303045,,,,,,,,,,,,,,,,, -22200,0.03651537,0.035973657,,,,,,,,,,,,,,,,, -22300,0.032408074,0.03467021,,,,,,,,,,,,,,,,, -22400,0.0393183,0.033727728,,,,,,,,,,,,,,,,, -22500,0.035004847,0.03541711,,,,,,,,,,,,,,,,, -22547,,,0.992116153240204,0.0254687368869781,0.5342317305319013,0.986757755279541,0.0462933629751205,0.2647903676022655,43793.0,0.9858317971229552,0.0495360493659973,0.2483276572328031,43793.0,7215.500328779221,10890.866336345673,7215.500328779221,3673.886606693268,0.8734378814697266,0.0 -22600,0.03514867,0.033257317,,,,,,,,,,,,,,,,, -22700,0.028273923,0.03195807,,,,,,,,,,,,,,,,, -22800,0.031582057,0.03492022,,,,,,,,,,,,,,,,, -22900,0.027233783,0.030164387,,,,,,,,,,,,,,,,, -23000,0.03319302,0.03494593,,,,,,,,,,,,,,,,, -23100,0.02847766,0.03073287,,,,,,,,,,,,,,,,, -23200,0.036321275,0.034850262,,,,,,,,,,,,,,,,, -23294,,,0.9923154711723328,0.0248916465789079,0.5460445997356278,0.9866899847984314,0.0465769246220588,0.254345836560073,43793.0,0.9859139323234558,0.0498078241944313,0.2459262801428786,43793.0,7455.528498411179,11247.001574993134,7455.528498411179,3789.942592382431,0.9038915634155272,0.0 -23300,0.030387467,0.03381214,,,,,,,,,,,,,,,,, -23400,0.029438283,0.03065878,,,,,,,,,,,,,,,,, -23500,0.03433064,0.0324492,,,,,,,,,,,,,,,,, -23600,0.031760767,0.031556353,,,,,,,,,,,,,,,,, -23700,0.032180917,0.032385543,,,,,,,,,,,,,,,,, -23800,0.032083742,0.033740863,,,,,,,,,,,,,,,,, -23900,0.02971203,0.032463133,,,,,,,,,,,,,,,,, -24000,0.04256886,0.032520887,,,,,,,,,,,,,,,,, -24044,,,0.9923921823501588,0.0241668839007616,0.5712409322426214,0.9866307377815248,0.0474363043904304,0.2526439866836966,43793.0,0.9857686161994934,0.0509364232420921,0.2463427566125356,43793.0,7695.701915502548,11602.075031757357,7695.701915502548,3904.790680646896,0.9359946250915528,0.0 -24100,0.04564901,0.037938282,,,,,,,,,,,,,,,,, -24200,0.031058693,0.03251565,,,,,,,,,,,,,,,,, -24300,0.03596606,0.030914081,,,,,,,,,,,,,,,,, -24400,0.03727678,0.034787886,,,,,,,,,,,,,,,,, -24500,0.034055054,0.034419235,,,,,,,,,,,,,,,,, -24600,0.03397289,0.032647062,,,,,,,,,,,,,,,,, -24700,0.03483228,0.03186635,,,,,,,,,,,,,,,,, -24798,,,0.9930984377861024,0.0224706102162599,0.5961115454126245,0.9866713285446168,0.0468453131616115,0.2598697709526262,43793.0,0.9857513904571532,0.0502862893044948,0.2472731325056807,43793.0,7935.791239023209,11959.413158655168,7935.791239023209,4021.984845399857,0.9708609580993652,0.0 -24800,0.042779077,0.034252685,,,,,,,,,,,,,,,,, -24900,0.03873827,0.032223538,,,,,,,,,,,,,,,,, -25000,0.03546264,0.031802468,,,,,,,,,,,,,,,,, -25100,0.037815705,0.0333107,,,,,,,,,,,,,,,,, -25200,0.040396158,0.03047881,,,,,,,,,,,,,,,,, -25300,0.03339384,0.032382406,,,,,,,,,,,,,,,,, -25400,0.038153276,0.032922506,,,,,,,,,,,,,,,,, -25500,0.03339478,0.031983647,,,,,,,,,,,,,,,,, -25538,,,0.9933809638023376,0.0218420680612325,0.6266066179462839,0.986694872379303,0.0469269268214702,0.2572042716750418,43793.0,0.9857223033905028,0.0505571216344833,0.2435503236909003,43793.0,8176.01961183548,12325.004558086395,8176.01961183548,4147.293561458588,1.0030477046966553,0.0 -25600,0.03391549,0.03152055,,,,,,,,,,,,,,,,, -25700,0.040286046,0.031243714,,,,,,,,,,,,,,,,, -25800,0.034445684,0.03248886,,,,,,,,,,,,,,,,, -25900,0.039308257,0.033723798,,,,,,,,,,,,,,,,, -26000,0.038874023,0.033135667,,,,,,,,,,,,,,,,, -26100,0.0383726,0.033102505,,,,,,,,,,,,,,,,, -26200,0.03694368,0.033638168,,,,,,,,,,,,,,,,, -26291,,,0.9928025603294371,0.0234451312571764,0.5761827578002454,0.9866213798522948,0.0469994507730007,0.2511379755020969,43793.0,0.9857041835784912,0.0503971055150032,0.2480874887484809,43793.0,8416.271092653275,12680.953973293304,8416.271092653275,4262.936295509338,1.0376760959625244,0.0 -26300,0.041413162,0.032059226,,,,,,,,,,,,,,,,, -26400,0.0452701,0.033588856,,,,,,,,,,,,,,,,, -26500,0.044903055,0.03147089,,,,,,,,,,,,,,,,, -26600,0.04029111,0.029587202,,,,,,,,,,,,,,,,, -26700,0.035532653,0.031418987,,,,,,,,,,,,,,,,, -26800,0.034170315,0.029861504,,,,,,,,,,,,,,,,, -26900,0.040065072,0.03271444,,,,,,,,,,,,,,,,, -27000,0.04716744,0.03373628,,,,,,,,,,,,,,,,, -27040,,,0.992655873298645,0.0236091390252113,0.5897485965947075,0.9866039156913756,0.048078216612339,0.2507343853926724,43793.0,0.9857812523841858,0.0514933206140995,0.2459544175952203,43793.0,8656.47072982788,13040.865669727324,8656.47072982788,4382.595624923706,1.070669651031494,0.0 -27100,0.04640654,0.030934047,,,,,,,,,,,,,,,,, -27200,0.038051818,0.032006882,,,,,,,,,,,,,,,,, -27300,0.03765858,0.030444968,,,,,,,,,,,,,,,,, -27400,0.035953864,0.03059263,,,,,,,,,,,,,,,,, -27500,0.048884682,0.030073907,,,,,,,,,,,,,,,,, -27600,0.044566084,0.031453863,,,,,,,,,,,,,,,,, -27700,0.05461363,0.03271628,,,,,,,,,,,,,,,,, -27791,,,0.9928238987922668,0.0232392903417348,0.5776398059058571,0.9865588545799256,0.0476624406874179,0.2508783567306901,43793.0,0.9857842326164246,0.0510902516543865,0.2458477064293119,43793.0,8896.424266338348,13395.142271757126,8896.424266338348,4496.867026567459,1.102473258972168,0.0 -27800,0.04569939,0.030855536,,,,,,,,,,,,,,,,, -27900,0.042065036,0.032275766,,,,,,,,,,,,,,,,, -28000,0.044688266,0.031758215,,,,,,,,,,,,,,,,, -28100,0.04424023,0.032076105,,,,,,,,,,,,,,,,, -28200,0.041794796,0.031953074,,,,,,,,,,,,,,,,, -28300,0.042957887,0.030664021,,,,,,,,,,,,,,,,, -28400,0.046849225,0.030850623,,,,,,,,,,,,,,,,, -28500,0.040625997,0.031036362,,,,,,,,,,,,,,,,, -28536,,,0.9926972389221193,0.023307790979743,0.5889419244248999,0.9865353107452391,0.0485096462070941,0.2496172886362365,43793.0,0.9857046008110046,0.0521145649254322,0.2429519902380678,43793.0,9136.532777786257,13751.155616044998,9136.532777786257,4612.7197296619415,1.1351377964019775,0.0 -28600,0.040138546,0.029754976,,,,,,,,,,,,,,,,, -28700,0.043064207,0.029948294,,,,,,,,,,,,,,,,, -28800,0.047540028,0.031620495,,,,,,,,,,,,,,,,, -28900,0.049103998,0.031533483,,,,,,,,,,,,,,,,, -29000,0.045761637,0.03205459,,,,,,,,,,,,,,,,, -29100,0.04400602,0.030917456,,,,,,,,,,,,,,,,, -29200,0.043589428,0.029006422,,,,,,,,,,,,,,,,, -29288,,,0.9929196834564208,0.0228444803506135,0.5824909793466204,0.9865726828575134,0.0482287295162677,0.2497479171584527,43793.0,0.9856772422790528,0.0517030544579029,0.2464958534102047,43793.0,9376.589265823364,14111.342035531998,9376.589265823364,4732.798652887344,1.166337966918945,0.0 -29300,0.059011545,0.03165785,,,,,,,,,,,,,,,,, -29400,0.042415783,0.03059765,,,,,,,,,,,,,,,,, -29500,0.0445487,0.032460533,,,,,,,,,,,,,,,,, -29600,0.046272993,0.029931411,,,,,,,,,,,,,,,,, -29700,0.049707625,0.030362034,,,,,,,,,,,,,,,,, -29800,0.043081608,0.029949624,,,,,,,,,,,,,,,,, -29900,0.04810178,0.030293602,,,,,,,,,,,,,,,,, -30000,0.049859244,0.032270763,,,,,,,,,,,,,,,,, -30037,,,0.9931166768074036,0.0221618507057428,0.6110780190593907,0.9865093231201172,0.0487466380000114,0.2475532409059372,43793.0,0.9855828881263732,0.0522092543542385,0.239892505514401,43793.0,9616.7351603508,14462.951917409897,9616.7351603508,4844.209650993347,1.199488401412964,0.0 -30100,0.048847433,0.031220604,,,,,,,,,,,,,,,,, -30200,0.05730227,0.0317253,,,,,,,,,,,,,,,,, -30300,0.05093946,0.032478593,,,,,,,,,,,,,,,,, -30400,0.04506232,0.029243806,,,,,,,,,,,,,,,,, -30500,0.05665817,0.029392604,,,,,,,,,,,,,,,,, -30600,0.07089264,0.03147456,,,,,,,,,,,,,,,,, -30700,0.0475198,0.0316367,,,,,,,,,,,,,,,,, -30790,,,0.9931876063346864,0.0214119311422109,0.6239691444803845,0.9865888953208924,0.0495238825678825,0.2529329067237749,43793.0,0.9857505559921264,0.0530433468520641,0.2417635372968837,43793.0,9856.721096038818,14818.875288009644,9856.721096038818,4960.09548664093,1.2316246032714844,0.0 -30800,0.049247783,0.029158246,,,,,,,,,,,,,,,,, -30900,0.043436915,0.028759925,,,,,,,,,,,,,,,,, -31000,0.04066469,0.02859293,,,,,,,,,,,,,,,,, -31100,0.0537043,0.027780192,,,,,,,,,,,,,,,,, -31200,0.048415776,0.030205052,,,,,,,,,,,,,,,,, -31300,0.04918663,0.0311796,,,,,,,,,,,,,,,,, -31400,0.06374696,0.029481327,,,,,,,,,,,,,,,,, -31500,0.055614837,0.030779831,,,,,,,,,,,,,,,,, -31529,,,0.9937662482261658,0.020005514845252,0.6605400162041181,0.9865211248397828,0.0490268692374229,0.2478563465140137,43793.0,0.9855732321739196,0.0524737164378166,0.2433974481567167,43793.0,10096.81805896759,15170.70418548584,10096.81805896759,5071.77366065979,1.265533208847046,0.0 -31600,0.04743301,0.02970096,,,,,,,,,,,,,,,,, -31700,0.062180486,0.030425463,,,,,,,,,,,,,,,,, -31800,0.059602607,0.029966287,,,,,,,,,,,,,,,,, -31900,0.058149517,0.031506475,,,,,,,,,,,,,,,,, -32000,0.055127505,0.029986667,,,,,,,,,,,,,,,,, -32100,0.047357287,0.029776607,,,,,,,,,,,,,,,,, -32200,0.049175438,0.030826489,,,,,,,,,,,,,,,,, -32270,,,0.9942206740379332,0.0189791806042194,0.6804974564516768,0.9864314198493958,0.0497116670012474,0.2431531213486542,43793.0,0.9855959415435792,0.0532693080604076,0.2317083155339858,43793.0,10336.932120800018,15529.72211575508,10336.932120800018,5190.624104499817,1.299309253692627,0.0 -32300,0.052014317,0.030557042,,,,,,,,,,,,,,,,, -32400,0.049383488,0.03007396,,,,,,,,,,,,,,,,, -32500,0.05381314,0.027640382,,,,,,,,,,,,,,,,, -32600,0.05864082,0.030989986,,,,,,,,,,,,,,,,, -32700,0.06688648,0.028763983,,,,,,,,,,,,,,,,, -32800,0.056403887,0.029261341,,,,,,,,,,,,,,,,, -32900,0.06003562,0.031807195,,,,,,,,,,,,,,,,, -33000,0.049984112,0.027204918,,,,,,,,,,,,,,,,, -33024,,,0.9941815733909608,0.0190620310604572,0.6768720292775423,0.9864902496337892,0.0494199618697166,0.2456189152667728,43793.0,0.985584557056427,0.053100511431694,0.2379802507963694,43793.0,10577.190371990204,15888.460463523865,10577.190371990204,5309.052172183991,1.3314814567565918,0.0 -33100,0.056763645,0.02993699,,,,,,,,,,,,,,,,, -33200,0.06046072,0.029086407,,,,,,,,,,,,,,,,, -33300,0.05311883,0.029748788,,,,,,,,,,,,,,,,, -33400,0.052155316,0.028904067,,,,,,,,,,,,,,,,, -33500,0.061567444,0.027912209,,,,,,,,,,,,,,,,, -33600,0.05902857,0.030835692,,,,,,,,,,,,,,,,, -33700,0.052534994,0.028906468,,,,,,,,,,,,,,,,, -33781,,,0.9938918352127076,0.0196193568408489,0.6653318432091149,0.9864837527275084,0.0499318204820156,0.2432102650466545,43793.0,0.9855904579162598,0.0534843131899833,0.2413656636156215,43793.0,10817.216018676758,16245.507838010788,10817.216018676758,5426.022752046585,1.362900733947754,0.0 -33800,0.06353984,0.030102715,,,,,,,,,,,,,,,,, -33900,0.057546712,0.028382562,,,,,,,,,,,,,,,,, -34000,0.06173232,0.030178541,,,,,,,,,,,,,,,,, -34100,0.055138756,0.03044477,,,,,,,,,,,,,,,,, -34200,0.05327658,0.029483814,,,,,,,,,,,,,,,,, -34300,0.063078366,0.030225098,,,,,,,,,,,,,,,,, -34400,0.06479648,0.03093874,,,,,,,,,,,,,,,,, -34500,0.053966183,0.02900942,,,,,,,,,,,,,,,,, -34531,,,0.9931476712226868,0.0218456238508224,0.6065809626161704,0.9863153100013732,0.0502861440181732,0.2356881656238881,43793.0,0.9854485392570496,0.053956814110279,0.2340651284593916,43793.0,11057.21807861328,16601.556545495987,11057.21807861328,5542.016725063324,1.3957459926605225,0.0 -34600,0.07289142,0.031307876,,,,,,,,,,,,,,,,, -34700,0.06233201,0.027981248,,,,,,,,,,,,,,,,, -34800,0.06019645,0.030524585,,,,,,,,,,,,,,,,, -34900,0.0646574,0.027754582,,,,,,,,,,,,,,,,, -35000,0.06338943,0.028580189,,,,,,,,,,,,,,,,, -35100,0.053641595,0.027005032,,,,,,,,,,,,,,,,, -35200,0.058161113,0.029660363,,,,,,,,,,,,,,,,, -35289,,,0.9932732582092284,0.0213515106588602,0.6102445285731928,0.9863283038139344,0.0507453233003616,0.2422555799690798,43793.0,0.9854581952095032,0.0542607828974723,0.2370528405868589,43793.0,11297.394924879074,16956.7015914917,11297.394924879074,5656.931999921799,1.4285671710968018,0.0 -35300,0.05983239,0.028882898,,,,,,,,,,,,,,,,, -35400,0.059789617,0.028305063,,,,,,,,,,,,,,,,, -35500,0.056112207,0.028075881,,,,,,,,,,,,,,,,, -35600,0.08292904,0.02910004,,,,,,,,,,,,,,,,, -35700,0.065291315,0.029077103,,,,,,,,,,,,,,,,, -35800,0.080063626,0.029402431,,,,,,,,,,,,,,,,, -35900,0.06582139,0.02788606,,,,,,,,,,,,,,,,, -36000,0.06458185,0.028615762,,,,,,,,,,,,,,,,, -36047,,,0.9936595559120178,0.0201679468154907,0.6438834465263861,0.9863104224205016,0.0507057793438434,0.2338064859736999,43793.0,0.9854379892349244,0.0543590039014816,0.2314366596590499,43793.0,11537.618818044662,17312.897706270218,11537.618818044662,5772.85106253624,1.4614310264587402,0.0 -36100,0.059201103,0.02833176,,,,,,,,,,,,,,,,, -36200,0.06361746,0.028575365,,,,,,,,,,,,,,,,, -36300,0.05970494,0.02744892,,,,,,,,,,,,,,,,, -36400,0.069663905,0.029603263,,,,,,,,,,,,,,,,, -36500,0.06103731,0.026869351,,,,,,,,,,,,,,,,, -36600,0.061636884,0.0280668,,,,,,,,,,,,,,,,, -36700,0.061443694,0.028011763,,,,,,,,,,,,,,,,, -36797,,,0.9936636686325072,0.0199106354266405,0.6476884772247021,0.9864094853401184,0.0514586195349693,0.2377665407309361,43793.0,0.985491931438446,0.0551511570811271,0.2346267707140708,43793.0,11777.46779179573,17670.489493846893,11777.46779179573,5890.137528896332,1.8974733352661133,0.0 -36800,0.062639624,0.028030638,,,,,,,,,,,,,,,,, -36900,0.06329808,0.026415646,,,,,,,,,,,,,,,,, -37000,0.064024985,0.027479518,,,,,,,,,,,,,,,,, -37100,0.065858856,0.027531099,,,,,,,,,,,,,,,,, -37200,0.07767673,0.029321983,,,,,,,,,,,,,,,,, -37300,0.066469505,0.028836431,,,,,,,,,,,,,,,,, -37400,0.06589168,0.027308049,,,,,,,,,,,,,,,,, -37500,0.065450385,0.028797563,,,,,,,,,,,,,,,,, -37549,,,0.9935628175735474,0.0201002769172191,0.6463171003589824,0.9863964915275574,0.0519293472170829,0.2382274855362635,43793.0,0.9854986667633056,0.0555910542607307,0.2343527281268959,43793.0,12017.64917397499,18023.41338586808,12017.64917397499,6002.826015472412,1.9313278198242188,0.0 -37600,0.0801134,0.030299304,,,,,,,,,,,,,,,,, -37700,0.07509145,0.02773243,,,,,,,,,,,,,,,,, -37800,0.073868595,0.027441403,,,,,,,,,,,,,,,,, -37900,0.071012326,0.027686449,,,,,,,,,,,,,,,,, -38000,0.06579101,0.026300069,,,,,,,,,,,,,,,,, -38100,0.063549206,0.02726251,,,,,,,,,,,,,,,,, -38200,0.07381503,0.028849194,,,,,,,,,,,,,,,,, -38299,,,0.9944259524345398,0.0178183428943157,0.7018950843743552,0.9862747192382812,0.0519014485180377,0.2361996599409608,43793.0,0.9853967428207396,0.0553966537117958,0.2333304576637424,43793.0,12257.801238059998,18378.24086880684,12257.801238059998,6117.448066949844,1.9647307395935056,0.0 -38300,0.07372468,0.02670208,,,,,,,,,,,,,,,,, -38400,0.06402663,0.026882803,,,,,,,,,,,,,,,,, -38500,0.07309164,0.029227631,,,,,,,,,,,,,,,,, -38600,0.06910076,0.027370667,,,,,,,,,,,,,,,,, -38700,0.07081007,0.026981387,,,,,,,,,,,,,,,,, -38800,0.06493725,0.025687177,,,,,,,,,,,,,,,,, -38900,0.067853734,0.027074637,,,,,,,,,,,,,,,,, -39000,0.087793,0.030263582,,,,,,,,,,,,,,,,, -39053,,,0.9949982762336732,0.0165216345340013,0.7224225479696811,0.9863018989562988,0.0520194247364997,0.2346073985838009,43793.0,0.985359251499176,0.055601317435503,0.2355170950865257,43793.0,12497.982424497604,18735.421145915985,12497.982424497604,6234.392640352249,1.999298334121704,0.0 -39100,0.069599874,0.02946493,,,,,,,,,,,,,,,,, -39200,0.069194436,0.026153062,,,,,,,,,,,,,,,,, -39300,0.07229822,0.026815565,,,,,,,,,,,,,,,,, -39400,0.06622265,0.02718358,,,,,,,,,,,,,,,,, -39500,0.07108161,0.027928354,,,,,,,,,,,,,,,,, -39600,0.07368795,0.027067553,,,,,,,,,,,,,,,,, -39700,0.07226055,0.027698439,,,,,,,,,,,,,,,,, -39792,,,0.9952055215835572,0.0157686341553926,0.7366664288057573,0.9863213896751404,0.0527669936418533,0.2330148621436922,43793.0,0.9853891134262084,0.0565869100391864,0.2302643002277187,43793.0,12738.079112768171,19092.81990480423,12738.079112768171,6351.638965606689,2.03450345993042,0.0 -39800,0.062756255,0.026342098,,,,,,,,,,,,,,,,, -39900,0.07519913,0.027612697,,,,,,,,,,,,,,,,, -40000,0.06438484,0.027127296,,,,,,,,,,,,,,,,, -40100,0.07208116,0.027567072,,,,,,,,,,,,,,,,, -40200,0.07395913,0.026426256,,,,,,,,,,,,,,,,, -40300,0.08484233,0.02881361,,,,,,,,,,,,,,,,, -40400,0.07019797,0.026707055,,,,,,,,,,,,,,,,, -40500,0.076140516,0.026813855,,,,,,,,,,,,,,,,, -40539,,,0.9952002763748168,0.0160813517868518,0.7377185027103813,0.9861488342285156,0.0530131682753562,0.2330946343841542,43793.0,0.9852867722511292,0.0564857460558414,0.2309985263228809,43793.0,12978.10954594612,19451.29268836975,12978.10954594612,6470.0264637470245,2.069436550140381,0.0 -40600,0.07805977,0.028252278,,,,,,,,,,,,,,,,, -40700,0.08282327,0.029031046,,,,,,,,,,,,,,,,, -40800,0.06635835,0.026019525,,,,,,,,,,,,,,,,, -40900,0.0662945,0.025890572,,,,,,,,,,,,,,,,, -41000,0.078289405,0.02950712,,,,,,,,,,,,,,,,, -41100,0.07280345,0.028294569,,,,,,,,,,,,,,,,, -41200,0.082717,0.027373198,,,,,,,,,,,,,,,,, -41274,,,0.9948239326477052,0.0167916342616081,0.7085789565068721,0.9862288236618042,0.0531344339251518,0.2288084561811299,43793.0,0.985367238521576,0.0568977668881416,0.228333392060031,43793.0,13218.214102506638,19811.006506443024,13218.214102506638,6589.574536561966,2.107372522354126,0.0 -41300,0.08054119,0.026357219,,,,,,,,,,,,,,,,, -41400,0.08238219,0.028081613,,,,,,,,,,,,,,,,, -41500,0.10703981,0.027223332,,,,,,,,,,,,,,,,, -41600,0.091375194,0.026762152,,,,,,,,,,,,,,,,, -41700,0.072653964,0.026240231,,,,,,,,,,,,,,,,, -41800,0.077801086,0.027454317,,,,,,,,,,,,,,,,, -41900,0.07753202,0.026402792,,,,,,,,,,,,,,,,, -42000,0.07255739,0.02662382,,,,,,,,,,,,,,,,, -42022,,,0.9945333003997804,0.0173073932528495,0.7135553817663034,0.9861322045326232,0.0535982139408588,0.2313976091081589,43793.0,0.9854017496109008,0.05723188072443,0.2287596771324323,43793.0,13458.43649482727,20166.63150715828,13458.43649482727,6704.920714616776,2.143451690673828,0.0 -42100,0.07662168,0.027403492,,,,,,,,,,,,,,,,, -42200,0.08071981,0.027290206,,,,,,,,,,,,,,,,, -42300,0.0736628,0.026788415,,,,,,,,,,,,,,,,, -42400,0.09336833,0.026543396,,,,,,,,,,,,,,,,, -42500,0.080635644,0.027639268,,,,,,,,,,,,,,,,, -42600,0.09109598,0.027130917,,,,,,,,,,,,,,,,, -42700,0.08514557,0.027753644,,,,,,,,,,,,,,,,, -42774,,,0.9939397573471068,0.0189362727105617,0.6697946583231249,0.9862284064292908,0.0541604533791542,0.2336562058187641,43793.0,0.9853150248527528,0.0580363795161247,0.2279296138834921,43793.0,13698.556680202484,20523.85961985588,13698.556680202484,6821.973500728607,2.1787869930267334,0.0 -42800,0.06371082,0.024261527,,,,,,,,,,,,,,,,, -42900,0.07728255,0.026599977,,,,,,,,,,,,,,,,, -43000,0.08785899,0.025602372,,,,,,,,,,,,,,,,, -43100,0.09545844,0.027683137,,,,,,,,,,,,,,,,, -43200,0.07617843,0.02684898,,,,,,,,,,,,,,,,, -43300,0.077601224,0.025415817,,,,,,,,,,,,,,,,, -43400,0.08038797,0.027461424,,,,,,,,,,,,,,,,, -43500,0.06712968,0.02521833,,,,,,,,,,,,,,,,, -43522,,,0.994067132472992,0.0184782650321722,0.6751643354404507,0.9861873984336852,0.0549007356166839,0.2255450868821931,43793.0,0.9853845238685608,0.0585099793970584,0.2311539778899343,43793.0,13938.790473937988,20879.66306090355,13938.790473937988,6937.488364696503,2.2132139205932617,0.0 -43600,0.093930975,0.028294299,,,,,,,,,,,,,,,,, -43700,0.07815655,0.02659882,,,,,,,,,,,,,,,,, -43800,0.082266,0.026244313,,,,,,,,,,,,,,,,, -43900,0.08084699,0.026786793,,,,,,,,,,,,,,,,, -44000,0.09251419,0.025746379,,,,,,,,,,,,,,,,, -44100,0.08210834,0.025968377,,,,,,,,,,,,,,,,, -44200,0.09081885,0.0269774,,,,,,,,,,,,,,,,, -44274,,,0.9948219656944276,0.0166021138429641,0.7204996206656971,0.9860494136810304,0.0540257133543491,0.2289064375174395,43793.0,0.9851073622703552,0.0578879825770854,0.2267297435673207,43793.0,14178.823575496674,21232.88632273674,14178.823575496674,7050.623664855957,2.248157262802124,0.0 -44300,0.11239807,0.02813178,,,,,,,,,,,,,,,,, -44400,0.072862536,0.024612635,,,,,,,,,,,,,,,,, -44500,0.084415704,0.02764658,,,,,,,,,,,,,,,,, -44600,0.085025184,0.025964104,,,,,,,,,,,,,,,,, -44700,0.07335897,0.02593211,,,,,,,,,,,,,,,,, -44800,0.082769945,0.027359782,,,,,,,,,,,,,,,,, -44900,0.079568684,0.025665328,,,,,,,,,,,,,,,,, -45000,0.0930647,0.024780523,,,,,,,,,,,,,,,,, -45021,,,0.99458110332489,0.0170487016439437,0.7150672404338416,0.9861192107200624,0.054877046495676,0.221464246573604,43793.0,0.9852181673049928,0.058937769383192,0.2193470926250697,43793.0,14418.789083957672,21584.268087387085,14418.789083957672,7161.983935594559,2.284372329711914,0.0 -45100,0.08929021,0.024572022,,,,,,,,,,,,,,,,, -45200,0.08611225,0.026980627,,,,,,,,,,,,,,,,, -45300,0.08390668,0.025048073,,,,,,,,,,,,,,,,, -45400,0.08363319,0.026117599,,,,,,,,,,,,,,,,, -45500,0.09041779,0.025156645,,,,,,,,,,,,,,,,, -45600,0.09845267,0.028684663,,,,,,,,,,,,,,,,, -45700,0.08606813,0.025091773,,,,,,,,,,,,,,,,, -45776,,,0.994462788105011,0.0171351470053195,0.7201120524918182,0.9861781001091005,0.0556235834956169,0.2299968563395169,43793.0,0.985298991203308,0.0596415922045707,0.2246596163663179,43793.0,14658.8561296463,21935.34939599037,14658.8561296463,7272.9426436424255,2.3197762966156006,0.0 -45800,0.09292931,0.025462072,,,,,,,,,,,,,,,,, -45900,0.083327815,0.025300605,,,,,,,,,,,,,,,,, -46000,0.09236914,0.026752032,,,,,,,,,,,,,,,,, -46100,0.08786043,0.025407,,,,,,,,,,,,,,,,, -46200,0.08920224,0.025845857,,,,,,,,,,,,,,,,, -46300,0.08931253,0.025188468,,,,,,,,,,,,,,,,, -46400,0.08084515,0.026268361,,,,,,,,,,,,,,,,, -46500,0.08449222,0.02538025,,,,,,,,,,,,,,,,, -46528,,,0.9965863823890686,0.0121391955763101,0.8180281852483683,0.986148476600647,0.0563682243227958,0.231505998139778,43793.0,0.9852783679962158,0.0603801384568214,0.2273424703023917,43793.0,14898.866291761398,22288.94997239113,14898.866291761398,7386.477149009705,2.3562400341033936,0.0 -46600,0.089466736,0.025351904,,,,,,,,,,,,,,,,, -46700,0.09503011,0.024723334,,,,,,,,,,,,,,,,, -46800,0.09027067,0.024443544,,,,,,,,,,,,,,,,, -46900,0.08930247,0.025815161,,,,,,,,,,,,,,,,, -47000,0.08666977,0.026421668,,,,,,,,,,,,,,,,, -47100,0.08632072,0.025513956,,,,,,,,,,,,,,,,, -47200,0.09482259,0.025302162,,,,,,,,,,,,,,,,, -47279,,,0.996531307697296,0.0124534703791141,0.8025178704840774,0.9861151576042176,0.0560289360582828,0.2313723651048697,43793.0,0.9851911664009094,0.0600710287690162,0.2206319943723637,43793.0,15139.000858545303,22639.64312577248,15139.000858545303,7496.979228496551,2.3927478790283203,0.0 -47300,0.08492183,0.024178257,,,,,,,,,,,,,,,,, -47400,0.08226642,0.022555795,,,,,,,,,,,,,,,,, -47500,0.09547399,0.02431649,,,,,,,,,,,,,,,,, -47600,0.09194943,0.024730962,,,,,,,,,,,,,,,,, -47700,0.09543539,0.027443757,,,,,,,,,,,,,,,,, -47800,0.09741076,0.025362587,,,,,,,,,,,,,,,,, -47900,0.09102237,0.024541767,,,,,,,,,,,,,,,,, -48000,0.095166,0.024177924,,,,,,,,,,,,,,,,, -48028,,,0.9962067008018494,0.0131070259958505,0.7931497384251549,0.986046552658081,0.0562718920409679,0.2260776234022609,43793.0,0.985040843486786,0.0601809658110141,0.2190393447107278,43793.0,15379.0965590477,22995.49823999405,15379.0965590477,7612.681452512741,2.4296257495880127,0.0 -48100,0.09456917,0.0247627,,,,,,,,,,,,,,,,, -48200,0.07791885,0.024478404,,,,,,,,,,,,,,,,, -48300,0.10072435,0.02698427,,,,,,,,,,,,,,,,, -48400,0.09775837,0.024450447,,,,,,,,,,,,,,,,, -48500,0.08571001,0.02434052,,,,,,,,,,,,,,,,, -48600,0.09407117,0.024418993,,,,,,,,,,,,,,,,, -48700,0.1077308,0.02519655,,,,,,,,,,,,,,,,, -48773,,,0.9955082535743712,0.0142448795959353,0.774797287980069,0.986183762550354,0.0573287643492221,0.228134299896944,43793.0,0.9852885007858276,0.061463788151741,0.2218434503684044,43793.0,15619.057337284088,23349.98605751992,15619.057337284088,7727.149502515793,2.468273878097534,0.0 -48800,0.08097139,0.024618013,,,,,,,,,,,,,,,,, -48900,0.08692122,0.025237255,,,,,,,,,,,,,,,,, -49000,0.088292725,0.024141124,,,,,,,,,,,,,,,,, -49100,0.08107887,0.023599114,,,,,,,,,,,,,,,,, -49200,0.09313769,0.023359446,,,,,,,,,,,,,,,,, -49300,0.08575093,0.022909706,,,,,,,,,,,,,,,,, -49400,0.097603284,0.024437578,,,,,,,,,,,,,,,,, -49500,0.0881734,0.026209908,,,,,,,,,,,,,,,,, -49526,,,0.995168685913086,0.0149586545303463,0.7747991965325125,0.9860441088676452,0.0570524372160434,0.2230743490184299,43793.0,0.9851848483085632,0.0611521825194358,0.2181409027100457,43793.0,15859.288888454435,23699.553965330124,15859.288888454435,7836.429103851318,2.5045437812805176,0.0 -49600,0.085850015,0.024656918,,,,,,,,,,,,,,,,, -49700,0.09334434,0.024781905,,,,,,,,,,,,,,,,, -49800,0.08380328,0.022419944,,,,,,,,,,,,,,,,, -49900,0.08859641,0.02480146,,,,,,,,,,,,,,,,, -50000,0.08588086,0.024169847,,,,,,,,,,,,,,,,, -50100,0.105217904,0.02627712,,,,,,,,,,,,,,,,, -50200,0.09022266,0.024137694,,,,,,,,,,,,,,,,, -50272,,,0.99516099691391,0.0150730079039931,0.7670223106612885,0.986040472984314,0.0577150806784629,0.2250570354018474,43793.0,0.9852353930473328,0.0616249740123748,0.221362975017089,43793.0,16099.445001363754,24051.175461292267,16099.445001363754,7947.837135314941,2.5418262481689453,0.0 -50300,0.10815634,0.023721965,,,,,,,,,,,,,,,,, -50400,0.095012724,0.025078872,,,,,,,,,,,,,,,,, -50500,0.09085965,0.024229188,,,,,,,,,,,,,,,,, -50600,0.0905128,0.024024125,,,,,,,,,,,,,,,,, -50700,0.08695099,0.02266098,,,,,,,,,,,,,,,,, -50800,0.09942648,0.02621108,,,,,,,,,,,,,,,,, -50900,0.09943845,0.024474531,,,,,,,,,,,,,,,,, -51000,0.09347429,0.024124537,,,,,,,,,,,,,,,,, -51016,,,0.994962215423584,0.0152889490127563,0.7472203813139899,0.9861119389533995,0.0582580640912056,0.220357172150594,43793.0,0.9851894974708556,0.0623769909143447,0.2168952545827823,43793.0,16339.63740158081,24398.847739219666,16339.63740158081,8055.256649494171,2.5816922187805176,0.0 -51100,0.09912682,0.024401532,,,,,,,,,,,,,,,,, -51200,0.09666826,0.02362765,,,,,,,,,,,,,,,,, -51300,0.07297556,0.022012863,,,,,,,,,,,,,,,,, -51400,0.1009173,0.024418456,,,,,,,,,,,,,,,,, -51500,0.08895071,0.02402187,,,,,,,,,,,,,,,,, -51600,0.09178734,0.023779208,,,,,,,,,,,,,,,,, -51700,0.09534986,0.023251664,,,,,,,,,,,,,,,,, -51765,,,0.9956825971603394,0.0138268237933516,0.788454380210666,0.9860782027244568,0.0579963140189647,0.2223265183184763,43793.0,0.9851473569869996,0.0621534325182437,0.2212461640478424,43793.0,16579.84620285034,24750.48007750511,16579.84620285034,8166.622575998306,2.619372129440308,0.0 -51800,0.092657514,0.024119481,,,,,,,,,,,,,,,,, -51900,0.08727418,0.02325753,,,,,,,,,,,,,,,,, -52000,0.09898718,0.023630427,,,,,,,,,,,,,,,,, -52100,0.074764304,0.022806728,,,,,,,,,,,,,,,,, -52200,0.104742736,0.024192201,,,,,,,,,,,,,,,,, -52300,0.08749336,0.025047675,,,,,,,,,,,,,,,,, -52400,0.08633823,0.024592651,,,,,,,,,,,,,,,,, -52500,0.09484529,0.022406466,,,,,,,,,,,,,,,,, -52515,,,0.9962998032569884,0.0124083021655678,0.8206180664414136,0.9861164093017578,0.0587662868201732,0.2223903278764082,43793.0,0.9851554036140442,0.0629678145051002,0.2186763948138787,43793.0,16819.83504796028,25103.957667827606,16819.83504796028,8280.05424952507,2.6561763286590576,0.0 -52600,0.13063665,0.024249472,,,,,,,,,,,,,,,,, -52700,0.09017052,0.02233965,,,,,,,,,,,,,,,,, -52800,0.111983545,0.024559865,,,,,,,,,,,,,,,,, -52900,0.09408688,0.023345584,,,,,,,,,,,,,,,,, -53000,0.104131125,0.025359841,,,,,,,,,,,,,,,,, -53100,0.095687374,0.023099452,,,,,,,,,,,,,,,,, -53200,0.08734953,0.022501275,,,,,,,,,,,,,,,,, -53261,,,0.997327983379364,0.0103095285594463,0.8675313433747215,0.9860234260559082,0.0592837184667587,0.2185852231293817,43793.0,0.985207200050354,0.0634720176458358,0.2194677612263874,43793.0,17059.884666204453,25456.877187013622,17059.884666204453,8392.8616604805,2.697401523590088,0.0 -53300,0.10369163,0.022113327,,,,,,,,,,,,,,,,, -53400,0.08633039,0.024228375,,,,,,,,,,,,,,,,, -53500,0.092549205,0.023478571,,,,,,,,,,,,,,,,, -53600,0.089393824,0.023295721,,,,,,,,,,,,,,,,, -53700,0.10024439,0.023103345,,,,,,,,,,,,,,,,, -53800,0.08543531,0.022358129,,,,,,,,,,,,,,,,, -53900,0.09748576,0.023640422,,,,,,,,,,,,,,,,, -54000,0.09101576,0.024011604,,,,,,,,,,,,,,,,, -54013,,,0.996036410331726,0.0128533076494932,0.8124292085607783,0.9860871434211732,0.0590945109724998,0.2162775368447429,43793.0,0.9851477742195128,0.063467264175415,0.2142464943903636,43793.0,17299.912273406982,25809.669014453888,17299.912273406982,8505.566602230072,2.735142230987549,0.0 -54100,0.08214886,0.022505604,,,,,,,,,,,,,,,,, -54200,0.08796183,0.023795547,,,,,,,,,,,,,,,,, -54300,0.090234935,0.023655845,,,,,,,,,,,,,,,,, -54400,0.09164407,0.023659745,,,,,,,,,,,,,,,,, -54500,0.0897741,0.021826472,,,,,,,,,,,,,,,,, -54600,0.09060173,0.021983234,,,,,,,,,,,,,,,,, -54700,0.097178966,0.023309352,,,,,,,,,,,,,,,,, -54762,,,0.9976819157600404,0.0096607785671949,0.8719512232661151,0.9859962463378906,0.059577465057373,0.2174690053207167,43793.0,0.9850918054580688,0.0639261305332183,0.2151637094748986,43793.0,17539.928025960922,26156.15963792801,17539.928025960922,8611.9836332798,2.773247480392456,0.0 -54800,0.087271325,0.022326248,,,,,,,,,,,,,,,,, -54900,0.1074372,0.024296923,,,,,,,,,,,,,,,,, -55000,0.101487316,0.023094052,,,,,,,,,,,,,,,,, -55100,0.09023569,0.023657482,,,,,,,,,,,,,,,,, -55200,0.08984548,0.022604248,,,,,,,,,,,,,,,,, -55300,0.079654954,0.02221404,,,,,,,,,,,,,,,,, -55400,0.09050845,0.024065526,,,,,,,,,,,,,,,,, -55500,0.08227889,0.022153132,,,,,,,,,,,,,,,,, -55525,,,0.9970630407333374,0.010564861819148,0.8574635624911153,0.9860855340957642,0.0606580302119255,0.2188187087485164,43793.0,0.9851404428482056,0.0652488842606544,0.214197982853505,43793.0,17780.066515922546,26504.03938817978,17780.066515922546,8719.667748212814,2.8100857734680176,0.0 -55600,0.089750536,0.022832686,,,,,,,,,,,,,,,,, -55700,0.10113472,0.023415573,,,,,,,,,,,,,,,,, -55800,0.09118111,0.02172876,,,,,,,,,,,,,,,,, -55900,0.07684975,0.021634972,,,,,,,,,,,,,,,,, -56000,0.08302972,0.021704009,,,,,,,,,,,,,,,,, -56100,0.08019517,0.02272665,,,,,,,,,,,,,,,,, -56200,0.09797044,0.023069894,,,,,,,,,,,,,,,,, -56277,,,0.9965944290161132,0.0112590510398149,0.8476336181469304,0.9860563278198242,0.0609218515455722,0.2172369927346399,43793.0,0.9851536750793456,0.0651892498135566,0.2122587202567049,43793.0,18020.076694488525,26857.69357442856,18020.076694488525,8833.25545835495,2.8469173908233643,0.0 -56300,0.080682755,0.021931821,,,,,,,,,,,,,,,,, -56400,0.08883506,0.0224705,,,,,,,,,,,,,,,,, -56500,0.09191642,0.022477461,,,,,,,,,,,,,,,,, -56600,0.0997111,0.022291921,,,,,,,,,,,,,,,,, -56700,0.07924792,0.021906435,,,,,,,,,,,,,,,,, -56800,0.083435036,0.020923603,,,,,,,,,,,,,,,,, -56900,0.089987606,0.023406686,,,,,,,,,,,,,,,,, -57000,0.09735815,0.022675328,,,,,,,,,,,,,,,,, -57029,,,0.995991826057434,0.0125447763130068,0.8311303477558817,0.9860153198242188,0.0611287616193294,0.2130954236813024,43793.0,0.985084593296051,0.0654769912362098,0.2110058729860776,43793.0,18260.329787254333,27213.90539121628,18260.329787254333,8949.157200813293,2.883453845977783,0.0 -57100,0.088873714,0.020733574,,,,,,,,,,,,,,,,, -57200,0.08832413,0.02242202,,,,,,,,,,,,,,,,, -57300,0.08431222,0.021461934,,,,,,,,,,,,,,,,, -57400,0.08832639,0.021466516,,,,,,,,,,,,,,,,, -57500,0.07895166,0.023493074,,,,,,,,,,,,,,,,, -57600,0.08559872,0.02065691,,,,,,,,,,,,,,,,, -57700,,,,,,,,,,,,,,18477.080607652664,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 4b9a918f5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -115.04727101325987,0.0,11.409060716629028,1,0,11.409060716629028,0.5047287344932556,0.7438743114471436,0.0269289900473865,43793,126.4563901424408,0.5091323852539062,0.7392957210540771,0.0224870301959398,0.5064647793769836,0.7422075867652893,0.0252986296517555,43793 -228.1838824748993,0.0212934017181396,251.55146312713623,748,0,251.55146312713623,0.983142077922821,0.0805582851171493,0.0400774903969423,43793,479.7769241333008,0.9866933822631836,0.069809466600418,0.0358362544300362,0.9841179251670836,0.0777350962162017,0.0368130067090984,43793 -336.9643814563751,0.0503711700439453,491.7307615280152,1492,0,491.7307615280152,0.9834933280944824,0.0614174157381057,0.0807156931335104,43793,828.7861225605011,0.987136721611023,0.0488488338887691,0.0804861564040043,0.9844698905944824,0.0582043789327144,0.0804741995938276,43793 -449.5094120502472,0.0773499011993408,731.8950526714325,2240,0,731.8950526714325,0.9840232133865356,0.0565517283976078,0.1219216110893134,43793,1181.5423827171326,0.987756609916687,0.0444316267967224,0.1284717842173021,0.985007345676422,0.0535632334649562,0.1232223915642816,43793 -566.2745454311371,0.1053874492645263,972.0586864948272,2982,0,972.0586864948272,0.9840303659439088,0.0546653345227241,0.1447432721257186,43793,1538.52033162117,0.9878631830215454,0.0435997433960437,0.1540226900623495,0.9849075078964232,0.0523733645677566,0.139691764104344,43793 -679.1746046543121,0.1389689445495605,1212.2938175201416,3716,0,1212.2938175201416,0.9845231771469116,0.0531827509403228,0.1678773727123591,43793,1891.7137813568115,0.9883297085762024,0.0408373698592186,0.1751967621768912,0.9854157567024232,0.0504284836351871,0.1665435963180997,43793 -792.9642839431763,0.1680970191955566,1452.2720143795011,4456,0,1452.2720143795011,0.9844696521759032,0.0521303638815879,0.1820111961167904,43793,2245.531072616577,0.9884167313575744,0.0401438362896442,0.2029430039185905,0.9853593111038208,0.049749881029129,0.1753748990749337,43793 -901.5599522590636,0.1956448554992675,1692.4591073989868,5202,0,1692.4591073989868,0.9848099946975708,0.0512166544795036,0.1928349290109782,43793,2594.3612883090973,0.9886823296546936,0.0386181212961673,0.2132548290412207,0.9857100248336792,0.0483920536935329,0.1913639220445636,43793 -1012.5432093143464,0.2226417064666748,1932.6851942539213,5953,0,1932.6851942539213,0.9849632978439332,0.0504787340760231,0.197353427924065,43793,2945.617326259613,0.9889214038848876,0.0379679910838604,0.2413591196015644,0.9860051274299622,0.0477324537932872,0.2028683202818477,43793 -1120.6893513202667,0.2512521743774414,2172.6391608715057,6704,0,2172.6391608715057,0.9851145148277284,0.049725517630577,0.208911943517931,43793,3293.7659606933594,0.989250123500824,0.0368568673729896,0.2679609941091441,0.9859207272529602,0.0471237525343894,0.2060660896620549,43793 -1233.30415892601,0.2791018486022949,2412.912478208542,7446,0,2412.912478208542,0.9852392077445984,0.0492012463510036,0.2160302095265996,43793,3646.7018489837646,0.9892452955245972,0.0365756042301654,0.2556209530754341,0.9861403107643129,0.046624518930912,0.2181407743537345,43793 -1342.8322954177856,0.3075456619262695,2653.0472116470337,8192,0,2653.0472116470337,0.985280454158783,0.0489580743014812,0.22005215130504,43793,3996.4135501384735,0.9891985654830932,0.0363474003970623,0.2777754718600149,0.9861192107200624,0.0464285761117935,0.2241906615481037,43793 -1453.4429540634155,0.335129976272583,2893.168367624283,8944,0,2893.168367624283,0.9853727221488952,0.0489822104573249,0.2192443785054851,43793,4347.192717552185,0.9893425703048706,0.0358811616897583,0.2778406643591576,0.9862227439880372,0.0463167168200016,0.213456634775387,43793 -1561.425509929657,0.3628628253936767,3133.286482810974,9700,0,3133.286482810974,0.9855323433876038,0.0486967712640762,0.2262703918585826,43793,4695.340826034546,0.989570677280426,0.0349280498921871,0.2965503820496965,0.9864078760147096,0.0460746698081493,0.23116527704377,43793 -1671.5593321323397,0.3925163745880127,3373.3551876544952,10446,0,3373.3551876544952,0.9855976104736328,0.0481628887355327,0.2321888170443082,43793,5045.592380523682,0.9897487759590148,0.0343241393566131,0.3145140324309475,0.9864723682403564,0.0455073565244674,0.2354921471673552,43793 -1779.7080972194672,0.420879602432251,3613.492337703705,11191,0,3613.492337703705,0.9856485724449158,0.0476858727633953,0.2412774564387674,43793,5393.927290678024,0.9900938272476196,0.0334394387900829,0.3254542392372632,0.986543834209442,0.0448717139661312,0.2416672091869643,43793 -1894.3356800079343,0.4488670825958252,3853.50465130806,11940,0,3853.50465130806,0.9855904579162598,0.0475993864238262,0.2349569160331164,43793,5748.61470246315,0.9900937676429749,0.0329126082360744,0.3425646620707906,0.9863781929016112,0.0451198928058147,0.2363973885692383,43793 -2006.2802784442904,0.4773871898651123,4093.458253622055,12683,0,4093.458253622055,0.9857320189476012,0.047577790915966,0.2476646823485428,43793,6100.5614466667175,0.9903525710105896,0.0320491008460521,0.3773028926927707,0.9866164922714232,0.0448366589844226,0.2527637025524815,43793 -2121.336286067962,0.5057508945465088,4333.447480678558,13417,0,4333.447480678558,0.9857416749000548,0.0473575927317142,0.2451571180269876,43793,6455.656747102737,0.9905537962913512,0.0310894008725881,0.3724031313377338,0.9866245985031128,0.0445709228515625,0.2492293152620507,43793 -2229.1408503055573,0.5367867946624756,4573.6459176540375,14151,0,4573.6459176540375,0.9856612086296082,0.0475810840725898,0.2457801394641009,43793,6803.714690446854,0.9907062649726868,0.0309913195669651,0.378599764563951,0.9865682125091552,0.044715654104948,0.2523916541291809,43793 -2343.599368095398,0.567765474319458,4813.645154476166,14903,0,4813.645154476166,0.9858141541481018,0.0474219284951686,0.2527017177476719,43793,7158.222968816757,0.9905675649642944,0.0312392432242631,0.3900551148423534,0.98665589094162,0.0447624139487743,0.2534447646793866,43793 -2449.626982450485,0.5992722511291504,5053.900189876556,15652,0,5053.900189876556,0.9858461618423462,0.0471496619284153,0.2551532981849261,43793,7504.557371377945,0.9906066060066224,0.0311218444257974,0.381319950241573,0.9867772459983826,0.0444102399051189,0.2583747608772624,43793 -2561.516103744507,0.6315131187438965,5294.087451457977,16395,0,5294.087451457977,0.985654890537262,0.0475700162351131,0.2462189962544967,43793,7856.685685396194,0.9906516671180724,0.0308571141213178,0.3869455478987839,0.98651784658432,0.0449075438082218,0.2555116961913109,43793 -2670.1546428203583,0.6623170375823975,5534.083811998367,17141,0,5534.083811998367,0.9855812191963196,0.0473384335637092,0.2524884362959509,43793,8205.371302127838,0.990618884563446,0.0309921987354755,0.3856777667274151,0.9865024089813232,0.0446387305855751,0.2601376590016442,43793 -2781.957170248032,0.6927800178527832,5774.310196876526,17887,0,5774.310196876526,0.9856747388839722,0.0483092926442623,0.2461299896463264,43793,8557.450502157211,0.9904924035072328,0.031084245070815,0.3930251830262817,0.9865880608558656,0.0455077663064003,0.2632867709525448,43793 -2888.7862842082977,0.7243075370788574,6014.308893442154,18633,0,6014.308893442154,0.9858941435813904,0.0473372787237167,0.2532866247875107,43793,8904.329939126968,0.9910005927085876,0.0293161552399396,0.4221835982123053,0.9868016242980956,0.0446500442922115,0.2631362447173095,43793 -3000.428964138031,0.754981517791748,6254.484782218933,19385,0,6254.484782218933,0.9859779477119446,0.0470069982111454,0.2564808137063278,43793,9256.199064016342,0.991144597530365,0.0288536138832569,0.4362776975542201,0.986806094646454,0.0443035177886486,0.2714599200929934,43793 -3106.488452196121,0.7869946956634521,6494.502206325531,20137,0,6494.502206325531,0.9858676195144652,0.0475185513496398,0.2565515376288197,43793,9602.328292369844,0.9912317991256714,0.0284095779061317,0.4510648548895596,0.9867587685585022,0.0445445626974105,0.2627217483865065,43793 -3217.7399446964264,0.8169877529144287,6734.702629804611,20892,0,6734.702629804611,0.985849916934967,0.0476750880479812,0.2547641389335,43793,9953.830657482147,0.9913666248321532,0.0281404834240674,0.4477376798318661,0.9866806268692015,0.0450030602514743,0.2594265373774643,43793 -3330.0670692920685,0.84893798828125,6974.657138586044,21632,0,6974.657138586044,0.9857699275016784,0.0473132506012916,0.2545662576834401,43793,10306.164506912231,0.9910850524902344,0.029171073809266,0.4186773832495749,0.986506462097168,0.0445950105786323,0.260642723999837,43793 -3441.468363761902,0.8800814151763916,7214.641557455063,22380,0,7214.641557455063,0.9858149886131288,0.0473951138556003,0.260235189498801,43793,10657.6014316082,0.9910112023353576,0.0293988939374685,0.413999674447663,0.9867098927497864,0.0447753965854644,0.2613197464129227,43793 -3550.564575910568,0.9132204055786132,7454.6159999370575,23133,0,7454.6159999370575,0.9857193827629088,0.047521024942398,0.256449131398079,43793,11006.72515630722,0.9910687804222108,0.0292721986770629,0.427519522785992,0.98666113615036,0.0447103604674339,0.2611749950881226,43793 -3665.318177700042,0.9451286792755128,7694.792117834091,23882,0,7694.792117834091,0.9859809279441832,0.0475788004696369,0.2583359648758452,43793,11361.706496953964,0.9910465478897096,0.0290975049138069,0.439029657393755,0.986792266368866,0.0448956228792667,0.2642872001662467,43793 -3773.929664850235,0.9760401248931884,7934.79806470871,24616,0,7934.79806470871,0.9857947826385498,0.0476335324347019,0.2508257458319819,43793,11710.37767481804,0.9911187291145324,0.0287637133151292,0.4408214315236216,0.9866234064102172,0.0449697487056255,0.2635825782785605,43793 -3887.515341043472,1.0077130794525146,8174.913177490234,25358,0,8174.913177490234,0.9857442378997804,0.0476606115698814,0.2516790343245108,43793,12064.129575967789,0.991275668144226,0.0282500777393579,0.4428742902391496,0.986655056476593,0.0448307059705257,0.2637613334042595,43793 -3996.383995056152,1.039290189743042,8414.918694972992,26103,0,8414.918694972992,0.9858486652374268,0.0474232174456119,0.2543806985051701,43793,12413.05533671379,0.9914022088050842,0.0277662333101034,0.4523410083924511,0.9866623878479004,0.044797908514738,0.2641685438259154,43793 -4104.547461509705,1.0705244541168213,8655.010743618011,26844,0,8655.010743618011,0.9857585430145264,0.0474692322313785,0.2568551148478193,43793,12761.361777067184,0.991820216178894,0.0265189036726951,0.4992254146540879,0.986660361289978,0.0447358042001724,0.2714495760771454,43793 -4213.402389764786,1.1035490036010742,8895.237357616425,27591,0,8895.237357616425,0.9857147336006165,0.0480520650744438,0.2612988077710391,43793,13110.495953798294,0.99174964427948,0.0268844552338123,0.4890036321106513,0.98653244972229,0.0451820157468318,0.2685428630516191,43793 -4322.653234481812,1.1350586414337158,9135.220917224884,28339,0,9135.220917224884,0.9858225584030152,0.0483605340123176,0.2572859701949634,43793,13459.781569480896,0.991543173789978,0.0273032449185848,0.4728328304098546,0.9867244958877563,0.0452097095549106,0.267856911482501,43793 -4429.404903411865,1.1667847633361816,9375.281760692596,29087,0,9375.281760692596,0.9858547449111938,0.0476239696145057,0.2617479458606887,43793,13806.64552783966,0.9913496375083924,0.0279213413596153,0.446704589850163,0.9866745471954346,0.0448725298047065,0.2700355602727555,43793 -4536.267310619354,1.1982874870300293,9615.312079906464,29841,0,9615.312079906464,0.9858208894729614,0.0481179356575012,0.2530234007583115,43793,14153.5893740654,0.9912735819816588,0.0282083544880151,0.4498426126219626,0.9865994453430176,0.0453692749142646,0.2591981850120724,43793 -4645.564736843109,1.230731964111328,9855.56041288376,30587,0,9855.56041288376,0.9857636094093324,0.0484786257147789,0.254502359861954,43793,14503.186990499496,0.9914898872375488,0.0273872166872024,0.4727808883611701,0.986668050289154,0.0453085117042064,0.2628570416310531,43793 -4750.071592330933,1.2652955055236816,10095.512593507769,31331,0,10095.512593507769,0.9859392046928406,0.0476431138813495,0.2611744513547663,43793,14847.700211763382,0.9915563464164734,0.0272246785461902,0.4779938282806382,0.986726939678192,0.0449745245277881,0.2681475382524397,43793 -4857.177111625671,1.300189971923828,10335.648624420166,32084,0,10335.648624420166,0.9857879877090454,0.0480624511837959,0.2527473212298942,43793,15194.996549367905,0.9915878772735596,0.0269397553056478,0.4846764472143901,0.9865304231643676,0.0454461760818958,0.2681769855204736,43793 -4962.369988441467,1.332615613937378,10575.60150718689,32842,0,10575.60150718689,0.9858031868934632,0.0478540509939193,0.2626084049281989,43793,15540.194693565369,0.9919490218162536,0.0257465578615665,0.5127910015596717,0.9866639971733092,0.0449895188212394,0.2654056363679506,43793 -5067.671158790588,1.3651981353759766,10815.79008245468,33592,0,10815.79008245468,0.9859409332275392,0.0479836985468864,0.2574362116662635,43793,15885.7362678051,0.9920307397842408,0.0254134628921747,0.5067656067708026,0.9868279695510864,0.045044295489788,0.2658500783934601,43793 -5174.355105876923,1.3994545936584473,11056.026224136353,34348,0,11056.026224136353,0.985846996307373,0.0481024570763111,0.2553736626151114,43793,16232.710386753082,0.9923913478851318,0.0244236923754215,0.5342409464257092,0.9867200255393982,0.0453106351196765,0.2671155695587615,43793 -5284.490906953812,1.4411022663116455,11296.04472565651,35099,0,11296.04472565651,0.9857711791992188,0.0481272302567958,0.2624014605046485,43793,16582.92563867569,0.9922991991043092,0.0248646866530179,0.5406791037882595,0.9865483045578004,0.0454125665128231,0.2769924562069861,43793 -5393.113869190216,1.474443435668945,11536.114127397535,35851,0,11536.114127397535,0.985917329788208,0.048857532441616,0.258611242636092,43793,16931.671364068985,0.9920918941497804,0.0254524517804384,0.5169067383191193,0.9867861866950988,0.0457796640694141,0.2736346291879268,43793 -5502.739744663239,1.5083093643188477,11776.257864952087,36606,0,11776.257864952087,0.9859451055526732,0.0486032329499721,0.2502696397802891,43793,17281.49413871765,0.9918672442436218,0.0259287282824516,0.4994652663793621,0.9868308305740356,0.0456037297844886,0.2706534819480221,43793 -5610.971168756485,1.541703224182129,12016.488486766815,37345,0,12016.488486766815,0.9858027696609496,0.0489179119467735,0.2587711548087072,43793,17630.008954048157,0.9918071031570436,0.0262284409254789,0.4914974538833261,0.986737072467804,0.0458212234079837,0.269111434883714,43793 -5714.142570018768,1.5757763385772705,12256.664202690125,38090,0,12256.664202690125,0.9859261512756348,0.048667199909687,0.2615826043412462,43793,17973.409933567047,0.991980254650116,0.0255869254469871,0.5115828856662827,0.9867179989814758,0.0459587611258029,0.2713995649094152,43793 -5819.092994213104,1.608827829360962,12496.846721887589,38833,0,12496.846721887589,0.9857484102249146,0.0489294081926345,0.2565882521212089,43793,18318.59648966789,0.9920877814292908,0.0252947751432657,0.521443788126027,0.986579179763794,0.0461673587560653,0.2693436348083421,43793 -5931.902855873108,1.6432619094848633,12737.104538679125,39566,0,12737.104538679125,0.985687792301178,0.0492235198616981,0.25741389805365,43793,18671.7216861248,0.9922082424163818,0.0247992202639579,0.5314044332998933,0.9865458607673644,0.0463249869644641,0.267102513934124,43793 -6041.992013454437,1.677354335784912,12977.295271873474,40305,0,12977.295271873474,0.9857577085494996,0.0494484901428222,0.2498049455988531,43793,19022.058073043823,0.9925952553749084,0.023490697145462,0.5565252775312892,0.986531674861908,0.0463293083012104,0.2676215261520471,43793 -6147.838949918747,1.7124638557434082,13217.297476530077,41054,0,13217.297476530077,0.9858141541481018,0.049161035567522,0.2650616828523954,43793,19367.962020874023,0.9928678274154664,0.0224870759993791,0.581720308999962,0.9867366552352904,0.0461860261857509,0.2734461547168441,43793 -6253.418445587158,1.7486231327056885,13457.432827711104,41810,0,13457.432827711104,0.9857454895973206,0.049967210739851,0.2580405016937968,43793,19713.732904434204,0.9927249550819396,0.0229897052049636,0.5715105094814787,0.9866583347320556,0.0467988699674606,0.2642473790358029,43793 -6366.196691989899,1.7837131023406982,13697.439344882963,42556,0,13697.439344882963,0.9856923818588256,0.0494961068034172,0.2557325618023192,43793,20066.57444548607,0.992652952671051,0.0233211573213338,0.5512938937985699,0.9865283966064452,0.0467281602323055,0.2636467984267341,43793 -6472.557656049728,1.8214752674102783,13937.537743330002,43306,0,13937.537743330002,0.985849916934967,0.0496870279312133,0.2605483988073737,43793,20413.0918610096,0.992501735687256,0.0237785894423723,0.5646825159913036,0.9867951273918152,0.0462379977107048,0.2821592979825692,43793 -6580.312434911728,1.8559012413024905,14177.75285601616,44051,0,14177.75285601616,0.9858596324920654,0.0500569939613342,0.2565946215782902,43793,20761.1159992218,0.9924307465553284,0.0238613691180944,0.555371738352912,0.9867119193077089,0.0468770824372768,0.2732673112388979,43793 -6687.529041051865,1.891420125961304,14417.734598636627,44798,0,14417.734598636627,0.9858166575431824,0.0504205785691738,0.2533213945513642,43793,21108.37005758285,0.992325484752655,0.0240122228860855,0.5586604499035724,0.986703395843506,0.0472040958702564,0.2669402476371244,43793 -6794.665002822876,1.9260263442993164,14657.84571671486,45551,0,14657.84571671486,0.9856654405593872,0.0509008951485157,0.2581821676206569,43793,21455.67155122757,0.9925184845924376,0.0234085004776716,0.5609331101563908,0.9865624904632568,0.0477843768894672,0.2695785599802789,43793 -6900.194499254227,1.9617371559143064,14897.96158337593,46308,0,14897.96158337593,0.9856970310211182,0.050856564193964,0.2524435396842576,43793,21801.37257266045,0.9927812218666076,0.0225784312933683,0.5735938940547376,0.986531674861908,0.0475295148789882,0.2711253417235155,43793 -7008.161518335342,2.0013952255249023,15138.184672355652,47046,0,15138.184672355652,0.9854986667633056,0.0510683432221412,0.2488496460511818,43793,22149.626344442368,0.9932058453559875,0.0213255230337381,0.6142257470793187,0.9863327741622924,0.0477554388344287,0.2685979413267775,43793 -7115.730270385742,2.038448095321656,15378.320225954056,47797,0,15378.320225954056,0.9856852293014526,0.051686979830265,0.2522580826705521,43793,22497.388407230377,0.9933127760887146,0.0208547096699476,0.6183224958318798,0.9866700768470764,0.04813152551651,0.2735240056114583,43793 -7222.079265594482,2.07443904876709,15618.439760684969,48545,0,15618.439760684969,0.9857058525085448,0.0517389625310897,0.2559660860530463,43793,22843.91304540634,0.9937769770622252,0.0196090713143348,0.657254856675352,0.9864935278892516,0.0484823510050773,0.2643268234435679,43793 -7328.919096469879,2.4131574630737305,15858.267110586166,49290,0,15858.267110586166,0.9856612086296082,0.0520920641720294,0.2503146224776895,43793,23190.93936944008,0.9935559630393982,0.0202613193541765,0.6167782802913997,0.9866542816162108,0.0487692281603813,0.2706925333525783,43793 -7438.20524597168,2.4488024711608887,16098.356865167618,50032,0,16098.356865167618,0.9856587052345276,0.0523470975458622,0.2541978403734496,43793,23540.37098479271,0.9932066798210144,0.0211821142584085,0.6060510859266648,0.9864216446876526,0.0490487702190876,0.2732556869305221,43793 -7546.218659877777,2.485704183578491,16338.490710258484,50782,0,16338.490710258484,0.9856283664703368,0.0528356507420539,0.2524023510446382,43793,23888.57576751709,0.9932138323783876,0.0211225748062133,0.6076637757132363,0.9864943027496338,0.0495653636753559,0.2643919542063089,43793 -7655.152938842773,2.5216963291168213,16578.616221904755,51533,0,16578.616221904755,0.9855205416679382,0.0532411523163318,0.2527073064198517,43793,24237.691730499268,0.9932173490524292,0.0210808832198381,0.6193763045917757,0.986382246017456,0.0498485118150711,0.2682000829620698,43793 -7756.792923927307,2.5575668811798096,16818.712033748627,52279,0,16818.712033748627,0.9856444001197816,0.0534443370997905,0.2497437597936006,43793,24579.483705997467,0.9932751059532166,0.0206806678324937,0.6200296858576568,0.9865409731864928,0.0498331263661384,0.2641053459337334,43793 -7864.91987657547,2.594414710998535,17058.962177991867,53026,0,17058.962177991867,0.9855024218559264,0.0538224689662456,0.2479787630190198,43793,24927.917595624924,0.9934790134429932,0.0199831202626228,0.6333494312558867,0.9864338040351868,0.0501837916672229,0.2680832358381055,43793 -7971.463971138,2.6330935955047607,17298.90731549263,53772,0,17298.90731549263,0.9854165315628052,0.0540853217244148,0.2510609877039152,43793,25274.465651512142,0.9938586950302124,0.0189154967665672,0.6609112039038493,0.986370086669922,0.0502746887505054,0.2667339680899637,43793 -8074.933919668198,2.669600009918213,17538.941791057587,54520,0,17538.941791057587,0.9854194521903992,0.0547283142805099,0.2489122732522624,43793,25618.02638888359,0.9943572282791138,0.0174389947205781,0.6922163307395394,0.9864155650138856,0.0508133694529533,0.262530649754857,43793 -8181.723499298096,2.7075655460357666,17779.046933174133,55267,0,17779.046933174133,0.9853533506393432,0.0552709624171257,0.2484196984421053,43793,25964.97895693779,0.9947988390922546,0.0162368603050708,0.7333053532778322,0.9862548112869264,0.0513133332133293,0.2627462523088922,43793 -8290.270651817322,2.7457399368286133,18019.02331376076,56017,0,18019.02331376076,0.9853945970535278,0.0558489337563514,0.2532018150713649,43793,26313.56055831909,0.9949828386306764,0.0157982297241687,0.7273481342521675,0.9862483143806458,0.0518519133329391,0.2608411574989082,43793 -8400.08038520813,2.7842960357666016,18259.248502254486,56766,0,18259.248502254486,0.9853697419166565,0.05587721988558769,0.2519444893929446,43793,26663.653777122498,0.9939604997634888,0.018462734296917915,0.6661507763472906,0.9862211346626282,0.05208123102784157,0.26548057966644023,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/measurements.csv deleted file mode 100644 index 2fc1bd3fd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/measurements.csv +++ /dev/null @@ -1,654 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3577504,0.7395575,,,,,,,,,,,,,,,,, -1,,,0.5091323852539062,0.7392957210540771,0.0224870301959398,0.5064647793769836,0.7422075867652893,0.0252986296517555,43793.0,0.5047287344932556,0.7438743114471436,0.0269289900473865,43793.0,11.409060716629028,126.4563901424408,11.409060716629028,115.04727101325987,0.0,0.0 -100,0.6409674,0.49093008,,,,,,,,,,,,,,,,, -200,0.399678,0.36154094,,,,,,,,,,,,,,,,, -300,0.28872114,0.2600051,,,,,,,,,,,,,,,,, -400,0.19085127,0.1734185,,,,,,,,,,,,,,,,, -500,0.11599374,0.121233374,,,,,,,,,,,,,,,,, -600,0.072615385,0.088651516,,,,,,,,,,,,,,,,, -700,0.05308028,0.07355888,,,,,,,,,,,,,,,,, -748,,,0.9866933822631836,0.069809466600418,0.0358362544300362,0.9841179251670836,0.0777350962162017,0.0368130067090984,43793.0,0.983142077922821,0.0805582851171493,0.0400774903969423,43793.0,251.55146312713623,479.7769241333008,251.55146312713623,228.1838824748993,0.0212934017181396,0.0 -800,0.098516785,0.06374102,,,,,,,,,,,,,,,,, -900,0.13409114,0.05849568,,,,,,,,,,,,,,,,, -1000,0.07311197,0.058961626,,,,,,,,,,,,,,,,, -1100,0.037872065,0.053800557,,,,,,,,,,,,,,,,, -1200,0.16516636,0.049108736,,,,,,,,,,,,,,,,, -1300,0.12737033,0.04793763,,,,,,,,,,,,,,,,, -1400,0.07462031,0.049932286,,,,,,,,,,,,,,,,, -1492,,,0.987136721611023,0.0488488338887691,0.0804861564040043,0.9844698905944824,0.0582043789327144,0.0804741995938276,43793.0,0.9834933280944824,0.0614174157381057,0.0807156931335104,43793.0,491.7307615280152,828.7861225605011,491.7307615280152,336.9643814563751,0.0503711700439453,0.0 -1500,0.3771057,0.050226945,,,,,,,,,,,,,,,,, -1600,0.2926956,0.043503597,,,,,,,,,,,,,,,,, -1700,0.21939488,0.05050567,,,,,,,,,,,,,,,,, -1800,0.20367083,0.05437785,,,,,,,,,,,,,,,,, -1900,0.13683909,0.04831048,,,,,,,,,,,,,,,,, -2000,0.22287358,0.055412725,,,,,,,,,,,,,,,,, -2100,0.16411163,0.051236425,,,,,,,,,,,,,,,,, -2200,0.23180118,0.04080877,,,,,,,,,,,,,,,,, -2240,,,0.987756609916687,0.0444316267967224,0.1284717842173021,0.985007345676422,0.0535632334649562,0.1232223915642816,43793.0,0.9840232133865356,0.0565517283976078,0.1219216110893134,43793.0,731.8950526714325,1181.5423827171326,731.8950526714325,449.5094120502472,0.0773499011993408,0.0 -2300,0.10427606,0.03987392,,,,,,,,,,,,,,,,, -2400,0.19421287,0.045513425,,,,,,,,,,,,,,,,, -2500,0.17879653,0.046243463,,,,,,,,,,,,,,,,, -2600,0.17068651,0.043365154,,,,,,,,,,,,,,,,, -2700,0.08866114,0.038412422,,,,,,,,,,,,,,,,, -2800,0.24694847,0.04198673,,,,,,,,,,,,,,,,, -2900,0.07155318,0.041887395,,,,,,,,,,,,,,,,, -2982,,,0.9878631830215454,0.0435997433960437,0.1540226900623495,0.9849075078964232,0.0523733645677566,0.139691764104344,43793.0,0.9840303659439088,0.0546653345227241,0.1447432721257186,43793.0,972.0586864948272,1538.52033162117,972.0586864948272,566.2745454311371,0.1053874492645263,0.0 -3000,0.21902576,0.040333357,,,,,,,,,,,,,,,,, -3100,0.1694753,0.041914206,,,,,,,,,,,,,,,,, -3200,0.09253975,0.044779714,,,,,,,,,,,,,,,,, -3300,0.057287116,0.038437206,,,,,,,,,,,,,,,,, -3400,0.1723044,0.045125462,,,,,,,,,,,,,,,,, -3500,0.075849004,0.045267798,,,,,,,,,,,,,,,,, -3600,0.12332454,0.042844683,,,,,,,,,,,,,,,,, -3700,0.1260932,0.04014081,,,,,,,,,,,,,,,,, -3716,,,0.9883297085762024,0.0408373698592186,0.1751967621768912,0.9854157567024232,0.0504284836351871,0.1665435963180997,43793.0,0.9845231771469116,0.0531827509403228,0.1678773727123591,43793.0,1212.2938175201416,1891.7137813568115,1212.2938175201416,679.1746046543121,0.1389689445495605,0.0 -3800,0.10702493,0.04163233,,,,,,,,,,,,,,,,, -3900,0.08317313,0.04093457,,,,,,,,,,,,,,,,, -4000,0.111873515,0.035029724,,,,,,,,,,,,,,,,, -4100,0.08987873,0.03868885,,,,,,,,,,,,,,,,, -4200,0.050092213,0.03914178,,,,,,,,,,,,,,,,, -4300,0.045597076,0.038543437,,,,,,,,,,,,,,,,, -4400,0.24244986,0.043717608,,,,,,,,,,,,,,,,, -4456,,,0.9884167313575744,0.0401438362896442,0.2029430039185905,0.9853593111038208,0.049749881029129,0.1753748990749337,43793.0,0.9844696521759032,0.0521303638815879,0.1820111961167904,43793.0,1452.2720143795011,2245.531072616577,1452.2720143795011,792.9642839431763,0.1680970191955566,0.0 -4500,0.05686752,0.03828095,,,,,,,,,,,,,,,,, -4600,0.08394143,0.041461103,,,,,,,,,,,,,,,,, -4700,0.08038283,0.037235048,,,,,,,,,,,,,,,,, -4800,0.14753605,0.043281883,,,,,,,,,,,,,,,,, -4900,0.070231214,0.045126267,,,,,,,,,,,,,,,,, -5000,0.05342373,0.039254487,,,,,,,,,,,,,,,,, -5100,0.064220734,0.037441462,,,,,,,,,,,,,,,,, -5200,0.039380543,0.035054523,,,,,,,,,,,,,,,,, -5202,,,0.9886823296546936,0.0386181212961673,0.2132548290412207,0.9857100248336792,0.0483920536935329,0.1913639220445636,43793.0,0.9848099946975708,0.0512166544795036,0.1928349290109782,43793.0,1692.4591073989868,2594.3612883090973,1692.4591073989868,901.5599522590636,0.1956448554992675,0.0 -5300,0.08124476,0.036286805,,,,,,,,,,,,,,,,, -5400,0.0935552,0.037423614,,,,,,,,,,,,,,,,, -5500,0.07483863,0.04265656,,,,,,,,,,,,,,,,, -5600,0.07152682,0.035841223,,,,,,,,,,,,,,,,, -5700,0.030683028,0.03651831,,,,,,,,,,,,,,,,, -5800,0.055558078,0.040273916,,,,,,,,,,,,,,,,, -5900,0.06099552,0.040706728,,,,,,,,,,,,,,,,, -5953,,,0.9889214038848876,0.0379679910838604,0.2413591196015644,0.9860051274299622,0.0477324537932872,0.2028683202818477,43793.0,0.9849632978439332,0.0504787340760231,0.197353427924065,43793.0,1932.6851942539213,2945.617326259613,1932.6851942539213,1012.5432093143464,0.2226417064666748,0.0 -6000,0.036658004,0.032730732,,,,,,,,,,,,,,,,, -6100,0.0627621,0.04137043,,,,,,,,,,,,,,,,, -6200,0.03948813,0.034731727,,,,,,,,,,,,,,,,, -6300,0.034123998,0.038845643,,,,,,,,,,,,,,,,, -6400,0.048537828,0.036129154,,,,,,,,,,,,,,,,, -6500,0.029423136,0.035604376,,,,,,,,,,,,,,,,, -6600,0.06704572,0.038648676,,,,,,,,,,,,,,,,, -6700,0.050375696,0.040082965,,,,,,,,,,,,,,,,, -6704,,,0.989250123500824,0.0368568673729896,0.2679609941091441,0.9859207272529602,0.0471237525343894,0.2060660896620549,43793.0,0.9851145148277284,0.049725517630577,0.208911943517931,43793.0,2172.6391608715057,3293.7659606933594,2172.6391608715057,1120.6893513202667,0.2512521743774414,0.0 -6800,0.036150463,0.03497692,,,,,,,,,,,,,,,,, -6900,0.029616833,0.03497853,,,,,,,,,,,,,,,,, -7000,0.04411382,0.036623504,,,,,,,,,,,,,,,,, -7100,0.028695155,0.031885896,,,,,,,,,,,,,,,,, -7200,0.042085662,0.0360412,,,,,,,,,,,,,,,,, -7300,0.027865719,0.035015326,,,,,,,,,,,,,,,,, -7400,0.033843275,0.03777812,,,,,,,,,,,,,,,,, -7446,,,0.9892452955245972,0.0365756042301654,0.2556209530754341,0.9861403107643129,0.046624518930912,0.2181407743537345,43793.0,0.9852392077445984,0.0492012463510036,0.2160302095265996,43793.0,2412.912478208542,3646.7018489837646,2412.912478208542,1233.30415892601,0.2791018486022949,0.0 -7500,0.02684988,0.03811773,,,,,,,,,,,,,,,,, -7600,0.03718904,0.038550712,,,,,,,,,,,,,,,,, -7700,0.036546737,0.042498197,,,,,,,,,,,,,,,,, -7800,0.027782418,0.034596078,,,,,,,,,,,,,,,,, -7900,0.036705595,0.035842873,,,,,,,,,,,,,,,,, -8000,0.036762085,0.036590777,,,,,,,,,,,,,,,,, -8100,0.026771232,0.0394265,,,,,,,,,,,,,,,,, -8192,,,0.9891985654830932,0.0363474003970623,0.2777754718600149,0.9861192107200624,0.0464285761117935,0.2241906615481037,43793.0,0.985280454158783,0.0489580743014812,0.22005215130504,43793.0,2653.0472116470337,3996.4135501384735,2653.0472116470337,1342.8322954177856,0.3075456619262695,0.0 -8200,0.07914484,0.036124095,,,,,,,,,,,,,,,,, -8300,0.034777578,0.04100023,,,,,,,,,,,,,,,,, -8400,0.044610307,0.033002667,,,,,,,,,,,,,,,,, -8500,0.036769092,0.040218618,,,,,,,,,,,,,,,,, -8600,0.039574407,0.036582667,,,,,,,,,,,,,,,,, -8700,0.039094284,0.037290066,,,,,,,,,,,,,,,,, -8800,0.023551555,0.03508019,,,,,,,,,,,,,,,,, -8900,0.05029769,0.037819404,,,,,,,,,,,,,,,,, -8944,,,0.9893425703048706,0.0358811616897583,0.2778406643591576,0.9862227439880372,0.0463167168200016,0.213456634775387,43793.0,0.9853727221488952,0.0489822104573249,0.2192443785054851,43793.0,2893.168367624283,4347.192717552185,2893.168367624283,1453.4429540634155,0.335129976272583,0.0 -9000,0.02521121,0.033745617,,,,,,,,,,,,,,,,, -9100,0.026236229,0.040739942,,,,,,,,,,,,,,,,, -9200,0.04405834,0.036237538,,,,,,,,,,,,,,,,, -9300,0.054877944,0.038417526,,,,,,,,,,,,,,,,, -9400,0.03164271,0.033755,,,,,,,,,,,,,,,,, -9500,0.045484632,0.035340965,,,,,,,,,,,,,,,,, -9600,0.025724893,0.031597592,,,,,,,,,,,,,,,,, -9700,,,0.989570677280426,0.0349280498921871,0.2965503820496965,0.9864078760147096,0.0460746698081493,0.23116527704377,43793.0,0.9855323433876038,0.0486967712640762,0.2262703918585826,43793.0,3133.286482810974,4695.340826034546,3133.286482810974,1561.425509929657,0.3628628253936767,0.0 -9700,0.035660665,0.033754505,,,,,,,,,,,,,,,,, -9800,0.04108704,0.037089754,,,,,,,,,,,,,,,,, -9900,0.023782728,0.032308694,,,,,,,,,,,,,,,,, -10000,0.032287568,0.038150303,,,,,,,,,,,,,,,,, -10100,0.035501245,0.03685046,,,,,,,,,,,,,,,,, -10200,0.04418978,0.03705341,,,,,,,,,,,,,,,,, -10300,0.02562702,0.03470487,,,,,,,,,,,,,,,,, -10400,0.027350035,0.030223541,,,,,,,,,,,,,,,,, -10446,,,0.9897487759590148,0.0343241393566131,0.3145140324309475,0.9864723682403564,0.0455073565244674,0.2354921471673552,43793.0,0.9855976104736328,0.0481628887355327,0.2321888170443082,43793.0,3373.3551876544952,5045.592380523682,3373.3551876544952,1671.5593321323397,0.3925163745880127,0.0 -10500,0.033737056,0.03456356,,,,,,,,,,,,,,,,, -10600,0.041261762,0.03754908,,,,,,,,,,,,,,,,, -10700,0.032951284,0.035096128,,,,,,,,,,,,,,,,, -10800,0.03766444,0.031603437,,,,,,,,,,,,,,,,, -10900,0.041211728,0.033928756,,,,,,,,,,,,,,,,, -11000,0.037002917,0.035297032,,,,,,,,,,,,,,,,, -11100,0.029049646,0.033124976,,,,,,,,,,,,,,,,, -11191,,,0.9900938272476196,0.0334394387900829,0.3254542392372632,0.986543834209442,0.0448717139661312,0.2416672091869643,43793.0,0.9856485724449158,0.0476858727633953,0.2412774564387674,43793.0,3613.492337703705,5393.927290678024,3613.492337703705,1779.7080972194672,0.420879602432251,0.0 -11200,0.029068071,0.032832783,,,,,,,,,,,,,,,,, -11300,0.053938765,0.036300674,,,,,,,,,,,,,,,,, -11400,0.028775103,0.034749035,,,,,,,,,,,,,,,,, -11500,0.038403254,0.033098474,,,,,,,,,,,,,,,,, -11600,0.044387203,0.029845893,,,,,,,,,,,,,,,,, -11700,0.03900567,0.03316526,,,,,,,,,,,,,,,,, -11800,0.041671425,0.034535002,,,,,,,,,,,,,,,,, -11900,0.039330676,0.037024535,,,,,,,,,,,,,,,,, -11940,,,0.9900937676429749,0.0329126082360744,0.3425646620707906,0.9863781929016112,0.0451198928058147,0.2363973885692383,43793.0,0.9855904579162598,0.0475993864238262,0.2349569160331164,43793.0,3853.50465130806,5748.61470246315,3853.50465130806,1894.3356800079343,0.4488670825958252,0.0 -12000,0.034017928,0.035041817,,,,,,,,,,,,,,,,, -12100,0.036022954,0.03112206,,,,,,,,,,,,,,,,, -12200,0.03361942,0.033363607,,,,,,,,,,,,,,,,, -12300,0.032698832,0.03252634,,,,,,,,,,,,,,,,, -12400,0.03971781,0.030377569,,,,,,,,,,,,,,,,, -12500,0.04548681,0.030737106,,,,,,,,,,,,,,,,, -12600,0.038233127,0.031324606,,,,,,,,,,,,,,,,, -12683,,,0.9903525710105896,0.0320491008460521,0.3773028926927707,0.9866164922714232,0.0448366589844226,0.2527637025524815,43793.0,0.9857320189476012,0.047577790915966,0.2476646823485428,43793.0,4093.458253622055,6100.5614466667175,4093.458253622055,2006.2802784442904,0.4773871898651123,0.0 -12700,0.048227128,0.030311922,,,,,,,,,,,,,,,,, -12800,0.05940573,0.033140205,,,,,,,,,,,,,,,,, -12900,0.04622706,0.032519292,,,,,,,,,,,,,,,,, -13000,0.07819373,0.038030464,,,,,,,,,,,,,,,,, -13100,0.039225098,0.036853395,,,,,,,,,,,,,,,,, -13200,0.03385353,0.028752033,,,,,,,,,,,,,,,,, -13300,0.042813502,0.033633627,,,,,,,,,,,,,,,,, -13400,0.046973664,0.034821875,,,,,,,,,,,,,,,,, -13417,,,0.9905537962913512,0.0310894008725881,0.3724031313377338,0.9866245985031128,0.0445709228515625,0.2492293152620507,43793.0,0.9857416749000548,0.0473575927317142,0.2451571180269876,43793.0,4333.447480678558,6455.656747102737,4333.447480678558,2121.336286067962,0.5057508945465088,0.0 -13500,0.03578965,0.029886253,,,,,,,,,,,,,,,,, -13600,0.05325577,0.030483818,,,,,,,,,,,,,,,,, -13700,0.047901392,0.031624738,,,,,,,,,,,,,,,,, -13800,0.04549363,0.029274028,,,,,,,,,,,,,,,,, -13900,0.058355637,0.029241586,,,,,,,,,,,,,,,,, -14000,0.05235594,0.03199891,,,,,,,,,,,,,,,,, -14100,0.05361472,0.028412845,,,,,,,,,,,,,,,,, -14151,,,0.9907062649726868,0.0309913195669651,0.378599764563951,0.9865682125091552,0.044715654104948,0.2523916541291809,43793.0,0.9856612086296082,0.0475810840725898,0.2457801394641009,43793.0,4573.6459176540375,6803.714690446854,4573.6459176540375,2229.1408503055573,0.5367867946624756,0.0 -14200,0.05301141,0.028447658,,,,,,,,,,,,,,,,, -14300,0.07867539,0.033778373,,,,,,,,,,,,,,,,, -14400,0.0897146,0.03212406,,,,,,,,,,,,,,,,, -14500,0.048343595,0.02858196,,,,,,,,,,,,,,,,, -14600,0.06522705,0.033192717,,,,,,,,,,,,,,,,, -14700,0.064586796,0.03647,,,,,,,,,,,,,,,,, -14800,0.05590579,0.031848352,,,,,,,,,,,,,,,,, -14900,0.083376326,0.032917205,,,,,,,,,,,,,,,,, -14903,,,0.9905675649642944,0.0312392432242631,0.3900551148423534,0.98665589094162,0.0447624139487743,0.2534447646793866,43793.0,0.9858141541481018,0.0474219284951686,0.2527017177476719,43793.0,4813.645154476166,7158.222968816757,4813.645154476166,2343.599368095398,0.567765474319458,0.0 -15000,0.06036621,0.031371575,,,,,,,,,,,,,,,,, -15100,0.059185244,0.035736144,,,,,,,,,,,,,,,,, -15200,0.05245202,0.03561887,,,,,,,,,,,,,,,,, -15300,0.059977647,0.03155472,,,,,,,,,,,,,,,,, -15400,0.060205333,0.03195032,,,,,,,,,,,,,,,,, -15500,0.06969128,0.032459848,,,,,,,,,,,,,,,,, -15600,0.065994106,0.032313075,,,,,,,,,,,,,,,,, -15652,,,0.9906066060066224,0.0311218444257974,0.381319950241573,0.9867772459983826,0.0444102399051189,0.2583747608772624,43793.0,0.9858461618423462,0.0471496619284153,0.2551532981849261,43793.0,5053.900189876556,7504.557371377945,5053.900189876556,2449.626982450485,0.5992722511291504,0.0 -15700,0.052812856,0.028067447,,,,,,,,,,,,,,,,, -15800,0.06583417,0.030113831,,,,,,,,,,,,,,,,, -15900,0.0655622,0.02879543,,,,,,,,,,,,,,,,, -16000,0.08741749,0.032130837,,,,,,,,,,,,,,,,, -16100,0.05562324,0.03013091,,,,,,,,,,,,,,,,, -16200,0.07356718,0.032739796,,,,,,,,,,,,,,,,, -16300,0.07939664,0.031370934,,,,,,,,,,,,,,,,, -16395,,,0.9906516671180724,0.0308571141213178,0.3869455478987839,0.98651784658432,0.0449075438082218,0.2555116961913109,43793.0,0.985654890537262,0.0475700162351131,0.2462189962544967,43793.0,5294.087451457977,7856.685685396194,5294.087451457977,2561.516103744507,0.6315131187438965,0.0 -16400,0.083877884,0.030557498,,,,,,,,,,,,,,,,, -16500,0.06317895,0.029754823,,,,,,,,,,,,,,,,, -16600,0.071020566,0.032880645,,,,,,,,,,,,,,,,, -16700,0.06702443,0.0307844,,,,,,,,,,,,,,,,, -16800,0.059911977,0.032946248,,,,,,,,,,,,,,,,, -16900,0.12630881,0.032373603,,,,,,,,,,,,,,,,, -17000,0.06657,0.030851161,,,,,,,,,,,,,,,,, -17100,0.06807847,0.026694147,,,,,,,,,,,,,,,,, -17141,,,0.990618884563446,0.0309921987354755,0.3856777667274151,0.9865024089813232,0.0446387305855751,0.2601376590016442,43793.0,0.9855812191963196,0.0473384335637092,0.2524884362959509,43793.0,5534.083811998367,8205.371302127838,5534.083811998367,2670.1546428203583,0.6623170375823975,0.0 -17200,0.077148415,0.031105973,,,,,,,,,,,,,,,,, -17300,0.050712675,0.030722743,,,,,,,,,,,,,,,,, -17400,0.07793234,0.03555929,,,,,,,,,,,,,,,,, -17500,0.10867419,0.03274223,,,,,,,,,,,,,,,,, -17600,0.06366967,0.030919442,,,,,,,,,,,,,,,,, -17700,0.06812235,0.030070726,,,,,,,,,,,,,,,,, -17800,0.06048418,0.03001927,,,,,,,,,,,,,,,,, -17887,,,0.9904924035072328,0.031084245070815,0.3930251830262817,0.9865880608558656,0.0455077663064003,0.2632867709525448,43793.0,0.9856747388839722,0.0483092926442623,0.2461299896463264,43793.0,5774.310196876526,8557.450502157211,5774.310196876526,2781.957170248032,0.6927800178527832,0.0 -17900,0.095036305,0.03142233,,,,,,,,,,,,,,,,, -18000,0.111042395,0.034852758,,,,,,,,,,,,,,,,, -18100,0.12196791,0.03145592,,,,,,,,,,,,,,,,, -18200,0.06937908,0.028236495,,,,,,,,,,,,,,,,, -18300,0.093664154,0.032758094,,,,,,,,,,,,,,,,, -18400,0.11428145,0.028647006,,,,,,,,,,,,,,,,, -18500,0.10086213,0.031042395,,,,,,,,,,,,,,,,, -18600,0.075211495,0.032625508,,,,,,,,,,,,,,,,, -18633,,,0.9910005927085876,0.0293161552399396,0.4221835982123053,0.9868016242980956,0.0446500442922115,0.2631362447173095,43793.0,0.9858941435813904,0.0473372787237167,0.2532866247875107,43793.0,6014.308893442154,8904.329939126968,6014.308893442154,2888.7862842082977,0.7243075370788574,0.0 -18700,0.09326199,0.033933923,,,,,,,,,,,,,,,,, -18800,0.10619978,0.03136144,,,,,,,,,,,,,,,,, -18900,0.082036376,0.030441606,,,,,,,,,,,,,,,,, -19000,0.08232802,0.029808879,,,,,,,,,,,,,,,,, -19100,0.10940476,0.029726507,,,,,,,,,,,,,,,,, -19200,0.075314954,0.027447121,,,,,,,,,,,,,,,,, -19300,0.06756652,0.031421628,,,,,,,,,,,,,,,,, -19385,,,0.991144597530365,0.0288536138832569,0.4362776975542201,0.986806094646454,0.0443035177886486,0.2714599200929934,43793.0,0.9859779477119446,0.0470069982111454,0.2564808137063278,43793.0,6254.484782218933,9256.199064016342,6254.484782218933,3000.428964138031,0.754981517791748,0.0 -19400,0.07986476,0.029361432,,,,,,,,,,,,,,,,, -19500,0.07785343,0.032581303,,,,,,,,,,,,,,,,, -19600,0.07439653,0.028207878,,,,,,,,,,,,,,,,, -19700,0.062145453,0.028600994,,,,,,,,,,,,,,,,, -19800,0.11424782,0.030632894,,,,,,,,,,,,,,,,, -19900,0.11780153,0.028650321,,,,,,,,,,,,,,,,, -20000,0.07119805,0.031523157,,,,,,,,,,,,,,,,, -20100,0.06751107,0.027730463,,,,,,,,,,,,,,,,, -20137,,,0.9912317991256714,0.0284095779061317,0.4510648548895596,0.9867587685585022,0.0445445626974105,0.2627217483865065,43793.0,0.9858676195144652,0.0475185513496398,0.2565515376288197,43793.0,6494.502206325531,9602.328292369844,6494.502206325531,3106.488452196121,0.7869946956634521,0.0 -20200,0.08392026,0.030184545,,,,,,,,,,,,,,,,, -20300,0.09206973,0.03074185,,,,,,,,,,,,,,,,, -20400,0.10978636,0.033885747,,,,,,,,,,,,,,,,, -20500,0.06848459,0.028268164,,,,,,,,,,,,,,,,, -20600,0.11011211,0.028243497,,,,,,,,,,,,,,,,, -20700,0.068097204,0.028934805,,,,,,,,,,,,,,,,, -20800,0.08896868,0.028404983,,,,,,,,,,,,,,,,, -20892,,,0.9913666248321532,0.0281404834240674,0.4477376798318661,0.9866806268692015,0.0450030602514743,0.2594265373774643,43793.0,0.985849916934967,0.0476750880479812,0.2547641389335,43793.0,6734.702629804611,9953.830657482147,6734.702629804611,3217.7399446964264,0.8169877529144287,0.0 -20900,0.084875524,0.027024703,,,,,,,,,,,,,,,,, -21000,0.14992775,0.031760696,,,,,,,,,,,,,,,,, -21100,0.09092924,0.03272212,,,,,,,,,,,,,,,,, -21200,0.0782376,0.030138202,,,,,,,,,,,,,,,,, -21300,0.09782265,0.02804834,,,,,,,,,,,,,,,,, -21400,0.081100784,0.031061515,,,,,,,,,,,,,,,,, -21500,0.10156834,0.028561989,,,,,,,,,,,,,,,,, -21600,0.082714975,0.030566622,,,,,,,,,,,,,,,,, -21632,,,0.9910850524902344,0.029171073809266,0.4186773832495749,0.986506462097168,0.0445950105786323,0.260642723999837,43793.0,0.9857699275016784,0.0473132506012916,0.2545662576834401,43793.0,6974.657138586044,10306.164506912231,6974.657138586044,3330.0670692920685,0.84893798828125,0.0 -21700,0.102859765,0.028859401,,,,,,,,,,,,,,,,, -21800,0.087606,0.028921383,,,,,,,,,,,,,,,,, -21900,0.11623614,0.0292548,,,,,,,,,,,,,,,,, -22000,0.09502483,0.02616194,,,,,,,,,,,,,,,,, -22100,0.087685846,0.03141291,,,,,,,,,,,,,,,,, -22200,0.10018681,0.03302588,,,,,,,,,,,,,,,,, -22300,0.093601115,0.03169868,,,,,,,,,,,,,,,,, -22380,,,0.9910112023353576,0.0293988939374685,0.413999674447663,0.9867098927497864,0.0447753965854644,0.2613197464129227,43793.0,0.9858149886131288,0.0473951138556003,0.260235189498801,43793.0,7214.641557455063,10657.6014316082,7214.641557455063,3441.468363761902,0.8800814151763916,0.0 -22400,0.09637475,0.029803948,,,,,,,,,,,,,,,,, -22500,0.09029157,0.03258139,,,,,,,,,,,,,,,,, -22600,0.077512555,0.028492698,,,,,,,,,,,,,,,,, -22700,0.1257537,0.02741359,,,,,,,,,,,,,,,,, -22800,0.11008849,0.032226577,,,,,,,,,,,,,,,,, -22900,0.08040983,0.027526896,,,,,,,,,,,,,,,,, -23000,0.11904879,0.032666974,,,,,,,,,,,,,,,,, -23100,0.08651732,0.027417576,,,,,,,,,,,,,,,,, -23133,,,0.9910687804222108,0.0292721986770629,0.427519522785992,0.98666113615036,0.0447103604674339,0.2611749950881226,43793.0,0.9857193827629088,0.047521024942398,0.256449131398079,43793.0,7454.6159999370575,11006.72515630722,7454.6159999370575,3550.564575910568,0.9132204055786132,0.0 -23200,0.10556473,0.031950343,,,,,,,,,,,,,,,,, -23300,0.09933172,0.030182647,,,,,,,,,,,,,,,,, -23400,0.1031243,0.026926072,,,,,,,,,,,,,,,,, -23500,0.13095902,0.028715637,,,,,,,,,,,,,,,,, -23600,0.11512202,0.027000545,,,,,,,,,,,,,,,,, -23700,0.0937226,0.030055163,,,,,,,,,,,,,,,,, -23800,0.08178323,0.030677535,,,,,,,,,,,,,,,,, -23882,,,0.9910465478897096,0.0290975049138069,0.439029657393755,0.986792266368866,0.0448956228792667,0.2642872001662467,43793.0,0.9859809279441832,0.0475788004696369,0.2583359648758452,43793.0,7694.792117834091,11361.706496953964,7694.792117834091,3665.318177700042,0.9451286792755128,0.0 -23900,0.09419485,0.028223457,,,,,,,,,,,,,,,,, -24000,0.11633725,0.028372567,,,,,,,,,,,,,,,,, -24100,0.17032328,0.03552303,,,,,,,,,,,,,,,,, -24200,0.09110016,0.029147143,,,,,,,,,,,,,,,,, -24300,0.09996577,0.027355459,,,,,,,,,,,,,,,,, -24400,0.103642955,0.03234976,,,,,,,,,,,,,,,,, -24500,0.10220784,0.031043772,,,,,,,,,,,,,,,,, -24600,0.08604578,0.028097518,,,,,,,,,,,,,,,,, -24616,,,0.9911187291145324,0.0287637133151292,0.4408214315236216,0.9866234064102172,0.0449697487056255,0.2635825782785605,43793.0,0.9857947826385498,0.0476335324347019,0.2508257458319819,43793.0,7934.79806470871,11710.37767481804,7934.79806470871,3773.929664850235,0.9760401248931884,0.0 -24700,0.10874573,0.02946469,,,,,,,,,,,,,,,,, -24800,0.10351778,0.031540908,,,,,,,,,,,,,,,,, -24900,0.09296248,0.030613806,,,,,,,,,,,,,,,,, -25000,0.1131015,0.028524516,,,,,,,,,,,,,,,,, -25100,0.15606874,0.031054024,,,,,,,,,,,,,,,,, -25200,0.09462523,0.02777894,,,,,,,,,,,,,,,,, -25300,0.12313176,0.028790755,,,,,,,,,,,,,,,,, -25358,,,0.991275668144226,0.0282500777393579,0.4428742902391496,0.986655056476593,0.0448307059705257,0.2637613334042595,43793.0,0.9857442378997804,0.0476606115698814,0.2516790343245108,43793.0,8174.913177490234,12064.129575967789,8174.913177490234,3887.515341043472,1.0077130794525146,0.0 -25400,0.119740576,0.028920054,,,,,,,,,,,,,,,,, -25500,0.09126381,0.028110642,,,,,,,,,,,,,,,,, -25600,0.083176546,0.028297639,,,,,,,,,,,,,,,,, -25700,0.09677001,0.028150607,,,,,,,,,,,,,,,,, -25800,0.09121842,0.030133432,,,,,,,,,,,,,,,,, -25900,0.09693356,0.030752705,,,,,,,,,,,,,,,,, -26000,0.10431232,0.0313055,,,,,,,,,,,,,,,,, -26100,0.09911956,0.030114306,,,,,,,,,,,,,,,,, -26103,,,0.9914022088050842,0.0277662333101034,0.4523410083924511,0.9866623878479004,0.044797908514738,0.2641685438259154,43793.0,0.9858486652374268,0.0474232174456119,0.2543806985051701,43793.0,8414.918694972992,12413.05533671379,8414.918694972992,3996.383995056152,1.039290189743042,0.0 -26200,0.102336794,0.03215817,,,,,,,,,,,,,,,,, -26300,0.112399265,0.0289382,,,,,,,,,,,,,,,,, -26400,0.10999418,0.030498546,,,,,,,,,,,,,,,,, -26500,0.09141514,0.026150463,,,,,,,,,,,,,,,,, -26600,0.13226256,0.025873365,,,,,,,,,,,,,,,,, -26700,0.08534764,0.028231282,,,,,,,,,,,,,,,,, -26800,0.07593383,0.026361858,,,,,,,,,,,,,,,,, -26844,,,0.991820216178894,0.0265189036726951,0.4992254146540879,0.986660361289978,0.0447358042001724,0.2714495760771454,43793.0,0.9857585430145264,0.0474692322313785,0.2568551148478193,43793.0,8655.010743618011,12761.361777067184,8655.010743618011,4104.547461509705,1.0705244541168213,0.0 -26900,0.077196494,0.028586509,,,,,,,,,,,,,,,,, -27000,0.10251771,0.029890772,,,,,,,,,,,,,,,,, -27100,0.07365918,0.026633885,,,,,,,,,,,,,,,,, -27200,0.12296347,0.030178718,,,,,,,,,,,,,,,,, -27300,0.098914295,0.027435431,,,,,,,,,,,,,,,,, -27400,0.098261125,0.027913138,,,,,,,,,,,,,,,,, -27500,0.12988836,0.025849385,,,,,,,,,,,,,,,,, -27591,,,0.99174964427948,0.0268844552338123,0.4890036321106513,0.98653244972229,0.0451820157468318,0.2685428630516191,43793.0,0.9857147336006165,0.0480520650744438,0.2612988077710391,43793.0,8895.237357616425,13110.495953798294,8895.237357616425,4213.402389764786,1.1035490036010742,0.0 -27600,0.11069943,0.029045923,,,,,,,,,,,,,,,,, -27700,0.100468375,0.030575853,,,,,,,,,,,,,,,,, -27800,0.098413154,0.02750464,,,,,,,,,,,,,,,,, -27900,0.09825123,0.029657189,,,,,,,,,,,,,,,,, -28000,0.08567467,0.027965559,,,,,,,,,,,,,,,,, -28100,0.09788155,0.027739445,,,,,,,,,,,,,,,,, -28200,0.14617562,0.030862942,,,,,,,,,,,,,,,,, -28300,0.1260514,0.02879449,,,,,,,,,,,,,,,,, -28339,,,0.991543173789978,0.0273032449185848,0.4728328304098546,0.9867244958877563,0.0452097095549106,0.267856911482501,43793.0,0.9858225584030152,0.0483605340123176,0.2572859701949634,43793.0,9135.220917224884,13459.781569480896,9135.220917224884,4322.653234481812,1.1350586414337158,0.0 -28400,0.14656901,0.028908156,,,,,,,,,,,,,,,,, -28500,0.08688746,0.028070614,,,,,,,,,,,,,,,,, -28600,0.099576846,0.026053337,,,,,,,,,,,,,,,,, -28700,0.10061708,0.027362399,,,,,,,,,,,,,,,,, -28800,0.10604052,0.03011609,,,,,,,,,,,,,,,,, -28900,0.09493108,0.02814605,,,,,,,,,,,,,,,,, -29000,0.09203853,0.030322965,,,,,,,,,,,,,,,,, -29087,,,0.9913496375083924,0.0279213413596153,0.446704589850163,0.9866745471954346,0.0448725298047065,0.2700355602727555,43793.0,0.9858547449111938,0.0476239696145057,0.2617479458606887,43793.0,9375.281760692596,13806.64552783966,9375.281760692596,4429.404903411865,1.1667847633361816,0.0 -29100,0.09504881,0.027095802,,,,,,,,,,,,,,,,, -29200,0.08987166,0.025102124,,,,,,,,,,,,,,,,, -29300,0.21860233,0.029266678,,,,,,,,,,,,,,,,, -29400,0.09597576,0.027095618,,,,,,,,,,,,,,,,, -29500,0.08937205,0.030128159,,,,,,,,,,,,,,,,, -29600,0.080559954,0.027339479,,,,,,,,,,,,,,,,, -29700,0.094141036,0.027300373,,,,,,,,,,,,,,,,, -29800,0.092050955,0.02616923,,,,,,,,,,,,,,,,, -29841,,,0.9912735819816588,0.0282083544880151,0.4498426126219626,0.9865994453430176,0.0453692749142646,0.2591981850120724,43793.0,0.9858208894729614,0.0481179356575012,0.2530234007583115,43793.0,9615.312079906464,14153.5893740654,9615.312079906464,4536.267310619354,1.1982874870300293,0.0 -29900,0.09886873,0.028205233,,,,,,,,,,,,,,,,, -30000,0.14068325,0.03067871,,,,,,,,,,,,,,,,, -30100,0.11737578,0.029176699,,,,,,,,,,,,,,,,, -30200,0.19978118,0.029761849,,,,,,,,,,,,,,,,, -30300,0.08503668,0.03018526,,,,,,,,,,,,,,,,, -30400,0.10209632,0.025828866,,,,,,,,,,,,,,,,, -30500,0.1238424,0.026519235,,,,,,,,,,,,,,,,, -30587,,,0.9914898872375488,0.0273872166872024,0.4727808883611701,0.986668050289154,0.0453085117042064,0.2628570416310531,43793.0,0.9857636094093324,0.0484786257147789,0.254502359861954,43793.0,9855.56041288376,14503.186990499496,9855.56041288376,4645.564736843109,1.230731964111328,0.0 -30600,0.09585542,0.028734518,,,,,,,,,,,,,,,,, -30700,0.09568805,0.029588057,,,,,,,,,,,,,,,,, -30800,0.12942646,0.027734213,,,,,,,,,,,,,,,,, -30900,0.08471707,0.025562532,,,,,,,,,,,,,,,,, -31000,0.1007007,0.026898105,,,,,,,,,,,,,,,,, -31100,0.116816305,0.023897843,,,,,,,,,,,,,,,,, -31200,0.10868389,0.026628463,,,,,,,,,,,,,,,,, -31300,0.115255505,0.029235322,,,,,,,,,,,,,,,,, -31331,,,0.9915563464164734,0.0272246785461902,0.4779938282806382,0.986726939678192,0.0449745245277881,0.2681475382524397,43793.0,0.9859392046928406,0.0476431138813495,0.2611744513547663,43793.0,10095.512593507769,14847.700211763382,10095.512593507769,4750.071592330933,1.2652955055236816,0.0 -31400,0.11731528,0.027372262,,,,,,,,,,,,,,,,, -31500,0.10343293,0.028206632,,,,,,,,,,,,,,,,, -31600,0.084769234,0.02683612,,,,,,,,,,,,,,,,, -31700,0.09479092,0.026970498,,,,,,,,,,,,,,,,, -31800,0.18109362,0.02853073,,,,,,,,,,,,,,,,, -31900,0.12622172,0.030074028,,,,,,,,,,,,,,,,, -32000,0.102393076,0.027072085,,,,,,,,,,,,,,,,, -32084,,,0.9915878772735596,0.0269397553056478,0.4846764472143901,0.9865304231643676,0.0454461760818958,0.2681769855204736,43793.0,0.9857879877090454,0.0480624511837959,0.2527473212298942,43793.0,10335.648624420166,15194.996549367905,10335.648624420166,4857.177111625671,1.300189971923828,0.0 -32100,0.09786163,0.027382553,,,,,,,,,,,,,,,,, -32200,0.11015079,0.028108679,,,,,,,,,,,,,,,,, -32300,0.1155535,0.028454002,,,,,,,,,,,,,,,,, -32400,0.09763437,0.027571695,,,,,,,,,,,,,,,,, -32500,0.11050883,0.02475548,,,,,,,,,,,,,,,,, -32600,0.17590629,0.029236618,,,,,,,,,,,,,,,,, -32700,0.1365606,0.02352922,,,,,,,,,,,,,,,,, -32800,0.10471296,0.026285445,,,,,,,,,,,,,,,,, -32842,,,0.9919490218162536,0.0257465578615665,0.5127910015596717,0.9866639971733092,0.0449895188212394,0.2654056363679506,43793.0,0.9858031868934632,0.0478540509939193,0.2626084049281989,43793.0,10575.60150718689,15540.194693565369,10575.60150718689,4962.369988441467,1.332615613937378,0.0 -32900,0.100618534,0.031161495,,,,,,,,,,,,,,,,, -33000,0.087787405,0.02285283,,,,,,,,,,,,,,,,, -33100,0.09748227,0.028047279,,,,,,,,,,,,,,,,, -33200,0.104948595,0.027124995,,,,,,,,,,,,,,,,, -33300,0.109919086,0.028881,,,,,,,,,,,,,,,,, -33400,0.09403877,0.025756024,,,,,,,,,,,,,,,,, -33500,0.14245352,0.024111578,,,,,,,,,,,,,,,,, -33592,,,0.9920307397842408,0.0254134628921747,0.5067656067708026,0.9868279695510864,0.045044295489788,0.2658500783934601,43793.0,0.9859409332275392,0.0479836985468864,0.2574362116662635,43793.0,10815.79008245468,15885.7362678051,10815.79008245468,5067.671158790588,1.3651981353759766,0.0 -33600,0.10890853,0.028359165,,,,,,,,,,,,,,,,, -33700,0.09570089,0.025654238,,,,,,,,,,,,,,,,, -33800,0.117074914,0.026473152,,,,,,,,,,,,,,,,, -33900,0.10381877,0.026263988,,,,,,,,,,,,,,,,, -34000,0.13061623,0.028942263,,,,,,,,,,,,,,,,, -34100,0.12414586,0.028465562,,,,,,,,,,,,,,,,, -34200,0.13650149,0.025436917,,,,,,,,,,,,,,,,, -34300,0.13561,0.02853014,,,,,,,,,,,,,,,,, -34348,,,0.9923913478851318,0.0244236923754215,0.5342409464257092,0.9867200255393982,0.0453106351196765,0.2671155695587615,43793.0,0.985846996307373,0.0481024570763111,0.2553736626151114,43793.0,11056.026224136353,16232.710386753082,11056.026224136353,5174.355105876923,1.3994545936584473,0.0 -34400,0.12155726,0.029157227,,,,,,,,,,,,,,,,, -34500,0.095868975,0.026368864,,,,,,,,,,,,,,,,, -34600,0.15300892,0.030389363,,,,,,,,,,,,,,,,, -34700,0.10513388,0.026291292,,,,,,,,,,,,,,,,, -34800,0.10347259,0.029231872,,,,,,,,,,,,,,,,, -34900,0.10623486,0.025613759,,,,,,,,,,,,,,,,, -35000,0.116887555,0.02569254,,,,,,,,,,,,,,,,, -35099,,,0.9922991991043092,0.0248646866530179,0.5406791037882595,0.9865483045578004,0.0454125665128231,0.2769924562069861,43793.0,0.9857711791992188,0.0481272302567958,0.2624014605046485,43793.0,11296.04472565651,16582.92563867569,11296.04472565651,5284.490906953812,1.4411022663116455,0.0 -35100,0.10099718,0.022243628,,,,,,,,,,,,,,,,, -35200,0.0999389,0.026608767,,,,,,,,,,,,,,,,, -35300,0.10017306,0.024007414,,,,,,,,,,,,,,,,, -35400,0.09943647,0.025778241,,,,,,,,,,,,,,,,, -35500,0.11737779,0.025752373,,,,,,,,,,,,,,,,, -35600,0.12673493,0.026886515,,,,,,,,,,,,,,,,, -35700,0.13594131,0.026476867,,,,,,,,,,,,,,,,, -35800,0.10121665,0.026538447,,,,,,,,,,,,,,,,, -35851,,,0.9920918941497804,0.0254524517804384,0.5169067383191193,0.9867861866950988,0.0457796640694141,0.2736346291879268,43793.0,0.985917329788208,0.048857532441616,0.258611242636092,43793.0,11536.114127397535,16931.671364068985,11536.114127397535,5393.113869190216,1.474443435668945,0.0 -35900,0.130899,0.024766799,,,,,,,,,,,,,,,,, -36000,0.12258773,0.02584539,,,,,,,,,,,,,,,,, -36100,0.116149575,0.025442386,,,,,,,,,,,,,,,,, -36200,0.12716354,0.026692314,,,,,,,,,,,,,,,,, -36300,0.10443445,0.024265915,,,,,,,,,,,,,,,,, -36400,0.115276046,0.02530532,,,,,,,,,,,,,,,,, -36500,0.09595693,0.022437269,,,,,,,,,,,,,,,,, -36600,0.10174371,0.024535697,,,,,,,,,,,,,,,,, -36606,,,0.9918672442436218,0.0259287282824516,0.4994652663793621,0.9868308305740356,0.0456037297844886,0.2706534819480221,43793.0,0.9859451055526732,0.0486032329499721,0.2502696397802891,43793.0,11776.257864952087,17281.49413871765,11776.257864952087,5502.739744663239,1.5083093643188477,0.0 -36700,0.11438386,0.025311233,,,,,,,,,,,,,,,,, -36800,0.13504282,0.025012916,,,,,,,,,,,,,,,,, -36900,0.098023206,0.023509795,,,,,,,,,,,,,,,,, -37000,0.111730866,0.02410519,,,,,,,,,,,,,,,,, -37100,0.103640005,0.02440521,,,,,,,,,,,,,,,,, -37200,0.12235517,0.028438332,,,,,,,,,,,,,,,,, -37300,0.13142262,0.026554821,,,,,,,,,,,,,,,,, -37345,,,0.9918071031570436,0.0262284409254789,0.4914974538833261,0.986737072467804,0.0458212234079837,0.269111434883714,43793.0,0.9858027696609496,0.0489179119467735,0.2587711548087072,43793.0,12016.488486766815,17630.008954048157,12016.488486766815,5610.971168756485,1.541703224182129,0.0 -37400,0.1108258,0.023611473,,,,,,,,,,,,,,,,, -37500,0.10716722,0.027397718,,,,,,,,,,,,,,,,, -37600,0.13214794,0.030047985,,,,,,,,,,,,,,,,, -37700,0.115239546,0.024144672,,,,,,,,,,,,,,,,, -37800,0.13711286,0.024640212,,,,,,,,,,,,,,,,, -37900,0.11502521,0.022943046,,,,,,,,,,,,,,,,, -38000,0.11055653,0.023621641,,,,,,,,,,,,,,,,, -38090,,,0.991980254650116,0.0255869254469871,0.5115828856662827,0.9867179989814758,0.0459587611258029,0.2713995649094152,43793.0,0.9859261512756348,0.048667199909687,0.2615826043412462,43793.0,12256.664202690125,17973.409933567047,12256.664202690125,5714.142570018768,1.5757763385772705,0.0 -38100,0.13675305,0.024299935,,,,,,,,,,,,,,,,, -38200,0.13667673,0.028256102,,,,,,,,,,,,,,,,, -38300,0.11306944,0.023928365,,,,,,,,,,,,,,,,, -38400,0.11826993,0.024009189,,,,,,,,,,,,,,,,, -38500,0.13263628,0.026845125,,,,,,,,,,,,,,,,, -38600,0.11999197,0.02386947,,,,,,,,,,,,,,,,, -38700,0.1424192,0.024631198,,,,,,,,,,,,,,,,, -38800,0.11501854,0.0238901,,,,,,,,,,,,,,,,, -38833,,,0.9920877814292908,0.0252947751432657,0.521443788126027,0.986579179763794,0.0461673587560653,0.2693436348083421,43793.0,0.9857484102249146,0.0489294081926345,0.2565882521212089,43793.0,12496.846721887589,18318.59648966789,12496.846721887589,5819.092994213104,1.608827829360962,0.0 -38900,0.1422826,0.025778214,,,,,,,,,,,,,,,,, -39000,0.11691188,0.028440693,,,,,,,,,,,,,,,,, -39100,0.12787353,0.027667556,,,,,,,,,,,,,,,,, -39200,0.1257471,0.02241007,,,,,,,,,,,,,,,,, -39300,0.1190355,0.02339757,,,,,,,,,,,,,,,,, -39400,0.12719679,0.024904637,,,,,,,,,,,,,,,,, -39500,0.12784824,0.024868907,,,,,,,,,,,,,,,,, -39566,,,0.9922082424163818,0.0247992202639579,0.5314044332998933,0.9865458607673644,0.0463249869644641,0.267102513934124,43793.0,0.985687792301178,0.0492235198616981,0.25741389805365,43793.0,12737.104538679125,18671.7216861248,12737.104538679125,5931.902855873108,1.6432619094848633,0.0 -39600,0.18530376,0.026045356,,,,,,,,,,,,,,,,, -39700,0.11099311,0.025108388,,,,,,,,,,,,,,,,, -39800,0.10603078,0.02302197,,,,,,,,,,,,,,,,, -39900,0.11399189,0.025626495,,,,,,,,,,,,,,,,, -40000,0.11353357,0.025159381,,,,,,,,,,,,,,,,, -40100,0.14003937,0.02455127,,,,,,,,,,,,,,,,, -40200,0.15407614,0.024201948,,,,,,,,,,,,,,,,, -40300,0.16052458,0.027700786,,,,,,,,,,,,,,,,, -40305,,,0.9925952553749084,0.023490697145462,0.5565252775312892,0.986531674861908,0.0463293083012104,0.2676215261520471,43793.0,0.9857577085494996,0.0494484901428222,0.2498049455988531,43793.0,12977.295271873474,19022.058073043823,12977.295271873474,6041.992013454437,1.677354335784912,0.0 -40400,0.14937435,0.024542449,,,,,,,,,,,,,,,,, -40500,0.14651582,0.024903046,,,,,,,,,,,,,,,,, -40600,0.15570387,0.025623921,,,,,,,,,,,,,,,,, -40700,0.12781794,0.028025482,,,,,,,,,,,,,,,,, -40800,0.12760556,0.024742858,,,,,,,,,,,,,,,,, -40900,0.117486686,0.02211987,,,,,,,,,,,,,,,,, -41000,0.1165993,0.027500223,,,,,,,,,,,,,,,,, -41054,,,0.9928678274154664,0.0224870759993791,0.581720308999962,0.9867366552352904,0.0461860261857509,0.2734461547168441,43793.0,0.9858141541481018,0.049161035567522,0.2650616828523954,43793.0,13217.297476530077,19367.962020874023,13217.297476530077,6147.838949918747,1.7124638557434082,0.0 -41100,0.15642451,0.027529765,,,,,,,,,,,,,,,,, -41200,0.13195595,0.025344666,,,,,,,,,,,,,,,,, -41300,0.12335857,0.023474213,,,,,,,,,,,,,,,,, -41400,0.13725384,0.027269006,,,,,,,,,,,,,,,,, -41500,0.15298927,0.026329666,,,,,,,,,,,,,,,,, -41600,0.15722829,0.022938123,,,,,,,,,,,,,,,,, -41700,0.12920575,0.025649704,,,,,,,,,,,,,,,,, -41800,0.13324723,0.025401888,,,,,,,,,,,,,,,,, -41810,,,0.9927249550819396,0.0229897052049636,0.5715105094814787,0.9866583347320556,0.0467988699674606,0.2642473790358029,43793.0,0.9857454895973206,0.049967210739851,0.2580405016937968,43793.0,13457.432827711104,19713.732904434204,13457.432827711104,6253.418445587158,1.7486231327056885,0.0 -41900,0.14451607,0.02373725,,,,,,,,,,,,,,,,, -42000,0.13230267,0.024018798,,,,,,,,,,,,,,,,, -42100,0.18477766,0.027270261,,,,,,,,,,,,,,,,, -42200,0.114688694,0.02544942,,,,,,,,,,,,,,,,, -42300,0.118469104,0.025451649,,,,,,,,,,,,,,,,, -42400,0.14738362,0.026002927,,,,,,,,,,,,,,,,, -42500,0.13626416,0.025935493,,,,,,,,,,,,,,,,, -42556,,,0.992652952671051,0.0233211573213338,0.5512938937985699,0.9865283966064452,0.0467281602323055,0.2636467984267341,43793.0,0.9856923818588256,0.0494961068034172,0.2557325618023192,43793.0,13697.439344882963,20066.57444548607,13697.439344882963,6366.196691989899,1.7837131023406982,0.0 -42600,0.13491765,0.02461847,,,,,,,,,,,,,,,,, -42700,0.19940166,0.025354942,,,,,,,,,,,,,,,,, -42800,0.11629461,0.020303115,,,,,,,,,,,,,,,,, -42900,0.13554342,0.024075858,,,,,,,,,,,,,,,,, -43000,0.15947723,0.022335691,,,,,,,,,,,,,,,,, -43100,0.14733939,0.025244731,,,,,,,,,,,,,,,,, -43200,0.14028287,0.023811353,,,,,,,,,,,,,,,,, -43300,0.13492048,0.021913934,,,,,,,,,,,,,,,,, -43306,,,0.992501735687256,0.0237785894423723,0.5646825159913036,0.9867951273918152,0.0462379977107048,0.2821592979825692,43793.0,0.985849916934967,0.0496870279312133,0.2605483988073737,43793.0,13937.537743330002,20413.0918610096,13937.537743330002,6472.557656049728,1.8214752674102783,0.0 -43400,0.15631202,0.026612954,,,,,,,,,,,,,,,,, -43500,0.12356666,0.021382866,,,,,,,,,,,,,,,,, -43600,0.15657762,0.026505344,,,,,,,,,,,,,,,,, -43700,0.20213866,0.024987828,,,,,,,,,,,,,,,,, -43800,0.18685715,0.022975715,,,,,,,,,,,,,,,,, -43900,0.15118101,0.02469034,,,,,,,,,,,,,,,,, -44000,0.11393057,0.023185773,,,,,,,,,,,,,,,,, -44051,,,0.9924307465553284,0.0238613691180944,0.555371738352912,0.9867119193077089,0.0468770824372768,0.2732673112388979,43793.0,0.9858596324920654,0.0500569939613342,0.2565946215782902,43793.0,14177.75285601616,20761.1159992218,14177.75285601616,6580.312434911728,1.8559012413024905,0.0 -44100,0.14630571,0.023971094,,,,,,,,,,,,,,,,, -44200,0.12974858,0.024545874,,,,,,,,,,,,,,,,, -44300,0.15842088,0.026676277,,,,,,,,,,,,,,,,, -44400,0.13384622,0.02127945,,,,,,,,,,,,,,,,, -44500,0.12747136,0.026508536,,,,,,,,,,,,,,,,, -44600,0.15195505,0.023233509,,,,,,,,,,,,,,,,, -44700,0.11747961,0.023229755,,,,,,,,,,,,,,,,, -44798,,,0.992325484752655,0.0240122228860855,0.5586604499035724,0.986703395843506,0.0472040958702564,0.2669402476371244,43793.0,0.9858166575431824,0.0504205785691738,0.2533213945513642,43793.0,14417.734598636627,21108.37005758285,14417.734598636627,6687.529041051865,1.891420125961304,0.0 -44800,0.17833823,0.026846882,,,,,,,,,,,,,,,,, -44900,0.13445398,0.02265853,,,,,,,,,,,,,,,,, -45000,0.14252636,0.022920912,,,,,,,,,,,,,,,,, -45100,0.14773712,0.023304228,,,,,,,,,,,,,,,,, -45200,0.14949077,0.025109166,,,,,,,,,,,,,,,,, -45300,0.19706185,0.022402717,,,,,,,,,,,,,,,,, -45400,0.15123585,0.023378368,,,,,,,,,,,,,,,,, -45500,0.13424587,0.023098875,,,,,,,,,,,,,,,,, -45551,,,0.9925184845924376,0.0234085004776716,0.5609331101563908,0.9865624904632568,0.0477843768894672,0.2695785599802789,43793.0,0.9856654405593872,0.0509008951485157,0.2581821676206569,43793.0,14657.84571671486,21455.67155122757,14657.84571671486,6794.665002822876,1.9260263442993164,0.0 -45600,0.14402984,0.027036121,,,,,,,,,,,,,,,,, -45700,0.24376604,0.022269968,,,,,,,,,,,,,,,,, -45800,0.16220836,0.023585165,,,,,,,,,,,,,,,,, -45900,0.14022377,0.0220131,,,,,,,,,,,,,,,,, -46000,0.14247175,0.024099626,,,,,,,,,,,,,,,,, -46100,0.12224053,0.022494633,,,,,,,,,,,,,,,,, -46200,0.14718057,0.023141008,,,,,,,,,,,,,,,,, -46300,0.17069115,0.023036899,,,,,,,,,,,,,,,,, -46308,,,0.9927812218666076,0.0225784312933683,0.5735938940547376,0.986531674861908,0.0475295148789882,0.2711253417235155,43793.0,0.9856970310211182,0.050856564193964,0.2524435396842576,43793.0,14897.96158337593,21801.37257266045,14897.96158337593,6900.194499254227,1.9617371559143064,0.0 -46400,0.13399746,0.022935681,,,,,,,,,,,,,,,,, -46500,0.17621039,0.023930836,,,,,,,,,,,,,,,,, -46600,0.1609307,0.023881456,,,,,,,,,,,,,,,,, -46700,0.14370541,0.021020733,,,,,,,,,,,,,,,,, -46800,0.1395501,0.022335593,,,,,,,,,,,,,,,,, -46900,0.12563887,0.02395445,,,,,,,,,,,,,,,,, -47000,0.14916591,0.022813056,,,,,,,,,,,,,,,,, -47046,,,0.9932058453559875,0.0213255230337381,0.6142257470793187,0.9863327741622924,0.0477554388344287,0.2685979413267775,43793.0,0.9854986667633056,0.0510683432221412,0.2488496460511818,43793.0,15138.184672355652,22149.626344442368,15138.184672355652,7008.161518335342,2.0013952255249023,0.0 -47100,0.15157686,0.02233933,,,,,,,,,,,,,,,,, -47200,0.16642232,0.02344006,,,,,,,,,,,,,,,,, -47300,0.1889987,0.022129038,,,,,,,,,,,,,,,,, -47400,0.13716142,0.018573446,,,,,,,,,,,,,,,,, -47500,0.12776445,0.021537751,,,,,,,,,,,,,,,,, -47600,0.13969463,0.020715095,,,,,,,,,,,,,,,,, -47700,0.16716145,0.023135591,,,,,,,,,,,,,,,,, -47797,,,0.9933127760887146,0.0208547096699476,0.6183224958318798,0.9866700768470764,0.04813152551651,0.2735240056114583,43793.0,0.9856852293014526,0.051686979830265,0.2522580826705521,43793.0,15378.320225954056,22497.388407230377,15378.320225954056,7115.730270385742,2.038448095321656,0.0 -47800,0.2042473,0.02383445,,,,,,,,,,,,,,,,, -47900,0.15185812,0.02253408,,,,,,,,,,,,,,,,, -48000,0.16555804,0.022339616,,,,,,,,,,,,,,,,, -48100,0.15767074,0.021493614,,,,,,,,,,,,,,,,, -48200,0.123893686,0.020376842,,,,,,,,,,,,,,,,, -48300,0.21379879,0.025320705,,,,,,,,,,,,,,,,, -48400,0.1418642,0.0202673,,,,,,,,,,,,,,,,, -48500,0.16626593,0.023053315,,,,,,,,,,,,,,,,, -48545,,,0.9937769770622252,0.0196090713143348,0.657254856675352,0.9864935278892516,0.0484823510050773,0.2643268234435679,43793.0,0.9857058525085448,0.0517389625310897,0.2559660860530463,43793.0,15618.439760684969,22843.91304540634,15618.439760684969,7222.079265594482,2.07443904876709,0.0 -48600,0.16052492,0.022018408,,,,,,,,,,,,,,,,, -48700,0.1509278,0.022179365,,,,,,,,,,,,,,,,, -48800,0.16608839,0.023159245,,,,,,,,,,,,,,,,, -48900,0.12899667,0.021306293,,,,,,,,,,,,,,,,, -49000,0.15568848,0.021195406,,,,,,,,,,,,,,,,, -49100,0.17193265,0.019880287,,,,,,,,,,,,,,,,, -49200,0.14009783,0.019335806,,,,,,,,,,,,,,,,, -49290,,,0.9935559630393982,0.0202613193541765,0.6167782802913997,0.9866542816162108,0.0487692281603813,0.2706925333525783,43793.0,0.9856612086296082,0.0520920641720294,0.2503146224776895,43793.0,15858.267110586166,23190.93936944008,15858.267110586166,7328.919096469879,2.4131574630737305,0.0 -49300,0.14342351,0.01986459,,,,,,,,,,,,,,,,, -49400,0.17009667,0.02099492,,,,,,,,,,,,,,,,, -49500,0.16247594,0.023927856,,,,,,,,,,,,,,,,, -49600,0.15058844,0.020913834,,,,,,,,,,,,,,,,, -49700,0.19435193,0.02132259,,,,,,,,,,,,,,,,, -49800,0.16104598,0.018763458,,,,,,,,,,,,,,,,, -49900,0.14295053,0.022770759,,,,,,,,,,,,,,,,, -50000,0.15471005,0.020028349,,,,,,,,,,,,,,,,, -50032,,,0.9932066798210144,0.0211821142584085,0.6060510859266648,0.9864216446876526,0.0490487702190876,0.2732556869305221,43793.0,0.9856587052345276,0.0523470975458622,0.2541978403734496,43793.0,16098.356865167618,23540.37098479271,16098.356865167618,7438.20524597168,2.4488024711608887,0.0 -50100,0.1754245,0.025277158,,,,,,,,,,,,,,,,, -50200,0.1479995,0.020199511,,,,,,,,,,,,,,,,, -50300,0.20067881,0.02044453,,,,,,,,,,,,,,,,, -50400,0.17847279,0.02325321,,,,,,,,,,,,,,,,, -50500,0.16736773,0.020811934,,,,,,,,,,,,,,,,, -50600,0.14537567,0.020511933,,,,,,,,,,,,,,,,, -50700,0.17296316,0.020308383,,,,,,,,,,,,,,,,, -50782,,,0.9932138323783876,0.0211225748062133,0.6076637757132363,0.9864943027496338,0.0495653636753559,0.2643919542063089,43793.0,0.9856283664703368,0.0528356507420539,0.2524023510446382,43793.0,16338.490710258484,23888.57576751709,16338.490710258484,7546.218659877777,2.485704183578491,0.0 -50800,0.21054506,0.023553899,,,,,,,,,,,,,,,,, -50900,0.22636674,0.022544198,,,,,,,,,,,,,,,,, -51000,0.18031855,0.021180842,,,,,,,,,,,,,,,,, -51100,0.16215351,0.02139309,,,,,,,,,,,,,,,,, -51200,0.1773092,0.020687195,,,,,,,,,,,,,,,,, -51300,0.17004305,0.019549271,,,,,,,,,,,,,,,,, -51400,0.18500452,0.020245772,,,,,,,,,,,,,,,,, -51500,0.17753378,0.020693446,,,,,,,,,,,,,,,,, -51533,,,0.9932173490524292,0.0210808832198381,0.6193763045917757,0.986382246017456,0.0498485118150711,0.2682000829620698,43793.0,0.9855205416679382,0.0532411523163318,0.2527073064198517,43793.0,16578.616221904755,24237.691730499268,16578.616221904755,7655.152938842773,2.5216963291168213,0.0 -51600,0.18737327,0.020946214,,,,,,,,,,,,,,,,, -51700,0.2253934,0.019990996,,,,,,,,,,,,,,,,, -51800,0.1741776,0.020857163,,,,,,,,,,,,,,,,, -51900,0.17217277,0.018927805,,,,,,,,,,,,,,,,, -52000,0.20032667,0.021285316,,,,,,,,,,,,,,,,, -52100,0.17521629,0.019264719,,,,,,,,,,,,,,,,, -52200,0.20452657,0.021327404,,,,,,,,,,,,,,,,, -52279,,,0.9932751059532166,0.0206806678324937,0.6200296858576568,0.9865409731864928,0.0498331263661384,0.2641053459337334,43793.0,0.9856444001197816,0.0534443370997905,0.2497437597936006,43793.0,16818.712033748627,24579.483705997467,16818.712033748627,7756.792923927307,2.5575668811798096,0.0 -52300,0.1984582,0.02137154,,,,,,,,,,,,,,,,, -52400,0.16215527,0.02062321,,,,,,,,,,,,,,,,, -52500,0.19027826,0.01806189,,,,,,,,,,,,,,,,, -52600,0.17172533,0.017841486,,,,,,,,,,,,,,,,, -52700,0.167183,0.019187534,,,,,,,,,,,,,,,,, -52800,0.24893631,0.022295479,,,,,,,,,,,,,,,,, -52900,0.17084134,0.020159181,,,,,,,,,,,,,,,,, -53000,0.24278271,0.022667501,,,,,,,,,,,,,,,,, -53026,,,0.9934790134429932,0.0199831202626228,0.6333494312558867,0.9864338040351868,0.0501837916672229,0.2680832358381055,43793.0,0.9855024218559264,0.0538224689662456,0.2479787630190198,43793.0,17058.962177991867,24927.917595624924,17058.962177991867,7864.91987657547,2.594414710998535,0.0 -53100,0.25487897,0.019202726,,,,,,,,,,,,,,,,, -53200,0.18583392,0.016397288,,,,,,,,,,,,,,,,, -53300,0.21149874,0.017613353,,,,,,,,,,,,,,,,, -53400,0.19886431,0.019822665,,,,,,,,,,,,,,,,, -53500,0.19465232,0.020796293,,,,,,,,,,,,,,,,, -53600,0.20071456,0.020289956,,,,,,,,,,,,,,,,, -53700,0.19440821,0.019639108,,,,,,,,,,,,,,,,, -53772,,,0.9938586950302124,0.0189154967665672,0.6609112039038493,0.986370086669922,0.0502746887505054,0.2667339680899637,43793.0,0.9854165315628052,0.0540853217244148,0.2510609877039152,43793.0,17298.90731549263,25274.465651512142,17298.90731549263,7971.463971138,2.6330935955047607,0.0 -53800,0.16434348,0.017995574,,,,,,,,,,,,,,,,, -53900,0.177148,0.01890974,,,,,,,,,,,,,,,,, -54000,0.23482433,0.020461457,,,,,,,,,,,,,,,,, -54100,0.20498776,0.018539922,,,,,,,,,,,,,,,,, -54200,0.2221066,0.0187823,,,,,,,,,,,,,,,,, -54300,0.1883635,0.021731881,,,,,,,,,,,,,,,,, -54400,0.21518591,0.019991785,,,,,,,,,,,,,,,,, -54500,0.22759075,0.018700879,,,,,,,,,,,,,,,,, -54520,,,0.9943572282791138,0.0174389947205781,0.6922163307395394,0.9864155650138856,0.0508133694529533,0.262530649754857,43793.0,0.9854194521903992,0.0547283142805099,0.2489122732522624,43793.0,17538.941791057587,25618.02638888359,17538.941791057587,8074.933919668198,2.669600009918213,0.0 -54600,0.2288263,0.015949639,,,,,,,,,,,,,,,,, -54700,0.20278609,0.020036077,,,,,,,,,,,,,,,,, -54800,0.2329768,0.018662296,,,,,,,,,,,,,,,,, -54900,0.20044935,0.019160803,,,,,,,,,,,,,,,,, -55000,0.22090724,0.01904103,,,,,,,,,,,,,,,,, -55100,0.21271025,0.019341998,,,,,,,,,,,,,,,,, -55200,0.19175869,0.01772143,,,,,,,,,,,,,,,,, -55267,,,0.9947988390922546,0.0162368603050708,0.7333053532778322,0.9862548112869264,0.0513133332133293,0.2627462523088922,43793.0,0.9853533506393432,0.0552709624171257,0.2484196984421053,43793.0,17779.046933174133,25964.97895693779,17779.046933174133,8181.723499298096,2.7075655460357666,0.0 -55300,0.21499027,0.01879877,,,,,,,,,,,,,,,,, -55400,0.22784226,0.021656398,,,,,,,,,,,,,,,,, -55500,0.1881943,0.017866297,,,,,,,,,,,,,,,,, -55600,0.21472976,0.017768444,,,,,,,,,,,,,,,,, -55700,0.19372799,0.018774308,,,,,,,,,,,,,,,,, -55800,0.19987313,0.016671324,,,,,,,,,,,,,,,,, -55900,0.19001158,0.015459053,,,,,,,,,,,,,,,,, -56000,0.18828148,0.015452033,,,,,,,,,,,,,,,,, -56017,,,0.9949828386306764,0.0157982297241687,0.7273481342521675,0.9862483143806458,0.0518519133329391,0.2608411574989082,43793.0,0.9853945970535278,0.0558489337563514,0.2532018150713649,43793.0,18019.02331376076,26313.56055831909,18019.02331376076,8290.270651817322,2.7457399368286133,0.0 -56100,0.19737788,0.018561747,,,,,,,,,,,,,,,,, -56200,0.22346011,0.018832916,,,,,,,,,,,,,,,,, -56300,0.19951978,0.015140264,,,,,,,,,,,,,,,,, -56400,0.19286299,0.016601708,,,,,,,,,,,,,,,,, -56500,0.18963781,0.017638879,,,,,,,,,,,,,,,,, -56600,0.22036724,0.018290155,,,,,,,,,,,,,,,,, -56700,0.17831491,0.016774911,,,,,,,,,,,,,,,,, -56766,,,0.9939604997634888,0.0184627342969179,0.6661507763472906,0.9862211346626282,0.0520812310278415,0.2654805796664402,43793.0,0.9853697419166564,0.0558772198855876,0.2519444893929446,43793.0,18259.24850225449,26663.653777122498,18259.24850225449,8400.08038520813,2.7842960357666016,0.0 -56800,0.20971093,0.016925527,,,,,,,,,,,,,,,,, -56900,0.2469034,0.019293783,,,,,,,,,,,,,,,,, -57000,0.21751586,0.019023016,,,,,,,,,,,,,,,,, -57100,0.20274226,0.016092962,,,,,,,,,,,,,,,,, -57200,0.22164497,0.017546287,,,,,,,,,,,,,,,,, -57300,0.19375291,0.015872078,,,,,,,,,,,,,,,,, -57400,0.24418765,0.015394491,,,,,,,,,,,,,,,,, -57446,,,,,,,,,,,,,,18477.078250169754,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/eval_measurements.csv deleted file mode 100644 index afdc91c81..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -107.80116128921507,0.0,12.185586929321287,1,0,12.185586929321287,0.5047287344932556,0.7438743114471436,0.0269311415807948,43793,119.9868142604828,0.5091689229011536,0.7391471266746521,0.0227705634173207,0.5064647793769836,0.7422076463699341,0.0252988304021323,43793 -215.41246700286865,0.0240700244903564,252.15884280204773,745,0,252.15884280204773,0.983149230480194,0.0662031918764114,0.0460700617813824,43793,467.6156971454621,0.9868181943893432,0.0527527704834938,0.0483171707343765,0.9841492176055908,0.063015766441822,0.0477191124991385,43793 -321.669025182724,0.0508973598480224,492.1526997089386,1490,0,492.1526997089386,0.98329496383667,0.0615998916327953,0.0745604968623183,43793,813.9127051830292,0.9869986772537231,0.0491938292980194,0.0780179580921825,0.9842535257339478,0.0585653334856033,0.0744461513894765,43793 -429.0665693283081,0.0775086879730224,732.1452407836914,2240,0,732.1452407836914,0.9836757183074952,0.0591330528259277,0.1034593614915365,43793,1161.3493824005127,0.9874287843704224,0.0456875301897525,0.1117813627247076,0.984661102294922,0.0556657016277313,0.1072617279992498,43793 -531.9732112884521,0.1048529148101806,972.3031947612762,2992,0,972.3031947612762,0.983911156654358,0.0559541843831539,0.1336785072545706,43793,1504.4612703323364,0.987829566001892,0.042909987270832,0.1490976674809723,0.9849030375480652,0.0530146658420562,0.1296165186489274,43793 -637.5870983600616,0.1319179534912109,1212.3853332996368,3743,0,1212.3853332996368,0.9843466877937316,0.0536787956953048,0.1512781169298654,43793,1850.2039790153503,0.9880852103233336,0.0412999391555786,0.1683103936134479,0.9852334856987,0.0507333017885685,0.1494133137828359,43793 -744.434494972229,0.1595597267150879,1452.6483166217804,4498,0,1452.6483166217804,0.98445326089859,0.0527374073863029,0.1563751314541928,43793,2197.36225771904,0.9884041547775269,0.0404745452105999,0.1948218218889239,0.9853795766830444,0.0500588230788707,0.1574081830732877,43793 -851.1025938987732,0.1864519119262695,1692.851990699768,5254,0,1692.851990699768,0.9845661520957948,0.0523319989442825,0.1708755911669754,43793,2544.280503988266,0.9886635541915894,0.0394278690218925,0.1902771159028346,0.9854867458343506,0.0497402250766754,0.165965228441913,43793 -954.8695249557496,0.2135336399078369,1933.0624508857727,6008,0,1933.0624508857727,0.9846979379653932,0.0519571416079998,0.1727666752466523,43793,2888.3051455020905,0.988634705543518,0.0391428656876087,0.1998674039412336,0.9856597185134888,0.0488583073019981,0.18008463135684,43793 -1062.309351682663,0.2424702644348144,2173.258081674576,6753,0,2173.258081674576,0.9844822883605956,0.0525984466075897,0.1626187375082621,43793,3235.989268064499,0.9883594512939452,0.0401516146957874,0.192082496409469,0.9853593111038208,0.0498590171337127,0.168534194680524,43793 -1166.7139718532562,0.2698376178741455,2413.3347930908203,7492,0,2413.3347930908203,0.98464697599411,0.0516998767852783,0.1845904393962085,43793,3580.518236398697,0.9885475635528564,0.0390818305313587,0.1999583433178592,0.9856422543525696,0.048810314387083,0.1828529834684011,43793 -1274.9142456054688,0.2987575531005859,2653.599544286728,8243,0,2653.599544286728,0.9848407506942748,0.0517675057053566,0.1818865562994291,43793,3929.0323138237,0.9886873364448548,0.0387571267783641,0.2114938379224503,0.9858253002166748,0.0486804917454719,0.1836186190699482,43793 -1383.1852016448977,0.3282186985015869,2893.7972707748413,9003,0,2893.7972707748413,0.9847792387008668,0.0514960400760173,0.1920021524379697,43793,4277.550271987915,0.9886422753334044,0.0389967672526836,0.2087046357847387,0.9857802391052246,0.0487343408167362,0.1894534784275115,43793 -1491.027760028839,0.3559896945953369,3133.9032673835754,9760,0,3133.9032673835754,0.9846423864364624,0.0516324192285537,0.1834430491442636,43793,4625.546536207199,0.988835632801056,0.0383559502661228,0.2161694728775924,0.9855862259864808,0.0487497597932815,0.1840018137573714,43793 -1596.7612025737762,0.384077787399292,3374.0375757217407,10503,0,3374.0375757217407,0.9848563075065612,0.0519847050309181,0.1847305208964719,43793,4971.461939096451,0.9887315034866332,0.0383016318082809,0.2269017371348451,0.9857413172721864,0.0489319637417793,0.1896570640462174,43793 -1705.79274225235,0.4135866165161133,3614.018998861313,11254,0,3614.018998861313,0.9848538041114808,0.0512711331248283,0.1845297036709734,43793,5320.525203466415,0.9888266921043396,0.0380459837615489,0.2298554811650588,0.9857701063156128,0.0484095476567745,0.1863925213489288,43793 -1810.805812358856,0.4426114559173584,3854.225952386856,12003,0,3854.225952386856,0.983836591243744,0.0538304969668388,0.1767750627461585,43793,5665.794174194336,0.988034188747406,0.0402605235576629,0.2159543490357493,0.9848340153694152,0.0508625023066997,0.1719574172968483,43793 -1917.3220510482788,0.4725079536437988,4094.2330226898193,12752,0,4094.2330226898193,0.9844094514846802,0.0528549253940582,0.1829934808872338,43793,6012.367550611496,0.9886979460716248,0.0385139733552932,0.212413276491084,0.9853869080543518,0.0496855191886425,0.1816565361883723,43793 -2025.405591011048,0.5020263195037842,4334.211813926697,13508,0,4334.211813926697,0.9848297834396362,0.0509868077933788,0.1925377068346132,43793,6360.479429960251,0.98878014087677,0.0384350679814815,0.2341694003337173,0.9857327938079834,0.0484113171696662,0.18682739612154,43793 -2134.202548503876,0.5312979221343994,4574.353425979614,14261,0,4574.353425979614,0.9848424196243286,0.0520608276128768,0.1777015158910258,43793,6709.469097614288,0.9888355135917664,0.0383663959801197,0.2215462984661317,0.9857721328735352,0.0490221828222274,0.1812746860597814,43793 -2244.846806049347,0.5642621517181396,4814.332322835922,14993,0,4814.332322835922,0.984917402267456,0.0510648749768734,0.1841147673501316,43793,7060.149246931076,0.9888824820518494,0.0379234850406646,0.2257129817583256,0.9858293533325196,0.0481456443667411,0.1915524691425417,43793 -2349.802227497101,0.5941531658172607,5054.521810770035,15741,0,5054.521810770035,0.9846630096435548,0.052205454558134,0.1708737522269427,43793,7405.344657182693,0.9887647032737732,0.0387309081852436,0.2083025958420364,0.9856645464897156,0.0490984246134758,0.1780559326371,43793 -2456.497183799744,0.624169111251831,5294.745937824249,16489,0,5294.745937824249,0.9848487377166748,0.0515647567808628,0.1753466000554308,43793,7752.313290119171,0.9889504313468932,0.037890437990427,0.2155687895466367,0.9857696890830994,0.0486704371869564,0.1825648845056662,43793 -2561.243638038636,0.6534569263458252,5534.832649469376,17242,0,5534.832649469376,0.9849043488502502,0.0510930605232715,0.1818741040714724,43793,8097.196264505386,0.9890180826187134,0.0376419089734554,0.2201965623201106,0.9857839345932008,0.0482412688434124,0.1835722451651859,43793 -2669.7339446544647,0.6886923313140869,5775.044799804688,17989,0,5775.044799804688,0.9846507906913756,0.0521830730140209,0.1831122481482856,43793,8445.953789234161,0.9887692928314208,0.0380432158708572,0.227398046227228,0.9855777025222778,0.0491019561886787,0.1862749647511783,43793 -2775.819948196411,0.7185335159301758,6015.091109275818,18740,0,6015.091109275818,0.984965443611145,0.0507509633898735,0.1886865838804157,43793,8792.135896921158,0.9890387058258056,0.0370751470327377,0.2429982094044376,0.9859284162521362,0.0477214083075523,0.1912126206648628,43793 -2880.935016393661,0.7496731281280518,6255.089984178543,19487,0,6255.089984178543,0.9846385717391968,0.0519570372998714,0.1779373353694832,43793,9137.301024913788,0.9886627197265624,0.0383425168693065,0.2338148047161424,0.9854279160499572,0.0490564405918121,0.1800967657073751,43793 -2989.8407900333405,0.7810525894165039,6495.277824878693,20231,0,6495.277824878693,0.9849616289138794,0.0513088703155517,0.1883093383042809,43793,9486.45058965683,0.988877534866333,0.037773884832859,0.2241492141391161,0.985895574092865,0.0481287725269794,0.1903862533615728,43793 -3096.3421173095703,0.8141474723815918,6735.258136510849,20989,0,6735.258136510849,0.9849637150764464,0.0508499443531036,0.1929505335081207,43793,9832.98583817482,0.9890268445014954,0.0376813746988773,0.2302380819582901,0.9859223365783693,0.0477700121700763,0.1952306653002494,43793 -3204.208570957184,0.8466379642486572,6975.445188045502,21742,0,6975.445188045502,0.9847792387008668,0.0514329299330711,0.1829139607213623,43793,10181.09177494049,0.9889109134674072,0.0379564836621284,0.228778463186817,0.9857218265533448,0.0485035479068756,0.1889965521575424,43793 -3309.9134817123413,0.8789269924163818,7215.630878448486,22481,0,7215.630878448486,0.984722375869751,0.0518376901745796,0.1856335490853918,43793,10527.035320281982,0.9888846278190612,0.0379314087331295,0.2274528755384141,0.9856629371643066,0.0488991737365722,0.1854297338063141,43793 -3414.523555278778,0.9095168113708496,7455.803935050964,23233,0,7455.803935050964,0.9848858118057252,0.0515651702880859,0.1854361962732914,43793,10871.868999242784,0.9888824224472046,0.0377452857792377,0.2347773893975781,0.9858500957489014,0.0483477227389812,0.1909631095072125,43793 -3517.945199251175,0.9403893947601318,7695.835782766342,23985,0,7695.835782766342,0.9850140810012816,0.0502418987452983,0.1828531633615252,43793,11215.373279094696,0.9890407919883728,0.0375085286796093,0.2314681183015991,0.9859633445739746,0.0474525094032287,0.1906305339511292,43793 -3625.518209695816,0.9731786251068116,7936.069957971573,24741,0,7936.069957971573,0.9850648045539856,0.0509517341852188,0.1927351888664488,43793,11563.233343362808,0.9891539812088012,0.0367405265569686,0.2390435111815593,0.985897183418274,0.0479989349842071,0.1873103703045062,43793 -3730.604122877121,1.4308359622955322,8175.831615447998,25490,0,8175.831615447998,0.985063135623932,0.0512276515364646,0.1845998296445171,43793,11908.558824062347,0.9891579151153564,0.0368153937160968,0.2493061489164509,0.9859349131584167,0.048080027103424,0.1947319230197419,43793 -3837.456799507141,1.4622154235839844,8415.865903615952,26244,0,8415.865903615952,0.9848095774650574,0.0511777698993682,0.1883528025284292,43793,12255.497112989426,0.9890082478523254,0.0371231213212013,0.2619470465234472,0.9857433438301086,0.0480293482542037,0.1981787973978353,43793 -3939.849988222122,1.493131399154663,8655.825356006622,27000,0,8655.825356006622,0.985078752040863,0.0505181364715099,0.1862611174670923,43793,12597.900453329086,0.9892248511314392,0.03641227632761,0.248816781902566,0.9859369397163392,0.047542754560709,0.1895510494609632,43793 -4043.614537715912,1.5242531299591064,8895.789571285248,27749,0,8895.789571285248,0.985129714012146,0.0505303442478179,0.1984891361260043,43793,12941.680038928986,0.9891148209571838,0.0368561372160911,0.2498081527409343,0.986055076122284,0.0475734174251556,0.2024660222425728,43793 -4149.741677761078,1.555681228637695,9136.02029132843,28500,0,9136.02029132843,0.9850189089775084,0.0510314255952835,0.1900455540619972,43793,13288.089123249054,0.9891347885131836,0.0373249910771846,0.2263010704364689,0.9859467148780824,0.0480697453022003,0.1921317430397762,43793 -4256.041279792786,1.5881264209747314,9376.250892162325,29250,0,9376.250892162325,0.9850140810012816,0.0502172112464904,0.1876090685663295,43793,13634.6715593338,0.9891011714935304,0.0372898466885089,0.2516501756761238,0.9858760833740234,0.0475060418248176,0.1935256298053031,43793 -4364.7704746723175,1.6209070682525637,9616.499011993408,30004,0,9616.499011993408,0.9850319623947144,0.0508443377912044,0.1901787682076284,43793,13983.701684236526,0.989096701145172,0.0369564853608608,0.2388970122226297,0.9860494136810304,0.0476928651332855,0.2032375966967203,43793 -4472.7656536102295,1.652916431427002,9856.662395954132,30751,0,9856.662395954132,0.9849410057067872,0.0504181608557701,0.1896190516759259,43793,14331.912987470629,0.9891952872276306,0.0367727316915988,0.2468313038552545,0.9859515428543092,0.0474362857639789,0.2034833376790149,43793 -4583.790787220001,1.6917784214019775,10096.915367603302,31490,0,10096.915367603302,0.9851625561714172,0.050069410353899,0.1946114423504687,43793,14683.252610206604,0.9892831444740297,0.0365368127822876,0.2541385304879885,0.9860900044441224,0.0471078380942344,0.2012027431813133,43793 -4689.847971916199,1.7260665893554688,10337.149015903473,32239,0,10337.149015903473,0.9850172400474548,0.0500676706433296,0.1955522357274692,43793,15029.597774505615,0.9892269372940063,0.0368354618549346,0.2552228813312813,0.985957682132721,0.0472273305058479,0.1943741291666913,43793 -4792.9368715286255,1.7592804431915283,10577.234008789062,32988,0,10577.234008789062,0.9851486682891846,0.0502247475087642,0.1932763301956823,43793,15372.825038433077,0.9894253015518188,0.0358051843941211,0.2698218583076517,0.9860225915908812,0.0472880974411964,0.2004973368649585,43793 -4895.532358884811,1.7915773391723633,10817.288613796234,33734,0,10817.288613796234,0.9850395321846008,0.0497498586773872,0.1964414351587808,43793,15715.528035879135,0.9893887639045716,0.0357852131128311,0.2618129740028855,0.9860676527023317,0.0469548366963863,0.197591852607744,43793 -5000.842721700668,1.825816631317139,11057.456293582916,34482,0,11057.456293582916,0.985135555267334,0.0503055602312088,0.1877756219260508,43793,16061.060340881348,0.9893254041671752,0.0362496450543403,0.2509931044289288,0.9861057996749878,0.0473644398152828,0.1936171298186332,43793 -5103.608122110367,1.8610637187957764,11297.589622735975,35237,0,11297.589622735975,0.9851351380348206,0.0506232194602489,0.1930257648180393,43793,16404.014665842056,0.9891011714935304,0.0366949811577796,0.2536961986611331,0.9861447811126708,0.0473889149725437,0.2023330538214926,43793 -5205.648674488068,1.895133256912232,11537.72613310814,35987,0,11537.72613310814,0.9851229190826416,0.0503481589257717,0.1972139998452186,43793,16746.245433568954,0.9890627861022948,0.0370471514761447,0.2397458856785428,0.9860416650772096,0.0474542863667011,0.1940947287240224,43793 -5312.391384601593,1.928436040878296,11777.842349767683,36737,0,11777.842349767683,0.9852808713912964,0.0496365427970886,0.2027815439346703,43793,17093.15770673752,0.9894454479217528,0.0358574204146862,0.2619400303030704,0.9862446784973145,0.046729139983654,0.2127803122106341,43793 -5414.43673324585,1.9633357524871824,12017.802143096924,37487,0,12017.802143096924,0.9852720499038696,0.0496285483241081,0.2091661338432342,43793,17435.217926979065,0.9894551038742064,0.0357481762766838,0.2561994339752247,0.9861322045326232,0.0467582568526268,0.2066152129075412,43793 -5521.393639087677,1.9962775707244875,12257.834669589996,38244,0,12257.834669589996,0.9849477410316468,0.0501863174140453,0.1923162326230432,43793,17782.260478019714,0.9893867373466492,0.0362753197550773,0.2601817317611524,0.985926389694214,0.0473117418587207,0.2024692703405251,43793 -5623.934894800186,2.029609203338623,12497.840120315552,39001,0,12497.840120315552,0.9850168228149414,0.0501539148390293,0.1939370750438576,43793,18124.86086988449,0.9894042611122132,0.0361329726874828,0.2705461838739672,0.9859158396720886,0.0472590513527393,0.2061454407305062,43793 -5728.363832473755,2.065474271774292,12737.917171955109,39756,0,12737.917171955109,0.9850248098373412,0.0500476472079753,0.1950534085546705,43793,18469.422612428665,0.9893486499786376,0.0360335744917392,0.256148759447705,0.9859779477119446,0.0470002144575119,0.2051794239423218,43793 -5834.631159543991,2.10075044631958,12977.919306516647,40505,0,12977.919306516647,0.9851389527320862,0.0496848039329052,0.1942995179941897,43793,18815.74730205536,0.9895033836364746,0.035438735038042,0.2753709398089208,0.986076593399048,0.0467179007828235,0.2022420670019095,43793 -5937.405626773834,2.137742519378662,13218.008937835692,41252,0,13218.008937835692,0.9852118492126464,0.0496916957199573,0.1984080443646017,43793,19158.668475151066,0.9897860288619996,0.0345465615391731,0.2853760124042723,0.9862000346183776,0.0466513372957706,0.2100652647208077,43793 -6043.70458650589,2.17232346534729,13458.234591960909,42004,0,13458.234591960909,0.9853630065917968,0.0496775656938552,0.1955947536136331,43793,19505.24766516685,0.989495038986206,0.0356295444071292,0.266926135120057,0.9862678050994872,0.0465229675173759,0.2150185824859295,43793 -6148.046504974365,2.210094928741455,13698.329906463625,42737,0,13698.329906463625,0.9852128624916076,0.0497658960521221,0.2048573545565305,43793,19849.74711751938,0.9894113540649414,0.0357576794922351,0.2688181003628976,0.9861470460891724,0.0467201247811317,0.2114273808923142,43793 -6251.22530412674,2.2450668811798096,13938.442008972168,43494,0,13938.442008972168,0.9853647351264954,0.0495802052319049,0.2040145778814236,43793,20193.09298181533,0.989499032497406,0.0354002267122268,0.2650605700785743,0.9862678050994872,0.0466506890952587,0.2114251127695588,43793 -6353.971425771713,2.28100848197937,14178.594394683838,44251,0,14178.594394683838,0.9851962327957152,0.0495260991156101,0.1989728517536515,43793,20536.04729604721,0.9895753860473632,0.0354918055236339,0.2742665919504021,0.9862418174743652,0.0463673397898674,0.2104525270200389,43793 -6458.873847723007,2.3189632892608643,14418.677802801132,45006,0,14418.677802801132,0.9851511716842652,0.0497616715729236,0.1953420900467133,43793,20881.09065937996,0.9895598888397216,0.0351086929440498,0.2776030484718139,0.9861411452293396,0.0466781482100486,0.2058879765164011,43793 -6560.754787683487,2.355625629425049,14658.762327671053,45765,0,14658.762327671053,0.9853546023368835,0.0493918880820274,0.2038367222214243,43793,21223.11307501793,0.9897001385688782,0.0346075557172298,0.2768291465979394,0.9862962365150452,0.0462343730032444,0.2151871670410814,43793 -6664.698795795441,2.390665292739868,14898.90408039093,46515,0,14898.90408039093,0.985383689403534,0.0493339970707893,0.207040467005581,43793,21567.253492355347,0.989697515964508,0.0346462801098823,0.2963236467176596,0.9863830804824828,0.0461990013718605,0.2210658521639429,43793 -6770.398800849915,2.428494691848755,15138.879612207413,47270,0,15138.879612207413,0.9853179454803468,0.0488965511322021,0.2030546784519302,43793,21912.986404657364,0.9897758364677428,0.0343558825552463,0.3042403984564562,0.9862929582595824,0.0459770634770393,0.212829559312736,43793 -6871.342316389084,2.464463233947754,15379.08549618721,48027,0,15379.08549618721,0.9853202700614928,0.049766506999731,0.2046247702905341,43793,22254.191664218903,0.98962664604187,0.0343369282782077,0.286865701081385,0.98627507686615,0.0465668700635433,0.2209927814720388,43793 -6974.270753145218,2.500919818878174,15619.100360155106,48784,0,15619.100360155106,0.9854194521903992,0.0493343286216259,0.2104642647012252,43793,22597.190956115723,0.9900445342063904,0.0335013419389724,0.3042157934923303,0.9863396286964417,0.0461429804563522,0.2207973284185761,43793 -7074.364864110947,2.5379621982574463,15859.136625528336,49536,0,15859.136625528336,0.9853870272636414,0.0492770634591579,0.2131225694710239,43793,22937.37840652466,0.9898927807807922,0.0341162048280239,0.3005784614493065,0.986319363117218,0.0463139228522777,0.2191926410291987,43793 -7178.855273962021,2.573608636856079,16099.339428424835,50294,0,16099.339428424835,0.9856048226356506,0.0489003583788871,0.2149699454561261,43793,23282.12690448761,0.9897878170013428,0.0343374647200107,0.2856928937956072,0.9864285588264464,0.0458579137921333,0.2164103157897776,43793 -7282.207628250122,2.6092591285705566,16339.413675069807,51040,0,16339.413675069807,0.9853739738464355,0.0486579872667789,0.2081253269428495,43793,23625.6093814373,0.9899283647537231,0.0339838527143001,0.3022169393236976,0.9863327741622924,0.04576126486063,0.2156734050628471,43793 -7388.165745735168,2.646596670150757,16579.61815905571,51787,0,16579.61815905571,0.9853882789611816,0.0487843081355094,0.2069404616550198,43793,23971.82929778099,0.9897343516349792,0.0343607477843761,0.2915450888909189,0.986326277256012,0.045857734978199,0.2213827129714896,43793 -7490.901193618774,2.687276124954224,16819.64329266548,52526,0,16819.64329266548,0.9855268597602844,0.0486095622181892,0.212095926354634,43793,24314.65186572075,0.9898606538772584,0.0337282866239547,0.308960692285714,0.9864127039909364,0.0457212105393409,0.2221381761078136,43793 -7597.760349988937,2.723623275756836,17059.861709833145,53280,0,17059.861709833145,0.9855159521102904,0.0484535545110702,0.2134460965103021,43793,24661.785895824432,0.9900662302970886,0.0333186835050582,0.3060979508321819,0.9864590167999268,0.045574489980936,0.2243359206043439,43793 -7701.634706735611,2.759958982467652,17299.852272987366,54027,0,17299.852272987366,0.985465407371521,0.0490322783589363,0.2094928839266382,43793,25005.707471370697,0.9900087714195251,0.0333698242902755,0.3049010737919717,0.9865007996559144,0.0456825047731399,0.2264214793265886,43793 -7803.162973642349,2.796483993530273,17539.843241930008,54781,0,17539.843241930008,0.985546052455902,0.0486128292977809,0.2145471011356968,43793,25347.28339076042,0.9901589155197144,0.032963290810585,0.3163235677717774,0.9865801930427552,0.0455623939633369,0.2280526233849575,43793 -7910.455152988434,2.8337621688842773,17779.891257047653,55533,0,17779.891257047653,0.985528528690338,0.048599362373352,0.2175347020578288,43793,25694.680911540985,0.990270733833313,0.0325558632612228,0.323719534949942,0.9865257740020752,0.0453366674482822,0.2321489278343596,43793 -8012.519022226334,2.870720386505127,18020.05457186699,56285,0,18020.05457186699,0.9855323433876038,0.0487038306891918,0.2174664089875856,43793,26036.965087890625,0.9902195930480956,0.0326430946588516,0.3289052356829938,0.9864122867584229,0.0456474311649799,0.2347272987081529,43793 -8117.47713804245,2.908522367477417,18260.280821561813,57039,0,18260.280821561813,0.9855504631996155,0.04861481487751007,0.2146260371624704,43793,26382.207079172134,0.9902625679969788,0.032707177102565765,0.3176695183003912,0.9864926934242249,0.04545671492815018,0.2280582942071802,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/measurements.csv deleted file mode 100644 index 9f7384459..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/measurements.csv +++ /dev/null @@ -1,657 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3577504,0.7395575,,,,,,,,,,,,,,,,, -1,,,0.5091689229011536,0.7391471266746521,0.0227705634173207,0.5064647793769836,0.7422076463699341,0.0252988304021323,43793.0,0.5047287344932556,0.7438743114471436,0.0269311415807948,43793.0,12.185586929321287,119.9868142604828,12.185586929321287,107.80116128921507,0.0,0.0 -100,0.117479414,0.12054831,,,,,,,,,,,,,,,,, -200,0.009640586,0.05104791,,,,,,,,,,,,,,,,, -300,0.038240507,0.05818461,,,,,,,,,,,,,,,,, -400,0.01950717,0.053819984,,,,,,,,,,,,,,,,, -500,0.015619927,0.057029653,,,,,,,,,,,,,,,,, -600,0.010656175,0.05191454,,,,,,,,,,,,,,,,, -700,0.013504599,0.05250624,,,,,,,,,,,,,,,,, -745,,,0.9868181943893432,0.0527527704834938,0.0483171707343765,0.9841492176055908,0.063015766441822,0.0477191124991385,43793.0,0.983149230480194,0.0662031918764114,0.0460700617813824,43793.0,252.15884280204773,467.6156971454621,252.15884280204773,215.41246700286865,0.0240700244903564,0.0 -800,0.009882028,0.050166186,,,,,,,,,,,,,,,,, -900,0.012532623,0.050259784,,,,,,,,,,,,,,,,, -1000,0.010208196,0.052873574,,,,,,,,,,,,,,,,, -1100,0.009962817,0.050683305,,,,,,,,,,,,,,,,, -1200,0.007068973,0.0480172,,,,,,,,,,,,,,,,, -1300,0.009565147,0.046699185,,,,,,,,,,,,,,,,, -1400,0.010800557,0.048305456,,,,,,,,,,,,,,,,, -1490,,,0.9869986772537231,0.0491938292980194,0.0780179580921825,0.9842535257339478,0.0585653334856033,0.0744461513894765,43793.0,0.98329496383667,0.0615998916327953,0.0745604968623183,43793.0,492.1526997089386,813.9127051830292,492.1526997089386,321.669025182724,0.0508973598480224,0.0 -1500,0.016901027,0.051463954,,,,,,,,,,,,,,,,, -1600,0.009415901,0.043663017,,,,,,,,,,,,,,,,, -1700,0.0062868833,0.050853707,,,,,,,,,,,,,,,,, -1800,0.007637621,0.05619156,,,,,,,,,,,,,,,,, -1900,0.005671479,0.051082984,,,,,,,,,,,,,,,,, -2000,0.021941433,0.054319978,,,,,,,,,,,,,,,,, -2100,0.011747443,0.051753685,,,,,,,,,,,,,,,,, -2200,0.014752629,0.040549267,,,,,,,,,,,,,,,,, -2240,,,0.9874287843704224,0.0456875301897525,0.1117813627247076,0.984661102294922,0.0556657016277313,0.1072617279992498,43793.0,0.9836757183074952,0.0591330528259277,0.1034593614915365,43793.0,732.1452407836914,1161.3493824005127,732.1452407836914,429.0665693283081,0.0775086879730224,0.0 -2300,0.016485032,0.040136013,,,,,,,,,,,,,,,,, -2400,0.0125034,0.043962713,,,,,,,,,,,,,,,,, -2500,0.007355565,0.04522062,,,,,,,,,,,,,,,,, -2600,0.013074838,0.0443658,,,,,,,,,,,,,,,,, -2700,0.020992804,0.039467964,,,,,,,,,,,,,,,,, -2800,0.024946412,0.042453133,,,,,,,,,,,,,,,,, -2900,0.011828899,0.042234913,,,,,,,,,,,,,,,,, -2992,,,0.987829566001892,0.042909987270832,0.1490976674809723,0.9849030375480652,0.0530146658420562,0.1296165186489274,43793.0,0.983911156654358,0.0559541843831539,0.1336785072545706,43793.0,972.3031947612762,1504.4612703323364,972.3031947612762,531.9732112884521,0.1048529148101806,0.0 -3000,0.033655893,0.040872026,,,,,,,,,,,,,,,,, -3100,0.008348005,0.041897014,,,,,,,,,,,,,,,,, -3200,0.018534048,0.045973916,,,,,,,,,,,,,,,,, -3300,0.012251414,0.03814528,,,,,,,,,,,,,,,,, -3400,0.048224654,0.046020888,,,,,,,,,,,,,,,,, -3500,0.048359748,0.04812019,,,,,,,,,,,,,,,,, -3600,0.03147209,0.04387403,,,,,,,,,,,,,,,,, -3700,0.045016482,0.041038107,,,,,,,,,,,,,,,,, -3743,,,0.9880852103233336,0.0412999391555786,0.1683103936134479,0.9852334856987,0.0507333017885685,0.1494133137828359,43793.0,0.9843466877937316,0.0536787956953048,0.1512781169298654,43793.0,1212.3853332996368,1850.2039790153503,1212.3853332996368,637.5870983600616,0.1319179534912109,0.0 -3800,0.17684767,0.0453629,,,,,,,,,,,,,,,,, -3900,0.017617611,0.042479876,,,,,,,,,,,,,,,,, -4000,0.027233621,0.034789715,,,,,,,,,,,,,,,,, -4100,0.042613674,0.03974522,,,,,,,,,,,,,,,,, -4200,0.03270555,0.03974466,,,,,,,,,,,,,,,,, -4300,0.01982747,0.04003552,,,,,,,,,,,,,,,,, -4400,0.054860104,0.04207107,,,,,,,,,,,,,,,,, -4498,,,0.9884041547775269,0.0404745452105999,0.1948218218889239,0.9853795766830444,0.0500588230788707,0.1574081830732877,43793.0,0.98445326089859,0.0527374073863029,0.1563751314541928,43793.0,1452.6483166217804,2197.36225771904,1452.6483166217804,744.434494972229,0.1595597267150879,0.0 -4500,0.019421918,0.03845976,,,,,,,,,,,,,,,,, -4600,0.035457637,0.042838767,,,,,,,,,,,,,,,,, -4700,0.034061275,0.03797273,,,,,,,,,,,,,,,,, -4800,0.041369185,0.043605197,,,,,,,,,,,,,,,,, -4900,0.046156794,0.047457848,,,,,,,,,,,,,,,,, -5000,0.044483334,0.042184014,,,,,,,,,,,,,,,,, -5100,0.031835526,0.03782045,,,,,,,,,,,,,,,,, -5200,0.055594187,0.036823135,,,,,,,,,,,,,,,,, -5254,,,0.9886635541915894,0.0394278690218925,0.1902771159028346,0.9854867458343506,0.0497402250766754,0.165965228441913,43793.0,0.9845661520957948,0.0523319989442825,0.1708755911669754,43793.0,1692.851990699768,2544.280503988266,1692.851990699768,851.1025938987732,0.1864519119262695,0.0 -5300,0.023190005,0.036871254,,,,,,,,,,,,,,,,, -5400,0.05843465,0.038929917,,,,,,,,,,,,,,,,, -5500,0.04311291,0.044146247,,,,,,,,,,,,,,,,, -5600,0.028194299,0.03742474,,,,,,,,,,,,,,,,, -5700,0.023847764,0.037760597,,,,,,,,,,,,,,,,, -5800,0.04406922,0.042267792,,,,,,,,,,,,,,,,, -5900,0.036433555,0.043356158,,,,,,,,,,,,,,,,, -6000,0.020273335,0.03494073,,,,,,,,,,,,,,,,, -6008,,,0.988634705543518,0.0391428656876087,0.1998674039412336,0.9856597185134888,0.0488583073019981,0.18008463135684,43793.0,0.9846979379653932,0.0519571416079998,0.1727666752466523,43793.0,1933.0624508857727,2888.3051455020905,1933.0624508857727,954.8695249557496,0.2135336399078369,0.0 -6100,0.0324099,0.04364893,,,,,,,,,,,,,,,,, -6200,0.040075533,0.03768764,,,,,,,,,,,,,,,,, -6300,0.056966342,0.040340062,,,,,,,,,,,,,,,,, -6400,0.050821852,0.03831524,,,,,,,,,,,,,,,,, -6500,0.02297011,0.03742421,,,,,,,,,,,,,,,,, -6600,0.029970251,0.040451445,,,,,,,,,,,,,,,,, -6700,0.04134276,0.042903423,,,,,,,,,,,,,,,,, -6753,,,0.9883594512939452,0.0401516146957874,0.192082496409469,0.9853593111038208,0.0498590171337127,0.168534194680524,43793.0,0.9844822883605956,0.0525984466075897,0.1626187375082621,43793.0,2173.258081674576,3235.989268064499,2173.258081674576,1062.309351682663,0.2424702644348144,0.0 -6800,0.04075159,0.036587484,,,,,,,,,,,,,,,,, -6900,0.026625207,0.03622286,,,,,,,,,,,,,,,,, -7000,0.024209239,0.038462162,,,,,,,,,,,,,,,,, -7100,0.030473577,0.033905093,,,,,,,,,,,,,,,,, -7200,0.0258603,0.036836687,,,,,,,,,,,,,,,,, -7300,0.08064461,0.040741123,,,,,,,,,,,,,,,,, -7400,0.026672674,0.040713955,,,,,,,,,,,,,,,,, -7492,,,0.9885475635528564,0.0390818305313587,0.1999583433178592,0.9856422543525696,0.048810314387083,0.1828529834684011,43793.0,0.98464697599411,0.0516998767852783,0.1845904393962085,43793.0,2413.3347930908203,3580.518236398697,2413.3347930908203,1166.7139718532562,0.2698376178741455,0.0 -7500,0.09865828,0.041088965,,,,,,,,,,,,,,,,, -7600,0.054843247,0.041207768,,,,,,,,,,,,,,,,, -7700,0.031358805,0.043712292,,,,,,,,,,,,,,,,, -7800,0.03881483,0.036029276,,,,,,,,,,,,,,,,, -7900,0.048256535,0.03802056,,,,,,,,,,,,,,,,, -8000,0.06139671,0.0381462,,,,,,,,,,,,,,,,, -8100,0.03772462,0.041078296,,,,,,,,,,,,,,,,, -8200,0.06791566,0.037060443,,,,,,,,,,,,,,,,, -8243,,,0.9886873364448548,0.0387571267783641,0.2114938379224503,0.9858253002166748,0.0486804917454719,0.1836186190699482,43793.0,0.9848407506942748,0.0517675057053566,0.1818865562994291,43793.0,2653.599544286728,3929.0323138237,2653.599544286728,1274.9142456054688,0.2987575531005859,0.0 -8300,0.033504527,0.043830656,,,,,,,,,,,,,,,,, -8400,0.020050198,0.035188463,,,,,,,,,,,,,,,,, -8500,0.055279925,0.043557003,,,,,,,,,,,,,,,,, -8600,0.023280183,0.038079195,,,,,,,,,,,,,,,,, -8700,0.0783518,0.039984837,,,,,,,,,,,,,,,,, -8800,0.0332893,0.037175182,,,,,,,,,,,,,,,,, -8900,0.026506275,0.039841846,,,,,,,,,,,,,,,,, -9000,0.063192695,0.038968705,,,,,,,,,,,,,,,,, -9003,,,0.9886422753334044,0.0389967672526836,0.2087046357847387,0.9857802391052246,0.0487343408167362,0.1894534784275115,43793.0,0.9847792387008668,0.0514960400760173,0.1920021524379697,43793.0,2893.7972707748413,4277.550271987915,2893.7972707748413,1383.1852016448977,0.3282186985015869,0.0 -9100,0.04476434,0.044425987,,,,,,,,,,,,,,,,, -9200,0.07698731,0.039958116,,,,,,,,,,,,,,,,, -9300,0.07239662,0.04249152,,,,,,,,,,,,,,,,, -9400,0.034004316,0.03588864,,,,,,,,,,,,,,,,, -9500,0.0657211,0.039139345,,,,,,,,,,,,,,,,, -9600,0.0507887,0.03520561,,,,,,,,,,,,,,,,, -9700,0.026308803,0.036147382,,,,,,,,,,,,,,,,, -9760,,,0.988835632801056,0.0383559502661228,0.2161694728775924,0.9855862259864808,0.0487497597932815,0.1840018137573714,43793.0,0.9846423864364624,0.0516324192285537,0.1834430491442636,43793.0,3133.9032673835754,4625.546536207199,3133.9032673835754,1491.027760028839,0.3559896945953369,0.0 -9800,0.02328144,0.041518506,,,,,,,,,,,,,,,,, -9900,0.031107709,0.036183417,,,,,,,,,,,,,,,,, -10000,0.035417225,0.04222198,,,,,,,,,,,,,,,,, -10100,0.045951724,0.039502904,,,,,,,,,,,,,,,,, -10200,0.039932,0.042981666,,,,,,,,,,,,,,,,, -10300,0.050467093,0.03869482,,,,,,,,,,,,,,,,, -10400,0.035225317,0.034953807,,,,,,,,,,,,,,,,, -10500,0.024642562,0.037786737,,,,,,,,,,,,,,,,, -10503,,,0.9887315034866332,0.0383016318082809,0.2269017371348451,0.9857413172721864,0.0489319637417793,0.1896570640462174,43793.0,0.9848563075065612,0.0519847050309181,0.1847305208964719,43793.0,3374.0375757217407,4971.461939096451,3374.0375757217407,1596.7612025737762,0.384077787399292,0.0 -10600,0.07118722,0.042183504,,,,,,,,,,,,,,,,, -10700,0.036366653,0.03864407,,,,,,,,,,,,,,,,, -10800,0.04760532,0.037796333,,,,,,,,,,,,,,,,, -10900,0.06264858,0.03895353,,,,,,,,,,,,,,,,, -11000,0.09792163,0.04000336,,,,,,,,,,,,,,,,, -11100,0.045434464,0.037687223,,,,,,,,,,,,,,,,, -11200,0.04630467,0.0369498,,,,,,,,,,,,,,,,, -11254,,,0.9888266921043396,0.0380459837615489,0.2298554811650588,0.9857701063156128,0.0484095476567745,0.1863925213489288,43793.0,0.9848538041114808,0.0512711331248283,0.1845297036709734,43793.0,3614.018998861313,5320.525203466415,3614.018998861313,1705.79274225235,0.4135866165161133,0.0 -11300,0.046100654,0.041460235,,,,,,,,,,,,,,,,, -11400,0.030882793,0.038910326,,,,,,,,,,,,,,,,, -11500,0.023736935,0.037582338,,,,,,,,,,,,,,,,, -11600,0.043260727,0.033617456,,,,,,,,,,,,,,,,, -11700,0.05622556,0.036584068,,,,,,,,,,,,,,,,, -11800,0.029916247,0.03915243,,,,,,,,,,,,,,,,, -11900,0.036177665,0.042230688,,,,,,,,,,,,,,,,, -12000,0.027669776,0.04008108,,,,,,,,,,,,,,,,, -12003,,,0.988034188747406,0.0402605235576629,0.2159543490357493,0.9848340153694152,0.0508625023066997,0.1719574172968483,43793.0,0.983836591243744,0.0538304969668388,0.1767750627461585,43793.0,3854.225952386856,5665.794174194336,3854.225952386856,1810.805812358856,0.4426114559173584,0.0 -12100,0.019291312,0.035373405,,,,,,,,,,,,,,,,, -12200,0.03150459,0.03766592,,,,,,,,,,,,,,,,, -12300,0.026327461,0.03726619,,,,,,,,,,,,,,,,, -12400,0.031024352,0.03584748,,,,,,,,,,,,,,,,, -12500,0.05704779,0.036808446,,,,,,,,,,,,,,,,, -12600,0.051901586,0.03702908,,,,,,,,,,,,,,,,, -12700,0.05327164,0.03604117,,,,,,,,,,,,,,,,, -12752,,,0.9886979460716248,0.0385139733552932,0.212413276491084,0.9853869080543518,0.0496855191886425,0.1816565361883723,43793.0,0.9844094514846802,0.0528549253940582,0.1829934808872338,43793.0,4094.2330226898193,6012.367550611496,4094.2330226898193,1917.3220510482788,0.4725079536437988,0.0 -12800,0.040350754,0.03826846,,,,,,,,,,,,,,,,, -12900,0.04367856,0.039367087,,,,,,,,,,,,,,,,, -13000,0.05659852,0.045335207,,,,,,,,,,,,,,,,, -13100,0.050714877,0.043613393,,,,,,,,,,,,,,,,, -13200,0.038308088,0.034114804,,,,,,,,,,,,,,,,, -13300,0.041704334,0.03957626,,,,,,,,,,,,,,,,, -13400,0.035742,0.041807137,,,,,,,,,,,,,,,,, -13500,0.04048616,0.035464387,,,,,,,,,,,,,,,,, -13508,,,0.98878014087677,0.0384350679814815,0.2341694003337173,0.9857327938079834,0.0484113171696662,0.18682739612154,43793.0,0.9848297834396362,0.0509868077933788,0.1925377068346132,43793.0,4334.211813926697,6360.479429960251,4334.211813926697,2025.405591011048,0.5020263195037842,0.0 -13600,0.046819124,0.03812292,,,,,,,,,,,,,,,,, -13700,0.024580713,0.037145596,,,,,,,,,,,,,,,,, -13800,0.030272491,0.034851074,,,,,,,,,,,,,,,,, -13900,0.04277601,0.032960977,,,,,,,,,,,,,,,,, -14000,0.05835191,0.040280573,,,,,,,,,,,,,,,,, -14100,0.06277782,0.033911765,,,,,,,,,,,,,,,,, -14200,0.023862438,0.033423446,,,,,,,,,,,,,,,,, -14261,,,0.9888355135917664,0.0383663959801197,0.2215462984661317,0.9857721328735352,0.0490221828222274,0.1812746860597814,43793.0,0.9848424196243286,0.0520608276128768,0.1777015158910258,43793.0,4574.353425979614,6709.469097614288,4574.353425979614,2134.202548503876,0.5312979221343994,0.0 -14300,0.046531495,0.041074485,,,,,,,,,,,,,,,,, -14400,0.05372894,0.03901509,,,,,,,,,,,,,,,,, -14500,0.046261087,0.034626883,,,,,,,,,,,,,,,,, -14600,0.049344566,0.039417148,,,,,,,,,,,,,,,,, -14700,0.02111436,0.04306685,,,,,,,,,,,,,,,,, -14800,0.037029885,0.037979912,,,,,,,,,,,,,,,,, -14900,0.031534676,0.038386337,,,,,,,,,,,,,,,,, -14993,,,0.9888824820518494,0.0379234850406646,0.2257129817583256,0.9858293533325196,0.0481456443667411,0.1915524691425417,43793.0,0.984917402267456,0.0510648749768734,0.1841147673501316,43793.0,4814.332322835922,7060.149246931076,4814.332322835922,2244.846806049347,0.5642621517181396,0.0 -15000,0.033988442,0.037351158,,,,,,,,,,,,,,,,, -15100,0.056891073,0.044675123,,,,,,,,,,,,,,,,, -15200,0.04131216,0.04182432,,,,,,,,,,,,,,,,, -15300,0.0465089,0.036673438,,,,,,,,,,,,,,,,, -15400,0.057834912,0.040718626,,,,,,,,,,,,,,,,, -15500,0.026596524,0.037836757,,,,,,,,,,,,,,,,, -15600,0.024389744,0.03790697,,,,,,,,,,,,,,,,, -15700,0.027080677,0.033320505,,,,,,,,,,,,,,,,, -15741,,,0.9887647032737732,0.0387309081852436,0.2083025958420364,0.9856645464897156,0.0490984246134758,0.1780559326371,43793.0,0.9846630096435548,0.052205454558134,0.1708737522269427,43793.0,5054.521810770035,7405.344657182693,5054.521810770035,2349.802227497101,0.5941531658172607,0.0 -15800,0.084624894,0.0379931,,,,,,,,,,,,,,,,, -15900,0.04840392,0.035905123,,,,,,,,,,,,,,,,, -16000,0.032499682,0.037778415,,,,,,,,,,,,,,,,, -16100,0.023646135,0.037685234,,,,,,,,,,,,,,,,, -16200,0.0382811,0.040319078,,,,,,,,,,,,,,,,, -16300,0.043838784,0.038386747,,,,,,,,,,,,,,,,, -16400,0.11423241,0.03934078,,,,,,,,,,,,,,,,, -16489,,,0.9889504313468932,0.037890437990427,0.2155687895466367,0.9857696890830994,0.0486704371869564,0.1825648845056662,43793.0,0.9848487377166748,0.0515647567808628,0.1753466000554308,43793.0,5294.745937824249,7752.313290119171,5294.745937824249,2456.497183799744,0.624169111251831,0.0 -16500,0.037484832,0.036435958,,,,,,,,,,,,,,,,, -16600,0.030036872,0.039710246,,,,,,,,,,,,,,,,, -16700,0.04323624,0.03747676,,,,,,,,,,,,,,,,, -16800,0.035911486,0.039607774,,,,,,,,,,,,,,,,, -16900,0.04315001,0.039304398,,,,,,,,,,,,,,,,, -17000,0.039044555,0.037952825,,,,,,,,,,,,,,,,, -17100,0.048877355,0.034408174,,,,,,,,,,,,,,,,, -17200,0.03755772,0.037725355,,,,,,,,,,,,,,,,, -17242,,,0.9890180826187134,0.0376419089734554,0.2201965623201106,0.9857839345932008,0.0482412688434124,0.1835722451651859,43793.0,0.9849043488502502,0.0510930605232715,0.1818741040714724,43793.0,5534.832649469376,8097.196264505386,5534.832649469376,2561.243638038636,0.6534569263458252,0.0 -17300,0.060871605,0.038200956,,,,,,,,,,,,,,,,, -17400,0.053588107,0.0442413,,,,,,,,,,,,,,,,, -17500,0.04293901,0.039692294,,,,,,,,,,,,,,,,, -17600,0.044565268,0.037364542,,,,,,,,,,,,,,,,, -17700,0.026321864,0.0364126,,,,,,,,,,,,,,,,, -17800,0.04608783,0.036309924,,,,,,,,,,,,,,,,, -17900,0.04625012,0.037851267,,,,,,,,,,,,,,,,, -17989,,,0.9887692928314208,0.0380432158708572,0.227398046227228,0.9855777025222778,0.0491019561886787,0.1862749647511783,43793.0,0.9846507906913756,0.0521830730140209,0.1831122481482856,43793.0,5775.044799804688,8445.953789234161,5775.044799804688,2669.7339446544647,0.6886923313140869,0.0 -18000,0.062290188,0.044526633,,,,,,,,,,,,,,,,, -18100,0.05004039,0.038892373,,,,,,,,,,,,,,,,, -18200,0.027409758,0.035604883,,,,,,,,,,,,,,,,, -18300,0.065940745,0.040512774,,,,,,,,,,,,,,,,, -18400,0.024714949,0.034731027,,,,,,,,,,,,,,,,, -18500,0.046527047,0.03873596,,,,,,,,,,,,,,,,, -18600,0.026049994,0.040489037,,,,,,,,,,,,,,,,, -18700,0.06632776,0.04200948,,,,,,,,,,,,,,,,, -18740,,,0.9890387058258056,0.0370751470327377,0.2429982094044376,0.9859284162521362,0.0477214083075523,0.1912126206648628,43793.0,0.984965443611145,0.0507509633898735,0.1886865838804157,43793.0,6015.091109275818,8792.135896921158,6015.091109275818,2775.819948196411,0.7185335159301758,0.0 -18800,0.055638153,0.04149358,,,,,,,,,,,,,,,,, -18900,0.054742374,0.039979972,,,,,,,,,,,,,,,,, -19000,0.02856995,0.037303876,,,,,,,,,,,,,,,,, -19100,0.02490305,0.036281753,,,,,,,,,,,,,,,,, -19200,0.02908181,0.033635154,,,,,,,,,,,,,,,,, -19300,0.024892293,0.03881007,,,,,,,,,,,,,,,,, -19400,0.06561888,0.039995432,,,,,,,,,,,,,,,,, -19487,,,0.9886627197265624,0.0383425168693065,0.2338148047161424,0.9854279160499572,0.0490564405918121,0.1800967657073751,43793.0,0.9846385717391968,0.0519570372998714,0.1779373353694832,43793.0,6255.089984178543,9137.301024913788,6255.089984178543,2880.935016393661,0.7496731281280518,0.0 -19500,0.039623298,0.04162172,,,,,,,,,,,,,,,,, -19600,0.038888305,0.033831686,,,,,,,,,,,,,,,,, -19700,0.044269096,0.035777017,,,,,,,,,,,,,,,,, -19800,0.07786961,0.037905015,,,,,,,,,,,,,,,,, -19900,0.04434175,0.038336277,,,,,,,,,,,,,,,,, -20000,0.030637804,0.039728142,,,,,,,,,,,,,,,,, -20100,0.057162054,0.036120128,,,,,,,,,,,,,,,,, -20200,0.057227448,0.041605126,,,,,,,,,,,,,,,,, -20231,,,0.988877534866333,0.037773884832859,0.2241492141391161,0.985895574092865,0.0481287725269794,0.1903862533615728,43793.0,0.9849616289138794,0.0513088703155517,0.1883093383042809,43793.0,6495.277824878693,9486.45058965683,6495.277824878693,2989.8407900333405,0.7810525894165039,0.0 -20300,0.030091964,0.03835064,,,,,,,,,,,,,,,,, -20400,0.05490429,0.042145018,,,,,,,,,,,,,,,,, -20500,0.035674926,0.03633442,,,,,,,,,,,,,,,,, -20600,0.10272891,0.037723955,,,,,,,,,,,,,,,,, -20700,0.03030219,0.03911831,,,,,,,,,,,,,,,,, -20800,0.03991645,0.036827125,,,,,,,,,,,,,,,,, -20900,0.023273483,0.033661023,,,,,,,,,,,,,,,,, -20989,,,0.9890268445014954,0.0376813746988773,0.2302380819582901,0.9859223365783693,0.0477700121700763,0.1952306653002494,43793.0,0.9849637150764464,0.0508499443531036,0.1929505335081207,43793.0,6735.258136510849,9832.98583817482,6735.258136510849,3096.3421173095703,0.8141474723815918,0.0 -21000,0.029783932,0.039390825,,,,,,,,,,,,,,,,, -21100,0.044160403,0.041105196,,,,,,,,,,,,,,,,, -21200,0.028732497,0.037758097,,,,,,,,,,,,,,,,, -21300,0.04818837,0.036277194,,,,,,,,,,,,,,,,, -21400,0.05505332,0.03915062,,,,,,,,,,,,,,,,, -21500,0.026424147,0.037247065,,,,,,,,,,,,,,,,, -21600,0.034975633,0.037400156,,,,,,,,,,,,,,,,, -21700,0.045607865,0.035815705,,,,,,,,,,,,,,,,, -21742,,,0.9889109134674072,0.0379564836621284,0.228778463186817,0.9857218265533448,0.0485035479068756,0.1889965521575424,43793.0,0.9847792387008668,0.0514329299330711,0.1829139607213623,43793.0,6975.445188045502,10181.09177494049,6975.445188045502,3204.208570957184,0.8466379642486572,0.0 -21800,0.046829384,0.03710349,,,,,,,,,,,,,,,,, -21900,0.029358242,0.03620053,,,,,,,,,,,,,,,,, -22000,0.03760221,0.035192758,,,,,,,,,,,,,,,,, -22100,0.029053675,0.039548166,,,,,,,,,,,,,,,,, -22200,0.03357747,0.04222906,,,,,,,,,,,,,,,,, -22300,0.031082353,0.037663378,,,,,,,,,,,,,,,,, -22400,0.09377913,0.039066553,,,,,,,,,,,,,,,,, -22481,,,0.9888846278190612,0.0379314087331295,0.2274528755384141,0.9856629371643066,0.0488991737365722,0.1854297338063141,43793.0,0.984722375869751,0.0518376901745796,0.1856335490853918,43793.0,7215.630878448486,10527.035320281982,7215.630878448486,3309.9134817123413,0.8789269924163818,0.0 -22500,0.040683858,0.041549455,,,,,,,,,,,,,,,,, -22600,0.09415896,0.039156716,,,,,,,,,,,,,,,,, -22700,0.07010933,0.037963774,,,,,,,,,,,,,,,,, -22800,0.033433247,0.040314667,,,,,,,,,,,,,,,,, -22900,0.039768603,0.034986507,,,,,,,,,,,,,,,,, -23000,0.069532335,0.040562674,,,,,,,,,,,,,,,,, -23100,0.061531637,0.03574948,,,,,,,,,,,,,,,,, -23200,0.052937638,0.04341176,,,,,,,,,,,,,,,,, -23233,,,0.9888824224472046,0.0377452857792377,0.2347773893975781,0.9858500957489014,0.0483477227389812,0.1909631095072125,43793.0,0.9848858118057252,0.0515651702880859,0.1854361962732914,43793.0,7455.803935050964,10871.868999242784,7455.803935050964,3414.523555278778,0.9095168113708496,0.0 -23300,0.051318448,0.03939635,,,,,,,,,,,,,,,,, -23400,0.045719326,0.03577848,,,,,,,,,,,,,,,,, -23500,0.05101668,0.034734625,,,,,,,,,,,,,,,,, -23600,0.036604024,0.036357082,,,,,,,,,,,,,,,,, -23700,0.03463142,0.036536682,,,,,,,,,,,,,,,,, -23800,0.040557567,0.041453637,,,,,,,,,,,,,,,,, -23900,0.08896416,0.03664098,,,,,,,,,,,,,,,,, -23985,,,0.9890407919883728,0.0375085286796093,0.2314681183015991,0.9859633445739746,0.0474525094032287,0.1906305339511292,43793.0,0.9850140810012816,0.0502418987452983,0.1828531633615252,43793.0,7695.835782766342,11215.373279094696,7695.835782766342,3517.945199251175,0.9403893947601318,0.0 -24000,0.05806933,0.036730878,,,,,,,,,,,,,,,,, -24100,0.11879284,0.045959026,,,,,,,,,,,,,,,,, -24200,0.03578568,0.0356035,,,,,,,,,,,,,,,,, -24300,0.03408293,0.034644358,,,,,,,,,,,,,,,,, -24400,0.036007795,0.04045418,,,,,,,,,,,,,,,,, -24500,0.03538838,0.036619738,,,,,,,,,,,,,,,,, -24600,0.034105025,0.035337944,,,,,,,,,,,,,,,,, -24700,0.036188304,0.037120778,,,,,,,,,,,,,,,,, -24741,,,0.9891539812088012,0.0367405265569686,0.2390435111815593,0.985897183418274,0.0479989349842071,0.1873103703045062,43793.0,0.9850648045539856,0.0509517341852188,0.1927351888664488,43793.0,7936.069957971573,11563.233343362808,7936.069957971573,3625.518209695816,0.9731786251068116,0.0 -24800,0.028390968,0.040085234,,,,,,,,,,,,,,,,, -24900,0.08467775,0.039550312,,,,,,,,,,,,,,,,, -25000,0.055110365,0.036781114,,,,,,,,,,,,,,,,, -25100,0.061159015,0.042594265,,,,,,,,,,,,,,,,, -25200,0.032915212,0.03517361,,,,,,,,,,,,,,,,, -25300,0.033421595,0.03792475,,,,,,,,,,,,,,,,, -25400,0.03315272,0.039403275,,,,,,,,,,,,,,,,, -25490,,,0.9891579151153564,0.0368153937160968,0.2493061489164509,0.9859349131584167,0.048080027103424,0.1947319230197419,43793.0,0.985063135623932,0.0512276515364646,0.1845998296445171,43793.0,8175.831615447998,11908.558824062347,8175.831615447998,3730.604122877121,1.4308359622955322,0.0 -25500,0.049175072,0.037898544,,,,,,,,,,,,,,,,, -25600,0.03495895,0.035420824,,,,,,,,,,,,,,,,, -25700,0.04184018,0.035941105,,,,,,,,,,,,,,,,, -25800,0.02758823,0.03820017,,,,,,,,,,,,,,,,, -25900,0.030135505,0.038256355,,,,,,,,,,,,,,,,, -26000,0.042031728,0.038373854,,,,,,,,,,,,,,,,, -26100,0.04294635,0.039056472,,,,,,,,,,,,,,,,, -26200,0.062577285,0.04335935,,,,,,,,,,,,,,,,, -26244,,,0.9890082478523254,0.0371231213212013,0.2619470465234472,0.9857433438301086,0.0480293482542037,0.1981787973978353,43793.0,0.9848095774650574,0.0511777698993682,0.1883528025284292,43793.0,8415.865903615952,12255.497112989426,8415.865903615952,3837.456799507141,1.4622154235839844,0.0 -26300,0.0668723,0.036104307,,,,,,,,,,,,,,,,, -26400,0.04488291,0.039294843,,,,,,,,,,,,,,,,, -26500,0.02858777,0.035434812,,,,,,,,,,,,,,,,, -26600,0.03974989,0.03332016,,,,,,,,,,,,,,,,, -26700,0.025718758,0.036174703,,,,,,,,,,,,,,,,, -26800,0.04644421,0.035317857,,,,,,,,,,,,,,,,, -26900,0.061230574,0.03909501,,,,,,,,,,,,,,,,, -27000,,,0.9892248511314392,0.03641227632761,0.248816781902566,0.9859369397163392,0.047542754560709,0.1895510494609632,43793.0,0.985078752040863,0.0505181364715099,0.1862611174670923,43793.0,8655.825356006622,12597.900453329086,8655.825356006622,3939.849988222122,1.493131399154663,0.0 -27000,0.051992603,0.039262626,,,,,,,,,,,,,,,,, -27100,0.03947142,0.035490423,,,,,,,,,,,,,,,,, -27200,0.05160597,0.04091923,,,,,,,,,,,,,,,,, -27300,0.051676117,0.034746405,,,,,,,,,,,,,,,,, -27400,0.025440887,0.03527617,,,,,,,,,,,,,,,,, -27500,0.033799667,0.034136627,,,,,,,,,,,,,,,,, -27600,0.043606088,0.03833265,,,,,,,,,,,,,,,,, -27700,0.053249683,0.039849907,,,,,,,,,,,,,,,,, -27749,,,0.9891148209571838,0.0368561372160911,0.2498081527409343,0.986055076122284,0.0475734174251556,0.2024660222425728,43793.0,0.985129714012146,0.0505303442478179,0.1984891361260043,43793.0,8895.789571285248,12941.680038928986,8895.789571285248,4043.614537715912,1.5242531299591064,0.0 -27800,0.05816809,0.03822709,,,,,,,,,,,,,,,,, -27900,0.032781728,0.037715916,,,,,,,,,,,,,,,,, -28000,0.095122814,0.037873577,,,,,,,,,,,,,,,,, -28100,0.03249442,0.036467005,,,,,,,,,,,,,,,,, -28200,0.05588234,0.040956747,,,,,,,,,,,,,,,,, -28300,0.0557372,0.034834355,,,,,,,,,,,,,,,,, -28400,0.03029185,0.034739893,,,,,,,,,,,,,,,,, -28500,,,0.9891347885131836,0.0373249910771846,0.2263010704364689,0.9859467148780824,0.0480697453022003,0.1921317430397762,43793.0,0.9850189089775084,0.0510314255952835,0.1900455540619972,43793.0,9136.02029132843,13288.089123249054,9136.02029132843,4149.741677761078,1.555681228637695,0.0 -28500,0.03770011,0.037458677,,,,,,,,,,,,,,,,, -28600,0.0340051,0.03549347,,,,,,,,,,,,,,,,, -28700,0.0357348,0.03629116,,,,,,,,,,,,,,,,, -28800,0.028968435,0.0379948,,,,,,,,,,,,,,,,, -28900,0.040519558,0.036911264,,,,,,,,,,,,,,,,, -29000,0.06348344,0.03912286,,,,,,,,,,,,,,,,, -29100,0.07346244,0.035762593,,,,,,,,,,,,,,,,, -29200,0.06256511,0.031663533,,,,,,,,,,,,,,,,, -29250,,,0.9891011714935304,0.0372898466885089,0.2516501756761238,0.9858760833740234,0.0475060418248176,0.1935256298053031,43793.0,0.9850140810012816,0.0502172112464904,0.1876090685663295,43793.0,9376.250892162325,13634.6715593338,9376.250892162325,4256.041279792786,1.5881264209747314,0.0 -29300,0.05079349,0.03831539,,,,,,,,,,,,,,,,, -29400,0.037605807,0.03723165,,,,,,,,,,,,,,,,, -29500,0.052204736,0.041809388,,,,,,,,,,,,,,,,, -29600,0.06705146,0.03671299,,,,,,,,,,,,,,,,, -29700,0.050649848,0.03630807,,,,,,,,,,,,,,,,, -29800,0.022406856,0.034225807,,,,,,,,,,,,,,,,, -29900,0.026952187,0.037140597,,,,,,,,,,,,,,,,, -30000,0.03583426,0.0401434,,,,,,,,,,,,,,,,, -30004,,,0.989096701145172,0.0369564853608608,0.2388970122226297,0.9860494136810304,0.0476928651332855,0.2032375966967203,43793.0,0.9850319623947144,0.0508443377912044,0.1901787682076284,43793.0,9616.499011993408,13983.701684236526,9616.499011993408,4364.7704746723175,1.6209070682525637,0.0 -30100,0.15786745,0.040727608,,,,,,,,,,,,,,,,, -30200,0.07819922,0.038845014,,,,,,,,,,,,,,,,, -30300,0.055073924,0.041424587,,,,,,,,,,,,,,,,, -30400,0.07280237,0.03476714,,,,,,,,,,,,,,,,, -30500,0.03370691,0.032593995,,,,,,,,,,,,,,,,, -30600,0.063758485,0.03774626,,,,,,,,,,,,,,,,, -30700,0.071574196,0.039175462,,,,,,,,,,,,,,,,, -30751,,,0.9891952872276306,0.0367727316915988,0.2468313038552545,0.9859515428543092,0.0474362857639789,0.2034833376790149,43793.0,0.9849410057067872,0.0504181608557701,0.1896190516759259,43793.0,9856.662395954132,14331.912987470629,9856.662395954132,4472.7656536102295,1.652916431427002,0.0 -30800,0.05151334,0.035405498,,,,,,,,,,,,,,,,, -30900,0.029489594,0.03344987,,,,,,,,,,,,,,,,, -31000,0.03621743,0.037034653,,,,,,,,,,,,,,,,, -31100,0.09605679,0.033820055,,,,,,,,,,,,,,,,, -31200,0.08705487,0.036383145,,,,,,,,,,,,,,,,, -31300,0.04831174,0.04071425,,,,,,,,,,,,,,,,, -31400,0.03678143,0.03610435,,,,,,,,,,,,,,,,, -31490,,,0.9892831444740297,0.0365368127822876,0.2541385304879885,0.9860900044441224,0.0471078380942344,0.2012027431813133,43793.0,0.9851625561714172,0.050069410353899,0.1946114423504687,43793.0,10096.915367603302,14683.252610206604,10096.915367603302,4583.790787220001,1.6917784214019775,0.0 -31500,0.038554735,0.037263624,,,,,,,,,,,,,,,,, -31600,0.048906032,0.036470924,,,,,,,,,,,,,,,,, -31700,0.04199536,0.036705554,,,,,,,,,,,,,,,,, -31800,0.054069933,0.037826177,,,,,,,,,,,,,,,,, -31900,0.041179284,0.038760018,,,,,,,,,,,,,,,,, -32000,0.053695507,0.035693698,,,,,,,,,,,,,,,,, -32100,0.04463321,0.038028996,,,,,,,,,,,,,,,,, -32200,0.072017424,0.038780406,,,,,,,,,,,,,,,,, -32239,,,0.9892269372940063,0.0368354618549346,0.2552228813312813,0.985957682132721,0.0472273305058479,0.1943741291666913,43793.0,0.9850172400474548,0.0500676706433296,0.1955522357274692,43793.0,10337.149015903473,15029.597774505615,10337.149015903473,4689.847971916199,1.7260665893554688,0.0 -32300,0.088498145,0.03764504,,,,,,,,,,,,,,,,, -32400,0.03892491,0.035570767,,,,,,,,,,,,,,,,, -32500,0.07380541,0.032691985,,,,,,,,,,,,,,,,, -32600,0.055789627,0.040696047,,,,,,,,,,,,,,,,, -32700,0.052124046,0.034805104,,,,,,,,,,,,,,,,, -32800,0.069110885,0.03577942,,,,,,,,,,,,,,,,, -32900,0.067367755,0.038923725,,,,,,,,,,,,,,,,, -32988,,,0.9894253015518188,0.0358051843941211,0.2698218583076517,0.9860225915908812,0.0472880974411964,0.2004973368649585,43793.0,0.9851486682891846,0.0502247475087642,0.1932763301956823,43793.0,10577.234008789062,15372.825038433077,10577.234008789062,4792.9368715286255,1.7592804431915283,0.0 -33000,0.08450956,0.032940887,,,,,,,,,,,,,,,,, -33100,0.06326723,0.03800539,,,,,,,,,,,,,,,,, -33200,0.03570158,0.03641954,,,,,,,,,,,,,,,,, -33300,0.057046093,0.03706663,,,,,,,,,,,,,,,,, -33400,0.054873973,0.034898907,,,,,,,,,,,,,,,,, -33500,0.042668052,0.034688096,,,,,,,,,,,,,,,,, -33600,0.052842088,0.038039494,,,,,,,,,,,,,,,,, -33700,0.07102231,0.03383788,,,,,,,,,,,,,,,,, -33734,,,0.9893887639045716,0.0357852131128311,0.2618129740028855,0.9860676527023317,0.0469548366963863,0.197591852607744,43793.0,0.9850395321846008,0.0497498586773872,0.1964414351587808,43793.0,10817.288613796234,15715.528035879135,10817.288613796234,4895.532358884811,1.7915773391723633,0.0 -33800,0.076879025,0.03923452,,,,,,,,,,,,,,,,, -33900,0.06181419,0.03862922,,,,,,,,,,,,,,,,, -34000,0.052848432,0.037636846,,,,,,,,,,,,,,,,, -34100,0.047753368,0.039704077,,,,,,,,,,,,,,,,, -34200,0.03577135,0.03549661,,,,,,,,,,,,,,,,, -34300,0.036138438,0.03801823,,,,,,,,,,,,,,,,, -34400,0.03613675,0.04034842,,,,,,,,,,,,,,,,, -34482,,,0.9893254041671752,0.0362496450543403,0.2509931044289288,0.9861057996749878,0.0473644398152828,0.1936171298186332,43793.0,0.985135555267334,0.0503055602312088,0.1877756219260508,43793.0,11057.456293582916,16061.060340881348,11057.456293582916,5000.842721700668,1.825816631317139,0.0 -34500,0.07096426,0.035817415,,,,,,,,,,,,,,,,, -34600,0.040064026,0.040358376,,,,,,,,,,,,,,,,, -34700,0.03785283,0.03464619,,,,,,,,,,,,,,,,, -34800,0.03932907,0.038212802,,,,,,,,,,,,,,,,, -34900,0.061021503,0.033713903,,,,,,,,,,,,,,,,, -35000,0.043286704,0.03379708,,,,,,,,,,,,,,,,, -35100,0.11018994,0.031950753,,,,,,,,,,,,,,,,, -35200,0.029690731,0.03538565,,,,,,,,,,,,,,,,, -35237,,,0.9891011714935304,0.0366949811577796,0.2536961986611331,0.9861447811126708,0.0473889149725437,0.2023330538214926,43793.0,0.9851351380348206,0.0506232194602489,0.1930257648180393,43793.0,11297.589622735975,16404.014665842056,11297.589622735975,5103.608122110367,1.8610637187957764,0.0 -35300,0.06003226,0.03289594,,,,,,,,,,,,,,,,, -35400,0.05166959,0.038036756,,,,,,,,,,,,,,,,, -35500,0.06574131,0.034421008,,,,,,,,,,,,,,,,, -35600,0.03842953,0.035928994,,,,,,,,,,,,,,,,, -35700,0.09222156,0.03758385,,,,,,,,,,,,,,,,, -35800,0.042530973,0.036662046,,,,,,,,,,,,,,,,, -35900,0.05504321,0.03497208,,,,,,,,,,,,,,,,, -35987,,,0.9890627861022948,0.0370471514761447,0.2397458856785428,0.9860416650772096,0.0474542863667011,0.1940947287240224,43793.0,0.9851229190826416,0.0503481589257717,0.1972139998452186,43793.0,11537.72613310814,16746.245433568954,11537.72613310814,5205.648674488068,1.895133256912232,0.0 -36000,0.045182712,0.03466388,,,,,,,,,,,,,,,,, -36100,0.05114028,0.03760087,,,,,,,,,,,,,,,,, -36200,0.053586606,0.03930931,,,,,,,,,,,,,,,,, -36300,0.057234347,0.033736292,,,,,,,,,,,,,,,,, -36400,0.045396127,0.035357464,,,,,,,,,,,,,,,,, -36500,0.05519725,0.033236463,,,,,,,,,,,,,,,,, -36600,0.092997655,0.03623445,,,,,,,,,,,,,,,,, -36700,0.033769596,0.032223176,,,,,,,,,,,,,,,,, -36737,,,0.9894454479217528,0.0358574204146862,0.2619400303030704,0.9862446784973145,0.046729139983654,0.2127803122106341,43793.0,0.9852808713912964,0.0496365427970886,0.2027815439346703,43793.0,11777.842349767683,17093.15770673752,11777.842349767683,5312.391384601593,1.928436040878296,0.0 -36800,0.0860682,0.035166357,,,,,,,,,,,,,,,,, -36900,0.036796384,0.031973682,,,,,,,,,,,,,,,,, -37000,0.07180716,0.0351743,,,,,,,,,,,,,,,,, -37100,0.033986695,0.032473337,,,,,,,,,,,,,,,,, -37200,0.04950186,0.038529903,,,,,,,,,,,,,,,,, -37300,0.042076536,0.03791012,,,,,,,,,,,,,,,,, -37400,0.08357249,0.033062104,,,,,,,,,,,,,,,,, -37487,,,0.9894551038742064,0.0357481762766838,0.2561994339752247,0.9861322045326232,0.0467582568526268,0.2066152129075412,43793.0,0.9852720499038696,0.0496285483241081,0.2091661338432342,43793.0,12017.802143096924,17435.217926979065,12017.802143096924,5414.43673324585,1.9633357524871824,0.0 -37500,0.04908762,0.037587382,,,,,,,,,,,,,,,,, -37600,0.044102777,0.038301107,,,,,,,,,,,,,,,,, -37700,0.03999214,0.034628775,,,,,,,,,,,,,,,,, -37800,0.037849985,0.033013728,,,,,,,,,,,,,,,,, -37900,0.04840089,0.03405067,,,,,,,,,,,,,,,,, -38000,0.03276867,0.03284574,,,,,,,,,,,,,,,,, -38100,0.028875176,0.034699738,,,,,,,,,,,,,,,,, -38200,0.043630883,0.039430626,,,,,,,,,,,,,,,,, -38244,,,0.9893867373466492,0.0362753197550773,0.2601817317611524,0.985926389694214,0.0473117418587207,0.2024692703405251,43793.0,0.9849477410316468,0.0501863174140453,0.1923162326230432,43793.0,12257.834669589996,17782.260478019714,12257.834669589996,5521.393639087677,1.9962775707244875,0.0 -38300,0.06534316,0.032657403,,,,,,,,,,,,,,,,, -38400,0.033721015,0.034569222,,,,,,,,,,,,,,,,, -38500,0.057663057,0.038359515,,,,,,,,,,,,,,,,, -38600,0.040099785,0.034867376,,,,,,,,,,,,,,,,, -38700,0.052726842,0.036018636,,,,,,,,,,,,,,,,, -38800,0.055927627,0.032407217,,,,,,,,,,,,,,,,, -38900,0.038638107,0.036205072,,,,,,,,,,,,,,,,, -39000,0.104965135,0.042623587,,,,,,,,,,,,,,,,, -39001,,,0.9894042611122132,0.0361329726874828,0.2705461838739672,0.9859158396720886,0.0472590513527393,0.2061454407305062,43793.0,0.9850168228149414,0.0501539148390293,0.1939370750438576,43793.0,12497.840120315552,18124.86086988449,12497.840120315552,5623.934894800186,2.029609203338623,0.0 -39100,0.042036068,0.03888781,,,,,,,,,,,,,,,,, -39200,0.039364148,0.031876698,,,,,,,,,,,,,,,,, -39300,0.038136024,0.034837876,,,,,,,,,,,,,,,,, -39400,0.04854663,0.036777876,,,,,,,,,,,,,,,,, -39500,0.048266437,0.038369022,,,,,,,,,,,,,,,,, -39600,0.033451058,0.03489569,,,,,,,,,,,,,,,,, -39700,0.046866853,0.03558188,,,,,,,,,,,,,,,,, -39756,,,0.9893486499786376,0.0360335744917392,0.256148759447705,0.9859779477119446,0.0470002144575119,0.2051794239423218,43793.0,0.9850248098373412,0.0500476472079753,0.1950534085546705,43793.0,12737.917171955109,18469.422612428665,12737.917171955109,5728.363832473755,2.065474271774292,0.0 -39800,0.042541403,0.033728894,,,,,,,,,,,,,,,,, -39900,0.058238458,0.03760274,,,,,,,,,,,,,,,,, -40000,0.042283803,0.03625047,,,,,,,,,,,,,,,,, -40100,0.0642201,0.036639433,,,,,,,,,,,,,,,,, -40200,0.09094809,0.036705635,,,,,,,,,,,,,,,,, -40300,0.09414872,0.041313365,,,,,,,,,,,,,,,,, -40400,0.054306824,0.03683688,,,,,,,,,,,,,,,,, -40500,0.07270849,0.034215398,,,,,,,,,,,,,,,,, -40505,,,0.9895033836364746,0.035438735038042,0.2753709398089208,0.986076593399048,0.0467179007828235,0.2022420670019095,43793.0,0.9851389527320862,0.0496848039329052,0.1942995179941897,43793.0,12977.919306516647,18815.74730205536,12977.919306516647,5834.631159543991,2.10075044631958,0.0 -40600,0.04297115,0.035649356,,,,,,,,,,,,,,,,, -40700,0.053514462,0.03950143,,,,,,,,,,,,,,,,, -40800,0.038603667,0.034408823,,,,,,,,,,,,,,,,, -40900,0.051171657,0.032162447,,,,,,,,,,,,,,,,, -41000,0.047134083,0.040324986,,,,,,,,,,,,,,,,, -41100,0.0675394,0.03713526,,,,,,,,,,,,,,,,, -41200,0.04439571,0.03751595,,,,,,,,,,,,,,,,, -41252,,,0.9897860288619996,0.0345465615391731,0.2853760124042723,0.9862000346183776,0.0466513372957706,0.2100652647208077,43793.0,0.9852118492126464,0.0496916957199573,0.1984080443646017,43793.0,13218.008937835692,19158.668475151066,13218.008937835692,5937.405626773834,2.137742519378662,0.0 -41300,0.044496365,0.03510032,,,,,,,,,,,,,,,,, -41400,0.052644886,0.03924808,,,,,,,,,,,,,,,,, -41500,0.061753586,0.037890162,,,,,,,,,,,,,,,,, -41600,0.055673286,0.036136124,,,,,,,,,,,,,,,,, -41700,0.045436304,0.03638424,,,,,,,,,,,,,,,,, -41800,0.044032793,0.037210554,,,,,,,,,,,,,,,,, -41900,0.0699815,0.034876283,,,,,,,,,,,,,,,,, -42000,0.05610467,0.034450736,,,,,,,,,,,,,,,,, -42004,,,0.989495038986206,0.0356295444071292,0.266926135120057,0.9862678050994872,0.0465229675173759,0.2150185824859295,43793.0,0.9853630065917968,0.0496775656938552,0.1955947536136331,43793.0,13458.234591960909,19505.24766516685,13458.234591960909,6043.70458650589,2.17232346534729,0.0 -42100,0.053487577,0.039786734,,,,,,,,,,,,,,,,, -42200,0.07424088,0.038387716,,,,,,,,,,,,,,,,, -42300,0.04702048,0.03672463,,,,,,,,,,,,,,,,, -42400,0.04709597,0.03661851,,,,,,,,,,,,,,,,, -42500,0.03526362,0.03781255,,,,,,,,,,,,,,,,, -42600,0.04407282,0.03609623,,,,,,,,,,,,,,,,, -42700,0.09254483,0.036354527,,,,,,,,,,,,,,,,, -42737,,,0.9894113540649414,0.0357576794922351,0.2688181003628976,0.9861470460891724,0.0467201247811317,0.2114273808923142,43793.0,0.9852128624916076,0.0497658960521221,0.2048573545565305,43793.0,13698.329906463625,19849.74711751938,13698.329906463625,6148.046504974365,2.210094928741455,0.0 -42800,0.05564997,0.03258962,,,,,,,,,,,,,,,,, -42900,0.043884605,0.036898036,,,,,,,,,,,,,,,,, -43000,0.047019884,0.033178598,,,,,,,,,,,,,,,,, -43100,0.070275515,0.040580355,,,,,,,,,,,,,,,,, -43200,0.045156628,0.03445382,,,,,,,,,,,,,,,,, -43300,0.047107514,0.03232686,,,,,,,,,,,,,,,,, -43400,0.060932167,0.03891957,,,,,,,,,,,,,,,,, -43494,,,0.989499032497406,0.0354002267122268,0.2650605700785743,0.9862678050994872,0.0466506890952587,0.2114251127695588,43793.0,0.9853647351264954,0.0495802052319049,0.2040145778814236,43793.0,13938.442008972168,20193.09298181533,13938.442008972168,6251.22530412674,2.2450668811798096,0.0 -43500,0.037871376,0.03302388,,,,,,,,,,,,,,,,, -43600,0.06710652,0.038714767,,,,,,,,,,,,,,,,, -43700,0.06295344,0.036800668,,,,,,,,,,,,,,,,, -43800,0.0528166,0.03555663,,,,,,,,,,,,,,,,, -43900,0.045676652,0.036436997,,,,,,,,,,,,,,,,, -44000,0.082141556,0.033071533,,,,,,,,,,,,,,,,, -44100,0.057612315,0.035449307,,,,,,,,,,,,,,,,, -44200,0.06445121,0.03843423,,,,,,,,,,,,,,,,, -44251,,,0.9895753860473632,0.0354918055236339,0.2742665919504021,0.9862418174743652,0.0463673397898674,0.2104525270200389,43793.0,0.9851962327957152,0.0495260991156101,0.1989728517536515,43793.0,14178.594394683838,20536.04729604721,14178.594394683838,6353.971425771713,2.28100848197937,0.0 -44300,0.092805766,0.038363207,,,,,,,,,,,,,,,,, -44400,0.047085438,0.0333657,,,,,,,,,,,,,,,,, -44500,0.10313663,0.038646434,,,,,,,,,,,,,,,,, -44600,0.03619881,0.034319192,,,,,,,,,,,,,,,,, -44700,0.034394816,0.03486447,,,,,,,,,,,,,,,,, -44800,0.07404827,0.039902322,,,,,,,,,,,,,,,,, -44900,0.13566485,0.035279978,,,,,,,,,,,,,,,,, -45000,0.06311647,0.032514155,,,,,,,,,,,,,,,,, -45006,,,0.9895598888397216,0.0351086929440498,0.2776030484718139,0.9861411452293396,0.0466781482100486,0.2058879765164011,43793.0,0.9851511716842652,0.0497616715729236,0.1953420900467133,43793.0,14418.677802801132,20881.09065937996,14418.677802801132,6458.873847723007,2.3189632892608643,0.0 -45100,0.03938617,0.033634964,,,,,,,,,,,,,,,,, -45200,0.06936411,0.038979452,,,,,,,,,,,,,,,,, -45300,0.057270996,0.03288461,,,,,,,,,,,,,,,,, -45400,0.049259193,0.034956917,,,,,,,,,,,,,,,,, -45500,0.06685074,0.03525169,,,,,,,,,,,,,,,,, -45600,0.10653842,0.041774213,,,,,,,,,,,,,,,,, -45700,0.051144894,0.031667195,,,,,,,,,,,,,,,,, -45765,,,0.9897001385688782,0.0346075557172298,0.2768291465979394,0.9862962365150452,0.0462343730032444,0.2151871670410814,43793.0,0.9853546023368835,0.0493918880820274,0.2038367222214243,43793.0,14658.762327671053,21223.11307501793,14658.762327671053,6560.754787683487,2.355625629425049,0.0 -45800,0.04603366,0.035724025,,,,,,,,,,,,,,,,, -45900,0.06104049,0.035515405,,,,,,,,,,,,,,,,, -46000,0.04355075,0.03745443,,,,,,,,,,,,,,,,, -46100,0.05407206,0.033437103,,,,,,,,,,,,,,,,, -46200,0.05909199,0.034443937,,,,,,,,,,,,,,,,, -46300,0.061025497,0.035665765,,,,,,,,,,,,,,,,, -46400,0.058857422,0.036001153,,,,,,,,,,,,,,,,, -46500,0.07088558,0.035485454,,,,,,,,,,,,,,,,, -46515,,,0.989697515964508,0.0346462801098823,0.2963236467176596,0.9863830804824828,0.0461990013718605,0.2210658521639429,43793.0,0.985383689403534,0.0493339970707893,0.207040467005581,43793.0,14898.90408039093,21567.253492355347,14898.90408039093,6664.698795795441,2.390665292739868,0.0 -46600,0.06780134,0.03534999,,,,,,,,,,,,,,,,, -46700,0.058296878,0.033173546,,,,,,,,,,,,,,,,, -46800,0.04192619,0.032959357,,,,,,,,,,,,,,,,, -46900,0.071072675,0.03674687,,,,,,,,,,,,,,,,, -47000,0.053802427,0.037504166,,,,,,,,,,,,,,,,, -47100,0.087848455,0.035034183,,,,,,,,,,,,,,,,, -47200,0.047485646,0.0347751,,,,,,,,,,,,,,,,, -47270,,,0.9897758364677428,0.0343558825552463,0.3042403984564562,0.9862929582595824,0.0459770634770393,0.212829559312736,43793.0,0.9853179454803468,0.0488965511322021,0.2030546784519302,43793.0,15138.879612207413,21912.986404657364,15138.879612207413,6770.398800849915,2.428494691848755,0.0 -47300,0.061683577,0.035265002,,,,,,,,,,,,,,,,, -47400,0.05944506,0.02724714,,,,,,,,,,,,,,,,, -47500,0.123833224,0.032438137,,,,,,,,,,,,,,,,, -47600,0.07054657,0.032128572,,,,,,,,,,,,,,,,, -47700,0.10860533,0.03964002,,,,,,,,,,,,,,,,, -47800,0.052185554,0.037040744,,,,,,,,,,,,,,,,, -47900,0.04766967,0.034812737,,,,,,,,,,,,,,,,, -48000,0.04147406,0.03224304,,,,,,,,,,,,,,,,, -48027,,,0.98962664604187,0.0343369282782077,0.286865701081385,0.98627507686615,0.0465668700635433,0.2209927814720388,43793.0,0.9853202700614928,0.049766506999731,0.2046247702905341,43793.0,15379.08549618721,22254.191664218903,15379.08549618721,6871.342316389084,2.464463233947754,0.0 -48100,0.078079835,0.034812793,,,,,,,,,,,,,,,,, -48200,0.096669815,0.03335765,,,,,,,,,,,,,,,,, -48300,0.050477143,0.03996085,,,,,,,,,,,,,,,,, -48400,0.04664737,0.032872126,,,,,,,,,,,,,,,,, -48500,0.048208147,0.034383,,,,,,,,,,,,,,,,, -48600,0.069481656,0.032835323,,,,,,,,,,,,,,,,, -48700,0.04650113,0.03494716,,,,,,,,,,,,,,,,, -48784,,,0.9900445342063904,0.0335013419389724,0.3042157934923303,0.9863396286964417,0.0461429804563522,0.2207973284185761,43793.0,0.9854194521903992,0.0493343286216259,0.2104642647012252,43793.0,15619.100360155106,22597.190956115723,15619.100360155106,6974.270753145218,2.500919818878174,0.0 -48800,0.068347305,0.034310266,,,,,,,,,,,,,,,,, -48900,0.052486546,0.03467091,,,,,,,,,,,,,,,,, -49000,0.06551747,0.033795312,,,,,,,,,,,,,,,,, -49100,0.079664744,0.031331602,,,,,,,,,,,,,,,,, -49200,0.052145835,0.032032143,,,,,,,,,,,,,,,,, -49300,0.059026133,0.03250304,,,,,,,,,,,,,,,,, -49400,0.04735698,0.034940835,,,,,,,,,,,,,,,,, -49500,0.059381932,0.0380031,,,,,,,,,,,,,,,,, -49536,,,0.9898927807807922,0.0341162048280239,0.3005784614493065,0.986319363117218,0.0463139228522777,0.2191926410291987,43793.0,0.9853870272636414,0.0492770634591579,0.2131225694710239,43793.0,15859.136625528336,22937.37840652466,15859.136625528336,7074.364864110947,2.5379621982574463,0.0 -49600,0.06295913,0.033062406,,,,,,,,,,,,,,,,, -49700,0.06396627,0.035247125,,,,,,,,,,,,,,,,, -49800,0.060327407,0.031084996,,,,,,,,,,,,,,,,, -49900,0.05445169,0.036231305,,,,,,,,,,,,,,,,, -50000,0.07362859,0.033439662,,,,,,,,,,,,,,,,, -50100,0.07881648,0.039294466,,,,,,,,,,,,,,,,, -50200,0.05359964,0.033213727,,,,,,,,,,,,,,,,, -50294,,,0.9897878170013428,0.0343374647200107,0.2856928937956072,0.9864285588264464,0.0458579137921333,0.2164103157897776,43793.0,0.9856048226356506,0.0489003583788871,0.2149699454561261,43793.0,16099.339428424835,23282.12690448761,16099.339428424835,7178.855273962021,2.573608636856079,0.0 -50300,0.08184106,0.03202115,,,,,,,,,,,,,,,,, -50400,0.07375996,0.037824463,,,,,,,,,,,,,,,,, -50500,0.07945084,0.03312862,,,,,,,,,,,,,,,,, -50600,0.053769406,0.034691043,,,,,,,,,,,,,,,,, -50700,0.07802944,0.032994322,,,,,,,,,,,,,,,,, -50800,0.05915486,0.038059622,,,,,,,,,,,,,,,,, -50900,0.10248533,0.03787325,,,,,,,,,,,,,,,,, -51000,0.090094194,0.03451349,,,,,,,,,,,,,,,,, -51040,,,0.9899283647537231,0.0339838527143001,0.3022169393236976,0.9863327741622924,0.04576126486063,0.2156734050628471,43793.0,0.9853739738464355,0.0486579872667789,0.2081253269428495,43793.0,16339.413675069807,23625.6093814373,16339.413675069807,7282.207628250122,2.6092591285705566,0.0 -51100,0.062072214,0.035251793,,,,,,,,,,,,,,,,, -51200,0.061126858,0.032102812,,,,,,,,,,,,,,,,, -51300,0.08948365,0.033682182,,,,,,,,,,,,,,,,, -51400,0.05916855,0.034862824,,,,,,,,,,,,,,,,, -51500,0.08611162,0.034430586,,,,,,,,,,,,,,,,, -51600,0.0573933,0.03472549,,,,,,,,,,,,,,,,, -51700,0.116329625,0.031157332,,,,,,,,,,,,,,,,, -51787,,,0.9897343516349792,0.0343607477843761,0.2915450888909189,0.986326277256012,0.045857734978199,0.2213827129714896,43793.0,0.9853882789611816,0.0487843081355094,0.2069404616550198,43793.0,16579.61815905571,23971.82929778099,16579.61815905571,7388.165745735168,2.646596670150757,0.0 -51800,0.056154232,0.033555076,,,,,,,,,,,,,,,,, -51900,0.0564714,0.03404357,,,,,,,,,,,,,,,,, -52000,0.056080304,0.034161042,,,,,,,,,,,,,,,,, -52100,0.079922065,0.03350389,,,,,,,,,,,,,,,,, -52200,0.04500989,0.034367204,,,,,,,,,,,,,,,,, -52300,0.061787922,0.03800273,,,,,,,,,,,,,,,,, -52400,0.092867404,0.037247766,,,,,,,,,,,,,,,,, -52500,0.105746076,0.031129228,,,,,,,,,,,,,,,,, -52526,,,0.9898606538772584,0.0337282866239547,0.308960692285714,0.9864127039909364,0.0457212105393409,0.2221381761078136,43793.0,0.9855268597602844,0.0486095622181892,0.212095926354634,43793.0,16819.64329266548,24314.65186572075,16819.64329266548,7490.901193618774,2.687276124954224,0.0 -52600,0.060359817,0.02910605,,,,,,,,,,,,,,,,, -52700,0.07258003,0.032759264,,,,,,,,,,,,,,,,, -52800,0.07784694,0.038925886,,,,,,,,,,,,,,,,, -52900,0.06019736,0.03355787,,,,,,,,,,,,,,,,, -53000,0.08874825,0.03803012,,,,,,,,,,,,,,,,, -53100,0.07689898,0.03254073,,,,,,,,,,,,,,,,, -53200,0.08027416,0.030431442,,,,,,,,,,,,,,,,, -53280,,,0.9900662302970886,0.0333186835050582,0.3060979508321819,0.9864590167999268,0.045574489980936,0.2243359206043439,43793.0,0.9855159521102904,0.0484535545110702,0.2134460965103021,43793.0,17059.861709833145,24661.785895824432,17059.861709833145,7597.760349988937,2.723623275756836,0.0 -53300,0.084341586,0.030318538,,,,,,,,,,,,,,,,, -53400,0.05981239,0.03500536,,,,,,,,,,,,,,,,, -53500,0.0645572,0.034279447,,,,,,,,,,,,,,,,, -53600,0.101039425,0.03397945,,,,,,,,,,,,,,,,, -53700,0.090601474,0.034397837,,,,,,,,,,,,,,,,, -53800,0.07096815,0.030557124,,,,,,,,,,,,,,,,, -53900,0.061385337,0.03342843,,,,,,,,,,,,,,,,, -54000,0.087393776,0.035449784,,,,,,,,,,,,,,,,, -54027,,,0.9900087714195251,0.0333698242902755,0.3049010737919717,0.9865007996559144,0.0456825047731399,0.2264214793265886,43793.0,0.985465407371521,0.0490322783589363,0.2094928839266382,43793.0,17299.852272987366,25005.707471370697,17299.852272987366,7701.634706735611,2.759958982467652,0.0 -54100,0.087077565,0.032724462,,,,,,,,,,,,,,,,, -54200,0.07811074,0.03455428,,,,,,,,,,,,,,,,, -54300,0.06267749,0.035296146,,,,,,,,,,,,,,,,, -54400,0.063577734,0.03604057,,,,,,,,,,,,,,,,, -54500,0.1390167,0.03399691,,,,,,,,,,,,,,,,, -54600,0.086070195,0.03360781,,,,,,,,,,,,,,,,, -54700,0.05916419,0.034355916,,,,,,,,,,,,,,,,, -54781,,,0.9901589155197144,0.032963290810585,0.3163235677717774,0.9865801930427552,0.0455623939633369,0.2280526233849575,43793.0,0.985546052455902,0.0486128292977809,0.2145471011356968,43793.0,17539.843241930008,25347.28339076042,17539.843241930008,7803.162973642349,2.796483993530273,0.0 -54800,0.099486254,0.03175839,,,,,,,,,,,,,,,,, -54900,0.12594873,0.03600929,,,,,,,,,,,,,,,,, -55000,0.05961805,0.033335876,,,,,,,,,,,,,,,,, -55100,0.0861726,0.035534967,,,,,,,,,,,,,,,,, -55200,0.08278855,0.033715233,,,,,,,,,,,,,,,,, -55300,0.11464524,0.03212755,,,,,,,,,,,,,,,,, -55400,0.06141176,0.036828514,,,,,,,,,,,,,,,,, -55500,0.06124725,0.032125253,,,,,,,,,,,,,,,,, -55533,,,0.990270733833313,0.0325558632612228,0.323719534949942,0.9865257740020752,0.0453366674482822,0.2321489278343596,43793.0,0.985528528690338,0.048599362373352,0.2175347020578288,43793.0,17779.891257047653,25694.680911540985,17779.891257047653,7910.455152988434,2.8337621688842773,0.0 -55600,0.07664045,0.03219823,,,,,,,,,,,,,,,,, -55700,0.07684686,0.035084575,,,,,,,,,,,,,,,,, -55800,0.10698258,0.030458637,,,,,,,,,,,,,,,,, -55900,0.11698421,0.0289486,,,,,,,,,,,,,,,,, -56000,0.06310935,0.029957967,,,,,,,,,,,,,,,,, -56100,0.07194962,0.033873964,,,,,,,,,,,,,,,,, -56200,0.07436297,0.03425463,,,,,,,,,,,,,,,,, -56285,,,0.9902195930480956,0.0326430946588516,0.3289052356829938,0.9864122867584229,0.0456474311649799,0.2347272987081529,43793.0,0.9855323433876038,0.0487038306891918,0.2174664089875856,43793.0,18020.05457186699,26036.965087890625,18020.05457186699,8012.519022226334,2.870720386505127,0.0 -56300,0.073082216,0.030644054,,,,,,,,,,,,,,,,, -56400,0.057576936,0.03175543,,,,,,,,,,,,,,,,, -56500,0.076018445,0.03136466,,,,,,,,,,,,,,,,, -56600,0.07706365,0.033592038,,,,,,,,,,,,,,,,, -56700,0.09724063,0.032536283,,,,,,,,,,,,,,,,, -56800,0.08214008,0.029917687,,,,,,,,,,,,,,,,, -56900,0.09393139,0.037258405,,,,,,,,,,,,,,,,, -57000,0.07727213,0.03580569,,,,,,,,,,,,,,,,, -57039,,,0.9902625679969788,0.0327071771025657,0.3176695183003912,0.9864926934242249,0.0454567149281501,0.2280582942071802,43793.0,0.9855504631996156,0.04861481487751,0.2146260371624704,43793.0,18260.280821561813,26382.20707917213,18260.280821561813,8117.47713804245,2.908522367477417,0.0 -57100,0.10217946,0.028825615,,,,,,,,,,,,,,,,, -57200,0.05780444,0.032747187,,,,,,,,,,,,,,,,, -57300,0.0655151,0.030967027,,,,,,,,,,,,,,,,, -57400,0.087690055,0.029658033,,,,,,,,,,,,,,,,, -57500,0.07670836,0.036662042,,,,,,,,,,,,,,,,, -57600,0.07890093,0.032511946,,,,,,,,,,,,,,,,, -57700,0.0950976,0.03257438,,,,,,,,,,,,,,,,, -57715,,,,,,,,,,,,,,18477.03369784355,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/eval_measurements.csv deleted file mode 100644 index bd0aa6768..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -103.7207329273224,0.0,13.580633878707886,1,0,13.580633878707886,0.5047284960746765,0.7438743114471436,0.0269269992061447,43793,117.30141496658324,0.5093092918395996,0.7391443848609924,0.0232773786467988,0.5064646005630493,0.7422075867652893,0.0252931536418489,43793 -207.2520191669464,0.0207283496856689,253.804012298584,750,0,253.804012298584,0.9832671880722046,0.064820185303688,0.0545995080153342,43793,461.0968973636627,0.986806094646454,0.0515382997691631,0.0542356702645199,0.9842458367347716,0.0615352354943752,0.0538402358803808,43793 -311.0966327190399,0.0467598438262939,494.0419046878815,1495,0,494.0419046878815,0.9834954738616944,0.0607746839523315,0.0840767077035999,43793,805.2264924049377,0.9871118068695068,0.0480010174214839,0.083969882400915,0.9844828844070436,0.057401418685913,0.0852865402316337,43793 -416.1317677497864,0.0737733840942382,734.1674497127533,2243,0,734.1674497127533,0.9837182760238647,0.0578309632837772,0.123550536234461,43793,1150.4348917007446,0.987299919128418,0.0455262772738933,0.1279417429584762,0.9847471714019777,0.0546827651560306,0.1286468551470923,43793 -517.6807377338409,0.1014776229858398,974.1777114868164,2991,0,974.1777114868164,0.9842160940170288,0.0548467971384525,0.1484768175352116,43793,1492.042894601822,0.9879700541496276,0.0423735380172729,0.1602362821606831,0.9851559400558472,0.051908191293478,0.1475911315728504,43793 -618.0172493457794,0.1294021606445312,1214.414056301117,3736,0,1214.414056301117,0.984274685382843,0.0530230216681957,0.1659362918288775,43793,1832.6644973754885,0.9882538318634032,0.0412957370281219,0.1869888306182944,0.9851595759391784,0.0504656732082366,0.1660371829803102,43793 -726.6796119213104,0.1562621593475341,1454.5042502880096,4488,0,1454.5042502880096,0.9847304224967957,0.051666285842657,0.1886703707765753,43793,2181.4644510746,0.9885920882225036,0.0392466597259044,0.2131411379152453,0.9855939149856568,0.0488959811627864,0.1844909876500332,43793 -835.3410527706146,0.1837186813354492,1694.6122086048126,5241,0,1694.6122086048126,0.9850661158561708,0.0502384155988693,0.2137931187659128,43793,2530.281988620758,0.9889898300170898,0.0376451984047889,0.2380226997982301,0.9859138131141664,0.0475182346999645,0.2031912993116336,43793 -943.2999262809752,0.2107102870941162,1934.768192768097,5993,0,1934.768192768097,0.9853339791297911,0.0492023862898349,0.2152423660366171,43793,2878.444193124771,0.9889129996299744,0.0374315716326236,0.2512646496902923,0.9860774278640748,0.0465936623513698,0.2131519201854946,43793 -1043.2435188293457,0.2378804683685302,2174.874867916107,6748,0,2174.874867916107,0.9853845238685608,0.0488400496542453,0.2211056133914604,43793,3218.5420083999634,0.989357590675354,0.0359286963939666,0.2719629300015331,0.9862214922904968,0.0462537966668605,0.2262959319746798,43793 -1149.6194591522217,0.2666077613830566,2415.1428520679474,7493,0,2415.1428520679474,0.9854055643081664,0.0487098097801208,0.2348598280664689,43793,3565.234797239304,0.9894558787345886,0.0353555269539356,0.3062496271384285,0.9862868785858154,0.0459942333400249,0.2279327528108939,43793 -1254.9102308750153,0.3024904727935791,2655.087381839752,8242,0,2655.087381839752,0.9856169819831848,0.0478177182376384,0.2419383466019386,43793,3910.5265514850616,0.989901602268219,0.0337459854781627,0.3307906793603523,0.9863956570625304,0.0453762374818325,0.2373848047517522,43793 -1359.43941283226,0.332345962524414,2895.369548559189,8996,0,2895.369548559189,0.9856843948364258,0.0478332489728927,0.2460210322111171,43793,4255.388598442078,0.99005526304245,0.0332898646593093,0.3302585783276252,0.9864711761474608,0.0451908968389034,0.2407431674503511,43793 -1464.221135854721,0.3602263927459717,3135.624180316925,9748,0,3135.624180316925,0.9858174920082092,0.0472788773477077,0.2526541590940708,43793,4600.473388910294,0.990172564983368,0.0328486375510692,0.3479799884843341,0.9866335391998292,0.0446899943053722,0.2526990602814843,43793 -1567.718020439148,0.3888866901397705,3375.697806596756,10498,0,3375.697806596756,0.9857223033905028,0.0474354512989521,0.2465764344200082,43793,4944.093198299408,0.9901142716407776,0.033044509589672,0.342539331587014,0.9865771532058716,0.0446068868041038,0.2567563124017319,43793 -1673.7032914161682,0.4190025329589844,3615.875280857086,11243,0,3615.875280857086,0.9858258962631226,0.0469792783260345,0.2500176737745309,43793,5290.306088924408,0.9903327226638794,0.0322994701564312,0.3561079984004907,0.9867005348205566,0.0442373193800449,0.2580287784778692,43793 -1774.1074786186218,0.4494631290435791,3855.9695858955374,11987,0,3855.9695858955374,0.9858882427215576,0.0469309277832508,0.2575036834951336,43793,5630.856683969498,0.990323007106781,0.0320759005844593,0.3612385092120623,0.9867829084396362,0.0440880171954631,0.2602869942875813,43793 -1875.4612171649933,0.4779088497161865,4096.190707445145,12742,0,4096.190707445145,0.9859021306037904,0.047086376696825,0.2540120004735609,43793,5972.480631828308,0.9905083775520324,0.0314913280308246,0.3642103863357995,0.9867606163024902,0.0440910868346691,0.2677271230835614,43793 -1983.5776450634005,0.5073034763336182,4336.178166389465,13482,0,4336.178166389465,0.9860032200813292,0.0466867499053478,0.2611335866362385,43793,6320.635596513748,0.9906244874000548,0.0309746414422988,0.389609197822945,0.9867532849311828,0.0439917035400867,0.2688608755476995,43793 -2087.2807455062866,0.5352981090545654,4576.319272518158,14236,0,4576.319272518158,0.9858457446098328,0.0466507412493228,0.2567049480858492,43793,6664.528880357742,0.990672528743744,0.0306298565119504,0.3995707839872688,0.986504077911377,0.0440920814871788,0.2636364530613028,43793 -2194.837212085724,0.5642292499542236,4816.31423330307,14991,0,4816.31423330307,0.9859707951545716,0.0472371615469455,0.2638035088881912,43793,7012.129799127579,0.9907553195953368,0.0302142277359962,0.4122109902822474,0.986757755279541,0.0445883460342884,0.2677942141421415,43793 -2295.940425634384,0.5946774482727051,5056.359443902969,15729,0,5056.359443902969,0.9859463572502136,0.0468619801104068,0.2607283811938001,43793,7353.330956935883,0.9911349415779114,0.029304539784789,0.430178080129186,0.9867480397224426,0.0441630408167839,0.2715012183204673,43793 -2398.510513544082,0.6254196166992188,5296.309223651886,16471,0,5296.309223651886,0.9859893321990968,0.0467883460223674,0.2560449718034864,43793,7695.902285814285,0.9909190535545348,0.0297707431018352,0.4091028998139538,0.9867866039276124,0.0440623573958873,0.2620690775753606,43793 -2501.5175173282623,0.6656818389892578,5536.299431324005,17210,0,5536.299431324005,0.9859653115272522,0.0465751178562641,0.2614984133559347,43793,8038.960066795349,0.9909037947654724,0.0299575366079807,0.4211723660599582,0.9867788553237916,0.044044554233551,0.2709214835239603,43793 -2609.689152956009,0.6960053443908691,5776.414937496185,17958,0,5776.414937496185,0.986120343208313,0.046714399009943,0.2669996794222807,43793,8387.298349618912,0.9907863736152648,0.030131721869111,0.4060708456613312,0.9868494868278505,0.0441721640527248,0.2710522605596167,43793 -2716.024253845215,0.7282171249389648,6016.625653028488,18709,0,6016.625653028488,0.986135482788086,0.0465055629611015,0.2658340993299205,43793,8733.897119998932,0.9909077882766724,0.0298218131065368,0.4070056823467816,0.9868308305740356,0.0438988842070102,0.2765967366141773,43793 -2819.025629043579,0.7590713500976562,6256.626332998276,19465,0,6256.626332998276,0.9860706329345704,0.0468316636979579,0.2596261907954604,43793,9076.95063853264,0.9910405278205872,0.029201403260231,0.4292364006707121,0.9868568181991576,0.0440428256988525,0.2791245527026243,43793 -2922.849404811859,0.792198657989502,6496.625156402588,20215,0,6496.625156402588,0.9858280420303344,0.0466070771217346,0.2692436956815187,43793,9420.829341888428,0.9910680651664734,0.0291820932179689,0.439064296870055,0.986634373664856,0.0441146455705165,0.2813325085483499,43793 -3022.484857082367,0.8230326175689697,6736.826936483383,20963,0,6736.826936483383,0.9860310554504396,0.0468637309968471,0.2672742096154502,43793,9760.718255281448,0.9911518692970276,0.0288035813719034,0.4462325316914583,0.9868718385696412,0.0439016968011856,0.2812189409328452,43793 -3124.1582732200623,0.8541834354400635,6977.031692266464,21713,0,6977.031692266464,0.9860352277755736,0.0469725541770458,0.2679272339349347,43793,10102.648232460022,0.9913423657417296,0.0279766172170639,0.4702118638735022,0.9869785904884338,0.0438735596835613,0.285625691321616,43793 -3230.3336248397827,0.8844475746154785,7217.031193494797,22456,0,7217.031193494797,0.9862037301063538,0.0466453246772289,0.2640839126761976,43793,10448.873989105225,0.9913394451141356,0.0280387885868549,0.4604188786782695,0.9869635701179504,0.0437427349388599,0.2827766819212199,43793 -3331.6023383140564,0.9160325527191162,7457.124490976334,23204,0,7457.124490976334,0.986123263835907,0.0469043962657451,0.2627899405295638,43793,10790.28831577301,0.9916256070137024,0.0272593423724174,0.4747280372458277,0.9869043231010436,0.0440606139600276,0.2845651073448388,43793 -3436.4192943573,0.9561116695404052,7697.121632099152,23956,0,7697.121632099152,0.985987663269043,0.0470348894596099,0.2644944977788918,43793,11135.16293668747,0.9913517236709596,0.0281870812177658,0.4601179168266128,0.9868783354759216,0.0442613177001476,0.280703321938136,43793 -3540.332468509674,0.987633228302002,7937.132448196411,24711,0,7937.132448196411,0.9862121343612672,0.0468297265470027,0.2668786160394751,43793,11479.138907432556,0.9913532137870787,0.0281356312334537,0.4598455509794903,0.9870455861091614,0.0438859760761261,0.288337047625382,43793 -3641.869652032852,1.0197508335113523,8177.254436969757,25462,0,8177.254436969757,0.9861923456192015,0.046654887497425,0.2683779409036351,43793,11820.850269317629,0.991429328918457,0.0277799628674983,0.4671925935508493,0.9870439767837524,0.0436918511986732,0.2888897962900462,43793 -3745.5998179912567,1.0510282516479492,8417.495745420456,26214,0,8417.495745420456,0.9861291646957396,0.0465985499322414,0.2675201139818264,43793,12164.87309885025,0.991330623626709,0.028105879202485,0.4568819825250968,0.986963987350464,0.043882068246603,0.2856185672583953,43793 -3846.632305622101,1.0833914279937744,8657.630143404007,26964,0,8657.630143404007,0.986011266708374,0.0466346591711044,0.2703870400534875,43793,12506.092809200289,0.9913989305496216,0.0279050935059785,0.4681628812771156,0.9868897199630736,0.043874591588974,0.2818689651514881,43793 -3949.8536190986633,1.1148099899291992,8897.8357629776,27709,0,8897.8357629776,0.9860866665840148,0.0466265827417373,0.2637615840012572,43793,12849.57141637802,0.9916355013847352,0.0270985979586839,0.4803630402502638,0.986885666847229,0.043906919658184,0.286745084223735,43793 -4051.599108457565,1.148134708404541,9137.999815702438,28458,0,9137.999815702438,0.9860950708389282,0.0467545948922634,0.2691075067365001,43793,13191.534855604172,0.9917035102844238,0.0267012659460306,0.4912260841639426,0.9870139360427856,0.0437790490686893,0.2853140712726124,43793 -4153.242586612701,1.1815519332885742,9378.210587263107,29204,0,9378.210587263107,0.986171305179596,0.0467007234692573,0.2694446762523158,43793,13533.442611694336,0.9919525980949402,0.0259919613599777,0.5057251405123364,0.9870707392692566,0.0437920428812503,0.290256895570826,43793 -4259.714983224869,1.2135634422302246,9618.264122486116,29961,0,9618.264122486116,0.9861283302307128,0.0471483767032623,0.2724732373198176,43793,13880.021289348602,0.991941213607788,0.0258173458278179,0.5112644960667561,0.9870638251304626,0.0439641214907169,0.289960550844474,43793 -4363.952226638794,1.246518850326538,9858.291661024094,30714,0,9858.291661024094,0.9860677123069764,0.0471735931932926,0.2675294987951833,43793,14224.340274810793,0.9917917847633362,0.0263051856309175,0.5038807477785993,0.9869644045829772,0.0440908707678318,0.2894453883646158,43793 -4470.956292152405,1.2795672416687012,10098.319123268127,31458,0,10098.319123268127,0.9861093759536744,0.0466942265629768,0.2704962118586615,43793,14571.425382375715,0.9918238520622252,0.0264401771128177,0.4987664840541888,0.9869713187217712,0.0437646061182022,0.2861938749291255,43793 -4574.533927679062,1.313849925994873,10338.541726350784,32201,0,10338.541726350784,0.9861093759536744,0.0469667054712772,0.2685648392049272,43793,14915.281131029127,0.9917375445365906,0.0267392918467521,0.4871506207529771,0.9870427250862122,0.0438420996069908,0.2848324278871966,43793 -4676.352701425552,1.3461999893188477,10578.52193403244,32947,0,10578.52193403244,0.9862593412399292,0.0471234656870365,0.2710033452222284,43793,15257.13302230835,0.9917527437210084,0.0266880467534065,0.4959007838136359,0.9871665239334106,0.0442121364176273,0.2930568691400217,43793 -4777.501918077469,1.3808777332305908,10818.582146406174,33704,0,10818.582146406174,0.9861788749694824,0.046940054744482,0.271115682743845,43793,15598.397938489914,0.9918839931488036,0.0260620433837175,0.5066806571838988,0.9870265126228333,0.0439733080565929,0.2861925315494497,43793 -4878.264587640762,1.417754888534546,11058.546568632126,34457,0,11058.546568632126,0.98611319065094,0.0471945516765117,0.2731183121212644,43793,15939.182308912275,0.991776704788208,0.02636144682765,0.5051048917352821,0.9870297312736512,0.044214628636837,0.290110529395656,43793 -4981.795501708984,1.4510157108306885,11298.666824102402,35208,0,11298.666824102402,0.9862593412399292,0.0469382219016552,0.2796964435428016,43793,16282.887593269348,0.9921960234642028,0.025136025622487,0.5219924353902371,0.9870204329490662,0.0439331047236919,0.2923408581900801,43793 -5079.614475250244,1.4840147495269775,11538.917538404465,35951,0,11538.917538404465,0.9862323999404908,0.0474622808396816,0.2696253351743681,43793,16621.011003494263,0.9921252131462096,0.0250017885118722,0.5476866748148286,0.9870662689208984,0.0444701761007308,0.2893286274359,43793 -5180.689831018448,1.5187046527862549,11779.103993415833,36702,0,11779.103993415833,0.9861683249473572,0.0470186807215213,0.2739539266716618,43793,16962.328053236008,0.992489457130432,0.0240619573742151,0.5486724927922778,0.987052857875824,0.0439661256968975,0.2908290268421575,43793 -5283.982161521912,1.553828477859497,12019.094329595566,37449,0,12019.094329595566,0.9862315058708192,0.0470658540725708,0.2740260003075513,43793,17305.666412353516,0.9924684762954712,0.0240710340440273,0.5645409408952597,0.9871202707290648,0.0441614389419555,0.2904596654457011,43793 -5384.793403148651,1.5871310234069824,12259.080934286118,38196,0,12259.080934286118,0.9862298369407654,0.0478082112967968,0.2720654290337988,43793,17646.517784833908,0.9922829866409302,0.0246264319866895,0.533122242307951,0.9870577454566956,0.0446517802774906,0.2929242246556005,43793 -5485.143780708313,1.6230242252349854,12499.300460100174,38952,0,12499.300460100174,0.9861696362495422,0.0471304431557655,0.2749916187429383,43793,17987.14372611046,0.9922152757644652,0.0249605812132358,0.5255462330459101,0.986975371837616,0.0441494546830654,0.2922564708468797,43793 -5582.348515987396,1.6569876670837402,12739.280121326448,39710,0,12739.280121326448,0.9861927628517152,0.0475735031068325,0.2701267080078491,43793,18324.382556915283,0.992161512374878,0.0250895153731107,0.5311539849439371,0.9870626330375672,0.0443945117294788,0.2939392503608551,43793 -5681.157953500748,1.6909382343292236,12979.375619888306,40468,0,12979.375619888306,0.9862176179885864,0.0470668077468872,0.2757043997765046,43793,18663.341647863388,0.9923699498176576,0.0245824754238128,0.5457115756354987,0.9870626330375672,0.0439769886434078,0.2900833526106087,43793 -5793.282270669937,1.726207971572876,13219.553693056108,41216,0,13219.553693056108,0.9861923456192015,0.0471550598740577,0.2781001258652558,43793,19015.69958639145,0.9924188256263732,0.0242145117372274,0.5426725337847191,0.9870244860649108,0.044272493571043,0.290808543338558,43793 -5894.937569618225,1.7861547470092771,13459.66164469719,41962,0,13459.66164469719,0.9861965775489808,0.0475791618227958,0.2741416501901633,43793,19357.54460954666,0.9924597144126892,0.0239833146333694,0.5485732064088567,0.9870756268501282,0.0444074310362339,0.2981488644825977,43793 -5998.332057952881,1.8204026222229004,13699.857156276705,42711,0,13699.857156276705,0.9863145351409912,0.0470815673470497,0.2796231683478654,43793,19701.188717842106,0.9925682544708252,0.0235717296600341,0.5703393827897597,0.9870427250862122,0.0443245023488998,0.2951427843127839,43793 -6099.189184188843,1.8538603782653809,13939.90434885025,43463,0,13939.90434885025,0.986214280128479,0.0480096787214279,0.2711062177589242,43793,20042.14708971977,0.9928522109985352,0.0226859804242849,0.587405296771667,0.987025260925293,0.0448318682610988,0.2926435323088024,43793 -6203.414827346802,1.889482498168945,14180.02293419838,44220,0,14180.02293419838,0.9862829446792604,0.0478932149708271,0.2773757847676095,43793,20386.546948194504,0.9930338859558104,0.0222247261554002,0.6030117376077448,0.9871328473091124,0.0447461754083633,0.2994731789222761,43793 -6301.98726439476,1.9259374141693115,14420.104619264604,44972,0,14420.104619264604,0.9862542748451232,0.0480829142034053,0.2777546667384631,43793,20725.25773501396,0.9930641651153564,0.0220830775797367,0.6060911317596043,0.987050473690033,0.0448730997741222,0.2943794044568053,43793 -6404.840629816055,1.961230993270874,14660.14389538765,45724,0,14660.14389538765,0.986276626586914,0.0473332963883876,0.2744227430076576,43793,21068.20579099655,0.993030309677124,0.0223230738192796,0.5868794150076238,0.9870455861091614,0.0445633865892887,0.2950067238532902,43793 -6509.072257757187,1.9975545406341555,14900.095911979675,46471,0,14900.095911979675,0.98615700006485,0.0479879155755043,0.2757587022034949,43793,21412.44612908364,0.9928067922592164,0.0229206141084432,0.5767878589178458,0.9869668483734132,0.0449565276503562,0.2892812244489832,43793 -6612.673577547073,2.0322277545928955,15140.2622089386,47223,0,15140.2622089386,0.986199915409088,0.0484224893152713,0.2699944692662746,43793,21756.26882982254,0.9927178621292114,0.0231426432728767,0.5750837991380785,0.9870151281356812,0.0452580600976944,0.2898598701086695,43793 -6714.644959449768,2.067057847976685,15380.479937553406,47963,0,15380.479937553406,0.9862197637557985,0.0482847988605499,0.2736227194885751,43793,22098.51292657852,0.9928799867630004,0.0225607007741928,0.5870601530022526,0.9871381521224976,0.0450849756598472,0.2975629992554879,43793 -6819.330502986908,2.104949951171875,15620.667268753052,48705,0,15620.667268753052,0.9861965775489808,0.0484400577843189,0.268240220401425,43793,22443.44396972656,0.992908239364624,0.0224849469959735,0.5866123590096386,0.9870979189872742,0.0450217388570308,0.2982286863077764,43793 -6918.238473892212,2.142441749572754,15860.90509223938,49454,0,15860.90509223938,0.9861738085746764,0.0482622310519218,0.2745028253297134,43793,22782.647591114044,0.9932429790496826,0.0214418750256299,0.6142716018141637,0.9870386719703674,0.0449598133563995,0.2990870103980558,43793 -7020.40039730072,2.1796648502349854,16101.125497341156,50198,0,16101.125497341156,0.98628968000412,0.0490044243633747,0.2771777891807733,43793,23125.086805820465,0.993171751499176,0.0213980432599782,0.6153558358088966,0.9871393442153932,0.0456449352204799,0.3004330684502952,43793 -7119.673288583756,2.215346097946167,16341.301450490952,50942,0,16341.301450490952,0.986243724822998,0.0487203709781169,0.2774862370226513,43793,23464.592199087143,0.9934930801391602,0.0205563474446535,0.6322967831703095,0.9871681928634644,0.045472003519535,0.2960142246072738,43793 -7223.04040813446,2.2525784969329834,16581.445430278778,51688,0,16581.445430278778,0.9861788749694824,0.0486474893987178,0.2764061666842379,43793,23808.160900831223,0.9936957955360411,0.0200053192675113,0.6529153022018896,0.9870293140411376,0.0454819798469543,0.2990586756813242,43793 -7327.888458013535,2.29858922958374,16821.530868291855,52442,0,16821.530868291855,0.9862736463546752,0.049133013933897,0.2778174709848799,43793,24153.16040802002,0.9935854077339172,0.0202750619500875,0.6460475938094383,0.987106442451477,0.0458005554974079,0.3005454513036418,43793 -7426.502200841904,2.33558201789856,17061.780789375305,53189,0,17061.780789375305,0.9861940741539,0.0492467880249023,0.2722941436238601,43793,24492.081380605698,0.9934902787208556,0.0206452459096908,0.6349567310737267,0.9871125817298888,0.0457358434796333,0.2995423135091453,43793 -7531.053488969803,2.381908178329468,17301.95036458969,53935,0,17301.95036458969,0.9861367344856262,0.0490672029554843,0.2779133849380018,43793,24836.86819219589,0.9932562112808228,0.0211492143571376,0.6126243232067949,0.9869729280471802,0.0458820350468158,0.2990511593052038,43793 -7633.643043994904,2.4187357425689697,17542.20799946785,54684,0,17542.20799946785,0.9862450361251832,0.0496200770139694,0.2787473506538628,43793,25179.772585392,0.9933159947395324,0.0209419373422861,0.6317724607619654,0.987216055393219,0.0461957119405269,0.2959647151213805,43793 -7737.169899225235,2.456860542297364,17782.284705638885,55434,0,17782.284705638885,0.9863107204437256,0.049470916390419,0.277157897267368,43793,25523.43450474739,0.9935021996498108,0.0203129090368747,0.6375259346868289,0.9871361255645752,0.0462104454636573,0.2995275510186174,43793 -7841.688732147217,2.4949471950531006,18022.405507564545,56177,0,18022.405507564545,0.9862247705459596,0.0495844110846519,0.2745376159234082,43793,25868.13215303421,0.9936658143997192,0.0198207218199968,0.6528672863038035,0.9871364831924438,0.0461660847067832,0.3025037933798346,43793 -7941.999848365784,2.5373575687408447,18262.658086299896,56924,0,18262.658086299896,0.9862277507781982,0.04944280534982681,0.27963990690568774,43793,26208.75817298889,0.9937217235565186,0.019551070407032967,0.6607373302507992,0.9871413707733154,0.04628141224384308,0.3001707093408799,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/measurements.csv deleted file mode 100644 index b1e3bb2b5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/measurements.csv +++ /dev/null @@ -1,655 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.6822484,0.73993963,,,,,,,,,,,,,,,,, -1,,,0.5093092918395996,0.7391443848609924,0.0232773786467988,0.5064646005630493,0.7422075867652893,0.0252931536418489,43793.0,0.5047284960746765,0.7438743114471436,0.0269269992061447,43793.0,13.580633878707886,117.30141496658324,13.580633878707886,103.7207329273224,0.0,0.0 -100,0.30844277,0.28630748,,,,,,,,,,,,,,,,, -200,0.09958939,0.10910574,,,,,,,,,,,,,,,,, -300,0.03186115,0.07073575,,,,,,,,,,,,,,,,, -400,0.035819672,0.05903004,,,,,,,,,,,,,,,,, -500,0.025683044,0.05906123,,,,,,,,,,,,,,,,, -600,0.017127005,0.05228337,,,,,,,,,,,,,,,,, -700,0.026687348,0.05071632,,,,,,,,,,,,,,,,, -750,,,0.986806094646454,0.0515382997691631,0.0542356702645199,0.9842458367347716,0.0615352354943752,0.0538402358803808,43793.0,0.9832671880722046,0.064820185303688,0.0545995080153342,43793.0,253.804012298584,461.0968973636627,253.804012298584,207.2520191669464,0.0207283496856689,0.0 -800,0.014552038,0.049253695,,,,,,,,,,,,,,,,, -900,0.016398085,0.049218137,,,,,,,,,,,,,,,,, -1000,0.017486507,0.05031804,,,,,,,,,,,,,,,,, -1100,0.017394537,0.049401537,,,,,,,,,,,,,,,,, -1200,0.015124977,0.05005229,,,,,,,,,,,,,,,,, -1300,0.012625506,0.046881717,,,,,,,,,,,,,,,,, -1400,0.014980021,0.049485665,,,,,,,,,,,,,,,,, -1495,,,0.9871118068695068,0.0480010174214839,0.083969882400915,0.9844828844070436,0.057401418685913,0.0852865402316337,43793.0,0.9834954738616944,0.0607746839523315,0.0840767077035999,43793.0,494.0419046878815,805.2264924049377,494.0419046878815,311.0966327190399,0.0467598438262939,0.0 -1500,0.045317743,0.05273185,,,,,,,,,,,,,,,,, -1600,0.021235483,0.04304016,,,,,,,,,,,,,,,,, -1700,0.01931009,0.05178459,,,,,,,,,,,,,,,,, -1800,0.018702602,0.05406435,,,,,,,,,,,,,,,,, -1900,0.01529531,0.049655497,,,,,,,,,,,,,,,,, -2000,0.021184761,0.054395437,,,,,,,,,,,,,,,,, -2100,0.012230952,0.051681604,,,,,,,,,,,,,,,,, -2200,0.024575382,0.040767968,,,,,,,,,,,,,,,,, -2243,,,0.987299919128418,0.0455262772738933,0.1279417429584762,0.9847471714019777,0.0546827651560306,0.1286468551470923,43793.0,0.9837182760238647,0.0578309632837772,0.123550536234461,43793.0,734.1674497127533,1150.4348917007446,734.1674497127533,416.1317677497864,0.0737733840942382,0.0 -2300,0.012708174,0.040023174,,,,,,,,,,,,,,,,, -2400,0.01902501,0.044929937,,,,,,,,,,,,,,,,, -2500,0.03033742,0.048666853,,,,,,,,,,,,,,,,, -2600,0.018537696,0.044210948,,,,,,,,,,,,,,,,, -2700,0.009316569,0.038751546,,,,,,,,,,,,,,,,, -2800,0.022883793,0.04277721,,,,,,,,,,,,,,,,, -2900,0.011581413,0.042596295,,,,,,,,,,,,,,,,, -2991,,,0.9879700541496276,0.0423735380172729,0.1602362821606831,0.9851559400558472,0.051908191293478,0.1475911315728504,43793.0,0.9842160940170288,0.0548467971384525,0.1484768175352116,43793.0,974.1777114868164,1492.042894601822,974.1777114868164,517.6807377338409,0.1014776229858398,0.0 -3000,0.015730416,0.039539386,,,,,,,,,,,,,,,,, -3100,0.015083541,0.042847123,,,,,,,,,,,,,,,,, -3200,0.014150587,0.045711953,,,,,,,,,,,,,,,,, -3300,0.012131287,0.03909364,,,,,,,,,,,,,,,,, -3400,0.013585965,0.04489986,,,,,,,,,,,,,,,,, -3500,0.013884426,0.046837892,,,,,,,,,,,,,,,,, -3600,0.013825706,0.04277332,,,,,,,,,,,,,,,,, -3700,0.017524034,0.040461954,,,,,,,,,,,,,,,,, -3736,,,0.9882538318634032,0.0412957370281219,0.1869888306182944,0.9851595759391784,0.0504656732082366,0.1660371829803102,43793.0,0.984274685382843,0.0530230216681957,0.1659362918288775,43793.0,1214.414056301117,1832.6644973754885,1214.414056301117,618.0172493457794,0.1294021606445312,0.0 -3800,0.014297964,0.042807247,,,,,,,,,,,,,,,,, -3900,0.011802253,0.04137255,,,,,,,,,,,,,,,,, -4000,0.015719038,0.03470147,,,,,,,,,,,,,,,,, -4100,0.01130149,0.038264032,,,,,,,,,,,,,,,,, -4200,0.012062432,0.038918566,,,,,,,,,,,,,,,,, -4300,0.012746202,0.038784463,,,,,,,,,,,,,,,,, -4400,0.023885569,0.042032864,,,,,,,,,,,,,,,,, -4488,,,0.9885920882225036,0.0392466597259044,0.2131411379152453,0.9855939149856568,0.0488959811627864,0.1844909876500332,43793.0,0.9847304224967957,0.051666285842657,0.1886703707765753,43793.0,1454.5042502880096,2181.4644510746,1454.5042502880096,726.6796119213104,0.1562621593475341,0.0 -4500,0.010781999,0.037847884,,,,,,,,,,,,,,,,, -4600,0.011751143,0.0418436,,,,,,,,,,,,,,,,, -4700,0.01247252,0.037130106,,,,,,,,,,,,,,,,, -4800,0.016952347,0.04149737,,,,,,,,,,,,,,,,, -4900,0.013207966,0.044041645,,,,,,,,,,,,,,,,, -5000,0.027507445,0.040087316,,,,,,,,,,,,,,,,, -5100,0.009887483,0.03719158,,,,,,,,,,,,,,,,, -5200,0.014089693,0.03477419,,,,,,,,,,,,,,,,, -5241,,,0.9889898300170898,0.0376451984047889,0.2380226997982301,0.9859138131141664,0.0475182346999645,0.2031912993116336,43793.0,0.9850661158561708,0.0502384155988693,0.2137931187659128,43793.0,1694.6122086048126,2530.281988620758,1694.6122086048126,835.3410527706146,0.1837186813354492,0.0 -5300,0.018807758,0.03633782,,,,,,,,,,,,,,,,, -5400,0.014906427,0.037370708,,,,,,,,,,,,,,,,, -5500,0.012689859,0.0420153,,,,,,,,,,,,,,,,, -5600,0.013369196,0.036086936,,,,,,,,,,,,,,,,, -5700,0.011339256,0.035903584,,,,,,,,,,,,,,,,, -5800,0.02773902,0.04072581,,,,,,,,,,,,,,,,, -5900,0.016711239,0.040635936,,,,,,,,,,,,,,,,, -5993,,,0.9889129996299744,0.0374315716326236,0.2512646496902923,0.9860774278640748,0.0465936623513698,0.2131519201854946,43793.0,0.9853339791297911,0.0492023862898349,0.2152423660366171,43793.0,1934.768192768097,2878.444193124771,1934.768192768097,943.2999262809752,0.2107102870941162,0.0 -6000,0.014020402,0.033879668,,,,,,,,,,,,,,,,, -6100,0.019684067,0.041867323,,,,,,,,,,,,,,,,, -6200,0.017333247,0.034078725,,,,,,,,,,,,,,,,, -6300,0.01199608,0.038318932,,,,,,,,,,,,,,,,, -6400,0.017048221,0.035908394,,,,,,,,,,,,,,,,, -6500,0.012007099,0.034975752,,,,,,,,,,,,,,,,, -6600,0.015465303,0.038455017,,,,,,,,,,,,,,,,, -6700,0.0107259,0.039480448,,,,,,,,,,,,,,,,, -6748,,,0.989357590675354,0.0359286963939666,0.2719629300015331,0.9862214922904968,0.0462537966668605,0.2262959319746798,43793.0,0.9853845238685608,0.0488400496542453,0.2211056133914604,43793.0,2174.874867916107,3218.5420083999634,2174.874867916107,1043.2435188293457,0.2378804683685302,0.0 -6800,0.012660387,0.033878066,,,,,,,,,,,,,,,,, -6900,0.015360426,0.034699712,,,,,,,,,,,,,,,,, -7000,0.019956972,0.035891194,,,,,,,,,,,,,,,,, -7100,0.009733217,0.031483468,,,,,,,,,,,,,,,,, -7200,0.017324362,0.035891186,,,,,,,,,,,,,,,,, -7300,0.013824456,0.034074333,,,,,,,,,,,,,,,,, -7400,0.016984496,0.037222374,,,,,,,,,,,,,,,,, -7493,,,0.9894558787345886,0.0353555269539356,0.3062496271384285,0.9862868785858154,0.0459942333400249,0.2279327528108939,43793.0,0.9854055643081664,0.0487098097801208,0.2348598280664689,43793.0,2415.1428520679474,3565.234797239304,2415.1428520679474,1149.6194591522217,0.2666077613830566,0.0 -7500,0.015678102,0.038314577,,,,,,,,,,,,,,,,, -7600,0.01442435,0.038405415,,,,,,,,,,,,,,,,, -7700,0.020938192,0.041911233,,,,,,,,,,,,,,,,, -7800,0.020674584,0.03446281,,,,,,,,,,,,,,,,, -7900,0.022332061,0.035681006,,,,,,,,,,,,,,,,, -8000,0.013232895,0.035782643,,,,,,,,,,,,,,,,, -8100,0.015640901,0.03823643,,,,,,,,,,,,,,,,, -8200,0.03296369,0.03506887,,,,,,,,,,,,,,,,, -8242,,,0.989901602268219,0.0337459854781627,0.3307906793603523,0.9863956570625304,0.0453762374818325,0.2373848047517522,43793.0,0.9856169819831848,0.0478177182376384,0.2419383466019386,43793.0,2655.087381839752,3910.5265514850616,2655.087381839752,1254.9102308750153,0.3024904727935791,0.0 -8300,0.023457436,0.041547094,,,,,,,,,,,,,,,,, -8400,0.015038191,0.032298516,,,,,,,,,,,,,,,,, -8500,0.022489497,0.03974186,,,,,,,,,,,,,,,,, -8600,0.018366668,0.035623066,,,,,,,,,,,,,,,,, -8700,0.017948441,0.036096796,,,,,,,,,,,,,,,,, -8800,0.017665071,0.03515437,,,,,,,,,,,,,,,,, -8900,0.028361924,0.03625804,,,,,,,,,,,,,,,,, -8996,,,0.99005526304245,0.0332898646593093,0.3302585783276252,0.9864711761474608,0.0451908968389034,0.2407431674503511,43793.0,0.9856843948364258,0.0478332489728927,0.2460210322111171,43793.0,2895.369548559189,4255.388598442078,2895.369548559189,1359.43941283226,0.332345962524414,0.0 -9000,0.022229966,0.034086257,,,,,,,,,,,,,,,,, -9100,0.017511437,0.040215693,,,,,,,,,,,,,,,,, -9200,0.02384682,0.03462075,,,,,,,,,,,,,,,,, -9300,0.024616167,0.037595697,,,,,,,,,,,,,,,,, -9400,0.021631803,0.03309302,,,,,,,,,,,,,,,,, -9500,0.028591845,0.034291945,,,,,,,,,,,,,,,,, -9600,0.020265907,0.03151688,,,,,,,,,,,,,,,,, -9700,0.018417066,0.032655127,,,,,,,,,,,,,,,,, -9748,,,0.990172564983368,0.0328486375510692,0.3479799884843341,0.9866335391998292,0.0446899943053722,0.2526990602814843,43793.0,0.9858174920082092,0.0472788773477077,0.2526541590940708,43793.0,3135.624180316925,4600.473388910294,3135.624180316925,1464.221135854721,0.3602263927459717,0.0 -9800,0.026276985,0.035623778,,,,,,,,,,,,,,,,, -9900,0.021263866,0.03260198,,,,,,,,,,,,,,,,, -10000,0.0301605,0.037271723,,,,,,,,,,,,,,,,, -10100,0.03065554,0.034744047,,,,,,,,,,,,,,,,, -10200,0.0247254,0.036793202,,,,,,,,,,,,,,,,, -10300,0.027084045,0.035451576,,,,,,,,,,,,,,,,, -10400,0.020922074,0.030597495,,,,,,,,,,,,,,,,, -10498,,,0.9901142716407776,0.033044509589672,0.342539331587014,0.9865771532058716,0.0446068868041038,0.2567563124017319,43793.0,0.9857223033905028,0.0474354512989521,0.2465764344200082,43793.0,3375.697806596756,4944.093198299408,3375.697806596756,1567.718020439148,0.3888866901397705,0.0 -10500,0.024734624,0.03413873,,,,,,,,,,,,,,,,, -10600,0.022202237,0.036343805,,,,,,,,,,,,,,,,, -10700,0.020666337,0.034782898,,,,,,,,,,,,,,,,, -10800,0.018994404,0.030404331,,,,,,,,,,,,,,,,, -10900,0.025415806,0.034351133,,,,,,,,,,,,,,,,, -11000,0.029336825,0.03507297,,,,,,,,,,,,,,,,, -11100,0.02131171,0.03211019,,,,,,,,,,,,,,,,, -11200,0.019789139,0.031937018,,,,,,,,,,,,,,,,, -11243,,,0.9903327226638794,0.0322994701564312,0.3561079984004907,0.9867005348205566,0.0442373193800449,0.2580287784778692,43793.0,0.9858258962631226,0.0469792783260345,0.2500176737745309,43793.0,3615.875280857086,5290.306088924408,3615.875280857086,1673.7032914161682,0.4190025329589844,0.0 -11300,0.026654745,0.03620753,,,,,,,,,,,,,,,,, -11400,0.018834312,0.034221623,,,,,,,,,,,,,,,,, -11500,0.022651054,0.032538917,,,,,,,,,,,,,,,,, -11600,0.020689528,0.030184805,,,,,,,,,,,,,,,,, -11700,0.027028527,0.03261209,,,,,,,,,,,,,,,,, -11800,0.029925037,0.035035945,,,,,,,,,,,,,,,,, -11900,0.037701618,0.036798514,,,,,,,,,,,,,,,,, -11987,,,0.990323007106781,0.0320759005844593,0.3612385092120623,0.9867829084396362,0.0440880171954631,0.2602869942875813,43793.0,0.9858882427215576,0.0469309277832508,0.2575036834951336,43793.0,3855.9695858955374,5630.856683969498,3855.9695858955374,1774.1074786186218,0.4494631290435791,0.0 -12000,0.029908132,0.035350062,,,,,,,,,,,,,,,,, -12100,0.025488388,0.030858424,,,,,,,,,,,,,,,,, -12200,0.026565908,0.03327901,,,,,,,,,,,,,,,,, -12300,0.027849605,0.03307687,,,,,,,,,,,,,,,,, -12400,0.027244093,0.030370992,,,,,,,,,,,,,,,,, -12500,0.025416335,0.03194479,,,,,,,,,,,,,,,,, -12600,0.023417778,0.03102662,,,,,,,,,,,,,,,,, -12700,0.035488725,0.030716533,,,,,,,,,,,,,,,,, -12742,,,0.9905083775520324,0.0314913280308246,0.3642103863357995,0.9867606163024902,0.0440910868346691,0.2677271230835614,43793.0,0.9859021306037904,0.047086376696825,0.2540120004735609,43793.0,4096.190707445145,5972.480631828308,4096.190707445145,1875.4612171649933,0.4779088497161865,0.0 -12800,0.026882282,0.03233419,,,,,,,,,,,,,,,,, -12900,0.03851984,0.03316256,,,,,,,,,,,,,,,,, -13000,0.029389895,0.037145633,,,,,,,,,,,,,,,,, -13100,0.02793384,0.03740531,,,,,,,,,,,,,,,,, -13200,0.03438821,0.030081885,,,,,,,,,,,,,,,,, -13300,0.032042213,0.03390949,,,,,,,,,,,,,,,,, -13400,0.031777907,0.034651466,,,,,,,,,,,,,,,,, -13482,,,0.9906244874000548,0.0309746414422988,0.389609197822945,0.9867532849311828,0.0439917035400867,0.2688608755476995,43793.0,0.9860032200813292,0.0466867499053478,0.2611335866362385,43793.0,4336.178166389465,6320.635596513748,4336.178166389465,1983.5776450634005,0.5073034763336182,0.0 -13500,0.0247527,0.030487714,,,,,,,,,,,,,,,,, -13600,0.05176085,0.032288056,,,,,,,,,,,,,,,,, -13700,0.026812127,0.031945802,,,,,,,,,,,,,,,,, -13800,0.027401093,0.030112214,,,,,,,,,,,,,,,,, -13900,0.03181436,0.029411295,,,,,,,,,,,,,,,,, -14000,0.032727268,0.032236516,,,,,,,,,,,,,,,,, -14100,0.03400262,0.027907906,,,,,,,,,,,,,,,,, -14200,0.032807034,0.028808603,,,,,,,,,,,,,,,,, -14236,,,0.990672528743744,0.0306298565119504,0.3995707839872688,0.986504077911377,0.0440920814871788,0.2636364530613028,43793.0,0.9858457446098328,0.0466507412493228,0.2567049480858492,43793.0,4576.319272518158,6664.528880357742,4576.319272518158,2087.2807455062866,0.5352981090545654,0.0 -14300,0.043393604,0.034094643,,,,,,,,,,,,,,,,, -14400,0.031511,0.032581802,,,,,,,,,,,,,,,,, -14500,0.030820448,0.028874539,,,,,,,,,,,,,,,,, -14600,0.040167585,0.033755478,,,,,,,,,,,,,,,,, -14700,0.03308061,0.0368372,,,,,,,,,,,,,,,,, -14800,0.03242112,0.032083604,,,,,,,,,,,,,,,,, -14900,0.046490856,0.0334266,,,,,,,,,,,,,,,,, -14991,,,0.9907553195953368,0.0302142277359962,0.4122109902822474,0.986757755279541,0.0445883460342884,0.2677942141421415,43793.0,0.9859707951545716,0.0472371615469455,0.2638035088881912,43793.0,4816.31423330307,7012.129799127579,4816.31423330307,2194.837212085724,0.5642292499542236,0.0 -15000,0.039510086,0.03140608,,,,,,,,,,,,,,,,, -15100,0.037890065,0.03563528,,,,,,,,,,,,,,,,, -15200,0.039620653,0.03678184,,,,,,,,,,,,,,,,, -15300,0.037212778,0.03147789,,,,,,,,,,,,,,,,, -15400,0.035727795,0.033020504,,,,,,,,,,,,,,,,, -15500,0.040449694,0.033627965,,,,,,,,,,,,,,,,, -15600,0.040272728,0.0320415,,,,,,,,,,,,,,,,, -15700,0.043359946,0.029165555,,,,,,,,,,,,,,,,, -15729,,,0.9911349415779114,0.029304539784789,0.430178080129186,0.9867480397224426,0.0441630408167839,0.2715012183204673,43793.0,0.9859463572502136,0.0468619801104068,0.2607283811938001,43793.0,5056.359443902969,7353.330956935883,5056.359443902969,2295.940425634384,0.5946774482727051,0.0 -15800,0.031945944,0.030599719,,,,,,,,,,,,,,,,, -15900,0.03314164,0.029055635,,,,,,,,,,,,,,,,, -16000,0.047757216,0.031848025,,,,,,,,,,,,,,,,, -16100,0.033160668,0.03149555,,,,,,,,,,,,,,,,, -16200,0.03520024,0.033864334,,,,,,,,,,,,,,,,, -16300,0.040976115,0.032651443,,,,,,,,,,,,,,,,, -16400,0.037085116,0.030178377,,,,,,,,,,,,,,,,, -16471,,,0.9909190535545348,0.0297707431018352,0.4091028998139538,0.9867866039276124,0.0440623573958873,0.2620690775753606,43793.0,0.9859893321990968,0.0467883460223674,0.2560449718034864,43793.0,5296.309223651886,7695.902285814285,5296.309223651886,2398.510513544082,0.6254196166992188,0.0 -16500,0.034525886,0.030428536,,,,,,,,,,,,,,,,, -16600,0.042689685,0.032810427,,,,,,,,,,,,,,,,, -16700,0.04033169,0.030190254,,,,,,,,,,,,,,,,, -16800,0.0417036,0.0328574,,,,,,,,,,,,,,,,, -16900,0.043017283,0.03392076,,,,,,,,,,,,,,,,, -17000,0.037993822,0.03173436,,,,,,,,,,,,,,,,, -17100,0.043980483,0.027037317,,,,,,,,,,,,,,,,, -17200,0.03883272,0.032048363,,,,,,,,,,,,,,,,, -17210,,,0.9909037947654724,0.0299575366079807,0.4211723660599582,0.9867788553237916,0.044044554233551,0.2709214835239603,43793.0,0.9859653115272522,0.0465751178562641,0.2614984133559347,43793.0,5536.299431324005,8038.960066795349,5536.299431324005,2501.5175173282623,0.6656818389892578,0.0 -17300,0.031591184,0.0314021,,,,,,,,,,,,,,,,, -17400,0.04436894,0.037511084,,,,,,,,,,,,,,,,, -17500,0.05516659,0.034260895,,,,,,,,,,,,,,,,, -17600,0.049132664,0.03157108,,,,,,,,,,,,,,,,, -17700,0.039938465,0.03041527,,,,,,,,,,,,,,,,, -17800,0.03786429,0.031174075,,,,,,,,,,,,,,,,, -17900,0.07138623,0.03251131,,,,,,,,,,,,,,,,, -17958,,,0.9907863736152648,0.030131721869111,0.4060708456613312,0.9868494868278505,0.0441721640527248,0.2710522605596167,43793.0,0.986120343208313,0.046714399009943,0.2669996794222807,43793.0,5776.414937496185,8387.298349618912,5776.414937496185,2609.689152956009,0.6960053443908691,0.0 -18000,0.04672356,0.035144698,,,,,,,,,,,,,,,,, -18100,0.060882956,0.031529933,,,,,,,,,,,,,,,,, -18200,0.03506509,0.029418662,,,,,,,,,,,,,,,,, -18300,0.048599243,0.033854797,,,,,,,,,,,,,,,,, -18400,0.040352903,0.029548978,,,,,,,,,,,,,,,,, -18500,0.038255088,0.031411324,,,,,,,,,,,,,,,,, -18600,0.040904734,0.03281376,,,,,,,,,,,,,,,,, -18700,0.04114768,0.034922104,,,,,,,,,,,,,,,,, -18709,,,0.9909077882766724,0.0298218131065368,0.4070056823467816,0.9868308305740356,0.0438988842070102,0.2765967366141773,43793.0,0.986135482788086,0.0465055629611015,0.2658340993299205,43793.0,6016.625653028488,8733.897119998932,6016.625653028488,2716.024253845215,0.7282171249389648,0.0 -18800,0.04601279,0.03161406,,,,,,,,,,,,,,,,, -18900,0.043327667,0.031443883,,,,,,,,,,,,,,,,, -19000,0.038540013,0.02928916,,,,,,,,,,,,,,,,, -19100,0.044607162,0.029524606,,,,,,,,,,,,,,,,, -19200,0.040439468,0.029806893,,,,,,,,,,,,,,,,, -19300,0.041294783,0.03249111,,,,,,,,,,,,,,,,, -19400,0.04432755,0.029800462,,,,,,,,,,,,,,,,, -19465,,,0.9910405278205872,0.029201403260231,0.4292364006707121,0.9868568181991576,0.0440428256988525,0.2791245527026243,43793.0,0.9860706329345704,0.0468316636979579,0.2596261907954604,43793.0,6256.626332998276,9076.95063853264,6256.626332998276,2819.025629043579,0.7590713500976562,0.0 -19500,0.043838896,0.031965636,,,,,,,,,,,,,,,,, -19600,0.043403745,0.02784872,,,,,,,,,,,,,,,,, -19700,0.037822273,0.029529233,,,,,,,,,,,,,,,,, -19800,0.05383147,0.031371098,,,,,,,,,,,,,,,,, -19900,0.046447404,0.030559443,,,,,,,,,,,,,,,,, -20000,0.050581217,0.032277867,,,,,,,,,,,,,,,,, -20100,0.044715077,0.028258119,,,,,,,,,,,,,,,,, -20200,0.043454222,0.03146793,,,,,,,,,,,,,,,,, -20215,,,0.9910680651664734,0.0291820932179689,0.439064296870055,0.986634373664856,0.0441146455705165,0.2813325085483499,43793.0,0.9858280420303344,0.0466070771217346,0.2692436956815187,43793.0,6496.625156402588,9420.829341888428,6496.625156402588,2922.849404811859,0.792198657989502,0.0 -20300,0.046328582,0.030662103,,,,,,,,,,,,,,,,, -20400,0.04510921,0.03528918,,,,,,,,,,,,,,,,, -20500,0.047417123,0.028880766,,,,,,,,,,,,,,,,, -20600,0.045258883,0.029707972,,,,,,,,,,,,,,,,, -20700,0.042219162,0.030604769,,,,,,,,,,,,,,,,, -20800,0.05037254,0.030006824,,,,,,,,,,,,,,,,, -20900,0.04416389,0.02761419,,,,,,,,,,,,,,,,, -20963,,,0.9911518692970276,0.0288035813719034,0.4462325316914583,0.9868718385696412,0.0439016968011856,0.2812189409328452,43793.0,0.9860310554504396,0.0468637309968471,0.2672742096154502,43793.0,6736.826936483383,9760.718255281448,6736.826936483383,3022.484857082367,0.8230326175689697,0.0 -21000,0.046132587,0.03148344,,,,,,,,,,,,,,,,, -21100,0.05635858,0.033211607,,,,,,,,,,,,,,,,, -21200,0.046749447,0.030747024,,,,,,,,,,,,,,,,, -21300,0.05492916,0.029451838,,,,,,,,,,,,,,,,, -21400,0.049497508,0.031212997,,,,,,,,,,,,,,,,, -21500,0.04921113,0.030011514,,,,,,,,,,,,,,,,, -21600,0.05117053,0.031286985,,,,,,,,,,,,,,,,, -21700,0.044834808,0.028213184,,,,,,,,,,,,,,,,, -21713,,,0.9913423657417296,0.0279766172170639,0.4702118638735022,0.9869785904884338,0.0438735596835613,0.285625691321616,43793.0,0.9860352277755736,0.0469725541770458,0.2679272339349347,43793.0,6977.031692266464,10102.648232460022,6977.031692266464,3124.1582732200623,0.8541834354400635,0.0 -21800,0.043387312,0.030039826,,,,,,,,,,,,,,,,, -21900,0.043853257,0.029257722,,,,,,,,,,,,,,,,, -22000,0.05247441,0.027562749,,,,,,,,,,,,,,,,, -22100,0.062094178,0.03285776,,,,,,,,,,,,,,,,, -22200,0.05163566,0.033869904,,,,,,,,,,,,,,,,, -22300,0.04839327,0.03166308,,,,,,,,,,,,,,,,, -22400,0.06245857,0.03206363,,,,,,,,,,,,,,,,, -22456,,,0.9913394451141356,0.0280387885868549,0.4604188786782695,0.9869635701179504,0.0437427349388599,0.2827766819212199,43793.0,0.9862037301063538,0.0466453246772289,0.2640839126761976,43793.0,7217.031193494797,10448.873989105225,7217.031193494797,3230.3336248397827,0.8844475746154785,0.0 -22500,0.055678997,0.034198236,,,,,,,,,,,,,,,,, -22600,0.048818514,0.029576601,,,,,,,,,,,,,,,,, -22700,0.047712877,0.02889804,,,,,,,,,,,,,,,,, -22800,0.052007586,0.031564146,,,,,,,,,,,,,,,,, -22900,0.048970636,0.027376888,,,,,,,,,,,,,,,,, -23000,0.049410872,0.03300907,,,,,,,,,,,,,,,,, -23100,0.045668498,0.02856014,,,,,,,,,,,,,,,,, -23200,0.05610357,0.033707824,,,,,,,,,,,,,,,,, -23204,,,0.9916256070137024,0.0272593423724174,0.4747280372458277,0.9869043231010436,0.0440606139600276,0.2845651073448388,43793.0,0.986123263835907,0.0469043962657451,0.2627899405295638,43793.0,7457.124490976334,10790.28831577301,7457.124490976334,3331.6023383140564,0.9160325527191162,0.0 -23300,0.053384878,0.030686453,,,,,,,,,,,,,,,,, -23400,0.045682125,0.02716943,,,,,,,,,,,,,,,,, -23500,0.064379014,0.02841579,,,,,,,,,,,,,,,,, -23600,0.053989023,0.027818521,,,,,,,,,,,,,,,,, -23700,0.05759028,0.029839274,,,,,,,,,,,,,,,,, -23800,0.047486704,0.03293434,,,,,,,,,,,,,,,,, -23900,0.05377316,0.030040676,,,,,,,,,,,,,,,,, -23956,,,0.9913517236709596,0.0281870812177658,0.4601179168266128,0.9868783354759216,0.0442613177001476,0.280703321938136,43793.0,0.985987663269043,0.0470348894596099,0.2644944977788918,43793.0,7697.121632099152,11135.16293668747,7697.121632099152,3436.4192943573,0.9561116695404052,0.0 -24000,0.064726,0.02887722,,,,,,,,,,,,,,,,, -24100,0.06974245,0.03700547,,,,,,,,,,,,,,,,, -24200,0.050565865,0.030583432,,,,,,,,,,,,,,,,, -24300,0.055739176,0.028263258,,,,,,,,,,,,,,,,, -24400,0.059408978,0.033353187,,,,,,,,,,,,,,,,, -24500,0.058616977,0.0320942,,,,,,,,,,,,,,,,, -24600,0.05108016,0.029253555,,,,,,,,,,,,,,,,, -24700,0.056751937,0.030768268,,,,,,,,,,,,,,,,, -24711,,,0.9913532137870787,0.0281356312334537,0.4598455509794903,0.9870455861091614,0.0438859760761261,0.288337047625382,43793.0,0.9862121343612672,0.0468297265470027,0.2668786160394751,43793.0,7937.132448196411,11479.138907432556,7937.132448196411,3540.332468509674,0.987633228302002,0.0 -24800,0.053022057,0.033792827,,,,,,,,,,,,,,,,, -24900,0.052693333,0.030715734,,,,,,,,,,,,,,,,, -25000,0.04973752,0.02961099,,,,,,,,,,,,,,,,, -25100,0.08454163,0.032678377,,,,,,,,,,,,,,,,, -25200,0.058327857,0.029224873,,,,,,,,,,,,,,,,, -25300,0.055105843,0.03014172,,,,,,,,,,,,,,,,, -25400,0.058627035,0.030497625,,,,,,,,,,,,,,,,, -25462,,,0.991429328918457,0.0277799628674983,0.4671925935508493,0.9870439767837524,0.0436918511986732,0.2888897962900462,43793.0,0.9861923456192015,0.046654887497425,0.2683779409036351,43793.0,8177.254436969757,11820.850269317629,8177.254436969757,3641.869652032852,1.0197508335113523,0.0 -25500,0.056948636,0.029343,,,,,,,,,,,,,,,,, -25600,0.04861697,0.029360058,,,,,,,,,,,,,,,,, -25700,0.049492,0.028382536,,,,,,,,,,,,,,,,, -25800,0.05645642,0.030827956,,,,,,,,,,,,,,,,, -25900,0.04774199,0.030548591,,,,,,,,,,,,,,,,, -26000,0.0580501,0.03126267,,,,,,,,,,,,,,,,, -26100,0.06387335,0.030787576,,,,,,,,,,,,,,,,, -26200,0.056360397,0.032452125,,,,,,,,,,,,,,,,, -26214,,,0.991330623626709,0.028105879202485,0.4568819825250968,0.986963987350464,0.043882068246603,0.2856185672583953,43793.0,0.9861291646957396,0.0465985499322414,0.2675201139818264,43793.0,8417.495745420456,12164.87309885025,8417.495745420456,3745.5998179912567,1.0510282516479492,0.0 -26300,0.054509144,0.029430052,,,,,,,,,,,,,,,,, -26400,0.06344946,0.03217152,,,,,,,,,,,,,,,,, -26500,0.053895585,0.027117146,,,,,,,,,,,,,,,,, -26600,0.062319648,0.026693786,,,,,,,,,,,,,,,,, -26700,0.055914227,0.03016545,,,,,,,,,,,,,,,,, -26800,0.052340195,0.027716432,,,,,,,,,,,,,,,,, -26900,0.056075037,0.03082601,,,,,,,,,,,,,,,,, -26964,,,0.9913989305496216,0.0279050935059785,0.4681628812771156,0.9868897199630736,0.043874591588974,0.2818689651514881,43793.0,0.986011266708374,0.0466346591711044,0.2703870400534875,43793.0,8657.630143404007,12506.092809200289,8657.630143404007,3846.632305622101,1.0833914279937744,0.0 -27000,0.07183692,0.030243605,,,,,,,,,,,,,,,,, -27100,0.057274837,0.027670687,,,,,,,,,,,,,,,,, -27200,0.05594057,0.031422783,,,,,,,,,,,,,,,,, -27300,0.06084533,0.028777186,,,,,,,,,,,,,,,,, -27400,0.05308084,0.030040933,,,,,,,,,,,,,,,,, -27500,0.067217514,0.0267337,,,,,,,,,,,,,,,,, -27600,0.06480537,0.031929906,,,,,,,,,,,,,,,,, -27700,0.061667293,0.032063644,,,,,,,,,,,,,,,,, -27709,,,0.9916355013847352,0.0270985979586839,0.4803630402502638,0.986885666847229,0.043906919658184,0.286745084223735,43793.0,0.9860866665840148,0.0466265827417373,0.2637615840012572,43793.0,8897.8357629776,12849.57141637802,8897.8357629776,3949.8536190986633,1.1148099899291992,0.0 -27800,0.0574411,0.030113826,,,,,,,,,,,,,,,,, -27900,0.06345062,0.032355648,,,,,,,,,,,,,,,,, -28000,0.051000763,0.028721128,,,,,,,,,,,,,,,,, -28100,0.057133827,0.030210365,,,,,,,,,,,,,,,,, -28200,0.05375867,0.031111704,,,,,,,,,,,,,,,,, -28300,0.051738307,0.029699504,,,,,,,,,,,,,,,,, -28400,0.049851056,0.028622994,,,,,,,,,,,,,,,,, -28458,,,0.9917035102844238,0.0267012659460306,0.4912260841639426,0.9870139360427856,0.0437790490686893,0.2853140712726124,43793.0,0.9860950708389282,0.0467545948922634,0.2691075067365001,43793.0,9137.999815702438,13191.534855604172,9137.999815702438,4051.599108457565,1.148134708404541,0.0 -28500,0.07049203,0.028957076,,,,,,,,,,,,,,,,, -28600,0.051427636,0.027819537,,,,,,,,,,,,,,,,, -28700,0.05645692,0.027069671,,,,,,,,,,,,,,,,, -28800,0.056768205,0.029980445,,,,,,,,,,,,,,,,, -28900,0.052854694,0.0305609,,,,,,,,,,,,,,,,, -29000,0.057415884,0.031719875,,,,,,,,,,,,,,,,, -29100,0.053189833,0.028677963,,,,,,,,,,,,,,,,, -29200,0.05288439,0.025271479,,,,,,,,,,,,,,,,, -29204,,,0.9919525980949402,0.0259919613599777,0.5057251405123364,0.9870707392692566,0.0437920428812503,0.290256895570826,43793.0,0.986171305179596,0.0467007234692573,0.2694446762523158,43793.0,9378.210587263107,13533.442611694336,9378.210587263107,4153.242586612701,1.1815519332885742,0.0 -29300,0.07286534,0.030497892,,,,,,,,,,,,,,,,, -29400,0.05117185,0.027845873,,,,,,,,,,,,,,,,, -29500,0.05934758,0.030462934,,,,,,,,,,,,,,,,, -29600,0.053894293,0.027755165,,,,,,,,,,,,,,,,, -29700,0.057707027,0.028382195,,,,,,,,,,,,,,,,, -29800,0.05744629,0.026204674,,,,,,,,,,,,,,,,, -29900,0.05590501,0.029009914,,,,,,,,,,,,,,,,, -29961,,,0.991941213607788,0.0258173458278179,0.5112644960667561,0.9870638251304626,0.0439641214907169,0.289960550844474,43793.0,0.9861283302307128,0.0471483767032623,0.2724732373198176,43793.0,9618.264122486116,13880.021289348602,9618.264122486116,4259.714983224869,1.2135634422302246,0.0 -30000,0.055984847,0.031070786,,,,,,,,,,,,,,,,, -30100,0.055826604,0.030017078,,,,,,,,,,,,,,,,, -30200,0.086276874,0.03084483,,,,,,,,,,,,,,,,, -30300,0.058087755,0.030613981,,,,,,,,,,,,,,,,, -30400,0.05661787,0.026540369,,,,,,,,,,,,,,,,, -30500,0.0749736,0.026772011,,,,,,,,,,,,,,,,, -30600,0.057853103,0.029380154,,,,,,,,,,,,,,,,, -30700,0.059122853,0.031447444,,,,,,,,,,,,,,,,, -30714,,,0.9917917847633362,0.0263051856309175,0.5038807477785993,0.9869644045829772,0.0440908707678318,0.2894453883646158,43793.0,0.9860677123069764,0.0471735931932926,0.2675294987951833,43793.0,9858.291661024094,14224.340274810793,9858.291661024094,4363.952226638794,1.246518850326538,0.0 -30800,0.062981755,0.02752997,,,,,,,,,,,,,,,,, -30900,0.056720268,0.025092024,,,,,,,,,,,,,,,,, -31000,0.05909593,0.028413985,,,,,,,,,,,,,,,,, -31100,0.054236393,0.025920793,,,,,,,,,,,,,,,,, -31200,0.057164833,0.028268881,,,,,,,,,,,,,,,,, -31300,0.06858953,0.031529017,,,,,,,,,,,,,,,,, -31400,0.06995159,0.029131351,,,,,,,,,,,,,,,,, -31458,,,0.9918238520622252,0.0264401771128177,0.4987664840541888,0.9869713187217712,0.0437646061182022,0.2861938749291255,43793.0,0.9861093759536744,0.0466942265629768,0.2704962118586615,43793.0,10098.319123268127,14571.425382375715,10098.319123268127,4470.956292152405,1.2795672416687012,0.0 -31500,0.056747798,0.029041015,,,,,,,,,,,,,,,,, -31600,0.069274,0.028170213,,,,,,,,,,,,,,,,, -31700,0.07040847,0.029127127,,,,,,,,,,,,,,,,, -31800,0.06871668,0.029608054,,,,,,,,,,,,,,,,, -31900,0.05992817,0.030858543,,,,,,,,,,,,,,,,, -32000,0.05820417,0.02793283,,,,,,,,,,,,,,,,, -32100,0.05997461,0.028870754,,,,,,,,,,,,,,,,, -32200,0.059355773,0.030467479,,,,,,,,,,,,,,,,, -32201,,,0.9917375445365906,0.0267392918467521,0.4871506207529771,0.9870427250862122,0.0438420996069908,0.2848324278871966,43793.0,0.9861093759536744,0.0469667054712772,0.2685648392049272,43793.0,10338.541726350784,14915.281131029127,10338.541726350784,4574.533927679062,1.313849925994873,0.0 -32300,0.06402548,0.028436173,,,,,,,,,,,,,,,,, -32400,0.060049288,0.028796073,,,,,,,,,,,,,,,,, -32500,0.053771365,0.026189405,,,,,,,,,,,,,,,,, -32600,0.077253714,0.030509101,,,,,,,,,,,,,,,,, -32700,0.066636845,0.02644273,,,,,,,,,,,,,,,,, -32800,0.05933928,0.027303597,,,,,,,,,,,,,,,,, -32900,0.06772022,0.030726897,,,,,,,,,,,,,,,,, -32947,,,0.9917527437210084,0.0266880467534065,0.4959007838136359,0.9871665239334106,0.0442121364176273,0.2930568691400217,43793.0,0.9862593412399292,0.0471234656870365,0.2710033452222284,43793.0,10578.52193403244,15257.13302230835,10578.52193403244,4676.352701425552,1.3461999893188477,0.0 -33000,0.077635,0.024311552,,,,,,,,,,,,,,,,, -33100,0.06348241,0.029462583,,,,,,,,,,,,,,,,, -33200,0.069271594,0.028138066,,,,,,,,,,,,,,,,, -33300,0.060441464,0.029891603,,,,,,,,,,,,,,,,, -33400,0.055493873,0.027252192,,,,,,,,,,,,,,,,, -33500,0.06735158,0.025626674,,,,,,,,,,,,,,,,, -33600,0.06566087,0.028748032,,,,,,,,,,,,,,,,, -33700,0.055486806,0.02677941,,,,,,,,,,,,,,,,, -33704,,,0.9918839931488036,0.0260620433837175,0.5066806571838988,0.9870265126228333,0.0439733080565929,0.2861925315494497,43793.0,0.9861788749694824,0.046940054744482,0.271115682743845,43793.0,10818.582146406174,15598.397938489914,10818.582146406174,4777.501918077469,1.3808777332305908,0.0 -33800,0.0875117,0.029131891,,,,,,,,,,,,,,,,, -33900,0.06529258,0.028649706,,,,,,,,,,,,,,,,, -34000,0.063326955,0.02988456,,,,,,,,,,,,,,,,, -34100,0.070909545,0.030518716,,,,,,,,,,,,,,,,, -34200,0.06445044,0.026602618,,,,,,,,,,,,,,,,, -34300,0.07660071,0.030540854,,,,,,,,,,,,,,,,, -34400,0.064347155,0.03219216,,,,,,,,,,,,,,,,, -34457,,,0.991776704788208,0.02636144682765,0.5051048917352821,0.9870297312736512,0.044214628636837,0.290110529395656,43793.0,0.98611319065094,0.0471945516765117,0.2731183121212644,43793.0,11058.546568632126,15939.182308912275,11058.546568632126,4878.264587640762,1.417754888534546,0.0 -34500,0.05551472,0.026901716,,,,,,,,,,,,,,,,, -34600,0.069195345,0.030536152,,,,,,,,,,,,,,,,, -34700,0.061425477,0.026991377,,,,,,,,,,,,,,,,, -34800,0.064079195,0.031319413,,,,,,,,,,,,,,,,, -34900,0.06083958,0.026714494,,,,,,,,,,,,,,,,, -35000,0.06296641,0.027765265,,,,,,,,,,,,,,,,, -35100,0.06187138,0.023976406,,,,,,,,,,,,,,,,, -35200,0.06914677,0.029386435,,,,,,,,,,,,,,,,, -35208,,,0.9921960234642028,0.025136025622487,0.5219924353902371,0.9870204329490662,0.0439331047236919,0.2923408581900801,43793.0,0.9862593412399292,0.0469382219016552,0.2796964435428016,43793.0,11298.666824102402,16282.887593269348,11298.666824102402,4981.795501708984,1.4510157108306885,0.0 -35300,0.05411381,0.025608985,,,,,,,,,,,,,,,,, -35400,0.058254484,0.027344866,,,,,,,,,,,,,,,,, -35500,0.065922275,0.027868865,,,,,,,,,,,,,,,,, -35600,0.06485946,0.028538516,,,,,,,,,,,,,,,,, -35700,0.07348854,0.027163746,,,,,,,,,,,,,,,,, -35800,0.07541028,0.02947459,,,,,,,,,,,,,,,,, -35900,0.0670903,0.027843488,,,,,,,,,,,,,,,,, -35951,,,0.9921252131462096,0.0250017885118722,0.5476866748148286,0.9870662689208984,0.0444701761007308,0.2893286274359,43793.0,0.9862323999404908,0.0474622808396816,0.2696253351743681,43793.0,11538.917538404465,16621.011003494263,11538.917538404465,5079.614475250244,1.4840147495269775,0.0 -36000,0.06976856,0.026738605,,,,,,,,,,,,,,,,, -36100,0.073348805,0.027580215,,,,,,,,,,,,,,,,, -36200,0.07339769,0.029168285,,,,,,,,,,,,,,,,, -36300,0.06067125,0.025429625,,,,,,,,,,,,,,,,, -36400,0.06961791,0.027413728,,,,,,,,,,,,,,,,, -36500,0.06291839,0.02362083,,,,,,,,,,,,,,,,, -36600,0.06763736,0.027406367,,,,,,,,,,,,,,,,, -36700,0.06407026,0.025893724,,,,,,,,,,,,,,,,, -36702,,,0.992489457130432,0.0240619573742151,0.5486724927922778,0.987052857875824,0.0439661256968975,0.2908290268421575,43793.0,0.9861683249473572,0.0470186807215213,0.2739539266716618,43793.0,11779.103993415833,16962.328053236008,11779.103993415833,5180.689831018448,1.5187046527862549,0.0 -36800,0.06659346,0.02632022,,,,,,,,,,,,,,,,, -36900,0.0694293,0.026137883,,,,,,,,,,,,,,,,, -37000,0.060128678,0.026001716,,,,,,,,,,,,,,,,, -37100,0.06277629,0.026757907,,,,,,,,,,,,,,,,, -37200,0.07161296,0.030231984,,,,,,,,,,,,,,,,, -37300,0.06749673,0.028728174,,,,,,,,,,,,,,,,, -37400,0.06436504,0.026870443,,,,,,,,,,,,,,,,, -37449,,,0.9924684762954712,0.0240710340440273,0.5645409408952597,0.9871202707290648,0.0441614389419555,0.2904596654457011,43793.0,0.9862315058708192,0.0470658540725708,0.2740260003075513,43793.0,12019.094329595566,17305.666412353516,12019.094329595566,5283.982161521912,1.553828477859497,0.0 -37500,0.06794818,0.029111272,,,,,,,,,,,,,,,,, -37600,0.06602362,0.030307882,,,,,,,,,,,,,,,,, -37700,0.057151426,0.025282366,,,,,,,,,,,,,,,,, -37800,0.0696695,0.025926195,,,,,,,,,,,,,,,,, -37900,0.07566925,0.025805384,,,,,,,,,,,,,,,,, -38000,0.07166633,0.025664268,,,,,,,,,,,,,,,,, -38100,0.0681658,0.025199553,,,,,,,,,,,,,,,,, -38196,,,0.9922829866409302,0.0246264319866895,0.533122242307951,0.9870577454566956,0.0446517802774906,0.2929242246556005,43793.0,0.9862298369407654,0.0478082112967968,0.2720654290337988,43793.0,12259.080934286118,17646.517784833908,12259.080934286118,5384.793403148651,1.5871310234069824,0.0 -38200,0.06580879,0.030931814,,,,,,,,,,,,,,,,, -38300,0.070083894,0.025251163,,,,,,,,,,,,,,,,, -38400,0.06310825,0.026246801,,,,,,,,,,,,,,,,, -38500,0.0679376,0.028901437,,,,,,,,,,,,,,,,, -38600,0.082156666,0.025679672,,,,,,,,,,,,,,,,, -38700,0.07488025,0.026290838,,,,,,,,,,,,,,,,, -38800,0.083611935,0.026672611,,,,,,,,,,,,,,,,, -38900,0.06785419,0.027449824,,,,,,,,,,,,,,,,, -38952,,,0.9922152757644652,0.0249605812132358,0.5255462330459101,0.986975371837616,0.0441494546830654,0.2922564708468797,43793.0,0.9861696362495422,0.0471304431557655,0.2749916187429383,43793.0,12499.300460100174,17987.14372611046,12499.300460100174,5485.143780708313,1.6230242252349854,0.0 -39000,0.092780225,0.03166115,,,,,,,,,,,,,,,,, -39100,0.066226296,0.028534492,,,,,,,,,,,,,,,,, -39200,0.063760966,0.024286423,,,,,,,,,,,,,,,,, -39300,0.07117394,0.025754465,,,,,,,,,,,,,,,,, -39400,0.064062834,0.02675267,,,,,,,,,,,,,,,,, -39500,0.06881362,0.028696889,,,,,,,,,,,,,,,,, -39600,0.09631581,0.027571805,,,,,,,,,,,,,,,,, -39700,0.0735291,0.027416607,,,,,,,,,,,,,,,,, -39710,,,0.992161512374878,0.0250895153731107,0.5311539849439371,0.9870626330375672,0.0443945117294788,0.2939392503608551,43793.0,0.9861927628517152,0.0475735031068325,0.2701267080078491,43793.0,12739.280121326448,18324.382556915283,12739.280121326448,5582.348515987396,1.6569876670837402,0.0 -39800,0.07274482,0.0261799,,,,,,,,,,,,,,,,, -39900,0.067427434,0.028329642,,,,,,,,,,,,,,,,, -40000,0.07186345,0.027987802,,,,,,,,,,,,,,,,, -40100,0.060892444,0.026640834,,,,,,,,,,,,,,,,, -40200,0.07147802,0.02628702,,,,,,,,,,,,,,,,, -40300,0.07890315,0.029815111,,,,,,,,,,,,,,,,, -40400,0.074386865,0.026854318,,,,,,,,,,,,,,,,, -40468,,,0.9923699498176576,0.0245824754238128,0.5457115756354987,0.9870626330375672,0.0439769886434078,0.2900833526106087,43793.0,0.9862176179885864,0.0470668077468872,0.2757043997765046,43793.0,12979.375619888306,18663.341647863388,12979.375619888306,5681.157953500748,1.6909382343292236,0.0 -40500,0.06691908,0.026064003,,,,,,,,,,,,,,,,, -40600,0.06598654,0.026953364,,,,,,,,,,,,,,,,, -40700,0.07816751,0.030013863,,,,,,,,,,,,,,,,, -40800,0.082806416,0.026569413,,,,,,,,,,,,,,,,, -40900,0.06750578,0.022694089,,,,,,,,,,,,,,,,, -41000,0.07962176,0.031869605,,,,,,,,,,,,,,,,, -41100,0.07668267,0.028523713,,,,,,,,,,,,,,,,, -41200,0.06983754,0.028388571,,,,,,,,,,,,,,,,, -41216,,,0.9924188256263732,0.0242145117372274,0.5426725337847191,0.9870244860649108,0.044272493571043,0.290808543338558,43793.0,0.9861923456192015,0.0471550598740577,0.2781001258652558,43793.0,13219.553693056108,19015.69958639145,13219.553693056108,5793.282270669937,1.726207971572876,0.0 -41300,0.07913593,0.026259815,,,,,,,,,,,,,,,,, -41400,0.092110164,0.029391043,,,,,,,,,,,,,,,,, -41500,0.083513096,0.027997402,,,,,,,,,,,,,,,,, -41600,0.07588852,0.025529025,,,,,,,,,,,,,,,,, -41700,0.07047263,0.027567323,,,,,,,,,,,,,,,,, -41800,0.078793496,0.02754391,,,,,,,,,,,,,,,,, -41900,0.07576817,0.026424406,,,,,,,,,,,,,,,,, -41962,,,0.9924597144126892,0.0239833146333694,0.5485732064088567,0.9870756268501282,0.0444074310362339,0.2981488644825977,43793.0,0.9861965775489808,0.0475791618227958,0.2741416501901633,43793.0,13459.66164469719,19357.54460954666,13459.66164469719,5894.937569618225,1.7861547470092771,0.0 -42000,0.06353391,0.025836399,,,,,,,,,,,,,,,,, -42100,0.07884145,0.027798824,,,,,,,,,,,,,,,,, -42200,0.07428237,0.027912611,,,,,,,,,,,,,,,,, -42300,0.07509798,0.026644051,,,,,,,,,,,,,,,,, -42400,0.07509032,0.027887434,,,,,,,,,,,,,,,,, -42500,0.073301084,0.028185068,,,,,,,,,,,,,,,,, -42600,0.06582553,0.026503667,,,,,,,,,,,,,,,,, -42700,0.086861655,0.026214419,,,,,,,,,,,,,,,,, -42711,,,0.9925682544708252,0.0235717296600341,0.5703393827897597,0.9870427250862122,0.0443245023488998,0.2951427843127839,43793.0,0.9863145351409912,0.0470815673470497,0.2796231683478654,43793.0,13699.857156276705,19701.188717842106,13699.857156276705,5998.332057952881,1.8204026222229004,0.0 -42800,0.06935648,0.022428958,,,,,,,,,,,,,,,,, -42900,0.08658065,0.027587404,,,,,,,,,,,,,,,,, -43000,0.06745556,0.024797214,,,,,,,,,,,,,,,,, -43100,0.081397586,0.028676622,,,,,,,,,,,,,,,,, -43200,0.076011516,0.026472531,,,,,,,,,,,,,,,,, -43300,0.0728008,0.024294697,,,,,,,,,,,,,,,,, -43400,0.07227451,0.027315838,,,,,,,,,,,,,,,,, -43463,,,0.9928522109985352,0.0226859804242849,0.587405296771667,0.987025260925293,0.0448318682610988,0.2926435323088024,43793.0,0.986214280128479,0.0480096787214279,0.2711062177589242,43793.0,13939.90434885025,20042.14708971977,13939.90434885025,6099.189184188843,1.8538603782653809,0.0 -43500,0.07497702,0.024678476,,,,,,,,,,,,,,,,, -43600,0.07879367,0.029805632,,,,,,,,,,,,,,,,, -43700,0.07191937,0.02673936,,,,,,,,,,,,,,,,, -43800,0.07896849,0.025939135,,,,,,,,,,,,,,,,, -43900,0.08844229,0.027509864,,,,,,,,,,,,,,,,, -44000,0.07632587,0.025707796,,,,,,,,,,,,,,,,, -44100,0.08802839,0.026451351,,,,,,,,,,,,,,,,, -44200,0.08867644,0.028045334,,,,,,,,,,,,,,,,, -44220,,,0.9930338859558104,0.0222247261554002,0.6030117376077448,0.9871328473091124,0.0447461754083633,0.2994731789222761,43793.0,0.9862829446792604,0.0478932149708271,0.2773757847676095,43793.0,14180.02293419838,20386.546948194504,14180.02293419838,6203.414827346802,1.889482498168945,0.0 -44300,0.08614086,0.027977679,,,,,,,,,,,,,,,,, -44400,0.07285104,0.023583362,,,,,,,,,,,,,,,,, -44500,0.08160515,0.028883694,,,,,,,,,,,,,,,,, -44600,0.0856614,0.025657078,,,,,,,,,,,,,,,,, -44700,0.080116056,0.024456445,,,,,,,,,,,,,,,,, -44800,0.08595193,0.028803587,,,,,,,,,,,,,,,,, -44900,0.09775485,0.027151734,,,,,,,,,,,,,,,,, -44972,,,0.9930641651153564,0.0220830775797367,0.6060911317596043,0.987050473690033,0.0448730997741222,0.2943794044568053,43793.0,0.9862542748451232,0.0480829142034053,0.2777546667384631,43793.0,14420.104619264604,20725.25773501396,14420.104619264604,6301.98726439476,1.9259374141693115,0.0 -45000,0.07478112,0.023919506,,,,,,,,,,,,,,,,, -45100,0.070159845,0.02453021,,,,,,,,,,,,,,,,, -45200,0.08276482,0.026516462,,,,,,,,,,,,,,,,, -45300,0.074608795,0.024107404,,,,,,,,,,,,,,,,, -45400,0.08944993,0.026837371,,,,,,,,,,,,,,,,, -45500,0.07935729,0.02475977,,,,,,,,,,,,,,,,, -45600,0.09491152,0.031870577,,,,,,,,,,,,,,,,, -45700,0.09253866,0.024343584,,,,,,,,,,,,,,,,, -45724,,,0.993030309677124,0.0223230738192796,0.5868794150076238,0.9870455861091614,0.0445633865892887,0.2950067238532902,43793.0,0.986276626586914,0.0473332963883876,0.2744227430076576,43793.0,14660.14389538765,21068.20579099655,14660.14389538765,6404.840629816055,1.961230993270874,0.0 -45800,0.076614164,0.024562547,,,,,,,,,,,,,,,,, -45900,0.09150752,0.025956832,,,,,,,,,,,,,,,,, -46000,0.100313984,0.026763093,,,,,,,,,,,,,,,,, -46100,0.075047806,0.024711587,,,,,,,,,,,,,,,,, -46200,0.08047188,0.024993848,,,,,,,,,,,,,,,,, -46300,0.086321935,0.025064249,,,,,,,,,,,,,,,,, -46400,0.078494415,0.024710415,,,,,,,,,,,,,,,,, -46471,,,0.9928067922592164,0.0229206141084432,0.5767878589178458,0.9869668483734132,0.0449565276503562,0.2892812244489832,43793.0,0.98615700006485,0.0479879155755043,0.2757587022034949,43793.0,14900.095911979675,21412.44612908364,14900.095911979675,6509.072257757187,1.9975545406341555,0.0 -46500,0.08926397,0.02605091,,,,,,,,,,,,,,,,, -46600,0.08557401,0.026113912,,,,,,,,,,,,,,,,, -46700,0.08019742,0.024366233,,,,,,,,,,,,,,,,, -46800,0.092154294,0.025061527,,,,,,,,,,,,,,,,, -46900,0.074305065,0.026518393,,,,,,,,,,,,,,,,, -47000,0.08732109,0.027060008,,,,,,,,,,,,,,,,, -47100,0.07877535,0.024650276,,,,,,,,,,,,,,,,, -47200,0.10209365,0.02600726,,,,,,,,,,,,,,,,, -47223,,,0.9927178621292114,0.0231426432728767,0.5750837991380785,0.9870151281356812,0.0452580600976944,0.2898598701086695,43793.0,0.986199915409088,0.0484224893152713,0.2699944692662746,43793.0,15140.2622089386,21756.26882982254,15140.2622089386,6612.673577547073,2.0322277545928955,0.0 -47300,0.09752702,0.025840074,,,,,,,,,,,,,,,,, -47400,0.07189098,0.020903235,,,,,,,,,,,,,,,,, -47500,0.078229256,0.024249557,,,,,,,,,,,,,,,,, -47600,0.08381035,0.023491718,,,,,,,,,,,,,,,,, -47700,0.0907342,0.026825268,,,,,,,,,,,,,,,,, -47800,0.0976141,0.027236298,,,,,,,,,,,,,,,,, -47900,0.08598062,0.024779247,,,,,,,,,,,,,,,,, -47963,,,0.9928799867630004,0.0225607007741928,0.5870601530022526,0.9871381521224976,0.0450849756598472,0.2975629992554879,43793.0,0.9862197637557985,0.0482847988605499,0.2736227194885751,43793.0,15380.479937553406,22098.51292657852,15380.479937553406,6714.644959449768,2.067057847976685,0.0 -48000,0.08836605,0.025229173,,,,,,,,,,,,,,,,, -48100,0.07746484,0.023373725,,,,,,,,,,,,,,,,, -48200,0.091510735,0.024657983,,,,,,,,,,,,,,,,, -48300,0.103883,0.027111575,,,,,,,,,,,,,,,,, -48400,0.08086488,0.02312373,,,,,,,,,,,,,,,,, -48500,0.094345,0.025087124,,,,,,,,,,,,,,,,, -48600,0.09078236,0.024753293,,,,,,,,,,,,,,,,, -48700,0.08416943,0.026520703,,,,,,,,,,,,,,,,, -48705,,,0.992908239364624,0.0224849469959735,0.5866123590096386,0.9870979189872742,0.0450217388570308,0.2982286863077764,43793.0,0.9861965775489808,0.0484400577843189,0.268240220401425,43793.0,15620.667268753052,22443.44396972656,15620.667268753052,6819.330502986908,2.104949951171875,0.0 -48800,0.094395444,0.025407126,,,,,,,,,,,,,,,,, -48900,0.09342567,0.025504485,,,,,,,,,,,,,,,,, -49000,0.076884784,0.024386493,,,,,,,,,,,,,,,,, -49100,0.08509174,0.021925597,,,,,,,,,,,,,,,,, -49200,0.07250042,0.021863326,,,,,,,,,,,,,,,,, -49300,0.1012692,0.023229882,,,,,,,,,,,,,,,,, -49400,0.091622494,0.025087165,,,,,,,,,,,,,,,,, -49454,,,0.9932429790496826,0.0214418750256299,0.6142716018141637,0.9870386719703674,0.0449598133563995,0.2990870103980558,43793.0,0.9861738085746764,0.0482622310519218,0.2745028253297134,43793.0,15860.90509223938,22782.647591114044,15860.90509223938,6918.238473892212,2.142441749572754,0.0 -49500,0.093740165,0.02791032,,,,,,,,,,,,,,,,, -49600,0.08978742,0.024305096,,,,,,,,,,,,,,,,, -49700,0.09025784,0.025360567,,,,,,,,,,,,,,,,, -49800,0.08510582,0.023043219,,,,,,,,,,,,,,,,, -49900,0.090448156,0.02670275,,,,,,,,,,,,,,,,, -50000,0.10592834,0.024848845,,,,,,,,,,,,,,,,, -50100,0.09893644,0.028269583,,,,,,,,,,,,,,,,, -50198,,,0.993171751499176,0.0213980432599782,0.6153558358088966,0.9871393442153932,0.0456449352204799,0.3004330684502952,43793.0,0.98628968000412,0.0490044243633747,0.2771777891807733,43793.0,16101.125497341156,23125.086805820465,16101.125497341156,7020.40039730072,2.1796648502349854,0.0 -50200,0.091609046,0.025426261,,,,,,,,,,,,,,,,, -50300,0.09735354,0.023441011,,,,,,,,,,,,,,,,, -50400,0.090225145,0.026738787,,,,,,,,,,,,,,,,, -50500,0.097047545,0.023835965,,,,,,,,,,,,,,,,, -50600,0.081890136,0.023900313,,,,,,,,,,,,,,,,, -50700,0.094948165,0.023694165,,,,,,,,,,,,,,,,, -50800,0.099445544,0.026237354,,,,,,,,,,,,,,,,, -50900,0.10567819,0.025771463,,,,,,,,,,,,,,,,, -50942,,,0.9934930801391602,0.0205563474446535,0.6322967831703095,0.9871681928634644,0.045472003519535,0.2960142246072738,43793.0,0.986243724822998,0.0487203709781169,0.2774862370226513,43793.0,16341.301450490952,23464.592199087143,16341.301450490952,7119.673288583756,2.215346097946167,0.0 -51000,0.09447281,0.025786567,,,,,,,,,,,,,,,,, -51100,0.08863707,0.025317976,,,,,,,,,,,,,,,,, -51200,0.08502144,0.023177916,,,,,,,,,,,,,,,,, -51300,0.09789216,0.02256113,,,,,,,,,,,,,,,,, -51400,0.10836923,0.02376794,,,,,,,,,,,,,,,,, -51500,0.092389695,0.023429642,,,,,,,,,,,,,,,,, -51600,0.10454,0.024912277,,,,,,,,,,,,,,,,, -51688,,,0.9936957955360411,0.0200053192675113,0.6529153022018896,0.9870293140411376,0.0454819798469543,0.2990586756813242,43793.0,0.9861788749694824,0.0486474893987178,0.2764061666842379,43793.0,16581.445430278778,23808.160900831223,16581.445430278778,7223.04040813446,2.2525784969329834,0.0 -51700,0.10399417,0.023250807,,,,,,,,,,,,,,,,, -51800,0.1057785,0.026473409,,,,,,,,,,,,,,,,, -51900,0.08896009,0.022845712,,,,,,,,,,,,,,,,, -52000,0.091142945,0.02465728,,,,,,,,,,,,,,,,, -52100,0.10021226,0.024077367,,,,,,,,,,,,,,,,, -52200,0.09195073,0.024707574,,,,,,,,,,,,,,,,, -52300,0.09064296,0.025610337,,,,,,,,,,,,,,,,, -52400,0.09604711,0.02523638,,,,,,,,,,,,,,,,, -52442,,,0.9935854077339172,0.0202750619500875,0.6460475938094383,0.987106442451477,0.0458005554974079,0.3005454513036418,43793.0,0.9862736463546752,0.049133013933897,0.2778174709848799,43793.0,16821.530868291855,24153.16040802002,16821.530868291855,7327.888458013535,2.29858922958374,0.0 -52500,0.085408635,0.02104549,,,,,,,,,,,,,,,,, -52600,0.092642024,0.020397851,,,,,,,,,,,,,,,,, -52700,0.09500007,0.023306185,,,,,,,,,,,,,,,,, -52800,0.10718391,0.025846956,,,,,,,,,,,,,,,,, -52900,0.1002641,0.02425359,,,,,,,,,,,,,,,,, -53000,0.11094903,0.026017252,,,,,,,,,,,,,,,,, -53100,0.10303601,0.023182133,,,,,,,,,,,,,,,,, -53189,,,0.9934902787208556,0.0206452459096908,0.6349567310737267,0.9871125817298888,0.0457358434796333,0.2995423135091453,43793.0,0.9861940741539,0.0492467880249023,0.2722941436238601,43793.0,17061.780789375305,24492.081380605698,17061.780789375305,7426.502200841904,2.33558201789856,0.0 -53200,0.092583805,0.02068312,,,,,,,,,,,,,,,,, -53300,0.09908063,0.02096944,,,,,,,,,,,,,,,,, -53400,0.09451361,0.02421778,,,,,,,,,,,,,,,,, -53500,0.0949155,0.023778269,,,,,,,,,,,,,,,,, -53600,0.105130546,0.02399618,,,,,,,,,,,,,,,,, -53700,0.10052334,0.022790294,,,,,,,,,,,,,,,,, -53800,0.08684714,0.022537911,,,,,,,,,,,,,,,,, -53900,0.09480066,0.023074593,,,,,,,,,,,,,,,,, -53935,,,0.9932562112808228,0.0211492143571376,0.6126243232067949,0.9869729280471802,0.0458820350468158,0.2990511593052038,43793.0,0.9861367344856262,0.0490672029554843,0.2779133849380018,43793.0,17301.95036458969,24836.86819219589,17301.95036458969,7531.053488969803,2.381908178329468,0.0 -54000,0.114963435,0.024907367,,,,,,,,,,,,,,,,, -54100,0.107702546,0.022325378,,,,,,,,,,,,,,,,, -54200,0.10191463,0.024201559,,,,,,,,,,,,,,,,, -54300,0.10695609,0.025965143,,,,,,,,,,,,,,,,, -54400,0.12309308,0.026581358,,,,,,,,,,,,,,,,, -54500,0.13426533,0.023802716,,,,,,,,,,,,,,,,, -54600,0.09718835,0.020383172,,,,,,,,,,,,,,,,, -54684,,,0.9933159947395324,0.0209419373422861,0.6317724607619654,0.987216055393219,0.0461957119405269,0.2959647151213805,43793.0,0.9862450361251832,0.0496200770139694,0.2787473506538628,43793.0,17542.20799946785,25179.772585392,17542.20799946785,7633.643043994904,2.4187357425689697,0.0 -54700,0.10898853,0.024202393,,,,,,,,,,,,,,,,, -54800,0.14751545,0.021763524,,,,,,,,,,,,,,,,, -54900,0.111070074,0.025473682,,,,,,,,,,,,,,,,, -55000,0.099085726,0.024514265,,,,,,,,,,,,,,,,, -55100,0.121151835,0.023870427,,,,,,,,,,,,,,,,, -55200,0.11300375,0.021826623,,,,,,,,,,,,,,,,, -55300,0.11038531,0.022876633,,,,,,,,,,,,,,,,, -55400,0.10054141,0.027047837,,,,,,,,,,,,,,,,, -55434,,,0.9935021996498108,0.0203129090368747,0.6375259346868289,0.9871361255645752,0.0462104454636573,0.2995275510186174,43793.0,0.9863107204437256,0.049470916390419,0.277157897267368,43793.0,17782.284705638885,25523.43450474739,17782.284705638885,7737.169899225235,2.456860542297364,0.0 -55500,0.09886604,0.022911115,,,,,,,,,,,,,,,,, -55600,0.10166023,0.021971934,,,,,,,,,,,,,,,,, -55700,0.11639723,0.025014842,,,,,,,,,,,,,,,,, -55800,0.12030506,0.021477295,,,,,,,,,,,,,,,,, -55900,0.10385101,0.018656192,,,,,,,,,,,,,,,,, -56000,0.123737365,0.02084493,,,,,,,,,,,,,,,,, -56100,0.10790147,0.023602394,,,,,,,,,,,,,,,,, -56177,,,0.9936658143997192,0.0198207218199968,0.6528672863038035,0.9871364831924438,0.0461660847067832,0.3025037933798346,43793.0,0.9862247705459596,0.0495844110846519,0.2745376159234082,43793.0,18022.405507564545,25868.13215303421,18022.405507564545,7841.688732147217,2.4949471950531006,0.0 -56200,0.10340992,0.02413137,,,,,,,,,,,,,,,,, -56300,0.10064136,0.019386645,,,,,,,,,,,,,,,,, -56400,0.10729888,0.021825848,,,,,,,,,,,,,,,,, -56500,0.09062133,0.021049047,,,,,,,,,,,,,,,,, -56600,0.11464233,0.02314689,,,,,,,,,,,,,,,,, -56700,0.10225072,0.022282701,,,,,,,,,,,,,,,,, -56800,0.09538424,0.020639813,,,,,,,,,,,,,,,,, -56900,0.11718491,0.024862746,,,,,,,,,,,,,,,,, -56924,,,0.9937217235565186,0.0195510704070329,0.6607373302507992,0.9871413707733154,0.046281412243843,0.3001707093408799,43793.0,0.9862277507781982,0.0494428053498268,0.2796399069056877,43793.0,18262.658086299896,26208.75817298889,18262.658086299896,7941.999848365784,2.5373575687408447,0.0 -57000,0.10977856,0.024089055,,,,,,,,,,,,,,,,, -57100,0.10301069,0.020863254,,,,,,,,,,,,,,,,, -57200,0.102679394,0.022394868,,,,,,,,,,,,,,,,, -57300,0.112738304,0.021192934,,,,,,,,,,,,,,,,, -57400,0.115915604,0.020414244,,,,,,,,,,,,,,,,, -57500,0.1230275,0.02354922,,,,,,,,,,,,,,,,, -57590,,,,,,,,,,,,,,18477.211141824722,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index c483ed895..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -891.9083139896393,0.0,37.09946036338806,1,0,37.09946036338806,0.0007088489946909,0.0,11.041826248168944,3003,929.0078175067902,0.000594088807702,0.0,11.064360618591309,0.0004835649742744,0.0,11.036645889282228,3000 -1467.1980648040771,0.0300376415252685,877.1772933006287,2401,0,877.1772933006287,0.3828597962856293,7.93982270373288,4.295248985290527,3003,2344.4883332252502,0.4122332334518432,14.315394056700264,3.973730564117432,0.3997098505496979,9.72599437822303,4.093648433685303,3000 -1925.8410477638245,0.056603193283081,1717.1984317302704,4801,0,1717.1984317302704,0.5431642532348633,18.73440224873982,2.7641193866729736,3003,3643.25822019577,0.542317271232605,24.2609666755636,2.7344980239868164,0.5440229773521423,20.12233410151373,2.7259020805358887,3000 -2351.8377647399902,0.0824933052062988,2557.0957322120667,7203,0,2557.0957322120667,0.5903899073600769,21.95396396643552,2.3085203170776367,3003,4909.256707191467,0.5805636644363403,27.023267640432262,2.379575252532959,0.5879034399986267,23.28663282112053,2.307112455368042,3000 -2804.38107419014,0.109868049621582,3397.0145201683044,9607,0,3397.0145201683044,0.6115739941596985,23.36949635105416,2.1070995330810547,3003,6201.824702739716,0.593815803527832,28.30504380712284,2.2298953533172607,0.6075063943862915,24.75181928129517,2.134629487991333,3000 -3475.285562515259,0.1377413272857666,4237.137367010117,12012,0,4237.137367010117,0.6252629160881042,24.26036476993152,1.965983867645264,3003,7712.961602926254,0.6022624969482422,28.608856006115754,2.149235486984253,0.619471549987793,25.498765397658293,2.007272720336914,3000 -3925.826815366745,0.1641829013824463,5077.038963317871,14418,0,5077.038963317871,0.6387659311294556,25.881130375271567,1.878983736038208,3003,9003.50999879837,0.6124773025512695,29.56920209281836,2.0575904846191406,0.6307671070098877,26.21993126585849,1.9267817735672,3000 -4377.803135633469,0.1926159858703613,5916.943907022476,16823,0,5916.943907022476,0.6472604870796204,26.0036035069853,1.81446385383606,3003,10295.500350952148,0.6192818880081177,29.80887983767533,2.010906934738159,0.6372766494750977,26.629292980287577,1.866188883781433,3000 -4964.684417009354,0.2200927734375,6757.095978498459,19230,0,6757.095978498459,0.6535820364952087,26.51818607626821,1.7662612199783323,3003,11722.64000749588,0.636422872543335,30.822602520333803,1.8756887912750244,0.6443069577217102,27.01061448954327,1.820154428482056,3000 -5431.484809875488,0.2487683296203613,7597.150082349777,21638,0,7597.150082349777,0.6585091352462769,27.19216074685112,1.7311699390411377,3003,13029.59965801239,0.6261128783226013,30.657585430152547,1.9536588191986084,0.6480762958526611,27.471663978033018,1.7900378704071045,3000 -5933.594583034515,0.2760488986968994,8437.384685277939,24047,0,8437.384685277939,0.6549997329711914,26.310299172789826,1.7625569105148315,3003,14372.047944068909,0.6246453523635864,29.88082264721335,1.9670435190200808,0.6471339464187622,27.139849535608672,1.80759871006012,3000 -6415.389216184616,0.3032636642456054,9277.445498466492,26455,0,9277.445498466492,0.6604846119880676,27.430558262691022,1.7240240573883057,3003,15694.008920431135,0.6341571807861328,30.91961353683156,1.888005971908569,0.6495393514633179,27.610013018571298,1.773059964179993,3000 -6918.62993311882,0.3362345695495605,10117.52335381508,28863,0,10117.52335381508,0.6618906855583191,27.259525740209053,1.6961952447891235,3003,17037.438900470734,0.6321583390235901,30.918711635523824,1.9047863483428955,0.6536558866500854,27.96373892400448,1.7527354955673218,3000 -7450.977185487747,0.3637206554412842,10957.66311454773,31272,0,10957.66311454773,0.6626111268997192,27.10428323835168,1.6905726194381714,3003,18410.030876636505,0.6326280832290649,30.91387757741362,1.9133557081222528,0.6534698605537415,27.38010984222516,1.7469457387924194,3000 -7931.514752388,0.3925554752349853,11797.693771123886,33681,0,11797.693771123886,0.6662018895149231,27.610379038748825,1.6690125465393066,3003,19730.70421743393,0.63736891746521,30.778807132463136,1.876145958900452,0.6563588976860046,27.70838599360045,1.7285524606704712,3000 -8416.673792600632,0.4278111457824707,12637.642409086227,36090,0,12637.642409086227,0.6704433560371399,28.013954515918293,1.653224229812622,3003,21055.924880743027,0.6397969126701355,30.72574395429576,1.858922243118286,0.6573135852813721,28.437746236206387,1.7214807271957395,3000 -8890.188046693802,0.4626927375793457,13477.828237771988,38500,0,13477.828237771988,0.6685724258422852,27.752310460703928,1.6480258703231812,3003,22369.73627972603,0.6449896693229675,31.57133577679328,1.8099712133407595,0.6588138937950134,28.44258842361144,1.7067581415176392,3000 -9357.548652887344,0.4925673007965088,14318.0378780365,40910,0,14318.0378780365,0.6722212433815002,28.249942435103367,1.6383522748947144,3003,23677.41546010971,0.6428003907203674,31.276876161097288,1.8337481021881104,0.6606861352920532,28.58794224496653,1.705464482307434,3000 -9851.644168376924,0.5287151336669922,15158.246175765991,43320,0,15158.246175765991,0.6720702052116394,28.10957228259788,1.6318928003311155,3003,25011.8331720829,0.6413927674293518,30.77948761714045,1.848698258399964,0.6622360348701477,28.52823324647621,1.69834566116333,3000 -10347.062133073809,0.5609724521636963,15998.48146367073,45730,0,15998.48146367073,0.6729068756103516,27.98403783607923,1.6298214197158811,3003,26347.59620976448,0.642379641532898,31.599102572516426,1.8306187391281128,0.6620624661445618,28.548053027735303,1.6910277605056765,3000 -10828.267130374908,0.5938632488250732,16838.670334100723,48140,0,16838.670334100723,0.674429178237915,28.14769900835641,1.6186769008636477,3003,27669.10016155243,0.643227219581604,31.4491060968431,1.8408466577529907,0.6614425182342529,28.538099099304283,1.6888902187347412,3000 -11365.696782827376,0.6260786056518555,17678.634654521942,50549,0,17678.634654521942,0.6748009920120239,28.20650007439196,1.6188348531723022,3003,29046.603624105453,0.6617670655250549,32.493355001995035,1.7062171697616575,0.6642819046974182,28.50163118714814,1.6819127798080444,3000 -11840.77304840088,0.6584103107452393,18518.534535884857,52959,0,18518.534535884857,0.6778455972671509,28.40794194410856,1.6047146320343018,3003,30361.687136888504,0.6417939066886902,31.521758366440302,1.8330038785934448,0.6663029789924622,28.710437477264808,1.6756658554077148,3000 -12342.030277490616,0.6971557140350342,19358.49755549431,55368,0,19358.49755549431,0.6779385209083557,28.4380119668229,1.60168719291687,3003,31703.02485537529,0.6438504457473755,31.664923167422646,1.8326653242111208,0.664095938205719,28.63440920485693,1.6701427698135376,3000 -12818.775514364244,0.7295718193054199,20198.622362852097,57777,0,20198.622362852097,0.6783917546272278,28.49080923108632,1.5919768810272217,3003,33020.00865364075,0.6529799699783325,32.104956986509244,1.7605587244033811,0.6671584844589233,28.91055358960645,1.6584866046905518,3000 -13294.7455265522,0.7640244960784912,21038.77587127685,60187,0,21038.77587127685,0.6815757751464844,28.60529602499556,1.579599380493164,3003,34336.24448752403,0.6493332982063293,31.857854688348205,1.788055658340454,0.6688447594642639,28.90816655813359,1.6495282649993896,3000 -13757.336695432665,0.7973370552062988,21878.77816271782,62597,0,21878.77816271782,0.6829004883766174,28.931010286915832,1.5770306587219238,3003,35638.94716525078,0.6954829692840576,35.48335512968942,1.5248199701309204,0.6700350642204285,29.23137089067413,1.6441853046417236,3000 -14245.597965955734,0.8305079936981201,22718.93424129486,65007,0,22718.93424129486,0.6821103096008301,28.572309209482075,1.5699652433395386,3003,36967.4750084877,0.6498906016349792,32.04606986938242,1.779046893119812,0.6694399118423462,28.851999438402675,1.64409077167511,3000 -14714.698032855988,0.8699114322662354,23559.06813430786,67416,0,23559.06813430786,0.6850618720054626,28.808396256578643,1.5574126243591309,3003,38276.82657814026,0.6513877511024475,31.90224548614695,1.778394341468811,0.6698490977287292,29.310274944733973,1.6329776048660278,3000 -15217.340163707731,0.9094698429107666,24399.16579413414,69825,0,24399.16579413414,0.6853756308555603,28.84675359208095,1.5609307289123535,3003,39619.68317270279,0.6634706258773804,32.673334074985306,1.7029547691345217,0.6726140975952148,29.223958515493404,1.6234630346298218,3000 -15726.62624669075,0.9434311389923096,25239.123465538025,72234,0,25239.123465538025,0.6883621215820312,29.265421368484017,1.5452096462249756,3003,40969.03985142708,0.6572171449661255,32.34493762341491,1.744814395904541,0.6746351718902588,29.395188395213623,1.619512677192688,3000 -16209.47501564026,0.977283239364624,26079.2656018734,74644,0,26079.2656018734,0.6876532435417175,29.1380813702546,1.5402637720108032,3003,42292.14107131958,0.6548944115638733,32.62063170484274,1.7606871128082275,0.6750690937042236,29.763167120385383,1.6131248474121094,3000 -16745.17874765396,1.0129663944244385,26919.287441253666,77053,0,26919.287441253666,0.6893498301506042,29.18517059563849,1.530578374862671,3003,43667.97906756401,0.6639668941497803,32.92362728757387,1.694888710975647,0.6748707294464111,29.34807239685257,1.6068533658981323,3000 -17228.647493600845,1.0498454570770264,27759.34330010414,79462,0,27759.34330010414,0.6904653906822205,29.331326809471992,1.5268282890319824,3003,44991.61689281464,0.6601611971855164,32.68691853658534,1.7258120775222778,0.6769289970397949,29.752262409047173,1.5998398065567017,3000 -17722.64498400688,1.086064338684082,28599.2469124794,81870,0,28599.2469124794,0.6923711895942688,29.556238156403044,1.518932819366455,3003,46325.63169527054,0.6784107685089111,33.84269230391496,1.6096141338348389,0.6793591976165771,29.80917299937564,1.5915672779083252,3000 -18211.00092744828,1.1236786842346191,29439.36522555352,84279,0,29439.36522555352,0.6955435872077942,29.76918535946277,1.5063549280166626,3003,47654.22263765335,0.6645780205726624,32.87321657993862,1.6899528503417969,0.6802023649215698,29.55673975736919,1.5834990739822388,3000 -18689.371471881863,1.1595537662506104,30279.430206775665,86688,0,30279.430206775665,0.6947998404502869,30.07424609365526,1.4972296953201294,3003,48972.770466566086,0.6623689532279968,33.333061211211565,1.7069313526153564,0.6798799633979797,29.87338660571638,1.5803080797195437,3000 -19172.79546570778,1.1961100101470947,31119.458785533905,89097,0,31119.458785533905,0.6962872743606567,29.97291099023428,1.4902687072753906,3003,50296.33786916733,0.6741820573806763,33.7021119458403,1.636649250984192,0.6820126175880432,29.77797063662827,1.5686326026916504,3000 -19661.390310049057,1.2334024906158447,31959.54501605034,91506,0,31959.54501605034,0.6995874643325806,30.32512532130016,1.481274127960205,3003,51625.13171696663,0.667885959148407,33.10366408069063,1.6756535768508911,0.6826573610305786,30.156964677978845,1.5624314546585083,3000 -20132.38963317871,1.2875926494598389,32799.61296272278,93915,0,32799.61296272278,0.7001568675041199,30.08239175584561,1.4738965034484863,3003,52936.32979607582,0.7043598890304565,36.35120024978537,1.4692833423614502,0.6835501194000244,30.0273109673757,1.558236002922058,3000 -20667.71193766594,1.3252015113830566,33639.83898591995,96324,0,33639.83898591995,0.7004590034484863,30.379588836077428,1.4679011106491089,3003,54311.99412155152,0.6754699945449829,33.64124946940375,1.6206367015838623,0.6841948628425598,30.02768930762145,1.5504695177078247,3000 -21192.948457717896,1.361955642700195,34480.07180213928,98734,0,34480.07180213928,0.700714647769928,30.23568171788698,1.4622043371200562,3003,55677.57796263695,0.6730888485908508,34.10313197148955,1.640097737312317,0.6855587363243103,30.445706302183336,1.5443347692489624,3000 -21685.645532608032,1.3999087810516355,35320.00034117699,101142,0,35320.00034117699,0.702934205532074,30.4668122978629,1.454554796218872,3003,57010.319242954254,0.686933696269989,34.618318349974814,1.5602630376815796,0.6874186396598816,30.18892686900693,1.5365707874298096,3000 -22168.00412297249,1.4385159015655518,36160.209800720215,103551,0,36160.209800720215,0.7046424150466919,30.654866411054087,1.4481626749038696,3003,58333.0024998188,0.6821432709693909,34.26272973822794,1.5882389545440674,0.688956081867218,30.360145618817835,1.5335873365402222,3000 -22680.11134338379,1.4766552448272705,37000.21959018707,105959,0,37000.21959018707,0.7055836319923401,30.78548127622312,1.4425654411315918,3003,59685.23624706268,0.6814278364181519,34.65526969945151,1.598901867866516,0.6893280744552612,30.57867552999859,1.5266581773757937,3000 -23180.20060920716,1.51641845703125,37840.39820337296,108368,0,37840.39820337296,0.7045029401779175,30.56989330909604,1.443967342376709,3003,61025.62011170387,0.6921325922012329,35.40294324144317,1.5298690795898438,0.6879394054412842,30.47344674934183,1.527682185173035,3000 -23671.07068610192,1.5580275058746338,38680.33823132515,110776,0,38680.33823132515,0.705258309841156,30.330988152869008,1.4395705461502075,3003,62356.54861474037,0.6881821751594543,34.94405351615512,1.555271863937378,0.6900472044944763,30.51655914551938,1.5213935375213623,3000 -24168.117631196976,1.6014611721038818,39520.55152916908,113185,0,39520.55152916908,0.7078961133956909,30.702009793502405,1.430131435394287,3003,63693.93018531799,0.7068272233009338,36.19503096365993,1.4597846269607544,0.6921550631523132,30.68075410546633,1.5146585702896118,3000 -24675.07525396347,1.6436638832092283,40360.71843266487,115594,0,40360.71843266487,0.7095230221748352,31.016669589638788,1.4238560199737549,3003,65041.172853946686,0.6958543062210083,35.537028715325654,1.5099217891693115,0.6910639405250549,30.63486420797816,1.5157564878463743,3000 -25169.378446102142,1.6847825050354004,41200.7603263855,118003,0,41200.7603263855,0.70961594581604,31.12338429193865,1.4236197471618652,3003,66375.6348798275,0.6981893181800842,35.49911535816573,1.5048353672027588,0.6911135315895081,30.810753206662906,1.5123188495635986,3000 -25667.826288461685,1.732635259628296,42040.92917585373,120412,0,42040.92917585373,0.7101737260818481,31.0455792076282,1.4174529314041138,3003,67714.37622475624,0.7060102224349976,36.32610595477296,1.4636739492416382,0.6930354237556458,30.81666176363183,1.5079798698425293,3000 -26163.87600445748,1.7744617462158203,42881.0361392498,122820,0,42881.0361392498,0.7105455994606018,31.13267295113105,1.418480634689331,3003,69050.65186357498,0.7075486779212952,36.11652502233829,1.4473143815994265,0.6921798586845398,30.694070873645327,1.5096970796585083,3000 -26669.166478395466,1.816622018814087,43720.94148516655,125227,0,43720.94148516655,0.7112312316894531,31.124645847204345,1.41438090801239,3003,70395.96740603447,0.7090739607810974,36.57344895473294,1.4433223009109497,0.6934942007064819,30.81755869738153,1.5054748058319092,3000 -27158.02793073654,1.858067274093628,44561.00401854515,127635,0,44561.00401854515,0.7111498713493347,31.248859219822087,1.4130109548568726,3003,71725.01059532166,0.7085887789726257,36.61241848155233,1.438305377960205,0.6935437917709351,30.862560161613207,1.5044677257537842,3000 -27658.42151737213,1.900503635406494,45400.904284238815,130043,0,45400.904284238815,0.7116844058036804,31.10714427248701,1.4134386777877808,3003,73065.42194890976,0.7091085910797119,36.99864153134804,1.4369813203811646,0.6933081746101379,30.757680438251448,1.504925012588501,3000 -28166.858474493027,1.944511890411377,46241.0879445076,132451,0,46241.0879445076,0.7116495370864868,31.19698552073992,1.4128319025039673,3003,74414.16377663612,0.7096970677375793,36.13791974967528,1.4432387351989746,0.693717360496521,30.865012416464875,1.504041075706482,3000 -28674.002204179764,1.9894163608551025,46548.72388243675,133333,0,46548.72388243675,0.7118006348609924,31.18247574009484,1.413087010383606,3003,75229.0163321495,0.7093749642372131,36.541089722712044,1.4428577423095703,0.693630576133728,30.869243773906362,1.5042879581451416,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 3093392fc..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.0723386,11.050199,,,,,,,,,,,,,,,,, -1,,,0.000594088807702,11.064360618591309,0.0,0.0004835649742744,11.036645889282228,0.0,3000.0,0.0007088489946909,11.041826248168944,0.0,3003.0,37.09946036338806,929.0078175067902,37.09946036338806,891.9083139896393,0.0,0.0 -100,0.4189977,8.939543,,,,,,,,,,,,,,,,, -200,0.16424973,8.590473,,,,,,,,,,,,,,,,, -300,0.17960836,8.396744,,,,,,,,,,,,,,,,, -400,0.25386724,7.968256,,,,,,,,,,,,,,,,, -500,0.2975194,7.671663,,,,,,,,,,,,,,,,, -600,0.5425824,7.4480743,,,,,,,,,,,,,,,,, -700,0.48073766,7.202721,,,,,,,,,,,,,,,,, -800,0.80870485,6.913717,,,,,,,,,,,,,,,,, -900,0.6763217,6.761005,,,,,,,,,,,,,,,,, -1000,0.5554898,6.5891056,,,,,,,,,,,,,,,,, -1100,0.5135228,6.391354,,,,,,,,,,,,,,,,, -1200,0.96777505,6.2369094,,,,,,,,,,,,,,,,, -1300,0.5139368,6.1117487,,,,,,,,,,,,,,,,, -1400,0.7465646,5.876919,,,,,,,,,,,,,,,,, -1500,0.6132424,5.7851853,,,,,,,,,,,,,,,,, -1600,0.6910087,5.6901445,,,,,,,,,,,,,,,,, -1700,0.73555195,5.5804634,,,,,,,,,,,,,,,,, -1800,0.7982794,5.4437733,,,,,,,,,,,,,,,,, -1900,1.1343764,5.3424697,,,,,,,,,,,,,,,,, -2000,0.78638446,5.2491565,,,,,,,,,,,,,,,,, -2100,0.78086805,5.124227,,,,,,,,,,,,,,,,, -2200,0.8694172,5.073522,,,,,,,,,,,,,,,,, -2300,0.78215533,4.9696784,,,,,,,,,,,,,,,,, -2400,0.68108225,4.806545,,,,,,,,,,,,,,,,, -2401,,,0.4122332334518432,3.973730564117432,14.315394056700264,0.3997098505496979,4.093648433685303,9.72599437822303,3000.0,0.3828597962856293,4.295248985290527,7.93982270373288,3003.0,877.1772933006287,2344.4883332252502,877.1772933006287,1467.1980648040771,0.0300376415252685,0.0 -2500,0.9648552,4.712875,,,,,,,,,,,,,,,,, -2600,0.9527727,4.71868,,,,,,,,,,,,,,,,, -2700,0.8151246,4.613298,,,,,,,,,,,,,,,,, -2800,1.0406384,4.547945,,,,,,,,,,,,,,,,, -2900,1.08444,4.5377693,,,,,,,,,,,,,,,,, -3000,0.99986094,4.436871,,,,,,,,,,,,,,,,, -3100,0.950237,4.2656465,,,,,,,,,,,,,,,,, -3200,0.63304114,4.352903,,,,,,,,,,,,,,,,, -3300,0.69173014,4.2752457,,,,,,,,,,,,,,,,, -3400,0.8617949,4.1878257,,,,,,,,,,,,,,,,, -3500,0.7238631,4.132709,,,,,,,,,,,,,,,,, -3600,0.6176629,4.2502127,,,,,,,,,,,,,,,,, -3700,0.7302055,4.0741367,,,,,,,,,,,,,,,,, -3800,0.64870405,4.0591774,,,,,,,,,,,,,,,,, -3900,0.7148585,3.9961016,,,,,,,,,,,,,,,,, -4000,0.62825155,4.050451,,,,,,,,,,,,,,,,, -4100,0.59788555,3.8224993,,,,,,,,,,,,,,,,, -4200,0.6999702,3.9275458,,,,,,,,,,,,,,,,, -4300,0.61008215,3.9754012,,,,,,,,,,,,,,,,, -4400,0.7366874,3.8787758,,,,,,,,,,,,,,,,, -4500,0.62703085,3.8105671,,,,,,,,,,,,,,,,, -4600,0.6344661,3.8208725,,,,,,,,,,,,,,,,, -4700,0.5360622,3.7529764,,,,,,,,,,,,,,,,, -4800,0.75979674,3.8429453,,,,,,,,,,,,,,,,, -4801,,,0.542317271232605,2.7344980239868164,24.2609666755636,0.5440229773521423,2.7259020805358887,20.12233410151373,3000.0,0.5431642532348633,2.7641193866729736,18.73440224873982,3003.0,1717.1984317302704,3643.25822019577,1717.1984317302704,1925.8410477638245,0.056603193283081,0.0 -4900,0.5589283,3.7239738,,,,,,,,,,,,,,,,, -5000,0.61992764,3.703364,,,,,,,,,,,,,,,,, -5100,0.5873449,3.7678783,,,,,,,,,,,,,,,,, -5200,0.68282723,3.7184064,,,,,,,,,,,,,,,,, -5300,0.48878798,3.7111886,,,,,,,,,,,,,,,,, -5400,0.52137417,3.6915784,,,,,,,,,,,,,,,,, -5500,0.5004371,3.6114962,,,,,,,,,,,,,,,,, -5600,0.50808156,3.631603,,,,,,,,,,,,,,,,, -5700,0.5261708,3.6762798,,,,,,,,,,,,,,,,, -5800,0.5136988,3.6318038,,,,,,,,,,,,,,,,, -5900,0.66456735,3.6552281,,,,,,,,,,,,,,,,, -6000,0.49375382,3.6127393,,,,,,,,,,,,,,,,, -6100,0.6682128,3.5738995,,,,,,,,,,,,,,,,, -6200,0.5057216,3.585975,,,,,,,,,,,,,,,,, -6300,0.48659176,3.5285301,,,,,,,,,,,,,,,,, -6400,0.40759087,3.5931163,,,,,,,,,,,,,,,,, -6500,0.6141207,3.5723078,,,,,,,,,,,,,,,,, -6600,0.44339928,3.5607524,,,,,,,,,,,,,,,,, -6700,0.4469394,3.51521,,,,,,,,,,,,,,,,, -6800,0.47114983,3.5076911,,,,,,,,,,,,,,,,, -6900,0.50720835,3.529172,,,,,,,,,,,,,,,,, -7000,0.47240666,3.4961388,,,,,,,,,,,,,,,,, -7100,0.62237614,3.476899,,,,,,,,,,,,,,,,, -7200,0.44636658,3.5371766,,,,,,,,,,,,,,,,, -7203,,,0.5805636644363403,2.379575252532959,27.023267640432262,0.5879034399986267,2.307112455368042,23.28663282112053,3000.0,0.5903899073600769,2.3085203170776367,21.95396396643552,3003.0,2557.0957322120667,4909.256707191467,2557.0957322120667,2351.8377647399902,0.0824933052062988,0.0 -7300,0.46465993,3.432716,,,,,,,,,,,,,,,,, -7400,0.39880192,3.459524,,,,,,,,,,,,,,,,, -7500,0.38901022,3.3875484,,,,,,,,,,,,,,,,, -7600,0.51204234,3.5260239,,,,,,,,,,,,,,,,, -7700,0.44939968,3.3514705,,,,,,,,,,,,,,,,, -7800,0.37245718,3.3368778,,,,,,,,,,,,,,,,, -7900,0.43896016,3.3803535,,,,,,,,,,,,,,,,, -8000,0.35212615,3.5170288,,,,,,,,,,,,,,,,, -8100,0.36297658,3.3398578,,,,,,,,,,,,,,,,, -8200,0.3998358,3.4141235,,,,,,,,,,,,,,,,, -8300,0.38474798,3.3943677,,,,,,,,,,,,,,,,, -8400,0.40479916,3.375346,,,,,,,,,,,,,,,,, -8500,0.37190616,3.4458716,,,,,,,,,,,,,,,,, -8600,0.34823346,3.3641913,,,,,,,,,,,,,,,,, -8700,0.38051122,3.3740106,,,,,,,,,,,,,,,,, -8800,0.3524118,3.4404438,,,,,,,,,,,,,,,,, -8900,0.33806318,3.3296337,,,,,,,,,,,,,,,,, -9000,0.32267666,3.3291054,,,,,,,,,,,,,,,,, -9100,0.34676978,3.3802135,,,,,,,,,,,,,,,,, -9200,0.36493698,3.4417162,,,,,,,,,,,,,,,,, -9300,0.33280283,3.2764966,,,,,,,,,,,,,,,,, -9400,0.32354736,3.3960025,,,,,,,,,,,,,,,,, -9500,0.31061962,3.3475335,,,,,,,,,,,,,,,,, -9600,0.2924667,3.3104808,,,,,,,,,,,,,,,,, -9607,,,0.593815803527832,2.2298953533172607,28.30504380712284,0.6075063943862915,2.134629487991333,24.75181928129517,3000.0,0.6115739941596985,2.1070995330810547,23.36949635105416,3003.0,3397.0145201683044,6201.824702739716,3397.0145201683044,2804.38107419014,0.109868049621582,0.0 -9700,0.36947587,3.2811024,,,,,,,,,,,,,,,,, -9800,0.2847013,3.343935,,,,,,,,,,,,,,,,, -9900,0.27550572,3.2677462,,,,,,,,,,,,,,,,, -10000,0.28337276,3.2282615,,,,,,,,,,,,,,,,, -10100,0.29991764,3.2755444,,,,,,,,,,,,,,,,, -10200,0.3917131,3.3297126,,,,,,,,,,,,,,,,, -10300,0.28544015,3.197742,,,,,,,,,,,,,,,,, -10400,0.27486983,3.2125618,,,,,,,,,,,,,,,,, -10500,0.27349907,3.3135488,,,,,,,,,,,,,,,,, -10600,0.2952698,3.2485209,,,,,,,,,,,,,,,,, -10700,0.27820665,3.299675,,,,,,,,,,,,,,,,, -10800,0.28084093,3.2919807,,,,,,,,,,,,,,,,, -10900,0.25379288,3.2300189,,,,,,,,,,,,,,,,, -11000,0.28091446,3.3172739,,,,,,,,,,,,,,,,, -11100,0.26205716,3.2391002,,,,,,,,,,,,,,,,, -11200,0.27539143,3.2647398,,,,,,,,,,,,,,,,, -11300,0.26948228,3.301154,,,,,,,,,,,,,,,,, -11400,0.27306473,3.177237,,,,,,,,,,,,,,,,, -11500,0.25330696,3.3221433,,,,,,,,,,,,,,,,, -11600,0.27883533,3.190311,,,,,,,,,,,,,,,,, -11700,0.29102492,3.2397459,,,,,,,,,,,,,,,,, -11800,0.28025132,3.2204847,,,,,,,,,,,,,,,,, -11900,0.2642047,3.2520082,,,,,,,,,,,,,,,,, -12000,0.23671255,3.26585,,,,,,,,,,,,,,,,, -12012,,,0.6022624969482422,2.149235486984253,28.608856006115754,0.619471549987793,2.007272720336914,25.498765397658293,3000.0,0.6252629160881042,1.965983867645264,24.26036476993152,3003.0,4237.137367010117,7712.961602926254,4237.137367010117,3475.285562515259,0.1377413272857666,0.0 -12100,0.25649744,3.2020557,,,,,,,,,,,,,,,,, -12200,0.23722793,3.2035222,,,,,,,,,,,,,,,,, -12300,0.2349013,3.2522326,,,,,,,,,,,,,,,,, -12400,0.22909717,3.1521087,,,,,,,,,,,,,,,,, -12500,0.2457433,3.243241,,,,,,,,,,,,,,,,, -12600,0.26869217,3.1760101,,,,,,,,,,,,,,,,, -12700,0.30730832,3.1858726,,,,,,,,,,,,,,,,, -12800,0.24904343,3.2957425,,,,,,,,,,,,,,,,, -12900,0.24739604,3.1157606,,,,,,,,,,,,,,,,, -13000,0.24093197,3.211424,,,,,,,,,,,,,,,,, -13100,0.308507,3.2573292,,,,,,,,,,,,,,,,, -13200,0.24263038,3.2331421,,,,,,,,,,,,,,,,, -13300,0.23270984,3.2245986,,,,,,,,,,,,,,,,, -13400,0.2743817,3.2050014,,,,,,,,,,,,,,,,, -13500,0.2506716,3.2429981,,,,,,,,,,,,,,,,, -13600,0.22551316,3.1464417,,,,,,,,,,,,,,,,, -13700,0.277286,3.1801376,,,,,,,,,,,,,,,,, -13800,0.2521915,3.1362135,,,,,,,,,,,,,,,,, -13900,0.24580424,3.14184,,,,,,,,,,,,,,,,, -14000,0.26818585,3.171025,,,,,,,,,,,,,,,,, -14100,0.2552638,3.1877708,,,,,,,,,,,,,,,,, -14200,0.26358816,3.1000257,,,,,,,,,,,,,,,,, -14300,0.25866708,3.121542,,,,,,,,,,,,,,,,, -14400,0.35845625,3.1413598,,,,,,,,,,,,,,,,, -14418,,,0.6124773025512695,2.0575904846191406,29.56920209281836,0.6307671070098877,1.9267817735672,26.21993126585849,3000.0,0.6387659311294556,1.878983736038208,25.881130375271567,3003.0,5077.038963317871,9003.50999879837,5077.038963317871,3925.826815366745,0.1641829013824463,0.0 -14500,0.2761798,3.1185312,,,,,,,,,,,,,,,,, -14600,0.24191102,3.1925232,,,,,,,,,,,,,,,,, -14700,0.24836272,3.1168454,,,,,,,,,,,,,,,,, -14800,0.31038225,3.1727443,,,,,,,,,,,,,,,,, -14900,0.24482755,3.1435456,,,,,,,,,,,,,,,,, -15000,0.23128358,3.1528246,,,,,,,,,,,,,,,,, -15100,0.30718377,3.118035,,,,,,,,,,,,,,,,, -15200,0.2996521,3.1517162,,,,,,,,,,,,,,,,, -15300,0.2453735,3.145319,,,,,,,,,,,,,,,,, -15400,0.31840926,3.1400578,,,,,,,,,,,,,,,,, -15500,0.2493558,3.1230986,,,,,,,,,,,,,,,,, -15600,0.2605249,3.1953263,,,,,,,,,,,,,,,,, -15700,0.26092273,3.1775584,,,,,,,,,,,,,,,,, -15800,0.29306045,3.076605,,,,,,,,,,,,,,,,, -15900,0.2957731,3.1686866,,,,,,,,,,,,,,,,, -16000,0.27954707,3.0188606,,,,,,,,,,,,,,,,, -16100,0.2777589,3.1505241,,,,,,,,,,,,,,,,, -16200,0.24471428,3.0484207,,,,,,,,,,,,,,,,, -16300,0.25360292,3.1192167,,,,,,,,,,,,,,,,, -16400,0.2483096,3.057439,,,,,,,,,,,,,,,,, -16500,0.26235384,3.0574381,,,,,,,,,,,,,,,,, -16600,0.26155773,3.075743,,,,,,,,,,,,,,,,, -16700,0.2293744,3.048357,,,,,,,,,,,,,,,,, -16800,0.2913952,3.095084,,,,,,,,,,,,,,,,, -16823,,,0.6192818880081177,2.010906934738159,29.80887983767533,0.6372766494750977,1.866188883781433,26.629292980287577,3000.0,0.6472604870796204,1.81446385383606,26.0036035069853,3003.0,5916.943907022476,10295.500350952148,5916.943907022476,4377.803135633469,0.1926159858703613,0.0 -16900,0.29234245,3.1707551,,,,,,,,,,,,,,,,, -17000,0.28023347,3.1891198,,,,,,,,,,,,,,,,, -17100,0.3486461,3.1386013,,,,,,,,,,,,,,,,, -17200,0.32831055,3.0717134,,,,,,,,,,,,,,,,, -17300,0.28412035,3.0263877,,,,,,,,,,,,,,,,, -17400,0.3257456,3.117484,,,,,,,,,,,,,,,,, -17500,0.29490608,3.0422804,,,,,,,,,,,,,,,,, -17600,0.28648224,3.0915456,,,,,,,,,,,,,,,,, -17700,0.27088618,3.074155,,,,,,,,,,,,,,,,, -17800,0.2914233,3.120818,,,,,,,,,,,,,,,,, -17900,0.30035663,3.0118446,,,,,,,,,,,,,,,,, -18000,0.33197945,3.0705364,,,,,,,,,,,,,,,,, -18100,0.3482403,3.007872,,,,,,,,,,,,,,,,, -18200,0.3025795,3.0739837,,,,,,,,,,,,,,,,, -18300,0.29928574,3.1054165,,,,,,,,,,,,,,,,, -18400,0.37551534,3.0440493,,,,,,,,,,,,,,,,, -18500,0.28111237,3.089026,,,,,,,,,,,,,,,,, -18600,0.28805357,3.0393088,,,,,,,,,,,,,,,,, -18700,0.29041803,3.0646698,,,,,,,,,,,,,,,,, -18800,0.31731325,2.9943407,,,,,,,,,,,,,,,,, -18900,0.31655267,3.0669646,,,,,,,,,,,,,,,,, -19000,0.33348453,2.9728198,,,,,,,,,,,,,,,,, -19100,0.33628199,3.0668066,,,,,,,,,,,,,,,,, -19200,0.30730882,3.0734897,,,,,,,,,,,,,,,,, -19230,,,0.636422872543335,1.8756887912750244,30.822602520333803,0.6443069577217102,1.820154428482056,27.01061448954327,3000.0,0.6535820364952087,1.7662612199783323,26.51818607626821,3003.0,6757.095978498459,11722.64000749588,6757.095978498459,4964.684417009354,0.2200927734375,0.0 -19300,0.33170968,2.9675195,,,,,,,,,,,,,,,,, -19400,0.3003093,3.05155,,,,,,,,,,,,,,,,, -19500,0.38828444,3.0536907,,,,,,,,,,,,,,,,, -19600,0.2798976,3.074431,,,,,,,,,,,,,,,,, -19700,0.2957574,3.0557253,,,,,,,,,,,,,,,,, -19800,0.3221381,3.109547,,,,,,,,,,,,,,,,, -19900,0.39548644,3.0314674,,,,,,,,,,,,,,,,, -20000,0.33229962,3.042297,,,,,,,,,,,,,,,,, -20100,0.31754082,3.0231714,,,,,,,,,,,,,,,,, -20200,0.31459528,3.0380714,,,,,,,,,,,,,,,,, -20300,0.29866818,3.0160701,,,,,,,,,,,,,,,,, -20400,0.2895588,3.1351426,,,,,,,,,,,,,,,,, -20500,0.2922265,3.0829537,,,,,,,,,,,,,,,,, -20600,0.3150173,3.0774636,,,,,,,,,,,,,,,,, -20700,0.3354919,3.0508249,,,,,,,,,,,,,,,,, -20800,0.45975694,3.0210977,,,,,,,,,,,,,,,,, -20900,0.34248915,2.9823596,,,,,,,,,,,,,,,,, -21000,0.32233104,3.1062298,,,,,,,,,,,,,,,,, -21100,0.4051584,3.0430856,,,,,,,,,,,,,,,,, -21200,0.31902146,3.05426,,,,,,,,,,,,,,,,, -21300,0.35790735,3.0068855,,,,,,,,,,,,,,,,, -21400,0.3691591,3.1259115,,,,,,,,,,,,,,,,, -21500,0.33930898,2.994496,,,,,,,,,,,,,,,,, -21600,0.327855,3.0074193,,,,,,,,,,,,,,,,, -21638,,,0.6261128783226013,1.9536588191986084,30.657585430152547,0.6480762958526611,1.7900378704071045,27.471663978033018,3000.0,0.6585091352462769,1.7311699390411377,27.19216074685112,3003.0,7597.150082349777,13029.59965801239,7597.150082349777,5431.484809875488,0.2487683296203613,0.0 -21700,0.31967646,2.9987493,,,,,,,,,,,,,,,,, -21800,0.33214882,3.0852504,,,,,,,,,,,,,,,,, -21900,0.33367026,3.010917,,,,,,,,,,,,,,,,, -22000,0.37829125,3.021897,,,,,,,,,,,,,,,,, -22100,0.35195008,3.0601082,,,,,,,,,,,,,,,,, -22200,0.3620827,3.0248592,,,,,,,,,,,,,,,,, -22300,0.35357356,3.012206,,,,,,,,,,,,,,,,, -22400,1.2556627,3.233935,,,,,,,,,,,,,,,,, -22500,0.40061265,3.080781,,,,,,,,,,,,,,,,, -22600,0.32391703,3.0191693,,,,,,,,,,,,,,,,, -22700,0.3557115,3.077967,,,,,,,,,,,,,,,,, -22800,0.32623807,2.9078486,,,,,,,,,,,,,,,,, -22900,0.31827936,3.0688443,,,,,,,,,,,,,,,,, -23000,0.35737,3.0512607,,,,,,,,,,,,,,,,, -23100,0.35798588,3.0277538,,,,,,,,,,,,,,,,, -23200,0.40441403,4.951005,,,,,,,,,,,,,,,,, -23300,0.8103988,4.800269,,,,,,,,,,,,,,,,, -23400,0.4633919,4.7491145,,,,,,,,,,,,,,,,, -23500,0.9162581,4.73256,,,,,,,,,,,,,,,,, -23600,1.7985848,4.723015,,,,,,,,,,,,,,,,, -23700,3.8238082,4.5235023,,,,,,,,,,,,,,,,, -23800,0.45383963,3.2167916,,,,,,,,,,,,,,,,, -23900,0.33084968,3.056258,,,,,,,,,,,,,,,,, -24000,0.34717715,3.0329657,,,,,,,,,,,,,,,,, -24047,,,0.6246453523635864,1.9670435190200808,29.88082264721335,0.6471339464187622,1.80759871006012,27.139849535608672,3000.0,0.6549997329711914,1.7625569105148315,26.310299172789826,3003.0,8437.384685277939,14372.047944068909,8437.384685277939,5933.594583034515,0.2760488986968994,0.0 -24100,0.33554503,3.1059432,,,,,,,,,,,,,,,,, -24200,0.32146445,3.0242348,,,,,,,,,,,,,,,,, -24300,0.31605777,3.126899,,,,,,,,,,,,,,,,, -24400,0.41115415,3.0635707,,,,,,,,,,,,,,,,, -24500,0.32870996,3.019026,,,,,,,,,,,,,,,,, -24600,0.32940075,2.9976394,,,,,,,,,,,,,,,,, -24700,0.31162694,2.9599712,,,,,,,,,,,,,,,,, -24800,0.3103319,3.0419047,,,,,,,,,,,,,,,,, -24900,0.31605253,3.0346017,,,,,,,,,,,,,,,,, -25000,0.36257458,3.0440142,,,,,,,,,,,,,,,,, -25100,0.39073664,3.040219,,,,,,,,,,,,,,,,, -25200,0.37548482,3.0149846,,,,,,,,,,,,,,,,, -25300,0.3979779,3.075325,,,,,,,,,,,,,,,,, -25400,0.31207344,3.0147836,,,,,,,,,,,,,,,,, -25500,0.3266771,2.93045,,,,,,,,,,,,,,,,, -25600,0.3194738,3.0399013,,,,,,,,,,,,,,,,, -25700,0.3444063,3.0024753,,,,,,,,,,,,,,,,, -25800,0.34251478,2.9295177,,,,,,,,,,,,,,,,, -25900,0.37787107,3.0667784,,,,,,,,,,,,,,,,, -26000,0.3538953,2.9029894,,,,,,,,,,,,,,,,, -26100,0.29191592,3.0658846,,,,,,,,,,,,,,,,, -26200,0.3547889,3.0745838,,,,,,,,,,,,,,,,, -26300,0.33395585,3.017019,,,,,,,,,,,,,,,,, -26400,0.34638646,3.1209092,,,,,,,,,,,,,,,,, -26455,,,0.6341571807861328,1.888005971908569,30.91961353683156,0.6495393514633179,1.773059964179993,27.610013018571298,3000.0,0.6604846119880676,1.7240240573883057,27.430558262691022,3003.0,9277.445498466492,15694.008920431135,9277.445498466492,6415.389216184616,0.3032636642456054,0.0 -26500,0.38181126,3.0662756,,,,,,,,,,,,,,,,, -26600,0.37549728,3.0394502,,,,,,,,,,,,,,,,, -26700,0.3264214,3.0150638,,,,,,,,,,,,,,,,, -26800,0.35898492,3.0587785,,,,,,,,,,,,,,,,, -26900,0.3792943,3.0130424,,,,,,,,,,,,,,,,, -27000,0.35221946,2.9789255,,,,,,,,,,,,,,,,, -27100,0.45654702,3.0418968,,,,,,,,,,,,,,,,, -27200,0.35707584,2.9869251,,,,,,,,,,,,,,,,, -27300,0.3794677,2.996762,,,,,,,,,,,,,,,,, -27400,0.31511322,2.9461389,,,,,,,,,,,,,,,,, -27500,0.37863076,2.965122,,,,,,,,,,,,,,,,, -27600,0.6740385,3.0480287,,,,,,,,,,,,,,,,, -27700,0.3531155,3.0822992,,,,,,,,,,,,,,,,, -27800,0.3182686,3.0452094,,,,,,,,,,,,,,,,, -27900,0.4266684,2.9999032,,,,,,,,,,,,,,,,, -28000,0.3650755,2.9955118,,,,,,,,,,,,,,,,, -28100,0.43415835,2.9604275,,,,,,,,,,,,,,,,, -28200,0.4944147,3.028823,,,,,,,,,,,,,,,,, -28300,0.3920546,2.9758136,,,,,,,,,,,,,,,,, -28400,0.3721593,2.997543,,,,,,,,,,,,,,,,, -28500,0.5467609,3.051574,,,,,,,,,,,,,,,,, -28600,0.3460258,3.036574,,,,,,,,,,,,,,,,, -28700,0.33977795,3.0010767,,,,,,,,,,,,,,,,, -28800,0.33848295,3.0464072,,,,,,,,,,,,,,,,, -28863,,,0.6321583390235901,1.9047863483428955,30.918711635523824,0.6536558866500854,1.7527354955673218,27.96373892400448,3000.0,0.6618906855583191,1.6961952447891235,27.259525740209053,3003.0,10117.52335381508,17037.438900470734,10117.52335381508,6918.62993311882,0.3362345695495605,0.0 -28900,0.36644328,3.076702,,,,,,,,,,,,,,,,, -29000,0.34352198,2.990071,,,,,,,,,,,,,,,,, -29100,0.37043813,3.0281658,,,,,,,,,,,,,,,,, -29200,0.38884985,3.0269089,,,,,,,,,,,,,,,,, -29300,0.33586192,3.0736358,,,,,,,,,,,,,,,,, -29400,0.45217404,3.0115228,,,,,,,,,,,,,,,,, -29500,0.35272953,3.0370128,,,,,,,,,,,,,,,,, -29600,0.37294182,2.9872398,,,,,,,,,,,,,,,,, -29700,0.32346275,3.0217502,,,,,,,,,,,,,,,,, -29800,0.35375842,2.9888103,,,,,,,,,,,,,,,,, -29900,0.3686023,3.0185277,,,,,,,,,,,,,,,,, -30000,0.38578662,3.012173,,,,,,,,,,,,,,,,, -30100,0.36295572,3.0437958,,,,,,,,,,,,,,,,, -30200,0.36326063,3.0135527,,,,,,,,,,,,,,,,, -30300,0.38640127,3.0236444,,,,,,,,,,,,,,,,, -30400,0.33298004,2.944222,,,,,,,,,,,,,,,,, -30500,0.36283216,3.0241752,,,,,,,,,,,,,,,,, -30600,0.37246954,2.9243214,,,,,,,,,,,,,,,,, -30700,0.30697942,2.996847,,,,,,,,,,,,,,,,, -30800,0.34580967,2.9465544,,,,,,,,,,,,,,,,, -30900,0.3883292,2.989435,,,,,,,,,,,,,,,,, -31000,0.5217367,2.9999576,,,,,,,,,,,,,,,,, -31100,0.34031558,3.0550592,,,,,,,,,,,,,,,,, -31200,0.43036947,2.9633608,,,,,,,,,,,,,,,,, -31272,,,0.6326280832290649,1.9133557081222528,30.91387757741362,0.6534698605537415,1.7469457387924194,27.38010984222516,3000.0,0.6626111268997192,1.6905726194381714,27.10428323835168,3003.0,10957.66311454773,18410.030876636505,10957.66311454773,7450.977185487747,0.3637206554412842,0.0 -31300,0.3527978,3.0241156,,,,,,,,,,,,,,,,, -31400,0.37120378,3.0000408,,,,,,,,,,,,,,,,, -31500,0.33684456,3.0122855,,,,,,,,,,,,,,,,, -31600,0.40507144,2.9728687,,,,,,,,,,,,,,,,, -31700,0.42334768,3.034522,,,,,,,,,,,,,,,,, -31800,0.36071432,2.9203851,,,,,,,,,,,,,,,,, -31900,0.39341202,3.0147297,,,,,,,,,,,,,,,,, -32000,0.358254,2.9879231,,,,,,,,,,,,,,,,, -32100,0.3708614,3.017419,,,,,,,,,,,,,,,,, -32200,0.4511008,2.9978235,,,,,,,,,,,,,,,,, -32300,0.3642645,3.0297496,,,,,,,,,,,,,,,,, -32400,0.39234808,3.046992,,,,,,,,,,,,,,,,, -32500,0.45324507,2.9798117,,,,,,,,,,,,,,,,, -32600,0.324661,2.9929411,,,,,,,,,,,,,,,,, -32700,0.3413221,2.9849694,,,,,,,,,,,,,,,,, -32800,0.3589403,3.033182,,,,,,,,,,,,,,,,, -32900,0.40201846,2.981211,,,,,,,,,,,,,,,,, -33000,0.4154511,2.9672396,,,,,,,,,,,,,,,,, -33100,0.39592287,3.0058513,,,,,,,,,,,,,,,,, -33200,0.34501827,3.0204275,,,,,,,,,,,,,,,,, -33300,0.44334337,2.964961,,,,,,,,,,,,,,,,, -33400,0.46238232,3.0177586,,,,,,,,,,,,,,,,, -33500,0.35901427,2.9631226,,,,,,,,,,,,,,,,, -33600,0.34995556,3.0303648,,,,,,,,,,,,,,,,, -33681,,,0.63736891746521,1.876145958900452,30.778807132463136,0.6563588976860046,1.7285524606704712,27.70838599360045,3000.0,0.6662018895149231,1.6690125465393066,27.610379038748825,3003.0,11797.693771123886,19730.70421743393,11797.693771123886,7931.514752388,0.3925554752349853,0.0 -33700,0.41661236,3.042301,,,,,,,,,,,,,,,,, -33800,0.40023756,2.9488273,,,,,,,,,,,,,,,,, -33900,0.323034,2.9025512,,,,,,,,,,,,,,,,, -34000,0.4029363,2.9675553,,,,,,,,,,,,,,,,, -34100,0.3813762,2.9754076,,,,,,,,,,,,,,,,, -34200,0.36010852,2.9926457,,,,,,,,,,,,,,,,, -34300,0.38689965,2.9456062,,,,,,,,,,,,,,,,, -34400,0.341096,2.9497027,,,,,,,,,,,,,,,,, -34500,0.36510417,2.9095075,,,,,,,,,,,,,,,,, -34600,0.4225929,2.9664521,,,,,,,,,,,,,,,,, -34700,0.38348332,3.0318887,,,,,,,,,,,,,,,,, -34800,0.3586507,2.9313636,,,,,,,,,,,,,,,,, -34900,0.40955073,2.9921165,,,,,,,,,,,,,,,,, -35000,0.4249765,2.9387228,,,,,,,,,,,,,,,,, -35100,0.36146754,2.975209,,,,,,,,,,,,,,,,, -35200,0.33812585,2.9130406,,,,,,,,,,,,,,,,, -35300,0.37054926,2.9770968,,,,,,,,,,,,,,,,, -35400,0.43368796,2.9614835,,,,,,,,,,,,,,,,, -35500,0.34703708,2.9853776,,,,,,,,,,,,,,,,, -35600,0.34499523,2.9558008,,,,,,,,,,,,,,,,, -35700,0.38273567,3.011263,,,,,,,,,,,,,,,,, -35800,0.38619623,2.988782,,,,,,,,,,,,,,,,, -35900,0.37392673,3.0080667,,,,,,,,,,,,,,,,, -36000,0.3391919,2.9354174,,,,,,,,,,,,,,,,, -36090,,,0.6397969126701355,1.858922243118286,30.72574395429576,0.6573135852813721,1.7214807271957395,28.437746236206387,3000.0,0.6704433560371399,1.653224229812622,28.013954515918293,3003.0,12637.642409086227,21055.924880743027,12637.642409086227,8416.673792600632,0.4278111457824707,0.0 -36100,0.44385767,2.9063737,,,,,,,,,,,,,,,,, -36200,0.33844116,2.98869,,,,,,,,,,,,,,,,, -36300,0.33211547,2.9055903,,,,,,,,,,,,,,,,, -36400,0.3393165,3.0171845,,,,,,,,,,,,,,,,, -36500,0.38864183,2.975186,,,,,,,,,,,,,,,,, -36600,0.3658602,2.9309075,,,,,,,,,,,,,,,,, -36700,0.3797833,2.8893993,,,,,,,,,,,,,,,,, -36800,0.3925277,2.9308202,,,,,,,,,,,,,,,,, -36900,0.44163728,3.0230029,,,,,,,,,,,,,,,,, -37000,0.43422243,2.9894774,,,,,,,,,,,,,,,,, -37100,0.33839706,2.9221716,,,,,,,,,,,,,,,,, -37200,0.33104488,2.9793653,,,,,,,,,,,,,,,,, -37300,0.38225842,3.0162215,,,,,,,,,,,,,,,,, -37400,0.39074126,3.0713432,,,,,,,,,,,,,,,,, -37500,0.34131894,2.998153,,,,,,,,,,,,,,,,, -37600,0.37660944,3.0403755,,,,,,,,,,,,,,,,, -37700,0.36989388,3.0335283,,,,,,,,,,,,,,,,, -37800,0.34776425,2.898425,,,,,,,,,,,,,,,,, -37900,0.38580102,2.9886472,,,,,,,,,,,,,,,,, -38000,0.38694248,2.983973,,,,,,,,,,,,,,,,, -38100,0.36357978,2.9982815,,,,,,,,,,,,,,,,, -38200,0.40853325,2.9461343,,,,,,,,,,,,,,,,, -38300,0.3498758,2.954144,,,,,,,,,,,,,,,,, -38400,0.34673524,2.9814944,,,,,,,,,,,,,,,,, -38500,,,0.6449896693229675,1.8099712133407595,31.57133577679328,0.6588138937950134,1.7067581415176392,28.44258842361144,3000.0,0.6685724258422852,1.6480258703231812,27.752310460703928,3003.0,13477.828237771988,22369.73627972603,13477.828237771988,8890.188046693802,0.4626927375793457,0.0 -38500,0.36839217,2.9878879,,,,,,,,,,,,,,,,, -38600,0.36799803,2.9489515,,,,,,,,,,,,,,,,, -38700,0.3580553,2.9598694,,,,,,,,,,,,,,,,, -38800,0.34867203,2.9477305,,,,,,,,,,,,,,,,, -38900,0.359257,2.9535313,,,,,,,,,,,,,,,,, -39000,0.37839782,2.9857645,,,,,,,,,,,,,,,,, -39100,0.36512813,2.9759817,,,,,,,,,,,,,,,,, -39200,0.3475339,2.8727734,,,,,,,,,,,,,,,,, -39300,0.3575704,3.0020075,,,,,,,,,,,,,,,,, -39400,0.41807538,2.9118347,,,,,,,,,,,,,,,,, -39500,0.32469746,2.9833472,,,,,,,,,,,,,,,,, -39600,0.37182468,2.892242,,,,,,,,,,,,,,,,, -39700,0.3461079,2.9544868,,,,,,,,,,,,,,,,, -39800,0.3993152,2.9466493,,,,,,,,,,,,,,,,, -39900,0.33079788,2.9343164,,,,,,,,,,,,,,,,, -40000,0.38237888,2.971024,,,,,,,,,,,,,,,,, -40100,0.3348972,2.9323823,,,,,,,,,,,,,,,,, -40200,0.37203276,2.999049,,,,,,,,,,,,,,,,, -40300,0.332455,2.9285882,,,,,,,,,,,,,,,,, -40400,0.3446807,2.9848568,,,,,,,,,,,,,,,,, -40500,0.32516843,2.9769127,,,,,,,,,,,,,,,,, -40600,0.37376276,2.8548367,,,,,,,,,,,,,,,,, -40700,0.41948164,2.9617894,,,,,,,,,,,,,,,,, -40800,0.3331064,2.9680884,,,,,,,,,,,,,,,,, -40900,0.3769537,3.0333617,,,,,,,,,,,,,,,,, -40910,,,0.6428003907203674,1.8337481021881104,31.276876161097288,0.6606861352920532,1.705464482307434,28.58794224496653,3000.0,0.6722212433815002,1.6383522748947144,28.249942435103367,3003.0,14318.0378780365,23677.41546010971,14318.0378780365,9357.548652887344,0.4925673007965088,0.0 -41000,0.41150585,3.0063841,,,,,,,,,,,,,,,,, -41100,0.35505363,2.9919977,,,,,,,,,,,,,,,,, -41200,0.40674728,2.9094815,,,,,,,,,,,,,,,,, -41300,0.41899255,3.0013547,,,,,,,,,,,,,,,,, -41400,0.3797595,2.9727833,,,,,,,,,,,,,,,,, -41500,0.35613784,2.9267204,,,,,,,,,,,,,,,,, -41600,0.433192,2.9969761,,,,,,,,,,,,,,,,, -41700,0.48042494,2.9586725,,,,,,,,,,,,,,,,, -41800,0.36095145,3.0062377,,,,,,,,,,,,,,,,, -41900,0.3453037,2.9118292,,,,,,,,,,,,,,,,, -42000,0.33643326,3.0130434,,,,,,,,,,,,,,,,, -42100,0.34439334,2.9663005,,,,,,,,,,,,,,,,, -42200,0.437041,2.9381077,,,,,,,,,,,,,,,,, -42300,0.34411597,2.9107883,,,,,,,,,,,,,,,,, -42400,0.35943854,2.8870413,,,,,,,,,,,,,,,,, -42500,0.35986924,2.9433022,,,,,,,,,,,,,,,,, -42600,0.35365465,2.8997114,,,,,,,,,,,,,,,,, -42700,0.34222013,2.9684854,,,,,,,,,,,,,,,,, -42800,0.3578856,2.8865814,,,,,,,,,,,,,,,,, -42900,0.3479402,2.8756528,,,,,,,,,,,,,,,,, -43000,0.34232584,2.9591703,,,,,,,,,,,,,,,,, -43100,0.36402842,2.9702418,,,,,,,,,,,,,,,,, -43200,0.38171187,2.9765146,,,,,,,,,,,,,,,,, -43300,0.36502105,2.896486,,,,,,,,,,,,,,,,, -43320,,,0.6413927674293518,1.848698258399964,30.77948761714045,0.6622360348701477,1.69834566116333,28.52823324647621,3000.0,0.6720702052116394,1.6318928003311155,28.10957228259788,3003.0,15158.246175765991,25011.8331720829,15158.246175765991,9851.644168376924,0.5287151336669922,0.0 -43400,0.34522328,2.9976504,,,,,,,,,,,,,,,,, -43500,0.35319644,2.932988,,,,,,,,,,,,,,,,, -43600,0.3617107,2.9217124,,,,,,,,,,,,,,,,, -43700,0.37231237,2.9856026,,,,,,,,,,,,,,,,, -43800,0.34769568,2.9896164,,,,,,,,,,,,,,,,, -43900,0.41869125,2.9447849,,,,,,,,,,,,,,,,, -44000,0.3742565,2.9576542,,,,,,,,,,,,,,,,, -44100,0.3496294,2.9131956,,,,,,,,,,,,,,,,, -44200,0.3419783,2.9652276,,,,,,,,,,,,,,,,, -44300,0.38803506,2.9335883,,,,,,,,,,,,,,,,, -44400,0.36320284,2.904542,,,,,,,,,,,,,,,,, -44500,0.35261223,2.9870129,,,,,,,,,,,,,,,,, -44600,0.32318544,2.8719404,,,,,,,,,,,,,,,,, -44700,0.33572066,2.9397905,,,,,,,,,,,,,,,,, -44800,0.36520013,2.96145,,,,,,,,,,,,,,,,, -44900,0.35670453,2.9985454,,,,,,,,,,,,,,,,, -45000,0.32032368,2.9275002,,,,,,,,,,,,,,,,, -45100,0.34379226,2.947551,,,,,,,,,,,,,,,,, -45200,0.37277398,2.9283206,,,,,,,,,,,,,,,,, -45300,0.38955906,3.019192,,,,,,,,,,,,,,,,, -45400,0.3730519,2.965506,,,,,,,,,,,,,,,,, -45500,0.3315798,2.9606287,,,,,,,,,,,,,,,,, -45600,0.3595002,2.9527097,,,,,,,,,,,,,,,,, -45700,0.37972945,2.913202,,,,,,,,,,,,,,,,, -45730,,,0.642379641532898,1.8306187391281128,31.599102572516426,0.6620624661445618,1.6910277605056765,28.548053027735303,3000.0,0.6729068756103516,1.6298214197158811,27.98403783607923,3003.0,15998.48146367073,26347.59620976448,15998.48146367073,10347.062133073809,0.5609724521636963,0.0 -45800,0.31877908,2.8915386,,,,,,,,,,,,,,,,, -45900,0.3532235,2.9464526,,,,,,,,,,,,,,,,, -46000,0.34915167,2.9263847,,,,,,,,,,,,,,,,, -46100,0.39213443,2.9837947,,,,,,,,,,,,,,,,, -46200,0.3780412,2.9770494,,,,,,,,,,,,,,,,, -46300,0.33689436,2.9248323,,,,,,,,,,,,,,,,, -46400,0.39994282,2.8950505,,,,,,,,,,,,,,,,, -46500,0.40489617,2.9386046,,,,,,,,,,,,,,,,, -46600,0.35937598,2.8725734,,,,,,,,,,,,,,,,, -46700,0.3938976,3.0340986,,,,,,,,,,,,,,,,, -46800,0.338953,2.8687644,,,,,,,,,,,,,,,,, -46900,0.35911304,2.910063,,,,,,,,,,,,,,,,, -47000,0.34725553,2.9363797,,,,,,,,,,,,,,,,, -47100,0.35132638,2.9096305,,,,,,,,,,,,,,,,, -47200,0.3205466,2.8679008,,,,,,,,,,,,,,,,, -47300,0.31157926,2.8788116,,,,,,,,,,,,,,,,, -47400,0.38010246,2.9049826,,,,,,,,,,,,,,,,, -47500,0.34563518,2.9272842,,,,,,,,,,,,,,,,, -47600,0.33758187,2.9252646,,,,,,,,,,,,,,,,, -47700,0.38331327,2.9681509,,,,,,,,,,,,,,,,, -47800,0.347349,2.8787754,,,,,,,,,,,,,,,,, -47900,0.37917966,2.9338791,,,,,,,,,,,,,,,,, -48000,0.32270113,2.877758,,,,,,,,,,,,,,,,, -48100,0.36561638,2.9621093,,,,,,,,,,,,,,,,, -48140,,,0.643227219581604,1.8408466577529907,31.4491060968431,0.6614425182342529,1.6888902187347412,28.538099099304283,3000.0,0.674429178237915,1.6186769008636477,28.14769900835641,3003.0,16838.670334100723,27669.10016155243,16838.670334100723,10828.267130374908,0.5938632488250732,0.0 -48200,0.30308795,2.9348507,,,,,,,,,,,,,,,,, -48300,0.3656923,2.9613414,,,,,,,,,,,,,,,,, -48400,0.3565238,2.9247534,,,,,,,,,,,,,,,,, -48500,0.34611458,2.8979766,,,,,,,,,,,,,,,,, -48600,0.3301931,2.983437,,,,,,,,,,,,,,,,, -48700,0.34949088,2.9174733,,,,,,,,,,,,,,,,, -48800,0.31951752,2.884602,,,,,,,,,,,,,,,,, -48900,0.3435319,3.015014,,,,,,,,,,,,,,,,, -49000,0.3623376,2.9439182,,,,,,,,,,,,,,,,, -49100,0.36683774,2.9429867,,,,,,,,,,,,,,,,, -49200,0.37226647,2.9636838,,,,,,,,,,,,,,,,, -49300,0.44781694,2.9638722,,,,,,,,,,,,,,,,, -49400,0.34892938,2.9114292,,,,,,,,,,,,,,,,, -49500,0.40112287,2.9840846,,,,,,,,,,,,,,,,, -49600,0.351469,2.9582837,,,,,,,,,,,,,,,,, -49700,0.38180077,2.8964748,,,,,,,,,,,,,,,,, -49800,0.36001083,2.9206204,,,,,,,,,,,,,,,,, -49900,0.38832048,2.991858,,,,,,,,,,,,,,,,, -50000,0.36119276,2.9365158,,,,,,,,,,,,,,,,, -50100,0.35138074,2.9524667,,,,,,,,,,,,,,,,, -50200,0.37970835,3.0021982,,,,,,,,,,,,,,,,, -50300,0.33164823,2.8614144,,,,,,,,,,,,,,,,, -50400,0.31774983,2.9067225,,,,,,,,,,,,,,,,, -50500,0.39097118,2.9372714,,,,,,,,,,,,,,,,, -50549,,,0.6617670655250549,1.7062171697616575,32.493355001995035,0.6642819046974182,1.6819127798080444,28.50163118714814,3000.0,0.6748009920120239,1.6188348531723022,28.20650007439196,3003.0,17678.634654521942,29046.603624105453,17678.634654521942,11365.696782827376,0.6260786056518555,0.0 -50600,0.36424735,2.9333467,,,,,,,,,,,,,,,,, -50700,0.3546461,3.012265,,,,,,,,,,,,,,,,, -50800,0.3235164,2.8694072,,,,,,,,,,,,,,,,, -50900,0.41561487,2.9384365,,,,,,,,,,,,,,,,, -51000,0.3233962,2.9210858,,,,,,,,,,,,,,,,, -51100,0.35556623,2.8938637,,,,,,,,,,,,,,,,, -51200,0.38061157,2.978263,,,,,,,,,,,,,,,,, -51300,0.41310495,2.953616,,,,,,,,,,,,,,,,, -51400,0.36999953,2.9884548,,,,,,,,,,,,,,,,, -51500,0.3359097,2.8300796,,,,,,,,,,,,,,,,, -51600,0.3509856,2.9575253,,,,,,,,,,,,,,,,, -51700,0.4980116,2.9100103,,,,,,,,,,,,,,,,, -51800,0.32030365,2.981756,,,,,,,,,,,,,,,,, -51900,0.36580762,2.9552398,,,,,,,,,,,,,,,,, -52000,0.349761,2.950932,,,,,,,,,,,,,,,,, -52100,0.41756156,2.9690154,,,,,,,,,,,,,,,,, -52200,0.35098046,2.895122,,,,,,,,,,,,,,,,, -52300,0.36078274,2.979927,,,,,,,,,,,,,,,,, -52400,0.32393274,2.940219,,,,,,,,,,,,,,,,, -52500,0.34393936,2.9039958,,,,,,,,,,,,,,,,, -52600,0.33878177,2.893368,,,,,,,,,,,,,,,,, -52700,0.35129917,2.945137,,,,,,,,,,,,,,,,, -52800,0.31941223,2.9290786,,,,,,,,,,,,,,,,, -52900,0.332716,2.897949,,,,,,,,,,,,,,,,, -52959,,,0.6417939066886902,1.8330038785934448,31.521758366440302,0.6663029789924622,1.6756658554077148,28.710437477264808,3000.0,0.6778455972671509,1.6047146320343018,28.40794194410856,3003.0,18518.534535884857,30361.687136888504,18518.534535884857,11840.77304840088,0.6584103107452393,0.0 -53000,0.5674177,2.9506311,,,,,,,,,,,,,,,,, -53100,0.3445877,2.9219763,,,,,,,,,,,,,,,,, -53200,0.39961052,2.907761,,,,,,,,,,,,,,,,, -53300,0.3155157,2.871451,,,,,,,,,,,,,,,,, -53400,0.34020686,2.8861842,,,,,,,,,,,,,,,,, -53500,0.33793324,2.9159563,,,,,,,,,,,,,,,,, -53600,0.3517171,2.9204113,,,,,,,,,,,,,,,,, -53700,0.33374968,2.8397768,,,,,,,,,,,,,,,,, -53800,0.32498077,2.909721,,,,,,,,,,,,,,,,, -53900,0.3334037,2.9191008,,,,,,,,,,,,,,,,, -54000,0.35646248,2.9259527,,,,,,,,,,,,,,,,, -54100,0.34363276,2.8601,,,,,,,,,,,,,,,,, -54200,0.388806,2.9964054,,,,,,,,,,,,,,,,, -54300,0.36341342,2.8421257,,,,,,,,,,,,,,,,, -54400,0.38323605,2.8705478,,,,,,,,,,,,,,,,, -54500,0.33209708,2.8778584,,,,,,,,,,,,,,,,, -54600,0.3519774,2.9420006,,,,,,,,,,,,,,,,, -54700,0.34808183,2.8826585,,,,,,,,,,,,,,,,, -54800,0.36797887,2.8971577,,,,,,,,,,,,,,,,, -54900,0.34192252,2.9038405,,,,,,,,,,,,,,,,, -55000,0.32783723,2.8953846,,,,,,,,,,,,,,,,, -55100,0.36496392,2.886903,,,,,,,,,,,,,,,,, -55200,0.3292008,2.8568563,,,,,,,,,,,,,,,,, -55300,0.37262335,2.9699357,,,,,,,,,,,,,,,,, -55368,,,0.6438504457473755,1.8326653242111208,31.664923167422646,0.664095938205719,1.6701427698135376,28.63440920485693,3000.0,0.6779385209083557,1.60168719291687,28.4380119668229,3003.0,19358.49755549431,31703.02485537529,19358.49755549431,12342.030277490616,0.6971557140350342,0.0 -55400,0.37744546,2.939764,,,,,,,,,,,,,,,,, -55500,0.33950144,2.933517,,,,,,,,,,,,,,,,, -55600,0.34290922,2.9539418,,,,,,,,,,,,,,,,, -55700,0.346322,2.9544432,,,,,,,,,,,,,,,,, -55800,0.34978062,2.9356825,,,,,,,,,,,,,,,,, -55900,0.35593873,2.8575625,,,,,,,,,,,,,,,,, -56000,0.3232321,2.8629696,,,,,,,,,,,,,,,,, -56100,0.33971563,2.846754,,,,,,,,,,,,,,,,, -56200,0.33304626,2.9357479,,,,,,,,,,,,,,,,, -56300,0.40885416,2.9100153,,,,,,,,,,,,,,,,, -56400,0.37215164,2.938904,,,,,,,,,,,,,,,,, -56500,0.35120878,2.8801084,,,,,,,,,,,,,,,,, -56600,0.33566403,2.9792757,,,,,,,,,,,,,,,,, -56700,0.35702962,2.8373513,,,,,,,,,,,,,,,,, -56800,0.36675116,2.968954,,,,,,,,,,,,,,,,, -56900,0.35461873,2.8815966,,,,,,,,,,,,,,,,, -57000,0.47986695,2.9281733,,,,,,,,,,,,,,,,, -57100,0.33787665,2.8949976,,,,,,,,,,,,,,,,, -57200,0.3556958,2.933138,,,,,,,,,,,,,,,,, -57300,0.33996728,2.9896228,,,,,,,,,,,,,,,,, -57400,0.44231004,2.909468,,,,,,,,,,,,,,,,, -57500,0.34697828,2.8631728,,,,,,,,,,,,,,,,, -57600,0.3633189,2.8502252,,,,,,,,,,,,,,,,, -57700,0.33262295,2.9119465,,,,,,,,,,,,,,,,, -57777,,,0.6529799699783325,1.7605587244033811,32.104956986509244,0.6671584844589233,1.6584866046905518,28.91055358960645,3000.0,0.6783917546272278,1.5919768810272217,28.49080923108632,3003.0,20198.622362852097,33020.00865364075,20198.622362852097,12818.775514364244,0.7295718193054199,0.0 -57800,0.3628075,2.8903065,,,,,,,,,,,,,,,,, -57900,0.35249835,2.8745067,,,,,,,,,,,,,,,,, -58000,0.34550494,2.860415,,,,,,,,,,,,,,,,, -58100,0.33186132,2.8267756,,,,,,,,,,,,,,,,, -58200,0.33075693,2.8536708,,,,,,,,,,,,,,,,, -58300,0.35358047,2.9018776,,,,,,,,,,,,,,,,, -58400,0.33069906,2.8654742,,,,,,,,,,,,,,,,, -58500,0.31573293,2.8640301,,,,,,,,,,,,,,,,, -58600,0.30843726,2.9028866,,,,,,,,,,,,,,,,, -58700,0.33064577,2.9168127,,,,,,,,,,,,,,,,, -58800,0.36752713,2.9229164,,,,,,,,,,,,,,,,, -58900,0.3753418,2.9232247,,,,,,,,,,,,,,,,, -59000,0.367803,2.8963392,,,,,,,,,,,,,,,,, -59100,0.3287156,2.914113,,,,,,,,,,,,,,,,, -59200,0.35650578,2.9072835,,,,,,,,,,,,,,,,, -59300,0.33359334,2.9080546,,,,,,,,,,,,,,,,, -59400,0.33958572,2.9232986,,,,,,,,,,,,,,,,, -59500,0.37222755,2.895827,,,,,,,,,,,,,,,,, -59600,0.3478281,2.8805077,,,,,,,,,,,,,,,,, -59700,0.32622746,2.8501265,,,,,,,,,,,,,,,,, -59800,0.3614963,2.8890138,,,,,,,,,,,,,,,,, -59900,0.36286002,2.7990944,,,,,,,,,,,,,,,,, -60000,0.31841612,2.8613527,,,,,,,,,,,,,,,,, -60100,0.34565508,2.9143832,,,,,,,,,,,,,,,,, -60187,,,0.6493332982063293,1.788055658340454,31.857854688348205,0.6688447594642639,1.6495282649993896,28.90816655813359,3000.0,0.6815757751464844,1.579599380493164,28.60529602499556,3003.0,21038.77587127685,34336.24448752403,21038.77587127685,13294.7455265522,0.7640244960784912,0.0 -60200,0.32199562,2.9292207,,,,,,,,,,,,,,,,, -60300,0.38764375,2.9036262,,,,,,,,,,,,,,,,, -60400,0.3520466,2.867693,,,,,,,,,,,,,,,,, -60500,0.3318672,2.8525999,,,,,,,,,,,,,,,,, -60600,0.3578865,2.9030538,,,,,,,,,,,,,,,,, -60700,0.33058855,2.8569887,,,,,,,,,,,,,,,,, -60800,0.35316256,2.9466238,,,,,,,,,,,,,,,,, -60900,0.3605609,2.987785,,,,,,,,,,,,,,,,, -61000,0.37561437,2.9181292,,,,,,,,,,,,,,,,, -61100,0.43908462,2.878644,,,,,,,,,,,,,,,,, -61200,0.36017773,2.8878112,,,,,,,,,,,,,,,,, -61300,0.3802557,2.8831875,,,,,,,,,,,,,,,,, -61400,0.35434183,2.8926477,,,,,,,,,,,,,,,,, -61500,0.33153126,2.8767247,,,,,,,,,,,,,,,,, -61600,0.35447466,2.8280916,,,,,,,,,,,,,,,,, -61700,0.35852268,2.8710337,,,,,,,,,,,,,,,,, -61800,0.3601364,2.9219005,,,,,,,,,,,,,,,,, -61900,0.33814865,2.8501825,,,,,,,,,,,,,,,,, -62000,0.3508383,2.9868033,,,,,,,,,,,,,,,,, -62100,0.3650156,2.8164277,,,,,,,,,,,,,,,,, -62200,0.33813435,2.8861763,,,,,,,,,,,,,,,,, -62300,0.38900432,2.8833323,,,,,,,,,,,,,,,,, -62400,0.3611663,2.866753,,,,,,,,,,,,,,,,, -62500,0.35519204,2.9193358,,,,,,,,,,,,,,,,, -62597,,,0.6954829692840576,1.5248199701309204,35.48335512968942,0.6700350642204285,1.6441853046417236,29.23137089067413,3000.0,0.6829004883766174,1.5770306587219238,28.931010286915832,3003.0,21878.77816271782,35638.94716525078,21878.77816271782,13757.336695432665,0.7973370552062988,0.0 -62600,0.3808213,2.892209,,,,,,,,,,,,,,,,, -62700,0.3427107,2.9100754,,,,,,,,,,,,,,,,, -62800,0.34681335,2.8487515,,,,,,,,,,,,,,,,, -62900,0.37804705,2.8720646,,,,,,,,,,,,,,,,, -63000,0.35299885,2.9141471,,,,,,,,,,,,,,,,, -63100,0.35239795,2.8361392,,,,,,,,,,,,,,,,, -63200,0.33944133,2.8004694,,,,,,,,,,,,,,,,, -63300,0.35987058,2.8736563,,,,,,,,,,,,,,,,, -63400,0.39864194,2.916272,,,,,,,,,,,,,,,,, -63500,0.34820575,2.9303024,,,,,,,,,,,,,,,,, -63600,0.37252715,2.925956,,,,,,,,,,,,,,,,, -63700,0.38739505,2.9156778,,,,,,,,,,,,,,,,, -63800,0.3501081,2.911757,,,,,,,,,,,,,,,,, -63900,0.34876603,2.890685,,,,,,,,,,,,,,,,, -64000,0.34742245,2.9052117,,,,,,,,,,,,,,,,, -64100,0.3380297,2.7817092,,,,,,,,,,,,,,,,, -64200,0.329045,2.963474,,,,,,,,,,,,,,,,, -64300,0.33026594,2.9023244,,,,,,,,,,,,,,,,, -64400,0.34637743,2.9943693,,,,,,,,,,,,,,,,, -64500,0.35822466,2.9299948,,,,,,,,,,,,,,,,, -64600,0.36374438,2.9059305,,,,,,,,,,,,,,,,, -64700,0.33538646,2.8165152,,,,,,,,,,,,,,,,, -64800,0.3464459,2.9019248,,,,,,,,,,,,,,,,, -64900,0.3370749,2.9036007,,,,,,,,,,,,,,,,, -65000,0.37120366,2.888855,,,,,,,,,,,,,,,,, -65007,,,0.6498906016349792,1.779046893119812,32.04606986938242,0.6694399118423462,1.64409077167511,28.851999438402675,3000.0,0.6821103096008301,1.5699652433395386,28.572309209482075,3003.0,22718.93424129486,36967.4750084877,22718.93424129486,14245.597965955734,0.8305079936981201,0.0 -65100,0.38256153,2.8633614,,,,,,,,,,,,,,,,, -65200,0.34142852,2.8908741,,,,,,,,,,,,,,,,, -65300,0.35608256,2.7689438,,,,,,,,,,,,,,,,, -65400,0.3617513,2.9599378,,,,,,,,,,,,,,,,, -65500,0.37053305,2.9084387,,,,,,,,,,,,,,,,, -65600,0.39040107,2.9169328,,,,,,,,,,,,,,,,, -65700,0.41702053,2.808055,,,,,,,,,,,,,,,,, -65800,0.3304709,2.8508084,,,,,,,,,,,,,,,,, -65900,0.33863157,2.889853,,,,,,,,,,,,,,,,, -66000,0.37140322,2.8354049,,,,,,,,,,,,,,,,, -66100,0.37686574,2.8550313,,,,,,,,,,,,,,,,, -66200,0.34267312,2.903106,,,,,,,,,,,,,,,,, -66300,0.3546967,2.8415508,,,,,,,,,,,,,,,,, -66400,0.348417,2.8291926,,,,,,,,,,,,,,,,, -66500,0.36286363,2.8861055,,,,,,,,,,,,,,,,, -66600,0.3510441,2.9220397,,,,,,,,,,,,,,,,, -66700,0.34860015,2.8768065,,,,,,,,,,,,,,,,, -66800,0.3864098,2.9279568,,,,,,,,,,,,,,,,, -66900,0.34505674,2.7992241,,,,,,,,,,,,,,,,, -67000,0.40406287,2.9095435,,,,,,,,,,,,,,,,, -67100,0.35759804,2.9148505,,,,,,,,,,,,,,,,, -67200,0.34512523,2.8467486,,,,,,,,,,,,,,,,, -67300,0.35844532,2.833888,,,,,,,,,,,,,,,,, -67400,0.3514582,2.871887,,,,,,,,,,,,,,,,, -67416,,,0.6513877511024475,1.778394341468811,31.90224548614695,0.6698490977287292,1.6329776048660278,29.310274944733973,3000.0,0.6850618720054626,1.5574126243591309,28.808396256578643,3003.0,23559.06813430786,38276.82657814026,23559.06813430786,14714.698032855988,0.8699114322662354,0.0 -67500,0.3452799,2.8366663,,,,,,,,,,,,,,,,, -67600,0.36529276,2.85087,,,,,,,,,,,,,,,,, -67700,0.36074668,2.832023,,,,,,,,,,,,,,,,, -67800,0.33284774,2.8712466,,,,,,,,,,,,,,,,, -67900,0.35250932,2.9305503,,,,,,,,,,,,,,,,, -68000,0.33827257,2.7727108,,,,,,,,,,,,,,,,, -68100,0.4207243,2.9022577,,,,,,,,,,,,,,,,, -68200,0.36270243,2.9878275,,,,,,,,,,,,,,,,, -68300,0.36333978,2.8567019,,,,,,,,,,,,,,,,, -68400,0.35787034,2.818819,,,,,,,,,,,,,,,,, -68500,0.31939504,2.859669,,,,,,,,,,,,,,,,, -68600,0.35018238,2.8695834,,,,,,,,,,,,,,,,, -68700,0.36385906,2.8616376,,,,,,,,,,,,,,,,, -68800,0.36092,2.9571912,,,,,,,,,,,,,,,,, -68900,0.33820942,2.8464744,,,,,,,,,,,,,,,,, -69000,0.40124136,2.8169131,,,,,,,,,,,,,,,,, -69100,0.3991407,2.915351,,,,,,,,,,,,,,,,, -69200,0.38116506,2.8601632,,,,,,,,,,,,,,,,, -69300,0.33977577,2.8406816,,,,,,,,,,,,,,,,, -69400,0.3534544,2.8690052,,,,,,,,,,,,,,,,, -69500,0.36619705,2.8864443,,,,,,,,,,,,,,,,, -69600,0.37749583,2.8138514,,,,,,,,,,,,,,,,, -69700,0.37002656,2.9521642,,,,,,,,,,,,,,,,, -69800,0.348148,2.8205106,,,,,,,,,,,,,,,,, -69825,,,0.6634706258773804,1.7029547691345217,32.673334074985306,0.6726140975952148,1.6234630346298218,29.223958515493404,3000.0,0.6853756308555603,1.5609307289123535,28.84675359208095,3003.0,24399.16579413414,39619.68317270279,24399.16579413414,15217.340163707731,0.9094698429107666,0.0 -69900,0.39345497,2.8922563,,,,,,,,,,,,,,,,, -70000,0.35266966,2.7878828,,,,,,,,,,,,,,,,, -70100,0.33743823,2.8239954,,,,,,,,,,,,,,,,, -70200,0.399245,2.837888,,,,,,,,,,,,,,,,, -70300,0.35832492,2.8242548,,,,,,,,,,,,,,,,, -70400,0.39556676,2.8112059,,,,,,,,,,,,,,,,, -70500,0.39765185,2.9248493,,,,,,,,,,,,,,,,, -70600,0.3673901,2.8151321,,,,,,,,,,,,,,,,, -70700,0.43280336,2.8282735,,,,,,,,,,,,,,,,, -70800,0.3848774,2.8573768,,,,,,,,,,,,,,,,, -70900,0.38781705,2.8483198,,,,,,,,,,,,,,,,, -71000,0.35442275,2.8328528,,,,,,,,,,,,,,,,, -71100,0.36027762,2.8459852,,,,,,,,,,,,,,,,, -71200,0.35761905,2.8632717,,,,,,,,,,,,,,,,, -71300,0.3507256,2.8210359,,,,,,,,,,,,,,,,, -71400,0.38131133,2.8737977,,,,,,,,,,,,,,,,, -71500,0.35894933,2.8871222,,,,,,,,,,,,,,,,, -71600,0.37864614,2.8438938,,,,,,,,,,,,,,,,, -71700,0.42112952,2.8262887,,,,,,,,,,,,,,,,, -71800,0.3438642,2.847652,,,,,,,,,,,,,,,,, -71900,0.35476157,2.74874,,,,,,,,,,,,,,,,, -72000,0.35589153,2.8957553,,,,,,,,,,,,,,,,, -72100,0.35864565,2.8365662,,,,,,,,,,,,,,,,, -72200,0.36610383,2.880483,,,,,,,,,,,,,,,,, -72234,,,0.6572171449661255,1.744814395904541,32.34493762341491,0.6746351718902588,1.619512677192688,29.395188395213623,3000.0,0.6883621215820312,1.5452096462249756,29.265421368484017,3003.0,25239.123465538025,40969.03985142708,25239.123465538025,15726.62624669075,0.9434311389923096,0.0 -72300,0.33681282,2.8011825,,,,,,,,,,,,,,,,, -72400,0.42202955,2.8339942,,,,,,,,,,,,,,,,, -72500,0.41905233,2.83929,,,,,,,,,,,,,,,,, -72600,0.37581086,2.7822475,,,,,,,,,,,,,,,,, -72700,0.35820475,2.8465164,,,,,,,,,,,,,,,,, -72800,0.35852608,2.872096,,,,,,,,,,,,,,,,, -72900,0.3706104,2.8581653,,,,,,,,,,,,,,,,, -73000,0.35818174,2.8698547,,,,,,,,,,,,,,,,, -73100,0.36884236,2.8153853,,,,,,,,,,,,,,,,, -73200,0.33685684,2.84683,,,,,,,,,,,,,,,,, -73300,0.33558,2.84066,,,,,,,,,,,,,,,,, -73400,0.36326882,2.8844907,,,,,,,,,,,,,,,,, -73500,0.34602314,2.7987535,,,,,,,,,,,,,,,,, -73600,0.36721867,2.788141,,,,,,,,,,,,,,,,, -73700,0.341525,2.8663878,,,,,,,,,,,,,,,,, -73800,0.3654016,2.8836536,,,,,,,,,,,,,,,,, -73900,0.36086145,2.8566673,,,,,,,,,,,,,,,,, -74000,0.3822784,2.8811328,,,,,,,,,,,,,,,,, -74100,0.36635646,2.8412008,,,,,,,,,,,,,,,,, -74200,0.34356794,2.8414245,,,,,,,,,,,,,,,,, -74300,0.34364197,2.8533657,,,,,,,,,,,,,,,,, -74400,0.3671325,2.8295221,,,,,,,,,,,,,,,,, -74500,0.3184331,2.8353631,,,,,,,,,,,,,,,,, -74600,0.3689911,2.7722144,,,,,,,,,,,,,,,,, -74644,,,0.6548944115638733,1.7606871128082275,32.62063170484274,0.6750690937042236,1.6131248474121094,29.763167120385383,3000.0,0.6876532435417175,1.5402637720108032,29.1380813702546,3003.0,26079.2656018734,42292.14107131958,26079.2656018734,16209.47501564026,0.977283239364624,0.0 -74700,0.3660091,2.8499491,,,,,,,,,,,,,,,,, -74800,0.37721696,2.8868616,,,,,,,,,,,,,,,,, -74900,0.33714893,2.786221,,,,,,,,,,,,,,,,, -75000,0.37494934,2.875732,,,,,,,,,,,,,,,,, -75100,0.35299933,2.8188157,,,,,,,,,,,,,,,,, -75200,0.36290583,2.812223,,,,,,,,,,,,,,,,, -75300,0.36390215,2.8315783,,,,,,,,,,,,,,,,, -75400,0.35436442,2.807307,,,,,,,,,,,,,,,,, -75500,0.3542114,2.7937195,,,,,,,,,,,,,,,,, -75600,0.37523192,2.8505251,,,,,,,,,,,,,,,,, -75700,0.33529705,2.7829602,,,,,,,,,,,,,,,,, -75800,0.35170668,2.8352036,,,,,,,,,,,,,,,,, -75900,0.3771898,2.9222603,,,,,,,,,,,,,,,,, -76000,0.38247228,2.8836877,,,,,,,,,,,,,,,,, -76100,0.36896235,2.7644663,,,,,,,,,,,,,,,,, -76200,0.34750873,2.8298345,,,,,,,,,,,,,,,,, -76300,0.3697449,2.8653948,,,,,,,,,,,,,,,,, -76400,0.41992956,2.9242547,,,,,,,,,,,,,,,,, -76500,0.36517346,2.8093417,,,,,,,,,,,,,,,,, -76600,0.36504602,2.8235466,,,,,,,,,,,,,,,,, -76700,0.35393396,2.8024194,,,,,,,,,,,,,,,,, -76800,0.35594484,2.845332,,,,,,,,,,,,,,,,, -76900,0.38340598,2.9024482,,,,,,,,,,,,,,,,, -77000,0.3693536,2.80184,,,,,,,,,,,,,,,,, -77053,,,0.6639668941497803,1.694888710975647,32.92362728757387,0.6748707294464111,1.6068533658981323,29.34807239685257,3000.0,0.6893498301506042,1.530578374862671,29.18517059563849,3003.0,26919.287441253666,43667.97906756401,26919.287441253666,16745.17874765396,1.0129663944244385,0.0 -77100,0.35105085,2.8426094,,,,,,,,,,,,,,,,, -77200,0.34897032,2.83682,,,,,,,,,,,,,,,,, -77300,0.37605056,2.8332183,,,,,,,,,,,,,,,,, -77400,0.3825921,2.8509974,,,,,,,,,,,,,,,,, -77500,0.343378,2.8004622,,,,,,,,,,,,,,,,, -77600,0.38175306,2.7610657,,,,,,,,,,,,,,,,, -77700,0.3778841,2.80073,,,,,,,,,,,,,,,,, -77800,0.38908172,2.8335733,,,,,,,,,,,,,,,,, -77900,0.39540985,2.9430056,,,,,,,,,,,,,,,,, -78000,0.39043015,2.8774137,,,,,,,,,,,,,,,,, -78100,0.38449886,2.808095,,,,,,,,,,,,,,,,, -78200,0.40845466,2.837265,,,,,,,,,,,,,,,,, -78300,0.37301812,2.8530765,,,,,,,,,,,,,,,,, -78400,0.35760802,2.8396838,,,,,,,,,,,,,,,,, -78500,0.4295464,2.8842452,,,,,,,,,,,,,,,,, -78600,0.38645425,2.8789532,,,,,,,,,,,,,,,,, -78700,0.37864718,2.8406224,,,,,,,,,,,,,,,,, -78800,0.35673207,2.8301058,,,,,,,,,,,,,,,,, -78900,0.34772024,2.8156948,,,,,,,,,,,,,,,,, -79000,0.3938065,2.8295267,,,,,,,,,,,,,,,,, -79100,0.42249814,2.8217404,,,,,,,,,,,,,,,,, -79200,0.37463364,2.7916634,,,,,,,,,,,,,,,,, -79300,0.35929498,2.7981741,,,,,,,,,,,,,,,,, -79400,0.37463674,2.7895563,,,,,,,,,,,,,,,,, -79462,,,0.6601611971855164,1.7258120775222778,32.68691853658534,0.6769289970397949,1.5998398065567017,29.752262409047173,3000.0,0.6904653906822205,1.5268282890319824,29.331326809471992,3003.0,27759.34330010414,44991.61689281464,27759.34330010414,17228.647493600845,1.0498454570770264,0.0 -79500,0.40447247,2.8096552,,,,,,,,,,,,,,,,, -79600,0.5787765,2.8223221,,,,,,,,,,,,,,,,, -79700,0.386214,2.896508,,,,,,,,,,,,,,,,, -79800,0.3886136,2.8940024,,,,,,,,,,,,,,,,, -79900,0.36544788,2.8131628,,,,,,,,,,,,,,,,, -80000,0.38063097,2.783597,,,,,,,,,,,,,,,,, -80100,0.38254103,2.900853,,,,,,,,,,,,,,,,, -80200,0.37278125,2.803612,,,,,,,,,,,,,,,,, -80300,0.3683521,2.8732393,,,,,,,,,,,,,,,,, -80400,0.3885896,2.8231592,,,,,,,,,,,,,,,,, -80500,0.37677196,2.8362002,,,,,,,,,,,,,,,,, -80600,0.35599,2.8175008,,,,,,,,,,,,,,,,, -80700,0.4235729,2.8944113,,,,,,,,,,,,,,,,, -80800,0.41268983,2.85784,,,,,,,,,,,,,,,,, -80900,0.375014,2.8168702,,,,,,,,,,,,,,,,, -81000,0.36809745,2.7536848,,,,,,,,,,,,,,,,, -81100,0.34729004,2.8074954,,,,,,,,,,,,,,,,, -81200,0.3695953,2.8394976,,,,,,,,,,,,,,,,, -81300,0.39713767,2.7895222,,,,,,,,,,,,,,,,, -81400,0.36110094,2.79653,,,,,,,,,,,,,,,,, -81500,0.35031468,2.738251,,,,,,,,,,,,,,,,, -81600,0.38382584,2.7950256,,,,,,,,,,,,,,,,, -81700,0.3494029,2.828563,,,,,,,,,,,,,,,,, -81800,0.37710702,2.8269258,,,,,,,,,,,,,,,,, -81870,,,0.6784107685089111,1.6096141338348389,33.84269230391496,0.6793591976165771,1.5915672779083252,29.80917299937564,3000.0,0.6923711895942688,1.518932819366455,29.556238156403044,3003.0,28599.2469124794,46325.63169527054,28599.2469124794,17722.64498400688,1.086064338684082,0.0 -81900,0.38220546,2.8153234,,,,,,,,,,,,,,,,, -82000,0.36661398,2.7869349,,,,,,,,,,,,,,,,, -82100,0.41503248,2.86915,,,,,,,,,,,,,,,,, -82200,0.3849406,2.7629933,,,,,,,,,,,,,,,,, -82300,0.3836513,2.8847306,,,,,,,,,,,,,,,,, -82400,0.38091272,2.8335345,,,,,,,,,,,,,,,,, -82500,0.36948994,2.8561244,,,,,,,,,,,,,,,,, -82600,0.3502618,2.765516,,,,,,,,,,,,,,,,, -82700,0.37285408,2.8148146,,,,,,,,,,,,,,,,, -82800,0.3598685,2.809273,,,,,,,,,,,,,,,,, -82900,0.38432193,2.831128,,,,,,,,,,,,,,,,, -83000,0.37644956,2.8040464,,,,,,,,,,,,,,,,, -83100,0.39935482,2.7476072,,,,,,,,,,,,,,,,, -83200,0.40091613,2.795568,,,,,,,,,,,,,,,,, -83300,0.36112472,2.8278997,,,,,,,,,,,,,,,,, -83400,0.39895788,2.7498503,,,,,,,,,,,,,,,,, -83500,0.37352794,2.8068423,,,,,,,,,,,,,,,,, -83600,0.35811588,2.79608,,,,,,,,,,,,,,,,, -83700,0.39911732,2.7998254,,,,,,,,,,,,,,,,, -83800,0.38892084,2.8596327,,,,,,,,,,,,,,,,, -83900,0.3739509,2.7705002,,,,,,,,,,,,,,,,, -84000,0.42624784,2.8737752,,,,,,,,,,,,,,,,, -84100,0.37532133,2.8010974,,,,,,,,,,,,,,,,, -84200,0.37005517,2.740713,,,,,,,,,,,,,,,,, -84279,,,0.6645780205726624,1.6899528503417969,32.87321657993862,0.6802023649215698,1.5834990739822388,29.55673975736919,3000.0,0.6955435872077942,1.5063549280166626,29.76918535946277,3003.0,29439.36522555352,47654.22263765335,29439.36522555352,18211.00092744828,1.1236786842346191,0.0 -84300,0.3933884,2.8230085,,,,,,,,,,,,,,,,, -84400,0.39270982,2.8221607,,,,,,,,,,,,,,,,, -84500,0.39598885,2.8183577,,,,,,,,,,,,,,,,, -84600,0.38878337,2.7293873,,,,,,,,,,,,,,,,, -84700,0.38169822,2.785972,,,,,,,,,,,,,,,,, -84800,0.42446914,2.8533144,,,,,,,,,,,,,,,,, -84900,0.38346842,2.8035378,,,,,,,,,,,,,,,,, -85000,0.36655936,2.7616923,,,,,,,,,,,,,,,,, -85100,0.38418669,2.728557,,,,,,,,,,,,,,,,, -85200,0.42209452,2.8244128,,,,,,,,,,,,,,,,, -85300,0.37592223,2.8392122,,,,,,,,,,,,,,,,, -85400,0.39234585,2.775175,,,,,,,,,,,,,,,,, -85500,0.39085278,2.785321,,,,,,,,,,,,,,,,, -85600,0.3844825,2.7943044,,,,,,,,,,,,,,,,, -85700,0.41405424,2.8099985,,,,,,,,,,,,,,,,, -85800,0.41189143,2.7799358,,,,,,,,,,,,,,,,, -85900,0.39096874,2.8471997,,,,,,,,,,,,,,,,, -86000,0.38075867,2.8941371,,,,,,,,,,,,,,,,, -86100,0.38148013,2.854979,,,,,,,,,,,,,,,,, -86200,0.40628952,2.8121736,,,,,,,,,,,,,,,,, -86300,0.4255727,2.7946544,,,,,,,,,,,,,,,,, -86400,0.3699792,2.744417,,,,,,,,,,,,,,,,, -86500,0.38555637,2.7969007,,,,,,,,,,,,,,,,, -86600,0.44799325,2.814606,,,,,,,,,,,,,,,,, -86688,,,0.6623689532279968,1.7069313526153564,33.333061211211565,0.6798799633979797,1.5803080797195437,29.87338660571638,3000.0,0.6947998404502869,1.4972296953201294,30.07424609365526,3003.0,30279.430206775665,48972.770466566086,30279.430206775665,18689.371471881863,1.1595537662506104,0.0 -86700,0.41040537,2.7940354,,,,,,,,,,,,,,,,, -86800,0.4297653,2.764508,,,,,,,,,,,,,,,,, -86900,0.4103508,2.800479,,,,,,,,,,,,,,,,, -87000,0.38950813,2.748076,,,,,,,,,,,,,,,,, -87100,0.39144784,2.8232725,,,,,,,,,,,,,,,,, -87200,0.379431,2.7731926,,,,,,,,,,,,,,,,, -87300,0.38511115,2.774627,,,,,,,,,,,,,,,,, -87400,0.40277392,2.8202796,,,,,,,,,,,,,,,,, -87500,0.4036035,2.7870874,,,,,,,,,,,,,,,,, -87600,0.39914817,2.8143713,,,,,,,,,,,,,,,,, -87700,0.40474162,2.8159502,,,,,,,,,,,,,,,,, -87800,0.38058385,2.7884545,,,,,,,,,,,,,,,,, -87900,0.4120359,2.8588881,,,,,,,,,,,,,,,,, -88000,0.41841638,2.7513993,,,,,,,,,,,,,,,,, -88100,0.38398066,2.8018098,,,,,,,,,,,,,,,,, -88200,0.39390057,2.7995148,,,,,,,,,,,,,,,,, -88300,0.40287474,2.7232854,,,,,,,,,,,,,,,,, -88400,0.40826496,2.8477058,,,,,,,,,,,,,,,,, -88500,0.45038998,2.7585256,,,,,,,,,,,,,,,,, -88600,0.40949786,2.7957494,,,,,,,,,,,,,,,,, -88700,0.37662596,2.8181124,,,,,,,,,,,,,,,,, -88800,0.40747568,2.8367324,,,,,,,,,,,,,,,,, -88900,0.40472585,2.6792815,,,,,,,,,,,,,,,,, -89000,0.39874417,2.7444997,,,,,,,,,,,,,,,,, -89097,,,0.6741820573806763,1.636649250984192,33.7021119458403,0.6820126175880432,1.5686326026916504,29.77797063662827,3000.0,0.6962872743606567,1.4902687072753906,29.97291099023428,3003.0,31119.458785533905,50296.33786916733,31119.458785533905,19172.79546570778,1.1961100101470947,0.0 -89100,0.3825082,2.822919,,,,,,,,,,,,,,,,, -89200,0.4011492,2.828002,,,,,,,,,,,,,,,,, -89300,0.39745322,2.7807524,,,,,,,,,,,,,,,,, -89400,0.4092346,2.6941574,,,,,,,,,,,,,,,,, -89500,0.44069853,2.7586682,,,,,,,,,,,,,,,,, -89600,0.40838853,2.7851114,,,,,,,,,,,,,,,,, -89700,0.39232665,2.7325468,,,,,,,,,,,,,,,,, -89800,0.42763352,2.7687485,,,,,,,,,,,,,,,,, -89900,0.41943508,2.7440376,,,,,,,,,,,,,,,,, -90000,0.43769792,2.799989,,,,,,,,,,,,,,,,, -90100,0.3936336,2.81843,,,,,,,,,,,,,,,,, -90200,0.41250348,2.7929287,,,,,,,,,,,,,,,,, -90300,0.39218864,2.7689002,,,,,,,,,,,,,,,,, -90400,0.3919901,2.81744,,,,,,,,,,,,,,,,, -90500,0.39132074,2.7768908,,,,,,,,,,,,,,,,, -90600,0.4158072,2.7965736,,,,,,,,,,,,,,,,, -90700,0.41524315,2.7128158,,,,,,,,,,,,,,,,, -90800,0.4170445,2.7683432,,,,,,,,,,,,,,,,, -90900,0.4024598,2.7788808,,,,,,,,,,,,,,,,, -91000,0.42692718,2.8181913,,,,,,,,,,,,,,,,, -91100,0.41925648,2.7635517,,,,,,,,,,,,,,,,, -91200,0.4205291,2.767354,,,,,,,,,,,,,,,,, -91300,0.44957662,2.7889423,,,,,,,,,,,,,,,,, -91400,0.43157047,2.7276444,,,,,,,,,,,,,,,,, -91500,0.40481254,2.766114,,,,,,,,,,,,,,,,, -91506,,,0.667885959148407,1.6756535768508911,33.10366408069063,0.6826573610305786,1.5624314546585083,30.156964677978845,3000.0,0.6995874643325806,1.481274127960205,30.32512532130016,3003.0,31959.54501605034,51625.13171696663,31959.54501605034,19661.390310049057,1.2334024906158447,0.0 -91600,0.40931785,2.7055237,,,,,,,,,,,,,,,,, -91700,0.4086039,2.792702,,,,,,,,,,,,,,,,, -91800,0.44019362,2.7255738,,,,,,,,,,,,,,,,, -91900,0.41325107,2.7369413,,,,,,,,,,,,,,,,, -92000,0.446276,2.8257089,,,,,,,,,,,,,,,,, -92100,0.44141418,2.815012,,,,,,,,,,,,,,,,, -92200,0.42429283,2.7562168,,,,,,,,,,,,,,,,, -92300,0.4234354,2.730456,,,,,,,,,,,,,,,,, -92400,0.39435828,2.7248166,,,,,,,,,,,,,,,,, -92500,0.45115086,2.8170042,,,,,,,,,,,,,,,,, -92600,0.4326224,2.838362,,,,,,,,,,,,,,,,, -92700,0.44080797,2.789548,,,,,,,,,,,,,,,,, -92800,0.4298643,2.7528663,,,,,,,,,,,,,,,,, -92900,0.42746776,2.727254,,,,,,,,,,,,,,,,, -93000,0.42177635,2.7903545,,,,,,,,,,,,,,,,, -93100,0.42164108,2.7750957,,,,,,,,,,,,,,,,, -93200,0.42996085,2.8181872,,,,,,,,,,,,,,,,, -93300,0.42985672,2.7555327,,,,,,,,,,,,,,,,, -93400,0.43100795,2.7502534,,,,,,,,,,,,,,,,, -93500,0.40236226,2.7468088,,,,,,,,,,,,,,,,, -93600,0.4493875,2.7929997,,,,,,,,,,,,,,,,, -93700,0.4376638,2.761123,,,,,,,,,,,,,,,,, -93800,0.4279427,2.6785104,,,,,,,,,,,,,,,,, -93900,0.42035776,2.7570186,,,,,,,,,,,,,,,,, -93915,,,0.7043598890304565,1.4692833423614502,36.35120024978537,0.6835501194000244,1.558236002922058,30.0273109673757,3000.0,0.7001568675041199,1.4738965034484863,30.08239175584561,3003.0,32799.61296272278,52936.32979607582,32799.61296272278,20132.38963317871,1.2875926494598389,0.0 -94000,0.43146166,2.7562106,,,,,,,,,,,,,,,,, -94100,0.4162662,2.7471528,,,,,,,,,,,,,,,,, -94200,0.47243986,2.8232656,,,,,,,,,,,,,,,,, -94300,0.4376607,2.7825444,,,,,,,,,,,,,,,,, -94400,0.41807494,2.8275802,,,,,,,,,,,,,,,,, -94500,0.41817167,2.768149,,,,,,,,,,,,,,,,, -94600,0.42694253,2.7511516,,,,,,,,,,,,,,,,, -94700,0.4277082,2.7288094,,,,,,,,,,,,,,,,, -94800,0.4385463,2.6887121,,,,,,,,,,,,,,,,, -94900,0.44105825,2.7350132,,,,,,,,,,,,,,,,, -95000,0.42901722,2.796166,,,,,,,,,,,,,,,,, -95100,0.44165218,2.7530262,,,,,,,,,,,,,,,,, -95200,0.43086538,2.7159843,,,,,,,,,,,,,,,,, -95300,0.43689528,2.745754,,,,,,,,,,,,,,,,, -95400,0.46227702,2.6860917,,,,,,,,,,,,,,,,, -95500,0.4483607,2.75134,,,,,,,,,,,,,,,,, -95600,0.42057854,2.773465,,,,,,,,,,,,,,,,, -95700,0.4469106,2.8373182,,,,,,,,,,,,,,,,, -95800,0.4531352,2.7923918,,,,,,,,,,,,,,,,, -95900,0.465911,2.8199666,,,,,,,,,,,,,,,,, -96000,0.4356208,2.706177,,,,,,,,,,,,,,,,, -96100,0.4667542,2.8086371,,,,,,,,,,,,,,,,, -96200,0.46459672,2.7308643,,,,,,,,,,,,,,,,, -96300,0.46501553,2.77285,,,,,,,,,,,,,,,,, -96324,,,0.6754699945449829,1.6206367015838623,33.64124946940375,0.6841948628425598,1.5504695177078247,30.02768930762145,3000.0,0.7004590034484863,1.4679011106491089,30.379588836077428,3003.0,33639.83898591995,54311.99412155152,33639.83898591995,20667.71193766594,1.3252015113830566,0.0 -96400,0.4437687,2.820407,,,,,,,,,,,,,,,,, -96500,0.44469374,2.7446408,,,,,,,,,,,,,,,,, -96600,0.43394843,2.7269788,,,,,,,,,,,,,,,,, -96700,0.44503468,2.7330055,,,,,,,,,,,,,,,,, -96800,0.47664115,2.7346842,,,,,,,,,,,,,,,,, -96900,0.4192759,2.7386796,,,,,,,,,,,,,,,,, -97000,0.4826864,2.7745037,,,,,,,,,,,,,,,,, -97100,0.4636737,2.662086,,,,,,,,,,,,,,,,, -97200,0.46015048,2.7667322,,,,,,,,,,,,,,,,, -97300,0.4460753,2.772498,,,,,,,,,,,,,,,,, -97400,0.45764962,2.7620876,,,,,,,,,,,,,,,,, -97500,0.44330436,2.7156096,,,,,,,,,,,,,,,,, -97600,0.44185063,2.6975868,,,,,,,,,,,,,,,,, -97700,0.44124207,2.7326233,,,,,,,,,,,,,,,,, -97800,0.47411793,2.6844337,,,,,,,,,,,,,,,,, -97900,0.4752457,2.7356293,,,,,,,,,,,,,,,,, -98000,0.4543032,2.6819801,,,,,,,,,,,,,,,,, -98100,0.45116323,2.6566718,,,,,,,,,,,,,,,,, -98200,0.44679475,2.7033443,,,,,,,,,,,,,,,,, -98300,0.48308265,2.785282,,,,,,,,,,,,,,,,, -98400,0.44150075,2.6636138,,,,,,,,,,,,,,,,, -98500,0.46019238,2.719999,,,,,,,,,,,,,,,,, -98600,0.4654555,2.7254653,,,,,,,,,,,,,,,,, -98700,0.4604716,2.743972,,,,,,,,,,,,,,,,, -98734,,,0.6730888485908508,1.640097737312317,34.10313197148955,0.6855587363243103,1.5443347692489624,30.445706302183336,3000.0,0.700714647769928,1.4622043371200562,30.23568171788698,3003.0,34480.07180213928,55677.57796263695,34480.07180213928,21192.948457717896,1.361955642700195,0.0 -98800,0.49341252,2.7252712,,,,,,,,,,,,,,,,, -98900,0.4695367,2.7284029,,,,,,,,,,,,,,,,, -99000,0.4620215,2.7751877,,,,,,,,,,,,,,,,, -99100,0.46138212,2.68021,,,,,,,,,,,,,,,,, -99200,0.43534997,2.6978495,,,,,,,,,,,,,,,,, -99300,0.46422118,2.7564347,,,,,,,,,,,,,,,,, -99400,0.45699242,2.7160716,,,,,,,,,,,,,,,,, -99500,0.46929651,2.7452822,,,,,,,,,,,,,,,,, -99600,0.44732234,2.664467,,,,,,,,,,,,,,,,, -99700,0.47466162,2.7446547,,,,,,,,,,,,,,,,, -99800,0.47151548,2.7183514,,,,,,,,,,,,,,,,, -99900,0.4418139,2.7207546,,,,,,,,,,,,,,,,, -100000,0.47663972,2.7283485,,,,,,,,,,,,,,,,, -100100,0.48816645,2.7932048,,,,,,,,,,,,,,,,, -100200,0.45917103,2.6772194,,,,,,,,,,,,,,,,, -100300,0.4689488,2.7247987,,,,,,,,,,,,,,,,, -100400,0.4789802,2.7597418,,,,,,,,,,,,,,,,, -100500,0.4924953,2.7245345,,,,,,,,,,,,,,,,, -100600,0.48457164,2.7112458,,,,,,,,,,,,,,,,, -100700,0.48195922,2.6199136,,,,,,,,,,,,,,,,, -100800,0.48401454,2.744716,,,,,,,,,,,,,,,,, -100900,0.47111493,2.7604237,,,,,,,,,,,,,,,,, -101000,0.48442218,2.6613276,,,,,,,,,,,,,,,,, -101100,0.47668892,2.7104173,,,,,,,,,,,,,,,,, -101142,,,0.686933696269989,1.5602630376815796,34.618318349974814,0.6874186396598816,1.5365707874298096,30.18892686900693,3000.0,0.702934205532074,1.454554796218872,30.4668122978629,3003.0,35320.00034117699,57010.319242954254,35320.00034117699,21685.645532608032,1.3999087810516355,0.0 -101200,0.4514771,2.6445234,,,,,,,,,,,,,,,,, -101300,0.47505093,2.6950884,,,,,,,,,,,,,,,,, -101400,0.46720704,2.6883628,,,,,,,,,,,,,,,,, -101500,0.48491883,2.752648,,,,,,,,,,,,,,,,, -101600,0.51167864,2.7438824,,,,,,,,,,,,,,,,, -101700,0.49752682,2.7011251,,,,,,,,,,,,,,,,, -101800,0.4918345,2.686941,,,,,,,,,,,,,,,,, -101900,0.4816648,2.7458837,,,,,,,,,,,,,,,,, -102000,0.4707369,2.6976569,,,,,,,,,,,,,,,,, -102100,0.48088697,2.672102,,,,,,,,,,,,,,,,, -102200,0.5012948,2.7608936,,,,,,,,,,,,,,,,, -102300,0.49997032,2.7461853,,,,,,,,,,,,,,,,, -102400,0.4825284,2.6682804,,,,,,,,,,,,,,,,, -102500,0.4870818,2.6750631,,,,,,,,,,,,,,,,, -102600,0.49075434,2.6761527,,,,,,,,,,,,,,,,, -102700,0.50062615,2.698816,,,,,,,,,,,,,,,,, -102800,0.47860852,2.6909695,,,,,,,,,,,,,,,,, -102900,0.51385444,2.7719898,,,,,,,,,,,,,,,,, -103000,0.49556047,2.7390938,,,,,,,,,,,,,,,,, -103100,0.4884771,2.7451134,,,,,,,,,,,,,,,,, -103200,0.48747572,2.657813,,,,,,,,,,,,,,,,, -103300,0.5253061,2.745106,,,,,,,,,,,,,,,,, -103400,0.52693653,2.7750506,,,,,,,,,,,,,,,,, -103500,0.49177936,2.6265361,,,,,,,,,,,,,,,,, -103551,,,0.6821432709693909,1.5882389545440674,34.26272973822794,0.688956081867218,1.5335873365402222,30.360145618817835,3000.0,0.7046424150466919,1.4481626749038696,30.654866411054087,3003.0,36160.209800720215,58333.0024998188,36160.209800720215,22168.00412297249,1.4385159015655518,0.0 -103600,0.49711326,2.6912532,,,,,,,,,,,,,,,,, -103700,0.49997756,2.6741967,,,,,,,,,,,,,,,,, -103800,0.5284735,2.7515163,,,,,,,,,,,,,,,,, -103900,0.5348183,2.6580071,,,,,,,,,,,,,,,,, -104000,0.48706174,2.7106736,,,,,,,,,,,,,,,,, -104100,0.49799907,2.6968663,,,,,,,,,,,,,,,,, -104200,0.5877933,2.8002722,,,,,,,,,,,,,,,,, -104300,0.53916186,2.6344495,,,,,,,,,,,,,,,,, -104400,0.51534915,2.7152126,,,,,,,,,,,,,,,,, -104500,0.5093423,2.7051263,,,,,,,,,,,,,,,,, -104600,0.49752864,2.6672323,,,,,,,,,,,,,,,,, -104700,0.5068476,2.685809,,,,,,,,,,,,,,,,, -104800,0.5146546,2.6397424,,,,,,,,,,,,,,,,, -104900,0.523846,2.6956599,,,,,,,,,,,,,,,,, -105000,0.5043292,2.6912568,,,,,,,,,,,,,,,,, -105100,0.5069701,2.7478201,,,,,,,,,,,,,,,,, -105200,0.5015968,2.6659071,,,,,,,,,,,,,,,,, -105300,0.52816933,2.7404106,,,,,,,,,,,,,,,,, -105400,0.50141925,2.6597579,,,,,,,,,,,,,,,,, -105500,0.5335423,2.722623,,,,,,,,,,,,,,,,, -105600,0.5127723,2.6969206,,,,,,,,,,,,,,,,, -105700,0.538568,2.701091,,,,,,,,,,,,,,,,, -105800,0.5407788,2.7363832,,,,,,,,,,,,,,,,, -105900,0.52608395,2.6574206,,,,,,,,,,,,,,,,, -105959,,,0.6814278364181519,1.598901867866516,34.65526969945151,0.6893280744552612,1.5266581773757937,30.57867552999859,3000.0,0.7055836319923401,1.4425654411315918,30.78548127622312,3003.0,37000.21959018707,59685.23624706268,37000.21959018707,22680.11134338379,1.4766552448272705,0.0 -106000,0.5641104,2.7137852,,,,,,,,,,,,,,,,, -106100,0.5274703,2.6702552,,,,,,,,,,,,,,,,, -106200,0.55083925,2.6425173,,,,,,,,,,,,,,,,, -106300,0.53131217,2.6359756,,,,,,,,,,,,,,,,, -106400,0.53240496,2.6306915,,,,,,,,,,,,,,,,, -106500,0.5153597,2.6049485,,,,,,,,,,,,,,,,, -106600,0.5458562,2.707265,,,,,,,,,,,,,,,,, -106700,0.52621025,2.6767888,,,,,,,,,,,,,,,,, -106800,0.53824043,2.696195,,,,,,,,,,,,,,,,, -106900,0.566153,2.7015147,,,,,,,,,,,,,,,,, -107000,0.5415681,2.6673648,,,,,,,,,,,,,,,,, -107100,0.5657205,2.6774564,,,,,,,,,,,,,,,,, -107200,0.56597316,2.73998,,,,,,,,,,,,,,,,, -107300,0.54041684,2.6227698,,,,,,,,,,,,,,,,, -107400,0.5647594,2.6856766,,,,,,,,,,,,,,,,, -107500,0.5609065,2.7012424,,,,,,,,,,,,,,,,, -107600,0.5222777,2.6726143,,,,,,,,,,,,,,,,, -107700,0.52805626,2.6340292,,,,,,,,,,,,,,,,, -107800,0.56535774,2.7239323,,,,,,,,,,,,,,,,, -107900,0.56502104,2.7327414,,,,,,,,,,,,,,,,, -108000,0.5481369,2.7064228,,,,,,,,,,,,,,,,, -108100,0.56991154,2.7127929,,,,,,,,,,,,,,,,, -108200,0.55807287,2.6207485,,,,,,,,,,,,,,,,, -108300,0.54915476,2.63152,,,,,,,,,,,,,,,,, -108368,,,0.6921325922012329,1.5298690795898438,35.40294324144317,0.6879394054412842,1.527682185173035,30.47344674934183,3000.0,0.7045029401779175,1.443967342376709,30.56989330909604,3003.0,37840.39820337296,61025.62011170387,37840.39820337296,23180.20060920716,1.51641845703125,0.0 -108400,0.5829122,2.671086,,,,,,,,,,,,,,,,, -108500,0.5664863,2.6924548,,,,,,,,,,,,,,,,, -108600,0.58451396,2.7027755,,,,,,,,,,,,,,,,, -108700,0.58507204,2.7488234,,,,,,,,,,,,,,,,, -108800,0.56266487,2.672177,,,,,,,,,,,,,,,,, -108900,0.5681797,2.7204459,,,,,,,,,,,,,,,,, -109000,0.5687382,2.646578,,,,,,,,,,,,,,,,, -109100,0.60309154,2.619979,,,,,,,,,,,,,,,,, -109200,0.54225564,2.669223,,,,,,,,,,,,,,,,, -109300,0.57236433,2.6613955,,,,,,,,,,,,,,,,, -109400,0.5707766,2.719985,,,,,,,,,,,,,,,,, -109500,0.5573244,2.7172246,,,,,,,,,,,,,,,,, -109600,0.5715935,2.6450574,,,,,,,,,,,,,,,,, -109700,0.53528297,2.5994864,,,,,,,,,,,,,,,,, -109800,0.57966894,2.6965542,,,,,,,,,,,,,,,,, -109900,0.5786694,2.7393792,,,,,,,,,,,,,,,,, -110000,0.5631526,2.7197616,,,,,,,,,,,,,,,,, -110100,0.57163155,2.6454344,,,,,,,,,,,,,,,,, -110200,0.5610514,2.69719,,,,,,,,,,,,,,,,, -110300,0.58226436,2.6496446,,,,,,,,,,,,,,,,, -110400,0.597279,2.6722393,,,,,,,,,,,,,,,,, -110500,0.5819123,2.701745,,,,,,,,,,,,,,,,, -110600,0.60515016,2.682523,,,,,,,,,,,,,,,,, -110700,0.5859498,2.7011313,,,,,,,,,,,,,,,,, -110776,,,0.6881821751594543,1.555271863937378,34.94405351615512,0.6900472044944763,1.5213935375213623,30.51655914551938,3000.0,0.705258309841156,1.4395705461502075,30.330988152869008,3003.0,38680.33823132515,62356.54861474037,38680.33823132515,23671.07068610192,1.5580275058746338,0.0 -110800,0.58925337,2.6541343,,,,,,,,,,,,,,,,, -110900,0.5420224,2.6486766,,,,,,,,,,,,,,,,, -111000,0.5950928,2.680668,,,,,,,,,,,,,,,,, -111100,0.59243846,2.6412137,,,,,,,,,,,,,,,,, -111200,0.5956773,2.6446805,,,,,,,,,,,,,,,,, -111300,0.616678,2.6069992,,,,,,,,,,,,,,,,, -111400,0.5847963,2.695103,,,,,,,,,,,,,,,,, -111500,0.59159195,2.6768782,,,,,,,,,,,,,,,,, -111600,0.58029056,2.6821527,,,,,,,,,,,,,,,,, -111700,0.6035409,2.5734894,,,,,,,,,,,,,,,,, -111800,0.6437331,2.7096367,,,,,,,,,,,,,,,,, -111900,0.59263384,2.7184494,,,,,,,,,,,,,,,,, -112000,0.5852177,2.6737502,,,,,,,,,,,,,,,,, -112100,0.60824335,2.6646547,,,,,,,,,,,,,,,,, -112200,0.604724,2.7198117,,,,,,,,,,,,,,,,, -112300,0.6147579,2.674034,,,,,,,,,,,,,,,,, -112400,0.5928908,2.595067,,,,,,,,,,,,,,,,, -112500,0.62064606,2.6989658,,,,,,,,,,,,,,,,, -112600,0.61182517,2.594871,,,,,,,,,,,,,,,,, -112700,0.60990334,2.6530488,,,,,,,,,,,,,,,,, -112800,0.6056965,2.653127,,,,,,,,,,,,,,,,, -112900,0.60526526,2.6388898,,,,,,,,,,,,,,,,, -113000,0.62833124,2.685431,,,,,,,,,,,,,,,,, -113100,0.5986333,2.6737764,,,,,,,,,,,,,,,,, -113185,,,0.7068272233009338,1.4597846269607544,36.19503096365993,0.6921550631523132,1.5146585702896118,30.68075410546633,3000.0,0.7078961133956909,1.430131435394287,30.702009793502405,3003.0,39520.55152916908,63693.93018531799,39520.55152916908,24168.117631196976,1.6014611721038818,0.0 -113200,0.60613495,2.6408627,,,,,,,,,,,,,,,,, -113300,0.6150192,2.6689835,,,,,,,,,,,,,,,,, -113400,0.60480684,2.7621617,,,,,,,,,,,,,,,,, -113500,0.63896346,2.5937726,,,,,,,,,,,,,,,,, -113600,0.62318945,2.6351762,,,,,,,,,,,,,,,,, -113700,0.67173135,2.7216804,,,,,,,,,,,,,,,,, -113800,0.5967607,2.614603,,,,,,,,,,,,,,,,, -113900,0.6393931,2.7414074,,,,,,,,,,,,,,,,, -114000,0.6178933,2.5773804,,,,,,,,,,,,,,,,, -114100,0.6280483,2.6314802,,,,,,,,,,,,,,,,, -114200,0.626814,2.6706276,,,,,,,,,,,,,,,,, -114300,0.62536937,2.6905305,,,,,,,,,,,,,,,,, -114400,0.6211071,2.6237867,,,,,,,,,,,,,,,,, -114500,0.63943267,2.6878269,,,,,,,,,,,,,,,,, -114600,0.6263075,2.627158,,,,,,,,,,,,,,,,, -114700,0.63530606,2.6474638,,,,,,,,,,,,,,,,, -114800,0.6395153,2.6886728,,,,,,,,,,,,,,,,, -114900,0.6238987,2.6880903,,,,,,,,,,,,,,,,, -115000,0.62183255,2.6402578,,,,,,,,,,,,,,,,, -115100,0.6182534,2.5709188,,,,,,,,,,,,,,,,, -115200,0.62857366,2.5949793,,,,,,,,,,,,,,,,, -115300,0.62947947,2.670445,,,,,,,,,,,,,,,,, -115400,0.6477713,2.5665112,,,,,,,,,,,,,,,,, -115500,0.66840285,2.6962807,,,,,,,,,,,,,,,,, -115594,,,0.6958543062210083,1.5099217891693115,35.537028715325654,0.6910639405250549,1.5157564878463743,30.63486420797816,3000.0,0.7095230221748352,1.4238560199737549,31.016669589638788,3003.0,40360.71843266487,65041.172853946686,40360.71843266487,24675.07525396347,1.6436638832092283,0.0 -115600,0.6520418,2.6210961,,,,,,,,,,,,,,,,, -115700,0.647936,2.6761444,,,,,,,,,,,,,,,,, -115800,0.6460466,2.6842475,,,,,,,,,,,,,,,,, -115900,0.6620305,2.6626942,,,,,,,,,,,,,,,,, -116000,0.6446637,2.6321187,,,,,,,,,,,,,,,,, -116100,0.6322687,2.6369727,,,,,,,,,,,,,,,,, -116200,0.6379219,2.6107364,,,,,,,,,,,,,,,,, -116300,0.65212995,2.6696763,,,,,,,,,,,,,,,,, -116400,0.6620713,2.6649933,,,,,,,,,,,,,,,,, -116500,0.6436798,2.6520405,,,,,,,,,,,,,,,,, -116600,0.6648441,2.6247191,,,,,,,,,,,,,,,,, -116700,0.65901625,2.6204324,,,,,,,,,,,,,,,,, -116800,0.641853,2.6580133,,,,,,,,,,,,,,,,, -116900,0.65007347,2.617439,,,,,,,,,,,,,,,,, -117000,0.6649337,2.6084542,,,,,,,,,,,,,,,,, -117100,0.6772049,2.6369455,,,,,,,,,,,,,,,,, -117200,0.64147943,2.6138356,,,,,,,,,,,,,,,,, -117300,0.65542156,2.5599566,,,,,,,,,,,,,,,,, -117400,0.62200356,2.6016347,,,,,,,,,,,,,,,,, -117500,0.68656665,2.7114725,,,,,,,,,,,,,,,,, -117600,0.6342301,2.5766006,,,,,,,,,,,,,,,,, -117700,0.628831,2.633012,,,,,,,,,,,,,,,,, -117800,0.64749616,2.574122,,,,,,,,,,,,,,,,, -117900,0.65291286,2.6191404,,,,,,,,,,,,,,,,, -118000,0.7030256,2.636038,,,,,,,,,,,,,,,,, -118003,,,0.6981893181800842,1.5048353672027588,35.49911535816573,0.6911135315895081,1.5123188495635986,30.810753206662906,3000.0,0.70961594581604,1.4236197471618652,31.12338429193865,3003.0,41200.7603263855,66375.6348798275,41200.7603263855,25169.378446102142,1.6847825050354004,0.0 -118100,0.68053824,2.6644373,,,,,,,,,,,,,,,,, -118200,0.67551714,2.6113596,,,,,,,,,,,,,,,,, -118300,0.66333,2.6204925,,,,,,,,,,,,,,,,, -118400,0.67768764,2.6287794,,,,,,,,,,,,,,,,, -118500,0.6505826,2.6220353,,,,,,,,,,,,,,,,, -118600,0.6936988,2.6409807,,,,,,,,,,,,,,,,, -118700,0.6941198,2.7023194,,,,,,,,,,,,,,,,, -118800,0.66966647,2.6606042,,,,,,,,,,,,,,,,, -118900,0.68282723,2.616941,,,,,,,,,,,,,,,,, -119000,0.6876504,2.6541262,,,,,,,,,,,,,,,,, -119100,0.6509661,2.594009,,,,,,,,,,,,,,,,, -119200,0.655546,2.6210785,,,,,,,,,,,,,,,,, -119300,0.6867273,2.6093266,,,,,,,,,,,,,,,,, -119400,0.686913,2.659876,,,,,,,,,,,,,,,,, -119500,0.6759195,2.6244318,,,,,,,,,,,,,,,,, -119600,0.6595215,2.5603638,,,,,,,,,,,,,,,,, -119700,0.7033164,2.6946886,,,,,,,,,,,,,,,,, -119800,0.6872937,2.645047,,,,,,,,,,,,,,,,, -119900,0.6582665,2.5965762,,,,,,,,,,,,,,,,, -120000,0.67280865,2.5687244,,,,,,,,,,,,,,,,, -120100,0.6847849,2.5976877,,,,,,,,,,,,,,,,, -120200,0.6660331,2.6247919,,,,,,,,,,,,,,,,, -120300,0.67872226,2.592253,,,,,,,,,,,,,,,,, -120400,0.7078553,2.6265223,,,,,,,,,,,,,,,,, -120412,,,0.7060102224349976,1.4636739492416382,36.32610595477296,0.6930354237556458,1.5079798698425293,30.81666176363183,3000.0,0.7101737260818481,1.4174529314041138,31.0455792076282,3003.0,42040.92917585373,67714.37622475624,42040.92917585373,25667.826288461685,1.732635259628296,0.0 -120500,0.6902722,2.5593736,,,,,,,,,,,,,,,,, -120600,0.6778987,2.6348634,,,,,,,,,,,,,,,,, -120700,0.6770751,2.5998983,,,,,,,,,,,,,,,,, -120800,0.7184648,2.5772812,,,,,,,,,,,,,,,,, -120900,0.7027112,2.5913908,,,,,,,,,,,,,,,,, -121000,0.67369276,2.6004708,,,,,,,,,,,,,,,,, -121100,0.67828715,2.6398923,,,,,,,,,,,,,,,,, -121200,0.69982755,2.5877154,,,,,,,,,,,,,,,,, -121300,0.6770525,2.656185,,,,,,,,,,,,,,,,, -121400,0.6613425,2.5641892,,,,,,,,,,,,,,,,, -121500,0.69521296,2.58878,,,,,,,,,,,,,,,,, -121600,0.6699902,2.5839257,,,,,,,,,,,,,,,,, -121700,0.68457496,2.6160219,,,,,,,,,,,,,,,,, -121800,0.7104297,2.5941355,,,,,,,,,,,,,,,,, -121900,0.71621686,2.6299448,,,,,,,,,,,,,,,,, -122000,0.70410645,2.637621,,,,,,,,,,,,,,,,, -122100,0.697639,2.5783427,,,,,,,,,,,,,,,,, -122200,0.7198913,2.6300876,,,,,,,,,,,,,,,,, -122300,0.7162643,2.6313524,,,,,,,,,,,,,,,,, -122400,0.6739605,2.5258052,,,,,,,,,,,,,,,,, -122500,0.70008916,2.6045394,,,,,,,,,,,,,,,,, -122600,0.6856222,2.5627809,,,,,,,,,,,,,,,,, -122700,0.69768995,2.6782095,,,,,,,,,,,,,,,,, -122800,0.71349156,2.5954928,,,,,,,,,,,,,,,,, -122820,,,0.7075486779212952,1.4473143815994265,36.11652502233829,0.6921798586845398,1.5096970796585083,30.694070873645327,3000.0,0.7105455994606018,1.418480634689331,31.13267295113105,3003.0,42881.0361392498,69050.65186357498,42881.0361392498,26163.87600445748,1.7744617462158203,0.0 -122900,0.7182447,2.5869455,,,,,,,,,,,,,,,,, -123000,0.7400379,2.614063,,,,,,,,,,,,,,,,, -123100,0.6821823,2.5969095,,,,,,,,,,,,,,,,, -123200,0.72479945,2.5824585,,,,,,,,,,,,,,,,, -123300,0.7276922,2.6171498,,,,,,,,,,,,,,,,, -123400,0.6945836,2.6283765,,,,,,,,,,,,,,,,, -123500,0.7093809,2.6344733,,,,,,,,,,,,,,,,, -123600,0.70059484,2.5894425,,,,,,,,,,,,,,,,, -123700,0.69711125,2.6374826,,,,,,,,,,,,,,,,, -123800,0.69080645,2.574238,,,,,,,,,,,,,,,,, -123900,0.6846289,2.6059391,,,,,,,,,,,,,,,,, -124000,0.6902735,2.6396759,,,,,,,,,,,,,,,,, -124100,0.69728786,2.578778,,,,,,,,,,,,,,,,, -124200,0.71933734,2.590745,,,,,,,,,,,,,,,,, -124300,0.7194927,2.6649582,,,,,,,,,,,,,,,,, -124400,0.71483463,2.5860956,,,,,,,,,,,,,,,,, -124500,0.70395166,2.665683,,,,,,,,,,,,,,,,, -124600,0.7074257,2.604321,,,,,,,,,,,,,,,,, -124700,0.7010731,2.6195805,,,,,,,,,,,,,,,,, -124800,0.70661354,2.5836508,,,,,,,,,,,,,,,,, -124900,0.70509547,2.6195765,,,,,,,,,,,,,,,,, -125000,0.6858317,2.584985,,,,,,,,,,,,,,,,, -125100,0.7216268,2.638186,,,,,,,,,,,,,,,,, -125200,0.79856324,2.5353208,,,,,,,,,,,,,,,,, -125227,,,0.7090739607810974,1.4433223009109497,36.57344895473294,0.6934942007064819,1.5054748058319092,30.81755869738153,3000.0,0.7112312316894531,1.41438090801239,31.124645847204345,3003.0,43720.94148516655,70395.96740603447,43720.94148516655,26669.166478395466,1.816622018814087,0.0 -125300,0.7420466,2.6993618,,,,,,,,,,,,,,,,, -125400,0.70737517,2.6457903,,,,,,,,,,,,,,,,, -125500,0.7046714,2.6199615,,,,,,,,,,,,,,,,, -125600,0.7065943,2.5843823,,,,,,,,,,,,,,,,, -125700,0.74719656,2.5951889,,,,,,,,,,,,,,,,, -125800,0.6970531,2.5817697,,,,,,,,,,,,,,,,, -125900,0.71319664,2.6145544,,,,,,,,,,,,,,,,, -126000,0.71660626,2.6273637,,,,,,,,,,,,,,,,, -126100,0.72401255,2.6031191,,,,,,,,,,,,,,,,, -126200,0.68319714,2.5647116,,,,,,,,,,,,,,,,, -126300,0.7013909,2.5720158,,,,,,,,,,,,,,,,, -126400,0.68748444,2.4906816,,,,,,,,,,,,,,,,, -126500,0.7050425,2.5983448,,,,,,,,,,,,,,,,, -126600,0.6835114,2.5942528,,,,,,,,,,,,,,,,, -126700,0.7116886,2.6348615,,,,,,,,,,,,,,,,, -126800,0.70544755,2.6584513,,,,,,,,,,,,,,,,, -126900,0.7164226,2.549409,,,,,,,,,,,,,,,,, -127000,0.7065557,2.5820918,,,,,,,,,,,,,,,,, -127100,0.67726904,2.5473788,,,,,,,,,,,,,,,,, -127200,0.7188135,2.634083,,,,,,,,,,,,,,,,, -127300,0.7062385,2.6420891,,,,,,,,,,,,,,,,, -127400,0.7119586,2.6083567,,,,,,,,,,,,,,,,, -127500,0.7206473,2.5509703,,,,,,,,,,,,,,,,, -127600,0.7061867,2.5821517,,,,,,,,,,,,,,,,, -127635,,,0.7085887789726257,1.438305377960205,36.61241848155233,0.6935437917709351,1.5044677257537842,30.862560161613207,3000.0,0.7111498713493347,1.4130109548568726,31.248859219822087,3003.0,44561.00401854515,71725.01059532166,44561.00401854515,27158.02793073654,1.858067274093628,0.0 -127700,0.73321,2.5995033,,,,,,,,,,,,,,,,, -127800,0.69942474,2.5777314,,,,,,,,,,,,,,,,, -127900,0.7208014,2.5715532,,,,,,,,,,,,,,,,, -128000,0.71941507,2.6558654,,,,,,,,,,,,,,,,, -128100,0.74829346,2.5944629,,,,,,,,,,,,,,,,, -128200,0.72079945,2.580546,,,,,,,,,,,,,,,,, -128300,0.7066244,2.5527222,,,,,,,,,,,,,,,,, -128400,0.7094569,2.5724266,,,,,,,,,,,,,,,,, -128500,0.7280655,2.5283368,,,,,,,,,,,,,,,,, -128600,0.7271764,2.5720954,,,,,,,,,,,,,,,,, -128700,0.6720437,2.5083983,,,,,,,,,,,,,,,,, -128800,0.71168,2.561645,,,,,,,,,,,,,,,,, -128900,0.71783316,2.5813725,,,,,,,,,,,,,,,,, -129000,0.7265012,2.6272302,,,,,,,,,,,,,,,,, -129100,0.71469307,2.640731,,,,,,,,,,,,,,,,, -129200,0.725518,2.596793,,,,,,,,,,,,,,,,, -129300,0.71538424,2.6030846,,,,,,,,,,,,,,,,, -129400,0.7056739,2.597928,,,,,,,,,,,,,,,,, -129500,0.7262487,2.6352096,,,,,,,,,,,,,,,,, -129600,0.7067985,2.5855904,,,,,,,,,,,,,,,,, -129700,0.7099242,2.5590742,,,,,,,,,,,,,,,,, -129800,0.71814865,2.564251,,,,,,,,,,,,,,,,, -129900,0.71996665,2.6037126,,,,,,,,,,,,,,,,, -130000,0.7300397,2.586521,,,,,,,,,,,,,,,,, -130043,,,0.7091085910797119,1.4369813203811646,36.99864153134804,0.6933081746101379,1.504925012588501,30.757680438251448,3000.0,0.7116844058036804,1.4134386777877808,31.10714427248701,3003.0,45400.904284238815,73065.42194890976,45400.904284238815,27658.42151737213,1.900503635406494,0.0 -130100,0.71806014,2.583422,,,,,,,,,,,,,,,,, -130200,0.72105193,2.6213374,,,,,,,,,,,,,,,,, -130300,0.71265453,2.5714128,,,,,,,,,,,,,,,,, -130400,0.7010968,2.498098,,,,,,,,,,,,,,,,, -130500,0.70567584,2.6066463,,,,,,,,,,,,,,,,, -130600,0.7166721,2.5936036,,,,,,,,,,,,,,,,, -130700,0.711208,2.5991745,,,,,,,,,,,,,,,,, -130800,0.7051755,2.5587711,,,,,,,,,,,,,,,,, -130900,0.7401762,2.6785724,,,,,,,,,,,,,,,,, -131000,0.7101158,2.5573823,,,,,,,,,,,,,,,,, -131100,0.72980326,2.574953,,,,,,,,,,,,,,,,, -131200,0.69436,2.588321,,,,,,,,,,,,,,,,, -131300,0.6987671,2.6206684,,,,,,,,,,,,,,,,, -131400,0.71427786,2.5592642,,,,,,,,,,,,,,,,, -131500,0.6951618,2.577053,,,,,,,,,,,,,,,,, -131600,0.70894706,2.6103919,,,,,,,,,,,,,,,,, -131700,0.73293376,2.6083636,,,,,,,,,,,,,,,,, -131800,0.70795935,2.6314948,,,,,,,,,,,,,,,,, -131900,0.7037443,2.5709136,,,,,,,,,,,,,,,,, -132000,0.6890368,2.5598893,,,,,,,,,,,,,,,,, -132100,0.7120936,2.563241,,,,,,,,,,,,,,,,, -132200,0.69995284,2.5806763,,,,,,,,,,,,,,,,, -132300,0.6942988,2.5420482,,,,,,,,,,,,,,,,, -132400,0.72477436,2.5584517,,,,,,,,,,,,,,,,, -132451,,,0.7096970677375793,1.4432387351989746,36.13791974967528,0.693717360496521,1.504041075706482,30.865012416464875,3000.0,0.7116495370864868,1.4128319025039673,31.19698552073992,3003.0,46241.0879445076,74414.16377663612,46241.0879445076,28166.858474493027,1.944511890411377,0.0 -132500,0.69065744,2.63499,,,,,,,,,,,,,,,,, -132600,0.6958037,2.6018918,,,,,,,,,,,,,,,,, -132700,0.71690625,2.5879846,,,,,,,,,,,,,,,,, -132800,0.70987076,2.5573766,,,,,,,,,,,,,,,,, -132900,0.6771097,2.5965955,,,,,,,,,,,,,,,,, -133000,0.6949295,2.581716,,,,,,,,,,,,,,,,, -133100,0.7054751,2.5855265,,,,,,,,,,,,,,,,, -133200,0.69798976,2.509618,,,,,,,,,,,,,,,,, -133300,0.7285554,2.6763813,,,,,,,,,,,,,,,,, -133333,,,0.7093749642372131,1.4428577423095703,36.54108972271205,0.693630576133728,1.5042879581451416,30.86924377390636,3000.0,0.7118006348609924,1.413087010383606,31.18247574009484,3003.0,46548.72388243675,75229.0163321495,46548.72388243675,28674.002204179764,1.9894163608551023,0.0 -133333,,,,,,,,,,,,,,46548.72388243675,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 1f3d40548..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -868.8139111995697,0.0,27.26857042312622,1,0,27.26857042312622,0.0007088489946909,0.0,11.041826248168944,3003,896.0825242996216,0.0006168890395201,0.0,11.06764030456543,0.0004835649742744,0.0,11.036645889282228,3000 -1559.3806607723236,0.0209846496582031,867.3339140415192,2403,0,867.3339140415192,0.3978037536144256,8.912712367020626,4.211139678955078,3003,2426.8114387989044,0.4234203398227691,15.385226949250676,3.922987222671509,0.4070997238159179,10.213285356315824,4.05037260055542,3000 -2021.2432608604431,0.0460124015808105,1707.4406251907349,4805,0,1707.4406251907349,0.5467666387557983,18.429889335573424,2.8108057975769043,3003,3728.882633447647,0.5433375835418701,23.48978144292716,2.8315465450286865,0.5458456873893738,19.935954988012053,2.782279253005981,3000 -2486.478396892548,0.0709567070007324,2547.405416727066,7209,0,2547.405416727066,0.5943059921264648,22.20622453217967,2.385735273361206,3003,5034.181886911392,0.5829526782035828,26.772562741454795,2.4544622898101807,0.5922927260398865,23.585131664602347,2.3894851207733154,3000 -2967.781727552414,0.0971579551696777,3387.3830687999725,9614,0,3387.3830687999725,0.6229039430618286,23.902748515407488,2.132373332977295,3003,6355.566185951233,0.6014232039451599,28.633667442316284,2.2858991622924805,0.6167809367179871,24.9768614171484,2.163081407546997,3000 -3458.648379802704,0.1283042430877685,4227.473156452179,12019,0,4227.473156452179,0.6383940577507019,25.06606033830548,1.9973318576812744,3003,7686.628827095032,0.6114262938499451,29.34431349214636,2.1870195865631104,0.6296511888504028,25.76193795439446,2.044898271560669,3000 -3928.399883508682,0.1582114696502685,5067.463086843491,14423,0,5067.463086843491,0.6490268111228943,26.05126954299677,1.909260869026184,3003,8996.47902727127,0.6259955763816833,30.1025163364132,2.062871217727661,0.6408227682113647,26.82538380800523,1.9591337442398071,3000 -4388.0984263420105,0.1860949993133545,5907.40859413147,16828,0,5907.40859413147,0.6569984555244446,26.831555098397647,1.8533170223236084,3003,10296.226013422012,0.6307092905044556,30.666321567767124,2.047231912612915,0.647034764289856,27.453680639560886,1.90785562992096,3000 -4849.583470821381,0.2175691127777099,6747.634348392487,19234,0,6747.634348392487,0.6652141213417053,27.350837919366988,1.8207148313522337,3003,11598.04559969902,0.6473884582519531,31.799686839622684,1.9363336563110352,0.6544494032859802,27.786832108362752,1.8840500116348269,3000 -5376.148508548737,0.24692964553833,7587.8516755104065,21640,0,7587.8516755104065,0.6675498485565186,27.407634766447824,1.796425461769104,3003,12964.936593532562,0.6389160752296448,31.14960287562537,1.9853614568710327,0.6562844514846802,28.15702780958249,1.854095458984375,3000 -5892.5287890434265,0.274691104888916,8428.0314412117,24047,0,8428.0314412117,0.668502688407898,27.702566928617625,1.7730231285095217,3003,14321.60050368309,0.6381778717041016,31.64516066956896,1.978065729141236,0.6574996113777161,27.9298678608305,1.8356502056121824,3000 -6366.634291410446,0.3060090541839599,9267.948320865631,26453,0,9267.948320865631,0.6716518402099609,28.19627620748597,1.732503056526184,3003,15635.728284358978,0.650377094745636,31.99286060995921,1.885564208030701,0.6624096632003784,28.47187284902378,1.792500615119934,3000 -6856.061012983322,0.3347887992858886,10108.041011333466,28861,0,10108.041011333466,0.6757306456565857,28.2438093623282,1.708932399749756,3003,16965.351059913635,0.6490204334259033,31.745865182709256,1.8970963954925537,0.6626080274581909,28.336580553316416,1.7790693044662476,3000 -7542.333392858505,0.3696925640106201,10948.245263814926,31268,0,10948.245263814926,0.6761257648468018,27.99109156876941,1.7377357482910156,3003,18491.9369943142,0.6462864875793457,32.04996210677026,1.938828706741333,0.6645174622535706,28.524153144156863,1.804949164390564,3000 -8058.018579006195,0.3987507820129394,11788.358795166016,33674,0,11788.358795166016,0.6791703104972839,28.64386057392573,1.6936718225479126,3003,19847.839690446854,0.6534045934677124,31.9049431373965,1.8654236793518064,0.667592465877533,28.721269757279806,1.7680360078811646,3000 -8535.275834321976,0.4287540912628174,12628.493074417114,36080,0,12628.493074417114,0.6801928877830505,28.412542674194228,1.681604504585266,3003,21165.33632516861,0.6488655209541321,32.563431170556136,1.889739751815796,0.6692167520523071,29.021798663119707,1.754157543182373,3000 -9007.191690206528,0.4582366943359375,13468.718967199326,38487,0,13468.718967199326,0.6821103096008301,28.864725008764605,1.679569959640503,3003,22477.581879138947,0.6632543206214905,32.92000144955914,1.8003400564193726,0.66898113489151,28.953533798359764,1.75272798538208,3000 -9488.803978919985,0.4883334636688232,14308.6868288517,40893,0,14308.6868288517,0.6849921941757202,29.026990271202685,1.6443243026733398,3003,23799.267735242844,0.6533471941947937,32.56679466777,1.8556861877441408,0.6716221570968628,29.09871189610873,1.7270926237106323,3000 -10092.543762683868,0.5201573371887207,15148.599982261658,43300,0,15148.599982261658,0.685631275177002,28.75444159865201,1.6498874425888062,3003,25243.02674794197,0.6563020348548889,32.51621448980282,1.8518691062927248,0.6724033355712891,29.32289596719666,1.733350157737732,3000 -10597.39832019806,0.5518801212310791,15988.6215569973,45706,0,15988.6215569973,0.6875951886177063,29.44166874318576,1.6285982131958008,3003,26588.011610507965,0.6638458371162415,33.13194819782453,1.7846899032592771,0.6739159822463989,29.235713821138607,1.7038495540618896,3000 -11084.258370399475,0.5843021869659424,16828.530921697617,48112,0,16828.530921697617,0.6886642575263977,29.52146011305364,1.616563081741333,3003,27914.88912725449,0.6603615880012512,32.8939815446997,1.8153201341629028,0.6750814914703369,29.456297594517302,1.6979299783706665,3000 -11547.249927043917,0.6154048442840576,17668.503177165985,50518,0,17668.503177165985,0.6907094717025757,29.501761063552795,1.609735131263733,3003,29217.95916581154,0.669950544834137,33.24540770074166,1.7502673864364624,0.6771769523620605,29.52148956201149,1.6892192363739014,3000 -12130.155924797058,0.6474461555480957,18508.68393635749,52924,0,18508.68393635749,0.6914182901382446,29.61664679715476,1.6207830905914309,3003,30641.154708862305,0.6591640114784241,32.61795057033321,1.8146765232086184,0.6771769523620605,29.5259645039707,1.7013094425201416,3000 -12755.12444281578,0.679734468460083,19348.650020122528,55330,0,19348.650020122528,0.6917552947998047,29.35782497609175,1.6076879501342771,3003,32106.20058512688,0.6601080894470215,32.35156533613473,1.815306544303894,0.677424967288971,29.76204785307309,1.687734603881836,3000 -13255.915263652802,0.7129116058349609,20188.56041240692,57737,0,20188.56041240692,0.6926733255386353,29.221216444179152,1.5996352434158323,3003,33447.00986433029,0.6684832572937012,33.19071769118694,1.7588273286819458,0.6785284876823425,29.483006700913737,1.682852268218994,3000 -13797.100610494614,0.7455697059631348,21028.63009619713,60144,0,21028.63009619713,0.6945558190345764,29.65021273085204,1.593664526939392,3003,34828.37162208557,0.6650095582008362,33.05104338028556,1.7817625999450684,0.6779953241348267,29.6121101877778,1.6812210083007812,3000 -14454.052380561829,0.7788140773773193,21868.660029172897,62550,0,21868.660029172897,0.6963453888893127,29.76351245408125,1.5847262144088743,3003,36325.46452903748,0.7026515007019043,35.26644664948576,1.5845561027526855,0.6804379224777222,30.021677303349723,1.6678342819213867,3000 -14921.166560411451,0.8124041557312012,22708.81247138977,64956,0,22708.81247138977,0.6964964270591736,30.08841979180686,1.5666615962982178,3003,37632.84275865555,0.6722168922424316,33.19875507863805,1.7237827777862549,0.6810950636863708,29.83164370579329,1.657042384147644,3000 -15426.43873333931,0.8456981182098389,23548.917605161667,67362,0,23548.917605161667,0.6967636942863464,29.93042932332061,1.5596245527267456,3003,38978.3294506073,0.6712957620620728,33.597483300602605,1.734437108039856,0.6810950636863708,29.94199898103257,1.6491377353668213,3000 -15941.064447641373,0.8869020938873291,24389.033936738968,69769,0,24389.033936738968,0.6988670229911804,30.00534103427137,1.5575608015060425,3003,40333.188530921936,0.6853699684143066,34.29055829084901,1.6466609239578247,0.6822109818458557,29.90150956813121,1.6531620025634766,3000 -16386.192722558975,0.9233071804046632,25228.926471233368,72175,0,25228.926471233368,0.6993783116340637,29.9560497314722,1.5530868768692017,3003,41618.32129120827,0.6768722534179688,33.83223992527047,1.701812744140625,0.6833640933036804,29.873267425150825,1.6468572616577148,3000 -16898.12716984749,0.9586780071258544,26069.1646475792,74582,0,26069.1646475792,0.6996455788612366,29.979354940034053,1.5487751960754397,3003,42970.60527634621,0.6750902533531189,33.31560347389563,1.7135508060455322,0.6835005283355713,30.202027501872564,1.641509175300598,3000 -17383.99588394165,1.000993251800537,26909.29024219513,76989,0,26909.29024219513,0.7015281319618225,30.30835640115028,1.541305422782898,3003,44296.71672439575,0.6826446056365967,34.03202156257172,1.6596330404281616,0.6867862939834595,30.34294338222907,1.6362948417663574,3000 -17849.676304340363,1.0389277935028076,27749.502287864685,79395,0,27749.502287864685,0.7026785612106323,30.35795776214528,1.537244200706482,3003,45602.72605991364,0.6748082637786865,34.39932828355682,1.7070127725601196,0.6856703758239746,30.385912673364096,1.6362732648849487,3000 -18349.50480556488,1.0803887844085691,28589.634435892105,81801,0,28589.634435892105,0.7022950649261475,30.29497923799136,1.538099765777588,3003,46942.80326747894,0.7017195820808411,35.8828094371977,1.5646322965621948,0.685732364654541,30.01887627978177,1.630754470825195,3000 -18835.574059724808,1.1179816722869873,29429.53974795341,84207,0,29429.53974795341,0.7048166990280151,30.50286774566312,1.5285508632659912,3003,48268.892318964005,0.683625340461731,34.68102130080084,1.652172327041626,0.6857695579528809,30.57193751932736,1.6258001327514648,3000 -19418.148502588272,1.1565618515014648,30269.54340934753,86613,0,30269.54340934753,0.705142080783844,30.38943963120881,1.5264335870742798,3003,49691.58460474014,0.6862362027168274,34.51882434228991,1.647615671157837,0.6874062418937683,30.232125626879338,1.6228716373443604,3000 -19918.049089193344,1.1962296962738037,31109.477601766583,89019,0,31109.477601766583,0.7048166990280151,30.56950362536727,1.5245695114135742,3003,51031.5335791111,0.6950206756591797,35.72473038291398,1.5976887941360474,0.6871954202651978,30.638430798079032,1.61845600605011,3000 -20395.571761369705,1.2364046573638916,31949.395862579346,91425,0,31949.395862579346,0.7056185007095337,30.65265721547933,1.5198183059692385,3003,52349.0916249752,0.693752646446228,34.55984270647691,1.605541110038757,0.688559353351593,30.754953917388068,1.6157660484313965,3000 -20934.71894097328,1.274902582168579,32789.459950208664,93831,0,32789.459950208664,0.7061066031455994,30.63082315061668,1.518967866897583,3003,53728.417481184006,0.7316763997077942,38.14423812644837,1.4330400228500366,0.6875426173210144,30.55295250101604,1.6164175271987915,3000 -21441.42243003845,1.31276273727417,33629.5340692997,96237,0,33629.5340692997,0.7069665193557739,30.65861045851249,1.5066982507705688,3003,55075.31088614464,0.7016276121139526,35.49628524369535,1.558693528175354,0.6899976134300232,30.619367228391614,1.6058191061019895,3000 -21960.90658211708,1.353524923324585,34469.55902791023,98644,0,34469.55902791023,0.7068967819213867,30.476182677665605,1.5110329389572144,3003,56434.93587017059,0.6927933096885681,35.66692699405113,1.601349115371704,0.6894148588180542,30.435671431315622,1.6130400896072388,3000 -22414.529752969745,1.3917343616485596,35309.74000072479,101051,0,35309.74000072479,0.7070245742797852,30.72464167955525,1.507551908493042,3003,57728.85360980034,0.7072663307189941,36.582741766440975,1.5225698947906494,0.6892164945602417,30.412669470445955,1.6107611656188965,3000 -22943.753257989883,1.431570529937744,36149.81416296959,103458,0,36149.81416296959,0.707466185092926,30.460744589687334,1.5031170845031738,3003,59098.26532077789,0.6998023390769958,35.472857329654225,1.5615475177764893,0.6898860335350037,30.398652314834955,1.608350396156311,3000 -23452.08066368103,1.4725525379180908,36989.73922157288,105864,0,36989.73922157288,0.707524299621582,30.83850855238789,1.5049904584884644,3003,60446.634615659714,0.7032337188720703,36.129084059359286,1.5497827529907229,0.6897744536399841,30.705780856663885,1.6095569133758545,3000 -23941.40462422371,1.512440204620361,37829.69793653488,108269,0,37829.69793653488,0.7091279029846191,30.73570104529433,1.499886393547058,3003,61776.03334188461,0.7132049202919006,36.41010826418029,1.4947891235351562,0.6916218996047974,30.63931633701646,1.6021170616149902,3000 -24402.493275880814,1.5521199703216553,38669.68802213669,110674,0,38669.68802213669,0.7097786664962769,30.64078556582809,1.4970381259918213,3003,63077.23090171814,0.7082228064537048,36.62626379524225,1.514228343963623,0.6904811859130859,30.574834332474047,1.604583740234375,3000 -24868.04579520225,1.593592405319214,39509.64947485924,113080,0,39509.64947485924,0.7103829383850098,30.95519761628077,1.4962009191513062,3003,64382.861696243286,0.7224063277244568,37.28813238914175,1.443524718284607,0.6914855241775513,30.60543726425647,1.6032272577285769,3000 -25343.89398145676,1.635310173034668,40349.80061197281,115487,0,40349.80061197281,0.7099064588546753,30.68944762531302,1.4967427253723145,3003,65698.97678422928,0.713723361492157,36.435484973088045,1.487293004989624,0.6912871599197388,30.65375561094501,1.6010518074035645,3000 -25802.397783994675,1.6849846839904783,41189.93091821671,117894,0,41189.93091821671,0.7090465426445007,30.936780229896527,1.4972790479660034,3003,66997.73618650436,0.7144909501075745,37.108529161669374,1.490625500679016,0.6910887360572815,30.71050729467754,1.6030627489089966,3000 -26271.71107816696,1.7338509559631348,42030.07278776169,120300,0,42030.07278776169,0.7098715901374817,30.98069747621216,1.4950560331344604,3003,68307.31502318382,0.7231228351593018,37.805040018899184,1.442205548286438,0.6926014423370361,30.68321910491765,1.6032183170318604,3000 -26738.382657289505,1.7781829833984375,42870.12835860253,122710,0,42870.12835860253,0.7106037139892578,30.8694776006975,1.4928041696548462,3003,69614.16217589378,0.7189242243766785,37.58180576067513,1.4655243158340454,0.6922046542167664,30.80034796220132,1.6004964113235474,3000 -27202.7842271328,1.8205416202545168,43710.05010247231,125116,0,43710.05010247231,0.7102667093276978,30.91072129641423,1.4937689304351809,3003,70918.60228586197,0.7213461399078369,37.47862548500293,1.4562888145446775,0.6918947100639343,30.825696629345693,1.6007949113845823,3000 -27669.634046316147,1.8627715110778809,44550.06398534775,127522,0,44550.06398534775,0.710719883441925,30.998903514021368,1.493880033493042,3003,72225.58293819427,0.7226781845092773,37.238805367993805,1.445623517036438,0.6923658847808838,30.86297775830037,1.6010756492614746,3000 -28143.56909942627,1.9056484699249268,45389.95973515511,129928,0,45389.95973515511,0.7106850147247314,30.92520062066843,1.4928480386734009,3003,73539.53221654892,0.7224118113517761,37.6780744346327,1.4469929933547974,0.6922914981842041,30.87728482321264,1.6008161306381226,3000 -28613.24405527115,1.9580817222595213,46230.10717344284,132335,0,46230.10717344284,0.7106037139892578,30.889636565732552,1.492241621017456,3003,74849.48199796677,0.7235233783721924,37.51305828525132,1.4436975717544556,0.6922170519828796,30.95629465203926,1.6004061698913574,3000 -29085.583038330078,2.0043785572052,46578.47412419319,133333,0,46578.47412419319,0.7105804681777954,30.862555894325826,1.4921976327896118,3003,75670.26692295074,0.721368134021759,37.40015536663024,1.4508073329925537,0.6922046542167664,30.959032282956052,1.6004267930984497,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/measurements.csv deleted file mode 100644 index 068d4dc69..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.562566,11.047717,,,,,,,,,,,,,,,,, -1,,,0.0006168890395201,11.06764030456543,0.0,0.0004835649742744,11.036645889282228,0.0,3000.0,0.0007088489946909,11.041826248168944,0.0,3003.0,27.26857042312622,896.0825242996216,27.26857042312622,868.8139111995697,0.0,0.0 -100,0.2592753,9.05992,,,,,,,,,,,,,,,,, -200,0.2224346,8.706313,,,,,,,,,,,,,,,,, -300,1.1062071,8.366786,,,,,,,,,,,,,,,,, -400,0.98816025,8.043669,,,,,,,,,,,,,,,,, -500,0.7191279,7.8348227,,,,,,,,,,,,,,,,, -600,0.8250634,7.678504,,,,,,,,,,,,,,,,, -700,0.69967175,7.4668326,,,,,,,,,,,,,,,,, -800,0.7018196,7.2386637,,,,,,,,,,,,,,,,, -900,0.5786748,7.1178274,,,,,,,,,,,,,,,,, -1000,0.47409943,6.988267,,,,,,,,,,,,,,,,, -1100,0.5683264,6.8439584,,,,,,,,,,,,,,,,, -1200,0.5753018,6.7279963,,,,,,,,,,,,,,,,, -1300,0.5221525,6.6380734,,,,,,,,,,,,,,,,, -1400,0.7020857,6.464896,,,,,,,,,,,,,,,,, -1500,0.6246643,6.3957243,,,,,,,,,,,,,,,,, -1600,0.62739414,6.319633,,,,,,,,,,,,,,,,, -1700,0.6578382,6.2348785,,,,,,,,,,,,,,,,, -1800,0.55716443,6.107324,,,,,,,,,,,,,,,,, -1900,0.75224334,6.035811,,,,,,,,,,,,,,,,, -2000,0.6864018,5.9462085,,,,,,,,,,,,,,,,, -2100,0.63935405,5.8231854,,,,,,,,,,,,,,,,, -2200,0.6191799,5.766654,,,,,,,,,,,,,,,,, -2300,0.5642418,5.658942,,,,,,,,,,,,,,,,, -2400,0.7018257,5.5513916,,,,,,,,,,,,,,,,, -2403,,,0.4234203398227691,3.922987222671509,15.385226949250676,0.4070997238159179,4.05037260055542,10.213285356315824,3000.0,0.3978037536144256,4.211139678955078,8.912712367020626,3003.0,867.3339140415192,2426.8114387989044,867.3339140415192,1559.3806607723236,0.0209846496582031,0.0 -2500,0.63084704,5.448476,,,,,,,,,,,,,,,,, -2600,0.6625712,5.4643807,,,,,,,,,,,,,,,,, -2700,0.68245125,5.388047,,,,,,,,,,,,,,,,, -2800,0.47328174,5.289172,,,,,,,,,,,,,,,,, -2900,0.480792,5.28347,,,,,,,,,,,,,,,,, -3000,0.5490552,5.21779,,,,,,,,,,,,,,,,, -3100,0.6030887,5.0897408,,,,,,,,,,,,,,,,, -3200,0.6455108,5.1849813,,,,,,,,,,,,,,,,, -3300,0.5068755,5.086654,,,,,,,,,,,,,,,,, -3400,0.60082006,5.0324984,,,,,,,,,,,,,,,,, -3500,0.4476434,4.960403,,,,,,,,,,,,,,,,, -3600,0.4448728,5.098901,,,,,,,,,,,,,,,,, -3700,0.43890417,4.930079,,,,,,,,,,,,,,,,, -3800,0.49710393,4.931765,,,,,,,,,,,,,,,,, -3900,0.4356366,4.876384,,,,,,,,,,,,,,,,, -4000,0.38196746,4.917715,,,,,,,,,,,,,,,,, -4100,0.3950644,4.738012,,,,,,,,,,,,,,,,, -4200,0.42104027,4.8183823,,,,,,,,,,,,,,,,, -4300,0.3792873,4.865501,,,,,,,,,,,,,,,,, -4400,0.39514893,4.778925,,,,,,,,,,,,,,,,, -4500,0.37239105,4.7273397,,,,,,,,,,,,,,,,, -4600,0.4490163,4.7607164,,,,,,,,,,,,,,,,, -4700,0.3878577,4.6913915,,,,,,,,,,,,,,,,, -4800,0.3826147,4.756342,,,,,,,,,,,,,,,,, -4805,,,0.5433375835418701,2.8315465450286865,23.48978144292716,0.5458456873893738,2.782279253005981,19.935954988012053,3000.0,0.5467666387557983,2.8108057975769043,18.429889335573424,3003.0,1707.4406251907349,3728.882633447647,1707.4406251907349,2021.2432608604431,0.0460124015808105,0.0 -4900,0.39236617,4.6707687,,,,,,,,,,,,,,,,, -5000,0.41610402,4.668042,,,,,,,,,,,,,,,,, -5100,0.3324193,4.6991405,,,,,,,,,,,,,,,,, -5200,0.3320615,4.6577597,,,,,,,,,,,,,,,,, -5300,0.36451295,4.6614223,,,,,,,,,,,,,,,,, -5400,0.31959862,4.6394906,,,,,,,,,,,,,,,,, -5500,0.2957892,4.573168,,,,,,,,,,,,,,,,, -5600,0.3433841,4.5846934,,,,,,,,,,,,,,,,, -5700,0.29646954,4.629944,,,,,,,,,,,,,,,,, -5800,0.31083035,4.5890775,,,,,,,,,,,,,,,,, -5900,0.33029118,4.616156,,,,,,,,,,,,,,,,, -6000,0.34403172,4.5864115,,,,,,,,,,,,,,,,, -6100,0.32795167,4.5334835,,,,,,,,,,,,,,,,, -6200,0.30047494,4.558327,,,,,,,,,,,,,,,,, -6300,0.2945861,4.506973,,,,,,,,,,,,,,,,, -6400,0.25952667,4.555329,,,,,,,,,,,,,,,,, -6500,0.3269615,4.548392,,,,,,,,,,,,,,,,, -6600,0.26850024,4.5326114,,,,,,,,,,,,,,,,, -6700,0.2664488,4.4887047,,,,,,,,,,,,,,,,, -6800,0.25108254,4.478877,,,,,,,,,,,,,,,,, -6900,0.29644904,4.4971585,,,,,,,,,,,,,,,,, -7000,0.29228228,4.474088,,,,,,,,,,,,,,,,, -7100,0.26422405,4.4449215,,,,,,,,,,,,,,,,, -7200,0.23286176,4.48745,,,,,,,,,,,,,,,,, -7209,,,0.5829526782035828,2.4544622898101807,26.772562741454795,0.5922927260398865,2.3894851207733154,23.585131664602347,3000.0,0.5943059921264648,2.385735273361206,22.20622453217967,3003.0,2547.405416727066,5034.181886911392,2547.405416727066,2486.478396892548,0.0709567070007324,0.0 -7300,0.25889128,4.4064555,,,,,,,,,,,,,,,,, -7400,0.23614126,4.4207683,,,,,,,,,,,,,,,,, -7500,0.22192213,4.3573723,,,,,,,,,,,,,,,,, -7600,0.214025,4.476133,,,,,,,,,,,,,,,,, -7700,0.28619957,4.3310475,,,,,,,,,,,,,,,,, -7800,0.22106816,4.305487,,,,,,,,,,,,,,,,, -7900,0.27136734,4.347011,,,,,,,,,,,,,,,,, -8000,0.2268239,4.470519,,,,,,,,,,,,,,,,, -8100,0.22202152,4.321971,,,,,,,,,,,,,,,,, -8200,0.22582006,4.378681,,,,,,,,,,,,,,,,, -8300,0.22612604,4.348706,,,,,,,,,,,,,,,,, -8400,0.20336773,4.3253727,,,,,,,,,,,,,,,,, -8500,0.20060332,4.3982506,,,,,,,,,,,,,,,,, -8600,0.20743184,4.335174,,,,,,,,,,,,,,,,, -8700,0.20439859,4.335959,,,,,,,,,,,,,,,,, -8800,0.22910608,4.3871303,,,,,,,,,,,,,,,,, -8900,0.20762187,4.289026,,,,,,,,,,,,,,,,, -9000,0.18523808,4.288216,,,,,,,,,,,,,,,,, -9100,0.18946137,4.3312287,,,,,,,,,,,,,,,,, -9200,0.24029742,4.3938003,,,,,,,,,,,,,,,,, -9300,0.21580249,4.2523556,,,,,,,,,,,,,,,,, -9400,0.19036518,4.346597,,,,,,,,,,,,,,,,, -9500,0.188667,4.309164,,,,,,,,,,,,,,,,, -9600,0.27430573,4.279812,,,,,,,,,,,,,,,,, -9614,,,0.6014232039451599,2.2858991622924805,28.633667442316284,0.6167809367179871,2.163081407546997,24.9768614171484,3000.0,0.6229039430618286,2.132373332977295,23.902748515407488,3003.0,3387.3830687999725,6355.566185951233,3387.3830687999725,2967.781727552414,0.0971579551696777,0.0 -9700,0.19696526,4.2473116,,,,,,,,,,,,,,,,, -9800,0.18289231,4.294917,,,,,,,,,,,,,,,,, -9900,0.17859109,4.2378745,,,,,,,,,,,,,,,,, -10000,0.17815511,4.1965995,,,,,,,,,,,,,,,,, -10100,0.17465088,4.2374983,,,,,,,,,,,,,,,,, -10200,0.18485461,4.287072,,,,,,,,,,,,,,,,, -10300,0.17636018,4.1726356,,,,,,,,,,,,,,,,, -10400,0.1789003,4.179076,,,,,,,,,,,,,,,,, -10500,0.16398941,4.269984,,,,,,,,,,,,,,,,, -10600,0.17496781,4.216596,,,,,,,,,,,,,,,,, -10700,0.17178819,4.2664,,,,,,,,,,,,,,,,, -10800,0.1761545,4.2459006,,,,,,,,,,,,,,,,, -10900,0.17459477,4.2023487,,,,,,,,,,,,,,,,, -11000,0.18924607,4.2785163,,,,,,,,,,,,,,,,, -11100,0.17080912,4.202644,,,,,,,,,,,,,,,,, -11200,0.18174323,4.2277665,,,,,,,,,,,,,,,,, -11300,0.1675401,4.256275,,,,,,,,,,,,,,,,, -11400,0.20107046,4.154194,,,,,,,,,,,,,,,,, -11500,0.16932897,4.27894,,,,,,,,,,,,,,,,, -11600,0.1709299,4.1626067,,,,,,,,,,,,,,,,, -11700,0.15751322,4.1960845,,,,,,,,,,,,,,,,, -11800,0.17141458,4.183102,,,,,,,,,,,,,,,,, -11900,0.16576111,4.2107534,,,,,,,,,,,,,,,,, -12000,0.16832632,4.228216,,,,,,,,,,,,,,,,, -12019,,,0.6114262938499451,2.1870195865631104,29.34431349214636,0.6296511888504028,2.044898271560669,25.76193795439446,3000.0,0.6383940577507019,1.9973318576812744,25.06606033830548,3003.0,4227.473156452179,7686.628827095032,4227.473156452179,3458.648379802704,0.1283042430877685,0.0 -12100,0.1567538,4.171562,,,,,,,,,,,,,,,,, -12200,0.16184346,4.1698995,,,,,,,,,,,,,,,,, -12300,0.15728635,4.211117,,,,,,,,,,,,,,,,, -12400,0.15295957,4.121087,,,,,,,,,,,,,,,,, -12500,0.18942276,4.2009606,,,,,,,,,,,,,,,,, -12600,0.17108837,4.1480494,,,,,,,,,,,,,,,,, -12700,0.17809984,4.160778,,,,,,,,,,,,,,,,, -12800,0.16765884,4.2503347,,,,,,,,,,,,,,,,, -12900,0.16932969,4.095819,,,,,,,,,,,,,,,,, -13000,0.16125342,4.1739945,,,,,,,,,,,,,,,,, -13100,0.18094295,4.2141576,,,,,,,,,,,,,,,,, -13200,0.1601454,4.1913457,,,,,,,,,,,,,,,,, -13300,0.16749881,4.1846223,,,,,,,,,,,,,,,,, -13400,0.16416487,4.1655602,,,,,,,,,,,,,,,,, -13500,0.15714452,4.198824,,,,,,,,,,,,,,,,, -13600,0.16034745,4.129385,,,,,,,,,,,,,,,,, -13700,0.1746025,4.149331,,,,,,,,,,,,,,,,, -13800,0.161326,4.112795,,,,,,,,,,,,,,,,, -13900,0.15324531,4.1154428,,,,,,,,,,,,,,,,, -14000,0.17096783,4.1549764,,,,,,,,,,,,,,,,, -14100,0.15936273,4.1554403,,,,,,,,,,,,,,,,, -14200,0.16466571,4.0815506,,,,,,,,,,,,,,,,, -14300,0.16535462,4.1043434,,,,,,,,,,,,,,,,, -14400,0.16987124,4.1180806,,,,,,,,,,,,,,,,, -14423,,,0.6259955763816833,2.062871217727661,30.1025163364132,0.6408227682113647,1.9591337442398071,26.82538380800523,3000.0,0.6490268111228943,1.909260869026184,26.05126954299677,3003.0,5067.463086843491,8996.47902727127,5067.463086843491,3928.399883508682,0.1582114696502685,0.0 -14500,0.17484367,4.1030855,,,,,,,,,,,,,,,,, -14600,0.16159016,4.1648474,,,,,,,,,,,,,,,,, -14700,0.15150036,4.102749,,,,,,,,,,,,,,,,, -14800,0.19664557,4.151707,,,,,,,,,,,,,,,,, -14900,0.16168971,4.123936,,,,,,,,,,,,,,,,, -15000,0.16177784,4.131224,,,,,,,,,,,,,,,,, -15100,0.1813087,4.1132755,,,,,,,,,,,,,,,,, -15200,0.16946822,4.123862,,,,,,,,,,,,,,,,, -15300,0.16551676,4.132626,,,,,,,,,,,,,,,,, -15400,0.15798894,4.125591,,,,,,,,,,,,,,,,, -15500,0.14637624,4.1081553,,,,,,,,,,,,,,,,, -15600,0.15531516,4.1788225,,,,,,,,,,,,,,,,, -15700,0.1522552,4.155885,,,,,,,,,,,,,,,,, -15800,0.14850865,4.0650535,,,,,,,,,,,,,,,,, -15900,0.14268573,4.1422343,,,,,,,,,,,,,,,,, -16000,0.1874812,4.0209723,,,,,,,,,,,,,,,,, -16100,0.15383454,4.1274276,,,,,,,,,,,,,,,,, -16200,0.14916404,4.041917,,,,,,,,,,,,,,,,, -16300,0.1576934,4.106678,,,,,,,,,,,,,,,,, -16400,0.14442334,4.047832,,,,,,,,,,,,,,,,, -16500,0.20352444,4.047899,,,,,,,,,,,,,,,,, -16600,0.1627285,4.068321,,,,,,,,,,,,,,,,, -16700,0.1991795,4.0380387,,,,,,,,,,,,,,,,, -16800,0.16795097,4.0857434,,,,,,,,,,,,,,,,, -16828,,,0.6307092905044556,2.047231912612915,30.666321567767124,0.647034764289856,1.90785562992096,27.453680639560886,3000.0,0.6569984555244446,1.8533170223236084,26.831555098397647,3003.0,5907.40859413147,10296.226013422012,5907.40859413147,4388.0984263420105,0.1860949993133545,0.0 -16900,0.15193513,4.1445026,,,,,,,,,,,,,,,,, -17000,0.16095427,4.1649165,,,,,,,,,,,,,,,,, -17100,0.16437656,4.1269727,,,,,,,,,,,,,,,,, -17200,0.17020679,4.0682526,,,,,,,,,,,,,,,,, -17300,0.16088712,4.021797,,,,,,,,,,,,,,,,, -17400,0.17744611,4.102976,,,,,,,,,,,,,,,,, -17500,0.15199573,4.041462,,,,,,,,,,,,,,,,, -17600,0.15604444,4.08086,,,,,,,,,,,,,,,,, -17700,0.17029704,4.065984,,,,,,,,,,,,,,,,, -17800,0.15847345,4.1078944,,,,,,,,,,,,,,,,, -17900,0.1883178,4.011142,,,,,,,,,,,,,,,,, -18000,0.24107291,4.0575657,,,,,,,,,,,,,,,,, -18100,0.20348682,4.0035906,,,,,,,,,,,,,,,,, -18200,0.1622735,4.0663023,,,,,,,,,,,,,,,,, -18300,0.17701702,4.089575,,,,,,,,,,,,,,,,, -18400,0.19891636,4.035203,,,,,,,,,,,,,,,,, -18500,0.15873042,4.072477,,,,,,,,,,,,,,,,, -18600,0.16153188,4.036273,,,,,,,,,,,,,,,,, -18700,0.15856145,4.059093,,,,,,,,,,,,,,,,, -18800,0.16240504,3.984478,,,,,,,,,,,,,,,,, -18900,0.15668006,4.048007,,,,,,,,,,,,,,,,, -19000,0.16077535,3.976594,,,,,,,,,,,,,,,,, -19100,0.18203005,4.06263,,,,,,,,,,,,,,,,, -19200,0.15693052,4.0571246,,,,,,,,,,,,,,,,, -19234,,,0.6473884582519531,1.9363336563110352,31.799686839622684,0.6544494032859802,1.8840500116348269,27.786832108362752,3000.0,0.6652141213417053,1.8207148313522337,27.350837919366988,3003.0,6747.634348392487,11598.04559969902,6747.634348392487,4849.583470821381,0.2175691127777099,0.0 -19300,0.17116328,3.9682417,,,,,,,,,,,,,,,,, -19400,0.15184744,4.0348644,,,,,,,,,,,,,,,,, -19500,0.21488227,4.040685,,,,,,,,,,,,,,,,, -19600,0.16193171,4.0645537,,,,,,,,,,,,,,,,, -19700,0.15807642,4.0467973,,,,,,,,,,,,,,,,, -19800,0.18285878,4.089028,,,,,,,,,,,,,,,,, -19900,0.1620277,4.0172796,,,,,,,,,,,,,,,,, -20000,0.18593825,4.0277967,,,,,,,,,,,,,,,,, -20100,0.15991704,4.0065756,,,,,,,,,,,,,,,,, -20200,0.15785716,4.02771,,,,,,,,,,,,,,,,, -20300,0.17733625,4.01089,,,,,,,,,,,,,,,,, -20400,0.15411729,4.111868,,,,,,,,,,,,,,,,, -20500,0.1636299,4.065465,,,,,,,,,,,,,,,,, -20600,0.20715204,4.06258,,,,,,,,,,,,,,,,, -20700,0.16861561,4.0291533,,,,,,,,,,,,,,,,, -20800,0.18243639,4.009923,,,,,,,,,,,,,,,,, -20900,0.1597214,3.9761887,,,,,,,,,,,,,,,,, -21000,0.1630525,4.0828032,,,,,,,,,,,,,,,,, -21100,0.17036194,4.031411,,,,,,,,,,,,,,,,, -21200,0.20754889,4.0382547,,,,,,,,,,,,,,,,, -21300,0.17631994,3.9938328,,,,,,,,,,,,,,,,, -21400,0.49677536,4.25232,,,,,,,,,,,,,,,,, -21500,0.16419467,4.0082707,,,,,,,,,,,,,,,,, -21600,0.18140203,4.009423,,,,,,,,,,,,,,,,, -21640,,,0.6389160752296448,1.9853614568710327,31.14960287562537,0.6562844514846802,1.854095458984375,28.15702780958249,3000.0,0.6675498485565186,1.796425461769104,27.407634766447824,3003.0,7587.8516755104065,12964.936593532562,7587.8516755104065,5376.148508548737,0.24692964553833,0.0 -21700,0.1589603,3.9889207,,,,,,,,,,,,,,,,, -21800,0.15942092,4.067013,,,,,,,,,,,,,,,,, -21900,0.15856421,3.992617,,,,,,,,,,,,,,,,, -22000,0.16095996,4.0029826,,,,,,,,,,,,,,,,, -22100,0.17268056,4.0363755,,,,,,,,,,,,,,,,, -22200,0.2068796,4.040164,,,,,,,,,,,,,,,,, -22300,0.21420889,3.9943016,,,,,,,,,,,,,,,,, -22400,0.16824566,4.0721216,,,,,,,,,,,,,,,,, -22500,0.16084947,4.028165,,,,,,,,,,,,,,,,, -22600,0.18055409,3.995229,,,,,,,,,,,,,,,,, -22700,0.16484456,4.0509677,,,,,,,,,,,,,,,,, -22800,0.16452663,3.9039445,,,,,,,,,,,,,,,,, -22900,0.17353417,4.0486135,,,,,,,,,,,,,,,,, -23000,0.16805552,4.0309825,,,,,,,,,,,,,,,,, -23100,0.17279838,4.0193405,,,,,,,,,,,,,,,,, -23200,0.18392594,4.01013,,,,,,,,,,,,,,,,, -23300,0.15749048,3.9645529,,,,,,,,,,,,,,,,, -23400,0.17202479,4.012308,,,,,,,,,,,,,,,,, -23500,0.18615125,3.9968114,,,,,,,,,,,,,,,,, -23600,0.24577576,4.0582786,,,,,,,,,,,,,,,,, -23700,0.16165848,3.991314,,,,,,,,,,,,,,,,, -23800,0.17336877,4.0494976,,,,,,,,,,,,,,,,, -23900,0.15837336,3.9922535,,,,,,,,,,,,,,,,, -24000,0.17280789,3.9902456,,,,,,,,,,,,,,,,, -24047,,,0.6381778717041016,1.978065729141236,31.64516066956896,0.6574996113777161,1.8356502056121824,27.9298678608305,3000.0,0.668502688407898,1.7730231285095217,27.702566928617625,3003.0,8428.0314412117,14321.60050368309,8428.0314412117,5892.5287890434265,0.274691104888916,0.0 -24100,0.19020334,4.056404,,,,,,,,,,,,,,,,, -24200,0.18139672,3.9952247,,,,,,,,,,,,,,,,, -24300,0.17399071,4.095223,,,,,,,,,,,,,,,,, -24400,0.17137112,4.0357246,,,,,,,,,,,,,,,,, -24500,0.1891478,3.9982889,,,,,,,,,,,,,,,,, -24600,0.22107035,3.9725316,,,,,,,,,,,,,,,,, -24700,0.18668424,3.9343495,,,,,,,,,,,,,,,,, -24800,0.21357162,4.0138383,,,,,,,,,,,,,,,,, -24900,0.24736208,4.0012903,,,,,,,,,,,,,,,,, -25000,0.20529985,4.010097,,,,,,,,,,,,,,,,, -25100,0.17206283,4.003896,,,,,,,,,,,,,,,,, -25200,0.18236136,3.9816062,,,,,,,,,,,,,,,,, -25300,0.18954997,4.0268373,,,,,,,,,,,,,,,,, -25400,0.17883177,3.9823165,,,,,,,,,,,,,,,,, -25500,0.19909382,3.906015,,,,,,,,,,,,,,,,, -25600,0.17288189,3.9966645,,,,,,,,,,,,,,,,, -25700,0.17792997,3.9737732,,,,,,,,,,,,,,,,, -25800,0.17047176,3.9037597,,,,,,,,,,,,,,,,, -25900,0.17793085,4.0206594,,,,,,,,,,,,,,,,, -26000,0.1799477,3.8765016,,,,,,,,,,,,,,,,, -26100,0.18524301,4.0292015,,,,,,,,,,,,,,,,, -26200,0.18365815,4.028508,,,,,,,,,,,,,,,,, -26300,0.18287376,3.9821894,,,,,,,,,,,,,,,,, -26400,0.2324398,4.077704,,,,,,,,,,,,,,,,, -26453,,,0.650377094745636,1.885564208030701,31.99286060995921,0.6624096632003784,1.792500615119934,28.47187284902378,3000.0,0.6716518402099609,1.732503056526184,28.19627620748597,3003.0,9267.948320865631,15635.728284358978,9267.948320865631,6366.634291410446,0.3060090541839599,0.0 -26500,0.19907787,4.0174108,,,,,,,,,,,,,,,,, -26600,0.20273831,3.9926379,,,,,,,,,,,,,,,,, -26700,0.24913777,3.973742,,,,,,,,,,,,,,,,, -26800,0.21284491,4.019382,,,,,,,,,,,,,,,,, -26900,0.18576081,3.9667652,,,,,,,,,,,,,,,,, -27000,0.18395706,3.9375286,,,,,,,,,,,,,,,,, -27100,0.2027031,3.99399,,,,,,,,,,,,,,,,, -27200,0.19186686,3.9510946,,,,,,,,,,,,,,,,, -27300,0.25244084,5.811736,,,,,,,,,,,,,,,,, -27400,0.21218602,5.595997,,,,,,,,,,,,,,,,, -27500,0.29291454,5.5795426,,,,,,,,,,,,,,,,, -27600,0.42885816,5.5702624,,,,,,,,,,,,,,,,, -27700,0.24738857,5.552463,,,,,,,,,,,,,,,,, -27800,0.29103172,5.507767,,,,,,,,,,,,,,,,, -27900,0.49756974,5.484077,,,,,,,,,,,,,,,,, -28000,0.36160576,4.279431,,,,,,,,,,,,,,,,, -28100,0.23334455,3.9873588,,,,,,,,,,,,,,,,, -28200,0.218407,3.9943056,,,,,,,,,,,,,,,,, -28300,0.27601832,3.954486,,,,,,,,,,,,,,,,, -28400,0.21044146,3.9686048,,,,,,,,,,,,,,,,, -28500,0.2139426,4.0062904,,,,,,,,,,,,,,,,, -28600,0.17206115,3.9915118,,,,,,,,,,,,,,,,, -28700,0.17746896,3.959823,,,,,,,,,,,,,,,,, -28800,0.18183678,3.9972768,,,,,,,,,,,,,,,,, -28861,,,0.6490204334259033,1.8970963954925537,31.745865182709256,0.6626080274581909,1.7790693044662476,28.336580553316416,3000.0,0.6757306456565857,1.708932399749756,28.2438093623282,3003.0,10108.041011333466,16965.351059913635,10108.041011333466,6856.061012983322,0.3347887992858886,0.0 -28900,0.2102271,4.022375,,,,,,,,,,,,,,,,, -29000,0.1948953,3.9479783,,,,,,,,,,,,,,,,, -29100,0.21144314,3.9873905,,,,,,,,,,,,,,,,, -29200,0.18069723,3.9835916,,,,,,,,,,,,,,,,, -29300,0.18367234,4.026177,,,,,,,,,,,,,,,,, -29400,0.1996697,3.9685905,,,,,,,,,,,,,,,,, -29500,0.20752294,3.9859753,,,,,,,,,,,,,,,,, -29600,0.25398755,3.9482045,,,,,,,,,,,,,,,,, -29700,0.18282811,3.9786975,,,,,,,,,,,,,,,,, -29800,0.21593532,3.950339,,,,,,,,,,,,,,,,, -29900,0.18956085,3.9740024,,,,,,,,,,,,,,,,, -30000,0.20273742,3.9707463,,,,,,,,,,,,,,,,, -30100,0.25842237,3.989899,,,,,,,,,,,,,,,,, -30200,0.20177098,3.9625602,,,,,,,,,,,,,,,,, -30300,0.19616702,3.979771,,,,,,,,,,,,,,,,, -30400,0.20864405,3.9204693,,,,,,,,,,,,,,,,, -30500,0.20418267,3.982744,,,,,,,,,,,,,,,,, -30600,0.19061299,3.9017746,,,,,,,,,,,,,,,,, -30700,0.22471914,3.9543383,,,,,,,,,,,,,,,,, -30800,0.23361528,3.9103975,,,,,,,,,,,,,,,,, -30900,0.22200927,3.951345,,,,,,,,,,,,,,,,, -31000,0.29514205,3.968087,,,,,,,,,,,,,,,,, -31100,0.19963121,4.006147,,,,,,,,,,,,,,,,, -31200,0.18830657,3.923484,,,,,,,,,,,,,,,,, -31268,,,0.6462864875793457,1.938828706741333,32.04996210677026,0.6645174622535706,1.804949164390564,28.524153144156863,3000.0,0.6761257648468018,1.7377357482910156,27.99109156876941,3003.0,10948.245263814926,18491.9369943142,10948.245263814926,7542.333392858505,0.3696925640106201,0.0 -31300,0.20821111,3.984094,,,,,,,,,,,,,,,,, -31400,0.18766293,3.9626298,,,,,,,,,,,,,,,,, -31500,0.221703,3.967011,,,,,,,,,,,,,,,,, -31600,0.21675953,3.9302468,,,,,,,,,,,,,,,,, -31700,0.1972951,3.9824524,,,,,,,,,,,,,,,,, -31800,0.3235003,3.8898916,,,,,,,,,,,,,,,,, -31900,0.20974271,3.975831,,,,,,,,,,,,,,,,, -32000,0.2529183,3.953208,,,,,,,,,,,,,,,,, -32100,0.21148679,3.9676511,,,,,,,,,,,,,,,,, -32200,0.22618,3.9616024,,,,,,,,,,,,,,,,, -32300,0.20850952,3.9819062,,,,,,,,,,,,,,,,, -32400,0.25880888,4.0046244,,,,,,,,,,,,,,,,, -32500,0.27720934,3.9365919,,,,,,,,,,,,,,,,, -32600,0.24054582,3.9517193,,,,,,,,,,,,,,,,, -32700,0.18390323,3.9449446,,,,,,,,,,,,,,,,, -32800,0.19479589,3.98714,,,,,,,,,,,,,,,,, -32900,0.19293532,3.9375823,,,,,,,,,,,,,,,,, -33000,0.21184456,3.9245982,,,,,,,,,,,,,,,,, -33100,0.2049112,3.9552271,,,,,,,,,,,,,,,,, -33200,0.2080167,3.976094,,,,,,,,,,,,,,,,, -33300,0.23689607,3.9209516,,,,,,,,,,,,,,,,, -33400,0.22596204,3.968907,,,,,,,,,,,,,,,,, -33500,0.21032104,3.9350228,,,,,,,,,,,,,,,,, -33600,0.2233962,3.9998164,,,,,,,,,,,,,,,,, -33674,,,0.6534045934677124,1.8654236793518064,31.9049431373965,0.667592465877533,1.7680360078811646,28.721269757279806,3000.0,0.6791703104972839,1.6936718225479126,28.64386057392573,3003.0,11788.358795166016,19847.839690446854,11788.358795166016,8058.018579006195,0.3987507820129394,0.0 -33700,0.29195634,3.9979265,,,,,,,,,,,,,,,,, -33800,0.22128959,3.9246814,,,,,,,,,,,,,,,,, -33900,0.7214327,4.848242,,,,,,,,,,,,,,,,, -34000,0.236445,4.0396557,,,,,,,,,,,,,,,,, -34100,0.23174486,3.9941814,,,,,,,,,,,,,,,,, -34200,0.19310367,3.986097,,,,,,,,,,,,,,,,, -34300,0.20529683,3.9311364,,,,,,,,,,,,,,,,, -34400,0.20587303,3.9219892,,,,,,,,,,,,,,,,, -34500,0.17713831,3.8867948,,,,,,,,,,,,,,,,, -34600,0.24216521,3.9265904,,,,,,,,,,,,,,,,, -34700,0.19930038,3.990141,,,,,,,,,,,,,,,,, -34800,0.19117178,3.8923118,,,,,,,,,,,,,,,,, -34900,0.17785846,3.9447722,,,,,,,,,,,,,,,,, -35000,0.2620335,3.8985872,,,,,,,,,,,,,,,,, -35100,0.20802945,3.920258,,,,,,,,,,,,,,,,, -35200,0.23314133,3.8715477,,,,,,,,,,,,,,,,, -35300,0.20596127,3.9316242,,,,,,,,,,,,,,,,, -35400,0.24477857,3.9174771,,,,,,,,,,,,,,,,, -35500,0.25544515,3.943121,,,,,,,,,,,,,,,,, -35600,0.21684025,3.9073002,,,,,,,,,,,,,,,,, -35700,0.22630051,3.9523122,,,,,,,,,,,,,,,,, -35800,0.21134683,3.9403558,,,,,,,,,,,,,,,,, -35900,0.26237482,3.9555392,,,,,,,,,,,,,,,,, -36000,0.19300956,3.90112,,,,,,,,,,,,,,,,, -36080,,,0.6488655209541321,1.889739751815796,32.563431170556136,0.6692167520523071,1.754157543182373,29.021798663119707,3000.0,0.6801928877830505,1.681604504585266,28.412542674194228,3003.0,12628.493074417114,21165.33632516861,12628.493074417114,8535.275834321976,0.4287540912628174,0.0 -36100,0.22193678,3.868378,,,,,,,,,,,,,,,,, -36200,0.2006444,3.9459658,,,,,,,,,,,,,,,,, -36300,0.23036149,3.8687825,,,,,,,,,,,,,,,,, -36400,0.21492618,3.974823,,,,,,,,,,,,,,,,, -36500,0.2610066,3.9309928,,,,,,,,,,,,,,,,, -36600,0.23278217,3.8931727,,,,,,,,,,,,,,,,, -36700,0.23148538,3.8650534,,,,,,,,,,,,,,,,, -36800,0.28765902,3.8901408,,,,,,,,,,,,,,,,, -36900,0.2095393,3.9784112,,,,,,,,,,,,,,,,, -37000,0.21991816,3.9368374,,,,,,,,,,,,,,,,, -37100,0.20215432,3.8889394,,,,,,,,,,,,,,,,, -37200,0.19679414,3.9321427,,,,,,,,,,,,,,,,, -37300,0.22541729,3.9660072,,,,,,,,,,,,,,,,, -37400,0.28931323,4.0188766,,,,,,,,,,,,,,,,, -37500,0.21217823,3.9477117,,,,,,,,,,,,,,,,, -37600,0.20795625,3.992511,,,,,,,,,,,,,,,,, -37700,0.23726107,3.9759703,,,,,,,,,,,,,,,,, -37800,0.2083969,3.8623984,,,,,,,,,,,,,,,,, -37900,0.22784677,3.9430733,,,,,,,,,,,,,,,,, -38000,0.21295887,3.9322743,,,,,,,,,,,,,,,,, -38100,0.25004384,3.9485974,,,,,,,,,,,,,,,,, -38200,0.22097102,3.9053352,,,,,,,,,,,,,,,,, -38300,0.22139432,3.914268,,,,,,,,,,,,,,,,, -38400,0.22431053,3.9415164,,,,,,,,,,,,,,,,, -38487,,,0.6632543206214905,1.8003400564193726,32.92000144955914,0.66898113489151,1.75272798538208,28.953533798359764,3000.0,0.6821103096008301,1.679569959640503,28.864725008764605,3003.0,13468.718967199326,22477.581879138947,13468.718967199326,9007.191690206528,0.4582366943359375,0.0 -38500,0.2252478,3.9405022,,,,,,,,,,,,,,,,, -38600,0.22189415,3.9067156,,,,,,,,,,,,,,,,, -38700,0.26716897,3.913757,,,,,,,,,,,,,,,,, -38800,0.21162574,3.9091756,,,,,,,,,,,,,,,,, -38900,0.21467462,3.9070961,,,,,,,,,,,,,,,,, -39000,0.22308801,3.938059,,,,,,,,,,,,,,,,, -39100,0.2121113,3.931008,,,,,,,,,,,,,,,,, -39200,0.24428046,3.8337023,,,,,,,,,,,,,,,,, -39300,0.29673314,3.9471703,,,,,,,,,,,,,,,,, -39400,0.24824747,3.8756723,,,,,,,,,,,,,,,,, -39500,0.20873377,3.9340494,,,,,,,,,,,,,,,,, -39600,0.25839126,3.85574,,,,,,,,,,,,,,,,, -39700,0.24739009,3.9065092,,,,,,,,,,,,,,,,, -39800,0.25395358,3.9121132,,,,,,,,,,,,,,,,, -39900,0.319253,3.9020545,,,,,,,,,,,,,,,,, -40000,0.21988809,3.9257095,,,,,,,,,,,,,,,,, -40100,0.19713187,3.8955634,,,,,,,,,,,,,,,,, -40200,0.23206261,3.9687529,,,,,,,,,,,,,,,,, -40300,0.20745902,3.9068449,,,,,,,,,,,,,,,,, -40400,0.22109032,3.95291,,,,,,,,,,,,,,,,, -40500,0.24832097,3.9449828,,,,,,,,,,,,,,,,, -40600,0.20977326,3.8283896,,,,,,,,,,,,,,,,, -40700,0.30280274,3.9236033,,,,,,,,,,,,,,,,, -40800,0.22420274,3.9305067,,,,,,,,,,,,,,,,, -40893,,,0.6533471941947937,1.8556861877441408,32.56679466777,0.6716221570968628,1.7270926237106323,29.09871189610873,3000.0,0.6849921941757202,1.6443243026733398,29.026990271202685,3003.0,14308.6868288517,23799.267735242844,14308.6868288517,9488.803978919985,0.4883334636688232,0.0 -40900,0.26384744,3.9825952,,,,,,,,,,,,,,,,, -41000,0.27198726,3.9498074,,,,,,,,,,,,,,,,, -41100,0.22707734,3.943823,,,,,,,,,,,,,,,,, -41200,0.33801702,3.8748271,,,,,,,,,,,,,,,,, -41300,0.22925694,3.9325507,,,,,,,,,,,,,,,,, -41400,0.21711345,3.9204087,,,,,,,,,,,,,,,,, -41500,0.21031499,3.8880122,,,,,,,,,,,,,,,,, -41600,0.25053513,3.9409668,,,,,,,,,,,,,,,,, -41700,0.25413823,3.9138997,,,,,,,,,,,,,,,,, -41800,0.22055805,3.9511294,,,,,,,,,,,,,,,,, -41900,0.2670587,3.8777492,,,,,,,,,,,,,,,,, -42000,0.29677546,3.9561906,,,,,,,,,,,,,,,,, -42100,0.23603147,3.9178479,,,,,,,,,,,,,,,,, -42200,0.2523126,3.8928456,,,,,,,,,,,,,,,,, -42300,0.24093527,3.8692575,,,,,,,,,,,,,,,,, -42400,0.24537832,3.8524938,,,,,,,,,,,,,,,,, -42500,0.29005978,3.8917234,,,,,,,,,,,,,,,,, -42600,0.20792942,3.861183,,,,,,,,,,,,,,,,, -42700,0.23466021,3.9019122,,,,,,,,,,,,,,,,, -42800,0.21915112,3.8449123,,,,,,,,,,,,,,,,, -42900,0.20904066,3.8379943,,,,,,,,,,,,,,,,, -43000,0.24832606,3.9022124,,,,,,,,,,,,,,,,, -43100,0.30194175,3.9193058,,,,,,,,,,,,,,,,, -43200,0.25152233,3.9119866,,,,,,,,,,,,,,,,, -43300,,,0.6563020348548889,1.8518691062927248,32.51621448980282,0.6724033355712891,1.733350157737732,29.32289596719666,3000.0,0.685631275177002,1.6498874425888062,28.75444159865201,3003.0,15148.599982261658,25243.02674794197,15148.599982261658,10092.543762683868,0.5201573371887207,0.0 -43300,0.22220714,3.8517828,,,,,,,,,,,,,,,,, -43400,0.29836485,3.9334629,,,,,,,,,,,,,,,,, -43500,0.22195812,3.882144,,,,,,,,,,,,,,,,, -43600,0.25599658,3.8761003,,,,,,,,,,,,,,,,, -43700,0.23234518,3.9257374,,,,,,,,,,,,,,,,, -43800,0.22093885,3.9248567,,,,,,,,,,,,,,,,, -43900,0.2250653,3.8981936,,,,,,,,,,,,,,,,, -44000,0.28868508,3.901973,,,,,,,,,,,,,,,,, -44100,0.22651967,3.864452,,,,,,,,,,,,,,,,, -44200,0.23829132,3.9029078,,,,,,,,,,,,,,,,, -44300,0.30920693,3.88364,,,,,,,,,,,,,,,,, -44400,0.24304198,3.8607628,,,,,,,,,,,,,,,,, -44500,0.23622324,3.9293659,,,,,,,,,,,,,,,,, -44600,0.21366423,3.8268945,,,,,,,,,,,,,,,,, -44700,0.2623054,3.892,,,,,,,,,,,,,,,,, -44800,0.28383875,3.9070125,,,,,,,,,,,,,,,,, -44900,0.25162008,3.9398937,,,,,,,,,,,,,,,,, -45000,0.261201,3.8817258,,,,,,,,,,,,,,,,, -45100,0.2411459,3.894238,,,,,,,,,,,,,,,,, -45200,0.28639042,3.8823462,,,,,,,,,,,,,,,,, -45300,0.23535721,3.9600513,,,,,,,,,,,,,,,,, -45400,0.22589074,3.9051712,,,,,,,,,,,,,,,,, -45500,0.2305522,3.906752,,,,,,,,,,,,,,,,, -45600,0.2805388,3.894608,,,,,,,,,,,,,,,,, -45700,0.23818901,3.8615706,,,,,,,,,,,,,,,,, -45706,,,0.6638458371162415,1.7846899032592771,33.13194819782453,0.6739159822463989,1.7038495540618896,29.235713821138607,3000.0,0.6875951886177063,1.6285982131958008,29.44166874318576,3003.0,15988.6215569973,26588.011610507965,15988.6215569973,10597.39832019806,0.5518801212310791,0.0 -45800,0.23441054,3.8518188,,,,,,,,,,,,,,,,, -45900,0.23628886,3.8947723,,,,,,,,,,,,,,,,, -46000,0.21664587,3.8824978,,,,,,,,,,,,,,,,, -46100,0.22688042,3.9307957,,,,,,,,,,,,,,,,, -46200,0.25705895,3.9234107,,,,,,,,,,,,,,,,, -46300,0.24781786,3.8804379,,,,,,,,,,,,,,,,, -46400,0.26419938,3.8575459,,,,,,,,,,,,,,,,, -46500,0.22876838,3.8813002,,,,,,,,,,,,,,,,, -46600,0.25085667,3.840015,,,,,,,,,,,,,,,,, -46700,0.22661924,3.972223,,,,,,,,,,,,,,,,, -46800,0.26949567,3.8372166,,,,,,,,,,,,,,,,, -46900,0.26642343,3.8696873,,,,,,,,,,,,,,,,, -47000,0.24128419,3.891435,,,,,,,,,,,,,,,,, -47100,0.24585658,3.8620992,,,,,,,,,,,,,,,,, -47200,0.20849344,3.830705,,,,,,,,,,,,,,,,, -47300,0.22275445,3.842067,,,,,,,,,,,,,,,,, -47400,0.27036455,3.8574235,,,,,,,,,,,,,,,,, -47500,0.23119397,3.87664,,,,,,,,,,,,,,,,, -47600,0.25543058,3.8764105,,,,,,,,,,,,,,,,, -47700,0.29209068,3.9118364,,,,,,,,,,,,,,,,, -47800,0.24716106,3.8341503,,,,,,,,,,,,,,,,, -47900,0.24052377,3.86777,,,,,,,,,,,,,,,,, -48000,0.21834373,3.8296397,,,,,,,,,,,,,,,,, -48100,0.27204245,3.903711,,,,,,,,,,,,,,,,, -48112,,,0.6603615880012512,1.8153201341629028,32.8939815446997,0.6750814914703369,1.6979299783706665,29.456297594517302,3000.0,0.6886642575263977,1.616563081741333,29.52146011305364,3003.0,16828.530921697617,27914.88912725449,16828.530921697617,11084.258370399475,0.5843021869659424,0.0 -48200,0.26059666,3.878296,,,,,,,,,,,,,,,,, -48300,0.24778576,3.9050121,,,,,,,,,,,,,,,,, -48400,0.22343512,3.873094,,,,,,,,,,,,,,,,, -48500,0.21997288,3.8492184,,,,,,,,,,,,,,,,, -48600,0.26741475,3.9202478,,,,,,,,,,,,,,,,, -48700,0.23399842,3.8739188,,,,,,,,,,,,,,,,, -48800,0.25047246,3.8348823,,,,,,,,,,,,,,,,, -48900,0.2446602,3.948564,,,,,,,,,,,,,,,,, -49000,0.2981983,3.88067,,,,,,,,,,,,,,,,, -49100,0.23724735,3.882853,,,,,,,,,,,,,,,,, -49200,0.2388087,3.9037254,,,,,,,,,,,,,,,,, -49300,0.3237528,3.9050558,,,,,,,,,,,,,,,,, -49400,0.27006957,3.8539424,,,,,,,,,,,,,,,,, -49500,0.23423757,3.9146652,,,,,,,,,,,,,,,,, -49600,0.2426809,3.8948712,,,,,,,,,,,,,,,,, -49700,0.24781443,3.8502614,,,,,,,,,,,,,,,,, -49800,0.26560065,3.869961,,,,,,,,,,,,,,,,, -49900,0.25413284,3.924253,,,,,,,,,,,,,,,,, -50000,1.7293766,5.572414,,,,,,,,,,,,,,,,, -50100,0.27271616,3.985427,,,,,,,,,,,,,,,,, -50200,0.26247263,3.976699,,,,,,,,,,,,,,,,, -50300,0.22657867,3.8379745,,,,,,,,,,,,,,,,, -50400,0.22599593,3.8607821,,,,,,,,,,,,,,,,, -50500,0.27965018,3.892822,,,,,,,,,,,,,,,,, -50518,,,0.669950544834137,1.7502673864364624,33.24540770074166,0.6771769523620605,1.6892192363739014,29.52148956201149,3000.0,0.6907094717025757,1.609735131263733,29.501761063552795,3003.0,17668.503177165985,29217.95916581154,17668.503177165985,11547.249927043917,0.6154048442840576,0.0 -50600,0.2441526,3.8687682,,,,,,,,,,,,,,,,, -50700,0.24662957,3.9425907,,,,,,,,,,,,,,,,, -50800,0.25465295,3.8143609,,,,,,,,,,,,,,,,, -50900,0.24586177,3.877175,,,,,,,,,,,,,,,,, -51000,0.24049129,3.8608632,,,,,,,,,,,,,,,,, -51100,0.2383368,3.8394582,,,,,,,,,,,,,,,,, -51200,0.22150527,3.9058475,,,,,,,,,,,,,,,,, -51300,0.24252792,3.8940117,,,,,,,,,,,,,,,,, -51400,0.23346092,3.912082,,,,,,,,,,,,,,,,, -51500,0.22200035,3.785356,,,,,,,,,,,,,,,,, -51600,0.2295942,3.895255,,,,,,,,,,,,,,,,, -51700,0.23350321,3.8533666,,,,,,,,,,,,,,,,, -51800,0.23742495,3.9176483,,,,,,,,,,,,,,,,, -51900,0.27043387,3.8967383,,,,,,,,,,,,,,,,, -52000,0.23545761,3.8821213,,,,,,,,,,,,,,,,, -52100,0.25079697,3.8897512,,,,,,,,,,,,,,,,, -52200,0.2478731,3.8460515,,,,,,,,,,,,,,,,, -52300,0.24690333,3.9152813,,,,,,,,,,,,,,,,, -52400,0.23176269,3.8805552,,,,,,,,,,,,,,,,, -52500,0.3150258,3.850323,,,,,,,,,,,,,,,,, -52600,0.2495207,3.8469028,,,,,,,,,,,,,,,,, -52700,0.2724817,3.890788,,,,,,,,,,,,,,,,, -52800,0.24730046,3.8783276,,,,,,,,,,,,,,,,, -52900,0.23585092,3.853567,,,,,,,,,,,,,,,,, -52924,,,0.6591640114784241,1.8146765232086184,32.61795057033321,0.6771769523620605,1.7013094425201416,29.5259645039707,3000.0,0.6914182901382446,1.6207830905914309,29.61664679715476,3003.0,18508.68393635749,30641.154708862305,18508.68393635749,12130.155924797058,0.6474461555480957,0.0 -53000,0.27626997,3.8964183,,,,,,,,,,,,,,,,, -53100,0.25663525,3.8730042,,,,,,,,,,,,,,,,, -53200,0.3108674,3.85566,,,,,,,,,,,,,,,,, -53300,0.29373583,3.8283672,,,,,,,,,,,,,,,,, -53400,0.22518723,3.8413725,,,,,,,,,,,,,,,,, -53500,0.24252115,3.8722723,,,,,,,,,,,,,,,,, -53600,0.23711796,3.869609,,,,,,,,,,,,,,,,, -53700,0.25299165,3.8003812,,,,,,,,,,,,,,,,, -53800,0.28299028,3.857374,,,,,,,,,,,,,,,,, -53900,0.25034457,3.8655982,,,,,,,,,,,,,,,,, -54000,0.2492761,3.8729503,,,,,,,,,,,,,,,,, -54100,0.23771901,3.813648,,,,,,,,,,,,,,,,, -54200,0.27948773,3.9287007,,,,,,,,,,,,,,,,, -54300,0.2624968,3.8075244,,,,,,,,,,,,,,,,, -54400,0.24056366,3.824562,,,,,,,,,,,,,,,,, -54500,0.26411378,3.828882,,,,,,,,,,,,,,,,, -54600,0.24050695,3.8836052,,,,,,,,,,,,,,,,, -54700,0.24726278,3.825455,,,,,,,,,,,,,,,,, -54800,0.26005486,3.8418248,,,,,,,,,,,,,,,,, -54900,0.23766167,3.8438346,,,,,,,,,,,,,,,,, -55000,0.29051778,3.8488858,,,,,,,,,,,,,,,,, -55100,0.24730393,3.8342023,,,,,,,,,,,,,,,,, -55200,0.2508357,3.8104563,,,,,,,,,,,,,,,,, -55300,0.22682025,3.9010575,,,,,,,,,,,,,,,,, -55330,,,0.6601080894470215,1.815306544303894,32.35156533613473,0.677424967288971,1.687734603881836,29.76204785307309,3000.0,0.6917552947998047,1.6076879501342771,29.35782497609175,3003.0,19348.650020122528,32106.20058512688,19348.650020122528,12755.12444281578,0.679734468460083,0.0 -55400,0.26133877,3.8726246,,,,,,,,,,,,,,,,, -55500,0.24660535,3.8704638,,,,,,,,,,,,,,,,, -55600,0.25215995,3.8898952,,,,,,,,,,,,,,,,, -55700,0.25284234,3.8867137,,,,,,,,,,,,,,,,, -55800,0.24655123,3.874868,,,,,,,,,,,,,,,,, -55900,0.2468998,3.8077252,,,,,,,,,,,,,,,,, -56000,0.2894071,3.8162177,,,,,,,,,,,,,,,,, -56100,0.24748474,3.7975662,,,,,,,,,,,,,,,,, -56200,0.2603348,3.8927531,,,,,,,,,,,,,,,,, -56300,0.24936236,3.8629105,,,,,,,,,,,,,,,,, -56400,0.3449494,3.9001253,,,,,,,,,,,,,,,,, -56500,0.2608814,3.8313396,,,,,,,,,,,,,,,,, -56600,0.26607043,3.9245546,,,,,,,,,,,,,,,,, -56700,0.2512182,3.8026495,,,,,,,,,,,,,,,,, -56800,0.25239086,3.9090862,,,,,,,,,,,,,,,,, -56900,0.24133661,3.8361182,,,,,,,,,,,,,,,,, -57000,0.27206233,3.8675866,,,,,,,,,,,,,,,,, -57100,0.24629837,3.8400605,,,,,,,,,,,,,,,,, -57200,0.25698465,3.8835213,,,,,,,,,,,,,,,,, -57300,0.2573152,3.923613,,,,,,,,,,,,,,,,, -57400,0.2433494,3.8556683,,,,,,,,,,,,,,,,, -57500,0.23191816,3.8183272,,,,,,,,,,,,,,,,, -57600,0.25947398,3.80066,,,,,,,,,,,,,,,,, -57700,0.26762503,3.8550951,,,,,,,,,,,,,,,,, -57737,,,0.6684832572937012,1.7588273286819458,33.19071769118694,0.6785284876823425,1.682852268218994,29.483006700913737,3000.0,0.6926733255386353,1.5996352434158323,29.221216444179152,3003.0,20188.56041240692,33447.00986433029,20188.56041240692,13255.915263652802,0.7129116058349609,0.0 -57800,0.2901458,3.8344202,,,,,,,,,,,,,,,,, -57900,0.25374988,3.8213592,,,,,,,,,,,,,,,,, -58000,0.24498959,3.8097932,,,,,,,,,,,,,,,,, -58100,0.25875434,3.7911274,,,,,,,,,,,,,,,,, -58200,0.25494215,3.8025463,,,,,,,,,,,,,,,,, -58300,0.28260806,3.8356886,,,,,,,,,,,,,,,,, -58400,0.24946532,3.809577,,,,,,,,,,,,,,,,, -58500,0.2431428,3.8179224,,,,,,,,,,,,,,,,, -58600,0.23520617,3.8461313,,,,,,,,,,,,,,,,, -58700,0.2652438,3.8632476,,,,,,,,,,,,,,,,, -58800,0.25123933,3.8673677,,,,,,,,,,,,,,,,, -58900,0.2623599,3.8725834,,,,,,,,,,,,,,,,, -59000,0.2544265,3.838854,,,,,,,,,,,,,,,,, -59100,0.2812793,3.8610244,,,,,,,,,,,,,,,,, -59200,0.29753268,3.841945,,,,,,,,,,,,,,,,, -59300,0.24468729,3.8547692,,,,,,,,,,,,,,,,, -59400,0.26275146,3.8629923,,,,,,,,,,,,,,,,, -59500,0.26563755,3.8357482,,,,,,,,,,,,,,,,, -59600,0.29713032,3.8290799,,,,,,,,,,,,,,,,, -59700,0.3381491,3.802795,,,,,,,,,,,,,,,,, -59800,0.2373975,3.833119,,,,,,,,,,,,,,,,, -59900,0.232831,3.755429,,,,,,,,,,,,,,,,, -60000,0.23836035,3.8136547,,,,,,,,,,,,,,,,, -60100,0.24336347,3.85692,,,,,,,,,,,,,,,,, -60144,,,0.6650095582008362,1.7817625999450684,33.05104338028556,0.6779953241348267,1.6812210083007812,29.6121101877778,3000.0,0.6945558190345764,1.593664526939392,29.65021273085204,3003.0,21028.63009619713,34828.37162208557,21028.63009619713,13797.100610494614,0.7455697059631348,0.0 -60200,0.24202473,3.863583,,,,,,,,,,,,,,,,, -60300,0.24186659,3.8536239,,,,,,,,,,,,,,,,, -60400,0.36979026,3.8133657,,,,,,,,,,,,,,,,, -60500,0.24324174,3.7998753,,,,,,,,,,,,,,,,, -60600,0.2589553,3.8457637,,,,,,,,,,,,,,,,, -60700,0.25585485,3.8008144,,,,,,,,,,,,,,,,, -60800,0.30018365,3.883471,,,,,,,,,,,,,,,,, -60900,0.24045788,3.9138172,,,,,,,,,,,,,,,,, -61000,0.27020505,3.8583717,,,,,,,,,,,,,,,,, -61100,0.29608998,3.8304925,,,,,,,,,,,,,,,,, -61200,0.25662035,3.8273225,,,,,,,,,,,,,,,,, -61300,0.26962724,3.8239665,,,,,,,,,,,,,,,,, -61400,0.25238985,3.830818,,,,,,,,,,,,,,,,, -61500,0.25284752,3.820085,,,,,,,,,,,,,,,,, -61600,0.28942692,3.787474,,,,,,,,,,,,,,,,, -61700,0.27203694,3.8186646,,,,,,,,,,,,,,,,, -61800,0.29075152,3.8548057,,,,,,,,,,,,,,,,, -61900,0.23770943,3.7939856,,,,,,,,,,,,,,,,, -62000,0.28635064,3.9114728,,,,,,,,,,,,,,,,, -62100,0.2823219,3.772758,,,,,,,,,,,,,,,,, -62200,0.27434012,3.8292663,,,,,,,,,,,,,,,,, -62300,0.2845828,3.830787,,,,,,,,,,,,,,,,, -62400,0.25835535,3.8182411,,,,,,,,,,,,,,,,, -62500,0.29028255,3.870364,,,,,,,,,,,,,,,,, -62550,,,0.7026515007019043,1.5845561027526855,35.26644664948576,0.6804379224777222,1.6678342819213867,30.021677303349723,3000.0,0.6963453888893127,1.5847262144088743,29.76351245408125,3003.0,21868.660029172897,36325.46452903748,21868.660029172897,14454.052380561829,0.7788140773773193,0.0 -62600,0.27047333,3.8316813,,,,,,,,,,,,,,,,, -62700,0.26773417,3.8583581,,,,,,,,,,,,,,,,, -62800,0.2733978,3.8028953,,,,,,,,,,,,,,,,, -62900,0.28779635,3.822361,,,,,,,,,,,,,,,,, -63000,0.2625954,3.8522463,,,,,,,,,,,,,,,,, -63100,0.2969318,3.7974932,,,,,,,,,,,,,,,,, -63200,0.24027811,3.7594151,,,,,,,,,,,,,,,,, -63300,0.24456261,3.8175423,,,,,,,,,,,,,,,,, -63400,0.2535184,3.8549933,,,,,,,,,,,,,,,,, -63500,0.28936347,3.865727,,,,,,,,,,,,,,,,, -63600,0.27525675,3.8659685,,,,,,,,,,,,,,,,, -63700,0.27150905,3.8545034,,,,,,,,,,,,,,,,, -63800,0.27062577,3.857386,,,,,,,,,,,,,,,,, -63900,0.2506967,3.832746,,,,,,,,,,,,,,,,, -64000,0.2499257,3.8375328,,,,,,,,,,,,,,,,, -64100,0.26673916,3.7498937,,,,,,,,,,,,,,,,, -64200,0.25520816,3.8934174,,,,,,,,,,,,,,,,, -64300,0.25786737,3.8377113,,,,,,,,,,,,,,,,, -64400,0.25342235,3.9233193,,,,,,,,,,,,,,,,, -64500,0.26308745,3.8648717,,,,,,,,,,,,,,,,, -64600,0.27062237,3.8414586,,,,,,,,,,,,,,,,, -64700,0.29522142,3.7658052,,,,,,,,,,,,,,,,, -64800,0.2596061,3.8480482,,,,,,,,,,,,,,,,, -64900,0.25600684,3.8478436,,,,,,,,,,,,,,,,, -64956,,,0.6722168922424316,1.7237827777862549,33.19875507863805,0.6810950636863708,1.657042384147644,29.83164370579329,3000.0,0.6964964270591736,1.5666615962982178,30.08841979180686,3003.0,22708.81247138977,37632.84275865555,22708.81247138977,14921.166560411451,0.8124041557312012,0.0 -65000,0.2654662,3.8352635,,,,,,,,,,,,,,,,, -65100,0.2667582,3.8068237,,,,,,,,,,,,,,,,, -65200,0.27191645,3.8334975,,,,,,,,,,,,,,,,, -65300,0.26246765,3.73491,,,,,,,,,,,,,,,,, -65400,0.27101532,3.889497,,,,,,,,,,,,,,,,, -65500,0.2668757,3.84983,,,,,,,,,,,,,,,,, -65600,0.28233382,3.855197,,,,,,,,,,,,,,,,, -65700,0.27891284,3.770193,,,,,,,,,,,,,,,,, -65800,0.2508243,3.7960548,,,,,,,,,,,,,,,,, -65900,0.25566614,3.8280168,,,,,,,,,,,,,,,,, -66000,0.26101905,3.7776108,,,,,,,,,,,,,,,,, -66100,0.24893223,3.8003619,,,,,,,,,,,,,,,,, -66200,0.2577433,3.8412876,,,,,,,,,,,,,,,,, -66300,0.27736658,3.7890828,,,,,,,,,,,,,,,,, -66400,0.2821093,3.7797709,,,,,,,,,,,,,,,,, -66500,0.29558495,3.8315625,,,,,,,,,,,,,,,,, -66600,0.25882077,3.8593519,,,,,,,,,,,,,,,,, -66700,0.27360752,3.8150892,,,,,,,,,,,,,,,,, -66800,0.25282565,3.863888,,,,,,,,,,,,,,,,, -66900,0.26534307,3.7586205,,,,,,,,,,,,,,,,, -67000,0.2902869,3.8499753,,,,,,,,,,,,,,,,, -67100,0.32868204,3.8472793,,,,,,,,,,,,,,,,, -67200,0.26057816,3.7887983,,,,,,,,,,,,,,,,, -67300,0.2742534,3.7856286,,,,,,,,,,,,,,,,, -67362,,,0.6712957620620728,1.734437108039856,33.597483300602605,0.6810950636863708,1.6491377353668213,29.94199898103257,3000.0,0.6967636942863464,1.5596245527267456,29.93042932332061,3003.0,23548.917605161667,38978.3294506073,23548.917605161667,15426.43873333931,0.8456981182098389,0.0 -67400,0.28496814,3.8126774,,,,,,,,,,,,,,,,, -67500,0.25236472,3.785182,,,,,,,,,,,,,,,,, -67600,0.27808616,3.7948797,,,,,,,,,,,,,,,,, -67700,0.23865376,3.780715,,,,,,,,,,,,,,,,, -67800,0.28463063,3.8214445,,,,,,,,,,,,,,,,, -67900,0.3262982,3.8681095,,,,,,,,,,,,,,,,, -68000,0.25226423,3.7324018,,,,,,,,,,,,,,,,, -68100,0.2872369,3.8400295,,,,,,,,,,,,,,,,, -68200,0.28451192,3.9078236,,,,,,,,,,,,,,,,, -68300,0.28205764,3.799593,,,,,,,,,,,,,,,,, -68400,0.2646201,3.7686381,,,,,,,,,,,,,,,,, -68500,0.26093882,3.8016489,,,,,,,,,,,,,,,,, -68600,0.26145497,3.8176916,,,,,,,,,,,,,,,,, -68700,0.2740826,3.8084288,,,,,,,,,,,,,,,,, -68800,0.2665332,3.886006,,,,,,,,,,,,,,,,, -68900,0.24427631,3.7927485,,,,,,,,,,,,,,,,, -69000,0.3102391,3.7654479,,,,,,,,,,,,,,,,, -69100,0.27988657,3.8510623,,,,,,,,,,,,,,,,, -69200,0.29904732,3.802756,,,,,,,,,,,,,,,,, -69300,0.24194238,3.7838297,,,,,,,,,,,,,,,,, -69400,0.30924094,3.813964,,,,,,,,,,,,,,,,, -69500,0.2618283,3.8295705,,,,,,,,,,,,,,,,, -69600,0.2383677,3.765986,,,,,,,,,,,,,,,,, -69700,0.29078624,3.8763492,,,,,,,,,,,,,,,,, -69769,,,0.6853699684143066,1.6466609239578247,34.29055829084901,0.6822109818458557,1.6531620025634766,29.90150956813121,3000.0,0.6988670229911804,1.5575608015060425,30.00534103427137,3003.0,24389.033936738968,40333.188530921936,24389.033936738968,15941.064447641373,0.8869020938873291,0.0 -69800,0.27249116,3.7671402,,,,,,,,,,,,,,,,, -69900,0.29946083,3.820633,,,,,,,,,,,,,,,,, -70000,0.2564586,3.7455585,,,,,,,,,,,,,,,,, -70100,0.26287925,3.7748716,,,,,,,,,,,,,,,,, -70200,0.27334926,3.782229,,,,,,,,,,,,,,,,, -70300,0.29121444,3.7716231,,,,,,,,,,,,,,,,, -70400,0.3114573,3.7562268,,,,,,,,,,,,,,,,, -70500,0.2721695,3.8527489,,,,,,,,,,,,,,,,, -70600,0.2781206,3.7630134,,,,,,,,,,,,,,,,, -70700,0.26089245,3.7742846,,,,,,,,,,,,,,,,, -70800,0.3377229,3.801844,,,,,,,,,,,,,,,,, -70900,0.25426733,3.7885728,,,,,,,,,,,,,,,,, -71000,0.2767844,3.7768736,,,,,,,,,,,,,,,,, -71100,0.27345434,3.7917054,,,,,,,,,,,,,,,,, -71200,0.27047798,3.8054423,,,,,,,,,,,,,,,,, -71300,0.26514667,3.7620392,,,,,,,,,,,,,,,,, -71400,0.3245142,3.8211513,,,,,,,,,,,,,,,,, -71500,0.28130043,3.823434,,,,,,,,,,,,,,,,, -71600,0.2651632,3.7914042,,,,,,,,,,,,,,,,, -71700,0.29009464,3.7652013,,,,,,,,,,,,,,,,, -71800,0.2948072,3.7979581,,,,,,,,,,,,,,,,, -71900,0.29837516,3.7050483,,,,,,,,,,,,,,,,, -72000,0.2857001,3.8348637,,,,,,,,,,,,,,,,, -72100,0.26516786,3.7842717,,,,,,,,,,,,,,,,, -72175,,,0.6768722534179688,1.701812744140625,33.83223992527047,0.6833640933036804,1.6468572616577148,29.873267425150825,3000.0,0.6993783116340637,1.5530868768692017,29.9560497314722,3003.0,25228.926471233368,41618.32129120827,25228.926471233368,16386.192722558975,0.9233071804046632,0.0 -72200,0.27156708,3.819386,,,,,,,,,,,,,,,,, -72300,0.2916024,3.754329,,,,,,,,,,,,,,,,, -72400,0.291434,3.777815,,,,,,,,,,,,,,,,, -72500,0.2775171,3.7764196,,,,,,,,,,,,,,,,, -72600,0.3192963,3.7360032,,,,,,,,,,,,,,,,, -72700,0.29806665,3.7885528,,,,,,,,,,,,,,,,, -72800,0.26071167,3.8084528,,,,,,,,,,,,,,,,, -72900,0.28764826,3.8052976,,,,,,,,,,,,,,,,, -73000,0.33271605,3.8138695,,,,,,,,,,,,,,,,, -73100,0.29827625,3.7658393,,,,,,,,,,,,,,,,, -73200,0.25744268,3.7816162,,,,,,,,,,,,,,,,, -73300,0.28429037,3.7777867,,,,,,,,,,,,,,,,, -73400,0.27860808,3.8238964,,,,,,,,,,,,,,,,, -73500,0.26470384,3.743707,,,,,,,,,,,,,,,,, -73600,0.26075444,3.749357,,,,,,,,,,,,,,,,, -73700,0.256266,3.7948687,,,,,,,,,,,,,,,,, -73800,0.26449883,3.8112428,,,,,,,,,,,,,,,,, -73900,0.28384027,3.7963445,,,,,,,,,,,,,,,,, -74000,0.30687332,3.8163204,,,,,,,,,,,,,,,,, -74100,0.26482907,3.778538,,,,,,,,,,,,,,,,, -74200,0.27108347,3.7864766,,,,,,,,,,,,,,,,, -74300,0.26622933,3.7948022,,,,,,,,,,,,,,,,, -74400,0.2911504,3.7791896,,,,,,,,,,,,,,,,, -74500,0.25932676,3.7846413,,,,,,,,,,,,,,,,, -74582,,,0.6750902533531189,1.7135508060455322,33.31560347389563,0.6835005283355713,1.641509175300598,30.202027501872564,3000.0,0.6996455788612366,1.5487751960754397,29.979354940034053,3003.0,26069.1646475792,42970.60527634621,26069.1646475792,16898.12716984749,0.9586780071258544,0.0 -74600,0.27392033,3.7263887,,,,,,,,,,,,,,,,, -74700,0.2644316,3.7916925,,,,,,,,,,,,,,,,, -74800,0.27823865,3.8218343,,,,,,,,,,,,,,,,, -74900,0.25446478,3.73638,,,,,,,,,,,,,,,,, -75000,0.28579476,3.8208606,,,,,,,,,,,,,,,,, -75100,0.26980487,3.765606,,,,,,,,,,,,,,,,, -75200,0.27082068,3.7583935,,,,,,,,,,,,,,,,, -75300,0.2797348,3.7789698,,,,,,,,,,,,,,,,, -75400,0.27618152,3.7547657,,,,,,,,,,,,,,,,, -75500,0.26161072,3.736078,,,,,,,,,,,,,,,,, -75600,0.28001451,3.7910042,,,,,,,,,,,,,,,,, -75700,0.29859275,3.738684,,,,,,,,,,,,,,,,, -75800,0.29486617,3.780768,,,,,,,,,,,,,,,,, -75900,0.3019905,3.8469036,,,,,,,,,,,,,,,,, -76000,0.28905517,3.824768,,,,,,,,,,,,,,,,, -76100,0.28591347,3.7154706,,,,,,,,,,,,,,,,, -76200,0.28487206,3.7752,,,,,,,,,,,,,,,,, -76300,0.28159243,3.8100028,,,,,,,,,,,,,,,,, -76400,0.27720147,3.8480017,,,,,,,,,,,,,,,,, -76500,0.28338504,3.75378,,,,,,,,,,,,,,,,, -76600,0.2804757,3.77035,,,,,,,,,,,,,,,,, -76700,0.27681562,3.7514443,,,,,,,,,,,,,,,,, -76800,0.2807219,3.7904112,,,,,,,,,,,,,,,,, -76900,0.2866698,3.8324702,,,,,,,,,,,,,,,,, -76989,,,0.6826446056365967,1.6596330404281616,34.03202156257172,0.6867862939834595,1.6362948417663574,30.34294338222907,3000.0,0.7015281319618225,1.541305422782898,30.30835640115028,3003.0,26909.29024219513,44296.71672439575,26909.29024219513,17383.99588394165,1.000993251800537,0.0 -77000,0.2857142,3.7510643,,,,,,,,,,,,,,,,, -77100,0.2717529,3.7833695,,,,,,,,,,,,,,,,, -77200,0.2738809,3.7811215,,,,,,,,,,,,,,,,, -77300,0.2896416,3.775028,,,,,,,,,,,,,,,,, -77400,0.28284645,3.7954278,,,,,,,,,,,,,,,,, -77500,0.26223728,3.755425,,,,,,,,,,,,,,,,, -77600,0.2887808,3.7162015,,,,,,,,,,,,,,,,, -77700,0.2853727,3.7510266,,,,,,,,,,,,,,,,, -77800,0.30178374,3.7805512,,,,,,,,,,,,,,,,, -77900,0.30093193,3.8736508,,,,,,,,,,,,,,,,, -78000,0.27246934,3.8123949,,,,,,,,,,,,,,,,, -78100,0.2716886,3.757193,,,,,,,,,,,,,,,,, -78200,0.3457239,3.7864463,,,,,,,,,,,,,,,,, -78300,0.2959015,3.7994502,,,,,,,,,,,,,,,,, -78400,0.27921543,3.7846594,,,,,,,,,,,,,,,,, -78500,0.28597587,3.813646,,,,,,,,,,,,,,,,, -78600,0.29269305,3.8186004,,,,,,,,,,,,,,,,, -78700,0.2746508,3.7794776,,,,,,,,,,,,,,,,, -78800,0.28362533,3.7688532,,,,,,,,,,,,,,,,, -78900,0.28134406,3.772106,,,,,,,,,,,,,,,,, -79000,0.29554674,3.7631533,,,,,,,,,,,,,,,,, -79100,0.27136028,3.770165,,,,,,,,,,,,,,,,, -79200,0.2968059,3.7404754,,,,,,,,,,,,,,,,, -79300,0.29377678,3.7437105,,,,,,,,,,,,,,,,, -79395,,,0.6748082637786865,1.7070127725601196,34.39932828355682,0.6856703758239746,1.6362732648849487,30.385912673364096,3000.0,0.7026785612106323,1.537244200706482,30.35795776214528,3003.0,27749.502287864685,45602.72605991364,27749.502287864685,17849.676304340363,1.0389277935028076,0.0 -79400,0.2991266,3.742049,,,,,,,,,,,,,,,,, -79500,0.28746882,3.7557714,,,,,,,,,,,,,,,,, -79600,0.30885667,3.7561612,,,,,,,,,,,,,,,,, -79700,0.29791042,3.8173616,,,,,,,,,,,,,,,,, -79800,0.27756754,3.8167562,,,,,,,,,,,,,,,,, -79900,0.28583354,3.754958,,,,,,,,,,,,,,,,, -80000,0.2914994,3.728364,,,,,,,,,,,,,,,,, -80100,0.2939666,3.837362,,,,,,,,,,,,,,,,, -80200,0.2748704,3.751708,,,,,,,,,,,,,,,,, -80300,0.29043657,3.8087847,,,,,,,,,,,,,,,,, -80400,0.28876558,3.7653573,,,,,,,,,,,,,,,,, -80500,0.29451954,3.7777698,,,,,,,,,,,,,,,,, -80600,0.29259205,3.761775,,,,,,,,,,,,,,,,, -80700,0.28429297,3.825726,,,,,,,,,,,,,,,,, -80800,0.29990143,3.7952583,,,,,,,,,,,,,,,,, -80900,0.27932945,3.762193,,,,,,,,,,,,,,,,, -81000,0.2750919,3.719279,,,,,,,,,,,,,,,,, -81100,0.27493957,3.7500181,,,,,,,,,,,,,,,,, -81200,0.30381182,3.773748,,,,,,,,,,,,,,,,, -81300,0.27656373,3.7426627,,,,,,,,,,,,,,,,, -81400,0.26941222,3.7468195,,,,,,,,,,,,,,,,, -81500,0.31907395,3.699884,,,,,,,,,,,,,,,,, -81600,0.2883647,3.75037,,,,,,,,,,,,,,,,, -81700,0.2784229,3.7755117,,,,,,,,,,,,,,,,, -81800,0.30825618,3.7775233,,,,,,,,,,,,,,,,, -81801,,,0.7017195820808411,1.5646322965621948,35.8828094371977,0.685732364654541,1.630754470825195,30.01887627978177,3000.0,0.7022950649261475,1.538099765777588,30.29497923799136,3003.0,28589.634435892105,46942.80326747894,28589.634435892105,18349.50480556488,1.0803887844085691,0.0 -81900,0.27636397,3.7598734,,,,,,,,,,,,,,,,, -82000,0.27124685,3.7404268,,,,,,,,,,,,,,,,, -82100,0.3005027,3.8033712,,,,,,,,,,,,,,,,, -82200,0.27696627,3.7181754,,,,,,,,,,,,,,,,, -82300,0.27770427,3.8232145,,,,,,,,,,,,,,,,, -82400,0.28366756,3.7770236,,,,,,,,,,,,,,,,, -82500,0.2819195,3.7979438,,,,,,,,,,,,,,,,, -82600,0.27300605,3.7182987,,,,,,,,,,,,,,,,, -82700,0.3016309,3.7701373,,,,,,,,,,,,,,,,, -82800,0.2705824,3.7547448,,,,,,,,,,,,,,,,, -82900,0.27814603,3.7751667,,,,,,,,,,,,,,,,, -83000,0.27410343,3.7528694,,,,,,,,,,,,,,,,, -83100,0.31159824,3.7032545,,,,,,,,,,,,,,,,, -83200,0.2715318,3.7388217,,,,,,,,,,,,,,,,, -83300,0.29844382,3.7748704,,,,,,,,,,,,,,,,, -83400,0.28277898,3.7060812,,,,,,,,,,,,,,,,, -83500,0.31907818,3.7506752,,,,,,,,,,,,,,,,, -83600,0.29202643,3.7434134,,,,,,,,,,,,,,,,, -83700,0.29874712,3.7504056,,,,,,,,,,,,,,,,, -83800,0.2862908,3.791963,,,,,,,,,,,,,,,,, -83900,0.28124237,3.7215033,,,,,,,,,,,,,,,,, -84000,0.28114325,3.803718,,,,,,,,,,,,,,,,, -84100,0.2901515,3.7508128,,,,,,,,,,,,,,,,, -84200,0.31118932,3.7056868,,,,,,,,,,,,,,,,, -84207,,,0.683625340461731,1.652172327041626,34.68102130080084,0.6857695579528809,1.6258001327514648,30.57193751932736,3000.0,0.7048166990280151,1.5285508632659912,30.50286774566312,3003.0,29429.53974795341,48268.892318964005,29429.53974795341,18835.574059724808,1.1179816722869873,0.0 -84300,0.31084108,3.765871,,,,,,,,,,,,,,,,, -84400,0.32219252,3.7695029,,,,,,,,,,,,,,,,, -84500,0.30106002,3.7659695,,,,,,,,,,,,,,,,, -84600,0.30881172,3.685477,,,,,,,,,,,,,,,,, -84700,0.31562534,3.7304811,,,,,,,,,,,,,,,,, -84800,0.29903656,3.7929351,,,,,,,,,,,,,,,,, -84900,0.286407,3.7446456,,,,,,,,,,,,,,,,, -85000,0.2824842,3.7189577,,,,,,,,,,,,,,,,, -85100,0.2904239,3.683318,,,,,,,,,,,,,,,,, -85200,0.31137532,3.7648313,,,,,,,,,,,,,,,,, -85300,0.298882,3.7854264,,,,,,,,,,,,,,,,, -85400,0.32301253,3.7285135,,,,,,,,,,,,,,,,, -85500,0.32197547,3.7271159,,,,,,,,,,,,,,,,, -85600,0.2886192,3.7366908,,,,,,,,,,,,,,,,, -85700,0.3122045,3.749446,,,,,,,,,,,,,,,,, -85800,0.30580282,3.723424,,,,,,,,,,,,,,,,, -85900,0.29040295,3.7887135,,,,,,,,,,,,,,,,, -86000,0.30024868,3.8284547,,,,,,,,,,,,,,,,, -86100,0.30620456,3.7964065,,,,,,,,,,,,,,,,, -86200,0.31761396,3.7545452,,,,,,,,,,,,,,,,, -86300,0.29046,3.7448494,,,,,,,,,,,,,,,,, -86400,0.2848237,3.695221,,,,,,,,,,,,,,,,, -86500,0.2984234,3.7463603,,,,,,,,,,,,,,,,, -86600,0.30650073,3.7546124,,,,,,,,,,,,,,,,, -86613,,,0.6862362027168274,1.647615671157837,34.51882434228991,0.6874062418937683,1.6228716373443604,30.232125626879338,3000.0,0.705142080783844,1.5264335870742798,30.38943963120881,3003.0,30269.54340934753,49691.58460474014,30269.54340934753,19418.148502588272,1.1565618515014648,0.0 -86700,0.35061252,3.7353117,,,,,,,,,,,,,,,,, -86800,0.2929998,3.7057164,,,,,,,,,,,,,,,,, -86900,0.33585113,3.7467048,,,,,,,,,,,,,,,,, -87000,0.29291454,3.7038746,,,,,,,,,,,,,,,,, -87100,0.3020398,3.7683568,,,,,,,,,,,,,,,,, -87200,0.29461643,3.725874,,,,,,,,,,,,,,,,, -87300,0.294259,3.7233975,,,,,,,,,,,,,,,,, -87400,0.30048373,3.7643542,,,,,,,,,,,,,,,,, -87500,0.28987464,3.7335434,,,,,,,,,,,,,,,,, -87600,0.29398346,3.7596855,,,,,,,,,,,,,,,,, -87700,0.2994892,3.760112,,,,,,,,,,,,,,,,, -87800,0.27860436,3.7432656,,,,,,,,,,,,,,,,, -87900,0.3254514,3.796094,,,,,,,,,,,,,,,,, -88000,0.31127188,3.7147825,,,,,,,,,,,,,,,,, -88100,0.30989367,3.7457125,,,,,,,,,,,,,,,,, -88200,0.32406938,3.744602,,,,,,,,,,,,,,,,, -88300,0.30744776,3.6871703,,,,,,,,,,,,,,,,, -88400,0.31210756,3.7912004,,,,,,,,,,,,,,,,, -88500,0.3264458,3.7159646,,,,,,,,,,,,,,,,, -88600,0.31848586,3.745501,,,,,,,,,,,,,,,,, -88700,0.29022476,3.7594366,,,,,,,,,,,,,,,,, -88800,0.3274535,3.780849,,,,,,,,,,,,,,,,, -88900,0.2820052,3.6470826,,,,,,,,,,,,,,,,, -89000,0.3004169,3.7074249,,,,,,,,,,,,,,,,, -89019,,,0.6950206756591797,1.5976887941360474,35.72473038291398,0.6871954202651978,1.61845600605011,30.638430798079032,3000.0,0.7048166990280151,1.5245695114135742,30.56950362536727,3003.0,31109.477601766583,51031.5335791111,31109.477601766583,19918.049089193344,1.1962296962738037,0.0 -89100,0.31614447,3.7623405,,,,,,,,,,,,,,,,, -89200,0.3122179,3.774277,,,,,,,,,,,,,,,,, -89300,0.2876651,3.7279646,,,,,,,,,,,,,,,,, -89400,0.30616575,3.6570535,,,,,,,,,,,,,,,,, -89500,0.30636394,3.7094998,,,,,,,,,,,,,,,,, -89600,0.30327722,3.7287624,,,,,,,,,,,,,,,,, -89700,0.30243662,3.693781,,,,,,,,,,,,,,,,, -89800,0.3016275,3.7242925,,,,,,,,,,,,,,,,, -89900,0.30099884,3.6954803,,,,,,,,,,,,,,,,, -90000,0.29321042,3.7456048,,,,,,,,,,,,,,,,, -90100,0.30977824,3.76222,,,,,,,,,,,,,,,,, -90200,0.30330783,3.7400825,,,,,,,,,,,,,,,,, -90300,0.29862016,3.7221067,,,,,,,,,,,,,,,,, -90400,0.32780656,3.7618308,,,,,,,,,,,,,,,,, -90500,0.3179383,3.728668,,,,,,,,,,,,,,,,, -90600,0.29368737,3.747593,,,,,,,,,,,,,,,,, -90700,0.28876254,3.6697662,,,,,,,,,,,,,,,,, -90800,0.3468697,3.722152,,,,,,,,,,,,,,,,, -90900,0.3305293,3.7258313,,,,,,,,,,,,,,,,, -91000,0.3127759,3.764712,,,,,,,,,,,,,,,,, -91100,0.3153522,3.7098575,,,,,,,,,,,,,,,,, -91200,0.31379706,3.7192898,,,,,,,,,,,,,,,,, -91300,0.30828673,3.7303324,,,,,,,,,,,,,,,,, -91400,0.30996293,3.6814382,,,,,,,,,,,,,,,,, -91425,,,0.693752646446228,1.605541110038757,34.55984270647691,0.688559353351593,1.6157660484313965,30.754953917388068,3000.0,0.7056185007095337,1.5198183059692385,30.65265721547933,3003.0,31949.395862579346,52349.0916249752,31949.395862579346,20395.571761369705,1.2364046573638916,0.0 -91500,0.31770608,3.7184906,,,,,,,,,,,,,,,,, -91600,0.28383082,3.668785,,,,,,,,,,,,,,,,, -91700,0.29831883,3.7379313,,,,,,,,,,,,,,,,, -91800,0.34036815,3.687771,,,,,,,,,,,,,,,,, -91900,0.29148844,3.695014,,,,,,,,,,,,,,,,, -92000,0.33094296,3.7607288,,,,,,,,,,,,,,,,, -92100,0.31153622,3.7536325,,,,,,,,,,,,,,,,, -92200,0.3131057,3.7052858,,,,,,,,,,,,,,,,, -92300,0.3189312,3.689019,,,,,,,,,,,,,,,,, -92400,0.3007541,3.6839225,,,,,,,,,,,,,,,,, -92500,0.3318184,3.7652662,,,,,,,,,,,,,,,,, -92600,0.33761007,3.7813463,,,,,,,,,,,,,,,,, -92700,0.318453,3.7425785,,,,,,,,,,,,,,,,, -92800,0.3158081,3.7075472,,,,,,,,,,,,,,,,, -92900,0.2918357,3.6843271,,,,,,,,,,,,,,,,, -93000,0.30275497,3.74097,,,,,,,,,,,,,,,,, -93100,0.31547824,3.720307,,,,,,,,,,,,,,,,, -93200,0.3140665,3.7604957,,,,,,,,,,,,,,,,, -93300,0.3177248,3.707401,,,,,,,,,,,,,,,,, -93400,0.31038818,3.7024157,,,,,,,,,,,,,,,,, -93500,0.31225473,3.7001438,,,,,,,,,,,,,,,,, -93600,0.31559956,3.7377458,,,,,,,,,,,,,,,,, -93700,0.3014023,3.7118773,,,,,,,,,,,,,,,,, -93800,0.3058963,3.6445107,,,,,,,,,,,,,,,,, -93831,,,0.7316763997077942,1.4330400228500366,38.14423812644837,0.6875426173210144,1.6164175271987915,30.55295250101604,3000.0,0.7061066031455994,1.518967866897583,30.63082315061668,3003.0,32789.459950208664,53728.417481184006,32789.459950208664,20934.71894097328,1.274902582168579,0.0 -93900,0.3138994,3.7180293,,,,,,,,,,,,,,,,, -94000,0.3004865,3.7076645,,,,,,,,,,,,,,,,, -94100,0.29599294,3.7056038,,,,,,,,,,,,,,,,, -94200,0.3289784,3.7612152,,,,,,,,,,,,,,,,, -94300,0.3050933,3.7323735,,,,,,,,,,,,,,,,, -94400,0.33973593,3.7656705,,,,,,,,,,,,,,,,, -94500,0.3186388,3.719492,,,,,,,,,,,,,,,,, -94600,0.32924578,3.706181,,,,,,,,,,,,,,,,, -94700,0.3325712,3.6893604,,,,,,,,,,,,,,,,, -94800,0.31776956,3.6575742,,,,,,,,,,,,,,,,, -94900,0.30263716,3.693276,,,,,,,,,,,,,,,,, -95000,0.33696392,3.7454393,,,,,,,,,,,,,,,,, -95100,0.33536795,3.7066052,,,,,,,,,,,,,,,,, -95200,0.32445887,3.6734805,,,,,,,,,,,,,,,,, -95300,0.32580099,3.6969151,,,,,,,,,,,,,,,,, -95400,0.29690307,3.6476755,,,,,,,,,,,,,,,,, -95500,0.3428785,3.706882,,,,,,,,,,,,,,,,, -95600,0.32037067,3.729008,,,,,,,,,,,,,,,,, -95700,0.34544608,3.772867,,,,,,,,,,,,,,,,, -95800,0.32821226,3.7414305,,,,,,,,,,,,,,,,, -95900,0.3305292,3.7599225,,,,,,,,,,,,,,,,, -96000,0.31895486,3.6655838,,,,,,,,,,,,,,,,, -96100,0.33872363,3.7516222,,,,,,,,,,,,,,,,, -96200,0.32959494,3.6838849,,,,,,,,,,,,,,,,, -96237,,,0.7016276121139526,1.558693528175354,35.49628524369535,0.6899976134300232,1.6058191061019895,30.619367228391614,3000.0,0.7069665193557739,1.5066982507705688,30.65861045851249,3003.0,33629.5340692997,55075.31088614464,33629.5340692997,21441.42243003845,1.31276273727417,0.0 -96300,0.34877366,3.7225049,,,,,,,,,,,,,,,,, -96400,0.33004147,3.7618215,,,,,,,,,,,,,,,,, -96500,0.32628927,3.7048109,,,,,,,,,,,,,,,,, -96600,0.31837818,3.6836824,,,,,,,,,,,,,,,,, -96700,0.32059407,3.689784,,,,,,,,,,,,,,,,, -96800,0.3244647,3.6920056,,,,,,,,,,,,,,,,, -96900,0.30779403,3.697148,,,,,,,,,,,,,,,,, -97000,0.32783395,3.7221262,,,,,,,,,,,,,,,,, -97100,0.34894353,3.6342194,,,,,,,,,,,,,,,,, -97200,0.32253903,3.7199767,,,,,,,,,,,,,,,,, -97300,0.31703883,3.7234256,,,,,,,,,,,,,,,,, -97400,0.32853743,3.7160387,,,,,,,,,,,,,,,,, -97500,0.3237136,3.6765606,,,,,,,,,,,,,,,,, -97600,0.3045884,3.6597888,,,,,,,,,,,,,,,,, -97700,0.3087513,3.6932719,,,,,,,,,,,,,,,,, -97800,0.3205641,3.6443663,,,,,,,,,,,,,,,,, -97900,0.325911,3.6823988,,,,,,,,,,,,,,,,, -98000,0.33992764,3.6506407,,,,,,,,,,,,,,,,, -98100,0.3151832,3.6224706,,,,,,,,,,,,,,,,, -98200,0.31547028,3.6733456,,,,,,,,,,,,,,,,, -98300,0.3587903,3.7376482,,,,,,,,,,,,,,,,, -98400,0.30933416,3.6359267,,,,,,,,,,,,,,,,, -98500,0.33041856,3.68325,,,,,,,,,,,,,,,,, -98600,0.3240413,3.686512,,,,,,,,,,,,,,,,, -98644,,,0.6927933096885681,1.601349115371704,35.66692699405113,0.6894148588180542,1.6130400896072388,30.435671431315622,3000.0,0.7068967819213867,1.5110329389572144,30.476182677665605,3003.0,34469.55902791023,56434.93587017059,34469.55902791023,21960.90658211708,1.353524923324585,0.0 -98700,0.33231393,3.6971798,,,,,,,,,,,,,,,,, -98800,0.33270675,3.6836364,,,,,,,,,,,,,,,,, -98900,0.31394538,3.6820095,,,,,,,,,,,,,,,,, -99000,0.33303064,3.7247524,,,,,,,,,,,,,,,,, -99100,0.31748483,3.64401,,,,,,,,,,,,,,,,, -99200,0.30321452,3.668704,,,,,,,,,,,,,,,,, -99300,0.32332265,3.7113786,,,,,,,,,,,,,,,,, -99400,0.32375014,3.6836183,,,,,,,,,,,,,,,,, -99500,0.33393943,3.7057385,,,,,,,,,,,,,,,,, -99600,0.31837714,3.632714,,,,,,,,,,,,,,,,, -99700,0.33074766,3.697203,,,,,,,,,,,,,,,,, -99800,0.32178342,3.6808496,,,,,,,,,,,,,,,,, -99900,0.3550095,3.6803973,,,,,,,,,,,,,,,,, -100000,0.33527616,3.688065,,,,,,,,,,,,,,,,, -100100,0.34002283,3.745837,,,,,,,,,,,,,,,,, -100200,0.3328669,3.6437364,,,,,,,,,,,,,,,,, -100300,0.3334463,3.6857088,,,,,,,,,,,,,,,,, -100400,0.35210705,3.7129698,,,,,,,,,,,,,,,,, -100500,0.3382416,3.6934597,,,,,,,,,,,,,,,,, -100600,0.33955178,3.6749427,,,,,,,,,,,,,,,,, -100700,0.3365352,3.5973911,,,,,,,,,,,,,,,,, -100800,0.34233037,3.7042894,,,,,,,,,,,,,,,,, -100900,0.32661423,3.7231026,,,,,,,,,,,,,,,,, -101000,0.33493024,3.6284416,,,,,,,,,,,,,,,,, -101051,,,0.7072663307189941,1.5225698947906494,36.582741766440975,0.6892164945602417,1.6107611656188965,30.412669470445955,3000.0,0.7070245742797852,1.507551908493042,30.72464167955525,3003.0,35309.74000072479,57728.85360980034,35309.74000072479,22414.529752969745,1.3917343616485596,0.0 -101100,0.34039316,3.6712887,,,,,,,,,,,,,,,,, -101200,0.32754147,3.6171014,,,,,,,,,,,,,,,,, -101300,0.31438783,3.6644096,,,,,,,,,,,,,,,,, -101400,0.33140442,3.65925,,,,,,,,,,,,,,,,, -101500,0.3425125,3.7061942,,,,,,,,,,,,,,,,, -101600,0.3748868,3.6970081,,,,,,,,,,,,,,,,, -101700,0.3280698,3.6622498,,,,,,,,,,,,,,,,, -101800,0.3214159,3.6518178,,,,,,,,,,,,,,,,, -101900,0.32024002,3.7064419,,,,,,,,,,,,,,,,, -102000,0.31494933,3.665982,,,,,,,,,,,,,,,,, -102100,0.33107215,3.6391547,,,,,,,,,,,,,,,,, -102200,0.36840156,3.7189949,,,,,,,,,,,,,,,,, -102300,0.34085986,3.71004,,,,,,,,,,,,,,,,, -102400,0.32369623,3.6446776,,,,,,,,,,,,,,,,, -102500,0.32229403,3.6524363,,,,,,,,,,,,,,,,, -102600,0.3377369,3.6444564,,,,,,,,,,,,,,,,, -102700,0.3383208,3.6593225,,,,,,,,,,,,,,,,, -102800,0.32854357,3.6560779,,,,,,,,,,,,,,,,, -102900,0.34524447,3.7193162,,,,,,,,,,,,,,,,, -103000,0.34652594,3.6974862,,,,,,,,,,,,,,,,, -103100,0.33560777,3.703618,,,,,,,,,,,,,,,,, -103200,0.33188188,3.6307585,,,,,,,,,,,,,,,,, -103300,0.34582606,3.7018306,,,,,,,,,,,,,,,,, -103400,0.3582868,3.7249546,,,,,,,,,,,,,,,,, -103458,,,0.6998023390769958,1.5615475177764893,35.472857329654225,0.6898860335350037,1.608350396156311,30.398652314834955,3000.0,0.707466185092926,1.5031170845031738,30.460744589687334,3003.0,36149.81416296959,59098.26532077789,36149.81416296959,22943.753257989883,1.431570529937744,0.0 -103500,0.33668706,3.6071124,,,,,,,,,,,,,,,,, -103600,0.32689965,3.660812,,,,,,,,,,,,,,,,, -103700,0.33490676,3.646875,,,,,,,,,,,,,,,,, -103800,0.3598016,3.7158527,,,,,,,,,,,,,,,,, -103900,0.32854903,3.6265218,,,,,,,,,,,,,,,,, -104000,0.34276712,3.6703866,,,,,,,,,,,,,,,,, -104100,0.33435932,3.663948,,,,,,,,,,,,,,,,, -104200,0.37596184,3.7441826,,,,,,,,,,,,,,,,, -104300,0.34437314,3.6041882,,,,,,,,,,,,,,,,, -104400,0.33905512,3.6809587,,,,,,,,,,,,,,,,, -104500,0.33621436,3.672651,,,,,,,,,,,,,,,,, -104600,0.33647934,3.6450648,,,,,,,,,,,,,,,,, -104700,0.34989226,3.6552927,,,,,,,,,,,,,,,,, -104800,0.3498142,3.6235013,,,,,,,,,,,,,,,,, -104900,0.33950353,3.657523,,,,,,,,,,,,,,,,, -105000,0.35784754,3.6680155,,,,,,,,,,,,,,,,, -105100,0.3534816,3.707846,,,,,,,,,,,,,,,,, -105200,0.3482628,3.6439297,,,,,,,,,,,,,,,,, -105300,0.34163252,3.7013066,,,,,,,,,,,,,,,,, -105400,0.33334687,3.633984,,,,,,,,,,,,,,,,, -105500,0.37416103,3.6875281,,,,,,,,,,,,,,,,, -105600,0.35718915,3.669057,,,,,,,,,,,,,,,,, -105700,0.36024165,3.6681862,,,,,,,,,,,,,,,,, -105800,0.36842504,3.6952047,,,,,,,,,,,,,,,,, -105864,,,0.7032337188720703,1.5497827529907229,36.129084059359286,0.6897744536399841,1.6095569133758545,30.705780856663885,3000.0,0.707524299621582,1.5049904584884644,30.83850855238789,3003.0,36989.73922157288,60446.634615659714,36989.73922157288,23452.08066368103,1.4725525379180908,0.0 -105900,0.34158733,3.6346176,,,,,,,,,,,,,,,,, -106000,0.35855192,3.6768842,,,,,,,,,,,,,,,,, -106100,0.3467902,3.6358967,,,,,,,,,,,,,,,,, -106200,0.330967,3.6126027,,,,,,,,,,,,,,,,, -106300,0.36110985,3.6105301,,,,,,,,,,,,,,,,, -106400,0.33616862,3.613503,,,,,,,,,,,,,,,,, -106500,0.3472967,3.5889325,,,,,,,,,,,,,,,,, -106600,0.36491814,3.680433,,,,,,,,,,,,,,,,, -106700,0.36001313,3.6492903,,,,,,,,,,,,,,,,, -106800,0.36085403,3.6605806,,,,,,,,,,,,,,,,, -106900,0.3712309,3.6715002,,,,,,,,,,,,,,,,, -107000,0.35600042,3.6359596,,,,,,,,,,,,,,,,, -107100,0.34948424,3.6496236,,,,,,,,,,,,,,,,, -107200,0.36859715,3.6986578,,,,,,,,,,,,,,,,, -107300,0.32284626,3.6034093,,,,,,,,,,,,,,,,, -107400,0.35185856,3.6560097,,,,,,,,,,,,,,,,, -107500,0.3396247,3.669613,,,,,,,,,,,,,,,,, -107600,0.33794045,3.6420627,,,,,,,,,,,,,,,,, -107700,0.335717,3.6155381,,,,,,,,,,,,,,,,, -107800,0.37574342,3.6830742,,,,,,,,,,,,,,,,, -107900,0.3618843,3.6946843,,,,,,,,,,,,,,,,, -108000,0.36026323,3.6767664,,,,,,,,,,,,,,,,, -108100,0.35811234,3.6853654,,,,,,,,,,,,,,,,, -108200,0.36297733,3.6004443,,,,,,,,,,,,,,,,, -108269,,,0.7132049202919006,1.4947891235351562,36.41010826418029,0.6916218996047974,1.6021170616149902,30.63931633701646,3000.0,0.7091279029846191,1.499886393547058,30.73570104529433,3003.0,37829.69793653488,61776.03334188461,37829.69793653488,23941.40462422371,1.512440204620361,0.0 -108300,0.3427932,3.609861,,,,,,,,,,,,,,,,, -108400,0.354469,3.6439967,,,,,,,,,,,,,,,,, -108500,0.36179328,3.658078,,,,,,,,,,,,,,,,, -108600,0.36948183,3.6751986,,,,,,,,,,,,,,,,, -108700,0.36190215,3.7101629,,,,,,,,,,,,,,,,, -108800,0.36426556,3.6477664,,,,,,,,,,,,,,,,, -108900,0.36100936,3.6847072,,,,,,,,,,,,,,,,, -109000,0.35931075,3.6317492,,,,,,,,,,,,,,,,, -109100,0.35788754,3.6010396,,,,,,,,,,,,,,,,, -109200,0.35342765,3.6446972,,,,,,,,,,,,,,,,, -109300,0.34635374,3.6399057,,,,,,,,,,,,,,,,, -109400,0.36736092,3.6896229,,,,,,,,,,,,,,,,, -109500,0.35012347,3.6783056,,,,,,,,,,,,,,,,, -109600,0.3800331,3.6275728,,,,,,,,,,,,,,,,, -109700,0.3301892,3.5901148,,,,,,,,,,,,,,,,, -109800,0.36269206,3.6658466,,,,,,,,,,,,,,,,, -109900,0.3543515,3.7032094,,,,,,,,,,,,,,,,, -110000,0.37096357,3.6888828,,,,,,,,,,,,,,,,, -110100,0.3411283,3.6304035,,,,,,,,,,,,,,,,, -110200,0.3588388,3.6723568,,,,,,,,,,,,,,,,, -110300,0.38792962,3.6309447,,,,,,,,,,,,,,,,, -110400,0.37309402,3.6500316,,,,,,,,,,,,,,,,, -110500,0.37107265,3.6722522,,,,,,,,,,,,,,,,, -110600,0.36994967,3.6555274,,,,,,,,,,,,,,,,, -110674,,,0.7082228064537048,1.514228343963623,36.62626379524225,0.6904811859130859,1.604583740234375,30.574834332474047,3000.0,0.7097786664962769,1.4970381259918213,30.64078556582809,3003.0,38669.68802213669,63077.23090171814,38669.68802213669,24402.493275880814,1.5521199703216553,0.0 -110700,0.36971095,3.6677856,,,,,,,,,,,,,,,,, -110800,0.35324094,3.6374772,,,,,,,,,,,,,,,,, -110900,0.34843862,3.631721,,,,,,,,,,,,,,,,, -111000,0.3952803,3.6506186,,,,,,,,,,,,,,,,, -111100,0.3593933,3.6236,,,,,,,,,,,,,,,,, -111200,0.34991097,3.6252682,,,,,,,,,,,,,,,,, -111300,0.37810376,3.6005046,,,,,,,,,,,,,,,,, -111400,0.3719036,3.6614685,,,,,,,,,,,,,,,,, -111500,0.3640828,3.6587625,,,,,,,,,,,,,,,,, -111600,0.36356756,3.6669164,,,,,,,,,,,,,,,,, -111700,0.38611835,3.5635197,,,,,,,,,,,,,,,,, -111800,0.3782532,3.6808279,,,,,,,,,,,,,,,,, -111900,0.3800952,3.6884928,,,,,,,,,,,,,,,,, -112000,0.35792363,3.6505392,,,,,,,,,,,,,,,,, -112100,0.35796714,3.637494,,,,,,,,,,,,,,,,, -112200,0.38020828,3.6944728,,,,,,,,,,,,,,,,, -112300,0.35194877,3.6487837,,,,,,,,,,,,,,,,, -112400,0.36354917,3.5785022,,,,,,,,,,,,,,,,, -112500,0.36921406,3.6843066,,,,,,,,,,,,,,,,, -112600,0.36219496,3.5872917,,,,,,,,,,,,,,,,, -112700,0.3853198,3.64254,,,,,,,,,,,,,,,,, -112800,0.34789994,3.6420465,,,,,,,,,,,,,,,,, -112900,0.37916777,3.6297603,,,,,,,,,,,,,,,,, -113000,0.368249,3.661022,,,,,,,,,,,,,,,,, -113080,,,0.7224063277244568,1.443524718284607,37.28813238914175,0.6914855241775513,1.6032272577285769,30.60543726425647,3000.0,0.7103829383850098,1.4962009191513062,30.95519761628077,3003.0,39509.64947485924,64382.861696243286,39509.64947485924,24868.04579520225,1.593592405319214,0.0 -113100,0.38396296,3.6529467,,,,,,,,,,,,,,,,, -113200,0.36554685,3.6214335,,,,,,,,,,,,,,,,, -113300,0.37625706,3.651604,,,,,,,,,,,,,,,,, -113400,0.3825323,3.7240138,,,,,,,,,,,,,,,,, -113500,0.37604642,3.5845168,,,,,,,,,,,,,,,,, -113600,0.36855158,3.626553,,,,,,,,,,,,,,,,, -113700,0.37642333,3.7034047,,,,,,,,,,,,,,,,, -113800,0.3541622,3.6037269,,,,,,,,,,,,,,,,, -113900,0.37461537,3.707555,,,,,,,,,,,,,,,,, -114000,0.36464724,3.5717413,,,,,,,,,,,,,,,,, -114100,0.36404186,3.615102,,,,,,,,,,,,,,,,, -114200,0.36692244,3.6513107,,,,,,,,,,,,,,,,, -114300,0.37055737,3.6702993,,,,,,,,,,,,,,,,, -114400,0.37617776,3.6133316,,,,,,,,,,,,,,,,, -114500,0.38127998,3.6661313,,,,,,,,,,,,,,,,, -114600,0.357323,3.6151438,,,,,,,,,,,,,,,,, -114700,0.3745093,3.635466,,,,,,,,,,,,,,,,, -114800,0.3800197,3.6714635,,,,,,,,,,,,,,,,, -114900,0.3822174,3.6699991,,,,,,,,,,,,,,,,, -115000,0.36412913,3.6277041,,,,,,,,,,,,,,,,, -115100,0.36167854,3.5722048,,,,,,,,,,,,,,,,, -115200,0.38436642,3.5960784,,,,,,,,,,,,,,,,, -115300,0.36976215,3.6663065,,,,,,,,,,,,,,,,, -115400,0.36487707,3.5666955,,,,,,,,,,,,,,,,, -115487,,,0.713723361492157,1.487293004989624,36.435484973088045,0.6912871599197388,1.6010518074035645,30.65375561094501,3000.0,0.7099064588546753,1.4967427253723145,30.68944762531302,3003.0,40349.80061197281,65698.97678422928,40349.80061197281,25343.89398145676,1.635310173034668,0.0 -115500,0.39088202,3.671459,,,,,,,,,,,,,,,,, -115600,0.38819093,3.6155221,,,,,,,,,,,,,,,,, -115700,0.37545225,3.6609402,,,,,,,,,,,,,,,,, -115800,0.38869756,3.6642978,,,,,,,,,,,,,,,,, -115900,0.37820578,3.6425776,,,,,,,,,,,,,,,,, -116000,0.3820084,3.6147845,,,,,,,,,,,,,,,,, -116100,0.37316567,3.6371593,,,,,,,,,,,,,,,,, -116200,0.36008382,3.602319,,,,,,,,,,,,,,,,, -116300,0.38108426,3.654441,,,,,,,,,,,,,,,,, -116400,0.4138864,3.646159,,,,,,,,,,,,,,,,, -116500,0.3719636,3.6402843,,,,,,,,,,,,,,,,, -116600,0.3918235,3.6158206,,,,,,,,,,,,,,,,, -116700,0.37510958,3.618018,,,,,,,,,,,,,,,,, -116800,0.40280306,3.6480515,,,,,,,,,,,,,,,,, -116900,0.3875313,3.6157758,,,,,,,,,,,,,,,,, -117000,0.38172552,3.6052108,,,,,,,,,,,,,,,,, -117100,0.38169682,3.6277926,,,,,,,,,,,,,,,,, -117200,0.3707842,3.6061425,,,,,,,,,,,,,,,,, -117300,0.3956229,3.561478,,,,,,,,,,,,,,,,, -117400,0.3693765,3.601528,,,,,,,,,,,,,,,,, -117500,0.38547915,3.6849194,,,,,,,,,,,,,,,,, -117600,0.36585936,3.5787477,,,,,,,,,,,,,,,,, -117700,0.3749697,3.6299925,,,,,,,,,,,,,,,,, -117800,0.36647376,3.5792644,,,,,,,,,,,,,,,,, -117894,,,0.7144909501075745,1.490625500679016,37.108529161669374,0.6910887360572815,1.6030627489089966,30.71050729467754,3000.0,0.7090465426445007,1.4972790479660034,30.936780229896527,3003.0,41189.93091821671,66997.73618650436,41189.93091821671,25802.397783994675,1.6849846839904783,0.0 -117900,0.37639007,3.6143546,,,,,,,,,,,,,,,,, -118000,0.39059508,3.6272852,,,,,,,,,,,,,,,,, -118100,0.3758977,3.6506476,,,,,,,,,,,,,,,,, -118200,0.36561215,3.6033065,,,,,,,,,,,,,,,,, -118300,0.38201588,3.6192636,,,,,,,,,,,,,,,,, -118400,0.3805443,3.618354,,,,,,,,,,,,,,,,, -118500,0.3879925,3.6207094,,,,,,,,,,,,,,,,, -118600,0.39273477,3.6280274,,,,,,,,,,,,,,,,, -118700,0.40376064,3.6851819,,,,,,,,,,,,,,,,, -118800,0.36828294,3.6502705,,,,,,,,,,,,,,,,, -118900,0.40982196,3.614712,,,,,,,,,,,,,,,,, -119000,0.39197966,3.6418228,,,,,,,,,,,,,,,,, -119100,0.37461066,3.5980384,,,,,,,,,,,,,,,,, -119200,0.38885513,3.6267703,,,,,,,,,,,,,,,,, -119300,0.38169336,3.6173065,,,,,,,,,,,,,,,,, -119400,0.39789715,3.6528177,,,,,,,,,,,,,,,,, -119500,0.37864786,3.6254253,,,,,,,,,,,,,,,,, -119600,0.36391044,3.5680733,,,,,,,,,,,,,,,,, -119700,0.40485427,3.6828303,,,,,,,,,,,,,,,,, -119800,0.39512825,3.647404,,,,,,,,,,,,,,,,, -119900,0.37426332,3.6031826,,,,,,,,,,,,,,,,, -120000,0.41340822,3.5776863,,,,,,,,,,,,,,,,, -120100,0.38167122,3.5968997,,,,,,,,,,,,,,,,, -120200,0.37235352,3.6249216,,,,,,,,,,,,,,,,, -120300,,,0.7231228351593018,1.442205548286438,37.805040018899184,0.6926014423370361,1.6032183170318604,30.68321910491765,3000.0,0.7098715901374817,1.4950560331344604,30.98069747621216,3003.0,42030.07278776169,68307.31502318382,42030.07278776169,26271.71107816696,1.7338509559631348,0.0 -120300,0.37427604,3.5984297,,,,,,,,,,,,,,,,, -120400,0.3866681,3.6268814,,,,,,,,,,,,,,,,, -120500,0.36842257,3.566448,,,,,,,,,,,,,,,,, -120600,0.37799752,3.631665,,,,,,,,,,,,,,,,, -120700,0.39574325,3.6070101,,,,,,,,,,,,,,,,, -120800,0.3722482,3.591265,,,,,,,,,,,,,,,,, -120900,0.38098937,3.5999336,,,,,,,,,,,,,,,,, -121000,0.38493583,3.5999334,,,,,,,,,,,,,,,,, -121100,0.3836096,3.6457486,,,,,,,,,,,,,,,,, -121200,0.39913288,3.5902753,,,,,,,,,,,,,,,,, -121300,0.38455206,3.650585,,,,,,,,,,,,,,,,, -121400,0.35780165,3.5725732,,,,,,,,,,,,,,,,, -121500,0.3807809,3.595095,,,,,,,,,,,,,,,,, -121600,0.3625324,3.5904024,,,,,,,,,,,,,,,,, -121700,0.38644668,3.6171646,,,,,,,,,,,,,,,,, -121800,0.3895909,3.606101,,,,,,,,,,,,,,,,, -121900,0.37701216,3.6264322,,,,,,,,,,,,,,,,, -122000,0.37580055,3.6322908,,,,,,,,,,,,,,,,, -122100,0.37380788,3.5849955,,,,,,,,,,,,,,,,, -122200,0.3805996,3.6320145,,,,,,,,,,,,,,,,, -122300,0.41411847,3.6336272,,,,,,,,,,,,,,,,, -122400,0.35891134,3.5466192,,,,,,,,,,,,,,,,, -122500,0.38404375,3.614268,,,,,,,,,,,,,,,,, -122600,0.36986357,3.576423,,,,,,,,,,,,,,,,, -122700,0.38906312,3.6723826,,,,,,,,,,,,,,,,, -122710,,,0.7189242243766785,1.4655243158340454,37.58180576067513,0.6922046542167664,1.6004964113235474,30.80034796220132,3000.0,0.7106037139892578,1.4928041696548462,30.8694776006975,3003.0,42870.12835860253,69614.16217589378,42870.12835860253,26738.382657289505,1.7781829833984375,0.0 -122800,0.39033353,3.6050277,,,,,,,,,,,,,,,,, -122900,0.3980435,3.5972164,,,,,,,,,,,,,,,,, -123000,0.3888151,3.6180763,,,,,,,,,,,,,,,,, -123100,0.3838007,3.6056504,,,,,,,,,,,,,,,,, -123200,0.3967258,3.5912642,,,,,,,,,,,,,,,,, -123300,0.39255813,3.623242,,,,,,,,,,,,,,,,, -123400,0.37049696,3.6319664,,,,,,,,,,,,,,,,, -123500,0.40832525,3.6343942,,,,,,,,,,,,,,,,, -123600,0.3705564,3.599128,,,,,,,,,,,,,,,,, -123700,0.3719457,3.6382706,,,,,,,,,,,,,,,,, -123800,0.37902853,3.5886533,,,,,,,,,,,,,,,,, -123900,0.37512887,3.6135411,,,,,,,,,,,,,,,,, -124000,0.39346278,3.6422803,,,,,,,,,,,,,,,,, -124100,0.376801,3.5920267,,,,,,,,,,,,,,,,, -124200,0.40255994,3.5998266,,,,,,,,,,,,,,,,, -124300,0.3921362,3.6614912,,,,,,,,,,,,,,,,, -124400,0.38954124,3.6001444,,,,,,,,,,,,,,,,, -124500,0.4035891,3.6666024,,,,,,,,,,,,,,,,, -124600,0.38258263,3.6092165,,,,,,,,,,,,,,,,, -124700,0.39083374,3.6220765,,,,,,,,,,,,,,,,, -124800,0.38459837,3.598307,,,,,,,,,,,,,,,,, -124900,0.37175965,3.6253722,,,,,,,,,,,,,,,,, -125000,0.37472227,3.607207,,,,,,,,,,,,,,,,, -125100,0.4012012,3.6451654,,,,,,,,,,,,,,,,, -125116,,,0.7213461399078369,1.4562888145446775,37.47862548500293,0.6918947100639343,1.6007949113845823,30.825696629345693,3000.0,0.7102667093276978,1.4937689304351809,30.91072129641423,3003.0,43710.05010247231,70918.60228586197,43710.05010247231,27202.7842271328,1.8205416202545168,0.0 -125200,0.39092588,3.559071,,,,,,,,,,,,,,,,, -125300,0.4023068,3.7009776,,,,,,,,,,,,,,,,, -125400,0.40092072,3.650985,,,,,,,,,,,,,,,,, -125500,0.3822943,3.6265123,,,,,,,,,,,,,,,,, -125600,0.38012883,3.5900028,,,,,,,,,,,,,,,,, -125700,0.395516,3.6065574,,,,,,,,,,,,,,,,, -125800,0.37381652,3.6003177,,,,,,,,,,,,,,,,, -125900,0.38872665,3.6255789,,,,,,,,,,,,,,,,, -126000,0.40096012,3.632129,,,,,,,,,,,,,,,,, -126100,0.37575653,3.6166732,,,,,,,,,,,,,,,,, -126200,0.39140853,3.5859075,,,,,,,,,,,,,,,,, -126300,0.37479693,3.5848475,,,,,,,,,,,,,,,,, -126400,0.37871653,3.5191119,,,,,,,,,,,,,,,,, -126500,0.38486132,3.6142318,,,,,,,,,,,,,,,,, -126600,0.37839797,3.6047447,,,,,,,,,,,,,,,,, -126700,0.39977697,3.6397562,,,,,,,,,,,,,,,,, -126800,0.38863987,3.6552904,,,,,,,,,,,,,,,,, -126900,0.37503213,3.5647032,,,,,,,,,,,,,,,,, -127000,0.38647082,3.599258,,,,,,,,,,,,,,,,, -127100,0.37166157,3.5729134,,,,,,,,,,,,,,,,, -127200,0.38905612,3.64048,,,,,,,,,,,,,,,,, -127300,0.40068746,3.6545086,,,,,,,,,,,,,,,,, -127400,0.38778493,3.6293523,,,,,,,,,,,,,,,,, -127500,0.38908508,3.5728464,,,,,,,,,,,,,,,,, -127522,,,0.7226781845092773,1.445623517036438,37.238805367993805,0.6923658847808838,1.6010756492614746,30.86297775830037,3000.0,0.710719883441925,1.493880033493042,30.998903514021368,3003.0,44550.06398534775,72225.58293819427,44550.06398534775,27669.634046316147,1.8627715110778809,0.0 -127600,0.38168904,3.5968046,,,,,,,,,,,,,,,,, -127700,0.39516234,3.6141725,,,,,,,,,,,,,,,,, -127800,0.37393588,3.5977235,,,,,,,,,,,,,,,,, -127900,0.3865732,3.588866,,,,,,,,,,,,,,,,, -128000,0.38198698,3.661424,,,,,,,,,,,,,,,,, -128100,0.4134392,3.6124797,,,,,,,,,,,,,,,,, -128200,0.37195024,3.6022725,,,,,,,,,,,,,,,,, -128300,0.38376367,3.5758994,,,,,,,,,,,,,,,,, -128400,0.38908753,3.5959785,,,,,,,,,,,,,,,,, -128500,0.36620143,3.5535367,,,,,,,,,,,,,,,,, -128600,0.3729243,3.595149,,,,,,,,,,,,,,,,, -128700,0.3618724,3.5382056,,,,,,,,,,,,,,,,, -128800,0.3877852,3.5832102,,,,,,,,,,,,,,,,, -128900,0.3766971,3.603162,,,,,,,,,,,,,,,,, -129000,0.4093981,3.6393385,,,,,,,,,,,,,,,,, -129100,0.3822908,3.6429124,,,,,,,,,,,,,,,,, -129200,0.3903945,3.6091714,,,,,,,,,,,,,,,,, -129300,0.38740894,3.6131825,,,,,,,,,,,,,,,,, -129400,0.38541567,3.6190248,,,,,,,,,,,,,,,,, -129500,0.39461637,3.6485858,,,,,,,,,,,,,,,,, -129600,0.38071138,3.6058328,,,,,,,,,,,,,,,,, -129700,0.38851583,3.5739057,,,,,,,,,,,,,,,,, -129800,0.37684426,3.5838895,,,,,,,,,,,,,,,,, -129900,0.38678977,3.6215608,,,,,,,,,,,,,,,,, -129928,,,0.7224118113517761,1.4469929933547974,37.6780744346327,0.6922914981842041,1.6008161306381226,30.87728482321264,3000.0,0.7106850147247314,1.4928480386734009,30.92520062066843,3003.0,45389.95973515511,73539.53221654892,45389.95973515511,28143.56909942627,1.9056484699249268,0.0 -130000,0.3750921,3.603476,,,,,,,,,,,,,,,,, -130100,0.39110124,3.6055024,,,,,,,,,,,,,,,,, -130200,0.39388904,3.6419175,,,,,,,,,,,,,,,,, -130300,0.38581675,3.593704,,,,,,,,,,,,,,,,, -130400,0.35755837,3.5260766,,,,,,,,,,,,,,,,, -130500,0.3796562,3.6262434,,,,,,,,,,,,,,,,, -130600,0.4016845,3.6115067,,,,,,,,,,,,,,,,, -130700,0.38105562,3.6168215,,,,,,,,,,,,,,,,, -130800,0.37435016,3.5851605,,,,,,,,,,,,,,,,, -130900,0.3995883,3.6828344,,,,,,,,,,,,,,,,, -131000,0.37136438,3.5812957,,,,,,,,,,,,,,,,, -131100,0.3898926,3.5943823,,,,,,,,,,,,,,,,, -131200,0.35786295,3.6084585,,,,,,,,,,,,,,,,, -131300,0.3958829,3.6447947,,,,,,,,,,,,,,,,, -131400,0.38705865,3.5803635,,,,,,,,,,,,,,,,, -131500,0.38111702,3.6037111,,,,,,,,,,,,,,,,, -131600,0.3851484,3.6252468,,,,,,,,,,,,,,,,, -131700,0.40145227,3.622037,,,,,,,,,,,,,,,,, -131800,0.37338623,3.6491816,,,,,,,,,,,,,,,,, -131900,0.38477048,3.5971773,,,,,,,,,,,,,,,,, -132000,0.38957867,3.5851579,,,,,,,,,,,,,,,,, -132100,0.38251346,3.5810504,,,,,,,,,,,,,,,,, -132200,0.3769524,3.601448,,,,,,,,,,,,,,,,, -132300,0.37188506,3.5729795,,,,,,,,,,,,,,,,, -132335,,,0.7235233783721924,1.4436975717544556,37.51305828525132,0.6922170519828796,1.6004061698913574,30.95629465203926,3000.0,0.7106037139892578,1.492241621017456,30.889636565732552,3003.0,46230.10717344284,74849.48199796677,46230.10717344284,28613.24405527115,1.9580817222595213,0.0 -132400,0.38759288,3.5793507,,,,,,,,,,,,,,,,, -132500,0.38513246,3.6459944,,,,,,,,,,,,,,,,, -132600,0.38133225,3.6219878,,,,,,,,,,,,,,,,, -132700,0.4310438,3.6139722,,,,,,,,,,,,,,,,, -132800,0.3843189,3.5822895,,,,,,,,,,,,,,,,, -132900,0.37623993,3.619977,,,,,,,,,,,,,,,,, -133000,0.37634492,3.6015077,,,,,,,,,,,,,,,,, -133100,0.37727267,3.6012466,,,,,,,,,,,,,,,,, -133200,0.35574055,3.5478523,,,,,,,,,,,,,,,,, -133300,0.40533355,3.6818671,,,,,,,,,,,,,,,,, -133333,,,0.721368134021759,1.4508073329925537,37.40015536663024,0.6922046542167664,1.6004267930984497,30.95903228295605,3000.0,0.7105804681777954,1.4921976327896118,30.862555894325823,3003.0,46578.47412419319,75670.26692295074,46578.47412419319,29085.58303833008,2.0043785572052,0.0 -133333,,,,,,,,,,,,,,46578.47412419319,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 163ebf302..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -870.8081820011139,0.0,25.81904888153076,1,0,25.81904888153076,0.0007088489946909,0.0,11.041826248168944,3003,896.6272666454315,0.0005957768880762,0.0,11.064934730529783,0.0004835649742744,0.0,11.036645889282228,3000 -1450.2326774597168,0.0231471061706542,865.8087675571442,2403,0,865.8087675571442,0.3866132199764251,8.388442673207205,4.231210708618164,3003,2316.1432354450226,0.413316398859024,14.63394950552278,3.916454315185547,0.4011481404304504,9.957431113516414,4.038556098937988,3000 -1927.821095943451,0.0491232872009277,1705.7114703655243,4804,0,1705.7114703655243,0.543048083782196,18.875308714605914,2.6712963581085205,3003,3633.737921953201,0.5378835201263428,24.378417805692475,2.7001044750213623,0.5419771671295166,20.41463598087644,2.6425657272338867,3000 -2404.573896408081,0.0743830204010009,2545.865830183029,7207,0,2545.865830183029,0.5880890488624573,22.14029524164679,2.223965167999268,3003,4950.748953819275,0.580255389213562,27.101228921682285,2.2710001468658447,0.5857831835746765,23.166204664087257,2.226724624633789,3000 -2842.7774851322174,0.1002109050750732,3385.8169617652893,9611,0,3385.8169617652893,0.6116321086883545,23.503151632503048,2.003997325897217,3003,6229.004224777222,0.592166006565094,28.45783172233359,2.1426703929901123,0.606316089630127,24.673418436166767,2.0390303134918213,3000 -3454.847516775131,0.1273379325866699,4225.967286348343,12015,0,4225.967286348343,0.6232293248176575,24.42601174196864,1.8754557371139529,3003,7681.333123445511,0.6005766987800598,28.8104042804232,2.0656661987304688,0.6186903715133667,25.631044209587834,1.916280031204224,3000 -3892.275357961655,0.1567516326904297,5066.110929250717,14421,0,5066.110929250717,0.6382430195808411,25.50415595748084,1.775894045829773,3003,8959.01217341423,0.6168133616447449,29.425459383473545,1.9349526166915887,0.630469560623169,26.62938777326784,1.8252501487731927,3000 -4358.376479148865,0.1858303546905517,5906.1692090034485,16827,0,5906.1692090034485,0.6456103920936584,26.09951750501349,1.7116901874542236,3003,10265.277593374252,0.6208845973014832,29.833422011187903,1.901543140411377,0.6382933855056763,26.58655602133679,1.7594693899154663,3000 -4948.353693962097,0.2140419483184814,6746.2333035469055,19234,0,6746.2333035469055,0.6555923819541931,26.70090522635841,1.65578031539917,3003,11695.425820350649,0.6365357041358948,31.482361801789065,1.7768452167510986,0.6452617049217224,27.29098019252068,1.712559938430786,3000 -5539.945268630981,0.2424554824829101,7586.293897151947,21641,0,7586.293897151947,0.6589041948318481,26.76084053464521,1.6343481540679932,3003,13127.183210134506,0.6268560886383057,30.100959194200843,1.837093949317932,0.6466503739356995,27.4922055487141,1.69412362575531,3000 -6044.751909255981,0.2706022262573242,8426.46519112587,24049,0,8426.46519112587,0.657300591468811,26.49205587264576,1.6291598081588743,3003,14472.267474412918,0.6285372376441956,30.61643964662627,1.8340964317321773,0.6487582325935364,27.35033583597104,1.6795746088027954,3000 -6594.221925020218,0.2991580963134765,9266.6213285923,26457,0,9266.6213285923,0.6611701846122742,27.17247322045552,1.60666024684906,3003,15862.000746011734,0.6383501887321472,31.067359641378346,1.7676606178283691,0.6521803736686707,27.703317488145554,1.6639560461044312,3000 -7149.713710069656,0.3287866115570068,10106.590117931366,28863,0,10106.590117931366,0.6616930961608887,27.073841281803126,1.588875412940979,3003,17257.571030139923,0.6295340657234192,30.57459531827869,1.8214620351791384,0.6530855298042297,27.829608372370423,1.650708794593811,3000 -7678.296708583832,0.3578181266784668,10946.506821632383,31270,0,10946.506821632383,0.6649003624916077,27.376689492425207,1.5839022397994995,3003,18626.177748918533,0.6339073777198792,31.101627683228955,1.8032761812210083,0.6544494032859802,27.79121788889197,1.6417778730392456,3000 -8185.084298849106,0.3881950378417969,11786.677701950071,33679,0,11786.677701950071,0.6679217219352722,27.7718476746868,1.5678555965423584,3003,19973.24415802956,0.640436053276062,31.079749336354755,1.7520774602890017,0.6569168567657471,27.985500631619285,1.6260370016098022,3000 -8668.399793624878,0.4177160263061523,12626.625022888184,36087,0,12626.625022888184,0.6699668765068054,27.99159121351021,1.5581209659576416,3003,21296.61293053627,0.6369031071662903,30.80821034725975,1.7718901634216309,0.6572268009185791,28.23449537713964,1.6264159679412842,3000 -9162.348851919174,0.4478676319122314,13466.692353248596,38495,0,13466.692353248596,0.6695834398269653,27.83423761879015,1.555444359779358,3003,22630.738989830017,0.6430981159210205,31.545833523294696,1.7214125394821167,0.6587147116661072,28.382361296399072,1.6191388368606567,3000 -9674.442895889282,0.4786746501922607,14306.815686225891,40904,0,14306.815686225891,0.668677031993866,27.83588488357372,1.5514838695526123,3003,23983.063989400864,0.6393962502479553,31.326564034977284,1.7517564296722412,0.6592354774475098,28.034270082296626,1.6140036582946775,3000 -10184.291736602783,0.5092837810516357,15146.73096871376,43312,0,15146.73096871376,0.6726861000061035,28.08085070114241,1.533181071281433,3003,25332.93573999405,0.6379010081291199,31.306862035234403,1.7598146200180054,0.6621120572090149,28.58756001053182,1.6029443740844729,3000 -10889.123789072037,0.5416944026947021,15986.845165491104,45720,0,15986.845165491104,0.6723374724388123,27.78312192161673,1.5314693450927734,3003,26877.99145627021,0.642835795879364,31.51465397782276,1.7258001565933228,0.6610457301139832,28.238541337906103,1.5972543954849243,3000 -11406.725434064863,0.5741453170776367,16826.8217689991,48128,0,16826.8217689991,0.6742199659347534,27.99158132715858,1.5215578079223633,3003,28235.679394960403,0.6452369689941406,31.82425775905154,1.726482629776001,0.6602894067764282,28.259296946768995,1.590267300605774,3000 -11933.50832438469,0.6121892929077148,17666.926292657852,50537,0,17666.926292657852,0.675358772277832,28.05993585302966,1.518868088722229,3003,29602.6824324131,0.6565302610397339,32.767179635812575,1.640212893486023,0.6637363433837891,28.6552268817002,1.5860275030136108,3000 -12486.558814764025,0.6497421264648438,18507.14643883705,52946,0,18507.14643883705,0.6761606335639954,28.09270029743737,1.5175516605377195,3003,30996.068150520325,0.6452774405479431,31.52326751196575,1.7110553979873655,0.6643067002296448,28.57172511434688,1.580517292022705,3000 -13019.85465836525,0.6876661777496338,19347.095779180527,55354,0,19347.095779180527,0.6759397983551025,28.14646888038358,1.5070810317993164,3003,32369.428835868835,0.6445077657699585,31.342913567592323,1.712929129600525,0.663823127746582,28.29432199708726,1.576424241065979,3000 -13559.9481112957,0.7214796543121338,20187.175048589703,57761,0,20187.175048589703,0.6782522797584534,28.270571580751486,1.4985307455062866,3003,33749.7142393589,0.649020254611969,32.0713855796427,1.682659387588501,0.6664641499519348,28.97768596478849,1.5685975551605225,3000 -14104.895560979843,0.7536098957061768,21027.074191093445,60169,0,21027.074191093445,0.6790889501571655,28.51629844762099,1.4897910356521606,3003,35134.669848680496,0.648051917552948,31.88436570280064,1.6906379461288452,0.6664765477180481,28.958729574332384,1.5599160194396973,3000 -14563.95576763153,0.7935366630554199,21867.165781974792,62577,0,21867.165781974792,0.678345263004303,28.343151127526827,1.491255760192871,3003,36433.94098830223,0.6952161192893982,35.779476196254954,1.432866096496582,0.6689935326576233,29.131601118820857,1.5532121658325195,3000 -15045.877503871918,0.8332781791687012,22707.37787055969,64986,0,22707.37787055969,0.6832142472267151,28.673106638091205,1.4675744771957395,3003,37756.19357728958,0.6524246335029602,31.827023341578347,1.6702783107757568,0.6694151163101196,28.90430783369693,1.539971113204956,3000 -15647.19865846634,0.9477291107177734,23547.34636592865,67394,0,23547.34636592865,0.6817849278450012,28.59808848117703,1.4733388423919678,3003,39197.67491483688,0.6514863967895508,31.94321390091731,1.6692497730255127,0.6705062389373779,29.242645233820426,1.540156364440918,3000 -16195.264653921127,0.9868950843811036,24387.48620867729,69801,0,24387.48620867729,0.684980571269989,29.0403300225907,1.4600976705551147,3003,40585.99910902977,0.6611226201057434,32.63959978255694,1.6098320484161377,0.6719693541526794,29.068353792762743,1.5329428911209106,3000 -16704.657161474228,1.0241799354553225,25227.61577558517,72213,0,25227.61577558517,0.6867700815200806,29.133928657891733,1.4527868032455444,3003,41935.63665962219,0.6569823026657104,32.19100745654006,1.6434437036514282,0.6732092499732971,29.36287108991364,1.5275042057037354,3000 -17256.88942360878,1.0615234375,26067.56249308586,74621,0,26067.56249308586,0.6867700815200806,28.99523113730168,1.4440298080444336,3003,43327.928844451904,0.6536204218864441,32.52132858933652,1.6660853624343872,0.6733828186988831,29.318499071867286,1.5218721628189087,3000 -17774.62933588028,1.0971648693084717,26907.49367928505,77029,0,26907.49367928505,0.6889199018478394,29.566770359025814,1.4393484592437744,3003,44685.71225547791,0.6623920798301697,33.60163001132954,1.6091350317001345,0.6767553687095642,29.60513743671253,1.5119057893753052,3000 -18292.526223659515,1.133357286453247,27747.6175262928,79440,0,27747.6175262928,0.6889663934707642,29.096602843757115,1.430628776550293,3003,46043.84932875633,0.6589401364326477,33.36118043719307,1.6313480138778689,0.6741020083427429,29.13616132874532,1.511282444000244,3000 -18959.974626541138,1.171102523803711,28587.77638578415,81848,0,28587.77638578415,0.6912091374397278,29.329210019172148,1.4248101711273191,3003,47551.57074832916,0.6777380108833313,33.80495966434508,1.5110163688659668,0.6777721047401428,29.83331729621893,1.4969340562820437,3000 -19460.1951122284,1.209423542022705,29427.856281757355,84257,0,29427.856281757355,0.6941258907318115,29.71491599447524,1.4093471765518188,3003,48891.9859354496,0.6639957427978516,33.48077142570666,1.589513659477234,0.6791732311248779,29.60846193664788,1.4911015033721924,3000 -20017.788903951645,1.2460203170776367,30268.08896493912,86666,0,30268.08896493912,0.693370521068573,29.51739396229348,1.4054758548736572,3003,50289.92531251907,0.6675595045089722,32.93528053933675,1.5793704986572266,0.6773629784584045,29.58336269676801,1.487136960029602,3000 -20590.925078868862,1.283271074295044,31108.269670009613,89075,0,31108.269670009613,0.696914792060852,29.432154247263547,1.3995591402053833,3003,51703.35502099991,0.6773558855056763,33.889525427289215,1.5134402513504028,0.6815290451049805,29.90880797413811,1.4774235486984253,3000 -21150.12527155876,1.319873571395874,31948.175379037857,91482,0,31948.175379037857,0.6981697678565979,29.75321280222057,1.38344144821167,3003,53102.57399082184,0.6696231365203857,33.075913430571255,1.5658307075500488,0.6815414428710938,29.879511675630194,1.4682722091674805,3000 -21653.25521636009,1.3652818202972412,32788.21763634682,93890,0,32788.21763634682,0.697519063949585,29.793031114372457,1.383360743522644,3003,54445.86746001244,0.7072305679321289,36.74364605152375,1.355831503868103,0.6831037402153015,30.11848958134699,1.466017246246338,3000 -22213.611531734467,1.4041471481323242,33628.44213843346,96299,0,33628.44213843346,0.6997618079185486,30.24534871633443,1.3693721294403076,3003,55846.56341743469,0.6774893403053284,34.065012769999825,1.5120842456817627,0.6842816472053528,30.26416928198846,1.452805519104004,3000 -22746.94253396988,1.4424107074737549,34468.39393520355,98707,0,34468.39393520355,0.6995061635971069,30.188361116708307,1.36750328540802,3003,57219.96085047722,0.6742954254150391,34.28313749510387,1.5320757627487185,0.6845172047615051,30.17301385143447,1.4518380165100098,3000 -23276.88682627678,1.4825043678283691,35308.40327715874,101115,0,35308.40327715874,0.701225996017456,30.21796655190293,1.3644455671310425,3003,58590.03084850311,0.6891303658485413,34.77914901250355,1.4438997507095337,0.6863027215003967,30.29630986323259,1.444441556930542,3000 -23805.79743504524,1.5310685634613037,36148.475451231,103522,0,36148.475451231,0.7032827734947205,30.08292498174387,1.3556402921676636,3003,59959.14341711998,0.6818674206733704,34.39548551799863,1.4904634952545166,0.6874682307243347,30.242459944963176,1.439511775970459,3000 -24323.759889364243,1.5710361003875732,36988.668182611465,105930,0,36988.668182611465,0.7045843005180359,30.341159836823227,1.3529555797576904,3003,61317.41492795944,0.6841477751731873,34.1736090872612,1.4826703071594238,0.68779057264328,30.294645014237727,1.4393343925476074,3000 -24850.823167324063,1.614072561264038,37828.766984939575,108337,0,37828.766984939575,0.7040265202522278,30.276564514232,1.3469500541687012,3003,62684.69976902008,0.6883156895637512,35.21473145739453,1.451397180557251,0.6892040967941284,30.54711567710777,1.4314101934432983,3000 -25363.45061326027,1.654350519180298,38668.92079091072,110745,0,38668.92079091072,0.7047585844993591,30.64520647586331,1.3457623720169067,3003,64037.598447322845,0.6906625032424927,35.11437290724362,1.4407415390014648,0.6892908811569214,30.62353490460124,1.4316939115524292,3000 -25987.774053812027,1.6964483261108398,39508.85071182251,113152,0,39508.85071182251,0.705909013748169,30.24255994252021,1.339033603668213,3003,65501.97001576424,0.7029312252998352,36.05811097263953,1.376530647277832,0.691299557685852,30.798797566594644,1.4230973720550537,3000 -26519.51696491241,1.7374508380889893,40348.93051624298,115560,0,40348.93051624298,0.7081633806228638,30.60836370943533,1.335689663887024,3003,66873.91002559662,0.7006024122238159,35.91853943560247,1.3854455947875977,0.6913491487503052,30.759350276539493,1.4228339195251465,3000 -27035.818858623505,1.779308557510376,41188.847472667694,117969,0,41188.847472667694,0.7076637148857117,30.73994588003593,1.326730251312256,3003,68230.25590658188,0.6944806575775146,35.57020902703697,1.42534601688385,0.691969096660614,30.88002752465824,1.4174968004226685,3000 -27588.76014494896,1.8211784362792969,42028.750698804855,120376,0,42028.750698804855,0.7080472111701965,30.485523564052155,1.3269232511520386,3003,69623.21995973587,0.7052478790283203,36.04900926844258,1.3589560985565186,0.6927626132965088,31.06434313028884,1.4140264987945557,3000 -28117.84645557404,1.864339828491211,42868.7630045414,122788,0,42868.7630045414,0.709337055683136,30.99924664179277,1.32336163520813,3003,70992.43810915947,0.7059203386306763,36.08858036119556,1.3507002592086792,0.6928866505622864,30.96063553398092,1.412611722946167,3000 -28649.569895029068,1.914734125137329,43708.71370458603,125195,0,43708.71370458603,0.7094649076461792,30.60444715357637,1.3227872848510742,3003,72364.23915076256,0.7114030718803406,37.05163518567466,1.329108476638794,0.6935189962387085,30.95792946222888,1.4125992059707642,3000 -29168.70892763137,1.9580953121185305,44548.620322942734,127603,0,44548.620322942734,0.7101272344589233,30.63157748084628,1.321054458618164,3003,73723.40506076813,0.7080180048942566,36.59462651717983,1.3448046445846558,0.6943497061729431,30.97203620968157,1.4095133543014526,3000 -29694.35936927796,2.001633644104004,45388.499920129776,130010,0,45388.499920129776,0.7105107307434082,30.66140440468964,1.320068359375,3003,75089.05698800087,0.7112451791763306,37.12121517508366,1.3329684734344482,0.6941761374473572,31.13628827962512,1.4096252918243408,3000 -30233.320573568344,2.052457571029663,46228.42295074463,132416,0,46228.42295074463,0.7104874849319458,30.713799107089372,1.3203920125961304,3003,76468.07191419601,0.7100916504859924,37.0542430438851,1.3323860168457031,0.6940397620201111,31.14853330741573,1.4101505279541016,3000 -30753.830042362213,2.0963242053985596,46548.19679522514,133333,0,46548.19679522514,0.7104526162147522,30.73949967608062,1.3205291032791138,3003,77308.42940235138,0.7102746367454529,37.00281637120834,1.3364615440368652,0.6941141486167908,31.145169924199735,1.4101862907409668,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/measurements.csv deleted file mode 100644 index 89cc2443d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.592714,11.052679,,,,,,,,,,,,,,,,, -1,,,0.0005957768880762,11.064934730529783,0.0,0.0004835649742744,11.036645889282228,0.0,3000.0,0.0007088489946909,11.041826248168944,0.0,3003.0,25.81904888153076,896.6272666454315,25.81904888153076,870.8081820011139,0.0,0.0 -100,0.4681868,8.697461,,,,,,,,,,,,,,,,, -200,0.18712346,8.3111515,,,,,,,,,,,,,,,,, -300,0.20741197,8.086752,,,,,,,,,,,,,,,,, -400,0.2908298,7.6012554,,,,,,,,,,,,,,,,, -500,0.3411262,7.267909,,,,,,,,,,,,,,,,, -600,0.6625122,7.0151267,,,,,,,,,,,,,,,,, -700,0.53656554,6.734341,,,,,,,,,,,,,,,,, -800,1.0281254,6.4052477,,,,,,,,,,,,,,,,, -900,0.82552105,6.22725,,,,,,,,,,,,,,,,, -1000,0.61394733,6.0266924,,,,,,,,,,,,,,,,, -1100,0.6189682,5.7974906,,,,,,,,,,,,,,,,, -1200,1.0285046,5.6107087,,,,,,,,,,,,,,,,, -1300,1.5139393,5.487386,,,,,,,,,,,,,,,,, -1400,0.74749976,5.1831026,,,,,,,,,,,,,,,,, -1500,0.8098886,5.076829,,,,,,,,,,,,,,,,, -1600,0.8217279,4.958217,,,,,,,,,,,,,,,,, -1700,0.9257667,4.824285,,,,,,,,,,,,,,,,, -1800,0.88512135,4.6586,,,,,,,,,,,,,,,,, -1900,1.4460672,4.5378466,,,,,,,,,,,,,,,,, -2000,0.8387927,4.419963,,,,,,,,,,,,,,,,, -2100,1.0295527,4.2755837,,,,,,,,,,,,,,,,, -2200,1.0105151,4.21023,,,,,,,,,,,,,,,,, -2300,0.9296192,4.0935464,,,,,,,,,,,,,,,,, -2400,0.75842565,3.8946235,,,,,,,,,,,,,,,,, -2403,,,0.413316398859024,3.916454315185547,14.63394950552278,0.4011481404304504,4.038556098937988,9.957431113516414,3000.0,0.3866132199764251,4.231210708618164,8.388442673207205,3003.0,865.8087675571442,2316.1432354450226,865.8087675571442,1450.2326774597168,0.0231471061706542,0.0 -2500,1.1828837,3.7850266,,,,,,,,,,,,,,,,, -2600,0.9653972,3.7851515,,,,,,,,,,,,,,,,, -2700,0.91851866,3.6630082,,,,,,,,,,,,,,,,, -2800,0.8346634,3.573356,,,,,,,,,,,,,,,,, -2900,0.9627428,3.5654683,,,,,,,,,,,,,,,,, -3000,1.2038976,3.4554634,,,,,,,,,,,,,,,,, -3100,1.0925719,3.2538428,,,,,,,,,,,,,,,,, -3200,0.77891666,3.3519263,,,,,,,,,,,,,,,,, -3300,0.8416723,3.2602115,,,,,,,,,,,,,,,,, -3400,0.8267934,3.1497436,,,,,,,,,,,,,,,,, -3500,0.79682696,3.0963447,,,,,,,,,,,,,,,,, -3600,0.70992094,3.2270143,,,,,,,,,,,,,,,,, -3700,0.76004416,3.018071,,,,,,,,,,,,,,,,, -3800,0.90953743,3.006879,,,,,,,,,,,,,,,,, -3900,0.7197831,2.9263241,,,,,,,,,,,,,,,,, -4000,0.97711575,3.0014098,,,,,,,,,,,,,,,,, -4100,0.7758073,2.7327747,,,,,,,,,,,,,,,,, -4200,1.036137,2.856352,,,,,,,,,,,,,,,,, -4300,0.7312067,2.9084134,,,,,,,,,,,,,,,,, -4400,0.6235975,2.7776556,,,,,,,,,,,,,,,,, -4500,0.6100672,2.7083077,,,,,,,,,,,,,,,,, -4600,0.6857011,2.7222297,,,,,,,,,,,,,,,,, -4700,0.6011423,2.6399064,,,,,,,,,,,,,,,,, -4800,0.62791485,2.7368932,,,,,,,,,,,,,,,,, -4804,,,0.5378835201263428,2.7001044750213623,24.378417805692475,0.5419771671295166,2.6425657272338867,20.41463598087644,3000.0,0.543048083782196,2.6712963581085205,18.875308714605914,3003.0,1705.7114703655243,3633.737921953201,1705.7114703655243,1927.821095943451,0.0491232872009277,0.0 -4900,0.7012809,2.6127758,,,,,,,,,,,,,,,,, -5000,0.60776997,2.5820625,,,,,,,,,,,,,,,,, -5100,0.64816934,2.6604483,,,,,,,,,,,,,,,,, -5200,0.6622174,2.5948513,,,,,,,,,,,,,,,,, -5300,0.5674035,2.5941114,,,,,,,,,,,,,,,,, -5400,0.5617918,2.56837,,,,,,,,,,,,,,,,, -5500,0.65343,2.4806337,,,,,,,,,,,,,,,,, -5600,0.6340234,2.5058022,,,,,,,,,,,,,,,,, -5700,0.7897588,2.5731363,,,,,,,,,,,,,,,,, -5800,0.7426128,2.5169554,,,,,,,,,,,,,,,,, -5900,0.6952588,2.5308506,,,,,,,,,,,,,,,,, -6000,0.5335246,2.480517,,,,,,,,,,,,,,,,, -6100,0.69597805,2.4298394,,,,,,,,,,,,,,,,, -6200,0.6550032,2.4585273,,,,,,,,,,,,,,,,, -6300,0.6755178,2.3830435,,,,,,,,,,,,,,,,, -6400,0.5734746,2.4704227,,,,,,,,,,,,,,,,, -6500,0.49167645,2.4335833,,,,,,,,,,,,,,,,, -6600,0.5382071,2.427095,,,,,,,,,,,,,,,,, -6700,0.5506985,2.3694,,,,,,,,,,,,,,,,, -6800,0.5553699,2.3649554,,,,,,,,,,,,,,,,, -6900,0.62438107,2.385627,,,,,,,,,,,,,,,,, -7000,0.47748977,2.3449442,,,,,,,,,,,,,,,,, -7100,0.61434424,2.32917,,,,,,,,,,,,,,,,, -7200,0.48428503,2.4006286,,,,,,,,,,,,,,,,, -7207,,,0.580255389213562,2.2710001468658447,27.101228921682285,0.5857831835746765,2.226724624633789,23.166204664087257,3000.0,0.5880890488624573,2.223965167999268,22.14029524164679,3003.0,2545.865830183029,4950.748953819275,2545.865830183029,2404.573896408081,0.0743830204010009,0.0 -7300,0.4711456,2.2699316,,,,,,,,,,,,,,,,, -7400,0.52122945,2.3093605,,,,,,,,,,,,,,,,, -7500,0.45637473,2.2322285,,,,,,,,,,,,,,,,, -7600,0.50672036,2.3793283,,,,,,,,,,,,,,,,, -7700,0.4764448,2.1862571,,,,,,,,,,,,,,,,, -7800,0.5018725,2.1767404,,,,,,,,,,,,,,,,, -7900,0.4731151,2.2061443,,,,,,,,,,,,,,,,, -8000,0.4492985,2.3729563,,,,,,,,,,,,,,,,, -8100,0.49520278,2.1872578,,,,,,,,,,,,,,,,, -8200,0.4463452,2.2611308,,,,,,,,,,,,,,,,, -8300,0.45652878,2.232826,,,,,,,,,,,,,,,,, -8400,0.4434976,2.2042356,,,,,,,,,,,,,,,,, -8500,0.4222139,2.2914107,,,,,,,,,,,,,,,,, -8600,0.38615593,2.2125158,,,,,,,,,,,,,,,,, -8700,0.42441085,2.2129526,,,,,,,,,,,,,,,,, -8800,0.41059148,2.2832365,,,,,,,,,,,,,,,,, -8900,0.36835888,2.1612475,,,,,,,,,,,,,,,,, -9000,0.36057857,2.159263,,,,,,,,,,,,,,,,, -9100,0.3769983,2.2267625,,,,,,,,,,,,,,,,, -9200,0.42217946,2.2962272,,,,,,,,,,,,,,,,, -9300,0.35362437,2.1047294,,,,,,,,,,,,,,,,, -9400,0.41034177,2.243636,,,,,,,,,,,,,,,,, -9500,0.3577381,2.1864405,,,,,,,,,,,,,,,,, -9600,0.3744964,2.1440997,,,,,,,,,,,,,,,,, -9611,,,0.592166006565094,2.1426703929901123,28.45783172233359,0.606316089630127,2.0390303134918213,24.673418436166767,3000.0,0.6116321086883545,2.003997325897217,23.503151632503048,3003.0,3385.8169617652893,6229.004224777222,3385.8169617652893,2842.7774851322174,0.1002109050750732,0.0 -9700,0.38674787,2.1063695,,,,,,,,,,,,,,,,, -9800,0.32187632,2.1780746,,,,,,,,,,,,,,,,, -9900,0.34319246,2.100873,,,,,,,,,,,,,,,,, -10000,0.3395971,2.0458608,,,,,,,,,,,,,,,,, -10100,0.2970691,2.0983589,,,,,,,,,,,,,,,,, -10200,0.35189086,2.1584067,,,,,,,,,,,,,,,,, -10300,0.3969047,2.0232956,,,,,,,,,,,,,,,,, -10400,0.33126807,2.0314026,,,,,,,,,,,,,,,,, -10500,0.3064937,2.142928,,,,,,,,,,,,,,,,, -10600,0.32559404,2.0767415,,,,,,,,,,,,,,,,, -10700,0.31402373,2.1355135,,,,,,,,,,,,,,,,, -10800,0.3058915,2.1180089,,,,,,,,,,,,,,,,, -10900,0.30223498,2.0735786,,,,,,,,,,,,,,,,, -11000,0.3185493,2.1620598,,,,,,,,,,,,,,,,, -11100,0.30122188,2.067061,,,,,,,,,,,,,,,,, -11200,0.32274607,2.0991933,,,,,,,,,,,,,,,,, -11300,0.3170115,2.1486914,,,,,,,,,,,,,,,,, -11400,0.3010861,2.007144,,,,,,,,,,,,,,,,, -11500,0.29479596,2.1603131,,,,,,,,,,,,,,,,, -11600,0.3304868,2.0148501,,,,,,,,,,,,,,,,, -11700,0.2914129,2.0644972,,,,,,,,,,,,,,,,, -11800,0.29909322,2.048775,,,,,,,,,,,,,,,,, -11900,0.28286052,2.074661,,,,,,,,,,,,,,,,, -12000,0.3185562,2.1083853,,,,,,,,,,,,,,,,, -12015,,,0.6005766987800598,2.0656661987304688,28.8104042804232,0.6186903715133667,1.916280031204224,25.631044209587834,3000.0,0.6232293248176575,1.8754557371139529,24.42601174196864,3003.0,4225.967286348343,7681.333123445511,4225.967286348343,3454.847516775131,0.1273379325866699,0.0 -12100,0.28101876,2.0261433,,,,,,,,,,,,,,,,, -12200,0.30301583,2.033987,,,,,,,,,,,,,,,,, -12300,0.2704767,2.0841904,,,,,,,,,,,,,,,,, -12400,0.25128418,1.9735012,,,,,,,,,,,,,,,,, -12500,0.33014014,2.0755975,,,,,,,,,,,,,,,,, -12600,0.3455937,2.009761,,,,,,,,,,,,,,,,, -12700,0.33199266,2.0101738,,,,,,,,,,,,,,,,, -12800,0.28427854,2.1435316,,,,,,,,,,,,,,,,, -12900,0.26898327,1.9314381,,,,,,,,,,,,,,,,, -13000,0.29856727,2.0331528,,,,,,,,,,,,,,,,, -13100,0.34032497,2.0947247,,,,,,,,,,,,,,,,, -13200,0.2920345,2.0711167,,,,,,,,,,,,,,,,, -13300,0.27446258,2.0580554,,,,,,,,,,,,,,,,, -13400,0.32284144,2.0327425,,,,,,,,,,,,,,,,, -13500,0.27302742,2.0847173,,,,,,,,,,,,,,,,, -13600,0.26919872,1.9754106,,,,,,,,,,,,,,,,, -13700,0.30579767,2.012624,,,,,,,,,,,,,,,,, -13800,0.25018173,1.9599034,,,,,,,,,,,,,,,,, -13900,0.2726137,1.9638746,,,,,,,,,,,,,,,,, -14000,0.29458407,2.0022929,,,,,,,,,,,,,,,,, -14100,0.31185958,2.018892,,,,,,,,,,,,,,,,, -14200,0.2693221,1.9155159,,,,,,,,,,,,,,,,, -14300,0.28756556,1.9401166,,,,,,,,,,,,,,,,, -14400,0.31566638,1.9561839,,,,,,,,,,,,,,,,, -14421,,,0.6168133616447449,1.9349526166915887,29.425459383473545,0.630469560623169,1.8252501487731927,26.62938777326784,3000.0,0.6382430195808411,1.775894045829773,25.50415595748084,3003.0,5066.110929250717,8959.01217341423,5066.110929250717,3892.275357961655,0.1567516326904297,0.0 -14500,0.29153976,1.9467006,,,,,,,,,,,,,,,,, -14600,0.2638954,2.0305665,,,,,,,,,,,,,,,,, -14700,0.28107762,1.937986,,,,,,,,,,,,,,,,, -14800,0.3792091,2.0086896,,,,,,,,,,,,,,,,, -14900,0.31422785,1.970697,,,,,,,,,,,,,,,,, -15000,0.29401964,1.9797343,,,,,,,,,,,,,,,,, -15100,0.31947842,1.9434891,,,,,,,,,,,,,,,,, -15200,0.3379171,1.9766681,,,,,,,,,,,,,,,,, -15300,0.30764967,1.9781438,,,,,,,,,,,,,,,,, -15400,0.3373255,1.9706547,,,,,,,,,,,,,,,,, -15500,0.28504652,1.9488202,,,,,,,,,,,,,,,,, -15600,0.27217123,2.0258074,,,,,,,,,,,,,,,,, -15700,0.286722,2.0083568,,,,,,,,,,,,,,,,, -15800,0.3064458,1.8950356,,,,,,,,,,,,,,,,, -15900,0.29276034,1.9923228,,,,,,,,,,,,,,,,, -16000,0.44158784,1.8409418,,,,,,,,,,,,,,,,, -16100,0.3036794,1.9753591,,,,,,,,,,,,,,,,, -16200,0.27621055,1.868546,,,,,,,,,,,,,,,,, -16300,0.31133652,1.9533942,,,,,,,,,,,,,,,,, -16400,0.28820184,1.872526,,,,,,,,,,,,,,,,, -16500,0.33226755,1.8736644,,,,,,,,,,,,,,,,, -16600,0.32054216,1.8965061,,,,,,,,,,,,,,,,, -16700,0.35175428,1.8627661,,,,,,,,,,,,,,,,, -16800,0.31447482,1.9258661,,,,,,,,,,,,,,,,, -16827,,,0.6208845973014832,1.901543140411377,29.833422011187903,0.6382933855056763,1.7594693899154663,26.58655602133679,3000.0,0.6456103920936584,1.7116901874542236,26.09951750501349,3003.0,5906.1692090034485,10265.277593374252,5906.1692090034485,4358.376479148865,0.1858303546905517,0.0 -16900,0.30690083,1.9973823,,,,,,,,,,,,,,,,, -17000,0.29605404,2.025813,,,,,,,,,,,,,,,,, -17100,0.36491522,1.9666078,,,,,,,,,,,,,,,,, -17200,0.30533203,1.9000626,,,,,,,,,,,,,,,,, -17300,0.33760622,1.8502197,,,,,,,,,,,,,,,,, -17400,0.37122205,1.9476541,,,,,,,,,,,,,,,,, -17500,0.30493787,1.8606447,,,,,,,,,,,,,,,,, -17600,0.3671687,1.9161143,,,,,,,,,,,,,,,,, -17700,0.38578403,1.8982528,,,,,,,,,,,,,,,,, -17800,0.30357122,1.95856,,,,,,,,,,,,,,,,, -17900,0.3039239,1.834689,,,,,,,,,,,,,,,,, -18000,2.9788396,2.0383184,,,,,,,,,,,,,,,,, -18100,0.5362924,1.9721509,,,,,,,,,,,,,,,,, -18200,0.37028185,1.9415822,,,,,,,,,,,,,,,,, -18300,0.33792698,1.9409802,,,,,,,,,,,,,,,,, -18400,0.30252934,1.8533918,,,,,,,,,,,,,,,,, -18500,0.29934216,1.9112282,,,,,,,,,,,,,,,,, -18600,0.2912912,1.857013,,,,,,,,,,,,,,,,, -18700,0.37390813,1.8813515,,,,,,,,,,,,,,,,, -18800,0.32371536,1.7967116,,,,,,,,,,,,,,,,, -18900,0.3043273,1.8818431,,,,,,,,,,,,,,,,, -19000,0.33829907,1.7786016,,,,,,,,,,,,,,,,, -19100,0.37218285,1.8896068,,,,,,,,,,,,,,,,, -19200,0.3109969,1.8831685,,,,,,,,,,,,,,,,, -19234,,,0.6365357041358948,1.7768452167510986,31.482361801789065,0.6452617049217224,1.712559938430786,27.29098019252068,3000.0,0.6555923819541931,1.65578031539917,26.70090522635841,3003.0,6746.2333035469055,11695.425820350649,6746.2333035469055,4948.353693962097,0.2140419483184814,0.0 -19300,0.35650665,1.775511,,,,,,,,,,,,,,,,, -19400,0.33004484,1.8750563,,,,,,,,,,,,,,,,, -19500,0.40404645,1.8740579,,,,,,,,,,,,,,,,, -19600,0.3815835,1.8973186,,,,,,,,,,,,,,,,, -19700,0.29945394,1.8823389,,,,,,,,,,,,,,,,, -19800,0.30673698,1.9321772,,,,,,,,,,,,,,,,, -19900,0.3897186,1.8517203,,,,,,,,,,,,,,,,, -20000,0.35035676,1.8659406,,,,,,,,,,,,,,,,, -20100,0.38608027,1.8392985,,,,,,,,,,,,,,,,, -20200,0.32464695,1.8620508,,,,,,,,,,,,,,,,, -20300,0.3369587,1.8412588,,,,,,,,,,,,,,,,, -20400,0.37811947,1.9733369,,,,,,,,,,,,,,,,, -20500,0.3334584,1.9135917,,,,,,,,,,,,,,,,, -20600,0.32635248,1.9074267,,,,,,,,,,,,,,,,, -20700,0.35632175,1.8706045,,,,,,,,,,,,,,,,, -20800,0.49347717,1.8450594,,,,,,,,,,,,,,,,, -20900,0.33351538,1.810716,,,,,,,,,,,,,,,,, -21000,0.3497889,1.9332534,,,,,,,,,,,,,,,,, -21100,0.4530204,1.8711554,,,,,,,,,,,,,,,,, -21200,0.37630928,1.8849922,,,,,,,,,,,,,,,,, -21300,0.38358766,1.8258224,,,,,,,,,,,,,,,,, -21400,0.343954,1.9646896,,,,,,,,,,,,,,,,, -21500,0.33751842,1.8084923,,,,,,,,,,,,,,,,, -21600,0.3589019,1.8290564,,,,,,,,,,,,,,,,, -21641,,,0.6268560886383057,1.837093949317932,30.100959194200843,0.6466503739356995,1.69412362575531,27.4922055487141,3000.0,0.6589041948318481,1.6343481540679932,26.76084053464521,3003.0,7586.293897151947,13127.183210134506,7586.293897151947,5539.945268630981,0.2424554824829101,0.0 -21700,0.47602093,1.8226043,,,,,,,,,,,,,,,,, -21800,0.37331852,1.9144955,,,,,,,,,,,,,,,,, -21900,0.3501207,1.8340235,,,,,,,,,,,,,,,,, -22000,0.5078604,1.849787,,,,,,,,,,,,,,,,, -22100,0.48605302,1.8944321,,,,,,,,,,,,,,,,, -22200,0.35269216,1.8506491,,,,,,,,,,,,,,,,, -22300,0.36656082,1.8289539,,,,,,,,,,,,,,,,, -22400,0.37116992,1.9374105,,,,,,,,,,,,,,,,, -22500,0.35883173,1.8810202,,,,,,,,,,,,,,,,, -22600,0.37377328,1.8449992,,,,,,,,,,,,,,,,, -22700,0.40147316,1.9174094,,,,,,,,,,,,,,,,, -22800,0.36623585,1.7186309,,,,,,,,,,,,,,,,, -22900,0.37932068,1.8983781,,,,,,,,,,,,,,,,, -23000,0.38124767,1.8757979,,,,,,,,,,,,,,,,, -23100,0.36822146,1.8575021,,,,,,,,,,,,,,,,, -23200,0.40049782,1.8569038,,,,,,,,,,,,,,,,, -23300,0.34923545,1.7909491,,,,,,,,,,,,,,,,, -23400,0.4062172,1.8588067,,,,,,,,,,,,,,,,, -23500,0.411525,1.8385633,,,,,,,,,,,,,,,,, -23600,0.5045752,1.9093283,,,,,,,,,,,,,,,,, -23700,0.4364473,1.8382365,,,,,,,,,,,,,,,,, -23800,0.39444208,1.9077592,,,,,,,,,,,,,,,,, -23900,0.3687346,1.8338307,,,,,,,,,,,,,,,,, -24000,0.4413933,1.8259134,,,,,,,,,,,,,,,,, -24049,,,0.6285372376441956,1.8340964317321773,30.61643964662627,0.6487582325935364,1.6795746088027954,27.35033583597104,3000.0,0.657300591468811,1.6291598081588743,26.49205587264576,3003.0,8426.46519112587,14472.267474412918,8426.46519112587,6044.751909255981,0.2706022262573242,0.0 -24100,0.39843705,1.9184698,,,,,,,,,,,,,,,,, -24200,0.39462656,1.8463316,,,,,,,,,,,,,,,,, -24300,0.38078249,1.9653939,,,,,,,,,,,,,,,,, -24400,0.4316173,1.8943826,,,,,,,,,,,,,,,,, -24500,0.40630358,1.8381015,,,,,,,,,,,,,,,,, -24600,0.40006852,1.816814,,,,,,,,,,,,,,,,, -24700,0.4027758,1.7675481,,,,,,,,,,,,,,,,, -24800,0.42499587,1.8657136,,,,,,,,,,,,,,,,, -24900,0.39765325,1.8452255,,,,,,,,,,,,,,,,, -25000,0.48908514,1.8714839,,,,,,,,,,,,,,,,, -25100,0.4295737,1.8588016,,,,,,,,,,,,,,,,, -25200,0.46818736,1.8306619,,,,,,,,,,,,,,,,, -25300,0.49887735,1.9026623,,,,,,,,,,,,,,,,, -25400,0.3818433,1.8305919,,,,,,,,,,,,,,,,, -25500,0.38474393,1.7333717,,,,,,,,,,,,,,,,, -25600,0.3757308,1.8568472,,,,,,,,,,,,,,,,, -25700,0.35838667,1.8084259,,,,,,,,,,,,,,,,, -25800,0.44696385,1.7345161,,,,,,,,,,,,,,,,, -25900,0.42638788,1.8745223,,,,,,,,,,,,,,,,, -26000,0.48120418,1.6907576,,,,,,,,,,,,,,,,, -26100,0.36689323,1.8873512,,,,,,,,,,,,,,,,, -26200,0.38518885,1.8823318,,,,,,,,,,,,,,,,, -26300,0.39889076,1.8201277,,,,,,,,,,,,,,,,, -26400,0.35600054,1.9431458,,,,,,,,,,,,,,,,, -26457,,,0.6383501887321472,1.7676606178283691,31.067359641378346,0.6521803736686707,1.6639560461044312,27.703317488145554,3000.0,0.6611701846122742,1.60666024684906,27.17247322045552,3003.0,9266.6213285923,15862.000746011734,9266.6213285923,6594.221925020218,0.2991580963134765,0.0 -26500,0.41413504,1.8805035,,,,,,,,,,,,,,,,, -26600,0.36760116,1.8496878,,,,,,,,,,,,,,,,, -26700,0.4900926,1.8263333,,,,,,,,,,,,,,,,, -26800,0.40216067,1.8903328,,,,,,,,,,,,,,,,, -26900,0.4129022,1.8279676,,,,,,,,,,,,,,,,, -27000,0.38057044,1.7834455,,,,,,,,,,,,,,,,, -27100,0.39557752,1.8536677,,,,,,,,,,,,,,,,, -27200,0.3627808,1.792689,,,,,,,,,,,,,,,,, -27300,0.3825536,1.814895,,,,,,,,,,,,,,,,, -27400,0.37376747,1.7514449,,,,,,,,,,,,,,,,, -27500,0.41178948,1.7706591,,,,,,,,,,,,,,,,, -27600,0.7923808,1.8693359,,,,,,,,,,,,,,,,, -27700,0.46283662,1.9018244,,,,,,,,,,,,,,,,, -27800,0.4072631,1.8573602,,,,,,,,,,,,,,,,, -27900,0.44468793,1.8044447,,,,,,,,,,,,,,,,, -28000,0.3973455,1.8129128,,,,,,,,,,,,,,,,, -28100,0.45392975,1.7688222,,,,,,,,,,,,,,,,, -28200,0.38846084,1.8345301,,,,,,,,,,,,,,,,, -28300,0.41543514,1.7790856,,,,,,,,,,,,,,,,, -28400,0.41446313,1.8095365,,,,,,,,,,,,,,,,, -28500,0.43921763,1.8556062,,,,,,,,,,,,,,,,, -28600,0.39476117,1.8485055,,,,,,,,,,,,,,,,, -28700,0.37767646,1.8096567,,,,,,,,,,,,,,,,, -28800,0.37490347,1.8652512,,,,,,,,,,,,,,,,, -28863,,,0.6295340657234192,1.8214620351791384,30.57459531827869,0.6530855298042297,1.650708794593811,27.829608372370423,3000.0,0.6616930961608887,1.588875412940979,27.073841281803126,3003.0,10106.590117931366,17257.571030139923,10106.590117931366,7149.713710069656,0.3287866115570068,0.0 -28900,0.4354887,1.8953948,,,,,,,,,,,,,,,,, -29000,0.41814902,1.7944139,,,,,,,,,,,,,,,,, -29100,0.43074492,1.8502779,,,,,,,,,,,,,,,,, -29200,0.6651599,1.8559847,,,,,,,,,,,,,,,,, -29300,0.40734696,1.8982071,,,,,,,,,,,,,,,,, -29400,0.36458483,1.826256,,,,,,,,,,,,,,,,, -29500,0.36967945,1.8467114,,,,,,,,,,,,,,,,, -29600,0.45404676,1.7888813,,,,,,,,,,,,,,,,, -29700,0.40077883,1.8279141,,,,,,,,,,,,,,,,, -29800,0.4113802,1.7927915,,,,,,,,,,,,,,,,, -29900,0.37613255,1.8285052,,,,,,,,,,,,,,,,, -30000,0.42421353,1.8333848,,,,,,,,,,,,,,,,, -30100,0.39868513,1.8691437,,,,,,,,,,,,,,,,, -30200,0.407058,1.8235698,,,,,,,,,,,,,,,,, -30300,0.40144458,1.8384845,,,,,,,,,,,,,,,,, -30400,0.3414149,1.7475722,,,,,,,,,,,,,,,,, -30500,0.44084743,1.8342987,,,,,,,,,,,,,,,,, -30600,0.39556706,1.7280313,,,,,,,,,,,,,,,,, -30700,0.35945106,1.8032175,,,,,,,,,,,,,,,,, -30800,0.38714686,1.7490222,,,,,,,,,,,,,,,,, -30900,0.40265238,1.7917564,,,,,,,,,,,,,,,,, -31000,0.6266661,1.8109075,,,,,,,,,,,,,,,,, -31100,0.43310624,1.8737925,,,,,,,,,,,,,,,,, -31200,0.40640584,1.7662675,,,,,,,,,,,,,,,,, -31270,,,0.6339073777198792,1.8032761812210083,31.101627683228955,0.6544494032859802,1.6417778730392456,27.79121788889197,3000.0,0.6649003624916077,1.5839022397994995,27.376689492425207,3003.0,10946.506821632383,18626.177748918533,10946.506821632383,7678.296708583832,0.3578181266784668,0.0 -31300,0.5257296,1.8366225,,,,,,,,,,,,,,,,, -31400,0.40389058,1.8139735,,,,,,,,,,,,,,,,, -31500,0.49080938,1.8331722,,,,,,,,,,,,,,,,, -31600,0.8148243,1.8065109,,,,,,,,,,,,,,,,, -31700,0.47817317,1.8533349,,,,,,,,,,,,,,,,, -31800,0.36206627,1.728167,,,,,,,,,,,,,,,,, -31900,0.49294418,1.8390095,,,,,,,,,,,,,,,,, -32000,0.3466213,1.8036317,,,,,,,,,,,,,,,,, -32100,0.38315162,1.8314763,,,,,,,,,,,,,,,,, -32200,0.5579533,1.8214624,,,,,,,,,,,,,,,,, -32300,0.43069276,1.8460625,,,,,,,,,,,,,,,,, -32400,0.41862482,1.8664871,,,,,,,,,,,,,,,,, -32500,0.5064471,1.789961,,,,,,,,,,,,,,,,, -32600,0.36637568,1.8029361,,,,,,,,,,,,,,,,, -32700,0.39749712,1.7949551,,,,,,,,,,,,,,,,, -32800,0.46905383,1.8521523,,,,,,,,,,,,,,,,, -32900,0.40810597,1.7934918,,,,,,,,,,,,,,,,, -33000,0.441092,1.7768908,,,,,,,,,,,,,,,,, -33100,0.4462568,1.8149987,,,,,,,,,,,,,,,,, -33200,0.44330484,1.8435574,,,,,,,,,,,,,,,,, -33300,0.4523312,1.7715373,,,,,,,,,,,,,,,,, -33400,0.41579527,1.841986,,,,,,,,,,,,,,,,, -33500,0.41816843,1.7697034,,,,,,,,,,,,,,,,, -33600,0.4368937,1.8544669,,,,,,,,,,,,,,,,, -33679,,,0.640436053276062,1.7520774602890017,31.079749336354755,0.6569168567657471,1.6260370016098022,27.985500631619285,3000.0,0.6679217219352722,1.5678555965423584,27.7718476746868,3003.0,11786.677701950071,19973.24415802956,11786.677701950071,8185.084298849106,0.3881950378417969,0.0 -33700,0.43318537,1.8710295,,,,,,,,,,,,,,,,, -33800,0.4277744,1.7570332,,,,,,,,,,,,,,,,, -33900,0.4399366,1.7101915,,,,,,,,,,,,,,,,, -34000,0.49399465,1.7800604,,,,,,,,,,,,,,,,, -34100,0.38867408,1.783105,,,,,,,,,,,,,,,,, -34200,0.36523175,1.8095001,,,,,,,,,,,,,,,,, -34300,0.38432908,1.7604625,,,,,,,,,,,,,,,,, -34400,0.43068495,1.7629986,,,,,,,,,,,,,,,,, -34500,0.4129934,1.7142671,,,,,,,,,,,,,,,,, -34600,0.45872664,1.7851048,,,,,,,,,,,,,,,,, -34700,0.45321095,1.8507404,,,,,,,,,,,,,,,,, -34800,0.4267019,1.7532784,,,,,,,,,,,,,,,,, -34900,0.39487544,1.8030181,,,,,,,,,,,,,,,,, -35000,0.4955556,1.747761,,,,,,,,,,,,,,,,, -35100,0.41453642,1.7903068,,,,,,,,,,,,,,,,, -35200,0.44067836,1.72639,,,,,,,,,,,,,,,,, -35300,0.41722226,1.7946166,,,,,,,,,,,,,,,,, -35400,0.49676296,1.7723048,,,,,,,,,,,,,,,,, -35500,0.40744486,1.803582,,,,,,,,,,,,,,,,, -35600,0.4139087,1.7616501,,,,,,,,,,,,,,,,, -35700,0.41630387,1.8247916,,,,,,,,,,,,,,,,, -35800,0.41396883,1.798922,,,,,,,,,,,,,,,,, -35900,0.3787964,1.8229378,,,,,,,,,,,,,,,,, -36000,0.39931872,1.7395967,,,,,,,,,,,,,,,,, -36087,,,0.6369031071662903,1.7718901634216309,30.80821034725975,0.6572268009185791,1.6264159679412842,28.23449537713964,3000.0,0.6699668765068054,1.5581209659576416,27.99159121351021,3003.0,12626.625022888184,21296.61293053627,12626.625022888184,8668.399793624878,0.4177160263061523,0.0 -36100,0.43195158,1.7040739,,,,,,,,,,,,,,,,, -36200,0.37257484,1.809257,,,,,,,,,,,,,,,,, -36300,0.3758641,1.7089205,,,,,,,,,,,,,,,,, -36400,0.3836461,1.840695,,,,,,,,,,,,,,,,, -36500,0.42472285,1.7935594,,,,,,,,,,,,,,,,, -36600,0.4179904,1.7371637,,,,,,,,,,,,,,,,, -36700,0.40459108,1.6851387,,,,,,,,,,,,,,,,, -36800,0.38443044,1.7390559,,,,,,,,,,,,,,,,, -36900,0.4288173,1.84529,,,,,,,,,,,,,,,,, -37000,0.41352886,1.8032954,,,,,,,,,,,,,,,,, -37100,0.3673671,1.7225509,,,,,,,,,,,,,,,,, -37200,0.41366848,1.7933958,,,,,,,,,,,,,,,,, -37300,0.42519486,1.832704,,,,,,,,,,,,,,,,, -37400,0.49144456,1.9039506,,,,,,,,,,,,,,,,, -37500,0.38228655,1.8158337,,,,,,,,,,,,,,,,, -37600,0.41992068,1.8663614,,,,,,,,,,,,,,,,, -37700,0.44997668,1.8648772,,,,,,,,,,,,,,,,, -37800,0.41570553,1.7007599,,,,,,,,,,,,,,,,, -37900,0.46365008,1.8013965,,,,,,,,,,,,,,,,, -38000,0.36135665,1.794828,,,,,,,,,,,,,,,,, -38100,0.39649576,1.8092728,,,,,,,,,,,,,,,,, -38200,0.42177644,1.7582196,,,,,,,,,,,,,,,,, -38300,0.41498277,1.7649109,,,,,,,,,,,,,,,,, -38400,0.4333645,1.7998385,,,,,,,,,,,,,,,,, -38495,,,0.6430981159210205,1.7214125394821167,31.545833523294696,0.6587147116661072,1.6191388368606567,28.382361296399072,3000.0,0.6695834398269653,1.555444359779358,27.83423761879015,3003.0,13466.692353248596,22630.738989830017,13466.692353248596,9162.348851919174,0.4478676319122314,0.0 -38500,0.41172525,1.8026764,,,,,,,,,,,,,,,,, -38600,0.39682642,1.7588168,,,,,,,,,,,,,,,,, -38700,0.45504484,1.7757332,,,,,,,,,,,,,,,,, -38800,0.41891003,1.7597667,,,,,,,,,,,,,,,,, -38900,0.39867616,1.7635188,,,,,,,,,,,,,,,,, -39000,0.3740487,1.8001848,,,,,,,,,,,,,,,,, -39100,0.4000051,1.7898135,,,,,,,,,,,,,,,,, -39200,0.38880208,1.6732616,,,,,,,,,,,,,,,,, -39300,0.39972758,1.8156614,,,,,,,,,,,,,,,,, -39400,0.42689827,1.7161036,,,,,,,,,,,,,,,,, -39500,0.46413705,1.8006881,,,,,,,,,,,,,,,,, -39600,0.38557833,1.6975703,,,,,,,,,,,,,,,,, -39700,0.4022504,1.7693664,,,,,,,,,,,,,,,,, -39800,0.38231367,1.75192,,,,,,,,,,,,,,,,, -39900,0.35911837,1.7499373,,,,,,,,,,,,,,,,, -40000,0.36239272,1.7753762,,,,,,,,,,,,,,,,, -40100,0.3747684,1.7405736,,,,,,,,,,,,,,,,, -40200,0.3802153,1.8138928,,,,,,,,,,,,,,,,, -40300,0.35151193,1.7339629,,,,,,,,,,,,,,,,, -40400,0.39359435,1.8039595,,,,,,,,,,,,,,,,, -40500,0.37983426,1.7970906,,,,,,,,,,,,,,,,, -40600,0.399809,1.6590871,,,,,,,,,,,,,,,,, -40700,0.38872957,1.7681669,,,,,,,,,,,,,,,,, -40800,0.35889426,1.7828826,,,,,,,,,,,,,,,,, -40900,0.42545816,1.8587282,,,,,,,,,,,,,,,,, -40904,,,0.6393962502479553,1.7517564296722412,31.326564034977284,0.6592354774475098,1.6140036582946775,28.034270082296626,3000.0,0.668677031993866,1.5514838695526123,27.83588488357372,3003.0,14306.815686225891,23983.063989400864,14306.815686225891,9674.442895889282,0.4786746501922607,0.0 -41000,0.407726,1.8104964,,,,,,,,,,,,,,,,, -41100,0.3879425,1.8091794,,,,,,,,,,,,,,,,, -41200,0.45824444,1.715696,,,,,,,,,,,,,,,,, -41300,0.36801618,1.7891465,,,,,,,,,,,,,,,,, -41400,0.38792387,1.7764425,,,,,,,,,,,,,,,,, -41500,0.38770652,1.7308669,,,,,,,,,,,,,,,,, -41600,0.38912603,1.8052411,,,,,,,,,,,,,,,,, -41700,0.45472112,1.7781419,,,,,,,,,,,,,,,,, -41800,0.44527668,1.8320075,,,,,,,,,,,,,,,,, -41900,0.3421823,1.7180132,,,,,,,,,,,,,,,,, -42000,0.40476015,1.843054,,,,,,,,,,,,,,,,, -42100,0.39174914,1.7754042,,,,,,,,,,,,,,,,, -42200,0.47467354,1.768282,,,,,,,,,,,,,,,,, -42300,0.4029829,1.7230177,,,,,,,,,,,,,,,,, -42400,0.43953925,1.6969018,,,,,,,,,,,,,,,,, -42500,0.5264967,1.7608668,,,,,,,,,,,,,,,,, -42600,0.36750892,1.7103642,,,,,,,,,,,,,,,,, -42700,0.7359581,1.785565,,,,,,,,,,,,,,,,, -42800,0.47761023,1.6927513,,,,,,,,,,,,,,,,, -42900,0.37261942,1.6750561,,,,,,,,,,,,,,,,, -43000,0.42034376,1.7720472,,,,,,,,,,,,,,,,, -43100,0.43132427,1.7837903,,,,,,,,,,,,,,,,, -43200,0.41745594,1.7961214,,,,,,,,,,,,,,,,, -43300,0.41242853,1.7035198,,,,,,,,,,,,,,,,, -43312,,,0.6379010081291199,1.7598146200180054,31.306862035234403,0.6621120572090149,1.6029443740844729,28.58756001053182,3000.0,0.6726861000061035,1.533181071281433,28.08085070114241,3003.0,15146.73096871376,25332.93573999405,15146.73096871376,10184.291736602783,0.5092837810516357,0.0 -43400,0.42873847,1.819648,,,,,,,,,,,,,,,,, -43500,0.45860195,1.7439518,,,,,,,,,,,,,,,,, -43600,0.3954683,1.7358524,,,,,,,,,,,,,,,,, -43700,0.42830232,1.800385,,,,,,,,,,,,,,,,, -43800,0.3957093,1.801367,,,,,,,,,,,,,,,,, -43900,0.4072992,1.7568432,,,,,,,,,,,,,,,,, -44000,0.48098373,1.779177,,,,,,,,,,,,,,,,, -44100,0.38019177,1.7152363,,,,,,,,,,,,,,,,, -44200,0.3576386,1.774436,,,,,,,,,,,,,,,,, -44300,0.49092287,1.749026,,,,,,,,,,,,,,,,, -44400,0.39444587,1.7124288,,,,,,,,,,,,,,,,, -44500,0.37469938,1.8038298,,,,,,,,,,,,,,,,, -44600,0.39511698,1.673862,,,,,,,,,,,,,,,,, -44700,0.3845035,1.760305,,,,,,,,,,,,,,,,, -44800,0.47523984,1.7705517,,,,,,,,,,,,,,,,, -44900,0.36193344,1.8191113,,,,,,,,,,,,,,,,, -45000,0.37122777,1.7403395,,,,,,,,,,,,,,,,, -45100,0.36214796,1.7569573,,,,,,,,,,,,,,,,, -45200,0.36409384,1.7362927,,,,,,,,,,,,,,,,, -45300,0.43511024,1.8461479,,,,,,,,,,,,,,,,, -45400,0.3513093,1.776937,,,,,,,,,,,,,,,,, -45500,0.38529864,1.7752159,,,,,,,,,,,,,,,,, -45600,0.3839652,1.7637514,,,,,,,,,,,,,,,,, -45700,0.4138768,1.7204678,,,,,,,,,,,,,,,,, -45720,,,0.642835795879364,1.7258001565933228,31.51465397782276,0.6610457301139832,1.5972543954849243,28.238541337906103,3000.0,0.6723374724388123,1.5314693450927734,27.78312192161673,3003.0,15986.845165491104,26877.99145627021,15986.845165491104,10889.123789072037,0.5416944026947021,0.0 -45800,0.35681078,1.6998404,,,,,,,,,,,,,,,,, -45900,0.4456277,1.7542946,,,,,,,,,,,,,,,,, -46000,0.36964765,1.7378762,,,,,,,,,,,,,,,,, -46100,0.36405453,1.8011335,,,,,,,,,,,,,,,,, -46200,0.42807415,1.7932395,,,,,,,,,,,,,,,,, -46300,0.625062,1.7417772,,,,,,,,,,,,,,,,, -46400,0.43992493,1.6900151,,,,,,,,,,,,,,,,, -46500,0.41977972,1.7572336,,,,,,,,,,,,,,,,, -46600,0.37134853,1.674025,,,,,,,,,,,,,,,,, -46700,0.50932235,1.9083923,,,,,,,,,,,,,,,,, -46800,0.46375775,1.6755862,,,,,,,,,,,,,,,,, -46900,0.4009929,1.7199981,,,,,,,,,,,,,,,,, -47000,0.36324662,1.7511237,,,,,,,,,,,,,,,,, -47100,0.36983082,1.718868,,,,,,,,,,,,,,,,, -47200,0.33799934,1.6718557,,,,,,,,,,,,,,,,, -47300,0.3550884,1.6805917,,,,,,,,,,,,,,,,, -47400,0.3726245,1.7116232,,,,,,,,,,,,,,,,, -47500,0.35390645,1.7312441,,,,,,,,,,,,,,,,, -47600,0.3649548,1.7440895,,,,,,,,,,,,,,,,, -47700,0.39238065,1.7838428,,,,,,,,,,,,,,,,, -47800,0.37250978,1.6810665,,,,,,,,,,,,,,,,, -47900,0.37649497,1.7338183,,,,,,,,,,,,,,,,, -48000,0.3942076,1.6858786,,,,,,,,,,,,,,,,, -48100,0.4148966,1.7778388,,,,,,,,,,,,,,,,, -48128,,,0.6452369689941406,1.726482629776001,31.82425775905154,0.6602894067764282,1.590267300605774,28.259296946768995,3000.0,0.6742199659347534,1.5215578079223633,27.99158132715858,3003.0,16826.8217689991,28235.679394960403,16826.8217689991,11406.725434064863,0.5741453170776367,0.0 -48200,0.36133337,1.7458642,,,,,,,,,,,,,,,,, -48300,0.42826292,1.7845745,,,,,,,,,,,,,,,,, -48400,0.35520348,1.7343016,,,,,,,,,,,,,,,,, -48500,0.3563875,1.7033254,,,,,,,,,,,,,,,,, -48600,0.38591844,1.8011823,,,,,,,,,,,,,,,,, -48700,0.33675578,1.7246991,,,,,,,,,,,,,,,,, -48800,0.3480138,1.6945363,,,,,,,,,,,,,,,,, -48900,0.39466617,1.836304,,,,,,,,,,,,,,,,, -49000,0.40248552,1.7549666,,,,,,,,,,,,,,,,, -49100,0.34608394,1.7579844,,,,,,,,,,,,,,,,, -49200,0.4058329,1.7858053,,,,,,,,,,,,,,,,, -49300,0.41956246,1.7732402,,,,,,,,,,,,,,,,, -49400,0.37764475,1.7117977,,,,,,,,,,,,,,,,, -49500,0.394222,1.8042489,,,,,,,,,,,,,,,,, -49600,0.3778511,1.767372,,,,,,,,,,,,,,,,, -49700,0.411103,1.7171855,,,,,,,,,,,,,,,,, -49800,0.4082131,1.7341912,,,,,,,,,,,,,,,,, -49900,0.4138741,1.8185315,,,,,,,,,,,,,,,,, -50000,0.37222865,1.74781,,,,,,,,,,,,,,,,, -50100,0.42091104,1.7685179,,,,,,,,,,,,,,,,, -50200,0.44836164,1.8249098,,,,,,,,,,,,,,,,, -50300,0.397922,1.6637429,,,,,,,,,,,,,,,,, -50400,0.36505955,1.7174263,,,,,,,,,,,,,,,,, -50500,0.41553256,1.7559097,,,,,,,,,,,,,,,,, -50537,,,0.6565302610397339,1.640212893486023,32.767179635812575,0.6637363433837891,1.5860275030136108,28.6552268817002,3000.0,0.675358772277832,1.518868088722229,28.05993585302966,3003.0,17666.926292657852,29602.6824324131,17666.926292657852,11933.50832438469,0.6121892929077148,0.0 -50600,0.35050258,1.749441,,,,,,,,,,,,,,,,, -50700,0.4104343,1.8270733,,,,,,,,,,,,,,,,, -50800,0.39111328,1.6686921,,,,,,,,,,,,,,,,, -50900,0.40042984,1.7514058,,,,,,,,,,,,,,,,, -51000,0.376613,1.7291907,,,,,,,,,,,,,,,,, -51100,0.36849394,1.6913477,,,,,,,,,,,,,,,,, -51200,0.44459605,1.7965022,,,,,,,,,,,,,,,,, -51300,0.38459158,1.7670282,,,,,,,,,,,,,,,,, -51400,0.3897969,1.8050703,,,,,,,,,,,,,,,,, -51500,0.3604187,1.6338931,,,,,,,,,,,,,,,,, -51600,0.39860493,1.7714999,,,,,,,,,,,,,,,,, -51700,0.38087726,1.7199167,,,,,,,,,,,,,,,,, -51800,0.3683184,1.8031504,,,,,,,,,,,,,,,,, -51900,0.39511615,1.7779454,,,,,,,,,,,,,,,,, -52000,0.3839939,1.7624687,,,,,,,,,,,,,,,,, -52100,0.34887984,1.7672209,,,,,,,,,,,,,,,,, -52200,0.38866487,1.7065763,,,,,,,,,,,,,,,,, -52300,0.4340925,1.796979,,,,,,,,,,,,,,,,, -52400,0.370716,1.7541,,,,,,,,,,,,,,,,, -52500,0.3568783,1.7112535,,,,,,,,,,,,,,,,, -52600,0.34159046,1.7020844,,,,,,,,,,,,,,,,, -52700,0.35496905,1.7617582,,,,,,,,,,,,,,,,, -52800,0.3945247,1.7425499,,,,,,,,,,,,,,,,, -52900,0.36088946,1.7094153,,,,,,,,,,,,,,,,, -52946,,,0.6452774405479431,1.7110553979873655,31.52326751196575,0.6643067002296448,1.580517292022705,28.57172511434688,3000.0,0.6761606335639954,1.5175516605377195,28.09270029743737,3003.0,18507.14643883705,30996.068150520325,18507.14643883705,12486.558814764025,0.6497421264648438,0.0 -53000,0.39854062,1.7612114,,,,,,,,,,,,,,,,, -53100,0.35672238,1.7368536,,,,,,,,,,,,,,,,, -53200,0.582981,1.7242653,,,,,,,,,,,,,,,,, -53300,0.37640297,1.6864408,,,,,,,,,,,,,,,,, -53400,0.35668,1.6925173,,,,,,,,,,,,,,,,, -53500,0.36676705,1.7366097,,,,,,,,,,,,,,,,, -53600,0.41523018,1.7332268,,,,,,,,,,,,,,,,, -53700,0.36177474,1.6441088,,,,,,,,,,,,,,,,, -53800,0.4227148,1.7172093,,,,,,,,,,,,,,,,, -53900,0.3900916,1.7257892,,,,,,,,,,,,,,,,, -54000,0.3485283,1.736557,,,,,,,,,,,,,,,,, -54100,0.37913686,1.6644627,,,,,,,,,,,,,,,,, -54200,0.4034837,1.821375,,,,,,,,,,,,,,,,, -54300,0.44424838,1.65161,,,,,,,,,,,,,,,,, -54400,0.38999322,1.6725907,,,,,,,,,,,,,,,,, -54500,0.35953748,1.6866745,,,,,,,,,,,,,,,,, -54600,0.39341524,1.7592536,,,,,,,,,,,,,,,,, -54700,0.37066913,1.6801113,,,,,,,,,,,,,,,,, -54800,0.37519893,1.7076683,,,,,,,,,,,,,,,,, -54900,0.39798194,1.7044224,,,,,,,,,,,,,,,,, -55000,0.38201594,1.7047868,,,,,,,,,,,,,,,,, -55100,0.4712077,1.6946168,,,,,,,,,,,,,,,,, -55200,0.3825887,1.656428,,,,,,,,,,,,,,,,, -55300,0.38359183,1.781304,,,,,,,,,,,,,,,,, -55354,,,0.6445077657699585,1.712929129600525,31.342913567592323,0.663823127746582,1.576424241065979,28.29432199708726,3000.0,0.6759397983551025,1.5070810317993164,28.14646888038358,3003.0,19347.095779180527,32369.428835868835,19347.095779180527,13019.85465836525,0.6876661777496338,0.0 -55400,0.37330282,1.746178,,,,,,,,,,,,,,,,, -55500,0.45189676,1.7410225,,,,,,,,,,,,,,,,, -55600,0.39904514,1.776917,,,,,,,,,,,,,,,,, -55700,0.3993853,1.7649468,,,,,,,,,,,,,,,,, -55800,0.46758217,1.766083,,,,,,,,,,,,,,,,, -55900,0.37547636,1.6638142,,,,,,,,,,,,,,,,, -56000,0.4193196,1.6636138,,,,,,,,,,,,,,,,, -56100,0.3668561,1.6414329,,,,,,,,,,,,,,,,, -56200,0.38798425,1.7460463,,,,,,,,,,,,,,,,, -56300,0.4199858,1.712395,,,,,,,,,,,,,,,,, -56400,0.39683753,1.7508435,,,,,,,,,,,,,,,,, -56500,0.4417155,1.6746043,,,,,,,,,,,,,,,,, -56600,0.39896464,1.7908535,,,,,,,,,,,,,,,,, -56700,0.3673751,1.6365671,,,,,,,,,,,,,,,,, -56800,0.39821956,1.7855948,,,,,,,,,,,,,,,,, -56900,0.4017414,1.6880603,,,,,,,,,,,,,,,,, -57000,0.43517512,1.7314644,,,,,,,,,,,,,,,,, -57100,0.3544866,1.6957358,,,,,,,,,,,,,,,,, -57200,0.38725397,1.7395866,,,,,,,,,,,,,,,,, -57300,0.39979732,1.8089463,,,,,,,,,,,,,,,,, -57400,0.46848568,1.7138085,,,,,,,,,,,,,,,,, -57500,0.41450998,1.6672376,,,,,,,,,,,,,,,,, -57600,0.3719635,1.6443522,,,,,,,,,,,,,,,,, -57700,0.40862653,1.7269995,,,,,,,,,,,,,,,,, -57761,,,0.649020254611969,1.682659387588501,32.0713855796427,0.6664641499519348,1.5685975551605225,28.97768596478849,3000.0,0.6782522797584534,1.4985307455062866,28.270571580751486,3003.0,20187.175048589703,33749.7142393589,20187.175048589703,13559.9481112957,0.7214796543121338,0.0 -57800,0.42367202,1.6935652,,,,,,,,,,,,,,,,, -57900,0.3965036,1.6813657,,,,,,,,,,,,,,,,, -58000,0.35484004,1.6562023,,,,,,,,,,,,,,,,, -58100,0.3570205,1.6217179,,,,,,,,,,,,,,,,, -58200,0.42837575,1.6497843,,,,,,,,,,,,,,,,, -58300,0.37120938,1.703854,,,,,,,,,,,,,,,,, -58400,0.41213778,1.6734967,,,,,,,,,,,,,,,,, -58500,0.38800335,1.6728909,,,,,,,,,,,,,,,,, -58600,0.4001819,1.7139845,,,,,,,,,,,,,,,,, -58700,0.38455987,1.7272526,,,,,,,,,,,,,,,,, -58800,0.4346644,1.7407954,,,,,,,,,,,,,,,,, -58900,0.39586234,1.7412963,,,,,,,,,,,,,,,,, -59000,0.39135653,1.7055995,,,,,,,,,,,,,,,,, -59100,0.3416369,1.7252432,,,,,,,,,,,,,,,,, -59200,0.40530473,1.7176309,,,,,,,,,,,,,,,,, -59300,0.37368894,1.7158905,,,,,,,,,,,,,,,,, -59400,0.37206566,1.731172,,,,,,,,,,,,,,,,, -59500,0.37626988,1.6971457,,,,,,,,,,,,,,,,, -59600,0.41830006,1.6866595,,,,,,,,,,,,,,,,, -59700,0.3418199,1.6514939,,,,,,,,,,,,,,,,, -59800,0.34785965,1.6958219,,,,,,,,,,,,,,,,, -59900,0.40436772,1.5908568,,,,,,,,,,,,,,,,, -60000,0.37103617,1.6661627,,,,,,,,,,,,,,,,, -60100,0.38628554,1.7256707,,,,,,,,,,,,,,,,, -60169,,,0.648051917552948,1.6906379461288452,31.88436570280064,0.6664765477180481,1.5599160194396973,28.958729574332384,3000.0,0.6790889501571655,1.4897910356521606,28.51629844762099,3003.0,21027.074191093445,35134.669848680496,21027.074191093445,14104.895560979843,0.7536098957061768,0.0 -60200,0.34539196,1.7436582,,,,,,,,,,,,,,,,, -60300,0.44618267,1.7173849,,,,,,,,,,,,,,,,, -60400,0.3991463,1.66948,,,,,,,,,,,,,,,,, -60500,0.36956415,1.6551493,,,,,,,,,,,,,,,,, -60600,0.40682113,1.7115637,,,,,,,,,,,,,,,,, -60700,0.3586404,1.6572552,,,,,,,,,,,,,,,,, -60800,0.41483894,1.7602049,,,,,,,,,,,,,,,,, -60900,0.40652415,1.8035405,,,,,,,,,,,,,,,,, -61000,0.38122305,1.731871,,,,,,,,,,,,,,,,, -61100,0.3705281,1.6794348,,,,,,,,,,,,,,,,, -61200,0.38275638,1.6835662,,,,,,,,,,,,,,,,, -61300,0.38533017,1.6861651,,,,,,,,,,,,,,,,, -61400,0.42031205,1.6989462,,,,,,,,,,,,,,,,, -61500,0.37193808,1.6767173,,,,,,,,,,,,,,,,, -61600,0.37099263,1.6306758,,,,,,,,,,,,,,,,, -61700,0.3790926,1.6711332,,,,,,,,,,,,,,,,, -61800,0.39529648,1.7293425,,,,,,,,,,,,,,,,, -61900,0.34811085,1.6435072,,,,,,,,,,,,,,,,, -62000,0.37538043,1.8095465,,,,,,,,,,,,,,,,, -62100,0.3495855,1.6145415,,,,,,,,,,,,,,,,, -62200,0.39388406,1.6934137,,,,,,,,,,,,,,,,, -62300,0.40828368,1.6961983,,,,,,,,,,,,,,,,, -62400,0.38186648,1.6740981,,,,,,,,,,,,,,,,, -62500,0.40872732,1.7308804,,,,,,,,,,,,,,,,, -62577,,,0.6952161192893982,1.432866096496582,35.779476196254954,0.6689935326576233,1.5532121658325195,29.131601118820857,3000.0,0.678345263004303,1.491255760192871,28.343151127526827,3003.0,21867.165781974792,36433.94098830223,21867.165781974792,14563.95576763153,0.7935366630554199,0.0 -62600,0.37547138,1.6943275,,,,,,,,,,,,,,,,, -62700,0.37443605,1.7149309,,,,,,,,,,,,,,,,, -62800,0.3725081,1.6496431,,,,,,,,,,,,,,,,, -62900,0.3895265,1.6797035,,,,,,,,,,,,,,,,, -63000,0.40147623,1.7219784,,,,,,,,,,,,,,,,, -63100,0.35051158,1.6351374,,,,,,,,,,,,,,,,, -63200,0.39693528,1.5953041,,,,,,,,,,,,,,,,, -63300,0.36293715,1.671655,,,,,,,,,,,,,,,,, -63400,0.36719936,1.72711,,,,,,,,,,,,,,,,, -63500,0.3951032,1.7432274,,,,,,,,,,,,,,,,, -63600,0.3834596,1.7473187,,,,,,,,,,,,,,,,, -63700,0.38994458,1.7156161,,,,,,,,,,,,,,,,, -63800,0.37786877,1.7273124,,,,,,,,,,,,,,,,, -63900,0.3809718,1.6980593,,,,,,,,,,,,,,,,, -64000,0.38233268,1.7131482,,,,,,,,,,,,,,,,, -64100,0.38541397,1.5794891,,,,,,,,,,,,,,,,, -64200,0.36713213,1.7839804,,,,,,,,,,,,,,,,, -64300,0.37971336,1.7062274,,,,,,,,,,,,,,,,, -64400,0.37910438,1.8191359,,,,,,,,,,,,,,,,, -64500,0.41054818,1.7564456,,,,,,,,,,,,,,,,, -64600,0.40682355,1.7213587,,,,,,,,,,,,,,,,, -64700,0.41117144,1.6241933,,,,,,,,,,,,,,,,, -64800,0.39356825,1.7145909,,,,,,,,,,,,,,,,, -64900,0.36420596,1.7168465,,,,,,,,,,,,,,,,, -64986,,,0.6524246335029602,1.6702783107757568,31.827023341578347,0.6694151163101196,1.539971113204956,28.90430783369693,3000.0,0.6832142472267151,1.4675744771957395,28.673106638091205,3003.0,22707.37787055969,37756.19357728958,22707.37787055969,15045.877503871918,0.8332781791687012,0.0 -65000,0.40893424,1.702528,,,,,,,,,,,,,,,,, -65100,0.41302356,1.6677238,,,,,,,,,,,,,,,,, -65200,0.422726,1.697104,,,,,,,,,,,,,,,,, -65300,0.46465895,1.5631875,,,,,,,,,,,,,,,,, -65400,0.3789656,1.7722528,,,,,,,,,,,,,,,,, -65500,0.41112792,1.7209488,,,,,,,,,,,,,,,,, -65600,0.44347352,1.7215272,,,,,,,,,,,,,,,,, -65700,0.40235168,1.6085511,,,,,,,,,,,,,,,,, -65800,0.38663274,1.6487155,,,,,,,,,,,,,,,,, -65900,0.38115174,1.6969482,,,,,,,,,,,,,,,,, -66000,0.383822,1.6221782,,,,,,,,,,,,,,,,, -66100,0.3937369,1.6549788,,,,,,,,,,,,,,,,, -66200,0.37412894,1.718425,,,,,,,,,,,,,,,,, -66300,0.38250324,1.6420726,,,,,,,,,,,,,,,,, -66400,0.37368798,1.6305017,,,,,,,,,,,,,,,,, -66500,0.40642968,1.6955962,,,,,,,,,,,,,,,,, -66600,0.37809908,1.7371968,,,,,,,,,,,,,,,,, -66700,0.36196482,1.6826189,,,,,,,,,,,,,,,,, -66800,0.41835174,1.7481519,,,,,,,,,,,,,,,,, -66900,0.37145653,1.5979444,,,,,,,,,,,,,,,,, -67000,0.43198964,1.7250824,,,,,,,,,,,,,,,,, -67100,0.3716007,1.7258707,,,,,,,,,,,,,,,,, -67200,0.38106975,1.6468803,,,,,,,,,,,,,,,,, -67300,0.4335107,1.6405737,,,,,,,,,,,,,,,,, -67394,,,0.6514863967895508,1.6692497730255127,31.94321390091731,0.6705062389373779,1.540156364440918,29.242645233820426,3000.0,0.6817849278450012,1.4733388423919678,28.59808848117703,3003.0,23547.34636592865,39197.67491483688,23547.34636592865,15647.19865846634,0.9477291107177734,0.0 -67400,0.4308767,1.6784067,,,,,,,,,,,,,,,,, -67500,0.38068986,1.6359223,,,,,,,,,,,,,,,,, -67600,0.38910243,1.6496145,,,,,,,,,,,,,,,,, -67700,0.38046262,1.628977,,,,,,,,,,,,,,,,, -67800,0.39196855,1.6783665,,,,,,,,,,,,,,,,, -67900,0.39364842,1.7468232,,,,,,,,,,,,,,,,, -68000,0.37470436,1.5697869,,,,,,,,,,,,,,,,, -68100,0.42309955,1.7148943,,,,,,,,,,,,,,,,, -68200,0.40237963,1.8107637,,,,,,,,,,,,,,,,, -68300,0.38318813,1.6596242,,,,,,,,,,,,,,,,, -68400,0.3744678,1.6175506,,,,,,,,,,,,,,,,, -68500,0.39809403,1.6644716,,,,,,,,,,,,,,,,, -68600,0.3743064,1.6750078,,,,,,,,,,,,,,,,, -68700,0.417348,1.6704215,,,,,,,,,,,,,,,,, -68800,0.43733186,1.7699182,,,,,,,,,,,,,,,,, -68900,0.40364042,1.649304,,,,,,,,,,,,,,,,, -69000,0.4324774,1.6108072,,,,,,,,,,,,,,,,, -69100,0.39589065,1.7243237,,,,,,,,,,,,,,,,, -69200,0.41851062,1.6655761,,,,,,,,,,,,,,,,, -69300,0.39993355,1.635737,,,,,,,,,,,,,,,,, -69400,0.39870724,1.6749618,,,,,,,,,,,,,,,,, -69500,0.3944898,1.6997284,,,,,,,,,,,,,,,,, -69600,0.38347223,1.6110235,,,,,,,,,,,,,,,,, -69700,0.420919,1.768295,,,,,,,,,,,,,,,,, -69800,0.36125863,1.6194729,,,,,,,,,,,,,,,,, -69801,,,0.6611226201057434,1.6098320484161377,32.63959978255694,0.6719693541526794,1.5329428911209106,29.068353792762743,3000.0,0.684980571269989,1.4600976705551147,29.0403300225907,3003.0,24387.48620867729,40585.99910902977,24387.48620867729,16195.264653921127,0.9868950843811036,0.0 -69900,0.43552536,1.695028,,,,,,,,,,,,,,,,, -70000,0.4206665,1.5860821,,,,,,,,,,,,,,,,, -70100,0.355354,1.6299636,,,,,,,,,,,,,,,,, -70200,0.40089625,1.6377039,,,,,,,,,,,,,,,,, -70300,0.39516467,1.6287472,,,,,,,,,,,,,,,,, -70400,0.39658943,1.6136988,,,,,,,,,,,,,,,,, -70500,0.42288724,1.7345247,,,,,,,,,,,,,,,,, -70600,0.3927513,1.6054785,,,,,,,,,,,,,,,,, -70700,0.40899092,1.634365,,,,,,,,,,,,,,,,, -70800,0.45804495,1.6674652,,,,,,,,,,,,,,,,, -70900,0.36242437,1.6538934,,,,,,,,,,,,,,,,, -71000,0.3741424,1.6358277,,,,,,,,,,,,,,,,, -71100,0.36968395,1.6510462,,,,,,,,,,,,,,,,, -71200,0.39079487,1.6642187,,,,,,,,,,,,,,,,, -71300,0.37095773,1.623108,,,,,,,,,,,,,,,,, -71400,0.42311305,1.6827552,,,,,,,,,,,,,,,,, -71500,0.38040182,1.6882926,,,,,,,,,,,,,,,,, -71600,0.376335,1.648313,,,,,,,,,,,,,,,,, -71700,0.43483442,1.6258881,,,,,,,,,,,,,,,,, -71800,0.40096363,1.6511047,,,,,,,,,,,,,,,,, -71900,0.38429642,1.5324528,,,,,,,,,,,,,,,,, -72000,0.39251938,1.7014729,,,,,,,,,,,,,,,,, -72100,0.38746497,1.6394013,,,,,,,,,,,,,,,,, -72200,0.37491888,1.6876818,,,,,,,,,,,,,,,,, -72213,,,0.6569823026657104,1.6434437036514282,32.19100745654006,0.6732092499732971,1.5275042057037354,29.36287108991364,3000.0,0.6867700815200806,1.4527868032455444,29.133928657891733,3003.0,25227.61577558517,41935.63665962219,25227.61577558517,16704.657161474228,1.0241799354553225,0.0 -72300,0.38201153,1.5950372,,,,,,,,,,,,,,,,, -72400,0.4691968,1.6366671,,,,,,,,,,,,,,,,, -72500,0.4182459,1.6414791,,,,,,,,,,,,,,,,, -72600,0.39028993,1.5755125,,,,,,,,,,,,,,,,, -72700,0.41702417,1.6523739,,,,,,,,,,,,,,,,, -72800,0.37529594,1.6783562,,,,,,,,,,,,,,,,, -72900,0.41142288,1.6633227,,,,,,,,,,,,,,,,, -73000,0.38625187,1.6743332,,,,,,,,,,,,,,,,, -73100,0.4101902,1.6137745,,,,,,,,,,,,,,,,, -73200,0.3751369,1.6473428,,,,,,,,,,,,,,,,, -73300,0.39116755,1.6496018,,,,,,,,,,,,,,,,, -73400,0.434574,1.7000072,,,,,,,,,,,,,,,,, -73500,0.37861946,1.5976291,,,,,,,,,,,,,,,,, -73600,0.40227604,1.592643,,,,,,,,,,,,,,,,, -73700,0.41789252,1.6713594,,,,,,,,,,,,,,,,, -73800,0.37970865,1.6880103,,,,,,,,,,,,,,,,, -73900,0.3830334,1.6620095,,,,,,,,,,,,,,,,, -74000,0.418297,1.6835529,,,,,,,,,,,,,,,,, -74100,0.3656102,1.6443646,,,,,,,,,,,,,,,,, -74200,0.48215243,1.6495732,,,,,,,,,,,,,,,,, -74300,0.3919387,1.6592103,,,,,,,,,,,,,,,,, -74400,0.37246913,1.6334367,,,,,,,,,,,,,,,,, -74500,0.37041035,1.6440345,,,,,,,,,,,,,,,,, -74600,0.38552362,1.573137,,,,,,,,,,,,,,,,, -74621,,,0.6536204218864441,1.6660853624343872,32.52132858933652,0.6733828186988831,1.5218721628189087,29.318499071867286,3000.0,0.6867700815200806,1.4440298080444336,28.99523113730168,3003.0,26067.56249308586,43327.928844451904,26067.56249308586,17256.88942360878,1.0615234375,0.0 -74700,0.38678816,1.6527662,,,,,,,,,,,,,,,,, -74800,0.38056108,1.6933552,,,,,,,,,,,,,,,,, -74900,0.3968656,1.5810791,,,,,,,,,,,,,,,,, -75000,0.4256489,1.6872418,,,,,,,,,,,,,,,,, -75100,0.39654568,1.6152681,,,,,,,,,,,,,,,,, -75200,0.42580777,1.609481,,,,,,,,,,,,,,,,, -75300,0.38657847,1.6291488,,,,,,,,,,,,,,,,, -75400,0.4205032,1.6023977,,,,,,,,,,,,,,,,, -75500,0.405382,1.5872017,,,,,,,,,,,,,,,,, -75600,0.3814071,1.6482658,,,,,,,,,,,,,,,,, -75700,0.36830625,1.5792307,,,,,,,,,,,,,,,,, -75800,0.40566403,1.6354327,,,,,,,,,,,,,,,,, -75900,0.40496746,1.7307929,,,,,,,,,,,,,,,,, -76000,0.42749885,1.6938691,,,,,,,,,,,,,,,,, -76100,0.3724899,1.5524853,,,,,,,,,,,,,,,,, -76200,0.40155742,1.6371965,,,,,,,,,,,,,,,,, -76300,0.39569205,1.6751496,,,,,,,,,,,,,,,,, -76400,0.3997147,1.7402916,,,,,,,,,,,,,,,,, -76500,0.42003462,1.6053144,,,,,,,,,,,,,,,,, -76600,0.39721674,1.623971,,,,,,,,,,,,,,,,, -76700,0.40090942,1.5972561,,,,,,,,,,,,,,,,, -76800,0.3959483,1.6524453,,,,,,,,,,,,,,,,, -76900,0.42625502,1.7127798,,,,,,,,,,,,,,,,, -77000,0.4019475,1.6021127,,,,,,,,,,,,,,,,, -77029,,,0.6623920798301697,1.6091350317001345,33.60163001132954,0.6767553687095642,1.5119057893753052,29.60513743671253,3000.0,0.6889199018478394,1.4393484592437744,29.566770359025814,3003.0,26907.49367928505,44685.71225547791,26907.49367928505,17774.62933588028,1.0971648693084717,0.0 -77100,0.38662574,1.6489823,,,,,,,,,,,,,,,,, -77200,0.42658493,1.6429796,,,,,,,,,,,,,,,,, -77300,0.39173096,1.6363027,,,,,,,,,,,,,,,,, -77400,0.39343518,1.6557655,,,,,,,,,,,,,,,,, -77500,0.40149188,1.5950068,,,,,,,,,,,,,,,,, -77600,0.42844185,1.5547551,,,,,,,,,,,,,,,,, -77700,0.40830222,1.6026869,,,,,,,,,,,,,,,,, -77800,0.41268593,1.6289487,,,,,,,,,,,,,,,,, -77900,0.39607194,1.7670761,,,,,,,,,,,,,,,,, -78000,0.41572732,1.6849606,,,,,,,,,,,,,,,,, -78100,0.40718958,1.6001085,,,,,,,,,,,,,,,,, -78200,0.41549316,1.6463325,,,,,,,,,,,,,,,,, -78300,0.39685792,1.66072,,,,,,,,,,,,,,,,, -78400,0.49394315,1.6460729,,,,,,,,,,,,,,,,, -78500,0.40016717,1.6912476,,,,,,,,,,,,,,,,, -78600,0.38487807,1.6866589,,,,,,,,,,,,,,,,, -78700,0.4192497,1.6434507,,,,,,,,,,,,,,,,, -78800,0.40022972,1.6285086,,,,,,,,,,,,,,,,, -78900,0.39920133,1.6207545,,,,,,,,,,,,,,,,, -79000,0.43278047,1.6227875,,,,,,,,,,,,,,,,, -79100,0.41230094,1.6187006,,,,,,,,,,,,,,,,, -79200,0.38365597,1.5876313,,,,,,,,,,,,,,,,, -79300,0.40759888,1.5938226,,,,,,,,,,,,,,,,, -79400,0.38418984,1.5935867,,,,,,,,,,,,,,,,, -79440,,,0.6589401364326477,1.6313480138778689,33.36118043719307,0.6741020083427429,1.511282444000244,29.13616132874532,3000.0,0.6889663934707642,1.430628776550293,29.096602843757115,3003.0,27747.6175262928,46043.84932875633,27747.6175262928,18292.526223659515,1.133357286453247,0.0 -79500,0.42823794,1.6093849,,,,,,,,,,,,,,,,, -79600,0.4011875,1.6142272,,,,,,,,,,,,,,,,, -79700,0.42269948,1.7020077,,,,,,,,,,,,,,,,, -79800,0.4559563,1.6977983,,,,,,,,,,,,,,,,, -79900,0.4105055,1.6146264,,,,,,,,,,,,,,,,, -80000,0.39528176,1.5757943,,,,,,,,,,,,,,,,, -80100,0.43600172,1.7157345,,,,,,,,,,,,,,,,, -80200,0.38573486,1.6052998,,,,,,,,,,,,,,,,, -80300,0.43952474,1.6823205,,,,,,,,,,,,,,,,, -80400,0.39833316,1.6248082,,,,,,,,,,,,,,,,, -80500,0.44028905,1.632521,,,,,,,,,,,,,,,,, -80600,0.40821528,1.6239965,,,,,,,,,,,,,,,,, -80700,0.44806778,1.7016795,,,,,,,,,,,,,,,,, -80800,0.43566856,1.6652007,,,,,,,,,,,,,,,,, -80900,0.4205947,1.6176045,,,,,,,,,,,,,,,,, -81000,0.39879933,1.5469251,,,,,,,,,,,,,,,,, -81100,0.405032,1.609225,,,,,,,,,,,,,,,,, -81200,0.43916076,1.6400352,,,,,,,,,,,,,,,,, -81300,0.42004985,1.5835733,,,,,,,,,,,,,,,,, -81400,0.3854002,1.5964664,,,,,,,,,,,,,,,,, -81500,0.40486407,1.528645,,,,,,,,,,,,,,,,, -81600,0.38860244,1.5981435,,,,,,,,,,,,,,,,, -81700,0.41071782,1.6330054,,,,,,,,,,,,,,,,, -81800,0.43529427,1.6312743,,,,,,,,,,,,,,,,, -81848,,,0.6777380108833313,1.5110163688659668,33.80495966434508,0.6777721047401428,1.4969340562820437,29.83331729621893,3000.0,0.6912091374397278,1.4248101711273191,29.329210019172148,3003.0,28587.77638578415,47551.57074832916,28587.77638578415,18959.974626541138,1.171102523803711,0.0 -81900,0.42737424,1.6120023,,,,,,,,,,,,,,,,, -82000,0.40118372,1.58021,,,,,,,,,,,,,,,,, -82100,0.40982595,1.6715465,,,,,,,,,,,,,,,,, -82200,0.41907078,1.5556787,,,,,,,,,,,,,,,,, -82300,0.41507605,1.6911235,,,,,,,,,,,,,,,,, -82400,0.4213508,1.6341791,,,,,,,,,,,,,,,,, -82500,0.42868897,1.6614736,,,,,,,,,,,,,,,,, -82600,0.39769575,1.5609765,,,,,,,,,,,,,,,,, -82700,0.40788126,1.6210979,,,,,,,,,,,,,,,,, -82800,0.43244144,1.6134409,,,,,,,,,,,,,,,,, -82900,0.43107036,1.6378075,,,,,,,,,,,,,,,,, -83000,0.40419158,1.5990835,,,,,,,,,,,,,,,,, -83100,0.40984347,1.5358882,,,,,,,,,,,,,,,,, -83200,0.4508771,1.5971562,,,,,,,,,,,,,,,,, -83300,0.41757044,1.6418073,,,,,,,,,,,,,,,,, -83400,0.42051363,1.5454246,,,,,,,,,,,,,,,,, -83500,0.40530947,1.6069062,,,,,,,,,,,,,,,,, -83600,0.41282088,1.5948639,,,,,,,,,,,,,,,,, -83700,0.4510191,1.5972621,,,,,,,,,,,,,,,,, -83800,0.4243107,1.6682842,,,,,,,,,,,,,,,,, -83900,0.4205117,1.5701721,,,,,,,,,,,,,,,,, -84000,0.41279736,1.6751143,,,,,,,,,,,,,,,,, -84100,0.42160204,1.5960131,,,,,,,,,,,,,,,,, -84200,0.39560303,1.5416032,,,,,,,,,,,,,,,,, -84257,,,0.6639957427978516,1.589513659477234,33.48077142570666,0.6791732311248779,1.4911015033721924,29.60846193664788,3000.0,0.6941258907318115,1.4093471765518188,29.71491599447524,3003.0,29427.856281757355,48891.9859354496,29427.856281757355,19460.1951122284,1.209423542022705,0.0 -84300,0.42867196,1.6299477,,,,,,,,,,,,,,,,, -84400,0.4132966,1.6232986,,,,,,,,,,,,,,,,, -84500,0.4836592,1.6209532,,,,,,,,,,,,,,,,, -84600,0.40737462,1.5145478,,,,,,,,,,,,,,,,, -84700,0.45509136,1.585069,,,,,,,,,,,,,,,,, -84800,0.44578573,1.6552448,,,,,,,,,,,,,,,,, -84900,0.4322558,1.5975101,,,,,,,,,,,,,,,,, -85000,0.39296153,1.5536474,,,,,,,,,,,,,,,,, -85100,0.43255106,1.5216863,,,,,,,,,,,,,,,,, -85200,0.42903477,1.6243203,,,,,,,,,,,,,,,,, -85300,0.42443526,1.6341404,,,,,,,,,,,,,,,,, -85400,0.4068488,1.5663036,,,,,,,,,,,,,,,,, -85500,0.43305627,1.5828602,,,,,,,,,,,,,,,,, -85600,0.4231886,1.5835817,,,,,,,,,,,,,,,,, -85700,0.4491622,1.6048465,,,,,,,,,,,,,,,,, -85800,0.47722986,1.5746657,,,,,,,,,,,,,,,,, -85900,0.44642785,1.6529665,,,,,,,,,,,,,,,,, -86000,0.5238687,1.7027379,,,,,,,,,,,,,,,,, -86100,0.40584624,1.6578108,,,,,,,,,,,,,,,,, -86200,0.4406922,1.6110812,,,,,,,,,,,,,,,,, -86300,0.44635063,1.5966977,,,,,,,,,,,,,,,,, -86400,0.44526792,1.5396457,,,,,,,,,,,,,,,,, -86500,0.42082494,1.5953302,,,,,,,,,,,,,,,,, -86600,0.587192,1.6100768,,,,,,,,,,,,,,,,, -86666,,,0.6675595045089722,1.5793704986572266,32.93528053933675,0.6773629784584045,1.487136960029602,29.58336269676801,3000.0,0.693370521068573,1.4054758548736572,29.51739396229348,3003.0,30268.08896493912,50289.92531251907,30268.08896493912,20017.788903951645,1.2460203170776367,0.0 -86700,0.45467398,1.5997665,,,,,,,,,,,,,,,,, -86800,0.4593604,1.5580958,,,,,,,,,,,,,,,,, -86900,0.45541313,1.5992295,,,,,,,,,,,,,,,,, -87000,0.41164795,1.546263,,,,,,,,,,,,,,,,, -87100,0.4178733,1.6245474,,,,,,,,,,,,,,,,, -87200,0.43651864,1.572626,,,,,,,,,,,,,,,,, -87300,0.42282552,1.5751612,,,,,,,,,,,,,,,,, -87400,0.441572,1.624796,,,,,,,,,,,,,,,,, -87500,0.45005965,1.5818764,,,,,,,,,,,,,,,,, -87600,0.4342206,1.613934,,,,,,,,,,,,,,,,, -87700,0.4239274,1.6191794,,,,,,,,,,,,,,,,, -87800,0.39962146,1.5852066,,,,,,,,,,,,,,,,, -87900,0.4576227,1.6640337,,,,,,,,,,,,,,,,, -88000,0.4431219,1.5502315,,,,,,,,,,,,,,,,, -88100,0.43763006,1.604281,,,,,,,,,,,,,,,,, -88200,0.4403599,1.6034138,,,,,,,,,,,,,,,,, -88300,0.43482208,1.5084282,,,,,,,,,,,,,,,,, -88400,0.46780515,1.6545678,,,,,,,,,,,,,,,,, -88500,0.47932205,1.5557973,,,,,,,,,,,,,,,,, -88600,0.43715316,1.5978537,,,,,,,,,,,,,,,,, -88700,0.43942627,1.6178299,,,,,,,,,,,,,,,,, -88800,0.4583292,1.6382194,,,,,,,,,,,,,,,,, -88900,0.4198255,1.4624281,,,,,,,,,,,,,,,,, -89000,0.45597684,1.538584,,,,,,,,,,,,,,,,, -89075,,,0.6773558855056763,1.5134402513504028,33.889525427289215,0.6815290451049805,1.4774235486984253,29.90880797413811,3000.0,0.696914792060852,1.3995591402053833,29.432154247263547,3003.0,31108.269670009613,51703.35502099991,31108.269670009613,20590.925078868862,1.283271074295044,0.0 -89100,0.45243353,1.6241963,,,,,,,,,,,,,,,,, -89200,0.470417,1.6380131,,,,,,,,,,,,,,,,, -89300,0.43460286,1.581487,,,,,,,,,,,,,,,,, -89400,0.4306787,1.4748547,,,,,,,,,,,,,,,,, -89500,0.44503507,1.5547868,,,,,,,,,,,,,,,,, -89600,0.45470366,1.585333,,,,,,,,,,,,,,,,, -89700,0.45194328,1.5264977,,,,,,,,,,,,,,,,, -89800,0.45503598,1.5664959,,,,,,,,,,,,,,,,, -89900,0.45833054,1.5361238,,,,,,,,,,,,,,,,, -90000,0.4610406,1.5983516,,,,,,,,,,,,,,,,, -90100,0.4682386,1.6238279,,,,,,,,,,,,,,,,, -90200,0.4712126,1.5906701,,,,,,,,,,,,,,,,, -90300,0.4577462,1.5626483,,,,,,,,,,,,,,,,, -90400,0.4607282,1.6173542,,,,,,,,,,,,,,,,, -90500,0.44673204,1.5727006,,,,,,,,,,,,,,,,, -90600,0.43104362,1.5949395,,,,,,,,,,,,,,,,, -90700,0.4430202,1.4999368,,,,,,,,,,,,,,,,, -90800,0.4370195,1.5692536,,,,,,,,,,,,,,,,, -90900,0.43662563,1.5718889,,,,,,,,,,,,,,,,, -91000,0.43619683,1.6215863,,,,,,,,,,,,,,,,, -91100,0.45054448,1.5565822,,,,,,,,,,,,,,,,, -91200,0.4537331,1.5655775,,,,,,,,,,,,,,,,, -91300,0.45984188,1.5855751,,,,,,,,,,,,,,,,, -91400,0.46333423,1.5218651,,,,,,,,,,,,,,,,, -91482,,,0.6696231365203857,1.5658307075500488,33.075913430571255,0.6815414428710938,1.4682722091674805,29.879511675630194,3000.0,0.6981697678565979,1.38344144821167,29.75321280222057,3003.0,31948.175379037857,53102.57399082184,31948.175379037857,21150.12527155876,1.319873571395874,0.0 -91500,0.43177125,1.5592196,,,,,,,,,,,,,,,,, -91600,0.44626534,1.4950302,,,,,,,,,,,,,,,,, -91700,0.45903197,1.5902364,,,,,,,,,,,,,,,,, -91800,0.49557784,1.5146077,,,,,,,,,,,,,,,,, -91900,0.45800504,1.5272455,,,,,,,,,,,,,,,,, -92000,0.4859556,1.6263665,,,,,,,,,,,,,,,,, -92100,0.45232317,1.6115092,,,,,,,,,,,,,,,,, -92200,0.46011645,1.5512894,,,,,,,,,,,,,,,,, -92300,0.5268897,1.5293951,,,,,,,,,,,,,,,,, -92400,0.43089372,1.5158832,,,,,,,,,,,,,,,,, -92500,0.47062862,1.6150466,,,,,,,,,,,,,,,,, -92600,0.48684794,1.6426421,,,,,,,,,,,,,,,,, -92700,0.4667741,1.5902033,,,,,,,,,,,,,,,,, -92800,0.4603852,1.5524008,,,,,,,,,,,,,,,,, -92900,0.43406314,1.5203711,,,,,,,,,,,,,,,,, -93000,0.43553156,1.5851456,,,,,,,,,,,,,,,,, -93100,0.46862757,1.576895,,,,,,,,,,,,,,,,, -93200,0.4925456,1.6165875,,,,,,,,,,,,,,,,, -93300,0.4787579,1.5439985,,,,,,,,,,,,,,,,, -93400,0.5130878,1.546215,,,,,,,,,,,,,,,,, -93500,0.44393316,1.5342062,,,,,,,,,,,,,,,,, -93600,0.48638657,1.5911602,,,,,,,,,,,,,,,,, -93700,0.46666163,1.5615611,,,,,,,,,,,,,,,,, -93800,0.48236018,1.4604403,,,,,,,,,,,,,,,,, -93890,,,0.7072305679321289,1.355831503868103,36.74364605152375,0.6831037402153015,1.466017246246338,30.11848958134699,3000.0,0.697519063949585,1.383360743522644,29.793031114372457,3003.0,32788.21763634682,54445.86746001244,32788.21763634682,21653.25521636009,1.3652818202972412,0.0 -93900,0.45299032,1.5531678,,,,,,,,,,,,,,,,, -94000,0.46805075,1.5551966,,,,,,,,,,,,,,,,, -94100,0.47528642,1.5394199,,,,,,,,,,,,,,,,, -94200,0.4905776,1.6172433,,,,,,,,,,,,,,,,, -94300,0.48377368,1.5822101,,,,,,,,,,,,,,,,, -94400,0.48556247,1.6260124,,,,,,,,,,,,,,,,, -94500,0.4598641,1.5612227,,,,,,,,,,,,,,,,, -94600,0.47298986,1.5439764,,,,,,,,,,,,,,,,, -94700,0.47846827,1.5227973,,,,,,,,,,,,,,,,, -94800,0.489046,1.4771893,,,,,,,,,,,,,,,,, -94900,0.46297216,1.5275754,,,,,,,,,,,,,,,,, -95000,0.48868376,1.5960506,,,,,,,,,,,,,,,,, -95100,0.48749766,1.5481952,,,,,,,,,,,,,,,,, -95200,0.46217984,1.5058864,,,,,,,,,,,,,,,,, -95300,0.47679034,1.5413164,,,,,,,,,,,,,,,,, -95400,0.5147645,1.4700432,,,,,,,,,,,,,,,,, -95500,0.4665521,1.5391085,,,,,,,,,,,,,,,,, -95600,0.484771,1.5697612,,,,,,,,,,,,,,,,, -95700,0.5017068,1.6417216,,,,,,,,,,,,,,,,, -95800,0.52755696,1.5981716,,,,,,,,,,,,,,,,, -95900,0.4853087,1.6190209,,,,,,,,,,,,,,,,, -96000,0.4884509,1.4852293,,,,,,,,,,,,,,,,, -96100,0.48842362,1.60394,,,,,,,,,,,,,,,,, -96200,0.5148201,1.5222937,,,,,,,,,,,,,,,,, -96299,,,0.6774893403053284,1.5120842456817627,34.065012769999825,0.6842816472053528,1.452805519104004,30.26416928198846,3000.0,0.6997618079185486,1.3693721294403076,30.24534871633443,3003.0,33628.44213843346,55846.56341743469,33628.44213843346,22213.611531734467,1.4041471481323242,0.0 -96300,0.4879642,1.5681993,,,,,,,,,,,,,,,,, -96400,0.505099,1.6200651,,,,,,,,,,,,,,,,, -96500,0.48881915,1.5396311,,,,,,,,,,,,,,,,, -96600,0.46636608,1.5146489,,,,,,,,,,,,,,,,, -96700,0.4999463,1.5201743,,,,,,,,,,,,,,,,, -96800,0.51048696,1.5294937,,,,,,,,,,,,,,,,, -96900,0.47372955,1.5278825,,,,,,,,,,,,,,,,, -97000,0.50413686,1.5716865,,,,,,,,,,,,,,,,, -97100,0.5019866,1.4435766,,,,,,,,,,,,,,,,, -97200,0.49800178,1.566202,,,,,,,,,,,,,,,,, -97300,0.49146074,1.5594234,,,,,,,,,,,,,,,,, -97400,0.5191201,1.556056,,,,,,,,,,,,,,,,, -97500,0.48325694,1.4973563,,,,,,,,,,,,,,,,, -97600,0.46972945,1.4824661,,,,,,,,,,,,,,,,, -97700,0.48910853,1.5213958,,,,,,,,,,,,,,,,, -97800,0.51370287,1.46751,,,,,,,,,,,,,,,,, -97900,0.498026,1.52067,,,,,,,,,,,,,,,,, -98000,0.48510808,1.4663447,,,,,,,,,,,,,,,,, -98100,0.5057561,1.4405451,,,,,,,,,,,,,,,,, -98200,0.47369447,1.4882637,,,,,,,,,,,,,,,,, -98300,0.51468027,1.5810592,,,,,,,,,,,,,,,,, -98400,0.49024805,1.4448729,,,,,,,,,,,,,,,,, -98500,0.5330457,1.5075413,,,,,,,,,,,,,,,,, -98600,0.69847536,1.5058366,,,,,,,,,,,,,,,,, -98700,0.4939073,1.5351743,,,,,,,,,,,,,,,,, -98707,,,0.6742954254150391,1.5320757627487185,34.28313749510387,0.6845172047615051,1.4518380165100098,30.17301385143447,3000.0,0.6995061635971069,1.36750328540802,30.188361116708307,3003.0,34468.39393520355,57219.96085047722,34468.39393520355,22746.94253396988,1.4424107074737549,0.0 -98800,0.5216345,1.5233979,,,,,,,,,,,,,,,,, -98900,0.52744704,1.5217811,,,,,,,,,,,,,,,,, -99000,0.5039681,1.571659,,,,,,,,,,,,,,,,, -99100,0.50241244,1.4619799,,,,,,,,,,,,,,,,, -99200,0.49094078,1.4886268,,,,,,,,,,,,,,,,, -99300,0.51080835,1.5514877,,,,,,,,,,,,,,,,, -99400,0.51080173,1.5101485,,,,,,,,,,,,,,,,, -99500,0.5367973,1.5400738,,,,,,,,,,,,,,,,, -99600,0.49493515,1.4480637,,,,,,,,,,,,,,,,, -99700,0.52607596,1.531446,,,,,,,,,,,,,,,,, -99800,0.51702034,1.5068945,,,,,,,,,,,,,,,,, -99900,0.50124824,1.5192906,,,,,,,,,,,,,,,,, -100000,0.52639073,1.516689,,,,,,,,,,,,,,,,, -100100,0.5319493,1.5888698,,,,,,,,,,,,,,,,, -100200,0.50840014,1.4625823,,,,,,,,,,,,,,,,, -100300,0.5095969,1.5177233,,,,,,,,,,,,,,,,, -100400,0.52519035,1.5526519,,,,,,,,,,,,,,,,, -100500,0.53235143,1.5168203,,,,,,,,,,,,,,,,, -100600,0.50641817,1.4985188,,,,,,,,,,,,,,,,, -100700,0.51306725,1.397781,,,,,,,,,,,,,,,,, -100800,0.5387401,1.5448266,,,,,,,,,,,,,,,,, -100900,0.5214521,1.5522151,,,,,,,,,,,,,,,,, -101000,0.51863176,1.4356673,,,,,,,,,,,,,,,,, -101100,0.53239,1.4957895,,,,,,,,,,,,,,,,, -101115,,,0.6891303658485413,1.4438997507095337,34.77914901250355,0.6863027215003967,1.444441556930542,30.29630986323259,3000.0,0.701225996017456,1.3644455671310425,30.21796655190293,3003.0,35308.40327715874,58590.03084850311,35308.40327715874,23276.88682627678,1.4825043678283691,0.0 -101200,0.49656448,1.4261239,,,,,,,,,,,,,,,,, -101300,0.50195074,1.4841553,,,,,,,,,,,,,,,,, -101400,0.5178091,1.4787292,,,,,,,,,,,,,,,,, -101500,0.5489638,1.5477883,,,,,,,,,,,,,,,,, -101600,0.53334767,1.5351651,,,,,,,,,,,,,,,,, -101700,0.55066484,1.4884424,,,,,,,,,,,,,,,,, -101800,0.5313174,1.4697156,,,,,,,,,,,,,,,,, -101900,0.54450625,1.5415668,,,,,,,,,,,,,,,,, -102000,0.5301043,1.4904472,,,,,,,,,,,,,,,,, -102100,0.52807426,1.4574693,,,,,,,,,,,,,,,,, -102200,0.57034516,1.5603138,,,,,,,,,,,,,,,,, -102300,0.5550031,1.5413247,,,,,,,,,,,,,,,,, -102400,0.52319527,1.4601375,,,,,,,,,,,,,,,,, -102500,0.5235148,1.4586229,,,,,,,,,,,,,,,,, -102600,0.54140955,1.4699018,,,,,,,,,,,,,,,,, -102700,0.54724145,1.485109,,,,,,,,,,,,,,,,, -102800,0.5385294,1.4748647,,,,,,,,,,,,,,,,, -102900,0.570569,1.5669562,,,,,,,,,,,,,,,,, -103000,0.54435533,1.5345358,,,,,,,,,,,,,,,,, -103100,0.53766406,1.5392723,,,,,,,,,,,,,,,,, -103200,0.5400703,1.4418695,,,,,,,,,,,,,,,,, -103300,0.5762322,1.5347625,,,,,,,,,,,,,,,,, -103400,0.58615965,1.5657482,,,,,,,,,,,,,,,,, -103500,0.5369538,1.403327,,,,,,,,,,,,,,,,, -103522,,,0.6818674206733704,1.4904634952545166,34.39548551799863,0.6874682307243347,1.439511775970459,30.242459944963176,3000.0,0.7032827734947205,1.3556402921676636,30.08292498174387,3003.0,36148.475451231,59959.14341711998,36148.475451231,23805.79743504524,1.5310685634613037,0.0 -103600,0.55366224,1.4797903,,,,,,,,,,,,,,,,, -103700,0.54682046,1.4603841,,,,,,,,,,,,,,,,, -103800,0.54400915,1.5513082,,,,,,,,,,,,,,,,, -103900,0.5678917,1.4373504,,,,,,,,,,,,,,,,, -104000,0.53295547,1.5016044,,,,,,,,,,,,,,,,, -104100,0.5441607,1.4807993,,,,,,,,,,,,,,,,, -104200,0.6087642,1.5921267,,,,,,,,,,,,,,,,, -104300,0.5760913,1.4082403,,,,,,,,,,,,,,,,, -104400,0.5527486,1.5066179,,,,,,,,,,,,,,,,, -104500,0.5634142,1.4931494,,,,,,,,,,,,,,,,, -104600,0.57519233,1.4479274,,,,,,,,,,,,,,,,, -104700,0.554972,1.4731278,,,,,,,,,,,,,,,,, -104800,0.5950734,1.4280292,,,,,,,,,,,,,,,,, -104900,0.56514645,1.4798356,,,,,,,,,,,,,,,,, -105000,0.5795883,1.4837233,,,,,,,,,,,,,,,,, -105100,0.5667646,1.5389642,,,,,,,,,,,,,,,,, -105200,0.5530456,1.4486493,,,,,,,,,,,,,,,,, -105300,0.556998,1.5294322,,,,,,,,,,,,,,,,, -105400,0.5593757,1.4383943,,,,,,,,,,,,,,,,, -105500,0.5915107,1.5097308,,,,,,,,,,,,,,,,, -105600,0.58910686,1.4829587,,,,,,,,,,,,,,,,, -105700,0.5990949,1.4907134,,,,,,,,,,,,,,,,, -105800,0.61752605,1.5250733,,,,,,,,,,,,,,,,, -105900,0.56210196,1.4414246,,,,,,,,,,,,,,,,, -105930,,,0.6841477751731873,1.4826703071594238,34.1736090872612,0.68779057264328,1.4393343925476074,30.294645014237727,3000.0,0.7045843005180359,1.3529555797576904,30.341159836823227,3003.0,36988.668182611465,61317.41492795944,36988.668182611465,24323.759889364243,1.5710361003875732,0.0 -106000,0.59010565,1.5072348,,,,,,,,,,,,,,,,, -106100,0.5795681,1.458459,,,,,,,,,,,,,,,,, -106200,0.57959026,1.4234765,,,,,,,,,,,,,,,,, -106300,0.592799,1.4126835,,,,,,,,,,,,,,,,, -106400,0.5931563,1.4050941,,,,,,,,,,,,,,,,, -106500,0.5708512,1.3788674,,,,,,,,,,,,,,,,, -106600,0.60424453,1.498155,,,,,,,,,,,,,,,,, -106700,0.6074325,1.4638991,,,,,,,,,,,,,,,,, -106800,0.59793925,1.4798052,,,,,,,,,,,,,,,,, -106900,0.586638,1.4972076,,,,,,,,,,,,,,,,, -107000,0.612805,1.4524955,,,,,,,,,,,,,,,,, -107100,0.6055871,1.4597852,,,,,,,,,,,,,,,,, -107200,0.6041274,1.5296608,,,,,,,,,,,,,,,,, -107300,0.5725041,1.3968273,,,,,,,,,,,,,,,,, -107400,0.6259488,1.4738982,,,,,,,,,,,,,,,,, -107500,0.61186945,1.4889708,,,,,,,,,,,,,,,,, -107600,0.57071483,1.4553533,,,,,,,,,,,,,,,,, -107700,0.5729144,1.4150702,,,,,,,,,,,,,,,,, -107800,0.64175415,1.5113117,,,,,,,,,,,,,,,,, -107900,0.619009,1.5240449,,,,,,,,,,,,,,,,, -108000,0.5944255,1.4922048,,,,,,,,,,,,,,,,, -108100,0.58824825,1.504115,,,,,,,,,,,,,,,,, -108200,0.62153417,1.3892245,,,,,,,,,,,,,,,,, -108300,0.5919661,1.4165164,,,,,,,,,,,,,,,,, -108337,,,0.6883156895637512,1.451397180557251,35.21473145739453,0.6892040967941284,1.4314101934432983,30.54711567710777,3000.0,0.7040265202522278,1.3469500541687012,30.276564514232,3003.0,37828.766984939575,62684.69976902008,37828.766984939575,24850.823167324063,1.614072561264038,0.0 -108400,0.6155516,1.4577585,,,,,,,,,,,,,,,,, -108500,0.6220002,1.4778049,,,,,,,,,,,,,,,,, -108600,0.63181955,1.485694,,,,,,,,,,,,,,,,, -108700,0.6330787,1.5464946,,,,,,,,,,,,,,,,, -108800,0.62110925,1.4513402,,,,,,,,,,,,,,,,, -108900,0.61680865,1.5107687,,,,,,,,,,,,,,,,, -109000,0.62234783,1.432685,,,,,,,,,,,,,,,,, -109100,0.627981,1.4007242,,,,,,,,,,,,,,,,, -109200,0.62195116,1.4547118,,,,,,,,,,,,,,,,, -109300,0.639525,1.4435569,,,,,,,,,,,,,,,,, -109400,0.6517261,1.5141696,,,,,,,,,,,,,,,,, -109500,0.6229255,1.5035644,,,,,,,,,,,,,,,,, -109600,0.6323245,1.4237425,,,,,,,,,,,,,,,,, -109700,0.62245166,1.3766775,,,,,,,,,,,,,,,,, -109800,0.72480917,1.4780973,,,,,,,,,,,,,,,,, -109900,0.63471085,1.5293847,,,,,,,,,,,,,,,,, -110000,0.61849105,1.505914,,,,,,,,,,,,,,,,, -110100,0.6145366,1.4309535,,,,,,,,,,,,,,,,, -110200,0.62834597,1.4811596,,,,,,,,,,,,,,,,, -110300,0.63689363,1.4291109,,,,,,,,,,,,,,,,, -110400,0.6550469,1.4618288,,,,,,,,,,,,,,,,, -110500,0.6328283,1.489409,,,,,,,,,,,,,,,,, -110600,0.6430636,1.467179,,,,,,,,,,,,,,,,, -110700,0.6135607,1.4793892,,,,,,,,,,,,,,,,, -110745,,,0.6906625032424927,1.4407415390014648,35.11437290724362,0.6892908811569214,1.4316939115524292,30.62353490460124,3000.0,0.7047585844993591,1.3457623720169067,30.64520647586331,3003.0,38668.92079091072,64037.598447322845,38668.92079091072,25363.45061326027,1.654350519180298,0.0 -110800,0.65404755,1.4410369,,,,,,,,,,,,,,,,, -110900,0.6193764,1.4281213,,,,,,,,,,,,,,,,, -111000,0.66075224,1.4678793,,,,,,,,,,,,,,,,, -111100,0.6658303,1.4248813,,,,,,,,,,,,,,,,, -111200,0.64950526,1.4233196,,,,,,,,,,,,,,,,, -111300,0.648121,1.3821409,,,,,,,,,,,,,,,,, -111400,0.6403802,1.4769732,,,,,,,,,,,,,,,,, -111500,0.6262646,1.4593129,,,,,,,,,,,,,,,,, -111600,0.62620157,1.470655,,,,,,,,,,,,,,,,, -111700,0.6414371,1.3414085,,,,,,,,,,,,,,,,, -111800,0.7117994,1.5057619,,,,,,,,,,,,,,,,, -111900,0.63859075,1.5064296,,,,,,,,,,,,,,,,, -112000,0.66198367,1.4618747,,,,,,,,,,,,,,,,, -112100,0.6388884,1.4478539,,,,,,,,,,,,,,,,, -112200,0.66378105,1.5161113,,,,,,,,,,,,,,,,, -112300,0.6700416,1.4555433,,,,,,,,,,,,,,,,, -112400,0.6409613,1.3682601,,,,,,,,,,,,,,,,, -112500,0.68137693,1.4912308,,,,,,,,,,,,,,,,, -112600,0.66359097,1.3633298,,,,,,,,,,,,,,,,, -112700,0.6670024,1.4334911,,,,,,,,,,,,,,,,, -112800,0.67525035,1.4343386,,,,,,,,,,,,,,,,, -112900,0.6640103,1.4170616,,,,,,,,,,,,,,,,, -113000,0.6699083,1.4705775,,,,,,,,,,,,,,,,, -113100,0.6823347,1.462054,,,,,,,,,,,,,,,,, -113152,,,0.7029312252998352,1.376530647277832,36.05811097263953,0.691299557685852,1.4230973720550537,30.798797566594644,3000.0,0.705909013748169,1.339033603668213,30.24255994252021,3003.0,39508.85071182251,65501.97001576424,39508.85071182251,25987.774053812027,1.6964483261108398,0.0 -113200,0.6579034,1.4168315,,,,,,,,,,,,,,,,, -113300,0.6814581,1.4500864,,,,,,,,,,,,,,,,, -113400,0.6681151,1.5527842,,,,,,,,,,,,,,,,, -113500,0.68076247,1.3654956,,,,,,,,,,,,,,,,, -113600,0.66719085,1.4137275,,,,,,,,,,,,,,,,, -113700,0.7014228,1.5194925,,,,,,,,,,,,,,,,, -113800,0.6898875,1.3970146,,,,,,,,,,,,,,,,, -113900,0.6919,1.5298257,,,,,,,,,,,,,,,,, -114000,0.6829546,1.3494834,,,,,,,,,,,,,,,,, -114100,0.6948905,1.4055804,,,,,,,,,,,,,,,,, -114200,0.6746414,1.4514725,,,,,,,,,,,,,,,,, -114300,0.6791145,1.4804312,,,,,,,,,,,,,,,,, -114400,0.7025365,1.4038838,,,,,,,,,,,,,,,,, -114500,0.6917449,1.4720861,,,,,,,,,,,,,,,,, -114600,0.68566144,1.3991368,,,,,,,,,,,,,,,,, -114700,0.67068815,1.4310801,,,,,,,,,,,,,,,,, -114800,0.69852656,1.4752067,,,,,,,,,,,,,,,,, -114900,0.67930686,1.4718484,,,,,,,,,,,,,,,,, -115000,0.7143114,1.4230978,,,,,,,,,,,,,,,,, -115100,0.6798837,1.3426077,,,,,,,,,,,,,,,,, -115200,0.70427793,1.3622835,,,,,,,,,,,,,,,,, -115300,0.70303386,1.4671106,,,,,,,,,,,,,,,,, -115400,0.70471835,1.3384705,,,,,,,,,,,,,,,,, -115500,0.7236391,1.4722288,,,,,,,,,,,,,,,,, -115560,,,0.7006024122238159,1.3854455947875977,35.91853943560247,0.6913491487503052,1.4228339195251465,30.759350276539493,3000.0,0.7081633806228638,1.335689663887024,30.60836370943533,3003.0,40348.93051624298,66873.91002559662,40348.93051624298,26519.51696491241,1.7374508380889893,0.0 -115600,0.70775235,1.395713,,,,,,,,,,,,,,,,, -115700,0.71292305,1.4600804,,,,,,,,,,,,,,,,, -115800,0.7036051,1.4642518,,,,,,,,,,,,,,,,, -115900,0.7606046,1.4470243,,,,,,,,,,,,,,,,, -116000,0.72451264,1.406619,,,,,,,,,,,,,,,,, -116100,0.6786765,1.4194391,,,,,,,,,,,,,,,,, -116200,0.70885754,1.3881533,,,,,,,,,,,,,,,,, -116300,0.72925615,1.4537951,,,,,,,,,,,,,,,,, -116400,0.72639805,1.4450791,,,,,,,,,,,,,,,,, -116500,0.7181903,1.4356349,,,,,,,,,,,,,,,,, -116600,0.72723323,1.4006625,,,,,,,,,,,,,,,,, -116700,0.7190925,1.3974087,,,,,,,,,,,,,,,,, -116800,0.712764,1.440148,,,,,,,,,,,,,,,,, -116900,0.717189,1.3940557,,,,,,,,,,,,,,,,, -117000,0.76837796,1.384648,,,,,,,,,,,,,,,,, -117100,0.720564,1.4175711,,,,,,,,,,,,,,,,, -117200,0.69467866,1.3826712,,,,,,,,,,,,,,,,, -117300,0.71600276,1.3253492,,,,,,,,,,,,,,,,, -117400,0.70223695,1.3762248,,,,,,,,,,,,,,,,, -117500,0.7397738,1.5015624,,,,,,,,,,,,,,,,, -117600,0.7209788,1.3509253,,,,,,,,,,,,,,,,, -117700,0.71817964,1.4143975,,,,,,,,,,,,,,,,, -117800,0.71786165,1.3461155,,,,,,,,,,,,,,,,, -117900,0.7174779,1.3947973,,,,,,,,,,,,,,,,, -117969,,,0.6944806575775146,1.42534601688385,35.57020902703697,0.691969096660614,1.4174968004226685,30.88002752465824,3000.0,0.7076637148857117,1.326730251312256,30.73994588003593,3003.0,41188.847472667694,68230.25590658188,41188.847472667694,27035.818858623505,1.779308557510376,0.0 -118000,0.74087363,1.4148184,,,,,,,,,,,,,,,,, -118100,0.72762674,1.4423689,,,,,,,,,,,,,,,,, -118200,0.7094288,1.381082,,,,,,,,,,,,,,,,, -118300,0.745118,1.4006466,,,,,,,,,,,,,,,,, -118400,0.73633325,1.403647,,,,,,,,,,,,,,,,, -118500,0.74189794,1.4116528,,,,,,,,,,,,,,,,, -118600,0.75077,1.4198115,,,,,,,,,,,,,,,,, -118700,0.77430433,1.4936327,,,,,,,,,,,,,,,,, -118800,0.72847164,1.4374365,,,,,,,,,,,,,,,,, -118900,0.7632991,1.3933138,,,,,,,,,,,,,,,,, -119000,0.7406414,1.4366771,,,,,,,,,,,,,,,,, -119100,0.7318324,1.372249,,,,,,,,,,,,,,,,, -119200,0.7473117,1.4080381,,,,,,,,,,,,,,,,, -119300,0.7359192,1.3963072,,,,,,,,,,,,,,,,, -119400,0.7886134,1.4429986,,,,,,,,,,,,,,,,, -119500,0.74973726,1.3990827,,,,,,,,,,,,,,,,, -119600,0.7286851,1.3328292,,,,,,,,,,,,,,,,, -119700,0.78349143,1.4805688,,,,,,,,,,,,,,,,, -119800,0.7202784,1.4265467,,,,,,,,,,,,,,,,, -119900,0.74300647,1.3676721,,,,,,,,,,,,,,,,, -120000,0.7377245,1.3393754,,,,,,,,,,,,,,,,, -120100,0.7326132,1.3729923,,,,,,,,,,,,,,,,, -120200,0.741316,1.4083724,,,,,,,,,,,,,,,,, -120300,0.73100346,1.3679904,,,,,,,,,,,,,,,,, -120376,,,0.7052478790283203,1.3589560985565186,36.04900926844258,0.6927626132965088,1.4140264987945557,31.06434313028884,3000.0,0.7080472111701965,1.3269232511520386,30.485523564052155,3003.0,42028.750698804855,69623.21995973587,42028.750698804855,27588.76014494896,1.8211784362792969,0.0 -120400,0.7549506,1.4046078,,,,,,,,,,,,,,,,, -120500,0.7544587,1.3282323,,,,,,,,,,,,,,,,, -120600,0.74301726,1.4157366,,,,,,,,,,,,,,,,, -120700,0.7359558,1.3738277,,,,,,,,,,,,,,,,, -120800,0.7733209,1.354308,,,,,,,,,,,,,,,,, -120900,0.7502174,1.3694663,,,,,,,,,,,,,,,,, -121000,0.7578606,1.3713706,,,,,,,,,,,,,,,,, -121100,0.75065714,1.4245282,,,,,,,,,,,,,,,,, -121200,0.77181846,1.3628994,,,,,,,,,,,,,,,,, -121300,0.75226843,1.4304358,,,,,,,,,,,,,,,,, -121400,0.75620943,1.331715,,,,,,,,,,,,,,,,, -121500,0.74834114,1.3629026,,,,,,,,,,,,,,,,, -121600,0.72828484,1.3488269,,,,,,,,,,,,,,,,, -121700,0.7638128,1.3930829,,,,,,,,,,,,,,,,, -121800,0.7819503,1.3692911,,,,,,,,,,,,,,,,, -121900,0.767092,1.408913,,,,,,,,,,,,,,,,, -122000,0.7737822,1.4116522,,,,,,,,,,,,,,,,, -122100,0.77943844,1.3462387,,,,,,,,,,,,,,,,, -122200,0.7683131,1.4036516,,,,,,,,,,,,,,,,, -122300,0.7732956,1.4107968,,,,,,,,,,,,,,,,, -122400,0.76604366,1.2925491,,,,,,,,,,,,,,,,, -122500,0.7937968,1.383333,,,,,,,,,,,,,,,,, -122600,0.77098227,1.3353714,,,,,,,,,,,,,,,,, -122700,0.7774964,1.4634329,,,,,,,,,,,,,,,,, -122788,,,0.7059203386306763,1.3507002592086792,36.08858036119556,0.6928866505622864,1.412611722946167,30.96063553398092,3000.0,0.709337055683136,1.32336163520813,30.99924664179277,3003.0,42868.7630045414,70992.43810915947,42868.7630045414,28117.84645557404,1.864339828491211,0.0 -122800,0.7795088,1.3704913,,,,,,,,,,,,,,,,, -122900,0.7909768,1.3573914,,,,,,,,,,,,,,,,, -123000,0.75041455,1.3875451,,,,,,,,,,,,,,,,, -123100,0.7744022,1.3727604,,,,,,,,,,,,,,,,, -123200,0.7747314,1.3562349,,,,,,,,,,,,,,,,, -123300,0.78913003,1.3980068,,,,,,,,,,,,,,,,, -123400,0.7627008,1.4039682,,,,,,,,,,,,,,,,, -123500,0.78942156,1.4145453,,,,,,,,,,,,,,,,, -123600,0.7808117,1.3650738,,,,,,,,,,,,,,,,, -123700,0.7666364,1.4153252,,,,,,,,,,,,,,,,, -123800,0.78447664,1.3406391,,,,,,,,,,,,,,,,, -123900,0.74167657,1.3813411,,,,,,,,,,,,,,,,, -124000,0.7630008,1.4210557,,,,,,,,,,,,,,,,, -124100,0.7363813,1.3493679,,,,,,,,,,,,,,,,, -124200,0.80169755,1.3639067,,,,,,,,,,,,,,,,, -124300,0.77820253,1.4465908,,,,,,,,,,,,,,,,, -124400,0.7786554,1.358296,,,,,,,,,,,,,,,,, -124500,0.7604148,1.4448124,,,,,,,,,,,,,,,,, -124600,0.8081956,1.3744938,,,,,,,,,,,,,,,,, -124700,0.7724111,1.3974622,,,,,,,,,,,,,,,,, -124800,0.78455687,1.361275,,,,,,,,,,,,,,,,, -124900,0.78824353,1.3955287,,,,,,,,,,,,,,,,, -125000,0.7571925,1.3589001,,,,,,,,,,,,,,,,, -125100,0.7928888,1.4199617,,,,,,,,,,,,,,,,, -125195,,,0.7114030718803406,1.329108476638794,37.05163518567466,0.6935189962387085,1.4125992059707642,30.95792946222888,3000.0,0.7094649076461792,1.3227872848510742,30.60444715357637,3003.0,43708.71370458603,72364.23915076256,43708.71370458603,28649.569895029068,1.914734125137329,0.0 -125200,0.7691015,1.301408,,,,,,,,,,,,,,,,, -125300,0.820955,1.4874547,,,,,,,,,,,,,,,,, -125400,0.7970905,1.4304489,,,,,,,,,,,,,,,,, -125500,0.77340627,1.3979864,,,,,,,,,,,,,,,,, -125600,0.7523049,1.3513373,,,,,,,,,,,,,,,,, -125700,0.7700405,1.3671895,,,,,,,,,,,,,,,,, -125800,0.7569825,1.3541695,,,,,,,,,,,,,,,,, -125900,0.7681018,1.3879207,,,,,,,,,,,,,,,,, -126000,0.7871219,1.4102142,,,,,,,,,,,,,,,,, -126100,0.768816,1.3799742,,,,,,,,,,,,,,,,, -126200,0.75989115,1.3408872,,,,,,,,,,,,,,,,, -126300,0.7732995,1.3486683,,,,,,,,,,,,,,,,, -126400,0.7783548,1.2511907,,,,,,,,,,,,,,,,, -126500,0.7821582,1.3698417,,,,,,,,,,,,,,,,, -126600,0.7813586,1.3713707,,,,,,,,,,,,,,,,, -126700,0.801824,1.4114214,,,,,,,,,,,,,,,,, -126800,0.79982984,1.4284004,,,,,,,,,,,,,,,,, -126900,0.7546176,1.3183708,,,,,,,,,,,,,,,,, -127000,0.7962172,1.3554131,,,,,,,,,,,,,,,,, -127100,0.7645401,1.3161952,,,,,,,,,,,,,,,,, -127200,0.7930383,1.4130552,,,,,,,,,,,,,,,,, -127300,0.80973935,1.4330142,,,,,,,,,,,,,,,,, -127400,0.81070673,1.3848505,,,,,,,,,,,,,,,,, -127500,0.79061997,1.3231883,,,,,,,,,,,,,,,,, -127600,0.8043832,1.3542225,,,,,,,,,,,,,,,,, -127603,,,0.7080180048942566,1.3448046445846558,36.59462651717983,0.6943497061729431,1.4095133543014526,30.97203620968157,3000.0,0.7101272344589233,1.321054458618164,30.63157748084628,3003.0,44548.620322942734,73723.40506076813,44548.620322942734,29168.70892763137,1.9580953121185305,0.0 -127700,0.7990934,1.3763273,,,,,,,,,,,,,,,,, -127800,0.78387445,1.3450882,,,,,,,,,,,,,,,,, -127900,0.7923431,1.3408202,,,,,,,,,,,,,,,,, -128000,0.78000194,1.4366211,,,,,,,,,,,,,,,,, -128100,0.83624536,1.3680657,,,,,,,,,,,,,,,,, -128200,0.77870286,1.3536423,,,,,,,,,,,,,,,,, -128300,0.7897099,1.3222772,,,,,,,,,,,,,,,,, -128400,0.76828986,1.3472373,,,,,,,,,,,,,,,,, -128500,0.78606504,1.294306,,,,,,,,,,,,,,,,, -128600,0.763068,1.3461003,,,,,,,,,,,,,,,,, -128700,0.75729066,1.2680268,,,,,,,,,,,,,,,,, -128800,0.78372234,1.3355137,,,,,,,,,,,,,,,,, -128900,0.81457794,1.3572738,,,,,,,,,,,,,,,,, -129000,0.80473447,1.3977319,,,,,,,,,,,,,,,,, -129100,0.759225,1.4238706,,,,,,,,,,,,,,,,, -129200,0.8011727,1.3721112,,,,,,,,,,,,,,,,, -129300,0.785493,1.3768222,,,,,,,,,,,,,,,,, -129400,0.774865,1.3758649,,,,,,,,,,,,,,,,, -129500,0.7990866,1.4096864,,,,,,,,,,,,,,,,, -129600,0.8044244,1.3568214,,,,,,,,,,,,,,,,, -129700,0.78031695,1.3324689,,,,,,,,,,,,,,,,, -129800,0.7699328,1.3399508,,,,,,,,,,,,,,,,, -129900,0.7981281,1.3799243,,,,,,,,,,,,,,,,, -130000,0.7994543,1.3585951,,,,,,,,,,,,,,,,, -130010,,,0.7112451791763306,1.3329684734344482,37.12121517508366,0.6941761374473572,1.4096252918243408,31.13628827962512,3000.0,0.7105107307434082,1.320068359375,30.66140440468964,3003.0,45388.499920129776,75089.05698800087,45388.499920129776,29694.35936927796,2.001633644104004,0.0 -130100,0.79297364,1.3530996,,,,,,,,,,,,,,,,, -130200,0.7910426,1.4002082,,,,,,,,,,,,,,,,, -130300,0.7918937,1.3425331,,,,,,,,,,,,,,,,, -130400,0.7540089,1.2560602,,,,,,,,,,,,,,,,, -130500,0.7938953,1.387018,,,,,,,,,,,,,,,,, -130600,0.791818,1.3671906,,,,,,,,,,,,,,,,, -130700,0.7719122,1.3765217,,,,,,,,,,,,,,,,, -130800,0.76202375,1.3272092,,,,,,,,,,,,,,,,, -130900,0.81499314,1.4670178,,,,,,,,,,,,,,,,, -131000,0.786973,1.3269256,,,,,,,,,,,,,,,,, -131100,0.8131984,1.3465655,,,,,,,,,,,,,,,,, -131200,0.7758406,1.3606677,,,,,,,,,,,,,,,,, -131300,0.769608,1.3983061,,,,,,,,,,,,,,,,, -131400,0.7763448,1.3221742,,,,,,,,,,,,,,,,, -131500,0.7806427,1.353149,,,,,,,,,,,,,,,,, -131600,0.7765436,1.3891078,,,,,,,,,,,,,,,,, -131700,0.80939525,1.3868902,,,,,,,,,,,,,,,,, -131800,0.7841853,1.4152086,,,,,,,,,,,,,,,,, -131900,0.78026325,1.3389902,,,,,,,,,,,,,,,,, -132000,0.7563223,1.3283676,,,,,,,,,,,,,,,,, -132100,0.79335576,1.3310044,,,,,,,,,,,,,,,,, -132200,0.75663745,1.351036,,,,,,,,,,,,,,,,, -132300,0.75954324,1.3089403,,,,,,,,,,,,,,,,, -132400,0.7920304,1.3243307,,,,,,,,,,,,,,,,, -132416,,,0.7100916504859924,1.3323860168457031,37.0542430438851,0.6940397620201111,1.4101505279541016,31.14853330741573,3000.0,0.7104874849319458,1.3203920125961304,30.71379910708937,3003.0,46228.42295074463,76468.07191419601,46228.42295074463,30233.320573568344,2.052457571029663,0.0 -132500,0.76163405,1.4078419,,,,,,,,,,,,,,,,, -132600,0.7986993,1.3827224,,,,,,,,,,,,,,,,, -132700,0.80353624,1.3635267,,,,,,,,,,,,,,,,, -132800,0.763233,1.3287084,,,,,,,,,,,,,,,,, -132900,0.76810807,1.3809853,,,,,,,,,,,,,,,,, -133000,0.7654377,1.3540769,,,,,,,,,,,,,,,,, -133100,0.76523185,1.3500909,,,,,,,,,,,,,,,,, -133200,0.7822426,1.2780553,,,,,,,,,,,,,,,,, -133300,0.8046751,1.4648685,,,,,,,,,,,,,,,,, -133333,,,0.7102746367454529,1.3364615440368652,37.00281637120834,0.6941141486167908,1.4101862907409668,31.14516992419973,3000.0,0.7104526162147522,1.3205291032791138,30.73949967608062,3003.0,46548.19679522514,77308.42940235138,46548.19679522514,30753.830042362213,2.09632420539856,0.0 -133333,,,,,,,,,,,,,,46548.19679522514,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 6181d8fca..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -871.6384983062744,0.0,27.327144145965576,1,0,27.327144145965576,0.0007088489946909,0.0,11.041826248168944,3003,898.9657063484192,0.0004938441561535,0.0,11.066301345825195,0.0004835649742744,0.0,11.036645889282228,3000 -1335.3256645202637,0.0222003459930419,867.5621616840363,2407,0,867.5621616840363,0.5244320631027222,17.176751782717933,2.647041082382202,3003,2202.989140987396,0.5285830497741699,22.527051695575885,2.639150857925415,0.5278669595718384,18.724227858741123,2.6120784282684326,3000 -1913.8280260562897,0.0496740341186523,1707.8006699085236,4815,0,1707.8006699085236,0.5975945591926575,22.466849339434948,2.0635406970977783,3003,3621.835510253906,0.5797721147537231,27.448449614774468,2.221475601196289,0.5930862426757812,23.89420124644968,2.0885164737701416,3000 -2400.76931643486,0.0744776725769043,2547.789595603943,7223,0,2547.789595603943,0.611388087272644,23.430505913385765,1.9671002626419067,3003,4948.868293046951,0.5924304723739624,28.30629469243196,2.1243929862976074,0.6062664985656738,24.714309696920097,2.0095314979553223,3000 -2852.701284646988,0.1004605293273925,3387.7001433372498,9631,0,3387.7001433372498,0.612155020236969,22.814568874859702,1.955398797988892,3003,6240.813716411591,0.5950818657875061,27.14673853695992,2.1058120727539062,0.6074815988540649,24.029534022545416,1.9885951280593872,3000 -3327.1690158843994,0.1277108192443847,4227.669346570969,12039,0,4227.669346570969,0.6214398145675659,23.91240416453053,1.895939826965332,3003,7555.354757547378,0.595694899559021,28.15082233737513,2.1055915355682373,0.612738847732544,24.70037868232737,1.944647431373596,3000 -3874.010820150376,0.155163288116455,5067.656465530396,14447,0,5067.656465530396,0.6236011981964111,24.321233521773724,1.880143523216248,3003,8942.286494970322,0.597169041633606,28.419443225761,2.078855276107788,0.615293025970459,25.01238163048505,1.933337569236756,3000 -4361.068868637085,0.181708812713623,5907.831691741943,16856,0,5907.831691741943,0.6249956488609314,24.49852003101253,1.88040554523468,3003,10269.621250391006,0.5945723056793213,28.754664656766057,2.086899518966675,0.6146854758262634,24.68478496541946,1.9335582256317136,3000 -4951.462516546249,0.2087688446044922,6748.037788152695,19264,0,6748.037788152695,0.6264830827713013,23.86853125338317,1.8554015159606927,3003,11700.324100255966,0.6116220355033875,28.895200180970697,1.968550443649292,0.6213685870170593,25.17453094857682,1.8958373069763184,3000 -5447.454230070114,0.2373852729797363,7588.130204200745,21673,0,7588.130204200745,0.629283607006073,24.44263277773216,1.8413125276565552,3003,13036.512573957443,0.6006941199302673,28.42460930168897,2.057503700256348,0.6233896613121033,25.493374681381614,1.8967119455337524,3000 -5961.189096927643,0.2700085639953613,8428.177636146545,24082,0,8428.177636146545,0.6273081302642822,23.7675562804176,1.846322655677796,3003,14390.40482354164,0.6017192602157593,27.65494665264454,2.061342716217041,0.6205998659133911,24.93630152938552,1.899630427360535,3000 -6543.995737552643,0.2978191375732422,9268.377697706224,26491,0,9268.377697706224,0.6311893463134766,24.47399505397761,1.8257577419281008,3003,15813.514918327332,0.6068245768547058,28.52083851534977,2.0195443630218506,0.6253611445426941,25.04661834767373,1.86841082572937,3000 -7099.407840251923,0.3309357166290283,10108.486330747604,28899,0,10108.486330747604,0.6313404440879822,24.2225975925976,1.8199553489685056,3003,17209.145723819733,0.6058840155601501,28.15505565929569,2.0256903171539307,0.6242700219154358,25.25882325875779,1.872414231300354,3000 -7554.298825263977,0.3605096340179443,10948.709831953049,31308,0,10948.709831953049,0.6303178071975708,24.27951532927107,1.821932077407837,3003,18504.36586952209,0.6748192310333252,32.83433125238866,1.52360999584198,0.6257950663566589,25.46695972247443,1.85858416557312,3000 -8083.440821886063,0.3903145790100097,11788.792954921722,33716,0,11788.792954921722,0.6331764459609985,24.776988322226952,1.7970689535140991,3003,19873.69760823249,0.606469988822937,28.84137114817753,2.019516468048096,0.627865731716156,25.98349001394717,1.842737078666687,3000 -8564.694946527481,0.4211187362670898,12628.945209980013,36125,0,12628.945209980013,0.6359886527061462,24.281479120561823,1.7861478328704834,3003,21195.21056318283,0.6092419028282166,28.618722834476483,2.0129339694976807,0.6275929808616638,25.35805133784494,1.8477492332458496,3000 -9122.370000362396,0.4585793018341064,13468.92739701271,38534,0,13468.92739701271,0.6339085698127747,24.38247752713862,1.7903285026550293,3003,22592.98215246201,0.6119536757469177,28.499661182879315,1.9667437076568604,0.6280888915061951,26.057433104129547,1.836946964263916,3000 -9651.330287218094,0.4973683357238769,14309.143985748293,40943,0,14309.143985748293,0.6396607160568237,24.5530504118486,1.7795579433441162,3003,23962.27396202088,0.6094376444816589,29.30273771814245,1.9945857524871824,0.6275309324264526,25.75543886358571,1.846041321754456,3000 -10196.409768819807,0.5284297466278076,15149.22114944458,43352,0,15149.22114944458,0.638893723487854,25.21839585129908,1.768803954124451,3003,25347.535893440247,0.6117132902145386,28.89464208910684,1.979415774345398,0.6324781775474548,26.215222262675244,1.8176498413085933,3000 -10797.866746902466,0.5607788562774658,15989.326357841492,45761,0,15989.326357841492,0.6413805484771729,25.37917327989407,1.7728618383407593,3003,26789.20568847656,0.6118826866149902,28.652251811486824,1.96248745918274,0.630605936050415,26.06368270027007,1.8224587440490725,3000 -11330.634264469149,0.5914275646209717,16829.423892498016,48173,0,16829.423892498016,0.6418104767799377,25.306392423383475,1.7519155740737915,3003,28162.17760157585,0.6136108636856079,29.17189039490762,1.967371225357056,0.6328873634338379,25.978765786257235,1.8119760751724243,3000 -11825.816824674606,0.6280605792999268,17669.362218618393,50581,0,17669.362218618393,0.6434489488601685,25.11578170365946,1.7368030548095703,3003,29497.41099047661,0.6235958337783813,29.89978511352561,1.867612361907959,0.6353052258491516,26.18907230106764,1.788408279418945,3000 -12396.352447509766,0.6613750457763672,18509.431342840195,52989,0,18509.431342840195,0.646528422832489,25.58466393499604,1.7195571660995483,3003,30908.126499176025,0.6194369792938232,29.233347669535043,1.9303407669067385,0.6357763409614563,26.26460212317322,1.7921721935272217,3000 -12955.963208198547,0.6942222118377686,19349.615591049194,55397,0,19349.615591049194,0.6478763818740845,25.67578430644092,1.710266351699829,3003,32308.030866384503,0.619412362575531,29.55223406319987,1.933494210243225,0.6385537385940552,26.649243759342824,1.7798618078231812,3000 -13580.615655899048,0.7280442714691162,20189.71707105637,57805,0,20189.71707105637,0.648561954498291,25.331366094109622,1.6985454559326172,3003,33772.894074201584,0.6222780346870422,29.94380117071459,1.898165822029113,0.6379957795143127,26.4487990260602,1.757357835769653,3000 -14125.097086429596,0.7669491767883301,21029.91014480591,60213,0,21029.91014480591,0.6507001519203186,26.085239423423623,1.698147892951965,3003,35157.684792757034,0.6185654997825623,29.45453204450172,1.9307162761688232,0.6402896642684937,26.69694660475321,1.7573803663253784,3000 -14679.639946699142,0.8013193607330322,21869.91243505478,62620,0,21869.91243505478,0.6520132422447205,25.953660552026207,1.678233027458191,3003,36552.34384179115,0.6608275771141052,33.19012386447994,1.6121519804000854,0.6453608870506287,26.712245705946223,1.729269027709961,3000 -15204.78667140007,0.8344008922576904,22709.923098564148,65027,0,22709.923098564148,0.6520829796791077,26.03838315165289,1.6709034442901611,3003,37917.61070728302,0.6231135129928589,30.1586215499912,1.8940107822418213,0.6435257792472839,26.89136604491987,1.7233736515045166,3000 -15805.35246348381,0.8699424266815186,23549.88878941536,67434,0,23549.88878941536,0.6572192311286926,26.566984573098747,1.64961040019989,3003,39358.25233221054,0.6210484504699707,30.423989758952263,1.9104690551757808,0.6472083330154419,26.998232384875,1.7089643478393557,3000 -16365.558494091034,0.9100024700164796,24390.009131908417,69842,0,24390.009131908417,0.6600197553634644,26.693438091255096,1.6326422691345217,3003,40758.694177389145,0.6296305060386658,30.28551842580955,1.8308836221694944,0.6483986377716064,27.293684458325167,1.700315237045288,3000 -16875.51067852974,0.9470679759979248,25230.24514484405,72250,0,25230.24514484405,0.6600894927978516,26.730419230948563,1.632683038711548,3003,42108.99506902695,0.6270090937614441,30.405617139420453,1.8659878969192505,0.6491177678108215,27.24253045501503,1.6966098546981812,3000 -17419.357144355774,0.9834208488464355,26070.14436650276,74657,0,26070.14436650276,0.6641218066215515,26.914830269027004,1.609385967254639,3003,43492.85367035866,0.6295642852783203,29.989682284103964,1.8524365425109863,0.6508660912513733,27.60271767408957,1.6791478395462036,3000 -17933.644248008728,1.0192360877990725,26910.25663924217,77065,0,26910.25663924217,0.6620184779167175,26.76849545725773,1.6064454317092896,3003,44847.3646748066,0.6358433365821838,30.153082714253948,1.804227352142334,0.6521183848381042,27.59902273365211,1.671108961105347,3000 -18513.77645015717,1.0562868118286133,27750.299865961075,79473,0,27750.299865961075,0.6640869379043579,26.96958389750291,1.592431664466858,3003,46267.65453505516,0.632390558719635,30.46417054723414,1.817828059196472,0.654263436794281,27.5063002215074,1.6578301191329956,3000 -19080.33007502556,1.0926096439361572,28590.505031108856,81881,0,28590.505031108856,0.6693974733352661,27.44548816720142,1.568246603012085,3003,47674.52781748772,0.6480662226676941,31.98326599336038,1.706380844116211,0.6571152210235596,27.902996214779897,1.6383646726608276,3000 -19643.68313574791,1.1292719841003418,29430.63443851471,84289,0,29430.63443851471,0.6694555878639221,27.313128787818524,1.5594232082366943,3003,49078.12288761139,0.6433086395263672,30.503595770525028,1.7546846866607666,0.6575739979743958,27.81755635340168,1.631597876548767,3000 -20193.09274697304,1.167518138885498,30270.717396259308,86697,0,30270.717396259308,0.6706408858299255,27.45172026614546,1.5513267517089844,3003,50467.73133611679,0.6414464712142944,30.919897634747738,1.7629821300506592,0.6594710350036621,27.64385552238076,1.6154887676239014,3000 -20896.967987298965,1.2072625160217283,31110.686940193176,89105,0,31110.686940193176,0.676021158695221,27.724933282657265,1.5245511531829834,3003,52011.69216299057,0.6479898691177368,31.870726073454623,1.7014189958572388,0.664678692817688,27.893390280404773,1.591239333152771,3000 -21431.26032066345,1.246953010559082,31950.802713632584,91513,0,31950.802713632584,0.6771948337554932,27.948103940076336,1.5109649896621704,3003,53386.21637392044,0.6436875462532043,31.34514009640573,1.7378581762313845,0.6646538972854614,28.456537818777825,1.5829776525497437,3000 -21988.401491642,1.291767120361328,32790.77405810356,93921,0,32790.77405810356,0.6811574101448059,28.463327942895543,1.4931126832962036,3003,54783.45065832138,0.6853712797164917,34.03844929654699,1.4687334299087524,0.6677164435386658,28.313632988196403,1.5667170286178589,3000 -22617.3301281929,1.331843376159668,33630.75802206993,96329,0,33630.75802206993,0.6832374930381775,28.521573655414183,1.4707233905792236,3003,56252.48037528992,0.6571729779243469,31.70680713304561,1.6591682434082031,0.6698614954948425,28.72916283015469,1.549560785293579,3000 -23148.10334300995,1.372218370437622,34470.64178228378,98737,0,34470.64178228378,0.6842019557952881,28.679556840476543,1.4612979888916016,3003,57623.255120038986,0.6514151692390442,31.84842778417848,1.690314531326294,0.6744739413261414,29.34685369574843,1.5323798656463623,3000 -23697.752576589584,1.4147746562957764,35310.656386613846,101145,0,35310.656386613846,0.688176155090332,29.02306537809828,1.4475120306015017,3003,59013.03903198242,0.6625736951828003,32.59588528405028,1.604433298110962,0.6739407777786255,29.29180447115428,1.5239664316177368,3000 -24358.2102060318,1.455638408660889,36150.72316431999,103553,0,36150.72316431999,0.6907559037208557,29.4018099388418,1.4269720315933228,3003,60513.682002067566,0.6571160554885864,32.57483756565918,1.6496434211730957,0.6794335842132568,29.71928640001056,1.5055320262908936,3000 -24949.888291597366,1.4942302703857422,36990.85962986946,105962,0,36990.85962986946,0.6929289698600769,29.46607487977802,1.4130771160125732,3003,61945.61144256592,0.6626680493354797,32.79830573550928,1.6218568086624146,0.6800907254219055,29.62048571366333,1.4952389001846311,3000 -25514.119074106216,1.5336592197418213,37831.08743548393,108370,0,37831.08743548393,0.6960200071334839,29.64580006144558,1.3971692323684692,3003,63350.19052934647,0.6737005710601807,33.575351495616545,1.5475784540176392,0.6833640933036804,29.811560649958405,1.479899525642395,3000 -26094.5704498291,1.57651948928833,38671.22283697128,110781,0,38671.22283697128,0.6975306868553162,29.50214132674396,1.3853771686553955,3003,64770.899523973465,0.6674370765686035,33.23845615206641,1.5858745574951172,0.6833516955375671,30.19010167590585,1.4708929061889648,3000 -26716.692593574524,1.61653470993042,39511.19406795502,113188,0,39511.19406795502,0.7002033591270447,29.95588310322253,1.369438409805298,3003,66233.1094198227,0.686324417591095,34.47618891337282,1.4715728759765625,0.6864266991615295,29.93864461495176,1.4545800685882568,3000 -27242.97893857956,1.6643104553222656,40351.23788332939,115596,0,40351.23788332939,0.7025042176246643,30.259197723267093,1.358932375907898,3003,67599.56405115128,0.6802150011062622,34.051214524478446,1.505571722984314,0.6871086359024048,30.445699454155303,1.4513825178146362,3000 -27726.58704996109,1.705122947692871,41191.43535208702,118005,0,41191.43535208702,0.7049096822738647,30.587941780081074,1.342985987663269,3003,68923.48609733582,0.680402934551239,34.02548684398224,1.5085806846618652,0.6894644498825073,30.400086035153965,1.437278389930725,3000 -28305.8653011322,1.7545392513275146,42031.62937164307,120413,0,42031.62937164307,0.7053279876708984,30.600486228597944,1.3383585214614868,3003,70343.08654594421,0.6904399991035461,35.09718083113306,1.4429056644439695,0.6901216506958008,30.587810108397143,1.4297069311141968,3000 -28883.33780145645,1.80470871925354,42871.6183693409,122820,0,42871.6183693409,0.7075591087341309,30.70963575294762,1.3309905529022217,3003,71760.67544698715,0.688624382019043,35.01808765515029,1.460661768913269,0.6914607286453247,30.7654477246006,1.4257858991622925,3000 -29431.05741333961,1.8475875854492188,43711.80949354172,125228,0,43711.80949354172,0.7094881534576416,30.738778424579746,1.3231710195541382,3003,73148.70480275154,0.699196994304657,36.03404700717277,1.396242618560791,0.6935189962387085,31.030971558360275,1.420162796974182,3000 -29983.50764322281,1.889930009841919,44551.95158267021,127636,0,44551.95158267021,0.7097554206848145,30.757323033480343,1.319035291671753,3003,74541.41567277908,0.696530818939209,35.25999172604652,1.4135314226150513,0.6937545537948608,31.00998771095735,1.4164506196975708,3000 -30523.580441236496,1.9321520328521729,45392.100281476974,130043,0,45392.100281476974,0.7098251581192017,30.876633593370595,1.3173716068267822,3003,75921.76205563545,0.6964461803436279,35.426351940087365,1.4212472438812256,0.693804144859314,30.97718072563005,1.414819836616516,3000 -31058.94911122322,1.9750447273254397,46232.19126367569,132451,0,46232.19126367569,0.7100808024406433,30.8854601382876,1.3164043426513672,3003,77297.3417544365,0.6961711049079895,35.54502780300605,1.419246792793274,0.6945357322692871,31.107725195668912,1.414249300956726,3000 -31597.15144467354,2.018925905227661,46539.62941074371,133333,0,46539.62941074371,0.7099761962890625,30.94145452506206,1.316562533378601,3003,78143.05393791199,0.6953193545341492,35.78710381048599,1.4255940914154053,0.6945977210998535,31.089424411746926,1.4143544435501099,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/measurements.csv deleted file mode 100644 index 871768dad..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.592714,11.052679,,,,,,,,,,,,,,,,, -1,,,0.0004938441561535,11.066301345825195,0.0,0.0004835649742744,11.036645889282228,0.0,3000.0,0.0007088489946909,11.041826248168944,0.0,3003.0,27.327144145965576,898.9657063484192,27.327144145965576,871.6384983062744,0.0,0.0 -100,0.5568858,7.5837264,,,,,,,,,,,,,,,,, -200,0.5368911,6.5762672,,,,,,,,,,,,,,,,, -300,0.58873624,5.885468,,,,,,,,,,,,,,,,, -400,0.38714546,5.3317547,,,,,,,,,,,,,,,,, -500,0.41671997,5.009959,,,,,,,,,,,,,,,,, -600,0.5747637,4.7823896,,,,,,,,,,,,,,,,, -700,0.5024608,4.458038,,,,,,,,,,,,,,,,, -800,0.49296826,4.0973625,,,,,,,,,,,,,,,,, -900,0.429829,3.9813647,,,,,,,,,,,,,,,,, -1000,0.46468723,3.8193138,,,,,,,,,,,,,,,,, -1100,0.47245887,3.650995,,,,,,,,,,,,,,,,, -1200,0.38255563,3.5871954,,,,,,,,,,,,,,,,, -1300,0.3342064,3.5784283,,,,,,,,,,,,,,,,, -1400,0.319643,3.3677516,,,,,,,,,,,,,,,,, -1500,0.38953033,3.277452,,,,,,,,,,,,,,,,, -1600,0.38463855,3.242651,,,,,,,,,,,,,,,,, -1700,0.26568553,3.1338124,,,,,,,,,,,,,,,,, -1800,0.26104116,2.9815557,,,,,,,,,,,,,,,,, -1900,0.31719986,2.9055722,,,,,,,,,,,,,,,,, -2000,0.2813896,2.8770254,,,,,,,,,,,,,,,,, -2100,0.18977995,2.7339077,,,,,,,,,,,,,,,,, -2200,0.25899264,2.7744927,,,,,,,,,,,,,,,,, -2300,0.1652533,2.6883833,,,,,,,,,,,,,,,,, -2400,0.23351912,2.6340175,,,,,,,,,,,,,,,,, -2407,,,0.5285830497741699,2.639150857925415,22.527051695575885,0.5278669595718384,2.6120784282684326,18.724227858741123,3000.0,0.5244320631027222,2.647041082382202,17.176751782717933,3003.0,867.5621616840363,2202.989140987396,867.5621616840363,1335.3256645202637,0.0222003459930419,0.0 -2500,0.27033052,2.5681543,,,,,,,,,,,,,,,,, -2600,0.19807485,2.560554,,,,,,,,,,,,,,,,, -2700,0.44046867,2.5488687,,,,,,,,,,,,,,,,, -2800,0.3491405,2.5016198,,,,,,,,,,,,,,,,, -2900,0.17676814,2.4908986,,,,,,,,,,,,,,,,, -3000,0.16483594,2.4348714,,,,,,,,,,,,,,,,, -3100,0.27830288,2.3002884,,,,,,,,,,,,,,,,, -3200,0.1953323,2.419587,,,,,,,,,,,,,,,,, -3300,0.38513997,2.3897295,,,,,,,,,,,,,,,,, -3400,0.27487683,2.309545,,,,,,,,,,,,,,,,, -3500,0.3310529,2.2602694,,,,,,,,,,,,,,,,, -3600,0.2103852,2.4380646,,,,,,,,,,,,,,,,, -3700,0.39016005,2.2759652,,,,,,,,,,,,,,,,, -3800,0.61311346,2.2972407,,,,,,,,,,,,,,,,, -3900,0.27262527,2.2486293,,,,,,,,,,,,,,,,, -4000,0.5079274,2.3426507,,,,,,,,,,,,,,,,, -4100,0.34731272,2.1207016,,,,,,,,,,,,,,,,, -4200,0.34101528,2.2494595,,,,,,,,,,,,,,,,, -4300,0.4792763,2.3391883,,,,,,,,,,,,,,,,, -4400,0.3504698,2.2398067,,,,,,,,,,,,,,,,, -4500,0.21986562,2.1759543,,,,,,,,,,,,,,,,, -4600,0.3578437,2.216097,,,,,,,,,,,,,,,,, -4700,0.49367574,2.1646028,,,,,,,,,,,,,,,,, -4800,0.70047826,2.2772093,,,,,,,,,,,,,,,,, -4815,,,0.5797721147537231,2.221475601196289,27.448449614774468,0.5930862426757812,2.0885164737701416,23.89420124644968,3000.0,0.5975945591926575,2.0635406970977783,22.466849339434948,3003.0,1707.8006699085236,3621.835510253906,1707.8006699085236,1913.8280260562897,0.0496740341186523,0.0 -4900,0.3508569,2.1731462,,,,,,,,,,,,,,,,, -5000,0.7233518,2.1798532,,,,,,,,,,,,,,,,, -5100,0.30662987,2.2407415,,,,,,,,,,,,,,,,, -5200,0.43273178,2.233076,,,,,,,,,,,,,,,,, -5300,0.27678424,2.2244568,,,,,,,,,,,,,,,,, -5400,0.33051252,2.2132463,,,,,,,,,,,,,,,,, -5500,0.3079934,2.1421914,,,,,,,,,,,,,,,,, -5600,0.3589503,2.147031,,,,,,,,,,,,,,,,, -5700,0.5118942,2.2392719,,,,,,,,,,,,,,,,, -5800,0.35469282,2.2012122,,,,,,,,,,,,,,,,, -5900,0.9474307,2.2321305,,,,,,,,,,,,,,,,, -6000,0.25832266,2.1787858,,,,,,,,,,,,,,,,, -6100,0.56095773,2.162343,,,,,,,,,,,,,,,,, -6200,0.30183896,2.181886,,,,,,,,,,,,,,,,, -6300,0.28615332,2.1259694,,,,,,,,,,,,,,,,, -6400,0.4830696,2.2217722,,,,,,,,,,,,,,,,, -6500,0.27430543,2.205361,,,,,,,,,,,,,,,,, -6600,0.48026696,2.1977737,,,,,,,,,,,,,,,,, -6700,0.297447,2.1585858,,,,,,,,,,,,,,,,, -6800,0.4979333,2.1672602,,,,,,,,,,,,,,,,, -6900,1.0662823,2.210566,,,,,,,,,,,,,,,,, -7000,0.34056845,2.1720657,,,,,,,,,,,,,,,,, -7100,0.44328785,2.145709,,,,,,,,,,,,,,,,, -7200,0.52781487,2.2174284,,,,,,,,,,,,,,,,, -7223,,,0.5924304723739624,2.1243929862976074,28.30629469243196,0.6062664985656738,2.0095314979553223,24.714309696920097,3000.0,0.611388087272644,1.9671002626419067,23.430505913385765,3003.0,2547.789595603943,4948.868293046951,2547.789595603943,2400.76931643486,0.0744776725769043,0.0 -7300,0.28763357,2.1244118,,,,,,,,,,,,,,,,, -7400,0.37343448,2.1397104,,,,,,,,,,,,,,,,, -7500,0.46572113,2.0899684,,,,,,,,,,,,,,,,, -7600,0.6492601,2.233908,,,,,,,,,,,,,,,,, -7700,0.42885008,2.0417717,,,,,,,,,,,,,,,,, -7800,0.39445135,2.0530772,,,,,,,,,,,,,,,,, -7900,0.50973046,2.0799437,,,,,,,,,,,,,,,,, -8000,0.73284125,2.2701936,,,,,,,,,,,,,,,,, -8100,0.76880074,2.0862522,,,,,,,,,,,,,,,,, -8200,0.68130094,2.1539767,,,,,,,,,,,,,,,,, -8300,0.2887853,2.1306617,,,,,,,,,,,,,,,,, -8400,0.30423966,2.1144154,,,,,,,,,,,,,,,,, -8500,0.29344806,2.1827998,,,,,,,,,,,,,,,,, -8600,0.30850968,2.1154552,,,,,,,,,,,,,,,,, -8700,0.45174554,2.1381645,,,,,,,,,,,,,,,,, -8800,0.33356345,2.2067785,,,,,,,,,,,,,,,,, -8900,0.29120478,2.0797963,,,,,,,,,,,,,,,,, -9000,0.66848344,2.1011326,,,,,,,,,,,,,,,,, -9100,0.4148424,2.1610053,,,,,,,,,,,,,,,,, -9200,0.45152995,2.231775,,,,,,,,,,,,,,,,, -9300,0.6099838,2.0691388,,,,,,,,,,,,,,,,, -9400,0.4108289,2.1768236,,,,,,,,,,,,,,,,, -9500,0.3975405,2.143816,,,,,,,,,,,,,,,,, -9600,0.8167134,2.0979707,,,,,,,,,,,,,,,,, -9631,,,0.5950818657875061,2.1058120727539062,27.14673853695992,0.6074815988540649,1.9885951280593872,24.029534022545416,3000.0,0.612155020236969,1.955398797988892,22.814568874859702,3003.0,3387.7001433372498,6240.813716411591,3387.7001433372498,2852.701284646988,0.1004605293273925,0.0 -9700,0.44896728,2.0774143,,,,,,,,,,,,,,,,, -9800,0.41799855,2.1563373,,,,,,,,,,,,,,,,, -9900,0.36820865,2.0852363,,,,,,,,,,,,,,,,, -10000,0.37699535,2.019713,,,,,,,,,,,,,,,,, -10100,0.34821793,2.0841835,,,,,,,,,,,,,,,,, -10200,0.4060187,2.1537735,,,,,,,,,,,,,,,,, -10300,0.2542715,1.9964145,,,,,,,,,,,,,,,,, -10400,0.28535348,2.0357435,,,,,,,,,,,,,,,,, -10500,0.5574662,2.1562648,,,,,,,,,,,,,,,,, -10600,0.3419455,2.0743322,,,,,,,,,,,,,,,,, -10700,0.4113638,2.1377797,,,,,,,,,,,,,,,,, -10800,0.35080084,2.1235235,,,,,,,,,,,,,,,,, -10900,0.26963243,2.0714989,,,,,,,,,,,,,,,,, -11000,0.9216164,2.1800852,,,,,,,,,,,,,,,,, -11100,0.5163356,2.0719023,,,,,,,,,,,,,,,,, -11200,0.6247187,2.117878,,,,,,,,,,,,,,,,, -11300,0.5330853,2.163258,,,,,,,,,,,,,,,,, -11400,0.4211602,2.0257056,,,,,,,,,,,,,,,,, -11500,0.2833814,2.1871197,,,,,,,,,,,,,,,,, -11600,0.33371022,2.042307,,,,,,,,,,,,,,,,, -11700,0.7167447,2.103484,,,,,,,,,,,,,,,,, -11800,0.46187028,2.0750568,,,,,,,,,,,,,,,,, -11900,0.73910815,2.1201236,,,,,,,,,,,,,,,,, -12000,0.960607,2.1638155,,,,,,,,,,,,,,,,, -12039,,,0.595694899559021,2.1055915355682373,28.15082233737513,0.612738847732544,1.944647431373596,24.70037868232737,3000.0,0.6214398145675659,1.895939826965332,23.91240416453053,3003.0,4227.669346570969,7555.354757547378,4227.669346570969,3327.1690158843994,0.1277108192443847,0.0 -12100,0.30449596,2.06593,,,,,,,,,,,,,,,,, -12200,0.40788665,2.0761843,,,,,,,,,,,,,,,,, -12300,0.25868803,2.1240518,,,,,,,,,,,,,,,,, -12400,0.49643844,2.0206761,,,,,,,,,,,,,,,,, -12500,0.96375203,2.143327,,,,,,,,,,,,,,,,, -12600,0.4322834,2.0602267,,,,,,,,,,,,,,,,, -12700,0.802134,2.0721784,,,,,,,,,,,,,,,,, -12800,0.60788655,2.2151203,,,,,,,,,,,,,,,,, -12900,0.4917611,2.0020669,,,,,,,,,,,,,,,,, -13000,0.7280117,2.1081085,,,,,,,,,,,,,,,,, -13100,0.68341583,2.1565533,,,,,,,,,,,,,,,,, -13200,0.38202623,2.1440508,,,,,,,,,,,,,,,,, -13300,0.5684739,2.1424787,,,,,,,,,,,,,,,,, -13400,0.4623737,2.1058207,,,,,,,,,,,,,,,,, -13500,0.49858564,2.1588404,,,,,,,,,,,,,,,,, -13600,0.45077735,2.0600498,,,,,,,,,,,,,,,,, -13700,0.565657,2.1007254,,,,,,,,,,,,,,,,, -13800,0.322098,2.0467858,,,,,,,,,,,,,,,,, -13900,0.39007184,2.059091,,,,,,,,,,,,,,,,, -14000,0.29066586,2.0963426,,,,,,,,,,,,,,,,, -14100,0.46301177,2.1164327,,,,,,,,,,,,,,,,, -14200,0.37842122,2.0071628,,,,,,,,,,,,,,,,, -14300,0.38848054,2.0460904,,,,,,,,,,,,,,,,, -14400,0.81640345,2.07407,,,,,,,,,,,,,,,,, -14447,,,0.597169041633606,2.078855276107788,28.419443225761,0.615293025970459,1.933337569236756,25.01238163048505,3000.0,0.6236011981964111,1.880143523216248,24.321233521773724,3003.0,5067.656465530396,8942.286494970322,5067.656465530396,3874.010820150376,0.155163288116455,0.0 -14500,0.3205256,2.0535586,,,,,,,,,,,,,,,,, -14600,0.31295186,2.1485069,,,,,,,,,,,,,,,,, -14700,0.707607,2.0521889,,,,,,,,,,,,,,,,, -14800,0.46874097,2.12484,,,,,,,,,,,,,,,,, -14900,0.36584738,2.0741751,,,,,,,,,,,,,,,,, -15000,0.54065055,2.1016,,,,,,,,,,,,,,,,, -15100,0.84730685,2.0822978,,,,,,,,,,,,,,,,, -15200,0.5077755,2.1033654,,,,,,,,,,,,,,,,, -15300,0.38587978,2.0979593,,,,,,,,,,,,,,,,, -15400,0.571299,2.1079004,,,,,,,,,,,,,,,,, -15500,0.7598442,2.0748844,,,,,,,,,,,,,,,,, -15600,0.53295666,2.1616054,,,,,,,,,,,,,,,,, -15700,0.44672891,2.1453114,,,,,,,,,,,,,,,,, -15800,0.42608428,2.0199363,,,,,,,,,,,,,,,,, -15900,0.3699812,2.1333857,,,,,,,,,,,,,,,,, -16000,0.45407492,1.9716694,,,,,,,,,,,,,,,,, -16100,0.30375186,2.1238008,,,,,,,,,,,,,,,,, -16200,0.8437269,2.0251405,,,,,,,,,,,,,,,,, -16300,0.2902003,2.0875628,,,,,,,,,,,,,,,,, -16400,0.49227777,2.0102563,,,,,,,,,,,,,,,,, -16500,0.9220041,2.0317938,,,,,,,,,,,,,,,,, -16600,0.5498446,2.0422723,,,,,,,,,,,,,,,,, -16700,0.52622676,2.01648,,,,,,,,,,,,,,,,, -16800,0.47463298,2.0785277,,,,,,,,,,,,,,,,, -16856,,,0.5945723056793213,2.086899518966675,28.754664656766057,0.6146854758262634,1.9335582256317136,24.68478496541946,3000.0,0.6249956488609314,1.88040554523468,24.49852003101253,3003.0,5907.831691741943,10269.621250391006,5907.831691741943,4361.068868637085,0.181708812713623,0.0 -16900,0.30445224,2.1418965,,,,,,,,,,,,,,,,, -17000,0.40227628,2.1858644,,,,,,,,,,,,,,,,, -17100,0.26719642,2.1180558,,,,,,,,,,,,,,,,, -17200,0.394153,2.0543742,,,,,,,,,,,,,,,,, -17300,0.27818838,2.007886,,,,,,,,,,,,,,,,, -17400,0.8876437,2.1201117,,,,,,,,,,,,,,,,, -17500,0.29516175,2.0300133,,,,,,,,,,,,,,,,, -17600,0.65024054,2.084795,,,,,,,,,,,,,,,,, -17700,0.76410365,2.054426,,,,,,,,,,,,,,,,, -17800,0.3637372,2.128349,,,,,,,,,,,,,,,,, -17900,0.4915328,1.9933088,,,,,,,,,,,,,,,,, -18000,0.8163965,2.0658565,,,,,,,,,,,,,,,,, -18100,0.39508525,1.9885173,,,,,,,,,,,,,,,,, -18200,0.5312223,2.0722842,,,,,,,,,,,,,,,,, -18300,0.38646394,2.1135495,,,,,,,,,,,,,,,,, -18400,0.29599664,2.0182047,,,,,,,,,,,,,,,,, -18500,0.3012822,2.0929239,,,,,,,,,,,,,,,,, -18600,0.5877527,2.0509956,,,,,,,,,,,,,,,,, -18700,0.284959,2.0606394,,,,,,,,,,,,,,,,, -18800,0.5594936,1.9859607,,,,,,,,,,,,,,,,, -18900,0.3337019,2.0732038,,,,,,,,,,,,,,,,, -19000,0.31415588,1.9661201,,,,,,,,,,,,,,,,, -19100,0.41134983,2.0826368,,,,,,,,,,,,,,,,, -19200,0.5532759,2.077475,,,,,,,,,,,,,,,,, -19264,,,0.6116220355033875,1.968550443649292,28.895200180970697,0.6213685870170593,1.8958373069763184,25.17453094857682,3000.0,0.6264830827713013,1.8554015159606927,23.86853125338317,3003.0,6748.037788152695,11700.324100255966,6748.037788152695,4951.462516546249,0.2087688446044922,0.0 -19300,0.4666625,1.9568262,,,,,,,,,,,,,,,,, -19400,0.5233101,2.0507138,,,,,,,,,,,,,,,,, -19500,0.8288173,2.0589566,,,,,,,,,,,,,,,,, -19600,0.35461357,2.090922,,,,,,,,,,,,,,,,, -19700,0.47974706,2.0773783,,,,,,,,,,,,,,,,, -19800,0.29827216,2.1088428,,,,,,,,,,,,,,,,, -19900,0.30317333,2.036993,,,,,,,,,,,,,,,,, -20000,0.59085613,2.0579906,,,,,,,,,,,,,,,,, -20100,0.32120275,2.025984,,,,,,,,,,,,,,,,, -20200,0.34024113,2.0561886,,,,,,,,,,,,,,,,, -20300,0.39620167,2.0209577,,,,,,,,,,,,,,,,, -20400,0.28967038,2.164527,,,,,,,,,,,,,,,,, -20500,0.5372959,2.1100771,,,,,,,,,,,,,,,,, -20600,0.34206358,2.1149466,,,,,,,,,,,,,,,,, -20700,0.39085254,2.059055,,,,,,,,,,,,,,,,, -20800,0.4473942,2.019779,,,,,,,,,,,,,,,,, -20900,0.3702512,1.9920902,,,,,,,,,,,,,,,,, -21000,0.6763721,2.1418574,,,,,,,,,,,,,,,,, -21100,0.26458135,2.058702,,,,,,,,,,,,,,,,, -21200,0.31531793,2.078815,,,,,,,,,,,,,,,,, -21300,0.53807515,2.0147862,,,,,,,,,,,,,,,,, -21400,0.29259977,2.1398675,,,,,,,,,,,,,,,,, -21500,0.36395174,1.9881469,,,,,,,,,,,,,,,,, -21600,0.811255,2.0478163,,,,,,,,,,,,,,,,, -21673,,,0.6006941199302673,2.057503700256348,28.42460930168897,0.6233896613121033,1.8967119455337524,25.493374681381614,3000.0,0.629283607006073,1.8413125276565552,24.44263277773216,3003.0,7588.130204200745,13036.512573957443,7588.130204200745,5447.454230070114,0.2373852729797363,0.0 -21700,0.70193404,2.0217843,,,,,,,,,,,,,,,,, -21800,0.4958037,2.1346836,,,,,,,,,,,,,,,,, -21900,0.42049408,2.0274022,,,,,,,,,,,,,,,,, -22000,0.5915885,2.0507677,,,,,,,,,,,,,,,,, -22100,0.73441875,2.1024568,,,,,,,,,,,,,,,,, -22200,0.91897374,2.0519457,,,,,,,,,,,,,,,,, -22300,0.44193882,2.0273705,,,,,,,,,,,,,,,,, -22400,0.38351095,2.1370785,,,,,,,,,,,,,,,,, -22500,0.40192217,2.0899465,,,,,,,,,,,,,,,,, -22600,0.45791176,2.049008,,,,,,,,,,,,,,,,, -22700,0.45417276,2.1132252,,,,,,,,,,,,,,,,, -22800,0.4892048,1.906649,,,,,,,,,,,,,,,,, -22900,0.3189847,2.0897834,,,,,,,,,,,,,,,,, -23000,0.3117706,2.084938,,,,,,,,,,,,,,,,, -23100,0.37798968,2.0710795,,,,,,,,,,,,,,,,, -23200,0.6355887,2.0770996,,,,,,,,,,,,,,,,, -23300,1.0608106,2.0072696,,,,,,,,,,,,,,,,, -23400,0.3648723,2.0668838,,,,,,,,,,,,,,,,, -23500,0.26168826,2.040763,,,,,,,,,,,,,,,,, -23600,0.40198922,2.1293592,,,,,,,,,,,,,,,,, -23700,0.39053327,2.035591,,,,,,,,,,,,,,,,, -23800,0.55466235,2.1134849,,,,,,,,,,,,,,,,, -23900,0.33280507,2.049361,,,,,,,,,,,,,,,,, -24000,0.29418424,2.0290709,,,,,,,,,,,,,,,,, -24082,,,0.6017192602157593,2.061342716217041,27.65494665264454,0.6205998659133911,1.899630427360535,24.93630152938552,3000.0,0.6273081302642822,1.846322655677796,23.7675562804176,3003.0,8428.177636146545,14390.40482354164,8428.177636146545,5961.189096927643,0.2700085639953613,0.0 -24100,0.28810495,2.130411,,,,,,,,,,,,,,,,, -24200,0.41536567,2.0607097,,,,,,,,,,,,,,,,, -24300,0.3352812,2.2043507,,,,,,,,,,,,,,,,, -24400,0.842476,2.1174638,,,,,,,,,,,,,,,,, -24500,0.5774828,2.0477664,,,,,,,,,,,,,,,,, -24600,0.28139922,2.0164106,,,,,,,,,,,,,,,,, -24700,0.41777605,1.9625436,,,,,,,,,,,,,,,,, -24800,0.36293635,2.069205,,,,,,,,,,,,,,,,, -24900,0.64577484,2.0858488,,,,,,,,,,,,,,,,, -25000,0.783368,2.0732527,,,,,,,,,,,,,,,,, -25100,0.4552804,2.060017,,,,,,,,,,,,,,,,, -25200,0.4915104,2.0481389,,,,,,,,,,,,,,,,, -25300,0.56705946,2.104895,,,,,,,,,,,,,,,,, -25400,0.35410804,2.037148,,,,,,,,,,,,,,,,, -25500,0.32638955,1.939823,,,,,,,,,,,,,,,,, -25600,0.4288313,2.0613515,,,,,,,,,,,,,,,,, -25700,0.42937434,2.0394855,,,,,,,,,,,,,,,,, -25800,0.604427,1.9467846,,,,,,,,,,,,,,,,, -25900,0.44810793,2.0925663,,,,,,,,,,,,,,,,, -26000,0.38114658,1.8986593,,,,,,,,,,,,,,,,, -26100,0.30042446,2.100378,,,,,,,,,,,,,,,,, -26200,0.55314875,2.1064968,,,,,,,,,,,,,,,,, -26300,0.36457816,2.0378435,,,,,,,,,,,,,,,,, -26400,0.45976505,2.1859713,,,,,,,,,,,,,,,,, -26491,,,0.6068245768547058,2.0195443630218506,28.52083851534977,0.6253611445426941,1.86841082572937,25.04661834767373,3000.0,0.6311893463134766,1.8257577419281008,24.47399505397761,3003.0,9268.377697706224,15813.514918327332,9268.377697706224,6543.995737552643,0.2978191375732422,0.0 -26500,0.44888318,2.0955637,,,,,,,,,,,,,,,,, -26600,0.52217436,2.0615916,,,,,,,,,,,,,,,,, -26700,0.5144492,2.0364158,,,,,,,,,,,,,,,,, -26800,0.7969588,2.1112816,,,,,,,,,,,,,,,,, -26900,0.5174717,2.054411,,,,,,,,,,,,,,,,, -27000,0.9241155,2.002475,,,,,,,,,,,,,,,,, -27100,0.4055177,2.077184,,,,,,,,,,,,,,,,, -27200,0.6563166,2.0256898,,,,,,,,,,,,,,,,, -27300,0.6308397,2.0449739,,,,,,,,,,,,,,,,, -27400,0.34400493,1.9638631,,,,,,,,,,,,,,,,, -27500,0.5104184,1.9895353,,,,,,,,,,,,,,,,, -27600,0.35023868,2.0813525,,,,,,,,,,,,,,,,, -27700,0.26867837,2.1055148,,,,,,,,,,,,,,,,, -27800,0.55876565,2.0833414,,,,,,,,,,,,,,,,, -27900,0.53896874,2.025988,,,,,,,,,,,,,,,,, -28000,0.33726594,2.0229547,,,,,,,,,,,,,,,,, -28100,0.5506268,1.998261,,,,,,,,,,,,,,,,, -28200,0.36800477,2.0496335,,,,,,,,,,,,,,,,, -28300,0.34038317,1.9950262,,,,,,,,,,,,,,,,, -28400,0.34651443,2.0234756,,,,,,,,,,,,,,,,, -28500,0.3220793,2.0663843,,,,,,,,,,,,,,,,, -28600,0.28214332,2.070642,,,,,,,,,,,,,,,,, -28700,0.3075565,2.0234942,,,,,,,,,,,,,,,,, -28800,0.80785686,2.1084027,,,,,,,,,,,,,,,,, -28899,,,0.6058840155601501,2.0256903171539307,28.15505565929569,0.6242700219154358,1.872414231300354,25.25882325875779,3000.0,0.6313404440879822,1.8199553489685056,24.2225975925976,3003.0,10108.486330747604,17209.145723819733,10108.486330747604,7099.407840251923,0.3309357166290283,0.0 -28900,0.45448896,2.1192474,,,,,,,,,,,,,,,,, -29000,0.28847474,2.0049732,,,,,,,,,,,,,,,,, -29100,0.31512883,2.0712712,,,,,,,,,,,,,,,,, -29200,0.64007896,2.0706012,,,,,,,,,,,,,,,,, -29300,0.42767864,2.1211932,,,,,,,,,,,,,,,,, -29400,0.47993052,2.058997,,,,,,,,,,,,,,,,, -29500,0.30944315,2.0737085,,,,,,,,,,,,,,,,, -29600,0.4649076,2.0135098,,,,,,,,,,,,,,,,, -29700,0.37287894,2.030322,,,,,,,,,,,,,,,,, -29800,0.26479685,2.005491,,,,,,,,,,,,,,,,, -29900,0.31986603,2.0488853,,,,,,,,,,,,,,,,, -30000,0.5425778,2.0581439,,,,,,,,,,,,,,,,, -30100,0.28998724,2.0751228,,,,,,,,,,,,,,,,, -30200,0.40894267,2.0401542,,,,,,,,,,,,,,,,, -30300,0.31909975,2.0443451,,,,,,,,,,,,,,,,, -30400,0.32613438,1.9754572,,,,,,,,,,,,,,,,, -30500,0.31611395,2.0673666,,,,,,,,,,,,,,,,, -30600,0.2839024,1.9311346,,,,,,,,,,,,,,,,, -30700,0.30536216,2.0360408,,,,,,,,,,,,,,,,, -30800,0.45079795,1.967058,,,,,,,,,,,,,,,,, -30900,0.3662199,2.0180972,,,,,,,,,,,,,,,,, -31000,0.45302147,2.0434632,,,,,,,,,,,,,,,,, -31100,0.40631732,2.110486,,,,,,,,,,,,,,,,, -31200,0.34150487,1.9794397,,,,,,,,,,,,,,,,, -31300,0.39098957,2.0721216,,,,,,,,,,,,,,,,, -31308,,,0.6748192310333252,1.52360999584198,32.83433125238866,0.6257950663566589,1.85858416557312,25.46695972247443,3000.0,0.6303178071975708,1.821932077407837,24.27951532927107,3003.0,10948.709831953049,18504.36586952209,10948.709831953049,7554.298825263977,0.3605096340179443,0.0 -31400,0.4353416,2.0422318,,,,,,,,,,,,,,,,, -31500,0.36676612,2.0521202,,,,,,,,,,,,,,,,, -31600,0.55142385,2.0228426,,,,,,,,,,,,,,,,, -31700,0.46087652,2.0782907,,,,,,,,,,,,,,,,, -31800,0.35112765,1.9517945,,,,,,,,,,,,,,,,, -31900,0.4506957,2.0643494,,,,,,,,,,,,,,,,, -32000,0.35024145,2.0181997,,,,,,,,,,,,,,,,, -32100,0.28551707,2.0674179,,,,,,,,,,,,,,,,, -32200,0.30298078,2.0366611,,,,,,,,,,,,,,,,, -32300,0.42466426,2.0762212,,,,,,,,,,,,,,,,, -32400,0.31491834,2.0976477,,,,,,,,,,,,,,,,, -32500,0.45081362,2.0198805,,,,,,,,,,,,,,,,, -32600,0.27617812,2.0272107,,,,,,,,,,,,,,,,, -32700,0.3045046,2.0235217,,,,,,,,,,,,,,,,, -32800,0.27389705,2.0729604,,,,,,,,,,,,,,,,, -32900,0.32629448,2.0234165,,,,,,,,,,,,,,,,, -33000,0.3636141,1.9996521,,,,,,,,,,,,,,,,, -33100,0.3003295,2.0395243,,,,,,,,,,,,,,,,, -33200,0.32769525,2.0706363,,,,,,,,,,,,,,,,, -33300,0.31874815,2.0013673,,,,,,,,,,,,,,,,, -33400,0.40147638,2.068129,,,,,,,,,,,,,,,,, -33500,0.34041494,1.9921232,,,,,,,,,,,,,,,,, -33600,0.38314152,2.0835238,,,,,,,,,,,,,,,,, -33700,0.5536237,2.0897589,,,,,,,,,,,,,,,,, -33716,,,0.606469988822937,2.019516468048096,28.84137114817753,0.627865731716156,1.842737078666687,25.98349001394717,3000.0,0.6331764459609985,1.7970689535140991,24.776988322226952,3003.0,11788.792954921722,19873.69760823249,11788.792954921722,8083.440821886063,0.3903145790100097,0.0 -33800,0.2945097,1.9748361,,,,,,,,,,,,,,,,, -33900,0.50949603,1.9365953,,,,,,,,,,,,,,,,, -34000,0.3167077,1.998706,,,,,,,,,,,,,,,,, -34100,0.30458415,2.0110006,,,,,,,,,,,,,,,,, -34200,0.34617737,2.0315647,,,,,,,,,,,,,,,,, -34300,0.46477646,1.9726156,,,,,,,,,,,,,,,,, -34400,0.28304663,1.9895821,,,,,,,,,,,,,,,,, -34500,0.48806286,1.948722,,,,,,,,,,,,,,,,, -34600,0.32187158,2.0045788,,,,,,,,,,,,,,,,, -34700,0.39860386,2.0908153,,,,,,,,,,,,,,,,, -34800,0.31929296,1.9658779,,,,,,,,,,,,,,,,, -34900,0.32658148,2.0386107,,,,,,,,,,,,,,,,, -35000,0.51859283,1.9889096,,,,,,,,,,,,,,,,, -35100,0.35506666,1.9998294,,,,,,,,,,,,,,,,, -35200,0.4513901,1.9404349,,,,,,,,,,,,,,,,, -35300,0.5214799,2.0098424,,,,,,,,,,,,,,,,, -35400,0.25373688,1.9908073,,,,,,,,,,,,,,,,, -35500,0.30677357,2.0275323,,,,,,,,,,,,,,,,, -35600,0.3339768,2.0029743,,,,,,,,,,,,,,,,, -35700,0.3711779,2.0518966,,,,,,,,,,,,,,,,, -35800,0.4073715,2.0282092,,,,,,,,,,,,,,,,, -35900,0.3724468,2.0577986,,,,,,,,,,,,,,,,, -36000,0.51990485,1.9571891,,,,,,,,,,,,,,,,, -36100,0.38653588,1.9256413,,,,,,,,,,,,,,,,, -36125,,,0.6092419028282166,2.0129339694976807,28.618722834476483,0.6275929808616638,1.8477492332458496,25.35805133784494,3000.0,0.6359886527061462,1.7861478328704834,24.281479120561823,3003.0,12628.945209980013,21195.21056318283,12628.945209980013,8564.694946527481,0.4211187362670898,0.0 -36200,0.25851986,2.0276206,,,,,,,,,,,,,,,,, -36300,0.3434311,1.9240577,,,,,,,,,,,,,,,,, -36400,0.3193281,2.0772896,,,,,,,,,,,,,,,,, -36500,0.42542815,2.0198178,,,,,,,,,,,,,,,,, -36600,0.3299604,1.9574742,,,,,,,,,,,,,,,,, -36700,0.34762645,1.9171349,,,,,,,,,,,,,,,,, -36800,0.393789,1.9577217,,,,,,,,,,,,,,,,, -36900,0.44651648,2.0861855,,,,,,,,,,,,,,,,, -37000,0.35455492,2.025791,,,,,,,,,,,,,,,,, -37100,0.27507538,1.9489889,,,,,,,,,,,,,,,,, -37200,0.4381337,2.0145912,,,,,,,,,,,,,,,,, -37300,0.35721576,2.0736995,,,,,,,,,,,,,,,,, -37400,0.5610544,2.1355767,,,,,,,,,,,,,,,,, -37500,0.29208094,2.0555336,,,,,,,,,,,,,,,,, -37600,0.31616643,2.0853128,,,,,,,,,,,,,,,,, -37700,0.5285646,2.0915053,,,,,,,,,,,,,,,,, -37800,0.25521427,1.9182751,,,,,,,,,,,,,,,,, -37900,0.31119356,2.0309548,,,,,,,,,,,,,,,,, -38000,0.48171327,2.0131035,,,,,,,,,,,,,,,,, -38100,0.3762746,2.0356472,,,,,,,,,,,,,,,,, -38200,0.4725952,1.9912379,,,,,,,,,,,,,,,,, -38300,0.29236493,1.9877053,,,,,,,,,,,,,,,,, -38400,0.2570975,2.0143464,,,,,,,,,,,,,,,,, -38500,0.38686606,2.0396605,,,,,,,,,,,,,,,,, -38534,,,0.6119536757469177,1.9667437076568604,28.499661182879315,0.6280888915061951,1.836946964263916,26.057433104129547,3000.0,0.6339085698127747,1.7903285026550293,24.38247752713862,3003.0,13468.92739701271,22592.98215246201,13468.92739701271,9122.370000362396,0.4585793018341064,0.0 -38600,0.26310387,1.9836868,,,,,,,,,,,,,,,,, -38700,0.26909456,1.9930178,,,,,,,,,,,,,,,,, -38800,0.39793053,1.9843732,,,,,,,,,,,,,,,,, -38900,0.5273247,2.0112884,,,,,,,,,,,,,,,,, -39000,0.2994263,2.037094,,,,,,,,,,,,,,,,, -39100,0.2895738,2.0237823,,,,,,,,,,,,,,,,, -39200,0.24452864,1.8898463,,,,,,,,,,,,,,,,, -39300,0.772744,2.0513904,,,,,,,,,,,,,,,,, -39400,0.45697626,1.9380387,,,,,,,,,,,,,,,,, -39500,0.35397717,2.0378003,,,,,,,,,,,,,,,,, -39600,0.3449047,1.9129565,,,,,,,,,,,,,,,,, -39700,0.38708845,2.0005724,,,,,,,,,,,,,,,,, -39800,0.3362185,1.9824668,,,,,,,,,,,,,,,,, -39900,0.34021395,1.9537283,,,,,,,,,,,,,,,,, -40000,0.4051396,2.0063064,,,,,,,,,,,,,,,,, -40100,0.5505054,1.9561762,,,,,,,,,,,,,,,,, -40200,0.3268964,2.0436375,,,,,,,,,,,,,,,,, -40300,0.25800022,1.9602065,,,,,,,,,,,,,,,,, -40400,0.4133746,2.026115,,,,,,,,,,,,,,,,, -40500,0.4093181,2.023493,,,,,,,,,,,,,,,,, -40600,0.4162189,1.8768117,,,,,,,,,,,,,,,,, -40700,0.6378062,2.0037005,,,,,,,,,,,,,,,,, -40800,0.32143468,2.0004709,,,,,,,,,,,,,,,,, -40900,0.42243722,2.0803685,,,,,,,,,,,,,,,,, -40943,,,0.6094376444816589,1.9945857524871824,29.30273771814245,0.6275309324264526,1.846041321754456,25.75543886358571,3000.0,0.6396607160568237,1.7795579433441162,24.5530504118486,3003.0,14309.143985748293,23962.27396202088,14309.143985748293,9651.330287218094,0.4973683357238769,0.0 -41000,0.43458724,2.0414307,,,,,,,,,,,,,,,,, -41100,0.32287517,2.0337183,,,,,,,,,,,,,,,,, -41200,0.30407283,1.9401907,,,,,,,,,,,,,,,,, -41300,0.31598872,2.0164025,,,,,,,,,,,,,,,,, -41400,0.2678852,1.9856175,,,,,,,,,,,,,,,,, -41500,0.28706476,1.9702617,,,,,,,,,,,,,,,,, -41600,0.37810785,2.0141456,,,,,,,,,,,,,,,,, -41700,0.48474464,2.0135329,,,,,,,,,,,,,,,,, -41800,0.29156357,2.0566885,,,,,,,,,,,,,,,,, -41900,0.38197532,1.9416474,,,,,,,,,,,,,,,,, -42000,0.5250528,2.0556161,,,,,,,,,,,,,,,,, -42100,0.53408563,2.0056183,,,,,,,,,,,,,,,,, -42200,0.33503884,1.9721723,,,,,,,,,,,,,,,,, -42300,0.3082784,1.9396857,,,,,,,,,,,,,,,,, -42400,0.33220527,1.9138234,,,,,,,,,,,,,,,,, -42500,0.2960083,1.969977,,,,,,,,,,,,,,,,, -42600,0.31599143,1.9241109,,,,,,,,,,,,,,,,, -42700,0.37551793,2.0082958,,,,,,,,,,,,,,,,, -42800,0.3018578,1.9105808,,,,,,,,,,,,,,,,, -42900,0.28167406,1.9039721,,,,,,,,,,,,,,,,, -43000,0.2671977,1.9976836,,,,,,,,,,,,,,,,, -43100,0.27597266,2.01455,,,,,,,,,,,,,,,,, -43200,0.45857173,2.012333,,,,,,,,,,,,,,,,, -43300,0.34199113,1.9244847,,,,,,,,,,,,,,,,, -43352,,,0.6117132902145386,1.979415774345398,28.89464208910684,0.6324781775474548,1.8176498413085933,26.215222262675244,3000.0,0.638893723487854,1.768803954124451,25.21839585129908,3003.0,15149.22114944458,25347.535893440247,15149.22114944458,10196.409768819807,0.5284297466278076,0.0 -43400,0.75466686,2.047881,,,,,,,,,,,,,,,,, -43500,0.4506662,1.9573522,,,,,,,,,,,,,,,,, -43600,0.40457904,1.9545099,,,,,,,,,,,,,,,,, -43700,0.44932657,2.035326,,,,,,,,,,,,,,,,, -43800,0.41010723,2.0375206,,,,,,,,,,,,,,,,, -43900,0.36063412,1.9859893,,,,,,,,,,,,,,,,, -44000,0.38696787,1.9974312,,,,,,,,,,,,,,,,, -44100,0.35617605,1.923154,,,,,,,,,,,,,,,,, -44200,0.30096757,1.993921,,,,,,,,,,,,,,,,, -44300,0.67615694,1.9719087,,,,,,,,,,,,,,,,, -44400,0.53382564,1.9328511,,,,,,,,,,,,,,,,, -44500,0.29274186,2.026511,,,,,,,,,,,,,,,,, -44600,0.40656888,1.8962067,,,,,,,,,,,,,,,,, -44700,0.52006704,1.9787445,,,,,,,,,,,,,,,,, -44800,0.4911208,1.9976356,,,,,,,,,,,,,,,,, -44900,0.43341795,2.04247,,,,,,,,,,,,,,,,, -45000,0.3704042,1.9628898,,,,,,,,,,,,,,,,, -45100,0.4613437,1.9776942,,,,,,,,,,,,,,,,, -45200,0.46992192,1.9609542,,,,,,,,,,,,,,,,, -45300,0.48603654,2.0715783,,,,,,,,,,,,,,,,, -45400,0.47398695,1.9934773,,,,,,,,,,,,,,,,, -45500,0.42625922,1.996652,,,,,,,,,,,,,,,,, -45600,0.4380374,1.9911896,,,,,,,,,,,,,,,,, -45700,0.37585744,1.9496212,,,,,,,,,,,,,,,,, -45761,,,0.6118826866149902,1.96248745918274,28.652251811486824,0.630605936050415,1.8224587440490725,26.06368270027007,3000.0,0.6413805484771729,1.7728618383407593,25.37917327989407,3003.0,15989.326357841492,26789.20568847656,15989.326357841492,10797.866746902466,0.5607788562774658,0.0 -45800,0.35371307,1.914222,,,,,,,,,,,,,,,,, -45900,0.35925096,1.9760851,,,,,,,,,,,,,,,,, -46000,0.5486693,1.9503608,,,,,,,,,,,,,,,,, -46100,0.3084904,2.022204,,,,,,,,,,,,,,,,, -46200,0.33559853,2.025622,,,,,,,,,,,,,,,,, -46300,0.29880273,1.9557025,,,,,,,,,,,,,,,,, -46400,0.44609144,1.9180475,,,,,,,,,,,,,,,,, -46500,0.36559704,1.9494851,,,,,,,,,,,,,,,,, -46600,0.3425939,1.8997549,,,,,,,,,,,,,,,,, -46700,0.3638333,2.0839157,,,,,,,,,,,,,,,,, -46800,0.5106582,1.8944803,,,,,,,,,,,,,,,,, -46900,0.24517334,1.9300046,,,,,,,,,,,,,,,,, -47000,0.4586835,1.9949836,,,,,,,,,,,,,,,,, -47100,0.2362375,1.9312487,,,,,,,,,,,,,,,,, -47200,0.3832106,1.8877746,,,,,,,,,,,,,,,,, -47300,0.24992071,1.8945144,,,,,,,,,,,,,,,,, -47400,0.31469184,1.9458758,,,,,,,,,,,,,,,,, -47500,0.46739638,1.9517106,,,,,,,,,,,,,,,,, -47600,0.32885382,1.9599191,,,,,,,,,,,,,,,,, -47700,0.3916777,2.0217037,,,,,,,,,,,,,,,,, -47800,0.4194945,1.8983611,,,,,,,,,,,,,,,,, -47900,0.34412733,1.9590453,,,,,,,,,,,,,,,,, -48000,0.29886925,1.9047072,,,,,,,,,,,,,,,,, -48100,0.38015938,2.005688,,,,,,,,,,,,,,,,, -48173,,,0.6136108636856079,1.967371225357056,29.17189039490762,0.6328873634338379,1.8119760751724243,25.978765786257235,3000.0,0.6418104767799377,1.7519155740737915,25.306392423383475,3003.0,16829.423892498016,28162.17760157585,16829.423892498016,11330.634264469149,0.5914275646209717,0.0 -48200,0.28914332,1.9701551,,,,,,,,,,,,,,,,, -48300,0.30623484,1.9905869,,,,,,,,,,,,,,,,, -48400,0.40732896,1.9513631,,,,,,,,,,,,,,,,, -48500,0.44422686,1.9088644,,,,,,,,,,,,,,,,, -48600,0.29646426,2.0174892,,,,,,,,,,,,,,,,, -48700,0.4583073,1.9471679,,,,,,,,,,,,,,,,, -48800,0.34008804,1.9059805,,,,,,,,,,,,,,,,, -48900,0.36596328,2.057743,,,,,,,,,,,,,,,,, -49000,0.42493758,1.971727,,,,,,,,,,,,,,,,, -49100,0.31557763,1.9600724,,,,,,,,,,,,,,,,, -49200,0.31223932,2.0061502,,,,,,,,,,,,,,,,, -49300,0.34561795,1.9845302,,,,,,,,,,,,,,,,, -49400,0.38932708,1.9397794,,,,,,,,,,,,,,,,, -49500,0.30508092,2.0403671,,,,,,,,,,,,,,,,, -49600,0.27110004,1.9830114,,,,,,,,,,,,,,,,, -49700,0.5402114,1.9147037,,,,,,,,,,,,,,,,, -49800,0.45906726,1.9456098,,,,,,,,,,,,,,,,, -49900,0.35755664,2.042567,,,,,,,,,,,,,,,,, -50000,0.3456086,1.9638046,,,,,,,,,,,,,,,,, -50100,0.42355657,1.9948267,,,,,,,,,,,,,,,,, -50200,0.3302988,2.0555105,,,,,,,,,,,,,,,,, -50300,0.37695765,1.8748201,,,,,,,,,,,,,,,,, -50400,0.42057055,1.935918,,,,,,,,,,,,,,,,, -50500,0.6159601,1.9769136,,,,,,,,,,,,,,,,, -50581,,,0.6235958337783813,1.867612361907959,29.89978511352561,0.6353052258491516,1.788408279418945,26.18907230106764,3000.0,0.6434489488601685,1.7368030548095703,25.11578170365946,3003.0,17669.362218618393,29497.41099047661,17669.362218618393,11825.816824674606,0.6280605792999268,0.0 -50600,0.7072126,1.9637115,,,,,,,,,,,,,,,,, -50700,0.6078873,2.0659494,,,,,,,,,,,,,,,,, -50800,0.3126931,1.8893884,,,,,,,,,,,,,,,,, -50900,0.44463164,1.9753262,,,,,,,,,,,,,,,,, -51000,0.26928052,1.9566464,,,,,,,,,,,,,,,,, -51100,0.38357344,1.933418,,,,,,,,,,,,,,,,, -51200,0.39335537,2.0143301,,,,,,,,,,,,,,,,, -51300,0.56987345,2.0044312,,,,,,,,,,,,,,,,, -51400,0.31673095,2.0255451,,,,,,,,,,,,,,,,, -51500,0.31277218,1.8353424,,,,,,,,,,,,,,,,, -51600,0.5758838,1.9964378,,,,,,,,,,,,,,,,, -51700,0.30731827,1.928967,,,,,,,,,,,,,,,,, -51800,0.28676808,2.0182762,,,,,,,,,,,,,,,,, -51900,0.31517646,1.9992045,,,,,,,,,,,,,,,,, -52000,0.41248283,1.9849471,,,,,,,,,,,,,,,,, -52100,0.4608593,1.9803934,,,,,,,,,,,,,,,,, -52200,0.3469726,1.912009,,,,,,,,,,,,,,,,, -52300,0.3116895,2.0165725,,,,,,,,,,,,,,,,, -52400,0.46655256,1.9623051,,,,,,,,,,,,,,,,, -52500,0.39313048,1.9244794,,,,,,,,,,,,,,,,, -52600,0.33468264,1.9202586,,,,,,,,,,,,,,,,, -52700,0.25591862,1.9780906,,,,,,,,,,,,,,,,, -52800,0.470908,1.9626449,,,,,,,,,,,,,,,,, -52900,0.5867512,1.9097912,,,,,,,,,,,,,,,,, -52989,,,0.6194369792938232,1.9303407669067385,29.233347669535043,0.6357763409614563,1.7921721935272217,26.26460212317322,3000.0,0.646528422832489,1.7195571660995483,25.58466393499604,3003.0,18509.431342840195,30908.126499176025,18509.431342840195,12396.352447509766,0.6613750457763672,0.0 -53000,0.6089669,1.98782,,,,,,,,,,,,,,,,, -53100,0.4081439,1.9648312,,,,,,,,,,,,,,,,, -53200,0.43686652,1.9390616,,,,,,,,,,,,,,,,, -53300,0.36294416,1.8975431,,,,,,,,,,,,,,,,, -53400,0.4232914,1.9087087,,,,,,,,,,,,,,,,, -53500,0.35912532,1.9531451,,,,,,,,,,,,,,,,, -53600,0.31686178,1.9452281,,,,,,,,,,,,,,,,, -53700,0.27736998,1.855325,,,,,,,,,,,,,,,,, -53800,0.2519554,1.925637,,,,,,,,,,,,,,,,, -53900,0.3728635,1.9271834,,,,,,,,,,,,,,,,, -54000,0.44245034,1.9575342,,,,,,,,,,,,,,,,, -54100,0.34031194,1.8786643,,,,,,,,,,,,,,,,, -54200,0.596092,2.035073,,,,,,,,,,,,,,,,, -54300,0.36971802,1.8556744,,,,,,,,,,,,,,,,, -54400,0.5355311,1.8852916,,,,,,,,,,,,,,,,, -54500,0.28841564,1.8886443,,,,,,,,,,,,,,,,, -54600,0.3748787,1.9728023,,,,,,,,,,,,,,,,, -54700,0.25914785,1.8816557,,,,,,,,,,,,,,,,, -54800,0.36583775,1.9271462,,,,,,,,,,,,,,,,, -54900,0.35414937,1.9195511,,,,,,,,,,,,,,,,, -55000,0.3327881,1.9043577,,,,,,,,,,,,,,,,, -55100,0.27245805,1.9015409,,,,,,,,,,,,,,,,, -55200,0.7146201,1.8721739,,,,,,,,,,,,,,,,, -55300,0.34536007,2.0085924,,,,,,,,,,,,,,,,, -55397,,,0.619412362575531,1.933494210243225,29.55223406319987,0.6385537385940552,1.7798618078231812,26.649243759342824,3000.0,0.6478763818740845,1.710266351699829,25.67578430644092,3003.0,19349.615591049194,32308.030866384503,19349.615591049194,12955.963208198547,0.6942222118377686,0.0 -55400,0.32373062,1.9666405,,,,,,,,,,,,,,,,, -55500,0.2862501,1.9518071,,,,,,,,,,,,,,,,, -55600,0.5661685,1.9918414,,,,,,,,,,,,,,,,, -55700,0.41843426,1.980531,,,,,,,,,,,,,,,,, -55800,0.24551791,1.9571676,,,,,,,,,,,,,,,,, -55900,0.29742423,1.8694862,,,,,,,,,,,,,,,,, -56000,0.29019526,1.8742775,,,,,,,,,,,,,,,,, -56100,0.32337973,1.8574067,,,,,,,,,,,,,,,,, -56200,0.27901378,1.9624594,,,,,,,,,,,,,,,,, -56300,0.6162966,1.9358573,,,,,,,,,,,,,,,,, -56400,0.3137431,1.972236,,,,,,,,,,,,,,,,, -56500,0.32818198,1.8828256,,,,,,,,,,,,,,,,, -56600,0.511447,2.0187533,,,,,,,,,,,,,,,,, -56700,0.3571033,1.8458849,,,,,,,,,,,,,,,,, -56800,0.41570106,1.9986115,,,,,,,,,,,,,,,,, -56900,0.60489494,1.9004443,,,,,,,,,,,,,,,,, -57000,0.56871045,1.9360703,,,,,,,,,,,,,,,,, -57100,0.43441227,1.9148828,,,,,,,,,,,,,,,,, -57200,0.325924,1.957181,,,,,,,,,,,,,,,,, -57300,0.37744015,2.03044,,,,,,,,,,,,,,,,, -57400,0.5313022,1.9203582,,,,,,,,,,,,,,,,, -57500,0.25744247,1.8724335,,,,,,,,,,,,,,,,, -57600,0.28234065,1.8465575,,,,,,,,,,,,,,,,, -57700,0.3309802,1.9400727,,,,,,,,,,,,,,,,, -57800,0.474722,1.9078264,,,,,,,,,,,,,,,,, -57805,,,0.6222780346870422,1.898165822029113,29.94380117071459,0.6379957795143127,1.757357835769653,26.4487990260602,3000.0,0.648561954498291,1.6985454559326172,25.331366094109622,3003.0,20189.71707105637,33772.894074201584,20189.71707105637,13580.615655899048,0.7280442714691162,0.0 -57900,0.5338463,1.898549,,,,,,,,,,,,,,,,, -58000,0.28384373,1.8686824,,,,,,,,,,,,,,,,, -58100,0.27749977,1.8228127,,,,,,,,,,,,,,,,, -58200,0.32304993,1.8478316,,,,,,,,,,,,,,,,, -58300,0.5320528,1.9087855,,,,,,,,,,,,,,,,, -58400,0.50569654,1.8856231,,,,,,,,,,,,,,,,, -58500,0.35368234,1.8535296,,,,,,,,,,,,,,,,, -58600,0.26622516,1.916913,,,,,,,,,,,,,,,,, -58700,0.34056264,1.9385682,,,,,,,,,,,,,,,,, -58800,0.3205434,1.950305,,,,,,,,,,,,,,,,, -58900,0.5761269,1.9509244,,,,,,,,,,,,,,,,, -59000,0.40598366,1.9124676,,,,,,,,,,,,,,,,, -59100,0.26222152,1.9257176,,,,,,,,,,,,,,,,, -59200,0.33881545,1.9283209,,,,,,,,,,,,,,,,, -59300,0.32813355,1.9264517,,,,,,,,,,,,,,,,, -59400,0.4931291,1.9561758,,,,,,,,,,,,,,,,, -59500,0.34406206,1.9122036,,,,,,,,,,,,,,,,, -59600,0.36046797,1.8917333,,,,,,,,,,,,,,,,, -59700,0.39930186,1.8569998,,,,,,,,,,,,,,,,, -59800,0.3252545,1.8977652,,,,,,,,,,,,,,,,, -59900,0.4077991,1.7955627,,,,,,,,,,,,,,,,, -60000,0.2573264,1.8592175,,,,,,,,,,,,,,,,, -60100,0.42482126,1.9327742,,,,,,,,,,,,,,,,, -60200,0.45459682,1.9400375,,,,,,,,,,,,,,,,, -60213,,,0.6185654997825623,1.9307162761688232,29.45453204450172,0.6402896642684937,1.7573803663253784,26.69694660475321,3000.0,0.6507001519203186,1.698147892951965,26.085239423423623,3003.0,21029.91014480591,35157.684792757034,21029.91014480591,14125.097086429596,0.7669491767883301,0.0 -60300,0.4266125,1.9214952,,,,,,,,,,,,,,,,, -60400,0.28692514,1.8794212,,,,,,,,,,,,,,,,, -60500,0.27300116,1.8510869,,,,,,,,,,,,,,,,, -60600,0.4211005,1.9292332,,,,,,,,,,,,,,,,, -60700,0.35086533,1.8638057,,,,,,,,,,,,,,,,, -60800,0.59508544,1.9707371,,,,,,,,,,,,,,,,, -60900,0.32293752,2.0053396,,,,,,,,,,,,,,,,, -61000,0.60775214,1.9516059,,,,,,,,,,,,,,,,, -61100,0.32479015,1.8827295,,,,,,,,,,,,,,,,, -61200,0.27356657,1.8969383,,,,,,,,,,,,,,,,, -61300,0.5388412,1.8868425,,,,,,,,,,,,,,,,, -61400,0.44324034,1.9073102,,,,,,,,,,,,,,,,, -61500,0.29228505,1.8853929,,,,,,,,,,,,,,,,, -61600,0.34070608,1.8310529,,,,,,,,,,,,,,,,, -61700,0.3058391,1.8726425,,,,,,,,,,,,,,,,, -61800,0.32010666,1.934503,,,,,,,,,,,,,,,,, -61900,0.26500088,1.8420908,,,,,,,,,,,,,,,,, -62000,0.31630015,2.0212288,,,,,,,,,,,,,,,,, -62100,0.36712316,1.814225,,,,,,,,,,,,,,,,, -62200,0.29601347,1.8963836,,,,,,,,,,,,,,,,, -62300,0.34681088,1.8876479,,,,,,,,,,,,,,,,, -62400,0.4903071,1.8736485,,,,,,,,,,,,,,,,, -62500,0.2768277,1.9231571,,,,,,,,,,,,,,,,, -62600,0.32689783,1.9152057,,,,,,,,,,,,,,,,, -62620,,,0.6608275771141052,1.6121519804000854,33.19012386447994,0.6453608870506287,1.729269027709961,26.712245705946223,3000.0,0.6520132422447205,1.678233027458191,25.953660552026207,3003.0,21869.91243505478,36552.34384179115,21869.91243505478,14679.639946699142,0.8013193607330322,0.0 -62700,0.29168707,1.921433,,,,,,,,,,,,,,,,, -62800,0.3076653,1.8573215,,,,,,,,,,,,,,,,, -62900,0.29082632,1.8907386,,,,,,,,,,,,,,,,, -63000,0.39156055,1.9278072,,,,,,,,,,,,,,,,, -63100,0.40827414,1.8576405,,,,,,,,,,,,,,,,, -63200,0.32318643,1.8025507,,,,,,,,,,,,,,,,, -63300,0.28364643,1.8832418,,,,,,,,,,,,,,,,, -63400,0.36219814,1.9398649,,,,,,,,,,,,,,,,, -63500,0.42528364,1.9490229,,,,,,,,,,,,,,,,, -63600,0.33505982,1.9556246,,,,,,,,,,,,,,,,, -63700,0.35501662,1.9260242,,,,,,,,,,,,,,,,, -63800,0.3707243,1.9344059,,,,,,,,,,,,,,,,, -63900,0.40306273,1.9048697,,,,,,,,,,,,,,,,, -64000,0.27844182,1.9174715,,,,,,,,,,,,,,,,, -64100,0.404547,1.7772752,,,,,,,,,,,,,,,,, -64200,0.31042436,1.9937015,,,,,,,,,,,,,,,,, -64300,0.2708568,1.9015293,,,,,,,,,,,,,,,,, -64400,0.49312842,2.0332084,,,,,,,,,,,,,,,,, -64500,0.28762576,1.9351567,,,,,,,,,,,,,,,,, -64600,0.37318873,1.9273869,,,,,,,,,,,,,,,,, -64700,0.340431,1.8146399,,,,,,,,,,,,,,,,, -64800,0.3534751,1.9249945,,,,,,,,,,,,,,,,, -64900,0.37763616,1.9281086,,,,,,,,,,,,,,,,, -65000,0.38980293,1.9156966,,,,,,,,,,,,,,,,, -65027,,,0.6231135129928589,1.8940107822418213,30.1586215499912,0.6435257792472839,1.7233736515045166,26.89136604491987,3000.0,0.6520829796791077,1.6709034442901611,26.03838315165289,3003.0,22709.923098564148,37917.61070728302,22709.923098564148,15204.78667140007,0.8344008922576904,0.0 -65100,0.37278768,1.8724788,,,,,,,,,,,,,,,,, -65200,0.30483896,1.8998238,,,,,,,,,,,,,,,,, -65300,0.29226544,1.7613583,,,,,,,,,,,,,,,,, -65400,0.33159694,1.986451,,,,,,,,,,,,,,,,, -65500,0.4064213,1.9291952,,,,,,,,,,,,,,,,, -65600,0.4728929,1.9323473,,,,,,,,,,,,,,,,, -65700,0.29071635,1.813556,,,,,,,,,,,,,,,,, -65800,0.31736392,1.8654717,,,,,,,,,,,,,,,,, -65900,0.28428417,1.8888013,,,,,,,,,,,,,,,,, -66000,0.4022087,1.8190012,,,,,,,,,,,,,,,,, -66100,0.27828062,1.8566275,,,,,,,,,,,,,,,,, -66200,0.28057775,1.9082277,,,,,,,,,,,,,,,,, -66300,0.36056387,1.8397729,,,,,,,,,,,,,,,,, -66400,0.3182611,1.8364866,,,,,,,,,,,,,,,,, -66500,0.35419774,1.897539,,,,,,,,,,,,,,,,, -66600,0.3867357,1.9521544,,,,,,,,,,,,,,,,, -66700,0.29908517,1.8897444,,,,,,,,,,,,,,,,, -66800,0.32012522,1.9545768,,,,,,,,,,,,,,,,, -66900,0.2845203,1.8049134,,,,,,,,,,,,,,,,, -67000,0.30684483,1.9291545,,,,,,,,,,,,,,,,, -67100,0.2845805,1.9251125,,,,,,,,,,,,,,,,, -67200,0.31041914,1.8449873,,,,,,,,,,,,,,,,, -67300,0.35150516,1.82513,,,,,,,,,,,,,,,,, -67400,0.57237023,1.8816888,,,,,,,,,,,,,,,,, -67434,,,0.6210484504699707,1.9104690551757808,30.423989758952263,0.6472083330154419,1.7089643478393557,26.998232384875,3000.0,0.6572192311286926,1.64961040019989,26.566984573098747,3003.0,23549.88878941536,39358.25233221054,23549.88878941536,15805.35246348381,0.8699424266815186,0.0 -67500,0.35827315,1.8354429,,,,,,,,,,,,,,,,, -67600,0.29089463,1.8431164,,,,,,,,,,,,,,,,, -67700,0.2760997,1.8290336,,,,,,,,,,,,,,,,, -67800,0.5493253,1.8764541,,,,,,,,,,,,,,,,, -67900,0.26057512,1.9475161,,,,,,,,,,,,,,,,, -68000,0.36414787,1.7675346,,,,,,,,,,,,,,,,, -68100,0.35705072,1.910333,,,,,,,,,,,,,,,,, -68200,0.3488842,2.0290344,,,,,,,,,,,,,,,,, -68300,0.3880862,1.8547854,,,,,,,,,,,,,,,,, -68400,0.27787733,1.8123741,,,,,,,,,,,,,,,,, -68500,0.2618153,1.8636382,,,,,,,,,,,,,,,,, -68600,0.323034,1.8660707,,,,,,,,,,,,,,,,, -68700,0.44750777,1.8655916,,,,,,,,,,,,,,,,, -68800,0.32020068,1.9765347,,,,,,,,,,,,,,,,, -68900,0.3249175,1.8448746,,,,,,,,,,,,,,,,, -69000,0.39720416,1.8047976,,,,,,,,,,,,,,,,, -69100,0.3243894,1.930573,,,,,,,,,,,,,,,,, -69200,0.29935753,1.8505756,,,,,,,,,,,,,,,,, -69300,0.49435198,1.838206,,,,,,,,,,,,,,,,, -69400,0.48317346,1.8742902,,,,,,,,,,,,,,,,, -69500,0.27225232,1.8839597,,,,,,,,,,,,,,,,, -69600,0.390964,1.7982489,,,,,,,,,,,,,,,,, -69700,0.31202996,1.9737837,,,,,,,,,,,,,,,,, -69800,0.31608602,1.81654,,,,,,,,,,,,,,,,, -69842,,,0.6296305060386658,1.8308836221694944,30.28551842580955,0.6483986377716064,1.700315237045288,27.293684458325167,3000.0,0.6600197553634644,1.6326422691345217,26.693438091255096,3003.0,24390.009131908417,40758.694177389145,24390.009131908417,16365.558494091034,0.9100024700164796,0.0 -69900,0.2954894,1.9048682,,,,,,,,,,,,,,,,, -70000,0.31152698,1.7724963,,,,,,,,,,,,,,,,, -70100,0.26933756,1.8237797,,,,,,,,,,,,,,,,, -70200,0.47633505,1.8337485,,,,,,,,,,,,,,,,, -70300,0.30871382,1.8260533,,,,,,,,,,,,,,,,, -70400,0.3100776,1.809497,,,,,,,,,,,,,,,,, -70500,0.3325689,1.930739,,,,,,,,,,,,,,,,, -70600,0.3864382,1.8003975,,,,,,,,,,,,,,,,, -70700,0.42618728,1.8279922,,,,,,,,,,,,,,,,, -70800,0.48642012,1.8684798,,,,,,,,,,,,,,,,, -70900,0.32240435,1.8492124,,,,,,,,,,,,,,,,, -71000,0.31154278,1.8249154,,,,,,,,,,,,,,,,, -71100,0.28245795,1.8427918,,,,,,,,,,,,,,,,, -71200,0.2747695,1.8632259,,,,,,,,,,,,,,,,, -71300,0.32362625,1.8115691,,,,,,,,,,,,,,,,, -71400,0.30684432,1.8704811,,,,,,,,,,,,,,,,, -71500,0.3088099,1.8758813,,,,,,,,,,,,,,,,, -71600,0.2558977,1.8325642,,,,,,,,,,,,,,,,, -71700,0.34465465,1.8247573,,,,,,,,,,,,,,,,, -71800,0.27959386,1.843686,,,,,,,,,,,,,,,,, -71900,0.28987625,1.7294655,,,,,,,,,,,,,,,,, -72000,0.38157687,1.8932192,,,,,,,,,,,,,,,,, -72100,0.3333294,1.8282024,,,,,,,,,,,,,,,,, -72200,0.4086513,1.8840251,,,,,,,,,,,,,,,,, -72250,,,0.6270090937614441,1.8659878969192505,30.405617139420453,0.6491177678108215,1.6966098546981812,27.24253045501503,3000.0,0.6600894927978516,1.632683038711548,26.730419230948563,3003.0,25230.24514484405,42108.99506902695,25230.24514484405,16875.51067852974,0.9470679759979248,0.0 -72300,0.28642532,1.7988677,,,,,,,,,,,,,,,,, -72400,0.33379558,1.8198845,,,,,,,,,,,,,,,,, -72500,0.41106778,1.8227893,,,,,,,,,,,,,,,,, -72600,0.5408815,1.7667149,,,,,,,,,,,,,,,,, -72700,0.3502402,1.8417307,,,,,,,,,,,,,,,,, -72800,0.29619032,1.8717608,,,,,,,,,,,,,,,,, -72900,0.3471884,1.8534845,,,,,,,,,,,,,,,,, -73000,0.2961872,1.8632306,,,,,,,,,,,,,,,,, -73100,0.30389592,1.7986474,,,,,,,,,,,,,,,,, -73200,0.3754091,1.8407693,,,,,,,,,,,,,,,,, -73300,0.29485205,1.8344389,,,,,,,,,,,,,,,,, -73400,0.30264425,1.8969485,,,,,,,,,,,,,,,,, -73500,0.3468156,1.7886026,,,,,,,,,,,,,,,,, -73600,0.33643028,1.7759796,,,,,,,,,,,,,,,,, -73700,0.29174364,1.8613213,,,,,,,,,,,,,,,,, -73800,0.40127066,1.8867983,,,,,,,,,,,,,,,,, -73900,0.2668563,1.8490872,,,,,,,,,,,,,,,,, -74000,0.34231824,1.8816899,,,,,,,,,,,,,,,,, -74100,0.29241684,1.8307459,,,,,,,,,,,,,,,,, -74200,0.32933182,1.8396236,,,,,,,,,,,,,,,,, -74300,0.34359527,1.8536801,,,,,,,,,,,,,,,,, -74400,0.35086805,1.8165619,,,,,,,,,,,,,,,,, -74500,0.2634549,1.8229995,,,,,,,,,,,,,,,,, -74600,0.2740073,1.7521775,,,,,,,,,,,,,,,,, -74657,,,0.6295642852783203,1.8524365425109863,29.989682284103964,0.6508660912513733,1.6791478395462036,27.60271767408957,3000.0,0.6641218066215515,1.609385967254639,26.914830269027004,3003.0,26070.14436650276,43492.85367035866,26070.14436650276,17419.357144355774,0.9834208488464355,0.0 -74700,0.30430317,1.8482988,,,,,,,,,,,,,,,,, -74800,0.28471833,1.8828799,,,,,,,,,,,,,,,,, -74900,0.2966164,1.7657589,,,,,,,,,,,,,,,,, -75000,0.32721436,1.8788275,,,,,,,,,,,,,,,,, -75100,0.43537888,1.7966641,,,,,,,,,,,,,,,,, -75200,0.31204408,1.784964,,,,,,,,,,,,,,,,, -75300,0.46674874,1.8238378,,,,,,,,,,,,,,,,, -75400,0.3252587,1.7802668,,,,,,,,,,,,,,,,, -75500,0.30530038,1.7787378,,,,,,,,,,,,,,,,, -75600,0.29614663,1.848141,,,,,,,,,,,,,,,,, -75700,0.32629102,1.7643234,,,,,,,,,,,,,,,,, -75800,0.31027594,1.821905,,,,,,,,,,,,,,,,, -75900,0.3075956,1.9109327,,,,,,,,,,,,,,,,, -76000,0.3683462,1.8845218,,,,,,,,,,,,,,,,, -76100,0.28374368,1.7283262,,,,,,,,,,,,,,,,, -76200,0.37877634,1.8052698,,,,,,,,,,,,,,,,, -76300,0.56784976,1.8659228,,,,,,,,,,,,,,,,, -76400,0.42727473,1.9285705,,,,,,,,,,,,,,,,, -76500,0.279582,1.787581,,,,,,,,,,,,,,,,, -76600,0.35148138,1.8095493,,,,,,,,,,,,,,,,, -76700,0.36840814,1.7875981,,,,,,,,,,,,,,,,, -76800,0.3071573,1.8322666,,,,,,,,,,,,,,,,, -76900,0.38345495,1.888303,,,,,,,,,,,,,,,,, -77000,0.2948453,1.78976,,,,,,,,,,,,,,,,, -77065,,,0.6358433365821838,1.804227352142334,30.153082714253948,0.6521183848381042,1.671108961105347,27.59902273365211,3000.0,0.6620184779167175,1.6064454317092896,26.76849545725773,3003.0,26910.25663924217,44847.3646748066,26910.25663924217,17933.644248008728,1.0192360877990725,0.0 -77100,0.2803314,1.8232337,,,,,,,,,,,,,,,,, -77200,0.30927202,1.8273331,,,,,,,,,,,,,,,,, -77300,0.42092428,1.8205371,,,,,,,,,,,,,,,,, -77400,0.27792066,1.8329294,,,,,,,,,,,,,,,,, -77500,0.4829127,1.7819359,,,,,,,,,,,,,,,,, -77600,0.30092946,1.7294412,,,,,,,,,,,,,,,,, -77700,0.2756002,1.7872512,,,,,,,,,,,,,,,,, -77800,0.41819626,1.8182467,,,,,,,,,,,,,,,,, -77900,0.47020522,1.9539036,,,,,,,,,,,,,,,,, -78000,0.3503591,1.881822,,,,,,,,,,,,,,,,, -78100,0.33154523,1.7961193,,,,,,,,,,,,,,,,, -78200,0.33953762,1.822401,,,,,,,,,,,,,,,,, -78300,0.29929584,1.8336843,,,,,,,,,,,,,,,,, -78400,0.3632681,1.8295289,,,,,,,,,,,,,,,,, -78500,0.42168662,1.8678228,,,,,,,,,,,,,,,,, -78600,0.3199654,1.8634088,,,,,,,,,,,,,,,,, -78700,0.27380094,1.8250699,,,,,,,,,,,,,,,,, -78800,0.35213593,1.8089038,,,,,,,,,,,,,,,,, -78900,0.33617958,1.7951496,,,,,,,,,,,,,,,,, -79000,0.31138518,1.8075114,,,,,,,,,,,,,,,,, -79100,0.27893725,1.7945416,,,,,,,,,,,,,,,,, -79200,0.27789047,1.7558879,,,,,,,,,,,,,,,,, -79300,0.3554429,1.7694944,,,,,,,,,,,,,,,,, -79400,0.29290447,1.7714385,,,,,,,,,,,,,,,,, -79473,,,0.632390558719635,1.817828059196472,30.46417054723414,0.654263436794281,1.6578301191329956,27.5063002215074,3000.0,0.6640869379043579,1.592431664466858,26.96958389750291,3003.0,27750.299865961075,46267.65453505516,27750.299865961075,18513.77645015717,1.0562868118286133,0.0 -79500,0.29530844,1.7954289,,,,,,,,,,,,,,,,, -79600,0.3274993,1.8015599,,,,,,,,,,,,,,,,, -79700,0.33346683,1.8895564,,,,,,,,,,,,,,,,, -79800,0.42393914,1.884333,,,,,,,,,,,,,,,,, -79900,0.30174533,1.7984021,,,,,,,,,,,,,,,,, -80000,0.29749465,1.7464564,,,,,,,,,,,,,,,,, -80100,0.36554593,1.929871,,,,,,,,,,,,,,,,, -80200,0.32895947,1.7815983,,,,,,,,,,,,,,,,, -80300,0.32703266,1.8623207,,,,,,,,,,,,,,,,, -80400,0.38311937,1.8079094,,,,,,,,,,,,,,,,, -80500,0.3660837,1.8167304,,,,,,,,,,,,,,,,, -80600,0.32564276,1.7980857,,,,,,,,,,,,,,,,, -80700,0.31317034,1.8883226,,,,,,,,,,,,,,,,, -80800,0.29864222,1.8316945,,,,,,,,,,,,,,,,, -80900,0.3087468,1.7903959,,,,,,,,,,,,,,,,, -81000,0.31678346,1.7248285,,,,,,,,,,,,,,,,, -81100,0.30436772,1.7828109,,,,,,,,,,,,,,,,, -81200,0.34081298,1.8057675,,,,,,,,,,,,,,,,, -81300,0.30766588,1.7568364,,,,,,,,,,,,,,,,, -81400,0.40494886,1.7721901,,,,,,,,,,,,,,,,, -81500,0.31232142,1.6966019,,,,,,,,,,,,,,,,, -81600,0.40148818,1.771371,,,,,,,,,,,,,,,,, -81700,0.29565677,1.7946506,,,,,,,,,,,,,,,,, -81800,0.31273624,1.7962921,,,,,,,,,,,,,,,,, -81881,,,0.6480662226676941,1.706380844116211,31.98326599336038,0.6571152210235596,1.6383646726608276,27.902996214779897,3000.0,0.6693974733352661,1.568246603012085,27.44548816720142,3003.0,28590.505031108856,47674.52781748772,28590.505031108856,19080.33007502556,1.0926096439361572,0.0 -81900,0.29078934,1.7795013,,,,,,,,,,,,,,,,, -82000,0.27607083,1.7580358,,,,,,,,,,,,,,,,, -82100,0.3514257,1.8478885,,,,,,,,,,,,,,,,, -82200,0.33589998,1.7381985,,,,,,,,,,,,,,,,, -82300,0.36216158,1.8742813,,,,,,,,,,,,,,,,, -82400,0.28374648,1.804694,,,,,,,,,,,,,,,,, -82500,0.37492564,1.8323078,,,,,,,,,,,,,,,,, -82600,0.30322608,1.72916,,,,,,,,,,,,,,,,, -82700,0.33357918,1.7920852,,,,,,,,,,,,,,,,, -82800,0.33190396,1.7814286,,,,,,,,,,,,,,,,, -82900,0.33897066,1.8151803,,,,,,,,,,,,,,,,, -83000,0.34138563,1.7766314,,,,,,,,,,,,,,,,, -83100,0.3059769,1.6969825,,,,,,,,,,,,,,,,, -83200,0.26734674,1.7597903,,,,,,,,,,,,,,,,, -83300,0.4001755,1.8079809,,,,,,,,,,,,,,,,, -83400,0.3393458,1.7171782,,,,,,,,,,,,,,,,, -83500,0.3063399,1.7690935,,,,,,,,,,,,,,,,, -83600,0.27474287,1.7651362,,,,,,,,,,,,,,,,, -83700,0.29170135,1.7878884,,,,,,,,,,,,,,,,, -83800,0.41848338,1.8542086,,,,,,,,,,,,,,,,, -83900,0.31238815,1.7336031,,,,,,,,,,,,,,,,, -84000,0.34307432,1.8468864,,,,,,,,,,,,,,,,, -84100,0.32864988,1.776951,,,,,,,,,,,,,,,,, -84200,0.28439167,1.713574,,,,,,,,,,,,,,,,, -84289,,,0.6433086395263672,1.7546846866607666,30.503595770525028,0.6575739979743958,1.631597876548767,27.81755635340168,3000.0,0.6694555878639221,1.5594232082366943,27.313128787818524,3003.0,29430.63443851471,49078.12288761139,29430.63443851471,19643.68313574791,1.1292719841003418,0.0 -84300,0.30813217,1.7931451,,,,,,,,,,,,,,,,, -84400,0.32971796,1.8087856,,,,,,,,,,,,,,,,, -84500,0.35789123,1.7902172,,,,,,,,,,,,,,,,, -84600,0.3084719,1.685139,,,,,,,,,,,,,,,,, -84700,0.29529026,1.7378254,,,,,,,,,,,,,,,,, -84800,0.31444845,1.8307854,,,,,,,,,,,,,,,,, -84900,0.29342356,1.7779334,,,,,,,,,,,,,,,,, -85000,0.3281446,1.7272848,,,,,,,,,,,,,,,,, -85100,0.27567434,1.6856751,,,,,,,,,,,,,,,,, -85200,0.38036078,1.7905658,,,,,,,,,,,,,,,,, -85300,0.32890812,1.817677,,,,,,,,,,,,,,,,, -85400,0.3441123,1.7458146,,,,,,,,,,,,,,,,, -85500,0.28389472,1.7519401,,,,,,,,,,,,,,,,, -85600,0.2718278,1.7528646,,,,,,,,,,,,,,,,, -85700,0.28291994,1.7653534,,,,,,,,,,,,,,,,, -85800,0.31821778,1.7388395,,,,,,,,,,,,,,,,, -85900,0.3084238,1.8105887,,,,,,,,,,,,,,,,, -86000,0.30375355,1.8728939,,,,,,,,,,,,,,,,, -86100,0.33264723,1.8318974,,,,,,,,,,,,,,,,, -86200,0.29965734,1.7738413,,,,,,,,,,,,,,,,, -86300,0.31114617,1.7610193,,,,,,,,,,,,,,,,, -86400,0.32305714,1.7051475,,,,,,,,,,,,,,,,, -86500,0.34017423,1.7642963,,,,,,,,,,,,,,,,, -86600,0.42673996,1.784605,,,,,,,,,,,,,,,,, -86697,,,0.6414464712142944,1.7629821300506592,30.919897634747738,0.6594710350036621,1.6154887676239014,27.64385552238076,3000.0,0.6706408858299255,1.5513267517089844,27.45172026614546,3003.0,30270.717396259308,50467.73133611679,30270.717396259308,20193.09274697304,1.167518138885498,0.0 -86700,0.32971603,1.7686847,,,,,,,,,,,,,,,,, -86800,0.44771636,1.7271625,,,,,,,,,,,,,,,,, -86900,0.35418832,1.766598,,,,,,,,,,,,,,,,, -87000,0.32105672,1.702199,,,,,,,,,,,,,,,,, -87100,0.34270397,1.805652,,,,,,,,,,,,,,,,, -87200,0.28932104,1.7369783,,,,,,,,,,,,,,,,, -87300,0.30810434,1.7324648,,,,,,,,,,,,,,,,, -87400,0.28979123,1.7972796,,,,,,,,,,,,,,,,, -87500,0.3388558,1.7535832,,,,,,,,,,,,,,,,, -87600,0.3346031,1.786569,,,,,,,,,,,,,,,,, -87700,0.32663018,1.7896272,,,,,,,,,,,,,,,,, -87800,0.36884585,1.7478149,,,,,,,,,,,,,,,,, -87900,0.32638747,1.8398663,,,,,,,,,,,,,,,,, -88000,0.31416154,1.7108233,,,,,,,,,,,,,,,,, -88100,0.3099996,1.767685,,,,,,,,,,,,,,,,, -88200,0.2823219,1.7557768,,,,,,,,,,,,,,,,, -88300,0.406611,1.6725905,,,,,,,,,,,,,,,,, -88400,0.3525888,1.8199614,,,,,,,,,,,,,,,,, -88500,0.34034425,1.7134962,,,,,,,,,,,,,,,,, -88600,0.35915318,1.7687168,,,,,,,,,,,,,,,,, -88700,0.3367042,1.7886553,,,,,,,,,,,,,,,,, -88800,0.33198464,1.8136079,,,,,,,,,,,,,,,,, -88900,0.31847802,1.6258385,,,,,,,,,,,,,,,,, -89000,0.2799977,1.7000288,,,,,,,,,,,,,,,,, -89100,0.3561257,1.7976972,,,,,,,,,,,,,,,,, -89105,,,0.6479898691177368,1.7014189958572388,31.870726073454623,0.664678692817688,1.591239333152771,27.893390280404773,3000.0,0.676021158695221,1.5245511531829834,27.724933282657265,3003.0,31110.686940193176,52011.69216299057,31110.686940193176,20896.967987298965,1.2072625160217283,0.0 -89200,0.3075986,1.7941517,,,,,,,,,,,,,,,,, -89300,0.30240998,1.7404494,,,,,,,,,,,,,,,,, -89400,0.33974898,1.6232916,,,,,,,,,,,,,,,,, -89500,0.29717258,1.717194,,,,,,,,,,,,,,,,, -89600,0.3128229,1.7446632,,,,,,,,,,,,,,,,, -89700,0.3161841,1.6837851,,,,,,,,,,,,,,,,, -89800,0.3050698,1.7265463,,,,,,,,,,,,,,,,, -89900,0.28784811,1.699341,,,,,,,,,,,,,,,,, -90000,0.3122055,1.7677269,,,,,,,,,,,,,,,,, -90100,0.29727006,1.7905278,,,,,,,,,,,,,,,,, -90200,0.2793894,1.765207,,,,,,,,,,,,,,,,, -90300,0.36107302,1.7251648,,,,,,,,,,,,,,,,, -90400,0.32163316,1.7751604,,,,,,,,,,,,,,,,, -90500,0.3199959,1.7270936,,,,,,,,,,,,,,,,, -90600,0.30426517,1.7613273,,,,,,,,,,,,,,,,, -90700,0.3148321,1.6542991,,,,,,,,,,,,,,,,, -90800,0.35657075,1.727328,,,,,,,,,,,,,,,,, -90900,0.33939904,1.739958,,,,,,,,,,,,,,,,, -91000,0.4102437,1.7753413,,,,,,,,,,,,,,,,, -91100,0.32828933,1.7258716,,,,,,,,,,,,,,,,, -91200,0.41491464,1.7177316,,,,,,,,,,,,,,,,, -91300,0.5466566,1.7459675,,,,,,,,,,,,,,,,, -91400,0.3188105,1.6726886,,,,,,,,,,,,,,,,, -91500,0.29491088,1.724599,,,,,,,,,,,,,,,,, -91513,,,0.6436875462532043,1.7378581762313845,31.34514009640573,0.6646538972854614,1.5829776525497437,28.456537818777825,3000.0,0.6771948337554932,1.5109649896621704,27.948103940076336,3003.0,31950.802713632584,53386.21637392044,31950.802713632584,21431.26032066345,1.246953010559082,0.0 -91600,0.2884364,1.6410228,,,,,,,,,,,,,,,,, -91700,0.32299677,1.7472051,,,,,,,,,,,,,,,,, -91800,0.30299684,1.6753178,,,,,,,,,,,,,,,,, -91900,0.32565,1.6695957,,,,,,,,,,,,,,,,, -92000,0.3099486,1.778942,,,,,,,,,,,,,,,,, -92100,0.32289562,1.7762585,,,,,,,,,,,,,,,,, -92200,0.30310974,1.7025046,,,,,,,,,,,,,,,,, -92300,0.33076242,1.68705,,,,,,,,,,,,,,,,, -92400,0.29705158,1.6584612,,,,,,,,,,,,,,,,, -92500,0.31664047,1.7833128,,,,,,,,,,,,,,,,, -92600,0.30914146,1.8079015,,,,,,,,,,,,,,,,, -92700,0.37600988,1.7509059,,,,,,,,,,,,,,,,, -92800,0.32020828,1.7005392,,,,,,,,,,,,,,,,, -92900,0.30209845,1.6626579,,,,,,,,,,,,,,,,, -93000,0.3133614,1.7391186,,,,,,,,,,,,,,,,, -93100,0.30734032,1.7235363,,,,,,,,,,,,,,,,, -93200,0.3183718,1.7801969,,,,,,,,,,,,,,,,, -93300,0.28641242,1.7064141,,,,,,,,,,,,,,,,, -93400,0.33309886,1.6987801,,,,,,,,,,,,,,,,, -93500,0.31302387,1.6801764,,,,,,,,,,,,,,,,, -93600,0.33426255,1.7488304,,,,,,,,,,,,,,,,, -93700,0.28836277,1.7095617,,,,,,,,,,,,,,,,, -93800,0.30572316,1.6048679,,,,,,,,,,,,,,,,, -93900,0.3189005,1.6992387,,,,,,,,,,,,,,,,, -93921,,,0.6853712797164917,1.4687334299087524,34.03844929654699,0.6677164435386658,1.5667170286178589,28.313632988196403,3000.0,0.6811574101448059,1.4931126832962036,28.463327942895543,3003.0,32790.77405810356,54783.45065832138,32790.77405810356,21988.401491642,1.291767120361328,0.0 -94000,0.31647548,1.7044526,,,,,,,,,,,,,,,,, -94100,0.29299462,1.6891043,,,,,,,,,,,,,,,,, -94200,0.34107146,1.7797781,,,,,,,,,,,,,,,,, -94300,0.31603375,1.7383127,,,,,,,,,,,,,,,,, -94400,0.3099661,1.7819127,,,,,,,,,,,,,,,,, -94500,0.36766812,1.7125231,,,,,,,,,,,,,,,,, -94600,0.3461483,1.6990254,,,,,,,,,,,,,,,,, -94700,0.33728668,1.6782784,,,,,,,,,,,,,,,,, -94800,0.33324805,1.6290736,,,,,,,,,,,,,,,,, -94900,0.32457092,1.6762308,,,,,,,,,,,,,,,,, -95000,0.3202476,1.7526675,,,,,,,,,,,,,,,,, -95100,0.31373405,1.7012731,,,,,,,,,,,,,,,,, -95200,0.31293908,1.6637766,,,,,,,,,,,,,,,,, -95300,0.35709372,1.6892283,,,,,,,,,,,,,,,,, -95400,0.3142467,1.6227292,,,,,,,,,,,,,,,,, -95500,0.3184056,1.6955452,,,,,,,,,,,,,,,,, -95600,0.29378384,1.7211494,,,,,,,,,,,,,,,,, -95700,0.33460557,1.7936203,,,,,,,,,,,,,,,,, -95800,0.35769743,1.7529244,,,,,,,,,,,,,,,,, -95900,0.3620455,1.7770525,,,,,,,,,,,,,,,,, -96000,0.8366703,1.6560079,,,,,,,,,,,,,,,,, -96100,0.32122394,1.7604482,,,,,,,,,,,,,,,,, -96200,0.34812298,1.6681619,,,,,,,,,,,,,,,,, -96300,0.7661078,1.7314696,,,,,,,,,,,,,,,,, -96329,,,0.6571729779243469,1.6591682434082031,31.70680713304561,0.6698614954948425,1.549560785293579,28.72916283015469,3000.0,0.6832374930381775,1.4707233905792236,28.521573655414183,3003.0,33630.75802206993,56252.48037528992,33630.75802206993,22617.3301281929,1.331843376159668,0.0 -96400,0.33904374,1.7740188,,,,,,,,,,,,,,,,, -96500,0.32073137,1.6949909,,,,,,,,,,,,,,,,, -96600,0.27908444,1.6554643,,,,,,,,,,,,,,,,, -96700,0.31830943,1.671673,,,,,,,,,,,,,,,,, -96800,0.30559865,1.676546,,,,,,,,,,,,,,,,, -96900,0.3172583,1.684804,,,,,,,,,,,,,,,,, -97000,0.30484352,1.7222717,,,,,,,,,,,,,,,,, -97100,0.37575504,1.5801301,,,,,,,,,,,,,,,,, -97200,0.3658182,1.7180048,,,,,,,,,,,,,,,,, -97300,0.34963137,1.7151502,,,,,,,,,,,,,,,,, -97400,0.28972805,1.7036673,,,,,,,,,,,,,,,,, -97500,0.38695338,1.6505255,,,,,,,,,,,,,,,,, -97600,0.30694875,1.6306107,,,,,,,,,,,,,,,,, -97700,0.28969285,1.671731,,,,,,,,,,,,,,,,, -97800,0.29900387,1.6124737,,,,,,,,,,,,,,,,, -97900,0.3130174,1.6604276,,,,,,,,,,,,,,,,, -98000,0.3116384,1.6146795,,,,,,,,,,,,,,,,, -98100,0.315566,1.5761678,,,,,,,,,,,,,,,,, -98200,0.32563135,1.6294799,,,,,,,,,,,,,,,,, -98300,0.367892,1.7333219,,,,,,,,,,,,,,,,, -98400,0.3540752,1.5872422,,,,,,,,,,,,,,,,, -98500,0.3093507,1.6640878,,,,,,,,,,,,,,,,, -98600,0.31316036,1.6552613,,,,,,,,,,,,,,,,, -98700,0.31901032,1.6835164,,,,,,,,,,,,,,,,, -98737,,,0.6514151692390442,1.690314531326294,31.84842778417848,0.6744739413261414,1.5323798656463623,29.34685369574843,3000.0,0.6842019557952881,1.4612979888916016,28.679556840476543,3003.0,34470.64178228378,57623.255120038986,34470.64178228378,23148.10334300995,1.372218370437622,0.0 -98800,0.31073377,1.659256,,,,,,,,,,,,,,,,, -98900,0.32821715,1.6649196,,,,,,,,,,,,,,,,, -99000,0.3747413,1.7170918,,,,,,,,,,,,,,,,, -99100,0.30649105,1.6095744,,,,,,,,,,,,,,,,, -99200,0.3233282,1.6271806,,,,,,,,,,,,,,,,, -99300,0.3355339,1.7003272,,,,,,,,,,,,,,,,, -99400,0.3099986,1.6525185,,,,,,,,,,,,,,,,, -99500,0.31840113,1.6817688,,,,,,,,,,,,,,,,, -99600,0.2950369,1.5821519,,,,,,,,,,,,,,,,, -99700,0.3115512,1.6824094,,,,,,,,,,,,,,,,, -99800,0.30427155,1.6461012,,,,,,,,,,,,,,,,, -99900,0.37664908,1.6597012,,,,,,,,,,,,,,,,, -100000,0.33579683,1.6583608,,,,,,,,,,,,,,,,, -100100,0.34540087,1.7453904,,,,,,,,,,,,,,,,, -100200,0.34473494,1.6026472,,,,,,,,,,,,,,,,, -100300,0.3042706,1.6527748,,,,,,,,,,,,,,,,, -100400,0.30720904,1.6967238,,,,,,,,,,,,,,,,, -100500,0.31533733,1.6610488,,,,,,,,,,,,,,,,, -100600,0.34320086,1.6413145,,,,,,,,,,,,,,,,, -100700,0.34341568,1.5321345,,,,,,,,,,,,,,,,, -100800,0.33585754,1.6771381,,,,,,,,,,,,,,,,, -100900,0.34447002,1.6960208,,,,,,,,,,,,,,,,, -101000,0.29099825,1.5746675,,,,,,,,,,,,,,,,, -101100,0.322258,1.6466638,,,,,,,,,,,,,,,,, -101145,,,0.6625736951828003,1.604433298110962,32.59588528405028,0.6739407777786255,1.5239664316177368,29.29180447115428,3000.0,0.688176155090332,1.4475120306015017,29.02306537809828,3003.0,35310.656386613846,59013.03903198242,35310.656386613846,23697.752576589584,1.4147746562957764,0.0 -101200,0.3331114,1.5563858,,,,,,,,,,,,,,,,, -101300,0.3273399,1.6201357,,,,,,,,,,,,,,,,, -101400,0.32536873,1.6187398,,,,,,,,,,,,,,,,, -101500,0.36354336,1.6946859,,,,,,,,,,,,,,,,, -101600,0.3281241,1.6859103,,,,,,,,,,,,,,,,, -101700,0.3245071,1.6270893,,,,,,,,,,,,,,,,, -101800,0.35377923,1.605976,,,,,,,,,,,,,,,,, -101900,0.32660514,1.6800437,,,,,,,,,,,,,,,,, -102000,0.29509717,1.6262246,,,,,,,,,,,,,,,,, -102100,0.29514682,1.5867516,,,,,,,,,,,,,,,,, -102200,0.31757802,1.7089162,,,,,,,,,,,,,,,,, -102300,0.31216824,1.6782782,,,,,,,,,,,,,,,,, -102400,0.36401314,1.5964093,,,,,,,,,,,,,,,,, -102500,0.28452033,1.5846833,,,,,,,,,,,,,,,,, -102600,0.32551417,1.5984204,,,,,,,,,,,,,,,,, -102700,0.35085404,1.6237217,,,,,,,,,,,,,,,,, -102800,0.32238275,1.6087533,,,,,,,,,,,,,,,,, -102900,0.33500782,1.7130954,,,,,,,,,,,,,,,,, -103000,0.2976476,1.6605371,,,,,,,,,,,,,,,,, -103100,0.3454729,1.6689004,,,,,,,,,,,,,,,,, -103200,0.32299462,1.568336,,,,,,,,,,,,,,,,, -103300,0.3189585,1.6697861,,,,,,,,,,,,,,,,, -103400,0.36162552,1.709909,,,,,,,,,,,,,,,,, -103500,0.30796614,1.534509,,,,,,,,,,,,,,,,, -103553,,,0.6571160554885864,1.6496434211730957,32.57483756565918,0.6794335842132568,1.5055320262908936,29.71928640001056,3000.0,0.6907559037208557,1.4269720315933228,29.4018099388418,3003.0,36150.72316431999,60513.682002067566,36150.72316431999,24358.2102060318,1.455638408660889,0.0 -103600,0.3149243,1.6141685,,,,,,,,,,,,,,,,, -103700,0.31629106,1.5969111,,,,,,,,,,,,,,,,, -103800,0.32318616,1.6894895,,,,,,,,,,,,,,,,, -103900,0.37949443,1.57201,,,,,,,,,,,,,,,,, -104000,0.30400825,1.6351722,,,,,,,,,,,,,,,,, -104100,0.33913204,1.6132995,,,,,,,,,,,,,,,,, -104200,0.32064855,1.7351395,,,,,,,,,,,,,,,,, -104300,0.32501203,1.5416073,,,,,,,,,,,,,,,,, -104400,0.3076194,1.6346624,,,,,,,,,,,,,,,,, -104500,0.3185606,1.6333231,,,,,,,,,,,,,,,,, -104600,0.31721216,1.5793965,,,,,,,,,,,,,,,,, -104700,0.3151167,1.6082466,,,,,,,,,,,,,,,,, -104800,0.3140132,1.5510358,,,,,,,,,,,,,,,,, -104900,0.37220836,1.6212656,,,,,,,,,,,,,,,,, -105000,0.34393907,1.6204157,,,,,,,,,,,,,,,,, -105100,0.34933025,1.6761183,,,,,,,,,,,,,,,,, -105200,0.31244907,1.5820278,,,,,,,,,,,,,,,,, -105300,0.332262,1.6677706,,,,,,,,,,,,,,,,, -105400,0.3152701,1.5653859,,,,,,,,,,,,,,,,, -105500,0.33642256,1.641657,,,,,,,,,,,,,,,,, -105600,0.3324121,1.6225312,,,,,,,,,,,,,,,,, -105700,0.38148028,1.6105965,,,,,,,,,,,,,,,,, -105800,0.33117738,1.663739,,,,,,,,,,,,,,,,, -105900,0.31481707,1.5605111,,,,,,,,,,,,,,,,, -105962,,,0.6626680493354797,1.6218568086624146,32.79830573550928,0.6800907254219055,1.4952389001846311,29.62048571366333,3000.0,0.6929289698600769,1.4130771160125732,29.46607487977802,3003.0,36990.85962986946,61945.61144256592,36990.85962986946,24949.888291597366,1.4942302703857422,0.0 -106000,0.43429577,1.6456121,,,,,,,,,,,,,,,,, -106100,0.36632228,1.5840017,,,,,,,,,,,,,,,,, -106200,0.31272423,1.5416104,,,,,,,,,,,,,,,,, -106300,0.3311218,1.5378889,,,,,,,,,,,,,,,,, -106400,0.3192571,1.5299493,,,,,,,,,,,,,,,,, -106500,0.32518464,1.4972647,,,,,,,,,,,,,,,,, -106600,0.3267109,1.62977,,,,,,,,,,,,,,,,, -106700,0.33912855,1.5941274,,,,,,,,,,,,,,,,, -106800,0.3319775,1.5979346,,,,,,,,,,,,,,,,, -106900,0.35877156,1.6266179,,,,,,,,,,,,,,,,, -107000,0.356412,1.5819086,,,,,,,,,,,,,,,,, -107100,0.31859723,1.5832773,,,,,,,,,,,,,,,,, -107200,0.31779027,1.6591597,,,,,,,,,,,,,,,,, -107300,0.3156598,1.5174199,,,,,,,,,,,,,,,,, -107400,0.31626403,1.5971274,,,,,,,,,,,,,,,,, -107500,0.32192442,1.6135304,,,,,,,,,,,,,,,,, -107600,0.29593414,1.5792245,,,,,,,,,,,,,,,,, -107700,0.32799995,1.5337042,,,,,,,,,,,,,,,,, -107800,0.33637682,1.6499286,,,,,,,,,,,,,,,,, -107900,0.32989287,1.652776,,,,,,,,,,,,,,,,, -108000,0.34132588,1.6143081,,,,,,,,,,,,,,,,, -108100,0.30524457,1.623076,,,,,,,,,,,,,,,,, -108200,0.32897714,1.5100772,,,,,,,,,,,,,,,,, -108300,0.3327176,1.5305241,,,,,,,,,,,,,,,,, -108370,,,0.6737005710601807,1.5475784540176392,33.575351495616545,0.6833640933036804,1.479899525642395,29.811560649958405,3000.0,0.6960200071334839,1.3971692323684692,29.64580006144558,3003.0,37831.08743548393,63350.19052934647,37831.08743548393,25514.119074106216,1.5336592197418213,0.0 -108400,0.3255405,1.579912,,,,,,,,,,,,,,,,, -108500,0.34872222,1.6093912,,,,,,,,,,,,,,,,, -108600,0.3271911,1.6156223,,,,,,,,,,,,,,,,, -108700,0.3421449,1.6677538,,,,,,,,,,,,,,,,, -108800,0.3369351,1.5747547,,,,,,,,,,,,,,,,, -108900,0.33249593,1.6262223,,,,,,,,,,,,,,,,, -109000,0.352858,1.5570474,,,,,,,,,,,,,,,,, -109100,0.34055024,1.5204428,,,,,,,,,,,,,,,,, -109200,0.3580142,1.5861313,,,,,,,,,,,,,,,,, -109300,0.34413826,1.5689317,,,,,,,,,,,,,,,,, -109400,0.3261503,1.6255469,,,,,,,,,,,,,,,,, -109500,0.36525455,1.628153,,,,,,,,,,,,,,,,, -109600,0.33467886,1.5338848,,,,,,,,,,,,,,,,, -109700,0.33290315,1.4903803,,,,,,,,,,,,,,,,, -109800,0.35075256,1.601542,,,,,,,,,,,,,,,,, -109900,0.3475987,1.6566535,,,,,,,,,,,,,,,,, -110000,0.33084825,1.6305206,,,,,,,,,,,,,,,,, -110100,0.32448068,1.543845,,,,,,,,,,,,,,,,, -110200,0.38783363,1.6005428,,,,,,,,,,,,,,,,, -110300,0.3379645,1.5429672,,,,,,,,,,,,,,,,, -110400,0.39019877,1.5811157,,,,,,,,,,,,,,,,, -110500,0.3519663,1.6048602,,,,,,,,,,,,,,,,, -110600,0.36416698,1.5864211,,,,,,,,,,,,,,,,, -110700,0.3429558,1.6025782,,,,,,,,,,,,,,,,, -110781,,,0.6674370765686035,1.5858745574951172,33.23845615206641,0.6833516955375671,1.4708929061889648,30.19010167590585,3000.0,0.6975306868553162,1.3853771686553955,29.50214132674396,3003.0,38671.22283697128,64770.899523973465,38671.22283697128,26094.5704498291,1.57651948928833,0.0 -110800,0.34080207,1.5636696,,,,,,,,,,,,,,,,, -110900,0.34253317,1.5500882,,,,,,,,,,,,,,,,, -111000,0.3429303,1.578513,,,,,,,,,,,,,,,,, -111100,0.34876132,1.5468009,,,,,,,,,,,,,,,,, -111200,0.34075773,1.5375824,,,,,,,,,,,,,,,,, -111300,0.32722807,1.5005207,,,,,,,,,,,,,,,,, -111400,0.34034014,1.5914682,,,,,,,,,,,,,,,,, -111500,0.33860818,1.5786165,,,,,,,,,,,,,,,,, -111600,0.3368781,1.588419,,,,,,,,,,,,,,,,, -111700,0.33101663,1.4541821,,,,,,,,,,,,,,,,, -111800,0.34425664,1.6252613,,,,,,,,,,,,,,,,, -111900,0.34300545,1.6202841,,,,,,,,,,,,,,,,, -112000,0.33455145,1.5716599,,,,,,,,,,,,,,,,, -112100,0.35190004,1.5551382,,,,,,,,,,,,,,,,, -112200,0.34154046,1.6273775,,,,,,,,,,,,,,,,, -112300,0.33010426,1.5654223,,,,,,,,,,,,,,,,, -112400,0.33886215,1.473822,,,,,,,,,,,,,,,,, -112500,0.35704157,1.600497,,,,,,,,,,,,,,,,, -112600,0.35154116,1.4793187,,,,,,,,,,,,,,,,, -112700,0.34672633,1.5576341,,,,,,,,,,,,,,,,, -112800,0.34998965,1.5404124,,,,,,,,,,,,,,,,, -112900,0.35390034,1.5340108,,,,,,,,,,,,,,,,, -113000,0.37074217,1.5819968,,,,,,,,,,,,,,,,, -113100,0.36227196,1.5731446,,,,,,,,,,,,,,,,, -113188,,,0.686324417591095,1.4715728759765625,34.47618891337282,0.6864266991615295,1.4545800685882568,29.93864461495176,3000.0,0.7002033591270447,1.369438409805298,29.95588310322253,3003.0,39511.19406795502,66233.1094198227,39511.19406795502,26716.692593574524,1.61653470993042,0.0 -113200,0.35679123,1.5244789,,,,,,,,,,,,,,,,, -113300,0.35937002,1.5620927,,,,,,,,,,,,,,,,, -113400,0.3475303,1.6628419,,,,,,,,,,,,,,,,, -113500,0.3385575,1.4686427,,,,,,,,,,,,,,,,, -113600,0.39434627,1.5245539,,,,,,,,,,,,,,,,, -113700,0.37241527,1.6345255,,,,,,,,,,,,,,,,, -113800,0.34922206,1.5066003,,,,,,,,,,,,,,,,, -113900,0.3477334,1.6473105,,,,,,,,,,,,,,,,, -114000,0.35288763,1.4483402,,,,,,,,,,,,,,,,, -114100,0.33892956,1.5118009,,,,,,,,,,,,,,,,, -114200,0.3499193,1.5560713,,,,,,,,,,,,,,,,, -114300,0.34993362,1.5941542,,,,,,,,,,,,,,,,, -114400,0.36131766,1.5014598,,,,,,,,,,,,,,,,, -114500,0.36383638,1.5796328,,,,,,,,,,,,,,,,, -114600,0.35618496,1.5092494,,,,,,,,,,,,,,,,, -114700,0.34627497,1.5426286,,,,,,,,,,,,,,,,, -114800,0.37273708,1.5711851,,,,,,,,,,,,,,,,, -114900,0.34582573,1.5817785,,,,,,,,,,,,,,,,, -115000,0.37022415,1.5246787,,,,,,,,,,,,,,,,, -115100,0.32291415,1.4437408,,,,,,,,,,,,,,,,, -115200,0.3524129,1.4736506,,,,,,,,,,,,,,,,, -115300,0.34787992,1.5687783,,,,,,,,,,,,,,,,, -115400,0.35469073,1.4365122,,,,,,,,,,,,,,,,, -115500,0.35478064,1.5865645,,,,,,,,,,,,,,,,, -115596,,,0.6802150011062622,1.505571722984314,34.051214524478446,0.6871086359024048,1.4513825178146362,30.445699454155303,3000.0,0.7025042176246643,1.358932375907898,30.259197723267093,3003.0,40351.23788332939,67599.56405115128,40351.23788332939,27242.97893857956,1.6643104553222656,0.0 -115600,0.37303102,1.5033733,,,,,,,,,,,,,,,,, -115700,0.34496805,1.5616288,,,,,,,,,,,,,,,,, -115800,0.37337264,1.5699005,,,,,,,,,,,,,,,,, -115900,0.35214487,1.5591836,,,,,,,,,,,,,,,,, -116000,0.3578336,1.5085952,,,,,,,,,,,,,,,,, -116100,0.35474217,1.5212239,,,,,,,,,,,,,,,,, -116200,0.34100685,1.4906832,,,,,,,,,,,,,,,,, -116300,0.33706313,1.5446535,,,,,,,,,,,,,,,,, -116400,0.3913897,1.557004,,,,,,,,,,,,,,,,, -116500,0.36597556,1.5357195,,,,,,,,,,,,,,,,, -116600,0.35984764,1.5021805,,,,,,,,,,,,,,,,, -116700,0.35144463,1.4921708,,,,,,,,,,,,,,,,, -116800,0.35926193,1.5468673,,,,,,,,,,,,,,,,, -116900,0.36942074,1.4888153,,,,,,,,,,,,,,,,, -117000,0.37280244,1.485664,,,,,,,,,,,,,,,,, -117100,0.3855748,1.5156761,,,,,,,,,,,,,,,,, -117200,0.34688425,1.4792897,,,,,,,,,,,,,,,,, -117300,0.35598943,1.4229043,,,,,,,,,,,,,,,,, -117400,0.35046765,1.4711936,,,,,,,,,,,,,,,,, -117500,0.39219895,1.6079612,,,,,,,,,,,,,,,,, -117600,0.34754226,1.4516835,,,,,,,,,,,,,,,,, -117700,0.3575336,1.5053055,,,,,,,,,,,,,,,,, -117800,0.36148453,1.4392112,,,,,,,,,,,,,,,,, -117900,0.364563,1.4889896,,,,,,,,,,,,,,,,, -118000,0.3895693,1.512121,,,,,,,,,,,,,,,,, -118005,,,0.680402934551239,1.5085806846618652,34.02548684398224,0.6894644498825073,1.437278389930725,30.400086035153965,3000.0,0.7049096822738647,1.342985987663269,30.587941780081074,3003.0,41191.43535208702,68923.48609733582,41191.43535208702,27726.58704996109,1.705122947692871,0.0 -118100,0.35948318,1.5433402,,,,,,,,,,,,,,,,, -118200,0.37419882,1.47604,,,,,,,,,,,,,,,,, -118300,0.3588299,1.4972327,,,,,,,,,,,,,,,,, -118400,0.35198095,1.4994154,,,,,,,,,,,,,,,,, -118500,0.36544958,1.5037054,,,,,,,,,,,,,,,,, -118600,0.40535483,1.5199322,,,,,,,,,,,,,,,,, -118700,0.3836785,1.59046,,,,,,,,,,,,,,,,, -118800,0.35360557,1.5287403,,,,,,,,,,,,,,,,, -118900,0.3827666,1.4801971,,,,,,,,,,,,,,,,, -119000,0.4300461,1.5353261,,,,,,,,,,,,,,,,, -119100,0.34936842,1.4521657,,,,,,,,,,,,,,,,, -119200,0.36769643,1.5039753,,,,,,,,,,,,,,,,, -119300,0.36765632,1.4866241,,,,,,,,,,,,,,,,, -119400,0.35172576,1.5341043,,,,,,,,,,,,,,,,, -119500,0.36845174,1.4929193,,,,,,,,,,,,,,,,, -119600,0.36586744,1.4208072,,,,,,,,,,,,,,,,, -119700,0.38477427,1.5750033,,,,,,,,,,,,,,,,, -119800,0.37550223,1.5254086,,,,,,,,,,,,,,,,, -119900,0.340529,1.4508493,,,,,,,,,,,,,,,,, -120000,0.38572624,1.4316366,,,,,,,,,,,,,,,,, -120100,0.38901976,1.4501935,,,,,,,,,,,,,,,,, -120200,0.37623712,1.4940771,,,,,,,,,,,,,,,,, -120300,0.3570935,1.4542044,,,,,,,,,,,,,,,,, -120400,0.3694526,1.4972522,,,,,,,,,,,,,,,,, -120413,,,0.6904399991035461,1.4429056644439695,35.09718083113306,0.6901216506958008,1.4297069311141968,30.587810108397143,3000.0,0.7053279876708984,1.3383585214614868,30.600486228597944,3003.0,42031.62937164307,70343.08654594421,42031.62937164307,28305.8653011322,1.7545392513275146,0.0 -120500,0.3811883,1.4155921,,,,,,,,,,,,,,,,, -120600,0.38834134,1.5082448,,,,,,,,,,,,,,,,, -120700,0.38479194,1.4639484,,,,,,,,,,,,,,,,, -120800,0.36999297,1.4469596,,,,,,,,,,,,,,,,, -120900,0.36585802,1.4581572,,,,,,,,,,,,,,,,, -121000,0.39480197,1.4575307,,,,,,,,,,,,,,,,, -121100,0.37697205,1.5140971,,,,,,,,,,,,,,,,, -121200,0.36742008,1.4434841,,,,,,,,,,,,,,,,, -121300,0.40830943,1.5211977,,,,,,,,,,,,,,,,, -121400,0.36536232,1.4217501,,,,,,,,,,,,,,,,, -121500,0.37615353,1.4471437,,,,,,,,,,,,,,,,, -121600,0.3748824,1.4393587,,,,,,,,,,,,,,,,, -121700,0.38511154,1.4820276,,,,,,,,,,,,,,,,, -121800,0.38587284,1.4563949,,,,,,,,,,,,,,,,, -121900,0.37698525,1.4896758,,,,,,,,,,,,,,,,, -122000,0.40549618,1.4988657,,,,,,,,,,,,,,,,, -122100,0.37209216,1.4240186,,,,,,,,,,,,,,,,, -122200,0.38421333,1.4892694,,,,,,,,,,,,,,,,, -122300,0.38424402,1.5029613,,,,,,,,,,,,,,,,, -122400,0.35257146,1.3747045,,,,,,,,,,,,,,,,, -122500,0.36942834,1.4602875,,,,,,,,,,,,,,,,, -122600,0.36133513,1.412776,,,,,,,,,,,,,,,,, -122700,0.39316738,1.5445327,,,,,,,,,,,,,,,,, -122800,0.3891645,1.4612395,,,,,,,,,,,,,,,,, -122820,,,0.688624382019043,1.460661768913269,35.01808765515029,0.6914607286453247,1.4257858991622925,30.7654477246006,3000.0,0.7075591087341309,1.3309905529022217,30.70963575294762,3003.0,42871.6183693409,71760.67544698715,42871.6183693409,28883.33780145645,1.80470871925354,0.0 -122900,0.39547017,1.4416679,,,,,,,,,,,,,,,,, -123000,0.38182747,1.4732225,,,,,,,,,,,,,,,,, -123100,0.36149782,1.4502442,,,,,,,,,,,,,,,,, -123200,0.39213413,1.4418008,,,,,,,,,,,,,,,,, -123300,0.39905438,1.4764779,,,,,,,,,,,,,,,,, -123400,0.37605503,1.4917257,,,,,,,,,,,,,,,,, -123500,0.36923516,1.4913315,,,,,,,,,,,,,,,,, -123600,0.375693,1.434083,,,,,,,,,,,,,,,,, -123700,0.3840619,1.5015241,,,,,,,,,,,,,,,,, -123800,0.38396522,1.4249707,,,,,,,,,,,,,,,,, -123900,0.3807451,1.4596062,,,,,,,,,,,,,,,,, -124000,0.38060924,1.4930712,,,,,,,,,,,,,,,,, -124100,0.37084126,1.4299316,,,,,,,,,,,,,,,,, -124200,0.3818892,1.4533445,,,,,,,,,,,,,,,,, -124300,0.39371705,1.5259633,,,,,,,,,,,,,,,,, -124400,0.38189325,1.4360536,,,,,,,,,,,,,,,,, -124500,0.4097509,1.5290511,,,,,,,,,,,,,,,,, -124600,0.40387738,1.4528136,,,,,,,,,,,,,,,,, -124700,0.39646074,1.4796188,,,,,,,,,,,,,,,,, -124800,0.39274275,1.4340534,,,,,,,,,,,,,,,,, -124900,0.38279724,1.4671129,,,,,,,,,,,,,,,,, -125000,0.39035088,1.4329078,,,,,,,,,,,,,,,,, -125100,0.37734824,1.4947368,,,,,,,,,,,,,,,,, -125200,0.3885962,1.379658,,,,,,,,,,,,,,,,, -125228,,,0.699196994304657,1.396242618560791,36.03404700717277,0.6935189962387085,1.420162796974182,31.030971558360275,3000.0,0.7094881534576416,1.3231710195541382,30.738778424579746,3003.0,43711.80949354172,73148.70480275154,43711.80949354172,29431.05741333961,1.8475875854492188,0.0 -125300,0.4178764,1.5675678,,,,,,,,,,,,,,,,, -125400,0.3738199,1.503736,,,,,,,,,,,,,,,,, -125500,0.38617125,1.4714122,,,,,,,,,,,,,,,,, -125600,0.3847341,1.4268357,,,,,,,,,,,,,,,,, -125700,0.3731469,1.438758,,,,,,,,,,,,,,,,, -125800,0.37680817,1.4304937,,,,,,,,,,,,,,,,, -125900,0.3912749,1.4652534,,,,,,,,,,,,,,,,, -126000,0.37915695,1.4896914,,,,,,,,,,,,,,,,, -126100,0.3862394,1.4507197,,,,,,,,,,,,,,,,, -126200,0.39173755,1.4092652,,,,,,,,,,,,,,,,, -126300,0.38414603,1.414682,,,,,,,,,,,,,,,,, -126400,0.35892895,1.3160647,,,,,,,,,,,,,,,,, -126500,0.38082784,1.447198,,,,,,,,,,,,,,,,, -126600,0.37472904,1.4373722,,,,,,,,,,,,,,,,, -126700,0.38749662,1.4905976,,,,,,,,,,,,,,,,, -126800,0.3840786,1.5028491,,,,,,,,,,,,,,,,, -126900,0.37626714,1.3868324,,,,,,,,,,,,,,,,, -127000,0.40708274,1.428619,,,,,,,,,,,,,,,,, -127100,0.37270218,1.3791858,,,,,,,,,,,,,,,,, -127200,0.3773927,1.4788798,,,,,,,,,,,,,,,,, -127300,0.393161,1.5110446,,,,,,,,,,,,,,,,, -127400,0.38429853,1.4494936,,,,,,,,,,,,,,,,, -127500,0.39697307,1.3922565,,,,,,,,,,,,,,,,, -127600,0.3929659,1.4182966,,,,,,,,,,,,,,,,, -127636,,,0.696530818939209,1.4135314226150513,35.25999172604652,0.6937545537948608,1.4164506196975708,31.00998771095735,3000.0,0.7097554206848145,1.319035291671753,30.757323033480343,3003.0,44551.95158267021,74541.41567277908,44551.95158267021,29983.50764322281,1.889930009841919,0.0 -127700,0.38966084,1.4245691,,,,,,,,,,,,,,,,, -127800,0.37899402,1.4194056,,,,,,,,,,,,,,,,, -127900,0.37685758,1.4148908,,,,,,,,,,,,,,,,, -128000,0.40184608,1.5063168,,,,,,,,,,,,,,,,, -128100,0.3886084,1.4351166,,,,,,,,,,,,,,,,, -128200,0.38462082,1.4159523,,,,,,,,,,,,,,,,, -128300,0.38346884,1.3887888,,,,,,,,,,,,,,,,, -128400,0.38728574,1.4127246,,,,,,,,,,,,,,,,, -128500,0.3662838,1.3551918,,,,,,,,,,,,,,,,, -128600,0.37539706,1.4074491,,,,,,,,,,,,,,,,, -128700,0.3758913,1.3369633,,,,,,,,,,,,,,,,, -128800,0.3880715,1.4050922,,,,,,,,,,,,,,,,, -128900,0.38128278,1.4207416,,,,,,,,,,,,,,,,, -129000,0.4037899,1.4700147,,,,,,,,,,,,,,,,, -129100,0.3729625,1.4787714,,,,,,,,,,,,,,,,, -129200,0.40141153,1.4419593,,,,,,,,,,,,,,,,, -129300,0.3966801,1.440214,,,,,,,,,,,,,,,,, -129400,0.36960366,1.4348359,,,,,,,,,,,,,,,,, -129500,0.4035106,1.4817698,,,,,,,,,,,,,,,,, -129600,0.4130842,1.4287332,,,,,,,,,,,,,,,,, -129700,0.38683927,1.3896748,,,,,,,,,,,,,,,,, -129800,0.38062534,1.3985401,,,,,,,,,,,,,,,,, -129900,0.39335817,1.4552485,,,,,,,,,,,,,,,,, -130000,0.39489222,1.4179853,,,,,,,,,,,,,,,,, -130043,,,0.6964461803436279,1.4212472438812256,35.426351940087365,0.693804144859314,1.414819836616516,30.97718072563005,3000.0,0.7098251581192017,1.3173716068267822,30.876633593370595,3003.0,45392.100281476974,75921.76205563545,45392.100281476974,30523.580441236496,1.9321520328521729,0.0 -130100,0.38261107,1.4150169,,,,,,,,,,,,,,,,, -130200,0.38738135,1.4680413,,,,,,,,,,,,,,,,, -130300,0.37730333,1.4053625,,,,,,,,,,,,,,,,, -130400,0.37295535,1.3193479,,,,,,,,,,,,,,,,, -130500,0.37082237,1.459462,,,,,,,,,,,,,,,,, -130600,0.3822659,1.427818,,,,,,,,,,,,,,,,, -130700,0.39196098,1.4375752,,,,,,,,,,,,,,,,, -130800,0.38666895,1.3962593,,,,,,,,,,,,,,,,, -130900,0.41704854,1.5267334,,,,,,,,,,,,,,,,, -131000,0.38555196,1.386097,,,,,,,,,,,,,,,,, -131100,0.3909506,1.414173,,,,,,,,,,,,,,,,, -131200,0.3680069,1.4260648,,,,,,,,,,,,,,,,, -131300,0.37761948,1.4608508,,,,,,,,,,,,,,,,, -131400,0.3660447,1.3868387,,,,,,,,,,,,,,,,, -131500,0.37923956,1.4105912,,,,,,,,,,,,,,,,, -131600,0.3892599,1.4502002,,,,,,,,,,,,,,,,, -131700,0.40123937,1.4488255,,,,,,,,,,,,,,,,, -131800,0.3833199,1.4803927,,,,,,,,,,,,,,,,, -131900,0.38120413,1.4025584,,,,,,,,,,,,,,,,, -132000,0.37638384,1.3858626,,,,,,,,,,,,,,,,, -132100,0.3853386,1.4004027,,,,,,,,,,,,,,,,, -132200,0.37016934,1.4133468,,,,,,,,,,,,,,,,, -132300,0.39099997,1.3759028,,,,,,,,,,,,,,,,, -132400,0.38065305,1.385019,,,,,,,,,,,,,,,,, -132451,,,0.6961711049079895,1.419246792793274,35.54502780300605,0.6945357322692871,1.414249300956726,31.107725195668912,3000.0,0.7100808024406433,1.3164043426513672,30.8854601382876,3003.0,46232.19126367569,77297.3417544365,46232.19126367569,31058.94911122322,1.97504472732544,0.0 -132500,0.38715267,1.475498,,,,,,,,,,,,,,,,, -132600,0.37447336,1.4451339,,,,,,,,,,,,,,,,, -132700,0.39567143,1.4280767,,,,,,,,,,,,,,,,, -132800,0.37280622,1.3869855,,,,,,,,,,,,,,,,, -132900,0.37122625,1.4354092,,,,,,,,,,,,,,,,, -133000,0.3820715,1.4201531,,,,,,,,,,,,,,,,, -133100,0.3812767,1.414835,,,,,,,,,,,,,,,,, -133200,0.3686248,1.3346071,,,,,,,,,,,,,,,,, -133300,0.39591604,1.5303085,,,,,,,,,,,,,,,,, -133333,,,0.6953193545341492,1.4255940914154053,35.78710381048599,0.6945977210998535,1.41435444355011,31.08942441174693,3000.0,0.7099761962890625,1.316562533378601,30.94145452506206,3003.0,46539.62941074371,78143.05393791199,46539.62941074371,31597.15144467354,2.018925905227661,0.0 -133333,,,,,,,,,,,,,,46539.62941074371,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/eval_measurements.csv deleted file mode 100644 index b01d87b9d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,59 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -869.8805475234985,0.0,30.065114498138428,1,0,30.065114498138428,0.0007088489946909,0.0,11.041826248168944,3003,899.9457039833069,0.0006145372171886,0.0,11.063870429992676,0.0004835649742744,0.0,11.036645889282228,3000 -1351.6217308044434,0.0189239978790283,870.2016100883484,2349,0,870.2016100883484,0.5108012557029724,16.717221966085745,2.912444591522217,3003,2221.916384935379,0.5141674280166626,22.965880254505947,2.880971193313598,0.5140295624732971,18.528226390624614,2.860582113265991,3000 -1821.823566198349,0.0436756610870361,1710.351322889328,4699,0,1710.351322889328,0.5915867686271667,22.13067104185569,2.126992464065552,3003,3532.369195461273,0.5758695602416992,27.23287427611181,2.2680907249450684,0.589490532875061,23.513794603427275,2.1476547718048096,3000 -2287.665424346924,0.0689821243286132,2550.4531705379486,7049,0,2550.4531705379486,0.620010495185852,23.77728489354792,1.8880869150161743,3003,4838.414577007294,0.6065484881401062,29.012708543686976,1.9995412826538088,0.6178224682807922,25.41065344375216,1.92452335357666,3000 -2796.3146154880524,0.0997548103332519,3390.548141002655,9399,0,3390.548141002655,0.6367672085762024,25.14278859816317,1.7608153820037842,3003,6187.265499591827,0.6141012907028198,29.74698626680871,1.939665675163269,0.6299859881401062,26.071460310919417,1.8144594430923464,3000 -3285.607243537903,0.1272540092468261,4230.704621553421,11749,0,4230.704621553421,0.6456801295280457,25.485975409278893,1.687228083610535,3003,7516.816064357758,0.6164451241493225,29.7451596694926,1.9157977104187007,0.6390249133110046,26.594450639942533,1.7488797903060913,3000 -3802.454508543015,0.1534552574157714,5070.610538482666,14099,0,5070.610538482666,0.651943564414978,26.12009908804103,1.6426197290420532,3003,8873.669828891754,0.6286578178405762,30.895011292840326,1.8235194683074951,0.6437737941741943,26.906824146885256,1.7144992351531982,3000 -4301.075335741043,0.1794331073760986,5910.604150533676,16449,0,5910.604150533676,0.6589855551719666,26.91923126334531,1.607773780822754,3003,10212.38568687439,0.6269233226776123,30.456122777081536,1.830869555473328,0.6486342549324036,27.54467483066716,1.6799792051315308,3000 -4762.190878391266,0.2115025520324707,6750.731124639511,18800,0,6750.731124639511,0.6605542898178101,26.86495937979053,1.596093773841858,3003,11513.73326444626,0.671360194683075,34.06164879315195,1.5273321866989136,0.6512752175331116,27.3727786688078,1.659209132194519,3000 -5275.910169124603,0.2403328418731689,7590.68705701828,21150,0,7590.68705701828,0.6623671054840088,26.76010784968576,1.5717984437942505,3003,12867.510741472244,0.6342859268188477,30.72447289999782,1.768483281135559,0.6521803736686707,27.661841579977537,1.6406331062316897,3000 -5830.0237646102905,0.2706208229064941,8430.784281015396,23500,0,8430.784281015396,0.667851984500885,27.18575754374719,1.5515425205230713,3003,14261.82667016983,0.6302686333656311,30.47776465991185,1.80091655254364,0.6556645035743713,27.54684052250174,1.6268913745880127,3000 -6342.428351640701,0.3031198978424072,9270.913045167925,25851,0,9270.913045167925,0.6690372824668884,27.224962151330093,1.541406273841858,3003,15614.465742111206,0.6471864581108093,31.524659902627377,1.6817272901535034,0.6565448641777039,27.799491595963826,1.6173309087753296,3000 -7036.736354827881,0.3339135646820068,10111.160809516909,28201,0,10111.160809516909,0.6693510413169861,27.50862562138666,1.5309597253799438,3003,17149.126772880554,0.6392074823379517,31.261923161305248,1.7369545698165894,0.6590991020202637,28.174983894107854,1.6041685342788696,3000 -7603.56007528305,0.3622970581054687,10951.253396511078,30552,0,10951.253396511078,0.6702225208282471,27.48608402007468,1.522372841835022,3003,18556.146406650543,0.6390373110771179,31.27872196488971,1.7507790327072144,0.6580327749252319,27.93156543775106,1.595872402191162,3000 -8134.691674232483,0.390512466430664,11791.432320356367,32903,0,11791.432320356367,0.6717913150787354,27.377931603736805,1.5121753215789795,3003,19927.55940771103,0.6409906148910522,31.478213278866903,1.71613872051239,0.6610209345817566,28.024704001168377,1.5890172719955444,3000 -8690.012436389923,0.4196438789367676,12631.56686782837,35253,0,12631.56686782837,0.6747545599937439,28.059762524152184,1.4998594522476196,3003,21323.122226953503,0.6431419849395752,31.608488104923683,1.7123115062713623,0.6615169048309326,28.002704988103364,1.5796931982040403,3000 -9245.60012793541,0.45088791847229,13471.785254716871,37603,0,13471.785254716871,0.6743013262748718,27.469852057344102,1.49647319316864,3003,22719.0359723568,0.6761975288391113,34.013655184488385,1.4833768606185913,0.6622856259346008,28.19304704726278,1.571945309638977,3000 -9767.721544027328,0.4817836284637451,14311.73891043663,39953,0,14311.73891043663,0.6754401326179504,28.240845223490155,1.4862949848175049,3003,24081.21613359452,0.6511838436126709,31.470296217893207,1.662556767463684,0.665162205696106,28.531168082604847,1.560990333557129,3000 -10280.534869909286,0.5133066177368164,15151.733073234558,42301,0,15151.733073234558,0.6777874827384949,28.08367316038388,1.4830689430236816,3003,25434.131522655487,0.6459149718284607,31.73541057923894,1.6967586278915403,0.6646910905838013,28.590118849023177,1.5630953311920166,3000 -10799.775356054306,0.5443167686462402,15991.678433418274,44650,0,15991.678433418274,0.6779733896255493,27.80825715299147,1.475687861442566,3003,26793.424834012985,0.6544402837753296,32.2549238963339,1.6286474466323853,0.666154146194458,28.460654397908776,1.5522340536117554,3000 -11432.476187705994,0.577672004699707,16831.602340459824,46999,0,16831.602340459824,0.6801813244819641,28.38958892681242,1.458451747894287,3003,28266.15973854065,0.6529893279075623,31.74572989276444,1.6469743251800537,0.6668609380722046,28.47980213878581,1.5455957651138306,3000 -11935.63941526413,0.6097137928009033,17671.68727684021,49349,0,17671.68727684021,0.6800883412361145,28.21943189274275,1.457780838012695,3003,29609.51440000534,0.64932781457901,31.92990226128372,1.6753994226455688,0.6689067482948303,28.933472139073533,1.5383224487304688,3000 -12620.1957821846,0.643932580947876,18511.66797375679,51699,0,18511.66797375679,0.6813085079193115,28.43325483477639,1.4497711658477783,3003,31134.161219596863,0.6542662978172302,32.459203013070265,1.6272282600402832,0.6693159341812134,28.620420565903288,1.5343632698059082,3000 -13282.779415607452,0.6774072647094727,19351.66876530648,54048,0,19351.66876530648,0.68177330493927,28.328816470806625,1.448590636253357,3003,32636.85775065422,0.6529883146286011,32.510313488174106,1.654189944267273,0.6688695549964905,28.57319480788474,1.5301557779312134,3000 -13884.517220497131,0.7161824703216553,20191.756650686264,56397,0,20191.756650686264,0.682877242565155,28.169233789147054,1.4382997751235962,3003,34078.80337572098,0.68346107006073,34.018130430164184,1.446035623550415,0.6704070568084717,29.086346659127816,1.5175199508666992,3000 -14473.63262438774,0.7506313323974609,21031.882925748825,58747,0,21031.882925748825,0.6855267286300659,28.621032891839445,1.4299596548080444,3003,35508.15443897247,0.6575668454170227,32.5052435486537,1.6168078184127808,0.6729984879493713,29.125088035511126,1.5107536315917969,3000 -14998.93830871582,0.7828867435455322,21871.777238607407,61096,0,21871.777238607407,0.6879902482032776,29.14197211825932,1.4211199283599854,3003,36873.4613969326,0.6561620235443115,32.63350373980873,1.6366279125213623,0.6729736924171448,29.140785546387807,1.512277603149414,3000 -15558.572283506392,0.8171112537384033,22711.711676359177,63444,0,22711.711676359177,0.6889896392822266,28.9947816834239,1.410127878189087,3003,38273.14426493645,0.6674286723136902,33.49294815381641,1.547412633895874,0.6748335361480713,29.477780759940647,1.4961423873901367,3000 -16119.899282455444,0.8502511978149414,23551.849817037582,65794,0,23551.849817037582,0.6903724670410156,29.56408904455485,1.4074007272720337,3003,39674.71835613251,0.6565250754356384,32.66615467451501,1.6203813552856443,0.676073431968689,29.329756231686886,1.4927388429641724,3000 -16683.019134521484,0.8892166614532471,24391.99292993545,68144,0,24391.99292993545,0.6897798180580139,29.068497099009715,1.4021321535110474,3003,41078.09699392319,0.6547269821166992,32.47219353442573,1.634839653968811,0.677561342716217,29.4492775140218,1.4846450090408323,3000 -17196.86828827858,0.9239933490753174,25232.03781723976,70494,0,25232.03781723976,0.6899541020393372,29.03229432623568,1.3922793865203855,3003,42432.101380348206,0.6648585796356201,32.83949660411032,1.5671465396881104,0.6763090491294861,29.078669687795948,1.4787070751190186,3000 -17712.821888685226,0.9654061794281006,26072.20747256279,72843,0,26072.20747256279,0.6910812854766846,29.18977920484513,1.3824280500411987,3003,43788.344264507294,0.6621139049530029,33.390701136287404,1.5815826654434204,0.6767553687095642,29.35268774776388,1.473919153213501,3000 -18301.44424295425,1.0022211074829102,26912.381454229355,75193,0,26912.381454229355,0.693544864654541,29.07100048590121,1.378287672996521,3003,45217.25480270386,0.6858174800872803,34.366781063424106,1.426164984703064,0.6785904765129089,29.51441599734365,1.4634785652160645,3000 -19005.60043978691,1.0388991832733154,27752.45276355744,77543,0,27752.45276355744,0.6951484680175781,29.485973265345606,1.3676152229309082,3003,46761.59459590912,0.6669697165489197,33.21220426533364,1.5403841733932495,0.6799419522285461,29.43968950403044,1.4549518823623655,3000 -19533.910708904263,1.0820260047912598,28592.432891607285,79893,0,28592.432891607285,0.6960781216621399,29.612354085956323,1.3559688329696655,3003,48130.00425624848,0.6644576787948608,32.65712602644397,1.573724389076233,0.6827317476272583,29.785093513403424,1.4516927003860474,3000 -20242.84255671501,1.1201717853546145,29432.561252594,82243,0,29432.561252594,0.6966823935508728,29.61672950952151,1.3540998697280884,3003,49679.17920422554,0.6750671863555908,33.87395627976873,1.4929388761520386,0.6826077699661255,29.84936934505664,1.4436779022216797,3000 -20943.49081230164,1.1579155921936035,30272.66680049896,84593,0,30272.66680049896,0.6992040276527405,29.948791249609343,1.3429124355316162,3003,51220.04585957527,0.6679847240447998,33.47505030763876,1.5398011207580566,0.6833269000053406,29.752018414615385,1.4395651817321775,3000 -21512.685092926025,1.1974620819091797,31112.70547223091,86944,0,31112.70547223091,0.7015630006790161,30.01117652283077,1.336978793144226,3003,52629.39412069321,0.66844642162323,33.3239685944154,1.5463353395462036,0.6863275170326233,30.17483410773,1.4329047203063965,3000 -22114.81109213829,1.2344374656677246,31952.67540025711,89293,0,31952.67540025711,0.7022950649261475,30.109991566006315,1.3293834924697876,3003,54071.60535812378,0.6765536069869995,33.80841067188661,1.4846397638320925,0.6877037882804871,30.144694232267003,1.4231077432632446,3000 -22637.1190366745,1.2718265056610107,32792.68555688858,91642,0,32792.68555688858,0.7022950649261475,30.012051023290432,1.3214539289474487,3003,55434.03802442551,0.6729382872581482,33.947059749229915,1.5041873455047607,0.6867242455482483,30.009697354141604,1.4200741052627563,3000 -23166.96990251541,1.3109593391418457,33632.6453063488,93991,0,33632.6453063488,0.7032014727592468,29.828300853864768,1.3106350898742676,3003,56803.9655148983,0.6925990581512451,34.75326842015933,1.38995623588562,0.6878277659416199,30.225550510148068,1.4126427173614502,3000 -23863.072603464127,1.351414918899536,34472.82131195068,96341,0,34472.82131195068,0.7055023312568665,30.228751807442475,1.3078140020370483,3003,58340.36324286461,0.6803255081176758,33.819762322880095,1.4658100605010986,0.6885097622871399,29.990235745955843,1.407971978187561,3000 -24379.07930803299,1.391371726989746,35312.71530985832,98690,0,35312.71530985832,0.7070013284683228,30.51100426873089,1.3015625476837158,3003,59696.38206171989,0.678549587726593,34.745071742947765,1.4740723371505735,0.689811646938324,30.306885565787923,1.4059818983078003,3000 -24952.356118440628,1.4311096668243408,36152.8135163784,101040,0,36152.8135163784,0.7060601115226746,30.27591467570254,1.2959932088851929,3003,61109.87213683128,0.6873783469200134,34.822233159168285,1.4144078493118286,0.689452052116394,30.43498417126523,1.394218683242798,3000 -25546.56099653244,1.4709479808807373,36992.80270028114,103390,0,36992.80270028114,0.7068851590156555,30.54459233862504,1.2904762029647827,3003,62544.18134522438,0.6787706017494202,34.936464710336374,1.472381591796875,0.6904067993164062,30.4547279180483,1.394535779953003,3000 -26074.223981142044,1.5178804397583008,37832.90144467354,105740,0,37832.90144467354,0.7088722586631775,30.58754004769873,1.2860301733016968,3003,63912.06481075287,0.6793479919433594,34.524389490258514,1.474918246269226,0.6916218996047974,30.763161599157005,1.387118220329285,3000 -26644.022025108337,1.5571041107177734,38672.90054440498,108090,0,38672.90054440498,0.7091395258903503,30.70939975137136,1.276688575744629,3003,65321.97705411911,0.6895817518234253,34.956725807169725,1.4022433757781982,0.6924774646759033,30.50776881508136,1.3789490461349487,3000 -27234.110072135925,1.5973656177520752,39513.07421565056,110440,0,39513.07421565056,0.7114055156707764,30.80145430874005,1.2721518278121948,3003,66752.35574197769,0.6895023584365845,35.16845176199788,1.4150652885437012,0.6943125128746033,30.72178388616816,1.378193974494934,3000 -27799.28997278213,1.638994216918945,40353.12133765221,112791,0,40353.12133765221,0.7100808024406433,30.55024025701441,1.2701268196105957,3003,68157.70009493828,0.6971345543861389,35.54626248357229,1.3650158643722534,0.6938785314559937,30.809153601006447,1.3746583461761477,3000 -28490.20787382126,1.6810777187347412,41193.01627731323,115141,0,41193.01627731323,0.7127999663352966,30.843691765002152,1.2630198001861572,3003,69688.62987804413,0.6917914748191833,35.197873516392164,1.399938702583313,0.6941389441490173,30.793833684223024,1.3688514232635498,3000 -29044.949915885925,1.7298576831817627,42032.95151424408,117491,0,42032.95151424408,0.7123932838439941,30.727742222734168,1.262866735458374,3003,71083.43123865128,0.6893499493598938,35.249029116488906,1.4066481590270996,0.6947464942932129,30.907103979729687,1.3684260845184326,3000 -29609.83997654915,1.772599458694458,42873.06673336029,119841,0,42873.06673336029,0.7137296199798584,30.89031023851516,1.258353590965271,3003,72488.55532360077,0.6982017755508423,35.42124147020956,1.360303521156311,0.696234405040741,30.81000255265572,1.3644174337387085,3000 -30275.60227298737,1.814788818359375,43712.94831061363,122190,0,43712.94831061363,0.7138806581497192,31.07766958165444,1.2562905550003052,3003,73994.31863379478,0.6944260001182556,35.73075424767851,1.3774758577346802,0.6958252191543579,30.87195959719736,1.363082766532898,3000 -30864.93314909935,1.8580679893493648,44553.16763043404,124540,0,44553.16763043404,0.7133693695068359,30.94660623697596,1.2556757926940918,3003,75423.9879014492,0.6927077174186707,35.60168167605079,1.3844177722930908,0.6960111856460571,31.036330806615823,1.361225128173828,3000 -31472.653499126434,1.9009826183319087,45393.23704409599,126890,0,45393.23704409599,0.7141130566596985,30.891123087192003,1.2537082433700562,3003,76871.89690589905,0.6990454196929932,35.55317402084076,1.3546652793884275,0.6965071558952332,30.923319946605933,1.360289216041565,3000 -32112.87235379219,1.943530559539795,46233.19505262375,129239,0,46233.19505262375,0.714043378829956,30.942558677133007,1.252797245979309,3003,78352.19459056854,0.6980040073394775,35.294106469705945,1.359439492225647,0.6968419551849365,30.871485970728653,1.3597447872161863,3000 -32698.58621668816,1.9933269023895264,47073.37712454796,131588,0,47073.37712454796,0.714043378829956,30.929931785958644,1.252558946609497,3003,79778.21843957901,0.6947080492973328,35.40285800418826,1.3761208057403564,0.6967551708221436,30.906944090047656,1.3596832752227783,3000 -33292.195563316345,2.0389645099639893,47697.16169524193,133333,0,47697.16169524193,0.7140317559242249,30.921156057591386,1.2525193691253662,3003,80995.71348690987,0.6980851888656616,35.45882877194227,1.3578859567642212,0.6968915462493896,30.90651925945721,1.359694004058838,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/measurements.csv deleted file mode 100644 index 3895848d8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1394 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.717681,11.032805,,,,,,,,,,,,,,,,, -1,,,0.0006145372171886,11.063870429992676,0.0,0.0004835649742744,11.036645889282228,0.0,3000.0,0.0007088489946909,11.041826248168944,0.0,3003.0,30.065114498138428,899.9457039833069,30.065114498138428,869.8805475234985,0.0,0.0 -100,0.16883129,8.255187,,,,,,,,,,,,,,,,, -200,0.568055,7.432907,,,,,,,,,,,,,,,,, -300,0.6005472,6.8771214,,,,,,,,,,,,,,,,, -400,0.45084456,6.282825,,,,,,,,,,,,,,,,, -500,0.36807558,5.860792,,,,,,,,,,,,,,,,, -600,0.35619077,5.561835,,,,,,,,,,,,,,,,, -700,0.4322507,5.266047,,,,,,,,,,,,,,,,, -800,0.7127715,4.982022,,,,,,,,,,,,,,,,, -900,0.5225726,4.813824,,,,,,,,,,,,,,,,, -1000,0.5303154,4.556843,,,,,,,,,,,,,,,,, -1100,0.53954506,4.2719283,,,,,,,,,,,,,,,,, -1200,0.553711,4.115427,,,,,,,,,,,,,,,,, -1300,0.5937428,4.0353394,,,,,,,,,,,,,,,,, -1400,0.53078926,3.7172914,,,,,,,,,,,,,,,,, -1500,0.6158643,3.6267145,,,,,,,,,,,,,,,,, -1600,0.68593,3.5660255,,,,,,,,,,,,,,,,, -1700,0.6303174,3.516877,,,,,,,,,,,,,,,,, -1800,0.5276762,3.335792,,,,,,,,,,,,,,,,, -1900,0.43161523,3.2170444,,,,,,,,,,,,,,,,, -2000,0.390192,3.1880329,,,,,,,,,,,,,,,,, -2100,0.41922495,3.085104,,,,,,,,,,,,,,,,, -2200,0.3804072,3.0833647,,,,,,,,,,,,,,,,, -2300,0.5281038,3.0522695,,,,,,,,,,,,,,,,, -2349,,,0.5141674280166626,2.880971193313598,22.965880254505947,0.5140295624732971,2.860582113265991,18.528226390624614,3000.0,0.5108012557029724,2.912444591522217,16.717221966085745,3003.0,870.2016100883484,2221.916384935379,870.2016100883484,1351.6217308044434,0.0189239978790283,0.0 -2400,0.3794611,2.927327,,,,,,,,,,,,,,,,, -2500,0.32846597,2.841686,,,,,,,,,,,,,,,,, -2600,0.3135362,2.8681955,,,,,,,,,,,,,,,,, -2700,0.37832546,2.8418572,,,,,,,,,,,,,,,,, -2800,0.35523057,2.7884207,,,,,,,,,,,,,,,,, -2900,0.30884793,2.7859244,,,,,,,,,,,,,,,,, -3000,0.31046394,2.7306082,,,,,,,,,,,,,,,,, -3100,0.26976183,2.5452626,,,,,,,,,,,,,,,,, -3200,0.22881934,2.6819062,,,,,,,,,,,,,,,,, -3300,0.24443203,2.614486,,,,,,,,,,,,,,,,, -3400,0.2518608,2.5465794,,,,,,,,,,,,,,,,, -3500,0.22732209,2.4908924,,,,,,,,,,,,,,,,, -3600,0.22108752,2.6612794,,,,,,,,,,,,,,,,, -3700,0.2058108,2.4807932,,,,,,,,,,,,,,,,, -3800,0.18980442,2.4745338,,,,,,,,,,,,,,,,, -3900,0.23309794,2.4362805,,,,,,,,,,,,,,,,, -4000,0.2044445,2.5289934,,,,,,,,,,,,,,,,, -4100,0.18440907,2.2809148,,,,,,,,,,,,,,,,, -4200,0.20842636,2.3985422,,,,,,,,,,,,,,,,, -4300,0.18668155,2.48198,,,,,,,,,,,,,,,,, -4400,0.18247697,2.3796453,,,,,,,,,,,,,,,,, -4500,0.1879587,2.313947,,,,,,,,,,,,,,,,, -4600,0.19456416,2.3426695,,,,,,,,,,,,,,,,, -4699,,,0.5758695602416992,2.2680907249450684,27.23287427611181,0.589490532875061,2.1476547718048096,23.513794603427275,3000.0,0.5915867686271667,2.126992464065552,22.13067104185569,3003.0,1710.351322889328,3532.369195461273,1710.351322889328,1821.823566198349,0.0436756610870361,0.0 -4700,0.17690098,2.2769165,,,,,,,,,,,,,,,,, -4800,0.18978807,2.3698232,,,,,,,,,,,,,,,,, -4900,0.15938333,2.2700264,,,,,,,,,,,,,,,,, -5000,0.15983823,2.2569509,,,,,,,,,,,,,,,,, -5100,0.15338542,2.317498,,,,,,,,,,,,,,,,, -5200,0.20335571,2.285846,,,,,,,,,,,,,,,,, -5300,0.16538292,2.2856946,,,,,,,,,,,,,,,,, -5400,0.15247993,2.2801394,,,,,,,,,,,,,,,,, -5500,0.15306251,2.1989207,,,,,,,,,,,,,,,,, -5600,0.16117713,2.2097294,,,,,,,,,,,,,,,,, -5700,0.16901483,2.2852247,,,,,,,,,,,,,,,,, -5800,0.14989954,2.2333894,,,,,,,,,,,,,,,,, -5900,0.14576913,2.2547274,,,,,,,,,,,,,,,,, -6000,0.18075825,2.222018,,,,,,,,,,,,,,,,, -6100,0.1617684,2.1810246,,,,,,,,,,,,,,,,, -6200,0.1635096,2.2109506,,,,,,,,,,,,,,,,, -6300,0.18813221,2.1436005,,,,,,,,,,,,,,,,, -6400,0.19042523,2.2384367,,,,,,,,,,,,,,,,, -6500,0.16618347,2.2059882,,,,,,,,,,,,,,,,, -6600,0.18049897,2.204409,,,,,,,,,,,,,,,,, -6700,0.1513997,2.1499581,,,,,,,,,,,,,,,,, -6800,0.15698148,2.151864,,,,,,,,,,,,,,,,, -6900,0.20707539,2.1808646,,,,,,,,,,,,,,,,, -7000,0.14863613,2.1489637,,,,,,,,,,,,,,,,, -7049,,,0.6065484881401062,1.9995412826538088,29.012708543686976,0.6178224682807922,1.92452335357666,25.41065344375216,3000.0,0.620010495185852,1.8880869150161743,23.77728489354792,3003.0,2550.4531705379486,4838.414577007294,2550.4531705379486,2287.665424346924,0.0689821243286132,0.0 -7100,0.16762614,2.1212914,,,,,,,,,,,,,,,,, -7200,0.20552814,2.1868489,,,,,,,,,,,,,,,,, -7300,0.17960072,2.085382,,,,,,,,,,,,,,,,, -7400,0.19595902,2.1173635,,,,,,,,,,,,,,,,, -7500,0.15288723,2.0508037,,,,,,,,,,,,,,,,, -7600,0.15790677,2.1878767,,,,,,,,,,,,,,,,, -7700,0.24347386,2.0003264,,,,,,,,,,,,,,,,, -7800,0.15961458,2.0117002,,,,,,,,,,,,,,,,, -7900,0.19171733,2.0401924,,,,,,,,,,,,,,,,, -8000,0.16592665,2.2052326,,,,,,,,,,,,,,,,, -8100,0.2284381,2.0254273,,,,,,,,,,,,,,,,, -8200,0.17745167,2.0843284,,,,,,,,,,,,,,,,, -8300,0.18547443,2.0671637,,,,,,,,,,,,,,,,, -8400,0.1641682,2.047502,,,,,,,,,,,,,,,,, -8500,0.19956304,2.1391714,,,,,,,,,,,,,,,,, -8600,0.17019251,2.0543964,,,,,,,,,,,,,,,,, -8700,0.1822863,2.0576322,,,,,,,,,,,,,,,,, -8800,0.17097159,2.1256516,,,,,,,,,,,,,,,,, -8900,0.21023358,2.0194786,,,,,,,,,,,,,,,,, -9000,0.15616873,2.011052,,,,,,,,,,,,,,,,, -9100,0.18001842,2.0787854,,,,,,,,,,,,,,,,, -9200,0.229785,2.1427135,,,,,,,,,,,,,,,,, -9300,0.1506812,1.9707092,,,,,,,,,,,,,,,,, -9399,,,0.6141012907028198,1.939665675163269,29.74698626680871,0.6299859881401062,1.8144594430923464,26.071460310919417,3000.0,0.6367672085762024,1.7608153820037842,25.14278859816317,3003.0,3390.548141002655,6187.265499591827,3390.548141002655,2796.3146154880524,0.0997548103332519,0.0 -9400,0.1760852,2.0991912,,,,,,,,,,,,,,,,, -9500,0.21665397,2.0452006,,,,,,,,,,,,,,,,, -9600,0.17456125,2.0004852,,,,,,,,,,,,,,,,, -9700,0.17908394,1.9775466,,,,,,,,,,,,,,,,, -9800,0.18613918,2.0612178,,,,,,,,,,,,,,,,, -9900,0.19143349,1.9897802,,,,,,,,,,,,,,,,, -10000,0.18226205,1.931681,,,,,,,,,,,,,,,,, -10100,0.14328986,1.9763998,,,,,,,,,,,,,,,,, -10200,0.17797047,2.043988,,,,,,,,,,,,,,,,, -10300,0.17336826,1.8982981,,,,,,,,,,,,,,,,, -10400,0.1788328,1.9239737,,,,,,,,,,,,,,,,, -10500,0.18860757,2.0306215,,,,,,,,,,,,,,,,, -10600,0.18730995,1.9563181,,,,,,,,,,,,,,,,, -10700,0.18015778,2.0236638,,,,,,,,,,,,,,,,, -10800,0.17534165,2.0091383,,,,,,,,,,,,,,,,, -10900,0.16605236,1.9624511,,,,,,,,,,,,,,,,, -11000,0.23784095,2.0648863,,,,,,,,,,,,,,,,, -11100,0.2256989,1.9571772,,,,,,,,,,,,,,,,, -11200,0.33621666,1.9792907,,,,,,,,,,,,,,,,, -11300,0.18194956,2.0383537,,,,,,,,,,,,,,,,, -11400,0.23367597,1.9089859,,,,,,,,,,,,,,,,, -11500,0.1786068,2.0635943,,,,,,,,,,,,,,,,, -11600,0.2657094,1.9212557,,,,,,,,,,,,,,,,, -11700,0.2186911,1.9688576,,,,,,,,,,,,,,,,, -11749,,,0.6164451241493225,1.9157977104187007,29.7451596694926,0.6390249133110046,1.7488797903060913,26.594450639942533,3000.0,0.6456801295280457,1.687228083610535,25.485975409278893,3003.0,4230.704621553421,7516.816064357758,4230.704621553421,3285.607243537903,0.1272540092468261,0.0 -11800,0.29445466,1.949369,,,,,,,,,,,,,,,,, -11900,0.18599114,2.0000215,,,,,,,,,,,,,,,,, -12000,0.1950868,2.0239038,,,,,,,,,,,,,,,,, -12100,0.19868645,1.9359846,,,,,,,,,,,,,,,,, -12200,0.16393478,1.9333849,,,,,,,,,,,,,,,,, -12300,0.16966757,1.995674,,,,,,,,,,,,,,,,, -12400,0.15793073,1.8857573,,,,,,,,,,,,,,,,, -12500,0.18337637,1.993022,,,,,,,,,,,,,,,,, -12600,0.16205037,1.9104164,,,,,,,,,,,,,,,,, -12700,0.23368247,1.9410429,,,,,,,,,,,,,,,,, -12800,0.19445033,2.0532174,,,,,,,,,,,,,,,,, -12900,0.23594719,1.8515569,,,,,,,,,,,,,,,,, -13000,0.16364627,1.9598011,,,,,,,,,,,,,,,,, -13100,0.27796367,2.0105937,,,,,,,,,,,,,,,,, -13200,0.17769249,1.9895723,,,,,,,,,,,,,,,,, -13300,0.18611774,1.9887898,,,,,,,,,,,,,,,,, -13400,0.22202186,1.9505748,,,,,,,,,,,,,,,,, -13500,0.18226595,2.0082264,,,,,,,,,,,,,,,,, -13600,0.22805634,1.906826,,,,,,,,,,,,,,,,, -13700,0.18683618,1.9420699,,,,,,,,,,,,,,,,, -13800,0.1640953,1.8955073,,,,,,,,,,,,,,,,, -13900,0.19414964,1.9006668,,,,,,,,,,,,,,,,, -14000,0.18945496,1.9414028,,,,,,,,,,,,,,,,, -14099,,,0.6286578178405762,1.8235194683074951,30.895011292840326,0.6437737941741943,1.7144992351531982,26.906824146885256,3000.0,0.651943564414978,1.6426197290420532,26.12009908804103,3003.0,5070.610538482666,8873.669828891754,5070.610538482666,3802.454508543015,0.1534552574157714,0.0 -14100,0.21434961,1.9534329,,,,,,,,,,,,,,,,, -14200,0.3883765,1.8810344,,,,,,,,,,,,,,,,, -14300,0.19630265,1.8915744,,,,,,,,,,,,,,,,, -14400,0.1931719,1.9055912,,,,,,,,,,,,,,,,, -14500,0.2195052,1.8909853,,,,,,,,,,,,,,,,, -14600,0.22070736,1.9734929,,,,,,,,,,,,,,,,, -14700,0.17968883,1.8873031,,,,,,,,,,,,,,,,, -14800,0.2497622,1.9511298,,,,,,,,,,,,,,,,, -14900,0.1821397,1.9221469,,,,,,,,,,,,,,,,, -15000,0.18868694,1.9363265,,,,,,,,,,,,,,,,, -15100,0.26639295,1.9065185,,,,,,,,,,,,,,,,, -15200,0.19371201,1.932638,,,,,,,,,,,,,,,,, -15300,0.18517464,1.9398321,,,,,,,,,,,,,,,,, -15400,0.19140723,1.9275733,,,,,,,,,,,,,,,,, -15500,0.17339246,1.9030691,,,,,,,,,,,,,,,,, -15600,0.21705203,1.9917482,,,,,,,,,,,,,,,,, -15700,0.19739881,1.9690999,,,,,,,,,,,,,,,,, -15800,0.17160235,1.850749,,,,,,,,,,,,,,,,, -15900,0.23259677,1.9604748,,,,,,,,,,,,,,,,, -16000,0.24771504,1.7979985,,,,,,,,,,,,,,,,, -16100,0.19544333,1.9373626,,,,,,,,,,,,,,,,, -16200,0.16985828,1.8333431,,,,,,,,,,,,,,,,, -16300,0.17466031,1.9258562,,,,,,,,,,,,,,,,, -16400,0.18332775,1.8401302,,,,,,,,,,,,,,,,, -16449,,,0.6269233226776123,1.830869555473328,30.456122777081536,0.6486342549324036,1.6799792051315308,27.54467483066716,3000.0,0.6589855551719666,1.607773780822754,26.91923126334531,3003.0,5910.604150533676,10212.38568687439,5910.604150533676,4301.075335741043,0.1794331073760986,0.0 -16500,0.18593384,1.8366308,,,,,,,,,,,,,,,,, -16600,0.20780982,1.8753178,,,,,,,,,,,,,,,,, -16700,0.30227864,1.8349051,,,,,,,,,,,,,,,,, -16800,0.19539477,1.8961427,,,,,,,,,,,,,,,,, -16900,0.17088915,1.9729178,,,,,,,,,,,,,,,,, -17000,0.1858572,2.0097752,,,,,,,,,,,,,,,,, -17100,0.2644845,1.9437159,,,,,,,,,,,,,,,,, -17200,0.16997719,1.8745909,,,,,,,,,,,,,,,,, -17300,0.19074224,1.8178183,,,,,,,,,,,,,,,,, -17400,0.25440943,1.9295185,,,,,,,,,,,,,,,,, -17500,0.16106534,1.8389131,,,,,,,,,,,,,,,,, -17600,0.17007059,1.897902,,,,,,,,,,,,,,,,, -17700,0.19825497,1.8850211,,,,,,,,,,,,,,,,, -17800,0.21228245,1.9345373,,,,,,,,,,,,,,,,, -17900,0.21797805,1.8087685,,,,,,,,,,,,,,,,, -18000,0.1736568,1.8764923,,,,,,,,,,,,,,,,, -18100,0.17427717,1.7957882,,,,,,,,,,,,,,,,, -18200,0.19392258,1.8850543,,,,,,,,,,,,,,,,, -18300,0.23456146,1.9174623,,,,,,,,,,,,,,,,, -18400,0.25195128,1.8344754,,,,,,,,,,,,,,,,, -18500,0.19221339,1.8993244,,,,,,,,,,,,,,,,, -18600,0.20687066,1.8492402,,,,,,,,,,,,,,,,, -18700,0.1949129,1.8736812,,,,,,,,,,,,,,,,, -18800,,,0.671360194683075,1.5273321866989136,34.06164879315195,0.6512752175331116,1.659209132194519,27.3727786688078,3000.0,0.6605542898178101,1.596093773841858,26.86495937979053,3003.0,6750.731124639511,11513.73326444626,6750.731124639511,4762.190878391266,0.2115025520324707,0.0 -18800,0.22626917,1.7856091,,,,,,,,,,,,,,,,, -18900,0.18307865,1.8828182,,,,,,,,,,,,,,,,, -19000,0.18915857,1.7760453,,,,,,,,,,,,,,,,, -19100,0.2186645,1.885531,,,,,,,,,,,,,,,,, -19200,0.195995,1.8879049,,,,,,,,,,,,,,,,, -19300,0.21808341,1.7671671,,,,,,,,,,,,,,,,, -19400,0.18434094,1.860273,,,,,,,,,,,,,,,,, -19500,0.29287493,1.864613,,,,,,,,,,,,,,,,, -19600,0.17739631,1.8861585,,,,,,,,,,,,,,,,, -19700,0.18324775,1.8638426,,,,,,,,,,,,,,,,, -19800,0.22857267,1.9248224,,,,,,,,,,,,,,,,, -19900,0.1882083,1.8315011,,,,,,,,,,,,,,,,, -20000,0.20436127,1.8474575,,,,,,,,,,,,,,,,, -20100,0.17489834,1.825701,,,,,,,,,,,,,,,,, -20200,0.17162642,1.863176,,,,,,,,,,,,,,,,, -20300,0.18134342,1.8230754,,,,,,,,,,,,,,,,, -20400,0.2738484,1.9707639,,,,,,,,,,,,,,,,, -20500,0.28263944,1.9065275,,,,,,,,,,,,,,,,, -20600,0.2562378,1.9041564,,,,,,,,,,,,,,,,, -20700,0.18314935,1.8620607,,,,,,,,,,,,,,,,, -20800,0.28973886,1.8289349,,,,,,,,,,,,,,,,, -20900,0.20075582,1.7967051,,,,,,,,,,,,,,,,, -21000,0.17422134,1.9206715,,,,,,,,,,,,,,,,, -21100,0.2319186,1.8524542,,,,,,,,,,,,,,,,, -21150,,,0.6342859268188477,1.768483281135559,30.72447289999782,0.6521803736686707,1.6406331062316897,27.661841579977537,3000.0,0.6623671054840088,1.5717984437942505,26.76010784968576,3003.0,7590.68705701828,12867.510741472244,7590.68705701828,5275.910169124603,0.2403328418731689,0.0 -21200,0.18982525,1.8701483,,,,,,,,,,,,,,,,, -21300,0.21041359,1.8175368,,,,,,,,,,,,,,,,, -21400,0.17247179,1.952145,,,,,,,,,,,,,,,,, -21500,0.19744918,1.8013729,,,,,,,,,,,,,,,,, -21600,0.1792484,1.8189089,,,,,,,,,,,,,,,,, -21700,0.25908545,1.8483951,,,,,,,,,,,,,,,,, -21800,0.18495986,1.9215465,,,,,,,,,,,,,,,,, -21900,0.22239506,1.8232174,,,,,,,,,,,,,,,,, -22000,0.18912959,1.8398463,,,,,,,,,,,,,,,,, -22100,0.24104986,1.8830774,,,,,,,,,,,,,,,,, -22200,0.2072942,1.8422774,,,,,,,,,,,,,,,,, -22300,0.22540487,1.8281007,,,,,,,,,,,,,,,,, -22400,0.17929462,1.9276657,,,,,,,,,,,,,,,,, -22500,0.28099275,1.8781948,,,,,,,,,,,,,,,,, -22600,0.21286625,1.8381394,,,,,,,,,,,,,,,,, -22700,0.18972446,1.9023032,,,,,,,,,,,,,,,,, -22800,0.22907017,1.7068692,,,,,,,,,,,,,,,,, -22900,0.20929675,1.9008688,,,,,,,,,,,,,,,,, -23000,0.18499768,1.8742748,,,,,,,,,,,,,,,,, -23100,0.17740695,1.8584948,,,,,,,,,,,,,,,,, -23200,0.2017672,1.857424,,,,,,,,,,,,,,,,, -23300,0.1843958,1.7924403,,,,,,,,,,,,,,,,, -23400,0.17212294,1.8620538,,,,,,,,,,,,,,,,, -23500,,,0.6302686333656311,1.80091655254364,30.47776465991185,0.6556645035743713,1.6268913745880127,27.54684052250174,3000.0,0.667851984500885,1.5515425205230713,27.18575754374719,3003.0,8430.784281015396,14261.82667016983,8430.784281015396,5830.0237646102905,0.2706208229064941,0.0 -23500,0.18746792,1.834688,,,,,,,,,,,,,,,,, -23600,0.26692516,1.9071944,,,,,,,,,,,,,,,,, -23700,0.20272651,1.8367865,,,,,,,,,,,,,,,,, -23800,0.1988028,1.8973366,,,,,,,,,,,,,,,,, -23900,0.19735901,1.8299698,,,,,,,,,,,,,,,,, -24000,0.21114093,1.8245217,,,,,,,,,,,,,,,,, -24100,0.18998714,1.9264975,,,,,,,,,,,,,,,,, -24200,0.17090744,1.8390914,,,,,,,,,,,,,,,,, -24300,0.19288412,1.9698051,,,,,,,,,,,,,,,,, -24400,0.20683855,1.8900213,,,,,,,,,,,,,,,,, -24500,0.18908101,1.843176,,,,,,,,,,,,,,,,, -24600,0.18981552,1.8228797,,,,,,,,,,,,,,,,, -24700,0.19010632,1.7577692,,,,,,,,,,,,,,,,, -24800,0.18551907,1.860153,,,,,,,,,,,,,,,,, -24900,0.18705161,1.8491117,,,,,,,,,,,,,,,,, -25000,0.26323125,1.8646883,,,,,,,,,,,,,,,,, -25100,0.22618529,1.8517572,,,,,,,,,,,,,,,,, -25200,0.18167378,1.821015,,,,,,,,,,,,,,,,, -25300,0.2543828,1.8897383,,,,,,,,,,,,,,,,, -25400,0.16498181,1.8231541,,,,,,,,,,,,,,,,, -25500,0.17101744,1.727381,,,,,,,,,,,,,,,,, -25600,0.21908544,1.8516073,,,,,,,,,,,,,,,,, -25700,0.18336043,1.8071852,,,,,,,,,,,,,,,,, -25800,0.19987145,1.7324227,,,,,,,,,,,,,,,,, -25851,,,0.6471864581108093,1.6817272901535034,31.524659902627377,0.6565448641777039,1.6173309087753296,27.799491595963826,3000.0,0.6690372824668884,1.541406273841858,27.224962151330093,3003.0,9270.913045167925,15614.465742111206,9270.913045167925,6342.428351640701,0.3031198978424072,0.0 -25900,0.20029473,1.8779278,,,,,,,,,,,,,,,,, -26000,0.19068535,1.6984247,,,,,,,,,,,,,,,,, -26100,0.24839295,1.908454,,,,,,,,,,,,,,,,, -26200,0.26426515,1.9015661,,,,,,,,,,,,,,,,, -26300,0.17784777,1.8272415,,,,,,,,,,,,,,,,, -26400,0.24246965,1.9471741,,,,,,,,,,,,,,,,, -26500,0.20054872,1.8796377,,,,,,,,,,,,,,,,, -26600,0.1910717,1.85015,,,,,,,,,,,,,,,,, -26700,0.20227161,1.8216959,,,,,,,,,,,,,,,,, -26800,0.22051862,1.8845376,,,,,,,,,,,,,,,,, -26900,0.19937447,1.824133,,,,,,,,,,,,,,,,, -27000,0.2062074,1.790739,,,,,,,,,,,,,,,,, -27100,0.21696259,1.8642726,,,,,,,,,,,,,,,,, -27200,0.19195233,1.8020383,,,,,,,,,,,,,,,,, -27300,0.24671324,1.8143119,,,,,,,,,,,,,,,,, -27400,0.19530971,1.7569706,,,,,,,,,,,,,,,,, -27500,0.1941771,1.7781931,,,,,,,,,,,,,,,,, -27600,0.18513606,1.8738616,,,,,,,,,,,,,,,,, -27700,0.23928289,1.8947415,,,,,,,,,,,,,,,,, -27800,0.22598624,1.8594754,,,,,,,,,,,,,,,,, -27900,0.20171615,1.8172373,,,,,,,,,,,,,,,,, -28000,0.23989986,1.8235376,,,,,,,,,,,,,,,,, -28100,0.18355218,1.772053,,,,,,,,,,,,,,,,, -28200,0.18828927,1.8267689,,,,,,,,,,,,,,,,, -28201,,,0.6392074823379517,1.7369545698165894,31.261923161305248,0.6590991020202637,1.6041685342788696,28.174983894107854,3000.0,0.6693510413169861,1.5309597253799438,27.50862562138666,3003.0,10111.160809516909,17149.126772880554,10111.160809516909,7036.736354827881,0.3339135646820068,0.0 -28300,0.17058189,1.7831784,,,,,,,,,,,,,,,,, -28400,0.22064525,1.8038067,,,,,,,,,,,,,,,,, -28500,0.25769,1.8639352,,,,,,,,,,,,,,,,, -28600,0.18587497,1.8505992,,,,,,,,,,,,,,,,, -28700,0.17510353,1.8124864,,,,,,,,,,,,,,,,, -28800,0.20846988,1.8628161,,,,,,,,,,,,,,,,, -28900,0.29323825,1.9078782,,,,,,,,,,,,,,,,, -29000,0.24357527,1.8071392,,,,,,,,,,,,,,,,, -29100,0.18800732,1.8520691,,,,,,,,,,,,,,,,, -29200,0.23057286,1.8508146,,,,,,,,,,,,,,,,, -29300,0.18311952,1.8974566,,,,,,,,,,,,,,,,, -29400,0.19884834,1.8236462,,,,,,,,,,,,,,,,, -29500,0.2091104,1.86766,,,,,,,,,,,,,,,,, -29600,0.19767118,1.7983902,,,,,,,,,,,,,,,,, -29700,0.21064812,1.8355955,,,,,,,,,,,,,,,,, -29800,0.19722736,1.8011762,,,,,,,,,,,,,,,,, -29900,0.19913822,1.8412347,,,,,,,,,,,,,,,,, -30000,0.2020644,1.8306195,,,,,,,,,,,,,,,,, -30100,0.22876939,1.8597394,,,,,,,,,,,,,,,,, -30200,0.2378427,1.8132308,,,,,,,,,,,,,,,,, -30300,0.21472064,1.8402727,,,,,,,,,,,,,,,,, -30400,0.187572,1.7572504,,,,,,,,,,,,,,,,, -30500,0.24529198,1.8497775,,,,,,,,,,,,,,,,, -30552,,,0.6390373110771179,1.7507790327072144,31.27872196488971,0.6580327749252319,1.595872402191162,27.93156543775106,3000.0,0.6702225208282471,1.522372841835022,27.48608402007468,3003.0,10951.253396511078,18556.146406650543,10951.253396511078,7603.56007528305,0.3622970581054687,0.0 -30600,0.19968565,1.7308704,,,,,,,,,,,,,,,,, -30700,0.17697904,1.8099341,,,,,,,,,,,,,,,,, -30800,0.2020352,1.7538346,,,,,,,,,,,,,,,,, -30900,0.30139324,1.7967256,,,,,,,,,,,,,,,,, -31000,0.3575427,1.8227589,,,,,,,,,,,,,,,,, -31100,0.2015191,1.8737091,,,,,,,,,,,,,,,,, -31200,0.18772626,1.7630694,,,,,,,,,,,,,,,,, -31300,0.2037486,1.841847,,,,,,,,,,,,,,,,, -31400,0.18846208,1.8220425,,,,,,,,,,,,,,,,, -31500,0.23933914,1.8266578,,,,,,,,,,,,,,,,, -31600,0.18567693,1.7811148,,,,,,,,,,,,,,,,, -31700,0.19718492,1.850464,,,,,,,,,,,,,,,,, -31800,0.17356336,1.7257596,,,,,,,,,,,,,,,,, -31900,0.18331403,1.8369262,,,,,,,,,,,,,,,,, -32000,0.16876006,1.8134155,,,,,,,,,,,,,,,,, -32100,0.18994936,1.8336103,,,,,,,,,,,,,,,,, -32200,0.2171381,1.8241687,,,,,,,,,,,,,,,,, -32300,0.19677934,1.858593,,,,,,,,,,,,,,,,, -32400,0.20831242,1.8633281,,,,,,,,,,,,,,,,, -32500,0.20788506,1.7877246,,,,,,,,,,,,,,,,, -32600,0.23449865,1.8063728,,,,,,,,,,,,,,,,, -32700,0.2062023,1.8080976,,,,,,,,,,,,,,,,, -32800,0.20887923,1.8567203,,,,,,,,,,,,,,,,, -32900,0.18995142,1.8011243,,,,,,,,,,,,,,,,, -32903,,,0.6409906148910522,1.71613872051239,31.478213278866903,0.6610209345817566,1.5890172719955444,28.024704001168377,3000.0,0.6717913150787354,1.5121753215789795,27.377931603736805,3003.0,11791.432320356367,19927.55940771103,11791.432320356367,8134.691674232483,0.390512466430664,0.0 -33000,0.26450065,1.77953,,,,,,,,,,,,,,,,, -33100,0.1991913,1.8194956,,,,,,,,,,,,,,,,, -33200,0.19151883,1.848165,,,,,,,,,,,,,,,,, -33300,0.19291095,1.7778226,,,,,,,,,,,,,,,,, -33400,0.1894757,1.8440696,,,,,,,,,,,,,,,,, -33500,0.20899726,1.7704598,,,,,,,,,,,,,,,,, -33600,0.19672126,1.8533348,,,,,,,,,,,,,,,,, -33700,0.22295837,1.8684086,,,,,,,,,,,,,,,,, -33800,0.21040025,1.7621673,,,,,,,,,,,,,,,,, -33900,0.22861601,1.7322098,,,,,,,,,,,,,,,,, -34000,0.2207887,1.7766825,,,,,,,,,,,,,,,,, -34100,0.19523518,1.7951983,,,,,,,,,,,,,,,,, -34200,0.17874007,1.8058277,,,,,,,,,,,,,,,,, -34300,0.18386044,1.7599181,,,,,,,,,,,,,,,,, -34400,0.21245152,1.7578945,,,,,,,,,,,,,,,,, -34500,0.19356124,1.7150639,,,,,,,,,,,,,,,,, -34600,0.22657827,1.7929258,,,,,,,,,,,,,,,,, -34700,0.19990484,1.8583272,,,,,,,,,,,,,,,,, -34800,0.24039209,1.7402999,,,,,,,,,,,,,,,,, -34900,0.17409676,1.8165848,,,,,,,,,,,,,,,,, -35000,0.23494525,1.7580241,,,,,,,,,,,,,,,,, -35100,0.2319709,1.8028089,,,,,,,,,,,,,,,,, -35200,0.20146815,1.7196336,,,,,,,,,,,,,,,,, -35253,,,0.6431419849395752,1.7123115062713623,31.608488104923683,0.6615169048309326,1.5796931982040403,28.002704988103364,3000.0,0.6747545599937439,1.4998594522476196,28.059762524152184,3003.0,12631.56686782837,21323.122226953503,12631.56686782837,8690.012436389923,0.4196438789367676,0.0 -35300,0.18766397,1.7979052,,,,,,,,,,,,,,,,, -35400,0.23890898,1.7664082,,,,,,,,,,,,,,,,, -35500,0.25927925,1.8091729,,,,,,,,,,,,,,,,, -35600,0.21030407,1.7677796,,,,,,,,,,,,,,,,, -35700,0.24618858,1.8297122,,,,,,,,,,,,,,,,, -35800,0.21462937,1.8042802,,,,,,,,,,,,,,,,, -35900,0.19957343,1.8214538,,,,,,,,,,,,,,,,, -36000,0.18457283,1.7409732,,,,,,,,,,,,,,,,, -36100,0.21710709,1.7137249,,,,,,,,,,,,,,,,, -36200,0.23549668,1.8286533,,,,,,,,,,,,,,,,, -36300,0.19311179,1.716496,,,,,,,,,,,,,,,,, -36400,0.21068694,1.8497515,,,,,,,,,,,,,,,,, -36500,0.19284049,1.7868639,,,,,,,,,,,,,,,,, -36600,0.20272194,1.7301089,,,,,,,,,,,,,,,,, -36700,0.20258275,1.6955053,,,,,,,,,,,,,,,,, -36800,0.25023553,1.7434391,,,,,,,,,,,,,,,,, -36900,0.22260684,1.853219,,,,,,,,,,,,,,,,, -37000,0.23625766,1.7979488,,,,,,,,,,,,,,,,, -37100,0.18205944,1.7377461,,,,,,,,,,,,,,,,, -37200,0.19147241,1.7905331,,,,,,,,,,,,,,,,, -37300,0.19579753,1.8334483,,,,,,,,,,,,,,,,, -37400,0.33104742,1.9147072,,,,,,,,,,,,,,,,, -37500,0.21894406,1.8156663,,,,,,,,,,,,,,,,, -37600,0.20400418,1.8779739,,,,,,,,,,,,,,,,, -37603,,,0.6761975288391113,1.4833768606185913,34.013655184488385,0.6622856259346008,1.571945309638977,28.19304704726278,3000.0,0.6743013262748718,1.49647319316864,27.469852057344102,3003.0,13471.785254716871,22719.0359723568,13471.785254716871,9245.60012793541,0.45088791847229,0.0 -37700,0.2216816,1.8619276,,,,,,,,,,,,,,,,, -37800,0.18570271,1.7066227,,,,,,,,,,,,,,,,, -37900,0.20606536,1.8023357,,,,,,,,,,,,,,,,, -38000,0.23495671,1.7923746,,,,,,,,,,,,,,,,, -38100,0.19356345,1.8161187,,,,,,,,,,,,,,,,, -38200,0.20126194,1.7577906,,,,,,,,,,,,,,,,, -38300,0.20451829,1.770516,,,,,,,,,,,,,,,,, -38400,0.20279938,1.8025098,,,,,,,,,,,,,,,,, -38500,0.18873005,1.8083181,,,,,,,,,,,,,,,,, -38600,0.19031018,1.7665957,,,,,,,,,,,,,,,,, -38700,0.21394233,1.7747017,,,,,,,,,,,,,,,,, -38800,0.20178477,1.7681803,,,,,,,,,,,,,,,,, -38900,0.20569782,1.7750769,,,,,,,,,,,,,,,,, -39000,0.2045538,1.8095851,,,,,,,,,,,,,,,,, -39100,0.18768723,1.7912352,,,,,,,,,,,,,,,,, -39200,0.1719792,1.6775472,,,,,,,,,,,,,,,,, -39300,0.19822235,1.8213375,,,,,,,,,,,,,,,,, -39400,0.23799114,1.7280266,,,,,,,,,,,,,,,,, -39500,0.18425733,1.802267,,,,,,,,,,,,,,,,, -39600,0.20618296,1.7023194,,,,,,,,,,,,,,,,, -39700,0.20366645,1.7777524,,,,,,,,,,,,,,,,, -39800,0.21230024,1.7586206,,,,,,,,,,,,,,,,, -39900,0.20840032,1.7504158,,,,,,,,,,,,,,,,, -39953,,,0.6511838436126709,1.662556767463684,31.470296217893207,0.665162205696106,1.560990333557129,28.531168082604847,3000.0,0.6754401326179504,1.4862949848175049,28.240845223490155,3003.0,14311.73891043663,24081.21613359452,14311.73891043663,9767.721544027328,0.4817836284637451,0.0 -40000,0.20863159,1.7808398,,,,,,,,,,,,,,,,, -40100,0.23186181,1.7490467,,,,,,,,,,,,,,,,, -40200,0.20942324,1.8253037,,,,,,,,,,,,,,,,, -40300,0.19778392,1.7443717,,,,,,,,,,,,,,,,, -40400,0.1982447,1.8069429,,,,,,,,,,,,,,,,, -40500,0.19980885,1.8016887,,,,,,,,,,,,,,,,, -40600,0.20110907,1.6496595,,,,,,,,,,,,,,,,, -40700,0.20428412,1.7832952,,,,,,,,,,,,,,,,, -40800,0.20179982,1.7921094,,,,,,,,,,,,,,,,, -40900,0.22801977,1.8589809,,,,,,,,,,,,,,,,, -41000,0.19827512,1.8087399,,,,,,,,,,,,,,,,, -41100,0.19159152,1.8089083,,,,,,,,,,,,,,,,, -41200,0.23671988,1.7137003,,,,,,,,,,,,,,,,, -41300,0.19425973,1.8030826,,,,,,,,,,,,,,,,, -41400,0.20574827,1.7740366,,,,,,,,,,,,,,,,, -41500,0.24575986,1.7425535,,,,,,,,,,,,,,,,, -41600,0.22892909,1.8013477,,,,,,,,,,,,,,,,, -41700,0.24175547,1.7809447,,,,,,,,,,,,,,,,, -41800,0.20545794,1.8325387,,,,,,,,,,,,,,,,, -41900,0.19074881,1.7237263,,,,,,,,,,,,,,,,, -42000,0.2217061,1.83448,,,,,,,,,,,,,,,,, -42100,0.20250066,1.7828422,,,,,,,,,,,,,,,,, -42200,0.27793476,1.7532446,,,,,,,,,,,,,,,,, -42300,0.1961658,1.7202061,,,,,,,,,,,,,,,,, -42301,,,0.6459149718284607,1.6967586278915403,31.73541057923894,0.6646910905838013,1.5630953311920166,28.590118849023177,3000.0,0.6777874827384949,1.4830689430236816,28.08367316038388,3003.0,15151.733073234558,25434.131522655487,15151.733073234558,10280.534869909286,0.5133066177368164,0.0 -42400,0.18144068,1.6976038,,,,,,,,,,,,,,,,, -42500,0.19310197,1.756428,,,,,,,,,,,,,,,,, -42600,0.18877487,1.7050108,,,,,,,,,,,,,,,,, -42700,0.19232108,1.7753164,,,,,,,,,,,,,,,,, -42800,0.21307814,1.6914907,,,,,,,,,,,,,,,,, -42900,0.22246766,1.6774658,,,,,,,,,,,,,,,,, -43000,0.20518537,1.7719005,,,,,,,,,,,,,,,,, -43100,0.8726913,1.7887839,,,,,,,,,,,,,,,,, -43200,0.21051607,1.7913411,,,,,,,,,,,,,,,,, -43300,0.19264095,1.709828,,,,,,,,,,,,,,,,, -43400,0.20604254,1.8184904,,,,,,,,,,,,,,,,, -43500,0.1936654,1.7374482,,,,,,,,,,,,,,,,, -43600,0.18363999,1.7352095,,,,,,,,,,,,,,,,, -43700,0.22516286,1.8084506,,,,,,,,,,,,,,,,, -43800,0.17765358,1.7983559,,,,,,,,,,,,,,,,, -43900,0.1964397,1.7652334,,,,,,,,,,,,,,,,, -44000,0.19250281,1.7701133,,,,,,,,,,,,,,,,, -44100,0.18791224,1.720959,,,,,,,,,,,,,,,,, -44200,0.18962578,1.7802321,,,,,,,,,,,,,,,,, -44300,0.25956273,1.74413,,,,,,,,,,,,,,,,, -44400,0.20536868,1.7105985,,,,,,,,,,,,,,,,, -44500,0.18418565,1.7953624,,,,,,,,,,,,,,,,, -44600,0.18698926,1.671413,,,,,,,,,,,,,,,,, -44650,,,0.6544402837753296,1.6286474466323853,32.2549238963339,0.666154146194458,1.5522340536117554,28.460654397908776,3000.0,0.6779733896255493,1.475687861442566,27.80825715299147,3003.0,15991.678433418274,26793.424834012985,15991.678433418274,10799.775356054306,0.5443167686462402,0.0 -44700,0.19864398,1.7627088,,,,,,,,,,,,,,,,, -44800,0.20389648,1.7741741,,,,,,,,,,,,,,,,, -44900,0.2030673,1.8186536,,,,,,,,,,,,,,,,, -45000,0.20108862,1.7338046,,,,,,,,,,,,,,,,, -45100,0.20296656,1.7663269,,,,,,,,,,,,,,,,, -45200,0.25487348,1.7600151,,,,,,,,,,,,,,,,, -45300,0.21257225,1.8469049,,,,,,,,,,,,,,,,, -45400,0.18005012,1.7787526,,,,,,,,,,,,,,,,, -45500,0.21654728,1.7689991,,,,,,,,,,,,,,,,, -45600,0.21374395,1.763683,,,,,,,,,,,,,,,,, -45700,0.19890945,1.7167064,,,,,,,,,,,,,,,,, -45800,0.18844038,1.7005776,,,,,,,,,,,,,,,,, -45900,0.20458557,1.765023,,,,,,,,,,,,,,,,, -46000,0.19444305,1.7323076,,,,,,,,,,,,,,,,, -46100,0.21590738,1.8015455,,,,,,,,,,,,,,,,, -46200,0.21777055,1.7886132,,,,,,,,,,,,,,,,, -46300,0.19922753,1.7349768,,,,,,,,,,,,,,,,, -46400,0.18710631,1.6976619,,,,,,,,,,,,,,,,, -46500,0.21079756,1.7379324,,,,,,,,,,,,,,,,, -46600,0.1998883,1.6793915,,,,,,,,,,,,,,,,, -46700,0.21465512,1.8790562,,,,,,,,,,,,,,,,, -46800,0.18803993,1.6796595,,,,,,,,,,,,,,,,, -46900,0.20961313,1.7166542,,,,,,,,,,,,,,,,, -46999,,,0.6529893279075623,1.6469743251800537,31.74572989276444,0.6668609380722046,1.5455957651138306,28.47980213878581,3000.0,0.6801813244819641,1.458451747894287,28.38958892681242,3003.0,16831.602340459824,28266.15973854065,16831.602340459824,11432.476187705994,0.577672004699707,0.0 -47000,0.17635703,1.7638527,,,,,,,,,,,,,,,,, -47100,0.18175796,1.714548,,,,,,,,,,,,,,,,, -47200,0.17915507,1.6710528,,,,,,,,,,,,,,,,, -47300,0.18925598,1.6921628,,,,,,,,,,,,,,,,, -47400,0.2179508,1.7238274,,,,,,,,,,,,,,,,, -47500,0.19645065,1.7388343,,,,,,,,,,,,,,,,, -47600,0.20767672,1.743791,,,,,,,,,,,,,,,,, -47700,0.19877435,1.7829031,,,,,,,,,,,,,,,,, -47800,0.19411826,1.6906339,,,,,,,,,,,,,,,,, -47900,0.20333159,1.7432317,,,,,,,,,,,,,,,,, -48000,0.22155459,1.6878157,,,,,,,,,,,,,,,,, -48100,0.23885119,1.776148,,,,,,,,,,,,,,,,, -48200,0.21038865,1.7443333,,,,,,,,,,,,,,,,, -48300,0.22141702,1.7872263,,,,,,,,,,,,,,,,, -48400,0.17749606,1.7339953,,,,,,,,,,,,,,,,, -48500,0.19022763,1.7048069,,,,,,,,,,,,,,,,, -48600,0.20868614,1.805335,,,,,,,,,,,,,,,,, -48700,0.22417817,1.7370726,,,,,,,,,,,,,,,,, -48800,0.19189633,1.6992335,,,,,,,,,,,,,,,,, -48900,0.20263411,1.8390809,,,,,,,,,,,,,,,,, -49000,0.22183439,1.7481439,,,,,,,,,,,,,,,,, -49100,0.20874438,1.7541234,,,,,,,,,,,,,,,,, -49200,0.20762849,1.7910988,,,,,,,,,,,,,,,,, -49300,0.23242772,1.778243,,,,,,,,,,,,,,,,, -49349,,,0.64932781457901,1.6753994226455688,31.92990226128372,0.6689067482948303,1.5383224487304688,28.933472139073533,3000.0,0.6800883412361145,1.457780838012695,28.21943189274275,3003.0,17671.68727684021,29609.51440000534,17671.68727684021,11935.63941526413,0.6097137928009033,0.0 -49400,0.1808578,1.7167672,,,,,,,,,,,,,,,,, -49500,0.20044328,1.8031094,,,,,,,,,,,,,,,,, -49600,0.25763407,1.7675295,,,,,,,,,,,,,,,,, -49700,0.19542347,1.7093613,,,,,,,,,,,,,,,,, -49800,0.20431681,1.7323741,,,,,,,,,,,,,,,,, -49900,0.20768078,1.8137587,,,,,,,,,,,,,,,,, -50000,0.20189488,1.7446437,,,,,,,,,,,,,,,,, -50100,0.19779523,1.7720343,,,,,,,,,,,,,,,,, -50200,0.26477152,1.8204975,,,,,,,,,,,,,,,,, -50300,0.198802,1.6662219,,,,,,,,,,,,,,,,, -50400,0.20249666,1.7180792,,,,,,,,,,,,,,,,, -50500,0.17829041,1.7586099,,,,,,,,,,,,,,,,, -50600,0.28535977,1.7458525,,,,,,,,,,,,,,,,, -50700,0.20334135,1.8373582,,,,,,,,,,,,,,,,, -50800,0.19881147,1.6686323,,,,,,,,,,,,,,,,, -50900,0.21722578,1.7551858,,,,,,,,,,,,,,,,, -51000,0.22301248,1.7246662,,,,,,,,,,,,,,,,, -51100,0.1915772,1.6886483,,,,,,,,,,,,,,,,, -51200,0.19796456,1.7928618,,,,,,,,,,,,,,,,, -51300,0.22861764,1.7698276,,,,,,,,,,,,,,,,, -51400,0.223313,1.8060462,,,,,,,,,,,,,,,,, -51500,0.19910757,1.6348683,,,,,,,,,,,,,,,,, -51600,0.20652886,1.7774221,,,,,,,,,,,,,,,,, -51699,,,0.6542662978172302,1.6272282600402832,32.459203013070265,0.6693159341812134,1.5343632698059082,28.620420565903288,3000.0,0.6813085079193115,1.4497711658477783,28.43325483477639,3003.0,18511.66797375679,31134.161219596863,18511.66797375679,12620.1957821846,0.643932580947876,0.0 -51700,0.21419224,1.7222886,,,,,,,,,,,,,,,,, -51800,0.23320429,1.8013301,,,,,,,,,,,,,,,,, -51900,0.20587176,1.7784268,,,,,,,,,,,,,,,,, -52000,0.21089186,1.7624213,,,,,,,,,,,,,,,,, -52100,0.1942245,1.7742907,,,,,,,,,,,,,,,,, -52200,0.19871736,1.7012504,,,,,,,,,,,,,,,,, -52300,0.23053104,1.7983172,,,,,,,,,,,,,,,,, -52400,0.18785027,1.7483177,,,,,,,,,,,,,,,,, -52500,0.19635198,1.713598,,,,,,,,,,,,,,,,, -52600,0.28461242,1.7018938,,,,,,,,,,,,,,,,, -52700,0.20312676,1.7655853,,,,,,,,,,,,,,,,, -52800,0.23834658,1.7465782,,,,,,,,,,,,,,,,, -52900,0.18367931,1.7030522,,,,,,,,,,,,,,,,, -53000,0.20707859,1.7752802,,,,,,,,,,,,,,,,, -53100,0.25568,1.7401898,,,,,,,,,,,,,,,,, -53200,0.23904218,1.707715,,,,,,,,,,,,,,,,, -53300,0.20408367,1.6833788,,,,,,,,,,,,,,,,, -53400,0.29614872,1.694751,,,,,,,,,,,,,,,,, -53500,0.19586287,1.7297373,,,,,,,,,,,,,,,,, -53600,0.20891817,1.7353779,,,,,,,,,,,,,,,,, -53700,0.22635975,1.6552842,,,,,,,,,,,,,,,,, -53800,0.20572504,1.7059829,,,,,,,,,,,,,,,,, -53900,0.20280878,1.7191838,,,,,,,,,,,,,,,,, -54000,0.22040959,1.7447697,,,,,,,,,,,,,,,,, -54048,,,0.6529883146286011,1.654189944267273,32.510313488174106,0.6688695549964905,1.5301557779312134,28.57319480788474,3000.0,0.68177330493927,1.448590636253357,28.328816470806625,3003.0,19351.66876530648,32636.85775065422,19351.66876530648,13282.779415607452,0.6774072647094727,0.0 -54100,0.19426224,1.6688472,,,,,,,,,,,,,,,,, -54200,0.22753556,1.811235,,,,,,,,,,,,,,,,, -54300,0.20325823,1.6553801,,,,,,,,,,,,,,,,, -54400,0.1880482,1.6817002,,,,,,,,,,,,,,,,, -54500,0.2045466,1.6719428,,,,,,,,,,,,,,,,, -54600,0.21403123,1.7566386,,,,,,,,,,,,,,,,, -54700,0.18915267,1.6769478,,,,,,,,,,,,,,,,, -54800,0.18862961,1.7028598,,,,,,,,,,,,,,,,, -54900,0.19512752,1.705458,,,,,,,,,,,,,,,,, -55000,0.18079965,1.7077304,,,,,,,,,,,,,,,,, -55100,0.22148201,1.698481,,,,,,,,,,,,,,,,, -55200,0.25315738,1.6632328,,,,,,,,,,,,,,,,, -55300,0.20942107,1.7844532,,,,,,,,,,,,,,,,, -55400,0.18722974,1.7536095,,,,,,,,,,,,,,,,, -55500,0.18939345,1.7449157,,,,,,,,,,,,,,,,, -55600,0.4254064,1.8338213,,,,,,,,,,,,,,,,, -55700,0.19041593,1.7631952,,,,,,,,,,,,,,,,, -55800,0.22672379,1.7488917,,,,,,,,,,,,,,,,, -55900,0.22290948,1.6651618,,,,,,,,,,,,,,,,, -56000,0.21785116,1.6740876,,,,,,,,,,,,,,,,, -56100,0.19318902,1.6497388,,,,,,,,,,,,,,,,, -56200,0.19538744,1.7374402,,,,,,,,,,,,,,,,, -56300,0.22052015,1.711163,,,,,,,,,,,,,,,,, -56397,,,0.68346107006073,1.446035623550415,34.018130430164184,0.6704070568084717,1.5175199508666992,29.086346659127816,3000.0,0.682877242565155,1.4382997751235962,28.169233789147054,3003.0,20191.756650686264,34078.80337572098,20191.756650686264,13884.517220497131,0.7161824703216553,0.0 -56400,0.19754532,1.7484508,,,,,,,,,,,,,,,,, -56500,0.2040485,1.6796635,,,,,,,,,,,,,,,,, -56600,0.21021189,1.7974803,,,,,,,,,,,,,,,,, -56700,0.19133541,1.6406318,,,,,,,,,,,,,,,,, -56800,0.1965658,1.786745,,,,,,,,,,,,,,,,, -56900,0.19784826,1.6948116,,,,,,,,,,,,,,,,, -57000,0.19562483,1.7364014,,,,,,,,,,,,,,,,, -57100,0.19936645,1.7029746,,,,,,,,,,,,,,,,, -57200,0.19034901,1.738897,,,,,,,,,,,,,,,,, -57300,0.20488445,1.8183501,,,,,,,,,,,,,,,,, -57400,0.27316016,1.7184998,,,,,,,,,,,,,,,,, -57500,0.20763434,1.6644279,,,,,,,,,,,,,,,,, -57600,0.18529242,1.6520861,,,,,,,,,,,,,,,,, -57700,0.21278018,1.7228434,,,,,,,,,,,,,,,,, -57800,0.20083618,1.6941086,,,,,,,,,,,,,,,,, -57900,0.20790634,1.6776619,,,,,,,,,,,,,,,,, -58000,0.18246,1.6622572,,,,,,,,,,,,,,,,, -58100,0.18831249,1.6279937,,,,,,,,,,,,,,,,, -58200,0.19081204,1.6574833,,,,,,,,,,,,,,,,, -58300,0.20026381,1.702662,,,,,,,,,,,,,,,,, -58400,0.2163773,1.6702095,,,,,,,,,,,,,,,,, -58500,0.22050443,1.6603805,,,,,,,,,,,,,,,,, -58600,0.18915045,1.715485,,,,,,,,,,,,,,,,, -58700,0.20234102,1.7306402,,,,,,,,,,,,,,,,, -58747,,,0.6575668454170227,1.6168078184127808,32.5052435486537,0.6729984879493713,1.5107536315917969,29.125088035511126,3000.0,0.6855267286300659,1.4299596548080444,28.621032891839445,3003.0,21031.882925748825,35508.15443897247,21031.882925748825,14473.63262438774,0.7506313323974609,0.0 -58800,0.20782545,1.7422942,,,,,,,,,,,,,,,,, -58900,0.21974629,1.7380426,,,,,,,,,,,,,,,,, -59000,0.19497544,1.6986475,,,,,,,,,,,,,,,,, -59100,0.1898852,1.7222537,,,,,,,,,,,,,,,,, -59200,0.20491058,1.7096994,,,,,,,,,,,,,,,,, -59300,0.22437878,1.7236034,,,,,,,,,,,,,,,,, -59400,0.20774545,1.7353278,,,,,,,,,,,,,,,,, -59500,0.22055072,1.7019852,,,,,,,,,,,,,,,,, -59600,0.20479691,1.6891334,,,,,,,,,,,,,,,,, -59700,0.19681701,1.6579221,,,,,,,,,,,,,,,,, -59800,0.19090201,1.6979129,,,,,,,,,,,,,,,,, -59900,0.29633966,1.5890337,,,,,,,,,,,,,,,,, -60000,0.17927691,1.6728923,,,,,,,,,,,,,,,,, -60100,0.18615998,1.7185162,,,,,,,,,,,,,,,,, -60200,0.1938815,1.7447677,,,,,,,,,,,,,,,,, -60300,0.22705752,1.7182622,,,,,,,,,,,,,,,,, -60400,0.2263096,1.6703206,,,,,,,,,,,,,,,,, -60500,0.2311336,1.6458422,,,,,,,,,,,,,,,,, -60600,0.19323297,1.7229164,,,,,,,,,,,,,,,,, -60700,0.20895809,1.6539274,,,,,,,,,,,,,,,,, -60800,0.2062202,1.7613451,,,,,,,,,,,,,,,,, -60900,0.2116878,1.8127556,,,,,,,,,,,,,,,,, -61000,0.2548554,1.7386643,,,,,,,,,,,,,,,,, -61096,,,0.6561620235443115,1.6366279125213623,32.63350373980873,0.6729736924171448,1.512277603149414,29.140785546387807,3000.0,0.6879902482032776,1.4211199283599854,29.14197211825932,3003.0,21871.777238607407,36873.4613969326,21871.777238607407,14998.93830871582,0.7828867435455322,0.0 -61100,0.23802918,1.6796947,,,,,,,,,,,,,,,,, -61200,0.1880577,1.687149,,,,,,,,,,,,,,,,, -61300,0.19649099,1.6848197,,,,,,,,,,,,,,,,, -61400,0.19632348,1.6946334,,,,,,,,,,,,,,,,, -61500,0.21276437,1.6783345,,,,,,,,,,,,,,,,, -61600,0.20013659,1.6272321,,,,,,,,,,,,,,,,, -61700,0.18678403,1.6763314,,,,,,,,,,,,,,,,, -61800,0.22064643,1.7331431,,,,,,,,,,,,,,,,, -61900,0.18389612,1.6362174,,,,,,,,,,,,,,,,, -62000,0.20141554,1.8105155,,,,,,,,,,,,,,,,, -62100,0.19629149,1.617051,,,,,,,,,,,,,,,,, -62200,0.19758064,1.6883487,,,,,,,,,,,,,,,,, -62300,0.21501161,1.6842582,,,,,,,,,,,,,,,,, -62400,0.20250176,1.6777074,,,,,,,,,,,,,,,,, -62500,0.20995693,1.728643,,,,,,,,,,,,,,,,, -62600,0.2089634,1.6934618,,,,,,,,,,,,,,,,, -62700,0.19902547,1.7268555,,,,,,,,,,,,,,,,, -62800,0.19504818,1.6488131,,,,,,,,,,,,,,,,, -62900,0.19417363,1.6757672,,,,,,,,,,,,,,,,, -63000,0.24333425,1.727981,,,,,,,,,,,,,,,,, -63100,0.19549073,1.6345593,,,,,,,,,,,,,,,,, -63200,0.18754272,1.5975558,,,,,,,,,,,,,,,,, -63300,0.17998977,1.6727784,,,,,,,,,,,,,,,,, -63400,0.18954788,1.7262852,,,,,,,,,,,,,,,,, -63444,,,0.6674286723136902,1.547412633895874,33.49294815381641,0.6748335361480713,1.4961423873901367,29.477780759940647,3000.0,0.6889896392822266,1.410127878189087,28.9947816834239,3003.0,22711.711676359177,38273.14426493645,22711.711676359177,15558.572283506392,0.8171112537384033,0.0 -63500,0.43373254,1.7421684,,,,,,,,,,,,,,,,, -63600,0.20819482,1.7455782,,,,,,,,,,,,,,,,, -63700,0.20497242,1.7233539,,,,,,,,,,,,,,,,, -63800,0.19288221,1.729829,,,,,,,,,,,,,,,,, -63900,0.1903781,1.6963494,,,,,,,,,,,,,,,,, -64000,0.19215752,1.706357,,,,,,,,,,,,,,,,, -64100,0.22035874,1.5769565,,,,,,,,,,,,,,,,, -64200,0.19089898,1.7821598,,,,,,,,,,,,,,,,, -64300,0.19995606,1.6932847,,,,,,,,,,,,,,,,, -64400,0.20946372,1.8169048,,,,,,,,,,,,,,,,, -64500,0.2059057,1.7387037,,,,,,,,,,,,,,,,, -64600,0.19810802,1.7168237,,,,,,,,,,,,,,,,, -64700,0.19443312,1.6121471,,,,,,,,,,,,,,,,, -64800,0.18975976,1.7204715,,,,,,,,,,,,,,,,, -64900,0.19521956,1.7222658,,,,,,,,,,,,,,,,, -65000,0.19735746,1.7067869,,,,,,,,,,,,,,,,, -65100,0.19277935,1.6568033,,,,,,,,,,,,,,,,, -65200,0.2322108,1.7081605,,,,,,,,,,,,,,,,, -65300,0.21478291,1.5793221,,,,,,,,,,,,,,,,, -65400,0.20555183,1.7757591,,,,,,,,,,,,,,,,, -65500,0.19105692,1.7188793,,,,,,,,,,,,,,,,, -65600,0.22800116,1.7274234,,,,,,,,,,,,,,,,, -65700,0.20967497,1.6081374,,,,,,,,,,,,,,,,, -65794,,,0.6565250754356384,1.6203813552856443,32.66615467451501,0.676073431968689,1.4927388429641724,29.329756231686886,3000.0,0.6903724670410156,1.4074007272720337,29.56408904455485,3003.0,23551.849817037582,39674.71835613251,23551.849817037582,16119.899282455444,0.8502511978149414,0.0 -65800,0.19220886,1.6557156,,,,,,,,,,,,,,,,, -65900,0.19191633,1.6919765,,,,,,,,,,,,,,,,, -66000,0.19443294,1.6218818,,,,,,,,,,,,,,,,, -66100,0.20310953,1.6579713,,,,,,,,,,,,,,,,, -66200,0.18840303,1.7178856,,,,,,,,,,,,,,,,, -66300,0.20007205,1.6474222,,,,,,,,,,,,,,,,, -66400,0.19715105,1.6361704,,,,,,,,,,,,,,,,, -66500,0.20257106,1.699554,,,,,,,,,,,,,,,,, -66600,0.18943703,1.7364215,,,,,,,,,,,,,,,,, -66700,0.25425947,1.6845727,,,,,,,,,,,,,,,,, -66800,0.20562817,1.747021,,,,,,,,,,,,,,,,, -66900,0.18951726,1.6027572,,,,,,,,,,,,,,,,, -67000,0.22471525,1.7268802,,,,,,,,,,,,,,,,, -67100,0.1912148,1.7153312,,,,,,,,,,,,,,,,, -67200,0.19908008,1.6443367,,,,,,,,,,,,,,,,, -67300,0.18100457,1.6273115,,,,,,,,,,,,,,,,, -67400,0.22082284,1.6746533,,,,,,,,,,,,,,,,, -67500,0.19636978,1.6473513,,,,,,,,,,,,,,,,, -67600,0.19864376,1.6486745,,,,,,,,,,,,,,,,, -67700,0.19535021,1.6393267,,,,,,,,,,,,,,,,, -67800,0.21022938,1.6765434,,,,,,,,,,,,,,,,, -67900,0.20921056,1.7457643,,,,,,,,,,,,,,,,, -68000,0.22000767,1.5847561,,,,,,,,,,,,,,,,, -68100,0.25087202,1.7127848,,,,,,,,,,,,,,,,, -68144,,,0.6547269821166992,1.634839653968811,32.47219353442573,0.677561342716217,1.4846450090408323,29.4492775140218,3000.0,0.6897798180580139,1.4021321535110474,29.068497099009715,3003.0,24391.99292993545,41078.09699392319,24391.99292993545,16683.019134521484,0.8892166614532471,0.0 -68200,0.19877698,1.8113999,,,,,,,,,,,,,,,,, -68300,0.20284544,1.664434,,,,,,,,,,,,,,,,, -68400,0.18301646,1.6205943,,,,,,,,,,,,,,,,, -68500,0.23153992,1.6614534,,,,,,,,,,,,,,,,, -68600,0.18592723,1.6774747,,,,,,,,,,,,,,,,, -68700,0.19385815,1.6755635,,,,,,,,,,,,,,,,, -68800,0.21024173,1.7588247,,,,,,,,,,,,,,,,, -68900,0.18594882,1.6488173,,,,,,,,,,,,,,,,, -69000,0.21937212,1.6063838,,,,,,,,,,,,,,,,, -69100,0.19742638,1.7324722,,,,,,,,,,,,,,,,, -69200,0.20969634,1.658044,,,,,,,,,,,,,,,,, -69300,0.19592531,1.6366316,,,,,,,,,,,,,,,,, -69400,0.19987236,1.6790276,,,,,,,,,,,,,,,,, -69500,0.20668282,1.6967797,,,,,,,,,,,,,,,,, -69600,0.19080621,1.6046321,,,,,,,,,,,,,,,,, -69700,0.20396048,1.7642599,,,,,,,,,,,,,,,,, -69800,0.19156799,1.6210485,,,,,,,,,,,,,,,,, -69900,0.2180502,1.6911217,,,,,,,,,,,,,,,,, -70000,0.18809332,1.5810438,,,,,,,,,,,,,,,,, -70100,0.18547994,1.6346098,,,,,,,,,,,,,,,,, -70200,0.21236302,1.6459424,,,,,,,,,,,,,,,,, -70300,0.19537845,1.6287936,,,,,,,,,,,,,,,,, -70400,0.22674197,1.6086295,,,,,,,,,,,,,,,,, -70494,,,0.6648585796356201,1.5671465396881104,32.83949660411032,0.6763090491294861,1.4787070751190186,29.078669687795948,3000.0,0.6899541020393372,1.3922793865203855,29.03229432623568,3003.0,25232.03781723976,42432.101380348206,25232.03781723976,17196.86828827858,0.9239933490753174,0.0 -70500,0.32523525,1.735992,,,,,,,,,,,,,,,,, -70600,0.20284386,1.6029232,,,,,,,,,,,,,,,,, -70700,0.18852618,1.6308316,,,,,,,,,,,,,,,,, -70800,0.21110472,1.6676539,,,,,,,,,,,,,,,,, -70900,0.1931609,1.6556647,,,,,,,,,,,,,,,,, -71000,0.20227808,1.6385583,,,,,,,,,,,,,,,,, -71100,0.22223863,1.6444552,,,,,,,,,,,,,,,,, -71200,0.20870408,1.6733663,,,,,,,,,,,,,,,,, -71300,0.21055149,1.6169614,,,,,,,,,,,,,,,,, -71400,0.205785,1.674353,,,,,,,,,,,,,,,,, -71500,0.20194493,1.6892307,,,,,,,,,,,,,,,,, -71600,0.19548652,1.653259,,,,,,,,,,,,,,,,, -71700,0.25116545,1.614791,,,,,,,,,,,,,,,,, -71800,0.20136423,1.649167,,,,,,,,,,,,,,,,, -71900,0.20587139,1.5388337,,,,,,,,,,,,,,,,, -72000,0.19693564,1.6951864,,,,,,,,,,,,,,,,, -72100,0.18948478,1.6437312,,,,,,,,,,,,,,,,, -72200,0.20716122,1.6828839,,,,,,,,,,,,,,,,, -72300,0.20288666,1.5997,,,,,,,,,,,,,,,,, -72400,0.20806006,1.6293696,,,,,,,,,,,,,,,,, -72500,0.19731298,1.6310815,,,,,,,,,,,,,,,,, -72600,0.224606,1.5803757,,,,,,,,,,,,,,,,, -72700,0.20115459,1.6522806,,,,,,,,,,,,,,,,, -72800,0.49471918,1.6789938,,,,,,,,,,,,,,,,, -72843,,,0.6621139049530029,1.5815826654434204,33.390701136287404,0.6767553687095642,1.473919153213501,29.35268774776388,3000.0,0.6910812854766846,1.3824280500411987,29.18977920484513,3003.0,26072.20747256279,43788.344264507294,26072.20747256279,17712.821888685226,0.9654061794281006,0.0 -72900,0.2140258,1.6641834,,,,,,,,,,,,,,,,, -73000,0.20630683,1.6753082,,,,,,,,,,,,,,,,, -73100,0.21553485,1.6227139,,,,,,,,,,,,,,,,, -73200,0.1891639,1.6542661,,,,,,,,,,,,,,,,, -73300,0.22633547,1.6361248,,,,,,,,,,,,,,,,, -73400,0.19896288,1.7039633,,,,,,,,,,,,,,,,, -73500,0.18608351,1.5984373,,,,,,,,,,,,,,,,, -73600,0.1936739,1.5885906,,,,,,,,,,,,,,,,, -73700,0.19373588,1.665336,,,,,,,,,,,,,,,,, -73800,0.19666003,1.6935544,,,,,,,,,,,,,,,,, -73900,0.21002436,1.66454,,,,,,,,,,,,,,,,, -74000,0.22526611,1.6912317,,,,,,,,,,,,,,,,, -74100,0.19703975,1.6452417,,,,,,,,,,,,,,,,, -74200,0.20414698,1.6469923,,,,,,,,,,,,,,,,, -74300,0.19228834,1.6601906,,,,,,,,,,,,,,,,, -74400,0.20971955,1.6409342,,,,,,,,,,,,,,,,, -74500,0.1972287,1.6403569,,,,,,,,,,,,,,,,, -74600,0.22535509,1.5705163,,,,,,,,,,,,,,,,, -74700,0.19358999,1.649187,,,,,,,,,,,,,,,,, -74800,0.20491494,1.6954062,,,,,,,,,,,,,,,,, -74900,0.19806337,1.5794903,,,,,,,,,,,,,,,,, -75000,0.19885463,1.6791949,,,,,,,,,,,,,,,,, -75100,0.20867774,1.6218636,,,,,,,,,,,,,,,,, -75193,,,0.6858174800872803,1.426164984703064,34.366781063424106,0.6785904765129089,1.4634785652160645,29.51441599734365,3000.0,0.693544864654541,1.378287672996521,29.07100048590121,3003.0,26912.381454229355,45217.25480270386,26912.381454229355,18301.44424295425,1.0022211074829102,0.0 -75200,0.20557226,1.6050836,,,,,,,,,,,,,,,,, -75300,0.20956999,1.6381501,,,,,,,,,,,,,,,,, -75400,0.19082971,1.6003381,,,,,,,,,,,,,,,,, -75500,0.1938781,1.5855848,,,,,,,,,,,,,,,,, -75600,0.20696002,1.6639569,,,,,,,,,,,,,,,,, -75700,0.19077003,1.5829132,,,,,,,,,,,,,,,,, -75800,0.21265154,1.638692,,,,,,,,,,,,,,,,, -75900,0.19663584,1.7250979,,,,,,,,,,,,,,,,, -76000,0.21950294,1.6997772,,,,,,,,,,,,,,,,, -76100,0.18208066,1.5498008,,,,,,,,,,,,,,,,, -76200,0.19195065,1.6250662,,,,,,,,,,,,,,,,, -76300,0.19586869,1.6768397,,,,,,,,,,,,,,,,, -76400,0.19777586,1.7325244,,,,,,,,,,,,,,,,, -76500,0.20459287,1.6123763,,,,,,,,,,,,,,,,, -76600,0.2185769,1.6153029,,,,,,,,,,,,,,,,, -76700,0.20741434,1.6079003,,,,,,,,,,,,,,,,, -76800,0.19014885,1.6581675,,,,,,,,,,,,,,,,, -76900,0.20784815,1.7050835,,,,,,,,,,,,,,,,, -77000,0.19207755,1.6028422,,,,,,,,,,,,,,,,, -77100,0.19903418,1.6497163,,,,,,,,,,,,,,,,, -77200,0.19523709,1.6439146,,,,,,,,,,,,,,,,, -77300,0.21752016,1.6406991,,,,,,,,,,,,,,,,, -77400,0.20083603,1.6648582,,,,,,,,,,,,,,,,, -77500,0.1942515,1.5986799,,,,,,,,,,,,,,,,, -77543,,,0.6669697165489197,1.5403841733932495,33.21220426533364,0.6799419522285461,1.4549518823623655,29.43968950403044,3000.0,0.6951484680175781,1.3676152229309082,29.485973265345606,3003.0,27752.45276355744,46761.59459590912,27752.45276355744,19005.60043978691,1.0388991832733154,0.0 -77600,0.18731782,1.5532243,,,,,,,,,,,,,,,,, -77700,0.19334246,1.6116407,,,,,,,,,,,,,,,,, -77800,0.2073968,1.6410391,,,,,,,,,,,,,,,,, -77900,0.21992272,1.7663293,,,,,,,,,,,,,,,,, -78000,0.2033835,1.6801507,,,,,,,,,,,,,,,,, -78100,0.201531,1.6016359,,,,,,,,,,,,,,,,, -78200,0.22255032,1.6378381,,,,,,,,,,,,,,,,, -78300,0.87626916,1.6760069,,,,,,,,,,,,,,,,, -78400,0.20828237,1.646869,,,,,,,,,,,,,,,,, -78500,0.20291162,1.6783533,,,,,,,,,,,,,,,,, -78600,0.21475454,1.6913481,,,,,,,,,,,,,,,,, -78700,0.20743038,1.6474818,,,,,,,,,,,,,,,,, -78800,0.20790684,1.6251297,,,,,,,,,,,,,,,,, -78900,0.18893081,1.625088,,,,,,,,,,,,,,,,, -79000,0.19868033,1.6210717,,,,,,,,,,,,,,,,, -79100,0.20646903,1.6307243,,,,,,,,,,,,,,,,, -79200,0.19270717,1.586073,,,,,,,,,,,,,,,,, -79300,0.19293883,1.5932642,,,,,,,,,,,,,,,,, -79400,0.20246445,1.5912284,,,,,,,,,,,,,,,,, -79500,0.49846175,1.6185523,,,,,,,,,,,,,,,,, -79600,0.2823235,1.6214789,,,,,,,,,,,,,,,,, -79700,0.20405309,1.7066662,,,,,,,,,,,,,,,,, -79800,0.2146395,1.7011102,,,,,,,,,,,,,,,,, -79893,,,0.6644576787948608,1.573724389076233,32.65712602644397,0.6827317476272583,1.4516927003860474,29.785093513403424,3000.0,0.6960781216621399,1.3559688329696655,29.612354085956323,3003.0,28592.432891607285,48130.00425624848,28592.432891607285,19533.910708904263,1.0820260047912598,0.0 -79900,0.1997701,1.6198771,,,,,,,,,,,,,,,,, -80000,0.21063623,1.5799382,,,,,,,,,,,,,,,,, -80100,0.21404393,1.7172297,,,,,,,,,,,,,,,,, -80200,0.20204185,1.6061894,,,,,,,,,,,,,,,,, -80300,0.21221057,1.6871878,,,,,,,,,,,,,,,,, -80400,0.2036918,1.6226557,,,,,,,,,,,,,,,,, -80500,0.22433005,1.6456228,,,,,,,,,,,,,,,,, -80600,0.21961568,1.619938,,,,,,,,,,,,,,,,, -80700,0.20583749,1.7084341,,,,,,,,,,,,,,,,, -80800,0.20812021,1.661218,,,,,,,,,,,,,,,,, -80900,0.21495272,1.6221246,,,,,,,,,,,,,,,,, -81000,0.20263089,1.5637084,,,,,,,,,,,,,,,,, -81100,0.20121348,1.609316,,,,,,,,,,,,,,,,, -81200,0.20868298,1.6401782,,,,,,,,,,,,,,,,, -81300,0.19544068,1.583445,,,,,,,,,,,,,,,,, -81400,0.20626047,1.5961578,,,,,,,,,,,,,,,,, -81500,0.18933332,1.527949,,,,,,,,,,,,,,,,, -81600,0.20771778,1.606608,,,,,,,,,,,,,,,,, -81700,0.19550937,1.6318946,,,,,,,,,,,,,,,,, -81800,0.19236171,1.6339788,,,,,,,,,,,,,,,,, -81900,0.20735838,1.6193666,,,,,,,,,,,,,,,,, -82000,0.21201898,1.5862304,,,,,,,,,,,,,,,,, -82100,0.22360979,1.6692104,,,,,,,,,,,,,,,,, -82200,0.19280304,1.5646111,,,,,,,,,,,,,,,,, -82243,,,0.6750671863555908,1.4929388761520386,33.87395627976873,0.6826077699661255,1.4436779022216797,29.84936934505664,3000.0,0.6966823935508728,1.3540998697280884,29.61672950952151,3003.0,29432.561252594,49679.17920422554,29432.561252594,20242.84255671501,1.1201717853546145,0.0 -82300,0.20103416,1.6946929,,,,,,,,,,,,,,,,, -82400,0.21165992,1.6386433,,,,,,,,,,,,,,,,, -82500,0.20884171,1.6615325,,,,,,,,,,,,,,,,, -82600,0.19555789,1.5591743,,,,,,,,,,,,,,,,, -82700,0.1968671,1.6212974,,,,,,,,,,,,,,,,, -82800,0.20215055,1.6152816,,,,,,,,,,,,,,,,, -82900,0.21049874,1.6429372,,,,,,,,,,,,,,,,, -83000,0.19805689,1.6085064,,,,,,,,,,,,,,,,, -83100,0.2177962,1.5376377,,,,,,,,,,,,,,,,, -83200,0.23595911,1.6029297,,,,,,,,,,,,,,,,, -83300,0.20253466,1.6438454,,,,,,,,,,,,,,,,, -83400,0.49265462,1.5501676,,,,,,,,,,,,,,,,, -83500,0.32916066,1.604271,,,,,,,,,,,,,,,,, -83600,0.19662364,1.5983925,,,,,,,,,,,,,,,,, -83700,0.19508488,1.604209,,,,,,,,,,,,,,,,, -83800,0.32499552,1.6752819,,,,,,,,,,,,,,,,, -83900,0.19651516,1.5751681,,,,,,,,,,,,,,,,, -84000,0.21562655,1.684142,,,,,,,,,,,,,,,,, -84100,0.20339137,1.6000015,,,,,,,,,,,,,,,,, -84200,0.18938735,1.5594994,,,,,,,,,,,,,,,,, -84300,0.20362966,1.6289183,,,,,,,,,,,,,,,,, -84400,0.2132954,1.6352992,,,,,,,,,,,,,,,,, -84500,0.21380922,1.6151202,,,,,,,,,,,,,,,,, -84593,,,0.6679847240447998,1.5398011207580566,33.47505030763876,0.6833269000053406,1.4395651817321775,29.752018414615385,3000.0,0.6992040276527405,1.3429124355316162,29.948791249609343,3003.0,30272.66680049896,51220.04585957527,30272.66680049896,20943.49081230164,1.1579155921936035,0.0 -84600,0.19906032,1.523549,,,,,,,,,,,,,,,,, -84700,0.21733867,1.5766351,,,,,,,,,,,,,,,,, -84800,0.22654343,1.6685935,,,,,,,,,,,,,,,,, -84900,0.21053202,1.6064622,,,,,,,,,,,,,,,,, -85000,0.19159254,1.5587124,,,,,,,,,,,,,,,,, -85100,0.21104246,1.5215092,,,,,,,,,,,,,,,,, -85200,0.21064031,1.6302358,,,,,,,,,,,,,,,,, -85300,0.1973036,1.6439385,,,,,,,,,,,,,,,,, -85400,0.20154262,1.5858834,,,,,,,,,,,,,,,,, -85500,0.20443018,1.5843627,,,,,,,,,,,,,,,,, -85600,0.21318695,1.5884022,,,,,,,,,,,,,,,,, -85700,0.21673068,1.6150461,,,,,,,,,,,,,,,,, -85800,0.21672837,1.5808135,,,,,,,,,,,,,,,,, -85900,0.20749582,1.661939,,,,,,,,,,,,,,,,, -86000,0.20395161,1.7057972,,,,,,,,,,,,,,,,, -86100,0.21264385,1.6762087,,,,,,,,,,,,,,,,, -86200,0.19616894,1.6070114,,,,,,,,,,,,,,,,, -86300,0.20352975,1.5994838,,,,,,,,,,,,,,,,, -86400,0.19996057,1.5519317,,,,,,,,,,,,,,,,, -86500,0.20234302,1.6045154,,,,,,,,,,,,,,,,, -86600,0.23300233,1.6171918,,,,,,,,,,,,,,,,, -86700,0.20793048,1.595625,,,,,,,,,,,,,,,,, -86800,0.21683481,1.5633317,,,,,,,,,,,,,,,,, -86900,0.22055604,1.5989377,,,,,,,,,,,,,,,,, -86944,,,0.66844642162323,1.5463353395462036,33.3239685944154,0.6863275170326233,1.4329047203063965,30.17483410773,3000.0,0.7015630006790161,1.336978793144226,30.01117652283077,3003.0,31112.70547223091,52629.39412069321,31112.70547223091,21512.685092926025,1.1974620819091797,0.0 -87000,0.2022444,1.5452324,,,,,,,,,,,,,,,,, -87100,0.19990848,1.636011,,,,,,,,,,,,,,,,, -87200,0.20052847,1.5775567,,,,,,,,,,,,,,,,, -87300,0.20090282,1.5769919,,,,,,,,,,,,,,,,, -87400,0.21163864,1.6358705,,,,,,,,,,,,,,,,, -87500,0.20815763,1.589266,,,,,,,,,,,,,,,,, -87600,0.2042594,1.6322705,,,,,,,,,,,,,,,,, -87700,0.20094228,1.625828,,,,,,,,,,,,,,,,, -87800,0.7441621,1.592181,,,,,,,,,,,,,,,,, -87900,0.23155116,1.6703906,,,,,,,,,,,,,,,,, -88000,0.20923902,1.5580127,,,,,,,,,,,,,,,,, -88100,0.2640392,1.6104518,,,,,,,,,,,,,,,,, -88200,0.21332109,1.5925944,,,,,,,,,,,,,,,,, -88300,0.21924517,1.5283258,,,,,,,,,,,,,,,,, -88400,0.21319821,1.6651156,,,,,,,,,,,,,,,,, -88500,0.23179291,1.5620499,,,,,,,,,,,,,,,,, -88600,0.21250741,1.6014571,,,,,,,,,,,,,,,,, -88700,0.1994753,1.617301,,,,,,,,,,,,,,,,, -88800,0.2047087,1.6503006,,,,,,,,,,,,,,,,, -88900,0.20899661,1.474265,,,,,,,,,,,,,,,,, -89000,0.20220263,1.5501306,,,,,,,,,,,,,,,,, -89100,0.21274826,1.6311252,,,,,,,,,,,,,,,,, -89200,0.214914,1.6377646,,,,,,,,,,,,,,,,, -89293,,,0.6765536069869995,1.4846397638320925,33.80841067188661,0.6877037882804871,1.4231077432632446,30.144694232267003,3000.0,0.7022950649261475,1.3293834924697876,30.109991566006315,3003.0,31952.67540025711,54071.60535812378,31952.67540025711,22114.81109213829,1.2344374656677246,0.0 -89300,0.20180924,1.5852075,,,,,,,,,,,,,,,,, -89400,0.19912712,1.4769974,,,,,,,,,,,,,,,,, -89500,0.2007902,1.5610113,,,,,,,,,,,,,,,,, -89600,0.22932833,1.592849,,,,,,,,,,,,,,,,, -89700,0.21222866,1.534849,,,,,,,,,,,,,,,,, -89800,0.21537583,1.5786835,,,,,,,,,,,,,,,,, -89900,0.20976791,1.5377105,,,,,,,,,,,,,,,,, -90000,0.20526616,1.6107049,,,,,,,,,,,,,,,,, -90100,0.20916973,1.6288049,,,,,,,,,,,,,,,,, -90200,0.2159646,1.6053059,,,,,,,,,,,,,,,,, -90300,0.2120127,1.5787233,,,,,,,,,,,,,,,,, -90400,0.22269782,1.6285813,,,,,,,,,,,,,,,,, -90500,0.2064132,1.5703264,,,,,,,,,,,,,,,,, -90600,0.21940486,1.6188065,,,,,,,,,,,,,,,,, -90700,0.20747434,1.5034091,,,,,,,,,,,,,,,,, -90800,0.23975003,1.5768416,,,,,,,,,,,,,,,,, -90900,0.22198443,1.5850011,,,,,,,,,,,,,,,,, -91000,0.20124756,1.6260356,,,,,,,,,,,,,,,,, -91100,0.21319349,1.5685642,,,,,,,,,,,,,,,,, -91200,0.22339614,1.5682362,,,,,,,,,,,,,,,,, -91300,0.21138807,1.5935265,,,,,,,,,,,,,,,,, -91400,0.2993174,1.5242958,,,,,,,,,,,,,,,,, -91500,0.21119058,1.5754801,,,,,,,,,,,,,,,,, -91600,0.19061592,1.5024816,,,,,,,,,,,,,,,,, -91642,,,0.6729382872581482,1.5041873455047607,33.947059749229915,0.6867242455482483,1.4200741052627563,30.009697354141604,3000.0,0.7022950649261475,1.3214539289474487,30.012051023290432,3003.0,32792.68555688858,55434.03802442551,32792.68555688858,22637.1190366745,1.2718265056610107,0.0 -91700,0.20860018,1.5973009,,,,,,,,,,,,,,,,, -91800,0.20744848,1.5296094,,,,,,,,,,,,,,,,, -91900,0.20484824,1.5339279,,,,,,,,,,,,,,,,, -92000,0.2307567,1.6271685,,,,,,,,,,,,,,,,, -92100,0.21079487,1.6271697,,,,,,,,,,,,,,,,, -92200,0.20951226,1.5656363,,,,,,,,,,,,,,,,, -92300,0.21298175,1.5386881,,,,,,,,,,,,,,,,, -92400,0.19704281,1.5178452,,,,,,,,,,,,,,,,, -92500,0.28687206,1.6257948,,,,,,,,,,,,,,,,, -92600,0.22522807,1.6596884,,,,,,,,,,,,,,,,, -92700,0.21424791,1.6086822,,,,,,,,,,,,,,,,, -92800,0.22569801,1.5612154,,,,,,,,,,,,,,,,, -92900,0.21797767,1.5199648,,,,,,,,,,,,,,,,, -93000,0.20009986,1.595194,,,,,,,,,,,,,,,,, -93100,0.21048269,1.5749913,,,,,,,,,,,,,,,,, -93200,0.21832635,1.6295933,,,,,,,,,,,,,,,,, -93300,0.19951427,1.5590589,,,,,,,,,,,,,,,,, -93400,0.21566801,1.5565279,,,,,,,,,,,,,,,,, -93500,0.20977306,1.5416437,,,,,,,,,,,,,,,,, -93600,0.21718597,1.5996971,,,,,,,,,,,,,,,,, -93700,0.19842902,1.5662148,,,,,,,,,,,,,,,,, -93800,0.20469427,1.4739072,,,,,,,,,,,,,,,,, -93900,0.20155087,1.5632426,,,,,,,,,,,,,,,,, -93991,,,0.6925990581512451,1.38995623588562,34.75326842015933,0.6878277659416199,1.4126427173614502,30.225550510148068,3000.0,0.7032014727592468,1.3106350898742676,29.828300853864768,3003.0,33632.6453063488,56803.9655148983,33632.6453063488,23166.96990251541,1.3109593391418457,0.0 -94000,0.2052963,1.5551686,,,,,,,,,,,,,,,,, -94100,0.20585074,1.5505534,,,,,,,,,,,,,,,,, -94200,0.21960427,1.6298496,,,,,,,,,,,,,,,,, -94300,0.21164297,1.5937605,,,,,,,,,,,,,,,,, -94400,0.23811907,1.6340696,,,,,,,,,,,,,,,,, -94500,0.21042825,1.5670866,,,,,,,,,,,,,,,,, -94600,0.22113526,1.5619403,,,,,,,,,,,,,,,,, -94700,0.21591818,1.5437297,,,,,,,,,,,,,,,,, -94800,1.8169519,1.488834,,,,,,,,,,,,,,,,, -94900,0.20268402,1.5380073,,,,,,,,,,,,,,,,, -95000,0.2125692,1.6164411,,,,,,,,,,,,,,,,, -95100,0.21356757,1.55692,,,,,,,,,,,,,,,,, -95200,0.22058861,1.5157484,,,,,,,,,,,,,,,,, -95300,0.21103024,1.5483408,,,,,,,,,,,,,,,,, -95400,0.2048749,1.4826729,,,,,,,,,,,,,,,,, -95500,0.2227898,1.5541527,,,,,,,,,,,,,,,,, -95600,0.2128179,1.5853747,,,,,,,,,,,,,,,,, -95700,0.21424057,1.6522309,,,,,,,,,,,,,,,,, -95800,0.21688895,1.6029,,,,,,,,,,,,,,,,, -95900,0.21371086,1.6403838,,,,,,,,,,,,,,,,, -96000,0.20855898,1.5117143,,,,,,,,,,,,,,,,, -96100,0.21696898,1.6248542,,,,,,,,,,,,,,,,, -96200,0.21832493,1.5425388,,,,,,,,,,,,,,,,, -96300,0.201716,1.5775634,,,,,,,,,,,,,,,,, -96341,,,0.6803255081176758,1.4658100605010986,33.819762322880095,0.6885097622871399,1.407971978187561,29.990235745955843,3000.0,0.7055023312568665,1.3078140020370483,30.228751807442475,3003.0,34472.82131195068,58340.36324286461,34472.82131195068,23863.072603464127,1.351414918899536,0.0 -96400,0.22403696,1.633256,,,,,,,,,,,,,,,,, -96500,0.20604537,1.54455,,,,,,,,,,,,,,,,, -96600,0.20871715,1.5263715,,,,,,,,,,,,,,,,, -96700,0.23185669,1.5356383,,,,,,,,,,,,,,,,, -96800,0.2183003,1.5497638,,,,,,,,,,,,,,,,, -96900,0.20262979,1.5468711,,,,,,,,,,,,,,,,, -97000,0.21420182,1.5835342,,,,,,,,,,,,,,,,, -97100,0.22431344,1.4615548,,,,,,,,,,,,,,,,, -97200,0.22353004,1.5825357,,,,,,,,,,,,,,,,, -97300,0.21216413,1.5957606,,,,,,,,,,,,,,,,, -97400,0.21789184,1.5819478,,,,,,,,,,,,,,,,, -97500,0.22007005,1.5202459,,,,,,,,,,,,,,,,, -97600,0.20487553,1.5083506,,,,,,,,,,,,,,,,, -97700,0.21712151,1.5441525,,,,,,,,,,,,,,,,, -97800,0.2124737,1.4896947,,,,,,,,,,,,,,,,, -97900,0.21530987,1.5303205,,,,,,,,,,,,,,,,, -98000,0.21248254,1.4879355,,,,,,,,,,,,,,,,, -98100,0.21743837,1.454799,,,,,,,,,,,,,,,,, -98200,0.2076643,1.5142788,,,,,,,,,,,,,,,,, -98300,0.21816768,1.6009365,,,,,,,,,,,,,,,,, -98400,0.21875006,1.4693396,,,,,,,,,,,,,,,,, -98500,0.20703234,1.5329579,,,,,,,,,,,,,,,,, -98600,0.20610493,1.5319394,,,,,,,,,,,,,,,,, -98690,,,0.678549587726593,1.4740723371505735,34.745071742947765,0.689811646938324,1.4059818983078003,30.306885565787923,3000.0,0.7070013284683228,1.3015625476837158,30.51100426873089,3003.0,35312.71530985832,59696.38206171989,35312.71530985832,24379.07930803299,1.391371726989746,0.0 -98700,0.22053157,1.559509,,,,,,,,,,,,,,,,, -98800,0.21580322,1.5352907,,,,,,,,,,,,,,,,, -98900,0.22256012,1.5422391,,,,,,,,,,,,,,,,, -99000,0.20768163,1.5936776,,,,,,,,,,,,,,,,, -99100,0.21288979,1.4854215,,,,,,,,,,,,,,,,, -99200,0.20898105,1.5110974,,,,,,,,,,,,,,,,, -99300,0.21183215,1.5748581,,,,,,,,,,,,,,,,, -99400,0.22388706,1.5301818,,,,,,,,,,,,,,,,, -99500,0.2131476,1.5537479,,,,,,,,,,,,,,,,, -99600,0.220662,1.4634837,,,,,,,,,,,,,,,,, -99700,0.22022225,1.5593038,,,,,,,,,,,,,,,,, -99800,0.21417764,1.5188093,,,,,,,,,,,,,,,,, -99900,0.21620393,1.5324059,,,,,,,,,,,,,,,,, -100000,0.23477958,1.5269092,,,,,,,,,,,,,,,,, -100100,0.22367692,1.601593,,,,,,,,,,,,,,,,, -100200,0.21160218,1.4861549,,,,,,,,,,,,,,,,, -100300,0.20815341,1.5320857,,,,,,,,,,,,,,,,, -100400,0.22718881,1.5685143,,,,,,,,,,,,,,,,, -100500,0.2273615,1.5436428,,,,,,,,,,,,,,,,, -100600,0.20941347,1.5182446,,,,,,,,,,,,,,,,, -100700,0.20526432,1.4154735,,,,,,,,,,,,,,,,, -100800,0.21747157,1.5574766,,,,,,,,,,,,,,,,, -100900,0.30679557,1.5766399,,,,,,,,,,,,,,,,, -101000,0.21823913,1.4545107,,,,,,,,,,,,,,,,, -101040,,,0.6873783469200134,1.4144078493118286,34.822233159168285,0.689452052116394,1.394218683242798,30.43498417126523,3000.0,0.7060601115226746,1.2959932088851929,30.27591467570254,3003.0,36152.8135163784,61109.87213683128,36152.8135163784,24952.356118440628,1.4311096668243408,0.0 -101100,0.22061029,1.5257874,,,,,,,,,,,,,,,,, -101200,0.20921956,1.4419723,,,,,,,,,,,,,,,,, -101300,0.21118475,1.5098741,,,,,,,,,,,,,,,,, -101400,0.22338521,1.506337,,,,,,,,,,,,,,,,, -101500,0.21690181,1.5729245,,,,,,,,,,,,,,,,, -101600,0.22611229,1.5579871,,,,,,,,,,,,,,,,, -101700,0.23079649,1.5030698,,,,,,,,,,,,,,,,, -101800,0.20446864,1.4922594,,,,,,,,,,,,,,,,, -101900,0.21127401,1.5592618,,,,,,,,,,,,,,,,, -102000,0.21258204,1.5154196,,,,,,,,,,,,,,,,, -102100,0.21153523,1.4821246,,,,,,,,,,,,,,,,, -102200,0.23208724,1.5873861,,,,,,,,,,,,,,,,, -102300,0.22053461,1.5612628,,,,,,,,,,,,,,,,, -102400,0.21435946,1.4817427,,,,,,,,,,,,,,,,, -102500,0.21091634,1.4905366,,,,,,,,,,,,,,,,, -102600,0.2246142,1.4894133,,,,,,,,,,,,,,,,, -102700,0.23246552,1.4955603,,,,,,,,,,,,,,,,, -102800,0.21808998,1.5007232,,,,,,,,,,,,,,,,, -102900,0.21805567,1.591437,,,,,,,,,,,,,,,,, -103000,0.21984081,1.5546249,,,,,,,,,,,,,,,,, -103100,0.223223,1.560791,,,,,,,,,,,,,,,,, -103200,0.20671107,1.4658462,,,,,,,,,,,,,,,,, -103300,0.23189725,1.561627,,,,,,,,,,,,,,,,, -103390,,,0.6787706017494202,1.472381591796875,34.936464710336374,0.6904067993164062,1.394535779953003,30.4547279180483,3000.0,0.7068851590156555,1.2904762029647827,30.54459233862504,3003.0,36992.80270028114,62544.18134522438,36992.80270028114,25546.56099653244,1.4709479808807373,0.0 -103400,0.25081035,1.5998738,,,,,,,,,,,,,,,,, -103500,0.21629702,1.435731,,,,,,,,,,,,,,,,, -103600,0.21229534,1.5066724,,,,,,,,,,,,,,,,, -103700,0.21974172,1.4894266,,,,,,,,,,,,,,,,, -103800,0.22858113,1.5739375,,,,,,,,,,,,,,,,, -103900,0.21793312,1.468584,,,,,,,,,,,,,,,,, -104000,0.20866093,1.5222881,,,,,,,,,,,,,,,,, -104100,0.2086951,1.5165089,,,,,,,,,,,,,,,,, -104200,0.22634979,1.6165972,,,,,,,,,,,,,,,,, -104300,0.21810974,1.4429468,,,,,,,,,,,,,,,,, -104400,0.22979201,1.5364366,,,,,,,,,,,,,,,,, -104500,0.22659111,1.5209346,,,,,,,,,,,,,,,,, -104600,0.223414,1.4782361,,,,,,,,,,,,,,,,, -104700,0.22155528,1.5046369,,,,,,,,,,,,,,,,, -104800,0.23240663,1.4402916,,,,,,,,,,,,,,,,, -104900,0.23024127,1.5096718,,,,,,,,,,,,,,,,, -105000,0.22404341,1.5136961,,,,,,,,,,,,,,,,, -105100,0.21504684,1.5705265,,,,,,,,,,,,,,,,, -105200,0.22069298,1.4826064,,,,,,,,,,,,,,,,, -105300,0.23263022,1.560161,,,,,,,,,,,,,,,,, -105400,0.21298829,1.4674155,,,,,,,,,,,,,,,,, -105500,0.22853447,1.5431271,,,,,,,,,,,,,,,,, -105600,0.22215372,1.5242286,,,,,,,,,,,,,,,,, -105700,0.23353624,1.5216401,,,,,,,,,,,,,,,,, -105740,,,0.6793479919433594,1.474918246269226,34.524389490258514,0.6916218996047974,1.387118220329285,30.763161599157005,3000.0,0.7088722586631775,1.2860301733016968,30.58754004769873,3003.0,37832.90144467354,63912.06481075287,37832.90144467354,26074.223981142044,1.5178804397583008,0.0 -105800,0.24621479,1.5585161,,,,,,,,,,,,,,,,, -105900,0.21576573,1.4711925,,,,,,,,,,,,,,,,, -106000,0.23130105,1.5310194,,,,,,,,,,,,,,,,, -106100,0.21984744,1.484138,,,,,,,,,,,,,,,,, -106200,0.21514694,1.4486716,,,,,,,,,,,,,,,,, -106300,0.22769892,1.4368745,,,,,,,,,,,,,,,,, -106400,0.22187385,1.435773,,,,,,,,,,,,,,,,, -106500,0.22053231,1.4085354,,,,,,,,,,,,,,,,, -106600,0.22763188,1.5361449,,,,,,,,,,,,,,,,, -106700,0.21517669,1.4946853,,,,,,,,,,,,,,,,, -106800,0.23221393,1.5168875,,,,,,,,,,,,,,,,, -106900,0.22498484,1.5391171,,,,,,,,,,,,,,,,, -107000,0.23031043,1.4824452,,,,,,,,,,,,,,,,, -107100,0.21233907,1.4960121,,,,,,,,,,,,,,,,, -107200,0.22876795,1.5697719,,,,,,,,,,,,,,,,, -107300,0.2112994,1.4296317,,,,,,,,,,,,,,,,, -107400,0.23656203,1.5054231,,,,,,,,,,,,,,,,, -107500,0.24506262,1.5277227,,,,,,,,,,,,,,,,, -107600,0.21576539,1.4933549,,,,,,,,,,,,,,,,, -107700,0.22134405,1.4450427,,,,,,,,,,,,,,,,, -107800,0.23868467,1.5525724,,,,,,,,,,,,,,,,, -107900,0.23628357,1.5655553,,,,,,,,,,,,,,,,, -108000,0.22455207,1.5213858,,,,,,,,,,,,,,,,, -108090,,,0.6895817518234253,1.4022433757781982,34.956725807169725,0.6924774646759033,1.3789490461349487,30.50776881508136,3000.0,0.7091395258903503,1.276688575744629,30.70939975137136,3003.0,38672.90054440498,65321.97705411911,38672.90054440498,26644.022025108337,1.5571041107177734,0.0 -108100,0.22243622,1.53519,,,,,,,,,,,,,,,,, -108200,0.2265177,1.4362468,,,,,,,,,,,,,,,,, -108300,0.21774322,1.4532942,,,,,,,,,,,,,,,,, -108400,0.22818947,1.4965835,,,,,,,,,,,,,,,,, -108500,0.2350847,1.5211415,,,,,,,,,,,,,,,,, -108600,0.23820171,1.5233587,,,,,,,,,,,,,,,,, -108700,0.22661154,1.5823174,,,,,,,,,,,,,,,,, -108800,0.23468024,1.4929008,,,,,,,,,,,,,,,,, -108900,0.21810149,1.5382999,,,,,,,,,,,,,,,,, -109000,0.21786977,1.4702568,,,,,,,,,,,,,,,,, -109100,0.24612607,1.434437,,,,,,,,,,,,,,,,, -109200,0.22339247,1.4995539,,,,,,,,,,,,,,,,, -109300,0.23203139,1.4929,,,,,,,,,,,,,,,,, -109400,0.2250927,1.545076,,,,,,,,,,,,,,,,, -109500,0.22264236,1.5387814,,,,,,,,,,,,,,,,, -109600,0.24205105,1.4584873,,,,,,,,,,,,,,,,, -109700,0.21900873,1.4164468,,,,,,,,,,,,,,,,, -109800,0.224271,1.5172828,,,,,,,,,,,,,,,,, -109900,0.22573426,1.5708494,,,,,,,,,,,,,,,,, -110000,0.22636724,1.5461392,,,,,,,,,,,,,,,,, -110100,0.33649617,1.4666406,,,,,,,,,,,,,,,,, -110200,0.2273896,1.5170809,,,,,,,,,,,,,,,,, -110300,0.22516581,1.4725666,,,,,,,,,,,,,,,,, -110400,0.23168735,1.5021353,,,,,,,,,,,,,,,,, -110440,,,0.6895023584365845,1.4150652885437012,35.16845176199788,0.6943125128746033,1.378193974494934,30.72178388616816,3000.0,0.7114055156707764,1.2721518278121948,30.80145430874005,3003.0,39513.07421565056,66752.35574197769,39513.07421565056,27234.110072135925,1.5973656177520752,0.0 -110500,0.23520118,1.5323485,,,,,,,,,,,,,,,,, -110600,0.232785,1.5045991,,,,,,,,,,,,,,,,, -110700,0.22653723,1.5208138,,,,,,,,,,,,,,,,, -110800,0.22394425,1.4886606,,,,,,,,,,,,,,,,, -110900,0.2220399,1.4877988,,,,,,,,,,,,,,,,, -111000,0.2296461,1.5069177,,,,,,,,,,,,,,,,, -111100,0.2358675,1.4671838,,,,,,,,,,,,,,,,, -111200,0.2307614,1.4660975,,,,,,,,,,,,,,,,, -111300,0.22867721,1.4249847,,,,,,,,,,,,,,,,, -111400,0.22760928,1.5196265,,,,,,,,,,,,,,,,, -111500,0.23069671,1.5096068,,,,,,,,,,,,,,,,, -111600,0.21773332,1.5146227,,,,,,,,,,,,,,,,, -111700,0.23349312,1.3854082,,,,,,,,,,,,,,,,, -111800,0.2414722,1.5509448,,,,,,,,,,,,,,,,, -111900,0.25012296,1.554607,,,,,,,,,,,,,,,,, -112000,0.23224658,1.5068679,,,,,,,,,,,,,,,,, -112100,0.224598,1.4883404,,,,,,,,,,,,,,,,, -112200,0.25356662,1.5544813,,,,,,,,,,,,,,,,, -112300,0.23724054,1.4990164,,,,,,,,,,,,,,,,, -112400,0.22801259,1.4064384,,,,,,,,,,,,,,,,, -112500,0.22563688,1.5362061,,,,,,,,,,,,,,,,, -112600,0.2239393,1.4079942,,,,,,,,,,,,,,,,, -112700,0.23930664,1.4819261,,,,,,,,,,,,,,,,, -112791,,,0.6971345543861389,1.3650158643722534,35.54626248357229,0.6938785314559937,1.3746583461761477,30.809153601006447,3000.0,0.7100808024406433,1.2701268196105957,30.55024025701441,3003.0,40353.12133765221,68157.70009493828,40353.12133765221,27799.28997278213,1.638994216918945,0.0 -112800,0.22690962,1.4860618,,,,,,,,,,,,,,,,, -112900,0.23053338,1.4775194,,,,,,,,,,,,,,,,, -113000,0.23184618,1.5176519,,,,,,,,,,,,,,,,, -113100,0.22881484,1.5084713,,,,,,,,,,,,,,,,, -113200,0.22466332,1.4592406,,,,,,,,,,,,,,,,, -113300,0.23586363,1.4984627,,,,,,,,,,,,,,,,, -113400,0.23718527,1.5982258,,,,,,,,,,,,,,,,, -113500,0.23847914,1.4101992,,,,,,,,,,,,,,,,, -113600,0.23492138,1.4636546,,,,,,,,,,,,,,,,, -113700,0.24196179,1.5729676,,,,,,,,,,,,,,,,, -113800,0.23001981,1.4533932,,,,,,,,,,,,,,,,, -113900,0.23656537,1.5880204,,,,,,,,,,,,,,,,, -114000,0.22594637,1.3983488,,,,,,,,,,,,,,,,, -114100,0.23147973,1.4628451,,,,,,,,,,,,,,,,, -114200,0.23174462,1.4981978,,,,,,,,,,,,,,,,, -114300,0.23096503,1.5320915,,,,,,,,,,,,,,,,, -114400,0.23690732,1.4474366,,,,,,,,,,,,,,,,, -114500,0.23868021,1.5334374,,,,,,,,,,,,,,,,, -114600,0.22531603,1.4583148,,,,,,,,,,,,,,,,, -114700,0.23569475,1.4856547,,,,,,,,,,,,,,,,, -114800,0.22775348,1.5193745,,,,,,,,,,,,,,,,, -114900,0.23272054,1.5260534,,,,,,,,,,,,,,,,, -115000,0.22850537,1.4763589,,,,,,,,,,,,,,,,, -115100,0.22108997,1.3927855,,,,,,,,,,,,,,,,, -115141,,,0.6917914748191833,1.399938702583313,35.197873516392164,0.6941389441490173,1.3688514232635498,30.793833684223024,3000.0,0.7127999663352966,1.2630198001861572,30.843691765002152,3003.0,41193.01627731323,69688.62987804413,41193.01627731323,28490.20787382126,1.6810777187347412,0.0 -115200,0.23071846,1.4236767,,,,,,,,,,,,,,,,, -115300,0.23079038,1.5183493,,,,,,,,,,,,,,,,, -115400,0.22097665,1.3884809,,,,,,,,,,,,,,,,, -115500,0.23668939,1.531727,,,,,,,,,,,,,,,,, -115600,0.23257992,1.4581113,,,,,,,,,,,,,,,,, -115700,0.23261108,1.5192091,,,,,,,,,,,,,,,,, -115800,0.23359583,1.518697,,,,,,,,,,,,,,,,, -115900,0.24529591,1.5048376,,,,,,,,,,,,,,,,, -116000,0.2332033,1.4613433,,,,,,,,,,,,,,,,, -116100,0.22751367,1.4758263,,,,,,,,,,,,,,,,, -116200,0.22790718,1.4474171,,,,,,,,,,,,,,,,, -116300,0.23013891,1.5145606,,,,,,,,,,,,,,,,, -116400,0.24853726,1.5026871,,,,,,,,,,,,,,,,, -116500,0.25128156,1.4865283,,,,,,,,,,,,,,,,, -116600,0.23334679,1.4598151,,,,,,,,,,,,,,,,, -116700,0.244343,1.4493246,,,,,,,,,,,,,,,,, -116800,0.23735934,1.5037947,,,,,,,,,,,,,,,,, -116900,0.23757675,1.4569174,,,,,,,,,,,,,,,,, -117000,0.24139301,1.4478031,,,,,,,,,,,,,,,,, -117100,0.23625304,1.4817902,,,,,,,,,,,,,,,,, -117200,0.22786856,1.4504635,,,,,,,,,,,,,,,,, -117300,0.23143862,1.3818902,,,,,,,,,,,,,,,,, -117400,0.2251658,1.4394138,,,,,,,,,,,,,,,,, -117491,,,0.6893499493598938,1.4066481590270996,35.249029116488906,0.6947464942932129,1.3684260845184326,30.907103979729687,3000.0,0.7123932838439941,1.262866735458374,30.727742222734168,3003.0,42032.95151424408,71083.43123865128,42032.95151424408,29044.949915885925,1.7298576831817627,0.0 -117500,0.24165536,1.5657161,,,,,,,,,,,,,,,,, -117600,0.22773188,1.4165391,,,,,,,,,,,,,,,,, -117700,0.22329596,1.4692614,,,,,,,,,,,,,,,,, -117800,0.23404665,1.4049163,,,,,,,,,,,,,,,,, -117900,0.23181559,1.4626682,,,,,,,,,,,,,,,,, -118000,0.24652852,1.4708041,,,,,,,,,,,,,,,,, -118100,0.23459783,1.5167093,,,,,,,,,,,,,,,,, -118200,0.22840036,1.442047,,,,,,,,,,,,,,,,, -118300,0.2384978,1.4622235,,,,,,,,,,,,,,,,, -118400,0.23006153,1.4642082,,,,,,,,,,,,,,,,, -118500,0.23962556,1.4791378,,,,,,,,,,,,,,,,, -118600,0.24702346,1.482881,,,,,,,,,,,,,,,,, -118700,0.24384555,1.5584304,,,,,,,,,,,,,,,,, -118800,0.22661887,1.4948004,,,,,,,,,,,,,,,,, -118900,0.23749551,1.4571211,,,,,,,,,,,,,,,,, -119000,0.23824832,1.4974085,,,,,,,,,,,,,,,,, -119100,0.22507617,1.436594,,,,,,,,,,,,,,,,, -119200,0.2422726,1.4721743,,,,,,,,,,,,,,,,, -119300,0.2326592,1.4612157,,,,,,,,,,,,,,,,, -119400,0.24265695,1.5138392,,,,,,,,,,,,,,,,, -119500,0.23646626,1.4702357,,,,,,,,,,,,,,,,, -119600,0.23542854,1.390713,,,,,,,,,,,,,,,,, -119700,0.2510049,1.5508075,,,,,,,,,,,,,,,,, -119800,0.23776141,1.4975132,,,,,,,,,,,,,,,,, -119841,,,0.6982017755508423,1.360303521156311,35.42124147020956,0.696234405040741,1.3644174337387085,30.81000255265572,3000.0,0.7137296199798584,1.258353590965271,30.89031023851516,3003.0,42873.06673336029,72488.55532360077,42873.06673336029,29609.83997654915,1.772599458694458,0.0 -119900,0.2330619,1.4323255,,,,,,,,,,,,,,,,, -120000,0.24007188,1.4014895,,,,,,,,,,,,,,,,, -120100,0.24778105,1.42449,,,,,,,,,,,,,,,,, -120200,0.23972613,1.4802701,,,,,,,,,,,,,,,,, -120300,0.2398779,1.4421186,,,,,,,,,,,,,,,,, -120400,0.24004532,1.4791356,,,,,,,,,,,,,,,,, -120500,0.2360802,1.3893201,,,,,,,,,,,,,,,,, -120600,0.23774067,1.4848268,,,,,,,,,,,,,,,,, -120700,0.24688952,1.4413494,,,,,,,,,,,,,,,,, -120800,0.23939733,1.4239037,,,,,,,,,,,,,,,,, -120900,0.23251462,1.445214,,,,,,,,,,,,,,,,, -121000,0.23863024,1.4339567,,,,,,,,,,,,,,,,, -121100,0.23447444,1.4994352,,,,,,,,,,,,,,,,, -121200,0.2327545,1.426213,,,,,,,,,,,,,,,,, -121300,0.23340476,1.5100449,,,,,,,,,,,,,,,,, -121400,0.23084731,1.4055107,,,,,,,,,,,,,,,,, -121500,0.23225954,1.4282091,,,,,,,,,,,,,,,,, -121600,0.2395753,1.4310344,,,,,,,,,,,,,,,,, -121700,0.25384414,1.4626094,,,,,,,,,,,,,,,,, -121800,0.23942475,1.4501652,,,,,,,,,,,,,,,,, -121900,0.23567489,1.4756219,,,,,,,,,,,,,,,,, -122000,0.23463169,1.4776607,,,,,,,,,,,,,,,,, -122100,0.22858195,1.416294,,,,,,,,,,,,,,,,, -122190,,,0.6944260001182556,1.3774758577346802,35.73075424767851,0.6958252191543579,1.363082766532898,30.87195959719736,3000.0,0.7138806581497192,1.2562905550003052,31.07766958165444,3003.0,43712.94831061363,73994.31863379478,43712.94831061363,30275.60227298737,1.814788818359375,0.0 -122200,0.24358957,1.4745448,,,,,,,,,,,,,,,,, -122300,0.24138921,1.4944475,,,,,,,,,,,,,,,,, -122400,0.2312791,1.3652165,,,,,,,,,,,,,,,,, -122500,0.23413563,1.4552888,,,,,,,,,,,,,,,,, -122600,0.23006998,1.4097608,,,,,,,,,,,,,,,,, -122700,0.24026991,1.532617,,,,,,,,,,,,,,,,, -122800,0.23967536,1.4451797,,,,,,,,,,,,,,,,, -122900,0.23783843,1.4340999,,,,,,,,,,,,,,,,, -123000,0.23824014,1.4723669,,,,,,,,,,,,,,,,, -123100,0.23449224,1.4485115,,,,,,,,,,,,,,,,, -123200,0.24624363,1.433242,,,,,,,,,,,,,,,,, -123300,0.23730211,1.470574,,,,,,,,,,,,,,,,, -123400,0.23271364,1.4844807,,,,,,,,,,,,,,,,, -123500,0.2379756,1.4949706,,,,,,,,,,,,,,,,, -123600,0.22356716,1.4293729,,,,,,,,,,,,,,,,, -123700,0.22508238,1.5010918,,,,,,,,,,,,,,,,, -123800,0.22629805,1.4252172,,,,,,,,,,,,,,,,, -123900,0.23737733,1.4648293,,,,,,,,,,,,,,,,, -124000,0.2424707,1.5065349,,,,,,,,,,,,,,,,, -124100,0.2280438,1.4279523,,,,,,,,,,,,,,,,, -124200,0.23570363,1.450029,,,,,,,,,,,,,,,,, -124300,0.24687563,1.5327266,,,,,,,,,,,,,,,,, -124400,0.23895514,1.4405452,,,,,,,,,,,,,,,,, -124500,0.2290861,1.5300903,,,,,,,,,,,,,,,,, -124540,,,0.6927077174186707,1.3844177722930908,35.60168167605079,0.6960111856460571,1.361225128173828,31.036330806615823,3000.0,0.7133693695068359,1.2556757926940918,30.94660623697596,3003.0,44553.16763043404,75423.9879014492,44553.16763043404,30864.93314909935,1.8580679893493648,0.0 -124600,0.23632947,1.4590495,,,,,,,,,,,,,,,,, -124700,0.23396403,1.4788868,,,,,,,,,,,,,,,,, -124800,0.23212261,1.4341778,,,,,,,,,,,,,,,,, -124900,0.23902178,1.4784335,,,,,,,,,,,,,,,,, -125000,0.22077295,1.4352634,,,,,,,,,,,,,,,,, -125100,0.23982035,1.4906023,,,,,,,,,,,,,,,,, -125200,0.23172052,1.3797904,,,,,,,,,,,,,,,,, -125300,0.24153869,1.5780754,,,,,,,,,,,,,,,,, -125400,0.2430157,1.5142131,,,,,,,,,,,,,,,,, -125500,0.2429049,1.4776195,,,,,,,,,,,,,,,,, -125600,0.23730808,1.4360211,,,,,,,,,,,,,,,,, -125700,0.23546349,1.4480928,,,,,,,,,,,,,,,,, -125800,0.23226757,1.4371165,,,,,,,,,,,,,,,,, -125900,0.2299005,1.4823396,,,,,,,,,,,,,,,,, -126000,0.24382083,1.490585,,,,,,,,,,,,,,,,, -126100,0.23016918,1.4639182,,,,,,,,,,,,,,,,, -126200,0.26603153,1.4300197,,,,,,,,,,,,,,,,, -126300,0.23163678,1.417418,,,,,,,,,,,,,,,,, -126400,0.22757064,1.3316643,,,,,,,,,,,,,,,,, -126500,0.23863262,1.4580262,,,,,,,,,,,,,,,,, -126600,0.2249556,1.4568123,,,,,,,,,,,,,,,,, -126700,0.24008444,1.4984419,,,,,,,,,,,,,,,,, -126800,0.23899183,1.5262023,,,,,,,,,,,,,,,,, -126890,,,0.6990454196929932,1.3546652793884275,35.55317402084076,0.6965071558952332,1.360289216041565,30.923319946605933,3000.0,0.7141130566596985,1.2537082433700562,30.891123087192003,3003.0,45393.23704409599,76871.89690589905,45393.23704409599,31472.653499126434,1.9009826183319087,0.0 -126900,0.23740323,1.3967909,,,,,,,,,,,,,,,,, -127000,0.23429877,1.4390662,,,,,,,,,,,,,,,,, -127100,0.22631525,1.3981086,,,,,,,,,,,,,,,,, -127200,0.24198057,1.5055773,,,,,,,,,,,,,,,,, -127300,0.24173972,1.513973,,,,,,,,,,,,,,,,, -127400,0.22980495,1.4773108,,,,,,,,,,,,,,,,, -127500,0.24554107,1.4146625,,,,,,,,,,,,,,,,, -127600,0.23368634,1.4384395,,,,,,,,,,,,,,,,, -127700,0.22664446,1.4469209,,,,,,,,,,,,,,,,, -127800,0.229449,1.4402618,,,,,,,,,,,,,,,,, -127900,0.23993254,1.4318534,,,,,,,,,,,,,,,,, -128000,0.23684105,1.5246049,,,,,,,,,,,,,,,,, -128100,0.23802328,1.4569885,,,,,,,,,,,,,,,,, -128200,0.22945544,1.442318,,,,,,,,,,,,,,,,, -128300,0.22984777,1.4060475,,,,,,,,,,,,,,,,, -128400,0.22877987,1.4376117,,,,,,,,,,,,,,,,, -128500,0.22475693,1.3764932,,,,,,,,,,,,,,,,, -128600,0.2255104,1.4290897,,,,,,,,,,,,,,,,, -128700,0.22599213,1.3519635,,,,,,,,,,,,,,,,, -128800,0.26376206,1.4232846,,,,,,,,,,,,,,,,, -128900,0.23274276,1.4440168,,,,,,,,,,,,,,,,, -129000,0.24828945,1.496377,,,,,,,,,,,,,,,,, -129100,0.2377393,1.5016943,,,,,,,,,,,,,,,,, -129200,0.24229863,1.457997,,,,,,,,,,,,,,,,, -129239,,,0.6980040073394775,1.359439492225647,35.294106469705945,0.6968419551849365,1.3597447872161863,30.871485970728653,3000.0,0.714043378829956,1.252797245979309,30.942558677133007,3003.0,46233.19505262375,78352.19459056854,46233.19505262375,32112.87235379219,1.943530559539795,0.0 -129300,0.24254641,1.4657075,,,,,,,,,,,,,,,,, -129400,0.23206961,1.4588715,,,,,,,,,,,,,,,,, -129500,0.2494909,1.5122516,,,,,,,,,,,,,,,,, -129600,0.23926955,1.4470055,,,,,,,,,,,,,,,,, -129700,0.23948082,1.4137052,,,,,,,,,,,,,,,,, -129800,0.22830875,1.4295036,,,,,,,,,,,,,,,,, -129900,0.23267618,1.4785963,,,,,,,,,,,,,,,,, -130000,0.23224479,1.4473698,,,,,,,,,,,,,,,,, -130100,0.2368766,1.4497042,,,,,,,,,,,,,,,,, -130200,0.23586291,1.4964329,,,,,,,,,,,,,,,,, -130300,0.23180868,1.4259447,,,,,,,,,,,,,,,,, -130400,0.23110746,1.3425748,,,,,,,,,,,,,,,,, -130500,0.23872499,1.474772,,,,,,,,,,,,,,,,, -130600,0.2266066,1.4610932,,,,,,,,,,,,,,,,, -130700,0.24615863,1.4722555,,,,,,,,,,,,,,,,, -130800,0.22416852,1.4183673,,,,,,,,,,,,,,,,, -130900,0.24090184,1.556702,,,,,,,,,,,,,,,,, -131000,0.23507977,1.4195848,,,,,,,,,,,,,,,,, -131100,0.23796403,1.4462254,,,,,,,,,,,,,,,,, -131200,0.22775322,1.4478406,,,,,,,,,,,,,,,,, -131300,0.230252,1.4906089,,,,,,,,,,,,,,,,, -131400,0.23670024,1.4164778,,,,,,,,,,,,,,,,, -131500,0.2338066,1.4424381,,,,,,,,,,,,,,,,, -131588,,,0.6947080492973328,1.3761208057403564,35.40285800418826,0.6967551708221436,1.3596832752227783,30.906944090047656,3000.0,0.714043378829956,1.252558946609497,30.929931785958644,3003.0,47073.37712454796,79778.21843957901,47073.37712454796,32698.58621668816,1.9933269023895264,0.0 -131600,0.23050009,1.4803784,,,,,,,,,,,,,,,,, -131700,0.25253382,1.4808745,,,,,,,,,,,,,,,,, -131800,0.24098802,1.5097022,,,,,,,,,,,,,,,,, -131900,0.24296008,1.4307816,,,,,,,,,,,,,,,,, -132000,0.23223737,1.4180076,,,,,,,,,,,,,,,,, -132100,0.23578198,1.42206,,,,,,,,,,,,,,,,, -132200,0.23791881,1.4373264,,,,,,,,,,,,,,,,, -132300,0.22763929,1.414293,,,,,,,,,,,,,,,,, -132400,0.23157996,1.4141052,,,,,,,,,,,,,,,,, -132500,0.23395734,1.5039573,,,,,,,,,,,,,,,,, -132600,0.23667325,1.4760239,,,,,,,,,,,,,,,,, -132700,0.24532944,1.4669217,,,,,,,,,,,,,,,,, -132800,0.22798851,1.4162321,,,,,,,,,,,,,,,,, -132900,0.22971867,1.46208,,,,,,,,,,,,,,,,, -133000,0.23157631,1.4421784,,,,,,,,,,,,,,,,, -133100,0.22880234,1.4445933,,,,,,,,,,,,,,,,, -133200,0.23288459,1.3616681,,,,,,,,,,,,,,,,, -133300,0.24283859,1.558481,,,,,,,,,,,,,,,,, -133333,,,0.6980851888656616,1.3578859567642212,35.45882877194227,0.6968915462493896,1.359694004058838,30.90651925945721,3000.0,0.7140317559242249,1.2525193691253662,30.921156057591386,3003.0,47697.16169524193,80995.71348690987,47697.16169524193,33292.195563316345,2.0389645099639893,0.0 -133333,,,,,,,,,,,,,,47697.16169524193,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 51bc4be51..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -789.6232960224152,0.0,19.821783542633057,1,0,19.821783542633057,0.5930051776315789,95000000,809.4451253414154,0.5934471255578335,0.5934494414247642,83274637 -1430.9373450279236,0.0283305644989013,1220.1847472190857,1559,0,1220.1847472190857,0.128100659087171,95000000,2651.2012639045715,0.1240559720299528,0.1257393250761783,83274637 -1983.853398323059,0.0619218349456787,2420.6119287014008,3122,0,2420.6119287014008,0.1275880861328125,95000000,4404.628431797028,0.1231940922872075,0.1251292627970266,83274637 -2523.788309574127,0.0908405780792236,3620.54736661911,4693,0,3620.54736661911,0.1267332276110197,95000000,6144.577065706253,0.1228165935326672,0.1243790355418331,83274637 -3033.384358644485,0.1131591796875,4820.525413274765,6272,0,4820.525413274765,0.1267882915296052,95000000,7854.223849773407,0.1232032984401444,0.1243668215794804,83274637 -3491.767944574356,0.1436803340911865,6020.769790649414,7845,0,6020.769790649414,0.1263900748766447,95000000,9512.933719873428,0.1219598606416264,0.1240499002602977,83274637 -3818.5553402900696,0.16919875144958496,7221.337500095367,9407,0,7221.337500095367,0.12598504163240132,95000000,11040.364230632782,0.12371866953260494,0.12366220080686152,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 63d918bad..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,110 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.8689604,0.593364,,,,,,,,,,, -1,,,0.5934471255578335,0.5934494414247642,83274637.0,0.5930051776315789,95000000.0,19.821783542633057,809.4451253414154,19.821783542633057,789.6232960224152,0.0,0.0 -100,0.14305909,0.14145963,,,,,,,,,,, -200,0.033346143,0.12514693,,,,,,,,,,, -300,0.008956414,0.1263695,,,,,,,,,,, -400,0.02517664,0.14033863,,,,,,,,,,, -500,0.014895337,0.12597677,,,,,,,,,,, -600,0.04565084,0.12640028,,,,,,,,,,, -700,0.023740606,0.12675717,,,,,,,,,,, -800,0.011380516,0.122107334,,,,,,,,,,, -900,0.007991499,0.12125404,,,,,,,,,,, -1000,0.008053569,0.12121835,,,,,,,,,,, -1100,0.015970686,0.121832184,,,,,,,,,,, -1200,0.010091568,0.12980062,,,,,,,,,,, -1300,0.0071467157,0.12019891,,,,,,,,,,, -1400,0.012736411,0.12779571,,,,,,,,,,, -1500,0.012707091,0.123718925,,,,,,,,,,, -1559,,,0.1240559720299528,0.1257393250761783,83274637.0,0.128100659087171,95000000.0,1220.1847472190857,2651.2012639045715,1220.1847472190857,1430.9373450279236,0.0283305644989013,0.0 -1600,0.016967593,0.11758965,,,,,,,,,,, -1700,0.011799624,0.12265742,,,,,,,,,,, -1800,0.033904765,0.12825565,,,,,,,,,,, -1900,0.0053947065,0.12964842,,,,,,,,,,, -2000,0.0073069767,0.11676799,,,,,,,,,,, -2100,0.01623624,0.12937099,,,,,,,,,,, -2200,0.016061014,0.1323927,,,,,,,,,,, -2300,0.011632912,0.116615415,,,,,,,,,,, -2400,0.0112238,0.122150615,,,,,,,,,,, -2500,0.016274905,0.12643734,,,,,,,,,,, -2600,0.0065546674,0.1279931,,,,,,,,,,, -2700,0.015876768,0.12027149,,,,,,,,,,, -2800,0.0128721995,0.12216821,,,,,,,,,,, -2900,0.019914541,0.11697235,,,,,,,,,,, -3000,0.0057323263,0.12713195,,,,,,,,,,, -3100,0.005671126,0.114027604,,,,,,,,,,, -3122,,,0.1231940922872075,0.1251292627970266,83274637.0,0.1275880861328125,95000000.0,2420.6119287014008,4404.628431797028,2420.6119287014008,1983.853398323059,0.0619218349456787,0.0 -3200,0.014266486,0.12205626,,,,,,,,,,, -3300,0.016996391,0.12580714,,,,,,,,,,, -3400,0.0067512114,0.13084096,,,,,,,,,,, -3500,0.038645808,0.12730508,,,,,,,,,,, -3600,0.0060208254,0.13855574,,,,,,,,,,, -3700,0.006338771,0.123784296,,,,,,,,,,, -3800,0.0060985032,0.120332204,,,,,,,,,,, -3900,0.012685545,0.12235949,,,,,,,,,,, -4000,0.009015765,0.12053936,,,,,,,,,,, -4100,0.006032789,0.12225617,,,,,,,,,,, -4200,0.010819113,0.12112528,,,,,,,,,,, -4300,0.01643131,0.12945539,,,,,,,,,,, -4400,0.010778077,0.12568982,,,,,,,,,,, -4500,0.008603881,0.1260732,,,,,,,,,,, -4600,0.0126145845,0.12090535,,,,,,,,,,, -4693,,,0.1228165935326672,0.1243790355418331,83274637.0,0.1267332276110197,95000000.0,3620.54736661911,6144.577065706253,3620.54736661911,2523.788309574127,0.0908405780792236,0.0 -4700,0.012279029,0.121979475,,,,,,,,,,, -4800,0.006969995,0.11829758,,,,,,,,,,, -4900,0.006999944,0.114997946,,,,,,,,,,, -5000,0.0074397586,0.11632217,,,,,,,,,,, -5100,0.006790245,0.119045764,,,,,,,,,,, -5200,0.009241233,0.13000454,,,,,,,,,,, -5300,0.016954511,0.12075726,,,,,,,,,,, -5400,0.0072269877,0.12769166,,,,,,,,,,, -5500,0.009152146,0.12828392,,,,,,,,,,, -5600,0.014200248,0.1250207,,,,,,,,,,, -5700,0.0075712986,0.12187411,,,,,,,,,,, -5800,0.013425044,0.122390896,,,,,,,,,,, -5900,0.009113464,0.116561934,,,,,,,,,,, -6000,0.0074636154,0.1334394,,,,,,,,,,, -6100,0.005798182,0.13819191,,,,,,,,,,, -6200,0.0082300315,0.1262469,,,,,,,,,,, -6272,,,0.1232032984401444,0.1243668215794804,83274637.0,0.1267882915296052,95000000.0,4820.525413274765,7854.223849773407,4820.525413274765,3033.384358644485,0.1131591796875,0.0 -6300,0.065899156,0.13405621,,,,,,,,,,, -6400,0.0077040605,0.126286,,,,,,,,,,, -6500,0.006017109,0.119542025,,,,,,,,,,, -6600,0.009132438,0.12569177,,,,,,,,,,, -6700,0.011442973,0.117048085,,,,,,,,,,, -6800,0.0062482664,0.12838005,,,,,,,,,,, -6900,0.009887905,0.122501865,,,,,,,,,,, -7000,0.009077637,0.12644573,,,,,,,,,,, -7100,0.008953393,0.12599179,,,,,,,,,,, -7200,0.006685711,0.12087943,,,,,,,,,,, -7300,0.008083028,0.117734075,,,,,,,,,,, -7400,0.014961515,0.1271986,,,,,,,,,,, -7500,0.011283823,0.12253359,,,,,,,,,,, -7600,0.0067242947,0.11878133,,,,,,,,,,, -7700,0.010582604,0.12431135,,,,,,,,,,, -7800,0.010317053,0.12086873,,,,,,,,,,, -7845,,,0.1219598606416264,0.1240499002602977,83274637.0,0.1263900748766447,95000000.0,6020.769790649414,9512.933719873428,6020.769790649414,3491.767944574356,0.1436803340911865,0.0 -7900,0.009627467,0.11779968,,,,,,,,,,, -8000,0.007799819,0.12233286,,,,,,,,,,, -8100,0.009845127,0.12736928,,,,,,,,,,, -8200,0.012219266,0.12467092,,,,,,,,,,, -8300,0.008954754,0.118540175,,,,,,,,,,, -8400,0.010083014,0.12171955,,,,,,,,,,, -8500,0.009326627,0.13270353,,,,,,,,,,, -8600,0.009755687,0.12772913,,,,,,,,,,, -8700,0.012995502,0.12438075,,,,,,,,,,, -8800,0.01024156,0.1256401,,,,,,,,,,, -8900,0.010947768,0.12003619,,,,,,,,,,, -9000,0.016872626,0.1239472,,,,,,,,,,, -9100,0.012538108,0.120382436,,,,,,,,,,, -9200,0.0072747087,0.12947734,,,,,,,,,,, -9300,0.030503169,0.119596414,,,,,,,,,,, -9400,0.009772068,0.12796953,,,,,,,,,,, -9407,,,0.1237186695326049,0.1236622008068615,83274637.0,0.1259850416324013,95000000.0,7221.337500095367,11040.364230632782,7221.337500095367,3818.5553402900696,0.1691987514495849,0.0 -9500,0.010848333,0.12696044,,,,,,,,,,, -9600,0.013648354,0.12921098,,,,,,,,,,, -9700,0.009301061,0.12115734,,,,,,,,,,, -9800,0.03084622,0.12702872,,,,,,,,,,, -9900,0.010296184,0.12046442,,,,,,,,,,, -10000,0.014009819,0.12217678,,,,,,,,,,, -10040,,,,,,,,7703.609013557434,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/eval_measurements.csv deleted file mode 100644 index b35650696..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -174.805522441864,0.0,5.898716688156128,1,0,5.898716688156128,0.5930051776315789,95000000,180.70427775383,0.5933512944095539,0.5934494414247642,83274637 -198.0180559158325,0.0196254253387451,1205.983119249344,1531,0,1205.983119249344,0.1279737744449013,95000000,1404.0712733268738,0.1239531634832328,0.1258019287808483,83274637 -221.8493254184723,0.0473461151123046,2406.7319440841675,3054,0,2406.7319440841675,0.1272648790810032,95000000,2628.7291634082794,0.1253895798475487,0.1249785106294579,83274637 -245.3142669200897,0.0717847347259521,3606.915803670883,4590,0,3606.915803670883,0.1271673913548519,95000000,3852.453292131424,0.1221171261848143,0.1247986461126813,83274637 -268.6245036125183,0.1031219959259033,4807.1970937252045,6111,0,4807.1970937252045,0.1264952478412828,95000000,5076.125498533249,0.1246137860342391,0.1240950822251107,83274637 -292.01829075813293,0.1278812885284423,6007.624234676361,7631,0,6007.624234676361,0.1263192885587993,95000000,6300.0208122730255,0.1202213231872462,0.1239485521564026,83274637 -315.3039107322693,0.1507117748260498,7208.275316238403,9157,0,7208.275316238403,0.12599986039268093,95000000,7524.030328273773,0.12338495680933478,0.12366559019419682,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/measurements.csv deleted file mode 100644 index bd598d92a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.914909,0.59277475,,,,,,,,,,, -1,,,0.5933512944095539,0.5934494414247642,83274637.0,0.5930051776315789,95000000.0,5.898716688156128,180.70427775383,5.898716688156128,174.805522441864,0.0,0.0 -100,0.035585262,0.13042593,,,,,,,,,,, -200,0.038325682,0.12771219,,,,,,,,,,, -300,0.1536527,0.13090235,,,,,,,,,,, -400,0.06838267,0.12723756,,,,,,,,,,, -500,0.123228125,0.14195623,,,,,,,,,,, -600,0.0090460535,0.1193864,,,,,,,,,,, -700,0.019095574,0.1224008,,,,,,,,,,, -800,0.108041465,0.13673751,,,,,,,,,,, -900,0.037155125,0.13018928,,,,,,,,,,, -1000,0.093444616,0.13236484,,,,,,,,,,, -1100,0.014354824,0.12827104,,,,,,,,,,, -1200,0.010635572,0.12745418,,,,,,,,,,, -1300,0.06825677,0.12933087,,,,,,,,,,, -1400,0.046608318,0.12176815,,,,,,,,,,, -1500,0.007399584,0.13195142,,,,,,,,,,, -1531,,,0.1239531634832328,0.1258019287808483,83274637.0,0.1279737744449013,95000000.0,1205.983119249344,1404.0712733268738,1205.983119249344,198.0180559158325,0.0196254253387451,0.0 -1600,0.012717008,0.12676838,,,,,,,,,,, -1700,0.010494387,0.122260585,,,,,,,,,,, -1800,0.011165302,0.12857276,,,,,,,,,,, -1900,0.06543867,0.1260693,,,,,,,,,,, -2000,0.028851451,0.12157731,,,,,,,,,,, -2100,0.008672287,0.12029041,,,,,,,,,,, -2200,0.035681937,0.12697989,,,,,,,,,,, -2300,0.012967102,0.11844101,,,,,,,,,,, -2400,0.007209301,0.1216699,,,,,,,,,,, -2500,0.012357993,0.122197464,,,,,,,,,,, -2600,0.011433967,0.12858953,,,,,,,,,,, -2700,0.012455847,0.12437947,,,,,,,,,,, -2800,0.023287503,0.1219029,,,,,,,,,,, -2900,0.030160068,0.124660745,,,,,,,,,,, -3000,0.019657806,0.12064891,,,,,,,,,,, -3054,,,0.1253895798475487,0.1249785106294579,83274637.0,0.1272648790810032,95000000.0,2406.7319440841675,2628.7291634082794,2406.7319440841675,221.8493254184723,0.0473461151123046,0.0 -3100,0.0053508845,0.12125613,,,,,,,,,,, -3200,0.014917976,0.122469336,,,,,,,,,,, -3300,0.010104794,0.11698622,,,,,,,,,,, -3400,0.013784006,0.11879362,,,,,,,,,,, -3500,0.0074100657,0.12800623,,,,,,,,,,, -3600,0.017627517,0.11940335,,,,,,,,,,, -3700,0.02442179,0.13817334,,,,,,,,,,, -3800,0.012635296,0.122586645,,,,,,,,,,, -3900,0.037493967,0.12274308,,,,,,,,,,, -4000,0.0070193848,0.13355662,,,,,,,,,,, -4100,0.00741195,0.12301154,,,,,,,,,,, -4200,0.0130141815,0.12961441,,,,,,,,,,, -4300,0.021304572,0.13367344,,,,,,,,,,, -4400,0.014581209,0.12692337,,,,,,,,,,, -4500,0.0064866072,0.117241405,,,,,,,,,,, -4590,,,0.1221171261848143,0.1247986461126813,83274637.0,0.1271673913548519,95000000.0,3606.915803670883,3852.453292131424,3606.915803670883,245.3142669200897,0.0717847347259521,0.0 -4600,0.0061339596,0.12131926,,,,,,,,,,, -4700,0.03316738,0.12564002,,,,,,,,,,, -4800,0.007855878,0.12056483,,,,,,,,,,, -4900,0.005954286,0.11819187,,,,,,,,,,, -5000,0.00615738,0.116717726,,,,,,,,,,, -5100,0.0072085676,0.12507814,,,,,,,,,,, -5200,0.021872897,0.12474288,,,,,,,,,,, -5300,0.022114426,0.12227405,,,,,,,,,,, -5400,0.022410512,0.11717409,,,,,,,,,,, -5500,0.0063719437,0.13320859,,,,,,,,,,, -5600,0.013241503,0.12154569,,,,,,,,,,, -5700,0.005969502,0.12886854,,,,,,,,,,, -5800,0.011060309,0.11490949,,,,,,,,,,, -5900,0.013017315,0.13076125,,,,,,,,,,, -6000,0.009655029,0.122752234,,,,,,,,,,, -6100,0.013531063,0.119407386,,,,,,,,,,, -6111,,,0.1246137860342391,0.1240950822251107,83274637.0,0.1264952478412828,95000000.0,4807.1970937252045,5076.125498533249,4807.1970937252045,268.6245036125183,0.1031219959259033,0.0 -6200,0.02275796,0.12352772,,,,,,,,,,, -6300,0.005884124,0.12867364,,,,,,,,,,, -6400,0.010931259,0.11987159,,,,,,,,,,, -6500,0.010161537,0.11963152,,,,,,,,,,, -6600,0.012125693,0.13406762,,,,,,,,,,, -6700,0.0077737113,0.12324856,,,,,,,,,,, -6800,0.010502648,0.1194881,,,,,,,,,,, -6900,0.007609584,0.1246344,,,,,,,,,,, -7000,0.016178595,0.11932308,,,,,,,,,,, -7100,0.007455762,0.12563105,,,,,,,,,,, -7200,0.015846385,0.12055562,,,,,,,,,,, -7300,0.008747261,0.12766647,,,,,,,,,,, -7400,0.008582173,0.12492029,,,,,,,,,,, -7500,0.007120853,0.12265374,,,,,,,,,,, -7600,0.014348556,0.123910934,,,,,,,,,,, -7631,,,0.1202213231872462,0.1239485521564026,83274637.0,0.1263192885587993,95000000.0,6007.624234676361,6300.0208122730255,6007.624234676361,292.01829075813293,0.1278812885284423,0.0 -7700,0.011019513,0.1208985,,,,,,,,,,, -7800,0.0064207427,0.12430106,,,,,,,,,,, -7900,0.005934659,0.116357684,,,,,,,,,,, -8000,0.0058872984,0.11803908,,,,,,,,,,, -8100,0.007909869,0.1263311,,,,,,,,,,, -8200,0.006259214,0.125338,,,,,,,,,,, -8300,0.011303947,0.1185018,,,,,,,,,,, -8400,0.0064301314,0.11679474,,,,,,,,,,, -8500,0.006120118,0.12110862,,,,,,,,,,, -8600,0.009357987,0.1343663,,,,,,,,,,, -8700,0.006421581,0.13103485,,,,,,,,,,, -8800,0.006764949,0.122280955,,,,,,,,,,, -8900,0.010019633,0.1174753,,,,,,,,,,, -9000,0.0067053842,0.122513376,,,,,,,,,,, -9100,0.007962086,0.12763345,,,,,,,,,,, -9157,,,0.1233849568093347,0.1236655901941968,83274637.0,0.1259998603926809,95000000.0,7208.275316238403,7524.030328273773,7208.275316238403,315.3039107322693,0.1507117748260498,0.0 -9200,0.0068640723,0.121078335,,,,,,,,,,, -9300,0.007144509,0.12378404,,,,,,,,,,, -9400,0.006804951,0.12589239,,,,,,,,,,, -9500,0.01004989,0.122317314,,,,,,,,,,, -9600,0.0076734675,0.12142384,,,,,,,,,,, -9700,0.012096426,0.1340645,,,,,,,,,,, -9790,,,,,,,,7703.32422208786,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 2ef3a2d70..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -23.342114448547363,0.0,6.02808952331543,1,0,6.02808952331543,0.5930051776315789,95000000,29.370241165161133,0.5932942280229533,0.5934494414247642,83274637 -47.00682520866394,0.0167949199676513,1206.50141620636,1409,0,1206.50141620636,0.1284075707339638,95000000,1253.5709066390991,0.1253464801131554,0.1260380975339766,83274637 -72.8450014591217,0.0519769191741943,2406.9263048172,2832,0,2406.9263048172,0.1275689274362664,95000000,2479.915580034256,0.1256817717503451,0.1252486379890344,83274637 -96.83144330978394,0.0794475078582763,3607.616421222687,4262,0,3607.616421222687,0.1276711111739309,95000000,3704.666927576065,0.1236076376067017,0.1250838564710585,83274637 -120.74389028549194,0.1047101020812988,4808.179257631302,5685,0,4808.179257631302,0.126843351223273,95000000,4929.213754653931,0.1240097874271794,0.1244795237744476,83274637 -146.62033653259277,0.1263008117675781,6008.445064067841,7097,0,6008.445064067841,0.1263845815275493,95000000,6155.423595905304,0.1244147480377611,0.1240670128761714,83274637 -170.11274075508118,0.1586766242980957,7208.413768529892,8533,0,7208.413768529892,0.1260424281044408,95000000,7378.963323831558,0.12057999901051791,0.12371181727014012,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/measurements.csv deleted file mode 100644 index 6799a6925..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/measurements.csv +++ /dev/null @@ -1,100 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.873732,0.59355295,,,,,,,,,,, -1,,,0.5932942280229533,0.5934494414247642,83274637.0,0.5930051776315789,95000000.0,6.02808952331543,29.370241165161133,6.02808952331543,23.342114448547363,0.0,0.0 -100,0.25395226,0.1408786,,,,,,,,,,, -200,0.010124657,0.12807864,,,,,,,,,,, -300,0.016994115,0.13609675,,,,,,,,,,, -400,0.012184594,0.12207297,,,,,,,,,,, -500,0.011106229,0.12314696,,,,,,,,,,, -600,0.008041091,0.12635213,,,,,,,,,,, -700,0.016031127,0.12661436,,,,,,,,,,, -800,0.01807122,0.12861386,,,,,,,,,,, -900,0.018772678,0.11935973,,,,,,,,,,, -1000,0.041119404,0.12396681,,,,,,,,,,, -1100,0.009898819,0.12489991,,,,,,,,,,, -1200,0.02119652,0.121727735,,,,,,,,,,, -1300,0.016457219,0.11906231,,,,,,,,,,, -1400,0.024532428,0.1276725,,,,,,,,,,, -1409,,,0.1253464801131554,0.1260380975339766,83274637.0,0.1284075707339638,95000000.0,1206.50141620636,1253.5709066390991,1206.50141620636,47.00682520866394,0.0167949199676513,0.0 -1500,0.018689176,0.117896006,,,,,,,,,,, -1600,0.005682446,0.13321796,,,,,,,,,,, -1700,0.007728111,0.12621945,,,,,,,,,,, -1800,0.008801062,0.12968217,,,,,,,,,,, -1900,0.02295074,0.12754491,,,,,,,,,,, -2000,0.02060194,0.11915259,,,,,,,,,,, -2100,0.025264982,0.1300824,,,,,,,,,,, -2200,0.0073836623,0.13198209,,,,,,,,,,, -2300,0.019799558,0.13044976,,,,,,,,,,, -2400,0.016468037,0.11992418,,,,,,,,,,, -2500,0.0088135395,0.13891584,,,,,,,,,,, -2600,0.0065700705,0.12710197,,,,,,,,,,, -2700,0.008501959,0.12274535,,,,,,,,,,, -2800,0.008608259,0.1281683,,,,,,,,,,, -2832,,,0.1256817717503451,0.1252486379890344,83274637.0,0.1275689274362664,95000000.0,2406.9263048172,2479.915580034256,2406.9263048172,72.8450014591217,0.0519769191741943,0.0 -2900,0.01055797,0.11624173,,,,,,,,,,, -3000,0.016655372,0.1245722,,,,,,,,,,, -3100,0.012009731,0.12170377,,,,,,,,,,, -3200,0.020335007,0.13205361,,,,,,,,,,, -3300,0.010208777,0.11956537,,,,,,,,,,, -3400,0.028946366,0.12381387,,,,,,,,,,, -3500,0.008495304,0.12308019,,,,,,,,,,, -3600,0.0214716,0.12180945,,,,,,,,,,, -3700,0.008750973,0.12344028,,,,,,,,,,, -3800,0.0068461313,0.11943189,,,,,,,,,,, -3900,0.0059930016,0.12153973,,,,,,,,,,, -4000,0.0077864425,0.120625205,,,,,,,,,,, -4100,0.007798162,0.13018718,,,,,,,,,,, -4200,0.00820208,0.11743281,,,,,,,,,,, -4262,,,0.1236076376067017,0.1250838564710585,83274637.0,0.1276711111739309,95000000.0,3607.616421222687,3704.666927576065,3607.616421222687,96.83144330978394,0.0794475078582763,0.0 -4300,0.010731634,0.13271537,,,,,,,,,,, -4400,0.018631442,0.123892,,,,,,,,,,, -4500,0.011546873,0.1179961,,,,,,,,,,, -4600,0.014674871,0.12230605,,,,,,,,,,, -4700,0.0067299465,0.120535284,,,,,,,,,,, -4800,0.0060041556,0.11469706,,,,,,,,,,, -4900,0.017248064,0.12754162,,,,,,,,,,, -5000,0.014183971,0.1234348,,,,,,,,,,, -5100,0.010803265,0.11993433,,,,,,,,,,, -5200,0.008172727,0.13091622,,,,,,,,,,, -5300,0.018895637,0.12039246,,,,,,,,,,, -5400,0.011165249,0.13082927,,,,,,,,,,, -5500,0.011258742,0.11628915,,,,,,,,,,, -5600,0.008654639,0.12828615,,,,,,,,,,, -5685,,,0.1240097874271794,0.1244795237744476,83274637.0,0.126843351223273,95000000.0,4808.179257631302,4929.213754653931,4808.179257631302,120.74389028549194,0.1047101020812988,0.0 -5700,0.013059211,0.120806634,,,,,,,,,,, -5800,0.006398701,0.12164092,,,,,,,,,,, -5900,0.012899922,0.12569949,,,,,,,,,,, -6000,0.006355007,0.1192963,,,,,,,,,,, -6100,0.0055272896,0.11859021,,,,,,,,,,, -6200,0.011924344,0.12497997,,,,,,,,,,, -6300,0.006891661,0.12354508,,,,,,,,,,, -6400,0.00819856,0.12045469,,,,,,,,,,, -6500,0.014208573,0.12019224,,,,,,,,,,, -6600,0.008598413,0.11673829,,,,,,,,,,, -6700,0.0051671714,0.12530577,,,,,,,,,,, -6800,0.009382465,0.1222594,,,,,,,,,,, -6900,0.011887474,0.12930696,,,,,,,,,,, -7000,0.007001202,0.13055089,,,,,,,,,,, -7097,,,0.1244147480377611,0.1240670128761714,83274637.0,0.1263845815275493,95000000.0,6008.445064067841,6155.423595905304,6008.445064067841,146.62033653259277,0.1263008117675781,0.0 -7100,0.010064375,0.1273942,,,,,,,,,,, -7200,0.011291899,0.1208338,,,,,,,,,,, -7300,0.018254519,0.1250281,,,,,,,,,,, -7400,0.010374415,0.12060057,,,,,,,,,,, -7500,0.015413148,0.12012482,,,,,,,,,,, -7600,0.009908596,0.13026617,,,,,,,,,,, -7700,0.0090893395,0.11652069,,,,,,,,,,, -7800,0.007880267,0.12032688,,,,,,,,,,, -7900,0.008594543,0.12549949,,,,,,,,,,, -8000,0.023940096,0.12540232,,,,,,,,,,, -8100,0.0073465398,0.11880749,,,,,,,,,,, -8200,0.008815162,0.12186946,,,,,,,,,,, -8300,0.0068310406,0.118892565,,,,,,,,,,, -8400,0.00730676,0.119618855,,,,,,,,,,, -8500,0.00590336,0.115306735,,,,,,,,,,, -8533,,,0.1205799990105179,0.1237118172701401,83274637.0,0.1260424281044408,95000000.0,7208.413768529892,7378.963323831558,7208.413768529892,170.11274075508118,0.1586766242980957,0.0 -8600,0.007281335,0.12854357,,,,,,,,,,, -8700,0.0070180907,0.122536406,,,,,,,,,,, -8800,0.00996294,0.1325045,,,,,,,,,,, -8900,0.008183284,0.11868498,,,,,,,,,,, -9000,0.009208336,0.11506684,,,,,,,,,,, -9051,,,,,,,,7703.254302978516,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 62927b9e6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.948208332061768,0.0,9.029136180877686,1,0,9.029136180877686,0.5930051776315789,95000000,31.97738528251648,0.5933495756215269,0.5934494414247642,83274637 -44.91099667549133,0.0170416831970214,1209.289989709854,1524,0,1209.289989709854,0.1289717205078125,95000000,1254.2688052654266,0.1254836830702967,0.1263740532445221,83274637 -66.89414644241333,0.0477163791656494,2409.679847240448,3041,0,2409.679847240448,0.1296342404913651,95000000,2476.72232413292,0.1241661334094011,0.1265863525131607,83274637 -88.92537879943848,0.0753521919250488,3609.713275909424,4576,0,3609.713275909424,0.1281889597039473,95000000,3698.865172863007,0.1279426673017208,0.1258326709029965,83274637 -111.1356008052826,0.1027441024780273,4809.681446075439,6120,0,4809.681446075439,0.1283866040501644,95000000,4921.120703935623,0.1255744771106438,0.1258819443746869,83274637 -133.25801420211792,0.12852144241333,6010.143446445465,7648,0,6010.143446445465,0.1280765943976151,95000000,6143.780212402344,0.1243427811459925,0.1257455138415223,83274637 -155.45485138893127,0.15249276161193848,7210.4379069805145,9180,0,7210.4379069805145,0.12752799109786184,95000000,7366.344910621643,0.12473163842780036,0.125333197705428,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/measurements.csv deleted file mode 100644 index 4ecf2c75e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.845667,0.59345335,,,,,,,,,,, -1,,,0.5933495756215269,0.5934494414247642,83274637.0,0.5930051776315789,95000000.0,9.029136180877686,31.97738528251648,9.029136180877686,22.948208332061768,0.0,0.0 -100,0.05743124,0.1266821,,,,,,,,,,, -200,0.041869335,0.12816177,,,,,,,,,,, -300,0.06883074,0.13286163,,,,,,,,,,, -400,0.083166525,0.13539891,,,,,,,,,,, -500,0.009892915,0.119244024,,,,,,,,,,, -600,0.029702952,0.12501942,,,,,,,,,,, -700,0.011357653,0.12633541,,,,,,,,,,, -800,0.050160404,0.13230437,,,,,,,,,,, -900,0.055211782,0.12864017,,,,,,,,,,, -1000,0.0048211906,0.13456614,,,,,,,,,,, -1100,0.04581717,0.12233243,,,,,,,,,,, -1200,0.017387742,0.12815289,,,,,,,,,,, -1300,0.032440748,0.12281245,,,,,,,,,,, -1400,0.03238756,0.122843,,,,,,,,,,, -1500,0.056595292,0.12962422,,,,,,,,,,, -1524,,,0.1254836830702967,0.1263740532445221,83274637.0,0.1289717205078125,95000000.0,1209.289989709854,1254.2688052654266,1209.289989709854,44.91099667549133,0.0170416831970214,0.0 -1600,0.00782569,0.13058911,,,,,,,,,,, -1700,0.01717303,0.12994598,,,,,,,,,,, -1800,0.0037682792,0.11551578,,,,,,,,,,, -1900,0.054935567,0.13160855,,,,,,,,,,, -2000,0.02001081,0.124277525,,,,,,,,,,, -2100,0.057599936,0.12329046,,,,,,,,,,, -2200,0.017704211,0.13922234,,,,,,,,,,, -2300,0.0410198,0.13337639,,,,,,,,,,, -2400,0.0075502563,0.12365502,,,,,,,,,,, -2500,0.050578326,0.12523378,,,,,,,,,,, -2600,0.022786738,0.123649985,,,,,,,,,,, -2700,0.003688438,0.12076392,,,,,,,,,,, -2800,0.025883988,0.13338341,,,,,,,,,,, -2900,0.030015698,0.13131219,,,,,,,,,,, -3000,0.007042267,0.1218565,,,,,,,,,,, -3041,,,0.1241661334094011,0.1265863525131607,83274637.0,0.1296342404913651,95000000.0,2409.679847240448,2476.72232413292,2409.679847240448,66.89414644241333,0.0477163791656494,0.0 -3100,0.031767305,0.11983143,,,,,,,,,,, -3200,0.048554216,0.1353339,,,,,,,,,,, -3300,0.0072070123,0.12048098,,,,,,,,,,, -3400,0.024993517,0.12387927,,,,,,,,,,, -3500,0.025503049,0.12756754,,,,,,,,,,, -3600,0.028932774,0.12012692,,,,,,,,,,, -3700,0.03146845,0.118941076,,,,,,,,,,, -3800,0.050061513,0.124015644,,,,,,,,,,, -3900,0.040034514,0.12726222,,,,,,,,,,, -4000,0.005836443,0.121048726,,,,,,,,,,, -4100,0.024984118,0.123593375,,,,,,,,,,, -4200,0.027611608,0.13517399,,,,,,,,,,, -4300,0.03596832,0.12604603,,,,,,,,,,, -4400,0.014495261,0.12361494,,,,,,,,,,, -4500,0.0523875,0.12302452,,,,,,,,,,, -4576,,,0.1279426673017208,0.1258326709029965,83274637.0,0.1281889597039473,95000000.0,3609.713275909424,3698.865172863007,3609.713275909424,88.92537879943848,0.0753521919250488,0.0 -4600,0.014998914,0.122955255,,,,,,,,,,, -4700,0.012966924,0.12254276,,,,,,,,,,, -4800,0.013869367,0.121149905,,,,,,,,,,, -4900,0.034132916,0.123319164,,,,,,,,,,, -5000,0.03006323,0.12483159,,,,,,,,,,, -5100,0.03385884,0.13173713,,,,,,,,,,, -5200,0.063082986,0.13731183,,,,,,,,,,, -5300,0.03994119,0.12390219,,,,,,,,,,, -5400,0.02929136,0.12695438,,,,,,,,,,, -5500,0.018164003,0.1327116,,,,,,,,,,, -5600,0.0229113,0.12053318,,,,,,,,,,, -5700,0.04380983,0.11952726,,,,,,,,,,, -5800,0.02368462,0.12468834,,,,,,,,,,, -5900,0.046610456,0.122369125,,,,,,,,,,, -6000,0.008876641,0.12382304,,,,,,,,,,, -6100,0.06523617,0.13113764,,,,,,,,,,, -6120,,,0.1255744771106438,0.1258819443746869,83274637.0,0.1283866040501644,95000000.0,4809.681446075439,4921.120703935623,4809.681446075439,111.1356008052826,0.1027441024780273,0.0 -6200,0.026613563,0.11951559,,,,,,,,,,, -6300,0.008068639,0.12169326,,,,,,,,,,, -6400,0.017495608,0.12281083,,,,,,,,,,, -6500,0.025020907,0.124823436,,,,,,,,,,, -6600,0.033333857,0.1263564,,,,,,,,,,, -6700,0.030333685,0.1275728,,,,,,,,,,, -6800,0.02047847,0.12644152,,,,,,,,,,, -6900,0.036107447,0.12703872,,,,,,,,,,, -7000,0.023355784,0.120779134,,,,,,,,,,, -7100,0.014695691,0.12447137,,,,,,,,,,, -7200,0.019391306,0.13151048,,,,,,,,,,, -7300,0.019079268,0.12355882,,,,,,,,,,, -7400,0.042236418,0.13699006,,,,,,,,,,, -7500,0.013074093,0.123486325,,,,,,,,,,, -7600,0.020368593,0.12572527,,,,,,,,,,, -7648,,,0.1243427811459925,0.1257455138415223,83274637.0,0.1280765943976151,95000000.0,6010.143446445465,6143.780212402344,6010.143446445465,133.25801420211792,0.12852144241333,0.0 -7700,0.007146137,0.120583214,,,,,,,,,,, -7800,0.008436137,0.118814915,,,,,,,,,,, -7900,0.016101664,0.12740539,,,,,,,,,,, -8000,0.023243682,0.1311157,,,,,,,,,,, -8100,0.025446007,0.13100976,,,,,,,,,,, -8200,0.011691874,0.1256075,,,,,,,,,,, -8300,0.009558637,0.121397816,,,,,,,,,,, -8400,0.007607832,0.12694703,,,,,,,,,,, -8500,0.014891397,0.12207882,,,,,,,,,,, -8600,0.012049356,0.12128803,,,,,,,,,,, -8700,0.009427277,0.12838463,,,,,,,,,,, -8800,0.010789712,0.13051012,,,,,,,,,,, -8900,0.01505493,0.122614704,,,,,,,,,,, -9000,0.011050527,0.1253872,,,,,,,,,,, -9100,0.020840582,0.121701725,,,,,,,,,,, -9180,,,0.1247316384278003,0.125333197705428,83274637.0,0.1275279910978618,95000000.0,7210.437906980514,7366.344910621643,7210.437906980514,155.45485138893127,0.1524927616119384,0.0 -9200,0.011808751,0.12181353,,,,,,,,,,, -9300,0.012744673,0.122300744,,,,,,,,,,, -9400,0.011783707,0.12854856,,,,,,,,,,, -9500,0.0090225,0.12636255,,,,,,,,,,, -9600,0.013927718,0.118918926,,,,,,,,,,, -9700,0.010877438,0.12908308,,,,,,,,,,, -9800,0.009531075,0.13543187,,,,,,,,,,, -9805,,,,,,,,7703.342462778091,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 6b7fd89d3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.243659734725952,0.0,6.33348536491394,1,0,6.33348536491394,0.5930051776315789,95000000,28.57718396186829,0.5933873810858097,0.5934494414247642,83274637 -44.658483028411865,0.0171129703521728,1206.740442276001,1400,0,1206.740442276001,0.1280286769942434,95000000,1251.46262717247,0.1248062971815373,0.1256744583869365,83274637 -66.82098436355591,0.0410530567169189,2406.969470500946,2807,0,2406.969470500946,0.1272585674650493,95000000,2473.924459457397,0.1213179879218527,0.1250336486850282,83274637 -89.17287230491638,0.0666711330413818,3607.6907799243927,4214,0,3607.6907799243927,0.1268623212993421,95000000,3697.0699832439423,0.1216190793791657,0.1245090394858911,83274637 -111.49109768867493,0.0916898250579834,4807.679771661758,5614,0,4807.679771661758,0.126648775051398,95000000,4919.4477915763855,0.121329590490779,0.1242824161436302,83274637 -133.719393491745,0.1192834377288818,6007.800831317902,7014,0,6007.800831317902,0.1262987828022204,95000000,6141.869731664658,0.1249938235155441,0.1239446981449315,83274637 -155.94320702552795,0.14342546463012695,7207.755153179169,8407,0,7207.755153179169,0.12602336692023025,95000000,7364.11737537384,0.12131002934286429,0.12371043292006484,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/measurements.csv deleted file mode 100644 index c8bd9561f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/measurements.csv +++ /dev/null @@ -1,99 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.871594,0.5947284,,,,,,,,,,, -1,,,0.5933873810858097,0.5934494414247642,83274637.0,0.5930051776315789,95000000.0,6.33348536491394,28.57718396186829,6.33348536491394,22.243659734725952,0.0,0.0 -100,0.028779574,0.13014811,,,,,,,,,,, -200,0.100017466,0.13098386,,,,,,,,,,, -300,0.007917808,0.13210727,,,,,,,,,,, -400,0.02360417,0.12493366,,,,,,,,,,, -500,0.008332814,0.13066931,,,,,,,,,,, -600,0.04048108,0.12682036,,,,,,,,,,, -700,0.03823329,0.12128695,,,,,,,,,,, -800,0.02009284,0.12560292,,,,,,,,,,, -900,0.059780613,0.123136446,,,,,,,,,,, -1000,0.072597675,0.1273559,,,,,,,,,,, -1100,0.02784012,0.13111618,,,,,,,,,,, -1200,0.014507963,0.12230516,,,,,,,,,,, -1300,0.017871656,0.12331464,,,,,,,,,,, -1400,,,0.1248062971815373,0.1256744583869365,83274637.0,0.1280286769942434,95000000.0,1206.740442276001,1251.46262717247,1206.740442276001,44.658483028411865,0.0171129703521728,0.0 -1400,0.013589326,0.13454707,,,,,,,,,,, -1500,0.026255926,0.1139804,,,,,,,,,,, -1600,0.025022952,0.12598485,,,,,,,,,,, -1700,0.011737939,0.12107222,,,,,,,,,,, -1800,0.018194545,0.11931758,,,,,,,,,,, -1900,0.022770146,0.11409718,,,,,,,,,,, -2000,0.084341116,0.13845351,,,,,,,,,,, -2100,0.008544769,0.12191425,,,,,,,,,,, -2200,0.026063647,0.122912705,,,,,,,,,,, -2300,0.017973594,0.12817414,,,,,,,,,,, -2400,0.010252729,0.12361212,,,,,,,,,,, -2500,0.008914752,0.12933712,,,,,,,,,,, -2600,0.020441653,0.12120091,,,,,,,,,,, -2700,0.006869001,0.13112916,,,,,,,,,,, -2800,0.020573903,0.12279901,,,,,,,,,,, -2807,,,0.1213179879218527,0.1250336486850282,83274637.0,0.1272585674650493,95000000.0,2406.969470500946,2473.924459457397,2406.969470500946,66.82098436355591,0.0410530567169189,0.0 -2900,0.025189018,0.119222775,,,,,,,,,,, -3000,0.0057116547,0.12510422,,,,,,,,,,, -3100,0.018301677,0.115468524,,,,,,,,,,, -3200,0.017650463,0.12193138,,,,,,,,,,, -3300,0.006244952,0.11758729,,,,,,,,,,, -3400,0.006999409,0.12707454,,,,,,,,,,, -3500,0.0078973975,0.120540574,,,,,,,,,,, -3600,0.01703777,0.12362445,,,,,,,,,,, -3700,0.02881737,0.1262163,,,,,,,,,,, -3800,0.0062961974,0.12061258,,,,,,,,,,, -3900,0.027553586,0.12257093,,,,,,,,,,, -4000,0.022258913,0.1220726,,,,,,,,,,, -4100,0.030255327,0.12309597,,,,,,,,,,, -4200,0.013551188,0.12293242,,,,,,,,,,, -4214,,,0.1216190793791657,0.1245090394858911,83274637.0,0.1268623212993421,95000000.0,3607.6907799243927,3697.0699832439423,3607.6907799243927,89.17287230491638,0.0666711330413818,0.0 -4300,0.005704558,0.117092475,,,,,,,,,,, -4400,0.005427836,0.12564254,,,,,,,,,,, -4500,0.0146389995,0.12578905,,,,,,,,,,, -4600,0.00866583,0.1224213,,,,,,,,,,, -4700,0.005216215,0.116235815,,,,,,,,,,, -4800,0.0062964964,0.120444536,,,,,,,,,,, -4900,0.02736931,0.12039913,,,,,,,,,,, -5000,0.008606965,0.11474995,,,,,,,,,,, -5100,0.012897996,0.12540855,,,,,,,,,,, -5200,0.009024596,0.12238693,,,,,,,,,,, -5300,0.0072751897,0.12605529,,,,,,,,,,, -5400,0.004641899,0.114183456,,,,,,,,,,, -5500,0.013344496,0.12461446,,,,,,,,,,, -5600,0.014267033,0.11848843,,,,,,,,,,, -5614,,,0.121329590490779,0.1242824161436302,83274637.0,0.126648775051398,95000000.0,4807.679771661758,4919.4477915763855,4807.679771661758,111.49109768867493,0.0916898250579834,0.0 -5700,0.012901145,0.12284366,,,,,,,,,,, -5800,0.0076943226,0.1228536,,,,,,,,,,, -5900,0.008050752,0.12264983,,,,,,,,,,, -6000,0.008158665,0.12346746,,,,,,,,,,, -6100,0.0055644475,0.118378714,,,,,,,,,,, -6200,0.007889835,0.12511064,,,,,,,,,,, -6300,0.009402786,0.12552646,,,,,,,,,,, -6400,0.0062739775,0.123054095,,,,,,,,,,, -6500,0.016126271,0.12700468,,,,,,,,,,, -6600,0.005817744,0.115803726,,,,,,,,,,, -6700,0.007863312,0.12197952,,,,,,,,,,, -6800,0.007603108,0.12500659,,,,,,,,,,, -6900,0.006220428,0.12690581,,,,,,,,,,, -7000,0.010347856,0.11541876,,,,,,,,,,, -7014,,,0.1249938235155441,0.1239446981449315,83274637.0,0.1262987828022204,95000000.0,6007.800831317902,6141.869731664658,6007.800831317902,133.719393491745,0.1192834377288818,0.0 -7100,0.009512276,0.1250416,,,,,,,,,,, -7200,0.013385177,0.11909772,,,,,,,,,,, -7300,0.0051893303,0.12184197,,,,,,,,,,, -7400,0.011385346,0.14013235,,,,,,,,,,, -7500,0.006522644,0.12525696,,,,,,,,,,, -7600,0.0062642545,0.117150754,,,,,,,,,,, -7700,0.0113440305,0.124072686,,,,,,,,,,, -7800,0.011299395,0.12408972,,,,,,,,,,, -7900,0.004999797,0.11867404,,,,,,,,,,, -8000,0.006393371,0.12989137,,,,,,,,,,, -8100,0.0059442697,0.12837735,,,,,,,,,,, -8200,0.0063086394,0.12310727,,,,,,,,,,, -8300,0.0053978576,0.122440495,,,,,,,,,,, -8400,0.0060058925,0.12765025,,,,,,,,,,, -8407,,,0.1213100293428642,0.1237104329200648,83274637.0,0.1260233669202302,95000000.0,7207.755153179169,7364.11737537384,7207.755153179169,155.94320702552795,0.1434254646301269,0.0 -8500,0.008048144,0.11732805,,,,,,,,,,, -8600,0.007854761,0.12594064,,,,,,,,,,, -8700,0.010085905,0.12122847,,,,,,,,,,, -8800,0.006912114,0.1279737,,,,,,,,,,, -8900,0.0068079783,0.121996015,,,,,,,,,,, -8904,,,,,,,,7703.589025020599,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index c82ccdf44..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,109 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -204.0422370433808,0.0,57.97748017311096,1,0,57.97748017311096,1.0888061012112538,3581,0.2660334975118245,262.02021503448486,1.0853406361171178,0.246516398021153,1.0927659497880908,3554,0.2452213628679832 -208.51142573356628,0.0327417850494384,138.13936042785645,334,0,138.13936042785645,0.3475053409596656,3581,0.6770980140411198,346.6962506771088,0.3249119349888393,0.6825276102338519,0.3466054973269555,3554,0.6577410850889842 -212.57180762290955,0.0704364776611328,218.2856752872467,572,0,218.2856752872467,0.3260053194158405,3581,0.6999024937386554,430.9490616321564,0.3037167617252895,0.7051773752485003,0.3242165174253482,3554,0.6816797122300576 -216.63490796089167,0.1094112396240234,298.414998292923,815,0,298.414998292923,0.3135829862686749,3581,0.7135750505052709,515.1890535354614,0.2911108561924526,0.719536576952253,0.3113940855594049,3554,0.6960570808200267 -220.69820737838745,0.1396596431732177,378.5005056858063,1094,0,378.5005056858063,0.307564657382278,3581,0.7193366601726124,599.3780615329742,0.2851648841585432,0.7254536492483956,0.3053462125641882,3554,0.7022134225960186 -224.76212549209595,0.1666257381439209,458.5756525993347,1440,0,458.5756525993347,0.3017688230308398,3581,0.7263367670736177,683.556557893753,0.2791761670793806,0.7331483023507255,0.2996527350344682,3554,0.7092399189293753 -228.82746934890747,0.1932723522186279,538.5728995800018,1784,0,538.5728995800018,0.2988863819572919,3581,0.7292040047516406,767.6581978797913,0.2767836877277919,0.735295023236956,0.2969039204285664,3554,0.7124547577333639 -232.8936219215393,0.2190008163452148,618.6690142154694,2125,0,618.6690142154694,0.2971061529709753,3581,0.731143971655962,851.858500957489,0.2746272768293108,0.7378713744027274,0.2951256918920582,3554,0.7141197088667698 -236.95533347129825,0.2474796772003173,698.6350164413452,2468,0,698.6350164413452,0.2955118076523667,3581,0.7325169814428582,935.9269535541534,0.2728226866040911,0.7398177555629185,0.2936773348471792,3554,0.7154810298475309 -241.0200252532959,0.2753784656524658,778.713544845581,2812,0,778.713544845581,0.2939528119327702,3581,0.7338114517156521,1020.110639810562,0.2718547412327358,0.7401653017316546,0.292136274433297,3554,0.7169991119161508 -245.08817720413208,0.3041086196899414,858.7817091941833,3156,0,858.7817091941833,0.2940611787362992,3581,0.7323439490758518,1104.2880628108978,0.2720067501068115,0.7388830184936523,0.2924436484418964,3554,0.7153528457152153 -249.155867099762,0.3296148777008056,938.8269300460817,3500,0,938.8269300460817,0.2924660834744135,3581,0.7362849691732407,1188.4386489391327,0.2697582415172032,0.7438075883047921,0.2907793499072348,3554,0.7191269959904333 -253.22288179397583,0.3548049926757812,1018.9434487819672,3846,0,1018.9434487819672,0.2917956341847598,3581,0.7364441616779531,1272.6597929000854,0.2694469520023891,0.7432947158813477,0.2902159167597249,3554,0.7191311863613182 -257.2856154441833,0.3827457427978515,1099.1777846813202,4189,0,1099.1777846813202,0.2911862711838522,3581,0.7380193152663362,1356.9970092773438,0.2687106643404279,0.7452216829572406,0.2896371647153735,3554,0.7207494938581528 -261.35140013694763,0.4086706638336181,1179.3721747398376,4533,0,1179.3721747398376,0.2911826578207903,3581,0.7377660389686889,1441.2953598499298,0.2684806244713919,0.7452092170715332,0.2896702068202201,3554,0.7205184052080402 -265.4180963039398,0.4344563484191894,1259.479147195816,4881,0,1259.479147195816,0.2910055007657602,3581,0.7387648952370148,1525.5069935321808,0.2685016393661499,0.745835712977818,0.289429432230849,3554,0.7216284413249155 -269.48070764541626,0.4603831768035888,1339.4585328102112,5226,0,1339.4585328102112,0.2902943840974937,3581,0.7380662208094806,1609.586873292923,0.2679123026984079,0.7452772004263741,0.2888458372168683,3554,0.7207955879378869 -273.5457835197449,0.4872004985809326,1419.520034074783,5566,0,1419.520034074783,0.2902249802560388,3581,0.7375023316418249,1693.7521421909332,0.2673601422991071,0.7456020627702985,0.2887105088456668,3554,0.7201848242077589 -277.6095860004425,0.5148611068725586,1499.587522983551,5911,0,1499.587522983551,0.2898647688647549,3581,0.7391205047036442,1777.9230041503906,0.2671688624790737,0.7467336654663086,0.2883714666243142,3554,0.7218456536648846 -281.67469453811646,0.540931224822998,1579.6640601158142,6252,0,1579.6640601158142,0.2898444522196139,3581,0.7387071496046844,1862.1026401519773,0.2673352786472865,0.7460424559456962,0.2884108629800928,3554,0.7214108168173186 -285.7416567802429,0.5702078342437744,1659.788988828659,6595,0,1659.788988828659,0.2901490996317369,3581,0.7393427606202876,1946.3356404304504,0.2672669036047799,0.7469039644513812,0.288696701230128,3554,0.7222953285470597 -289.8089179992676,0.5982434749603271,1739.779695034027,6938,0,1739.779695034027,0.2896087314123149,3581,0.7393149445423765,2030.43337893486,0.2670338494437081,0.7465673855372837,0.288116661161939,3554,0.722022336188098 -293.8766875267029,0.6245615482330322,1819.796627521515,7283,0,1819.796627521515,0.2893108334896851,3581,0.7390221257810319,2114.556656360626,0.2668186085564749,0.746274607522147,0.2878892648468275,3554,0.7219009528216446 -297.93975472450256,0.6534247398376465,1899.944031238556,7624,0,1899.944031238556,0.2896880549580249,3581,0.7385133233559061,2198.8077549934387,0.266804644039699,0.7464902741568429,0.2883452596326322,3554,0.7211011415394626 -302.0066475868225,0.6825459003448486,1980.0877783298488,7970,0,1980.0877783298488,0.2889701547119345,3581,0.7405279437046216,2283.059469461441,0.2662150178636823,0.7481408800397601,0.2875332722317723,3554,0.7232896829496693 -306.07082080841064,0.7084383964538574,2060.240626811981,8316,0,2060.240626811981,0.2890171284317404,3581,0.7396587594465582,2367.314504146576,0.2663652215685163,0.747328553880964,0.2876317802948702,3554,0.7223760447075478 -310.1355743408203,0.7345848083496094,2140.375987768173,8659,0,2140.375987768173,0.2889147270860968,3581,0.7407127706340757,2451.5527720451355,0.2662083080836704,0.748100825718471,0.2874923130737373,3554,0.7235134212770822 -314.2036237716675,0.7611832618713379,2220.5630145072937,9003,0,2220.5630145072937,0.2899190033990156,3581,0.7374599939350042,2535.846463680268,0.2672805956431797,0.7448654856000628,0.2884747146151343,3554,0.720109260142621 -318.2686553001404,0.7889750003814697,2300.723839521408,9348,0,2300.723839521408,0.2884272639560353,3581,0.7397366171940449,2620.1121048927307,0.2657300063541957,0.7472290992736816,0.2869929032977806,3554,0.7225334927414533 -322.3337795734405,0.8150200843811035,2380.7071380615234,9690,0,2380.7071380615234,0.2881913386187517,3581,0.741373334307805,2704.198437929153,0.2654880796160017,0.7489068167550224,0.2868176118403119,3554,0.7240464227147229 -326.4004166126251,0.8413591384887695,2460.8392124176025,10034,0,2460.8392124176025,0.2881762033998883,3581,0.7419083165709648,2788.4354214668274,0.2650915724890573,0.7498970031738281,0.2868405730119231,3554,0.7247131725476575 -330.4654278755188,0.8701965808868408,2540.927223443985,10380,0,2540.927223443985,0.2879667306072849,3581,0.741463804737678,2872.6292066574097,0.265133125441415,0.7492196900503976,0.286560779887099,3554,0.7242343024584975 -334.5306432247162,0.8970785140991211,2621.015692472458,10722,0,2621.015692472458,0.2882319719090337,3581,0.7410609488445965,2956.8216876983643,0.2656360183443342,0.7487141064235142,0.2868639120038601,3554,0.7237369535206809 -338.6022090911865,0.923598051071167,2701.032628774643,11068,0,2701.032628774643,0.2879757981032707,3581,0.741399241439193,3040.948775529861,0.2647242035184587,0.7497291564941406,0.2865727155746606,3554,0.7241230858935355 -342.6689305305481,0.9507875442504884,2781.2527372837067,11413,0,2781.2527372837067,0.2877792106996998,3581,0.7419121344640115,3125.2749075889587,0.2648477043424334,0.7497448240007673,0.2863667863235087,3554,0.7246891981306275 -346.7381579875946,0.9784178733825684,2861.2369871139526,11756,0,2861.2369871139526,0.2881267071436051,3581,0.7417808943905334,3209.367850065232,0.2651637281690325,0.7496111733572823,0.2867564908158061,3554,0.7245175990081598 -350.7999076843262,1.005598783493042,2941.4175691604614,12099,0,2941.4175691604614,0.2876435732315519,3581,0.7413425866334473,3293.6491615772247,0.2644143274852207,0.7495346069335938,0.286263383769828,3554,0.724081250879291 -354.8683977127075,1.0340232849121094,3021.535113096237,12444,0,3021.535113096237,0.2880808242503839,3581,0.7408034455939333,3377.875727415085,0.2650868552071707,0.7488523891993931,0.2867186057413302,3554,0.7235066205112197 -358.9352297782898,1.0614628791809082,3101.513483762741,12787,0,3101.513483762741,0.2880176926617914,3581,0.7415428214884111,3461.960406780243,0.2650796856198992,0.7494787488664899,0.2866628085486863,3554,0.7243308183780599 -362.99676966667175,1.0889112949371338,3181.577430486679,13129,0,3181.577430486679,0.2884689198962405,3581,0.7402498511021712,3546.125237464905,0.26520950453622,0.7483366557529995,0.2870787372063168,3554,0.7230694480470948 -367.06433033943176,1.1174826622009275,3261.7306559085846,13473,0,3261.7306559085846,0.2889175905058817,3581,0.739767023985095,3630.386870384216,0.2661695991243635,0.7472009658813477,0.2874675143214512,3554,0.7228545026290799 -371.1292831897736,1.147545337677002,3341.8958337306976,13817,0,3341.8958337306976,0.2879727642418319,3581,0.7417157175020944,3714.6591458320618,0.2649522338594709,0.7497094018118722,0.2865731449159397,3554,0.7245324370427687 -375.1954782009125,1.1753830909729004,3422.105791330337,14158,0,3422.105791330337,0.2877275327902471,3581,0.7407984686976403,3798.9750142097473,0.2646031379699707,0.7489772524152484,0.286383101292118,3554,0.7235373269995076 -379.2632808685303,1.2069926261901855,3502.180704832077,14503,0,3502.180704832077,0.2876645034666469,3581,0.7420518284435214,3883.161416769028,0.2646099499293736,0.7499461855207171,0.2862713008230163,3554,0.7248605911692811 -383.3338749408722,1.2347795963287354,3582.177492141724,14846,0,3582.177492141724,0.2880744156441985,3581,0.7409713647113236,3967.268642663956,0.2652239799499511,0.7488870620727539,0.2867498789601065,3554,0.7236846769265265 -387.40187311172485,1.2628545761108398,3662.319780349731,15188,0,3662.319780349731,0.287569806083758,3581,0.7422340646598367,4051.518739938736,0.2646545852933611,0.7501388277326312,0.2861393041401414,3554,0.7251050065726998 -391.464732170105,1.2954604625701904,3742.445062160492,15531,0,3742.445062160492,0.2876024627046391,3581,0.7413597671521572,4135.751101493835,0.2645735570362636,0.7493679864065987,0.2862349442034943,3554,0.7241955587014631 -395.5321831703186,1.3235111236572266,3822.540432453156,15879,0,3822.540432453156,0.2872469895913327,3581,0.7424361402846621,4219.954213857651,0.2641646180834089,0.7504794938223702,0.2858967263173976,3554,0.7251092656381893 -399.6045315265656,1.3510067462921145,3902.564522981644,16221,0,3902.564522981644,0.2873541292149539,3581,0.7414070817552709,4304.090502262116,0.264483094215393,0.7492396490914481,0.2860015199368229,3554,0.7241737825117825 -403.6683895587921,1.3806986808776855,3982.604186058045,16565,0,3982.604186058045,0.2871096136183154,3581,0.743032208836568,4388.235741376877,0.2637863499777658,0.751457759312221,0.2857622051078011,3554,0.7258935519397158 -407.73914527893066,1.4130854606628418,4062.732031106949,16909,0,4062.732031106949,0.2870845246068661,3581,0.7424294589718305,4472.478684663773,0.263904469353812,0.750687871660505,0.285708640489809,3554,0.7252478226558103 -411.8066053390503,1.4408173561096191,4142.8806121349335,17250,0,4142.8806121349335,0.2871348730714186,3581,0.742367827269792,4556.73437833786,0.2641006708145141,0.7503536769321987,0.2857754631665025,3554,0.7251186767990293 -415.8730938434601,1.470808506011963,4222.94013261795,17592,0,4222.94013261795,0.2869742488568137,3581,0.7416765159217048,4640.90233540535,0.2634621517998831,0.7503302437918526,0.2856861945277328,3554,0.7243447633828081 -419.94170808792114,1.4998552799224854,4303.021485567093,17937,0,4303.021485567093,0.2871031709237992,3581,0.7427179826034976,4725.093312740326,0.2638522556849888,0.7509931155613491,0.2857127106451357,3554,0.7255100299618388 -424.01124811172485,1.5275304317474363,4383.116291999817,18280,0,4383.116291999817,0.2869682833989284,3581,0.7418415034426487,4809.297451496124,0.2638928038733346,0.7499188014439174,0.2856596955839811,3554,0.7245200033193233 -428.0824813842773,1.5559325218200684,4463.175409078598,18624,0,4463.175409078598,0.2876706393661861,3581,0.7412632971760681,4893.468147993088,0.2639822449002947,0.7499551773071289,0.2862691712902715,3554,0.7240284934229038 -432.1502904891968,1.584796667098999,4543.15851688385,18969,0,4543.15851688385,0.2870536405791853,3581,0.7427049608611421,4977.560157299042,0.26383147920881,0.7511331013270787,0.2856988171613411,3554,0.7255755646146947 -436.2197251319885,1.612985372543335,4623.259985208511,19312,0,4623.259985208511,0.2868566100273143,3581,0.7427993855373848,5061.771022558212,0.2636727605547224,0.7509918212890625,0.2854848506414427,3554,0.725606614576006 -440.2881488800049,1.6470685005187988,4703.453898668289,19656,0,4703.453898668289,0.2871876758957868,3581,0.7422250653405125,5146.079728126526,0.2636232546397618,0.7508409363882882,0.285777146184317,3554,0.7250958701902785 -444.3014948368073,1.6759638786315918,4783.4698967933655,20001,0,4783.4698967933655,0.2871160563128316,3581,0.741523868376501,5230.150224685669,0.2639295033046177,0.7498019763401577,0.2857674774187095,3554,0.7243107595534961 -448.3680605888367,1.7046496868133545,4863.60972738266,20346,0,4863.60972738266,0.2866581818538816,3581,0.7424123466297822,5314.397246599197,0.2634925842285156,0.7507320812770298,0.2854115019773055,3554,0.725203239857379 -452.4350354671478,1.7338902950286863,4943.681494951248,20689,0,4943.681494951248,0.2869315702666853,3581,0.7424970220434236,5398.577282190323,0.2634158304759434,0.7509799684797015,0.2855863297461927,3554,0.7253446133537915 -456.5036413669586,1.7630236148834229,5023.702661037445,21032,0,5023.702661037445,0.2871429520058119,3581,0.742887265254119,5482.707812309265,0.2638038567134312,0.751258373260498,0.2857964665418806,3554,0.7257534149461874 -460.57174134254456,1.7934448719024658,5103.776798248291,21375,0,5103.776798248291,0.2866348313473017,3581,0.7430574342013404,5566.892321109772,0.2633882079805646,0.7514749254499163,0.2853645663886642,3554,0.7258524725661227 -464.63701915740967,1.8233022689819336,5183.902359724045,21716,0,5183.902359724045,0.2869842026493996,3581,0.7415020518448059,5651.124752044678,0.2636804580688476,0.7499002729143415,0.2856560375962823,3554,0.7242064811436058 -468.7076172828674,1.85288667678833,5264.08660030365,22059,0,5264.08660030365,0.2868711998328854,3581,0.742929262077632,5735.421021938324,0.2632166828427996,0.7515812601361956,0.2854894703536068,3554,0.725703405273987 -472.7145164012909,1.8831133842468264,5344.117122173309,22401,0,5344.117122173309,0.2867023944188948,3581,0.7425564720922927,5819.500441074371,0.2633006232125418,0.751082215990339,0.2853966124217431,3554,0.7252957714898706 -476.7797954082489,1.914899826049805,5424.106284618378,22743,0,5424.106284618378,0.2866686469714291,3581,0.7431424504982895,5903.598691701889,0.2631919384002685,0.751730033329555,0.2853215635661402,3554,0.7259724133458779 -480.8402616977692,1.9453659057617188,5504.153155326843,23087,0,5504.153155326843,0.2866985083491867,3581,0.7430985447282533,5987.74852848053,0.2628579650606428,0.7519538061959403,0.2852925229220157,3554,0.7259164959376758 -484.9044568538666,1.9780759811401367,5584.316662073135,23432,0,5584.316662073135,0.2866007430165282,3581,0.7433880910098436,6072.021075725555,0.2630608592714582,0.7519017628261021,0.2852483007702588,3554,0.7262908815331317 -488.9714064598084,2.009888172149658,5664.35845375061,23773,0,5664.35845375061,0.2866360585272096,3581,0.7433591159286861,6156.173418521881,0.2631400823593139,0.7518328939165387,0.2853126676148353,3554,0.7261721772562606 -493.0383477210999,2.040656805038452,5744.427692174912,24116,0,5744.427692174912,0.2865653252408545,3581,0.7434972418449804,6240.352616786957,0.2625663450786045,0.7526019641331264,0.2851527980961152,3554,0.7263241297217924 -497.10008668899536,2.070549726486206,5824.588913679123,24462,0,5824.588913679123,0.2865596665779461,3581,0.7435039913344736,6324.617612838745,0.2628721169063023,0.7522748538425991,0.2851877464762415,3554,0.7263377312535172 -501.1667170524597,2.1010448932647705,5904.574389696121,24805,0,5904.574389696121,0.2864862403134599,3581,0.7429024686496439,6408.712107658386,0.262893659727914,0.7515347344534737,0.2851589806105357,3554,0.7256277725142445 -505.2287847995758,2.1322381496429443,5984.683702468872,25147,0,5984.683702468872,0.2865853691793493,3581,0.7434330876064646,6492.926589488983,0.2624268020902361,0.7526866367885044,0.2852081831211311,3554,0.7262972014367614 -509.29019594192505,2.163586139678955,6064.761931180954,25493,0,6064.761931180954,0.2866510233044191,3581,0.7434775387897934,6577.1097593307495,0.2628815514700753,0.7523791449410575,0.2852760533905458,3554,0.7263473484981711 -513.3572182655334,2.195711135864258,6144.78192949295,25838,0,6144.78192949295,0.2867412551159767,3581,0.7429859850600391,6661.241092920303,0.2630417006356375,0.7518025806971959,0.2854469999142691,3554,0.7257995777205262 -517.4261784553528,2.225970506668091,6224.93829870224,26179,0,6224.93829870224,0.2867445616840617,3581,0.7430417535691846,6745.50874876976,0.2625166688646589,0.7523094585963658,0.2853016764780881,3554,0.725935249564751 -521.4913175106049,2.2567601203918457,6304.959088087082,26523,0,6304.959088087082,0.2864469023797473,3581,0.7429809399870846,6829.637432098389,0.262533494404384,0.751971926007952,0.2850752590610931,3554,0.7258363293340251 -525.5572319030762,2.2879343032836914,6384.982036590576,26869,0,6384.982036590576,0.2863624655844212,3581,0.7430339814297682,6913.769613027573,0.2626591580254691,0.7519010135105678,0.2850800505097689,3554,0.7259174576621412 -529.6197819709778,2.318350076675415,6465.129308223724,27211,0,6465.129308223724,0.286503386743839,3581,0.743049662061924,6998.021443128586,0.2623308045523507,0.7523035321916852,0.2851284630324107,3554,0.7259570944490363 -533.6854808330536,2.350006580352783,6545.123202323914,27557,0,6545.123202323914,0.2864705255929733,3581,0.7436318907515359,7082.125019311905,0.2624163287026541,0.7527543476649693,0.28511815884171,3554,0.7264644727991347 -537.7504897117615,2.381258487701416,6625.1555235385895,27904,0,6625.1555235385895,0.2863888158641092,3581,0.7438905530054454,7166.265917301178,0.2625382627759661,0.7528565270560128,0.2850966230831457,3554,0.7267664542812676 -541.8201594352722,2.4140686988830566,6705.3267295360565,28249,0,6705.3267295360565,0.2866401150385716,3581,0.743525603336184,7250.551569223404,0.2624980040958949,0.7527240344456264,0.2852377561484419,3554,0.7263930991048818 -545.8914339542389,2.4484989643096924,6785.361228466034,28593,0,6785.361228466034,0.2864463910547857,3581,0.7430519800684167,7334.703771352768,0.2622591597693307,0.7522410665239606,0.2850595279966235,3554,0.725907909112092 -549.9570317268372,2.481710910797119,6865.41695022583,28936,0,6865.41695022583,0.2864603672704028,3581,0.7436990447631597,7418.8703582286835,0.2624342782156808,0.7527742385864258,0.2851516302878359,3554,0.7265822153515406 -554.0184555053711,2.514178514480591,6945.470186471939,29279,0,6945.470186471939,0.2867397552294226,3581,0.7435703954028204,7503.029404401779,0.2625941719327654,0.7527771677289691,0.2853723460526431,3554,0.7264663275534609 -558.0815465450287,2.545555830001831,7025.437329292297,29622,0,7025.437329292297,0.2864134276389277,3581,0.7435144905403519,7587.103115797043,0.2620632989065988,0.7529736246381488,0.2850943218138893,3554,0.7263612248083146 -562.1518692970276,2.579129934310913,7105.544379711151,29965,0,7105.544379711151,0.2863203664959159,3581,0.7437412461166574,7671.325980186462,0.2621970346995762,0.7529262134007045,0.2849808898479266,3554,0.7266506351777926 -566.2167932987213,2.61244797706604,7185.599816322327,30306,0,7185.599816322327,0.2863320247050405,3581,0.7438927346586149,7755.4916203022,0.2622241633278983,0.7530655860900879,0.28504800447669,3554,0.7267884365547622 -570.28289103508,2.6435678005218506,7265.729054450989,30650,0,7265.729054450989,0.2863270137204168,3581,0.743650980216769,7839.730357885361,0.2618012939180646,0.7533008711678642,0.2849833113327413,3554,0.7265442272351575 -574.3471963405609,2.674584150314331,7345.938908815384,30994,0,7345.938908815384,0.2863733056736072,3581,0.7437419278832729,7924.047634363174,0.2620712518692016,0.7531344549996513,0.2850146017251688,3554,0.7266233634197383 -578.4171321392059,2.707150936126709,7426.023921489716,31336,0,7426.023921489716,0.2863085719334683,3581,0.7436792053546495,8008.246747255325,0.2621201447078159,0.7530013493129185,0.2849921385894415,3554,0.7265300761465954 -582.4814457893372,2.7388293743133545,7506.242448568344,31680,0,7506.242448568344,0.2863454555073653,3581,0.7437314968540562,8092.573554754257,0.2617100136620657,0.7534921509878976,0.2849501661859876,3554,0.726634904113323 -586.5494747161865,2.771655321121216,7586.352063417435,32026,0,7586.352063417435,0.2862539965158999,3581,0.743563236853358,8176.795695543289,0.261877179145813,0.7530773707798549,0.2848989715318567,3554,0.7264495660699212 -590.6169307231903,2.8068225383758545,7666.316948890686,32371,0,7666.316948890686,0.2863116057949071,3581,0.743465403344038,8260.874934196472,0.2620010035378592,0.752845014844622,0.2849781764110421,3554,0.7263357391099817 -594.6864111423492,2.842201471328736,7746.45441865921,32711,0,7746.45441865921,0.2863427284409033,3581,0.743592893701131,8345.129185199738,0.2616416556494577,0.753363949911935,0.28493270058275,3554,0.7265177798123593 -598.7506074905396,2.87981915473938,7826.421699523926,33056,0,7826.421699523926,0.286236509202213,3581,0.7435417612049707,8429.210087537766,0.2617289338793073,0.7531705583844867,0.2848614814513576,3554,0.7264386436277785 -602.82102227211,2.9121415615081787,7906.595588207245,33399,0,7906.595588207245,0.2862080454460172,3581,0.7436923634503281,8513.498555660248,0.2618094001497541,0.7532474654061454,0.2849090868123944,3554,0.7265992516134989 -606.8904416561127,2.949477195739746,7986.753165006638,33741,0,7986.753165006638,0.2862449290199141,3581,0.7439161192535255,8597.774811983109,0.261636563709804,0.753598690032959,0.2848777448990134,3554,0.7268514295072454 -610.9565181732178,2.982344388961792,8066.87997674942,34084,0,8066.87997674942,0.2860837934803477,3581,0.7439040519844318,8682.012324333191,0.2615481785365513,0.753511905670166,0.2847301888881806,3554,0.7268128918340251 -615.0226359367371,3.015657901763916,8147.030732393265,34427,0,8147.030732393265,0.2861510497569638,3581,0.7440554723497277,8766.274324178696,0.26164140020098,0.7536603382655552,0.2848227376943233,3554,0.7269527540491347 -619.0873327255249,3.049402475357056,8227.022418498993,34769,0,8227.022418498993,0.2861293354902611,3581,0.7438475335320092,8850.376275539398,0.2615990468433925,0.7534738268171038,0.2847944698645012,3554,0.726740693804516 -623.1559307575226,3.083103895187378,8307.055827379227,35111,0,8307.055827379227,0.2861419140843165,3581,0.7439033702178163,8934.523864030838,0.2615196364266531,0.7536313874380929,0.2847824826559862,3554,0.7267891921954136 -627.22323179245,3.116241693496704,8387.092158794403,35456,0,8387.092158794403,0.286086452370148,3581,0.7439780236622103,9018.672763586044,0.2615129436765398,0.7536501203264508,0.2847393080969506,3554,0.7268746482836241 -631.2888576984406,3.150402069091797,8467.19000673294,35798,0,8467.19000673294,0.2861236768273527,3581,0.7440725846917761,9102.882082939148,0.2615251200539725,0.7537720543997628,0.2847793227041713,3554,0.7269665616646737 -635.3528606891632,3.185049057006836,8547.280812740326,36142,0,8547.280812740326,0.2860745896310388,3581,0.7439823187918877,9187.083525657654,0.2614941937582833,0.7536488941737584,0.2847275097985984,3554,0.7268760908703221 -639.4234480857849,3.2191946506500244,8556.043129205704,36189,0,8556.043129205704,0.28607445327771575,3581,0.7439811597886414,9199.952863454819,0.2614939212799072,0.75364807673863,0.28472732088843555,3554,0.7268746482836241 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 1571c7811..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,472 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.668263,1.0711248,,,,,,,,,,,,,, -1,,,0.246516398021153,1.0853406361171178,0.2452213628679832,1.0927659497880908,3554.0,0.2660334975118245,1.0888061012112538,3581.0,57.97748017311096,262.02021503448486,57.97748017311096,204.0422370433808,0.0,0.0 -100,0.98460853,0.4277502,,,,,,,,,,,,,, -200,0.14292225,0.34520626,,,,,,,,,,,,,, -300,0.1784972,0.3774542,,,,,,,,,,,,,, -334,,,0.6825276102338519,0.3249119349888393,0.6577410850889842,0.3466054973269555,3554.0,0.6770980140411198,0.3475053409596656,3581.0,138.13936042785645,346.6962506771088,138.13936042785645,208.51142573356628,0.0327417850494384,0.0 -400,0.07658572,0.3721679,,,,,,,,,,,,,, -500,0.08933114,0.31396264,,,,,,,,,,,,,, -572,,,0.7051773752485003,0.3037167617252895,0.6816797122300576,0.3242165174253482,3554.0,0.6999024937386554,0.3260053194158405,3581.0,218.2856752872467,430.9490616321564,218.2856752872467,212.57180762290955,0.0704364776611328,0.0 -600,0.13515356,0.28204864,,,,,,,,,,,,,, -700,0.13088782,0.25488746,,,,,,,,,,,,,, -800,0.118063994,0.22574578,,,,,,,,,,,,,, -815,,,0.719536576952253,0.2911108561924526,0.6960570808200267,0.3113940855594049,3554.0,0.7135750505052709,0.3135829862686749,3581.0,298.414998292923,515.1890535354614,298.414998292923,216.63490796089167,0.1094112396240234,0.0 -900,0.14314614,0.3589097,,,,,,,,,,,,,, -1000,0.23408645,0.2803129,,,,,,,,,,,,,, -1094,,,0.7254536492483956,0.2851648841585432,0.7022134225960186,0.3053462125641882,3554.0,0.7193366601726124,0.307564657382278,3581.0,378.5005056858063,599.3780615329742,378.5005056858063,220.69820737838745,0.1396596431732177,0.0 -1100,0.21811865,0.3054975,,,,,,,,,,,,,, -1200,0.16594706,0.23779812,,,,,,,,,,,,,, -1300,0.22971198,0.27683303,,,,,,,,,,,,,, -1400,0.13468516,0.2625783,,,,,,,,,,,,,, -1440,,,0.7331483023507255,0.2791761670793806,0.7092399189293753,0.2996527350344682,3554.0,0.7263367670736177,0.3017688230308398,3581.0,458.5756525993347,683.556557893753,458.5756525993347,224.76212549209595,0.1666257381439209,0.0 -1500,0.0719669,0.28918988,,,,,,,,,,,,,, -1600,0.1335424,0.2762107,,,,,,,,,,,,,, -1700,0.0690214,0.3323819,,,,,,,,,,,,,, -1784,,,0.735295023236956,0.2767836877277919,0.7124547577333639,0.2969039204285664,3554.0,0.7292040047516406,0.2988863819572919,3581.0,538.5728995800018,767.6581978797913,538.5728995800018,228.82746934890747,0.1932723522186279,0.0 -1800,0.12354151,0.26653495,,,,,,,,,,,,,, -1900,0.10059588,0.3900827,,,,,,,,,,,,,, -2000,0.081938155,0.30435807,,,,,,,,,,,,,, -2100,0.06793637,0.26855344,,,,,,,,,,,,,, -2125,,,0.7378713744027274,0.2746272768293108,0.7141197088667698,0.2951256918920582,3554.0,0.731143971655962,0.2971061529709753,3581.0,618.6690142154694,851.858500957489,618.6690142154694,232.8936219215393,0.2190008163452148,0.0 -2200,0.118991956,0.27837783,,,,,,,,,,,,,, -2300,0.06482739,0.2698666,,,,,,,,,,,,,, -2400,0.059294686,0.28211355,,,,,,,,,,,,,, -2468,,,0.7398177555629185,0.2728226866040911,0.7154810298475309,0.2936773348471792,3554.0,0.7325169814428582,0.2955118076523667,3581.0,698.6350164413452,935.9269535541534,698.6350164413452,236.95533347129825,0.2474796772003173,0.0 -2500,0.1982781,0.1853136,,,,,,,,,,,,,, -2600,0.16320594,0.27385733,,,,,,,,,,,,,, -2700,0.0921938,0.28406128,,,,,,,,,,,,,, -2800,0.09784449,0.43127328,,,,,,,,,,,,,, -2812,,,0.7401653017316546,0.2718547412327358,0.7169991119161508,0.292136274433297,3554.0,0.7338114517156521,0.2939528119327702,3581.0,778.713544845581,1020.110639810562,778.713544845581,241.0200252532959,0.2753784656524658,0.0 -2900,0.09749267,0.2769702,,,,,,,,,,,,,, -3000,0.2470899,0.35087162,,,,,,,,,,,,,, -3100,0.11241343,0.27555212,,,,,,,,,,,,,, -3156,,,0.7388830184936523,0.2720067501068115,0.7153528457152153,0.2924436484418964,3554.0,0.7323439490758518,0.2940611787362992,3581.0,858.7817091941833,1104.2880628108978,858.7817091941833,245.08817720413208,0.3041086196899414,0.0 -3200,0.2149095,0.36438155,,,,,,,,,,,,,, -3300,0.13685301,0.27068537,,,,,,,,,,,,,, -3400,0.10375092,0.3029824,,,,,,,,,,,,,, -3500,,,0.7438075883047921,0.2697582415172032,0.7191269959904333,0.2907793499072348,3554.0,0.7362849691732407,0.2924660834744135,3581.0,938.8269300460817,1188.4386489391327,938.8269300460817,249.155867099762,0.3296148777008056,0.0 -3500,0.15286523,0.3270638,,,,,,,,,,,,,, -3600,0.1031725,0.33006805,,,,,,,,,,,,,, -3700,0.21940787,0.2544431,,,,,,,,,,,,,, -3800,0.063344784,0.31055373,,,,,,,,,,,,,, -3846,,,0.7432947158813477,0.2694469520023891,0.7191311863613182,0.2902159167597249,3554.0,0.7364441616779531,0.2917956341847598,3581.0,1018.9434487819672,1272.6597929000854,1018.9434487819672,253.22288179397583,0.3548049926757812,0.0 -3900,0.16450778,0.31237257,,,,,,,,,,,,,, -4000,0.2272729,0.27036864,,,,,,,,,,,,,, -4100,0.151465,0.24510455,,,,,,,,,,,,,, -4189,,,0.7452216829572406,0.2687106643404279,0.7207494938581528,0.2896371647153735,3554.0,0.7380193152663362,0.2911862711838522,3581.0,1099.1777846813202,1356.9970092773438,1099.1777846813202,257.2856154441833,0.3827457427978515,0.0 -4200,0.21753538,0.22316512,,,,,,,,,,,,,, -4300,0.29871184,0.28794232,,,,,,,,,,,,,, -4400,0.05377089,0.25896153,,,,,,,,,,,,,, -4500,0.20383722,0.33018976,,,,,,,,,,,,,, -4533,,,0.7452092170715332,0.2684806244713919,0.7205184052080402,0.2896702068202201,3554.0,0.7377660389686889,0.2911826578207903,3581.0,1179.3721747398376,1441.2953598499298,1179.3721747398376,261.35140013694763,0.4086706638336181,0.0 -4600,0.29786563,0.28882504,,,,,,,,,,,,,, -4700,0.09848899,0.31985676,,,,,,,,,,,,,, -4800,0.22510569,0.25319064,,,,,,,,,,,,,, -4881,,,0.745835712977818,0.2685016393661499,0.7216284413249155,0.289429432230849,3554.0,0.7387648952370148,0.2910055007657602,3581.0,1259.479147195816,1525.5069935321808,1259.479147195816,265.4180963039398,0.4344563484191894,0.0 -4900,0.039953865,0.31450534,,,,,,,,,,,,,, -5000,0.12763968,0.39214906,,,,,,,,,,,,,, -5100,0.05302118,0.29337138,,,,,,,,,,,,,, -5200,0.031822383,0.26514798,,,,,,,,,,,,,, -5226,,,0.7452772004263741,0.2679123026984079,0.7207955879378869,0.2888458372168683,3554.0,0.7380662208094806,0.2902943840974937,3581.0,1339.4585328102112,1609.586873292923,1339.4585328102112,269.48070764541626,0.4603831768035888,0.0 -5300,0.23951176,0.2625174,,,,,,,,,,,,,, -5400,0.13550964,0.3532498,,,,,,,,,,,,,, -5500,0.07090783,0.31366745,,,,,,,,,,,,,, -5566,,,0.7456020627702985,0.2673601422991071,0.7201848242077589,0.2887105088456668,3554.0,0.7375023316418249,0.2902249802560388,3581.0,1419.520034074783,1693.7521421909332,1419.520034074783,273.5457835197449,0.4872004985809326,0.0 -5600,0.06839984,0.25023732,,,,,,,,,,,,,, -5700,0.20015514,0.29618302,,,,,,,,,,,,,, -5800,0.13577086,0.27235007,,,,,,,,,,,,,, -5900,0.21109185,0.25376707,,,,,,,,,,,,,, -5911,,,0.7467336654663086,0.2671688624790737,0.7218456536648846,0.2883714666243142,3554.0,0.7391205047036442,0.2898647688647549,3581.0,1499.587522983551,1777.9230041503906,1499.587522983551,277.6095860004425,0.5148611068725586,0.0 -6000,0.23664774,0.23986506,,,,,,,,,,,,,, -6100,0.10579787,0.2923072,,,,,,,,,,,,,, -6200,0.13698119,0.27432135,,,,,,,,,,,,,, -6252,,,0.7460424559456962,0.2673352786472865,0.7214108168173186,0.2884108629800928,3554.0,0.7387071496046844,0.2898444522196139,3581.0,1579.6640601158142,1862.1026401519773,1579.6640601158142,281.67469453811646,0.540931224822998,0.0 -6300,0.80756515,0.22750866,,,,,,,,,,,,,, -6400,0.06541117,0.3001472,,,,,,,,,,,,,, -6500,0.17809898,0.3001792,,,,,,,,,,,,,, -6595,,,0.7469039644513812,0.2672669036047799,0.7222953285470597,0.288696701230128,3554.0,0.7393427606202876,0.2901490996317369,3581.0,1659.788988828659,1946.3356404304504,1659.788988828659,285.7416567802429,0.5702078342437744,0.0 -6600,0.30128363,0.24720508,,,,,,,,,,,,,, -6700,0.074606165,0.27367142,,,,,,,,,,,,,, -6800,0.114005536,0.26850313,,,,,,,,,,,,,, -6900,0.042023428,0.34349176,,,,,,,,,,,,,, -6938,,,0.7465673855372837,0.2670338494437081,0.722022336188098,0.288116661161939,3554.0,0.7393149445423765,0.2896087314123149,3581.0,1739.779695034027,2030.43337893486,1739.779695034027,289.8089179992676,0.5982434749603271,0.0 -7000,0.06535122,0.24818143,,,,,,,,,,,,,, -7100,0.43079257,0.2746741,,,,,,,,,,,,,, -7200,0.09295077,0.2725417,,,,,,,,,,,,,, -7283,,,0.746274607522147,0.2668186085564749,0.7219009528216446,0.2878892648468275,3554.0,0.7390221257810319,0.2893108334896851,3581.0,1819.796627521515,2114.556656360626,1819.796627521515,293.8766875267029,0.6245615482330322,0.0 -7300,0.11272649,0.3067193,,,,,,,,,,,,,, -7400,0.10412209,0.27667424,,,,,,,,,,,,,, -7500,0.06435029,0.26898766,,,,,,,,,,,,,, -7600,0.31296483,0.26701245,,,,,,,,,,,,,, -7624,,,0.7464902741568429,0.266804644039699,0.7211011415394626,0.2883452596326322,3554.0,0.7385133233559061,0.2896880549580249,3581.0,1899.944031238556,2198.8077549934387,1899.944031238556,297.93975472450256,0.6534247398376465,0.0 -7700,0.16188578,0.21848863,,,,,,,,,,,,,, -7800,0.12166052,0.288478,,,,,,,,,,,,,, -7900,0.11727424,0.2891818,,,,,,,,,,,,,, -7970,,,0.7481408800397601,0.2662150178636823,0.7232896829496693,0.2875332722317723,3554.0,0.7405279437046216,0.2889701547119345,3581.0,1980.0877783298488,2283.059469461441,1980.0877783298488,302.0066475868225,0.6825459003448486,0.0 -8000,0.14059368,0.29776883,,,,,,,,,,,,,, -8100,0.48113614,0.19696021,,,,,,,,,,,,,, -8200,0.16062497,0.3272916,,,,,,,,,,,,,, -8300,0.3620797,0.19809449,,,,,,,,,,,,,, -8316,,,0.747328553880964,0.2663652215685163,0.7223760447075478,0.2876317802948702,3554.0,0.7396587594465582,0.2890171284317404,3581.0,2060.240626811981,2367.314504146576,2060.240626811981,306.07082080841064,0.7084383964538574,0.0 -8400,0.17206697,0.19880532,,,,,,,,,,,,,, -8500,0.055654723,0.25351033,,,,,,,,,,,,,, -8600,0.047796004,0.17391127,,,,,,,,,,,,,, -8659,,,0.748100825718471,0.2662083080836704,0.7235134212770822,0.2874923130737373,3554.0,0.7407127706340757,0.2889147270860968,3581.0,2140.375987768173,2451.5527720451355,2140.375987768173,310.1355743408203,0.7345848083496094,0.0 -8700,0.3781902,0.26755264,,,,,,,,,,,,,, -8800,0.0952551,0.2570473,,,,,,,,,,,,,, -8900,0.15318146,0.25175062,,,,,,,,,,,,,, -9000,0.4063958,0.3411498,,,,,,,,,,,,,, -9003,,,0.7448654856000628,0.2672805956431797,0.720109260142621,0.2884747146151343,3554.0,0.7374599939350042,0.2899190033990156,3581.0,2220.5630145072937,2535.846463680268,2220.5630145072937,314.2036237716675,0.7611832618713379,0.0 -9100,0.29051298,0.28850582,,,,,,,,,,,,,, -9200,0.118042484,0.24204662,,,,,,,,,,,,,, -9300,0.08833073,0.41676974,,,,,,,,,,,,,, -9348,,,0.7472290992736816,0.2657300063541957,0.7225334927414533,0.2869929032977806,3554.0,0.7397366171940449,0.2884272639560353,3581.0,2300.723839521408,2620.1121048927307,2300.723839521408,318.2686553001404,0.7889750003814697,0.0 -9400,0.133701,0.28096512,,,,,,,,,,,,,, -9500,0.09348357,0.3204098,,,,,,,,,,,,,, -9600,0.11940685,0.29023314,,,,,,,,,,,,,, -9690,,,0.7489068167550224,0.2654880796160017,0.7240464227147229,0.2868176118403119,3554.0,0.741373334307805,0.2881913386187517,3581.0,2380.7071380615234,2704.198437929153,2380.7071380615234,322.3337795734405,0.8150200843811035,0.0 -9700,0.11744338,0.2862021,,,,,,,,,,,,,, -9800,0.5330628,0.22702669,,,,,,,,,,,,,, -9900,0.13398157,0.28957945,,,,,,,,,,,,,, -10000,0.1907504,0.22989915,,,,,,,,,,,,,, -10034,,,0.7498970031738281,0.2650915724890573,0.7247131725476575,0.2868405730119231,3554.0,0.7419083165709648,0.2881762033998883,3581.0,2460.8392124176025,2788.4354214668274,2460.8392124176025,326.4004166126251,0.8413591384887695,0.0 -10100,0.58139634,0.22894725,,,,,,,,,,,,,, -10200,0.12857792,0.34064728,,,,,,,,,,,,,, -10300,0.07209396,0.26803064,,,,,,,,,,,,,, -10380,,,0.7492196900503976,0.265133125441415,0.7242343024584975,0.286560779887099,3554.0,0.741463804737678,0.2879667306072849,3581.0,2540.927223443985,2872.6292066574097,2540.927223443985,330.4654278755188,0.8701965808868408,0.0 -10400,0.044646755,0.23409113,,,,,,,,,,,,,, -10500,0.51312155,0.25872076,,,,,,,,,,,,,, -10600,0.14029104,0.24964523,,,,,,,,,,,,,, -10700,0.24090089,0.2209867,,,,,,,,,,,,,, -10722,,,0.7487141064235142,0.2656360183443342,0.7237369535206809,0.2868639120038601,3554.0,0.7410609488445965,0.2882319719090337,3581.0,2621.015692472458,2956.8216876983643,2621.015692472458,334.5306432247162,0.8970785140991211,0.0 -10800,0.09994086,0.2588406,,,,,,,,,,,,,, -10900,0.25991786,0.28292033,,,,,,,,,,,,,, -11000,0.08838544,0.33136532,,,,,,,,,,,,,, -11068,,,0.7497291564941406,0.2647242035184587,0.7241230858935355,0.2865727155746606,3554.0,0.741399241439193,0.2879757981032707,3581.0,2701.032628774643,3040.948775529861,2701.032628774643,338.6022090911865,0.923598051071167,0.0 -11100,0.09881534,0.32903147,,,,,,,,,,,,,, -11200,0.07877813,0.23972243,,,,,,,,,,,,,, -11300,0.32998282,0.28409255,,,,,,,,,,,,,, -11400,0.12879851,0.28228953,,,,,,,,,,,,,, -11413,,,0.7497448240007673,0.2648477043424334,0.7246891981306275,0.2863667863235087,3554.0,0.7419121344640115,0.2877792106996998,3581.0,2781.2527372837067,3125.2749075889587,2781.2527372837067,342.6689305305481,0.9507875442504884,0.0 -11500,0.1729342,0.25499934,,,,,,,,,,,,,, -11600,0.2687302,0.2286168,,,,,,,,,,,,,, -11700,0.1925166,0.3010094,,,,,,,,,,,,,, -11756,,,0.7496111733572823,0.2651637281690325,0.7245175990081598,0.2867564908158061,3554.0,0.7417808943905334,0.2881267071436051,3581.0,2861.2369871139526,3209.367850065232,2861.2369871139526,346.7381579875946,0.9784178733825684,0.0 -11800,0.2397945,0.28722057,,,,,,,,,,,,,, -11900,0.2896068,0.29646888,,,,,,,,,,,,,, -12000,0.08011049,0.33149457,,,,,,,,,,,,,, -12099,,,0.7495346069335938,0.2644143274852207,0.724081250879291,0.286263383769828,3554.0,0.7413425866334473,0.2876435732315519,3581.0,2941.4175691604614,3293.6491615772247,2941.4175691604614,350.7999076843262,1.005598783493042,0.0 -12100,0.11777931,0.2895161,,,,,,,,,,,,,, -12200,0.11386526,0.2504673,,,,,,,,,,,,,, -12300,0.32787266,0.2101766,,,,,,,,,,,,,, -12400,0.14527312,0.27587605,,,,,,,,,,,,,, -12444,,,0.7488523891993931,0.2650868552071707,0.7235066205112197,0.2867186057413302,3554.0,0.7408034455939333,0.2880808242503839,3581.0,3021.535113096237,3377.875727415085,3021.535113096237,354.8683977127075,1.0340232849121094,0.0 -12500,0.20109983,0.3299528,,,,,,,,,,,,,, -12600,0.16342127,0.2652495,,,,,,,,,,,,,, -12700,0.06525378,0.20886748,,,,,,,,,,,,,, -12787,,,0.7494787488664899,0.2650796856198992,0.7243308183780599,0.2866628085486863,3554.0,0.7415428214884111,0.2880176926617914,3581.0,3101.513483762741,3461.960406780243,3101.513483762741,358.9352297782898,1.0614628791809082,0.0 -12800,0.13393585,0.27115834,,,,,,,,,,,,,, -12900,0.22494717,0.22085811,,,,,,,,,,,,,, -13000,0.3050992,0.3258139,,,,,,,,,,,,,, -13100,0.12675363,0.23812295,,,,,,,,,,,,,, -13129,,,0.7483366557529995,0.26520950453622,0.7230694480470948,0.2870787372063168,3554.0,0.7402498511021712,0.2884689198962405,3581.0,3181.577430486679,3546.125237464905,3181.577430486679,362.99676966667175,1.0889112949371338,0.0 -13200,0.12417768,0.24101207,,,,,,,,,,,,,, -13300,0.10869783,0.35506517,,,,,,,,,,,,,, -13400,0.23609486,0.3624936,,,,,,,,,,,,,, -13473,,,0.7472009658813477,0.2661695991243635,0.7228545026290799,0.2874675143214512,3554.0,0.739767023985095,0.2889175905058817,3581.0,3261.7306559085846,3630.386870384216,3261.7306559085846,367.06433033943176,1.1174826622009275,0.0 -13500,0.074725136,0.3414285,,,,,,,,,,,,,, -13600,0.15808094,0.24103439,,,,,,,,,,,,,, -13700,0.17955877,0.3738071,,,,,,,,,,,,,, -13800,0.14166741,0.21553838,,,,,,,,,,,,,, -13817,,,0.7497094018118722,0.2649522338594709,0.7245324370427687,0.2865731449159397,3554.0,0.7417157175020944,0.2879727642418319,3581.0,3341.8958337306976,3714.6591458320618,3341.8958337306976,371.1292831897736,1.147545337677002,0.0 -13900,0.15634093,0.2384671,,,,,,,,,,,,,, -14000,0.13248613,0.2636766,,,,,,,,,,,,,, -14100,0.15584616,0.27848098,,,,,,,,,,,,,, -14158,,,0.7489772524152484,0.2646031379699707,0.7235373269995076,0.286383101292118,3554.0,0.7407984686976403,0.2877275327902471,3581.0,3422.105791330337,3798.9750142097473,3422.105791330337,375.1954782009125,1.1753830909729004,0.0 -14200,0.21778171,0.23080757,,,,,,,,,,,,,, -14300,0.14068584,0.23024225,,,,,,,,,,,,,, -14400,0.18609744,0.30702332,,,,,,,,,,,,,, -14500,0.14204547,0.35290104,,,,,,,,,,,,,, -14503,,,0.7499461855207171,0.2646099499293736,0.7248605911692811,0.2862713008230163,3554.0,0.7420518284435214,0.2876645034666469,3581.0,3502.180704832077,3883.161416769028,3502.180704832077,379.2632808685303,1.2069926261901855,0.0 -14600,0.1867905,0.21470803,,,,,,,,,,,,,, -14700,0.14075783,0.27325633,,,,,,,,,,,,,, -14800,0.32545042,0.27834165,,,,,,,,,,,,,, -14846,,,0.7488870620727539,0.2652239799499511,0.7236846769265265,0.2867498789601065,3554.0,0.7409713647113236,0.2880744156441985,3581.0,3582.177492141724,3967.268642663956,3582.177492141724,383.3338749408722,1.2347795963287354,0.0 -14900,0.20679823,0.2069993,,,,,,,,,,,,,, -15000,0.18863536,0.209982,,,,,,,,,,,,,, -15100,0.11039803,0.26595476,,,,,,,,,,,,,, -15188,,,0.7501388277326312,0.2646545852933611,0.7251050065726998,0.2861393041401414,3554.0,0.7422340646598367,0.287569806083758,3581.0,3662.319780349731,4051.518739938736,3662.319780349731,387.40187311172485,1.2628545761108398,0.0 -15200,0.144259,0.2583617,,,,,,,,,,,,,, -15300,0.21378425,0.17451644,,,,,,,,,,,,,, -15400,0.13635848,0.3535466,,,,,,,,,,,,,, -15500,0.18190712,0.27749065,,,,,,,,,,,,,, -15531,,,0.7493679864065987,0.2645735570362636,0.7241955587014631,0.2862349442034943,3554.0,0.7413597671521572,0.2876024627046391,3581.0,3742.445062160492,4135.751101493835,3742.445062160492,391.464732170105,1.2954604625701904,0.0 -15600,0.07268983,0.29548997,,,,,,,,,,,,,, -15700,0.3092938,0.21024986,,,,,,,,,,,,,, -15800,0.0839247,0.38618922,,,,,,,,,,,,,, -15879,,,0.7504794938223702,0.2641646180834089,0.7251092656381893,0.2858967263173976,3554.0,0.7424361402846621,0.2872469895913327,3581.0,3822.540432453156,4219.954213857651,3822.540432453156,395.5321831703186,1.3235111236572266,0.0 -15900,0.15316297,0.20788923,,,,,,,,,,,,,, -16000,0.11641561,0.23177043,,,,,,,,,,,,,, -16100,0.080693744,0.30218974,,,,,,,,,,,,,, -16200,0.06028247,0.26474136,,,,,,,,,,,,,, -16221,,,0.7492396490914481,0.264483094215393,0.7241737825117825,0.2860015199368229,3554.0,0.7414070817552709,0.2873541292149539,3581.0,3902.564522981644,4304.090502262116,3902.564522981644,399.6045315265656,1.3510067462921145,0.0 -16300,0.20390017,0.24398929,,,,,,,,,,,,,, -16400,0.22120479,0.24091995,,,,,,,,,,,,,, -16500,0.15158789,0.32050255,,,,,,,,,,,,,, -16565,,,0.751457759312221,0.2637863499777658,0.7258935519397158,0.2857622051078011,3554.0,0.743032208836568,0.2871096136183154,3581.0,3982.604186058045,4388.235741376877,3982.604186058045,403.6683895587921,1.3806986808776855,0.0 -16600,0.16038755,0.27576345,,,,,,,,,,,,,, -16700,0.162506,0.2818053,,,,,,,,,,,,,, -16800,0.19498476,0.17559332,,,,,,,,,,,,,, -16900,0.13802192,0.31777117,,,,,,,,,,,,,, -16909,,,0.750687871660505,0.263904469353812,0.7252478226558103,0.285708640489809,3554.0,0.7424294589718305,0.2870845246068661,3581.0,4062.732031106949,4472.478684663773,4062.732031106949,407.73914527893066,1.4130854606628418,0.0 -17000,0.21836774,0.20648506,,,,,,,,,,,,,, -17100,0.14941649,0.25500426,,,,,,,,,,,,,, -17200,0.16128224,0.24060468,,,,,,,,,,,,,, -17250,,,0.7503536769321987,0.2641006708145141,0.7251186767990293,0.2857754631665025,3554.0,0.742367827269792,0.2871348730714186,3581.0,4142.8806121349335,4556.73437833786,4142.8806121349335,411.8066053390503,1.4408173561096191,0.0 -17300,0.073864095,0.33960572,,,,,,,,,,,,,, -17400,0.17809322,0.28528363,,,,,,,,,,,,,, -17500,0.13728075,0.22558683,,,,,,,,,,,,,, -17592,,,0.7503302437918526,0.2634621517998831,0.7243447633828081,0.2856861945277328,3554.0,0.7416765159217048,0.2869742488568137,3581.0,4222.94013261795,4640.90233540535,4222.94013261795,415.8730938434601,1.470808506011963,0.0 -17600,0.20562716,0.29223734,,,,,,,,,,,,,, -17700,0.2813917,0.2104806,,,,,,,,,,,,,, -17800,0.14248073,0.2736715,,,,,,,,,,,,,, -17900,0.096608214,0.2879805,,,,,,,,,,,,,, -17937,,,0.7509931155613491,0.2638522556849888,0.7255100299618388,0.2857127106451357,3554.0,0.7427179826034976,0.2871031709237992,3581.0,4303.021485567093,4725.093312740326,4303.021485567093,419.94170808792114,1.4998552799224854,0.0 -18000,0.18289071,0.32224196,,,,,,,,,,,,,, -18100,0.21115434,0.29241717,,,,,,,,,,,,,, -18200,0.14960569,0.36797994,,,,,,,,,,,,,, -18280,,,0.7499188014439174,0.2638928038733346,0.7245200033193233,0.2856596955839811,3554.0,0.7418415034426487,0.2869682833989284,3581.0,4383.116291999817,4809.297451496124,4383.116291999817,424.01124811172485,1.5275304317474363,0.0 -18300,0.15001057,0.26672783,,,,,,,,,,,,,, -18400,0.19170368,0.27052382,,,,,,,,,,,,,, -18500,0.37477294,0.2788967,,,,,,,,,,,,,, -18600,0.110017814,0.39344552,,,,,,,,,,,,,, -18624,,,0.7499551773071289,0.2639822449002947,0.7240284934229038,0.2862691712902715,3554.0,0.7412632971760681,0.2876706393661861,3581.0,4463.175409078598,4893.468147993088,4463.175409078598,428.0824813842773,1.5559325218200684,0.0 -18700,0.32474205,0.26198393,,,,,,,,,,,,,, -18800,0.084325865,0.2714342,,,,,,,,,,,,,, -18900,0.11326107,0.24262084,,,,,,,,,,,,,, -18969,,,0.7511331013270787,0.26383147920881,0.7255755646146947,0.2856988171613411,3554.0,0.7427049608611421,0.2870536405791853,3581.0,4543.15851688385,4977.560157299042,4543.15851688385,432.1502904891968,1.584796667098999,0.0 -19000,0.067054816,0.27963176,,,,,,,,,,,,,, -19100,0.17480768,0.19829687,,,,,,,,,,,,,, -19200,0.25597328,0.2163238,,,,,,,,,,,,,, -19300,0.20217787,0.27869627,,,,,,,,,,,,,, -19312,,,0.7509918212890625,0.2636727605547224,0.725606614576006,0.2854848506414427,3554.0,0.7427993855373848,0.2868566100273143,3581.0,4623.259985208511,5061.771022558212,4623.259985208511,436.2197251319885,1.612985372543335,0.0 -19400,0.106892005,0.23062058,,,,,,,,,,,,,, -19500,0.7739497,0.31364575,,,,,,,,,,,,,, -19600,0.118750885,0.20489986,,,,,,,,,,,,,, -19656,,,0.7508409363882882,0.2636232546397618,0.7250958701902785,0.285777146184317,3554.0,0.7422250653405125,0.2871876758957868,3581.0,4703.453898668289,5146.079728126526,4703.453898668289,440.2881488800049,1.6470685005187988,0.0 -19700,0.123714834,0.19847697,,,,,,,,,,,,,, -19800,0.07182691,0.23159817,,,,,,,,,,,,,, -19900,0.42261308,0.25431094,,,,,,,,,,,,,, -20000,0.18253016,0.27957517,,,,,,,,,,,,,, -20001,,,0.7498019763401577,0.2639295033046177,0.7243107595534961,0.2857674774187095,3554.0,0.741523868376501,0.2871160563128316,3581.0,4783.4698967933655,5230.150224685669,4783.4698967933655,444.3014948368073,1.6759638786315918,0.0 -20100,0.11411307,0.29962885,,,,,,,,,,,,,, -20200,0.1538195,0.3130577,,,,,,,,,,,,,, -20300,0.18043014,0.28050506,,,,,,,,,,,,,, -20346,,,0.7507320812770298,0.2634925842285156,0.725203239857379,0.2854115019773055,3554.0,0.7424123466297822,0.2866581818538816,3581.0,4863.60972738266,5314.397246599197,4863.60972738266,448.3680605888367,1.7046496868133545,0.0 -20400,0.21436752,0.29280192,,,,,,,,,,,,,, -20500,0.7151855,0.25736588,,,,,,,,,,,,,, -20600,0.1360422,0.3740169,,,,,,,,,,,,,, -20689,,,0.7509799684797015,0.2634158304759434,0.7253446133537915,0.2855863297461927,3554.0,0.7424970220434236,0.2869315702666853,3581.0,4943.681494951248,5398.577282190323,4943.681494951248,452.4350354671478,1.7338902950286863,0.0 -20700,0.18316919,0.25242075,,,,,,,,,,,,,, -20800,0.11422422,0.31450668,,,,,,,,,,,,,, -20900,0.19688295,0.28764692,,,,,,,,,,,,,, -21000,0.18659861,0.33045116,,,,,,,,,,,,,, -21032,,,0.751258373260498,0.2638038567134312,0.7257534149461874,0.2857964665418806,3554.0,0.742887265254119,0.2871429520058119,3581.0,5023.702661037445,5482.707812309265,5023.702661037445,456.5036413669586,1.7630236148834229,0.0 -21100,0.1188508,0.29390928,,,,,,,,,,,,,, -21200,0.18018165,0.19307521,,,,,,,,,,,,,, -21300,0.1315347,0.2622348,,,,,,,,,,,,,, -21375,,,0.7514749254499163,0.2633882079805646,0.7258524725661227,0.2853645663886642,3554.0,0.7430574342013404,0.2866348313473017,3581.0,5103.776798248291,5566.892321109772,5103.776798248291,460.57174134254456,1.7934448719024658,0.0 -21400,0.31416088,0.2099047,,,,,,,,,,,,,, -21500,0.30264503,0.19184871,,,,,,,,,,,,,, -21600,0.11492683,0.30149388,,,,,,,,,,,,,, -21700,0.16093013,0.29544166,,,,,,,,,,,,,, -21716,,,0.7499002729143415,0.2636804580688476,0.7242064811436058,0.2856560375962823,3554.0,0.7415020518448059,0.2869842026493996,3581.0,5183.902359724045,5651.124752044678,5183.902359724045,464.63701915740967,1.8233022689819336,0.0 -21800,0.17640749,0.22324859,,,,,,,,,,,,,, -21900,0.1230001,0.31149426,,,,,,,,,,,,,, -22000,0.1256126,0.38580284,,,,,,,,,,,,,, -22059,,,0.7515812601361956,0.2632166828427996,0.725703405273987,0.2854894703536068,3554.0,0.742929262077632,0.2868711998328854,3581.0,5264.08660030365,5735.421021938324,5264.08660030365,468.7076172828674,1.85288667678833,0.0 -22100,0.15654153,0.34742957,,,,,,,,,,,,,, -22200,0.30475843,0.3000026,,,,,,,,,,,,,, -22300,0.5799377,0.25413433,,,,,,,,,,,,,, -22400,0.08837092,0.30837607,,,,,,,,,,,,,, -22401,,,0.751082215990339,0.2633006232125418,0.7252957714898706,0.2853966124217431,3554.0,0.7425564720922927,0.2867023944188948,3581.0,5344.117122173309,5819.500441074371,5344.117122173309,472.7145164012909,1.8831133842468264,0.0 -22500,0.14667617,0.3051047,,,,,,,,,,,,,, -22600,0.18191962,0.27393204,,,,,,,,,,,,,, -22700,0.14357594,0.2246536,,,,,,,,,,,,,, -22743,,,0.751730033329555,0.2631919384002685,0.7259724133458779,0.2853215635661402,3554.0,0.7431424504982895,0.2866686469714291,3581.0,5424.106284618378,5903.598691701889,5424.106284618378,476.7797954082489,1.914899826049805,0.0 -22800,0.25498053,0.30755174,,,,,,,,,,,,,, -22900,0.4438806,0.19559929,,,,,,,,,,,,,, -23000,0.10934973,0.2580505,,,,,,,,,,,,,, -23087,,,0.7519538061959403,0.2628579650606428,0.7259164959376758,0.2852925229220157,3554.0,0.7430985447282533,0.2866985083491867,3581.0,5504.153155326843,5987.74852848053,5504.153155326843,480.8402616977692,1.9453659057617188,0.0 -23100,0.24580848,0.35487014,,,,,,,,,,,,,, -23200,0.18402915,0.26161727,,,,,,,,,,,,,, -23300,0.108770534,0.17675263,,,,,,,,,,,,,, -23400,0.14282833,0.27430016,,,,,,,,,,,,,, -23432,,,0.7519017628261021,0.2630608592714582,0.7262908815331317,0.2852483007702588,3554.0,0.7433880910098436,0.2866007430165282,3581.0,5584.316662073135,6072.021075725555,5584.316662073135,484.9044568538666,1.9780759811401367,0.0 -23500,0.20812273,0.3376173,,,,,,,,,,,,,, -23600,0.24881473,0.19864205,,,,,,,,,,,,,, -23700,0.15581608,0.25483555,,,,,,,,,,,,,, -23773,,,0.7518328939165387,0.2631400823593139,0.7261721772562606,0.2853126676148353,3554.0,0.7433591159286861,0.2866360585272096,3581.0,5664.35845375061,6156.173418521881,5664.35845375061,488.9714064598084,2.009888172149658,0.0 -23800,0.1655192,0.19124708,,,,,,,,,,,,,, -23900,0.7148794,0.23387113,,,,,,,,,,,,,, -24000,0.113538995,0.24218583,,,,,,,,,,,,,, -24100,0.16794424,0.30267727,,,,,,,,,,,,,, -24116,,,0.7526019641331264,0.2625663450786045,0.7263241297217924,0.2851527980961152,3554.0,0.7434972418449804,0.2865653252408545,3581.0,5744.427692174912,6240.352616786957,5744.427692174912,493.0383477210999,2.040656805038452,0.0 -24200,0.31387243,0.22131594,,,,,,,,,,,,,, -24300,0.15540378,0.26307493,,,,,,,,,,,,,, -24400,0.93992126,0.2013394,,,,,,,,,,,,,, -24462,,,0.7522748538425991,0.2628721169063023,0.7263377312535172,0.2851877464762415,3554.0,0.7435039913344736,0.2865596665779461,3581.0,5824.588913679123,6324.617612838745,5824.588913679123,497.10008668899536,2.070549726486206,0.0 -24500,0.08785751,0.3783299,,,,,,,,,,,,,, -24600,0.07324888,0.24970026,,,,,,,,,,,,,, -24700,0.12331679,0.3003421,,,,,,,,,,,,,, -24800,0.15902314,0.24334091,,,,,,,,,,,,,, -24805,,,0.7515347344534737,0.262893659727914,0.7256277725142445,0.2851589806105357,3554.0,0.7429024686496439,0.2864862403134599,3581.0,5904.574389696121,6408.712107658386,5904.574389696121,501.1667170524597,2.1010448932647705,0.0 -24900,0.14295232,0.28979772,,,,,,,,,,,,,, -25000,0.17613684,0.19549155,,,,,,,,,,,,,, -25100,0.1058182,0.3202656,,,,,,,,,,,,,, -25147,,,0.7526866367885044,0.2624268020902361,0.7262972014367614,0.2852081831211311,3554.0,0.7434330876064646,0.2865853691793493,3581.0,5984.683702468872,6492.926589488983,5984.683702468872,505.2287847995758,2.1322381496429443,0.0 -25200,0.108994186,0.22654346,,,,,,,,,,,,,, -25300,0.10712417,0.2376264,,,,,,,,,,,,,, -25400,0.2955635,0.26161903,,,,,,,,,,,,,, -25493,,,0.7523791449410575,0.2628815514700753,0.7263473484981711,0.2852760533905458,3554.0,0.7434775387897934,0.2866510233044191,3581.0,6064.761931180954,6577.1097593307495,6064.761931180954,509.29019594192505,2.163586139678955,0.0 -25500,0.19339591,0.24198982,,,,,,,,,,,,,, -25600,0.11416481,0.2648694,,,,,,,,,,,,,, -25700,0.28747717,0.2876181,,,,,,,,,,,,,, -25800,0.05530801,0.2657256,,,,,,,,,,,,,, -25838,,,0.7518025806971959,0.2630417006356375,0.7257995777205262,0.2854469999142691,3554.0,0.7429859850600391,0.2867412551159767,3581.0,6144.78192949295,6661.241092920303,6144.78192949295,513.3572182655334,2.195711135864258,0.0 -25900,0.13629347,0.21021262,,,,,,,,,,,,,, -26000,0.17035788,0.2334006,,,,,,,,,,,,,, -26100,0.17150776,0.3227347,,,,,,,,,,,,,, -26179,,,0.7523094585963658,0.2625166688646589,0.725935249564751,0.2853016764780881,3554.0,0.7430417535691846,0.2867445616840617,3581.0,6224.93829870224,6745.50874876976,6224.93829870224,517.4261784553528,2.225970506668091,0.0 -26200,0.11406376,0.23510805,,,,,,,,,,,,,, -26300,0.1394366,0.21833608,,,,,,,,,,,,,, -26400,0.17356156,0.33170614,,,,,,,,,,,,,, -26500,0.277483,0.22137538,,,,,,,,,,,,,, -26523,,,0.751971926007952,0.262533494404384,0.7258363293340251,0.2850752590610931,3554.0,0.7429809399870846,0.2864469023797473,3581.0,6304.959088087082,6829.637432098389,6304.959088087082,521.4913175106049,2.2567601203918457,0.0 -26600,0.116844565,0.26025715,,,,,,,,,,,,,, -26700,0.13124917,0.24487638,,,,,,,,,,,,,, -26800,0.18925542,0.21981329,,,,,,,,,,,,,, -26869,,,0.7519010135105678,0.2626591580254691,0.7259174576621412,0.2850800505097689,3554.0,0.7430339814297682,0.2863624655844212,3581.0,6384.982036590576,6913.769613027573,6384.982036590576,525.5572319030762,2.2879343032836914,0.0 -26900,0.15361609,0.27108687,,,,,,,,,,,,,, -27000,0.10012289,0.2726102,,,,,,,,,,,,,, -27100,0.19934902,0.33567733,,,,,,,,,,,,,, -27200,0.12252026,0.32827717,,,,,,,,,,,,,, -27211,,,0.7523035321916852,0.2623308045523507,0.7259570944490363,0.2851284630324107,3554.0,0.743049662061924,0.286503386743839,3581.0,6465.129308223724,6998.021443128586,6465.129308223724,529.6197819709778,2.318350076675415,0.0 -27300,0.14989734,0.21855049,,,,,,,,,,,,,, -27400,0.14168784,0.26697278,,,,,,,,,,,,,, -27500,0.14005736,0.22970548,,,,,,,,,,,,,, -27557,,,0.7527543476649693,0.2624163287026541,0.7264644727991347,0.28511815884171,3554.0,0.7436318907515359,0.2864705255929733,3581.0,6545.123202323914,7082.125019311905,6545.123202323914,533.6854808330536,2.350006580352783,0.0 -27600,0.071899004,0.2495693,,,,,,,,,,,,,, -27700,0.11990655,0.3323408,,,,,,,,,,,,,, -27800,0.08989956,0.27002248,,,,,,,,,,,,,, -27900,0.11866681,0.23939537,,,,,,,,,,,,,, -27904,,,0.7528565270560128,0.2625382627759661,0.7267664542812676,0.2850966230831457,3554.0,0.7438905530054454,0.2863888158641092,3581.0,6625.1555235385895,7166.265917301178,6625.1555235385895,537.7504897117615,2.381258487701416,0.0 -28000,0.09032305,0.23624283,,,,,,,,,,,,,, -28100,0.23413242,0.25011596,,,,,,,,,,,,,, -28200,0.13375919,0.29033765,,,,,,,,,,,,,, -28249,,,0.7527240344456264,0.2624980040958949,0.7263930991048818,0.2852377561484419,3554.0,0.743525603336184,0.2866401150385716,3581.0,6705.3267295360565,7250.551569223404,6705.3267295360565,541.8201594352722,2.4140686988830566,0.0 -28300,0.15643641,0.19223651,,,,,,,,,,,,,, -28400,0.09512884,0.22470416,,,,,,,,,,,,,, -28500,0.15443553,0.25093663,,,,,,,,,,,,,, -28593,,,0.7522410665239606,0.2622591597693307,0.725907909112092,0.2850595279966235,3554.0,0.7430519800684167,0.2864463910547857,3581.0,6785.361228466034,7334.703771352768,6785.361228466034,545.8914339542389,2.4484989643096924,0.0 -28600,0.1043688,0.33704966,,,,,,,,,,,,,, -28700,0.18007009,0.29903814,,,,,,,,,,,,,, -28800,0.08734198,0.25736946,,,,,,,,,,,,,, -28900,0.11699637,0.2760343,,,,,,,,,,,,,, -28936,,,0.7527742385864258,0.2624342782156808,0.7265822153515406,0.2851516302878359,3554.0,0.7436990447631597,0.2864603672704028,3581.0,6865.41695022583,7418.8703582286835,6865.41695022583,549.9570317268372,2.481710910797119,0.0 -29000,0.1858206,0.29477063,,,,,,,,,,,,,, -29100,0.118679546,0.19905931,,,,,,,,,,,,,, -29200,0.12601785,0.24431233,,,,,,,,,,,,,, -29279,,,0.7527771677289691,0.2625941719327654,0.7264663275534609,0.2853723460526431,3554.0,0.7435703954028204,0.2867397552294226,3581.0,6945.470186471939,7503.029404401779,6945.470186471939,554.0184555053711,2.514178514480591,0.0 -29300,0.13571814,0.22420086,,,,,,,,,,,,,, -29400,0.09243034,0.22667122,,,,,,,,,,,,,, -29500,0.15500422,0.29288432,,,,,,,,,,,,,, -29600,0.07938581,0.2699506,,,,,,,,,,,,,, -29622,,,0.7529736246381488,0.2620632989065988,0.7263612248083146,0.2850943218138893,3554.0,0.7435144905403519,0.2864134276389277,3581.0,7025.437329292297,7587.103115797043,7025.437329292297,558.0815465450287,2.545555830001831,0.0 -29700,0.10215713,0.34124506,,,,,,,,,,,,,, -29800,0.15285341,0.29994118,,,,,,,,,,,,,, -29900,0.15219025,0.20605485,,,,,,,,,,,,,, -29965,,,0.7529262134007045,0.2621970346995762,0.7266506351777926,0.2849808898479266,3554.0,0.7437412461166574,0.2863203664959159,3581.0,7105.544379711151,7671.325980186462,7105.544379711151,562.1518692970276,2.579129934310913,0.0 -30000,0.0759877,0.27528918,,,,,,,,,,,,,, -30100,0.087736405,0.2715732,,,,,,,,,,,,,, -30200,0.14026976,0.24416651,,,,,,,,,,,,,, -30300,0.12817077,0.16405462,,,,,,,,,,,,,, -30306,,,0.7530655860900879,0.2622241633278983,0.7267884365547622,0.28504800447669,3554.0,0.7438927346586149,0.2863320247050405,3581.0,7185.599816322327,7755.4916203022,7185.599816322327,566.2167932987213,2.61244797706604,0.0 -30400,0.17755616,0.23799837,,,,,,,,,,,,,, -30500,0.22413486,0.28247648,,,,,,,,,,,,,, -30600,0.091177404,0.26439345,,,,,,,,,,,,,, -30650,,,0.7533008711678642,0.2618012939180646,0.7265442272351575,0.2849833113327413,3554.0,0.743650980216769,0.2863270137204168,3581.0,7265.729054450989,7839.730357885361,7265.729054450989,570.28289103508,2.6435678005218506,0.0 -30700,0.09493977,0.28813967,,,,,,,,,,,,,, -30800,0.11494938,0.28362247,,,,,,,,,,,,,, -30900,0.10498519,0.22109911,,,,,,,,,,,,,, -30994,,,0.7531344549996513,0.2620712518692016,0.7266233634197383,0.2850146017251688,3554.0,0.7437419278832729,0.2863733056736072,3581.0,7345.938908815384,7924.047634363174,7345.938908815384,574.3471963405609,2.674584150314331,0.0 -31000,0.07378274,0.26313433,,,,,,,,,,,,,, -31100,0.14750822,0.2812026,,,,,,,,,,,,,, -31200,0.1594539,0.18901509,,,,,,,,,,,,,, -31300,0.11845353,0.2936557,,,,,,,,,,,,,, -31336,,,0.7530013493129185,0.2621201447078159,0.7265300761465954,0.2849921385894415,3554.0,0.7436792053546495,0.2863085719334683,3581.0,7426.023921489716,8008.246747255325,7426.023921489716,578.4171321392059,2.707150936126709,0.0 -31400,0.07032891,0.20885642,,,,,,,,,,,,,, -31500,0.124247745,0.30420795,,,,,,,,,,,,,, -31600,0.08470524,0.20302372,,,,,,,,,,,,,, -31680,,,0.7534921509878976,0.2617100136620657,0.726634904113323,0.2849501661859876,3554.0,0.7437314968540562,0.2863454555073653,3581.0,7506.242448568344,8092.573554754257,7506.242448568344,582.4814457893372,2.7388293743133545,0.0 -31700,0.11029924,0.25439113,,,,,,,,,,,,,, -31800,0.06727905,0.27132532,,,,,,,,,,,,,, -31900,0.086570516,0.2553701,,,,,,,,,,,,,, -32000,0.14417875,0.21870404,,,,,,,,,,,,,, -32026,,,0.7530773707798549,0.261877179145813,0.7264495660699212,0.2848989715318567,3554.0,0.743563236853358,0.2862539965158999,3581.0,7586.352063417435,8176.795695543289,7586.352063417435,586.5494747161865,2.771655321121216,0.0 -32100,0.06595624,0.42473093,,,,,,,,,,,,,, -32200,0.09424631,0.2772055,,,,,,,,,,,,,, -32300,0.07395014,0.2007119,,,,,,,,,,,,,, -32371,,,0.752845014844622,0.2620010035378592,0.7263357391099817,0.2849781764110421,3554.0,0.743465403344038,0.2863116057949071,3581.0,7666.316948890686,8260.874934196472,7666.316948890686,590.6169307231903,2.8068225383758545,0.0 -32400,0.06639399,0.33481258,,,,,,,,,,,,,, -32500,0.14240596,0.27556267,,,,,,,,,,,,,, -32600,0.0807504,0.26152655,,,,,,,,,,,,,, -32700,0.12149241,0.19281799,,,,,,,,,,,,,, -32711,,,0.753363949911935,0.2616416556494577,0.7265177798123593,0.28493270058275,3554.0,0.743592893701131,0.2863427284409033,3581.0,7746.45441865921,8345.129185199738,7746.45441865921,594.6864111423492,2.842201471328736,0.0 -32800,0.0549959,0.31974554,,,,,,,,,,,,,, -32900,0.087105565,0.23204485,,,,,,,,,,,,,, -33000,0.05602863,0.34335652,,,,,,,,,,,,,, -33056,,,0.7531705583844867,0.2617289338793073,0.7264386436277785,0.2848614814513576,3554.0,0.7435417612049707,0.286236509202213,3581.0,7826.421699523926,8429.210087537766,7826.421699523926,598.7506074905396,2.87981915473938,0.0 -33100,0.072894596,0.28609574,,,,,,,,,,,,,, -33200,0.07462018,0.34305134,,,,,,,,,,,,,, -33300,0.10304083,0.30495736,,,,,,,,,,,,,, -33399,,,0.7532474654061454,0.2618094001497541,0.7265992516134989,0.2849090868123944,3554.0,0.7436923634503281,0.2862080454460172,3581.0,7906.595588207245,8513.498555660248,7906.595588207245,602.82102227211,2.9121415615081787,0.0 -33400,0.08915778,0.2648629,,,,,,,,,,,,,, -33500,0.08925549,0.1949809,,,,,,,,,,,,,, -33600,0.08531962,0.23566528,,,,,,,,,,,,,, -33700,0.09390128,0.34187478,,,,,,,,,,,,,, -33741,,,0.753598690032959,0.261636563709804,0.7268514295072454,0.2848777448990134,3554.0,0.7439161192535255,0.2862449290199141,3581.0,7986.753165006638,8597.774811983109,7986.753165006638,606.8904416561127,2.949477195739746,0.0 -33800,0.05473418,0.20633432,,,,,,,,,,,,,, -33900,0.07370547,0.31802008,,,,,,,,,,,,,, -34000,0.061333783,0.26745608,,,,,,,,,,,,,, -34084,,,0.753511905670166,0.2615481785365513,0.7268128918340251,0.2847301888881806,3554.0,0.7439040519844318,0.2860837934803477,3581.0,8066.87997674942,8682.012324333191,8066.87997674942,610.9565181732178,2.982344388961792,0.0 -34100,0.07351555,0.21054247,,,,,,,,,,,,,, -34200,0.05460368,0.27854916,,,,,,,,,,,,,, -34300,0.08311916,0.25089705,,,,,,,,,,,,,, -34400,0.06178535,0.19779019,,,,,,,,,,,,,, -34427,,,0.7536603382655552,0.26164140020098,0.7269527540491347,0.2848227376943233,3554.0,0.7440554723497277,0.2861510497569638,3581.0,8147.030732393265,8766.274324178696,8147.030732393265,615.0226359367371,3.015657901763916,0.0 -34500,0.08337696,0.22745636,,,,,,,,,,,,,, -34600,0.07504466,0.21463454,,,,,,,,,,,,,, -34700,0.16809735,0.3073065,,,,,,,,,,,,,, -34769,,,0.7534738268171038,0.2615990468433925,0.726740693804516,0.2847944698645012,3554.0,0.7438475335320092,0.2861293354902611,3581.0,8227.022418498993,8850.376275539398,8227.022418498993,619.0873327255249,3.049402475357056,0.0 -34800,0.064029306,0.25680542,,,,,,,,,,,,,, -34900,0.06493915,0.17489077,,,,,,,,,,,,,, -35000,0.060462255,0.3080855,,,,,,,,,,,,,, -35100,0.07069145,0.25142342,,,,,,,,,,,,,, -35111,,,0.7536313874380929,0.2615196364266531,0.7267891921954136,0.2847824826559862,3554.0,0.7439033702178163,0.2861419140843165,3581.0,8307.055827379227,8934.523864030838,8307.055827379227,623.1559307575226,3.083103895187378,0.0 -35200,0.051986825,0.24874312,,,,,,,,,,,,,, -35300,0.068730384,0.33323073,,,,,,,,,,,,,, -35400,0.07654779,0.31676883,,,,,,,,,,,,,, -35456,,,0.7536501203264508,0.2615129436765398,0.7268746482836241,0.2847393080969506,3554.0,0.7439780236622103,0.286086452370148,3581.0,8387.092158794403,9018.672763586044,8387.092158794403,627.22323179245,3.116241693496704,0.0 -35500,0.08758826,0.2751231,,,,,,,,,,,,,, -35600,0.05406681,0.24917652,,,,,,,,,,,,,, -35700,0.09311105,0.1655212,,,,,,,,,,,,,, -35798,,,0.7537720543997628,0.2615251200539725,0.7269665616646737,0.2847793227041713,3554.0,0.7440725846917761,0.2861236768273527,3581.0,8467.19000673294,9102.882082939148,8467.19000673294,631.2888576984406,3.150402069091797,0.0 -35800,0.17746715,0.25779277,,,,,,,,,,,,,, -35900,0.059973035,0.2795889,,,,,,,,,,,,,, -36000,0.059831858,0.21543103,,,,,,,,,,,,,, -36100,0.05126633,0.26321563,,,,,,,,,,,,,, -36142,,,0.7536488941737584,0.2614941937582833,0.7268760908703221,0.2847275097985984,3554.0,0.7439823187918877,0.2860745896310388,3581.0,8547.280812740326,9187.083525657654,8547.280812740326,635.3528606891632,3.185049057006836,0.0 -36189,,,0.75364807673863,0.2614939212799072,0.7268746482836241,0.2847273208884355,3554.0,0.7439811597886414,0.2860744532777157,3581.0,8556.043129205704,9199.95286345482,8556.043129205704,639.4234480857849,3.2191946506500244,0.0 -36189,,,,,,,,,,,8556.043129205704,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/eval_measurements.csv deleted file mode 100644 index bbce07937..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -4.010667562484741,0.0,29.09799885749817,1,0,29.09799885749817,1.0888061012112538,3581,0.2660334975118245,33.10879158973694,1.0853406361171178,0.246516398021153,1.0927659497880908,3554,0.2452213628679832 -8.07419729232788,0.0180323123931884,109.085777759552,344,0,109.085777759552,0.3331046233593968,3581,0.6908669726071628,117.18999433517456,0.310058695929391,0.696904046194894,0.3311653888554269,3554,0.6722392187939645 -12.142432689666748,0.0424883365631103,189.21145105361936,689,0,189.21145105361936,0.3144970307700363,3581,0.7133751565336149,201.42073678970337,0.2921078545706613,0.7192108290536063,0.312307174244689,3554,0.6960894359788267 -22.118088483810425,0.0664534568786621,269.2266945838928,1028,0,269.2266945838928,0.3067101311064472,3581,0.7213561892409592,291.4474046230316,0.2841302156448364,0.7280421938214984,0.3044019708756858,3554,0.7043286668630416 -26.18432569503784,0.0907158851623535,349.2425203323364,1374,0,349.2425203323364,0.3027950863171076,3581,0.724923805763055,375.56581711769104,0.2803813900266375,0.7316979680742536,0.3007247830349606,3554,0.7079886467052968 -30.24979519844055,0.1158676147460937,429.3981335163117,1721,0,429.3981335163117,0.299889124295326,3581,0.7287686285910011,459.82407784461975,0.2773428303854806,0.7355102130344936,0.297957146107379,3554,0.7117270757860861 -34.31474590301514,0.1400947570800781,509.3901512622833,2066,0,509.3901512622833,0.3002220309336603,3581,0.7280282982232616,543.9175727367401,0.2779646260397775,0.7349446160452706,0.2982596771463492,3554,0.7111211893728897 -38.37882971763611,0.1682534217834472,589.5030369758606,2412,0,589.5030369758606,0.2956712728637252,3581,0.732865296006702,628.1349098682404,0.273242746080671,0.7399503844124931,0.2938508574185776,3554,0.7157946207178532 -42.44567584991455,0.1926784515380859,669.475870847702,2760,0,669.475870847702,0.2953393547869833,3581,0.7333316925483454,712.2110702991486,0.2727430718285696,0.7402798788888114,0.2935481202957934,3554,0.7162341974931415 -46.51434302330017,0.2209067344665527,749.4381399154663,3103,0,749.4381399154663,0.2941503538096027,3581,0.7337215948757331,796.2820212841034,0.272240366254534,0.7401915277753558,0.2925725538675612,3554,0.7165038238164744 -50.57876443862915,0.2453043460845947,829.4498481750488,3450,0,829.4498481750488,0.2928982553319603,3581,0.7354444872896886,880.3948223590851,0.2701899324144636,0.7429808889116559,0.2912714093604917,3554,0.7183132397035031 -54.64941358566284,0.2731256484985351,909.4199450016022,3797,0,909.4199450016022,0.2946227498974623,3581,0.7336543045107861,964.4751698970796,0.2723832811628069,0.7405669348580497,0.2928737453617402,3554,0.7165122045582443 -58.71142244338989,0.2989468574523926,989.4974684715272,4144,0,989.4974684715272,0.2945829347271188,3581,0.7350272461210207,1048.6526260375977,0.27250451701028,0.7413287843976702,0.2927918613929727,3554,0.7183819343081739 -62.77670836448669,0.3269517421722412,1069.665913581848,4489,0,1069.665913581848,0.2921583340241901,3581,0.736975666931374,1132.9262413978577,0.2691952841622488,0.7448859214782715,0.2904703615754255,3554,0.7199599867666714 -66.84105086326599,0.3562269210815429,1149.8356471061709,4838,0,1149.8356471061709,0.2923560122543459,3581,0.7358253902977521,1217.201759338379,0.2701189517974853,0.7423827988760812,0.2907897227925401,3554,0.7188070165218767 -70.9094934463501,0.3851003646850586,1229.944278717041,5184,0,1229.944278717041,0.2911693633717886,3581,0.7376395712615191,1301.4194130897522,0.2686702353613717,0.7448692321777344,0.2895360806046004,3554,0.7204929195097074 -74.97279810905457,0.4111568927764892,1309.9406440258026,5529,0,1309.9406440258026,0.2918261091524713,3581,0.7374340868036163,1385.516949415207,0.2685231992176601,0.7455554008483887,0.2902117607361423,3554,0.7202796227622046 -79.03741836547852,0.4429998397827148,1389.9110429286957,5876,0,1389.9110429286957,0.2916755750837755,3581,0.7384640316296076,1469.595920562744,0.2688674245561872,0.745835712977818,0.2901526146815208,3554,0.7213378631471581 -83.10330438613892,0.4701013565063476,1469.9522788524628,6225,0,1469.9522788524628,0.2911691588418039,3581,0.7372454419811156,1553.7422938346865,0.2689230612346104,0.7440728460039411,0.2896457858882597,3554,0.7203090927476083 -87.16951441764832,0.4982421398162842,1549.914607048035,6568,0,1549.914607048035,0.2909232115352729,3581,0.7361113232162804,1637.810804605484,0.2678774084363665,0.7444088799612862,0.289412430316193,3554,0.7188717268394766 -91.23540997505188,0.5247664451599121,1629.9265213012695,6916,0,1629.9265213012695,0.2905541030896572,3581,0.7378540550387461,1721.9272499084473,0.2679372344698225,0.7451337405613491,0.2890318965536191,3554,0.7206505049328221 -95.30541563034058,0.5512409210205078,1709.958627462387,7262,0,1709.958627462387,0.2910777339386693,3581,0.7383208606403588,1806.067808151245,0.2686879805156162,0.7455902780805316,0.2896817475138049,3554,0.7210564213518219 -99.37048292160034,0.5778157711029053,1790.0470299720764,7605,0,1790.0470299720764,0.2898603714700852,3581,0.7399462604283021,1890.2598686218264,0.2671380043029785,0.7475149972098214,0.2884691846994583,3554,0.7227935705147369 -103.4391713142395,0.6066532135009766,1870.235821723938,7952,0,1870.235821723938,0.289658841258552,3581,0.7400685693591176,1974.5582411289213,0.2665707724434988,0.7480261666434151,0.288152399530019,3554,0.7228617155625704 -107.50533604621889,0.6335721015930176,1950.461632490158,8300,0,1950.461632490158,0.2904865059297333,3581,0.7392413819245671,2058.8887956142426,0.2678480148315429,0.7468312127249581,0.2889935306169105,3554,0.7222216192362478 -111.5710666179657,0.6612002849578857,2030.442727804184,8643,0,2030.442727804184,0.289656284633744,3581,0.7408080134302569,2142.975204706192,0.2670548132487705,0.7479840006147113,0.2881419922974114,3554,0.7238505057022018 -115.63677453994752,0.6878864765167236,2110.4105145931244,8989,0,2110.4105145931244,0.2908245258176836,3581,0.7387133536808852,2227.0471861362457,0.2675431455884661,0.7469439506530762,0.2893986570479565,3554,0.7216550261369232 -119.70294833183289,0.7147269248962402,2190.373128414154,9334,0,2190.373128414154,0.2884846346167272,3581,0.7400444348209299,2311.114481449127,0.2656364100319998,0.7478771890912738,0.2870566518909151,3554,0.7228479079470315 -123.77129817008972,0.7413980960845947,2270.471269130707,9680,0,2270.471269130707,0.2888194502015847,3581,0.7403361627556897,2395.319464445114,0.2662521089826311,0.747654846736363,0.2873728703298659,3554,0.7231064744390123 -127.836492061615,0.7694323062896729,2350.5713725090027,10028,0,2350.5713725090027,0.2880712454294366,3581,0.740131564594387,2479.52467250824,0.2649153130395071,0.7482317515781948,0.2866361378684229,3554,0.7228617155625704 -131.9038667678833,0.7975926399230957,2430.7196304798126,10375,0,2430.7196304798126,0.2883899372338383,3581,0.7417603732154077,2563.780470609665,0.2653673716953822,0.7496601513453892,0.2869357322130434,3554,0.7246820538917417 -135.97573614120483,0.8289098739624023,2510.73530125618,10721,0,2510.73530125618,0.2889256694402751,3581,0.7399122402741902,2647.9112889766693,0.2658732107707432,0.7480405398777553,0.2874855810024796,3554,0.722738683525605 -140.04378747940063,0.8607263565063477,2590.721424818039,11069,0,2590.721424818039,0.2879434823656974,3581,0.7416544948600251,2732.009222984314,0.2646873508180891,0.7497758184160505,0.2865176225017146,3554,0.7245000131893641 -144.11624264717102,0.8889884948730469,2670.772175550461,11416,0,2670.772175550461,0.2881520688617006,3581,0.7416660167158265,2816.172864437104,0.2649575301579067,0.7497409411839077,0.2867852910288143,3554,0.7244434088351154 -148.1839623451233,0.9151699542999268,2750.9215309619904,11760,0,2750.9215309619904,0.2879775706964709,3581,0.7414381703129364,2900.4281883239746,0.2650561843599592,0.7492256845746722,0.2865726297064047,3554,0.7242193270346793 -152.245596408844,0.9438767433166504,2831.0232417583466,12106,0,2831.0232417583466,0.2880726089626675,3581,0.7412864772409942,2984.6321744918823,0.2645014354160854,0.7497265679495675,0.2866236011030705,3554,0.7241296118809791 -156.31152319908142,0.9713306427001952,2911.0521461963654,12453,0,2911.0521461963654,0.2881726923018186,3581,0.7413394505070162,3068.76672244072,0.2649199962615967,0.7496450287955148,0.2867595133784116,3554,0.7241797589423888 -160.32822513580322,1.001964807510376,2991.1137623786926,12801,0,2991.1137623786926,0.2879016218955075,3581,0.7406805912498254,3152.8874640464783,0.2650728055409023,0.7484604971749442,0.286540600846977,3554,0.723449810073157 -164.3973925113678,1.030181884765625,3071.2110888957977,13146,0,3071.2110888957977,0.2877257942853777,3581,0.7409998625558504,3237.094136953354,0.2642490012305123,0.7495675086975098,0.2863297599315911,3554,0.7237743920802265 -168.46466898918152,1.060338020324707,3151.210128545761,13494,0,3151.210128545761,0.2876294947509424,3581,0.7415456849081961,3321.202522754669,0.2644489322389875,0.7496738433837891,0.2861986734522281,3554,0.7243514954540659 -172.53335785865784,1.0881555080413818,3231.4134809970856,13842,0,3231.4134809970856,0.2876833884018954,3581,0.7415598656537978,3405.514585494995,0.2644280535834176,0.7498070171901158,0.2862306336170512,3554,0.7244232126213421 -176.60153603553772,1.1152000427246094,3311.5125164985657,14187,0,3311.5125164985657,0.2883037960219736,3581,0.7418224139774156,3489.7209939956665,0.2647754464830671,0.7504711151123047,0.2869762276824968,3554,0.7246985405968627 -180.6669204235077,1.1448464393615725,3391.5887458324432,14533,0,3391.5887458324432,0.2874723816344073,3581,0.7422204293275272,3573.904182910919,0.2641300303595407,0.750499997820173,0.286035317682321,3554,0.7250736131383653 -184.73077940940857,1.1741950511932373,3471.7626218795776,14879,0,3471.7626218795776,0.2876913650712964,3581,0.7414732812936331,3658.1833543777466,0.2644610064370291,0.7497051102774483,0.2863332977037317,3554,0.724206000281373 -188.79840087890625,1.2018353939056396,3551.921292543412,15225,0,3551.921292543412,0.2879501295901983,3581,0.7423924390446105,3742.449414014816,0.2643768957683018,0.7510156631469727,0.2865569673365398,3554,0.7252654771692107 -192.8629801273346,1.22971510887146,3632.060835123062,15571,0,3632.060835123062,0.2875700447020734,3581,0.7415143236438844,3826.6935093402863,0.2642054557800293,0.7497109685625348,0.2861932809257614,3554,0.7242641846115293 -196.92791748046875,1.258352279663086,3712.266945838928,15919,0,3712.266945838928,0.2877211582723925,3581,0.7413944008962231,3911.0051939487457,0.2644821405410766,0.749645437513079,0.2863569629950408,3554,0.7241334587788407 -200.9972834587097,1.286785364151001,3792.3345487117767,16264,0,3792.3345487117767,0.2874283395110479,3581,0.7418102103549986,3995.182499408722,0.2641898904527937,0.7499145099094936,0.2860444025437887,3554,0.7246541638822454 -205.06292033195496,1.3158259391784668,3872.3533470630646,16611,0,3872.3533470630646,0.2888027469195057,3581,0.7419807883621893,4079.307808637619,0.2646711553846086,0.7509322847638812,0.2872992297136589,3554,0.7248380593389491 -209.13350009918213,1.3445637226104736,3952.4735357761374,16959,0,3952.4735357761374,0.2872786235622905,3581,0.7411933479213209,4163.5395884513855,0.2639563594545637,0.7492737088884626,0.2859136423637978,3554,0.723998817353686 -213.2035298347473,1.3733389377593994,4032.434810400009,17303,0,4032.434810400009,0.2872969630842467,3581,0.7424263910220609,4247.611432790756,0.26378219468253,0.7509043557303292,0.2859099156814944,3554,0.7252153988024057 -217.2685823440552,1.4080066680908203,4112.587472200394,17648,0,4112.587472200394,0.2873458116622452,3581,0.7421531389625803,4331.875928163528,0.2633920737675258,0.7511190005711147,0.2859708134485351,3554,0.724937803904931 -221.33321714401245,1.436694622039795,4192.576465368271,17995,0,4192.576465368271,0.287149565141982,3581,0.7418655698041748,4415.970258712769,0.2636561734335763,0.7503595352172852,0.2857852006267146,3554,0.7246912589687676 -225.4005281925201,1.4654943943023682,4272.693914175034,18335,0,4272.693914175034,0.2876740141109327,3581,0.742102756409697,4500.195482492447,0.2641763687133789,0.7505239759172712,0.2862748042478545,3554,0.7249439864193514 -229.46516680717468,1.49409818649292,4352.8466629982,18682,0,4352.8466629982,0.2871079773784383,3581,0.7423240578530788,4584.453443288803,0.2632061072758266,0.7512984957013812,0.2856864521325004,3554,0.7251478720060144 -233.5327014923096,1.524968147277832,4432.961780309677,19029,0,4432.961780309677,0.2872036974112503,3581,0.7422262243437587,4668.678829908371,0.2636205468858991,0.7505662781851632,0.2857843762914586,3554,0.7250470970209623 -237.60093474388125,1.555567741394043,4513.029278039932,19374,0,4513.029278039932,0.287047504679646,3581,0.7426140131946384,4752.857128858566,0.2634677035467965,0.7511434555053711,0.2856579782188643,3554,0.7254661341094542 -241.6682722568512,1.5856091976165771,4593.14227104187,19718,0,4593.14227104187,0.2871675637806304,3581,0.742857608406346,4837.079212188721,0.2629783153533935,0.752007007598877,0.2857248695901624,3554,0.7257539645030248 -245.73634243011475,1.6148817539215088,4673.121878147125,20064,0,4673.121878147125,0.2874566669139207,3581,0.7411866666084892,4921.168244838715,0.2639386143003191,0.7493659428187779,0.2860666767693532,3554,0.7241226737259074 -249.80592560768127,1.6456749439239502,4753.126408100128,20412,0,4753.126408100128,0.2876055647427394,3581,0.7412260045422019,5005.285262107849,0.2641740356172834,0.7491849490574428,0.2862578195068497,3554,0.7240449114334201 -253.87785720825195,1.6757760047912598,4833.162514925003,20758,0,4833.162514925003,0.2870925694529286,3581,0.7423890983881947,5089.435174465179,0.2628540141241891,0.7516568728855678,0.2856588884223762,3554,0.7252162231376618 -257.94165658950806,1.7051050662994385,4913.222190141678,21107,0,4913.222190141678,0.2874962434659487,3581,0.7421710694245671,5173.600088834763,0.2635410002299717,0.7511718613760812,0.2860896722882667,3554,0.7249950265106219 -262.0073826313019,1.7356181144714355,4993.34165430069,21454,0,4993.34165430069,0.2871669501906765,3581,0.7413911965931305,5257.827834844589,0.2635111468178885,0.7499760900224958,0.2857918468297165,3554,0.7241845675647158 -266.0779194831848,1.765653133392334,5073.316593170166,21798,0,5073.316593170166,0.2879440959556513,3581,0.7419022488480871,5341.9150631427765,0.2637288229806082,0.7510729517255511,0.2865836895377567,3554,0.7247480007122257 -270.1434907913208,1.799447774887085,5153.433739185333,22146,0,5153.433739185333,0.287016347945319,3581,0.7419124071706577,5426.143560171127,0.2630865403584072,0.750969409942627,0.2856206598748769,3554,0.7246992962375141 -274.20904064178467,1.8289234638214111,5233.556937217712,22492,0,5233.556937217712,0.2869991333382784,3581,0.7428656532524085,5510.373544454575,0.263117824281965,0.7517211096627372,0.2855805422257491,3554,0.7257709320703785 -278.2751727104187,1.863584280014038,5313.555603504181,22835,0,5313.555603504181,0.28713978179105,3581,0.7428976281066741,5594.484750509262,0.2629103319985525,0.7522268976484027,0.2857270849911631,3554,0.7257775267524268 -282.3405730724335,1.893626928329468,5393.517927885056,23182,0,5393.517927885056,0.287239728776878,3581,0.7429226489414619,5678.554426193237,0.2629103830882481,0.7522249221801758,0.2857639568202201,3554,0.7257977229662 -286.4092450141907,1.9252281188964844,5473.649057149887,23530,0,5473.649057149887,0.287000394606517,3581,0.7425341783239667,5762.797847509384,0.2631608928952898,0.7511571475437709,0.2856311701493915,3554,0.7253869979248734 -290.47550415992737,1.956531286239624,5553.753219604492,23877,0,5553.753219604492,0.2869906794322466,3581,0.7428848790709648,5847.011462926865,0.263093147959028,0.7517889567783901,0.285614219755689,3554,0.7257241510445976 -294.5453701019287,1.9875144958496087,5633.818361282349,24223,0,5633.818361282349,0.2871886644573792,3581,0.7426174902043773,5931.189614534378,0.2628435237067086,0.7519252640860421,0.2856898353417804,3554,0.725507900429094 -298.6102910041809,2.0182888507843018,5713.8310968875885,24569,0,5713.8310968875885,0.2869305135284313,3581,0.7429443291198339,6015.310056209564,0.262832828930446,0.7520390238080706,0.2855415580375984,3554,0.725781236261079 -302.67626667022705,2.049443483352661,5793.904387712479,24914,0,5793.904387712479,0.2869498757003106,3581,0.7425912421896816,6099.492146253586,0.2628613199506487,0.7516685894557408,0.2855545928388348,3554,0.7254507465180079 -306.7443902492523,2.080031633377075,5873.891070127487,25262,0,5873.891070127487,0.2871179652593549,3581,0.7426024231621754,6183.589738368988,0.2625758647918701,0.752216134752546,0.2857069746456457,3554,0.72544092318954 -310.8091578483581,2.1123173236846924,5953.967887401581,25609,0,5953.967887401581,0.2869750669767523,3581,0.7428054532602625,6267.775489091873,0.2627349751336233,0.7520290102277484,0.2855440653906689,3554,0.7257105495128728 -314.874475479126,2.1450178623199463,6033.956508636475,25953,0,6033.956508636475,0.2868983341441811,3581,0.742740344548485,6351.873990535736,0.2627204656600952,0.7518808501107352,0.2855149388782885,3554,0.7255820906021384 -318.94093799591064,2.177800178527832,6113.944706916809,26300,0,6113.944706916809,0.2870122232572955,3581,0.7427877955049218,6435.973375320435,0.2623820304870605,0.7524644306727818,0.2856000171461733,3554,0.7256400001538759 -323.00680804252625,2.215407609939575,6193.946717262268,26646,0,6193.946717262268,0.2869487848737259,3581,0.7425352009738899,6520.0908126831055,0.2625506094523838,0.7519332340785435,0.2855287636674785,3554,0.7253897457090602 -327.07042241096497,2.2482399940490723,6274.016363620758,26988,0,6274.016363620758,0.2870836723985968,3581,0.7427188689000978,6604.268972635269,0.2627148628234863,0.7521262850080218,0.285691638575153,3554,0.7255871740028841 -331.14239048957825,2.2848925590515137,6354.049415111542,27334,0,6354.049415111542,0.2869639882692509,3581,0.7427528208775481,6688.422640323639,0.2621868678501674,0.7526041439601353,0.2855241954762679,3554,0.725647281781971 -335.2097702026367,2.317322254180908,6434.165660142899,27684,0,6434.165660142899,0.2869821914378839,3581,0.7428624489493159,6772.650623321533,0.2624724251883371,0.752434526171003,0.2855600368862549,3554,0.7257336995946468 -339.2761797904968,2.3487956523895264,6514.367031574249,28030,0,6514.367031574249,0.286982634586184,3581,0.7429992113323792,6856.961793422699,0.2626483099801199,0.7522539411272321,0.285618221216411,3554,0.7259011083462296 -343.3381669521332,2.380867481231689,6594.354041337967,28374,0,6594.354041337967,0.2870561972039933,3581,0.742843768544052,6941.054801940918,0.2622614758355276,0.7526415416172573,0.2856316510116242,3554,0.7256958488674733 -347.3995735645294,2.413620710372925,6674.474102973938,28720,0,6674.474102973938,0.2869618747927429,3581,0.7427183234868053,7025.281142234802,0.2624066046306065,0.7523200852530343,0.2855549878328116,3554,0.7255589405203644 -351.46641421318054,2.447869300842285,6754.4682993888855,29066,0,6754.4682993888855,0.2870614808952632,3581,0.7427424580249931,7109.388455867767,0.2625972884041922,0.7522360937935966,0.2856577206140968,3554,0.7256316881067107 -355.53622817993164,2.4804461002349854,6834.473086357117,29410,0,6834.473086357117,0.2870381644770141,3581,0.7428698802054244,7193.50737786293,0.2621924536568777,0.7528103419712612,0.28564556166907,3554,0.7257646808613534 -359.60470604896545,2.513329029083252,6914.607104301453,29757,0,6914.607104301453,0.2868713361862084,3581,0.7428138389896328,7277.75493144989,0.2622502020427159,0.7524969237191337,0.2854467938304551,3554,0.725686918568866 -363.6717460155487,2.547752857208252,6994.64807009697,30105,0,6994.64807009697,0.2869789530464605,3581,0.743100249144792,7361.90936923027,0.2624603680201939,0.7526117733546666,0.2855814696029122,3554,0.7260031885287704 -367.73788118362427,2.5821104049682617,7074.760458707809,30448,0,7074.760458707809,0.2869994401332554,3581,0.7429354661538328,7446.134158611298,0.2622091088976179,0.7527153832571847,0.2855970632781724,3554,0.7258541212366347 -371.8072714805603,2.6142380237579346,7154.866800308228,30796,0,7154.866800308228,0.2869368539579552,3581,0.7429225807648003,7530.353799581528,0.2621920279094151,0.7526858874729702,0.2854876155992807,3554,0.7258120801385762 -375.8769974708557,2.6478710174560547,7234.82731628418,31141,0,7234.82731628418,0.2870004968715093,3581,0.7431579947771223,7614.42972612381,0.262384397642953,0.7527549607413155,0.2855822080699124,3554,0.7260390471124085 -379.9435257911682,2.680339097976685,7314.991788864136,31487,0,7314.991788864136,0.286932354298293,3581,0.7427517982276249,7698.705207109451,0.2623383658272879,0.7522657939365932,0.2855801987527258,3554,0.7256374584535031 -384.0136480331421,2.718026161193848,7395.034659147263,31835,0,7395.034659147263,0.2869863502142383,3581,0.7428622444193312,7782.868022441864,0.2621617998395647,0.7526765550885882,0.2855543352340672,3554,0.725747644599395 -388.08017563819885,2.7506425380706787,7475.214136600494,32182,0,7475.214136600494,0.2869042314254049,3581,0.7431491318111212,7867.158411026001,0.262218884059361,0.7528634071350098,0.2855046690348902,3554,0.7260420696750141 -392.1504149436951,2.783066987991333,7555.255586385727,32526,0,7555.255586385727,0.2873262790487119,3581,0.7430746147200503,7951.3141577243805,0.262486457824707,0.7528461047581264,0.2858738853613446,3554,0.7259932965056978 -396.2140109539032,2.818079233169556,7635.29710650444,32873,0,7635.29710650444,0.2869643632408894,3581,0.7429696908379293,8035.4664006233215,0.2620720863342285,0.7528681755065918,0.2855324388288284,3554,0.7258712261931978 -400.2839741706848,2.851149559020996,7715.372609138489,33220,0,7715.372609138489,0.2869187530543144,3581,0.7430804097362818,8119.656742095947,0.2621460471834455,0.7528653144836426,0.2855008393106799,3554,0.7259631395742473 -404.3473062515259,2.8837192058563232,7795.492541074753,33565,0,7795.492541074753,0.2868944480744729,3581,0.7431213839098716,8203.884775876999,0.2621467964989798,0.7528884070260184,0.2854834080547446,3554,0.7260305976760341 -408.41320180892944,2.919473171234131,7875.469173669815,33912,0,7875.469173669815,0.2869322861216315,3581,0.7430133239013195,8287.974950790405,0.2619972058704921,0.7530028479439872,0.2855185109977314,3554,0.7259071534714406 -412.47713685035706,2.95334792137146,7955.515779495239,34257,0,7955.515779495239,0.2869017088889277,3581,0.7430517073617705,8372.131846427917,0.2620432717459542,0.7529234204973493,0.2854790459473481,3554,0.7259287235773073 -416.4928929805756,2.98781418800354,8035.698899507523,34601,0,8035.698899507523,0.2868458040264591,3581,0.7431550631806758,8456.37722826004,0.2620537621634347,0.7529683794294085,0.2854372968013594,3554,0.7260481834948298 -420.558625459671,3.02548623085022,8115.888223648071,34948,0,8115.888223648071,0.2868427360766895,3581,0.7431837655551871,8540.681955337524,0.2619693449565342,0.7530630656651088,0.2854283149817986,3554,0.7260910489281444 -424.6205706596375,3.0597083568573,8195.908843517303,35295,0,8195.908843517303,0.286832270959142,3581,0.7430881136990366,8624.81097126007,0.2619802270616804,0.7529725347246442,0.2854171349348885,3554,0.7259849844585327 -428.6862292289734,3.0939202308654785,8275.888614416122,35642,0,8275.888614416122,0.286830157482634,3581,0.7431127936505166,8708.90253329277,0.2619913646153041,0.7529713766915458,0.2854261854490539,3554,0.7260084093187253 -432.7504985332489,3.1271345615386963,8355.901663303375,35985,0,8355.901663303375,0.2868288280377338,3581,0.7430827959194359,8793.024923086166,0.2619766848427908,0.7529499190194267,0.285419127078424,3554,0.7259724133458779 -436.82040643692017,3.1628973484039307,8401.65019273758,36189,0,8401.65019273758,0.28683199825249583,3581,0.7431058396310388,8842.886168718338,0.26198225361960276,0.7529680388314384,0.2854213768267269,3554,0.7259976242657921 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/measurements.csv deleted file mode 100644 index 4c1d93a5e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.668263,1.0711248,,,,,,,,,,,,,, -1,,,0.246516398021153,1.0853406361171178,0.2452213628679832,1.0927659497880908,3554.0,0.2660334975118245,1.0888061012112538,3581.0,29.09799885749817,33.10879158973694,29.09799885749817,4.010667562484741,0.0,0.0 -100,0.63617754,0.37195343,,,,,,,,,,,,,, -200,0.18125482,0.33390185,,,,,,,,,,,,,, -300,0.14957076,0.36245355,,,,,,,,,,,,,, -344,,,0.696904046194894,0.310058695929391,0.6722392187939645,0.3311653888554269,3554.0,0.6908669726071628,0.3331046233593968,3581.0,109.085777759552,117.18999433517456,109.085777759552,8.07419729232788,0.0180323123931884,0.0 -400,0.2719756,0.36049622,,,,,,,,,,,,,, -500,0.25494158,0.30208653,,,,,,,,,,,,,, -600,0.19312198,0.27588382,,,,,,,,,,,,,, -689,,,0.7192108290536063,0.2921078545706613,0.6960894359788267,0.312307174244689,3554.0,0.7133751565336149,0.3144970307700363,3581.0,189.21145105361936,201.42073678970337,189.21145105361936,12.142432689666748,0.0424883365631103,0.0 -700,0.4860722,0.25031605,,,,,,,,,,,,,, -800,0.26572427,0.22404236,,,,,,,,,,,,,, -900,0.17388712,0.35641205,,,,,,,,,,,,,, -1000,0.40254116,0.27932984,,,,,,,,,,,,,, -1028,,,0.7280421938214984,0.2841302156448364,0.7043286668630416,0.3044019708756858,3554.0,0.7213561892409592,0.3067101311064472,3581.0,269.2266945838928,291.4474046230316,269.2266945838928,22.118088483810425,0.0664534568786621,0.0 -1100,0.4169776,0.30629987,,,,,,,,,,,,,, -1200,0.1631782,0.23641063,,,,,,,,,,,,,, -1300,0.3701606,0.27692422,,,,,,,,,,,,,, -1374,,,0.7316979680742536,0.2803813900266375,0.7079886467052968,0.3007247830349606,3554.0,0.724923805763055,0.3027950863171076,3581.0,349.2425203323364,375.56581711769104,349.2425203323364,26.18432569503784,0.0907158851623535,0.0 -1400,0.2988075,0.26268402,,,,,,,,,,,,,, -1500,0.21322107,0.28872886,,,,,,,,,,,,,, -1600,0.52911264,0.27934152,,,,,,,,,,,,,, -1700,0.16383022,0.33257738,,,,,,,,,,,,,, -1721,,,0.7355102130344936,0.2773428303854806,0.7117270757860861,0.297957146107379,3554.0,0.7287686285910011,0.299889124295326,3581.0,429.3981335163117,459.82407784461975,429.3981335163117,30.24979519844055,0.1158676147460937,0.0 -1800,0.16720167,0.26670757,,,,,,,,,,,,,, -1900,0.13911566,0.39070144,,,,,,,,,,,,,, -2000,0.056964874,0.30474815,,,,,,,,,,,,,, -2066,,,0.7349446160452706,0.2779646260397775,0.7111211893728897,0.2982596771463492,3554.0,0.7280282982232616,0.3002220309336603,3581.0,509.3901512622833,543.9175727367401,509.3901512622833,34.31474590301514,0.1400947570800781,0.0 -2100,0.08010713,0.26796266,,,,,,,,,,,,,, -2200,0.17616223,0.27959234,,,,,,,,,,,,,, -2300,0.08089788,0.2697921,,,,,,,,,,,,,, -2400,0.22996055,0.28295678,,,,,,,,,,,,,, -2412,,,0.7399503844124931,0.273242746080671,0.7157946207178532,0.2938508574185776,3554.0,0.732865296006702,0.2956712728637252,3581.0,589.5030369758606,628.1349098682404,589.5030369758606,38.37882971763611,0.1682534217834472,0.0 -2500,0.22973196,0.18599394,,,,,,,,,,,,,, -2600,0.0815365,0.2740707,,,,,,,,,,,,,, -2700,0.16408068,0.28423014,,,,,,,,,,,,,, -2760,,,0.7402798788888114,0.2727430718285696,0.7162341974931415,0.2935481202957934,3554.0,0.7333316925483454,0.2953393547869833,3581.0,669.475870847702,712.2110702991486,669.475870847702,42.44567584991455,0.1926784515380859,0.0 -2800,0.12109575,0.43322346,,,,,,,,,,,,,, -2900,0.21932757,0.27754095,,,,,,,,,,,,,, -3000,0.15428855,0.3511174,,,,,,,,,,,,,, -3100,0.24352725,0.2767358,,,,,,,,,,,,,, -3103,,,0.7401915277753558,0.272240366254534,0.7165038238164744,0.2925725538675612,3554.0,0.7337215948757331,0.2941503538096027,3581.0,749.4381399154663,796.2820212841034,749.4381399154663,46.51434302330017,0.2209067344665527,0.0 -3200,0.36427253,0.36888015,,,,,,,,,,,,,, -3300,0.08472884,0.27083895,,,,,,,,,,,,,, -3400,0.11968263,0.303489,,,,,,,,,,,,,, -3450,,,0.7429808889116559,0.2701899324144636,0.7183132397035031,0.2912714093604917,3554.0,0.7354444872896886,0.2928982553319603,3581.0,829.4498481750488,880.3948223590851,829.4498481750488,50.57876443862915,0.2453043460845947,0.0 -3500,0.16358525,0.32802218,,,,,,,,,,,,,, -3600,0.07352145,0.32928094,,,,,,,,,,,,,, -3700,0.21985732,0.2604274,,,,,,,,,,,,,, -3797,,,0.7405669348580497,0.2723832811628069,0.7165122045582443,0.2928737453617402,3554.0,0.7336543045107861,0.2946227498974623,3581.0,909.4199450016022,964.4751698970796,909.4199450016022,54.64941358566284,0.2731256484985351,0.0 -3800,0.16166127,0.31378394,,,,,,,,,,,,,, -3900,0.2516409,0.31614688,,,,,,,,,,,,,, -4000,0.3678351,0.27262956,,,,,,,,,,,,,, -4100,0.15003589,0.24638413,,,,,,,,,,,,,, -4144,,,0.7413287843976702,0.27250451701028,0.7183819343081739,0.2927918613929727,3554.0,0.7350272461210207,0.2945829347271188,3581.0,989.4974684715272,1048.6526260375977,989.4974684715272,58.71142244338989,0.2989468574523926,0.0 -4200,0.2758947,0.22506464,,,,,,,,,,,,,, -4300,0.18112361,0.29381463,,,,,,,,,,,,,, -4400,0.16806757,0.2600183,,,,,,,,,,,,,, -4489,,,0.7448859214782715,0.2691952841622488,0.7199599867666714,0.2904703615754255,3554.0,0.736975666931374,0.2921583340241901,3581.0,1069.665913581848,1132.9262413978577,1069.665913581848,62.77670836448669,0.3269517421722412,0.0 -4500,0.23107049,0.33213878,,,,,,,,,,,,,, -4600,0.1877831,0.2905074,,,,,,,,,,,,,, -4700,0.1851867,0.3217705,,,,,,,,,,,,,, -4800,0.3476366,0.25615925,,,,,,,,,,,,,, -4838,,,0.7423827988760812,0.2701189517974853,0.7188070165218767,0.2907897227925401,3554.0,0.7358253902977521,0.2923560122543459,3581.0,1149.8356471061709,1217.201759338379,1149.8356471061709,66.84105086326599,0.3562269210815429,0.0 -4900,0.082928896,0.31535307,,,,,,,,,,,,,, -5000,0.2342502,0.3951227,,,,,,,,,,,,,, -5100,0.15293625,0.29388276,,,,,,,,,,,,,, -5184,,,0.7448692321777344,0.2686702353613717,0.7204929195097074,0.2895360806046004,3554.0,0.7376395712615191,0.2911693633717886,3581.0,1229.944278717041,1301.4194130897522,1229.944278717041,70.9094934463501,0.3851003646850586,0.0 -5200,0.22778668,0.26726604,,,,,,,,,,,,,, -5300,0.31782693,0.26474965,,,,,,,,,,,,,, -5400,0.08699378,0.354022,,,,,,,,,,,,,, -5500,0.0830598,0.31386295,,,,,,,,,,,,,, -5529,,,0.7455554008483887,0.2685231992176601,0.7202796227622046,0.2902117607361423,3554.0,0.7374340868036163,0.2918261091524713,3581.0,1309.9406440258026,1385.516949415207,1309.9406440258026,74.97279810905457,0.4111568927764892,0.0 -5600,0.056280825,0.2505979,,,,,,,,,,,,,, -5700,0.22907698,0.2980038,,,,,,,,,,,,,, -5800,0.067146316,0.273058,,,,,,,,,,,,,, -5876,,,0.745835712977818,0.2688674245561872,0.7213378631471581,0.2901526146815208,3554.0,0.7384640316296076,0.2916755750837755,3581.0,1389.9110429286957,1469.595920562744,1389.9110429286957,79.03741836547852,0.4429998397827148,0.0 -5900,0.13352098,0.2537218,,,,,,,,,,,,,, -6000,0.18172127,0.24040002,,,,,,,,,,,,,, -6100,0.16788426,0.29347673,,,,,,,,,,,,,, -6200,0.11552591,0.2740421,,,,,,,,,,,,,, -6225,,,0.7440728460039411,0.2689230612346104,0.7203090927476083,0.2896457858882597,3554.0,0.7372454419811156,0.2911691588418039,3581.0,1469.9522788524628,1553.7422938346865,1469.9522788524628,83.10330438613892,0.4701013565063476,0.0 -6300,0.18488216,0.22220501,,,,,,,,,,,,,, -6400,0.08610904,0.30031002,,,,,,,,,,,,,, -6500,0.10354298,0.300222,,,,,,,,,,,,,, -6568,,,0.7444088799612862,0.2678774084363665,0.7188717268394766,0.289412430316193,3554.0,0.7361113232162804,0.2909232115352729,3581.0,1549.914607048035,1637.810804605484,1549.914607048035,87.16951441764832,0.4982421398162842,0.0 -6600,0.14165592,0.24702291,,,,,,,,,,,,,, -6700,0.20359981,0.27459517,,,,,,,,,,,,,, -6800,0.1546417,0.27025646,,,,,,,,,,,,,, -6900,0.069448136,0.3444992,,,,,,,,,,,,,, -6916,,,0.7451337405613491,0.2679372344698225,0.7206505049328221,0.2890318965536191,3554.0,0.7378540550387461,0.2905541030896572,3581.0,1629.9265213012695,1721.9272499084473,1629.9265213012695,91.23540997505188,0.5247664451599121,0.0 -7000,0.1078889,0.24856803,,,,,,,,,,,,,, -7100,0.2709581,0.27483004,,,,,,,,,,,,,, -7200,0.12235661,0.27321652,,,,,,,,,,,,,, -7262,,,0.7455902780805316,0.2686879805156162,0.7210564213518219,0.2896817475138049,3554.0,0.7383208606403588,0.2910777339386693,3581.0,1709.958627462387,1806.067808151245,1709.958627462387,95.30541563034058,0.5512409210205078,0.0 -7300,0.08701395,0.3065886,,,,,,,,,,,,,, -7400,0.4260845,0.28050086,,,,,,,,,,,,,, -7500,0.05721874,0.2695001,,,,,,,,,,,,,, -7600,0.17649336,0.26677707,,,,,,,,,,,,,, -7605,,,0.7475149972098214,0.2671380043029785,0.7227935705147369,0.2884691846994583,3554.0,0.7399462604283021,0.2898603714700852,3581.0,1790.0470299720764,1890.2598686218264,1790.0470299720764,99.37048292160034,0.5778157711029053,0.0 -7700,0.26794055,0.22055262,,,,,,,,,,,,,, -7800,0.052461956,0.2887064,,,,,,,,,,,,,, -7900,0.11195721,0.2895236,,,,,,,,,,,,,, -7952,,,0.7480261666434151,0.2665707724434988,0.7228617155625704,0.288152399530019,3554.0,0.7400685693591176,0.289658841258552,3581.0,1870.235821723938,1974.5582411289213,1870.235821723938,103.4391713142395,0.6066532135009766,0.0 -8000,0.1077008,0.2984329,,,,,,,,,,,,,, -8100,0.26025897,0.19676718,,,,,,,,,,,,,, -8200,0.057816956,0.32719222,,,,,,,,,,,,,, -8300,,,0.7468312127249581,0.2678480148315429,0.7222216192362478,0.2889935306169105,3554.0,0.7392413819245671,0.2904865059297333,3581.0,1950.461632490158,2058.8887956142426,1950.461632490158,107.50533604621889,0.6335721015930176,0.0 -8300,0.3771605,0.19884725,,,,,,,,,,,,,, -8400,0.33881542,0.20088787,,,,,,,,,,,,,, -8500,0.06917392,0.25364447,,,,,,,,,,,,,, -8600,0.07602312,0.17427048,,,,,,,,,,,,,, -8643,,,0.7479840006147113,0.2670548132487705,0.7238505057022018,0.2881419922974114,3554.0,0.7408080134302569,0.289656284633744,3581.0,2030.442727804184,2142.975204706192,2030.442727804184,111.5710666179657,0.6612002849578857,0.0 -8700,0.2237709,0.26742703,,,,,,,,,,,,,, -8800,0.08655942,0.2571951,,,,,,,,,,,,,, -8900,0.11698906,0.25174046,,,,,,,,,,,,,, -8989,,,0.7469439506530762,0.2675431455884661,0.7216550261369232,0.2893986570479565,3554.0,0.7387133536808852,0.2908245258176836,3581.0,2110.4105145931244,2227.0471861362457,2110.4105145931244,115.63677453994752,0.6878864765167236,0.0 -9000,0.17695493,0.34202647,,,,,,,,,,,,,, -9100,0.29524967,0.28927562,,,,,,,,,,,,,, -9200,0.09946776,0.24220306,,,,,,,,,,,,,, -9300,0.08451781,0.41661173,,,,,,,,,,,,,, -9334,,,0.7478771890912738,0.2656364100319998,0.7228479079470315,0.2870566518909151,3554.0,0.7400444348209299,0.2884846346167272,3581.0,2190.373128414154,2311.114481449127,2190.373128414154,119.70294833183289,0.7147269248962402,0.0 -9400,0.087082535,0.28070202,,,,,,,,,,,,,, -9500,0.061387457,0.32040358,,,,,,,,,,,,,, -9600,0.13167669,0.29005858,,,,,,,,,,,,,, -9680,,,0.747654846736363,0.2662521089826311,0.7231064744390123,0.2873728703298659,3554.0,0.7403361627556897,0.2888194502015847,3581.0,2270.471269130707,2395.319464445114,2270.471269130707,123.77129817008972,0.7413980960845947,0.0 -9700,0.021539286,0.2862075,,,,,,,,,,,,,, -9800,0.40566912,0.22635219,,,,,,,,,,,,,, -9900,0.06428845,0.28913102,,,,,,,,,,,,,, -10000,0.08734793,0.2296056,,,,,,,,,,,,,, -10028,,,0.7482317515781948,0.2649153130395071,0.7228617155625704,0.2866361378684229,3554.0,0.740131564594387,0.2880712454294366,3581.0,2350.5713725090027,2479.52467250824,2350.5713725090027,127.836492061615,0.7694323062896729,0.0 -10100,0.23413853,0.2264473,,,,,,,,,,,,,, -10200,0.10321075,0.3405385,,,,,,,,,,,,,, -10300,0.038274474,0.2680075,,,,,,,,,,,,,, -10375,,,0.7496601513453892,0.2653673716953822,0.7246820538917417,0.2869357322130434,3554.0,0.7417603732154077,0.2883899372338383,3581.0,2430.7196304798126,2563.780470609665,2430.7196304798126,131.9038667678833,0.7975926399230957,0.0 -10400,0.19551106,0.23483147,,,,,,,,,,,,,, -10500,0.21552621,0.2567733,,,,,,,,,,,,,, -10600,0.0793304,0.2495139,,,,,,,,,,,,,, -10700,0.08665875,0.22050233,,,,,,,,,,,,,, -10721,,,0.7480405398777553,0.2658732107707432,0.722738683525605,0.2874855810024796,3554.0,0.7399122402741902,0.2889256694402751,3581.0,2510.73530125618,2647.9112889766693,2510.73530125618,135.97573614120483,0.8289098739624023,0.0 -10800,0.1903756,0.25923204,,,,,,,,,,,,,, -10900,0.069967926,0.28228003,,,,,,,,,,,,,, -11000,0.03659596,0.33053163,,,,,,,,,,,,,, -11069,,,0.7497758184160505,0.2646873508180891,0.7245000131893641,0.2865176225017146,3554.0,0.7416544948600251,0.2879434823656974,3581.0,2590.721424818039,2732.009222984314,2590.721424818039,140.04378747940063,0.8607263565063477,0.0 -11100,0.096307226,0.32881385,,,,,,,,,,,,,, -11200,0.1370365,0.23997827,,,,,,,,,,,,,, -11300,0.12883238,0.28373235,,,,,,,,,,,,,, -11400,0.058804497,0.28219032,,,,,,,,,,,,,, -11416,,,0.7497409411839077,0.2649575301579067,0.7244434088351154,0.2867852910288143,3554.0,0.7416660167158265,0.2881520688617006,3581.0,2670.772175550461,2816.172864437104,2670.772175550461,144.11624264717102,0.8889884948730469,0.0 -11500,0.10975057,0.25475696,,,,,,,,,,,,,, -11600,0.08042246,0.2279077,,,,,,,,,,,,,, -11700,0.08046747,0.30027822,,,,,,,,,,,,,, -11760,,,0.7492256845746722,0.2650561843599592,0.7242193270346793,0.2865726297064047,3554.0,0.7414381703129364,0.2879775706964709,3581.0,2750.9215309619904,2900.4281883239746,2750.9215309619904,148.1839623451233,0.9151699542999268,0.0 -11800,0.25143287,0.28905344,,,,,,,,,,,,,, -11900,0.14461984,0.29645768,,,,,,,,,,,,,, -12000,0.09426667,0.33122793,,,,,,,,,,,,,, -12100,0.075253725,0.2886997,,,,,,,,,,,,,, -12106,,,0.7497265679495675,0.2645014354160854,0.7241296118809791,0.2866236011030705,3554.0,0.7412864772409942,0.2880726089626675,3581.0,2831.0232417583466,2984.6321744918823,2831.0232417583466,152.245596408844,0.9438767433166504,0.0 -12200,0.058968984,0.25049117,,,,,,,,,,,,,, -12300,0.12830599,0.20955895,,,,,,,,,,,,,, -12400,0.20924942,0.27636135,,,,,,,,,,,,,, -12453,,,0.7496450287955148,0.2649199962615967,0.7241797589423888,0.2867595133784116,3554.0,0.7413394505070162,0.2881726923018186,3581.0,2911.0521461963654,3068.76672244072,2911.0521461963654,156.31152319908142,0.9713306427001952,0.0 -12500,0.14211288,0.33035356,,,,,,,,,,,,,, -12600,0.09494312,0.2651709,,,,,,,,,,,,,, -12700,0.11317049,0.20849246,,,,,,,,,,,,,, -12800,0.104884125,0.27062833,,,,,,,,,,,,,, -12801,,,0.7484604971749442,0.2650728055409023,0.723449810073157,0.286540600846977,3554.0,0.7406805912498254,0.2879016218955075,3581.0,2991.1137623786926,3152.8874640464783,2991.1137623786926,160.32822513580322,1.001964807510376,0.0 -12900,0.12432134,0.22064017,,,,,,,,,,,,,, -13000,0.10737031,0.3249984,,,,,,,,,,,,,, -13100,0.104987755,0.23755398,,,,,,,,,,,,,, -13146,,,0.7495675086975098,0.2642490012305123,0.7237743920802265,0.2863297599315911,3554.0,0.7409998625558504,0.2877257942853777,3581.0,3071.2110888957977,3237.094136953354,3071.2110888957977,164.3973925113678,1.030181884765625,0.0 -13200,0.10463247,0.2404501,,,,,,,,,,,,,, -13300,0.042030822,0.354137,,,,,,,,,,,,,, -13400,0.08708553,0.36226118,,,,,,,,,,,,,, -13494,,,0.7496738433837891,0.2644489322389875,0.7243514954540659,0.2861986734522281,3554.0,0.7415456849081961,0.2876294947509424,3581.0,3151.210128545761,3321.202522754669,3151.210128545761,168.46466898918152,1.060338020324707,0.0 -13500,0.12757233,0.3410573,,,,,,,,,,,,,, -13600,0.06520631,0.24067836,,,,,,,,,,,,,, -13700,0.11901469,0.37378693,,,,,,,,,,,,,, -13800,0.13375673,0.21547893,,,,,,,,,,,,,, -13842,,,0.7498070171901158,0.2644280535834176,0.7244232126213421,0.2862306336170512,3554.0,0.7415598656537978,0.2876833884018954,3581.0,3231.4134809970856,3405.514585494995,3231.4134809970856,172.53335785865784,1.0881555080413818,0.0 -13900,0.13328278,0.23805347,,,,,,,,,,,,,, -14000,0.060847297,0.2629039,,,,,,,,,,,,,, -14100,0.110814676,0.2779418,,,,,,,,,,,,,, -14187,,,0.7504711151123047,0.2647754464830671,0.7246985405968627,0.2869762276824968,3554.0,0.7418224139774156,0.2883037960219736,3581.0,3311.5125164985657,3489.7209939956665,3311.5125164985657,176.60153603553772,1.1152000427246094,0.0 -14200,0.15530613,0.23103261,,,,,,,,,,,,,, -14300,0.15117247,0.22988631,,,,,,,,,,,,,, -14400,0.11163192,0.30695128,,,,,,,,,,,,,, -14500,0.07829643,0.35256878,,,,,,,,,,,,,, -14533,,,0.750499997820173,0.2641300303595407,0.7250736131383653,0.286035317682321,3554.0,0.7422204293275272,0.2874723816344073,3581.0,3391.5887458324432,3573.904182910919,3391.5887458324432,180.6669204235077,1.1448464393615725,0.0 -14600,0.12059893,0.21414773,,,,,,,,,,,,,, -14700,0.11103552,0.27321383,,,,,,,,,,,,,, -14800,0.084368564,0.27729553,,,,,,,,,,,,,, -14879,,,0.7497051102774483,0.2644610064370291,0.724206000281373,0.2863332977037317,3554.0,0.7414732812936331,0.2876913650712964,3581.0,3471.7626218795776,3658.1833543777466,3471.7626218795776,184.73077940940857,1.1741950511932373,0.0 -14900,0.09980755,0.20642616,,,,,,,,,,,,,, -15000,0.0783274,0.20949134,,,,,,,,,,,,,, -15100,0.13523567,0.26531824,,,,,,,,,,,,,, -15200,0.08844927,0.25765666,,,,,,,,,,,,,, -15225,,,0.7510156631469727,0.2643768957683018,0.7252654771692107,0.2865569673365398,3554.0,0.7423924390446105,0.2879501295901983,3581.0,3551.921292543412,3742.449414014816,3551.921292543412,188.79840087890625,1.2018353939056396,0.0 -15300,0.26164898,0.17486545,,,,,,,,,,,,,, -15400,0.06903908,0.3530507,,,,,,,,,,,,,, -15500,0.102912314,0.27717882,,,,,,,,,,,,,, -15571,,,0.7497109685625348,0.2642054557800293,0.7242641846115293,0.2861932809257614,3554.0,0.7415143236438844,0.2875700447020734,3581.0,3632.060835123062,3826.6935093402863,3632.060835123062,192.8629801273346,1.22971510887146,0.0 -15600,0.054680422,0.29489774,,,,,,,,,,,,,, -15700,0.2853597,0.21146339,,,,,,,,,,,,,, -15800,0.03768992,0.3856531,,,,,,,,,,,,,, -15900,0.068681695,0.20755793,,,,,,,,,,,,,, -15919,,,0.749645437513079,0.2644821405410766,0.7241334587788407,0.2863569629950408,3554.0,0.7413944008962231,0.2877211582723925,3581.0,3712.266945838928,3911.0051939487457,3712.266945838928,196.92791748046875,1.258352279663086,0.0 -16000,0.1747704,0.23201804,,,,,,,,,,,,,, -16100,0.09430862,0.3016551,,,,,,,,,,,,,, -16200,0.09793946,0.2647742,,,,,,,,,,,,,, -16264,,,0.7499145099094936,0.2641898904527937,0.7246541638822454,0.2860444025437887,3554.0,0.7418102103549986,0.2874283395110479,3581.0,3792.3345487117767,3995.182499408722,3792.3345487117767,200.9972834587097,1.286785364151001,0.0 -16300,0.12328706,0.24310715,,,,,,,,,,,,,, -16400,0.061328735,0.24021837,,,,,,,,,,,,,, -16500,0.083317116,0.3198439,,,,,,,,,,,,,, -16600,0.097176194,0.27505764,,,,,,,,,,,,,, -16611,,,0.7509322847638812,0.2646711553846086,0.7248380593389491,0.2872992297136589,3554.0,0.7419807883621893,0.2888027469195057,3581.0,3872.3533470630646,4079.307808637619,3872.3533470630646,205.06292033195496,1.3158259391784668,0.0 -16700,0.036283277,0.28129345,,,,,,,,,,,,,, -16800,0.11765247,0.1751317,,,,,,,,,,,,,, -16900,0.085471086,0.31756583,,,,,,,,,,,,,, -16959,,,0.7492737088884626,0.2639563594545637,0.723998817353686,0.2859136423637978,3554.0,0.7411933479213209,0.2872786235622905,3581.0,3952.4735357761374,4163.5395884513855,3952.4735357761374,209.13350009918213,1.3445637226104736,0.0 -17000,0.107754275,0.20597345,,,,,,,,,,,,,, -17100,0.114475824,0.25477728,,,,,,,,,,,,,, -17200,0.094474584,0.2400727,,,,,,,,,,,,,, -17300,0.035124697,0.33919933,,,,,,,,,,,,,, -17303,,,0.7509043557303292,0.26378219468253,0.7252153988024057,0.2859099156814944,3554.0,0.7424263910220609,0.2872969630842467,3581.0,4032.434810400009,4247.611432790756,4032.434810400009,213.2035298347473,1.3733389377593994,0.0 -17400,0.13644245,0.28422412,,,,,,,,,,,,,, -17500,0.050540484,0.22515613,,,,,,,,,,,,,, -17600,0.08556056,0.29186857,,,,,,,,,,,,,, -17648,,,0.7511190005711147,0.2633920737675258,0.724937803904931,0.2859708134485351,3554.0,0.7421531389625803,0.2873458116622452,3581.0,4112.587472200394,4331.875928163528,4112.587472200394,217.2685823440552,1.4080066680908203,0.0 -17700,0.12707318,0.21000393,,,,,,,,,,,,,, -17800,0.08292583,0.27303448,,,,,,,,,,,,,, -17900,0.06407839,0.2872492,,,,,,,,,,,,,, -17995,,,0.7503595352172852,0.2636561734335763,0.7246912589687676,0.2857852006267146,3554.0,0.7418655698041748,0.287149565141982,3581.0,4192.576465368271,4415.970258712769,4192.576465368271,221.33321714401245,1.436694622039795,0.0 -18000,0.08116458,0.3220883,,,,,,,,,,,,,, -18100,0.149893,0.29099485,,,,,,,,,,,,,, -18200,0.05576857,0.36683762,,,,,,,,,,,,,, -18300,0.04633539,0.26587927,,,,,,,,,,,,,, -18335,,,0.7505239759172712,0.2641763687133789,0.7249439864193514,0.2862748042478545,3554.0,0.742102756409697,0.2876740141109327,3581.0,4272.693914175034,4500.195482492447,4272.693914175034,225.4005281925201,1.4654943943023682,0.0 -18400,0.12186308,0.27007094,,,,,,,,,,,,,, -18500,0.13154946,0.27754536,,,,,,,,,,,,,, -18600,0.07103783,0.3929966,,,,,,,,,,,,,, -18682,,,0.7512984957013812,0.2632061072758266,0.7251478720060144,0.2856864521325004,3554.0,0.7423240578530788,0.2871079773784383,3581.0,4352.8466629982,4584.453443288803,4352.8466629982,229.46516680717468,1.49409818649292,0.0 -18700,0.18969053,0.26177168,,,,,,,,,,,,,, -18800,0.046286475,0.2706967,,,,,,,,,,,,,, -18900,0.09507612,0.24229684,,,,,,,,,,,,,, -19000,0.046536714,0.27912453,,,,,,,,,,,,,, -19029,,,0.7505662781851632,0.2636205468858991,0.7250470970209623,0.2857843762914586,3554.0,0.7422262243437587,0.2872036974112503,3581.0,4432.961780309677,4668.678829908371,4432.961780309677,233.5327014923096,1.524968147277832,0.0 -19100,0.18327868,0.19802962,,,,,,,,,,,,,, -19200,0.16136004,0.21557373,,,,,,,,,,,,,, -19300,0.13156013,0.27822903,,,,,,,,,,,,,, -19374,,,0.7511434555053711,0.2634677035467965,0.7254661341094542,0.2856579782188643,3554.0,0.7426140131946384,0.287047504679646,3581.0,4513.029278039932,4752.857128858566,4513.029278039932,237.60093474388125,1.555567741394043,0.0 -19400,0.239779,0.2307708,,,,,,,,,,,,,, -19500,0.27744758,0.31520596,,,,,,,,,,,,,, -19600,0.1119398,0.203998,,,,,,,,,,,,,, -19700,0.057775047,0.19723155,,,,,,,,,,,,,, -19718,,,0.752007007598877,0.2629783153533935,0.7257539645030248,0.2857248695901624,3554.0,0.742857608406346,0.2871675637806304,3581.0,4593.14227104187,4837.079212188721,4593.14227104187,241.6682722568512,1.5856091976165771,0.0 -19800,0.05677628,0.23078403,,,,,,,,,,,,,, -19900,0.06804068,0.25076395,,,,,,,,,,,,,, -20000,0.09105565,0.2784373,,,,,,,,,,,,,, -20064,,,0.7493659428187779,0.2639386143003191,0.7241226737259074,0.2860666767693532,3554.0,0.7411866666084892,0.2874566669139207,3581.0,4673.121878147125,4921.168244838715,4673.121878147125,245.73634243011475,1.6148817539215088,0.0 -20100,0.045265272,0.2984853,,,,,,,,,,,,,, -20200,0.053268176,0.3119312,,,,,,,,,,,,,, -20300,0.07858732,0.27969223,,,,,,,,,,,,,, -20400,0.09456219,0.29218623,,,,,,,,,,,,,, -20412,,,0.7491849490574428,0.2641740356172834,0.7240449114334201,0.2862578195068497,3554.0,0.7412260045422019,0.2876055647427394,3581.0,4753.126408100128,5005.285262107849,4753.126408100128,249.80592560768127,1.6456749439239502,0.0 -20500,0.35593084,0.2571686,,,,,,,,,,,,,, -20600,0.09782623,0.37370855,,,,,,,,,,,,,, -20700,0.12965631,0.25189826,,,,,,,,,,,,,, -20758,,,0.7516568728855678,0.2628540141241891,0.7252162231376618,0.2856588884223762,3554.0,0.7423890983881947,0.2870925694529286,3581.0,4833.162514925003,5089.435174465179,4833.162514925003,253.87785720825195,1.6757760047912598,0.0 -20800,0.03371992,0.31328642,,,,,,,,,,,,,, -20900,0.14303426,0.28727856,,,,,,,,,,,,,, -21000,0.043519307,0.32916945,,,,,,,,,,,,,, -21100,0.0423663,0.29278514,,,,,,,,,,,,,, -21107,,,0.7511718613760812,0.2635410002299717,0.7249950265106219,0.2860896722882667,3554.0,0.7421710694245671,0.2874962434659487,3581.0,4913.222190141678,5173.600088834763,4913.222190141678,257.94165658950806,1.7051050662994385,0.0 -21200,0.11193392,0.19271956,,,,,,,,,,,,,, -21300,0.036249336,0.2613602,,,,,,,,,,,,,, -21400,0.22778581,0.20952173,,,,,,,,,,,,,, -21454,,,0.7499760900224958,0.2635111468178885,0.7241845675647158,0.2857918468297165,3554.0,0.7413911965931305,0.2871669501906765,3581.0,4993.34165430069,5257.827834844589,4993.34165430069,262.0073826313019,1.7356181144714355,0.0 -21500,0.15185237,0.19112045,,,,,,,,,,,,,, -21600,0.097063035,0.3006777,,,,,,,,,,,,,, -21700,0.06417282,0.29504248,,,,,,,,,,,,,, -21798,,,0.7510729517255511,0.2637288229806082,0.7247480007122257,0.2865836895377567,3554.0,0.7419022488480871,0.2879440959556513,3581.0,5073.316593170166,5341.9150631427765,5073.316593170166,266.0779194831848,1.765653133392334,0.0 -21800,0.09813448,0.22234748,,,,,,,,,,,,,, -21900,0.07417483,0.31061587,,,,,,,,,,,,,, -22000,0.06366494,0.3852054,,,,,,,,,,,,,, -22100,0.04176736,0.34660918,,,,,,,,,,,,,, -22146,,,0.750969409942627,0.2630865403584072,0.7246992962375141,0.2856206598748769,3554.0,0.7419124071706577,0.287016347945319,3581.0,5153.433739185333,5426.143560171127,5153.433739185333,270.1434907913208,1.799447774887085,0.0 -22200,0.08726469,0.2985455,,,,,,,,,,,,,, -22300,0.2517891,0.25308546,,,,,,,,,,,,,, -22400,0.05707133,0.3076415,,,,,,,,,,,,,, -22492,,,0.7517211096627372,0.263117824281965,0.7257709320703785,0.2855805422257491,3554.0,0.7428656532524085,0.2869991333382784,3581.0,5233.556937217712,5510.373544454575,5233.556937217712,274.20904064178467,1.8289234638214111,0.0 -22500,0.12184417,0.30437762,,,,,,,,,,,,,, -22600,0.053868555,0.2730912,,,,,,,,,,,,,, -22700,0.062446114,0.22371686,,,,,,,,,,,,,, -22800,0.08060685,0.3063494,,,,,,,,,,,,,, -22835,,,0.7522268976484027,0.2629103319985525,0.7257775267524268,0.2857270849911631,3554.0,0.7428976281066741,0.28713978179105,3581.0,5313.555603504181,5594.484750509262,5313.555603504181,278.2751727104187,1.863584280014038,0.0 -22900,0.2867675,0.19448557,,,,,,,,,,,,,, -23000,0.03817955,0.2572167,,,,,,,,,,,,,, -23100,0.1476985,0.35438335,,,,,,,,,,,,,, -23182,,,0.7522249221801758,0.2629103830882481,0.7257977229662,0.2857639568202201,3554.0,0.7429226489414619,0.287239728776878,3581.0,5393.517927885056,5678.554426193237,5393.517927885056,282.3405730724335,1.893626928329468,0.0 -23200,0.068562485,0.26043808,,,,,,,,,,,,,, -23300,0.12282389,0.17547014,,,,,,,,,,,,,, -23400,0.08579702,0.27354288,,,,,,,,,,,,,, -23500,0.09146203,0.33630618,,,,,,,,,,,,,, -23530,,,0.7511571475437709,0.2631608928952898,0.7253869979248734,0.2856311701493915,3554.0,0.7425341783239667,0.287000394606517,3581.0,5473.649057149887,5762.797847509384,5473.649057149887,286.4092450141907,1.9252281188964844,0.0 -23600,0.0336546,0.19771749,,,,,,,,,,,,,, -23700,0.06659981,0.25360167,,,,,,,,,,,,,, -23800,0.05434733,0.19016398,,,,,,,,,,,,,, -23877,,,0.7517889567783901,0.263093147959028,0.7257241510445976,0.285614219755689,3554.0,0.7428848790709648,0.2869906794322466,3581.0,5553.753219604492,5847.011462926865,5553.753219604492,290.47550415992737,1.956531286239624,0.0 -23900,0.16401926,0.2321508,,,,,,,,,,,,,, -24000,0.030113574,0.24136245,,,,,,,,,,,,,, -24100,0.07306284,0.30202505,,,,,,,,,,,,,, -24200,0.080078505,0.22056061,,,,,,,,,,,,,, -24223,,,0.7519252640860421,0.2628435237067086,0.725507900429094,0.2856898353417804,3554.0,0.7426174902043773,0.2871886644573792,3581.0,5633.818361282349,5931.189614534378,5633.818361282349,294.5453701019287,1.9875144958496087,0.0 -24300,0.09469438,0.26246977,,,,,,,,,,,,,, -24400,0.5037161,0.20046891,,,,,,,,,,,,,, -24500,0.038377937,0.3777632,,,,,,,,,,,,,, -24569,,,0.7520390238080706,0.262832828930446,0.725781236261079,0.2855415580375984,3554.0,0.7429443291198339,0.2869305135284313,3581.0,5713.8310968875885,6015.310056209564,5713.8310968875885,298.6102910041809,2.0182888507843018,0.0 -24600,0.039713763,0.24898964,,,,,,,,,,,,,, -24700,0.042404983,0.29980156,,,,,,,,,,,,,, -24800,0.042103708,0.24264401,,,,,,,,,,,,,, -24900,0.04825705,0.28910032,,,,,,,,,,,,,, -24914,,,0.7516685894557408,0.2628613199506487,0.7254507465180079,0.2855545928388348,3554.0,0.7425912421896816,0.2869498757003106,3581.0,5793.904387712479,6099.492146253586,5793.904387712479,302.67626667022705,2.049443483352661,0.0 -25000,0.061747905,0.19463158,,,,,,,,,,,,,, -25100,0.028836604,0.31957108,,,,,,,,,,,,,, -25200,0.050513677,0.22582835,,,,,,,,,,,,,, -25262,,,0.752216134752546,0.2625758647918701,0.72544092318954,0.2857069746456457,3554.0,0.7426024231621754,0.2871179652593549,3581.0,5873.891070127487,6183.589738368988,5873.891070127487,306.7443902492523,2.080031633377075,0.0 -25300,0.024554664,0.23691128,,,,,,,,,,,,,, -25400,0.046310488,0.26086587,,,,,,,,,,,,,, -25500,0.025255522,0.24090725,,,,,,,,,,,,,, -25600,0.0933058,0.2642578,,,,,,,,,,,,,, -25609,,,0.7520290102277484,0.2627349751336233,0.7257105495128728,0.2855440653906689,3554.0,0.7428054532602625,0.2869750669767523,3581.0,5953.967887401581,6267.775489091873,5953.967887401581,310.8091578483581,2.1123173236846924,0.0 -25700,0.055752497,0.2867242,,,,,,,,,,,,,, -25800,0.01979,0.2652344,,,,,,,,,,,,,, -25900,0.06221302,0.20966992,,,,,,,,,,,,,, -25953,,,0.7518808501107352,0.2627204656600952,0.7255820906021384,0.2855149388782885,3554.0,0.742740344548485,0.2868983341441811,3581.0,6033.956508636475,6351.873990535736,6033.956508636475,314.874475479126,2.1450178623199463,0.0 -26000,0.041070476,0.23288915,,,,,,,,,,,,,, -26100,0.10087962,0.3219137,,,,,,,,,,,,,, -26200,0.045651063,0.2344034,,,,,,,,,,,,,, -26300,,,0.7524644306727818,0.2623820304870605,0.7256400001538759,0.2856000171461733,3554.0,0.7427877955049218,0.2870122232572955,3581.0,6113.944706916809,6435.973375320435,6113.944706916809,318.94093799591064,2.177800178527832,0.0 -26300,0.068578005,0.21791856,,,,,,,,,,,,,, -26400,0.061539162,0.33099872,,,,,,,,,,,,,, -26500,0.06790694,0.22075576,,,,,,,,,,,,,, -26600,0.05501678,0.25967005,,,,,,,,,,,,,, -26646,,,0.7519332340785435,0.2625506094523838,0.7253897457090602,0.2855287636674785,3554.0,0.7425352009738899,0.2869487848737259,3581.0,6193.946717262268,6520.0908126831055,6193.946717262268,323.00680804252625,2.215407609939575,0.0 -26700,0.033636164,0.24431808,,,,,,,,,,,,,, -26800,0.06771834,0.21918848,,,,,,,,,,,,,, -26900,0.047119703,0.27011052,,,,,,,,,,,,,, -26988,,,0.7521262850080218,0.2627148628234863,0.7255871740028841,0.285691638575153,3554.0,0.7427188689000978,0.2870836723985968,3581.0,6274.016363620758,6604.268972635269,6274.016363620758,327.07042241096497,2.2482399940490723,0.0 -27000,0.037930176,0.272118,,,,,,,,,,,,,, -27100,0.07700525,0.33517092,,,,,,,,,,,,,, -27200,0.04893769,0.3272786,,,,,,,,,,,,,, -27300,0.059336018,0.21772116,,,,,,,,,,,,,, -27334,,,0.7526041439601353,0.2621868678501674,0.725647281781971,0.2855241954762679,3554.0,0.7427528208775481,0.2869639882692509,3581.0,6354.049415111542,6688.422640323639,6354.049415111542,331.14239048957825,2.2848925590515137,0.0 -27400,0.042238932,0.2662777,,,,,,,,,,,,,, -27500,0.049707912,0.22927228,,,,,,,,,,,,,, -27600,0.030359238,0.24895346,,,,,,,,,,,,,, -27684,,,0.752434526171003,0.2624724251883371,0.7257336995946468,0.2855600368862549,3554.0,0.7428624489493159,0.2869821914378839,3581.0,6434.165660142899,6772.650623321533,6434.165660142899,335.2097702026367,2.317322254180908,0.0 -27700,0.02651067,0.3319739,,,,,,,,,,,,,, -27800,0.040453475,0.26942033,,,,,,,,,,,,,, -27900,0.03810173,0.23904546,,,,,,,,,,,,,, -28000,0.030701917,0.23580639,,,,,,,,,,,,,, -28030,,,0.7522539411272321,0.2626483099801199,0.7259011083462296,0.285618221216411,3554.0,0.7429992113323792,0.286982634586184,3581.0,6514.367031574249,6856.961793422699,6514.367031574249,339.2761797904968,2.3487956523895264,0.0 -28100,0.06664809,0.24979943,,,,,,,,,,,,,, -28200,0.05187049,0.290888,,,,,,,,,,,,,, -28300,0.033116467,0.19184887,,,,,,,,,,,,,, -28374,,,0.7526415416172573,0.2622614758355276,0.7256958488674733,0.2856316510116242,3554.0,0.742843768544052,0.2870561972039933,3581.0,6594.354041337967,6941.054801940918,6594.354041337967,343.3381669521332,2.380867481231689,0.0 -28400,0.039257463,0.22436798,,,,,,,,,,,,,, -28500,0.03004272,0.2505528,,,,,,,,,,,,,, -28600,0.03661547,0.33668694,,,,,,,,,,,,,, -28700,0.049088072,0.2985262,,,,,,,,,,,,,, -28720,,,0.7523200852530343,0.2624066046306065,0.7255589405203644,0.2855549878328116,3554.0,0.7427183234868053,0.2869618747927429,3581.0,6674.474102973938,7025.281142234802,6674.474102973938,347.3995735645294,2.413620710372925,0.0 -28800,0.034155298,0.25702742,,,,,,,,,,,,,, -28900,0.035252508,0.27573845,,,,,,,,,,,,,, -29000,0.0320485,0.29425693,,,,,,,,,,,,,, -29066,,,0.7522360937935966,0.2625972884041922,0.7256316881067107,0.2856577206140968,3554.0,0.7427424580249931,0.2870614808952632,3581.0,6754.4682993888855,7109.388455867767,6754.4682993888855,351.46641421318054,2.447869300842285,0.0 -29100,0.051702894,0.19868603,,,,,,,,,,,,,, -29200,0.040985048,0.24399391,,,,,,,,,,,,,, -29300,0.045015104,0.2240863,,,,,,,,,,,,,, -29400,0.03200459,0.22617114,,,,,,,,,,,,,, -29410,,,0.7528103419712612,0.2621924536568777,0.7257646808613534,0.28564556166907,3554.0,0.7428698802054244,0.2870381644770141,3581.0,6834.473086357117,7193.50737786293,6834.473086357117,355.53622817993164,2.4804461002349854,0.0 -29500,0.023977017,0.2923906,,,,,,,,,,,,,, -29600,0.04389133,0.2697394,,,,,,,,,,,,,, -29700,0.031059245,0.3406924,,,,,,,,,,,,,, -29757,,,0.7524969237191337,0.2622502020427159,0.725686918568866,0.2854467938304551,3554.0,0.7428138389896328,0.2868713361862084,3581.0,6914.607104301453,7277.75493144989,6914.607104301453,359.60470604896545,2.513329029083252,0.0 -29800,0.04440384,0.29956555,,,,,,,,,,,,,, -29900,0.038065337,0.20572284,,,,,,,,,,,,,, -30000,0.031191858,0.27505064,,,,,,,,,,,,,, -30100,0.028251553,0.27154306,,,,,,,,,,,,,, -30105,,,0.7526117733546666,0.2624603680201939,0.7260031885287704,0.2855814696029122,3554.0,0.743100249144792,0.2869789530464605,3581.0,6994.64807009697,7361.90936923027,6994.64807009697,363.6717460155487,2.547752857208252,0.0 -30200,0.01977635,0.2437498,,,,,,,,,,,,,, -30300,0.072798565,0.16370976,,,,,,,,,,,,,, -30400,0.04981102,0.23757958,,,,,,,,,,,,,, -30448,,,0.7527153832571847,0.2622091088976179,0.7258541212366347,0.2855970632781724,3554.0,0.7429354661538328,0.2869994401332554,3581.0,7074.760458707809,7446.134158611298,7074.760458707809,367.73788118362427,2.5821104049682617,0.0 -30500,0.0534702,0.28213996,,,,,,,,,,,,,, -30600,0.026656339,0.2641766,,,,,,,,,,,,,, -30700,0.023677917,0.28804672,,,,,,,,,,,,,, -30796,,,0.7526858874729702,0.2621920279094151,0.7258120801385762,0.2854876155992807,3554.0,0.7429225807648003,0.2869368539579552,3581.0,7154.866800308228,7530.353799581528,7154.866800308228,371.8072714805603,2.6142380237579346,0.0 -30800,0.02138751,0.2834631,,,,,,,,,,,,,, -30900,0.032332286,0.22084671,,,,,,,,,,,,,, -31000,0.02165911,0.26271918,,,,,,,,,,,,,, -31100,0.04638457,0.28111476,,,,,,,,,,,,,, -31141,,,0.7527549607413155,0.262384397642953,0.7260390471124085,0.2855822080699124,3554.0,0.7431579947771223,0.2870004968715093,3581.0,7234.82731628418,7614.42972612381,7234.82731628418,375.8769974708557,2.6478710174560547,0.0 -31200,0.027111258,0.18896839,,,,,,,,,,,,,, -31300,0.026077468,0.29330721,,,,,,,,,,,,,, -31400,0.020179551,0.20862283,,,,,,,,,,,,,, -31487,,,0.7522657939365932,0.2623383658272879,0.7256374584535031,0.2855801987527258,3554.0,0.7427517982276249,0.286932354298293,3581.0,7314.991788864136,7698.705207109451,7314.991788864136,379.9435257911682,2.680339097976685,0.0 -31500,0.022047453,0.30413526,,,,,,,,,,,,,, -31600,0.028078476,0.20292488,,,,,,,,,,,,,, -31700,0.024763566,0.25428772,,,,,,,,,,,,,, -31800,0.029544994,0.27130225,,,,,,,,,,,,,, -31835,,,0.7526765550885882,0.2621617998395647,0.725747644599395,0.2855543352340672,3554.0,0.7428622444193312,0.2869863502142383,3581.0,7395.034659147263,7782.868022441864,7395.034659147263,384.0136480331421,2.718026161193848,0.0 -31900,0.026061505,0.25523928,,,,,,,,,,,,,, -32000,0.028641282,0.2186299,,,,,,,,,,,,,, -32100,0.019944953,0.42460003,,,,,,,,,,,,,, -32182,,,0.7528634071350098,0.262218884059361,0.7260420696750141,0.2855046690348902,3554.0,0.7431491318111212,0.2869042314254049,3581.0,7475.214136600494,7867.158411026001,7475.214136600494,388.08017563819885,2.7506425380706787,0.0 -32200,0.023258885,0.27727306,,,,,,,,,,,,,, -32300,0.034935277,0.2006554,,,,,,,,,,,,,, -32400,0.020860987,0.33495608,,,,,,,,,,,,,, -32500,0.028947152,0.27549326,,,,,,,,,,,,,, -32526,,,0.7528461047581264,0.262486457824707,0.7259932965056978,0.2858738853613446,3554.0,0.7430746147200503,0.2873262790487119,3581.0,7555.255586385727,7951.3141577243805,7555.255586385727,392.1504149436951,2.783066987991333,0.0 -32600,0.020251118,0.26157895,,,,,,,,,,,,,, -32700,0.043286122,0.1928366,,,,,,,,,,,,,, -32800,0.01851801,0.31952146,,,,,,,,,,,,,, -32873,,,0.7528681755065918,0.2620720863342285,0.7258712261931978,0.2855324388288284,3554.0,0.7429696908379293,0.2869643632408894,3581.0,7635.29710650444,8035.4664006233215,7635.29710650444,396.2140109539032,2.818079233169556,0.0 -32900,0.02400496,0.2320377,,,,,,,,,,,,,, -33000,0.020811006,0.34323734,,,,,,,,,,,,,, -33100,0.02126694,0.28603944,,,,,,,,,,,,,, -33200,0.033135127,0.34288195,,,,,,,,,,,,,, -33220,,,0.7528653144836426,0.2621460471834455,0.7259631395742473,0.2855008393106799,3554.0,0.7430804097362818,0.2869187530543144,3581.0,7715.372609138489,8119.656742095947,7715.372609138489,400.2839741706848,2.851149559020996,0.0 -33300,0.02105022,0.30484176,,,,,,,,,,,,,, -33400,0.028136652,0.2647494,,,,,,,,,,,,,, -33500,0.03318997,0.19496873,,,,,,,,,,,,,, -33565,,,0.7528884070260184,0.2621467964989798,0.7260305976760341,0.2854834080547446,3554.0,0.7431213839098716,0.2868944480744729,3581.0,7795.492541074753,8203.884775876999,7795.492541074753,404.3473062515259,2.8837192058563232,0.0 -33600,0.017951053,0.23557778,,,,,,,,,,,,,, -33700,0.016578613,0.34188455,,,,,,,,,,,,,, -33800,0.02234883,0.20631953,,,,,,,,,,,,,, -33900,0.020228013,0.31814328,,,,,,,,,,,,,, -33912,,,0.7530028479439872,0.2619972058704921,0.7259071534714406,0.2855185109977314,3554.0,0.7430133239013195,0.2869322861216315,3581.0,7875.469173669815,8287.974950790405,7875.469173669815,408.41320180892944,2.919473171234131,0.0 -34000,0.021166159,0.26745793,,,,,,,,,,,,,, -34100,0.026463887,0.21051794,,,,,,,,,,,,,, -34200,0.016462281,0.2785387,,,,,,,,,,,,,, -34257,,,0.7529234204973493,0.2620432717459542,0.7259287235773073,0.2854790459473481,3554.0,0.7430517073617705,0.2869017088889277,3581.0,7955.515779495239,8372.131846427917,7955.515779495239,412.47713685035706,2.95334792137146,0.0 -34300,0.014894932,0.2508824,,,,,,,,,,,,,, -34400,0.032773476,0.19793747,,,,,,,,,,,,,, -34500,0.020724451,0.22742733,,,,,,,,,,,,,, -34600,0.039249055,0.21482667,,,,,,,,,,,,,, -34601,,,0.7529683794294085,0.2620537621634347,0.7260481834948298,0.2854372968013594,3554.0,0.7431550631806758,0.2868458040264591,3581.0,8035.698899507523,8456.37722826004,8035.698899507523,416.4928929805756,2.98781418800354,0.0 -34700,0.023586135,0.30721936,,,,,,,,,,,,,, -34800,0.020535724,0.2568051,,,,,,,,,,,,,, -34900,0.017312434,0.17518541,,,,,,,,,,,,,, -34948,,,0.7530630656651088,0.2619693449565342,0.7260910489281444,0.2854283149817986,3554.0,0.7431837655551871,0.2868427360766895,3581.0,8115.888223648071,8540.681955337524,8115.888223648071,420.558625459671,3.02548623085022,0.0 -35000,0.024856627,0.30820802,,,,,,,,,,,,,, -35100,0.016818425,0.2516317,,,,,,,,,,,,,, -35200,0.020422483,0.24898125,,,,,,,,,,,,,, -35295,,,0.7529725347246442,0.2619802270616804,0.7259849844585327,0.2854171349348885,3554.0,0.7430881136990366,0.286832270959142,3581.0,8195.908843517303,8624.81097126007,8195.908843517303,424.6205706596375,3.0597083568573,0.0 -35300,0.022394393,0.33331847,,,,,,,,,,,,,, -35400,0.018098412,0.3168604,,,,,,,,,,,,,, -35500,0.030555915,0.27511728,,,,,,,,,,,,,, -35600,0.015892636,0.24915604,,,,,,,,,,,,,, -35642,,,0.7529713766915458,0.2619913646153041,0.7260084093187253,0.2854261854490539,3554.0,0.7431127936505166,0.286830157482634,3581.0,8275.888614416122,8708.90253329277,8275.888614416122,428.6862292289734,3.0939202308654785,0.0 -35700,0.023827977,0.16569167,,,,,,,,,,,,,, -35800,0.02381084,0.25800103,,,,,,,,,,,,,, -35900,0.01940483,0.27986857,,,,,,,,,,,,,, -35985,,,0.7529499190194267,0.2619766848427908,0.7259724133458779,0.285419127078424,3554.0,0.7430827959194359,0.2868288280377338,3581.0,8355.901663303375,8793.024923086166,8355.901663303375,432.7504985332489,3.1271345615386963,0.0 -36000,0.021138242,0.21557821,,,,,,,,,,,,,, -36100,0.017434163,0.2635222,,,,,,,,,,,,,, -36189,,,0.7529680388314384,0.2619822536196027,0.7259976242657921,0.2854213768267269,3554.0,0.7431058396310388,0.2868319982524958,3581.0,8401.65019273758,8842.886168718338,8401.65019273758,436.82040643692017,3.1628973484039307,0.0 -36189,,,,,,,,,,,8401.65019273758,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 18949ab2c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,109 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -4.015159368515015,0.0,45.728089332580566,1,0,45.728089332580566,1.0888061012112538,3581,0.2660334975118245,49.74335145950317,1.0853406361171178,0.246516398021153,1.0927659497880908,3554,0.2452213628679832 -8.081871747970581,0.0230159759521484,125.74299168586732,260,0,125.74299168586732,0.3573848209844492,3581,0.663606125427604,133.8574230670929,0.3349068164825439,0.6689420427594867,0.3565587616396138,3554,0.6439759915104459 -12.144757747650146,0.0539445877075195,205.8846409320832,499,0,205.8846409320832,0.3316295531101647,3581,0.6938741087946803,218.1014540195465,0.3091085297720773,0.6992927278791156,0.3299610006990363,3554,0.6752175420740715 -16.21077609062195,0.0866448879241943,285.99159264564514,738,0,285.99159264564514,0.3170625867207134,3581,0.7101002905962022,302.3159143924713,0.2945798805781773,0.7161976950509208,0.3149693306067986,3554,0.6924763745515616 -20.27893114089965,0.1195316314697265,366.0610370635986,1052,0,366.0610370635986,0.3081722137016545,3581,0.7198837778815275,386.4980580806732,0.2856094326291765,0.7262726511274066,0.3057588267071434,3554,0.7028399173906162 -24.34521126747132,0.1439287662506103,446.2296071052551,1398,0,446.2296071052551,0.3022956581829447,3581,0.7254137914295937,470.7702660560608,0.2798751252038138,0.7319378171648298,0.3001379937218627,3554,0.7083462021226083 -28.41634178161621,0.1698086261749267,526.3007123470306,1741,0,526.3007123470306,0.3014767542127722,3581,0.726199595630585,554.9513428211212,0.2790592227663312,0.732933589390346,0.2994219555100766,3554,0.7093250315445625 -32.47808003425598,0.1965868473052978,606.3087928295135,2085,0,606.3087928295135,0.2967336697806129,3581,0.7314931043397445,639.0613944530487,0.2745100430079868,0.7381625175476074,0.2949065561031584,3554,0.7143384324880416 -36.54174590110779,0.2217030525207519,686.3964641094208,2428,0,686.3964641094208,0.295681840246265,3581,0.7324901198382086,723.2513434886932,0.2728924070085798,0.7399488857814244,0.2938475944248558,3554,0.7154231202957935 -40.60655450820923,0.2466726303100586,766.532796382904,2773,0,766.532796382904,0.2940835747696174,3581,0.7338204510349763,807.490364074707,0.2716763360159738,0.7410245622907367,0.2923744042803883,3554,0.7166780333339196 -44.67102932929993,0.2755327224731445,846.7297871112823,3119,0,846.7297871112823,0.2942355746365366,3581,0.7330739847676976,891.793835401535,0.2724208831787109,0.7391742978777204,0.2926273378147861,3554,0.7159810578749296 -48.73697257041931,0.3005204200744629,926.8994541168212,3466,0,926.8994541168212,0.2924607656948129,3581,0.7364933852275901,976.0675466060638,0.2697564533778599,0.7439769336155483,0.2907622449506717,3554,0.7193141200935566 -52.80117058753967,3.2227320671081543,1004.2445499897004,3799,0,1004.2445499897004,0.2917038683983175,3581,0.7361285719116518,1060.412445783615,0.2692778621401105,0.7434635162353516,0.2901401122634707,3554,0.7188673303847777 -56.872363567352295,3.248110055923462,1084.4557919502258,4144,0,1084.4557919502258,0.2928426913527995,3581,0.7354656220547682,1144.7332124710083,0.2704278060368129,0.7428201266697475,0.2911202468829136,3554,0.7184732294377814 -60.93706798553467,3.2746899127960205,1164.5836515426636,4487,0,1164.5836515426636,0.2912233592877338,3581,0.7379175275106464,1228.9657146930697,0.2683043990816389,0.7458201817103794,0.2896644021261255,3554,0.7207160395856781 -65.0037591457367,3.3040175437927246,1244.6789507865906,4832,0,1244.6789507865906,0.2904094322138543,3581,0.7380615166198339,1313.1699743270874,0.2678250585283552,0.7455734525408063,0.2888621178381753,3554,0.7208194936603123 -69.06503939628601,3.3317201137542725,1324.7413284778595,5175,0,1324.7413284778595,0.2902906343811086,3581,0.7379987940912106,1397.3338668346405,0.2677472148622785,0.7453188896179199,0.288773604840057,3554,0.7208296604618036 -73.12762188911438,3.3577311038970947,1404.8958258628843,5518,0,1404.8958258628843,0.2918743441405159,3581,0.738994855116413,1481.5906162261963,0.2686936003821237,0.7465689522879464,0.2903028497819358,3554,0.7218854965355938 -77.1900954246521,3.3832943439483643,1484.859869003296,5864,0,1484.859869003296,0.2900212002146746,3581,0.7382644785412594,1565.6560204029083,0.2675286020551409,0.7454586029052734,0.2885642580323227,3554,0.7209607984621201 -81.2557954788208,3.411367416381836,1565.0343101024628,6210,0,1565.0343101024628,0.29393297252426,3581,0.7349887263072465,1649.9378747940063,0.2713038921356201,0.742236818586077,0.2926282308446469,3554,0.7176969804050014 -85.31988334655762,3.442832946777344,1645.018116235733,6551,0,1645.018116235733,0.2900438348663083,3581,0.7383212697003281,1734.030422449112,0.2674429076058524,0.7458180700029645,0.2885521334345983,3554,0.7210834183314575 -89.38604640960693,3.468510150909424,1725.143525838852,6896,0,1725.143525838852,0.2896101972105382,3581,0.739134685449246,1818.260770559311,0.2670606545039585,0.7465430668422154,0.2881888076704945,3554,0.7218675672437747 -93.45235657691956,3.498748540878296,1805.1661853790283,7242,0,1805.1661853790283,0.2889781313813355,3581,0.7398020667891302,1902.393094301224,0.2663912432534354,0.7473367963518415,0.287675384195185,3554,0.7224701563159468 -97.51888036727904,3.5248067378997803,1885.309853553772,7584,0,1885.309853553772,0.2897599813359571,3581,0.736945396493647,1986.6425902843475,0.2673313447407314,0.7440915788922992,0.288352300829611,3554,0.7194599587392726 -101.58399534225464,3.5519094467163086,1965.398687839508,7930,0,1965.398687839508,0.2890621250283615,3581,0.7405779853741972,2070.8367421627045,0.2660749299185617,0.7484252112252372,0.2875628967800366,3554,0.7234536569710186 -105.65423774719238,3.579378604888916,2045.390664577484,8272,0,2045.390664577484,0.2885723438918074,3581,0.7402424198460625,2154.939321041107,0.265859808240618,0.7478161539350238,0.2872010479499332,3554,0.7229330892568233 -109.7238380908966,3.606642246246338,2125.487339735031,8616,0,2125.487339735031,0.2881952246884599,3581,0.7409527524827213,2239.1462643146515,0.2655953679765974,0.7485686029706683,0.2868307153361529,3554,0.7237002019071821 -113.7906687259674,3.633155345916748,2205.603874444961,8962,0,2205.603874444961,0.2887139127295099,3581,0.7410736978803057,2323.369287729264,0.2655679157802036,0.7491508211408343,0.2873204906938045,3554,0.7238994162607274 -117.85758423805235,3.660656452178955,2285.7724466323853,9307,0,2285.7724466323853,0.2882544361190135,3581,0.7404068619537141,2407.645413160324,0.2655667236873081,0.7480177879333496,0.2868091967512398,3554,0.7231749629598692 -121.92598867416382,3.68761682510376,2365.7505254745483,9648,0,2365.7505254745483,0.2886442702697396,3581,0.7401621759154217,2491.731642961502,0.2661188500268118,0.7472796440124512,0.2871454396674521,3554,0.7231537363270258 -125.99807476997375,3.715875148773194,2445.8985407352448,9994,0,2445.8985407352448,0.2886798925753979,3581,0.7414724631736945,2575.993050098419,0.2653477191925049,0.749699320111956,0.287246506604574,3554,0.7242661767550647 -130.0650417804718,3.743248224258423,2526.036517381668,10339,0,2526.036517381668,0.2882380737202422,3581,0.741316611325398,2660.2385306358337,0.2654490981783186,0.7490485736301967,0.2868410367005047,3554,0.724125696288513 -134.1287350654602,3.7703020572662354,2606.084993124008,10683,0,2606.084993124008,0.2878581251854405,3581,0.7408440788842153,2744.390327692032,0.2652280500956944,0.7484066826956612,0.28647137385912,3554,0.7236086319991559 -138.194429397583,3.802428722381592,2686.2606456279755,11025,0,2686.2606456279755,0.2886463837462475,3581,0.7404084300169296,2828.676636695862,0.2652821029935564,0.7488865852355957,0.287192426777047,3554,0.7232043642506683 -142.26156210899353,3.828813552856445,2766.471347808838,11372,0,2766.471347808838,0.288206951074246,3581,0.7402820986630829,2912.99440908432,0.2654268571308681,0.7482709203447614,0.2867865618790007,3554,0.7230311851522931 -146.33048748970032,3.8615622520446777,2846.460176229477,11718,0,2846.460176229477,0.2878480350395315,3581,0.7413314738376152,2997.098103761673,0.2650367191859654,0.74904693875994,0.2864352576707143,3554,0.7241317414137239 -150.39404916763306,3.889822244644165,2926.4845530986786,12059,0,2926.4845530986786,0.288100527305571,3581,0.7402561233550335,3081.2274656295776,0.2649466480527605,0.7484876087733677,0.2866702275659908,3554,0.7230258956677336 -154.4581105709076,3.916551351547241,3006.467269420624,12404,0,3006.467269420624,0.2878106742290037,3581,0.7411182854169576,3165.3140320777893,0.2648206268038068,0.7491763659885952,0.2864063544157991,3554,0.7238350494161508 -158.52502942085266,3.9443979263305664,3086.66849565506,12749,0,3086.66849565506,0.2876451072064367,3581,0.7410273377504538,3249.6230750083923,0.2648651259286063,0.748763016292027,0.286322169177775,3554,0.7238626646472285 -162.59536600112915,3.9718716144561768,3166.8827443122864,13091,0,3166.8827443122864,0.2878322521423834,3581,0.7406824320196872,3333.9484016895294,0.2645631006785801,0.7491682597569057,0.286429796449643,3554,0.7234744714362338 -166.6649980545044,3.9991109371185294,3247.015244960785,13437,0,3247.015244960785,0.287715806404461,3581,0.7417780309707483,3418.19149851799,0.2646686179297311,0.7497834478105817,0.2863316318595684,3554,0.7245596401062183 -170.73585772514343,4.028689861297607,3327.116854429245,13782,0,3327.116854429245,0.2886482586044401,3581,0.7406301405202806,3502.4071373939514,0.2658377545220511,0.7486034802028111,0.2873000368752638,3554,0.7232885838359947 -174.8053970336914,4.05685567855835,3407.216105222702,14123,0,3407.216105222702,0.2883330097214465,3581,0.7391924310815764,3586.617163658142,0.2655568974358694,0.7467929295131138,0.2870196941936023,3554,0.7217877441131472 -178.87139344215393,4.088459253311157,3487.2479860782623,14468,0,3487.2479860782623,0.2875891000789758,3581,0.742357123533929,3670.760345458984,0.2643441472734724,0.7505397796630859,0.286208908948324,3554,0.7251111203925155 -182.9370470046997,4.116533517837524,3567.3677020072937,14813,0,3567.3677020072937,0.2896819190584858,3581,0.7358483658326934,3754.986970663071,0.2671321630477905,0.7425896780831474,0.2882850488116383,3554,0.7190899009039111 -187.00343370437625,4.144828796386719,3647.348000049591,15155,0,3647.348000049591,0.287650186367722,3581,0.7413178385053057,3839.075215816498,0.2646791424070085,0.7493658065795898,0.2862461414240556,3554,0.7241403969339125 -191.01669716835025,4.177236795425415,3727.449832677841,15501,0,3727.449832677841,0.287321302152419,3581,0.7425276333644583,3923.2359850406647,0.2639762333461216,0.7509126663208008,0.285933735535664,3554,0.7253296379299733 -195.0844044685364,4.205004215240479,3807.651174545288,15847,0,3807.651174545288,0.287013484525534,3581,0.743033708723122,4007.545920372009,0.2639362130846296,0.7511205673217773,0.2856074361634778,3554,0.7259452102824282 -199.1486778259277,4.234395503997803,3887.811373949051,16190,0,3887.811373949051,0.2872470577679942,3581,0.7431809703120636,4091.81326007843,0.2642698287963867,0.7509918894086566,0.2859001095266777,3554,0.7260166526712859 -203.21332693099976,4.2629711627960205,3967.889343976976,16535,0,3967.889343976976,0.2870917854213208,3581,0.7422342010131597,4175.997426271439,0.2636375086648123,0.7507597378322056,0.2857542193600081,3554,0.7249813562842924 -207.28168106079104,4.291972875595093,4048.056460380554,16881,0,4048.056460380554,0.2870818998053965,3581,0.7419927874546216,4260.275257110596,0.2640057461602347,0.7500498635428292,0.2857108902381119,3554,0.7248276864536438 -211.35364246368408,4.321918964385986,4128.021166086197,17222,0,4128.021166086197,0.2876822634869799,3581,0.7408937796704831,4344.355574369431,0.2648025410515921,0.7488697596958706,0.2863705988740679,3554,0.7235759333673326 -215.41570520401,4.34968638420105,4208.193595170975,17570,0,4208.193595170975,0.2870471978846691,3581,0.7430041882286722,4428.63065123558,0.2635806117738996,0.7515374592372349,0.2856503187704435,3554,0.7258187435152293 -219.47929334640503,4.378240585327148,4288.185995578766,17916,0,4288.185995578766,0.2869160941645141,3581,0.7422369962562831,4512.728018522263,0.2637154204504831,0.7505055155072894,0.2855568254134865,3554,0.7249672051957302 -223.54808259010315,4.410733461380005,4368.278284549713,18262,0,4368.278284549713,0.2869031746871509,3581,0.7425056804794401,4596.934916496277,0.2638850552695138,0.7505534717014858,0.2856249189403665,3554,0.7251971947321679 -227.6206016540528,4.440626382827759,4448.341143369675,18605,0,4448.341143369675,0.2870902173581053,3581,0.7427279363960835,4681.113558769226,0.2634532962526594,0.7513707705906459,0.2857273769432329,3554,0.725492650226857 -231.68354749679563,4.469612121582031,4528.50818157196,18952,0,4528.50818157196,0.2870987053524679,3581,0.7412419578810039,4765.386015415192,0.2640271357127598,0.7493149893624442,0.2857467831690525,3554,0.7240767170353827 -235.7559115886688,4.49866247177124,4608.514596700668,19298,0,4608.514596700668,0.2868421224867356,3581,0.7421114148457135,4849.507258892059,0.263763632093157,0.7502527918134417,0.2855173260158009,3554,0.7248332507166221 -239.829217672348,4.528627872467041,4688.5540153980255,19641,0,4688.5540153980255,0.2883027051953888,3581,0.7414645546809551,4933.662929773331,0.2645717178072248,0.7502414839608329,0.2868963873782182,3554,0.7243462746641108 -243.896320104599,4.558478355407715,4768.637889385223,19987,0,4768.637889385223,0.2870272562111666,3581,0.7424568659897725,5017.856953859329,0.2637774603707449,0.7504786082676479,0.2857269819492561,3554,0.7253230432479248 -247.96650791168213,4.588074684143066,4848.670360326767,20331,0,4848.670360326767,0.2868187038034941,3581,0.7427785234789515,5102.002462863922,0.2635509627205984,0.7510354178292411,0.2854417791243142,3554,0.7256886359339828 -252.0331449508667,4.625176191329956,4928.62136054039,20674,0,4928.62136054039,0.2869164691361526,3581,0.7417379430937587,5186.070428848267,0.2634921755109514,0.7502099445887974,0.285569877388374,3554,0.7245509158914252 -256.1032905578613,4.656060457229614,5008.828974246979,21021,0,5008.828974246979,0.2872547958190798,3581,0.742762024726857,5270.392125844955,0.263641391481672,0.7512658664158413,0.2858189640249103,3554,0.7255388130011958 -260.1727867126465,4.685746908187866,5088.926310777664,21366,0,5088.926310777664,0.286783047409505,3581,0.7427312770524993,5354.601647377014,0.2635410853794643,0.7510338510785785,0.2854907240301421,3554,0.7255325617921707 -264.2384469509125,4.715479373931885,5168.926723718643,21708,0,5168.926723718643,0.2869437738891022,3581,0.7430581159679559,5438.710728645325,0.263600264276777,0.7513729504176548,0.2855925981288689,3554,0.7258484882790518 -268.303653717041,4.745387554168701,5248.983458995819,22054,0,5248.983458995819,0.2867228133290282,3581,0.7430818414461743,5522.876201629639,0.263126083782741,0.7516428402491978,0.2853465512285892,3554,0.7258781643482696 -272.3733148574829,4.775863885879517,5329.01037311554,22399,0,5329.01037311554,0.2868253169396642,3581,0.7435324891790003,5607.016135931015,0.2633925335747855,0.7519129344395229,0.2853977458827201,3554,0.7263766123997608 -276.4347996711731,4.80604100227356,5409.231924772263,22742,0,5409.231924772263,0.2866503756261344,3581,0.7425101801391022,5691.342561483383,0.2632285526820591,0.7509878022330148,0.2852576432364941,3554,0.7253110216921075 -280.50565695762634,4.836712121963501,5489.367866516113,23085,0,5489.367866516113,0.2867670258940414,3581,0.7431136799471167,5775.593355178833,0.2628756761550903,0.7520301001412528,0.285353489383661,3554,0.7259378599597285 -284.56813645362854,4.866225004196167,5569.326278924942,23431,0,5569.326278924942,0.2865393840211359,3581,0.7431680849230313,5859.657479524612,0.2629827431270054,0.7517572130475726,0.2851981880561515,3554,0.7259823053689505 -288.6337020397186,4.896897554397583,5649.284064054489,23775,0,5649.284064054489,0.2867870357442055,3581,0.7431376781319813,5943.725197553635,0.2632080146244594,0.7516813278198242,0.2854376231007315,3554,0.7259476145935917 -292.6982800960541,4.928308963775635,5729.313740730286,24120,0,5729.313740730286,0.2865086363467781,3581,0.7430952040718375,6027.86398601532,0.2625001668930053,0.7521965163094657,0.2851261102422007,3554,0.7258818738569218 -296.76456236839294,4.957936525344849,5809.358879804611,24466,0,5809.358879804611,0.2865454517440135,3581,0.742909081785814,6112.018396377564,0.2629297460828508,0.751568181174142,0.2852220079103211,3554,0.7256388323455966 -300.82558012008667,4.990025758743286,5889.45965218544,24810,0,5889.45965218544,0.2865225102974029,3581,0.7432031277270664,6196.225397586823,0.262939385005406,0.751816953931536,0.2851854795542874,3554,0.7259641012987127 -304.89204001426697,5.020129442214966,5969.472599029541,25157,0,5969.472599029541,0.2866050722345364,3581,0.7439072562875244,6280.348185300827,0.2624353340693882,0.7531287329537528,0.2852321231908589,3554,0.7268263559765406 -308.9539279937744,5.050216436386108,6049.431377649307,25504,0,6049.431377649307,0.2863970311518256,3581,0.7429824398736387,6364.412181138992,0.2626841919762747,0.7517944063459124,0.2850880190839107,3554,0.7257352795705543 -313.02527832984924,5.085922002792358,6129.473692893982,25850,0,6129.473692893982,0.2865169538994869,3581,0.7426864849858629,6448.574826478958,0.2628457375935146,0.7514446803501674,0.285239816986582,3554,0.7254274590470244 -317.09102606773376,5.121199369430542,6209.568758249283,26192,0,6209.568758249283,0.2865926299938041,3581,0.7433600704019477,6532.783788204193,0.2623382295880999,0.7526767594473702,0.2851828004647053,3554,0.7262026776607344 -321.1544461250305,5.158927917480469,6289.681404352188,26539,0,6289.681404352188,0.286464253340111,3581,0.7433585705153938,6617.011211872101,0.2625647783279419,0.7522519656590053,0.2851176951531285,3554,0.726147447198579 -325.2188422679901,5.190818786621094,6369.843276500702,26885,0,6369.843276500702,0.2863344449765254,3581,0.7436085743332868,6701.28259563446,0.2626039300646101,0.7523502622331891,0.2850532596139473,3554,0.7264758074089055 -329.28604793548584,5.221648454666138,6449.991070270538,27225,0,6449.991070270538,0.2870730027510646,3581,0.7439125740671251,6785.541536808014,0.2628480195999145,0.7528362955365863,0.2855947963562183,3554,0.7270255703300859 -333.3548312187195,5.253356695175171,6530.014122247696,27570,0,6530.014122247696,0.2863756236800998,3581,0.7434651988140534,6869.67798948288,0.2624267680304391,0.752467427934919,0.2850071311869109,3554,0.7263111464415095 -337.4182233810425,5.285075664520264,6610.146442651749,27915,0,6610.146442651749,0.2863726920836533,3581,0.7429035594762287,6953.919053077698,0.262539301599775,0.7517823491777692,0.2850680461276027,3554,0.7257233267093416 -341.4839758872986,5.319840431213379,6690.115103960037,28258,0,6690.115103960037,0.2864671849365575,3581,0.7430074607084264,7038.001373529434,0.2624285902295794,0.7520993777683803,0.2851729943098885,3554,0.7257936012899198 -345.5509777069092,5.350916862487793,6770.211782455444,28605,0,6770.211782455444,0.2865967205934969,3581,0.7439264821060807,7122.209839820862,0.2624111686434064,0.7531392233712333,0.2852265245805782,3554,0.7267762776097355 -349.6141200065613,5.382358074188232,6850.18680357933,28950,0,6850.18680357933,0.2863386719295413,3581,0.7434527224849903,7206.29279589653,0.2623393024717058,0.7524910654340472,0.2850574499848322,3554,0.7262524812491207 -353.68150186538696,5.413628816604614,6930.2322034835815,29290,0,6930.2322034835815,0.2867405733493612,3581,0.7437895151930327,7290.450036048889,0.2626663787024362,0.7528343200683594,0.2854241074372626,3554,0.7266698009724958 -357.7428359985352,5.4454345703125,7010.228944301605,29634,0,7010.228944301605,0.2864734912777506,3581,0.7435791220154985,7374.552933931351,0.2621491125651768,0.7528693335396903,0.2850999032505187,3554,0.7264843942344893 -361.8078374862671,5.477893114089966,7090.363373994827,29980,0,7090.363373994827,0.2863717376103916,3581,0.743540056788432,7458.798035383224,0.2622225625174386,0.7527691296168736,0.285038490273943,3554,0.7263874661472988 -365.8724312782288,5.510989665985107,7170.468336820602,30322,0,7170.468336820602,0.2863525458801662,3581,0.7435047412777507,7543.014036178589,0.2622182880129133,0.7526637486049107,0.2850195477367051,3554,0.7263018039752743 -369.9424302577973,5.542418241500855,7250.497773885727,30665,0,7250.497773885727,0.2863883386274783,3581,0.7435993023073164,7627.158119440079,0.2618430852890014,0.7531729425702777,0.2849867460629748,3554,0.7264931871438872 -374.01244950294495,5.842834711074829,7330.379662036896,31008,0,7330.379662036896,0.2863955653536023,3581,0.7435995750139626,7711.4234845638275,0.2621064186096191,0.7529027802603585,0.2850429554232467,3554,0.7264368575680571 -378.07388615608215,5.875456809997559,7410.376347541809,31351,0,7410.376347541809,0.286300186204098,3581,0.743733473977241,7795.527099370956,0.2620646102087838,0.7530208315168109,0.2849707573937377,3554,0.7265999385595456 -382.1353690624237,5.907495021820068,7490.46945476532,31697,0,7490.46945476532,0.2863908270756248,3581,0.7439781600155334,7879.72705745697,0.2617484842027937,0.753650392804827,0.2849589762690366,3554,0.7268968366409327 -386.2018015384674,5.940229654312134,7570.541733264923,32039,0,7570.541733264923,0.2862719610662175,3581,0.7437529043257819,7963.911533355713,0.2619015148707798,0.7531657900129046,0.2849250239606781,3554,0.7266006255055923 -390.2697324752808,5.974608659744263,7650.532155036926,32380,0,7650.532155036926,0.2862588711472005,3581,0.7434078622416923,8048.016820669174,0.2619706051690237,0.7526820727757045,0.2849329410138664,3554,0.7262389484120005 -394.3321797847748,6.007458448410034,7730.736515045166,32725,0,7730.736515045166,0.2863338654749022,3581,0.7436514574533999,8132.329214334488,0.2616528442927769,0.7534105437142509,0.2849374576841235,3554,0.7265398994750633 -398.39750695228577,6.040522813796997,7810.769964694977,33070,0,7810.769964694977,0.2862581212039234,3581,0.7438506014817788,8216.474393606186,0.2617532696042742,0.7533956936427525,0.2848618764453345,3554,0.7267253062130697 -402.461065530777,6.074344158172607,7890.764701366425,33413,0,7890.764701366425,0.2861579015114493,3581,0.7435203537332449,8300.579656124115,0.2617817606244768,0.7529990332467216,0.2848488759914005,3554,0.726368231657991 -406.5280122756958,6.112351179122925,7970.896803617477,33754,0,7970.896803617477,0.2862019436348087,3581,0.7437582902820441,8384.830063343048,0.2616239275251116,0.7533809798104423,0.2848602449484735,3554,0.7266520777644907 -410.5884718894959,6.1458728313446045,8051.041610479355,34098,0,8051.041610479355,0.2861940692304,3581,0.7438228535805291,8469.08253955841,0.2616209132330758,0.7534258706229073,0.2848003947741541,3554,0.7267079951726927 -414.6534061431885,6.180485725402832,8131.060321331024,34442,0,8131.060321331024,0.2861305626701689,3581,0.743851624131702,8553.214111566544,0.2616205726351057,0.7533926963806152,0.2847812976740556,3554,0.7267149333277645 -418.7173886299133,6.215153217315674,8211.12027812004,34783,0,8211.12027812004,0.2861496180470713,3581,0.7437214748848087,8637.385941267014,0.2616024698529924,0.7533390181405204,0.2848096170248311,3554,0.726566690370885 -422.782103061676,6.249707460403442,8291.133207082748,35127,0,8291.133207082748,0.2861220405874756,3581,0.7439961586541818,8721.511438369751,0.261509827205113,0.7536563192095075,0.284757271736072,3554,0.7268637258414814 -426.8485105037689,6.288463354110718,8371.14042711258,35473,0,8371.14042711258,0.2861458683306863,3581,0.7440684940920832,8805.637065887451,0.2615468502044678,0.7537084306989398,0.2847832382966376,3554,0.7269505558217854 -430.91059255599976,6.323931455612183,8451.220481395721,35816,0,8451.220481395721,0.2861297786385611,3581,0.7440939239868403,8889.82793712616,0.261535746710641,0.7537142208644322,0.2847685033039357,3554,0.7269712328977912 -434.9767861366272,6.358710527420044,8531.392409324646,36161,0,8531.392409324646,0.2860935768312797,3581,0.7440289516283859,8974.114171504974,0.2615071535110473,0.753655093056815,0.2847359420613217,3554,0.7268996531197243 -439.0459842681885,6.394179105758667,8535.762496232986,36189,0,8535.762496232986,0.2860935768312797,3581,0.7440284062150936,8982.592011928558,0.26150708539145334,0.7536547524588448,0.28473594206132175,3554,0.7268993783413056 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/measurements.csv deleted file mode 100644 index dce9067da..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/measurements.csv +++ /dev/null @@ -1,472 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.668263,1.0711248,,,,,,,,,,,,,, -1,,,0.246516398021153,1.0853406361171178,0.2452213628679832,1.0927659497880908,3554.0,0.2660334975118245,1.0888061012112538,3581.0,45.728089332580566,49.74335145950317,45.728089332580566,4.015159368515015,0.0,0.0 -100,0.98455656,0.4277538,,,,,,,,,,,,,, -200,0.14291877,0.34520656,,,,,,,,,,,,,, -260,,,0.6689420427594867,0.3349068164825439,0.6439759915104459,0.3565587616396138,3554.0,0.663606125427604,0.3573848209844492,3581.0,125.74299168586732,133.8574230670929,125.74299168586732,8.081871747970581,0.0230159759521484,0.0 -300,0.1784539,0.37745482,,,,,,,,,,,,,, -400,0.076686114,0.37216967,,,,,,,,,,,,,, -499,,,0.6992927278791156,0.3091085297720773,0.6752175420740715,0.3299610006990363,3554.0,0.6938741087946803,0.3316295531101647,3581.0,205.8846409320832,218.1014540195465,205.8846409320832,12.144757747650146,0.0539445877075195,0.0 -500,0.09880565,0.3140308,,,,,,,,,,,,,, -600,0.12548687,0.28192455,,,,,,,,,,,,,, -700,0.105911694,0.2555012,,,,,,,,,,,,,, -738,,,0.7161976950509208,0.2945798805781773,0.6924763745515616,0.3149693306067986,3554.0,0.7101002905962022,0.3170625867207134,3581.0,285.99159264564514,302.3159143924713,285.99159264564514,16.21077609062195,0.0866448879241943,0.0 -800,0.1554309,0.22620068,,,,,,,,,,,,,, -900,0.110051565,0.3590724,,,,,,,,,,,,,, -1000,0.0722804,0.28108588,,,,,,,,,,,,,, -1052,,,0.7262726511274066,0.2856094326291765,0.7028399173906162,0.3057588267071434,3554.0,0.7198837778815275,0.3081722137016545,3581.0,366.0610370635986,386.4980580806732,366.0610370635986,20.27893114089965,0.1195316314697265,0.0 -1100,0.28844073,0.30548805,,,,,,,,,,,,,, -1200,0.14989589,0.23776855,,,,,,,,,,,,,, -1300,0.23765738,0.27691406,,,,,,,,,,,,,, -1398,,,0.7319378171648298,0.2798751252038138,0.7083462021226083,0.3001379937218627,3554.0,0.7254137914295937,0.3022956581829447,3581.0,446.2296071052551,470.7702660560608,446.2296071052551,24.34521126747132,0.1439287662506103,0.0 -1400,0.17181416,0.26268426,,,,,,,,,,,,,, -1500,0.121291846,0.28905958,,,,,,,,,,,,,, -1600,0.07029355,0.27616128,,,,,,,,,,,,,, -1700,0.06513453,0.3324601,,,,,,,,,,,,,, -1741,,,0.732933589390346,0.2790592227663312,0.7093250315445625,0.2994219555100766,3554.0,0.726199595630585,0.3014767542127722,3581.0,526.3007123470306,554.9513428211212,526.3007123470306,28.41634178161621,0.1698086261749267,0.0 -1800,0.12398466,0.2663774,,,,,,,,,,,,,, -1900,0.074978165,0.3900019,,,,,,,,,,,,,, -2000,0.1080408,0.30440217,,,,,,,,,,,,,, -2085,,,0.7381625175476074,0.2745100430079868,0.7143384324880416,0.2949065561031584,3554.0,0.7314931043397445,0.2967336697806129,3581.0,606.3087928295135,639.0613944530487,606.3087928295135,32.47808003425598,0.1965868473052978,0.0 -2100,0.06374257,0.2684809,,,,,,,,,,,,,, -2200,0.092135295,0.27814043,,,,,,,,,,,,,, -2300,0.073529184,0.26999256,,,,,,,,,,,,,, -2400,0.049914,0.2820391,,,,,,,,,,,,,, -2428,,,0.7399488857814244,0.2728924070085798,0.7154231202957935,0.2938475944248558,3554.0,0.7324901198382086,0.295681840246265,3581.0,686.3964641094208,723.2513434886932,686.3964641094208,36.54174590110779,0.2217030525207519,0.0 -2500,0.22147252,0.18570463,,,,,,,,,,,,,, -2600,0.09777072,0.27407795,,,,,,,,,,,,,, -2700,0.09013032,0.28392094,,,,,,,,,,,,,, -2773,,,0.7410245622907367,0.2716763360159738,0.7166780333339196,0.2923744042803883,3554.0,0.7338204510349763,0.2940835747696174,3581.0,766.532796382904,807.490364074707,766.532796382904,40.60655450820923,0.2466726303100586,0.0 -2800,0.12168391,0.4313155,,,,,,,,,,,,,, -2900,0.074948736,0.27672842,,,,,,,,,,,,,, -3000,0.31186363,0.35099763,,,,,,,,,,,,,, -3100,0.102371864,0.27590004,,,,,,,,,,,,,, -3119,,,0.7391742978777204,0.2724208831787109,0.7159810578749296,0.2926273378147861,3554.0,0.7330739847676976,0.2942355746365366,3581.0,846.7297871112823,891.793835401535,846.7297871112823,44.67102932929993,0.2755327224731445,0.0 -3200,0.2812573,0.36564124,,,,,,,,,,,,,, -3300,0.15226501,0.2707171,,,,,,,,,,,,,, -3400,0.10186057,0.30339453,,,,,,,,,,,,,, -3466,,,0.7439769336155483,0.2697564533778599,0.7193141200935566,0.2907622449506717,3554.0,0.7364933852275901,0.2924607656948129,3581.0,926.8994541168212,976.0675466060638,926.8994541168212,48.73697257041931,0.3005204200744629,0.0 -3500,0.1514177,0.32727015,,,,,,,,,,,,,, -3600,0.078720026,0.3288271,,,,,,,,,,,,,, -3700,0.12349272,0.2540967,,,,,,,,,,,,,, -3799,,,0.7434635162353516,0.2692778621401105,0.7188673303847777,0.2901401122634707,3554.0,0.7361285719116518,0.2917038683983175,3581.0,1004.2445499897004,1060.412445783615,1004.2445499897004,52.80117058753967,3.2227320671081543,0.0 -3800,0.07869508,0.31048104,,,,,,,,,,,,,, -3900,0.12756631,0.31185293,,,,,,,,,,,,,, -4000,0.16368505,0.2699703,,,,,,,,,,,,,, -4100,0.16819018,0.24514405,,,,,,,,,,,,,, -4144,,,0.7428201266697475,0.2704278060368129,0.7184732294377814,0.2911202468829136,3554.0,0.7354656220547682,0.2928426913527995,3581.0,1084.4557919502258,1144.7332124710083,1084.4557919502258,56.872363567352295,3.248110055923462,0.0 -4200,0.5022938,0.22466506,,,,,,,,,,,,,, -4300,0.24614333,0.28661686,,,,,,,,,,,,,, -4400,0.06996018,0.25913203,,,,,,,,,,,,,, -4487,,,0.7458201817103794,0.2683043990816389,0.7207160395856781,0.2896644021261255,3554.0,0.7379175275106464,0.2912233592877338,3581.0,1164.5836515426636,1228.9657146930697,1164.5836515426636,60.93706798553467,3.2746899127960205,0.0 -4500,0.1858863,0.33039457,,,,,,,,,,,,,, -4600,0.26178437,0.28859323,,,,,,,,,,,,,, -4700,0.10588473,0.319917,,,,,,,,,,,,,, -4800,0.23049693,0.25295484,,,,,,,,,,,,,, -4832,,,0.7455734525408063,0.2678250585283552,0.7208194936603123,0.2888621178381753,3554.0,0.7380615166198339,0.2904094322138543,3581.0,1244.6789507865906,1313.1699743270874,1244.6789507865906,65.0037591457367,3.3040175437927246,0.0 -4900,0.039970927,0.31465957,,,,,,,,,,,,,, -5000,0.13916285,0.39201337,,,,,,,,,,,,,, -5100,0.12924045,0.29277733,,,,,,,,,,,,,, -5175,,,0.7453188896179199,0.2677472148622785,0.7208296604618036,0.288773604840057,3554.0,0.7379987940912106,0.2902906343811086,3581.0,1324.7413284778595,1397.3338668346405,1324.7413284778595,69.06503939628601,3.3317201137542725,0.0 -5200,0.07356882,0.26530346,,,,,,,,,,,,,, -5300,0.18968263,0.26250818,,,,,,,,,,,,,, -5400,0.15747645,0.3532693,,,,,,,,,,,,,, -5500,0.087174304,0.31357765,,,,,,,,,,,,,, -5518,,,0.7465689522879464,0.2686936003821237,0.7218854965355938,0.2903028497819358,3554.0,0.738994855116413,0.2918743441405159,3581.0,1404.8958258628843,1481.5906162261963,1404.8958258628843,73.12762188911438,3.3577311038970947,0.0 -5600,0.11176947,0.2502225,,,,,,,,,,,,,, -5700,0.27955604,0.2970239,,,,,,,,,,,,,, -5800,0.09507778,0.27253658,,,,,,,,,,,,,, -5864,,,0.7454586029052734,0.2675286020551409,0.7209607984621201,0.2885642580323227,3554.0,0.7382644785412594,0.2900212002146746,3581.0,1484.859869003296,1565.6560204029083,1484.859869003296,77.1900954246521,3.3832943439483643,0.0 -5900,0.2707866,0.2540259,,,,,,,,,,,,,, -6000,0.18870735,0.24002732,,,,,,,,,,,,,, -6100,0.117941625,0.29254773,,,,,,,,,,,,,, -6200,0.105825,0.27345806,,,,,,,,,,,,,, -6210,,,0.742236818586077,0.2713038921356201,0.7176969804050014,0.2926282308446469,3554.0,0.7349887263072465,0.29393297252426,3581.0,1565.0343101024628,1649.9378747940063,1565.0343101024628,81.2557954788208,3.411367416381836,0.0 -6300,0.49155915,0.22304074,,,,,,,,,,,,,, -6400,0.106539406,0.3000044,,,,,,,,,,,,,, -6500,0.20549557,0.30100077,,,,,,,,,,,,,, -6551,,,0.7458180700029645,0.2674429076058524,0.7210834183314575,0.2885521334345983,3554.0,0.7383212697003281,0.2900438348663083,3581.0,1645.018116235733,1734.030422449112,1645.018116235733,85.31988334655762,3.442832946777344,0.0 -6600,0.26230374,0.24748787,,,,,,,,,,,,,, -6700,0.26201797,0.27427465,,,,,,,,,,,,,, -6800,0.10593401,0.26871645,,,,,,,,,,,,,, -6896,,,0.7465430668422154,0.2670606545039585,0.7218675672437747,0.2881888076704945,3554.0,0.739134685449246,0.2896101972105382,3581.0,1725.143525838852,1818.260770559311,1725.143525838852,89.38604640960693,3.468510150909424,0.0 -6900,0.17042275,0.34349132,,,,,,,,,,,,,, -7000,0.07392616,0.24809682,,,,,,,,,,,,,, -7100,0.4296686,0.27448037,,,,,,,,,,,,,, -7200,0.10122109,0.27269253,,,,,,,,,,,,,, -7242,,,0.7473367963518415,0.2663912432534354,0.7224701563159468,0.287675384195185,3554.0,0.7398020667891302,0.2889781313813355,3581.0,1805.1661853790283,1902.393094301224,1805.1661853790283,93.45235657691956,3.498748540878296,0.0 -7300,0.0874117,0.30631664,,,,,,,,,,,,,, -7400,0.1831308,0.27667627,,,,,,,,,,,,,, -7500,0.09571534,0.26886415,,,,,,,,,,,,,, -7584,,,0.7440915788922992,0.2673313447407314,0.7194599587392726,0.288352300829611,3554.0,0.736945396493647,0.2897599813359571,3581.0,1885.309853553772,1986.6425902843475,1885.309853553772,97.51888036727904,3.5248067378997803,0.0 -7600,0.25389153,0.2662344,,,,,,,,,,,,,, -7700,0.48814481,0.22056809,,,,,,,,,,,,,, -7800,0.110770024,0.28867403,,,,,,,,,,,,,, -7900,0.2057535,0.2894628,,,,,,,,,,,,,, -7930,,,0.7484252112252372,0.2660749299185617,0.7234536569710186,0.2875628967800366,3554.0,0.7405779853741972,0.2890621250283615,3581.0,1965.398687839508,2070.8367421627045,1965.398687839508,101.58399534225464,3.5519094467163086,0.0 -8000,0.13913184,0.29783806,,,,,,,,,,,,,, -8100,0.424984,0.19641115,,,,,,,,,,,,,, -8200,0.15009926,0.3270442,,,,,,,,,,,,,, -8272,,,0.7478161539350238,0.265859808240618,0.7229330892568233,0.2872010479499332,3554.0,0.7402424198460625,0.2885723438918074,3581.0,2045.390664577484,2154.939321041107,2045.390664577484,105.65423774719238,3.579378604888916,0.0 -8300,0.4261997,0.1983581,,,,,,,,,,,,,, -8400,0.2540284,0.1988835,,,,,,,,,,,,,, -8500,0.07258254,0.25337824,,,,,,,,,,,,,, -8600,0.07590694,0.17387408,,,,,,,,,,,,,, -8616,,,0.7485686029706683,0.2655953679765974,0.7237002019071821,0.2868307153361529,3554.0,0.7409527524827213,0.2881952246884599,3581.0,2125.487339735031,2239.1462643146515,2125.487339735031,109.7238380908966,3.606642246246338,0.0 -8700,0.6042687,0.26961282,,,,,,,,,,,,,, -8800,0.09994479,0.25715324,,,,,,,,,,,,,, -8900,0.11762703,0.25165516,,,,,,,,,,,,,, -8962,,,0.7491508211408343,0.2655679157802036,0.7238994162607274,0.2873204906938045,3554.0,0.7410736978803057,0.2887139127295099,3581.0,2205.603874444961,2323.369287729264,2205.603874444961,113.7906687259674,3.633155345916748,0.0 -9000,0.16072974,0.34618443,,,,,,,,,,,,,, -9100,0.36240196,0.28870535,,,,,,,,,,,,,, -9200,0.06767902,0.24224715,,,,,,,,,,,,,, -9300,0.056212116,0.41656,,,,,,,,,,,,,, -9307,,,0.7480177879333496,0.2655667236873081,0.7231749629598692,0.2868091967512398,3554.0,0.7404068619537141,0.2882544361190135,3581.0,2285.7724466323853,2407.645413160324,2285.7724466323853,117.85758423805235,3.660656452178955,0.0 -9400,0.18868914,0.28161126,,,,,,,,,,,,,, -9500,0.094176844,0.32078615,,,,,,,,,,,,,, -9600,0.13865688,0.2903047,,,,,,,,,,,,,, -9648,,,0.7472796440124512,0.2661188500268118,0.7231537363270258,0.2871454396674521,3554.0,0.7401621759154217,0.2886442702697396,3581.0,2365.7505254745483,2491.731642961502,2365.7505254745483,121.92598867416382,3.68761682510376,0.0 -9700,0.1183546,0.28630733,,,,,,,,,,,,,, -9800,0.61419415,0.22783467,,,,,,,,,,,,,, -9900,0.13917394,0.28969845,,,,,,,,,,,,,, -9994,,,0.749699320111956,0.2653477191925049,0.7242661767550647,0.287246506604574,3554.0,0.7414724631736945,0.2886798925753979,3581.0,2445.8985407352448,2575.993050098419,2445.8985407352448,125.99807476997375,3.715875148773194,0.0 -10000,0.13381483,0.22981739,,,,,,,,,,,,,, -10100,0.49734017,0.22779918,,,,,,,,,,,,,, -10200,0.1877008,0.34162498,,,,,,,,,,,,,, -10300,0.068483815,0.26819453,,,,,,,,,,,,,, -10339,,,0.7490485736301967,0.2654490981783186,0.724125696288513,0.2868410367005047,3554.0,0.741316611325398,0.2882380737202422,3581.0,2526.036517381668,2660.2385306358337,2526.036517381668,130.0650417804718,3.743248224258423,0.0 -10400,0.056250922,0.23410834,,,,,,,,,,,,,, -10500,0.42994848,0.25824827,,,,,,,,,,,,,, -10600,0.09448807,0.24958834,,,,,,,,,,,,,, -10683,,,0.7484066826956612,0.2652280500956944,0.7236086319991559,0.28647137385912,3554.0,0.7408440788842153,0.2878581251854405,3581.0,2606.084993124008,2744.390327692032,2606.084993124008,134.1287350654602,3.7703020572662354,0.0 -10700,0.08561201,0.22043833,,,,,,,,,,,,,, -10800,0.15541431,0.25926876,,,,,,,,,,,,,, -10900,0.06041419,0.282262,,,,,,,,,,,,,, -11000,0.0906788,0.3310698,,,,,,,,,,,,,, -11025,,,0.7488865852355957,0.2652821029935564,0.7232043642506683,0.287192426777047,3554.0,0.7404084300169296,0.2886463837462475,3581.0,2686.2606456279755,2828.676636695862,2686.2606456279755,138.194429397583,3.802428722381592,0.0 -11100,0.069610424,0.32918748,,,,,,,,,,,,,, -11200,0.07736797,0.23976876,,,,,,,,,,,,,, -11300,0.18480189,0.28446665,,,,,,,,,,,,,, -11372,,,0.7482709203447614,0.2654268571308681,0.7230311851522931,0.2867865618790007,3554.0,0.7402820986630829,0.288206951074246,3581.0,2766.471347808838,2912.99440908432,2766.471347808838,142.26156210899353,3.828813552856445,0.0 -11400,0.12226393,0.28250557,,,,,,,,,,,,,, -11500,0.11005472,0.2549957,,,,,,,,,,,,,, -11600,0.13174522,0.22825044,,,,,,,,,,,,,, -11700,0.12369825,0.3009015,,,,,,,,,,,,,, -11718,,,0.74904693875994,0.2650367191859654,0.7241317414137239,0.2864352576707143,3554.0,0.7413314738376152,0.2878480350395315,3581.0,2846.460176229477,2997.098103761673,2846.460176229477,146.33048748970032,3.8615622520446777,0.0 -11800,0.20192416,0.28701526,,,,,,,,,,,,,, -11900,0.31406856,0.29664147,,,,,,,,,,,,,, -12000,0.14353703,0.33155626,,,,,,,,,,,,,, -12059,,,0.7484876087733677,0.2649466480527605,0.7230258956677336,0.2866702275659908,3554.0,0.7402561233550335,0.288100527305571,3581.0,2926.4845530986786,3081.2274656295776,2926.4845530986786,150.39404916763306,3.889822244644165,0.0 -12100,0.11228427,0.28956163,,,,,,,,,,,,,, -12200,0.104262315,0.25060773,,,,,,,,,,,,,, -12300,0.24038072,0.20953086,,,,,,,,,,,,,, -12400,0.060606916,0.27590927,,,,,,,,,,,,,, -12404,,,0.7491763659885952,0.2648206268038068,0.7238350494161508,0.2864063544157991,3554.0,0.7411182854169576,0.2878106742290037,3581.0,3006.467269420624,3165.3140320777893,3006.467269420624,154.4581105709076,3.916551351547241,0.0 -12500,0.13271238,0.32979256,,,,,,,,,,,,,, -12600,0.18137564,0.26523998,,,,,,,,,,,,,, -12700,0.06007764,0.20856272,,,,,,,,,,,,,, -12749,,,0.748763016292027,0.2648651259286063,0.7238626646472285,0.286322169177775,3554.0,0.7410273377504538,0.2876451072064367,3581.0,3086.66849565506,3249.6230750083923,3086.66849565506,158.52502942085266,3.9443979263305664,0.0 -12800,0.09300079,0.27068618,,,,,,,,,,,,,, -12900,0.3521744,0.22151637,,,,,,,,,,,,,, -13000,0.28602374,0.32499212,,,,,,,,,,,,,, -13091,,,0.7491682597569057,0.2645631006785801,0.7234744714362338,0.286429796449643,3554.0,0.7406824320196872,0.2878322521423834,3581.0,3166.8827443122864,3333.9484016895294,3166.8827443122864,162.59536600112915,3.9718716144561768,0.0 -13100,0.11739887,0.23773262,,,,,,,,,,,,,, -13200,0.1746435,0.24104401,,,,,,,,,,,,,, -13300,0.122559294,0.35509494,,,,,,,,,,,,,, -13400,0.25209442,0.3630119,,,,,,,,,,,,,, -13437,,,0.7497834478105817,0.2646686179297311,0.7245596401062183,0.2863316318595684,3554.0,0.7417780309707483,0.287715806404461,3581.0,3247.015244960785,3418.19149851799,3247.015244960785,166.6649980545044,3.9991109371185294,0.0 -13500,0.08873259,0.3415823,,,,,,,,,,,,,, -13600,0.15365568,0.24107982,,,,,,,,,,,,,, -13700,0.17581275,0.373895,,,,,,,,,,,,,, -13782,,,0.7486034802028111,0.2658377545220511,0.7232885838359947,0.2873000368752638,3554.0,0.7406301405202806,0.2886482586044401,3581.0,3327.116854429245,3502.4071373939514,3327.116854429245,170.73585772514343,4.028689861297607,0.0 -13800,0.13700758,0.21609034,,,,,,,,,,,,,, -13900,0.14759193,0.2388286,,,,,,,,,,,,,, -14000,0.114210054,0.26366454,,,,,,,,,,,,,, -14100,0.15379961,0.27845109,,,,,,,,,,,,,, -14123,,,0.7467929295131138,0.2655568974358694,0.7217877441131472,0.2870196941936023,3554.0,0.7391924310815764,0.2883330097214465,3581.0,3407.216105222702,3586.617163658142,3407.216105222702,174.8053970336914,4.05685567855835,0.0 -14200,0.17123856,0.23059548,,,,,,,,,,,,,, -14300,0.15628397,0.23013812,,,,,,,,,,,,,, -14400,0.13301575,0.3067696,,,,,,,,,,,,,, -14468,,,0.7505397796630859,0.2643441472734724,0.7251111203925155,0.286208908948324,3554.0,0.742357123533929,0.2875891000789758,3581.0,3487.2479860782623,3670.760345458984,3487.2479860782623,178.87139344215393,4.088459253311157,0.0 -14500,0.26640195,0.35336894,,,,,,,,,,,,,, -14600,0.15357111,0.21467586,,,,,,,,,,,,,, -14700,0.12754697,0.27346528,,,,,,,,,,,,,, -14800,0.2905595,0.2782181,,,,,,,,,,,,,, -14813,,,0.7425896780831474,0.2671321630477905,0.7190899009039111,0.2882850488116383,3554.0,0.7358483658326934,0.2896819190584858,3581.0,3567.3677020072937,3754.986970663071,3567.3677020072937,182.9370470046997,4.116533517837524,0.0 -14900,0.15343033,0.20681131,,,,,,,,,,,,,, -15000,0.24614127,0.2101519,,,,,,,,,,,,,, -15100,0.23887844,0.2662841,,,,,,,,,,,,,, -15155,,,0.7493658065795898,0.2646791424070085,0.7241403969339125,0.2862461414240556,3554.0,0.7413178385053057,0.287650186367722,3581.0,3647.348000049591,3839.075215816498,3647.348000049591,187.00343370437625,4.144828796386719,0.0 -15200,0.15134522,0.25846174,,,,,,,,,,,,,, -15300,0.33274418,0.17472367,,,,,,,,,,,,,, -15400,0.18495175,0.35363778,,,,,,,,,,,,,, -15500,0.093239896,0.27721423,,,,,,,,,,,,,, -15501,,,0.7509126663208008,0.2639762333461216,0.7253296379299733,0.285933735535664,3554.0,0.7425276333644583,0.287321302152419,3581.0,3727.449832677841,3923.2359850406647,3727.449832677841,191.01669716835025,4.177236795425415,0.0 -15600,0.07892264,0.29550746,,,,,,,,,,,,,, -15700,0.2983919,0.21029267,,,,,,,,,,,,,, -15800,0.06489657,0.38593608,,,,,,,,,,,,,, -15847,,,0.7511205673217773,0.2639362130846296,0.7259452102824282,0.2856074361634778,3554.0,0.743033708723122,0.287013484525534,3581.0,3807.651174545288,4007.545920372009,3807.651174545288,195.0844044685364,4.205004215240479,0.0 -15900,0.16652265,0.20806351,,,,,,,,,,,,,, -16000,0.20826848,0.23219919,,,,,,,,,,,,,, -16100,0.1118655,0.3020956,,,,,,,,,,,,,, -16190,,,0.7509918894086566,0.2642698287963867,0.7260166526712859,0.2859001095266777,3554.0,0.7431809703120636,0.2872470577679942,3581.0,3887.811373949051,4091.81326007843,3887.811373949051,199.1486778259277,4.234395503997803,0.0 -16200,0.15840307,0.26499036,,,,,,,,,,,,,, -16300,0.263879,0.24399218,,,,,,,,,,,,,, -16400,0.14848137,0.24058084,,,,,,,,,,,,,, -16500,0.120000616,0.3204666,,,,,,,,,,,,,, -16535,,,0.7507597378322056,0.2636375086648123,0.7249813562842924,0.2857542193600081,3554.0,0.7422342010131597,0.2870917854213208,3581.0,3967.889343976976,4175.997426271439,3967.889343976976,203.21332693099976,4.2629711627960205,0.0 -16600,0.110405385,0.27569747,,,,,,,,,,,,,, -16700,0.2711191,0.28205356,,,,,,,,,,,,,, -16800,0.10674503,0.17545739,,,,,,,,,,,,,, -16881,,,0.7500498635428292,0.2640057461602347,0.7248276864536438,0.2857108902381119,3554.0,0.7419927874546216,0.2870818998053965,3581.0,4048.056460380554,4260.275257110596,4048.056460380554,207.28168106079104,4.291972875595093,0.0 -16900,0.27239987,0.31800997,,,,,,,,,,,,,, -17000,0.26276544,0.2067944,,,,,,,,,,,,,, -17100,0.21892428,0.2550886,,,,,,,,,,,,,, -17200,0.1491365,0.24043107,,,,,,,,,,,,,, -17222,,,0.7488697596958706,0.2648025410515921,0.7235759333673326,0.2863705988740679,3554.0,0.7408937796704831,0.2876822634869799,3581.0,4128.021166086197,4344.355574369431,4128.021166086197,211.35364246368408,4.321918964385986,0.0 -17300,0.17504397,0.34001344,,,,,,,,,,,,,, -17400,0.2052711,0.28486162,,,,,,,,,,,,,, -17500,0.10801699,0.22547726,,,,,,,,,,,,,, -17570,,,0.7515374592372349,0.2635806117738996,0.7258187435152293,0.2856503187704435,3554.0,0.7430041882286722,0.2870471978846691,3581.0,4208.193595170975,4428.63065123558,4208.193595170975,215.41570520401,4.34968638420105,0.0 -17600,0.20787863,0.2923659,,,,,,,,,,,,,, -17700,0.36134535,0.21042973,,,,,,,,,,,,,, -17800,0.11375805,0.27365735,,,,,,,,,,,,,, -17900,0.09451199,0.28815514,,,,,,,,,,,,,, -17916,,,0.7505055155072894,0.2637154204504831,0.7249672051957302,0.2855568254134865,3554.0,0.7422369962562831,0.2869160941645141,3581.0,4288.185995578766,4512.728018522263,4288.185995578766,219.47929334640503,4.378240585327148,0.0 -18000,0.13302495,0.3221974,,,,,,,,,,,,,, -18100,0.08974158,0.2915611,,,,,,,,,,,,,, -18200,0.19337638,0.36745694,,,,,,,,,,,,,, -18262,,,0.7505534717014858,0.2638850552695138,0.7251971947321679,0.2856249189403665,3554.0,0.7425056804794401,0.2869031746871509,3581.0,4368.278284549713,4596.934916496277,4368.278284549713,223.54808259010315,4.410733461380005,0.0 -18300,0.08979721,0.26652756,,,,,,,,,,,,,, -18400,0.11745091,0.27027285,,,,,,,,,,,,,, -18500,0.3535036,0.2788725,,,,,,,,,,,,,, -18600,0.13491525,0.39335173,,,,,,,,,,,,,, -18605,,,0.7513707705906459,0.2634532962526594,0.725492650226857,0.2857273769432329,3554.0,0.7427279363960835,0.2870902173581053,3581.0,4448.341143369675,4681.113558769226,4448.341143369675,227.6206016540528,4.440626382827759,0.0 -18700,0.3430802,0.26203865,,,,,,,,,,,,,, -18800,0.09915935,0.27138403,,,,,,,,,,,,,, -18900,0.16931933,0.24268898,,,,,,,,,,,,,, -18952,,,0.7493149893624442,0.2640271357127598,0.7240767170353827,0.2857467831690525,3554.0,0.7412419578810039,0.2870987053524679,3581.0,4528.50818157196,4765.386015415192,4528.50818157196,231.68354749679563,4.469612121582031,0.0 -19000,0.09090098,0.27976888,,,,,,,,,,,,,, -19100,0.22571078,0.19828476,,,,,,,,,,,,,, -19200,0.2668014,0.21622777,,,,,,,,,,,,,, -19298,,,0.7502527918134417,0.263763632093157,0.7248332507166221,0.2855173260158009,3554.0,0.7421114148457135,0.2868421224867356,3581.0,4608.514596700668,4849.507258892059,4608.514596700668,235.7559115886688,4.49866247177124,0.0 -19300,0.24732645,0.27886817,,,,,,,,,,,,,, -19400,0.19674858,0.23073919,,,,,,,,,,,,,, -19500,0.63669866,0.3136152,,,,,,,,,,,,,, -19600,0.22547612,0.2048299,,,,,,,,,,,,,, -19641,,,0.7502414839608329,0.2645717178072248,0.7243462746641108,0.2868963873782182,3554.0,0.7414645546809551,0.2883027051953888,3581.0,4688.5540153980255,4933.662929773331,4688.5540153980255,239.829217672348,4.528627872467041,0.0 -19700,0.063708484,0.198339,,,,,,,,,,,,,, -19800,0.11239181,0.23150317,,,,,,,,,,,,,, -19900,0.13907841,0.2515989,,,,,,,,,,,,,, -19987,,,0.7504786082676479,0.2637774603707449,0.7253230432479248,0.2857269819492561,3554.0,0.7424568659897725,0.2870272562111666,3581.0,4768.637889385223,5017.856953859329,4768.637889385223,243.896320104599,4.558478355407715,0.0 -20000,0.18667741,0.27893993,,,,,,,,,,,,,, -20100,0.15618639,0.29950005,,,,,,,,,,,,,, -20200,0.1498296,0.31282023,,,,,,,,,,,,,, -20300,0.27033868,0.28058568,,,,,,,,,,,,,, -20331,,,0.7510354178292411,0.2635509627205984,0.7256886359339828,0.2854417791243142,3554.0,0.7427785234789515,0.2868187038034941,3581.0,4848.670360326767,5102.002462863922,4848.670360326767,247.96650791168213,4.588074684143066,0.0 -20400,0.16238339,0.29257616,,,,,,,,,,,,,, -20500,0.67936724,0.25715768,,,,,,,,,,,,,, -20600,0.2804251,0.3746135,,,,,,,,,,,,,, -20674,,,0.7502099445887974,0.2634921755109514,0.7245509158914252,0.285569877388374,3554.0,0.7417379430937587,0.2869164691361526,3581.0,4928.62136054039,5186.070428848267,4928.62136054039,252.0331449508667,4.625176191329956,0.0 -20700,0.28804025,0.25261456,,,,,,,,,,,,,, -20800,0.12558457,0.31429893,,,,,,,,,,,,,, -20900,0.17154528,0.28759396,,,,,,,,,,,,,, -21000,0.15770124,0.33029938,,,,,,,,,,,,,, -21021,,,0.7512658664158413,0.263641391481672,0.7255388130011958,0.2858189640249103,3554.0,0.742762024726857,0.2872547958190798,3581.0,5008.828974246979,5270.392125844955,5008.828974246979,256.1032905578613,4.656060457229614,0.0 -21100,0.14720194,0.29330778,,,,,,,,,,,,,, -21200,0.15784235,0.19282466,,,,,,,,,,,,,, -21300,0.15873823,0.2624285,,,,,,,,,,,,,, -21366,,,0.7510338510785785,0.2635410853794643,0.7255325617921707,0.2854907240301421,3554.0,0.7427312770524993,0.286783047409505,3581.0,5088.926310777664,5354.601647377014,5088.926310777664,260.1727867126465,4.685746908187866,0.0 -21400,0.32557482,0.2097289,,,,,,,,,,,,,, -21500,0.3289519,0.19170801,,,,,,,,,,,,,, -21600,0.16212204,0.3014887,,,,,,,,,,,,,, -21700,0.15594147,0.29538015,,,,,,,,,,,,,, -21708,,,0.7513729504176548,0.263600264276777,0.7258484882790518,0.2855925981288689,3554.0,0.7430581159679559,0.2869437738891022,3581.0,5168.926723718643,5438.710728645325,5168.926723718643,264.2384469509125,4.715479373931885,0.0 -21800,0.18509421,0.22316901,,,,,,,,,,,,,, -21900,0.22962713,0.31156227,,,,,,,,,,,,,, -22000,0.12761213,0.3858299,,,,,,,,,,,,,, -22054,,,0.7516428402491978,0.263126083782741,0.7258781643482696,0.2853465512285892,3554.0,0.7430818414461743,0.2867228133290282,3581.0,5248.983458995819,5522.876201629639,5248.983458995819,268.303653717041,4.745387554168701,0.0 -22100,0.16409527,0.34745103,,,,,,,,,,,,,, -22200,0.15049161,0.2994157,,,,,,,,,,,,,, -22300,0.46806616,0.25347918,,,,,,,,,,,,,, -22399,,,0.7519129344395229,0.2633925335747855,0.7263766123997608,0.2853977458827201,3554.0,0.7435324891790003,0.2868253169396642,3581.0,5329.01037311554,5607.016135931015,5329.01037311554,272.3733148574829,4.775863885879517,0.0 -22400,0.11744859,0.3084714,,,,,,,,,,,,,, -22500,0.10547585,0.30469328,,,,,,,,,,,,,, -22600,0.17701578,0.27364028,,,,,,,,,,,,,, -22700,0.116645455,0.22448942,,,,,,,,,,,,,, -22742,,,0.7509878022330148,0.2632285526820591,0.7253110216921075,0.2852576432364941,3554.0,0.7425101801391022,0.2866503756261344,3581.0,5409.231924772263,5691.342561483383,5409.231924772263,276.4347996711731,4.80604100227356,0.0 -22800,0.17040315,0.3073444,,,,,,,,,,,,,, -22900,0.6191364,0.19644006,,,,,,,,,,,,,, -23000,0.11161247,0.2582768,,,,,,,,,,,,,, -23085,,,0.7520301001412528,0.2628756761550903,0.7259378599597285,0.285353489383661,3554.0,0.7431136799471167,0.2867670258940414,3581.0,5489.367866516113,5775.593355178833,5489.367866516113,280.50565695762634,4.836712121963501,0.0 -23100,0.38395348,0.3551159,,,,,,,,,,,,,, -23200,0.12999606,0.26122087,,,,,,,,,,,,,, -23300,0.07139376,0.1759918,,,,,,,,,,,,,, -23400,0.1498892,0.27387488,,,,,,,,,,,,,, -23431,,,0.7517572130475726,0.2629827431270054,0.7259823053689505,0.2851981880561515,3554.0,0.7431680849230313,0.2865393840211359,3581.0,5569.326278924942,5859.657479524612,5569.326278924942,284.56813645362854,4.866225004196167,0.0 -23500,0.24483255,0.3417344,,,,,,,,,,,,,, -23600,0.14447382,0.1986109,,,,,,,,,,,,,, -23700,0.12101527,0.25481188,,,,,,,,,,,,,, -23775,,,0.7516813278198242,0.2632080146244594,0.7259476145935917,0.2854376231007315,3554.0,0.7431376781319813,0.2867870357442055,3581.0,5649.284064054489,5943.725197553635,5649.284064054489,288.6337020397186,4.896897554397583,0.0 -23800,0.15261225,0.19116634,,,,,,,,,,,,,, -23900,0.5172617,0.23329383,,,,,,,,,,,,,, -24000,0.1532122,0.24223101,,,,,,,,,,,,,, -24100,0.20251766,0.30285117,,,,,,,,,,,,,, -24120,,,0.7521965163094657,0.2625001668930053,0.7258818738569218,0.2851261102422007,3554.0,0.7430952040718375,0.2865086363467781,3581.0,5729.313740730286,6027.86398601532,5729.313740730286,292.6982800960541,4.928308963775635,0.0 -24200,0.17472802,0.22116196,,,,,,,,,,,,,, -24300,0.16411291,0.26365203,,,,,,,,,,,,,, -24400,0.9859578,0.2015094,,,,,,,,,,,,,, -24466,,,0.751568181174142,0.2629297460828508,0.7256388323455966,0.2852220079103211,3554.0,0.742909081785814,0.2865454517440135,3581.0,5809.358879804611,6112.018396377564,5809.358879804611,296.76456236839294,4.957936525344849,0.0 -24500,0.092849866,0.37845117,,,,,,,,,,,,,, -24600,0.085161954,0.24977398,,,,,,,,,,,,,, -24700,0.08429137,0.30041593,,,,,,,,,,,,,, -24800,0.15628368,0.24341775,,,,,,,,,,,,,, -24810,,,0.751816953931536,0.262939385005406,0.7259641012987127,0.2851854795542874,3554.0,0.7432031277270664,0.2865225102974029,3581.0,5889.45965218544,6196.225397586823,5889.45965218544,300.82558012008667,4.990025758743286,0.0 -24900,0.10846073,0.28972068,,,,,,,,,,,,,, -25000,0.18219584,0.19542721,,,,,,,,,,,,,, -25100,0.06524832,0.3201562,,,,,,,,,,,,,, -25157,,,0.7531287329537528,0.2624353340693882,0.7268263559765406,0.2852321231908589,3554.0,0.7439072562875244,0.2866050722345364,3581.0,5969.472599029541,6280.348185300827,5969.472599029541,304.89204001426697,5.020129442214966,0.0 -25200,0.09744377,0.22648959,,,,,,,,,,,,,, -25300,0.12787388,0.23760696,,,,,,,,,,,,,, -25400,0.17459223,0.26133293,,,,,,,,,,,,,, -25500,0.110719636,0.24163343,,,,,,,,,,,,,, -25504,,,0.7517944063459124,0.2626841919762747,0.7257352795705543,0.2850880190839107,3554.0,0.7429824398736387,0.2863970311518256,3581.0,6049.431377649307,6364.412181138992,6049.431377649307,308.9539279937744,5.050216436386108,0.0 -25600,0.06703143,0.2647352,,,,,,,,,,,,,, -25700,0.25463945,0.28746343,,,,,,,,,,,,,, -25800,0.1148671,0.26587442,,,,,,,,,,,,,, -25850,,,0.7514446803501674,0.2628457375935146,0.7254274590470244,0.285239816986582,3554.0,0.7426864849858629,0.2865169538994869,3581.0,6129.473692893982,6448.574826478958,6129.473692893982,313.02527832984924,5.085922002792358,0.0 -25900,0.2445576,0.2102404,,,,,,,,,,,,,, -26000,0.10342064,0.2332688,,,,,,,,,,,,,, -26100,0.13910781,0.32272604,,,,,,,,,,,,,, -26192,,,0.7526767594473702,0.2623382295880999,0.7262026776607344,0.2851828004647053,3554.0,0.7433600704019477,0.2865926299938041,3581.0,6209.568758249283,6532.783788204193,6209.568758249283,317.09102606773376,5.121199369430542,0.0 -26200,0.08753341,0.23504668,,,,,,,,,,,,,, -26300,0.21235052,0.2184433,,,,,,,,,,,,,, -26400,0.11821633,0.33154643,,,,,,,,,,,,,, -26500,0.23578848,0.22124368,,,,,,,,,,,,,, -26539,,,0.7522519656590053,0.2625647783279419,0.726147447198579,0.2851176951531285,3554.0,0.7433585705153938,0.286464253340111,3581.0,6289.681404352188,6617.011211872101,6289.681404352188,321.1544461250305,5.158927917480469,0.0 -26600,0.13012737,0.26003692,,,,,,,,,,,,,, -26700,0.08652841,0.24481529,,,,,,,,,,,,,, -26800,0.24192789,0.21979082,,,,,,,,,,,,,, -26885,,,0.7523502622331891,0.2626039300646101,0.7264758074089055,0.2850532596139473,3554.0,0.7436085743332868,0.2863344449765254,3581.0,6369.843276500702,6701.28259563446,6369.843276500702,325.2188422679901,5.190818786621094,0.0 -26900,0.15650885,0.27103537,,,,,,,,,,,,,, -27000,0.07868012,0.2724521,,,,,,,,,,,,,, -27100,0.1933596,0.33553845,,,,,,,,,,,,,, -27200,0.108133696,0.32822725,,,,,,,,,,,,,, -27225,,,0.7528362955365863,0.2628480195999145,0.7270255703300859,0.2855947963562183,3554.0,0.7439125740671251,0.2870730027510646,3581.0,6449.991070270538,6785.541536808014,6449.991070270538,329.28604793548584,5.221648454666138,0.0 -27300,0.11173062,0.21831483,,,,,,,,,,,,,, -27400,0.14316453,0.26699075,,,,,,,,,,,,,, -27500,0.15332173,0.22966228,,,,,,,,,,,,,, -27570,,,0.752467427934919,0.2624267680304391,0.7263111464415095,0.2850071311869109,3554.0,0.7434651988140534,0.2863756236800998,3581.0,6530.014122247696,6869.67798948288,6530.014122247696,333.3548312187195,5.253356695175171,0.0 -27600,0.07910081,0.24955675,,,,,,,,,,,,,, -27700,0.09294644,0.33253264,,,,,,,,,,,,,, -27800,0.11622867,0.26997155,,,,,,,,,,,,,, -27900,0.09222654,0.23932609,,,,,,,,,,,,,, -27915,,,0.7517823491777692,0.262539301599775,0.7257233267093416,0.2850680461276027,3554.0,0.7429035594762287,0.2863726920836533,3581.0,6610.146442651749,6953.919053077698,6610.146442651749,337.4182233810425,5.285075664520264,0.0 -28000,0.1009798,0.23612142,,,,,,,,,,,,,, -28100,0.26118687,0.25008115,,,,,,,,,,,,,, -28200,0.12505554,0.2901137,,,,,,,,,,,,,, -28258,,,0.7520993777683803,0.2624285902295794,0.7257936012899198,0.2851729943098885,3554.0,0.7430074607084264,0.2864671849365575,3581.0,6690.115103960037,7038.001373529434,6690.115103960037,341.4839758872986,5.319840431213379,0.0 -28300,0.08208366,0.1922799,,,,,,,,,,,,,, -28400,0.14429109,0.22468793,,,,,,,,,,,,,, -28500,0.08396057,0.2507328,,,,,,,,,,,,,, -28600,0.11484048,0.3370731,,,,,,,,,,,,,, -28605,,,0.7531392233712333,0.2624111686434064,0.7267762776097355,0.2852265245805782,3554.0,0.7439264821060807,0.2865967205934969,3581.0,6770.211782455444,7122.209839820862,6770.211782455444,345.5509777069092,5.350916862487793,0.0 -28700,0.17377646,0.29897705,,,,,,,,,,,,,, -28800,0.11863819,0.2572542,,,,,,,,,,,,,, -28900,0.13415587,0.2759601,,,,,,,,,,,,,, -28950,,,0.7524910654340472,0.2623393024717058,0.7262524812491207,0.2850574499848322,3554.0,0.7434527224849903,0.2863386719295413,3581.0,6850.18680357933,7206.29279589653,6850.18680357933,349.6141200065613,5.382358074188232,0.0 -29000,0.10798198,0.29459536,,,,,,,,,,,,,, -29100,0.090550505,0.19905333,,,,,,,,,,,,,, -29200,0.12314472,0.24426314,,,,,,,,,,,,,, -29290,,,0.7528343200683594,0.2626663787024362,0.7266698009724958,0.2854241074372626,3554.0,0.7437895151930327,0.2867405733493612,3581.0,6930.2322034835815,7290.450036048889,6930.2322034835815,353.68150186538696,5.413628816604614,0.0 -29300,0.11920651,0.22425507,,,,,,,,,,,,,, -29400,0.14361309,0.22658366,,,,,,,,,,,,,, -29500,0.15477164,0.29283133,,,,,,,,,,,,,, -29600,0.16823144,0.26996207,,,,,,,,,,,,,, -29634,,,0.7528693335396903,0.2621491125651768,0.7264843942344893,0.2850999032505187,3554.0,0.7435791220154985,0.2864734912777506,3581.0,7010.228944301605,7374.552933931351,7010.228944301605,357.7428359985352,5.4454345703125,0.0 -29700,0.068483666,0.34114823,,,,,,,,,,,,,, -29800,0.14641857,0.29979715,,,,,,,,,,,,,, -29900,0.14226146,0.20605765,,,,,,,,,,,,,, -29980,,,0.7527691296168736,0.2622225625174386,0.7263874661472988,0.285038490273943,3554.0,0.743540056788432,0.2863717376103916,3581.0,7090.363373994827,7458.798035383224,7090.363373994827,361.8078374862671,5.477893114089966,0.0 -30000,0.0692183,0.27528208,,,,,,,,,,,,,, -30100,0.09627299,0.27169955,,,,,,,,,,,,,, -30200,0.12905502,0.24403793,,,,,,,,,,,,,, -30300,0.10294933,0.16398934,,,,,,,,,,,,,, -30322,,,0.7526637486049107,0.2622182880129133,0.7263018039752743,0.2850195477367051,3554.0,0.7435047412777507,0.2863525458801662,3581.0,7170.468336820602,7543.014036178589,7170.468336820602,365.8724312782288,5.510989665985107,0.0 -30400,0.13689499,0.23785338,,,,,,,,,,,,,, -30500,0.24810362,0.28252217,,,,,,,,,,,,,, -30600,0.08206068,0.2643607,,,,,,,,,,,,,, -30665,,,0.7531729425702777,0.2618430852890014,0.7264931871438872,0.2849867460629748,3554.0,0.7435993023073164,0.2863883386274783,3581.0,7250.497773885727,7627.158119440079,7250.497773885727,369.9424302577973,5.542418241500855,0.0 -30700,0.08942063,0.28808305,,,,,,,,,,,,,, -30800,0.0929675,0.28354397,,,,,,,,,,,,,, -30900,0.08212001,0.22103019,,,,,,,,,,,,,, -31000,0.06748356,0.26303077,,,,,,,,,,,,,, -31008,,,0.7529027802603585,0.2621064186096191,0.7264368575680571,0.2850429554232467,3554.0,0.7435995750139626,0.2863955653536023,3581.0,7330.379662036896,7711.4234845638275,7330.379662036896,374.01244950294495,5.842834711074829,0.0 -31100,0.13605471,0.2812281,,,,,,,,,,,,,, -31200,0.1553022,0.1890331,,,,,,,,,,,,,, -31300,0.08861885,0.2935522,,,,,,,,,,,,,, -31351,,,0.7530208315168109,0.2620646102087838,0.7265999385595456,0.2849707573937377,3554.0,0.743733473977241,0.286300186204098,3581.0,7410.376347541809,7795.527099370956,7410.376347541809,378.07388615608215,5.875456809997559,0.0 -31400,0.09842726,0.20883656,,,,,,,,,,,,,, -31500,0.09900382,0.3038549,,,,,,,,,,,,,, -31600,0.1313737,0.20309144,,,,,,,,,,,,,, -31697,,,0.753650392804827,0.2617484842027937,0.7268968366409327,0.2849589762690366,3554.0,0.7439781600155334,0.2863908270756248,3581.0,7490.46945476532,7879.72705745697,7490.46945476532,382.1353690624237,5.907495021820068,0.0 -31700,0.105678104,0.25430977,,,,,,,,,,,,,, -31800,0.108262695,0.27129576,,,,,,,,,,,,,, -31900,0.06383153,0.25533983,,,,,,,,,,,,,, -32000,0.15936837,0.2186281,,,,,,,,,,,,,, -32039,,,0.7531657900129046,0.2619015148707798,0.7266006255055923,0.2849250239606781,3554.0,0.7437529043257819,0.2862719610662175,3581.0,7570.541733264923,7963.911533355713,7570.541733264923,386.2018015384674,5.940229654312134,0.0 -32100,0.07290696,0.42483437,,,,,,,,,,,,,, -32200,0.09966556,0.27720508,,,,,,,,,,,,,, -32300,0.072978444,0.20083202,,,,,,,,,,,,,, -32380,,,0.7526820727757045,0.2619706051690237,0.7262389484120005,0.2849329410138664,3554.0,0.7434078622416923,0.2862588711472005,3581.0,7650.532155036926,8048.016820669174,7650.532155036926,390.2697324752808,5.974608659744263,0.0 -32400,0.049402308,0.33485952,,,,,,,,,,,,,, -32500,0.08384874,0.27547356,,,,,,,,,,,,,, -32600,0.06064002,0.26149732,,,,,,,,,,,,,, -32700,0.09935546,0.1929067,,,,,,,,,,,,,, -32725,,,0.7534105437142509,0.2616528442927769,0.7265398994750633,0.2849374576841235,3554.0,0.7436514574533999,0.2863338654749022,3581.0,7730.736515045166,8132.329214334488,7730.736515045166,394.3321797847748,6.007458448410034,0.0 -32800,0.046214547,0.3197526,,,,,,,,,,,,,, -32900,0.08261178,0.2320137,,,,,,,,,,,,,, -33000,0.062042795,0.3434059,,,,,,,,,,,,,, -33070,,,0.7533956936427525,0.2617532696042742,0.7267253062130697,0.2848618764453345,3554.0,0.7438506014817788,0.2862581212039234,3581.0,7810.769964694977,8216.474393606186,7810.769964694977,398.39750695228577,6.040522813796997,0.0 -33100,0.05276686,0.28607848,,,,,,,,,,,,,, -33200,0.07903318,0.3429837,,,,,,,,,,,,,, -33300,0.073079355,0.30488002,,,,,,,,,,,,,, -33400,0.086334586,0.2648785,,,,,,,,,,,,,, -33413,,,0.7529990332467216,0.2617817606244768,0.726368231657991,0.2848488759914005,3554.0,0.7435203537332449,0.2861579015114493,3581.0,7890.764701366425,8300.579656124115,7890.764701366425,402.461065530777,6.074344158172607,0.0 -33500,0.08677307,0.19478406,,,,,,,,,,,,,, -33600,0.094153,0.23570576,,,,,,,,,,,,,, -33700,0.05636058,0.34188247,,,,,,,,,,,,,, -33754,,,0.7533809798104423,0.2616239275251116,0.7266520777644907,0.2848602449484735,3554.0,0.7437582902820441,0.2862019436348087,3581.0,7970.896803617477,8384.830063343048,7970.896803617477,406.5280122756958,6.112351179122925,0.0 -33800,0.063189425,0.20640507,,,,,,,,,,,,,, -33900,0.0688106,0.31797928,,,,,,,,,,,,,, -34000,0.07077207,0.26732317,,,,,,,,,,,,,, -34098,,,0.7534258706229073,0.2616209132330758,0.7267079951726927,0.2848003947741541,3554.0,0.7438228535805291,0.2861940692304,3581.0,8051.041610479355,8469.08253955841,8051.041610479355,410.5884718894959,6.1458728313446045,0.0 -34100,0.07231735,0.21049635,,,,,,,,,,,,,, -34200,0.0734844,0.27856585,,,,,,,,,,,,,, -34300,0.072592415,0.25078386,,,,,,,,,,,,,, -34400,0.07120686,0.19773066,,,,,,,,,,,,,, -34442,,,0.7533926963806152,0.2616205726351057,0.7267149333277645,0.2847812976740556,3554.0,0.743851624131702,0.2861305626701689,3581.0,8131.060321331024,8553.214111566544,8131.060321331024,414.6534061431885,6.180485725402832,0.0 -34500,0.07811438,0.22743091,,,,,,,,,,,,,, -34600,0.07362585,0.21460903,,,,,,,,,,,,,, -34700,0.12531316,0.30729103,,,,,,,,,,,,,, -34783,,,0.7533390181405204,0.2616024698529924,0.726566690370885,0.2848096170248311,3554.0,0.7437214748848087,0.2861496180470713,3581.0,8211.12027812004,8637.385941267014,8211.12027812004,418.7173886299133,6.215153217315674,0.0 -34800,0.06534952,0.2567484,,,,,,,,,,,,,, -34900,0.08021844,0.17488417,,,,,,,,,,,,,, -35000,0.063017525,0.30784938,,,,,,,,,,,,,, -35100,0.041254245,0.251337,,,,,,,,,,,,,, -35127,,,0.7536563192095075,0.261509827205113,0.7268637258414814,0.284757271736072,3554.0,0.7439961586541818,0.2861220405874756,3581.0,8291.133207082748,8721.511438369751,8291.133207082748,422.782103061676,6.249707460403442,0.0 -35200,0.04978085,0.24876112,,,,,,,,,,,,,, -35300,0.056681868,0.33323646,,,,,,,,,,,,,, -35400,0.08185614,0.31663665,,,,,,,,,,,,,, -35473,,,0.7537084306989398,0.2615468502044678,0.7269505558217854,0.2847832382966376,3554.0,0.7440684940920832,0.2861458683306863,3581.0,8371.14042711258,8805.637065887451,8371.14042711258,426.8485105037689,6.288463354110718,0.0 -35500,0.06706356,0.27509567,,,,,,,,,,,,,, -35600,0.0634645,0.24916281,,,,,,,,,,,,,, -35700,0.088894255,0.16553514,,,,,,,,,,,,,, -35800,0.116396606,0.25773397,,,,,,,,,,,,,, -35816,,,0.7537142208644322,0.261535746710641,0.7269712328977912,0.2847685033039357,3554.0,0.7440939239868403,0.2861297786385611,3581.0,8451.220481395721,8889.82793712616,8451.220481395721,430.91059255599976,6.323931455612183,0.0 -35900,0.05642856,0.27955827,,,,,,,,,,,,,, -36000,0.049111195,0.21542455,,,,,,,,,,,,,, -36100,0.0508712,0.2633106,,,,,,,,,,,,,, -36161,,,0.753655093056815,0.2615071535110473,0.7268996531197243,0.2847359420613217,3554.0,0.7440289516283859,0.2860935768312797,3581.0,8531.392409324646,8974.114171504974,8531.392409324646,434.9767861366272,6.358710527420044,0.0 -36189,,,0.7536547524588448,0.2615070853914533,0.7268993783413056,0.2847359420613217,3554.0,0.7440284062150936,0.2860935768312797,3581.0,8535.762496232986,8982.592011928558,8535.762496232986,439.0459842681885,6.394179105758667,0.0 -36189,,,,,,,,,,,8535.762496232986,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/eval_measurements.csv deleted file mode 100644 index fbb570f99..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -4.018085956573486,0.0,30.921306133270264,1,0,30.921306133270264,1.0888061012112538,3581,0.2660334975118245,34.93953561782837,1.0853406361171178,0.246516398021153,1.0927659497880908,3554,0.2452213628679832 -8.083146572113037,0.0205724239349365,110.98090124130248,343,0,110.98090124130248,0.314064927089151,3581,0.7135120552700014,119.09683203697205,0.2926885059901646,0.7190193448747907,0.3120596675840602,3554,0.6958582099395048 -12.156377792358398,0.0475318431854248,191.13330054283145,682,0,191.13330054283145,0.3084947233991029,3581,0.7233772181958601,203.36154174804688,0.2871142455509731,0.7287539754595075,0.3065337705424346,3554,0.7065624780177265 -16.221641063690186,0.0735299587249755,271.3507161140442,1021,0,271.3507161140442,0.2976419192657428,3581,0.7285011715477521,287.68252325057983,0.2757679224014282,0.7351130758013044,0.2959042421940946,3554,0.710854448222953 -20.29216480255127,0.0981926918029785,351.4316899776459,1367,0,351.4316899776459,0.2956445817007295,3581,0.7317456306941148,371.8707139492035,0.2734766347067697,0.7387027059282575,0.29405591081352,3554,0.7141811218433455 -24.362091302871704,0.1249408721923828,431.6135165691376,1717,0,431.6135165691376,0.2967758711341106,3581,0.7316911575415387,456.1615972518921,0.2745262043816702,0.7386926923479352,0.2950654467237619,3554,0.7143179614958497 -28.435359239578247,0.149343729019165,511.7001166343689,2063,0,511.7001166343689,0.2973489300627443,3581,0.728711155665317,540.3580513000488,0.2759300640651158,0.732980591910226,0.2958190952316052,3554,0.7118857603228756 -32.50367093086243,0.1732890605926513,591.8303995132446,2408,0,591.8303995132446,0.2951985358925579,3581,0.7338002707431583,624.5928885936737,0.2731419120516096,0.7401093755449567,0.2937547193193409,3554,0.7167132736661156 -36.57264566421509,0.1989805698394775,672.050882101059,2755,0,672.050882101059,0.2934341920683817,3581,0.7364833632583426,708.9201610088348,0.2711653198514666,0.7433654921395438,0.2919886840751617,3554,0.7190582326911579 -40.64362549781799,0.2251121997833252,752.0927169322968,3098,0,752.0927169322968,0.2930829799964221,3581,0.7322224582649749,793.0711913108826,0.2710207189832415,0.7391443252563477,0.2916609077689751,3554,0.7145846339511818 -44.71409749984741,0.2506833076477051,832.1244361400604,3444,0,832.1244361400604,0.2921300066213174,3581,0.7370181409915177,877.211119890213,0.2696787970406668,0.744307381766183,0.2905521081549838,3554,0.7197982796672763 -48.78200364112854,0.2791216373443603,912.100329875946,3790,0,912.100329875946,0.2941143906206367,3581,0.7324526226743577,961.2952928543092,0.2722888674054827,0.738464491707938,0.2926425536697207,3554,0.7155386646208497 -52.84801506996155,0.3050611019134521,992.0747768878936,4135,0,992.0747768878936,0.2919348168393081,3581,0.7397047105164409,1045.3737080097198,0.2697600637163435,0.7464638437543597,0.2903878937025183,3554,0.7227100378754572 -56.91461539268494,0.3306634426116943,1072.153960943222,4479,0,1072.153960943222,0.2929800673258168,3581,0.7403491844980452,1129.5571291446686,0.2706511190959385,0.7473081861223493,0.2913767525367544,3554,0.7233477298906162 -60.98463320732117,0.3593010902404785,1152.2814135551453,4828,0,1152.2814135551453,0.2928126254450572,3581,0.739230269128735,1213.7953803539276,0.2700659206935337,0.7465886388506208,0.2913121796083638,3554,0.7221604123434862 -65.05647873878479,0.3860032558441162,1232.244782447815,5175,0,1232.244782447815,0.2925935397431758,3581,0.730916534313041,1297.869435787201,0.2703052588871547,0.7383684430803571,0.2912168314970808,3554,0.7130640101821891 -69.12384462356567,0.4116835594177246,1312.2321465015411,5519,0,1312.2321465015411,0.2919312375645769,3581,0.7368998544837336,1381.9618797302246,0.2692771298544748,0.7446727752685547,0.2904981828903172,3554,0.7195834716384707 -73.19486999511719,0.4371984004974365,1392.328580379486,5865,0,1392.328580379486,0.2922777795352206,3581,0.7348363514686889,1466.1667184829712,0.2701320307595389,0.7418381146022252,0.2907738199915588,3554,0.7172035470596511 -77.25946640968323,0.4633595943450928,1472.3439490795135,6212,0,1472.3439490795135,0.2911363658675998,3581,0.7384315113620497,1550.2850694656372,0.2688169138772147,0.7456263133457729,0.2895488234537668,3554,0.7211000424257878 -81.32579112052917,0.4954812526702881,1552.5109388828278,6555,0,1552.5109388828278,0.2919402368839011,3581,0.7346102094823374,1634.5626347064972,0.2695868696485247,0.7413718359810966,0.2903499055861353,3554,0.7175827412774338 -85.39356756210327,0.5252249240875244,1632.5720751285553,6904,0,1632.5720751285553,0.2937172274787943,3581,0.7325803175614354,1718.733685016632,0.2716102770396641,0.7396819250924247,0.2920786396599782,3554,0.7154917462058596 -89.4588303565979,0.5517885684967041,1712.6343562602997,7250,0,1712.6343562602997,0.2922027511191881,3581,0.734957365042935,1802.8997659683228,0.2699088539396013,0.742391858782087,0.2908533339964652,3554,0.7173541256330894 -93.52749109268188,0.5771715641021729,1792.7491919994354,7595,0,1792.7491919994354,0.2917915094967362,3581,0.7391194138770595,1887.1206386089325,0.2696233136313302,0.7462922504970005,0.2901129435473234,3554,0.7221008541212366 -97.5896496772766,0.6047909259796143,1872.717239141464,7941,0,1872.717239141464,0.2932556714801207,3581,0.7348287156825957,1971.190478086472,0.2708460944039481,0.7423366819109235,0.2916510500932048,3554,0.7175806117446891 -101.66148519515993,0.6309757232666016,1952.7435121536253,8288,0,1952.7435121536253,0.2910643713130061,3581,0.7391162777506283,2055.3272376060486,0.2686344725745065,0.7464926583426339,0.2896130872564364,3554,0.7218817870269415 -105.73176598548888,0.6581485271453857,2032.869675397873,8634,0,2032.869675397873,0.295108065462685,3581,0.7369340791678303,2139.5630140304565,0.2723206111363002,0.7441981860569545,0.2937800332811621,3554,0.7199123814056345 -109.80195808410645,0.6855344772338867,2113.0663661956787,8980,0,2113.0663661956787,0.2899546597930047,3581,0.7375566002644164,2223.8691778182983,0.2677069221224104,0.7444562230791364,0.2886012157296356,3554,0.7200566400754431 -113.86398601531982,0.7114531993865967,2193.08740067482,9327,0,2193.08740067482,0.2921331086594178,3581,0.7373822725408405,2307.990259170532,0.2699027742658342,0.744776862008231,0.2906218331787247,3554,0.7201405161877462 -117.93263745307922,0.7382164001464844,2273.16263628006,9671,0,2273.16263628006,0.2912903428577038,3581,0.7349934986735549,2392.173000574112,0.2686062029429844,0.7430753707885742,0.2899052454101013,3554,0.7174473442116277 -121.99877452850342,0.7642619609832764,2353.270979166031,10018,0,2353.270979166031,0.2935151518539688,3581,0.7365767652846621,2476.385600566864,0.2707388401031494,0.7443200520106724,0.2921717895439117,3554,0.7192811466833146 -126.0719668865204,0.7919392585754395,2433.446925401688,10365,0,2433.446925401688,0.2911413427638927,3581,0.7376453662777507,2560.674476146698,0.2684168475014822,0.7452825137547084,0.2898022378503974,3554,0.7201388675172341 -130.13805532455444,0.8192059993743896,2513.66655087471,10711,0,2513.66655087471,0.2900784686103742,3581,0.7377283372748534,2644.9995489120483,0.2677556276321411,0.7451065608433315,0.2886788406329136,3554,0.7203397305412915 -134.20095443725586,0.8461177349090576,2593.805549621582,11057,0,2593.805549621582,0.2898371573168284,3581,0.7402480103323094,2729.240618467331,0.2672877141407558,0.7475382941109794,0.2884536940661051,3554,0.7230717836636537 -138.2629885673523,0.8738346099853516,2673.933702230453,11404,0,2673.933702230453,0.2903264271284208,3581,0.7352287763325538,2813.4708971977234,0.2679642949785505,0.7428219658987862,0.2890443302770645,3554,0.7175823291098059 -142.33258748054504,0.9034388065338136,2753.910780191421,11751,0,2753.910780191421,0.290267249786198,3581,0.7377622892523038,2897.559303998947,0.2679216861724853,0.7452921867370605,0.2889721322475556,3554,0.7202087299301843 -146.40227818489075,0.9323925971984864,2833.876010656357,12095,0,2833.876010656357,0.2911458765118856,3581,0.7378974835721517,2981.635125398636,0.2683262654713222,0.7459325109209333,0.2895289020184123,3554,0.7207554015941545 -150.47212028503418,0.960963010787964,2913.926083803177,12442,0,2913.926083803177,0.2897804002460905,3581,0.738725284596656,3065.7957208156586,0.2672440324510847,0.7460260391235352,0.2883688562293366,3554,0.7213144382869654 -154.54147481918335,0.9875788688659668,2993.89020228386,12787,0,2993.89020228386,0.2897285859833147,3581,0.7387887570685563,3149.8680698871613,0.267539484160287,0.7459701129368373,0.2883787825997116,3554,0.7213828581132175 -158.60821294784546,1.0157010555267334,3073.952570438385,13129,0,3073.952570438385,0.2914820215416434,3581,0.7373984985862887,3234.037271499634,0.2689516884940011,0.7447130339486259,0.2901015058956457,3554,0.7199925480092854 -162.68055200576782,1.0426347255706787,3154.042021512985,13476,0,3154.042021512985,0.28967902155037,3581,0.7384223756894024,3318.2383949756622,0.2672208036695208,0.7458803313119071,0.2882889300568022,3554,0.7209985804946891 -166.74572896957395,1.071657419204712,3234.0972611904144,13822,0,3234.0972611904144,0.2893738969016162,3581,0.7383779245060738,3402.4001002311707,0.2666973727090018,0.7462607792445591,0.2880001894597196,3554,0.7209504255768149 -170.81399512290955,1.099256992340088,3314.093006849289,14165,0,3314.093006849289,0.2902781580520455,3581,0.7433448670064228,3486.50393986702,0.2673636163984026,0.7511077608380999,0.2888397233970526,3554,0.7264354836759637 -174.87861514091492,1.128664493560791,3394.106276988983,14512,0,3394.106276988983,0.289621071388055,3581,0.7409909995898492,3570.623516082764,0.2671572991779872,0.7483531406947544,0.2885451265849219,3554,0.7235218707134566 -178.94793677330017,1.155869722366333,3474.10543012619,14857,0,3474.10543012619,0.2893824871609711,3581,0.7379840679323164,3654.7311673164368,0.2669458218983241,0.7454121453421456,0.288061774172807,3554,0.7204186606420583 -183.0156137943268,1.1840167045593262,3554.072303533554,15193,0,3554.072303533554,0.2890145718069324,3581,0.7402746674069743,3738.805694103241,0.2667359624590192,0.7471641813005719,0.2876480952634795,3554,0.7230010282208427 -187.08305025100708,1.214017391204834,3634.212679386139,15538,0,3634.212679386139,0.289791001716961,3581,0.7416667666591036,3823.0554959774017,0.2667510339191982,0.7494399888174874,0.2883546364461698,3554,0.7244877855497327 -191.1493947505951,1.2421448230743408,3714.300877094269,15884,0,3714.300877094269,0.2901403730190589,3581,0.7372491235208392,3907.250316858292,0.2677041462489536,0.7447804042271206,0.2888085016992297,3554,0.7197174948121835 -195.2165243625641,1.2717201709747314,3794.307913780213,16229,0,3794.307913780213,0.2892953232991831,3581,0.7380651981595574,3991.366268634796,0.2670665298189436,0.7450590133666992,0.2879636782773371,3554,0.720541830068233 -199.2868800163269,1.3034842014312744,3874.412040710449,16575,0,3874.412040710449,0.2887590797677848,3581,0.7400445029975915,4075.5846552848816,0.2658184255872454,0.7477623394557408,0.2874061356921778,3554,0.7226664168014912 -203.36089968681333,1.3331928253173828,3954.473661661148,16923,0,3954.473661661148,0.2885726506867844,3581,0.7386974003420832,4159.762113332748,0.265954852104187,0.7463718141828265,0.2872358245935477,3554,0.7211565093908272 -207.42804789543152,1.3615479469299316,4034.495836257935,17265,0,4034.495836257935,0.289015083131894,3581,0.7388150732599135,4243.891760110855,0.2664083412715367,0.7463821683611188,0.2878108499555958,3554,0.7211328784468205 -211.501356124878,1.390101432800293,4114.508437395096,17608,0,4114.508437395096,0.2888032923327981,3581,0.7425831291669576,4328.01832151413,0.2660165173666818,0.7503792217799595,0.2874407921202342,3554,0.7255287148943093 -215.56890416145325,1.4220576286315918,4194.599543809891,17953,0,4194.599543809891,0.2886968344757923,3581,0.7408580550998325,4412.221243619919,0.2660280976976667,0.748267650604248,0.2873225343582934,3554,0.7235039414216375 -219.6377148628235,1.4514093399047852,4274.669712781906,18299,0,4274.669712781906,0.2881057769085102,3581,0.7409523434227521,4496.401560783386,0.2655689546040126,0.7484581129891532,0.2867548936662475,3554,0.7236135093160875 -223.7048351764679,1.4820961952209473,4354.866664648056,18646,0,4354.866664648056,0.2888511523492041,3581,0.737948002478358,4580.708430051804,0.2659528255462646,0.7454803330557687,0.2874833656014789,3554,0.7206421241910523 -227.77167344093323,1.5122063159942627,4434.899230241776,18989,0,4434.899230241776,0.2887143899661407,3581,0.7397250953382435,4664.850015163422,0.2660172326224191,0.7470806666782924,0.2872723529495815,3554,0.7226263678469682 -231.8424940109253,1.5426294803619385,4514.962484121323,19332,0,4514.962484121323,0.2883070344133971,3581,0.7417560099090686,4749.026458740234,0.2656295129231044,0.7494047028677804,0.2869826334543824,3554,0.7244741840180079 -235.91216826438904,1.573349952697754,4595.0616002082825,19676,0,4595.0616002082825,0.2881644088374406,3581,0.740140154853742,4833.237877607346,0.2648649385997227,0.7484260967799595,0.2867769446343469,3554,0.7227509111652364 -239.9841315746308,1.6079411506652832,4675.134254932404,20023,0,4675.134254932404,0.2881558185780857,3581,0.738479030495148,4917.429103136063,0.2652723789215088,0.7465226990836007,0.2867201341962841,3554,0.7210318973779544 -244.053169965744,1.639510154724121,4755.205282211304,20368,0,4755.205282211304,0.2882353125654496,3581,0.7434962191950573,5001.61284160614,0.2655698912484305,0.7510388919285366,0.2868956832585203,3554,0.7263813523274831 -248.11840963363647,1.670926809310913,4835.244169473648,20714,0,4835.244169473648,0.2880398500767941,3581,0.7412412761143884,5085.760347127914,0.2648200818470546,0.7495647839137486,0.2867387676078011,3554,0.7239776594154473 -252.1876802444458,1.7009556293487549,4915.309281587601,21059,0,4915.309281587601,0.2880960958225705,3581,0.7430783644364354,5169.936907529831,0.2651166234697614,0.7508315358843122,0.286702204904465,3554,0.7259722759566686 -256.25526690483093,1.7352640628814695,4995.265493392944,21405,0,4995.265493392944,0.2880237944729998,3581,0.7425363599771363,5254.006997823715,0.2652876717703683,0.7501467296055385,0.2866950778392304,3554,0.7254517082424733 -260.32216477394104,1.7701237201690674,5075.27999830246,21750,0,5075.27999830246,0.2883760632832134,3581,0.7422327693032672,5338.135623931885,0.2651770285197666,0.7506645747593471,0.2870192476786719,3554,0.7251320035523354 -264.38602781295776,1.8033950328826904,5155.243874788284,22097,0,5155.243874788284,0.2875862707475216,3581,0.7391977488611771,5422.208884477615,0.2646193844931466,0.7473197664533343,0.2862578366805008,3554,0.7216715815366489 -268.45214796066284,1.8332464694976809,5235.346307039261,22440,0,5235.346307039261,0.287656799503892,3581,0.7404796746282463,5506.419248342514,0.2648459843226841,0.7482331139700753,0.2863896444532129,3554,0.7230001351909819 -272.5250914096832,1.8631410598754885,5315.468108892441,22783,0,5315.468108892441,0.2875015953338802,3581,0.7431066577509774,5590.655894994736,0.2646455764770508,0.7507668903895787,0.2861737373107326,3554,0.7259625213228053 -276.5922944545746,1.8937137126922607,5395.63080406189,23131,0,5395.63080406189,0.2874370320353951,3581,0.7412096421434307,5674.928566217423,0.2643979617527553,0.7489895820617676,0.2860227637433173,3554,0.7240669624015195 -280.66194915771484,1.9245622158050537,5475.673554420471,23478,0,5475.673554420471,0.2886014894146188,3581,0.7412838183511938,5759.084021091461,0.2652779647282192,0.7495903968811035,0.2874707944888242,3554,0.7237818110975309 -284.73019194602966,1.9562304019927976,5555.82008767128,23824,0,5555.82008767128,0.287390842347197,3581,0.7429002869964745,5843.342591047287,0.2644648041043962,0.7507916178022113,0.2861515661270751,3554,0.7256301081308033 -288.7960410118103,1.9867405891418457,5635.91052532196,24168,0,5635.91052532196,0.2874138178821384,3581,0.7394907039758447,5927.541179180145,0.2642674275806972,0.7475453104291644,0.2861658546048466,3554,0.7220405402583356 -292.8611936569214,2.0173099040985107,5715.999706029892,24514,0,5715.999706029892,0.2871191242626012,3581,0.7424466394905403,6011.738223552704,0.2639252458299909,0.7507332393101284,0.2857912629255768,3554,0.7251709533931837 -296.9268276691437,2.047306060791016,5796.07946395874,24858,0,5796.07946395874,0.2872798166538676,3581,0.7425150888587336,6095.925489425659,0.2644594056265695,0.7502329690115792,0.2860405212986248,3554,0.7251993242649127 -300.9990952014923,2.07837176322937,5876.218941450119,25203,0,5876.218941450119,0.2870587538288013,3581,0.7400631834028554,6180.180327177048,0.2636398417609079,0.748668943132673,0.2858235837370744,3554,0.7225095183244232 -305.06549167633057,2.109112024307251,5956.200823068619,25549,0,5956.200823068619,0.2870998984440449,3581,0.7440258155019548,6264.271352767944,0.2639256715774536,0.7521319389343262,0.2857768370585959,3554,0.7268184560970034 -309.1350910663605,2.1402125358581543,6036.274323225021,25895,0,6036.274323225021,0.2869088674383901,3581,0.7425444729998604,6348.457298994064,0.2637628316879272,0.7507930483136859,0.2856170534081317,3554,0.7252840934070766 -313.1980721950531,2.1756951808929443,6116.422718524933,26241,0,6116.422718524933,0.2866761464041992,3581,0.7435862805649609,6432.716036319733,0.2629812274660383,0.7523622512817383,0.2854140780249806,3554,0.7263478980550084 -317.26215529441833,2.207003831863404,6196.501030445099,26588,0,6196.501030445099,0.2865871076842188,3581,0.7424559796931723,6516.901560783386,0.263419577053615,0.7505794933864048,0.2853495050965901,3554,0.7252331220104108 -321.3304181098938,2.2384111881256104,6276.576962709427,26935,0,6276.576962709427,0.2867723436736421,3581,0.7430407990959229,6601.088958978653,0.2635314805167062,0.7514126641409737,0.2855885279735421,3554,0.7257549262274902 -325.4008071422577,2.2764458656311035,6356.704101800919,27278,0,6356.704101800919,0.2867761615666888,3581,0.7421731147244136,6685.336197853088,0.2628715208598545,0.7512620517185756,0.2853974882779526,3554,0.7248926715496623 -329.46841621398926,2.3095593452453613,6436.780147314072,27624,0,6436.780147314072,0.286496466812692,3581,0.7441440338330774,6769.524851560593,0.2629967076437814,0.7526355470929827,0.2851656955081422,3554,0.7269753545740715 -333.53331780433655,2.341780662536621,6516.768382072449,27970,0,6516.768382072449,0.2865856759743263,3581,0.7426191264442544,6853.62221121788,0.2631453445979527,0.7511420931134906,0.2853274197811885,3554,0.7253446133537915 -337.60342359542847,2.3730151653289795,6596.887026309967,28312,0,6596.887026309967,0.2868086818342467,3581,0.7436688425020944,6937.854245662689,0.2629969120025635,0.7525931085859027,0.285445660369478,3554,0.7264981331554234 -341.6689395904541,2.404748678207397,6676.900361776352,28659,0,6676.900361776352,0.2864548449608175,3581,0.742927898544401,7021.976664066315,0.2628911563328334,0.7515886851719448,0.2851351092354126,3554,0.7257243571284117 -345.7351453304291,2.437550067901612,6757.101865053177,29005,0,6757.101865053177,0.2864383121203923,3581,0.7437804476970469,7106.2891047000885,0.2629059553146362,0.7523136820111956,0.2851231220268975,3554,0.7265699877119092 -349.80698704719543,2.470907688140869,6837.253606081009,29349,0,6837.253606081009,0.2862742790727101,3581,0.7425458365330914,7190.557967662811,0.2624988215310233,0.7514803750174386,0.285028031520382,3554,0.7252164979160804 -353.8756756782532,2.50755262374878,6917.243940591812,29696,0,6917.243940591812,0.2861342101215617,3581,0.7433965449158755,7274.665694713592,0.2623601811272757,0.7522429057529995,0.2848330075377215,3554,0.7262109897078995 -357.9442069530487,2.5392658710479736,6997.3993401527405,30042,0,6997.3993401527405,0.2860980764909417,3581,0.7436579342362468,7358.933331489563,0.2624179295131138,0.7523172923496791,0.2848615329723111,3554,0.7264418722741981 -362.0090301036835,2.570589065551758,7077.524636507034,30384,0,7077.524636507034,0.2862335094291049,3581,0.7439412764416364,7443.16653752327,0.2624817064830235,0.752615247453962,0.2849113709079998,3554,0.7267547761984735 -366.0766587257385,2.6026461124420166,7157.639096975327,30730,0,7157.639096975327,0.2860827708304244,3581,0.7433814778736736,7527.392614603043,0.2620935099465506,0.7524106161934989,0.2847992269658747,3554,0.7261035513461944 -370.1431245803833,2.635183095932007,7237.776307582855,31078,0,7237.776307582855,0.2860759531642697,3581,0.7438499197151633,7611.640864372253,0.2622142178671701,0.7527110236031669,0.2848172077786473,3554,0.7266209591085748 -374.21287846565247,2.667577743530273,7317.927659273148,31422,0,7317.927659273148,0.2860836571270245,3581,0.7436331861081052,7695.90665435791,0.2622088704790388,0.7525147029331752,0.2847631794720737,3554,0.7264274464072172 -378.281388759613,2.6994807720184326,7398.061017036438,31769,0,7398.061017036438,0.2860863841934865,3581,0.7435474198678791,7780.152475118637,0.2618723937443324,0.752835886819022,0.2847628359990504,3554,0.7263305870146314 -382.349130153656,2.732588291168213,7478.068758964539,32114,0,7478.068758964539,0.2859717451370951,3581,0.7438437838156241,7864.272994995117,0.2619447537830898,0.752873148236956,0.2846750614679322,3554,0.7266755026246835 -386.4174609184265,2.765906810760498,7558.151543617248,32457,0,7558.151543617248,0.2859257088463854,3581,0.7433731603209648,7948.469160079956,0.2619777406964983,0.7523930413382394,0.2846769333959095,3554,0.7261111764473129 -390.4858248233795,2.7979736328125,7638.266149044037,32802,0,7638.266149044037,0.2859749835285186,3581,0.7439243686295728,8032.69620680809,0.2616222926548549,0.7533761433192662,0.2846399413512943,3554,0.7267281226918613 -394.5550725460053,2.830050706863404,7718.342501401901,33147,0,7718.342501401901,0.2860920769447256,3581,0.743834920849623,8116.885622739792,0.2618422337940761,0.7531262125287738,0.2847223920505504,3554,0.726640124903278 -398.6193375587464,2.863105297088623,7798.298315048218,33490,0,7798.298315048218,0.2859360546547752,3581,0.7436743648116797,8200.950634002686,0.2618260894502912,0.7528594561985561,0.2846458319136448,3554,0.7264542373030388 -402.6862740516663,2.8966965675354004,7878.393220424652,33835,0,7878.393220424652,0.2859982658584369,3581,0.7441031960128107,8285.157857656479,0.2615469523838588,0.7536260059901646,0.2846270611129185,3554,0.7269739119873734 -406.7518992424011,2.929734468460083,7958.514358758926,34182,0,7958.514358758926,0.2859778128599727,3581,0.7439697061095015,8369.389568090439,0.2616581235613142,0.7533986227852958,0.2846151597726593,3554,0.7268093197145822 -410.8144009113312,2.966977119445801,8038.634386777878,34529,0,8038.634386777878,0.2859718814904182,3581,0.7439403219683748,8453.621324062347,0.2616626535143171,0.7533296176365444,0.2845997206602595,3554,0.7267763463043402 -414.8794231414795,3.003573179244995,8118.746688127518,34873,0,8118.746688127518,0.2859627117294401,3581,0.7440671987355139,8537.84702205658,0.2615305015019008,0.7535209655761719,0.2845629862204119,3554,0.7269215666986142 -418.9490053653717,3.041345357894897,8198.770928144455,35217,0,8198.770928144455,0.2859043525071558,3581,0.7439063699909243,8621.990622997284,0.2615130628858294,0.753373282296317,0.2845305623670072,3554,0.7267270235781865 -423.0212299823761,3.075789213180542,8278.949311256409,35562,0,8278.949311256409,0.2859195729468462,3581,0.7439984084840129,8706.287668466568,0.2615209477288382,0.7534516879490444,0.2845391148452887,3554,0.7268488191122678 -427.092668056488,3.1104507446289062,8359.058719158173,35906,0,8359.058719158173,0.285884973291111,3581,0.7438668957038885,8790.514942407608,0.2614938191005161,0.7533174923488072,0.2845197601404227,3554,0.7266854633423607 -431.1592288017273,3.1441428661346436,8423.860284805298,36189,0,8423.860284805298,0.28589658036773946,3581,0.7441039459560876,8859.426425457,0.2615004437310355,0.7535576820373535,0.28452429398433104,3554,0.7269425185530388 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/measurements.csv deleted file mode 100644 index f92ed588c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.668263,1.0711248,,,,,,,,,,,,,, -1,,,0.246516398021153,1.0853406361171178,0.2452213628679832,1.0927659497880908,3554.0,0.2660334975118245,1.0888061012112538,3581.0,30.921306133270264,34.93953561782837,30.921306133270264,4.018085956573486,0.0,0.0 -100,0.37125352,0.2824171,,,,,,,,,,,,,, -200,0.3005005,0.29679072,,,,,,,,,,,,,, -300,0.6221414,0.34453422,,,,,,,,,,,,,, -343,,,0.7190193448747907,0.2926885059901646,0.6958582099395048,0.3120596675840602,3554.0,0.7135120552700014,0.314064927089151,3581.0,110.98090124130248,119.09683203697205,110.98090124130248,8.083146572113037,0.0205724239349365,0.0 -400,0.3026434,0.34200683,,,,,,,,,,,,,, -500,0.2365207,0.28122163,,,,,,,,,,,,,, -600,0.14732997,0.2636921,,,,,,,,,,,,,, -682,,,0.7287539754595075,0.2871142455509731,0.7065624780177265,0.3065337705424346,3554.0,0.7233772181958601,0.3084947233991029,3581.0,191.13330054283145,203.36154174804688,191.13330054283145,12.156377792358398,0.0475318431854248,0.0 -700,0.16563047,0.24371421,,,,,,,,,,,,,, -800,0.2943948,0.22394517,,,,,,,,,,,,,, -900,0.2098713,0.34537268,,,,,,,,,,,,,, -1000,0.26081136,0.27014703,,,,,,,,,,,,,, -1021,,,0.7351130758013044,0.2757679224014282,0.710854448222953,0.2959042421940946,3554.0,0.7285011715477521,0.2976419192657428,3581.0,271.3507161140442,287.68252325057983,271.3507161140442,16.221641063690186,0.0735299587249755,0.0 -1100,0.4206351,0.29988453,,,,,,,,,,,,,, -1200,0.22812673,0.23332924,,,,,,,,,,,,,, -1300,0.16620336,0.2692418,,,,,,,,,,,,,, -1367,,,0.7387027059282575,0.2734766347067697,0.7141811218433455,0.29405591081352,3554.0,0.7317456306941148,0.2956445817007295,3581.0,351.4316899776459,371.8707139492035,351.4316899776459,20.29216480255127,0.0981926918029785,0.0 -1400,0.22901176,0.2615131,,,,,,,,,,,,,, -1500,0.107053556,0.2848257,,,,,,,,,,,,,, -1600,0.30564925,0.27042452,,,,,,,,,,,,,, -1700,0.15974982,0.32923132,,,,,,,,,,,,,, -1717,,,0.7386926923479352,0.2745262043816702,0.7143179614958497,0.2950654467237619,3554.0,0.7316911575415387,0.2967758711341106,3581.0,431.6135165691376,456.1615972518921,431.6135165691376,24.362091302871704,0.1249408721923828,0.0 -1800,0.28474578,0.26250562,,,,,,,,,,,,,, -1900,0.055954155,0.383483,,,,,,,,,,,,,, -2000,0.14678535,0.30193385,,,,,,,,,,,,,, -2063,,,0.732980591910226,0.2759300640651158,0.7118857603228756,0.2958190952316052,3554.0,0.728711155665317,0.2973489300627443,3581.0,511.7001166343689,540.3580513000488,511.7001166343689,28.435359239578247,0.149343729019165,0.0 -2100,0.081327155,0.26572797,,,,,,,,,,,,,, -2200,0.17360753,0.2744382,,,,,,,,,,,,,, -2300,0.37820974,0.27071553,,,,,,,,,,,,,, -2400,0.07135125,0.2810663,,,,,,,,,,,,,, -2408,,,0.7401093755449567,0.2731419120516096,0.7167132736661156,0.2937547193193409,3554.0,0.7338002707431583,0.2951985358925579,3581.0,591.8303995132446,624.5928885936737,591.8303995132446,32.50367093086243,0.1732890605926513,0.0 -2500,0.32039297,0.18723817,,,,,,,,,,,,,, -2600,0.18910748,0.2726809,,,,,,,,,,,,,, -2700,0.30902094,0.28545612,,,,,,,,,,,,,, -2755,,,0.7433654921395438,0.2711653198514666,0.7190582326911579,0.2919886840751617,3554.0,0.7364833632583426,0.2934341920683817,3581.0,672.050882101059,708.9201610088348,672.050882101059,36.57264566421509,0.1989805698394775,0.0 -2800,0.17283702,0.42980066,,,,,,,,,,,,,, -2900,0.09059563,0.27658147,,,,,,,,,,,,,, -3000,0.30593172,0.35261017,,,,,,,,,,,,,, -3098,,,0.7391443252563477,0.2710207189832415,0.7145846339511818,0.2916609077689751,3554.0,0.7322224582649749,0.2930829799964221,3581.0,752.0927169322968,793.0711913108826,752.0927169322968,40.64362549781799,0.2251121997833252,0.0 -3100,0.2717111,0.2772707,,,,,,,,,,,,,, -3200,0.18101759,0.3634211,,,,,,,,,,,,,, -3300,0.08782327,0.2701805,,,,,,,,,,,,,, -3400,0.14910631,0.3040413,,,,,,,,,,,,,, -3444,,,0.744307381766183,0.2696787970406668,0.7197982796672763,0.2905521081549838,3554.0,0.7370181409915177,0.2921300066213174,3581.0,832.1244361400604,877.211119890213,832.1244361400604,44.71409749984741,0.2506833076477051,0.0 -3500,0.19873974,0.32763794,,,,,,,,,,,,,, -3600,0.08789088,0.32866237,,,,,,,,,,,,,, -3700,0.3297413,0.25614646,,,,,,,,,,,,,, -3790,,,0.738464491707938,0.2722888674054827,0.7155386646208497,0.2926425536697207,3554.0,0.7324526226743577,0.2941143906206367,3581.0,912.100329875946,961.2952928543092,912.100329875946,48.78200364112854,0.2791216373443603,0.0 -3800,0.0991142,0.31076798,,,,,,,,,,,,,, -3900,0.19431353,0.31385607,,,,,,,,,,,,,, -4000,0.25191355,0.27032816,,,,,,,,,,,,,, -4100,0.10248632,0.24589294,,,,,,,,,,,,,, -4135,,,0.7464638437543597,0.2697600637163435,0.7227100378754572,0.2903878937025183,3554.0,0.7397047105164409,0.2919348168393081,3581.0,992.0747768878936,1045.3737080097198,992.0747768878936,52.84801506996155,0.3050611019134521,0.0 -4200,0.2692283,0.22557485,,,,,,,,,,,,,, -4300,0.16907784,0.2872391,,,,,,,,,,,,,, -4400,0.27962556,0.26259407,,,,,,,,,,,,,, -4479,,,0.7473081861223493,0.2706511190959385,0.7233477298906162,0.2913767525367544,3554.0,0.7403491844980452,0.2929800673258168,3581.0,1072.153960943222,1129.5571291446686,1072.153960943222,56.91461539268494,0.3306634426116943,0.0 -4500,0.14701454,0.3307448,,,,,,,,,,,,,, -4600,0.26229087,0.29076344,,,,,,,,,,,,,, -4700,0.2736896,0.3223075,,,,,,,,,,,,,, -4800,0.16872673,0.25563318,,,,,,,,,,,,,, -4828,,,0.7465886388506208,0.2700659206935337,0.7221604123434862,0.2913121796083638,3554.0,0.739230269128735,0.2928126254450572,3581.0,1152.2814135551453,1213.7953803539276,1152.2814135551453,60.98463320732117,0.3593010902404785,0.0 -4900,0.12249901,0.31554547,,,,,,,,,,,,,, -5000,0.2810619,0.3951137,,,,,,,,,,,,,, -5100,0.08058958,0.29351687,,,,,,,,,,,,,, -5175,,,0.7383684430803571,0.2703052588871547,0.7130640101821891,0.2912168314970808,3554.0,0.730916534313041,0.2925935397431758,3581.0,1232.244782447815,1297.869435787201,1232.244782447815,65.05647873878479,0.3860032558441162,0.0 -5200,0.09903727,0.26700562,,,,,,,,,,,,,, -5300,0.23239417,0.26537046,,,,,,,,,,,,,, -5400,0.09762867,0.3549126,,,,,,,,,,,,,, -5500,0.34031832,0.31706566,,,,,,,,,,,,,, -5519,,,0.7446727752685547,0.2692771298544748,0.7195834716384707,0.2904981828903172,3554.0,0.7368998544837336,0.2919312375645769,3581.0,1312.2321465015411,1381.9618797302246,1312.2321465015411,69.12384462356567,0.4116835594177246,0.0 -5600,0.28790042,0.2531293,,,,,,,,,,,,,, -5700,0.055365343,0.29812223,,,,,,,,,,,,,, -5800,0.10975805,0.27301878,,,,,,,,,,,,,, -5865,,,0.7418381146022252,0.2701320307595389,0.7172035470596511,0.2907738199915588,3554.0,0.7348363514686889,0.2922777795352206,3581.0,1392.328580379486,1466.1667184829712,1392.328580379486,73.19486999511719,0.4371984004974365,0.0 -5900,0.3395913,0.25733063,,,,,,,,,,,,,, -6000,0.23155962,0.2430994,,,,,,,,,,,,,, -6100,0.18188953,0.29540405,,,,,,,,,,,,,, -6200,0.25117025,0.27774805,,,,,,,,,,,,,, -6212,,,0.7456263133457729,0.2688169138772147,0.7211000424257878,0.2895488234537668,3554.0,0.7384315113620497,0.2911363658675998,3581.0,1472.3439490795135,1550.2850694656372,1472.3439490795135,77.25946640968323,0.4633595943450928,0.0 -6300,0.33880225,0.2260747,,,,,,,,,,,,,, -6400,0.1240884,0.3022476,,,,,,,,,,,,,, -6500,0.1792424,0.30211806,,,,,,,,,,,,,, -6555,,,0.7413718359810966,0.2695868696485247,0.7175827412774338,0.2903499055861353,3554.0,0.7346102094823374,0.2919402368839011,3581.0,1552.5109388828278,1634.5626347064972,1552.5109388828278,81.32579112052917,0.4954812526702881,0.0 -6600,0.34702253,0.25292554,,,,,,,,,,,,,, -6700,0.25109425,0.27766106,,,,,,,,,,,,,, -6800,0.15197901,0.27173576,,,,,,,,,,,,,, -6900,0.11261718,0.3449242,,,,,,,,,,,,,, -6904,,,0.7396819250924247,0.2716102770396641,0.7154917462058596,0.2920786396599782,3554.0,0.7325803175614354,0.2937172274787943,3581.0,1632.5720751285553,1718.733685016632,1632.5720751285553,85.39356756210327,0.5252249240875244,0.0 -7000,0.14662637,0.25031084,,,,,,,,,,,,,, -7100,0.069393285,0.27425748,,,,,,,,,,,,,, -7200,0.24028115,0.27512953,,,,,,,,,,,,,, -7250,,,0.742391858782087,0.2699088539396013,0.7173541256330894,0.2908533339964652,3554.0,0.734957365042935,0.2922027511191881,3581.0,1712.6343562602997,1802.8997659683228,1712.6343562602997,89.4588303565979,0.5517885684967041,0.0 -7300,0.19624281,0.30941057,,,,,,,,,,,,,, -7400,0.24205138,0.28126302,,,,,,,,,,,,,, -7500,0.26615855,0.27376044,,,,,,,,,,,,,, -7595,,,0.7462922504970005,0.2696233136313302,0.7221008541212366,0.2901129435473234,3554.0,0.7391194138770595,0.2917915094967362,3581.0,1792.7491919994354,1887.1206386089325,1792.7491919994354,93.52749109268188,0.5771715641021729,0.0 -7600,0.08127473,0.26877406,,,,,,,,,,,,,, -7700,0.17889638,0.22094357,,,,,,,,,,,,,, -7800,0.29796806,0.2917863,,,,,,,,,,,,,, -7900,0.28061348,0.29273528,,,,,,,,,,,,,, -7941,,,0.7423366819109235,0.2708460944039481,0.7175806117446891,0.2916510500932048,3554.0,0.7348287156825957,0.2932556714801207,3581.0,1872.717239141464,1971.190478086472,1872.717239141464,97.5896496772766,0.6047909259796143,0.0 -8000,0.19279619,0.30035007,,,,,,,,,,,,,, -8100,0.09684072,0.19737579,,,,,,,,,,,,,, -8200,0.109252766,0.32889748,,,,,,,,,,,,,, -8288,,,0.7464926583426339,0.2686344725745065,0.7218817870269415,0.2896130872564364,3554.0,0.7391162777506283,0.2910643713130061,3581.0,1952.7435121536253,2055.3272376060486,1952.7435121536253,101.66148519515993,0.6309757232666016,0.0 -8300,0.21748616,0.20196831,,,,,,,,,,,,,, -8400,0.08636936,0.20169969,,,,,,,,,,,,,, -8500,0.19372809,0.25779593,,,,,,,,,,,,,, -8600,0.36192325,0.17781731,,,,,,,,,,,,,, -8634,,,0.7441981860569545,0.2723206111363002,0.7199123814056345,0.2937800332811621,3554.0,0.7369340791678303,0.295108065462685,3581.0,2032.869675397873,2139.5630140304565,2032.869675397873,105.73176598548888,0.6581485271453857,0.0 -8700,0.23812827,0.27086163,,,,,,,,,,,,,, -8800,0.056075923,0.2594843,,,,,,,,,,,,,, -8900,0.13967614,0.25568464,,,,,,,,,,,,,, -8980,,,0.7444562230791364,0.2677069221224104,0.7200566400754431,0.2886012157296356,3554.0,0.7375566002644164,0.2899546597930047,3581.0,2113.0663661956787,2223.8691778182983,2113.0663661956787,109.80195808410645,0.6855344772338867,0.0 -9000,0.13409676,0.34319568,,,,,,,,,,,,,, -9100,0.0704554,0.28795275,,,,,,,,,,,,,, -9200,0.1192035,0.2448439,,,,,,,,,,,,,, -9300,0.14077367,0.41920507,,,,,,,,,,,,,, -9327,,,0.744776862008231,0.2699027742658342,0.7201405161877462,0.2906218331787247,3554.0,0.7373822725408405,0.2921331086594178,3581.0,2193.08740067482,2307.990259170532,2193.08740067482,113.86398601531982,0.7114531993865967,0.0 -9400,0.15565367,0.28348777,,,,,,,,,,,,,, -9500,0.16521476,0.32447666,,,,,,,,,,,,,, -9600,0.17870925,0.29336393,,,,,,,,,,,,,, -9671,,,0.7430753707885742,0.2686062029429844,0.7174473442116277,0.2899052454101013,3554.0,0.7349934986735549,0.2912903428577038,3581.0,2273.16263628006,2392.173000574112,2273.16263628006,117.93263745307922,0.7382164001464844,0.0 -9700,0.09565864,0.28943318,,,,,,,,,,,,,, -9800,0.2577873,0.2285631,,,,,,,,,,,,,, -9900,0.1438051,0.2934282,,,,,,,,,,,,,, -10000,0.13940263,0.23240717,,,,,,,,,,,,,, -10018,,,0.7443200520106724,0.2707388401031494,0.7192811466833146,0.2921717895439117,3554.0,0.7365767652846621,0.2935151518539688,3581.0,2353.270979166031,2476.385600566864,2353.270979166031,121.99877452850342,0.7642619609832764,0.0 -10100,0.17770755,0.22867994,,,,,,,,,,,,,, -10200,0.27914888,0.34482238,,,,,,,,,,,,,, -10300,0.26233923,0.2721156,,,,,,,,,,,,,, -10365,,,0.7452825137547084,0.2684168475014822,0.7201388675172341,0.2898022378503974,3554.0,0.7376453662777507,0.2911413427638927,3581.0,2433.446925401688,2560.674476146698,2433.446925401688,126.0719668865204,0.7919392585754395,0.0 -10400,0.103212036,0.23735559,,,,,,,,,,,,,, -10500,0.089742996,0.259174,,,,,,,,,,,,,, -10600,0.23108768,0.25273067,,,,,,,,,,,,,, -10700,0.072602734,0.22304522,,,,,,,,,,,,,, -10711,,,0.7451065608433315,0.2677556276321411,0.7203397305412915,0.2886788406329136,3554.0,0.7377283372748534,0.2900784686103742,3581.0,2513.66655087471,2644.9995489120483,2513.66655087471,130.13805532455444,0.8192059993743896,0.0 -10800,0.26825386,0.26234335,,,,,,,,,,,,,, -10900,0.15634868,0.28493333,,,,,,,,,,,,,, -11000,0.22008657,0.3336856,,,,,,,,,,,,,, -11057,,,0.7475382941109794,0.2672877141407558,0.7230717836636537,0.2884536940661051,3554.0,0.7402480103323094,0.2898371573168284,3581.0,2593.805549621582,2729.240618467331,2593.805549621582,134.20095443725586,0.8461177349090576,0.0 -11100,0.20334259,0.33370662,,,,,,,,,,,,,, -11200,0.09429886,0.24202125,,,,,,,,,,,,,, -11300,0.20717283,0.2878765,,,,,,,,,,,,,, -11400,0.14754319,0.28501,,,,,,,,,,,,,, -11404,,,0.7428219658987862,0.2679642949785505,0.7175823291098059,0.2890443302770645,3554.0,0.7352287763325538,0.2903264271284208,3581.0,2673.933702230453,2813.4708971977234,2673.933702230453,138.2629885673523,0.8738346099853516,0.0 -11500,0.16081095,0.2575646,,,,,,,,,,,,,, -11600,0.11283866,0.23054971,,,,,,,,,,,,,, -11700,0.2629473,0.30455157,,,,,,,,,,,,,, -11751,,,0.7452921867370605,0.2679216861724853,0.7202087299301843,0.2889721322475556,3554.0,0.7377622892523038,0.290267249786198,3581.0,2753.910780191421,2897.559303998947,2753.910780191421,142.33258748054504,0.9034388065338136,0.0 -11800,0.09515699,0.28854638,,,,,,,,,,,,,, -11900,0.19477507,0.30026495,,,,,,,,,,,,,, -12000,0.14555329,0.33393943,,,,,,,,,,,,,, -12095,,,0.7459325109209333,0.2683262654713222,0.7207554015941545,0.2895289020184123,3554.0,0.7378974835721517,0.2911458765118856,3581.0,2833.876010656357,2981.635125398636,2833.876010656357,146.40227818489075,0.9323925971984864,0.0 -12100,0.1580329,0.29265487,,,,,,,,,,,,,, -12200,0.18369214,0.25380546,,,,,,,,,,,,,, -12300,0.3107535,0.21350214,,,,,,,,,,,,,, -12400,0.12447972,0.27905497,,,,,,,,,,,,,, -12442,,,0.7460260391235352,0.2672440324510847,0.7213144382869654,0.2883688562293366,3554.0,0.738725284596656,0.2897804002460905,3581.0,2913.926083803177,3065.7957208156586,2913.926083803177,150.47212028503418,0.960963010787964,0.0 -12500,0.14211582,0.33160758,,,,,,,,,,,,,, -12600,0.15043667,0.2677297,,,,,,,,,,,,,, -12700,0.45070195,0.21373151,,,,,,,,,,,,,, -12787,,,0.7459701129368373,0.267539484160287,0.7213828581132175,0.2883787825997116,3554.0,0.7387887570685563,0.2897285859833147,3581.0,2993.89020228386,3149.8680698871613,2993.89020228386,154.54147481918335,0.9875788688659668,0.0 -12800,0.073637895,0.27402118,,,,,,,,,,,,,, -12900,0.18964808,0.22381763,,,,,,,,,,,,,, -13000,0.24991012,0.32894686,,,,,,,,,,,,,, -13100,0.22487865,0.24132526,,,,,,,,,,,,,, -13129,,,0.7447130339486259,0.2689516884940011,0.7199925480092854,0.2901015058956457,3554.0,0.7373984985862887,0.2914820215416434,3581.0,3073.952570438385,3234.037271499634,3073.952570438385,158.60821294784546,1.0157010555267334,0.0 -13200,0.21202794,0.24495807,,,,,,,,,,,,,, -13300,0.12305536,0.35750085,,,,,,,,,,,,,, -13400,0.19760758,0.3683623,,,,,,,,,,,,,, -13476,,,0.7458803313119071,0.2672208036695208,0.7209985804946891,0.2882889300568022,3554.0,0.7384223756894024,0.28967902155037,3581.0,3154.042021512985,3318.2383949756622,3154.042021512985,162.68055200576782,1.0426347255706787,0.0 -13500,0.1965921,0.34536034,,,,,,,,,,,,,, -13600,0.11148954,0.24391152,,,,,,,,,,,,,, -13700,0.15474223,0.3768348,,,,,,,,,,,,,, -13800,0.28339416,0.21951155,,,,,,,,,,,,,, -13822,,,0.7462607792445591,0.2666973727090018,0.7209504255768149,0.2880001894597196,3554.0,0.7383779245060738,0.2893738969016162,3581.0,3234.0972611904144,3402.4001002311707,3234.0972611904144,166.74572896957395,1.071657419204712,0.0 -13900,0.31529063,0.24288484,,,,,,,,,,,,,, -14000,0.13527186,0.2657569,,,,,,,,,,,,,, -14100,0.069294944,0.28031918,,,,,,,,,,,,,, -14165,,,0.7511077608380999,0.2673636163984026,0.7264354836759637,0.2888397233970526,3554.0,0.7433448670064228,0.2902781580520455,3581.0,3314.093006849289,3486.50393986702,3314.093006849289,170.81399512290955,1.099256992340088,0.0 -14200,0.30650353,0.2361398,,,,,,,,,,,,,, -14300,0.08351701,0.23186442,,,,,,,,,,,,,, -14400,0.12668438,0.31028634,,,,,,,,,,,,,, -14500,0.14161272,0.35601497,,,,,,,,,,,,,, -14512,,,0.7483531406947544,0.2671572991779872,0.7235218707134566,0.2885451265849219,3554.0,0.7409909995898492,0.289621071388055,3581.0,3394.106276988983,3570.623516082764,3394.106276988983,174.87861514091492,1.128664493560791,0.0 -14600,0.3365062,0.21882527,,,,,,,,,,,,,, -14700,0.05223781,0.27627277,,,,,,,,,,,,,, -14800,0.15524098,0.28067094,,,,,,,,,,,,,, -14857,,,0.7454121453421456,0.2669458218983241,0.7204186606420583,0.288061774172807,3554.0,0.7379840679323164,0.2893824871609711,3581.0,3474.10543012619,3654.7311673164368,3474.10543012619,178.94793677330017,1.155869722366333,0.0 -14900,0.23113722,0.20996219,,,,,,,,,,,,,, -15000,0.14904651,0.21237588,,,,,,,,,,,,,, -15100,0.31672233,0.27098727,,,,,,,,,,,,,, -15193,,,0.7471641813005719,0.2667359624590192,0.7230010282208427,0.2876480952634795,3554.0,0.7402746674069743,0.2890145718069324,3581.0,3554.072303533554,3738.805694103241,3554.072303533554,183.0156137943268,1.1840167045593262,0.0 -15200,0.18466876,0.26232225,,,,,,,,,,,,,, -15300,0.23668218,0.17696589,,,,,,,,,,,,,, -15400,0.21675184,0.35744244,,,,,,,,,,,,,, -15500,0.13533016,0.27972388,,,,,,,,,,,,,, -15538,,,0.7494399888174874,0.2667510339191982,0.7244877855497327,0.2883546364461698,3554.0,0.7416667666591036,0.289791001716961,3581.0,3634.212679386139,3823.0554959774017,3634.212679386139,187.08305025100708,1.214017391204834,0.0 -15600,0.083669424,0.2983039,,,,,,,,,,,,,, -15700,0.09657456,0.21255264,,,,,,,,,,,,,, -15800,0.12157934,0.3888849,,,,,,,,,,,,,, -15884,,,0.7447804042271206,0.2677041462489536,0.7197174948121835,0.2888085016992297,3554.0,0.7372491235208392,0.2901403730190589,3581.0,3714.300877094269,3907.250316858292,3714.300877094269,191.1493947505951,1.2421448230743408,0.0 -15900,0.18119493,0.21069866,,,,,,,,,,,,,, -16000,0.10394316,0.23434234,,,,,,,,,,,,,, -16100,0.09378338,0.30522323,,,,,,,,,,,,,, -16200,0.24954523,0.2691356,,,,,,,,,,,,,, -16229,,,0.7450590133666992,0.2670665298189436,0.720541830068233,0.2879636782773371,3554.0,0.7380651981595574,0.2892953232991831,3581.0,3794.307913780213,3991.366268634796,3794.307913780213,195.2165243625641,1.2717201709747314,0.0 -16300,0.19377248,0.24725412,,,,,,,,,,,,,, -16400,0.1445664,0.24312383,,,,,,,,,,,,,, -16500,0.14325784,0.32319295,,,,,,,,,,,,,, -16575,,,0.7477623394557408,0.2658184255872454,0.7226664168014912,0.2874061356921778,3554.0,0.7400445029975915,0.2887590797677848,3581.0,3874.412040710449,4075.5846552848816,3874.412040710449,199.2868800163269,1.3034842014312744,0.0 -16600,0.07766837,0.27944398,,,,,,,,,,,,,, -16700,0.12956758,0.28411475,,,,,,,,,,,,,, -16800,0.08060531,0.17840403,,,,,,,,,,,,,, -16900,0.10811156,0.32020965,,,,,,,,,,,,,, -16923,,,0.7463718141828265,0.265954852104187,0.7211565093908272,0.2872358245935477,3554.0,0.7386974003420832,0.2885726506867844,3581.0,3954.473661661148,4159.762113332748,3954.473661661148,203.36089968681333,1.3331928253173828,0.0 -17000,0.1591061,0.21041238,,,,,,,,,,,,,, -17100,0.14530434,0.2580457,,,,,,,,,,,,,, -17200,0.14372922,0.24489936,,,,,,,,,,,,,, -17265,,,0.7463821683611188,0.2664083412715367,0.7211328784468205,0.2878108499555958,3554.0,0.7388150732599135,0.289015083131894,3581.0,4034.495836257935,4243.891760110855,4034.495836257935,207.42804789543152,1.3615479469299316,0.0 -17300,0.08049678,0.34179324,,,,,,,,,,,,,, -17400,0.1135486,0.2879674,,,,,,,,,,,,,, -17500,0.16438152,0.22783136,,,,,,,,,,,,,, -17600,0.1410425,0.29546365,,,,,,,,,,,,,, -17608,,,0.7503792217799595,0.2660165173666818,0.7255287148943093,0.2874407921202342,3554.0,0.7425831291669576,0.2888032923327981,3581.0,4114.508437395096,4328.01832151413,4114.508437395096,211.501356124878,1.390101432800293,0.0 -17700,0.17946175,0.21231812,,,,,,,,,,,,,, -17800,0.1800996,0.27678868,,,,,,,,,,,,,, -17900,0.17668942,0.29109812,,,,,,,,,,,,,, -17953,,,0.748267650604248,0.2660280976976667,0.7235039414216375,0.2873225343582934,3554.0,0.7408580550998325,0.2886968344757923,3581.0,4194.599543809891,4412.221243619919,4194.599543809891,215.56890416145325,1.4220576286315918,0.0 -18000,0.14373384,0.32504624,,,,,,,,,,,,,, -18100,0.14114447,0.29428872,,,,,,,,,,,,,, -18200,0.11251797,0.36985272,,,,,,,,,,,,,, -18299,,,0.7484581129891532,0.2655689546040126,0.7236135093160875,0.2867548936662475,3554.0,0.7409523434227521,0.2881057769085102,3581.0,4274.669712781906,4496.401560783386,4274.669712781906,219.6377148628235,1.4514093399047852,0.0 -18300,0.09872772,0.26828212,,,,,,,,,,,,,, -18400,0.091700725,0.27274168,,,,,,,,,,,,,, -18500,0.116095856,0.28235456,,,,,,,,,,,,,, -18600,0.1725265,0.39576313,,,,,,,,,,,,,, -18646,,,0.7454803330557687,0.2659528255462646,0.7206421241910523,0.2874833656014789,3554.0,0.737948002478358,0.2888511523492041,3581.0,4354.866664648056,4580.708430051804,4354.866664648056,223.7048351764679,1.4820961952209473,0.0 -18700,0.18970154,0.26478332,,,,,,,,,,,,,, -18800,0.09430636,0.27323547,,,,,,,,,,,,,, -18900,0.26181465,0.24596882,,,,,,,,,,,,,, -18989,,,0.7470806666782924,0.2660172326224191,0.7226263678469682,0.2872723529495815,3554.0,0.7397250953382435,0.2887143899661407,3581.0,4434.899230241776,4664.850015163422,4434.899230241776,227.77167344093323,1.5122063159942627,0.0 -19000,0.12598369,0.2823386,,,,,,,,,,,,,, -19100,0.09204788,0.20070851,,,,,,,,,,,,,, -19200,0.16851753,0.21974213,,,,,,,,,,,,,, -19300,0.101771384,0.2820208,,,,,,,,,,,,,, -19332,,,0.7494047028677804,0.2656295129231044,0.7244741840180079,0.2869826334543824,3554.0,0.7417560099090686,0.2883070344133971,3581.0,4514.962484121323,4749.026458740234,4514.962484121323,231.8424940109253,1.5426294803619385,0.0 -19400,0.130966,0.23358873,,,,,,,,,,,,,, -19500,0.16786335,0.31516945,,,,,,,,,,,,,, -19600,0.14028196,0.20687884,,,,,,,,,,,,,, -19676,,,0.7484260967799595,0.2648649385997227,0.7227509111652364,0.2867769446343469,3554.0,0.740140154853742,0.2881644088374406,3581.0,4595.0616002082825,4833.237877607346,4595.0616002082825,235.91216826438904,1.573349952697754,0.0 -19700,0.28137773,0.20251222,,,,,,,,,,,,,, -19800,0.1395356,0.23418401,,,,,,,,,,,,,, -19900,0.17924799,0.25666976,,,,,,,,,,,,,, -20000,0.09573561,0.28114682,,,,,,,,,,,,,, -20023,,,0.7465226990836007,0.2652723789215088,0.7210318973779544,0.2867201341962841,3554.0,0.738479030495148,0.2881558185780857,3581.0,4675.134254932404,4917.429103136063,4675.134254932404,239.9841315746308,1.6079411506652832,0.0 -20100,0.10348089,0.30183083,,,,,,,,,,,,,, -20200,0.12968573,0.315084,,,,,,,,,,,,,, -20300,0.16334492,0.28317696,,,,,,,,,,,,,, -20368,,,0.7510388919285366,0.2655698912484305,0.7263813523274831,0.2868956832585203,3554.0,0.7434962191950573,0.2882353125654496,3581.0,4755.205282211304,5001.61284160614,4755.205282211304,244.053169965744,1.639510154724121,0.0 -20400,0.07029086,0.2953376,,,,,,,,,,,,,, -20500,0.25603995,0.2575715,,,,,,,,,,,,,, -20600,0.07707786,0.3769452,,,,,,,,,,,,,, -20700,0.1391161,0.2548352,,,,,,,,,,,,,, -20714,,,0.7495647839137486,0.2648200818470546,0.7239776594154473,0.2867387676078011,3554.0,0.7412412761143884,0.2880398500767941,3581.0,4835.244169473648,5085.760347127914,4835.244169473648,248.11840963363647,1.670926809310913,0.0 -20800,0.15937905,0.3173945,,,,,,,,,,,,,, -20900,0.074839085,0.2894102,,,,,,,,,,,,,, -21000,0.07877754,0.331962,,,,,,,,,,,,,, -21059,,,0.7508315358843122,0.2651166234697614,0.7259722759566686,0.286702204904465,3554.0,0.7430783644364354,0.2880960958225705,3581.0,4915.309281587601,5169.936907529831,4915.309281587601,252.1876802444458,1.7009556293487549,0.0 -21100,0.17440212,0.29684812,,,,,,,,,,,,,, -21200,0.16533373,0.19548686,,,,,,,,,,,,,, -21300,0.08003622,0.26465967,,,,,,,,,,,,,, -21400,0.16963138,0.21184325,,,,,,,,,,,,,, -21405,,,0.7501467296055385,0.2652876717703683,0.7254517082424733,0.2866950778392304,3554.0,0.7425363599771363,0.2880237944729998,3581.0,4995.265493392944,5254.006997823715,4995.265493392944,256.25526690483093,1.7352640628814695,0.0 -21500,0.21898022,0.19520015,,,,,,,,,,,,,, -21600,0.09618106,0.3044303,,,,,,,,,,,,,, -21700,0.1565569,0.29761907,,,,,,,,,,,,,, -21750,,,0.7506645747593471,0.2651770285197666,0.7251320035523354,0.2870192476786719,3554.0,0.7422327693032672,0.2883760632832134,3581.0,5075.27999830246,5338.135623931885,5075.27999830246,260.32216477394104,1.7701237201690674,0.0 -21800,0.10387325,0.2269901,,,,,,,,,,,,,, -21900,0.068642184,0.31354338,,,,,,,,,,,,,, -22000,0.07210358,0.38811162,,,,,,,,,,,,,, -22097,,,0.7473197664533343,0.2646193844931466,0.7216715815366489,0.2862578366805008,3554.0,0.7391977488611771,0.2875862707475216,3581.0,5155.243874788284,5422.208884477615,5155.243874788284,264.38602781295776,1.8033950328826904,0.0 -22100,0.15780479,0.3499411,,,,,,,,,,,,,, -22200,0.1417595,0.30220792,,,,,,,,,,,,,, -22300,0.12069482,0.25527883,,,,,,,,,,,,,, -22400,0.04543074,0.31146756,,,,,,,,,,,,,, -22440,,,0.7482331139700753,0.2648459843226841,0.7230001351909819,0.2863896444532129,3554.0,0.7404796746282463,0.287656799503892,3581.0,5235.346307039261,5506.419248342514,5235.346307039261,268.45214796066284,1.8332464694976809,0.0 -22500,0.09063578,0.30669752,,,,,,,,,,,,,, -22600,0.055125907,0.27587608,,,,,,,,,,,,,, -22700,0.10934849,0.22727413,,,,,,,,,,,,,, -22783,,,0.7507668903895787,0.2646455764770508,0.7259625213228053,0.2861737373107326,3554.0,0.7431066577509774,0.2875015953338802,3581.0,5315.468108892441,5590.655894994736,5315.468108892441,272.5250914096832,1.8631410598754885,0.0 -22800,0.14941373,0.30993176,,,,,,,,,,,,,, -22900,0.16300373,0.19717959,,,,,,,,,,,,,, -23000,0.08154357,0.26026735,,,,,,,,,,,,,, -23100,0.15831712,0.35757133,,,,,,,,,,,,,, -23131,,,0.7489895820617676,0.2643979617527553,0.7240669624015195,0.2860227637433173,3554.0,0.7412096421434307,0.2874370320353951,3581.0,5395.63080406189,5674.928566217423,5395.63080406189,276.5922944545746,1.8937137126922607,0.0 -23200,0.12847939,0.26355073,,,,,,,,,,,,,, -23300,0.110150516,0.1786092,,,,,,,,,,,,,, -23400,0.1337289,0.27671528,,,,,,,,,,,,,, -23478,,,0.7495903968811035,0.2652779647282192,0.7237818110975309,0.2874707944888242,3554.0,0.7412838183511938,0.2886014894146188,3581.0,5475.673554420471,5759.084021091461,5475.673554420471,280.66194915771484,1.9245622158050537,0.0 -23500,0.15511604,0.34158522,,,,,,,,,,,,,, -23600,0.087337546,0.19974396,,,,,,,,,,,,,, -23700,0.110725276,0.2571088,,,,,,,,,,,,,, -23800,0.094143845,0.19326992,,,,,,,,,,,,,, -23824,,,0.7507916178022113,0.2644648041043962,0.7256301081308033,0.2861515661270751,3554.0,0.7429002869964745,0.287390842347197,3581.0,5555.82008767128,5843.342591047287,5555.82008767128,284.73019194602966,1.9562304019927976,0.0 -23900,0.09769069,0.23388563,,,,,,,,,,,,,, -24000,0.10324745,0.2438751,,,,,,,,,,,,,, -24100,0.10057023,0.30508566,,,,,,,,,,,,,, -24168,,,0.7475453104291644,0.2642674275806972,0.7220405402583356,0.2861658546048466,3554.0,0.7394907039758447,0.2874138178821384,3581.0,5635.91052532196,5927.541179180145,5635.91052532196,288.7960410118103,1.9867405891418457,0.0 -24200,0.14262116,0.22422017,,,,,,,,,,,,,, -24300,0.048035365,0.26421344,,,,,,,,,,,,,, -24400,0.082240134,0.20101744,,,,,,,,,,,,,, -24500,0.10730647,0.3808023,,,,,,,,,,,,,, -24514,,,0.7507332393101284,0.2639252458299909,0.7251709533931837,0.2857912629255768,3554.0,0.7424466394905403,0.2871191242626012,3581.0,5715.999706029892,6011.738223552704,5715.999706029892,292.8611936569214,2.0173099040985107,0.0 -24600,0.04926697,0.25173256,,,,,,,,,,,,,, -24700,0.03226614,0.30151027,,,,,,,,,,,,,, -24800,0.1396088,0.2471903,,,,,,,,,,,,,, -24858,,,0.7502329690115792,0.2644594056265695,0.7251993242649127,0.2860405212986248,3554.0,0.7425150888587336,0.2872798166538676,3581.0,5796.07946395874,6095.925489425659,5796.07946395874,296.9268276691437,2.047306060791016,0.0 -24900,0.041994244,0.29151332,,,,,,,,,,,,,, -25000,0.09495415,0.19737166,,,,,,,,,,,,,, -25100,0.156454,0.32240725,,,,,,,,,,,,,, -25200,0.0863521,0.22842225,,,,,,,,,,,,,, -25203,,,0.748668943132673,0.2636398417609079,0.7225095183244232,0.2858235837370744,3554.0,0.7400631834028554,0.2870587538288013,3581.0,5876.218941450119,6180.180327177048,5876.218941450119,300.9990952014923,2.07837176322937,0.0 -25300,0.07589375,0.24011767,,,,,,,,,,,,,, -25400,0.10402243,0.2636067,,,,,,,,,,,,,, -25500,0.081156746,0.24448425,,,,,,,,,,,,,, -25549,,,0.7521319389343262,0.2639256715774536,0.7268184560970034,0.2857768370585959,3554.0,0.7440258155019548,0.2870998984440449,3581.0,5956.200823068619,6264.271352767944,5956.200823068619,305.06549167633057,2.109112024307251,0.0 -25600,0.10576982,0.26660457,,,,,,,,,,,,,, -25700,0.09760954,0.28958493,,,,,,,,,,,,,, -25800,0.042275377,0.26727307,,,,,,,,,,,,,, -25895,,,0.7507930483136859,0.2637628316879272,0.7252840934070766,0.2856170534081317,3554.0,0.7425444729998604,0.2869088674383901,3581.0,6036.274323225021,6348.457298994064,6036.274323225021,309.1350910663605,2.1402125358581543,0.0 -25900,0.061527945,0.21161602,,,,,,,,,,,,,, -26000,0.057533413,0.23506977,,,,,,,,,,,,,, -26100,0.068694696,0.32505608,,,,,,,,,,,,,, -26200,0.13615762,0.23697832,,,,,,,,,,,,,, -26241,,,0.7523622512817383,0.2629812274660383,0.7263478980550084,0.2854140780249806,3554.0,0.7435862805649609,0.2866761464041992,3581.0,6116.422718524933,6432.716036319733,6116.422718524933,313.1980721950531,2.1756951808929443,0.0 -26300,0.09361962,0.21987632,,,,,,,,,,,,,, -26400,0.07977333,0.33335453,,,,,,,,,,,,,, -26500,0.07673672,0.22295748,,,,,,,,,,,,,, -26588,,,0.7505794933864048,0.263419577053615,0.7252331220104108,0.2853495050965901,3554.0,0.7424559796931723,0.2865871076842188,3581.0,6196.501030445099,6516.901560783386,6196.501030445099,317.26215529441833,2.207003831863404,0.0 -26600,0.053633645,0.26200122,,,,,,,,,,,,,, -26700,0.095762715,0.24655677,,,,,,,,,,,,,, -26800,0.07874628,0.22131374,,,,,,,,,,,,,, -26900,0.15843904,0.27424628,,,,,,,,,,,,,, -26935,,,0.7514126641409737,0.2635314805167062,0.7257549262274902,0.2855885279735421,3554.0,0.7430407990959229,0.2867723436736421,3581.0,6276.576962709427,6601.088958978653,6276.576962709427,321.3304181098938,2.2384111881256104,0.0 -27000,0.05299329,0.27408952,,,,,,,,,,,,,, -27100,0.07660125,0.33750257,,,,,,,,,,,,,, -27200,0.10826302,0.3300854,,,,,,,,,,,,,, -27278,,,0.7512620517185756,0.2628715208598545,0.7248926715496623,0.2853974882779526,3554.0,0.7421731147244136,0.2867761615666888,3581.0,6356.704101800919,6685.336197853088,6356.704101800919,325.4008071422577,2.2764458656311035,0.0 -27300,0.068480834,0.22032805,,,,,,,,,,,,,, -27400,0.05401634,0.26861516,,,,,,,,,,,,,, -27500,0.074348845,0.2308443,,,,,,,,,,,,,, -27600,0.097014755,0.2511072,,,,,,,,,,,,,, -27624,,,0.7526355470929827,0.2629967076437814,0.7269753545740715,0.2851656955081422,3554.0,0.7441440338330774,0.286496466812692,3581.0,6436.780147314072,6769.524851560593,6436.780147314072,329.46841621398926,2.3095593452453613,0.0 -27700,0.06446064,0.33383322,,,,,,,,,,,,,, -27800,0.08748588,0.27139345,,,,,,,,,,,,,, -27900,0.06972781,0.241105,,,,,,,,,,,,,, -27970,,,0.7511420931134906,0.2631453445979527,0.7253446133537915,0.2853274197811885,3554.0,0.7426191264442544,0.2865856759743263,3581.0,6516.768382072449,6853.62221121788,6516.768382072449,333.53331780433655,2.341780662536621,0.0 -28000,0.09642289,0.23805453,,,,,,,,,,,,,, -28100,0.079400346,0.2513668,,,,,,,,,,,,,, -28200,0.03976493,0.29136047,,,,,,,,,,,,,, -28300,0.06754216,0.19338226,,,,,,,,,,,,,, -28312,,,0.7525931085859027,0.2629969120025635,0.7264981331554234,0.285445660369478,3554.0,0.7436688425020944,0.2868086818342467,3581.0,6596.887026309967,6937.854245662689,6596.887026309967,337.60342359542847,2.3730151653289795,0.0 -28400,0.0828505,0.22604465,,,,,,,,,,,,,, -28500,0.107413,0.25236273,,,,,,,,,,,,,, -28600,0.05106005,0.33831173,,,,,,,,,,,,,, -28659,,,0.7515886851719448,0.2628911563328334,0.7257243571284117,0.2851351092354126,3554.0,0.742927898544401,0.2864548449608175,3581.0,6676.900361776352,7021.976664066315,6676.900361776352,341.6689395904541,2.404748678207397,0.0 -28700,0.0941246,0.30105543,,,,,,,,,,,,,, -28800,0.060894027,0.25873917,,,,,,,,,,,,,, -28900,0.088938355,0.27745265,,,,,,,,,,,,,, -29000,0.07780213,0.2958951,,,,,,,,,,,,,, -29005,,,0.7523136820111956,0.2629059553146362,0.7265699877119092,0.2851231220268975,3554.0,0.7437804476970469,0.2864383121203923,3581.0,6757.101865053177,7106.2891047000885,6757.101865053177,345.7351453304291,2.437550067901612,0.0 -29100,0.05196794,0.20030817,,,,,,,,,,,,,, -29200,0.09797859,0.24594314,,,,,,,,,,,,,, -29300,0.0826793,0.22559245,,,,,,,,,,,,,, -29349,,,0.7514803750174386,0.2624988215310233,0.7252164979160804,0.285028031520382,3554.0,0.7425458365330914,0.2862742790727101,3581.0,6837.253606081009,7190.557967662811,6837.253606081009,349.80698704719543,2.470907688140869,0.0 -29400,0.09825527,0.22784291,,,,,,,,,,,,,, -29500,0.05335675,0.2936643,,,,,,,,,,,,,, -29600,0.07515894,0.27107283,,,,,,,,,,,,,, -29696,,,0.7522429057529995,0.2623601811272757,0.7262109897078995,0.2848330075377215,3554.0,0.7433965449158755,0.2861342101215617,3581.0,6917.243940591812,7274.665694713592,6917.243940591812,353.8756756782532,2.50755262374878,0.0 -29700,0.044900227,0.34222937,,,,,,,,,,,,,, -29800,0.05391691,0.3010139,,,,,,,,,,,,,, -29900,0.029197788,0.20730346,,,,,,,,,,,,,, -30000,0.044184465,0.2766902,,,,,,,,,,,,,, -30042,,,0.7523172923496791,0.2624179295131138,0.7264418722741981,0.2848615329723111,3554.0,0.7436579342362468,0.2860980764909417,3581.0,6997.3993401527405,7358.933331489563,6997.3993401527405,357.9442069530487,2.5392658710479736,0.0 -30100,0.049672693,0.27242324,,,,,,,,,,,,,, -30200,0.074443154,0.24551252,,,,,,,,,,,,,, -30300,0.07984388,0.1656807,,,,,,,,,,,,,, -30384,,,0.752615247453962,0.2624817064830235,0.7267547761984735,0.2849113709079998,3554.0,0.7439412764416364,0.2862335094291049,3581.0,7077.524636507034,7443.16653752327,7077.524636507034,362.0090301036835,2.570589065551758,0.0 -30400,0.056577887,0.23950626,,,,,,,,,,,,,, -30500,0.059596915,0.28350237,,,,,,,,,,,,,, -30600,0.04336753,0.26515734,,,,,,,,,,,,,, -30700,0.03378631,0.28909582,,,,,,,,,,,,,, -30730,,,0.7524106161934989,0.2620935099465506,0.7261035513461944,0.2847992269658747,3554.0,0.7433814778736736,0.2860827708304244,3581.0,7157.639096975327,7527.392614603043,7157.639096975327,366.0766587257385,2.6026461124420166,0.0 -30800,0.051765986,0.2843476,,,,,,,,,,,,,, -30900,0.04499415,0.22206777,,,,,,,,,,,,,, -31000,0.04074386,0.26409248,,,,,,,,,,,,,, -31078,,,0.7527110236031669,0.2622142178671701,0.7266209591085748,0.2848172077786473,3554.0,0.7438499197151633,0.2860759531642697,3581.0,7237.776307582855,7611.640864372253,7237.776307582855,370.1431245803833,2.635183095932007,0.0 -31100,0.045185063,0.28216505,,,,,,,,,,,,,, -31200,0.055310544,0.18971659,,,,,,,,,,,,,, -31300,0.033985253,0.29469928,,,,,,,,,,,,,, -31400,0.028239772,0.20976363,,,,,,,,,,,,,, -31422,,,0.7525147029331752,0.2622088704790388,0.7264274464072172,0.2847631794720737,3554.0,0.7436331861081052,0.2860836571270245,3581.0,7317.927659273148,7695.90665435791,7317.927659273148,374.21287846565247,2.667577743530273,0.0 -31500,0.029765341,0.3047004,,,,,,,,,,,,,, -31600,0.021776516,0.20388362,,,,,,,,,,,,,, -31700,0.05679083,0.2553463,,,,,,,,,,,,,, -31769,,,0.752835886819022,0.2618723937443324,0.7263305870146314,0.2847628359990504,3554.0,0.7435474198678791,0.2860863841934865,3581.0,7398.061017036438,7780.152475118637,7398.061017036438,378.281388759613,2.6994807720184326,0.0 -31800,0.052418504,0.2720397,,,,,,,,,,,,,, -31900,0.04094827,0.25602883,,,,,,,,,,,,,, -32000,0.025017478,0.21961169,,,,,,,,,,,,,, -32100,0.031862155,0.4255159,,,,,,,,,,,,,, -32114,,,0.752873148236956,0.2619447537830898,0.7266755026246835,0.2846750614679322,3554.0,0.7438437838156241,0.2859717451370951,3581.0,7478.068758964539,7864.272994995117,7478.068758964539,382.349130153656,2.732588291168213,0.0 -32200,0.02456492,0.27777004,,,,,,,,,,,,,, -32300,0.023224037,0.20142752,,,,,,,,,,,,,, -32400,0.028813707,0.33556995,,,,,,,,,,,,,, -32457,,,0.7523930413382394,0.2619777406964983,0.7261111764473129,0.2846769333959095,3554.0,0.7433731603209648,0.2859257088463854,3581.0,7558.151543617248,7948.469160079956,7558.151543617248,386.4174609184265,2.765906810760498,0.0 -32500,0.03991317,0.27629519,,,,,,,,,,,,,, -32600,0.030260513,0.26194093,,,,,,,,,,,,,, -32700,0.033484854,0.19344808,,,,,,,,,,,,,, -32800,0.019302538,0.3202116,,,,,,,,,,,,,, -32802,,,0.7533761433192662,0.2616222926548549,0.7267281226918613,0.2846399413512943,3554.0,0.7439243686295728,0.2859749835285186,3581.0,7638.266149044037,8032.69620680809,7638.266149044037,390.4858248233795,2.7979736328125,0.0 -32900,0.043314017,0.23281664,,,,,,,,,,,,,, -33000,0.032550078,0.3440344,,,,,,,,,,,,,, -33100,0.022417087,0.28649873,,,,,,,,,,,,,, -33147,,,0.7531262125287738,0.2618422337940761,0.726640124903278,0.2847223920505504,3554.0,0.743834920849623,0.2860920769447256,3581.0,7718.342501401901,8116.885622739792,7718.342501401901,394.5550725460053,2.830050706863404,0.0 -33200,0.024450896,0.34372643,,,,,,,,,,,,,, -33300,0.029669387,0.30561775,,,,,,,,,,,,,, -33400,0.020332199,0.2653442,,,,,,,,,,,,,, -33490,,,0.7528594561985561,0.2618260894502912,0.7264542373030388,0.2846458319136448,3554.0,0.7436743648116797,0.2859360546547752,3581.0,7798.298315048218,8200.950634002686,7798.298315048218,398.6193375587464,2.863105297088623,0.0 -33500,0.022377236,0.19564512,,,,,,,,,,,,,, -33600,0.026726378,0.23645435,,,,,,,,,,,,,, -33700,0.015722971,0.3429544,,,,,,,,,,,,,, -33800,0.022041582,0.20696628,,,,,,,,,,,,,, -33835,,,0.7536260059901646,0.2615469523838588,0.7269739119873734,0.2846270611129185,3554.0,0.7441031960128107,0.2859982658584369,3581.0,7878.393220424652,8285.157857656479,7878.393220424652,402.6862740516663,2.8966965675354004,0.0 -33900,0.038510703,0.31855434,,,,,,,,,,,,,, -34000,0.02093356,0.26778165,,,,,,,,,,,,,, -34100,0.020963943,0.21089333,,,,,,,,,,,,,, -34182,,,0.7533986227852958,0.2616581235613142,0.7268093197145822,0.2846151597726593,3554.0,0.7439697061095015,0.2859778128599727,3581.0,7958.514358758926,8369.389568090439,7958.514358758926,406.7518992424011,2.929734468460083,0.0 -34200,0.02465666,0.27913538,,,,,,,,,,,,,, -34300,0.01825037,0.25145456,,,,,,,,,,,,,, -34400,0.029849207,0.19816649,,,,,,,,,,,,,, -34500,0.029571058,0.2284358,,,,,,,,,,,,,, -34529,,,0.7533296176365444,0.2616626535143171,0.7267763463043402,0.2845997206602595,3554.0,0.7439403219683748,0.2859718814904182,3581.0,8038.634386777878,8453.621324062347,8038.634386777878,410.8144009113312,2.966977119445801,0.0 -34600,0.03628392,0.21537614,,,,,,,,,,,,,, -34700,0.03648051,0.30774903,,,,,,,,,,,,,, -34800,0.017352862,0.25720024,,,,,,,,,,,,,, -34873,,,0.7535209655761719,0.2615305015019008,0.7269215666986142,0.2845629862204119,3554.0,0.7440671987355139,0.2859627117294401,3581.0,8118.746688127518,8537.84702205658,8118.746688127518,414.8794231414795,3.003573179244995,0.0 -34900,0.023632696,0.17538425,,,,,,,,,,,,,, -35000,0.03418291,0.30855903,,,,,,,,,,,,,, -35100,0.013498094,0.251723,,,,,,,,,,,,,, -35200,0.014568154,0.2491421,,,,,,,,,,,,,, -35217,,,0.753373282296317,0.2615130628858294,0.7267270235781865,0.2845305623670072,3554.0,0.7439063699909243,0.2859043525071558,3581.0,8198.770928144455,8621.990622997284,8198.770928144455,418.9490053653717,3.041345357894897,0.0 -35300,0.015900232,0.3337732,,,,,,,,,,,,,, -35400,0.01956263,0.31727254,,,,,,,,,,,,,, -35500,0.025686992,0.2756055,,,,,,,,,,,,,, -35562,,,0.7534516879490444,0.2615209477288382,0.7268488191122678,0.2845391148452887,3554.0,0.7439984084840129,0.2859195729468462,3581.0,8278.949311256409,8706.287668466568,8278.949311256409,423.0212299823761,3.075789213180542,0.0 -35600,0.015165707,0.24980916,,,,,,,,,,,,,, -35700,0.020791436,0.1660886,,,,,,,,,,,,,, -35800,0.029142193,0.25818548,,,,,,,,,,,,,, -35900,0.0201883,0.28001735,,,,,,,,,,,,,, -35906,,,0.7533174923488072,0.2614938191005161,0.7266854633423607,0.2845197601404227,3554.0,0.7438668957038885,0.285884973291111,3581.0,8359.058719158173,8790.514942407608,8359.058719158173,427.092668056488,3.1104507446289062,0.0 -36000,0.020417223,0.21602494,,,,,,,,,,,,,, -36100,0.014220389,0.263988,,,,,,,,,,,,,, -36189,,,0.7535576820373535,0.2615004437310355,0.7269425185530388,0.284524293984331,3554.0,0.7441039459560876,0.2858965803677394,3581.0,8423.860284805298,8859.426425457,8423.860284805298,431.1592288017273,3.144142866134644,0.0 -36189,,,,,,,,,,,8423.860284805298,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/eval_measurements.csv deleted file mode 100644 index e92c57aee..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,79 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -199.95662784576416,0.0,55.89515209197998,1,0,55.89515209197998,1.150831046233594,3581,0.233558584477974,255.85205793380737,1.1540372031075614,0.210459760257176,1.1542266691140264,3554,0.2105898881267146 -204.3298182487488,0.026456594467163,136.12256789207458,337,0,136.12256789207458,0.3304063274486875,3581,0.6975823737695476,340.49064683914185,0.2987578936985561,0.701190676007952,0.3286026679335783,3554,0.679899079382386 -208.3515384197235,0.0636670589447021,216.5056059360504,576,0,216.5056059360504,0.3179138746051207,3581,0.7110114716777786,424.94063425064087,0.287034205027989,0.7153627531869071,0.315769691445818,3554,0.6935581084913126 -212.3737070560456,0.0957064628601074,297.05216240882874,818,0,297.05216240882874,0.3113394627788153,3581,0.7185868532489179,509.5492379665375,0.2811103548322405,0.7224571364266532,0.3088610750540764,3554,0.7014776346853897 -216.3376326560974,0.12656831741333,377.3095192909241,1101,0,377.3095192909241,0.3093823835324455,3581,0.7187400462074142,593.8110988140106,0.279527153287615,0.722386087690081,0.3073478359550858,3554,0.7012314332222496 -220.3560636043549,0.1503820419311523,457.4085175991058,1450,0,457.4085175991058,0.3038241107581681,3581,0.7228750289069045,677.9646747112274,0.273552622113909,0.7277840886797223,0.3014803206444323,3554,0.7057781917162 -224.37543368339536,0.1734664440155029,537.4023633003235,1798,0,537.4023633003235,0.3024637818303197,3581,0.7263438574464186,762.0131983757019,0.272597210747855,0.7309179987226214,0.3001107219638084,3554,0.7098639407182048 -228.39755821228027,0.1965482234954834,617.4026784896851,2144,0,617.4026784896851,0.2993826398766929,3581,0.7295739994938565,846.0708227157593,0.2699415854045323,0.7346102169581822,0.2972823933530001,3554,0.7122909211012239 -232.41604614257807,0.223621129989624,697.4331579208374,2491,0,697.4331579208374,0.2964027743537943,3581,0.7318134664723541,930.158771276474,0.2671689135687692,0.7364674976893834,0.2947580727151624,3554,0.7145024065093908 -236.43962216377253,0.2478590011596679,777.4768579006195,2840,0,777.4768579006195,0.2962220721123638,3581,0.7317461079307456,1014.2621154785156,0.2669699362346104,0.7368014880589077,0.2946955262776097,3554,0.7142880793428179 -240.45988512039185,0.2709853649139404,857.5939862728119,3185,0,857.5939862728119,0.2960640386108978,3581,0.7315329195100879,1098.4345676898956,0.2667936256953648,0.736929076058524,0.2942849386254924,3554,0.7142186290974958 -244.4809172153473,0.2939739227294922,937.7757697105408,3535,0,937.7757697105408,0.2933858889036757,3581,0.7348998921172508,1182.672307014465,0.2642932108470371,0.740006855555943,0.291770681747239,3554,0.7175935263303672 -248.5042879581452,0.3168442249298095,1017.7614195346832,3881,0,1017.7614195346832,0.2932634777078679,3581,0.7359896960520804,1266.7160177230835,0.264447672026498,0.7408522878374372,0.2916430128244583,3554,0.718755220789955 -252.5282917022705,0.3407328128814697,1097.7921483516693,4225,0,1097.7921483516693,0.2928563948617704,3581,0.7337723864885856,1350.8066008090973,0.2640284981046404,0.738696779523577,0.2914711732708743,3554,0.7159462297103616 -256.5510618686676,0.3653392791748047,1177.9653112888336,4573,0,1177.9653112888336,0.2937777342659173,3581,0.7327200115409452,1435.0387947559357,0.264590927532741,0.7383772305079869,0.2919031249450443,3554,0.7156449351742754 -260.5774459838867,0.3891665935516357,1258.125807762146,4921,0,1258.125807762146,0.295212375754852,3581,0.7332136105705459,1519.2613484859469,0.2662299530846732,0.7383550235203334,0.2935848719092923,3554,0.7162272593380697 -264.5962345600128,0.4131014347076416,1338.1535222530365,5267,0,1338.1535222530365,0.2921203937120392,3581,0.7363770076663292,1603.3435831069946,0.2632380894252232,0.7415508542742048,0.2906155476223973,3554,0.7190130316412845 -268.6173930168152,0.43715500831604,1418.1331946849823,5615,0,1418.1331946849823,0.2922144434166434,3581,0.7338091337091595,1687.3803277015686,0.2629983765738351,0.7396905762808663,0.2906085407727208,3554,0.716353039159222 -272.638201713562,0.4610710144042969,1498.1605167388916,5962,0,1498.1605167388916,0.292230908080407,3581,0.7355647509206577,1771.4643297195437,0.2633684022086007,0.741100719996861,0.2906225888193761,3554,0.7180967830041854 -276.6635181903839,0.4882900714874267,1578.220741033554,6307,0,1578.220741033554,0.2916041600308049,3581,0.7355887491055222,1855.5889139175413,0.2626251493181501,0.7410097122192383,0.2901365401440278,3554,0.7181193835291221 -280.685994386673,0.5136318206787109,1658.302259206772,6652,0,1658.302259206772,0.2926179128996788,3581,0.735859410451864,1939.7298274040224,0.2632386854716709,0.7417586190359933,0.2910806444433209,3554,0.7186538962480655 -284.7060179710388,0.539787769317627,1738.3994042873385,7000,0,1738.3994042873385,0.2913242607468235,3581,0.7368869009180397,2023.885073661804,0.262102552822658,0.7424896103995187,0.2897551820461979,3554,0.7194457389561058 -288.7262351512909,0.566457986831665,1818.5374870300293,7348,0,1818.5374870300293,0.2914589096533789,3581,0.7334932030595505,2108.081840991974,0.26255578654153,0.7386776379176548,0.29005053449898,3554,0.7159020590795583 -292.74238300323486,0.5925304889678955,1898.6189014911647,7694,0,1898.6189014911647,0.2932632050012217,3581,0.7344401087117775,2192.217269182205,0.2637131043842861,0.7403997693743024,0.2915032536512556,3554,0.717441024307998 -296.76133704185486,0.6188676357269287,1978.6849682331083,8041,0,1978.6849682331083,0.291029771657271,3581,0.7373533656363446,2276.340225696564,0.2620921645845686,0.742549010685512,0.2896134650767621,3554,0.7199878767761677 -300.7832546234131,0.644014835357666,2058.670554161072,8388,0,2058.670554161072,0.2922346237084613,3581,0.7366206710546984,2360.384894132614,0.2629777533667428,0.7422918592180524,0.2908316608486916,3554,0.7192087425699916 -304.80595803260803,0.6700015068054199,2138.7477850914,8731,0,2138.7477850914,0.2911163219291049,3581,0.7364160047167342,2444.522290945053,0.2619698728833879,0.7420447894505092,0.2895826898938696,3554,0.7189635715259215 -308.8210089206696,0.6955749988555908,2218.78799200058,9080,0,2218.78799200058,0.29208429416975,3581,0.7354463280595505,2528.614986181259,0.2628451585769653,0.7410753113882882,0.2903377122938063,3554,0.7184303640044668 -312.8451218605041,0.7218742370605469,2298.791113376617,9425,0,2298.791113376617,0.2904704844142697,3581,0.7384188986796635,2612.6801404953003,0.2612859351294381,0.7440524101257324,0.2889386092804762,3554,0.721223967492614 -316.8660409450531,0.7474322319030762,2378.839912891388,9769,0,2378.839912891388,0.2925327943377374,3581,0.7389364958941287,2696.7871565818787,0.2629133633204868,0.7446088109697614,0.2909829950627813,3554,0.7217951631304516 -320.8879749774933,0.776827335357666,2458.948125362396,10117,0,2458.948125362396,0.2907282262832484,3581,0.7365600620025831,2780.958632707596,0.2617013113839285,0.7419192450387138,0.2892021904785981,3554,0.7191811273389139 -324.90787982940674,0.8022313117980957,2539.0965452194214,10467,0,2539.0965452194214,0.2908981566121544,3581,0.7381637134354929,2865.1641743183136,0.2619059426443917,0.7434990065438407,0.2893591919975731,3554,0.7210130750562747 -328.9296774864197,0.8319780826568604,2619.157484292984,10813,0,2619.157484292984,0.2908987020254468,3581,0.7383843331122592,2949.288330078125,0.2616296325411115,0.744257995060512,0.2892786475735966,3554,0.7210670003209412 -332.9528057575226,0.8582251071929932,2699.1604931354523,11160,0,2699.1604931354523,0.2899860892339779,3581,0.7384683267592851,3033.352387428284,0.26066940171378,0.7442623547145298,0.2885349941307329,3554,0.7210964703063449 -336.9744844436645,0.8830990791320801,2779.2176752090454,11507,0,2779.2176752090454,0.2936014294191567,3581,0.7342443735164759,3117.467721223831,0.2642649071557181,0.7400140081133161,0.2920154749709834,3554,0.7170874531777575 -340.9918022155762,0.909146785736084,2859.2424590587616,11854,0,2859.2424590587616,0.2904373846450886,3581,0.735805823595888,3201.547816991806,0.2613157033920288,0.7415238107953753,0.2890872644049838,3554,0.7181548986397369 -344.9590666294098,0.9371671676635742,2939.2277522087097,12201,0,2939.2277522087097,0.2900164278483664,3581,0.7371981273780019,3285.540221691132,0.2606162514005388,0.7432973044259208,0.2884695968670864,3554,0.7198399085977069 -348.986310005188,0.9636261463165284,3019.357433795929,12551,0,3019.357433795929,0.2904927440942648,3581,0.7386292236805362,3369.735442638397,0.2611629281725202,0.7444561549595424,0.2888864357282287,3554,0.721402573464758 -353.00896859169006,0.9894242286682128,3099.3548357486725,12899,0,3099.3548357486725,0.2913513950581192,3581,0.7363361698460625,3453.79327750206,0.2620973927634103,0.7422619547162738,0.2898500492952483,3554,0.7189331398160523 -357.0331120491028,1.020883560180664,3179.518202066421,13245,0,3179.518202066421,0.2914022207593025,3581,0.7387102857311156,3538.023914575577,0.2619391339165823,0.7446674619402204,0.289828719620498,3554,0.7216041921294668 -361.0515315532684,1.0500917434692385,3259.6106762886047,13594,0,3259.6106762886047,0.2912512435423066,3581,0.7342123304855487,3622.175755262375,0.2620710645403181,0.7400798797607422,0.2898678755451603,3554,0.716524363503271 -365.0702559947968,1.0759599208831787,3339.786530971527,13944,0,3339.786530971527,0.2912381877116203,3581,0.7380589940833566,3706.407961368561,0.2618956565856933,0.743708746773856,0.2896753245682681,3554,0.7209935657885481 -369.08724784851074,1.1019816398620603,3419.961035490036,14291,0,3419.961035490036,0.2900729122124581,3581,0.738984764970504,3790.6374514102936,0.2604974678584507,0.7451628957475934,0.28851891959324,3554,0.7217649375043964 -373.1056852340698,1.128594160079956,3499.925144672394,14639,0,3499.925144672394,0.2903789572461428,3581,0.7378574638718235,3874.658413887024,0.2609571899686541,0.7438312258039202,0.2889053267445132,3554,0.7205989152847144 -377.1321818828583,1.1578266620635986,3580.1369519233704,14987,0,3580.1369519233704,0.2895177837458112,3581,0.7386629711280019,3958.937545061112,0.2602474348885672,0.7444625582013812,0.2880851303383951,3554,0.7212619556089969 -381.1527545452118,1.184983491897583,3660.195437669754,15335,0,3660.195437669754,0.2901615759608,3581,0.7375525778413851,4043.0554831028,0.260447655405317,0.7439757755824498,0.2886782910760762,3554,0.7202468554357766 -385.1748712062836,1.210752248764038,3740.1879727840414,15682,0,3740.1879727840414,0.2897529932281485,3581,0.7375366926792446,4127.107687950134,0.2603464467184884,0.7433926037379673,0.2883091262705754,3554,0.720238612083216 -389.2005672454834,1.2374587059020996,3820.219638109207,16031,0,3820.219638109207,0.2894294267924462,3581,0.7384739854221936,4211.203547954559,0.2601829086031232,0.74422516141619,0.2879862272813203,3554,0.7211628979890616 -393.2281312942505,1.2644691467285156,3900.333717823029,16377,0,3900.333717823029,0.28971559832929,3581,0.739368599575014,4295.38410115242,0.2603557280131748,0.745075157710484,0.2882707946811691,3554,0.7221244163706387 -397.2499804496765,1.2904927730560305,3980.330514192581,16726,0,3980.330514192581,0.2894622197666504,3581,0.7377609938957345,4379.440765142441,0.260012013571603,0.7437037059238979,0.28797430876741,3554,0.720381359471722 -401.2716295719147,1.3176743984222412,4060.434873819351,17075,0,4060.434873819351,0.289546417943661,3581,0.7381862117338034,4463.605730295181,0.2600981848580496,0.7442690304347447,0.2881094997494021,3554,0.7208010148116559 -405.2911117076874,1.3442814350128174,4140.43146109581,17417,0,4140.43146109581,0.2892599055235095,3581,0.7391449801251396,4547.659974336624,0.259650605065482,0.7453197070530483,0.2878206904577149,3554,0.7217897362566826 -409.3108265399933,1.3714284896850586,4220.44064950943,17767,0,4220.44064950943,0.2909241319202038,3581,0.7371324732529322,4631.727960109711,0.2613401753561837,0.743206364767892,0.2893077740859771,3554,0.7201920371412492 -413.335652589798,1.4000122547149658,4300.534623622894,18118,0,4300.534623622894,0.2895894033287664,3581,0.7390620091280369,4715.887307167053,0.2603205101830618,0.7448388508387974,0.288125162119267,3554,0.7218726506445202 -417.3572266101837,1.4266109466552734,4380.675292730331,18463,0,4380.675292730331,0.2892399297616762,3581,0.7387730764364004,4800.087876796722,0.2599183320999145,0.7445802688598633,0.2878617183103545,3554,0.7213314058543191 -421.3814389705658,1.4544930458068848,4460.688131093979,18812,0,4460.688131093979,0.289521124402227,3581,0.7380600849099413,4884.164512872696,0.2601046221596854,0.7440087454659599,0.2881414770878763,3554,0.7206378651255627 -425.4045426845551,1.481736421585083,4540.853692531586,19160,0,4540.853692531586,0.2896156172551312,3581,0.7394270951506213,4968.392388820648,0.2602435520717076,0.7451982498168945,0.2882147055364554,3554,0.7221053192705402 -429.4290335178375,1.5102970600128174,4620.966006994247,19506,0,4620.966006994247,0.2889937779251605,3581,0.739859130654845,5052.569364309311,0.2597181286130632,0.7456203869410923,0.2874624996153102,3554,0.7225618636131823 -433.454980134964,1.5415527820587158,4701.128844022751,19853,0,4701.128844022751,0.2890187646716176,3581,0.7397810002007121,5136.8013026714325,0.2593709911618914,0.7458940914699009,0.2875328085431907,3554,0.7224794300875774 -437.4760012626648,1.568915605545044,4781.280727386475,20199,0,4781.280727386475,0.288969029797019,3581,0.7392424045744904,5221.013342380524,0.2595722845622471,0.7451731136866978,0.2875190524486054,3554,0.7218680481060074 -441.496239900589,1.5969295501708984,4861.3911328315735,20545,0,4861.3911328315735,0.2892529174157009,3581,0.7392852876946034,5305.1836223602295,0.25978684425354,0.7453296525137765,0.2878251212597161,3554,0.7219164091076955 -445.5145680904389,1.6300604343414309,4941.5508685112,20883,0,4941.5508685112,0.2891802410944917,3581,0.7395689026066392,5389.406239509583,0.2595101594924927,0.7458878244672503,0.2876796947816281,3554,0.7223037779834341 -449.53950595855713,1.6624484062194824,5021.575328111649,21230,0,5021.575328111649,0.2896083905290073,3581,0.7384011727476613,5473.500054836273,0.2599371671676636,0.7447165080479213,0.288131001160664,3554,0.7209138113525253 -453.5573410987854,1.689453363418579,5101.56090760231,21578,0,5101.56090760231,0.2895140340294261,3581,0.7385272995715233,5557.542322397232,0.259907671383449,0.7447715486798968,0.2880903339546989,3554,0.7211905819147439 -457.5829334259033,1.7177293300628662,5181.675065755844,21926,0,5181.675065755844,0.2895594737743472,3581,0.7386807652366657,5641.722361803055,0.2598998546600342,0.7449408258710589,0.2881070782645874,3554,0.7213057140721723 -461.6102552413941,1.7452080249786377,5261.723633766174,22274,0,5261.723633766174,0.289036456515289,3581,0.7404866968243856,5725.83787894249,0.2593822990145002,0.7466844831194196,0.2875330489743071,3554,0.7233081617983258 -465.6337096691132,1.7741332054138184,5341.895644664764,22622,0,5341.895644664764,0.2886994933655927,3581,0.739449729802255,5810.07394862175,0.2592719623020717,0.7454221589224679,0.2873049828868,3554,0.7220717963034609 -469.65551710128784,1.8019211292266848,5421.932682275772,22970,0,5421.932682275772,0.2897626061374267,3581,0.7376963624205878,5894.17240691185,0.259899514062064,0.744358880179269,0.2882900978650816,3554,0.7203458443611072 -473.6802542209625,1.8305857181549072,5501.911800146103,23319,0,5501.911800146103,0.2887416947190903,3581,0.7392803107983106,5978.216735839844,0.2591825383050101,0.7454833303179059,0.2873254882262943,3554,0.7220026208365574 -477.7016234397888,1.8585996627807613,5581.881755113602,23668,0,5581.881755113602,0.2891561065563041,3581,0.7398573580616448,6062.247819662094,0.259478007044111,0.7462118012564523,0.2877843853591464,3554,0.7225752590610931 -481.7267339229584,1.8905537128448489,5661.870182514191,24011,0,5661.870182514191,0.2884309795840896,3581,0.7396680996491902,6146.304933309555,0.2587435586111886,0.7459972926548549,0.2870924933009021,3554,0.7221812268087014 -485.7525725364685,1.9191038608551023,5741.839620113373,24358,0,5741.839620113373,0.2889409069241308,3581,0.7385344581209857,6230.340355634689,0.2594368287495204,0.7446889877319336,0.2875983431960467,3554,0.7210218679656725 -489.77895855903625,1.9476697444915767,5821.816328525543,24706,0,5821.816328525543,0.2882839225251326,3581,0.7397184822020735,6314.383749961853,0.2587285212108067,0.7457452501569476,0.2868627957165341,3554,0.7223903331853193 -493.80056977272034,1.9763381481170648,5901.898707389832,25050,0,5901.898707389832,0.2880520196108803,3581,0.7405085815327422,6398.527945518494,0.258384244782584,0.746821403503418,0.2866580342736617,3554,0.723183961953081 -497.8165833950043,2.005166530609131,5981.937157869339,25398,0,5981.937157869339,0.2889792903845818,3581,0.740772493389591,6482.623092412949,0.2592196294239589,0.7470160893031529,0.2875559414513136,3554,0.7234911642251688 -501.8346812725067,2.0372416973114014,6062.071349143982,25747,0,6062.071349143982,0.2892967209207449,3581,0.7385599561924043,6566.819430589676,0.2595665114266531,0.7450372150966099,0.2879249001730005,3554,0.7211857045978123 -505.8603518009186,2.0665762424468994,6142.251312255859,26092,0,6142.251312255859,0.2886063640459194,3581,0.7407934918013473,6651.06591629982,0.2586601461683001,0.7474192210606166,0.2871058887488129,3554,0.723627385626231 -509.8812828063965,2.0960710048675537,6222.414223909378,26443,0,6222.414223909378,0.2882083827841385,3581,0.7414119222982407,6735.290981769562,0.2583139113017491,0.7478544371468681,0.28681055346968204,3554,0.724158875782569 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/measurements.csv deleted file mode 100644 index aab8402c0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/measurements.csv +++ /dev/null @@ -1,345 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.4210806,1.0849575,,,,,,,,,,,,,, -1,,,0.210459760257176,1.1540372031075614,0.2105898881267146,1.1542266691140264,3554.0,0.233558584477974,1.150831046233594,3581.0,55.89515209197998,255.85205793380737,55.89515209197998,199.95662784576416,0.0,0.0 -100,0.2399431,0.33590412,,,,,,,,,,,,,, -200,0.16434585,0.48278666,,,,,,,,,,,,,, -300,0.19008891,0.41215897,,,,,,,,,,,,,, -337,,,0.701190676007952,0.2987578936985561,0.679899079382386,0.3286026679335783,3554.0,0.6975823737695476,0.3304063274486875,3581.0,136.12256789207458,340.49064683914185,136.12256789207458,204.3298182487488,0.026456594467163,0.0 -400,0.10894664,0.33010972,,,,,,,,,,,,,, -500,0.3105247,0.2766124,,,,,,,,,,,,,, -576,,,0.7153627531869071,0.287034205027989,0.6935581084913126,0.315769691445818,3554.0,0.7110114716777786,0.3179138746051207,3581.0,216.5056059360504,424.94063425064087,216.5056059360504,208.3515384197235,0.0636670589447021,0.0 -600,0.13099287,0.36523917,,,,,,,,,,,,,, -700,0.39246282,0.25738737,,,,,,,,,,,,,, -800,0.19242251,0.33679974,,,,,,,,,,,,,, -818,,,0.7224571364266532,0.2811103548322405,0.7014776346853897,0.3088610750540764,3554.0,0.7185868532489179,0.3113394627788153,3581.0,297.05216240882874,509.5492379665375,297.05216240882874,212.3737070560456,0.0957064628601074,0.0 -900,0.49283075,0.32519957,,,,,,,,,,,,,, -1000,0.13652164,0.23639685,,,,,,,,,,,,,, -1100,0.36039943,0.3124118,,,,,,,,,,,,,, -1101,,,0.722386087690081,0.279527153287615,0.7012314332222496,0.3073478359550858,3554.0,0.7187400462074142,0.3093823835324455,3581.0,377.3095192909241,593.8110988140106,377.3095192909241,216.3376326560974,0.12656831741333,0.0 -1200,0.24798478,0.3370981,,,,,,,,,,,,,, -1300,0.13398913,0.28114426,,,,,,,,,,,,,, -1400,0.35858044,0.21642116,,,,,,,,,,,,,, -1450,,,0.7277840886797223,0.273552622113909,0.7057781917162,0.3014803206444323,3554.0,0.7228750289069045,0.3038241107581681,3581.0,457.4085175991058,677.9646747112274,457.4085175991058,220.3560636043549,0.1503820419311523,0.0 -1500,0.16574581,0.31471035,,,,,,,,,,,,,, -1600,0.42286256,0.26887083,,,,,,,,,,,,,, -1700,0.07372279,0.39416897,,,,,,,,,,,,,, -1798,,,0.7309179987226214,0.272597210747855,0.7098639407182048,0.3001107219638084,3554.0,0.7263438574464186,0.3024637818303197,3581.0,537.4023633003235,762.0131983757019,537.4023633003235,224.37543368339536,0.1734664440155029,0.0 -1800,0.16570815,0.3367334,,,,,,,,,,,,,, -1900,0.14233685,0.34502292,,,,,,,,,,,,,, -2000,0.19054794,0.28741124,,,,,,,,,,,,,, -2100,0.08151068,0.28000894,,,,,,,,,,,,,, -2144,,,0.7346102169581822,0.2699415854045323,0.7122909211012239,0.2972823933530001,3554.0,0.7295739994938565,0.2993826398766929,3581.0,617.4026784896851,846.0708227157593,617.4026784896851,228.39755821228027,0.1965482234954834,0.0 -2200,0.13256676,0.25034744,,,,,,,,,,,,,, -2300,0.08959343,0.2871824,,,,,,,,,,,,,, -2400,0.110075004,0.23927635,,,,,,,,,,,,,, -2491,,,0.7364674976893834,0.2671689135687692,0.7145024065093908,0.2947580727151624,3554.0,0.7318134664723541,0.2964027743537943,3581.0,697.4331579208374,930.158771276474,697.4331579208374,232.41604614257807,0.223621129989624,0.0 -2500,0.14196844,0.31468195,,,,,,,,,,,,,, -2600,0.11259343,0.2787968,,,,,,,,,,,,,, -2700,0.17183347,0.30487674,,,,,,,,,,,,,, -2800,0.0782601,0.2787603,,,,,,,,,,,,,, -2840,,,0.7368014880589077,0.2669699362346104,0.7142880793428179,0.2946955262776097,3554.0,0.7317461079307456,0.2962220721123638,3581.0,777.4768579006195,1014.2621154785156,777.4768579006195,236.43962216377253,0.2478590011596679,0.0 -2900,0.12994745,0.26003152,,,,,,,,,,,,,, -3000,0.16219456,0.30622503,,,,,,,,,,,,,, -3100,0.21527687,0.2403001,,,,,,,,,,,,,, -3185,,,0.736929076058524,0.2667936256953648,0.7142186290974958,0.2942849386254924,3554.0,0.7315329195100879,0.2960640386108978,3581.0,857.5939862728119,1098.4345676898956,857.5939862728119,240.45988512039185,0.2709853649139404,0.0 -3200,0.15621725,0.23723334,,,,,,,,,,,,,, -3300,0.13309756,0.3054,,,,,,,,,,,,,, -3400,0.07645847,0.37125003,,,,,,,,,,,,,, -3500,0.28388974,0.25646496,,,,,,,,,,,,,, -3535,,,0.740006855555943,0.2642932108470371,0.7175935263303672,0.291770681747239,3554.0,0.7348998921172508,0.2933858889036757,3581.0,937.7757697105408,1182.672307014465,937.7757697105408,244.4809172153473,0.2939739227294922,0.0 -3600,0.23307882,0.3509268,,,,,,,,,,,,,, -3700,0.10483403,0.34773484,,,,,,,,,,,,,, -3800,0.07027519,0.2572378,,,,,,,,,,,,,, -3881,,,0.7408522878374372,0.264447672026498,0.718755220789955,0.2916430128244583,3554.0,0.7359896960520804,0.2932634777078679,3581.0,1017.7614195346832,1266.7160177230835,1017.7614195346832,248.5042879581452,0.3168442249298095,0.0 -3900,0.22582898,0.28376433,,,,,,,,,,,,,, -4000,0.29724047,0.2888298,,,,,,,,,,,,,, -4100,0.08655737,0.2753749,,,,,,,,,,,,,, -4200,0.3018006,0.28325346,,,,,,,,,,,,,, -4225,,,0.738696779523577,0.2640284981046404,0.7159462297103616,0.2914711732708743,3554.0,0.7337723864885856,0.2928563948617704,3581.0,1097.7921483516693,1350.8066008090973,1097.7921483516693,252.5282917022705,0.3407328128814697,0.0 -4300,0.1773042,0.4000774,,,,,,,,,,,,,, -4400,0.16578002,0.35892665,,,,,,,,,,,,,, -4500,0.18857211,0.27168086,,,,,,,,,,,,,, -4573,,,0.7383772305079869,0.264590927532741,0.7156449351742754,0.2919031249450443,3554.0,0.7327200115409452,0.2937777342659173,3581.0,1177.9653112888336,1435.0387947559357,1177.9653112888336,256.5510618686676,0.3653392791748047,0.0 -4600,0.37343103,0.26013592,,,,,,,,,,,,,, -4700,0.24166213,0.2257221,,,,,,,,,,,,,, -4800,0.19031456,0.22274315,,,,,,,,,,,,,, -4900,0.0970585,0.3945529,,,,,,,,,,,,,, -4921,,,0.7383550235203334,0.2662299530846732,0.7162272593380697,0.2935848719092923,3554.0,0.7332136105705459,0.295212375754852,3581.0,1258.125807762146,1519.2613484859469,1258.125807762146,260.5774459838867,0.3891665935516357,0.0 -5000,0.23621127,0.3382774,,,,,,,,,,,,,, -5100,0.08634088,0.26416025,,,,,,,,,,,,,, -5200,0.1426004,0.24142914,,,,,,,,,,,,,, -5267,,,0.7415508542742048,0.2632380894252232,0.7190130316412845,0.2906155476223973,3554.0,0.7363770076663292,0.2921203937120392,3581.0,1338.1535222530365,1603.3435831069946,1338.1535222530365,264.5962345600128,0.4131014347076416,0.0 -5300,0.33182767,0.19448516,,,,,,,,,,,,,, -5400,0.1687457,0.28946996,,,,,,,,,,,,,, -5500,0.19513644,0.3070838,,,,,,,,,,,,,, -5600,0.19982879,0.23186828,,,,,,,,,,,,,, -5615,,,0.7396905762808663,0.2629983765738351,0.716353039159222,0.2906085407727208,3554.0,0.7338091337091595,0.2922144434166434,3581.0,1418.1331946849823,1687.3803277015686,1418.1331946849823,268.6173930168152,0.43715500831604,0.0 -5700,0.10383309,0.30763584,,,,,,,,,,,,,, -5800,0.14164536,0.3703593,,,,,,,,,,,,,, -5900,0.09988102,0.30155385,,,,,,,,,,,,,, -5962,,,0.741100719996861,0.2633684022086007,0.7180967830041854,0.2906225888193761,3554.0,0.7355647509206577,0.292230908080407,3581.0,1498.1605167388916,1771.4643297195437,1498.1605167388916,272.638201713562,0.4610710144042969,0.0 -6000,0.19578187,0.2827755,,,,,,,,,,,,,, -6100,0.19565146,0.2797544,,,,,,,,,,,,,, -6200,0.18928406,0.26475197,,,,,,,,,,,,,, -6300,0.2904836,0.26588583,,,,,,,,,,,,,, -6307,,,0.7410097122192383,0.2626251493181501,0.7181193835291221,0.2901365401440278,3554.0,0.7355887491055222,0.2916041600308049,3581.0,1578.220741033554,1855.5889139175413,1578.220741033554,276.6635181903839,0.4882900714874267,0.0 -6400,0.15349649,0.24903502,,,,,,,,,,,,,, -6500,0.07008614,0.30008027,,,,,,,,,,,,,, -6600,0.27653074,0.38966495,,,,,,,,,,,,,, -6652,,,0.7417586190359933,0.2632386854716709,0.7186538962480655,0.2910806444433209,3554.0,0.735859410451864,0.2926179128996788,3581.0,1658.302259206772,1939.7298274040224,1658.302259206772,280.685994386673,0.5136318206787109,0.0 -6700,0.086690724,0.33472064,,,,,,,,,,,,,, -6800,0.18994163,0.3208755,,,,,,,,,,,,,, -6900,0.14941138,0.32804614,,,,,,,,,,,,,, -7000,,,0.7424896103995187,0.262102552822658,0.7194457389561058,0.2897551820461979,3554.0,0.7368869009180397,0.2913242607468235,3581.0,1738.3994042873385,2023.885073661804,1738.3994042873385,284.7060179710388,0.539787769317627,0.0 -7000,0.2855275,0.22786786,,,,,,,,,,,,,, -7100,0.28415814,0.24203682,,,,,,,,,,,,,, -7200,0.12948951,0.30582175,,,,,,,,,,,,,, -7300,0.24072826,0.27774346,,,,,,,,,,,,,, -7348,,,0.7386776379176548,0.26255578654153,0.7159020590795583,0.29005053449898,3554.0,0.7334932030595505,0.2914589096533789,3581.0,1818.5374870300293,2108.081840991974,1818.5374870300293,288.7262351512909,0.566457986831665,0.0 -7400,0.12900196,0.30287296,,,,,,,,,,,,,, -7500,0.14722289,0.25063285,,,,,,,,,,,,,, -7600,0.32534438,0.2545946,,,,,,,,,,,,,, -7694,,,0.7403997693743024,0.2637131043842861,0.717441024307998,0.2915032536512556,3554.0,0.7344401087117775,0.2932632050012217,3581.0,1898.6189014911647,2192.217269182205,1898.6189014911647,292.74238300323486,0.5925304889678955,0.0 -7700,0.22601229,0.2565541,,,,,,,,,,,,,, -7800,0.09266451,0.25631332,,,,,,,,,,,,,, -7900,0.0988933,0.31056494,,,,,,,,,,,,,, -8000,0.085912816,0.29647186,,,,,,,,,,,,,, -8041,,,0.742549010685512,0.2620921645845686,0.7199878767761677,0.2896134650767621,3554.0,0.7373533656363446,0.291029771657271,3581.0,1978.6849682331083,2276.340225696564,1978.6849682331083,296.76133704185486,0.6188676357269287,0.0 -8100,0.11468547,0.22232284,,,,,,,,,,,,,, -8200,0.11560994,0.3593302,,,,,,,,,,,,,, -8300,0.19172122,0.27727497,,,,,,,,,,,,,, -8388,,,0.7422918592180524,0.2629777533667428,0.7192087425699916,0.2908316608486916,3554.0,0.7366206710546984,0.2922346237084613,3581.0,2058.670554161072,2360.384894132614,2058.670554161072,300.7832546234131,0.644014835357666,0.0 -8400,0.09926631,0.26331684,,,,,,,,,,,,,, -8500,0.17494513,0.3179474,,,,,,,,,,,,,, -8600,0.12826678,0.21403244,,,,,,,,,,,,,, -8700,0.29157585,0.29634356,,,,,,,,,,,,,, -8731,,,0.7420447894505092,0.2619698728833879,0.7189635715259215,0.2895826898938696,3554.0,0.7364160047167342,0.2911163219291049,3581.0,2138.7477850914,2444.522290945053,2138.7477850914,304.80595803260803,0.6700015068054199,0.0 -8800,0.07669888,0.2596552,,,,,,,,,,,,,, -8900,0.17936286,0.30696014,,,,,,,,,,,,,, -9000,0.24310678,0.31200397,,,,,,,,,,,,,, -9080,,,0.7410753113882882,0.2628451585769653,0.7184303640044668,0.2903377122938063,3554.0,0.7354463280595505,0.29208429416975,3581.0,2218.78799200058,2528.614986181259,2218.78799200058,308.8210089206696,0.6955749988555908,0.0 -9100,0.1931695,0.33724374,,,,,,,,,,,,,, -9200,0.15873128,0.27863583,,,,,,,,,,,,,, -9300,0.10929852,0.3021149,,,,,,,,,,,,,, -9400,0.08013012,0.2396112,,,,,,,,,,,,,, -9425,,,0.7440524101257324,0.2612859351294381,0.721223967492614,0.2889386092804762,3554.0,0.7384188986796635,0.2904704844142697,3581.0,2298.791113376617,2612.6801404953003,2298.791113376617,312.8451218605041,0.7218742370605469,0.0 -9500,0.07706163,0.26728788,,,,,,,,,,,,,, -9600,0.06395672,0.37107605,,,,,,,,,,,,,, -9700,0.047302544,0.3500932,,,,,,,,,,,,,, -9769,,,0.7446088109697614,0.2629133633204868,0.7217951631304516,0.2909829950627813,3554.0,0.7389364958941287,0.2925327943377374,3581.0,2378.839912891388,2696.7871565818787,2378.839912891388,316.8660409450531,0.7474322319030762,0.0 -9800,0.16426635,0.3000286,,,,,,,,,,,,,, -9900,0.14563105,0.27447098,,,,,,,,,,,,,, -10000,0.3008789,0.4083894,,,,,,,,,,,,,, -10100,0.12863967,0.27750453,,,,,,,,,,,,,, -10117,,,0.7419192450387138,0.2617013113839285,0.7191811273389139,0.2892021904785981,3554.0,0.7365600620025831,0.2907282262832484,3581.0,2458.948125362396,2780.958632707596,2458.948125362396,320.8879749774933,0.776827335357666,0.0 -10200,0.20103283,0.24982819,,,,,,,,,,,,,, -10300,0.09751584,0.27772024,,,,,,,,,,,,,, -10400,0.24212104,0.3186146,,,,,,,,,,,,,, -10467,,,0.7434990065438407,0.2619059426443917,0.7210130750562747,0.2893591919975731,3554.0,0.7381637134354929,0.2908981566121544,3581.0,2539.0965452194214,2865.1641743183136,2539.0965452194214,324.90787982940674,0.8022313117980957,0.0 -10500,0.23748651,0.30205145,,,,,,,,,,,,,, -10600,0.11330946,0.27445513,,,,,,,,,,,,,, -10700,0.12636328,0.2674493,,,,,,,,,,,,,, -10800,0.2652751,0.33505815,,,,,,,,,,,,,, -10813,,,0.744257995060512,0.2616296325411115,0.7210670003209412,0.2892786475735966,3554.0,0.7383843331122592,0.2908987020254468,3581.0,2619.157484292984,2949.288330078125,2619.157484292984,328.9296774864197,0.8319780826568604,0.0 -10900,0.29928812,0.23900987,,,,,,,,,,,,,, -11000,0.16589388,0.29246902,,,,,,,,,,,,,, -11100,0.16924961,0.26512343,,,,,,,,,,,,,, -11160,,,0.7442623547145298,0.26066940171378,0.7210964703063449,0.2885349941307329,3554.0,0.7384683267592851,0.2899860892339779,3581.0,2699.1604931354523,3033.352387428284,2699.1604931354523,332.9528057575226,0.8582251071929932,0.0 -11200,0.18959454,0.29832506,,,,,,,,,,,,,, -11300,0.26339042,0.2823352,,,,,,,,,,,,,, -11400,0.21278928,0.2965999,,,,,,,,,,,,,, -11500,0.17500299,0.2920966,,,,,,,,,,,,,, -11507,,,0.7400140081133161,0.2642649071557181,0.7170874531777575,0.2920154749709834,3554.0,0.7342443735164759,0.2936014294191567,3581.0,2779.2176752090454,3117.467721223831,2779.2176752090454,336.9744844436645,0.8830990791320801,0.0 -11600,0.12026389,0.24180007,,,,,,,,,,,,,, -11700,0.07975544,0.30697453,,,,,,,,,,,,,, -11800,0.216229,0.28491783,,,,,,,,,,,,,, -11854,,,0.7415238107953753,0.2613157033920288,0.7181548986397369,0.2890872644049838,3554.0,0.735805823595888,0.2904373846450886,3581.0,2859.2424590587616,3201.547816991806,2859.2424590587616,340.9918022155762,0.909146785736084,0.0 -11900,0.24277468,0.32128233,,,,,,,,,,,,,, -12000,0.18802762,0.31079438,,,,,,,,,,,,,, -12100,0.17198887,0.22727934,,,,,,,,,,,,,, -12200,0.1417741,0.28972626,,,,,,,,,,,,,, -12201,,,0.7432973044259208,0.2606162514005388,0.7198399085977069,0.2884695968670864,3554.0,0.7371981273780019,0.2900164278483664,3581.0,2939.2277522087097,3285.540221691132,2939.2277522087097,344.9590666294098,0.9371671676635742,0.0 -12300,0.18130249,0.35203525,,,,,,,,,,,,,, -12400,0.13394605,0.3291643,,,,,,,,,,,,,, -12500,0.067973845,0.2946281,,,,,,,,,,,,,, -12551,,,0.7444561549595424,0.2611629281725202,0.721402573464758,0.2888864357282287,3554.0,0.7386292236805362,0.2904927440942648,3581.0,3019.357433795929,3369.735442638397,3019.357433795929,348.986310005188,0.9636261463165284,0.0 -12600,0.11384746,0.3133821,,,,,,,,,,,,,, -12700,0.23925304,0.34147376,,,,,,,,,,,,,, -12800,0.31483278,0.23096873,,,,,,,,,,,,,, -12899,,,0.7422619547162738,0.2620973927634103,0.7189331398160523,0.2898500492952483,3554.0,0.7363361698460625,0.2913513950581192,3581.0,3099.3548357486725,3453.79327750206,3099.3548357486725,353.00896859169006,0.9894242286682128,0.0 -12900,0.15158029,0.26172692,,,,,,,,,,,,,, -13000,0.17566495,0.2624408,,,,,,,,,,,,,, -13100,0.1333262,0.23963027,,,,,,,,,,,,,, -13200,0.18459754,0.29050076,,,,,,,,,,,,,, -13245,,,0.7446674619402204,0.2619391339165823,0.7216041921294668,0.289828719620498,3554.0,0.7387102857311156,0.2914022207593025,3581.0,3179.518202066421,3538.023914575577,3179.518202066421,357.0331120491028,1.020883560180664,0.0 -13300,0.11315003,0.3415078,,,,,,,,,,,,,, -13400,0.1488878,0.293574,,,,,,,,,,,,,, -13500,0.17765403,0.32469732,,,,,,,,,,,,,, -13594,,,0.7400798797607422,0.2620710645403181,0.716524363503271,0.2898678755451603,3554.0,0.7342123304855487,0.2912512435423066,3581.0,3259.6106762886047,3622.175755262375,3259.6106762886047,361.0515315532684,1.0500917434692385,0.0 -13600,0.13655671,0.23129562,,,,,,,,,,,,,, -13700,0.13835134,0.2828701,,,,,,,,,,,,,, -13800,0.123179205,0.25449708,,,,,,,,,,,,,, -13900,0.46267098,0.20625341,,,,,,,,,,,,,, -13944,,,0.743708746773856,0.2618956565856933,0.7209935657885481,0.2896753245682681,3554.0,0.7380589940833566,0.2912381877116203,3581.0,3339.786530971527,3706.407961368561,3339.786530971527,365.0702559947968,1.0759599208831787,0.0 -14000,0.16127466,0.26736066,,,,,,,,,,,,,, -14100,0.1249413,0.28007817,,,,,,,,,,,,,, -14200,0.2329087,0.22940812,,,,,,,,,,,,,, -14291,,,0.7451628957475934,0.2604974678584507,0.7217649375043964,0.28851891959324,3554.0,0.738984764970504,0.2900729122124581,3581.0,3419.961035490036,3790.6374514102936,3419.961035490036,369.08724784851074,1.1019816398620603,0.0 -14300,0.22058207,0.22198884,,,,,,,,,,,,,, -14400,0.11281902,0.31075448,,,,,,,,,,,,,, -14500,0.17067496,0.31786698,,,,,,,,,,,,,, -14600,0.104906015,0.22762218,,,,,,,,,,,,,, -14639,,,0.7438312258039202,0.2609571899686541,0.7205989152847144,0.2889053267445132,3554.0,0.7378574638718235,0.2903789572461428,3581.0,3499.925144672394,3874.658413887024,3499.925144672394,373.1056852340698,1.128594160079956,0.0 -14700,0.16812663,0.26342747,,,,,,,,,,,,,, -14800,0.12792496,0.25290006,,,,,,,,,,,,,, -14900,0.18480754,0.26243168,,,,,,,,,,,,,, -14987,,,0.7444625582013812,0.2602474348885672,0.7212619556089969,0.2880851303383951,3554.0,0.7386629711280019,0.2895177837458112,3581.0,3580.1369519233704,3958.937545061112,3580.1369519233704,377.1321818828583,1.1578266620635986,0.0 -15000,0.13138604,0.26286924,,,,,,,,,,,,,, -15100,0.08123517,0.24499574,,,,,,,,,,,,,, -15200,0.12047123,0.2320569,,,,,,,,,,,,,, -15300,0.15087312,0.29743403,,,,,,,,,,,,,, -15335,,,0.7439757755824498,0.260447655405317,0.7202468554357766,0.2886782910760762,3554.0,0.7375525778413851,0.2901615759608,3581.0,3660.195437669754,4043.0554831028,3660.195437669754,381.1527545452118,1.184983491897583,0.0 -15400,0.28295895,0.27120045,,,,,,,,,,,,,, -15500,0.23001109,0.29935575,,,,,,,,,,,,,, -15600,0.118166246,0.3499276,,,,,,,,,,,,,, -15682,,,0.7433926037379673,0.2603464467184884,0.720238612083216,0.2883091262705754,3554.0,0.7375366926792446,0.2897529932281485,3581.0,3740.1879727840414,4127.107687950134,3740.1879727840414,385.1748712062836,1.210752248764038,0.0 -15700,0.33028477,0.26001257,,,,,,,,,,,,,, -15800,0.23751254,0.29444197,,,,,,,,,,,,,, -15900,0.21354717,0.2654656,,,,,,,,,,,,,, -16000,0.34335852,0.30018803,,,,,,,,,,,,,, -16031,,,0.74422516141619,0.2601829086031232,0.7211628979890616,0.2879862272813203,3554.0,0.7384739854221936,0.2894294267924462,3581.0,3820.219638109207,4211.203547954559,3820.219638109207,389.2005672454834,1.2374587059020996,0.0 -16100,0.11563159,0.304487,,,,,,,,,,,,,, -16200,0.26027492,0.25283003,,,,,,,,,,,,,, -16300,0.36076146,0.21305561,,,,,,,,,,,,,, -16377,,,0.745075157710484,0.2603557280131748,0.7221244163706387,0.2882707946811691,3554.0,0.739368599575014,0.28971559832929,3581.0,3900.333717823029,4295.38410115242,3900.333717823029,393.2281312942505,1.2644691467285156,0.0 -16400,0.33562195,0.24232288,,,,,,,,,,,,,, -16500,0.12027423,0.19445488,,,,,,,,,,,,,, -16600,0.1594101,0.36574125,,,,,,,,,,,,,, -16700,0.078423955,0.32743827,,,,,,,,,,,,,, -16726,,,0.7437037059238979,0.260012013571603,0.720381359471722,0.28797430876741,3554.0,0.7377609938957345,0.2894622197666504,3581.0,3980.330514192581,4379.440765142441,3980.330514192581,397.2499804496765,1.2904927730560305,0.0 -16800,0.3496279,0.23205927,,,,,,,,,,,,,, -16900,0.088670485,0.41694582,,,,,,,,,,,,,, -17000,0.31356508,0.29867202,,,,,,,,,,,,,, -17075,,,0.7442690304347447,0.2600981848580496,0.7208010148116559,0.2881094997494021,3554.0,0.7381862117338034,0.289546417943661,3581.0,4060.434873819351,4463.605730295181,4060.434873819351,401.2716295719147,1.3176743984222412,0.0 -17100,0.14886418,0.27416268,,,,,,,,,,,,,, -17200,0.07909513,0.2948548,,,,,,,,,,,,,, -17300,0.09791688,0.27288336,,,,,,,,,,,,,, -17400,0.12708367,0.2902847,,,,,,,,,,,,,, -17417,,,0.7453197070530483,0.259650605065482,0.7217897362566826,0.2878206904577149,3554.0,0.7391449801251396,0.2892599055235095,3581.0,4140.43146109581,4547.659974336624,4140.43146109581,405.2911117076874,1.3442814350128174,0.0 -17500,0.24295758,0.31701723,,,,,,,,,,,,,, -17600,0.18057238,0.30139926,,,,,,,,,,,,,, -17700,0.16715044,0.22990298,,,,,,,,,,,,,, -17767,,,0.743206364767892,0.2613401753561837,0.7201920371412492,0.2893077740859771,3554.0,0.7371324732529322,0.2909241319202038,3581.0,4220.44064950943,4631.727960109711,4220.44064950943,409.3108265399933,1.3714284896850586,0.0 -17800,0.41534132,0.3058355,,,,,,,,,,,,,, -17900,0.21694227,0.26292178,,,,,,,,,,,,,, -18000,0.07909359,0.33942094,,,,,,,,,,,,,, -18100,0.13092895,0.28316072,,,,,,,,,,,,,, -18118,,,0.7448388508387974,0.2603205101830618,0.7218726506445202,0.288125162119267,3554.0,0.7390620091280369,0.2895894033287664,3581.0,4300.534623622894,4715.887307167053,4300.534623622894,413.335652589798,1.4000122547149658,0.0 -18200,0.10498818,0.27447057,,,,,,,,,,,,,, -18300,0.09327976,0.20554093,,,,,,,,,,,,,, -18400,0.08374764,0.2868206,,,,,,,,,,,,,, -18463,,,0.7445802688598633,0.2599183320999145,0.7213314058543191,0.2878617183103545,3554.0,0.7387730764364004,0.2892399297616762,3581.0,4380.675292730331,4800.087876796722,4380.675292730331,417.3572266101837,1.4266109466552734,0.0 -18500,0.23203273,0.28624144,,,,,,,,,,,,,, -18600,0.0929152,0.3424744,,,,,,,,,,,,,, -18700,0.1685609,0.37393087,,,,,,,,,,,,,, -18800,0.23991957,0.2915552,,,,,,,,,,,,,, -18812,,,0.7440087454659599,0.2601046221596854,0.7206378651255627,0.2881414770878763,3554.0,0.7380600849099413,0.289521124402227,3581.0,4460.688131093979,4884.164512872696,4460.688131093979,421.3814389705658,1.4544930458068848,0.0 -18900,0.13814647,0.25155094,,,,,,,,,,,,,, -19000,0.20874667,0.28683597,,,,,,,,,,,,,, -19100,0.18551958,0.26843867,,,,,,,,,,,,,, -19160,,,0.7451982498168945,0.2602435520717076,0.7221053192705402,0.2882147055364554,3554.0,0.7394270951506213,0.2896156172551312,3581.0,4540.853692531586,4968.392388820648,4540.853692531586,425.4045426845551,1.481736421585083,0.0 -19200,0.22229187,0.2275355,,,,,,,,,,,,,, -19300,0.2094212,0.28781682,,,,,,,,,,,,,, -19400,0.09465431,0.24163213,,,,,,,,,,,,,, -19500,0.067650124,0.29747126,,,,,,,,,,,,,, -19506,,,0.7456203869410923,0.2597181286130632,0.7225618636131823,0.2874624996153102,3554.0,0.739859130654845,0.2889937779251605,3581.0,4620.966006994247,5052.569364309311,4620.966006994247,429.4290335178375,1.5102970600128174,0.0 -19600,0.14287992,0.33784226,,,,,,,,,,,,,, -19700,0.20784059,0.22217992,,,,,,,,,,,,,, -19800,0.1522107,0.2861787,,,,,,,,,,,,,, -19853,,,0.7458940914699009,0.2593709911618914,0.7224794300875774,0.2875328085431907,3554.0,0.7397810002007121,0.2890187646716176,3581.0,4701.128844022751,5136.8013026714325,4701.128844022751,433.454980134964,1.5415527820587158,0.0 -19900,0.16306585,0.30166325,,,,,,,,,,,,,, -20000,0.08762396,0.23853545,,,,,,,,,,,,,, -20100,0.16545598,0.35626343,,,,,,,,,,,,,, -20199,,,0.7451731136866978,0.2595722845622471,0.7218680481060074,0.2875190524486054,3554.0,0.7392424045744904,0.288969029797019,3581.0,4781.280727386475,5221.013342380524,4781.280727386475,437.4760012626648,1.568915605545044,0.0 -20200,0.33252552,0.3462565,,,,,,,,,,,,,, -20300,0.19574158,0.23568165,,,,,,,,,,,,,, -20400,0.19540375,0.21368738,,,,,,,,,,,,,, -20500,0.08500764,0.24006152,,,,,,,,,,,,,, -20545,,,0.7453296525137765,0.25978684425354,0.7219164091076955,0.2878251212597161,3554.0,0.7392852876946034,0.2892529174157009,3581.0,4861.3911328315735,5305.1836223602295,4861.3911328315735,441.496239900589,1.5969295501708984,0.0 -20600,0.2652213,0.28748375,,,,,,,,,,,,,, -20700,0.16167334,0.24799699,,,,,,,,,,,,,, -20800,0.19468999,0.3232913,,,,,,,,,,,,,, -20883,,,0.7458878244672503,0.2595101594924927,0.7223037779834341,0.2876796947816281,3554.0,0.7395689026066392,0.2891802410944917,3581.0,4941.5508685112,5389.406239509583,4941.5508685112,445.5145680904389,1.6300604343414309,0.0 -20900,0.14894621,0.341617,,,,,,,,,,,,,, -21000,0.15519986,0.29520935,,,,,,,,,,,,,, -21100,0.10671085,0.31126294,,,,,,,,,,,,,, -21200,0.3592474,0.2685673,,,,,,,,,,,,,, -21230,,,0.7447165080479213,0.2599371671676636,0.7209138113525253,0.288131001160664,3554.0,0.7384011727476613,0.2896083905290073,3581.0,5021.575328111649,5473.500054836273,5021.575328111649,449.53950595855713,1.6624484062194824,0.0 -21300,0.19810839,0.30154535,,,,,,,,,,,,,, -21400,0.117390335,0.2767827,,,,,,,,,,,,,, -21500,0.32284206,0.23473558,,,,,,,,,,,,,, -21578,,,0.7447715486798968,0.259907671383449,0.7211905819147439,0.2880903339546989,3554.0,0.7385272995715233,0.2895140340294261,3581.0,5101.56090760231,5557.542322397232,5101.56090760231,453.5573410987854,1.689453363418579,0.0 -21600,0.1904961,0.3812329,,,,,,,,,,,,,, -21700,0.1669185,0.2699458,,,,,,,,,,,,,, -21800,0.36782724,0.24212663,,,,,,,,,,,,,, -21900,0.21968885,0.22641212,,,,,,,,,,,,,, -21926,,,0.7449408258710589,0.2598998546600342,0.7213057140721723,0.2881070782645874,3554.0,0.7386807652366657,0.2895594737743472,3581.0,5181.675065755844,5641.722361803055,5181.675065755844,457.5829334259033,1.7177293300628662,0.0 -22000,0.3343571,0.31311244,,,,,,,,,,,,,, -22100,0.08526451,0.29862663,,,,,,,,,,,,,, -22200,0.09969154,0.28603047,,,,,,,,,,,,,, -22274,,,0.7466844831194196,0.2593822990145002,0.7233081617983258,0.2875330489743071,3554.0,0.7404866968243856,0.289036456515289,3581.0,5261.723633766174,5725.83787894249,5261.723633766174,461.6102552413941,1.7452080249786377,0.0 -22300,0.16027395,0.2602321,,,,,,,,,,,,,, -22400,0.15736195,0.3091493,,,,,,,,,,,,,, -22500,0.0693572,0.217811,,,,,,,,,,,,,, -22600,0.15903042,0.2940609,,,,,,,,,,,,,, -22622,,,0.7454221589224679,0.2592719623020717,0.7220717963034609,0.2873049828868,3554.0,0.739449729802255,0.2886994933655927,3581.0,5341.895644664764,5810.07394862175,5341.895644664764,465.6337096691132,1.7741332054138184,0.0 -22700,0.24833553,0.25114417,,,,,,,,,,,,,, -22800,0.11853339,0.26424542,,,,,,,,,,,,,, -22900,0.18694532,0.24335416,,,,,,,,,,,,,, -22970,,,0.744358880179269,0.259899514062064,0.7203458443611072,0.2882900978650816,3554.0,0.7376963624205878,0.2897626061374267,3581.0,5421.932682275772,5894.17240691185,5421.932682275772,469.65551710128784,1.8019211292266848,0.0 -23000,0.20395152,0.30222404,,,,,,,,,,,,,, -23100,0.22693916,0.26581252,,,,,,,,,,,,,, -23200,0.07114672,0.3046645,,,,,,,,,,,,,, -23300,0.31389254,0.24166712,,,,,,,,,,,,,, -23319,,,0.7454833303179059,0.2591825383050101,0.7220026208365574,0.2873254882262943,3554.0,0.7392803107983106,0.2887416947190903,3581.0,5501.911800146103,5978.216735839844,5501.911800146103,473.6802542209625,1.8305857181549072,0.0 -23400,0.09830573,0.33142045,,,,,,,,,,,,,, -23500,0.09794275,0.21189556,,,,,,,,,,,,,, -23600,0.266449,0.23322108,,,,,,,,,,,,,, -23668,,,0.7462118012564523,0.259478007044111,0.7225752590610931,0.2877843853591464,3554.0,0.7398573580616448,0.2891561065563041,3581.0,5581.881755113602,6062.247819662094,5581.881755113602,477.7016234397888,1.8585996627807613,0.0 -23700,0.06513558,0.31694165,,,,,,,,,,,,,, -23800,0.18554914,0.2417578,,,,,,,,,,,,,, -23900,0.20505919,0.24930635,,,,,,,,,,,,,, -24000,0.11644036,0.28887743,,,,,,,,,,,,,, -24011,,,0.7459972926548549,0.2587435586111886,0.7221812268087014,0.2870924933009021,3554.0,0.7396680996491902,0.2884309795840896,3581.0,5661.870182514191,6146.304933309555,5661.870182514191,481.7267339229584,1.8905537128448489,0.0 -24100,0.11729857,0.23820853,,,,,,,,,,,,,, -24200,0.38359654,0.31375778,,,,,,,,,,,,,, -24300,0.17217064,0.2693602,,,,,,,,,,,,,, -24358,,,0.7446889877319336,0.2594368287495204,0.7210218679656725,0.2875983431960467,3554.0,0.7385344581209857,0.2889409069241308,3581.0,5741.839620113373,6230.340355634689,5741.839620113373,485.7525725364685,1.9191038608551023,0.0 -24400,0.22106187,0.2170951,,,,,,,,,,,,,, -24500,0.18665938,0.3887807,,,,,,,,,,,,,, -24600,0.23035601,0.24542615,,,,,,,,,,,,,, -24700,0.1284604,0.28900394,,,,,,,,,,,,,, -24706,,,0.7457452501569476,0.2587285212108067,0.7223903331853193,0.2868627957165341,3554.0,0.7397184822020735,0.2882839225251326,3581.0,5821.816328525543,6314.383749961853,5821.816328525543,489.77895855903625,1.9476697444915767,0.0 -24800,0.24757877,0.2859557,,,,,,,,,,,,,, -24900,0.268689,0.2616569,,,,,,,,,,,,,, -25000,0.16480231,0.2740733,,,,,,,,,,,,,, -25050,,,0.746821403503418,0.258384244782584,0.723183961953081,0.2866580342736617,3554.0,0.7405085815327422,0.2880520196108803,3581.0,5901.898707389832,6398.527945518494,5901.898707389832,493.80056977272034,1.9763381481170648,0.0 -25100,0.23361818,0.28525916,,,,,,,,,,,,,, -25200,0.18958725,0.30691138,,,,,,,,,,,,,, -25300,0.21084008,0.30590633,,,,,,,,,,,,,, -25398,,,0.7470160893031529,0.2592196294239589,0.7234911642251688,0.2875559414513136,3554.0,0.740772493389591,0.2889792903845818,3581.0,5981.937157869339,6482.623092412949,5981.937157869339,497.8165833950043,2.005166530609131,0.0 -25400,0.06304659,0.30361193,,,,,,,,,,,,,, -25500,0.23549469,0.28364024,,,,,,,,,,,,,, -25600,0.20321245,0.3164493,,,,,,,,,,,,,, -25700,0.08452516,0.31934035,,,,,,,,,,,,,, -25747,,,0.7450372150966099,0.2595665114266531,0.7211857045978123,0.2879249001730005,3554.0,0.7385599561924043,0.2892967209207449,3581.0,6062.071349143982,6566.819430589676,6062.071349143982,501.8346812725067,2.0372416973114014,0.0 -25800,0.09260072,0.27069786,,,,,,,,,,,,,, -25900,0.1486313,0.22249642,,,,,,,,,,,,,, -26000,0.25713328,0.3102435,,,,,,,,,,,,,, -26092,,,0.7474192210606166,0.2586601461683001,0.723627385626231,0.2871058887488129,3554.0,0.7407934918013473,0.2886063640459194,3581.0,6142.251312255859,6651.06591629982,6142.251312255859,505.8603518009186,2.0665762424468994,0.0 -26100,0.07784476,0.34500548,,,,,,,,,,,,,, -26200,0.116965815,0.30362588,,,,,,,,,,,,,, -26300,0.11074896,0.2789892,,,,,,,,,,,,,, -26400,0.09308785,0.24542326,,,,,,,,,,,,,, -26443,,,0.7478544371468681,0.2583139113017491,0.724158875782569,0.286810553469682,3554.0,0.7414119222982407,0.2882083827841385,3581.0,6222.414223909378,6735.290981769562,6222.414223909378,509.8812828063965,2.0960710048675537,0.0 -26443,,,,,,,,,,,6222.414223909378,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d8ebd30b3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -36.84709548950195,0.0,56.31122303009033,1,0,56.31122303009033,0.0012000000569969,6.910790920257568,10000,93.1584370136261,0.0009964923374354,6.9111857414245605,0.0010199999669566,6.91091251373291,50000 -54.637348890304565,0.0286989212036132,566.3269417285919,1491,0,566.3269417285919,0.0431000031530857,5.679708957672119,10000,621.0424935817719,0.0701729878783226,5.387164115905762,0.065420001745224,5.45823860168457,50000 -72.63711643218994,0.0562620162963867,1076.2741117477417,2979,0,1076.2741117477417,0.1184000074863433,4.798498630523682,10000,1149.0670676231384,0.1756218075752258,4.2726922035217285,0.157379999756813,4.3756303787231445,50000 -90.69883179664612,0.0825309753417968,1586.424509525299,4468,0,1586.424509525299,0.1955000162124633,4.1657586097717285,10000,1677.3552947044373,0.2839604616165161,3.509944200515747,0.2582399845123291,3.6567907333374015,50000 -108.77395868301392,0.1093926429748535,2096.635529756546,5958,0,2096.635529756546,0.2608000040054321,3.706625461578369,10000,2205.71702504158,0.3703563511371612,2.9588210582733154,0.3423799872398376,3.106684684753418,50000 -126.81416702270508,0.1372864246368408,2606.764097929001,7447,0,2606.764097929001,0.314300000667572,3.368595123291016,10000,2733.9635617733,0.437519907951355,2.5694799423217773,0.407260000705719,2.7285168170928955,50000 -144.91092801094055,0.1650645732879638,3116.7672028541565,8937,0,3116.7672028541565,0.3574000298976898,3.0720314979553223,10000,3262.138736724853,0.4955755770206451,2.2581822872161865,0.4651599824428558,2.403724908828736,50000 -164.654226064682,0.1945652961730957,3626.985067844391,10428,0,3626.985067844391,0.3729000091552734,2.977101802825928,10000,3792.1785800457,0.5627790093421936,1.9515559673309328,0.4867599904537201,2.320361614227295,50000 -187.7570457458496,0.2273647785186767,4137.04333615303,11920,0,4137.04333615303,0.4123000204563141,2.767740249633789,10000,4325.421687364578,0.5782246589660645,1.846057415008545,0.52947998046875,2.106097936630249,50000 -209.34475588798523,0.2544946670532226,4647.025430679321,13413,0,4647.025430679321,0.4132000207901001,2.7203760147094727,10000,4857.068476676941,0.5894252061843872,1.8045969009399407,0.5416799783706665,2.0362939834594727,50000 -231.6502606868744,0.2813937664031982,5157.151046037674,14907,0,5157.151046037674,0.428600013256073,2.705935001373291,10000,5389.574882030487,0.5977359414100647,1.8043289184570312,0.5502399802207947,2.0121490955352783,50000 -252.6213595867157,0.3276073932647705,5667.149292945862,16400,0,5667.149292945862,0.4308000206947326,2.7112877368927,10000,5920.640057086945,0.5902822017669678,1.8064643144607544,0.5517199635505676,1.990447163581848,50000 -276.202490568161,0.3575654029846191,6177.358016967773,17895,0,6177.358016967773,0.4360000193119049,2.634429931640625,10000,6454.508526086807,0.5936902165412903,1.7567952871322632,0.5552600026130676,1.939975380897522,50000 -298.885041475296,0.3975296020507812,6687.38344168663,19389,0,6687.38344168663,0.4310000240802765,2.670055866241455,10000,6987.305696725845,0.6022600531578064,1.7681397199630735,0.5593400001525879,1.9684282541275024,50000 -322.6722848415375,0.4260029792785644,7197.566004276276,20884,0,7197.566004276276,0.4415000081062317,2.5937812328338623,10000,7521.352801322937,0.6381337642669678,1.587932586669922,0.5699599981307983,1.8987358808517456,50000 -345.7115654945373,0.4558007717132568,7707.807578086853,22379,0,7707.807578086853,0.4531000256538391,2.565328359603882,10000,8054.712727546692,0.6238042116165161,1.6296091079711914,0.5750600099563599,1.864007472991944,50000 -367.9157955646515,0.4832274913787842,8217.97511434555,23874,0,8217.97511434555,0.4646000266075134,2.491010904312134,10000,8587.158815145493,0.6375358700752258,1.5796483755111694,0.586899995803833,1.815063118934632,50000 -390.196186542511,0.512209415435791,8727.92223238945,25369,0,8727.92223238945,0.4585000276565552,2.4793083667755127,10000,9119.464904785156,0.6294642686843872,1.597659945487976,0.584119975566864,1.8133976459503167,50000 -411.78082633018494,0.5465049743652344,9237.966356039047,26864,0,9237.966356039047,0.464100033044815,2.510469913482666,10000,9651.175999879835,0.6300422549247742,1.63661527633667,0.5836600065231323,1.8408207893371584,50000 -434.0755660533905,1.5038611888885498,9747.298719406128,28358,0,9747.298719406128,0.4733000099658966,2.4336440563201904,10000,10183.808728456495,0.6334103941917419,1.5800617933273315,0.5934999585151672,1.7827062606811523,50000 -458.8446247577667,1.535851240158081,10257.261818170547,29854,0,10257.261818170547,0.4643000364303589,2.4862213134765625,10000,10718.622612953186,0.6729910373687744,1.422000288963318,0.5934999585151672,1.7916085720062256,50000 -483.3538534641266,1.5658612251281738,10767.481489896774,31351,0,10767.481489896774,0.4695000350475311,2.46608304977417,10000,11253.432497739792,0.6510881781578064,1.5065069198608398,0.5972599983215332,1.757551193237305,50000 -506.2302823066712,1.5994904041290283,11277.71798467636,32848,0,11277.71798467636,0.4761000275611877,2.4272708892822266,10000,11786.628136873243,0.6462252736091614,1.5267188549041748,0.6002399921417236,1.758679986000061,50000 -528.2995517253876,1.6302180290222168,11787.83618426323,34344,0,11787.83618426323,0.4770000278949737,2.4422802925109863,10000,12318.89468216896,0.6496133208274841,1.536085605621338,0.6037399768829346,1.748511791229248,50000 -551.7431771755219,1.6666617393493652,12297.890921831133,35840,0,12297.890921831133,0.4759000241756439,2.4189414978027344,10000,12852.47915005684,0.645926296710968,1.52766752243042,0.596560001373291,1.7534399032592771,50000 -573.2801125049591,1.6983842849731443,12807.821355819702,37336,0,12807.821355819702,0.4829000234603882,2.418141841888428,10000,13384.028679132462,0.6403061151504517,1.5559779405593872,0.6036199927330017,1.746623992919922,50000 -596.6457276344299,1.734304666519165,13317.77873301506,38832,0,13317.77873301506,0.4917000234127044,2.3655688762664795,10000,13917.43752002716,0.6729711294174194,1.439097881317139,0.6108199954032898,1.7267426252365112,50000 -617.3422248363495,1.76465106010437,13827.9108877182,40329,0,13827.9108877182,0.4941000342369079,2.351206064224243,10000,14448.347055912018,0.6744459271430969,1.4112752676010132,0.610260009765625,1.698612928390503,50000 -637.8876869678497,1.7998626232147217,14337.827102661133,41825,0,14337.827102661133,0.4783000349998474,2.419309377670288,10000,14978.894652605057,0.6544762253761292,1.509508728981018,0.6057199835777283,1.743627667427063,50000 -657.9168131351471,1.832062005996704,14847.95288991928,43322,0,14847.95288991928,0.4749000370502472,2.452561378479004,10000,15509.13153076172,0.6491350531578064,1.5379236936569214,0.6043199896812439,1.748526930809021,50000 -676.1758737564087,1.8695974349975584,15358.028193473816,44819,0,15358.028193473816,0.4844000339508056,2.402400016784668,10000,16037.553413391111,0.6519451141357422,1.5086034536361694,0.6087799668312073,1.7206631898880005,50000 -693.9103631973267,1.903987169265747,15867.962439060211,46315,0,15867.962439060211,0.4705000221729278,2.4317517280578613,10000,16565.30499601364,0.6457070708274841,1.543419361114502,0.6055799722671509,1.7364466190338137,50000 -712.2113356590271,1.949737787246704,16378.023522853851,47812,0,16378.023522853851,0.4796000123023987,2.4064650535583496,10000,17093.762252807617,0.6532405614852905,1.498455286026001,0.6109600067138672,1.7112398147583008,50000 -729.52081990242,1.992770195007324,16887.94068956375,49308,0,16887.94068956375,0.4860000312328338,2.362870216369629,10000,17621.080164194107,0.6916453838348389,1.3557281494140625,0.6181399822235107,1.6922627687454224,50000 -746.666835308075,2.0300121307373047,17398.12221980095,50807,0,17398.12221980095,0.4921000301837921,2.351040840148926,10000,18148.496007680893,0.671297013759613,1.4186804294586182,0.6159999966621399,1.6773052215576172,50000 -763.9545395374298,2.067458152770996,17908.237648248672,52304,0,17908.237648248672,0.4946000277996063,2.344132661819458,10000,18675.98785209656,0.6713567972183228,1.4391323328018188,0.6169399619102478,1.6796952486038208,50000 -781.1611652374268,2.107489109039306,18418.46981525421,53801,0,18418.46981525421,0.4966000318527221,2.325680494308472,10000,19203.51691007614,0.6638432741165161,1.4525445699691772,0.6154999732971191,1.6743123531341553,50000 -798.3581395149231,2.149481773376465,18928.68518781662,55299,0,18928.68518781662,0.4979000091552734,2.3356235027313232,10000,19731.023154973984,0.6633649468421936,1.4799798727035522,0.6179400086402893,1.6885169744491575,50000 -815.757345199585,2.195958137512207,19438.76026201248,56797,0,19438.76026201248,0.508400022983551,2.2914958000183105,10000,20258.59286975861,0.6719148755073547,1.425616979598999,0.6231799721717834,1.6395264863967896,50000 -833.042439699173,2.232260465621948,19948.85229182244,58294,0,19948.85229182244,0.5078000426292419,2.266110897064209,10000,20786.05773258209,0.7206233739852905,1.2154070138931274,0.6320399641990662,1.6034917831420898,50000 -850.6091375350952,2.270244836807251,20458.78334474564,59792,0,20458.78334474564,0.4915000200271606,2.3457841873168945,10000,21313.644002199173,0.6831552982330322,1.3804254531860352,0.6195999979972839,1.6663074493408203,50000 -867.795756816864,2.307173728942871,20968.790618896484,61289,0,20968.790618896484,0.5031999945640564,2.3163297176361084,10000,21840.92485141754,0.6767179369926453,1.392195224761963,0.6308199763298035,1.6244053840637207,50000 -884.985392332077,2.344024419784546,21478.93054676056,62787,0,21478.93054676056,0.5047000050544739,2.289322853088379,10000,22368.341687202454,0.682637095451355,1.37348473072052,0.6302399635314941,1.626953125,50000 -902.2461650371552,2.3862104415893555,21988.93044018745,64284,0,21988.93044018745,0.5065000057220459,2.2937440872192383,10000,22895.69519519806,0.6745256781578064,1.4175406694412231,0.6255800127983093,1.6421574354171753,50000 -919.5094072818756,2.438976287841797,22499.075800418854,65782,0,22499.075800418854,0.504800021648407,2.2867746353149414,10000,23423.206660985947,0.6729512214660645,1.4147884845733645,0.6275599598884583,1.630405068397522,50000 -936.7078473567964,2.4785714149475098,23009.149444818497,67279,0,23009.149444818497,0.5004000067710876,2.3088037967681885,10000,23950.570457935333,0.6883569955825806,1.3431607484817505,0.6266799569129944,1.618585467338562,50000 -953.9149236679076,2.5172877311706543,23519.17913007736,68777,0,23519.17913007736,0.5062000155448914,2.276012897491455,10000,24477.8970515728,0.7035036683082581,1.274147391319275,0.6340799927711487,1.591722011566162,50000 -971.5965132713318,2.5620453357696533,24029.25793027877,70275,0,24029.25793027877,0.5059000253677368,2.2884953022003174,10000,25005.75264811516,0.6885961294174194,1.343567132949829,0.6320799589157104,1.6132029294967651,50000 -988.8352327346802,2.6068990230560303,24539.311646461487,71773,0,24539.311646461487,0.5006000399589539,2.317592144012451,10000,25533.141258955,0.6815010905265808,1.3888893127441406,0.626259982585907,1.645130634307861,50000 -1006.3702020645142,2.651460647583008,25049.316918611526,73271,0,25049.316918611526,0.5053000450134277,2.248891592025757,10000,26060.77575063705,0.6846500039100647,1.3535865545272827,0.6345799565315247,1.5897456407546997,50000 -1024.0254747867584,2.6931560039520264,25559.289101839066,74769,0,25559.289101839066,0.5216000080108643,2.2021830081939697,10000,26588.49595093727,0.6935586333274841,1.3076993227005005,0.6421399712562561,1.5479406118392944,50000 -1041.3056933879852,2.7349095344543457,26069.32606625557,76267,0,26069.32606625557,0.5103999972343445,2.274660348892212,10000,27115.904922246933,0.6836535334587097,1.3648991584777832,0.6360799670219421,1.5922093391418457,50000 -1058.6034083366394,2.780672311782837,26579.532474040985,77765,0,26579.532474040985,0.5115000009536743,2.2503161430358887,10000,27643.506959199905,0.7120934128761292,1.2422109842300415,0.638700008392334,1.562541961669922,50000 -1075.7997291088104,2.820523977279663,27089.686541080475,79263,0,27089.686541080475,0.5182999968528748,2.221109390258789,10000,28170.947011709213,0.7058752775192261,1.28544819355011,0.6415599584579468,1.569705605506897,50000 -1092.9376814365387,2.8669273853302,27599.76722025872,80762,0,27599.76722025872,0.5182999968528748,2.2196383476257324,10000,28698.262020349503,0.7010523080825806,1.3063050508499146,0.648140013217926,1.5532654523849487,50000 -1110.088816165924,2.9066665172576904,28109.76361966133,82260,0,28109.76361966133,0.5174000263214111,2.2001779079437256,10000,29225.50002479553,0.6965281963348389,1.2849880456924438,0.64656001329422,1.5245449542999268,50000 -1127.1571650505066,2.946876287460327,28619.767454862595,83758,0,28619.767454862595,0.5253000259399414,2.1855990886688232,10000,29752.6642165184,0.7009526491165161,1.2806954383850098,0.6505799889564514,1.5104814767837524,50000 -1144.5813839435575,2.9912562370300293,29129.908967256542,85257,0,29129.908967256542,0.5213000178337097,2.2066726684570312,10000,30280.324359178543,0.7001753449440002,1.2842198610305786,0.6523399949073792,1.5107489824295044,50000 -1161.4845032691956,3.0346953868865967,29639.830031633377,86755,0,29639.830031633377,0.5329000353813171,2.135961532592773,10000,30807.24199461937,0.7459940910339355,1.084036946296692,0.6534799933433533,1.4911538362503052,50000 -1178.5891389846802,3.0897841453552246,30150.005770921707,88253,0,30150.005770921707,0.5353000164031982,2.1265015602111816,10000,31334.62922167778,0.7254464030265808,1.1623930931091309,0.6578199863433838,1.4649604558944702,50000 -1195.8721101284027,3.136313915252685,30660.09109258652,89751,0,30660.09109258652,0.5382000207901001,2.107398748397827,10000,31862.09580183029,0.7215999364852905,1.1721854209899902,0.6617000102996826,1.4494667053222656,50000 -1213.0246744155884,3.181325912475586,31170.067274332047,91249,0,31170.067274332047,0.5379000306129456,2.1280786991119385,10000,32389.319952964783,0.7166972160339355,1.2218282222747805,0.6592199802398682,1.4793634414672852,50000 -1230.3129467964172,3.225436210632324,31679.97949957848,92747,0,31679.97949957848,0.5285000205039978,2.168752908706665,10000,32916.61490535736,0.7115951776504517,1.2422090768814087,0.6555399894714355,1.498248815536499,50000 -1247.537139415741,3.262047529220581,32189.99800372124,94245,0,32189.99800372124,0.5364000201225281,2.136128187179565,10000,33443.94494795799,0.7129902839660645,1.244142770767212,0.6612399816513062,1.4770818948745728,50000 -1264.8968999385834,3.308539628982544,32699.949419021606,95743,0,32699.949419021606,0.5428000092506409,2.1002559661865234,10000,33971.354709625244,0.742586076259613,1.1307847499847412,0.669219970703125,1.4510703086853027,50000 -1281.9905200004578,3.355043172836304,33210.114094257355,97241,0,33210.114094257355,0.5376999974250793,2.125854253768921,10000,34498.70940423012,0.7416892647743225,1.120581030845642,0.6673399806022644,1.463154911994934,50000 -1299.1062581539154,3.401531219482422,33720.06629896164,98739,0,33720.06629896164,0.5424000024795532,2.1051735877990723,10000,35025.87611365318,0.7276586294174194,1.1531447172164917,0.666920006275177,1.4372801780700684,50000 -1316.049479007721,3.4516913890838623,34229.99035668373,100237,0,34229.99035668373,0.5361000299453735,2.1154842376708984,10000,35552.84500479698,0.7302096486091614,1.1675571203231812,0.6653199791908264,1.4513657093048096,50000 -1333.1427223682404,3.4981353282928467,34739.97491002083,101735,0,34739.97491002083,0.5398000478744507,2.113425970077514,10000,36080.01940727234,0.7242107391357422,1.1790367364883425,0.6686800122261047,1.4343260526657104,50000 -1350.248485803604,3.549124956130981,35250.05642032623,103233,0,35250.05642032623,0.5520000457763672,2.0522189140319824,10000,36607.30830931664,0.7329798936843872,1.1520893573760986,0.676859974861145,1.4166473150253296,50000 -1367.418380498886,3.5954058170318604,35759.99800825119,104731,0,35759.99800825119,0.5373000502586365,2.095768928527832,10000,37134.51687049866,0.7315847873687744,1.141116499900818,0.6702199578285217,1.417891502380371,50000 -1384.609278678894,3.6411328315734863,36270.085739851,106229,0,36270.085739851,0.5469000339508057,2.085392951965332,10000,37661.89115190506,0.7540457248687744,1.072227954864502,0.6693599820137024,1.442069172859192,50000 -1401.521045923233,3.6874375343322754,36780.10208392143,107727,0,36780.10208392143,0.5508000254631042,2.0545578002929688,10000,38188.91479277611,0.751973032951355,1.068299651145935,0.6773999929428101,1.4052557945251465,50000 -1418.6850700378418,3.740187406539917,37290.17399263382,109225,0,37290.17399263382,0.5574000477790833,2.0311355590820312,10000,38716.25345778465,0.7469507455825806,1.086888074874878,0.6785999536514282,1.395654797554016,50000 -1435.787403345108,3.788019895553589,37800.17114710808,110723,0,37800.17114710808,0.5574000477790833,2.017634391784668,10000,39243.452078580856,0.7480069994926453,1.080083966255188,0.6832799911499023,1.3639898300170898,50000 -1452.9339735507965,3.8331050872802734,38310.2585234642,112221,0,38310.2585234642,0.5624000430107117,2.006924152374268,10000,39770.78035974503,0.7472297549247742,1.0725810527801514,0.6823199987411499,1.363050103187561,50000 -1470.0948634147644,3.881237268447876,38820.35129117966,113719,0,38820.35129117966,0.5601000189781189,2.016330003738404,10000,40298.13475847244,0.7479073405265808,1.0708001852035522,0.6868199706077576,1.3397966623306274,50000 -1487.3175942897797,3.92743730545044,39330.478276491165,115217,0,39330.478276491165,0.5639000535011292,1.989737153053284,10000,40825.58081579208,0.7872887253761292,0.9008607268333436,0.6908800005912781,1.3157756328582764,50000 -1504.6061329841614,3.976889371871948,39840.57607722282,116716,0,39840.57607722282,0.5701000094413757,1.9820610284805296,10000,41353.06807017326,0.7737165093421936,0.961524486541748,0.6925399899482727,1.3126429319381714,50000 -1521.836299419403,4.030749559402466,40350.64465737343,118214,0,40350.64465737343,0.5715000033378601,1.9586877822875977,10000,41880.47035765648,0.7669204473495483,1.010676383972168,0.6954599618911743,1.3212053775787354,50000 -1539.2372624874115,4.080804824829102,40860.65657663345,119712,0,40860.65657663345,0.5733000040054321,1.962609052658081,10000,42407.98349690437,0.7684550285339355,1.0029422044754028,0.6985799670219421,1.310947299003601,50000 -1556.559273481369,4.127295732498169,41370.574806690216,121210,0,41370.574806690216,0.5770000219345093,1.9499543905258176,10000,42935.32121896744,0.7650868892669678,0.9947850704193116,0.6988399624824524,1.2936506271362305,50000 -1573.6708698272705,4.196326732635498,41880.55111408234,122708,0,41880.55111408234,0.5733000040054321,1.966423749923706,10000,43462.52871155739,0.7655652165412903,1.012738823890686,0.6997199654579163,1.3031195402145386,50000 -1590.7367997169497,4.2473227977752686,42390.75530362129,124206,0,42390.75530362129,0.5795000195503235,1.902595043182373,10000,43989.89999890328,0.7978515625,0.8683147430419922,0.7073799967765808,1.261612057685852,50000 -1607.9538469314575,4.293523788452148,42900.89259982109,125705,0,42900.89259982109,0.5808000564575195,1.9092512130737305,10000,44517.35197234154,0.7971938848495483,0.8802825808525085,0.7064599990844727,1.266579031944275,50000 -1625.3317823410034,4.347227096557617,43410.79142928124,127202,0,43410.79142928124,0.5782000422477722,1.8972203731536863,10000,45044.73135638237,0.7934072017669678,0.8894890546798706,0.7084199786186218,1.2445156574249268,50000 -1642.3397538661957,4.400299549102783,43920.76516246796,128699,0,43920.76516246796,0.5886000394821167,1.8822438716888428,10000,45571.81634020805,0.7883649468421936,0.9085213541984558,0.7102800011634827,1.2549796104431152,50000 -1659.612549781799,4.4501917362213135,44430.832174539566,130197,0,44430.832174539566,0.5909000039100647,1.8738075494766235,10000,46099.25558972359,0.7904177308082581,0.8950925469398499,0.7142800092697144,1.2327156066894531,50000 -1676.6434774398804,4.499432563781738,44940.92831659317,131695,0,44940.92831659317,0.5855000019073486,1.8798346519470213,10000,46626.481568574905,0.7917131781578064,0.8998433351516724,0.7148999571800232,1.22651207447052,50000 -1693.8448660373688,4.552513122558594,45450.86598086357,133193,0,45450.86598086357,0.5919000506401062,1.845463752746582,10000,47153.72575092316,0.7970942258834839,0.8763123750686646,0.7187199592590332,1.2121546268463137,50000 -1711.0298657417295,4.6041131019592285,45961.00407743454,134692,0,45961.00407743454,0.5961000323295593,1.841491937637329,10000,47681.152082681656,0.8205516338348389,0.7716889381408691,0.7160199880599976,1.2104519605636597,50000 -1728.2518351078031,4.656829357147217,46470.92022943497,136190,0,46470.92022943497,0.5961000323295593,1.8225889205932613,10000,48208.3938832283,0.8130978941917419,0.8022621870040894,0.7203399538993835,1.197677493095398,50000 -1745.192828655243,4.712372779846191,46980.9376642704,137689,0,46980.9376642704,0.5981000065803528,1.852874755859375,10000,48735.45760130882,0.810566782951355,0.8220297694206238,0.722819983959198,1.207771062850952,50000 -1762.3447728157043,4.771524906158447,47490.93303322792,139187,0,47490.93303322792,0.5973000526428223,1.825844049453736,10000,49262.71445965767,0.8121013641357422,0.8041132688522339,0.7247599959373474,1.185956954956055,50000 -1779.533354997635,4.824806928634644,48000.869512319565,140684,0,48000.869512319565,0.6026000380516052,1.7988563776016235,10000,49789.94197225571,0.8148915767669678,0.7869499921798706,0.7278199791908264,1.1672886610031128,50000 -1797.0130088329315,4.874320030212402,48510.98008084297,142182,0,48510.98008084297,0.6078000068664551,1.7888727188110352,10000,50317.63361525536,0.8167450428009033,0.7779141664505005,0.7300800085067749,1.1511822938919067,50000 -1814.199674367905,4.921878337860107,49021.22454404831,143681,0,49021.22454404831,0.6038000583648682,1.7933541536331177,10000,50845.16202926636,0.8469786047935486,0.6639379262924194,0.7300399541854858,1.1502034664154053,50000 -1831.4566979408264,4.974995851516724,49531.14088320732,145179,0,49531.14088320732,0.6086000204086304,1.7797400951385498,10000,51372.43950200081,0.8422552347183228,0.703779399394989,0.7349399924278259,1.1473641395568848,50000 -1848.4796833992004,5.031417369842529,50041.08595466614,146677,0,50041.08595466614,0.6091000437736511,1.7796084880828855,10000,51899.51325583458,0.8409199714660645,0.7093674540519714,0.7373799681663513,1.132510542869568,50000 -1865.7171568870544,5.085627317428589,50551.017679452896,148175,0,50551.017679452896,0.6115000247955322,1.7595330476760864,10000,52426.786516428,0.8373923897743225,0.7035523056983948,0.737280011177063,1.1355136632919312,50000 -1882.9315605163567,5.151835203170776,51060.90836381912,149673,0,51060.90836381912,0.6195000410079956,1.7411073446273804,10000,52954.00826811791,0.8453842401504517,0.686999499797821,0.7429400086402893,1.107838749885559,50000 -1900.640768289566,5.20801043510437,51571.01009392738,151171,0,51571.01009392738,0.6212000250816345,1.727745532989502,10000,53481.925209999084,0.8451052308082581,0.675420880317688,0.7450799942016602,1.1006088256835938,50000 -1917.6977479457853,5.260669708251953,52081.04586791992,152669,0,52081.04586791992,0.6201000213623047,1.735266089439392,10000,54009.12109827995,0.8661710619926453,0.6043965220451355,0.7461599707603455,1.0976543426513672,50000 -1934.90061545372,5.321510553359985,52591.11575245857,154167,0,52591.11575245857,0.6278000473976135,1.6983375549316406,10000,54536.50642514229,0.8671875,0.5794281959533691,0.7481399774551392,1.0786703824996948,50000 -1952.2021520137787,5.373865604400635,53101.14227557182,155665,0,53101.14227557182,0.6246000528335571,1.6996747255325315,10000,55063.9355969429,0.8687419891357422,0.5882298946380615,0.7505999803543091,1.0769861936569214,50000 -1969.1301229000087,5.426816463470459,53611.06381011009,157162,0,53611.06381011009,0.6239000558853149,1.6967737674713137,10000,55590.8880238533,0.8694794178009033,0.580014169216156,0.7535199522972107,1.0628749132156372,50000 -1986.1422533988955,5.484708070755005,54121.251658678055,158660,0,54121.251658678055,0.6336000561714172,1.6853803396224976,10000,56118.19445109368,0.8698182106018066,0.5817106366157532,0.7541399598121643,1.064910888671875,50000 -2003.3706483840945,5.54168963432312,54631.442410707474,160158,0,54631.442410707474,0.6309000253677368,1.6826121807098389,10000,56645.72098970413,0.875418484210968,0.5687929391860962,0.756879985332489,1.0507851839065552,50000 -2020.6487910747528,5.598784685134888,55141.38109588623,161656,0,55141.38109588623,0.6301000118255615,1.673742413520813,10000,57173.04496335983,0.8792450428009033,0.5455688238143921,0.7571799755096436,1.0448676347732544,50000 -2037.681653022766,5.65578556060791,55651.28565096855,163154,0,55651.28565096855,0.6349000334739685,1.6725289821624756,10000,57700.09034585953,0.8960259556770325,0.491013616323471,0.7576000094413757,1.0471340417861938,50000 -2054.770972251892,5.715636491775513,56161.19146823883,164652,0,56161.19146823883,0.6367000341415405,1.661550521850586,10000,58227.19503569603,0.8957070708274841,0.4908312857151031,0.7618599534034729,1.0314052104949951,50000 -2071.8406381607056,5.775366306304932,56671.27508664131,166150,0,56671.27508664131,0.6398000121116638,1.6573363542556765,10000,58754.45888543129,0.8951490521430969,0.4891964495182037,0.7637199759483337,1.026726245880127,50000 -2088.97993516922,5.834702968597412,57181.31681752205,167648,0,57181.31681752205,0.6388000249862671,1.6397991180419922,10000,59281.74941539765,0.8947305083274841,0.4829561114311218,0.7655199766159058,1.0176388025283811,50000 -2106.1391632556915,5.890669107437134,57691.400824546814,169146,0,57691.400824546814,0.6415000557899475,1.655003786087036,10000,59809.09949350357,0.9001514315605164,0.4777239859104156,0.7644000053405762,1.024844527244568,50000 -2123.243331670761,5.947777032852173,58201.40622735024,170644,0,58201.40622735024,0.6467000246047974,1.6335299015045166,10000,60336.31550955773,0.9061303734779358,0.4618020951747894,0.7671999931335449,1.0098974704742432,50000 -2140.4393548965454,6.021910905838013,58711.40957093239,172142,0,58711.40957093239,0.6457000374794006,1.6340036392211914,10000,60863.64063549042,0.9163145422935486,0.4251190423965454,0.7689799666404724,1.0125985145568848,50000 -2157.631984949112,6.077967643737793,59221.40473794937,173640,0,59221.40473794937,0.6492000222206116,1.6199817657470703,10000,61390.93344545365,0.9162746667861938,0.415965586900711,0.7691599726676941,1.0013891458511353,50000 -2174.5137605667114,6.134857654571533,59731.5888364315,175138,0,59731.5888364315,0.6484000086784363,1.6202576160430908,10000,61918.1064593792,0.9125677347183228,0.4280580282211303,0.7706199884414673,1.000607967376709,50000 -2191.785396575928,6.190088748931885,60241.7090446949,176636,0,60241.7090446949,0.6476000547409058,1.617110252380371,10000,62445.60426354408,0.916214883327484,0.4126388728618622,0.7721799612045288,0.9947059154510498,50000 -2208.687886953354,6.247087478637695,60751.60144472122,178133,0,60751.60144472122,0.6497000455856323,1.6151281595230105,10000,62972.50779604912,0.9172711968421936,0.4106919467449188,0.7718799710273743,0.9955499172210692,50000 -2225.9638142585754,6.308993816375732,61261.76955938339,179631,0,61261.76955938339,0.65010005235672,1.6078531742095947,10000,63500.064401865005,0.918965220451355,0.4076812863349914,0.7737799882888794,0.989608645439148,50000 -2243.2266731262207,6.367881774902344,61771.71801686287,181128,0,61771.71801686287,0.6499000191688538,1.609398603439331,10000,64027.38457632065,0.9202008843421936,0.4020408689975738,0.7730000019073486,0.992912530899048,50000 -2260.4189026355743,6.43097186088562,62281.93049740791,182626,0,62281.93049740791,0.6509000062942505,1.606026530265808,10000,64554.90435361862,0.9197225570678712,0.4042039811611175,0.7734400033950806,0.9904934167861938,50000 -2277.7590713500977,6.492360353469849,62791.98581242561,184124,0,62791.98581242561,0.6513000130653381,1.6081324815750122,10000,65082.41243624687,0.9210976958274841,0.3999505639076233,0.7736999988555908,0.9918732643127441,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 871805a3e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1974 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5895074,6.9289746,,,,,,,,,,,,,, -1,,,0.0009964923374354,6.9111857414245605,0.0010199999669566,6.91091251373291,50000.0,0.0012000000569969,6.910790920257568,10000.0,56.31122303009033,93.1584370136261,56.31122303009033,36.84709548950195,0.0,0.0 -100,0.5772568,6.899452,,,,,,,,,,,,,, -200,0.6113818,6.865883,,,,,,,,,,,,,, -300,0.62578416,6.7905064,,,,,,,,,,,,,, -400,0.67745554,6.702958,,,,,,,,,,,,,, -500,0.7048504,6.610506,,,,,,,,,,,,,, -600,0.71161443,6.5352387,,,,,,,,,,,,,, -700,0.75959605,6.4473276,,,,,,,,,,,,,, -800,0.837795,6.326313,,,,,,,,,,,,,, -900,1.3525826,6.246131,,,,,,,,,,,,,, -1000,1.2931962,6.1710443,,,,,,,,,,,,,, -1100,1.9297827,6.088641,,,,,,,,,,,,,, -1200,1.3950429,5.9900813,,,,,,,,,,,,,, -1300,1.7049736,5.9459696,,,,,,,,,,,,,, -1400,2.5060513,5.887013,,,,,,,,,,,,,, -1491,,,0.0701729878783226,5.387164115905762,0.065420001745224,5.45823860168457,50000.0,0.0431000031530857,5.679708957672119,10000.0,566.3269417285919,621.0424935817719,566.3269417285919,54.637348890304565,0.0286989212036132,0.0 -1500,2.3757594,5.8064814,,,,,,,,,,,,,, -1600,3.0248077,5.6401324,,,,,,,,,,,,,, -1700,2.7701116,5.6074715,,,,,,,,,,,,,, -1800,2.6698632,5.6147394,,,,,,,,,,,,,, -1900,3.0676565,5.4826274,,,,,,,,,,,,,, -2000,3.7299736,5.566607,,,,,,,,,,,,,, -2100,3.445632,5.4515476,,,,,,,,,,,,,, -2200,4.4386425,5.4381604,,,,,,,,,,,,,, -2300,4.120721,5.3542843,,,,,,,,,,,,,, -2400,4.999115,5.3305836,,,,,,,,,,,,,, -2500,4.6112485,5.2303023,,,,,,,,,,,,,, -2600,4.4293613,5.2050657,,,,,,,,,,,,,, -2700,4.9884267,5.166606,,,,,,,,,,,,,, -2800,3.5241106,5.087974,,,,,,,,,,,,,, -2900,3.9979315,5.1210103,,,,,,,,,,,,,, -2979,,,0.1756218075752258,4.2726922035217285,0.157379999756813,4.3756303787231445,50000.0,0.1184000074863433,4.798498630523682,10000.0,1076.2741117477417,1149.0670676231384,1076.2741117477417,72.63711643218994,0.0562620162963867,0.0 -3000,3.8317893,5.0123963,,,,,,,,,,,,,, -3100,4.362658,4.919945,,,,,,,,,,,,,, -3200,4.951066,4.9887514,,,,,,,,,,,,,, -3300,6.48279,4.8522587,,,,,,,,,,,,,, -3400,4.8726583,4.822034,,,,,,,,,,,,,, -3500,3.394552,4.8008423,,,,,,,,,,,,,, -3600,6.855064,4.811044,,,,,,,,,,,,,, -3700,4.4317327,4.686103,,,,,,,,,,,,,, -3800,4.8633823,4.66678,,,,,,,,,,,,,, -3900,5.929997,4.6586666,,,,,,,,,,,,,, -4000,3.9148645,4.730207,,,,,,,,,,,,,, -4100,6.7919655,4.562337,,,,,,,,,,,,,, -4200,3.2613764,4.599891,,,,,,,,,,,,,, -4300,4.483529,4.475386,,,,,,,,,,,,,, -4400,9.046588,4.375804,,,,,,,,,,,,,, -4468,,,0.2839604616165161,3.509944200515747,0.2582399845123291,3.6567907333374015,50000.0,0.1955000162124633,4.1657586097717285,10000.0,1586.424509525299,1677.3552947044373,1586.424509525299,90.69883179664612,0.0825309753417968,0.0 -4500,4.2707286,4.4645095,,,,,,,,,,,,,, -4600,8.197527,4.384858,,,,,,,,,,,,,, -4700,4.1373696,4.2663283,,,,,,,,,,,,,, -4800,4.442736,4.489892,,,,,,,,,,,,,, -4900,4.5336175,4.2477064,,,,,,,,,,,,,, -5000,4.6120095,4.2666106,,,,,,,,,,,,,, -5100,8.081155,4.142897,,,,,,,,,,,,,, -5200,6.059262,4.134467,,,,,,,,,,,,,, -5300,3.2196877,4.120341,,,,,,,,,,,,,, -5400,4.821855,4.091835,,,,,,,,,,,,,, -5500,5.791486,4.0673695,,,,,,,,,,,,,, -5600,10.132583,4.1694655,,,,,,,,,,,,,, -5700,7.987113,3.9863884,,,,,,,,,,,,,, -5800,5.1427655,4.0154376,,,,,,,,,,,,,, -5900,5.1984587,4.100184,,,,,,,,,,,,,, -5958,,,0.3703563511371612,2.9588210582733154,0.3423799872398376,3.106684684753418,50000.0,0.2608000040054321,3.706625461578369,10000.0,2096.635529756546,2205.71702504158,2096.635529756546,108.77395868301392,0.1093926429748535,0.0 -6000,6.5554724,4.005924,,,,,,,,,,,,,, -6100,3.6967092,3.9733691,,,,,,,,,,,,,, -6200,8.859809,3.940619,,,,,,,,,,,,,, -6300,5.9244857,3.874807,,,,,,,,,,,,,, -6400,6.5749717,3.950835,,,,,,,,,,,,,, -6500,4.7364063,3.8130717,,,,,,,,,,,,,, -6600,6.0530133,3.8931727,,,,,,,,,,,,,, -6700,8.60379,3.8447862,,,,,,,,,,,,,, -6800,8.308514,3.731717,,,,,,,,,,,,,, -6900,4.939874,3.773272,,,,,,,,,,,,,, -7000,7.1322355,3.7167072,,,,,,,,,,,,,, -7100,3.890442,3.7151241,,,,,,,,,,,,,, -7200,4.249107,3.7288947,,,,,,,,,,,,,, -7300,5.3571424,3.727903,,,,,,,,,,,,,, -7400,3.671941,3.626471,,,,,,,,,,,,,, -7447,,,0.437519907951355,2.5694799423217773,0.407260000705719,2.7285168170928955,50000.0,0.314300000667572,3.368595123291016,10000.0,2606.764097929001,2733.9635617733,2606.764097929001,126.81416702270508,0.1372864246368408,0.0 -7500,4.9154477,3.6810465,,,,,,,,,,,,,, -7600,7.3424683,3.7402809,,,,,,,,,,,,,, -7700,4.4830155,3.7330887,,,,,,,,,,,,,, -7800,6.1561313,3.5267584,,,,,,,,,,,,,, -7900,6.3159614,3.6843097,,,,,,,,,,,,,, -8000,4.924933,3.6138246,,,,,,,,,,,,,, -8100,4.100943,3.5411882,,,,,,,,,,,,,, -8200,6.226868,3.6080506,,,,,,,,,,,,,, -8300,7.3953023,3.5445573,,,,,,,,,,,,,, -8400,4.8652015,3.4610384,,,,,,,,,,,,,, -8500,5.6237884,3.5927982,,,,,,,,,,,,,, -8600,6.0635037,3.5028272,,,,,,,,,,,,,, -8700,4.7598996,3.3935895,,,,,,,,,,,,,, -8800,4.820744,3.5024118,,,,,,,,,,,,,, -8900,4.6408167,3.4457085,,,,,,,,,,,,,, -8937,,,0.4955755770206451,2.2581822872161865,0.4651599824428558,2.403724908828736,50000.0,0.3574000298976898,3.0720314979553223,10000.0,3116.7672028541565,3262.138736724853,3116.7672028541565,144.91092801094055,0.1650645732879638,0.0 -9000,4.502882,3.485939,,,,,,,,,,,,,, -9100,4.242973,3.408711,,,,,,,,,,,,,, -9200,5.031585,3.442112,,,,,,,,,,,,,, -9300,5.1467104,3.4010546,,,,,,,,,,,,,, -9400,4.777847,3.3029256,,,,,,,,,,,,,, -9500,5.2585664,3.442873,,,,,,,,,,,,,, -9600,6.7564287,3.372344,,,,,,,,,,,,,, -9700,3.6445184,3.3211005,,,,,,,,,,,,,, -9800,4.5707717,3.3833916,,,,,,,,,,,,,, -9900,4.6050463,3.389097,,,,,,,,,,,,,, -10000,6.500695,3.4162364,,,,,,,,,,,,,, -10100,3.5985122,3.3524704,,,,,,,,,,,,,, -10200,6.7607236,3.3207228,,,,,,,,,,,,,, -10300,5.937307,3.2602453,,,,,,,,,,,,,, -10400,5.0937777,3.33147,,,,,,,,,,,,,, -10428,,,0.5627790093421936,1.9515559673309328,0.4867599904537201,2.320361614227295,50000.0,0.3729000091552734,2.977101802825928,10000.0,3626.985067844391,3792.1785800457,3626.985067844391,164.654226064682,0.1945652961730957,0.0 -10500,4.1678233,3.2665474,,,,,,,,,,,,,, -10600,6.493261,3.215622,,,,,,,,,,,,,, -10700,8.524902,3.213326,,,,,,,,,,,,,, -10800,5.0095615,3.3054433,,,,,,,,,,,,,, -10900,4.390612,3.1586702,,,,,,,,,,,,,, -11000,4.7637024,3.1563666,,,,,,,,,,,,,, -11100,7.3321185,3.3299658,,,,,,,,,,,,,, -11200,5.129617,3.2933846,,,,,,,,,,,,,, -11300,6.0283,3.2185726,,,,,,,,,,,,,, -11400,5.9141874,3.1850371,,,,,,,,,,,,,, -11500,7.655546,3.2082512,,,,,,,,,,,,,, -11600,6.7086678,3.1953082,,,,,,,,,,,,,, -11700,5.5837893,3.141582,,,,,,,,,,,,,, -11800,5.712434,3.1835313,,,,,,,,,,,,,, -11900,6.001021,3.1482944,,,,,,,,,,,,,, -11920,,,0.5782246589660645,1.846057415008545,0.52947998046875,2.106097936630249,50000.0,0.4123000204563141,2.767740249633789,10000.0,4137.04333615303,4325.421687364578,4137.04333615303,187.7570457458496,0.2273647785186767,0.0 -12000,6.381893,3.1829438,,,,,,,,,,,,,, -12100,6.389663,3.2206392,,,,,,,,,,,,,, -12200,4.4535913,3.194776,,,,,,,,,,,,,, -12300,4.8685403,3.0761318,,,,,,,,,,,,,, -12400,4.91603,3.1708713,,,,,,,,,,,,,, -12500,5.5177355,3.181957,,,,,,,,,,,,,, -12600,3.6453395,3.232417,,,,,,,,,,,,,, -12700,3.6493561,3.109824,,,,,,,,,,,,,, -12800,5.706775,3.1393633,,,,,,,,,,,,,, -12900,7.5330925,3.1816456,,,,,,,,,,,,,, -13000,4.946209,3.107184,,,,,,,,,,,,,, -13100,5.3434772,3.164422,,,,,,,,,,,,,, -13200,4.136095,3.1215682,,,,,,,,,,,,,, -13300,2.5819101,3.0934265,,,,,,,,,,,,,, -13400,5.280014,3.2263427,,,,,,,,,,,,,, -13413,,,0.5894252061843872,1.8045969009399407,0.5416799783706665,2.0362939834594727,50000.0,0.4132000207901001,2.7203760147094727,10000.0,4647.025430679321,4857.068476676941,4647.025430679321,209.34475588798523,0.2544946670532226,0.0 -13500,6.340549,3.195387,,,,,,,,,,,,,, -13600,5.5403023,3.1057167,,,,,,,,,,,,,, -13700,4.409441,2.9903462,,,,,,,,,,,,,, -13800,6.4186,3.141829,,,,,,,,,,,,,, -13900,7.014582,3.054094,,,,,,,,,,,,,, -14000,4.7573023,3.2167094,,,,,,,,,,,,,, -14100,4.133001,3.085534,,,,,,,,,,,,,, -14200,4.596253,3.0214655,,,,,,,,,,,,,, -14300,4.5811057,3.0833507,,,,,,,,,,,,,, -14400,7.8433304,3.0923004,,,,,,,,,,,,,, -14500,5.114779,3.0809762,,,,,,,,,,,,,, -14600,6.177337,3.0728233,,,,,,,,,,,,,, -14700,6.4345145,3.1396623,,,,,,,,,,,,,, -14800,5.089832,3.047435,,,,,,,,,,,,,, -14900,5.432012,2.990369,,,,,,,,,,,,,, -14907,,,0.5977359414100647,1.8043289184570312,0.5502399802207947,2.0121490955352783,50000.0,0.428600013256073,2.705935001373291,10000.0,5157.151046037674,5389.574882030487,5157.151046037674,231.6502606868744,0.2813937664031982,0.0 -15000,6.6139355,2.9839191,,,,,,,,,,,,,, -15100,4.36064,3.0761576,,,,,,,,,,,,,, -15200,3.8692458,3.074664,,,,,,,,,,,,,, -15300,5.6894917,3.0459034,,,,,,,,,,,,,, -15400,6.7587023,3.0826294,,,,,,,,,,,,,, -15500,7.6156006,3.0223107,,,,,,,,,,,,,, -15600,4.674738,3.1976252,,,,,,,,,,,,,, -15700,4.922649,3.1554096,,,,,,,,,,,,,, -15800,3.8682983,3.009718,,,,,,,,,,,,,, -15900,5.1480656,3.0146766,,,,,,,,,,,,,, -16000,9.116356,3.1389763,,,,,,,,,,,,,, -16100,5.713602,2.9768045,,,,,,,,,,,,,, -16200,4.851358,3.0231416,,,,,,,,,,,,,, -16300,3.7751486,2.907516,,,,,,,,,,,,,, -16400,,,0.5902822017669678,1.8064643144607544,0.5517199635505676,1.990447163581848,50000.0,0.4308000206947326,2.7112877368927,10000.0,5667.149292945862,5920.640057086945,5667.149292945862,252.6213595867157,0.3276073932647705,0.0 -16400,3.455562,3.0061688,,,,,,,,,,,,,, -16500,4.6083074,2.9619122,,,,,,,,,,,,,, -16600,4.6085825,2.971184,,,,,,,,,,,,,, -16700,5.2251654,3.073553,,,,,,,,,,,,,, -16800,3.618029,2.9900393,,,,,,,,,,,,,, -16900,4.574426,2.980155,,,,,,,,,,,,,, -17000,4.7850456,3.0246763,,,,,,,,,,,,,, -17100,2.4683843,3.000083,,,,,,,,,,,,,, -17200,5.929402,2.9868062,,,,,,,,,,,,,, -17300,3.881436,2.9232621,,,,,,,,,,,,,, -17400,3.9210644,2.9689791,,,,,,,,,,,,,, -17500,3.4740133,3.0550363,,,,,,,,,,,,,, -17600,4.339325,3.0171375,,,,,,,,,,,,,, -17700,4.425171,3.102982,,,,,,,,,,,,,, -17800,4.626693,3.0032477,,,,,,,,,,,,,, -17895,,,0.5936902165412903,1.7567952871322632,0.5552600026130676,1.939975380897522,50000.0,0.4360000193119049,2.634429931640625,10000.0,6177.358016967773,6454.508526086807,6177.358016967773,276.202490568161,0.3575654029846191,0.0 -17900,3.342442,3.0672102,,,,,,,,,,,,,, -18000,3.4705389,3.1327207,,,,,,,,,,,,,, -18100,2.5597208,2.9775636,,,,,,,,,,,,,, -18200,3.8288212,2.918789,,,,,,,,,,,,,, -18300,4.1714997,2.9551685,,,,,,,,,,,,,, -18400,3.994953,3.095118,,,,,,,,,,,,,, -18500,3.3508303,2.9405475,,,,,,,,,,,,,, -18600,4.4779663,2.9966326,,,,,,,,,,,,,, -18700,4.0706377,2.9345927,,,,,,,,,,,,,, -18800,2.7947607,3.0150526,,,,,,,,,,,,,, -18900,2.9715292,3.0434167,,,,,,,,,,,,,, -19000,3.7969923,3.0844066,,,,,,,,,,,,,, -19100,3.9584897,2.9729779,,,,,,,,,,,,,, -19200,3.930369,3.095722,,,,,,,,,,,,,, -19300,2.9679515,2.951459,,,,,,,,,,,,,, -19389,,,0.6022600531578064,1.7681397199630735,0.5593400001525879,1.9684282541275024,50000.0,0.4310000240802765,2.670055866241455,10000.0,6687.38344168663,6987.305696725845,6687.38344168663,298.885041475296,0.3975296020507812,0.0 -19400,3.0705476,3.1473982,,,,,,,,,,,,,, -19500,3.053816,3.0202436,,,,,,,,,,,,,, -19600,3.4713533,2.9665895,,,,,,,,,,,,,, -19700,4.130599,3.0076876,,,,,,,,,,,,,, -19800,3.6564739,2.9512851,,,,,,,,,,,,,, -19900,3.201387,3.029679,,,,,,,,,,,,,, -20000,3.4867096,2.9269426,,,,,,,,,,,,,, -20100,3.9934669,2.8745205,,,,,,,,,,,,,, -20200,3.0772796,2.9630198,,,,,,,,,,,,,, -20300,3.0003834,3.0460432,,,,,,,,,,,,,, -20400,3.3976228,2.9549174,,,,,,,,,,,,,, -20500,3.8580115,2.9025805,,,,,,,,,,,,,, -20600,3.1908603,2.8608093,,,,,,,,,,,,,, -20700,2.966269,2.917957,,,,,,,,,,,,,, -20800,4.8564835,2.8963444,,,,,,,,,,,,,, -20884,,,0.6381337642669678,1.587932586669922,0.5699599981307983,1.8987358808517456,50000.0,0.4415000081062317,2.5937812328338623,10000.0,7197.566004276276,7521.352801322937,7197.566004276276,322.6722848415375,0.4260029792785644,0.0 -20900,3.9177375,3.1112094,,,,,,,,,,,,,, -21000,3.1504443,3.009395,,,,,,,,,,,,,, -21100,3.9897277,3.0114276,,,,,,,,,,,,,, -21200,3.2963936,2.9768288,,,,,,,,,,,,,, -21300,4.0035944,2.9709723,,,,,,,,,,,,,, -21400,2.6700907,2.8934386,,,,,,,,,,,,,, -21500,3.338604,2.9666603,,,,,,,,,,,,,, -21600,3.7851796,3.0330904,,,,,,,,,,,,,, -21700,2.7174304,2.8933778,,,,,,,,,,,,,, -21800,3.3959591,2.9517896,,,,,,,,,,,,,, -21900,3.1433935,3.0138934,,,,,,,,,,,,,, -22000,2.7971559,2.8941264,,,,,,,,,,,,,, -22100,2.7870417,2.8562071,,,,,,,,,,,,,, -22200,3.0768626,2.865262,,,,,,,,,,,,,, -22300,2.9711835,2.9316416,,,,,,,,,,,,,, -22379,,,0.6238042116165161,1.6296091079711914,0.5750600099563599,1.864007472991944,50000.0,0.4531000256538391,2.565328359603882,10000.0,7707.807578086853,8054.712727546692,7707.807578086853,345.7115654945373,0.4558007717132568,0.0 -22400,3.1613128,2.890312,,,,,,,,,,,,,, -22500,3.1197639,2.9021933,,,,,,,,,,,,,, -22600,3.4196825,2.844256,,,,,,,,,,,,,, -22700,3.2246995,2.8324137,,,,,,,,,,,,,, -22800,2.4129372,2.8546798,,,,,,,,,,,,,, -22900,2.8645432,2.8389425,,,,,,,,,,,,,, -23000,3.0430393,2.9341245,,,,,,,,,,,,,, -23100,3.2010388,2.969802,,,,,,,,,,,,,, -23200,3.5876966,2.8997424,,,,,,,,,,,,,, -23300,3.3578446,2.9019785,,,,,,,,,,,,,, -23400,3.1168115,3.0068629,,,,,,,,,,,,,, -23500,3.0427341,2.965856,,,,,,,,,,,,,, -23600,2.623696,2.9145546,,,,,,,,,,,,,, -23700,2.8257918,2.8753703,,,,,,,,,,,,,, -23800,2.6538742,2.8920443,,,,,,,,,,,,,, -23874,,,0.6375358700752258,1.5796483755111694,0.586899995803833,1.815063118934632,50000.0,0.4646000266075134,2.491010904312134,10000.0,8217.97511434555,8587.158815145493,8217.97511434555,367.9157955646515,0.4832274913787842,0.0 -23900,3.1369045,2.8396275,,,,,,,,,,,,,, -24000,2.7743413,2.9682004,,,,,,,,,,,,,, -24100,2.9582233,2.9262292,,,,,,,,,,,,,, -24200,2.8667693,2.879624,,,,,,,,,,,,,, -24300,3.79569,2.895205,,,,,,,,,,,,,, -24400,2.8799822,2.9260974,,,,,,,,,,,,,, -24500,3.2377098,2.9053829,,,,,,,,,,,,,, -24600,3.2237844,2.8754735,,,,,,,,,,,,,, -24700,2.6032965,2.8355286,,,,,,,,,,,,,, -24800,2.9702709,2.865117,,,,,,,,,,,,,, -24900,3.039397,2.9631588,,,,,,,,,,,,,, -25000,3.456172,2.8621237,,,,,,,,,,,,,, -25100,3.028482,2.8675046,,,,,,,,,,,,,, -25200,2.8016202,2.8218536,,,,,,,,,,,,,, -25300,3.4732516,2.8722532,,,,,,,,,,,,,, -25369,,,0.6294642686843872,1.597659945487976,0.584119975566864,1.8133976459503167,50000.0,0.4585000276565552,2.4793083667755127,10000.0,8727.92223238945,9119.464904785156,8727.92223238945,390.196186542511,0.512209415435791,0.0 -25400,2.5518787,2.9992573,,,,,,,,,,,,,, -25500,3.0027938,2.8484178,,,,,,,,,,,,,, -25600,4.0567465,2.9049263,,,,,,,,,,,,,, -25700,3.1762676,2.9227865,,,,,,,,,,,,,, -25800,2.863089,2.8673916,,,,,,,,,,,,,, -25900,2.6075697,2.9806802,,,,,,,,,,,,,, -26000,3.618236,2.831902,,,,,,,,,,,,,, -26100,3.5630114,2.8637161,,,,,,,,,,,,,, -26200,3.0731676,2.9313252,,,,,,,,,,,,,, -26300,2.57565,2.8918748,,,,,,,,,,,,,, -26400,3.4324536,2.9276075,,,,,,,,,,,,,, -26500,2.9725468,2.830861,,,,,,,,,,,,,, -26600,2.742927,2.7587156,,,,,,,,,,,,,, -26700,3.7855775,2.7571316,,,,,,,,,,,,,, -26800,2.9301925,2.892899,,,,,,,,,,,,,, -26864,,,0.6300422549247742,1.63661527633667,0.5836600065231323,1.8408207893371584,50000.0,0.464100033044815,2.510469913482666,10000.0,9237.966356039047,9651.175999879835,9237.966356039047,411.78082633018494,0.5465049743652344,0.0 -26900,2.9868217,2.882936,,,,,,,,,,,,,, -27000,2.6907153,2.9841275,,,,,,,,,,,,,, -27100,2.748768,2.9552915,,,,,,,,,,,,,, -27200,2.607993,2.9048843,,,,,,,,,,,,,, -27300,3.017277,2.7634547,,,,,,,,,,,,,, -27400,3.3103878,2.8524923,,,,,,,,,,,,,, -27500,2.886156,2.8021166,,,,,,,,,,,,,, -27600,2.8153257,2.8629284,,,,,,,,,,,,,, -27700,2.6451356,2.8192527,,,,,,,,,,,,,, -27800,2.6322794,2.8019276,,,,,,,,,,,,,, -27900,3.392852,2.9125037,,,,,,,,,,,,,, -28000,3.1294532,2.798827,,,,,,,,,,,,,, -28100,3.0277517,2.8577428,,,,,,,,,,,,,, -28200,3.4385989,2.8107467,,,,,,,,,,,,,, -28300,3.434373,2.7471704,,,,,,,,,,,,,, -28358,,,0.6334103941917419,1.5800617933273315,0.5934999585151672,1.7827062606811523,50000.0,0.4733000099658966,2.4336440563201904,10000.0,9747.298719406128,10183.808728456495,9747.298719406128,434.0755660533905,1.5038611888885498,0.0 -28400,2.8834176,2.864355,,,,,,,,,,,,,, -28500,3.2258685,2.82095,,,,,,,,,,,,,, -28600,3.21987,2.8144448,,,,,,,,,,,,,, -28700,2.9703496,2.8552852,,,,,,,,,,,,,, -28800,3.485923,2.8085947,,,,,,,,,,,,,, -28900,3.1889405,2.9184034,,,,,,,,,,,,,, -29000,3.123487,2.9011314,,,,,,,,,,,,,, -29100,3.005081,2.7029593,,,,,,,,,,,,,, -29200,3.008213,2.8081174,,,,,,,,,,,,,, -29300,2.8173485,2.8815646,,,,,,,,,,,,,, -29400,3.4983199,2.914936,,,,,,,,,,,,,, -29500,3.063312,2.88921,,,,,,,,,,,,,, -29600,3.4888523,2.8822105,,,,,,,,,,,,,, -29700,2.9483583,2.7743433,,,,,,,,,,,,,, -29800,3.3102882,2.8038619,,,,,,,,,,,,,, -29854,,,0.6729910373687744,1.422000288963318,0.5934999585151672,1.7916085720062256,50000.0,0.4643000364303589,2.4862213134765625,10000.0,10257.261818170547,10718.622612953186,10257.261818170547,458.8446247577667,1.535851240158081,0.0 -29900,3.3672442,2.8678734,,,,,,,,,,,,,, -30000,2.7086287,2.8109138,,,,,,,,,,,,,, -30100,3.8007662,2.8380296,,,,,,,,,,,,,, -30200,3.0956516,2.8659942,,,,,,,,,,,,,, -30300,3.2469025,2.7370656,,,,,,,,,,,,,, -30400,2.418109,2.8180559,,,,,,,,,,,,,, -30500,3.350694,2.8767035,,,,,,,,,,,,,, -30600,3.3107786,2.8654191,,,,,,,,,,,,,, -30700,2.6675866,2.868961,,,,,,,,,,,,,, -30800,3.3669922,2.8380892,,,,,,,,,,,,,, -30900,2.9016378,2.7325156,,,,,,,,,,,,,, -31000,3.1178904,2.8264396,,,,,,,,,,,,,, -31100,2.7539089,2.8386343,,,,,,,,,,,,,, -31200,3.514039,2.808787,,,,,,,,,,,,,, -31300,3.0076993,2.8655572,,,,,,,,,,,,,, -31351,,,0.6510881781578064,1.5065069198608398,0.5972599983215332,1.757551193237305,50000.0,0.4695000350475311,2.46608304977417,10000.0,10767.481489896774,11253.432497739792,10767.481489896774,483.3538534641266,1.5658612251281738,0.0 -31400,2.9995403,2.8804207,,,,,,,,,,,,,, -31500,3.6439805,2.9274263,,,,,,,,,,,,,, -31600,3.3964918,2.9077044,,,,,,,,,,,,,, -31700,3.065477,2.8206484,,,,,,,,,,,,,, -31800,3.1398568,2.8757396,,,,,,,,,,,,,, -31900,3.495733,2.8432424,,,,,,,,,,,,,, -32000,3.7606385,2.898484,,,,,,,,,,,,,, -32100,4.5005007,2.8743417,,,,,,,,,,,,,, -32200,3.1876812,2.8570282,,,,,,,,,,,,,, -32300,2.7891376,2.8574638,,,,,,,,,,,,,, -32400,2.7778182,2.8460393,,,,,,,,,,,,,, -32500,3.2626266,2.7880602,,,,,,,,,,,,,, -32600,3.2152753,2.7708573,,,,,,,,,,,,,, -32700,2.7475317,2.8278403,,,,,,,,,,,,,, -32800,2.806305,2.783218,,,,,,,,,,,,,, -32848,,,0.6462252736091614,1.5267188549041748,0.6002399921417236,1.758679986000061,50000.0,0.4761000275611877,2.4272708892822266,10000.0,11277.71798467636,11786.628136873243,11277.71798467636,506.2302823066712,1.5994904041290283,0.0 -32900,2.875621,2.8498578,,,,,,,,,,,,,, -33000,3.2327218,2.7891052,,,,,,,,,,,,,, -33100,2.7658722,2.6958244,,,,,,,,,,,,,, -33200,3.482995,2.8202126,,,,,,,,,,,,,, -33300,2.8028867,2.8796842,,,,,,,,,,,,,, -33400,2.8915725,2.7817469,,,,,,,,,,,,,, -33500,3.3588693,2.8555515,,,,,,,,,,,,,, -33600,3.0775387,2.925966,,,,,,,,,,,,,, -33700,2.7870107,2.80776,,,,,,,,,,,,,, -33800,3.2072163,2.8669038,,,,,,,,,,,,,, -33900,2.8608809,2.9045672,,,,,,,,,,,,,, -34000,2.6355705,2.8323607,,,,,,,,,,,,,, -34100,2.924445,2.8972661,,,,,,,,,,,,,, -34200,3.8041646,2.8766122,,,,,,,,,,,,,, -34300,2.7606204,2.770054,,,,,,,,,,,,,, -34344,,,0.6496133208274841,1.536085605621338,0.6037399768829346,1.748511791229248,50000.0,0.4770000278949737,2.4422802925109863,10000.0,11787.83618426323,12318.89468216896,11787.83618426323,528.2995517253876,1.6302180290222168,0.0 -34400,2.7937312,2.7509434,,,,,,,,,,,,,, -34500,2.777008,2.8135977,,,,,,,,,,,,,, -34600,2.902621,2.755237,,,,,,,,,,,,,, -34700,2.8773837,2.8224583,,,,,,,,,,,,,, -34800,2.9153256,2.8953912,,,,,,,,,,,,,, -34900,3.4878678,2.8494582,,,,,,,,,,,,,, -35000,2.7652473,2.7772803,,,,,,,,,,,,,, -35100,2.909022,2.7809284,,,,,,,,,,,,,, -35200,2.7963586,2.7654874,,,,,,,,,,,,,, -35300,2.8938127,2.7479587,,,,,,,,,,,,,, -35400,4.4360456,2.7740788,,,,,,,,,,,,,, -35500,3.253333,2.802876,,,,,,,,,,,,,, -35600,3.2064705,2.8350866,,,,,,,,,,,,,, -35700,2.9535902,2.818058,,,,,,,,,,,,,, -35800,3.152738,2.876364,,,,,,,,,,,,,, -35840,,,0.645926296710968,1.52766752243042,0.596560001373291,1.7534399032592771,50000.0,0.4759000241756439,2.4189414978027344,10000.0,12297.890921831133,12852.47915005684,12297.890921831133,551.7431771755219,1.6666617393493652,0.0 -35900,3.0301116,2.8400352,,,,,,,,,,,,,, -36000,3.439973,2.8925858,,,,,,,,,,,,,, -36100,2.9135714,2.8245516,,,,,,,,,,,,,, -36200,2.7612896,2.6837626,,,,,,,,,,,,,, -36300,2.7285075,2.765042,,,,,,,,,,,,,, -36400,3.3438592,2.8539028,,,,,,,,,,,,,, -36500,3.8527448,2.872261,,,,,,,,,,,,,, -36600,3.008128,2.9088132,,,,,,,,,,,,,, -36700,3.2097933,2.737313,,,,,,,,,,,,,, -36800,3.4427538,2.6018283,,,,,,,,,,,,,, -36900,3.0066874,2.7610068,,,,,,,,,,,,,, -37000,2.9382226,2.782052,,,,,,,,,,,,,, -37100,3.5863876,2.7945685,,,,,,,,,,,,,, -37200,2.9147046,2.6972759,,,,,,,,,,,,,, -37300,3.6406,2.8576446,,,,,,,,,,,,,, -37336,,,0.6403061151504517,1.5559779405593872,0.6036199927330017,1.746623992919922,50000.0,0.4829000234603882,2.418141841888428,10000.0,12807.821355819702,13384.028679132462,12807.821355819702,573.2801125049591,1.6983842849731443,0.0 -37400,3.6183457,2.726819,,,,,,,,,,,,,, -37500,3.2376835,2.840212,,,,,,,,,,,,,, -37600,2.9616637,2.7884846,,,,,,,,,,,,,, -37700,2.9373329,2.8376548,,,,,,,,,,,,,, -37800,3.1098778,2.8044367,,,,,,,,,,,,,, -37900,3.308525,2.7751575,,,,,,,,,,,,,, -38000,2.795887,2.7792902,,,,,,,,,,,,,, -38100,2.9275846,2.782152,,,,,,,,,,,,,, -38200,3.088213,2.8299136,,,,,,,,,,,,,, -38300,3.1283498,2.7529883,,,,,,,,,,,,,, -38400,2.8146968,2.7368019,,,,,,,,,,,,,, -38500,3.0379515,2.7256434,,,,,,,,,,,,,, -38600,3.4734116,2.7699955,,,,,,,,,,,,,, -38700,2.7973335,2.7507212,,,,,,,,,,,,,, -38800,3.052828,2.7806942,,,,,,,,,,,,,, -38832,,,0.6729711294174194,1.439097881317139,0.6108199954032898,1.7267426252365112,50000.0,0.4917000234127044,2.3655688762664795,10000.0,13317.77873301506,13917.43752002716,13317.77873301506,596.6457276344299,1.734304666519165,0.0 -38900,2.903426,2.7534266,,,,,,,,,,,,,, -39000,3.1344845,2.8151257,,,,,,,,,,,,,, -39100,3.082072,2.6836343,,,,,,,,,,,,,, -39200,3.8173983,2.8399534,,,,,,,,,,,,,, -39300,3.1733124,2.846567,,,,,,,,,,,,,, -39400,3.0043685,2.8308463,,,,,,,,,,,,,, -39500,3.2529051,2.72815,,,,,,,,,,,,,, -39600,2.9166052,2.771601,,,,,,,,,,,,,, -39700,2.9126146,2.7171645,,,,,,,,,,,,,, -39800,2.8111658,2.796397,,,,,,,,,,,,,, -39900,2.8075085,2.7567372,,,,,,,,,,,,,, -40000,3.0752332,2.716949,,,,,,,,,,,,,, -40100,2.981286,2.7886572,,,,,,,,,,,,,, -40200,3.4269142,2.7945323,,,,,,,,,,,,,, -40300,3.6114957,2.7906096,,,,,,,,,,,,,, -40329,,,0.6744459271430969,1.4112752676010132,0.610260009765625,1.698612928390503,50000.0,0.4941000342369079,2.351206064224243,10000.0,13827.9108877182,14448.347055912018,13827.9108877182,617.3422248363495,1.76465106010437,0.0 -40400,3.2119808,2.7501054,,,,,,,,,,,,,, -40500,3.2507539,2.7928894,,,,,,,,,,,,,, -40600,2.827345,2.8906772,,,,,,,,,,,,,, -40700,3.0199988,2.7357416,,,,,,,,,,,,,, -40800,3.2297368,2.701263,,,,,,,,,,,,,, -40900,2.8741188,2.757013,,,,,,,,,,,,,, -41000,2.612519,2.638964,,,,,,,,,,,,,, -41100,3.3783066,2.7418232,,,,,,,,,,,,,, -41200,3.1608465,2.763505,,,,,,,,,,,,,, -41300,2.920503,2.8565965,,,,,,,,,,,,,, -41400,2.8185396,2.7284832,,,,,,,,,,,,,, -41500,3.230693,2.919279,,,,,,,,,,,,,, -41600,3.3058698,2.8729165,,,,,,,,,,,,,, -41700,3.362994,2.8658476,,,,,,,,,,,,,, -41800,3.2516153,2.7610495,,,,,,,,,,,,,, -41825,,,0.6544762253761292,1.509508728981018,0.6057199835777283,1.743627667427063,50000.0,0.4783000349998474,2.419309377670288,10000.0,14337.827102661133,14978.894652605057,14337.827102661133,637.8876869678497,1.7998626232147217,0.0 -41900,3.8784854,2.8538191,,,,,,,,,,,,,, -42000,3.02508,2.7249584,,,,,,,,,,,,,, -42100,3.4061556,2.6795952,,,,,,,,,,,,,, -42200,3.2905655,2.8812637,,,,,,,,,,,,,, -42300,2.9227715,2.784542,,,,,,,,,,,,,, -42400,2.7565281,2.833279,,,,,,,,,,,,,, -42500,3.2442596,2.7933764,,,,,,,,,,,,,, -42600,3.3158622,2.7461748,,,,,,,,,,,,,, -42700,2.9712756,2.7775445,,,,,,,,,,,,,, -42800,3.0251017,2.7420285,,,,,,,,,,,,,, -42900,3.2713468,2.8314304,,,,,,,,,,,,,, -43000,3.5261407,2.7896624,,,,,,,,,,,,,, -43100,2.9830632,2.8614588,,,,,,,,,,,,,, -43200,3.3404794,2.753631,,,,,,,,,,,,,, -43300,3.3608801,2.783825,,,,,,,,,,,,,, -43322,,,0.6491350531578064,1.5379236936569214,0.6043199896812439,1.748526930809021,50000.0,0.4749000370502472,2.452561378479004,10000.0,14847.95288991928,15509.13153076172,14847.95288991928,657.9168131351471,1.832062005996704,0.0 -43400,3.1205685,2.8574836,,,,,,,,,,,,,, -43500,3.0042384,2.6959748,,,,,,,,,,,,,, -43600,3.4641783,2.7866044,,,,,,,,,,,,,, -43700,2.6770542,2.6927881,,,,,,,,,,,,,, -43800,2.8504872,2.7682824,,,,,,,,,,,,,, -43900,2.9174333,2.7028737,,,,,,,,,,,,,, -44000,3.123931,2.803926,,,,,,,,,,,,,, -44100,2.6238961,2.7967305,,,,,,,,,,,,,, -44200,3.223876,2.8631673,,,,,,,,,,,,,, -44300,3.0185158,2.6733685,,,,,,,,,,,,,, -44400,2.9394805,2.7394688,,,,,,,,,,,,,, -44500,3.3970716,2.7971363,,,,,,,,,,,,,, -44600,3.5112214,2.718691,,,,,,,,,,,,,, -44700,2.8348918,2.737289,,,,,,,,,,,,,, -44800,2.5153642,2.7458248,,,,,,,,,,,,,, -44819,,,0.6519451141357422,1.5086034536361694,0.6087799668312073,1.7206631898880005,50000.0,0.4844000339508056,2.402400016784668,10000.0,15358.028193473816,16037.553413391111,15358.028193473816,676.1758737564087,1.8695974349975584,0.0 -44900,3.86678,2.7287114,,,,,,,,,,,,,, -45000,3.12373,2.685528,,,,,,,,,,,,,, -45100,4.1444597,2.7304094,,,,,,,,,,,,,, -45200,2.832792,2.7582202,,,,,,,,,,,,,, -45300,2.9113634,2.8231475,,,,,,,,,,,,,, -45400,2.6692803,2.6488063,,,,,,,,,,,,,, -45500,3.1665702,2.8352728,,,,,,,,,,,,,, -45600,2.8355148,2.7611039,,,,,,,,,,,,,, -45700,2.9672966,2.7367015,,,,,,,,,,,,,, -45800,2.9342606,2.7810879,,,,,,,,,,,,,, -45900,2.8227568,2.8274498,,,,,,,,,,,,,, -46000,2.8281693,2.7566757,,,,,,,,,,,,,, -46100,2.7998369,2.6965497,,,,,,,,,,,,,, -46200,3.5136657,2.8093853,,,,,,,,,,,,,, -46300,3.4977157,2.6865742,,,,,,,,,,,,,, -46315,,,0.6457070708274841,1.543419361114502,0.6055799722671509,1.7364466190338137,50000.0,0.4705000221729278,2.4317517280578613,10000.0,15867.962439060211,16565.30499601364,15867.962439060211,693.9103631973267,1.903987169265747,0.0 -46400,2.5078838,2.6526208,,,,,,,,,,,,,, -46500,3.2492316,2.7038884,,,,,,,,,,,,,, -46600,3.5491273,2.76308,,,,,,,,,,,,,, -46700,3.2682881,2.7519283,,,,,,,,,,,,,, -46800,2.7296543,2.711875,,,,,,,,,,,,,, -46900,3.7968137,2.7316227,,,,,,,,,,,,,, -47000,3.1239483,2.8154812,,,,,,,,,,,,,, -47100,2.9963517,2.6736436,,,,,,,,,,,,,, -47200,3.126826,2.713922,,,,,,,,,,,,,, -47300,2.9208417,2.7342114,,,,,,,,,,,,,, -47400,2.866103,2.6960878,,,,,,,,,,,,,, -47500,3.1197736,2.8360167,,,,,,,,,,,,,, -47600,2.9554615,2.8362591,,,,,,,,,,,,,, -47700,3.0299459,2.798276,,,,,,,,,,,,,, -47800,3.5899975,2.7952704,,,,,,,,,,,,,, -47812,,,0.6532405614852905,1.498455286026001,0.6109600067138672,1.7112398147583008,50000.0,0.4796000123023987,2.4064650535583496,10000.0,16378.023522853851,17093.762252807617,16378.023522853851,712.2113356590271,1.949737787246704,0.0 -47900,3.1219144,2.7423668,,,,,,,,,,,,,, -48000,3.1972532,2.7282825,,,,,,,,,,,,,, -48100,3.1296155,2.6636646,,,,,,,,,,,,,, -48200,2.641608,2.689465,,,,,,,,,,,,,, -48300,2.9703782,2.7495408,,,,,,,,,,,,,, -48400,3.5839121,2.7088685,,,,,,,,,,,,,, -48500,2.945224,2.7818613,,,,,,,,,,,,,, -48600,2.8766253,2.8165956,,,,,,,,,,,,,, -48700,3.3330576,2.6756117,,,,,,,,,,,,,, -48800,3.269945,2.749885,,,,,,,,,,,,,, -48900,3.3237703,2.7116492,,,,,,,,,,,,,, -49000,3.1603286,2.7459989,,,,,,,,,,,,,, -49100,3.10745,2.7833266,,,,,,,,,,,,,, -49200,3.0379715,2.7357779,,,,,,,,,,,,,, -49300,2.9184606,2.6464586,,,,,,,,,,,,,, -49308,,,0.6916453838348389,1.3557281494140625,0.6181399822235107,1.6922627687454224,50000.0,0.4860000312328338,2.362870216369629,10000.0,16887.94068956375,17621.080164194107,16887.94068956375,729.52081990242,1.992770195007324,0.0 -49400,2.903936,2.769189,,,,,,,,,,,,,, -49500,2.9425282,2.7320497,,,,,,,,,,,,,, -49600,3.200371,2.7352338,,,,,,,,,,,,,, -49700,3.056834,2.8363712,,,,,,,,,,,,,, -49800,3.2087154,2.6378975,,,,,,,,,,,,,, -49900,3.3552563,2.801907,,,,,,,,,,,,,, -50000,3.9774396,2.6835697,,,,,,,,,,,,,, -50100,3.0412142,2.81748,,,,,,,,,,,,,, -50200,3.2973688,2.734727,,,,,,,,,,,,,, -50300,3.280349,2.709904,,,,,,,,,,,,,, -50400,3.0850286,2.5398884,,,,,,,,,,,,,, -50500,3.9654415,2.7770894,,,,,,,,,,,,,, -50600,2.9517553,2.729313,,,,,,,,,,,,,, -50700,2.8308735,2.7364683,,,,,,,,,,,,,, -50800,3.0292397,2.6828053,,,,,,,,,,,,,, -50807,,,0.671297013759613,1.4186804294586182,0.6159999966621399,1.6773052215576172,50000.0,0.4921000301837921,2.351040840148926,10000.0,17398.12221980095,18148.496007680893,17398.12221980095,746.666835308075,2.0300121307373047,0.0 -50900,2.677102,2.792563,,,,,,,,,,,,,, -51000,3.2123048,2.729262,,,,,,,,,,,,,, -51100,3.1004217,2.6544003,,,,,,,,,,,,,, -51200,2.8334832,2.7656257,,,,,,,,,,,,,, -51300,2.9311006,2.709344,,,,,,,,,,,,,, -51400,3.0927312,2.7117753,,,,,,,,,,,,,, -51500,3.4825542,2.7731278,,,,,,,,,,,,,, -51600,2.6922696,2.7204452,,,,,,,,,,,,,, -51700,2.8648694,2.7789166,,,,,,,,,,,,,, -51800,2.9008832,2.7706432,,,,,,,,,,,,,, -51900,3.3089874,2.7003584,,,,,,,,,,,,,, -52000,3.2097242,2.5943549,,,,,,,,,,,,,, -52100,3.6027403,2.7153823,,,,,,,,,,,,,, -52200,3.485321,2.8101726,,,,,,,,,,,,,, -52300,3.0747242,2.725754,,,,,,,,,,,,,, -52304,,,0.6713567972183228,1.4391323328018188,0.6169399619102478,1.6796952486038208,50000.0,0.4946000277996063,2.344132661819458,10000.0,17908.237648248672,18675.98785209656,17908.237648248672,763.9545395374298,2.067458152770996,0.0 -52400,3.083706,2.7319446,,,,,,,,,,,,,, -52500,3.163939,2.6310315,,,,,,,,,,,,,, -52600,2.7789116,2.8445797,,,,,,,,,,,,,, -52700,3.3486931,2.6239622,,,,,,,,,,,,,, -52800,3.298986,2.7028022,,,,,,,,,,,,,, -52900,3.1772754,2.680466,,,,,,,,,,,,,, -53000,2.825765,2.658502,,,,,,,,,,,,,, -53100,3.0606463,2.7423651,,,,,,,,,,,,,, -53200,3.1405904,2.7070913,,,,,,,,,,,,,, -53300,2.6633582,2.809187,,,,,,,,,,,,,, -53400,3.3088982,2.7708998,,,,,,,,,,,,,, -53500,2.9268308,2.67936,,,,,,,,,,,,,, -53600,3.2742782,2.756691,,,,,,,,,,,,,, -53700,2.684847,2.7146976,,,,,,,,,,,,,, -53800,3.0870867,2.7054913,,,,,,,,,,,,,, -53801,,,0.6638432741165161,1.4525445699691772,0.6154999732971191,1.6743123531341553,50000.0,0.4966000318527221,2.325680494308472,10000.0,18418.46981525421,19203.51691007614,18418.46981525421,781.1611652374268,2.107489109039306,0.0 -53900,3.0520988,2.7960274,,,,,,,,,,,,,, -54000,2.9448934,2.7468352,,,,,,,,,,,,,, -54100,3.069392,2.8216462,,,,,,,,,,,,,, -54200,2.879124,2.803808,,,,,,,,,,,,,, -54300,3.5775492,2.721562,,,,,,,,,,,,,, -54400,3.3073187,2.7548428,,,,,,,,,,,,,, -54500,3.1332173,2.754846,,,,,,,,,,,,,, -54600,3.0523407,2.6994014,,,,,,,,,,,,,, -54700,3.4617107,2.8260977,,,,,,,,,,,,,, -54800,3.3737235,2.7544155,,,,,,,,,,,,,, -54900,2.9957974,2.6746416,,,,,,,,,,,,,, -55000,3.7509935,2.7762074,,,,,,,,,,,,,, -55100,3.0483654,2.6831684,,,,,,,,,,,,,, -55200,2.9697912,2.7233958,,,,,,,,,,,,,, -55299,,,0.6633649468421936,1.4799798727035522,0.6179400086402893,1.6885169744491575,50000.0,0.4979000091552734,2.3356235027313232,10000.0,18928.68518781662,19731.023154973984,18928.68518781662,798.3581395149231,2.149481773376465,0.0 -55300,3.0946112,2.77404,,,,,,,,,,,,,, -55400,3.0059903,2.730447,,,,,,,,,,,,,, -55500,2.844702,2.7461257,,,,,,,,,,,,,, -55600,2.9235594,2.6729288,,,,,,,,,,,,,, -55700,3.1986706,2.7823582,,,,,,,,,,,,,, -55800,3.218698,2.7163792,,,,,,,,,,,,,, -55900,3.070992,2.772447,,,,,,,,,,,,,, -56000,2.9558544,2.8001442,,,,,,,,,,,,,, -56100,3.0366616,2.6612246,,,,,,,,,,,,,, -56200,2.9256299,2.7067409,,,,,,,,,,,,,, -56300,2.7668211,2.703617,,,,,,,,,,,,,, -56400,3.0174923,2.707953,,,,,,,,,,,,,, -56500,2.99343,2.658287,,,,,,,,,,,,,, -56600,2.7289352,2.6213272,,,,,,,,,,,,,, -56700,3.1551843,2.853824,,,,,,,,,,,,,, -56797,,,0.6719148755073547,1.425616979598999,0.6231799721717834,1.6395264863967896,50000.0,0.508400022983551,2.2914958000183105,10000.0,19438.76026201248,20258.59286975861,19438.76026201248,815.757345199585,2.195958137512207,0.0 -56800,3.5280778,2.6880543,,,,,,,,,,,,,, -56900,3.0082977,2.7181542,,,,,,,,,,,,,, -57000,3.3861682,2.7468917,,,,,,,,,,,,,, -57100,3.262967,2.536167,,,,,,,,,,,,,, -57200,3.189963,2.7003775,,,,,,,,,,,,,, -57300,2.808747,2.7495468,,,,,,,,,,,,,, -57400,3.2068467,2.6308684,,,,,,,,,,,,,, -57500,3.1505902,2.6453524,,,,,,,,,,,,,, -57600,3.182976,2.7235124,,,,,,,,,,,,,, -57700,3.2121024,2.6890519,,,,,,,,,,,,,, -57800,3.3970203,2.7777143,,,,,,,,,,,,,, -57900,3.1321948,2.6766648,,,,,,,,,,,,,, -58000,2.9101276,2.7188601,,,,,,,,,,,,,, -58100,3.0686326,2.7454782,,,,,,,,,,,,,, -58200,3.0490947,2.6855154,,,,,,,,,,,,,, -58294,,,0.7206233739852905,1.2154070138931274,0.6320399641990662,1.6034917831420898,50000.0,0.5078000426292419,2.266110897064209,10000.0,19948.85229182244,20786.05773258209,19948.85229182244,833.042439699173,2.232260465621948,0.0 -58300,3.1755028,2.6574926,,,,,,,,,,,,,, -58400,3.3498504,2.8045483,,,,,,,,,,,,,, -58500,2.9695694,2.6322625,,,,,,,,,,,,,, -58600,3.0833857,2.765048,,,,,,,,,,,,,, -58700,2.8378935,2.7942913,,,,,,,,,,,,,, -58800,3.3910763,2.6753924,,,,,,,,,,,,,, -58900,3.0809393,2.5922487,,,,,,,,,,,,,, -59000,3.270312,2.7082555,,,,,,,,,,,,,, -59100,2.828506,2.6004748,,,,,,,,,,,,,, -59200,3.0198178,2.774668,,,,,,,,,,,,,, -59300,3.1382127,2.626851,,,,,,,,,,,,,, -59400,3.1726048,2.568885,,,,,,,,,,,,,, -59500,3.5233347,2.6968074,,,,,,,,,,,,,, -59600,3.2279205,2.680358,,,,,,,,,,,,,, -59700,3.0498774,2.666058,,,,,,,,,,,,,, -59792,,,0.6831552982330322,1.3804254531860352,0.6195999979972839,1.6663074493408203,50000.0,0.4915000200271606,2.3457841873168945,10000.0,20458.78334474564,21313.644002199173,20458.78334474564,850.6091375350952,2.270244836807251,0.0 -59800,3.3132634,2.6616864,,,,,,,,,,,,,, -59900,3.4279103,2.6758137,,,,,,,,,,,,,, -60000,2.8026605,2.6188216,,,,,,,,,,,,,, -60100,3.3048253,2.6972153,,,,,,,,,,,,,, -60200,3.3485413,2.632891,,,,,,,,,,,,,, -60300,3.3625267,2.7723162,,,,,,,,,,,,,, -60400,3.435159,2.531726,,,,,,,,,,,,,, -60500,3.4076104,2.7266326,,,,,,,,,,,,,, -60600,3.3171506,2.6873472,,,,,,,,,,,,,, -60700,3.0665638,2.65439,,,,,,,,,,,,,, -60800,3.1923952,2.6852448,,,,,,,,,,,,,, -60900,3.766224,2.7348378,,,,,,,,,,,,,, -61000,3.2408993,2.6629338,,,,,,,,,,,,,, -61100,2.8863893,2.7189748,,,,,,,,,,,,,, -61200,3.2829037,2.8300643,,,,,,,,,,,,,, -61289,,,0.6767179369926453,1.392195224761963,0.6308199763298035,1.6244053840637207,50000.0,0.5031999945640564,2.3163297176361084,10000.0,20968.790618896484,21840.92485141754,20968.790618896484,867.795756816864,2.307173728942871,0.0 -61300,2.88972,2.6864927,,,,,,,,,,,,,, -61400,3.2043,2.7167737,,,,,,,,,,,,,, -61500,3.044132,2.6713367,,,,,,,,,,,,,, -61600,3.4400895,2.6496577,,,,,,,,,,,,,, -61700,3.1622546,2.7079291,,,,,,,,,,,,,, -61800,3.741231,2.6021497,,,,,,,,,,,,,, -61900,2.818433,2.7205367,,,,,,,,,,,,,, -62000,2.9706876,2.6210706,,,,,,,,,,,,,, -62100,2.9453998,2.6764073,,,,,,,,,,,,,, -62200,3.0280788,2.69849,,,,,,,,,,,,,, -62300,3.086959,2.6940277,,,,,,,,,,,,,, -62400,3.288451,2.7042532,,,,,,,,,,,,,, -62500,2.918942,2.7598748,,,,,,,,,,,,,, -62600,3.1972523,2.6516042,,,,,,,,,,,,,, -62700,3.3838873,2.7764318,,,,,,,,,,,,,, -62787,,,0.682637095451355,1.37348473072052,0.6302399635314941,1.626953125,50000.0,0.5047000050544739,2.289322853088379,10000.0,21478.93054676056,22368.341687202454,21478.93054676056,884.985392332077,2.344024419784546,0.0 -62800,2.893342,2.6587694,,,,,,,,,,,,,, -62900,2.9429576,2.6536245,,,,,,,,,,,,,, -63000,3.2751195,2.6128979,,,,,,,,,,,,,, -63100,3.4843035,2.710946,,,,,,,,,,,,,, -63200,3.7313137,2.6094217,,,,,,,,,,,,,, -63300,3.1230915,2.8279471,,,,,,,,,,,,,, -63400,3.2781644,2.7514467,,,,,,,,,,,,,, -63500,2.912023,2.7236853,,,,,,,,,,,,,, -63600,3.117823,2.6874988,,,,,,,,,,,,,, -63700,3.0858912,2.7260342,,,,,,,,,,,,,, -63800,2.7719269,2.6777298,,,,,,,,,,,,,, -63900,3.1681855,2.6438587,,,,,,,,,,,,,, -64000,3.6442726,2.6590528,,,,,,,,,,,,,, -64100,3.3024325,2.633836,,,,,,,,,,,,,, -64200,3.1040664,2.734667,,,,,,,,,,,,,, -64284,,,0.6745256781578064,1.4175406694412231,0.6255800127983093,1.6421574354171753,50000.0,0.5065000057220459,2.2937440872192383,10000.0,21988.93044018745,22895.69519519806,21988.93044018745,902.2461650371552,2.3862104415893555,0.0 -64300,3.7991903,2.6448717,,,,,,,,,,,,,, -64400,2.988798,2.67295,,,,,,,,,,,,,, -64500,3.569422,2.6561368,,,,,,,,,,,,,, -64600,3.6359506,2.6494467,,,,,,,,,,,,,, -64700,3.3011131,2.6804183,,,,,,,,,,,,,, -64800,3.178111,2.7041626,,,,,,,,,,,,,, -64900,3.1506114,2.6563687,,,,,,,,,,,,,, -65000,3.2081888,2.6619189,,,,,,,,,,,,,, -65100,3.3213813,2.6245673,,,,,,,,,,,,,, -65200,3.1332958,2.5962183,,,,,,,,,,,,,, -65300,3.1375916,2.6150277,,,,,,,,,,,,,, -65400,3.0282283,2.6446605,,,,,,,,,,,,,, -65500,3.1330671,2.5829098,,,,,,,,,,,,,, -65600,3.3556826,2.6493962,,,,,,,,,,,,,, -65700,3.1956422,2.7822433,,,,,,,,,,,,,, -65782,,,0.6729512214660645,1.4147884845733645,0.6275599598884583,1.630405068397522,50000.0,0.504800021648407,2.2867746353149414,10000.0,22499.075800418854,23423.206660985947,22499.075800418854,919.5094072818756,2.438976287841797,0.0 -65800,3.2874067,2.564142,,,,,,,,,,,,,, -65900,3.652226,2.7643726,,,,,,,,,,,,,, -66000,3.5484905,2.7491806,,,,,,,,,,,,,, -66100,2.9776013,2.6800842,,,,,,,,,,,,,, -66200,2.8669055,2.7277927,,,,,,,,,,,,,, -66300,3.1985316,2.7169437,,,,,,,,,,,,,, -66400,2.9339526,2.7152548,,,,,,,,,,,,,, -66500,3.2244902,2.676598,,,,,,,,,,,,,, -66600,3.3730588,2.6060615,,,,,,,,,,,,,, -66700,3.1882656,2.7019186,,,,,,,,,,,,,, -66800,3.1882179,2.6227326,,,,,,,,,,,,,, -66900,2.9409463,2.692227,,,,,,,,,,,,,, -67000,3.1362998,2.7338018,,,,,,,,,,,,,, -67100,3.3208368,2.7524362,,,,,,,,,,,,,, -67200,3.148228,2.5913625,,,,,,,,,,,,,, -67279,,,0.6883569955825806,1.3431607484817505,0.6266799569129944,1.618585467338562,50000.0,0.5004000067710876,2.3088037967681885,10000.0,23009.149444818497,23950.570457935333,23009.149444818497,936.7078473567964,2.4785714149475098,0.0 -67300,2.8746626,2.6092455,,,,,,,,,,,,,, -67400,3.1594074,2.7171664,,,,,,,,,,,,,, -67500,2.8518112,2.5763278,,,,,,,,,,,,,, -67600,2.908618,2.712134,,,,,,,,,,,,,, -67700,3.2268233,2.692701,,,,,,,,,,,,,, -67800,3.5034935,2.6482263,,,,,,,,,,,,,, -67900,4.0994635,2.58562,,,,,,,,,,,,,, -68000,3.7342987,2.644109,,,,,,,,,,,,,, -68100,3.2386732,2.712067,,,,,,,,,,,,,, -68200,3.7294636,2.7363486,,,,,,,,,,,,,, -68300,2.9780781,2.7044804,,,,,,,,,,,,,, -68400,2.938955,2.6513007,,,,,,,,,,,,,, -68500,2.9961488,2.5889547,,,,,,,,,,,,,, -68600,3.185877,2.5954213,,,,,,,,,,,,,, -68700,2.8386748,2.5808945,,,,,,,,,,,,,, -68777,,,0.7035036683082581,1.274147391319275,0.6340799927711487,1.591722011566162,50000.0,0.5062000155448914,2.276012897491455,10000.0,23519.17913007736,24477.8970515728,23519.17913007736,953.9149236679076,2.5172877311706543,0.0 -68800,3.3967783,2.7009094,,,,,,,,,,,,,, -68900,2.9740815,2.6180274,,,,,,,,,,,,,, -69000,3.5935366,2.6630728,,,,,,,,,,,,,, -69100,3.0319495,2.6188407,,,,,,,,,,,,,, -69200,2.9501224,2.682007,,,,,,,,,,,,,, -69300,3.105622,2.6146512,,,,,,,,,,,,,, -69400,3.836483,2.6483517,,,,,,,,,,,,,, -69500,3.3350904,2.7273142,,,,,,,,,,,,,, -69600,2.9087653,2.5938113,,,,,,,,,,,,,, -69700,3.0331523,2.6024733,,,,,,,,,,,,,, -69800,3.2258635,2.5160818,,,,,,,,,,,,,, -69900,3.5480695,2.7094622,,,,,,,,,,,,,, -70000,3.3821843,2.641912,,,,,,,,,,,,,, -70100,3.3370113,2.6875327,,,,,,,,,,,,,, -70200,3.1761022,2.666216,,,,,,,,,,,,,, -70275,,,0.6885961294174194,1.343567132949829,0.6320799589157104,1.6132029294967651,50000.0,0.5059000253677368,2.2884953022003174,10000.0,24029.25793027877,25005.75264811516,24029.25793027877,971.5965132713318,2.5620453357696533,0.0 -70300,2.870858,2.6159317,,,,,,,,,,,,,, -70400,3.1855056,2.6305997,,,,,,,,,,,,,, -70500,3.3384366,2.5643907,,,,,,,,,,,,,, -70600,3.252624,2.6196315,,,,,,,,,,,,,, -70700,3.1211836,2.7168546,,,,,,,,,,,,,, -70800,3.5015275,2.6896665,,,,,,,,,,,,,, -70900,2.988283,2.5828745,,,,,,,,,,,,,, -71000,3.2680502,2.74965,,,,,,,,,,,,,, -71100,3.5498044,2.5383976,,,,,,,,,,,,,, -71200,3.5445952,2.6649258,,,,,,,,,,,,,, -71300,3.4680462,2.6589344,,,,,,,,,,,,,, -71400,3.0568228,2.6665354,,,,,,,,,,,,,, -71500,3.205227,2.7222595,,,,,,,,,,,,,, -71600,3.080204,2.5750666,,,,,,,,,,,,,, -71700,2.8475425,2.6911094,,,,,,,,,,,,,, -71773,,,0.6815010905265808,1.3888893127441406,0.626259982585907,1.645130634307861,50000.0,0.5006000399589539,2.317592144012451,10000.0,24539.311646461487,25533.141258955,24539.311646461487,988.8352327346802,2.6068990230560303,0.0 -71800,3.4847074,2.5848663,,,,,,,,,,,,,, -71900,3.1414638,2.5756173,,,,,,,,,,,,,, -72000,3.8396974,2.6186454,,,,,,,,,,,,,, -72100,3.2773669,2.5848715,,,,,,,,,,,,,, -72200,3.3420339,2.5907114,,,,,,,,,,,,,, -72300,3.0746264,2.7226899,,,,,,,,,,,,,, -72400,3.401657,2.7246315,,,,,,,,,,,,,, -72500,3.144977,2.653998,,,,,,,,,,,,,, -72600,3.29584,2.5991755,,,,,,,,,,,,,, -72700,3.3186145,2.654907,,,,,,,,,,,,,, -72800,3.141932,2.6048803,,,,,,,,,,,,,, -72900,3.2671306,2.6552029,,,,,,,,,,,,,, -73000,3.3714757,2.6396508,,,,,,,,,,,,,, -73100,3.4361184,2.7013388,,,,,,,,,,,,,, -73200,3.7772334,2.7296298,,,,,,,,,,,,,, -73271,,,0.6846500039100647,1.3535865545272827,0.6345799565315247,1.5897456407546997,50000.0,0.5053000450134277,2.248891592025757,10000.0,25049.316918611526,26060.77575063705,25049.316918611526,1006.3702020645142,2.651460647583008,0.0 -73300,3.6475635,2.741942,,,,,,,,,,,,,, -73400,3.4067974,2.551282,,,,,,,,,,,,,, -73500,3.245751,2.6989522,,,,,,,,,,,,,, -73600,2.9155207,2.5434098,,,,,,,,,,,,,, -73700,3.591951,2.6159656,,,,,,,,,,,,,, -73800,3.0198553,2.5626922,,,,,,,,,,,,,, -73900,3.6060154,2.6170776,,,,,,,,,,,,,, -74000,4.0043335,2.6589022,,,,,,,,,,,,,, -74100,2.7860138,2.5314744,,,,,,,,,,,,,, -74200,2.8566191,2.5064723,,,,,,,,,,,,,, -74300,3.1375422,2.6331465,,,,,,,,,,,,,, -74400,3.4271493,2.56454,,,,,,,,,,,,,, -74500,3.7165558,2.6153111,,,,,,,,,,,,,, -74600,3.5753112,2.6586719,,,,,,,,,,,,,, -74700,3.38714,2.5710907,,,,,,,,,,,,,, -74769,,,0.6935586333274841,1.3076993227005005,0.6421399712562561,1.5479406118392944,50000.0,0.5216000080108643,2.2021830081939697,10000.0,25559.289101839066,26588.49595093727,25559.289101839066,1024.0254747867584,2.6931560039520264,0.0 -74800,3.2524755,2.690129,,,,,,,,,,,,,, -74900,3.4239101,2.5883558,,,,,,,,,,,,,, -75000,3.0684474,2.565896,,,,,,,,,,,,,, -75100,3.3635037,2.6266427,,,,,,,,,,,,,, -75200,3.7592566,2.640019,,,,,,,,,,,,,, -75300,3.6313593,2.5941556,,,,,,,,,,,,,, -75400,2.915866,2.615806,,,,,,,,,,,,,, -75500,3.4688075,2.6431496,,,,,,,,,,,,,, -75600,3.2780504,2.7058282,,,,,,,,,,,,,, -75700,3.4224021,2.668443,,,,,,,,,,,,,, -75800,3.5518882,2.655336,,,,,,,,,,,,,, -75900,3.7186642,2.6170866,,,,,,,,,,,,,, -76000,3.2629268,2.6569188,,,,,,,,,,,,,, -76100,3.715735,2.5998278,,,,,,,,,,,,,, -76200,4.0759277,2.5953274,,,,,,,,,,,,,, -76267,,,0.6836535334587097,1.3648991584777832,0.6360799670219421,1.5922093391418457,50000.0,0.5103999972343445,2.274660348892212,10000.0,26069.32606625557,27115.904922246933,26069.32606625557,1041.3056933879852,2.7349095344543457,0.0 -76300,3.1420164,2.6020756,,,,,,,,,,,,,, -76400,3.4967968,2.6994064,,,,,,,,,,,,,, -76500,3.61668,2.535667,,,,,,,,,,,,,, -76600,3.6599963,2.6136894,,,,,,,,,,,,,, -76700,3.736697,2.654033,,,,,,,,,,,,,, -76800,3.707429,2.5965457,,,,,,,,,,,,,, -76900,3.4455256,2.6916702,,,,,,,,,,,,,, -77000,3.215203,2.6022806,,,,,,,,,,,,,, -77100,2.9445677,2.6190145,,,,,,,,,,,,,, -77200,4.038818,2.684803,,,,,,,,,,,,,, -77300,3.900901,2.6400037,,,,,,,,,,,,,, -77400,3.1804245,2.658213,,,,,,,,,,,,,, -77500,3.1913319,2.5944483,,,,,,,,,,,,,, -77600,3.2476728,2.560785,,,,,,,,,,,,,, -77700,3.8733096,2.6796548,,,,,,,,,,,,,, -77765,,,0.7120934128761292,1.2422109842300415,0.638700008392334,1.562541961669922,50000.0,0.5115000009536743,2.2503161430358887,10000.0,26579.532474040985,27643.506959199905,26579.532474040985,1058.6034083366394,2.780672311782837,0.0 -77800,3.5377793,2.6357212,,,,,,,,,,,,,, -77900,3.2965808,2.6337667,,,,,,,,,,,,,, -78000,3.538697,2.616064,,,,,,,,,,,,,, -78100,3.8547935,2.6421137,,,,,,,,,,,,,, -78200,3.0455666,2.5777965,,,,,,,,,,,,,, -78300,3.123042,2.5569885,,,,,,,,,,,,,, -78400,3.9075367,2.6152163,,,,,,,,,,,,,, -78500,3.1633296,2.587667,,,,,,,,,,,,,, -78600,3.4245076,2.6523612,,,,,,,,,,,,,, -78700,3.080539,2.493623,,,,,,,,,,,,,, -78800,3.396883,2.5987449,,,,,,,,,,,,,, -78900,3.3167453,2.608082,,,,,,,,,,,,,, -79000,3.5397224,2.684856,,,,,,,,,,,,,, -79100,3.5183656,2.550535,,,,,,,,,,,,,, -79200,3.6009932,2.7613056,,,,,,,,,,,,,, -79263,,,0.7058752775192261,1.28544819355011,0.6415599584579468,1.569705605506897,50000.0,0.5182999968528748,2.221109390258789,10000.0,27089.686541080475,28170.947011709213,27089.686541080475,1075.7997291088104,2.820523977279663,0.0 -79300,3.7653344,2.6520646,,,,,,,,,,,,,, -79400,3.6122525,2.5357268,,,,,,,,,,,,,, -79500,3.5045583,2.615176,,,,,,,,,,,,,, -79600,3.8757997,2.5508928,,,,,,,,,,,,,, -79700,3.294933,2.5001545,,,,,,,,,,,,,, -79800,3.4274719,2.5462437,,,,,,,,,,,,,, -79900,4.004253,2.5280752,,,,,,,,,,,,,, -80000,3.1984386,2.6536994,,,,,,,,,,,,,, -80100,3.585839,2.7169015,,,,,,,,,,,,,, -80200,2.904095,2.5629659,,,,,,,,,,,,,, -80300,3.508413,2.679244,,,,,,,,,,,,,, -80400,3.6383092,2.564246,,,,,,,,,,,,,, -80500,4.2950487,2.625844,,,,,,,,,,,,,, -80600,3.3486419,2.6065614,,,,,,,,,,,,,, -80700,4.1485286,2.5420022,,,,,,,,,,,,,, -80762,,,0.7010523080825806,1.3063050508499146,0.648140013217926,1.5532654523849487,50000.0,0.5182999968528748,2.2196383476257324,10000.0,27599.76722025872,28698.262020349503,27599.76722025872,1092.9376814365387,2.8669273853302,0.0 -80800,3.2649648,2.526475,,,,,,,,,,,,,, -80900,2.885378,2.5698826,,,,,,,,,,,,,, -81000,3.265397,2.616033,,,,,,,,,,,,,, -81100,3.1060169,2.496035,,,,,,,,,,,,,, -81200,3.370407,2.4743245,,,,,,,,,,,,,, -81300,3.0808547,2.623061,,,,,,,,,,,,,, -81400,3.559712,2.6736298,,,,,,,,,,,,,, -81500,3.2306554,2.7112246,,,,,,,,,,,,,, -81600,3.2398505,2.6012666,,,,,,,,,,,,,, -81700,3.2546601,2.5755658,,,,,,,,,,,,,, -81800,4.3871584,2.6427748,,,,,,,,,,,,,, -81900,3.7203982,2.626729,,,,,,,,,,,,,, -82000,3.1907136,2.5755954,,,,,,,,,,,,,, -82100,3.9737663,2.534657,,,,,,,,,,,,,, -82200,3.6946352,2.588623,,,,,,,,,,,,,, -82260,,,0.6965281963348389,1.2849880456924438,0.64656001329422,1.5245449542999268,50000.0,0.5174000263214111,2.2001779079437256,10000.0,28109.76361966133,29225.50002479553,28109.76361966133,1110.088816165924,2.9066665172576904,0.0 -82300,3.3648646,2.6107235,,,,,,,,,,,,,, -82400,3.2448246,2.5977054,,,,,,,,,,,,,, -82500,3.1954079,2.675427,,,,,,,,,,,,,, -82600,3.3573358,2.5800376,,,,,,,,,,,,,, -82700,3.4763646,2.6322026,,,,,,,,,,,,,, -82800,4.075628,2.5301619,,,,,,,,,,,,,, -82900,3.7940543,2.6494749,,,,,,,,,,,,,, -83000,3.169124,2.6032753,,,,,,,,,,,,,, -83100,3.5245953,2.6016788,,,,,,,,,,,,,, -83200,3.312117,2.4967427,,,,,,,,,,,,,, -83300,3.967401,2.6730556,,,,,,,,,,,,,, -83400,3.7260506,2.599233,,,,,,,,,,,,,, -83500,3.1086402,2.5714679,,,,,,,,,,,,,, -83600,3.4778197,2.5760102,,,,,,,,,,,,,, -83700,3.6174061,2.7187371,,,,,,,,,,,,,, -83758,,,0.7009526491165161,1.2806954383850098,0.6505799889564514,1.5104814767837524,50000.0,0.5253000259399414,2.1855990886688232,10000.0,28619.767454862595,29752.6642165184,28619.767454862595,1127.1571650505066,2.946876287460327,0.0 -83800,3.25884,2.6622272,,,,,,,,,,,,,, -83900,3.7295468,2.6417968,,,,,,,,,,,,,, -84000,3.4836874,2.646504,,,,,,,,,,,,,, -84100,3.529524,2.6226969,,,,,,,,,,,,,, -84200,3.8151402,2.6222968,,,,,,,,,,,,,, -84300,3.5757036,2.6022415,,,,,,,,,,,,,, -84400,3.5710886,2.6186647,,,,,,,,,,,,,, -84500,3.6320384,2.6069725,,,,,,,,,,,,,, -84600,4.1469636,2.5883923,,,,,,,,,,,,,, -84700,3.3588662,2.6182058,,,,,,,,,,,,,, -84800,3.3975868,2.5258982,,,,,,,,,,,,,, -84900,3.3742907,2.6171246,,,,,,,,,,,,,, -85000,3.3773339,2.5892534,,,,,,,,,,,,,, -85100,3.2505078,2.6667252,,,,,,,,,,,,,, -85200,3.4079127,2.580149,,,,,,,,,,,,,, -85257,,,0.7001753449440002,1.2842198610305786,0.6523399949073792,1.5107489824295044,50000.0,0.5213000178337097,2.2066726684570312,10000.0,29129.908967256542,30280.324359178543,29129.908967256542,1144.5813839435575,2.9912562370300293,0.0 -85300,3.4166076,2.55632,,,,,,,,,,,,,, -85400,3.7727005,2.5288763,,,,,,,,,,,,,, -85500,3.2501657,2.5344646,,,,,,,,,,,,,, -85600,3.1196976,2.5132523,,,,,,,,,,,,,, -85700,4.3387938,2.5945535,,,,,,,,,,,,,, -85800,3.65514,2.565853,,,,,,,,,,,,,, -85900,3.3689008,2.575963,,,,,,,,,,,,,, -86000,3.5930774,2.557194,,,,,,,,,,,,,, -86100,3.3535414,2.6331065,,,,,,,,,,,,,, -86200,4.4769897,2.6680121,,,,,,,,,,,,,, -86300,4.1951494,2.5006335,,,,,,,,,,,,,, -86400,3.2341151,2.5826554,,,,,,,,,,,,,, -86500,3.7900383,2.5907161,,,,,,,,,,,,,, -86600,3.665479,2.594403,,,,,,,,,,,,,, -86700,3.5542305,2.6381261,,,,,,,,,,,,,, -86755,,,0.7459940910339355,1.084036946296692,0.6534799933433533,1.4911538362503052,50000.0,0.5329000353813171,2.135961532592773,10000.0,29639.830031633377,30807.24199461937,29639.830031633377,1161.4845032691956,3.0346953868865967,0.0 -86800,3.4865034,2.5320358,,,,,,,,,,,,,, -86900,3.4968958,2.6198032,,,,,,,,,,,,,, -87000,3.556732,2.5655456,,,,,,,,,,,,,, -87100,3.5912986,2.559127,,,,,,,,,,,,,, -87200,3.529425,2.5332665,,,,,,,,,,,,,, -87300,4.3310165,2.5462763,,,,,,,,,,,,,, -87400,3.305593,2.60918,,,,,,,,,,,,,, -87500,3.354409,2.5049586,,,,,,,,,,,,,, -87600,3.780889,2.5879734,,,,,,,,,,,,,, -87700,4.098467,2.5733466,,,,,,,,,,,,,, -87800,3.6805005,2.5425465,,,,,,,,,,,,,, -87900,3.7915475,2.484661,,,,,,,,,,,,,, -88000,3.6306393,2.4429305,,,,,,,,,,,,,, -88100,3.6046374,2.5159702,,,,,,,,,,,,,, -88200,3.319458,2.5650616,,,,,,,,,,,,,, -88253,,,0.7254464030265808,1.1623930931091309,0.6578199863433838,1.4649604558944702,50000.0,0.5353000164031982,2.1265015602111816,10000.0,30150.005770921707,31334.62922167778,30150.005770921707,1178.5891389846802,3.0897841453552246,0.0 -88300,4.0319214,2.5392034,,,,,,,,,,,,,, -88400,3.36676,2.591098,,,,,,,,,,,,,, -88500,3.6531754,2.5605578,,,,,,,,,,,,,, -88600,3.456368,2.5425894,,,,,,,,,,,,,, -88700,3.5900493,2.5748158,,,,,,,,,,,,,, -88800,3.6484656,2.452353,,,,,,,,,,,,,, -88900,3.5261838,2.60644,,,,,,,,,,,,,, -89000,3.6199574,2.5313942,,,,,,,,,,,,,, -89100,3.4824557,2.453479,,,,,,,,,,,,,, -89200,4.113929,2.539363,,,,,,,,,,,,,, -89300,3.491589,2.5353084,,,,,,,,,,,,,, -89400,3.8281133,2.5067542,,,,,,,,,,,,,, -89500,3.5225625,2.4758909,,,,,,,,,,,,,, -89600,3.517563,2.5618367,,,,,,,,,,,,,, -89700,3.4035318,2.4558406,,,,,,,,,,,,,, -89751,,,0.7215999364852905,1.1721854209899902,0.6617000102996826,1.4494667053222656,50000.0,0.5382000207901001,2.107398748397827,10000.0,30660.09109258652,31862.09580183029,30660.09109258652,1195.8721101284027,3.136313915252685,0.0 -89800,4.227507,2.4827356,,,,,,,,,,,,,, -89900,3.9507134,2.6311731,,,,,,,,,,,,,, -90000,3.3609264,2.4896007,,,,,,,,,,,,,, -90100,4.0353303,2.609331,,,,,,,,,,,,,, -90200,3.3924594,2.5399623,,,,,,,,,,,,,, -90300,3.896539,2.6083767,,,,,,,,,,,,,, -90400,3.3278763,2.5156927,,,,,,,,,,,,,, -90500,3.798355,2.481568,,,,,,,,,,,,,, -90600,3.6368375,2.4939482,,,,,,,,,,,,,, -90700,3.493658,2.6367185,,,,,,,,,,,,,, -90800,3.4439247,2.544973,,,,,,,,,,,,,, -90900,3.8119495,2.493099,,,,,,,,,,,,,, -91000,3.758817,2.638841,,,,,,,,,,,,,, -91100,4.106591,2.5560534,,,,,,,,,,,,,, -91200,3.8545954,2.493985,,,,,,,,,,,,,, -91249,,,0.7166972160339355,1.2218282222747805,0.6592199802398682,1.4793634414672852,50000.0,0.5379000306129456,2.1280786991119385,10000.0,31170.067274332047,32389.319952964783,31170.067274332047,1213.0246744155884,3.181325912475586,0.0 -91300,4.1799145,2.5072005,,,,,,,,,,,,,, -91400,3.5930257,2.5226686,,,,,,,,,,,,,, -91500,3.6073375,2.459887,,,,,,,,,,,,,, -91600,3.791826,2.5308692,,,,,,,,,,,,,, -91700,3.7651725,2.5299273,,,,,,,,,,,,,, -91800,3.580981,2.595308,,,,,,,,,,,,,, -91900,3.5686612,2.5232582,,,,,,,,,,,,,, -92000,3.3023963,2.585378,,,,,,,,,,,,,, -92100,3.852619,2.6995857,,,,,,,,,,,,,, -92200,3.3781643,2.520347,,,,,,,,,,,,,, -92300,3.6550713,2.4756265,,,,,,,,,,,,,, -92400,3.4405608,2.5127795,,,,,,,,,,,,,, -92500,4.0829525,2.5087273,,,,,,,,,,,,,, -92600,3.7756712,2.5515435,,,,,,,,,,,,,, -92700,3.3383358,2.4685075,,,,,,,,,,,,,, -92747,,,0.7115951776504517,1.2422090768814087,0.6555399894714355,1.498248815536499,50000.0,0.5285000205039978,2.168752908706665,10000.0,31679.97949957848,32916.61490535736,31679.97949957848,1230.3129467964172,3.225436210632324,0.0 -92800,3.5604002,2.5635545,,,,,,,,,,,,,, -92900,4.1843457,2.5342019,,,,,,,,,,,,,, -93000,3.928497,2.5386386,,,,,,,,,,,,,, -93100,3.6770363,2.6040807,,,,,,,,,,,,,, -93200,3.9502695,2.5025082,,,,,,,,,,,,,, -93300,3.695036,2.5609808,,,,,,,,,,,,,, -93400,3.7873225,2.4447868,,,,,,,,,,,,,, -93500,3.4838047,2.5296872,,,,,,,,,,,,,, -93600,3.5886343,2.490178,,,,,,,,,,,,,, -93700,3.6196208,2.5693102,,,,,,,,,,,,,, -93800,3.5562582,2.583605,,,,,,,,,,,,,, -93900,3.7972796,2.479479,,,,,,,,,,,,,, -94000,3.5519857,2.5662982,,,,,,,,,,,,,, -94100,3.554559,2.5228333,,,,,,,,,,,,,, -94200,4.1590505,2.5948482,,,,,,,,,,,,,, -94245,,,0.7129902839660645,1.244142770767212,0.6612399816513062,1.4770818948745728,50000.0,0.5364000201225281,2.136128187179565,10000.0,32189.99800372124,33443.94494795799,32189.99800372124,1247.537139415741,3.262047529220581,0.0 -94300,4.0175886,2.5601342,,,,,,,,,,,,,, -94400,3.5066762,2.4292643,,,,,,,,,,,,,, -94500,4.0578218,2.6176095,,,,,,,,,,,,,, -94600,4.0344415,2.5349271,,,,,,,,,,,,,, -94700,3.2597291,2.4942439,,,,,,,,,,,,,, -94800,3.620897,2.4879148,,,,,,,,,,,,,, -94900,3.9393134,2.5744116,,,,,,,,,,,,,, -95000,3.6596472,2.4973981,,,,,,,,,,,,,, -95100,4.180866,2.4351094,,,,,,,,,,,,,, -95200,3.794833,2.4899175,,,,,,,,,,,,,, -95300,3.5414026,2.4738598,,,,,,,,,,,,,, -95400,3.9832292,2.4993007,,,,,,,,,,,,,, -95500,3.9386392,2.619154,,,,,,,,,,,,,, -95600,3.7852857,2.5195749,,,,,,,,,,,,,, -95700,3.6908677,2.507448,,,,,,,,,,,,,, -95743,,,0.742586076259613,1.1307847499847412,0.669219970703125,1.4510703086853027,50000.0,0.5428000092506409,2.1002559661865234,10000.0,32699.949419021606,33971.354709625244,32699.949419021606,1264.8968999385834,3.308539628982544,0.0 -95800,3.4206424,2.481874,,,,,,,,,,,,,, -95900,4.3089314,2.6360111,,,,,,,,,,,,,, -96000,4.2955594,2.5731902,,,,,,,,,,,,,, -96100,4.2837405,2.469309,,,,,,,,,,,,,, -96200,4.028129,2.5639071,,,,,,,,,,,,,, -96300,3.5254107,2.4753363,,,,,,,,,,,,,, -96400,3.5486228,2.4687982,,,,,,,,,,,,,, -96500,3.9107575,2.5227368,,,,,,,,,,,,,, -96600,3.6675453,2.498609,,,,,,,,,,,,,, -96700,3.7262697,2.4999516,,,,,,,,,,,,,, -96800,3.786572,2.4711246,,,,,,,,,,,,,, -96900,3.4858932,2.5624495,,,,,,,,,,,,,, -97000,3.8261604,2.513249,,,,,,,,,,,,,, -97100,3.867444,2.5241368,,,,,,,,,,,,,, -97200,4.7355776,2.6342134,,,,,,,,,,,,,, -97241,,,0.7416892647743225,1.120581030845642,0.6673399806022644,1.463154911994934,50000.0,0.5376999974250793,2.125854253768921,10000.0,33210.114094257355,34498.70940423012,33210.114094257355,1281.9905200004578,3.355043172836304,0.0 -97300,4.2317657,2.4924984,,,,,,,,,,,,,, -97400,3.7060616,2.5404496,,,,,,,,,,,,,, -97500,3.5666418,2.4447014,,,,,,,,,,,,,, -97600,3.748178,2.4713237,,,,,,,,,,,,,, -97700,4.50033,2.4853723,,,,,,,,,,,,,, -97800,3.7818692,2.589375,,,,,,,,,,,,,, -97900,3.6692708,2.507791,,,,,,,,,,,,,, -98000,3.8162742,2.5425718,,,,,,,,,,,,,, -98100,3.7944489,2.508881,,,,,,,,,,,,,, -98200,3.727445,2.4646068,,,,,,,,,,,,,, -98300,3.8151968,2.4611263,,,,,,,,,,,,,, -98400,3.720037,2.579615,,,,,,,,,,,,,, -98500,4.210934,2.4877656,,,,,,,,,,,,,, -98600,3.7343686,2.5084236,,,,,,,,,,,,,, -98700,3.7658265,2.5820134,,,,,,,,,,,,,, -98739,,,0.7276586294174194,1.1531447172164917,0.666920006275177,1.4372801780700684,50000.0,0.5424000024795532,2.1051735877990723,10000.0,33720.06629896164,35025.87611365318,33720.06629896164,1299.1062581539154,3.401531219482422,0.0 -98800,4.206744,2.497863,,,,,,,,,,,,,, -98900,3.5002868,2.4865484,,,,,,,,,,,,,, -99000,3.5572295,2.5331721,,,,,,,,,,,,,, -99100,3.526984,2.3830645,,,,,,,,,,,,,, -99200,3.8997731,2.573385,,,,,,,,,,,,,, -99300,4.084818,2.503231,,,,,,,,,,,,,, -99400,3.5962315,2.4592218,,,,,,,,,,,,,, -99500,3.7992196,2.5054579,,,,,,,,,,,,,, -99600,3.9318554,2.4931438,,,,,,,,,,,,,, -99700,3.9112546,2.4737315,,,,,,,,,,,,,, -99800,3.6654778,2.3849878,,,,,,,,,,,,,, -99900,3.6337965,2.5545745,,,,,,,,,,,,,, -100000,3.7313342,2.4691129,,,,,,,,,,,,,, -100100,4.191308,2.5178893,,,,,,,,,,,,,, -100200,3.6839237,2.3622427,,,,,,,,,,,,,, -100237,,,0.7302096486091614,1.1675571203231812,0.6653199791908264,1.4513657093048096,50000.0,0.5361000299453735,2.1154842376708984,10000.0,34229.99035668373,35552.84500479698,34229.99035668373,1316.049479007721,3.4516913890838623,0.0 -100300,4.15539,2.601591,,,,,,,,,,,,,, -100400,4.2282333,2.4786448,,,,,,,,,,,,,, -100500,3.8860283,2.5240157,,,,,,,,,,,,,, -100600,3.3448322,2.4131181,,,,,,,,,,,,,, -100700,3.9453075,2.5935383,,,,,,,,,,,,,, -100800,4.2331204,2.421222,,,,,,,,,,,,,, -100900,3.758955,2.5495553,,,,,,,,,,,,,, -101000,4.0726724,2.6205215,,,,,,,,,,,,,, -101100,3.703353,2.4487634,,,,,,,,,,,,,, -101200,3.8373375,2.491975,,,,,,,,,,,,,, -101300,4.128204,2.4490852,,,,,,,,,,,,,, -101400,3.8256044,2.5231082,,,,,,,,,,,,,, -101500,3.8619719,2.4530232,,,,,,,,,,,,,, -101600,3.8024387,2.4538164,,,,,,,,,,,,,, -101700,3.6810057,2.4081044,,,,,,,,,,,,,, -101735,,,0.7242107391357422,1.1790367364883425,0.6686800122261047,1.4343260526657104,50000.0,0.5398000478744507,2.113425970077514,10000.0,34739.97491002083,36080.01940727234,34739.97491002083,1333.1427223682404,3.4981353282928467,0.0 -101800,4.029446,2.476125,,,,,,,,,,,,,, -101900,3.9373357,2.6009016,,,,,,,,,,,,,, -102000,4.6321282,2.543567,,,,,,,,,,,,,, -102100,4.242277,2.4725215,,,,,,,,,,,,,, -102200,3.5441167,2.4415712,,,,,,,,,,,,,, -102300,3.7070358,2.4008722,,,,,,,,,,,,,, -102400,3.884517,2.4094667,,,,,,,,,,,,,, -102500,4.179984,2.5040727,,,,,,,,,,,,,, -102600,4.103744,2.5245671,,,,,,,,,,,,,, -102700,3.5369396,2.351575,,,,,,,,,,,,,, -102800,3.585403,2.4565613,,,,,,,,,,,,,, -102900,3.5716481,2.5178742,,,,,,,,,,,,,, -103000,4.1022453,2.5544045,,,,,,,,,,,,,, -103100,4.081295,2.4449248,,,,,,,,,,,,,, -103200,4.8205357,2.5099308,,,,,,,,,,,,,, -103233,,,0.7329798936843872,1.1520893573760986,0.676859974861145,1.4166473150253296,50000.0,0.5520000457763672,2.0522189140319824,10000.0,35250.05642032623,36607.30830931664,35250.05642032623,1350.248485803604,3.549124956130981,0.0 -103300,3.8673432,2.5191765,,,,,,,,,,,,,, -103400,3.7835288,2.4253106,,,,,,,,,,,,,, -103500,4.26477,2.4420717,,,,,,,,,,,,,, -103600,4.643629,2.5871441,,,,,,,,,,,,,, -103700,3.7418888,2.4610777,,,,,,,,,,,,,, -103800,3.6345217,2.4092884,,,,,,,,,,,,,, -103900,4.268362,2.433514,,,,,,,,,,,,,, -104000,4.4415007,2.4573982,,,,,,,,,,,,,, -104100,3.6924515,2.5357134,,,,,,,,,,,,,, -104200,3.584972,2.427476,,,,,,,,,,,,,, -104300,4.50327,2.4271827,,,,,,,,,,,,,, -104400,4.4094515,2.4478912,,,,,,,,,,,,,, -104500,3.8041782,2.512682,,,,,,,,,,,,,, -104600,3.7500422,2.4637692,,,,,,,,,,,,,, -104700,4.1609344,2.4974775,,,,,,,,,,,,,, -104731,,,0.7315847873687744,1.141116499900818,0.6702199578285217,1.417891502380371,50000.0,0.5373000502586365,2.095768928527832,10000.0,35759.99800825119,37134.51687049866,35759.99800825119,1367.418380498886,3.5954058170318604,0.0 -104800,4.296969,2.5144026,,,,,,,,,,,,,, -104900,4.5804434,2.4708562,,,,,,,,,,,,,, -105000,5.0007787,2.4398348,,,,,,,,,,,,,, -105100,3.8907106,2.4603882,,,,,,,,,,,,,, -105200,3.9789572,2.3659248,,,,,,,,,,,,,, -105300,3.71746,2.4375827,,,,,,,,,,,,,, -105400,4.2033105,2.5358386,,,,,,,,,,,,,, -105500,3.971474,2.4934242,,,,,,,,,,,,,, -105600,3.8687503,2.4180548,,,,,,,,,,,,,, -105700,4.696437,2.5978599,,,,,,,,,,,,,, -105800,4.1483064,2.5039713,,,,,,,,,,,,,, -105900,3.7668097,2.4279027,,,,,,,,,,,,,, -106000,4.5866013,2.4489508,,,,,,,,,,,,,, -106100,4.26761,2.4652503,,,,,,,,,,,,,, -106200,4.2848516,2.548004,,,,,,,,,,,,,, -106229,,,0.7540457248687744,1.072227954864502,0.6693599820137024,1.442069172859192,50000.0,0.5469000339508057,2.085392951965332,10000.0,36270.085739851,37661.89115190506,36270.085739851,1384.609278678894,3.6411328315734863,0.0 -106300,4.0985265,2.4643664,,,,,,,,,,,,,, -106400,4.688437,2.5188446,,,,,,,,,,,,,, -106500,4.283196,2.4622717,,,,,,,,,,,,,, -106600,4.28493,2.5204053,,,,,,,,,,,,,, -106700,4.6933503,2.3726175,,,,,,,,,,,,,, -106800,3.9864292,2.4676514,,,,,,,,,,,,,, -106900,3.6456707,2.3754225,,,,,,,,,,,,,, -107000,4.0339665,2.5472615,,,,,,,,,,,,,, -107100,4.2037706,2.487712,,,,,,,,,,,,,, -107200,3.8908129,2.3985023,,,,,,,,,,,,,, -107300,3.9288108,2.5341141,,,,,,,,,,,,,, -107400,4.0484395,2.4670513,,,,,,,,,,,,,, -107500,3.8997247,2.3944054,,,,,,,,,,,,,, -107600,4.068571,2.573569,,,,,,,,,,,,,, -107700,3.815566,2.3492827,,,,,,,,,,,,,, -107727,,,0.751973032951355,1.068299651145935,0.6773999929428101,1.4052557945251465,50000.0,0.5508000254631042,2.0545578002929688,10000.0,36780.10208392143,38188.91479277611,36780.10208392143,1401.521045923233,3.6874375343322754,0.0 -107800,3.8335032,2.439652,,,,,,,,,,,,,, -107900,3.9504273,2.449082,,,,,,,,,,,,,, -108000,4.535797,2.5541513,,,,,,,,,,,,,, -108100,4.1779413,2.401799,,,,,,,,,,,,,, -108200,4.5417776,2.4750538,,,,,,,,,,,,,, -108300,3.9854105,2.433934,,,,,,,,,,,,,, -108400,4.423403,2.486475,,,,,,,,,,,,,, -108500,4.226835,2.4943123,,,,,,,,,,,,,, -108600,4.3779783,2.4101,,,,,,,,,,,,,, -108700,3.7550662,2.4134243,,,,,,,,,,,,,, -108800,4.179275,2.441666,,,,,,,,,,,,,, -108900,4.0926313,2.4729555,,,,,,,,,,,,,, -109000,3.6592681,2.4342065,,,,,,,,,,,,,, -109100,4.286437,2.4041345,,,,,,,,,,,,,, -109200,3.8341115,2.3509483,,,,,,,,,,,,,, -109225,,,0.7469507455825806,1.086888074874878,0.6785999536514282,1.395654797554016,50000.0,0.5574000477790833,2.0311355590820312,10000.0,37290.17399263382,38716.25345778465,37290.17399263382,1418.6850700378418,3.740187406539917,0.0 -109300,4.294342,2.3612356,,,,,,,,,,,,,, -109400,4.1998076,2.4272738,,,,,,,,,,,,,, -109500,4.3861985,2.4811668,,,,,,,,,,,,,, -109600,4.1492715,2.4432852,,,,,,,,,,,,,, -109700,4.5464377,2.4589186,,,,,,,,,,,,,, -109800,4.115286,2.3826869,,,,,,,,,,,,,, -109900,3.9969673,2.4701536,,,,,,,,,,,,,, -110000,4.5865126,2.465899,,,,,,,,,,,,,, -110100,3.9432182,2.4320679,,,,,,,,,,,,,, -110200,3.9628053,2.4370012,,,,,,,,,,,,,, -110300,4.21677,2.5131912,,,,,,,,,,,,,, -110400,4.462194,2.4676125,,,,,,,,,,,,,, -110500,4.137765,2.4427097,,,,,,,,,,,,,, -110600,3.798871,2.3720775,,,,,,,,,,,,,, -110700,4.630963,2.468279,,,,,,,,,,,,,, -110723,,,0.7480069994926453,1.080083966255188,0.6832799911499023,1.3639898300170898,50000.0,0.5574000477790833,2.017634391784668,10000.0,37800.17114710808,39243.452078580856,37800.17114710808,1435.787403345108,3.788019895553589,0.0 -110800,4.5756655,2.3987036,,,,,,,,,,,,,, -110900,4.4450173,2.4433715,,,,,,,,,,,,,, -111000,4.3944187,2.5052557,,,,,,,,,,,,,, -111100,3.9448645,2.4132407,,,,,,,,,,,,,, -111200,4.068756,2.4035668,,,,,,,,,,,,,, -111300,4.050548,2.506405,,,,,,,,,,,,,, -111400,4.448647,2.3714836,,,,,,,,,,,,,, -111500,4.109907,2.3667793,,,,,,,,,,,,,, -111600,4.0245233,2.421273,,,,,,,,,,,,,, -111700,4.2797637,2.42481,,,,,,,,,,,,,, -111800,4.180863,2.368239,,,,,,,,,,,,,, -111900,4.1529293,2.3529062,,,,,,,,,,,,,, -112000,3.9948099,2.4182694,,,,,,,,,,,,,, -112100,4.489428,2.3667798,,,,,,,,,,,,,, -112200,3.9661956,2.3320014,,,,,,,,,,,,,, -112221,,,0.7472297549247742,1.0725810527801514,0.6823199987411499,1.363050103187561,50000.0,0.5624000430107117,2.006924152374268,10000.0,38310.2585234642,39770.78035974503,38310.2585234642,1452.9339735507965,3.8331050872802734,0.0 -112300,4.523848,2.4211612,,,,,,,,,,,,,, -112400,4.1895704,2.4481425,,,,,,,,,,,,,, -112500,4.79361,2.4404337,,,,,,,,,,,,,, -112600,4.5024567,2.4310813,,,,,,,,,,,,,, -112700,4.3161764,2.4369144,,,,,,,,,,,,,, -112800,3.9387956,2.4205158,,,,,,,,,,,,,, -112900,4.408013,2.4044583,,,,,,,,,,,,,, -113000,4.299967,2.371782,,,,,,,,,,,,,, -113100,4.031404,2.3826976,,,,,,,,,,,,,, -113200,4.215103,2.327657,,,,,,,,,,,,,, -113300,3.98394,2.3347676,,,,,,,,,,,,,, -113400,4.584251,2.413117,,,,,,,,,,,,,, -113500,4.345895,2.4510105,,,,,,,,,,,,,, -113600,4.031479,2.377233,,,,,,,,,,,,,, -113700,4.9071317,2.3378701,,,,,,,,,,,,,, -113719,,,0.7479073405265808,1.0708001852035522,0.6868199706077576,1.3397966623306274,50000.0,0.5601000189781189,2.016330003738404,10000.0,38820.35129117966,40298.13475847244,38820.35129117966,1470.0948634147644,3.881237268447876,0.0 -113800,4.122278,2.3999853,,,,,,,,,,,,,, -113900,4.3414297,2.3914886,,,,,,,,,,,,,, -114000,3.94749,2.4563196,,,,,,,,,,,,,, -114100,4.645937,2.4088478,,,,,,,,,,,,,, -114200,3.8022149,2.3655722,,,,,,,,,,,,,, -114300,3.752739,2.350468,,,,,,,,,,,,,, -114400,4.624102,2.417339,,,,,,,,,,,,,, -114500,4.03428,2.4243362,,,,,,,,,,,,,, -114600,4.251709,2.4976115,,,,,,,,,,,,,, -114700,4.621498,2.458713,,,,,,,,,,,,,, -114800,4.3702,2.4469042,,,,,,,,,,,,,, -114900,4.1983333,2.4462152,,,,,,,,,,,,,, -115000,4.135072,2.4154048,,,,,,,,,,,,,, -115100,4.166737,2.2952003,,,,,,,,,,,,,, -115200,4.041244,2.3591197,,,,,,,,,,,,,, -115217,,,0.7872887253761292,0.9008607268333436,0.6908800005912781,1.3157756328582764,50000.0,0.5639000535011292,1.989737153053284,10000.0,39330.478276491165,40825.58081579208,39330.478276491165,1487.3175942897797,3.92743730545044,0.0 -115300,4.7884583,2.3613431,,,,,,,,,,,,,, -115400,4.507634,2.3493333,,,,,,,,,,,,,, -115500,4.0897717,2.3028035,,,,,,,,,,,,,, -115600,4.845076,2.4377742,,,,,,,,,,,,,, -115700,4.454658,2.3098278,,,,,,,,,,,,,, -115800,3.9331255,2.358732,,,,,,,,,,,,,, -115900,4.633487,2.2993698,,,,,,,,,,,,,, -116000,4.296332,2.3717225,,,,,,,,,,,,,, -116100,3.9429715,2.3970873,,,,,,,,,,,,,, -116200,4.3655396,2.3601594,,,,,,,,,,,,,, -116300,4.073463,2.4129052,,,,,,,,,,,,,, -116400,4.056383,2.36663,,,,,,,,,,,,,, -116500,4.982603,2.3405285,,,,,,,,,,,,,, -116600,4.370494,2.420289,,,,,,,,,,,,,, -116700,4.901781,2.3526993,,,,,,,,,,,,,, -116716,,,0.7737165093421936,0.961524486541748,0.6925399899482727,1.3126429319381714,50000.0,0.5701000094413757,1.9820610284805296,10000.0,39840.57607722282,41353.06807017326,39840.57607722282,1504.6061329841614,3.976889371871948,0.0 -116800,4.7467575,2.4081075,,,,,,,,,,,,,, -116900,4.5324416,2.4316034,,,,,,,,,,,,,, -117000,4.199419,2.3326151,,,,,,,,,,,,,, -117100,4.806034,2.2840161,,,,,,,,,,,,,, -117200,4.5005703,2.3328795,,,,,,,,,,,,,, -117300,4.2779846,2.2854676,,,,,,,,,,,,,, -117400,4.6177673,2.34079,,,,,,,,,,,,,, -117500,4.4606695,2.3785748,,,,,,,,,,,,,, -117600,4.403537,2.410348,,,,,,,,,,,,,, -117700,4.2187605,2.382256,,,,,,,,,,,,,, -117800,4.6697764,2.358869,,,,,,,,,,,,,, -117900,4.336558,2.335685,,,,,,,,,,,,,, -118000,4.4034615,2.3265123,,,,,,,,,,,,,, -118100,4.477179,2.3031251,,,,,,,,,,,,,, -118200,4.039165,2.3774178,,,,,,,,,,,,,, -118214,,,0.7669204473495483,1.010676383972168,0.6954599618911743,1.3212053775787354,50000.0,0.5715000033378601,1.9586877822875977,10000.0,40350.64465737343,41880.47035765648,40350.64465737343,1521.836299419403,4.030749559402466,0.0 -118300,4.908731,2.4484773,,,,,,,,,,,,,, -118400,4.324087,2.3384984,,,,,,,,,,,,,, -118500,4.3165493,2.338848,,,,,,,,,,,,,, -118600,4.272822,2.400784,,,,,,,,,,,,,, -118700,5.287309,2.3823767,,,,,,,,,,,,,, -118800,4.6422706,2.394958,,,,,,,,,,,,,, -118900,4.361475,2.3713388,,,,,,,,,,,,,, -119000,4.1477284,2.3392448,,,,,,,,,,,,,, -119100,4.783783,2.3889165,,,,,,,,,,,,,, -119200,4.7517543,2.3484359,,,,,,,,,,,,,, -119300,4.517587,2.432844,,,,,,,,,,,,,, -119400,4.6339207,2.4408562,,,,,,,,,,,,,, -119500,4.624273,2.3453765,,,,,,,,,,,,,, -119600,4.4752793,2.3089244,,,,,,,,,,,,,, -119700,4.5134373,2.2963932,,,,,,,,,,,,,, -119712,,,0.7684550285339355,1.0029422044754028,0.6985799670219421,1.310947299003601,50000.0,0.5733000040054321,1.962609052658081,10000.0,40860.65657663345,42407.98349690437,40860.65657663345,1539.2372624874115,4.080804824829102,0.0 -119800,4.843474,2.4268856,,,,,,,,,,,,,, -119900,4.8573456,2.3690488,,,,,,,,,,,,,, -120000,4.352327,2.3472576,,,,,,,,,,,,,, -120100,4.047625,2.3580873,,,,,,,,,,,,,, -120200,4.4243608,2.3020093,,,,,,,,,,,,,, -120300,4.474461,2.3653188,,,,,,,,,,,,,, -120400,4.582506,2.4156253,,,,,,,,,,,,,, -120500,4.686265,2.4803534,,,,,,,,,,,,,, -120600,4.2949185,2.3931248,,,,,,,,,,,,,, -120700,4.2471514,2.323388,,,,,,,,,,,,,, -120800,4.3121805,2.3383002,,,,,,,,,,,,,, -120900,4.5162287,2.319245,,,,,,,,,,,,,, -121000,4.5516076,2.3031301,,,,,,,,,,,,,, -121100,5.2627606,2.315154,,,,,,,,,,,,,, -121200,4.7863393,2.3294861,,,,,,,,,,,,,, -121210,,,0.7650868892669678,0.9947850704193116,0.6988399624824524,1.2936506271362305,50000.0,0.5770000219345093,1.9499543905258176,10000.0,41370.574806690216,42935.32121896744,41370.574806690216,1556.559273481369,4.127295732498169,0.0 -121300,3.9610353,2.2583995,,,,,,,,,,,,,, -121400,5.285324,2.455407,,,,,,,,,,,,,, -121500,4.653036,2.3059678,,,,,,,,,,,,,, -121600,4.3799763,2.3083684,,,,,,,,,,,,,, -121700,4.1560397,2.375051,,,,,,,,,,,,,, -121800,4.472558,2.411985,,,,,,,,,,,,,, -121900,4.981989,2.354663,,,,,,,,,,,,,, -122000,4.3597126,2.3532162,,,,,,,,,,,,,, -122100,4.533781,2.2868228,,,,,,,,,,,,,, -122200,4.6511745,2.2954931,,,,,,,,,,,,,, -122300,4.9339356,2.4040842,,,,,,,,,,,,,, -122400,4.5199666,2.3102736,,,,,,,,,,,,,, -122500,5.6033897,2.363463,,,,,,,,,,,,,, -122600,4.975396,2.3539042,,,,,,,,,,,,,, -122700,4.6107125,2.3931932,,,,,,,,,,,,,, -122708,,,0.7655652165412903,1.012738823890686,0.6997199654579163,1.3031195402145386,50000.0,0.5733000040054321,1.966423749923706,10000.0,41880.55111408234,43462.52871155739,41880.55111408234,1573.6708698272705,4.196326732635498,0.0 -122800,4.9661956,2.287738,,,,,,,,,,,,,, -122900,4.3005376,2.3607523,,,,,,,,,,,,,, -123000,4.828794,2.4045584,,,,,,,,,,,,,, -123100,4.621441,2.383388,,,,,,,,,,,,,, -123200,4.943218,2.32366,,,,,,,,,,,,,, -123300,4.5510926,2.2320516,,,,,,,,,,,,,, -123400,4.6156077,2.341472,,,,,,,,,,,,,, -123500,4.418789,2.3007295,,,,,,,,,,,,,, -123600,4.4579134,2.3401508,,,,,,,,,,,,,, -123700,4.8098764,2.3784018,,,,,,,,,,,,,, -123800,4.8314695,2.3747044,,,,,,,,,,,,,, -123900,4.4605074,2.3199704,,,,,,,,,,,,,, -124000,5.1004963,2.311842,,,,,,,,,,,,,, -124100,4.571673,2.3470883,,,,,,,,,,,,,, -124200,4.5554013,2.2926083,,,,,,,,,,,,,, -124206,,,0.7978515625,0.8683147430419922,0.7073799967765808,1.261612057685852,50000.0,0.5795000195503235,1.902595043182373,10000.0,42390.75530362129,43989.89999890328,42390.75530362129,1590.7367997169497,4.2473227977752686,0.0 -124300,4.4287915,2.269279,,,,,,,,,,,,,, -124400,4.570364,2.3103144,,,,,,,,,,,,,, -124500,4.4850984,2.257758,,,,,,,,,,,,,, -124600,4.6245465,2.2653463,,,,,,,,,,,,,, -124700,4.879073,2.2738426,,,,,,,,,,,,,, -124800,4.7360125,2.177335,,,,,,,,,,,,,, -124900,4.587414,2.2608225,,,,,,,,,,,,,, -125000,4.6608896,2.2272549,,,,,,,,,,,,,, -125100,4.5166864,2.2933035,,,,,,,,,,,,,, -125200,4.5411115,2.366405,,,,,,,,,,,,,, -125300,5.195542,2.2735689,,,,,,,,,,,,,, -125400,4.592691,2.3777518,,,,,,,,,,,,,, -125500,4.532978,2.4250937,,,,,,,,,,,,,, -125600,5.5689735,2.3440874,,,,,,,,,,,,,, -125700,4.6821823,2.2219462,,,,,,,,,,,,,, -125705,,,0.7971938848495483,0.8802825808525085,0.7064599990844727,1.266579031944275,50000.0,0.5808000564575195,1.9092512130737305,10000.0,42900.89259982109,44517.35197234154,42900.89259982109,1607.9538469314575,4.293523788452148,0.0 -125800,4.9897766,2.391163,,,,,,,,,,,,,, -125900,4.9229965,2.3070333,,,,,,,,,,,,,, -126000,5.053068,2.3719049,,,,,,,,,,,,,, -126100,5.27184,2.2653608,,,,,,,,,,,,,, -126200,4.7285995,2.3307917,,,,,,,,,,,,,, -126300,5.021414,2.378243,,,,,,,,,,,,,, -126400,5.5132155,2.3001883,,,,,,,,,,,,,, -126500,4.698857,2.215332,,,,,,,,,,,,,, -126600,4.7569385,2.3555975,,,,,,,,,,,,,, -126700,5.112481,2.3073115,,,,,,,,,,,,,, -126800,6.2598815,2.3150187,,,,,,,,,,,,,, -126900,4.820547,2.3821046,,,,,,,,,,,,,, -127000,4.5564837,2.2579136,,,,,,,,,,,,,, -127100,4.8077717,2.3369148,,,,,,,,,,,,,, -127200,5.3198056,2.2650661,,,,,,,,,,,,,, -127202,,,0.7934072017669678,0.8894890546798706,0.7084199786186218,1.2445156574249268,50000.0,0.5782000422477722,1.8972203731536863,10000.0,43410.79142928124,45044.73135638237,43410.79142928124,1625.3317823410034,4.347227096557617,0.0 -127300,4.880734,2.2998188,,,,,,,,,,,,,, -127400,4.941797,2.2492697,,,,,,,,,,,,,, -127500,4.776217,2.1578345,,,,,,,,,,,,,, -127600,4.668805,2.3128011,,,,,,,,,,,,,, -127700,5.101534,2.3381014,,,,,,,,,,,,,, -127800,4.7728977,2.38621,,,,,,,,,,,,,, -127900,4.823097,2.2403688,,,,,,,,,,,,,, -128000,4.7681346,2.3315184,,,,,,,,,,,,,, -128100,4.6697955,2.288768,,,,,,,,,,,,,, -128200,4.8508906,2.2269056,,,,,,,,,,,,,, -128300,5.146144,2.279193,,,,,,,,,,,,,, -128400,4.9870815,2.2523394,,,,,,,,,,,,,, -128500,4.784512,2.2541454,,,,,,,,,,,,,, -128600,4.9550543,2.246623,,,,,,,,,,,,,, -128699,,,0.7883649468421936,0.9085213541984558,0.7102800011634827,1.2549796104431152,50000.0,0.5886000394821167,1.8822438716888428,10000.0,43920.76516246796,45571.81634020805,43920.76516246796,1642.3397538661957,4.400299549102783,0.0 -128700,4.8316555,2.216116,,,,,,,,,,,,,, -128800,5.575389,2.3182318,,,,,,,,,,,,,, -128900,4.56269,2.18761,,,,,,,,,,,,,, -129000,4.8138933,2.2322211,,,,,,,,,,,,,, -129100,5.7127113,2.3418949,,,,,,,,,,,,,, -129200,4.879325,2.2674286,,,,,,,,,,,,,, -129300,5.1656904,2.2634382,,,,,,,,,,,,,, -129400,5.7579465,2.3191965,,,,,,,,,,,,,, -129500,5.2245374,2.2265668,,,,,,,,,,,,,, -129600,5.161492,2.3018956,,,,,,,,,,,,,, -129700,5.0046163,2.2210717,,,,,,,,,,,,,, -129800,4.821076,2.2934675,,,,,,,,,,,,,, -129900,4.833915,2.223333,,,,,,,,,,,,,, -130000,5.529192,2.303172,,,,,,,,,,,,,, -130100,4.941044,2.3015344,,,,,,,,,,,,,, -130197,,,0.7904177308082581,0.8950925469398499,0.7142800092697144,1.2327156066894531,50000.0,0.5909000039100647,1.8738075494766235,10000.0,44430.832174539566,46099.25558972359,44430.832174539566,1659.612549781799,4.4501917362213135,0.0 -130200,4.558204,2.2173798,,,,,,,,,,,,,, -130300,5.035272,2.1832945,,,,,,,,,,,,,, -130400,4.861363,2.264633,,,,,,,,,,,,,, -130500,5.0561333,2.2174506,,,,,,,,,,,,,, -130600,4.4853473,2.2015674,,,,,,,,,,,,,, -130700,4.960248,2.2969587,,,,,,,,,,,,,, -130800,4.9134774,2.312488,,,,,,,,,,,,,, -130900,4.660577,2.2799323,,,,,,,,,,,,,, -131000,4.876357,2.19569,,,,,,,,,,,,,, -131100,5.2066326,2.1910932,,,,,,,,,,,,,, -131200,5.6543827,2.2759192,,,,,,,,,,,,,, -131300,4.9405713,2.2637877,,,,,,,,,,,,,, -131400,5.452592,2.2560756,,,,,,,,,,,,,, -131500,5.1715074,2.225421,,,,,,,,,,,,,, -131600,5.1250944,2.171462,,,,,,,,,,,,,, -131695,,,0.7917131781578064,0.8998433351516724,0.7148999571800232,1.22651207447052,50000.0,0.5855000019073486,1.8798346519470213,10000.0,44940.92831659317,46626.481568574905,44940.92831659317,1676.6434774398804,4.499432563781738,0.0 -131700,5.227515,2.3377173,,,,,,,,,,,,,, -131800,5.0036798,2.1985276,,,,,,,,,,,,,, -131900,5.197238,2.2237742,,,,,,,,,,,,,, -132000,5.1270456,2.302154,,,,,,,,,,,,,, -132100,5.337079,2.2264314,,,,,,,,,,,,,, -132200,5.1062627,2.1179628,,,,,,,,,,,,,, -132300,5.1134195,2.1333432,,,,,,,,,,,,,, -132400,5.104919,2.1724274,,,,,,,,,,,,,, -132500,5.470213,2.2398686,,,,,,,,,,,,,, -132600,5.035412,2.3014898,,,,,,,,,,,,,, -132700,5.5269885,2.276601,,,,,,,,,,,,,, -132800,5.308633,2.3013394,,,,,,,,,,,,,, -132900,5.544625,2.1947682,,,,,,,,,,,,,, -133000,6.092518,2.3953295,,,,,,,,,,,,,, -133100,5.0603952,2.2334118,,,,,,,,,,,,,, -133193,,,0.7970942258834839,0.8763123750686646,0.7187199592590332,1.2121546268463137,50000.0,0.5919000506401062,1.845463752746582,10000.0,45450.86598086357,47153.72575092316,45450.86598086357,1693.8448660373688,4.552513122558594,0.0 -133200,4.8322144,2.2559597,,,,,,,,,,,,,, -133300,5.1457725,2.2701397,,,,,,,,,,,,,, -133400,5.367403,2.1209989,,,,,,,,,,,,,, -133500,4.84368,2.1895394,,,,,,,,,,,,,, -133600,5.315275,2.2334647,,,,,,,,,,,,,, -133700,4.598202,2.193906,,,,,,,,,,,,,, -133800,4.792842,2.1519156,,,,,,,,,,,,,, -133900,5.282729,2.2940853,,,,,,,,,,,,,, -134000,5.138461,2.2040865,,,,,,,,,,,,,, -134100,5.4121356,2.2330546,,,,,,,,,,,,,, -134200,5.13373,2.1453326,,,,,,,,,,,,,, -134300,5.688992,2.204373,,,,,,,,,,,,,, -134400,4.901976,2.2461588,,,,,,,,,,,,,, -134500,5.0600667,2.2585669,,,,,,,,,,,,,, -134600,6.099106,2.2310712,,,,,,,,,,,,,, -134692,,,0.8205516338348389,0.7716889381408691,0.7160199880599976,1.2104519605636597,50000.0,0.5961000323295593,1.841491937637329,10000.0,45961.00407743454,47681.152082681656,45961.00407743454,1711.0298657417295,4.6041131019592285,0.0 -134700,5.262111,2.301639,,,,,,,,,,,,,, -134800,4.8729124,2.2035112,,,,,,,,,,,,,, -134900,5.70103,2.2053144,,,,,,,,,,,,,, -135000,5.9938607,2.2299526,,,,,,,,,,,,,, -135100,6.08263,2.1516085,,,,,,,,,,,,,, -135200,5.1886945,2.1696343,,,,,,,,,,,,,, -135300,5.0726295,2.2001216,,,,,,,,,,,,,, -135400,4.7900105,2.2126493,,,,,,,,,,,,,, -135500,5.933803,2.1669564,,,,,,,,,,,,,, -135600,5.1952477,2.1241891,,,,,,,,,,,,,, -135700,5.5011573,2.1985645,,,,,,,,,,,,,, -135800,5.352469,2.1639223,,,,,,,,,,,,,, -135900,4.9566383,2.1951866,,,,,,,,,,,,,, -136000,5.1607127,2.069086,,,,,,,,,,,,,, -136100,5.574253,2.2340293,,,,,,,,,,,,,, -136190,,,0.8130978941917419,0.8022621870040894,0.7203399538993835,1.197677493095398,50000.0,0.5961000323295593,1.8225889205932613,10000.0,46470.92022943497,48208.3938832283,46470.92022943497,1728.2518351078031,4.656829357147217,0.0 -136200,5.560393,2.1952128,,,,,,,,,,,,,, -136300,5.292867,2.1623769,,,,,,,,,,,,,, -136400,5.1221533,2.1879222,,,,,,,,,,,,,, -136500,5.305113,2.1766758,,,,,,,,,,,,,, -136600,5.422003,2.281547,,,,,,,,,,,,,, -136700,5.3828435,2.162623,,,,,,,,,,,,,, -136800,5.595368,2.1819725,,,,,,,,,,,,,, -136900,5.472495,2.2184296,,,,,,,,,,,,,, -137000,5.010658,2.1853192,,,,,,,,,,,,,, -137100,5.337949,2.201415,,,,,,,,,,,,,, -137200,5.175022,2.1437979,,,,,,,,,,,,,, -137300,5.794987,2.2000961,,,,,,,,,,,,,, -137400,5.128595,2.1635299,,,,,,,,,,,,,, -137500,5.543665,2.1179326,,,,,,,,,,,,,, -137600,5.0264006,2.1381392,,,,,,,,,,,,,, -137689,,,0.810566782951355,0.8220297694206238,0.722819983959198,1.207771062850952,50000.0,0.5981000065803528,1.852874755859375,10000.0,46980.9376642704,48735.45760130882,46980.9376642704,1745.192828655243,4.712372779846191,0.0 -137700,5.6172404,2.2835033,,,,,,,,,,,,,, -137800,5.337102,2.3004823,,,,,,,,,,,,,, -137900,5.5307713,2.324887,,,,,,,,,,,,,, -138000,5.817081,2.201914,,,,,,,,,,,,,, -138100,5.578945,2.1674175,,,,,,,,,,,,,, -138200,5.235253,2.2027056,,,,,,,,,,,,,, -138300,5.3230896,2.171933,,,,,,,,,,,,,, -138400,5.517541,2.2567954,,,,,,,,,,,,,, -138500,5.85075,2.2498348,,,,,,,,,,,,,, -138600,5.0257034,2.1387227,,,,,,,,,,,,,, -138700,5.533752,2.1793602,,,,,,,,,,,,,, -138800,5.464427,2.184978,,,,,,,,,,,,,, -138900,5.376611,2.234624,,,,,,,,,,,,,, -139000,5.3638263,2.136014,,,,,,,,,,,,,, -139100,5.976863,2.2321799,,,,,,,,,,,,,, -139187,,,0.8121013641357422,0.8041132688522339,0.7247599959373474,1.185956954956055,50000.0,0.5973000526428223,1.825844049453736,10000.0,47490.93303322792,49262.71445965767,47490.93303322792,1762.3447728157043,4.771524906158447,0.0 -139200,5.4259677,2.1441536,,,,,,,,,,,,,, -139300,5.4081717,2.208796,,,,,,,,,,,,,, -139400,5.733328,2.2546313,,,,,,,,,,,,,, -139500,6.130597,2.1466022,,,,,,,,,,,,,, -139600,5.132674,2.1140628,,,,,,,,,,,,,, -139700,5.405246,2.1775792,,,,,,,,,,,,,, -139800,5.442494,2.127387,,,,,,,,,,,,,, -139900,5.4598165,2.1094804,,,,,,,,,,,,,, -140000,5.935909,2.1564882,,,,,,,,,,,,,, -140100,5.2698636,2.1140804,,,,,,,,,,,,,, -140200,5.457488,2.149349,,,,,,,,,,,,,, -140300,6.169686,2.1632192,,,,,,,,,,,,,, -140400,5.2892585,2.1147523,,,,,,,,,,,,,, -140500,5.684648,2.1944184,,,,,,,,,,,,,, -140600,5.476223,2.1671915,,,,,,,,,,,,,, -140684,,,0.8148915767669678,0.7869499921798706,0.7278199791908264,1.1672886610031128,50000.0,0.6026000380516052,1.7988563776016235,10000.0,48000.869512319565,49789.94197225571,48000.869512319565,1779.533354997635,4.824806928634644,0.0 -140700,5.565162,2.1479018,,,,,,,,,,,,,, -140800,5.770327,2.1795979,,,,,,,,,,,,,, -140900,5.472463,2.1290078,,,,,,,,,,,,,, -141000,6.042528,2.1276157,,,,,,,,,,,,,, -141100,5.5521474,2.1309538,,,,,,,,,,,,,, -141200,5.8779235,2.1343298,,,,,,,,,,,,,, -141300,5.4477587,2.1727948,,,,,,,,,,,,,, -141400,6.0790763,2.1161685,,,,,,,,,,,,,, -141500,5.717397,2.1789193,,,,,,,,,,,,,, -141600,5.8599505,2.1868396,,,,,,,,,,,,,, -141700,5.3585544,2.1101968,,,,,,,,,,,,,, -141800,5.504389,2.0940447,,,,,,,,,,,,,, -141900,5.702259,2.136008,,,,,,,,,,,,,, -142000,5.382731,2.115592,,,,,,,,,,,,,, -142100,6.5360456,2.212957,,,,,,,,,,,,,, -142182,,,0.8167450428009033,0.7779141664505005,0.7300800085067749,1.1511822938919067,50000.0,0.6078000068664551,1.7888727188110352,10000.0,48510.98008084297,50317.63361525536,48510.98008084297,1797.0130088329315,4.874320030212402,0.0 -142200,5.9049225,2.2164812,,,,,,,,,,,,,, -142300,5.8302197,2.2071118,,,,,,,,,,,,,, -142400,5.713236,2.1646583,,,,,,,,,,,,,, -142500,6.063742,2.121251,,,,,,,,,,,,,, -142600,6.0680585,2.1533766,,,,,,,,,,,,,, -142700,5.435763,2.1670246,,,,,,,,,,,,,, -142800,5.768033,2.1300502,,,,,,,,,,,,,, -142900,5.7208138,2.0712764,,,,,,,,,,,,,, -143000,6.242872,2.2075732,,,,,,,,,,,,,, -143100,5.5732713,2.125431,,,,,,,,,,,,,, -143200,5.8022285,2.1371017,,,,,,,,,,,,,, -143300,5.4836836,2.131098,,,,,,,,,,,,,, -143400,5.838568,2.0572536,,,,,,,,,,,,,, -143500,6.356613,2.0973759,,,,,,,,,,,,,, -143600,6.0525236,2.18667,,,,,,,,,,,,,, -143681,,,0.8469786047935486,0.6639379262924194,0.7300399541854858,1.1502034664154053,50000.0,0.6038000583648682,1.7933541536331177,10000.0,49021.22454404831,50845.16202926636,49021.22454404831,1814.199674367905,4.921878337860107,0.0 -143700,6.182267,2.1733656,,,,,,,,,,,,,, -143800,5.9454947,2.1075969,,,,,,,,,,,,,, -143900,5.6675196,2.1341302,,,,,,,,,,,,,, -144000,5.749284,2.1464808,,,,,,,,,,,,,, -144100,5.3539877,2.0530357,,,,,,,,,,,,,, -144200,5.7692103,2.126952,,,,,,,,,,,,,, -144300,6.2893586,2.0708625,,,,,,,,,,,,,, -144400,5.516259,2.0344927,,,,,,,,,,,,,, -144500,6.618501,2.1315513,,,,,,,,,,,,,, -144600,6.178775,2.1173558,,,,,,,,,,,,,, -144700,6.022197,2.2018564,,,,,,,,,,,,,, -144800,6.2084765,2.1717594,,,,,,,,,,,,,, -144900,6.098962,2.0755434,,,,,,,,,,,,,, -145000,5.2639155,2.0597193,,,,,,,,,,,,,, -145100,6.359673,2.060948,,,,,,,,,,,,,, -145179,,,0.8422552347183228,0.703779399394989,0.7349399924278259,1.1473641395568848,50000.0,0.6086000204086304,1.7797400951385498,10000.0,49531.14088320732,51372.43950200081,49531.14088320732,1831.4566979408264,4.974995851516724,0.0 -145200,6.15902,2.0929704,,,,,,,,,,,,,, -145300,6.350396,2.088581,,,,,,,,,,,,,, -145400,5.965935,2.052752,,,,,,,,,,,,,, -145500,5.601359,2.1147707,,,,,,,,,,,,,, -145600,5.4426565,2.0727563,,,,,,,,,,,,,, -145700,5.9761744,2.0768287,,,,,,,,,,,,,, -145800,5.5262527,2.013445,,,,,,,,,,,,,, -145900,6.375509,2.2106915,,,,,,,,,,,,,, -146000,6.03894,2.0855896,,,,,,,,,,,,,, -146100,5.9213877,2.1708813,,,,,,,,,,,,,, -146200,5.978552,2.0984979,,,,,,,,,,,,,, -146300,6.3251057,2.0993853,,,,,,,,,,,,,, -146400,6.038743,2.181863,,,,,,,,,,,,,, -146500,6.330296,2.0667589,,,,,,,,,,,,,, -146600,5.578219,2.1046536,,,,,,,,,,,,,, -146677,,,0.8409199714660645,0.7093674540519714,0.7373799681663513,1.132510542869568,50000.0,0.6091000437736511,1.7796084880828855,10000.0,50041.08595466614,51899.51325583458,50041.08595466614,1848.4796833992004,5.031417369842529,0.0 -146700,5.694161,2.1007543,,,,,,,,,,,,,, -146800,6.046243,2.1129332,,,,,,,,,,,,,, -146900,5.8907924,2.0361485,,,,,,,,,,,,,, -147000,6.2115774,2.1290498,,,,,,,,,,,,,, -147100,5.813403,2.1300826,,,,,,,,,,,,,, -147200,5.4800186,2.0416157,,,,,,,,,,,,,, -147300,6.61155,2.1584308,,,,,,,,,,,,,, -147400,5.4807377,2.0365572,,,,,,,,,,,,,, -147500,6.361913,2.0107753,,,,,,,,,,,,,, -147600,6.706104,2.1370335,,,,,,,,,,,,,, -147700,6.2986784,2.1108358,,,,,,,,,,,,,, -147800,5.9514265,2.0836616,,,,,,,,,,,,,, -147900,6.4515285,2.143538,,,,,,,,,,,,,, -148000,6.164423,2.063567,,,,,,,,,,,,,, -148100,7.1754518,2.1388903,,,,,,,,,,,,,, -148175,,,0.8373923897743225,0.7035523056983948,0.737280011177063,1.1355136632919312,50000.0,0.6115000247955322,1.7595330476760864,10000.0,50551.017679452896,52426.786516428,50551.017679452896,1865.7171568870544,5.085627317428589,0.0 -148200,6.2453876,2.0132413,,,,,,,,,,,,,, -148300,6.0940433,2.0188534,,,,,,,,,,,,,, -148400,6.2065635,2.0836735,,,,,,,,,,,,,, -148500,5.8771963,2.0539503,,,,,,,,,,,,,, -148600,6.2344165,2.056155,,,,,,,,,,,,,, -148700,6.186754,2.1096623,,,,,,,,,,,,,, -148800,6.015292,2.0225031,,,,,,,,,,,,,, -148900,6.3428864,2.0992124,,,,,,,,,,,,,, -149000,6.0677333,2.045031,,,,,,,,,,,,,, -149100,6.825209,2.1027682,,,,,,,,,,,,,, -149200,6.6475616,2.089754,,,,,,,,,,,,,, -149300,6.5065823,2.059052,,,,,,,,,,,,,, -149400,5.962479,2.0227163,,,,,,,,,,,,,, -149500,6.1804214,2.1469693,,,,,,,,,,,,,, -149600,5.696654,2.0032609,,,,,,,,,,,,,, -149673,,,0.8453842401504517,0.686999499797821,0.7429400086402893,1.107838749885559,50000.0,0.6195000410079956,1.7411073446273804,10000.0,51060.90836381912,52954.00826811791,51060.90836381912,1882.9315605163567,5.151835203170776,0.0 -149700,7.06896,2.0823967,,,,,,,,,,,,,, -149800,6.556229,2.067988,,,,,,,,,,,,,, -149900,5.458934,1.9885653,,,,,,,,,,,,,, -150000,6.5698957,2.0470061,,,,,,,,,,,,,, -150100,5.907695,2.0049367,,,,,,,,,,,,,, -150200,6.223444,2.0774279,,,,,,,,,,,,,, -150300,6.2894382,2.0913095,,,,,,,,,,,,,, -150400,6.398135,2.0654063,,,,,,,,,,,,,, -150500,6.8497343,2.0557122,,,,,,,,,,,,,, -150600,6.036269,2.0118759,,,,,,,,,,,,,, -150700,6.4818387,2.00583,,,,,,,,,,,,,, -150800,6.38222,2.0054486,,,,,,,,,,,,,, -150900,6.643212,2.011936,,,,,,,,,,,,,, -151000,6.6258836,2.0385032,,,,,,,,,,,,,, -151100,6.2505774,2.0437393,,,,,,,,,,,,,, -151171,,,0.8451052308082581,0.675420880317688,0.7450799942016602,1.1006088256835938,50000.0,0.6212000250816345,1.727745532989502,10000.0,51571.01009392738,53481.925209999084,51571.01009392738,1900.640768289566,5.20801043510437,0.0 -151200,6.563231,2.0318882,,,,,,,,,,,,,, -151300,6.559092,2.0617375,,,,,,,,,,,,,, -151400,6.758108,2.0369556,,,,,,,,,,,,,, -151500,6.6454663,2.04647,,,,,,,,,,,,,, -151600,5.762533,1.9912363,,,,,,,,,,,,,, -151700,6.9823546,2.1081972,,,,,,,,,,,,,, -151800,6.4182677,2.1341038,,,,,,,,,,,,,, -151900,7.061034,2.0400982,,,,,,,,,,,,,, -152000,6.487306,1.9312184,,,,,,,,,,,,,, -152100,6.248158,2.0754473,,,,,,,,,,,,,, -152200,6.222494,2.0641673,,,,,,,,,,,,,, -152300,6.508445,2.014422,,,,,,,,,,,,,, -152400,7.3988676,2.0804024,,,,,,,,,,,,,, -152500,6.0075974,1.9589808,,,,,,,,,,,,,, -152600,6.5240192,1.9942493,,,,,,,,,,,,,, -152669,,,0.8661710619926453,0.6043965220451355,0.7461599707603455,1.0976543426513672,50000.0,0.6201000213623047,1.735266089439392,10000.0,52081.04586791992,54009.12109827995,52081.04586791992,1917.6977479457853,5.260669708251953,0.0 -152700,6.81375,1.9961045,,,,,,,,,,,,,, -152800,6.321934,1.930075,,,,,,,,,,,,,, -152900,6.3399024,1.982964,,,,,,,,,,,,,, -153000,6.249255,2.054456,,,,,,,,,,,,,, -153100,6.807649,1.9456377,,,,,,,,,,,,,, -153200,6.598354,1.941402,,,,,,,,,,,,,, -153300,6.39061,1.9778168,,,,,,,,,,,,,, -153400,6.201834,2.0174892,,,,,,,,,,,,,, -153500,6.4044747,2.076886,,,,,,,,,,,,,, -153600,7.2254634,2.035382,,,,,,,,,,,,,, -153700,6.669706,2.124329,,,,,,,,,,,,,, -153800,6.629127,1.97157,,,,,,,,,,,,,, -153900,6.6245847,1.9721513,,,,,,,,,,,,,, -154000,6.6116214,2.0548112,,,,,,,,,,,,,, -154100,6.3432574,2.0146859,,,,,,,,,,,,,, -154167,,,0.8671875,0.5794281959533691,0.7481399774551392,1.0786703824996948,50000.0,0.6278000473976135,1.6983375549316406,10000.0,52591.11575245857,54536.50642514229,52591.11575245857,1934.90061545372,5.321510553359985,0.0 -154200,6.819833,2.0411105,,,,,,,,,,,,,, -154300,6.46855,2.0120761,,,,,,,,,,,,,, -154400,7.048882,2.039364,,,,,,,,,,,,,, -154500,6.788445,2.0078473,,,,,,,,,,,,,, -154600,6.542411,2.0538409,,,,,,,,,,,,,, -154700,5.820357,1.9491725,,,,,,,,,,,,,, -154800,6.027293,1.9961579,,,,,,,,,,,,,, -154900,6.8503036,1.9662228,,,,,,,,,,,,,, -155000,6.836639,2.0658913,,,,,,,,,,,,,, -155100,6.6998873,1.9974885,,,,,,,,,,,,,, -155200,6.838086,1.9320217,,,,,,,,,,,,,, -155300,6.3585567,2.0365026,,,,,,,,,,,,,, -155400,6.530911,2.008601,,,,,,,,,,,,,, -155500,6.4469166,1.9534751,,,,,,,,,,,,,, -155600,6.9442854,2.0163522,,,,,,,,,,,,,, -155665,,,0.8687419891357422,0.5882298946380615,0.7505999803543091,1.0769861936569214,50000.0,0.6246000528335571,1.6996747255325315,10000.0,53101.14227557182,55063.9355969429,53101.14227557182,1952.2021520137787,5.373865604400635,0.0 -155700,6.019149,1.9599721,,,,,,,,,,,,,, -155800,6.761764,1.9093304,,,,,,,,,,,,,, -155900,7.401125,2.0071182,,,,,,,,,,,,,, -156000,5.908836,1.9961973,,,,,,,,,,,,,, -156100,6.383271,1.9604932,,,,,,,,,,,,,, -156200,6.326725,2.029506,,,,,,,,,,,,,, -156300,7.126922,1.9923139,,,,,,,,,,,,,, -156400,6.618187,1.9693348,,,,,,,,,,,,,, -156500,7.443127,1.9520041,,,,,,,,,,,,,, -156600,6.755917,1.911546,,,,,,,,,,,,,, -156700,6.5505385,1.9798933,,,,,,,,,,,,,, -156800,7.073777,1.9395626,,,,,,,,,,,,,, -156900,6.748881,1.9524802,,,,,,,,,,,,,, -157000,6.237751,1.8760395,,,,,,,,,,,,,, -157100,6.8003883,1.9917402,,,,,,,,,,,,,, -157162,,,0.8694794178009033,0.580014169216156,0.7535199522972107,1.0628749132156372,50000.0,0.6239000558853149,1.6967737674713137,10000.0,53611.06381011009,55590.8880238533,53611.06381011009,1969.1301229000087,5.426816463470459,0.0 -157200,6.8662167,1.9715024,,,,,,,,,,,,,, -157300,6.4166527,1.8820015,,,,,,,,,,,,,, -157400,7.605221,1.9836174,,,,,,,,,,,,,, -157500,6.8750186,1.9190302,,,,,,,,,,,,,, -157600,6.28401,1.9176658,,,,,,,,,,,,,, -157700,6.798329,1.8905077,,,,,,,,,,,,,, -157800,6.027482,1.8933123,,,,,,,,,,,,,, -157900,6.761316,1.9854631,,,,,,,,,,,,,, -158000,6.4881115,1.9596734,,,,,,,,,,,,,, -158100,6.4570665,1.9838421,,,,,,,,,,,,,, -158200,6.49299,1.9620774,,,,,,,,,,,,,, -158300,7.625867,2.052944,,,,,,,,,,,,,, -158400,6.929214,2.0177627,,,,,,,,,,,,,, -158500,7.102456,1.9522629,,,,,,,,,,,,,, -158600,6.569952,1.903954,,,,,,,,,,,,,, -158660,,,0.8698182106018066,0.5817106366157532,0.7541399598121643,1.064910888671875,50000.0,0.6336000561714172,1.6853803396224976,10000.0,54121.251658678055,56118.19445109368,54121.251658678055,1986.1422533988955,5.484708070755005,0.0 -158700,7.0710683,1.9873668,,,,,,,,,,,,,, -158800,7.151721,1.9861372,,,,,,,,,,,,,, -158900,6.897852,1.9235848,,,,,,,,,,,,,, -159000,7.2402678,1.978504,,,,,,,,,,,,,, -159100,6.7181606,2.0166724,,,,,,,,,,,,,, -159200,6.8609567,1.964848,,,,,,,,,,,,,, -159300,6.683439,1.9407865,,,,,,,,,,,,,, -159400,6.6581235,2.0529099,,,,,,,,,,,,,, -159500,6.8221445,1.9544867,,,,,,,,,,,,,, -159600,6.9311624,1.8707976,,,,,,,,,,,,,, -159700,7.0515194,1.940134,,,,,,,,,,,,,, -159800,6.7386312,1.9057533,,,,,,,,,,,,,, -159900,6.9150863,1.9743525,,,,,,,,,,,,,, -160000,7.0300055,1.8630737,,,,,,,,,,,,,, -160100,7.008747,1.9876585,,,,,,,,,,,,,, -160158,,,0.875418484210968,0.5687929391860962,0.756879985332489,1.0507851839065552,50000.0,0.6309000253677368,1.6826121807098389,10000.0,54631.442410707474,56645.72098970413,54631.442410707474,2003.3706483840945,5.54168963432312,0.0 -160200,6.7676396,1.981288,,,,,,,,,,,,,, -160300,7.1909323,1.958662,,,,,,,,,,,,,, -160400,7.6075306,1.9253576,,,,,,,,,,,,,, -160500,6.713435,1.9660834,,,,,,,,,,,,,, -160600,7.791474,2.0241318,,,,,,,,,,,,,, -160700,7.8465247,1.9857653,,,,,,,,,,,,,, -160800,7.4109416,1.982238,,,,,,,,,,,,,, -160900,7.114417,1.9823279,,,,,,,,,,,,,, -161000,6.76255,1.8751577,,,,,,,,,,,,,, -161100,6.912176,1.9418616,,,,,,,,,,,,,, -161200,6.7500734,1.9091558,,,,,,,,,,,,,, -161300,7.112112,2.0154037,,,,,,,,,,,,,, -161400,6.944118,1.8761383,,,,,,,,,,,,,, -161500,7.8147254,1.9788842,,,,,,,,,,,,,, -161600,7.5896583,1.9160419,,,,,,,,,,,,,, -161656,,,0.8792450428009033,0.5455688238143921,0.7571799755096436,1.0448676347732544,50000.0,0.6301000118255615,1.673742413520813,10000.0,55141.38109588623,57173.04496335983,55141.38109588623,2020.6487910747528,5.598784685134888,0.0 -161700,6.863358,1.8473792,,,,,,,,,,,,,, -161800,6.941334,1.8652058,,,,,,,,,,,,,, -161900,6.881479,1.9090662,,,,,,,,,,,,,, -162000,8.423663,1.9575946,,,,,,,,,,,,,, -162100,7.7848577,1.9020007,,,,,,,,,,,,,, -162200,7.11567,1.9323983,,,,,,,,,,,,,, -162300,7.4120073,1.9759886,,,,,,,,,,,,,, -162400,7.254138,1.8939217,,,,,,,,,,,,,, -162500,7.5791116,1.9252973,,,,,,,,,,,,,, -162600,7.1501613,1.8842044,,,,,,,,,,,,,, -162700,6.846484,1.8951904,,,,,,,,,,,,,, -162800,7.0397944,1.8979787,,,,,,,,,,,,,, -162900,7.0302167,1.846067,,,,,,,,,,,,,, -163000,7.546666,1.890049,,,,,,,,,,,,,, -163100,7.369758,1.9165124,,,,,,,,,,,,,, -163154,,,0.8960259556770325,0.491013616323471,0.7576000094413757,1.0471340417861938,50000.0,0.6349000334739685,1.6725289821624756,10000.0,55651.28565096855,57700.09034585953,55651.28565096855,2037.681653022766,5.65578556060791,0.0 -163200,6.9708014,1.9640453,,,,,,,,,,,,,, -163300,7.6353936,1.937542,,,,,,,,,,,,,, -163400,7.152024,1.9123571,,,,,,,,,,,,,, -163500,7.0620885,1.8988732,,,,,,,,,,,,,, -163600,7.284323,1.8575882,,,,,,,,,,,,,, -163700,7.331197,1.901773,,,,,,,,,,,,,, -163800,6.8509946,1.8465428,,,,,,,,,,,,,, -163900,6.9760685,1.8741095,,,,,,,,,,,,,, -164000,7.2021065,1.9319677,,,,,,,,,,,,,, -164100,7.884265,1.9086764,,,,,,,,,,,,,, -164200,6.998528,1.8637292,,,,,,,,,,,,,, -164300,7.4660845,1.9189234,,,,,,,,,,,,,, -164400,7.5344977,1.9604297,,,,,,,,,,,,,, -164500,6.8854704,1.8657616,,,,,,,,,,,,,, -164600,7.31279,1.9188029,,,,,,,,,,,,,, -164652,,,0.8957070708274841,0.4908312857151031,0.7618599534034729,1.0314052104949951,50000.0,0.6367000341415405,1.661550521850586,10000.0,56161.19146823883,58227.19503569603,56161.19146823883,2054.770972251892,5.715636491775513,0.0 -164700,7.61302,1.8779494,,,,,,,,,,,,,, -164800,7.1646557,1.8982073,,,,,,,,,,,,,, -164900,7.7721753,1.9007244,,,,,,,,,,,,,, -165000,7.3915353,1.8929648,,,,,,,,,,,,,, -165100,7.898934,1.9007008,,,,,,,,,,,,,, -165200,7.592468,1.8980172,,,,,,,,,,,,,, -165300,7.469331,1.8673004,,,,,,,,,,,,,, -165400,7.665046,1.8941236,,,,,,,,,,,,,, -165500,7.4479303,1.9034461,,,,,,,,,,,,,, -165600,7.65783,1.9055296,,,,,,,,,,,,,, -165700,6.8585415,1.8737514,,,,,,,,,,,,,, -165800,7.222423,1.8989162,,,,,,,,,,,,,, -165900,6.976032,1.8414679,,,,,,,,,,,,,, -166000,8.522653,1.8938897,,,,,,,,,,,,,, -166100,7.5974636,1.8642365,,,,,,,,,,,,,, -166150,,,0.8951490521430969,0.4891964495182037,0.7637199759483337,1.026726245880127,50000.0,0.6398000121116638,1.6573363542556765,10000.0,56671.27508664131,58754.45888543129,56671.27508664131,2071.8406381607056,5.775366306304932,0.0 -166200,7.1668673,1.8568721,,,,,,,,,,,,,, -166300,7.5612087,1.8665442,,,,,,,,,,,,,, -166400,8.197269,1.9154174,,,,,,,,,,,,,, -166500,7.6572304,1.8664547,,,,,,,,,,,,,, -166600,7.1270185,1.8304596,,,,,,,,,,,,,, -166700,7.3194585,1.9306655,,,,,,,,,,,,,, -166800,7.3994074,1.8318208,,,,,,,,,,,,,, -166900,7.595518,1.8594568,,,,,,,,,,,,,, -167000,6.684193,1.8623223,,,,,,,,,,,,,, -167100,8.106978,1.9365305,,,,,,,,,,,,,, -167200,6.8460846,1.8171787,,,,,,,,,,,,,, -167300,7.172441,1.8486438,,,,,,,,,,,,,, -167400,7.4114656,1.8716673,,,,,,,,,,,,,, -167500,7.6711674,1.854525,,,,,,,,,,,,,, -167600,7.268354,1.8334048,,,,,,,,,,,,,, -167648,,,0.8947305083274841,0.4829561114311218,0.7655199766159058,1.0176388025283811,50000.0,0.6388000249862671,1.6397991180419922,10000.0,57181.31681752205,59281.74941539765,57181.31681752205,2088.97993516922,5.834702968597412,0.0 -167700,7.472301,1.8219247,,,,,,,,,,,,,, -167800,7.7418838,1.8812265,,,,,,,,,,,,,, -167900,7.513864,1.8169751,,,,,,,,,,,,,, -168000,6.7896214,1.7863799,,,,,,,,,,,,,, -168100,8.298691,1.8733258,,,,,,,,,,,,,, -168200,7.7360735,1.8951231,,,,,,,,,,,,,, -168300,6.8486238,1.817291,,,,,,,,,,,,,, -168400,7.7650685,1.7909977,,,,,,,,,,,,,, -168500,7.951152,1.8934658,,,,,,,,,,,,,, -168600,7.9546924,1.8894314,,,,,,,,,,,,,, -168700,7.8314686,1.8828,,,,,,,,,,,,,, -168800,7.3818874,1.8240045,,,,,,,,,,,,,, -168900,7.550444,1.8891096,,,,,,,,,,,,,, -169000,7.602762,1.8775693,,,,,,,,,,,,,, -169100,8.357992,1.8588942,,,,,,,,,,,,,, -169146,,,0.9001514315605164,0.4777239859104156,0.7644000053405762,1.024844527244568,50000.0,0.6415000557899475,1.655003786087036,10000.0,57691.400824546814,59809.09949350357,57691.400824546814,2106.1391632556915,5.890669107437134,0.0 -169200,7.9130177,1.7957671,,,,,,,,,,,,,, -169300,7.799386,1.9264741,,,,,,,,,,,,,, -169400,7.30428,1.8152043,,,,,,,,,,,,,, -169500,7.689045,1.8423562,,,,,,,,,,,,,, -169600,7.0553417,1.8376912,,,,,,,,,,,,,, -169700,7.4302664,1.8454849,,,,,,,,,,,,,, -169800,7.4353704,1.8735566,,,,,,,,,,,,,, -169900,7.5030003,1.81339,,,,,,,,,,,,,, -170000,8.304471,1.9196708,,,,,,,,,,,,,, -170100,7.1313167,1.7963966,,,,,,,,,,,,,, -170200,8.111696,1.8374974,,,,,,,,,,,,,, -170300,6.79523,1.8140079,,,,,,,,,,,,,, -170400,7.2869916,1.8505431,,,,,,,,,,,,,, -170500,7.841014,1.8569205,,,,,,,,,,,,,, -170600,7.7037945,1.8842645,,,,,,,,,,,,,, -170644,,,0.9061303734779358,0.4618020951747894,0.7671999931335449,1.0098974704742432,50000.0,0.6467000246047974,1.6335299015045166,10000.0,58201.40622735024,60336.31550955773,58201.40622735024,2123.243331670761,5.947777032852173,0.0 -170700,7.7952423,1.8623762,,,,,,,,,,,,,, -170800,7.6818104,1.8568124,,,,,,,,,,,,,, -170900,7.8473,1.9145896,,,,,,,,,,,,,, -171000,8.17061,1.8488528,,,,,,,,,,,,,, -171100,7.170144,1.831005,,,,,,,,,,,,,, -171200,7.4137363,1.8000152,,,,,,,,,,,,,, -171300,7.3893905,1.8463594,,,,,,,,,,,,,, -171400,8.088552,1.8707076,,,,,,,,,,,,,, -171500,7.259269,1.7663984,,,,,,,,,,,,,, -171600,7.3024683,1.8382341,,,,,,,,,,,,,, -171700,7.924942,1.9058503,,,,,,,,,,,,,, -171800,7.9608226,1.7565123,,,,,,,,,,,,,, -171900,7.797467,1.7667636,,,,,,,,,,,,,, -172000,7.3437624,1.8294485,,,,,,,,,,,,,, -172100,7.910439,1.8983649,,,,,,,,,,,,,, -172142,,,0.9163145422935486,0.4251190423965454,0.7689799666404724,1.0125985145568848,50000.0,0.6457000374794006,1.6340036392211914,10000.0,58711.40957093239,60863.64063549042,58711.40957093239,2140.4393548965454,6.021910905838013,0.0 -172200,7.8401966,1.8084965,,,,,,,,,,,,,, -172300,8.248984,1.8817202,,,,,,,,,,,,,, -172400,8.400515,1.8457974,,,,,,,,,,,,,, -172500,7.5784845,1.8176267,,,,,,,,,,,,,, -172600,7.2665014,1.8129561,,,,,,,,,,,,,, -172700,7.9809628,1.8524616,,,,,,,,,,,,,, -172800,8.233977,1.8514078,,,,,,,,,,,,,, -172900,7.3142185,1.8022791,,,,,,,,,,,,,, -173000,7.584178,1.7921357,,,,,,,,,,,,,, -173100,8.354088,1.7931621,,,,,,,,,,,,,, -173200,7.227056,1.8208549,,,,,,,,,,,,,, -173300,8.324791,1.7990624,,,,,,,,,,,,,, -173400,6.981098,1.733908,,,,,,,,,,,,,, -173500,7.517347,1.7844781,,,,,,,,,,,,,, -173600,7.3516645,1.7835454,,,,,,,,,,,,,, -173640,,,0.9162746667861938,0.415965586900711,0.7691599726676941,1.0013891458511353,50000.0,0.6492000222206116,1.6199817657470703,10000.0,59221.40473794937,61390.93344545365,59221.40473794937,2157.631984949112,6.077967643737793,0.0 -173700,8.342728,1.860466,,,,,,,,,,,,,, -173800,7.3353148,1.8003924,,,,,,,,,,,,,, -173900,7.757616,1.7735829,,,,,,,,,,,,,, -174000,7.2763104,1.7507905,,,,,,,,,,,,,, -174100,7.882649,1.8602259,,,,,,,,,,,,,, -174200,8.290374,1.9004998,,,,,,,,,,,,,, -174300,7.678131,1.8335083,,,,,,,,,,,,,, -174400,7.4319625,1.7835827,,,,,,,,,,,,,, -174500,8.13232,1.8563914,,,,,,,,,,,,,, -174600,8.048708,1.775537,,,,,,,,,,,,,, -174700,8.543827,1.8309698,,,,,,,,,,,,,, -174800,8.02141,1.8092842,,,,,,,,,,,,,, -174900,8.235479,1.8198416,,,,,,,,,,,,,, -175000,7.713244,1.8619639,,,,,,,,,,,,,, -175100,8.35822,1.87219,,,,,,,,,,,,,, -175138,,,0.9125677347183228,0.4280580282211303,0.7706199884414673,1.000607967376709,50000.0,0.6484000086784363,1.6202576160430908,10000.0,59731.5888364315,61918.1064593792,59731.5888364315,2174.5137605667114,6.134857654571533,0.0 -175200,7.346027,1.8782222,,,,,,,,,,,,,, -175300,8.107488,1.7760489,,,,,,,,,,,,,, -175400,7.7986917,1.8074826,,,,,,,,,,,,,, -175500,7.3235183,1.8156924,,,,,,,,,,,,,, -175600,8.299314,1.7843312,,,,,,,,,,,,,, -175700,7.423405,1.7736869,,,,,,,,,,,,,, -175800,7.6673617,1.8319261,,,,,,,,,,,,,, -175900,7.574192,1.8020204,,,,,,,,,,,,,, -176000,7.5827737,1.9056515,,,,,,,,,,,,,, -176100,7.4615602,1.7665291,,,,,,,,,,,,,, -176200,7.679833,1.7318188,,,,,,,,,,,,,, -176300,8.446914,1.8590963,,,,,,,,,,,,,, -176400,8.5349655,1.8054938,,,,,,,,,,,,,, -176500,7.899834,1.7822571,,,,,,,,,,,,,, -176600,7.078993,1.7840557,,,,,,,,,,,,,, -176636,,,0.916214883327484,0.4126388728618622,0.7721799612045288,0.9947059154510498,50000.0,0.6476000547409058,1.617110252380371,10000.0,60241.7090446949,62445.60426354408,60241.7090446949,2191.785396575928,6.190088748931885,0.0 -176700,8.605567,1.721211,,,,,,,,,,,,,, -176800,8.374929,1.8414445,,,,,,,,,,,,,, -176900,8.04592,1.7332456,,,,,,,,,,,,,, -177000,7.3903427,1.7627604,,,,,,,,,,,,,, -177100,7.3602033,1.739006,,,,,,,,,,,,,, -177200,9.401292,1.8594993,,,,,,,,,,,,,, -177300,8.3528,1.8116398,,,,,,,,,,,,,, -177400,7.100556,1.7078509,,,,,,,,,,,,,, -177500,7.8231487,1.7927115,,,,,,,,,,,,,, -177600,7.9921308,1.8321878,,,,,,,,,,,,,, -177700,7.6486354,1.7751745,,,,,,,,,,,,,, -177800,8.146598,1.7738185,,,,,,,,,,,,,, -177900,7.438827,1.7130827,,,,,,,,,,,,,, -178000,7.5563073,1.762458,,,,,,,,,,,,,, -178100,7.4370494,1.7071469,,,,,,,,,,,,,, -178133,,,0.9172711968421936,0.4106919467449188,0.7718799710273743,0.9955499172210692,50000.0,0.6497000455856323,1.6151281595230105,10000.0,60751.60144472122,62972.50779604912,60751.60144472122,2208.687886953354,6.247087478637695,0.0 -178200,7.0562515,1.7146423,,,,,,,,,,,,,, -178300,7.9075956,1.8208321,,,,,,,,,,,,,, -178400,8.335959,1.8027973,,,,,,,,,,,,,, -178500,7.6892915,1.7869172,,,,,,,,,,,,,, -178600,8.310741,1.8203769,,,,,,,,,,,,,, -178700,8.61104,1.769747,,,,,,,,,,,,,, -178800,8.016062,1.7818352,,,,,,,,,,,,,, -178900,7.306567,1.7771697,,,,,,,,,,,,,, -179000,7.9503603,1.7687876,,,,,,,,,,,,,, -179100,8.020508,1.7087672,,,,,,,,,,,,,, -179200,8.542672,1.8209782,,,,,,,,,,,,,, -179300,7.4693456,1.7460048,,,,,,,,,,,,,, -179400,8.270999,1.7813172,,,,,,,,,,,,,, -179500,7.736552,1.8220203,,,,,,,,,,,,,, -179600,8.174454,1.8089099,,,,,,,,,,,,,, -179631,,,0.918965220451355,0.4076812863349914,0.7737799882888794,0.989608645439148,50000.0,0.65010005235672,1.6078531742095947,10000.0,61261.76955938339,63500.064401865005,61261.76955938339,2225.9638142585754,6.308993816375732,0.0 -179700,7.0215063,1.7449523,,,,,,,,,,,,,, -179800,8.18749,1.7888587,,,,,,,,,,,,,, -179900,8.485644,1.7819878,,,,,,,,,,,,,, -180000,8.568746,1.8374047,,,,,,,,,,,,,, -180100,7.8358054,1.8353599,,,,,,,,,,,,,, -180200,8.186034,1.7603887,,,,,,,,,,,,,, -180300,8.120799,1.7778952,,,,,,,,,,,,,, -180400,8.3731985,1.7770405,,,,,,,,,,,,,, -180500,8.179558,1.79104,,,,,,,,,,,,,, -180600,8.703113,1.7523699,,,,,,,,,,,,,, -180700,7.841159,1.762378,,,,,,,,,,,,,, -180800,7.6170692,1.753554,,,,,,,,,,,,,, -180900,7.367071,1.7263304,,,,,,,,,,,,,, -181000,8.128121,1.7975718,,,,,,,,,,,,,, -181100,7.7795134,1.7695715,,,,,,,,,,,,,, -181128,,,0.9202008843421936,0.4020408689975738,0.7730000019073486,0.992912530899048,50000.0,0.6499000191688538,1.609398603439331,10000.0,61771.71801686287,64027.38457632065,61771.71801686287,2243.2266731262207,6.367881774902344,0.0 -181200,7.4917274,1.7443312,,,,,,,,,,,,,, -181300,7.7705016,1.7746129,,,,,,,,,,,,,, -181400,7.455124,1.8057142,,,,,,,,,,,,,, -181500,8.346162,1.7965654,,,,,,,,,,,,,, -181600,7.8426833,1.7508398,,,,,,,,,,,,,, -181700,7.312555,1.6872804,,,,,,,,,,,,,, -181800,7.640635,1.794139,,,,,,,,,,,,,, -181900,7.8978925,1.783536,,,,,,,,,,,,,, -182000,8.227751,1.8133224,,,,,,,,,,,,,, -182100,8.066653,1.7884588,,,,,,,,,,,,,, -182200,6.96392,1.7296846,,,,,,,,,,,,,, -182300,8.980992,1.8347615,,,,,,,,,,,,,, -182400,7.74303,1.7797418,,,,,,,,,,,,,, -182500,7.5675187,1.754406,,,,,,,,,,,,,, -182600,8.047703,1.7600791,,,,,,,,,,,,,, -182626,,,0.9197225570678712,0.4042039811611175,0.7734400033950806,0.9904934167861938,50000.0,0.6509000062942505,1.606026530265808,10000.0,62281.93049740791,64554.90435361862,62281.93049740791,2260.4189026355743,6.43097186088562,0.0 -182700,8.254622,1.7934897,,,,,,,,,,,,,, -182800,7.634106,1.7737286,,,,,,,,,,,,,, -182900,7.39901,1.7248368,,,,,,,,,,,,,, -183000,8.030331,1.7699282,,,,,,,,,,,,,, -183100,7.7605925,1.7349671,,,,,,,,,,,,,, -183200,7.790339,1.7562814,,,,,,,,,,,,,, -183300,8.042446,1.816022,,,,,,,,,,,,,, -183400,8.670111,1.6959472,,,,,,,,,,,,,, -183500,7.6525946,1.7252889,,,,,,,,,,,,,, -183600,7.4828486,1.7524537,,,,,,,,,,,,,, -183700,7.8164907,1.7614403,,,,,,,,,,,,,, -183800,7.936824,1.8566813,,,,,,,,,,,,,, -183900,7.508185,1.8018358,,,,,,,,,,,,,, -184000,7.418323,1.7432028,,,,,,,,,,,,,, -184100,8.434907,1.7560903,,,,,,,,,,,,,, -184124,,,0.921097695827484,0.3999505639076233,0.7736999988555908,0.991873264312744,50000.0,0.6513000130653381,1.6081324815750122,10000.0,62791.98581242561,65082.41243624687,62791.98581242561,2277.7590713500977,6.492360353469849,0.0 -184200,8.133047,1.791073,,,,,,,,,,,,,, -184300,7.2101398,1.7144209,,,,,,,,,,,,,, -184400,7.7182255,1.7867999,,,,,,,,,,,,,, -184500,8.152205,1.7915604,,,,,,,,,,,,,, -184600,7.369583,1.7786417,,,,,,,,,,,,,, -184700,8.087138,1.7929249,,,,,,,,,,,,,, -184760,,,,,,,,,,,63008.28060030937,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/eval_measurements.csv deleted file mode 100644 index adc75cd02..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.40215516090393,0.0,32.273086071014404,1,0,32.273086071014404,0.0012000000569969,6.910790920257568,10000,49.67532181739807,0.0010363520123064,6.910816669464111,0.0010199999669566,6.91091251373291,50000 -34.842326641082764,0.0192625522613525,542.2085037231445,1491,0,542.2085037231445,0.0553000010550022,5.544482231140137,10000,577.119710445404,0.0832270383834838,5.25524377822876,0.0771199986338615,5.3119893074035645,50000 -52.28858208656311,0.0503122806549072,1052.3712661266327,2981,0,1052.3712661266327,0.1312000006437301,4.706563472747803,10000,1104.8108565807345,0.1989795863628387,4.21004581451416,0.1790999919176101,4.317634105682373,50000 -70.44600534439087,0.0788969993591308,1562.2961611747742,4470,0,1562.2961611747742,0.19200000166893,4.257549285888672,10000,1632.9713141918182,0.2876076102256775,3.585117816925049,0.2639600038528442,3.7350914478302,50000 -87.99014830589294,0.1064472198486328,2072.471899271012,5960,0,2072.471899271012,0.2747000157833099,3.713503360748291,10000,2160.769720315933,0.388073980808258,2.992854356765747,0.3614400029182434,3.130856990814209,50000 -105.64619159698486,0.1386871337890625,2582.5007004737854,7450,0,2582.5007004737854,0.3164000213146209,3.388169765472412,10000,2688.537534236908,0.4523876011371612,2.617628335952759,0.4184799790382385,2.7827277183532715,50000 -122.9762909412384,0.1666934490203857,3092.7613253593445,8942,0,3092.7613253593445,0.3680000305175781,3.117877244949341,10000,3216.206754922867,0.5526944994926453,2.124098539352417,0.4767199754714966,2.479666233062744,50000 -140.33562183380127,0.1949865818023681,3602.929986476898,10434,0,3602.929986476898,0.4036000072956085,2.9123005867004395,10000,3743.8133985996246,0.5689373016357422,2.022469520568848,0.522159993648529,2.2517900466918945,50000 -157.598552942276,0.2236585617065429,4112.95587015152,11926,0,4112.95587015152,0.4262000322341919,2.801759719848633,10000,4271.180375099182,0.5983737111091614,1.9081037044525144,0.5486199855804443,2.138624906539917,50000 -175.05454540252686,0.2537183761596679,4623.103399991989,13419,0,4623.103399991989,0.4473000168800354,2.6989235877990723,10000,4798.86471414566,0.6163105964660645,1.7988837957382202,0.5671799778938293,2.038691282272339,50000 -192.3369791507721,0.2829594612121582,5133.311586380005,14912,0,5133.311586380005,0.4590000212192535,2.626842975616455,10000,5326.436780452728,0.6312180757522583,1.7623796463012695,0.5798799991607666,1.994551658630371,50000 -209.60658407211304,0.3215117454528808,5643.3654227256775,16405,0,5643.3654227256775,0.4595000147819519,2.652688980102539,10000,5853.849092960358,0.6322146058082581,1.7425355911254885,0.5858799815177917,1.9565809965133667,50000 -227.14486122131348,0.3570168018341064,6153.40634727478,17899,0,6153.40634727478,0.4756000339984894,2.5317037105560303,10000,6381.513226747513,0.6486168503761292,1.6495143175125122,0.5984199643135071,1.885051250457764,50000 -244.6494562625885,0.3890948295593261,6663.387640237808,19392,0,6663.387640237808,0.4908000230789184,2.465847969055176,10000,6909.0832579135895,0.6943359375,1.483161449432373,0.615619957447052,1.8391530513763428,50000 -262.23272013664246,0.420799970626831,7173.381569385529,20885,0,7173.381569385529,0.4933000206947326,2.453991174697876,10000,7436.744100093842,0.6927216053009033,1.4884588718414309,0.6250799894332886,1.790726661682129,50000 -279.69251894950867,0.4559054374694824,7683.411033153534,22378,0,7683.411033153534,0.4984000325202942,2.4333972930908203,10000,7964.318654060364,0.6850087642669678,1.496760010719299,0.6200599670410156,1.7901036739349363,50000 -296.98755836486816,0.4871697425842285,8193.398950099945,23872,0,8193.398950099945,0.4937000274658203,2.4656965732574463,10000,8491.685319185257,0.6826769709587097,1.535045146942139,0.6202600002288818,1.8089746236801147,50000 -314.5280523300171,0.5178220272064209,8703.517453432083,25367,0,8703.517453432083,0.5014000535011292,2.360618352890014,10000,9019.426125764849,0.7011120915412903,1.4206559658050537,0.6384599804878235,1.7001821994781494,50000 -332.03754591941833,0.5506632328033447,9213.776574373243,26862,0,9213.776574373243,0.5125000476837158,2.392122507095337,10000,9547.277846574783,0.6947743892669678,1.4827384948730469,0.6357600092887878,1.7430270910263062,50000 -349.4531116485596,0.5828866958618164,9723.97894001007,28357,0,9723.97894001007,0.515500009059906,2.3387906551361084,10000,10074.980818033218,0.7403140664100647,1.251942753791809,0.6389600038528442,1.6875686645507812,50000 -366.90966606140137,0.6168107986450195,10234.157279968262,29851,0,10234.157279968262,0.5200999975204468,2.3046329021453857,10000,10602.70266866684,0.725984513759613,1.2846310138702393,0.6454600095748901,1.647182583808899,50000 -384.3558855056762,0.6501708030700684,10744.193604707718,31345,0,10744.193604707718,0.5267000198364258,2.2946465015411377,10000,11130.2699239254,0.7231544852256775,1.3217514753341677,0.6491000056266785,1.644545078277588,50000 -401.97428369522095,0.6830503940582275,11254.132838010788,32838,0,11254.132838010788,0.5159000158309937,2.34363865852356,10000,11657.91011095047,0.7084861397743225,1.379503846168518,0.64055997133255,1.693148136138916,50000 -419.3604607582092,0.7189726829528809,11764.288525104525,34333,0,11764.288525104525,0.524399995803833,2.2950656414031982,10000,12185.538932323456,0.7168765664100647,1.3567454814910889,0.6507399678230286,1.6542654037475586,50000 -436.809175491333,0.7555732727050781,12274.358264684675,35828,0,12274.358264684675,0.5097000002861023,2.392539501190185,10000,12713.143894910812,0.6993981003761292,1.4522290229797363,0.6362599730491638,1.7302823066711426,50000 -454.1600670814514,0.7909941673278809,12784.553744077682,37323,0,12784.553744077682,0.5236000418663025,2.282211303710937,10000,13240.775558948517,0.7249680757522583,1.295501708984375,0.6560800075531006,1.6071767807006836,50000 -471.69002628326416,0.8257701396942139,13294.51098227501,38818,0,13294.51098227501,0.5186000466346741,2.347128868103028,10000,13768.348722219467,0.7355508208274841,1.286131739616394,0.6459800004959106,1.6857054233551023,50000 -489.20293831825256,0.8625240325927734,13804.579131126404,40313,0,13804.579131126404,0.5261000394821167,2.311683416366577,10000,14296.019186019896,0.7302096486091614,1.310068964958191,0.6537799835205078,1.6593209505081177,50000 -507.0605084896088,0.8970699310302734,14314.632142066956,41808,0,14314.632142066956,0.5362000465393066,2.260235071182251,10000,14824.014763116837,0.7320830821990967,1.2917817831039429,0.6575599908828735,1.6185765266418457,50000 -524.4154839515686,0.935309648513794,14824.831728935242,43304,0,14824.831728935242,0.5380000472068787,2.22286057472229,10000,15351.65698647499,0.7353116869926453,1.2636276483535769,0.6688199639320374,1.5687376260757446,50000 -541.8120038509369,0.9696609973907472,15334.955714941025,44799,0,15334.955714941025,0.5301000475883484,2.2401185035705566,10000,15879.262532949448,0.7269411683082581,1.2829989194869995,0.6580599546432495,1.5886094570159912,50000 -559.0715284347534,1.003936529159546,15845.028705358503,46294,0,15845.028705358503,0.5383000373840332,2.21299695968628,10000,16406.679964780807,0.7327606678009033,1.280967116355896,0.6660400032997131,1.5779304504394531,50000 -576.2020993232727,1.0376520156860352,16355.009400129318,47789,0,16355.009400129318,0.5379000306129456,2.240457057952881,10000,16933.876542568207,0.7629344463348389,1.151097536087036,0.6612600088119507,1.5798641443252563,50000 -593.5875813961029,1.0741991996765137,16865.14865398407,49284,0,16865.14865398407,0.5371000170707703,2.2499301433563232,10000,17461.489314556122,0.7473692297935486,1.239794373512268,0.6687399744987488,1.597258687019348,50000 -611.8048617839813,1.1133863925933838,17375.061665296555,50779,0,17375.061665296555,0.5362000465393066,2.243440628051758,10000,17989.708940029144,0.7400350570678711,1.2261372804641724,0.6668199896812439,1.554896354675293,50000 -629.0761721134186,1.1479730606079102,17885.29244852066,52275,0,17885.29244852066,0.5449000000953674,2.2030839920043945,10000,18517.295283079147,0.7399553656578064,1.244918942451477,0.6628599762916565,1.5738775730133057,50000 -646.256591796875,1.186426877975464,18395.48710179329,53771,0,18395.48710179329,0.5424000024795532,2.229390382766724,10000,19044.75901412964,0.7357900142669678,1.252927303314209,0.6644200086593628,1.56941819190979,50000 -663.5889291763306,1.2300894260406494,18905.508989572525,55266,0,18905.508989572525,0.5421000123023987,2.2132768630981445,10000,19572.208678483963,0.7410913705825806,1.232458233833313,0.6733999848365784,1.5392425060272217,50000 -680.8869771957397,1.26725435256958,19415.54997587204,56761,0,19415.54997587204,0.5403000116348267,2.260906457901001,10000,20099.63463830948,0.7707669138908386,1.1692932844161987,0.6709399819374084,1.5907434225082395,50000 -698.3031196594238,1.3048710823059082,19925.609867811203,58256,0,19925.609867811203,0.5485000014305115,2.1709296703338623,10000,20627.20070528984,0.7694514989852905,1.1050457954406738,0.6767799854278564,1.4937835931777954,50000 -715.6581652164459,1.3460206985473633,20435.80241703987,59752,0,20435.80241703987,0.5371000170707703,2.2205097675323486,10000,21154.84163069725,0.7449377775192261,1.1936075687408447,0.666979968547821,1.5506370067596436,50000 -732.6220688819885,1.3891007900238037,20946.0156428814,61248,0,20946.0156428814,0.5407000184059143,2.2069180011749268,10000,21682.11145663261,0.7509366869926453,1.1979851722717283,0.6717399954795837,1.540035605430603,50000 -749.6996276378632,1.426433801651001,21456.26297426224,62744,0,21456.26297426224,0.5323000550270081,2.2445785999298096,10000,22209.52569293976,0.7419483065605164,1.2212562561035156,0.6708199977874756,1.5463603734970093,50000 -767.0168855190277,1.464949369430542,21966.35502266884,64239,0,21966.35502266884,0.5462000370025635,2.255506992340088,10000,22737.02538251877,0.7436822056770325,1.2639172077178955,0.6767199635505676,1.5742335319519043,50000 -784.1112470626831,1.5077550411224363,22476.510162115097,65735,0,22476.510162115097,0.5468000173568726,2.2133982181549072,10000,23264.367428541183,0.7479472160339355,1.2383085489273071,0.6757599711418152,1.5526641607284546,50000 -800.9966917037964,1.5449728965759275,22986.69782590866,67230,0,22986.69782590866,0.5466000437736511,2.168231725692749,10000,23791.52751684189,0.7716238498687744,1.069993019104004,0.6693199872970581,1.5192725658416748,50000 -818.1228411197662,1.5834143161773682,23496.80472302437,68726,0,23496.80472302437,0.5540000200271606,2.1980412006378174,10000,24318.850444078445,0.7636120915412903,1.1651519536972046,0.6775799989700317,1.5437074899673462,50000 -835.3253185749054,1.6262106895446775,24006.80624294281,70221,0,24006.80624294281,0.5615000128746033,2.142663717269897,10000,24846.14938569069,0.7656847834587097,1.155431866645813,0.68367999792099,1.5203362703323364,50000 -852.434784412384,1.664741277694702,24516.862579345703,71717,0,24516.862579345703,0.55840003490448,2.1221847534179688,10000,25373.404136896133,0.770527720451355,1.1102607250213623,0.6875999569892883,1.4715756177902222,50000 -869.5368921756744,1.7121210098266602,25026.77587389946,73212,0,25026.77587389946,0.5588000416755676,2.1481716632843018,10000,25900.518624067307,0.7600645422935486,1.1490471363067627,0.6841599941253662,1.4841029644012451,50000 -886.6352317333221,1.7510671615600586,25536.89725470543,74708,0,25536.89725470543,0.5622000098228455,2.1055116653442383,10000,26427.82671189308,0.7657246589660645,1.1162415742874146,0.6844599843025208,1.458303928375244,50000 -903.6346333026886,1.7926886081695557,26046.908839941025,76203,0,26046.908839941025,0.5706000328063965,2.115853786468506,10000,26954.93002486229,0.8078364133834839,1.0016721487045288,0.6889399886131287,1.5002516508102417,50000 -920.7117989063264,1.8335063457489007,26557.14165997505,77699,0,26557.14165997505,0.5588000416755676,2.096848964691162,10000,27482.33225798607,0.786551296710968,1.0268144607543943,0.6879599690437317,1.454639196395874,50000 -938.072808265686,1.8744986057281487,27067.081208705906,79194,0,27067.081208705906,0.5654000043869019,2.102102756500244,10000,28009.725796461105,0.7811902165412903,1.058087706565857,0.6916999816894531,1.4431703090667725,50000 -955.9951276779176,1.9179785251617432,27577.212859630585,80690,0,27577.212859630585,0.5612000226974487,2.126382827758789,10000,28537.874395132065,0.7735171914100647,1.107993483543396,0.6871599555015564,1.4867323637008667,50000 -973.1926457881927,1.9603497982025144,28087.34799814224,82186,0,28087.34799814224,0.5517000555992126,2.1907894611358643,10000,29065.30001354217,0.7631337642669678,1.1587564945220947,0.6825199723243713,1.5233628749847412,50000 -990.360454082489,2.00658917427063,28597.258448839188,83680,0,28597.258448839188,0.5542000532150269,2.152329683303833,10000,29592.47763466835,0.7662228941917419,1.1387059688568115,0.6846799850463867,1.4886494874954224,50000 -1007.7248740196228,2.047886610031128,29107.465238809586,85176,0,29107.465238809586,0.5660000443458557,2.104079246520996,10000,30120.14112353325,0.7777224183082581,1.1008814573287964,0.6924399733543396,1.4702638387680054,50000 -1024.9641880989077,2.0917553901672363,29617.56979894638,86672,0,29617.56979894638,0.5685000419616699,2.0924246311187744,10000,30647.5791118145,0.7954002022743225,1.0093271732330322,0.6930800080299377,1.4499201774597168,50000 -1042.3706114292145,2.15392541885376,30127.56591463089,88167,0,30127.56591463089,0.5648000240325928,2.1210591793060303,10000,31175.094157218933,0.7859733700752258,1.0578300952911377,0.693399965763092,1.460898995399475,50000 -1059.5920572280884,2.1985137462615967,30637.575980186462,89662,0,30637.575980186462,0.5674000382423401,2.1174557209014893,10000,31702.41963338852,0.7878866195678711,1.08149516582489,0.6979999542236328,1.4614171981811523,50000 -1076.8009662628174,2.2499001026153564,31147.80820083618,91158,0,31147.80820083618,0.5745000243186951,2.0515968799591064,10000,32229.96341466904,0.7898397445678711,1.0304040908813477,0.6977599859237671,1.4194663763046265,50000 -1094.0021858215332,2.295071840286255,31657.907905578613,92654,0,31657.907905578613,0.5738000273704529,2.067585229873657,10000,32757.361146211624,0.7855349183082581,1.0370293855667114,0.6988399624824524,1.4155194759368896,50000 -1111.0696558952332,2.34027099609375,32168.14310240745,94150,0,32168.14310240745,0.5834000110626221,2.04312801361084,10000,33284.76044559479,0.7915935516357422,1.011091709136963,0.7041199803352356,1.3878302574157717,50000 -1128.321323633194,2.410131454467773,32678.294413089752,95646,0,32678.294413089752,0.570900022983551,2.061049699783325,10000,33812.28715801239,0.8151108026504517,0.9238508343696594,0.7001399993896484,1.4183787107467651,50000 -1145.3379135131836,2.4528286457061768,33188.26053261757,97141,0,33188.26053261757,0.5785000324249268,2.047945022583008,10000,34339.3632774353,0.8090322017669678,0.9489533305168152,0.7054199576377869,1.388757824897766,50000 -1162.741782665253,2.5040063858032227,33698.45007276535,98637,0,33698.45007276535,0.5819000005722046,2.044516563415528,10000,34867.06263709068,0.8087332248687744,0.9831731915473938,0.7090199589729309,1.404129147529602,50000 -1180.1750495433807,2.553105354309082,34208.44733428955,100132,0,34208.44733428955,0.5609000325202942,2.109252691268921,10000,35394.59343075752,0.7820471525192261,1.040468454360962,0.687720000743866,1.4538909196853638,50000 -1197.4660267829895,2.60019588470459,34718.620924949646,101628,0,34718.620924949646,0.5788000226020813,2.02254056930542,10000,35922.155831575394,0.8019172549247742,0.9545334577560424,0.7048199772834778,1.3737319707870483,50000 -1214.7932357788086,2.648162364959717,35228.76759457588,103124,0,35228.76759457588,0.5879000425338745,2.022136926651001,10000,36449.7290225029,0.8040497303009033,0.9926905035972596,0.7096999883651733,1.3941349983215332,50000 -1232.5322148799896,2.693121671676636,35738.865753889084,104620,0,35738.865753889084,0.5788000226020813,2.058584213256836,10000,36977.66103959084,0.8269292116165161,0.9147905707359314,0.7056399583816528,1.4107757806777954,50000 -1249.7009994983673,2.7423999309539795,36249.07039260864,106116,0,36249.07039260864,0.5855000019073486,2.002440929412842,10000,37505.13443374634,0.8301976919174194,0.8830342292785645,0.711359977722168,1.3829030990600586,50000 -1266.9804100990295,2.787186861038208,36759.14828467369,107612,0,36759.14828467369,0.5803000330924988,2.026761054992676,10000,38032.58826184273,0.8184390664100647,0.9219950437545776,0.7115199565887451,1.380744695663452,50000 -1284.1733181476593,2.833984613418579,37269.2513358593,109107,0,37269.2513358593,0.5889000296592712,1.9659696817398071,10000,38559.98116540909,0.8204121589660645,0.8763977885246277,0.7163999676704407,1.3261253833770752,50000 -1301.520435810089,2.8837478160858154,37779.28477668762,110603,0,37779.28477668762,0.5920000076293945,1.9924134016036987,10000,39087.46141552925,0.8202527165412903,0.907325804233551,0.7150799632072449,1.3554683923721311,50000 -1318.7552318572998,2.933572769165039,38289.51672291756,112099,0,38289.51672291756,0.5906000137329102,2.007988452911377,10000,39615.02745246887,0.8202327489852905,0.9242339134216307,0.7185199856758118,1.3499921560287476,50000 -1335.9850897789,2.984329462051392,38799.4973552227,113594,0,38799.4973552227,0.5892000198364258,2.026318550109864,10000,40142.33880519867,0.8160673975944519,0.9513839483261108,0.7136200070381165,1.3845340013504028,50000 -1353.4041435718536,3.0307705402374268,39309.72198271752,115090,0,39309.72198271752,0.5927000045776367,1.9718732833862305,10000,40670.07956337929,0.8474569320678711,0.7873832583427429,0.7181199789047241,1.3268829584121704,50000 -1370.7485435009005,3.077587127685547,39819.89853024483,116586,0,39819.89853024483,0.5944000482559204,1.96568763256073,10000,41197.69870424271,0.8415776491165161,0.8181595206260681,0.7233999967575073,1.3208779096603394,50000 -1388.5198497772217,3.126605749130249,40329.83609485626,118081,0,40329.83609485626,0.5999000072479248,1.953519582748413,10000,41725.5061340332,0.8370934128761292,0.8302208185195923,0.7227999567985535,1.316156268119812,50000 -1405.8627269268036,3.179180860519409,40839.775539159775,119577,0,40839.775539159775,0.5933000445365906,1.958518624305725,10000,42252.89182281494,0.8384885191917419,0.8248199224472046,0.7274799942970276,1.3024871349334717,50000 -1423.086817741394,3.2281734943389893,41350.00688409805,121073,0,41350.00688409805,0.6027000546455383,1.9129961729049685,10000,42780.44771409035,0.8438097834587097,0.8003804683685303,0.7287399768829346,1.283615231513977,50000 -1440.245083808899,3.281466007232666,41860.03619909287,122569,0,41860.03619909287,0.5974000096321106,1.9827086925506592,10000,43307.74064588547,0.832051157951355,0.8912011384963989,0.7217599749565125,1.3543952703475952,50000 -1457.4404020309448,3.34682035446167,42370.23780179024,124065,0,42370.23780179024,0.6080000400543213,1.9164098501205444,10000,43835.2552819252,0.8768334984779358,0.6922958493232727,0.7325199842453003,1.277902126312256,50000 -1474.8576436042786,3.398554563522339,42880.30663514137,125561,0,42880.30663514137,0.6013000011444092,1.927111268043518,10000,44362.843577861786,0.857421875,0.7503759264945984,0.7276999950408936,1.2880154848098757,50000 -1492.1135189533234,3.44738507270813,43390.468705654144,127057,0,43390.468705654144,0.6014000177383423,1.9597022533416748,10000,44890.360988378525,0.8565648794174194,0.7832421064376831,0.7305200099945068,1.3154152631759644,50000 -1509.6416466236117,3.495047092437744,43900.60306334496,128553,0,43900.60306334496,0.6163000464439392,1.8852897882461548,10000,45418.12388706207,0.8598732352256775,0.7385032773017883,0.7359399795532227,1.2580225467681885,50000 -1526.9550709724426,3.543776512145996,44410.62669610977,130048,0,44410.62669610977,0.6109000444412231,1.88594388961792,10000,45945.561729192734,0.8598732352256775,0.7397308349609375,0.7367599606513977,1.2535072565078735,50000 -1544.0944755077362,3.5926764011383057,44920.61176490784,131543,0,44920.61176490784,0.613800048828125,1.8877145051956177,10000,46472.78757691383,0.8572624325752258,0.7386617660522461,0.7386400103569031,1.2487965822219849,50000 -1561.2148640155792,3.644584894180298,45430.52412056923,133037,0,45430.52412056923,0.613800048828125,1.889812588691712,10000,46999.9243516922,0.866609513759613,0.7195022106170654,0.73881995677948,1.249531865119934,50000 -1578.4732689857483,3.704412937164306,45940.66190671921,134533,0,45940.66190671921,0.6176000237464905,1.8791606426239007,10000,47527.43191242218,0.8819156289100647,0.64863520860672,0.7436599731445312,1.2403019666671753,50000 -1595.7042515277865,3.7541310787200928,46450.88255214691,136029,0,46450.88255214691,0.6157000064849854,1.88289487361908,10000,48054.984679460526,0.8819355964660645,0.6802816987037659,0.7430999875068665,1.2458999156951904,50000 -1612.8796932697296,3.811321020126343,46961.05339837074,137525,0,46961.05339837074,0.6220000386238098,1.8697481155395508,10000,48582.43872284889,0.8787667155265808,0.682250440120697,0.7416799664497375,1.249807357788086,50000 -1630.022828578949,3.863240003585816,47471.02515506744,139020,0,47471.02515506744,0.6232000589370728,1.8762387037277224,10000,49109.65690302849,0.8798628449440002,0.6878957748413086,0.7449600100517273,1.2519656419754028,50000 -1647.2985713481903,3.913106203079224,47981.19206619263,140516,0,47981.19206619263,0.6203000545501709,1.8631223440170288,10000,49637.203291893005,0.8801219463348389,0.6773303151130676,0.742859959602356,1.2454878091812134,50000 -1664.4843318462372,3.967681646347046,48491.136139154434,142011,0,48491.136139154434,0.6175000071525574,1.883688807487488,10000,50164.43881726265,0.8788862824440002,0.6838304996490479,0.7432599663734436,1.2508834600448608,50000 -1681.6815202236176,4.021223306655884,49001.24908995628,143507,0,49001.24908995628,0.6228000521659851,1.848327279090881,10000,50691.854425907135,0.9057118892669678,0.5698558688163757,0.7466599941253662,1.2103779315948486,50000 -1699.1049826145172,4.07612681388855,49511.237023592,145002,0,49511.237023592,0.6213000416755676,1.8598668575286863,10000,51219.37195444107,0.8981983065605164,0.6036162972450256,0.7472599744796753,1.228808045387268,50000 -1716.2525057792664,4.129936933517456,50021.47194981575,146498,0,50021.47194981575,0.6213000416755676,1.8459560871124268,10000,51746.86004567146,0.8981983065605164,0.5911901593208313,0.7471599578857422,1.2115702629089355,50000 -1733.6255042552948,4.185802459716797,50531.48566865921,147993,0,50531.48566865921,0.6234000325202942,1.849199652671814,10000,52274.35407853127,0.8964046239852905,0.6049222350120544,0.7515599727630615,1.208533525466919,50000 -1751.1035561561584,4.243193626403809,51041.4313583374,149489,0,51041.4313583374,0.6240000128746033,1.8540517091751096,10000,52801.8858203888,0.899832546710968,0.6207277774810791,0.7516599893569946,1.2279235124588013,50000 -1769.2537994384766,4.300515413284302,51551.43528985977,150984,0,51551.43528985977,0.6312000155448914,1.8340696096420288,10000,53330.150327920914,0.9048349857330322,0.5874025821685791,0.7528600096702576,1.2086260318756104,50000 -1786.5181086063385,4.345062255859375,52061.50852203369,152480,0,52061.50852203369,0.6295000314712524,1.83812952041626,10000,53857.58398604393,0.9223333597183228,0.5322718024253845,0.7513200044631958,1.214342713356018,50000 -1803.822417974472,4.4028167724609375,52571.40790128708,153975,0,52571.40790128708,0.6319000124931335,1.841834664344788,10000,54384.898052454,0.9167131781578064,0.5371339917182922,0.7554799914360046,1.202966809272766,50000 -1821.2124042510984,4.457838535308838,53081.45244860649,155471,0,53081.45244860649,0.6337000131607056,1.82937240600586,10000,54912.43840265274,0.9189253449440002,0.5368630886077881,0.7572599649429321,1.1976443529129028,50000 -1839.1349787712093,4.514558553695679,53591.627490758896,156967,0,53591.627490758896,0.6295000314712524,1.8315805196762085,10000,55440.64496970177,0.9168327450752258,0.5451948642730713,0.7570799589157104,1.195743203163147,50000 -1856.390554189682,4.57092547416687,54101.57953572273,158462,0,54101.57953572273,0.6314000487327576,1.8226438760757449,10000,55967.96126079559,0.9191047549247742,0.5344241857528687,0.7557399868965149,1.1946972608566284,50000 -1873.7629334926603,4.634932041168213,54611.48469758034,159957,0,54611.48469758034,0.6325000524520874,1.8268934488296509,10000,56495.354562044144,0.919144570827484,0.5355068445205688,0.7576599717140198,1.2013722658157349,50000 -1891.2067058086395,4.697674751281738,55121.617911338806,161453,0,55121.617911338806,0.6340000033378601,1.8122520446777344,10000,57023.04491233826,0.9230110049247742,0.5109298229217529,0.7595599889755249,1.1808478832244873,50000 -1908.4826707839968,4.7523229122161865,55631.54829573631,162948,0,55631.54829573631,0.6338000297546387,1.813969612121582,10000,57550.35655713081,0.9344706535339355,0.4809099733829498,0.7597000002861023,1.189298152923584,50000 -1925.8598954677584,4.809988737106323,56141.64248919487,164444,0,56141.64248919487,0.6353000402450562,1.8088539838790887,10000,58077.93835878372,0.9301458597183228,0.4850182235240936,0.7604199647903442,1.181099534034729,50000 -1943.168488740921,4.8650195598602295,56651.543511390686,165939,0,56651.543511390686,0.6385000348091125,1.806591510772705,10000,58605.25307846069,0.9299465417861938,0.4834108352661133,0.7608599662780762,1.1775710582733154,50000 -1960.2930953502653,4.926531791687012,57161.636585474014,167435,0,57161.636585474014,0.6402000188827515,1.803207874298096,10000,59132.583621263504,0.9329559803009032,0.4767115414142608,0.7617599964141846,1.1800211668014526,50000 -1977.6313047409053,4.987208604812622,57671.803205251694,168931,0,57671.803205251694,0.6378000378608704,1.8050771951675413,10000,59660.20059251785,0.9340720176696776,0.4837055206298828,0.7615999579429626,1.1844817399978638,50000 -1994.7446548938751,5.04884672164917,58181.90040636063,170427,0,58181.90040636063,0.6420000195503235,1.7998629808425903,10000,60187.52374267578,0.9329758882522584,0.4784607589244842,0.7625399827957153,1.1756603717803955,50000 -2011.99293589592,5.10937762260437,58692.09399271011,171923,0,58692.09399271011,0.6399000287055969,1.7999354600906372,10000,60715.07775473595,0.9389548301696776,0.4606906175613403,0.7628600001335144,1.1755813360214231,50000 -2029.256034374237,5.169875860214233,59202.0762693882,173418,0,59202.0762693882,0.64000004529953,1.8056416511535645,10000,61242.43579864502,0.9387754797935486,0.4628245830535888,0.761900007724762,1.1791095733642578,50000 -2046.667736530304,5.231820344924927,59712.00906252861,174913,0,59712.00906252861,0.6394000053405762,1.7951186895370483,10000,61769.894496679306,0.9377790093421936,0.4617302119731903,0.7626799941062927,1.1732220649719238,50000 -2063.812755346298,5.288837909698486,60222.23619198799,176409,0,60222.23619198799,0.64410001039505,1.790612816810608,10000,62297.375115156174,0.93949294090271,0.4578624963760376,0.76419997215271,1.1710017919540403,50000 -2080.94895029068,5.348676443099976,60732.223915576935,177904,0,60732.223915576935,0.6421000361442566,1.7899819612503052,10000,62824.61188578606,0.9384167790412904,0.4547231495380401,0.7644400000572205,1.167048454284668,50000 -2098.2528672218323,5.410964012145996,61242.4484193325,179400,0,61242.4484193325,0.6438000202178955,1.7903746366500854,10000,63352.25589585304,0.9395527839660645,0.4537563621997833,0.7645599842071533,1.1670894622802734,50000 -2115.509332180023,5.475275278091431,61752.42921423912,180895,0,61752.42921423912,0.6429000496864319,1.7909166812896729,10000,63879.60906338692,0.9403898119926452,0.4528295397758484,0.7649999856948853,1.168945074081421,50000 -2132.8234283924103,5.542778015136719,62262.553658008575,182391,0,62262.553658008575,0.6430000066757202,1.7924065589904783,10000,64407.16589832306,0.9399114847183228,0.4492985010147095,0.7644000053405762,1.1697618961334229,50000 -2150.1170587539673,6.747524976730347,62771.477128744125,183883,0,62771.477128744125,0.6425000429153442,1.794602870941162,10000,64934.63848924637,0.9393534660339355,0.4564688205718994,0.7648400068283081,1.1722710132598877,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/measurements.csv deleted file mode 100644 index 550facb75..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1972 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.52604043,6.9287834,,,,,,,,,,,,,, -1,,,0.0010363520123064,6.910816669464111,0.0010199999669566,6.91091251373291,50000.0,0.0012000000569969,6.910790920257568,10000.0,32.273086071014404,49.67532181739807,32.273086071014404,17.40215516090393,0.0,0.0 -100,0.5132727,6.8975954,,,,,,,,,,,,,, -200,0.54992706,6.8603363,,,,,,,,,,,,,, -300,0.5763813,6.7615232,,,,,,,,,,,,,, -400,0.62886685,6.686418,,,,,,,,,,,,,, -500,0.6432526,6.612051,,,,,,,,,,,,,, -600,0.65572786,6.529647,,,,,,,,,,,,,, -700,0.9977077,6.4364595,,,,,,,,,,,,,, -800,1.2202073,6.32804,,,,,,,,,,,,,, -900,1.1839918,6.275005,,,,,,,,,,,,,, -1000,2.4436398,6.221039,,,,,,,,,,,,,, -1100,1.6328164,6.1786933,,,,,,,,,,,,,, -1200,1.4590516,6.078861,,,,,,,,,,,,,, -1300,1.7700081,6.032664,,,,,,,,,,,,,, -1400,1.5763371,5.99924,,,,,,,,,,,,,, -1491,,,0.0832270383834838,5.25524377822876,0.0771199986338615,5.3119893074035645,50000.0,0.0553000010550022,5.544482231140137,10000.0,542.2085037231445,577.119710445404,542.2085037231445,34.842326641082764,0.0192625522613525,0.0 -1500,2.3492813,5.9396834,,,,,,,,,,,,,, -1600,3.098443,5.776701,,,,,,,,,,,,,, -1700,2.8277385,5.7617674,,,,,,,,,,,,,, -1800,2.6708627,5.768613,,,,,,,,,,,,,, -1900,3.1922178,5.660063,,,,,,,,,,,,,, -2000,2.0481493,5.7112093,,,,,,,,,,,,,, -2100,2.8098357,5.646862,,,,,,,,,,,,,, -2200,2.2730007,5.599942,,,,,,,,,,,,,, -2300,3.1922066,5.5399294,,,,,,,,,,,,,, -2400,2.8404627,5.515602,,,,,,,,,,,,,, -2500,3.0483162,5.4065228,,,,,,,,,,,,,, -2600,2.7332075,5.4106293,,,,,,,,,,,,,, -2700,5.6989083,5.4051456,,,,,,,,,,,,,, -2800,2.42146,5.299096,,,,,,,,,,,,,, -2900,2.9945822,5.3330717,,,,,,,,,,,,,, -2981,,,0.1989795863628387,4.21004581451416,0.1790999919176101,4.317634105682373,50000.0,0.1312000006437301,4.706563472747803,10000.0,1052.3712661266327,1104.8108565807345,1052.3712661266327,52.28858208656311,0.0503122806549072,0.0 -3000,4.6148214,5.203945,,,,,,,,,,,,,, -3100,3.6489673,5.182989,,,,,,,,,,,,,, -3200,2.3546028,5.1932893,,,,,,,,,,,,,, -3300,3.2097037,5.0747504,,,,,,,,,,,,,, -3400,4.7958455,5.071396,,,,,,,,,,,,,, -3500,3.3314672,5.0424075,,,,,,,,,,,,,, -3600,3.6293738,5.0411468,,,,,,,,,,,,,, -3700,3.2358031,4.9513264,,,,,,,,,,,,,, -3800,3.972486,4.9459295,,,,,,,,,,,,,, -3900,5.5437837,4.9301596,,,,,,,,,,,,,, -4000,4.977533,5.004155,,,,,,,,,,,,,, -4100,3.9627318,4.8288007,,,,,,,,,,,,,, -4200,3.5189373,4.8799486,,,,,,,,,,,,,, -4300,4.8530445,4.7631235,,,,,,,,,,,,,, -4400,3.1651895,4.6867414,,,,,,,,,,,,,, -4470,,,0.2876076102256775,3.585117816925049,0.2639600038528442,3.7350914478302,50000.0,0.19200000166893,4.257549285888672,10000.0,1562.2961611747742,1632.9713141918182,1562.2961611747742,70.44600534439087,0.0788969993591308,0.0 -4500,3.9264638,4.7499313,,,,,,,,,,,,,, -4600,4.1325226,4.694254,,,,,,,,,,,,,, -4700,3.6340816,4.5889373,,,,,,,,,,,,,, -4800,4.4641643,4.814889,,,,,,,,,,,,,, -4900,4.1612687,4.5681877,,,,,,,,,,,,,, -5000,3.6950407,4.6233,,,,,,,,,,,,,, -5100,5.304863,4.5023108,,,,,,,,,,,,,, -5200,4.2696686,4.4459705,,,,,,,,,,,,,, -5300,4.3809257,4.475561,,,,,,,,,,,,,, -5400,4.8222084,4.445606,,,,,,,,,,,,,, -5500,2.4932094,4.4113984,,,,,,,,,,,,,, -5600,3.0648224,4.5056973,,,,,,,,,,,,,, -5700,3.6474373,4.3500366,,,,,,,,,,,,,, -5800,2.2693295,4.386189,,,,,,,,,,,,,, -5900,2.615666,4.4282446,,,,,,,,,,,,,, -5960,,,0.388073980808258,2.992854356765747,0.3614400029182434,3.130856990814209,50000.0,0.2747000157833099,3.713503360748291,10000.0,2072.471899271012,2160.769720315933,2072.471899271012,87.99014830589294,0.1064472198486328,0.0 -6000,3.394574,4.369294,,,,,,,,,,,,,, -6100,3.1815603,4.37648,,,,,,,,,,,,,, -6200,4.397026,4.369485,,,,,,,,,,,,,, -6300,2.6565535,4.2724085,,,,,,,,,,,,,, -6400,2.6842594,4.3273563,,,,,,,,,,,,,, -6500,3.6978292,4.27097,,,,,,,,,,,,,, -6600,2.7904406,4.2915974,,,,,,,,,,,,,, -6700,3.0354402,4.2081027,,,,,,,,,,,,,, -6800,2.9186122,4.1441946,,,,,,,,,,,,,, -6900,2.5946782,4.2156353,,,,,,,,,,,,,, -7000,5.691961,4.1972322,,,,,,,,,,,,,, -7100,3.0842752,4.192257,,,,,,,,,,,,,, -7200,2.7176366,4.1912556,,,,,,,,,,,,,, -7300,2.4093204,4.19752,,,,,,,,,,,,,, -7400,2.6255667,4.092894,,,,,,,,,,,,,, -7450,,,0.4523876011371612,2.617628335952759,0.4184799790382385,2.7827277183532715,50000.0,0.3164000213146209,3.388169765472412,10000.0,2582.5007004737854,2688.537534236908,2582.5007004737854,105.64619159698486,0.1386871337890625,0.0 -7500,2.761042,4.1496806,,,,,,,,,,,,,, -7600,3.301669,4.181832,,,,,,,,,,,,,, -7700,2.9193285,4.1627645,,,,,,,,,,,,,, -7800,2.3164258,4.014605,,,,,,,,,,,,,, -7900,2.8525624,4.1450224,,,,,,,,,,,,,, -8000,2.821714,4.0683365,,,,,,,,,,,,,, -8100,2.8109543,4.03735,,,,,,,,,,,,,, -8200,3.7176714,4.0739746,,,,,,,,,,,,,, -8300,2.5780592,4.027705,,,,,,,,,,,,,, -8400,2.3810725,3.9690266,,,,,,,,,,,,,, -8500,2.3440394,4.076456,,,,,,,,,,,,,, -8600,2.2941935,4.023225,,,,,,,,,,,,,, -8700,2.5349007,3.9213955,,,,,,,,,,,,,, -8800,2.6973062,4.0017257,,,,,,,,,,,,,, -8900,2.360788,3.9739947,,,,,,,,,,,,,, -8942,,,0.5526944994926453,2.124098539352417,0.4767199754714966,2.479666233062744,50000.0,0.3680000305175781,3.117877244949341,10000.0,3092.7613253593445,3216.206754922867,3092.7613253593445,122.9762909412384,0.1666934490203857,0.0 -9000,1.8861032,3.9868464,,,,,,,,,,,,,, -9100,2.7171361,3.9276822,,,,,,,,,,,,,, -9200,2.0662153,3.9314094,,,,,,,,,,,,,, -9300,2.1298788,3.9038157,,,,,,,,,,,,,, -9400,2.1367884,3.851864,,,,,,,,,,,,,, -9500,1.8775679,3.9588141,,,,,,,,,,,,,, -9600,1.7085259,3.8965552,,,,,,,,,,,,,, -9700,1.9505674,3.8708978,,,,,,,,,,,,,, -9800,1.6441973,3.8577938,,,,,,,,,,,,,, -9900,3.3862622,3.898024,,,,,,,,,,,,,, -10000,2.908264,3.903823,,,,,,,,,,,,,, -10100,2.3785648,3.83133,,,,,,,,,,,,,, -10200,2.257267,3.822773,,,,,,,,,,,,,, -10300,1.5767171,3.7632923,,,,,,,,,,,,,, -10400,1.6474746,3.815244,,,,,,,,,,,,,, -10434,,,0.5689373016357422,2.022469520568848,0.522159993648529,2.2517900466918945,50000.0,0.4036000072956085,2.9123005867004395,10000.0,3602.929986476898,3743.8133985996246,3602.929986476898,140.33562183380127,0.1949865818023681,0.0 -10500,2.0135272,3.7800026,,,,,,,,,,,,,, -10600,2.5066192,3.7429247,,,,,,,,,,,,,, -10700,2.5489779,3.7523928,,,,,,,,,,,,,, -10800,1.5452917,3.772097,,,,,,,,,,,,,, -10900,1.6769338,3.708373,,,,,,,,,,,,,, -11000,1.25515,3.680966,,,,,,,,,,,,,, -11100,1.7842301,3.7994082,,,,,,,,,,,,,, -11200,1.797121,3.7733104,,,,,,,,,,,,,, -11300,1.9642707,3.71143,,,,,,,,,,,,,, -11400,2.0475433,3.7088652,,,,,,,,,,,,,, -11500,2.2026818,3.7105935,,,,,,,,,,,,,, -11600,2.0440845,3.717077,,,,,,,,,,,,,, -11700,2.7116756,3.5993195,,,,,,,,,,,,,, -11800,1.7417134,3.6514344,,,,,,,,,,,,,, -11900,1.4012059,3.654393,,,,,,,,,,,,,, -11926,,,0.5983737111091614,1.9081037044525144,0.5486199855804443,2.138624906539917,50000.0,0.4262000322341919,2.801759719848633,10000.0,4112.95587015152,4271.180375099182,4112.95587015152,157.598552942276,0.2236585617065429,0.0 -12000,1.9777753,3.686514,,,,,,,,,,,,,, -12100,2.1708293,3.6948695,,,,,,,,,,,,,, -12200,1.8644549,3.6923685,,,,,,,,,,,,,, -12300,1.4750901,3.6070113,,,,,,,,,,,,,, -12400,1.5242966,3.6464653,,,,,,,,,,,,,, -12500,2.0239673,3.6548352,,,,,,,,,,,,,, -12600,1.4339596,3.707909,,,,,,,,,,,,,, -12700,1.5371718,3.59032,,,,,,,,,,,,,, -12800,1.8614614,3.6393778,,,,,,,,,,,,,, -12900,1.8664848,3.634951,,,,,,,,,,,,,, -13000,1.5827912,3.6198978,,,,,,,,,,,,,, -13100,1.7732857,3.6743426,,,,,,,,,,,,,, -13200,2.401636,3.6599038,,,,,,,,,,,,,, -13300,1.5576053,3.6040688,,,,,,,,,,,,,, -13400,1.5519108,3.7022452,,,,,,,,,,,,,, -13419,,,0.6163105964660645,1.7988837957382202,0.5671799778938293,2.038691282272339,50000.0,0.4473000168800354,2.6989235877990723,10000.0,4623.103399991989,4798.86471414566,4623.103399991989,175.05454540252686,0.2537183761596679,0.0 -13500,1.9962001,3.6561577,,,,,,,,,,,,,, -13600,1.8292577,3.6020174,,,,,,,,,,,,,, -13700,1.8049037,3.5344071,,,,,,,,,,,,,, -13800,1.7144449,3.655117,,,,,,,,,,,,,, -13900,1.9235636,3.5352366,,,,,,,,,,,,,, -14000,1.3138016,3.6563146,,,,,,,,,,,,,, -14100,1.6726885,3.5157866,,,,,,,,,,,,,, -14200,1.4524466,3.5397854,,,,,,,,,,,,,, -14300,1.5512129,3.5666726,,,,,,,,,,,,,, -14400,1.4467055,3.5342402,,,,,,,,,,,,,, -14500,1.8467348,3.5754375,,,,,,,,,,,,,, -14600,1.6871942,3.5281153,,,,,,,,,,,,,, -14700,1.7528248,3.5873675,,,,,,,,,,,,,, -14800,1.407281,3.5194628,,,,,,,,,,,,,, -14900,1.6648222,3.4844682,,,,,,,,,,,,,, -14912,,,0.6312180757522583,1.7623796463012695,0.5798799991607666,1.994551658630371,50000.0,0.4590000212192535,2.626842975616455,10000.0,5133.311586380005,5326.436780452728,5133.311586380005,192.3369791507721,0.2829594612121582,0.0 -15000,2.083211,3.482032,,,,,,,,,,,,,, -15100,1.6609802,3.5560527,,,,,,,,,,,,,, -15200,2.0711093,3.574633,,,,,,,,,,,,,, -15300,1.3657968,3.4808216,,,,,,,,,,,,,, -15400,1.8903617,3.527814,,,,,,,,,,,,,, -15500,1.4435221,3.4935918,,,,,,,,,,,,,, -15600,1.3558968,3.5894682,,,,,,,,,,,,,, -15700,1.6818888,3.5987837,,,,,,,,,,,,,, -15800,1.1881012,3.4386003,,,,,,,,,,,,,, -15900,1.297229,3.4884841,,,,,,,,,,,,,, -16000,1.3290033,3.5597036,,,,,,,,,,,,,, -16100,1.2100083,3.41075,,,,,,,,,,,,,, -16200,2.1326857,3.4872332,,,,,,,,,,,,,, -16300,1.3004656,3.3975363,,,,,,,,,,,,,, -16400,1.4321554,3.4533622,,,,,,,,,,,,,, -16405,,,0.6322146058082581,1.7425355911254885,0.5858799815177917,1.9565809965133667,50000.0,0.4595000147819519,2.652688980102539,10000.0,5643.3654227256775,5853.849092960358,5643.3654227256775,209.60658407211304,0.3215117454528808,0.0 -16500,1.710975,3.4629574,,,,,,,,,,,,,, -16600,1.6099799,3.4315035,,,,,,,,,,,,,, -16700,1.6398785,3.5107276,,,,,,,,,,,,,, -16800,1.3786061,3.4508717,,,,,,,,,,,,,, -16900,1.5608302,3.4755812,,,,,,,,,,,,,, -17000,1.2835732,3.4379907,,,,,,,,,,,,,, -17100,1.4657571,3.45553,,,,,,,,,,,,,, -17200,1.601437,3.4261732,,,,,,,,,,,,,, -17300,1.3000892,3.4031672,,,,,,,,,,,,,, -17400,1.9597474,3.4082365,,,,,,,,,,,,,, -17500,1.3774037,3.4421759,,,,,,,,,,,,,, -17600,1.385517,3.4748294,,,,,,,,,,,,,, -17700,1.6447854,3.5226183,,,,,,,,,,,,,, -17800,1.5599107,3.4605472,,,,,,,,,,,,,, -17899,,,0.6486168503761292,1.6495143175125122,0.5984199643135071,1.885051250457764,50000.0,0.4756000339984894,2.5317037105560303,10000.0,6153.40634727478,6381.513226747513,6153.40634727478,227.14486122131348,0.3570168018341064,0.0 -17900,1.5565715,3.4791136,,,,,,,,,,,,,, -18000,1.645308,3.5347552,,,,,,,,,,,,,, -18100,1.5005445,3.4053211,,,,,,,,,,,,,, -18200,1.606543,3.4217021,,,,,,,,,,,,,, -18300,1.3804404,3.4131525,,,,,,,,,,,,,, -18400,1.4195439,3.4729605,,,,,,,,,,,,,, -18500,1.1540015,3.3804114,,,,,,,,,,,,,, -18600,1.1430731,3.4145322,,,,,,,,,,,,,, -18700,1.4373175,3.3824227,,,,,,,,,,,,,, -18800,1.5070195,3.4513338,,,,,,,,,,,,,, -18900,1.2451577,3.4212868,,,,,,,,,,,,,, -19000,1.7549337,3.5016046,,,,,,,,,,,,,, -19100,1.4828451,3.4578643,,,,,,,,,,,,,, -19200,1.7790526,3.4966624,,,,,,,,,,,,,, -19300,1.669355,3.4000344,,,,,,,,,,,,,, -19392,,,0.6943359375,1.483161449432373,0.615619957447052,1.8391530513763428,50000.0,0.4908000230789184,2.465847969055176,10000.0,6663.387640237808,6909.0832579135895,6663.387640237808,244.6494562625885,0.3890948295593261,0.0 -19400,1.197369,3.5160694,,,,,,,,,,,,,, -19500,1.8456296,3.4478736,,,,,,,,,,,,,, -19600,1.490853,3.4147162,,,,,,,,,,,,,, -19700,1.5207739,3.440584,,,,,,,,,,,,,, -19800,1.3347734,3.3889537,,,,,,,,,,,,,, -19900,1.3515108,3.462482,,,,,,,,,,,,,, -20000,1.2648532,3.3849194,,,,,,,,,,,,,, -20100,1.269572,3.3529084,,,,,,,,,,,,,, -20200,1.5281059,3.3682184,,,,,,,,,,,,,, -20300,1.3202645,3.4744215,,,,,,,,,,,,,, -20400,1.1321639,3.3514934,,,,,,,,,,,,,, -20500,1.5474819,3.373765,,,,,,,,,,,,,, -20600,1.3315121,3.3285928,,,,,,,,,,,,,, -20700,1.2881422,3.3789597,,,,,,,,,,,,,, -20800,1.2493774,3.3646405,,,,,,,,,,,,,, -20885,,,0.6927216053009033,1.4884588718414309,0.6250799894332886,1.790726661682129,50000.0,0.4933000206947326,2.453991174697876,10000.0,7173.381569385529,7436.744100093842,7173.381569385529,262.23272013664246,0.420799970626831,0.0 -20900,1.3334371,3.489345,,,,,,,,,,,,,, -21000,1.2202258,3.4113617,,,,,,,,,,,,,, -21100,1.7937739,3.4375918,,,,,,,,,,,,,, -21200,1.2602837,3.4063866,,,,,,,,,,,,,, -21300,1.5180097,3.4094348,,,,,,,,,,,,,, -21400,1.1743021,3.2939277,,,,,,,,,,,,,, -21500,1.1296952,3.397254,,,,,,,,,,,,,, -21600,1.2411655,3.420697,,,,,,,,,,,,,, -21700,1.2127929,3.316807,,,,,,,,,,,,,, -21800,1.6072459,3.3778813,,,,,,,,,,,,,, -21900,1.1681685,3.4353504,,,,,,,,,,,,,, -22000,1.1818383,3.3509684,,,,,,,,,,,,,, -22100,1.2740396,3.3304899,,,,,,,,,,,,,, -22200,1.5177191,3.3082411,,,,,,,,,,,,,, -22300,1.6417571,3.3634121,,,,,,,,,,,,,, -22378,,,0.6850087642669678,1.496760010719299,0.6200599670410156,1.7901036739349363,50000.0,0.4984000325202942,2.4333972930908203,10000.0,7683.411033153534,7964.318654060364,7683.411033153534,279.69251894950867,0.4559054374694824,0.0 -22400,1.3992921,3.334021,,,,,,,,,,,,,, -22500,1.3769116,3.3733969,,,,,,,,,,,,,, -22600,1.3297116,3.313531,,,,,,,,,,,,,, -22700,1.3780077,3.3138533,,,,,,,,,,,,,, -22800,1.4545768,3.3393817,,,,,,,,,,,,,, -22900,1.2737696,3.29635,,,,,,,,,,,,,, -23000,1.3978614,3.3542335,,,,,,,,,,,,,, -23100,1.3517092,3.3673444,,,,,,,,,,,,,, -23200,1.3530818,3.3489773,,,,,,,,,,,,,, -23300,1.2355926,3.3322773,,,,,,,,,,,,,, -23400,1.7545285,3.4098353,,,,,,,,,,,,,, -23500,1.3447417,3.3727462,,,,,,,,,,,,,, -23600,1.2758754,3.3659399,,,,,,,,,,,,,, -23700,1.571544,3.3303306,,,,,,,,,,,,,, -23800,1.062575,3.3217807,,,,,,,,,,,,,, -23872,,,0.6826769709587097,1.535045146942139,0.6202600002288818,1.8089746236801147,50000.0,0.4937000274658203,2.4656965732574463,10000.0,8193.398950099945,8491.685319185257,8193.398950099945,296.98755836486816,0.4871697425842285,0.0 -23900,1.2980396,3.304303,,,,,,,,,,,,,, -24000,1.4061942,3.4148407,,,,,,,,,,,,,, -24100,1.6145341,3.3642097,,,,,,,,,,,,,, -24200,1.2524949,3.3477607,,,,,,,,,,,,,, -24300,1.2241527,3.3407362,,,,,,,,,,,,,, -24400,1.2690636,3.3499467,,,,,,,,,,,,,, -24500,1.2204918,3.302281,,,,,,,,,,,,,, -24600,1.256431,3.3095136,,,,,,,,,,,,,, -24700,1.1077812,3.2738693,,,,,,,,,,,,,, -24800,1.1801085,3.3032353,,,,,,,,,,,,,, -24900,1.2083386,3.4435363,,,,,,,,,,,,,, -25000,1.4587917,3.3081331,,,,,,,,,,,,,, -25100,1.4729995,3.306713,,,,,,,,,,,,,, -25200,1.236328,3.2509859,,,,,,,,,,,,,, -25300,1.3341528,3.309268,,,,,,,,,,,,,, -25367,,,0.7011120915412903,1.4206559658050537,0.6384599804878235,1.7001821994781494,50000.0,0.5014000535011292,2.360618352890014,10000.0,8703.517453432083,9019.426125764849,8703.517453432083,314.5280523300171,0.5178220272064209,0.0 -25400,1.6003319,3.4161398,,,,,,,,,,,,,, -25500,1.6252416,3.2964716,,,,,,,,,,,,,, -25600,1.5501766,3.3241775,,,,,,,,,,,,,, -25700,1.1895077,3.3599668,,,,,,,,,,,,,, -25800,1.1531407,3.299532,,,,,,,,,,,,,, -25900,1.2547375,3.410922,,,,,,,,,,,,,, -26000,1.2106982,3.2345648,,,,,,,,,,,,,, -26100,1.5124038,3.321769,,,,,,,,,,,,,, -26200,1.446198,3.349401,,,,,,,,,,,,,, -26300,1.2508254,3.2672515,,,,,,,,,,,,,, -26400,1.2821755,3.3725219,,,,,,,,,,,,,, -26500,1.1149403,3.2779436,,,,,,,,,,,,,, -26600,1.343527,3.2295313,,,,,,,,,,,,,, -26700,1.4522853,3.2396812,,,,,,,,,,,,,, -26800,1.3023225,3.3313394,,,,,,,,,,,,,, -26862,,,0.6947743892669678,1.4827384948730469,0.6357600092887878,1.7430270910263062,50000.0,0.5125000476837158,2.392122507095337,10000.0,9213.776574373243,9547.277846574783,9213.776574373243,332.03754591941833,0.5506632328033447,0.0 -26900,1.2104517,3.2771032,,,,,,,,,,,,,, -27000,1.1364368,3.404171,,,,,,,,,,,,,, -27100,1.2725766,3.4014108,,,,,,,,,,,,,, -27200,1.1795317,3.3314147,,,,,,,,,,,,,, -27300,1.7179431,3.2426634,,,,,,,,,,,,,, -27400,1.3259382,3.2313561,,,,,,,,,,,,,, -27500,1.4468455,3.2642677,,,,,,,,,,,,,, -27600,1.3662386,3.2753584,,,,,,,,,,,,,, -27700,1.2922866,3.2561376,,,,,,,,,,,,,, -27800,1.3418843,3.2818499,,,,,,,,,,,,,, -27900,1.4042823,3.3420298,,,,,,,,,,,,,, -28000,1.3739111,3.216642,,,,,,,,,,,,,, -28100,1.377508,3.2852304,,,,,,,,,,,,,, -28200,1.4766324,3.25557,,,,,,,,,,,,,, -28300,1.8297734,3.2144723,,,,,,,,,,,,,, -28357,,,0.7403140664100647,1.251942753791809,0.6389600038528442,1.6875686645507812,50000.0,0.515500009059906,2.3387906551361084,10000.0,9723.97894001007,10074.980818033218,9723.97894001007,349.4531116485596,0.5828866958618164,0.0 -28400,1.5668918,3.2980459,,,,,,,,,,,,,, -28500,1.6098366,3.256988,,,,,,,,,,,,,, -28600,1.2691905,3.2522388,,,,,,,,,,,,,, -28700,1.379167,3.26927,,,,,,,,,,,,,, -28800,1.1777447,3.2115917,,,,,,,,,,,,,, -28900,1.492086,3.3463225,,,,,,,,,,,,,, -29000,1.510336,3.340687,,,,,,,,,,,,,, -29100,1.1991415,3.1625626,,,,,,,,,,,,,, -29200,1.176307,3.2543032,,,,,,,,,,,,,, -29300,1.1453644,3.2742069,,,,,,,,,,,,,, -29400,1.3958269,3.2926564,,,,,,,,,,,,,, -29500,1.2593949,3.3292513,,,,,,,,,,,,,, -29600,1.2669514,3.2850816,,,,,,,,,,,,,, -29700,1.3322779,3.2022722,,,,,,,,,,,,,, -29800,1.4846616,3.2532754,,,,,,,,,,,,,, -29851,,,0.725984513759613,1.2846310138702393,0.6454600095748901,1.647182583808899,50000.0,0.5200999975204468,2.3046329021453857,10000.0,10234.157279968262,10602.70266866684,10234.157279968262,366.90966606140137,0.6168107986450195,0.0 -29900,1.6181884,3.322631,,,,,,,,,,,,,, -30000,1.3021017,3.2379322,,,,,,,,,,,,,, -30100,1.4342027,3.2418578,,,,,,,,,,,,,, -30200,1.2653979,3.2866197,,,,,,,,,,,,,, -30300,1.325359,3.1768208,,,,,,,,,,,,,, -30400,1.2618303,3.282712,,,,,,,,,,,,,, -30500,1.2444879,3.2542253,,,,,,,,,,,,,, -30600,1.4092797,3.298723,,,,,,,,,,,,,, -30700,1.3076743,3.2919776,,,,,,,,,,,,,, -30800,1.5310128,3.2804646,,,,,,,,,,,,,, -30900,1.2382643,3.1901164,,,,,,,,,,,,,, -31000,1.5332437,3.2490947,,,,,,,,,,,,,, -31100,1.6563562,3.2763667,,,,,,,,,,,,,, -31200,1.3463933,3.23498,,,,,,,,,,,,,, -31300,1.278772,3.2965765,,,,,,,,,,,,,, -31345,,,0.7231544852256775,1.3217514753341677,0.6491000056266785,1.644545078277588,50000.0,0.5267000198364258,2.2946465015411377,10000.0,10744.193604707718,11130.2699239254,10744.193604707718,384.3558855056762,0.6501708030700684,0.0 -31400,1.2589538,3.2727818,,,,,,,,,,,,,, -31500,1.4031484,3.3323116,,,,,,,,,,,,,, -31600,1.3198745,3.3227222,,,,,,,,,,,,,, -31700,1.3695209,3.2415943,,,,,,,,,,,,,, -31800,1.3939478,3.262298,,,,,,,,,,,,,, -31900,1.2880065,3.232995,,,,,,,,,,,,,, -32000,1.5369128,3.2841105,,,,,,,,,,,,,, -32100,1.2993068,3.237747,,,,,,,,,,,,,, -32200,1.39984,3.257546,,,,,,,,,,,,,, -32300,1.2887461,3.2451057,,,,,,,,,,,,,, -32400,1.4289235,3.265139,,,,,,,,,,,,,, -32500,1.3823206,3.2334175,,,,,,,,,,,,,, -32600,1.2675211,3.219848,,,,,,,,,,,,,, -32700,1.3769524,3.2502308,,,,,,,,,,,,,, -32800,1.5003475,3.2113295,,,,,,,,,,,,,, -32838,,,0.7084861397743225,1.379503846168518,0.64055997133255,1.693148136138916,50000.0,0.5159000158309937,2.34363865852356,10000.0,11254.132838010788,11657.91011095047,11254.132838010788,401.97428369522095,0.6830503940582275,0.0 -32900,1.4314054,3.2466412,,,,,,,,,,,,,, -33000,1.371249,3.2278404,,,,,,,,,,,,,, -33100,1.3211876,3.1487117,,,,,,,,,,,,,, -33200,1.4738017,3.231955,,,,,,,,,,,,,, -33300,1.3394686,3.2503514,,,,,,,,,,,,,, -33400,1.5157082,3.1815646,,,,,,,,,,,,,, -33500,1.465668,3.267971,,,,,,,,,,,,,, -33600,1.5106422,3.2926629,,,,,,,,,,,,,, -33700,1.4347494,3.2496784,,,,,,,,,,,,,, -33800,1.4286826,3.274297,,,,,,,,,,,,,, -33900,1.2928174,3.3086152,,,,,,,,,,,,,, -34000,1.3468441,3.2350388,,,,,,,,,,,,,, -34100,1.3494016,3.323135,,,,,,,,,,,,,, -34200,1.450021,3.2717276,,,,,,,,,,,,,, -34300,1.3976268,3.2389722,,,,,,,,,,,,,, -34333,,,0.7168765664100647,1.3567454814910889,0.6507399678230286,1.6542654037475586,50000.0,0.524399995803833,2.2950656414031982,10000.0,11764.288525104525,12185.538932323456,11764.288525104525,419.3604607582092,0.7189726829528809,0.0 -34400,1.407877,3.2116678,,,,,,,,,,,,,, -34500,1.3231379,3.2133324,,,,,,,,,,,,,, -34600,1.2448906,3.1738906,,,,,,,,,,,,,, -34700,1.7460483,3.2781646,,,,,,,,,,,,,, -34800,1.38943,3.283615,,,,,,,,,,,,,, -34900,1.4648196,3.2391572,,,,,,,,,,,,,, -35000,1.378158,3.1304343,,,,,,,,,,,,,, -35100,1.3012983,3.1778753,,,,,,,,,,,,,, -35200,1.2480156,3.2146254,,,,,,,,,,,,,, -35300,1.5157093,3.1812055,,,,,,,,,,,,,, -35400,1.3824315,3.2097867,,,,,,,,,,,,,, -35500,1.5691484,3.2352707,,,,,,,,,,,,,, -35600,1.4011917,3.247718,,,,,,,,,,,,,, -35700,1.3074375,3.249601,,,,,,,,,,,,,, -35800,1.5796112,3.2895405,,,,,,,,,,,,,, -35828,,,0.6993981003761292,1.4522290229797363,0.6362599730491638,1.7302823066711426,50000.0,0.5097000002861023,2.392539501190185,10000.0,12274.358264684675,12713.143894910812,12274.358264684675,436.809175491333,0.7555732727050781,0.0 -35900,1.3434781,3.2084408,,,,,,,,,,,,,, -36000,1.3709211,3.26916,,,,,,,,,,,,,, -36100,1.3300927,3.2383668,,,,,,,,,,,,,, -36200,1.5176845,3.1198494,,,,,,,,,,,,,, -36300,1.4303597,3.1726954,,,,,,,,,,,,,, -36400,1.463298,3.2549691,,,,,,,,,,,,,, -36500,1.4438049,3.2384572,,,,,,,,,,,,,, -36600,1.4525933,3.3214202,,,,,,,,,,,,,, -36700,1.3316209,3.161669,,,,,,,,,,,,,, -36800,1.4276853,3.1091306,,,,,,,,,,,,,, -36900,1.620364,3.168882,,,,,,,,,,,,,, -37000,1.3962731,3.2307813,,,,,,,,,,,,,, -37100,1.5148884,3.2349474,,,,,,,,,,,,,, -37200,1.4193096,3.1623163,,,,,,,,,,,,,, -37300,1.5527657,3.2628214,,,,,,,,,,,,,, -37323,,,0.7249680757522583,1.295501708984375,0.6560800075531006,1.6071767807006836,50000.0,0.5236000418663025,2.282211303710937,10000.0,12784.553744077682,13240.775558948517,12784.553744077682,454.1600670814514,0.7909941673278809,0.0 -37400,1.5480728,3.1315508,,,,,,,,,,,,,, -37500,1.3559225,3.2634761,,,,,,,,,,,,,, -37600,1.4762336,3.1958976,,,,,,,,,,,,,, -37700,1.8913774,3.2775173,,,,,,,,,,,,,, -37800,1.5693982,3.2122319,,,,,,,,,,,,,, -37900,1.5420396,3.2161448,,,,,,,,,,,,,, -38000,1.7272574,3.190417,,,,,,,,,,,,,, -38100,1.5534394,3.188055,,,,,,,,,,,,,, -38200,1.780285,3.2173784,,,,,,,,,,,,,, -38300,1.6755955,3.1777408,,,,,,,,,,,,,, -38400,1.4856423,3.1569805,,,,,,,,,,,,,, -38500,1.9273074,3.1746244,,,,,,,,,,,,,, -38600,1.531819,3.214456,,,,,,,,,,,,,, -38700,1.36138,3.1978974,,,,,,,,,,,,,, -38800,1.5844177,3.1946483,,,,,,,,,,,,,, -38818,,,0.7355508208274841,1.286131739616394,0.6459800004959106,1.6857054233551023,50000.0,0.5186000466346741,2.347128868103028,10000.0,13294.51098227501,13768.348722219467,13294.51098227501,471.69002628326416,0.8257701396942139,0.0 -38900,1.5439727,3.1598468,,,,,,,,,,,,,, -39000,1.4002359,3.2420146,,,,,,,,,,,,,, -39100,1.5429097,3.1353605,,,,,,,,,,,,,, -39200,1.5904499,3.2191994,,,,,,,,,,,,,, -39300,1.5910214,3.2429457,,,,,,,,,,,,,, -39400,1.469587,3.2463813,,,,,,,,,,,,,, -39500,1.543441,3.1412146,,,,,,,,,,,,,, -39600,1.3894149,3.1897483,,,,,,,,,,,,,, -39700,1.5635599,3.154858,,,,,,,,,,,,,, -39800,1.5279983,3.2168217,,,,,,,,,,,,,, -39900,1.4802926,3.1736095,,,,,,,,,,,,,, -40000,1.5224459,3.1719127,,,,,,,,,,,,,, -40100,1.5518719,3.2159061,,,,,,,,,,,,,, -40200,1.5761151,3.1790884,,,,,,,,,,,,,, -40300,1.4684937,3.19383,,,,,,,,,,,,,, -40313,,,0.7302096486091614,1.310068964958191,0.6537799835205078,1.6593209505081177,50000.0,0.5261000394821167,2.311683416366577,10000.0,13804.579131126404,14296.019186019896,13804.579131126404,489.20293831825256,0.8625240325927734,0.0 -40400,1.5928531,3.1989684,,,,,,,,,,,,,, -40500,1.5976732,3.2331264,,,,,,,,,,,,,, -40600,1.6464571,3.2828882,,,,,,,,,,,,,, -40700,1.5455147,3.193397,,,,,,,,,,,,,, -40800,1.725735,3.142513,,,,,,,,,,,,,, -40900,1.6014328,3.1881971,,,,,,,,,,,,,, -41000,1.5205693,3.1073084,,,,,,,,,,,,,, -41100,1.627625,3.165894,,,,,,,,,,,,,, -41200,1.5961908,3.1945071,,,,,,,,,,,,,, -41300,1.8023931,3.2262394,,,,,,,,,,,,,, -41400,1.4224718,3.154045,,,,,,,,,,,,,, -41500,1.6471491,3.2841218,,,,,,,,,,,,,, -41600,1.6347506,3.285482,,,,,,,,,,,,,, -41700,1.60094,3.2613678,,,,,,,,,,,,,, -41800,1.5268676,3.1961615,,,,,,,,,,,,,, -41808,,,0.7320830821990967,1.2917817831039429,0.6575599908828735,1.6185765266418457,50000.0,0.5362000465393066,2.260235071182251,10000.0,14314.632142066956,14824.014763116837,14314.632142066956,507.0605084896088,0.8970699310302734,0.0 -41900,1.6981236,3.266427,,,,,,,,,,,,,, -42000,1.5535897,3.174255,,,,,,,,,,,,,, -42100,1.5848391,3.1096106,,,,,,,,,,,,,, -42200,1.6021262,3.2479918,,,,,,,,,,,,,, -42300,1.7071083,3.1638818,,,,,,,,,,,,,, -42400,1.5708824,3.2173784,,,,,,,,,,,,,, -42500,1.5313698,3.167003,,,,,,,,,,,,,, -42600,1.7263427,3.1496787,,,,,,,,,,,,,, -42700,1.5973235,3.1532695,,,,,,,,,,,,,, -42800,1.548867,3.1617057,,,,,,,,,,,,,, -42900,1.6574619,3.2399447,,,,,,,,,,,,,, -43000,1.6362537,3.1698654,,,,,,,,,,,,,, -43100,1.7816961,3.3185394,,,,,,,,,,,,,, -43200,1.5797392,3.1557596,,,,,,,,,,,,,, -43300,1.7044346,3.2033734,,,,,,,,,,,,,, -43304,,,0.7353116869926453,1.2636276483535769,0.6688199639320374,1.5687376260757446,50000.0,0.5380000472068787,2.22286057472229,10000.0,14824.831728935242,15351.65698647499,14824.831728935242,524.4154839515686,0.935309648513794,0.0 -43400,1.5705475,3.2288542,,,,,,,,,,,,,, -43500,1.5817423,3.1299596,,,,,,,,,,,,,, -43600,1.9837029,3.2354965,,,,,,,,,,,,,, -43700,1.4838749,3.1002872,,,,,,,,,,,,,, -43800,1.6397142,3.2028217,,,,,,,,,,,,,, -43900,1.509627,3.1272001,,,,,,,,,,,,,, -44000,1.6837593,3.1962278,,,,,,,,,,,,,, -44100,1.600939,3.1970248,,,,,,,,,,,,,, -44200,1.7260826,3.2858963,,,,,,,,,,,,,, -44300,1.6051638,3.0783396,,,,,,,,,,,,,, -44400,1.5235195,3.1680584,,,,,,,,,,,,,, -44500,1.6682149,3.1842215,,,,,,,,,,,,,, -44600,1.62707,3.1562138,,,,,,,,,,,,,, -44700,1.6371377,3.1503537,,,,,,,,,,,,,, -44799,,,0.7269411683082581,1.2829989194869995,0.6580599546432495,1.5886094570159912,50000.0,0.5301000475883484,2.2401185035705566,10000.0,15334.955714941025,15879.262532949448,15334.955714941025,541.8120038509369,0.9696609973907472,0.0 -44800,1.553147,3.1522222,,,,,,,,,,,,,, -44900,2.0018718,3.15079,,,,,,,,,,,,,, -45000,1.6555239,3.1501136,,,,,,,,,,,,,, -45100,1.8272961,3.1624138,,,,,,,,,,,,,, -45200,1.6893265,3.1562777,,,,,,,,,,,,,, -45300,1.6349581,3.2438006,,,,,,,,,,,,,, -45400,1.7822132,3.1570988,,,,,,,,,,,,,, -45500,1.7621696,3.2385006,,,,,,,,,,,,,, -45600,1.5759532,3.1548083,,,,,,,,,,,,,, -45700,1.5874654,3.1476908,,,,,,,,,,,,,, -45800,1.612381,3.1871092,,,,,,,,,,,,,, -45900,1.7938931,3.2392106,,,,,,,,,,,,,, -46000,1.750406,3.1753345,,,,,,,,,,,,,, -46100,1.6355515,3.1462758,,,,,,,,,,,,,, -46200,1.5989149,3.2085738,,,,,,,,,,,,,, -46294,,,0.7327606678009033,1.280967116355896,0.6660400032997131,1.5779304504394531,50000.0,0.5383000373840332,2.21299695968628,10000.0,15845.028705358503,16406.679964780807,15845.028705358503,559.0715284347534,1.003936529159546,0.0 -46300,1.9025518,3.1005077,,,,,,,,,,,,,, -46400,1.5495695,3.0961256,,,,,,,,,,,,,, -46500,1.6453819,3.1281486,,,,,,,,,,,,,, -46600,1.832796,3.1548722,,,,,,,,,,,,,, -46700,1.8630822,3.1841013,,,,,,,,,,,,,, -46800,1.5902637,3.116073,,,,,,,,,,,,,, -46900,1.6108373,3.1425502,,,,,,,,,,,,,, -47000,1.8433757,3.2220094,,,,,,,,,,,,,, -47100,1.7149285,3.0868735,,,,,,,,,,,,,, -47200,1.7720302,3.16392,,,,,,,,,,,,,, -47300,1.6873131,3.10339,,,,,,,,,,,,,, -47400,1.7350332,3.1486244,,,,,,,,,,,,,, -47500,1.6998168,3.1978562,,,,,,,,,,,,,, -47600,1.731179,3.209177,,,,,,,,,,,,,, -47700,1.6313642,3.1819618,,,,,,,,,,,,,, -47789,,,0.7629344463348389,1.151097536087036,0.6612600088119507,1.5798641443252563,50000.0,0.5379000306129456,2.240457057952881,10000.0,16355.009400129318,16933.876542568207,16355.009400129318,576.2020993232727,1.0376520156860352,0.0 -47800,1.9912832,3.1903362,,,,,,,,,,,,,, -47900,1.7774354,3.1759381,,,,,,,,,,,,,, -48000,1.5945191,3.111467,,,,,,,,,,,,,, -48100,1.8976009,3.1063318,,,,,,,,,,,,,, -48200,1.6645958,3.135174,,,,,,,,,,,,,, -48300,1.7266496,3.1982906,,,,,,,,,,,,,, -48400,1.6249192,3.1286912,,,,,,,,,,,,,, -48500,1.7178035,3.152603,,,,,,,,,,,,,, -48600,1.7492558,3.1972682,,,,,,,,,,,,,, -48700,1.7428056,3.0803103,,,,,,,,,,,,,, -48800,1.8528955,3.1836405,,,,,,,,,,,,,, -48900,1.6376015,3.1414576,,,,,,,,,,,,,, -49000,1.734618,3.1687765,,,,,,,,,,,,,, -49100,1.7466154,3.216021,,,,,,,,,,,,,, -49200,1.6106457,3.1421888,,,,,,,,,,,,,, -49284,,,0.7473692297935486,1.239794373512268,0.6687399744987488,1.597258687019348,50000.0,0.5371000170707703,2.2499301433563232,10000.0,16865.14865398407,17461.489314556122,16865.14865398407,593.5875813961029,1.0741991996765137,0.0 -49300,1.6827797,3.0728269,,,,,,,,,,,,,, -49400,1.8334808,3.1949148,,,,,,,,,,,,,, -49500,1.6484714,3.1616132,,,,,,,,,,,,,, -49600,1.6833038,3.1847608,,,,,,,,,,,,,, -49700,1.6259966,3.2074416,,,,,,,,,,,,,, -49800,1.5932981,3.0698383,,,,,,,,,,,,,, -49900,1.7168036,3.1674821,,,,,,,,,,,,,, -50000,1.767176,3.0863106,,,,,,,,,,,,,, -50100,1.6814264,3.2121508,,,,,,,,,,,,,, -50200,1.7403029,3.1577737,,,,,,,,,,,,,, -50300,1.6861721,3.1775737,,,,,,,,,,,,,, -50400,1.651916,3.024245,,,,,,,,,,,,,, -50500,2.0088959,3.198843,,,,,,,,,,,,,, -50600,1.8050729,3.144535,,,,,,,,,,,,,, -50700,1.6948118,3.1426191,,,,,,,,,,,,,, -50779,,,0.7400350570678711,1.2261372804641724,0.6668199896812439,1.554896354675293,50000.0,0.5362000465393066,2.243440628051758,10000.0,17375.061665296555,17989.708940029144,17375.061665296555,611.8048617839813,1.1133863925933838,0.0 -50800,1.6873256,3.1243174,,,,,,,,,,,,,, -50900,1.7820295,3.1668024,,,,,,,,,,,,,, -51000,1.7964553,3.1192055,,,,,,,,,,,,,, -51100,1.7756172,3.0968316,,,,,,,,,,,,,, -51200,1.8094819,3.152508,,,,,,,,,,,,,, -51300,1.8610005,3.1416745,,,,,,,,,,,,,, -51400,1.6585242,3.1185498,,,,,,,,,,,,,, -51500,1.7639668,3.2073596,,,,,,,,,,,,,, -51600,1.6889327,3.1418736,,,,,,,,,,,,,, -51700,1.8461928,3.1763628,,,,,,,,,,,,,, -51800,1.6854084,3.1879678,,,,,,,,,,,,,, -51900,1.7922093,3.102438,,,,,,,,,,,,,, -52000,1.7064717,3.0831704,,,,,,,,,,,,,, -52100,1.7765073,3.167445,,,,,,,,,,,,,, -52200,1.7253255,3.1818712,,,,,,,,,,,,,, -52275,,,0.7399553656578064,1.244918942451477,0.6628599762916565,1.5738775730133057,50000.0,0.5449000000953674,2.2030839920043945,10000.0,17885.29244852066,18517.295283079147,17885.29244852066,629.0761721134186,1.1479730606079102,0.0 -52300,1.9376388,3.146279,,,,,,,,,,,,,, -52400,1.9225702,3.150193,,,,,,,,,,,,,, -52500,1.8624904,3.1045353,,,,,,,,,,,,,, -52600,1.703412,3.2300544,,,,,,,,,,,,,, -52700,1.7667726,3.056593,,,,,,,,,,,,,, -52800,1.8673112,3.1577406,,,,,,,,,,,,,, -52900,1.7793261,3.1272166,,,,,,,,,,,,,, -53000,1.6651747,3.1354868,,,,,,,,,,,,,, -53100,1.7076473,3.1489804,,,,,,,,,,,,,, -53200,1.7742897,3.112498,,,,,,,,,,,,,, -53300,1.7998407,3.1799753,,,,,,,,,,,,,, -53400,1.863782,3.1766913,,,,,,,,,,,,,, -53500,1.799159,3.0657957,,,,,,,,,,,,,, -53600,1.9444557,3.1512222,,,,,,,,,,,,,, -53700,1.7778518,3.1545744,,,,,,,,,,,,,, -53771,,,0.7357900142669678,1.252927303314209,0.6644200086593628,1.56941819190979,50000.0,0.5424000024795532,2.229390382766724,10000.0,18395.48710179329,19044.75901412964,18395.48710179329,646.256591796875,1.186426877975464,0.0 -53800,1.7995545,3.1167274,,,,,,,,,,,,,, -53900,1.8752229,3.1549468,,,,,,,,,,,,,, -54000,1.7905025,3.1069593,,,,,,,,,,,,,, -54100,1.8677802,3.2169619,,,,,,,,,,,,,, -54200,1.9468353,3.1687553,,,,,,,,,,,,,, -54300,1.7636007,3.1337867,,,,,,,,,,,,,, -54400,2.1185203,3.1681652,,,,,,,,,,,,,, -54500,1.881616,3.2211084,,,,,,,,,,,,,, -54600,1.911178,3.1485627,,,,,,,,,,,,,, -54700,2.1241584,3.1749468,,,,,,,,,,,,,, -54800,1.9331895,3.1567075,,,,,,,,,,,,,, -54900,1.801036,3.1153436,,,,,,,,,,,,,, -55000,1.9161845,3.1646652,,,,,,,,,,,,,, -55100,1.7908466,3.1211731,,,,,,,,,,,,,, -55200,2.1030498,3.1570542,,,,,,,,,,,,,, -55266,,,0.7410913705825806,1.232458233833313,0.6733999848365784,1.5392425060272217,50000.0,0.5421000123023987,2.2132768630981445,10000.0,18905.508989572525,19572.208678483963,18905.508989572525,663.5889291763306,1.2300894260406494,0.0 -55300,2.0520248,3.1570752,,,,,,,,,,,,,, -55400,1.8035626,3.1502128,,,,,,,,,,,,,, -55500,1.8776902,3.1636825,,,,,,,,,,,,,, -55600,1.8228273,3.1347077,,,,,,,,,,,,,, -55700,1.8774579,3.1823602,,,,,,,,,,,,,, -55800,1.8017136,3.1622808,,,,,,,,,,,,,, -55900,1.924385,3.1940923,,,,,,,,,,,,,, -56000,1.9127526,3.196735,,,,,,,,,,,,,, -56100,1.8931397,3.0863664,,,,,,,,,,,,,, -56200,1.7658484,3.1339824,,,,,,,,,,,,,, -56300,1.7372609,3.153772,,,,,,,,,,,,,, -56400,1.859984,3.1098206,,,,,,,,,,,,,, -56500,1.8092049,3.1046584,,,,,,,,,,,,,, -56600,1.8327317,3.0730197,,,,,,,,,,,,,, -56700,1.9575367,3.2066102,,,,,,,,,,,,,, -56761,,,0.7707669138908386,1.1692932844161987,0.6709399819374084,1.5907434225082395,50000.0,0.5403000116348267,2.260906457901001,10000.0,19415.54997587204,20099.63463830948,19415.54997587204,680.8869771957397,1.26725435256958,0.0 -56800,1.9610239,3.1004753,,,,,,,,,,,,,, -56900,1.8773729,3.1235082,,,,,,,,,,,,,, -57000,1.8855531,3.1623676,,,,,,,,,,,,,, -57100,1.830135,3.020269,,,,,,,,,,,,,, -57200,1.8203634,3.0969887,,,,,,,,,,,,,, -57300,1.818526,3.1436722,,,,,,,,,,,,,, -57400,1.8330964,3.0383573,,,,,,,,,,,,,, -57500,1.7366284,3.0941694,,,,,,,,,,,,,, -57600,1.8945938,3.1201947,,,,,,,,,,,,,, -57700,1.8102313,3.1055455,,,,,,,,,,,,,, -57800,1.8901867,3.1694326,,,,,,,,,,,,,, -57900,1.9304721,3.0749683,,,,,,,,,,,,,, -58000,1.8228447,3.1094365,,,,,,,,,,,,,, -58100,1.9213223,3.170196,,,,,,,,,,,,,, -58200,1.9862728,3.0882316,,,,,,,,,,,,,, -58256,,,0.7694514989852905,1.1050457954406738,0.6767799854278564,1.4937835931777954,50000.0,0.5485000014305115,2.1709296703338623,10000.0,19925.609867811203,20627.20070528984,19925.609867811203,698.3031196594238,1.3048710823059082,0.0 -58300,1.8638047,3.0712616,,,,,,,,,,,,,, -58400,2.0046196,3.1634498,,,,,,,,,,,,,, -58500,2.020604,3.0757601,,,,,,,,,,,,,, -58600,1.8344187,3.1282907,,,,,,,,,,,,,, -58700,1.8275683,3.180287,,,,,,,,,,,,,, -58800,1.9803611,3.0682473,,,,,,,,,,,,,, -58900,1.8406703,3.0437646,,,,,,,,,,,,,, -59000,1.916915,3.1156533,,,,,,,,,,,,,, -59100,1.9222962,2.999711,,,,,,,,,,,,,, -59200,1.9015722,3.1595488,,,,,,,,,,,,,, -59300,1.8018913,3.0496035,,,,,,,,,,,,,, -59400,1.8841164,2.9997854,,,,,,,,,,,,,, -59500,1.9939737,3.139377,,,,,,,,,,,,,, -59600,1.8665648,3.0911171,,,,,,,,,,,,,, -59700,1.8738234,3.0742884,,,,,,,,,,,,,, -59752,,,0.7449377775192261,1.1936075687408447,0.666979968547821,1.5506370067596436,50000.0,0.5371000170707703,2.2205097675323486,10000.0,20435.80241703987,21154.84163069725,20435.80241703987,715.6581652164459,1.3460206985473633,0.0 -59800,2.0033529,3.0743341,,,,,,,,,,,,,, -59900,1.996876,3.1071048,,,,,,,,,,,,,, -60000,1.8332299,3.074003,,,,,,,,,,,,,, -60100,1.9112164,3.1134727,,,,,,,,,,,,,, -60200,1.8601524,3.0978837,,,,,,,,,,,,,, -60300,1.9641033,3.172988,,,,,,,,,,,,,, -60400,1.9706224,3.0084608,,,,,,,,,,,,,, -60500,1.8672855,3.1381495,,,,,,,,,,,,,, -60600,1.9345304,3.0950696,,,,,,,,,,,,,, -60700,2.0017042,3.056561,,,,,,,,,,,,,, -60800,2.013974,3.1103354,,,,,,,,,,,,,, -60900,1.9488552,3.102701,,,,,,,,,,,,,, -61000,1.8946924,3.1230662,,,,,,,,,,,,,, -61100,2.0018659,3.144887,,,,,,,,,,,,,, -61200,1.9730988,3.2431264,,,,,,,,,,,,,, -61248,,,0.7509366869926453,1.1979851722717283,0.6717399954795837,1.540035605430603,50000.0,0.5407000184059143,2.2069180011749268,10000.0,20946.0156428814,21682.11145663261,20946.0156428814,732.6220688819885,1.3891007900238037,0.0 -61300,1.9035642,3.117502,,,,,,,,,,,,,, -61400,1.9233931,3.1132903,,,,,,,,,,,,,, -61500,1.8809658,3.0898077,,,,,,,,,,,,,, -61600,2.033181,3.048919,,,,,,,,,,,,,, -61700,2.0155098,3.1127653,,,,,,,,,,,,,, -61800,1.9257904,3.0180469,,,,,,,,,,,,,, -61900,1.9943908,3.1352487,,,,,,,,,,,,,, -62000,2.0774403,3.055262,,,,,,,,,,,,,, -62100,2.073977,3.1266072,,,,,,,,,,,,,, -62200,1.843069,3.0969982,,,,,,,,,,,,,, -62300,1.9679371,3.072121,,,,,,,,,,,,,, -62400,2.1276422,3.1120052,,,,,,,,,,,,,, -62500,2.0508902,3.1587362,,,,,,,,,,,,,, -62600,1.9704581,3.0806453,,,,,,,,,,,,,, -62700,2.0981832,3.1777887,,,,,,,,,,,,,, -62744,,,0.7419483065605164,1.2212562561035156,0.6708199977874756,1.5463603734970093,50000.0,0.5323000550270081,2.2445785999298096,10000.0,21456.26297426224,22209.52569293976,21456.26297426224,749.6996276378632,1.426433801651001,0.0 -62800,2.2183394,3.0869415,,,,,,,,,,,,,, -62900,2.0210416,3.0793977,,,,,,,,,,,,,, -63000,1.9158279,3.0464454,,,,,,,,,,,,,, -63100,1.922656,3.130393,,,,,,,,,,,,,, -63200,1.9689525,3.0245295,,,,,,,,,,,,,, -63300,2.040238,3.2095108,,,,,,,,,,,,,, -63400,2.0428433,3.1329014,,,,,,,,,,,,,, -63500,2.034332,3.1310568,,,,,,,,,,,,,, -63600,1.996697,3.0933268,,,,,,,,,,,,,, -63700,1.9330301,3.1126313,,,,,,,,,,,,,, -63800,1.9472296,3.093877,,,,,,,,,,,,,, -63900,1.9048808,3.0424783,,,,,,,,,,,,,, -64000,2.085194,3.0870895,,,,,,,,,,,,,, -64100,1.9008393,3.042614,,,,,,,,,,,,,, -64200,1.8393781,3.0829308,,,,,,,,,,,,,, -64239,,,0.7436822056770325,1.2639172077178955,0.6767199635505676,1.5742335319519043,50000.0,0.5462000370025635,2.255506992340088,10000.0,21966.35502266884,22737.02538251877,21966.35502266884,767.0168855190277,1.464949369430542,0.0 -64300,1.9437613,3.0813448,,,,,,,,,,,,,, -64400,2.0033875,3.1212952,,,,,,,,,,,,,, -64500,1.9659199,3.082587,,,,,,,,,,,,,, -64600,1.8886846,3.059709,,,,,,,,,,,,,, -64700,1.9168024,3.0804431,,,,,,,,,,,,,, -64800,1.8721966,3.1170502,,,,,,,,,,,,,, -64900,1.9981807,3.0534997,,,,,,,,,,,,,, -65000,2.0182824,3.0470276,,,,,,,,,,,,,, -65100,2.030544,3.0414348,,,,,,,,,,,,,, -65200,1.9324847,3.0301824,,,,,,,,,,,,,, -65300,2.0603068,3.0727465,,,,,,,,,,,,,, -65400,2.0341353,3.1040633,,,,,,,,,,,,,, -65500,1.998132,3.029717,,,,,,,,,,,,,, -65600,2.1123323,3.0770183,,,,,,,,,,,,,, -65700,2.1187365,3.1062546,,,,,,,,,,,,,, -65735,,,0.7479472160339355,1.2383085489273071,0.6757599711418152,1.5526641607284546,50000.0,0.5468000173568726,2.2133982181549072,10000.0,22476.510162115097,23264.367428541183,22476.510162115097,784.1112470626831,1.5077550411224363,0.0 -65800,2.0032425,2.988133,,,,,,,,,,,,,, -65900,1.8568114,3.1732666,,,,,,,,,,,,,, -66000,1.9368098,3.1347394,,,,,,,,,,,,,, -66100,2.0652292,3.1338332,,,,,,,,,,,,,, -66200,2.016111,3.115146,,,,,,,,,,,,,, -66300,2.063629,3.111772,,,,,,,,,,,,,, -66400,2.1047897,3.1498086,,,,,,,,,,,,,, -66500,1.9516749,3.0762522,,,,,,,,,,,,,, -66600,2.2846973,3.065834,,,,,,,,,,,,,, -66700,2.187953,3.1156645,,,,,,,,,,,,,, -66800,1.963155,3.0399287,,,,,,,,,,,,,, -66900,2.1425524,3.1188629,,,,,,,,,,,,,, -67000,2.0733273,3.1005442,,,,,,,,,,,,,, -67100,2.0258064,3.1727276,,,,,,,,,,,,,, -67200,1.9311322,3.05435,,,,,,,,,,,,,, -67230,,,0.7716238498687744,1.069993019104004,0.6693199872970581,1.5192725658416748,50000.0,0.5466000437736511,2.168231725692749,10000.0,22986.69782590866,23791.52751684189,22986.69782590866,800.9966917037964,1.5449728965759275,0.0 -67300,1.961029,3.050912,,,,,,,,,,,,,, -67400,2.1707404,3.1103845,,,,,,,,,,,,,, -67500,1.9839945,3.018424,,,,,,,,,,,,,, -67600,2.266991,3.1192186,,,,,,,,,,,,,, -67700,2.094704,3.074509,,,,,,,,,,,,,, -67800,1.986611,3.0410676,,,,,,,,,,,,,, -67900,1.8825859,2.9844162,,,,,,,,,,,,,, -68000,1.9898909,3.0688024,,,,,,,,,,,,,, -68100,1.9677143,3.1351695,,,,,,,,,,,,,, -68200,2.0139806,3.1122649,,,,,,,,,,,,,, -68300,2.0317676,3.0897336,,,,,,,,,,,,,, -68400,2.0232959,3.0781085,,,,,,,,,,,,,, -68500,1.9880888,3.026189,,,,,,,,,,,,,, -68600,1.9967313,3.02296,,,,,,,,,,,,,, -68700,1.9188337,3.0394583,,,,,,,,,,,,,, -68726,,,0.7636120915412903,1.1651519536972046,0.6775799989700317,1.5437074899673462,50000.0,0.5540000200271606,2.1980412006378174,10000.0,23496.80472302437,24318.850444078445,23496.80472302437,818.1228411197662,1.5834143161773682,0.0 -68800,2.3431847,3.1054559,,,,,,,,,,,,,, -68900,2.053014,3.0149655,,,,,,,,,,,,,, -69000,2.1629918,3.083018,,,,,,,,,,,,,, -69100,2.1110787,3.053471,,,,,,,,,,,,,, -69200,2.0497007,3.098174,,,,,,,,,,,,,, -69300,1.9647903,3.0315595,,,,,,,,,,,,,, -69400,2.0850894,3.0515742,,,,,,,,,,,,,, -69500,2.2648914,3.142242,,,,,,,,,,,,,, -69600,2.0389671,3.0391133,,,,,,,,,,,,,, -69700,2.0957682,3.0449586,,,,,,,,,,,,,, -69800,2.0785203,3.0006707,,,,,,,,,,,,,, -69900,2.219808,3.0842793,,,,,,,,,,,,,, -70000,2.0287087,3.109344,,,,,,,,,,,,,, -70100,2.1449425,3.105905,,,,,,,,,,,,,, -70200,2.253665,3.0879223,,,,,,,,,,,,,, -70221,,,0.7656847834587097,1.155431866645813,0.68367999792099,1.5203362703323364,50000.0,0.5615000128746033,2.142663717269897,10000.0,24006.80624294281,24846.14938569069,24006.80624294281,835.3253185749054,1.6262106895446775,0.0 -70300,2.076046,3.0597498,,,,,,,,,,,,,, -70400,2.0004601,3.0734446,,,,,,,,,,,,,, -70500,2.168124,3.0261984,,,,,,,,,,,,,, -70600,2.0048282,3.0517604,,,,,,,,,,,,,, -70700,2.0459108,3.1334605,,,,,,,,,,,,,, -70800,2.0709262,3.1196504,,,,,,,,,,,,,, -70900,2.023125,3.0450482,,,,,,,,,,,,,, -71000,2.0511777,3.1199274,,,,,,,,,,,,,, -71100,2.1724238,3.019641,,,,,,,,,,,,,, -71200,2.0347788,3.0708358,,,,,,,,,,,,,, -71300,2.2995179,3.0699422,,,,,,,,,,,,,, -71400,2.1150005,3.0375562,,,,,,,,,,,,,, -71500,2.1273003,3.0486035,,,,,,,,,,,,,, -71600,2.0697274,3.0438972,,,,,,,,,,,,,, -71700,1.999799,3.0833695,,,,,,,,,,,,,, -71717,,,0.770527720451355,1.1102607250213623,0.6875999569892883,1.4715756177902222,50000.0,0.55840003490448,2.1221847534179688,10000.0,24516.862579345703,25373.404136896133,24516.862579345703,852.434784412384,1.664741277694702,0.0 -71800,2.0973392,3.0056345,,,,,,,,,,,,,, -71900,2.040963,2.9855282,,,,,,,,,,,,,, -72000,2.1941385,3.0765455,,,,,,,,,,,,,, -72100,1.9700677,3.048451,,,,,,,,,,,,,, -72200,2.0846138,3.013961,,,,,,,,,,,,,, -72300,2.2354648,3.1287653,,,,,,,,,,,,,, -72400,2.1878066,3.1284463,,,,,,,,,,,,,, -72500,2.0473423,3.0932984,,,,,,,,,,,,,, -72600,1.9975652,3.0474434,,,,,,,,,,,,,, -72700,2.1778274,3.1036706,,,,,,,,,,,,,, -72800,2.170129,3.0371878,,,,,,,,,,,,,, -72900,2.078187,3.074873,,,,,,,,,,,,,, -73000,2.1614654,3.0445619,,,,,,,,,,,,,, -73100,2.1577086,3.1347976,,,,,,,,,,,,,, -73200,2.2839034,3.122358,,,,,,,,,,,,,, -73212,,,0.7600645422935486,1.1490471363067627,0.6841599941253662,1.4841029644012451,50000.0,0.5588000416755676,2.1481716632843018,10000.0,25026.77587389946,25900.518624067307,25026.77587389946,869.5368921756744,1.7121210098266602,0.0 -73300,2.3424096,3.097276,,,,,,,,,,,,,, -73400,2.1048458,3.0014417,,,,,,,,,,,,,, -73500,2.169986,3.0957568,,,,,,,,,,,,,, -73600,2.1954799,2.9625633,,,,,,,,,,,,,, -73700,2.0381851,3.0192065,,,,,,,,,,,,,, -73800,1.9248012,3.0317779,,,,,,,,,,,,,, -73900,2.1685674,3.0684516,,,,,,,,,,,,,, -74000,2.505514,3.0288272,,,,,,,,,,,,,, -74100,2.1109874,3.006556,,,,,,,,,,,,,, -74200,2.039469,2.9477365,,,,,,,,,,,,,, -74300,2.1743188,3.0496724,,,,,,,,,,,,,, -74400,2.2065682,3.0076225,,,,,,,,,,,,,, -74500,2.174257,3.0372763,,,,,,,,,,,,,, -74600,2.1175244,3.0836308,,,,,,,,,,,,,, -74700,2.1757095,3.0352092,,,,,,,,,,,,,, -74708,,,0.7657246589660645,1.1162415742874146,0.6844599843025208,1.458303928375244,50000.0,0.5622000098228455,2.1055116653442383,10000.0,25536.89725470543,26427.82671189308,25536.89725470543,886.6352317333221,1.7510671615600586,0.0 -74800,2.0709713,3.0874095,,,,,,,,,,,,,, -74900,2.23244,3.0093143,,,,,,,,,,,,,, -75000,2.1881495,3.0117908,,,,,,,,,,,,,, -75100,2.229868,3.0577693,,,,,,,,,,,,,, -75200,2.208947,3.058206,,,,,,,,,,,,,, -75300,2.215353,3.0580828,,,,,,,,,,,,,, -75400,2.0055807,3.0429158,,,,,,,,,,,,,, -75500,2.0793335,3.091191,,,,,,,,,,,,,, -75600,2.235689,3.1525922,,,,,,,,,,,,,, -75700,2.305245,3.0640345,,,,,,,,,,,,,, -75800,2.2165685,3.067512,,,,,,,,,,,,,, -75900,2.0376196,3.060209,,,,,,,,,,,,,, -76000,2.1227741,3.056083,,,,,,,,,,,,,, -76100,2.0138295,3.0136669,,,,,,,,,,,,,, -76200,1.9794755,3.025141,,,,,,,,,,,,,, -76203,,,0.8078364133834839,1.0016721487045288,0.6889399886131287,1.5002516508102417,50000.0,0.5706000328063965,2.115853786468506,10000.0,26046.908839941025,26954.93002486229,26046.908839941025,903.6346333026886,1.7926886081695557,0.0 -76300,2.0439272,3.0217476,,,,,,,,,,,,,, -76400,2.1295493,3.0804615,,,,,,,,,,,,,, -76500,2.1344879,2.992643,,,,,,,,,,,,,, -76600,2.2201195,3.037327,,,,,,,,,,,,,, -76700,2.2566936,3.08255,,,,,,,,,,,,,, -76800,1.9582933,3.0062,,,,,,,,,,,,,, -76900,2.316379,3.0780401,,,,,,,,,,,,,, -77000,2.1867752,3.0350533,,,,,,,,,,,,,, -77100,2.157163,3.0712795,,,,,,,,,,,,,, -77200,2.2618341,3.0962515,,,,,,,,,,,,,, -77300,2.53176,3.0432062,,,,,,,,,,,,,, -77400,2.1827,3.0786543,,,,,,,,,,,,,, -77500,2.1355727,3.0010803,,,,,,,,,,,,,, -77600,2.056896,3.013748,,,,,,,,,,,,,, -77699,,,0.786551296710968,1.0268144607543943,0.6879599690437317,1.454639196395874,50000.0,0.5588000416755676,2.096848964691162,10000.0,26557.14165997505,27482.33225798607,26557.14165997505,920.7117989063264,1.8335063457489007,0.0 -77700,2.272251,3.0902987,,,,,,,,,,,,,, -77800,2.2762337,3.057695,,,,,,,,,,,,,, -77900,2.118301,3.0554745,,,,,,,,,,,,,, -78000,2.1671226,3.0678742,,,,,,,,,,,,,, -78100,2.1324883,3.004603,,,,,,,,,,,,,, -78200,2.159903,2.9855723,,,,,,,,,,,,,, -78300,2.0548408,2.9804616,,,,,,,,,,,,,, -78400,2.2754648,3.053876,,,,,,,,,,,,,, -78500,2.2134855,3.0754123,,,,,,,,,,,,,, -78600,2.2461774,3.0416455,,,,,,,,,,,,,, -78700,2.1361935,2.939993,,,,,,,,,,,,,, -78800,2.257727,3.0381637,,,,,,,,,,,,,, -78900,2.410056,3.037433,,,,,,,,,,,,,, -79000,2.1960304,3.1036,,,,,,,,,,,,,, -79100,2.3981717,3.0341046,,,,,,,,,,,,,, -79194,,,0.7811902165412903,1.058087706565857,0.6916999816894531,1.4431703090667725,50000.0,0.5654000043869019,2.102102756500244,10000.0,27067.081208705906,28009.725796461105,27067.081208705906,938.072808265686,1.8744986057281487,0.0 -79200,2.2563157,3.109568,,,,,,,,,,,,,, -79300,2.336558,3.0626123,,,,,,,,,,,,,, -79400,2.247033,2.99467,,,,,,,,,,,,,, -79500,2.2262228,3.0164104,,,,,,,,,,,,,, -79600,2.1586366,3.0112758,,,,,,,,,,,,,, -79700,2.0832195,2.967262,,,,,,,,,,,,,, -79800,2.1897147,2.9960613,,,,,,,,,,,,,, -79900,2.064491,2.9516425,,,,,,,,,,,,,, -80000,2.361308,3.062065,,,,,,,,,,,,,, -80100,2.220877,3.1028466,,,,,,,,,,,,,, -80200,2.139377,3.011768,,,,,,,,,,,,,, -80300,2.3496535,3.0583017,,,,,,,,,,,,,, -80400,2.1858883,2.9739146,,,,,,,,,,,,,, -80500,2.1960943,3.0619226,,,,,,,,,,,,,, -80600,2.179563,3.0310931,,,,,,,,,,,,,, -80690,,,0.7735171914100647,1.107993483543396,0.6871599555015564,1.4867323637008667,50000.0,0.5612000226974487,2.126382827758789,10000.0,27577.212859630585,28537.874395132065,27577.212859630585,955.9951276779176,1.9179785251617432,0.0 -80700,2.22211,2.9413652,,,,,,,,,,,,,, -80800,2.2979238,2.9419749,,,,,,,,,,,,,, -80900,2.1724174,2.994793,,,,,,,,,,,,,, -81000,2.2780988,3.07559,,,,,,,,,,,,,, -81100,2.2146895,2.96349,,,,,,,,,,,,,, -81200,2.0868936,2.9469912,,,,,,,,,,,,,, -81300,2.2784202,3.0482035,,,,,,,,,,,,,, -81400,2.240457,3.060297,,,,,,,,,,,,,, -81500,2.3801594,3.1011903,,,,,,,,,,,,,, -81600,2.1089523,3.009703,,,,,,,,,,,,,, -81700,2.3379176,3.0166342,,,,,,,,,,,,,, -81800,2.378489,3.081335,,,,,,,,,,,,,, -81900,2.2649481,3.070281,,,,,,,,,,,,,, -82000,2.314021,2.9819338,,,,,,,,,,,,,, -82100,2.3017097,2.9971797,,,,,,,,,,,,,, -82186,,,0.7631337642669678,1.1587564945220947,0.6825199723243713,1.5233628749847412,50000.0,0.5517000555992126,2.1907894611358643,10000.0,28087.34799814224,29065.30001354217,28087.34799814224,973.1926457881927,1.9603497982025144,0.0 -82200,2.4542797,3.0220602,,,,,,,,,,,,,, -82300,2.1951687,2.9958968,,,,,,,,,,,,,, -82400,2.148292,2.9962568,,,,,,,,,,,,,, -82500,2.239259,3.040851,,,,,,,,,,,,,, -82600,2.2734919,3.005203,,,,,,,,,,,,,, -82700,2.2612162,3.0501695,,,,,,,,,,,,,, -82800,2.19354,2.9551063,,,,,,,,,,,,,, -82900,2.2107477,3.052335,,,,,,,,,,,,,, -83000,2.2448235,3.0331793,,,,,,,,,,,,,, -83100,2.1521592,3.0067897,,,,,,,,,,,,,, -83200,2.134927,2.9198523,,,,,,,,,,,,,, -83300,2.431199,3.078742,,,,,,,,,,,,,, -83400,2.4162462,3.0206847,,,,,,,,,,,,,, -83500,2.354989,2.9986389,,,,,,,,,,,,,, -83600,2.420276,3.015593,,,,,,,,,,,,,, -83680,,,0.7662228941917419,1.1387059688568115,0.6846799850463867,1.4886494874954224,50000.0,0.5542000532150269,2.152329683303833,10000.0,28597.258448839188,29592.47763466835,28597.258448839188,990.360454082489,2.00658917427063,0.0 -83700,2.2413855,3.1118476,,,,,,,,,,,,,, -83800,2.3226588,3.1078587,,,,,,,,,,,,,, -83900,2.320249,3.010801,,,,,,,,,,,,,, -84000,2.3074143,3.0560787,,,,,,,,,,,,,, -84100,2.179559,3.1010623,,,,,,,,,,,,,, -84200,2.411667,3.023601,,,,,,,,,,,,,, -84300,2.332265,2.9820404,,,,,,,,,,,,,, -84400,2.3409703,3.0396786,,,,,,,,,,,,,, -84500,2.3121843,3.0219777,,,,,,,,,,,,,, -84600,2.4530463,3.0029073,,,,,,,,,,,,,, -84700,2.3284051,3.0455034,,,,,,,,,,,,,, -84800,2.3325136,2.9967594,,,,,,,,,,,,,, -84900,2.2562404,3.0263608,,,,,,,,,,,,,, -85000,2.255089,3.0248344,,,,,,,,,,,,,, -85100,2.3560147,3.071263,,,,,,,,,,,,,, -85176,,,0.7777224183082581,1.1008814573287964,0.6924399733543396,1.4702638387680054,50000.0,0.5660000443458557,2.104079246520996,10000.0,29107.465238809586,30120.14112353325,29107.465238809586,1007.7248740196228,2.047886610031128,0.0 -85200,2.243152,2.9985132,,,,,,,,,,,,,, -85300,2.3135076,3.010852,,,,,,,,,,,,,, -85400,2.3476613,2.977546,,,,,,,,,,,,,, -85500,2.2394328,2.9386778,,,,,,,,,,,,,, -85600,2.2145011,2.9255543,,,,,,,,,,,,,, -85700,2.222967,2.9747202,,,,,,,,,,,,,, -85800,2.2080002,3.023815,,,,,,,,,,,,,, -85900,2.2987993,3.014194,,,,,,,,,,,,,, -86000,2.2981071,2.9920688,,,,,,,,,,,,,, -86100,2.2015452,3.0296814,,,,,,,,,,,,,, -86200,2.3685496,3.0566454,,,,,,,,,,,,,, -86300,2.3383014,2.955796,,,,,,,,,,,,,, -86400,2.2444465,2.9892917,,,,,,,,,,,,,, -86500,2.1225502,2.9573946,,,,,,,,,,,,,, -86600,2.5425172,3.0172462,,,,,,,,,,,,,, -86672,,,0.7954002022743225,1.0093271732330322,0.6930800080299377,1.4499201774597168,50000.0,0.5685000419616699,2.0924246311187744,10000.0,29617.56979894638,30647.5791118145,29617.56979894638,1024.9641880989077,2.0917553901672363,0.0 -86700,2.422392,3.0881662,,,,,,,,,,,,,, -86800,2.1801927,2.9430733,,,,,,,,,,,,,, -86900,2.2242446,3.002358,,,,,,,,,,,,,, -87000,2.249415,2.9930973,,,,,,,,,,,,,, -87100,2.5508797,2.9867382,,,,,,,,,,,,,, -87200,2.3786821,2.953474,,,,,,,,,,,,,, -87300,2.526879,2.9318964,,,,,,,,,,,,,, -87400,2.3367405,3.0442872,,,,,,,,,,,,,, -87500,2.301753,2.9572685,,,,,,,,,,,,,, -87600,2.303093,3.0060112,,,,,,,,,,,,,, -87700,2.5217435,3.0000937,,,,,,,,,,,,,, -87800,2.2507184,2.9465814,,,,,,,,,,,,,, -87900,2.2440166,2.935376,,,,,,,,,,,,,, -88000,2.3519666,2.9135075,,,,,,,,,,,,,, -88100,2.327366,2.9511166,,,,,,,,,,,,,, -88167,,,0.7859733700752258,1.0578300952911377,0.693399965763092,1.460898995399475,50000.0,0.5648000240325928,2.1210591793060303,10000.0,30127.56591463089,31175.094157218933,30127.56591463089,1042.3706114292145,2.15392541885376,0.0 -88200,2.4266803,3.0015535,,,,,,,,,,,,,, -88300,2.3337882,2.9773996,,,,,,,,,,,,,, -88400,2.1897242,3.020004,,,,,,,,,,,,,, -88500,2.3779702,3.0041025,,,,,,,,,,,,,, -88600,2.3771513,2.9912977,,,,,,,,,,,,,, -88700,2.3683598,2.9808702,,,,,,,,,,,,,, -88800,2.2657511,2.8708973,,,,,,,,,,,,,, -88900,2.2473624,3.0313156,,,,,,,,,,,,,, -89000,2.4852376,2.9789493,,,,,,,,,,,,,, -89100,2.343983,2.943527,,,,,,,,,,,,,, -89200,2.2197473,2.979932,,,,,,,,,,,,,, -89300,2.4122322,2.9988124,,,,,,,,,,,,,, -89400,2.190628,2.8839977,,,,,,,,,,,,,, -89500,2.165442,2.9292946,,,,,,,,,,,,,, -89600,2.3064902,2.9690244,,,,,,,,,,,,,, -89662,,,0.7878866195678711,1.08149516582489,0.6979999542236328,1.4614171981811523,50000.0,0.5674000382423401,2.1174557209014893,10000.0,30637.575980186462,31702.41963338852,30637.575980186462,1059.5920572280884,2.1985137462615967,0.0 -89700,2.4175055,2.937295,,,,,,,,,,,,,, -89800,2.3074944,2.9547772,,,,,,,,,,,,,, -89900,2.5521057,3.0435934,,,,,,,,,,,,,, -90000,2.2608685,2.9474137,,,,,,,,,,,,,, -90100,2.303923,3.0385897,,,,,,,,,,,,,, -90200,2.4159172,2.98441,,,,,,,,,,,,,, -90300,2.461082,2.9894485,,,,,,,,,,,,,, -90400,2.5591428,2.9771767,,,,,,,,,,,,,, -90500,2.3158853,2.9378567,,,,,,,,,,,,,, -90600,2.2963243,2.942484,,,,,,,,,,,,,, -90700,2.3764088,3.0755644,,,,,,,,,,,,,, -90800,2.558398,3.0011365,,,,,,,,,,,,,, -90900,2.3230433,2.9592426,,,,,,,,,,,,,, -91000,2.6043437,3.058033,,,,,,,,,,,,,, -91100,2.263196,2.9825683,,,,,,,,,,,,,, -91158,,,0.7898397445678711,1.0304040908813477,0.6977599859237671,1.4194663763046265,50000.0,0.5745000243186951,2.0515968799591064,10000.0,31147.80820083618,32229.96341466904,31147.80820083618,1076.8009662628174,2.2499001026153564,0.0 -91200,2.4196732,2.9514894,,,,,,,,,,,,,, -91300,2.345998,2.9618716,,,,,,,,,,,,,, -91400,2.4882941,2.9813685,,,,,,,,,,,,,, -91500,2.2226496,2.8992379,,,,,,,,,,,,,, -91600,2.3311543,2.976891,,,,,,,,,,,,,, -91700,2.5898452,2.963969,,,,,,,,,,,,,, -91800,2.4116874,3.0329602,,,,,,,,,,,,,, -91900,2.3294036,2.9532614,,,,,,,,,,,,,, -92000,2.3462055,2.993045,,,,,,,,,,,,,, -92100,2.581357,3.0694332,,,,,,,,,,,,,, -92200,2.2879002,2.9678755,,,,,,,,,,,,,, -92300,2.3997376,2.9016216,,,,,,,,,,,,,, -92400,2.2925189,2.9549599,,,,,,,,,,,,,, -92500,2.312805,2.962247,,,,,,,,,,,,,, -92600,2.3888774,2.9853358,,,,,,,,,,,,,, -92654,,,0.7855349183082581,1.0370293855667114,0.6988399624824524,1.4155194759368896,50000.0,0.5738000273704529,2.067585229873657,10000.0,31657.907905578613,32757.361146211624,31657.907905578613,1094.0021858215332,2.295071840286255,0.0 -92700,2.4000936,2.9291184,,,,,,,,,,,,,, -92800,2.3494105,2.9976754,,,,,,,,,,,,,, -92900,2.3778927,2.9559016,,,,,,,,,,,,,, -93000,2.4374568,3.0038838,,,,,,,,,,,,,, -93100,2.3717005,3.0287757,,,,,,,,,,,,,, -93200,2.4466994,2.938799,,,,,,,,,,,,,, -93300,2.2398372,2.9589846,,,,,,,,,,,,,, -93400,2.2857182,2.8839433,,,,,,,,,,,,,, -93500,2.4803884,2.9790528,,,,,,,,,,,,,, -93600,2.347861,2.9496586,,,,,,,,,,,,,, -93700,2.545062,2.9929748,,,,,,,,,,,,,, -93800,2.44815,3.0387223,,,,,,,,,,,,,, -93900,2.4859908,2.9272387,,,,,,,,,,,,,, -94000,2.5372152,3.005538,,,,,,,,,,,,,, -94100,2.3226755,2.9556584,,,,,,,,,,,,,, -94150,,,0.7915935516357422,1.011091709136963,0.7041199803352356,1.3878302574157717,50000.0,0.5834000110626221,2.04312801361084,10000.0,32168.14310240745,33284.76044559479,32168.14310240745,1111.0696558952332,2.34027099609375,0.0 -94200,2.4272199,3.0404534,,,,,,,,,,,,,, -94300,2.4800456,3.0232635,,,,,,,,,,,,,, -94400,2.1767771,2.8812141,,,,,,,,,,,,,, -94500,2.501276,3.0194855,,,,,,,,,,,,,, -94600,2.438298,2.9623585,,,,,,,,,,,,,, -94700,2.4869325,2.928365,,,,,,,,,,,,,, -94800,2.571926,2.9719896,,,,,,,,,,,,,, -94900,2.6333306,3.0132322,,,,,,,,,,,,,, -95000,2.3478367,2.9743204,,,,,,,,,,,,,, -95100,2.6167114,2.9155118,,,,,,,,,,,,,, -95200,2.7908573,2.9437766,,,,,,,,,,,,,, -95300,2.5960178,2.9205458,,,,,,,,,,,,,, -95400,2.3481388,2.9620633,,,,,,,,,,,,,, -95500,2.490853,3.0452983,,,,,,,,,,,,,, -95600,2.3498127,2.9592202,,,,,,,,,,,,,, -95646,,,0.8151108026504517,0.9238508343696594,0.7001399993896484,1.4183787107467651,50000.0,0.570900022983551,2.061049699783325,10000.0,32678.294413089752,33812.28715801239,32678.294413089752,1128.321323633194,2.410131454467773,0.0 -95700,2.3392298,2.9419312,,,,,,,,,,,,,, -95800,2.47942,2.9318666,,,,,,,,,,,,,, -95900,2.5455227,3.006628,,,,,,,,,,,,,, -96000,2.6746657,2.9810443,,,,,,,,,,,,,, -96100,2.5109177,2.893846,,,,,,,,,,,,,, -96200,2.6348796,2.9664078,,,,,,,,,,,,,, -96300,2.4389224,2.9768457,,,,,,,,,,,,,, -96400,2.2876132,2.8921185,,,,,,,,,,,,,, -96500,2.469266,2.9270942,,,,,,,,,,,,,, -96600,2.6955407,3.0024898,,,,,,,,,,,,,, -96700,2.3983514,2.9334927,,,,,,,,,,,,,, -96800,2.5050716,2.952136,,,,,,,,,,,,,, -96900,2.3765574,2.9613252,,,,,,,,,,,,,, -97000,2.4323652,2.9305844,,,,,,,,,,,,,, -97100,2.519895,2.9804034,,,,,,,,,,,,,, -97141,,,0.8090322017669678,0.9489533305168152,0.7054199576377869,1.388757824897766,50000.0,0.5785000324249268,2.047945022583008,10000.0,33188.26053261757,34339.3632774353,33188.26053261757,1145.3379135131836,2.4528286457061768,0.0 -97200,2.589512,3.0217607,,,,,,,,,,,,,, -97300,2.3907964,2.9418907,,,,,,,,,,,,,, -97400,2.7276263,2.9514322,,,,,,,,,,,,,, -97500,2.4772367,2.877186,,,,,,,,,,,,,, -97600,2.324715,2.8851852,,,,,,,,,,,,,, -97700,2.415495,2.8838906,,,,,,,,,,,,,, -97800,2.5977852,3.0232453,,,,,,,,,,,,,, -97900,2.3164635,2.939127,,,,,,,,,,,,,, -98000,2.5087562,2.9639654,,,,,,,,,,,,,, -98100,2.4382946,2.9535172,,,,,,,,,,,,,, -98200,2.4338589,2.8934395,,,,,,,,,,,,,, -98300,2.3439152,2.906188,,,,,,,,,,,,,, -98400,2.5411894,3.021315,,,,,,,,,,,,,, -98500,2.3804996,2.934136,,,,,,,,,,,,,, -98600,2.4402957,2.9597282,,,,,,,,,,,,,, -98637,,,0.8087332248687744,0.9831731915473938,0.7090199589729309,1.404129147529602,50000.0,0.5819000005722046,2.044516563415528,10000.0,33698.45007276535,34867.06263709068,33698.45007276535,1162.741782665253,2.5040063858032227,0.0 -98700,2.6240928,2.9989693,,,,,,,,,,,,,, -98800,2.6109414,2.9687443,,,,,,,,,,,,,, -98900,2.459713,2.9268565,,,,,,,,,,,,,, -99000,2.5520957,3.0032747,,,,,,,,,,,,,, -99100,2.4250793,2.858971,,,,,,,,,,,,,, -99200,2.585123,3.0078638,,,,,,,,,,,,,, -99300,2.6191254,2.9528666,,,,,,,,,,,,,, -99400,2.488697,2.9046228,,,,,,,,,,,,,, -99500,2.3639438,2.9732926,,,,,,,,,,,,,, -99600,2.57101,2.951905,,,,,,,,,,,,,, -99700,2.521826,2.9101098,,,,,,,,,,,,,, -99800,2.4453738,2.8503149,,,,,,,,,,,,,, -99900,2.598954,3.00557,,,,,,,,,,,,,, -100000,2.3823397,2.8820944,,,,,,,,,,,,,, -100100,2.4865856,2.9458728,,,,,,,,,,,,,, -100132,,,0.7820471525192261,1.040468454360962,0.687720000743866,1.4538909196853638,50000.0,0.5609000325202942,2.109252691268921,10000.0,34208.44733428955,35394.59343075752,34208.44733428955,1180.1750495433807,2.553105354309082,0.0 -100200,2.404794,2.8412259,,,,,,,,,,,,,, -100300,2.6141667,3.025289,,,,,,,,,,,,,, -100400,2.434975,2.9199843,,,,,,,,,,,,,, -100500,2.548085,2.9544592,,,,,,,,,,,,,, -100600,2.5132153,2.8899484,,,,,,,,,,,,,, -100700,2.6411269,2.9932528,,,,,,,,,,,,,, -100800,2.5529416,2.8695052,,,,,,,,,,,,,, -100900,2.669415,2.9622312,,,,,,,,,,,,,, -101000,2.572135,3.0160952,,,,,,,,,,,,,, -101100,2.6282122,2.910373,,,,,,,,,,,,,, -101200,2.4459248,2.9132533,,,,,,,,,,,,,, -101300,2.6848848,2.8900018,,,,,,,,,,,,,, -101400,2.5300357,2.9348557,,,,,,,,,,,,,, -101500,2.5254955,2.9032886,,,,,,,,,,,,,, -101600,2.5379405,2.9257636,,,,,,,,,,,,,, -101628,,,0.8019172549247742,0.9545334577560424,0.7048199772834778,1.3737319707870483,50000.0,0.5788000226020813,2.02254056930542,10000.0,34718.620924949646,35922.155831575394,34718.620924949646,1197.4660267829895,2.60019588470459,0.0 -101700,2.6888099,2.9128873,,,,,,,,,,,,,, -101800,2.5176861,2.9065957,,,,,,,,,,,,,, -101900,2.703166,3.0164902,,,,,,,,,,,,,, -102000,2.6471612,2.9520257,,,,,,,,,,,,,, -102100,2.6147482,2.9165647,,,,,,,,,,,,,, -102200,2.3602357,2.8835776,,,,,,,,,,,,,, -102300,2.5534236,2.8576682,,,,,,,,,,,,,, -102400,2.62558,2.9150317,,,,,,,,,,,,,, -102500,2.5977652,2.9436738,,,,,,,,,,,,,, -102600,2.5152402,2.9770296,,,,,,,,,,,,,, -102700,2.421562,2.8479362,,,,,,,,,,,,,, -102800,2.5653255,2.8946579,,,,,,,,,,,,,, -102900,2.660785,2.990483,,,,,,,,,,,,,, -103000,2.7031734,2.9943728,,,,,,,,,,,,,, -103100,2.7329617,2.9154902,,,,,,,,,,,,,, -103124,,,0.8040497303009033,0.9926905035972596,0.7096999883651733,1.3941349983215332,50000.0,0.5879000425338745,2.022136926651001,10000.0,35228.76759457588,36449.7290225029,35228.76759457588,1214.7932357788086,2.648162364959717,0.0 -103200,2.661185,2.9455276,,,,,,,,,,,,,, -103300,2.6039042,2.9816227,,,,,,,,,,,,,, -103400,2.5678737,2.9132383,,,,,,,,,,,,,, -103500,2.534915,2.902996,,,,,,,,,,,,,, -103600,2.779865,3.0063248,,,,,,,,,,,,,, -103700,2.4357462,2.8947449,,,,,,,,,,,,,, -103800,2.590459,2.9053323,,,,,,,,,,,,,, -103900,2.633137,2.8766234,,,,,,,,,,,,,, -104000,2.6844854,2.9117713,,,,,,,,,,,,,, -104100,2.5398066,2.925499,,,,,,,,,,,,,, -104200,2.6653364,2.8844988,,,,,,,,,,,,,, -104300,2.5342946,2.884862,,,,,,,,,,,,,, -104400,2.8946187,2.9125943,,,,,,,,,,,,,, -104500,2.5732632,2.9622822,,,,,,,,,,,,,, -104600,2.6235073,2.9516602,,,,,,,,,,,,,, -104620,,,0.8269292116165161,0.9147905707359314,0.7056399583816528,1.4107757806777954,50000.0,0.5788000226020813,2.058584213256836,10000.0,35738.865753889084,36977.66103959084,35738.865753889084,1232.5322148799896,2.693121671676636,0.0 -104700,2.6414533,2.9481473,,,,,,,,,,,,,, -104800,2.604721,2.9492526,,,,,,,,,,,,,, -104900,2.7171283,2.9043953,,,,,,,,,,,,,, -105000,2.736629,2.889416,,,,,,,,,,,,,, -105100,2.6688313,2.918801,,,,,,,,,,,,,, -105200,2.5238047,2.8419232,,,,,,,,,,,,,, -105300,2.446338,2.8873088,,,,,,,,,,,,,, -105400,2.7347052,2.9446564,,,,,,,,,,,,,, -105500,2.8255885,2.9368916,,,,,,,,,,,,,, -105600,2.361961,2.8599238,,,,,,,,,,,,,, -105700,2.5821302,2.9879632,,,,,,,,,,,,,, -105800,2.7746058,2.9106836,,,,,,,,,,,,,, -105900,2.7308228,2.8984413,,,,,,,,,,,,,, -106000,2.531827,2.8829296,,,,,,,,,,,,,, -106100,2.4889226,2.9341192,,,,,,,,,,,,,, -106116,,,0.8301976919174194,0.8830342292785645,0.711359977722168,1.3829030990600586,50000.0,0.5855000019073486,2.002440929412842,10000.0,36249.07039260864,37505.13443374634,36249.07039260864,1249.7009994983673,2.7423999309539795,0.0 -106200,2.6353676,2.9757745,,,,,,,,,,,,,, -106300,2.6757886,2.9466403,,,,,,,,,,,,,, -106400,2.6919703,2.9635856,,,,,,,,,,,,,, -106500,2.6671932,2.9137483,,,,,,,,,,,,,, -106600,2.6905272,2.953084,,,,,,,,,,,,,, -106700,2.6373684,2.8574276,,,,,,,,,,,,,, -106800,2.63089,2.9473662,,,,,,,,,,,,,, -106900,2.5673401,2.8300278,,,,,,,,,,,,,, -107000,2.7961345,2.9627893,,,,,,,,,,,,,, -107100,2.7672157,2.8998358,,,,,,,,,,,,,, -107200,2.5228155,2.86791,,,,,,,,,,,,,, -107300,2.6571267,2.947937,,,,,,,,,,,,,, -107400,2.5771399,2.876624,,,,,,,,,,,,,, -107500,2.4786184,2.881729,,,,,,,,,,,,,, -107600,2.8013914,2.9779096,,,,,,,,,,,,,, -107612,,,0.8184390664100647,0.9219950437545776,0.7115199565887451,1.380744695663452,50000.0,0.5803000330924988,2.026761054992676,10000.0,36759.14828467369,38032.58826184273,36759.14828467369,1266.9804100990295,2.787186861038208,0.0 -107700,2.6152713,2.844965,,,,,,,,,,,,,, -107800,2.558501,2.8835917,,,,,,,,,,,,,, -107900,2.53365,2.8829737,,,,,,,,,,,,,, -108000,2.8094096,2.9860208,,,,,,,,,,,,,, -108100,2.5216024,2.841419,,,,,,,,,,,,,, -108200,2.8608994,2.9299936,,,,,,,,,,,,,, -108300,2.8144817,2.9024305,,,,,,,,,,,,,, -108400,2.8122396,2.9140143,,,,,,,,,,,,,, -108500,2.9267642,2.953333,,,,,,,,,,,,,, -108600,2.8051856,2.8567493,,,,,,,,,,,,,, -108700,2.7219405,2.8875442,,,,,,,,,,,,,, -108800,2.69576,2.9086206,,,,,,,,,,,,,, -108900,2.79921,2.9408891,,,,,,,,,,,,,, -109000,2.7611585,2.8920882,,,,,,,,,,,,,, -109100,2.7010071,2.8712683,,,,,,,,,,,,,, -109107,,,0.8204121589660645,0.8763977885246277,0.7163999676704407,1.3261253833770752,50000.0,0.5889000296592712,1.9659696817398071,10000.0,37269.2513358593,38559.98116540909,37269.2513358593,1284.1733181476593,2.833984613418579,0.0 -109200,2.6533008,2.8468113,,,,,,,,,,,,,, -109300,2.6804934,2.8223038,,,,,,,,,,,,,, -109400,2.7651377,2.8510628,,,,,,,,,,,,,, -109500,2.801354,2.9085755,,,,,,,,,,,,,, -109600,2.982308,2.8906853,,,,,,,,,,,,,, -109700,2.768363,2.9147224,,,,,,,,,,,,,, -109800,2.714649,2.8478003,,,,,,,,,,,,,, -109900,2.7108686,2.884427,,,,,,,,,,,,,, -110000,2.6370716,2.8675187,,,,,,,,,,,,,, -110100,2.8769035,2.8879316,,,,,,,,,,,,,, -110200,2.71988,2.9104712,,,,,,,,,,,,,, -110300,2.7400393,2.9587145,,,,,,,,,,,,,, -110400,2.8094344,2.9399376,,,,,,,,,,,,,, -110500,3.0036848,2.911114,,,,,,,,,,,,,, -110600,2.6975126,2.8226273,,,,,,,,,,,,,, -110603,,,0.8202527165412903,0.907325804233551,0.7150799632072449,1.3554683923721311,50000.0,0.5920000076293945,1.9924134016036987,10000.0,37779.28477668762,39087.46141552925,37779.28477668762,1301.520435810089,2.8837478160858154,0.0 -110700,2.5956771,2.9098132,,,,,,,,,,,,,, -110800,2.7891579,2.852338,,,,,,,,,,,,,, -110900,2.6980326,2.9033442,,,,,,,,,,,,,, -111000,3.0220084,2.9784896,,,,,,,,,,,,,, -111100,2.7670162,2.9008899,,,,,,,,,,,,,, -111200,2.673927,2.8552554,,,,,,,,,,,,,, -111300,2.751276,2.941904,,,,,,,,,,,,,, -111400,2.7281163,2.8451076,,,,,,,,,,,,,, -111500,2.7781122,2.8625722,,,,,,,,,,,,,, -111600,2.791108,2.9096038,,,,,,,,,,,,,, -111700,2.7683303,2.8706412,,,,,,,,,,,,,, -111800,2.730632,2.8319345,,,,,,,,,,,,,, -111900,2.8642457,2.8333468,,,,,,,,,,,,,, -112000,2.69503,2.8759086,,,,,,,,,,,,,, -112099,,,0.8202327489852905,0.9242339134216307,0.7185199856758118,1.3499921560287476,50000.0,0.5906000137329102,2.007988452911377,10000.0,38289.51672291756,39615.02745246887,38289.51672291756,1318.7552318572998,2.933572769165039,0.0 -112100,2.7466314,2.8653636,,,,,,,,,,,,,, -112200,2.7554162,2.8226862,,,,,,,,,,,,,, -112300,2.7804804,2.9032304,,,,,,,,,,,,,, -112400,2.736908,2.9311018,,,,,,,,,,,,,, -112500,2.7723916,2.883096,,,,,,,,,,,,,, -112600,2.8657053,2.9059596,,,,,,,,,,,,,, -112700,2.8417623,2.8801627,,,,,,,,,,,,,, -112800,2.823685,2.8584633,,,,,,,,,,,,,, -112900,2.869931,2.888982,,,,,,,,,,,,,, -113000,2.905247,2.8721886,,,,,,,,,,,,,, -113100,2.6840005,2.8543904,,,,,,,,,,,,,, -113200,2.6716774,2.8125155,,,,,,,,,,,,,, -113300,2.6832132,2.812941,,,,,,,,,,,,,, -113400,2.7516894,2.8565428,,,,,,,,,,,,,, -113500,2.790129,2.9111853,,,,,,,,,,,,,, -113594,,,0.8160673975944519,0.9513839483261108,0.7136200070381165,1.3845340013504028,50000.0,0.5892000198364258,2.026318550109864,10000.0,38799.4973552227,40142.33880519867,38799.4973552227,1335.9850897789,2.984329462051392,0.0 -113600,2.7182317,2.8375626,,,,,,,,,,,,,, -113700,2.8062608,2.786142,,,,,,,,,,,,,, -113800,2.6985002,2.8236358,,,,,,,,,,,,,, -113900,2.786525,2.8420062,,,,,,,,,,,,,, -114000,3.0327764,2.9222293,,,,,,,,,,,,,, -114100,3.0342984,2.8593235,,,,,,,,,,,,,, -114200,2.7179067,2.80723,,,,,,,,,,,,,, -114300,2.769291,2.837943,,,,,,,,,,,,,, -114400,2.8543606,2.8425188,,,,,,,,,,,,,, -114500,2.817388,2.898856,,,,,,,,,,,,,, -114600,2.9630084,2.9299612,,,,,,,,,,,,,, -114700,2.8720207,2.897037,,,,,,,,,,,,,, -114800,2.6873057,2.9218457,,,,,,,,,,,,,, -114900,2.8244839,2.8730237,,,,,,,,,,,,,, -115000,2.694514,2.919574,,,,,,,,,,,,,, -115090,,,0.8474569320678711,0.7873832583427429,0.7181199789047241,1.3268829584121704,50000.0,0.5927000045776367,1.9718732833862305,10000.0,39309.72198271752,40670.07956337929,39309.72198271752,1353.4041435718536,3.0307705402374268,0.0 -115100,2.8978844,2.808113,,,,,,,,,,,,,, -115200,2.8647265,2.8424718,,,,,,,,,,,,,, -115300,2.8394198,2.8313658,,,,,,,,,,,,,, -115400,2.8540206,2.8469217,,,,,,,,,,,,,, -115500,2.9382772,2.8388252,,,,,,,,,,,,,, -115600,2.9359083,2.8430202,,,,,,,,,,,,,, -115700,2.72658,2.8202038,,,,,,,,,,,,,, -115800,2.7332537,2.8357008,,,,,,,,,,,,,, -115900,2.9258146,2.811371,,,,,,,,,,,,,, -116000,2.8799233,2.8177404,,,,,,,,,,,,,, -116100,2.7942934,2.8403666,,,,,,,,,,,,,, -116200,3.1023738,2.8396404,,,,,,,,,,,,,, -116300,2.7892296,2.8467028,,,,,,,,,,,,,, -116400,2.679433,2.8023071,,,,,,,,,,,,,, -116500,2.6527958,2.8236415,,,,,,,,,,,,,, -116586,,,0.8415776491165161,0.8181595206260681,0.7233999967575073,1.3208779096603394,50000.0,0.5944000482559204,1.96568763256073,10000.0,39819.89853024483,41197.69870424271,39819.89853024483,1370.7485435009005,3.077587127685547,0.0 -116600,2.9045064,2.8898606,,,,,,,,,,,,,, -116700,2.9892995,2.8404782,,,,,,,,,,,,,, -116800,3.2326074,2.8843625,,,,,,,,,,,,,, -116900,2.9792204,2.881415,,,,,,,,,,,,,, -117000,2.7338526,2.7917542,,,,,,,,,,,,,, -117100,3.0321279,2.8140607,,,,,,,,,,,,,, -117200,2.8069422,2.7929578,,,,,,,,,,,,,, -117300,2.9555576,2.8096356,,,,,,,,,,,,,, -117400,2.8465562,2.8622727,,,,,,,,,,,,,, -117500,2.953489,2.8425121,,,,,,,,,,,,,, -117600,3.3185477,2.8931165,,,,,,,,,,,,,, -117700,2.7218678,2.8347507,,,,,,,,,,,,,, -117800,2.8418634,2.840435,,,,,,,,,,,,,, -117900,2.8710365,2.8778954,,,,,,,,,,,,,, -118000,2.8630874,2.8078864,,,,,,,,,,,,,, -118081,,,0.8370934128761292,0.8302208185195923,0.7227999567985535,1.316156268119812,50000.0,0.5999000072479248,1.953519582748413,10000.0,40329.83609485626,41725.5061340332,40329.83609485626,1388.5198497772217,3.126605749130249,0.0 -118100,2.788877,2.7823794,,,,,,,,,,,,,, -118200,3.0589187,2.8741891,,,,,,,,,,,,,, -118300,2.79857,2.8962789,,,,,,,,,,,,,, -118400,3.1393695,2.8088164,,,,,,,,,,,,,, -118500,3.0229912,2.8213978,,,,,,,,,,,,,, -118600,2.7948427,2.8389769,,,,,,,,,,,,,, -118700,2.8855736,2.8423805,,,,,,,,,,,,,, -118800,2.999262,2.8345768,,,,,,,,,,,,,, -118900,3.2262144,2.898532,,,,,,,,,,,,,, -119000,2.8344216,2.8112838,,,,,,,,,,,,,, -119100,2.861307,2.856618,,,,,,,,,,,,,, -119200,2.7493193,2.812379,,,,,,,,,,,,,, -119300,2.9865901,2.8902762,,,,,,,,,,,,,, -119400,3.0050447,2.9173725,,,,,,,,,,,,,, -119500,2.8708665,2.8291926,,,,,,,,,,,,,, -119577,,,0.8384885191917419,0.8248199224472046,0.7274799942970276,1.3024871349334717,50000.0,0.5933000445365906,1.958518624305725,10000.0,40839.775539159775,42252.89182281494,40839.775539159775,1405.8627269268036,3.179180860519409,0.0 -119600,2.862915,2.7713127,,,,,,,,,,,,,, -119700,2.8510616,2.77387,,,,,,,,,,,,,, -119800,2.9889612,2.889784,,,,,,,,,,,,,, -119900,3.504896,2.8169446,,,,,,,,,,,,,, -120000,3.095152,2.82071,,,,,,,,,,,,,, -120100,2.9457068,2.8314939,,,,,,,,,,,,,, -120200,2.9762144,2.792712,,,,,,,,,,,,,, -120300,2.9557621,2.8219051,,,,,,,,,,,,,, -120400,3.0708685,2.881804,,,,,,,,,,,,,, -120500,3.035453,2.906585,,,,,,,,,,,,,, -120600,3.1152065,2.836372,,,,,,,,,,,,,, -120700,2.991259,2.8186884,,,,,,,,,,,,,, -120800,2.9229403,2.7993538,,,,,,,,,,,,,, -120900,2.9967902,2.7966511,,,,,,,,,,,,,, -121000,2.9069102,2.7851574,,,,,,,,,,,,,, -121073,,,0.8438097834587097,0.8003804683685303,0.7287399768829346,1.283615231513977,50000.0,0.6027000546455383,1.9129961729049685,10000.0,41350.00688409805,42780.44771409035,41350.00688409805,1423.086817741394,3.2281734943389893,0.0 -121100,3.2613494,2.8028383,,,,,,,,,,,,,, -121200,2.962293,2.8012686,,,,,,,,,,,,,, -121300,2.9936614,2.756569,,,,,,,,,,,,,, -121400,2.8222563,2.877101,,,,,,,,,,,,,, -121500,3.1120305,2.8139262,,,,,,,,,,,,,, -121600,2.9523504,2.8179178,,,,,,,,,,,,,, -121700,2.8927953,2.8393154,,,,,,,,,,,,,, -121800,3.0057592,2.8975382,,,,,,,,,,,,,, -121900,3.0300162,2.8443823,,,,,,,,,,,,,, -122000,3.13141,2.839056,,,,,,,,,,,,,, -122100,3.0130386,2.7984443,,,,,,,,,,,,,, -122200,3.048695,2.784327,,,,,,,,,,,,,, -122300,3.025816,2.8567023,,,,,,,,,,,,,, -122400,3.057916,2.7832801,,,,,,,,,,,,,, -122500,3.210171,2.8199,,,,,,,,,,,,,, -122569,,,0.832051157951355,0.8912011384963989,0.7217599749565125,1.3543952703475952,50000.0,0.5974000096321106,1.9827086925506592,10000.0,41860.03619909287,43307.74064588547,41860.03619909287,1440.245083808899,3.281466007232666,0.0 -122600,2.7885003,2.8149586,,,,,,,,,,,,,, -122700,3.017115,2.8591883,,,,,,,,,,,,,, -122800,2.9568124,2.797756,,,,,,,,,,,,,, -122900,3.057273,2.8327973,,,,,,,,,,,,,, -123000,3.0075784,2.8171415,,,,,,,,,,,,,, -123100,2.984426,2.8533173,,,,,,,,,,,,,, -123200,3.0879648,2.7650628,,,,,,,,,,,,,, -123300,2.919644,2.7528973,,,,,,,,,,,,,, -123400,3.0149229,2.8265514,,,,,,,,,,,,,, -123500,3.1485536,2.7980525,,,,,,,,,,,,,, -123600,3.030506,2.8037124,,,,,,,,,,,,,, -123700,2.9616418,2.8398857,,,,,,,,,,,,,, -123800,3.2101545,2.8309405,,,,,,,,,,,,,, -123900,3.0550046,2.818499,,,,,,,,,,,,,, -124000,2.966341,2.799097,,,,,,,,,,,,,, -124065,,,0.8768334984779358,0.6922958493232727,0.7325199842453003,1.277902126312256,50000.0,0.6080000400543213,1.9164098501205444,10000.0,42370.23780179024,43835.2552819252,42370.23780179024,1457.4404020309448,3.34682035446167,0.0 -124100,3.0483894,2.8384724,,,,,,,,,,,,,, -124200,2.870059,2.7447278,,,,,,,,,,,,,, -124300,3.0220394,2.746305,,,,,,,,,,,,,, -124400,2.9046955,2.8107448,,,,,,,,,,,,,, -124500,3.0249596,2.745998,,,,,,,,,,,,,, -124600,3.0450995,2.7744615,,,,,,,,,,,,,, -124700,3.0019662,2.7479944,,,,,,,,,,,,,, -124800,2.9371145,2.724066,,,,,,,,,,,,,, -124900,2.8745503,2.7129078,,,,,,,,,,,,,, -125000,2.8209684,2.7368104,,,,,,,,,,,,,, -125100,3.0083475,2.809282,,,,,,,,,,,,,, -125200,3.0796132,2.8084478,,,,,,,,,,,,,, -125300,3.0931,2.742719,,,,,,,,,,,,,, -125400,3.098808,2.8076844,,,,,,,,,,,,,, -125500,3.1928012,2.8783088,,,,,,,,,,,,,, -125561,,,0.857421875,0.7503759264945984,0.7276999950408936,1.2880154848098757,50000.0,0.6013000011444092,1.927111268043518,10000.0,42880.30663514137,44362.843577861786,42880.30663514137,1474.8576436042786,3.398554563522339,0.0 -125600,3.130815,2.8236003,,,,,,,,,,,,,, -125700,2.980015,2.7228584,,,,,,,,,,,,,, -125800,3.226832,2.840319,,,,,,,,,,,,,, -125900,3.2830708,2.7835612,,,,,,,,,,,,,, -126000,3.1262207,2.8052444,,,,,,,,,,,,,, -126100,3.0685594,2.7559495,,,,,,,,,,,,,, -126200,3.2981265,2.8372416,,,,,,,,,,,,,, -126300,3.145511,2.8456297,,,,,,,,,,,,,, -126400,3.4241714,2.8059063,,,,,,,,,,,,,, -126500,3.0005953,2.7043009,,,,,,,,,,,,,, -126600,3.0335584,2.8123004,,,,,,,,,,,,,, -126700,3.041966,2.7683926,,,,,,,,,,,,,, -126800,3.1832688,2.789899,,,,,,,,,,,,,, -126900,3.1650703,2.8411276,,,,,,,,,,,,,, -127000,3.0586808,2.7724078,,,,,,,,,,,,,, -127057,,,0.8565648794174194,0.7832421064376831,0.7305200099945068,1.3154152631759644,50000.0,0.6014000177383423,1.9597022533416748,10000.0,43390.468705654144,44890.360988378525,43390.468705654144,1492.1135189533234,3.44738507270813,0.0 -127100,3.097267,2.8183355,,,,,,,,,,,,,, -127200,3.2555754,2.779198,,,,,,,,,,,,,, -127300,3.287798,2.773777,,,,,,,,,,,,,, -127400,3.1577702,2.7570305,,,,,,,,,,,,,, -127500,3.0123498,2.6772823,,,,,,,,,,,,,, -127600,3.1203213,2.779698,,,,,,,,,,,,,, -127700,3.2902393,2.8398604,,,,,,,,,,,,,, -127800,3.2018254,2.827971,,,,,,,,,,,,,, -127900,3.2766078,2.7142093,,,,,,,,,,,,,, -128000,3.433888,2.8234823,,,,,,,,,,,,,, -128100,3.2432117,2.8070066,,,,,,,,,,,,,, -128200,3.1424534,2.7202349,,,,,,,,,,,,,, -128300,3.4613175,2.7949104,,,,,,,,,,,,,, -128400,3.2422554,2.7497635,,,,,,,,,,,,,, -128500,3.1897237,2.7749329,,,,,,,,,,,,,, -128553,,,0.8598732352256775,0.7385032773017883,0.7359399795532227,1.2580225467681885,50000.0,0.6163000464439392,1.8852897882461548,10000.0,43900.60306334496,45418.12388706207,43900.60306334496,1509.6416466236117,3.495047092437744,0.0 -128600,3.3196883,2.7595994,,,,,,,,,,,,,, -128700,2.96504,2.7132344,,,,,,,,,,,,,, -128800,3.0988107,2.781216,,,,,,,,,,,,,, -128900,3.1058326,2.6756508,,,,,,,,,,,,,, -129000,3.24244,2.7374084,,,,,,,,,,,,,, -129100,3.1777754,2.8229425,,,,,,,,,,,,,, -129200,2.9422112,2.7658901,,,,,,,,,,,,,, -129300,3.2856193,2.7840483,,,,,,,,,,,,,, -129400,3.140599,2.8278193,,,,,,,,,,,,,, -129500,3.1955996,2.7428598,,,,,,,,,,,,,, -129600,3.1110148,2.7988458,,,,,,,,,,,,,, -129700,3.149899,2.7221196,,,,,,,,,,,,,, -129800,3.1007888,2.802176,,,,,,,,,,,,,, -129900,3.1792455,2.7395885,,,,,,,,,,,,,, -130000,3.1027553,2.8032231,,,,,,,,,,,,,, -130048,,,0.8598732352256775,0.7397308349609375,0.7367599606513977,1.2535072565078735,50000.0,0.6109000444412231,1.88594388961792,10000.0,44410.62669610977,45945.561729192734,44410.62669610977,1526.9550709724426,3.543776512145996,0.0 -130100,3.2556376,2.7900305,,,,,,,,,,,,,, -130200,3.2570906,2.7386456,,,,,,,,,,,,,, -130300,2.976583,2.6977444,,,,,,,,,,,,,, -130400,2.9900682,2.7559133,,,,,,,,,,,,,, -130500,3.3523247,2.7500892,,,,,,,,,,,,,, -130600,3.1728466,2.7058272,,,,,,,,,,,,,, -130700,3.2911615,2.769107,,,,,,,,,,,,,, -130800,3.360157,2.8269901,,,,,,,,,,,,,, -130900,3.3236964,2.7776868,,,,,,,,,,,,,, -131000,3.1509976,2.6971905,,,,,,,,,,,,,, -131100,3.5516222,2.7132368,,,,,,,,,,,,,, -131200,3.3855634,2.7516174,,,,,,,,,,,,,, -131300,3.19052,2.7539818,,,,,,,,,,,,,, -131400,3.2659307,2.7469258,,,,,,,,,,,,,, -131500,3.1846802,2.7330136,,,,,,,,,,,,,, -131543,,,0.8572624325752258,0.7386617660522461,0.7386400103569031,1.2487965822219849,50000.0,0.613800048828125,1.8877145051956177,10000.0,44920.61176490784,46472.78757691383,44920.61176490784,1544.0944755077362,3.5926764011383057,0.0 -131600,3.1920497,2.6929057,,,,,,,,,,,,,, -131700,3.3708274,2.8020887,,,,,,,,,,,,,, -131800,3.347999,2.749381,,,,,,,,,,,,,, -131900,3.1055202,2.7227552,,,,,,,,,,,,,, -132000,3.3504465,2.7501402,,,,,,,,,,,,,, -132100,3.6215878,2.7567801,,,,,,,,,,,,,, -132200,3.49915,2.6993139,,,,,,,,,,,,,, -132300,3.1685655,2.6949174,,,,,,,,,,,,,, -132400,3.149975,2.7051754,,,,,,,,,,,,,, -132500,3.3929667,2.772639,,,,,,,,,,,,,, -132600,3.2813334,2.7827477,,,,,,,,,,,,,, -132700,3.3492093,2.7672732,,,,,,,,,,,,,, -132800,3.3593616,2.7878118,,,,,,,,,,,,,, -132900,3.464885,2.718027,,,,,,,,,,,,,, -133000,3.4838586,2.8250663,,,,,,,,,,,,,, -133037,,,0.866609513759613,0.7195022106170654,0.73881995677948,1.249531865119934,50000.0,0.613800048828125,1.889812588691712,10000.0,45430.52412056923,46999.9243516922,45430.52412056923,1561.2148640155792,3.644584894180298,0.0 -133100,3.378495,2.7313576,,,,,,,,,,,,,, -133200,3.6545362,2.7742662,,,,,,,,,,,,,, -133300,3.3657649,2.7618606,,,,,,,,,,,,,, -133400,3.539343,2.7109756,,,,,,,,,,,,,, -133500,3.3440661,2.7216444,,,,,,,,,,,,,, -133600,3.3355389,2.7121751,,,,,,,,,,,,,, -133700,3.3646746,2.736279,,,,,,,,,,,,,, -133800,3.1154547,2.6798449,,,,,,,,,,,,,, -133900,3.3043697,2.8006225,,,,,,,,,,,,,, -134000,3.2807598,2.7522213,,,,,,,,,,,,,, -134100,3.6571236,2.766457,,,,,,,,,,,,,, -134200,3.3694155,2.7024238,,,,,,,,,,,,,, -134300,3.3462842,2.733819,,,,,,,,,,,,,, -134400,3.1059418,2.7386951,,,,,,,,,,,,,, -134500,3.2594619,2.7697523,,,,,,,,,,,,,, -134533,,,0.8819156289100647,0.64863520860672,0.7436599731445312,1.2403019666671753,50000.0,0.6176000237464905,1.8791606426239007,10000.0,45940.66190671921,47527.43191242218,45940.66190671921,1578.4732689857483,3.704412937164306,0.0 -134600,3.1863177,2.7247715,,,,,,,,,,,,,, -134700,3.3530576,2.7964098,,,,,,,,,,,,,, -134800,3.2827723,2.738417,,,,,,,,,,,,,, -134900,3.1469457,2.7127461,,,,,,,,,,,,,, -135000,3.7210937,2.7131534,,,,,,,,,,,,,, -135100,3.4599183,2.7089045,,,,,,,,,,,,,, -135200,3.3801112,2.6870937,,,,,,,,,,,,,, -135300,3.238962,2.7156878,,,,,,,,,,,,,, -135400,3.1953373,2.7088852,,,,,,,,,,,,,, -135500,3.1417642,2.6539936,,,,,,,,,,,,,, -135600,3.379742,2.6400855,,,,,,,,,,,,,, -135700,3.376285,2.7265837,,,,,,,,,,,,,, -135800,3.4177194,2.6960387,,,,,,,,,,,,,, -135900,3.4101944,2.7265787,,,,,,,,,,,,,, -136000,3.091802,2.6325955,,,,,,,,,,,,,, -136029,,,0.8819355964660645,0.6802816987037659,0.7430999875068665,1.2458999156951904,50000.0,0.6157000064849854,1.88289487361908,10000.0,46450.88255214691,48054.984679460526,46450.88255214691,1595.7042515277865,3.7541310787200928,0.0 -136100,3.1707218,2.7267678,,,,,,,,,,,,,, -136200,3.202426,2.6606822,,,,,,,,,,,,,, -136300,3.3454099,2.689316,,,,,,,,,,,,,, -136400,3.508166,2.7102337,,,,,,,,,,,,,, -136500,3.543012,2.6861932,,,,,,,,,,,,,, -136600,3.2327929,2.7539139,,,,,,,,,,,,,, -136700,3.3682272,2.7083478,,,,,,,,,,,,,, -136800,3.164771,2.6961524,,,,,,,,,,,,,, -136900,3.5345685,2.7275047,,,,,,,,,,,,,, -137000,3.1525035,2.674613,,,,,,,,,,,,,, -137100,3.463095,2.7391715,,,,,,,,,,,,,, -137200,3.2257185,2.671367,,,,,,,,,,,,,, -137300,3.4861462,2.7345626,,,,,,,,,,,,,, -137400,3.5216522,2.7109609,,,,,,,,,,,,,, -137500,3.4267735,2.6711946,,,,,,,,,,,,,, -137525,,,0.8787667155265808,0.682250440120697,0.7416799664497375,1.249807357788086,50000.0,0.6220000386238098,1.8697481155395508,10000.0,46961.05339837074,48582.43872284889,46961.05339837074,1612.8796932697296,3.811321020126343,0.0 -137600,3.3435528,2.695475,,,,,,,,,,,,,, -137700,3.3159883,2.7990868,,,,,,,,,,,,,, -137800,3.55791,2.8360496,,,,,,,,,,,,,, -137900,3.67022,2.7870479,,,,,,,,,,,,,, -138000,3.4746854,2.6986022,,,,,,,,,,,,,, -138100,3.436444,2.6974068,,,,,,,,,,,,,, -138200,3.2293801,2.6896572,,,,,,,,,,,,,, -138300,3.517392,2.7478483,,,,,,,,,,,,,, -138400,3.7586076,2.7974668,,,,,,,,,,,,,, -138500,3.388097,2.761957,,,,,,,,,,,,,, -138600,3.1366692,2.6933029,,,,,,,,,,,,,, -138700,3.3447194,2.6905222,,,,,,,,,,,,,, -138800,3.4519417,2.7026258,,,,,,,,,,,,,, -138900,3.4142478,2.7505922,,,,,,,,,,,,,, -139000,3.2717767,2.6374283,,,,,,,,,,,,,, -139020,,,0.8798628449440002,0.6878957748413086,0.7449600100517273,1.2519656419754028,50000.0,0.6232000589370728,1.8762387037277224,10000.0,47471.02515506744,49109.65690302849,47471.02515506744,1630.022828578949,3.863240003585816,0.0 -139100,3.6490097,2.7682054,,,,,,,,,,,,,, -139200,3.328128,2.6697896,,,,,,,,,,,,,, -139300,3.5859985,2.7609284,,,,,,,,,,,,,, -139400,3.6839857,2.7661748,,,,,,,,,,,,,, -139500,3.410419,2.6867907,,,,,,,,,,,,,, -139600,3.1115103,2.6560774,,,,,,,,,,,,,, -139700,3.3562133,2.7240334,,,,,,,,,,,,,, -139800,3.4478376,2.7018862,,,,,,,,,,,,,, -139900,3.6884866,2.6780758,,,,,,,,,,,,,, -140000,3.6483014,2.6923444,,,,,,,,,,,,,, -140100,3.3446276,2.653157,,,,,,,,,,,,,, -140200,3.3272674,2.6829748,,,,,,,,,,,,,, -140300,3.4235132,2.671727,,,,,,,,,,,,,, -140400,3.4759235,2.6643384,,,,,,,,,,,,,, -140500,3.4040763,2.7368124,,,,,,,,,,,,,, -140516,,,0.8801219463348389,0.6773303151130676,0.742859959602356,1.2454878091812134,50000.0,0.6203000545501709,1.8631223440170288,10000.0,47981.19206619263,49637.203291893005,47981.19206619263,1647.2985713481903,3.913106203079224,0.0 -140600,3.7051206,2.719705,,,,,,,,,,,,,, -140700,3.452014,2.675627,,,,,,,,,,,,,, -140800,3.9580486,2.6904685,,,,,,,,,,,,,, -140900,3.360701,2.6955001,,,,,,,,,,,,,, -141000,3.424647,2.6588955,,,,,,,,,,,,,, -141100,3.5410492,2.7109952,,,,,,,,,,,,,, -141200,3.3735042,2.6680555,,,,,,,,,,,,,, -141300,3.4174047,2.73321,,,,,,,,,,,,,, -141400,3.5183883,2.6451056,,,,,,,,,,,,,, -141500,3.9006336,2.694942,,,,,,,,,,,,,, -141600,3.513084,2.7163389,,,,,,,,,,,,,, -141700,3.4042113,2.6526124,,,,,,,,,,,,,, -141800,3.48003,2.6810222,,,,,,,,,,,,,, -141900,3.54598,2.6718116,,,,,,,,,,,,,, -142000,3.259911,2.6751208,,,,,,,,,,,,,, -142011,,,0.8788862824440002,0.6838304996490479,0.7432599663734436,1.2508834600448608,50000.0,0.6175000071525574,1.883688807487488,10000.0,48491.136139154434,50164.43881726265,48491.136139154434,1664.4843318462372,3.967681646347046,0.0 -142100,3.4942472,2.7433343,,,,,,,,,,,,,, -142200,3.733059,2.7261739,,,,,,,,,,,,,, -142300,3.576425,2.7377086,,,,,,,,,,,,,, -142400,3.6519778,2.7040167,,,,,,,,,,,,,, -142500,3.51571,2.6887662,,,,,,,,,,,,,, -142600,3.5493658,2.694244,,,,,,,,,,,,,, -142700,3.4646847,2.7037659,,,,,,,,,,,,,, -142800,3.48369,2.6436517,,,,,,,,,,,,,, -142900,3.3885016,2.6300352,,,,,,,,,,,,,, -143000,3.783509,2.7316434,,,,,,,,,,,,,, -143100,3.5647864,2.6842296,,,,,,,,,,,,,, -143200,3.6074436,2.720814,,,,,,,,,,,,,, -143300,3.5959115,2.7040403,,,,,,,,,,,,,, -143400,3.501688,2.6170263,,,,,,,,,,,,,, -143500,3.5065153,2.6262255,,,,,,,,,,,,,, -143507,,,0.9057118892669678,0.5698558688163757,0.7466599941253662,1.2103779315948486,50000.0,0.6228000521659851,1.848327279090881,10000.0,49001.24908995628,50691.854425907135,49001.24908995628,1681.6815202236176,4.021223306655884,0.0 -143600,3.5842001,2.7173681,,,,,,,,,,,,,, -143700,3.8418314,2.7148337,,,,,,,,,,,,,, -143800,3.4390547,2.653799,,,,,,,,,,,,,, -143900,3.6759932,2.690281,,,,,,,,,,,,,, -144000,3.5303888,2.6747506,,,,,,,,,,,,,, -144100,3.5262728,2.6281557,,,,,,,,,,,,,, -144200,3.6709185,2.68103,,,,,,,,,,,,,, -144300,3.8154452,2.619171,,,,,,,,,,,,,, -144400,3.4095964,2.6101568,,,,,,,,,,,,,, -144500,3.4976964,2.6571767,,,,,,,,,,,,,, -144600,3.5935016,2.664266,,,,,,,,,,,,,, -144700,3.9765866,2.7496555,,,,,,,,,,,,,, -144800,3.5217948,2.6885633,,,,,,,,,,,,,, -144900,3.518724,2.6380253,,,,,,,,,,,,,, -145000,3.4058688,2.642079,,,,,,,,,,,,,, -145002,,,0.8981983065605164,0.6036162972450256,0.7472599744796753,1.228808045387268,50000.0,0.6213000416755676,1.8598668575286863,10000.0,49511.237023592,51219.37195444107,49511.237023592,1699.1049826145172,4.07612681388855,0.0 -145100,3.6362143,2.6125944,,,,,,,,,,,,,, -145200,3.4746127,2.6738224,,,,,,,,,,,,,, -145300,3.3428957,2.6189578,,,,,,,,,,,,,, -145400,3.4852579,2.5981073,,,,,,,,,,,,,, -145500,3.7484326,2.6535285,,,,,,,,,,,,,, -145600,3.616129,2.637256,,,,,,,,,,,,,, -145700,3.6183805,2.6392848,,,,,,,,,,,,,, -145800,3.8216255,2.6896703,,,,,,,,,,,,,, -145900,3.444076,2.7203274,,,,,,,,,,,,,, -146000,3.7578583,2.6622071,,,,,,,,,,,,,, -146100,3.4211814,2.7184396,,,,,,,,,,,,,, -146200,3.762017,2.693205,,,,,,,,,,,,,, -146300,3.5199807,2.66205,,,,,,,,,,,,,, -146400,3.9288764,2.6927652,,,,,,,,,,,,,, -146498,,,0.8981983065605164,0.5911901593208313,0.7471599578857422,1.2115702629089355,50000.0,0.6213000416755676,1.8459560871124268,10000.0,50021.47194981575,51746.86004567146,50021.47194981575,1716.2525057792664,4.129936933517456,0.0 -146500,3.670904,2.6612241,,,,,,,,,,,,,, -146600,3.7409708,2.6804225,,,,,,,,,,,,,, -146700,3.5708368,2.6706867,,,,,,,,,,,,,, -146800,3.8598607,2.6368213,,,,,,,,,,,,,, -146900,3.7677886,2.6151679,,,,,,,,,,,,,, -147000,3.9531522,2.6548238,,,,,,,,,,,,,, -147100,3.754005,2.7085276,,,,,,,,,,,,,, -147200,3.6764905,2.6495352,,,,,,,,,,,,,, -147300,3.7714257,2.6807485,,,,,,,,,,,,,, -147400,3.5136096,2.6224933,,,,,,,,,,,,,, -147500,3.4548974,2.5744114,,,,,,,,,,,,,, -147600,3.5207148,2.6730657,,,,,,,,,,,,,, -147700,3.7237647,2.6764374,,,,,,,,,,,,,, -147800,3.5496664,2.6128783,,,,,,,,,,,,,, -147900,3.7190115,2.6738853,,,,,,,,,,,,,, -147993,,,0.8964046239852905,0.6049222350120544,0.7515599727630615,1.208533525466919,50000.0,0.6234000325202942,1.849199652671814,10000.0,50531.48566865921,52274.35407853127,50531.48566865921,1733.6255042552948,4.185802459716797,0.0 -148000,3.8277147,2.6593478,,,,,,,,,,,,,, -148100,3.611182,2.6756985,,,,,,,,,,,,,, -148200,3.4782085,2.613889,,,,,,,,,,,,,, -148300,3.5399094,2.6184244,,,,,,,,,,,,,, -148400,3.4508774,2.6305857,,,,,,,,,,,,,, -148500,3.5550334,2.6368124,,,,,,,,,,,,,, -148600,3.5847116,2.6684651,,,,,,,,,,,,,, -148700,3.5356224,2.677588,,,,,,,,,,,,,, -148800,3.6837413,2.6135576,,,,,,,,,,,,,, -148900,3.7123024,2.6776087,,,,,,,,,,,,,, -149000,3.6337595,2.6334193,,,,,,,,,,,,,, -149100,3.7208667,2.6522152,,,,,,,,,,,,,, -149200,3.8038676,2.6745784,,,,,,,,,,,,,, -149300,3.759364,2.6351066,,,,,,,,,,,,,, -149400,3.5750258,2.5813012,,,,,,,,,,,,,, -149489,,,0.899832546710968,0.6207277774810791,0.7516599893569946,1.2279235124588013,50000.0,0.6240000128746033,1.8540517091751096,10000.0,51041.4313583374,52801.8858203888,51041.4313583374,1751.1035561561584,4.243193626403809,0.0 -149500,3.6981034,2.6931314,,,,,,,,,,,,,, -149600,3.7032304,2.6201215,,,,,,,,,,,,,, -149700,3.875347,2.6270494,,,,,,,,,,,,,, -149800,3.546278,2.622981,,,,,,,,,,,,,, -149900,3.498498,2.601986,,,,,,,,,,,,,, -150000,3.7400205,2.6213958,,,,,,,,,,,,,, -150100,3.750856,2.6110966,,,,,,,,,,,,,, -150200,3.5656383,2.6263604,,,,,,,,,,,,,, -150300,3.6455708,2.6276174,,,,,,,,,,,,,, -150400,3.7629602,2.6344416,,,,,,,,,,,,,, -150500,3.8722277,2.6405015,,,,,,,,,,,,,, -150600,3.6730175,2.6484156,,,,,,,,,,,,,, -150700,3.8249981,2.5876708,,,,,,,,,,,,,, -150800,3.7836795,2.5960093,,,,,,,,,,,,,, -150900,3.6456494,2.6128938,,,,,,,,,,,,,, -150984,,,0.9048349857330322,0.5874025821685791,0.7528600096702576,1.2086260318756104,50000.0,0.6312000155448914,1.8340696096420288,10000.0,51551.43528985977,53330.150327920914,51551.43528985977,1769.2537994384766,4.300515413284302,0.0 -151000,3.8710864,2.626595,,,,,,,,,,,,,, -151100,3.6088417,2.6525307,,,,,,,,,,,,,, -151200,3.7982695,2.6415687,,,,,,,,,,,,,, -151300,3.8544705,2.6474252,,,,,,,,,,,,,, -151400,3.885736,2.6010447,,,,,,,,,,,,,, -151500,3.6919775,2.6171167,,,,,,,,,,,,,, -151600,3.7118907,2.606093,,,,,,,,,,,,,, -151700,3.8846865,2.675908,,,,,,,,,,,,,, -151800,3.78854,2.6973476,,,,,,,,,,,,,, -151900,4.0042887,2.627361,,,,,,,,,,,,,, -152000,3.6887398,2.571444,,,,,,,,,,,,,, -152100,3.6572711,2.6446383,,,,,,,,,,,,,, -152200,3.6673064,2.657012,,,,,,,,,,,,,, -152300,3.7288597,2.600736,,,,,,,,,,,,,, -152400,3.8494449,2.6373758,,,,,,,,,,,,,, -152480,,,0.9223333597183228,0.5322718024253845,0.7513200044631958,1.214342713356018,50000.0,0.6295000314712524,1.83812952041626,10000.0,52061.50852203369,53857.58398604393,52061.50852203369,1786.5181086063385,4.345062255859375,0.0 -152500,3.449355,2.5679517,,,,,,,,,,,,,, -152600,3.683264,2.5910704,,,,,,,,,,,,,, -152700,3.682417,2.599218,,,,,,,,,,,,,, -152800,3.697942,2.535713,,,,,,,,,,,,,, -152900,3.6618297,2.580888,,,,,,,,,,,,,, -153000,4.085297,2.6361306,,,,,,,,,,,,,, -153100,3.678386,2.5597324,,,,,,,,,,,,,, -153200,3.8438823,2.5768106,,,,,,,,,,,,,, -153300,3.702498,2.630015,,,,,,,,,,,,,, -153400,3.7455354,2.6473649,,,,,,,,,,,,,, -153500,3.9861875,2.645641,,,,,,,,,,,,,, -153600,4.0072637,2.6362803,,,,,,,,,,,,,, -153700,3.994266,2.7226443,,,,,,,,,,,,,, -153800,3.732828,2.5990071,,,,,,,,,,,,,, -153900,3.6475182,2.5766845,,,,,,,,,,,,,, -153975,,,0.9167131781578064,0.5371339917182922,0.7554799914360046,1.202966809272766,50000.0,0.6319000124931335,1.841834664344788,10000.0,52571.40790128708,54384.898052454,52571.40790128708,1803.822417974472,4.4028167724609375,0.0 -154000,3.787462,2.6291244,,,,,,,,,,,,,, -154100,3.9246674,2.5934892,,,,,,,,,,,,,, -154200,3.8042173,2.6405363,,,,,,,,,,,,,, -154300,3.6717873,2.6327069,,,,,,,,,,,,,, -154400,4.1313596,2.640174,,,,,,,,,,,,,, -154500,3.641603,2.591239,,,,,,,,,,,,,, -154600,4.0880227,2.6623564,,,,,,,,,,,,,, -154700,3.6383784,2.563777,,,,,,,,,,,,,, -154800,3.8571105,2.6105342,,,,,,,,,,,,,, -154900,3.7797885,2.5866568,,,,,,,,,,,,,, -155000,4.0370073,2.656653,,,,,,,,,,,,,, -155100,3.815509,2.599027,,,,,,,,,,,,,, -155200,4.3374486,2.5598578,,,,,,,,,,,,,, -155300,3.978761,2.654893,,,,,,,,,,,,,, -155400,3.9145076,2.6552334,,,,,,,,,,,,,, -155471,,,0.9189253449440002,0.5368630886077881,0.7572599649429321,1.1976443529129028,50000.0,0.6337000131607056,1.82937240600586,10000.0,53081.45244860649,54912.43840265274,53081.45244860649,1821.2124042510984,4.457838535308838,0.0 -155500,4.093766,2.5769546,,,,,,,,,,,,,, -155600,3.983828,2.636055,,,,,,,,,,,,,, -155700,3.9185328,2.5670254,,,,,,,,,,,,,, -155800,3.581011,2.5461164,,,,,,,,,,,,,, -155900,3.7838938,2.6184158,,,,,,,,,,,,,, -156000,3.559234,2.5835485,,,,,,,,,,,,,, -156100,4.0558357,2.5936408,,,,,,,,,,,,,, -156200,3.7259047,2.6248388,,,,,,,,,,,,,, -156300,4.021242,2.6314,,,,,,,,,,,,,, -156400,3.7596312,2.5778875,,,,,,,,,,,,,, -156500,3.865038,2.5722973,,,,,,,,,,,,,, -156600,3.586552,2.5519464,,,,,,,,,,,,,, -156700,3.6891246,2.6000195,,,,,,,,,,,,,, -156800,3.637279,2.551366,,,,,,,,,,,,,, -156900,3.6914287,2.5824375,,,,,,,,,,,,,, -156967,,,0.9168327450752258,0.5451948642730713,0.7570799589157104,1.195743203163147,50000.0,0.6295000314712524,1.8315805196762085,10000.0,53591.627490758896,55440.64496970177,53591.627490758896,1839.1349787712093,4.514558553695679,0.0 -157000,3.7443461,2.549309,,,,,,,,,,,,,, -157100,3.7589202,2.6231976,,,,,,,,,,,,,, -157200,4.158996,2.5972614,,,,,,,,,,,,,, -157300,3.791649,2.5317147,,,,,,,,,,,,,, -157400,3.9976306,2.590564,,,,,,,,,,,,,, -157500,3.861828,2.5459318,,,,,,,,,,,,,, -157600,3.653276,2.5675588,,,,,,,,,,,,,, -157700,3.8584423,2.5666523,,,,,,,,,,,,,, -157800,3.4993694,2.545891,,,,,,,,,,,,,, -157900,4.0158587,2.5839527,,,,,,,,,,,,,, -158000,3.824447,2.5643313,,,,,,,,,,,,,, -158100,3.9483113,2.6444583,,,,,,,,,,,,,, -158200,3.6327443,2.5855982,,,,,,,,,,,,,, -158300,3.977144,2.6576726,,,,,,,,,,,,,, -158400,4.215748,2.6168938,,,,,,,,,,,,,, -158462,,,0.9191047549247742,0.5344241857528687,0.7557399868965149,1.1946972608566284,50000.0,0.6314000487327576,1.8226438760757449,10000.0,54101.57953572273,55967.96126079559,54101.57953572273,1856.390554189682,4.57092547416687,0.0 -158500,3.5647426,2.552161,,,,,,,,,,,,,, -158600,3.8060114,2.5348265,,,,,,,,,,,,,, -158700,3.9914222,2.6128838,,,,,,,,,,,,,, -158800,3.74375,2.58508,,,,,,,,,,,,,, -158900,3.928158,2.5871959,,,,,,,,,,,,,, -159000,3.838547,2.5938435,,,,,,,,,,,,,, -159100,4.0347304,2.6449075,,,,,,,,,,,,,, -159200,3.9709156,2.61333,,,,,,,,,,,,,, -159300,3.779087,2.584024,,,,,,,,,,,,,, -159400,3.995307,2.6462526,,,,,,,,,,,,,, -159500,3.7095685,2.5687735,,,,,,,,,,,,,, -159600,3.8405519,2.5404866,,,,,,,,,,,,,, -159700,3.77335,2.5726457,,,,,,,,,,,,,, -159800,3.7441669,2.5661254,,,,,,,,,,,,,, -159900,4.209159,2.6094131,,,,,,,,,,,,,, -159957,,,0.919144570827484,0.5355068445205688,0.7576599717140198,1.2013722658157349,50000.0,0.6325000524520874,1.8268934488296509,10000.0,54611.48469758034,56495.354562044144,54611.48469758034,1873.7629334926603,4.634932041168213,0.0 -160000,3.7023313,2.5295196,,,,,,,,,,,,,, -160100,3.860358,2.624714,,,,,,,,,,,,,, -160200,4.024345,2.6184816,,,,,,,,,,,,,, -160300,4.0079803,2.5785031,,,,,,,,,,,,,, -160400,4.1608095,2.5937,,,,,,,,,,,,,, -160500,3.699715,2.591912,,,,,,,,,,,,,, -160600,4.0190635,2.619067,,,,,,,,,,,,,, -160700,3.97253,2.6352046,,,,,,,,,,,,,, -160800,4.1092787,2.582242,,,,,,,,,,,,,, -160900,3.7580354,2.61377,,,,,,,,,,,,,, -161000,4.162163,2.509214,,,,,,,,,,,,,, -161100,3.8426645,2.5813131,,,,,,,,,,,,,, -161200,4.094522,2.534421,,,,,,,,,,,,,, -161300,4.072258,2.637075,,,,,,,,,,,,,, -161400,3.7880218,2.5289726,,,,,,,,,,,,,, -161453,,,0.9230110049247742,0.5109298229217529,0.7595599889755249,1.1808478832244873,50000.0,0.6340000033378601,1.8122520446777344,10000.0,55121.617911338806,57023.04491233826,55121.617911338806,1891.2067058086395,4.697674751281738,0.0 -161500,4.4368653,2.611896,,,,,,,,,,,,,, -161600,3.9985378,2.5951679,,,,,,,,,,,,,, -161700,3.9330525,2.5157428,,,,,,,,,,,,,, -161800,3.7747774,2.5241628,,,,,,,,,,,,,, -161900,3.9659834,2.5604825,,,,,,,,,,,,,, -162000,4.1151686,2.58662,,,,,,,,,,,,,, -162100,3.9048831,2.559655,,,,,,,,,,,,,, -162200,4.0059566,2.581246,,,,,,,,,,,,,, -162300,3.9401262,2.6231046,,,,,,,,,,,,,, -162400,3.6335883,2.5299797,,,,,,,,,,,,,, -162500,3.978045,2.566204,,,,,,,,,,,,,, -162600,3.8999636,2.524293,,,,,,,,,,,,,, -162700,3.9965432,2.567211,,,,,,,,,,,,,, -162800,3.9292152,2.549914,,,,,,,,,,,,,, -162900,3.8842046,2.5208688,,,,,,,,,,,,,, -162948,,,0.9344706535339355,0.4809099733829498,0.7597000002861023,1.189298152923584,50000.0,0.6338000297546387,1.813969612121582,10000.0,55631.54829573631,57550.35655713081,55631.54829573631,1908.4826707839968,4.7523229122161865,0.0 -163000,3.8029723,2.5609274,,,,,,,,,,,,,, -163100,4.0155425,2.570723,,,,,,,,,,,,,, -163200,3.877286,2.6135774,,,,,,,,,,,,,, -163300,4.3310585,2.605566,,,,,,,,,,,,,, -163400,3.8257964,2.5629125,,,,,,,,,,,,,, -163500,3.9390538,2.5697684,,,,,,,,,,,,,, -163600,4.3317823,2.5400174,,,,,,,,,,,,,, -163700,3.7842026,2.5477757,,,,,,,,,,,,,, -163800,4.0750337,2.5206275,,,,,,,,,,,,,, -163900,3.9311352,2.5340319,,,,,,,,,,,,,, -164000,4.0466976,2.579873,,,,,,,,,,,,,, -164100,3.8476107,2.5376198,,,,,,,,,,,,,, -164200,4.1839905,2.532318,,,,,,,,,,,,,, -164300,4.0169873,2.577826,,,,,,,,,,,,,, -164400,3.8930557,2.614697,,,,,,,,,,,,,, -164444,,,0.9301458597183228,0.4850182235240936,0.7604199647903442,1.181099534034729,50000.0,0.6353000402450562,1.8088539838790887,10000.0,56141.64248919487,58077.93835878372,56141.64248919487,1925.8598954677584,4.809988737106323,0.0 -164500,3.8145936,2.5383847,,,,,,,,,,,,,, -164600,3.8624318,2.5653727,,,,,,,,,,,,,, -164700,4.0332007,2.5624194,,,,,,,,,,,,,, -164800,3.9310749,2.5616288,,,,,,,,,,,,,, -164900,4.0045037,2.5796838,,,,,,,,,,,,,, -165000,3.6672354,2.5579467,,,,,,,,,,,,,, -165100,3.8892338,2.5743587,,,,,,,,,,,,,, -165200,3.836169,2.5770438,,,,,,,,,,,,,, -165300,4.293868,2.5630682,,,,,,,,,,,,,, -165400,4.010149,2.569331,,,,,,,,,,,,,, -165500,4.151148,2.5612223,,,,,,,,,,,,,, -165600,3.8777437,2.5559633,,,,,,,,,,,,,, -165700,4.0049725,2.5564015,,,,,,,,,,,,,, -165800,4.1838746,2.5701642,,,,,,,,,,,,,, -165900,3.9244292,2.55611,,,,,,,,,,,,,, -165939,,,0.9299465417861938,0.4834108352661133,0.7608599662780762,1.1775710582733154,50000.0,0.6385000348091125,1.806591510772705,10000.0,56651.543511390686,58605.25307846069,56651.543511390686,1943.168488740921,4.8650195598602295,0.0 -166000,3.7686536,2.5326176,,,,,,,,,,,,,, -166100,3.7108283,2.5301423,,,,,,,,,,,,,, -166200,3.875974,2.5700812,,,,,,,,,,,,,, -166300,4.151319,2.570736,,,,,,,,,,,,,, -166400,4.1902127,2.5707495,,,,,,,,,,,,,, -166500,4.1155815,2.5478487,,,,,,,,,,,,,, -166600,3.9471025,2.5079732,,,,,,,,,,,,,, -166700,4.101851,2.6019194,,,,,,,,,,,,,, -166800,4.3693967,2.5297525,,,,,,,,,,,,,, -166900,4.106139,2.5316362,,,,,,,,,,,,,, -167000,3.9474103,2.54881,,,,,,,,,,,,,, -167100,4.282256,2.6108842,,,,,,,,,,,,,, -167200,3.6865227,2.5305986,,,,,,,,,,,,,, -167300,3.690477,2.539565,,,,,,,,,,,,,, -167400,3.8530076,2.5415998,,,,,,,,,,,,,, -167435,,,0.9329559803009032,0.4767115414142608,0.7617599964141846,1.1800211668014526,50000.0,0.6402000188827515,1.803207874298096,10000.0,57161.636585474014,59132.583621263504,57161.636585474014,1960.2930953502653,4.926531791687012,0.0 -167500,3.7485926,2.5146432,,,,,,,,,,,,,, -167600,3.999877,2.530369,,,,,,,,,,,,,, -167700,4.121007,2.5379915,,,,,,,,,,,,,, -167800,4.009285,2.5693073,,,,,,,,,,,,,, -167900,3.8355997,2.5322123,,,,,,,,,,,,,, -168000,3.674487,2.4819484,,,,,,,,,,,,,, -168100,4.234211,2.5662043,,,,,,,,,,,,,, -168200,4.2326694,2.574509,,,,,,,,,,,,,, -168300,3.85256,2.5300345,,,,,,,,,,,,,, -168400,3.5909047,2.5218487,,,,,,,,,,,,,, -168500,3.8355696,2.565901,,,,,,,,,,,,,, -168600,4.1851993,2.5858428,,,,,,,,,,,,,, -168700,3.8962772,2.568635,,,,,,,,,,,,,, -168800,4.1007943,2.5148578,,,,,,,,,,,,,, -168900,4.142076,2.5890958,,,,,,,,,,,,,, -168931,,,0.9340720176696776,0.4837055206298828,0.7615999579429626,1.1844817399978638,50000.0,0.6378000378608704,1.8050771951675413,10000.0,57671.803205251694,59660.20059251785,57671.803205251694,1977.6313047409053,4.987208604812622,0.0 -169000,4.1363163,2.56827,,,,,,,,,,,,,, -169100,3.8414614,2.5380483,,,,,,,,,,,,,, -169200,4.090688,2.502523,,,,,,,,,,,,,, -169300,4.0682516,2.6096466,,,,,,,,,,,,,, -169400,4.004904,2.5139425,,,,,,,,,,,,,, -169500,4.1404357,2.5187986,,,,,,,,,,,,,, -169600,3.8590815,2.5455189,,,,,,,,,,,,,, -169700,4.107006,2.5360332,,,,,,,,,,,,,, -169800,4.0697284,2.5735624,,,,,,,,,,,,,, -169900,3.9833615,2.4973502,,,,,,,,,,,,,, -170000,4.404216,2.6115746,,,,,,,,,,,,,, -170100,3.9478986,2.5156302,,,,,,,,,,,,,, -170200,4.090961,2.5153625,,,,,,,,,,,,,, -170300,3.7894845,2.5182505,,,,,,,,,,,,,, -170400,3.8781471,2.5528073,,,,,,,,,,,,,, -170427,,,0.9329758882522584,0.4784607589244842,0.7625399827957153,1.1756603717803955,50000.0,0.6420000195503235,1.7998629808425903,10000.0,58181.90040636063,60187.52374267578,58181.90040636063,1994.7446548938751,5.04884672164917,0.0 -170500,4.2217064,2.5391238,,,,,,,,,,,,,, -170600,4.204429,2.561545,,,,,,,,,,,,,, -170700,4.094582,2.5636694,,,,,,,,,,,,,, -170800,4.1452627,2.5511234,,,,,,,,,,,,,, -170900,4.231926,2.5981693,,,,,,,,,,,,,, -171000,4.009327,2.5662532,,,,,,,,,,,,,, -171100,4.1544456,2.5240498,,,,,,,,,,,,,, -171200,3.9650805,2.517924,,,,,,,,,,,,,, -171300,4.084906,2.5461779,,,,,,,,,,,,,, -171400,4.1626725,2.5684767,,,,,,,,,,,,,, -171500,3.9137669,2.4974458,,,,,,,,,,,,,, -171600,4.3014035,2.4972575,,,,,,,,,,,,,, -171700,4.0069766,2.5843725,,,,,,,,,,,,,, -171800,3.9953756,2.501423,,,,,,,,,,,,,, -171900,3.8129947,2.4970977,,,,,,,,,,,,,, -171923,,,0.9389548301696776,0.4606906175613403,0.7628600001335144,1.1755813360214231,50000.0,0.6399000287055969,1.7999354600906372,10000.0,58692.09399271011,60715.07775473595,58692.09399271011,2011.99293589592,5.10937762260437,0.0 -172000,3.9535265,2.5324135,,,,,,,,,,,,,, -172100,3.9098315,2.5602207,,,,,,,,,,,,,, -172200,4.1285844,2.5359411,,,,,,,,,,,,,, -172300,4.1632824,2.581176,,,,,,,,,,,,,, -172400,4.2616553,2.5662794,,,,,,,,,,,,,, -172500,4.072199,2.5416474,,,,,,,,,,,,,, -172600,4.0614333,2.5227213,,,,,,,,,,,,,, -172700,4.2333145,2.56549,,,,,,,,,,,,,, -172800,3.8926802,2.550499,,,,,,,,,,,,,, -172900,3.981163,2.5361533,,,,,,,,,,,,,, -173000,4.008811,2.5076866,,,,,,,,,,,,,, -173100,4.0229983,2.5168576,,,,,,,,,,,,,, -173200,4.0591726,2.526485,,,,,,,,,,,,,, -173300,3.914256,2.521642,,,,,,,,,,,,,, -173400,3.9924295,2.4683676,,,,,,,,,,,,,, -173418,,,0.9387754797935486,0.4628245830535888,0.761900007724762,1.1791095733642578,50000.0,0.64000004529953,1.8056416511535645,10000.0,59202.0762693882,61242.43579864502,59202.0762693882,2029.256034374237,5.169875860214233,0.0 -173500,4.1261396,2.522703,,,,,,,,,,,,,, -173600,3.9878576,2.5294151,,,,,,,,,,,,,, -173700,4.4235387,2.581799,,,,,,,,,,,,,, -173800,3.9934576,2.5225818,,,,,,,,,,,,,, -173900,3.863104,2.480256,,,,,,,,,,,,,, -174000,4.022063,2.4859653,,,,,,,,,,,,,, -174100,4.199226,2.5415766,,,,,,,,,,,,,, -174200,3.9884179,2.5966659,,,,,,,,,,,,,, -174300,4.04698,2.5186496,,,,,,,,,,,,,, -174400,4.2653885,2.4961877,,,,,,,,,,,,,, -174500,4.0710645,2.5728362,,,,,,,,,,,,,, -174600,3.9442017,2.4853811,,,,,,,,,,,,,, -174700,4.0435443,2.5478292,,,,,,,,,,,,,, -174800,3.868466,2.550007,,,,,,,,,,,,,, -174900,3.9762743,2.518443,,,,,,,,,,,,,, -174913,,,0.9377790093421936,0.4617302119731903,0.7626799941062927,1.1732220649719238,50000.0,0.6394000053405762,1.7951186895370483,10000.0,59712.00906252861,61769.894496679306,59712.00906252861,2046.667736530304,5.231820344924927,0.0 -175000,4.0133586,2.5457723,,,,,,,,,,,,,, -175100,4.238617,2.5835881,,,,,,,,,,,,,, -175200,4.015615,2.5807984,,,,,,,,,,,,,, -175300,4.1007743,2.5204568,,,,,,,,,,,,,, -175400,3.9304821,2.550138,,,,,,,,,,,,,, -175500,3.8991287,2.5379577,,,,,,,,,,,,,, -175600,4.0296087,2.5138545,,,,,,,,,,,,,, -175700,4.028932,2.5116215,,,,,,,,,,,,,, -175800,3.936973,2.5433042,,,,,,,,,,,,,, -175900,4.088377,2.5260825,,,,,,,,,,,,,, -176000,4.106951,2.615747,,,,,,,,,,,,,, -176100,4.090584,2.5063097,,,,,,,,,,,,,, -176200,3.9528022,2.4754024,,,,,,,,,,,,,, -176300,4.250314,2.5736108,,,,,,,,,,,,,, -176400,4.1943564,2.519204,,,,,,,,,,,,,, -176409,,,0.93949294090271,0.4578624963760376,0.76419997215271,1.1710017919540403,50000.0,0.64410001039505,1.790612816810608,10000.0,60222.23619198799,62297.375115156174,60222.23619198799,2063.812755346298,5.288837909698486,0.0 -176500,4.0617166,2.4894986,,,,,,,,,,,,,, -176600,3.9355228,2.4970844,,,,,,,,,,,,,, -176700,4.0637746,2.4622295,,,,,,,,,,,,,, -176800,3.9948227,2.5471492,,,,,,,,,,,,,, -176900,3.925123,2.4496994,,,,,,,,,,,,,, -177000,4.000763,2.5004253,,,,,,,,,,,,,, -177100,4.111804,2.4549983,,,,,,,,,,,,,, -177200,4.4029946,2.556044,,,,,,,,,,,,,, -177300,4.248279,2.527315,,,,,,,,,,,,,, -177400,4.0003424,2.4635983,,,,,,,,,,,,,, -177500,4.1734796,2.5173423,,,,,,,,,,,,,, -177600,4.1983914,2.550582,,,,,,,,,,,,,, -177700,4.066331,2.5183256,,,,,,,,,,,,,, -177800,4.201087,2.5322976,,,,,,,,,,,,,, -177900,3.8909683,2.4612925,,,,,,,,,,,,,, -177904,,,0.9384167790412904,0.4547231495380401,0.7644400000572205,1.167048454284668,50000.0,0.6421000361442566,1.7899819612503052,10000.0,60732.223915576935,62824.61188578606,60732.223915576935,2080.94895029068,5.348676443099976,0.0 -178000,3.9221408,2.4880042,,,,,,,,,,,,,, -178100,3.690136,2.4754624,,,,,,,,,,,,,, -178200,3.7527122,2.4606276,,,,,,,,,,,,,, -178300,4.3566513,2.5777178,,,,,,,,,,,,,, -178400,4.0683928,2.5206325,,,,,,,,,,,,,, -178500,4.1267486,2.5142937,,,,,,,,,,,,,, -178600,4.216478,2.5199156,,,,,,,,,,,,,, -178700,3.9498289,2.4872231,,,,,,,,,,,,,, -178800,4.141542,2.531301,,,,,,,,,,,,,, -178900,4.105282,2.502685,,,,,,,,,,,,,, -179000,4.085905,2.4879718,,,,,,,,,,,,,, -179100,3.83694,2.482161,,,,,,,,,,,,,, -179200,3.9615242,2.540631,,,,,,,,,,,,,, -179300,3.8424954,2.486246,,,,,,,,,,,,,, -179400,,,0.9395527839660645,0.4537563621997833,0.7645599842071533,1.1670894622802734,50000.0,0.6438000202178955,1.7903746366500854,10000.0,61242.4484193325,63352.25589585304,61242.4484193325,2098.2528672218323,5.410964012145996,0.0 -179400,4.2586284,2.5398438,,,,,,,,,,,,,, -179500,4.0312552,2.5511162,,,,,,,,,,,,,, -179600,3.8319366,2.5139894,,,,,,,,,,,,,, -179700,3.9859793,2.4828684,,,,,,,,,,,,,, -179800,4.255048,2.5288277,,,,,,,,,,,,,, -179900,4.0057425,2.514669,,,,,,,,,,,,,, -180000,4.0152936,2.5274172,,,,,,,,,,,,,, -180100,3.9604595,2.5259805,,,,,,,,,,,,,, -180200,4.265509,2.4887478,,,,,,,,,,,,,, -180300,3.910728,2.509476,,,,,,,,,,,,,, -180400,4.010951,2.5131888,,,,,,,,,,,,,, -180500,3.997635,2.5570047,,,,,,,,,,,,,, -180600,4.136631,2.5022929,,,,,,,,,,,,,, -180700,4.156547,2.487737,,,,,,,,,,,,,, -180800,3.656561,2.4944787,,,,,,,,,,,,,, -180895,,,0.9403898119926452,0.4528295397758484,0.7649999856948853,1.168945074081421,50000.0,0.6429000496864319,1.7909166812896729,10000.0,61752.42921423912,63879.60906338692,61752.42921423912,2115.509332180023,5.475275278091431,0.0 -180900,4.214675,2.4889235,,,,,,,,,,,,,, -181000,4.391697,2.5468292,,,,,,,,,,,,,, -181100,4.017286,2.5160577,,,,,,,,,,,,,, -181200,3.9421968,2.4885373,,,,,,,,,,,,,, -181300,3.9067814,2.5160828,,,,,,,,,,,,,, -181400,3.829815,2.5070891,,,,,,,,,,,,,, -181500,4.058096,2.5292861,,,,,,,,,,,,,, -181600,4.25356,2.4970224,,,,,,,,,,,,,, -181700,4.2153835,2.4407537,,,,,,,,,,,,,, -181800,3.9576027,2.5316515,,,,,,,,,,,,,, -181900,4.077906,2.537658,,,,,,,,,,,,,, -182000,4.1282616,2.5554183,,,,,,,,,,,,,, -182100,4.140591,2.5192566,,,,,,,,,,,,,, -182200,3.8056922,2.4887066,,,,,,,,,,,,,, -182300,4.271711,2.571149,,,,,,,,,,,,,, -182391,,,0.9399114847183228,0.4492985010147095,0.7644000053405762,1.1697618961334229,50000.0,0.6430000066757202,1.7924065589904783,10000.0,62262.553658008575,64407.16589832306,62262.553658008575,2132.8234283924103,5.542778015136719,0.0 -182400,4.240525,2.5166345,,,,,,,,,,,,,, -182500,4.3059406,2.49631,,,,,,,,,,,,,, -182600,3.9322164,2.4913049,,,,,,,,,,,,,, -182700,4.0052767,2.505733,,,,,,,,,,,,,, -182800,3.9589298,2.5160055,,,,,,,,,,,,,, -182900,3.9580903,2.4823642,,,,,,,,,,,,,, -183000,4.0981064,2.51748,,,,,,,,,,,,,, -183100,4.5805106,2.4868124,,,,,,,,,,,,,, -183200,3.9131987,2.4925323,,,,,,,,,,,,,, -183300,4.049259,2.5504937,,,,,,,,,,,,,, -183400,3.7839994,2.4669993,,,,,,,,,,,,,, -183500,3.824924,2.4712036,,,,,,,,,,,,,, -183600,3.87745,2.4756649,,,,,,,,,,,,,, -183700,3.995107,2.524057,,,,,,,,,,,,,, -183800,4.241805,2.5756388,,,,,,,,,,,,,, -183883,,,0.9393534660339355,0.4564688205718994,0.7648400068283081,1.1722710132598877,50000.0,0.6425000429153442,1.794602870941162,10000.0,62771.477128744125,64934.63848924637,62771.477128744125,2150.117058753968,6.747524976730347,0.0 -183900,4.0077844,2.5225396,,,,,,,,,,,,,, -184000,3.8488522,2.494617,,,,,,,,,,,,,, -184100,4.292134,2.4912102,,,,,,,,,,,,,, -184200,3.9732015,2.5175648,,,,,,,,,,,,,, -184300,4.0144734,2.4696703,,,,,,,,,,,,,, -184400,3.907752,2.5255833,,,,,,,,,,,,,, -184500,4.0767574,2.5515046,,,,,,,,,,,,,, -184577,,,,,,,,,,,63008.11992740631,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/eval_measurements.csv deleted file mode 100644 index f099c0711..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.43062710762024,0.0,31.49718737602234,1,0,31.49718737602234,0.0012000000569969,6.910790920257568,10000,48.92790198326111,0.0011758609907701,6.911500930786133,0.0010199999669566,6.91091251373291,50000 -34.72050094604492,0.0184686183929443,541.692113161087,1492,0,541.692113161087,0.0446000024676322,5.659720420837402,10000,576.4821372032166,0.0675422474741935,5.367956638336182,0.064699999988079,5.413669109344482,50000 -52.13346600532532,0.0493288040161132,1051.8486168384552,2982,0,1051.8486168384552,0.1075000017881393,4.850800514221191,10000,1104.1325709819794,0.168207898736,4.278195858001709,0.1507599949836731,4.392276287078857,50000 -69.77436900138855,0.0782089233398437,1561.8328936100006,4471,0,1561.8328936100006,0.1806000024080276,4.246772289276123,10000,1631.8379180431366,0.2690728604793548,3.525336265563965,0.2464799880981445,3.675424337387085,50000 -87.29565978050232,0.1079814434051513,2072.0533249378204,5961,0,2072.0533249378204,0.2507000267505646,3.733696699142456,10000,2159.659814119339,0.3585578799247741,2.934255361557007,0.3324399888515472,3.0907390117645264,50000 -104.87834978103638,0.136627197265625,2582.084417819977,7451,0,2582.084417819977,0.2959000170230865,3.4668314456939697,10000,2687.353482246399,0.4449138939380646,2.4179675579071045,0.3871199786663055,2.7826380729675293,50000 -123.03875970840454,0.1637771129608154,3092.1619415283203,8941,0,3092.1619415283203,0.3477000296115875,3.115972757339477,10000,3215.6690373420715,0.5053611397743225,2.105006217956543,0.4567599892616272,2.381918430328369,50000 -140.51118230819702,0.1939032077789306,3602.4153928756714,10433,0,3602.4153928756714,0.3731000125408172,2.946731805801392,10000,3743.475313663482,0.5231783986091614,2.0281193256378174,0.4831199944019317,2.261015653610229,50000 -158.01940941810608,0.2233092784881591,4112.552248477936,11925,0,4112.552248477936,0.4001000225543976,2.7808094024658203,10000,4271.202058792114,0.5595105290412903,1.8536113500595093,0.5171599984169006,2.09015154838562,50000 -175.39147543907166,0.2514383792877197,4622.760374307632,13418,0,4622.760374307632,0.4075000286102295,2.762702703475952,10000,4798.860710382462,0.5712292790412903,1.7940986156463623,0.5276399850845337,2.01565170288086,50000 -192.6115975379944,0.2807481288909912,5132.832676887512,14911,0,5132.832676887512,0.426800012588501,2.671276807785034,10000,5326.234619617462,0.5833864808082581,1.730528473854065,0.5450999736785889,1.929821848869324,50000 -209.93391489982605,0.3114583492279053,5643.064661502838,16405,0,5643.064661502838,0.4225000143051147,2.6825528144836426,10000,5853.869685411453,0.5859175324440002,1.707316517829895,0.549780011177063,1.903155207633972,50000 -227.40043759346008,0.3407416343688965,6153.015790462494,17899,0,6153.015790462494,0.4250000119209289,2.651263236999512,10000,6381.367884874344,0.6173269748687744,1.5603954792022705,0.5484799742698669,1.9227455854415887,50000 -244.9995155334473,0.3710279464721679,6663.139073133469,19394,0,6663.139073133469,0.4438000321388244,2.586523532867432,10000,6909.171550512314,0.6105906963348389,1.588725447654724,0.5607399940490723,1.864424467086792,50000 -262.37942600250244,0.4020321369171142,7173.313796281815,20889,0,7173.313796281815,0.4298000335693359,2.619914054870605,10000,7436.808223009109,0.6055883169174194,1.5978363752365112,0.5553199648857117,1.85867977142334,50000 -279.904082775116,0.4324047565460205,7683.348705768585,22384,0,7683.348705768585,0.4462000131607055,2.553001165390014,10000,7964.449161529541,0.6145368218421936,1.5811792612075806,0.5697999596595764,1.8240339756011963,50000 -297.7186679840088,0.4656527042388916,8193.355751514435,23879,0,8193.355751514435,0.4519000351428985,2.501939296722412,10000,8492.355654716492,0.6197385191917419,1.565221071243286,0.5730199813842773,1.7761461734771729,50000 -315.1154990196228,0.4975943565368652,8703.507536411285,25375,0,8703.507536411285,0.4537000358104706,2.5031628608703613,10000,9019.985451698303,0.6237244606018066,1.5361076593399048,0.5832799673080444,1.745568037033081,50000 -332.7810888290405,0.5300009250640869,9213.550162315369,26871,0,9213.550162315369,0.4626000225543976,2.476860761642456,10000,9547.775417804718,0.6704002022743225,1.311537265777588,0.5862399935722351,1.7210910320281982,50000 -350.4288082122803,0.5664412975311279,9723.562857627869,28366,0,9723.562857627869,0.453900009393692,2.494658708572388,10000,10075.523859977722,0.6376155614852905,1.4694420099258425,0.57669997215271,1.7748409509658811,50000 -367.77251267433167,0.6059026718139648,10233.790237903597,29863,0,10233.790237903597,0.4571000337600708,2.4968318939208984,10000,10603.18566417694,0.6395886540412903,1.4568294286727903,0.586080014705658,1.7302289009094238,50000 -385.0090951919556,0.6394505500793457,10743.71028470993,31359,0,10743.71028470993,0.4750000238418579,2.414470672607422,10000,11130.426788568497,0.638671875,1.4534977674484253,0.5884400010108948,1.7131558656692505,50000 -402.6841251850128,0.6713178157806396,11253.648388624191,32855,0,11253.648388624191,0.4702000319957733,2.409991502761841,10000,11658.121633052826,0.6384924650192261,1.4499725103378296,0.5962799787521362,1.6743437051773071,50000 -420.37781167030334,0.7036874294281006,11763.811690092089,34352,0,11763.811690092089,0.4675000309944153,2.4392247200012207,10000,12186.06164097786,0.6353236436843872,1.4725189208984375,0.5895400047302246,1.7027602195739746,50000 -437.6548173427582,0.7323830127716064,12273.92913389206,35849,0,12273.92913389206,0.4771000146865845,2.394937753677368,10000,12713.53648853302,0.652762234210968,1.4022929668426514,0.5992199778556824,1.6665433645248413,50000 -455.2939562797546,0.7698986530303955,12784.103419065475,37346,0,12784.103419065475,0.4744000136852264,2.4055161476135254,10000,13241.438656330109,0.6575254797935486,1.3663238286972046,0.5941999554634094,1.6988669633865356,50000 -472.52238607406616,0.8064682483673096,13294.079423427582,38841,0,13294.079423427582,0.4828000366687774,2.4071831703186035,10000,13768.731862545012,0.6528818607330322,1.384704828262329,0.5984399914741516,1.6746433973312378,50000 -490.08971118927,0.845099687576294,13804.172444343569,40338,0,13804.172444343569,0.4669000208377838,2.4362292289733887,10000,14296.480261325836,0.6393694281578064,1.4571224451065063,0.5904799699783325,1.712005615234375,50000 -507.6242277622223,0.8813223838806152,14314.146545886992,41835,0,14314.146545886992,0.4811000227928161,2.351424217224121,10000,14824.075912237167,0.6407644748687744,1.4308786392211914,0.5977599620819092,1.671799659729004,50000 -524.9377455711365,0.915226936340332,14824.182245731354,43332,0,14824.182245731354,0.4744000136852264,2.401122093200684,10000,15351.50887298584,0.6420599222183228,1.4258434772491455,0.5988199710845947,1.6653971672058103,50000 -542.2236630916595,0.9501383304595948,15334.278590202332,44829,0,15334.278590202332,0.4788000285625458,2.33974552154541,10000,15878.976461172104,0.6495137214660645,1.4016178846359253,0.6055200099945068,1.6298383474349976,50000 -559.746725320816,0.9882421493530272,15844.478893518448,46327,0,15844.478893518448,0.4900000095367431,2.314314126968384,10000,16406.788619995117,0.6932198405265808,1.2089905738830566,0.6118999719619751,1.6064671277999878,50000 -577.9383962154388,1.026573657989502,16354.602724313736,47825,0,16354.602724313736,0.4797000288963318,2.354118585586548,10000,16935.195270061493,0.6660555005073547,1.3233352899551392,0.604699969291687,1.6360735893249512,50000 -595.4793081283569,1.063713788986206,16864.671887874603,49323,0,16864.671887874603,0.4821000099182129,2.3614590167999268,10000,17462.89333844185,0.6553133130073547,1.379987716674805,0.6030399799346924,1.6445624828338623,50000 -612.9747793674469,1.0999596118927002,17374.922873973846,50822,0,17374.922873973846,0.4883000254631042,2.343726396560669,10000,17990.727430582047,0.6630061864852905,1.3523014783859253,0.6115399599075317,1.6137793064117432,50000 -630.2206964492798,1.1384408473968506,17885.14226746559,52320,0,17885.14226746559,0.4897000193595886,2.339770793914795,10000,18518.28278398513,0.6526426672935486,1.3899118900299072,0.6061800122261047,1.6280734539031982,50000 -647.4825391769409,1.1758079528808594,18395.222403526303,53818,0,18395.222403526303,0.4889000356197357,2.3146893978118896,10000,19045.71305155754,0.6602359414100647,1.361589789390564,0.613379955291748,1.606788992881775,50000 -664.8474590778351,1.2139925956726074,18905.270992279053,55316,0,18905.270992279053,0.4910000264644623,2.2842953205108643,10000,19573.216374635696,0.699238657951355,1.1795507669448853,0.6133399605751038,1.5974823236465454,50000 -682.3041090965271,1.252126932144165,19415.29590034485,56814,0,19415.29590034485,0.4950000345706939,2.303964853286743,10000,20100.7879998684,0.685965359210968,1.2393864393234253,0.6195799708366394,1.5674865245819092,50000 -699.7403359413147,1.298933506011963,19925.24132657051,58312,0,19925.24132657051,0.484000027179718,2.3704142570495605,10000,20628.266762018204,0.6573660373687744,1.369436264038086,0.6055799722671509,1.6424763202667236,50000 -717.0170676708221,1.3389678001403809,20435.426945209503,59811,0,20435.426945209503,0.4971000254154205,2.3210856914520264,10000,21155.81867671013,0.6717155575752258,1.299277424812317,0.6213600039482117,1.5658512115478516,50000 -734.4156177043915,1.3787786960601809,20945.637128591537,61310,0,20945.637128591537,0.4895000159740448,2.3405680656433105,10000,21683.5158367157,0.6558912396430969,1.3650057315826416,0.6165199875831604,1.608451247215271,50000 -751.8346419334412,1.415900468826294,21455.61483645439,62808,0,21455.61483645439,0.4837000370025635,2.364459991455078,10000,22211.000241994858,0.6518056392669678,1.3800477981567385,0.6077399849891663,1.6119192838668823,50000 -769.2622895240784,1.4578070640563965,21965.65988755226,64306,0,21965.65988755226,0.4803000092506408,2.3926892280578613,10000,22738.56585907936,0.6535594463348389,1.3952399492263794,0.6021400094032288,1.6519254446029663,50000 -786.6829445362091,1.5008704662322998,22475.830349206924,65805,0,22475.830349206924,0.4997000098228454,2.25830078125,10000,23266.250294208527,0.6921834945678711,1.207922101020813,0.6225999593734741,1.554673671722412,50000 -803.9027199745178,1.547590732574463,22985.990831136703,67304,0,22985.990831136703,0.5111000537872314,2.2321712970733643,10000,23793.728055477142,0.6910474896430969,1.2102947235107422,0.6284999847412109,1.5212945938110352,50000 -821.4244170188904,1.585910081863403,23496.12866091728,68803,0,23496.12866091728,0.4935000240802765,2.3209471702575684,10000,24321.4762878418,0.6715561151504517,1.3061954975128174,0.614579975605011,1.594861626625061,50000 -839.0864970684052,1.6295363903045654,24006.130012512207,70301,0,24006.130012512207,0.5094000101089478,2.2113044261932373,10000,24849.235661029816,0.6846699714660645,1.244409203529358,0.6282599568367004,1.5239821672439575,50000 -856.4387171268463,1.6738028526306152,24516.21312022209,71800,0,24516.21312022209,0.510200023651123,2.217399835586548,10000,25376.767111063004,0.6765784025192261,1.2757912874221802,0.6304799914360046,1.5133068561553955,50000 -873.810319185257,1.7138514518737793,25026.155697584152,73298,0,25026.155697584152,0.5108000040054321,2.190986156463623,10000,25904.17239308357,0.6873405575752258,1.2287601232528689,0.6338199973106384,1.491397738456726,50000 -891.2437620162964,1.7546777725219729,25536.339210748672,74797,0,25536.339210748672,0.5003000497817993,2.303654909133911,10000,26431.880017757416,0.6872010231018066,1.2051925659179688,0.6161800026893616,1.587920069694519,50000 -908.53577709198,1.7942428588867188,26046.54245686531,76296,0,26046.54245686531,0.4935000240802765,2.3314461708068848,10000,26959.465401887894,0.6835737824440002,1.253710389137268,0.6217799782752991,1.5734446048736572,50000 -926.8507552146912,1.83475923538208,26556.63892173767,77783,0,26556.63892173767,0.5133000016212463,2.162278652191162,10000,27487.96911430359,0.6978435516357422,1.2021422386169434,0.6412799954414368,1.4752492904663086,50000 -944.3251445293428,1.8716328144073489,27066.782158851624,79282,0,27066.782158851624,0.5126000046730042,2.2020113468170166,10000,28015.67524456978,0.6960698366165161,1.1828243732452393,0.6363999843597412,1.4793654680252075,50000 -961.7491703033448,1.9183545112609863,27576.946058273315,80781,0,27576.946058273315,0.5117000341415405,2.207427263259888,10000,28543.36324763298,0.6879384517669678,1.2187047004699707,0.6388999819755554,1.481725573539734,50000 -979.0919954776764,1.960249662399292,28087.02109003067,82280,0,28087.02109003067,0.5145000219345093,2.178096532821656,10000,29070.87357234955,0.6951131820678711,1.184423327445984,0.6406799554824829,1.4681005477905271,50000 -996.4454569816588,2.001596212387085,28596.984345436096,83779,0,28596.984345436096,0.5115000009536743,2.1780617237091064,10000,29598.28356528282,0.7388990521430969,1.0035078525543213,0.6438999772071838,1.44945228099823,50000 -1014.3660583496094,2.0440704822540283,29107.010613918304,85277,0,29107.010613918304,0.522599995136261,2.130420923233032,10000,30126.32354569435,0.7179328799247742,1.081685185432434,0.6489399671554565,1.4340031147003174,50000 -1031.8577575683594,2.080768346786499,29617.09086871147,86776,0,29617.09086871147,0.5206000208854675,2.138397455215454,10000,30653.98149752617,0.7063137888908386,1.1358715295791626,0.6468200087547302,1.4428404569625854,50000 -1049.4971315860748,2.121476888656616,30127.182915449142,88275,0,30127.182915449142,0.5166000127792358,2.154176950454712,10000,31181.80454421044,0.7068120241165161,1.13710355758667,0.6526600122451782,1.4295099973678589,50000 -1066.8635828495026,3.207822322845459,30636.21947956085,89771,0,30636.21947956085,0.5118000507354736,2.200810432434082,10000,31709.34457540512,0.6945750713348389,1.1966772079467771,0.6418799757957458,1.463532567024231,50000 -1084.156126499176,3.255371332168579,31146.34582209587,91270,0,31146.34582209587,0.5267000198364258,2.1272940635681152,10000,32236.86432123184,0.7092235088348389,1.134615778923035,0.6538000106811523,1.4101835489273071,50000 -1101.4758217334747,3.3038907051086426,31656.25477242469,92768,0,31656.25477242469,0.5174000263214111,2.1670353412628174,10000,32764.192311525345,0.7113958597183228,1.1191548109054563,0.6482200026512146,1.4321719408035278,50000 -1118.9492797851562,3.351661443710327,32166.24189376831,94267,0,32166.24189376831,0.5303000211715698,2.0873241424560547,10000,33291.75237035751,0.7361487150192261,1.0138471126556396,0.659500002861023,1.3887938261032104,50000 -1136.129715681076,3.401322364807129,32676.229960918427,95766,0,32676.229960918427,0.5320000052452087,2.1097888946533203,10000,33819.02078318596,0.7229153513908386,1.0641008615493774,0.6566199660301208,1.4057345390319824,50000 -1153.3248386383057,3.446504831314087,33186.39605593681,97265,0,33186.39605593681,0.5344000458717346,2.089726209640503,10000,34346.4784386158,0.7237324714660645,1.0708374977111816,0.6580399870872498,1.3982630968093872,50000 -1170.7596073150637,3.4904332160949707,33696.623683452606,98765,0,33696.623683452606,0.536300003528595,2.0971498489379883,10000,34874.23455262184,0.7150430083274841,1.1016902923583984,0.6593799591064453,1.3938559293746948,50000 -1188.2499401569366,3.535299777984619,34206.6414039135,100264,0,34206.6414039135,0.5412000417709351,2.056500196456909,10000,35401.838272333145,0.713309109210968,1.1122252941131592,0.6600399613380432,1.3887194395065308,50000 -1205.2929441928864,3.586189031600952,34716.84902739525,101764,0,34716.84902739525,0.5402000546455383,2.057695150375366,10000,35929.189604759216,0.7233139276504517,1.072434902191162,0.6643799543380737,1.362235426902771,50000 -1222.604502916336,3.63908052444458,35226.76603245735,103263,0,35226.76603245735,0.5405000448226929,2.041477918624878,10000,36456.52231359482,0.75394606590271,0.9281212091445924,0.6668999791145325,1.3567508459091189,50000 -1240.0294904708862,3.6891348361968994,35736.759813547134,104762,0,35736.759813547134,0.5410000085830688,2.045719623565674,10000,36984.04020619392,0.7365872263908386,1.0082699060440063,0.6635199785232544,1.3672338724136353,50000 -1257.5226662158966,3.7353804111480713,36246.74147129059,106261,0,36246.74147129059,0.5432000160217285,2.066902637481689,10000,37511.61378097534,0.7372249364852905,0.999062955379486,0.6670599579811096,1.3503457307815552,50000 -1274.914494514465,3.7825989723205566,36756.6770863533,107760,0,36756.6770863533,0.5291000008583069,2.1044600009918213,10000,38039.03833150864,0.7306082248687744,1.0350559949874878,0.6622999906539917,1.3632971048355105,50000 -1292.4509418010712,3.830709218978882,37266.8054394722,109259,0,37266.8054394722,0.549500048160553,2.004624605178833,10000,38566.80239248276,0.7350525856018066,1.0115010738372805,0.6735000014305115,1.3214294910430908,50000 -1309.994685173035,3.881550550460816,37776.72888278961,110758,0,37776.72888278961,0.5402000546455383,2.009732484817505,10000,39094.371799230576,0.7382214665412903,0.9993380308151244,0.6735000014305115,1.3205397129058838,50000 -1327.4420084953308,3.929419755935669,38286.65507602692,112257,0,38286.65507602692,0.5503000020980835,1.996065974235535,10000,39621.84500050545,0.7806122303009033,0.819338858127594,0.6816200017929077,1.2890325784683228,50000 -1345.113474369049,3.976061820983887,38796.73586535454,113756,0,38796.73586535454,0.5544000267982483,1.976928472518921,10000,40149.694214344025,0.7663623690605164,0.8877369165420532,0.6804400086402893,1.2903692722320557,50000 -1362.6518981456757,4.028971195220947,39306.689522743225,115255,0,39306.689522743225,0.5565000176429749,1.98634934425354,10000,40677.29035973549,0.7538464665412903,0.9340843558311462,0.6775199770927429,1.3047815561294556,50000 -1379.9599359035492,4.075598001480103,39816.85724949837,116754,0,39816.85724949837,0.5436000227928162,2.046970129013061,10000,41204.86368608475,0.7443000674247742,0.957797348499298,0.6729399561882019,1.3240751028060913,50000 -1397.556854724884,4.121983051300049,40326.82902789116,118253,0,40326.82902789116,0.5649000406265259,1.9236938953399656,10000,41732.53103065491,0.7626155614852905,0.9064222574234008,0.6890999674797058,1.2482517957687378,50000 -1414.801174402237,4.171825647354126,40836.754336595535,119752,0,40836.754336595535,0.5614000558853149,1.951021194458008,10000,42259.80155444145,0.7578921914100647,0.9132976531982422,0.6890599727630615,1.2581051588058472,50000 -1432.4339890480042,4.225170850753784,41346.87264943123,121251,0,41346.87264943123,0.5688000321388245,1.932717442512512,10000,42787.65633249283,0.7889827489852905,0.7927571535110474,0.6894400119781494,1.2526060342788696,50000 -1449.9362061023712,4.2682952880859375,41857.08833384514,122751,0,41857.08833384514,0.5613000392913818,1.9879837036132808,10000,43315.46732521057,0.778738796710968,0.8235031962394714,0.6881600022315979,1.264472484588623,50000 -1467.9139926433563,4.320462465286255,42367.10335683823,124250,0,42367.10335683823,0.5701000094413757,1.924551010131836,10000,43843.5637190342,0.782246470451355,0.8037508726119995,0.6977399587631226,1.2117669582366943,50000 -1485.1218111515043,4.3753581047058105,42877.30141162872,125749,0,42877.30141162872,0.5735000371932983,1.909732699394226,10000,44371.0747590065,0.7780413031578064,0.8233484625816345,0.6994400024414062,1.2192716598510742,50000 -1502.3710179328918,4.427471876144409,43387.38800287247,127248,0,43387.38800287247,0.5724000334739685,1.918750882148743,10000,44898.51348400116,0.7848173975944519,0.798477292060852,0.6995799541473389,1.2017760276794434,50000 -1519.7068076133728,4.479499578475952,43897.49772691727,128747,0,43897.49772691727,0.5781000256538391,1.8842612504959104,10000,45426.061324596405,0.7816087007522583,0.8127210736274719,0.7039200067520142,1.1963602304458618,50000 -1536.9744882583618,4.538180828094482,44407.48382616043,130246,0,44407.48382616043,0.5733000040054321,1.893583297729492,10000,45953.42463064194,0.7840800285339355,0.8188005685806274,0.7031799554824829,1.1907312870025637,50000 -1554.476620197296,4.591028213500977,44917.434248924255,131745,0,44917.434248924255,0.5805000066757202,1.8970988988876345,10000,46480.9824347496,0.8127192258834839,0.6832287907600403,0.7074599862098694,1.1807633638381958,50000 -1571.6903052330017,4.63919734954834,45427.39062857628,133243,0,45427.39062857628,0.5821000337600708,1.8820722103118896,10000,47008.251879930496,0.79984450340271,0.7381070852279663,0.7058199644088745,1.1831684112548828,50000 -1589.1792786121368,4.692151308059692,45937.45642733574,134742,0,45937.45642733574,0.5893000364303589,1.825054168701172,10000,47535.90947675705,0.8057836294174194,0.7143407464027405,0.7128799557685852,1.146412968635559,50000 -1606.857929468155,4.741678476333618,46447.50892996788,136241,0,46447.50892996788,0.5920000076293945,1.826322317123413,10000,48063.74123668671,0.8033322691917419,0.7214902639389038,0.7112399935722351,1.1542009115219116,50000 -1624.2676224708557,4.790990352630615,46957.61950492859,137740,0,46957.61950492859,0.5933000445365906,1.8347601890563965,10000,48591.36189293861,0.805683970451355,0.7093267440795898,0.7160399556159973,1.1419600248336792,50000 -1641.4246740341189,4.841644525527954,47467.85808753967,139238,0,47467.85808753967,0.5916000008583069,1.830636978149414,10000,49118.85888767242,0.8067402839660645,0.7009314298629761,0.7160199880599976,1.1319137811660769,50000 -1658.625111579895,4.895170450210571,47978.06637239456,140737,0,47978.06637239456,0.5914000272750854,1.810274839401245,10000,49646.37145638466,0.8413185477256775,0.5797902345657349,0.7160599827766418,1.138371467590332,50000 -1676.1800088882446,4.949621677398682,48488.23559617996,142236,0,48488.23559617996,0.5998000502586365,1.7817325592041016,10000,50174.20073246956,0.8332070708274841,0.6023600101470947,0.7219199538230896,1.1166437864303589,50000 -1693.400636434555,5.0050599575042725,48998.14870905876,143735,0,48998.14870905876,0.6032000184059143,1.785300612449646,10000,50701.440212488174,0.8356783986091614,0.5958223938941956,0.7263199687004089,1.1001946926116943,50000 -1710.852013349533,5.087496757507324,49508.11609601975,145233,0,49508.11609601975,0.6122000217437744,1.7577836513519287,10000,51228.993624448776,0.8364556431770325,0.5910767316818237,0.7298399806022644,1.085693120956421,50000 -1728.0061275959015,5.139847993850708,50018.29201626778,146732,0,50018.29201626778,0.6041000485420227,1.7766298055648804,10000,51756.42676591873,0.83402419090271,0.6022161841392517,0.729919970035553,1.0840203762054443,50000 -1745.2415256500244,5.190353631973267,50528.38182926178,148231,0,50528.38182926178,0.6070000529289246,1.7447444200515747,10000,52283.85293364525,0.8361965417861938,0.5804234743118286,0.7318599820137024,1.0679874420166016,50000 -1762.763778924942,5.244019508361816,51038.4784321785,149730,0,51038.4784321785,0.6065000295639038,1.750016450881958,10000,52811.57619476318,0.8720304369926453,0.4616715610027313,0.7359199523925781,1.06617271900177,50000 -1780.0837841033936,5.296934366226196,51548.38336634636,151228,0,51548.38336634636,0.6137000322341919,1.7262399196624756,10000,53338.903485774994,0.8650151491165161,0.4839383959770202,0.7351799607276917,1.0641911029815674,50000 -1797.5025045871737,5.351749420166016,52058.31352877617,152726,0,52058.31352877617,0.6089000105857849,1.7674813270568848,10000,53866.35694384575,0.8591955900192261,0.492692083120346,0.7340599894523621,1.070358157157898,50000 -1814.846135139465,5.408911228179932,52568.48841428757,154225,0,52568.48841428757,0.6195000410079956,1.7304673194885254,10000,54393.9851975441,0.8656927347183228,0.478325217962265,0.7418199777603149,1.0443856716156006,50000 -1832.3190059661863,5.4648637771606445,53078.62742900848,155724,0,53078.62742900848,0.6204000115394592,1.7236957550048828,10000,54921.7041721344,0.8673469424247742,0.4798938035964966,0.7419399619102478,1.0362374782562256,50000 -1849.9641468524933,5.523918867111206,53588.541823387146,157222,0,53588.541823387146,0.6229000091552734,1.7110683917999268,10000,55449.373153209686,0.8683633208274841,0.4606120586395263,0.7424399852752686,1.0408451557159424,50000 -1867.2719218730929,5.576862573623657,54098.47330284119,158721,0,54098.47330284119,0.6211000084877014,1.7047979831695557,10000,55976.7153468132,0.8824138641357422,0.4179320037364959,0.7461999654769897,1.0250033140182495,50000 -1884.594167470932,5.631555557250977,54608.64770102501,160220,0,54608.64770102501,0.6242000460624695,1.7131024599075315,10000,56504.31801342964,0.8951291441917419,0.3667570948600769,0.7464399933815002,1.0255107879638672,50000 -1902.1840479373927,5.686068296432495,55118.76144886017,161719,0,55118.76144886017,0.6291000247001648,1.6922518014907837,10000,57032.12668228149,0.8964046239852905,0.3655628263950348,0.7499399781227112,1.0128942728042605,50000 -1919.415715456009,5.746610403060913,55628.83065366745,163218,0,55628.83065366745,0.6283000111579895,1.6882383823394775,10000,57559.539657115936,0.8958067297935486,0.3668033480644226,0.7515599727630615,1.0076985359191897,50000 -1937.3459751605988,5.80523681640625,56138.98541164398,164717,0,56138.98541164398,0.6323000192642212,1.692557692527771,10000,58087.73439216614,0.89652419090271,0.3590350449085235,0.7538999915122986,0.995878040790558,50000 -1954.7560713291168,5.8507981300354,56649.082023859024,166216,0,56649.082023859024,0.6308000087738037,1.68972647190094,10000,58615.33793616295,0.9032605290412904,0.3413633108139038,0.7542799711227417,0.9981317520141602,50000 -1972.07869887352,5.907650232315064,57159.2556912899,167715,0,57159.2556912899,0.6345000267028809,1.6859267950057983,10000,59142.94081425667,0.9057517647743224,0.3304113447666168,0.755620002746582,0.9929838180541992,50000 -1989.2317397594447,5.962466239929199,57669.26935172081,169213,0,57669.26935172081,0.6319000124931335,1.6728354692459106,10000,59670.21319055557,0.920340359210968,0.2846298217773437,0.7571199536323547,0.9897215962409972,50000 -2006.736686944961,6.018509387969971,58179.45411133766,170712,0,58179.45411133766,0.6370000243186951,1.660903811454773,10000,60198.01121592522,0.920719027519226,0.2819788455963135,0.7594199776649475,0.9810996651649476,50000 -2024.2342946529388,6.075444459915161,58689.38737034798,172210,0,58689.38737034798,0.6363000273704529,1.666724443435669,10000,60725.54909420013,0.9206393361091614,0.2826157808303833,0.7601400017738342,0.9788408875465392,50000 -2041.5964758396149,6.139222145080566,59199.27856183052,173708,0,59199.27856183052,0.6378000378608704,1.6611794233322144,10000,61252.9173913002,0.9241868257522584,0.267222911119461,0.7613199949264526,0.9772453904151917,50000 -2058.824327230453,6.205310344696045,59709.26147675514,175206,0,59709.26147675514,0.6360000371932983,1.6560876369476318,10000,61780.24515080452,0.9243263602256776,0.2708885669708252,0.7618199586868286,0.9700082540512084,50000 -2076.4793939590454,6.263585090637207,60219.2170612812,176704,0,60219.2170612812,0.6390000581741333,1.6528555154800415,10000,62307.96381902695,0.9294283986091614,0.2561113834381103,0.7628200054168701,0.9697470664978028,50000 -2093.848112821579,6.324261903762817,60729.36058998108,178203,0,60729.36058998108,0.6402000188827515,1.652724266052246,10000,62835.58871150017,0.9300462007522584,0.2525694072246551,0.763759970664978,0.9669691920280457,50000 -2111.0798873901367,6.385619401931763,61239.35639166832,179701,0,61239.35639166832,0.64410001039505,1.6483752727508545,10000,63362.927340745926,0.9334741234779358,0.2409894466400146,0.7638599872589111,0.9650241136550904,50000 -2128.31769990921,6.443438529968262,61749.25453186035,181199,0,61749.25453186035,0.6421000361442566,1.6437848806381226,10000,63890.17459869385,0.933254897594452,0.2420443147420883,0.7646999955177307,0.9629248380661012,50000 -2145.533955574036,6.500751495361328,62259.378823280334,182698,0,62259.378823280334,0.6443000435829163,1.6459656953811646,10000,64417.62543559074,0.9331353306770324,0.2427098006010055,0.7647599577903748,0.9616653919219972,50000 -2162.7728395462036,6.56220269203186,62769.5444791317,184197,0,62769.5444791317,0.6438000202178955,1.6469836235046387,10000,64945.14319252968,0.9338129758834839,0.24053217470645905,0.7651599645614624,0.9626170992851257,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/measurements.csv deleted file mode 100644 index 09fcf8542..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1975 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6531614,6.929166,,,,,,,,,,,,,, -1,,,0.0011758609907701,6.911500930786133,0.0010199999669566,6.91091251373291,50000.0,0.0012000000569969,6.910790920257568,10000.0,31.49718737602234,48.92790198326111,31.49718737602234,17.43062710762024,0.0,0.0 -100,0.64127946,6.8972306,,,,,,,,,,,,,, -200,0.67995983,6.859826,,,,,,,,,,,,,, -300,0.6973275,6.773916,,,,,,,,,,,,,, -400,0.75662786,6.669539,,,,,,,,,,,,,, -500,0.7913645,6.555879,,,,,,,,,,,,,, -600,0.80036825,6.460474,,,,,,,,,,,,,, -700,0.85710937,6.348482,,,,,,,,,,,,,, -800,0.9600871,6.19419,,,,,,,,,,,,,, -900,1.702767,6.090267,,,,,,,,,,,,,, -1000,1.5950238,5.9900117,,,,,,,,,,,,,, -1100,2.3983133,5.880382,,,,,,,,,,,,,, -1200,1.7885987,5.762966,,,,,,,,,,,,,, -1300,2.2370322,5.700748,,,,,,,,,,,,,, -1400,3.0854425,5.633295,,,,,,,,,,,,,, -1492,,,0.0675422474741935,5.367956638336182,0.064699999988079,5.413669109344482,50000.0,0.0446000024676322,5.659720420837402,10000.0,541.692113161087,576.4821372032166,541.692113161087,34.72050094604492,0.0184686183929443,0.0 -1500,2.9605603,5.5207047,,,,,,,,,,,,,, -1600,3.8266504,5.328137,,,,,,,,,,,,,, -1700,3.3383522,5.289833,,,,,,,,,,,,,, -1800,3.2079284,5.2990456,,,,,,,,,,,,,, -1900,3.6149292,5.1287794,,,,,,,,,,,,,, -2000,6.255208,5.2383957,,,,,,,,,,,,,, -2100,4.6587334,5.093015,,,,,,,,,,,,,, -2200,5.846603,5.0722017,,,,,,,,,,,,,, -2300,5.2656555,4.9687986,,,,,,,,,,,,,, -2400,6.7763786,4.9486737,,,,,,,,,,,,,, -2500,6.9231563,4.8479853,,,,,,,,,,,,,, -2600,6.1550927,4.800861,,,,,,,,,,,,,, -2700,6.4776173,4.7461677,,,,,,,,,,,,,, -2800,4.6527834,4.6593575,,,,,,,,,,,,,, -2900,4.80045,4.700909,,,,,,,,,,,,,, -2982,,,0.168207898736,4.278195858001709,0.1507599949836731,4.392276287078857,50000.0,0.1075000017881393,4.850800514221191,10000.0,1051.8486168384552,1104.1325709819794,1051.8486168384552,52.13346600532532,0.0493288040161132,0.0 -3000,4.697735,4.5768075,,,,,,,,,,,,,, -3100,6.2394614,4.4422846,,,,,,,,,,,,,, -3200,6.4606586,4.567251,,,,,,,,,,,,,, -3300,8.499981,4.3946333,,,,,,,,,,,,,, -3400,5.9171486,4.34876,,,,,,,,,,,,,, -3500,4.1579843,4.3388305,,,,,,,,,,,,,, -3600,7.482451,4.3388104,,,,,,,,,,,,,, -3700,6.746224,4.1861596,,,,,,,,,,,,,, -3800,6.489109,4.1602917,,,,,,,,,,,,,, -3900,9.935474,4.1545515,,,,,,,,,,,,,, -4000,4.7999787,4.2610817,,,,,,,,,,,,,, -4100,9.969456,4.0476265,,,,,,,,,,,,,, -4200,4.496143,4.091651,,,,,,,,,,,,,, -4300,5.526952,3.9434683,,,,,,,,,,,,,, -4400,11.474292,3.8425164,,,,,,,,,,,,,, -4471,,,0.2690728604793548,3.525336265563965,0.2464799880981445,3.675424337387085,50000.0,0.1806000024080276,4.246772289276123,10000.0,1561.8328936100006,1631.8379180431366,1561.8328936100006,69.77436900138855,0.0782089233398437,0.0 -4500,5.9880366,3.9365547,,,,,,,,,,,,,, -4600,8.631846,3.840261,,,,,,,,,,,,,, -4700,6.0709853,3.7075672,,,,,,,,,,,,,, -4800,5.9327164,3.9591393,,,,,,,,,,,,,, -4900,6.4583464,3.682198,,,,,,,,,,,,,, -5000,6.2863307,3.695685,,,,,,,,,,,,,, -5100,9.771049,3.5586824,,,,,,,,,,,,,, -5200,7.7305174,3.5463417,,,,,,,,,,,,,, -5300,4.490062,3.508177,,,,,,,,,,,,,, -5400,5.375097,3.4844337,,,,,,,,,,,,,, -5500,8.266056,3.4681811,,,,,,,,,,,,,, -5600,11.157162,3.604453,,,,,,,,,,,,,, -5700,11.814644,3.3920262,,,,,,,,,,,,,, -5800,6.9877295,3.3984222,,,,,,,,,,,,,, -5900,7.4173045,3.5087156,,,,,,,,,,,,,, -5961,,,0.3585578799247741,2.934255361557007,0.3324399888515472,3.0907390117645264,50000.0,0.2507000267505646,3.733696699142456,10000.0,2072.0533249378204,2159.659814119339,2072.0533249378204,87.29565978050232,0.1079814434051513,0.0 -6000,7.4860864,3.399908,,,,,,,,,,,,,, -6100,4.6676373,3.356272,,,,,,,,,,,,,, -6200,13.588626,3.317605,,,,,,,,,,,,,, -6300,8.585403,3.2212315,,,,,,,,,,,,,, -6400,11.541688,3.3259935,,,,,,,,,,,,,, -6500,5.545324,3.1652303,,,,,,,,,,,,,, -6600,7.616776,3.2708442,,,,,,,,,,,,,, -6700,9.47938,3.2098808,,,,,,,,,,,,,, -6800,9.855965,3.05862,,,,,,,,,,,,,, -6900,7.515577,3.1307347,,,,,,,,,,,,,, -7000,7.136223,3.0476701,,,,,,,,,,,,,, -7100,8.6003685,3.085645,,,,,,,,,,,,,, -7200,6.8550205,3.0855052,,,,,,,,,,,,,, -7300,7.9913354,3.0711844,,,,,,,,,,,,,, -7400,3.3652184,2.9838924,,,,,,,,,,,,,, -7451,,,0.4449138939380646,2.4179675579071045,0.3871199786663055,2.7826380729675293,50000.0,0.2959000170230865,3.4668314456939697,10000.0,2582.084417819977,2687.353482246399,2582.084417819977,104.87834978103638,0.136627197265625,0.0 -7500,6.520137,2.9913843,,,,,,,,,,,,,, -7600,13.615356,3.0907705,,,,,,,,,,,,,, -7700,6.765888,3.08773,,,,,,,,,,,,,, -7800,7.950795,2.826744,,,,,,,,,,,,,, -7900,8.027349,2.9943368,,,,,,,,,,,,,, -8000,8.891745,2.9182944,,,,,,,,,,,,,, -8100,5.710428,2.8285036,,,,,,,,,,,,,, -8200,6.6983647,2.8900704,,,,,,,,,,,,,, -8300,6.6520042,2.8385744,,,,,,,,,,,,,, -8400,5.772019,2.7856686,,,,,,,,,,,,,, -8500,7.062276,2.9168587,,,,,,,,,,,,,, -8600,8.909641,2.8210745,,,,,,,,,,,,,, -8700,7.4092817,2.6941473,,,,,,,,,,,,,, -8800,8.341493,2.7868445,,,,,,,,,,,,,, -8900,5.922514,2.730505,,,,,,,,,,,,,, -8941,,,0.5053611397743225,2.105006217956543,0.4567599892616272,2.381918430328369,50000.0,0.3477000296115875,3.115972757339477,10000.0,3092.1619415283203,3215.6690373420715,3092.1619415283203,123.03875970840454,0.1637771129608154,0.0 -9000,6.155279,2.7679315,,,,,,,,,,,,,, -9100,5.7961364,2.6972327,,,,,,,,,,,,,, -9200,4.6044283,2.7195106,,,,,,,,,,,,,, -9300,7.4437723,2.7167385,,,,,,,,,,,,,, -9400,7.6553483,2.5625792,,,,,,,,,,,,,, -9500,6.2698183,2.738227,,,,,,,,,,,,,, -9600,8.230375,2.654365,,,,,,,,,,,,,, -9700,4.631346,2.5776136,,,,,,,,,,,,,, -9800,6.5033784,2.6872594,,,,,,,,,,,,,, -9900,8.103903,2.6765494,,,,,,,,,,,,,, -10000,7.3611484,2.6772711,,,,,,,,,,,,,, -10100,6.7022533,2.602409,,,,,,,,,,,,,, -10200,5.763085,2.5564766,,,,,,,,,,,,,, -10300,6.020455,2.487839,,,,,,,,,,,,,, -10400,6.7638206,2.6335158,,,,,,,,,,,,,, -10433,,,0.5231783986091614,2.0281193256378174,0.4831199944019317,2.261015653610229,50000.0,0.3731000125408172,2.946731805801392,10000.0,3602.4153928756714,3743.475313663482,3602.4153928756714,140.51118230819702,0.1939032077789306,0.0 -10500,4.6675386,2.4953754,,,,,,,,,,,,,, -10600,7.2573156,2.4306211,,,,,,,,,,,,,, -10700,6.336898,2.4725971,,,,,,,,,,,,,, -10800,5.6192045,2.5465093,,,,,,,,,,,,,, -10900,7.062539,2.3520205,,,,,,,,,,,,,, -11000,5.1018,2.360961,,,,,,,,,,,,,, -11100,10.43212,2.5848098,,,,,,,,,,,,,, -11200,8.329623,2.5434659,,,,,,,,,,,,,, -11300,5.0050592,2.4493046,,,,,,,,,,,,,, -11400,7.830518,2.4265924,,,,,,,,,,,,,, -11500,11.222109,2.4745216,,,,,,,,,,,,,, -11600,7.38528,2.4085095,,,,,,,,,,,,,, -11700,5.1459103,2.371295,,,,,,,,,,,,,, -11800,11.736947,2.4461966,,,,,,,,,,,,,, -11900,6.8512745,2.4037137,,,,,,,,,,,,,, -11925,,,0.5595105290412903,1.8536113500595093,0.5171599984169006,2.09015154838562,50000.0,0.4001000225543976,2.7808094024658203,10000.0,4112.552248477936,4271.202058792114,4112.552248477936,158.01940941810608,0.2233092784881591,0.0 -12000,6.8984337,2.3879495,,,,,,,,,,,,,, -12100,5.8407483,2.4170995,,,,,,,,,,,,,, -12200,6.483694,2.4201932,,,,,,,,,,,,,, -12300,4.95639,2.295721,,,,,,,,,,,,,, -12400,6.790916,2.3767836,,,,,,,,,,,,,, -12500,5.9154096,2.416644,,,,,,,,,,,,,, -12600,5.845142,2.437266,,,,,,,,,,,,,, -12700,3.6055448,2.3032928,,,,,,,,,,,,,, -12800,9.087492,2.3664637,,,,,,,,,,,,,, -12900,7.61741,2.3940616,,,,,,,,,,,,,, -13000,4.235059,2.331127,,,,,,,,,,,,,, -13100,6.1194263,2.348337,,,,,,,,,,,,,, -13200,5.2365055,2.3541646,,,,,,,,,,,,,, -13300,6.9354115,2.3368554,,,,,,,,,,,,,, -13400,5.910556,2.4127278,,,,,,,,,,,,,, -13418,,,0.5712292790412903,1.7940986156463623,0.5276399850845337,2.01565170288086,50000.0,0.4075000286102295,2.762702703475952,10000.0,4622.760374307632,4798.860710382462,4622.760374307632,175.39147543907166,0.2514383792877197,0.0 -13500,8.474977,2.3957894,,,,,,,,,,,,,, -13600,5.4218717,2.2833834,,,,,,,,,,,,,, -13700,5.5136266,2.1732209,,,,,,,,,,,,,, -13800,5.5255804,2.3549473,,,,,,,,,,,,,, -13900,5.965481,2.211281,,,,,,,,,,,,,, -14000,4.606673,2.3964663,,,,,,,,,,,,,, -14100,5.4186625,2.2970085,,,,,,,,,,,,,, -14200,8.955366,2.194339,,,,,,,,,,,,,, -14300,5.039338,2.3108325,,,,,,,,,,,,,, -14400,5.623256,2.305035,,,,,,,,,,,,,, -14500,5.3580256,2.29785,,,,,,,,,,,,,, -14600,5.2721224,2.181643,,,,,,,,,,,,,, -14700,5.895338,2.3704743,,,,,,,,,,,,,, -14800,4.5317497,2.2366052,,,,,,,,,,,,,, -14900,5.206332,2.1305764,,,,,,,,,,,,,, -14911,,,0.5833864808082581,1.730528473854065,0.5450999736785889,1.929821848869324,50000.0,0.426800012588501,2.671276807785034,10000.0,5132.832676887512,5326.234619617462,5132.832676887512,192.6115975379944,0.2807481288909912,0.0 -15000,7.583055,2.1161742,,,,,,,,,,,,,, -15100,5.098239,2.3294427,,,,,,,,,,,,,, -15200,6.5731764,2.2283764,,,,,,,,,,,,,, -15300,4.848228,2.184328,,,,,,,,,,,,,, -15400,4.8119845,2.2718067,,,,,,,,,,,,,, -15500,7.0500445,2.2070384,,,,,,,,,,,,,, -15600,6.4424367,2.369598,,,,,,,,,,,,,, -15700,8.362508,2.3462422,,,,,,,,,,,,,, -15800,5.097404,2.212739,,,,,,,,,,,,,, -15900,5.440891,2.205329,,,,,,,,,,,,,, -16000,6.6394453,2.2935658,,,,,,,,,,,,,, -16100,6.7089505,2.1149032,,,,,,,,,,,,,, -16200,5.5150595,2.2210822,,,,,,,,,,,,,, -16300,5.072677,2.0552466,,,,,,,,,,,,,, -16400,3.2604883,2.1729863,,,,,,,,,,,,,, -16405,,,0.5859175324440002,1.707316517829895,0.549780011177063,1.903155207633972,50000.0,0.4225000143051147,2.6825528144836426,10000.0,5643.064661502838,5853.869685411453,5643.064661502838,209.93391489982605,0.3114583492279053,0.0 -16500,7.4805155,2.2117505,,,,,,,,,,,,,, -16600,5.6422424,2.1582382,,,,,,,,,,,,,, -16700,6.607876,2.2806876,,,,,,,,,,,,,, -16800,4.4571466,2.180027,,,,,,,,,,,,,, -16900,4.8535414,2.1990933,,,,,,,,,,,,,, -17000,7.324683,2.2395377,,,,,,,,,,,,,, -17100,5.4669976,2.2032104,,,,,,,,,,,,,, -17200,4.8575873,2.1605759,,,,,,,,,,,,,, -17300,4.9371934,2.1121469,,,,,,,,,,,,,, -17400,5.5100503,2.1415915,,,,,,,,,,,,,, -17500,5.610598,2.208445,,,,,,,,,,,,,, -17600,5.932989,2.2133229,,,,,,,,,,,,,, -17700,5.8900104,2.288952,,,,,,,,,,,,,, -17800,4.6979175,2.2490437,,,,,,,,,,,,,, -17899,,,0.6173269748687744,1.5603954792022705,0.5484799742698669,1.9227455854415887,50000.0,0.4250000119209289,2.651263236999512,10000.0,6153.015790462494,6381.367884874344,6153.015790462494,227.40043759346008,0.3407416343688965,0.0 -17900,6.7903585,2.2431364,,,,,,,,,,,,,, -18000,5.2507224,2.3420086,,,,,,,,,,,,,, -18100,4.7108927,2.180774,,,,,,,,,,,,,, -18200,4.1168,2.0595458,,,,,,,,,,,,,, -18300,5.419303,2.1264455,,,,,,,,,,,,,, -18400,3.7966003,2.3020027,,,,,,,,,,,,,, -18500,3.5668437,2.164449,,,,,,,,,,,,,, -18600,4.4374332,2.1695855,,,,,,,,,,,,,, -18700,4.144028,2.1463532,,,,,,,,,,,,,, -18800,4.0591254,2.2547376,,,,,,,,,,,,,, -18900,4.704334,2.2180202,,,,,,,,,,,,,, -19000,3.912241,2.2157931,,,,,,,,,,,,,, -19100,4.3220153,2.1613958,,,,,,,,,,,,,, -19200,4.893621,2.306308,,,,,,,,,,,,,, -19300,4.4428487,2.122498,,,,,,,,,,,,,, -19394,,,0.6105906963348389,1.588725447654724,0.5607399940490723,1.864424467086792,50000.0,0.4438000321388244,2.586523532867432,10000.0,6663.139073133469,6909.171550512314,6663.139073133469,244.9995155334473,0.3710279464721679,0.0 -19400,4.2897816,2.345162,,,,,,,,,,,,,, -19500,4.9923816,2.2352862,,,,,,,,,,,,,, -19600,3.7280538,2.2179613,,,,,,,,,,,,,, -19700,4.2577267,2.1860743,,,,,,,,,,,,,, -19800,5.3492646,2.1456182,,,,,,,,,,,,,, -19900,4.099113,2.2087135,,,,,,,,,,,,,, -20000,3.4730978,2.0817032,,,,,,,,,,,,,, -20100,3.8708358,2.0986853,,,,,,,,,,,,,, -20200,4.011689,2.1342537,,,,,,,,,,,,,, -20300,3.2998667,2.2201262,,,,,,,,,,,,,, -20400,4.6903534,2.0903785,,,,,,,,,,,,,, -20500,4.239024,2.1118898,,,,,,,,,,,,,, -20600,3.5983915,2.0446105,,,,,,,,,,,,,, -20700,3.5448482,2.0800965,,,,,,,,,,,,,, -20800,3.5857015,2.0911498,,,,,,,,,,,,,, -20889,,,0.6055883169174194,1.5978363752365112,0.5553199648857117,1.85867977142334,50000.0,0.4298000335693359,2.619914054870605,10000.0,7173.313796281815,7436.808223009109,7173.313796281815,262.37942600250244,0.4020321369171142,0.0 -20900,3.9808958,2.3309226,,,,,,,,,,,,,, -21000,3.2780614,2.1394894,,,,,,,,,,,,,, -21100,3.9363773,2.1863544,,,,,,,,,,,,,, -21200,3.22266,2.135749,,,,,,,,,,,,,, -21300,3.1464033,2.1280527,,,,,,,,,,,,,, -21400,3.7156034,2.0068996,,,,,,,,,,,,,, -21500,3.7050514,2.1362522,,,,,,,,,,,,,, -21600,3.688896,2.234771,,,,,,,,,,,,,, -21700,3.5928094,2.0096428,,,,,,,,,,,,,, -21800,3.4898493,2.1696262,,,,,,,,,,,,,, -21900,3.9258254,2.181287,,,,,,,,,,,,,, -22000,3.3199296,2.0831082,,,,,,,,,,,,,, -22100,3.5421326,2.0504208,,,,,,,,,,,,,, -22200,3.98155,2.041881,,,,,,,,,,,,,, -22300,4.3074284,2.1096914,,,,,,,,,,,,,, -22384,,,0.6145368218421936,1.5811792612075806,0.5697999596595764,1.8240339756011963,50000.0,0.4462000131607055,2.553001165390014,10000.0,7683.348705768585,7964.449161529541,7683.348705768585,279.904082775116,0.4324047565460205,0.0 -22400,3.6286533,2.0709236,,,,,,,,,,,,,, -22500,3.3623505,2.1194468,,,,,,,,,,,,,, -22600,2.9203014,1.9985145,,,,,,,,,,,,,, -22700,3.209,1.990104,,,,,,,,,,,,,, -22800,3.9091024,2.0453649,,,,,,,,,,,,,, -22900,3.5614967,1.9897927,,,,,,,,,,,,,, -23000,4.589352,2.09968,,,,,,,,,,,,,, -23100,4.3227024,2.124751,,,,,,,,,,,,,, -23200,3.3385274,2.0849082,,,,,,,,,,,,,, -23300,3.5845408,2.036351,,,,,,,,,,,,,, -23400,4.7033606,2.220351,,,,,,,,,,,,,, -23500,3.2424757,2.1807044,,,,,,,,,,,,,, -23600,5.00143,2.0583093,,,,,,,,,,,,,, -23700,5.1049485,2.0349677,,,,,,,,,,,,,, -23800,3.5474994,2.0552623,,,,,,,,,,,,,, -23879,,,0.6197385191917419,1.565221071243286,0.5730199813842773,1.7761461734771729,50000.0,0.4519000351428985,2.501939296722412,10000.0,8193.355751514435,8492.355654716492,8193.355751514435,297.7186679840088,0.4656527042388916,0.0 -23900,3.6969144,1.9887557,,,,,,,,,,,,,, -24000,4.4240365,2.1325774,,,,,,,,,,,,,, -24100,3.7572448,2.089129,,,,,,,,,,,,,, -24200,3.0130212,2.0253391,,,,,,,,,,,,,, -24300,3.5879283,2.0659156,,,,,,,,,,,,,, -24400,3.734496,2.0994282,,,,,,,,,,,,,, -24500,3.514572,2.0426412,,,,,,,,,,,,,, -24600,3.0555801,2.000138,,,,,,,,,,,,,, -24700,3.710123,1.9946243,,,,,,,,,,,,,, -24800,3.7271903,1.9924216,,,,,,,,,,,,,, -24900,3.7647247,2.1622107,,,,,,,,,,,,,, -25000,3.5221224,2.0115175,,,,,,,,,,,,,, -25100,4.6320033,2.0391529,,,,,,,,,,,,,, -25200,3.3687143,1.9632349,,,,,,,,,,,,,, -25300,3.5425832,2.0260873,,,,,,,,,,,,,, -25375,,,0.6237244606018066,1.5361076593399048,0.5832799673080444,1.745568037033081,50000.0,0.4537000358104706,2.5031628608703613,10000.0,8703.507536411285,9019.985451698303,8703.507536411285,315.1154990196228,0.4975943565368652,0.0 -25400,4.2317276,2.1267915,,,,,,,,,,,,,, -25500,4.1683855,1.9848338,,,,,,,,,,,,,, -25600,3.5635486,2.0904255,,,,,,,,,,,,,, -25700,3.4002767,2.0798767,,,,,,,,,,,,,, -25800,4.0815783,2.0270057,,,,,,,,,,,,,, -25900,3.4388273,2.1844006,,,,,,,,,,,,,, -26000,3.6943011,1.9346882,,,,,,,,,,,,,, -26100,3.3683841,2.0075116,,,,,,,,,,,,,, -26200,3.6331131,2.098056,,,,,,,,,,,,,, -26300,3.3656254,2.0346236,,,,,,,,,,,,,, -26400,4.1670017,2.1031632,,,,,,,,,,,,,, -26500,3.7367308,1.976134,,,,,,,,,,,,,, -26600,3.751691,1.9426975,,,,,,,,,,,,,, -26700,3.0721774,1.8910434,,,,,,,,,,,,,, -26800,3.6739204,2.0506864,,,,,,,,,,,,,, -26871,,,0.6704002022743225,1.311537265777588,0.5862399935722351,1.7210910320281982,50000.0,0.4626000225543976,2.476860761642456,10000.0,9213.550162315369,9547.775417804718,9213.550162315369,332.7810888290405,0.5300009250640869,0.0 -26900,2.9297588,1.9940362,,,,,,,,,,,,,, -27000,3.6235368,2.1982584,,,,,,,,,,,,,, -27100,4.1986947,2.1348515,,,,,,,,,,,,,, -27200,3.3445919,2.1056056,,,,,,,,,,,,,, -27300,3.385272,1.9217198,,,,,,,,,,,,,, -27400,3.6722455,1.9819996,,,,,,,,,,,,,, -27500,3.5153275,1.9413172,,,,,,,,,,,,,, -27600,3.1000936,2.0199907,,,,,,,,,,,,,, -27700,3.882503,1.9560543,,,,,,,,,,,,,, -27800,3.4839678,1.9955127,,,,,,,,,,,,,, -27900,3.65965,2.0374475,,,,,,,,,,,,,, -28000,3.2777977,1.9337918,,,,,,,,,,,,,, -28100,3.6736956,2.0024595,,,,,,,,,,,,,, -28200,3.3196373,1.9157032,,,,,,,,,,,,,, -28300,4.482511,1.9140663,,,,,,,,,,,,,, -28366,,,0.6376155614852905,1.4694420099258425,0.57669997215271,1.7748409509658811,50000.0,0.453900009393692,2.494658708572388,10000.0,9723.562857627869,10075.523859977722,9723.562857627869,350.4288082122803,0.5664412975311279,0.0 -28400,2.9584875,2.0311725,,,,,,,,,,,,,, -28500,4.113639,1.9704851,,,,,,,,,,,,,, -28600,3.7451801,1.9844458,,,,,,,,,,,,,, -28700,3.579721,2.0047822,,,,,,,,,,,,,, -28800,3.6092308,1.9262617,,,,,,,,,,,,,, -28900,3.688826,2.0922215,,,,,,,,,,,,,, -29000,3.82576,2.0654504,,,,,,,,,,,,,, -29100,3.4443655,1.7649845,,,,,,,,,,,,,, -29200,3.6837714,2.0189977,,,,,,,,,,,,,, -29300,3.6474211,2.040229,,,,,,,,,,,,,, -29400,3.291065,2.0567102,,,,,,,,,,,,,, -29500,3.873579,2.0382104,,,,,,,,,,,,,, -29600,3.3491364,2.0268726,,,,,,,,,,,,,, -29700,4.072754,1.924821,,,,,,,,,,,,,, -29800,3.133103,1.9330105,,,,,,,,,,,,,, -29863,,,0.6395886540412903,1.4568294286727903,0.586080014705658,1.7302289009094238,50000.0,0.4571000337600708,2.4968318939208984,10000.0,10233.790237903597,10603.18566417694,10233.790237903597,367.77251267433167,0.6059026718139648,0.0 -29900,5.3562613,2.0452855,,,,,,,,,,,,,, -30000,3.6931229,1.9871202,,,,,,,,,,,,,, -30100,4.1311183,1.9922979,,,,,,,,,,,,,, -30200,3.6085753,2.0290086,,,,,,,,,,,,,, -30300,3.359269,1.8916187,,,,,,,,,,,,,, -30400,4.2884793,1.9881446,,,,,,,,,,,,,, -30500,3.9350429,1.9771687,,,,,,,,,,,,,, -30600,3.5592506,2.0250225,,,,,,,,,,,,,, -30700,3.764963,2.0010722,,,,,,,,,,,,,, -30800,3.773147,1.9499213,,,,,,,,,,,,,, -30900,3.5632925,1.8931557,,,,,,,,,,,,,, -31000,3.3725274,1.9884233,,,,,,,,,,,,,, -31100,3.503551,1.9863784,,,,,,,,,,,,,, -31200,3.0986652,1.9535456,,,,,,,,,,,,,, -31300,3.7016203,1.9926865,,,,,,,,,,,,,, -31359,,,0.638671875,1.4534977674484253,0.5884400010108948,1.7131558656692505,50000.0,0.4750000238418579,2.414470672607422,10000.0,10743.71028470993,11130.426788568497,10743.71028470993,385.0090951919556,0.6394505500793457,0.0 -31400,3.8074884,2.011794,,,,,,,,,,,,,, -31500,3.4829764,2.0621078,,,,,,,,,,,,,, -31600,3.7206573,2.032648,,,,,,,,,,,,,, -31700,3.4887009,1.929497,,,,,,,,,,,,,, -31800,3.5929487,2.048201,,,,,,,,,,,,,, -31900,3.292896,1.961832,,,,,,,,,,,,,, -32000,3.5363512,2.0305183,,,,,,,,,,,,,, -32100,3.6358392,1.9515173,,,,,,,,,,,,,, -32200,3.7624383,2.015564,,,,,,,,,,,,,, -32300,4.337088,2.0012412,,,,,,,,,,,,,, -32400,3.5115702,1.9614192,,,,,,,,,,,,,, -32500,3.4422145,1.9565744,,,,,,,,,,,,,, -32600,3.1971972,1.9169277,,,,,,,,,,,,,, -32700,3.4722424,1.9806198,,,,,,,,,,,,,, -32800,3.928179,1.9374965,,,,,,,,,,,,,, -32855,,,0.6384924650192261,1.4499725103378296,0.5962799787521362,1.6743437051773071,50000.0,0.4702000319957733,2.409991502761841,10000.0,11253.648388624191,11658.121633052826,11253.648388624191,402.6841251850128,0.6713178157806396,0.0 -32900,3.351868,1.988491,,,,,,,,,,,,,, -33000,3.3175914,1.9522592,,,,,,,,,,,,,, -33100,2.9506698,1.8441024,,,,,,,,,,,,,, -33200,3.7396731,1.9606922,,,,,,,,,,,,,, -33300,3.6067667,1.9751825,,,,,,,,,,,,,, -33400,4.0315175,1.9271967,,,,,,,,,,,,,, -33500,4.15613,2.0636108,,,,,,,,,,,,,, -33600,3.3217158,2.068013,,,,,,,,,,,,,, -33700,3.9218411,1.9566032,,,,,,,,,,,,,, -33800,3.7475224,1.9837555,,,,,,,,,,,,,, -33900,3.910681,2.0952568,,,,,,,,,,,,,, -34000,3.671143,1.9753046,,,,,,,,,,,,,, -34100,3.6820076,2.0561547,,,,,,,,,,,,,, -34200,3.8644838,1.9856229,,,,,,,,,,,,,, -34300,3.3187475,1.9523184,,,,,,,,,,,,,, -34352,,,0.6353236436843872,1.4725189208984375,0.5895400047302246,1.7027602195739746,50000.0,0.4675000309944153,2.4392247200012207,10000.0,11763.811690092089,12186.06164097786,11763.811690092089,420.37781167030334,0.7036874294281006,0.0 -34400,3.1004179,1.9190058,,,,,,,,,,,,,, -34500,3.5231955,1.9678206,,,,,,,,,,,,,, -34600,3.3340425,1.9197733,,,,,,,,,,,,,, -34700,3.8487997,2.0131154,,,,,,,,,,,,,, -34800,3.7937636,2.049048,,,,,,,,,,,,,, -34900,3.7560112,1.9377074,,,,,,,,,,,,,, -35000,3.3509579,1.8643336,,,,,,,,,,,,,, -35100,3.252781,1.9233222,,,,,,,,,,,,,, -35200,3.0540435,1.8797252,,,,,,,,,,,,,, -35300,3.3269563,1.8590293,,,,,,,,,,,,,, -35400,3.5960844,1.9183078,,,,,,,,,,,,,, -35500,3.5178733,1.9478413,,,,,,,,,,,,,, -35600,3.5269272,1.9544277,,,,,,,,,,,,,, -35700,3.2279134,2.0127606,,,,,,,,,,,,,, -35800,3.9723973,1.988119,,,,,,,,,,,,,, -35849,,,0.652762234210968,1.4022929668426514,0.5992199778556824,1.6665433645248413,50000.0,0.4771000146865845,2.394937753677368,10000.0,12273.92913389206,12713.53648853302,12273.92913389206,437.6548173427582,0.7323830127716064,0.0 -35900,3.9673557,1.940697,,,,,,,,,,,,,, -36000,3.778924,2.0241077,,,,,,,,,,,,,, -36100,3.226535,1.9551933,,,,,,,,,,,,,, -36200,3.4521687,1.8072222,,,,,,,,,,,,,, -36300,3.370393,1.8536664,,,,,,,,,,,,,, -36400,3.9994662,2.0032835,,,,,,,,,,,,,, -36500,3.6522512,1.983183,,,,,,,,,,,,,, -36600,3.822183,2.1066759,,,,,,,,,,,,,, -36700,3.212859,1.8185743,,,,,,,,,,,,,, -36800,3.8706133,1.7291918,,,,,,,,,,,,,, -36900,3.679196,1.8701984,,,,,,,,,,,,,, -37000,3.242696,1.9256135,,,,,,,,,,,,,, -37100,3.621364,1.9526966,,,,,,,,,,,,,, -37200,3.0890605,1.8579657,,,,,,,,,,,,,, -37300,3.454162,1.9667981,,,,,,,,,,,,,, -37346,,,0.6575254797935486,1.3663238286972046,0.5941999554634094,1.6988669633865356,50000.0,0.4744000136852264,2.4055161476135254,10000.0,12784.103419065475,13241.438656330109,12784.103419065475,455.2939562797546,0.7698986530303955,0.0 -37400,3.259774,1.8521068,,,,,,,,,,,,,, -37500,3.7056706,2.010724,,,,,,,,,,,,,, -37600,3.5944326,1.9122334,,,,,,,,,,,,,, -37700,3.9262996,2.0616107,,,,,,,,,,,,,, -37800,4.0029783,1.9370348,,,,,,,,,,,,,, -37900,3.279328,1.9498181,,,,,,,,,,,,,, -38000,3.3386486,1.8767636,,,,,,,,,,,,,, -38100,4.2225447,1.9260283,,,,,,,,,,,,,, -38200,3.7337902,1.9856668,,,,,,,,,,,,,, -38300,3.3381772,1.9042435,,,,,,,,,,,,,, -38400,3.5897865,1.8760823,,,,,,,,,,,,,, -38500,3.6751814,1.8998625,,,,,,,,,,,,,, -38600,3.5185921,1.9824295,,,,,,,,,,,,,, -38700,2.8767622,1.874528,,,,,,,,,,,,,, -38800,3.1643426,1.9119649,,,,,,,,,,,,,, -38841,,,0.6528818607330322,1.384704828262329,0.5984399914741516,1.6746433973312378,50000.0,0.4828000366687774,2.4071831703186035,10000.0,13294.079423427582,13768.731862545012,13294.079423427582,472.52238607406616,0.8064682483673096,0.0 -38900,3.8545222,1.9321663,,,,,,,,,,,,,, -39000,3.190828,1.9668801,,,,,,,,,,,,,, -39100,3.5611756,1.8145818,,,,,,,,,,,,,, -39200,3.592553,1.989471,,,,,,,,,,,,,, -39300,3.9429736,1.9977337,,,,,,,,,,,,,, -39400,3.412649,2.0011284,,,,,,,,,,,,,, -39500,3.8516939,1.8195139,,,,,,,,,,,,,, -39600,3.0647423,1.8874799,,,,,,,,,,,,,, -39700,3.4382014,1.868882,,,,,,,,,,,,,, -39800,3.6382005,1.9626102,,,,,,,,,,,,,, -39900,3.9119642,1.9094387,,,,,,,,,,,,,, -40000,3.0787618,1.8354299,,,,,,,,,,,,,, -40100,4.7313037,1.8905113,,,,,,,,,,,,,, -40200,3.373418,1.9167928,,,,,,,,,,,,,, -40300,3.4908392,1.9416966,,,,,,,,,,,,,, -40338,,,0.6393694281578064,1.4571224451065063,0.5904799699783325,1.712005615234375,50000.0,0.4669000208377838,2.4362292289733887,10000.0,13804.172444343569,14296.480261325836,13804.172444343569,490.08971118927,0.845099687576294,0.0 -40400,3.3468058,1.916579,,,,,,,,,,,,,, -40500,3.4836216,1.9683094,,,,,,,,,,,,,, -40600,4.0306726,2.0854363,,,,,,,,,,,,,, -40700,3.5564933,1.8656371,,,,,,,,,,,,,, -40800,3.6574771,1.8322498,,,,,,,,,,,,,, -40900,3.5848331,1.8523123,,,,,,,,,,,,,, -41000,3.1871777,1.790171,,,,,,,,,,,,,, -41100,3.1688116,1.853138,,,,,,,,,,,,,, -41200,3.304346,1.8629564,,,,,,,,,,,,,, -41300,3.8043327,1.9473841,,,,,,,,,,,,,, -41400,3.591663,1.832566,,,,,,,,,,,,,, -41500,3.8738484,2.0734105,,,,,,,,,,,,,, -41600,3.4760108,2.0186427,,,,,,,,,,,,,, -41700,3.7308187,2.0577404,,,,,,,,,,,,,, -41800,3.5273094,1.8630688,,,,,,,,,,,,,, -41835,,,0.6407644748687744,1.4308786392211914,0.5977599620819092,1.671799659729004,50000.0,0.4811000227928161,2.351424217224121,10000.0,14314.146545886992,14824.075912237167,14314.146545886992,507.6242277622223,0.8813223838806152,0.0 -41900,4.235614,2.0293796,,,,,,,,,,,,,, -42000,4.0656567,1.8904538,,,,,,,,,,,,,, -42100,3.6030252,1.8187181,,,,,,,,,,,,,, -42200,4.476641,2.0494134,,,,,,,,,,,,,, -42300,4.133125,1.896266,,,,,,,,,,,,,, -42400,3.3432255,1.9640534,,,,,,,,,,,,,, -42500,3.3240118,1.997047,,,,,,,,,,,,,, -42600,3.3867369,1.8615702,,,,,,,,,,,,,, -42700,3.2303548,1.8778756,,,,,,,,,,,,,, -42800,3.6363466,1.8498963,,,,,,,,,,,,,, -42900,3.1945062,1.9837755,,,,,,,,,,,,,, -43000,3.5406187,1.8871598,,,,,,,,,,,,,, -43100,2.9358113,2.0356815,,,,,,,,,,,,,, -43200,3.1793194,1.8875841,,,,,,,,,,,,,, -43300,3.637212,1.9667069,,,,,,,,,,,,,, -43332,,,0.6420599222183228,1.4258434772491455,0.5988199710845947,1.6653971672058103,50000.0,0.4744000136852264,2.401122093200684,10000.0,14824.182245731354,15351.50887298584,14824.182245731354,524.9377455711365,0.915226936340332,0.0 -43400,3.3602302,1.9997301,,,,,,,,,,,,,, -43500,3.93101,1.8607658,,,,,,,,,,,,,, -43600,3.2222893,1.9228609,,,,,,,,,,,,,, -43700,3.3179462,1.8146274,,,,,,,,,,,,,, -43800,3.3769016,1.9162642,,,,,,,,,,,,,, -43900,3.4469612,1.7941158,,,,,,,,,,,,,, -44000,3.6462185,1.9445662,,,,,,,,,,,,,, -44100,3.6884365,1.9083287,,,,,,,,,,,,,, -44200,3.8189065,2.035593,,,,,,,,,,,,,, -44300,3.057477,1.7638967,,,,,,,,,,,,,, -44400,3.6558816,1.884372,,,,,,,,,,,,,, -44500,3.9598043,1.9563141,,,,,,,,,,,,,, -44600,3.327409,1.8668766,,,,,,,,,,,,,, -44700,3.638659,1.9060462,,,,,,,,,,,,,, -44800,3.3801587,1.9181709,,,,,,,,,,,,,, -44829,,,0.6495137214660645,1.4016178846359253,0.6055200099945068,1.6298383474349976,50000.0,0.4788000285625458,2.33974552154541,10000.0,15334.278590202332,15878.976461172104,15334.278590202332,542.2236630916595,0.9501383304595948,0.0 -44900,3.6887348,1.8910719,,,,,,,,,,,,,, -45000,3.3482594,1.7782211,,,,,,,,,,,,,, -45100,3.8220067,1.8690026,,,,,,,,,,,,,, -45200,3.0118613,1.8759391,,,,,,,,,,,,,, -45300,3.6331465,1.9375162,,,,,,,,,,,,,, -45400,3.4064965,1.7848358,,,,,,,,,,,,,, -45500,3.635716,2.0193474,,,,,,,,,,,,,, -45600,3.09654,1.8975011,,,,,,,,,,,,,, -45700,3.4542272,1.8744693,,,,,,,,,,,,,, -45800,3.7630057,1.9475601,,,,,,,,,,,,,, -45900,3.6490645,2.0246038,,,,,,,,,,,,,, -46000,3.2442691,1.9049484,,,,,,,,,,,,,, -46100,3.3513858,1.8975312,,,,,,,,,,,,,, -46200,4.5107837,1.9526606,,,,,,,,,,,,,, -46300,4.4002833,1.8434908,,,,,,,,,,,,,, -46327,,,0.6932198405265808,1.2089905738830566,0.6118999719619751,1.6064671277999878,50000.0,0.4900000095367431,2.314314126968384,10000.0,15844.478893518448,16406.788619995117,15844.478893518448,559.746725320816,0.9882421493530272,0.0 -46400,3.344906,1.8079308,,,,,,,,,,,,,, -46500,3.3479514,1.8850822,,,,,,,,,,,,,, -46600,3.5272784,1.8818216,,,,,,,,,,,,,, -46700,3.471938,1.9191022,,,,,,,,,,,,,, -46800,3.7442613,1.9135908,,,,,,,,,,,,,, -46900,3.2430892,1.8824625,,,,,,,,,,,,,, -47000,3.2183444,1.9738603,,,,,,,,,,,,,, -47100,3.7985523,1.8091149,,,,,,,,,,,,,, -47200,3.8825753,1.8764136,,,,,,,,,,,,,, -47300,3.4987733,1.8662078,,,,,,,,,,,,,, -47400,3.9991329,1.8935136,,,,,,,,,,,,,, -47500,4.0089064,1.980862,,,,,,,,,,,,,, -47600,3.4438875,1.9615581,,,,,,,,,,,,,, -47700,3.317361,1.8996619,,,,,,,,,,,,,, -47800,3.6622114,1.9324768,,,,,,,,,,,,,, -47825,,,0.6660555005073547,1.3233352899551392,0.604699969291687,1.6360735893249512,50000.0,0.4797000288963318,2.354118585586548,10000.0,16354.602724313736,16935.195270061493,16354.602724313736,577.9383962154388,1.026573657989502,0.0 -47900,3.8377957,1.8677809,,,,,,,,,,,,,, -48000,3.1198683,1.8345996,,,,,,,,,,,,,, -48100,3.801602,1.8152928,,,,,,,,,,,,,, -48200,3.9958816,1.816496,,,,,,,,,,,,,, -48300,3.2084854,1.9263875,,,,,,,,,,,,,, -48400,3.6844645,1.8310101,,,,,,,,,,,,,, -48500,3.7118137,1.9130747,,,,,,,,,,,,,, -48600,3.7324507,1.9834177,,,,,,,,,,,,,, -48700,4.356683,1.8423198,,,,,,,,,,,,,, -48800,4.0483756,1.8878874,,,,,,,,,,,,,, -48900,4.0099754,1.8536706,,,,,,,,,,,,,, -49000,3.780149,1.8563474,,,,,,,,,,,,,, -49100,3.4120538,1.9271094,,,,,,,,,,,,,, -49200,3.3542354,1.9198021,,,,,,,,,,,,,, -49300,3.3182228,1.7893559,,,,,,,,,,,,,, -49323,,,0.6553133130073547,1.379987716674805,0.6030399799346924,1.6445624828338623,50000.0,0.4821000099182129,2.3614590167999268,10000.0,16864.671887874603,17462.89333844185,16864.671887874603,595.4793081283569,1.063713788986206,0.0 -49400,3.6299744,1.9194524,,,,,,,,,,,,,, -49500,3.0992968,1.8785414,,,,,,,,,,,,,, -49600,3.5075703,1.8878293,,,,,,,,,,,,,, -49700,4.024905,2.0483303,,,,,,,,,,,,,, -49800,3.3588622,1.7326782,,,,,,,,,,,,,, -49900,3.3528843,2.000299,,,,,,,,,,,,,, -50000,4.08399,1.826339,,,,,,,,,,,,,, -50100,3.2827752,2.0051486,,,,,,,,,,,,,, -50200,4.157866,1.9114393,,,,,,,,,,,,,, -50300,4.2083354,1.8980415,,,,,,,,,,,,,, -50400,3.660872,1.6831657,,,,,,,,,,,,,, -50500,3.7101903,1.9477698,,,,,,,,,,,,,, -50600,3.7288952,1.8850287,,,,,,,,,,,,,, -50700,3.383408,1.856833,,,,,,,,,,,,,, -50800,3.0567214,1.8329818,,,,,,,,,,,,,, -50822,,,0.6630061864852905,1.3523014783859253,0.6115399599075317,1.6137793064117432,50000.0,0.4883000254631042,2.343726396560669,10000.0,17374.922873973846,17990.727430582047,17374.922873973846,612.9747793674469,1.0999596118927002,0.0 -50900,3.6794164,1.9457035,,,,,,,,,,,,,, -51000,3.9766984,1.8864424,,,,,,,,,,,,,, -51100,3.603023,1.7926348,,,,,,,,,,,,,, -51200,4.166807,1.8418357,,,,,,,,,,,,,, -51300,3.9898958,1.849232,,,,,,,,,,,,,, -51400,3.6478052,1.8397979,,,,,,,,,,,,,, -51500,3.6442494,1.9591396,,,,,,,,,,,,,, -51600,3.379639,1.8919271,,,,,,,,,,,,,, -51700,3.5320315,1.9353174,,,,,,,,,,,,,, -51800,3.6422617,1.9197137,,,,,,,,,,,,,, -51900,3.6993413,1.8186194,,,,,,,,,,,,,, -52000,3.442854,1.804005,,,,,,,,,,,,,, -52100,3.5345297,1.8850504,,,,,,,,,,,,,, -52200,3.4826076,2.003169,,,,,,,,,,,,,, -52300,3.7486808,1.8657606,,,,,,,,,,,,,, -52320,,,0.6526426672935486,1.3899118900299072,0.6061800122261047,1.6280734539031982,50000.0,0.4897000193595886,2.339770793914795,10000.0,17885.14226746559,18518.28278398513,17885.14226746559,630.2206964492798,1.1384408473968506,0.0 -52400,3.4090445,1.8526214,,,,,,,,,,,,,, -52500,3.893427,1.7689922,,,,,,,,,,,,,, -52600,3.803196,2.0083685,,,,,,,,,,,,,, -52700,4.1245947,1.7584062,,,,,,,,,,,,,, -52800,3.4085524,1.81946,,,,,,,,,,,,,, -52900,4.268755,1.7896987,,,,,,,,,,,,,, -53000,3.4526508,1.8087149,,,,,,,,,,,,,, -53100,3.5433273,1.9287393,,,,,,,,,,,,,, -53200,3.8439946,1.8408053,,,,,,,,,,,,,, -53300,3.5731041,1.9833026,,,,,,,,,,,,,, -53400,4.0756145,1.9176457,,,,,,,,,,,,,, -53500,3.7767131,1.7694457,,,,,,,,,,,,,, -53600,3.075298,1.8833332,,,,,,,,,,,,,, -53700,3.2322762,1.8773867,,,,,,,,,,,,,, -53800,3.3305852,1.8292396,,,,,,,,,,,,,, -53818,,,0.6602359414100647,1.361589789390564,0.613379955291748,1.606788992881775,50000.0,0.4889000356197357,2.3146893978118896,10000.0,18395.222403526303,19045.71305155754,18395.222403526303,647.4825391769409,1.1758079528808594,0.0 -53900,3.8343987,1.9660707,,,,,,,,,,,,,, -54000,3.3217902,1.8564632,,,,,,,,,,,,,, -54100,3.6756306,1.9848777,,,,,,,,,,,,,, -54200,3.3745613,1.9653313,,,,,,,,,,,,,, -54300,4.082955,1.85913,,,,,,,,,,,,,, -54400,3.5074635,1.9175382,,,,,,,,,,,,,, -54500,4.5538883,1.9352807,,,,,,,,,,,,,, -54600,3.8015795,1.8969011,,,,,,,,,,,,,, -54700,3.6621826,1.9256451,,,,,,,,,,,,,, -54800,4.3354983,1.9284499,,,,,,,,,,,,,, -54900,4.2875295,1.8261596,,,,,,,,,,,,,, -55000,3.5293622,1.888903,,,,,,,,,,,,,, -55100,3.7283866,1.8697389,,,,,,,,,,,,,, -55200,3.2884908,1.9076703,,,,,,,,,,,,,, -55300,3.9958386,1.857449,,,,,,,,,,,,,, -55316,,,0.699238657951355,1.1795507669448853,0.6133399605751038,1.5974823236465454,50000.0,0.4910000264644623,2.2842953205108643,10000.0,18905.270992279053,19573.216374635696,18905.270992279053,664.8474590778351,1.2139925956726074,0.0 -55400,3.4570253,1.8309236,,,,,,,,,,,,,, -55500,3.7459414,1.9338784,,,,,,,,,,,,,, -55600,3.1868262,1.8536217,,,,,,,,,,,,,, -55700,4.114999,1.8997147,,,,,,,,,,,,,, -55800,3.4212859,1.8748043,,,,,,,,,,,,,, -55900,3.850668,1.9879634,,,,,,,,,,,,,, -56000,3.9544482,1.9570279,,,,,,,,,,,,,, -56100,3.6095297,1.807598,,,,,,,,,,,,,, -56200,3.5124128,1.9178761,,,,,,,,,,,,,, -56300,3.4484947,1.9083552,,,,,,,,,,,,,, -56400,3.569372,1.8212123,,,,,,,,,,,,,, -56500,3.4409359,1.8043422,,,,,,,,,,,,,, -56600,3.947694,1.8017647,,,,,,,,,,,,,, -56700,3.439595,2.008627,,,,,,,,,,,,,, -56800,4.4088006,1.8162615,,,,,,,,,,,,,, -56814,,,0.685965359210968,1.2393864393234253,0.6195799708366394,1.5674865245819092,50000.0,0.4950000345706939,2.303964853286743,10000.0,19415.29590034485,20100.7879998684,19415.29590034485,682.3041090965271,1.252126932144165,0.0 -56900,4.0451736,1.8169818,,,,,,,,,,,,,, -57000,3.7489321,1.9158611,,,,,,,,,,,,,, -57100,3.0204685,1.6231135,,,,,,,,,,,,,, -57200,3.4142087,1.8074934,,,,,,,,,,,,,, -57300,3.8281627,1.8615206,,,,,,,,,,,,,, -57400,3.618502,1.7154185,,,,,,,,,,,,,, -57500,4.169082,1.8057017,,,,,,,,,,,,,, -57600,3.9268138,1.8973577,,,,,,,,,,,,,, -57700,3.990287,1.7997164,,,,,,,,,,,,,, -57800,3.589498,1.9544106,,,,,,,,,,,,,, -57900,4.530187,1.8247082,,,,,,,,,,,,,, -58000,3.3847466,1.7892766,,,,,,,,,,,,,, -58100,3.8275077,1.9367788,,,,,,,,,,,,,, -58200,3.6730847,1.8345937,,,,,,,,,,,,,, -58300,4.111893,1.7814182,,,,,,,,,,,,,, -58312,,,0.6573660373687744,1.369436264038086,0.6055799722671509,1.6424763202667236,50000.0,0.484000027179718,2.3704142570495605,10000.0,19925.24132657051,20628.266762018204,19925.24132657051,699.7403359413147,1.298933506011963,0.0 -58400,3.6016755,1.9179095,,,,,,,,,,,,,, -58500,3.329471,1.7443521,,,,,,,,,,,,,, -58600,3.5725484,1.8952163,,,,,,,,,,,,,, -58700,3.1728346,1.8844272,,,,,,,,,,,,,, -58800,4.039165,1.7777922,,,,,,,,,,,,,, -58900,3.734725,1.7558216,,,,,,,,,,,,,, -59000,3.344581,1.8390664,,,,,,,,,,,,,, -59100,3.357027,1.7316362,,,,,,,,,,,,,, -59200,3.4628057,1.9163239,,,,,,,,,,,,,, -59300,3.4312904,1.7630261,,,,,,,,,,,,,, -59400,3.7820544,1.7247247,,,,,,,,,,,,,, -59500,3.7946773,1.860286,,,,,,,,,,,,,, -59600,3.3255706,1.8222142,,,,,,,,,,,,,, -59700,3.1877155,1.7945093,,,,,,,,,,,,,, -59800,4.1070237,1.8573079,,,,,,,,,,,,,, -59811,,,0.6717155575752258,1.299277424812317,0.6213600039482117,1.5658512115478516,50000.0,0.4971000254154205,2.3210856914520264,10000.0,20435.426945209503,21155.81867671013,20435.426945209503,717.0170676708221,1.3389678001403809,0.0 -59900,3.408357,1.803,,,,,,,,,,,,,, -60000,3.6687546,1.7645655,,,,,,,,,,,,,, -60100,3.7388892,1.8787017,,,,,,,,,,,,,, -60200,3.6796737,1.7295228,,,,,,,,,,,,,, -60300,3.8952549,1.8846767,,,,,,,,,,,,,, -60400,3.9109917,1.6290367,,,,,,,,,,,,,, -60500,3.71713,1.8355693,,,,,,,,,,,,,, -60600,3.5847983,1.8105856,,,,,,,,,,,,,, -60700,3.9905663,1.7379313,,,,,,,,,,,,,, -60800,3.1947536,1.7990562,,,,,,,,,,,,,, -60900,3.9743454,1.872495,,,,,,,,,,,,,, -61000,3.9160953,1.8349589,,,,,,,,,,,,,, -61100,3.480617,1.8245568,,,,,,,,,,,,,, -61200,4.1012487,1.9908965,,,,,,,,,,,,,, -61300,3.3918576,1.8689976,,,,,,,,,,,,,, -61310,,,0.6558912396430969,1.3650057315826416,0.6165199875831604,1.608451247215271,50000.0,0.4895000159740448,2.3405680656433105,10000.0,20945.637128591537,21683.5158367157,20945.637128591537,734.4156177043915,1.3787786960601809,0.0 -61400,3.5611913,1.8779268,,,,,,,,,,,,,, -61500,3.93492,1.7690388,,,,,,,,,,,,,, -61600,4.1365647,1.762205,,,,,,,,,,,,,, -61700,3.7189841,1.861207,,,,,,,,,,,,,, -61800,3.6032753,1.7561008,,,,,,,,,,,,,, -61900,3.9119463,1.919235,,,,,,,,,,,,,, -62000,3.7365222,1.7424582,,,,,,,,,,,,,, -62100,3.772363,1.8358749,,,,,,,,,,,,,, -62200,3.3269367,1.7849084,,,,,,,,,,,,,, -62300,3.745419,1.7667884,,,,,,,,,,,,,, -62400,3.859655,1.8270924,,,,,,,,,,,,,, -62500,3.6499405,1.9256097,,,,,,,,,,,,,, -62600,3.6696932,1.7150317,,,,,,,,,,,,,, -62700,3.7004364,1.9034761,,,,,,,,,,,,,, -62800,4.1237144,1.8591111,,,,,,,,,,,,,, -62808,,,0.6518056392669678,1.3800477981567385,0.6077399849891663,1.6119192838668823,50000.0,0.4837000370025635,2.364459991455078,10000.0,21455.61483645439,22211.000241994858,21455.61483645439,751.8346419334412,1.415900468826294,0.0 -62900,3.973144,1.7739319,,,,,,,,,,,,,, -63000,4.1518517,1.7524871,,,,,,,,,,,,,, -63100,3.7702835,1.8589416,,,,,,,,,,,,,, -63200,3.9592671,1.7815356,,,,,,,,,,,,,, -63300,4.3057857,2.0003917,,,,,,,,,,,,,, -63400,4.800727,1.9175699,,,,,,,,,,,,,, -63500,3.8775446,1.8782387,,,,,,,,,,,,,, -63600,4.0744867,1.7999868,,,,,,,,,,,,,, -63700,4.2745295,1.8891948,,,,,,,,,,,,,, -63800,4.1114416,1.7962267,,,,,,,,,,,,,, -63900,3.7709277,1.7754945,,,,,,,,,,,,,, -64000,3.6966786,1.7873333,,,,,,,,,,,,,, -64100,3.6938417,1.7651237,,,,,,,,,,,,,, -64200,3.5370312,1.864399,,,,,,,,,,,,,, -64300,3.4284062,1.7738167,,,,,,,,,,,,,, -64306,,,0.6535594463348389,1.3952399492263794,0.6021400094032288,1.6519254446029663,50000.0,0.4803000092506408,2.3926892280578613,10000.0,21965.65988755226,22738.56585907936,21965.65988755226,769.2622895240784,1.4578070640563965,0.0 -64400,3.5561674,1.8126723,,,,,,,,,,,,,, -64500,3.5743432,1.8019595,,,,,,,,,,,,,, -64600,3.89122,1.6853731,,,,,,,,,,,,,, -64700,4.2508926,1.7936363,,,,,,,,,,,,,, -64800,3.6842961,1.8379587,,,,,,,,,,,,,, -64900,4.3033714,1.6980213,,,,,,,,,,,,,, -65000,3.7148378,1.7818292,,,,,,,,,,,,,, -65100,3.8097718,1.7514209,,,,,,,,,,,,,, -65200,3.7550476,1.7243044,,,,,,,,,,,,,, -65300,3.6985154,1.8197811,,,,,,,,,,,,,, -65400,3.5442753,1.7386055,,,,,,,,,,,,,, -65500,3.3568454,1.7138278,,,,,,,,,,,,,, -65600,3.3708663,1.8021605,,,,,,,,,,,,,, -65700,3.6438673,1.8665855,,,,,,,,,,,,,, -65800,3.6237557,1.6647379,,,,,,,,,,,,,, -65805,,,0.6921834945678711,1.207922101020813,0.6225999593734741,1.554673671722412,50000.0,0.4997000098228454,2.25830078125,10000.0,22475.830349206924,23266.250294208527,22475.830349206924,786.6829445362091,1.5008704662322998,0.0 -65900,3.7917154,1.9395659,,,,,,,,,,,,,, -66000,3.7422955,1.8789692,,,,,,,,,,,,,, -66100,4.195257,1.8323387,,,,,,,,,,,,,, -66200,4.4317455,1.9205116,,,,,,,,,,,,,, -66300,3.5659063,1.855426,,,,,,,,,,,,,, -66400,3.5594246,1.8679323,,,,,,,,,,,,,, -66500,4.339646,1.8352745,,,,,,,,,,,,,, -66600,4.0005713,1.7101028,,,,,,,,,,,,,, -66700,4.0973473,1.8063697,,,,,,,,,,,,,, -66800,3.5212214,1.7224598,,,,,,,,,,,,,, -66900,3.552911,1.8497058,,,,,,,,,,,,,, -67000,3.9234612,1.8872317,,,,,,,,,,,,,, -67100,3.683493,1.9426183,,,,,,,,,,,,,, -67200,3.6625912,1.6872505,,,,,,,,,,,,,, -67300,4.058361,1.7243083,,,,,,,,,,,,,, -67304,,,0.6910474896430969,1.2102947235107422,0.6284999847412109,1.5212945938110352,50000.0,0.5111000537872314,2.2321712970733643,10000.0,22985.990831136703,23793.728055477142,22985.990831136703,803.9027199745178,1.547590732574463,0.0 -67400,4.482993,1.9011426,,,,,,,,,,,,,, -67500,3.0208719,1.719907,,,,,,,,,,,,,, -67600,3.4333098,1.8951746,,,,,,,,,,,,,, -67700,3.696491,1.8144516,,,,,,,,,,,,,, -67800,3.7178452,1.7787983,,,,,,,,,,,,,, -67900,3.6795137,1.6171137,,,,,,,,,,,,,, -68000,4.319189,1.7894403,,,,,,,,,,,,,, -68100,3.564667,1.858872,,,,,,,,,,,,,, -68200,4.80133,1.9028347,,,,,,,,,,,,,, -68300,3.747431,1.8213015,,,,,,,,,,,,,, -68400,4.316654,1.7589248,,,,,,,,,,,,,, -68500,3.6602478,1.7699841,,,,,,,,,,,,,, -68600,3.284846,1.7100292,,,,,,,,,,,,,, -68700,3.493615,1.7101787,,,,,,,,,,,,,, -68800,4.650037,1.8160017,,,,,,,,,,,,,, -68803,,,0.6715561151504517,1.3061954975128174,0.614579975605011,1.594861626625061,50000.0,0.4935000240802765,2.3209471702575684,10000.0,23496.12866091728,24321.4762878418,23496.12866091728,821.4244170188904,1.585910081863403,0.0 -68900,3.7726643,1.7312005,,,,,,,,,,,,,, -69000,4.2651463,1.8353062,,,,,,,,,,,,,, -69100,4.0461974,1.7687958,,,,,,,,,,,,,, -69200,3.7800162,1.8451164,,,,,,,,,,,,,, -69300,3.4581616,1.7757927,,,,,,,,,,,,,, -69400,3.73264,1.7946005,,,,,,,,,,,,,, -69500,5.9363074,1.8511269,,,,,,,,,,,,,, -69600,3.849583,1.7137762,,,,,,,,,,,,,, -69700,4.1040516,1.7397873,,,,,,,,,,,,,, -69800,4.096769,1.6618397,,,,,,,,,,,,,, -69900,3.8873057,1.8065698,,,,,,,,,,,,,, -70000,3.7581565,1.8016553,,,,,,,,,,,,,, -70100,3.9639816,1.8182639,,,,,,,,,,,,,, -70200,3.9772096,1.7603384,,,,,,,,,,,,,, -70300,4.8030863,1.7394121,,,,,,,,,,,,,, -70301,,,0.6846699714660645,1.244409203529358,0.6282599568367004,1.5239821672439575,50000.0,0.5094000101089478,2.2113044261932373,10000.0,24006.130012512207,24849.235661029816,24006.130012512207,839.0864970684052,1.6295363903045654,0.0 -70400,3.5671244,1.774756,,,,,,,,,,,,,, -70500,3.6009045,1.6860178,,,,,,,,,,,,,, -70600,3.9961903,1.788797,,,,,,,,,,,,,, -70700,3.9234054,1.863434,,,,,,,,,,,,,, -70800,4.0439825,1.8591536,,,,,,,,,,,,,, -70900,3.8161652,1.7148752,,,,,,,,,,,,,, -71000,4.704803,1.8563607,,,,,,,,,,,,,, -71100,3.4795444,1.6574614,,,,,,,,,,,,,, -71200,4.5293193,1.7508556,,,,,,,,,,,,,, -71300,4.458849,1.7646863,,,,,,,,,,,,,, -71400,3.7094905,1.7447827,,,,,,,,,,,,,, -71500,3.6278589,1.8601813,,,,,,,,,,,,,, -71600,4.9713,1.7152648,,,,,,,,,,,,,, -71700,4.03179,1.8320293,,,,,,,,,,,,,, -71800,,,0.6765784025192261,1.2757912874221802,0.6304799914360046,1.5133068561553955,50000.0,0.510200023651123,2.217399835586548,10000.0,24516.21312022209,25376.767111063004,24516.21312022209,856.4387171268463,1.6738028526306152,0.0 -71800,3.9714866,1.7049328,,,,,,,,,,,,,, -71900,3.9371862,1.6925344,,,,,,,,,,,,,, -72000,4.429511,1.7828616,,,,,,,,,,,,,, -72100,4.0826697,1.7354865,,,,,,,,,,,,,, -72200,3.260025,1.7076722,,,,,,,,,,,,,, -72300,4.162722,1.818775,,,,,,,,,,,,,, -72400,3.8560395,1.9223377,,,,,,,,,,,,,, -72500,4.059992,1.7823709,,,,,,,,,,,,,, -72600,3.8538933,1.7498755,,,,,,,,,,,,,, -72700,3.6270225,1.7861857,,,,,,,,,,,,,, -72800,3.9666357,1.7286417,,,,,,,,,,,,,, -72900,4.6560345,1.7849622,,,,,,,,,,,,,, -73000,4.165321,1.7454991,,,,,,,,,,,,,, -73100,3.7293615,1.8735472,,,,,,,,,,,,,, -73200,3.994293,1.8694352,,,,,,,,,,,,,, -73298,,,0.6873405575752258,1.2287601232528689,0.6338199973106384,1.491397738456726,50000.0,0.5108000040054321,2.190986156463623,10000.0,25026.155697584152,25904.17239308357,25026.155697584152,873.810319185257,1.7138514518737793,0.0 -73300,3.8271482,1.8960993,,,,,,,,,,,,,, -73400,4.1456947,1.6629552,,,,,,,,,,,,,, -73500,4.350477,1.8434669,,,,,,,,,,,,,, -73600,3.950012,1.6027668,,,,,,,,,,,,,, -73700,4.7800045,1.7190825,,,,,,,,,,,,,, -73800,3.5619059,1.7111795,,,,,,,,,,,,,, -73900,3.7989643,1.7851369,,,,,,,,,,,,,, -74000,3.6647446,1.7491584,,,,,,,,,,,,,, -74100,3.4590886,1.641322,,,,,,,,,,,,,, -74200,4.115981,1.6033088,,,,,,,,,,,,,, -74300,3.9790404,1.766601,,,,,,,,,,,,,, -74400,3.9620047,1.7129331,,,,,,,,,,,,,, -74500,4.271871,1.7056593,,,,,,,,,,,,,, -74600,3.811089,1.8381152,,,,,,,,,,,,,, -74700,4.4451227,1.7215339,,,,,,,,,,,,,, -74797,,,0.6872010231018066,1.2051925659179688,0.6161800026893616,1.587920069694519,50000.0,0.5003000497817993,2.303654909133911,10000.0,25536.339210748672,26431.880017757416,25536.339210748672,891.2437620162964,1.7546777725219729,0.0 -74800,4.0308924,1.8593352,,,,,,,,,,,,,, -74900,4.1444936,1.708038,,,,,,,,,,,,,, -75000,3.620585,1.7077978,,,,,,,,,,,,,, -75100,3.749766,1.7524937,,,,,,,,,,,,,, -75200,4.7292285,1.747026,,,,,,,,,,,,,, -75300,3.9194121,1.7887367,,,,,,,,,,,,,, -75400,3.8835468,1.763776,,,,,,,,,,,,,, -75500,3.846371,1.8100318,,,,,,,,,,,,,, -75600,3.7208073,1.8706408,,,,,,,,,,,,,, -75700,3.5679808,1.7629663,,,,,,,,,,,,,, -75800,4.6172075,1.764785,,,,,,,,,,,,,, -75900,3.7015512,1.7401162,,,,,,,,,,,,,, -76000,4.161664,1.7671144,,,,,,,,,,,,,, -76100,3.8226714,1.7255384,,,,,,,,,,,,,, -76200,3.4084678,1.725164,,,,,,,,,,,,,, -76296,,,0.6835737824440002,1.253710389137268,0.6217799782752991,1.5734446048736572,50000.0,0.4935000240802765,2.3314461708068848,10000.0,26046.54245686531,26959.465401887894,26046.54245686531,908.53577709198,1.7942428588867188,0.0 -76300,5.109128,1.7127494,,,,,,,,,,,,,, -76400,3.8319936,1.8284603,,,,,,,,,,,,,, -76500,4.1586084,1.6806818,,,,,,,,,,,,,, -76600,3.738774,1.754918,,,,,,,,,,,,,, -76700,3.917578,1.7949212,,,,,,,,,,,,,, -76800,3.7539444,1.6998448,,,,,,,,,,,,,, -76900,3.9352365,1.7986444,,,,,,,,,,,,,, -77000,4.141638,1.7547154,,,,,,,,,,,,,, -77100,3.554528,1.7891543,,,,,,,,,,,,,, -77200,3.9139874,1.845594,,,,,,,,,,,,,, -77300,4.2836914,1.748018,,,,,,,,,,,,,, -77400,4.2858477,1.754312,,,,,,,,,,,,,, -77500,3.5179918,1.7267027,,,,,,,,,,,,,, -77600,3.617353,1.663157,,,,,,,,,,,,,, -77700,4.4276476,1.8320427,,,,,,,,,,,,,, -77783,,,0.6978435516357422,1.2021422386169434,0.6412799954414368,1.4752492904663086,50000.0,0.5133000016212463,2.162278652191162,10000.0,26556.63892173767,27487.96911430359,26556.63892173767,926.8507552146912,1.83475923538208,0.0 -77800,4.057336,1.7975785,,,,,,,,,,,,,, -77900,4.1098332,1.8195391,,,,,,,,,,,,,, -78000,3.524752,1.7721741,,,,,,,,,,,,,, -78100,3.8731446,1.7263031,,,,,,,,,,,,,, -78200,3.7867138,1.6461477,,,,,,,,,,,,,, -78300,3.672285,1.6491452,,,,,,,,,,,,,, -78400,4.2434473,1.8133078,,,,,,,,,,,,,, -78500,3.7862453,1.7492757,,,,,,,,,,,,,, -78600,3.5754557,1.7553644,,,,,,,,,,,,,, -78700,3.929666,1.5487629,,,,,,,,,,,,,, -78800,4.447517,1.7647934,,,,,,,,,,,,,, -78900,4.113376,1.8094423,,,,,,,,,,,,,, -79000,4.1067696,1.7831354,,,,,,,,,,,,,, -79100,3.8519874,1.688813,,,,,,,,,,,,,, -79200,4.7468643,1.9482337,,,,,,,,,,,,,, -79282,,,0.6960698366165161,1.1828243732452393,0.6363999843597412,1.4793654680252075,50000.0,0.5126000046730042,2.2020113468170166,10000.0,27066.782158851624,28015.67524456978,27066.782158851624,944.3251445293428,1.8716328144073489,0.0 -79300,4.077689,1.8026905,,,,,,,,,,,,,, -79400,4.1521006,1.620614,,,,,,,,,,,,,, -79500,4.39633,1.7445731,,,,,,,,,,,,,, -79600,3.957355,1.658972,,,,,,,,,,,,,, -79700,3.8372269,1.613769,,,,,,,,,,,,,, -79800,4.561296,1.702955,,,,,,,,,,,,,, -79900,4.353876,1.6224337,,,,,,,,,,,,,, -80000,3.8904173,1.6917655,,,,,,,,,,,,,, -80100,3.8574438,1.8351719,,,,,,,,,,,,,, -80200,4.0037265,1.7383591,,,,,,,,,,,,,, -80300,3.9210687,1.7959194,,,,,,,,,,,,,, -80400,4.092198,1.6380358,,,,,,,,,,,,,, -80500,4.3509464,1.7578741,,,,,,,,,,,,,, -80600,3.9474869,1.7271852,,,,,,,,,,,,,, -80700,4.2877192,1.6654953,,,,,,,,,,,,,, -80781,,,0.6879384517669678,1.2187047004699707,0.6388999819755554,1.481725573539734,50000.0,0.5117000341415405,2.207427263259888,10000.0,27576.946058273315,28543.36324763298,27576.946058273315,961.7491703033448,1.9183545112609863,0.0 -80800,4.3798985,1.6146877,,,,,,,,,,,,,, -80900,4.6123805,1.7193414,,,,,,,,,,,,,, -81000,3.8280554,1.7414613,,,,,,,,,,,,,, -81100,4.113565,1.5633817,,,,,,,,,,,,,, -81200,4.365696,1.5977833,,,,,,,,,,,,,, -81300,3.9345179,1.7175481,,,,,,,,,,,,,, -81400,3.7849061,1.7783892,,,,,,,,,,,,,, -81500,3.6656075,1.879997,,,,,,,,,,,,,, -81600,4.0324388,1.747807,,,,,,,,,,,,,, -81700,4.4493685,1.6763284,,,,,,,,,,,,,, -81800,4.0351815,1.7778221,,,,,,,,,,,,,, -81900,3.912889,1.7346172,,,,,,,,,,,,,, -82000,4.1861033,1.7072277,,,,,,,,,,,,,, -82100,4.210862,1.7147509,,,,,,,,,,,,,, -82200,3.9428573,1.7214024,,,,,,,,,,,,,, -82280,,,0.6951131820678711,1.184423327445984,0.6406799554824829,1.4681005477905271,50000.0,0.5145000219345093,2.178096532821656,10000.0,28087.02109003067,29070.87357234955,28087.02109003067,979.0919954776764,1.960249662399292,0.0 -82300,3.823158,1.7275414,,,,,,,,,,,,,, -82400,3.939571,1.6705836,,,,,,,,,,,,,, -82500,3.8926482,1.7602637,,,,,,,,,,,,,, -82600,4.168802,1.6920571,,,,,,,,,,,,,, -82700,3.6957479,1.7665836,,,,,,,,,,,,,, -82800,5.5781827,1.603716,,,,,,,,,,,,,, -82900,4.6020627,1.7900002,,,,,,,,,,,,,, -83000,4.328263,1.7353609,,,,,,,,,,,,,, -83100,4.057532,1.6919975,,,,,,,,,,,,,, -83200,3.7972834,1.6128807,,,,,,,,,,,,,, -83300,4.3934016,1.8270116,,,,,,,,,,,,,, -83400,4.3878603,1.6723425,,,,,,,,,,,,,, -83500,4.150898,1.6595871,,,,,,,,,,,,,, -83600,3.9615927,1.7058183,,,,,,,,,,,,,, -83700,5.789331,1.88662,,,,,,,,,,,,,, -83779,,,0.7388990521430969,1.0035078525543213,0.6438999772071838,1.44945228099823,50000.0,0.5115000009536743,2.1780617237091064,10000.0,28596.984345436096,29598.28356528282,28596.984345436096,996.4454569816588,2.001596212387085,0.0 -83800,3.62053,1.8174046,,,,,,,,,,,,,, -83900,4.038114,1.7534983,,,,,,,,,,,,,, -84000,4.2602577,1.7914689,,,,,,,,,,,,,, -84100,4.3193817,1.7653372,,,,,,,,,,,,,, -84200,3.9410648,1.7282149,,,,,,,,,,,,,, -84300,4.298847,1.7379858,,,,,,,,,,,,,, -84400,4.0757523,1.7623739,,,,,,,,,,,,,, -84500,4.2061057,1.7952392,,,,,,,,,,,,,, -84600,4.05503,1.7564386,,,,,,,,,,,,,, -84700,4.095515,1.7745261,,,,,,,,,,,,,, -84800,3.9950879,1.6513801,,,,,,,,,,,,,, -84900,3.7024906,1.703642,,,,,,,,,,,,,, -85000,3.9255106,1.7276987,,,,,,,,,,,,,, -85100,3.8848526,1.7476484,,,,,,,,,,,,,, -85200,4.7319603,1.7468145,,,,,,,,,,,,,, -85277,,,0.7179328799247742,1.081685185432434,0.6489399671554565,1.4340031147003174,50000.0,0.522599995136261,2.130420923233032,10000.0,29107.010613918304,30126.32354569435,29107.010613918304,1014.3660583496094,2.0440704822540283,0.0 -85300,3.6678321,1.682308,,,,,,,,,,,,,, -85400,4.010569,1.6024427,,,,,,,,,,,,,, -85500,4.1419444,1.5874674,,,,,,,,,,,,,, -85600,4.510952,1.5989138,,,,,,,,,,,,,, -85700,4.2348804,1.6855707,,,,,,,,,,,,,, -85800,3.6886177,1.7031076,,,,,,,,,,,,,, -85900,3.9752162,1.7181334,,,,,,,,,,,,,, -86000,4.7478175,1.7205805,,,,,,,,,,,,,, -86100,3.9255178,1.757966,,,,,,,,,,,,,, -86200,4.0431066,1.770384,,,,,,,,,,,,,, -86300,4.2563753,1.5834652,,,,,,,,,,,,,, -86400,4.3878946,1.7193444,,,,,,,,,,,,,, -86500,4.0942,1.7144204,,,,,,,,,,,,,, -86600,4.304084,1.6871209,,,,,,,,,,,,,, -86700,4.425259,1.7952698,,,,,,,,,,,,,, -86776,,,0.7063137888908386,1.1358715295791626,0.6468200087547302,1.4428404569625854,50000.0,0.5206000208854675,2.138397455215454,10000.0,29617.09086871147,30653.98149752617,29617.09086871147,1031.8577575683594,2.080768346786499,0.0 -86800,4.5969515,1.6185353,,,,,,,,,,,,,, -86900,3.880128,1.7380948,,,,,,,,,,,,,, -87000,3.889102,1.6540838,,,,,,,,,,,,,, -87100,3.9656463,1.6465697,,,,,,,,,,,,,, -87200,4.221691,1.6135564,,,,,,,,,,,,,, -87300,4.4840007,1.5948886,,,,,,,,,,,,,, -87400,5.113938,1.7057306,,,,,,,,,,,,,, -87500,4.1557736,1.5976454,,,,,,,,,,,,,, -87600,4.140342,1.7036804,,,,,,,,,,,,,, -87700,5.0962043,1.6740593,,,,,,,,,,,,,, -87800,4.3051176,1.6090782,,,,,,,,,,,,,, -87900,4.04097,1.615168,,,,,,,,,,,,,, -88000,3.9859111,1.5540293,,,,,,,,,,,,,, -88100,3.9119868,1.625848,,,,,,,,,,,,,, -88200,4.247386,1.6678964,,,,,,,,,,,,,, -88275,,,0.7068120241165161,1.13710355758667,0.6526600122451782,1.4295099973678589,50000.0,0.5166000127792358,2.154176950454712,10000.0,30127.182915449142,31181.80454421044,30127.182915449142,1049.4971315860748,2.121476888656616,0.0 -88300,4.1080456,1.624501,,,,,,,,,,,,,, -88400,4.289387,1.7571218,,,,,,,,,,,,,, -88500,4.0814795,1.6936189,,,,,,,,,,,,,, -88600,4.411356,1.7422731,,,,,,,,,,,,,, -88700,4.301964,1.7036589,,,,,,,,,,,,,, -88800,4.33799,1.5092196,,,,,,,,,,,,,, -88900,4.435481,1.7211108,,,,,,,,,,,,,, -89000,4.23141,1.6065924,,,,,,,,,,,,,, -89100,3.706017,1.5990218,,,,,,,,,,,,,, -89200,3.9778547,1.6826383,,,,,,,,,,,,,, -89300,3.9664786,1.718559,,,,,,,,,,,,,, -89400,3.9639344,1.5816661,,,,,,,,,,,,,, -89500,4.4660454,1.5804021,,,,,,,,,,,,,, -89600,4.2790704,1.6350965,,,,,,,,,,,,,, -89700,3.8714445,1.5352491,,,,,,,,,,,,,, -89771,,,0.6945750713348389,1.1966772079467771,0.6418799757957458,1.463532567024231,50000.0,0.5118000507354736,2.200810432434082,10000.0,30636.21947956085,31709.34457540512,30636.21947956085,1066.8635828495026,3.207822322845459,0.0 -89800,5.005991,1.5897038,,,,,,,,,,,,,, -89900,5.1058683,1.7685441,,,,,,,,,,,,,, -90000,4.079558,1.6426543,,,,,,,,,,,,,, -90100,4.2062154,1.7425996,,,,,,,,,,,,,, -90200,4.1373425,1.6574104,,,,,,,,,,,,,, -90300,4.1462054,1.6721315,,,,,,,,,,,,,, -90400,4.606571,1.6527938,,,,,,,,,,,,,, -90500,4.1334486,1.6525648,,,,,,,,,,,,,, -90600,4.6950073,1.6410208,,,,,,,,,,,,,, -90700,3.95456,1.8208741,,,,,,,,,,,,,, -90800,4.4742584,1.7004429,,,,,,,,,,,,,, -90900,3.9408002,1.6136261,,,,,,,,,,,,,, -91000,4.1275234,1.7426624,,,,,,,,,,,,,, -91100,4.0175037,1.7004988,,,,,,,,,,,,,, -91200,3.8914843,1.6231285,,,,,,,,,,,,,, -91270,,,0.7092235088348389,1.134615778923035,0.6538000106811523,1.4101835489273071,50000.0,0.5267000198364258,2.1272940635681152,10000.0,31146.34582209587,32236.86432123184,31146.34582209587,1084.156126499176,3.255371332168579,0.0 -91300,3.7510583,1.5828385,,,,,,,,,,,,,, -91400,4.1905313,1.6234689,,,,,,,,,,,,,, -91500,3.7008243,1.5438199,,,,,,,,,,,,,, -91600,4.1948867,1.69053,,,,,,,,,,,,,, -91700,4.811162,1.6194,,,,,,,,,,,,,, -91800,4.1675954,1.7482768,,,,,,,,,,,,,, -91900,4.0885324,1.6459103,,,,,,,,,,,,,, -92000,4.1102133,1.6581572,,,,,,,,,,,,,, -92100,4.3149424,1.8353117,,,,,,,,,,,,,, -92200,3.9364893,1.6318066,,,,,,,,,,,,,, -92300,4.32121,1.5239513,,,,,,,,,,,,,, -92400,4.749727,1.6387417,,,,,,,,,,,,,, -92500,3.8494468,1.6316,,,,,,,,,,,,,, -92600,3.8250616,1.6602712,,,,,,,,,,,,,, -92700,4.019221,1.5658337,,,,,,,,,,,,,, -92768,,,0.7113958597183228,1.1191548109054563,0.6482200026512146,1.4321719408035278,50000.0,0.5174000263214111,2.1670353412628174,10000.0,31656.25477242469,32764.192311525345,31656.25477242469,1101.4758217334747,3.3038907051086426,0.0 -92800,4.3324246,1.6823335,,,,,,,,,,,,,, -92900,4.9146976,1.602892,,,,,,,,,,,,,, -93000,4.353906,1.6577203,,,,,,,,,,,,,, -93100,4.5886264,1.6795325,,,,,,,,,,,,,, -93200,4.1848054,1.6263328,,,,,,,,,,,,,, -93300,4.4152026,1.6465684,,,,,,,,,,,,,, -93400,3.9757688,1.5126255,,,,,,,,,,,,,, -93500,4.032372,1.705944,,,,,,,,,,,,,, -93600,4.2682815,1.5750278,,,,,,,,,,,,,, -93700,4.3071375,1.6910932,,,,,,,,,,,,,, -93800,4.5734043,1.740466,,,,,,,,,,,,,, -93900,3.8450546,1.5585515,,,,,,,,,,,,,, -94000,4.750589,1.6503887,,,,,,,,,,,,,, -94100,3.9320045,1.6208854,,,,,,,,,,,,,, -94200,4.8938494,1.6927342,,,,,,,,,,,,,, -94267,,,0.7361487150192261,1.0138471126556396,0.659500002861023,1.3887938261032104,50000.0,0.5303000211715698,2.0873241424560547,10000.0,32166.24189376831,33291.75237035751,32166.24189376831,1118.9492797851562,3.351661443710327,0.0 -94300,4.2597795,1.6616673,,,,,,,,,,,,,, -94400,4.153965,1.5348617,,,,,,,,,,,,,, -94500,4.093407,1.7343925,,,,,,,,,,,,,, -94600,4.223356,1.6566575,,,,,,,,,,,,,, -94700,4.6714745,1.5776957,,,,,,,,,,,,,, -94800,4.207379,1.6408175,,,,,,,,,,,,,, -94900,4.022901,1.6462585,,,,,,,,,,,,,, -95000,4.4060373,1.6491486,,,,,,,,,,,,,, -95100,4.3341393,1.6012421,,,,,,,,,,,,,, -95200,4.2303514,1.569241,,,,,,,,,,,,,, -95300,4.4321523,1.5438386,,,,,,,,,,,,,, -95400,4.379081,1.6802851,,,,,,,,,,,,,, -95500,4.561418,1.7779357,,,,,,,,,,,,,, -95600,4.7779403,1.6266866,,,,,,,,,,,,,, -95700,4.1524024,1.6097232,,,,,,,,,,,,,, -95766,,,0.7229153513908386,1.0641008615493774,0.6566199660301208,1.4057345390319824,50000.0,0.5320000052452087,2.1097888946533203,10000.0,32676.229960918427,33819.02078318596,32676.229960918427,1136.129715681076,3.401322364807129,0.0 -95800,4.375744,1.6014748,,,,,,,,,,,,,, -95900,4.4631042,1.725224,,,,,,,,,,,,,, -96000,4.6760545,1.6620094,,,,,,,,,,,,,, -96100,4.874606,1.5380673,,,,,,,,,,,,,, -96200,4.4095287,1.6922858,,,,,,,,,,,,,, -96300,4.5144377,1.6211637,,,,,,,,,,,,,, -96400,4.1710386,1.547193,,,,,,,,,,,,,, -96500,4.763638,1.6286796,,,,,,,,,,,,,, -96600,4.1042542,1.5974344,,,,,,,,,,,,,, -96700,4.394446,1.6022133,,,,,,,,,,,,,, -96800,4.0074434,1.5715907,,,,,,,,,,,,,, -96900,4.5982027,1.668265,,,,,,,,,,,,,, -97000,4.7533674,1.606149,,,,,,,,,,,,,, -97100,3.9375358,1.6459309,,,,,,,,,,,,,, -97200,4.62574,1.7521658,,,,,,,,,,,,,, -97265,,,0.7237324714660645,1.0708374977111816,0.6580399870872498,1.3982630968093872,50000.0,0.5344000458717346,2.089726209640503,10000.0,33186.39605593681,34346.4784386158,33186.39605593681,1153.3248386383057,3.446504831314087,0.0 -97300,4.190321,1.5906528,,,,,,,,,,,,,, -97400,4.394496,1.5751878,,,,,,,,,,,,,, -97500,4.112526,1.5674953,,,,,,,,,,,,,, -97600,4.0460496,1.5673932,,,,,,,,,,,,,, -97700,4.360734,1.6037489,,,,,,,,,,,,,, -97800,4.9554996,1.695405,,,,,,,,,,,,,, -97900,4.350665,1.606262,,,,,,,,,,,,,, -98000,4.3545012,1.6613909,,,,,,,,,,,,,, -98100,4.2286134,1.6244602,,,,,,,,,,,,,, -98200,5.1777196,1.535479,,,,,,,,,,,,,, -98300,4.318589,1.5660126,,,,,,,,,,,,,, -98400,4.7350397,1.7124467,,,,,,,,,,,,,, -98500,4.1096888,1.5513289,,,,,,,,,,,,,, -98600,4.351834,1.6128169,,,,,,,,,,,,,, -98700,4.6084356,1.6794083,,,,,,,,,,,,,, -98765,,,0.7150430083274841,1.1016902923583984,0.6593799591064453,1.3938559293746948,50000.0,0.536300003528595,2.0971498489379883,10000.0,33696.623683452606,34874.23455262184,33696.623683452606,1170.7596073150637,3.4904332160949707,0.0 -98800,5.121598,1.598809,,,,,,,,,,,,,, -98900,4.2707505,1.5796463,,,,,,,,,,,,,, -99000,3.8376846,1.6382314,,,,,,,,,,,,,, -99100,4.2862043,1.4484199,,,,,,,,,,,,,, -99200,4.480778,1.6860018,,,,,,,,,,,,,, -99300,4.963086,1.6121302,,,,,,,,,,,,,, -99400,5.0383425,1.539049,,,,,,,,,,,,,, -99500,4.866221,1.6351991,,,,,,,,,,,,,, -99600,4.9276533,1.6237943,,,,,,,,,,,,,, -99700,4.4377174,1.5443442,,,,,,,,,,,,,, -99800,5.1199484,1.4647835,,,,,,,,,,,,,, -99900,4.1048646,1.6963652,,,,,,,,,,,,,, -100000,4.0444922,1.526609,,,,,,,,,,,,,, -100100,4.57806,1.612848,,,,,,,,,,,,,, -100200,4.3678007,1.4263093,,,,,,,,,,,,,, -100264,,,0.713309109210968,1.1122252941131592,0.6600399613380432,1.3887194395065308,50000.0,0.5412000417709351,2.056500196456909,10000.0,34206.6414039135,35401.838272333145,34206.6414039135,1188.2499401569366,3.535299777984619,0.0 -100300,4.337639,1.6962783,,,,,,,,,,,,,, -100400,4.4591336,1.5641363,,,,,,,,,,,,,, -100500,4.1372905,1.6226062,,,,,,,,,,,,,, -100600,4.2847605,1.5071934,,,,,,,,,,,,,, -100700,4.7093296,1.6803092,,,,,,,,,,,,,, -100800,4.0104995,1.426338,,,,,,,,,,,,,, -100900,4.51483,1.6860764,,,,,,,,,,,,,, -101000,4.674889,1.7786319,,,,,,,,,,,,,, -101100,4.1519046,1.4845273,,,,,,,,,,,,,, -101200,4.76819,1.555244,,,,,,,,,,,,,, -101300,4.2695527,1.5314139,,,,,,,,,,,,,, -101400,4.295862,1.598673,,,,,,,,,,,,,, -101500,4.479156,1.6127765,,,,,,,,,,,,,, -101600,4.43846,1.554949,,,,,,,,,,,,,, -101700,4.336948,1.4641054,,,,,,,,,,,,,, -101764,,,0.7233139276504517,1.072434902191162,0.6643799543380737,1.362235426902771,50000.0,0.5402000546455383,2.057695150375366,10000.0,34716.84902739525,35929.189604759216,34716.84902739525,1205.2929441928864,3.586189031600952,0.0 -101800,4.2437744,1.5222394,,,,,,,,,,,,,, -101900,5.0705175,1.7210104,,,,,,,,,,,,,, -102000,4.7958565,1.6182469,,,,,,,,,,,,,, -102100,4.493094,1.5877625,,,,,,,,,,,,,, -102200,4.0330243,1.514575,,,,,,,,,,,,,, -102300,4.313003,1.4491876,,,,,,,,,,,,,, -102400,4.1279182,1.524013,,,,,,,,,,,,,, -102500,4.3288093,1.631715,,,,,,,,,,,,,, -102600,4.4524307,1.6277126,,,,,,,,,,,,,, -102700,4.62236,1.4496164,,,,,,,,,,,,,, -102800,4.9150906,1.5387781,,,,,,,,,,,,,, -102900,4.544458,1.6181529,,,,,,,,,,,,,, -103000,4.6238284,1.6545837,,,,,,,,,,,,,, -103100,4.791272,1.5660324,,,,,,,,,,,,,, -103200,4.877057,1.6569241,,,,,,,,,,,,,, -103263,,,0.75394606590271,0.9281212091445924,0.6668999791145325,1.3567508459091189,50000.0,0.5405000448226929,2.041477918624878,10000.0,35226.76603245735,36456.52231359482,35226.76603245735,1222.604502916336,3.63908052444458,0.0 -103300,5.028842,1.6248715,,,,,,,,,,,,,, -103400,4.459053,1.5164083,,,,,,,,,,,,,, -103500,4.7325163,1.5665207,,,,,,,,,,,,,, -103600,5.042705,1.6941965,,,,,,,,,,,,,, -103700,4.4637923,1.5785611,,,,,,,,,,,,,, -103800,4.786386,1.5320168,,,,,,,,,,,,,, -103900,4.3626404,1.4683319,,,,,,,,,,,,,, -104000,4.8075566,1.6212509,,,,,,,,,,,,,, -104100,4.6950593,1.6699598,,,,,,,,,,,,,, -104200,4.0882826,1.4942008,,,,,,,,,,,,,, -104300,4.99594,1.5168079,,,,,,,,,,,,,, -104400,4.6741014,1.5184568,,,,,,,,,,,,,, -104500,4.477258,1.6020215,,,,,,,,,,,,,, -104600,4.577296,1.6060846,,,,,,,,,,,,,, -104700,4.4333706,1.6194555,,,,,,,,,,,,,, -104762,,,0.7365872263908386,1.0082699060440063,0.6635199785232544,1.3672338724136353,50000.0,0.5410000085830688,2.045719623565674,10000.0,35736.759813547134,36984.04020619392,35736.759813547134,1240.0294904708862,3.6891348361968994,0.0 -104800,4.6479053,1.640948,,,,,,,,,,,,,, -104900,4.979074,1.4990884,,,,,,,,,,,,,, -105000,4.327565,1.5558829,,,,,,,,,,,,,, -105100,4.808986,1.5581492,,,,,,,,,,,,,, -105200,4.3851337,1.4705753,,,,,,,,,,,,,, -105300,4.4887233,1.5826837,,,,,,,,,,,,,, -105400,5.5484357,1.6374285,,,,,,,,,,,,,, -105500,4.7233386,1.5742729,,,,,,,,,,,,,, -105600,4.8053226,1.5291793,,,,,,,,,,,,,, -105700,5.0086985,1.7039112,,,,,,,,,,,,,, -105800,5.149924,1.5609912,,,,,,,,,,,,,, -105900,4.39104,1.5501347,,,,,,,,,,,,,, -106000,5.3889694,1.5276808,,,,,,,,,,,,,, -106100,4.38166,1.5496597,,,,,,,,,,,,,, -106200,4.8723993,1.6489687,,,,,,,,,,,,,, -106261,,,0.7372249364852905,0.999062955379486,0.6670599579811096,1.3503457307815552,50000.0,0.5432000160217285,2.066902637481689,10000.0,36246.74147129059,37511.61378097534,36246.74147129059,1257.5226662158966,3.7353804111480713,0.0 -106300,4.700468,1.6406808,,,,,,,,,,,,,, -106400,5.297009,1.6280012,,,,,,,,,,,,,, -106500,5.6985664,1.5525305,,,,,,,,,,,,,, -106600,5.443657,1.6540405,,,,,,,,,,,,,, -106700,4.0553117,1.4446079,,,,,,,,,,,,,, -106800,5.243873,1.5454897,,,,,,,,,,,,,, -106900,5.143489,1.484785,,,,,,,,,,,,,, -107000,4.386395,1.6582392,,,,,,,,,,,,,, -107100,4.978689,1.6120551,,,,,,,,,,,,,, -107200,4.633291,1.5070266,,,,,,,,,,,,,, -107300,5.1796412,1.620894,,,,,,,,,,,,,, -107400,4.743374,1.5701733,,,,,,,,,,,,,, -107500,4.557388,1.4966959,,,,,,,,,,,,,, -107600,5.083144,1.684386,,,,,,,,,,,,,, -107700,4.952808,1.4681916,,,,,,,,,,,,,, -107760,,,0.7306082248687744,1.0350559949874878,0.6622999906539917,1.3632971048355105,50000.0,0.5291000008583069,2.1044600009918213,10000.0,36756.6770863533,38039.03833150864,36756.6770863533,1274.914494514465,3.7825989723205566,0.0 -107800,4.652146,1.5283949,,,,,,,,,,,,,, -107900,4.559628,1.5056394,,,,,,,,,,,,,, -108000,4.984641,1.6827502,,,,,,,,,,,,,, -108100,4.6310716,1.4645277,,,,,,,,,,,,,, -108200,5.4321785,1.5284495,,,,,,,,,,,,,, -108300,5.044913,1.503396,,,,,,,,,,,,,, -108400,5.2694454,1.5990527,,,,,,,,,,,,,, -108500,4.743456,1.6305882,,,,,,,,,,,,,, -108600,4.4218273,1.489075,,,,,,,,,,,,,, -108700,4.782585,1.5432315,,,,,,,,,,,,,, -108800,5.315977,1.5189208,,,,,,,,,,,,,, -108900,5.8692446,1.6071979,,,,,,,,,,,,,, -109000,5.2206187,1.5434983,,,,,,,,,,,,,, -109100,4.9158125,1.5330552,,,,,,,,,,,,,, -109200,4.336618,1.4546335,,,,,,,,,,,,,, -109259,,,0.7350525856018066,1.0115010738372805,0.6735000014305115,1.3214294910430908,50000.0,0.549500048160553,2.004624605178833,10000.0,37266.8054394722,38566.80239248276,37266.8054394722,1292.4509418010712,3.830709218978882,0.0 -109300,4.6942315,1.4273957,,,,,,,,,,,,,, -109400,5.025783,1.4857154,,,,,,,,,,,,,, -109500,4.35834,1.5659946,,,,,,,,,,,,,, -109600,5.450424,1.5691293,,,,,,,,,,,,,, -109700,4.492526,1.4516658,,,,,,,,,,,,,, -109800,5.2009873,1.5308043,,,,,,,,,,,,,, -109900,4.777786,1.5510424,,,,,,,,,,,,,, -110000,4.7768393,1.5592053,,,,,,,,,,,,,, -110100,4.52097,1.5441443,,,,,,,,,,,,,, -110200,4.595527,1.5327039,,,,,,,,,,,,,, -110300,5.2253637,1.5874363,,,,,,,,,,,,,, -110400,6.454836,1.5970982,,,,,,,,,,,,,, -110500,5.383016,1.5635194,,,,,,,,,,,,,, -110600,4.5178437,1.4577867,,,,,,,,,,,,,, -110700,5.0193434,1.5734673,,,,,,,,,,,,,, -110758,,,0.7382214665412903,0.9993380308151244,0.6735000014305115,1.3205397129058838,50000.0,0.5402000546455383,2.009732484817505,10000.0,37776.72888278961,39094.371799230576,37776.72888278961,1309.994685173035,3.881550550460816,0.0 -110800,5.0963902,1.5067449,,,,,,,,,,,,,, -110900,5.288843,1.4960421,,,,,,,,,,,,,, -111000,5.5721965,1.637466,,,,,,,,,,,,,, -111100,5.220701,1.493069,,,,,,,,,,,,,, -111200,4.8119426,1.4803436,,,,,,,,,,,,,, -111300,4.7537775,1.6667823,,,,,,,,,,,,,, -111400,4.8594575,1.519547,,,,,,,,,,,,,, -111500,4.444563,1.4398434,,,,,,,,,,,,,, -111600,4.8584294,1.4972442,,,,,,,,,,,,,, -111700,4.9178576,1.5301514,,,,,,,,,,,,,, -111800,5.2194214,1.4759486,,,,,,,,,,,,,, -111900,5.3528886,1.4548426,,,,,,,,,,,,,, -112000,4.595026,1.5106533,,,,,,,,,,,,,, -112100,4.996021,1.5081935,,,,,,,,,,,,,, -112200,5.219018,1.4266704,,,,,,,,,,,,,, -112257,,,0.7806122303009033,0.819338858127594,0.6816200017929077,1.2890325784683228,50000.0,0.5503000020980835,1.996065974235535,10000.0,38286.65507602692,39621.84500050545,38286.65507602692,1327.4420084953308,3.929419755935669,0.0 -112300,5.568893,1.5063674,,,,,,,,,,,,,, -112400,4.790523,1.5226629,,,,,,,,,,,,,, -112500,5.5616875,1.5195708,,,,,,,,,,,,,, -112600,4.8073535,1.5521089,,,,,,,,,,,,,, -112700,4.79168,1.4987004,,,,,,,,,,,,,, -112800,4.7871895,1.5011578,,,,,,,,,,,,,, -112900,5.187003,1.4588445,,,,,,,,,,,,,, -113000,5.0599375,1.4405823,,,,,,,,,,,,,, -113100,6.1784377,1.5086093,,,,,,,,,,,,,, -113200,4.797812,1.4176773,,,,,,,,,,,,,, -113300,4.447545,1.38663,,,,,,,,,,,,,, -113400,5.223977,1.518999,,,,,,,,,,,,,, -113500,4.9932985,1.5176189,,,,,,,,,,,,,, -113600,5.159762,1.4335886,,,,,,,,,,,,,, -113700,5.080688,1.4066163,,,,,,,,,,,,,, -113756,,,0.7663623690605164,0.8877369165420532,0.6804400086402893,1.2903692722320557,50000.0,0.5544000267982483,1.976928472518921,10000.0,38796.73586535454,40149.694214344025,38796.73586535454,1345.113474369049,3.976061820983887,0.0 -113800,5.1165576,1.4877739,,,,,,,,,,,,,, -113900,5.14648,1.4668008,,,,,,,,,,,,,, -114000,4.9368324,1.5420276,,,,,,,,,,,,,, -114100,5.3090854,1.4910004,,,,,,,,,,,,,, -114200,5.0031834,1.444756,,,,,,,,,,,,,, -114300,4.8340235,1.4871209,,,,,,,,,,,,,, -114400,5.4052014,1.4649025,,,,,,,,,,,,,, -114500,4.601928,1.504014,,,,,,,,,,,,,, -114600,4.7542796,1.6000674,,,,,,,,,,,,,, -114700,5.2442274,1.5496411,,,,,,,,,,,,,, -114800,4.7843275,1.4997857,,,,,,,,,,,,,, -114900,5.0185657,1.5136498,,,,,,,,,,,,,, -115000,5.075167,1.5594621,,,,,,,,,,,,,, -115100,5.839085,1.3936274,,,,,,,,,,,,,, -115200,5.8751826,1.450865,,,,,,,,,,,,,, -115255,,,0.7538464665412903,0.9340843558311462,0.6775199770927429,1.3047815561294556,50000.0,0.5565000176429749,1.98634934425354,10000.0,39306.689522743225,40677.29035973549,39306.689522743225,1362.6518981456757,4.028971195220947,0.0 -115300,4.785165,1.4652313,,,,,,,,,,,,,, -115400,5.0174727,1.4171001,,,,,,,,,,,,,, -115500,4.517855,1.386194,,,,,,,,,,,,,, -115600,4.8908076,1.4943093,,,,,,,,,,,,,, -115700,5.2916784,1.36477,,,,,,,,,,,,,, -115800,4.430636,1.4701309,,,,,,,,,,,,,, -115900,4.4524217,1.3656102,,,,,,,,,,,,,, -116000,5.994136,1.4886847,,,,,,,,,,,,,, -116100,5.4700117,1.4979951,,,,,,,,,,,,,, -116200,5.6097045,1.4395164,,,,,,,,,,,,,, -116300,5.050853,1.4968457,,,,,,,,,,,,,, -116400,4.5830526,1.4042734,,,,,,,,,,,,,, -116500,5.2240195,1.4219468,,,,,,,,,,,,,, -116600,5.6036406,1.5281686,,,,,,,,,,,,,, -116700,4.920284,1.4221716,,,,,,,,,,,,,, -116754,,,0.7443000674247742,0.957797348499298,0.6729399561882019,1.3240751028060913,50000.0,0.5436000227928162,2.046970129013061,10000.0,39816.85724949837,41204.86368608475,39816.85724949837,1379.9599359035492,4.075598001480103,0.0 -116800,5.033793,1.4635885,,,,,,,,,,,,,, -116900,5.5232606,1.503924,,,,,,,,,,,,,, -117000,5.3488183,1.3567959,,,,,,,,,,,,,, -117100,5.8438196,1.409112,,,,,,,,,,,,,, -117200,4.973286,1.379517,,,,,,,,,,,,,, -117300,5.076929,1.385741,,,,,,,,,,,,,, -117400,4.755605,1.4569243,,,,,,,,,,,,,, -117500,4.9051433,1.4583097,,,,,,,,,,,,,, -117600,5.254992,1.5307921,,,,,,,,,,,,,, -117700,5.163012,1.480778,,,,,,,,,,,,,, -117800,5.0843678,1.4630195,,,,,,,,,,,,,, -117900,4.9555836,1.4435402,,,,,,,,,,,,,, -118000,5.4791327,1.4028922,,,,,,,,,,,,,, -118100,4.7222195,1.3737485,,,,,,,,,,,,,, -118200,4.9723306,1.4497173,,,,,,,,,,,,,, -118253,,,0.7626155614852905,0.9064222574234008,0.6890999674797058,1.2482517957687378,50000.0,0.5649000406265259,1.9236938953399656,10000.0,40326.82902789116,41732.53103065491,40326.82902789116,1397.556854724884,4.121983051300049,0.0 -118300,4.586151,1.5000523,,,,,,,,,,,,,, -118400,5.205132,1.4633163,,,,,,,,,,,,,, -118500,4.955036,1.3644893,,,,,,,,,,,,,, -118600,5.513144,1.4705553,,,,,,,,,,,,,, -118700,5.328004,1.4126551,,,,,,,,,,,,,, -118800,5.20458,1.4892884,,,,,,,,,,,,,, -118900,5.590018,1.5241762,,,,,,,,,,,,,, -119000,5.2363586,1.4424615,,,,,,,,,,,,,, -119100,5.374686,1.4728003,,,,,,,,,,,,,, -119200,5.239765,1.4270701,,,,,,,,,,,,,, -119300,5.1564803,1.5091383,,,,,,,,,,,,,, -119400,5.766445,1.5467502,,,,,,,,,,,,,, -119500,5.6512084,1.4483826,,,,,,,,,,,,,, -119600,4.8844233,1.3411287,,,,,,,,,,,,,, -119700,5.076148,1.3665214,,,,,,,,,,,,,, -119752,,,0.7578921914100647,0.9132976531982422,0.6890599727630615,1.2581051588058472,50000.0,0.5614000558853149,1.951021194458008,10000.0,40836.754336595535,42259.80155444145,40836.754336595535,1414.801174402237,4.171825647354126,0.0 -119800,5.4371033,1.4985738,,,,,,,,,,,,,, -119900,5.400019,1.4784089,,,,,,,,,,,,,, -120000,5.5108047,1.4168793,,,,,,,,,,,,,, -120100,5.2565374,1.4124113,,,,,,,,,,,,,, -120200,5.3159122,1.4122977,,,,,,,,,,,,,, -120300,5.661841,1.4267869,,,,,,,,,,,,,, -120400,5.8956866,1.546046,,,,,,,,,,,,,, -120500,5.0060377,1.5500921,,,,,,,,,,,,,, -120600,5.2657027,1.4924841,,,,,,,,,,,,,, -120700,5.0933876,1.3919754,,,,,,,,,,,,,, -120800,4.804447,1.3959324,,,,,,,,,,,,,, -120900,5.414472,1.3770361,,,,,,,,,,,,,, -121000,5.183382,1.3872417,,,,,,,,,,,,,, -121100,5.1505337,1.356774,,,,,,,,,,,,,, -121200,4.849283,1.3701681,,,,,,,,,,,,,, -121251,,,0.7889827489852905,0.7927571535110474,0.6894400119781494,1.2526060342788696,50000.0,0.5688000321388245,1.932717442512512,10000.0,41346.87264943123,42787.65633249283,41346.87264943123,1432.4339890480042,4.225170850753784,0.0 -121300,5.0679393,1.3239534,,,,,,,,,,,,,, -121400,5.9228745,1.5501885,,,,,,,,,,,,,, -121500,5.6375475,1.4191134,,,,,,,,,,,,,, -121600,5.1571336,1.4044417,,,,,,,,,,,,,, -121700,5.1205473,1.5091041,,,,,,,,,,,,,, -121800,5.4660673,1.5152774,,,,,,,,,,,,,, -121900,5.488563,1.4723252,,,,,,,,,,,,,, -122000,5.6284776,1.443027,,,,,,,,,,,,,, -122100,6.0202656,1.3472465,,,,,,,,,,,,,, -122200,6.1737685,1.3501213,,,,,,,,,,,,,, -122300,5.1747375,1.4235198,,,,,,,,,,,,,, -122400,4.7673864,1.3843236,,,,,,,,,,,,,, -122500,5.9471703,1.4925618,,,,,,,,,,,,,, -122600,5.9265614,1.4267304,,,,,,,,,,,,,, -122700,5.5520573,1.4489043,,,,,,,,,,,,,, -122751,,,0.778738796710968,0.8235031962394714,0.6881600022315979,1.264472484588623,50000.0,0.5613000392913818,1.9879837036132808,10000.0,41857.08833384514,43315.46732521057,41857.08833384514,1449.9362061023712,4.2682952880859375,0.0 -122800,5.3956194,1.3707083,,,,,,,,,,,,,, -122900,5.01776,1.4017059,,,,,,,,,,,,,, -123000,5.717127,1.4385769,,,,,,,,,,,,,, -123100,5.598775,1.4677454,,,,,,,,,,,,,, -123200,5.616306,1.3454362,,,,,,,,,,,,,, -123300,5.2494445,1.289243,,,,,,,,,,,,,, -123400,5.2591705,1.4314871,,,,,,,,,,,,,, -123500,5.6184273,1.3720703,,,,,,,,,,,,,, -123600,5.7435913,1.3999225,,,,,,,,,,,,,, -123700,5.409554,1.4715358,,,,,,,,,,,,,, -123800,5.325403,1.4098611,,,,,,,,,,,,,, -123900,5.731843,1.4102148,,,,,,,,,,,,,, -124000,5.6999574,1.421094,,,,,,,,,,,,,, -124100,5.9995985,1.4475446,,,,,,,,,,,,,, -124200,5.9438596,1.3464086,,,,,,,,,,,,,, -124250,,,0.782246470451355,0.8037508726119995,0.6977399587631226,1.2117669582366943,50000.0,0.5701000094413757,1.924551010131836,10000.0,42367.10335683823,43843.5637190342,42367.10335683823,1467.9139926433563,4.320462465286255,0.0 -124300,4.9022465,1.3070343,,,,,,,,,,,,,, -124400,5.1221905,1.3643154,,,,,,,,,,,,,, -124500,5.4191833,1.3523151,,,,,,,,,,,,,, -124600,5.8094473,1.3428053,,,,,,,,,,,,,, -124700,5.665876,1.3204515,,,,,,,,,,,,,, -124800,5.8588085,1.2624986,,,,,,,,,,,,,, -124900,5.6584806,1.3447795,,,,,,,,,,,,,, -125000,5.3734655,1.3069441,,,,,,,,,,,,,, -125100,5.8329477,1.4227713,,,,,,,,,,,,,, -125200,5.7819147,1.4084729,,,,,,,,,,,,,, -125300,5.66631,1.2914114,,,,,,,,,,,,,, -125400,6.754126,1.4584069,,,,,,,,,,,,,, -125500,5.666439,1.4990585,,,,,,,,,,,,,, -125600,5.6795807,1.4131023,,,,,,,,,,,,,, -125700,5.4375553,1.2545762,,,,,,,,,,,,,, -125749,,,0.7780413031578064,0.8233484625816345,0.6994400024414062,1.2192716598510742,50000.0,0.5735000371932983,1.909732699394226,10000.0,42877.30141162872,44371.0747590065,42877.30141162872,1485.1218111515043,4.3753581047058105,0.0 -125800,5.6126485,1.4698379,,,,,,,,,,,,,, -125900,5.7350016,1.3886232,,,,,,,,,,,,,, -126000,5.657561,1.4746104,,,,,,,,,,,,,, -126100,6.0945945,1.3179259,,,,,,,,,,,,,, -126200,5.301483,1.3698616,,,,,,,,,,,,,, -126300,6.501175,1.4611678,,,,,,,,,,,,,, -126400,5.4723334,1.3561956,,,,,,,,,,,,,, -126500,5.2277455,1.2545223,,,,,,,,,,,,,, -126600,5.8637395,1.4147995,,,,,,,,,,,,,, -126700,5.65616,1.3681678,,,,,,,,,,,,,, -126800,5.7467966,1.3891633,,,,,,,,,,,,,, -126900,6.0074115,1.4866204,,,,,,,,,,,,,, -127000,5.239055,1.2945508,,,,,,,,,,,,,, -127100,6.452236,1.427896,,,,,,,,,,,,,, -127200,6.3896575,1.3555787,,,,,,,,,,,,,, -127248,,,0.7848173975944519,0.798477292060852,0.6995799541473389,1.2017760276794434,50000.0,0.5724000334739685,1.918750882148743,10000.0,43387.38800287247,44898.51348400116,43387.38800287247,1502.3710179328918,4.427471876144409,0.0 -127300,5.929318,1.3795524,,,,,,,,,,,,,, -127400,5.374868,1.319922,,,,,,,,,,,,,, -127500,5.5383873,1.1969783,,,,,,,,,,,,,, -127600,5.6738434,1.4086131,,,,,,,,,,,,,, -127700,5.744748,1.4131082,,,,,,,,,,,,,, -127800,5.471384,1.4731059,,,,,,,,,,,,,, -127900,5.9484353,1.2779152,,,,,,,,,,,,,, -128000,6.0914392,1.4186798,,,,,,,,,,,,,, -128100,5.4910703,1.405093,,,,,,,,,,,,,, -128200,5.7669263,1.2954057,,,,,,,,,,,,,, -128300,5.5791826,1.39797,,,,,,,,,,,,,, -128400,5.909684,1.2744784,,,,,,,,,,,,,, -128500,5.3607454,1.3241415,,,,,,,,,,,,,, -128600,5.9831653,1.3724687,,,,,,,,,,,,,, -128700,5.502111,1.2573011,,,,,,,,,,,,,, -128747,,,0.7816087007522583,0.8127210736274719,0.7039200067520142,1.1963602304458618,50000.0,0.5781000256538391,1.8842612504959104,10000.0,43897.49772691727,45426.061324596405,43897.49772691727,1519.7068076133728,4.479499578475952,0.0 -128800,5.567416,1.3388776,,,,,,,,,,,,,, -128900,5.342006,1.2465887,,,,,,,,,,,,,, -129000,6.1690784,1.2757732,,,,,,,,,,,,,, -129100,5.5300584,1.4162127,,,,,,,,,,,,,, -129200,5.698071,1.3169879,,,,,,,,,,,,,, -129300,5.644675,1.3462504,,,,,,,,,,,,,, -129400,5.796867,1.3908774,,,,,,,,,,,,,, -129500,6.0692997,1.302797,,,,,,,,,,,,,, -129600,5.8542027,1.4127097,,,,,,,,,,,,,, -129700,5.535502,1.2446499,,,,,,,,,,,,,, -129800,6.405326,1.3965261,,,,,,,,,,,,,, -129900,5.2151384,1.251789,,,,,,,,,,,,,, -130000,6.3064485,1.3702493,,,,,,,,,,,,,, -130100,6.0593286,1.4014773,,,,,,,,,,,,,, -130200,6.201459,1.2598839,,,,,,,,,,,,,, -130246,,,0.7840800285339355,0.8188005685806274,0.7031799554824829,1.1907312870025637,50000.0,0.5733000040054321,1.893583297729492,10000.0,44407.48382616043,45953.42463064194,44407.48382616043,1536.9744882583618,4.538180828094482,0.0 -130300,6.073007,1.2437607,,,,,,,,,,,,,, -130400,5.9699163,1.3459319,,,,,,,,,,,,,, -130500,5.3199463,1.2890476,,,,,,,,,,,,,, -130600,4.8967633,1.2575341,,,,,,,,,,,,,, -130700,5.4976506,1.3796793,,,,,,,,,,,,,, -130800,6.2686315,1.3983393,,,,,,,,,,,,,, -130900,5.8391914,1.3780271,,,,,,,,,,,,,, -131000,5.772426,1.2498782,,,,,,,,,,,,,, -131100,6.7074556,1.28024,,,,,,,,,,,,,, -131200,5.772223,1.2826285,,,,,,,,,,,,,, -131300,5.5643816,1.2830877,,,,,,,,,,,,,, -131400,5.9190903,1.308939,,,,,,,,,,,,,, -131500,5.774735,1.2140032,,,,,,,,,,,,,, -131600,5.3462973,1.2123195,,,,,,,,,,,,,, -131700,6.768967,1.4244114,,,,,,,,,,,,,, -131745,,,0.8127192258834839,0.6832287907600403,0.7074599862098694,1.1807633638381958,50000.0,0.5805000066757202,1.8970988988876345,10000.0,44917.434248924255,46480.9824347496,44917.434248924255,1554.476620197296,4.591028213500977,0.0 -131800,6.2109766,1.271242,,,,,,,,,,,,,, -131900,5.457628,1.2588078,,,,,,,,,,,,,, -132000,5.484449,1.2971667,,,,,,,,,,,,,, -132100,5.73692,1.2854749,,,,,,,,,,,,,, -132200,5.800091,1.2026526,,,,,,,,,,,,,, -132300,6.0396686,1.1840672,,,,,,,,,,,,,, -132400,5.8213325,1.225632,,,,,,,,,,,,,, -132500,6.537839,1.3601086,,,,,,,,,,,,,, -132600,5.5362697,1.3370464,,,,,,,,,,,,,, -132700,6.671744,1.3618175,,,,,,,,,,,,,, -132800,5.5792603,1.3469924,,,,,,,,,,,,,, -132900,7.0322185,1.2159033,,,,,,,,,,,,,, -133000,6.1176376,1.4771767,,,,,,,,,,,,,, -133100,5.446235,1.2224755,,,,,,,,,,,,,, -133200,6.087166,1.316019,,,,,,,,,,,,,, -133243,,,0.79984450340271,0.7381070852279663,0.7058199644088745,1.1831684112548828,50000.0,0.5821000337600708,1.8820722103118896,10000.0,45427.39062857628,47008.251879930496,45427.39062857628,1571.6903052330017,4.63919734954834,0.0 -133300,5.7028003,1.3159829,,,,,,,,,,,,,, -133400,6.7135015,1.2106069,,,,,,,,,,,,,, -133500,5.736091,1.2298224,,,,,,,,,,,,,, -133600,5.9155746,1.2747073,,,,,,,,,,,,,, -133700,5.5067296,1.2717128,,,,,,,,,,,,,, -133800,5.8722014,1.1813027,,,,,,,,,,,,,, -133900,6.08938,1.3583032,,,,,,,,,,,,,, -134000,5.590215,1.2674685,,,,,,,,,,,,,, -134100,6.266578,1.3185009,,,,,,,,,,,,,, -134200,5.898716,1.2339886,,,,,,,,,,,,,, -134300,6.3427925,1.2523329,,,,,,,,,,,,,, -134400,5.826841,1.2990477,,,,,,,,,,,,,, -134500,6.4748616,1.3424152,,,,,,,,,,,,,, -134600,5.6744695,1.2318162,,,,,,,,,,,,,, -134700,6.381693,1.3512764,,,,,,,,,,,,,, -134742,,,0.8057836294174194,0.7143407464027405,0.7128799557685852,1.146412968635559,50000.0,0.5893000364303589,1.825054168701172,10000.0,45937.45642733574,47535.90947675705,45937.45642733574,1589.1792786121368,4.692151308059692,0.0 -134800,5.7965527,1.2213947,,,,,,,,,,,,,, -134900,6.1139817,1.2538484,,,,,,,,,,,,,, -135000,7.118715,1.260494,,,,,,,,,,,,,, -135100,6.1226873,1.2221138,,,,,,,,,,,,,, -135200,6.153674,1.1832591,,,,,,,,,,,,,, -135300,5.862118,1.2761813,,,,,,,,,,,,,, -135400,6.0599732,1.239136,,,,,,,,,,,,,, -135500,5.8662252,1.1749074,,,,,,,,,,,,,, -135600,7.1361785,1.21463,,,,,,,,,,,,,, -135700,7.117112,1.259656,,,,,,,,,,,,,, -135800,7.0504994,1.2601019,,,,,,,,,,,,,, -135900,6.1626334,1.2706097,,,,,,,,,,,,,, -136000,5.421012,1.1231527,,,,,,,,,,,,,, -136100,6.700701,1.2696766,,,,,,,,,,,,,, -136200,6.548259,1.2560285,,,,,,,,,,,,,, -136241,,,0.8033322691917419,0.7214902639389038,0.7112399935722351,1.1542009115219116,50000.0,0.5920000076293945,1.826322317123413,10000.0,46447.50892996788,48063.74123668671,46447.50892996788,1606.857929468155,4.741678476333618,0.0 -136300,5.875438,1.1805941,,,,,,,,,,,,,, -136400,6.1112366,1.2161788,,,,,,,,,,,,,, -136500,6.315309,1.2074211,,,,,,,,,,,,,, -136600,6.057691,1.3004062,,,,,,,,,,,,,, -136700,6.3843856,1.1916779,,,,,,,,,,,,,, -136800,5.7067013,1.2113204,,,,,,,,,,,,,, -136900,5.86685,1.1926025,,,,,,,,,,,,,, -137000,5.9364505,1.238876,,,,,,,,,,,,,, -137100,6.0978637,1.2672915,,,,,,,,,,,,,, -137200,5.9631605,1.1886665,,,,,,,,,,,,,, -137300,6.6917377,1.245271,,,,,,,,,,,,,, -137400,6.1517873,1.2455562,,,,,,,,,,,,,, -137500,6.174665,1.1312765,,,,,,,,,,,,,, -137600,6.238767,1.1995156,,,,,,,,,,,,,, -137700,6.7344475,1.3659871,,,,,,,,,,,,,, -137740,,,0.805683970451355,0.7093267440795898,0.7160399556159973,1.1419600248336792,50000.0,0.5933000445365906,1.8347601890563965,10000.0,46957.61950492859,48591.36189293861,46957.61950492859,1624.2676224708557,4.790990352630615,0.0 -137800,6.7067003,1.3938891,,,,,,,,,,,,,, -137900,6.733498,1.37271,,,,,,,,,,,,,, -138000,6.2022347,1.2088561,,,,,,,,,,,,,, -138100,6.872552,1.2422097,,,,,,,,,,,,,, -138200,5.84134,1.2163122,,,,,,,,,,,,,, -138300,6.4429255,1.2381983,,,,,,,,,,,,,, -138400,6.841902,1.3547856,,,,,,,,,,,,,, -138500,7.485295,1.3285352,,,,,,,,,,,,,, -138600,6.4235206,1.2468768,,,,,,,,,,,,,, -138700,6.705366,1.2348832,,,,,,,,,,,,,, -138800,6.423607,1.2389371,,,,,,,,,,,,,, -138900,6.5880613,1.2902203,,,,,,,,,,,,,, -139000,6.6976924,1.1938869,,,,,,,,,,,,,, -139100,6.910719,1.2983268,,,,,,,,,,,,,, -139200,7.1934333,1.1897483,,,,,,,,,,,,,, -139238,,,0.8067402839660645,0.7009314298629761,0.7160199880599976,1.1319137811660769,50000.0,0.5916000008583069,1.830636978149414,10000.0,47467.85808753967,49118.85888767242,47467.85808753967,1641.4246740341189,4.841644525527954,0.0 -139300,6.0957303,1.3138676,,,,,,,,,,,,,, -139400,6.8881197,1.3226227,,,,,,,,,,,,,, -139500,6.2430315,1.2092055,,,,,,,,,,,,,, -139600,5.870772,1.174438,,,,,,,,,,,,,, -139700,6.053491,1.2270833,,,,,,,,,,,,,, -139800,6.496675,1.1358896,,,,,,,,,,,,,, -139900,6.4600534,1.1928257,,,,,,,,,,,,,, -140000,6.4061427,1.145828,,,,,,,,,,,,,, -140100,6.4570913,1.1536157,,,,,,,,,,,,,, -140200,6.5549054,1.2136629,,,,,,,,,,,,,, -140300,6.1444836,1.1858472,,,,,,,,,,,,,, -140400,6.5873804,1.1913791,,,,,,,,,,,,,, -140500,6.6767406,1.2605443,,,,,,,,,,,,,, -140600,6.8591948,1.2328701,,,,,,,,,,,,,, -140700,6.5927286,1.1414689,,,,,,,,,,,,,, -140737,,,0.8413185477256775,0.5797902345657349,0.7160599827766418,1.138371467590332,50000.0,0.5914000272750854,1.810274839401245,10000.0,47978.06637239456,49646.37145638466,47978.06637239456,1658.625111579895,4.895170450210571,0.0 -140800,6.91213,1.205752,,,,,,,,,,,,,, -140900,6.357919,1.1912326,,,,,,,,,,,,,, -141000,6.639814,1.1627669,,,,,,,,,,,,,, -141100,6.3973856,1.1684043,,,,,,,,,,,,,, -141200,6.702794,1.1925517,,,,,,,,,,,,,, -141300,6.6232977,1.2575151,,,,,,,,,,,,,, -141400,7.288866,1.1362097,,,,,,,,,,,,,, -141500,7.028526,1.2346292,,,,,,,,,,,,,, -141600,6.423446,1.2367238,,,,,,,,,,,,,, -141700,6.0780077,1.1438878,,,,,,,,,,,,,, -141800,6.96218,1.1245612,,,,,,,,,,,,,, -141900,6.985243,1.1812676,,,,,,,,,,,,,, -142000,6.407845,1.1414428,,,,,,,,,,,,,, -142100,6.8396397,1.2476346,,,,,,,,,,,,,, -142200,7.4797025,1.2307956,,,,,,,,,,,,,, -142236,,,0.8332070708274841,0.6023600101470947,0.7219199538230896,1.1166437864303589,50000.0,0.5998000502586365,1.7817325592041016,10000.0,48488.23559617996,50174.20073246956,48488.23559617996,1676.1800088882446,4.949621677398682,0.0 -142300,7.6939297,1.2396146,,,,,,,,,,,,,, -142400,6.4055686,1.1604357,,,,,,,,,,,,,, -142500,6.73001,1.1851873,,,,,,,,,,,,,, -142600,6.3791823,1.1995331,,,,,,,,,,,,,, -142700,6.8706517,1.1977842,,,,,,,,,,,,,, -142800,7.3021994,1.147492,,,,,,,,,,,,,, -142900,6.1842194,1.088697,,,,,,,,,,,,,, -143000,6.849264,1.2358375,,,,,,,,,,,,,, -143100,7.0285344,1.1857477,,,,,,,,,,,,,, -143200,6.776231,1.1923977,,,,,,,,,,,,,, -143300,6.739951,1.2248013,,,,,,,,,,,,,, -143400,6.3551726,1.0450001,,,,,,,,,,,,,, -143500,6.2560635,1.0915918,,,,,,,,,,,,,, -143600,6.550121,1.2667024,,,,,,,,,,,,,, -143700,6.8502793,1.2164856,,,,,,,,,,,,,, -143735,,,0.8356783986091614,0.5958223938941956,0.7263199687004089,1.1001946926116943,50000.0,0.6032000184059143,1.785300612449646,10000.0,48998.14870905876,50701.440212488174,48998.14870905876,1693.400636434555,5.0050599575042725,0.0 -143800,6.3893495,1.1154915,,,,,,,,,,,,,, -143900,7.2911277,1.1624563,,,,,,,,,,,,,, -144000,6.760169,1.1955779,,,,,,,,,,,,,, -144100,6.3844156,1.06283,,,,,,,,,,,,,, -144200,6.5622225,1.15957,,,,,,,,,,,,,, -144300,7.8780136,1.1114428,,,,,,,,,,,,,, -144400,6.350197,1.1019702,,,,,,,,,,,,,, -144500,6.8238893,1.1791728,,,,,,,,,,,,,, -144600,6.104464,1.102055,,,,,,,,,,,,,, -144700,6.967614,1.2814088,,,,,,,,,,,,,, -144800,6.712404,1.188316,,,,,,,,,,,,,, -144900,7.7828274,1.1282524,,,,,,,,,,,,,, -145000,6.375181,1.0681355,,,,,,,,,,,,,, -145100,6.4921474,1.047241,,,,,,,,,,,,,, -145200,6.834887,1.1322167,,,,,,,,,,,,,, -145233,,,0.8364556431770325,0.5910767316818237,0.7298399806022644,1.085693120956421,50000.0,0.6122000217437744,1.7577836513519287,10000.0,49508.11609601975,51228.993624448776,49508.11609601975,1710.852013349533,5.087496757507324,0.0 -145300,6.6352015,1.1073923,,,,,,,,,,,,,, -145400,6.856825,1.038704,,,,,,,,,,,,,, -145500,7.4229965,1.1880834,,,,,,,,,,,,,, -145600,6.9682355,1.103983,,,,,,,,,,,,,, -145700,6.943321,1.0867187,,,,,,,,,,,,,, -145800,6.412381,1.1033314,,,,,,,,,,,,,, -145900,7.2941084,1.2447959,,,,,,,,,,,,,, -146000,6.632347,1.0967767,,,,,,,,,,,,,, -146100,6.6960273,1.2084923,,,,,,,,,,,,,, -146200,7.0009456,1.1712012,,,,,,,,,,,,,, -146300,6.4878917,1.1175474,,,,,,,,,,,,,, -146400,8.152831,1.2266574,,,,,,,,,,,,,, -146500,7.099554,1.1246443,,,,,,,,,,,,,, -146600,7.4961734,1.1811011,,,,,,,,,,,,,, -146700,6.5502944,1.1211689,,,,,,,,,,,,,, -146732,,,0.83402419090271,0.6022161841392517,0.729919970035553,1.0840203762054443,50000.0,0.6041000485420227,1.7766298055648804,10000.0,50018.29201626778,51756.42676591873,50018.29201626778,1728.0061275959015,5.139847993850708,0.0 -146800,6.773161,1.1294179,,,,,,,,,,,,,, -146900,6.7120857,1.0678985,,,,,,,,,,,,,, -147000,7.1546755,1.1339613,,,,,,,,,,,,,, -147100,7.215452,1.1744655,,,,,,,,,,,,,, -147200,6.5328298,1.0854404,,,,,,,,,,,,,, -147300,6.92913,1.1624849,,,,,,,,,,,,,, -147400,6.5442204,1.077312,,,,,,,,,,,,,, -147500,6.3065424,1.0016292,,,,,,,,,,,,,, -147600,7.006615,1.1703253,,,,,,,,,,,,,, -147700,7.1289816,1.1271969,,,,,,,,,,,,,, -147800,6.82629,1.0841327,,,,,,,,,,,,,, -147900,7.034286,1.1874021,,,,,,,,,,,,,, -148000,7.129029,1.1256199,,,,,,,,,,,,,, -148100,7.7482247,1.091469,,,,,,,,,,,,,, -148200,7.2447724,1.0774206,,,,,,,,,,,,,, -148231,,,0.8361965417861938,0.5804234743118286,0.7318599820137024,1.0679874420166016,50000.0,0.6070000529289246,1.7447444200515747,10000.0,50528.38182926178,52283.85293364525,50528.38182926178,1745.2415256500244,5.190353631973267,0.0 -148300,7.222058,1.0589337,,,,,,,,,,,,,, -148400,7.361692,1.1145358,,,,,,,,,,,,,, -148500,6.9633913,1.0695642,,,,,,,,,,,,,, -148600,6.2288404,1.0729235,,,,,,,,,,,,,, -148700,7.2461314,1.1140871,,,,,,,,,,,,,, -148800,6.8651385,1.062526,,,,,,,,,,,,,, -148900,7.313769,1.1248871,,,,,,,,,,,,,, -149000,7.030279,1.062118,,,,,,,,,,,,,, -149100,7.1003017,1.1070379,,,,,,,,,,,,,, -149200,7.683491,1.1752563,,,,,,,,,,,,,, -149300,7.5108433,1.0836846,,,,,,,,,,,,,, -149400,7.103618,1.0439682,,,,,,,,,,,,,, -149500,6.9446945,1.0954494,,,,,,,,,,,,,, -149600,7.96428,1.0393554,,,,,,,,,,,,,, -149700,7.719065,1.0854223,,,,,,,,,,,,,, -149730,,,0.8720304369926453,0.4616715610027313,0.7359199523925781,1.06617271900177,50000.0,0.6065000295639038,1.750016450881958,10000.0,51038.4784321785,52811.57619476318,51038.4784321785,1762.763778924942,5.244019508361816,0.0 -149800,7.276053,1.0715241,,,,,,,,,,,,,, -149900,6.718938,1.033616,,,,,,,,,,,,,, -150000,7.3660765,1.0562216,,,,,,,,,,,,,, -150100,7.634733,1.035449,,,,,,,,,,,,,, -150200,7.3160634,1.0840169,,,,,,,,,,,,,, -150300,7.05704,1.0873268,,,,,,,,,,,,,, -150400,7.472329,1.1093409,,,,,,,,,,,,,, -150500,7.4993057,1.0709983,,,,,,,,,,,,,, -150600,7.550611,1.0724527,,,,,,,,,,,,,, -150700,7.335207,0.9977231,,,,,,,,,,,,,, -150800,6.24372,1.0001242,,,,,,,,,,,,,, -150900,7.6885643,1.0312192,,,,,,,,,,,,,, -151000,7.373028,1.0229727,,,,,,,,,,,,,, -151100,7.4466863,1.0685927,,,,,,,,,,,,,, -151200,7.1646767,1.0653076,,,,,,,,,,,,,, -151228,,,0.8650151491165161,0.4839383959770202,0.7351799607276917,1.0641911029815674,50000.0,0.6137000322341919,1.7262399196624756,10000.0,51548.38336634636,53338.903485774994,51548.38336634636,1780.0837841033936,5.296934366226196,0.0 -151300,7.9421687,1.1040406,,,,,,,,,,,,,, -151400,7.0767393,1.0113113,,,,,,,,,,,,,, -151500,7.6244116,1.0622807,,,,,,,,,,,,,, -151600,7.1730943,1.0254666,,,,,,,,,,,,,, -151700,7.9180884,1.1613162,,,,,,,,,,,,,, -151800,7.400778,1.1334095,,,,,,,,,,,,,, -151900,7.89178,1.0624547,,,,,,,,,,,,,, -152000,7.1912184,0.9637896,,,,,,,,,,,,,, -152100,7.2440467,1.0836738,,,,,,,,,,,,,, -152200,7.1958313,1.1232727,,,,,,,,,,,,,, -152300,7.78984,1.0116923,,,,,,,,,,,,,, -152400,8.423128,1.0857351,,,,,,,,,,,,,, -152500,7.370634,0.9594843,,,,,,,,,,,,,, -152600,7.0288277,1.0051732,,,,,,,,,,,,,, -152700,7.6122546,0.99949217,,,,,,,,,,,,,, -152726,,,0.8591955900192261,0.492692083120346,0.7340599894523621,1.070358157157898,50000.0,0.6089000105857849,1.7674813270568848,10000.0,52058.31352877617,53866.35694384575,52058.31352877617,1797.5025045871737,5.351749420166016,0.0 -152800,7.4005346,0.92320985,,,,,,,,,,,,,, -152900,7.118778,0.98069334,,,,,,,,,,,,,, -153000,7.8068423,1.0909759,,,,,,,,,,,,,, -153100,7.227092,0.96863914,,,,,,,,,,,,,, -153200,6.8677115,0.90256613,,,,,,,,,,,,,, -153300,8.763039,0.9967403,,,,,,,,,,,,,, -153400,7.913491,1.0313789,,,,,,,,,,,,,, -153500,7.690897,1.1050197,,,,,,,,,,,,,, -153600,8.33396,1.0670239,,,,,,,,,,,,,, -153700,7.5975294,1.1762466,,,,,,,,,,,,,, -153800,7.179478,0.9908793,,,,,,,,,,,,,, -153900,7.713982,0.99407065,,,,,,,,,,,,,, -154000,8.179248,1.0401475,,,,,,,,,,,,,, -154100,7.343979,0.9992273,,,,,,,,,,,,,, -154200,8.553646,1.070685,,,,,,,,,,,,,, -154225,,,0.8656927347183228,0.478325217962265,0.7418199777603149,1.0443856716156006,50000.0,0.6195000410079956,1.7304673194885254,10000.0,52568.48841428757,54393.9851975441,52568.48841428757,1814.846135139465,5.408911228179932,0.0 -154300,7.5920663,1.0168769,,,,,,,,,,,,,, -154400,9.473683,1.0475433,,,,,,,,,,,,,, -154500,7.558048,0.99283814,,,,,,,,,,,,,, -154600,6.867911,1.0588403,,,,,,,,,,,,,, -154700,7.2925706,0.9683844,,,,,,,,,,,,,, -154800,7.0930395,1.0420527,,,,,,,,,,,,,, -154900,7.5241747,0.97016907,,,,,,,,,,,,,, -155000,9.271832,1.0693066,,,,,,,,,,,,,, -155100,7.772076,0.9869599,,,,,,,,,,,,,, -155200,7.601066,0.93997717,,,,,,,,,,,,,, -155300,8.475344,1.0406125,,,,,,,,,,,,,, -155400,7.507851,1.0278395,,,,,,,,,,,,,, -155500,7.7647004,0.9506236,,,,,,,,,,,,,, -155600,9.011335,1.0542654,,,,,,,,,,,,,, -155700,8.302551,0.93691516,,,,,,,,,,,,,, -155724,,,0.8673469424247742,0.4798938035964966,0.7419399619102478,1.0362374782562256,50000.0,0.6204000115394592,1.7236957550048828,10000.0,53078.62742900848,54921.7041721344,53078.62742900848,1832.3190059661863,5.4648637771606445,0.0 -155800,7.285368,0.8829932,,,,,,,,,,,,,, -155900,7.8440757,1.0046006,,,,,,,,,,,,,, -156000,8.36468,1.001946,,,,,,,,,,,,,, -156100,7.484087,0.9849771,,,,,,,,,,,,,, -156200,7.574686,1.0309676,,,,,,,,,,,,,, -156300,8.513796,1.0136557,,,,,,,,,,,,,, -156400,8.039751,0.95523727,,,,,,,,,,,,,, -156500,8.064753,0.93083274,,,,,,,,,,,,,, -156600,8.258106,0.9379416,,,,,,,,,,,,,, -156700,7.7975655,1.014364,,,,,,,,,,,,,, -156800,7.6484237,0.94734657,,,,,,,,,,,,,, -156900,7.6378846,0.9459567,,,,,,,,,,,,,, -157000,7.05281,0.8594379,,,,,,,,,,,,,, -157100,7.5631485,0.983793,,,,,,,,,,,,,, -157200,8.74338,0.991323,,,,,,,,,,,,,, -157222,,,0.8683633208274841,0.4606120586395263,0.7424399852752686,1.0408451557159424,50000.0,0.6229000091552734,1.7110683917999268,10000.0,53588.541823387146,55449.373153209686,53588.541823387146,1849.9641468524933,5.523918867111206,0.0 -157300,7.027052,0.86955357,,,,,,,,,,,,,, -157400,8.188292,1.0208688,,,,,,,,,,,,,, -157500,7.5181994,0.8964821,,,,,,,,,,,,,, -157600,7.305616,0.8915159,,,,,,,,,,,,,, -157700,7.376613,0.94531804,,,,,,,,,,,,,, -157800,7.2050033,0.89984965,,,,,,,,,,,,,, -157900,8.956558,0.98149,,,,,,,,,,,,,, -158000,8.240751,0.96815974,,,,,,,,,,,,,, -158100,7.764754,0.97125983,,,,,,,,,,,,,, -158200,7.3556705,0.959427,,,,,,,,,,,,,, -158300,8.972332,1.0837855,,,,,,,,,,,,,, -158400,7.801869,1.0144746,,,,,,,,,,,,,, -158500,8.218081,0.95314837,,,,,,,,,,,,,, -158600,7.9236913,0.9080793,,,,,,,,,,,,,, -158700,7.9714665,1.0284297,,,,,,,,,,,,,, -158721,,,0.8824138641357422,0.4179320037364959,0.7461999654769897,1.0250033140182495,50000.0,0.6211000084877014,1.7047979831695557,10000.0,54098.47330284119,55976.7153468132,54098.47330284119,1867.2719218730929,5.576862573623657,0.0 -158800,7.458909,0.9670992,,,,,,,,,,,,,, -158900,8.7099695,0.9569133,,,,,,,,,,,,,, -159000,8.208395,0.9722373,,,,,,,,,,,,,, -159100,7.9301286,1.0009899,,,,,,,,,,,,,, -159200,7.874344,0.98436195,,,,,,,,,,,,,, -159300,7.7637043,0.9634206,,,,,,,,,,,,,, -159400,8.644875,1.0769777,,,,,,,,,,,,,, -159500,7.2956133,0.92561704,,,,,,,,,,,,,, -159600,8.380659,0.8750585,,,,,,,,,,,,,, -159700,9.449441,0.94915736,,,,,,,,,,,,,, -159800,7.3959236,0.90100574,,,,,,,,,,,,,, -159900,7.7469125,0.98458207,,,,,,,,,,,,,, -160000,7.771627,0.8349067,,,,,,,,,,,,,, -160100,8.289272,0.98843545,,,,,,,,,,,,,, -160200,7.6540203,0.9869176,,,,,,,,,,,,,, -160220,,,0.8951291441917419,0.3667570948600769,0.7464399933815002,1.0255107879638672,50000.0,0.6242000460624695,1.7131024599075315,10000.0,54608.64770102501,56504.31801342964,54608.64770102501,1884.594167470932,5.631555557250977,0.0 -160300,8.82672,0.986509,,,,,,,,,,,,,, -160400,9.744026,0.9179712,,,,,,,,,,,,,, -160500,8.183252,0.92487156,,,,,,,,,,,,,, -160600,7.8867183,1.0339427,,,,,,,,,,,,,, -160700,8.90386,1.0383396,,,,,,,,,,,,,, -160800,9.345196,0.94462556,,,,,,,,,,,,,, -160900,7.672408,0.9932388,,,,,,,,,,,,,, -161000,8.700591,0.8868185,,,,,,,,,,,,,, -161100,8.200882,0.92696154,,,,,,,,,,,,,, -161200,7.650394,0.90109646,,,,,,,,,,,,,, -161300,8.688401,1.0120311,,,,,,,,,,,,,, -161400,8.0887575,0.86476773,,,,,,,,,,,,,, -161500,7.9543915,0.980561,,,,,,,,,,,,,, -161600,7.852679,0.91454804,,,,,,,,,,,,,, -161700,8.807149,0.8302407,,,,,,,,,,,,,, -161719,,,0.8964046239852905,0.3655628263950348,0.7499399781227112,1.0128942728042605,50000.0,0.6291000247001648,1.6922518014907837,10000.0,55118.76144886017,57032.12668228149,55118.76144886017,1902.1840479373927,5.686068296432495,0.0 -161800,8.460509,0.86599594,,,,,,,,,,,,,, -161900,7.735275,0.8761579,,,,,,,,,,,,,, -162000,8.461574,0.95446754,,,,,,,,,,,,,, -162100,7.579867,0.86362803,,,,,,,,,,,,,, -162200,8.439942,0.94530994,,,,,,,,,,,,,, -162300,8.381175,0.9803603,,,,,,,,,,,,,, -162400,7.6827383,0.84302187,,,,,,,,,,,,,, -162500,8.070216,0.89271414,,,,,,,,,,,,,, -162600,7.800783,0.8639634,,,,,,,,,,,,,, -162700,7.7389297,0.8881754,,,,,,,,,,,,,, -162800,8.600793,0.915001,,,,,,,,,,,,,, -162900,8.211556,0.8446321,,,,,,,,,,,,,, -163000,7.4235306,0.8806163,,,,,,,,,,,,,, -163100,8.467832,0.8824769,,,,,,,,,,,,,, -163200,7.764518,0.9261159,,,,,,,,,,,,,, -163218,,,0.8958067297935486,0.3668033480644226,0.7515599727630615,1.0076985359191897,50000.0,0.6283000111579895,1.6882383823394775,10000.0,55628.83065366745,57559.539657115936,55628.83065366745,1919.415715456009,5.746610403060913,0.0 -163300,8.575139,0.8939794,,,,,,,,,,,,,, -163400,8.257541,0.87513,,,,,,,,,,,,,, -163500,7.929429,0.86542237,,,,,,,,,,,,,, -163600,8.2001095,0.8569538,,,,,,,,,,,,,, -163700,8.076396,0.8962854,,,,,,,,,,,,,, -163800,7.811239,0.8156459,,,,,,,,,,,,,, -163900,8.849252,0.89566493,,,,,,,,,,,,,, -164000,9.26601,0.9215427,,,,,,,,,,,,,, -164100,8.531962,0.8647563,,,,,,,,,,,,,, -164200,8.042761,0.8363556,,,,,,,,,,,,,, -164300,8.418349,0.89060235,,,,,,,,,,,,,, -164400,8.006371,0.96487767,,,,,,,,,,,,,, -164500,8.575213,0.8709156,,,,,,,,,,,,,, -164600,9.00475,0.90720785,,,,,,,,,,,,,, -164700,9.55248,0.89715266,,,,,,,,,,,,,, -164717,,,0.89652419090271,0.3590350449085235,0.7538999915122986,0.995878040790558,50000.0,0.6323000192642212,1.692557692527771,10000.0,56138.98541164398,58087.73439216614,56138.98541164398,1937.3459751605988,5.80523681640625,0.0 -164800,8.544817,0.89752007,,,,,,,,,,,,,, -164900,8.770723,0.8644842,,,,,,,,,,,,,, -165000,8.898483,0.8592945,,,,,,,,,,,,,, -165100,8.10546,0.87658954,,,,,,,,,,,,,, -165200,9.48342,0.88594806,,,,,,,,,,,,,, -165300,8.555269,0.8817508,,,,,,,,,,,,,, -165400,8.200938,0.8965656,,,,,,,,,,,,,, -165500,8.215217,0.8764196,,,,,,,,,,,,,, -165600,8.585511,0.8827107,,,,,,,,,,,,,, -165700,8.346413,0.8299186,,,,,,,,,,,,,, -165800,8.573983,0.87531316,,,,,,,,,,,,,, -165900,7.7960625,0.8317867,,,,,,,,,,,,,, -166000,8.874569,0.83740866,,,,,,,,,,,,,, -166100,7.7197075,0.83156604,,,,,,,,,,,,,, -166200,7.7019577,0.8145556,,,,,,,,,,,,,, -166216,,,0.9032605290412904,0.3413633108139038,0.7542799711227417,0.9981317520141602,50000.0,0.6308000087738037,1.68972647190094,10000.0,56649.082023859024,58615.33793616295,56649.082023859024,1954.7560713291168,5.8507981300354,0.0 -166300,8.395438,0.8831137,,,,,,,,,,,,,, -166400,8.7691555,0.92953783,,,,,,,,,,,,,, -166500,8.599291,0.8305664,,,,,,,,,,,,,, -166600,8.412411,0.8022285,,,,,,,,,,,,,, -166700,8.653149,0.964592,,,,,,,,,,,,,, -166800,8.135183,0.83184457,,,,,,,,,,,,,, -166900,9.48231,0.84277713,,,,,,,,,,,,,, -167000,7.808739,0.83854944,,,,,,,,,,,,,, -167100,8.667425,0.9704212,,,,,,,,,,,,,, -167200,8.157848,0.8048697,,,,,,,,,,,,,, -167300,8.45282,0.8327012,,,,,,,,,,,,,, -167400,8.302095,0.84427226,,,,,,,,,,,,,, -167500,8.407638,0.7919447,,,,,,,,,,,,,, -167600,8.129396,0.83881843,,,,,,,,,,,,,, -167700,8.029886,0.8072578,,,,,,,,,,,,,, -167715,,,0.9057517647743224,0.3304113447666168,0.755620002746582,0.9929838180541992,50000.0,0.6345000267028809,1.6859267950057983,10000.0,57159.2556912899,59142.94081425667,57159.2556912899,1972.07869887352,5.907650232315064,0.0 -167800,8.753641,0.8135275,,,,,,,,,,,,,, -167900,7.877998,0.79136604,,,,,,,,,,,,,, -168000,7.9360485,0.7603898,,,,,,,,,,,,,, -168100,9.132786,0.85744524,,,,,,,,,,,,,, -168200,9.489167,0.880105,,,,,,,,,,,,,, -168300,9.127093,0.8389121,,,,,,,,,,,,,, -168400,7.9447618,0.7982073,,,,,,,,,,,,,, -168500,9.0648,0.90187985,,,,,,,,,,,,,, -168600,8.276892,0.8745858,,,,,,,,,,,,,, -168700,8.732087,0.83283186,,,,,,,,,,,,,, -168800,9.794562,0.81360024,,,,,,,,,,,,,, -168900,8.3893795,0.84523183,,,,,,,,,,,,,, -169000,8.813634,0.8730192,,,,,,,,,,,,,, -169100,9.07758,0.84021044,,,,,,,,,,,,,, -169200,9.012257,0.74572283,,,,,,,,,,,,,, -169213,,,0.920340359210968,0.2846298217773437,0.7571199536323547,0.9897215962409972,50000.0,0.6319000124931335,1.6728354692459106,10000.0,57669.26935172081,59670.21319055557,57669.26935172081,1989.2317397594447,5.962466239929199,0.0 -169300,9.251334,0.9170715,,,,,,,,,,,,,, -169400,8.277419,0.75660455,,,,,,,,,,,,,, -169500,10.302307,0.791891,,,,,,,,,,,,,, -169600,8.099682,0.80466384,,,,,,,,,,,,,, -169700,8.531056,0.7917804,,,,,,,,,,,,,, -169800,8.527049,0.86193514,,,,,,,,,,,,,, -169900,8.554066,0.79190814,,,,,,,,,,,,,, -170000,9.668275,0.90858024,,,,,,,,,,,,,, -170100,8.778906,0.7705512,,,,,,,,,,,,,, -170200,8.408553,0.7998975,,,,,,,,,,,,,, -170300,8.066212,0.79246217,,,,,,,,,,,,,, -170400,8.712956,0.84105897,,,,,,,,,,,,,, -170500,8.479929,0.7915313,,,,,,,,,,,,,, -170600,9.201553,0.8456619,,,,,,,,,,,,,, -170700,9.379209,0.8846535,,,,,,,,,,,,,, -170712,,,0.920719027519226,0.2819788455963135,0.7594199776649475,0.9810996651649476,50000.0,0.6370000243186951,1.660903811454773,10000.0,58179.45411133766,60198.01121592522,58179.45411133766,2006.736686944961,6.018509387969971,0.0 -170800,8.548129,0.7962868,,,,,,,,,,,,,, -170900,9.468511,0.906091,,,,,,,,,,,,,, -171000,8.112067,0.81466925,,,,,,,,,,,,,, -171100,10.371698,0.7633344,,,,,,,,,,,,,, -171200,8.888291,0.80586946,,,,,,,,,,,,,, -171300,9.479396,0.8002184,,,,,,,,,,,,,, -171400,8.730913,0.814984,,,,,,,,,,,,,, -171500,8.24723,0.738917,,,,,,,,,,,,,, -171600,8.4988785,0.7662096,,,,,,,,,,,,,, -171700,9.414423,0.8724315,,,,,,,,,,,,,, -171800,8.343008,0.7372027,,,,,,,,,,,,,, -171900,9.385607,0.74823725,,,,,,,,,,,,,, -172000,8.211173,0.7836609,,,,,,,,,,,,,, -172100,8.799623,0.8623233,,,,,,,,,,,,,, -172200,8.884263,0.77505463,,,,,,,,,,,,,, -172210,,,0.9206393361091614,0.2826157808303833,0.7601400017738342,0.9788408875465392,50000.0,0.6363000273704529,1.666724443435669,10000.0,58689.38737034798,60725.54909420013,58689.38737034798,2024.2342946529388,6.075444459915161,0.0 -172300,8.997915,0.8484893,,,,,,,,,,,,,, -172400,8.5184355,0.815605,,,,,,,,,,,,,, -172500,9.509407,0.81744635,,,,,,,,,,,,,, -172600,8.378229,0.77301985,,,,,,,,,,,,,, -172700,8.872828,0.8097657,,,,,,,,,,,,,, -172800,8.779933,0.8224195,,,,,,,,,,,,,, -172900,9.099239,0.784547,,,,,,,,,,,,,, -173000,8.000042,0.7258106,,,,,,,,,,,,,, -173100,8.451569,0.7658133,,,,,,,,,,,,,, -173200,8.887639,0.7537579,,,,,,,,,,,,,, -173300,8.511221,0.7670003,,,,,,,,,,,,,, -173400,7.831377,0.7237978,,,,,,,,,,,,,, -173500,8.175512,0.7441222,,,,,,,,,,,,,, -173600,8.852735,0.7346454,,,,,,,,,,,,,, -173700,9.076179,0.85111487,,,,,,,,,,,,,, -173708,,,0.9241868257522584,0.267222911119461,0.7613199949264526,0.9772453904151917,50000.0,0.6378000378608704,1.6611794233322144,10000.0,59199.27856183052,61252.9173913002,59199.27856183052,2041.5964758396149,6.139222145080566,0.0 -173800,8.884248,0.75782347,,,,,,,,,,,,,, -173900,8.080846,0.7385775,,,,,,,,,,,,,, -174000,7.8854585,0.68404496,,,,,,,,,,,,,, -174100,8.700096,0.8160299,,,,,,,,,,,,,, -174200,9.968134,0.85342723,,,,,,,,,,,,,, -174300,9.1317215,0.804921,,,,,,,,,,,,,, -174400,8.974736,0.7394866,,,,,,,,,,,,,, -174500,9.113736,0.8115225,,,,,,,,,,,,,, -174600,8.422087,0.73378736,,,,,,,,,,,,,, -174700,8.551968,0.7684617,,,,,,,,,,,,,, -174800,10.026045,0.779698,,,,,,,,,,,,,, -174900,9.100625,0.74345416,,,,,,,,,,,,,, -175000,8.6655655,0.76791865,,,,,,,,,,,,,, -175100,8.865154,0.8112906,,,,,,,,,,,,,, -175200,9.247648,0.86429316,,,,,,,,,,,,,, -175206,,,0.9243263602256776,0.2708885669708252,0.7618199586868286,0.9700082540512084,50000.0,0.6360000371932983,1.6560876369476318,10000.0,59709.26147675514,61780.24515080452,59709.26147675514,2058.824327230453,6.205310344696045,0.0 -175300,9.326573,0.7254615,,,,,,,,,,,,,, -175400,9.265034,0.7822944,,,,,,,,,,,,,, -175500,8.783952,0.7722601,,,,,,,,,,,,,, -175600,8.908085,0.75956935,,,,,,,,,,,,,, -175700,8.98597,0.80682766,,,,,,,,,,,,,, -175800,8.497406,0.79602754,,,,,,,,,,,,,, -175900,8.409782,0.7545075,,,,,,,,,,,,,, -176000,9.156457,0.86595523,,,,,,,,,,,,,, -176100,9.422137,0.71771276,,,,,,,,,,,,,, -176200,8.604384,0.68668556,,,,,,,,,,,,,, -176300,9.297608,0.8415467,,,,,,,,,,,,,, -176400,9.739092,0.7747295,,,,,,,,,,,,,, -176500,8.718113,0.75686723,,,,,,,,,,,,,, -176600,9.5243845,0.73513234,,,,,,,,,,,,,, -176700,9.269796,0.69925106,,,,,,,,,,,,,, -176704,,,0.9294283986091614,0.2561113834381103,0.7628200054168701,0.9697470664978028,50000.0,0.6390000581741333,1.6528555154800415,10000.0,60219.2170612812,62307.96381902695,60219.2170612812,2076.4793939590454,6.263585090637207,0.0 -176800,9.040303,0.79489785,,,,,,,,,,,,,, -176900,8.91412,0.6756232,,,,,,,,,,,,,, -177000,8.690329,0.7265104,,,,,,,,,,,,,, -177100,8.770789,0.74686897,,,,,,,,,,,,,, -177200,9.442709,0.8088708,,,,,,,,,,,,,, -177300,9.296095,0.7766306,,,,,,,,,,,,,, -177400,8.532333,0.6800839,,,,,,,,,,,,,, -177500,9.083875,0.7514719,,,,,,,,,,,,,, -177600,9.12697,0.79837215,,,,,,,,,,,,,, -177700,8.30184,0.69823647,,,,,,,,,,,,,, -177800,8.484597,0.76550215,,,,,,,,,,,,,, -177900,8.654341,0.68245625,,,,,,,,,,,,,, -178000,8.906122,0.7388403,,,,,,,,,,,,,, -178100,8.048671,0.66923803,,,,,,,,,,,,,, -178200,7.6053495,0.6643734,,,,,,,,,,,,,, -178203,,,0.9300462007522584,0.2525694072246551,0.763759970664978,0.9669691920280457,50000.0,0.6402000188827515,1.652724266052246,10000.0,60729.36058998108,62835.58871150017,60729.36058998108,2093.848112821579,6.324261903762817,0.0 -178300,8.624274,0.80092597,,,,,,,,,,,,,, -178400,8.858892,0.72721994,,,,,,,,,,,,,, -178500,8.870951,0.7408763,,,,,,,,,,,,,, -178600,9.969896,0.7660237,,,,,,,,,,,,,, -178700,8.904218,0.7119339,,,,,,,,,,,,,, -178800,9.588486,0.7699235,,,,,,,,,,,,,, -178900,8.848175,0.7197183,,,,,,,,,,,,,, -179000,8.069539,0.70181334,,,,,,,,,,,,,, -179100,8.84691,0.64667714,,,,,,,,,,,,,, -179200,8.780324,0.7835075,,,,,,,,,,,,,, -179300,8.912033,0.67533696,,,,,,,,,,,,,, -179400,8.83365,0.78066206,,,,,,,,,,,,,, -179500,9.545635,0.76178443,,,,,,,,,,,,,, -179600,8.945109,0.7367604,,,,,,,,,,,,,, -179700,8.618101,0.7154745,,,,,,,,,,,,,, -179701,,,0.9334741234779358,0.2409894466400146,0.7638599872589111,0.9650241136550904,50000.0,0.64410001039505,1.6483752727508545,10000.0,61239.35639166832,63362.927340745926,61239.35639166832,2111.0798873901367,6.385619401931763,0.0 -179800,8.908749,0.7395215,,,,,,,,,,,,,, -179900,9.007567,0.7389593,,,,,,,,,,,,,, -180000,9.785458,0.8133497,,,,,,,,,,,,,, -180100,8.914038,0.7437841,,,,,,,,,,,,,, -180200,9.754379,0.69422513,,,,,,,,,,,,,, -180300,9.326227,0.7227214,,,,,,,,,,,,,, -180400,9.219756,0.71287346,,,,,,,,,,,,,, -180500,9.513994,0.8006146,,,,,,,,,,,,,, -180600,9.754804,0.7318226,,,,,,,,,,,,,, -180700,9.3879795,0.6882292,,,,,,,,,,,,,, -180800,8.3554535,0.73514885,,,,,,,,,,,,,, -180900,9.506264,0.69951934,,,,,,,,,,,,,, -181000,9.697563,0.7797566,,,,,,,,,,,,,, -181100,8.827834,0.72545564,,,,,,,,,,,,,, -181199,,,0.933254897594452,0.2420443147420883,0.7646999955177307,0.9629248380661012,50000.0,0.6421000361442566,1.6437848806381226,10000.0,61749.25453186035,63890.17459869385,61749.25453186035,2128.31769990921,6.443438529968262,0.0 -181200,9.020594,0.7246904,,,,,,,,,,,,,, -181300,8.961964,0.7284921,,,,,,,,,,,,,, -181400,9.929523,0.7665589,,,,,,,,,,,,,, -181500,8.983811,0.7622708,,,,,,,,,,,,,, -181600,8.638952,0.70260525,,,,,,,,,,,,,, -181700,8.612564,0.65844333,,,,,,,,,,,,,, -181800,9.137725,0.7704524,,,,,,,,,,,,,, -181900,9.369949,0.7651013,,,,,,,,,,,,,, -182000,9.271892,0.79354304,,,,,,,,,,,,,, -182100,9.772946,0.76256543,,,,,,,,,,,,,, -182200,7.8191056,0.69037676,,,,,,,,,,,,,, -182300,9.550582,0.7851046,,,,,,,,,,,,,, -182400,8.941453,0.7175107,,,,,,,,,,,,,, -182500,8.975197,0.69524765,,,,,,,,,,,,,, -182600,8.685888,0.69899416,,,,,,,,,,,,,, -182698,,,0.9331353306770324,0.2427098006010055,0.7647599577903748,0.9616653919219972,50000.0,0.6443000435829163,1.6459656953811646,10000.0,62259.378823280334,64417.62543559074,62259.378823280334,2145.533955574036,6.500751495361328,0.0 -182700,8.658873,0.75267756,,,,,,,,,,,,,, -182800,9.628446,0.75328803,,,,,,,,,,,,,, -182900,8.817741,0.6747673,,,,,,,,,,,,,, -183000,8.83674,0.76195246,,,,,,,,,,,,,, -183100,8.42681,0.6717283,,,,,,,,,,,,,, -183200,8.7888155,0.72200525,,,,,,,,,,,,,, -183300,8.757359,0.7820151,,,,,,,,,,,,,, -183400,8.991259,0.6838709,,,,,,,,,,,,,, -183500,8.700002,0.6989206,,,,,,,,,,,,,, -183600,8.711701,0.70731163,,,,,,,,,,,,,, -183700,9.172904,0.7119514,,,,,,,,,,,,,, -183800,8.822271,0.812628,,,,,,,,,,,,,, -183900,8.813384,0.7580905,,,,,,,,,,,,,, -184000,8.424011,0.7045737,,,,,,,,,,,,,, -184100,9.683324,0.6798862,,,,,,,,,,,,,, -184197,,,0.933812975883484,0.240532174706459,0.7651599645614624,0.9626170992851256,50000.0,0.6438000202178955,1.646983623504639,10000.0,62769.5444791317,64945.14319252968,62769.5444791317,2162.772839546204,6.56220269203186,0.0 -184200,9.425276,0.76241654,,,,,,,,,,,,,, -184300,8.3109045,0.6919412,,,,,,,,,,,,,, -184400,8.572885,0.78039545,,,,,,,,,,,,,, -184500,10.229056,0.7844906,,,,,,,,,,,,,, -184600,10.210766,0.77496,,,,,,,,,,,,,, -184700,9.293252,0.769884,,,,,,,,,,,,,, -184800,8.69773,0.7325858,,,,,,,,,,,,,, -184898,,,,,,,,,,,63008.04301953316,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 3121d7186..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.358678817749023,0.0,33.559484243392944,1,0,33.559484243392944,0.0012000000569969,6.910790920257568,10000,50.918386936187744,0.001335299690254,6.91108512878418,0.0010199999669566,6.91091251373291,50000 -34.713916540145874,0.0192787647247314,543.7184791564941,1492,0,543.7184791564941,0.1371000111103058,4.80158805847168,10000,578.5034921169281,0.1883968412876129,4.257904052734375,0.1703599989414215,4.3842315673828125,50000 -51.93747568130493,0.0497360229492187,1053.8195695877075,2985,0,1053.8195695877075,0.2454000115394592,3.8982651233673096,10000,1105.9118909835815,0.3517418503761291,3.024991035461426,0.3260399997234344,3.193049192428589,50000 -69.63504076004028,0.0763628482818603,1563.8707466125488,4479,0,1563.8707466125488,0.2946000099182129,3.515520572662353,10000,1633.739863872528,0.409877210855484,2.6678028106689453,0.3876599967479706,2.80322003364563,50000 -87.07250952720642,0.1038854122161865,2073.9059176445007,5974,0,2073.9059176445007,0.2701000273227691,3.761731147766113,10000,2161.291288852692,0.4144810140132904,2.693990468978882,0.3527399897575378,3.0620031356811523,50000 -104.60787105560304,0.1320948600769043,2584.0340342521667,7470,0,2584.0340342521667,0.2379000186920166,4.055367946624756,10000,2689.035103559494,0.3179607689380646,3.236488103866577,0.2916599810123443,3.4889159202575684,50000 -121.91842865943909,0.1617681980133056,3094.1442699432373,8967,0,3094.1442699432373,0.2201000154018402,4.388706684112549,10000,3216.537830352783,0.3252750337123871,3.3958663940429688,0.3066799938678741,3.4880292415618896,50000 -139.31764602661133,0.1891367435455322,3604.070210456848,10464,0,3604.070210456848,0.1573000103235244,5.09550666809082,10000,3743.9415435791016,0.2239516824483871,4.2264933586120605,0.2140799909830093,4.341734409332275,50000 -157.3548982143402,0.2187190055847168,4114.044556617737,11962,0,4114.044556617737,0.2417000085115432,4.087606906890869,10000,4272.033766746521,0.3482939898967743,3.1312360763549805,0.3265599906444549,3.3243465423583984,50000 -174.88282465934753,0.248997688293457,4624.091902256012,13460,0,4624.091902256012,0.1425000131130218,5.075582027435303,10000,4799.691284656525,0.2033442258834839,4.297685623168945,0.1912999898195266,4.41984748840332,50000 -192.96386647224423,0.2813014984130859,5134.104543209076,14959,0,5134.104543209076,0.1761000156402588,4.722900867462158,10000,5327.868116140366,0.2525111436843872,3.852193355560303,0.2320999950170517,4.034037113189697,50000 -210.40357780456543,0.3103921413421631,5644.231992006302,16459,0,5644.231992006302,0.0839000046253204,6.400252819061279,10000,5855.515547513962,0.1354631632566452,5.622068405151367,0.1199799999594688,5.786570072174072,50000 -227.8885309696197,0.340724229812622,6154.353933811188,17959,0,6154.353933811188,0.254800021648407,3.834587574005127,10000,6383.206515073776,0.3703164756298065,2.9259555339813232,0.3415800034999847,3.128161191940308,50000 -245.5549459457397,0.3720059394836426,6664.5083796978,19460,0,6664.5083796978,0.1698000133037567,5.243460178375244,10000,6911.110089302063,0.248425543308258,4.295435428619385,0.2261399924755096,4.466485023498535,50000 -262.85289573669434,0.3986082077026367,7174.447612285614,20961,0,7174.447612285614,0.1830000132322311,4.831855773925781,10000,7438.4239411354065,0.2622568607330322,3.897775888442993,0.255400002002716,3.955528974533081,50000 -280.15942215919495,0.4302568435668945,7684.4579641819,22462,0,7684.4579641819,0.0861000046133995,6.424266338348389,10000,7965.822619438171,0.1203364133834838,5.798086643218994,0.1094199940562248,5.968713760375977,50000 -297.41808342933655,0.4613358974456787,8194.56183886528,23963,0,8194.56183886528,0.055500004440546,6.987226963043213,10000,8493.26726603508,0.0801777690649032,6.423879146575928,0.0747199952602386,6.47040605545044,50000 -315.2127788066864,0.494720458984375,8704.487569570541,25464,0,8704.487569570541,0.0408000014722347,9.22226333618164,10000,9021.073375463486,0.0676219686865806,8.211507797241211,0.0568999983370304,8.37997817993164,50000 -332.8436703681946,0.5292325019836426,9214.40810918808,26965,0,9214.40810918808,0.1159000024199485,5.787625312805176,10000,9548.710106372831,0.1829360574483871,4.75714921951294,0.1735799908638,4.842849254608154,50000 -350.31236839294434,0.5601718425750732,9724.49069738388,28467,0,9724.49069738388,0.1175000071525573,5.716974258422852,10000,10076.34494829178,0.1595384180545807,5.077648639678955,0.1496199965476989,5.205980777740479,50000 -367.536589384079,0.5917963981628418,10234.423271417618,29969,0,10234.423271417618,0.180400013923645,4.745062351226807,10000,10603.58622789383,0.2542450428009033,3.912951707839966,0.2386399954557418,4.062727928161621,50000 -384.77102971076965,0.6250383853912354,10744.5173869133,31471,0,10744.5173869133,0.2486000061035156,3.894976377487183,10000,11131.000636100767,0.342873066663742,3.108630418777466,0.3323999941349029,3.178079843521118,50000 -402.3353538513184,0.6580126285552979,11254.584214448929,32974,0,11254.584214448929,0.1951000094413757,5.094240665435791,10000,11658.715369939804,0.273138552904129,4.065606594085693,0.2622599899768829,4.184144973754883,50000 -419.9206705093384,0.6911814212799072,11764.796533107758,34477,0,11764.796533107758,0.1181000024080276,5.938706398010254,10000,12186.59768295288,0.1878587305545807,4.851592540740967,0.1707399934530258,5.058218955993652,50000 -437.3967123031616,0.727304220199585,12274.719363689424,35980,0,12274.719363689424,0.1688000112771988,4.942447185516357,10000,12714.083587408066,0.2466717064380645,4.081879615783691,0.231019988656044,4.22337007522583,50000 -454.6478660106659,0.764479398727417,12784.916873216627,37484,0,12784.916873216627,0.1565000116825103,5.578945159912109,10000,13241.621287107468,0.2384207546710968,4.461237907409668,0.2242599874734878,4.545414447784424,50000 -472.1811301708221,0.797905683517456,13295.052711725237,38987,0,13295.052711725237,0.1946000158786773,4.6907501220703125,10000,13769.374715805054,0.2852160334587097,3.727014541625977,0.2743600010871887,3.853545427322388,50000 -489.544287443161,0.8351831436157227,13804.977420091627,40490,0,13804.977420091627,0.2426000088453292,4.134172439575195,10000,14296.749955415726,0.3294602930545807,3.2785985469818115,0.3068799972534179,3.4649877548217773,50000 -506.81652092933655,0.8721778392791748,14315.08835530281,41994,0,14315.08835530281,0.0749000012874603,6.715982437133789,10000,14824.222242355348,0.1082987859845161,6.09940767288208,0.1052399948239326,6.120722770690918,50000 -524.241007566452,0.9064865112304688,14825.161313533785,43497,0,14825.161313533785,0.0809000059962272,5.751832008361816,10000,15351.805294513702,0.1314971297979354,4.971019744873047,0.1196599975228309,5.1093292236328125,50000 -541.479898929596,0.9436588287353516,15335.105654001236,44999,0,15335.105654001236,0.1255000084638595,5.8240647315979,10000,15879.07846212387,0.1791892498731613,5.077031135559082,0.1635999977588653,5.281920433044434,50000 -558.8642518520355,0.979116916656494,15845.155004501345,46503,0,15845.155004501345,0.2361000180244445,4.294202327728272,10000,16406.597977876663,0.3250558078289032,3.3820459842681885,0.3132599890232086,3.5227527618408203,50000 -576.4580047130585,1.020794153213501,16355.190060853958,48007,0,16355.190060853958,0.2668000161647796,3.835937261581421,10000,16934.3196310997,0.3681640625,3.034077405929565,0.3450599908828735,3.1679186820983887,50000 -593.8302228450775,1.0560753345489502,16865.38143491745,49511,0,16865.38143491745,0.203900009393692,4.691739082336426,10000,17461.97078728676,0.2969347834587097,3.669082164764404,0.2820599973201751,3.842806100845337,50000 -611.2783124446869,1.094315767288208,17375.29843711853,51014,0,17375.29843711853,0.1058000028133392,5.460964679718018,10000,17989.424177646637,0.1538384854793548,4.814224243164063,0.1469199955463409,4.868548393249512,50000 -629.0414938926697,1.1306431293487549,17885.518147945404,52519,0,17885.518147945404,0.3015000224113464,3.607529878616333,10000,18517.496727705,0.4587452113628387,2.4128456115722656,0.3999799787998199,2.8159847259521484,50000 -646.2846746444702,1.1675300598144531,18395.49843478203,54023,0,18395.49843478203,0.169400006532669,4.901164531707764,10000,19044.81125664711,0.2306481152772903,4.139097213745117,0.2137999981641769,4.321166038513184,50000 -664.4693939685822,1.2078561782836914,18905.47820210457,55527,0,18905.47820210457,0.172200009226799,4.754857063293457,10000,19573.068316936493,0.2505779564380646,3.95234489440918,0.2338999956846237,4.098372936248779,50000 -681.7302577495575,1.2400367259979248,19415.634298086166,57032,0,19415.634298086166,0.0820000022649765,6.516331672668457,10000,20100.56913280487,0.1284478604793548,5.659578323364258,0.1220599934458732,5.704127788543701,50000 -698.9143342971802,1.2803306579589844,19925.873183965683,58537,0,19925.873183965683,0.1523000001907348,5.230589866638184,10000,20628.08365273476,0.2176538556814193,4.364476680755615,0.2053599953651428,4.498104095458984,50000 -716.4034960269928,1.3220326900482178,20435.8396422863,60041,0,20435.8396422863,0.1358000040054321,6.057143211364746,10000,21155.63431572914,0.1741868555545807,5.3501811027526855,0.1663800030946731,5.571073055267334,50000 -734.0819170475006,1.3605809211730957,20945.94157910347,61546,0,20945.94157910347,0.1471000015735626,5.390626907348633,10000,21683.503759384155,0.2065329998731613,4.479652404785156,0.1864800006151199,4.748819351196289,50000 -751.711775302887,1.405496597290039,21456.14753627777,63052,0,21456.14753627777,0.2154000103473663,4.478518962860107,10000,22211.436054468155,0.3102877736091614,3.508485078811645,0.2889399826526642,3.657221794128418,50000 -768.9775350093842,1.437954664230347,21966.07893872261,64556,0,21966.07893872261,0.2107000052928924,4.761738300323486,10000,22738.717492341995,0.3029137253761291,3.726241111755371,0.2832199931144714,3.931478977203369,50000 -786.4678730964661,1.4794397354125977,22476.09085536003,66061,0,22476.09085536003,0.2424000054597854,4.018667221069336,10000,23266.313386917114,0.343849629163742,3.137306928634644,0.3238599896430969,3.2995736598968506,50000 -803.9587597846985,1.5198431015014648,22986.0070104599,67566,0,22986.0070104599,0.2331000119447708,4.2524847984313965,10000,23793.81243133545,0.32421875,3.3596620559692383,0.3198599815368652,3.3958446979522705,50000 -821.3490543365479,1.5737159252166748,23496.11438536644,69071,0,23496.11438536644,0.2605000138282776,3.976634979248047,10000,24321.4155292511,0.3446069657802582,3.2224009037017822,0.3281199932098388,3.301442623138428,50000 -838.9119355678558,1.6166942119598389,24006.213018655777,70576,0,24006.213018655777,0.2357000112533569,3.938102960586548,10000,24849.170696020126,0.3404615819454193,3.1158764362335205,0.3131199777126312,3.278167486190796,50000 -856.3283641338348,1.6619768142700195,24516.36786627769,72081,0,24516.36786627769,0.3187000155448913,3.359588146209717,10000,25376.83913421631,0.4693478941917419,2.3059208393096924,0.4316200017929077,2.5475010871887207,50000 -873.6862845420837,1.7103638648986816,25026.566119670868,73586,0,25026.566119670868,0.3611000180244446,3.1056861877441406,10000,25904.495133399963,0.5002591013908386,2.164903163909912,0.45933997631073,2.4115843772888184,50000 -891.2308826446533,1.750760555267334,25536.524069309235,75091,0,25536.524069309235,0.2302000075578689,4.258174419403076,10000,26432.089807510376,0.3360172212123871,3.280001163482666,0.3144199848175049,3.4575531482696533,50000 -908.7737927436827,1.7912750244140625,26046.467567443848,76595,0,26046.467567443848,0.2645000219345093,4.023133277893066,10000,26959.668189525604,0.3689811825752258,3.100395917892456,0.3448599874973297,3.260609865188598,50000 -926.1273169517516,1.8315112590789795,26556.542511701584,78100,0,26556.542511701584,0.1477000117301941,5.23762321472168,10000,27487.187576293945,0.2097417116165161,4.448638916015625,0.1947399973869323,4.624502658843994,50000 -943.505437374115,1.871402740478516,27066.480364322662,79604,0,27066.480364322662,0.1488000005483627,5.156861782073975,10000,28014.59451031685,0.2213209420442581,4.309572219848633,0.207639992237091,4.423527240753174,50000 -961.0810222625732,1.9099252223968504,27576.63488388061,81109,0,27576.63488388061,0.2300000041723251,4.20773458480835,10000,28542.41470336914,0.3338049948215484,3.2794435024261475,0.3070800006389618,3.5193474292755127,50000 -978.6555554866792,1.955564022064209,28086.69702768325,82614,0,28086.69702768325,0.2820000052452087,3.757661819458008,10000,29070.151161670685,0.395228773355484,2.817749261856079,0.367499977350235,3.013368606567383,50000 -996.0850801467896,1.9987788200378416,28596.871851682663,84119,0,28596.871851682663,0.2729000151157379,3.788983106613159,10000,29597.85049009323,0.3988759517669678,2.8061952590942383,0.3729199767112732,3.0000438690185547,50000 -1013.550271511078,2.044531345367432,29107.02249646187,85624,0,29107.02249646187,0.3519000113010406,3.282273292541504,10000,30125.56484889984,0.4944395720958709,2.248067140579224,0.4607999920845032,2.458147525787353,50000 -1031.0666897296906,2.0885136127471924,29617.0174241066,87129,0,29617.0174241066,0.0892000049352645,5.882399559020996,10000,30653.171503305435,0.1268136203289032,5.291484355926514,0.1218999996781349,5.354971885681152,50000 -1048.3781082630155,2.131108283996582,30126.98759460449,88633,0,30126.98759460449,0.3120000064373016,3.441679000854492,10000,31180.547289133072,0.4404296875,2.522449254989624,0.4178799986839294,2.665285348892212,50000 -1065.8279626369476,2.1842033863067627,30637.00756168365,90138,0,30637.00756168365,0.3660000264644623,3.175899744033813,10000,31708.12091970444,0.5174585580825806,2.07250714302063,0.4711000025272369,2.372112512588501,50000 -1083.9716200828552,2.2286715507507324,31147.1129257679,91643,0,31147.1129257679,0.3335000276565552,3.2831525802612305,10000,32236.46764421463,0.4838368892669678,2.2562386989593506,0.4466199874877929,2.4910802841186523,50000 -1101.381390094757,2.27234148979187,31657.05765509605,93148,0,31657.05765509605,0.3620000183582306,3.1680712699890137,10000,32763.91767191887,0.5083904266357422,2.1743996143341064,0.4763799905776977,2.3738784790039062,50000 -1118.651871919632,2.31717848777771,32167.045583724976,94650,0,32167.045583724976,0.3964000046253204,2.8958516120910645,10000,33291.27138733864,0.545918345451355,1.944828987121582,0.5107399821281433,2.1304807662963867,50000 -1136.1447837352753,2.3657150268554688,32676.99285697937,96154,0,32676.99285697937,0.2696000039577484,3.772510766983032,10000,33818.811498880386,0.3833506107330322,2.867633581161499,0.359959989786148,3.0436670780181885,50000 -1153.468049287796,2.4127352237701416,33187.0912899971,97659,0,33187.0912899971,0.3591000139713287,3.0999302864074707,10000,34346.33107614517,0.5059789419174194,2.157892942428589,0.475160002708435,2.326932668685913,50000 -1170.9882607460022,2.460010766983032,33697.07434248924,99164,0,33697.07434248924,0.3920000195503235,2.879778385162353,10000,34873.93353009224,0.5677016973495483,1.840865135192871,0.5112400054931641,2.122875928878784,50000 -1188.3502779006958,2.5082507133483887,34207.1025724411,100669,0,34207.1025724411,0.3048000037670135,3.538769245147705,10000,35401.42611813545,0.4465082883834839,2.48357892036438,0.4145599901676178,2.690641403198242,50000 -1205.7857220172882,2.556874513626098,34717.1273932457,102174,0,34717.1273932457,0.3846000134944916,2.99954605102539,10000,35928.986558914185,0.5240353941917419,2.0571160316467285,0.4918999969959259,2.25252103805542,50000 -1223.3412280082705,2.6039724349975586,35227.254885435104,103679,0,35227.254885435104,0.3644000291824341,3.1780288219451904,10000,36456.76866769791,0.5104631781578064,2.159630060195923,0.471919983625412,2.3893167972564697,50000 -1240.666074514389,2.65444564819336,35737.33686733246,105184,0,35737.33686733246,0.388700008392334,2.9926905632019043,10000,36984.278197050095,0.5213249325752258,2.0907678604125977,0.4876799881458282,2.284471035003662,50000 -1258.0668041706083,2.7028579711914062,36247.34074640274,106689,0,36247.34074640274,0.3607000112533569,3.1824216842651367,10000,37511.78411793709,0.5049425959587097,2.1846439838409424,0.4718999862670898,2.3848423957824707,50000 -1275.2838730812073,2.7409493923187256,36757.33936190605,108194,0,36757.33936190605,0.3131000101566314,3.456578731536865,10000,38039.08936858177,0.4869260191917419,2.2414798736572266,0.4351199865341186,2.565911769866944,50000 -1292.6387770175934,2.7895431518554688,37267.33650159836,109699,0,37267.33650159836,0.399800032377243,2.946348190307617,10000,38566.54166126251,0.5515983700752258,1.900874376296997,0.5029199719429016,2.1730642318725586,50000 -1310.1845960617063,2.837157964706421,37777.31170344353,111203,0,37777.31170344353,0.3591000139713287,3.1881215572357178,10000,39094.16153287888,0.5053411722183228,2.157813310623169,0.4675799906253814,2.3963406085968018,50000 -1327.739814043045,2.886847972869873,38287.42334794998,112708,0,38287.42334794998,0.3433000147342682,3.334402084350586,10000,39621.92883038521,0.482800543308258,2.318225383758545,0.4461199939250946,2.545124292373657,50000 -1345.3750817775726,2.9340431690216064,38797.589587688446,114214,0,38797.589587688446,0.3836000263690948,3.029585599899292,10000,40149.82967543602,0.5449019074440002,1.9557873010635376,0.5061599612236023,2.1710429191589355,50000 -1362.5382542610168,2.980867624282837,39307.54098248482,115718,0,39307.54098248482,0.4082000255584717,2.8958003520965576,10000,40677.041610240936,0.5570591688156128,1.8826926946640008,0.527999997138977,2.093355178833008,50000 -1379.7358441352844,3.027287483215332,39817.586441755295,117223,0,39817.586441755295,0.4335000216960907,2.682591438293457,10000,41204.382727622986,0.6294443607330322,1.476343035697937,0.5562199950218201,1.8933178186416624,50000 -1396.9002561569214,3.074965238571167,40327.71361851692,118728,0,40327.71361851692,0.3305000066757202,3.4887149333953857,10000,41731.77513575554,0.4579878747463226,2.500563621520996,0.4254599809646606,2.721994638442993,50000 -1414.213954925537,3.124530076980591,40837.7944047451,120233,0,40837.7944047451,0.4288000166416168,2.7302029132843018,10000,42259.27208185196,0.6004663705825806,1.6669267416000366,0.5487599968910217,1.962204933166504,50000 -1431.6174721717834,3.175598382949829,41347.86622738838,121738,0,41347.86622738838,0.4227000176906585,2.80234169960022,10000,42786.84964418411,0.5869738459587097,1.7295794486999512,0.5414199829101562,1.9991382360458367,50000 -1449.155464887619,3.227324724197388,41858.08864212036,123244,0,41858.08864212036,0.3977000117301941,2.858058214187622,10000,43314.71329832077,0.5721460580825806,1.803208351135254,0.5319399833679199,2.029165267944336,50000 -1466.4475784301758,3.278285026550293,42368.04260277748,124749,0,42368.04260277748,0.4696000218391418,2.445818901062012,10000,43842.06088280678,0.6416613459587097,1.4549739360809326,0.5970799922943115,1.6888469457626345,50000 -1483.7346332073212,3.327343463897705,42877.98620200157,126254,0,42877.98620200157,0.4337000250816345,2.768241167068481,10000,44369.39072751999,0.6292649507522583,1.4908218383789062,0.5594800114631653,1.9337760210037231,50000 -1501.349282026291,3.3759353160858154,43387.98578906059,127759,0,43387.98578906059,0.4629000127315521,2.50555157661438,10000,44897.10552716255,0.6521045565605164,1.3957295417785645,0.5881999731063843,1.7238211631774902,50000 -1518.861918926239,3.42822790145874,43897.913105010986,129265,0,43897.913105010986,0.4722000360488891,2.4591338634490967,10000,45424.64857959747,0.6463648080825806,1.4183142185211182,0.5917400121688843,1.7195110321044922,50000 -1536.1293251514437,3.4792776107788086,44408.10325551033,130771,0,44408.10325551033,0.4835000336170196,2.3731515407562256,10000,45952.20838069916,0.6597377061843872,1.3550430536270142,0.6078000068664551,1.6432515382766724,50000 -1553.2879321575165,3.5328681468963623,44918.220878601074,132277,0,44918.220878601074,0.4361000061035156,2.7067646980285645,10000,46479.58945250511,0.5997488498687744,1.664632797241211,0.5585799813270569,1.9178231954574585,50000 -1570.7316064834597,3.585947036743164,45428.21768307686,133783,0,45428.21768307686,0.4651000201702118,2.53909683227539,10000,47007.13358902931,0.623445451259613,1.5336153507232666,0.577239990234375,1.8043313026428225,50000 -1588.1084327697754,3.637380599975586,45938.33998990059,135289,0,45938.33998990059,0.4607000350952148,2.485707998275757,10000,47534.735048532486,0.6740872263908386,1.2957236766815186,0.5938000082969666,1.7162760496139526,50000 -1605.5273563861847,3.687337875366211,46448.27879500389,136794,0,46448.27879500389,0.497700035572052,2.348714590072632,10000,48062.19457030296,0.6825972199440002,1.2317051887512207,0.6187599897384644,1.594221830368042,50000 -1622.827439069748,3.740651845932007,46958.36615371704,138300,0,46958.36615371704,0.5135000348091125,2.21742844581604,10000,48589.68688797951,0.6966079473495483,1.193784475326538,0.6342799663543701,1.5047619342803955,50000 -1640.2983448505402,3.79308032989502,47468.524106025696,139806,0,47468.524106025696,0.4682000279426574,2.447305917739868,10000,49117.42114567757,0.6668726205825806,1.3343795537948608,0.6074599623680115,1.647443413734436,50000 -1657.447308063507,3.841919660568237,47978.61784863472,141312,0,47978.61784863472,0.5249000191688538,2.221546173095703,10000,49644.76369333267,0.7024075388908386,1.1639950275421145,0.6421200037002563,1.4964680671691897,50000 -1675.4333300590515,3.895407199859619,48488.61668562889,142818,0,48488.61668562889,0.5152000188827515,2.194200038909912,10000,50172.85512781143,0.7105787396430969,1.1328246593475342,0.649399995803833,1.4419379234313965,50000 -1693.17382478714,3.937635183334351,48998.60055851936,144323,0,48998.60055851936,0.532800018787384,2.105572462081909,10000,50700.674156188965,0.749422013759613,0.9607897996902466,0.6571199893951416,1.4050936698913574,50000 -1710.3399860858915,3.988811254501343,49508.71371245384,145829,0,49508.71371245384,0.5408000349998474,2.057956457138061,10000,51228.05661511421,0.7521324753761292,0.9460116028785706,0.6695399880409241,1.3460698127746582,50000 -1727.4768795967102,4.0461931228637695,50018.80451273918,147335,0,50018.80451273918,0.5372000336647034,2.097731828689575,10000,51755.39443898201,0.7348732352256775,1.0262718200683594,0.6617599725723267,1.3817018270492554,50000 -1744.73273396492,4.108054876327515,50528.79414916039,148840,0,50528.79414916039,0.5409000515937805,2.0931143760681152,10000,52282.75450348854,0.7347337007522583,1.017677903175354,0.6627399921417236,1.392814040184021,50000 -1762.139191389084,4.162874937057495,51038.696164131165,150346,0,51038.696164131165,0.556600034236908,2.0261521339416504,10000,52810.16929721832,0.7524314522743225,0.9449632167816162,0.682159960269928,1.306384563446045,50000 -1779.5146894454956,4.226062297821045,51548.87925410271,151852,0,51548.87925410271,0.5592000484466553,2.0278122425079346,10000,53337.84234023094,0.7565967440605164,0.9240195155143738,0.6843799948692322,1.2921584844589231,50000 -1796.8985974788666,4.281770467758179,52058.87175607681,153357,0,52058.87175607681,0.5514000058174133,2.04368543624878,10000,53865.32513237,0.7704081535339355,0.8769794702529907,0.6804400086402893,1.3156960010528564,50000 -1814.3410923480988,4.336676836013794,52569.02870512009,154863,0,52569.02870512009,0.5733000040054321,1.9361377954483032,10000,54393.03028893471,0.7912148833274841,0.7727125287055969,0.6976999640464783,1.230203628540039,50000 -1831.7030427455904,4.396013021469116,53079.0686750412,156369,0,53079.0686750412,0.5689000487327576,1.9560490846633911,10000,54920.545063734055,0.7887436151504517,0.7775185704231262,0.6951199769973755,1.2335890531539917,50000 -1849.0799005031583,4.450916528701782,53589.254455804825,157875,0,53589.254455804825,0.5729000568389893,1.9439477920532229,10000,55448.21320796013,0.7881656289100647,0.7825496196746826,0.6997399926185608,1.2260318994522097,50000 -1866.303447008133,4.505049228668213,54099.491090774536,159382,0,54099.491090774536,0.5769000053405762,1.903723120689392,10000,55975.7786552906,0.7869299650192261,0.7920331358909607,0.7021799683570862,1.2132388353347778,50000 -1883.7833399772644,4.564141511917114,54609.7164978981,160888,0,54609.7164978981,0.5863000154495239,1.8755098581314087,10000,56503.59543037415,0.8048867583274841,0.7252900004386902,0.7111799716949463,1.1790040731430054,50000 -1900.9561693668363,4.627662658691406,55119.78022289276,162393,0,55119.78022289276,0.5861999988555908,1.867242455482483,10000,57030.947724580765,0.8147919178009033,0.689312756061554,0.7153199911117554,1.1588680744171145,50000 -1918.2251298427584,4.691629648208618,55629.80427956581,163898,0,55629.80427956581,0.5986000299453735,1.7955706119537354,10000,57558.35658311844,0.8363161683082581,0.6028015613555908,0.7212600111961365,1.1305629014968872,50000 -1935.3796126842497,4.7491774559021,56139.99762535095,165403,0,56139.99762535095,0.6026000380516052,1.800337553024292,10000,58085.81350302696,0.8367944955825806,0.594135046005249,0.7243399620056152,1.119240164756775,50000 -1952.791288852692,4.805515766143799,56649.96901440621,166907,0,56649.96901440621,0.6011000275611877,1.818732500076294,10000,58613.30370092392,0.8360371589660645,0.5923864245414734,0.7259199619293213,1.118510603904724,50000 -1970.818876504898,4.8666205406188965,57159.98609948158,168412,0,57159.98609948158,0.6110000014305115,1.7652047872543335,10000,59141.460246801376,0.8469985723495483,0.5509080290794373,0.7348999977111816,1.0821901559829712,50000 -1988.379658460617,4.923551321029663,57670.21110057831,169918,0,57670.21110057831,0.610200047492981,1.7671650648117063,10000,59669.35487747192,0.8470583558082581,0.5498757362365723,0.7370799779891968,1.0780162811279297,50000 -2005.89358830452,4.983810424804688,58180.16239523888,171423,0,58180.16239523888,0.6117000579833984,1.7470028400421145,10000,60196.933161735535,0.8548508882522583,0.5203155279159546,0.738599956035614,1.06183123588562,50000 -2023.2214815616608,5.041860818862915,58690.09550428391,172928,0,58690.09550428391,0.6200000047683716,1.7335784435272217,10000,60724.30287194252,0.870515763759613,0.4607931077480316,0.7422800064086914,1.0552629232406616,50000 -2040.50359749794,5.102504730224609,59200.02728843689,174433,0,59200.02728843689,0.6190000176429749,1.7485793828964231,10000,61251.63083457947,0.8722695708274841,0.4543417692184448,0.7435599565505981,1.0551124811172483,50000 -2057.974116802216,5.163911104202271,59710.2247235775,175938,0,59710.2247235775,0.6223000288009644,1.732593059539795,10000,61779.41142606735,0.8727478981018066,0.4520764946937561,0.7457000017166138,1.043248414993286,50000 -2075.270829439163,5.224120616912842,60220.32627439499,177443,0,60220.32627439499,0.6240000128746033,1.7179020643234253,10000,62306.92303276062,0.875996470451355,0.442630410194397,0.7466399669647217,1.0333547592163086,50000 -2092.762161254883,5.284364223480225,60730.23235201836,178947,0,60730.23235201836,0.6215000152587891,1.717150092124939,10000,62834.43216466904,0.8785873651504517,0.4270265698432922,0.7484599947929382,1.0259389877319336,50000 -2110.0475058555603,5.346153259277344,61240.38385462761,180453,0,61240.38385462761,0.6262000203132629,1.7159655094146729,10000,63361.98218679428,0.8818159699440002,0.412717342376709,0.7495399713516235,1.025421977043152,50000 -2127.516664505005,5.409894943237305,61750.597148656845,181959,0,61750.597148656845,0.6243000030517578,1.715208888053894,10000,63889.77865052223,0.8836495280265808,0.4111799895763397,0.7505399584770203,1.0216362476348877,50000 -2144.8360488414764,5.469226121902466,62260.746554374695,183465,0,62260.746554374695,0.6269000172615051,1.7149865627288818,10000,64417.35872173309,0.884785532951355,0.4123013913631439,0.7511999607086182,1.020043134689331,50000 -2162.1414761543274,5.530078649520874,62770.95883798599,184971,0,62770.95883798599,0.6272000074386597,1.7114449739456177,10000,64944.98903656006,0.8868184089660645,0.4034609794616699,0.7515400052070618,1.0192183256149292,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/measurements.csv deleted file mode 100644 index 8e736b065..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1983 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6531614,6.929166,,,,,,,,,,,,,, -1,,,0.001335299690254,6.91108512878418,0.0010199999669566,6.91091251373291,50000.0,0.0012000000569969,6.910790920257568,10000.0,33.559484243392944,50.918386936187744,33.559484243392944,17.358678817749023,0.0,0.0 -100,0.7672736,6.6424255,,,,,,,,,,,,,, -200,0.9619094,6.3292284,,,,,,,,,,,,,, -300,2.1507497,5.9401283,,,,,,,,,,,,,, -400,2.4506922,5.690971,,,,,,,,,,,,,, -500,3.0506642,5.570857,,,,,,,,,,,,,, -600,2.9295955,5.328566,,,,,,,,,,,,,, -700,4.853057,5.2638144,,,,,,,,,,,,,, -800,2.3795197,5.022225,,,,,,,,,,,,,, -900,3.2171402,4.8368683,,,,,,,,,,,,,, -1000,3.6711733,4.664507,,,,,,,,,,,,,, -1100,2.6973855,4.6019726,,,,,,,,,,,,,, -1200,3.6170826,4.3426514,,,,,,,,,,,,,, -1300,3.189373,4.2512784,,,,,,,,,,,,,, -1400,2.6716344,4.3425035,,,,,,,,,,,,,, -1492,,,0.1883968412876129,4.257904052734375,0.1703599989414215,4.3842315673828125,50000.0,0.1371000111103058,4.80158805847168,10000.0,543.7184791564941,578.5034921169281,543.7184791564941,34.713916540145874,0.0192787647247314,0.0 -1500,2.4763634,4.028164,,,,,,,,,,,,,, -1600,3.123665,3.7789989,,,,,,,,,,,,,, -1700,3.5173118,3.8563914,,,,,,,,,,,,,, -1800,2.0055864,3.7633653,,,,,,,,,,,,,, -1900,2.5799983,3.5252137,,,,,,,,,,,,,, -2000,1.2712091,3.6393354,,,,,,,,,,,,,, -2100,1.5412259,3.590814,,,,,,,,,,,,,, -2200,1.765729,3.4079437,,,,,,,,,,,,,, -2300,1.7299296,3.317143,,,,,,,,,,,,,, -2400,1.3506253,3.4638493,,,,,,,,,,,,,, -2500,1.4858583,3.2783766,,,,,,,,,,,,,, -2600,1.1777363,3.2358367,,,,,,,,,,,,,, -2700,1.3671924,3.2920458,,,,,,,,,,,,,, -2800,1.2872809,3.1455054,,,,,,,,,,,,,, -2900,1.0795953,3.197222,,,,,,,,,,,,,, -2985,,,0.3517418503761291,3.024991035461426,0.3260399997234344,3.193049192428589,50000.0,0.2454000115394592,3.8982651233673096,10000.0,1053.8195695877075,1105.9118909835815,1053.8195695877075,51.93747568130493,0.0497360229492187,0.0 -3000,1.2225429,3.1042798,,,,,,,,,,,,,, -3100,1.1064203,2.9783363,,,,,,,,,,,,,, -3200,1.4637954,3.1736813,,,,,,,,,,,,,, -3300,0.983143,2.9504056,,,,,,,,,,,,,, -3400,0.86144125,2.9597003,,,,,,,,,,,,,, -3500,1.0837718,2.9704478,,,,,,,,,,,,,, -3600,1.2643753,3.1491857,,,,,,,,,,,,,, -3700,0.83308464,2.869773,,,,,,,,,,,,,, -3800,0.8952084,2.9641547,,,,,,,,,,,,,, -3900,0.8604539,3.022597,,,,,,,,,,,,,, -4000,0.93444204,3.0720775,,,,,,,,,,,,,, -4100,0.87145543,2.8348444,,,,,,,,,,,,,, -4200,0.7847407,2.8678436,,,,,,,,,,,,,, -4300,0.7676903,2.8090982,,,,,,,,,,,,,, -4400,0.888631,2.6745234,,,,,,,,,,,,,, -4479,,,0.409877210855484,2.6678028106689453,0.3876599967479706,2.80322003364563,50000.0,0.2946000099182129,3.515520572662353,10000.0,1563.8707466125488,1633.739863872528,1563.8707466125488,69.63504076004028,0.0763628482818603,0.0 -4500,0.80054826,2.917738,,,,,,,,,,,,,, -4600,0.8445046,2.7644434,,,,,,,,,,,,,, -4700,0.7999933,2.6250896,,,,,,,,,,,,,, -4800,0.83881176,3.095702,,,,,,,,,,,,,, -4900,0.7327355,2.5869374,,,,,,,,,,,,,, -5000,0.87914044,2.7048376,,,,,,,,,,,,,, -5100,1.2046188,2.6923285,,,,,,,,,,,,,, -5200,0.8189364,2.6396284,,,,,,,,,,,,,, -5300,0.8335248,2.6447875,,,,,,,,,,,,,, -5400,0.8569059,2.600019,,,,,,,,,,,,,, -5500,0.9120422,2.5686326,,,,,,,,,,,,,, -5600,1.1774696,2.7458358,,,,,,,,,,,,,, -5700,0.8829196,2.5345829,,,,,,,,,,,,,, -5800,0.7978887,2.605811,,,,,,,,,,,,,, -5900,0.80835414,2.7050722,,,,,,,,,,,,,, -5974,,,0.4144810140132904,2.693990468978882,0.3527399897575378,3.0620031356811523,50000.0,0.2701000273227691,3.761731147766113,10000.0,2073.9059176445007,2161.291288852692,2073.9059176445007,87.07250952720642,0.1038854122161865,0.0 -6000,1.2197915,2.6445732,,,,,,,,,,,,,, -6100,0.967302,2.6515353,,,,,,,,,,,,,, -6200,0.9506726,2.609985,,,,,,,,,,,,,, -6300,0.8169945,2.5640774,,,,,,,,,,,,,, -6400,0.913004,2.6594427,,,,,,,,,,,,,, -6500,1.022871,2.6053815,,,,,,,,,,,,,, -6600,0.94332385,2.6134987,,,,,,,,,,,,,, -6700,0.89612997,2.5704477,,,,,,,,,,,,,, -6800,1.0850705,2.4734228,,,,,,,,,,,,,, -6900,0.9202174,2.5919929,,,,,,,,,,,,,, -7000,0.98100084,2.5216038,,,,,,,,,,,,,, -7100,0.91553324,2.605489,,,,,,,,,,,,,, -7200,0.9353387,2.6080585,,,,,,,,,,,,,, -7300,0.9809895,2.599258,,,,,,,,,,,,,, -7400,0.8069615,2.4749491,,,,,,,,,,,,,, -7470,,,0.3179607689380646,3.236488103866577,0.2916599810123443,3.4889159202575684,50000.0,0.2379000186920166,4.055367946624756,10000.0,2584.0340342521667,2689.035103559494,2584.0340342521667,104.60787105560304,0.1320948600769043,0.0 -7500,0.97159654,2.567184,,,,,,,,,,,,,, -7600,1.0324872,2.6230702,,,,,,,,,,,,,, -7700,1.0532495,2.63356,,,,,,,,,,,,,, -7800,0.93305516,2.3914182,,,,,,,,,,,,,, -7900,1.2146416,2.6135166,,,,,,,,,,,,,, -8000,0.9066014,2.5654964,,,,,,,,,,,,,, -8100,0.9149792,2.5258226,,,,,,,,,,,,,, -8200,1.1076831,2.5354552,,,,,,,,,,,,,, -8300,1.0641744,2.4826436,,,,,,,,,,,,,, -8400,0.85520613,2.4054775,,,,,,,,,,,,,, -8500,1.0009866,2.5888073,,,,,,,,,,,,,, -8600,0.9973144,2.5029132,,,,,,,,,,,,,, -8700,0.83546615,2.4103746,,,,,,,,,,,,,, -8800,0.9320057,2.5232918,,,,,,,,,,,,,, -8900,1.025157,2.4585006,,,,,,,,,,,,,, -8967,,,0.3252750337123871,3.3958663940429688,0.3066799938678741,3.4880292415618896,50000.0,0.2201000154018402,4.388706684112549,10000.0,3094.1442699432373,3216.537830352783,3094.1442699432373,121.91842865943909,0.1617681980133056,0.0 -9000,1.0528735,2.6053727,,,,,,,,,,,,,, -9100,1.0032912,2.4487853,,,,,,,,,,,,,, -9200,0.9162326,2.4454055,,,,,,,,,,,,,, -9300,0.95463896,2.4373531,,,,,,,,,,,,,, -9400,1.0700043,2.4126828,,,,,,,,,,,,,, -9500,1.0540864,2.5611641,,,,,,,,,,,,,, -9600,1.0754498,2.4592571,,,,,,,,,,,,,, -9700,1.064342,2.4800766,,,,,,,,,,,,,, -9800,0.88253266,2.4512393,,,,,,,,,,,,,, -9900,0.8696155,2.4777784,,,,,,,,,,,,,, -10000,0.9844834,2.4731045,,,,,,,,,,,,,, -10100,1.1342425,2.5354066,,,,,,,,,,,,,, -10200,1.016848,2.455665,,,,,,,,,,,,,, -10300,1.0724498,2.4363413,,,,,,,,,,,,,, -10400,0.9315239,2.4898643,,,,,,,,,,,,,, -10464,,,0.2239516824483871,4.2264933586120605,0.2140799909830093,4.341734409332275,50000.0,0.1573000103235244,5.09550666809082,10000.0,3604.070210456848,3743.9415435791016,3604.070210456848,139.31764602661133,0.1891367435455322,0.0 -10500,1.1965742,2.4058366,,,,,,,,,,,,,, -10600,1.034492,2.3867605,,,,,,,,,,,,,, -10700,1.1487445,2.4014795,,,,,,,,,,,,,, -10800,1.0011661,2.4957697,,,,,,,,,,,,,, -10900,0.9147613,2.2729683,,,,,,,,,,,,,, -11000,0.89272606,2.368489,,,,,,,,,,,,,, -11100,0.9912311,2.526351,,,,,,,,,,,,,, -11200,1.0085499,2.4568348,,,,,,,,,,,,,, -11300,0.97024304,2.3973012,,,,,,,,,,,,,, -11400,0.9992649,2.4047766,,,,,,,,,,,,,, -11500,1.2988391,2.4587963,,,,,,,,,,,,,, -11600,1.0107858,2.4289608,,,,,,,,,,,,,, -11700,0.9093749,2.2780147,,,,,,,,,,,,,, -11800,1.0113859,2.3645108,,,,,,,,,,,,,, -11900,1.0279081,2.4065623,,,,,,,,,,,,,, -11962,,,0.3482939898967743,3.1312360763549805,0.3265599906444549,3.3243465423583984,50000.0,0.2417000085115432,4.087606906890869,10000.0,4114.044556617737,4272.033766746521,4114.044556617737,157.3548982143402,0.2187190055847168,0.0 -12000,1.0592892,2.434352,,,,,,,,,,,,,, -12100,1.0370669,2.426211,,,,,,,,,,,,,, -12200,0.94783455,2.40209,,,,,,,,,,,,,, -12300,0.94474274,2.2974377,,,,,,,,,,,,,, -12400,1.1287715,2.4212928,,,,,,,,,,,,,, -12500,0.9678049,2.3939145,,,,,,,,,,,,,, -12600,0.999465,2.5259175,,,,,,,,,,,,,, -12700,1.0671076,2.4091353,,,,,,,,,,,,,, -12800,1.0698425,2.4448664,,,,,,,,,,,,,, -12900,1.0720456,2.4024434,,,,,,,,,,,,,, -13000,0.88731045,2.3755717,,,,,,,,,,,,,, -13100,1.0001521,2.4895327,,,,,,,,,,,,,, -13200,0.93177855,2.4379404,,,,,,,,,,,,,, -13300,0.97856474,2.4127166,,,,,,,,,,,,,, -13400,1.0037271,2.549368,,,,,,,,,,,,,, -13460,,,0.2033442258834839,4.297685623168945,0.1912999898195266,4.41984748840332,50000.0,0.1425000131130218,5.075582027435303,10000.0,4624.091902256012,4799.691284656525,4624.091902256012,174.88282465934753,0.248997688293457,0.0 -13500,1.0806843,2.4560242,,,,,,,,,,,,,, -13600,1.0009224,2.4134734,,,,,,,,,,,,,, -13700,0.98140687,2.282486,,,,,,,,,,,,,, -13800,1.0038565,2.5824347,,,,,,,,,,,,,, -13900,1.0071911,2.3878386,,,,,,,,,,,,,, -14000,1.1890943,2.59445,,,,,,,,,,,,,, -14100,1.1118933,2.3351476,,,,,,,,,,,,,, -14200,0.97324437,2.298792,,,,,,,,,,,,,, -14300,1.0722722,2.4365685,,,,,,,,,,,,,, -14400,1.0747099,2.4039302,,,,,,,,,,,,,, -14500,1.0625736,2.390959,,,,,,,,,,,,,, -14600,1.0258265,2.354742,,,,,,,,,,,,,, -14700,1.045776,2.4445212,,,,,,,,,,,,,, -14800,1.0250435,2.3767195,,,,,,,,,,,,,, -14900,0.9719578,2.2411907,,,,,,,,,,,,,, -14959,,,0.2525111436843872,3.852193355560303,0.2320999950170517,4.034037113189697,50000.0,0.1761000156402588,4.722900867462158,10000.0,5134.104543209076,5327.868116140366,5134.104543209076,192.96386647224423,0.2813014984130859,0.0 -15000,1.0044214,2.3047307,,,,,,,,,,,,,, -15100,0.99111736,2.457884,,,,,,,,,,,,,, -15200,1.0332984,2.3742101,,,,,,,,,,,,,, -15300,1.2072846,2.310071,,,,,,,,,,,,,, -15400,1.0340393,2.4702325,,,,,,,,,,,,,, -15500,0.9565481,2.3581378,,,,,,,,,,,,,, -15600,1.0496479,2.5377297,,,,,,,,,,,,,, -15700,1.0753275,2.4842799,,,,,,,,,,,,,, -15800,1.0877749,2.3387005,,,,,,,,,,,,,, -15900,1.0704283,2.3652713,,,,,,,,,,,,,, -16000,1.0435965,2.468936,,,,,,,,,,,,,, -16100,1.0229008,2.2731109,,,,,,,,,,,,,, -16200,1.0083568,2.3592794,,,,,,,,,,,,,, -16300,0.9368559,2.2079911,,,,,,,,,,,,,, -16400,1.0245024,2.3782716,,,,,,,,,,,,,, -16459,,,0.1354631632566452,5.622068405151367,0.1199799999594688,5.786570072174072,50000.0,0.0839000046253204,6.400252819061279,10000.0,5644.231992006302,5855.515547513962,5644.231992006302,210.40357780456543,0.3103921413421631,0.0 -16500,1.055395,2.3758225,,,,,,,,,,,,,, -16600,1.0829221,2.3259523,,,,,,,,,,,,,, -16700,0.9667781,2.3858416,,,,,,,,,,,,,, -16800,1.0407605,2.3867917,,,,,,,,,,,,,, -16900,1.0122159,2.294303,,,,,,,,,,,,,, -17000,0.9903906,2.3675947,,,,,,,,,,,,,, -17100,1.1077825,2.4027796,,,,,,,,,,,,,, -17200,1.0960785,2.334458,,,,,,,,,,,,,, -17300,1.0767305,2.2812593,,,,,,,,,,,,,, -17400,1.1421883,2.3353198,,,,,,,,,,,,,, -17500,1.0288837,2.3783426,,,,,,,,,,,,,, -17600,1.0030233,2.32002,,,,,,,,,,,,,, -17700,1.2920172,2.450674,,,,,,,,,,,,,, -17800,1.0251209,2.387851,,,,,,,,,,,,,, -17900,1.0975407,2.4138517,,,,,,,,,,,,,, -17959,,,0.3703164756298065,2.9259555339813232,0.3415800034999847,3.128161191940308,50000.0,0.254800021648407,3.834587574005127,10000.0,6154.353933811188,6383.206515073776,6154.353933811188,227.8885309696197,0.340724229812622,0.0 -18000,1.0235829,2.4943974,,,,,,,,,,,,,, -18100,0.9583505,2.30091,,,,,,,,,,,,,, -18200,1.0700735,2.3143501,,,,,,,,,,,,,, -18300,1.2153666,2.3392813,,,,,,,,,,,,,, -18400,1.0994347,2.4205482,,,,,,,,,,,,,, -18500,1.1846712,2.3489952,,,,,,,,,,,,,, -18600,1.2123094,2.3504584,,,,,,,,,,,,,, -18700,1.0074127,2.2672439,,,,,,,,,,,,,, -18800,1.0114105,2.4785495,,,,,,,,,,,,,, -18900,1.0030303,2.3643222,,,,,,,,,,,,,, -19000,1.0881717,2.4609566,,,,,,,,,,,,,, -19100,0.92403406,2.3599362,,,,,,,,,,,,,, -19200,0.98316205,2.4436011,,,,,,,,,,,,,, -19300,1.0484244,2.401418,,,,,,,,,,,,,, -19400,1.0559896,2.5640862,,,,,,,,,,,,,, -19460,,,0.248425543308258,4.295435428619385,0.2261399924755096,4.466485023498535,50000.0,0.1698000133037567,5.243460178375244,10000.0,6664.5083796978,6911.110089302063,6664.5083796978,245.5549459457397,0.3720059394836426,0.0 -19500,1.095819,2.417488,,,,,,,,,,,,,, -19600,1.0848385,2.3471684,,,,,,,,,,,,,, -19700,1.0292523,2.3539324,,,,,,,,,,,,,, -19800,1.0556091,2.3300915,,,,,,,,,,,,,, -19900,0.9803505,2.351237,,,,,,,,,,,,,, -20000,1.1103721,2.3591597,,,,,,,,,,,,,, -20100,1.0187924,2.2758505,,,,,,,,,,,,,, -20200,1.0408918,2.3322725,,,,,,,,,,,,,, -20300,1.1018445,2.4572814,,,,,,,,,,,,,, -20400,1.0460482,2.3196878,,,,,,,,,,,,,, -20500,1.0076284,2.2576187,,,,,,,,,,,,,, -20600,1.1202339,2.2684226,,,,,,,,,,,,,, -20700,1.0183448,2.2760262,,,,,,,,,,,,,, -20800,0.99246544,2.2655804,,,,,,,,,,,,,, -20900,1.223481,2.5669181,,,,,,,,,,,,,, -20961,,,0.2622568607330322,3.897775888442993,0.255400002002716,3.955528974533081,50000.0,0.1830000132322311,4.831855773925781,10000.0,7174.447612285614,7438.4239411354065,7174.447612285614,262.85289573669434,0.3986082077026367,0.0 -21000,1.0441684,2.3402963,,,,,,,,,,,,,, -21100,1.0869299,2.3962464,,,,,,,,,,,,,, -21200,1.0795344,2.3338485,,,,,,,,,,,,,, -21300,1.0573488,2.3406043,,,,,,,,,,,,,, -21400,0.9691244,2.310839,,,,,,,,,,,,,, -21500,1.0627607,2.35842,,,,,,,,,,,,,, -21600,1.114693,2.4816165,,,,,,,,,,,,,, -21700,0.8777721,2.2577548,,,,,,,,,,,,,, -21800,1.1197262,2.3942873,,,,,,,,,,,,,, -21900,1.1426243,2.445471,,,,,,,,,,,,,, -22000,1.0616854,2.3096943,,,,,,,,,,,,,, -22100,0.96091217,2.2773798,,,,,,,,,,,,,, -22200,1.1015391,2.323492,,,,,,,,,,,,,, -22300,1.0673031,2.2845585,,,,,,,,,,,,,, -22400,1.0177022,2.3661098,,,,,,,,,,,,,, -22462,,,0.1203364133834838,5.798086643218994,0.1094199940562248,5.968713760375977,50000.0,0.0861000046133995,6.424266338348389,10000.0,7684.4579641819,7965.822619438171,7684.4579641819,280.15942215919495,0.4302568435668945,0.0 -22500,1.1969266,2.4195275,,,,,,,,,,,,,, -22600,0.99209166,2.2850332,,,,,,,,,,,,,, -22700,1.08749,2.2524755,,,,,,,,,,,,,, -22800,1.0784378,2.3422596,,,,,,,,,,,,,, -22900,1.0239258,2.2144341,,,,,,,,,,,,,, -23000,0.9899042,2.3199005,,,,,,,,,,,,,, -23100,1.1089664,2.3488145,,,,,,,,,,,,,, -23200,1.0782187,2.392839,,,,,,,,,,,,,, -23300,1.1327994,2.3384194,,,,,,,,,,,,,, -23400,1.0671921,2.4189613,,,,,,,,,,,,,, -23500,1.0180404,2.37964,,,,,,,,,,,,,, -23600,1.0864592,2.3171473,,,,,,,,,,,,,, -23700,1.0610716,2.3193774,,,,,,,,,,,,,, -23800,1.141523,2.3494537,,,,,,,,,,,,,, -23900,1.034215,2.296654,,,,,,,,,,,,,, -23963,,,0.0801777690649032,6.423879146575928,0.0747199952602386,6.47040605545044,50000.0,0.055500004440546,6.987226963043213,10000.0,8194.56183886528,8493.26726603508,8194.56183886528,297.41808342933655,0.4613358974456787,0.0 -24000,1.2011037,2.5252235,,,,,,,,,,,,,, -24100,1.0756179,2.3493803,,,,,,,,,,,,,, -24200,1.0473636,2.3331954,,,,,,,,,,,,,, -24300,1.0482988,2.321639,,,,,,,,,,,,,, -24400,1.0462558,2.306865,,,,,,,,,,,,,, -24500,1.0739896,2.2838755,,,,,,,,,,,,,, -24600,1.1148837,2.2848454,,,,,,,,,,,,,, -24700,1.0609183,2.2872014,,,,,,,,,,,,,, -24800,1.0390424,2.2613518,,,,,,,,,,,,,, -24900,1.0272499,2.4661207,,,,,,,,,,,,,, -25000,1.0683968,2.3171768,,,,,,,,,,,,,, -25100,1.1079594,2.3088493,,,,,,,,,,,,,, -25200,1.1252234,2.2589526,,,,,,,,,,,,,, -25300,1.0455253,2.3376,,,,,,,,,,,,,, -25400,1.0896215,2.4404116,,,,,,,,,,,,,, -25464,,,0.0676219686865806,8.211507797241211,0.0568999983370304,8.37997817993164,50000.0,0.0408000014722347,9.22226333618164,10000.0,8704.487569570541,9021.073375463486,8704.487569570541,315.2127788066864,0.494720458984375,0.0 -25500,1.0429707,2.239365,,,,,,,,,,,,,, -25600,1.0338161,2.3827715,,,,,,,,,,,,,, -25700,1.0657951,2.3631413,,,,,,,,,,,,,, -25800,1.138355,2.3083978,,,,,,,,,,,,,, -25900,1.1272019,2.5028884,,,,,,,,,,,,,, -26000,1.1274525,2.2079372,,,,,,,,,,,,,, -26100,1.0108885,2.3485284,,,,,,,,,,,,,, -26200,1.0901039,2.372394,,,,,,,,,,,,,, -26300,1.1955364,2.348518,,,,,,,,,,,,,, -26400,1.0522305,2.4067354,,,,,,,,,,,,,, -26500,1.0446239,2.2671592,,,,,,,,,,,,,, -26600,1.1777837,2.2046943,,,,,,,,,,,,,, -26700,1.06169,2.2123404,,,,,,,,,,,,,, -26800,1.204788,2.389175,,,,,,,,,,,,,, -26900,1.1343434,2.2966006,,,,,,,,,,,,,, -26965,,,0.1829360574483871,4.75714921951294,0.1735799908638,4.842849254608154,50000.0,0.1159000024199485,5.787625312805176,10000.0,9214.40810918808,9548.710106372831,9214.40810918808,332.8436703681946,0.5292325019836426,0.0 -27000,1.1255238,2.5115833,,,,,,,,,,,,,, -27100,1.084219,2.4149778,,,,,,,,,,,,,, -27200,1.1452478,2.4354692,,,,,,,,,,,,,, -27300,1.088129,2.2247481,,,,,,,,,,,,,, -27400,1.1701163,2.2228317,,,,,,,,,,,,,, -27500,1.041835,2.2422957,,,,,,,,,,,,,, -27600,1.1107798,2.3031082,,,,,,,,,,,,,, -27700,1.2341089,2.2759364,,,,,,,,,,,,,, -27800,1.1121134,2.3522453,,,,,,,,,,,,,, -27900,1.0395012,2.3693504,,,,,,,,,,,,,, -28000,1.0961151,2.2818851,,,,,,,,,,,,,, -28100,1.1493508,2.297175,,,,,,,,,,,,,, -28200,1.3397979,2.2352595,,,,,,,,,,,,,, -28300,1.1164963,2.2415187,,,,,,,,,,,,,, -28400,1.2094973,2.3606188,,,,,,,,,,,,,, -28467,,,0.1595384180545807,5.077648639678955,0.1496199965476989,5.205980777740479,50000.0,0.1175000071525573,5.716974258422852,10000.0,9724.49069738388,10076.34494829178,9724.49069738388,350.31236839294434,0.5601718425750732,0.0 -28500,1.172932,2.3475785,,,,,,,,,,,,,, -28600,1.1692448,2.2219815,,,,,,,,,,,,,, -28700,1.007828,2.2884579,,,,,,,,,,,,,, -28800,1.0887965,2.2847147,,,,,,,,,,,,,, -28900,1.0713553,2.363156,,,,,,,,,,,,,, -29000,1.4010901,2.425818,,,,,,,,,,,,,, -29100,1.0095912,2.1249468,,,,,,,,,,,,,, -29200,1.1331974,2.2567945,,,,,,,,,,,,,, -29300,1.0740883,2.3598044,,,,,,,,,,,,,, -29400,1.1478283,2.3690891,,,,,,,,,,,,,, -29500,1.1017369,2.3604045,,,,,,,,,,,,,, -29600,1.1917809,2.2889519,,,,,,,,,,,,,, -29700,1.3242781,2.282839,,,,,,,,,,,,,, -29800,1.2364552,2.2301717,,,,,,,,,,,,,, -29900,1.1356739,2.3158884,,,,,,,,,,,,,, -29969,,,0.2542450428009033,3.912951707839966,0.2386399954557418,4.062727928161621,50000.0,0.180400013923645,4.745062351226807,10000.0,10234.423271417618,10603.58622789383,10234.423271417618,367.536589384079,0.5917963981628418,0.0 -30000,1.1323067,2.2738924,,,,,,,,,,,,,, -30100,1.0881811,2.2912571,,,,,,,,,,,,,, -30200,1.064703,2.2937737,,,,,,,,,,,,,, -30300,1.1248708,2.2275324,,,,,,,,,,,,,, -30400,1.0814004,2.2484288,,,,,,,,,,,,,, -30500,1.0965766,2.3326821,,,,,,,,,,,,,, -30600,1.0685058,2.351005,,,,,,,,,,,,,, -30700,1.022489,2.305633,,,,,,,,,,,,,, -30800,1.1891448,2.2874286,,,,,,,,,,,,,, -30900,1.2230538,2.182258,,,,,,,,,,,,,, -31000,1.0610377,2.29152,,,,,,,,,,,,,, -31100,1.0747536,2.2953997,,,,,,,,,,,,,, -31200,1.1259403,2.2318294,,,,,,,,,,,,,, -31300,1.0851415,2.343546,,,,,,,,,,,,,, -31400,1.0801519,2.3286912,,,,,,,,,,,,,, -31471,,,0.342873066663742,3.108630418777466,0.3323999941349029,3.178079843521118,50000.0,0.2486000061035156,3.894976377487183,10000.0,10744.5173869133,11131.000636100767,10744.5173869133,384.77102971076965,0.6250383853912354,0.0 -31500,1.1949558,2.4207091,,,,,,,,,,,,,, -31600,1.2713922,2.4324238,,,,,,,,,,,,,, -31700,1.1720616,2.3193483,,,,,,,,,,,,,, -31800,1.222486,2.3728082,,,,,,,,,,,,,, -31900,0.96865344,2.282852,,,,,,,,,,,,,, -32000,1.1365274,2.391018,,,,,,,,,,,,,, -32100,1.0620472,2.2519307,,,,,,,,,,,,,, -32200,1.206452,2.3559558,,,,,,,,,,,,,, -32300,1.215521,2.286058,,,,,,,,,,,,,, -32400,1.2585018,2.3353045,,,,,,,,,,,,,, -32500,1.1398805,2.2618523,,,,,,,,,,,,,, -32600,1.2111108,2.263267,,,,,,,,,,,,,, -32700,1.224472,2.3241389,,,,,,,,,,,,,, -32800,1.0493637,2.2400157,,,,,,,,,,,,,, -32900,1.1153541,2.3243434,,,,,,,,,,,,,, -32974,,,0.273138552904129,4.065606594085693,0.2622599899768829,4.184144973754883,50000.0,0.1951000094413757,5.094240665435791,10000.0,11254.584214448929,11658.715369939804,11254.584214448929,402.3353538513184,0.6580126285552979,0.0 -33000,1.1065704,2.310762,,,,,,,,,,,,,, -33100,1.0737103,2.166367,,,,,,,,,,,,,, -33200,1.2168489,2.2557516,,,,,,,,,,,,,, -33300,1.1714587,2.329493,,,,,,,,,,,,,, -33400,1.3667824,2.2904687,,,,,,,,,,,,,, -33500,1.1177465,2.3743901,,,,,,,,,,,,,, -33600,1.201615,2.4573793,,,,,,,,,,,,,, -33700,1.2835895,2.3371584,,,,,,,,,,,,,, -33800,1.1096272,2.3137248,,,,,,,,,,,,,, -33900,1.1288291,2.3621602,,,,,,,,,,,,,, -34000,1.1113435,2.3160634,,,,,,,,,,,,,, -34100,1.1063379,2.4057302,,,,,,,,,,,,,, -34200,1.1510197,2.3184626,,,,,,,,,,,,,, -34300,1.0558794,2.2760763,,,,,,,,,,,,,, -34400,1.1195533,2.2180665,,,,,,,,,,,,,, -34477,,,0.1878587305545807,4.851592540740967,0.1707399934530258,5.058218955993652,50000.0,0.1181000024080276,5.938706398010254,10000.0,11764.796533107758,12186.59768295288,11764.796533107758,419.9206705093384,0.6911814212799072,0.0 -34500,1.2349207,2.2939658,,,,,,,,,,,,,, -34600,1.0807679,2.2342641,,,,,,,,,,,,,, -34700,1.2391355,2.4203424,,,,,,,,,,,,,, -34800,1.1398219,2.3958623,,,,,,,,,,,,,, -34900,1.0702415,2.2546706,,,,,,,,,,,,,, -35000,1.1178788,2.2233982,,,,,,,,,,,,,, -35100,1.0851684,2.2394273,,,,,,,,,,,,,, -35200,1.1717918,2.232998,,,,,,,,,,,,,, -35300,1.1320148,2.205995,,,,,,,,,,,,,, -35400,1.164165,2.2094834,,,,,,,,,,,,,, -35500,1.2002851,2.318031,,,,,,,,,,,,,, -35600,1.3631239,2.357198,,,,,,,,,,,,,, -35700,1.1645706,2.274633,,,,,,,,,,,,,, -35800,1.1014584,2.3338668,,,,,,,,,,,,,, -35900,1.0807358,2.284774,,,,,,,,,,,,,, -35980,,,0.2466717064380645,4.081879615783691,0.231019988656044,4.22337007522583,50000.0,0.1688000112771988,4.942447185516357,10000.0,12274.719363689424,12714.083587408066,12274.719363689424,437.3967123031616,0.727304220199585,0.0 -36000,1.2249494,2.3814874,,,,,,,,,,,,,, -36100,1.0725149,2.2955241,,,,,,,,,,,,,, -36200,1.1077067,2.111589,,,,,,,,,,,,,, -36300,1.1769612,2.2155216,,,,,,,,,,,,,, -36400,1.3114071,2.3364844,,,,,,,,,,,,,, -36500,1.3198737,2.306014,,,,,,,,,,,,,, -36600,1.2778319,2.4906013,,,,,,,,,,,,,, -36700,1.2053713,2.2089415,,,,,,,,,,,,,, -36800,1.1391689,2.0756521,,,,,,,,,,,,,, -36900,1.1630255,2.2541947,,,,,,,,,,,,,, -37000,1.1171175,2.315301,,,,,,,,,,,,,, -37100,1.197546,2.2806127,,,,,,,,,,,,,, -37200,1.2846553,2.235505,,,,,,,,,,,,,, -37300,1.0706116,2.30831,,,,,,,,,,,,,, -37400,1.0833935,2.1478512,,,,,,,,,,,,,, -37484,,,0.2384207546710968,4.461237907409668,0.2242599874734878,4.545414447784424,50000.0,0.1565000116825103,5.578945159912109,10000.0,12784.916873216627,13241.621287107468,12784.916873216627,454.6478660106659,0.764479398727417,0.0 -37500,1.1248854,2.3751197,,,,,,,,,,,,,, -37600,1.0760725,2.2092102,,,,,,,,,,,,,, -37700,1.1905719,2.336049,,,,,,,,,,,,,, -37800,1.172954,2.3151233,,,,,,,,,,,,,, -37900,1.238046,2.2899213,,,,,,,,,,,,,, -38000,1.2050289,2.2407458,,,,,,,,,,,,,, -38100,1.297979,2.2705286,,,,,,,,,,,,,, -38200,1.1921521,2.3449879,,,,,,,,,,,,,, -38300,1.2265873,2.2872167,,,,,,,,,,,,,, -38400,1.2802889,2.1712441,,,,,,,,,,,,,, -38500,1.262944,2.259672,,,,,,,,,,,,,, -38600,1.181403,2.2844868,,,,,,,,,,,,,, -38700,1.0884131,2.228138,,,,,,,,,,,,,, -38800,1.1083236,2.2495332,,,,,,,,,,,,,, -38900,1.1632215,2.227063,,,,,,,,,,,,,, -38987,,,0.2852160334587097,3.727014541625977,0.2743600010871887,3.853545427322388,50000.0,0.1946000158786773,4.6907501220703125,10000.0,13295.052711725237,13769.374715805054,13295.052711725237,472.1811301708221,0.797905683517456,0.0 -39000,1.1572146,2.2945237,,,,,,,,,,,,,, -39100,1.1513695,2.1796412,,,,,,,,,,,,,, -39200,1.1885469,2.3200293,,,,,,,,,,,,,, -39300,1.1512922,2.339484,,,,,,,,,,,,,, -39400,1.1189137,2.3493104,,,,,,,,,,,,,, -39500,1.1445036,2.138999,,,,,,,,,,,,,, -39600,1.1528045,2.1942844,,,,,,,,,,,,,, -39700,1.0966742,2.2029757,,,,,,,,,,,,,, -39800,1.1130583,2.2883584,,,,,,,,,,,,,, -39900,1.0966756,2.2089386,,,,,,,,,,,,,, -40000,1.2585673,2.19702,,,,,,,,,,,,,, -40100,1.1933609,2.2145276,,,,,,,,,,,,,, -40200,1.2320465,2.2360456,,,,,,,,,,,,,, -40300,1.2382087,2.287314,,,,,,,,,,,,,, -40400,1.183454,2.2165618,,,,,,,,,,,,,, -40490,,,0.3294602930545807,3.2785985469818115,0.3068799972534179,3.4649877548217773,50000.0,0.2426000088453292,4.134172439575195,10000.0,13804.977420091627,14296.749955415726,13804.977420091627,489.544287443161,0.8351831436157227,0.0 -40500,1.1161667,2.318319,,,,,,,,,,,,,, -40600,1.21668,2.4220972,,,,,,,,,,,,,, -40700,1.1695545,2.2056599,,,,,,,,,,,,,, -40800,1.1337193,2.1706903,,,,,,,,,,,,,, -40900,1.2019702,2.2729416,,,,,,,,,,,,,, -41000,1.2283275,2.0663524,,,,,,,,,,,,,, -41100,1.2344023,2.2575681,,,,,,,,,,,,,, -41200,1.1351503,2.2163863,,,,,,,,,,,,,, -41300,1.2493125,2.322175,,,,,,,,,,,,,, -41400,1.2953583,2.199,,,,,,,,,,,,,, -41500,1.4504156,2.4526618,,,,,,,,,,,,,, -41600,1.1206758,2.3808517,,,,,,,,,,,,,, -41700,1.1996918,2.394478,,,,,,,,,,,,,, -41800,1.0573226,2.200788,,,,,,,,,,,,,, -41900,1.2565634,2.3424425,,,,,,,,,,,,,, -41994,,,0.1082987859845161,6.09940767288208,0.1052399948239326,6.120722770690918,50000.0,0.0749000012874603,6.715982437133789,10000.0,14315.08835530281,14824.222242355348,14315.08835530281,506.81652092933655,0.8721778392791748,0.0 -42000,1.2442521,2.1657767,,,,,,,,,,,,,, -42100,1.1723799,2.1507661,,,,,,,,,,,,,, -42200,1.2330611,2.3820105,,,,,,,,,,,,,, -42300,1.191201,2.237899,,,,,,,,,,,,,, -42400,1.2266121,2.3283334,,,,,,,,,,,,,, -42500,1.2276157,2.2504597,,,,,,,,,,,,,, -42600,1.0627562,2.2107556,,,,,,,,,,,,,, -42700,1.2371806,2.2156658,,,,,,,,,,,,,, -42800,1.2830433,2.246928,,,,,,,,,,,,,, -42900,1.2750455,2.3446107,,,,,,,,,,,,,, -43000,1.3188812,2.2359555,,,,,,,,,,,,,, -43100,1.1561434,2.4259489,,,,,,,,,,,,,, -43200,1.1415246,2.1778028,,,,,,,,,,,,,, -43300,1.0992069,2.2665982,,,,,,,,,,,,,, -43400,1.1955708,2.337473,,,,,,,,,,,,,, -43497,,,0.1314971297979354,4.971019744873047,0.1196599975228309,5.1093292236328125,50000.0,0.0809000059962272,5.751832008361816,10000.0,14825.161313533785,15351.805294513702,14825.161313533785,524.241007566452,0.9064865112304688,0.0 -43500,1.24018,2.1722968,,,,,,,,,,,,,, -43600,1.282495,2.325818,,,,,,,,,,,,,, -43700,1.0717062,2.133052,,,,,,,,,,,,,, -43800,1.2774099,2.2991996,,,,,,,,,,,,,, -43900,1.1922542,2.1215465,,,,,,,,,,,,,, -44000,1.2543882,2.3161924,,,,,,,,,,,,,, -44100,1.2390189,2.2570136,,,,,,,,,,,,,, -44200,1.2972883,2.3860643,,,,,,,,,,,,,, -44300,1.3100935,2.1739495,,,,,,,,,,,,,, -44400,1.1650052,2.184021,,,,,,,,,,,,,, -44500,1.0734046,2.3029668,,,,,,,,,,,,,, -44600,1.2321478,2.2264392,,,,,,,,,,,,,, -44700,1.3545163,2.3262641,,,,,,,,,,,,,, -44800,1.2129924,2.2654586,,,,,,,,,,,,,, -44900,1.2010353,2.1721654,,,,,,,,,,,,,, -44999,,,0.1791892498731613,5.077031135559082,0.1635999977588653,5.281920433044434,50000.0,0.1255000084638595,5.8240647315979,10000.0,15335.105654001236,15879.07846212387,15335.105654001236,541.479898929596,0.9436588287353516,0.0 -45000,1.2577838,2.133626,,,,,,,,,,,,,, -45100,1.2253385,2.1708384,,,,,,,,,,,,,, -45200,1.2731683,2.2602515,,,,,,,,,,,,,, -45300,1.2342645,2.3130221,,,,,,,,,,,,,, -45400,1.2829354,2.2264037,,,,,,,,,,,,,, -45500,1.3858123,2.3561327,,,,,,,,,,,,,, -45600,1.230343,2.2454054,,,,,,,,,,,,,, -45700,1.2083957,2.2169352,,,,,,,,,,,,,, -45800,1.197797,2.2761896,,,,,,,,,,,,,, -45900,1.1564304,2.3646686,,,,,,,,,,,,,, -46000,1.1521664,2.2108552,,,,,,,,,,,,,, -46100,1.2242194,2.2504785,,,,,,,,,,,,,, -46200,1.2544755,2.3106844,,,,,,,,,,,,,, -46300,1.1028289,2.1642132,,,,,,,,,,,,,, -46400,1.1687866,2.158197,,,,,,,,,,,,,, -46500,1.1207889,2.1929493,,,,,,,,,,,,,, -46503,,,0.3250558078289032,3.3820459842681885,0.3132599890232086,3.5227527618408203,50000.0,0.2361000180244445,4.294202327728272,10000.0,15845.155004501345,16406.597977876663,15845.155004501345,558.8642518520355,0.979116916656494,0.0 -46600,1.2575216,2.2485375,,,,,,,,,,,,,, -46700,1.3362181,2.3030577,,,,,,,,,,,,,, -46800,1.1247554,2.2129018,,,,,,,,,,,,,, -46900,1.3776342,2.2263222,,,,,,,,,,,,,, -47000,1.3942437,2.4127965,,,,,,,,,,,,,, -47100,1.292142,2.1742768,,,,,,,,,,,,,, -47200,1.3094063,2.271906,,,,,,,,,,,,,, -47300,1.1896766,2.171377,,,,,,,,,,,,,, -47400,1.3160114,2.3043838,,,,,,,,,,,,,, -47500,1.2494477,2.2964432,,,,,,,,,,,,,, -47600,1.1719707,2.3194199,,,,,,,,,,,,,, -47700,1.1991681,2.29467,,,,,,,,,,,,,, -47800,1.2727414,2.2683995,,,,,,,,,,,,,, -47900,1.2558287,2.2474794,,,,,,,,,,,,,, -48000,1.3017994,2.2668478,,,,,,,,,,,,,, -48007,,,0.3681640625,3.034077405929565,0.3450599908828735,3.1679186820983887,50000.0,0.2668000161647796,3.835937261581421,10000.0,16355.190060853958,16934.3196310997,16355.190060853958,576.4580047130585,1.020794153213501,0.0 -48100,1.3040719,2.2192392,,,,,,,,,,,,,, -48200,1.1467247,2.1917706,,,,,,,,,,,,,, -48300,1.1543573,2.266755,,,,,,,,,,,,,, -48400,1.1327627,2.1505089,,,,,,,,,,,,,, -48500,1.218349,2.3127718,,,,,,,,,,,,,, -48600,1.2433803,2.29246,,,,,,,,,,,,,, -48700,1.1391888,2.1787298,,,,,,,,,,,,,, -48800,1.2022147,2.2879832,,,,,,,,,,,,,, -48900,1.281625,2.19196,,,,,,,,,,,,,, -49000,1.3299942,2.279639,,,,,,,,,,,,,, -49100,1.332459,2.277614,,,,,,,,,,,,,, -49200,1.2030452,2.2634463,,,,,,,,,,,,,, -49300,1.2948649,2.140151,,,,,,,,,,,,,, -49400,1.1728619,2.244391,,,,,,,,,,,,,, -49500,1.1296383,2.2266004,,,,,,,,,,,,,, -49511,,,0.2969347834587097,3.669082164764404,0.2820599973201751,3.842806100845337,50000.0,0.203900009393692,4.691739082336426,10000.0,16865.38143491745,17461.97078728676,16865.38143491745,593.8302228450775,1.0560753345489502,0.0 -49600,1.2344309,2.1591935,,,,,,,,,,,,,, -49700,1.2399503,2.3842077,,,,,,,,,,,,,, -49800,1.1879884,2.127977,,,,,,,,,,,,,, -49900,1.3555677,2.2427409,,,,,,,,,,,,,, -50000,1.118884,2.130292,,,,,,,,,,,,,, -50100,1.2312944,2.3389091,,,,,,,,,,,,,, -50200,1.2678337,2.2744732,,,,,,,,,,,,,, -50300,1.4326783,2.1946013,,,,,,,,,,,,,, -50400,1.1525497,1.9794582,,,,,,,,,,,,,, -50500,1.2265189,2.3130865,,,,,,,,,,,,,, -50600,1.1815252,2.2245805,,,,,,,,,,,,,, -50700,1.2625629,2.2690914,,,,,,,,,,,,,, -50800,1.1881152,2.1407733,,,,,,,,,,,,,, -50900,1.2103627,2.2481077,,,,,,,,,,,,,, -51000,1.2862573,2.165159,,,,,,,,,,,,,, -51014,,,0.1538384854793548,4.814224243164063,0.1469199955463409,4.868548393249512,50000.0,0.1058000028133392,5.460964679718018,10000.0,17375.29843711853,17989.424177646637,17375.29843711853,611.2783124446869,1.094315767288208,0.0 -51100,1.3175454,2.0954213,,,,,,,,,,,,,, -51200,1.1938648,2.2144897,,,,,,,,,,,,,, -51300,1.3169627,2.1887226,,,,,,,,,,,,,, -51400,1.0921662,2.125635,,,,,,,,,,,,,, -51500,1.1402998,2.2898743,,,,,,,,,,,,,, -51600,1.2388716,2.2023153,,,,,,,,,,,,,, -51700,1.2980384,2.3440497,,,,,,,,,,,,,, -51800,1.1959081,2.2679145,,,,,,,,,,,,,, -51900,1.2788517,2.2100768,,,,,,,,,,,,,, -52000,1.1278085,2.0988357,,,,,,,,,,,,,, -52100,1.1808825,2.1961386,,,,,,,,,,,,,, -52200,1.260239,2.3709924,,,,,,,,,,,,,, -52300,1.2250806,2.2003808,,,,,,,,,,,,,, -52400,1.3169287,2.318669,,,,,,,,,,,,,, -52500,1.1996982,2.0852277,,,,,,,,,,,,,, -52519,,,0.4587452113628387,2.4128456115722656,0.3999799787998199,2.8159847259521484,50000.0,0.3015000224113464,3.607529878616333,10000.0,17885.518147945404,18517.496727705,17885.518147945404,629.0414938926697,1.1306431293487549,0.0 -52600,1.1777499,2.326663,,,,,,,,,,,,,, -52700,1.2041129,2.0501409,,,,,,,,,,,,,, -52800,1.3410928,2.1954618,,,,,,,,,,,,,, -52900,1.2484387,2.1507568,,,,,,,,,,,,,, -53000,1.3184645,2.1670086,,,,,,,,,,,,,, -53100,1.2153314,2.213098,,,,,,,,,,,,,, -53200,1.3236226,2.2070754,,,,,,,,,,,,,, -53300,1.2379086,2.3338566,,,,,,,,,,,,,, -53400,1.333633,2.2444594,,,,,,,,,,,,,, -53500,1.2316192,2.1394517,,,,,,,,,,,,,, -53600,1.1973176,2.2162578,,,,,,,,,,,,,, -53700,1.1154274,2.2024164,,,,,,,,,,,,,, -53800,1.3293071,2.2269583,,,,,,,,,,,,,, -53900,1.3682948,2.326486,,,,,,,,,,,,,, -54000,1.1833864,2.2143676,,,,,,,,,,,,,, -54023,,,0.2306481152772903,4.139097213745117,0.2137999981641769,4.321166038513184,50000.0,0.169400006532669,4.901164531707764,10000.0,18395.49843478203,19044.81125664711,18395.49843478203,646.2846746444702,1.1675300598144531,0.0 -54100,1.2731287,2.3493805,,,,,,,,,,,,,, -54200,1.3644873,2.382656,,,,,,,,,,,,,, -54300,1.2661972,2.2140908,,,,,,,,,,,,,, -54400,1.3272964,2.2014904,,,,,,,,,,,,,, -54500,1.302705,2.25251,,,,,,,,,,,,,, -54600,1.2622079,2.2220733,,,,,,,,,,,,,, -54700,1.4186802,2.2508864,,,,,,,,,,,,,, -54800,1.3208374,2.2747543,,,,,,,,,,,,,, -54900,1.1712705,2.2085474,,,,,,,,,,,,,, -55000,1.2000387,2.321428,,,,,,,,,,,,,, -55100,1.2519697,2.2229257,,,,,,,,,,,,,, -55200,1.3248945,2.2248688,,,,,,,,,,,,,, -55300,1.2960075,2.2708514,,,,,,,,,,,,,, -55400,1.3699019,2.1241717,,,,,,,,,,,,,, -55500,1.2266653,2.3031108,,,,,,,,,,,,,, -55527,,,0.2505779564380646,3.95234489440918,0.2338999956846237,4.098372936248779,50000.0,0.172200009226799,4.754857063293457,10000.0,18905.47820210457,19573.068316936493,18905.47820210457,664.4693939685822,1.2078561782836914,0.0 -55600,1.2193613,2.1546655,,,,,,,,,,,,,, -55700,1.4798195,2.3025286,,,,,,,,,,,,,, -55800,1.2726997,2.2330306,,,,,,,,,,,,,, -55900,1.2458235,2.242713,,,,,,,,,,,,,, -56000,1.1602833,2.3188436,,,,,,,,,,,,,, -56100,1.2466983,2.1687012,,,,,,,,,,,,,, -56200,1.3237343,2.258916,,,,,,,,,,,,,, -56300,1.1150576,2.2212455,,,,,,,,,,,,,, -56400,1.2963179,2.1847706,,,,,,,,,,,,,, -56500,1.1438004,2.1448622,,,,,,,,,,,,,, -56600,1.2941921,2.0642254,,,,,,,,,,,,,, -56700,1.3085804,2.3906286,,,,,,,,,,,,,, -56800,1.2342995,2.195476,,,,,,,,,,,,,, -56900,1.3921657,2.2118478,,,,,,,,,,,,,, -57000,1.2116312,2.2794569,,,,,,,,,,,,,, -57032,,,0.1284478604793548,5.659578323364258,0.1220599934458732,5.704127788543701,50000.0,0.0820000022649765,6.516331672668457,10000.0,19415.634298086166,20100.56913280487,19415.634298086166,681.7302577495575,1.2400367259979248,0.0 -57100,1.2115147,2.0281959,,,,,,,,,,,,,, -57200,1.2667551,2.1265163,,,,,,,,,,,,,, -57300,1.2641772,2.2750807,,,,,,,,,,,,,, -57400,1.302427,2.0667355,,,,,,,,,,,,,, -57500,1.4399986,2.200827,,,,,,,,,,,,,, -57600,1.2810558,2.2056649,,,,,,,,,,,,,, -57700,1.4602189,2.1535873,,,,,,,,,,,,,, -57800,1.4341521,2.2958217,,,,,,,,,,,,,, -57900,1.3605833,2.10095,,,,,,,,,,,,,, -58000,1.2602534,2.1683636,,,,,,,,,,,,,, -58100,1.3228576,2.3118181,,,,,,,,,,,,,, -58200,1.2349824,2.1059725,,,,,,,,,,,,,, -58300,1.2436905,2.138388,,,,,,,,,,,,,, -58400,1.3068428,2.3055668,,,,,,,,,,,,,, -58500,1.1391983,2.0817986,,,,,,,,,,,,,, -58537,,,0.2176538556814193,4.364476680755615,0.2053599953651428,4.498104095458984,50000.0,0.1523000001907348,5.230589866638184,10000.0,19925.873183965683,20628.08365273476,19925.873183965683,698.9143342971802,1.2803306579589844,0.0 -58600,1.2988647,2.2604008,,,,,,,,,,,,,, -58700,1.1845871,2.2839527,,,,,,,,,,,,,, -58800,1.3427463,2.194824,,,,,,,,,,,,,, -58900,1.1938266,2.0450397,,,,,,,,,,,,,, -59000,1.3005813,2.1936114,,,,,,,,,,,,,, -59100,1.4654006,2.1348267,,,,,,,,,,,,,, -59200,1.25152,2.2836497,,,,,,,,,,,,,, -59300,1.1375711,2.0932353,,,,,,,,,,,,,, -59400,1.3960096,2.1255116,,,,,,,,,,,,,, -59500,1.3226901,2.2171335,,,,,,,,,,,,,, -59600,1.3027002,2.2163663,,,,,,,,,,,,,, -59700,1.2489148,2.1494677,,,,,,,,,,,,,, -59800,1.1971096,2.1343265,,,,,,,,,,,,,, -59900,1.1893866,2.1851535,,,,,,,,,,,,,, -60000,1.3888694,2.1764166,,,,,,,,,,,,,, -60041,,,0.1741868555545807,5.3501811027526855,0.1663800030946731,5.571073055267334,50000.0,0.1358000040054321,6.057143211364746,10000.0,20435.8396422863,21155.63431572914,20435.8396422863,716.4034960269928,1.3220326900482178,0.0 -60100,1.4257342,2.2695942,,,,,,,,,,,,,, -60200,1.4808013,2.1636438,,,,,,,,,,,,,, -60300,1.3772027,2.3041372,,,,,,,,,,,,,, -60400,1.3993956,2.0751662,,,,,,,,,,,,,, -60500,1.3291899,2.2042158,,,,,,,,,,,,,, -60600,1.2586032,2.1553519,,,,,,,,,,,,,, -60700,1.445008,2.1356637,,,,,,,,,,,,,, -60800,1.2885355,2.198225,,,,,,,,,,,,,, -60900,1.3533298,2.2344885,,,,,,,,,,,,,, -61000,1.3765887,2.1861234,,,,,,,,,,,,,, -61100,1.2683136,2.1771138,,,,,,,,,,,,,, -61200,1.3870794,2.3283405,,,,,,,,,,,,,, -61300,1.346983,2.201648,,,,,,,,,,,,,, -61400,1.3359907,2.2635908,,,,,,,,,,,,,, -61500,1.2129232,2.0932016,,,,,,,,,,,,,, -61546,,,0.2065329998731613,4.479652404785156,0.1864800006151199,4.748819351196289,50000.0,0.1471000015735626,5.390626907348633,10000.0,20945.94157910347,21683.503759384155,20945.94157910347,734.0819170475006,1.3605809211730957,0.0 -61600,1.3659465,2.0493522,,,,,,,,,,,,,, -61700,1.4634601,2.1860368,,,,,,,,,,,,,, -61800,1.3036268,2.10733,,,,,,,,,,,,,, -61900,1.3003044,2.2803354,,,,,,,,,,,,,, -62000,1.3014746,2.093985,,,,,,,,,,,,,, -62100,1.2933621,2.1814032,,,,,,,,,,,,,, -62200,1.2733299,2.1619244,,,,,,,,,,,,,, -62300,1.3281138,2.089189,,,,,,,,,,,,,, -62400,1.242789,2.1576471,,,,,,,,,,,,,, -62500,1.268932,2.2777615,,,,,,,,,,,,,, -62600,1.4092785,2.10822,,,,,,,,,,,,,, -62700,1.2703664,2.2835093,,,,,,,,,,,,,, -62800,1.4014101,2.2075658,,,,,,,,,,,,,, -62900,1.4034356,2.1776354,,,,,,,,,,,,,, -63000,1.2953956,2.156268,,,,,,,,,,,,,, -63052,,,0.3102877736091614,3.508485078811645,0.2889399826526642,3.657221794128418,50000.0,0.2154000103473663,4.478518962860107,10000.0,21456.14753627777,22211.436054468155,21456.14753627777,751.711775302887,1.405496597290039,0.0 -63100,1.2829351,2.224906,,,,,,,,,,,,,, -63200,1.4525836,2.1881084,,,,,,,,,,,,,, -63300,1.3471569,2.3501427,,,,,,,,,,,,,, -63400,1.3131871,2.253471,,,,,,,,,,,,,, -63500,1.2340038,2.2655725,,,,,,,,,,,,,, -63600,1.2960854,2.1466393,,,,,,,,,,,,,, -63700,1.4147272,2.2357836,,,,,,,,,,,,,, -63800,1.2826939,2.105387,,,,,,,,,,,,,, -63900,1.2087046,2.1140347,,,,,,,,,,,,,, -64000,1.3458474,2.1398659,,,,,,,,,,,,,, -64100,1.2475613,2.1444352,,,,,,,,,,,,,, -64200,1.2650266,2.178992,,,,,,,,,,,,,, -64300,1.3587561,2.2002873,,,,,,,,,,,,,, -64400,1.3895019,2.1735606,,,,,,,,,,,,,, -64500,1.3580643,2.1526954,,,,,,,,,,,,,, -64556,,,0.3029137253761291,3.726241111755371,0.2832199931144714,3.931478977203369,50000.0,0.2107000052928924,4.761738300323486,10000.0,21966.07893872261,22738.717492341995,21966.07893872261,768.9775350093842,1.437954664230347,0.0 -64600,1.2428458,2.1295116,,,,,,,,,,,,,, -64700,1.3081689,2.143387,,,,,,,,,,,,,, -64800,1.3306686,2.1719558,,,,,,,,,,,,,, -64900,1.2432212,2.0736928,,,,,,,,,,,,,, -65000,1.236559,2.132492,,,,,,,,,,,,,, -65100,1.3192363,2.1772363,,,,,,,,,,,,,, -65200,1.3383038,2.1057754,,,,,,,,,,,,,, -65300,1.3129363,2.1383886,,,,,,,,,,,,,, -65400,1.3288386,2.156682,,,,,,,,,,,,,, -65500,1.303685,2.1402152,,,,,,,,,,,,,, -65600,1.5434836,2.1670184,,,,,,,,,,,,,, -65700,1.3357608,2.21787,,,,,,,,,,,,,, -65800,1.4145304,1.9873614,,,,,,,,,,,,,, -65900,1.2467517,2.2713037,,,,,,,,,,,,,, -66000,1.4924372,2.2657344,,,,,,,,,,,,,, -66061,,,0.343849629163742,3.137306928634644,0.3238599896430969,3.2995736598968506,50000.0,0.2424000054597854,4.018667221069336,10000.0,22476.09085536003,23266.313386917114,22476.09085536003,786.4678730964661,1.4794397354125977,0.0 -66100,1.2748978,2.170893,,,,,,,,,,,,,, -66200,1.2960246,2.196037,,,,,,,,,,,,,, -66300,1.2661139,2.1762493,,,,,,,,,,,,,, -66400,1.2453718,2.2095835,,,,,,,,,,,,,, -66500,1.3440882,2.1574569,,,,,,,,,,,,,, -66600,1.3566232,2.0812001,,,,,,,,,,,,,, -66700,1.4275198,2.152829,,,,,,,,,,,,,, -66800,1.438049,2.1129673,,,,,,,,,,,,,, -66900,1.312834,2.1759453,,,,,,,,,,,,,, -67000,1.2853442,2.169687,,,,,,,,,,,,,, -67100,1.3675331,2.31862,,,,,,,,,,,,,, -67200,1.3003459,2.0469356,,,,,,,,,,,,,, -67300,1.2810291,2.0999367,,,,,,,,,,,,,, -67400,1.3916361,2.2964103,,,,,,,,,,,,,, -67500,1.4685838,2.0325396,,,,,,,,,,,,,, -67566,,,0.32421875,3.3596620559692383,0.3198599815368652,3.3958446979522705,50000.0,0.2331000119447708,4.2524847984313965,10000.0,22986.0070104599,23793.81243133545,22986.0070104599,803.9587597846985,1.5198431015014648,0.0 -67600,1.4777807,2.2503314,,,,,,,,,,,,,, -67700,1.5272461,2.234913,,,,,,,,,,,,,, -67800,1.3362902,2.1504488,,,,,,,,,,,,,, -67900,1.2517312,1.987716,,,,,,,,,,,,,, -68000,1.3544497,2.190773,,,,,,,,,,,,,, -68100,1.2714448,2.2104423,,,,,,,,,,,,,, -68200,1.2896221,2.184524,,,,,,,,,,,,,, -68300,1.3519081,2.1994872,,,,,,,,,,,,,, -68400,1.3929356,2.1680522,,,,,,,,,,,,,, -68500,1.3668271,2.131671,,,,,,,,,,,,,, -68600,1.3806101,2.0287583,,,,,,,,,,,,,, -68700,1.2950484,2.0619042,,,,,,,,,,,,,, -68800,1.3243964,2.186273,,,,,,,,,,,,,, -68900,1.2819908,2.0735717,,,,,,,,,,,,,, -69000,1.32679,2.1641226,,,,,,,,,,,,,, -69071,,,0.3446069657802582,3.2224009037017822,0.3281199932098388,3.301442623138428,50000.0,0.2605000138282776,3.976634979248047,10000.0,23496.11438536644,24321.4155292511,23496.11438536644,821.3490543365479,1.5737159252166748,0.0 -69100,1.3641791,2.1447823,,,,,,,,,,,,,, -69200,1.4710588,2.1920478,,,,,,,,,,,,,, -69300,1.3672476,2.1356664,,,,,,,,,,,,,, -69400,1.3294797,2.087462,,,,,,,,,,,,,, -69500,1.4871615,2.2466044,,,,,,,,,,,,,, -69600,1.2947453,2.1165457,,,,,,,,,,,,,, -69700,1.4194683,2.1159303,,,,,,,,,,,,,, -69800,1.4940231,2.0744696,,,,,,,,,,,,,, -69900,1.3185079,2.1586118,,,,,,,,,,,,,, -70000,1.4779495,2.1716144,,,,,,,,,,,,,, -70100,1.4843067,2.2606711,,,,,,,,,,,,,, -70200,1.4414139,2.175967,,,,,,,,,,,,,, -70300,1.3073672,2.119194,,,,,,,,,,,,,, -70400,1.3371928,2.1542315,,,,,,,,,,,,,, -70500,1.4418508,2.086827,,,,,,,,,,,,,, -70576,,,0.3404615819454193,3.1158764362335205,0.3131199777126312,3.278167486190796,50000.0,0.2357000112533569,3.938102960586548,10000.0,24006.213018655777,24849.170696020126,24006.213018655777,838.9119355678558,1.6166942119598389,0.0 -70600,1.4514174,2.1550899,,,,,,,,,,,,,, -70700,1.5636315,2.278412,,,,,,,,,,,,,, -70800,1.4035807,2.2509253,,,,,,,,,,,,,, -70900,1.4153267,2.04912,,,,,,,,,,,,,, -71000,1.3878798,2.1873796,,,,,,,,,,,,,, -71100,1.3101082,2.0938878,,,,,,,,,,,,,, -71200,1.4260678,2.0568843,,,,,,,,,,,,,, -71300,1.365428,2.144974,,,,,,,,,,,,,, -71400,1.3816938,2.1024365,,,,,,,,,,,,,, -71500,1.380247,2.2120667,,,,,,,,,,,,,, -71600,1.44254,2.0325012,,,,,,,,,,,,,, -71700,1.3399051,2.220854,,,,,,,,,,,,,, -71800,1.3272632,2.0889826,,,,,,,,,,,,,, -71900,1.4082617,1.9939336,,,,,,,,,,,,,, -72000,1.4578576,2.1557682,,,,,,,,,,,,,, -72081,,,0.4693478941917419,2.3059208393096924,0.4316200017929077,2.5475010871887207,50000.0,0.3187000155448913,3.359588146209717,10000.0,24516.36786627769,25376.83913421631,24516.36786627769,856.3283641338348,1.6619768142700195,0.0 -72100,1.3760297,2.0766788,,,,,,,,,,,,,, -72200,1.3290749,2.0678525,,,,,,,,,,,,,, -72300,1.452183,2.2060275,,,,,,,,,,,,,, -72400,1.6896813,2.2334683,,,,,,,,,,,,,, -72500,1.4376738,2.1911325,,,,,,,,,,,,,, -72600,1.2832249,2.0505764,,,,,,,,,,,,,, -72700,1.4453437,2.1999576,,,,,,,,,,,,,, -72800,1.2907277,2.0811727,,,,,,,,,,,,,, -72900,1.4481301,2.121933,,,,,,,,,,,,,, -73000,1.3335218,2.08592,,,,,,,,,,,,,, -73100,1.615399,2.193481,,,,,,,,,,,,,, -73200,1.3811692,2.1558933,,,,,,,,,,,,,, -73300,1.3799107,2.1855786,,,,,,,,,,,,,, -73400,1.3856363,2.0150418,,,,,,,,,,,,,, -73500,1.7408918,2.2764304,,,,,,,,,,,,,, -73586,,,0.5002591013908386,2.164903163909912,0.45933997631073,2.4115843772888184,50000.0,0.3611000180244446,3.1056861877441406,10000.0,25026.566119670868,25904.495133399963,25026.566119670868,873.6862845420837,1.7103638648986816,0.0 -73600,1.3071624,1.9440684,,,,,,,,,,,,,, -73700,1.4604071,2.1315997,,,,,,,,,,,,,, -73800,1.3312231,2.0736763,,,,,,,,,,,,,, -73900,1.4740683,2.1443315,,,,,,,,,,,,,, -74000,1.4045138,2.0788226,,,,,,,,,,,,,, -74100,1.4053038,1.9955099,,,,,,,,,,,,,, -74200,1.4610082,2.054802,,,,,,,,,,,,,, -74300,1.5005069,2.1872797,,,,,,,,,,,,,, -74400,1.4970095,2.0241475,,,,,,,,,,,,,, -74500,1.4140079,2.1253734,,,,,,,,,,,,,, -74600,1.4529091,2.1677737,,,,,,,,,,,,,, -74700,1.4539297,2.1160216,,,,,,,,,,,,,, -74800,1.3954939,2.252457,,,,,,,,,,,,,, -74900,1.5073376,2.079699,,,,,,,,,,,,,, -75000,1.560669,2.0191388,,,,,,,,,,,,,, -75091,,,0.3360172212123871,3.280001163482666,0.3144199848175049,3.4575531482696533,50000.0,0.2302000075578689,4.258174419403076,10000.0,25536.524069309235,26432.089807510376,25536.524069309235,891.2308826446533,1.750760555267334,0.0 -75100,1.3569576,2.113364,,,,,,,,,,,,,, -75200,1.4241365,2.0915852,,,,,,,,,,,,,, -75300,1.4477254,2.1053534,,,,,,,,,,,,,, -75400,1.4365543,2.0397234,,,,,,,,,,,,,, -75500,1.5134017,2.1575527,,,,,,,,,,,,,, -75600,1.4691385,2.2155366,,,,,,,,,,,,,, -75700,1.4355452,2.1476984,,,,,,,,,,,,,, -75800,1.6415638,2.1254773,,,,,,,,,,,,,, -75900,1.4302197,2.0944014,,,,,,,,,,,,,, -76000,1.2879714,2.0999706,,,,,,,,,,,,,, -76100,1.4174992,2.083129,,,,,,,,,,,,,, -76200,1.4376245,2.0429666,,,,,,,,,,,,,, -76300,1.3534037,2.0212212,,,,,,,,,,,,,, -76400,1.5032961,2.1584895,,,,,,,,,,,,,, -76500,1.3996238,2.03102,,,,,,,,,,,,,, -76595,,,0.3689811825752258,3.100395917892456,0.3448599874973297,3.260609865188598,50000.0,0.2645000219345093,4.023133277893066,10000.0,26046.467567443848,26959.668189525604,26046.467567443848,908.7737927436827,1.7912750244140625,0.0 -76600,1.6763159,2.1839705,,,,,,,,,,,,,, -76700,1.4548671,2.1621168,,,,,,,,,,,,,, -76800,1.5079645,2.068864,,,,,,,,,,,,,, -76900,1.4465624,2.1255713,,,,,,,,,,,,,, -77000,1.2958835,2.0784721,,,,,,,,,,,,,, -77100,1.394527,2.1170428,,,,,,,,,,,,,, -77200,1.5207895,2.2419884,,,,,,,,,,,,,, -77300,1.5963155,2.127694,,,,,,,,,,,,,, -77400,1.4408495,2.0737414,,,,,,,,,,,,,, -77500,1.3679457,2.0559592,,,,,,,,,,,,,, -77600,1.3349736,2.068304,,,,,,,,,,,,,, -77700,1.3683802,2.1969469,,,,,,,,,,,,,, -77800,1.4418467,2.1674898,,,,,,,,,,,,,, -77900,1.5934912,2.1682687,,,,,,,,,,,,,, -78000,1.4093257,2.1243863,,,,,,,,,,,,,, -78100,,,0.2097417116165161,4.448638916015625,0.1947399973869323,4.624502658843994,50000.0,0.1477000117301941,5.23762321472168,10000.0,26556.542511701584,27487.187576293945,26556.542511701584,926.1273169517516,1.8315112590789795,0.0 -78100,1.621904,2.1215308,,,,,,,,,,,,,, -78200,1.36564,2.0010593,,,,,,,,,,,,,, -78300,1.3702315,2.0655236,,,,,,,,,,,,,, -78400,1.4534736,2.140831,,,,,,,,,,,,,, -78500,1.4367706,2.1108282,,,,,,,,,,,,,, -78600,1.5252067,2.064927,,,,,,,,,,,,,, -78700,1.2685621,1.8647319,,,,,,,,,,,,,, -78800,1.5381488,2.1070793,,,,,,,,,,,,,, -78900,1.4452486,2.1086454,,,,,,,,,,,,,, -79000,1.4984449,2.1614435,,,,,,,,,,,,,, -79100,1.4708333,2.1091323,,,,,,,,,,,,,, -79200,1.5013089,2.2554064,,,,,,,,,,,,,, -79300,1.5464549,2.1506228,,,,,,,,,,,,,, -79400,1.3699182,1.9834108,,,,,,,,,,,,,, -79500,1.3312694,2.0896075,,,,,,,,,,,,,, -79600,1.319499,2.0006154,,,,,,,,,,,,,, -79604,,,0.2213209420442581,4.309572219848633,0.207639992237091,4.423527240753174,50000.0,0.1488000005483627,5.156861782073975,10000.0,27066.480364322662,28014.59451031685,27066.480364322662,943.505437374115,1.871402740478516,0.0 -79700,1.4353719,1.9544587,,,,,,,,,,,,,, -79800,1.6651179,2.0572612,,,,,,,,,,,,,, -79900,1.4817982,2.0676336,,,,,,,,,,,,,, -80000,1.4237137,2.0535781,,,,,,,,,,,,,, -80100,1.346065,2.2093728,,,,,,,,,,,,,, -80200,1.5485106,2.0603452,,,,,,,,,,,,,, -80300,1.3732877,2.1126637,,,,,,,,,,,,,, -80400,1.4750307,2.0242343,,,,,,,,,,,,,, -80500,1.4021256,2.0662165,,,,,,,,,,,,,, -80600,1.5239234,2.1037364,,,,,,,,,,,,,, -80700,1.4321667,1.9919629,,,,,,,,,,,,,, -80800,1.4746175,1.9927149,,,,,,,,,,,,,, -80900,1.5199491,2.1099799,,,,,,,,,,,,,, -81000,1.463188,2.053903,,,,,,,,,,,,,, -81100,1.50006,1.9904294,,,,,,,,,,,,,, -81109,,,0.3338049948215484,3.2794435024261475,0.3070800006389618,3.5193474292755127,50000.0,0.2300000041723251,4.20773458480835,10000.0,27576.63488388061,28542.41470336914,27576.63488388061,961.0810222625732,1.9099252223968504,0.0 -81200,1.5306255,1.916691,,,,,,,,,,,,,, -81300,1.4855037,2.143062,,,,,,,,,,,,,, -81400,1.4492986,2.2193823,,,,,,,,,,,,,, -81500,1.4508207,2.1839094,,,,,,,,,,,,,, -81600,1.7964988,2.0934644,,,,,,,,,,,,,, -81700,1.5985278,2.0186424,,,,,,,,,,,,,, -81800,1.5510163,2.153679,,,,,,,,,,,,,, -81900,1.6340985,2.1519518,,,,,,,,,,,,,, -82000,1.4252454,2.0334918,,,,,,,,,,,,,, -82100,1.4398575,2.012053,,,,,,,,,,,,,, -82200,1.482932,2.0690343,,,,,,,,,,,,,, -82300,1.7179399,2.0484037,,,,,,,,,,,,,, -82400,1.4178969,2.0549893,,,,,,,,,,,,,, -82500,1.5619065,2.1570885,,,,,,,,,,,,,, -82600,1.6650599,2.030561,,,,,,,,,,,,,, -82614,,,0.395228773355484,2.817749261856079,0.367499977350235,3.013368606567383,50000.0,0.2820000052452087,3.757661819458008,10000.0,28086.69702768325,29070.151161670685,28086.69702768325,978.6555554866792,1.955564022064209,0.0 -82700,1.4813313,2.091229,,,,,,,,,,,,,, -82800,1.6483896,1.9430717,,,,,,,,,,,,,, -82900,1.7076654,2.170251,,,,,,,,,,,,,, -83000,1.5391121,2.050579,,,,,,,,,,,,,, -83100,1.4491351,2.0902653,,,,,,,,,,,,,, -83200,1.37502,1.9608864,,,,,,,,,,,,,, -83300,1.6758826,2.1543744,,,,,,,,,,,,,, -83400,1.550592,2.044552,,,,,,,,,,,,,, -83500,1.4771868,2.1082356,,,,,,,,,,,,,, -83600,1.5581381,2.09807,,,,,,,,,,,,,, -83700,1.6552074,2.2272744,,,,,,,,,,,,,, -83800,1.3962506,2.1433854,,,,,,,,,,,,,, -83900,1.3036649,2.0975087,,,,,,,,,,,,,, -84000,1.4511997,2.1468964,,,,,,,,,,,,,, -84100,1.4365927,2.1418877,,,,,,,,,,,,,, -84119,,,0.3988759517669678,2.8061952590942383,0.3729199767112732,3.0000438690185547,50000.0,0.2729000151157379,3.788983106613159,10000.0,28596.871851682663,29597.85049009323,28596.871851682663,996.0850801467896,1.9987788200378416,0.0 -84200,1.3853145,2.1189237,,,,,,,,,,,,,, -84300,1.5414324,2.066516,,,,,,,,,,,,,, -84400,1.6103494,2.1708245,,,,,,,,,,,,,, -84500,1.544378,2.055171,,,,,,,,,,,,,, -84600,1.4722872,2.0807781,,,,,,,,,,,,,, -84700,1.6243243,2.1461709,,,,,,,,,,,,,, -84800,1.5960068,2.0298483,,,,,,,,,,,,,, -84900,1.6086001,2.0896063,,,,,,,,,,,,,, -85000,1.584464,2.081896,,,,,,,,,,,,,, -85100,1.6083292,2.1409075,,,,,,,,,,,,,, -85200,1.5162058,2.0658922,,,,,,,,,,,,,, -85300,1.4806803,2.06322,,,,,,,,,,,,,, -85400,1.6175742,1.9566245,,,,,,,,,,,,,, -85500,1.4898857,1.9504642,,,,,,,,,,,,,, -85600,1.4613069,1.9402136,,,,,,,,,,,,,, -85624,,,0.4944395720958709,2.248067140579224,0.4607999920845032,2.458147525787353,50000.0,0.3519000113010406,3.282273292541504,10000.0,29107.02249646187,30125.56484889984,29107.02249646187,1013.550271511078,2.044531345367432,0.0 -85700,1.4901508,2.0956907,,,,,,,,,,,,,, -85800,1.5698658,2.0080757,,,,,,,,,,,,,, -85900,1.5589343,2.0507278,,,,,,,,,,,,,, -86000,1.5859299,2.0382786,,,,,,,,,,,,,, -86100,1.5391799,2.1409945,,,,,,,,,,,,,, -86200,1.4934751,2.10345,,,,,,,,,,,,,, -86300,1.5547044,1.9382243,,,,,,,,,,,,,, -86400,1.4348832,2.0590858,,,,,,,,,,,,,, -86500,1.5762874,2.0399966,,,,,,,,,,,,,, -86600,1.6960118,2.0632598,,,,,,,,,,,,,, -86700,1.4272748,2.1219761,,,,,,,,,,,,,, -86800,1.4698384,2.0408041,,,,,,,,,,,,,, -86900,1.534901,2.0914593,,,,,,,,,,,,,, -87000,1.3817002,2.024828,,,,,,,,,,,,,, -87100,1.6540056,1.9840875,,,,,,,,,,,,,, -87129,,,0.1268136203289032,5.291484355926514,0.1218999996781349,5.354971885681152,50000.0,0.0892000049352645,5.882399559020996,10000.0,29617.0174241066,30653.171503305435,29617.0174241066,1031.0666897296906,2.0885136127471924,0.0 -87200,1.4538447,2.0317662,,,,,,,,,,,,,, -87300,1.510815,2.029675,,,,,,,,,,,,,, -87400,1.4461298,2.1063228,,,,,,,,,,,,,, -87500,1.4714932,2.0156887,,,,,,,,,,,,,, -87600,1.4392577,2.0365038,,,,,,,,,,,,,, -87700,1.595113,2.135272,,,,,,,,,,,,,, -87800,1.5597692,1.9729654,,,,,,,,,,,,,, -87900,1.5148977,1.9545716,,,,,,,,,,,,,, -88000,1.5534234,1.9607909,,,,,,,,,,,,,, -88100,1.6251597,1.9633065,,,,,,,,,,,,,, -88200,1.6142442,2.101646,,,,,,,,,,,,,, -88300,1.933518,2.0445685,,,,,,,,,,,,,, -88400,1.5522841,2.0646572,,,,,,,,,,,,,, -88500,1.6547507,2.147449,,,,,,,,,,,,,, -88600,1.5456232,2.0280695,,,,,,,,,,,,,, -88633,,,0.4404296875,2.522449254989624,0.4178799986839294,2.665285348892212,50000.0,0.3120000064373016,3.441679000854492,10000.0,30126.98759460449,31180.547289133072,30126.98759460449,1048.3781082630155,2.131108283996582,0.0 -88700,1.5337797,2.0560422,,,,,,,,,,,,,, -88800,1.4746528,1.9044225,,,,,,,,,,,,,, -88900,1.4835955,2.0326712,,,,,,,,,,,,,, -89000,1.5390357,1.9970042,,,,,,,,,,,,,, -89100,1.4907004,1.9117692,,,,,,,,,,,,,, -89200,1.5836214,2.0092907,,,,,,,,,,,,,, -89300,1.5992002,1.9834245,,,,,,,,,,,,,, -89400,1.5981075,1.9298736,,,,,,,,,,,,,, -89500,1.6301461,1.9393845,,,,,,,,,,,,,, -89600,1.624147,2.0657341,,,,,,,,,,,,,, -89700,1.5630772,1.9050488,,,,,,,,,,,,,, -89800,1.5134443,1.9964654,,,,,,,,,,,,,, -89900,1.6187121,2.177436,,,,,,,,,,,,,, -90000,1.5303402,1.9456941,,,,,,,,,,,,,, -90100,1.6730025,2.1340256,,,,,,,,,,,,,, -90138,,,0.5174585580825806,2.07250714302063,0.4711000025272369,2.372112512588501,50000.0,0.3660000264644623,3.175899744033813,10000.0,30637.00756168365,31708.12091970444,30637.00756168365,1065.8279626369476,2.1842033863067627,0.0 -90200,1.6650698,2.0386395,,,,,,,,,,,,,, -90300,1.6370865,2.09439,,,,,,,,,,,,,, -90400,1.8765336,1.9810351,,,,,,,,,,,,,, -90500,1.52086,1.9867861,,,,,,,,,,,,,, -90600,1.6252416,1.9979419,,,,,,,,,,,,,, -90700,1.5441957,2.127158,,,,,,,,,,,,,, -90800,1.6005534,2.0545719,,,,,,,,,,,,,, -90900,1.5856769,1.943577,,,,,,,,,,,,,, -91000,1.672625,2.1591265,,,,,,,,,,,,,, -91100,1.5059636,1.9935039,,,,,,,,,,,,,, -91200,1.5827355,2.0034173,,,,,,,,,,,,,, -91300,1.570838,1.9472364,,,,,,,,,,,,,, -91400,1.537514,1.9606206,,,,,,,,,,,,,, -91500,1.423578,1.8779962,,,,,,,,,,,,,, -91600,1.7334137,2.040473,,,,,,,,,,,,,, -91643,,,0.4838368892669678,2.2562386989593506,0.4466199874877929,2.4910802841186523,50000.0,0.3335000276565552,3.2831525802612305,10000.0,31147.1129257679,32236.46764421463,31147.1129257679,1083.9716200828552,2.2286715507507324,0.0 -91700,1.6014671,1.9691877,,,,,,,,,,,,,, -91800,1.558619,2.078886,,,,,,,,,,,,,, -91900,1.7697686,2.0727615,,,,,,,,,,,,,, -92000,1.5404906,2.0368447,,,,,,,,,,,,,, -92100,1.6401805,2.2319288,,,,,,,,,,,,,, -92200,1.544001,2.0084555,,,,,,,,,,,,,, -92300,1.6386702,1.9028629,,,,,,,,,,,,,, -92400,1.7856141,1.949917,,,,,,,,,,,,,, -92500,1.4398907,2.0257113,,,,,,,,,,,,,, -92600,1.5489532,1.9912463,,,,,,,,,,,,,, -92700,1.5854824,1.9113088,,,,,,,,,,,,,, -92800,1.5406575,2.0524056,,,,,,,,,,,,,, -92900,1.6123575,1.9903992,,,,,,,,,,,,,, -93000,1.6994791,1.9317963,,,,,,,,,,,,,, -93100,1.6228113,2.0223446,,,,,,,,,,,,,, -93148,,,0.5083904266357422,2.1743996143341064,0.4763799905776977,2.3738784790039062,50000.0,0.3620000183582306,3.1680712699890137,10000.0,31657.05765509605,32763.91767191887,31657.05765509605,1101.381390094757,2.27234148979187,0.0 -93200,1.588282,1.9437923,,,,,,,,,,,,,, -93300,1.5846773,2.0136244,,,,,,,,,,,,,, -93400,1.686805,1.8824118,,,,,,,,,,,,,, -93500,1.7431053,2.0114074,,,,,,,,,,,,,, -93600,1.5155941,1.9499013,,,,,,,,,,,,,, -93700,1.7605616,1.9598576,,,,,,,,,,,,,, -93800,1.6506,2.0980182,,,,,,,,,,,,,, -93900,1.7342597,1.9629041,,,,,,,,,,,,,, -94000,1.5181376,2.0132766,,,,,,,,,,,,,, -94100,1.6207657,1.989807,,,,,,,,,,,,,, -94200,1.5317484,2.0478482,,,,,,,,,,,,,, -94300,1.5462956,2.0138953,,,,,,,,,,,,,, -94400,1.7094553,1.9206791,,,,,,,,,,,,,, -94500,1.6490222,2.0634675,,,,,,,,,,,,,, -94600,1.7106501,2.056633,,,,,,,,,,,,,, -94650,,,0.545918345451355,1.944828987121582,0.5107399821281433,2.1304807662963867,50000.0,0.3964000046253204,2.8958516120910645,10000.0,32167.045583724976,33291.27138733864,32167.045583724976,1118.651871919632,2.31717848777771,0.0 -94700,1.6717192,1.970155,,,,,,,,,,,,,, -94800,1.477162,1.8974129,,,,,,,,,,,,,, -94900,1.8428737,2.0774183,,,,,,,,,,,,,, -95000,1.4970102,1.9639093,,,,,,,,,,,,,, -95100,1.6306857,1.9334584,,,,,,,,,,,,,, -95200,1.6817044,1.8981583,,,,,,,,,,,,,, -95300,1.5885212,1.9059595,,,,,,,,,,,,,, -95400,1.6289803,1.9915237,,,,,,,,,,,,,, -95500,1.7692246,2.154437,,,,,,,,,,,,,, -95600,1.668185,2.0042775,,,,,,,,,,,,,, -95700,1.6533056,2.0077987,,,,,,,,,,,,,, -95800,1.6451433,1.9625759,,,,,,,,,,,,,, -95900,1.5734766,2.0902975,,,,,,,,,,,,,, -96000,1.7247764,2.0823054,,,,,,,,,,,,,, -96100,1.5595344,1.8545836,,,,,,,,,,,,,, -96154,,,0.3833506107330322,2.867633581161499,0.359959989786148,3.0436670780181885,50000.0,0.2696000039577484,3.772510766983032,10000.0,32676.99285697937,33818.811498880386,32676.99285697937,1136.1447837352753,2.3657150268554688,0.0 -96200,1.7533132,2.0568614,,,,,,,,,,,,,, -96300,2.1477246,2.0230827,,,,,,,,,,,,,, -96400,1.48781,1.9008065,,,,,,,,,,,,,, -96500,1.7580063,1.9636933,,,,,,,,,,,,,, -96600,1.6910329,1.9972632,,,,,,,,,,,,,, -96700,1.6407614,1.9572866,,,,,,,,,,,,,, -96800,1.6047847,1.8754692,,,,,,,,,,,,,, -96900,1.766899,2.0656853,,,,,,,,,,,,,, -97000,1.5658017,1.9549901,,,,,,,,,,,,,, -97100,1.6333562,1.970022,,,,,,,,,,,,,, -97200,1.9178386,2.083649,,,,,,,,,,,,,, -97300,1.7741578,1.931596,,,,,,,,,,,,,, -97400,1.7722573,1.970256,,,,,,,,,,,,,, -97500,1.8420948,1.9386947,,,,,,,,,,,,,, -97600,1.8319014,1.9386936,,,,,,,,,,,,,, -97659,,,0.5059789419174194,2.157892942428589,0.475160002708435,2.326932668685913,50000.0,0.3591000139713287,3.0999302864074707,10000.0,33187.0912899971,34346.33107614517,33187.0912899971,1153.468049287796,2.4127352237701416,0.0 -97700,1.9064142,1.9346817,,,,,,,,,,,,,, -97800,1.6815234,2.0868068,,,,,,,,,,,,,, -97900,1.6430885,1.9022831,,,,,,,,,,,,,, -98000,1.8760325,2.0151203,,,,,,,,,,,,,, -98100,1.773602,1.9310721,,,,,,,,,,,,,, -98200,1.7482984,1.8823695,,,,,,,,,,,,,, -98300,1.7071612,1.9453707,,,,,,,,,,,,,, -98400,1.6825011,2.0165167,,,,,,,,,,,,,, -98500,1.6496177,1.9316267,,,,,,,,,,,,,, -98600,1.7887855,2.0112,,,,,,,,,,,,,, -98700,1.7321105,2.0575063,,,,,,,,,,,,,, -98800,1.8228011,1.9580077,,,,,,,,,,,,,, -98900,1.7219278,1.9386088,,,,,,,,,,,,,, -99000,1.6042148,1.962557,,,,,,,,,,,,,, -99100,1.6021073,1.8335232,,,,,,,,,,,,,, -99164,,,0.5677016973495483,1.840865135192871,0.5112400054931641,2.122875928878784,50000.0,0.3920000195503235,2.879778385162353,10000.0,33697.07434248924,34873.93353009224,33697.07434248924,1170.9882607460022,2.460010766983032,0.0 -99200,1.6637818,2.0356796,,,,,,,,,,,,,, -99300,1.908334,2.035703,,,,,,,,,,,,,, -99400,1.7887444,1.8904352,,,,,,,,,,,,,, -99500,1.626993,1.9509652,,,,,,,,,,,,,, -99600,1.7245777,1.9559026,,,,,,,,,,,,,, -99700,1.5975627,1.8850608,,,,,,,,,,,,,, -99800,1.7423315,1.808742,,,,,,,,,,,,,, -99900,1.6315162,2.1049519,,,,,,,,,,,,,, -100000,1.7990198,1.9183453,,,,,,,,,,,,,, -100100,1.7800796,2.0113113,,,,,,,,,,,,,, -100200,1.553664,1.7821146,,,,,,,,,,,,,, -100300,1.7351733,2.0667057,,,,,,,,,,,,,, -100400,1.8452687,1.9169867,,,,,,,,,,,,,, -100500,1.7866347,1.9813579,,,,,,,,,,,,,, -100600,1.8359114,1.8861945,,,,,,,,,,,,,, -100669,,,0.4465082883834839,2.48357892036438,0.4145599901676178,2.690641403198242,50000.0,0.3048000037670135,3.538769245147705,10000.0,34207.1025724411,35401.42611813545,34207.1025724411,1188.3502779006958,2.5082507133483887,0.0 -100700,1.844404,2.0789976,,,,,,,,,,,,,, -100800,1.8247596,1.7984408,,,,,,,,,,,,,, -100900,1.8483917,2.0435114,,,,,,,,,,,,,, -101000,1.8083336,2.1848965,,,,,,,,,,,,,, -101100,1.7277908,1.8448409,,,,,,,,,,,,,, -101200,1.879303,1.9618106,,,,,,,,,,,,,, -101300,1.7444391,1.8989248,,,,,,,,,,,,,, -101400,1.854729,1.9922597,,,,,,,,,,,,,, -101500,1.6124511,1.9150056,,,,,,,,,,,,,, -101600,1.8580072,1.9200556,,,,,,,,,,,,,, -101700,1.9880669,1.8455575,,,,,,,,,,,,,, -101800,1.7048773,1.9129667,,,,,,,,,,,,,, -101900,1.8009394,2.1116173,,,,,,,,,,,,,, -102000,1.8267965,1.9600859,,,,,,,,,,,,,, -102100,1.7987864,1.9621619,,,,,,,,,,,,,, -102174,,,0.5240353941917419,2.0571160316467285,0.4918999969959259,2.25252103805542,50000.0,0.3846000134944916,2.99954605102539,10000.0,34717.1273932457,35928.986558914185,34717.1273932457,1205.7857220172882,2.556874513626098,0.0 -102200,1.9028052,1.8767711,,,,,,,,,,,,,, -102300,1.8214159,1.7714661,,,,,,,,,,,,,, -102400,1.8418663,1.9279656,,,,,,,,,,,,,, -102500,1.7008494,1.9557395,,,,,,,,,,,,,, -102600,1.6009053,1.9697696,,,,,,,,,,,,,, -102700,1.6798018,1.8017485,,,,,,,,,,,,,, -102800,1.724907,1.8554397,,,,,,,,,,,,,, -102900,1.8894469,1.9963019,,,,,,,,,,,,,, -103000,1.7601076,2.0514083,,,,,,,,,,,,,, -103100,1.74859,1.8800387,,,,,,,,,,,,,, -103200,1.9404665,1.963623,,,,,,,,,,,,,, -103300,1.8892957,2.0011835,,,,,,,,,,,,,, -103400,1.71408,1.9001032,,,,,,,,,,,,,, -103500,1.7637889,1.9823247,,,,,,,,,,,,,, -103600,1.8852599,2.0907784,,,,,,,,,,,,,, -103679,,,0.5104631781578064,2.159630060195923,0.471919983625412,2.3893167972564697,50000.0,0.3644000291824341,3.1780288219451904,10000.0,35227.254885435104,36456.76866769791,35227.254885435104,1223.3412280082705,2.6039724349975586,0.0 -103700,1.6905814,1.8724384,,,,,,,,,,,,,, -103800,1.8692831,1.8911095,,,,,,,,,,,,,, -103900,1.832006,1.8620881,,,,,,,,,,,,,, -104000,1.913376,1.9979513,,,,,,,,,,,,,, -104100,1.7809132,1.9859039,,,,,,,,,,,,,, -104200,1.70787,1.8719969,,,,,,,,,,,,,, -104300,1.7688226,1.866747,,,,,,,,,,,,,, -104400,1.6894629,1.8818852,,,,,,,,,,,,,, -104500,1.999119,1.9679384,,,,,,,,,,,,,, -104600,1.7315149,1.9415487,,,,,,,,,,,,,, -104700,2.045936,1.9852731,,,,,,,,,,,,,, -104800,1.7464226,1.9732068,,,,,,,,,,,,,, -104900,1.7583681,1.8388374,,,,,,,,,,,,,, -105000,1.7203553,1.8589171,,,,,,,,,,,,,, -105100,1.7515469,1.8818045,,,,,,,,,,,,,, -105184,,,0.5213249325752258,2.0907678604125977,0.4876799881458282,2.284471035003662,50000.0,0.388700008392334,2.9926905632019043,10000.0,35737.33686733246,36984.278197050095,35737.33686733246,1240.666074514389,2.65444564819336,0.0 -105200,1.6650481,1.7762716,,,,,,,,,,,,,, -105300,1.7042665,1.9321034,,,,,,,,,,,,,, -105400,1.8238374,2.0050488,,,,,,,,,,,,,, -105500,1.9060738,1.9366094,,,,,,,,,,,,,, -105600,1.8997691,1.8856735,,,,,,,,,,,,,, -105700,1.9318558,2.0873232,,,,,,,,,,,,,, -105800,1.844223,1.9436874,,,,,,,,,,,,,, -105900,1.8164198,1.8544275,,,,,,,,,,,,,, -106000,1.8434073,1.8972749,,,,,,,,,,,,,, -106100,1.8315309,1.9138191,,,,,,,,,,,,,, -106200,1.7692581,2.0010474,,,,,,,,,,,,,, -106300,1.891807,1.9344218,,,,,,,,,,,,,, -106400,1.9102172,1.9913876,,,,,,,,,,,,,, -106500,1.9650055,1.9267422,,,,,,,,,,,,,, -106600,1.7893108,1.9927943,,,,,,,,,,,,,, -106689,,,0.5049425959587097,2.1846439838409424,0.4718999862670898,2.3848423957824707,50000.0,0.3607000112533569,3.1824216842651367,10000.0,36247.34074640274,37511.78411793709,36247.34074640274,1258.0668041706083,2.7028579711914062,0.0 -106700,2.0342176,1.8540875,,,,,,,,,,,,,, -106800,2.0186696,1.9138005,,,,,,,,,,,,,, -106900,2.0213702,1.8672262,,,,,,,,,,,,,, -107000,1.750536,1.9858447,,,,,,,,,,,,,, -107100,1.832806,1.9000762,,,,,,,,,,,,,, -107200,1.8547117,1.83189,,,,,,,,,,,,,, -107300,1.8155148,1.9560027,,,,,,,,,,,,,, -107400,1.8100913,1.8626274,,,,,,,,,,,,,, -107500,1.7251972,1.846716,,,,,,,,,,,,,, -107600,1.9596663,2.0960798,,,,,,,,,,,,,, -107700,1.8979294,1.7992635,,,,,,,,,,,,,, -107800,1.794255,1.8690422,,,,,,,,,,,,,, -107900,1.957543,1.8942536,,,,,,,,,,,,,, -108000,1.7438164,2.0125115,,,,,,,,,,,,,, -108100,2.0459344,1.8454674,,,,,,,,,,,,,, -108194,,,0.4869260191917419,2.2414798736572266,0.4351199865341186,2.565911769866944,50000.0,0.3131000101566314,3.456578731536865,10000.0,36757.33936190605,38039.08936858177,36757.33936190605,1275.2838730812073,2.7409493923187256,0.0 -108200,1.9261216,1.8981801,,,,,,,,,,,,,, -108300,1.9088502,1.8924814,,,,,,,,,,,,,, -108400,2.259809,2.0107052,,,,,,,,,,,,,, -108500,1.9353482,2.0102997,,,,,,,,,,,,,, -108600,1.9675475,1.8735915,,,,,,,,,,,,,, -108700,1.8291957,1.899207,,,,,,,,,,,,,, -108800,1.951681,1.8548073,,,,,,,,,,,,,, -108900,2.323732,1.9918882,,,,,,,,,,,,,, -109000,1.7639384,1.8512377,,,,,,,,,,,,,, -109100,1.8260831,1.8768498,,,,,,,,,,,,,, -109200,2.0560272,1.758173,,,,,,,,,,,,,, -109300,1.9490767,1.8261281,,,,,,,,,,,,,, -109400,2.108025,1.8653964,,,,,,,,,,,,,, -109500,1.8810775,1.9925514,,,,,,,,,,,,,, -109600,1.9317286,1.9373627,,,,,,,,,,,,,, -109699,,,0.5515983700752258,1.900874376296997,0.5029199719429016,2.1730642318725586,50000.0,0.399800032377243,2.946348190307617,10000.0,37267.33650159836,38566.54166126251,37267.33650159836,1292.6387770175934,2.7895431518554688,0.0 -109700,1.8642972,1.8974612,,,,,,,,,,,,,, -109800,2.0331824,1.8548512,,,,,,,,,,,,,, -109900,1.8401899,1.9067943,,,,,,,,,,,,,, -110000,2.07743,1.9327604,,,,,,,,,,,,,, -110100,1.7690756,1.896735,,,,,,,,,,,,,, -110200,2.0864356,1.9292709,,,,,,,,,,,,,, -110300,1.9026809,1.9734739,,,,,,,,,,,,,, -110400,1.919236,1.9880795,,,,,,,,,,,,,, -110500,1.9286113,1.8792984,,,,,,,,,,,,,, -110600,1.9091284,1.8320466,,,,,,,,,,,,,, -110700,2.0164754,1.923931,,,,,,,,,,,,,, -110800,1.8699193,1.8623278,,,,,,,,,,,,,, -110900,2.0817158,1.911613,,,,,,,,,,,,,, -111000,2.0458267,2.0021908,,,,,,,,,,,,,, -111100,1.8779464,1.850872,,,,,,,,,,,,,, -111200,2.0549057,1.8326677,,,,,,,,,,,,,, -111203,,,0.5053411722183228,2.157813310623169,0.4675799906253814,2.3963406085968018,50000.0,0.3591000139713287,3.1881215572357178,10000.0,37777.31170344353,39094.16153287888,37777.31170344353,1310.1845960617063,2.837157964706421,0.0 -111300,1.8991251,2.0268404,,,,,,,,,,,,,, -111400,1.87055,1.846668,,,,,,,,,,,,,, -111500,1.9620291,1.8537891,,,,,,,,,,,,,, -111600,1.937999,1.9213765,,,,,,,,,,,,,, -111700,1.8896056,1.904934,,,,,,,,,,,,,, -111800,1.8676221,1.8930683,,,,,,,,,,,,,, -111900,1.9519377,1.7484021,,,,,,,,,,,,,, -112000,1.8468243,1.827837,,,,,,,,,,,,,, -112100,1.9056534,1.8241534,,,,,,,,,,,,,, -112200,1.7823846,1.7681569,,,,,,,,,,,,,, -112300,2.1106515,1.8284922,,,,,,,,,,,,,, -112400,1.9190487,1.8937469,,,,,,,,,,,,,, -112500,1.8937815,1.8777235,,,,,,,,,,,,,, -112600,1.8703974,1.8969398,,,,,,,,,,,,,, -112700,1.9914061,1.8384919,,,,,,,,,,,,,, -112708,,,0.482800543308258,2.318225383758545,0.4461199939250946,2.545124292373657,50000.0,0.3433000147342682,3.334402084350586,10000.0,38287.42334794998,39621.92883038521,38287.42334794998,1327.739814043045,2.886847972869873,0.0 -112800,1.9167417,1.8312126,,,,,,,,,,,,,, -112900,1.9922793,1.8006549,,,,,,,,,,,,,, -113000,1.9964767,1.8254743,,,,,,,,,,,,,, -113100,2.0810897,1.8639446,,,,,,,,,,,,,, -113200,2.0598633,1.7325462,,,,,,,,,,,,,, -113300,1.9554837,1.7545086,,,,,,,,,,,,,, -113400,1.989064,1.8690683,,,,,,,,,,,,,, -113500,2.0017025,1.9318469,,,,,,,,,,,,,, -113600,2.1312692,1.8354092,,,,,,,,,,,,,, -113700,2.0447454,1.7348359,,,,,,,,,,,,,, -113800,2.0143864,1.864622,,,,,,,,,,,,,, -113900,2.149751,1.8632861,,,,,,,,,,,,,, -114000,2.1199634,1.9661452,,,,,,,,,,,,,, -114100,2.1188097,1.8424749,,,,,,,,,,,,,, -114200,2.0682323,1.7740841,,,,,,,,,,,,,, -114214,,,0.5449019074440002,1.9557873010635376,0.5061599612236023,2.1710429191589355,50000.0,0.3836000263690948,3.029585599899292,10000.0,38797.589587688446,40149.82967543602,38797.589587688446,1345.3750817775726,2.9340431690216064,0.0 -114300,1.8764813,1.8158131,,,,,,,,,,,,,, -114400,1.9982511,1.815676,,,,,,,,,,,,,, -114500,2.0552595,1.9300985,,,,,,,,,,,,,, -114600,2.01031,1.9868829,,,,,,,,,,,,,, -114700,1.8852208,1.933507,,,,,,,,,,,,,, -114800,1.9375212,1.8870116,,,,,,,,,,,,,, -114900,2.0031462,1.891867,,,,,,,,,,,,,, -115000,1.9858376,1.9261553,,,,,,,,,,,,,, -115100,1.9359328,1.7299407,,,,,,,,,,,,,, -115200,2.0007608,1.7839414,,,,,,,,,,,,,, -115300,1.9934802,1.8150518,,,,,,,,,,,,,, -115400,2.0309253,1.8079792,,,,,,,,,,,,,, -115500,2.0491648,1.7889211,,,,,,,,,,,,,, -115600,2.1128526,1.9111834,,,,,,,,,,,,,, -115700,2.3356423,1.8126363,,,,,,,,,,,,,, -115718,,,0.5570591688156128,1.8826926946640008,0.527999997138977,2.093355178833008,50000.0,0.4082000255584717,2.8958003520965576,10000.0,39307.54098248482,40677.041610240936,39307.54098248482,1362.5382542610168,2.980867624282837,0.0 -115800,2.1110938,1.8629168,,,,,,,,,,,,,, -115900,2.096613,1.7695366,,,,,,,,,,,,,, -116000,2.2595098,1.7888613,,,,,,,,,,,,,, -116100,1.9372913,1.8371779,,,,,,,,,,,,,, -116200,2.1905272,1.7776779,,,,,,,,,,,,,, -116300,2.0636983,1.822922,,,,,,,,,,,,,, -116400,2.1280293,1.7922025,,,,,,,,,,,,,, -116500,2.0564926,1.7261128,,,,,,,,,,,,,, -116600,2.3283975,1.8842645,,,,,,,,,,,,,, -116700,2.0047429,1.8136178,,,,,,,,,,,,,, -116800,2.2600594,1.8568783,,,,,,,,,,,,,, -116900,2.1410491,1.8722876,,,,,,,,,,,,,, -117000,2.0603867,1.7658677,,,,,,,,,,,,,, -117100,2.0276318,1.7395167,,,,,,,,,,,,,, -117200,2.3746917,1.7955714,,,,,,,,,,,,,, -117223,,,0.6294443607330322,1.476343035697937,0.5562199950218201,1.8933178186416624,50000.0,0.4335000216960907,2.682591438293457,10000.0,39817.586441755295,41204.382727622986,39817.586441755295,1379.7358441352844,3.027287483215332,0.0 -117300,2.186888,1.6758652,,,,,,,,,,,,,, -117400,2.00918,1.8048437,,,,,,,,,,,,,, -117500,2.1569703,1.8168634,,,,,,,,,,,,,, -117600,2.1292393,1.8650405,,,,,,,,,,,,,, -117700,2.1622503,1.8561823,,,,,,,,,,,,,, -117800,2.1613696,1.8054397,,,,,,,,,,,,,, -117900,2.1239936,1.7956426,,,,,,,,,,,,,, -118000,2.0720901,1.720968,,,,,,,,,,,,,, -118100,2.3004377,1.7462769,,,,,,,,,,,,,, -118200,2.099629,1.8109822,,,,,,,,,,,,,, -118300,2.245631,1.9110813,,,,,,,,,,,,,, -118400,2.0030713,1.7976456,,,,,,,,,,,,,, -118500,1.9743907,1.7173114,,,,,,,,,,,,,, -118600,2.1731858,1.8917881,,,,,,,,,,,,,, -118700,2.2307336,1.7979794,,,,,,,,,,,,,, -118728,,,0.4579878747463226,2.500563621520996,0.4254599809646606,2.721994638442993,50000.0,0.3305000066757202,3.4887149333953857,10000.0,40327.71361851692,41731.77513575554,40327.71361851692,1396.9002561569214,3.074965238571167,0.0 -118800,2.2675757,1.8960323,,,,,,,,,,,,,, -118900,2.134521,1.8672531,,,,,,,,,,,,,, -119000,2.18658,1.7559432,,,,,,,,,,,,,, -119100,2.2822156,1.8354951,,,,,,,,,,,,,, -119200,2.2657995,1.7806172,,,,,,,,,,,,,, -119300,2.1381845,1.8940504,,,,,,,,,,,,,, -119400,2.1477847,1.8865992,,,,,,,,,,,,,, -119500,2.0976856,1.8253995,,,,,,,,,,,,,, -119600,2.200083,1.6940356,,,,,,,,,,,,,, -119700,2.0274885,1.7114818,,,,,,,,,,,,,, -119800,2.2783844,1.9056368,,,,,,,,,,,,,, -119900,2.2263978,1.8368348,,,,,,,,,,,,,, -120000,2.1144257,1.7599264,,,,,,,,,,,,,, -120100,1.9406781,1.7287934,,,,,,,,,,,,,, -120200,2.1560173,1.7616253,,,,,,,,,,,,,, -120233,,,0.6004663705825806,1.6669267416000366,0.5487599968910217,1.962204933166504,50000.0,0.4288000166416168,2.7302029132843018,10000.0,40837.7944047451,42259.27208185196,40837.7944047451,1414.213954925537,3.124530076980591,0.0 -120300,2.1013253,1.7719543,,,,,,,,,,,,,, -120400,2.1757984,1.9056276,,,,,,,,,,,,,, -120500,2.1634336,1.9192727,,,,,,,,,,,,,, -120600,2.280079,1.9009968,,,,,,,,,,,,,, -120700,2.248928,1.7781181,,,,,,,,,,,,,, -120800,2.2527695,1.7637607,,,,,,,,,,,,,, -120900,2.3297598,1.7710334,,,,,,,,,,,,,, -121000,2.2189043,1.706445,,,,,,,,,,,,,, -121100,2.2955828,1.7069277,,,,,,,,,,,,,, -121200,2.177266,1.7250193,,,,,,,,,,,,,, -121300,2.0361025,1.691228,,,,,,,,,,,,,, -121400,2.1481557,1.885884,,,,,,,,,,,,,, -121500,2.080047,1.736372,,,,,,,,,,,,,, -121600,2.0930989,1.7454253,,,,,,,,,,,,,, -121700,2.0810585,1.8397074,,,,,,,,,,,,,, -121738,,,0.5869738459587097,1.7295794486999512,0.5414199829101562,1.9991382360458367,50000.0,0.4227000176906585,2.80234169960022,10000.0,41347.86622738838,42786.84964418411,41347.86622738838,1431.6174721717834,3.175598382949829,0.0 -121800,2.237994,1.9037565,,,,,,,,,,,,,, -121900,2.2083516,1.8775815,,,,,,,,,,,,,, -122000,2.1787357,1.7418088,,,,,,,,,,,,,, -122100,2.0615814,1.6834348,,,,,,,,,,,,,, -122200,2.1607425,1.701302,,,,,,,,,,,,,, -122300,2.089214,1.8014836,,,,,,,,,,,,,, -122400,2.258236,1.72854,,,,,,,,,,,,,, -122500,2.281314,1.7915981,,,,,,,,,,,,,, -122600,2.267304,1.7851034,,,,,,,,,,,,,, -122700,2.3309448,1.8235009,,,,,,,,,,,,,, -122800,2.5166304,1.7197497,,,,,,,,,,,,,, -122900,2.3311043,1.7982378,,,,,,,,,,,,,, -123000,2.1780174,1.8366498,,,,,,,,,,,,,, -123100,2.290427,1.8222895,,,,,,,,,,,,,, -123200,2.354673,1.7739507,,,,,,,,,,,,,, -123244,,,0.5721460580825806,1.803208351135254,0.5319399833679199,2.029165267944336,50000.0,0.3977000117301941,2.858058214187622,10000.0,41858.08864212036,43314.71329832077,41858.08864212036,1449.155464887619,3.227324724197388,0.0 -123300,2.1755612,1.6140887,,,,,,,,,,,,,, -123400,2.181605,1.7318828,,,,,,,,,,,,,, -123500,2.2230995,1.7305875,,,,,,,,,,,,,, -123600,2.2080932,1.8135977,,,,,,,,,,,,,, -123700,2.4088395,1.8260235,,,,,,,,,,,,,, -123800,2.4123876,1.8093305,,,,,,,,,,,,,, -123900,2.3417096,1.8084165,,,,,,,,,,,,,, -124000,2.2867212,1.7314405,,,,,,,,,,,,,, -124100,2.132416,1.7478282,,,,,,,,,,,,,, -124200,2.1099572,1.6903968,,,,,,,,,,,,,, -124300,2.2197623,1.7387787,,,,,,,,,,,,,, -124400,2.5402257,1.7328601,,,,,,,,,,,,,, -124500,2.3728395,1.7024233,,,,,,,,,,,,,, -124600,2.3785453,1.7015221,,,,,,,,,,,,,, -124700,2.2757936,1.7229168,,,,,,,,,,,,,, -124749,,,0.6416613459587097,1.4549739360809326,0.5970799922943115,1.6888469457626345,50000.0,0.4696000218391418,2.445818901062012,10000.0,42368.04260277748,43842.06088280678,42368.04260277748,1466.4475784301758,3.278285026550293,0.0 -124800,2.2698398,1.656458,,,,,,,,,,,,,, -124900,2.2556026,1.7012208,,,,,,,,,,,,,, -125000,2.4421556,1.6677423,,,,,,,,,,,,,, -125100,2.1198587,1.6814445,,,,,,,,,,,,,, -125200,2.1820452,1.7921046,,,,,,,,,,,,,, -125300,2.4493506,1.6895254,,,,,,,,,,,,,, -125400,2.4772098,1.8056169,,,,,,,,,,,,,, -125500,2.4025755,1.8481077,,,,,,,,,,,,,, -125600,2.3555794,1.7906593,,,,,,,,,,,,,, -125700,2.2069283,1.6123611,,,,,,,,,,,,,, -125800,2.4799676,1.8520645,,,,,,,,,,,,,, -125900,2.3947773,1.7586746,,,,,,,,,,,,,, -126000,2.5298622,1.7695636,,,,,,,,,,,,,, -126100,2.37895,1.7633207,,,,,,,,,,,,,, -126200,2.3049812,1.749925,,,,,,,,,,,,,, -126254,,,0.6292649507522583,1.4908218383789062,0.5594800114631653,1.9337760210037231,50000.0,0.4337000250816345,2.768241167068481,10000.0,42877.98620200157,44369.39072751999,42877.98620200157,1483.7346332073212,3.327343463897705,0.0 -126300,2.3891826,1.8532639,,,,,,,,,,,,,, -126400,2.4925914,1.7445741,,,,,,,,,,,,,, -126500,2.2212002,1.5411825,,,,,,,,,,,,,, -126600,2.2833283,1.7710178,,,,,,,,,,,,,, -126700,2.5202792,1.7452794,,,,,,,,,,,,,, -126800,2.4398618,1.7488676,,,,,,,,,,,,,, -126900,2.4451306,1.8668287,,,,,,,,,,,,,, -127000,2.5071223,1.7143755,,,,,,,,,,,,,, -127100,2.3961508,1.8009768,,,,,,,,,,,,,, -127200,2.3903089,1.6933867,,,,,,,,,,,,,, -127300,2.2222652,1.6387793,,,,,,,,,,,,,, -127400,2.4094403,1.6848547,,,,,,,,,,,,,, -127500,2.3044364,1.5676212,,,,,,,,,,,,,, -127600,2.2816947,1.7172494,,,,,,,,,,,,,, -127700,2.547079,1.8085997,,,,,,,,,,,,,, -127759,,,0.6521045565605164,1.3957295417785645,0.5881999731063843,1.7238211631774902,50000.0,0.4629000127315521,2.50555157661438,10000.0,43387.98578906059,44897.10552716255,43387.98578906059,1501.349282026291,3.3759353160858154,0.0 -127800,2.401377,1.8439232,,,,,,,,,,,,,, -127900,2.6213737,1.6542327,,,,,,,,,,,,,, -128000,2.4858766,1.82402,,,,,,,,,,,,,, -128100,2.288494,1.7445574,,,,,,,,,,,,,, -128200,2.3599484,1.6526589,,,,,,,,,,,,,, -128300,2.5900307,1.790329,,,,,,,,,,,,,, -128400,2.4561675,1.5709355,,,,,,,,,,,,,, -128500,2.7459888,1.7303061,,,,,,,,,,,,,, -128600,2.6518342,1.7145088,,,,,,,,,,,,,, -128700,2.307547,1.6382099,,,,,,,,,,,,,, -128800,2.3471463,1.713515,,,,,,,,,,,,,, -128900,2.5814936,1.6343625,,,,,,,,,,,,,, -129000,2.409475,1.6845677,,,,,,,,,,,,,, -129100,2.330737,1.7885728,,,,,,,,,,,,,, -129200,2.4416883,1.6914921,,,,,,,,,,,,,, -129265,,,0.6463648080825806,1.4183142185211182,0.5917400121688843,1.7195110321044922,50000.0,0.4722000360488891,2.4591338634490967,10000.0,43897.913105010986,45424.64857959747,43897.913105010986,1518.861918926239,3.42822790145874,0.0 -129300,2.4360816,1.722411,,,,,,,,,,,,,, -129400,2.4889112,1.8337629,,,,,,,,,,,,,, -129500,2.3036597,1.6216948,,,,,,,,,,,,,, -129600,2.5233839,1.869072,,,,,,,,,,,,,, -129700,2.695049,1.5882921,,,,,,,,,,,,,, -129800,2.2471914,1.7405382,,,,,,,,,,,,,, -129900,2.4148018,1.5860322,,,,,,,,,,,,,, -130000,2.650554,1.7159725,,,,,,,,,,,,,, -130100,2.6096623,1.7529111,,,,,,,,,,,,,, -130200,2.613967,1.5990384,,,,,,,,,,,,,, -130300,2.5112267,1.6054811,,,,,,,,,,,,,, -130400,2.5503528,1.6842228,,,,,,,,,,,,,, -130500,2.6387095,1.6373464,,,,,,,,,,,,,, -130600,2.4275627,1.6417015,,,,,,,,,,,,,, -130700,2.4993446,1.6763546,,,,,,,,,,,,,, -130771,,,0.6597377061843872,1.3550430536270142,0.6078000068664551,1.6432515382766724,50000.0,0.4835000336170196,2.3731515407562256,10000.0,44408.10325551033,45952.20838069916,44408.10325551033,1536.1293251514437,3.4792776107788086,0.0 -130800,2.324779,1.7084078,,,,,,,,,,,,,, -130900,2.5091922,1.7109749,,,,,,,,,,,,,, -131000,2.6924226,1.6241865,,,,,,,,,,,,,, -131100,2.70298,1.6097963,,,,,,,,,,,,,, -131200,2.5916872,1.6188551,,,,,,,,,,,,,, -131300,2.5934398,1.6940086,,,,,,,,,,,,,, -131400,2.552373,1.623318,,,,,,,,,,,,,, -131500,2.6111379,1.6384432,,,,,,,,,,,,,, -131600,2.5052853,1.6094688,,,,,,,,,,,,,, -131700,3.0806115,1.8182007,,,,,,,,,,,,,, -131800,2.5887039,1.6286323,,,,,,,,,,,,,, -131900,2.601875,1.6433116,,,,,,,,,,,,,, -132000,2.4741418,1.6507754,,,,,,,,,,,,,, -132100,2.5504549,1.6713403,,,,,,,,,,,,,, -132200,2.7909582,1.5642259,,,,,,,,,,,,,, -132277,,,0.5997488498687744,1.664632797241211,0.5585799813270569,1.9178231954574585,50000.0,0.4361000061035156,2.7067646980285645,10000.0,44918.220878601074,46479.58945250511,44918.220878601074,1553.2879321575165,3.5328681468963623,0.0 -132300,2.5573287,1.5525614,,,,,,,,,,,,,, -132400,2.6572123,1.5859625,,,,,,,,,,,,,, -132500,2.8158493,1.7273421,,,,,,,,,,,,,, -132600,2.63919,1.7304229,,,,,,,,,,,,,, -132700,2.97137,1.6823102,,,,,,,,,,,,,, -132800,2.7541845,1.6798167,,,,,,,,,,,,,, -132900,2.6582353,1.6477025,,,,,,,,,,,,,, -133000,2.8136868,1.8094797,,,,,,,,,,,,,, -133100,2.5878146,1.6380467,,,,,,,,,,,,,, -133200,2.8845696,1.7098777,,,,,,,,,,,,,, -133300,2.662949,1.672346,,,,,,,,,,,,,, -133400,2.7207794,1.5455797,,,,,,,,,,,,,, -133500,2.59636,1.5973865,,,,,,,,,,,,,, -133600,2.4715338,1.6348883,,,,,,,,,,,,,, -133700,2.5268433,1.6886816,,,,,,,,,,,,,, -133783,,,0.623445451259613,1.5336153507232666,0.577239990234375,1.8043313026428225,50000.0,0.4651000201702118,2.53909683227539,10000.0,45428.21768307686,47007.13358902931,45428.21768307686,1570.7316064834597,3.585947036743164,0.0 -133800,2.5577366,1.502279,,,,,,,,,,,,,, -133900,2.809816,1.808946,,,,,,,,,,,,,, -134000,2.7363315,1.6201372,,,,,,,,,,,,,, -134100,2.9447417,1.7652165,,,,,,,,,,,,,, -134200,2.5162084,1.5642341,,,,,,,,,,,,,, -134300,2.7430236,1.6264282,,,,,,,,,,,,,, -134400,2.7310884,1.6055372,,,,,,,,,,,,,, -134500,2.767063,1.659725,,,,,,,,,,,,,, -134600,2.5532691,1.6152217,,,,,,,,,,,,,, -134700,2.710516,1.630976,,,,,,,,,,,,,, -134800,2.6556995,1.5857844,,,,,,,,,,,,,, -134900,2.6207135,1.5785706,,,,,,,,,,,,,, -135000,2.7318668,1.6176381,,,,,,,,,,,,,, -135100,2.7341664,1.520682,,,,,,,,,,,,,, -135200,2.4964921,1.5316387,,,,,,,,,,,,,, -135289,,,0.6740872263908386,1.2957236766815186,0.5938000082969666,1.7162760496139526,50000.0,0.4607000350952148,2.485707998275757,10000.0,45938.33998990059,47534.735048532486,45938.33998990059,1588.1084327697754,3.637380599975586,0.0 -135300,3.096556,1.6466877,,,,,,,,,,,,,, -135400,2.6423547,1.5954404,,,,,,,,,,,,,, -135500,2.6725879,1.5734711,,,,,,,,,,,,,, -135600,2.8831913,1.5352168,,,,,,,,,,,,,, -135700,2.6990345,1.560799,,,,,,,,,,,,,, -135800,2.7406175,1.5949098,,,,,,,,,,,,,, -135900,2.8359385,1.6364229,,,,,,,,,,,,,, -136000,2.5536013,1.5096141,,,,,,,,,,,,,, -136100,2.6722155,1.6504169,,,,,,,,,,,,,, -136200,2.8961809,1.5966585,,,,,,,,,,,,,, -136300,2.676584,1.5476577,,,,,,,,,,,,,, -136400,2.6365263,1.6353693,,,,,,,,,,,,,, -136500,2.700429,1.5624887,,,,,,,,,,,,,, -136600,2.7585473,1.6377724,,,,,,,,,,,,,, -136700,2.7913978,1.5996729,,,,,,,,,,,,,, -136794,,,0.6825972199440002,1.2317051887512207,0.6187599897384644,1.594221830368042,50000.0,0.497700035572052,2.348714590072632,10000.0,46448.27879500389,48062.19457030296,46448.27879500389,1605.5273563861847,3.687337875366211,0.0 -136800,2.72512,1.5754946,,,,,,,,,,,,,, -136900,3.166474,1.6360309,,,,,,,,,,,,,, -137000,2.6056206,1.584058,,,,,,,,,,,,,, -137100,2.7976682,1.615546,,,,,,,,,,,,,, -137200,2.8137968,1.5465609,,,,,,,,,,,,,, -137300,2.8449821,1.6046576,,,,,,,,,,,,,, -137400,2.8720877,1.572765,,,,,,,,,,,,,, -137500,2.8724358,1.557663,,,,,,,,,,,,,, -137600,2.8705275,1.5730817,,,,,,,,,,,,,, -137700,2.8480217,1.7556995,,,,,,,,,,,,,, -137800,2.9061527,1.802882,,,,,,,,,,,,,, -137900,2.7914765,1.7249218,,,,,,,,,,,,,, -138000,2.9539099,1.6088297,,,,,,,,,,,,,, -138100,2.8823762,1.5878832,,,,,,,,,,,,,, -138200,2.6684763,1.5807841,,,,,,,,,,,,,, -138300,,,0.6966079473495483,1.193784475326538,0.6342799663543701,1.5047619342803955,50000.0,0.5135000348091125,2.21742844581604,10000.0,46958.36615371704,48589.68688797951,46958.36615371704,1622.827439069748,3.740651845932007,0.0 -138300,3.0336354,1.5799667,,,,,,,,,,,,,, -138400,2.8480618,1.6901298,,,,,,,,,,,,,, -138500,2.9209285,1.6618181,,,,,,,,,,,,,, -138600,2.556507,1.5291233,,,,,,,,,,,,,, -138700,2.6129358,1.529176,,,,,,,,,,,,,, -138800,2.9857535,1.6083223,,,,,,,,,,,,,, -138900,2.992632,1.6901293,,,,,,,,,,,,,, -139000,2.9267797,1.5201031,,,,,,,,,,,,,, -139100,2.9054074,1.6391519,,,,,,,,,,,,,, -139200,2.7770476,1.5271871,,,,,,,,,,,,,, -139300,3.0728283,1.6606048,,,,,,,,,,,,,, -139400,2.9641438,1.6996114,,,,,,,,,,,,,, -139500,2.840308,1.5805871,,,,,,,,,,,,,, -139600,2.7732856,1.57299,,,,,,,,,,,,,, -139700,3.0322473,1.5695803,,,,,,,,,,,,,, -139800,2.8918467,1.4795067,,,,,,,,,,,,,, -139806,,,0.6668726205825806,1.3343795537948608,0.6074599623680115,1.647443413734436,50000.0,0.4682000279426574,2.447305917739868,10000.0,47468.524106025696,49117.42114567757,47468.524106025696,1640.2983448505402,3.79308032989502,0.0 -139900,2.9854126,1.5435379,,,,,,,,,,,,,, -140000,3.2645006,1.5262518,,,,,,,,,,,,,, -140100,2.9730918,1.495843,,,,,,,,,,,,,, -140200,2.8524168,1.5079705,,,,,,,,,,,,,, -140300,2.9915798,1.5368255,,,,,,,,,,,,,, -140400,2.8564634,1.5061178,,,,,,,,,,,,,, -140500,3.115635,1.6406405,,,,,,,,,,,,,, -140600,2.9631903,1.5670037,,,,,,,,,,,,,, -140700,3.02983,1.5241983,,,,,,,,,,,,,, -140800,3.0352495,1.5641452,,,,,,,,,,,,,, -140900,2.8007867,1.5305961,,,,,,,,,,,,,, -141000,3.1089113,1.516589,,,,,,,,,,,,,, -141100,3.136099,1.5332687,,,,,,,,,,,,,, -141200,2.825075,1.5245583,,,,,,,,,,,,,, -141300,2.9102328,1.6011708,,,,,,,,,,,,,, -141312,,,0.7024075388908386,1.1639950275421145,0.6421200037002563,1.4964680671691897,50000.0,0.5249000191688538,2.221546173095703,10000.0,47978.61784863472,49644.76369333267,47978.61784863472,1657.447308063507,3.841919660568237,0.0 -141400,3.1183457,1.5405223,,,,,,,,,,,,,, -141500,3.1660514,1.5474463,,,,,,,,,,,,,, -141600,3.042119,1.569298,,,,,,,,,,,,,, -141700,2.9928255,1.4737256,,,,,,,,,,,,,, -141800,3.1962414,1.502337,,,,,,,,,,,,,, -141900,2.9183133,1.542232,,,,,,,,,,,,,, -142000,2.9644628,1.5042307,,,,,,,,,,,,,, -142100,3.262022,1.6738895,,,,,,,,,,,,,, -142200,3.3353703,1.6601813,,,,,,,,,,,,,, -142300,3.2553663,1.5923389,,,,,,,,,,,,,, -142400,3.1739805,1.5752835,,,,,,,,,,,,,, -142500,2.826448,1.5174683,,,,,,,,,,,,,, -142600,3.0482116,1.535871,,,,,,,,,,,,,, -142700,3.1534543,1.53863,,,,,,,,,,,,,, -142800,3.1083236,1.4864609,,,,,,,,,,,,,, -142818,,,0.7105787396430969,1.1328246593475342,0.649399995803833,1.4419379234313965,50000.0,0.5152000188827515,2.194200038909912,10000.0,48488.61668562889,50172.85512781143,48488.61668562889,1675.4333300590515,3.895407199859619,0.0 -142900,3.0454845,1.4372777,,,,,,,,,,,,,, -143000,3.1329677,1.6356481,,,,,,,,,,,,,, -143100,3.0633922,1.5068345,,,,,,,,,,,,,, -143200,3.1430998,1.5783987,,,,,,,,,,,,,, -143300,3.2272606,1.5476792,,,,,,,,,,,,,, -143400,3.0209978,1.3825408,,,,,,,,,,,,,, -143500,3.0724049,1.4422188,,,,,,,,,,,,,, -143600,3.0463066,1.5684141,,,,,,,,,,,,,, -143700,3.1681812,1.5798947,,,,,,,,,,,,,, -143800,3.0872169,1.4881536,,,,,,,,,,,,,, -143900,3.057759,1.4714826,,,,,,,,,,,,,, -144000,3.2606885,1.4848322,,,,,,,,,,,,,, -144100,3.2124362,1.4684033,,,,,,,,,,,,,, -144200,3.1606233,1.4712045,,,,,,,,,,,,,, -144300,3.2400322,1.4608774,,,,,,,,,,,,,, -144323,,,0.749422013759613,0.9607897996902466,0.6571199893951416,1.4050936698913574,50000.0,0.532800018787384,2.105572462081909,10000.0,48998.60055851936,50700.674156188965,48998.60055851936,1693.17382478714,3.937635183334351,0.0 -144400,3.1516962,1.3908343,,,,,,,,,,,,,, -144500,3.1904106,1.5203065,,,,,,,,,,,,,, -144600,3.1882005,1.4767932,,,,,,,,,,,,,, -144700,3.2442172,1.6668905,,,,,,,,,,,,,, -144800,3.5642738,1.5935512,,,,,,,,,,,,,, -144900,3.3602815,1.4249979,,,,,,,,,,,,,, -145000,3.0995955,1.4436101,,,,,,,,,,,,,, -145100,3.3035855,1.4492629,,,,,,,,,,,,,, -145200,3.2868836,1.4921246,,,,,,,,,,,,,, -145300,3.1216145,1.4782217,,,,,,,,,,,,,, -145400,3.5059114,1.4426781,,,,,,,,,,,,,, -145500,3.4254396,1.544168,,,,,,,,,,,,,, -145600,3.196413,1.4553137,,,,,,,,,,,,,, -145700,3.2268496,1.5038401,,,,,,,,,,,,,, -145800,2.9721334,1.4163507,,,,,,,,,,,,,, -145829,,,0.7521324753761292,0.9460116028785706,0.6695399880409241,1.3460698127746582,50000.0,0.5408000349998474,2.057956457138061,10000.0,49508.71371245384,51228.05661511421,49508.71371245384,1710.3399860858915,3.988811254501343,0.0 -145900,3.5325594,1.6543989,,,,,,,,,,,,,, -146000,3.0983856,1.4218472,,,,,,,,,,,,,, -146100,3.2247655,1.5722759,,,,,,,,,,,,,, -146200,3.1187868,1.5095831,,,,,,,,,,,,,, -146300,3.4877079,1.4524002,,,,,,,,,,,,,, -146400,3.322308,1.5797163,,,,,,,,,,,,,, -146500,3.407244,1.4592897,,,,,,,,,,,,,, -146600,3.42441,1.5164891,,,,,,,,,,,,,, -146700,3.4059362,1.5027028,,,,,,,,,,,,,, -146800,3.1441982,1.4574649,,,,,,,,,,,,,, -146900,3.306672,1.4273294,,,,,,,,,,,,,, -147000,3.4377291,1.4951099,,,,,,,,,,,,,, -147100,3.3629882,1.5175376,,,,,,,,,,,,,, -147200,3.3208225,1.422644,,,,,,,,,,,,,, -147300,3.3183913,1.4979424,,,,,,,,,,,,,, -147335,,,0.7348732352256775,1.0262718200683594,0.6617599725723267,1.3817018270492554,50000.0,0.5372000336647034,2.097731828689575,10000.0,50018.80451273918,51755.39443898201,50018.80451273918,1727.4768795967102,4.0461931228637695,0.0 -147400,3.1491408,1.4332023,,,,,,,,,,,,,, -147500,3.6061158,1.3553408,,,,,,,,,,,,,, -147600,3.5344708,1.5630383,,,,,,,,,,,,,, -147700,3.4033692,1.4892223,,,,,,,,,,,,,, -147800,3.1837468,1.363339,,,,,,,,,,,,,, -147900,3.3782163,1.5127858,,,,,,,,,,,,,, -148000,3.5955539,1.4686491,,,,,,,,,,,,,, -148100,3.5435808,1.4892582,,,,,,,,,,,,,, -148200,3.4657524,1.3750025,,,,,,,,,,,,,, -148300,3.3887012,1.4183728,,,,,,,,,,,,,, -148400,3.4243476,1.4440893,,,,,,,,,,,,,, -148500,3.5073469,1.4900236,,,,,,,,,,,,,, -148600,3.1367495,1.4303117,,,,,,,,,,,,,, -148700,3.5336757,1.4973319,,,,,,,,,,,,,, -148800,3.2822282,1.3607022,,,,,,,,,,,,,, -148840,,,0.7347337007522583,1.017677903175354,0.6627399921417236,1.392814040184021,50000.0,0.5409000515937805,2.0931143760681152,10000.0,50528.79414916039,52282.75450348854,50528.79414916039,1744.73273396492,4.108054876327515,0.0 -148900,3.5868037,1.4647386,,,,,,,,,,,,,, -149000,3.460588,1.4338248,,,,,,,,,,,,,, -149100,3.8568027,1.4638777,,,,,,,,,,,,,, -149200,3.3864913,1.4618682,,,,,,,,,,,,,, -149300,3.5401657,1.4720813,,,,,,,,,,,,,, -149400,3.2958293,1.331598,,,,,,,,,,,,,, -149500,3.4542382,1.483885,,,,,,,,,,,,,, -149600,3.7034123,1.386551,,,,,,,,,,,,,, -149700,3.653431,1.448157,,,,,,,,,,,,,, -149800,3.5816915,1.4105537,,,,,,,,,,,,,, -149900,3.0783083,1.3636495,,,,,,,,,,,,,, -150000,3.3544204,1.3851594,,,,,,,,,,,,,, -150100,3.419991,1.3365375,,,,,,,,,,,,,, -150200,3.5074415,1.4292848,,,,,,,,,,,,,, -150300,3.6122344,1.4521656,,,,,,,,,,,,,, -150346,,,0.7524314522743225,0.9449632167816162,0.682159960269928,1.306384563446045,50000.0,0.556600034236908,2.0261521339416504,10000.0,51038.696164131165,52810.16929721832,51038.696164131165,1762.139191389084,4.162874937057495,0.0 -150400,3.3088367,1.4081683,,,,,,,,,,,,,, -150500,3.4011426,1.4297597,,,,,,,,,,,,,, -150600,3.3299894,1.3772047,,,,,,,,,,,,,, -150700,3.5450447,1.3333645,,,,,,,,,,,,,, -150800,3.4972978,1.3466477,,,,,,,,,,,,,, -150900,3.8815978,1.3742336,,,,,,,,,,,,,, -151000,3.6614578,1.3714067,,,,,,,,,,,,,, -151100,3.6028001,1.4123309,,,,,,,,,,,,,, -151200,3.6481867,1.386589,,,,,,,,,,,,,, -151300,3.8219151,1.4362159,,,,,,,,,,,,,, -151400,3.8759334,1.3676279,,,,,,,,,,,,,, -151500,3.5912013,1.4081429,,,,,,,,,,,,,, -151600,3.5291014,1.3480074,,,,,,,,,,,,,, -151700,3.6701138,1.5053507,,,,,,,,,,,,,, -151800,3.549926,1.4662302,,,,,,,,,,,,,, -151852,,,0.7565967440605164,0.9240195155143738,0.6843799948692322,1.2921584844589231,50000.0,0.5592000484466553,2.0278122425079346,10000.0,51548.87925410271,53337.84234023094,51548.87925410271,1779.5146894454956,4.226062297821045,0.0 -151900,3.555927,1.355423,,,,,,,,,,,,,, -152000,3.5198631,1.2830503,,,,,,,,,,,,,, -152100,3.7647474,1.4511391,,,,,,,,,,,,,, -152200,3.4617696,1.4523858,,,,,,,,,,,,,, -152300,3.5952945,1.333877,,,,,,,,,,,,,, -152400,4.009929,1.429918,,,,,,,,,,,,,, -152500,3.1956463,1.2719966,,,,,,,,,,,,,, -152600,3.6371274,1.3407962,,,,,,,,,,,,,, -152700,3.619299,1.3704494,,,,,,,,,,,,,, -152800,3.76877,1.2300797,,,,,,,,,,,,,, -152900,3.579138,1.3069402,,,,,,,,,,,,,, -153000,4.042112,1.4209267,,,,,,,,,,,,,, -153100,3.7542512,1.3147663,,,,,,,,,,,,,, -153200,3.4770133,1.1916577,,,,,,,,,,,,,, -153300,3.6675403,1.3193704,,,,,,,,,,,,,, -153357,,,0.7704081535339355,0.8769794702529907,0.6804400086402893,1.3156960010528564,50000.0,0.5514000058174133,2.04368543624878,10000.0,52058.87175607681,53865.32513237,52058.87175607681,1796.8985974788666,4.281770467758179,0.0 -153400,3.866258,1.3637446,,,,,,,,,,,,,, -153500,3.7100813,1.4340703,,,,,,,,,,,,,, -153600,4.0584784,1.3680255,,,,,,,,,,,,,, -153700,3.9177818,1.52437,,,,,,,,,,,,,, -153800,3.6889212,1.3149018,,,,,,,,,,,,,, -153900,3.6261027,1.3142931,,,,,,,,,,,,,, -154000,4.1470966,1.3946186,,,,,,,,,,,,,, -154100,3.7843342,1.2921282,,,,,,,,,,,,,, -154200,3.9795132,1.3659486,,,,,,,,,,,,,, -154300,3.6569276,1.326745,,,,,,,,,,,,,, -154400,4.0926533,1.3481711,,,,,,,,,,,,,, -154500,3.7410252,1.3794713,,,,,,,,,,,,,, -154600,3.642253,1.4008762,,,,,,,,,,,,,, -154700,3.6685534,1.2800964,,,,,,,,,,,,,, -154800,3.5178156,1.3471434,,,,,,,,,,,,,, -154863,,,0.7912148833274841,0.7727125287055969,0.6976999640464783,1.230203628540039,50000.0,0.5733000040054321,1.9361377954483032,10000.0,52569.02870512009,54393.03028893471,52569.02870512009,1814.3410923480988,4.336676836013794,0.0 -154900,3.6839135,1.3068783,,,,,,,,,,,,,, -155000,3.8441176,1.4192107,,,,,,,,,,,,,, -155100,3.8879957,1.3276556,,,,,,,,,,,,,, -155200,4.0311775,1.2799273,,,,,,,,,,,,,, -155300,4.497233,1.4178932,,,,,,,,,,,,,, -155400,4.199435,1.3982534,,,,,,,,,,,,,, -155500,4.235341,1.2727355,,,,,,,,,,,,,, -155600,4.111707,1.3466568,,,,,,,,,,,,,, -155700,3.8646972,1.2686614,,,,,,,,,,,,,, -155800,3.7874792,1.2218978,,,,,,,,,,,,,, -155900,3.8771806,1.3724357,,,,,,,,,,,,,, -156000,3.723116,1.3331935,,,,,,,,,,,,,, -156100,4.1598754,1.2630113,,,,,,,,,,,,,, -156200,4.0572724,1.3576844,,,,,,,,,,,,,, -156300,4.1410966,1.3969182,,,,,,,,,,,,,, -156369,,,0.7887436151504517,0.7775185704231262,0.6951199769973755,1.2335890531539917,50000.0,0.5689000487327576,1.9560490846633911,10000.0,53079.0686750412,54920.545063734055,53079.0686750412,1831.7030427455904,4.396013021469116,0.0 -156400,4.000057,1.2640657,,,,,,,,,,,,,, -156500,3.8086674,1.2029889,,,,,,,,,,,,,, -156600,3.9008353,1.2443655,,,,,,,,,,,,,, -156700,3.826443,1.3447427,,,,,,,,,,,,,, -156800,3.9335554,1.2111236,,,,,,,,,,,,,, -156900,3.62705,1.2447702,,,,,,,,,,,,,, -157000,3.7400591,1.1650695,,,,,,,,,,,,,, -157100,3.8801482,1.3139769,,,,,,,,,,,,,, -157200,3.8347178,1.3155562,,,,,,,,,,,,,, -157300,3.8230119,1.1492428,,,,,,,,,,,,,, -157400,4.245036,1.3218352,,,,,,,,,,,,,, -157500,4.1293564,1.2094375,,,,,,,,,,,,,, -157600,4.172128,1.2416492,,,,,,,,,,,,,, -157700,3.8595276,1.2018718,,,,,,,,,,,,,, -157800,3.9514756,1.2324572,,,,,,,,,,,,,, -157875,,,0.7881656289100647,0.7825496196746826,0.6997399926185608,1.2260318994522097,50000.0,0.5729000568389893,1.9439477920532229,10000.0,53589.254455804825,55448.21320796013,53589.254455804825,1849.0799005031583,4.450916528701782,0.0 -157900,4.06042,1.3092569,,,,,,,,,,,,,, -158000,4.1297107,1.2610571,,,,,,,,,,,,,, -158100,4.079006,1.335415,,,,,,,,,,,,,, -158200,3.9917486,1.2514375,,,,,,,,,,,,,, -158300,4.4081645,1.3809555,,,,,,,,,,,,,, -158400,4.37957,1.3831303,,,,,,,,,,,,,, -158500,4.0336127,1.2638556,,,,,,,,,,,,,, -158600,4.0659556,1.1958964,,,,,,,,,,,,,, -158700,4.101681,1.3184426,,,,,,,,,,,,,, -158800,3.8639352,1.3068485,,,,,,,,,,,,,, -158900,4.206317,1.2339476,,,,,,,,,,,,,, -159000,4.5435367,1.3237376,,,,,,,,,,,,,, -159100,4.087699,1.312481,,,,,,,,,,,,,, -159200,4.1304464,1.2604343,,,,,,,,,,,,,, -159300,4.3900537,1.2738363,,,,,,,,,,,,,, -159382,,,0.7869299650192261,0.7920331358909607,0.7021799683570862,1.2132388353347778,50000.0,0.5769000053405762,1.903723120689392,10000.0,54099.491090774536,55975.7786552906,54099.491090774536,1866.303447008133,4.505049228668213,0.0 -159400,4.120103,1.4022975,,,,,,,,,,,,,, -159500,4.1706533,1.2645305,,,,,,,,,,,,,, -159600,4.173918,1.1926867,,,,,,,,,,,,,, -159700,4.210714,1.2018116,,,,,,,,,,,,,, -159800,4.2644606,1.2300906,,,,,,,,,,,,,, -159900,4.2469707,1.2834033,,,,,,,,,,,,,, -160000,4.1856737,1.157244,,,,,,,,,,,,,, -160100,3.96404,1.2689598,,,,,,,,,,,,,, -160200,4.1803703,1.3286045,,,,,,,,,,,,,, -160300,4.453998,1.3235551,,,,,,,,,,,,,, -160400,4.455596,1.2662383,,,,,,,,,,,,,, -160500,4.3852415,1.2648559,,,,,,,,,,,,,, -160600,4.3753,1.3345621,,,,,,,,,,,,,, -160700,4.441362,1.2999195,,,,,,,,,,,,,, -160800,4.321473,1.3103769,,,,,,,,,,,,,, -160888,,,0.8048867583274841,0.7252900004386902,0.7111799716949463,1.1790040731430054,50000.0,0.5863000154495239,1.8755098581314087,10000.0,54609.7164978981,56503.59543037415,54609.7164978981,1883.7833399772644,4.564141511917114,0.0 -160900,4.010364,1.2728761,,,,,,,,,,,,,, -161000,4.370873,1.203336,,,,,,,,,,,,,, -161100,4.1098905,1.2165104,,,,,,,,,,,,,, -161200,4.254268,1.2134501,,,,,,,,,,,,,, -161300,4.249337,1.3229711,,,,,,,,,,,,,, -161400,4.204646,1.1626484,,,,,,,,,,,,,, -161500,4.790111,1.3049726,,,,,,,,,,,,,, -161600,4.3369737,1.2110945,,,,,,,,,,,,,, -161700,4.235387,1.0994574,,,,,,,,,,,,,, -161800,4.0037923,1.1481841,,,,,,,,,,,,,, -161900,4.069725,1.1459389,,,,,,,,,,,,,, -162000,4.4255757,1.2779685,,,,,,,,,,,,,, -162100,4.353476,1.2055544,,,,,,,,,,,,,, -162200,4.526634,1.2508866,,,,,,,,,,,,,, -162300,4.4092474,1.3130128,,,,,,,,,,,,,, -162393,,,0.8147919178009033,0.689312756061554,0.7153199911117554,1.1588680744171145,50000.0,0.5861999988555908,1.867242455482483,10000.0,55119.78022289276,57030.947724580765,55119.78022289276,1900.9561693668363,4.627662658691406,0.0 -162400,4.350413,1.1597446,,,,,,,,,,,,,, -162500,4.501353,1.2675782,,,,,,,,,,,,,, -162600,4.3375616,1.1525323,,,,,,,,,,,,,, -162700,4.5839005,1.2099445,,,,,,,,,,,,,, -162800,4.7472906,1.2403331,,,,,,,,,,,,,, -162900,4.503149,1.1404624,,,,,,,,,,,,,, -163000,4.3218737,1.1614234,,,,,,,,,,,,,, -163100,4.6853766,1.2285788,,,,,,,,,,,,,, -163200,4.4443936,1.2168273,,,,,,,,,,,,,, -163300,4.4900575,1.2651961,,,,,,,,,,,,,, -163400,4.292388,1.1856223,,,,,,,,,,,,,, -163500,4.5417285,1.1692804,,,,,,,,,,,,,, -163600,4.522938,1.1703537,,,,,,,,,,,,,, -163700,4.3551054,1.1715581,,,,,,,,,,,,,, -163800,4.1641493,1.0793386,,,,,,,,,,,,,, -163898,,,0.8363161683082581,0.6028015613555908,0.7212600111961365,1.1305629014968872,50000.0,0.5986000299453735,1.7955706119537354,10000.0,55629.80427956581,57558.35658311844,55629.80427956581,1918.2251298427584,4.691629648208618,0.0 -163900,4.2702255,1.1739862,,,,,,,,,,,,,, -164000,4.5805893,1.2012563,,,,,,,,,,,,,, -164100,4.5415893,1.1312133,,,,,,,,,,,,,, -164200,4.6257014,1.1241441,,,,,,,,,,,,,, -164300,4.711272,1.2196746,,,,,,,,,,,,,, -164400,4.7170897,1.2887974,,,,,,,,,,,,,, -164500,4.758689,1.1490928,,,,,,,,,,,,,, -164600,4.4011784,1.2031376,,,,,,,,,,,,,, -164700,4.6744566,1.1864543,,,,,,,,,,,,,, -164800,4.6213775,1.1981274,,,,,,,,,,,,,, -164900,4.926095,1.217404,,,,,,,,,,,,,, -165000,4.701894,1.1383028,,,,,,,,,,,,,, -165100,4.277561,1.1249489,,,,,,,,,,,,,, -165200,4.85104,1.1619422,,,,,,,,,,,,,, -165300,4.6382327,1.1410052,,,,,,,,,,,,,, -165400,4.9581285,1.1308885,,,,,,,,,,,,,, -165403,,,0.8367944955825806,0.594135046005249,0.7243399620056152,1.119240164756775,50000.0,0.6026000380516052,1.800337553024292,10000.0,56139.99762535095,58085.81350302696,56139.99762535095,1935.3796126842497,4.7491774559021,0.0 -165500,4.493449,1.1523681,,,,,,,,,,,,,, -165600,4.4717984,1.1554582,,,,,,,,,,,,,, -165700,4.6584563,1.1406199,,,,,,,,,,,,,, -165800,4.669438,1.1713263,,,,,,,,,,,,,, -165900,4.263961,1.1339937,,,,,,,,,,,,,, -166000,4.518224,1.1549711,,,,,,,,,,,,,, -166100,4.249394,1.0899891,,,,,,,,,,,,,, -166200,4.410007,1.0968527,,,,,,,,,,,,,, -166300,4.7303753,1.1741811,,,,,,,,,,,,,, -166400,4.6683908,1.2083641,,,,,,,,,,,,,, -166500,4.3625894,1.0418115,,,,,,,,,,,,,, -166600,4.6032133,1.0719298,,,,,,,,,,,,,, -166700,4.6841974,1.2335441,,,,,,,,,,,,,, -166800,4.7225027,1.0872263,,,,,,,,,,,,,, -166900,4.809954,1.1350826,,,,,,,,,,,,,, -166907,,,0.8360371589660645,0.5923864245414734,0.7259199619293213,1.118510603904724,50000.0,0.6011000275611877,1.818732500076294,10000.0,56649.96901440621,58613.30370092392,56649.96901440621,1952.791288852692,4.805515766143799,0.0 -167000,4.499627,1.1344614,,,,,,,,,,,,,, -167100,5.099942,1.2120376,,,,,,,,,,,,,, -167200,4.514587,1.0656646,,,,,,,,,,,,,, -167300,4.478696,1.1162441,,,,,,,,,,,,,, -167400,4.607969,1.1043736,,,,,,,,,,,,,, -167500,4.7317104,1.097367,,,,,,,,,,,,,, -167600,4.6627345,1.117305,,,,,,,,,,,,,, -167700,5.2647243,1.1247963,,,,,,,,,,,,,, -167800,4.9934134,1.1608146,,,,,,,,,,,,,, -167900,5.140703,1.1087934,,,,,,,,,,,,,, -168000,4.7203007,1.0393612,,,,,,,,,,,,,, -168100,5.3652544,1.1336129,,,,,,,,,,,,,, -168200,4.8469687,1.183123,,,,,,,,,,,,,, -168300,4.6453834,1.0577334,,,,,,,,,,,,,, -168400,4.4173293,1.0332264,,,,,,,,,,,,,, -168412,,,0.8469985723495483,0.5509080290794373,0.7348999977111816,1.0821901559829712,50000.0,0.6110000014305115,1.7652047872543335,10000.0,57159.98609948158,59141.460246801376,57159.98609948158,1970.818876504898,4.8666205406188965,0.0 -168500,4.732218,1.1336907,,,,,,,,,,,,,, -168600,5.183653,1.1399032,,,,,,,,,,,,,, -168700,5.2308903,1.1519816,,,,,,,,,,,,,, -168800,5.5596657,1.0846255,,,,,,,,,,,,,, -168900,4.8894014,1.1228645,,,,,,,,,,,,,, -169000,4.8824854,1.1634115,,,,,,,,,,,,,, -169100,4.782941,1.0911554,,,,,,,,,,,,,, -169200,4.8196154,1.0305575,,,,,,,,,,,,,, -169300,5.254992,1.2196858,,,,,,,,,,,,,, -169400,4.605938,1.0424478,,,,,,,,,,,,,, -169500,4.7991853,1.0338961,,,,,,,,,,,,,, -169600,4.6648283,1.079567,,,,,,,,,,,,,, -169700,5.1070347,1.0690893,,,,,,,,,,,,,, -169800,4.579163,1.1124208,,,,,,,,,,,,,, -169900,4.9285297,1.0115614,,,,,,,,,,,,,, -169918,,,0.8470583558082581,0.5498757362365723,0.7370799779891968,1.0780162811279297,50000.0,0.610200047492981,1.7671650648117063,10000.0,57670.21110057831,59669.35487747192,57670.21110057831,1988.379658460617,4.923551321029663,0.0 -170000,5.221745,1.2125936,,,,,,,,,,,,,, -170100,4.7208915,0.98437005,,,,,,,,,,,,,, -170200,4.8809466,1.0437874,,,,,,,,,,,,,, -170300,4.8749375,1.0339768,,,,,,,,,,,,,, -170400,5.2966213,1.104096,,,,,,,,,,,,,, -170500,5.077588,1.0522355,,,,,,,,,,,,,, -170600,4.9639907,1.12135,,,,,,,,,,,,,, -170700,5.1693244,1.1140659,,,,,,,,,,,,,, -170800,4.8832197,1.0136018,,,,,,,,,,,,,, -170900,5.7678947,1.1904069,,,,,,,,,,,,,, -171000,5.000684,1.0589962,,,,,,,,,,,,,, -171100,4.9509897,1.0008422,,,,,,,,,,,,,, -171200,4.9342017,1.0463868,,,,,,,,,,,,,, -171300,4.834103,1.0572044,,,,,,,,,,,,,, -171400,4.984068,1.1007508,,,,,,,,,,,,,, -171423,,,0.8548508882522583,0.5203155279159546,0.738599956035614,1.06183123588562,50000.0,0.6117000579833984,1.7470028400421145,10000.0,58180.16239523888,60196.933161735535,58180.16239523888,2005.89358830452,4.983810424804688,0.0 -171500,4.8903937,0.95552146,,,,,,,,,,,,,, -171600,4.4921913,1.0173576,,,,,,,,,,,,,, -171700,5.178202,1.1390043,,,,,,,,,,,,,, -171800,4.743505,0.98484385,,,,,,,,,,,,,, -171900,4.5737853,0.955991,,,,,,,,,,,,,, -172000,4.8870544,1.0044968,,,,,,,,,,,,,, -172100,5.3737903,1.1401551,,,,,,,,,,,,,, -172200,4.578885,1.0236896,,,,,,,,,,,,,, -172300,5.1367774,1.1446701,,,,,,,,,,,,,, -172400,5.375228,1.0687896,,,,,,,,,,,,,, -172500,4.937671,1.0339693,,,,,,,,,,,,,, -172600,5.07808,1.0266496,,,,,,,,,,,,,, -172700,5.2847867,1.0632308,,,,,,,,,,,,,, -172800,5.206591,1.0654652,,,,,,,,,,,,,, -172900,4.9675283,1.0144057,,,,,,,,,,,,,, -172928,,,0.870515763759613,0.4607931077480316,0.7422800064086914,1.0552629232406616,50000.0,0.6200000047683716,1.7335784435272217,10000.0,58690.09550428391,60724.30287194252,58690.09550428391,2023.2214815616608,5.041860818862915,0.0 -173000,4.683286,0.96454465,,,,,,,,,,,,,, -173100,4.894842,1.0117112,,,,,,,,,,,,,, -173200,5.116409,1.027142,,,,,,,,,,,,,, -173300,4.8299766,0.96645784,,,,,,,,,,,,,, -173400,4.8176017,0.9348831,,,,,,,,,,,,,, -173500,4.951265,0.98283046,,,,,,,,,,,,,, -173600,5.0560884,1.0001646,,,,,,,,,,,,,, -173700,5.2631617,1.0996313,,,,,,,,,,,,,, -173800,4.91878,0.9819271,,,,,,,,,,,,,, -173900,4.574842,0.9260841,,,,,,,,,,,,,, -174000,4.817581,0.94795585,,,,,,,,,,,,,, -174100,5.5495996,1.0795766,,,,,,,,,,,,,, -174200,5.573495,1.1285392,,,,,,,,,,,,,, -174300,5.134447,1.0632577,,,,,,,,,,,,,, -174400,5.2807636,1.0126671,,,,,,,,,,,,,, -174433,,,0.8722695708274841,0.4543417692184448,0.7435599565505981,1.0551124811172483,50000.0,0.6190000176429749,1.7485793828964231,10000.0,59200.02728843689,61251.63083457947,59200.02728843689,2040.50359749794,5.102504730224609,0.0 -174500,5.2568274,1.0687898,,,,,,,,,,,,,, -174600,5.284718,0.96131665,,,,,,,,,,,,,, -174700,5.1108136,1.0177578,,,,,,,,,,,,,, -174800,5.1493273,0.98146707,,,,,,,,,,,,,, -174900,4.966596,0.9933841,,,,,,,,,,,,,, -175000,5.382347,1.0374312,,,,,,,,,,,,,, -175100,5.444181,1.106002,,,,,,,,,,,,,, -175200,5.2709184,1.0942336,,,,,,,,,,,,,, -175300,5.0344014,0.9752405,,,,,,,,,,,,,, -175400,4.99457,0.9969562,,,,,,,,,,,,,, -175500,5.068654,1.0271714,,,,,,,,,,,,,, -175600,5.139769,0.99608266,,,,,,,,,,,,,, -175700,4.8266063,0.98173714,,,,,,,,,,,,,, -175800,5.1686926,1.0856191,,,,,,,,,,,,,, -175900,5.162533,0.99287045,,,,,,,,,,,,,, -175938,,,0.8727478981018066,0.4520764946937561,0.7457000017166138,1.043248414993286,50000.0,0.6223000288009644,1.732593059539795,10000.0,59710.2247235775,61779.41142606735,59710.2247235775,2057.974116802216,5.163911104202271,0.0 -176000,5.360081,1.1019636,,,,,,,,,,,,,, -176100,5.049285,0.90150386,,,,,,,,,,,,,, -176200,4.955264,0.8758835,,,,,,,,,,,,,, -176300,5.2087483,1.0856863,,,,,,,,,,,,,, -176400,5.068012,0.9746707,,,,,,,,,,,,,, -176500,5.5833635,0.9547438,,,,,,,,,,,,,, -176600,5.2264194,0.9987975,,,,,,,,,,,,,, -176700,5.4948936,0.985939,,,,,,,,,,,,,, -176800,5.2919035,1.0462885,,,,,,,,,,,,,, -176900,4.745943,0.88861,,,,,,,,,,,,,, -177000,5.1242323,0.95437473,,,,,,,,,,,,,, -177100,4.906967,0.9401814,,,,,,,,,,,,,, -177200,5.1750093,1.0208236,,,,,,,,,,,,,, -177300,5.624727,1.0168735,,,,,,,,,,,,,, -177400,5.440314,0.8893328,,,,,,,,,,,,,, -177443,,,0.875996470451355,0.442630410194397,0.7466399669647217,1.0333547592163086,50000.0,0.6240000128746033,1.7179020643234253,10000.0,60220.32627439499,62306.92303276062,60220.32627439499,2075.270829439163,5.224120616912842,0.0 -177500,5.3363824,1.0038244,,,,,,,,,,,,,, -177600,5.183362,1.0138048,,,,,,,,,,,,,, -177700,5.1827836,0.9798353,,,,,,,,,,,,,, -177800,5.0827622,0.98851,,,,,,,,,,,,,, -177900,5.106245,0.8822094,,,,,,,,,,,,,, -178000,4.766925,0.9124814,,,,,,,,,,,,,, -178100,4.964596,0.9128666,,,,,,,,,,,,,, -178200,5.5064635,0.8844451,,,,,,,,,,,,,, -178300,5.3666477,1.0220397,,,,,,,,,,,,,, -178400,5.116106,0.9294608,,,,,,,,,,,,,, -178500,5.2210817,0.9436648,,,,,,,,,,,,,, -178600,5.699088,0.9860946,,,,,,,,,,,,,, -178700,5.0203195,0.88420165,,,,,,,,,,,,,, -178800,5.1323667,0.953655,,,,,,,,,,,,,, -178900,5.265737,0.94003797,,,,,,,,,,,,,, -178947,,,0.8785873651504517,0.4270265698432922,0.7484599947929382,1.0259389877319336,50000.0,0.6215000152587891,1.717150092124939,10000.0,60730.23235201836,62834.43216466904,60730.23235201836,2092.762161254883,5.284364223480225,0.0 -179000,5.026581,0.87034416,,,,,,,,,,,,,, -179100,4.7147136,0.8332569,,,,,,,,,,,,,, -179200,5.176135,0.9861962,,,,,,,,,,,,,, -179300,5.149054,0.8915421,,,,,,,,,,,,,, -179400,5.309927,0.9500295,,,,,,,,,,,,,, -179500,5.2275867,1.0029497,,,,,,,,,,,,,, -179600,5.1164904,0.96097887,,,,,,,,,,,,,, -179700,5.0241165,0.9261359,,,,,,,,,,,,,, -179800,5.086928,0.9689661,,,,,,,,,,,,,, -179900,5.490121,1.0008863,,,,,,,,,,,,,, -180000,5.5635695,1.0229466,,,,,,,,,,,,,, -180100,5.203721,0.9801722,,,,,,,,,,,,,, -180200,5.228023,0.9129863,,,,,,,,,,,,,, -180300,5.2927465,0.96922857,,,,,,,,,,,,,, -180400,5.6508403,0.9703621,,,,,,,,,,,,,, -180453,,,0.8818159699440002,0.412717342376709,0.7495399713516235,1.025421977043152,50000.0,0.6262000203132629,1.7159655094146729,10000.0,61240.38385462761,63361.98218679428,61240.38385462761,2110.0475058555603,5.346153259277344,0.0 -180500,5.33765,0.98446417,,,,,,,,,,,,,, -180600,5.2502213,0.931221,,,,,,,,,,,,,, -180700,5.0382476,0.89371365,,,,,,,,,,,,,, -180800,4.99784,0.9592421,,,,,,,,,,,,,, -180900,5.779017,0.90844464,,,,,,,,,,,,,, -181000,5.102966,0.9719903,,,,,,,,,,,,,, -181100,5.4696274,0.94112456,,,,,,,,,,,,,, -181200,4.868287,0.912818,,,,,,,,,,,,,, -181300,5.2202625,0.92906654,,,,,,,,,,,,,, -181400,5.1542845,0.9664856,,,,,,,,,,,,,, -181500,5.5683146,0.9607618,,,,,,,,,,,,,, -181600,5.3399477,0.89049494,,,,,,,,,,,,,, -181700,4.8162875,0.8330686,,,,,,,,,,,,,, -181800,5.261429,0.93396086,,,,,,,,,,,,,, -181900,5.2818184,0.9492558,,,,,,,,,,,,,, -181959,,,0.8836495280265808,0.4111799895763397,0.7505399584770203,1.0216362476348877,50000.0,0.6243000030517578,1.715208888053894,10000.0,61750.597148656845,63889.77865052223,61750.597148656845,2127.516664505005,5.409894943237305,0.0 -182000,5.290526,0.98405194,,,,,,,,,,,,,, -182100,5.3938026,0.92903465,,,,,,,,,,,,,, -182200,5.1999917,0.89897037,,,,,,,,,,,,,, -182300,5.369145,1.0193992,,,,,,,,,,,,,, -182400,5.552199,0.9089053,,,,,,,,,,,,,, -182500,5.2738204,0.94252634,,,,,,,,,,,,,, -182600,5.033254,0.9000032,,,,,,,,,,,,,, -182700,5.4031377,0.9776203,,,,,,,,,,,,,, -182800,5.498868,0.98022085,,,,,,,,,,,,,, -182900,5.116648,0.8412213,,,,,,,,,,,,,, -183000,5.384098,0.9388537,,,,,,,,,,,,,, -183100,5.174068,0.9250475,,,,,,,,,,,,,, -183200,5.032075,0.90757585,,,,,,,,,,,,,, -183300,5.382914,1.0209875,,,,,,,,,,,,,, -183400,5.109159,0.8045098,,,,,,,,,,,,,, -183465,,,0.884785532951355,0.4123013913631439,0.7511999607086182,1.020043134689331,50000.0,0.6269000172615051,1.7149865627288818,10000.0,62260.746554374695,64417.35872173309,62260.746554374695,2144.8360488414764,5.469226121902466,0.0 -183500,4.9131646,0.86076677,,,,,,,,,,,,,, -183600,5.356259,0.9151376,,,,,,,,,,,,,, -183700,5.0901613,0.8951177,,,,,,,,,,,,,, -183800,5.026398,1.0318692,,,,,,,,,,,,,, -183900,5.1706905,0.9669059,,,,,,,,,,,,,, -184000,5.2674885,0.92345613,,,,,,,,,,,,,, -184100,5.2341475,0.93641394,,,,,,,,,,,,,, -184200,5.535393,0.96645296,,,,,,,,,,,,,, -184300,5.318607,0.885807,,,,,,,,,,,,,, -184400,5.3697033,0.95217705,,,,,,,,,,,,,, -184500,5.4932323,0.95847535,,,,,,,,,,,,,, -184600,5.1643896,0.94224095,,,,,,,,,,,,,, -184700,5.1036477,0.96519935,,,,,,,,,,,,,, -184800,5.0380955,0.90303916,,,,,,,,,,,,,, -184900,5.164923,0.90532583,,,,,,,,,,,,,, -184971,,,0.8868184089660645,0.4034609794616699,0.7515400052070618,1.0192183256149292,50000.0,0.6272000074386597,1.7114449739456177,10000.0,62770.95883798599,64944.98903656006,62770.95883798599,2162.1414761543274,5.530078649520874,0.0 -185000,5.134821,0.8901613,,,,,,,,,,,,,, -185100,5.384822,0.9423262,,,,,,,,,,,,,, -185200,5.012449,0.9463469,,,,,,,,,,,,,, -185300,4.943231,0.9199063,,,,,,,,,,,,,, -185400,5.6246543,0.90926534,,,,,,,,,,,,,, -185500,5.1980033,0.99124455,,,,,,,,,,,,,, -185600,5.4727244,1.1226985,,,,,,,,,,,,,, -185672,,,,,,,,,,,63008.26549935341,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 2cb560159..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.4037983417511,0.0,35.2137176990509,1,0,35.2137176990509,0.0012000000569969,6.910790920257568,10000,52.61761689186096,0.0011160713620483,6.910408973693848,0.0010199999669566,6.91091251373291,50000 -34.85956573486328,0.019899845123291,545.2359344959259,1491,0,545.2359344959259,0.1160000041127204,4.982067584991455,10000,580.1692199707031,0.1638831347227096,4.366621017456055,0.1497000008821487,4.4936041831970215,50000 -52.29863715171814,0.0464069843292236,1055.3937442302704,2982,0,1055.3937442302704,0.2299000173807144,3.933526277542114,10000,1107.8423628807068,0.3359175622463226,3.0918776988983154,0.3116599917411804,3.26209020614624,50000 -69.99219536781311,0.0735189914703369,1565.4191055297852,4474,0,1565.4191055297852,0.3278000056743622,3.2873244285583496,10000,1635.6378026008606,0.5045041441917419,2.1400482654571533,0.4286199808120727,2.554784059524536,50000 -87.56485724449158,0.0971796512603759,2075.5087237358093,5968,0,2075.5087237358093,0.3862000107765198,2.930450439453125,10000,2163.373923540116,0.5469945669174194,1.9179505109786987,0.499019980430603,2.203223943710327,50000 -105.21260619163512,0.1253807544708252,2585.5674998760223,7462,0,2585.5674998760223,0.4089000225067138,2.7833967208862305,10000,2691.1594779491425,0.5733617544174194,1.780825972557068,0.5251399874687195,2.044733047485352,50000 -122.51022338867188,0.1549932956695556,3095.8154296875,8957,0,3095.8154296875,0.4390000104904175,2.5789635181427,10000,3218.784925699234,0.6165497303009033,1.578463792800903,0.5652799606323242,1.845713257789612,50000 -139.8828718662262,0.1847729682922363,3605.894075870514,10451,0,3605.894075870514,0.4404000341892242,2.6110658645629883,10000,3746.31577205658,0.6073620915412903,1.6080734729766846,0.5593799948692322,1.851853370666504,50000 -157.3446958065033,0.2137207984924316,4115.927650213242,11945,0,4115.927650213242,0.4595000147819519,2.491586923599243,10000,4273.891491889954,0.6295639276504517,1.503599762916565,0.5816599726676941,1.7506680488586426,50000 -174.75857639312744,0.24312424659729,4626.153445720673,13440,0,4626.153445720673,0.447700023651123,2.5825207233428955,10000,4801.611012220383,0.6180644035339355,1.5705331563949585,0.5727199912071228,1.823119044303894,50000 -192.2714822292328,0.2714576721191406,5136.090697288513,14935,0,5136.090697288513,0.4702000319957733,2.433048725128174,10000,5329.14058303833,0.68363356590271,1.2555805444717407,0.6050199866294861,1.6677595376968384,50000 -209.6851592063904,0.3002409934997558,5646.080711364746,16430,0,5646.080711364746,0.4785000085830688,2.401660203933716,10000,5856.62567782402,0.6679487824440002,1.3212976455688477,0.6027599573135376,1.6629142761230469,50000 -227.21400547027588,0.3311781883239746,6156.199065208435,17926,0,6156.199065208435,0.4845000207424164,2.360335111618042,10000,6384.354900598526,0.6713767647743225,1.295975685119629,0.6087599992752075,1.628176212310791,50000 -244.56793308258057,0.361548900604248,6666.292007684708,19421,0,6666.292007684708,0.4795000255107879,2.3939144611358643,10000,6911.881479263306,0.6640226244926453,1.3366535902023315,0.6086399555206299,1.63140869140625,50000 -262.7482252120972,0.3934106826782226,7176.418109893799,20917,0,7176.418109893799,0.4636000096797943,2.501309394836426,10000,7440.27081990242,0.6470623016357422,1.4147604703903198,0.5916999578475952,1.7275773286819458,50000 -280.29457664489746,0.4244773387908935,7686.367953538895,22412,0,7686.367953538895,0.4834000170230865,2.372260332107544,10000,7967.849324464798,0.6626673936843872,1.3498613834381104,0.6109600067138672,1.6298362016677856,50000 -297.5520570278168,0.4546489715576172,8196.371249198914,23907,0,8196.371249198914,0.4861000180244446,2.371723175048828,10000,8495.189532518387,0.7183513641357422,1.091708064079285,0.6167199611663818,1.610385775566101,50000 -315.2184538841248,0.4857900142669678,8706.593649148941,25403,0,8706.593649148941,0.4731000363826751,2.415712833404541,10000,9023.160031318665,0.6804647445678711,1.2763067483901978,0.6107999682426453,1.6338390111923218,50000 -332.7638320922852,0.5164616107940674,9216.7435195446,26899,0,9216.7435195446,0.4951000213623047,2.306924104690552,10000,9550.937099695206,0.6908880472183228,1.216625094413757,0.6270999908447266,1.541603446006775,50000 -350.1199746131897,0.5478644371032715,9726.73595571518,28395,0,9726.73595571518,0.4909000098705292,2.328392505645752,10000,10078.36783695221,0.6801259517669678,1.2542836666107178,0.6233199834823608,1.5583561658859253,50000 -367.5783729553223,0.5819680690765381,10236.778829574583,29891,0,10236.778829574583,0.4980000257492065,2.307891368865967,10000,10605.955340862274,0.6824178695678711,1.2555991411209106,0.625499963760376,1.545114517211914,50000 -385.03712701797485,0.6160974502563477,10747.021076440811,31387,0,10747.021076440811,0.5041000247001648,2.2895352840423584,10000,11133.741675138474,0.6882373690605164,1.2267718315124512,0.6324999928474426,1.5249176025390625,50000 -402.7224214076996,0.6510787010192871,11257.115474939346,32883,0,11257.115474939346,0.4873000085353851,2.34653377532959,10000,11661.608564853668,0.6759008169174194,1.282272458076477,0.6207599639892578,1.5945985317230225,50000 -420.22812843322754,0.6878213882446289,11767.26364517212,34379,0,11767.26364517212,0.4999000132083893,2.280648231506348,10000,12189.351138353348,0.7140066623687744,1.1095143556594849,0.6321399807929993,1.5328267812728882,50000 -437.8310143947601,0.7197163105010986,12277.20189833641,35875,0,12277.20189833641,0.4993000328540802,2.2765913009643555,10000,12716.973129034042,0.6957509517669678,1.1831389665603638,0.6268399953842163,1.5437438488006592,50000 -455.1110055446625,0.7531900405883789,12787.222305297852,37371,0,12787.222305297852,0.5092000365257263,2.2315361499786377,10000,13244.35971903801,0.7004544138908386,1.1742907762527466,0.6349799633026123,1.5201865434646606,50000 -472.4876246452332,0.7889235019683838,13297.305636644363,38868,0,13297.305636644363,0.5082000494003296,2.2616703510284424,10000,13771.90760755539,0.6949737071990967,1.1847723722457886,0.6327599883079529,1.518254041671753,50000 -490.2734615802765,0.8251686096191406,13807.231940984726,40364,0,13807.231940984726,0.4972000122070312,2.3025379180908203,10000,14299.705271959305,0.6914859414100647,1.2236030101776123,0.6318399906158447,1.5228060483932495,50000 -507.6738419532776,0.8653068542480469,14317.350157022476,41861,0,14317.350157022476,0.5012000203132629,2.2280113697052,10000,14827.315202951431,0.6909677982330322,1.2019314765930176,0.6317600011825562,1.5205481052398682,50000 -525.1450872421265,0.903350830078125,14827.423000097277,43358,0,14827.423000097277,0.5042999982833862,2.259047508239746,10000,15354.948499202728,0.726980984210968,1.0523135662078855,0.6317399740219116,1.5127232074737549,50000 -542.4807982444763,0.9419848918914796,15337.656279802322,44855,0,15337.656279802322,0.5175000429153442,2.1700892448425293,10000,15882.606467962263,0.7177534699440002,1.0975770950317385,0.6431800127029419,1.4693011045455933,50000 -559.6336979866028,0.9760580062866212,15847.795320510864,46351,0,15847.795320510864,0.5148000121116638,2.235180854797364,10000,16409.9840195179,0.7119937539100647,1.1093826293945312,0.6420199871063232,1.4766788482666016,50000 -576.9686605930328,1.016944408416748,16358.012321472168,47848,0,16358.012321472168,0.5045000314712524,2.241513967514038,10000,16937.62736606598,0.7031847834587097,1.1564139127731323,0.6360999941825867,1.489286184310913,50000 -594.6073455810547,1.0527050495147705,16868.260396003723,49344,0,16868.260396003723,0.5211000442504883,2.169034481048584,10000,17465.60064959526,0.7031847834587097,1.1441307067871094,0.6427199840545654,1.467440843582153,50000 -612.1412818431854,1.091554880142212,17378.219309806824,50841,0,17378.219309806824,0.5149000287055969,2.193108081817627,10000,17993.183161497116,0.7056361436843872,1.1398847103118896,0.645039975643158,1.4592891931533811,50000 -629.606897354126,1.1325068473815918,17888.419243574142,52338,0,17888.419243574142,0.520300030708313,2.1776282787323,10000,18520.939923524857,0.7645886540412903,0.90327787399292,0.6543799638748169,1.4254825115203855,50000 -646.8635385036469,1.1695225238800049,18398.483984470367,53835,0,18398.483984470367,0.5029000043869019,2.2745490074157715,10000,19048.349376678467,0.7227758169174194,1.0640006065368652,0.6343599557876587,1.4988964796066284,50000 -664.4518418312073,1.2079482078552246,18908.700251817703,55332,0,18908.700251817703,0.5034000277519226,2.26247239112854,10000,19576.242521762848,0.7080675959587097,1.1316912174224854,0.6315799951553345,1.5241401195526123,50000 -682.0746810436249,1.2414581775665283,19418.939255476,56829,0,19418.939255476,0.5152000188827515,2.216600179672241,10000,20104.18901371956,0.7178930044174194,1.0814168453216553,0.6469199657440186,1.4597644805908203,50000 -700.2618687152863,1.2832703590393066,19929.052931785583,58326,0,19929.052931785583,0.5131000280380249,2.1649599075317383,10000,20632.583993673325,0.7149434089660645,1.1157106161117554,0.6482999920845032,1.441026210784912,50000 -717.730926990509,1.3240761756896973,20439.141530275345,59823,0,20439.141530275345,0.5279000401496887,2.151419162750244,10000,21160.23457360268,0.7216398119926453,1.0750752687454224,0.6587199568748474,1.4014229774475098,50000 -734.9866235256195,1.3634462356567385,20949.247843265533,61320,0,20949.247843265533,0.5362000465393066,2.0705716609954834,10000,21687.68998861313,0.7239716053009033,1.0552133321762085,0.6634399890899658,1.3708513975143433,50000 -752.3298766613007,1.4020438194274902,21459.278631210327,62817,0,21459.278631210327,0.5330000519752502,2.138416290283203,10000,22215.15426421165,0.7576530575752258,0.9276350736618042,0.6600199937820435,1.3925979137420654,50000 -769.6031460762024,1.4653689861297607,21969.453302383423,64314,0,21969.453302383423,0.5248000025749207,2.172521114349365,10000,22742.71888899803,0.7278180718421936,1.0450934171676636,0.6518799662590027,1.438127517700195,50000 -787.1383633613586,1.5099318027496338,22479.49555540085,65811,0,22479.49555540085,0.5311000347137451,2.125349998474121,10000,23270.3928463459,0.7356903553009033,1.0099648237228394,0.6637600064277649,1.3708646297454834,50000 -804.4793326854706,1.55267596244812,22989.733140707016,67308,0,22989.733140707016,0.5324000120162964,2.098825216293335,10000,23798.065600156784,0.7371651530265808,0.9949511885643004,0.664359986782074,1.357032299041748,50000 -821.963175535202,1.5933401584625244,23499.958671808243,68806,0,23499.958671808243,0.5281000137329102,2.129694700241089,10000,24325.86668920517,0.7223772406578064,1.0757921934127808,0.6567599773406982,1.39765727519989,50000 -839.4227304458618,1.6327154636383057,24009.941289901733,70302,0,24009.941289901733,0.5242000222206116,2.144479513168335,10000,24853.39986562729,0.7234334945678711,1.0573880672454834,0.656719982624054,1.40852689743042,50000 -856.6956856250763,1.6729493141174316,24520.09197568893,71800,0,24520.09197568893,0.5440000295639038,2.086559772491455,10000,25380.915219783783,0.7745535373687744,0.8367967009544373,0.6677199602127075,1.3619836568832395,50000 -874.2368021011353,1.720144510269165,25030.211496591568,73297,0,25030.211496591568,0.5418000221252441,2.0513622760772705,10000,25908.675062417984,0.7525908350944519,0.937186062335968,0.6672799587249756,1.3505654335021973,50000 -891.7931699752808,1.7655150890350342,25540.13827967644,74793,0,25540.13827967644,0.5467000007629395,2.070542812347412,10000,26436.25676727295,0.7455755472183228,0.9578787088394164,0.6665999889373779,1.366343379020691,50000 -909.11328291893,1.809826374053955,26050.24849486351,76290,0,26050.24849486351,0.5372000336647034,2.145915985107422,10000,26963.783942461014,0.73441481590271,1.0014700889587402,0.6627599596977234,1.396458864212036,50000 -926.391303539276,1.854209899902344,26560.44718551636,77787,0,26560.44718551636,0.5320000052452087,2.0675227642059326,10000,27491.354991436005,0.7419283986091614,0.9948947429656982,0.668179988861084,1.3509660959243774,50000 -943.8302927017212,1.894223690032959,27070.60922384262,79284,0,27070.60922384262,0.5475000143051147,2.050487756729126,10000,28019.048672914505,0.7411909699440002,0.993799090385437,0.6715399622917175,1.3349510431289673,50000 -961.185555934906,1.935410022735596,27580.572791814804,80781,0,27580.572791814804,0.5409000515937805,2.0658974647521973,10000,28546.45977640152,0.7743940949440002,0.8552932739257812,0.6669399738311768,1.3668314218521118,50000 -978.5394492149352,1.9794235229492188,28090.6806986332,82278,0,28090.6806986332,0.5436000227928162,2.066192865371704,10000,29074.016382932663,0.7655652165412903,0.8683584332466125,0.6740999817848206,1.3325233459472656,50000 -996.1524906158448,2.0239923000335693,28600.69172692299,83775,0,28600.69172692299,0.5362000465393066,2.076328754425049,10000,29601.736602783203,0.7556401491165161,0.9285488128662108,0.6721199750900269,1.338887095451355,50000 -1013.6288385391236,2.0686020851135254,29110.69561982155,85272,0,29110.69561982155,0.5489000082015991,2.041975259780884,10000,30129.31273341179,0.7602040767669678,0.8968546390533447,0.6789399981498718,1.307782769203186,50000 -1031.1397771835327,2.111499547958374,29620.89577460289,86770,0,29620.89577460289,0.5494000315666199,2.010836839675904,10000,30657.11715722084,0.7505978941917419,0.9393353462219238,0.6771399974822998,1.3065553903579712,50000 -1048.6075825691223,2.1597437858581543,30130.933844089508,88267,0,30130.933844089508,0.5533000230789185,2.0117292404174805,10000,31184.721259593964,0.7596260905265808,0.9097063541412354,0.6830799579620361,1.28650164604187,50000 -1065.8815150260923,2.206571578979492,30641.0526971817,89764,0,30641.0526971817,0.5552000403404236,2.016308069229126,10000,31712.21232533455,0.7589086294174194,0.9112529158592224,0.682379961013794,1.2872778177261353,50000 -1083.19038939476,2.25357723236084,31151.199570178986,91261,0,31151.199570178986,0.5560000538825989,2.0203568935394287,10000,32239.76553440094,0.7894411683082581,0.784350574016571,0.6829000115394592,1.2923929691314695,50000 -1100.471552848816,2.3035874366760254,31661.18835258484,92757,0,31661.18835258484,0.5391000509262085,2.09073805809021,10000,32767.13542819023,0.7609016299247742,0.8901998996734619,0.6728399991989136,1.3282952308654783,50000 -1117.723935842514,2.349766731262207,32171.235827207565,94254,0,32171.235827207565,0.5612000226974487,1.9653502702713013,10000,33294.53087544441,0.7751315236091614,0.8435813784599304,0.6869999766349792,1.2639058828353882,50000 -1135.1968188285828,2.4002671241760254,32681.18448400497,95751,0,32681.18448400497,0.555400013923645,2.016524314880371,10000,33822.053425073624,0.7679567933082581,0.8627640008926392,0.681659996509552,1.2954697608947754,50000 -1153.3456366062164,2.44881272315979,33191.21137213707,97248,0,33191.21137213707,0.5631999969482422,1.985821008682251,10000,34350.32893657684,0.7716238498687744,0.8549726605415344,0.691540002822876,1.2567367553710938,50000 -1171.3598392009735,2.493917226791382,33701.27909350395,98745,0,33701.27909350395,0.5597000122070312,1.9715975522994995,10000,34878.50706458092,0.7628945708274841,0.8821855187416077,0.6861000061035156,1.2782268524169922,50000 -1188.65078830719,3.0731360912323,34210.90730881691,100241,0,34210.90730881691,0.5509000420570374,2.0477240085601807,10000,35406.05687189102,0.8028938174247742,0.7278432250022888,0.6812599897384644,1.3054099082946775,50000 -1206.1760022640228,3.119248867034912,34721.09689593315,101739,0,34721.09689593315,0.5597000122070312,1.9839948415756223,10000,35933.86906027794,0.788504421710968,0.7852997183799744,0.6868799924850464,1.2690773010253906,50000 -1223.5552270412445,3.163418292999268,35231.05541777611,103236,0,35231.05541777611,0.5722000002861023,1.93617844581604,10000,36461.30139732361,0.7882851958274841,0.7907478213310242,0.6976199746131897,1.247995376586914,50000 -1240.7657074928284,3.2133994102478027,35741.056241989136,104733,0,35741.056241989136,0.5595000386238098,1.9859092235565183,10000,36988.61328577995,0.7801936864852905,0.8114662766456604,0.6913599967956543,1.254683017730713,50000 -1258.3189299106598,3.259010076522827,36250.98030781746,106230,0,36250.98030781746,0.5678000450134277,1.9703844785690308,10000,37516.186690330505,0.7833226919174194,0.8032361268997192,0.6902399659156799,1.247302532196045,50000 -1275.867201089859,3.323164701461792,36760.9456949234,107727,0,36760.9456949234,0.5678000450134277,1.957873106002808,10000,38043.81551671028,0.7869698405265808,0.769432544708252,0.6983199715614319,1.2229222059249878,50000 -1293.1777880191803,3.372274875640869,37271.04302740097,109225,0,37271.04302740097,0.5749000310897827,1.95384418964386,10000,38571.32220196724,0.8082947731018066,0.6996628046035767,0.6953799724578857,1.2410041093826294,50000 -1310.7889490127563,3.420320987701416,37781.17570281029,110722,0,37781.17570281029,0.5769000053405762,1.93671452999115,10000,39099.16570162773,0.813875138759613,0.6721604466438293,0.7019599676132202,1.222309947013855,50000 -1328.0241174697876,3.4695346355438232,38291.41183042526,112219,0,38291.41183042526,0.5773000121116638,1.9144089221954343,10000,39626.73736596108,0.8077367544174194,0.6889608502388,0.7064599990844727,1.1906296014785769,50000 -1345.405649185181,3.519113063812256,38801.31273698807,113716,0,38801.31273698807,0.5742000341415405,1.97724187374115,10000,40154.12018656731,0.7943239808082581,0.7536799311637878,0.6958999633789062,1.2339434623718262,50000 -1362.8529393672943,3.5819525718688965,39311.23474597931,115212,0,39311.23474597931,0.5767000317573547,1.938300848007202,10000,40681.605519771576,0.7994260191917419,0.7419718503952026,0.6971399784088135,1.2301865816116333,50000 -1380.4033725261688,3.630246162414551,39821.20298480988,116709,0,39821.20298480988,0.5850000381469727,1.901528239250183,10000,41209.22456550598,0.8053252100944519,0.7091084718704224,0.7084199786186218,1.1827819347381592,50000 -1397.9720528125763,3.681635856628418,40331.18197154999,118206,0,40331.18197154999,0.5771000385284424,1.9114309549331665,10000,41736.874911785126,0.8056241869926453,0.7094724178314209,0.7069999575614929,1.1886725425720217,50000 -1415.2570950984957,3.732133626937866,40841.25449848175,119703,0,40841.25449848175,0.5821000337600708,1.9020187854766848,10000,42264.33503699303,0.8384486436843872,0.5811982154846191,0.7049599885940552,1.1874204874038696,50000 -1432.6516358852386,3.784732341766357,41351.66305828095,121201,0,41351.66305828095,0.5819000005722046,1.924055814743042,10000,42792.24165916443,0.8234614133834839,0.6339573860168457,0.7074999809265137,1.1968789100646973,50000 -1450.1179354190826,3.839577436447144,41861.57820510864,122698,0,41861.57820510864,0.5800000429153442,1.9360084533691408,10000,43319.72976899147,0.8205317258834839,0.6411925554275513,0.707040011882782,1.196536898612976,50000 -1467.3553059101105,3.890242338180542,42371.47960424423,124194,0,42371.47960424423,0.5819000005722046,1.9214253425598145,10000,43846.969765901566,0.8153499364852905,0.6613895297050476,0.7084000110626221,1.1882474422454834,50000 -1484.7452561855316,3.943494319915772,42881.63268017769,125692,0,42881.63268017769,0.5859000086784363,1.90839946269989,10000,44374.61629462242,0.8228236436843872,0.6361862421035767,0.7134000062942505,1.174013614654541,50000 -1502.1120376586914,3.9931235313415527,43391.758655786514,127190,0,43391.758655786514,0.5815000534057617,1.9391067028045648,10000,44902.20900511742,0.8191167116165161,0.6429269909858704,0.7112399935722351,1.1914821863174438,50000 -1519.7223196029663,4.045359134674072,43901.96710586548,128687,0,43901.96710586548,0.5940000414848328,1.870269656181336,10000,45430.13155961037,0.8657525181770325,0.4807564318180084,0.7155599594116211,1.1618179082870483,50000 -1536.9437124729156,4.095425844192505,44412.060628175735,130184,0,44412.060628175735,0.5944000482559204,1.8714016675949097,10000,45957.546800136566,0.8509646058082581,0.52117919921875,0.7196599841117859,1.150366187095642,50000 -1554.3216423988342,4.150796890258789,44922.07017183304,131681,0,44922.07017183304,0.5985000133514404,1.852310061454773,10000,46485.04060125351,0.8581393361091614,0.5003353953361511,0.721340000629425,1.128917932510376,50000 -1571.6787357330322,4.2078258991241455,45432.17251110077,133178,0,45432.17251110077,0.5966000556945801,1.8334498405456543,10000,47012.60780596733,0.8520009517669678,0.5243096351623535,0.7253999710083008,1.120864987373352,50000 -1589.7894802093506,4.2574567794799805,45942.078404426575,134674,0,45942.078404426575,0.6020000576972961,1.8749887943267824,10000,47540.725546360016,0.8475764989852905,0.5381889343261719,0.7217999696731567,1.1407110691070557,50000 -1607.0494379997251,4.309852600097656,46452.042345047,136171,0,46452.042345047,0.5981000065803528,1.8592684268951416,10000,48068.0526099205,0.8484932780265808,0.5277658104896545,0.7238399982452393,1.1313154697418213,50000 -1624.3033895492554,4.361900568008423,46961.99141907692,137668,0,46961.99141907692,0.5975000262260437,1.896036386489868,10000,48595.35859775543,0.865652859210968,0.4712446928024292,0.7237399816513062,1.1422827243804932,50000 -1641.7458896636963,4.415813446044922,47472.01818680763,139165,0,47472.01818680763,0.612500011920929,1.8532575368881223,10000,49122.931619644165,0.8843072056770325,0.402018278837204,0.7286999821662903,1.1118732690811155,50000 -1659.0822570323944,4.472891807556152,47982.18217277527,140663,0,47982.18217277527,0.6103000044822693,1.8223932981491089,10000,49650.53896570206,0.87890625,0.4263043105602264,0.7313599586486816,1.112805724143982,50000 -1676.9486873149872,4.525469541549683,48492.161371946335,142160,0,48492.161371946335,0.6021000146865845,1.843847751617432,10000,50178.48799037933,0.8776904940605164,0.4207326173782348,0.7316799759864807,1.1086206436157229,50000 -1694.4568555355072,4.5699193477630615,49002.2856194973,143657,0,49002.2856194973,0.6050000190734863,1.839349269866944,10000,50706.21445202828,0.8729472160339355,0.4409449398517608,0.7316799759864807,1.1113321781158447,50000 -1711.908765077591,4.625460624694824,49512.30458474159,145154,0,49512.30458474159,0.6053000092506409,1.8584744930267327,10000,51233.79352784157,0.8752989172935486,0.4280344545841217,0.7317799925804138,1.1089249849319458,50000 -1729.127257347107,4.677754163742065,50022.35403752327,146652,0,50022.35403752327,0.6035000085830688,1.850600242614746,10000,51761.16456365585,0.8880739808082581,0.3907930254936218,0.734279990196228,1.1064696311950684,50000 -1746.6031787395475,4.732522487640381,50532.34589600563,148149,0,50532.34589600563,0.6094000339508057,1.84394109249115,10000,52288.738582372665,0.91019606590271,0.3179280161857605,0.7368199825286865,1.0996601581573486,50000 -1764.0998284816742,4.791525602340698,51042.26584935188,149644,0,51042.26584935188,0.6175000071525574,1.8288391828536987,10000,52816.26387476921,0.9052136540412904,0.3252921402454376,0.7401399612426758,1.093822717666626,50000 -1781.5769836902618,4.874719858169556,51552.15370512009,151141,0,51552.15370512009,0.6135000586509705,1.828898549079895,10000,53343.76315808296,0.906668484210968,0.3213539719581604,0.7433599829673767,1.087909460067749,50000 -1799.1908648014069,4.927303552627564,52062.18849515915,152637,0,52062.18849515915,0.6148000359535217,1.844693660736084,10000,53871.51702570915,0.9010881781578064,0.3470576405525207,0.7376999855041504,1.1096341609954834,50000 -1816.406150817871,4.986354112625122,52572.27054524422,154134,0,52572.27054524422,0.6132000088691711,1.846460461616516,10000,54398.92409610748,0.9094188213348388,0.3132840991020202,0.7424799799919128,1.0863618850708008,50000 -1833.7602362632751,5.041287660598755,53082.467106580734,155632,0,53082.467106580734,0.6140000224113464,1.8501039743423464,10000,54926.58153581619,0.9090999364852904,0.3144540786743164,0.7421799898147583,1.0850337743759155,50000 -1851.127019643784,5.0969154834747314,53592.5631840229,157129,0,53592.5631840229,0.6165000200271606,1.826570749282837,10000,55454.15119338036,0.938257336616516,0.2249719649553299,0.7435799837112427,1.074958086013794,50000 -1868.317990541458,5.159096002578735,54102.53592252731,158626,0,54102.53592252731,0.6187000274658203,1.83165442943573,10000,55981.42848825455,0.9301857352256776,0.2437449693679809,0.7449399828910828,1.078855276107788,50000 -1885.6868290901184,5.215553045272827,54612.44617629051,160123,0,54612.44617629051,0.622700035572052,1.8242336511611936,10000,56508.81551861763,0.9317004084587096,0.2384717613458633,0.7465999722480774,1.0716962814331057,50000 -1902.8871002197263,5.275780916213989,55122.4037668705,161619,0,55122.4037668705,0.6202000379562378,1.837498545646668,10000,57036.08502626419,0.9328961968421936,0.2328613251447677,0.7475000023841858,1.067963480949402,50000 -1920.4235010147088,5.330445289611816,55632.31921625137,163115,0,55632.31921625137,0.6237000226974487,1.835316061973572,10000,57563.64182114601,0.9344307780265808,0.2265849560499191,0.7470999956130981,1.070637583732605,50000 -1937.6732861995697,5.390793085098267,56142.43895721436,164613,0,56142.43895721436,0.622700035572052,1.835475325584412,10000,58091.12245035172,0.9340322017669678,0.2242102921009063,0.7490800023078918,1.070961356163025,50000 -1955.1959567070007,5.446483373641968,56652.48220562935,166109,0,56652.48220562935,0.624500036239624,1.8405344486236568,10000,58618.7962179184,0.9427216053009032,0.2048528641462326,0.7487599849700928,1.0646597146987915,50000 -1972.493607759476,5.505061149597168,57162.58238697052,167606,0,57162.58238697052,0.6233000159263611,1.841591119766236,10000,59146.30214428902,0.9520288109779358,0.1748383939266204,0.750499963760376,1.0643622875213623,50000 -1989.808545589447,5.561115026473999,57672.57480549812,169103,0,57672.57480549812,0.6274000406265259,1.835326075553894,10000,59673.71718621254,0.9522680044174194,0.1750165224075317,0.750819981098175,1.0584666728973389,50000 -2007.2889399528503,5.620239496231079,58182.56483054161,170600,0,58182.56483054161,0.6294000148773193,1.825202226638794,10000,60201.29833936691,0.9530253410339355,0.1685916483402252,0.7524200081825256,1.058369517326355,50000 -2024.735008716584,5.683067798614502,58692.67744731903,172097,0,58692.67744731903,0.628600001335144,1.8379465341567995,10000,60728.97080159187,0.951809585094452,0.1718765199184417,0.7515000104904175,1.0592249631881714,50000 -2042.6753525733948,5.773506164550781,59202.68797492981,173594,0,59202.68797492981,0.6281000375747681,1.8330553770065308,10000,61257.06341433525,0.9528858065605164,0.1682156473398208,0.754040002822876,1.060206651687622,50000 -2060.2560591697693,5.830533027648926,59712.79919052124,175090,0,59712.79919052124,0.6284000277519226,1.82981026172638,10000,61784.8630464077,0.9554169178009032,0.1644298136234283,0.754040002822876,1.0577527284622192,50000 -2077.571899175644,5.891931772232056,60222.87142467499,176587,0,60222.87142467499,0.626800000667572,1.835967302322388,10000,62312.36273407936,0.960558831691742,0.1476965844631195,0.7540599703788757,1.0551767349243164,50000 -2094.967592716217,5.955878973007202,60733.02505970001,178084,0,60733.02505970001,0.6317000389099121,1.831668496131897,10000,62840.026733636856,0.959582269191742,0.1489991694688797,0.7542799711227417,1.0563178062438965,50000 -2112.402285337448,6.015355348587036,61242.96651220322,179581,0,61242.96651220322,0.6299000382423401,1.83180034160614,10000,63367.51571488381,0.9616549611091614,0.144578143954277,0.7544199824333191,1.0557457208633425,50000 -2129.514147043228,6.073501348495483,61753.01442456245,181078,0,61753.01442456245,0.629800021648407,1.8307338953018188,10000,63894.78460788727,0.9592434167861938,0.1504308581352234,0.7547599673271179,1.052348494529724,50000 -2146.771764278412,6.137564897537232,62263.06065821648,182575,0,62263.06065821648,0.6289000511169434,1.8282490968704224,10000,64422.20252633095,0.9605388641357422,0.1476653218269348,0.7547999620437622,1.0509791374206543,50000 -2164.0824568271637,6.202368974685669,62773.27559399605,184073,0,62773.27559399605,0.6287000179290771,1.829309105873108,10000,64949.84421133995,0.9612962007522583,0.14372682571411133,0.7549600005149841,1.0522350072860718,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/measurements.csv deleted file mode 100644 index 60446a8fb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1974 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6531614,6.929166,,,,,,,,,,,,,, -1,,,0.0011160713620483,6.910408973693848,0.0010199999669566,6.91091251373291,50000.0,0.0012000000569969,6.910790920257568,10000.0,35.2137176990509,52.61761689186096,35.2137176990509,17.4037983417511,0.0,0.0 -100,0.6630403,6.8140035,,,,,,,,,,,,,, -200,0.8227039,6.5890102,,,,,,,,,,,,,, -300,0.9289721,6.2478576,,,,,,,,,,,,,, -400,2.1543655,6.0050435,,,,,,,,,,,,,, -500,2.6169074,5.830378,,,,,,,,,,,,,, -600,2.4998019,5.599595,,,,,,,,,,,,,, -700,2.9808533,5.444866,,,,,,,,,,,,,, -800,3.1509216,5.2607207,,,,,,,,,,,,,, -900,4.428672,5.1540527,,,,,,,,,,,,,, -1000,5.6860337,5.021341,,,,,,,,,,,,,, -1100,4.92774,5.0364575,,,,,,,,,,,,,, -1200,4.266155,4.740297,,,,,,,,,,,,,, -1300,7.3782797,4.703537,,,,,,,,,,,,,, -1400,3.4980986,4.7378006,,,,,,,,,,,,,, -1491,,,0.1638831347227096,4.366621017456055,0.1497000008821487,4.4936041831970215,50000.0,0.1160000041127204,4.982067584991455,10000.0,545.2359344959259,580.1692199707031,545.2359344959259,34.85956573486328,0.019899845123291,0.0 -1500,8.455964,4.490657,,,,,,,,,,,,,, -1600,5.8182836,4.1999497,,,,,,,,,,,,,, -1700,6.348964,4.261139,,,,,,,,,,,,,, -1800,7.798219,4.2230225,,,,,,,,,,,,,, -1900,9.411899,3.9792755,,,,,,,,,,,,,, -2000,5.5290866,4.043638,,,,,,,,,,,,,, -2100,3.7657535,3.9794943,,,,,,,,,,,,,, -2200,11.020764,3.8824859,,,,,,,,,,,,,, -2300,4.304227,3.75112,,,,,,,,,,,,,, -2400,3.608837,3.8090518,,,,,,,,,,,,,, -2500,5.025325,3.5693593,,,,,,,,,,,,,, -2600,3.0366356,3.5958557,,,,,,,,,,,,,, -2700,4.636557,3.5665302,,,,,,,,,,,,,, -2800,3.429832,3.4382074,,,,,,,,,,,,,, -2900,2.9080338,3.4983573,,,,,,,,,,,,,, -2982,,,0.3359175622463226,3.0918776988983154,0.3116599917411804,3.26209020614624,50000.0,0.2299000173807144,3.933526277542114,10000.0,1055.3937442302704,1107.8423628807068,1055.3937442302704,52.29863715171814,0.0464069843292236,0.0 -3000,3.6235704,3.342705,,,,,,,,,,,,,, -3100,2.6099076,3.279894,,,,,,,,,,,,,, -3200,4.5829496,3.4084902,,,,,,,,,,,,,, -3300,4.7207584,3.199921,,,,,,,,,,,,,, -3400,3.7447295,3.1543465,,,,,,,,,,,,,, -3500,3.0292046,3.1348643,,,,,,,,,,,,,, -3600,4.179173,3.2390647,,,,,,,,,,,,,, -3700,3.0236034,3.0259528,,,,,,,,,,,,,, -3800,2.8445802,3.0771027,,,,,,,,,,,,,, -3900,3.373919,3.0759916,,,,,,,,,,,,,, -4000,2.6403205,3.1833198,,,,,,,,,,,,,, -4100,3.1399143,2.9088311,,,,,,,,,,,,,, -4200,3.7165937,2.9216347,,,,,,,,,,,,,, -4300,4.08592,2.8567348,,,,,,,,,,,,,, -4400,3.494207,2.7048178,,,,,,,,,,,,,, -4474,,,0.5045041441917419,2.1400482654571533,0.4286199808120727,2.554784059524536,50000.0,0.3278000056743622,3.2873244285583496,10000.0,1565.4191055297852,1635.6378026008606,1565.4191055297852,69.99219536781311,0.0735189914703369,0.0 -4500,3.00781,2.876083,,,,,,,,,,,,,, -4600,3.084065,2.756735,,,,,,,,,,,,,, -4700,3.4611132,2.64235,,,,,,,,,,,,,, -4800,2.2493145,3.0212598,,,,,,,,,,,,,, -4900,3.3029826,2.5702603,,,,,,,,,,,,,, -5000,3.2302022,2.6348865,,,,,,,,,,,,,, -5100,2.5621226,2.5445223,,,,,,,,,,,,,, -5200,1.9211032,2.5241704,,,,,,,,,,,,,, -5300,1.837455,2.527914,,,,,,,,,,,,,, -5400,2.6043293,2.5194006,,,,,,,,,,,,,, -5500,2.7198153,2.4495826,,,,,,,,,,,,,, -5600,2.4545693,2.6520374,,,,,,,,,,,,,, -5700,2.9064443,2.3874445,,,,,,,,,,,,,, -5800,1.738317,2.4923577,,,,,,,,,,,,,, -5900,2.6515098,2.509531,,,,,,,,,,,,,, -5968,,,0.5469945669174194,1.9179505109786987,0.499019980430603,2.203223943710327,50000.0,0.3862000107765198,2.930450439453125,10000.0,2075.5087237358093,2163.373923540116,2075.5087237358093,87.56485724449158,0.0971796512603759,0.0 -6000,2.8497162,2.477576,,,,,,,,,,,,,, -6100,1.6373339,2.4649065,,,,,,,,,,,,,, -6200,2.6538713,2.443359,,,,,,,,,,,,,, -6300,2.1348417,2.417902,,,,,,,,,,,,,, -6400,1.9919044,2.5043402,,,,,,,,,,,,,, -6500,2.409383,2.3700123,,,,,,,,,,,,,, -6600,2.514054,2.3962862,,,,,,,,,,,,,, -6700,1.6079115,2.3549016,,,,,,,,,,,,,, -6800,2.1299565,2.243574,,,,,,,,,,,,,, -6900,2.3891666,2.330082,,,,,,,,,,,,,, -7000,2.1048326,2.2984123,,,,,,,,,,,,,, -7100,2.423318,2.3800962,,,,,,,,,,,,,, -7200,1.9879441,2.3554273,,,,,,,,,,,,,, -7300,1.5916417,2.375623,,,,,,,,,,,,,, -7400,1.9391376,2.2005465,,,,,,,,,,,,,, -7462,,,0.5733617544174194,1.780825972557068,0.5251399874687195,2.044733047485352,50000.0,0.4089000225067138,2.7833967208862305,10000.0,2585.5674998760223,2691.1594779491425,2585.5674998760223,105.21260619163512,0.1253807544708252,0.0 -7500,2.014609,2.2979703,,,,,,,,,,,,,, -7600,1.9998722,2.3215826,,,,,,,,,,,,,, -7700,2.3349578,2.363868,,,,,,,,,,,,,, -7800,1.5153995,2.127717,,,,,,,,,,,,,, -7900,2.0055819,2.2993762,,,,,,,,,,,,,, -8000,2.3673072,2.3758404,,,,,,,,,,,,,, -8100,2.0292492,2.2071128,,,,,,,,,,,,,, -8200,2.0202863,2.2740712,,,,,,,,,,,,,, -8300,1.8997283,2.1550064,,,,,,,,,,,,,, -8400,1.7045754,2.1076245,,,,,,,,,,,,,, -8500,2.0667481,2.2483814,,,,,,,,,,,,,, -8600,1.7936156,2.1981318,,,,,,,,,,,,,, -8700,1.5539461,2.0955348,,,,,,,,,,,,,, -8800,1.9295733,2.222718,,,,,,,,,,,,,, -8900,1.5584334,2.1339502,,,,,,,,,,,,,, -8957,,,0.6165497303009033,1.578463792800903,0.5652799606323242,1.845713257789612,50000.0,0.4390000104904175,2.5789635181427,10000.0,3095.8154296875,3218.784925699234,3095.8154296875,122.51022338867188,0.1549932956695556,0.0 -9000,1.3094823,2.1791613,,,,,,,,,,,,,, -9100,1.6918794,2.159874,,,,,,,,,,,,,, -9200,1.7227325,2.1181912,,,,,,,,,,,,,, -9300,1.5656078,2.035725,,,,,,,,,,,,,, -9400,1.4691292,2.0519092,,,,,,,,,,,,,, -9500,1.5120971,2.1388068,,,,,,,,,,,,,, -9600,1.8089814,2.110492,,,,,,,,,,,,,, -9700,1.5697281,2.0857108,,,,,,,,,,,,,, -9800,1.5953453,2.1204906,,,,,,,,,,,,,, -9900,1.7043027,2.088224,,,,,,,,,,,,,, -10000,1.7488713,2.130824,,,,,,,,,,,,,, -10100,2.0870829,2.1319957,,,,,,,,,,,,,, -10200,1.7803322,2.1123173,,,,,,,,,,,,,, -10300,1.5385677,1.9962003,,,,,,,,,,,,,, -10400,1.411245,2.0908537,,,,,,,,,,,,,, -10451,,,0.6073620915412903,1.6080734729766846,0.5593799948692322,1.851853370666504,50000.0,0.4404000341892242,2.6110658645629883,10000.0,3605.894075870514,3746.31577205658,3605.894075870514,139.8828718662262,0.1847729682922363,0.0 -10500,2.2842674,2.0062351,,,,,,,,,,,,,, -10600,1.9908627,1.9291768,,,,,,,,,,,,,, -10700,1.5185779,1.9642905,,,,,,,,,,,,,, -10800,1.6245978,2.0234032,,,,,,,,,,,,,, -10900,1.3329247,1.9070578,,,,,,,,,,,,,, -11000,2.0400567,1.9553193,,,,,,,,,,,,,, -11100,1.9412115,2.0896688,,,,,,,,,,,,,, -11200,1.2793827,2.0523322,,,,,,,,,,,,,, -11300,1.5873425,2.0169222,,,,,,,,,,,,,, -11400,1.9546797,1.998673,,,,,,,,,,,,,, -11500,2.1389253,1.9538052,,,,,,,,,,,,,, -11600,1.5791321,1.9989219,,,,,,,,,,,,,, -11700,1.9250065,1.8807449,,,,,,,,,,,,,, -11800,1.7818905,2.0026991,,,,,,,,,,,,,, -11900,1.3509833,1.9491699,,,,,,,,,,,,,, -11945,,,0.6295639276504517,1.503599762916565,0.5816599726676941,1.7506680488586426,50000.0,0.4595000147819519,2.491586923599243,10000.0,4115.927650213242,4273.891491889954,4115.927650213242,157.3446958065033,0.2137207984924316,0.0 -12000,1.8839813,1.9919826,,,,,,,,,,,,,, -12100,1.707797,2.0073304,,,,,,,,,,,,,, -12200,1.6003151,2.0235827,,,,,,,,,,,,,, -12300,1.7448565,1.9013462,,,,,,,,,,,,,, -12400,1.2768084,2.0078242,,,,,,,,,,,,,, -12500,1.5515206,1.9625916,,,,,,,,,,,,,, -12600,1.3033502,2.0569637,,,,,,,,,,,,,, -12700,1.4962436,1.9246577,,,,,,,,,,,,,, -12800,1.6594007,1.9744344,,,,,,,,,,,,,, -12900,1.6567867,1.9864451,,,,,,,,,,,,,, -13000,1.7060059,1.9433209,,,,,,,,,,,,,, -13100,1.8516208,1.9973528,,,,,,,,,,,,,, -13200,1.4309844,1.9635773,,,,,,,,,,,,,, -13300,1.2104268,1.939172,,,,,,,,,,,,,, -13400,1.455501,2.0821342,,,,,,,,,,,,,, -13440,,,0.6180644035339355,1.5705331563949585,0.5727199912071228,1.823119044303894,50000.0,0.447700023651123,2.5825207233428955,10000.0,4626.153445720673,4801.611012220383,4626.153445720673,174.75857639312744,0.24312424659729,0.0 -13500,1.6466328,2.0570922,,,,,,,,,,,,,, -13600,1.5022738,1.9354348,,,,,,,,,,,,,, -13700,1.5348752,1.8552312,,,,,,,,,,,,,, -13800,1.4889538,2.0077786,,,,,,,,,,,,,, -13900,1.4172554,1.8618481,,,,,,,,,,,,,, -14000,1.5246643,2.0789304,,,,,,,,,,,,,, -14100,1.7846674,1.8592832,,,,,,,,,,,,,, -14200,1.5961037,1.8481942,,,,,,,,,,,,,, -14300,1.5872512,1.9530072,,,,,,,,,,,,,, -14400,1.777334,1.9129837,,,,,,,,,,,,,, -14500,1.4943496,1.9352697,,,,,,,,,,,,,, -14600,1.7347311,1.864281,,,,,,,,,,,,,, -14700,1.3036293,1.9437499,,,,,,,,,,,,,, -14800,1.8867062,1.8659719,,,,,,,,,,,,,, -14900,1.5237095,1.7873589,,,,,,,,,,,,,, -14935,,,0.68363356590271,1.2555805444717407,0.6050199866294861,1.6677595376968384,50000.0,0.4702000319957733,2.433048725128174,10000.0,5136.090697288513,5329.14058303833,5136.090697288513,192.2714822292328,0.2714576721191406,0.0 -15000,1.934681,1.8046988,,,,,,,,,,,,,, -15100,1.3849913,2.0174766,,,,,,,,,,,,,, -15200,1.8550626,1.9385794,,,,,,,,,,,,,, -15300,1.5782319,1.8452079,,,,,,,,,,,,,, -15400,1.5094632,1.9467539,,,,,,,,,,,,,, -15500,1.8484823,1.8841832,,,,,,,,,,,,,, -15600,1.6349154,2.0386481,,,,,,,,,,,,,, -15700,1.4025974,2.0221906,,,,,,,,,,,,,, -15800,1.7916089,1.8300972,,,,,,,,,,,,,, -15900,1.47882,1.8991606,,,,,,,,,,,,,, -16000,1.686953,1.9684649,,,,,,,,,,,,,, -16100,1.887794,1.777545,,,,,,,,,,,,,, -16200,1.5212219,1.8938464,,,,,,,,,,,,,, -16300,1.4894612,1.731633,,,,,,,,,,,,,, -16400,1.4014593,1.813277,,,,,,,,,,,,,, -16430,,,0.6679487824440002,1.3212976455688477,0.6027599573135376,1.6629142761230469,50000.0,0.4785000085830688,2.401660203933716,10000.0,5646.080711364746,5856.62567782402,5646.080711364746,209.6851592063904,0.3002409934997558,0.0 -16500,1.6942255,1.8460808,,,,,,,,,,,,,, -16600,1.7872384,1.8246361,,,,,,,,,,,,,, -16700,1.6460818,1.9140782,,,,,,,,,,,,,, -16800,1.7945353,1.8261973,,,,,,,,,,,,,, -16900,1.6442279,1.8254542,,,,,,,,,,,,,, -17000,1.6808298,1.8681712,,,,,,,,,,,,,, -17100,1.5380032,1.834666,,,,,,,,,,,,,, -17200,1.4592639,1.8248069,,,,,,,,,,,,,, -17300,1.4303398,1.7545198,,,,,,,,,,,,,, -17400,1.6205529,1.8132131,,,,,,,,,,,,,, -17500,1.4014329,1.8573916,,,,,,,,,,,,,, -17600,1.6148765,1.8138748,,,,,,,,,,,,,, -17700,1.5353141,1.9394137,,,,,,,,,,,,,, -17800,1.4454399,1.9168704,,,,,,,,,,,,,, -17900,1.5582768,1.8795804,,,,,,,,,,,,,, -17926,,,0.6713767647743225,1.295975685119629,0.6087599992752075,1.628176212310791,50000.0,0.4845000207424164,2.360335111618042,10000.0,6156.199065208435,6384.354900598526,6156.199065208435,227.21400547027588,0.3311781883239746,0.0 -18000,1.5921687,1.95055,,,,,,,,,,,,,, -18100,1.5067027,1.8174,,,,,,,,,,,,,, -18200,1.4302386,1.7630143,,,,,,,,,,,,,, -18300,1.6214317,1.8221478,,,,,,,,,,,,,, -18400,1.5968087,1.9523423,,,,,,,,,,,,,, -18500,1.4648912,1.7601838,,,,,,,,,,,,,, -18600,1.6444836,1.8678387,,,,,,,,,,,,,, -18700,1.4051907,1.7690966,,,,,,,,,,,,,, -18800,1.8255748,1.9145986,,,,,,,,,,,,,, -18900,1.4762307,1.834902,,,,,,,,,,,,,, -19000,1.8261442,1.9652796,,,,,,,,,,,,,, -19100,1.7141924,1.8445252,,,,,,,,,,,,,, -19200,1.7826829,1.9622126,,,,,,,,,,,,,, -19300,1.715591,1.8349361,,,,,,,,,,,,,, -19400,1.6362789,1.9918737,,,,,,,,,,,,,, -19421,,,0.6640226244926453,1.3366535902023315,0.6086399555206299,1.63140869140625,50000.0,0.4795000255107879,2.3939144611358643,10000.0,6666.292007684708,6911.881479263306,6666.292007684708,244.56793308258057,0.361548900604248,0.0 -19500,1.6701549,1.8546628,,,,,,,,,,,,,, -19600,1.7475084,1.8656805,,,,,,,,,,,,,, -19700,1.8336388,1.8661168,,,,,,,,,,,,,, -19800,1.6285957,1.8629644,,,,,,,,,,,,,, -19900,1.5367668,1.9073147,,,,,,,,,,,,,, -20000,1.5496414,1.7855172,,,,,,,,,,,,,, -20100,1.3878946,1.7271829,,,,,,,,,,,,,, -20200,1.4407066,1.787288,,,,,,,,,,,,,, -20300,1.769482,1.9264331,,,,,,,,,,,,,, -20400,1.6821485,1.800587,,,,,,,,,,,,,, -20500,1.904317,1.7537805,,,,,,,,,,,,,, -20600,1.7275186,1.7555956,,,,,,,,,,,,,, -20700,1.452388,1.7490933,,,,,,,,,,,,,, -20800,1.8212699,1.7397132,,,,,,,,,,,,,, -20900,1.6513658,1.9732153,,,,,,,,,,,,,, -20917,,,0.6470623016357422,1.4147604703903198,0.5916999578475952,1.7275773286819458,50000.0,0.4636000096797943,2.501309394836426,10000.0,7176.418109893799,7440.27081990242,7176.418109893799,262.7482252120972,0.3934106826782226,0.0 -21000,1.6701567,1.835166,,,,,,,,,,,,,, -21100,1.7248698,1.9241916,,,,,,,,,,,,,, -21200,1.6136857,1.8606467,,,,,,,,,,,,,, -21300,1.6697955,1.8463355,,,,,,,,,,,,,, -21400,1.5123637,1.7829804,,,,,,,,,,,,,, -21500,1.5861206,1.8290975,,,,,,,,,,,,,, -21600,1.9241515,1.919646,,,,,,,,,,,,,, -21700,1.5039316,1.6924375,,,,,,,,,,,,,, -21800,1.9257376,1.8939263,,,,,,,,,,,,,, -21900,1.5209308,1.8805543,,,,,,,,,,,,,, -22000,1.5321722,1.818083,,,,,,,,,,,,,, -22100,1.6086978,1.7831621,,,,,,,,,,,,,, -22200,1.6622701,1.7145402,,,,,,,,,,,,,, -22300,1.7375679,1.7798634,,,,,,,,,,,,,, -22400,1.8604279,1.7631619,,,,,,,,,,,,,, -22412,,,0.6626673936843872,1.3498613834381104,0.6109600067138672,1.6298362016677856,50000.0,0.4834000170230865,2.372260332107544,10000.0,7686.367953538895,7967.849324464798,7686.367953538895,280.29457664489746,0.4244773387908935,0.0 -22500,1.7254167,1.805847,,,,,,,,,,,,,, -22600,1.5613605,1.7285544,,,,,,,,,,,,,, -22700,1.6488286,1.6911666,,,,,,,,,,,,,, -22800,1.5917542,1.7935715,,,,,,,,,,,,,, -22900,1.6537086,1.688049,,,,,,,,,,,,,, -23000,1.900372,1.8154441,,,,,,,,,,,,,, -23100,1.7909126,1.8160506,,,,,,,,,,,,,, -23200,1.6718332,1.7979696,,,,,,,,,,,,,, -23300,1.6472224,1.7753216,,,,,,,,,,,,,, -23400,1.631181,1.8462675,,,,,,,,,,,,,, -23500,1.9349577,1.9120716,,,,,,,,,,,,,, -23600,1.6089698,1.7915914,,,,,,,,,,,,,, -23700,1.7002685,1.7766418,,,,,,,,,,,,,, -23800,1.5689448,1.7525579,,,,,,,,,,,,,, -23900,1.9585863,1.7412994,,,,,,,,,,,,,, -23907,,,0.7183513641357422,1.091708064079285,0.6167199611663818,1.610385775566101,50000.0,0.4861000180244446,2.371723175048828,10000.0,8196.371249198914,8495.189532518387,8196.371249198914,297.5520570278168,0.4546489715576172,0.0 -24000,1.776529,1.8685293,,,,,,,,,,,,,, -24100,1.9791679,1.8855878,,,,,,,,,,,,,, -24200,1.544134,1.7566645,,,,,,,,,,,,,, -24300,1.6367319,1.754522,,,,,,,,,,,,,, -24400,1.8507022,1.8081024,,,,,,,,,,,,,, -24500,1.6737696,1.7451866,,,,,,,,,,,,,, -24600,1.9335291,1.7982643,,,,,,,,,,,,,, -24700,1.6480017,1.7273755,,,,,,,,,,,,,, -24800,1.5846201,1.7163392,,,,,,,,,,,,,, -24900,1.6143124,1.8779026,,,,,,,,,,,,,, -25000,1.7755822,1.7428708,,,,,,,,,,,,,, -25100,1.6876801,1.7589362,,,,,,,,,,,,,, -25200,1.7152753,1.6878353,,,,,,,,,,,,,, -25300,1.6510803,1.7358947,,,,,,,,,,,,,, -25400,1.9519517,1.8908974,,,,,,,,,,,,,, -25403,,,0.6804647445678711,1.2763067483901978,0.6107999682426453,1.6338390111923218,50000.0,0.4731000363826751,2.415712833404541,10000.0,8706.593649148941,9023.160031318665,8706.593649148941,315.2184538841248,0.4857900142669678,0.0 -25500,1.949543,1.7083657,,,,,,,,,,,,,, -25600,1.7530614,1.8765467,,,,,,,,,,,,,, -25700,1.6285976,1.8052249,,,,,,,,,,,,,, -25800,1.6488886,1.7408516,,,,,,,,,,,,,, -25900,1.6444852,1.9184266,,,,,,,,,,,,,, -26000,1.8202643,1.6872303,,,,,,,,,,,,,, -26100,1.588761,1.7431756,,,,,,,,,,,,,, -26200,1.7880336,1.823592,,,,,,,,,,,,,, -26300,1.7793535,1.7362608,,,,,,,,,,,,,, -26400,1.7970464,1.8025094,,,,,,,,,,,,,, -26500,1.6680857,1.7193536,,,,,,,,,,,,,, -26600,1.6990544,1.662519,,,,,,,,,,,,,, -26700,1.7659166,1.6385272,,,,,,,,,,,,,, -26800,1.8167695,1.8118482,,,,,,,,,,,,,, -26899,,,0.6908880472183228,1.216625094413757,0.6270999908447266,1.541603446006775,50000.0,0.4951000213623047,2.306924104690552,10000.0,9216.7435195446,9550.937099695206,9216.7435195446,332.7638320922852,0.5164616107940674,0.0 -26900,1.7248431,1.7425169,,,,,,,,,,,,,, -27000,1.8012965,1.9380287,,,,,,,,,,,,,, -27100,1.7101103,1.8092637,,,,,,,,,,,,,, -27200,1.6603796,1.8329256,,,,,,,,,,,,,, -27300,1.6481904,1.6826931,,,,,,,,,,,,,, -27400,1.6837251,1.6509542,,,,,,,,,,,,,, -27500,1.5263535,1.6669472,,,,,,,,,,,,,, -27600,1.7525455,1.7337,,,,,,,,,,,,,, -27700,1.6168418,1.6912683,,,,,,,,,,,,,, -27800,1.7781957,1.7938627,,,,,,,,,,,,,, -27900,1.9265969,1.8297122,,,,,,,,,,,,,, -28000,1.827156,1.6625618,,,,,,,,,,,,,, -28100,1.7730778,1.7345096,,,,,,,,,,,,,, -28200,1.5265232,1.6986506,,,,,,,,,,,,,, -28300,1.944748,1.6437463,,,,,,,,,,,,,, -28395,,,0.6801259517669678,1.2542836666107178,0.6233199834823608,1.5583561658859253,50000.0,0.4909000098705292,2.328392505645752,10000.0,9726.73595571518,10078.36783695221,9726.73595571518,350.1199746131897,0.5478644371032715,0.0 -28400,1.7565914,1.7519904,,,,,,,,,,,,,, -28500,1.9201647,1.6965611,,,,,,,,,,,,,, -28600,1.6874193,1.6970419,,,,,,,,,,,,,, -28700,1.6388656,1.7434245,,,,,,,,,,,,,, -28800,1.4769542,1.6331693,,,,,,,,,,,,,, -28900,1.6593909,1.8295177,,,,,,,,,,,,,, -29000,1.9394003,1.8048211,,,,,,,,,,,,,, -29100,1.7423996,1.5640968,,,,,,,,,,,,,, -29200,1.641504,1.6874162,,,,,,,,,,,,,, -29300,1.6542895,1.776829,,,,,,,,,,,,,, -29400,1.6924312,1.7307522,,,,,,,,,,,,,, -29500,1.6593896,1.8149635,,,,,,,,,,,,,, -29600,1.6039509,1.7226694,,,,,,,,,,,,,, -29700,1.7026334,1.6002215,,,,,,,,,,,,,, -29800,1.7153459,1.6690958,,,,,,,,,,,,,, -29891,,,0.6824178695678711,1.2555991411209106,0.625499963760376,1.545114517211914,50000.0,0.4980000257492065,2.307891368865967,10000.0,10236.778829574583,10605.955340862274,10236.778829574583,367.5783729553223,0.5819680690765381,0.0 -29900,1.6917095,1.775113,,,,,,,,,,,,,, -30000,1.7388879,1.7065144,,,,,,,,,,,,,, -30100,1.5963359,1.7150851,,,,,,,,,,,,,, -30200,1.7234577,1.7403655,,,,,,,,,,,,,, -30300,1.733052,1.6074864,,,,,,,,,,,,,, -30400,1.5227773,1.7268485,,,,,,,,,,,,,, -30500,1.8046123,1.7551347,,,,,,,,,,,,,, -30600,2.3069372,1.7977015,,,,,,,,,,,,,, -30700,1.6729863,1.7960223,,,,,,,,,,,,,, -30800,1.8928666,1.6880121,,,,,,,,,,,,,, -30900,1.7852982,1.6068254,,,,,,,,,,,,,, -31000,1.6662767,1.6765693,,,,,,,,,,,,,, -31100,1.8385952,1.6857665,,,,,,,,,,,,,, -31200,1.8403211,1.6688449,,,,,,,,,,,,,, -31300,1.8421972,1.7585076,,,,,,,,,,,,,, -31387,,,0.6882373690605164,1.2267718315124512,0.6324999928474426,1.5249176025390625,50000.0,0.5041000247001648,2.2895352840423584,10000.0,10747.021076440811,11133.741675138474,10747.021076440811,385.03712701797485,0.6160974502563477,0.0 -31400,1.724608,1.8039253,,,,,,,,,,,,,, -31500,1.9592651,1.8417217,,,,,,,,,,,,,, -31600,1.7059251,1.7733654,,,,,,,,,,,,,, -31700,1.7935544,1.707382,,,,,,,,,,,,,, -31800,1.7596651,1.7534695,,,,,,,,,,,,,, -31900,1.766655,1.7119099,,,,,,,,,,,,,, -32000,1.6725897,1.7493657,,,,,,,,,,,,,, -32100,1.9151664,1.708777,,,,,,,,,,,,,, -32200,1.7526188,1.7177569,,,,,,,,,,,,,, -32300,1.7457212,1.7234303,,,,,,,,,,,,,, -32400,1.8827441,1.7487177,,,,,,,,,,,,,, -32500,1.9059426,1.7170942,,,,,,,,,,,,,, -32600,1.6960034,1.6484314,,,,,,,,,,,,,, -32700,1.7978618,1.6936516,,,,,,,,,,,,,, -32800,1.7193983,1.6533245,,,,,,,,,,,,,, -32883,,,0.6759008169174194,1.282272458076477,0.6207599639892578,1.5945985317230225,50000.0,0.4873000085353851,2.34653377532959,10000.0,11257.115474939346,11661.608564853668,11257.115474939346,402.7224214076996,0.6510787010192871,0.0 -32900,1.6405036,1.7124109,,,,,,,,,,,,,, -33000,1.7486293,1.7390101,,,,,,,,,,,,,, -33100,1.6724181,1.5989343,,,,,,,,,,,,,, -33200,1.615312,1.6517471,,,,,,,,,,,,,, -33300,1.8131869,1.7337402,,,,,,,,,,,,,, -33400,1.7831017,1.6441069,,,,,,,,,,,,,, -33500,1.7694218,1.7930439,,,,,,,,,,,,,, -33600,1.7965262,1.7952168,,,,,,,,,,,,,, -33700,1.8489437,1.7342706,,,,,,,,,,,,,, -33800,1.73845,1.7496347,,,,,,,,,,,,,, -33900,2.0212667,1.8262054,,,,,,,,,,,,,, -34000,1.615562,1.6860135,,,,,,,,,,,,,, -34100,1.8367193,1.8046842,,,,,,,,,,,,,, -34200,2.073257,1.7019978,,,,,,,,,,,,,, -34300,1.8046184,1.7141855,,,,,,,,,,,,,, -34379,,,0.7140066623687744,1.1095143556594849,0.6321399807929993,1.5328267812728882,50000.0,0.4999000132083893,2.280648231506348,10000.0,11767.26364517212,12189.351138353348,11767.26364517212,420.22812843322754,0.6878213882446289,0.0 -34400,1.7382255,1.6332991,,,,,,,,,,,,,, -34500,1.8084395,1.6990764,,,,,,,,,,,,,, -34600,1.8389522,1.6577787,,,,,,,,,,,,,, -34700,1.7117573,1.7325335,,,,,,,,,,,,,, -34800,1.8962272,1.7950526,,,,,,,,,,,,,, -34900,1.8346807,1.7332311,,,,,,,,,,,,,, -35000,1.7575485,1.5958042,,,,,,,,,,,,,, -35100,1.7092398,1.6505847,,,,,,,,,,,,,, -35200,1.6610762,1.6759627,,,,,,,,,,,,,, -35300,1.6402967,1.6478955,,,,,,,,,,,,,, -35400,1.7799318,1.6632425,,,,,,,,,,,,,, -35500,1.7706115,1.6663139,,,,,,,,,,,,,, -35600,1.7470714,1.7466325,,,,,,,,,,,,,, -35700,1.8311226,1.7032132,,,,,,,,,,,,,, -35800,1.9455773,1.7428497,,,,,,,,,,,,,, -35875,,,0.6957509517669678,1.1831389665603638,0.6268399953842163,1.5437438488006592,50000.0,0.4993000328540802,2.2765913009643555,10000.0,12277.20189833641,12716.973129034042,12277.20189833641,437.8310143947601,0.7197163105010986,0.0 -35900,1.9614972,1.680314,,,,,,,,,,,,,, -36000,1.8648362,1.7097208,,,,,,,,,,,,,, -36100,1.7608777,1.7169893,,,,,,,,,,,,,, -36200,1.535034,1.5046966,,,,,,,,,,,,,, -36300,1.76186,1.5868413,,,,,,,,,,,,,, -36400,1.7107759,1.6656318,,,,,,,,,,,,,, -36500,1.867836,1.7816169,,,,,,,,,,,,,, -36600,1.8604481,1.8330059,,,,,,,,,,,,,, -36700,1.7815531,1.5702552,,,,,,,,,,,,,, -36800,1.7258053,1.4864535,,,,,,,,,,,,,, -36900,1.7853379,1.6213839,,,,,,,,,,,,,, -37000,1.7954752,1.6942011,,,,,,,,,,,,,, -37100,2.0316212,1.7063984,,,,,,,,,,,,,, -37200,1.6704689,1.5900942,,,,,,,,,,,,,, -37300,1.868192,1.7633278,,,,,,,,,,,,,, -37371,,,0.7004544138908386,1.1742907762527466,0.6349799633026123,1.5201865434646606,50000.0,0.5092000365257263,2.2315361499786377,10000.0,12787.222305297852,13244.35971903801,12787.222305297852,455.1110055446625,0.7531900405883789,0.0 -37400,2.021186,1.6216402,,,,,,,,,,,,,, -37500,1.8424792,1.7696906,,,,,,,,,,,,,, -37600,1.6327984,1.6400371,,,,,,,,,,,,,, -37700,1.8623521,1.7716534,,,,,,,,,,,,,, -37800,1.9096637,1.6309606,,,,,,,,,,,,,, -37900,1.7118562,1.7167332,,,,,,,,,,,,,, -38000,2.0268548,1.6633596,,,,,,,,,,,,,, -38100,2.015073,1.6557696,,,,,,,,,,,,,, -38200,1.7405989,1.6854088,,,,,,,,,,,,,, -38300,1.7794322,1.6044825,,,,,,,,,,,,,, -38400,1.6350322,1.5725853,,,,,,,,,,,,,, -38500,1.9136405,1.6367254,,,,,,,,,,,,,, -38600,1.829108,1.6474262,,,,,,,,,,,,,, -38700,1.7684653,1.6616943,,,,,,,,,,,,,, -38800,1.6448444,1.7122587,,,,,,,,,,,,,, -38868,,,0.6949737071990967,1.1847723722457886,0.6327599883079529,1.518254041671753,50000.0,0.5082000494003296,2.2616703510284424,10000.0,13297.305636644363,13771.90760755539,13297.305636644363,472.4876246452332,0.7889235019683838,0.0 -38900,1.9500208,1.6199961,,,,,,,,,,,,,, -39000,1.7943451,1.6716967,,,,,,,,,,,,,, -39100,1.7015233,1.578513,,,,,,,,,,,,,, -39200,1.9987854,1.6783512,,,,,,,,,,,,,, -39300,1.691924,1.7312353,,,,,,,,,,,,,, -39400,1.649079,1.7414969,,,,,,,,,,,,,, -39500,1.7477936,1.5720723,,,,,,,,,,,,,, -39600,1.9897288,1.6490221,,,,,,,,,,,,,, -39700,1.5645442,1.595945,,,,,,,,,,,,,, -39800,1.6455935,1.6608129,,,,,,,,,,,,,, -39900,1.8976452,1.6733519,,,,,,,,,,,,,, -40000,1.7541125,1.6242614,,,,,,,,,,,,,, -40100,1.867865,1.5875902,,,,,,,,,,,,,, -40200,2.0685346,1.6567938,,,,,,,,,,,,,, -40300,1.9283834,1.6774806,,,,,,,,,,,,,, -40364,,,0.6914859414100647,1.2236030101776123,0.6318399906158447,1.5228060483932495,50000.0,0.4972000122070312,2.3025379180908203,10000.0,13807.231940984726,14299.705271959305,13807.231940984726,490.2734615802765,0.8251686096191406,0.0 -40400,1.8252155,1.6208086,,,,,,,,,,,,,, -40500,1.770036,1.6897871,,,,,,,,,,,,,, -40600,2.0121498,1.8012323,,,,,,,,,,,,,, -40700,1.8303579,1.6524699,,,,,,,,,,,,,, -40800,1.8690943,1.5520867,,,,,,,,,,,,,, -40900,1.9969528,1.6471486,,,,,,,,,,,,,, -41000,1.843815,1.5162717,,,,,,,,,,,,,, -41100,1.6820335,1.6201892,,,,,,,,,,,,,, -41200,1.873693,1.6534442,,,,,,,,,,,,,, -41300,1.640354,1.7001586,,,,,,,,,,,,,, -41400,1.8796546,1.6139758,,,,,,,,,,,,,, -41500,1.8260474,1.8103428,,,,,,,,,,,,,, -41600,1.9969143,1.7892399,,,,,,,,,,,,,, -41700,1.8499383,1.7588379,,,,,,,,,,,,,, -41800,1.6994085,1.6458429,,,,,,,,,,,,,, -41861,,,0.6909677982330322,1.2019314765930176,0.6317600011825562,1.5205481052398682,50000.0,0.5012000203132629,2.2280113697052,10000.0,14317.350157022476,14827.315202951431,14317.350157022476,507.6738419532776,0.8653068542480469,0.0 -41900,1.8579702,1.7493442,,,,,,,,,,,,,, -42000,1.784612,1.6046244,,,,,,,,,,,,,, -42100,1.8803631,1.5492308,,,,,,,,,,,,,, -42200,2.05457,1.7129291,,,,,,,,,,,,,, -42300,1.983723,1.6238283,,,,,,,,,,,,,, -42400,1.8114549,1.7059686,,,,,,,,,,,,,, -42500,1.9268807,1.6566412,,,,,,,,,,,,,, -42600,1.8663789,1.5797497,,,,,,,,,,,,,, -42700,1.8610599,1.601565,,,,,,,,,,,,,, -42800,1.8292598,1.6406312,,,,,,,,,,,,,, -42900,1.8935528,1.74963,,,,,,,,,,,,,, -43000,1.9615928,1.6417971,,,,,,,,,,,,,, -43100,1.836994,1.7810547,,,,,,,,,,,,,, -43200,1.766613,1.5797179,,,,,,,,,,,,,, -43300,1.9528196,1.6125518,,,,,,,,,,,,,, -43358,,,0.726980984210968,1.0523135662078855,0.6317399740219116,1.5127232074737549,50000.0,0.5042999982833862,2.259047508239746,10000.0,14827.423000097277,15354.948499202728,14827.423000097277,525.1450872421265,0.903350830078125,0.0 -43400,1.9705876,1.6888243,,,,,,,,,,,,,, -43500,1.7219089,1.5430889,,,,,,,,,,,,,, -43600,1.9206496,1.7425554,,,,,,,,,,,,,, -43700,1.6041094,1.549221,,,,,,,,,,,,,, -43800,1.7383828,1.6839335,,,,,,,,,,,,,, -43900,1.7852893,1.5678185,,,,,,,,,,,,,, -44000,1.8271985,1.7007344,,,,,,,,,,,,,, -44100,1.9994358,1.6414216,,,,,,,,,,,,,, -44200,1.9916372,1.7276925,,,,,,,,,,,,,, -44300,1.7819207,1.4735193,,,,,,,,,,,,,, -44400,1.7151998,1.5999867,,,,,,,,,,,,,, -44500,2.0113897,1.6856669,,,,,,,,,,,,,, -44600,1.726717,1.5949321,,,,,,,,,,,,,, -44700,1.8428878,1.6582139,,,,,,,,,,,,,, -44800,1.8895606,1.6185045,,,,,,,,,,,,,, -44855,,,0.7177534699440002,1.0975770950317385,0.6431800127029419,1.4693011045455933,50000.0,0.5175000429153442,2.1700892448425293,10000.0,15337.656279802322,15882.606467962263,15337.656279802322,542.4807982444763,0.9419848918914796,0.0 -44900,1.831707,1.5903049,,,,,,,,,,,,,, -45000,1.8945991,1.5849957,,,,,,,,,,,,,, -45100,2.0166578,1.6196051,,,,,,,,,,,,,, -45200,1.7800277,1.5854626,,,,,,,,,,,,,, -45300,1.8373897,1.7206157,,,,,,,,,,,,,, -45400,1.9008969,1.576634,,,,,,,,,,,,,, -45500,1.7430786,1.7101121,,,,,,,,,,,,,, -45600,1.7598535,1.673812,,,,,,,,,,,,,, -45700,1.9327685,1.6433314,,,,,,,,,,,,,, -45800,2.066965,1.7105503,,,,,,,,,,,,,, -45900,1.7632524,1.7059059,,,,,,,,,,,,,, -46000,1.8021845,1.601941,,,,,,,,,,,,,, -46100,1.8953874,1.6165044,,,,,,,,,,,,,, -46200,1.9078962,1.6730351,,,,,,,,,,,,,, -46300,1.8260216,1.5810457,,,,,,,,,,,,,, -46351,,,0.7119937539100647,1.1093826293945312,0.6420199871063232,1.4766788482666016,50000.0,0.5148000121116638,2.235180854797364,10000.0,15847.795320510864,16409.9840195179,15847.795320510864,559.6336979866028,0.9760580062866212,0.0 -46400,1.7798779,1.5489768,,,,,,,,,,,,,, -46500,1.7358521,1.5510023,,,,,,,,,,,,,, -46600,1.9191747,1.6152794,,,,,,,,,,,,,, -46700,1.831165,1.6692106,,,,,,,,,,,,,, -46800,1.9044611,1.5728961,,,,,,,,,,,,,, -46900,1.7569007,1.6029308,,,,,,,,,,,,,, -47000,2.4244184,1.6755149,,,,,,,,,,,,,, -47100,1.7038963,1.4876614,,,,,,,,,,,,,, -47200,1.8898436,1.6156789,,,,,,,,,,,,,, -47300,1.8426852,1.6056602,,,,,,,,,,,,,, -47400,1.9392976,1.6148334,,,,,,,,,,,,,, -47500,1.8340253,1.648462,,,,,,,,,,,,,, -47600,1.8878349,1.6834245,,,,,,,,,,,,,, -47700,1.7519393,1.6522177,,,,,,,,,,,,,, -47800,2.1204083,1.6550602,,,,,,,,,,,,,, -47848,,,0.7031847834587097,1.1564139127731323,0.6360999941825867,1.489286184310913,50000.0,0.5045000314712524,2.241513967514038,10000.0,16358.012321472168,16937.62736606598,16358.012321472168,576.9686605930328,1.016944408416748,0.0 -47900,1.7496215,1.6007953,,,,,,,,,,,,,, -48000,1.8531458,1.5994484,,,,,,,,,,,,,, -48100,1.7845061,1.5488536,,,,,,,,,,,,,, -48200,1.6227171,1.5656109,,,,,,,,,,,,,, -48300,1.8529321,1.6717334,,,,,,,,,,,,,, -48400,1.7111801,1.5799378,,,,,,,,,,,,,, -48500,2.1588676,1.657549,,,,,,,,,,,,,, -48600,1.8668201,1.6783062,,,,,,,,,,,,,, -48700,1.6488485,1.5299239,,,,,,,,,,,,,, -48800,1.9237247,1.6605262,,,,,,,,,,,,,, -48900,1.8360022,1.6020617,,,,,,,,,,,,,, -49000,1.929222,1.5976663,,,,,,,,,,,,,, -49100,2.0165844,1.6516154,,,,,,,,,,,,,, -49200,1.7237254,1.6209369,,,,,,,,,,,,,, -49300,1.8336343,1.5026633,,,,,,,,,,,,,, -49344,,,0.7031847834587097,1.1441307067871094,0.6427199840545654,1.467440843582153,50000.0,0.5211000442504883,2.169034481048584,10000.0,16868.260396003723,17465.60064959526,16868.260396003723,594.6073455810547,1.0527050495147705,0.0 -49400,2.0534387,1.6869408,,,,,,,,,,,,,, -49500,1.7179888,1.5903306,,,,,,,,,,,,,, -49600,1.6994683,1.6296589,,,,,,,,,,,,,, -49700,1.8574718,1.7135041,,,,,,,,,,,,,, -49800,1.6982256,1.4868641,,,,,,,,,,,,,, -49900,1.9606004,1.7031393,,,,,,,,,,,,,, -50000,2.0016758,1.5067705,,,,,,,,,,,,,, -50100,1.9501971,1.6957054,,,,,,,,,,,,,, -50200,1.7823541,1.5942404,,,,,,,,,,,,,, -50300,2.1247766,1.6073191,,,,,,,,,,,,,, -50400,1.7633018,1.3942701,,,,,,,,,,,,,, -50500,1.8213155,1.6810832,,,,,,,,,,,,,, -50600,1.8789966,1.5795256,,,,,,,,,,,,,, -50700,1.6949166,1.5740227,,,,,,,,,,,,,, -50800,1.7807875,1.585545,,,,,,,,,,,,,, -50841,,,0.7056361436843872,1.1398847103118896,0.645039975643158,1.4592891931533811,50000.0,0.5149000287055969,2.193108081817627,10000.0,17378.219309806824,17993.183161497116,17378.219309806824,612.1412818431854,1.091554880142212,0.0 -50900,1.8180813,1.664727,,,,,,,,,,,,,, -51000,1.8003443,1.5960882,,,,,,,,,,,,,, -51100,1.7121415,1.5190068,,,,,,,,,,,,,, -51200,1.8766834,1.6119933,,,,,,,,,,,,,, -51300,1.818144,1.5955873,,,,,,,,,,,,,, -51400,1.7727263,1.5409672,,,,,,,,,,,,,, -51500,1.9545773,1.6841553,,,,,,,,,,,,,, -51600,1.7837077,1.6226442,,,,,,,,,,,,,, -51700,1.8077595,1.6697617,,,,,,,,,,,,,, -51800,1.8316269,1.6377114,,,,,,,,,,,,,, -51900,2.347586,1.5862929,,,,,,,,,,,,,, -52000,1.9315126,1.4974056,,,,,,,,,,,,,, -52100,2.164857,1.6230104,,,,,,,,,,,,,, -52200,1.7932782,1.6752667,,,,,,,,,,,,,, -52300,1.9375644,1.5797296,,,,,,,,,,,,,, -52338,,,0.7645886540412903,0.90327787399292,0.6543799638748169,1.4254825115203855,50000.0,0.520300030708313,2.1776282787323,10000.0,17888.419243574142,18520.939923524857,17888.419243574142,629.606897354126,1.1325068473815918,0.0 -52400,1.9471637,1.6075665,,,,,,,,,,,,,, -52500,1.7643983,1.5134913,,,,,,,,,,,,,, -52600,1.9598134,1.7428975,,,,,,,,,,,,,, -52700,1.8165591,1.4705727,,,,,,,,,,,,,, -52800,1.9433883,1.6154839,,,,,,,,,,,,,, -52900,1.6931503,1.5613153,,,,,,,,,,,,,, -53000,1.7628999,1.5396005,,,,,,,,,,,,,, -53100,1.738071,1.6554168,,,,,,,,,,,,,, -53200,1.7520624,1.5474243,,,,,,,,,,,,,, -53300,1.8186275,1.7039862,,,,,,,,,,,,,, -53400,2.020135,1.6679233,,,,,,,,,,,,,, -53500,1.8962109,1.5186436,,,,,,,,,,,,,, -53600,1.9030625,1.6228046,,,,,,,,,,,,,, -53700,1.8721337,1.5845433,,,,,,,,,,,,,, -53800,1.894055,1.5547112,,,,,,,,,,,,,, -53835,,,0.7227758169174194,1.0640006065368652,0.6343599557876587,1.4988964796066284,50000.0,0.5029000043869019,2.2745490074157715,10000.0,18398.483984470367,19048.349376678467,18398.483984470367,646.8635385036469,1.1695225238800049,0.0 -53900,1.9756762,1.6874709,,,,,,,,,,,,,, -54000,1.7330006,1.5809819,,,,,,,,,,,,,, -54100,2.141283,1.6836226,,,,,,,,,,,,,, -54200,1.8577489,1.6717105,,,,,,,,,,,,,, -54300,1.848365,1.5853474,,,,,,,,,,,,,, -54400,1.8707243,1.5800188,,,,,,,,,,,,,, -54500,1.9728863,1.6730655,,,,,,,,,,,,,, -54600,1.9549078,1.6191757,,,,,,,,,,,,,, -54700,2.0106683,1.6433289,,,,,,,,,,,,,, -54800,2.20533,1.6126201,,,,,,,,,,,,,, -54900,1.8821948,1.5678184,,,,,,,,,,,,,, -55000,1.8165338,1.6232857,,,,,,,,,,,,,, -55100,1.8359424,1.5942253,,,,,,,,,,,,,, -55200,1.9009081,1.6134989,,,,,,,,,,,,,, -55300,2.127793,1.617602,,,,,,,,,,,,,, -55332,,,0.7080675959587097,1.1316912174224854,0.6315799951553345,1.5241401195526123,50000.0,0.5034000277519226,2.26247239112854,10000.0,18908.700251817703,19576.242521762848,18908.700251817703,664.4518418312073,1.2079482078552246,0.0 -55400,1.8534099,1.5458623,,,,,,,,,,,,,, -55500,1.996996,1.6863704,,,,,,,,,,,,,, -55600,1.9371346,1.6049194,,,,,,,,,,,,,, -55700,2.0142798,1.6922895,,,,,,,,,,,,,, -55800,1.9128878,1.595694,,,,,,,,,,,,,, -55900,1.9008943,1.7130204,,,,,,,,,,,,,, -56000,1.9359633,1.6897167,,,,,,,,,,,,,, -56100,1.9045763,1.5329357,,,,,,,,,,,,,, -56200,2.0700083,1.6093619,,,,,,,,,,,,,, -56300,1.8823013,1.5925158,,,,,,,,,,,,,, -56400,1.7818567,1.5698234,,,,,,,,,,,,,, -56500,1.802783,1.530549,,,,,,,,,,,,,, -56600,1.8163902,1.5093656,,,,,,,,,,,,,, -56700,1.9781561,1.7333518,,,,,,,,,,,,,, -56800,1.9385519,1.559585,,,,,,,,,,,,,, -56829,,,0.7178930044174194,1.0814168453216553,0.6469199657440186,1.4597644805908203,50000.0,0.5152000188827515,2.216600179672241,10000.0,19418.939255476,20104.18901371956,19418.939255476,682.0746810436249,1.2414581775665283,0.0 -56900,1.858342,1.5670066,,,,,,,,,,,,,, -57000,1.8390266,1.6045341,,,,,,,,,,,,,, -57100,1.7437935,1.4023144,,,,,,,,,,,,,, -57200,1.7491188,1.5420645,,,,,,,,,,,,,, -57300,1.8481092,1.6293842,,,,,,,,,,,,,, -57400,1.9118849,1.4651382,,,,,,,,,,,,,, -57500,1.8278803,1.5264413,,,,,,,,,,,,,, -57600,1.9312351,1.597547,,,,,,,,,,,,,, -57700,1.8814168,1.5110373,,,,,,,,,,,,,, -57800,1.8334374,1.6024898,,,,,,,,,,,,,, -57900,1.9115078,1.5767797,,,,,,,,,,,,,, -58000,1.74292,1.543633,,,,,,,,,,,,,, -58100,1.8172245,1.6546987,,,,,,,,,,,,,, -58200,2.1693208,1.549252,,,,,,,,,,,,,, -58300,1.7956423,1.4970908,,,,,,,,,,,,,, -58326,,,0.7149434089660645,1.1157106161117554,0.6482999920845032,1.441026210784912,50000.0,0.5131000280380249,2.1649599075317383,10000.0,19929.052931785583,20632.583993673325,19929.052931785583,700.2618687152863,1.2832703590393066,0.0 -58400,1.8770499,1.6532236,,,,,,,,,,,,,, -58500,1.8461851,1.485802,,,,,,,,,,,,,, -58600,1.881754,1.6130894,,,,,,,,,,,,,, -58700,2.099824,1.6673803,,,,,,,,,,,,,, -58800,1.9010837,1.5425143,,,,,,,,,,,,,, -58900,1.797441,1.4466563,,,,,,,,,,,,,, -59000,2.175681,1.5543909,,,,,,,,,,,,,, -59100,1.8922107,1.4263883,,,,,,,,,,,,,, -59200,1.9844984,1.6545742,,,,,,,,,,,,,, -59300,1.9062775,1.510883,,,,,,,,,,,,,, -59400,1.9497112,1.4759976,,,,,,,,,,,,,, -59500,1.9135342,1.6267112,,,,,,,,,,,,,, -59600,2.020078,1.5496112,,,,,,,,,,,,,, -59700,2.1884027,1.5191699,,,,,,,,,,,,,, -59800,1.9348706,1.512139,,,,,,,,,,,,,, -59823,,,0.7216398119926453,1.0750752687454224,0.6587199568748474,1.4014229774475098,50000.0,0.5279000401496887,2.151419162750244,10000.0,20439.141530275345,21160.23457360268,20439.141530275345,717.730926990509,1.3240761756896973,0.0 -59900,1.8077184,1.5434349,,,,,,,,,,,,,, -60000,1.7671155,1.5099874,,,,,,,,,,,,,, -60100,2.0826552,1.6096723,,,,,,,,,,,,,, -60200,1.9386338,1.4717004,,,,,,,,,,,,,, -60300,2.1671903,1.6619288,,,,,,,,,,,,,, -60400,1.8539491,1.4240535,,,,,,,,,,,,,, -60500,1.94661,1.5894558,,,,,,,,,,,,,, -60600,1.7900221,1.5179434,,,,,,,,,,,,,, -60700,1.7651333,1.4872998,,,,,,,,,,,,,, -60800,2.3061445,1.5605633,,,,,,,,,,,,,, -60900,1.8753401,1.552629,,,,,,,,,,,,,, -61000,1.9673452,1.6019772,,,,,,,,,,,,,, -61100,2.003182,1.5962555,,,,,,,,,,,,,, -61200,2.0127232,1.7061694,,,,,,,,,,,,,, -61300,1.718768,1.5630052,,,,,,,,,,,,,, -61320,,,0.7239716053009033,1.0552133321762085,0.6634399890899658,1.3708513975143433,50000.0,0.5362000465393066,2.0705716609954834,10000.0,20949.247843265533,21687.68998861313,20949.247843265533,734.9866235256195,1.3634462356567385,0.0 -61400,1.8486909,1.5721047,,,,,,,,,,,,,, -61500,1.9946404,1.542972,,,,,,,,,,,,,, -61600,1.9783545,1.4757082,,,,,,,,,,,,,, -61700,2.052024,1.5743642,,,,,,,,,,,,,, -61800,1.789942,1.4455403,,,,,,,,,,,,,, -61900,1.9379303,1.6188786,,,,,,,,,,,,,, -62000,2.0862727,1.4717921,,,,,,,,,,,,,, -62100,2.2562082,1.5586264,,,,,,,,,,,,,, -62200,2.1158688,1.5199916,,,,,,,,,,,,,, -62300,2.042376,1.5174992,,,,,,,,,,,,,, -62400,1.9656767,1.6132134,,,,,,,,,,,,,, -62500,1.9566498,1.63573,,,,,,,,,,,,,, -62600,1.9745044,1.4760149,,,,,,,,,,,,,, -62700,2.058497,1.6999412,,,,,,,,,,,,,, -62800,2.1243114,1.5392642,,,,,,,,,,,,,, -62817,,,0.7576530575752258,0.9276350736618042,0.6600199937820435,1.3925979137420654,50000.0,0.5330000519752502,2.138416290283203,10000.0,21459.278631210327,22215.15426421165,21459.278631210327,752.3298766613007,1.4020438194274902,0.0 -62900,2.0937867,1.5175611,,,,,,,,,,,,,, -63000,1.8704734,1.4734801,,,,,,,,,,,,,, -63100,2.1656456,1.5666022,,,,,,,,,,,,,, -63200,2.00755,1.5263674,,,,,,,,,,,,,, -63300,1.9338478,1.7038095,,,,,,,,,,,,,, -63400,2.3161855,1.6101067,,,,,,,,,,,,,, -63500,1.8923696,1.6015657,,,,,,,,,,,,,, -63600,1.8848777,1.5668452,,,,,,,,,,,,,, -63700,2.2298388,1.557657,,,,,,,,,,,,,, -63800,2.0752122,1.534636,,,,,,,,,,,,,, -63900,2.0569327,1.4925396,,,,,,,,,,,,,, -64000,1.9487246,1.5222428,,,,,,,,,,,,,, -64100,1.9507267,1.5023968,,,,,,,,,,,,,, -64200,1.9746872,1.5742288,,,,,,,,,,,,,, -64300,2.0768473,1.5540426,,,,,,,,,,,,,, -64314,,,0.7278180718421936,1.0450934171676636,0.6518799662590027,1.438127517700195,50000.0,0.5248000025749207,2.172521114349365,10000.0,21969.453302383423,22742.71888899803,21969.453302383423,769.6031460762024,1.4653689861297607,0.0 -64400,1.9060152,1.5280538,,,,,,,,,,,,,, -64500,1.8572079,1.5122831,,,,,,,,,,,,,, -64600,2.1166508,1.4350516,,,,,,,,,,,,,, -64700,1.9996219,1.5192218,,,,,,,,,,,,,, -64800,2.0652661,1.5691091,,,,,,,,,,,,,, -64900,1.9644313,1.4821529,,,,,,,,,,,,,, -65000,1.9114568,1.492652,,,,,,,,,,,,,, -65100,1.9191965,1.4541674,,,,,,,,,,,,,, -65200,1.8824955,1.442147,,,,,,,,,,,,,, -65300,1.9650099,1.4904953,,,,,,,,,,,,,, -65400,1.9011333,1.5347852,,,,,,,,,,,,,, -65500,1.8673018,1.4690273,,,,,,,,,,,,,, -65600,1.9818351,1.5359215,,,,,,,,,,,,,, -65700,2.1614792,1.508091,,,,,,,,,,,,,, -65800,2.0582376,1.4120165,,,,,,,,,,,,,, -65811,,,0.7356903553009033,1.0099648237228394,0.6637600064277649,1.3708646297454834,50000.0,0.5311000347137451,2.125349998474121,10000.0,22479.49555540085,23270.3928463459,22479.49555540085,787.1383633613586,1.5099318027496338,0.0 -65900,2.083596,1.7118818,,,,,,,,,,,,,, -66000,2.20579,1.6068076,,,,,,,,,,,,,, -66100,2.037989,1.5600703,,,,,,,,,,,,,, -66200,1.9148865,1.5740583,,,,,,,,,,,,,, -66300,1.8242918,1.5305059,,,,,,,,,,,,,, -66400,1.9351567,1.5613,,,,,,,,,,,,,, -66500,1.9411124,1.5090938,,,,,,,,,,,,,, -66600,2.0204458,1.4618134,,,,,,,,,,,,,, -66700,2.0614944,1.5196654,,,,,,,,,,,,,, -66800,1.9051088,1.4030501,,,,,,,,,,,,,, -66900,1.8628613,1.5351212,,,,,,,,,,,,,, -67000,1.9119763,1.566742,,,,,,,,,,,,,, -67100,1.9662852,1.6565487,,,,,,,,,,,,,, -67200,1.8961087,1.469376,,,,,,,,,,,,,, -67300,1.9812275,1.4736558,,,,,,,,,,,,,, -67308,,,0.7371651530265808,0.9949511885643004,0.664359986782074,1.357032299041748,50000.0,0.5324000120162964,2.098825216293335,10000.0,22989.733140707016,23798.065600156784,22989.733140707016,804.4793326854706,1.55267596244812,0.0 -67400,2.0943537,1.5981079,,,,,,,,,,,,,, -67500,1.8527863,1.4622966,,,,,,,,,,,,,, -67600,1.9981214,1.5647601,,,,,,,,,,,,,, -67700,1.9497746,1.5323002,,,,,,,,,,,,,, -67800,2.127968,1.5169088,,,,,,,,,,,,,, -67900,2.0125031,1.4173962,,,,,,,,,,,,,, -68000,2.0448878,1.5179499,,,,,,,,,,,,,, -68100,1.9342074,1.6229281,,,,,,,,,,,,,, -68200,2.3426197,1.5538439,,,,,,,,,,,,,, -68300,2.3103843,1.5633808,,,,,,,,,,,,,, -68400,2.266706,1.5242496,,,,,,,,,,,,,, -68500,1.9599142,1.4930159,,,,,,,,,,,,,, -68600,1.9686154,1.440303,,,,,,,,,,,,,, -68700,2.1267452,1.430686,,,,,,,,,,,,,, -68800,2.4097905,1.5560749,,,,,,,,,,,,,, -68806,,,0.7223772406578064,1.0757921934127808,0.6567599773406982,1.39765727519989,50000.0,0.5281000137329102,2.129694700241089,10000.0,23499.958671808243,24325.86668920517,23499.958671808243,821.963175535202,1.5933401584625244,0.0 -68900,2.1775,1.4760184,,,,,,,,,,,,,, -69000,2.1281245,1.5512962,,,,,,,,,,,,,, -69100,2.2225072,1.4730941,,,,,,,,,,,,,, -69200,2.2206976,1.5639695,,,,,,,,,,,,,, -69300,1.9360842,1.5032033,,,,,,,,,,,,,, -69400,2.1812863,1.5077959,,,,,,,,,,,,,, -69500,2.2202632,1.6040572,,,,,,,,,,,,,, -69600,2.1650314,1.4977431,,,,,,,,,,,,,, -69700,1.9446971,1.4653854,,,,,,,,,,,,,, -69800,2.0085967,1.3780341,,,,,,,,,,,,,, -69900,1.9933816,1.5553694,,,,,,,,,,,,,, -70000,1.8564253,1.5347955,,,,,,,,,,,,,, -70100,2.0183516,1.5861691,,,,,,,,,,,,,, -70200,2.14106,1.5139098,,,,,,,,,,,,,, -70300,2.0443838,1.4770433,,,,,,,,,,,,,, -70302,,,0.7234334945678711,1.0573880672454834,0.656719982624054,1.40852689743042,50000.0,0.5242000222206116,2.144479513168335,10000.0,24009.941289901733,24853.39986562729,24009.941289901733,839.4227304458618,1.6327154636383057,0.0 -70400,1.9851989,1.4963721,,,,,,,,,,,,,, -70500,1.9207042,1.4493161,,,,,,,,,,,,,, -70600,1.7616003,1.4681274,,,,,,,,,,,,,, -70700,2.1890178,1.6233743,,,,,,,,,,,,,, -70800,1.872715,1.5518963,,,,,,,,,,,,,, -70900,2.0714693,1.4605664,,,,,,,,,,,,,, -71000,2.07874,1.5630685,,,,,,,,,,,,,, -71100,1.9909302,1.4119897,,,,,,,,,,,,,, -71200,2.1038373,1.501331,,,,,,,,,,,,,, -71300,2.0279596,1.5322505,,,,,,,,,,,,,, -71400,1.896142,1.4570274,,,,,,,,,,,,,, -71500,2.1109242,1.6427436,,,,,,,,,,,,,, -71600,2.2367134,1.424115,,,,,,,,,,,,,, -71700,2.255976,1.6090176,,,,,,,,,,,,,, -71800,,,0.7745535373687744,0.8367967009544373,0.6677199602127075,1.3619836568832395,50000.0,0.5440000295639038,2.086559772491455,10000.0,24520.09197568893,25380.915219783783,24520.09197568893,856.6956856250763,1.6729493141174316,0.0 -71800,1.9429387,1.4423077,,,,,,,,,,,,,, -71900,2.0850391,1.3929739,,,,,,,,,,,,,, -72000,2.1199515,1.4877913,,,,,,,,,,,,,, -72100,2.2237487,1.5081668,,,,,,,,,,,,,, -72200,2.0029955,1.4348143,,,,,,,,,,,,,, -72300,2.0506945,1.5383703,,,,,,,,,,,,,, -72400,2.1445558,1.6317693,,,,,,,,,,,,,, -72500,2.0911448,1.5360212,,,,,,,,,,,,,, -72600,1.9613351,1.4939452,,,,,,,,,,,,,, -72700,2.196158,1.5347003,,,,,,,,,,,,,, -72800,1.9827067,1.4346489,,,,,,,,,,,,,, -72900,2.0435035,1.4694754,,,,,,,,,,,,,, -73000,2.002283,1.5022179,,,,,,,,,,,,,, -73100,1.9529574,1.6446729,,,,,,,,,,,,,, -73200,2.6695082,1.6174984,,,,,,,,,,,,,, -73297,,,0.7525908350944519,0.937186062335968,0.6672799587249756,1.3505654335021973,50000.0,0.5418000221252441,2.0513622760772705,10000.0,25030.211496591568,25908.675062417984,25030.211496591568,874.2368021011353,1.720144510269165,0.0 -73300,2.1869612,1.5894618,,,,,,,,,,,,,, -73400,1.9176718,1.4109824,,,,,,,,,,,,,, -73500,2.4913254,1.6020411,,,,,,,,,,,,,, -73600,2.1905398,1.401332,,,,,,,,,,,,,, -73700,2.0761456,1.4998137,,,,,,,,,,,,,, -73800,1.9038228,1.43986,,,,,,,,,,,,,, -73900,1.9660413,1.5485134,,,,,,,,,,,,,, -74000,2.1131222,1.418233,,,,,,,,,,,,,, -74100,1.8211176,1.3973664,,,,,,,,,,,,,, -74200,2.063294,1.3470829,,,,,,,,,,,,,, -74300,2.1184268,1.4868623,,,,,,,,,,,,,, -74400,2.2849133,1.4395291,,,,,,,,,,,,,, -74500,2.0630484,1.4195921,,,,,,,,,,,,,, -74600,2.0456362,1.5452847,,,,,,,,,,,,,, -74700,2.1053572,1.4468666,,,,,,,,,,,,,, -74793,,,0.7455755472183228,0.9578787088394164,0.6665999889373779,1.366343379020691,50000.0,0.5467000007629395,2.070542812347412,10000.0,25540.13827967644,26436.25676727295,25540.13827967644,891.7931699752808,1.7655150890350342,0.0 -74800,2.183634,1.60568,,,,,,,,,,,,,, -74900,1.9874948,1.4057359,,,,,,,,,,,,,, -75000,2.0130112,1.4166143,,,,,,,,,,,,,, -75100,2.2182293,1.492024,,,,,,,,,,,,,, -75200,2.1585317,1.4942472,,,,,,,,,,,,,, -75300,2.126754,1.5246165,,,,,,,,,,,,,, -75400,1.9875342,1.4813002,,,,,,,,,,,,,, -75500,2.2332742,1.498865,,,,,,,,,,,,,, -75600,2.0650663,1.6450644,,,,,,,,,,,,,, -75700,2.0145621,1.502109,,,,,,,,,,,,,, -75800,1.9675874,1.4886878,,,,,,,,,,,,,, -75900,2.3254406,1.4773467,,,,,,,,,,,,,, -76000,2.1087074,1.5399566,,,,,,,,,,,,,, -76100,1.9559106,1.4500287,,,,,,,,,,,,,, -76200,1.997635,1.432404,,,,,,,,,,,,,, -76290,,,0.73441481590271,1.0014700889587402,0.6627599596977234,1.396458864212036,50000.0,0.5372000336647034,2.145915985107422,10000.0,26050.24849486351,26963.783942461014,26050.24849486351,909.11328291893,1.809826374053955,0.0 -76300,2.2309968,1.4273671,,,,,,,,,,,,,, -76400,2.0022159,1.5312463,,,,,,,,,,,,,, -76500,2.0413978,1.4230206,,,,,,,,,,,,,, -76600,2.1281118,1.4981538,,,,,,,,,,,,,, -76700,2.1330545,1.5299382,,,,,,,,,,,,,, -76800,2.0707161,1.3954196,,,,,,,,,,,,,, -76900,2.0279756,1.4946802,,,,,,,,,,,,,, -77000,2.1554513,1.4636562,,,,,,,,,,,,,, -77100,2.0683033,1.5130603,,,,,,,,,,,,,, -77200,2.0064056,1.5597793,,,,,,,,,,,,,, -77300,2.3108535,1.4926513,,,,,,,,,,,,,, -77400,2.110514,1.5314107,,,,,,,,,,,,,, -77500,2.1925898,1.4810586,,,,,,,,,,,,,, -77600,2.1195533,1.4475617,,,,,,,,,,,,,, -77700,2.2634692,1.5634912,,,,,,,,,,,,,, -77787,,,0.7419283986091614,0.9948947429656982,0.668179988861084,1.3509660959243774,50000.0,0.5320000052452087,2.0675227642059326,10000.0,26560.44718551636,27491.354991436005,26560.44718551636,926.391303539276,1.854209899902344,0.0 -77800,2.071219,1.4778005,,,,,,,,,,,,,, -77900,2.2366872,1.4804829,,,,,,,,,,,,,, -78000,1.968222,1.5003827,,,,,,,,,,,,,, -78100,2.1242843,1.4964442,,,,,,,,,,,,,, -78200,2.2542763,1.4044652,,,,,,,,,,,,,, -78300,2.0229814,1.4042343,,,,,,,,,,,,,, -78400,2.0214207,1.4737206,,,,,,,,,,,,,, -78500,2.032786,1.4711282,,,,,,,,,,,,,, -78600,2.1917496,1.5026582,,,,,,,,,,,,,, -78700,2.0547383,1.3650495,,,,,,,,,,,,,, -78800,2.1907868,1.4384629,,,,,,,,,,,,,, -78900,2.3183846,1.4875841,,,,,,,,,,,,,, -79000,2.2003007,1.5030681,,,,,,,,,,,,,, -79100,2.1955702,1.4600579,,,,,,,,,,,,,, -79200,2.2180295,1.597026,,,,,,,,,,,,,, -79284,,,0.7411909699440002,0.993799090385437,0.6715399622917175,1.3349510431289673,50000.0,0.5475000143051147,2.050487756729126,10000.0,27070.60922384262,28019.048672914505,27070.60922384262,943.8302927017212,1.894223690032959,0.0 -79300,2.2253952,1.5227648,,,,,,,,,,,,,, -79400,1.9930607,1.4154384,,,,,,,,,,,,,, -79500,2.0991902,1.4457507,,,,,,,,,,,,,, -79600,2.0845096,1.4436705,,,,,,,,,,,,,, -79700,1.9833142,1.3543096,,,,,,,,,,,,,, -79800,2.183635,1.4225494,,,,,,,,,,,,,, -79900,2.0499284,1.3754745,,,,,,,,,,,,,, -80000,1.9296495,1.4506862,,,,,,,,,,,,,, -80100,2.2879708,1.5719092,,,,,,,,,,,,,, -80200,2.2405305,1.4609653,,,,,,,,,,,,,, -80300,2.2924373,1.5398344,,,,,,,,,,,,,, -80400,2.1892593,1.4344462,,,,,,,,,,,,,, -80500,2.141237,1.4969797,,,,,,,,,,,,,, -80600,2.159406,1.4811409,,,,,,,,,,,,,, -80700,2.282371,1.3585644,,,,,,,,,,,,,, -80781,,,0.7743940949440002,0.8552932739257812,0.6669399738311768,1.3668314218521118,50000.0,0.5409000515937805,2.0658974647521973,10000.0,27580.572791814804,28546.45977640152,27580.572791814804,961.185555934906,1.935410022735596,0.0 -80800,2.0223706,1.3677803,,,,,,,,,,,,,, -80900,2.1183956,1.4584316,,,,,,,,,,,,,, -81000,2.2984428,1.4766462,,,,,,,,,,,,,, -81100,2.115548,1.3653076,,,,,,,,,,,,,, -81200,2.1128514,1.3187058,,,,,,,,,,,,,, -81300,2.3043458,1.495511,,,,,,,,,,,,,, -81400,1.9930459,1.4963882,,,,,,,,,,,,,, -81500,2.254325,1.5609821,,,,,,,,,,,,,, -81600,2.2540095,1.4612429,,,,,,,,,,,,,, -81700,2.4008338,1.4473317,,,,,,,,,,,,,, -81800,2.2539027,1.5232148,,,,,,,,,,,,,, -81900,2.3975751,1.4823776,,,,,,,,,,,,,, -82000,2.179996,1.4016455,,,,,,,,,,,,,, -82100,1.9776249,1.3754611,,,,,,,,,,,,,, -82200,2.1802473,1.4729878,,,,,,,,,,,,,, -82278,,,0.7655652165412903,0.8683584332466125,0.6740999817848206,1.3325233459472656,50000.0,0.5436000227928162,2.066192865371704,10000.0,28090.6806986332,29074.016382932663,28090.6806986332,978.5394492149352,1.9794235229492188,0.0 -82300,2.2312903,1.3984835,,,,,,,,,,,,,, -82400,2.3240185,1.4679009,,,,,,,,,,,,,, -82500,2.065935,1.5123694,,,,,,,,,,,,,, -82600,2.0710144,1.4605212,,,,,,,,,,,,,, -82700,2.4285686,1.5018905,,,,,,,,,,,,,, -82800,2.3334022,1.4193741,,,,,,,,,,,,,, -82900,2.3174927,1.5367315,,,,,,,,,,,,,, -83000,2.2367246,1.462658,,,,,,,,,,,,,, -83100,2.2012587,1.4656742,,,,,,,,,,,,,, -83200,2.0117955,1.3073231,,,,,,,,,,,,,, -83300,2.5314,1.5385857,,,,,,,,,,,,,, -83400,2.1511872,1.4057474,,,,,,,,,,,,,, -83500,2.3350523,1.4854372,,,,,,,,,,,,,, -83600,2.3188221,1.4380457,,,,,,,,,,,,,, -83700,2.1981728,1.5883691,,,,,,,,,,,,,, -83775,,,0.7556401491165161,0.9285488128662108,0.6721199750900269,1.338887095451355,50000.0,0.5362000465393066,2.076328754425049,10000.0,28600.69172692299,29601.736602783203,28600.69172692299,996.1524906158448,2.0239923000335693,0.0 -83800,2.2577834,1.5739383,,,,,,,,,,,,,, -83900,2.2351797,1.4637878,,,,,,,,,,,,,, -84000,2.4616473,1.5383338,,,,,,,,,,,,,, -84100,2.1788878,1.5395824,,,,,,,,,,,,,, -84200,2.159029,1.4678284,,,,,,,,,,,,,, -84300,2.379359,1.4596026,,,,,,,,,,,,,, -84400,2.2006066,1.4941933,,,,,,,,,,,,,, -84500,2.1974523,1.4662933,,,,,,,,,,,,,, -84600,2.5283432,1.4559242,,,,,,,,,,,,,, -84700,2.2472043,1.5028129,,,,,,,,,,,,,, -84800,2.2179892,1.4061663,,,,,,,,,,,,,, -84900,2.1675384,1.4714696,,,,,,,,,,,,,, -85000,2.2208385,1.4518461,,,,,,,,,,,,,, -85100,2.1844301,1.531974,,,,,,,,,,,,,, -85200,2.070314,1.3998923,,,,,,,,,,,,,, -85272,,,0.7602040767669678,0.8968546390533447,0.6789399981498718,1.307782769203186,50000.0,0.5489000082015991,2.041975259780884,10000.0,29110.69561982155,30129.31273341179,29110.69561982155,1013.6288385391236,2.0686020851135254,0.0 -85300,2.246467,1.4420371,,,,,,,,,,,,,, -85400,2.1996582,1.4013896,,,,,,,,,,,,,, -85500,2.0474372,1.4088682,,,,,,,,,,,,,, -85600,2.2498386,1.3331566,,,,,,,,,,,,,, -85700,2.319712,1.4222529,,,,,,,,,,,,,, -85800,2.2098002,1.4835604,,,,,,,,,,,,,, -85900,2.4230423,1.4444876,,,,,,,,,,,,,, -86000,2.3551195,1.4097925,,,,,,,,,,,,,, -86100,2.065652,1.4995947,,,,,,,,,,,,,, -86200,2.174228,1.4453611,,,,,,,,,,,,,, -86300,2.1905618,1.3143002,,,,,,,,,,,,,, -86400,2.3458788,1.4696828,,,,,,,,,,,,,, -86500,2.0647821,1.3811316,,,,,,,,,,,,,, -86600,2.1605852,1.4122167,,,,,,,,,,,,,, -86700,2.3435159,1.5280293,,,,,,,,,,,,,, -86770,,,0.7505978941917419,0.9393353462219238,0.6771399974822998,1.3065553903579712,50000.0,0.5494000315666199,2.010836839675904,10000.0,29620.89577460289,30657.11715722084,29620.89577460289,1031.1397771835327,2.111499547958374,0.0 -86800,2.2688847,1.3393016,,,,,,,,,,,,,, -86900,2.0420277,1.457673,,,,,,,,,,,,,, -87000,2.215257,1.4214282,,,,,,,,,,,,,, -87100,2.2172282,1.4117604,,,,,,,,,,,,,, -87200,2.2741368,1.388003,,,,,,,,,,,,,, -87300,2.0930972,1.3296934,,,,,,,,,,,,,, -87400,2.2374244,1.5037221,,,,,,,,,,,,,, -87500,2.2051039,1.3649129,,,,,,,,,,,,,, -87600,2.269769,1.4627488,,,,,,,,,,,,,, -87700,2.1946838,1.4036739,,,,,,,,,,,,,, -87800,2.2034223,1.355433,,,,,,,,,,,,,, -87900,2.3182755,1.3409712,,,,,,,,,,,,,, -88000,2.1746995,1.3311976,,,,,,,,,,,,,, -88100,2.1218486,1.3240798,,,,,,,,,,,,,, -88200,2.3250186,1.3867586,,,,,,,,,,,,,, -88267,,,0.7596260905265808,0.9097063541412354,0.6830799579620361,1.28650164604187,50000.0,0.5533000230789185,2.0117292404174805,10000.0,30130.933844089508,31184.721259593964,30130.933844089508,1048.6075825691223,2.1597437858581543,0.0 -88300,2.0714762,1.3706591,,,,,,,,,,,,,, -88400,2.2220862,1.4485668,,,,,,,,,,,,,, -88500,2.170118,1.4258893,,,,,,,,,,,,,, -88600,2.203355,1.4323406,,,,,,,,,,,,,, -88700,2.2271361,1.4074278,,,,,,,,,,,,,, -88800,2.0810971,1.2745866,,,,,,,,,,,,,, -88900,2.4096315,1.4609768,,,,,,,,,,,,,, -89000,2.2420664,1.3787491,,,,,,,,,,,,,, -89100,2.2223823,1.331192,,,,,,,,,,,,,, -89200,2.2710907,1.4042194,,,,,,,,,,,,,, -89300,2.1604161,1.3852253,,,,,,,,,,,,,, -89400,2.0804722,1.2966425,,,,,,,,,,,,,, -89500,2.4418375,1.3051134,,,,,,,,,,,,,, -89600,2.2683146,1.3553861,,,,,,,,,,,,,, -89700,2.2524455,1.2985847,,,,,,,,,,,,,, -89764,,,0.7589086294174194,0.9112529158592224,0.682379961013794,1.2872778177261353,50000.0,0.5552000403404236,2.016308069229126,10000.0,30641.0526971817,31712.21232533455,30641.0526971817,1065.8815150260923,2.206571578979492,0.0 -89800,2.4999244,1.3872287,,,,,,,,,,,,,, -89900,2.4602604,1.4915546,,,,,,,,,,,,,, -90000,2.130297,1.3705269,,,,,,,,,,,,,, -90100,2.1732178,1.4641016,,,,,,,,,,,,,, -90200,2.1222215,1.3721148,,,,,,,,,,,,,, -90300,2.2690933,1.4254788,,,,,,,,,,,,,, -90400,2.4172876,1.4048781,,,,,,,,,,,,,, -90500,2.1509914,1.348692,,,,,,,,,,,,,, -90600,2.086583,1.3934983,,,,,,,,,,,,,, -90700,2.3156953,1.5273956,,,,,,,,,,,,,, -90800,2.381174,1.444913,,,,,,,,,,,,,, -90900,2.473032,1.3453348,,,,,,,,,,,,,, -91000,2.3726032,1.4689109,,,,,,,,,,,,,, -91100,2.2983327,1.440345,,,,,,,,,,,,,, -91200,2.5861015,1.354731,,,,,,,,,,,,,, -91261,,,0.7894411683082581,0.784350574016571,0.6829000115394592,1.2923929691314695,50000.0,0.5560000538825989,2.0203568935394287,10000.0,31151.199570178986,32239.76553440094,31151.199570178986,1083.19038939476,2.25357723236084,0.0 -91300,2.261855,1.353862,,,,,,,,,,,,,, -91400,2.2852356,1.3589396,,,,,,,,,,,,,, -91500,2.2273555,1.328849,,,,,,,,,,,,,, -91600,2.4374685,1.3862078,,,,,,,,,,,,,, -91700,2.3049953,1.3626949,,,,,,,,,,,,,, -91800,2.3039188,1.4699818,,,,,,,,,,,,,, -91900,2.268217,1.3817906,,,,,,,,,,,,,, -92000,2.2930317,1.4348809,,,,,,,,,,,,,, -92100,2.497424,1.5391786,,,,,,,,,,,,,, -92200,2.1497958,1.3733565,,,,,,,,,,,,,, -92300,2.1860535,1.2865814,,,,,,,,,,,,,, -92400,2.3766615,1.3982476,,,,,,,,,,,,,, -92500,2.1147864,1.3397063,,,,,,,,,,,,,, -92600,2.1580245,1.3504777,,,,,,,,,,,,,, -92700,2.7026103,1.3562812,,,,,,,,,,,,,, -92757,,,0.7609016299247742,0.8901998996734619,0.6728399991989136,1.3282952308654783,50000.0,0.5391000509262085,2.09073805809021,10000.0,31661.18835258484,32767.13542819023,31661.18835258484,1100.471552848816,2.3035874366760254,0.0 -92800,2.2793195,1.4109986,,,,,,,,,,,,,, -92900,2.360867,1.4439099,,,,,,,,,,,,,, -93000,2.4381793,1.3690825,,,,,,,,,,,,,, -93100,2.2563298,1.4096608,,,,,,,,,,,,,, -93200,2.361913,1.3556416,,,,,,,,,,,,,, -93300,2.2757902,1.4502207,,,,,,,,,,,,,, -93400,2.4754527,1.2786049,,,,,,,,,,,,,, -93500,2.300456,1.4039505,,,,,,,,,,,,,, -93600,2.233841,1.3488286,,,,,,,,,,,,,, -93700,2.4467144,1.4545641,,,,,,,,,,,,,, -93800,2.2699642,1.4786956,,,,,,,,,,,,,, -93900,2.288387,1.3544945,,,,,,,,,,,,,, -94000,2.8092265,1.3945172,,,,,,,,,,,,,, -94100,2.2412586,1.3591375,,,,,,,,,,,,,, -94200,2.4272413,1.4978112,,,,,,,,,,,,,, -94254,,,0.7751315236091614,0.8435813784599304,0.6869999766349792,1.2639058828353882,50000.0,0.5612000226974487,1.9653502702713013,10000.0,32171.235827207565,33294.53087544441,32171.235827207565,1117.723935842514,2.349766731262207,0.0 -94300,2.6438649,1.4327677,,,,,,,,,,,,,, -94400,2.3340518,1.2777529,,,,,,,,,,,,,, -94500,2.5004592,1.4809468,,,,,,,,,,,,,, -94600,2.288811,1.3341877,,,,,,,,,,,,,, -94700,2.2671702,1.2965866,,,,,,,,,,,,,, -94800,2.2275188,1.3666787,,,,,,,,,,,,,, -94900,2.5402608,1.4028281,,,,,,,,,,,,,, -95000,2.1339517,1.3368549,,,,,,,,,,,,,, -95100,2.2295856,1.3524795,,,,,,,,,,,,,, -95200,2.420306,1.3031671,,,,,,,,,,,,,, -95300,2.2983963,1.3266697,,,,,,,,,,,,,, -95400,2.3492713,1.3971074,,,,,,,,,,,,,, -95500,2.4554489,1.4852409,,,,,,,,,,,,,, -95600,2.4048152,1.3434522,,,,,,,,,,,,,, -95700,2.4186857,1.4163215,,,,,,,,,,,,,, -95751,,,0.7679567933082581,0.8627640008926392,0.681659996509552,1.2954697608947754,50000.0,0.555400013923645,2.016524314880371,10000.0,32681.18448400497,33822.053425073624,32681.18448400497,1135.1968188285828,2.4002671241760254,0.0 -95800,2.4506688,1.3816736,,,,,,,,,,,,,, -95900,2.4099424,1.4867828,,,,,,,,,,,,,, -96000,2.6037605,1.3897939,,,,,,,,,,,,,, -96100,2.6608593,1.2928066,,,,,,,,,,,,,, -96200,2.3663332,1.386271,,,,,,,,,,,,,, -96300,2.4685266,1.3821111,,,,,,,,,,,,,, -96400,2.1578271,1.280638,,,,,,,,,,,,,, -96500,2.50895,1.3549894,,,,,,,,,,,,,, -96600,2.4366674,1.3774334,,,,,,,,,,,,,, -96700,2.1505282,1.3245243,,,,,,,,,,,,,, -96800,2.39855,1.3608184,,,,,,,,,,,,,, -96900,2.3411896,1.3719083,,,,,,,,,,,,,, -97000,2.3192039,1.3132601,,,,,,,,,,,,,, -97100,2.5529053,1.3962097,,,,,,,,,,,,,, -97200,2.5564024,1.4467688,,,,,,,,,,,,,, -97248,,,0.7716238498687744,0.8549726605415344,0.691540002822876,1.2567367553710938,50000.0,0.5631999969482422,1.985821008682251,10000.0,33191.21137213707,34350.32893657684,33191.21137213707,1153.3456366062164,2.44881272315979,0.0 -97300,2.3299956,1.3290638,,,,,,,,,,,,,, -97400,2.3679733,1.3195417,,,,,,,,,,,,,, -97500,2.3565955,1.2989975,,,,,,,,,,,,,, -97600,2.2159588,1.2698641,,,,,,,,,,,,,, -97700,2.4829814,1.33788,,,,,,,,,,,,,, -97800,2.531623,1.4311249,,,,,,,,,,,,,, -97900,2.2047405,1.3413599,,,,,,,,,,,,,, -98000,2.6643307,1.4594829,,,,,,,,,,,,,, -98100,2.31288,1.3784171,,,,,,,,,,,,,, -98200,2.3153682,1.2932708,,,,,,,,,,,,,, -98300,2.5411396,1.3235754,,,,,,,,,,,,,, -98400,2.8020122,1.4838499,,,,,,,,,,,,,, -98500,2.4025142,1.3285866,,,,,,,,,,,,,, -98600,2.4007442,1.369706,,,,,,,,,,,,,, -98700,2.5060942,1.4392887,,,,,,,,,,,,,, -98745,,,0.7628945708274841,0.8821855187416077,0.6861000061035156,1.2782268524169922,50000.0,0.5597000122070312,1.9715975522994995,10000.0,33701.27909350395,34878.50706458092,33701.27909350395,1171.3598392009735,2.493917226791382,0.0 -98800,2.3931417,1.3261505,,,,,,,,,,,,,, -98900,2.4475713,1.3204124,,,,,,,,,,,,,, -99000,2.279269,1.4073331,,,,,,,,,,,,,, -99100,2.2614813,1.2719423,,,,,,,,,,,,,, -99200,2.4848278,1.4195725,,,,,,,,,,,,,, -99300,2.3311207,1.3428574,,,,,,,,,,,,,, -99400,2.2478597,1.33055,,,,,,,,,,,,,, -99500,2.4565344,1.3697213,,,,,,,,,,,,,, -99600,2.8299997,1.3493655,,,,,,,,,,,,,, -99700,2.4206996,1.3166826,,,,,,,,,,,,,, -99800,2.4822457,1.2106837,,,,,,,,,,,,,, -99900,2.6148322,1.4359423,,,,,,,,,,,,,, -100000,2.5584533,1.3171272,,,,,,,,,,,,,, -100100,2.5790837,1.4003094,,,,,,,,,,,,,, -100200,2.1159515,1.2070241,,,,,,,,,,,,,, -100241,,,0.8028938174247742,0.7278432250022888,0.6812599897384644,1.3054099082946775,50000.0,0.5509000420570374,2.0477240085601807,10000.0,34210.90730881691,35406.05687189102,34210.90730881691,1188.65078830719,3.0731360912323,0.0 -100300,2.435275,1.4658813,,,,,,,,,,,,,, -100400,2.2977848,1.3002179,,,,,,,,,,,,,, -100500,2.332131,1.3737714,,,,,,,,,,,,,, -100600,2.4411213,1.2322071,,,,,,,,,,,,,, -100700,2.3200264,1.3823968,,,,,,,,,,,,,, -100800,2.675115,1.2501353,,,,,,,,,,,,,, -100900,2.84774,1.450986,,,,,,,,,,,,,, -101000,2.650297,1.5070372,,,,,,,,,,,,,, -101100,2.4259517,1.255952,,,,,,,,,,,,,, -101200,2.5533628,1.3506999,,,,,,,,,,,,,, -101300,2.547827,1.3089445,,,,,,,,,,,,,, -101400,2.7412577,1.3264376,,,,,,,,,,,,,, -101500,2.371876,1.2745552,,,,,,,,,,,,,, -101600,2.413081,1.3164121,,,,,,,,,,,,,, -101700,2.4940639,1.286087,,,,,,,,,,,,,, -101739,,,0.788504421710968,0.7852997183799744,0.6868799924850464,1.2690773010253906,50000.0,0.5597000122070312,1.9839948415756223,10000.0,34721.09689593315,35933.86906027794,34721.09689593315,1206.1760022640228,3.119248867034912,0.0 -101800,2.4059572,1.2962224,,,,,,,,,,,,,, -101900,2.9151351,1.4559565,,,,,,,,,,,,,, -102000,2.5905428,1.3606645,,,,,,,,,,,,,, -102100,2.4565122,1.3336097,,,,,,,,,,,,,, -102200,2.3700974,1.2864844,,,,,,,,,,,,,, -102300,2.2676775,1.2405488,,,,,,,,,,,,,, -102400,2.2670941,1.3236816,,,,,,,,,,,,,, -102500,2.3753662,1.3537748,,,,,,,,,,,,,, -102600,2.4821372,1.3883106,,,,,,,,,,,,,, -102700,2.2490866,1.2119768,,,,,,,,,,,,,, -102800,2.5633159,1.3037077,,,,,,,,,,,,,, -102900,2.7271743,1.3722576,,,,,,,,,,,,,, -103000,2.7356641,1.4207804,,,,,,,,,,,,,, -103100,2.647755,1.3224299,,,,,,,,,,,,,, -103200,2.447712,1.3713449,,,,,,,,,,,,,, -103236,,,0.7882851958274841,0.7907478213310242,0.6976199746131897,1.247995376586914,50000.0,0.5722000002861023,1.93617844581604,10000.0,35231.05541777611,36461.30139732361,35231.05541777611,1223.5552270412445,3.163418292999268,0.0 -103300,2.6061702,1.3576553,,,,,,,,,,,,,, -103400,2.4382832,1.2954193,,,,,,,,,,,,,, -103500,2.543067,1.311583,,,,,,,,,,,,,, -103600,2.595702,1.4437524,,,,,,,,,,,,,, -103700,2.3954499,1.2839957,,,,,,,,,,,,,, -103800,2.5851786,1.24238,,,,,,,,,,,,,, -103900,2.5963054,1.2735621,,,,,,,,,,,,,, -104000,2.5152955,1.3575683,,,,,,,,,,,,,, -104100,2.5663393,1.3954225,,,,,,,,,,,,,, -104200,2.4494414,1.2035013,,,,,,,,,,,,,, -104300,2.6389782,1.2840159,,,,,,,,,,,,,, -104400,2.5070834,1.2513632,,,,,,,,,,,,,, -104500,2.6985083,1.3633996,,,,,,,,,,,,,, -104600,2.4354806,1.3272417,,,,,,,,,,,,,, -104700,2.6095445,1.3422343,,,,,,,,,,,,,, -104733,,,0.7801936864852905,0.8114662766456604,0.6913599967956543,1.254683017730713,50000.0,0.5595000386238098,1.9859092235565183,10000.0,35741.056241989136,36988.61328577995,35741.056241989136,1240.7657074928284,3.2133994102478027,0.0 -104800,2.8149524,1.4199513,,,,,,,,,,,,,, -104900,2.6068165,1.259413,,,,,,,,,,,,,, -105000,2.5224562,1.2919778,,,,,,,,,,,,,, -105100,2.5692887,1.281021,,,,,,,,,,,,,, -105200,2.5888493,1.1921489,,,,,,,,,,,,,, -105300,2.378656,1.2540516,,,,,,,,,,,,,, -105400,2.6186163,1.3506233,,,,,,,,,,,,,, -105500,2.6265154,1.3268253,,,,,,,,,,,,,, -105600,2.3480554,1.2904652,,,,,,,,,,,,,, -105700,2.6072807,1.4466623,,,,,,,,,,,,,, -105800,2.8079858,1.295081,,,,,,,,,,,,,, -105900,2.8096604,1.321784,,,,,,,,,,,,,, -106000,2.6059053,1.3081919,,,,,,,,,,,,,, -106100,2.3030448,1.2963705,,,,,,,,,,,,,, -106200,2.5444615,1.366085,,,,,,,,,,,,,, -106230,,,0.7833226919174194,0.8032361268997192,0.6902399659156799,1.247302532196045,50000.0,0.5678000450134277,1.9703844785690308,10000.0,36250.98030781746,37516.186690330505,36250.98030781746,1258.3189299106598,3.259010076522827,0.0 -106300,2.6123536,1.3650849,,,,,,,,,,,,,, -106400,2.7597644,1.3952569,,,,,,,,,,,,,, -106500,2.7359662,1.3308558,,,,,,,,,,,,,, -106600,2.7335706,1.3890196,,,,,,,,,,,,,, -106700,2.631486,1.2155566,,,,,,,,,,,,,, -106800,2.6857495,1.339282,,,,,,,,,,,,,, -106900,2.480127,1.213306,,,,,,,,,,,,,, -107000,2.82763,1.3750982,,,,,,,,,,,,,, -107100,2.605592,1.2810442,,,,,,,,,,,,,, -107200,2.4949455,1.2226055,,,,,,,,,,,,,, -107300,2.6589966,1.41118,,,,,,,,,,,,,, -107400,2.505604,1.2726768,,,,,,,,,,,,,, -107500,2.657694,1.2107937,,,,,,,,,,,,,, -107600,2.8314404,1.4142652,,,,,,,,,,,,,, -107700,2.4671793,1.1880366,,,,,,,,,,,,,, -107727,,,0.7869698405265808,0.769432544708252,0.6983199715614319,1.2229222059249878,50000.0,0.5678000450134277,1.957873106002808,10000.0,36760.9456949234,38043.81551671028,36760.9456949234,1275.867201089859,3.323164701461792,0.0 -107800,2.373541,1.2679696,,,,,,,,,,,,,, -107900,2.6139774,1.2660503,,,,,,,,,,,,,, -108000,2.8299234,1.4065678,,,,,,,,,,,,,, -108100,2.4540644,1.224055,,,,,,,,,,,,,, -108200,2.6537354,1.3391556,,,,,,,,,,,,,, -108300,2.480052,1.2461371,,,,,,,,,,,,,, -108400,2.858827,1.3416158,,,,,,,,,,,,,, -108500,2.8372054,1.3859057,,,,,,,,,,,,,, -108600,2.9617558,1.2461987,,,,,,,,,,,,,, -108700,2.599636,1.2826244,,,,,,,,,,,,,, -108800,2.6379614,1.2690678,,,,,,,,,,,,,, -108900,2.6110594,1.3377959,,,,,,,,,,,,,, -109000,2.6025653,1.3238246,,,,,,,,,,,,,, -109100,2.5394409,1.2199193,,,,,,,,,,,,,, -109200,2.6519732,1.2239304,,,,,,,,,,,,,, -109225,,,0.8082947731018066,0.6996628046035767,0.6953799724578857,1.2410041093826294,50000.0,0.5749000310897827,1.95384418964386,10000.0,37271.04302740097,38571.32220196724,37271.04302740097,1293.1777880191803,3.372274875640869,0.0 -109300,2.372012,1.1520207,,,,,,,,,,,,,, -109400,2.6496828,1.2653006,,,,,,,,,,,,,, -109500,2.6285605,1.3184114,,,,,,,,,,,,,, -109600,2.8050296,1.3199028,,,,,,,,,,,,,, -109700,2.8983238,1.296281,,,,,,,,,,,,,, -109800,2.9832902,1.2318764,,,,,,,,,,,,,, -109900,2.5843945,1.2951331,,,,,,,,,,,,,, -110000,2.606749,1.2663836,,,,,,,,,,,,,, -110100,2.6366441,1.2824981,,,,,,,,,,,,,, -110200,2.5763133,1.2856944,,,,,,,,,,,,,, -110300,2.7235053,1.3163689,,,,,,,,,,,,,, -110400,2.801456,1.3381335,,,,,,,,,,,,,, -110500,2.7223268,1.2836082,,,,,,,,,,,,,, -110600,2.6263578,1.2318804,,,,,,,,,,,,,, -110700,2.6835945,1.32212,,,,,,,,,,,,,, -110722,,,0.813875138759613,0.6721604466438293,0.7019599676132202,1.222309947013855,50000.0,0.5769000053405762,1.93671452999115,10000.0,37781.17570281029,39099.16570162773,37781.17570281029,1310.7889490127563,3.420320987701416,0.0 -110800,2.7963083,1.216908,,,,,,,,,,,,,, -110900,2.599443,1.2764455,,,,,,,,,,,,,, -111000,2.9747882,1.3859036,,,,,,,,,,,,,, -111100,2.5060666,1.254312,,,,,,,,,,,,,, -111200,2.6701515,1.2489198,,,,,,,,,,,,,, -111300,3.0857732,1.3749751,,,,,,,,,,,,,, -111400,2.8323767,1.2291557,,,,,,,,,,,,,, -111500,2.6675088,1.2266706,,,,,,,,,,,,,, -111600,2.716811,1.2872028,,,,,,,,,,,,,, -111700,2.6882186,1.2776593,,,,,,,,,,,,,, -111800,2.5293293,1.1923459,,,,,,,,,,,,,, -111900,2.7027762,1.1824069,,,,,,,,,,,,,, -112000,2.5809848,1.2254713,,,,,,,,,,,,,, -112100,2.6823046,1.254326,,,,,,,,,,,,,, -112200,2.8952363,1.213482,,,,,,,,,,,,,, -112219,,,0.8077367544174194,0.6889608502388,0.7064599990844727,1.1906296014785769,50000.0,0.5773000121116638,1.9144089221954343,10000.0,38291.41183042526,39626.73736596108,38291.41183042526,1328.0241174697876,3.4695346355438232,0.0 -112300,2.724001,1.2539916,,,,,,,,,,,,,, -112400,2.838438,1.2731814,,,,,,,,,,,,,, -112500,2.9644594,1.2415439,,,,,,,,,,,,,, -112600,2.6635158,1.2636884,,,,,,,,,,,,,, -112700,2.8601196,1.2530766,,,,,,,,,,,,,, -112800,2.5936985,1.2358837,,,,,,,,,,,,,, -112900,2.7143312,1.2511863,,,,,,,,,,,,,, -113000,2.5782516,1.2083216,,,,,,,,,,,,,, -113100,2.8470445,1.2171979,,,,,,,,,,,,,, -113200,2.9262912,1.1470983,,,,,,,,,,,,,, -113300,2.705631,1.1699759,,,,,,,,,,,,,, -113400,2.760667,1.2105742,,,,,,,,,,,,,, -113500,2.8403273,1.3316936,,,,,,,,,,,,,, -113600,2.776426,1.1830728,,,,,,,,,,,,,, -113700,2.7854092,1.1564736,,,,,,,,,,,,,, -113716,,,0.7943239808082581,0.7536799311637878,0.6958999633789062,1.2339434623718262,50000.0,0.5742000341415405,1.97724187374115,10000.0,38801.31273698807,40154.12018656731,38801.31273698807,1345.405649185181,3.519113063812256,0.0 -113800,2.78708,1.226156,,,,,,,,,,,,,, -113900,2.7598097,1.2309513,,,,,,,,,,,,,, -114000,2.797218,1.2698485,,,,,,,,,,,,,, -114100,2.935724,1.2346997,,,,,,,,,,,,,, -114200,2.6194263,1.1927862,,,,,,,,,,,,,, -114300,2.7457922,1.2191696,,,,,,,,,,,,,, -114400,3.0968583,1.2407385,,,,,,,,,,,,,, -114500,3.0864496,1.3053081,,,,,,,,,,,,,, -114600,2.8031104,1.3220844,,,,,,,,,,,,,, -114700,2.8387125,1.298368,,,,,,,,,,,,,, -114800,2.7727768,1.2604088,,,,,,,,,,,,,, -114900,2.826313,1.3068353,,,,,,,,,,,,,, -115000,2.6372986,1.3351645,,,,,,,,,,,,,, -115100,2.81233,1.1260103,,,,,,,,,,,,,, -115200,2.7260237,1.2121729,,,,,,,,,,,,,, -115212,,,0.7994260191917419,0.7419718503952026,0.6971399784088135,1.2301865816116333,50000.0,0.5767000317573547,1.938300848007202,10000.0,39311.23474597931,40681.605519771576,39311.23474597931,1362.8529393672943,3.5819525718688965,0.0 -115300,2.8703685,1.196215,,,,,,,,,,,,,, -115400,2.7305467,1.2005465,,,,,,,,,,,,,, -115500,2.5558069,1.1924154,,,,,,,,,,,,,, -115600,3.0774045,1.263349,,,,,,,,,,,,,, -115700,2.7572298,1.1475489,,,,,,,,,,,,,, -115800,2.6466198,1.2114075,,,,,,,,,,,,,, -115900,2.762673,1.1862398,,,,,,,,,,,,,, -116000,2.8212335,1.1978363,,,,,,,,,,,,,, -116100,2.999981,1.2812326,,,,,,,,,,,,,, -116200,2.9161115,1.1674765,,,,,,,,,,,,,, -116300,2.9759643,1.2525471,,,,,,,,,,,,,, -116400,2.6444175,1.1761012,,,,,,,,,,,,,, -116500,2.7822285,1.1670349,,,,,,,,,,,,,, -116600,2.9469821,1.2459824,,,,,,,,,,,,,, -116700,2.78077,1.1571583,,,,,,,,,,,,,, -116709,,,0.8053252100944519,0.7091084718704224,0.7084199786186218,1.1827819347381592,50000.0,0.5850000381469727,1.901528239250183,10000.0,39821.20298480988,41209.22456550598,39821.20298480988,1380.4033725261688,3.630246162414551,0.0 -116800,2.8658783,1.2464204,,,,,,,,,,,,,, -116900,3.0134075,1.250205,,,,,,,,,,,,,, -117000,2.635456,1.1506523,,,,,,,,,,,,,, -117100,3.3824406,1.1702894,,,,,,,,,,,,,, -117200,2.8268352,1.148139,,,,,,,,,,,,,, -117300,2.8004787,1.1507128,,,,,,,,,,,,,, -117400,2.6669526,1.2270352,,,,,,,,,,,,,, -117500,3.0141387,1.2451434,,,,,,,,,,,,,, -117600,3.0771413,1.2579396,,,,,,,,,,,,,, -117700,2.9901848,1.2145464,,,,,,,,,,,,,, -117800,2.7963042,1.2170094,,,,,,,,,,,,,, -117900,2.7290053,1.2290188,,,,,,,,,,,,,, -118000,2.7069628,1.1561477,,,,,,,,,,,,,, -118100,2.902687,1.1697404,,,,,,,,,,,,,, -118200,2.858932,1.2300825,,,,,,,,,,,,,, -118206,,,0.8056241869926453,0.7094724178314209,0.7069999575614929,1.1886725425720217,50000.0,0.5771000385284424,1.9114309549331665,10000.0,40331.18197154999,41736.874911785126,40331.18197154999,1397.9720528125763,3.681635856628418,0.0 -118300,2.6067336,1.2663773,,,,,,,,,,,,,, -118400,2.941905,1.1697708,,,,,,,,,,,,,, -118500,2.7539597,1.1629399,,,,,,,,,,,,,, -118600,3.3462696,1.2718351,,,,,,,,,,,,,, -118700,2.9232733,1.1777388,,,,,,,,,,,,,, -118800,2.9604766,1.2319244,,,,,,,,,,,,,, -118900,3.0926101,1.279735,,,,,,,,,,,,,, -119000,2.642971,1.1492174,,,,,,,,,,,,,, -119100,2.7026887,1.2180938,,,,,,,,,,,,,, -119200,2.8639715,1.1980759,,,,,,,,,,,,,, -119300,3.2189543,1.2304977,,,,,,,,,,,,,, -119400,3.0901098,1.3217914,,,,,,,,,,,,,, -119500,2.8267148,1.174111,,,,,,,,,,,,,, -119600,2.8933065,1.0947841,,,,,,,,,,,,,, -119700,2.83593,1.1125662,,,,,,,,,,,,,, -119703,,,0.8384486436843872,0.5811982154846191,0.7049599885940552,1.1874204874038696,50000.0,0.5821000337600708,1.9020187854766848,10000.0,40841.25449848175,42264.33503699303,40841.25449848175,1415.2570950984957,3.732133626937866,0.0 -119800,3.1731791,1.2810224,,,,,,,,,,,,,, -119900,2.9948342,1.1885084,,,,,,,,,,,,,, -120000,3.1578581,1.2332909,,,,,,,,,,,,,, -120100,2.7019234,1.180538,,,,,,,,,,,,,, -120200,2.9313445,1.1507051,,,,,,,,,,,,,, -120300,2.7841125,1.1321768,,,,,,,,,,,,,, -120400,3.011826,1.272836,,,,,,,,,,,,,, -120500,3.182793,1.2699004,,,,,,,,,,,,,, -120600,3.1815453,1.261555,,,,,,,,,,,,,, -120700,3.195396,1.1475924,,,,,,,,,,,,,, -120800,3.0449357,1.1772406,,,,,,,,,,,,,, -120900,2.9142048,1.1291983,,,,,,,,,,,,,, -121000,3.2191832,1.1215775,,,,,,,,,,,,,, -121100,2.664826,1.1469693,,,,,,,,,,,,,, -121200,2.9100313,1.1526399,,,,,,,,,,,,,, -121201,,,0.8234614133834839,0.6339573860168457,0.7074999809265137,1.1968789100646973,50000.0,0.5819000005722046,1.924055814743042,10000.0,41351.66305828095,42792.24165916443,41351.66305828095,1432.6516358852386,3.784732341766357,0.0 -121300,3.025409,1.1271905,,,,,,,,,,,,,, -121400,3.1020305,1.3213127,,,,,,,,,,,,,, -121500,2.7693317,1.181767,,,,,,,,,,,,,, -121600,2.9116216,1.1958019,,,,,,,,,,,,,, -121700,2.9873135,1.2104186,,,,,,,,,,,,,, -121800,2.8386428,1.2714474,,,,,,,,,,,,,, -121900,2.9200323,1.1606524,,,,,,,,,,,,,, -122000,2.832802,1.1433703,,,,,,,,,,,,,, -122100,3.170025,1.1123834,,,,,,,,,,,,,, -122200,3.0130975,1.1335548,,,,,,,,,,,,,, -122300,3.2498438,1.21538,,,,,,,,,,,,,, -122400,3.0563672,1.149361,,,,,,,,,,,,,, -122500,3.0632377,1.1825676,,,,,,,,,,,,,, -122600,3.048386,1.1986729,,,,,,,,,,,,,, -122698,,,0.8205317258834839,0.6411925554275513,0.707040011882782,1.196536898612976,50000.0,0.5800000429153442,1.9360084533691408,10000.0,41861.57820510864,43319.72976899147,41861.57820510864,1450.1179354190826,3.839577436447144,0.0 -122700,2.8687534,1.2170119,,,,,,,,,,,,,, -122800,3.3265164,1.1725038,,,,,,,,,,,,,, -122900,3.0318806,1.2150779,,,,,,,,,,,,,, -123000,3.2499866,1.1797979,,,,,,,,,,,,,, -123100,3.0797746,1.2266802,,,,,,,,,,,,,, -123200,3.2381594,1.1413255,,,,,,,,,,,,,, -123300,3.065358,1.0691497,,,,,,,,,,,,,, -123400,3.1100376,1.1810393,,,,,,,,,,,,,, -123500,2.879358,1.1609678,,,,,,,,,,,,,, -123600,3.226153,1.2317414,,,,,,,,,,,,,, -123700,3.102024,1.2038736,,,,,,,,,,,,,, -123800,3.164,1.1825833,,,,,,,,,,,,,, -123900,3.2470157,1.1617224,,,,,,,,,,,,,, -124000,3.1863296,1.1562741,,,,,,,,,,,,,, -124100,2.918518,1.1873375,,,,,,,,,,,,,, -124194,,,0.8153499364852905,0.6613895297050476,0.7084000110626221,1.1882474422454834,50000.0,0.5819000005722046,1.9214253425598145,10000.0,42371.47960424423,43846.969765901566,42371.47960424423,1467.3553059101105,3.890242338180542,0.0 -124200,2.9046793,1.1088692,,,,,,,,,,,,,, -124300,2.9100387,1.1023266,,,,,,,,,,,,,, -124400,3.3914177,1.211583,,,,,,,,,,,,,, -124500,2.9429479,1.1162949,,,,,,,,,,,,,, -124600,3.1545238,1.1030252,,,,,,,,,,,,,, -124700,2.972926,1.1318247,,,,,,,,,,,,,, -124800,3.123965,1.0174496,,,,,,,,,,,,,, -124900,3.1613631,1.0665894,,,,,,,,,,,,,, -125000,2.8200066,1.0167605,,,,,,,,,,,,,, -125100,3.0240216,1.135397,,,,,,,,,,,,,, -125200,2.9966211,1.1408204,,,,,,,,,,,,,, -125300,3.2528145,1.0702437,,,,,,,,,,,,,, -125400,3.111277,1.1841768,,,,,,,,,,,,,, -125500,3.2558994,1.262096,,,,,,,,,,,,,, -125600,3.2786844,1.2007823,,,,,,,,,,,,,, -125692,,,0.8228236436843872,0.6361862421035767,0.7134000062942505,1.174013614654541,50000.0,0.5859000086784363,1.90839946269989,10000.0,42881.63268017769,44374.61629462242,42881.63268017769,1484.7452561855316,3.943494319915772,0.0 -125700,3.1278832,1.0399052,,,,,,,,,,,,,, -125800,3.147454,1.216656,,,,,,,,,,,,,, -125900,3.144232,1.1466888,,,,,,,,,,,,,, -126000,3.0129955,1.1525376,,,,,,,,,,,,,, -126100,3.2469954,1.1128353,,,,,,,,,,,,,, -126200,3.1747308,1.1846602,,,,,,,,,,,,,, -126300,3.0796735,1.1822119,,,,,,,,,,,,,, -126400,3.4339342,1.1189666,,,,,,,,,,,,,, -126500,2.7299755,0.99915373,,,,,,,,,,,,,, -126600,3.2906594,1.179656,,,,,,,,,,,,,, -126700,3.1133676,1.1155653,,,,,,,,,,,,,, -126800,3.5633223,1.1299319,,,,,,,,,,,,,, -126900,3.1793768,1.2314411,,,,,,,,,,,,,, -127000,3.4210715,1.1390966,,,,,,,,,,,,,, -127100,3.3889537,1.1975926,,,,,,,,,,,,,, -127190,,,0.8191167116165161,0.6429269909858704,0.7112399935722351,1.1914821863174438,50000.0,0.5815000534057617,1.9391067028045648,10000.0,43391.758655786514,44902.20900511742,43391.758655786514,1502.1120376586914,3.9931235313415527,0.0 -127200,3.304075,1.1039517,,,,,,,,,,,,,, -127300,3.015718,1.1219487,,,,,,,,,,,,,, -127400,2.9601734,1.063729,,,,,,,,,,,,,, -127500,2.8824854,0.9844289,,,,,,,,,,,,,, -127600,3.0445228,1.1549897,,,,,,,,,,,,,, -127700,3.443571,1.194415,,,,,,,,,,,,,, -127800,3.5369518,1.2052827,,,,,,,,,,,,,, -127900,3.2515483,1.1118627,,,,,,,,,,,,,, -128000,3.3146684,1.2348633,,,,,,,,,,,,,, -128100,3.156258,1.1199446,,,,,,,,,,,,,, -128200,3.0964336,1.0536945,,,,,,,,,,,,,, -128300,3.4421902,1.0981786,,,,,,,,,,,,,, -128400,3.228771,1.0729359,,,,,,,,,,,,,, -128500,3.1442096,1.0957296,,,,,,,,,,,,,, -128600,3.0334506,1.078498,,,,,,,,,,,,,, -128687,,,0.8657525181770325,0.4807564318180084,0.7155599594116211,1.1618179082870483,50000.0,0.5940000414848328,1.870269656181336,10000.0,43901.96710586548,45430.13155961037,43901.96710586548,1519.7223196029663,4.045359134674072,0.0 -128700,2.9966948,1.0577782,,,,,,,,,,,,,, -128800,3.1195853,1.1251421,,,,,,,,,,,,,, -128900,3.1523306,1.0555469,,,,,,,,,,,,,, -129000,3.0409522,1.0217785,,,,,,,,,,,,,, -129100,3.1994584,1.153541,,,,,,,,,,,,,, -129200,3.1274176,1.1529648,,,,,,,,,,,,,, -129300,3.519296,1.0654736,,,,,,,,,,,,,, -129400,3.4551349,1.1921331,,,,,,,,,,,,,, -129500,3.1506875,1.0349722,,,,,,,,,,,,,, -129600,3.2842786,1.1793118,,,,,,,,,,,,,, -129700,3.108038,1.0200504,,,,,,,,,,,,,, -129800,3.3229907,1.1527214,,,,,,,,,,,,,, -129900,3.144345,1.0593245,,,,,,,,,,,,,, -130000,3.6210666,1.1293799,,,,,,,,,,,,,, -130100,3.2618895,1.1400234,,,,,,,,,,,,,, -130184,,,0.8509646058082581,0.52117919921875,0.7196599841117859,1.150366187095642,50000.0,0.5944000482559204,1.8714016675949097,10000.0,44412.060628175735,45957.546800136566,44412.060628175735,1536.9437124729156,4.095425844192505,0.0 -130200,3.194523,1.0564002,,,,,,,,,,,,,, -130300,3.3396099,1.0529156,,,,,,,,,,,,,, -130400,3.3186018,1.1076281,,,,,,,,,,,,,, -130500,3.3176446,1.092588,,,,,,,,,,,,,, -130600,2.9537833,1.0258027,,,,,,,,,,,,,, -130700,3.517932,1.1635957,,,,,,,,,,,,,, -130800,3.2573943,1.165647,,,,,,,,,,,,,, -130900,3.2493105,1.082185,,,,,,,,,,,,,, -131000,3.103669,1.0263065,,,,,,,,,,,,,, -131100,3.6283054,1.0456979,,,,,,,,,,,,,, -131200,3.742667,1.0631086,,,,,,,,,,,,,, -131300,3.229387,1.0839694,,,,,,,,,,,,,, -131400,3.7319791,1.07892,,,,,,,,,,,,,, -131500,3.2097266,1.0054194,,,,,,,,,,,,,, -131600,3.0756497,1.0100143,,,,,,,,,,,,,, -131681,,,0.8581393361091614,0.5003353953361511,0.721340000629425,1.128917932510376,50000.0,0.5985000133514404,1.852310061454773,10000.0,44922.07017183304,46485.04060125351,44922.07017183304,1554.3216423988342,4.150796890258789,0.0 -131700,3.4463413,1.1750442,,,,,,,,,,,,,, -131800,3.6053293,1.1030478,,,,,,,,,,,,,, -131900,3.5715892,1.040762,,,,,,,,,,,,,, -132000,3.3244996,1.0860524,,,,,,,,,,,,,, -132100,3.3920205,1.040605,,,,,,,,,,,,,, -132200,3.2564821,0.9763067,,,,,,,,,,,,,, -132300,3.0612314,0.93810546,,,,,,,,,,,,,, -132400,3.2488887,1.0030978,,,,,,,,,,,,,, -132500,3.6775906,1.0988808,,,,,,,,,,,,,, -132600,3.1845884,1.1279656,,,,,,,,,,,,,, -132700,3.5450873,1.0832899,,,,,,,,,,,,,, -132800,3.5008707,1.1391883,,,,,,,,,,,,,, -132900,3.4560764,1.0104847,,,,,,,,,,,,,, -133000,3.6140733,1.2035297,,,,,,,,,,,,,, -133100,3.14557,1.0222961,,,,,,,,,,,,,, -133178,,,0.8520009517669678,0.5243096351623535,0.7253999710083008,1.120864987373352,50000.0,0.5966000556945801,1.8334498405456543,10000.0,45432.17251110077,47012.60780596733,45432.17251110077,1571.6787357330322,4.2078258991241455,0.0 -133200,3.564938,1.0706711,,,,,,,,,,,,,, -133300,3.7078264,1.0650582,,,,,,,,,,,,,, -133400,3.5255196,0.9916733,,,,,,,,,,,,,, -133500,3.5140197,1.0392785,,,,,,,,,,,,,, -133600,3.19512,0.9988599,,,,,,,,,,,,,, -133700,3.215623,1.1097229,,,,,,,,,,,,,, -133800,3.0974543,0.95474076,,,,,,,,,,,,,, -133900,3.4876947,1.1448197,,,,,,,,,,,,,, -134000,3.3141472,1.0378205,,,,,,,,,,,,,, -134100,3.699643,1.0842996,,,,,,,,,,,,,, -134200,3.3775647,1.0067203,,,,,,,,,,,,,, -134300,3.2306113,1.031606,,,,,,,,,,,,,, -134400,3.1704047,1.0107241,,,,,,,,,,,,,, -134500,3.158417,1.0702603,,,,,,,,,,,,,, -134600,3.2428224,1.0499562,,,,,,,,,,,,,, -134674,,,0.8475764989852905,0.5381889343261719,0.7217999696731567,1.1407110691070557,50000.0,0.6020000576972961,1.8749887943267824,10000.0,45942.078404426575,47540.725546360016,45942.078404426575,1589.7894802093506,4.2574567794799805,0.0 -134700,3.3857472,1.0896178,,,,,,,,,,,,,, -134800,3.4706137,1.035669,,,,,,,,,,,,,, -134900,3.6538558,1.0394757,,,,,,,,,,,,,, -135000,3.9349642,1.0294377,,,,,,,,,,,,,, -135100,3.1297321,0.9741092,,,,,,,,,,,,,, -135200,3.3572342,0.9843364,,,,,,,,,,,,,, -135300,3.7193372,1.0955393,,,,,,,,,,,,,, -135400,3.30121,1.0310612,,,,,,,,,,,,,, -135500,3.2717028,0.92713267,,,,,,,,,,,,,, -135600,3.3651283,0.9686129,,,,,,,,,,,,,, -135700,3.3411028,1.0255752,,,,,,,,,,,,,, -135800,3.590238,1.0270733,,,,,,,,,,,,,, -135900,3.4650245,1.060673,,,,,,,,,,,,,, -136000,3.2876296,0.92884636,,,,,,,,,,,,,, -136100,3.3092115,1.0414089,,,,,,,,,,,,,, -136171,,,0.8484932780265808,0.5277658104896545,0.7238399982452393,1.1313154697418213,50000.0,0.5981000065803528,1.8592684268951416,10000.0,46452.042345047,48068.0526099205,46452.042345047,1607.0494379997251,4.309852600097656,0.0 -136200,3.5438023,0.9919723,,,,,,,,,,,,,, -136300,3.2144294,0.9420348,,,,,,,,,,,,,, -136400,3.3740854,0.98684597,,,,,,,,,,,,,, -136500,3.6310568,1.0380415,,,,,,,,,,,,,, -136600,3.6166306,1.1570698,,,,,,,,,,,,,, -136700,3.534601,1.0113846,,,,,,,,,,,,,, -136800,3.3743176,0.9999013,,,,,,,,,,,,,, -136900,3.4886684,1.0193596,,,,,,,,,,,,,, -137000,3.3953125,0.96931314,,,,,,,,,,,,,, -137100,3.7589462,1.0468471,,,,,,,,,,,,,, -137200,3.2827768,0.9475857,,,,,,,,,,,,,, -137300,3.5512738,1.0189526,,,,,,,,,,,,,, -137400,3.399923,0.9931118,,,,,,,,,,,,,, -137500,3.737546,0.96638525,,,,,,,,,,,,,, -137600,4.0066733,0.9950354,,,,,,,,,,,,,, -137668,,,0.865652859210968,0.4712446928024292,0.7237399816513062,1.1422827243804932,50000.0,0.5975000262260437,1.896036386489868,10000.0,46961.99141907692,48595.35859775543,46961.99141907692,1624.3033895492554,4.361900568008423,0.0 -137700,3.9400346,1.137856,,,,,,,,,,,,,, -137800,3.7724879,1.1664075,,,,,,,,,,,,,, -137900,3.8726356,1.0827602,,,,,,,,,,,,,, -138000,3.9522994,0.98128724,,,,,,,,,,,,,, -138100,3.8120894,0.97011584,,,,,,,,,,,,,, -138200,3.72476,1.0180957,,,,,,,,,,,,,, -138300,3.3440957,1.0306287,,,,,,,,,,,,,, -138400,3.7807806,1.1100383,,,,,,,,,,,,,, -138500,3.6722128,1.0855722,,,,,,,,,,,,,, -138600,3.476831,1.0089769,,,,,,,,,,,,,, -138700,3.817386,1.0101444,,,,,,,,,,,,,, -138800,4.214208,1.0457734,,,,,,,,,,,,,, -138900,3.8134205,1.0527934,,,,,,,,,,,,,, -139000,3.6228037,0.9450421,,,,,,,,,,,,,, -139100,3.7858467,1.0444874,,,,,,,,,,,,,, -139165,,,0.8843072056770325,0.402018278837204,0.7286999821662903,1.1118732690811155,50000.0,0.612500011920929,1.8532575368881223,10000.0,47472.01818680763,49122.931619644165,47472.01818680763,1641.7458896636963,4.415813446044922,0.0 -139200,3.505941,0.9499508,,,,,,,,,,,,,, -139300,3.674322,1.0682741,,,,,,,,,,,,,, -139400,3.8126516,1.0737882,,,,,,,,,,,,,, -139500,3.6106787,0.9533289,,,,,,,,,,,,,, -139600,3.4033306,0.9613422,,,,,,,,,,,,,, -139700,3.7245712,1.0154561,,,,,,,,,,,,,, -139800,3.4547706,0.9566981,,,,,,,,,,,,,, -139900,3.9479651,0.96830654,,,,,,,,,,,,,, -140000,3.6782894,0.96754915,,,,,,,,,,,,,, -140100,3.8030798,0.94241613,,,,,,,,,,,,,, -140200,3.9278822,0.984563,,,,,,,,,,,,,, -140300,3.5716383,0.9744535,,,,,,,,,,,,,, -140400,3.4868245,0.9404674,,,,,,,,,,,,,, -140500,3.7275996,1.0074413,,,,,,,,,,,,,, -140600,4.0733237,1.0068539,,,,,,,,,,,,,, -140663,,,0.87890625,0.4263043105602264,0.7313599586486816,1.112805724143982,50000.0,0.6103000044822693,1.8223932981491089,10000.0,47982.18217277527,49650.53896570206,47982.18217277527,1659.0822570323944,4.472891807556152,0.0 -140700,3.5844247,0.9758621,,,,,,,,,,,,,, -140800,3.9513454,0.9843794,,,,,,,,,,,,,, -140900,3.655174,0.979631,,,,,,,,,,,,,, -141000,3.3903627,0.9741724,,,,,,,,,,,,,, -141100,3.531453,0.9951796,,,,,,,,,,,,,, -141200,3.639102,0.9775017,,,,,,,,,,,,,, -141300,3.6464684,1.0120047,,,,,,,,,,,,,, -141400,3.5075228,0.9071498,,,,,,,,,,,,,, -141500,3.8240547,1.0110127,,,,,,,,,,,,,, -141600,3.6062973,0.9967426,,,,,,,,,,,,,, -141700,3.7448292,0.9163548,,,,,,,,,,,,,, -141800,3.6653101,0.96513665,,,,,,,,,,,,,, -141900,3.7370775,0.96235603,,,,,,,,,,,,,, -142000,3.3504364,0.9319668,,,,,,,,,,,,,, -142100,4.118897,1.0320666,,,,,,,,,,,,,, -142160,,,0.8776904940605164,0.4207326173782348,0.7316799759864807,1.1086206436157229,50000.0,0.6021000146865845,1.843847751617432,10000.0,48492.161371946335,50178.48799037933,48492.161371946335,1676.9486873149872,4.525469541549683,0.0 -142200,4.3130913,1.0404056,,,,,,,,,,,,,, -142300,3.8401427,0.991566,,,,,,,,,,,,,, -142400,3.6922922,1.001408,,,,,,,,,,,,,, -142500,4.1392107,0.9940578,,,,,,,,,,,,,, -142600,3.638213,0.96725667,,,,,,,,,,,,,, -142700,3.560027,1.0044606,,,,,,,,,,,,,, -142800,3.5365787,0.9007712,,,,,,,,,,,,,, -142900,3.8117561,0.88680494,,,,,,,,,,,,,, -143000,3.663263,1.0617352,,,,,,,,,,,,,, -143100,4.0845127,0.9499989,,,,,,,,,,,,,, -143200,3.5414348,0.9833539,,,,,,,,,,,,,, -143300,3.637455,0.9653195,,,,,,,,,,,,,, -143400,3.7716243,0.8708873,,,,,,,,,,,,,, -143500,3.6065054,0.87673825,,,,,,,,,,,,,, -143600,3.659963,0.9985691,,,,,,,,,,,,,, -143657,,,0.8729472160339355,0.4409449398517608,0.7316799759864807,1.1113321781158447,50000.0,0.6050000190734863,1.839349269866944,10000.0,49002.2856194973,50706.21445202828,49002.2856194973,1694.4568555355072,4.5699193477630615,0.0 -143700,3.7367675,0.9654778,,,,,,,,,,,,,, -143800,3.6731274,0.9088097,,,,,,,,,,,,,, -143900,3.7470427,0.9708385,,,,,,,,,,,,,, -144000,4.3013086,0.965671,,,,,,,,,,,,,, -144100,3.6607769,0.8855077,,,,,,,,,,,,,, -144200,3.8088803,0.95693874,,,,,,,,,,,,,, -144300,4.2377453,0.8603301,,,,,,,,,,,,,, -144400,3.6236558,0.83053434,,,,,,,,,,,,,, -144500,4.012858,0.93238425,,,,,,,,,,,,,, -144600,3.8167033,0.9336014,,,,,,,,,,,,,, -144700,3.881222,1.0544395,,,,,,,,,,,,,, -144800,4.0925508,0.9891342,,,,,,,,,,,,,, -144900,3.6840022,0.87654424,,,,,,,,,,,,,, -145000,3.4619334,0.8535518,,,,,,,,,,,,,, -145100,3.6523802,0.87244236,,,,,,,,,,,,,, -145154,,,0.8752989172935486,0.4280344545841217,0.7317799925804138,1.1089249849319458,50000.0,0.6053000092506409,1.8584744930267327,10000.0,49512.30458474159,51233.79352784157,49512.30458474159,1711.908765077591,4.625460624694824,0.0 -145200,4.2948523,0.9242867,,,,,,,,,,,,,, -145300,3.9027472,0.9052613,,,,,,,,,,,,,, -145400,3.5396922,0.8570324,,,,,,,,,,,,,, -145500,3.9003546,0.94192433,,,,,,,,,,,,,, -145600,3.6072798,0.88260734,,,,,,,,,,,,,, -145700,3.4130123,0.890203,,,,,,,,,,,,,, -145800,3.7217271,0.9044844,,,,,,,,,,,,,, -145900,4.00931,1.0166435,,,,,,,,,,,,,, -146000,4.2224574,0.9333901,,,,,,,,,,,,,, -146100,3.8438754,0.99108696,,,,,,,,,,,,,, -146200,4.4136653,0.95181364,,,,,,,,,,,,,, -146300,3.9433975,0.9042947,,,,,,,,,,,,,, -146400,3.676801,0.9243071,,,,,,,,,,,,,, -146500,3.712323,0.9204071,,,,,,,,,,,,,, -146600,4.1325307,0.9304824,,,,,,,,,,,,,, -146652,,,0.8880739808082581,0.3907930254936218,0.734279990196228,1.1064696311950684,50000.0,0.6035000085830688,1.850600242614746,10000.0,50022.35403752327,51761.16456365585,50022.35403752327,1729.127257347107,4.677754163742065,0.0 -146700,3.800831,0.97145444,,,,,,,,,,,,,, -146800,4.0082264,0.91198915,,,,,,,,,,,,,, -146900,3.9931178,0.843493,,,,,,,,,,,,,, -147000,3.8677187,0.9021566,,,,,,,,,,,,,, -147100,3.926332,0.94365454,,,,,,,,,,,,,, -147200,3.7863889,0.86327493,,,,,,,,,,,,,, -147300,3.7036173,0.93647546,,,,,,,,,,,,,, -147400,3.6707118,0.87305665,,,,,,,,,,,,,, -147500,3.7787886,0.78156316,,,,,,,,,,,,,, -147600,3.9304016,0.9170508,,,,,,,,,,,,,, -147700,4.042979,0.91341805,,,,,,,,,,,,,, -147800,3.6321666,0.8720412,,,,,,,,,,,,,, -147900,4.326065,0.9440747,,,,,,,,,,,,,, -148000,4.049543,0.919623,,,,,,,,,,,,,, -148100,3.9589272,0.90644705,,,,,,,,,,,,,, -148149,,,0.91019606590271,0.3179280161857605,0.7368199825286865,1.0996601581573486,50000.0,0.6094000339508057,1.84394109249115,10000.0,50532.34589600563,52288.738582372665,50532.34589600563,1746.6031787395475,4.732522487640381,0.0 -148200,3.7565484,0.8299259,,,,,,,,,,,,,, -148300,3.9829361,0.8350084,,,,,,,,,,,,,, -148400,4.297009,0.87325394,,,,,,,,,,,,,, -148500,3.7733889,0.8894722,,,,,,,,,,,,,, -148600,3.8705914,0.9085818,,,,,,,,,,,,,, -148700,3.9579651,0.9023409,,,,,,,,,,,,,, -148800,4.207268,0.8374436,,,,,,,,,,,,,, -148900,3.9241216,0.87673753,,,,,,,,,,,,,, -149000,3.8293066,0.87825686,,,,,,,,,,,,,, -149100,4.04095,0.93537176,,,,,,,,,,,,,, -149200,4.1487527,0.90582,,,,,,,,,,,,,, -149300,4.227079,0.833918,,,,,,,,,,,,,, -149400,3.818864,0.79499626,,,,,,,,,,,,,, -149500,4.189526,0.95465034,,,,,,,,,,,,,, -149600,3.7988086,0.8197812,,,,,,,,,,,,,, -149644,,,0.9052136540412904,0.3252921402454376,0.7401399612426758,1.093822717666626,50000.0,0.6175000071525574,1.8288391828536987,10000.0,51042.26584935188,52816.26387476921,51042.26584935188,1764.0998284816742,4.791525602340698,0.0 -149700,3.9618285,0.8987439,,,,,,,,,,,,,, -149800,4.29771,0.87376595,,,,,,,,,,,,,, -149900,4.228979,0.8533268,,,,,,,,,,,,,, -150000,3.8990705,0.84601474,,,,,,,,,,,,,, -150100,3.8418605,0.8519583,,,,,,,,,,,,,, -150200,3.7310948,0.8816612,,,,,,,,,,,,,, -150300,3.8092408,0.8709595,,,,,,,,,,,,,, -150400,4.0110154,0.8932278,,,,,,,,,,,,,, -150500,4.1505003,0.8585501,,,,,,,,,,,,,, -150600,4.0252585,0.8905357,,,,,,,,,,,,,, -150700,3.977112,0.7898412,,,,,,,,,,,,,, -150800,3.696499,0.8179253,,,,,,,,,,,,,, -150900,4.0491567,0.8211807,,,,,,,,,,,,,, -151000,3.946688,0.8589214,,,,,,,,,,,,,, -151100,4.0662503,0.852011,,,,,,,,,,,,,, -151141,,,0.906668484210968,0.3213539719581604,0.7433599829673767,1.087909460067749,50000.0,0.6135000586509705,1.828898549079895,10000.0,51552.15370512009,53343.76315808296,51552.15370512009,1781.5769836902618,4.874719858169556,0.0 -151200,4.174312,0.86143494,,,,,,,,,,,,,, -151300,4.218735,0.8916983,,,,,,,,,,,,,, -151400,4.1358757,0.8058597,,,,,,,,,,,,,, -151500,4.1243315,0.8670067,,,,,,,,,,,,,, -151600,4.1314893,0.8313041,,,,,,,,,,,,,, -151700,4.1738114,0.9357191,,,,,,,,,,,,,, -151800,4.4066997,0.9179017,,,,,,,,,,,,,, -151900,4.1631427,0.8750374,,,,,,,,,,,,,, -152000,4.019605,0.7774946,,,,,,,,,,,,,, -152100,3.8370874,0.8561543,,,,,,,,,,,,,, -152200,4.014764,0.90866494,,,,,,,,,,,,,, -152300,3.6849794,0.78098935,,,,,,,,,,,,,, -152400,4.217096,0.85289454,,,,,,,,,,,,,, -152500,3.637352,0.75932306,,,,,,,,,,,,,, -152600,3.8358028,0.8006141,,,,,,,,,,,,,, -152637,,,0.9010881781578064,0.3470576405525207,0.7376999855041504,1.1096341609954834,50000.0,0.6148000359535217,1.844693660736084,10000.0,52062.18849515915,53871.51702570915,52062.18849515915,1799.1908648014069,4.927303552627564,0.0 -152700,4.300736,0.793916,,,,,,,,,,,,,, -152800,3.6916208,0.75638855,,,,,,,,,,,,,, -152900,4.2086167,0.76994145,,,,,,,,,,,,,, -153000,4.1020575,0.8501176,,,,,,,,,,,,,, -153100,4.0396857,0.775658,,,,,,,,,,,,,, -153200,4.207517,0.72209555,,,,,,,,,,,,,, -153300,3.7772686,0.7899767,,,,,,,,,,,,,, -153400,4.0058002,0.87295514,,,,,,,,,,,,,, -153500,4.124529,0.84815526,,,,,,,,,,,,,, -153600,5.200186,0.8831912,,,,,,,,,,,,,, -153700,4.202195,0.9850246,,,,,,,,,,,,,, -153800,3.8770058,0.7880427,,,,,,,,,,,,,, -153900,4.115392,0.78703,,,,,,,,,,,,,, -154000,3.9836476,0.8474846,,,,,,,,,,,,,, -154100,4.3550515,0.85563254,,,,,,,,,,,,,, -154134,,,0.9094188213348388,0.3132840991020202,0.7424799799919128,1.0863618850708008,50000.0,0.6132000088691711,1.846460461616516,10000.0,52572.27054524422,54398.92409610748,52572.27054524422,1816.406150817871,4.986354112625122,0.0 -154200,4.588338,0.8928118,,,,,,,,,,,,,, -154300,4.4885283,0.8609849,,,,,,,,,,,,,, -154400,4.442909,0.8584989,,,,,,,,,,,,,, -154500,3.97657,0.76975685,,,,,,,,,,,,,, -154600,4.7003465,0.9248173,,,,,,,,,,,,,, -154700,4.083647,0.81305265,,,,,,,,,,,,,, -154800,4.0034904,0.84737927,,,,,,,,,,,,,, -154900,4.1576447,0.81705236,,,,,,,,,,,,,, -155000,4.438791,0.8601864,,,,,,,,,,,,,, -155100,3.9717028,0.8285119,,,,,,,,,,,,,, -155200,4.353615,0.80912226,,,,,,,,,,,,,, -155300,4.1146255,0.87304485,,,,,,,,,,,,,, -155400,3.9808621,0.8631217,,,,,,,,,,,,,, -155500,4.2785907,0.7556163,,,,,,,,,,,,,, -155600,3.9897308,0.8446511,,,,,,,,,,,,,, -155632,,,0.9090999364852904,0.3144540786743164,0.7421799898147583,1.0850337743759155,50000.0,0.6140000224113464,1.8501039743423464,10000.0,53082.467106580734,54926.58153581619,53082.467106580734,1833.7602362632751,5.041287660598755,0.0 -155700,4.0847454,0.781983,,,,,,,,,,,,,, -155800,3.883737,0.70724,,,,,,,,,,,,,, -155900,3.9925168,0.7930005,,,,,,,,,,,,,, -156000,4.082914,0.8346767,,,,,,,,,,,,,, -156100,4.1339664,0.78304815,,,,,,,,,,,,,, -156200,4.20773,0.90230775,,,,,,,,,,,,,, -156300,4.3793135,0.8108915,,,,,,,,,,,,,, -156400,3.90236,0.76615304,,,,,,,,,,,,,, -156500,4.0652056,0.7506853,,,,,,,,,,,,,, -156600,4.0133758,0.7595734,,,,,,,,,,,,,, -156700,4.325717,0.8274076,,,,,,,,,,,,,, -156800,4.0206027,0.7638926,,,,,,,,,,,,,, -156900,3.8579729,0.7648171,,,,,,,,,,,,,, -157000,4.1631446,0.7191709,,,,,,,,,,,,,, -157100,3.9381173,0.8055858,,,,,,,,,,,,,, -157129,,,0.938257336616516,0.2249719649553299,0.7435799837112427,1.074958086013794,50000.0,0.6165000200271606,1.826570749282837,10000.0,53592.5631840229,55454.15119338036,53592.5631840229,1851.127019643784,5.0969154834747314,0.0 -157200,4.5011387,0.79746616,,,,,,,,,,,,,, -157300,3.8833928,0.71690655,,,,,,,,,,,,,, -157400,4.070217,0.8064951,,,,,,,,,,,,,, -157500,4.7344565,0.7497354,,,,,,,,,,,,,, -157600,4.2840033,0.75057524,,,,,,,,,,,,,, -157700,4.504856,0.78454065,,,,,,,,,,,,,, -157800,3.7581146,0.7255819,,,,,,,,,,,,,, -157900,4.398634,0.7676157,,,,,,,,,,,,,, -158000,4.243075,0.78485703,,,,,,,,,,,,,, -158100,4.1042924,0.8209023,,,,,,,,,,,,,, -158200,4.390494,0.8100034,,,,,,,,,,,,,, -158300,4.4900074,0.9001746,,,,,,,,,,,,,, -158400,4.147064,0.8437604,,,,,,,,,,,,,, -158500,4.1393447,0.71950114,,,,,,,,,,,,,, -158600,4.292942,0.7051019,,,,,,,,,,,,,, -158626,,,0.9301857352256776,0.2437449693679809,0.7449399828910828,1.078855276107788,50000.0,0.6187000274658203,1.83165442943573,10000.0,54102.53592252731,55981.42848825455,54102.53592252731,1868.317990541458,5.159096002578735,0.0 -158700,4.346642,0.7981058,,,,,,,,,,,,,, -158800,4.129042,0.7972657,,,,,,,,,,,,,, -158900,4.606745,0.7758579,,,,,,,,,,,,,, -159000,4.219712,0.7730088,,,,,,,,,,,,,, -159100,4.2743783,0.8440826,,,,,,,,,,,,,, -159200,4.0905805,0.8345995,,,,,,,,,,,,,, -159300,4.2299433,0.78598,,,,,,,,,,,,,, -159400,4.5118766,0.8607979,,,,,,,,,,,,,, -159500,4.3260283,0.75170845,,,,,,,,,,,,,, -159600,4.2598424,0.70421034,,,,,,,,,,,,,, -159700,4.641332,0.7840891,,,,,,,,,,,,,, -159800,4.1653366,0.75993896,,,,,,,,,,,,,, -159900,4.522774,0.7956832,,,,,,,,,,,,,, -160000,4.0915275,0.69204247,,,,,,,,,,,,,, -160100,4.445397,0.79246676,,,,,,,,,,,,,, -160123,,,0.9317004084587096,0.2384717613458633,0.7465999722480774,1.0716962814331057,50000.0,0.622700035572052,1.8242336511611936,10000.0,54612.44617629051,56508.81551861763,54612.44617629051,1885.6868290901184,5.215553045272827,0.0 -160200,4.1314125,0.83110243,,,,,,,,,,,,,, -160300,4.4811487,0.7553853,,,,,,,,,,,,,, -160400,4.56272,0.751708,,,,,,,,,,,,,, -160500,4.603901,0.801677,,,,,,,,,,,,,, -160600,4.2152147,0.82339185,,,,,,,,,,,,,, -160700,4.558986,0.8186666,,,,,,,,,,,,,, -160800,4.814636,0.7888733,,,,,,,,,,,,,, -160900,4.3703923,0.8458353,,,,,,,,,,,,,, -161000,4.5318546,0.69605434,,,,,,,,,,,,,, -161100,4.516921,0.7999483,,,,,,,,,,,,,, -161200,4.1559296,0.69868386,,,,,,,,,,,,,, -161300,4.568876,0.81759113,,,,,,,,,,,,,, -161400,4.7607903,0.7184261,,,,,,,,,,,,,, -161500,4.7200937,0.80501175,,,,,,,,,,,,,, -161600,4.415658,0.77932286,,,,,,,,,,,,,, -161619,,,0.9328961968421936,0.2328613251447677,0.7475000023841858,1.067963480949402,50000.0,0.6202000379562378,1.837498545646668,10000.0,55122.4037668705,57036.08502626419,55122.4037668705,1902.8871002197263,5.275780916213989,0.0 -161700,4.1883183,0.6643797,,,,,,,,,,,,,, -161800,4.287649,0.672707,,,,,,,,,,,,,, -161900,4.227466,0.6917946,,,,,,,,,,,,,, -162000,4.6237383,0.7590244,,,,,,,,,,,,,, -162100,4.538111,0.71277773,,,,,,,,,,,,,, -162200,4.414987,0.75174963,,,,,,,,,,,,,, -162300,4.483796,0.82244456,,,,,,,,,,,,,, -162400,4.4272647,0.6892611,,,,,,,,,,,,,, -162500,4.2260256,0.7165599,,,,,,,,,,,,,, -162600,4.486923,0.68775856,,,,,,,,,,,,,, -162700,4.621868,0.72888005,,,,,,,,,,,,,, -162800,5.068894,0.78371185,,,,,,,,,,,,,, -162900,3.9926858,0.6619686,,,,,,,,,,,,,, -163000,4.05305,0.7669276,,,,,,,,,,,,,, -163100,4.6194677,0.7434881,,,,,,,,,,,,,, -163115,,,0.9344307780265808,0.2265849560499191,0.7470999956130981,1.070637583732605,50000.0,0.6237000226974487,1.835316061973572,10000.0,55632.31921625137,57563.64182114601,55632.31921625137,1920.4235010147088,5.330445289611816,0.0 -163200,4.2242837,0.7888558,,,,,,,,,,,,,, -163300,4.4502263,0.7752535,,,,,,,,,,,,,, -163400,4.5444956,0.71794707,,,,,,,,,,,,,, -163500,4.4279895,0.72232264,,,,,,,,,,,,,, -163600,4.1125855,0.65284735,,,,,,,,,,,,,, -163700,4.781735,0.7653565,,,,,,,,,,,,,, -163800,4.3230734,0.68268776,,,,,,,,,,,,,, -163900,4.1819916,0.70377177,,,,,,,,,,,,,, -164000,4.3042245,0.7529171,,,,,,,,,,,,,, -164100,4.6106057,0.72119725,,,,,,,,,,,,,, -164200,4.122389,0.65957123,,,,,,,,,,,,,, -164300,4.4029264,0.73260844,,,,,,,,,,,,,, -164400,4.971691,0.80292135,,,,,,,,,,,,,, -164500,4.3158813,0.71310407,,,,,,,,,,,,,, -164600,4.5275364,0.7053787,,,,,,,,,,,,,, -164613,,,0.9340322017669678,0.2242102921009063,0.7490800023078918,1.070961356163025,50000.0,0.622700035572052,1.835475325584412,10000.0,56142.43895721436,58091.12245035172,56142.43895721436,1937.6732861995697,5.390793085098267,0.0 -164700,4.402791,0.71378267,,,,,,,,,,,,,, -164800,4.529501,0.73205125,,,,,,,,,,,,,, -164900,4.4427004,0.7139983,,,,,,,,,,,,,, -165000,4.292343,0.7343657,,,,,,,,,,,,,, -165100,4.2709503,0.7216724,,,,,,,,,,,,,, -165200,5.084896,0.7089912,,,,,,,,,,,,,, -165300,4.3934703,0.6835663,,,,,,,,,,,,,, -165400,4.3680844,0.74615467,,,,,,,,,,,,,, -165500,4.3466005,0.7331326,,,,,,,,,,,,,, -165600,4.4742947,0.73192835,,,,,,,,,,,,,, -165700,4.4359326,0.6934875,,,,,,,,,,,,,, -165800,4.573456,0.7198682,,,,,,,,,,,,,, -165900,4.1730747,0.6926231,,,,,,,,,,,,,, -166000,4.717559,0.6760255,,,,,,,,,,,,,, -166100,4.3863664,0.65254796,,,,,,,,,,,,,, -166109,,,0.9427216053009032,0.2048528641462326,0.7487599849700928,1.0646597146987915,50000.0,0.624500036239624,1.8405344486236568,10000.0,56652.48220562935,58618.7962179184,56652.48220562935,1955.1959567070007,5.446483373641968,0.0 -166200,4.176003,0.71644914,,,,,,,,,,,,,, -166300,4.4600654,0.74332947,,,,,,,,,,,,,, -166400,4.7646937,0.75058436,,,,,,,,,,,,,, -166500,4.7843494,0.68458265,,,,,,,,,,,,,, -166600,4.2892394,0.6626175,,,,,,,,,,,,,, -166700,4.401239,0.7685129,,,,,,,,,,,,,, -166800,4.4839363,0.6762547,,,,,,,,,,,,,, -166900,4.5307527,0.66092324,,,,,,,,,,,,,, -167000,4.455664,0.6971523,,,,,,,,,,,,,, -167100,4.420135,0.7660048,,,,,,,,,,,,,, -167200,3.9816165,0.67172354,,,,,,,,,,,,,, -167300,4.4806333,0.7181811,,,,,,,,,,,,,, -167400,4.5105133,0.6949434,,,,,,,,,,,,,, -167500,4.4061494,0.6789198,,,,,,,,,,,,,, -167600,4.31692,0.70098484,,,,,,,,,,,,,, -167606,,,0.9520288109779358,0.1748383939266204,0.750499963760376,1.0643622875213623,50000.0,0.6233000159263611,1.841591119766236,10000.0,57162.58238697052,59146.30214428902,57162.58238697052,1972.493607759476,5.505061149597168,0.0 -167700,4.488573,0.6927943,,,,,,,,,,,,,, -167800,4.6002164,0.6715544,,,,,,,,,,,,,, -167900,4.6544023,0.66735923,,,,,,,,,,,,,, -168000,4.046949,0.6159772,,,,,,,,,,,,,, -168100,4.7411256,0.67362237,,,,,,,,,,,,,, -168200,4.656208,0.7398782,,,,,,,,,,,,,, -168300,4.463768,0.68187886,,,,,,,,,,,,,, -168400,4.4912143,0.6603078,,,,,,,,,,,,,, -168500,4.445964,0.7173048,,,,,,,,,,,,,, -168600,4.657412,0.7239059,,,,,,,,,,,,,, -168700,4.6104927,0.6803927,,,,,,,,,,,,,, -168800,4.682063,0.6863189,,,,,,,,,,,,,, -168900,4.769768,0.7046894,,,,,,,,,,,,,, -169000,4.825666,0.7450137,,,,,,,,,,,,,, -169100,4.497211,0.7118938,,,,,,,,,,,,,, -169103,,,0.9522680044174194,0.1750165224075317,0.750819981098175,1.0584666728973389,50000.0,0.6274000406265259,1.835326075553894,10000.0,57672.57480549812,59673.71718621254,57672.57480549812,1989.808545589447,5.561115026473999,0.0 -169200,4.7812276,0.6354982,,,,,,,,,,,,,, -169300,4.4287105,0.7668587,,,,,,,,,,,,,, -169400,4.609788,0.6363957,,,,,,,,,,,,,, -169500,4.601023,0.6663935,,,,,,,,,,,,,, -169600,4.251597,0.6703556,,,,,,,,,,,,,, -169700,4.0483046,0.6933723,,,,,,,,,,,,,, -169800,4.2817984,0.72867525,,,,,,,,,,,,,, -169900,4.434033,0.67904526,,,,,,,,,,,,,, -170000,5.030926,0.7642735,,,,,,,,,,,,,, -170100,4.2117906,0.6596967,,,,,,,,,,,,,, -170200,4.6899667,0.67453784,,,,,,,,,,,,,, -170300,4.480538,0.66240054,,,,,,,,,,,,,, -170400,4.409077,0.65347683,,,,,,,,,,,,,, -170500,4.5895452,0.67162496,,,,,,,,,,,,,, -170600,,,0.9530253410339355,0.1685916483402252,0.7524200081825256,1.058369517326355,50000.0,0.6294000148773193,1.825202226638794,10000.0,58182.56483054161,60201.29833936691,58182.56483054161,2007.2889399528503,5.620239496231079,0.0 -170600,5.162607,0.69560826,,,,,,,,,,,,,, -170700,4.7870173,0.71928775,,,,,,,,,,,,,, -170800,4.4374943,0.7105377,,,,,,,,,,,,,, -170900,4.5620036,0.75211424,,,,,,,,,,,,,, -171000,4.367043,0.7282802,,,,,,,,,,,,,, -171100,4.5778008,0.6516659,,,,,,,,,,,,,, -171200,4.483098,0.6728104,,,,,,,,,,,,,, -171300,4.751003,0.6776528,,,,,,,,,,,,,, -171400,4.3577495,0.7050687,,,,,,,,,,,,,, -171500,4.430912,0.60172904,,,,,,,,,,,,,, -171600,4.3244247,0.6131255,,,,,,,,,,,,,, -171700,4.671634,0.74967927,,,,,,,,,,,,,, -171800,4.724899,0.63778746,,,,,,,,,,,,,, -171900,4.524671,0.5954843,,,,,,,,,,,,,, -172000,4.549208,0.635388,,,,,,,,,,,,,, -172097,,,0.951809585094452,0.1718765199184417,0.7515000104904175,1.0592249631881714,50000.0,0.628600001335144,1.8379465341567995,10000.0,58692.67744731903,60728.97080159187,58692.67744731903,2024.735008716584,5.683067798614502,0.0 -172100,4.353446,0.7105482,,,,,,,,,,,,,, -172200,4.9794016,0.67579997,,,,,,,,,,,,,, -172300,4.7902846,0.7078486,,,,,,,,,,,,,, -172400,4.6752567,0.6526461,,,,,,,,,,,,,, -172500,4.9055505,0.6597793,,,,,,,,,,,,,, -172600,4.3944383,0.66195583,,,,,,,,,,,,,, -172700,4.3533573,0.68240315,,,,,,,,,,,,,, -172800,4.487922,0.6649285,,,,,,,,,,,,,, -172900,4.5347643,0.6479068,,,,,,,,,,,,,, -173000,4.674502,0.6176219,,,,,,,,,,,,,, -173100,4.3262906,0.6676362,,,,,,,,,,,,,, -173200,4.6870785,0.64062864,,,,,,,,,,,,,, -173300,4.488788,0.637255,,,,,,,,,,,,,, -173400,4.564117,0.6256132,,,,,,,,,,,,,, -173500,4.278651,0.59345365,,,,,,,,,,,,,, -173594,,,0.9528858065605164,0.1682156473398208,0.754040002822876,1.060206651687622,50000.0,0.6281000375747681,1.8330553770065308,10000.0,59202.68797492981,61257.06341433525,59202.68797492981,2042.6753525733948,5.773506164550781,0.0 -173600,4.5090275,0.65111053,,,,,,,,,,,,,, -173700,4.6376705,0.69765466,,,,,,,,,,,,,, -173800,4.2204294,0.63279486,,,,,,,,,,,,,, -173900,4.3224845,0.5812583,,,,,,,,,,,,,, -174000,4.213974,0.5692927,,,,,,,,,,,,,, -174100,5.671104,0.67709255,,,,,,,,,,,,,, -174200,4.7572083,0.7757499,,,,,,,,,,,,,, -174300,4.7937107,0.65043646,,,,,,,,,,,,,, -174400,4.5721145,0.62835264,,,,,,,,,,,,,, -174500,4.6260376,0.7020123,,,,,,,,,,,,,, -174600,4.381333,0.59634614,,,,,,,,,,,,,, -174700,5.0119176,0.6671532,,,,,,,,,,,,,, -174800,5.1413116,0.66072476,,,,,,,,,,,,,, -174900,4.626657,0.6525945,,,,,,,,,,,,,, -175000,4.7150736,0.6458354,,,,,,,,,,,,,, -175090,,,0.9554169178009032,0.1644298136234283,0.754040002822876,1.0577527284622192,50000.0,0.6284000277519226,1.82981026172638,10000.0,59712.79919052124,61784.8630464077,59712.79919052124,2060.2560591697693,5.830533027648926,0.0 -175100,4.6773424,0.6589639,,,,,,,,,,,,,, -175200,4.5843306,0.74764097,,,,,,,,,,,,,, -175300,4.8775854,0.62999237,,,,,,,,,,,,,, -175400,4.811733,0.67723435,,,,,,,,,,,,,, -175500,4.3309345,0.68482375,,,,,,,,,,,,,, -175600,4.4601483,0.63237345,,,,,,,,,,,,,, -175700,5.0363445,0.6719001,,,,,,,,,,,,,, -175800,4.7000256,0.65501404,,,,,,,,,,,,,, -175900,4.273007,0.6195793,,,,,,,,,,,,,, -176000,4.6366096,0.7717829,,,,,,,,,,,,,, -176100,4.620857,0.60023093,,,,,,,,,,,,,, -176200,4.6048117,0.5936533,,,,,,,,,,,,,, -176300,4.7042317,0.73050284,,,,,,,,,,,,,, -176400,4.7439966,0.6369678,,,,,,,,,,,,,, -176500,4.6328273,0.6166836,,,,,,,,,,,,,, -176587,,,0.960558831691742,0.1476965844631195,0.7540599703788757,1.0551767349243164,50000.0,0.626800000667572,1.835967302322388,10000.0,60222.87142467499,62312.36273407936,60222.87142467499,2077.571899175644,5.891931772232056,0.0 -176600,4.7188916,0.6242534,,,,,,,,,,,,,, -176700,4.499896,0.5918674,,,,,,,,,,,,,, -176800,4.8668213,0.6962367,,,,,,,,,,,,,, -176900,4.7810645,0.58939266,,,,,,,,,,,,,, -177000,4.7004957,0.60363525,,,,,,,,,,,,,, -177100,4.3760386,0.58429986,,,,,,,,,,,,,, -177200,4.468282,0.68953675,,,,,,,,,,,,,, -177300,4.5107684,0.6511357,,,,,,,,,,,,,, -177400,4.554007,0.5577088,,,,,,,,,,,,,, -177500,4.7057886,0.63874936,,,,,,,,,,,,,, -177600,4.7861276,0.6373237,,,,,,,,,,,,,, -177700,4.567109,0.5868473,,,,,,,,,,,,,, -177800,4.260703,0.6129661,,,,,,,,,,,,,, -177900,4.1298127,0.57573134,,,,,,,,,,,,,, -178000,4.1382685,0.5941548,,,,,,,,,,,,,, -178084,,,0.959582269191742,0.1489991694688797,0.7542799711227417,1.0563178062438965,50000.0,0.6317000389099121,1.831668496131897,10000.0,60733.02505970001,62840.026733636856,60733.02505970001,2094.967592716217,5.955878973007202,0.0 -178100,3.942622,0.5438282,,,,,,,,,,,,,, -178200,4.295761,0.57352823,,,,,,,,,,,,,, -178300,4.539767,0.6622931,,,,,,,,,,,,,, -178400,4.361281,0.5857015,,,,,,,,,,,,,, -178500,4.3856893,0.60097766,,,,,,,,,,,,,, -178600,4.495563,0.6380555,,,,,,,,,,,,,, -178700,4.164352,0.5713927,,,,,,,,,,,,,, -178800,4.316634,0.64832103,,,,,,,,,,,,,, -178900,5.0726323,0.61872524,,,,,,,,,,,,,, -179000,4.0158653,0.60397583,,,,,,,,,,,,,, -179100,3.949405,0.52206683,,,,,,,,,,,,,, -179200,4.5483584,0.6532801,,,,,,,,,,,,,, -179300,4.6080375,0.5632093,,,,,,,,,,,,,, -179400,4.944098,0.6511069,,,,,,,,,,,,,, -179500,5.0885725,0.6617246,,,,,,,,,,,,,, -179581,,,0.9616549611091614,0.144578143954277,0.7544199824333191,1.0557457208633425,50000.0,0.6299000382423401,1.83180034160614,10000.0,61242.96651220322,63367.51571488381,61242.96651220322,2112.402285337448,6.015355348587036,0.0 -179600,4.9992766,0.61458147,,,,,,,,,,,,,, -179700,3.976417,0.60586965,,,,,,,,,,,,,, -179800,4.4034204,0.65490586,,,,,,,,,,,,,, -179900,4.917289,0.6295172,,,,,,,,,,,,,, -180000,4.9468355,0.65373486,,,,,,,,,,,,,, -180100,4.791168,0.6923215,,,,,,,,,,,,,, -180200,4.7320385,0.62717795,,,,,,,,,,,,,, -180300,4.331787,0.5867236,,,,,,,,,,,,,, -180400,4.854414,0.60933983,,,,,,,,,,,,,, -180500,4.6135273,0.6685805,,,,,,,,,,,,,, -180600,4.9219317,0.60287285,,,,,,,,,,,,,, -180700,4.3875055,0.5979059,,,,,,,,,,,,,, -180800,4.6714683,0.63244057,,,,,,,,,,,,,, -180900,4.5686398,0.5872632,,,,,,,,,,,,,, -181000,5.5599604,0.65527105,,,,,,,,,,,,,, -181078,,,0.9592434167861938,0.1504308581352234,0.7547599673271179,1.052348494529724,50000.0,0.629800021648407,1.8307338953018188,10000.0,61753.01442456245,63894.78460788727,61753.01442456245,2129.514147043228,6.073501348495483,0.0 -181100,4.45743,0.6150969,,,,,,,,,,,,,, -181200,4.810137,0.6250636,,,,,,,,,,,,,, -181300,4.534045,0.6299401,,,,,,,,,,,,,, -181400,4.486145,0.6672805,,,,,,,,,,,,,, -181500,4.7319655,0.6735058,,,,,,,,,,,,,, -181600,4.264748,0.5994017,,,,,,,,,,,,,, -181700,4.2102156,0.5750423,,,,,,,,,,,,,, -181800,4.7396235,0.6874394,,,,,,,,,,,,,, -181900,4.3580065,0.63467103,,,,,,,,,,,,,, -182000,4.521533,0.6596458,,,,,,,,,,,,,, -182100,4.6554565,0.62936556,,,,,,,,,,,,,, -182200,4.177913,0.62532276,,,,,,,,,,,,,, -182300,5.059301,0.66008747,,,,,,,,,,,,,, -182400,4.5090137,0.60405016,,,,,,,,,,,,,, -182500,4.605483,0.5868476,,,,,,,,,,,,,, -182575,,,0.9605388641357422,0.1476653218269348,0.7547999620437622,1.0509791374206543,50000.0,0.6289000511169434,1.8282490968704224,10000.0,62263.06065821648,64422.20252633095,62263.06065821648,2146.771764278412,6.137564897537232,0.0 -182600,4.236797,0.569309,,,,,,,,,,,,,, -182700,4.439651,0.6452432,,,,,,,,,,,,,, -182800,4.519424,0.6121291,,,,,,,,,,,,,, -182900,4.2431717,0.56266403,,,,,,,,,,,,,, -183000,4.6766953,0.6462455,,,,,,,,,,,,,, -183100,4.487564,0.5877344,,,,,,,,,,,,,, -183200,4.44255,0.5960672,,,,,,,,,,,,,, -183300,4.8165264,0.67639154,,,,,,,,,,,,,, -183400,4.524183,0.57042134,,,,,,,,,,,,,, -183500,4.171738,0.5802269,,,,,,,,,,,,,, -183600,4.410167,0.59285235,,,,,,,,,,,,,, -183700,4.4320126,0.59805995,,,,,,,,,,,,,, -183800,4.399201,0.7135885,,,,,,,,,,,,,, -183900,4.2033844,0.63344437,,,,,,,,,,,,,, -184000,4.4141626,0.6123608,,,,,,,,,,,,,, -184073,,,0.9612962007522584,0.1437268257141113,0.7549600005149841,1.0522350072860718,50000.0,0.6287000179290771,1.829309105873108,10000.0,62773.27559399605,64949.84421133995,62773.27559399605,2164.082456827164,6.202368974685669,0.0 -184100,4.3516574,0.5869482,,,,,,,,,,,,,, -184200,4.2211967,0.6215924,,,,,,,,,,,,,, -184300,4.243239,0.5865504,,,,,,,,,,,,,, -184400,4.505526,0.6685689,,,,,,,,,,,,,, -184500,4.882505,0.67983365,,,,,,,,,,,,,, -184600,4.3359,0.63186646,,,,,,,,,,,,,, -184700,4.175481,0.6218714,,,,,,,,,,,,,, -184763,,,,,,,,,,,63008.12818527222,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index fb2032b76..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.2264838218689,0.0,42.74321413040161,1,0,42.74321413040161,0.0010000000474974,6.907756805419922,10000,82.96978783607483,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -61.97929525375366,0.0289614200592041,463.12793254852295,897,0,463.12793254852295,0.009400000795722,6.486231803894043,10000,525.1827142238617,0.0131445312872529,6.433816909790039,0.0123199997469782,6.444516181945801,50000 -83.38200998306274,0.0567998886108398,883.3025920391083,1840,0,883.3025920391083,0.0334000028669834,5.988006591796875,10000,966.8372595310212,0.0418945290148258,5.843362808227539,0.0398399978876113,5.8705878257751465,50000 -105.17406916618349,0.0844464302062988,1303.642668247223,2785,0,1303.642668247223,0.0494000017642974,5.625874996185303,10000,1409.0454559326172,0.072910152375698,5.394949436187744,0.066040001809597,5.446426868438721,50000 -126.95962452888487,0.1133489608764648,1723.5683901309967,3730,0,1723.5683901309967,0.0754000023007392,5.3332037925720215,10000,1850.8338613510127,0.1057617142796516,5.051905155181885,0.0959599986672401,5.101688385009766,50000 -148.88687324523926,0.1449859142303466,2143.667026281357,4667,0,2143.667026281357,0.0970000028610229,5.055785179138184,10000,2292.9393298625946,0.1399023383855819,4.692224025726318,0.1293399930000305,4.748983860015869,50000 -170.83421778678894,0.1749160289764404,2563.898540019989,5607,0,2563.898540019989,0.1337999999523162,4.728091716766357,10000,2735.196244239807,0.1944335848093032,4.300422191619873,0.1761399954557418,4.393296718597412,50000 -194.42507314682007,0.2019352912902832,2984.3140094280243,6544,0,2984.3140094280243,0.1663000136613845,4.450675010681152,10000,3179.2775950431824,0.2354101538658142,3.982255458831787,0.2187999933958053,4.063016891479492,50000 -225.34209632873527,0.2349696159362793,3404.61616230011,7483,0,3404.61616230011,0.1911000162363052,4.25723123550415,10000,3630.577404737473,0.2684960961341858,3.741316556930542,0.2501599788665771,3.8385329246521,50000 -250.3235805034637,0.2688426971435547,3824.698988676071,8423,0,3824.698988676071,0.2204000055789947,4.016154766082764,10000,4075.722831964493,0.3101952970027923,3.4238476753234863,0.2865000069141388,3.551166534423828,50000 -277.6685001850128,0.2980084419250488,4244.6240670681,9362,0,4244.6240670681,0.2503000199794769,3.823859453201294,10000,4523.069453001022,0.3661327958106994,3.119126558303833,0.3231000006198883,3.3199212551116943,50000 -308.2666883468628,0.3484940528869629,4664.77720451355,10299,0,4664.77720451355,0.2657000124454498,3.697940826416016,10000,4973.917723178864,0.3646875023841858,3.0610296726226807,0.3439199924468994,3.1753134727478027,50000 -336.04639506340027,0.3790206909179687,5085.150776386261,11238,0,5085.150776386261,0.2824000120162964,3.593581438064575,10000,5422.149684429169,0.3969140648841858,2.903132915496826,0.3653799891471863,3.0522048473358154,50000 -363.8884010314941,0.4129703044891357,5505.0947265625,12177,0,5505.0947265625,0.3043000102043152,3.429219007492065,10000,5870.017268896103,0.4346679449081421,2.682563304901123,0.3948400020599365,2.876403570175171,50000 -393.6926965713501,0.4578158855438232,5925.405028104782,13114,0,5925.405028104782,0.3204000294208526,3.36568021774292,10000,6320.223189115524,0.4341992139816284,2.6847686767578125,0.4049800038337707,2.8233370780944824,50000 -422.5653555393219,0.495150089263916,6345.643961429596,14045,0,6345.643961429596,0.3206000030040741,3.385582685470581,10000,6769.418182611465,0.4446093738079071,2.6638307571411133,0.4107399880886078,2.8218674659729004,50000 -451.4034585952759,0.5285844802856445,6765.623802185059,14974,0,6765.623802185059,0.3350000083446502,3.2387516498565674,10000,7218.315635204315,0.4664062261581421,2.479521036148072,0.4297599792480469,2.666552782058716,50000 -480.7165772914887,0.5597198009490967,7185.726698637009,15908,0,7185.726698637009,0.3364000022411346,3.2438104152679443,10000,7667.809241294861,0.4827929437160492,2.4374914169311523,0.4357599914073944,2.665578365325928,50000 -509.6156742572784,0.5903451442718506,7605.864891529083,16837,0,7605.864891529083,0.3495000302791595,3.1816678047180176,10000,8116.923457622528,0.4757421910762787,2.441497802734375,0.4432799816131592,2.600684881210327,50000 -540.8512902259827,0.6215465068817139,8025.931247711182,17761,0,8025.931247711182,0.3516000211238861,3.1301140785217285,10000,8568.302577495575,0.4897070229053497,2.3479857444763184,0.4527199864387512,2.532827854156494,50000 -571.6251528263092,0.6539442539215088,8445.988456726074,18693,0,8445.988456726074,0.3596000075340271,3.1122446060180664,10000,9019.213403463364,0.5077343583106995,2.280947208404541,0.4601799845695495,2.506450653076172,50000 -602.7466416358948,0.687971830368042,8866.089724779129,19624,0,8866.089724779129,0.3666000068187713,3.0513882637023926,10000,9470.51671743393,0.5035741925239563,2.2893991470336914,0.4703999757766723,2.448329448699951,50000 -633.781985282898,0.7244741916656494,9286.414787769318,20553,0,9286.414787769318,0.3762000203132629,2.9945783615112305,10000,9921.960839271544,0.5167187452316284,2.226701259613037,0.4785199761390686,2.39848256111145,50000 -666.2311322689056,0.7581882476806641,9706.38515138626,21480,0,9706.38515138626,0.3792000114917755,2.9629838466644287,10000,10374.46075630188,0.5309960842132568,2.138967514038086,0.4841599762439728,2.3563876152038574,50000 -696.6076049804688,0.7909915447235107,10126.384377241136,22413,0,10126.384377241136,0.3786000311374664,2.9851624965667725,10000,10824.917355537416,0.5240234136581421,2.189439535140991,0.4843199849128723,2.373043298721313,50000 -727.5011661052704,0.8213632106781006,10546.718402862549,23351,0,10546.718402862549,0.3881000280380249,2.9460103511810303,10000,11276.222237586975,0.5264257788658142,2.1538989543914795,0.4958399832248688,2.312279224395752,50000 -757.1414499282837,0.8539583683013916,10966.702552080154,24285,0,10966.702552080154,0.3977000117301941,2.882901906967163,10000,11725.926349878311,0.5478906035423279,2.078741788864136,0.5089600086212158,2.2755093574523926,50000 -789.5124342441559,0.8832418918609619,11386.934196472168,25218,0,11386.934196472168,0.4005000293254852,2.87434720993042,10000,12178.604896306992,0.5692187547683716,1.966854214668274,0.5087199807167053,2.25063705444336,50000 -820.4644253253937,0.912727117538452,11806.891574144363,26143,0,11806.891574144363,0.4005000293254852,2.883815288543701,10000,12629.590085029602,0.546679675579071,2.098712205886841,0.5127400159835815,2.264395236968994,50000 -853.4955842494965,0.941619634628296,12227.0956864357,27074,0,12227.0956864357,0.4057000279426574,2.811198234558105,10000,13082.901546955109,0.561718761920929,2.005185127258301,0.5216599702835083,2.184654235839844,50000 -887.1125380992889,0.973057985305786,12647.427459478378,28005,0,12647.427459478378,0.4076000154018402,2.814857482910156,10000,13536.928488254547,0.5656445026397705,1.974370360374451,0.5188599824905396,2.1966774463653564,50000 -919.5988109111786,1.001277208328247,13067.464753627775,28935,0,13067.464753627775,0.4113000333309173,2.815878391265869,10000,13989.526335477827,0.560351550579071,2.0256459712982178,0.5274800062179565,2.1879539489746094,50000 -952.1639549732208,1.035024881362915,13487.82652425766,29863,0,13487.82652425766,0.4169000089168548,2.7868216037750244,10000,14442.533663749697,0.5651562213897705,1.994471549987793,0.5298799872398376,2.168131113052368,50000 -985.389402627945,1.070270538330078,13908.166835308077,30794,0,13908.166835308077,0.4204000234603882,2.7451517581939697,10000,14896.18096923828,0.5801953077316284,1.903186917304993,0.5334599614143372,2.116943836212158,50000 -1018.364322900772,1.1095900535583496,14328.138955116272,31726,0,14328.138955116272,0.4228000342845917,2.7345998287200928,10000,15349.214524507524,0.5763476490974426,1.9141044616699217,0.5375999808311462,2.1075029373168945,50000 -1051.026505947113,1.14223051071167,14748.117035627363,32657,0,14748.117035627363,0.421500027179718,2.763683319091797,10000,15801.935980558395,0.5758984088897705,1.9533084630966189,0.5400399565696716,2.1200501918792725,50000 -1084.6919131278992,1.1724753379821775,15168.483533620834,33589,0,15168.483533620834,0.4267000257968902,2.706271171569824,10000,16256.044842720032,0.5841405987739563,1.894763708114624,0.5448200106620789,2.079097509384156,50000 -1117.0305380821228,1.202235460281372,15588.672712087631,34518,0,15588.672712087631,0.431300014257431,2.6870129108428955,10000,16708.648688793182,0.6061328053474426,1.7760976552963257,0.545799970626831,2.062562704086304,50000 -1150.9917540550232,1.2357745170593262,16009.01386475563,35448,0,16009.01386475563,0.4293000102043152,2.7089946269989014,10000,17163.030723571777,0.584667980670929,1.894122004508972,0.5425599813461304,2.086238145828247,50000 -1184.2244091033936,1.2700214385986328,16429.294507026672,36379,0,16429.294507026672,0.4314000308513641,2.709493398666382,10000,17616.62527346611,0.5904101133346558,1.8927258253097528,0.5471999645233154,2.0837810039520264,50000 -1217.0834302902222,1.3062067031860352,16849.5801718235,37309,0,16849.5801718235,0.4397000074386596,2.661262989044189,10000,18069.852862596512,0.6047655940055847,1.7837172746658323,0.556119978427887,2.017679214477539,50000 -1250.4606716632843,1.3389804363250732,17269.676019191742,38237,0,17269.676019191742,0.4394000172615051,2.677217721939087,10000,18523.4055621624,0.592089831829071,1.862553954124451,0.5530399680137634,2.0462992191314697,50000 -1283.460999250412,1.376353740692139,17689.861362457275,39165,0,17689.861362457275,0.4356000125408172,2.6622133255004883,10000,18976.674981355667,0.5917382836341858,1.8414897918701167,0.5532400012016296,2.0242159366607666,50000 -1316.411651134491,1.409106969833374,18110.095131635662,40095,0,18110.095131635662,0.4385000169277191,2.6310582160949707,10000,19429.93856573105,0.6004687547683716,1.7826272249221802,0.5520200133323669,2.0079150199890137,50000 -1348.9026863574982,1.444082498550415,18530.15157198906,41025,0,18530.15157198906,0.4479000270366668,2.604980707168579,10000,19882.56823849678,0.6124023199081421,1.7426855564117432,0.5575399994850159,1.9958285093307493,50000 -1382.1475772857666,1.474935531616211,18950.41338968277,41953,0,18950.41338968277,0.4416000247001648,2.6486172676086426,10000,20336.153192281723,0.6009570360183716,1.83956265449524,0.5607799887657166,2.0273380279541016,50000 -1413.6585688591003,1.506608247756958,19370.686596870422,42883,0,19370.686596870422,0.4479000270366668,2.569371700286865,10000,20788.01582312584,0.60546875,1.7532869577407837,0.5643599629402161,1.944724678993225,50000 -1447.244621515274,1.5436184406280518,19790.67938184738,43812,0,19790.67938184738,0.4571000337600708,2.540505886077881,10000,21241.678597688675,0.6259570121765137,1.6537854671478271,0.5655800104141235,1.9352000951766968,50000 -1480.019334077835,1.577929973602295,20210.85508942604,44744,0,20210.85508942604,0.454800009727478,2.562917232513428,10000,21694.710247278214,0.60791015625,1.7625596523284912,0.565779983997345,1.950334429740905,50000 -1513.5652458667755,1.608425855636597,20630.80467486381,45675,0,20630.80467486381,0.4541000127792358,2.59224271774292,10000,22148.28287220001,0.6089062094688416,1.7846567630767822,0.568619966506958,1.9842305183410645,50000 -1546.635691165924,1.6452360153198242,21051.00689411164,46604,0,21051.00689411164,0.4522000253200531,2.561401605606079,10000,22601.638983488083,0.619433581829071,1.69148051738739,0.5701199769973755,1.932340741157532,50000 -1577.0918953418732,1.681839942932129,21470.93300724029,47533,0,21470.93300724029,0.4607000350952148,2.502983331680298,10000,23052.105201005936,0.6141015291213989,1.700886845588684,0.5776199698448181,1.8775568008422847,50000 -1609.9773399829865,1.718390703201294,21891.24653053284,48459,0,21891.24653053284,0.4547000229358673,2.565771579742432,10000,23505.38790154457,0.6114843487739563,1.7474989891052246,0.570580005645752,1.9303126335144043,50000 -1643.8127937316897,1.7531659603118896,22311.626733779907,49388,0,22311.626733779907,0.4626000225543976,2.526845216751098,10000,23959.685015916824,0.6169726252555847,1.6989887952804563,0.5762199759483337,1.9017544984817505,50000 -1677.461537361145,1.786334991455078,22731.99612236023,50318,0,22731.99612236023,0.4619000256061554,2.520209550857544,10000,24413.78429436684,0.6458203196525574,1.5972343683242798,0.5758799910545349,1.9063783884048464,50000 -1710.0534682273865,1.8191730976104736,23151.934993743896,51246,0,23151.934993743896,0.4612000286579132,2.528799057006836,10000,24866.39386487007,0.613476574420929,1.7316758632659912,0.5759199857711792,1.9049936532974243,50000 -1743.183295249939,1.854111909866333,23571.889559984207,52175,0,23571.889559984207,0.465800017118454,2.502259492874145,10000,25319.559384584427,0.6240624785423279,1.667878031730652,0.5776399970054626,1.88678514957428,50000 -1774.4506244659424,1.8986506462097168,23992.10686635971,53104,0,23992.10686635971,0.4680000245571136,2.480672836303711,10000,25771.13542485237,0.6388866901397705,1.6128058433532717,0.5816599726676941,1.8688938617706297,50000 -1807.9144456386568,1.9343111515045168,24412.24143385887,54034,0,24412.24143385887,0.460500031709671,2.5516300201416016,10000,26224.815851688385,0.6194140315055847,1.7486610412597656,0.5802599787712097,1.9259588718414309,50000 -1841.114578008652,1.9702081680297847,24832.460919380188,54963,0,24832.460919380188,0.4665000140666961,2.5163378715515137,10000,26678.31806921959,0.6254101395606995,1.6814249753952026,0.5814599990844727,1.8866448402404783,50000 -1873.9938821792605,2.003046751022339,25252.844719171524,55891,0,25252.844719171524,0.4720000326633453,2.4755938053131104,10000,27131.659603118896,0.6341210603713989,1.6259918212890625,0.5867399573326111,1.858447790145874,50000 -1907.6127750873568,2.042140007019043,25672.96810555458,56820,0,25672.96810555458,0.4678000211715698,2.5007386207580566,10000,27585.487575769424,0.6281445026397705,1.6840720176696775,0.5877199769020081,1.8711938858032229,50000 -1940.665373325348,2.074948310852051,26093.019316911697,57752,0,26093.019316911697,0.4648000299930572,2.495413303375244,10000,28038.671048879623,0.6265624761581421,1.685849905014038,0.5827000141143799,1.870428204536438,50000 -1973.6229138374329,2.110410690307617,26513.309331178665,58682,0,26513.309331178665,0.4708000123500824,2.459522008895874,10000,28492.00569462776,0.6369921565055847,1.610442280769348,0.5866000056266785,1.8379977941513064,50000 -2007.113491773605,2.146523952484131,26933.473502397537,59613,0,26933.473502397537,0.4741000235080719,2.4697728157043457,10000,28945.743367433548,0.6574023365974426,1.540103316307068,0.5890399813652039,1.8447201251983645,50000 -2039.523863077164,2.1801421642303467,27353.815348386765,60543,0,27353.815348386765,0.4699000120162964,2.503729820251465,10000,29398.575788736343,0.6248632669448853,1.6834291219711304,0.589419960975647,1.868920922279358,50000 -2070.4793763160706,2.2234983444213867,27774.116877794266,61472,0,27774.116877794266,0.4755000174045563,2.495166063308716,10000,29849.922494888306,0.6370507478713989,1.664278268814087,0.5917400121688843,1.8678573369979856,50000 -2104.190548181534,2.257309675216675,28194.66119718552,62401,0,28194.66119718552,0.4774000346660614,2.467177152633667,10000,30304.25809168816,0.6454882621765137,1.614796757698059,0.5929200053215027,1.8566802740097048,50000 -2138.450407981873,2.2979867458343506,28614.79931879044,63332,0,28614.79931879044,0.4755000174045563,2.443674802780152,10000,30758.743307828903,0.63623046875,1.6195902824401855,0.5971999764442444,1.8102082014083865,50000 -2171.592320203781,2.3315210342407227,29035.028126716614,64262,0,29035.028126716614,0.4784000217914581,2.4457361698150635,10000,31212.194100141525,0.6394140720367432,1.6281222105026243,0.5967199802398682,1.819630146026612,50000 -2205.211054801941,2.3653695583343506,29454.9858648777,65192,0,29454.9858648777,0.4735000133514404,2.464690685272217,10000,31665.85106754303,0.6471484303474426,1.5985087156295776,0.5951799750328064,1.824656248092652,50000 -2238.4791843891144,2.404351234436035,29875.15716052056,66119,0,29875.15716052056,0.4864000082015991,2.4119138717651367,10000,32119.37609243393,0.6446288824081421,1.5758085250854492,0.6003199815750122,1.7816158533096311,50000 -2272.9096236228943,2.440931558609009,30295.264727830887,67048,0,30295.264727830887,0.4790000319480896,2.427741765975952,10000,32573.9971203804,0.6425976157188416,1.5925278663635254,0.6013799905776978,1.785370945930481,50000 -2304.2412803173065,2.480074882507324,30715.514212608337,67979,0,30715.514212608337,0.4832000136375427,2.4544215202331543,10000,33025.66438269615,0.6479296684265137,1.6087180376052856,0.5970799922943115,1.8276971578598025,50000 -2337.104706287384,2.5181806087493896,31135.559599637985,68905,0,31135.559599637985,0.4932000339031219,2.402475595474243,10000,33478.6574652195,0.6702734231948853,1.4906214475631714,0.6045799851417542,1.7795348167419434,50000 -2369.4204025268555,2.779877185821533,31555.447756052017,69834,0,31555.447756052017,0.4812000095844269,2.3967788219451904,10000,33931.170246362686,0.6407226324081421,1.5906224250793457,0.6022399663925171,1.764672040939331,50000 -2402.840269088745,2.8157126903533936,31975.477875947952,70761,0,31975.477875947952,0.4848000109195709,2.404025077819824,10000,34384.702178001404,0.6514452695846558,1.564841866493225,0.6045599579811096,1.7766313552856443,50000 -2436.2950756549835,2.854088068008423,32395.61031579972,71688,0,32395.61031579972,0.484000027179718,2.400637149810791,10000,34838.37434220314,0.6606835722923279,1.527460694313049,0.6034600138664246,1.7711652517318726,50000 -2467.9322276115417,2.894427299499512,32815.5456404686,72615,0,32815.5456404686,0.4868000149726867,2.4277710914611816,10000,35290.03398799896,0.6470116972923279,1.589882493019104,0.6062799692153931,1.779320240020752,50000 -2501.2704651355743,2.93189001083374,33235.80899262428,73541,0,33235.80899262428,0.4923000335693359,2.3877408504486084,10000,35743.71882414818,0.6502929329872131,1.571126103401184,0.6121199727058411,1.750990629196167,50000 -2534.7862520217896,2.968616485595703,33655.84841275215,74470,0,33655.84841275215,0.4879000186920166,2.3940000534057617,10000,36197.356711387634,0.6590234041213989,1.541833758354187,0.6091399788856506,1.7622010707855225,50000 -2568.215485572815,3.006901502609253,34076.092334747314,75400,0,34076.092334747314,0.4926000237464905,2.362853527069092,10000,36651.11418533325,0.6696093678474426,1.4805610179901123,0.6134999990463257,1.732527732849121,50000 -2602.766112804413,3.0473546981811523,34496.37813591957,76330,0,34496.37813591957,0.4974000155925751,2.353220701217652,10000,37106.03781723976,0.65869140625,1.5325862169265747,0.614300012588501,1.7345703840255735,50000 -2635.522970676422,3.086964130401612,34916.623254299164,77258,0,34916.623254299164,0.4937000274658203,2.351444721221924,10000,37559.12559890747,0.6606054306030273,1.511372208595276,0.6152399778366089,1.7202774286270142,50000 -2666.8984639644623,3.1255970001220703,35336.7923810482,78184,0,35336.7923810482,0.5019000172615051,2.3303937911987305,10000,38010.754996299744,0.6808788776397705,1.4191409349441528,0.6181600093841553,1.7033716440200806,50000 -2700.217987060547,3.1663384437561035,35757.1461520195,79112,0,35757.1461520195,0.4924000203609466,2.3449583053588867,10000,38464.51617407799,0.661816418170929,1.5143336057662964,0.6148399710655212,1.7183010578155518,50000 -2733.6324348449707,3.207908391952514,36177.24590039253,80043,0,36177.24590039253,0.4970000088214874,2.373055934906006,10000,38918.11818599701,0.664746105670929,1.510551691055298,0.615339994430542,1.7205344438552856,50000 -2767.0463218688965,3.247930765151977,36597.27389025688,80972,0,36597.27389025688,0.4974000155925751,2.3454010486602783,10000,39371.64637541771,0.6741796731948853,1.4529361724853516,0.6193599700927734,1.702526330947876,50000 -2800.32315158844,3.2903239727020264,37017.20155906677,81899,0,37017.20155906677,0.5027000308036804,2.336276054382324,10000,39824.93992829323,0.667773425579071,1.502530217170715,0.622759997844696,1.6995172500610352,50000 -2832.334250688553,3.334193706512451,37437.302568912506,82829,0,37437.302568912506,0.5026000142097473,2.331036329269409,10000,40277.142501831055,0.6694530844688416,1.4893957376480105,0.6229199767112732,1.6935948133468628,50000 -2865.438676595688,3.372976303100586,37857.52192592621,83758,0,37857.52192592621,0.4980000257492065,2.28191351890564,10000,40730.55080986023,0.6728710532188416,1.4168633222579956,0.626800000667572,1.644349455833435,50000 -2900.2883038520813,3.4122555255889893,38277.640315294266,84687,0,38277.640315294266,0.508400022983551,2.274543285369873,10000,41185.6047809124,0.695019543170929,1.3510807752609253,0.6265199780464172,1.650166630744934,50000 -2934.331175804138,3.451533555984497,38697.576645851135,85617,0,38697.576645851135,0.5103999972343445,2.277604818344116,10000,41639.67412424088,0.6769140362739563,1.4378875494003296,0.6297799944877625,1.655285120010376,50000 -2966.1738238334656,3.496487617492676,39117.53278756142,86545,0,39117.53278756142,0.5065000057220459,2.2886831760406494,10000,42091.563891649246,0.676074206829071,1.442645788192749,0.6266199946403503,1.6573635339736938,50000 -3000.3928265571594,3.5384793281555176,39537.75274658203,87473,0,39537.75274658203,0.5064000487327576,2.292130470275879,10000,42546.0910179615,0.6855859160423279,1.3951412439346311,0.6272000074386597,1.6594972610473633,50000 -3034.1299324035645,3.575552463531494,39957.88946032524,88401,0,39957.88946032524,0.5100000500679016,2.2937583923339844,10000,43000.04769778252,0.6751171946525574,1.4573235511779783,0.630299985408783,1.667418122291565,50000 -3067.7917091846466,3.61670994758606,40377.935428380966,89327,0,40377.935428380966,0.5110000371932983,2.2482283115386963,10000,43453.84268569946,0.6839843392372131,1.394673466682434,0.6331200003623962,1.625125527381897,50000 -3103.089093208313,3.654865503311157,40798.026288986206,90254,0,40798.026288986206,0.5161000490188599,2.281223773956299,10000,43909.31564474106,0.6922265291213989,1.403250217437744,0.6358599662780762,1.6580744981765747,50000 -3135.961694717407,3.69734001159668,41218.17198085785,91182,0,41218.17198085785,0.5085000395774841,2.2850115299224854,10000,44362.4227976799,0.6755273342132568,1.446442723274231,0.6324999928474426,1.6372050046920776,50000 -3170.2124574184418,3.7383203506469727,41638.28119826317,92112,0,41638.28119826317,0.5118000507354736,2.2473669052124023,10000,44816.87034630776,0.6849218606948853,1.3958221673965454,0.6351799964904785,1.6146154403686523,50000 -3202.508903503418,3.777261018753052,42058.32339930534,93038,0,42058.32339930534,0.5232000350952148,2.201004266738892,10000,45269.29402279854,0.689160168170929,1.3483566045761108,0.6417999863624573,1.5728561878204346,50000 -3234.89250922203,3.814180850982666,42478.304805994034,93965,0,42478.304805994034,0.5166000127792358,2.255450487136841,10000,45721.74182486534,0.7109375,1.3136073350906372,0.6380199790000916,1.624780297279358,50000 -3268.6627497673035,3.8543221950531006,42898.46243238449,94887,0,42898.46243238449,0.5225000381469727,2.2491910457611084,10000,46175.75550246239,0.6878905892372131,1.4085044860839844,0.6405199766159058,1.6214958429336548,50000 -3302.586932182312,3.902494430541992,43318.54101729393,95814,0,43318.54101729393,0.5213000178337097,2.232096910476685,10000,46629.85206055641,0.6914257407188416,1.367840051651001,0.6401599645614624,1.597143292427063,50000 -3336.776979207993,3.949820041656494,43738.83052825928,96744,0,43738.83052825928,0.5175999999046326,2.258472681045532,10000,47084.42543315888,0.7025195360183716,1.342665433883667,0.6408999562263489,1.617620587348938,50000 -3370.167640209198,3.994457244873047,44159.07249808312,97673,0,44159.07249808312,0.5253000259399414,2.2020792961120605,10000,47538.14868545532,0.68896484375,1.3808518648147583,0.6461799740791321,1.570910930633545,50000 -3403.740065574646,4.033020973205566,44579.15256071091,98598,0,44579.15256071091,0.5278000235557556,2.181706428527832,10000,47991.88709282875,0.69544917345047,1.3356281518936155,0.6452599763870239,1.56348717212677,50000 -3437.5442354679108,4.075830459594727,44999.109080553055,99525,0,44999.109080553055,0.5200999975204468,2.2290115356445312,10000,48445.73757982254,0.6994921565055847,1.3409777879714966,0.6436600089073181,1.5910606384277344,50000 -3471.073740005493,4.120898962020874,45419.194039821625,100453,0,45419.194039821625,0.5254000425338745,2.1810965538024902,10000,48899.44379377365,0.6985937356948853,1.3283151388168335,0.6507599949836731,1.5489925146102903,50000 -3502.795699119568,4.166559219360352,45839.292186021805,101381,0,45839.292186021805,0.5294000506401062,2.1788530349731445,10000,49351.35535264015,0.7010741829872131,1.3405998945236206,0.6499399542808533,1.5560424327850342,50000 -3536.4254937171936,4.207376718521118,46259.55302786827,102308,0,46259.55302786827,0.5324000120162964,2.1580193042755127,10000,49805.33374285698,0.7044335603713989,1.303029179573059,0.6525200009346008,1.5312974452972412,50000 -3570.0530354976654,4.251614093780518,46679.52591824532,103236,0,46679.52591824532,0.5303000211715698,2.1878247261047363,10000,50259.024639844894,0.721386730670929,1.2593860626220703,0.6481999754905701,1.5756916999816897,50000 -3604.563158750534,4.293308973312378,47099.60453367233,104166,0,47099.60453367233,0.5261000394821167,2.147125005722046,10000,50713.70186185837,0.6998632550239563,1.3183801174163818,0.6520000100135803,1.5319174528121948,50000 -3637.766434669495,4.338505029678345,47519.71596264839,105096,0,47519.71596264839,0.5349000096321106,2.143164873123169,10000,51167.10824346542,0.7125781178474426,1.287738800048828,0.6566999554634094,1.5229451656341553,50000 -3671.302904844284,4.382236242294312,47939.64816617966,106025,0,47939.64816617966,0.5263000130653381,2.194082260131836,10000,51620.6674015522,0.71156245470047,1.2922778129577637,0.6466999650001526,1.5727142095565796,50000 -3704.454334497452,4.4254231452941895,48359.77759027481,106952,0,48359.77759027481,0.5338000059127808,2.1496691703796387,10000,52074.03806447983,0.7048046588897705,1.3123520612716677,0.6563400030136108,1.5270174741744995,50000 -3737.262728214264,4.467191934585571,48779.73000144959,107880,0,48779.73000144959,0.5354000329971313,2.1473042964935303,10000,52526.886921167374,0.7080664038658142,1.289051175117493,0.6545799970626831,1.529636263847351,50000 -3768.2807273864746,4.513633966445923,49199.8685810566,108810,0,49199.8685810566,0.5387000441551208,2.107815742492676,10000,52978.13624429703,0.7175390720367432,1.221261978149414,0.6603599786758423,1.4771640300750732,50000 -3799.7456452846527,4.565981388092041,49619.92288994789,109738,0,49619.92288994789,0.5404000282287598,2.142338991165161,10000,53429.75419712067,0.7140429615974426,1.2802772521972656,0.6586199998855591,1.5189369916915894,50000 -3833.513250827789,4.615039587020874,50039.94888544083,110669,0,50039.94888544083,0.5420000553131104,2.146238327026367,10000,53883.64413046837,0.7124218344688416,1.2984763383865356,0.6633599996566772,1.5199729204177856,50000 -3866.6692838668814,4.662085294723511,50460.5451362133,111601,0,50460.5451362133,0.5434000492095947,2.1115095615386963,10000,54337.49081420898,0.7237108945846558,1.2157682180404663,0.6658399701118469,1.476227641105652,50000 -3900.708383321762,4.705749273300171,50880.741792202,112529,0,50880.741792202,0.550000011920929,2.085638523101806,10000,54791.81612062454,0.73876953125,1.147725224494934,0.6629999876022339,1.471642017364502,50000 -3935.346556425095,4.757417678833008,51301.066935777664,113460,0,51301.066935777664,0.5469000339508057,2.0994019508361816,10000,55246.87820911408,0.7142187356948853,1.256130933761597,0.6647799611091614,1.4717066287994385,50000 -3969.720571756363,4.803069829940796,51721.110796928406,114389,0,51721.110796928406,0.5499000549316406,2.0966713428497314,10000,55701.402671575546,0.7233593463897705,1.2339673042297363,0.6669600009918213,1.4803937673568726,50000 -4003.2828526496887,4.8567054271698,52141.49275612831,115318,0,52141.49275612831,0.549500048160553,2.0902318954467773,10000,56155.447914361954,0.7351366877555847,1.1889073848724363,0.6723600029945374,1.4659477472305298,50000 -4035.5553166866302,4.898540258407593,52561.77123785019,116249,0,52561.77123785019,0.5455000400543213,2.114979982376098,10000,56608.08748936653,0.7199413776397705,1.2665797472000122,0.67221999168396,1.4920510053634644,50000 -4070.2912969589233,4.944955587387085,52981.80159282684,117180,0,52981.80159282684,0.5527000427246094,2.0657169818878174,10000,57062.94676208496,0.7294726371765137,1.2057080268859863,0.677299976348877,1.432081937789917,50000 -4105.278161764145,4.990721702575684,53401.86762642861,118109,0,53401.86762642861,0.5592000484466553,2.0560216903686523,10000,57518.09233784676,0.7354491949081421,1.1728439331054688,0.6750800013542175,1.43381667137146,50000 -4137.934015035629,5.03800368309021,53822.07191777229,119039,0,53822.07191777229,0.5534999966621399,2.041276454925537,10000,57971.04748129845,0.7400195002555847,1.14028000831604,0.6732800006866455,1.417976975440979,50000 -4171.868766546249,5.085484981536865,54242.142776966095,119970,0,54242.142776966095,0.5523000359535217,2.0710713863372803,10000,58425.14707708359,0.7264648079872131,1.2079328298568726,0.6744799613952637,1.4414364099502563,50000 -4205.415090322495,5.132269144058228,54662.12339830399,120901,0,54662.12339830399,0.5651000142097473,2.001354694366455,10000,58878.76770377159,0.7350195050239563,1.1525589227676392,0.6802799701690674,1.4023057222366333,50000 -4237.944813966751,5.175921440124512,55082.16000986099,121830,0,55082.16000986099,0.5628000497817993,1.9966557025909424,10000,59331.42533326149,0.7547070384025574,1.0838943719863892,0.683459997177124,1.395894169807434,50000 -4271.503207921982,5.226181507110596,55502.36118531227,122757,0,55502.36118531227,0.5541000366210938,2.036602735519409,10000,59785.28081989288,0.7364257574081421,1.1707990169525146,0.6805199980735779,1.4080075025558472,50000 -4306.207540750504,5.269519805908203,55922.51798272133,123685,0,55922.51798272133,0.5642000436782837,2.0096795558929443,10000,60240.23173499107,0.7400780916213989,1.1385509967803955,0.6856399774551392,1.3829452991485596,50000 -4338.141446352005,5.316270351409912,56342.67802453041,124614,0,56342.67802453041,0.5678000450134277,2.01263165473938,10000,60692.419605493546,0.7510937452316284,1.11006760597229,0.6840999722480774,1.394182205200195,50000 -4371.525834798813,5.361638784408569,56762.89463853836,125544,0,56762.89463853836,0.566100001335144,1.991938591003418,10000,61146.11239314079,0.7424218654632568,1.1373136043548584,0.6873799562454224,1.3734878301620483,50000 -4404.65695309639,5.406764268875122,57183.18724656105,126472,0,57183.18724656105,0.5692000389099121,1.9838557243347168,10000,61599.6272623539,0.74867182970047,1.1079342365264893,0.6914599537849426,1.3647878170013428,50000 -4438.493448495865,5.450902700424194,57603.10753917694,127402,0,57603.10753917694,0.5679000020027161,1.9800525903701784,10000,62053.47536659241,0.7528125047683716,1.0917577743530271,0.6893599629402161,1.3566315174102783,50000 -4472.449482917786,5.49484133720398,58023.76083302498,128301,0,58023.76083302498,0.5696000456809998,1.9985442161560056,10000,62508.17419576645,0.7537695169448853,1.098175287246704,0.6912999749183655,1.3770054578781128,50000 -4506.356766462326,5.540136098861694,58443.69563674927,129231,0,58443.69563674927,0.570900022983551,1.974564552307129,10000,62962.10774159432,0.7473828196525574,1.1209533214569092,0.6952799558639526,1.3586920499801636,50000 -4538.41277050972,5.585582971572876,58863.781465768814,130161,0,58863.781465768814,0.5753000378608704,1.9617393016815183,10000,63414.34195232391,0.7574023008346558,1.075601577758789,0.6932399868965149,1.3420850038528442,50000 -4572.432414054871,5.636627674102783,59283.73093700409,131088,0,59283.73093700409,0.5745000243186951,1.9722157716751096,10000,63868.40799832344,0.7674218416213989,1.0314123630523682,0.6955599784851074,1.343638300895691,50000 -4605.888027429581,5.6878721714019775,59704.04973602295,132017,0,59704.04973602295,0.5759000182151794,1.91862154006958,10000,64322.28016543389,0.7551367282867432,1.0511291027069092,0.6979999542236328,1.3008636236190796,50000 -4640.043148756027,5.734477758407593,60124.05582642555,132946,0,60124.05582642555,0.5769000053405762,1.952175498008728,10000,64776.534338235855,0.7594921588897705,1.0723209381103516,0.6966399550437927,1.3334400653839111,50000 -4673.519508600235,5.778345823287964,60544.09117388725,133876,0,60544.09117388725,0.58160001039505,1.904213547706604,10000,65230.13675904274,0.7744726538658142,0.9877074360847472,0.7017599940299988,1.2961947917938232,50000 -4708.190718412399,5.826894760131836,60964.34392094612,134807,0,60964.34392094612,0.5774000287055969,1.9307143688201904,10000,65685.15578842163,0.7617382407188416,1.0603655576705933,0.701200008392334,1.320079684257507,50000 -4739.263605117798,5.874547243118286,61384.33895516396,135735,0,61384.33895516396,0.5831000208854675,1.8825321197509768,10000,66136.31755590439,0.768847644329071,1.01193368434906,0.7056999802589417,1.2756710052490234,50000 -4772.92241859436,5.9302287101745605,61804.52350068092,136662,0,61804.52350068092,0.5806000232696533,1.923264741897583,10000,66590.26353669167,0.7686523199081421,1.024351954460144,0.7022199630737305,1.317602038383484,50000 -4807.293427467346,5.978266716003418,62224.45097088814,137592,0,62224.45097088814,0.5868000388145447,1.8915445804595947,10000,67044.65680789948,0.786425769329071,0.9460193514823914,0.7095999717712402,1.2811696529388428,50000 -4842.306634902954,6.0289857387542725,62644.38854908943,138521,0,62644.38854908943,0.584600031375885,1.8975168466567995,10000,67499.70457077026,0.76917964220047,1.0142436027526855,0.7066400051116943,1.2850269079208374,50000 -4875.023307323456,6.075516939163208,63064.298337221146,139448,0,63064.298337221146,0.5867000222206116,1.8783414363861084,10000,67952.42379832268,0.7764062285423279,0.9926196932792664,0.7118200063705444,1.268027663230896,50000 -4909.52064538002,6.123907089233398,63484.35860395432,140376,0,63484.35860395432,0.5902000069618225,1.8925907611846924,10000,68407.07604622841,0.7857421636581421,0.9634817838668824,0.7120800018310547,1.285139560699463,50000 -4944.235973596573,6.169575452804565,63904.452016592026,141307,0,63904.452016592026,0.5962000489234924,1.866379141807556,10000,68861.97762393951,0.7754882574081421,0.99040687084198,0.7101199626922607,1.26579749584198,50000 -4978.624108314514,6.229996204376221,64324.7577764988,142234,0,64324.7577764988,0.5933000445365906,1.8614035844802856,10000,69316.77753448486,0.7803906202316284,0.9680672883987428,0.7150200009346008,1.2434526681900024,50000 -5012.432188272476,6.281265020370483,64745.00335741043,143162,0,64745.00335741043,0.6005000472068787,1.8527811765670776,10000,69770.92920541763,0.7883398532867432,0.9409700632095336,0.7188999652862549,1.2394484281539917,50000 -5046.197269678116,6.327600479125977,65165.10458111763,144090,0,65165.10458111763,0.6010000109672546,1.8121860027313232,10000,70224.88837790489,0.7860937118530273,0.9350039958953856,0.7210599780082703,1.209059238433838,50000 -5081.971020698547,6.380660057067871,65585.06423592567,145020,0,65585.06423592567,0.6005000472068787,1.830407738685608,10000,70680.72092986107,0.7858593463897705,0.9343098402023317,0.7197799682617188,1.222838282585144,50000 -5116.485925197601,6.430635690689087,66005.23581600189,145951,0,66005.23581600189,0.6039000153541565,1.830265760421753,10000,71135.50416016579,0.7895702719688416,0.9390873312950134,0.7216399908065796,1.2361648082733154,50000 -5151.747972011566,6.480493545532227,66425.18833255768,146880,0,66425.18833255768,0.6035000085830688,1.808951020240784,10000,71590.81511712074,0.8043749928474426,0.8740522861480713,0.7262799739837646,1.2046759128570557,50000 -5185.241056442261,6.546149730682373,66845.46230769157,147810,0,66845.46230769157,0.6008000373840332,1.8167974948883057,10000,72044.69520163536,0.7898827791213989,0.9271060228347778,0.7245599627494812,1.2024517059326172,50000 -5219.695116758347,6.595580577850342,67265.56567597389,148741,0,67265.56567597389,0.6121000051498413,1.7757694721221924,10000,72499.34892606735,0.79749995470047,0.8806569576263428,0.7253400087356567,1.184898853302002,50000 -5253.777672767639,6.64896035194397,67685.5972161293,149669,0,67685.5972161293,0.6064000129699707,1.7887310981750488,10000,72953.5630671978,0.8055273294448853,0.8653415441513062,0.7287799715995789,1.1887600421905518,50000 -5287.894725084305,6.696053504943848,68105.81286787987,150600,0,68105.81286787987,0.6121000051498413,1.759349703788757,10000,73407.98974180222,0.7973241806030273,0.8826409578323364,0.7310000061988831,1.1645584106445312,50000 -5322.22726726532,6.748147487640381,68526.03539800644,151531,0,68526.03539800644,0.6106000542640686,1.780478596687317,10000,73862.64343810081,0.8052929639816284,0.8601254224777222,0.733959972858429,1.1651591062545776,50000 -5357.129096031189,6.798153877258301,68946.24187660217,152460,0,68946.24187660217,0.610200047492981,1.789298176765442,10000,74317.84789443016,0.8087890148162842,0.8668062686920166,0.7332599759101868,1.18711519241333,50000 -5391.624893188477,6.85608983039856,69366.1802983284,153391,0,69366.1802983284,0.6176000237464905,1.7518985271453855,10000,74772.3870472908,0.8062109351158142,0.8395585417747498,0.7360999584197998,1.139323353767395,50000 -5426.559454441071,6.911020278930664,69786.34556651115,154319,0,69786.34556651115,0.6152999997138977,1.7600574493408203,10000,75227.58785867691,0.8050194978713989,0.8566577434539795,0.7374399900436401,1.1484848260879517,50000 -5461.275590896606,6.962958574295044,70206.66819000244,155245,0,70206.66819000244,0.6205000281333923,1.7459418773651123,10000,75682.72457933426,0.8122656345367432,0.8312785625457764,0.7390999794006348,1.1496641635894775,50000 -5495.359393119812,7.020386457443237,70626.57415676117,156170,0,70626.57415676117,0.6170000433921814,1.7455700635910034,10000,76136.81805038452,0.8208202719688416,0.808429479598999,0.7415599822998047,1.147451639175415,50000 -5529.437774181366,7.078781843185425,71046.6845676899,157100,0,71046.6845676899,0.6218000054359436,1.7313672304153442,10000,76591.11240267754,0.8145312070846558,0.8163199424743652,0.7399199604988098,1.1327764987945557,50000 -5563.826683521271,7.130944013595581,71466.77744364738,158029,0,71466.77744364738,0.6199000477790833,1.7298482656478882,10000,77045.694283247,0.8168163895606995,0.8214466571807861,0.7429199814796448,1.1342238187789917,50000 -5597.954945802689,7.182467699050903,71886.9072842598,158959,0,71886.9072842598,0.6255000233650208,1.7001981735229492,10000,77500.05093717575,0.823535144329071,0.7854181528091431,0.7455799579620361,1.114540696144104,50000 -5632.141172647476,7.231976747512817,72306.84834432602,159890,0,72306.84834432602,0.6300000548362732,1.702431082725525,10000,77954.27435541153,0.8223828077316284,0.7984592914581299,0.7449600100517273,1.1174359321594238,50000 -5666.60213804245,7.287751197814941,72726.83613228798,160819,0,72726.83613228798,0.6273000240325928,1.690682888031006,10000,78408.82541036606,0.82386714220047,0.7823898196220398,0.7475000023841858,1.1016745567321775,50000 -5701.25897192955,7.346428871154785,73147.12069392204,161748,0,73147.12069392204,0.6290000081062317,1.686492681503296,10000,78863.87222862244,0.8251367211341858,0.7654830813407898,0.7488999962806702,1.0918009281158447,50000 -5734.852890491486,7.401033163070679,73567.05645251274,162676,0,73567.05645251274,0.633400022983551,1.6805967092514038,10000,79317.50340890884,0.8285351395606995,0.7643269896507263,0.7503599524497986,1.0934828519821167,50000 -5767.8042142391205,7.45538067817688,73987.05826282501,163603,0,73987.05826282501,0.6317000389099121,1.6685441732406616,10000,79770.55769276619,0.8294531106948853,0.7535117268562317,0.7529199719429016,1.0765010118484497,50000 -5802.109175443649,7.517110109329224,74407.01922512054,164530,0,74407.01922512054,0.6342000365257263,1.6714333295822144,10000,80224.93182849884,0.82923823595047,0.7566958665847778,0.751039981842041,1.0820822715759275,50000 -5837.49352812767,7.568860530853271,74827.14331531525,165458,0,74827.14331531525,0.6332000494003296,1.685017704963684,10000,80680.53797078133,0.8375976085662842,0.7356262803077698,0.754859983921051,1.0842537879943848,50000 -5871.969348192215,7.618942499160767,75247.37900185585,166389,0,75247.37900185585,0.6389000415802002,1.6692404747009275,10000,81135.34643173218,0.8311718702316284,0.7485712170600891,0.7549999952316284,1.0757715702056885,50000 -5906.669553279877,7.670528650283813,75667.48089122772,167317,0,75667.48089122772,0.6389000415802002,1.6732136011123655,10000,81590.24653863907,0.8334569931030273,0.7626773715019226,0.7558599710464478,1.091723918914795,50000 -5941.224766492844,7.7321436405181885,76087.52441835403,168246,0,76087.52441835403,0.6396000385284424,1.6480201482772827,10000,82044.95395517349,0.8374999761581421,0.7262058258056641,0.7564399838447571,1.0680700540542605,50000 -5975.560669898987,7.782688856124878,76507.65330696106,169175,0,76507.65330696106,0.6421000361442566,1.6457515954971311,10000,82499.51538085938,0.83607417345047,0.7408673167228699,0.7595399618148804,1.0654107332229614,50000 -6007.491415500641,7.844382524490356,76927.95672249794,170103,0,76927.95672249794,0.6447000503540039,1.6202740669250488,10000,82951.85792160034,0.8408398032188416,0.7134418487548828,0.7605199813842773,1.0481165647506714,50000 -6040.283388137817,7.898629903793335,77348.14902710915,171033,0,77348.14902710915,0.6460000276565552,1.633678674697876,10000,83404.94340229034,0.8424023389816284,0.7111977934837341,0.7627399563789368,1.051758050918579,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 1cf7beaf3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1902 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.3291411,6.907757,,,,,,,,,,,,,, -1,,,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.74321413040161,82.96978783607483,42.74321413040161,40.2264838218689,0.0,0.0 -100,0.3865211,6.906055,,,,,,,,,,,,,, -200,0.3772917,6.898754,,,,,,,,,,,,,, -300,0.5410942,6.8608527,,,,,,,,,,,,,, -400,0.64235455,6.8146787,,,,,,,,,,,,,, -500,1.0609103,6.8208876,,,,,,,,,,,,,, -600,1.0887295,6.748292,,,,,,,,,,,,,, -700,0.9023302,6.6883297,,,,,,,,,,,,,, -800,1.1986238,6.6232667,,,,,,,,,,,,,, -897,,,0.0131445312872529,6.433816909790039,0.0123199997469782,6.444516181945801,50000.0,0.009400000795722,6.486231803894043,10000.0,463.12793254852295,525.1827142238617,463.12793254852295,61.97929525375366,0.0289614200592041,0.0 -900,0.9506203,6.5685315,,,,,,,,,,,,,, -1000,2.2170103,6.588651,,,,,,,,,,,,,, -1100,1.4745036,6.6172624,,,,,,,,,,,,,, -1200,1.3849,6.4598856,,,,,,,,,,,,,, -1300,1.2306503,6.666187,,,,,,,,,,,,,, -1400,1.568969,6.4165945,,,,,,,,,,,,,, -1500,1.6667418,6.6665907,,,,,,,,,,,,,, -1600,1.7794758,6.257121,,,,,,,,,,,,,, -1700,1.4973371,6.2520866,,,,,,,,,,,,,, -1800,1.6543092,6.2620106,,,,,,,,,,,,,, -1840,,,0.0418945290148258,5.843362808227539,0.0398399978876113,5.8705878257751465,50000.0,0.0334000028669834,5.988006591796875,10000.0,883.3025920391083,966.8372595310212,883.3025920391083,83.38200998306274,0.0567998886108398,0.0 -1900,3.0175462,6.210449,,,,,,,,,,,,,, -2000,1.7718686,6.4530325,,,,,,,,,,,,,, -2100,1.9056746,6.1201024,,,,,,,,,,,,,, -2200,1.6572099,6.1351604,,,,,,,,,,,,,, -2300,1.8283368,6.090671,,,,,,,,,,,,,, -2400,1.6611089,6.027038,,,,,,,,,,,,,, -2500,1.7603594,6.5513573,,,,,,,,,,,,,, -2600,2.230328,6.0312405,,,,,,,,,,,,,, -2700,1.5463232,6.296197,,,,,,,,,,,,,, -2785,,,0.072910152375698,5.394949436187744,0.066040001809597,5.446426868438721,50000.0,0.0494000017642974,5.625874996185303,10000.0,1303.642668247223,1409.0454559326172,1303.642668247223,105.17406916618349,0.0844464302062988,0.0 -2800,2.2406538,6.029985,,,,,,,,,,,,,, -2900,1.8309412,5.9745336,,,,,,,,,,,,,, -3000,1.4174001,6.3792477,,,,,,,,,,,,,, -3100,1.8938872,5.9241652,,,,,,,,,,,,,, -3200,1.5646814,6.579898,,,,,,,,,,,,,, -3300,2.4663439,5.883176,,,,,,,,,,,,,, -3400,1.8414762,5.899254,,,,,,,,,,,,,, -3500,1.6443193,5.8341312,,,,,,,,,,,,,, -3600,1.7070825,5.911439,,,,,,,,,,,,,, -3700,1.570376,6.162325,,,,,,,,,,,,,, -3730,,,0.1057617142796516,5.051905155181885,0.0959599986672401,5.101688385009766,50000.0,0.0754000023007392,5.3332037925720215,10000.0,1723.5683901309967,1850.8338613510127,1723.5683901309967,126.95962452888487,0.1133489608764648,0.0 -3800,1.5859218,5.9833007,,,,,,,,,,,,,, -3900,1.6097652,5.692972,,,,,,,,,,,,,, -4000,1.3656222,6.3482037,,,,,,,,,,,,,, -4100,1.7011266,5.8897963,,,,,,,,,,,,,, -4200,1.540205,5.645842,,,,,,,,,,,,,, -4300,1.4306157,5.5976505,,,,,,,,,,,,,, -4400,1.305668,6.18518,,,,,,,,,,,,,, -4500,1.5547941,5.569181,,,,,,,,,,,,,, -4600,1.0891997,6.5504427,,,,,,,,,,,,,, -4667,,,0.1399023383855819,4.692224025726318,0.1293399930000305,4.748983860015869,50000.0,0.0970000028610229,5.055785179138184,10000.0,2143.667026281357,2292.9393298625946,2143.667026281357,148.88687324523926,0.1449859142303466,0.0 -4700,1.6689212,5.6011744,,,,,,,,,,,,,, -4800,1.6071416,5.7596245,,,,,,,,,,,,,, -4900,1.1182468,6.500463,,,,,,,,,,,,,, -5000,1.2335497,6.2996573,,,,,,,,,,,,,, -5100,1.7901734,5.32448,,,,,,,,,,,,,, -5200,1.8064015,5.5770216,,,,,,,,,,,,,, -5300,1.7485671,5.3818264,,,,,,,,,,,,,, -5400,1.3401061,5.88961,,,,,,,,,,,,,, -5500,2.1739082,5.451758,,,,,,,,,,,,,, -5600,1.6030189,5.8695707,,,,,,,,,,,,,, -5607,,,0.1944335848093032,4.300422191619873,0.1761399954557418,4.393296718597412,50000.0,0.1337999999523162,4.728091716766357,10000.0,2563.898540019989,2735.196244239807,2563.898540019989,170.83421778678894,0.1749160289764404,0.0 -5700,1.6568408,5.469359,,,,,,,,,,,,,, -5800,1.5756669,5.227613,,,,,,,,,,,,,, -5900,1.4440485,5.6186924,,,,,,,,,,,,,, -6000,1.5180811,5.0395083,,,,,,,,,,,,,, -6100,1.4119378,6.423209,,,,,,,,,,,,,, -6200,1.478631,5.88603,,,,,,,,,,,,,, -6300,1.1539203,6.2175384,,,,,,,,,,,,,, -6400,1.6855793,5.0320854,,,,,,,,,,,,,, -6500,1.6488353,6.2881794,,,,,,,,,,,,,, -6544,,,0.2354101538658142,3.982255458831787,0.2187999933958053,4.063016891479492,50000.0,0.1663000136613845,4.450675010681152,10000.0,2984.3140094280243,3179.2775950431824,2984.3140094280243,194.42507314682007,0.2019352912902832,0.0 -6600,1.7616173,5.0389724,,,,,,,,,,,,,, -6700,1.5688859,4.9624043,,,,,,,,,,,,,, -6800,1.1398162,6.3116565,,,,,,,,,,,,,, -6900,1.5026441,5.0666084,,,,,,,,,,,,,, -7000,1.7703525,4.998901,,,,,,,,,,,,,, -7100,1.8418914,4.843635,,,,,,,,,,,,,, -7200,1.3002572,5.97262,,,,,,,,,,,,,, -7300,1.8652778,4.8855057,,,,,,,,,,,,,, -7400,1.5086485,6.2675333,,,,,,,,,,,,,, -7483,,,0.2684960961341858,3.741316556930542,0.2501599788665771,3.8385329246521,50000.0,0.1911000162363052,4.25723123550415,10000.0,3404.61616230011,3630.577404737473,3404.61616230011,225.34209632873527,0.2349696159362793,0.0 -7500,1.7215838,4.8238683,,,,,,,,,,,,,, -7600,1.7119689,4.730626,,,,,,,,,,,,,, -7700,1.6416105,4.914403,,,,,,,,,,,,,, -7800,1.5059643,4.772788,,,,,,,,,,,,,, -7900,1.1783408,5.7666254,,,,,,,,,,,,,, -8000,1.9468118,4.816466,,,,,,,,,,,,,, -8100,1.5175257,4.876981,,,,,,,,,,,,,, -8200,1.5949908,4.9831758,,,,,,,,,,,,,, -8300,1.5380441,4.7676716,,,,,,,,,,,,,, -8400,1.421269,5.963483,,,,,,,,,,,,,, -8423,,,0.3101952970027923,3.4238476753234863,0.2865000069141388,3.551166534423828,50000.0,0.2204000055789947,4.016154766082764,10000.0,3824.698988676071,4075.722831964493,3824.698988676071,250.3235805034637,0.2688426971435547,0.0 -8500,1.5768167,4.9648747,,,,,,,,,,,,,, -8600,1.6908015,4.588176,,,,,,,,,,,,,, -8700,1.3443338,5.7930584,,,,,,,,,,,,,, -8800,1.7369378,4.5624194,,,,,,,,,,,,,, -8900,1.2461559,5.766653,,,,,,,,,,,,,, -9000,1.2039938,6.2065115,,,,,,,,,,,,,, -9100,1.0742594,6.216818,,,,,,,,,,,,,, -9200,1.6162848,4.468289,,,,,,,,,,,,,, -9300,1.1930479,5.396079,,,,,,,,,,,,,, -9362,,,0.3661327958106994,3.119126558303833,0.3231000006198883,3.3199212551116943,50000.0,0.2503000199794769,3.823859453201294,10000.0,4244.6240670681,4523.069453001022,4244.6240670681,277.6685001850128,0.2980084419250488,0.0 -9400,1.8013846,4.772912,,,,,,,,,,,,,, -9500,1.3697008,5.024935,,,,,,,,,,,,,, -9600,1.512989,4.463256,,,,,,,,,,,,,, -9700,1.3071892,5.3165135,,,,,,,,,,,,,, -9800,1.6287689,4.433425,,,,,,,,,,,,,, -9900,1.6824418,4.666973,,,,,,,,,,,,,, -10000,1.2442213,5.6182766,,,,,,,,,,,,,, -10100,1.8271302,4.4341583,,,,,,,,,,,,,, -10200,1.5039989,4.6075325,,,,,,,,,,,,,, -10299,,,0.3646875023841858,3.0610296726226807,0.3439199924468994,3.1753134727478027,50000.0,0.2657000124454498,3.697940826416016,10000.0,4664.77720451355,4973.917723178864,4664.77720451355,308.2666883468628,0.3484940528869629,0.0 -10300,1.6352398,4.4100637,,,,,,,,,,,,,, -10400,1.8912672,4.2357073,,,,,,,,,,,,,, -10500,1.8231927,4.313989,,,,,,,,,,,,,, -10600,1.0559459,5.672284,,,,,,,,,,,,,, -10700,1.11025,5.811826,,,,,,,,,,,,,, -10800,1.511742,4.309699,,,,,,,,,,,,,, -10900,1.4092563,4.656362,,,,,,,,,,,,,, -11000,1.7270974,4.487981,,,,,,,,,,,,,, -11100,1.3506613,4.4599876,,,,,,,,,,,,,, -11200,1.2337388,5.477514,,,,,,,,,,,,,, -11238,,,0.3969140648841858,2.903132915496826,0.3653799891471863,3.0522048473358154,50000.0,0.2824000120162964,3.593581438064575,10000.0,5085.150776386261,5422.149684429169,5085.150776386261,336.04639506340027,0.3790206909179687,0.0 -11300,1.4619602,4.178303,,,,,,,,,,,,,, -11400,1.4362717,5.0772986,,,,,,,,,,,,,, -11500,1.3588887,4.4449487,,,,,,,,,,,,,, -11600,1.6400268,4.4458294,,,,,,,,,,,,,, -11700,1.4262183,4.2531896,,,,,,,,,,,,,, -11800,1.7124263,4.255506,,,,,,,,,,,,,, -11900,1.6147722,4.159179,,,,,,,,,,,,,, -12000,1.3558158,4.6611013,,,,,,,,,,,,,, -12100,1.4037709,4.271715,,,,,,,,,,,,,, -12177,,,0.4346679449081421,2.682563304901123,0.3948400020599365,2.876403570175171,50000.0,0.3043000102043152,3.429219007492065,10000.0,5505.0947265625,5870.017268896103,5505.0947265625,363.8884010314941,0.4129703044891357,0.0 -12200,1.7914388,4.3677416,,,,,,,,,,,,,, -12300,1.0961722,6.067524,,,,,,,,,,,,,, -12400,1.0946519,5.209118,,,,,,,,,,,,,, -12500,1.0706376,5.7847614,,,,,,,,,,,,,, -12600,1.6100992,4.130071,,,,,,,,,,,,,, -12700,1.25222,5.2811995,,,,,,,,,,,,,, -12800,1.3460504,4.5439997,,,,,,,,,,,,,, -12900,1.1771042,5.7420626,,,,,,,,,,,,,, -13000,1.245495,5.245078,,,,,,,,,,,,,, -13100,1.5202601,4.191992,,,,,,,,,,,,,, -13114,,,0.4341992139816284,2.6847686767578125,0.4049800038337707,2.8233370780944824,50000.0,0.3204000294208526,3.36568021774292,10000.0,5925.405028104782,6320.223189115524,5925.405028104782,393.6926965713501,0.4578158855438232,0.0 -13200,0.99565256,6.068511,,,,,,,,,,,,,, -13300,1.6473358,4.0821686,,,,,,,,,,,,,, -13400,1.4816638,4.1610627,,,,,,,,,,,,,, -13500,1.8746643,4.0867763,,,,,,,,,,,,,, -13600,1.4722614,4.0882316,,,,,,,,,,,,,, -13700,1.5021842,4.077174,,,,,,,,,,,,,, -13800,1.1629491,5.641824,,,,,,,,,,,,,, -13900,1.4029714,4.0610666,,,,,,,,,,,,,, -14000,0.9028776,5.9645457,,,,,,,,,,,,,, -14045,,,0.4446093738079071,2.6638307571411133,0.4107399880886078,2.8218674659729004,50000.0,0.3206000030040741,3.385582685470581,10000.0,6345.643961429596,6769.418182611465,6345.643961429596,422.5653555393219,0.495150089263916,0.0 -14100,1.4328705,4.0763826,,,,,,,,,,,,,, -14200,1.5900918,4.0432415,,,,,,,,,,,,,, -14300,1.5362021,4.084932,,,,,,,,,,,,,, -14400,1.3548963,4.1521997,,,,,,,,,,,,,, -14500,1.3752089,4.0445557,,,,,,,,,,,,,, -14600,1.1663735,4.3806763,,,,,,,,,,,,,, -14700,0.93810457,5.6204863,,,,,,,,,,,,,, -14800,1.1366359,5.544763,,,,,,,,,,,,,, -14900,1.1665503,4.8607674,,,,,,,,,,,,,, -14974,,,0.4664062261581421,2.479521036148072,0.4297599792480469,2.666552782058716,50000.0,0.3350000083446502,3.2387516498565674,10000.0,6765.623802185059,7218.315635204315,6765.623802185059,451.4034585952759,0.5285844802856445,0.0 -15000,0.9825131,5.0720325,,,,,,,,,,,,,, -15100,1.4970547,4.0707374,,,,,,,,,,,,,, -15200,1.0155033,5.9434204,,,,,,,,,,,,,, -15300,1.3331108,4.0442004,,,,,,,,,,,,,, -15400,1.6053345,4.259178,,,,,,,,,,,,,, -15500,1.4205092,4.1365604,,,,,,,,,,,,,, -15600,1.0868995,5.037866,,,,,,,,,,,,,, -15700,1.8385627,3.947741,,,,,,,,,,,,,, -15800,1.3663663,3.904938,,,,,,,,,,,,,, -15900,1.5122867,4.0313854,,,,,,,,,,,,,, -15908,,,0.4827929437160492,2.4374914169311523,0.4357599914073944,2.665578365325928,50000.0,0.3364000022411346,3.2438104152679443,10000.0,7185.726698637009,7667.809241294861,7185.726698637009,480.7165772914887,0.5597198009490967,0.0 -16000,1.3408363,3.9585283,,,,,,,,,,,,,, -16100,0.95362854,5.5676017,,,,,,,,,,,,,, -16200,1.069336,5.87036,,,,,,,,,,,,,, -16300,1.325597,4.431448,,,,,,,,,,,,,, -16400,0.9977749,5.2494645,,,,,,,,,,,,,, -16500,1.5901095,3.7587347,,,,,,,,,,,,,, -16600,1.505357,3.8989568,,,,,,,,,,,,,, -16700,1.5783165,4.0084696,,,,,,,,,,,,,, -16800,0.94633406,5.484852,,,,,,,,,,,,,, -16837,,,0.4757421910762787,2.441497802734375,0.4432799816131592,2.600684881210327,50000.0,0.3495000302791595,3.1816678047180176,10000.0,7605.864891529083,8116.923457622528,7605.864891529083,509.6156742572784,0.5903451442718506,0.0 -16900,1.3663243,3.8082347,,,,,,,,,,,,,, -17000,1.4019134,3.9050248,,,,,,,,,,,,,, -17100,1.3100822,4.039314,,,,,,,,,,,,,, -17200,1.5615025,4.0398707,,,,,,,,,,,,,, -17300,1.1014156,4.4606404,,,,,,,,,,,,,, -17400,0.9539688,5.064855,,,,,,,,,,,,,, -17500,0.94789267,5.7792807,,,,,,,,,,,,,, -17600,1.34638,4.002158,,,,,,,,,,,,,, -17700,1.2234136,3.7780166,,,,,,,,,,,,,, -17761,,,0.4897070229053497,2.3479857444763184,0.4527199864387512,2.532827854156494,50000.0,0.3516000211238861,3.1301140785217285,10000.0,8025.931247711182,8568.302577495575,8025.931247711182,540.8512902259827,0.6215465068817139,0.0 -17800,1.2031647,4.070902,,,,,,,,,,,,,, -17900,1.2660397,4.857977,,,,,,,,,,,,,, -18000,1.2963086,4.55154,,,,,,,,,,,,,, -18100,1.1161716,5.1050854,,,,,,,,,,,,,, -18200,1.1452502,4.4279838,,,,,,,,,,,,,, -18300,1.4835339,4.0994782,,,,,,,,,,,,,, -18400,1.1412157,5.998169,,,,,,,,,,,,,, -18500,1.2941471,4.2164106,,,,,,,,,,,,,, -18600,1.4025217,4.076776,,,,,,,,,,,,,, -18693,,,0.5077343583106995,2.280947208404541,0.4601799845695495,2.506450653076172,50000.0,0.3596000075340271,3.1122446060180664,10000.0,8445.988456726074,9019.213403463364,8445.988456726074,571.6251528263092,0.6539442539215088,0.0 -18700,1.4096096,3.8425846,,,,,,,,,,,,,, -18800,1.3600295,4.0223303,,,,,,,,,,,,,, -18900,1.2771261,4.0231876,,,,,,,,,,,,,, -19000,1.4773527,3.8589168,,,,,,,,,,,,,, -19100,1.3349521,4.0171537,,,,,,,,,,,,,, -19200,1.1581769,5.649651,,,,,,,,,,,,,, -19300,1.4244938,3.9238026,,,,,,,,,,,,,, -19400,1.4578692,3.659963,,,,,,,,,,,,,, -19500,1.4483395,3.8513258,,,,,,,,,,,,,, -19600,1.1685388,5.486492,,,,,,,,,,,,,, -19624,,,0.5035741925239563,2.2893991470336914,0.4703999757766723,2.448329448699951,50000.0,0.3666000068187713,3.0513882637023926,10000.0,8866.089724779129,9470.51671743393,8866.089724779129,602.7466416358948,0.687971830368042,0.0 -19700,1.8043061,3.7503095,,,,,,,,,,,,,, -19800,1.5427719,3.9264889,,,,,,,,,,,,,, -19900,1.0637294,5.4549637,,,,,,,,,,,,,, -20000,0.95392495,5.7670326,,,,,,,,,,,,,, -20100,1.4079002,3.9235127,,,,,,,,,,,,,, -20200,1.0608035,4.88385,,,,,,,,,,,,,, -20300,1.3887295,4.2913423,,,,,,,,,,,,,, -20400,1.244263,3.829946,,,,,,,,,,,,,, -20500,1.2570338,4.350074,,,,,,,,,,,,,, -20553,,,0.5167187452316284,2.226701259613037,0.4785199761390686,2.39848256111145,50000.0,0.3762000203132629,2.9945783615112305,10000.0,9286.414787769318,9921.960839271544,9286.414787769318,633.781985282898,0.7244741916656494,0.0 -20600,0.90744185,5.260525,,,,,,,,,,,,,, -20700,1.4542272,3.7385945,,,,,,,,,,,,,, -20800,1.347664,3.804254,,,,,,,,,,,,,, -20900,1.0139236,4.898305,,,,,,,,,,,,,, -21000,1.2518165,3.7751002,,,,,,,,,,,,,, -21100,1.2995398,3.6466768,,,,,,,,,,,,,, -21200,1.4809899,3.652616,,,,,,,,,,,,,, -21300,1.3187054,3.670098,,,,,,,,,,,,,, -21400,1.086831,5.4201765,,,,,,,,,,,,,, -21480,,,0.5309960842132568,2.138967514038086,0.4841599762439728,2.3563876152038574,50000.0,0.3792000114917755,2.9629838466644287,10000.0,9706.38515138626,10374.46075630188,9706.38515138626,666.2311322689056,0.7581882476806641,0.0 -21500,1.3344711,3.7703538,,,,,,,,,,,,,, -21600,0.8614662,5.298666,,,,,,,,,,,,,, -21700,1.3863635,3.7933812,,,,,,,,,,,,,, -21800,1.2098157,3.9706576,,,,,,,,,,,,,, -21900,1.4404283,3.824079,,,,,,,,,,,,,, -22000,1.2806121,3.980441,,,,,,,,,,,,,, -22100,1.4052558,3.846157,,,,,,,,,,,,,, -22200,1.2935511,3.863061,,,,,,,,,,,,,, -22300,1.1737373,4.0690107,,,,,,,,,,,,,, -22400,1.4451884,3.723927,,,,,,,,,,,,,, -22413,,,0.5240234136581421,2.189439535140991,0.4843199849128723,2.373043298721313,50000.0,0.3786000311374664,2.9851624965667725,10000.0,10126.384377241136,10824.917355537416,10126.384377241136,696.6076049804688,0.7909915447235107,0.0 -22500,1.4763397,3.587244,,,,,,,,,,,,,, -22600,1.0309184,4.699204,,,,,,,,,,,,,, -22700,1.4347516,3.6839442,,,,,,,,,,,,,, -22800,1.0051925,5.7933173,,,,,,,,,,,,,, -22900,1.0670524,5.3714166,,,,,,,,,,,,,, -23000,1.0389824,4.8695054,,,,,,,,,,,,,, -23100,1.2378591,4.0464854,,,,,,,,,,,,,, -23200,1.052339,4.350938,,,,,,,,,,,,,, -23300,1.3966879,3.6964836,,,,,,,,,,,,,, -23351,,,0.5264257788658142,2.1538989543914795,0.4958399832248688,2.312279224395752,50000.0,0.3881000280380249,2.9460103511810303,10000.0,10546.718402862549,11276.222237586975,10546.718402862549,727.5011661052704,0.8213632106781006,0.0 -23400,1.0039854,5.599104,,,,,,,,,,,,,, -23500,1.5201913,3.9577308,,,,,,,,,,,,,, -23600,1.3559763,3.5772161,,,,,,,,,,,,,, -23700,1.3592949,3.8317637,,,,,,,,,,,,,, -23800,1.1869726,3.9030786,,,,,,,,,,,,,, -23900,1.1437604,4.446167,,,,,,,,,,,,,, -24000,1.3459959,3.834746,,,,,,,,,,,,,, -24100,1.3547769,3.761136,,,,,,,,,,,,,, -24200,1.2782842,3.6207027,,,,,,,,,,,,,, -24285,,,0.5478906035423279,2.078741788864136,0.5089600086212158,2.2755093574523926,50000.0,0.3977000117301941,2.882901906967163,10000.0,10966.702552080154,11725.926349878311,10966.702552080154,757.1414499282837,0.8539583683013916,0.0 -24300,1.2651983,4.3556333,,,,,,,,,,,,,, -24400,0.96908414,4.6647134,,,,,,,,,,,,,, -24500,1.3721877,4.0293326,,,,,,,,,,,,,, -24600,1.2128441,4.1777496,,,,,,,,,,,,,, -24700,1.4455761,3.7246277,,,,,,,,,,,,,, -24800,1.3043923,3.5377128,,,,,,,,,,,,,, -24900,1.0529932,4.9749126,,,,,,,,,,,,,, -25000,1.4246951,3.674743,,,,,,,,,,,,,, -25100,1.0606983,5.311233,,,,,,,,,,,,,, -25200,1.3375577,3.7799933,,,,,,,,,,,,,, -25218,,,0.5692187547683716,1.966854214668274,0.5087199807167053,2.25063705444336,50000.0,0.4005000293254852,2.87434720993042,10000.0,11386.934196472168,12178.604896306992,11386.934196472168,789.5124342441559,0.8832418918609619,0.0 -25300,1.3403747,4.107239,,,,,,,,,,,,,, -25400,1.6501575,3.5883813,,,,,,,,,,,,,, -25500,1.5436606,3.8072865,,,,,,,,,,,,,, -25600,1.2163255,3.6982872,,,,,,,,,,,,,, -25700,1.132324,4.6062407,,,,,,,,,,,,,, -25800,1.1780877,5.873487,,,,,,,,,,,,,, -25900,1.2665073,3.7192678,,,,,,,,,,,,,, -26000,1.085574,4.055969,,,,,,,,,,,,,, -26100,1.2381959,3.9358752,,,,,,,,,,,,,, -26143,,,0.546679675579071,2.098712205886841,0.5127400159835815,2.264395236968994,50000.0,0.4005000293254852,2.883815288543701,10000.0,11806.891574144363,12629.590085029602,11806.891574144363,820.4644253253937,0.912727117538452,0.0 -26200,1.3397146,3.7342074,,,,,,,,,,,,,, -26300,1.3172599,3.5912194,,,,,,,,,,,,,, -26400,1.1061096,5.3411393,,,,,,,,,,,,,, -26500,1.117269,5.683624,,,,,,,,,,,,,, -26600,1.3730594,3.6922226,,,,,,,,,,,,,, -26700,1.3713888,3.5837893,,,,,,,,,,,,,, -26800,1.1150057,5.8099184,,,,,,,,,,,,,, -26900,1.3239893,3.6419864,,,,,,,,,,,,,, -27000,1.404977,3.5766482,,,,,,,,,,,,,, -27074,,,0.561718761920929,2.005185127258301,0.5216599702835083,2.184654235839844,50000.0,0.4057000279426574,2.811198234558105,10000.0,12227.0956864357,13082.901546955109,12227.0956864357,853.4955842494965,0.941619634628296,0.0 -27100,1.4464804,3.5266716,,,,,,,,,,,,,, -27200,1.2185056,4.080562,,,,,,,,,,,,,, -27300,1.2090188,4.4566126,,,,,,,,,,,,,, -27400,1.2467206,3.5421276,,,,,,,,,,,,,, -27500,1.3868935,3.7576818,,,,,,,,,,,,,, -27600,0.92449427,5.318943,,,,,,,,,,,,,, -27700,1.3023692,3.5117216,,,,,,,,,,,,,, -27800,1.4709144,3.6116967,,,,,,,,,,,,,, -27900,1.4616526,3.6558704,,,,,,,,,,,,,, -28000,1.584897,3.894056,,,,,,,,,,,,,, -28005,,,0.5656445026397705,1.974370360374451,0.5188599824905396,2.1966774463653564,50000.0,0.4076000154018402,2.814857482910156,10000.0,12647.427459478378,13536.928488254547,12647.427459478378,887.1125380992889,0.973057985305786,0.0 -28100,1.5926008,3.7476854,,,,,,,,,,,,,, -28200,1.3352635,3.7991476,,,,,,,,,,,,,, -28300,1.3020455,3.6788092,,,,,,,,,,,,,, -28400,1.0747603,5.120579,,,,,,,,,,,,,, -28500,1.3077221,3.6041358,,,,,,,,,,,,,, -28600,1.2430543,4.359005,,,,,,,,,,,,,, -28700,1.4773879,3.5560207,,,,,,,,,,,,,, -28800,1.1184515,5.001051,,,,,,,,,,,,,, -28900,1.4269122,3.5821,,,,,,,,,,,,,, -28935,,,0.560351550579071,2.0256459712982178,0.5274800062179565,2.1879539489746094,50000.0,0.4113000333309173,2.815878391265869,10000.0,13067.464753627775,13989.526335477827,13067.464753627775,919.5988109111786,1.001277208328247,0.0 -29000,1.1386979,4.9869165,,,,,,,,,,,,,, -29100,1.4229407,3.722393,,,,,,,,,,,,,, -29200,1.1237475,5.634825,,,,,,,,,,,,,, -29300,1.3618248,3.5901632,,,,,,,,,,,,,, -29400,1.0622414,5.7913775,,,,,,,,,,,,,, -29500,1.341509,3.5712857,,,,,,,,,,,,,, -29600,1.4295868,3.5747395,,,,,,,,,,,,,, -29700,1.5079291,3.5264409,,,,,,,,,,,,,, -29800,1.4382436,3.5290496,,,,,,,,,,,,,, -29863,,,0.5651562213897705,1.994471549987793,0.5298799872398376,2.168131113052368,50000.0,0.4169000089168548,2.7868216037750244,10000.0,13487.82652425766,14442.533663749697,13487.82652425766,952.1639549732208,1.035024881362915,0.0 -29900,1.2578036,5.469873,,,,,,,,,,,,,, -30000,1.331177,3.54667,,,,,,,,,,,,,, -30100,1.3875132,3.437374,,,,,,,,,,,,,, -30200,1.269488,3.454446,,,,,,,,,,,,,, -30300,1.3312293,3.6943712,,,,,,,,,,,,,, -30400,1.4453257,3.5098631,,,,,,,,,,,,,, -30500,1.3238193,4.193453,,,,,,,,,,,,,, -30600,1.3764822,3.4652462,,,,,,,,,,,,,, -30700,1.4876127,3.709479,,,,,,,,,,,,,, -30794,,,0.5801953077316284,1.903186917304993,0.5334599614143372,2.116943836212158,50000.0,0.4204000234603882,2.7451517581939697,10000.0,13908.166835308077,14896.18096923828,13908.166835308077,985.389402627945,1.070270538330078,0.0 -30800,1.4154781,3.574019,,,,,,,,,,,,,, -30900,1.3455966,3.6079602,,,,,,,,,,,,,, -31000,1.1082463,5.434694,,,,,,,,,,,,,, -31100,1.2204928,5.4706235,,,,,,,,,,,,,, -31200,1.1759378,5.084169,,,,,,,,,,,,,, -31300,1.4811404,3.5607455,,,,,,,,,,,,,, -31400,1.1986426,4.5896387,,,,,,,,,,,,,, -31500,1.363158,3.5407732,,,,,,,,,,,,,, -31600,1.6736337,3.5658324,,,,,,,,,,,,,, -31700,1.1602621,5.6873727,,,,,,,,,,,,,, -31726,,,0.5763476490974426,1.9141044616699217,0.5375999808311462,2.1075029373168945,50000.0,0.4228000342845917,2.7345998287200928,10000.0,14328.138955116272,15349.214524507524,14328.138955116272,1018.364322900772,1.1095900535583496,0.0 -31800,1.4020098,3.5506344,,,,,,,,,,,,,, -31900,1.0327508,5.46269,,,,,,,,,,,,,, -32000,1.2794499,4.2083054,,,,,,,,,,,,,, -32100,1.6092371,3.5632455,,,,,,,,,,,,,, -32200,1.4074547,3.616869,,,,,,,,,,,,,, -32300,1.0490167,4.6886544,,,,,,,,,,,,,, -32400,1.33956,3.7471678,,,,,,,,,,,,,, -32500,1.2894589,4.128153,,,,,,,,,,,,,, -32600,1.4566813,3.670543,,,,,,,,,,,,,, -32657,,,0.5758984088897705,1.9533084630966189,0.5400399565696716,2.1200501918792725,50000.0,0.421500027179718,2.763683319091797,10000.0,14748.117035627363,15801.935980558395,14748.117035627363,1051.026505947113,1.14223051071167,0.0 -32700,1.1905068,4.108163,,,,,,,,,,,,,, -32800,1.1737977,4.6748905,,,,,,,,,,,,,, -32900,1.1170897,4.618264,,,,,,,,,,,,,, -33000,1.4163338,3.6293015,,,,,,,,,,,,,, -33100,1.182796,4.1435328,,,,,,,,,,,,,, -33200,1.3900598,3.5297272,,,,,,,,,,,,,, -33300,1.1818898,5.489215,,,,,,,,,,,,,, -33400,1.3224905,4.0833726,,,,,,,,,,,,,, -33500,1.4880779,3.490248,,,,,,,,,,,,,, -33589,,,0.5841405987739563,1.894763708114624,0.5448200106620789,2.079097509384156,50000.0,0.4267000257968902,2.706271171569824,10000.0,15168.483533620834,16256.044842720032,15168.483533620834,1084.6919131278992,1.1724753379821775,0.0 -33600,1.3824793,3.5693314,,,,,,,,,,,,,, -33700,1.4806384,3.4551826,,,,,,,,,,,,,, -33800,1.3741243,3.5482633,,,,,,,,,,,,,, -33900,1.2879544,3.9224358,,,,,,,,,,,,,, -34000,1.2833508,3.3530784,,,,,,,,,,,,,, -34100,1.093548,4.235422,,,,,,,,,,,,,, -34200,1.304014,5.4240837,,,,,,,,,,,,,, -34300,1.2747935,3.8521411,,,,,,,,,,,,,, -34400,1.2043728,4.017389,,,,,,,,,,,,,, -34500,1.5148423,3.5710456,,,,,,,,,,,,,, -34518,,,0.6061328053474426,1.7760976552963257,0.545799970626831,2.062562704086304,50000.0,0.431300014257431,2.6870129108428955,10000.0,15588.672712087631,16708.648688793182,15588.672712087631,1117.0305380821228,1.202235460281372,0.0 -34600,1.0877872,4.5451593,,,,,,,,,,,,,, -34700,1.4204493,3.50596,,,,,,,,,,,,,, -34800,1.1052152,5.115506,,,,,,,,,,,,,, -34900,1.4249929,3.5467546,,,,,,,,,,,,,, -35000,1.8947793,3.8565848,,,,,,,,,,,,,, -35100,1.4176726,3.491548,,,,,,,,,,,,,, -35200,1.4274102,3.6901836,,,,,,,,,,,,,, -35300,1.6330881,3.4903436,,,,,,,,,,,,,, -35400,1.1379012,5.221512,,,,,,,,,,,,,, -35448,,,0.584667980670929,1.894122004508972,0.5425599813461304,2.086238145828247,50000.0,0.4293000102043152,2.7089946269989014,10000.0,16009.01386475563,17163.030723571777,16009.01386475563,1150.9917540550232,1.2357745170593262,0.0 -35500,1.4175467,3.6331952,,,,,,,,,,,,,, -35600,1.1340132,4.367344,,,,,,,,,,,,,, -35700,1.3643917,3.4629972,,,,,,,,,,,,,, -35800,1.5041753,3.493704,,,,,,,,,,,,,, -35900,1.4362013,3.4227414,,,,,,,,,,,,,, -36000,1.425145,3.432533,,,,,,,,,,,,,, -36100,1.2513143,3.6933804,,,,,,,,,,,,,, -36200,1.0142915,5.263977,,,,,,,,,,,,,, -36300,1.4958885,3.5697436,,,,,,,,,,,,,, -36379,,,0.5904101133346558,1.8927258253097528,0.5471999645233154,2.0837810039520264,50000.0,0.4314000308513641,2.709493398666382,10000.0,16429.294507026672,17616.62527346611,16429.294507026672,1184.2244091033936,1.2700214385986328,0.0 -36400,1.0204102,5.251382,,,,,,,,,,,,,, -36500,1.5148567,3.4947433,,,,,,,,,,,,,, -36600,1.340383,3.4690356,,,,,,,,,,,,,, -36700,1.4613122,3.3884315,,,,,,,,,,,,,, -36800,1.3402507,3.3406203,,,,,,,,,,,,,, -36900,1.4454776,3.8251665,,,,,,,,,,,,,, -37000,1.4001654,3.4776363,,,,,,,,,,,,,, -37100,1.4195107,3.6261678,,,,,,,,,,,,,, -37200,1.5059544,3.573957,,,,,,,,,,,,,, -37300,1.4140129,3.4732487,,,,,,,,,,,,,, -37309,,,0.6047655940055847,1.7837172746658323,0.556119978427887,2.017679214477539,50000.0,0.4397000074386596,2.661262989044189,10000.0,16849.5801718235,18069.852862596512,16849.5801718235,1217.0834302902222,1.3062067031860352,0.0 -37400,1.5727195,3.505485,,,,,,,,,,,,,, -37500,1.4314759,3.8233612,,,,,,,,,,,,,, -37600,1.099151,5.6687307,,,,,,,,,,,,,, -37700,1.3514634,3.8842926,,,,,,,,,,,,,, -37800,1.4853218,3.419615,,,,,,,,,,,,,, -37900,1.4170727,3.4535837,,,,,,,,,,,,,, -38000,1.4548144,3.7115402,,,,,,,,,,,,,, -38100,1.5579252,3.4535222,,,,,,,,,,,,,, -38200,1.2785475,4.0011044,,,,,,,,,,,,,, -38237,,,0.592089831829071,1.862553954124451,0.5530399680137634,2.0462992191314697,50000.0,0.4394000172615051,2.677217721939087,10000.0,17269.676019191742,18523.4055621624,17269.676019191742,1250.4606716632843,1.3389804363250732,0.0 -38300,1.5002131,3.4256153,,,,,,,,,,,,,, -38400,1.3107177,4.057417,,,,,,,,,,,,,, -38500,1.1423687,4.577868,,,,,,,,,,,,,, -38600,1.3227828,3.4558105,,,,,,,,,,,,,, -38700,1.2195108,4.222101,,,,,,,,,,,,,, -38800,1.5256559,3.3803744,,,,,,,,,,,,,, -38900,1.4866716,3.4281712,,,,,,,,,,,,,, -39000,1.119231,4.34216,,,,,,,,,,,,,, -39100,1.559165,3.377999,,,,,,,,,,,,,, -39165,,,0.5917382836341858,1.8414897918701167,0.5532400012016296,2.0242159366607666,50000.0,0.4356000125408172,2.6622133255004883,10000.0,17689.861362457275,18976.674981355667,17689.861362457275,1283.460999250412,1.376353740692139,0.0 -39200,1.4292065,3.404348,,,,,,,,,,,,,, -39300,1.4852544,3.3713696,,,,,,,,,,,,,, -39400,1.1959544,4.741022,,,,,,,,,,,,,, -39500,1.2419099,5.430008,,,,,,,,,,,,,, -39600,1.4288042,3.5870667,,,,,,,,,,,,,, -39700,1.463591,3.657346,,,,,,,,,,,,,, -39800,1.3298683,3.3436553,,,,,,,,,,,,,, -39900,1.1369892,5.3506455,,,,,,,,,,,,,, -40000,1.313337,3.9026432,,,,,,,,,,,,,, -40095,,,0.6004687547683716,1.7826272249221802,0.5520200133323669,2.0079150199890137,50000.0,0.4385000169277191,2.6310582160949707,10000.0,18110.095131635662,19429.93856573105,18110.095131635662,1316.411651134491,1.409106969833374,0.0 -40100,1.5245177,3.438066,,,,,,,,,,,,,, -40200,1.4453597,3.890563,,,,,,,,,,,,,, -40300,1.5716115,3.3798528,,,,,,,,,,,,,, -40400,1.4634935,3.574855,,,,,,,,,,,,,, -40500,1.15789,5.3195124,,,,,,,,,,,,,, -40600,1.5002437,3.4788072,,,,,,,,,,,,,, -40700,1.461834,3.508359,,,,,,,,,,,,,, -40800,1.7043889,3.525539,,,,,,,,,,,,,, -40900,1.2336658,4.7081933,,,,,,,,,,,,,, -41000,1.5600725,3.32375,,,,,,,,,,,,,, -41025,,,0.6124023199081421,1.7426855564117432,0.5575399994850159,1.9958285093307493,50000.0,0.4479000270366668,2.604980707168579,10000.0,18530.15157198906,19882.56823849678,18530.15157198906,1348.9026863574982,1.444082498550415,0.0 -41100,1.387808,3.315206,,,,,,,,,,,,,, -41200,1.177537,4.38357,,,,,,,,,,,,,, -41300,1.8958839,3.3896644,,,,,,,,,,,,,, -41400,1.7174104,3.5101404,,,,,,,,,,,,,, -41500,1.3734245,3.3524108,,,,,,,,,,,,,, -41600,1.6338884,3.4890988,,,,,,,,,,,,,, -41700,1.1394802,5.2692533,,,,,,,,,,,,,, -41800,1.6430631,3.3965533,,,,,,,,,,,,,, -41900,1.1488315,4.8977942,,,,,,,,,,,,,, -41953,,,0.6009570360183716,1.83956265449524,0.5607799887657166,2.0273380279541016,50000.0,0.4416000247001648,2.6486172676086426,10000.0,18950.41338968277,20336.153192281723,18950.41338968277,1382.1475772857666,1.474935531616211,0.0 -42000,1.4287033,4.440485,,,,,,,,,,,,,, -42100,1.2586639,3.8008342,,,,,,,,,,,,,, -42200,1.3272407,3.5053215,,,,,,,,,,,,,, -42300,1.3786099,3.7615037,,,,,,,,,,,,,, -42400,1.3608743,3.79955,,,,,,,,,,,,,, -42500,1.2434847,4.171105,,,,,,,,,,,,,, -42600,1.4481132,3.3215284,,,,,,,,,,,,,, -42700,1.4698675,3.6541286,,,,,,,,,,,,,, -42800,1.4303632,3.5107584,,,,,,,,,,,,,, -42883,,,0.60546875,1.7532869577407837,0.5643599629402161,1.944724678993225,50000.0,0.4479000270366668,2.569371700286865,10000.0,19370.686596870422,20788.01582312584,19370.686596870422,1413.6585688591003,1.506608247756958,0.0 -42900,1.4553401,3.3533401,,,,,,,,,,,,,, -43000,1.6869198,3.4911804,,,,,,,,,,,,,, -43100,1.5840055,3.3690376,,,,,,,,,,,,,, -43200,1.4060043,3.3802695,,,,,,,,,,,,,, -43300,1.439455,3.4293613,,,,,,,,,,,,,, -43400,1.2067946,4.126319,,,,,,,,,,,,,, -43500,1.4502661,3.336243,,,,,,,,,,,,,, -43600,1.0971963,5.4099684,,,,,,,,,,,,,, -43700,1.4027914,3.5154722,,,,,,,,,,,,,, -43800,1.5305084,3.4452288,,,,,,,,,,,,,, -43812,,,0.6259570121765137,1.6537854671478271,0.5655800104141235,1.9352000951766968,50000.0,0.4571000337600708,2.540505886077881,10000.0,19790.67938184738,21241.678597688675,19790.67938184738,1447.244621515274,1.5436184406280518,0.0 -43900,1.0539494,5.6140933,,,,,,,,,,,,,, -44000,1.4942518,3.387824,,,,,,,,,,,,,, -44100,1.7637658,3.4608097,,,,,,,,,,,,,, -44200,1.5325757,3.3167872,,,,,,,,,,,,,, -44300,1.1492518,4.526476,,,,,,,,,,,,,, -44400,1.4276017,4.4345045,,,,,,,,,,,,,, -44500,1.7702751,3.3590908,,,,,,,,,,,,,, -44600,1.6089551,3.4202223,,,,,,,,,,,,,, -44700,1.1892744,4.479843,,,,,,,,,,,,,, -44744,,,0.60791015625,1.7625596523284912,0.565779983997345,1.950334429740905,50000.0,0.454800009727478,2.562917232513428,10000.0,20210.85508942604,21694.710247278214,20210.85508942604,1480.019334077835,1.577929973602295,0.0 -44800,1.5848583,3.375271,,,,,,,,,,,,,, -44900,1.3923843,3.5364826,,,,,,,,,,,,,, -45000,1.3547127,5.5181313,,,,,,,,,,,,,, -45100,1.4805154,3.3849452,,,,,,,,,,,,,, -45200,1.1611187,5.556333,,,,,,,,,,,,,, -45300,1.4210842,3.6958768,,,,,,,,,,,,,, -45400,1.5047631,3.4461794,,,,,,,,,,,,,, -45500,1.3354576,4.257968,,,,,,,,,,,,,, -45600,1.2796011,4.033114,,,,,,,,,,,,,, -45675,,,0.6089062094688416,1.7846567630767822,0.568619966506958,1.9842305183410645,50000.0,0.4541000127792358,2.59224271774292,10000.0,20630.80467486381,22148.28287220001,20630.80467486381,1513.5652458667755,1.608425855636597,0.0 -45700,1.5183994,3.5945175,,,,,,,,,,,,,, -45800,1.4897141,3.3457987,,,,,,,,,,,,,, -45900,1.5080701,3.4371986,,,,,,,,,,,,,, -46000,1.554578,3.4554353,,,,,,,,,,,,,, -46100,1.4620407,3.3308074,,,,,,,,,,,,,, -46200,1.6506741,3.6746807,,,,,,,,,,,,,, -46300,1.5763121,3.4013739,,,,,,,,,,,,,, -46400,1.126762,5.43896,,,,,,,,,,,,,, -46500,1.4774572,3.376029,,,,,,,,,,,,,, -46600,1.4115365,3.3433228,,,,,,,,,,,,,, -46604,,,0.619433581829071,1.69148051738739,0.5701199769973755,1.932340741157532,50000.0,0.4522000253200531,2.561401605606079,10000.0,21051.00689411164,22601.638983488083,21051.00689411164,1546.635691165924,1.6452360153198242,0.0 -46700,1.5449744,3.3630855,,,,,,,,,,,,,, -46800,1.1025742,4.9892297,,,,,,,,,,,,,, -46900,1.3527492,4.1089687,,,,,,,,,,,,,, -47000,1.118281,5.0452676,,,,,,,,,,,,,, -47100,1.6079872,3.5033865,,,,,,,,,,,,,, -47200,1.3903711,5.4596677,,,,,,,,,,,,,, -47300,1.1608381,4.9453,,,,,,,,,,,,,, -47400,1.7810519,3.429709,,,,,,,,,,,,,, -47500,1.16977,5.5614676,,,,,,,,,,,,,, -47533,,,0.6141015291213989,1.700886845588684,0.5776199698448181,1.8775568008422847,50000.0,0.4607000350952148,2.502983331680298,10000.0,21470.93300724029,23052.105201005936,21470.93300724029,1577.0918953418732,1.681839942932129,0.0 -47600,1.3398762,3.7310066,,,,,,,,,,,,,, -47700,1.2952414,5.411546,,,,,,,,,,,,,, -47800,1.6673961,3.463544,,,,,,,,,,,,,, -47900,1.5325861,3.4736683,,,,,,,,,,,,,, -48000,1.0425923,5.385129,,,,,,,,,,,,,, -48100,1.4289305,4.0737844,,,,,,,,,,,,,, -48200,1.5105004,3.6388733,,,,,,,,,,,,,, -48300,1.1881602,4.9044037,,,,,,,,,,,,,, -48400,1.3991656,3.490676,,,,,,,,,,,,,, -48459,,,0.6114843487739563,1.7474989891052246,0.570580005645752,1.9303126335144043,50000.0,0.4547000229358673,2.565771579742432,10000.0,21891.24653053284,23505.38790154457,21891.24653053284,1609.9773399829865,1.718390703201294,0.0 -48500,1.5492944,3.3265905,,,,,,,,,,,,,, -48600,1.4033948,3.3411312,,,,,,,,,,,,,, -48700,1.4705865,3.3240333,,,,,,,,,,,,,, -48800,1.2157649,5.163315,,,,,,,,,,,,,, -48900,1.4349469,3.4185576,,,,,,,,,,,,,, -49000,1.4469514,3.3382974,,,,,,,,,,,,,, -49100,1.6358638,3.4407556,,,,,,,,,,,,,, -49200,1.4610754,3.8397377,,,,,,,,,,,,,, -49300,1.2257738,5.072008,,,,,,,,,,,,,, -49388,,,0.6169726252555847,1.6989887952804563,0.5762199759483337,1.9017544984817505,50000.0,0.4626000225543976,2.526845216751098,10000.0,22311.626733779907,23959.685015916824,22311.626733779907,1643.8127937316897,1.7531659603118896,0.0 -49400,1.5136553,3.39428,,,,,,,,,,,,,, -49500,1.4290792,3.394433,,,,,,,,,,,,,, -49600,1.8176649,3.4375687,,,,,,,,,,,,,, -49700,1.4889165,3.403417,,,,,,,,,,,,,, -49800,1.5380545,3.23837,,,,,,,,,,,,,, -49900,1.5000043,3.2893047,,,,,,,,,,,,,, -50000,1.4587537,3.358345,,,,,,,,,,,,,, -50100,1.2461137,4.8344746,,,,,,,,,,,,,, -50200,1.6771919,3.2898974,,,,,,,,,,,,,, -50300,1.583646,3.403559,,,,,,,,,,,,,, -50318,,,0.6458203196525574,1.5972343683242798,0.5758799910545349,1.9063783884048464,50000.0,0.4619000256061554,2.520209550857544,10000.0,22731.99612236023,24413.78429436684,22731.99612236023,1677.461537361145,1.786334991455078,0.0 -50400,1.6280615,3.4326966,,,,,,,,,,,,,, -50500,1.8729182,3.5332377,,,,,,,,,,,,,, -50600,1.4850074,3.5635936,,,,,,,,,,,,,, -50700,1.4145579,3.5591724,,,,,,,,,,,,,, -50800,1.5991133,3.3990755,,,,,,,,,,,,,, -50900,1.516229,3.3138998,,,,,,,,,,,,,, -51000,1.2276684,5.5308905,,,,,,,,,,,,,, -51100,1.590357,3.3216786,,,,,,,,,,,,,, -51200,1.3884101,3.814199,,,,,,,,,,,,,, -51246,,,0.613476574420929,1.7316758632659912,0.5759199857711792,1.9049936532974243,50000.0,0.4612000286579132,2.528799057006836,10000.0,23151.934993743896,24866.39386487007,23151.934993743896,1710.0534682273865,1.8191730976104736,0.0 -51300,1.7605476,3.39189,,,,,,,,,,,,,, -51400,1.716011,3.306998,,,,,,,,,,,,,, -51500,1.4972161,3.3174756,,,,,,,,,,,,,, -51600,1.2870432,4.070635,,,,,,,,,,,,,, -51700,1.7075392,3.3683028,,,,,,,,,,,,,, -51800,1.597864,3.3512533,,,,,,,,,,,,,, -51900,1.5593532,3.2107508,,,,,,,,,,,,,, -52000,1.4623592,3.2762983,,,,,,,,,,,,,, -52100,1.4389143,5.0182295,,,,,,,,,,,,,, -52175,,,0.6240624785423279,1.667878031730652,0.5776399970054626,1.88678514957428,50000.0,0.465800017118454,2.502259492874145,10000.0,23571.889559984207,25319.559384584427,23571.889559984207,1743.183295249939,1.854111909866333,0.0 -52200,1.4214526,3.3467739,,,,,,,,,,,,,, -52300,1.3103256,3.893237,,,,,,,,,,,,,, -52400,1.6342776,3.3181117,,,,,,,,,,,,,, -52500,1.594913,3.2439706,,,,,,,,,,,,,, -52600,1.5911653,3.2942357,,,,,,,,,,,,,, -52700,1.4548286,3.3393967,,,,,,,,,,,,,, -52800,1.4478018,3.924406,,,,,,,,,,,,,, -52900,1.4969407,3.332124,,,,,,,,,,,,,, -53000,1.4444377,3.325203,,,,,,,,,,,,,, -53100,1.6128597,3.2576942,,,,,,,,,,,,,, -53104,,,0.6388866901397705,1.6128058433532717,0.5816599726676941,1.8688938617706297,50000.0,0.4680000245571136,2.480672836303711,10000.0,23992.10686635971,25771.13542485237,23992.10686635971,1774.4506244659424,1.8986506462097168,0.0 -53200,1.633606,3.412941,,,,,,,,,,,,,, -53300,1.4454278,5.5699534,,,,,,,,,,,,,, -53400,1.5806456,3.3579283,,,,,,,,,,,,,, -53500,1.2879889,3.7876656,,,,,,,,,,,,,, -53600,1.1127051,4.880122,,,,,,,,,,,,,, -53700,1.5367419,3.236663,,,,,,,,,,,,,, -53800,1.473068,3.2401276,,,,,,,,,,,,,, -53900,1.6585515,3.4758468,,,,,,,,,,,,,, -54000,1.4661461,5.0854897,,,,,,,,,,,,,, -54034,,,0.6194140315055847,1.7486610412597656,0.5802599787712097,1.9259588718414309,50000.0,0.460500031709671,2.5516300201416016,10000.0,24412.24143385887,26224.815851688385,24412.24143385887,1807.9144456386568,1.9343111515045168,0.0 -54100,1.1474575,5.1459966,,,,,,,,,,,,,, -54200,1.5927054,3.414129,,,,,,,,,,,,,, -54300,1.0830957,5.4727736,,,,,,,,,,,,,, -54400,1.6404752,3.3592107,,,,,,,,,,,,,, -54500,1.5904379,3.3433995,,,,,,,,,,,,,, -54600,1.6996552,3.1777983,,,,,,,,,,,,,, -54700,1.3949658,4.2593217,,,,,,,,,,,,,, -54800,1.193542,5.5585938,,,,,,,,,,,,,, -54900,1.6565423,3.3634434,,,,,,,,,,,,,, -54963,,,0.6254101395606995,1.6814249753952026,0.5814599990844727,1.8866448402404783,50000.0,0.4665000140666961,2.5163378715515137,10000.0,24832.460919380188,26678.31806921959,24832.460919380188,1841.114578008652,1.9702081680297847,0.0 -55000,1.466287,3.2393146,,,,,,,,,,,,,, -55100,1.544162,4.1621504,,,,,,,,,,,,,, -55200,1.570473,3.9187427,,,,,,,,,,,,,, -55300,1.2713845,5.4563913,,,,,,,,,,,,,, -55400,1.6275862,3.3220797,,,,,,,,,,,,,, -55500,1.5624608,3.8225248,,,,,,,,,,,,,, -55600,1.6983917,3.275096,,,,,,,,,,,,,, -55700,1.4684843,3.5947244,,,,,,,,,,,,,, -55800,1.3221428,5.2710643,,,,,,,,,,,,,, -55891,,,0.6341210603713989,1.6259918212890625,0.5867399573326111,1.858447790145874,50000.0,0.4720000326633453,2.4755938053131104,10000.0,25252.844719171524,27131.659603118896,25252.844719171524,1873.9938821792605,2.003046751022339,0.0 -55900,1.6186309,3.284651,,,,,,,,,,,,,, -56000,1.2927103,4.887601,,,,,,,,,,,,,, -56100,1.524312,3.3592825,,,,,,,,,,,,,, -56200,1.5393701,3.1775517,,,,,,,,,,,,,, -56300,1.530463,3.3327022,,,,,,,,,,,,,, -56400,1.4741974,3.2428856,,,,,,,,,,,,,, -56500,1.2478086,5.2956066,,,,,,,,,,,,,, -56600,1.5584697,3.6693661,,,,,,,,,,,,,, -56700,1.5296992,3.3560143,,,,,,,,,,,,,, -56800,1.3949692,3.5220418,,,,,,,,,,,,,, -56820,,,0.6281445026397705,1.6840720176696775,0.5877199769020081,1.8711938858032229,50000.0,0.4678000211715698,2.5007386207580566,10000.0,25672.96810555458,27585.487575769424,25672.96810555458,1907.6127750873568,2.042140007019043,0.0 -56900,1.2457209,4.565387,,,,,,,,,,,,,, -57000,1.6210195,4.545428,,,,,,,,,,,,,, -57100,1.4034221,3.703669,,,,,,,,,,,,,, -57200,1.2304327,5.1412306,,,,,,,,,,,,,, -57300,1.3815626,3.7859707,,,,,,,,,,,,,, -57400,1.4897227,3.4372022,,,,,,,,,,,,,, -57500,1.5388633,3.5433838,,,,,,,,,,,,,, -57600,1.3416172,4.0752273,,,,,,,,,,,,,, -57700,1.3274527,4.4514236,,,,,,,,,,,,,, -57752,,,0.6265624761581421,1.685849905014038,0.5827000141143799,1.870428204536438,50000.0,0.4648000299930572,2.495413303375244,10000.0,26093.019316911697,28038.671048879623,26093.019316911697,1940.665373325348,2.074948310852051,0.0 -57800,1.4732401,3.3425696,,,,,,,,,,,,,, -57900,1.6026951,3.3739784,,,,,,,,,,,,,, -58000,1.3084294,4.5093217,,,,,,,,,,,,,, -58100,1.7138193,3.4185822,,,,,,,,,,,,,, -58200,1.3194606,4.9709864,,,,,,,,,,,,,, -58300,1.4006933,4.0146666,,,,,,,,,,,,,, -58400,1.727324,3.4424834,,,,,,,,,,,,,, -58500,1.5358521,3.6358302,,,,,,,,,,,,,, -58600,1.2303649,4.6724834,,,,,,,,,,,,,, -58682,,,0.6369921565055847,1.610442280769348,0.5866000056266785,1.8379977941513064,50000.0,0.4708000123500824,2.459522008895874,10000.0,26513.309331178665,28492.00569462776,26513.309331178665,1973.6229138374329,2.110410690307617,0.0 -58700,1.7106512,3.2960887,,,,,,,,,,,,,, -58800,1.2717227,4.502632,,,,,,,,,,,,,, -58900,1.5424879,3.317587,,,,,,,,,,,,,, -59000,1.511528,3.4755104,,,,,,,,,,,,,, -59100,1.6548245,3.417698,,,,,,,,,,,,,, -59200,1.5771497,3.3192942,,,,,,,,,,,,,, -59300,1.7651021,3.3884597,,,,,,,,,,,,,, -59400,1.5127798,3.4574523,,,,,,,,,,,,,, -59500,1.4913204,3.632424,,,,,,,,,,,,,, -59600,1.404361,4.29772,,,,,,,,,,,,,, -59613,,,0.6574023365974426,1.540103316307068,0.5890399813652039,1.8447201251983645,50000.0,0.4741000235080719,2.4697728157043457,10000.0,26933.473502397537,28945.743367433548,26933.473502397537,2007.113491773605,2.146523952484131,0.0 -59700,1.7333475,3.276532,,,,,,,,,,,,,, -59800,1.2198275,5.4454923,,,,,,,,,,,,,, -59900,1.5347668,3.2300375,,,,,,,,,,,,,, -60000,1.5646878,3.3343587,,,,,,,,,,,,,, -60100,1.5539925,3.3896043,,,,,,,,,,,,,, -60200,1.5907265,3.223927,,,,,,,,,,,,,, -60300,1.5186199,3.2295966,,,,,,,,,,,,,, -60400,1.4483762,3.1919036,,,,,,,,,,,,,, -60500,1.2973977,5.2114906,,,,,,,,,,,,,, -60543,,,0.6248632669448853,1.6834291219711304,0.589419960975647,1.868920922279358,50000.0,0.4699000120162964,2.503729820251465,10000.0,27353.815348386765,29398.575788736343,27353.815348386765,2039.523863077164,2.1801421642303467,0.0 -60600,1.6712247,3.6148689,,,,,,,,,,,,,, -60700,1.5961783,3.3264868,,,,,,,,,,,,,, -60800,1.7008924,3.3185508,,,,,,,,,,,,,, -60900,1.6048292,3.4787753,,,,,,,,,,,,,, -61000,1.6690444,3.2809875,,,,,,,,,,,,,, -61100,1.3093863,5.37579,,,,,,,,,,,,,, -61200,1.6000197,3.3170805,,,,,,,,,,,,,, -61300,1.6398002,3.5996418,,,,,,,,,,,,,, -61400,1.4708973,3.4965022,,,,,,,,,,,,,, -61472,,,0.6370507478713989,1.664278268814087,0.5917400121688843,1.8678573369979856,50000.0,0.4755000174045563,2.495166063308716,10000.0,27774.116877794266,29849.922494888306,27774.116877794266,2070.4793763160706,2.2234983444213867,0.0 -61500,1.1971319,5.3072715,,,,,,,,,,,,,, -61600,1.5658984,3.2866232,,,,,,,,,,,,,, -61700,1.3855885,3.6501276,,,,,,,,,,,,,, -61800,1.7093127,3.356424,,,,,,,,,,,,,, -61900,1.4606965,3.6802585,,,,,,,,,,,,,, -62000,1.6060573,3.2625382,,,,,,,,,,,,,, -62100,1.3728687,4.8564076,,,,,,,,,,,,,, -62200,1.3604032,4.2696257,,,,,,,,,,,,,, -62300,1.4453868,3.610932,,,,,,,,,,,,,, -62400,1.280266,5.3466244,,,,,,,,,,,,,, -62401,,,0.6454882621765137,1.614796757698059,0.5929200053215027,1.8566802740097048,50000.0,0.4774000346660614,2.467177152633667,10000.0,28194.66119718552,30304.25809168816,28194.66119718552,2104.190548181534,2.257309675216675,0.0 -62500,1.5518398,3.2751791,,,,,,,,,,,,,, -62600,1.4246131,3.5291123,,,,,,,,,,,,,, -62700,1.3520863,4.182805,,,,,,,,,,,,,, -62800,1.585879,3.463317,,,,,,,,,,,,,, -62900,1.7494965,3.2075531,,,,,,,,,,,,,, -63000,1.4069262,5.110101,,,,,,,,,,,,,, -63100,1.5857664,3.2999127,,,,,,,,,,,,,, -63200,1.4360417,4.6878633,,,,,,,,,,,,,, -63300,1.4802214,3.32704,,,,,,,,,,,,,, -63332,,,0.63623046875,1.6195902824401855,0.5971999764442444,1.8102082014083865,50000.0,0.4755000174045563,2.443674802780152,10000.0,28614.79931879044,30758.743307828903,28614.79931879044,2138.450407981873,2.2979867458343506,0.0 -63400,1.5510802,3.3362944,,,,,,,,,,,,,, -63500,1.6119449,3.227422,,,,,,,,,,,,,, -63600,1.5463982,3.1676483,,,,,,,,,,,,,, -63700,1.6063641,3.2980978,,,,,,,,,,,,,, -63800,1.5105834,3.4666467,,,,,,,,,,,,,, -63900,1.6753135,3.3138952,,,,,,,,,,,,,, -64000,1.8265452,3.2751422,,,,,,,,,,,,,, -64100,1.3243738,4.053977,,,,,,,,,,,,,, -64200,1.3172097,4.489782,,,,,,,,,,,,,, -64262,,,0.6394140720367432,1.6281222105026243,0.5967199802398682,1.819630146026612,50000.0,0.4784000217914581,2.4457361698150635,10000.0,29035.028126716614,31212.194100141525,29035.028126716614,2171.592320203781,2.3315210342407227,0.0 -64300,1.5460272,3.774381,,,,,,,,,,,,,, -64400,1.500945,3.3355203,,,,,,,,,,,,,, -64500,1.851148,3.2180886,,,,,,,,,,,,,, -64600,1.3408651,4.78368,,,,,,,,,,,,,, -64700,1.4464903,4.04099,,,,,,,,,,,,,, -64800,1.521911,3.4135334,,,,,,,,,,,,,, -64900,1.6117345,3.2571075,,,,,,,,,,,,,, -65000,1.5430782,3.2272463,,,,,,,,,,,,,, -65100,1.5071616,3.88214,,,,,,,,,,,,,, -65192,,,0.6471484303474426,1.5985087156295776,0.5951799750328064,1.824656248092652,50000.0,0.4735000133514404,2.464690685272217,10000.0,29454.9858648777,31665.85106754303,29454.9858648777,2205.211054801941,2.3653695583343506,0.0 -65200,1.6372185,3.0890539,,,,,,,,,,,,,, -65300,1.4749875,4.038753,,,,,,,,,,,,,, -65400,1.5705001,3.3190093,,,,,,,,,,,,,, -65500,1.2119148,4.64528,,,,,,,,,,,,,, -65600,1.614201,3.3685482,,,,,,,,,,,,,, -65700,1.7560233,3.3247216,,,,,,,,,,,,,, -65800,1.8439959,3.255378,,,,,,,,,,,,,, -65900,1.4566224,3.512159,,,,,,,,,,,,,, -66000,1.5149843,3.748622,,,,,,,,,,,,,, -66100,1.383101,5.443328,,,,,,,,,,,,,, -66119,,,0.6446288824081421,1.5758085250854492,0.6003199815750122,1.7816158533096311,50000.0,0.4864000082015991,2.4119138717651367,10000.0,29875.15716052056,32119.37609243393,29875.15716052056,2238.4791843891144,2.404351234436035,0.0 -66200,1.3923783,4.9064655,,,,,,,,,,,,,, -66300,1.6109937,3.4472408,,,,,,,,,,,,,, -66400,1.2960572,4.530588,,,,,,,,,,,,,, -66500,1.4925354,4.800213,,,,,,,,,,,,,, -66600,1.7468122,3.2502487,,,,,,,,,,,,,, -66700,1.5268018,3.4639065,,,,,,,,,,,,,, -66800,1.6193361,3.5454278,,,,,,,,,,,,,, -66900,1.3246198,4.905059,,,,,,,,,,,,,, -67000,1.6230942,3.376996,,,,,,,,,,,,,, -67048,,,0.6425976157188416,1.5925278663635254,0.6013799905776978,1.785370945930481,50000.0,0.4790000319480896,2.427741765975952,10000.0,30295.264727830887,32573.9971203804,30295.264727830887,2272.9096236228943,2.440931558609009,0.0 -67100,1.3707389,4.1988864,,,,,,,,,,,,,, -67200,1.48074,3.666788,,,,,,,,,,,,,, -67300,1.3821392,5.0626993,,,,,,,,,,,,,, -67400,1.3930005,4.7600775,,,,,,,,,,,,,, -67500,1.8487306,3.2490902,,,,,,,,,,,,,, -67600,1.5783875,3.7940726,,,,,,,,,,,,,, -67700,1.5975358,3.1741781,,,,,,,,,,,,,, -67800,1.6127995,3.3048775,,,,,,,,,,,,,, -67900,1.4310274,5.1812,,,,,,,,,,,,,, -67979,,,0.6479296684265137,1.6087180376052856,0.5970799922943115,1.8276971578598025,50000.0,0.4832000136375427,2.4544215202331543,10000.0,30715.514212608337,33025.66438269615,30715.514212608337,2304.2412803173065,2.480074882507324,0.0 -68000,1.6148231,3.2007213,,,,,,,,,,,,,, -68100,1.5369452,3.3135386,,,,,,,,,,,,,, -68200,1.5274849,3.4302545,,,,,,,,,,,,,, -68300,1.5880595,3.3191383,,,,,,,,,,,,,, -68400,1.7159281,3.239869,,,,,,,,,,,,,, -68500,1.6272382,3.3878908,,,,,,,,,,,,,, -68600,1.4372176,3.7649958,,,,,,,,,,,,,, -68700,1.5655645,3.189285,,,,,,,,,,,,,, -68800,1.6242944,3.256446,,,,,,,,,,,,,, -68900,1.5608466,3.1745715,,,,,,,,,,,,,, -68905,,,0.6702734231948853,1.4906214475631714,0.6045799851417542,1.7795348167419434,50000.0,0.4932000339031219,2.402475595474243,10000.0,31135.559599637985,33478.6574652195,31135.559599637985,2337.104706287384,2.5181806087493896,0.0 -69000,1.3903779,4.606887,,,,,,,,,,,,,, -69100,1.6528778,3.2972803,,,,,,,,,,,,,, -69200,1.439726,3.439841,,,,,,,,,,,,,, -69300,1.2006937,4.3700533,,,,,,,,,,,,,, -69400,1.5830716,3.196512,,,,,,,,,,,,,, -69500,1.6197866,3.345324,,,,,,,,,,,,,, -69600,1.563385,4.564304,,,,,,,,,,,,,, -69700,1.5792555,3.2932878,,,,,,,,,,,,,, -69800,1.6714313,3.303814,,,,,,,,,,,,,, -69834,,,0.6407226324081421,1.5906224250793457,0.6022399663925171,1.764672040939331,50000.0,0.4812000095844269,2.3967788219451904,10000.0,31555.447756052017,33931.170246362686,31555.447756052017,2369.4204025268555,2.779877185821533,0.0 -69900,1.4442532,3.385446,,,,,,,,,,,,,, -70000,1.448508,3.5910912,,,,,,,,,,,,,, -70100,1.6007843,3.2848728,,,,,,,,,,,,,, -70200,1.838414,3.402664,,,,,,,,,,,,,, -70300,1.5043806,3.5207863,,,,,,,,,,,,,, -70400,1.3813827,4.3267603,,,,,,,,,,,,,, -70500,1.5056862,3.2320735,,,,,,,,,,,,,, -70600,1.385728,3.9413056,,,,,,,,,,,,,, -70700,2.0951982,3.1762745,,,,,,,,,,,,,, -70761,,,0.6514452695846558,1.564841866493225,0.6045599579811096,1.7766313552856443,50000.0,0.4848000109195709,2.404025077819824,10000.0,31975.477875947952,34384.702178001404,31975.477875947952,2402.840269088745,2.8157126903533936,0.0 -70800,1.5936773,3.234801,,,,,,,,,,,,,, -70900,1.7506487,3.2382107,,,,,,,,,,,,,, -71000,1.7640917,3.2366457,,,,,,,,,,,,,, -71100,1.657313,3.1464522,,,,,,,,,,,,,, -71200,1.4210569,3.940122,,,,,,,,,,,,,, -71300,1.6936204,3.252139,,,,,,,,,,,,,, -71400,1.6777401,3.2024493,,,,,,,,,,,,,, -71500,1.451242,3.7494607,,,,,,,,,,,,,, -71600,1.7187905,3.254894,,,,,,,,,,,,,, -71688,,,0.6606835722923279,1.527460694313049,0.6034600138664246,1.7711652517318726,50000.0,0.484000027179718,2.400637149810791,10000.0,32395.61031579972,34838.37434220314,32395.61031579972,2436.2950756549835,2.854088068008423,0.0 -71700,1.7913072,3.1899087,,,,,,,,,,,,,, -71800,1.4532273,3.7219796,,,,,,,,,,,,,, -71900,1.631445,3.4369152,,,,,,,,,,,,,, -72000,1.2416964,4.9441595,,,,,,,,,,,,,, -72100,1.7419086,3.250551,,,,,,,,,,,,,, -72200,1.6736057,3.173228,,,,,,,,,,,,,, -72300,1.4532895,3.8281417,,,,,,,,,,,,,, -72400,1.4523897,3.811698,,,,,,,,,,,,,, -72500,1.4506003,4.5044117,,,,,,,,,,,,,, -72600,1.9699742,3.3077102,,,,,,,,,,,,,, -72615,,,0.6470116972923279,1.589882493019104,0.6062799692153931,1.779320240020752,50000.0,0.4868000149726867,2.4277710914611816,10000.0,32815.5456404686,35290.03398799896,32815.5456404686,2467.9322276115417,2.894427299499512,0.0 -72700,1.4749649,3.616222,,,,,,,,,,,,,, -72800,1.6038942,3.1326878,,,,,,,,,,,,,, -72900,1.6549307,3.1884608,,,,,,,,,,,,,, -73000,1.6519722,3.1757696,,,,,,,,,,,,,, -73100,1.5007628,3.5636034,,,,,,,,,,,,,, -73200,1.3204706,5.0611467,,,,,,,,,,,,,, -73300,1.5272584,4.228923,,,,,,,,,,,,,, -73400,1.4967612,4.340715,,,,,,,,,,,,,, -73500,1.4140313,3.6504025,,,,,,,,,,,,,, -73541,,,0.6502929329872131,1.571126103401184,0.6121199727058411,1.750990629196167,50000.0,0.4923000335693359,2.3877408504486084,10000.0,33235.80899262428,35743.71882414818,33235.80899262428,2501.2704651355743,2.93189001083374,0.0 -73600,1.3719411,4.6851797,,,,,,,,,,,,,, -73700,1.7635225,3.2041862,,,,,,,,,,,,,, -73800,1.6749934,3.3052497,,,,,,,,,,,,,, -73900,1.219942,4.710918,,,,,,,,,,,,,, -74000,1.8259461,3.1943355,,,,,,,,,,,,,, -74100,1.2808956,4.3322515,,,,,,,,,,,,,, -74200,1.7016622,3.4004238,,,,,,,,,,,,,, -74300,1.2914051,4.8895025,,,,,,,,,,,,,, -74400,1.5114261,4.6713076,,,,,,,,,,,,,, -74470,,,0.6590234041213989,1.541833758354187,0.6091399788856506,1.7622010707855225,50000.0,0.4879000186920166,2.3940000534057617,10000.0,33655.84841275215,36197.356711387634,33655.84841275215,2534.7862520217896,2.968616485595703,0.0 -74500,1.9069431,3.2425773,,,,,,,,,,,,,, -74600,1.6921805,3.2763405,,,,,,,,,,,,,, -74700,1.760311,3.2436924,,,,,,,,,,,,,, -74800,1.3555686,4.940134,,,,,,,,,,,,,, -74900,1.559022,3.4955623,,,,,,,,,,,,,, -75000,1.3474078,5.414615,,,,,,,,,,,,,, -75100,1.7852012,3.0841923,,,,,,,,,,,,,, -75200,1.6969706,3.0956845,,,,,,,,,,,,,, -75300,1.9136901,3.256002,,,,,,,,,,,,,, -75400,,,0.6696093678474426,1.4805610179901123,0.6134999990463257,1.732527732849121,50000.0,0.4926000237464905,2.362853527069092,10000.0,34076.092334747314,36651.11418533325,34076.092334747314,2568.215485572815,3.006901502609253,0.0 -75400,1.6466106,3.445284,,,,,,,,,,,,,, -75500,1.5155796,3.4126754,,,,,,,,,,,,,, -75600,1.5657555,3.3142931,,,,,,,,,,,,,, -75700,1.5359346,3.799448,,,,,,,,,,,,,, -75800,1.6187593,3.1789863,,,,,,,,,,,,,, -75900,1.4478862,4.1988244,,,,,,,,,,,,,, -76000,1.3584415,4.949308,,,,,,,,,,,,,, -76100,1.5320574,4.1327214,,,,,,,,,,,,,, -76200,1.5623406,3.1275542,,,,,,,,,,,,,, -76300,1.6147112,5.0500584,,,,,,,,,,,,,, -76330,,,0.65869140625,1.5325862169265747,0.614300012588501,1.7345703840255735,50000.0,0.4974000155925751,2.353220701217652,10000.0,34496.37813591957,37106.03781723976,34496.37813591957,2602.766112804413,3.0473546981811523,0.0 -76400,1.549557,5.383623,,,,,,,,,,,,,, -76500,1.656561,3.4753735,,,,,,,,,,,,,, -76600,1.4812971,5.338188,,,,,,,,,,,,,, -76700,1.4522182,4.1549516,,,,,,,,,,,,,, -76800,1.6850245,3.202434,,,,,,,,,,,,,, -76900,1.4600761,3.4783762,,,,,,,,,,,,,, -77000,1.587429,3.1696677,,,,,,,,,,,,,, -77100,1.6373155,3.0575323,,,,,,,,,,,,,, -77200,1.819743,3.0992584,,,,,,,,,,,,,, -77258,,,0.6606054306030273,1.511372208595276,0.6152399778366089,1.7202774286270142,50000.0,0.4937000274658203,2.351444721221924,10000.0,34916.623254299164,37559.12559890747,34916.623254299164,2635.522970676422,3.086964130401612,0.0 -77300,1.8877928,3.290936,,,,,,,,,,,,,, -77400,1.4008225,4.8419104,,,,,,,,,,,,,, -77500,1.6534113,3.2519565,,,,,,,,,,,,,, -77600,1.7911097,3.247864,,,,,,,,,,,,,, -77700,1.5751973,5.1512585,,,,,,,,,,,,,, -77800,1.5881946,5.238654,,,,,,,,,,,,,, -77900,1.5427307,4.782575,,,,,,,,,,,,,, -78000,1.4051714,5.458396,,,,,,,,,,,,,, -78100,1.8262848,3.0812893,,,,,,,,,,,,,, -78184,,,0.6808788776397705,1.4191409349441528,0.6181600093841553,1.7033716440200806,50000.0,0.5019000172615051,2.3303937911987305,10000.0,35336.7923810482,38010.754996299744,35336.7923810482,2666.8984639644623,3.1255970001220703,0.0 -78200,1.6579331,3.1747675,,,,,,,,,,,,,, -78300,1.3452419,4.074297,,,,,,,,,,,,,, -78400,1.3059865,5.205759,,,,,,,,,,,,,, -78500,1.6232909,3.1769722,,,,,,,,,,,,,, -78600,1.6222404,3.4397173,,,,,,,,,,,,,, -78700,1.6593573,3.534674,,,,,,,,,,,,,, -78800,1.8825696,3.0849552,,,,,,,,,,,,,, -78900,2.09979,3.2862525,,,,,,,,,,,,,, -79000,1.6253668,3.1340532,,,,,,,,,,,,,, -79100,1.7473594,3.2428021,,,,,,,,,,,,,, -79112,,,0.661816418170929,1.5143336057662964,0.6148399710655212,1.7183010578155518,50000.0,0.4924000203609466,2.3449583053588867,10000.0,35757.1461520195,38464.51617407799,35757.1461520195,2700.217987060547,3.1663384437561035,0.0 -79200,1.6289942,3.2947152,,,,,,,,,,,,,, -79300,1.8546233,3.144015,,,,,,,,,,,,,, -79400,1.6059825,4.265485,,,,,,,,,,,,,, -79500,1.6344967,4.161994,,,,,,,,,,,,,, -79600,1.8374072,3.1286743,,,,,,,,,,,,,, -79700,1.7390741,3.2000618,,,,,,,,,,,,,, -79800,1.3757579,5.255627,,,,,,,,,,,,,, -79900,1.7851475,3.2444737,,,,,,,,,,,,,, -80000,1.4903262,5.28271,,,,,,,,,,,,,, -80043,,,0.664746105670929,1.510551691055298,0.615339994430542,1.7205344438552856,50000.0,0.4970000088214874,2.373055934906006,10000.0,36177.24590039253,38918.11818599701,36177.24590039253,2733.6324348449707,3.207908391952514,0.0 -80100,1.6651452,3.0998762,,,,,,,,,,,,,, -80200,1.3631684,5.3160415,,,,,,,,,,,,,, -80300,1.8297608,3.236248,,,,,,,,,,,,,, -80400,1.7094059,3.1510901,,,,,,,,,,,,,, -80500,1.6876953,3.0763986,,,,,,,,,,,,,, -80600,1.5740571,4.405403,,,,,,,,,,,,,, -80700,1.7491082,3.343801,,,,,,,,,,,,,, -80800,1.834536,3.2352777,,,,,,,,,,,,,, -80900,1.8045783,3.2186093,,,,,,,,,,,,,, -80972,,,0.6741796731948853,1.4529361724853516,0.6193599700927734,1.702526330947876,50000.0,0.4974000155925751,2.3454010486602783,10000.0,36597.27389025688,39371.64637541771,36597.27389025688,2767.0463218688965,3.247930765151977,0.0 -81000,1.615451,3.3848433,,,,,,,,,,,,,, -81100,1.7501436,3.151277,,,,,,,,,,,,,, -81200,1.9802092,3.121912,,,,,,,,,,,,,, -81300,1.8484777,3.0350695,,,,,,,,,,,,,, -81400,1.3585562,4.056299,,,,,,,,,,,,,, -81500,1.6764839,3.2895565,,,,,,,,,,,,,, -81600,1.7358135,3.186118,,,,,,,,,,,,,, -81700,1.7495129,3.404297,,,,,,,,,,,,,, -81800,1.4898127,4.8873906,,,,,,,,,,,,,, -81899,,,0.667773425579071,1.502530217170715,0.622759997844696,1.6995172500610352,50000.0,0.5027000308036804,2.336276054382324,10000.0,37017.20155906677,39824.93992829323,37017.20155906677,2800.32315158844,3.2903239727020264,0.0 -81900,1.7903955,3.1299076,,,,,,,,,,,,,, -82000,1.6636994,3.1665049,,,,,,,,,,,,,, -82100,1.8436683,3.1272454,,,,,,,,,,,,,, -82200,1.6144907,3.489172,,,,,,,,,,,,,, -82300,1.8343062,3.1704311,,,,,,,,,,,,,, -82400,1.6578794,3.1764638,,,,,,,,,,,,,, -82500,2.0411167,3.184759,,,,,,,,,,,,,, -82600,1.7834558,3.1329527,,,,,,,,,,,,,, -82700,1.7933286,3.1554196,,,,,,,,,,,,,, -82800,1.4506258,4.330114,,,,,,,,,,,,,, -82829,,,0.6694530844688416,1.4893957376480105,0.6229199767112732,1.6935948133468628,50000.0,0.5026000142097473,2.331036329269409,10000.0,37437.302568912506,40277.142501831055,37437.302568912506,2832.334250688553,3.334193706512451,0.0 -82900,1.6393771,3.222591,,,,,,,,,,,,,, -83000,1.6167864,3.4100902,,,,,,,,,,,,,, -83100,1.7216924,3.1015036,,,,,,,,,,,,,, -83200,1.7896149,3.1147463,,,,,,,,,,,,,, -83300,1.8718532,3.1141324,,,,,,,,,,,,,, -83400,1.4884706,3.9965167,,,,,,,,,,,,,, -83500,1.6702341,2.9488575,,,,,,,,,,,,,, -83600,1.9706806,3.4183612,,,,,,,,,,,,,, -83700,1.6529686,3.0586948,,,,,,,,,,,,,, -83758,,,0.6728710532188416,1.4168633222579956,0.626800000667572,1.644349455833435,50000.0,0.4980000257492065,2.28191351890564,10000.0,37857.52192592621,40730.55080986023,37857.52192592621,2865.438676595688,3.372976303100586,0.0 -83800,1.9149866,3.1426985,,,,,,,,,,,,,, -83900,1.6569302,3.4876823,,,,,,,,,,,,,, -84000,1.6937968,3.1012545,,,,,,,,,,,,,, -84100,1.527559,5.097152,,,,,,,,,,,,,, -84200,1.7135084,3.2102184,,,,,,,,,,,,,, -84300,1.4157379,5.2961283,,,,,,,,,,,,,, -84400,1.5580955,3.560701,,,,,,,,,,,,,, -84500,1.3668655,5.0482936,,,,,,,,,,,,,, -84600,1.421932,5.233435,,,,,,,,,,,,,, -84687,,,0.695019543170929,1.3510807752609253,0.6265199780464172,1.650166630744934,50000.0,0.508400022983551,2.274543285369873,10000.0,38277.640315294266,41185.6047809124,38277.640315294266,2900.2883038520813,3.4122555255889893,0.0 -84700,1.535591,3.9704344,,,,,,,,,,,,,, -84800,1.7113283,3.55724,,,,,,,,,,,,,, -84900,1.4672258,3.9425676,,,,,,,,,,,,,, -85000,1.9341813,3.0520957,,,,,,,,,,,,,, -85100,1.5132719,5.3417625,,,,,,,,,,,,,, -85200,1.8778862,3.1686208,,,,,,,,,,,,,, -85300,1.6496454,3.562648,,,,,,,,,,,,,, -85400,1.8992847,3.1384797,,,,,,,,,,,,,, -85500,1.7442905,3.0861757,,,,,,,,,,,,,, -85600,1.6702963,4.1611304,,,,,,,,,,,,,, -85617,,,0.6769140362739563,1.4378875494003296,0.6297799944877625,1.655285120010376,50000.0,0.5103999972343445,2.277604818344116,10000.0,38697.576645851135,41639.67412424088,38697.576645851135,2934.331175804138,3.451533555984497,0.0 -85700,1.6100192,3.0412745,,,,,,,,,,,,,, -85800,1.5728196,3.5960732,,,,,,,,,,,,,, -85900,1.744521,3.100101,,,,,,,,,,,,,, -86000,1.8137882,3.0860898,,,,,,,,,,,,,, -86100,2.0418026,3.232749,,,,,,,,,,,,,, -86200,1.6075552,3.5182986,,,,,,,,,,,,,, -86300,1.7053493,3.2032146,,,,,,,,,,,,,, -86400,1.5450846,4.945646,,,,,,,,,,,,,, -86500,1.5269201,4.331631,,,,,,,,,,,,,, -86545,,,0.676074206829071,1.442645788192749,0.6266199946403503,1.6573635339736938,50000.0,0.5065000057220459,2.2886831760406494,10000.0,39117.53278756142,42091.563891649246,39117.53278756142,2966.1738238334656,3.496487617492676,0.0 -86600,1.697952,2.9985833,,,,,,,,,,,,,, -86700,1.8723627,3.0971937,,,,,,,,,,,,,, -86800,1.8170841,3.2282162,,,,,,,,,,,,,, -86900,1.7495631,3.4161603,,,,,,,,,,,,,, -87000,1.5109338,4.94132,,,,,,,,,,,,,, -87100,1.8200208,3.1798863,,,,,,,,,,,,,, -87200,1.8457725,3.2767594,,,,,,,,,,,,,, -87300,1.4311744,5.0712695,,,,,,,,,,,,,, -87400,1.5981779,3.7647686,,,,,,,,,,,,,, -87473,,,0.6855859160423279,1.3951412439346311,0.6272000074386597,1.6594972610473633,50000.0,0.5064000487327576,2.292130470275879,10000.0,39537.75274658203,42546.0910179615,39537.75274658203,3000.3928265571594,3.5384793281555176,0.0 -87500,1.464446,4.3162003,,,,,,,,,,,,,, -87600,1.8498636,3.099661,,,,,,,,,,,,,, -87700,1.9055924,3.156164,,,,,,,,,,,,,, -87800,1.8505518,3.1044323,,,,,,,,,,,,,, -87900,1.76341,3.021379,,,,,,,,,,,,,, -88000,1.9426088,3.1387079,,,,,,,,,,,,,, -88100,1.6781685,3.2492943,,,,,,,,,,,,,, -88200,1.8213048,3.0949438,,,,,,,,,,,,,, -88300,1.6606661,4.835484,,,,,,,,,,,,,, -88400,1.6744258,5.2628794,,,,,,,,,,,,,, -88401,,,0.6751171946525574,1.4573235511779783,0.630299985408783,1.667418122291565,50000.0,0.5100000500679016,2.2937583923339844,10000.0,39957.88946032524,43000.04769778252,39957.88946032524,3034.1299324035645,3.575552463531494,0.0 -88500,1.7645395,3.0792186,,,,,,,,,,,,,, -88600,1.771405,3.01797,,,,,,,,,,,,,, -88700,1.463184,4.0142097,,,,,,,,,,,,,, -88800,1.7509909,3.0268376,,,,,,,,,,,,,, -88900,1.8843076,3.089791,,,,,,,,,,,,,, -89000,1.7166225,3.462524,,,,,,,,,,,,,, -89100,1.761858,3.0872335,,,,,,,,,,,,,, -89200,1.9143051,3.1593342,,,,,,,,,,,,,, -89300,1.3766509,4.267751,,,,,,,,,,,,,, -89327,,,0.6839843392372131,1.394673466682434,0.6331200003623962,1.625125527381897,50000.0,0.5110000371932983,2.2482283115386963,10000.0,40377.935428380966,43453.84268569946,40377.935428380966,3067.7917091846466,3.61670994758606,0.0 -89400,1.6028111,5.0672636,,,,,,,,,,,,,, -89500,1.4808027,4.4686394,,,,,,,,,,,,,, -89600,1.5969114,3.6596777,,,,,,,,,,,,,, -89700,1.4905381,4.7295885,,,,,,,,,,,,,, -89800,1.8199836,3.143184,,,,,,,,,,,,,, -89900,1.6121998,5.1770487,,,,,,,,,,,,,, -90000,1.561992,3.7684205,,,,,,,,,,,,,, -90100,1.5945971,3.6414723,,,,,,,,,,,,,, -90200,1.5336614,3.753573,,,,,,,,,,,,,, -90254,,,0.6922265291213989,1.403250217437744,0.6358599662780762,1.6580744981765747,50000.0,0.5161000490188599,2.281223773956299,10000.0,40798.026288986206,43909.31564474106,40798.026288986206,3103.089093208313,3.654865503311157,0.0 -90300,1.5114598,4.060041,,,,,,,,,,,,,, -90400,1.6108366,4.3967957,,,,,,,,,,,,,, -90500,1.9156661,3.052137,,,,,,,,,,,,,, -90600,1.4132689,4.8180814,,,,,,,,,,,,,, -90700,1.7655369,3.0515528,,,,,,,,,,,,,, -90800,1.5987384,4.6086845,,,,,,,,,,,,,, -90900,1.4889348,4.745558,,,,,,,,,,,,,, -91000,1.9153663,3.1496637,,,,,,,,,,,,,, -91100,1.604513,3.5890675,,,,,,,,,,,,,, -91182,,,0.6755273342132568,1.446442723274231,0.6324999928474426,1.6372050046920776,50000.0,0.5085000395774841,2.2850115299224854,10000.0,41218.17198085785,44362.4227976799,41218.17198085785,3135.961694717407,3.69734001159668,0.0 -91200,2.0461788,3.0218132,,,,,,,,,,,,,, -91300,1.6834015,3.5177824,,,,,,,,,,,,,, -91400,1.928367,2.9622295,,,,,,,,,,,,,, -91500,1.7462767,5.223916,,,,,,,,,,,,,, -91600,1.8287956,2.989306,,,,,,,,,,,,,, -91700,2.0862894,3.1248896,,,,,,,,,,,,,, -91800,1.8345875,3.136034,,,,,,,,,,,,,, -91900,1.8168576,2.9159658,,,,,,,,,,,,,, -92000,1.775284,3.1382203,,,,,,,,,,,,,, -92100,1.7457381,5.026242,,,,,,,,,,,,,, -92112,,,0.6849218606948853,1.3958221673965454,0.6351799964904785,1.6146154403686523,50000.0,0.5118000507354736,2.2473669052124023,10000.0,41638.28119826317,44816.87034630776,41638.28119826317,3170.2124574184418,3.7383203506469727,0.0 -92200,1.899724,3.3627179,,,,,,,,,,,,,, -92300,1.5740267,5.136345,,,,,,,,,,,,,, -92400,1.6077472,4.115576,,,,,,,,,,,,,, -92500,1.9549358,3.0288346,,,,,,,,,,,,,, -92600,1.6578176,3.5526063,,,,,,,,,,,,,, -92700,1.8095626,3.0857942,,,,,,,,,,,,,, -92800,1.9033885,3.049979,,,,,,,,,,,,,, -92900,1.6747737,5.3151865,,,,,,,,,,,,,, -93000,1.7776567,3.2070155,,,,,,,,,,,,,, -93038,,,0.689160168170929,1.3483566045761108,0.6417999863624573,1.5728561878204346,50000.0,0.5232000350952148,2.201004266738892,10000.0,42058.32339930534,45269.29402279854,42058.32339930534,3202.508903503418,3.777261018753052,0.0 -93100,1.7604275,4.2950983,,,,,,,,,,,,,, -93200,1.9924282,2.9913526,,,,,,,,,,,,,, -93300,1.5432438,4.6707196,,,,,,,,,,,,,, -93400,2.1978865,3.1342592,,,,,,,,,,,,,, -93500,1.6571206,3.144568,,,,,,,,,,,,,, -93600,1.7707794,3.927522,,,,,,,,,,,,,, -93700,1.9172703,3.0827818,,,,,,,,,,,,,, -93800,1.61112,5.2163224,,,,,,,,,,,,,, -93900,1.6834991,4.143173,,,,,,,,,,,,,, -93965,,,0.7109375,1.3136073350906372,0.6380199790000916,1.624780297279358,50000.0,0.5166000127792358,2.255450487136841,10000.0,42478.304805994034,45721.74182486534,42478.304805994034,3234.89250922203,3.814180850982666,0.0 -94000,1.653096,3.4237516,,,,,,,,,,,,,, -94100,1.8665165,3.081632,,,,,,,,,,,,,, -94200,1.8336656,3.150565,,,,,,,,,,,,,, -94300,1.6384664,4.7210875,,,,,,,,,,,,,, -94400,1.9535793,2.9184635,,,,,,,,,,,,,, -94500,1.9909322,3.324405,,,,,,,,,,,,,, -94600,1.8821145,2.9971983,,,,,,,,,,,,,, -94700,1.7786249,3.3802629,,,,,,,,,,,,,, -94800,1.6190147,4.8927045,,,,,,,,,,,,,, -94887,,,0.6878905892372131,1.4085044860839844,0.6405199766159058,1.6214958429336548,50000.0,0.5225000381469727,2.2491910457611084,10000.0,42898.46243238449,46175.75550246239,42898.46243238449,3268.6627497673035,3.8543221950531006,0.0 -94900,1.9332582,3.0462475,,,,,,,,,,,,,, -95000,2.037427,3.0064263,,,,,,,,,,,,,, -95100,1.8959684,2.96322,,,,,,,,,,,,,, -95200,2.037424,2.93298,,,,,,,,,,,,,, -95300,1.617914,3.8029065,,,,,,,,,,,,,, -95400,1.4858314,5.213866,,,,,,,,,,,,,, -95500,1.6926892,4.364794,,,,,,,,,,,,,, -95600,1.7625297,3.2524767,,,,,,,,,,,,,, -95700,1.9329638,3.0211105,,,,,,,,,,,,,, -95800,1.7525618,3.3918455,,,,,,,,,,,,,, -95814,,,0.6914257407188416,1.367840051651001,0.6401599645614624,1.597143292427063,50000.0,0.5213000178337097,2.232096910476685,10000.0,43318.54101729393,46629.85206055641,43318.54101729393,3302.586932182312,3.902494430541992,0.0 -95900,1.9418097,3.0324857,,,,,,,,,,,,,, -96000,1.7890157,3.795572,,,,,,,,,,,,,, -96100,1.9949707,5.2457576,,,,,,,,,,,,,, -96200,1.983389,2.9986155,,,,,,,,,,,,,, -96300,1.8070998,2.9013839,,,,,,,,,,,,,, -96400,1.9223261,3.0299149,,,,,,,,,,,,,, -96500,1.7708007,3.744983,,,,,,,,,,,,,, -96600,1.8936634,3.2984798,,,,,,,,,,,,,, -96700,1.9344401,2.9577818,,,,,,,,,,,,,, -96744,,,0.7025195360183716,1.342665433883667,0.6408999562263489,1.617620587348938,50000.0,0.5175999999046326,2.258472681045532,10000.0,43738.83052825928,47084.42543315888,43738.83052825928,3336.776979207993,3.949820041656494,0.0 -96800,2.0279522,3.0646627,,,,,,,,,,,,,, -96900,1.777666,5.1252174,,,,,,,,,,,,,, -97000,1.8582414,2.940416,,,,,,,,,,,,,, -97100,1.5857278,5.1066675,,,,,,,,,,,,,, -97200,1.7489933,5.052432,,,,,,,,,,,,,, -97300,2.0041134,2.9998243,,,,,,,,,,,,,, -97400,1.7414099,4.660385,,,,,,,,,,,,,, -97500,2.0513513,3.0044549,,,,,,,,,,,,,, -97600,2.281904,3.0327134,,,,,,,,,,,,,, -97673,,,0.68896484375,1.3808518648147583,0.6461799740791321,1.570910930633545,50000.0,0.5253000259399414,2.2020792961120605,10000.0,44159.07249808312,47538.14868545532,44159.07249808312,3370.167640209198,3.994457244873047,0.0 -97700,1.6748506,4.2321763,,,,,,,,,,,,,, -97800,1.7981118,3.0105867,,,,,,,,,,,,,, -97900,1.8804735,3.2420754,,,,,,,,,,,,,, -98000,1.6880509,4.353167,,,,,,,,,,,,,, -98100,1.6483063,3.6922226,,,,,,,,,,,,,, -98200,2.0561237,2.9760146,,,,,,,,,,,,,, -98300,1.8130957,4.4133096,,,,,,,,,,,,,, -98400,1.8338147,4.7374043,,,,,,,,,,,,,, -98500,1.9626751,2.9854329,,,,,,,,,,,,,, -98598,,,0.69544917345047,1.3356281518936155,0.6452599763870239,1.56348717212677,50000.0,0.5278000235557556,2.181706428527832,10000.0,44579.15256071091,47991.88709282875,44579.15256071091,3403.740065574646,4.033020973205566,0.0 -98600,1.5960447,4.945903,,,,,,,,,,,,,, -98700,1.936668,2.9605007,,,,,,,,,,,,,, -98800,1.9764801,3.0311844,,,,,,,,,,,,,, -98900,1.8750552,2.9269314,,,,,,,,,,,,,, -99000,1.9450161,2.9590456,,,,,,,,,,,,,, -99100,1.6138191,4.1719046,,,,,,,,,,,,,, -99200,1.8642955,2.9857047,,,,,,,,,,,,,, -99300,1.7384996,3.0487223,,,,,,,,,,,,,, -99400,2.2948916,3.0091336,,,,,,,,,,,,,, -99500,1.7342798,2.9251685,,,,,,,,,,,,,, -99525,,,0.6994921565055847,1.3409777879714966,0.6436600089073181,1.5910606384277344,50000.0,0.5200999975204468,2.2290115356445312,10000.0,44999.109080553055,48445.73757982254,44999.109080553055,3437.5442354679108,4.075830459594727,0.0 -99600,1.9865814,2.9807146,,,,,,,,,,,,,, -99700,1.9464121,3.0971828,,,,,,,,,,,,,, -99800,1.7493724,3.0308843,,,,,,,,,,,,,, -99900,1.9846774,2.9576917,,,,,,,,,,,,,, -100000,1.5882406,4.162073,,,,,,,,,,,,,, -100100,2.0667403,3.0706966,,,,,,,,,,,,,, -100200,1.8213071,3.0167959,,,,,,,,,,,,,, -100300,1.9199923,2.9817753,,,,,,,,,,,,,, -100400,1.9596647,5.044921,,,,,,,,,,,,,, -100453,,,0.6985937356948853,1.3283151388168335,0.6507599949836731,1.5489925146102903,50000.0,0.5254000425338745,2.1810965538024902,10000.0,45419.194039821625,48899.44379377365,45419.194039821625,3471.073740005493,4.120898962020874,0.0 -100500,1.9787602,3.4149086,,,,,,,,,,,,,, -100600,2.2234342,3.1062648,,,,,,,,,,,,,, -100700,1.8394439,2.8852985,,,,,,,,,,,,,, -100800,1.829303,2.9390464,,,,,,,,,,,,,, -100900,1.696145,4.005119,,,,,,,,,,,,,, -101000,2.1675344,2.956339,,,,,,,,,,,,,, -101100,2.0050814,3.0331783,,,,,,,,,,,,,, -101200,2.05441,3.0328374,,,,,,,,,,,,,, -101300,1.9721026,3.005045,,,,,,,,,,,,,, -101381,,,0.7010741829872131,1.3405998945236206,0.6499399542808533,1.5560424327850342,50000.0,0.5294000506401062,2.1788530349731445,10000.0,45839.292186021805,49351.35535264015,45839.292186021805,3502.795699119568,4.166559219360352,0.0 -101400,1.9297451,3.0456717,,,,,,,,,,,,,, -101500,2.021657,2.9491603,,,,,,,,,,,,,, -101600,1.8968742,3.085042,,,,,,,,,,,,,, -101700,1.8939486,3.1811776,,,,,,,,,,,,,, -101800,1.677161,3.7371833,,,,,,,,,,,,,, -101900,2.0375366,3.3394666,,,,,,,,,,,,,, -102000,1.7638358,3.8753448,,,,,,,,,,,,,, -102100,1.9374547,3.0605025,,,,,,,,,,,,,, -102200,1.9793493,3.0351827,,,,,,,,,,,,,, -102300,1.7973744,3.2190397,,,,,,,,,,,,,, -102308,,,0.7044335603713989,1.303029179573059,0.6525200009346008,1.5312974452972412,50000.0,0.5324000120162964,2.1580193042755127,10000.0,46259.55302786827,49805.33374285698,46259.55302786827,3536.4254937171936,4.207376718521118,0.0 -102400,1.8946174,2.9446077,,,,,,,,,,,,,, -102500,1.925964,2.9163375,,,,,,,,,,,,,, -102600,2.3276005,4.6642213,,,,,,,,,,,,,, -102700,2.0360153,2.9252086,,,,,,,,,,,,,, -102800,1.9279839,2.9574966,,,,,,,,,,,,,, -102900,2.074107,3.0145023,,,,,,,,,,,,,, -103000,1.8538659,4.4495125,,,,,,,,,,,,,, -103100,2.0775812,2.9509583,,,,,,,,,,,,,, -103200,1.7683324,4.4388146,,,,,,,,,,,,,, -103236,,,0.721386730670929,1.2593860626220703,0.6481999754905701,1.5756916999816897,50000.0,0.5303000211715698,2.1878247261047363,10000.0,46679.52591824532,50259.024639844894,46679.52591824532,3570.0530354976654,4.251614093780518,0.0 -103300,2.0558813,2.878816,,,,,,,,,,,,,, -103400,1.9730661,5.20181,,,,,,,,,,,,,, -103500,1.7356033,3.5373168,,,,,,,,,,,,,, -103600,2.013007,2.897812,,,,,,,,,,,,,, -103700,1.9616697,2.9606884,,,,,,,,,,,,,, -103800,2.0088985,4.5294566,,,,,,,,,,,,,, -103900,2.1242018,3.0035915,,,,,,,,,,,,,, -104000,1.9218054,2.8650334,,,,,,,,,,,,,, -104100,2.2282863,2.9863179,,,,,,,,,,,,,, -104166,,,0.6998632550239563,1.3183801174163818,0.6520000100135803,1.5319174528121948,50000.0,0.5261000394821167,2.147125005722046,10000.0,47099.60453367233,50713.70186185837,47099.60453367233,3604.563158750534,4.293308973312378,0.0 -104200,2.022621,2.928767,,,,,,,,,,,,,, -104300,2.21822,2.9588494,,,,,,,,,,,,,, -104400,2.129064,2.9366226,,,,,,,,,,,,,, -104500,1.943049,2.9200723,,,,,,,,,,,,,, -104600,1.9820805,2.8565414,,,,,,,,,,,,,, -104700,1.9084535,2.9113178,,,,,,,,,,,,,, -104800,2.0189414,5.058874,,,,,,,,,,,,,, -104900,1.9527774,2.963399,,,,,,,,,,,,,, -105000,1.824135,5.0833874,,,,,,,,,,,,,, -105096,,,0.7125781178474426,1.287738800048828,0.6566999554634094,1.5229451656341553,50000.0,0.5349000096321106,2.143164873123169,10000.0,47519.71596264839,51167.10824346542,47519.71596264839,3637.766434669495,4.338505029678345,0.0 -105100,2.1072857,2.9851987,,,,,,,,,,,,,, -105200,2.0832856,3.012807,,,,,,,,,,,,,, -105300,2.0971184,3.1688304,,,,,,,,,,,,,, -105400,1.8850768,3.8439775,,,,,,,,,,,,,, -105500,2.1352468,3.0059686,,,,,,,,,,,,,, -105600,1.903842,3.2839713,,,,,,,,,,,,,, -105700,2.0630827,2.9069247,,,,,,,,,,,,,, -105800,1.7139673,4.148252,,,,,,,,,,,,,, -105900,2.1566608,3.3316686,,,,,,,,,,,,,, -106000,2.0502384,2.9787464,,,,,,,,,,,,,, -106025,,,0.71156245470047,1.2922778129577637,0.6466999650001526,1.5727142095565796,50000.0,0.5263000130653381,2.194082260131836,10000.0,47939.64816617966,51620.6674015522,47939.64816617966,3671.302904844284,4.382236242294312,0.0 -106100,2.0678916,2.9570494,,,,,,,,,,,,,, -106200,2.2327132,4.9971256,,,,,,,,,,,,,, -106300,1.9923445,2.8818445,,,,,,,,,,,,,, -106400,2.04734,3.0479572,,,,,,,,,,,,,, -106500,2.036531,3.0116568,,,,,,,,,,,,,, -106600,2.0666497,4.540369,,,,,,,,,,,,,, -106700,1.9340914,4.929753,,,,,,,,,,,,,, -106800,1.9200515,3.1203508,,,,,,,,,,,,,, -106900,1.8228621,4.0633445,,,,,,,,,,,,,, -106952,,,0.7048046588897705,1.3123520612716677,0.6563400030136108,1.5270174741744995,50000.0,0.5338000059127808,2.1496691703796387,10000.0,48359.77759027481,52074.03806447983,48359.77759027481,3704.454334497452,4.4254231452941895,0.0 -107000,2.1879466,2.8451102,,,,,,,,,,,,,, -107100,2.2033906,2.973091,,,,,,,,,,,,,, -107200,1.9689503,2.8777187,,,,,,,,,,,,,, -107300,2.4938169,4.3585167,,,,,,,,,,,,,, -107400,2.3373365,3.0073607,,,,,,,,,,,,,, -107500,2.0292656,4.371393,,,,,,,,,,,,,, -107600,1.8622377,4.996272,,,,,,,,,,,,,, -107700,1.9288267,3.185325,,,,,,,,,,,,,, -107800,2.1812477,2.980904,,,,,,,,,,,,,, -107880,,,0.7080664038658142,1.289051175117493,0.6545799970626831,1.529636263847351,50000.0,0.5354000329971313,2.1473042964935303,10000.0,48779.73000144959,52526.886921167374,48779.73000144959,3737.262728214264,4.467191934585571,0.0 -107900,2.0180457,2.914519,,,,,,,,,,,,,, -108000,2.0545254,2.913827,,,,,,,,,,,,,, -108100,1.8004199,3.8903165,,,,,,,,,,,,,, -108200,2.1136796,5.2047,,,,,,,,,,,,,, -108300,1.9577373,3.0724888,,,,,,,,,,,,,, -108400,1.9105259,4.793239,,,,,,,,,,,,,, -108500,1.8830452,3.9147964,,,,,,,,,,,,,, -108600,2.1258028,2.8697324,,,,,,,,,,,,,, -108700,2.0847328,5.071905,,,,,,,,,,,,,, -108800,2.4094281,2.927265,,,,,,,,,,,,,, -108810,,,0.7175390720367432,1.221261978149414,0.6603599786758423,1.4771640300750732,50000.0,0.5387000441551208,2.107815742492676,10000.0,49199.8685810566,52978.13624429703,49199.8685810566,3768.2807273864746,4.513633966445923,0.0 -108900,2.0105445,3.158157,,,,,,,,,,,,,, -109000,2.2915452,3.4107854,,,,,,,,,,,,,, -109100,1.9159497,2.8199153,,,,,,,,,,,,,, -109200,1.9262356,3.0567255,,,,,,,,,,,,,, -109300,2.0152485,2.9043982,,,,,,,,,,,,,, -109400,1.9822589,3.0660641,,,,,,,,,,,,,, -109500,2.0740335,3.1646428,,,,,,,,,,,,,, -109600,1.9488508,3.2393515,,,,,,,,,,,,,, -109700,1.9691318,4.1185436,,,,,,,,,,,,,, -109738,,,0.7140429615974426,1.2802772521972656,0.6586199998855591,1.5189369916915894,50000.0,0.5404000282287598,2.142338991165161,10000.0,49619.92288994789,53429.75419712067,49619.92288994789,3799.7456452846527,4.565981388092041,0.0 -109800,2.0820305,5.1382008,,,,,,,,,,,,,, -109900,2.116718,2.8921735,,,,,,,,,,,,,, -110000,2.059717,2.8761902,,,,,,,,,,,,,, -110100,2.0974655,2.9872956,,,,,,,,,,,,,, -110200,2.2748208,5.1042166,,,,,,,,,,,,,, -110300,2.3403866,3.007617,,,,,,,,,,,,,, -110400,2.1879945,3.2101192,,,,,,,,,,,,,, -110500,2.0613453,3.6220322,,,,,,,,,,,,,, -110600,1.8430065,4.5954437,,,,,,,,,,,,,, -110669,,,0.7124218344688416,1.2984763383865356,0.6633599996566772,1.5199729204177856,50000.0,0.5420000553131104,2.146238327026367,10000.0,50039.94888544083,53883.64413046837,50039.94888544083,3833.513250827789,4.615039587020874,0.0 -110700,2.0942981,2.8901954,,,,,,,,,,,,,, -110800,1.9677835,2.9707074,,,,,,,,,,,,,, -110900,1.9617957,3.3408668,,,,,,,,,,,,,, -111000,2.3368883,3.1143937,,,,,,,,,,,,,, -111100,2.2327158,2.853796,,,,,,,,,,,,,, -111200,2.1320603,3.032581,,,,,,,,,,,,,, -111300,1.9834427,3.1814508,,,,,,,,,,,,,, -111400,1.9803163,3.311691,,,,,,,,,,,,,, -111500,2.1445,2.8338559,,,,,,,,,,,,,, -111600,2.2089117,2.9284818,,,,,,,,,,,,,, -111601,,,0.7237108945846558,1.2157682180404663,0.6658399701118469,1.476227641105652,50000.0,0.5434000492095947,2.1115095615386963,10000.0,50460.5451362133,54337.49081420898,50460.5451362133,3866.6692838668814,4.662085294723511,0.0 -111700,2.0534537,2.885396,,,,,,,,,,,,,, -111800,2.078238,5.111148,,,,,,,,,,,,,, -111900,2.04858,3.45776,,,,,,,,,,,,,, -112000,1.9507921,4.8761716,,,,,,,,,,,,,, -112100,2.0714016,2.9272969,,,,,,,,,,,,,, -112200,2.2353709,2.7834764,,,,,,,,,,,,,, -112300,2.3735416,2.8742008,,,,,,,,,,,,,, -112400,2.1843772,2.9676309,,,,,,,,,,,,,, -112500,2.1438873,2.88064,,,,,,,,,,,,,, -112529,,,0.73876953125,1.147725224494934,0.6629999876022339,1.471642017364502,50000.0,0.550000011920929,2.085638523101806,10000.0,50880.741792202,54791.81612062454,50880.741792202,3900.708383321762,4.705749273300171,0.0 -112600,1.9362739,3.14155,,,,,,,,,,,,,, -112700,2.1289117,2.9101555,,,,,,,,,,,,,, -112800,2.2288232,2.7736368,,,,,,,,,,,,,, -112900,2.1108358,2.9383545,,,,,,,,,,,,,, -113000,2.400448,2.8407645,,,,,,,,,,,,,, -113100,1.8470429,3.4569457,,,,,,,,,,,,,, -113200,2.075744,2.919633,,,,,,,,,,,,,, -113300,2.007615,4.3533897,,,,,,,,,,,,,, -113400,2.1237948,3.9748158,,,,,,,,,,,,,, -113460,,,0.7142187356948853,1.256130933761597,0.6647799611091614,1.4717066287994385,50000.0,0.5469000339508057,2.0994019508361816,10000.0,51301.066935777664,55246.87820911408,51301.066935777664,3935.346556425095,4.757417678833008,0.0 -113500,2.090855,3.4823442,,,,,,,,,,,,,, -113600,2.2207117,2.8246555,,,,,,,,,,,,,, -113700,2.215615,2.9128375,,,,,,,,,,,,,, -113800,1.9193183,3.28549,,,,,,,,,,,,,, -113900,2.227927,4.0882053,,,,,,,,,,,,,, -114000,2.3414106,3.0685833,,,,,,,,,,,,,, -114100,2.1169739,2.8616142,,,,,,,,,,,,,, -114200,2.3128295,2.8640537,,,,,,,,,,,,,, -114300,2.141791,2.837438,,,,,,,,,,,,,, -114389,,,0.7233593463897705,1.2339673042297363,0.6669600009918213,1.4803937673568726,50000.0,0.5499000549316406,2.0966713428497314,10000.0,51721.110796928406,55701.402671575546,51721.110796928406,3969.720571756363,4.803069829940796,0.0 -114400,2.0335605,3.8636608,,,,,,,,,,,,,, -114500,2.0424104,3.6707458,,,,,,,,,,,,,, -114600,2.2833025,2.9164877,,,,,,,,,,,,,, -114700,2.1757827,2.8398015,,,,,,,,,,,,,, -114800,2.1906781,2.8359075,,,,,,,,,,,,,, -114900,2.3264108,2.8637655,,,,,,,,,,,,,, -115000,2.3241134,3.0021098,,,,,,,,,,,,,, -115100,2.2839127,2.884688,,,,,,,,,,,,,, -115200,2.3315842,2.8628228,,,,,,,,,,,,,, -115300,2.509034,2.9467988,,,,,,,,,,,,,, -115318,,,0.7351366877555847,1.1889073848724363,0.6723600029945374,1.4659477472305298,50000.0,0.549500048160553,2.0902318954467773,10000.0,52141.49275612831,56155.447914361954,52141.49275612831,4003.2828526496887,4.8567054271698,0.0 -115400,2.40021,2.8051398,,,,,,,,,,,,,, -115500,2.0093758,4.458002,,,,,,,,,,,,,, -115600,2.4218333,2.8833323,,,,,,,,,,,,,, -115700,2.2117515,3.2506592,,,,,,,,,,,,,, -115800,2.4743564,2.8229349,,,,,,,,,,,,,, -115900,2.1576257,3.1431377,,,,,,,,,,,,,, -116000,2.1180687,4.407543,,,,,,,,,,,,,, -116100,2.0521262,3.0076473,,,,,,,,,,,,,, -116200,2.2085156,3.8890584,,,,,,,,,,,,,, -116249,,,0.7199413776397705,1.2665797472000122,0.67221999168396,1.4920510053634644,50000.0,0.5455000400543213,2.114979982376098,10000.0,52561.77123785019,56608.08748936653,52561.77123785019,4035.5553166866302,4.898540258407593,0.0 -116300,2.3075516,3.3152163,,,,,,,,,,,,,, -116400,2.0867643,4.6871123,,,,,,,,,,,,,, -116500,2.2848668,2.8811193,,,,,,,,,,,,,, -116600,2.2519388,2.977342,,,,,,,,,,,,,, -116700,1.8930293,3.8062944,,,,,,,,,,,,,, -116800,2.491427,2.8970253,,,,,,,,,,,,,, -116900,2.3448155,4.5439177,,,,,,,,,,,,,, -117000,2.3363633,2.8969285,,,,,,,,,,,,,, -117100,1.8898638,4.0860157,,,,,,,,,,,,,, -117180,,,0.7294726371765137,1.2057080268859863,0.677299976348877,1.432081937789917,50000.0,0.5527000427246094,2.0657169818878174,10000.0,52981.80159282684,57062.94676208496,52981.80159282684,4070.2912969589233,4.944955587387085,0.0 -117200,2.045768,2.910003,,,,,,,,,,,,,, -117300,2.3928196,2.896982,,,,,,,,,,,,,, -117400,2.3244905,2.8081846,,,,,,,,,,,,,, -117500,2.2217882,2.7302103,,,,,,,,,,,,,, -117600,2.2783382,2.7516065,,,,,,,,,,,,,, -117700,2.0024507,3.5587273,,,,,,,,,,,,,, -117800,2.5238426,4.7275534,,,,,,,,,,,,,, -117900,1.891423,4.0786633,,,,,,,,,,,,,, -118000,2.0869398,3.9010472,,,,,,,,,,,,,, -118100,2.0358126,4.89894,,,,,,,,,,,,,, -118109,,,0.7354491949081421,1.1728439331054688,0.6750800013542175,1.43381667137146,50000.0,0.5592000484466553,2.0560216903686523,10000.0,53401.86762642861,57518.09233784676,53401.86762642861,4105.278161764145,4.990721702575684,0.0 -118200,2.0713496,4.8936834,,,,,,,,,,,,,, -118300,2.337056,2.882047,,,,,,,,,,,,,, -118400,1.9467123,4.4519124,,,,,,,,,,,,,, -118500,2.0895634,4.2093043,,,,,,,,,,,,,, -118600,1.9598795,3.561059,,,,,,,,,,,,,, -118700,2.067656,3.766079,,,,,,,,,,,,,, -118800,2.4391563,2.8608563,,,,,,,,,,,,,, -118900,2.1460593,4.0483246,,,,,,,,,,,,,, -119000,2.1247427,4.6730604,,,,,,,,,,,,,, -119039,,,0.7400195002555847,1.14028000831604,0.6732800006866455,1.417976975440979,50000.0,0.5534999966621399,2.041276454925537,10000.0,53822.07191777229,57971.04748129845,53822.07191777229,4137.934015035629,5.03800368309021,0.0 -119100,2.356588,2.8110695,,,,,,,,,,,,,, -119200,2.344081,2.7882254,,,,,,,,,,,,,, -119300,2.559252,4.5660486,,,,,,,,,,,,,, -119400,2.2670953,2.8228538,,,,,,,,,,,,,, -119500,2.235933,2.82058,,,,,,,,,,,,,, -119600,2.2342744,3.1052139,,,,,,,,,,,,,, -119700,2.2441783,3.1261477,,,,,,,,,,,,,, -119800,1.9594263,4.625757,,,,,,,,,,,,,, -119900,2.2565184,4.966783,,,,,,,,,,,,,, -119970,,,0.7264648079872131,1.2079328298568726,0.6744799613952637,1.4414364099502563,50000.0,0.5523000359535217,2.0710713863372803,10000.0,54242.142776966095,58425.14707708359,54242.142776966095,4171.868766546249,5.085484981536865,0.0 -120000,2.4597208,2.9285362,,,,,,,,,,,,,, -120100,2.211676,4.748322,,,,,,,,,,,,,, -120200,2.3201811,2.9759738,,,,,,,,,,,,,, -120300,2.0914335,4.2669797,,,,,,,,,,,,,, -120400,2.5385418,2.7990565,,,,,,,,,,,,,, -120500,2.3323777,2.825079,,,,,,,,,,,,,, -120600,2.184156,4.360178,,,,,,,,,,,,,, -120700,2.3576448,3.416903,,,,,,,,,,,,,, -120800,2.3993664,4.6149316,,,,,,,,,,,,,, -120900,2.3457236,2.8785043,,,,,,,,,,,,,, -120901,,,0.7350195050239563,1.1525589227676392,0.6802799701690674,1.4023057222366333,50000.0,0.5651000142097473,2.001354694366455,10000.0,54662.12339830399,58878.76770377159,54662.12339830399,4205.415090322495,5.132269144058228,0.0 -121000,2.5585675,2.7613468,,,,,,,,,,,,,, -121100,2.384286,3.2167199,,,,,,,,,,,,,, -121200,2.3417068,3.6958833,,,,,,,,,,,,,, -121300,2.117114,2.9579773,,,,,,,,,,,,,, -121400,2.2835906,3.053638,,,,,,,,,,,,,, -121500,2.3462012,3.1160562,,,,,,,,,,,,,, -121600,2.046124,4.297537,,,,,,,,,,,,,, -121700,2.0493023,3.8959222,,,,,,,,,,,,,, -121800,2.4392529,2.9204438,,,,,,,,,,,,,, -121830,,,0.7547070384025574,1.0838943719863892,0.683459997177124,1.395894169807434,50000.0,0.5628000497817993,1.9966557025909424,10000.0,55082.16000986099,59331.42533326149,55082.16000986099,4237.944813966751,5.175921440124512,0.0 -121900,2.1408033,2.9981766,,,,,,,,,,,,,, -122000,2.3731048,2.8778067,,,,,,,,,,,,,, -122100,2.362212,4.4767823,,,,,,,,,,,,,, -122200,2.499405,2.7381792,,,,,,,,,,,,,, -122300,2.289258,2.8128207,,,,,,,,,,,,,, -122400,2.498929,2.7670424,,,,,,,,,,,,,, -122500,2.180009,4.8685713,,,,,,,,,,,,,, -122600,2.203605,3.5269806,,,,,,,,,,,,,, -122700,2.3715272,3.5996366,,,,,,,,,,,,,, -122757,,,0.7364257574081421,1.1707990169525146,0.6805199980735779,1.4080075025558472,50000.0,0.5541000366210938,2.036602735519409,10000.0,55502.36118531227,59785.28081989288,55502.36118531227,4271.503207921982,5.226181507110596,0.0 -122800,2.4744039,4.884121,,,,,,,,,,,,,, -122900,2.3778844,4.264833,,,,,,,,,,,,,, -123000,2.2059028,4.844739,,,,,,,,,,,,,, -123100,2.3947992,2.829848,,,,,,,,,,,,,, -123200,2.1890073,3.0259817,,,,,,,,,,,,,, -123300,2.5141823,2.8170974,,,,,,,,,,,,,, -123400,2.3808572,3.9966514,,,,,,,,,,,,,, -123500,2.4573147,3.63172,,,,,,,,,,,,,, -123600,2.5265,2.70926,,,,,,,,,,,,,, -123685,,,0.7400780916213989,1.1385509967803955,0.6856399774551392,1.3829452991485596,50000.0,0.5642000436782837,2.0096795558929443,10000.0,55922.51798272133,60240.23173499107,55922.51798272133,4306.207540750504,5.269519805908203,0.0 -123700,2.4856265,4.744954,,,,,,,,,,,,,, -123800,2.3629396,2.9132843,,,,,,,,,,,,,, -123900,2.5989475,3.3080869,,,,,,,,,,,,,, -124000,2.41726,2.7544765,,,,,,,,,,,,,, -124100,2.5131528,4.922818,,,,,,,,,,,,,, -124200,2.5791671,2.711521,,,,,,,,,,,,,, -124300,2.3645263,4.932871,,,,,,,,,,,,,, -124400,2.3462481,2.9171803,,,,,,,,,,,,,, -124500,2.1595361,3.5579972,,,,,,,,,,,,,, -124600,2.4244905,2.7928379,,,,,,,,,,,,,, -124614,,,0.7510937452316284,1.11006760597229,0.6840999722480774,1.394182205200195,50000.0,0.5678000450134277,2.01263165473938,10000.0,56342.67802453041,60692.419605493546,56342.67802453041,4338.141446352005,5.316270351409912,0.0 -124700,2.2247307,3.7352958,,,,,,,,,,,,,, -124800,2.2899528,4.0662127,,,,,,,,,,,,,, -124900,2.681091,4.894538,,,,,,,,,,,,,, -125000,2.304963,3.8548257,,,,,,,,,,,,,, -125100,2.6256986,4.4570026,,,,,,,,,,,,,, -125200,2.1807058,3.386252,,,,,,,,,,,,,, -125300,2.2869859,2.7202723,,,,,,,,,,,,,, -125400,2.331294,3.01625,,,,,,,,,,,,,, -125500,2.4631524,3.6192954,,,,,,,,,,,,,, -125544,,,0.7424218654632568,1.1373136043548584,0.6873799562454224,1.3734878301620483,50000.0,0.566100001335144,1.991938591003418,10000.0,56762.89463853836,61146.11239314079,56762.89463853836,4371.525834798813,5.361638784408569,0.0 -125600,2.668445,2.7131212,,,,,,,,,,,,,, -125700,2.5178685,2.647658,,,,,,,,,,,,,, -125800,2.2238605,3.643332,,,,,,,,,,,,,, -125900,2.4943655,2.8207798,,,,,,,,,,,,,, -126000,2.5210009,2.7265842,,,,,,,,,,,,,, -126100,2.474468,2.5783293,,,,,,,,,,,,,, -126200,2.2913985,3.0854533,,,,,,,,,,,,,, -126300,2.7294154,2.7107096,,,,,,,,,,,,,, -126400,2.588126,2.7143219,,,,,,,,,,,,,, -126472,,,0.74867182970047,1.1079342365264893,0.6914599537849426,1.3647878170013428,50000.0,0.5692000389099121,1.9838557243347168,10000.0,57183.18724656105,61599.6272623539,57183.18724656105,4404.65695309639,5.406764268875122,0.0 -126500,2.4820602,2.6689594,,,,,,,,,,,,,, -126600,2.809218,2.8271284,,,,,,,,,,,,,, -126700,2.3326685,3.198869,,,,,,,,,,,,,, -126800,2.5780196,2.737576,,,,,,,,,,,,,, -126900,2.5738688,2.6711802,,,,,,,,,,,,,, -127000,2.6041148,2.9187965,,,,,,,,,,,,,, -127100,2.4286203,2.8569107,,,,,,,,,,,,,, -127200,2.2627654,3.9859202,,,,,,,,,,,,,, -127300,2.213864,3.2639782,,,,,,,,,,,,,, -127400,2.4102921,2.7073512,,,,,,,,,,,,,, -127402,,,0.7528125047683716,1.0917577743530271,0.6893599629402161,1.3566315174102783,50000.0,0.5679000020027161,1.9800525903701784,10000.0,57603.10753917694,62053.47536659241,57603.10753917694,4438.493448495865,5.450902700424194,0.0 -127500,2.549312,2.7756288,,,,,,,,,,,,,, -127600,2.7097478,2.8573728,,,,,,,,,,,,,, -127700,2.561259,2.6637468,,,,,,,,,,,,,, -127800,2.6110172,2.9247258,,,,,,,,,,,,,, -127900,2.8805416,4.4811544,,,,,,,,,,,,,, -128000,3.190881,2.7038198,,,,,,,,,,,,,, -128100,2.9479976,2.685547,,,,,,,,,,,,,, -128200,2.5600092,3.0407534,,,,,,,,,,,,,, -128300,2.4536595,4.0192437,,,,,,,,,,,,,, -128301,,,0.7537695169448853,1.098175287246704,0.6912999749183655,1.3770054578781128,50000.0,0.5696000456809998,1.9985442161560056,10000.0,58023.76083302498,62508.17419576645,58023.76083302498,4472.449482917786,5.49484133720398,0.0 -128400,2.6855025,4.800378,,,,,,,,,,,,,, -128500,2.3358135,3.6793675,,,,,,,,,,,,,, -128600,2.54042,4.707513,,,,,,,,,,,,,, -128700,2.426856,2.853636,,,,,,,,,,,,,, -128800,2.6117256,2.8693945,,,,,,,,,,,,,, -128900,2.717253,2.7269273,,,,,,,,,,,,,, -129000,2.3987358,3.0314813,,,,,,,,,,,,,, -129100,2.2814379,3.064272,,,,,,,,,,,,,, -129200,2.630885,2.6752691,,,,,,,,,,,,,, -129231,,,0.7473828196525574,1.1209533214569092,0.6952799558639526,1.3586920499801636,50000.0,0.570900022983551,1.974564552307129,10000.0,58443.69563674927,62962.10774159432,58443.69563674927,4506.356766462326,5.540136098861694,0.0 -129300,2.7863562,2.7072008,,,,,,,,,,,,,, -129400,2.9626832,4.7755995,,,,,,,,,,,,,, -129500,2.4693575,4.5268254,,,,,,,,,,,,,, -129600,2.5817924,2.7619176,,,,,,,,,,,,,, -129700,2.6009164,2.7563982,,,,,,,,,,,,,, -129800,2.3834038,3.5215807,,,,,,,,,,,,,, -129900,2.7075317,2.754598,,,,,,,,,,,,,, -130000,2.615642,2.890503,,,,,,,,,,,,,, -130100,2.8454218,2.6613579,,,,,,,,,,,,,, -130161,,,0.7574023008346558,1.075601577758789,0.6932399868965149,1.3420850038528442,50000.0,0.5753000378608704,1.9617393016815183,10000.0,58863.781465768814,63414.34195232391,58863.781465768814,4538.41277050972,5.585582971572876,0.0 -130200,2.8797739,2.768717,,,,,,,,,,,,,, -130300,2.2950265,3.6147375,,,,,,,,,,,,,, -130400,2.5676322,3.3387449,,,,,,,,,,,,,, -130500,2.636903,3.024506,,,,,,,,,,,,,, -130600,2.6907113,2.8547864,,,,,,,,,,,,,, -130700,2.7145088,4.1503463,,,,,,,,,,,,,, -130800,2.443864,3.2429476,,,,,,,,,,,,,, -130900,2.97916,2.853962,,,,,,,,,,,,,, -131000,2.7583647,2.763113,,,,,,,,,,,,,, -131088,,,0.7674218416213989,1.0314123630523682,0.6955599784851074,1.343638300895691,50000.0,0.5745000243186951,1.9722157716751096,10000.0,59283.73093700409,63868.40799832344,59283.73093700409,4572.432414054871,5.636627674102783,0.0 -131100,2.7213135,2.7550516,,,,,,,,,,,,,, -131200,2.8511662,4.880421,,,,,,,,,,,,,, -131300,2.5890448,3.3157997,,,,,,,,,,,,,, -131400,3.9297926,4.1612306,,,,,,,,,,,,,, -131500,2.455845,4.567159,,,,,,,,,,,,,, -131600,2.8122625,2.7028055,,,,,,,,,,,,,, -131700,2.682526,2.7522154,,,,,,,,,,,,,, -131800,3.161215,4.738963,,,,,,,,,,,,,, -131900,2.5457506,4.6825905,,,,,,,,,,,,,, -132000,2.698295,3.6676495,,,,,,,,,,,,,, -132017,,,0.7551367282867432,1.0511291027069092,0.6979999542236328,1.3008636236190796,50000.0,0.5759000182151794,1.91862154006958,10000.0,59704.04973602295,64322.28016543389,59704.04973602295,4605.888027429581,5.6878721714019775,0.0 -132100,2.5439684,4.0459437,,,,,,,,,,,,,, -132200,2.615637,2.7962565,,,,,,,,,,,,,, -132300,2.6895716,2.874218,,,,,,,,,,,,,, -132400,2.4497523,3.379957,,,,,,,,,,,,,, -132500,2.836558,2.750169,,,,,,,,,,,,,, -132600,2.675052,2.8551054,,,,,,,,,,,,,, -132700,2.6515176,3.3178964,,,,,,,,,,,,,, -132800,2.606201,2.926127,,,,,,,,,,,,,, -132900,2.5774424,2.5213096,,,,,,,,,,,,,, -132946,,,0.7594921588897705,1.0723209381103516,0.6966399550437927,1.3334400653839111,50000.0,0.5769000053405762,1.952175498008728,10000.0,60124.05582642555,64776.534338235855,60124.05582642555,4640.043148756027,5.734477758407593,0.0 -133000,2.415887,2.7834647,,,,,,,,,,,,,, -133100,2.5703905,3.6398911,,,,,,,,,,,,,, -133200,2.418141,3.162383,,,,,,,,,,,,,, -133300,2.7370443,3.3438628,,,,,,,,,,,,,, -133400,2.7208743,2.693668,,,,,,,,,,,,,, -133500,2.7523706,2.5924687,,,,,,,,,,,,,, -133600,2.79835,2.899544,,,,,,,,,,,,,, -133700,2.9280872,2.8151007,,,,,,,,,,,,,, -133800,2.7060916,2.6482868,,,,,,,,,,,,,, -133876,,,0.7744726538658142,0.9877074360847472,0.7017599940299988,1.2961947917938232,50000.0,0.58160001039505,1.904213547706604,10000.0,60544.09117388725,65230.13675904274,60544.09117388725,4673.519508600235,5.778345823287964,0.0 -133900,2.9062688,2.616961,,,,,,,,,,,,,, -134000,2.836849,3.482313,,,,,,,,,,,,,, -134100,2.441979,3.193414,,,,,,,,,,,,,, -134200,2.9249585,2.6087537,,,,,,,,,,,,,, -134300,2.5267541,3.8799388,,,,,,,,,,,,,, -134400,2.7647753,4.307081,,,,,,,,,,,,,, -134500,2.6946099,4.1610775,,,,,,,,,,,,,, -134600,3.0095663,4.5241127,,,,,,,,,,,,,, -134700,3.064211,4.6812925,,,,,,,,,,,,,, -134800,2.8463185,2.6117568,,,,,,,,,,,,,, -134807,,,0.7617382407188416,1.0603655576705933,0.701200008392334,1.320079684257507,50000.0,0.5774000287055969,1.9307143688201904,10000.0,60964.34392094612,65685.15578842163,60964.34392094612,4708.190718412399,5.826894760131836,0.0 -134900,2.6516864,2.9615774,,,,,,,,,,,,,, -135000,2.738616,2.531734,,,,,,,,,,,,,, -135100,2.816203,2.6994836,,,,,,,,,,,,,, -135200,2.766275,2.9891262,,,,,,,,,,,,,, -135300,3.3391182,2.618739,,,,,,,,,,,,,, -135400,2.726064,3.9356773,,,,,,,,,,,,,, -135500,2.8011928,2.9625292,,,,,,,,,,,,,, -135600,3.0667202,2.6530378,,,,,,,,,,,,,, -135700,3.0008783,3.2953417,,,,,,,,,,,,,, -135735,,,0.768847644329071,1.01193368434906,0.7056999802589417,1.2756710052490234,50000.0,0.5831000208854675,1.8825321197509768,10000.0,61384.33895516396,66136.31755590439,61384.33895516396,4739.263605117798,5.874547243118286,0.0 -135800,2.8363464,2.7036934,,,,,,,,,,,,,, -135900,2.739825,2.6006029,,,,,,,,,,,,,, -136000,2.8506908,2.8897157,,,,,,,,,,,,,, -136100,2.630325,3.0857983,,,,,,,,,,,,,, -136200,2.8432946,4.1750226,,,,,,,,,,,,,, -136300,3.2714474,2.5622,,,,,,,,,,,,,, -136400,3.196927,4.332265,,,,,,,,,,,,,, -136500,2.6260633,2.890203,,,,,,,,,,,,,, -136600,2.866674,3.07952,,,,,,,,,,,,,, -136662,,,0.7686523199081421,1.024351954460144,0.7022199630737305,1.317602038383484,50000.0,0.5806000232696533,1.923264741897583,10000.0,61804.52350068092,66590.26353669167,61804.52350068092,4772.92241859436,5.9302287101745605,0.0 -136700,2.7869859,2.6461294,,,,,,,,,,,,,, -136800,3.0728338,2.6363692,,,,,,,,,,,,,, -136900,2.7255366,4.226264,,,,,,,,,,,,,, -137000,3.062656,2.6417036,,,,,,,,,,,,,, -137100,2.8813531,2.5547185,,,,,,,,,,,,,, -137200,2.94692,3.1302195,,,,,,,,,,,,,, -137300,2.8024054,3.8547745,,,,,,,,,,,,,, -137400,2.9778364,2.827491,,,,,,,,,,,,,, -137500,2.901506,2.588532,,,,,,,,,,,,,, -137592,,,0.786425769329071,0.9460193514823914,0.7095999717712402,1.2811696529388428,50000.0,0.5868000388145447,1.8915445804595947,10000.0,62224.45097088814,67044.65680789948,62224.45097088814,4807.293427467346,5.978266716003418,0.0 -137600,2.841989,3.0727258,,,,,,,,,,,,,, -137700,3.3134255,4.607121,,,,,,,,,,,,,, -137800,3.2792308,2.6531563,,,,,,,,,,,,,, -137900,3.2212474,4.746707,,,,,,,,,,,,,, -138000,2.570537,3.4318476,,,,,,,,,,,,,, -138100,3.4136162,2.8418012,,,,,,,,,,,,,, -138200,3.2669942,4.582999,,,,,,,,,,,,,, -138300,2.6035647,3.5631526,,,,,,,,,,,,,, -138400,3.0571506,3.2763908,,,,,,,,,,,,,, -138500,2.8831613,2.8512588,,,,,,,,,,,,,, -138521,,,0.76917964220047,1.0142436027526855,0.7066400051116943,1.2850269079208374,50000.0,0.584600031375885,1.8975168466567995,10000.0,62644.38854908943,67499.70457077026,62644.38854908943,4842.306634902954,6.0289857387542725,0.0 -138600,2.74208,2.9616358,,,,,,,,,,,,,, -138700,2.7895882,3.6403422,,,,,,,,,,,,,, -138800,2.9076817,2.8459694,,,,,,,,,,,,,, -138900,3.0424535,4.1335196,,,,,,,,,,,,,, -139000,2.9544241,2.6086416,,,,,,,,,,,,,, -139100,3.3452926,3.1742146,,,,,,,,,,,,,, -139200,3.0328937,2.7207294,,,,,,,,,,,,,, -139300,2.9254932,2.8060207,,,,,,,,,,,,,, -139400,3.0595064,2.6677504,,,,,,,,,,,,,, -139448,,,0.7764062285423279,0.9926196932792664,0.7118200063705444,1.268027663230896,50000.0,0.5867000222206116,1.8783414363861084,10000.0,63064.298337221146,67952.42379832268,63064.298337221146,4875.023307323456,6.075516939163208,0.0 -139500,2.9846056,4.3830075,,,,,,,,,,,,,, -139600,2.9336941,2.6959336,,,,,,,,,,,,,, -139700,3.0392745,2.588644,,,,,,,,,,,,,, -139800,2.6638107,3.224855,,,,,,,,,,,,,, -139900,3.11864,2.643219,,,,,,,,,,,,,, -140000,3.0701697,2.5521781,,,,,,,,,,,,,, -140100,3.0264122,2.5731897,,,,,,,,,,,,,, -140200,3.0177288,2.6955194,,,,,,,,,,,,,, -140300,3.4559262,2.5439484,,,,,,,,,,,,,, -140376,,,0.7857421636581421,0.9634817838668824,0.7120800018310547,1.285139560699463,50000.0,0.5902000069618225,1.8925907611846924,10000.0,63484.35860395432,68407.07604622841,63484.35860395432,4909.52064538002,6.123907089233398,0.0 -140400,3.2418478,2.7025084,,,,,,,,,,,,,, -140500,3.110987,2.6328905,,,,,,,,,,,,,, -140600,3.0044327,3.135818,,,,,,,,,,,,,, -140700,3.9751294,2.5808902,,,,,,,,,,,,,, -140800,3.094869,2.7415867,,,,,,,,,,,,,, -140900,3.0900538,3.4149728,,,,,,,,,,,,,, -141000,2.9810584,2.5874162,,,,,,,,,,,,,, -141100,3.1129851,2.454346,,,,,,,,,,,,,, -141200,3.069433,2.9965148,,,,,,,,,,,,,, -141300,3.133479,2.656536,,,,,,,,,,,,,, -141307,,,0.7754882574081421,0.99040687084198,0.7101199626922607,1.26579749584198,50000.0,0.5962000489234924,1.866379141807556,10000.0,63904.452016592026,68861.97762393951,63904.452016592026,4944.235973596573,6.169575452804565,0.0 -141400,2.8618937,3.6171703,,,,,,,,,,,,,, -141500,2.9971833,4.048112,,,,,,,,,,,,,, -141600,2.8704288,2.5930393,,,,,,,,,,,,,, -141700,3.036503,2.6680532,,,,,,,,,,,,,, -141800,3.0355418,3.1955986,,,,,,,,,,,,,, -141900,3.444148,2.700994,,,,,,,,,,,,,, -142000,2.9424903,3.0704718,,,,,,,,,,,,,, -142100,3.356089,2.689472,,,,,,,,,,,,,, -142200,3.047834,2.834891,,,,,,,,,,,,,, -142234,,,0.7803906202316284,0.9680672883987428,0.7150200009346008,1.2434526681900024,50000.0,0.5933000445365906,1.8614035844802856,10000.0,64324.7577764988,69316.77753448486,64324.7577764988,4978.624108314514,6.229996204376221,0.0 -142300,3.2435355,2.7252307,,,,,,,,,,,,,, -142400,3.4783528,4.7031336,,,,,,,,,,,,,, -142500,3.0304291,3.3470478,,,,,,,,,,,,,, -142600,2.991813,3.4286003,,,,,,,,,,,,,, -142700,3.1685643,2.5274284,,,,,,,,,,,,,, -142800,3.6556616,2.5973456,,,,,,,,,,,,,, -142900,3.246926,2.6049373,,,,,,,,,,,,,, -143000,3.1020017,3.6914463,,,,,,,,,,,,,, -143100,3.4305491,4.6767883,,,,,,,,,,,,,, -143162,,,0.7883398532867432,0.9409700632095336,0.7188999652862549,1.2394484281539917,50000.0,0.6005000472068787,1.8527811765670776,10000.0,64745.00335741043,69770.92920541763,64745.00335741043,5012.432188272476,6.281265020370483,0.0 -143200,2.9367797,3.2759538,,,,,,,,,,,,,, -143300,3.5546737,4.1944757,,,,,,,,,,,,,, -143400,3.1949348,2.6173005,,,,,,,,,,,,,, -143500,3.2310126,3.0213761,,,,,,,,,,,,,, -143600,3.0453582,2.991087,,,,,,,,,,,,,, -143700,2.844242,2.9083667,,,,,,,,,,,,,, -143800,3.557338,4.601413,,,,,,,,,,,,,, -143900,3.6326857,4.0085497,,,,,,,,,,,,,, -144000,3.3062916,2.530199,,,,,,,,,,,,,, -144090,,,0.7860937118530273,0.9350039958953856,0.7210599780082703,1.209059238433838,50000.0,0.6010000109672546,1.8121860027313232,10000.0,65165.10458111763,70224.88837790489,65165.10458111763,5046.197269678116,6.327600479125977,0.0 -144100,3.305676,4.6263256,,,,,,,,,,,,,, -144200,3.236269,3.4557052,,,,,,,,,,,,,, -144300,3.4474149,2.5536277,,,,,,,,,,,,,, -144400,3.647659,2.8413537,,,,,,,,,,,,,, -144500,3.1797276,2.5866857,,,,,,,,,,,,,, -144600,3.1891117,2.481111,,,,,,,,,,,,,, -144700,3.2516832,2.5967896,,,,,,,,,,,,,, -144800,3.2056644,4.2737565,,,,,,,,,,,,,, -144900,3.2352703,2.5139947,,,,,,,,,,,,,, -145000,4.068598,4.2515283,,,,,,,,,,,,,, -145020,,,0.7858593463897705,0.9343098402023317,0.7197799682617188,1.222838282585144,50000.0,0.6005000472068787,1.830407738685608,10000.0,65585.06423592567,70680.72092986107,65585.06423592567,5081.971020698547,6.380660057067871,0.0 -145100,3.2088304,3.9809175,,,,,,,,,,,,,, -145200,3.8283408,3.428755,,,,,,,,,,,,,, -145300,3.4799647,4.513673,,,,,,,,,,,,,, -145400,3.1423895,2.8252585,,,,,,,,,,,,,, -145500,3.1174948,3.512151,,,,,,,,,,,,,, -145600,3.3032868,2.4394798,,,,,,,,,,,,,, -145700,3.4351938,3.52869,,,,,,,,,,,,,, -145800,3.0731685,2.9682488,,,,,,,,,,,,,, -145900,3.5786679,2.5027459,,,,,,,,,,,,,, -145951,,,0.7895702719688416,0.9390873312950134,0.7216399908065796,1.2361648082733154,50000.0,0.6039000153541565,1.830265760421753,10000.0,66005.23581600189,71135.50416016579,66005.23581600189,5116.485925197601,6.430635690689087,0.0 -146000,3.57451,4.1248465,,,,,,,,,,,,,, -146100,3.4365497,2.7728724,,,,,,,,,,,,,, -146200,3.276216,2.508834,,,,,,,,,,,,,, -146300,3.2572556,3.5944562,,,,,,,,,,,,,, -146400,3.3766904,2.8450482,,,,,,,,,,,,,, -146500,3.3215616,4.161886,,,,,,,,,,,,,, -146600,3.684932,3.791683,,,,,,,,,,,,,, -146700,3.4915905,2.689105,,,,,,,,,,,,,, -146800,3.626745,4.4309154,,,,,,,,,,,,,, -146880,,,0.8043749928474426,0.8740522861480713,0.7262799739837646,1.2046759128570557,50000.0,0.6035000085830688,1.808951020240784,10000.0,66425.18833255768,71590.81511712074,66425.18833255768,5151.747972011566,6.480493545532227,0.0 -146900,3.5589392,3.7515514,,,,,,,,,,,,,, -147000,3.6110804,2.567145,,,,,,,,,,,,,, -147100,3.2898495,2.5185633,,,,,,,,,,,,,, -147200,3.3116775,2.6315045,,,,,,,,,,,,,, -147300,3.654319,2.5985327,,,,,,,,,,,,,, -147400,3.280195,2.726481,,,,,,,,,,,,,, -147500,3.2386773,2.5595794,,,,,,,,,,,,,, -147600,3.4330049,2.7196198,,,,,,,,,,,,,, -147700,3.4993038,2.5178814,,,,,,,,,,,,,, -147800,2.903884,3.1229348,,,,,,,,,,,,,, -147810,,,0.7898827791213989,0.9271060228347778,0.7245599627494812,1.2024517059326172,50000.0,0.6008000373840332,1.8167974948883057,10000.0,66845.46230769157,72044.69520163536,66845.46230769157,5185.241056442261,6.546149730682373,0.0 -147900,3.142468,2.5421987,,,,,,,,,,,,,, -148000,3.5681841,3.4444191,,,,,,,,,,,,,, -148100,3.639555,2.6164172,,,,,,,,,,,,,, -148200,3.3961055,2.4741285,,,,,,,,,,,,,, -148300,3.4716194,2.636127,,,,,,,,,,,,,, -148400,3.7167199,2.6132195,,,,,,,,,,,,,, -148500,3.4250429,2.5874596,,,,,,,,,,,,,, -148600,3.6783504,3.859824,,,,,,,,,,,,,, -148700,3.668569,2.5188663,,,,,,,,,,,,,, -148741,,,0.79749995470047,0.8806569576263428,0.7253400087356567,1.184898853302002,50000.0,0.6121000051498413,1.7757694721221924,10000.0,67265.56567597389,72499.34892606735,67265.56567597389,5219.695116758347,6.595580577850342,0.0 -148800,3.6700952,2.5480485,,,,,,,,,,,,,, -148900,3.34259,2.8999836,,,,,,,,,,,,,, -149000,3.8468845,4.5039682,,,,,,,,,,,,,, -149100,3.7718284,2.5132706,,,,,,,,,,,,,, -149200,3.5181003,2.4647493,,,,,,,,,,,,,, -149300,3.8709426,2.504578,,,,,,,,,,,,,, -149400,3.3462183,2.421762,,,,,,,,,,,,,, -149500,3.3590162,2.5798125,,,,,,,,,,,,,, -149600,3.1507792,2.5213847,,,,,,,,,,,,,, -149669,,,0.8055273294448853,0.8653415441513062,0.7287799715995789,1.1887600421905518,50000.0,0.6064000129699707,1.7887310981750488,10000.0,67685.5972161293,72953.5630671978,67685.5972161293,5253.777672767639,6.64896035194397,0.0 -149700,3.7918432,2.6483154,,,,,,,,,,,,,, -149800,3.840447,4.5981565,,,,,,,,,,,,,, -149900,3.3557181,4.004049,,,,,,,,,,,,,, -150000,3.692756,3.3616433,,,,,,,,,,,,,, -150100,3.6430469,2.5636358,,,,,,,,,,,,,, -150200,3.6381378,3.9653654,,,,,,,,,,,,,, -150300,3.5596716,3.102366,,,,,,,,,,,,,, -150400,3.9677808,2.7199397,,,,,,,,,,,,,, -150500,3.542152,2.5467114,,,,,,,,,,,,,, -150600,,,0.7973241806030273,0.8826409578323364,0.7310000061988831,1.1645584106445312,50000.0,0.6121000051498413,1.759349703788757,10000.0,68105.81286787987,73407.98974180222,68105.81286787987,5287.894725084305,6.696053504943848,0.0 -150600,3.5463476,2.7988482,,,,,,,,,,,,,, -150700,3.5025141,2.4866261,,,,,,,,,,,,,, -150800,3.9631777,2.5680943,,,,,,,,,,,,,, -150900,3.8347814,2.5993783,,,,,,,,,,,,,, -151000,3.701652,2.8019176,,,,,,,,,,,,,, -151100,3.9026256,2.7269268,,,,,,,,,,,,,, -151200,4.5108414,4.4513817,,,,,,,,,,,,,, -151300,3.8919568,3.7060137,,,,,,,,,,,,,, -151400,4.1021585,4.39092,,,,,,,,,,,,,, -151500,4.5497007,4.457158,,,,,,,,,,,,,, -151531,,,0.8052929639816284,0.8601254224777222,0.733959972858429,1.1651591062545776,50000.0,0.6106000542640686,1.780478596687317,10000.0,68526.03539800644,73862.64343810081,68526.03539800644,5322.22726726532,6.748147487640381,0.0 -151600,3.8341634,2.9795065,,,,,,,,,,,,,, -151700,3.8533556,2.5887547,,,,,,,,,,,,,, -151800,4.017664,2.4874465,,,,,,,,,,,,,, -151900,3.4803832,2.6299942,,,,,,,,,,,,,, -152000,4.1110997,2.5535533,,,,,,,,,,,,,, -152100,3.7048807,2.3843684,,,,,,,,,,,,,, -152200,3.852421,2.74442,,,,,,,,,,,,,, -152300,3.8821394,2.394372,,,,,,,,,,,,,, -152400,3.971435,3.697247,,,,,,,,,,,,,, -152460,,,0.8087890148162842,0.8668062686920166,0.7332599759101868,1.18711519241333,50000.0,0.610200047492981,1.789298176765442,10000.0,68946.24187660217,74317.84789443016,68946.24187660217,5357.129096031189,6.798153877258301,0.0 -152500,3.8821404,2.6678896,,,,,,,,,,,,,, -152600,3.7204766,2.4172056,,,,,,,,,,,,,, -152700,4.216308,4.1791186,,,,,,,,,,,,,, -152800,3.9039605,4.051613,,,,,,,,,,,,,, -152900,4.187135,2.4876685,,,,,,,,,,,,,, -153000,3.7732167,2.4205241,,,,,,,,,,,,,, -153100,3.9559996,2.4006674,,,,,,,,,,,,,, -153200,4.4291515,2.4339232,,,,,,,,,,,,,, -153300,4.41693,4.0918455,,,,,,,,,,,,,, -153391,,,0.8062109351158142,0.8395585417747498,0.7360999584197998,1.139323353767395,50000.0,0.6176000237464905,1.7518985271453855,10000.0,69366.1802983284,74772.3870472908,69366.1802983284,5391.624893188477,6.85608983039856,0.0 -153400,3.8344288,2.5603504,,,,,,,,,,,,,, -153500,3.997486,2.9422495,,,,,,,,,,,,,, -153600,3.7987864,2.4315393,,,,,,,,,,,,,, -153700,4.031997,2.4405403,,,,,,,,,,,,,, -153800,3.6666744,2.8048365,,,,,,,,,,,,,, -153900,4.2426195,2.4550624,,,,,,,,,,,,,, -154000,3.7825963,2.4483826,,,,,,,,,,,,,, -154100,3.8313365,3.4311097,,,,,,,,,,,,,, -154200,3.850097,2.3369803,,,,,,,,,,,,,, -154300,4.3530684,2.532544,,,,,,,,,,,,,, -154319,,,0.8050194978713989,0.8566577434539795,0.7374399900436401,1.1484848260879517,50000.0,0.6152999997138977,1.7600574493408203,10000.0,69786.34556651115,75227.58785867691,69786.34556651115,5426.559454441071,6.911020278930664,0.0 -154400,4.3533516,4.035383,,,,,,,,,,,,,, -154500,3.839353,3.8890305,,,,,,,,,,,,,, -154600,3.8580902,2.978634,,,,,,,,,,,,,, -154700,4.1563864,3.7449598,,,,,,,,,,,,,, -154800,3.9643292,2.881771,,,,,,,,,,,,,, -154900,4.3861537,4.469999,,,,,,,,,,,,,, -155000,4.0226874,2.3676925,,,,,,,,,,,,,, -155100,3.7356644,3.253679,,,,,,,,,,,,,, -155200,3.8564718,2.4813342,,,,,,,,,,,,,, -155245,,,0.8122656345367432,0.8312785625457764,0.7390999794006348,1.1496641635894775,50000.0,0.6205000281333923,1.7459418773651123,10000.0,70206.66819000244,75682.72457933426,70206.66819000244,5461.275590896606,6.962958574295044,0.0 -155300,4.6808424,4.4550734,,,,,,,,,,,,,, -155400,4.550512,3.7370033,,,,,,,,,,,,,, -155500,4.0758157,2.4240527,,,,,,,,,,,,,, -155600,3.8789053,2.4687343,,,,,,,,,,,,,, -155700,4.0607777,2.5086265,,,,,,,,,,,,,, -155800,4.1479106,3.8085399,,,,,,,,,,,,,, -155900,3.7901866,2.809668,,,,,,,,,,,,,, -156000,4.4697266,4.091635,,,,,,,,,,,,,, -156100,3.788255,2.5864823,,,,,,,,,,,,,, -156170,,,0.8208202719688416,0.808429479598999,0.7415599822998047,1.147451639175415,50000.0,0.6170000433921814,1.7455700635910034,10000.0,70626.57415676117,76136.81805038452,70626.57415676117,5495.359393119812,7.020386457443237,0.0 -156200,4.102911,2.452802,,,,,,,,,,,,,, -156300,4.136429,2.993068,,,,,,,,,,,,,, -156400,3.5977411,2.3141985,,,,,,,,,,,,,, -156500,4.1576686,3.905612,,,,,,,,,,,,,, -156600,4.437581,3.6022067,,,,,,,,,,,,,, -156700,5.229682,4.412339,,,,,,,,,,,,,, -156800,3.9033818,2.3714495,,,,,,,,,,,,,, -156900,3.984405,3.3734221,,,,,,,,,,,,,, -157000,4.7041903,4.380461,,,,,,,,,,,,,, -157100,,,0.8145312070846558,0.8163199424743652,0.7399199604988098,1.1327764987945557,50000.0,0.6218000054359436,1.7313672304153442,10000.0,71046.6845676899,76591.11240267754,71046.6845676899,5529.437774181366,7.078781843185425,0.0 -157100,4.222475,2.3870273,,,,,,,,,,,,,, -157200,4.126521,2.3922164,,,,,,,,,,,,,, -157300,3.9978306,3.6877906,,,,,,,,,,,,,, -157400,4.0265117,3.1107945,,,,,,,,,,,,,, -157500,3.9344606,3.23603,,,,,,,,,,,,,, -157600,4.1266994,2.7360268,,,,,,,,,,,,,, -157700,4.2192707,3.6123786,,,,,,,,,,,,,, -157800,4.1553626,2.3719587,,,,,,,,,,,,,, -157900,5.1744905,4.366755,,,,,,,,,,,,,, -158000,4.134871,3.3186774,,,,,,,,,,,,,, -158029,,,0.8168163895606995,0.8214466571807861,0.7429199814796448,1.1342238187789917,50000.0,0.6199000477790833,1.7298482656478882,10000.0,71466.77744364738,77045.694283247,71466.77744364738,5563.826683521271,7.130944013595581,0.0 -158100,4.6684175,2.5030286,,,,,,,,,,,,,, -158200,5.0042725,4.4165287,,,,,,,,,,,,,, -158300,4.486215,2.4040642,,,,,,,,,,,,,, -158400,3.974846,3.6112847,,,,,,,,,,,,,, -158500,5.250808,4.3347783,,,,,,,,,,,,,, -158600,4.547585,2.6251016,,,,,,,,,,,,,, -158700,3.8457298,2.4856658,,,,,,,,,,,,,, -158800,3.9004648,2.38579,,,,,,,,,,,,,, -158900,4.455666,2.4012444,,,,,,,,,,,,,, -158959,,,0.823535144329071,0.7854181528091431,0.7455799579620361,1.114540696144104,50000.0,0.6255000233650208,1.7001981735229492,10000.0,71886.9072842598,77500.05093717575,71886.9072842598,5597.954945802689,7.182467699050903,0.0 -159000,3.9594119,2.2736092,,,,,,,,,,,,,, -159100,4.183597,3.8059168,,,,,,,,,,,,,, -159200,4.258071,4.1456027,,,,,,,,,,,,,, -159300,4.5251226,2.4445584,,,,,,,,,,,,,, -159400,4.3533483,2.494079,,,,,,,,,,,,,, -159500,4.580502,2.5228791,,,,,,,,,,,,,, -159600,4.390029,2.4147382,,,,,,,,,,,,,, -159700,4.0180335,3.2757576,,,,,,,,,,,,,, -159800,4.794394,3.315115,,,,,,,,,,,,,, -159890,,,0.8223828077316284,0.7984592914581299,0.7449600100517273,1.1174359321594238,50000.0,0.6300000548362732,1.702431082725525,10000.0,72306.84834432602,77954.27435541153,72306.84834432602,5632.141172647476,7.231976747512817,0.0 -159900,4.633045,2.367497,,,,,,,,,,,,,, -160000,5.114899,4.0892696,,,,,,,,,,,,,, -160100,4.5282803,2.415012,,,,,,,,,,,,,, -160200,4.8669844,2.3738186,,,,,,,,,,,,,, -160300,4.4523735,2.3901153,,,,,,,,,,,,,, -160400,4.1195507,2.3941023,,,,,,,,,,,,,, -160500,4.3605328,2.3826709,,,,,,,,,,,,,, -160600,4.2398887,2.3774412,,,,,,,,,,,,,, -160700,4.4999633,2.3991277,,,,,,,,,,,,,, -160800,4.1575384,3.436685,,,,,,,,,,,,,, -160819,,,0.82386714220047,0.7823898196220398,0.7475000023841858,1.1016745567321775,50000.0,0.6273000240325928,1.690682888031006,10000.0,72726.83613228798,78408.82541036606,72726.83613228798,5666.60213804245,7.287751197814941,0.0 -160900,4.412084,2.8477585,,,,,,,,,,,,,, -161000,4.5785203,3.4393387,,,,,,,,,,,,,, -161100,4.679264,2.5800939,,,,,,,,,,,,,, -161200,4.22109,2.807622,,,,,,,,,,,,,, -161300,4.5926943,2.3462071,,,,,,,,,,,,,, -161400,4.946259,3.2892046,,,,,,,,,,,,,, -161500,4.4783406,3.9059691,,,,,,,,,,,,,, -161600,4.4792857,2.34312,,,,,,,,,,,,,, -161700,4.8811283,3.0548205,,,,,,,,,,,,,, -161748,,,0.8251367211341858,0.7654830813407898,0.7488999962806702,1.0918009281158447,50000.0,0.6290000081062317,1.686492681503296,10000.0,73147.12069392204,78863.87222862244,73147.12069392204,5701.25897192955,7.346428871154785,0.0 -161800,5.8553553,4.4048862,,,,,,,,,,,,,, -161900,4.0201426,2.7883074,,,,,,,,,,,,,, -162000,4.8522415,3.7325773,,,,,,,,,,,,,, -162100,4.3458953,2.7057683,,,,,,,,,,,,,, -162200,4.187782,2.4537597,,,,,,,,,,,,,, -162300,4.349236,2.7107174,,,,,,,,,,,,,, -162400,4.449388,2.282131,,,,,,,,,,,,,, -162500,5.824171,4.345536,,,,,,,,,,,,,, -162600,4.110722,2.9181905,,,,,,,,,,,,,, -162676,,,0.8285351395606995,0.7643269896507263,0.7503599524497986,1.0934828519821167,50000.0,0.633400022983551,1.6805967092514038,10000.0,73567.05645251274,79317.50340890884,73567.05645251274,5734.852890491486,7.401033163070679,0.0 -162700,5.414965,3.9646597,,,,,,,,,,,,,, -162800,4.7674828,2.9248564,,,,,,,,,,,,,, -162900,4.9145856,2.375106,,,,,,,,,,,,,, -163000,4.7051435,3.5702753,,,,,,,,,,,,,, -163100,6.094577,4.43919,,,,,,,,,,,,,, -163200,4.5409756,2.333128,,,,,,,,,,,,,, -163300,4.5540595,3.1724029,,,,,,,,,,,,,, -163400,4.6491976,2.6595578,,,,,,,,,,,,,, -163500,4.274933,2.2984552,,,,,,,,,,,,,, -163600,4.8668756,2.2531514,,,,,,,,,,,,,, -163603,,,0.8294531106948853,0.7535117268562317,0.7529199719429016,1.0765010118484497,50000.0,0.6317000389099121,1.6685441732406616,10000.0,73987.05826282501,79770.55769276619,73987.05826282501,5767.8042142391205,7.45538067817688,0.0 -163700,4.753978,2.5219302,,,,,,,,,,,,,, -163800,5.77867,4.3288856,,,,,,,,,,,,,, -163900,5.181986,4.1201954,,,,,,,,,,,,,, -164000,4.514477,2.3010893,,,,,,,,,,,,,, -164100,4.1895933,2.2255223,,,,,,,,,,,,,, -164200,5.011338,2.5446162,,,,,,,,,,,,,, -164300,5.341061,3.6525826,,,,,,,,,,,,,, -164400,4.8817883,2.6881247,,,,,,,,,,,,,, -164500,4.7581716,3.560215,,,,,,,,,,,,,, -164530,,,0.82923823595047,0.7566958665847778,0.751039981842041,1.0820822715759275,50000.0,0.6342000365257263,1.6714333295822144,10000.0,74407.01922512054,80224.93182849884,74407.01922512054,5802.109175443649,7.517110109329224,0.0 -164600,4.8586698,2.1648247,,,,,,,,,,,,,, -164700,4.5330744,3.1086137,,,,,,,,,,,,,, -164800,4.5331707,2.2791307,,,,,,,,,,,,,, -164900,4.3494196,2.8481138,,,,,,,,,,,,,, -165000,4.6123743,2.419182,,,,,,,,,,,,,, -165100,4.486238,2.388486,,,,,,,,,,,,,, -165200,4.833196,2.376817,,,,,,,,,,,,,, -165300,5.140459,2.369638,,,,,,,,,,,,,, -165400,5.3033595,3.5005586,,,,,,,,,,,,,, -165458,,,0.8375976085662842,0.7356262803077698,0.754859983921051,1.0842537879943848,50000.0,0.6332000494003296,1.685017704963684,10000.0,74827.14331531525,80680.53797078133,74827.14331531525,5837.49352812767,7.568860530853271,0.0 -165500,5.2586145,2.3909326,,,,,,,,,,,,,, -165600,5.227793,3.403738,,,,,,,,,,,,,, -165700,5.7424436,4.138406,,,,,,,,,,,,,, -165800,4.9881353,2.7126672,,,,,,,,,,,,,, -165900,5.144749,2.2780252,,,,,,,,,,,,,, -166000,4.93344,2.2956734,,,,,,,,,,,,,, -166100,4.754572,2.366304,,,,,,,,,,,,,, -166200,4.5567074,3.0305734,,,,,,,,,,,,,, -166300,4.9827685,3.1521878,,,,,,,,,,,,,, -166389,,,0.8311718702316284,0.7485712170600891,0.7549999952316284,1.0757715702056885,50000.0,0.6389000415802002,1.6692404747009275,10000.0,75247.37900185585,81135.34643173218,75247.37900185585,5871.969348192215,7.618942499160767,0.0 -166400,4.903982,2.5919106,,,,,,,,,,,,,, -166500,5.3362827,2.2314785,,,,,,,,,,,,,, -166600,5.3967266,4.113095,,,,,,,,,,,,,, -166700,5.125962,2.375799,,,,,,,,,,,,,, -166800,4.501397,2.3535845,,,,,,,,,,,,,, -166900,5.7289166,2.281618,,,,,,,,,,,,,, -167000,4.9462705,2.2855012,,,,,,,,,,,,,, -167100,5.1084847,2.4853826,,,,,,,,,,,,,, -167200,5.747127,2.8551805,,,,,,,,,,,,,, -167300,4.8387055,3.525186,,,,,,,,,,,,,, -167317,,,0.8334569931030273,0.7626773715019226,0.7558599710464478,1.091723918914795,50000.0,0.6389000415802002,1.6732136011123655,10000.0,75667.48089122772,81590.24653863907,75667.48089122772,5906.669553279877,7.670528650283813,0.0 -167400,4.8909597,2.3106463,,,,,,,,,,,,,, -167500,5.4559035,3.46759,,,,,,,,,,,,,, -167600,5.236008,2.2398925,,,,,,,,,,,,,, -167700,4.9586353,2.8890553,,,,,,,,,,,,,, -167800,5.3205237,2.3138118,,,,,,,,,,,,,, -167900,5.1676188,2.2188435,,,,,,,,,,,,,, -168000,5.025832,2.1622932,,,,,,,,,,,,,, -168100,4.622839,2.319065,,,,,,,,,,,,,, -168200,5.1898894,2.2680519,,,,,,,,,,,,,, -168246,,,0.8374999761581421,0.7262058258056641,0.7564399838447571,1.0680700540542605,50000.0,0.6396000385284424,1.6480201482772827,10000.0,76087.52441835403,82044.95395517349,76087.52441835403,5941.224766492844,7.7321436405181885,0.0 -168300,4.9706206,2.7691422,,,,,,,,,,,,,, -168400,5.0595937,3.008132,,,,,,,,,,,,,, -168500,4.6058874,2.9901104,,,,,,,,,,,,,, -168600,5.3414927,2.240219,,,,,,,,,,,,,, -168700,5.2006893,2.2733533,,,,,,,,,,,,,, -168800,5.299084,2.2793431,,,,,,,,,,,,,, -168900,5.0729723,2.319713,,,,,,,,,,,,,, -169000,4.995938,2.2824793,,,,,,,,,,,,,, -169100,5.751478,3.9394932,,,,,,,,,,,,,, -169175,,,0.83607417345047,0.7408673167228699,0.7595399618148804,1.0654107332229614,50000.0,0.6421000361442566,1.6457515954971311,10000.0,76507.65330696106,82499.51538085938,76507.65330696106,5975.560669898987,7.782688856124878,0.0 -169200,5.383334,2.464941,,,,,,,,,,,,,, -169300,5.4447217,3.085865,,,,,,,,,,,,,, -169400,5.494101,3.1967616,,,,,,,,,,,,,, -169500,5.182959,2.3596904,,,,,,,,,,,,,, -169600,5.5427356,2.2665257,,,,,,,,,,,,,, -169700,5.13719,2.253827,,,,,,,,,,,,,, -169800,4.998607,2.2321,,,,,,,,,,,,,, -169900,5.2380733,2.251092,,,,,,,,,,,,,, -170000,6.1382155,3.8072922,,,,,,,,,,,,,, -170100,5.0629344,2.3633153,,,,,,,,,,,,,, -170103,,,0.8408398032188416,0.7134418487548828,0.7605199813842773,1.0481165647506714,50000.0,0.6447000503540039,1.6202740669250488,10000.0,76927.95672249794,82951.85792160034,76927.95672249794,6007.491415500641,7.844382524490356,0.0 -170200,5.432248,2.2473702,,,,,,,,,,,,,, -170300,5.129003,2.3505418,,,,,,,,,,,,,, -170400,4.7634935,3.1422389,,,,,,,,,,,,,, -170500,5.7413282,2.5784564,,,,,,,,,,,,,, -170600,6.460103,3.838564,,,,,,,,,,,,,, -170700,4.9668784,2.2644472,,,,,,,,,,,,,, -170800,4.686823,2.770429,,,,,,,,,,,,,, -170900,4.8489037,2.7905788,,,,,,,,,,,,,, -171000,5.451108,2.3271375,,,,,,,,,,,,,, -171033,,,0.8424023389816284,0.7111977934837341,0.7627399563789368,1.051758050918579,50000.0,0.6460000276565552,1.633678674697876,10000.0,77348.14902710915,83404.94340229034,77348.14902710915,6040.283388137817,7.898629903793335,0.0 -171100,6.1914763,4.1877885,,,,,,,,,,,,,, -171200,5.4222984,2.3475192,,,,,,,,,,,,,, -171300,5.0696797,2.3757253,,,,,,,,,,,,,, -171400,4.8050704,2.1225119,,,,,,,,,,,,,, -171419,,,,,,,,,,,77520.33168768883,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/eval_measurements.csv deleted file mode 100644 index dad28ab8c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -30.13853859901428,0.0,34.00022578239441,1,0,34.00022578239441,0.0010000000474974,6.907756805419922,10000,64.13885116577148,0.0008593749953433,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -65.81416869163513,0.0183992385864257,453.9772663116455,864,0,453.9772663116455,0.0127000007778406,6.433145999908447,10000,519.8541326522827,0.01708984375,6.374186038970947,0.0169200003147125,6.388251781463623,50000 -104.78045845031738,0.050161600112915,874.0901634693146,1785,0,874.0901634693146,0.0368000008165836,5.947690486907959,10000,979.012511730194,0.0503515601158142,5.795556545257568,0.048539999872446,5.824850559234619,50000 -144.04840278625488,0.0794043540954589,1294.3986456394196,2708,0,1294.3986456394196,0.0502000041306018,5.638961315155029,10000,1438.6663777828217,0.07275390625,5.425717830657959,0.0682599991559982,5.4650092124938965,50000 -183.41586256027225,0.1033854484558105,1714.635325908661,3630,0,1714.635325908661,0.0756999999284744,5.355987071990967,10000,1898.3418984413147,0.1091406196355819,5.063491344451904,0.0984599962830543,5.12637186050415,50000 -218.49655675888064,0.1298916339874267,2134.942580461502,4552,0,2134.942580461502,0.106500007212162,5.0176777839660645,10000,2353.803330421448,0.1644726544618606,4.604984760284424,0.143119990825653,4.721293449401856,50000 -252.4030647277832,0.1591732501983642,2555.2755966186523,5473,0,2555.2755966186523,0.1381999999284744,4.75654411315918,10000,2808.118916273117,0.1963085830211639,4.357907772064209,0.1792799979448318,4.431841850280762,50000 -288.8565876483917,0.1901607513427734,2975.428004980088,6396,0,2975.428004980088,0.1730000078678131,4.418845176696777,10000,3264.80242729187,0.2514062523841858,3.91365122795105,0.2306199967861175,4.019598484039307,50000 -322.7090003490448,0.2164924144744873,3395.463133573532,7320,0,3395.463133573532,0.203900009393692,4.221038818359375,10000,3718.7628943920135,0.2916015684604645,3.662829160690308,0.2649999856948852,3.806907415390015,50000 -357.6776604652405,0.2442131042480468,3815.831436395645,8243,0,3815.831436395645,0.2307000160217285,4.009162425994873,10000,4174.174175024033,0.3251757621765136,3.437488555908203,0.3050200045108795,3.53970718383789,50000 -391.93957924842834,0.2702317237854004,4235.847739696503,9164,0,4235.847739696503,0.254800021648407,3.8746755123138414,10000,4628.525053501129,0.3589257597923279,3.239076375961304,0.3297999799251556,3.377380609512329,50000 -430.49109721183777,0.2951619625091553,4655.996357917786,10086,0,4655.996357917786,0.281900018453598,3.6769216060638414,10000,5087.298045635223,0.3991015553474426,3.000910997390747,0.3657599985599518,3.1630349159240723,50000 -465.17704224586487,0.3227870464324951,5075.949893951416,11009,0,5075.949893951416,0.2939999997615814,3.614280223846436,10000,5542.012524604797,0.4095703065395355,2.957759618759156,0.3806999921798706,3.086174726486206,50000 -499.54923462867737,0.3511896133422851,5496.126168727875,11932,0,5496.126168727875,0.3217000067234039,3.442434549331665,10000,5996.636655807495,0.4536913931369781,2.704256534576416,0.4152999818325043,2.878851890563965,50000 -532.9983906745911,0.3771047592163086,5916.12104177475,12852,0,5916.12104177475,0.332500010728836,3.3532426357269287,10000,6450.153452396393,0.4732812345027923,2.620152473449707,0.4334799945354461,2.793245553970337,50000 -568.8881900310516,0.4073190689086914,6336.528871059418,13774,0,6336.528871059418,0.35630002617836,3.2585787773132324,10000,6906.528246641159,0.4964648187160492,2.481999397277832,0.4541199803352356,2.691492795944214,50000 -599.7011110782623,0.4356505870819092,6756.774499177933,14697,0,6756.774499177933,0.3745000064373016,3.150657892227173,10000,7357.662885427475,0.5178515315055847,2.393428325653076,0.4753799736499786,2.57056212425232,50000 -634.7569868564606,0.4672560691833496,7176.8117735385895,15618,0,7176.8117735385895,0.3872000277042389,3.020982027053833,10000,7812.835678100586,0.5367968678474426,2.243785858154297,0.49685999751091,2.430511713027954,50000 -669.5322709083557,0.4950077533721924,7597.053501844406,16540,0,7597.053501844406,0.3929000198841095,2.9950408935546875,10000,8267.926774263382,0.562207043170929,2.14985466003418,0.507319986820221,2.4012911319732666,50000 -702.9022083282471,0.5227203369140625,8017.169707298279,17463,0,8017.169707298279,0.398900032043457,2.9856181144714355,10000,8721.486993074417,0.553515613079071,2.2031924724578857,0.5131999850273132,2.382346868515014,50000 -736.4718625545502,0.549354076385498,8437.620900630951,18381,0,8437.620900630951,0.4169000089168548,2.871544122695923,10000,9175.58203101158,0.5694921612739563,2.0806596279144287,0.5267999768257141,2.2643842697143555,50000 -772.117954492569,0.5785338878631592,8857.677038431168,19298,0,8857.677038431168,0.4219000339508056,2.838536739349365,10000,9631.36013674736,0.5946288704872131,1.9711012840271,0.5412999987602234,2.2097156047821045,50000 -807.2917928695679,0.6079757213592529,9277.779623508452,20217,0,9277.779623508452,0.4321000277996063,2.757359027862549,10000,10086.7133436203,0.5886132717132568,1.966526865959168,0.5514000058174133,2.1377973556518555,50000 -841.8244321346283,0.6349647045135498,9697.771949529648,21138,0,9697.771949529648,0.4409000277519226,2.7603559494018555,10000,10541.31220149994,0.6024804711341858,1.935488104820252,0.5561999678611755,2.1444432735443115,50000 -876.3112514019012,0.6652498245239258,10117.83263373375,22059,0,10117.83263373375,0.4448000192642212,2.710822343826294,10000,10995.936970949171,0.6153905987739563,1.863569259643555,0.5629599690437317,2.098595142364502,50000 -908.642866373062,0.6947681903839111,10537.896867275238,22978,0,10537.896867275238,0.451200008392334,2.6695117950439453,10000,11448.413235902786,0.6119140386581421,1.860370397567749,0.568619966506958,2.0590929985046387,50000 -943.3075177669524,0.7228202819824219,10958.168203353882,23901,0,10958.168203353882,0.4563000202178955,2.6497387886047363,10000,11903.424647808077,0.619140625,1.842691659927368,0.5765399932861328,2.036699056625366,50000 -974.7620024681092,0.7505502700805664,11378.107141256332,24825,0,11378.107141256332,0.455700010061264,2.65883469581604,10000,12354.893264770508,0.6344335675239563,1.8031526803970337,0.5821200013160706,2.0396535396575928,50000 -1009.1432175636292,0.782789945602417,11798.067595005035,25745,0,11798.067595005035,0.4617000222206116,2.632565975189209,10000,12809.31416130066,0.6540429592132568,1.6929877996444702,0.5794999599456787,2.018251895904541,50000 -1040.0437922477722,0.8123970031738281,12218.119409561155,26668,0,12218.119409561155,0.4686000347137451,2.565860748291016,10000,13260.34297466278,0.6333398222923279,1.7327735424041748,0.5909799933433533,1.9266051054000848,50000 -1074.250785589218,0.8448200225830078,12638.257354021072,27589,0,12638.257354021072,0.4768000245094299,2.580284595489502,10000,13714.76744222641,0.6485351324081421,1.7430307865142822,0.5958600044250488,1.972648024559021,50000 -1104.9193880558014,0.8732032775878906,13058.69831609726,28513,0,13058.69831609726,0.4702000319957733,2.5671634674072266,10000,14165.95332980156,0.6622265577316284,1.6519564390182495,0.5966399908065796,1.929701805114746,50000 -1138.7440557479858,0.9072282314300536,13478.946665525436,29436,0,13478.946665525436,0.480400025844574,2.534058094024658,10000,14620.107602596285,0.6512890458106995,1.7113102674484253,0.6077600121498108,1.9051331281661987,50000 -1173.688907146454,0.9404423236846924,13898.911831855774,30358,0,13898.911831855774,0.4905000329017639,2.532050609588623,10000,15075.097905397415,0.6586328148841858,1.6941852569580078,0.6078000068664551,1.9113799333572388,50000 -1206.0904257297516,0.9788103103637696,14319.18147277832,31283,0,14319.18147277832,0.4874000251293182,2.487379312515259,10000,15527.85389828682,0.6657617092132568,1.6142001152038574,0.6137599945068359,1.8565632104873653,50000 -1240.2513403892517,1.008094310760498,14739.35320687294,32206,0,14739.35320687294,0.4907000362873077,2.5017411708831787,10000,15982.26224064827,0.6634374856948853,1.6532877683639526,0.6112799644470215,1.866252422332764,50000 -1273.814103603363,1.0417673587799072,15159.512373924255,33130,0,15159.512373924255,0.5005000233650208,2.448014974594116,10000,16436.065540075302,0.6699804663658142,1.617461323738098,0.6247999668121338,1.8267481327056885,50000 -1307.9982221126556,1.0729899406433103,15579.846864700316,34053,0,15579.846864700316,0.4978000223636627,2.4420933723449707,10000,16890.66156053543,0.6822265386581421,1.577724575996399,0.6250799894332886,1.833191156387329,50000 -1344.7656226158142,1.116785764694214,15999.829869508743,34975,0,15999.829869508743,0.4970000088214874,2.4435598850250244,10000,17347.502505779266,0.7008007764816284,1.4769867658615112,0.6222400069236755,1.811497449874878,50000 -1374.418939828873,1.147679090499878,16420.072093486786,35898,0,16420.072093486786,0.5062000155448914,2.371567964553833,10000,17797.47596859932,0.6775780916213989,1.5361953973770142,0.6276800036430359,1.7578741312026978,50000 -1409.6424214839935,1.1869757175445557,16840.190497398376,36818,0,16840.190497398376,0.5077000260353088,2.443421125411988,10000,18252.90421462059,0.6763281226158142,1.6137640476226809,0.6283800005912781,1.842094898223877,50000 -1442.1937334537506,1.2198548316955566,17260.28178691864,37740,0,17260.28178691864,0.5083000063896179,2.3802409172058105,10000,18705.626772880554,0.6975976228713989,1.450102925300598,0.631060004234314,1.7535679340362549,50000 -1475.909290790558,1.252131700515747,17680.43620157242,38663,0,17680.43620157242,0.5202000141143799,2.346991777420044,10000,19159.576062202454,0.6893359422683716,1.514500379562378,0.6391400098800659,1.7250676155090332,50000 -1509.32852268219,1.2823808193206787,18100.48945403099,39584,0,18100.48945403099,0.5121999979019165,2.3584721088409424,10000,19613.12621331215,0.6898046731948853,1.4997754096984863,0.6385999917984009,1.728279948234558,50000 -1540.6759810447693,1.316420316696167,18520.47150492668,40507,0,18520.47150492668,0.5217000246047974,2.341832637786865,10000,20064.541659355164,0.70068359375,1.4647341966629028,0.6398599743843079,1.738932967185974,50000 -1576.082843542099,1.3496661186218262,18940.805812358856,41427,0,18940.805812358856,0.5194000005722046,2.328978538513184,10000,20520.362723588943,0.6917382478713989,1.4947209358215332,0.6426999568939209,1.7017552852630615,50000 -1608.2144901752472,1.3862807750701904,19361.10916209221,42353,0,19361.10916209221,0.5272000432014465,2.3286025524139404,10000,20972.88152360916,0.70068359375,1.465484619140625,0.6469199657440186,1.71008563041687,50000 -1642.51211977005,1.4228754043579102,19781.25616669655,43275,0,19781.25616669655,0.5193000435829163,2.333106279373169,10000,21427.411111593246,0.7005859017372131,1.4444526433944702,0.6436799764633179,1.7053855657577517,50000 -1675.5485351085665,1.4599545001983645,20201.509860515594,44199,0,20201.509860515594,0.522599995136261,2.296396255493164,10000,21880.785943746567,0.7070116996765137,1.4227125644683838,0.6459999680519104,1.6852363348007202,50000 -1711.402881860733,1.4905712604522705,20621.46268749237,45121,0,20621.46268749237,0.5268000364303589,2.3047330379486084,10000,22336.67161679268,0.7047460675239563,1.4595746994018557,0.6520599722862244,1.6880558729171753,50000 -1742.409699678421,1.524674892425537,21041.53035831452,46045,0,21041.53035831452,0.5236999988555908,2.286877155303955,10000,22787.82704544068,0.703808605670929,1.4133350849151611,0.649899959564209,1.6546680927276611,50000 -1776.6765806674955,1.556483030319214,21461.85862827301,46969,0,21461.85862827301,0.5320000052452087,2.2345681190490723,10000,23242.50052928925,0.7266796827316284,1.295320987701416,0.6565399765968323,1.61535382270813,50000 -1810.9130997657776,1.5944783687591553,21882.07282042504,47891,0,21882.07282042504,0.5309000015258789,2.2502574920654297,10000,23697.036667346954,0.7119921445846558,1.4000214338302612,0.6568399667739868,1.6305649280548096,50000 -1841.5184531211853,1.6299116611480713,22302.034429311752,48814,0,22302.034429311752,0.5335000157356262,2.232346534729004,10000,24147.68631052971,0.7128710746765137,1.3655827045440674,0.6576200127601624,1.6067535877227783,50000 -1874.509628534317,1.6707801818847656,22722.3923227787,49734,0,22722.3923227787,0.5344000458717346,2.236469984054565,10000,24601.1224694252,0.728808581829071,1.3120293617248535,0.6609199643135071,1.6072094440460205,50000 -1909.011991024017,1.70641827583313,23142.32176733017,50652,0,23142.32176733017,0.5375000238418579,2.2071101665496826,10000,25055.637226343155,0.7116601467132568,1.3645031452178955,0.6611599922180176,1.5872159004211426,50000 -1940.8010630607605,1.7460317611694336,23562.47232246399,51574,0,23562.47232246399,0.5333000421524048,2.240694761276245,10000,25507.66274857521,0.717968761920929,1.3651286363601685,0.6620799899101257,1.6173219680786133,50000 -1975.382539987564,1.786369800567627,23982.808941841125,52496,0,23982.808941841125,0.5406000018119812,2.2341880798339844,10000,25962.668434381485,0.7241796851158142,1.34527850151062,0.6620000004768372,1.6125682592391968,50000 -2006.3908026218407,1.824455976486206,24402.90833449364,53418,0,24402.90833449364,0.5430999994277954,2.2015929222106934,10000,26413.86163020134,0.7199804782867432,1.3372256755828855,0.6657999753952026,1.5774351358413696,50000 -2039.2105541229248,1.866105318069458,24823.192568540573,54339,0,24823.192568540573,0.5449000000953674,2.186253309249878,10000,26867.0536248684,0.7254882454872131,1.335657000541687,0.670699954032898,1.579382300376892,50000 -2070.5169591903687,1.901301383972168,25243.527433633804,55261,0,25243.527433633804,0.5430000424385071,2.1910340785980225,10000,27318.778652668,0.7294921875,1.3043168783187866,0.667419970035553,1.5707368850708008,50000 -2102.191232919693,1.934786796569824,25663.721523284912,56182,0,25663.721523284912,0.5533000230789185,2.1446516513824463,10000,27770.72711038589,0.7528125047683716,1.1971156597137451,0.6747999787330627,1.5370960235595703,50000 -2136.8173022270203,1.9684455394744875,26084.066102027893,57106,0,26084.066102027893,0.546500027179718,2.2251136302948,10000,28225.778723955154,0.72314453125,1.3708430528640747,0.6675199866294861,1.6150413751602173,50000 -2169.2758326530457,2.005272388458252,26504.048363685608,58026,0,26504.048363685608,0.5457000136375427,2.212608337402344,10000,28678.30360651016,0.7324023246765137,1.3181718587875366,0.6703199744224548,1.5841935873031616,50000 -2200.3360035419464,2.045383214950561,26924.052193164825,58948,0,26924.052193164825,0.5502000451087952,2.179090738296509,10000,29129.454854011536,0.7443945407867432,1.2518020868301392,0.6735599637031555,1.557487726211548,50000 -2235.119631052017,2.088909387588501,27344.28582406044,59869,0,27344.28582406044,0.5525000095367432,2.166768789291382,10000,29584.562819719315,0.7306249737739563,1.318999409675598,0.6790599822998047,1.5507181882858276,50000 -2268.4358253479004,2.123471736907959,27764.46465349197,60792,0,27764.46465349197,0.5435000061988831,2.1946136951446533,10000,30038.139572381973,0.7305468320846558,1.3091228008270264,0.6714000105857849,1.5599348545074463,50000 -2302.441154003144,2.160538911819458,28184.4619910717,61707,0,28184.4619910717,0.5573000311851501,2.1491177082061768,10000,30492.225292921063,0.7434960603713989,1.2444376945495603,0.6765999794006348,1.5327082872390747,50000 -2337.75715136528,2.1949493885040283,28604.43643712997,62627,0,28604.43643712997,0.5503000020980835,2.195638656616211,10000,30947.598170518875,0.7280663847923279,1.3498351573944092,0.672540009021759,1.5801585912704468,50000 -2373.211641073227,2.230957508087158,29024.70809817314,63549,0,29024.70809817314,0.5556000471115112,2.1663198471069336,10000,31403.40734767914,0.7319726347923279,1.305245280265808,0.6799799799919128,1.548778414726257,50000 -2407.421452522278,2.273590564727783,29444.98510026932,64471,0,29444.98510026932,0.5586000084877014,2.1107430458068848,10000,31857.98429250717,0.7493359446525574,1.2207766771316528,0.6810799837112427,1.495862603187561,50000 -2442.6564412117004,2.313045024871826,29864.908487081528,65392,0,29864.908487081528,0.5560000538825989,2.167978048324585,10000,32313.229038715363,0.7576562166213989,1.2009063959121704,0.6784200072288513,1.5377628803253174,50000 -2477.143489599228,2.3523316383361816,30285.114988565445,66313,0,30285.114988565445,0.5582000017166138,2.1454310417175293,10000,32768.00912475586,0.7408398389816284,1.2733092308044434,0.6830599904060364,1.5249103307724,50000 -2511.189737558365,2.390868186950684,30705.05445933342,67235,0,30705.05445933342,0.5560000538825989,2.115541219711304,10000,33222.08050394058,0.7444140315055847,1.232791304588318,0.6813200116157532,1.510599970817566,50000 -2546.431643724441,2.431232452392578,31125.21816754341,68155,0,31125.21816754341,0.5618000030517578,2.129547595977783,10000,33677.57394862175,0.761523425579071,1.198433756828308,0.6867199540138245,1.5193663835525513,50000 -2578.39408159256,2.4752867221832275,31545.46352577209,69076,0,31545.46352577209,0.5611000061035156,2.0891273021698,10000,34129.872649908066,0.7444921731948853,1.2327656745910645,0.6861199736595154,1.484375238418579,50000 -2612.8748741149902,2.515571594238281,31965.51737833023,69998,0,31965.51737833023,0.5593000054359436,2.113833427429199,10000,34584.49463844299,0.7477343678474426,1.226297378540039,0.6846599578857422,1.4979448318481443,50000 -2644.0857014656067,2.554614782333374,32385.57051825524,70919,0,32385.57051825524,0.5624000430107117,2.088679552078247,10000,35035.8462703228,0.761035144329071,1.179047465324402,0.6904599666595459,1.4743375778198242,50000 -2677.212243080139,2.599778652191162,32805.80994772911,71841,0,32805.80994772911,0.5604000091552734,2.14981746673584,10000,35489.30454015732,0.7437499761581421,1.298214554786682,0.6857199668884277,1.5482041835784912,50000 -2709.371276378632,2.638214349746704,33225.91061067581,72764,0,33225.91061067581,0.5694000124931335,2.0661840438842773,10000,35941.64887547493,0.7504296898841858,1.206867218017578,0.6949999928474426,1.4593466520309448,50000 -2740.13196849823,2.6735827922821045,33646.25889086723,73686,0,33646.25889086723,0.5634000301361084,2.13584303855896,10000,36392.84107017517,0.7568163871765137,1.2269736528396606,0.6904199719429016,1.5098013877868652,50000 -2777.113403081894,2.73940110206604,34066.51602935791,74607,0,34066.51602935791,0.5653000473976135,2.070711374282837,10000,36850.19239878655,0.7570703029632568,1.1745266914367676,0.6887800097465515,1.460245132446289,50000 -2811.2738075256348,2.781567335128784,34486.733615875244,75530,0,34486.733615875244,0.5676000118255615,2.0816633701324463,10000,37304.65952205658,0.75341796875,1.207563400268555,0.6945199966430664,1.4618675708770752,50000 -2843.222591161728,2.818364381790161,34906.928926706314,76453,0,34906.928926706314,0.5722000002861023,2.0874454975128174,10000,37756.88763618469,0.7587304711341858,1.1997249126434326,0.6966999769210815,1.4641448259353638,50000 -2874.2982473373413,2.860573530197144,35327.165003061295,77374,0,35327.165003061295,0.5695000290870667,2.073258399963379,10000,38208.289085149765,0.7772070169448853,1.1090781688690186,0.694599986076355,1.465245246887207,50000 -2908.8261551856995,2.899253368377685,35747.40086340904,78298,0,35747.40086340904,0.5665000081062317,2.110428810119629,10000,38663.13839507103,0.7567187547683716,1.2263389825820925,0.6919599771499634,1.4922540187835691,50000 -2937.9259643554688,2.9404571056365967,36167.35218763352,79219,0,36167.35218763352,0.5734000205993652,2.026544570922852,10000,39112.27828145027,0.76429682970047,1.1475809812545776,0.6978799700737,1.4310686588287354,50000 -2972.902609348297,2.984685182571411,36587.51485085488,80140,0,36587.51485085488,0.570900022983551,2.05631685256958,10000,39567.50933790207,0.7719921469688416,1.1283113956451416,0.6975399851799011,1.4419785737991333,50000 -3005.9743795394897,3.023676633834839,37007.73335170746,81062,0,37007.73335170746,0.5752000212669373,2.033855438232422,10000,40020.88571333885,0.7625781297683716,1.1647394895553589,0.7009999752044678,1.4354666471481323,50000 -3037.7150337696075,3.0635695457458496,37427.75093221664,81984,0,37427.75093221664,0.5760000348091125,2.041268348693848,10000,40472.73112511635,0.7649609446525574,1.1457265615463257,0.7003799676895142,1.4246442317962646,50000 -3071.735149860382,3.1044604778289795,37847.68333148956,82906,0,37847.68333148956,0.5736000537872314,2.0298521518707275,10000,40926.77200841904,0.7728124856948853,1.1158013343811035,0.7010799646377563,1.41843843460083,50000 -3107.9714460372925,3.151188611984253,38267.85958909989,83831,0,38267.85958909989,0.5822000503540039,2.009699821472168,10000,41383.278517484665,0.7667187452316284,1.1418139934539795,0.7049999833106995,1.4089363813400269,50000 -3139.4142003059387,3.1884803771972656,38688.02298378944,84754,0,38688.02298378944,0.5760000348091125,2.029409170150757,10000,41834.969650030136,0.7683984041213989,1.1525145769119265,0.7024199962615967,1.4284731149673462,50000 -3173.858931779861,3.232621431350708,39108.00399470329,85675,0,39108.00399470329,0.5767000317573547,2.0247151851654053,10000,42289.48701906204,0.7705664038658142,1.13634991645813,0.7039200067520142,1.421459078788757,50000 -3211.1007244586945,3.273395776748657,39528.32387185097,86597,0,39528.32387185097,0.5759000182151794,2.075772523880005,10000,42747.13645219803,0.7929882407188416,1.103911519050598,0.7073799967765808,1.468793511390686,50000 -3247.392087459564,3.3191680908203125,39948.40259766579,87520,0,39948.40259766579,0.575700044631958,2.062497854232788,10000,43203.59956860542,0.7638280987739563,1.196157455444336,0.7017599940299988,1.4531490802764893,50000 -3281.204854249954,3.3600547313690186,40368.75314331055,88441,0,40368.75314331055,0.5833000540733337,1.993497610092163,10000,43657.85054802895,0.7773827910423279,1.0995545387268066,0.7068799734115601,1.3993632793426514,50000 -3313.9709889888763,3.399597406387329,40788.70879864693,89362,0,40788.70879864693,0.5820000171661377,2.0068376064300537,10000,44110.658349752426,0.7879687547683716,1.051804780960083,0.7084199786186218,1.3937915563583374,50000 -3346.769567489624,3.440016031265259,41208.68663263321,90282,0,41208.68663263321,0.5849000215530396,2.013418197631836,10000,44563.522152900696,0.7703515291213989,1.137851357460022,0.7090799808502197,1.4091787338256836,50000 -3380.757195711136,3.477928876876831,41628.99574398994,91203,0,41628.99574398994,0.5835000276565552,1.9904738664627075,10000,45017.9036052227,0.779980480670929,1.1019827127456665,0.7110399603843689,1.3903183937072754,50000 -3413.170263528824,3.516024589538574,42049.15962576866,92127,0,42049.15962576866,0.5919000506401062,1.961077690124512,10000,45470.5659160614,0.7870507836341858,1.0477782487869265,0.7113999724388123,1.3689666986465454,50000 -3447.9696865081787,3.5564661026000977,42469.19526147842,93051,0,42469.19526147842,0.5963000059127808,1.9534903764724727,10000,45925.48884344101,0.7793359160423279,1.0956302881240845,0.7135199904441833,1.3716667890548706,50000 -3480.8335807323456,3.598971605300904,42889.19483089447,93974,0,42889.19483089447,0.5918000340461731,1.9728366136550903,10000,46378.44157075882,0.7795116901397705,1.0931429862976074,0.7133600115776062,1.3716752529144287,50000 -3516.163388967514,3.6427321434021,43309.4011952877,94893,0,43309.4011952877,0.5925000309944153,1.949466347694397,10000,46834.06846022606,0.7892773151397705,1.0392942428588867,0.7130999565124512,1.3569936752319336,50000 -3547.8907368183136,3.684923648834229,43729.76480174065,95815,0,43729.76480174065,0.5929000377655029,1.937144875526428,10000,47286.249116420746,0.7995898127555847,0.9954485297203064,0.7177199721336365,1.337254524230957,50000 -3581.8719758987427,3.726364850997925,44150.00039052963,96738,0,44150.00039052963,0.5933000445365906,1.9811594486236568,10000,47740.57042002678,0.77943354845047,1.1019657850265503,0.7138400077819824,1.384581208229065,50000 -3617.057484388352,3.7728912830352783,44570.135746240616,97662,0,44570.135746240616,0.5907000303268433,1.9830719232559204,10000,48195.98546934128,0.7866015434265137,1.0639996528625488,0.7142999768257141,1.3826018571853638,50000 -3648.015291452408,3.817314863204956,44990.07020807266,98587,0,44990.07020807266,0.5903000235557556,1.9637293815612795,10000,48646.9700319767,0.7963476181030273,1.0071494579315186,0.7149400115013123,1.3564001321792605,50000 -3683.7723546028137,3.8653643131256095,45410.03365278244,99509,0,45410.03365278244,0.597100019454956,1.9461467266082764,10000,49102.78510594368,0.78919917345047,1.0567806959152222,0.7200799584388733,1.3418904542922974,50000 -3720.710344791413,3.908934354782105,45830.30660319328,100432,0,45830.30660319328,0.5997000336647034,1.9448856115341189,10000,49560.086330890656,0.7888476252555847,1.059237003326416,0.7181999683380127,1.3572092056274414,50000 -3758.933340787888,3.949855089187622,46250.637231349945,101357,0,46250.637231349945,0.597100019454956,1.923500657081604,10000,50018.72805285454,0.8024413585662842,0.987043559551239,0.7212799787521362,1.3317164182662964,50000 -3797.81196808815,3.9987268447875977,46670.904515028,102279,0,46670.904515028,0.5967000126838684,1.939675450325012,10000,50477.97007155418,0.7923827767372131,1.034652590751648,0.7217400074005127,1.337119460105896,50000 -3832.254055023194,4.04592490196228,47091.08183169365,103199,0,47091.08183169365,0.5997000336647034,1.9212018251419067,10000,50932.68405771256,0.7949413657188416,1.0258285999298096,0.7251799702644348,1.329066514968872,50000 -3865.475003957749,4.091261148452759,47511.4247674942,104119,0,47511.4247674942,0.6007000207901001,1.8960027694702148,10000,51386.34076428413,0.8043944835662842,0.9767816662788392,0.7267799973487854,1.3111376762390137,50000 -3902.273670196533,4.136998176574707,47931.44809389114,105026,0,47931.44809389114,0.6068000197410583,1.9124714136123653,10000,51843.25613093376,0.8005077838897705,0.9992862343788148,0.7254199981689453,1.3189843893051147,50000 -3938.51873588562,4.1818811893463135,48351.60951066017,105950,0,48351.60951066017,0.6034000515937805,1.899614930152893,10000,52299.75429058075,0.7973827719688416,0.9926227927207948,0.7240599989891052,1.3056074380874634,50000 -3972.452043533325,4.225016117095947,48771.873259305954,106871,0,48771.873259305954,0.6069000363349915,1.9079723358154297,10000,52754.04180908203,0.8055273294448853,0.9919548630714417,0.7270799875259399,1.3199224472045898,50000 -4008.251214504242,4.272565364837647,49191.80590748787,107792,0,49191.80590748787,0.6049000024795532,1.9311014413833616,10000,53209.868277311325,0.8181054592132568,0.96428245306015,0.7266599535942078,1.341126561164856,50000 -4042.738936901093,4.315826892852783,49611.74863290787,108714,0,49611.74863290787,0.6097000241279602,1.8781938552856443,10000,53664.38921499252,0.8000195026397705,0.9913946390151978,0.7280600070953369,1.2932459115982056,50000 -4079.958817481994,4.35719108581543,50031.87528467178,109635,0,50031.87528467178,0.6095000505447388,1.916000843048096,10000,54121.824350357056,0.8060351610183716,0.9905210733413696,0.7256999611854553,1.3248872756958008,50000 -4115.946965456009,4.404077291488648,50451.978063583374,110557,0,50451.978063583374,0.6116000413894653,1.922496199607849,10000,54578.00987672806,0.8092968463897705,0.9955326914787292,0.7274199724197388,1.3348983526229858,50000 -4150.26503443718,4.449316024780273,50871.95857954025,111475,0,50871.95857954025,0.6123000383377075,1.863784909248352,10000,55032.40033340454,0.8045703172683716,0.9676870703697203,0.7338599562644958,1.2687467336654663,50000 -4185.976754188538,4.499913215637207,51292.22880363464,112397,0,51292.22880363464,0.6091000437736511,1.8859690427780151,10000,55488.480078697205,0.8095507621765137,0.967756688594818,0.7329199910163879,1.299895405769348,50000 -4220.654976129532,4.545567512512207,51712.16885471344,113317,0,51712.16885471344,0.615600049495697,1.876335978507996,10000,55943.190348148346,0.8145117163658142,0.9527512192726136,0.7319999933242798,1.2896184921264648,50000 -4257.473162174225,4.599227428436279,52132.5546040535,114240,0,52132.5546040535,0.6132000088691711,1.860758304595948,10000,56400.49531674385,0.8056835532188416,0.9622820615768432,0.7326399683952332,1.2778819799423218,50000 -4292.337894201279,4.644402265548706,52552.574162483215,115161,0,52552.574162483215,0.6152000427246094,1.872671604156494,10000,56855.47197389603,0.8109765648841858,0.9565854668617249,0.7332599759101868,1.2811367511749268,50000 -4327.119237422943,4.6875598430633545,52972.78386044502,116082,0,52972.78386044502,0.6181000471115112,1.8877934217453003,10000,57310.55331468582,0.8161913752555847,0.9610245823860168,0.7358399629592896,1.299484133720398,50000 -4365.301455259323,4.738725185394287,53392.99377536774,117005,0,53392.99377536774,0.613800048828125,1.8561517000198364,10000,57769.04441308975,0.8302538990974426,0.8830304741859436,0.7373999953269958,1.2599748373031616,50000 -4405.786436319351,4.782990217208862,53813.30598425865,117924,0,53813.30598425865,0.6217000484466553,1.835737228393555,10000,58229.93274021149,0.8176171779632568,0.9192262291908264,0.7388799786567688,1.245247483253479,50000 -4440.815090417862,4.830764055252075,54233.51160264015,118847,0,54233.51160264015,0.6162000298500061,1.840397596359253,10000,58685.261761426926,0.8188671469688416,0.9099279642105104,0.7407199740409851,1.2410085201263428,50000 -4477.259788036346,4.876175165176392,54653.70579338074,119772,0,54653.70579338074,0.6149000525474548,1.845686197280884,10000,59141.99428868294,0.8258007764816284,0.8815270662307739,0.7394799590110779,1.2439427375793457,50000 -4513.142883777618,4.922154426574707,55073.763964653015,120694,0,55073.763964653015,0.6217000484466553,1.8443632125854488,10000,59598.02819681168,0.81947261095047,0.9173012971878052,0.7415800094604492,1.251138687133789,50000 -4546.011871576309,4.970205307006836,55493.97620844841,121618,0,55493.97620844841,0.6201000213623047,1.8253264427185056,10000,60051.2044301033,0.8227343559265137,0.90450918674469,0.7416200041770935,1.2397983074188232,50000 -4583.217739343643,5.016493797302246,55914.22261214256,122542,0,55914.22261214256,0.6277000308036804,1.8244028091430664,10000,60508.75056910515,0.828417956829071,0.8838768601417542,0.7440599799156189,1.24087393283844,50000 -4619.772355079651,5.061162710189819,56334.52269983292,123465,0,56334.52269983292,0.6285000443458557,1.8137975931167605,10000,60965.69726276398,0.8238476514816284,0.915237843990326,0.746399998664856,1.238528609275818,50000 -4658.850819826126,5.109199523925781,56754.61694145203,124385,0,56754.61694145203,0.6270000338554382,1.8222373723983765,10000,61424.96513009071,0.8286913633346558,0.886573851108551,0.7448999881744385,1.2313250303268433,50000 -4692.450619220734,5.156161308288574,57174.62446951866,125305,0,57174.62446951866,0.6314000487327576,1.801285982131958,10000,61878.66628551483,0.8369140625,0.8582568168640137,0.7482799887657166,1.2224678993225098,50000 -4725.944480419159,5.202749729156494,57594.86729764938,126227,0,57594.86729764938,0.6234000325202942,1.8007813692092896,10000,62332.49602270126,0.8360546827316284,0.8389551639556885,0.7470200061798096,1.2079110145568848,50000 -4762.485838413239,5.24896764755249,58015.21964287758,127149,0,58015.21964287758,0.6300000548362732,1.7968522310256958,10000,62789.48238158226,0.8318945169448853,0.8691099882125854,0.7507199645042419,1.216996669769287,50000 -4800.266010761261,5.297834873199463,58435.50190925598,128072,0,58435.50190925598,0.6292000412940979,1.807980179786682,10000,63247.64047813416,0.8342577815055847,0.880061686038971,0.7497999668121338,1.2383127212524414,50000 -4843.690578460693,5.350980758666992,58855.55880713463,128994,0,58855.55880713463,0.6323000192642212,1.784820795059204,10000,63711.22246456146,0.8449023365974426,0.8151799440383911,0.7497599720954895,1.2033263444900513,50000 -4876.382496595383,5.40146017074585,59275.69723725319,129916,0,59275.69723725319,0.6305000185966492,1.800089955329895,10000,64164.15061426163,0.8347851634025574,0.8594391942024231,0.7525799870491028,1.205952286720276,50000 -4914.366876363754,5.4463982582092285,59695.74828839302,130838,0,59695.74828839302,0.6305000185966492,1.7893548011779783,10000,64622.2776722908,0.8397656083106995,0.8464886546134949,0.753879964351654,1.204972743988037,50000 -4947.561500549316,5.493636131286621,60115.70867657661,131759,0,60115.70867657661,0.6345000267028809,1.7539469003677368,10000,65075.52677679062,0.8443359136581421,0.7996962666511536,0.751039981842041,1.184333324432373,50000 -4983.380564451218,5.537555456161499,60535.7517850399,132681,0,60535.7517850399,0.6374000310897827,1.769114375114441,10000,65531.48008418083,0.83753901720047,0.8382362127304077,0.7551400065422058,1.1913368701934814,50000 -5017.414732217789,5.583152770996094,60955.73228693008,133603,0,60955.73228693008,0.638200044631958,1.7628462314605713,10000,65985.58762574196,0.84046870470047,0.8086775541305542,0.7542600035667419,1.1765190362930298,50000 -5055.079945325851,5.631529569625855,61375.7694671154,134525,0,61375.7694671154,0.6324000358581543,1.7707762718200684,10000,66443.38515496254,0.8450585603713989,0.8082579374313354,0.7573599815368652,1.179269313812256,50000 -5094.165184736252,5.689408540725708,61795.82298064232,135415,0,61795.82298064232,0.6370000243186951,1.7401326894760132,10000,66902.62777686119,0.8431445360183716,0.7971628308296204,0.7569400072097778,1.1679935455322266,50000 -5131.206265687943,5.739239692687988,62215.76799035072,136337,0,62215.76799035072,0.6324000358581543,1.7639225721359253,10000,67359.71117639542,0.8441015481948853,0.8239647746086121,0.7590799927711487,1.176174521446228,50000 -5164.7442235946655,5.78521990776062,62636.00538110733,137259,0,62636.00538110733,0.6370000243186951,1.7493723630905151,10000,67813.579870224,0.8472656011581421,0.8076205253601074,0.7579599618911743,1.180842041969299,50000 -5199.060210227966,5.831464767456055,63055.95963048935,138180,0,63055.95963048935,0.6336000561714172,1.7638990879058838,10000,68267.94333863258,0.8597851395606995,0.7563513517379761,0.7586399912834167,1.1718906164169312,50000 -5234.952075719833,5.88202166557312,63476.03571605682,139101,0,63476.03571605682,0.6370000243186951,1.7559922933578491,10000,68724.00987887383,0.8470116853713989,0.8125196099281311,0.7589399814605713,1.179947018623352,50000 -5270.915474653244,5.92934775352478,63895.9807267189,140024,0,63895.9807267189,0.6468000411987305,1.713780164718628,10000,69180.01283836365,0.8527734279632568,0.758125364780426,0.762499988079071,1.136866569519043,50000 -5304.495651721954,5.978683710098267,64316.03085923195,140945,0,64316.03085923195,0.6384000182151794,1.7504186630249023,10000,69633.73905014992,0.8581054210662842,0.7595937848091125,0.7590199708938599,1.1628683805465698,50000 -5338.206914901733,6.028635501861572,64736.08189225197,141867,0,64736.08189225197,0.641800045967102,1.7470555305480957,10000,70087.59886169434,0.8505077958106995,0.8011537790298462,0.7633000016212463,1.1656519174575806,50000 -5372.0228152275085,6.0786073207855225,65156.10931110382,142788,0,65156.10931110382,0.6448000073432922,1.7246521711349487,10000,70541.53916501999,0.8565039038658142,0.7577779293060303,0.7645599842071533,1.1408846378326416,50000 -5406.831308364868,6.132803916931152,65576.03127932549,143709,0,65576.03127932549,0.6426000595092773,1.7453765869140625,10000,70996.37029480934,0.85986328125,0.7596256732940674,0.7628600001335144,1.166117787361145,50000 -5440.239306926727,6.184715032577515,65996.08980154991,144631,0,65996.08980154991,0.6456000208854675,1.7437084913253784,10000,71449.93562602997,0.8544726371765137,0.784826934337616,0.7655799984931946,1.1578985452651978,50000 -5476.885877370834,6.231912851333618,66416.06876826286,145554,0,66416.06876826286,0.6479000449180603,1.710391640663147,10000,71906.65629696846,0.8596875071525574,0.7511771321296692,0.765999972820282,1.1414402723312378,50000 -5510.088440179825,6.2781524658203125,66836.00802612305,146476,0,66836.00802612305,0.650600016117096,1.7234017848968506,10000,72359.89109659195,0.8626366853713989,0.7605917453765869,0.7663599848747253,1.1581281423568726,50000 -5545.940866231918,6.324589729309082,67255.94326424599,147400,0,67255.94326424599,0.6456000208854675,1.714258074760437,10000,72815.77249288559,0.8684960603713989,0.7181340456008911,0.7680400013923645,1.130260705947876,50000 -5585.072806835175,6.381242990493774,67676.07908463478,148314,0,67676.07908463478,0.6473000049591064,1.71936297416687,10000,73275.14289355278,0.8642382621765137,0.7342635989189148,0.7673199772834778,1.126471996307373,50000 -5622.308450460434,6.429650783538818,68096.35448336601,149235,0,68096.35448336601,0.6455000042915344,1.7286360263824463,10000,73732.74937844276,0.8697851300239563,0.7271069884300232,0.7689200043678284,1.1391234397888184,50000 -5666.171459674835,6.480257034301758,68516.60844302177,150156,0,68516.60844302177,0.650600016117096,1.7134649753570557,10000,74196.96331691742,0.8738671541213989,0.7149843573570251,0.7697599530220032,1.141154170036316,50000 -5699.519348621368,6.532433748245239,68936.61577987671,151078,0,68936.61577987671,0.6499000191688538,1.7185348272323608,10000,74650.41727089882,0.8681445121765137,0.7366388440132141,0.7703999876976013,1.1382390260696411,50000 -5736.198989152908,6.583809614181519,69356.5645096302,151997,0,69356.5645096302,0.6539000272750854,1.6868129968643188,10000,75107.14363598824,0.8676952719688416,0.7144691348075867,0.7712999582290649,1.1138060092926023,50000 -5776.524785995483,6.637896299362183,69776.55529689789,152918,0,69776.55529689789,0.6516000032424927,1.699005126953125,10000,75567.5617249012,0.8729101419448853,0.6911082863807678,0.772879958152771,1.115715265274048,50000 -5814.613633155823,6.696936845779419,70196.77774477005,153837,0,70196.77774477005,0.656000018119812,1.6826592683792114,10000,76025.97893810272,0.8698632717132568,0.7052730321884155,0.7731999754905701,1.1103754043579102,50000 -5849.63468337059,6.748908281326294,70616.80341100693,154759,0,70616.80341100693,0.6591000556945801,1.7006280422210691,10000,76481.1249115467,0.8730077743530273,0.7196352481842041,0.7730000019073486,1.1262603998184204,50000 -5886.051543951035,6.808360576629639,71037.05806207657,155678,0,71037.05806207657,0.6586000323295593,1.6903423070907593,10000,76937.91496515274,0.8753125071525574,0.6879866719245911,0.773419976234436,1.109251618385315,50000 -5918.482615470886,6.863985776901245,71457.10474681854,156597,0,71457.10474681854,0.6598000526428223,1.6865030527114868,10000,77390.49496340752,0.87353515625,0.7132282257080078,0.7741999626159668,1.1228692531585691,50000 -5951.625692844391,6.918470144271851,71877.31110310555,157518,0,71877.31110310555,0.6622000336647034,1.6971155405044556,10000,77843.94621014595,0.87548828125,0.706649124622345,0.774619996547699,1.1308605670928955,50000 -5985.92941904068,6.971019268035889,72297.48312687874,158438,0,72297.48312687874,0.6605000495910645,1.6729251146316528,10000,78298.52132201195,0.8758202791213989,0.6879441738128662,0.7753399610519409,1.1054383516311646,50000 -6019.583042383194,7.0212483406066895,72717.58166861534,159357,0,72717.58166861534,0.6556000113487244,1.6883338689804075,10000,78752.3706395626,0.8858202695846558,0.6633342504501343,0.7765199542045593,1.1071991920471191,50000 -6055.962126255035,7.0706892013549805,73137.55405020714,160277,0,73137.55405020714,0.6580000519752502,1.683222770690918,10000,79208.81894946098,0.8758593797683716,0.6875196099281311,0.776479959487915,1.1003152132034302,50000 -6093.915406227112,7.124929904937744,73557.74100780487,161199,0,73557.74100780487,0.6577000021934509,1.6750051975250244,10000,79667.06095504761,0.8811913728713989,0.671556293964386,0.7778199911117554,1.1007194519042969,50000 -6123.2500858306885,7.175060510635376,73978.00402450562,162117,0,73978.00402450562,0.6593000292778015,1.6683248281478882,10000,80116.75618052483,0.8856250047683716,0.6517072319984436,0.7775999903678894,1.0972752571105957,50000 -6159.228252887726,7.233767747879028,74398.24773645401,163035,0,74398.24773645401,0.6591000556945801,1.6852294206619265,10000,80573.08365154266,0.8808007836341858,0.6864597797393799,0.7800799608230591,1.105568766593933,50000 -6199.534796953201,7.287166118621826,74818.47779989243,163955,0,74818.47779989243,0.6650000214576721,1.66836416721344,10000,81033.72079610825,0.8822265267372131,0.6664526462554932,0.7804200053215027,1.0926640033721924,50000 -6232.852535486221,7.340670824050903,75238.48033547401,164877,0,75238.48033547401,0.6598000526428223,1.6655395030975342,10000,81487.14139032364,0.88636714220047,0.653562068939209,0.7805599570274353,1.0905064344406128,50000 -6272.963407754898,7.399810791015625,75658.44529294968,165798,0,75658.44529294968,0.6610000133514404,1.6883653402328491,10000,81947.32384061813,0.8838866949081421,0.6820747256278992,0.7806800007820129,1.1070207357406616,50000 -6303.624835968018,7.464045286178589,76078.42911958694,166719,0,76078.42911958694,0.6648000478744507,1.6466233730316162,10000,82398.08100652695,0.8874218463897705,0.640555739402771,0.7835800051689148,1.0742700099945068,50000 -6337.89848613739,7.528738975524902,76498.71252512932,167641,0,76498.71252512932,0.6659000515937805,1.6449536085128784,10000,82852.74976229668,0.8880468606948853,0.6393408179283142,0.7837599515914917,1.0733433961868286,50000 -6373.534379959106,7.579162836074829,76918.80696439743,168563,0,76918.80696439743,0.6633000373840332,1.6611697673797607,10000,83308.578540802,0.8868163824081421,0.655565083026886,0.7830599546432495,1.089125394821167,50000 -6412.271682739258,7.63238000869751,77338.80617928505,169484,0,77338.80617928505,0.6653000116348267,1.646395206451416,10000,83767.41523122787,0.8880468606948853,0.6424294710159302,0.7831000089645386,1.0735559463500977,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/measurements.csv deleted file mode 100644 index e11f4b0cd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1886 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.2919553,6.9077535,,,,,,,,,,,,,, -1,,,0.0008593749953433,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,34.00022578239441,64.13885116577148,34.00022578239441,30.13853859901428,0.0,0.0 -100,0.37659526,6.904312,,,,,,,,,,,,,, -200,0.3509707,6.892668,,,,,,,,,,,,,, -300,0.5187577,6.8431487,,,,,,,,,,,,,, -400,0.7402589,6.8125906,,,,,,,,,,,,,, -500,0.53772455,6.8132725,,,,,,,,,,,,,, -600,1.0234776,6.7523937,,,,,,,,,,,,,, -700,0.70563626,6.6762586,,,,,,,,,,,,,, -800,0.8845073,6.6316,,,,,,,,,,,,,, -864,,,0.01708984375,6.374186038970947,0.0169200003147125,6.388251781463623,50000.0,0.0127000007778406,6.433145999908447,10000.0,453.9772663116455,519.8541326522827,453.9772663116455,65.81416869163513,0.0183992385864257,0.0 -900,0.7145282,6.583419,,,,,,,,,,,,,, -1000,1.0284473,6.602406,,,,,,,,,,,,,, -1100,0.8624968,6.6336412,,,,,,,,,,,,,, -1200,1.6164496,6.500366,,,,,,,,,,,,,, -1300,0.92653483,6.677701,,,,,,,,,,,,,, -1400,0.99646914,6.467768,,,,,,,,,,,,,, -1500,1.2736408,6.7039127,,,,,,,,,,,,,, -1600,1.4053991,6.3390036,,,,,,,,,,,,,, -1700,1.0985726,6.3297687,,,,,,,,,,,,,, -1785,,,0.0503515601158142,5.795556545257568,0.048539999872446,5.824850559234619,50000.0,0.0368000008165836,5.947690486907959,10000.0,874.0901634693146,979.012511730194,874.0901634693146,104.78045845031738,0.050161600112915,0.0 -1800,1.087115,6.313601,,,,,,,,,,,,,, -1900,1.19471,6.277874,,,,,,,,,,,,,, -2000,1.3580103,6.5524373,,,,,,,,,,,,,, -2100,1.1504956,6.203087,,,,,,,,,,,,,, -2200,1.5609761,6.2499094,,,,,,,,,,,,,, -2300,1.1550694,6.2338,,,,,,,,,,,,,, -2400,1.3042578,6.160651,,,,,,,,,,,,,, -2500,1.0850233,6.5909066,,,,,,,,,,,,,, -2600,1.3903081,6.167694,,,,,,,,,,,,,, -2700,1.1040297,6.3824224,,,,,,,,,,,,,, -2708,,,0.07275390625,5.425717830657959,0.0682599991559982,5.4650092124938965,50000.0,0.0502000041306018,5.638961315155029,10000.0,1294.3986456394196,1438.6663777828217,1294.3986456394196,144.04840278625488,0.0794043540954589,0.0 -2800,1.1820748,6.1543183,,,,,,,,,,,,,, -2900,1.260597,6.131097,,,,,,,,,,,,,, -3000,0.8041457,6.4438286,,,,,,,,,,,,,, -3100,1.2167356,6.027967,,,,,,,,,,,,,, -3200,0.86570853,6.6371994,,,,,,,,,,,,,, -3300,1.1460798,6.0302,,,,,,,,,,,,,, -3400,1.197944,6.061206,,,,,,,,,,,,,, -3500,1.0631756,5.9942713,,,,,,,,,,,,,, -3600,1.1721869,6.055197,,,,,,,,,,,,,, -3630,,,0.1091406196355819,5.063491344451904,0.0984599962830543,5.12637186050415,50000.0,0.0756999999284744,5.355987071990967,10000.0,1714.635325908661,1898.3418984413147,1714.635325908661,183.41586256027225,0.1033854484558105,0.0 -3700,0.9261065,6.278217,,,,,,,,,,,,,, -3800,0.88390005,6.1177444,,,,,,,,,,,,,, -3900,2.1769443,5.985756,,,,,,,,,,,,,, -4000,1.0503231,6.4436917,,,,,,,,,,,,,, -4100,1.1038028,6.0486207,,,,,,,,,,,,,, -4200,1.069911,5.8232923,,,,,,,,,,,,,, -4300,1.0877109,5.8144736,,,,,,,,,,,,,, -4400,0.87822807,6.2684865,,,,,,,,,,,,,, -4500,1.1116664,5.7596006,,,,,,,,,,,,,, -4552,,,0.1644726544618606,4.604984760284424,0.143119990825653,4.721293449401856,50000.0,0.106500007212162,5.0176777839660645,10000.0,2134.942580461502,2353.803330421448,2134.942580461502,218.49655675888064,0.1298916339874267,0.0 -4600,0.80922574,6.6201262,,,,,,,,,,,,,, -4700,1.1231447,5.814674,,,,,,,,,,,,,, -4800,1.0549357,5.928412,,,,,,,,,,,,,, -4900,0.8754273,6.546876,,,,,,,,,,,,,, -5000,0.9114704,6.35841,,,,,,,,,,,,,, -5100,1.3197992,5.6071568,,,,,,,,,,,,,, -5200,1.1260432,5.788674,,,,,,,,,,,,,, -5300,1.154767,5.5896416,,,,,,,,,,,,,, -5400,1.0770575,6.0156198,,,,,,,,,,,,,, -5473,,,0.1963085830211639,4.357907772064209,0.1792799979448318,4.431841850280762,50000.0,0.1381999999284744,4.75654411315918,10000.0,2555.2755966186523,2808.118916273117,2555.2755966186523,252.4030647277832,0.1591732501983642,0.0 -5500,1.332841,5.6926284,,,,,,,,,,,,,, -5600,0.945159,6.0127115,,,,,,,,,,,,,, -5700,1.0956879,5.7036767,,,,,,,,,,,,,, -5800,1.1351709,5.505778,,,,,,,,,,,,,, -5900,0.90810925,5.7795763,,,,,,,,,,,,,, -6000,1.081264,5.3550467,,,,,,,,,,,,,, -6100,1.0388342,6.5422506,,,,,,,,,,,,,, -6200,0.8380175,6.0163555,,,,,,,,,,,,,, -6300,0.8516576,6.336608,,,,,,,,,,,,,, -6396,,,0.2514062523841858,3.91365122795105,0.2306199967861175,4.019598484039307,50000.0,0.1730000078678131,4.418845176696777,10000.0,2975.428004980088,3264.80242729187,2975.428004980088,288.8565876483917,0.1901607513427734,0.0 -6400,1.0463549,5.3353424,,,,,,,,,,,,,, -6500,1.1356792,6.429258,,,,,,,,,,,,,, -6600,1.170448,5.330953,,,,,,,,,,,,,, -6700,1.087127,5.244077,,,,,,,,,,,,,, -6800,0.72437257,6.3956666,,,,,,,,,,,,,, -6900,1.0137622,5.361326,,,,,,,,,,,,,, -7000,1.0365323,5.2645326,,,,,,,,,,,,,, -7100,1.2397124,5.199209,,,,,,,,,,,,,, -7200,0.94411695,6.1297474,,,,,,,,,,,,,, -7300,1.323659,5.194558,,,,,,,,,,,,,, -7320,,,0.2916015684604645,3.662829160690308,0.2649999856948852,3.806907415390015,50000.0,0.203900009393692,4.221038818359375,10000.0,3395.463133573532,3718.7628943920135,3395.463133573532,322.7090003490448,0.2164924144744873,0.0 -7400,0.79485565,6.3386784,,,,,,,,,,,,,, -7500,1.034121,5.1515417,,,,,,,,,,,,,, -7600,1.2494866,5.0778046,,,,,,,,,,,,,, -7700,1.1574705,5.221729,,,,,,,,,,,,,, -7800,0.9630364,5.1502743,,,,,,,,,,,,,, -7900,0.7295756,5.879191,,,,,,,,,,,,,, -8000,1.1314645,5.1195765,,,,,,,,,,,,,, -8100,1.0183882,5.1796556,,,,,,,,,,,,,, -8200,0.98963714,5.229722,,,,,,,,,,,,,, -8243,,,0.3251757621765136,3.437488555908203,0.3050200045108795,3.53970718383789,50000.0,0.2307000160217285,4.009162425994873,10000.0,3815.831436395645,4174.174175024033,3815.831436395645,357.6776604652405,0.2442131042480468,0.0 -8300,1.0680134,5.0216494,,,,,,,,,,,,,, -8400,0.86807275,6.0642176,,,,,,,,,,,,,, -8500,0.88077176,5.2439065,,,,,,,,,,,,,, -8600,1.1406748,4.9384093,,,,,,,,,,,,,, -8700,0.83089375,5.963119,,,,,,,,,,,,,, -8800,0.9643995,4.884458,,,,,,,,,,,,,, -8900,0.89104694,5.932562,,,,,,,,,,,,,, -9000,0.7126813,6.2953725,,,,,,,,,,,,,, -9100,0.8284871,6.3434505,,,,,,,,,,,,,, -9164,,,0.3589257597923279,3.239076375961304,0.3297999799251556,3.377380609512329,50000.0,0.254800021648407,3.8746755123138414,10000.0,4235.847739696503,4628.525053501129,4235.847739696503,391.93957924842834,0.2702317237854004,0.0 -9200,0.9899707,4.871069,,,,,,,,,,,,,, -9300,0.8738852,5.640869,,,,,,,,,,,,,, -9400,0.99099237,5.104598,,,,,,,,,,,,,, -9500,0.97087336,5.326378,,,,,,,,,,,,,, -9600,1.0195047,4.853391,,,,,,,,,,,,,, -9700,0.755561,5.5388975,,,,,,,,,,,,,, -9800,0.92666274,4.7431607,,,,,,,,,,,,,, -9900,0.8967715,4.9582634,,,,,,,,,,,,,, -10000,0.8026642,5.7984314,,,,,,,,,,,,,, -10086,,,0.3991015553474426,3.000910997390747,0.3657599985599518,3.1630349159240723,50000.0,0.281900018453598,3.6769216060638414,10000.0,4655.996357917786,5087.298045635223,4655.996357917786,430.49109721183777,0.2951619625091553,0.0 -10100,0.918515,4.742156,,,,,,,,,,,,,, -10200,1.0411271,4.930179,,,,,,,,,,,,,, -10300,0.9331661,4.76144,,,,,,,,,,,,,, -10400,0.9245242,4.5965605,,,,,,,,,,,,,, -10500,0.99251086,4.6698914,,,,,,,,,,,,,, -10600,0.6134409,5.776593,,,,,,,,,,,,,, -10700,0.78558487,5.932048,,,,,,,,,,,,,, -10800,1.0558656,4.655591,,,,,,,,,,,,,, -10900,0.8855275,4.9465227,,,,,,,,,,,,,, -11000,0.8713772,4.797397,,,,,,,,,,,,,, -11009,,,0.4095703065395355,2.957759618759156,0.3806999921798706,3.086174726486206,50000.0,0.2939999997615814,3.614280223846436,10000.0,5075.949893951416,5542.012524604797,5075.949893951416,465.17704224586487,0.3227870464324951,0.0 -11100,0.96336275,4.71762,,,,,,,,,,,,,, -11200,0.7357022,5.611048,,,,,,,,,,,,,, -11300,0.92453736,4.5691094,,,,,,,,,,,,,, -11400,0.84587985,5.2968073,,,,,,,,,,,,,, -11500,0.9728796,4.7140636,,,,,,,,,,,,,, -11600,0.9195223,4.6643925,,,,,,,,,,,,,, -11700,0.98094785,4.617242,,,,,,,,,,,,,, -11800,0.8757799,4.553348,,,,,,,,,,,,,, -11900,1.0508312,4.5581985,,,,,,,,,,,,,, -11932,,,0.4536913931369781,2.704256534576416,0.4152999818325043,2.878851890563965,50000.0,0.3217000067234039,3.442434549331665,10000.0,5496.126168727875,5996.636655807495,5496.126168727875,499.54923462867737,0.3511896133422851,0.0 -12000,0.9333936,4.9500976,,,,,,,,,,,,,, -12100,0.9307103,4.5287056,,,,,,,,,,,,,, -12200,0.8964753,4.654054,,,,,,,,,,,,,, -12300,0.67517316,6.102918,,,,,,,,,,,,,, -12400,0.687347,5.375107,,,,,,,,,,,,,, -12500,0.6381921,5.867324,,,,,,,,,,,,,, -12600,0.9790401,4.4440956,,,,,,,,,,,,,, -12700,0.79131544,5.403439,,,,,,,,,,,,,, -12800,0.8698023,4.779932,,,,,,,,,,,,,, -12852,,,0.4732812345027923,2.620152473449707,0.4334799945354461,2.793245553970337,50000.0,0.332500010728836,3.3532426357269287,10000.0,5916.12104177475,6450.153452396393,5916.12104177475,532.9983906745911,0.3771047592163086,0.0 -12900,0.7566316,5.876338,,,,,,,,,,,,,, -13000,0.82721967,5.4147706,,,,,,,,,,,,,, -13100,0.9457384,4.4940486,,,,,,,,,,,,,, -13200,0.630266,6.0885296,,,,,,,,,,,,,, -13300,1.0455534,4.4189835,,,,,,,,,,,,,, -13400,0.9169455,4.399697,,,,,,,,,,,,,, -13500,0.95277005,4.3740597,,,,,,,,,,,,,, -13600,1.087411,4.420727,,,,,,,,,,,,,, -13700,0.9660407,4.4041977,,,,,,,,,,,,,, -13774,,,0.4964648187160492,2.481999397277832,0.4541199803352356,2.691492795944214,50000.0,0.35630002617836,3.2585787773132324,10000.0,6336.528871059418,6906.528246641159,6336.528871059418,568.8881900310516,0.4073190689086914,0.0 -13800,0.67728734,5.749548,,,,,,,,,,,,,, -13900,1.1032205,4.3669233,,,,,,,,,,,,,, -14000,0.59906673,6.0073743,,,,,,,,,,,,,, -14100,0.85602814,4.347744,,,,,,,,,,,,,, -14200,0.9634938,4.3659286,,,,,,,,,,,,,, -14300,0.87226295,4.3673687,,,,,,,,,,,,,, -14400,0.8524234,4.3926144,,,,,,,,,,,,,, -14500,0.9351778,4.318496,,,,,,,,,,,,,, -14600,0.84267825,4.5734487,,,,,,,,,,,,,, -14697,,,0.5178515315055847,2.393428325653076,0.4753799736499786,2.57056212425232,50000.0,0.3745000064373016,3.150657892227173,10000.0,6756.774499177933,7357.662885427475,6756.774499177933,599.7011110782623,0.4356505870819092,0.0 -14700,0.71388745,5.6809735,,,,,,,,,,,,,, -14800,0.7796211,5.617694,,,,,,,,,,,,,, -14900,0.73698634,5.0170946,,,,,,,,,,,,,, -15000,0.70435023,5.1940384,,,,,,,,,,,,,, -15100,0.95529914,4.2806892,,,,,,,,,,,,,, -15200,0.6973598,5.9954805,,,,,,,,,,,,,, -15300,0.9064243,4.2641315,,,,,,,,,,,,,, -15400,0.8571712,4.4650426,,,,,,,,,,,,,, -15500,0.8817059,4.349507,,,,,,,,,,,,,, -15600,0.6982482,5.163825,,,,,,,,,,,,,, -15618,,,0.5367968678474426,2.243785858154297,0.49685999751091,2.430511713027954,50000.0,0.3872000277042389,3.020982027053833,10000.0,7176.8117735385895,7812.835678100586,7176.8117735385895,634.7569868564606,0.4672560691833496,0.0 -15700,0.98601174,4.1880126,,,,,,,,,,,,,, -15800,0.96616656,4.160508,,,,,,,,,,,,,, -15900,0.9049595,4.265627,,,,,,,,,,,,,, -16000,1.0068066,4.243218,,,,,,,,,,,,,, -16100,0.7362326,5.640515,,,,,,,,,,,,,, -16200,0.6265028,5.8974066,,,,,,,,,,,,,, -16300,0.87044656,4.621498,,,,,,,,,,,,,, -16400,0.7727037,5.3708487,,,,,,,,,,,,,, -16500,0.9990496,4.0524178,,,,,,,,,,,,,, -16540,,,0.562207043170929,2.14985466003418,0.507319986820221,2.4012911319732666,50000.0,0.3929000198841095,2.9950408935546875,10000.0,7597.053501844406,8267.926774263382,7597.053501844406,669.5322709083557,0.4950077533721924,0.0 -16600,0.97979957,4.1376176,,,,,,,,,,,,,, -16700,0.94777524,4.244416,,,,,,,,,,,,,, -16800,0.67755425,5.5526595,,,,,,,,,,,,,, -16900,0.8666794,4.08652,,,,,,,,,,,,,, -17000,0.9365709,4.153856,,,,,,,,,,,,,, -17100,0.8642768,4.2456455,,,,,,,,,,,,,, -17200,0.8785513,4.3056145,,,,,,,,,,,,,, -17300,0.79577565,4.610665,,,,,,,,,,,,,, -17400,0.7193329,5.1280327,,,,,,,,,,,,,, -17463,,,0.553515613079071,2.2031924724578857,0.5131999850273132,2.382346868515014,50000.0,0.398900032043457,2.9856181144714355,10000.0,8017.169707298279,8721.486993074417,8017.169707298279,702.9022083282471,0.5227203369140625,0.0 -17500,0.77671266,5.8335977,,,,,,,,,,,,,, -17600,0.9325902,4.2229295,,,,,,,,,,,,,, -17700,0.9171856,4.038625,,,,,,,,,,,,,, -17800,0.8012898,4.279474,,,,,,,,,,,,,, -17900,0.6869376,4.965243,,,,,,,,,,,,,, -18000,0.76774865,4.658698,,,,,,,,,,,,,, -18100,0.76383823,5.1576066,,,,,,,,,,,,,, -18200,0.8279145,4.5891156,,,,,,,,,,,,,, -18300,0.9182772,4.1941986,,,,,,,,,,,,,, -18381,,,0.5694921612739563,2.0806596279144287,0.5267999768257141,2.2643842697143555,50000.0,0.4169000089168548,2.871544122695923,10000.0,8437.620900630951,9175.58203101158,8437.620900630951,736.4718625545502,0.549354076385498,0.0 -18400,0.6877149,5.9227276,,,,,,,,,,,,,, -18500,0.8411238,4.3141246,,,,,,,,,,,,,, -18600,0.9422291,4.3147426,,,,,,,,,,,,,, -18700,0.87886155,4.1143007,,,,,,,,,,,,,, -18800,0.9005705,4.1886888,,,,,,,,,,,,,, -18900,0.9033544,4.2082076,,,,,,,,,,,,,, -19000,0.90570354,3.9900212,,,,,,,,,,,,,, -19100,0.84647053,4.2352624,,,,,,,,,,,,,, -19200,0.7286492,5.614108,,,,,,,,,,,,,, -19298,,,0.5946288704872131,1.9711012840271,0.5412999987602234,2.2097156047821045,50000.0,0.4219000339508056,2.838536739349365,10000.0,8857.677038431168,9631.36013674736,8857.677038431168,772.117954492569,0.5785338878631592,0.0 -19300,0.90616316,4.0626483,,,,,,,,,,,,,, -19400,0.9480623,3.8803754,,,,,,,,,,,,,, -19500,0.93941236,4.095715,,,,,,,,,,,,,, -19600,0.8403751,5.45032,,,,,,,,,,,,,, -19700,0.873022,4.0112567,,,,,,,,,,,,,, -19800,0.90600145,4.1123457,,,,,,,,,,,,,, -19900,0.7887542,5.433602,,,,,,,,,,,,,, -20000,0.78186685,5.7716103,,,,,,,,,,,,,, -20100,1.0163059,4.160527,,,,,,,,,,,,,, -20200,0.8145724,4.944845,,,,,,,,,,,,,, -20217,,,0.5886132717132568,1.966526865959168,0.5514000058174133,2.1377973556518555,50000.0,0.4321000277996063,2.757359027862549,10000.0,9277.779623508452,10086.7133436203,9277.779623508452,807.2917928695679,0.6079757213592529,0.0 -20300,0.80207664,4.426675,,,,,,,,,,,,,, -20400,0.93438876,4.0660343,,,,,,,,,,,,,, -20500,0.8893857,4.4641037,,,,,,,,,,,,,, -20600,0.7305509,5.2878685,,,,,,,,,,,,,, -20700,0.94351494,3.9779186,,,,,,,,,,,,,, -20800,1.0625829,4.0376563,,,,,,,,,,,,,, -20900,0.8439469,4.9539676,,,,,,,,,,,,,, -21000,0.96498066,4.0212235,,,,,,,,,,,,,, -21100,0.8774776,3.863089,,,,,,,,,,,,,, -21138,,,0.6024804711341858,1.935488104820252,0.5561999678611755,2.1444432735443115,50000.0,0.4409000277519226,2.7603559494018555,10000.0,9697.771949529648,10541.31220149994,9697.771949529648,841.8244321346283,0.6349647045135498,0.0 -21200,1.0235306,3.9418397,,,,,,,,,,,,,, -21300,0.8982383,3.933818,,,,,,,,,,,,,, -21400,0.69550776,5.409296,,,,,,,,,,,,,, -21500,0.91622543,3.9834168,,,,,,,,,,,,,, -21600,0.6752481,5.3144336,,,,,,,,,,,,,, -21700,0.9060565,4.0302854,,,,,,,,,,,,,, -21800,0.8955712,4.1284294,,,,,,,,,,,,,, -21900,0.86368823,4.0542426,,,,,,,,,,,,,, -22000,0.91261274,4.1833735,,,,,,,,,,,,,, -22059,,,0.6153905987739563,1.863569259643555,0.5629599690437317,2.098595142364502,50000.0,0.4448000192642212,2.710822343826294,10000.0,10117.83263373375,10995.936970949171,10117.83263373375,876.3112514019012,0.6652498245239258,0.0 -22100,0.9847958,4.009213,,,,,,,,,,,,,, -22200,0.9672073,4.0315895,,,,,,,,,,,,,, -22300,0.8916628,4.251997,,,,,,,,,,,,,, -22400,0.974029,3.9160635,,,,,,,,,,,,,, -22500,0.9230734,3.8389058,,,,,,,,,,,,,, -22600,0.760719,4.818745,,,,,,,,,,,,,, -22700,0.9660525,3.942885,,,,,,,,,,,,,, -22800,0.7946001,5.771125,,,,,,,,,,,,,, -22900,0.7299762,5.3800325,,,,,,,,,,,,,, -22978,,,0.6119140386581421,1.860370397567749,0.568619966506958,2.0590929985046387,50000.0,0.451200008392334,2.6695117950439453,10000.0,10537.896867275238,11448.413235902786,10537.896867275238,908.642866373062,0.6947681903839111,0.0 -23000,0.79566944,4.888535,,,,,,,,,,,,,, -23100,0.8887842,4.268422,,,,,,,,,,,,,, -23200,0.8209897,4.4959383,,,,,,,,,,,,,, -23300,0.9219298,3.968257,,,,,,,,,,,,,, -23400,0.7283256,5.540422,,,,,,,,,,,,,, -23500,0.8965853,4.12897,,,,,,,,,,,,,, -23600,0.9866568,3.786697,,,,,,,,,,,,,, -23700,0.97650796,3.98961,,,,,,,,,,,,,, -23800,0.9074548,4.1119013,,,,,,,,,,,,,, -23900,0.8608662,4.5504546,,,,,,,,,,,,,, -23901,,,0.619140625,1.842691659927368,0.5765399932861328,2.036699056625366,50000.0,0.4563000202178955,2.6497387886047363,10000.0,10958.168203353882,11903.424647808077,10958.168203353882,943.3075177669524,0.7228202819824219,0.0 -24000,0.9764653,4.0506735,,,,,,,,,,,,,, -24100,0.9716837,3.9329991,,,,,,,,,,,,,, -24200,0.89980274,3.83237,,,,,,,,,,,,,, -24300,0.8539586,4.5110598,,,,,,,,,,,,,, -24400,0.82929385,4.738932,,,,,,,,,,,,,, -24500,0.8323049,4.251837,,,,,,,,,,,,,, -24600,0.87512004,4.332034,,,,,,,,,,,,,, -24700,0.93253464,3.968293,,,,,,,,,,,,,, -24800,0.9484761,3.7486908,,,,,,,,,,,,,, -24825,,,0.6344335675239563,1.8031526803970337,0.5821200013160706,2.0396535396575928,50000.0,0.455700010061264,2.65883469581604,10000.0,11378.107141256332,12354.893264770508,11378.107141256332,974.7620024681092,0.7505502700805664,0.0 -24900,0.7528194,5.0274506,,,,,,,,,,,,,, -25000,0.9423441,3.9051797,,,,,,,,,,,,,, -25100,0.9036048,5.2819395,,,,,,,,,,,,,, -25200,1.0802499,3.969688,,,,,,,,,,,,,, -25300,0.8473127,4.2868237,,,,,,,,,,,,,, -25400,1.009407,3.8138742,,,,,,,,,,,,,, -25500,0.9702029,3.9693794,,,,,,,,,,,,,, -25600,0.94910294,3.9372387,,,,,,,,,,,,,, -25700,0.8933313,4.6608267,,,,,,,,,,,,,, -25745,,,0.6540429592132568,1.6929877996444702,0.5794999599456787,2.018251895904541,50000.0,0.4617000222206116,2.632565975189209,10000.0,11798.067595005035,12809.31416130066,11798.067595005035,1009.1432175636292,0.782789945602417,0.0 -25800,0.85159934,5.798152,,,,,,,,,,,,,, -25900,0.9933637,3.936332,,,,,,,,,,,,,, -26000,0.8908841,4.2357183,,,,,,,,,,,,,, -26100,0.8528444,4.143178,,,,,,,,,,,,,, -26200,0.99899244,3.9229622,,,,,,,,,,,,,, -26300,0.8920692,3.8639116,,,,,,,,,,,,,, -26400,0.85753053,5.3287063,,,,,,,,,,,,,, -26500,0.76649576,5.637409,,,,,,,,,,,,,, -26600,0.9975515,3.9141707,,,,,,,,,,,,,, -26668,,,0.6333398222923279,1.7327735424041748,0.5909799933433533,1.9266051054000848,50000.0,0.4686000347137451,2.565860748291016,10000.0,12218.119409561155,13260.34297466278,12218.119409561155,1040.0437922477722,0.8123970031738281,0.0 -26700,0.9785491,3.8011346,,,,,,,,,,,,,, -26800,0.76679575,5.694995,,,,,,,,,,,,,, -26900,0.9653517,3.83884,,,,,,,,,,,,,, -27000,1.0963022,3.8440843,,,,,,,,,,,,,, -27100,0.9483148,3.7708642,,,,,,,,,,,,,, -27200,0.9271001,4.212255,,,,,,,,,,,,,, -27300,0.78323466,4.581154,,,,,,,,,,,,,, -27400,0.99004924,3.775039,,,,,,,,,,,,,, -27500,1.0124508,3.8790457,,,,,,,,,,,,,, -27589,,,0.6485351324081421,1.7430307865142822,0.5958600044250488,1.972648024559021,50000.0,0.4768000245094299,2.580284595489502,10000.0,12638.257354021072,13714.76744222641,12638.257354021072,1074.250785589218,0.8448200225830078,0.0 -27600,0.74369437,5.336982,,,,,,,,,,,,,, -27700,0.9692302,3.7625544,,,,,,,,,,,,,, -27800,0.9478612,3.8595958,,,,,,,,,,,,,, -27900,0.93131,3.8621855,,,,,,,,,,,,,, -28000,0.96134883,4.0497684,,,,,,,,,,,,,, -28100,0.9225636,3.8328743,,,,,,,,,,,,,, -28200,0.8914054,3.991331,,,,,,,,,,,,,, -28300,0.9365382,3.965746,,,,,,,,,,,,,, -28400,0.7921332,5.095827,,,,,,,,,,,,,, -28500,1.0141084,3.8008428,,,,,,,,,,,,,, -28513,,,0.6622265577316284,1.6519564390182495,0.5966399908065796,1.929701805114746,50000.0,0.4702000319957733,2.5671634674072266,10000.0,13058.69831609726,14165.95332980156,13058.69831609726,1104.9193880558014,0.8732032775878906,0.0 -28600,0.85261554,4.497584,,,,,,,,,,,,,, -28700,1.0751613,3.7648253,,,,,,,,,,,,,, -28800,0.81383264,4.983162,,,,,,,,,,,,,, -28900,1.0180076,3.8148627,,,,,,,,,,,,,, -29000,0.85067624,5.0235596,,,,,,,,,,,,,, -29100,0.998634,3.8850806,,,,,,,,,,,,,, -29200,0.8127319,5.5714474,,,,,,,,,,,,,, -29300,1.0132699,3.7853506,,,,,,,,,,,,,, -29400,0.82245564,5.700489,,,,,,,,,,,,,, -29436,,,0.6512890458106995,1.7113102674484253,0.6077600121498108,1.9051331281661987,50000.0,0.480400025844574,2.534058094024658,10000.0,13478.946665525436,14620.107602596285,13478.946665525436,1138.7440557479858,0.9072282314300536,0.0 -29500,0.95115507,3.7902095,,,,,,,,,,,,,, -29600,1.1291822,3.8156114,,,,,,,,,,,,,, -29700,1.0100733,3.745331,,,,,,,,,,,,,, -29800,0.99409205,3.783735,,,,,,,,,,,,,, -29900,0.82356113,5.407349,,,,,,,,,,,,,, -30000,0.93202645,3.78617,,,,,,,,,,,,,, -30100,1.0388532,3.7210245,,,,,,,,,,,,,, -30200,0.92984164,3.6915402,,,,,,,,,,,,,, -30300,1.0862247,3.810345,,,,,,,,,,,,,, -30358,,,0.6586328148841858,1.6941852569580078,0.6078000068664551,1.9113799333572388,50000.0,0.4905000329017639,2.532050609588623,10000.0,13898.911831855774,15075.097905397415,13898.911831855774,1173.688907146454,0.9404423236846924,0.0 -30400,0.9896845,3.714048,,,,,,,,,,,,,, -30500,0.8665404,4.299871,,,,,,,,,,,,,, -30600,1.0178163,3.691092,,,,,,,,,,,,,, -30700,1.0051795,3.8952456,,,,,,,,,,,,,, -30800,1.0553894,3.7764044,,,,,,,,,,,,,, -30900,0.9736281,3.8332675,,,,,,,,,,,,,, -31000,0.80778265,5.340073,,,,,,,,,,,,,, -31100,0.86088526,5.3963923,,,,,,,,,,,,,, -31200,0.85007536,5.073598,,,,,,,,,,,,,, -31283,,,0.6657617092132568,1.6142001152038574,0.6137599945068359,1.8565632104873653,50000.0,0.4874000251293182,2.487379312515259,10000.0,14319.18147277832,15527.85389828682,14319.18147277832,1206.0904257297516,0.9788103103637696,0.0 -31300,1.093036,3.7537863,,,,,,,,,,,,,, -31400,0.8570768,4.635314,,,,,,,,,,,,,, -31500,1.0962592,3.7005265,,,,,,,,,,,,,, -31600,0.9359064,3.7180457,,,,,,,,,,,,,, -31700,0.8071053,5.5628667,,,,,,,,,,,,,, -31800,1.0222193,3.7248447,,,,,,,,,,,,,, -31900,0.8525091,5.416268,,,,,,,,,,,,,, -32000,0.8339154,4.304206,,,,,,,,,,,,,, -32100,1.0359353,3.78256,,,,,,,,,,,,,, -32200,0.9700317,3.839234,,,,,,,,,,,,,, -32206,,,0.6634374856948853,1.6532877683639526,0.6112799644470215,1.866252422332764,50000.0,0.4907000362873077,2.5017411708831787,10000.0,14739.35320687294,15982.26224064827,14739.35320687294,1240.2513403892517,1.008094310760498,0.0 -32300,0.7955597,4.746283,,,,,,,,,,,,,, -32400,0.94325966,3.9340072,,,,,,,,,,,,,, -32500,0.87781906,4.2748384,,,,,,,,,,,,,, -32600,1.0473194,3.8133352,,,,,,,,,,,,,, -32700,0.905105,4.2662635,,,,,,,,,,,,,, -32800,0.88946563,4.744584,,,,,,,,,,,,,, -32900,0.8584528,4.665579,,,,,,,,,,,,,, -33000,0.95524603,3.8132932,,,,,,,,,,,,,, -33100,0.9164209,4.233294,,,,,,,,,,,,,, -33130,,,0.6699804663658142,1.617461323738098,0.6247999668121338,1.8267481327056885,50000.0,0.5005000233650208,2.448014974594116,10000.0,15159.512373924255,16436.065540075302,15159.512373924255,1273.814103603363,1.0417673587799072,0.0 -33200,1.0616542,3.7538223,,,,,,,,,,,,,, -33300,0.86843944,5.3789415,,,,,,,,,,,,,, -33400,0.91852754,4.263383,,,,,,,,,,,,,, -33500,1.0752637,3.6938596,,,,,,,,,,,,,, -33600,0.95050627,3.7393334,,,,,,,,,,,,,, -33700,0.98517644,3.655054,,,,,,,,,,,,,, -33800,1.0710534,3.792612,,,,,,,,,,,,,, -33900,0.9543319,4.1038485,,,,,,,,,,,,,, -34000,1.0035291,3.6217747,,,,,,,,,,,,,, -34053,,,0.6822265386581421,1.577724575996399,0.6250799894332886,1.833191156387329,50000.0,0.4978000223636627,2.4420933723449707,10000.0,15579.846864700316,16890.66156053543,15579.846864700316,1307.9982221126556,1.0729899406433103,0.0 -34100,0.85457146,4.3398356,,,,,,,,,,,,,, -34200,0.86836153,5.2996025,,,,,,,,,,,,,, -34300,0.9287989,4.000804,,,,,,,,,,,,,, -34400,0.82957155,4.1848936,,,,,,,,,,,,,, -34500,1.00416,3.7094593,,,,,,,,,,,,,, -34600,0.8473501,4.60102,,,,,,,,,,,,,, -34700,0.9493951,3.6972272,,,,,,,,,,,,,, -34800,0.8231222,5.0629654,,,,,,,,,,,,,, -34900,1.0028569,3.7320752,,,,,,,,,,,,,, -34975,,,0.7008007764816284,1.4769867658615112,0.6222400069236755,1.811497449874878,50000.0,0.4970000088214874,2.4435598850250244,10000.0,15999.829869508743,17347.502505779266,15999.829869508743,1344.7656226158142,1.116785764694214,0.0 -35000,0.9399505,4.0374,,,,,,,,,,,,,, -35100,0.99862367,3.7426724,,,,,,,,,,,,,, -35200,1.077277,3.8681211,,,,,,,,,,,,,, -35300,1.0131813,3.6882768,,,,,,,,,,,,,, -35400,0.9188622,5.1595454,,,,,,,,,,,,,, -35500,1.0556381,3.7696834,,,,,,,,,,,,,, -35600,0.8838116,4.4699316,,,,,,,,,,,,,, -35700,1.0149729,3.6671076,,,,,,,,,,,,,, -35800,1.0377194,3.6865556,,,,,,,,,,,,,, -35898,,,0.6775780916213989,1.5361953973770142,0.6276800036430359,1.7578741312026978,50000.0,0.5062000155448914,2.371567964553833,10000.0,16420.072093486786,17797.47596859932,16420.072093486786,1374.418939828873,1.147679090499878,0.0 -35900,0.97164387,3.6519623,,,,,,,,,,,,,, -36000,0.97275543,3.6138084,,,,,,,,,,,,,, -36100,1.0323437,3.8750699,,,,,,,,,,,,,, -36200,0.8711677,5.219004,,,,,,,,,,,,,, -36300,1.0436406,3.7457113,,,,,,,,,,,,,, -36400,0.9362679,5.2246666,,,,,,,,,,,,,, -36500,1.127458,3.7691443,,,,,,,,,,,,,, -36600,0.95256495,3.6341968,,,,,,,,,,,,,, -36700,0.9886393,3.6239274,,,,,,,,,,,,,, -36800,1.0167223,3.578269,,,,,,,,,,,,,, -36818,,,0.6763281226158142,1.6137640476226809,0.6283800005912781,1.842094898223877,50000.0,0.5077000260353088,2.443421125411988,10000.0,16840.190497398376,18252.90421462059,16840.190497398376,1409.6424214839935,1.1869757175445557,0.0 -36900,0.9862275,3.967841,,,,,,,,,,,,,, -37000,1.0434364,3.7093544,,,,,,,,,,,,,, -37100,0.9077364,3.8356774,,,,,,,,,,,,,, -37200,1.0389268,3.7785192,,,,,,,,,,,,,, -37300,0.9609769,3.689135,,,,,,,,,,,,,, -37400,1.0735815,3.7026248,,,,,,,,,,,,,, -37500,0.9263669,3.9896946,,,,,,,,,,,,,, -37600,0.8797741,5.5419016,,,,,,,,,,,,,, -37700,0.9629912,4.006117,,,,,,,,,,,,,, -37740,,,0.6975976228713989,1.450102925300598,0.631060004234314,1.7535679340362549,50000.0,0.5083000063896179,2.3802409172058105,10000.0,17260.28178691864,18705.626772880554,17260.28178691864,1442.1937334537506,1.2198548316955566,0.0 -37800,1.0007144,3.6728222,,,,,,,,,,,,,, -37900,1.0474317,3.7030044,,,,,,,,,,,,,, -38000,1.0232195,3.9328308,,,,,,,,,,,,,, -38100,0.9506374,3.6084747,,,,,,,,,,,,,, -38200,0.9556934,4.1852746,,,,,,,,,,,,,, -38300,1.0531045,3.6574023,,,,,,,,,,,,,, -38400,0.9505411,4.1466866,,,,,,,,,,,,,, -38500,0.83479685,4.611985,,,,,,,,,,,,,, -38600,0.95045835,3.6467605,,,,,,,,,,,,,, -38663,,,0.6893359422683716,1.514500379562378,0.6391400098800659,1.7250676155090332,50000.0,0.5202000141143799,2.346991777420044,10000.0,17680.43620157242,19159.576062202454,17680.43620157242,1475.909290790558,1.252131700515747,0.0 -38700,0.8705801,4.3600206,,,,,,,,,,,,,, -38800,1.0117749,3.5636535,,,,,,,,,,,,,, -38900,1.0456696,3.6079388,,,,,,,,,,,,,, -39000,0.8990947,4.424861,,,,,,,,,,,,,, -39100,1.0685384,3.6102886,,,,,,,,,,,,,, -39200,0.9755602,3.6277452,,,,,,,,,,,,,, -39300,1.0649904,3.5594687,,,,,,,,,,,,,, -39400,0.860042,4.6882467,,,,,,,,,,,,,, -39500,1.0343453,5.3076835,,,,,,,,,,,,,, -39584,,,0.6898046731948853,1.4997754096984863,0.6385999917984009,1.728279948234558,50000.0,0.5121999979019165,2.3584721088409424,10000.0,18100.48945403099,19613.12621331215,18100.48945403099,1509.32852268219,1.2823808193206787,0.0 -39600,0.9338739,3.7776403,,,,,,,,,,,,,, -39700,0.9689617,3.860242,,,,,,,,,,,,,, -39800,0.9731024,3.5994706,,,,,,,,,,,,,, -39900,0.89007753,5.2480955,,,,,,,,,,,,,, -40000,1.1284597,4.149734,,,,,,,,,,,,,, -40100,1.0178909,3.6518967,,,,,,,,,,,,,, -40200,1.0314522,4.0220757,,,,,,,,,,,,,, -40300,1.068289,3.6486073,,,,,,,,,,,,,, -40400,1.0182652,3.7844195,,,,,,,,,,,,,, -40500,0.92528576,5.2152224,,,,,,,,,,,,,, -40507,,,0.70068359375,1.4647341966629028,0.6398599743843079,1.738932967185974,50000.0,0.5217000246047974,2.341832637786865,10000.0,18520.47150492668,20064.541659355164,18520.47150492668,1540.6759810447693,1.316420316696167,0.0 -40600,0.98973805,3.7175138,,,,,,,,,,,,,, -40700,0.98973656,3.7055721,,,,,,,,,,,,,, -40800,1.1896367,3.7598934,,,,,,,,,,,,,, -40900,0.9054351,4.7014456,,,,,,,,,,,,,, -41000,1.0048217,3.6257772,,,,,,,,,,,,,, -41100,1.0320526,3.5699596,,,,,,,,,,,,,, -41200,0.949991,4.4538765,,,,,,,,,,,,,, -41300,1.1413031,3.6547308,,,,,,,,,,,,,, -41400,1.0977077,3.6483219,,,,,,,,,,,,,, -41427,,,0.6917382478713989,1.4947209358215332,0.6426999568939209,1.7017552852630615,50000.0,0.5194000005722046,2.328978538513184,10000.0,18940.805812358856,20520.362723588943,18940.805812358856,1576.082843542099,1.3496661186218262,0.0 -41500,1.0104401,3.5954978,,,,,,,,,,,,,, -41600,1.0560042,3.6478763,,,,,,,,,,,,,, -41700,1.0149286,5.2083664,,,,,,,,,,,,,, -41800,1.0490488,3.6107998,,,,,,,,,,,,,, -41900,0.9140059,4.8422303,,,,,,,,,,,,,, -42000,0.8507957,4.517227,,,,,,,,,,,,,, -42100,0.8814625,3.883977,,,,,,,,,,,,,, -42200,1.0032883,3.7246654,,,,,,,,,,,,,, -42300,1.0149239,3.9040077,,,,,,,,,,,,,, -42353,,,0.70068359375,1.465484619140625,0.6469199657440186,1.71008563041687,50000.0,0.5272000432014465,2.3286025524139404,10000.0,19361.10916209221,20972.88152360916,19361.10916209221,1608.2144901752472,1.3862807750701904,0.0 -42400,0.9968195,3.92545,,,,,,,,,,,,,, -42500,0.9870267,4.2775683,,,,,,,,,,,,,, -42600,1.0166979,3.525135,,,,,,,,,,,,,, -42700,1.0112063,3.8380117,,,,,,,,,,,,,, -42800,0.9925763,3.6673217,,,,,,,,,,,,,, -42900,1.063416,3.5361857,,,,,,,,,,,,,, -43000,1.0210942,3.6732564,,,,,,,,,,,,,, -43100,1.0483624,3.5430784,,,,,,,,,,,,,, -43200,1.0493045,3.5799716,,,,,,,,,,,,,, -43275,,,0.7005859017372131,1.4444526433944702,0.6436799764633179,1.7053855657577517,50000.0,0.5193000435829163,2.333106279373169,10000.0,19781.25616669655,21427.411111593246,19781.25616669655,1642.51211977005,1.4228754043579102,0.0 -43300,1.1695586,3.6976366,,,,,,,,,,,,,, -43400,0.8958577,4.254368,,,,,,,,,,,,,, -43500,1.0327008,3.5488918,,,,,,,,,,,,,, -43600,0.9379797,5.320066,,,,,,,,,,,,,, -43700,1.0347524,3.7131262,,,,,,,,,,,,,, -43800,1.0543219,3.6548421,,,,,,,,,,,,,, -43900,0.94316316,5.419023,,,,,,,,,,,,,, -44000,0.99190676,3.5961907,,,,,,,,,,,,,, -44100,1.0127593,3.5980577,,,,,,,,,,,,,, -44199,,,0.7070116996765137,1.4227125644683838,0.6459999680519104,1.6852363348007202,50000.0,0.522599995136261,2.296396255493164,10000.0,20201.509860515594,21880.785943746567,20201.509860515594,1675.5485351085665,1.4599545001983645,0.0 -44200,0.960403,3.509264,,,,,,,,,,,,,, -44300,0.9050046,4.5943584,,,,,,,,,,,,,, -44400,0.9579112,4.475122,,,,,,,,,,,,,, -44500,1.1295244,3.5979855,,,,,,,,,,,,,, -44600,1.0962604,3.5814023,,,,,,,,,,,,,, -44700,0.935013,4.520934,,,,,,,,,,,,,, -44800,1.1189668,3.5630155,,,,,,,,,,,,,, -44900,1.033481,3.740236,,,,,,,,,,,,,, -45000,1.0132776,5.342588,,,,,,,,,,,,,, -45100,1.080235,3.5973425,,,,,,,,,,,,,, -45121,,,0.7047460675239563,1.4595746994018557,0.6520599722862244,1.6880558729171753,50000.0,0.5268000364303589,2.3047330379486084,10000.0,20621.46268749237,22336.67161679268,20621.46268749237,1711.402881860733,1.4905712604522705,0.0 -45200,0.954243,5.396023,,,,,,,,,,,,,, -45300,1.0040611,3.8461857,,,,,,,,,,,,,, -45400,1.006126,3.6419277,,,,,,,,,,,,,, -45500,0.9496719,4.3318176,,,,,,,,,,,,,, -45600,0.9503613,4.2046394,,,,,,,,,,,,,, -45700,1.0103581,3.766582,,,,,,,,,,,,,, -45800,1.0815006,3.5478399,,,,,,,,,,,,,, -45900,1.0782769,3.6435037,,,,,,,,,,,,,, -46000,0.99091685,3.6430092,,,,,,,,,,,,,, -46045,,,0.703808605670929,1.4133350849151611,0.649899959564209,1.6546680927276611,50000.0,0.5236999988555908,2.286877155303955,10000.0,21041.53035831452,22787.82704544068,21041.53035831452,1742.409699678421,1.524674892425537,0.0 -46100,1.1261573,3.5029116,,,,,,,,,,,,,, -46200,1.0556768,3.8300967,,,,,,,,,,,,,, -46300,1.0135901,3.6387193,,,,,,,,,,,,,, -46400,0.89721316,5.2796907,,,,,,,,,,,,,, -46500,1.1752369,3.6291077,,,,,,,,,,,,,, -46600,1.0520265,3.6000845,,,,,,,,,,,,,, -46700,1.0419983,3.5174642,,,,,,,,,,,,,, -46800,0.9159422,4.9679127,,,,,,,,,,,,,, -46900,1.0113248,4.181155,,,,,,,,,,,,,, -46969,,,0.7266796827316284,1.295320987701416,0.6565399765968323,1.61535382270813,50000.0,0.5320000052452087,2.2345681190490723,10000.0,21461.85862827301,23242.50052928925,21461.85862827301,1776.6765806674955,1.556483030319214,0.0 -47000,0.948594,4.95636,,,,,,,,,,,,,, -47100,1.0871394,3.6963809,,,,,,,,,,,,,, -47200,1.0312814,5.2372384,,,,,,,,,,,,,, -47300,0.8650553,4.878848,,,,,,,,,,,,,, -47400,1.1117449,3.611358,,,,,,,,,,,,,, -47500,1.2051454,5.402504,,,,,,,,,,,,,, -47600,0.990733,3.8600779,,,,,,,,,,,,,, -47700,0.9983785,5.2204905,,,,,,,,,,,,,, -47800,1.0435805,3.6438222,,,,,,,,,,,,,, -47891,,,0.7119921445846558,1.4000214338302612,0.6568399667739868,1.6305649280548096,50000.0,0.5309000015258789,2.2502574920654297,10000.0,21882.07282042504,23697.036667346954,21882.07282042504,1810.9130997657776,1.5944783687591553,0.0 -47900,1.0273136,3.6992872,,,,,,,,,,,,,, -48000,0.90664446,5.291619,,,,,,,,,,,,,, -48100,0.95433766,4.190644,,,,,,,,,,,,,, -48200,0.991843,3.7917805,,,,,,,,,,,,,, -48300,0.90710336,4.872183,,,,,,,,,,,,,, -48400,1.0333031,3.6796627,,,,,,,,,,,,,, -48500,1.0089433,3.5123632,,,,,,,,,,,,,, -48600,1.0377405,3.5486586,,,,,,,,,,,,,, -48700,1.0368224,3.556633,,,,,,,,,,,,,, -48800,1.0545028,5.1107464,,,,,,,,,,,,,, -48814,,,0.7128710746765137,1.3655827045440674,0.6576200127601624,1.6067535877227783,50000.0,0.5335000157356262,2.232346534729004,10000.0,22302.034429311752,24147.68631052971,22302.034429311752,1841.5184531211853,1.6299116611480713,0.0 -48900,1.0477997,3.589296,,,,,,,,,,,,,, -49000,1.0096493,3.5000074,,,,,,,,,,,,,, -49100,1.082225,3.5877638,,,,,,,,,,,,,, -49200,0.9538394,3.9722643,,,,,,,,,,,,,, -49300,0.91819394,4.9507313,,,,,,,,,,,,,, -49400,1.1710178,3.563245,,,,,,,,,,,,,, -49500,1.0684752,3.6330807,,,,,,,,,,,,,, -49600,1.1074783,3.6050313,,,,,,,,,,,,,, -49700,1.0799911,3.598531,,,,,,,,,,,,,, -49734,,,0.728808581829071,1.3120293617248535,0.6609199643135071,1.6072094440460205,50000.0,0.5344000458717346,2.236469984054565,10000.0,22722.3923227787,24601.1224694252,22722.3923227787,1874.509628534317,1.6707801818847656,0.0 -49800,1.0904317,3.5040157,,,,,,,,,,,,,, -49900,1.159978,3.5555143,,,,,,,,,,,,,, -50000,0.98131055,3.4924724,,,,,,,,,,,,,, -50100,0.886716,4.7957506,,,,,,,,,,,,,, -50200,1.0522394,3.4715571,,,,,,,,,,,,,, -50300,1.0601058,3.5500944,,,,,,,,,,,,,, -50400,1.1100147,3.54055,,,,,,,,,,,,,, -50500,1.0096078,3.7321236,,,,,,,,,,,,,, -50600,1.0589708,3.6977952,,,,,,,,,,,,,, -50652,,,0.7116601467132568,1.3645031452178955,0.6611599922180176,1.5872159004211426,50000.0,0.5375000238418579,2.2071101665496826,10000.0,23142.32176733017,25055.637226343155,23142.32176733017,1909.011991024017,1.70641827583313,0.0 -50700,1.0379533,3.7617614,,,,,,,,,,,,,, -50800,1.1619054,3.5826097,,,,,,,,,,,,,, -50900,1.0921956,3.494949,,,,,,,,,,,,,, -51000,1.081721,5.3435373,,,,,,,,,,,,,, -51100,1.0797356,3.5084584,,,,,,,,,,,,,, -51200,1.0155803,3.9488666,,,,,,,,,,,,,, -51300,1.1140802,3.522613,,,,,,,,,,,,,, -51400,1.1674455,3.5528197,,,,,,,,,,,,,, -51500,1.0677793,3.5298128,,,,,,,,,,,,,, -51574,,,0.717968761920929,1.3651286363601685,0.6620799899101257,1.6173219680786133,50000.0,0.5333000421524048,2.240694761276245,10000.0,23562.47232246399,25507.66274857521,23562.47232246399,1940.8010630607605,1.7460317611694336,0.0 -51600,0.9141488,4.150055,,,,,,,,,,,,,, -51700,1.0270207,3.5243015,,,,,,,,,,,,,, -51800,1.0701185,3.6099024,,,,,,,,,,,,,, -51900,1.0291213,3.436224,,,,,,,,,,,,,, -52000,1.0646505,3.5296068,,,,,,,,,,,,,, -52100,0.98427665,4.9212523,,,,,,,,,,,,,, -52200,1.0515513,3.5603583,,,,,,,,,,,,,, -52300,0.9165573,4.049132,,,,,,,,,,,,,, -52400,1.0431191,3.4875057,,,,,,,,,,,,,, -52496,,,0.7241796851158142,1.34527850151062,0.6620000004768372,1.6125682592391968,50000.0,0.5406000018119812,2.2341880798339844,10000.0,23982.808941841125,25962.668434381485,23982.808941841125,1975.382539987564,1.786369800567627,0.0 -52500,1.0170976,3.4714577,,,,,,,,,,,,,, -52600,1.1063135,3.4715273,,,,,,,,,,,,,, -52700,1.1181239,3.5006044,,,,,,,,,,,,,, -52800,1.0108185,4.066586,,,,,,,,,,,,,, -52900,1.0446337,3.4870183,,,,,,,,,,,,,, -53000,1.0231314,3.5573516,,,,,,,,,,,,,, -53100,1.2019101,3.4654717,,,,,,,,,,,,,, -53200,1.0593346,3.5921452,,,,,,,,,,,,,, -53300,0.99886864,5.3269825,,,,,,,,,,,,,, -53400,0.99655414,3.5477502,,,,,,,,,,,,,, -53418,,,0.7199804782867432,1.3372256755828855,0.6657999753952026,1.5774351358413696,50000.0,0.5430999994277954,2.2015929222106934,10000.0,24402.90833449364,26413.86163020134,24402.90833449364,2006.3908026218407,1.824455976486206,0.0 -53500,0.9961452,3.89226,,,,,,,,,,,,,, -53600,0.9929855,4.8087587,,,,,,,,,,,,,, -53700,1.1133957,3.4401891,,,,,,,,,,,,,, -53800,0.9973292,3.4788175,,,,,,,,,,,,,, -53900,1.0581026,3.626694,,,,,,,,,,,,,, -54000,1.0385652,4.9777207,,,,,,,,,,,,,, -54100,0.9668258,4.983318,,,,,,,,,,,,,, -54200,1.0717497,3.5554862,,,,,,,,,,,,,, -54300,1.0190812,5.284568,,,,,,,,,,,,,, -54339,,,0.7254882454872131,1.335657000541687,0.670699954032898,1.579382300376892,50000.0,0.5449000000953674,2.186253309249878,10000.0,24823.192568540573,26867.0536248684,24823.192568540573,2039.2105541229248,1.866105318069458,0.0 -54400,1.1161064,3.4908266,,,,,,,,,,,,,, -54500,1.0818856,3.535078,,,,,,,,,,,,,, -54600,1.0581262,3.4106326,,,,,,,,,,,,,, -54700,1.0241911,4.2940483,,,,,,,,,,,,,, -54800,1.0146993,5.370873,,,,,,,,,,,,,, -54900,1.1028925,3.5850725,,,,,,,,,,,,,, -55000,1.0815244,3.4428415,,,,,,,,,,,,,, -55100,1.0710567,4.1617174,,,,,,,,,,,,,, -55200,1.1119317,4.0133185,,,,,,,,,,,,,, -55261,,,0.7294921875,1.3043168783187866,0.667419970035553,1.5707368850708008,50000.0,0.5430000424385071,2.1910340785980225,10000.0,25243.527433633804,27318.778652668,25243.527433633804,2070.5169591903687,1.901301383972168,0.0 -55300,1.0644096,5.3123,,,,,,,,,,,,,, -55400,1.1225507,3.5212219,,,,,,,,,,,,,, -55500,1.0511036,4.025168,,,,,,,,,,,,,, -55600,1.1882522,3.4937484,,,,,,,,,,,,,, -55700,1.0401908,3.774532,,,,,,,,,,,,,, -55800,1.0326762,5.092415,,,,,,,,,,,,,, -55900,1.0766898,3.4802773,,,,,,,,,,,,,, -56000,0.9871223,4.7903185,,,,,,,,,,,,,, -56100,1.0569239,3.5712185,,,,,,,,,,,,,, -56182,,,0.7528125047683716,1.1971156597137451,0.6747999787330627,1.5370960235595703,50000.0,0.5533000230789185,2.1446516513824463,10000.0,25663.721523284912,27770.72711038589,25663.721523284912,2102.191232919693,1.934786796569824,0.0 -56200,1.0632509,3.3995929,,,,,,,,,,,,,, -56300,1.0952153,3.504774,,,,,,,,,,,,,, -56400,1.058931,3.4703932,,,,,,,,,,,,,, -56500,1.022453,5.114988,,,,,,,,,,,,,, -56600,1.1126977,3.7753825,,,,,,,,,,,,,, -56700,1.1032468,3.4943042,,,,,,,,,,,,,, -56800,1.0019249,3.6538596,,,,,,,,,,,,,, -56900,0.9529258,4.536788,,,,,,,,,,,,,, -57000,1.0171888,4.527849,,,,,,,,,,,,,, -57100,1.0281506,3.8289862,,,,,,,,,,,,,, -57106,,,0.72314453125,1.3708430528640747,0.6675199866294861,1.6150413751602173,50000.0,0.546500027179718,2.2251136302948,10000.0,26084.066102027893,28225.778723955154,26084.066102027893,2136.8173022270203,1.9684455394744875,0.0 -57200,0.9999834,4.9746175,,,,,,,,,,,,,, -57300,1.0405672,3.8670137,,,,,,,,,,,,,, -57400,1.1726336,3.6657963,,,,,,,,,,,,,, -57500,1.0863703,3.6425967,,,,,,,,,,,,,, -57600,0.9753714,4.1682587,,,,,,,,,,,,,, -57700,0.955406,4.4515905,,,,,,,,,,,,,, -57800,1.121794,3.5205443,,,,,,,,,,,,,, -57900,1.0499935,3.5821385,,,,,,,,,,,,,, -58000,1.0556324,4.529348,,,,,,,,,,,,,, -58026,,,0.7324023246765137,1.3181718587875366,0.6703199744224548,1.5841935873031616,50000.0,0.5457000136375427,2.212608337402344,10000.0,26504.048363685608,28678.30360651016,26504.048363685608,2169.2758326530457,2.005272388458252,0.0 -58100,1.1750231,3.5405703,,,,,,,,,,,,,, -58200,1.0245992,4.830801,,,,,,,,,,,,,, -58300,0.9539623,4.0997934,,,,,,,,,,,,,, -58400,1.2120208,3.6145272,,,,,,,,,,,,,, -58500,1.0814699,3.7809439,,,,,,,,,,,,,, -58600,1.0355123,4.6340833,,,,,,,,,,,,,, -58700,1.1880375,3.472463,,,,,,,,,,,,,, -58800,0.978973,4.4864383,,,,,,,,,,,,,, -58900,1.0901775,3.456285,,,,,,,,,,,,,, -58948,,,0.7443945407867432,1.2518020868301392,0.6735599637031555,1.557487726211548,50000.0,0.5502000451087952,2.179090738296509,10000.0,26924.052193164825,29129.454854011536,26924.052193164825,2200.3360035419464,2.045383214950561,0.0 -59000,1.0495251,3.6688428,,,,,,,,,,,,,, -59100,1.0770321,3.5464292,,,,,,,,,,,,,, -59200,1.1081125,3.4712424,,,,,,,,,,,,,, -59300,1.1703354,3.5363994,,,,,,,,,,,,,, -59400,1.0823392,3.6469185,,,,,,,,,,,,,, -59500,1.1155714,3.815182,,,,,,,,,,,,,, -59600,1.1172695,4.3005238,,,,,,,,,,,,,, -59700,1.1113358,3.4485765,,,,,,,,,,,,,, -59800,0.995757,5.2469006,,,,,,,,,,,,,, -59869,,,0.7306249737739563,1.318999409675598,0.6790599822998047,1.5507181882858276,50000.0,0.5525000095367432,2.166768789291382,10000.0,27344.28582406044,29584.562819719315,27344.28582406044,2235.119631052017,2.088909387588501,0.0 -59900,1.1781359,3.4573765,,,,,,,,,,,,,, -60000,1.1747067,3.5026379,,,,,,,,,,,,,, -60100,1.1753172,3.5548644,,,,,,,,,,,,,, -60200,1.1122267,3.461886,,,,,,,,,,,,,, -60300,1.1322876,3.4369307,,,,,,,,,,,,,, -60400,1.123524,3.4065037,,,,,,,,,,,,,, -60500,1.0163838,5.0949793,,,,,,,,,,,,,, -60600,1.042816,3.716715,,,,,,,,,,,,,, -60700,1.124217,3.518559,,,,,,,,,,,,,, -60792,,,0.7305468320846558,1.3091228008270264,0.6714000105857849,1.5599348545074463,50000.0,0.5435000061988831,2.1946136951446533,10000.0,27764.46465349197,30038.139572381973,27764.46465349197,2268.4358253479004,2.123471736907959,0.0 -60800,1.1750833,3.5063014,,,,,,,,,,,,,, -60900,1.1716374,3.6511831,,,,,,,,,,,,,, -61000,1.0920638,3.4254396,,,,,,,,,,,,,, -61100,1.0114313,5.116713,,,,,,,,,,,,,, -61200,1.2103169,3.5056007,,,,,,,,,,,,,, -61300,1.0700092,3.7421365,,,,,,,,,,,,,, -61400,1.0782856,3.6518407,,,,,,,,,,,,,, -61500,0.9999503,5.0754805,,,,,,,,,,,,,, -61600,1.1097564,3.5019243,,,,,,,,,,,,,, -61700,1.0028774,3.8155441,,,,,,,,,,,,,, -61707,,,0.7434960603713989,1.2444376945495603,0.6765999794006348,1.5327082872390747,50000.0,0.5573000311851501,2.1491177082061768,10000.0,28184.4619910717,30492.225292921063,28184.4619910717,2302.441154003144,2.160538911819458,0.0 -61800,1.103573,3.5635316,,,,,,,,,,,,,, -61900,1.2508233,3.771801,,,,,,,,,,,,,, -62000,1.1076146,3.4507906,,,,,,,,,,,,,, -62100,1.0556026,4.7504725,,,,,,,,,,,,,, -62200,0.982553,4.3129745,,,,,,,,,,,,,, -62300,1.0202478,3.7922952,,,,,,,,,,,,,, -62400,1.0966252,5.1436543,,,,,,,,,,,,,, -62500,1.1748972,3.424655,,,,,,,,,,,,,, -62600,1.0853715,3.7076683,,,,,,,,,,,,,, -62627,,,0.7280663847923279,1.3498351573944092,0.672540009021759,1.5801585912704468,50000.0,0.5503000020980835,2.195638656616211,10000.0,28604.43643712997,30947.598170518875,28604.43643712997,2337.75715136528,2.1949493885040283,0.0 -62700,1.0186871,4.215682,,,,,,,,,,,,,, -62800,1.1935736,3.6709013,,,,,,,,,,,,,, -62900,1.1811658,3.4529555,,,,,,,,,,,,,, -63000,1.0869242,4.925485,,,,,,,,,,,,,, -63100,1.1467589,3.4584043,,,,,,,,,,,,,, -63200,0.9786994,4.608179,,,,,,,,,,,,,, -63300,1.0865124,3.503769,,,,,,,,,,,,,, -63400,1.1940781,3.468776,,,,,,,,,,,,,, -63500,1.127585,3.4239678,,,,,,,,,,,,,, -63549,,,0.7319726347923279,1.305245280265808,0.6799799799919128,1.548778414726257,50000.0,0.5556000471115112,2.1663198471069336,10000.0,29024.70809817314,31403.40734767914,29024.70809817314,2373.211641073227,2.230957508087158,0.0 -63600,1.1779051,3.3732271,,,,,,,,,,,,,, -63700,1.1308198,3.5153747,,,,,,,,,,,,,, -63800,1.0824109,3.6538665,,,,,,,,,,,,,, -63900,1.1788176,3.5206866,,,,,,,,,,,,,, -64000,1.1550732,3.4802473,,,,,,,,,,,,,, -64100,0.99496377,4.1844788,,,,,,,,,,,,,, -64200,1.1191716,4.4681377,,,,,,,,,,,,,, -64300,1.0342855,3.9127288,,,,,,,,,,,,,, -64400,1.2163384,3.552289,,,,,,,,,,,,,, -64471,,,0.7493359446525574,1.2207766771316528,0.6810799837112427,1.495862603187561,50000.0,0.5586000084877014,2.1107430458068848,10000.0,29444.98510026932,31857.98429250717,29444.98510026932,2407.421452522278,2.273590564727783,0.0 -64500,1.093133,3.3516505,,,,,,,,,,,,,, -64600,1.1147426,4.704186,,,,,,,,,,,,,, -64700,1.0217469,4.0827622,,,,,,,,,,,,,, -64800,1.1551671,3.5671358,,,,,,,,,,,,,, -64900,1.2653528,3.484743,,,,,,,,,,,,,, -65000,1.1129193,3.425886,,,,,,,,,,,,,, -65100,1.0301964,3.9687967,,,,,,,,,,,,,, -65200,1.1322616,3.3150036,,,,,,,,,,,,,, -65300,0.9770171,4.1069536,,,,,,,,,,,,,, -65392,,,0.7576562166213989,1.2009063959121704,0.6784200072288513,1.5377628803253174,50000.0,0.5560000538825989,2.167978048324585,10000.0,29864.908487081528,32313.229038715363,29864.908487081528,2442.6564412117004,2.313045024871826,0.0 -65400,1.1020808,3.4961257,,,,,,,,,,,,,, -65500,1.0561628,4.593961,,,,,,,,,,,,,, -65600,1.159037,3.515807,,,,,,,,,,,,,, -65700,1.1272676,3.457377,,,,,,,,,,,,,, -65800,1.1727028,3.477817,,,,,,,,,,,,,, -65900,1.0773259,3.6756756,,,,,,,,,,,,,, -66000,1.0355177,3.8733063,,,,,,,,,,,,,, -66100,1.1956995,5.202333,,,,,,,,,,,,,, -66200,1.1304451,4.797374,,,,,,,,,,,,,, -66300,1.1281822,3.591321,,,,,,,,,,,,,, -66313,,,0.7408398389816284,1.2733092308044434,0.6830599904060364,1.5249103307724,50000.0,0.5582000017166138,2.1454310417175293,10000.0,30285.114988565445,32768.00912475586,30285.114988565445,2477.143489599228,2.3523316383361816,0.0 -66400,0.9935753,4.5117044,,,,,,,,,,,,,, -66500,1.0871027,4.714652,,,,,,,,,,,,,, -66600,1.2191792,3.432564,,,,,,,,,,,,,, -66700,1.126501,3.6540048,,,,,,,,,,,,,, -66800,1.1160216,3.7376297,,,,,,,,,,,,,, -66900,1.0549483,4.814654,,,,,,,,,,,,,, -67000,1.1608645,3.5587544,,,,,,,,,,,,,, -67100,1.045105,4.2167187,,,,,,,,,,,,,, -67200,1.0950528,3.7741547,,,,,,,,,,,,,, -67235,,,0.7444140315055847,1.232791304588318,0.6813200116157532,1.510599970817566,50000.0,0.5560000538825989,2.115541219711304,10000.0,30705.05445933342,33222.08050394058,30705.05445933342,2511.189737558365,2.390868186950684,0.0 -67300,1.074018,4.930888,,,,,,,,,,,,,, -67400,1.0751274,4.6308813,,,,,,,,,,,,,, -67500,1.176915,3.3898935,,,,,,,,,,,,,, -67600,1.0091604,3.9385765,,,,,,,,,,,,,, -67700,1.1902761,3.3706453,,,,,,,,,,,,,, -67800,1.1686162,3.420568,,,,,,,,,,,,,, -67900,1.1200674,4.967981,,,,,,,,,,,,,, -68000,1.101788,3.4328957,,,,,,,,,,,,,, -68100,1.1442196,3.4377575,,,,,,,,,,,,,, -68155,,,0.761523425579071,1.198433756828308,0.6867199540138245,1.5193663835525513,50000.0,0.5618000030517578,2.129547595977783,10000.0,31125.21816754341,33677.57394862175,31125.21816754341,2546.431643724441,2.431232452392578,0.0 -68200,1.1270533,3.692525,,,,,,,,,,,,,, -68300,1.0554675,3.5465238,,,,,,,,,,,,,, -68400,1.1965059,3.4932969,,,,,,,,,,,,,, -68500,1.2575296,3.5790274,,,,,,,,,,,,,, -68600,1.065807,3.889593,,,,,,,,,,,,,, -68700,1.1839782,3.3753846,,,,,,,,,,,,,, -68800,1.2122281,3.4274411,,,,,,,,,,,,,, -68900,1.1269685,3.464457,,,,,,,,,,,,,, -69000,1.0683973,4.5359926,,,,,,,,,,,,,, -69076,,,0.7444921731948853,1.2327656745910645,0.6861199736595154,1.484375238418579,50000.0,0.5611000061035156,2.0891273021698,10000.0,31545.46352577209,34129.872649908066,31545.46352577209,2578.39408159256,2.4752867221832275,0.0 -69100,1.0926156,3.5070395,,,,,,,,,,,,,, -69200,1.0520502,3.645942,,,,,,,,,,,,,, -69300,0.98523194,4.397399,,,,,,,,,,,,,, -69400,1.1347053,3.405057,,,,,,,,,,,,,, -69500,1.086775,3.4772484,,,,,,,,,,,,,, -69600,1.1437311,4.491622,,,,,,,,,,,,,, -69700,1.2250918,3.4591012,,,,,,,,,,,,,, -69800,1.1359106,3.4581704,,,,,,,,,,,,,, -69900,1.0941889,3.574152,,,,,,,,,,,,,, -69998,,,0.7477343678474426,1.226297378540039,0.6846599578857422,1.4979448318481443,50000.0,0.5593000054359436,2.113833427429199,10000.0,31965.51737833023,34584.49463844299,31965.51737833023,2612.8748741149902,2.515571594238281,0.0 -70000,1.1266155,3.7192066,,,,,,,,,,,,,, -70100,1.1169044,3.4693842,,,,,,,,,,,,,, -70200,1.2425302,3.578865,,,,,,,,,,,,,, -70300,1.0535812,3.6649828,,,,,,,,,,,,,, -70400,1.0001622,4.357178,,,,,,,,,,,,,, -70500,1.1367071,3.4864037,,,,,,,,,,,,,, -70600,1.0592666,4.0137396,,,,,,,,,,,,,, -70700,1.1633625,3.3441272,,,,,,,,,,,,,, -70800,1.1850046,3.4137976,,,,,,,,,,,,,, -70900,1.1550078,3.379877,,,,,,,,,,,,,, -70919,,,0.761035144329071,1.179047465324402,0.6904599666595459,1.4743375778198242,50000.0,0.5624000430107117,2.088679552078247,10000.0,32385.57051825524,35035.8462703228,32385.57051825524,2644.0857014656067,2.554614782333374,0.0 -71000,1.1829482,3.437956,,,,,,,,,,,,,, -71100,1.180815,3.3573456,,,,,,,,,,,,,, -71200,1.1340778,4.030436,,,,,,,,,,,,,, -71300,1.1787647,3.414922,,,,,,,,,,,,,, -71400,1.075506,3.4178216,,,,,,,,,,,,,, -71500,1.0346293,3.8538105,,,,,,,,,,,,,, -71600,1.1783844,3.4277372,,,,,,,,,,,,,, -71700,1.2095121,3.3735914,,,,,,,,,,,,,, -71800,1.0929018,3.8573065,,,,,,,,,,,,,, -71841,,,0.7437499761581421,1.298214554786682,0.6857199668884277,1.5482041835784912,50000.0,0.5604000091552734,2.14981746673584,10000.0,32805.80994772911,35489.30454015732,32805.80994772911,2677.212243080139,2.599778652191162,0.0 -71900,1.0961539,3.5505333,,,,,,,,,,,,,, -72000,1.0737219,4.79346,,,,,,,,,,,,,, -72100,1.1692916,3.3796582,,,,,,,,,,,,,, -72200,1.2128603,3.4171185,,,,,,,,,,,,,, -72300,1.0872642,3.9179688,,,,,,,,,,,,,, -72400,1.0297499,3.9268045,,,,,,,,,,,,,, -72500,1.0836976,4.4725842,,,,,,,,,,,,,, -72600,1.1855286,3.4546905,,,,,,,,,,,,,, -72700,1.1209134,3.8086815,,,,,,,,,,,,,, -72764,,,0.7504296898841858,1.206867218017578,0.6949999928474426,1.4593466520309448,50000.0,0.5694000124931335,2.0661840438842773,10000.0,33225.91061067581,35941.64887547493,33225.91061067581,2709.371276378632,2.638214349746704,0.0 -72800,1.1842501,3.3002098,,,,,,,,,,,,,, -72900,1.1756401,3.37254,,,,,,,,,,,,,, -73000,1.1486928,3.3780284,,,,,,,,,,,,,, -73100,1.1254005,3.711776,,,,,,,,,,,,,, -73200,1.159999,4.947061,,,,,,,,,,,,,, -73300,1.0847349,4.265219,,,,,,,,,,,,,, -73400,1.1165996,4.310092,,,,,,,,,,,,,, -73500,1.0121275,3.8589635,,,,,,,,,,,,,, -73600,1.2050213,4.629807,,,,,,,,,,,,,, -73686,,,0.7568163871765137,1.2269736528396606,0.6904199719429016,1.5098013877868652,50000.0,0.5634000301361084,2.13584303855896,10000.0,33646.25889086723,36392.84107017517,33646.25889086723,2740.13196849823,2.6735827922821045,0.0 -73700,1.151326,3.442065,,,,,,,,,,,,,, -73800,1.3087292,3.4415598,,,,,,,,,,,,,, -73900,1.0350069,4.6187186,,,,,,,,,,,,,, -74000,1.1521132,3.383162,,,,,,,,,,,,,, -74100,1.096449,4.350633,,,,,,,,,,,,,, -74200,1.2095723,3.5263963,,,,,,,,,,,,,, -74300,1.0994222,4.7730017,,,,,,,,,,,,,, -74400,1.1416284,4.5554495,,,,,,,,,,,,,, -74500,1.21523,3.4093819,,,,,,,,,,,,,, -74600,1.2256157,3.446073,,,,,,,,,,,,,, -74607,,,0.7570703029632568,1.1745266914367676,0.6887800097465515,1.460245132446289,50000.0,0.5653000473976135,2.070711374282837,10000.0,34066.51602935791,36850.19239878655,34066.51602935791,2777.113403081894,2.73940110206604,0.0 -74700,1.3106554,3.5345168,,,,,,,,,,,,,, -74800,1.2144843,4.802308,,,,,,,,,,,,,, -74900,1.1462005,3.646657,,,,,,,,,,,,,, -75000,1.2280234,5.1725636,,,,,,,,,,,,,, -75100,1.2062484,3.3651533,,,,,,,,,,,,,, -75200,1.2172443,3.3287945,,,,,,,,,,,,,, -75300,1.2248586,3.4400828,,,,,,,,,,,,,, -75400,1.1558591,3.5904858,,,,,,,,,,,,,, -75500,1.2607327,3.5473027,,,,,,,,,,,,,, -75530,,,0.75341796875,1.207563400268555,0.6945199966430664,1.4618675708770752,50000.0,0.5676000118255615,2.0816633701324463,10000.0,34486.733615875244,37304.65952205658,34486.733615875244,2811.2738075256348,2.781567335128784,0.0 -75600,1.1856437,3.460699,,,,,,,,,,,,,, -75700,1.1232704,3.917651,,,,,,,,,,,,,, -75800,1.2255116,3.3426607,,,,,,,,,,,,,, -75900,1.0790837,4.2646456,,,,,,,,,,,,,, -76000,1.126025,4.8236184,,,,,,,,,,,,,, -76100,1.0865278,4.1632257,,,,,,,,,,,,,, -76200,1.1686186,3.3508048,,,,,,,,,,,,,, -76300,1.2347138,4.8519,,,,,,,,,,,,,, -76400,1.1860838,5.1405816,,,,,,,,,,,,,, -76453,,,0.7587304711341858,1.1997249126434326,0.6966999769210815,1.4641448259353638,50000.0,0.5722000002861023,2.0874454975128174,10000.0,34906.928926706314,37756.88763618469,34906.928926706314,2843.222591161728,2.818364381790161,0.0 -76500,1.322528,3.6664886,,,,,,,,,,,,,, -76600,1.2233605,5.102216,,,,,,,,,,,,,, -76700,1.1318662,4.215717,,,,,,,,,,,,,, -76800,1.2152686,3.4092727,,,,,,,,,,,,,, -76900,1.1238203,3.6518261,,,,,,,,,,,,,, -77000,1.2340264,3.3784246,,,,,,,,,,,,,, -77100,1.1898165,3.302752,,,,,,,,,,,,,, -77200,1.1715338,3.332984,,,,,,,,,,,,,, -77300,1.3063873,3.4812853,,,,,,,,,,,,,, -77374,,,0.7772070169448853,1.1090781688690186,0.694599986076355,1.465245246887207,50000.0,0.5695000290870667,2.073258399963379,10000.0,35327.165003061295,38208.289085149765,35327.165003061295,2874.2982473373413,2.860573530197144,0.0 -77400,1.2291851,4.708461,,,,,,,,,,,,,, -77500,1.1601266,3.4592783,,,,,,,,,,,,,, -77600,1.197759,3.4290202,,,,,,,,,,,,,, -77700,1.1809745,4.956531,,,,,,,,,,,,,, -77800,1.2761405,5.0485115,,,,,,,,,,,,,, -77900,1.2245456,4.6896596,,,,,,,,,,,,,, -78000,1.1388407,5.1505003,,,,,,,,,,,,,, -78100,1.249931,3.3090198,,,,,,,,,,,,,, -78200,1.194813,3.3985343,,,,,,,,,,,,,, -78298,,,0.7567187547683716,1.2263389825820925,0.6919599771499634,1.4922540187835691,50000.0,0.5665000081062317,2.110428810119629,10000.0,35747.40086340904,38663.13839507103,35747.40086340904,2908.8261551856995,2.899253368377685,0.0 -78300,1.168585,4.163303,,,,,,,,,,,,,, -78400,1.2601154,5.033257,,,,,,,,,,,,,, -78500,1.2621847,3.4527872,,,,,,,,,,,,,, -78600,1.1304537,3.5733871,,,,,,,,,,,,,, -78700,1.2071053,3.7278526,,,,,,,,,,,,,, -78800,1.2397878,3.3246717,,,,,,,,,,,,,, -78900,1.2572486,3.4572563,,,,,,,,,,,,,, -79000,1.1401252,3.4033656,,,,,,,,,,,,,, -79100,1.2789271,3.4501948,,,,,,,,,,,,,, -79200,1.1630815,3.5016906,,,,,,,,,,,,,, -79219,,,0.76429682970047,1.1475809812545776,0.6978799700737,1.4310686588287354,50000.0,0.5734000205993652,2.026544570922852,10000.0,36167.35218763352,39112.27828145027,36167.35218763352,2937.9259643554688,2.9404571056365967,0.0 -79300,1.3093915,3.3880146,,,,,,,,,,,,,, -79400,1.1327432,4.255255,,,,,,,,,,,,,, -79500,1.1350472,4.1791983,,,,,,,,,,,,,, -79600,1.199202,3.299662,,,,,,,,,,,,,, -79700,1.2540694,3.3883255,,,,,,,,,,,,,, -79800,1.20751,5.0014386,,,,,,,,,,,,,, -79900,1.2321831,3.3810601,,,,,,,,,,,,,, -80000,1.3017365,5.0605707,,,,,,,,,,,,,, -80100,1.1317341,3.2970343,,,,,,,,,,,,,, -80140,,,0.7719921469688416,1.1283113956451416,0.6975399851799011,1.4419785737991333,50000.0,0.570900022983551,2.05631685256958,10000.0,36587.51485085488,39567.50933790207,36587.51485085488,2972.902609348297,2.984685182571411,0.0 -80200,1.2558556,5.061873,,,,,,,,,,,,,, -80300,1.404931,3.3588908,,,,,,,,,,,,,, -80400,1.2776854,3.385284,,,,,,,,,,,,,, -80500,1.2666914,3.3076038,,,,,,,,,,,,,, -80600,1.1385432,4.3500123,,,,,,,,,,,,,, -80700,1.168394,3.5086174,,,,,,,,,,,,,, -80800,1.206608,3.435648,,,,,,,,,,,,,, -80900,1.2316195,3.4772084,,,,,,,,,,,,,, -81000,1.2127315,3.5317676,,,,,,,,,,,,,, -81062,,,0.7625781297683716,1.1647394895553589,0.7009999752044678,1.4354666471481323,50000.0,0.5752000212669373,2.033855438232422,10000.0,37007.73335170746,40020.88571333885,37007.73335170746,3005.9743795394897,3.023676633834839,0.0 -81100,1.3673015,3.3829143,,,,,,,,,,,,,, -81200,1.2270496,3.2354693,,,,,,,,,,,,,, -81300,1.213779,3.2912092,,,,,,,,,,,,,, -81400,1.1056027,4.102541,,,,,,,,,,,,,, -81500,1.1513748,3.4747653,,,,,,,,,,,,,, -81600,1.2731302,3.3402135,,,,,,,,,,,,,, -81700,1.1444626,3.5040681,,,,,,,,,,,,,, -81800,1.2143856,4.7423244,,,,,,,,,,,,,, -81900,1.2204581,3.3496828,,,,,,,,,,,,,, -81984,,,0.7649609446525574,1.1457265615463257,0.7003799676895142,1.4246442317962646,50000.0,0.5760000348091125,2.041268348693848,10000.0,37427.75093221664,40472.73112511635,37427.75093221664,3037.7150337696075,3.0635695457458496,0.0 -82000,1.4129976,3.3929293,,,,,,,,,,,,,, -82100,1.299656,3.414117,,,,,,,,,,,,,, -82200,1.1936946,3.6781287,,,,,,,,,,,,,, -82300,1.324616,3.4053018,,,,,,,,,,,,,, -82400,1.3410046,3.4138265,,,,,,,,,,,,,, -82500,1.2476927,3.4019191,,,,,,,,,,,,,, -82600,1.251811,3.3117285,,,,,,,,,,,,,, -82700,1.2586526,3.389338,,,,,,,,,,,,,, -82800,1.1641108,4.3521347,,,,,,,,,,,,,, -82900,1.2829046,3.4241107,,,,,,,,,,,,,, -82906,,,0.7728124856948853,1.1158013343811035,0.7010799646377563,1.41843843460083,50000.0,0.5736000537872314,2.0298521518707275,10000.0,37847.68333148956,40926.77200841904,37847.68333148956,3071.735149860382,3.1044604778289795,0.0 -83000,1.2021064,3.594801,,,,,,,,,,,,,, -83100,1.3082061,3.3255124,,,,,,,,,,,,,, -83200,1.2103869,3.329965,,,,,,,,,,,,,, -83300,1.2664744,3.3485942,,,,,,,,,,,,,, -83400,1.2534463,4.0749836,,,,,,,,,,,,,, -83500,1.2424649,3.2577436,,,,,,,,,,,,,, -83600,1.2410648,3.5276585,,,,,,,,,,,,,, -83700,1.2019452,3.264804,,,,,,,,,,,,,, -83800,1.3281487,3.3395736,,,,,,,,,,,,,, -83831,,,0.7667187452316284,1.1418139934539795,0.7049999833106995,1.4089363813400269,50000.0,0.5822000503540039,2.009699821472168,10000.0,38267.85958909989,41383.278517484665,38267.85958909989,3107.9714460372925,3.151188611984253,0.0 -83900,1.2459598,3.6273475,,,,,,,,,,,,,, -84000,1.2970616,3.314723,,,,,,,,,,,,,, -84100,1.1780934,4.8856673,,,,,,,,,,,,,, -84200,1.2047737,3.350788,,,,,,,,,,,,,, -84300,1.32416,5.065468,,,,,,,,,,,,,, -84400,1.1396599,3.6982875,,,,,,,,,,,,,, -84500,1.3936135,4.8781157,,,,,,,,,,,,,, -84600,1.3542202,4.9986305,,,,,,,,,,,,,, -84700,1.1765575,4.0454865,,,,,,,,,,,,,, -84754,,,0.7683984041213989,1.1525145769119265,0.7024199962615967,1.4284731149673462,50000.0,0.5760000348091125,2.029409170150757,10000.0,38688.02298378944,41834.969650030136,38688.02298378944,3139.4142003059387,3.1884803771972656,0.0 -84800,1.1952379,3.6807284,,,,,,,,,,,,,, -84900,1.1775074,3.9996953,,,,,,,,,,,,,, -85000,1.3062718,3.297217,,,,,,,,,,,,,, -85100,1.4650134,5.106555,,,,,,,,,,,,,, -85200,1.3354889,3.4022233,,,,,,,,,,,,,, -85300,1.2894899,3.7466187,,,,,,,,,,,,,, -85400,1.2910453,3.3769042,,,,,,,,,,,,,, -85500,1.2196326,3.281735,,,,,,,,,,,,,, -85600,1.2732633,4.1573696,,,,,,,,,,,,,, -85675,,,0.7705664038658142,1.13634991645813,0.7039200067520142,1.421459078788757,50000.0,0.5767000317573547,2.0247151851654053,10000.0,39108.00399470329,42289.48701906204,39108.00399470329,3173.858931779861,3.232621431350708,0.0 -85700,1.3564713,3.3181665,,,,,,,,,,,,,, -85800,1.1464499,3.7310274,,,,,,,,,,,,,, -85900,1.3009806,3.3375773,,,,,,,,,,,,,, -86000,1.2605033,3.308066,,,,,,,,,,,,,, -86100,1.2468059,3.4229052,,,,,,,,,,,,,, -86200,1.3505996,3.6605463,,,,,,,,,,,,,, -86300,1.2427467,3.3710027,,,,,,,,,,,,,, -86400,1.2724802,4.7418594,,,,,,,,,,,,,, -86500,1.1524513,4.261179,,,,,,,,,,,,,, -86597,,,0.7929882407188416,1.103911519050598,0.7073799967765808,1.468793511390686,50000.0,0.5759000182151794,2.075772523880005,10000.0,39528.32387185097,42747.13645219803,39528.32387185097,3211.1007244586945,3.273395776748657,0.0 -86600,1.2461622,3.2180214,,,,,,,,,,,,,, -86700,1.3173845,3.2734194,,,,,,,,,,,,,, -86800,1.2324991,3.3747518,,,,,,,,,,,,,, -86900,1.1612506,3.6206806,,,,,,,,,,,,,, -87000,1.2722156,4.776666,,,,,,,,,,,,,, -87100,1.2850795,3.3697736,,,,,,,,,,,,,, -87200,1.2847134,3.4560077,,,,,,,,,,,,,, -87300,1.2475691,4.8496943,,,,,,,,,,,,,, -87400,1.1895019,3.8992164,,,,,,,,,,,,,, -87500,1.2946113,4.268652,,,,,,,,,,,,,, -87520,,,0.7638280987739563,1.196157455444336,0.7017599940299988,1.4531490802764893,50000.0,0.575700044631958,2.062497854232788,10000.0,39948.40259766579,43203.59956860542,39948.40259766579,3247.392087459564,3.3191680908203125,0.0 -87600,1.4315062,3.3409593,,,,,,,,,,,,,, -87700,1.202964,3.3610656,,,,,,,,,,,,,, -87800,1.2359365,3.3051753,,,,,,,,,,,,,, -87900,1.3260368,3.233071,,,,,,,,,,,,,, -88000,1.3267733,3.3100004,,,,,,,,,,,,,, -88100,1.2084082,3.4292655,,,,,,,,,,,,,, -88200,1.2977021,3.2878857,,,,,,,,,,,,,, -88300,1.3839834,4.685006,,,,,,,,,,,,,, -88400,1.3544436,5.012844,,,,,,,,,,,,,, -88441,,,0.7773827910423279,1.0995545387268066,0.7068799734115601,1.3993632793426514,50000.0,0.5833000540733337,1.993497610092163,10000.0,40368.75314331055,43657.85054802895,40368.75314331055,3281.204854249954,3.3600547313690186,0.0 -88500,1.3406419,3.3117497,,,,,,,,,,,,,, -88600,1.271475,3.2531664,,,,,,,,,,,,,, -88700,1.1452087,4.1046395,,,,,,,,,,,,,, -88800,1.4200749,3.285506,,,,,,,,,,,,,, -88900,1.3145425,3.297291,,,,,,,,,,,,,, -89000,1.2610452,3.6103468,,,,,,,,,,,,,, -89100,1.3482502,3.2759929,,,,,,,,,,,,,, -89200,1.3862293,3.3686807,,,,,,,,,,,,,, -89300,1.2077373,4.216964,,,,,,,,,,,,,, -89362,,,0.7879687547683716,1.051804780960083,0.7084199786186218,1.3937915563583374,50000.0,0.5820000171661377,2.0068376064300537,10000.0,40788.70879864693,44110.658349752426,40788.70879864693,3313.9709889888763,3.399597406387329,0.0 -89400,1.4204178,4.832316,,,,,,,,,,,,,, -89500,1.1805182,4.419757,,,,,,,,,,,,,, -89600,1.2225083,3.8094544,,,,,,,,,,,,,, -89700,1.2201782,4.630609,,,,,,,,,,,,,, -89800,1.329136,3.3791614,,,,,,,,,,,,,, -89900,1.439758,4.922241,,,,,,,,,,,,,, -90000,1.260114,3.8643804,,,,,,,,,,,,,, -90100,1.12159,3.7860355,,,,,,,,,,,,,, -90200,1.2445964,3.9097006,,,,,,,,,,,,,, -90282,,,0.7703515291213989,1.137851357460022,0.7090799808502197,1.4091787338256836,50000.0,0.5849000215530396,2.013418197631836,10000.0,41208.68663263321,44563.522152900696,41208.68663263321,3346.769567489624,3.440016031265259,0.0 -90300,1.2610184,4.162902,,,,,,,,,,,,,, -90400,1.1869124,4.398281,,,,,,,,,,,,,, -90500,1.3029858,3.24147,,,,,,,,,,,,,, -90600,1.2877641,4.61693,,,,,,,,,,,,,, -90700,1.2041091,3.2728565,,,,,,,,,,,,,, -90800,1.2704163,4.5196047,,,,,,,,,,,,,, -90900,1.3051645,4.616604,,,,,,,,,,,,,, -91000,1.3662664,3.3210018,,,,,,,,,,,,,, -91100,1.206859,3.7254074,,,,,,,,,,,,,, -91200,1.3244065,3.243153,,,,,,,,,,,,,, -91203,,,0.779980480670929,1.1019827127456665,0.7110399603843689,1.3903183937072754,50000.0,0.5835000276565552,1.9904738664627075,10000.0,41628.99574398994,45017.9036052227,41628.99574398994,3380.757195711136,3.477928876876831,0.0 -91300,1.2740201,3.662433,,,,,,,,,,,,,, -91400,1.3704392,3.2555242,,,,,,,,,,,,,, -91500,1.3617165,4.998319,,,,,,,,,,,,,, -91600,1.2538217,3.247071,,,,,,,,,,,,,, -91700,1.3673314,3.359791,,,,,,,,,,,,,, -91800,1.3283154,3.3271227,,,,,,,,,,,,,, -91900,1.2442561,3.118351,,,,,,,,,,,,,, -92000,1.1984606,3.362358,,,,,,,,,,,,,, -92100,1.3417553,4.8120584,,,,,,,,,,,,,, -92127,,,0.7870507836341858,1.0477782487869265,0.7113999724388123,1.3689666986465454,50000.0,0.5919000506401062,1.961077690124512,10000.0,42049.15962576866,45470.5659160614,42049.15962576866,3413.170263528824,3.516024589538574,0.0 -92200,1.3022044,3.5737822,,,,,,,,,,,,,, -92300,1.3430431,4.878213,,,,,,,,,,,,,, -92400,1.1495733,4.1473346,,,,,,,,,,,,,, -92500,1.3331636,3.3005617,,,,,,,,,,,,,, -92600,1.2148826,3.6617818,,,,,,,,,,,,,, -92700,1.3975995,3.3354805,,,,,,,,,,,,,, -92800,1.3411866,3.2577343,,,,,,,,,,,,,, -92900,1.4613724,4.9944396,,,,,,,,,,,,,, -93000,1.2507367,3.445597,,,,,,,,,,,,,, -93051,,,0.7793359160423279,1.0956302881240845,0.7135199904441833,1.3716667890548706,50000.0,0.5963000059127808,1.9534903764724727,10000.0,42469.19526147842,45925.48884344101,42469.19526147842,3447.9696865081787,3.5564661026000977,0.0 -93100,1.2720318,4.2434025,,,,,,,,,,,,,, -93200,1.4142022,3.239519,,,,,,,,,,,,,, -93300,1.2734947,4.531387,,,,,,,,,,,,,, -93400,1.2764816,3.3895433,,,,,,,,,,,,,, -93500,1.2420788,3.3736153,,,,,,,,,,,,,, -93600,1.289604,4.0658393,,,,,,,,,,,,,, -93700,1.4083759,3.2328963,,,,,,,,,,,,,, -93800,1.4440658,4.983445,,,,,,,,,,,,,, -93900,1.2353579,4.1717515,,,,,,,,,,,,,, -93974,,,0.7795116901397705,1.0931429862976074,0.7133600115776062,1.3716752529144287,50000.0,0.5918000340461731,1.9728366136550903,10000.0,42889.19483089447,46378.44157075882,42889.19483089447,3480.8335807323456,3.598971605300904,0.0 -94000,1.1882293,3.6170442,,,,,,,,,,,,,, -94100,1.4165087,3.2489667,,,,,,,,,,,,,, -94200,1.4130456,3.343027,,,,,,,,,,,,,, -94300,1.4031603,4.587747,,,,,,,,,,,,,, -94400,1.2436203,3.1832476,,,,,,,,,,,,,, -94500,1.2618592,3.532207,,,,,,,,,,,,,, -94600,1.3848519,3.2931237,,,,,,,,,,,,,, -94700,1.4374783,3.6008806,,,,,,,,,,,,,, -94800,1.2393702,4.7201314,,,,,,,,,,,,,, -94893,,,0.7892773151397705,1.0392942428588867,0.7130999565124512,1.3569936752319336,50000.0,0.5925000309944153,1.949466347694397,10000.0,43309.4011952877,46834.06846022606,43309.4011952877,3516.163388967514,3.6427321434021,0.0 -94900,1.2805603,3.2745647,,,,,,,,,,,,,, -95000,1.2736852,3.20463,,,,,,,,,,,,,, -95100,1.3446562,3.1489286,,,,,,,,,,,,,, -95200,1.2964374,3.1672823,,,,,,,,,,,,,, -95300,1.3183545,3.932047,,,,,,,,,,,,,, -95400,1.462391,4.927211,,,,,,,,,,,,,, -95500,1.324781,4.3175664,,,,,,,,,,,,,, -95600,1.1992265,3.4173717,,,,,,,,,,,,,, -95700,1.3900819,3.2582462,,,,,,,,,,,,,, -95800,1.2590971,3.5559552,,,,,,,,,,,,,, -95815,,,0.7995898127555847,0.9954485297203064,0.7177199721336365,1.337254524230957,50000.0,0.5929000377655029,1.937144875526428,10000.0,43729.76480174065,47286.249116420746,43729.76480174065,3547.8907368183136,3.684923648834229,0.0 -95900,1.3097671,3.2243223,,,,,,,,,,,,,, -96000,1.2486482,3.9217162,,,,,,,,,,,,,, -96100,1.4862027,4.942393,,,,,,,,,,,,,, -96200,1.3373936,3.2619214,,,,,,,,,,,,,, -96300,1.329721,3.1730444,,,,,,,,,,,,,, -96400,1.3350377,3.247881,,,,,,,,,,,,,, -96500,1.2014396,3.8057177,,,,,,,,,,,,,, -96600,1.3654717,3.5244114,,,,,,,,,,,,,, -96700,1.3713928,3.2673335,,,,,,,,,,,,,, -96738,,,0.77943354845047,1.1019657850265503,0.7138400077819824,1.384581208229065,50000.0,0.5933000445365906,1.9811594486236568,10000.0,44150.00039052963,47740.57042002678,44150.00039052963,3581.8719758987427,3.726364850997925,0.0 -96800,1.4799128,3.2891567,,,,,,,,,,,,,, -96900,1.3753327,4.9165716,,,,,,,,,,,,,, -97000,1.2690887,3.2128875,,,,,,,,,,,,,, -97100,1.4135704,4.8760595,,,,,,,,,,,,,, -97200,1.468493,4.812477,,,,,,,,,,,,,, -97300,1.3821272,3.201561,,,,,,,,,,,,,, -97400,1.3450612,4.558447,,,,,,,,,,,,,, -97500,1.2922987,3.2003303,,,,,,,,,,,,,, -97600,1.3520058,3.2336974,,,,,,,,,,,,,, -97662,,,0.7866015434265137,1.0639996528625488,0.7142999768257141,1.3826018571853638,50000.0,0.5907000303268433,1.9830719232559204,10000.0,44570.135746240616,48195.98546934128,44570.135746240616,3617.057484388352,3.7728912830352783,0.0 -97700,1.272726,4.23314,,,,,,,,,,,,,, -97800,1.3065449,3.2857172,,,,,,,,,,,,,, -97900,1.3902334,3.3946776,,,,,,,,,,,,,, -98000,1.253907,4.341673,,,,,,,,,,,,,, -98100,1.1999143,3.824374,,,,,,,,,,,,,, -98200,1.4196954,3.2370276,,,,,,,,,,,,,, -98300,1.4580048,4.398239,,,,,,,,,,,,,, -98400,1.3564464,4.548629,,,,,,,,,,,,,, -98500,1.4075106,3.2501197,,,,,,,,,,,,,, -98587,,,0.7963476181030273,1.0071494579315186,0.7149400115013123,1.3564001321792605,50000.0,0.5903000235557556,1.9637293815612795,10000.0,44990.07020807266,48646.9700319767,44990.07020807266,3648.015291452408,3.817314863204956,0.0 -98600,1.3983997,4.7447,,,,,,,,,,,,,, -98700,1.2975721,3.1960864,,,,,,,,,,,,,, -98800,1.3597028,3.2520351,,,,,,,,,,,,,, -98900,1.361646,3.1760283,,,,,,,,,,,,,, -99000,1.3638828,3.1858428,,,,,,,,,,,,,, -99100,1.2460192,4.2039433,,,,,,,,,,,,,, -99200,1.486828,3.2250652,,,,,,,,,,,,,, -99300,1.3027945,3.3040922,,,,,,,,,,,,,, -99400,1.3377191,3.1608224,,,,,,,,,,,,,, -99500,1.3084437,3.2056715,,,,,,,,,,,,,, -99509,,,0.78919917345047,1.0567806959152222,0.7200799584388733,1.3418904542922974,50000.0,0.597100019454956,1.9461467266082764,10000.0,45410.03365278244,49102.78510594368,45410.03365278244,3683.7723546028137,3.8653643131256095,0.0 -99600,1.5025492,3.2537274,,,,,,,,,,,,,, -99700,1.3034699,3.346782,,,,,,,,,,,,,, -99800,1.4151007,3.2840848,,,,,,,,,,,,,, -99900,1.5709784,3.2163594,,,,,,,,,,,,,, -100000,1.2650746,4.1737924,,,,,,,,,,,,,, -100100,1.4100572,3.2713602,,,,,,,,,,,,,, -100200,1.2773731,3.241695,,,,,,,,,,,,,, -100300,1.5707183,3.2267826,,,,,,,,,,,,,, -100400,1.5422627,4.836713,,,,,,,,,,,,,, -100432,,,0.7888476252555847,1.059237003326416,0.7181999683380127,1.3572092056274414,50000.0,0.5997000336647034,1.9448856115341189,10000.0,45830.30660319328,49560.086330890656,45830.30660319328,3720.710344791413,3.908934354782105,0.0 -100500,1.3265624,3.6204138,,,,,,,,,,,,,, -100600,1.3683163,3.3062854,,,,,,,,,,,,,, -100700,1.4482306,3.2133172,,,,,,,,,,,,,, -100800,1.306748,3.2271655,,,,,,,,,,,,,, -100900,1.3404845,4.069162,,,,,,,,,,,,,, -101000,1.4438416,3.2141843,,,,,,,,,,,,,, -101100,1.4634609,3.2964296,,,,,,,,,,,,,, -101200,1.4266436,3.259267,,,,,,,,,,,,,, -101300,1.5231109,3.23298,,,,,,,,,,,,,, -101357,,,0.8024413585662842,0.987043559551239,0.7212799787521362,1.3317164182662964,50000.0,0.597100019454956,1.923500657081604,10000.0,46250.637231349945,50018.72805285454,46250.637231349945,3758.933340787888,3.949855089187622,0.0 -101400,1.4566817,3.3074074,,,,,,,,,,,,,, -101500,1.3949791,3.1426878,,,,,,,,,,,,,, -101600,1.3862826,3.2789066,,,,,,,,,,,,,, -101700,1.2964137,3.3506408,,,,,,,,,,,,,, -101800,1.3040116,3.857936,,,,,,,,,,,,,, -101900,1.4777277,3.566829,,,,,,,,,,,,,, -102000,1.3843342,3.967892,,,,,,,,,,,,,, -102100,1.4187249,3.2731347,,,,,,,,,,,,,, -102200,1.357101,3.278099,,,,,,,,,,,,,, -102279,,,0.7923827767372131,1.034652590751648,0.7217400074005127,1.337119460105896,50000.0,0.5967000126838684,1.939675450325012,10000.0,46670.904515028,50477.97007155418,46670.904515028,3797.81196808815,3.9987268447875977,0.0 -102300,1.2723932,3.4263823,,,,,,,,,,,,,, -102400,1.4206339,3.214991,,,,,,,,,,,,,, -102500,1.34483,3.0842872,,,,,,,,,,,,,, -102600,1.3744702,4.5543084,,,,,,,,,,,,,, -102700,1.4090462,3.1545448,,,,,,,,,,,,,, -102800,1.4105203,3.2025952,,,,,,,,,,,,,, -102900,1.4615377,3.2140903,,,,,,,,,,,,,, -103000,1.468359,4.376425,,,,,,,,,,,,,, -103100,1.4366916,3.219715,,,,,,,,,,,,,, -103199,,,0.7949413657188416,1.0258285999298096,0.7251799702644348,1.329066514968872,50000.0,0.5997000336647034,1.9212018251419067,10000.0,47091.08183169365,50932.68405771256,47091.08183169365,3832.254055023194,4.04592490196228,0.0 -103200,1.3600357,4.3179607,,,,,,,,,,,,,, -103300,1.5361114,3.1628075,,,,,,,,,,,,,, -103400,1.572971,4.913235,,,,,,,,,,,,,, -103500,1.3080764,3.6478782,,,,,,,,,,,,,, -103600,1.3313887,3.18139,,,,,,,,,,,,,, -103700,1.3859048,3.1633103,,,,,,,,,,,,,, -103800,1.3401265,4.3768983,,,,,,,,,,,,,, -103900,1.3947877,3.1953492,,,,,,,,,,,,,, -104000,1.3527448,3.18871,,,,,,,,,,,,,, -104100,1.339444,3.2218876,,,,,,,,,,,,,, -104119,,,0.8043944835662842,0.9767816662788392,0.7267799973487854,1.3111376762390137,50000.0,0.6007000207901001,1.8960027694702148,10000.0,47511.4247674942,51386.34076428413,47511.4247674942,3865.475003957749,4.091261148452759,0.0 -104200,1.4834443,3.2224174,,,,,,,,,,,,,, -104300,1.4575112,3.202291,,,,,,,,,,,,,, -104400,1.3608602,3.1254363,,,,,,,,,,,,,, -104500,1.4377726,3.1601207,,,,,,,,,,,,,, -104600,1.3614364,3.1358502,,,,,,,,,,,,,, -104700,1.4198073,3.1797698,,,,,,,,,,,,,, -104800,1.4240648,4.7762322,,,,,,,,,,,,,, -104900,1.4562738,3.2416039,,,,,,,,,,,,,, -105000,1.5909449,4.851205,,,,,,,,,,,,,, -105026,,,0.8005077838897705,0.9992862343788148,0.7254199981689453,1.3189843893051147,50000.0,0.6068000197410583,1.9124714136123653,10000.0,47931.44809389114,51843.25613093376,47931.44809389114,3902.273670196533,4.136998176574707,0.0 -105100,1.4065706,3.1753008,,,,,,,,,,,,,, -105200,1.4722471,3.202788,,,,,,,,,,,,,, -105300,1.3954412,3.4225364,,,,,,,,,,,,,, -105400,1.332928,3.865053,,,,,,,,,,,,,, -105500,1.3741817,3.2435482,,,,,,,,,,,,,, -105600,1.399847,3.5058446,,,,,,,,,,,,,, -105700,1.3862413,3.165719,,,,,,,,,,,,,, -105800,1.3500144,4.164329,,,,,,,,,,,,,, -105900,1.408061,3.5561814,,,,,,,,,,,,,, -105950,,,0.7973827719688416,0.9926227927207948,0.7240599989891052,1.3056074380874634,50000.0,0.6034000515937805,1.899614930152893,10000.0,48351.60951066017,52299.75429058075,48351.60951066017,3938.51873588562,4.1818811893463135,0.0 -106000,1.593454,3.2113564,,,,,,,,,,,,,, -106100,1.5552677,3.1911418,,,,,,,,,,,,,, -106200,1.5382582,4.723667,,,,,,,,,,,,,, -106300,1.318981,3.184229,,,,,,,,,,,,,, -106400,1.5595737,3.2699428,,,,,,,,,,,,,, -106500,1.4933871,3.2987525,,,,,,,,,,,,,, -106600,1.3906976,4.4631824,,,,,,,,,,,,,, -106700,1.5299366,4.7110105,,,,,,,,,,,,,, -106800,1.3083185,3.3348775,,,,,,,,,,,,,, -106871,,,0.8055273294448853,0.9919548630714417,0.7270799875259399,1.3199224472045898,50000.0,0.6069000363349915,1.9079723358154297,10000.0,48771.873259305954,52754.04180908203,48771.873259305954,3972.452043533325,4.225016117095947,0.0 -106900,1.5334345,4.1153355,,,,,,,,,,,,,, -107000,1.5074037,3.1159432,,,,,,,,,,,,,, -107100,1.4561257,3.2344465,,,,,,,,,,,,,, -107200,1.4091604,3.1792855,,,,,,,,,,,,,, -107300,1.4619355,4.3252635,,,,,,,,,,,,,, -107400,1.408518,3.1790078,,,,,,,,,,,,,, -107500,1.413824,4.319354,,,,,,,,,,,,,, -107600,1.609077,4.7679167,,,,,,,,,,,,,, -107700,1.411135,3.3616219,,,,,,,,,,,,,, -107792,,,0.8181054592132568,0.96428245306015,0.7266599535942078,1.341126561164856,50000.0,0.6049000024795532,1.9311014413833616,10000.0,49191.80590748787,53209.868277311325,49191.80590748787,4008.251214504242,4.272565364837647,0.0 -107800,1.4987766,3.1592493,,,,,,,,,,,,,, -107900,1.4970474,3.1592932,,,,,,,,,,,,,, -108000,1.4957017,3.1845167,,,,,,,,,,,,,, -108100,1.4781848,3.9475422,,,,,,,,,,,,,, -108200,1.6660231,4.9479675,,,,,,,,,,,,,, -108300,1.4261369,3.3349233,,,,,,,,,,,,,, -108400,1.5955737,4.575924,,,,,,,,,,,,,, -108500,1.4679996,3.9910855,,,,,,,,,,,,,, -108600,1.400342,3.1160011,,,,,,,,,,,,,, -108700,1.5563481,4.8442035,,,,,,,,,,,,,, -108714,,,0.8000195026397705,0.9913946390151978,0.7280600070953369,1.2932459115982056,50000.0,0.6097000241279602,1.8781938552856443,10000.0,49611.74863290787,53664.38921499252,49611.74863290787,4042.738936901093,4.315826892852783,0.0 -108800,1.5763755,3.1535919,,,,,,,,,,,,,, -108900,1.4899439,3.3485382,,,,,,,,,,,,,, -109000,1.5046861,3.5527585,,,,,,,,,,,,,, -109100,1.4631971,3.1293082,,,,,,,,,,,,,, -109200,1.4105786,3.3149645,,,,,,,,,,,,,, -109300,1.4687877,3.14781,,,,,,,,,,,,,, -109400,1.4000312,3.2745621,,,,,,,,,,,,,, -109500,1.5439136,3.4379299,,,,,,,,,,,,,, -109600,1.3310876,3.4463642,,,,,,,,,,,,,, -109635,,,0.8060351610183716,0.9905210733413696,0.7256999611854553,1.3248872756958008,50000.0,0.6095000505447388,1.916000843048096,10000.0,50031.87528467178,54121.824350357056,50031.87528467178,4079.958817481994,4.35719108581543,0.0 -109700,1.5102642,4.1234164,,,,,,,,,,,,,, -109800,1.7435305,4.861759,,,,,,,,,,,,,, -109900,1.5464002,3.1076665,,,,,,,,,,,,,, -110000,1.57939,3.145313,,,,,,,,,,,,,, -110100,1.5585903,3.2064133,,,,,,,,,,,,,, -110200,1.8184593,4.8061814,,,,,,,,,,,,,, -110300,1.5114921,3.2725315,,,,,,,,,,,,,, -110400,1.4605117,3.362778,,,,,,,,,,,,,, -110500,1.3236586,3.7341344,,,,,,,,,,,,,, -110557,,,0.8092968463897705,0.9955326914787292,0.7274199724197388,1.3348983526229858,50000.0,0.6116000413894653,1.922496199607849,10000.0,50451.978063583374,54578.00987672806,50451.978063583374,4115.946965456009,4.404077291488648,0.0 -110600,1.5928276,4.48358,,,,,,,,,,,,,, -110700,1.4812602,3.1626747,,,,,,,,,,,,,, -110800,1.4235388,3.2277434,,,,,,,,,,,,,, -110900,1.3599341,3.5289583,,,,,,,,,,,,,, -111000,1.388944,3.3315892,,,,,,,,,,,,,, -111100,1.439879,3.1075783,,,,,,,,,,,,,, -111200,1.4626194,3.190384,,,,,,,,,,,,,, -111300,1.4814212,3.406108,,,,,,,,,,,,,, -111400,1.4237517,3.4857867,,,,,,,,,,,,,, -111475,,,0.8045703172683716,0.9676870703697203,0.7338599562644958,1.2687467336654663,50000.0,0.6123000383377075,1.863784909248352,10000.0,50871.95857954025,55032.40033340454,50871.95857954025,4150.26503443718,4.449316024780273,0.0 -111500,1.486964,3.1471398,,,,,,,,,,,,,, -111600,1.6194844,3.1683724,,,,,,,,,,,,,, -111700,1.4846002,3.0864518,,,,,,,,,,,,,, -111800,1.7086744,4.8316994,,,,,,,,,,,,,, -111900,1.3591646,3.6558993,,,,,,,,,,,,,, -112000,1.6404337,4.6642685,,,,,,,,,,,,,, -112100,1.5791916,3.1375709,,,,,,,,,,,,,, -112200,1.6166439,3.1625245,,,,,,,,,,,,,, -112300,1.4342357,3.1016612,,,,,,,,,,,,,, -112397,,,0.8095507621765137,0.967756688594818,0.7329199910163879,1.299895405769348,50000.0,0.6091000437736511,1.8859690427780151,10000.0,51292.22880363464,55488.480078697205,51292.22880363464,4185.976754188538,4.499913215637207,0.0 -112400,1.4504943,3.1702533,,,,,,,,,,,,,, -112500,1.6130561,3.133166,,,,,,,,,,,,,, -112600,1.3531818,3.412917,,,,,,,,,,,,,, -112700,1.5745298,3.2014716,,,,,,,,,,,,,, -112800,1.462357,3.0548964,,,,,,,,,,,,,, -112900,1.5524075,3.2546442,,,,,,,,,,,,,, -113000,1.4991465,3.1218321,,,,,,,,,,,,,, -113100,1.2934997,3.6552014,,,,,,,,,,,,,, -113200,1.5386446,3.1281476,,,,,,,,,,,,,, -113300,1.5131733,4.246306,,,,,,,,,,,,,, -113317,,,0.8145117163658142,0.9527512192726136,0.7319999933242798,1.2896184921264648,50000.0,0.615600049495697,1.876335978507996,10000.0,51712.16885471344,55943.190348148346,51712.16885471344,4220.654976129532,4.545567512512207,0.0 -113400,1.4414892,4.0295973,,,,,,,,,,,,,, -113500,1.51224,3.654875,,,,,,,,,,,,,, -113600,1.5195411,3.1200316,,,,,,,,,,,,,, -113700,1.5174468,3.1923902,,,,,,,,,,,,,, -113800,1.3731619,3.5070999,,,,,,,,,,,,,, -113900,1.3687502,4.135688,,,,,,,,,,,,,, -114000,1.6407034,3.2919302,,,,,,,,,,,,,, -114100,1.5017002,3.1835268,,,,,,,,,,,,,, -114200,1.602122,3.1337245,,,,,,,,,,,,,, -114240,,,0.8056835532188416,0.9622820615768432,0.7326399683952332,1.2778819799423218,50000.0,0.6132000088691711,1.860758304595948,10000.0,52132.5546040535,56400.49531674385,52132.5546040535,4257.473162174225,4.599227428436279,0.0 -114300,1.5670536,3.1064887,,,,,,,,,,,,,, -114400,1.5224476,3.9427447,,,,,,,,,,,,,, -114500,1.424985,3.7440882,,,,,,,,,,,,,, -114600,1.5369178,3.1253917,,,,,,,,,,,,,, -114700,1.4125165,3.1131454,,,,,,,,,,,,,, -114800,1.4838065,3.124108,,,,,,,,,,,,,, -114900,1.4602879,3.1064227,,,,,,,,,,,,,, -115000,1.4965769,3.1899762,,,,,,,,,,,,,, -115100,1.5475082,3.1599264,,,,,,,,,,,,,, -115161,,,0.8109765648841858,0.9565854668617249,0.7332599759101868,1.2811367511749268,50000.0,0.6152000427246094,1.872671604156494,10000.0,52552.574162483215,56855.47197389603,52552.574162483215,4292.337894201279,4.644402265548706,0.0 -115200,1.5545844,3.1478992,,,,,,,,,,,,,, -115300,1.5454773,3.2234256,,,,,,,,,,,,,, -115400,1.6598995,3.0747972,,,,,,,,,,,,,, -115500,1.5361925,4.3734946,,,,,,,,,,,,,, -115600,1.5563463,3.1312273,,,,,,,,,,,,,, -115700,1.5484542,3.474884,,,,,,,,,,,,,, -115800,1.6170442,3.092875,,,,,,,,,,,,,, -115900,1.4456718,3.3586311,,,,,,,,,,,,,, -116000,1.5601552,4.328184,,,,,,,,,,,,,, -116082,,,0.8161913752555847,0.9610245823860168,0.7358399629592896,1.299484133720398,50000.0,0.6181000471115112,1.8877934217453003,10000.0,52972.78386044502,57310.55331468582,52972.78386044502,4327.119237422943,4.6875598430633545,0.0 -116100,1.6279455,3.2920842,,,,,,,,,,,,,, -116200,1.6069695,3.9587355,,,,,,,,,,,,,, -116300,1.5429963,3.4941096,,,,,,,,,,,,,, -116400,1.6851496,4.5021453,,,,,,,,,,,,,, -116500,1.4559281,3.1674418,,,,,,,,,,,,,, -116600,1.6180043,3.2335975,,,,,,,,,,,,,, -116700,1.4720168,3.9024658,,,,,,,,,,,,,, -116800,1.6051072,3.1090744,,,,,,,,,,,,,, -116900,1.6667128,4.428535,,,,,,,,,,,,,, -117000,1.5736108,3.1308544,,,,,,,,,,,,,, -117005,,,0.8302538990974426,0.8830304741859436,0.7373999953269958,1.2599748373031616,50000.0,0.613800048828125,1.8561517000198364,10000.0,53392.99377536774,57769.04441308975,53392.99377536774,4365.301455259323,4.738725185394287,0.0 -117100,1.6936722,4.0692477,,,,,,,,,,,,,, -117200,1.5542595,3.1823788,,,,,,,,,,,,,, -117300,1.6700575,3.1171668,,,,,,,,,,,,,, -117400,1.548895,3.0955052,,,,,,,,,,,,,, -117500,1.4636961,3.0380359,,,,,,,,,,,,,, -117600,1.5341396,3.056194,,,,,,,,,,,,,, -117700,1.5105157,3.700643,,,,,,,,,,,,,, -117800,1.7492113,4.5769453,,,,,,,,,,,,,, -117900,1.614379,4.098425,,,,,,,,,,,,,, -117924,,,0.8176171779632568,0.9192262291908264,0.7388799786567688,1.245247483253479,50000.0,0.6217000484466553,1.835737228393555,10000.0,53813.30598425865,58229.93274021149,53813.30598425865,4405.786436319351,4.782990217208862,0.0 -118000,1.6007133,3.9881225,,,,,,,,,,,,,, -118100,1.721039,4.7090454,,,,,,,,,,,,,, -118200,1.7382171,4.7116632,,,,,,,,,,,,,, -118300,1.707373,3.2103987,,,,,,,,,,,,,, -118400,1.6640303,4.354335,,,,,,,,,,,,,, -118500,1.5781612,4.224739,,,,,,,,,,,,,, -118600,1.5506134,3.7110171,,,,,,,,,,,,,, -118700,1.5349324,3.8542833,,,,,,,,,,,,,, -118800,1.6596007,3.2118478,,,,,,,,,,,,,, -118847,,,0.8188671469688416,0.9099279642105104,0.7407199740409851,1.2410085201263428,50000.0,0.6162000298500061,1.840397596359253,10000.0,54233.51160264015,58685.261761426926,54233.51160264015,4440.815090417862,4.830764055252075,0.0 -118900,1.522501,4.09352,,,,,,,,,,,,,, -119000,1.7431523,4.540595,,,,,,,,,,,,,, -119100,1.6861657,3.0795503,,,,,,,,,,,,,, -119200,1.6196548,3.0614684,,,,,,,,,,,,,, -119300,1.722341,4.411455,,,,,,,,,,,,,, -119400,1.5738277,3.1027606,,,,,,,,,,,,,, -119500,1.5759435,3.1074862,,,,,,,,,,,,,, -119600,1.5754458,3.3268151,,,,,,,,,,,,,, -119700,1.6231927,3.3894668,,,,,,,,,,,,,, -119772,,,0.8258007764816284,0.8815270662307739,0.7394799590110779,1.2439427375793457,50000.0,0.6149000525474548,1.845686197280884,10000.0,54653.70579338074,59141.99428868294,54653.70579338074,4477.259788036346,4.876175165176392,0.0 -119800,1.7267767,4.4723563,,,,,,,,,,,,,, -119900,1.8258138,4.693211,,,,,,,,,,,,,, -120000,1.5296344,3.2262738,,,,,,,,,,,,,, -120100,1.7363764,4.588852,,,,,,,,,,,,,, -120200,1.545062,3.2314184,,,,,,,,,,,,,, -120300,1.6077014,4.215634,,,,,,,,,,,,,, -120400,1.6598471,3.1043386,,,,,,,,,,,,,, -120500,1.674102,3.1236854,,,,,,,,,,,,,, -120600,1.6395428,4.319793,,,,,,,,,,,,,, -120694,,,0.81947261095047,0.9173012971878052,0.7415800094604492,1.251138687133789,50000.0,0.6217000484466553,1.8443632125854488,10000.0,55073.763964653015,59598.02819681168,55073.763964653015,4513.142883777618,4.922154426574707,0.0 -120700,1.5428382,3.591463,,,,,,,,,,,,,, -120800,1.7541606,4.4879827,,,,,,,,,,,,,, -120900,1.6061697,3.1994028,,,,,,,,,,,,,, -121000,1.6474872,3.0404682,,,,,,,,,,,,,, -121100,1.63071,3.4552424,,,,,,,,,,,,,, -121200,1.6743361,3.831251,,,,,,,,,,,,,, -121300,1.502,3.2613006,,,,,,,,,,,,,, -121400,1.60273,3.3628957,,,,,,,,,,,,,, -121500,1.5923513,3.4138024,,,,,,,,,,,,,, -121600,1.6965436,4.278521,,,,,,,,,,,,,, -121618,,,0.8227343559265137,0.90450918674469,0.7416200041770935,1.2397983074188232,50000.0,0.6201000213623047,1.8253264427185056,10000.0,55493.97620844841,60051.2044301033,55493.97620844841,4546.011871576309,4.970205307006836,0.0 -121700,1.5638504,3.938483,,,,,,,,,,,,,, -121800,1.8264232,3.2167017,,,,,,,,,,,,,, -121900,1.5310764,3.296349,,,,,,,,,,,,,, -122000,1.722107,3.1275194,,,,,,,,,,,,,, -122100,1.7673975,4.346209,,,,,,,,,,,,,, -122200,1.6432111,3.089512,,,,,,,,,,,,,, -122300,1.6292659,3.1093268,,,,,,,,,,,,,, -122400,1.6942807,3.026963,,,,,,,,,,,,,, -122500,1.8018011,4.623193,,,,,,,,,,,,,, -122542,,,0.828417956829071,0.8838768601417542,0.7440599799156189,1.24087393283844,50000.0,0.6277000308036804,1.8244028091430664,10000.0,55914.22261214256,60508.75056910515,55914.22261214256,4583.217739343643,5.016493797302246,0.0 -122600,1.6231678,3.6542585,,,,,,,,,,,,,, -122700,1.7094554,3.6760933,,,,,,,,,,,,,, -122800,1.7589976,4.6689262,,,,,,,,,,,,,, -122900,1.6686149,4.2750483,,,,,,,,,,,,,, -123000,1.9415714,4.6276455,,,,,,,,,,,,,, -123100,1.6501346,3.0698864,,,,,,,,,,,,,, -123200,1.5599576,3.3098664,,,,,,,,,,,,,, -123300,1.5465956,3.1100197,,,,,,,,,,,,,, -123400,1.6631451,4.044796,,,,,,,,,,,,,, -123465,,,0.8238476514816284,0.915237843990326,0.746399998664856,1.238528609275818,50000.0,0.6285000443458557,1.8137975931167605,10000.0,56334.52269983292,60965.69726276398,56334.52269983292,4619.772355079651,5.061162710189819,0.0 -123500,1.6372913,3.736506,,,,,,,,,,,,,, -123600,1.770843,3.055946,,,,,,,,,,,,,, -123700,1.7918699,4.5592537,,,,,,,,,,,,,, -123800,1.6581285,3.1993878,,,,,,,,,,,,,, -123900,1.6760329,3.4851587,,,,,,,,,,,,,, -124000,1.7831621,3.1038694,,,,,,,,,,,,,, -124100,2.0576055,4.7587347,,,,,,,,,,,,,, -124200,1.8189249,2.9916112,,,,,,,,,,,,,, -124300,1.9007936,4.6916723,,,,,,,,,,,,,, -124385,,,0.8286913633346558,0.886573851108551,0.7448999881744385,1.2313250303268433,50000.0,0.6270000338554382,1.8222373723983765,10000.0,56754.61694145203,61424.96513009071,56754.61694145203,4658.850819826126,5.109199523925781,0.0 -124400,1.5529538,3.2013278,,,,,,,,,,,,,, -124500,1.6121854,3.696527,,,,,,,,,,,,,, -124600,1.770241,3.046208,,,,,,,,,,,,,, -124700,1.5697082,3.8526073,,,,,,,,,,,,,, -124800,1.7155384,4.0713577,,,,,,,,,,,,,, -124900,1.975523,4.6587195,,,,,,,,,,,,,, -125000,1.5986153,3.9320955,,,,,,,,,,,,,, -125100,1.8787329,4.395163,,,,,,,,,,,,,, -125200,1.6506135,3.5870829,,,,,,,,,,,,,, -125300,1.643549,3.0847025,,,,,,,,,,,,,, -125305,,,0.8369140625,0.8582568168640137,0.7482799887657166,1.2224678993225098,50000.0,0.6314000487327576,1.801285982131958,10000.0,57174.62446951866,61878.66628551483,57174.62446951866,4692.450619220734,5.156161308288574,0.0 -125400,1.633889,3.3303552,,,,,,,,,,,,,, -125500,1.5954002,3.7050576,,,,,,,,,,,,,, -125600,1.6025505,3.0447803,,,,,,,,,,,,,, -125700,1.5476805,2.9919927,,,,,,,,,,,,,, -125800,1.5883315,3.762354,,,,,,,,,,,,,, -125900,1.666318,3.102449,,,,,,,,,,,,,, -126000,1.7337599,3.0746856,,,,,,,,,,,,,, -126100,1.7473599,2.967285,,,,,,,,,,,,,, -126200,1.6650813,3.3103256,,,,,,,,,,,,,, -126227,,,0.8360546827316284,0.8389551639556885,0.7470200061798096,1.2079110145568848,50000.0,0.6234000325202942,1.8007813692092896,10000.0,57594.86729764938,62332.49602270126,57594.86729764938,4725.944480419159,5.202749729156494,0.0 -126300,1.6677806,3.0193088,,,,,,,,,,,,,, -126400,1.6806021,3.0559893,,,,,,,,,,,,,, -126500,1.7211438,3.009273,,,,,,,,,,,,,, -126600,1.8559878,3.038641,,,,,,,,,,,,,, -126700,1.8559629,3.429801,,,,,,,,,,,,,, -126800,1.851881,3.038858,,,,,,,,,,,,,, -126900,1.7386675,3.0230532,,,,,,,,,,,,,, -127000,1.6504059,3.1909049,,,,,,,,,,,,,, -127100,1.7912728,3.1248856,,,,,,,,,,,,,, -127149,,,0.8318945169448853,0.8691099882125854,0.7507199645042419,1.216996669769287,50000.0,0.6300000548362732,1.7968522310256958,10000.0,58015.21964287758,62789.48238158226,58015.21964287758,4762.485838413239,5.24896764755249,0.0 -127200,1.766273,4.0280414,,,,,,,,,,,,,, -127300,1.765443,3.5304465,,,,,,,,,,,,,, -127400,1.7301997,3.0805519,,,,,,,,,,,,,, -127500,1.8644437,3.0632267,,,,,,,,,,,,,, -127600,1.8603038,3.1209824,,,,,,,,,,,,,, -127700,1.5955317,3.012334,,,,,,,,,,,,,, -127800,1.8410567,3.1714063,,,,,,,,,,,,,, -127900,1.9899395,4.3949327,,,,,,,,,,,,,, -128000,1.7646258,2.9972382,,,,,,,,,,,,,, -128072,,,0.8342577815055847,0.880061686038971,0.7497999668121338,1.2383127212524414,50000.0,0.6292000412940979,1.807980179786682,10000.0,58435.50190925598,63247.64047813416,58435.50190925598,4800.266010761261,5.297834873199463,0.0 -128100,1.7485539,3.0510633,,,,,,,,,,,,,, -128200,1.7702949,3.323426,,,,,,,,,,,,,, -128300,1.8130709,4.053529,,,,,,,,,,,,,, -128400,1.9638227,4.5942025,,,,,,,,,,,,,, -128500,1.6193764,3.7654386,,,,,,,,,,,,,, -128600,2.0270307,4.4964724,,,,,,,,,,,,,, -128700,1.6831409,3.1404123,,,,,,,,,,,,,, -128800,1.720758,3.1248457,,,,,,,,,,,,,, -128900,1.7960279,3.0352135,,,,,,,,,,,,,, -128994,,,0.8449023365974426,0.8151799440383911,0.7497599720954895,1.2033263444900513,50000.0,0.6323000192642212,1.784820795059204,10000.0,58855.55880713463,63711.22246456146,58855.55880713463,4843.690578460693,5.350980758666992,0.0 -129000,1.5757306,3.2593567,,,,,,,,,,,,,, -129100,1.7525845,3.3336647,,,,,,,,,,,,,, -129200,1.6516205,3.0059125,,,,,,,,,,,,,, -129300,1.6964972,2.9929638,,,,,,,,,,,,,, -129400,2.2866082,4.6110783,,,,,,,,,,,,,, -129500,1.958957,4.458524,,,,,,,,,,,,,, -129600,1.8042488,3.120391,,,,,,,,,,,,,, -129700,1.8961747,3.1160848,,,,,,,,,,,,,, -129800,1.7477382,3.645688,,,,,,,,,,,,,, -129900,1.6904665,3.01687,,,,,,,,,,,,,, -129916,,,0.8347851634025574,0.8594391942024231,0.7525799870491028,1.205952286720276,50000.0,0.6305000185966492,1.800089955329895,10000.0,59275.69723725319,64164.15061426163,59275.69723725319,4876.382496595383,5.40146017074585,0.0 -130000,1.8568621,3.1567786,,,,,,,,,,,,,, -130100,1.8789147,2.9583888,,,,,,,,,,,,,, -130200,1.856435,3.0456839,,,,,,,,,,,,,, -130300,1.6246952,3.698765,,,,,,,,,,,,,, -130400,1.8656081,3.5122259,,,,,,,,,,,,,, -130500,1.8445518,3.2868514,,,,,,,,,,,,,, -130600,1.7250609,3.179312,,,,,,,,,,,,,, -130700,1.8316466,4.1486034,,,,,,,,,,,,,, -130800,1.6482217,3.4747403,,,,,,,,,,,,,, -130838,,,0.8397656083106995,0.8464886546134949,0.753879964351654,1.204972743988037,50000.0,0.6305000185966492,1.7893548011779783,10000.0,59695.74828839302,64622.2776722908,59695.74828839302,4914.366876363754,5.4463982582092285,0.0 -130900,1.9534916,3.0648327,,,,,,,,,,,,,, -131000,1.8963509,3.0872564,,,,,,,,,,,,,, -131100,1.9750277,3.0869818,,,,,,,,,,,,,, -131200,2.1249468,4.658723,,,,,,,,,,,,,, -131300,1.6899279,3.510062,,,,,,,,,,,,,, -131400,1.9085039,4.1511297,,,,,,,,,,,,,, -131500,1.9187708,4.4473543,,,,,,,,,,,,,, -131600,1.8413348,3.0122633,,,,,,,,,,,,,, -131700,1.8605971,3.0147552,,,,,,,,,,,,,, -131759,,,0.8443359136581421,0.7996962666511536,0.751039981842041,1.184333324432373,50000.0,0.6345000267028809,1.7539469003677368,10000.0,60115.70867657661,65075.52677679062,60115.70867657661,4947.561500549316,5.493636131286621,0.0 -131800,2.4085433,4.522771,,,,,,,,,,,,,, -131900,2.1051536,4.552372,,,,,,,,,,,,,, -132000,1.8285013,3.7959588,,,,,,,,,,,,,, -132100,1.9124352,4.0802402,,,,,,,,,,,,,, -132200,2.0051484,3.1174872,,,,,,,,,,,,,, -132300,1.854903,3.1742303,,,,,,,,,,,,,, -132400,1.7529542,3.598244,,,,,,,,,,,,,, -132500,1.8547713,3.0668225,,,,,,,,,,,,,, -132600,1.8011595,3.1537158,,,,,,,,,,,,,, -132681,,,0.83753901720047,0.8382362127304077,0.7551400065422058,1.1913368701934814,50000.0,0.6374000310897827,1.769114375114441,10000.0,60535.7517850399,65531.48008418083,60535.7517850399,4983.380564451218,5.537555456161499,0.0 -132700,1.695389,3.5677278,,,,,,,,,,,,,, -132800,1.9791003,3.1891599,,,,,,,,,,,,,, -132900,1.7262173,2.9418964,,,,,,,,,,,,,, -133000,1.7602531,3.0997849,,,,,,,,,,,,,, -133100,1.9347231,3.7325401,,,,,,,,,,,,,, -133200,1.6206448,3.4210806,,,,,,,,,,,,,, -133300,1.8071965,3.5306373,,,,,,,,,,,,,, -133400,1.8644373,2.9871738,,,,,,,,,,,,,, -133500,1.8724561,2.9815052,,,,,,,,,,,,,, -133600,1.9067183,3.1721115,,,,,,,,,,,,,, -133603,,,0.84046870470047,0.8086775541305542,0.7542600035667419,1.1765190362930298,50000.0,0.638200044631958,1.7628462314605713,10000.0,60955.73228693008,65985.58762574196,60955.73228693008,5017.414732217789,5.583152770996094,0.0 -133700,1.8590033,3.100356,,,,,,,,,,,,,, -133800,1.8014991,2.945743,,,,,,,,,,,,,, -133900,1.8778727,3.0096562,,,,,,,,,,,,,, -134000,1.8041096,3.63139,,,,,,,,,,,,,, -134100,1.838211,3.442215,,,,,,,,,,,,,, -134200,1.9203671,2.983675,,,,,,,,,,,,,, -134300,1.7543048,3.9120736,,,,,,,,,,,,,, -134400,2.085171,4.2637315,,,,,,,,,,,,,, -134500,2.0270116,4.1723795,,,,,,,,,,,,,, -134525,,,0.8450585603713989,0.8082579374313354,0.7573599815368652,1.179269313812256,50000.0,0.6324000358581543,1.7707762718200684,10000.0,61375.7694671154,66443.38515496254,61375.7694671154,5055.079945325851,5.631529569625855,0.0 -134600,2.0047717,4.4200916,,,,,,,,,,,,,, -134700,2.3484957,4.499301,,,,,,,,,,,,,, -134800,2.060513,3.0149627,,,,,,,,,,,,,, -134900,1.7677484,3.3059533,,,,,,,,,,,,,, -135000,1.8839799,2.932055,,,,,,,,,,,,,, -135100,1.9008838,3.005765,,,,,,,,,,,,,, -135200,1.7145746,3.2630458,,,,,,,,,,,,,, -135300,1.887207,2.94878,,,,,,,,,,,,,, -135400,1.7843976,3.9889746,,,,,,,,,,,,,, -135415,,,0.8431445360183716,0.7971628308296204,0.7569400072097778,1.1679935455322266,50000.0,0.6370000243186951,1.7401326894760132,10000.0,61795.82298064232,66902.62777686119,61795.82298064232,5094.165184736252,5.689408540725708,0.0 -135500,1.8196243,3.2255075,,,,,,,,,,,,,, -135600,1.8256047,3.0295577,,,,,,,,,,,,,, -135700,1.9243873,3.495418,,,,,,,,,,,,,, -135800,1.8894607,3.035101,,,,,,,,,,,,,, -135900,1.8449244,2.9472377,,,,,,,,,,,,,, -136000,1.8151065,3.2021656,,,,,,,,,,,,,, -136100,1.845603,3.333908,,,,,,,,,,,,,, -136200,1.8985847,4.157062,,,,,,,,,,,,,, -136300,1.9748956,2.8680341,,,,,,,,,,,,,, -136337,,,0.8441015481948853,0.8239647746086121,0.7590799927711487,1.176174521446228,50000.0,0.6324000358581543,1.7639225721359253,10000.0,62215.76799035072,67359.71117639542,62215.76799035072,5131.206265687943,5.739239692687988,0.0 -136400,2.0325947,4.270748,,,,,,,,,,,,,, -136500,1.715589,3.214603,,,,,,,,,,,,,, -136600,1.7961228,3.2982898,,,,,,,,,,,,,, -136700,1.8706561,2.970972,,,,,,,,,,,,,, -136800,1.8890938,3.0063987,,,,,,,,,,,,,, -136900,2.1100698,4.1694646,,,,,,,,,,,,,, -137000,1.9931734,2.9788399,,,,,,,,,,,,,, -137100,1.9853717,2.9038134,,,,,,,,,,,,,, -137200,1.83045,3.3949962,,,,,,,,,,,,,, -137259,,,0.8472656011581421,0.8076205253601074,0.7579599618911743,1.180842041969299,50000.0,0.6370000243186951,1.7493723630905151,10000.0,62636.00538110733,67813.579870224,62636.00538110733,5164.7442235946655,5.78521990776062,0.0 -137300,1.9359723,3.8938832,,,,,,,,,,,,,, -137400,1.8373559,3.1296759,,,,,,,,,,,,,, -137500,1.8325521,2.9332793,,,,,,,,,,,,,, -137600,1.9895214,3.331432,,,,,,,,,,,,,, -137700,2.1749938,4.4470663,,,,,,,,,,,,,, -137800,2.0251637,3.0357943,,,,,,,,,,,,,, -137900,2.2500563,4.5529575,,,,,,,,,,,,,, -138000,1.7933382,3.6158004,,,,,,,,,,,,,, -138100,1.8393534,3.1655896,,,,,,,,,,,,,, -138180,,,0.8597851395606995,0.7563513517379761,0.7586399912834167,1.1718906164169312,50000.0,0.6336000561714172,1.7638990879058838,10000.0,63055.95963048935,68267.94333863258,63055.95963048935,5199.060210227966,5.831464767456055,0.0 -138200,2.5481102,4.420736,,,,,,,,,,,,,, -138300,1.7797939,3.7315984,,,,,,,,,,,,,, -138400,1.9404231,3.478787,,,,,,,,,,,,,, -138500,1.8931693,3.1404471,,,,,,,,,,,,,, -138600,1.831047,3.228733,,,,,,,,,,,,,, -138700,1.9248139,3.7818236,,,,,,,,,,,,,, -138800,1.7997168,3.1894894,,,,,,,,,,,,,, -138900,2.1092873,4.0826592,,,,,,,,,,,,,, -139000,1.9493921,2.984774,,,,,,,,,,,,,, -139100,1.9804091,3.4362726,,,,,,,,,,,,,, -139101,,,0.8470116853713989,0.8125196099281311,0.7589399814605713,1.179947018623352,50000.0,0.6370000243186951,1.7559922933578491,10000.0,63476.03571605682,68724.00987887383,63476.03571605682,5234.952075719833,5.88202166557312,0.0 -139200,1.9014235,3.0496888,,,,,,,,,,,,,, -139300,1.7389396,3.0980036,,,,,,,,,,,,,, -139400,1.7572854,2.996873,,,,,,,,,,,,,, -139500,2.5564022,4.276977,,,,,,,,,,,,,, -139600,1.9088213,3.0241005,,,,,,,,,,,,,, -139700,2.0627825,2.8980482,,,,,,,,,,,,,, -139800,1.7781183,3.4650922,,,,,,,,,,,,,, -139900,2.0065887,3.0417144,,,,,,,,,,,,,, -140000,1.9117699,2.9280953,,,,,,,,,,,,,, -140024,,,0.8527734279632568,0.758125364780426,0.762499988079071,1.136866569519043,50000.0,0.6468000411987305,1.713780164718628,10000.0,63895.9807267189,69180.01283836365,63895.9807267189,5270.915474653244,5.92934775352478,0.0 -140100,1.8959454,2.9766078,,,,,,,,,,,,,, -140200,2.1536293,3.0746155,,,,,,,,,,,,,, -140300,1.9707397,2.9380612,,,,,,,,,,,,,, -140400,1.9034071,3.0813322,,,,,,,,,,,,,, -140500,2.0117276,3.0092938,,,,,,,,,,,,,, -140600,1.9679717,3.3912992,,,,,,,,,,,,,, -140700,2.1651714,2.9276147,,,,,,,,,,,,,, -140800,2.0807202,3.1475224,,,,,,,,,,,,,, -140900,1.9879318,3.6623075,,,,,,,,,,,,,, -140945,,,0.8581054210662842,0.7595937848091125,0.7590199708938599,1.1628683805465698,50000.0,0.6384000182151794,1.7504186630249023,10000.0,64316.03085923195,69633.73905014992,64316.03085923195,5304.495651721954,5.978683710098267,0.0 -141000,2.0152218,2.9620354,,,,,,,,,,,,,, -141100,1.9776136,2.8732438,,,,,,,,,,,,,, -141200,1.9037154,3.2709389,,,,,,,,,,,,,, -141300,1.7934941,2.9306962,,,,,,,,,,,,,, -141400,1.9277718,3.7998288,,,,,,,,,,,,,, -141500,2.2138562,4.063796,,,,,,,,,,,,,, -141600,1.87964,2.964285,,,,,,,,,,,,,, -141700,2.0015981,2.9880042,,,,,,,,,,,,,, -141800,1.9060515,3.4582152,,,,,,,,,,,,,, -141867,,,0.8505077958106995,0.8011537790298462,0.7633000016212463,1.1656519174575806,50000.0,0.641800045967102,1.7470555305480957,10000.0,64736.08189225197,70087.59886169434,64736.08189225197,5338.206914901733,6.028635501861572,0.0 -141900,2.0463185,3.0346198,,,,,,,,,,,,,, -142000,1.9621521,3.3666136,,,,,,,,,,,,,, -142100,2.013887,3.0483043,,,,,,,,,,,,,, -142200,1.8821054,3.191496,,,,,,,,,,,,,, -142300,1.9336692,3.0974803,,,,,,,,,,,,,, -142400,2.5337133,4.4904304,,,,,,,,,,,,,, -142500,2.0952687,3.5779326,,,,,,,,,,,,,, -142600,1.904066,3.6123815,,,,,,,,,,,,,, -142700,2.0847836,2.934173,,,,,,,,,,,,,, -142788,,,0.8565039038658142,0.7577779293060303,0.7645599842071533,1.1408846378326416,50000.0,0.6448000073432922,1.7246521711349487,10000.0,65156.10931110382,70541.53916501999,65156.10931110382,5372.0228152275085,6.0786073207855225,0.0 -142800,1.9615163,2.9550405,,,,,,,,,,,,,, -142900,2.1613975,2.942612,,,,,,,,,,,,,, -143000,2.0078604,3.7847931,,,,,,,,,,,,,, -143100,2.3998811,4.4695187,,,,,,,,,,,,,, -143200,2.0690167,3.5135107,,,,,,,,,,,,,, -143300,4.4542613,4.2106457,,,,,,,,,,,,,, -143400,2.0469484,2.98196,,,,,,,,,,,,,, -143500,1.8963761,3.3034496,,,,,,,,,,,,,, -143600,1.8793839,3.295299,,,,,,,,,,,,,, -143700,1.9031544,3.2176254,,,,,,,,,,,,,, -143709,,,0.85986328125,0.7596256732940674,0.7628600001335144,1.166117787361145,50000.0,0.6426000595092773,1.7453765869140625,10000.0,65576.03127932549,70996.37029480934,65576.03127932549,5406.831308364868,6.132803916931152,0.0 -143800,2.80926,4.4604225,,,,,,,,,,,,,, -143900,2.1999626,4.026923,,,,,,,,,,,,,, -144000,2.1192083,2.9048653,,,,,,,,,,,,,, -144100,2.5077438,4.4823036,,,,,,,,,,,,,, -144200,2.2962224,3.6192644,,,,,,,,,,,,,, -144300,1.9555409,2.9178145,,,,,,,,,,,,,, -144400,1.9617703,3.1519284,,,,,,,,,,,,,, -144500,2.0159395,2.9143605,,,,,,,,,,,,,, -144600,1.997441,2.9187458,,,,,,,,,,,,,, -144631,,,0.8544726371765137,0.784826934337616,0.7655799984931946,1.1578985452651978,50000.0,0.6456000208854675,1.7437084913253784,10000.0,65996.08980154991,71449.93562602997,65996.08980154991,5440.239306926727,6.184715032577515,0.0 -144700,2.1464355,2.9925265,,,,,,,,,,,,,, -144800,2.1666274,4.1926184,,,,,,,,,,,,,, -144900,2.0756478,2.8755372,,,,,,,,,,,,,, -145000,2.3821092,4.1918254,,,,,,,,,,,,,, -145100,2.169463,4.057634,,,,,,,,,,,,,, -145200,2.1158717,3.6079988,,,,,,,,,,,,,, -145300,2.3665087,4.3588796,,,,,,,,,,,,,, -145400,2.0325685,3.1626375,,,,,,,,,,,,,, -145500,2.002168,3.7265568,,,,,,,,,,,,,, -145554,,,0.8596875071525574,0.7511771321296692,0.765999972820282,1.1414402723312378,50000.0,0.6479000449180603,1.710391640663147,10000.0,66416.06876826286,71906.65629696846,66416.06876826286,5476.885877370834,6.231912851333618,0.0 -145600,2.044241,2.848987,,,,,,,,,,,,,, -145700,2.262603,3.6713426,,,,,,,,,,,,,, -145800,1.9675871,3.2628436,,,,,,,,,,,,,, -145900,2.154926,2.879898,,,,,,,,,,,,,, -146000,2.1799896,4.1050673,,,,,,,,,,,,,, -146100,2.0113442,3.1132731,,,,,,,,,,,,,, -146200,2.1933994,2.9258442,,,,,,,,,,,,,, -146300,1.9586369,3.7560537,,,,,,,,,,,,,, -146400,2.0118632,3.1946006,,,,,,,,,,,,,, -146476,,,0.8626366853713989,0.7605917453765869,0.7663599848747253,1.1581281423568726,50000.0,0.650600016117096,1.7234017848968506,10000.0,66836.00802612305,72359.89109659195,66836.00802612305,5510.088440179825,6.2781524658203125,0.0 -146500,2.2512572,4.158165,,,,,,,,,,,,,, -146600,2.1902342,3.880086,,,,,,,,,,,,,, -146700,2.2785914,3.0900552,,,,,,,,,,,,,, -146800,2.4635274,4.3172765,,,,,,,,,,,,,, -146900,2.1045537,3.8203325,,,,,,,,,,,,,, -147000,2.0831485,2.9419434,,,,,,,,,,,,,, -147100,2.1150024,2.9001508,,,,,,,,,,,,,, -147200,2.3162737,3.0206892,,,,,,,,,,,,,, -147300,2.1572495,2.963922,,,,,,,,,,,,,, -147400,,,0.8684960603713989,0.7181340456008911,0.7680400013923645,1.130260705947876,50000.0,0.6456000208854675,1.714258074760437,10000.0,67255.94326424599,72815.77249288559,67255.94326424599,5545.940866231918,6.324589729309082,0.0 -147400,2.0794044,3.1063304,,,,,,,,,,,,,, -147500,2.4050064,2.9022725,,,,,,,,,,,,,, -147600,2.0305593,3.0486774,,,,,,,,,,,,,, -147700,2.3129468,2.9206145,,,,,,,,,,,,,, -147800,1.985787,3.3773158,,,,,,,,,,,,,, -147900,2.0225227,2.9204779,,,,,,,,,,,,,, -148000,2.2263186,3.6646063,,,,,,,,,,,,,, -148100,2.0648286,2.9696476,,,,,,,,,,,,,, -148200,2.1686354,2.9100616,,,,,,,,,,,,,, -148300,2.1876457,2.9832187,,,,,,,,,,,,,, -148314,,,0.8642382621765137,0.7342635989189148,0.7673199772834778,1.126471996307373,50000.0,0.6473000049591064,1.71936297416687,10000.0,67676.07908463478,73275.14289355278,67676.07908463478,5585.072806835175,6.381242990493774,0.0 -148400,2.1296325,2.935587,,,,,,,,,,,,,, -148500,2.1667051,2.941157,,,,,,,,,,,,,, -148600,2.324313,3.9638636,,,,,,,,,,,,,, -148700,2.2319624,2.8508344,,,,,,,,,,,,,, -148800,2.1093087,2.9393613,,,,,,,,,,,,,, -148900,2.07007,3.2497137,,,,,,,,,,,,,, -149000,2.873458,4.374428,,,,,,,,,,,,,, -149100,2.1822996,2.9198468,,,,,,,,,,,,,, -149200,2.2322812,2.9595413,,,,,,,,,,,,,, -149235,,,0.8697851300239563,0.7271069884300232,0.7689200043678284,1.1391234397888184,50000.0,0.6455000042915344,1.7286360263824463,10000.0,68096.35448336601,73732.74937844276,68096.35448336601,5622.308450460434,6.429650783538818,0.0 -149300,2.141792,2.8347845,,,,,,,,,,,,,, -149400,2.1368651,2.8563833,,,,,,,,,,,,,, -149500,2.1432045,2.9547894,,,,,,,,,,,,,, -149600,2.0236697,2.950423,,,,,,,,,,,,,, -149700,2.137198,2.9655674,,,,,,,,,,,,,, -149800,3.0477474,4.452085,,,,,,,,,,,,,, -149900,2.170561,4.025277,,,,,,,,,,,,,, -150000,2.224557,3.5959134,,,,,,,,,,,,,, -150100,2.171632,2.924532,,,,,,,,,,,,,, -150156,,,0.8738671541213989,0.7149843573570251,0.7697599530220032,1.141154170036316,50000.0,0.650600016117096,1.7134649753570557,10000.0,68516.60844302177,74196.96331691742,68516.60844302177,5666.171459674835,6.480257034301758,0.0 -150200,2.324991,4.0258245,,,,,,,,,,,,,, -150300,2.0404446,3.4210932,,,,,,,,,,,,,, -150400,2.0507221,3.064747,,,,,,,,,,,,,, -150500,2.272453,2.9727123,,,,,,,,,,,,,, -150600,2.1492755,3.1371455,,,,,,,,,,,,,, -150700,2.0929842,2.881487,,,,,,,,,,,,,, -150800,2.1711526,2.969695,,,,,,,,,,,,,, -150900,2.111899,2.989692,,,,,,,,,,,,,, -151000,2.2356992,3.130926,,,,,,,,,,,,,, -151078,,,0.8681445121765137,0.7366388440132141,0.7703999876976013,1.1382390260696411,50000.0,0.6499000191688538,1.7185348272323608,10000.0,68936.61577987671,74650.41727089882,68936.61577987671,5699.519348621368,6.532433748245239,0.0 -151100,1.99478,3.1016798,,,,,,,,,,,,,, -151200,2.8450255,4.372375,,,,,,,,,,,,,, -151300,2.4439335,3.8003678,,,,,,,,,,,,,, -151400,2.7052917,4.3166804,,,,,,,,,,,,,, -151500,2.6371446,4.3662815,,,,,,,,,,,,,, -151600,2.1003523,3.2960618,,,,,,,,,,,,,, -151700,2.2024522,2.9436038,,,,,,,,,,,,,, -151800,2.410607,2.9352186,,,,,,,,,,,,,, -151900,2.1205342,2.9966357,,,,,,,,,,,,,, -151997,,,0.8676952719688416,0.7144691348075867,0.7712999582290649,1.1138060092926023,50000.0,0.6539000272750854,1.6868129968643188,10000.0,69356.5645096302,75107.14363598824,69356.5645096302,5736.198989152908,6.583809614181519,0.0 -152000,2.2339566,2.963067,,,,,,,,,,,,,, -152100,2.1928532,2.8404021,,,,,,,,,,,,,, -152200,2.1884835,3.0721865,,,,,,,,,,,,,, -152300,2.1472821,2.8535225,,,,,,,,,,,,,, -152400,2.1661382,3.8444765,,,,,,,,,,,,,, -152500,2.266358,3.0141711,,,,,,,,,,,,,, -152600,2.2573419,2.900405,,,,,,,,,,,,,, -152700,2.454281,4.165727,,,,,,,,,,,,,, -152800,2.4064646,4.0785084,,,,,,,,,,,,,, -152900,2.2924552,2.8682632,,,,,,,,,,,,,, -152918,,,0.8729101419448853,0.6911082863807678,0.772879958152771,1.115715265274048,50000.0,0.6516000032424927,1.699005126953125,10000.0,69776.55529689789,75567.5617249012,69776.55529689789,5776.524785995483,6.637896299362183,0.0 -153000,2.313009,2.834438,,,,,,,,,,,,,, -153100,2.1133432,2.7984347,,,,,,,,,,,,,, -153200,2.2414534,2.8648057,,,,,,,,,,,,,, -153300,2.5637283,4.114462,,,,,,,,,,,,,, -153400,2.1845284,2.9034755,,,,,,,,,,,,,, -153500,2.2246814,3.281651,,,,,,,,,,,,,, -153600,2.204376,2.8697975,,,,,,,,,,,,,, -153700,2.3601289,2.8682363,,,,,,,,,,,,,, -153800,2.2179356,3.1653187,,,,,,,,,,,,,, -153837,,,0.8698632717132568,0.7052730321884155,0.7731999754905701,1.1103754043579102,50000.0,0.656000018119812,1.6826592683792114,10000.0,70196.77774477005,76025.97893810272,70196.77774477005,5814.613633155823,6.696936845779419,0.0 -153900,2.2272203,2.8602974,,,,,,,,,,,,,, -154000,2.3243723,2.865813,,,,,,,,,,,,,, -154100,2.171217,3.6171274,,,,,,,,,,,,,, -154200,2.284718,2.7779267,,,,,,,,,,,,,, -154300,2.247267,2.8925185,,,,,,,,,,,,,, -154400,2.5966763,4.0223255,,,,,,,,,,,,,, -154500,2.592504,3.9748337,,,,,,,,,,,,,, -154600,2.1783621,3.262608,,,,,,,,,,,,,, -154700,2.3479686,3.8806398,,,,,,,,,,,,,, -154759,,,0.8730077743530273,0.7196352481842041,0.7730000019073486,1.1262603998184204,50000.0,0.6591000556945801,1.7006280422210691,10000.0,70616.80341100693,76481.1249115467,70616.80341100693,5849.63468337059,6.748908281326294,0.0 -154800,2.1994796,3.2457995,,,,,,,,,,,,,, -154900,2.767321,4.400837,,,,,,,,,,,,,, -155000,2.2223654,2.7978501,,,,,,,,,,,,,, -155100,2.4674842,3.5198197,,,,,,,,,,,,,, -155200,2.2266572,2.9429824,,,,,,,,,,,,,, -155300,2.9451494,4.371481,,,,,,,,,,,,,, -155400,2.6177795,3.8578334,,,,,,,,,,,,,, -155500,2.268915,2.866261,,,,,,,,,,,,,, -155600,2.1448052,2.9409697,,,,,,,,,,,,,, -155678,,,0.8753125071525574,0.6879866719245911,0.773419976234436,1.109251618385315,50000.0,0.6586000323295593,1.6903423070907593,10000.0,71037.05806207657,76937.91496515274,71037.05806207657,5886.051543951035,6.808360576629639,0.0 -155700,2.1186461,2.9269066,,,,,,,,,,,,,, -155800,2.5142162,3.9414303,,,,,,,,,,,,,, -155900,2.2756023,3.1743524,,,,,,,,,,,,,, -156000,2.5206017,4.0747194,,,,,,,,,,,,,, -156100,2.1100497,3.044887,,,,,,,,,,,,,, -156200,2.4422772,2.8785992,,,,,,,,,,,,,, -156300,2.2825353,3.3122544,,,,,,,,,,,,,, -156400,2.2594745,2.7986844,,,,,,,,,,,,,, -156500,2.5958853,3.985038,,,,,,,,,,,,,, -156597,,,0.87353515625,0.7132282257080078,0.7741999626159668,1.1228692531585691,50000.0,0.6598000526428223,1.6865030527114868,10000.0,71457.10474681854,77390.49496340752,71457.10474681854,5918.482615470886,6.863985776901245,0.0 -156600,2.356819,3.7817538,,,,,,,,,,,,,, -156700,3.177384,4.373261,,,,,,,,,,,,,, -156800,2.242382,2.8456295,,,,,,,,,,,,,, -156900,2.2467208,3.5971243,,,,,,,,,,,,,, -157000,2.7928529,4.3659706,,,,,,,,,,,,,, -157100,2.1341677,2.7824008,,,,,,,,,,,,,, -157200,2.283542,2.851235,,,,,,,,,,,,,, -157300,2.4506025,3.8214698,,,,,,,,,,,,,, -157400,2.306774,3.3906376,,,,,,,,,,,,,, -157500,2.418132,3.4956918,,,,,,,,,,,,,, -157518,,,0.87548828125,0.706649124622345,0.774619996547699,1.1308605670928955,50000.0,0.6622000336647034,1.6971155405044556,10000.0,71877.31110310555,77843.94621014595,71877.31110310555,5951.625692844391,6.918470144271851,0.0 -157600,2.2307727,3.1227217,,,,,,,,,,,,,, -157700,2.3617144,3.7999754,,,,,,,,,,,,,, -157800,2.3015747,2.802449,,,,,,,,,,,,,, -157900,3.074599,4.3446093,,,,,,,,,,,,,, -158000,2.1312237,3.5473177,,,,,,,,,,,,,, -158100,2.3212912,2.9176934,,,,,,,,,,,,,, -158200,3.1683378,4.3202786,,,,,,,,,,,,,, -158300,2.502612,2.8631945,,,,,,,,,,,,,, -158400,2.604618,3.8244486,,,,,,,,,,,,,, -158438,,,0.8758202791213989,0.6879441738128662,0.7753399610519409,1.1054383516311646,50000.0,0.6605000495910645,1.6729251146316528,10000.0,72297.48312687874,78298.52132201195,72297.48312687874,5985.92941904068,6.971019268035889,0.0 -158500,2.9761062,4.3008165,,,,,,,,,,,,,, -158600,2.3091545,3.0504866,,,,,,,,,,,,,, -158700,2.2649152,2.9161992,,,,,,,,,,,,,, -158800,2.2934818,2.8067656,,,,,,,,,,,,,, -158900,2.4550807,2.8404942,,,,,,,,,,,,,, -159000,2.2117693,2.7878022,,,,,,,,,,,,,, -159100,2.612978,3.910553,,,,,,,,,,,,,, -159200,2.8042533,4.166006,,,,,,,,,,,,,, -159300,2.2867894,2.8527098,,,,,,,,,,,,,, -159357,,,0.8858202695846558,0.6633342504501343,0.7765199542045593,1.1071991920471191,50000.0,0.6556000113487244,1.6883338689804075,10000.0,72717.58166861534,78752.3706395626,72717.58166861534,6019.583042383194,7.0212483406066895,0.0 -159400,2.3974237,2.8901973,,,,,,,,,,,,,, -159500,2.356401,2.9091604,,,,,,,,,,,,,, -159600,2.3820343,2.8500829,,,,,,,,,,,,,, -159700,2.3971543,3.542486,,,,,,,,,,,,,, -159800,2.3407032,3.5682185,,,,,,,,,,,,,, -159900,2.2969635,2.8020859,,,,,,,,,,,,,, -160000,2.9527056,4.1309886,,,,,,,,,,,,,, -160100,2.3956997,2.8371692,,,,,,,,,,,,,, -160200,2.4763,2.8456497,,,,,,,,,,,,,, -160277,,,0.8758593797683716,0.6875196099281311,0.776479959487915,1.1003152132034302,50000.0,0.6580000519752502,1.683222770690918,10000.0,73137.55405020714,79208.81894946098,73137.55405020714,6055.962126255035,7.0706892013549805,0.0 -160300,2.4040678,2.8606758,,,,,,,,,,,,,, -160400,2.3666906,2.8391118,,,,,,,,,,,,,, -160500,2.3575907,2.8325963,,,,,,,,,,,,,, -160600,2.3915322,2.8773549,,,,,,,,,,,,,, -160700,2.4503636,2.8355033,,,,,,,,,,,,,, -160800,2.3591983,3.6666007,,,,,,,,,,,,,, -160900,2.2795482,3.2334602,,,,,,,,,,,,,, -161000,2.4800625,3.6805205,,,,,,,,,,,,,, -161100,2.3433735,2.9760952,,,,,,,,,,,,,, -161199,,,0.8811913728713989,0.671556293964386,0.7778199911117554,1.1007194519042969,50000.0,0.6577000021934509,1.6750051975250244,10000.0,73557.74100780487,79667.06095504761,73557.74100780487,6093.915406227112,7.124929904937744,0.0 -161200,2.246907,3.150389,,,,,,,,,,,,,, -161300,2.4885883,2.8317277,,,,,,,,,,,,,, -161400,2.5027213,3.56785,,,,,,,,,,,,,, -161500,3.6338468,3.9965525,,,,,,,,,,,,,, -161600,2.189453,2.795701,,,,,,,,,,,,,, -161700,2.3811316,3.3838906,,,,,,,,,,,,,, -161800,3.330427,4.3605604,,,,,,,,,,,,,, -161900,2.295885,3.2125006,,,,,,,,,,,,,, -162000,2.713604,3.9228997,,,,,,,,,,,,,, -162100,2.5428483,3.118927,,,,,,,,,,,,,, -162117,,,0.8856250047683716,0.6517072319984436,0.7775999903678894,1.0972752571105957,50000.0,0.6593000292778015,1.6683248281478882,10000.0,73978.00402450562,80116.75618052483,73978.00402450562,6123.2500858306885,7.175060510635376,0.0 -162200,2.2611055,2.8830776,,,,,,,,,,,,,, -162300,2.2206814,3.0629451,,,,,,,,,,,,,, -162400,2.2293217,2.7558415,,,,,,,,,,,,,, -162500,3.2995884,4.2818527,,,,,,,,,,,,,, -162600,2.168817,3.2295392,,,,,,,,,,,,,, -162700,3.0462334,4.039996,,,,,,,,,,,,,, -162800,2.2333891,3.2638214,,,,,,,,,,,,,, -162900,2.467107,2.8070574,,,,,,,,,,,,,, -163000,2.4489353,3.7979348,,,,,,,,,,,,,, -163035,,,0.8808007836341858,0.6864597797393799,0.7800799608230591,1.105568766593933,50000.0,0.6591000556945801,1.6852294206619265,10000.0,74398.24773645401,80573.08365154266,74398.24773645401,6159.228252887726,7.233767747879028,0.0 -163100,3.14136,4.375021,,,,,,,,,,,,,, -163200,2.4163198,2.835815,,,,,,,,,,,,,, -163300,2.398491,3.4755437,,,,,,,,,,,,,, -163400,2.4738343,3.104891,,,,,,,,,,,,,, -163500,2.4555721,2.7832785,,,,,,,,,,,,,, -163600,2.458449,2.765289,,,,,,,,,,,,,, -163700,2.3528197,2.9690523,,,,,,,,,,,,,, -163800,3.0804884,4.266973,,,,,,,,,,,,,, -163900,3.0483155,4.1705647,,,,,,,,,,,,,, -163955,,,0.8822265267372131,0.6664526462554932,0.7804200053215027,1.0926640033721924,50000.0,0.6650000214576721,1.66836416721344,10000.0,74818.47779989243,81033.72079610825,74818.47779989243,6199.534796953201,7.287166118621826,0.0 -164000,2.4025767,2.77366,,,,,,,,,,,,,, -164100,2.3081286,2.7727206,,,,,,,,,,,,,, -164200,2.8564508,3.0136347,,,,,,,,,,,,,, -164300,2.733731,3.8172243,,,,,,,,,,,,,, -164400,2.5093572,3.08218,,,,,,,,,,,,,, -164500,2.898437,3.7333677,,,,,,,,,,,,,, -164600,2.436965,2.7170198,,,,,,,,,,,,,, -164700,2.5557172,3.400181,,,,,,,,,,,,,, -164800,2.4817026,2.8296444,,,,,,,,,,,,,, -164877,,,0.88636714220047,0.653562068939209,0.7805599570274353,1.0905064344406128,50000.0,0.6598000526428223,1.6655395030975342,10000.0,75238.48033547401,81487.14139032364,75238.48033547401,6232.852535486221,7.340670824050903,0.0 -164900,2.4285643,3.217319,,,,,,,,,,,,,, -165000,2.3653595,2.8873339,,,,,,,,,,,,,, -165100,2.2967973,2.8917713,,,,,,,,,,,,,, -165200,2.663171,2.891579,,,,,,,,,,,,,, -165300,2.4728544,2.8114498,,,,,,,,,,,,,, -165400,2.864365,3.7089012,,,,,,,,,,,,,, -165500,2.666289,2.88446,,,,,,,,,,,,,, -165600,2.4942563,3.644628,,,,,,,,,,,,,, -165700,3.0020795,4.1678996,,,,,,,,,,,,,, -165798,,,0.8838866949081421,0.6820747256278992,0.7806800007820129,1.1070207357406616,50000.0,0.6610000133514404,1.6883653402328491,10000.0,75658.44529294968,81947.32384061813,75658.44529294968,6272.963407754898,7.399810791015625,0.0 -165800,2.6038191,3.134109,,,,,,,,,,,,,, -165900,2.5131946,2.7857685,,,,,,,,,,,,,, -166000,2.4383605,2.817039,,,,,,,,,,,,,, -166100,2.6511533,2.8287468,,,,,,,,,,,,,, -166200,2.3984323,3.4251027,,,,,,,,,,,,,, -166300,2.887312,3.477199,,,,,,,,,,,,,, -166400,2.4105296,3.0614033,,,,,,,,,,,,,, -166500,2.5698462,2.7171912,,,,,,,,,,,,,, -166600,2.9558434,4.194937,,,,,,,,,,,,,, -166700,2.3497875,2.8290238,,,,,,,,,,,,,, -166719,,,0.8874218463897705,0.640555739402771,0.7835800051689148,1.0742700099945068,50000.0,0.6648000478744507,1.6466233730316162,10000.0,76078.42911958694,82398.08100652695,76078.42911958694,6303.624835968018,7.464045286178589,0.0 -166800,2.5177906,2.8631756,,,,,,,,,,,,,, -166900,2.7023563,2.8032439,,,,,,,,,,,,,, -167000,2.528022,2.777349,,,,,,,,,,,,,, -167100,2.6195388,2.93769,,,,,,,,,,,,,, -167200,2.4185808,3.2156851,,,,,,,,,,,,,, -167300,2.879835,3.776742,,,,,,,,,,,,,, -167400,2.5341034,2.8074713,,,,,,,,,,,,,, -167500,2.7674017,3.6536355,,,,,,,,,,,,,, -167600,2.4177382,2.748902,,,,,,,,,,,,,, -167641,,,0.8880468606948853,0.6393408179283142,0.7837599515914917,1.0733433961868286,50000.0,0.6659000515937805,1.6449536085128784,10000.0,76498.71252512932,82852.74976229668,76498.71252512932,6337.89848613739,7.528738975524902,0.0 -167700,2.2890012,3.2663846,,,,,,,,,,,,,, -167800,2.5673165,2.819137,,,,,,,,,,,,,, -167900,2.272786,2.7144356,,,,,,,,,,,,,, -168000,2.4452236,2.671703,,,,,,,,,,,,,, -168100,2.380363,2.7827346,,,,,,,,,,,,,, -168200,2.5444312,2.7897518,,,,,,,,,,,,,, -168300,2.503409,3.2173643,,,,,,,,,,,,,, -168400,2.5697863,3.3461423,,,,,,,,,,,,,, -168500,2.3666403,3.360612,,,,,,,,,,,,,, -168563,,,0.8868163824081421,0.655565083026886,0.7830599546432495,1.089125394821167,50000.0,0.6633000373840332,1.6611697673797607,10000.0,76918.80696439743,83308.578540802,76918.80696439743,6373.534379959106,7.579162836074829,0.0 -168600,2.4907074,2.7456722,,,,,,,,,,,,,, -168700,2.4871442,2.7573,,,,,,,,,,,,,, -168800,2.6055784,2.763549,,,,,,,,,,,,,, -168900,2.6346893,2.7788315,,,,,,,,,,,,,, -169000,2.5243068,2.7769291,,,,,,,,,,,,,, -169100,2.733948,4.0721955,,,,,,,,,,,,,, -169200,2.6403983,2.9256608,,,,,,,,,,,,,, -169300,2.6943572,3.463873,,,,,,,,,,,,,, -169400,2.5149906,3.523783,,,,,,,,,,,,,, -169484,,,0.8880468606948853,0.6424294710159302,0.7831000089645386,1.0735559463500977,50000.0,0.6653000116348267,1.646395206451416,10000.0,77338.80617928505,83767.41523122787,77338.80617928505,6412.271682739258,7.63238000869751,0.0 -169500,2.6494994,2.83997,,,,,,,,,,,,,, -169600,2.6122143,2.81615,,,,,,,,,,,,,, -169700,2.4339097,2.7638793,,,,,,,,,,,,,, -169800,2.4877863,2.743116,,,,,,,,,,,,,, -169887,,,,,,,,,,,77520.2169020176,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/eval_measurements.csv deleted file mode 100644 index c6448588a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -27.38794589042664,0.0,36.40782308578491,1,0,36.40782308578491,0.0010000000474974,6.907756805419922,10000,63.7958824634552,0.0011914062779396,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -65.63604092597961,0.0173430442810058,456.3771412372589,862,0,456.3771412372589,0.0081000002101063,6.500799655914307,10000,522.0745017528534,0.0114257810637354,6.446063041687012,0.0111599992960691,6.462528705596924,50000 -103.42611718177795,0.0514211654663085,876.6863760948181,1781,0,876.6863760948181,0.0306000020354986,6.009640693664551,10000,980.255558013916,0.0377734377980232,5.867616653442383,0.0365999974310398,5.891701698303223,50000 -147.62753534317017,0.0841724872589111,1296.615249156952,2700,0,1296.615249156952,0.0462000034749507,5.637139320373535,10000,1444.4656417369845,0.0656054690480232,5.406094074249268,0.0608999989926815,5.445602416992188,50000 -184.84893035888672,0.1116495132446289,1716.6178047657013,3620,0,1716.6178047657013,0.0667000040411949,5.365293502807617,10000,1901.7651226520536,0.0945898443460464,5.085729122161865,0.0872199982404708,5.123086929321289,50000 -222.7896881103516,0.1356348991394043,2136.5751678943634,4540,0,2136.5751678943634,0.0945000052452087,5.031833648681641,10000,2359.7350487709045,0.1378320306539535,4.668841361999512,0.1275399923324585,4.732216835021973,50000 -257.5574834346771,0.1605908870697021,2556.6016159057617,5457,0,2556.6016159057617,0.1197000071406364,4.782550811767578,10000,2814.6012556552887,0.1705273389816284,4.367717742919922,0.1566199958324432,4.442732334136963,50000 -291.22222685813904,0.1891450881958007,2976.937946081161,6378,0,2976.937946081161,0.1504000127315521,4.456921100616455,10000,3268.678102016449,0.2243359386920929,3.935076713562012,0.201339989900589,4.064193248748779,50000 -331.77545142173767,0.2208545207977295,3396.901723384857,7295,0,3396.901723384857,0.1825000047683715,4.207826137542725,10000,3729.273754119873,0.25927734375,3.695412397384644,0.2415599972009658,3.787732601165772,50000 -367.687614440918,0.247636079788208,3817.1914489269257,8216,0,3817.1914489269257,0.2082000076770782,4.0174102783203125,10000,4185.551072597504,0.2934960722923279,3.44327449798584,0.2733799815177917,3.552814483642578,50000 -402.15560007095337,0.2722275257110595,4237.426783323288,9136,0,4237.426783323288,0.2312000095844268,3.8833281993865967,10000,4640.326997518539,0.3303906321525574,3.2213683128356934,0.2978399991989136,3.398451089859009,50000 -439.33963537216187,0.2994673252105713,4657.610203266144,10054,0,4657.610203266144,0.2563000023365021,3.715104103088379,10000,5097.769255399704,0.3563671708106994,3.065946578979492,0.3330000042915344,3.185922622680664,50000 -477.90121936798096,0.3332977294921875,5077.543553113937,10973,0,5077.543553113937,0.2735000252723694,3.545906782150269,10000,5556.3450927734375,0.3862695097923279,2.848928213119507,0.3571199774742126,2.993165969848633,50000 -512.5262434482574,0.3625199794769287,5497.630401134491,11892,0,5497.630401134491,0.287200003862381,3.504602193832397,10000,6011.133774995804,0.4043749868869781,2.7795512676239014,0.367279976606369,2.9624674320220947,50000 -549.9961397647858,0.389383316040039,5917.629370927811,12809,0,5917.629370927811,0.3069000244140625,3.374930620193481,10000,6468.676458835602,0.4223828017711639,2.665945291519165,0.3883399963378906,2.8275606632232666,50000 -588.2510459423065,0.4151310920715332,6337.654529333115,13726,0,6337.654529333115,0.3149000108242035,3.2909696102142334,10000,6927.02899646759,0.4422460794448852,2.522275447845459,0.4086399972438812,2.700598955154419,50000 -627.6393864154816,0.4406590461730957,6758.110645294189,14648,0,6758.110645294189,0.3217000067234039,3.2877392768859863,10000,7386.947074890137,0.445136696100235,2.541533708572388,0.4125999808311462,2.715203762054444,50000 -669.4438850879669,0.4714105129241943,7178.091734886169,15568,0,7178.091734886169,0.3345000147819519,3.1643946170806885,10000,7848.810657024383,0.4643945097923279,2.417165279388428,0.4288599789142608,2.588149070739746,50000 -708.1572668552399,0.4968502521514892,7598.284974575043,16487,0,7598.284974575043,0.3418000042438507,3.12949013710022,10000,8307.79007267952,0.4732031226158142,2.3529422283172607,0.4392599761486053,2.526226758956909,50000 -745.3528969287872,0.5254116058349609,8018.596554040909,17409,0,8018.596554040909,0.3455000221729278,3.113865852355957,10000,8765.373228549957,0.4809765517711639,2.311243772506714,0.4396799802780151,2.5017337799072266,50000 -783.8180267810822,0.551140308380127,8438.65920996666,18330,0,8438.65920996666,0.3544000089168548,3.041668176651001,10000,9223.974129915236,0.5154687166213989,2.13545823097229,0.4569000005722046,2.423603773117065,50000 -821.6687545776367,0.5784051418304443,8858.939247369766,19249,0,8858.939247369766,0.3618000149726867,3.014835834503174,10000,9682.179508447647,0.495410144329071,2.2226991653442383,0.4612599909305572,2.3966081142425537,50000 -854.8419258594513,0.6058785915374756,9279.185322284698,20170,0,9279.185322284698,0.359000027179718,3.015586376190185,10000,10135.67358827591,0.4988085925579071,2.2068932056427,0.4660399854183197,2.378439426422119,50000 -889.8293704986572,0.6350181102752686,9699.388586997986,21092,0,9699.388586997986,0.3703000247478485,2.963587045669556,10000,10590.940872192385,0.5242773294448853,2.095875024795532,0.4736199975013733,2.349881649017334,50000 -923.1459929943084,0.6662139892578125,10119.395952701569,22012,0,10119.395952701569,0.3778000175952911,2.893505096435547,10000,11044.343485832214,0.5155078172683716,2.124331474304199,0.4827999770641327,2.288004398345948,50000 -960.2646844387054,0.6987216472625732,10539.358990907667,22927,0,10539.358990907667,0.3838000297546386,2.8525102138519287,10000,11501.50481057167,0.5287500023841858,2.041822910308838,0.4889199733734131,2.242918968200684,50000 -999.9240214824677,0.7256793975830078,10959.397310972214,23845,0,10959.397310972214,0.387800008058548,2.8514115810394287,10000,11961.27671599388,0.5420507788658142,2.0036168098449707,0.5009599924087524,2.2223894596099854,50000 -1033.0191838741302,0.761568546295166,11379.342381954191,24770,0,11379.342381954191,0.3936000168323517,2.818950891494751,10000,12414.400235176086,0.5383593440055847,2.017052173614502,0.5021600127220154,2.189324855804444,50000 -1070.0217413902285,0.7992439270019531,11799.4931101799,25691,0,11799.4931101799,0.393200010061264,2.8182153701782227,10000,12871.638572454453,0.5432421565055847,1.992870092391968,0.5070199966430664,2.167407989501953,50000 -1107.7912635803225,0.8276638984680176,12219.802811145782,26613,0,12219.802811145782,0.3975000083446502,2.823049783706665,10000,13329.79423236847,0.549023449420929,1.963714599609375,0.5048800110816956,2.183801889419556,50000 -1145.3771076202393,0.859968900680542,12639.845764398577,27533,0,12639.845764398577,0.4047000110149383,2.7460947036743164,10000,13787.502056837082,0.564648449420929,1.8632910251617432,0.5128799676895142,2.0994911193847656,50000 -1181.059654712677,0.8942258358001709,13060.279699325562,28453,0,13060.279699325562,0.4094000160694122,2.7339088916778564,10000,14243.699850559236,0.5565234422683716,1.9159458875656128,0.5190799832344055,2.100757360458374,50000 -1215.350920677185,0.921989917755127,13480.380759000778,29344,0,13480.380759000778,0.4134000241756439,2.7119228839874268,10000,14698.166138410568,0.5629491806030273,1.8562678098678589,0.523580014705658,2.056525230407715,50000 -1253.140303850174,0.950770616531372,13900.379534959791,30266,0,13900.379534959791,0.4092000126838684,2.7224607467651367,10000,15156.030605793,0.5864452719688416,1.7774485349655151,0.5238800048828125,2.0740952491760254,50000 -1291.0224838256836,0.9816055297851562,14320.638298034668,31188,0,14320.638298034668,0.4148000180721283,2.696887731552124,10000,15614.249346971512,0.5709179639816284,1.872394561767578,0.5306400060653687,2.058422327041626,50000 -1328.2112641334534,1.0133063793182373,14740.856187820436,32109,0,14740.856187820436,0.4218000173568725,2.658211231231689,10000,16071.735245227814,0.5722070336341858,1.8319276571273804,0.537339985370636,2.008451223373413,50000 -1368.0250644683838,1.0439252853393557,15160.993295431135,33030,0,15160.993295431135,0.4265000224113464,2.6062676906585693,10000,16531.774827718735,0.591601550579071,1.7154150009155271,0.5344399809837341,1.969157099723816,50000 -1408.3668491840365,1.0768020153045654,15580.957021474838,33951,0,15580.957021474838,0.4296000301837921,2.6152634620666504,10000,16992.161110639572,0.580078125,1.7822751998901367,0.5445799827575684,1.9617338180541992,50000 -1445.8618338108065,1.105936050415039,16001.323278665544,34872,0,16001.323278665544,0.4246000349521637,2.62258243560791,10000,17450.09957075119,0.5824218392372131,1.7896679639816284,0.5446400046348572,1.968542456626892,50000 -1483.3568606376648,1.1460247039794922,16421.62156009674,35793,0,16421.62156009674,0.4243000149726867,2.617381811141968,10000,17907.980507850647,0.5892577767372131,1.7423663139343262,0.5414800047874451,1.9773989915847776,50000 -1522.7079238891602,1.1861801147460938,16841.98622250557,36716,0,16841.98622250557,0.4310000240802765,2.5739073753356934,10000,18367.78451514244,0.5894140601158142,1.7541520595550537,0.5504999756813049,1.9289462566375728,50000 -1557.9151480197906,1.2215921878814695,17262.115788459778,37638,0,17262.115788459778,0.4353000223636627,2.56376576423645,10000,18823.204505443573,0.5914062261581421,1.7134793996810913,0.5567799806594849,1.893664002418518,50000 -1594.8616213798523,1.2564432621002195,17682.302928209305,38560,0,17682.302928209305,0.4353000223636627,2.5917446613311768,10000,19280.42056465149,0.5970898270606995,1.721079707145691,0.5518199801445007,1.9465895891189573,50000 -1634.6089255809784,1.2855734825134275,18102.482609033585,39481,0,18102.482609033585,0.4412000179290771,2.564246654510498,10000,19740.42428445816,0.6230077743530273,1.6257997751235962,0.5562999844551086,1.9278334379196167,50000 -1672.908019542694,1.3188085556030271,18522.80424499512,40401,0,18522.80424499512,0.4462000131607055,2.5406343936920166,10000,20199.12549352646,0.5972656011581421,1.7070727348327637,0.5578399896621704,1.892319798469544,50000 -1711.6933376789093,1.3502240180969238,18942.92242288589,41322,0,18942.92242288589,0.4448000192642212,2.5214507579803467,10000,20658.107868433,0.6015819907188416,1.6829158067703247,0.557379961013794,1.892937421798706,50000 -1749.5593955516815,1.3858461380004885,19362.84196233749,42241,0,19362.84196233749,0.4435000121593475,2.519059419631958,10000,21115.976779937744,0.6206249594688416,1.5839260816574097,0.5624200105667114,1.8680510520935056,50000 -1789.3050591945648,1.420839786529541,19782.84850549698,43163,0,19782.84850549698,0.4525000154972076,2.4949381351470947,10000,21575.81156134605,0.6048827767372131,1.6705154180526731,0.5646600127220154,1.8540127277374268,50000 -1824.1077728271484,1.4556865692138672,20203.19352889061,44084,0,20203.19352889061,0.4494000077247619,2.488030433654785,10000,22031.04239463806,0.6102148294448853,1.6581255197525024,0.5626199841499329,1.85548996925354,50000 -1861.775664567948,1.4894487857818604,20623.2698571682,45006,0,20623.2698571682,0.4508000314235687,2.4984617233276367,10000,22488.867975711823,0.6237499713897705,1.5927554368972778,0.568619966506958,1.843924641609192,50000 -1897.01748919487,1.5201151371002195,21043.675313472748,45926,0,21043.675313472748,0.4462000131607055,2.526916027069092,10000,22944.59355187416,0.605175793170929,1.694035887718201,0.5661999583244324,1.873929023742676,50000 -1929.7773234844208,1.5505378246307373,21463.90951180458,46847,0,21463.90951180458,0.4472000300884247,2.5412986278533936,10000,23397.665759801865,0.6008007526397705,1.7134603261947632,0.5633000135421753,1.8950095176696773,50000 -1968.531141042709,1.584803342819214,21884.20440530777,47769,0,21884.20440530777,0.4527000188827514,2.5124051570892334,10000,23856.79626774788,0.6170703172683716,1.6229615211486816,0.5671600103378296,1.8515721559524536,50000 -2007.020524263382,1.6213884353637695,22304.507378816605,48691,0,22304.507378816605,0.4574000239372253,2.444520711898804,10000,24315.673028230667,0.6227929592132568,1.575631856918335,0.5725600123405457,1.808995008468628,50000 -2046.3978426456447,1.657334804534912,22724.5529756546,49612,0,22724.5529756546,0.4617000222206116,2.452442646026612,10000,24775.178783893585,0.6174609065055847,1.6112401485443115,0.573639988899231,1.8005505800247192,50000 -2080.936208486557,1.6951377391815186,23144.769639730453,50533,0,23144.769639730453,0.4663000106811523,2.434957981109619,10000,25230.020445346832,0.62060546875,1.6004736423492432,0.5769400000572205,1.7959643602371216,50000 -2120.237253427505,1.7319800853729248,23564.7326142788,51453,0,23564.7326142788,0.4559000134468078,2.4530246257781982,10000,25689.36802005768,0.6434960961341858,1.4826300144195557,0.5772199630737305,1.7871181964874268,50000 -2160.178407430649,1.7693891525268557,23984.924451828003,52375,0,23984.924451828003,0.4628000259399414,2.418676376342773,10000,26149.58564400673,0.6228905916213989,1.5904203653335571,0.5809999704360962,1.7867530584335327,50000 -2195.0370230674744,1.8044648170471191,24404.92098903656,53295,0,24404.92098903656,0.4629000127315521,2.41633677482605,10000,26604.52316379547,0.6255077719688416,1.5740492343902588,0.5805599689483643,1.7782132625579834,50000 -2230.295460224152,1.8441081047058103,24825.276747226715,54218,0,24825.276747226715,0.4669000208377838,2.433056354522705,10000,27060.22543120384,0.6314257383346558,1.5377583503723145,0.5790599584579468,1.7804241180419922,50000 -2268.2810649871826,1.8811235427856443,25245.21862053871,55138,0,25245.21862053871,0.4554000198841095,2.452510356903076,10000,27518.237367630005,0.6166015267372131,1.6154673099517822,0.5781399607658386,1.7982147932052612,50000 -2302.746966123581,2.2839317321777344,25665.102162361145,56056,0,25665.102162361145,0.4698000252246856,2.410242795944214,10000,27973.03744482994,0.6266992092132568,1.5533452033996582,0.5848999619483948,1.7521530389785769,50000 -2341.369970321656,2.32083511352539,26085.34273672104,56978,0,26085.34273672104,0.463200032711029,2.456258773803711,10000,28431.98593950272,0.6295312643051147,1.5653393268585205,0.5822399854660034,1.799446702003479,50000 -2380.102923631668,2.3555450439453125,26505.34909033776,57898,0,26505.34909033776,0.4683000147342682,2.4005134105682373,10000,28890.808502435684,0.6297656297683716,1.5668786764144895,0.5859599709510803,1.7582910060882568,50000 -2414.5152776241302,2.388858795166016,26925.5447409153,58817,0,26925.5447409153,0.4696000218391418,2.3985865116119385,10000,29345.4977645874,0.6303515434265137,1.5301530361175537,0.590399980545044,1.723675012588501,50000 -2453.3868992328644,2.431230306625366,27345.576851844788,59735,0,27345.576851844788,0.4668000340461731,2.381608486175537,10000,29804.49097251892,0.6366991996765137,1.510748267173767,0.590719997882843,1.7339236736297607,50000 -2494.3704164028168,2.467085599899292,27765.77965736389,60655,0,27765.77965736389,0.4716000258922577,2.3704307079315186,10000,30265.760466575623,0.661816418170929,1.413130760192871,0.5918799638748169,1.732790231704712,50000 -2527.9952044487,2.5055477619171143,28186.06973552704,61575,0,28186.06973552704,0.4678000211715698,2.399973630905152,10000,30719.761566877365,0.6309961080551147,1.5477447509765625,0.5902799963951111,1.7401505708694458,50000 -2567.8129115104675,2.54514741897583,28606.19539070129,62494,0,28606.19539070129,0.4692000150680542,2.3889834880828857,10000,31179.79212284088,0.6387304663658142,1.5181124210357666,0.588979959487915,1.735752820968628,50000 -2608.2360076904297,2.584966897964477,29026.439562797543,63414,0,29026.439562797543,0.4750000238418579,2.362293004989624,10000,31640.54657483101,0.6572265625,1.420607089996338,0.5945599675178528,1.7118746042251587,50000 -2647.1877439022064,2.622751951217652,29446.779339313507,64334,0,29446.779339313507,0.4806000292301178,2.3335988521575928,10000,32099.92324185372,0.6352148056030273,1.5222920179367063,0.5979799628257751,1.6968483924865725,50000 -2684.573234319687,2.658973693847656,29867.059475898743,65257,0,29867.059475898743,0.4757000207901001,2.3748252391815186,10000,32557.672868967056,0.6451562643051147,1.4754486083984375,0.5941999554634094,1.7001410722732544,50000 -2722.811019182205,2.700793743133545,30287.294951438904,66178,0,30287.294951438904,0.4808000326156616,2.354539155960083,10000,33016.23513197899,0.6478906273841858,1.469024658203125,0.5952199697494507,1.713046669960022,50000 -2764.954957485199,2.7407922744750977,30707.29386544228,67099,0,30707.29386544228,0.4843000173568725,2.3312244415283203,10000,33478.465988874435,0.6410937309265137,1.491320013999939,0.5977799892425537,1.6954699754714966,50000 -2800.633416414261,2.7796566486358643,31127.23282170296,68020,0,31127.23282170296,0.4786000251770019,2.350700855255127,10000,33934.169929265976,0.6380273103713989,1.4931342601776123,0.5981599688529968,1.6907784938812256,50000 -2841.955982208252,2.814785957336426,31547.49372458458,68942,0,31547.49372458458,0.4767000079154968,2.359503746032715,10000,34395.838116168976,0.6507812142372131,1.4588117599487305,0.5990399718284607,1.6879316568374634,50000 -2881.1043269634247,2.851184606552124,31967.64450263977,69864,0,31967.64450263977,0.482200026512146,2.3192484378814697,10000,34855.22176671028,0.654589831829071,1.434368371963501,0.6043199896812439,1.666023850440979,50000 -2911.928690671921,2.8854973316192627,32387.930511713028,70784,0,32387.930511713028,0.490200012922287,2.3093273639678955,10000,35306.4135351181,0.6522656083106995,1.430999517440796,0.6041399836540222,1.6534889936447144,50000 -2947.54905629158,2.9329519271850586,32808.145381212234,71705,0,32808.145381212234,0.4886000156402588,2.3075990676879883,10000,35762.34394454956,0.657031238079071,1.4209951162338257,0.6064199805259705,1.661571025848389,50000 -2984.2034389972687,2.973165988922119,33228.30744147301,72625,0,33228.30744147301,0.4834000170230865,2.3566348552703857,10000,36219.24794006348,0.6685937643051147,1.38480544090271,0.6049599647521973,1.6850557327270508,50000 -3023.486344099045,3.0142745971679688,33648.37844085693,73544,0,33648.37844085693,0.4855000376701355,2.3125555515289307,10000,36678.69011569023,0.6450976133346558,1.466753005981445,0.6068800091743469,1.662022590637207,50000 -3064.1061642169952,3.052863836288452,34068.383913517,74464,0,34068.383913517,0.4845000207424164,2.3070414066314697,10000,37139.40186858177,0.6503515243530273,1.4331283569335938,0.6075199842453003,1.6410701274871826,50000 -3099.742102622986,3.097740411758423,34488.60297369957,75386,0,34488.60297369957,0.4878000319004059,2.278662919998169,10000,37595.3496427536,0.6625195145606995,1.3859827518463137,0.6066799759864807,1.6424832344055176,50000 -3141.430042743683,3.137625217437744,34908.645033836365,76306,0,34908.645033836365,0.4958000183105469,2.2743124961853027,10000,38057.16771769524,0.6537500023841858,1.4322412014007568,0.6104399561882019,1.6254669427871704,50000 -3178.3272848129272,3.1766517162323,35328.58205103874,77227,0,35328.58205103874,0.484900027513504,2.277125358581543,10000,38514.08833122253,0.6592382788658142,1.3970091342926023,0.615399956703186,1.60317063331604,50000 -3220.056258201599,3.212582588195801,35748.85032916069,78150,0,35748.85032916069,0.4933000206947326,2.281214952468872,10000,38976.1696164608,0.6634374856948853,1.390366792678833,0.6144999861717224,1.6266072988510132,50000 -3259.250541448593,3.260847330093384,36169.0088224411,79071,0,36169.0088224411,0.488500028848648,2.2790701389312744,10000,39435.61829423904,0.6562694907188416,1.445751428604126,0.614139974117279,1.6321967840194702,50000 -3294.8051438331604,3.3025565147399902,36589.23210167885,79992,0,36589.23210167885,0.492900013923645,2.28087854385376,10000,39891.485827207565,0.6576757431030273,1.4273301362991333,0.6126799583435059,1.6328654289245603,50000 -3334.398421525955,3.347494840621948,37009.35448670387,80913,0,37009.35448670387,0.4927000105381012,2.2475414276123047,10000,40351.29439616203,0.674023449420929,1.3578218221664429,0.6184200048446655,1.6004390716552734,50000 -3372.996278524399,3.3874971866607666,37429.34026837349,81833,0,37429.34026837349,0.4976000189781189,2.232736349105835,10000,40809.9657497406,0.6919335722923279,1.269242525100708,0.6193400025367737,1.5890880823135376,50000 -3410.873600244522,3.429560661315918,37849.2662627697,82750,0,37849.2662627697,0.4955000281333923,2.2633657455444336,10000,41267.85855412483,0.6674609184265137,1.4016146659851074,0.6218599677085876,1.6021045446395874,50000 -3449.7180716991425,3.4665777683258057,38269.307027578354,83672,0,38269.307027578354,0.5044000148773193,2.198859214782715,10000,41726.82884001732,0.6716406345367432,1.3420121669769287,0.6239799857139587,1.5581839084625244,50000 -3482.425269842148,3.50958251953125,38689.40639066696,84591,0,38689.40639066696,0.5059000253677368,2.2042882442474365,10000,42179.727041482925,0.6886523365974426,1.2715094089508057,0.6290000081062317,1.5482810735702517,50000 -3523.0634427070618,3.547616958618164,39109.69956469536,85512,0,39109.69956469536,0.5028000473976135,2.2121453285217285,10000,42640.74392175674,0.6723827719688416,1.3585656881332395,0.6247400045394897,1.566789627075195,50000 -3560.707655906677,3.589759588241577,39529.97727918625,86433,0,39529.97727918625,0.4973000288009643,2.2271392345428467,10000,43098.75600481033,0.6749609112739563,1.341333031654358,0.6245799660682678,1.5732306241989136,50000 -3598.127161026001,3.6342861652374254,39949.916645765305,87354,0,39949.916645765305,0.4999000132083893,2.2489187717437744,10000,43556.20781803131,0.6795117259025574,1.3415911197662354,0.6220999956130981,1.5983110666275024,50000 -3636.217653512954,3.6755335330963135,40369.94181585312,88275,0,40369.94181585312,0.5071000456809998,2.2031280994415283,10000,44014.411709070206,0.6721289157867432,1.3381081819534302,0.62909996509552,1.537670016288757,50000 -3675.636269807816,3.71611762046814,40790.010445833206,89194,0,40790.010445833206,0.5053000450134277,2.1833741664886475,10000,44473.98690462112,0.6786718368530273,1.3135088682174685,0.6307799816131592,1.5366266965866089,50000 -3714.726364850998,3.755250453948975,41210.27362036705,90112,0,41210.27362036705,0.510200023651123,2.1635007858276367,10000,44933.42651605606,0.6904687285423279,1.2547414302825928,0.6344999670982361,1.5105417966842651,50000 -3755.1138138771057,3.793323993682861,41630.5404984951,91032,0,41630.5404984951,0.5189000368118286,2.1651880741119385,10000,45394.16607880592,0.6882616877555847,1.284019112586975,0.637939989566803,1.5183864831924438,50000 -3789.587110757828,3.8358139991760254,42050.500351428986,91951,0,42050.500351428986,0.5070000290870667,2.2034645080566406,10000,45848.689403772354,0.6794726252555847,1.3359767198562622,0.6306799650192261,1.5554149150848389,50000 -3826.5279400348654,3.875086069107056,42470.74870181084,92871,0,42470.74870181084,0.5109000205993652,2.15106463432312,10000,46305.96616792679,0.6898437142372131,1.2812461853027344,0.6362199783325195,1.5131189823150637,50000 -3865.9597566127777,3.9189999103546143,42890.90657186508,93790,0,42890.90657186508,0.5225000381469727,2.150918960571289,10000,46765.64679956436,0.7091601490974426,1.1998614072799685,0.6394400000572205,1.507826566696167,50000 -3902.16099691391,3.960517644882202,43310.93693423271,94711,0,43310.93693423271,0.5175999999046326,2.1582961082458496,10000,47221.96811914444,0.6896093487739563,1.2857972383499146,0.6369999647140503,1.5179483890533447,50000 -3942.607335329056,4.002686023712158,43731.27998661995,95632,0,43731.27998661995,0.5159000158309937,2.168625831604004,10000,47682.84707713127,0.6902929544448853,1.3088555335998535,0.6377800107002258,1.5345081090927124,50000 -3981.583114147186,4.047634124755859,44151.53822731972,96552,0,44151.53822731972,0.5157999992370605,2.130087614059448,10000,48142.17348623276,0.7011132836341858,1.2075631618499756,0.6449399590492249,1.4762593507766724,50000 -4019.17895770073,4.093117952346802,44571.59751033783,97473,0,44571.59751033783,0.5208000540733337,2.130515813827514,10000,48599.9216735363,0.6884570121765137,1.2699761390686035,0.6421399712562561,1.4902182817459106,50000 -4060.374292612076,4.132027626037598,44991.82571578026,98393,0,44991.82571578026,0.5187000036239624,2.157788753509521,10000,49061.43175268173,0.691699206829071,1.2806599140167236,0.6419199705123901,1.4992659091949463,50000 -4100.273753166199,4.570847034454346,45411.43276429176,99312,0,45411.43276429176,0.5258000493049622,2.101590633392334,10000,49521.42532157898,0.7009570002555847,1.2121543884277344,0.6444999575614929,1.4744832515716553,50000 -4135.854706764221,4.6116437911987305,45831.65687251091,100232,0,45831.65687251091,0.5182000398635864,2.115975856781006,10000,49977.31903076172,0.6933984160423279,1.2418538331985474,0.6492399573326111,1.4507449865341189,50000 -4173.125863075256,4.663263559341431,46251.6263062954,101152,0,46251.6263062954,0.5205000042915344,2.1192281246185303,10000,50434.65885519981,0.6918163895606995,1.2562538385391235,0.6451799869537354,1.4713010787963867,50000 -4212.254663944244,4.7148377895355225,46672.01550674439,102069,0,46672.01550674439,0.5216000080108643,2.0956971645355225,10000,50894.275889635086,0.7005273103713989,1.223625421524048,0.64656001329422,1.4615575075149536,50000 -4252.351578950882,4.760955810546875,47092.12380337715,102986,0,47092.12380337715,0.5217000246047974,2.1090574264526367,10000,51354.57412791252,0.7196484208106995,1.1426352262496948,0.6477000117301941,1.4675030708312988,50000 -4292.62885594368,4.802370309829712,47512.29468727112,103907,0,47512.29468727112,0.5212000012397766,2.102874994277954,10000,51815.11089801788,0.7007421851158142,1.230646014213562,0.6519399881362915,1.450840711593628,50000 -4329.373802185059,4.848757028579712,47932.29152727127,104826,0,47932.29152727127,0.5300000309944153,2.06282377243042,10000,52271.9458630085,0.7058984041213989,1.1845290660858154,0.6528800129890442,1.4384924173355105,50000 -4368.735318899155,4.893977880477905,48352.20537209511,105745,0,48352.20537209511,0.528700053691864,2.088734865188598,10000,52731.31432533264,0.71728515625,1.1596193313598633,0.6541799902915955,1.4440714120864868,50000 -4404.980627298355,4.941775798797607,48772.65300059319,106667,0,48772.65300059319,0.5254999995231628,2.090418100357056,10000,53188.10340118408,0.7004296779632568,1.221923828125,0.6500799655914307,1.4510246515274048,50000 -4440.104451656342,4.982898235321045,49192.90219426155,107589,0,49192.90219426155,0.5357000231742859,2.077857255935669,10000,53643.56610870361,0.7095702886581421,1.1786792278289795,0.655239999294281,1.4211721420288086,50000 -4478.299695491791,5.023855924606323,49612.81418466568,108510,0,49612.81418466568,0.5354000329971313,2.0454037189483643,10000,54101.761887550354,0.7232226133346558,1.1331206560134888,0.6595199704170227,1.4153560400009155,50000 -4517.4173810482025,5.06796669960022,50032.897715091705,109430,0,50032.897715091705,0.5306000113487244,2.0763421058654785,10000,54561.05555176735,0.7046093344688416,1.207485914230347,0.6543200016021729,1.4327102899551392,50000 -4555.2443215847015,5.109800100326538,50452.97357225418,110350,0,50452.97357225418,0.5335000157356262,2.0360000133514404,10000,55019.047716379166,0.711621105670929,1.1574349403381348,0.6602199673652649,1.392660140991211,50000 -4594.720537662506,5.160179615020752,50873.10785269737,111270,0,50873.10785269737,0.5386000275611877,2.0380427837371826,10000,55478.7564125061,0.7218554615974426,1.139596462249756,0.6591399908065796,1.4095954895019531,50000 -4634.864460945129,5.204523324966431,51293.37358379364,112189,0,51293.37358379364,0.5388000011444092,2.0063302516937256,10000,55939.25805449486,0.72132807970047,1.1169127225875854,0.6618199944496155,1.3724278211593628,50000 -4674.825638771057,5.251617193222046,51713.63455915451,113109,0,51713.63455915451,0.5427000522613525,2.011446714401245,10000,56399.57521724701,0.7181445360183716,1.1483697891235352,0.6665199995040894,1.376614332199097,50000 -4715.29369020462,5.293383836746216,52133.7013399601,114029,0,52133.7013399601,0.5432000160217285,2.00452733039856,10000,56860.199070215225,0.7268164157867432,1.0959932804107666,0.6672999858856201,1.353961706161499,50000 -4752.462374687195,5.3358001708984375,52553.64809894562,114950,0,52553.64809894562,0.5387000441551208,2.023356914520264,10000,57317.40486860275,0.7395898103713989,1.0591590404510498,0.6657999753952026,1.3845274448394775,50000 -4790.233314990997,5.381194114685059,52973.94215083122,115872,0,52973.94215083122,0.5461000204086304,1.9749292135238647,10000,57775.563047885895,0.72279292345047,1.1121561527252195,0.670799970626831,1.3483924865722656,50000 -4828.128947257996,5.423955678939819,53393.94225502014,116791,0,53393.94225502014,0.542900025844574,2.02569580078125,10000,58233.54987645149,0.7229296565055847,1.1437608003616333,0.6652799844741821,1.4038317203521729,50000 -4865.206468343735,5.476790428161621,53813.971108198166,117710,0,53813.971108198166,0.5506000518798828,1.951604962348938,10000,58690.75763726234,0.741406261920929,1.021639108657837,0.6748600006103516,1.3188023567199707,50000 -4902.387039661408,5.5232415199279785,54233.92829823494,118631,0,54233.92829823494,0.5515000224113464,1.998087406158448,10000,59147.98989415169,0.7268164157867432,1.1247150897979736,0.6675800085067749,1.379623532295227,50000 -4941.857255935669,5.566040754318237,54653.83897304535,119551,0,54653.83897304535,0.5533000230789185,1.9669947624206543,10000,59607.46091222763,0.7312304377555847,1.0820475816726685,0.6734399795532227,1.340196967124939,50000 -4979.9430141448975,5.610635757446289,55074.14640974999,120472,0,55074.14640974999,0.5540000200271606,1.9518202543258667,10000,60065.947801828384,0.7416015267372131,1.0313720703125,0.6787199974060059,1.3119388818740845,50000 -5019.839179754257,5.657719373703003,55494.27367591858,121393,0,55494.27367591858,0.5552000403404236,1.934516668319702,10000,60526.06591033936,0.7378906011581421,1.0619568824768066,0.6817599534988403,1.3020442724227903,50000 -5058.548875808716,5.700714349746704,55914.204786777496,122313,0,55914.204786777496,0.54830002784729,1.960036039352417,10000,60984.79734659195,0.73011714220047,1.0760235786437988,0.6768199801445007,1.3135195970535278,50000 -5094.679251432419,5.7470362186431885,56334.31851649284,123234,0,56334.31851649284,0.5596000552177429,1.9329463243484497,10000,61441.13616466522,0.7426952719688416,1.0379911661148071,0.6814799904823303,1.3065279722213743,50000 -5134.01913523674,5.790019989013672,56754.34809017181,124154,0,56754.34809017181,0.562000036239624,1.938615322113037,10000,61900.59717607498,0.7546288967132568,0.995815634727478,0.6815999746322632,1.3149043321609497,50000 -5174.269753456116,5.836784839630127,57174.4171538353,125075,0,57174.4171538353,0.5525000095367432,1.9468294382095337,10000,62361.01248407364,0.7390820384025574,1.068745493888855,0.6852999925613403,1.3048429489135742,50000 -5209.269439458847,5.885339975357056,57594.61695051193,125998,0,57594.61695051193,0.5623000264167786,1.9115575551986688,10000,62816.308730363846,0.7458398342132568,1.0124694108963013,0.6858400106430054,1.2887572050094604,50000 -5247.848149061203,5.937408208847046,58014.5237903595,126918,0,58014.5237903595,0.5681000351905823,1.88283109664917,10000,63274.89465546608,0.7591796517372131,0.9492820501327516,0.6929000020027161,1.2491334676742554,50000 -5287.201727390289,5.98337721824646,58434.72530388832,127838,0,58434.72530388832,0.5680000185966492,1.907129049301148,10000,63734.5433690548,0.7441796660423279,1.0422630310058594,0.6882599592208862,1.2880897521972656,50000 -5327.26294875145,6.034414768218994,58854.81752896309,128759,0,58854.81752896309,0.5669000148773193,1.8810703754425049,10000,64194.79561638832,0.74853515625,1.0025569200515747,0.6908800005912781,1.2619134187698364,50000 -5363.8732233047485,6.085147142410278,59275.05393505096,129680,0,59275.05393505096,0.5715000033378601,1.8652161359786987,10000,64651.74168539047,0.7611327767372131,0.9462226629257202,0.6943999528884888,1.2438348531723022,50000 -5400.780389070511,6.132253170013428,59695.36253666878,130601,0,59695.36253666878,0.5756000280380249,1.8345760107040403,10000,65109.05203604698,0.7528710961341858,0.9790438413619996,0.6981199979782104,1.2228922843933103,50000 -5438.998530864716,6.179803848266602,60115.69743990898,131524,0,60115.69743990898,0.5743000507354736,1.8481074571609497,10000,65567.70094275475,0.7593945264816284,0.9528237581253052,0.695639967918396,1.2331786155700684,50000 -5478.9129366874695,6.230070352554321,60535.69753456116,132444,0,60535.69753456116,0.5740000009536743,1.8517664670944207,10000,66027.71335840225,0.7639452815055847,0.9388325214385986,0.6972399950027466,1.2304061651229858,50000 -5516.357885599136,6.277270317077637,60955.82214999199,133364,0,60955.82214999199,0.5734000205993652,1.842565536499024,10000,66485.37777233124,0.7642187476158142,0.9470322132110596,0.7019000053405762,1.2145510911941528,50000 -5555.144359827042,6.326719284057617,61376.02754378319,134284,0,61376.02754378319,0.5781000256538391,1.8447034358978271,10000,66944.4675412178,0.7623632550239563,0.9518302083015442,0.7013999819755554,1.2288814783096311,50000 -5592.194238185883,6.378962516784668,61796.07824969292,135201,0,61796.07824969292,0.5804000496864319,1.818007469177246,10000,67401.66700196266,0.7685546875,0.9147396087646484,0.7045199871063232,1.204154372215271,50000 -5630.725626945496,6.431999683380127,62216.07617998123,136122,0,62216.07617998123,0.5815000534057617,1.8282631635665887,10000,67860.29677629471,0.7792773246765137,0.8840520977973938,0.7030199766159058,1.2064038515090942,50000 -5671.676826477051,6.483324766159058,62636.41758394241,137042,0,62636.41758394241,0.5777000188827515,1.8337697982788088,10000,68321.68812680244,0.765625,0.9307951331138612,0.7046799659729004,1.205407738685608,50000 -5708.94930100441,6.528689861297607,63056.63118767738,137962,0,63056.63118767738,0.5809000134468079,1.822838544845581,10000,68779.26647734642,0.7728710770606995,0.9098615050315856,0.7051399946212769,1.195683240890503,50000 -5749.549212932587,6.578404903411865,63476.66384673119,138884,0,63476.66384673119,0.5842000246047974,1.8043678998947144,10000,69239.99764561653,0.7806445360183716,0.8815276026725769,0.7076799869537354,1.1955801248550415,50000 -5790.326649427414,6.624654054641724,63896.58831167221,139805,0,63896.58831167221,0.5848000049591064,1.822811365127564,10000,69700.79360175133,0.771289050579071,0.9243224859237672,0.7076199650764465,1.2005985975265503,50000 -5828.982895612717,6.679133176803589,64316.83931660652,140726,0,64316.83931660652,0.5871000289916992,1.7759684324264526,10000,70159.80298304558,0.777148425579071,0.8691675662994385,0.7116000056266785,1.1573071479797363,50000 -5869.657928228378,6.726979732513428,64737.20595860481,141648,0,64737.20595860481,0.5850000381469727,1.806455373764038,10000,70620.941873312,0.7853710651397705,0.8723379969596863,0.7133600115776062,1.185519099235535,50000 -5908.061641693115,7.164468288421631,65157.24151563645,142567,0,65157.24151563645,0.5960000157356262,1.7515387535095217,10000,71079.86623930931,0.7803124785423279,0.8587034344673157,0.7161999940872192,1.133145093917847,50000 -5945.656690597534,7.214048862457275,65577.55893421173,143486,0,65577.55893421173,0.5928000211715698,1.7593907117843628,10000,71537.87617897987,0.7825976610183716,0.8492751121520996,0.7155599594116211,1.143829345703125,50000 -5986.8613493442535,7.263937711715698,65997.90977883339,144404,0,65997.90977883339,0.5980000495910645,1.7610589265823364,10000,71999.52923107147,0.7892382740974426,0.8360196352005005,0.7186799645423889,1.148780107498169,50000 -6024.559996366501,7.30958104133606,66417.887434721,145324,0,66417.887434721,0.5958000421524048,1.755782127380371,10000,72457.2994647026,0.796191394329071,0.796824038028717,0.7212199568748474,1.1199946403503418,50000 -6063.822338104248,7.380070686340332,66837.9214565754,146245,0,66837.9214565754,0.6011000275611877,1.7424756288528442,10000,72916.71459913254,0.7898046970367432,0.825894832611084,0.7240999937057495,1.1160138845443726,50000 -6104.866254329681,7.430980443954468,67257.93815946579,147163,0,67257.93815946579,0.5966000556945801,1.7341212034225464,10000,73377.8737001419,0.7924023270606995,0.8204847574234009,0.7229200005531311,1.114357590675354,50000 -6146.32323050499,7.481757164001465,67677.93710446358,148082,0,67677.93710446358,0.6007000207901001,1.7311818599700928,10000,73839.43036198616,0.7983788847923279,0.792547345161438,0.7223399877548218,1.120306372642517,50000 -6187.508858203888,7.528716564178467,68097.8760201931,149000,0,68097.8760201931,0.6040000319480896,1.7259600162506104,10000,74300.64919734001,0.7928906083106995,0.8123624920845032,0.7275399565696716,1.103825926780701,50000 -6222.238587379456,7.587629318237305,68518.12698411942,149923,0,68518.12698411942,0.6043000221252441,1.7013198137283323,10000,74755.73614430428,0.7988671660423279,0.78203284740448,0.7282399535179138,1.0843335390090942,50000 -6265.426444530487,7.636116981506348,68938.50674057007,150845,0,68938.50674057007,0.6041000485420227,1.6799514293670654,10000,75219.40028524399,0.8016406297683716,0.7594988346099854,0.7286399602890015,1.0782972574234009,50000 -6309.68102478981,7.684762239456177,69358.68147015572,151767,0,69358.68147015572,0.6041000485420227,1.696513533592224,10000,75683.92640280724,0.7989062070846558,0.7856873273849487,0.7301799654960632,1.0794763565063477,50000 -6348.403820991516,7.73581075668335,69778.69618415833,152688,0,69778.69618415833,0.6077000498771667,1.7170206308364868,10000,76142.76267409325,0.8006835579872131,0.7903160452842712,0.7302599549293518,1.0970721244812012,50000 -6387.8042669296265,7.789788961410522,70198.95512890816,153610,0,70198.95512890816,0.6107000112533569,1.6835938692092896,10000,76602.52445936203,0.8105077743530273,0.7433952689170837,0.7332199811935425,1.0680303573608398,50000 -6428.671512365341,7.83747386932373,70619.00755596161,154531,0,70619.00755596161,0.6134000420570374,1.6804908514022827,10000,77063.54085183144,0.8120312094688416,0.7386379241943359,0.7362799644470215,1.065176486968994,50000 -6466.462705373764,7.888170719146728,71039.42457556725,155453,0,71039.42457556725,0.6100000143051147,1.664387345314026,10000,77521.84715628624,0.8109374642372131,0.7333440780639648,0.7377199530601501,1.0530829429626465,50000 -6507.477152347565,7.9470250606536865,71459.36435127258,156371,0,71459.36435127258,0.6124000549316406,1.6711875200271606,10000,77982.90768504143,0.81103515625,0.74477618932724,0.737060010433197,1.0577281713485718,50000 -6545.526882886887,7.999570369720459,71879.43958759308,157290,0,71879.43958759308,0.6168000102043152,1.6577025651931765,10000,78441.13249969482,0.82044917345047,0.6939221620559692,0.7407599687576294,1.043515920639038,50000 -6585.272683382034,8.054925918579102,72299.37812900543,158209,0,72299.37812900543,0.6183000206947327,1.6337950229644775,10000,78900.9204583168,0.8133788704872131,0.7271475791931152,0.7434799671173096,1.0289294719696045,50000 -6626.9921362400055,8.114330053329468,72719.4393901825,159129,0,72719.4393901825,0.6193000078201294,1.6151468753814695,10000,79362.80851197243,0.8163671493530273,0.6988460421562195,0.745419979095459,1.0122644901275637,50000 -6668.070779085159,8.16342830657959,73139.74932837486,160052,0,73139.74932837486,0.6236000061035156,1.628363013267517,10000,79824.29507088661,0.8247265219688416,0.6779054403305054,0.7456199526786804,1.016263723373413,50000 -6705.505133152008,8.21670150756836,73559.80233550072,160972,0,73559.80233550072,0.6212000250816345,1.6267614364624023,10000,80281.88353705406,0.8173242211341858,0.7082533240318298,0.745199978351593,1.023060321807861,50000 -6747.892434120178,8.278788328170776,73979.69502210617,161891,0,73979.69502210617,0.6193000078201294,1.627989649772644,10000,80744.27344155312,0.8249804377555847,0.6753319501876831,0.7469199895858765,1.015852451324463,50000 -6786.871521949768,8.332364797592163,74399.87388181686,162812,0,74399.87388181686,0.6238000392913818,1.611265778541565,10000,81203.53269195557,0.8253905773162842,0.6590047478675842,0.7476199865341187,1.0020933151245115,50000 -6825.848484277725,8.381687641143799,74820.23160123825,163733,0,74820.23160123825,0.6267000436782837,1.6085433959960938,10000,81662.96415829659,0.8265624642372131,0.6701897382736206,0.7502399682998657,0.9970006942749025,50000 -6866.622145652771,8.43875503540039,75240.45762014389,164650,0,75240.45762014389,0.6266000270843506,1.5907145738601685,10000,82124.06794404984,0.8306249976158142,0.6445387601852417,0.7535799741744995,0.9883765578269958,50000 -6905.6733481884,8.490631818771362,75660.44420909882,165572,0,75660.44420909882,0.6296000480651855,1.5999563932418823,10000,82583.20623064041,0.8291601538658142,0.6529331207275391,0.7529000043869019,0.9910786747932434,50000 -6944.787847995758,8.540288209915161,76080.50569367409,166493,0,76080.50569367409,0.6313000321388245,1.5768295526504517,10000,83042.480260849,0.8376562595367432,0.6226816177368164,0.7539199590682983,0.9782753586769104,50000 -6985.094571828842,8.595833539962769,76500.47437024117,167415,0,76500.47437024117,0.6365000009536743,1.5672297477722168,10000,83502.85957407951,0.8343359231948853,0.6367462277412415,0.7556799650192261,0.9818673729896544,50000 -7026.618679523468,8.646256685256958,76920.42654371262,168333,0,76920.42654371262,0.631600022315979,1.5703439712524414,10000,83964.43345236778,0.8360351324081421,0.6291006207466125,0.756659984588623,0.974940836429596,50000 -7063.42170381546,8.696255445480347,77340.37170767784,169253,0,77340.37170767784,0.6389000415802002,1.5526543855667114,10000,84421.27886939049,0.8400976657867432,0.60579913854599,0.7610599994659424,0.9543364644050598,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/measurements.csv deleted file mode 100644 index 325c16d22..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1884 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36488757,6.907756,,,,,,,,,,,,,, -1,,,0.0011914062779396,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,36.40782308578491,63.7958824634552,36.40782308578491,27.38794589042664,0.0,0.0 -100,0.42844564,6.905877,,,,,,,,,,,,,, -200,0.42369482,6.8975124,,,,,,,,,,,,,, -300,0.5885602,6.853509,,,,,,,,,,,,,, -400,0.814811,6.800469,,,,,,,,,,,,,, -500,1.2914318,6.8055167,,,,,,,,,,,,,, -600,1.1203699,6.7263336,,,,,,,,,,,,,, -700,1.0554678,6.6464024,,,,,,,,,,,,,, -800,1.5437915,6.587701,,,,,,,,,,,,,, -862,,,0.0114257810637354,6.446063041687012,0.0111599992960691,6.462528705596924,50000.0,0.0081000002101063,6.500799655914307,10000.0,456.3771412372589,522.0745017528534,456.3771412372589,65.63604092597961,0.0173430442810058,0.0 -900,1.0034686,6.5030437,,,,,,,,,,,,,, -1000,2.2712648,6.5310163,,,,,,,,,,,,,, -1100,2.0698843,6.5552464,,,,,,,,,,,,,, -1200,1.671938,6.3760724,,,,,,,,,,,,,, -1300,1.5023774,6.6035905,,,,,,,,,,,,,, -1400,2.0512426,6.3122334,,,,,,,,,,,,,, -1500,1.9628958,6.6241426,,,,,,,,,,,,,, -1600,2.0302296,6.127065,,,,,,,,,,,,,, -1700,1.6598198,6.1169944,,,,,,,,,,,,,, -1781,,,0.0377734377980232,5.867616653442383,0.0365999974310398,5.891701698303223,50000.0,0.0306000020354986,6.009640693664551,10000.0,876.6863760948181,980.255558013916,876.6863760948181,103.42611718177795,0.0514211654663085,0.0 -1800,2.166696,6.1170797,,,,,,,,,,,,,, -1900,3.041412,6.043242,,,,,,,,,,,,,, -2000,2.0469866,6.364162,,,,,,,,,,,,,, -2100,2.2217,5.949814,,,,,,,,,,,,,, -2200,1.8309466,5.960882,,,,,,,,,,,,,, -2300,2.5928624,5.9155145,,,,,,,,,,,,,, -2400,2.333645,5.8093524,,,,,,,,,,,,,, -2500,1.9969585,6.47247,,,,,,,,,,,,,, -2600,2.1205764,5.8354535,,,,,,,,,,,,,, -2700,,,0.0656054690480232,5.406094074249268,0.0608999989926815,5.445602416992188,50000.0,0.0462000034749507,5.637139320373535,10000.0,1296.615249156952,1444.4656417369845,1296.615249156952,147.62753534317017,0.0841724872589111,0.0 -2700,1.864677,6.163457,,,,,,,,,,,,,, -2800,2.667129,5.836322,,,,,,,,,,,,,, -2900,2.3676293,5.7672377,,,,,,,,,,,,,, -3000,1.9830049,6.2767487,,,,,,,,,,,,,, -3100,2.8657525,5.666056,,,,,,,,,,,,,, -3200,1.9263164,6.5176353,,,,,,,,,,,,,, -3300,2.6438906,5.6502337,,,,,,,,,,,,,, -3400,2.2969618,5.6704273,,,,,,,,,,,,,, -3500,2.2628956,5.6089787,,,,,,,,,,,,,, -3600,2.2082033,5.705845,,,,,,,,,,,,,, -3620,,,0.0945898443460464,5.085729122161865,0.0872199982404708,5.123086929321289,50000.0,0.0667000040411949,5.365293502807617,10000.0,1716.6178047657013,1901.7651226520536,1716.6178047657013,184.84893035888672,0.1116495132446289,0.0 -3700,1.634988,6.0178356,,,,,,,,,,,,,, -3800,1.9587463,5.803315,,,,,,,,,,,,,, -3900,1.8790209,5.422095,,,,,,,,,,,,,, -4000,1.8643892,6.2452464,,,,,,,,,,,,,, -4100,2.2426171,5.6947055,,,,,,,,,,,,,, -4200,1.8822175,5.416585,,,,,,,,,,,,,, -4300,1.6424308,5.341921,,,,,,,,,,,,,, -4400,1.459821,6.0131383,,,,,,,,,,,,,, -4500,1.9468112,5.304676,,,,,,,,,,,,,, -4540,,,0.1378320306539535,4.668841361999512,0.1275399923324585,4.732216835021973,50000.0,0.0945000052452087,5.031833648681641,10000.0,2136.5751678943634,2359.7350487709045,2136.5751678943634,222.7896881103516,0.1356348991394043,0.0 -4600,1.5169747,6.49801,,,,,,,,,,,,,, -4700,1.7780573,5.359398,,,,,,,,,,,,,, -4800,1.9298401,5.533621,,,,,,,,,,,,,, -4900,1.4513241,6.4050856,,,,,,,,,,,,,, -5000,1.4610356,6.1852365,,,,,,,,,,,,,, -5100,2.0263836,5.0610795,,,,,,,,,,,,,, -5200,2.0672238,5.3894334,,,,,,,,,,,,,, -5300,2.1811259,5.084919,,,,,,,,,,,,,, -5400,1.7220749,5.7197714,,,,,,,,,,,,,, -5457,,,0.1705273389816284,4.367717742919922,0.1566199958324432,4.442732334136963,50000.0,0.1197000071406364,4.782550811767578,10000.0,2556.6016159057617,2814.6012556552887,2556.6016159057617,257.5574834346771,0.1605908870697021,0.0 -5500,2.3029296,5.151623,,,,,,,,,,,,,, -5600,1.5697832,5.6871567,,,,,,,,,,,,,, -5700,1.815429,5.2058163,,,,,,,,,,,,,, -5800,1.7885712,4.8863964,,,,,,,,,,,,,, -5900,1.7109226,5.3705053,,,,,,,,,,,,,, -6000,1.8065794,4.6975965,,,,,,,,,,,,,, -6100,2.008911,6.3291283,,,,,,,,,,,,,, -6200,1.8587618,5.7501173,,,,,,,,,,,,,, -6300,1.4939364,6.1034408,,,,,,,,,,,,,, -6378,,,0.2243359386920929,3.935076713562012,0.201339989900589,4.064193248748779,50000.0,0.1504000127315521,4.456921100616455,10000.0,2976.937946081161,3268.678102016449,2976.937946081161,291.22222685813904,0.1891450881958007,0.0 -6400,1.790456,4.690585,,,,,,,,,,,,,, -6500,1.7715696,6.216513,,,,,,,,,,,,,, -6600,1.9706594,4.6468163,,,,,,,,,,,,,, -6700,1.9859761,4.626078,,,,,,,,,,,,,, -6800,1.3596368,6.2073555,,,,,,,,,,,,,, -6900,1.8689998,4.6988654,,,,,,,,,,,,,, -7000,2.1461453,4.631405,,,,,,,,,,,,,, -7100,2.421956,4.475419,,,,,,,,,,,,,, -7200,1.6856347,5.827184,,,,,,,,,,,,,, -7295,,,0.25927734375,3.695412397384644,0.2415599972009658,3.787732601165772,50000.0,0.1825000047683715,4.207826137542725,10000.0,3396.901723384857,3729.273754119873,3396.901723384857,331.77545142173767,0.2208545207977295,0.0 -7300,2.1097858,4.519277,,,,,,,,,,,,,, -7400,1.5854932,6.1690693,,,,,,,,,,,,,, -7500,1.9004412,4.403146,,,,,,,,,,,,,, -7600,2.269602,4.3670597,,,,,,,,,,,,,, -7700,1.8943136,4.507589,,,,,,,,,,,,,, -7800,1.9185724,4.391677,,,,,,,,,,,,,, -7900,1.4635043,5.552172,,,,,,,,,,,,,, -8000,2.517918,4.3561964,,,,,,,,,,,,,, -8100,1.9341309,4.5161915,,,,,,,,,,,,,, -8200,1.6636512,4.5981436,,,,,,,,,,,,,, -8216,,,0.2934960722923279,3.44327449798584,0.2733799815177917,3.552814483642578,50000.0,0.2082000076770782,4.0174102783203125,10000.0,3817.1914489269257,4185.551072597504,3817.1914489269257,367.687614440918,0.247636079788208,0.0 -8300,1.7850281,4.29619,,,,,,,,,,,,,, -8400,1.5548173,5.7704134,,,,,,,,,,,,,, -8500,1.7346793,4.642346,,,,,,,,,,,,,, -8600,2.3092623,4.1509013,,,,,,,,,,,,,, -8700,1.5465096,5.597967,,,,,,,,,,,,,, -8800,2.0949643,4.114829,,,,,,,,,,,,,, -8900,1.4143417,5.5591574,,,,,,,,,,,,,, -9000,1.4410869,6.06589,,,,,,,,,,,,,, -9100,1.3760668,6.0944357,,,,,,,,,,,,,, -9136,,,0.3303906321525574,3.2213683128356934,0.2978399991989136,3.398451089859009,50000.0,0.2312000095844268,3.8833281993865967,10000.0,4237.426783323288,4640.326997518539,4237.426783323288,402.15560007095337,0.2722275257110595,0.0 -9200,1.9970185,4.010526,,,,,,,,,,,,,, -9300,1.4225252,5.2467213,,,,,,,,,,,,,, -9400,1.7758998,4.3700876,,,,,,,,,,,,,, -9500,1.6850948,4.7133045,,,,,,,,,,,,,, -9600,1.7821635,3.971266,,,,,,,,,,,,,, -9700,1.4010934,5.046053,,,,,,,,,,,,,, -9800,1.8685778,3.934369,,,,,,,,,,,,,, -9900,1.8511181,4.211657,,,,,,,,,,,,,, -10000,1.4329854,5.4365683,,,,,,,,,,,,,, -10054,,,0.3563671708106994,3.065946578979492,0.3330000042915344,3.185922622680664,50000.0,0.2563000023365021,3.715104103088379,10000.0,4657.610203266144,5097.769255399704,4657.610203266144,439.33963537216187,0.2994673252105713,0.0 -10100,1.9229703,3.9084685,,,,,,,,,,,,,, -10200,1.5509287,4.1818943,,,,,,,,,,,,,, -10300,2.3373506,3.8591084,,,,,,,,,,,,,, -10400,2.073602,3.7530901,,,,,,,,,,,,,, -10500,1.8733044,3.8506021,,,,,,,,,,,,,, -10600,1.145324,5.456328,,,,,,,,,,,,,, -10700,1.4007738,5.6477785,,,,,,,,,,,,,, -10800,2.0019755,3.83234,,,,,,,,,,,,,, -10900,1.3701919,4.288796,,,,,,,,,,,,,, -10973,,,0.3862695097923279,2.848928213119507,0.3571199774742126,2.993165969848633,50000.0,0.2735000252723694,3.545906782150269,10000.0,5077.543553113937,5556.3450927734375,5077.543553113937,477.90121936798096,0.3332977294921875,0.0 -11000,1.7048309,3.9883308,,,,,,,,,,,,,, -11100,1.6430485,3.9904776,,,,,,,,,,,,,, -11200,1.5579187,5.2915187,,,,,,,,,,,,,, -11300,1.680663,3.6927676,,,,,,,,,,,,,, -11400,1.5508143,4.7867765,,,,,,,,,,,,,, -11500,1.6466321,4.007243,,,,,,,,,,,,,, -11600,1.5874667,3.9571934,,,,,,,,,,,,,, -11700,1.6646309,3.7726102,,,,,,,,,,,,,, -11800,1.8767889,3.7248921,,,,,,,,,,,,,, -11892,,,0.4043749868869781,2.7795512676239014,0.367279976606369,2.9624674320220947,50000.0,0.287200003862381,3.504602193832397,10000.0,5497.630401134491,6011.133774995804,5497.630401134491,512.5262434482574,0.3625199794769287,0.0 -11900,1.7342494,3.7349596,,,,,,,,,,,,,, -12000,1.4951099,4.2862263,,,,,,,,,,,,,, -12100,1.7706535,3.736533,,,,,,,,,,,,,, -12200,1.7830071,3.9601574,,,,,,,,,,,,,, -12300,1.2152781,5.8847938,,,,,,,,,,,,,, -12400,1.312873,4.9780426,,,,,,,,,,,,,, -12500,1.2491906,5.6222744,,,,,,,,,,,,,, -12600,2.0143733,3.6187115,,,,,,,,,,,,,, -12700,1.2560014,4.99718,,,,,,,,,,,,,, -12800,1.593517,4.1890697,,,,,,,,,,,,,, -12809,,,0.4223828017711639,2.665945291519165,0.3883399963378906,2.8275606632232666,50000.0,0.3069000244140625,3.374930620193481,10000.0,5917.629370927811,6468.676458835602,5917.629370927811,549.9961397647858,0.389383316040039,0.0 -12900,1.1451762,5.5709686,,,,,,,,,,,,,, -13000,1.4622271,5.02016,,,,,,,,,,,,,, -13100,1.7927222,3.690188,,,,,,,,,,,,,, -13200,1.1647933,5.920319,,,,,,,,,,,,,, -13300,1.8152878,3.5833707,,,,,,,,,,,,,, -13400,1.6786104,3.639712,,,,,,,,,,,,,, -13500,2.0769615,3.5246506,,,,,,,,,,,,,, -13600,1.9399805,3.5017982,,,,,,,,,,,,,, -13700,1.6989232,3.59565,,,,,,,,,,,,,, -13726,,,0.4422460794448852,2.522275447845459,0.4086399972438812,2.700598955154419,50000.0,0.3149000108242035,3.2909696102142334,10000.0,6337.654529333115,6927.02899646759,6337.654529333115,588.2510459423065,0.4151310920715332,0.0 -13800,1.3008937,5.461483,,,,,,,,,,,,,, -13900,1.8254794,3.5584908,,,,,,,,,,,,,, -14000,1.0316144,5.833988,,,,,,,,,,,,,, -14100,1.6510171,3.5442367,,,,,,,,,,,,,, -14200,1.7796952,3.5468876,,,,,,,,,,,,,, -14300,1.7078322,3.5120802,,,,,,,,,,,,,, -14400,1.6990627,3.6137383,,,,,,,,,,,,,, -14500,1.7526891,3.445033,,,,,,,,,,,,,, -14600,1.43487,3.9062412,,,,,,,,,,,,,, -14648,,,0.445136696100235,2.541533708572388,0.4125999808311462,2.715203762054444,50000.0,0.3217000067234039,3.2877392768859863,10000.0,6758.110645294189,7386.947074890137,6758.110645294189,627.6393864154816,0.4406590461730957,0.0 -14700,1.1561967,5.419718,,,,,,,,,,,,,, -14800,1.2180529,5.3214793,,,,,,,,,,,,,, -14900,1.4985783,4.507943,,,,,,,,,,,,,, -15000,1.113436,4.7729287,,,,,,,,,,,,,, -15100,1.8980536,3.4899917,,,,,,,,,,,,,, -15200,1.2150935,5.783888,,,,,,,,,,,,,, -15300,1.6387666,3.4692736,,,,,,,,,,,,,, -15400,1.7092738,3.7875085,,,,,,,,,,,,,, -15500,1.5891911,3.583775,,,,,,,,,,,,,, -15568,,,0.4643945097923279,2.417165279388428,0.4288599789142608,2.588149070739746,50000.0,0.3345000147819519,3.1643946170806885,10000.0,7178.091734886169,7848.810657024383,7178.091734886169,669.4438850879669,0.4714105129241943,0.0 -15600,1.2361072,4.7844467,,,,,,,,,,,,,, -15700,1.9883581,3.3845568,,,,,,,,,,,,,, -15800,1.5225358,3.284987,,,,,,,,,,,,,, -15900,1.7088119,3.4870164,,,,,,,,,,,,,, -16000,1.7759746,3.3983436,,,,,,,,,,,,,, -16100,1.2224808,5.3790593,,,,,,,,,,,,,, -16200,1.1205237,5.6966186,,,,,,,,,,,,,, -16300,1.4619985,4.011345,,,,,,,,,,,,,, -16400,1.1034122,5.0036616,,,,,,,,,,,,,, -16487,,,0.4732031226158142,2.3529422283172607,0.4392599761486053,2.526226758956909,50000.0,0.3418000042438507,3.12949013710022,10000.0,7598.284974575043,8307.79007267952,7598.284974575043,708.1572668552399,0.4968502521514892,0.0 -16500,2.0555494,3.2427826,,,,,,,,,,,,,, -16600,1.6536198,3.2486184,,,,,,,,,,,,,, -16700,1.717081,3.4434364,,,,,,,,,,,,,, -16800,1.2339877,5.2655535,,,,,,,,,,,,,, -16900,1.7685499,3.2055297,,,,,,,,,,,,,, -17000,1.8315319,3.3754628,,,,,,,,,,,,,, -17100,1.4930862,3.4878993,,,,,,,,,,,,,, -17200,1.8374047,3.5395794,,,,,,,,,,,,,, -17300,1.1668624,3.9858115,,,,,,,,,,,,,, -17400,1.2463346,4.767714,,,,,,,,,,,,,, -17409,,,0.4809765517711639,2.311243772506714,0.4396799802780151,2.5017337799072266,50000.0,0.3455000221729278,3.113865852355957,10000.0,8018.596554040909,8765.373228549957,8018.596554040909,745.3528969287872,0.5254116058349609,0.0 -17500,1.1676798,5.6081767,,,,,,,,,,,,,, -17600,1.5943067,3.4341788,,,,,,,,,,,,,, -17700,1.4734586,3.1902094,,,,,,,,,,,,,, -17800,1.3864743,3.537814,,,,,,,,,,,,,, -17900,1.561523,4.582114,,,,,,,,,,,,,, -18000,1.3585527,4.170854,,,,,,,,,,,,,, -18100,1.3609743,4.872556,,,,,,,,,,,,,, -18200,1.2725044,4.020527,,,,,,,,,,,,,, -18300,1.7274076,3.481092,,,,,,,,,,,,,, -18330,,,0.5154687166213989,2.13545823097229,0.4569000005722046,2.423603773117065,50000.0,0.3544000089168548,3.041668176651001,10000.0,8438.65920996666,9223.974129915236,8438.65920996666,783.8180267810822,0.551140308380127,0.0 -18400,1.287785,5.808053,,,,,,,,,,,,,, -18500,1.4805564,3.7223153,,,,,,,,,,,,,, -18600,2.0892172,3.6057205,,,,,,,,,,,,,, -18700,1.4481591,3.281217,,,,,,,,,,,,,, -18800,1.5191456,3.4070392,,,,,,,,,,,,,, -18900,1.5653989,3.5254502,,,,,,,,,,,,,, -19000,1.5347472,3.20291,,,,,,,,,,,,,, -19100,1.3976418,3.4613817,,,,,,,,,,,,,, -19200,1.127003,5.42227,,,,,,,,,,,,,, -19249,,,0.495410144329071,2.2226991653442383,0.4612599909305572,2.3966081142425537,50000.0,0.3618000149726867,3.014835834503174,10000.0,8858.939247369766,9682.179508447647,8858.939247369766,821.6687545776367,0.5784051418304443,0.0 -19300,1.5087668,3.2937453,,,,,,,,,,,,,, -19400,1.4623989,2.998651,,,,,,,,,,,,,, -19500,1.6604335,3.2293913,,,,,,,,,,,,,, -19600,1.1839933,5.210069,,,,,,,,,,,,,, -19700,1.6439172,3.1663368,,,,,,,,,,,,,, -19800,1.6036493,3.34825,,,,,,,,,,,,,, -19900,1.1499456,5.2214885,,,,,,,,,,,,,, -20000,1.2981825,5.579206,,,,,,,,,,,,,, -20100,1.7213068,3.260432,,,,,,,,,,,,,, -20170,,,0.4988085925579071,2.2068932056427,0.4660399854183197,2.378439426422119,50000.0,0.359000027179718,3.015586376190185,10000.0,9279.185322284698,10135.67358827591,9279.185322284698,854.8419258594513,0.6058785915374756,0.0 -20200,1.2762845,4.5471444,,,,,,,,,,,,,, -20300,1.5243243,3.8039095,,,,,,,,,,,,,, -20400,1.7160482,3.291722,,,,,,,,,,,,,, -20500,1.4040508,3.8835015,,,,,,,,,,,,,, -20600,1.0791724,4.9849396,,,,,,,,,,,,,, -20700,1.611092,3.1620646,,,,,,,,,,,,,, -20800,1.5858057,3.193881,,,,,,,,,,,,,, -20900,1.0911185,4.598334,,,,,,,,,,,,,, -21000,1.674271,3.1199315,,,,,,,,,,,,,, -21092,,,0.5242773294448853,2.095875024795532,0.4736199975013733,2.349881649017334,50000.0,0.3703000247478485,2.963587045669556,10000.0,9699.388586997986,10590.940872192385,9699.388586997986,889.8293704986572,0.6350181102752686,0.0 -21100,1.7471015,3.0429523,,,,,,,,,,,,,, -21200,1.6338518,3.0352576,,,,,,,,,,,,,, -21300,1.4201007,3.100226,,,,,,,,,,,,,, -21400,1.3272501,5.1956525,,,,,,,,,,,,,, -21500,1.6213223,3.1379204,,,,,,,,,,,,,, -21600,1.0875318,5.0717597,,,,,,,,,,,,,, -21700,1.6379718,3.219653,,,,,,,,,,,,,, -21800,1.297788,3.3927293,,,,,,,,,,,,,, -21900,1.4205239,3.274095,,,,,,,,,,,,,, -22000,1.7424514,3.4297996,,,,,,,,,,,,,, -22012,,,0.5155078172683716,2.124331474304199,0.4827999770641327,2.288004398345948,50000.0,0.3778000175952911,2.893505096435547,10000.0,10119.395952701569,11044.343485832214,10119.395952701569,923.1459929943084,0.6662139892578125,0.0 -22100,1.766399,3.3118672,,,,,,,,,,,,,, -22200,1.5367428,3.2089748,,,,,,,,,,,,,, -22300,1.4014819,3.5186214,,,,,,,,,,,,,, -22400,1.584277,3.0880873,,,,,,,,,,,,,, -22500,1.7201467,2.9473047,,,,,,,,,,,,,, -22600,1.1219671,4.3586392,,,,,,,,,,,,,, -22700,1.5296712,3.0443325,,,,,,,,,,,,,, -22800,1.1586039,5.601246,,,,,,,,,,,,,, -22900,1.1512074,5.1239014,,,,,,,,,,,,,, -22927,,,0.5287500023841858,2.041822910308838,0.4889199733734131,2.242918968200684,50000.0,0.3838000297546386,2.8525102138519287,10000.0,10539.358990907667,11501.50481057167,10539.358990907667,960.2646844387054,0.6987216472625732,0.0 -23000,1.1520337,4.509694,,,,,,,,,,,,,, -23100,1.4641684,3.5132473,,,,,,,,,,,,,, -23200,1.2983946,3.8947315,,,,,,,,,,,,,, -23300,1.6271192,3.1294737,,,,,,,,,,,,,, -23400,1.0788505,5.365367,,,,,,,,,,,,,, -23500,1.764719,3.3643932,,,,,,,,,,,,,, -23600,1.6181912,2.901511,,,,,,,,,,,,,, -23700,1.404709,3.1764572,,,,,,,,,,,,,, -23800,1.4494958,3.3760262,,,,,,,,,,,,,, -23845,,,0.5420507788658142,2.0036168098449707,0.5009599924087524,2.2223894596099854,50000.0,0.387800008058548,2.8514115810394287,10000.0,10959.397310972214,11961.27671599388,10959.397310972214,999.9240214824677,0.7256793975830078,0.0 -23900,1.3631008,4.0491767,,,,,,,,,,,,,, -24000,1.6577946,3.3067749,,,,,,,,,,,,,, -24100,1.5020629,3.1354146,,,,,,,,,,,,,, -24200,1.4848319,2.9588337,,,,,,,,,,,,,, -24300,1.4174857,3.9623966,,,,,,,,,,,,,, -24400,1.219076,4.3274956,,,,,,,,,,,,,, -24500,1.4597116,3.538175,,,,,,,,,,,,,, -24600,1.2733297,3.7273655,,,,,,,,,,,,,, -24700,1.5865932,3.1017022,,,,,,,,,,,,,, -24770,,,0.5383593440055847,2.017052173614502,0.5021600127220154,2.189324855804444,50000.0,0.3936000168323517,2.818950891494751,10000.0,11379.342381954191,12414.400235176086,11379.342381954191,1033.0191838741302,0.761568546295166,0.0 -24800,1.6269499,2.90846,,,,,,,,,,,,,, -24900,1.1495365,4.66228,,,,,,,,,,,,,, -25000,1.6193985,3.007915,,,,,,,,,,,,,, -25100,1.3464917,5.0915422,,,,,,,,,,,,,, -25200,1.6280245,3.1771607,,,,,,,,,,,,,, -25300,1.4353836,3.6278229,,,,,,,,,,,,,, -25400,1.5757627,2.9616666,,,,,,,,,,,,,, -25500,1.5749941,3.1901479,,,,,,,,,,,,,, -25600,1.4829916,3.0943398,,,,,,,,,,,,,, -25691,,,0.5432421565055847,1.992870092391968,0.5070199966430664,2.167407989501953,50000.0,0.393200010061264,2.8182153701782227,10000.0,11799.4931101799,12871.638572454453,11799.4931101799,1070.0217413902285,0.7992439270019531,0.0 -25700,1.4057821,4.2331424,,,,,,,,,,,,,, -25800,1.21742,5.634195,,,,,,,,,,,,,, -25900,1.5662904,3.1281488,,,,,,,,,,,,,, -26000,1.3312335,3.6325898,,,,,,,,,,,,,, -26100,1.5366932,3.4207668,,,,,,,,,,,,,, -26200,1.5673348,3.1013558,,,,,,,,,,,,,, -26300,1.7076244,2.9680138,,,,,,,,,,,,,, -26400,1.2318746,5.0806713,,,,,,,,,,,,,, -26500,1.1512573,5.491079,,,,,,,,,,,,,, -26600,1.5458361,3.1028073,,,,,,,,,,,,,, -26613,,,0.549023449420929,1.963714599609375,0.5048800110816956,2.183801889419556,50000.0,0.3975000083446502,2.823049783706665,10000.0,12219.802811145782,13329.79423236847,12219.802811145782,1107.7912635803225,0.8276638984680176,0.0 -26700,1.7552633,2.968845,,,,,,,,,,,,,, -26800,1.3510544,5.631646,,,,,,,,,,,,,, -26900,1.5818284,2.9557323,,,,,,,,,,,,,, -27000,1.8860046,2.8967621,,,,,,,,,,,,,, -27100,1.6209804,2.8774276,,,,,,,,,,,,,, -27200,1.427418,3.608819,,,,,,,,,,,,,, -27300,1.2898903,4.0846457,,,,,,,,,,,,,, -27400,1.5296268,2.833458,,,,,,,,,,,,,, -27500,1.6767911,3.051909,,,,,,,,,,,,,, -27533,,,0.564648449420929,1.8632910251617432,0.5128799676895142,2.0994911193847656,50000.0,0.4047000110149383,2.7460947036743164,10000.0,12639.845764398577,13787.502056837082,12639.845764398577,1145.3771076202393,0.859968900680542,0.0 -27600,1.2085816,5.1459703,,,,,,,,,,,,,, -27700,1.7249719,2.8483891,,,,,,,,,,,,,, -27800,1.6516578,3.0569315,,,,,,,,,,,,,, -27900,1.6737465,3.0624366,,,,,,,,,,,,,, -28000,1.499973,3.2103605,,,,,,,,,,,,,, -28100,1.5191146,3.0349798,,,,,,,,,,,,,, -28200,1.586214,3.18644,,,,,,,,,,,,,, -28300,1.6208483,3.1251326,,,,,,,,,,,,,, -28400,1.4054399,4.7946157,,,,,,,,,,,,,, -28453,,,0.5565234422683716,1.9159458875656128,0.5190799832344055,2.100757360458374,50000.0,0.4094000160694122,2.7339088916778564,10000.0,13060.279699325562,14243.699850559236,13060.279699325562,1181.059654712677,0.8942258358001709,0.0 -28500,1.5999014,2.9047663,,,,,,,,,,,,,, -28600,1.2758377,3.9220126,,,,,,,,,,,,,, -28700,1.9733995,2.9405246,,,,,,,,,,,,,, -28800,1.297732,4.6406007,,,,,,,,,,,,,, -28900,1.7965282,2.927448,,,,,,,,,,,,,, -29000,1.2699653,4.7047253,,,,,,,,,,,,,, -29100,1.6242748,3.1076,,,,,,,,,,,,,, -29200,1.3663878,5.4076815,,,,,,,,,,,,,, -29300,1.7508723,2.978634,,,,,,,,,,,,,, -29344,,,0.5629491806030273,1.8562678098678589,0.523580014705658,2.056525230407715,50000.0,0.4134000241756439,2.7119228839874268,10000.0,13480.380759000778,14698.166138410568,13480.380759000778,1215.350920677185,0.921989917755127,0.0 -29400,1.3772258,5.6256,,,,,,,,,,,,,, -29500,1.4894108,2.8904245,,,,,,,,,,,,,, -29600,1.7002354,2.9037352,,,,,,,,,,,,,, -29700,1.6875448,2.86587,,,,,,,,,,,,,, -29800,1.589957,2.8655787,,,,,,,,,,,,,, -29900,1.554611,5.293775,,,,,,,,,,,,,, -30000,1.5437541,2.9069757,,,,,,,,,,,,,, -30100,1.5730312,2.7844334,,,,,,,,,,,,,, -30200,1.6430289,2.7641745,,,,,,,,,,,,,, -30266,,,0.5864452719688416,1.7774485349655151,0.5238800048828125,2.0740952491760254,50000.0,0.4092000126838684,2.7224607467651367,10000.0,13900.379534959791,15156.030605793,13900.379534959791,1253.140303850174,0.950770616531372,0.0 -30300,1.6306784,3.0486596,,,,,,,,,,,,,, -30400,1.7351345,2.846162,,,,,,,,,,,,,, -30500,1.3688246,3.6495008,,,,,,,,,,,,,, -30600,1.6711315,2.7674923,,,,,,,,,,,,,, -30700,1.4998758,3.086719,,,,,,,,,,,,,, -30800,1.7499226,2.9335377,,,,,,,,,,,,,, -30900,1.5488433,2.9969962,,,,,,,,,,,,,, -31000,1.4353306,5.1757326,,,,,,,,,,,,,, -31100,1.1989771,5.2122564,,,,,,,,,,,,,, -31188,,,0.5709179639816284,1.872394561767578,0.5306400060653687,2.058422327041626,50000.0,0.4148000180721283,2.696887731552124,10000.0,14320.638298034668,15614.249346971512,14320.638298034668,1291.0224838256836,0.9816055297851562,0.0 -31200,1.425569,4.806019,,,,,,,,,,,,,, -31300,1.725058,2.9376206,,,,,,,,,,,,,, -31400,1.3269589,4.1938004,,,,,,,,,,,,,, -31500,1.7489318,2.8733735,,,,,,,,,,,,,, -31600,1.9578143,2.8717391,,,,,,,,,,,,,, -31700,1.2177305,5.4361067,,,,,,,,,,,,,, -31800,1.6695089,2.8942323,,,,,,,,,,,,,, -31900,1.1950086,5.258017,,,,,,,,,,,,,, -32000,1.4047132,3.7340944,,,,,,,,,,,,,, -32100,1.7653399,2.9076447,,,,,,,,,,,,,, -32109,,,0.5722070336341858,1.8319276571273804,0.537339985370636,2.008451223373413,50000.0,0.4218000173568725,2.658211231231689,10000.0,14740.856187820436,16071.735245227814,14740.856187820436,1328.2112641334534,1.0133063793182373,0.0 -32200,1.5762258,2.9663181,,,,,,,,,,,,,, -32300,1.2156429,4.3869476,,,,,,,,,,,,,, -32400,1.7011234,3.1545691,,,,,,,,,,,,,, -32500,1.3918927,3.647435,,,,,,,,,,,,,, -32600,1.6512605,3.0420444,,,,,,,,,,,,,, -32700,1.4481313,3.6157384,,,,,,,,,,,,,, -32800,1.5838313,4.2951775,,,,,,,,,,,,,, -32900,1.2544156,4.2441516,,,,,,,,,,,,,, -33000,1.5737287,2.9987462,,,,,,,,,,,,,, -33030,,,0.591601550579071,1.7154150009155271,0.5344399809837341,1.969157099723816,50000.0,0.4265000224113464,2.6062676906585693,10000.0,15160.993295431135,16531.774827718735,15160.993295431135,1368.0250644683838,1.0439252853393557,0.0 -33100,1.2589555,3.6185565,,,,,,,,,,,,,, -33200,1.5819008,2.8793626,,,,,,,,,,,,,, -33300,1.3813521,5.1937838,,,,,,,,,,,,,, -33400,1.3987486,3.606595,,,,,,,,,,,,,, -33500,1.7339401,2.843845,,,,,,,,,,,,,, -33600,1.6642516,2.8171732,,,,,,,,,,,,,, -33700,1.854914,2.7870114,,,,,,,,,,,,,, -33800,1.6938047,2.843102,,,,,,,,,,,,,, -33900,1.3893448,3.3888834,,,,,,,,,,,,,, -33951,,,0.580078125,1.7822751998901367,0.5445799827575684,1.9617338180541992,50000.0,0.4296000301837921,2.6152634620666504,10000.0,15580.957021474838,16992.161110639572,15580.957021474838,1408.3668491840365,1.0768020153045654,0.0 -34000,1.6226991,2.703648,,,,,,,,,,,,,, -34100,1.5370065,3.8123283,,,,,,,,,,,,,, -34200,1.3929176,5.171818,,,,,,,,,,,,,, -34300,1.6727966,3.2959208,,,,,,,,,,,,,, -34400,1.3900783,3.5541608,,,,,,,,,,,,,, -34500,1.8632846,2.8689606,,,,,,,,,,,,,, -34600,1.3223382,4.1727247,,,,,,,,,,,,,, -34700,1.6779546,2.8213162,,,,,,,,,,,,,, -34800,1.3239297,4.7841115,,,,,,,,,,,,,, -34872,,,0.5824218392372131,1.7896679639816284,0.5446400046348572,1.968542456626892,50000.0,0.4246000349521637,2.62258243560791,10000.0,16001.323278665544,17450.09957075119,16001.323278665544,1445.8618338108065,1.105936050415039,0.0 -34900,1.5847418,2.8579974,,,,,,,,,,,,,, -35000,1.5466865,3.3446302,,,,,,,,,,,,,, -35100,1.5928752,2.8074589,,,,,,,,,,,,,, -35200,1.6192757,3.0389094,,,,,,,,,,,,,, -35300,1.7978364,2.8525505,,,,,,,,,,,,,, -35400,1.221686,4.9560137,,,,,,,,,,,,,, -35500,1.5919741,2.9042284,,,,,,,,,,,,,, -35600,1.3232424,3.903049,,,,,,,,,,,,,, -35700,1.6136733,2.8009176,,,,,,,,,,,,,, -35793,,,0.5892577767372131,1.7423663139343262,0.5414800047874451,1.9773989915847776,50000.0,0.4243000149726867,2.617381811141968,10000.0,16421.62156009674,17907.980507850647,16421.62156009674,1483.3568606376648,1.1460247039794922,0.0 -35800,1.6213933,2.7836084,,,,,,,,,,,,,, -35900,1.547258,2.6783414,,,,,,,,,,,,,, -36000,1.8707143,2.7582204,,,,,,,,,,,,,, -36100,1.5240623,3.0744615,,,,,,,,,,,,,, -36200,1.1129827,5.0157886,,,,,,,,,,,,,, -36300,1.7387048,2.9612014,,,,,,,,,,,,,, -36400,1.3143679,4.971417,,,,,,,,,,,,,, -36500,1.7391627,2.8231566,,,,,,,,,,,,,, -36600,1.7217984,2.7797167,,,,,,,,,,,,,, -36700,1.6893332,2.7569914,,,,,,,,,,,,,, -36716,,,0.5894140601158142,1.7541520595550537,0.5504999756813049,1.9289462566375728,50000.0,0.4310000240802765,2.5739073753356934,10000.0,16841.98622250557,18367.78451514244,16841.98622250557,1522.7079238891602,1.1861801147460938,0.0 -36800,1.6013445,2.6848092,,,,,,,,,,,,,, -36900,1.649693,3.2639527,,,,,,,,,,,,,, -37000,1.677013,2.7581081,,,,,,,,,,,,,, -37100,1.5516663,3.047241,,,,,,,,,,,,,, -37200,1.7566756,2.903548,,,,,,,,,,,,,, -37300,1.6016709,2.7914855,,,,,,,,,,,,,, -37400,1.6621287,2.8079243,,,,,,,,,,,,,, -37500,1.7937912,3.3189242,,,,,,,,,,,,,, -37600,1.1667085,5.464835,,,,,,,,,,,,,, -37638,,,0.5914062261581421,1.7134793996810913,0.5567799806594849,1.893664002418518,50000.0,0.4353000223636627,2.56376576423645,10000.0,17262.115788459778,18823.204505443573,17262.115788459778,1557.9151480197906,1.2215921878814695,0.0 -37700,1.4522533,3.295954,,,,,,,,,,,,,, -37800,1.5937173,2.7267976,,,,,,,,,,,,,, -37900,1.7178701,2.8077435,,,,,,,,,,,,,, -38000,1.6165873,3.1579065,,,,,,,,,,,,,, -38100,1.8384457,2.7157936,,,,,,,,,,,,,, -38200,1.4215608,3.5574603,,,,,,,,,,,,,, -38300,1.8066524,2.8066497,,,,,,,,,,,,,, -38400,1.6015173,3.5568655,,,,,,,,,,,,,, -38500,1.3083556,4.20852,,,,,,,,,,,,,, -38560,,,0.5970898270606995,1.721079707145691,0.5518199801445007,1.9465895891189573,50000.0,0.4353000223636627,2.5917446613311768,10000.0,17682.302928209305,19280.42056465149,17682.302928209305,1594.8616213798523,1.2564432621002195,0.0 -38600,1.5765561,2.763685,,,,,,,,,,,,,, -38700,1.5397584,3.8064122,,,,,,,,,,,,,, -38800,1.7499841,2.6751304,,,,,,,,,,,,,, -38900,1.707501,2.6939645,,,,,,,,,,,,,, -39000,1.3512293,3.9223962,,,,,,,,,,,,,, -39100,2.0507305,2.6820757,,,,,,,,,,,,,, -39200,1.6201565,2.7601936,,,,,,,,,,,,,, -39300,1.7750276,2.6512086,,,,,,,,,,,,,, -39400,1.3657112,4.3342686,,,,,,,,,,,,,, -39481,,,0.6230077743530273,1.6257997751235962,0.5562999844551086,1.9278334379196167,50000.0,0.4412000179290771,2.564246654510498,10000.0,18102.482609033585,19740.42428445816,18102.482609033585,1634.6089255809784,1.2855734825134275,0.0 -39500,1.4622078,5.192574,,,,,,,,,,,,,, -39600,1.9670197,2.924715,,,,,,,,,,,,,, -39700,1.925297,3.0302043,,,,,,,,,,,,,, -39800,1.7357186,2.6971338,,,,,,,,,,,,,, -39900,1.3016592,5.0674853,,,,,,,,,,,,,, -40000,1.4981998,3.4394107,,,,,,,,,,,,,, -40100,1.9249471,2.7685516,,,,,,,,,,,,,, -40200,1.5907176,3.3647792,,,,,,,,,,,,,, -40300,1.7648336,2.7349095,,,,,,,,,,,,,, -40400,1.6740817,2.8681526,,,,,,,,,,,,,, -40401,,,0.5972656011581421,1.7070727348327637,0.5578399896621704,1.892319798469544,50000.0,0.4462000131607055,2.5406343936920166,10000.0,18522.80424499512,20199.12549352646,18522.80424499512,1672.908019542694,1.3188085556030271,0.0 -40500,1.3198712,5.025637,,,,,,,,,,,,,, -40600,1.5041901,2.7790802,,,,,,,,,,,,,, -40700,1.6631216,2.814431,,,,,,,,,,,,,, -40800,1.7072984,2.9631472,,,,,,,,,,,,,, -40900,1.3234577,4.317838,,,,,,,,,,,,,, -41000,1.7434617,2.639407,,,,,,,,,,,,,, -41100,1.63245,2.6140213,,,,,,,,,,,,,, -41200,1.5417664,3.9717402,,,,,,,,,,,,,, -41300,2.1164947,2.6989818,,,,,,,,,,,,,, -41322,,,0.6015819907188416,1.6829158067703247,0.557379961013794,1.892937421798706,50000.0,0.4448000192642212,2.5214507579803467,10000.0,18942.92242288589,20658.107868433,18942.92242288589,1711.6933376789093,1.3502240180969238,0.0 -41400,2.0916178,2.8036547,,,,,,,,,,,,,, -41500,1.6150632,2.6712692,,,,,,,,,,,,,, -41600,1.6886241,2.8006277,,,,,,,,,,,,,, -41700,1.2074978,4.975372,,,,,,,,,,,,,, -41800,1.6658295,2.796037,,,,,,,,,,,,,, -41900,1.3934649,4.5631638,,,,,,,,,,,,,, -42000,1.6260971,4.0283823,,,,,,,,,,,,,, -42100,1.4716927,3.1921532,,,,,,,,,,,,,, -42200,1.628554,2.869714,,,,,,,,,,,,,, -42241,,,0.6206249594688416,1.5839260816574097,0.5624200105667114,1.8680510520935056,50000.0,0.4435000121593475,2.519059419631958,10000.0,19362.84196233749,21115.976779937744,19362.84196233749,1749.5593955516815,1.3858461380004885,0.0 -42300,1.5897331,3.162879,,,,,,,,,,,,,, -42400,1.5636429,3.2517738,,,,,,,,,,,,,, -42500,1.5078987,3.6821306,,,,,,,,,,,,,, -42600,1.7679383,2.6047177,,,,,,,,,,,,,, -42700,1.5919716,3.0388079,,,,,,,,,,,,,, -42800,1.5223315,2.764936,,,,,,,,,,,,,, -42900,1.7661865,2.6707988,,,,,,,,,,,,,, -43000,1.7703186,2.825297,,,,,,,,,,,,,, -43100,1.7862889,2.6405244,,,,,,,,,,,,,, -43163,,,0.6048827767372131,1.6705154180526731,0.5646600127220154,1.8540127277374268,50000.0,0.4525000154972076,2.4949381351470947,10000.0,19782.84850549698,21575.81156134605,19782.84850549698,1789.3050591945648,1.420839786529541,0.0 -43200,1.71353,2.7675872,,,,,,,,,,,,,, -43300,1.6734061,2.802693,,,,,,,,,,,,,, -43400,1.5624189,3.6485784,,,,,,,,,,,,,, -43500,1.7030641,2.6381755,,,,,,,,,,,,,, -43600,1.4405236,5.166271,,,,,,,,,,,,,, -43700,1.8136774,2.8785706,,,,,,,,,,,,,, -43800,1.6565919,2.7216845,,,,,,,,,,,,,, -43900,1.2069906,5.360141,,,,,,,,,,,,,, -44000,1.7691145,2.68919,,,,,,,,,,,,,, -44084,,,0.6102148294448853,1.6581255197525024,0.5626199841499329,1.85548996925354,50000.0,0.4494000077247619,2.488030433654785,10000.0,20203.19352889061,22031.04239463806,20203.19352889061,1824.1077728271484,1.4556865692138672,0.0 -44100,1.6771678,2.672813,,,,,,,,,,,,,, -44200,1.8349141,2.6200862,,,,,,,,,,,,,, -44300,1.312107,4.1739364,,,,,,,,,,,,,, -44400,1.340236,3.9967132,,,,,,,,,,,,,, -44500,1.8712232,2.705751,,,,,,,,,,,,,, -44600,1.8143125,2.7385225,,,,,,,,,,,,,, -44700,1.4101673,4.128668,,,,,,,,,,,,,, -44800,1.7311177,2.6120915,,,,,,,,,,,,,, -44900,1.7535814,2.9122882,,,,,,,,,,,,,, -45000,1.3809211,5.2097855,,,,,,,,,,,,,, -45006,,,0.6237499713897705,1.5927554368972778,0.568619966506958,1.843924641609192,50000.0,0.4508000314235687,2.4984617233276367,10000.0,20623.2698571682,22488.867975711823,20623.2698571682,1861.775664567948,1.4894487857818604,0.0 -45100,1.7643344,2.685584,,,,,,,,,,,,,, -45200,1.2929708,5.306467,,,,,,,,,,,,,, -45300,1.7211897,3.1586993,,,,,,,,,,,,,, -45400,1.7601789,2.695308,,,,,,,,,,,,,, -45500,1.5168434,3.8420286,,,,,,,,,,,,,, -45600,1.4016409,3.563939,,,,,,,,,,,,,, -45700,1.6992812,2.9026427,,,,,,,,,,,,,, -45800,1.9318997,2.6761818,,,,,,,,,,,,,, -45900,1.7121389,2.679317,,,,,,,,,,,,,, -45926,,,0.605175793170929,1.694035887718201,0.5661999583244324,1.873929023742676,50000.0,0.4462000131607055,2.526916027069092,10000.0,21043.675313472748,22944.59355187416,21043.675313472748,1897.01748919487,1.5201151371002195,0.0 -46000,1.8812886,2.826785,,,,,,,,,,,,,, -46100,1.748824,2.6023633,,,,,,,,,,,,,, -46200,1.7428523,3.1079493,,,,,,,,,,,,,, -46300,1.697633,2.763166,,,,,,,,,,,,,, -46400,1.357653,5.1892776,,,,,,,,,,,,,, -46500,1.7753474,2.7212276,,,,,,,,,,,,,, -46600,1.9002596,2.662034,,,,,,,,,,,,,, -46700,1.7311958,2.5955276,,,,,,,,,,,,,, -46800,1.2697247,4.694628,,,,,,,,,,,,,, -46847,,,0.6008007526397705,1.7134603261947632,0.5633000135421753,1.8950095176696773,50000.0,0.4472000300884247,2.5412986278533936,10000.0,21463.90951180458,23397.665759801865,21463.90951180458,1929.7773234844208,1.5505378246307373,0.0 -46900,1.6403211,3.6331525,,,,,,,,,,,,,, -47000,1.2420422,4.762374,,,,,,,,,,,,,, -47100,1.8587717,2.8350677,,,,,,,,,,,,,, -47200,1.8174956,5.1990366,,,,,,,,,,,,,, -47300,1.3595085,4.5918894,,,,,,,,,,,,,, -47400,2.1995776,2.7456295,,,,,,,,,,,,,, -47500,1.2516439,5.3406324,,,,,,,,,,,,,, -47600,1.554926,3.0725553,,,,,,,,,,,,,, -47700,1.4404734,5.121511,,,,,,,,,,,,,, -47769,,,0.6170703172683716,1.6229615211486816,0.5671600103378296,1.8515721559524536,50000.0,0.4527000188827514,2.5124051570892334,10000.0,21884.20440530777,23856.79626774788,21884.20440530777,1968.531141042709,1.584803342819214,0.0 -47800,1.7499622,2.7826772,,,,,,,,,,,,,, -47900,1.7490466,2.7712855,,,,,,,,,,,,,, -48000,1.3621988,5.1083355,,,,,,,,,,,,,, -48100,1.5381067,3.649894,,,,,,,,,,,,,, -48200,1.8653184,3.0598032,,,,,,,,,,,,,, -48300,1.2379045,4.5517917,,,,,,,,,,,,,, -48400,1.5259895,2.8598926,,,,,,,,,,,,,, -48500,1.8424655,2.6228118,,,,,,,,,,,,,, -48600,1.7894216,2.704744,,,,,,,,,,,,,, -48691,,,0.6227929592132568,1.575631856918335,0.5725600123405457,1.808995008468628,50000.0,0.4574000239372253,2.444520711898804,10000.0,22304.507378816605,24315.673028230667,22304.507378816605,2007.020524263382,1.6213884353637695,0.0 -48700,1.9520469,2.649075,,,,,,,,,,,,,, -48800,1.4652082,4.9023614,,,,,,,,,,,,,, -48900,1.6319119,2.7879705,,,,,,,,,,,,,, -49000,1.7615613,2.6253734,,,,,,,,,,,,,, -49100,1.7723187,2.744822,,,,,,,,,,,,,, -49200,1.7691026,3.3376107,,,,,,,,,,,,,, -49300,1.5169401,4.763294,,,,,,,,,,,,,, -49400,1.7539883,2.772753,,,,,,,,,,,,,, -49500,1.8249904,2.7576818,,,,,,,,,,,,,, -49600,2.0418832,2.7803361,,,,,,,,,,,,,, -49612,,,0.6174609065055847,1.6112401485443115,0.573639988899231,1.8005505800247192,50000.0,0.4617000222206116,2.452442646026612,10000.0,22724.5529756546,24775.178783893585,22724.5529756546,2046.3978426456447,1.657334804534912,0.0 -49700,1.6900215,2.7175786,,,,,,,,,,,,,, -49800,1.6240007,2.522776,,,,,,,,,,,,,, -49900,1.8318832,2.6364255,,,,,,,,,,,,,, -50000,1.6994748,2.6059728,,,,,,,,,,,,,, -50100,1.2471495,4.455407,,,,,,,,,,,,,, -50200,1.7289069,2.5246973,,,,,,,,,,,,,, -50300,1.7953156,2.6804452,,,,,,,,,,,,,, -50400,1.8871411,2.7006445,,,,,,,,,,,,,, -50500,1.6336236,2.8709898,,,,,,,,,,,,,, -50533,,,0.62060546875,1.6004736423492432,0.5769400000572205,1.7959643602371216,50000.0,0.4663000106811523,2.434957981109619,10000.0,23144.769639730453,25230.020445346832,23144.769639730453,2080.936208486557,1.6951377391815186,0.0 -50600,1.802958,2.9476316,,,,,,,,,,,,,, -50700,1.7362013,3.0233366,,,,,,,,,,,,,, -50800,1.8333849,2.7508817,,,,,,,,,,,,,, -50900,1.8142272,2.5869257,,,,,,,,,,,,,, -51000,1.3447552,5.215791,,,,,,,,,,,,,, -51100,1.7336866,2.5723214,,,,,,,,,,,,,, -51200,1.6297939,3.1778538,,,,,,,,,,,,,, -51300,1.9491234,2.6658168,,,,,,,,,,,,,, -51400,1.7962449,2.576867,,,,,,,,,,,,,, -51453,,,0.6434960961341858,1.4826300144195557,0.5772199630737305,1.7871181964874268,50000.0,0.4559000134468078,2.4530246257781982,10000.0,23564.7326142788,25689.36802005768,23564.7326142788,2120.237253427505,1.7319800853729248,0.0 -51500,1.9344012,2.6629474,,,,,,,,,,,,,, -51600,1.4947398,3.529926,,,,,,,,,,,,,, -51700,1.7540852,2.612679,,,,,,,,,,,,,, -51800,1.8190314,2.7276833,,,,,,,,,,,,,, -51900,2.065338,2.491509,,,,,,,,,,,,,, -52000,1.674654,2.583868,,,,,,,,,,,,,, -52100,1.5050836,4.7092304,,,,,,,,,,,,,, -52200,1.8720781,2.6640234,,,,,,,,,,,,,, -52300,1.3702543,3.4016156,,,,,,,,,,,,,, -52375,,,0.6228905916213989,1.5904203653335571,0.5809999704360962,1.7867530584335327,50000.0,0.4628000259399414,2.418676376342773,10000.0,23984.924451828003,26149.58564400673,23984.924451828003,2160.178407430649,1.7693891525268557,0.0 -52400,1.8341156,2.664635,,,,,,,,,,,,,, -52500,1.884484,2.5738544,,,,,,,,,,,,,, -52600,1.674323,2.5573552,,,,,,,,,,,,,, -52700,1.8357595,2.6686113,,,,,,,,,,,,,, -52800,1.705389,3.3891966,,,,,,,,,,,,,, -52900,1.75393,2.6267838,,,,,,,,,,,,,, -53000,1.8252562,2.6608155,,,,,,,,,,,,,, -53100,2.0286505,2.5946307,,,,,,,,,,,,,, -53200,2.037279,2.741539,,,,,,,,,,,,,, -53295,,,0.6255077719688416,1.5740492343902588,0.5805599689483643,1.7782132625579834,50000.0,0.4629000127315521,2.41633677482605,10000.0,24404.92098903656,26604.52316379547,24404.92098903656,2195.0370230674744,1.8044648170471191,0.0 -53300,1.4176277,5.2941747,,,,,,,,,,,,,, -53400,1.6245639,2.719456,,,,,,,,,,,,,, -53500,1.5481308,3.279276,,,,,,,,,,,,,, -53600,1.3332711,4.548147,,,,,,,,,,,,,, -53700,1.6191839,2.4463744,,,,,,,,,,,,,, -53800,1.7635785,2.5249932,,,,,,,,,,,,,, -53900,1.8152267,2.7779913,,,,,,,,,,,,,, -54000,1.7192594,4.815122,,,,,,,,,,,,,, -54100,1.3580732,4.8329115,,,,,,,,,,,,,, -54200,1.6750209,2.7238464,,,,,,,,,,,,,, -54218,,,0.6314257383346558,1.5377583503723145,0.5790599584579468,1.7804241180419922,50000.0,0.4669000208377838,2.433056354522705,10000.0,24825.276747226715,27060.22543120384,24825.276747226715,2230.295460224152,1.8441081047058103,0.0 -54300,1.2941904,5.243781,,,,,,,,,,,,,, -54400,2.0640326,2.6510208,,,,,,,,,,,,,, -54500,1.7417411,2.6388757,,,,,,,,,,,,,, -54600,1.6241279,2.4212365,,,,,,,,,,,,,, -54700,1.6747293,3.768959,,,,,,,,,,,,,, -54800,1.5546799,5.3869023,,,,,,,,,,,,,, -54900,1.7946122,2.661759,,,,,,,,,,,,,, -55000,1.7773666,2.5701761,,,,,,,,,,,,,, -55100,1.5657176,3.6548228,,,,,,,,,,,,,, -55138,,,0.6166015267372131,1.6154673099517822,0.5781399607658386,1.7982147932052612,50000.0,0.4554000198841095,2.452510356903076,10000.0,25245.21862053871,27518.237367630005,25245.21862053871,2268.2810649871826,1.8811235427856443,0.0 -55200,1.7298162,3.3849118,,,,,,,,,,,,,, -55300,1.5441601,5.2161226,,,,,,,,,,,,,, -55400,1.8906318,2.6129007,,,,,,,,,,,,,, -55500,1.6416328,3.3204744,,,,,,,,,,,,,, -55600,1.8584083,2.588633,,,,,,,,,,,,,, -55700,1.7093773,3.0359864,,,,,,,,,,,,,, -55800,1.5524894,4.954461,,,,,,,,,,,,,, -55900,1.7249298,2.5721555,,,,,,,,,,,,,, -56000,1.4776195,4.545588,,,,,,,,,,,,,, -56056,,,0.6266992092132568,1.5533452033996582,0.5848999619483948,1.7521530389785769,50000.0,0.4698000252246856,2.410242795944214,10000.0,25665.102162361145,27973.03744482994,25665.102162361145,2302.746966123581,2.2839317321777344,0.0 -56100,1.7835402,2.7205088,,,,,,,,,,,,,, -56200,1.7937279,2.4630103,,,,,,,,,,,,,, -56300,1.8743659,2.646914,,,,,,,,,,,,,, -56400,1.899875,2.5819616,,,,,,,,,,,,,, -56500,1.4865056,5.01917,,,,,,,,,,,,,, -56600,1.8677706,3.090227,,,,,,,,,,,,,, -56700,1.7346113,2.6328447,,,,,,,,,,,,,, -56800,1.5671197,2.8956625,,,,,,,,,,,,,, -56900,1.4843754,4.160383,,,,,,,,,,,,,, -56978,,,0.6295312643051147,1.5653393268585205,0.5822399854660034,1.799446702003479,50000.0,0.463200032711029,2.456258773803711,10000.0,26085.34273672104,28431.98593950272,26085.34273672104,2341.369970321656,2.32083511352539,0.0 -57000,1.5301974,4.1235933,,,,,,,,,,,,,, -57100,1.5412791,3.1069715,,,,,,,,,,,,,, -57200,1.3554642,4.8070884,,,,,,,,,,,,,, -57300,1.6058639,3.1660585,,,,,,,,,,,,,, -57400,1.701415,2.780918,,,,,,,,,,,,,, -57500,1.8703585,2.8477616,,,,,,,,,,,,,, -57600,1.3996536,3.5830863,,,,,,,,,,,,,, -57700,1.4024065,4.010513,,,,,,,,,,,,,, -57800,1.661684,2.6461177,,,,,,,,,,,,,, -57898,,,0.6297656297683716,1.5668786764144895,0.5859599709510803,1.7582910060882568,50000.0,0.4683000147342682,2.4005134105682373,10000.0,26505.34909033776,28890.808502435684,26505.34909033776,2380.102923631668,2.3555450439453125,0.0 -57900,1.8337505,2.7228525,,,,,,,,,,,,,, -58000,1.4314976,4.1499825,,,,,,,,,,,,,, -58100,1.7883471,2.667278,,,,,,,,,,,,,, -58200,1.6487827,4.655243,,,,,,,,,,,,,, -58300,1.4823501,3.5258398,,,,,,,,,,,,,, -58400,2.0324342,2.7888336,,,,,,,,,,,,,, -58500,1.8230752,3.1050272,,,,,,,,,,,,,, -58600,1.4162898,4.300866,,,,,,,,,,,,,, -58700,1.9352666,2.595848,,,,,,,,,,,,,, -58800,1.526626,4.0760336,,,,,,,,,,,,,, -58817,,,0.6303515434265137,1.5301530361175537,0.590399980545044,1.723675012588501,50000.0,0.4696000218391418,2.3985865116119385,10000.0,26925.5447409153,29345.4977645874,26925.5447409153,2414.5152776241302,2.388858795166016,0.0 -58900,1.8606955,2.5956166,,,,,,,,,,,,,, -59000,1.7058078,2.779985,,,,,,,,,,,,,, -59100,1.7071161,2.7213998,,,,,,,,,,,,,, -59200,2.0323617,2.590609,,,,,,,,,,,,,, -59300,2.0599244,2.754059,,,,,,,,,,,,,, -59400,1.6657468,2.7863765,,,,,,,,,,,,,, -59500,1.9212996,3.0780144,,,,,,,,,,,,,, -59600,1.7279739,3.8106947,,,,,,,,,,,,,, -59700,2.0360157,2.5396698,,,,,,,,,,,,,, -59735,,,0.6366991996765137,1.510748267173767,0.590719997882843,1.7339236736297607,50000.0,0.4668000340461731,2.381608486175537,10000.0,27345.576851844788,29804.49097251892,27345.576851844788,2453.3868992328644,2.431230306625366,0.0 -59800,1.5921535,5.2273197,,,,,,,,,,,,,, -59900,1.880382,2.5169077,,,,,,,,,,,,,, -60000,2.018896,2.5567136,,,,,,,,,,,,,, -60100,1.7507463,2.6039872,,,,,,,,,,,,,, -60200,1.7390727,2.4873178,,,,,,,,,,,,,, -60300,1.8418514,2.4811287,,,,,,,,,,,,,, -60400,1.9473469,2.5037298,,,,,,,,,,,,,, -60500,1.6058666,4.9600706,,,,,,,,,,,,,, -60600,1.5921644,3.0846064,,,,,,,,,,,,,, -60655,,,0.661816418170929,1.413130760192871,0.5918799638748169,1.732790231704712,50000.0,0.4716000258922577,2.3704307079315186,10000.0,27765.77965736389,30265.760466575623,27765.77965736389,2494.3704164028168,2.467085599899292,0.0 -60700,1.6336781,2.6388693,,,,,,,,,,,,,, -60800,2.1459758,2.665377,,,,,,,,,,,,,, -60900,1.8608335,2.853309,,,,,,,,,,,,,, -61000,1.8071749,2.5425072,,,,,,,,,,,,,, -61100,1.4222441,5.100455,,,,,,,,,,,,,, -61200,1.9788948,2.5330641,,,,,,,,,,,,,, -61300,1.7699323,3.0321703,,,,,,,,,,,,,, -61400,1.6878299,2.8958068,,,,,,,,,,,,,, -61500,1.4481946,5.0051346,,,,,,,,,,,,,, -61575,,,0.6309961080551147,1.5477447509765625,0.5902799963951111,1.7401505708694458,50000.0,0.4678000211715698,2.399973630905152,10000.0,28186.06973552704,30719.761566877365,28186.06973552704,2527.9952044487,2.5055477619171143,0.0 -61600,2.0301692,2.6106334,,,,,,,,,,,,,, -61700,1.6616168,3.069552,,,,,,,,,,,,,, -61800,1.6477762,2.741683,,,,,,,,,,,,,, -61900,1.7493131,3.0906017,,,,,,,,,,,,,, -62000,1.7809364,2.5038133,,,,,,,,,,,,,, -62100,1.4533305,4.497801,,,,,,,,,,,,,, -62200,1.4275887,3.841686,,,,,,,,,,,,,, -62300,1.6454489,2.985225,,,,,,,,,,,,,, -62400,1.4721056,5.1374598,,,,,,,,,,,,,, -62494,,,0.6387304663658142,1.5181124210357666,0.588979959487915,1.735752820968628,50000.0,0.4692000150680542,2.3889834880828857,10000.0,28606.19539070129,31179.79212284088,28606.19539070129,2567.8129115104675,2.54514741897583,0.0 -62500,1.6906908,2.5018659,,,,,,,,,,,,,, -62600,1.7119275,2.9192872,,,,,,,,,,,,,, -62700,1.5393269,3.7966952,,,,,,,,,,,,,, -62800,1.8070043,2.8059192,,,,,,,,,,,,,, -62900,1.7624447,2.4902034,,,,,,,,,,,,,, -63000,1.4419867,4.8157825,,,,,,,,,,,,,, -63100,2.0297005,2.6265583,,,,,,,,,,,,,, -63200,1.6220062,4.2983522,,,,,,,,,,,,,, -63300,1.7290068,2.7402537,,,,,,,,,,,,,, -63400,1.8156002,2.579958,,,,,,,,,,,,,, -63414,,,0.6572265625,1.420607089996338,0.5945599675178528,1.7118746042251587,50000.0,0.4750000238418579,2.362293004989624,10000.0,29026.439562797543,31640.54657483101,29026.439562797543,2608.2360076904297,2.584966897964477,0.0 -63500,1.8683169,2.5587544,,,,,,,,,,,,,, -63600,1.782669,2.4193332,,,,,,,,,,,,,, -63700,1.9246993,2.5569706,,,,,,,,,,,,,, -63800,1.8748428,2.7717519,,,,,,,,,,,,,, -63900,2.0854876,2.5642948,,,,,,,,,,,,,, -64000,1.8926383,2.5790486,,,,,,,,,,,,,, -64100,1.5927252,3.6016593,,,,,,,,,,,,,, -64200,1.7649666,4.0351167,,,,,,,,,,,,,, -64300,1.7517914,3.2374802,,,,,,,,,,,,,, -64334,,,0.6352148056030273,1.5222920179367063,0.5979799628257751,1.6968483924865725,50000.0,0.4806000292301178,2.3335988521575928,10000.0,29446.779339313507,32099.92324185372,29446.779339313507,2647.1877439022064,2.622751951217652,0.0 -64400,1.8891374,2.6202276,,,,,,,,,,,,,, -64500,1.8440022,2.4671965,,,,,,,,,,,,,, -64600,1.5651673,4.4567957,,,,,,,,,,,,,, -64700,1.4571029,3.526403,,,,,,,,,,,,,, -64800,1.8177075,2.7562358,,,,,,,,,,,,,, -64900,1.8085886,2.5289054,,,,,,,,,,,,,, -65000,1.8638542,2.556211,,,,,,,,,,,,,, -65100,1.7175392,3.3153248,,,,,,,,,,,,,, -65200,1.816653,2.2699494,,,,,,,,,,,,,, -65257,,,0.6451562643051147,1.4754486083984375,0.5941999554634094,1.7001410722732544,50000.0,0.4757000207901001,2.3748252391815186,10000.0,29867.059475898743,32557.672868967056,29867.059475898743,2684.573234319687,2.658973693847656,0.0 -65300,1.6677842,3.5544975,,,,,,,,,,,,,, -65400,2.004676,2.6495833,,,,,,,,,,,,,, -65500,1.4536623,4.2246523,,,,,,,,,,,,,, -65600,2.0952857,2.6751864,,,,,,,,,,,,,, -65700,2.0687494,2.6568482,,,,,,,,,,,,,, -65800,2.0856328,2.6183043,,,,,,,,,,,,,, -65900,1.5775362,2.878665,,,,,,,,,,,,,, -66000,1.6508365,3.1866503,,,,,,,,,,,,,, -66100,1.7054453,5.185795,,,,,,,,,,,,,, -66178,,,0.6478906273841858,1.469024658203125,0.5952199697494507,1.713046669960022,50000.0,0.4808000326156616,2.354539155960083,10000.0,30287.294951438904,33016.23513197899,30287.294951438904,2722.811019182205,2.700793743133545,0.0 -66200,1.5618302,4.526864,,,,,,,,,,,,,, -66300,1.8679386,2.810002,,,,,,,,,,,,,, -66400,1.6072773,4.1737766,,,,,,,,,,,,,, -66500,1.7725817,4.4136257,,,,,,,,,,,,,, -66600,1.7825869,2.50702,,,,,,,,,,,,,, -66700,1.751404,2.8712857,,,,,,,,,,,,,, -66800,1.684444,2.9551146,,,,,,,,,,,,,, -66900,1.5699741,4.587482,,,,,,,,,,,,,, -67000,1.8599735,2.6917164,,,,,,,,,,,,,, -67099,,,0.6410937309265137,1.491320013999939,0.5977799892425537,1.6954699754714966,50000.0,0.4843000173568725,2.3312244415283203,10000.0,30707.29386544228,33478.465988874435,30707.29386544228,2764.954957485199,2.7407922744750977,0.0 -67100,1.5696813,3.7386835,,,,,,,,,,,,,, -67200,1.6688517,3.150622,,,,,,,,,,,,,, -67300,1.5022895,4.7153506,,,,,,,,,,,,,, -67400,1.472503,4.362477,,,,,,,,,,,,,, -67500,1.8519949,2.533012,,,,,,,,,,,,,, -67600,1.6279486,3.2229266,,,,,,,,,,,,,, -67700,2.0839176,2.4558744,,,,,,,,,,,,,, -67800,2.06943,2.5502155,,,,,,,,,,,,,, -67900,1.814714,4.875414,,,,,,,,,,,,,, -68000,1.917701,2.4798188,,,,,,,,,,,,,, -68020,,,0.6380273103713989,1.4931342601776123,0.5981599688529968,1.6907784938812256,50000.0,0.4786000251770019,2.350700855255127,10000.0,31127.23282170296,33934.169929265976,31127.23282170296,2800.633416414261,2.7796566486358643,0.0 -68100,2.0069559,2.5712867,,,,,,,,,,,,,, -68200,1.8214071,2.8221893,,,,,,,,,,,,,, -68300,1.9780395,2.638453,,,,,,,,,,,,,, -68400,1.9669147,2.571377,,,,,,,,,,,,,, -68500,1.9874707,2.7405171,,,,,,,,,,,,,, -68600,1.7152529,3.1951458,,,,,,,,,,,,,, -68700,1.8280756,2.4545674,,,,,,,,,,,,,, -68800,2.0807834,2.5445144,,,,,,,,,,,,,, -68900,1.9116834,2.5092463,,,,,,,,,,,,,, -68942,,,0.6507812142372131,1.4588117599487305,0.5990399718284607,1.6879316568374634,50000.0,0.4767000079154968,2.359503746032715,10000.0,31547.49372458458,34395.838116168976,31547.49372458458,2841.955982208252,2.814785957336426,0.0 -69000,1.542604,4.202791,,,,,,,,,,,,,, -69100,1.9703671,2.6521678,,,,,,,,,,,,,, -69200,1.6774068,2.8036945,,,,,,,,,,,,,, -69300,1.3618968,3.993413,,,,,,,,,,,,,, -69400,1.9549978,2.5183268,,,,,,,,,,,,,, -69500,1.7044089,2.654102,,,,,,,,,,,,,, -69600,1.5215425,4.151017,,,,,,,,,,,,,, -69700,2.1727512,2.6129272,,,,,,,,,,,,,, -69800,2.0272405,2.616295,,,,,,,,,,,,,, -69864,,,0.654589831829071,1.434368371963501,0.6043199896812439,1.666023850440979,50000.0,0.482200026512146,2.3192484378814697,10000.0,31967.64450263977,34855.22176671028,31967.64450263977,2881.1043269634247,2.851184606552124,0.0 -69900,1.6566535,2.714491,,,,,,,,,,,,,, -70000,1.7709914,2.9691768,,,,,,,,,,,,,, -70100,2.037851,2.6643,,,,,,,,,,,,,, -70200,2.115419,2.712908,,,,,,,,,,,,,, -70300,1.7530663,2.8946364,,,,,,,,,,,,,, -70400,1.4808242,3.913106,,,,,,,,,,,,,, -70500,1.7929323,2.572102,,,,,,,,,,,,,, -70600,1.6178106,3.4424667,,,,,,,,,,,,,, -70700,1.9347616,2.4258509,,,,,,,,,,,,,, -70784,,,0.6522656083106995,1.430999517440796,0.6041399836540222,1.6534889936447144,50000.0,0.490200012922287,2.3093273639678955,10000.0,32387.930511713028,35306.4135351181,32387.930511713028,2911.928690671921,2.8854973316192627,0.0 -70800,1.8862225,2.5170033,,,,,,,,,,,,,, -70900,1.7395391,2.483156,,,,,,,,,,,,,, -71000,1.8598434,2.4890144,,,,,,,,,,,,,, -71100,2.0369043,2.3844075,,,,,,,,,,,,,, -71200,1.6638073,3.4910345,,,,,,,,,,,,,, -71300,1.8513062,2.5145714,,,,,,,,,,,,,, -71400,1.8347075,2.5216317,,,,,,,,,,,,,, -71500,1.6166178,3.147541,,,,,,,,,,,,,, -71600,2.2084837,2.530546,,,,,,,,,,,,,, -71700,1.9462242,2.4337769,,,,,,,,,,,,,, -71705,,,0.657031238079071,1.4209951162338257,0.6064199805259705,1.661571025848389,50000.0,0.4886000156402588,2.3075990676879883,10000.0,32808.145381212234,35762.34394454956,32808.145381212234,2947.54905629158,2.9329519271850586,0.0 -71800,1.706659,3.121519,,,,,,,,,,,,,, -71900,1.7957631,2.725106,,,,,,,,,,,,,, -72000,1.468355,4.5963144,,,,,,,,,,,,,, -72100,1.8261753,2.5116816,,,,,,,,,,,,,, -72200,1.9639188,2.520232,,,,,,,,,,,,,, -72300,1.6630805,3.3056617,,,,,,,,,,,,,, -72400,1.7736815,3.321743,,,,,,,,,,,,,, -72500,1.6494317,4.113865,,,,,,,,,,,,,, -72600,2.1213348,2.612999,,,,,,,,,,,,,, -72625,,,0.6685937643051147,1.38480544090271,0.6049599647521973,1.6850557327270508,50000.0,0.4834000170230865,2.3566348552703857,10000.0,33228.30744147301,36219.24794006348,33228.30744147301,2984.2034389972687,2.973165988922119,0.0 -72700,1.8775172,3.0893083,,,,,,,,,,,,,, -72800,1.8928893,2.305183,,,,,,,,,,,,,, -72900,1.9969386,2.423386,,,,,,,,,,,,,, -73000,2.1127717,2.424318,,,,,,,,,,,,,, -73100,1.7989494,2.9792488,,,,,,,,,,,,,, -73200,1.4856545,4.7517,,,,,,,,,,,,,, -73300,1.5011687,3.8075395,,,,,,,,,,,,,, -73400,1.6503149,3.9519134,,,,,,,,,,,,,, -73500,1.6817098,3.1143937,,,,,,,,,,,,,, -73544,,,0.6450976133346558,1.466753005981445,0.6068800091743469,1.662022590637207,50000.0,0.4855000376701355,2.3125555515289307,10000.0,33648.37844085693,36678.69011569023,33648.37844085693,3023.486344099045,3.0142745971679688,0.0 -73600,1.7165041,4.352779,,,,,,,,,,,,,, -73700,2.0452373,2.5836525,,,,,,,,,,,,,, -73800,2.0831869,2.586163,,,,,,,,,,,,,, -73900,1.5053294,4.3179245,,,,,,,,,,,,,, -74000,2.03994,2.494548,,,,,,,,,,,,,, -74100,1.49196,3.9409666,,,,,,,,,,,,,, -74200,1.8612908,2.7352462,,,,,,,,,,,,,, -74300,1.5896488,4.583774,,,,,,,,,,,,,, -74400,1.609502,4.2417336,,,,,,,,,,,,,, -74464,,,0.6503515243530273,1.4331283569335938,0.6075199842453003,1.6410701274871826,50000.0,0.4845000207424164,2.3070414066314697,10000.0,34068.383913517,37139.40186858177,34068.383913517,3064.1061642169952,3.052863836288452,0.0 -74500,2.2577646,2.5497658,,,,,,,,,,,,,, -74600,1.967713,2.588743,,,,,,,,,,,,,, -74700,2.288565,2.6270785,,,,,,,,,,,,,, -74800,1.397335,4.5781803,,,,,,,,,,,,,, -74900,1.7215093,2.8330882,,,,,,,,,,,,,, -75000,1.5081252,5.1686106,,,,,,,,,,,,,, -75100,2.008546,2.3771615,,,,,,,,,,,,,, -75200,2.0065835,2.3330317,,,,,,,,,,,,,, -75300,1.9939791,2.518271,,,,,,,,,,,,,, -75386,,,0.6625195145606995,1.3859827518463137,0.6066799759864807,1.6424832344055176,50000.0,0.4878000319004059,2.278662919998169,10000.0,34488.60297369957,37595.3496427536,34488.60297369957,3099.742102622986,3.097740411758423,0.0 -75400,1.7590744,2.8587618,,,,,,,,,,,,,, -75500,1.915951,2.7728415,,,,,,,,,,,,,, -75600,1.8240256,2.509134,,,,,,,,,,,,,, -75700,1.7089778,3.2561898,,,,,,,,,,,,,, -75800,1.7650621,2.4140077,,,,,,,,,,,,,, -75900,1.7733837,3.7638917,,,,,,,,,,,,,, -76000,1.7736378,4.590227,,,,,,,,,,,,,, -76100,1.5613253,3.6350145,,,,,,,,,,,,,, -76200,1.9070667,2.3659568,,,,,,,,,,,,,, -76300,1.9533796,4.6952643,,,,,,,,,,,,,, -76306,,,0.6537500023841858,1.4322412014007568,0.6104399561882019,1.6254669427871704,50000.0,0.4958000183105469,2.2743124961853027,10000.0,34908.645033836365,38057.16771769524,34908.645033836365,3141.430042743683,3.137625217437744,0.0 -76400,1.6752393,5.134469,,,,,,,,,,,,,, -76500,1.920134,2.862034,,,,,,,,,,,,,, -76600,1.89162,5.086731,,,,,,,,,,,,,, -76700,1.7046846,3.7448082,,,,,,,,,,,,,, -76800,2.2174325,2.5140796,,,,,,,,,,,,,, -76900,1.9572712,2.8874848,,,,,,,,,,,,,, -77000,1.9266709,2.460301,,,,,,,,,,,,,, -77100,1.900803,2.3279943,,,,,,,,,,,,,, -77200,1.9942498,2.3915112,,,,,,,,,,,,,, -77227,,,0.6592382788658142,1.3970091342926023,0.615399956703186,1.60317063331604,50000.0,0.484900027513504,2.277125358581543,10000.0,35328.58205103874,38514.08833122253,35328.58205103874,3178.3272848129272,3.1766517162323,0.0 -77300,2.1579716,2.56081,,,,,,,,,,,,,, -77400,1.690212,4.566803,,,,,,,,,,,,,, -77500,2.0071049,2.5401456,,,,,,,,,,,,,, -77600,1.9768524,2.5738537,,,,,,,,,,,,,, -77700,1.6100695,4.813218,,,,,,,,,,,,,, -77800,1.7578743,4.9150324,,,,,,,,,,,,,, -77900,1.7375301,4.4185696,,,,,,,,,,,,,, -78000,1.5389014,5.1393776,,,,,,,,,,,,,, -78100,2.1638033,2.3559327,,,,,,,,,,,,,, -78150,,,0.6634374856948853,1.390366792678833,0.6144999861717224,1.6266072988510132,50000.0,0.4933000206947326,2.281214952468872,10000.0,35748.85032916069,38976.1696164608,35748.85032916069,3220.056258201599,3.212582588195801,0.0 -78200,1.9137958,2.4654226,,,,,,,,,,,,,, -78300,1.5704964,3.614494,,,,,,,,,,,,,, -78400,1.811446,4.933664,,,,,,,,,,,,,, -78500,1.9410955,2.5064397,,,,,,,,,,,,,, -78600,1.8058106,2.7899318,,,,,,,,,,,,,, -78700,1.6429762,2.9653633,,,,,,,,,,,,,, -78800,2.0908675,2.3239238,,,,,,,,,,,,,, -78900,2.2113614,2.5462651,,,,,,,,,,,,,, -79000,1.8858343,2.4191463,,,,,,,,,,,,,, -79071,,,0.6562694907188416,1.445751428604126,0.614139974117279,1.6321967840194702,50000.0,0.488500028848648,2.2790701389312744,10000.0,36169.0088224411,39435.61829423904,36169.0088224411,3259.250541448593,3.260847330093384,0.0 -79100,1.8873912,2.521522,,,,,,,,,,,,,, -79200,1.9302106,2.6020732,,,,,,,,,,,,,, -79300,1.7950977,2.404815,,,,,,,,,,,,,, -79400,1.574951,3.74734,,,,,,,,,,,,,, -79500,1.8945379,3.6985967,,,,,,,,,,,,,, -79600,1.8884245,2.4623625,,,,,,,,,,,,,, -79700,2.1672254,2.4667723,,,,,,,,,,,,,, -79800,1.4537332,4.9626465,,,,,,,,,,,,,, -79900,2.0588105,2.5164316,,,,,,,,,,,,,, -79992,,,0.6576757431030273,1.4273301362991333,0.6126799583435059,1.6328654289245603,50000.0,0.492900013923645,2.28087854385376,10000.0,36589.23210167885,39891.485827207565,36589.23210167885,3294.8051438331604,3.3025565147399902,0.0 -80000,1.7315704,5.0140696,,,,,,,,,,,,,, -80100,1.8658794,2.3083954,,,,,,,,,,,,,, -80200,1.61633,5.02656,,,,,,,,,,,,,, -80300,2.1006157,2.5180156,,,,,,,,,,,,,, -80400,1.906671,2.393994,,,,,,,,,,,,,, -80500,2.0619297,2.4234443,,,,,,,,,,,,,, -80600,1.6181628,4.057302,,,,,,,,,,,,,, -80700,1.9628634,2.69051,,,,,,,,,,,,,, -80800,2.074679,2.490455,,,,,,,,,,,,,, -80900,1.9399091,2.5216846,,,,,,,,,,,,,, -80913,,,0.674023449420929,1.3578218221664429,0.6184200048446655,1.6004390716552734,50000.0,0.4927000105381012,2.2475414276123047,10000.0,37009.35448670387,40351.29439616203,37009.35448670387,3334.398421525955,3.347494840621948,0.0 -81000,1.9332173,2.7168076,,,,,,,,,,,,,, -81100,2.0864613,2.4019108,,,,,,,,,,,,,, -81200,2.360724,2.3855488,,,,,,,,,,,,,, -81300,2.1244504,2.302267,,,,,,,,,,,,,, -81400,1.5562012,3.6116579,,,,,,,,,,,,,, -81500,1.8604755,2.6113198,,,,,,,,,,,,,, -81600,1.9310683,2.345955,,,,,,,,,,,,,, -81700,1.9388416,2.7309902,,,,,,,,,,,,,, -81800,1.596118,4.5346174,,,,,,,,,,,,,, -81833,,,0.6919335722923279,1.269242525100708,0.6193400025367737,1.5890880823135376,50000.0,0.4976000189781189,2.232736349105835,10000.0,37429.34026837349,40809.9657497406,37429.34026837349,3372.996278524399,3.3874971866607666,0.0 -81900,1.9214938,2.3879676,,,,,,,,,,,,,, -82000,2.0623636,2.4564252,,,,,,,,,,,,,, -82100,2.1319118,2.4564924,,,,,,,,,,,,,, -82200,2.02519,2.9099202,,,,,,,,,,,,,, -82300,2.2129917,2.4309878,,,,,,,,,,,,,, -82400,1.9659946,2.4709568,,,,,,,,,,,,,, -82500,1.9392589,2.4989321,,,,,,,,,,,,,, -82600,2.011915,2.3475564,,,,,,,,,,,,,, -82700,2.0806277,2.4186292,,,,,,,,,,,,,, -82750,,,0.6674609184265137,1.4016146659851074,0.6218599677085876,1.6021045446395874,50000.0,0.4955000281333923,2.2633657455444336,10000.0,37849.2662627697,41267.85855412483,37849.2662627697,3410.873600244522,3.429560661315918,0.0 -82800,1.8455689,3.9080203,,,,,,,,,,,,,, -82900,2.034308,2.5496519,,,,,,,,,,,,,, -83000,1.8016534,2.7168832,,,,,,,,,,,,,, -83100,2.1825144,2.3393705,,,,,,,,,,,,,, -83200,1.9934307,2.4064653,,,,,,,,,,,,,, -83300,2.229335,2.3826752,,,,,,,,,,,,,, -83400,1.7482866,3.4728298,,,,,,,,,,,,,, -83500,2.047534,2.1598175,,,,,,,,,,,,,, -83600,1.8957065,2.6783545,,,,,,,,,,,,,, -83672,,,0.6716406345367432,1.3420121669769287,0.6239799857139587,1.5581839084625244,50000.0,0.5044000148773193,2.198859214782715,10000.0,38269.307027578354,41726.82884001732,38269.307027578354,3449.7180716991425,3.4665777683258057,0.0 -83700,1.7822416,2.265322,,,,,,,,,,,,,, -83800,2.0793686,2.3896337,,,,,,,,,,,,,, -83900,1.843157,2.935215,,,,,,,,,,,,,, -84000,2.374528,2.4107938,,,,,,,,,,,,,, -84100,1.7416253,4.8032656,,,,,,,,,,,,,, -84200,1.8888731,2.4898512,,,,,,,,,,,,,, -84300,1.6463556,4.996553,,,,,,,,,,,,,, -84400,1.9167006,3.027862,,,,,,,,,,,,,, -84500,1.5811845,4.7542667,,,,,,,,,,,,,, -84591,,,0.6886523365974426,1.2715094089508057,0.6290000081062317,1.5482810735702517,50000.0,0.5059000253677368,2.2042882442474365,10000.0,38689.40639066696,42179.727041482925,38689.40639066696,3482.425269842148,3.50958251953125,0.0 -84600,1.5568199,4.931179,,,,,,,,,,,,,, -84700,1.7090185,3.4616513,,,,,,,,,,,,,, -84800,1.7879968,2.9340088,,,,,,,,,,,,,, -84900,1.7388479,3.4478745,,,,,,,,,,,,,, -85000,1.9624805,2.282858,,,,,,,,,,,,,, -85100,1.7757069,5.048961,,,,,,,,,,,,,, -85200,2.2475984,2.4897134,,,,,,,,,,,,,, -85300,1.8726103,2.9405856,,,,,,,,,,,,,, -85400,2.2630997,2.4308016,,,,,,,,,,,,,, -85500,1.994763,2.3502116,,,,,,,,,,,,,, -85512,,,0.6723827719688416,1.3585656881332395,0.6247400045394897,1.566789627075195,50000.0,0.5028000473976135,2.2121453285217285,10000.0,39109.69956469536,42640.74392175674,39109.69956469536,3523.0634427070618,3.547616958618164,0.0 -85600,1.9070743,3.7034645,,,,,,,,,,,,,, -85700,2.2218745,2.3515527,,,,,,,,,,,,,, -85800,1.8467454,2.981917,,,,,,,,,,,,,, -85900,2.2066212,2.387896,,,,,,,,,,,,,, -86000,2.0550423,2.3720105,,,,,,,,,,,,,, -86100,2.0371752,2.5257916,,,,,,,,,,,,,, -86200,1.9209255,2.922389,,,,,,,,,,,,,, -86300,1.8452294,2.4752007,,,,,,,,,,,,,, -86400,1.8593034,4.6331134,,,,,,,,,,,,,, -86433,,,0.6749609112739563,1.341333031654358,0.6245799660682678,1.5732306241989136,50000.0,0.4973000288009643,2.2271392345428467,10000.0,39529.97727918625,43098.75600481033,39529.97727918625,3560.707655906677,3.589759588241577,0.0 -86500,1.7567519,3.8525295,,,,,,,,,,,,,, -86600,2.0724466,2.1960735,,,,,,,,,,,,,, -86700,2.0453825,2.3691583,,,,,,,,,,,,,, -86800,2.0039036,2.4809318,,,,,,,,,,,,,, -86900,1.9129786,2.8028665,,,,,,,,,,,,,, -87000,1.8922949,4.612268,,,,,,,,,,,,,, -87100,1.933737,2.4595885,,,,,,,,,,,,,, -87200,2.14955,2.632504,,,,,,,,,,,,,, -87300,1.7396837,4.7150397,,,,,,,,,,,,,, -87354,,,0.6795117259025574,1.3415911197662354,0.6220999956130981,1.5983110666275024,50000.0,0.4999000132083893,2.2489187717437744,10000.0,39949.916645765305,43556.20781803131,39949.916645765305,3598.127161026001,3.6342861652374254,0.0 -87400,1.779539,3.2675738,,,,,,,,,,,,,, -87500,1.7158482,3.8736053,,,,,,,,,,,,,, -87600,2.1450522,2.3327017,,,,,,,,,,,,,, -87700,2.2140648,2.480561,,,,,,,,,,,,,, -87800,2.0559123,2.3657784,,,,,,,,,,,,,, -87900,2.0195706,2.2632756,,,,,,,,,,,,,, -88000,2.1891122,2.330458,,,,,,,,,,,,,, -88100,1.9746891,2.5608978,,,,,,,,,,,,,, -88200,2.1143641,2.331324,,,,,,,,,,,,,, -88275,,,0.6721289157867432,1.3381081819534302,0.62909996509552,1.537670016288757,50000.0,0.5071000456809998,2.2031280994415283,10000.0,40369.94181585312,44014.411709070206,40369.94181585312,3636.217653512954,3.6755335330963135,0.0 -88300,1.764491,4.49954,,,,,,,,,,,,,, -88400,2.0067754,4.98139,,,,,,,,,,,,,, -88500,2.0670283,2.3377578,,,,,,,,,,,,,, -88600,2.127299,2.2422798,,,,,,,,,,,,,, -88700,1.9563171,3.567944,,,,,,,,,,,,,, -88800,2.1600823,2.2931166,,,,,,,,,,,,,, -88900,1.9655638,2.3336408,,,,,,,,,,,,,, -89000,1.9176599,2.8310628,,,,,,,,,,,,,, -89100,2.149328,2.3271322,,,,,,,,,,,,,, -89194,,,0.6786718368530273,1.3135088682174685,0.6307799816131592,1.5366266965866089,50000.0,0.5053000450134277,2.1833741664886475,10000.0,40790.010445833206,44473.98690462112,40790.010445833206,3675.636269807816,3.71611762046814,0.0 -89200,2.0387769,2.3461168,,,,,,,,,,,,,, -89300,1.6673203,3.767912,,,,,,,,,,,,,, -89400,2.0025835,4.6670923,,,,,,,,,,,,,, -89500,1.7570741,4.077583,,,,,,,,,,,,,, -89600,1.7327185,3.0125163,,,,,,,,,,,,,, -89700,1.9294844,4.330848,,,,,,,,,,,,,, -89800,2.244802,2.4102597,,,,,,,,,,,,,, -89900,2.0512562,4.8455315,,,,,,,,,,,,,, -90000,1.9316401,3.3162766,,,,,,,,,,,,,, -90100,1.7443442,3.1362746,,,,,,,,,,,,,, -90112,,,0.6904687285423279,1.2547414302825928,0.6344999670982361,1.5105417966842651,50000.0,0.510200023651123,2.1635007858276367,10000.0,41210.27362036705,44933.42651605606,41210.27362036705,3714.726364850998,3.755250453948975,0.0 -90200,1.9137069,3.269003,,,,,,,,,,,,,, -90300,1.8316003,3.603333,,,,,,,,,,,,,, -90400,1.9870921,3.9973936,,,,,,,,,,,,,, -90500,2.30797,2.2429366,,,,,,,,,,,,,, -90600,1.6956534,4.4116135,,,,,,,,,,,,,, -90700,2.1588268,2.3240745,,,,,,,,,,,,,, -90800,1.7394332,4.249901,,,,,,,,,,,,,, -90900,1.8239602,4.41004,,,,,,,,,,,,,, -91000,2.1805196,2.3910635,,,,,,,,,,,,,, -91032,,,0.6882616877555847,1.284019112586975,0.637939989566803,1.5183864831924438,50000.0,0.5189000368118286,2.1651880741119385,10000.0,41630.5404984951,45394.16607880592,41630.5404984951,3755.1138138771057,3.793323993682861,0.0 -91100,1.8006622,3.0598679,,,,,,,,,,,,,, -91200,2.4512098,2.2626927,,,,,,,,,,,,,, -91300,2.09596,2.8655202,,,,,,,,,,,,,, -91400,1.9012926,2.2783508,,,,,,,,,,,,,, -91500,1.8955688,4.900774,,,,,,,,,,,,,, -91600,2.1141117,2.1903496,,,,,,,,,,,,,, -91700,2.2333875,2.4644604,,,,,,,,,,,,,, -91800,2.5403748,2.4152052,,,,,,,,,,,,,, -91900,2.1802897,2.1230347,,,,,,,,,,,,,, -91951,,,0.6794726252555847,1.3359767198562622,0.6306799650192261,1.5554149150848389,50000.0,0.5070000290870667,2.2034645080566406,10000.0,42050.500351428986,45848.689403772354,42050.500351428986,3789.587110757828,3.8358139991760254,0.0 -92000,2.1786613,2.4736323,,,,,,,,,,,,,, -92100,1.9715798,4.730377,,,,,,,,,,,,,, -92200,2.211949,2.7541013,,,,,,,,,,,,,, -92300,1.7790289,4.811392,,,,,,,,,,,,,, -92400,1.6709894,3.642861,,,,,,,,,,,,,, -92500,2.362135,2.266546,,,,,,,,,,,,,, -92600,1.8608942,2.875725,,,,,,,,,,,,,, -92700,2.1281173,2.3372617,,,,,,,,,,,,,, -92800,1.8681532,2.237748,,,,,,,,,,,,,, -92871,,,0.6898437142372131,1.2812461853027344,0.6362199783325195,1.5131189823150637,50000.0,0.5109000205993652,2.15106463432312,10000.0,42470.74870181084,46305.96616792679,42470.74870181084,3826.5279400348654,3.875086069107056,0.0 -92900,1.7748513,4.9805775,,,,,,,,,,,,,, -93000,2.1083903,2.5391304,,,,,,,,,,,,,, -93100,1.6844485,3.9213865,,,,,,,,,,,,,, -93200,2.3982573,2.219567,,,,,,,,,,,,,, -93300,1.71424,4.263421,,,,,,,,,,,,,, -93400,2.1341145,2.4186127,,,,,,,,,,,,,, -93500,2.0253756,2.4337125,,,,,,,,,,,,,, -93600,2.0463932,3.4261017,,,,,,,,,,,,,, -93700,2.081626,2.2761,,,,,,,,,,,,,, -93790,,,0.7091601490974426,1.1998614072799685,0.6394400000572205,1.507826566696167,50000.0,0.5225000381469727,2.150918960571289,10000.0,42890.90657186508,46765.64679956436,42890.90657186508,3865.9597566127777,3.9189999103546143,0.0 -93800,1.8332845,4.9407086,,,,,,,,,,,,,, -93900,2.1786013,3.7150574,,,,,,,,,,,,,, -94000,1.9546511,2.7805471,,,,,,,,,,,,,, -94100,2.2767186,2.2452083,,,,,,,,,,,,,, -94200,2.2731242,2.4137635,,,,,,,,,,,,,, -94300,1.9162016,4.373964,,,,,,,,,,,,,, -94400,2.1222823,2.1677983,,,,,,,,,,,,,, -94500,2.2052636,2.634018,,,,,,,,,,,,,, -94600,2.0807288,2.257539,,,,,,,,,,,,,, -94700,1.9051433,2.789736,,,,,,,,,,,,,, -94711,,,0.6896093487739563,1.2857972383499146,0.6369999647140503,1.5179483890533447,50000.0,0.5175999999046326,2.1582961082458496,10000.0,43310.93693423271,47221.96811914444,43310.93693423271,3902.16099691391,3.960517644882202,0.0 -94800,1.9030845,4.542184,,,,,,,,,,,,,, -94900,2.1151366,2.3199625,,,,,,,,,,,,,, -95000,2.0645688,2.2683063,,,,,,,,,,,,,, -95100,2.0966887,2.2052941,,,,,,,,,,,,,, -95200,2.4195585,2.1248095,,,,,,,,,,,,,, -95300,1.799548,3.240608,,,,,,,,,,,,,, -95400,1.8531289,4.83302,,,,,,,,,,,,,, -95500,1.9141315,3.942153,,,,,,,,,,,,,, -95600,2.1310039,2.5686722,,,,,,,,,,,,,, -95632,,,0.6902929544448853,1.3088555335998535,0.6377800107002258,1.5345081090927124,50000.0,0.5159000158309937,2.168625831604004,10000.0,43731.27998661995,47682.84707713127,43731.27998661995,3942.607335329056,4.002686023712158,0.0 -95700,2.2894032,2.2080386,,,,,,,,,,,,,, -95800,2.0525832,2.756347,,,,,,,,,,,,,, -95900,2.1620314,2.2695625,,,,,,,,,,,,,, -96000,1.867865,3.2700088,,,,,,,,,,,,,, -96100,2.0022678,4.898067,,,,,,,,,,,,,, -96200,2.1160834,2.2337732,,,,,,,,,,,,,, -96300,2.3575408,2.121888,,,,,,,,,,,,,, -96400,2.1667397,2.2926745,,,,,,,,,,,,,, -96500,1.930091,3.1981094,,,,,,,,,,,,,, -96552,,,0.7011132836341858,1.2075631618499756,0.6449399590492249,1.4762593507766724,50000.0,0.5157999992370605,2.130087614059448,10000.0,44151.53822731972,48142.17348623276,44151.53822731972,3981.583114147186,4.047634124755859,0.0 -96600,1.9605002,2.7079306,,,,,,,,,,,,,, -96700,2.5885592,2.2414756,,,,,,,,,,,,,, -96800,2.4086611,2.3355553,,,,,,,,,,,,,, -96900,2.1741073,4.8287096,,,,,,,,,,,,,, -97000,2.2715893,2.245955,,,,,,,,,,,,,, -97100,1.91395,4.759693,,,,,,,,,,,,,, -97200,1.9347451,4.7755604,,,,,,,,,,,,,, -97300,2.497991,2.207375,,,,,,,,,,,,,, -97400,2.076601,4.2796535,,,,,,,,,,,,,, -97473,,,0.6884570121765137,1.2699761390686035,0.6421399712562561,1.4902182817459106,50000.0,0.5208000540733337,2.130515813827514,10000.0,44571.59751033783,48599.9216735363,44571.59751033783,4019.17895770073,4.093117952346802,0.0 -97500,2.1546416,2.2617579,,,,,,,,,,,,,, -97600,2.6298828,2.299581,,,,,,,,,,,,,, -97700,2.0001554,3.7740295,,,,,,,,,,,,,, -97800,2.1743488,2.3089252,,,,,,,,,,,,,, -97900,2.2610059,2.5540104,,,,,,,,,,,,,, -98000,1.8647913,3.9221976,,,,,,,,,,,,,, -98100,1.9391673,3.1480448,,,,,,,,,,,,,, -98200,2.3730745,2.1763363,,,,,,,,,,,,,, -98300,2.3913972,4.0867176,,,,,,,,,,,,,, -98393,,,0.691699206829071,1.2806599140167236,0.6419199705123901,1.4992659091949463,50000.0,0.5187000036239624,2.157788753509521,10000.0,44991.82571578026,49061.43175268173,44991.82571578026,4060.374292612076,4.132027626037598,0.0 -98400,1.9392715,4.2874823,,,,,,,,,,,,,, -98500,3.2848442,2.2552094,,,,,,,,,,,,,, -98600,1.886179,4.59556,,,,,,,,,,,,,, -98700,2.2032647,2.1871476,,,,,,,,,,,,,, -98800,2.1905918,2.3194005,,,,,,,,,,,,,, -98900,2.2517776,2.1115491,,,,,,,,,,,,,, -99000,2.1794631,2.1849518,,,,,,,,,,,,,, -99100,1.879698,3.6904283,,,,,,,,,,,,,, -99200,2.267683,2.2303076,,,,,,,,,,,,,, -99300,2.0718434,2.2499907,,,,,,,,,,,,,, -99312,,,0.7009570002555847,1.2121543884277344,0.6444999575614929,1.4744832515716553,50000.0,0.5258000493049622,2.101590633392334,10000.0,45411.43276429176,49521.42532157898,45411.43276429176,4100.273753166199,4.570847034454346,0.0 -99400,2.3771813,2.1472025,,,,,,,,,,,,,, -99500,2.1264467,2.1702063,,,,,,,,,,,,,, -99600,2.3995354,2.289348,,,,,,,,,,,,,, -99700,1.9692385,2.3294148,,,,,,,,,,,,,, -99800,2.0108142,2.3108492,,,,,,,,,,,,,, -99900,2.5427806,2.1908422,,,,,,,,,,,,,, -100000,1.8044722,3.7264454,,,,,,,,,,,,,, -100100,2.6127498,2.3708386,,,,,,,,,,,,,, -100200,2.1778076,2.2426298,,,,,,,,,,,,,, -100232,,,0.6933984160423279,1.2418538331985474,0.6492399573326111,1.4507449865341189,50000.0,0.5182000398635864,2.115975856781006,10000.0,45831.65687251091,49977.31903076172,45831.65687251091,4135.854706764221,4.6116437911987305,0.0 -100300,2.4853437,2.2435312,,,,,,,,,,,,,, -100400,2.0212734,4.7044315,,,,,,,,,,,,,, -100500,2.1052215,2.8204708,,,,,,,,,,,,,, -100600,2.3647187,2.3933897,,,,,,,,,,,,,, -100700,2.2930288,2.1368032,,,,,,,,,,,,,, -100800,2.168611,2.2119522,,,,,,,,,,,,,, -100900,1.9012744,3.5369115,,,,,,,,,,,,,, -101000,2.4168143,2.1898522,,,,,,,,,,,,,, -101100,2.2465649,2.2691395,,,,,,,,,,,,,, -101152,,,0.6918163895606995,1.2562538385391235,0.6451799869537354,1.4713010787963867,50000.0,0.5205000042915344,2.1192281246185303,10000.0,46251.6263062954,50434.65885519981,46251.6263062954,4173.125863075256,4.663263559341431,0.0 -101200,2.3896856,2.2656054,,,,,,,,,,,,,, -101300,2.3643968,2.219659,,,,,,,,,,,,,, -101400,2.25435,2.2981234,,,,,,,,,,,,,, -101500,2.745525,2.210405,,,,,,,,,,,,,, -101600,2.2181356,2.3650584,,,,,,,,,,,,,, -101700,2.2433572,2.4350002,,,,,,,,,,,,,, -101800,2.029123,3.311939,,,,,,,,,,,,,, -101900,2.307567,2.7179956,,,,,,,,,,,,,, -102000,2.2233946,3.4276533,,,,,,,,,,,,,, -102069,,,0.7005273103713989,1.223625421524048,0.64656001329422,1.4615575075149536,50000.0,0.5216000080108643,2.0956971645355225,10000.0,46672.01550674439,50894.275889635086,46672.01550674439,4212.254663944244,4.7148377895355225,0.0 -102100,2.313139,2.3389554,,,,,,,,,,,,,, -102200,2.5526106,2.2746813,,,,,,,,,,,,,, -102300,2.451653,2.596652,,,,,,,,,,,,,, -102400,2.2859714,2.2098105,,,,,,,,,,,,,, -102500,2.415705,2.0907893,,,,,,,,,,,,,, -102600,2.184945,4.3577876,,,,,,,,,,,,,, -102700,2.2449296,2.110386,,,,,,,,,,,,,, -102800,2.4081912,2.1921148,,,,,,,,,,,,,, -102900,2.4070046,2.2322416,,,,,,,,,,,,,, -102986,,,0.7196484208106995,1.1426352262496948,0.6477000117301941,1.4675030708312988,50000.0,0.5217000246047974,2.1090574264526367,10000.0,47092.12380337715,51354.57412791252,47092.12380337715,4252.351578950882,4.760955810546875,0.0 -103000,1.994944,4.0144315,,,,,,,,,,,,,, -103100,2.5386786,2.1933196,,,,,,,,,,,,,, -103200,1.9163994,3.9542983,,,,,,,,,,,,,, -103300,2.228879,2.0391583,,,,,,,,,,,,,, -103400,2.2538035,4.8830833,,,,,,,,,,,,,, -103500,1.9630203,2.9075944,,,,,,,,,,,,,, -103600,2.335645,2.145823,,,,,,,,,,,,,, -103700,2.2265565,2.1942446,,,,,,,,,,,,,, -103800,2.2535455,4.120908,,,,,,,,,,,,,, -103900,2.1766849,2.2665908,,,,,,,,,,,,,, -103907,,,0.7007421851158142,1.230646014213562,0.6519399881362915,1.450840711593628,50000.0,0.5212000012397766,2.102874994277954,10000.0,47512.29468727112,51815.11089801788,47512.29468727112,4292.62885594368,4.802370309829712,0.0 -104000,2.245571,2.1200778,,,,,,,,,,,,,, -104100,2.8757887,2.311378,,,,,,,,,,,,,, -104200,2.4105012,2.1543517,,,,,,,,,,,,,, -104300,2.4436817,2.2541242,,,,,,,,,,,,,, -104400,2.437496,2.1824484,,,,,,,,,,,,,, -104500,2.5458894,2.094551,,,,,,,,,,,,,, -104600,2.2985198,2.119265,,,,,,,,,,,,,, -104700,2.3550549,2.1314282,,,,,,,,,,,,,, -104800,2.298614,4.6657257,,,,,,,,,,,,,, -104826,,,0.7058984041213989,1.1845290660858154,0.6528800129890442,1.4384924173355105,50000.0,0.5300000309944153,2.06282377243042,10000.0,47932.29152727127,52271.9458630085,47932.29152727127,4329.373802185059,4.848757028579712,0.0 -104900,2.7662005,2.1753898,,,,,,,,,,,,,, -105000,1.9969854,4.764709,,,,,,,,,,,,,, -105100,2.265139,2.200269,,,,,,,,,,,,,, -105200,2.381873,2.1857238,,,,,,,,,,,,,, -105300,2.2778263,2.4758728,,,,,,,,,,,,,, -105400,2.0188015,3.259389,,,,,,,,,,,,,, -105500,2.2381108,2.232237,,,,,,,,,,,,,, -105600,2.2017686,2.6175694,,,,,,,,,,,,,, -105700,2.2778282,2.1542408,,,,,,,,,,,,,, -105745,,,0.71728515625,1.1596193313598633,0.6541799902915955,1.4440714120864868,50000.0,0.528700053691864,2.088734865188598,10000.0,48352.20537209511,52731.31432533264,48352.20537209511,4368.735318899155,4.893977880477905,0.0 -105800,2.588124,3.689745,,,,,,,,,,,,,, -105900,2.284672,2.7453122,,,,,,,,,,,,,, -106000,2.7253559,2.2725766,,,,,,,,,,,,,, -106100,2.4021282,2.2056348,,,,,,,,,,,,,, -106200,2.1867223,4.6080236,,,,,,,,,,,,,, -106300,2.3005676,2.1703324,,,,,,,,,,,,,, -106400,2.5254512,2.3743765,,,,,,,,,,,,,, -106500,2.7623177,2.2494497,,,,,,,,,,,,,, -106600,2.1365638,4.130002,,,,,,,,,,,,,, -106667,,,0.7004296779632568,1.221923828125,0.6500799655914307,1.4510246515274048,50000.0,0.5254999995231628,2.090418100357056,10000.0,48772.65300059319,53188.10340118408,48772.65300059319,4404.980627298355,4.941775798797607,0.0 -106700,2.2021472,4.5662756,,,,,,,,,,,,,, -106800,2.3213756,2.4515054,,,,,,,,,,,,,, -106900,2.2977867,3.6367006,,,,,,,,,,,,,, -107000,2.6610525,2.1000571,,,,,,,,,,,,,, -107100,2.4991817,2.2908704,,,,,,,,,,,,,, -107200,2.2536,2.1946127,,,,,,,,,,,,,, -107300,2.1028452,4.008024,,,,,,,,,,,,,, -107400,2.8145735,2.2486343,,,,,,,,,,,,,, -107500,2.092504,3.9123275,,,,,,,,,,,,,, -107589,,,0.7095702886581421,1.1786792278289795,0.655239999294281,1.4211721420288086,50000.0,0.5357000231742859,2.077857255935669,10000.0,49192.90219426155,53643.56610870361,49192.90219426155,4440.104451656342,4.982898235321045,0.0 -107600,2.322643,4.670494,,,,,,,,,,,,,, -107700,2.1519833,2.5092216,,,,,,,,,,,,,, -107800,2.6528146,2.2575274,,,,,,,,,,,,,, -107900,2.447266,2.1567984,,,,,,,,,,,,,, -108000,2.3392084,2.2028456,,,,,,,,,,,,,, -108100,2.1058872,3.4126632,,,,,,,,,,,,,, -108200,2.450504,4.895567,,,,,,,,,,,,,, -108300,2.6725914,2.439042,,,,,,,,,,,,,, -108400,2.0876782,4.411614,,,,,,,,,,,,,, -108500,2.1470137,3.44613,,,,,,,,,,,,,, -108510,,,0.7232226133346558,1.1331206560134888,0.6595199704170227,1.4153560400009155,50000.0,0.5354000329971313,2.0454037189483643,10000.0,49612.81418466568,54101.761887550354,49612.81418466568,4478.299695491791,5.023855924606323,0.0 -108600,2.4549038,2.120796,,,,,,,,,,,,,, -108700,2.1734374,4.76168,,,,,,,,,,,,,, -108800,2.6446161,2.1691115,,,,,,,,,,,,,, -108900,2.5954275,2.493652,,,,,,,,,,,,,, -109000,2.3422256,2.8706014,,,,,,,,,,,,,, -109100,2.6354883,2.0509212,,,,,,,,,,,,,, -109200,2.722587,2.4059625,,,,,,,,,,,,,, -109300,2.4702032,2.1818697,,,,,,,,,,,,,, -109400,2.2475212,2.39138,,,,,,,,,,,,,, -109430,,,0.7046093344688416,1.207485914230347,0.6543200016021729,1.4327102899551392,50000.0,0.5306000113487244,2.0763421058654785,10000.0,50032.897715091705,54561.05555176735,50032.897715091705,4517.4173810482025,5.06796669960022,0.0 -109500,2.2478898,2.46838,,,,,,,,,,,,,, -109600,2.2428842,2.6169407,,,,,,,,,,,,,, -109700,2.1226382,3.649757,,,,,,,,,,,,,, -109800,2.2350254,4.7710624,,,,,,,,,,,,,, -109900,2.466256,2.0189457,,,,,,,,,,,,,, -110000,2.477068,2.058637,,,,,,,,,,,,,, -110100,2.634783,2.2067418,,,,,,,,,,,,,, -110200,2.4384348,4.7128425,,,,,,,,,,,,,, -110300,2.637072,2.3171713,,,,,,,,,,,,,, -110350,,,0.711621105670929,1.1574349403381348,0.6602199673652649,1.392660140991211,50000.0,0.5335000157356262,2.0360000133514404,10000.0,50452.97357225418,55019.047716379166,50452.97357225418,4555.2443215847015,5.109800100326538,0.0 -110400,2.523724,2.514956,,,,,,,,,,,,,, -110500,2.1367202,3.0785184,,,,,,,,,,,,,, -110600,2.370452,4.1494665,,,,,,,,,,,,,, -110700,2.286318,2.155412,,,,,,,,,,,,,, -110800,2.2727432,2.3089879,,,,,,,,,,,,,, -110900,2.238761,2.7827072,,,,,,,,,,,,,, -111000,2.3248067,2.459611,,,,,,,,,,,,,, -111100,2.555235,2.0586123,,,,,,,,,,,,,, -111200,2.7122223,2.2945974,,,,,,,,,,,,,, -111270,,,0.7218554615974426,1.139596462249756,0.6591399908065796,1.4095954895019531,50000.0,0.5386000275611877,2.0380427837371826,10000.0,50873.10785269737,55478.7564125061,50873.10785269737,4594.720537662506,5.160179615020752,0.0 -111300,2.533899,2.5702736,,,,,,,,,,,,,, -111400,2.4521651,2.6647675,,,,,,,,,,,,,, -111500,2.3336964,2.072886,,,,,,,,,,,,,, -111600,2.5145023,2.1595712,,,,,,,,,,,,,, -111700,2.3616102,2.0813208,,,,,,,,,,,,,, -111800,2.2801907,4.7969184,,,,,,,,,,,,,, -111900,2.4348419,2.914906,,,,,,,,,,,,,, -112000,2.2983942,4.498448,,,,,,,,,,,,,, -112100,2.755703,2.1872735,,,,,,,,,,,,,, -112189,,,0.72132807970047,1.1169127225875854,0.6618199944496155,1.3724278211593628,50000.0,0.5388000011444092,2.0063302516937256,10000.0,51293.37358379364,55939.25805449486,51293.37358379364,4634.864460945129,5.204523324966431,0.0 -112200,2.6301186,2.0814862,,,,,,,,,,,,,, -112300,2.713248,2.0946603,,,,,,,,,,,,,, -112400,2.5620759,2.2156107,,,,,,,,,,,,,, -112500,2.5502799,2.163432,,,,,,,,,,,,,, -112600,2.2783027,2.4769554,,,,,,,,,,,,,, -112700,2.4214258,2.1727443,,,,,,,,,,,,,, -112800,2.6743002,1.9707028,,,,,,,,,,,,,, -112900,2.5071383,2.2294898,,,,,,,,,,,,,, -113000,2.5857139,2.115809,,,,,,,,,,,,,, -113100,2.409159,2.9533675,,,,,,,,,,,,,, -113109,,,0.7181445360183716,1.1483697891235352,0.6665199995040894,1.376614332199097,50000.0,0.5427000522613525,2.011446714401245,10000.0,51713.63455915451,56399.57521724701,51713.63455915451,4674.825638771057,5.251617193222046,0.0 -113200,2.7561655,2.1414342,,,,,,,,,,,,,, -113300,2.0388024,3.868299,,,,,,,,,,,,,, -113400,2.1341403,3.4669793,,,,,,,,,,,,,, -113500,2.799961,2.9034338,,,,,,,,,,,,,, -113600,2.8701632,2.0836737,,,,,,,,,,,,,, -113700,2.3665593,2.1175687,,,,,,,,,,,,,, -113800,2.289343,2.7315295,,,,,,,,,,,,,, -113900,2.2803476,3.649431,,,,,,,,,,,,,, -114000,2.6691878,2.3268251,,,,,,,,,,,,,, -114029,,,0.7268164157867432,1.0959932804107666,0.6672999858856201,1.353961706161499,50000.0,0.5432000160217285,2.00452733039856,10000.0,52133.7013399601,56860.199070215225,52133.7013399601,4715.29369020462,5.293383836746216,0.0 -114100,2.5065672,2.139392,,,,,,,,,,,,,, -114200,2.5832036,2.0697536,,,,,,,,,,,,,, -114300,2.4785364,1.9800272,,,,,,,,,,,,,, -114400,2.3020692,3.344236,,,,,,,,,,,,,, -114500,2.33517,3.0285811,,,,,,,,,,,,,, -114600,2.3806512,2.1442971,,,,,,,,,,,,,, -114700,2.782814,2.1217036,,,,,,,,,,,,,, -114800,2.6291933,2.0695107,,,,,,,,,,,,,, -114900,2.6845973,2.140511,,,,,,,,,,,,,, -114950,,,0.7395898103713989,1.0591590404510498,0.6657999753952026,1.3845274448394775,50000.0,0.5387000441551208,2.023356914520264,10000.0,52553.64809894562,57317.40486860275,52553.64809894562,4752.462374687195,5.3358001708984375,0.0 -115000,2.4353588,2.2182295,,,,,,,,,,,,,, -115100,2.4156165,2.136066,,,,,,,,,,,,,, -115200,2.926504,2.0633366,,,,,,,,,,,,,, -115300,2.7435625,2.2660182,,,,,,,,,,,,,, -115400,2.4890416,2.0079339,,,,,,,,,,,,,, -115500,2.4389184,4.0098,,,,,,,,,,,,,, -115600,2.7266784,2.1691225,,,,,,,,,,,,,, -115700,2.392992,2.578036,,,,,,,,,,,,,, -115800,2.6509333,2.0391488,,,,,,,,,,,,,, -115872,,,0.72279292345047,1.1121561527252195,0.670799970626831,1.3483924865722656,50000.0,0.5461000204086304,1.9749292135238647,10000.0,52973.94215083122,57775.563047885895,52973.94215083122,4790.233314990997,5.381194114685059,0.0 -115900,2.4874296,2.428061,,,,,,,,,,,,,, -116000,2.402941,3.926579,,,,,,,,,,,,,, -116100,2.3204455,2.3368192,,,,,,,,,,,,,, -116200,2.333505,3.4122622,,,,,,,,,,,,,, -116300,2.3568635,2.6524944,,,,,,,,,,,,,, -116400,2.194473,4.3008046,,,,,,,,,,,,,, -116500,2.6878803,2.090739,,,,,,,,,,,,,, -116600,2.531694,2.2071443,,,,,,,,,,,,,, -116700,2.2726963,3.302568,,,,,,,,,,,,,, -116791,,,0.7229296565055847,1.1437608003616333,0.6652799844741821,1.4038317203521729,50000.0,0.542900025844574,2.02569580078125,10000.0,53393.94225502014,58233.54987645149,53393.94225502014,4828.128947257996,5.423955678939819,0.0 -116800,2.5707808,2.1032324,,,,,,,,,,,,,, -116900,2.6897407,4.172255,,,,,,,,,,,,,, -117000,2.6285686,2.1058881,,,,,,,,,,,,,, -117100,2.175027,3.6082737,,,,,,,,,,,,,, -117200,2.5535908,2.2234857,,,,,,,,,,,,,, -117300,2.6121097,2.0308142,,,,,,,,,,,,,, -117400,2.589274,2.057856,,,,,,,,,,,,,, -117500,2.5068192,1.8923311,,,,,,,,,,,,,, -117600,2.6480997,1.9300046,,,,,,,,,,,,,, -117700,2.292973,3.0227833,,,,,,,,,,,,,, -117710,,,0.741406261920929,1.021639108657837,0.6748600006103516,1.3188023567199707,50000.0,0.5506000518798828,1.951604962348938,10000.0,53813.971108198166,58690.75763726234,53813.971108198166,4865.206468343735,5.476790428161621,0.0 -117800,2.846429,4.33278,,,,,,,,,,,,,, -117900,2.4238029,3.6444736,,,,,,,,,,,,,, -118000,2.371453,3.4044106,,,,,,,,,,,,,, -118100,2.3445992,4.5686984,,,,,,,,,,,,,, -118200,2.4538548,4.5027347,,,,,,,,,,,,,, -118300,2.5409567,2.1780133,,,,,,,,,,,,,, -118400,2.6515875,4.0654845,,,,,,,,,,,,,, -118500,2.2382255,3.8310673,,,,,,,,,,,,,, -118600,2.5179822,3.064198,,,,,,,,,,,,,, -118631,,,0.7268164157867432,1.1247150897979736,0.6675800085067749,1.379623532295227,50000.0,0.5515000224113464,1.998087406158448,10000.0,54233.92829823494,59147.98989415169,54233.92829823494,4902.387039661408,5.5232415199279785,0.0 -118700,2.702497,3.26114,,,,,,,,,,,,,, -118800,2.5024204,2.1264384,,,,,,,,,,,,,, -118900,2.506029,3.5829003,,,,,,,,,,,,,, -119000,2.5001593,4.250554,,,,,,,,,,,,,, -119100,2.855113,2.069953,,,,,,,,,,,,,, -119200,2.5998294,1.9860191,,,,,,,,,,,,,, -119300,3.074425,4.1455593,,,,,,,,,,,,,, -119400,2.7648418,2.0049434,,,,,,,,,,,,,, -119500,2.758645,1.9814622,,,,,,,,,,,,,, -119551,,,0.7312304377555847,1.0820475816726685,0.6734399795532227,1.340196967124939,50000.0,0.5533000230789185,1.9669947624206543,10000.0,54653.83897304535,59607.46091222763,54653.83897304535,4941.857255935669,5.566040754318237,0.0 -119600,2.406598,2.384354,,,,,,,,,,,,,, -119700,2.665456,2.5019622,,,,,,,,,,,,,, -119800,2.4651933,4.2879934,,,,,,,,,,,,,, -119900,2.750119,4.535949,,,,,,,,,,,,,, -120000,2.5009394,2.133819,,,,,,,,,,,,,, -120100,2.3567224,4.3850155,,,,,,,,,,,,,, -120200,2.4152353,2.1982956,,,,,,,,,,,,,, -120300,2.4843695,3.8323872,,,,,,,,,,,,,, -120400,2.6102133,1.9904581,,,,,,,,,,,,,, -120472,,,0.7416015267372131,1.0313720703125,0.6787199974060059,1.3119388818740845,50000.0,0.5540000200271606,1.9518202543258667,10000.0,55074.14640974999,60065.947801828384,55074.14640974999,4979.9430141448975,5.610635757446289,0.0 -120500,2.7910137,1.9883808,,,,,,,,,,,,,, -120600,2.5140576,3.9880035,,,,,,,,,,,,,, -120700,2.5108953,2.8261666,,,,,,,,,,,,,, -120800,2.4497082,4.2125425,,,,,,,,,,,,,, -120900,2.5423183,2.1624024,,,,,,,,,,,,,, -121000,2.7787168,2.0156457,,,,,,,,,,,,,, -121100,2.2762494,2.545887,,,,,,,,,,,,,, -121200,2.4054196,3.2197943,,,,,,,,,,,,,, -121300,2.4568076,2.2183,,,,,,,,,,,,,, -121393,,,0.7378906011581421,1.0619568824768066,0.6817599534988403,1.3020442724227903,50000.0,0.5552000403404236,1.934516668319702,10000.0,55494.27367591858,60526.06591033936,55494.27367591858,5019.839179754257,5.657719373703003,0.0 -121400,2.4561102,2.3668756,,,,,,,,,,,,,, -121500,2.5040433,2.4542117,,,,,,,,,,,,,, -121600,2.2673388,3.8543735,,,,,,,,,,,,,, -121700,2.511721,3.3720453,,,,,,,,,,,,,, -121800,2.6485715,2.1817205,,,,,,,,,,,,,, -121900,2.6363263,2.3597143,,,,,,,,,,,,,, -122000,2.8809705,2.0841768,,,,,,,,,,,,,, -122100,2.8248122,4.083464,,,,,,,,,,,,,, -122200,2.9574583,1.9654994,,,,,,,,,,,,,, -122300,3.0177948,2.046419,,,,,,,,,,,,,, -122313,,,0.73011714220047,1.0760235786437988,0.6768199801445007,1.3135195970535278,50000.0,0.54830002784729,1.960036039352417,10000.0,55914.204786777496,60984.79734659195,55914.204786777496,5058.548875808716,5.700714349746704,0.0 -122400,3.2765727,1.9247775,,,,,,,,,,,,,, -122500,2.8075807,4.3536015,,,,,,,,,,,,,, -122600,2.4483266,2.8948073,,,,,,,,,,,,,, -122700,2.5375109,2.990962,,,,,,,,,,,,,, -122800,2.8701324,4.4822464,,,,,,,,,,,,,, -122900,2.4601223,3.829895,,,,,,,,,,,,,, -123000,2.7290978,4.4402246,,,,,,,,,,,,,, -123100,2.9226668,2.0944104,,,,,,,,,,,,,, -123200,2.4868789,2.2950118,,,,,,,,,,,,,, -123234,,,0.7426952719688416,1.0379911661148071,0.6814799904823303,1.3065279722213743,50000.0,0.5596000552177429,1.9329463243484497,10000.0,56334.31851649284,61441.13616466522,56334.31851649284,5094.679251432419,5.7470362186431885,0.0 -123300,2.7673748,2.045809,,,,,,,,,,,,,, -123400,2.7238362,3.517574,,,,,,,,,,,,,, -123500,2.6877432,3.0249062,,,,,,,,,,,,,, -123600,3.262599,1.8860092,,,,,,,,,,,,,, -123700,2.4835198,4.376871,,,,,,,,,,,,,, -123800,2.6507914,2.1402383,,,,,,,,,,,,,, -123900,3.0552478,2.6021001,,,,,,,,,,,,,, -124000,3.0925562,1.9370856,,,,,,,,,,,,,, -124100,2.7660294,4.5977893,,,,,,,,,,,,,, -124154,,,0.7546288967132568,0.995815634727478,0.6815999746322632,1.3149043321609497,50000.0,0.562000036239624,1.938615322113037,10000.0,56754.34809017181,61900.59717607498,56754.34809017181,5134.01913523674,5.790019989013672,0.0 -124200,2.6614652,1.8645558,,,,,,,,,,,,,, -124300,2.8731084,4.5772476,,,,,,,,,,,,,, -124400,2.9110405,2.1906104,,,,,,,,,,,,,, -124500,2.9051914,3.034302,,,,,,,,,,,,,, -124600,3.0011673,2.0142555,,,,,,,,,,,,,, -124700,2.6240203,3.2421145,,,,,,,,,,,,,, -124800,2.6932397,3.5583344,,,,,,,,,,,,,, -124900,2.8023784,4.562994,,,,,,,,,,,,,, -125000,2.5605845,3.3552597,,,,,,,,,,,,,, -125075,,,0.7390820384025574,1.068745493888855,0.6852999925613403,1.3048429489135742,50000.0,0.5525000095367432,1.9468294382095337,10000.0,57174.4171538353,62361.01248407364,57174.4171538353,5174.269753456116,5.836784839630127,0.0 -125100,2.805966,4.0536437,,,,,,,,,,,,,, -125200,2.7042112,2.7939763,,,,,,,,,,,,,, -125300,2.8893538,1.9501617,,,,,,,,,,,,,, -125400,2.580264,2.3623705,,,,,,,,,,,,,, -125500,2.6260502,3.0595412,,,,,,,,,,,,,, -125600,2.8766208,1.9438365,,,,,,,,,,,,,, -125700,2.897307,1.8246806,,,,,,,,,,,,,, -125800,2.716249,3.1532054,,,,,,,,,,,,,, -125900,2.9473662,2.1323044,,,,,,,,,,,,,, -125998,,,0.7458398342132568,1.0124694108963013,0.6858400106430054,1.2887572050094604,50000.0,0.5623000264167786,1.9115575551986688,10000.0,57594.61695051193,62816.308730363846,57594.61695051193,5209.269439458847,5.885339975357056,0.0 -126000,2.8795376,1.9324988,,,,,,,,,,,,,, -126100,3.0051618,1.8186642,,,,,,,,,,,,,, -126200,2.7124372,2.355867,,,,,,,,,,,,,, -126300,3.2445102,1.9847243,,,,,,,,,,,,,, -126400,2.8430448,1.9924726,,,,,,,,,,,,,, -126500,3.0799172,1.9457061,,,,,,,,,,,,,, -126600,2.9847043,1.9924122,,,,,,,,,,,,,, -126700,2.7113335,2.5778449,,,,,,,,,,,,,, -126800,3.0348184,1.9722216,,,,,,,,,,,,,, -126900,3.2593992,1.966713,,,,,,,,,,,,,, -126918,,,0.7591796517372131,0.9492820501327516,0.6929000020027161,1.2491334676742554,50000.0,0.5681000351905823,1.88283109664917,10000.0,58014.5237903595,63274.89465546608,58014.5237903595,5247.848149061203,5.937408208847046,0.0 -127000,2.8282642,2.1992316,,,,,,,,,,,,,, -127100,2.701853,2.0776336,,,,,,,,,,,,,, -127200,2.559233,3.5474105,,,,,,,,,,,,,, -127300,2.7615736,2.7188761,,,,,,,,,,,,,, -127400,3.1344783,1.9219731,,,,,,,,,,,,,, -127500,2.8787544,1.9492192,,,,,,,,,,,,,, -127600,3.387017,2.0705884,,,,,,,,,,,,,, -127700,2.8007355,1.9028237,,,,,,,,,,,,,, -127800,2.9705265,2.1523294,,,,,,,,,,,,,, -127838,,,0.7441796660423279,1.0422630310058594,0.6882599592208862,1.2880897521972656,50000.0,0.5680000185966492,1.907129049301148,10000.0,58434.72530388832,63734.5433690548,58434.72530388832,5287.201727390289,5.98337721824646,0.0 -127900,2.865694,4.1382565,,,,,,,,,,,,,, -128000,3.0581074,1.8507792,,,,,,,,,,,,,, -128100,2.8493376,1.8901703,,,,,,,,,,,,,, -128200,2.7492697,2.3236985,,,,,,,,,,,,,, -128300,2.5988326,3.5601745,,,,,,,,,,,,,, -128400,3.328558,4.457538,,,,,,,,,,,,,, -128500,2.5468607,3.0881398,,,,,,,,,,,,,, -128600,2.8042338,4.3239307,,,,,,,,,,,,,, -128700,2.7560465,2.0804012,,,,,,,,,,,,,, -128759,,,0.74853515625,1.0025569200515747,0.6908800005912781,1.2619134187698364,50000.0,0.5669000148773193,1.8810703754425049,10000.0,58854.81752896309,64194.79561638832,58854.81752896309,5327.26294875145,6.034414768218994,0.0 -128800,3.1069322,2.0978487,,,,,,,,,,,,,, -128900,3.0464323,1.9549822,,,,,,,,,,,,,, -129000,2.8317556,2.310109,,,,,,,,,,,,,, -129100,2.8329072,2.3655486,,,,,,,,,,,,,, -129200,2.8953342,1.8426135,,,,,,,,,,,,,, -129300,3.1171043,1.8712573,,,,,,,,,,,,,, -129400,3.2131567,4.4304743,,,,,,,,,,,,,, -129500,2.7571688,4.153066,,,,,,,,,,,,,, -129600,3.011016,2.002678,,,,,,,,,,,,,, -129680,,,0.7611327767372131,0.9462226629257202,0.6943999528884888,1.2438348531723022,50000.0,0.5715000033378601,1.8652161359786987,10000.0,59275.05393505096,64651.74168539047,59275.05393505096,5363.8732233047485,6.085147142410278,0.0 -129700,3.006769,1.9492317,,,,,,,,,,,,,, -129800,2.8640969,2.9538612,,,,,,,,,,,,,, -129900,3.0316112,1.9339223,,,,,,,,,,,,,, -130000,2.9958615,2.096498,,,,,,,,,,,,,, -130100,3.3875942,1.9184371,,,,,,,,,,,,,, -130200,3.0948796,2.0044465,,,,,,,,,,,,,, -130300,2.9552245,3.0947788,,,,,,,,,,,,,, -130400,3.1585612,2.6964068,,,,,,,,,,,,,, -130500,2.778228,2.2866025,,,,,,,,,,,,,, -130600,3.0298254,2.092657,,,,,,,,,,,,,, -130601,,,0.7528710961341858,0.9790438413619996,0.6981199979782104,1.2228922843933103,50000.0,0.5756000280380249,1.8345760107040403,10000.0,59695.36253666878,65109.05203604698,59695.36253666878,5400.780389070511,6.132253170013428,0.0 -130700,2.9230273,3.6909084,,,,,,,,,,,,,, -130800,2.5683494,2.6155846,,,,,,,,,,,,,, -130900,2.9969518,1.9955415,,,,,,,,,,,,,, -131000,3.274004,2.0367362,,,,,,,,,,,,,, -131100,3.256842,1.9984071,,,,,,,,,,,,,, -131200,3.5548964,4.527364,,,,,,,,,,,,,, -131300,2.9530318,2.6931212,,,,,,,,,,,,,, -131400,3.4520667,3.719262,,,,,,,,,,,,,, -131500,2.9326737,4.2078915,,,,,,,,,,,,,, -131524,,,0.7593945264816284,0.9528237581253052,0.695639967918396,1.2331786155700684,50000.0,0.5743000507354736,1.8481074571609497,10000.0,60115.69743990898,65567.70094275475,60115.69743990898,5438.998530864716,6.179803848266602,0.0 -131600,4.2092094,1.9150543,,,,,,,,,,,,,, -131700,2.9574797,1.9160702,,,,,,,,,,,,,, -131800,3.483787,4.361546,,,,,,,,,,,,,, -131900,3.0599165,4.3294687,,,,,,,,,,,,,, -132000,2.9470832,3.0976253,,,,,,,,,,,,,, -132100,3.096256,3.498319,,,,,,,,,,,,,, -132200,3.2041776,2.0482554,,,,,,,,,,,,,, -132300,3.263533,2.0455904,,,,,,,,,,,,,, -132400,2.9187565,2.8840668,,,,,,,,,,,,,, -132444,,,0.7639452815055847,0.9388325214385986,0.6972399950027466,1.2304061651229858,50000.0,0.5740000009536743,1.8517664670944207,10000.0,60535.69753456116,66027.71335840225,60535.69753456116,5478.9129366874695,6.230070352554321,0.0 -132500,3.1822357,1.9663842,,,,,,,,,,,,,, -132600,3.4205348,2.1015792,,,,,,,,,,,,,, -132700,2.7597272,2.7432663,,,,,,,,,,,,,, -132800,3.3568282,2.197287,,,,,,,,,,,,,, -132900,3.1216788,1.7922368,,,,,,,,,,,,,, -133000,3.313916,2.0660176,,,,,,,,,,,,,, -133100,3.126745,3.0421047,,,,,,,,,,,,,, -133200,2.8225963,2.5134447,,,,,,,,,,,,,, -133300,3.065102,2.7777371,,,,,,,,,,,,,, -133364,,,0.7642187476158142,0.9470322132110596,0.7019000053405762,1.2145510911941528,50000.0,0.5734000205993652,1.842565536499024,10000.0,60955.82214999199,66485.37777233124,60955.82214999199,5516.357885599136,6.277270317077637,0.0 -133400,3.0489573,1.9573963,,,,,,,,,,,,,, -133500,3.3072023,1.7489524,,,,,,,,,,,,,, -133600,3.21943,2.1955426,,,,,,,,,,,,,, -133700,3.2327263,2.063633,,,,,,,,,,,,,, -133800,3.0982862,1.8351845,,,,,,,,,,,,,, -133900,3.4057078,1.8255358,,,,,,,,,,,,,, -134000,2.9783134,2.816054,,,,,,,,,,,,,, -134100,3.009897,2.5701485,,,,,,,,,,,,,, -134200,3.2398286,1.8006073,,,,,,,,,,,,,, -134284,,,0.7623632550239563,0.9518302083015442,0.7013999819755554,1.2288814783096311,50000.0,0.5781000256538391,1.8447034358978271,10000.0,61376.02754378319,66944.4675412178,61376.02754378319,5555.144359827042,6.326719284057617,0.0 -134300,2.9061055,3.3545446,,,,,,,,,,,,,, -134400,3.5553753,3.8846877,,,,,,,,,,,,,, -134500,2.959922,3.779257,,,,,,,,,,,,,, -134600,3.020606,4.1163945,,,,,,,,,,,,,, -134700,4.1218553,4.3351345,,,,,,,,,,,,,, -134800,3.3033028,1.7947886,,,,,,,,,,,,,, -134900,2.8234985,2.32936,,,,,,,,,,,,,, -135000,3.074082,1.6998752,,,,,,,,,,,,,, -135100,3.1798298,1.8915269,,,,,,,,,,,,,, -135200,2.9940367,2.2858262,,,,,,,,,,,,,, -135201,,,0.7685546875,0.9147396087646484,0.7045199871063232,1.204154372215271,50000.0,0.5804000496864319,1.818007469177246,10000.0,61796.07824969292,67401.66700196266,61796.07824969292,5592.194238185883,6.378962516784668,0.0 -135300,3.3230777,1.7951436,,,,,,,,,,,,,, -135400,2.9860969,3.4715848,,,,,,,,,,,,,, -135500,2.9808373,2.2871006,,,,,,,,,,,,,, -135600,3.5558028,1.9301589,,,,,,,,,,,,,, -135700,2.950847,2.718856,,,,,,,,,,,,,, -135800,3.3153043,1.9103606,,,,,,,,,,,,,, -135900,3.3174808,1.7875613,,,,,,,,,,,,,, -136000,3.5068624,2.2252078,,,,,,,,,,,,,, -136100,3.2318697,2.449897,,,,,,,,,,,,,, -136122,,,0.7792773246765137,0.8840520977973938,0.7030199766159058,1.2064038515090942,50000.0,0.5815000534057617,1.8282631635665887,10000.0,62216.07617998123,67860.29677629471,62216.07617998123,5630.725626945496,6.431999683380127,0.0 -136200,3.0107095,3.7154808,,,,,,,,,,,,,, -136300,3.6620562,1.7791767,,,,,,,,,,,,,, -136400,4.0894213,3.9064333,,,,,,,,,,,,,, -136500,3.2290058,2.2021797,,,,,,,,,,,,,, -136600,3.2918074,2.3843913,,,,,,,,,,,,,, -136700,3.299811,1.8534952,,,,,,,,,,,,,, -136800,3.67738,1.8511932,,,,,,,,,,,,,, -136900,3.4007218,3.7714353,,,,,,,,,,,,,, -137000,3.5261543,1.8089021,,,,,,,,,,,,,, -137042,,,0.765625,0.9307951331138612,0.7046799659729004,1.205407738685608,50000.0,0.5777000188827515,1.8337697982788088,10000.0,62636.41758394241,68321.68812680244,62636.41758394241,5671.676826477051,6.483324766159058,0.0 -137100,3.5514567,1.6827571,,,,,,,,,,,,,, -137200,3.328375,2.4945776,,,,,,,,,,,,,, -137300,3.275561,3.394465,,,,,,,,,,,,,, -137400,3.2172608,2.0787623,,,,,,,,,,,,,, -137500,3.3533354,1.7908986,,,,,,,,,,,,,, -137600,3.288897,2.4279149,,,,,,,,,,,,,, -137700,3.6062734,4.232234,,,,,,,,,,,,,, -137800,3.4870954,1.9322886,,,,,,,,,,,,,, -137900,3.5749967,4.3498583,,,,,,,,,,,,,, -137962,,,0.7728710770606995,0.9098615050315856,0.7051399946212769,1.195683240890503,50000.0,0.5809000134468079,1.822838544845581,10000.0,63056.63118767738,68779.26647734642,63056.63118767738,5708.94930100441,6.528689861297607,0.0 -138000,3.265025,2.856932,,,,,,,,,,,,,, -138100,3.343182,2.1557255,,,,,,,,,,,,,, -138200,3.9749715,4.2302265,,,,,,,,,,,,,, -138300,3.1189725,3.0106099,,,,,,,,,,,,,, -138400,2.9920852,2.6906745,,,,,,,,,,,,,, -138500,3.40858,2.1381803,,,,,,,,,,,,,, -138600,4.3933067,2.2231426,,,,,,,,,,,,,, -138700,3.2995524,3.1801772,,,,,,,,,,,,,, -138800,3.3063045,2.142651,,,,,,,,,,,,,, -138884,,,0.7806445360183716,0.8815276026725769,0.7076799869537354,1.1955801248550415,50000.0,0.5842000246047974,1.8043678998947144,10000.0,63476.66384673119,69239.99764561653,63476.66384673119,5749.549212932587,6.578404903411865,0.0 -138900,3.2258565,3.631722,,,,,,,,,,,,,, -139000,3.4230533,1.7824246,,,,,,,,,,,,,, -139100,3.5403225,2.5218549,,,,,,,,,,,,,, -139200,3.399753,1.9307562,,,,,,,,,,,,,, -139300,3.0954733,2.0083246,,,,,,,,,,,,,, -139400,3.2959108,1.821243,,,,,,,,,,,,,, -139500,3.566231,3.9262478,,,,,,,,,,,,,, -139600,3.3810925,1.9384581,,,,,,,,,,,,,, -139700,3.5100374,1.7533739,,,,,,,,,,,,,, -139800,3.2824392,2.6296868,,,,,,,,,,,,,, -139805,,,0.771289050579071,0.9243224859237672,0.7076199650764465,1.2005985975265503,50000.0,0.5848000049591064,1.822811365127564,10000.0,63896.58831167221,69700.79360175133,63896.58831167221,5790.326649427414,6.624654054641724,0.0 -139900,3.6253269,1.9096975,,,,,,,,,,,,,, -140000,3.224442,1.7030357,,,,,,,,,,,,,, -140100,3.4807873,1.7920686,,,,,,,,,,,,,, -140200,3.669327,1.9557594,,,,,,,,,,,,,, -140300,3.4274194,1.8019878,,,,,,,,,,,,,, -140400,3.3340533,1.9060137,,,,,,,,,,,,,, -140500,3.8045115,1.887537,,,,,,,,,,,,,, -140600,3.5743277,2.5568135,,,,,,,,,,,,,, -140700,3.830754,1.742646,,,,,,,,,,,,,, -140726,,,0.777148425579071,0.8691675662994385,0.7116000056266785,1.1573071479797363,50000.0,0.5871000289916992,1.7759684324264526,10000.0,64316.83931660652,70159.80298304558,64316.83931660652,5828.982895612717,6.679133176803589,0.0 -140800,3.4285336,2.0433745,,,,,,,,,,,,,, -140900,3.5370152,2.8820906,,,,,,,,,,,,,, -141000,3.6070209,1.816385,,,,,,,,,,,,,, -141100,3.5746107,1.6288989,,,,,,,,,,,,,, -141200,3.3407247,2.3137426,,,,,,,,,,,,,, -141300,3.8848433,1.7713207,,,,,,,,,,,,,, -141400,3.6037424,3.0895908,,,,,,,,,,,,,, -141500,3.9539928,3.5798795,,,,,,,,,,,,,, -141600,3.5032592,1.8172154,,,,,,,,,,,,,, -141648,,,0.7853710651397705,0.8723379969596863,0.7133600115776062,1.185519099235535,50000.0,0.5850000381469727,1.806455373764038,10000.0,64737.20595860481,70620.941873312,64737.20595860481,5869.657928228378,6.726979732513428,0.0 -141700,3.7578018,1.8414936,,,,,,,,,,,,,, -141800,3.5436149,2.5562947,,,,,,,,,,,,,, -141900,4.0456586,1.8074585,,,,,,,,,,,,,, -142000,3.3833826,2.4420114,,,,,,,,,,,,,, -142100,3.4433613,1.8566247,,,,,,,,,,,,,, -142200,3.3658078,2.1451688,,,,,,,,,,,,,, -142300,3.8375072,1.9835643,,,,,,,,,,,,,, -142400,4.0728445,4.3166094,,,,,,,,,,,,,, -142500,3.5903468,2.7887197,,,,,,,,,,,,,, -142567,,,0.7803124785423279,0.8587034344673157,0.7161999940872192,1.133145093917847,50000.0,0.5960000157356262,1.7515387535095217,10000.0,65157.24151563645,71079.86623930931,65157.24151563645,5908.061641693115,7.164468288421631,0.0 -142600,3.1409888,2.7978075,,,,,,,,,,,,,, -142700,3.6801355,1.7206292,,,,,,,,,,,,,, -142800,4.4629054,1.790652,,,,,,,,,,,,,, -142900,3.7740662,1.7224339,,,,,,,,,,,,,, -143000,3.472392,3.153185,,,,,,,,,,,,,, -143100,4.600232,4.248193,,,,,,,,,,,,,, -143200,3.3728104,2.6922836,,,,,,,,,,,,,, -143300,3.9682624,3.7661526,,,,,,,,,,,,,, -143400,3.9866734,1.8148239,,,,,,,,,,,,,, -143486,,,0.7825976610183716,0.8492751121520996,0.7155599594116211,1.143829345703125,50000.0,0.5928000211715698,1.7593907117843628,10000.0,65577.55893421173,71537.87617897987,65577.55893421173,5945.656690597534,7.214048862457275,0.0 -143500,3.738633,2.370994,,,,,,,,,,,,,, -143600,3.6037643,2.3148963,,,,,,,,,,,,,, -143700,3.625643,2.2488902,,,,,,,,,,,,,, -143800,4.4875617,4.246709,,,,,,,,,,,,,, -143900,3.93443,3.573099,,,,,,,,,,,,,, -144000,3.8872578,1.7371283,,,,,,,,,,,,,, -144100,4.051441,4.213641,,,,,,,,,,,,,, -144200,3.811216,2.8844845,,,,,,,,,,,,,, -144300,4.2049894,1.7739308,,,,,,,,,,,,,, -144400,3.702867,2.1688945,,,,,,,,,,,,,, -144404,,,0.7892382740974426,0.8360196352005005,0.7186799645423889,1.148780107498169,50000.0,0.5980000495910645,1.7610589265823364,10000.0,65997.90977883339,71999.52923107147,65997.90977883339,5986.8613493442535,7.263937711715698,0.0 -144500,3.847027,1.7884153,,,,,,,,,,,,,, -144600,3.929344,1.7049855,,,,,,,,,,,,,, -144700,3.430389,1.8540316,,,,,,,,,,,,,, -144800,3.4049468,3.8318155,,,,,,,,,,,,,, -144900,3.4986408,1.7185445,,,,,,,,,,,,,, -145000,4.3370776,3.760407,,,,,,,,,,,,,, -145100,3.950452,3.5181334,,,,,,,,,,,,,, -145200,4.0734034,2.8641505,,,,,,,,,,,,,, -145300,4.519762,4.122857,,,,,,,,,,,,,, -145324,,,0.796191394329071,0.796824038028717,0.7212199568748474,1.1199946403503418,50000.0,0.5958000421524048,1.755782127380371,10000.0,66417.887434721,72457.2994647026,66417.887434721,6024.559996366501,7.30958104133606,0.0 -145400,3.5868244,2.071014,,,,,,,,,,,,,, -145500,3.860039,2.974577,,,,,,,,,,,,,, -145600,4.5233383,1.599649,,,,,,,,,,,,,, -145700,3.8317537,3.002934,,,,,,,,,,,,,, -145800,3.949458,2.2814746,,,,,,,,,,,,,, -145900,3.7003694,1.6781147,,,,,,,,,,,,,, -146000,4.4023347,3.6951737,,,,,,,,,,,,,, -146100,3.8287032,2.0749083,,,,,,,,,,,,,, -146200,4.1165924,1.7394152,,,,,,,,,,,,,, -146245,,,0.7898046970367432,0.825894832611084,0.7240999937057495,1.1160138845443726,50000.0,0.6011000275611877,1.7424756288528442,10000.0,66837.9214565754,72916.71459913254,66837.9214565754,6063.822338104248,7.380070686340332,0.0 -146300,4.18311,3.0785515,,,,,,,,,,,,,, -146400,4.287749,2.184649,,,,,,,,,,,,,, -146500,3.7360346,3.7477796,,,,,,,,,,,,,, -146600,4.317474,3.3010843,,,,,,,,,,,,,, -146700,3.8874571,1.9654961,,,,,,,,,,,,,, -146800,4.192549,4.033771,,,,,,,,,,,,,, -146900,3.838865,3.2375803,,,,,,,,,,,,,, -147000,4.0049076,1.8109683,,,,,,,,,,,,,, -147100,3.953466,1.7388666,,,,,,,,,,,,,, -147163,,,0.7924023270606995,0.8204847574234009,0.7229200005531311,1.114357590675354,50000.0,0.5966000556945801,1.7341212034225464,10000.0,67257.93815946579,73377.8737001419,67257.93815946579,6104.866254329681,7.430980443954468,0.0 -147200,4.1002064,1.884622,,,,,,,,,,,,,, -147300,4.0743113,1.8096702,,,,,,,,,,,,,, -147400,4.0726027,2.0000215,,,,,,,,,,,,,, -147500,3.9064507,1.73671,,,,,,,,,,,,,, -147600,3.7546875,1.9188155,,,,,,,,,,,,,, -147700,4.4425364,1.7339253,,,,,,,,,,,,,, -147800,3.6624415,2.5391085,,,,,,,,,,,,,, -147900,3.7814937,1.7160635,,,,,,,,,,,,,, -148000,3.9330828,2.9127617,,,,,,,,,,,,,, -148082,,,0.7983788847923279,0.792547345161438,0.7223399877548218,1.120306372642517,50000.0,0.6007000207901001,1.7311818599700928,10000.0,67677.93710446358,73839.43036198616,67677.93710446358,6146.32323050499,7.481757164001465,0.0 -148100,3.959629,1.8048081,,,,,,,,,,,,,, -148200,4.188024,1.6531373,,,,,,,,,,,,,, -148300,3.9700933,1.8319153,,,,,,,,,,,,,, -148400,4.488134,1.8366065,,,,,,,,,,,,,, -148500,4.026237,1.8026335,,,,,,,,,,,,,, -148600,3.7837734,3.3328714,,,,,,,,,,,,,, -148700,4.0582237,1.6858267,,,,,,,,,,,,,, -148800,4.4760604,1.7826402,,,,,,,,,,,,,, -148900,4.0923624,2.215555,,,,,,,,,,,,,, -149000,,,0.7928906083106995,0.8123624920845032,0.7275399565696716,1.103825926780701,50000.0,0.6040000319480896,1.7259600162506104,10000.0,68097.8760201931,74300.64919734001,68097.8760201931,6187.508858203888,7.528716564178467,0.0 -149000,4.591581,4.103639,,,,,,,,,,,,,, -149100,4.2147193,1.688875,,,,,,,,,,,,,, -149200,4.2955117,1.6881806,,,,,,,,,,,,,, -149300,3.9640584,1.6710713,,,,,,,,,,,,,, -149400,3.919529,1.5684347,,,,,,,,,,,,,, -149500,4.3749423,1.7600536,,,,,,,,,,,,,, -149600,4.143764,1.79484,,,,,,,,,,,,,, -149700,4.485236,1.8155732,,,,,,,,,,,,,, -149800,4.727943,4.18392,,,,,,,,,,,,,, -149900,4.03795,3.4835339,,,,,,,,,,,,,, -149923,,,0.7988671660423279,0.78203284740448,0.7282399535179138,1.0843335390090942,50000.0,0.6043000221252441,1.7013198137283323,10000.0,68518.12698411942,74755.73614430428,68518.12698411942,6222.238587379456,7.587629318237305,0.0 -150000,4.00448,2.762413,,,,,,,,,,,,,, -150100,4.6001825,1.7541718,,,,,,,,,,,,,, -150200,4.271413,3.49839,,,,,,,,,,,,,, -150300,4.131317,2.548918,,,,,,,,,,,,,, -150400,4.2352676,1.9597085,,,,,,,,,,,,,, -150500,4.492259,1.7340498,,,,,,,,,,,,,, -150600,4.498889,2.0734851,,,,,,,,,,,,,, -150700,4.217382,1.6729031,,,,,,,,,,,,,, -150800,4.619823,1.7770212,,,,,,,,,,,,,, -150845,,,0.8016406297683716,0.7594988346099854,0.7286399602890015,1.0782972574234009,50000.0,0.6041000485420227,1.6799514293670654,10000.0,68938.50674057007,75219.40028524399,68938.50674057007,6265.426444530487,7.636116981506348,0.0 -150900,4.1916313,1.8354223,,,,,,,,,,,,,, -151000,4.676916,2.0307345,,,,,,,,,,,,,, -151100,3.8764951,2.0425854,,,,,,,,,,,,,, -151200,4.6487207,3.9639463,,,,,,,,,,,,,, -151300,4.1961484,3.213544,,,,,,,,,,,,,, -151400,4.6053243,3.9697099,,,,,,,,,,,,,, -151500,4.4969044,4.080545,,,,,,,,,,,,,, -151600,4.031825,2.331336,,,,,,,,,,,,,, -151700,4.2864113,1.7716216,,,,,,,,,,,,,, -151767,,,0.7989062070846558,0.7856873273849487,0.7301799654960632,1.0794763565063477,50000.0,0.6041000485420227,1.696513533592224,10000.0,69358.68147015572,75683.92640280724,69358.68147015572,6309.68102478981,7.684762239456177,0.0 -151800,5.1865377,1.6730655,,,,,,,,,,,,,, -151900,4.259785,1.8488033,,,,,,,,,,,,,, -152000,4.6382737,1.7729651,,,,,,,,,,,,,, -152100,4.2895713,1.5659122,,,,,,,,,,,,,, -152200,4.214574,1.9807835,,,,,,,,,,,,,, -152300,4.536845,1.5434482,,,,,,,,,,,,,, -152400,4.5296063,3.2379405,,,,,,,,,,,,,, -152500,4.3800645,1.8665826,,,,,,,,,,,,,, -152600,4.407249,1.6154069,,,,,,,,,,,,,, -152688,,,0.8006835579872131,0.7903160452842712,0.7302599549293518,1.0970721244812012,50000.0,0.6077000498771667,1.7170206308364868,10000.0,69778.69618415833,76142.76267409325,69778.69618415833,6348.403820991516,7.73581075668335,0.0 -152700,4.808617,3.8030348,,,,,,,,,,,,,, -152800,4.3921223,3.565123,,,,,,,,,,,,,, -152900,4.385203,1.610255,,,,,,,,,,,,,, -153000,4.276395,1.51781,,,,,,,,,,,,,, -153100,4.488126,1.5105878,,,,,,,,,,,,,, -153200,4.771284,1.6568053,,,,,,,,,,,,,, -153300,4.748705,3.6323917,,,,,,,,,,,,,, -153400,4.349521,1.7784079,,,,,,,,,,,,,, -153500,4.320283,2.228474,,,,,,,,,,,,,, -153600,4.528769,1.668211,,,,,,,,,,,,,, -153610,,,0.8105077743530273,0.7433952689170837,0.7332199811935425,1.0680303573608398,50000.0,0.6107000112533569,1.6835938692092896,10000.0,70198.95512890816,76602.52445936203,70198.95512890816,6387.8042669296265,7.789788961410522,0.0 -153700,5.115653,1.6545197,,,,,,,,,,,,,, -153800,4.6449356,2.1260154,,,,,,,,,,,,,, -153900,4.6374598,1.6780038,,,,,,,,,,,,,, -154000,4.5777855,1.5810738,,,,,,,,,,,,,, -154100,4.6499534,2.8496222,,,,,,,,,,,,,, -154200,4.635721,1.4989718,,,,,,,,,,,,,, -154300,4.8185015,1.6663901,,,,,,,,,,,,,, -154400,4.6438622,3.5515559,,,,,,,,,,,,,, -154500,4.8069983,3.4196126,,,,,,,,,,,,,, -154531,,,0.8120312094688416,0.7386379241943359,0.7362799644470215,1.065176486968994,50000.0,0.6134000420570374,1.6804908514022827,10000.0,70619.00755596161,77063.54085183144,70619.00755596161,6428.671512365341,7.83747386932373,0.0 -154600,4.5890794,2.2696254,,,,,,,,,,,,,, -154700,4.4241157,3.260969,,,,,,,,,,,,,, -154800,4.1621256,2.1889462,,,,,,,,,,,,,, -154900,5.1252303,4.13458,,,,,,,,,,,,,, -155000,4.512636,1.4670796,,,,,,,,,,,,,, -155100,4.9612665,2.7141654,,,,,,,,,,,,,, -155200,4.5663033,1.6760471,,,,,,,,,,,,,, -155300,4.9795904,4.0560093,,,,,,,,,,,,,, -155400,4.614015,3.2342842,,,,,,,,,,,,,, -155453,,,0.8109374642372131,0.7333440780639648,0.7377199530601501,1.0530829429626465,50000.0,0.6100000143051147,1.664387345314026,10000.0,71039.42457556725,77521.84715628624,71039.42457556725,6466.462705373764,7.888170719146728,0.0 -155500,4.974203,1.5855836,,,,,,,,,,,,,, -155600,4.7203083,1.6943965,,,,,,,,,,,,,, -155700,4.683048,1.7431978,,,,,,,,,,,,,, -155800,5.1625185,3.3675103,,,,,,,,,,,,,, -155900,4.478182,2.064907,,,,,,,,,,,,,, -156000,5.1465006,3.5958216,,,,,,,,,,,,,, -156100,4.615215,1.8647329,,,,,,,,,,,,,, -156200,4.9304895,1.6560763,,,,,,,,,,,,,, -156300,4.5199466,2.3343978,,,,,,,,,,,,,, -156371,,,0.81103515625,0.74477618932724,0.737060010433197,1.0577281713485718,50000.0,0.6124000549316406,1.6711875200271606,10000.0,71459.36435127258,77982.90768504143,71459.36435127258,6507.477152347565,7.9470250606536865,0.0 -156400,4.9013376,1.5255831,,,,,,,,,,,,,, -156500,4.908048,3.4226325,,,,,,,,,,,,,, -156600,4.741331,3.1598542,,,,,,,,,,,,,, -156700,6.2416224,4.0131855,,,,,,,,,,,,,, -156800,4.468327,1.5930761,,,,,,,,,,,,,, -156900,4.7563987,2.7891955,,,,,,,,,,,,,, -157000,5.0490146,4.0191183,,,,,,,,,,,,,, -157100,5.135355,1.5144775,,,,,,,,,,,,,, -157200,4.973935,1.5838882,,,,,,,,,,,,,, -157290,,,0.82044917345047,0.6939221620559692,0.7407599687576294,1.043515920639038,50000.0,0.6168000102043152,1.6577025651931765,10000.0,71879.43958759308,78441.13249969482,71879.43958759308,6545.526882886887,7.999570369720459,0.0 -157300,4.9586644,3.1434703,,,,,,,,,,,,,, -157400,4.473293,2.4561183,,,,,,,,,,,,,, -157500,4.831864,2.6141841,,,,,,,,,,,,,, -157600,4.2056503,1.9961436,,,,,,,,,,,,,, -157700,4.795693,3.0749817,,,,,,,,,,,,,, -157800,4.7739515,1.5178325,,,,,,,,,,,,,, -157900,6.3519936,4.0333796,,,,,,,,,,,,,, -158000,4.6067348,2.718253,,,,,,,,,,,,,, -158100,5.102044,1.6808977,,,,,,,,,,,,,, -158200,6.118121,4.004696,,,,,,,,,,,,,, -158209,,,0.8133788704872131,0.7271475791931152,0.7434799671173096,1.0289294719696045,50000.0,0.6183000206947327,1.6337950229644775,10000.0,72299.37812900543,78900.9204583168,72299.37812900543,6585.272683382034,8.054925918579102,0.0 -158300,5.27526,1.6065509,,,,,,,,,,,,,, -158400,4.6488895,3.1019382,,,,,,,,,,,,,, -158500,6.1459804,3.9214168,,,,,,,,,,,,,, -158600,4.6092424,1.9095813,,,,,,,,,,,,,, -158700,4.6013207,1.6680219,,,,,,,,,,,,,, -158800,4.7337723,1.4912316,,,,,,,,,,,,,, -158900,5.005406,1.5668638,,,,,,,,,,,,,, -159000,4.8075833,1.4851322,,,,,,,,,,,,,, -159100,5.516959,3.3064203,,,,,,,,,,,,,, -159129,,,0.8163671493530273,0.6988460421562195,0.745419979095459,1.0122644901275637,50000.0,0.6193000078201294,1.6151468753814695,10000.0,72719.4393901825,79362.80851197243,72719.4393901825,6626.9921362400055,8.114330053329468,0.0 -159200,5.4144197,3.6523252,,,,,,,,,,,,,, -159300,4.994037,1.6709247,,,,,,,,,,,,,, -159400,5.12008,1.6922932,,,,,,,,,,,,,, -159500,5.411773,1.7696196,,,,,,,,,,,,,, -159600,5.037551,1.5378351,,,,,,,,,,,,,, -159700,4.997058,2.7663164,,,,,,,,,,,,,, -159800,4.689405,2.6945088,,,,,,,,,,,,,, -159900,5.3041835,1.4891506,,,,,,,,,,,,,, -160000,6.1314836,3.6327758,,,,,,,,,,,,,, -160052,,,0.8247265219688416,0.6779054403305054,0.7456199526786804,1.016263723373413,50000.0,0.6236000061035156,1.628363013267517,10000.0,73139.74932837486,79824.29507088661,73139.74932837486,6668.070779085159,8.16342830657959,0.0 -160100,6.1116977,1.6460289,,,,,,,,,,,,,, -160200,5.3648653,1.5515522,,,,,,,,,,,,,, -160300,4.856235,1.5155256,,,,,,,,,,,,,, -160400,4.9195976,1.5166228,,,,,,,,,,,,,, -160500,5.4511323,1.507396,,,,,,,,,,,,,, -160600,5.2550178,1.6041132,,,,,,,,,,,,,, -160700,5.1461654,1.5357854,,,,,,,,,,,,,, -160800,5.1411147,2.867905,,,,,,,,,,,,,, -160900,4.8176293,2.1712217,,,,,,,,,,,,,, -160972,,,0.8173242211341858,0.7082533240318298,0.745199978351593,1.023060321807861,50000.0,0.6212000250816345,1.6267614364624023,10000.0,73559.80233550072,80281.88353705406,73559.80233550072,6705.505133152008,8.21670150756836,0.0 -161000,5.015244,2.9369297,,,,,,,,,,,,,, -161100,4.927444,1.7541587,,,,,,,,,,,,,, -161200,4.8908296,2.1343231,,,,,,,,,,,,,, -161300,5.5451818,1.55314,,,,,,,,,,,,,, -161400,5.1732173,2.6949081,,,,,,,,,,,,,, -161500,5.3262525,3.373444,,,,,,,,,,,,,, -161600,5.597375,1.5339286,,,,,,,,,,,,,, -161700,4.533393,2.4230828,,,,,,,,,,,,,, -161800,6.531079,3.9707189,,,,,,,,,,,,,, -161891,,,0.8249804377555847,0.6753319501876831,0.7469199895858765,1.015852451324463,50000.0,0.6193000078201294,1.627989649772644,10000.0,73979.69502210617,80744.27344155312,73979.69502210617,6747.892434120178,8.278788328170776,0.0 -161900,5.3514614,2.1225786,,,,,,,,,,,,,, -162000,5.5743713,3.2506497,,,,,,,,,,,,,, -162100,5.307767,2.0769677,,,,,,,,,,,,,, -162200,5.126891,1.6614903,,,,,,,,,,,,,, -162300,5.53888,1.9486005,,,,,,,,,,,,,, -162400,5.2783422,1.3710803,,,,,,,,,,,,,, -162500,6.1838036,3.9043324,,,,,,,,,,,,,, -162600,5.076686,2.2237122,,,,,,,,,,,,,, -162700,6.0669127,3.4513173,,,,,,,,,,,,,, -162800,5.4035163,2.268243,,,,,,,,,,,,,, -162812,,,0.8253905773162842,0.6590047478675842,0.7476199865341187,1.0020933151245115,50000.0,0.6238000392913818,1.611265778541565,10000.0,74399.87388181686,81203.53269195557,74399.87388181686,6786.871521949768,8.332364797592163,0.0 -162900,5.2766705,1.5031614,,,,,,,,,,,,,, -163000,5.182979,3.0154252,,,,,,,,,,,,,, -163100,7.1490397,3.9918745,,,,,,,,,,,,,, -163200,5.728311,1.5288211,,,,,,,,,,,,,, -163300,5.0885587,2.5798247,,,,,,,,,,,,,, -163400,5.0937324,1.9294069,,,,,,,,,,,,,, -163500,5.470737,1.5195706,,,,,,,,,,,,,, -163600,5.528411,1.4266752,,,,,,,,,,,,,, -163700,5.2739096,1.7493055,,,,,,,,,,,,,, -163733,,,0.8265624642372131,0.6701897382736206,0.7502399682998657,0.9970006942749025,50000.0,0.6267000436782837,1.6085433959960938,10000.0,74820.23160123825,81662.96415829659,74820.23160123825,6825.848484277725,8.381687641143799,0.0 -163800,6.158342,3.8607597,,,,,,,,,,,,,, -163900,6.4367347,3.6695664,,,,,,,,,,,,,, -164000,5.6691093,1.4715161,,,,,,,,,,,,,, -164100,5.4740777,1.4421076,,,,,,,,,,,,,, -164200,5.5023,1.7836742,,,,,,,,,,,,,, -164300,6.4372325,3.1524713,,,,,,,,,,,,,, -164400,5.6212564,1.9391991,,,,,,,,,,,,,, -164500,5.6672773,2.990661,,,,,,,,,,,,,, -164600,5.277829,1.2779509,,,,,,,,,,,,,, -164650,,,0.8306249976158142,0.6445387601852417,0.7535799741744995,0.9883765578269958,50000.0,0.6266000270843506,1.5907145738601685,10000.0,75240.45762014389,82124.06794404984,75240.45762014389,6866.622145652771,8.43875503540039,0.0 -164700,5.7694936,2.4821577,,,,,,,,,,,,,, -164800,5.6495934,1.4743509,,,,,,,,,,,,,, -164900,5.056056,2.170929,,,,,,,,,,,,,, -165000,5.61228,1.6359239,,,,,,,,,,,,,, -165100,5.435202,1.6413857,,,,,,,,,,,,,, -165200,5.5485444,1.6104096,,,,,,,,,,,,,, -165300,5.859115,1.5656688,,,,,,,,,,,,,, -165400,5.505248,2.97891,,,,,,,,,,,,,, -165500,5.842414,1.6149821,,,,,,,,,,,,,, -165572,,,0.8291601538658142,0.6529331207275391,0.7529000043869019,0.9910786747932434,50000.0,0.6296000480651855,1.5999563932418823,10000.0,75660.44420909882,82583.20623064041,75660.44420909882,6905.6733481884,8.490631818771362,0.0 -165600,5.6018806,2.8621862,,,,,,,,,,,,,, -165700,6.006352,3.6618886,,,,,,,,,,,,,, -165800,5.789296,2.021218,,,,,,,,,,,,,, -165900,5.842295,1.4118106,,,,,,,,,,,,,, -166000,5.2167788,1.49767,,,,,,,,,,,,,, -166100,5.2102294,1.5249964,,,,,,,,,,,,,, -166200,6.011717,2.4212365,,,,,,,,,,,,,, -166300,5.27266,2.5966597,,,,,,,,,,,,,, -166400,5.3796577,1.8528887,,,,,,,,,,,,,, -166493,,,0.8376562595367432,0.6226816177368164,0.7539199590682983,0.9782753586769104,50000.0,0.6313000321388245,1.5768295526504517,10000.0,76080.50569367409,83042.480260849,76080.50569367409,6944.787847995758,8.540288209915161,0.0 -166500,5.683836,1.3248987,,,,,,,,,,,,,, -166600,6.497204,3.6627448,,,,,,,,,,,,,, -166700,5.5805745,1.466314,,,,,,,,,,,,,, -166800,5.391396,1.5453762,,,,,,,,,,,,,, -166900,5.730395,1.4199603,,,,,,,,,,,,,, -167000,6.33018,1.5096979,,,,,,,,,,,,,, -167100,5.497583,1.7034501,,,,,,,,,,,,,, -167200,6.38162,2.2126927,,,,,,,,,,,,,, -167300,5.525241,2.9758227,,,,,,,,,,,,,, -167400,5.6802325,1.4692357,,,,,,,,,,,,,, -167415,,,0.8343359231948853,0.6367462277412415,0.7556799650192261,0.9818673729896544,50000.0,0.6365000009536743,1.5672297477722168,10000.0,76500.47437024117,83502.85957407951,76500.47437024117,6985.094571828842,8.595833539962769,0.0 -167500,5.7346244,2.9308338,,,,,,,,,,,,,, -167600,5.6183963,1.4715071,,,,,,,,,,,,,, -167700,5.9725327,2.246973,,,,,,,,,,,,,, -167800,5.8721385,1.4747107,,,,,,,,,,,,,, -167900,5.8898835,1.3687906,,,,,,,,,,,,,, -168000,6.6967196,1.3312168,,,,,,,,,,,,,, -168100,5.858699,1.4971044,,,,,,,,,,,,,, -168200,5.9250817,1.4236311,,,,,,,,,,,,,, -168300,6.7478957,2.1485744,,,,,,,,,,,,,, -168333,,,0.8360351324081421,0.6291006207466125,0.756659984588623,0.974940836429596,50000.0,0.631600022315979,1.5703439712524414,10000.0,76920.42654371262,83964.43345236778,76920.42654371262,7026.618679523468,8.646256685256958,0.0 -168400,5.7566395,2.3326259,,,,,,,,,,,,,, -168500,5.4853616,2.3698385,,,,,,,,,,,,,, -168600,6.1450353,1.3961105,,,,,,,,,,,,,, -168700,6.007783,1.4366077,,,,,,,,,,,,,, -168800,5.981992,1.4234802,,,,,,,,,,,,,, -168900,6.2120223,1.4961917,,,,,,,,,,,,,, -169000,6.17253,1.4279152,,,,,,,,,,,,,, -169100,6.765363,3.4546897,,,,,,,,,,,,,, -169200,5.91596,1.7269332,,,,,,,,,,,,,, -169253,,,0.8400976657867432,0.60579913854599,0.7610599994659424,0.9543364644050598,50000.0,0.6389000415802002,1.5526543855667114,10000.0,77340.37170767784,84421.27886939049,77340.37170767784,7063.42170381546,8.696255445480347,0.0 -169300,5.9361362,2.4707284,,,,,,,,,,,,,, -169400,5.8456306,2.6312952,,,,,,,,,,,,,, -169500,5.8021693,1.4079075,,,,,,,,,,,,,, -169600,6.0520763,1.4860364,,,,,,,,,,,,,, -169654,,,,,,,,,,,77520.5035700798,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/eval_measurements.csv deleted file mode 100644 index b4e308d48..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -27.742810249328613,0.0,36.79507851600647,1,0,36.79507851600647,0.0010000000474974,6.907756805419922,10000,64.5379867553711,0.0010546874254941,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -68.91855096817017,0.0192360877990722,456.8279526233673,857,0,456.8279526233673,0.02730000205338,6.052794933319092,10000,525.8108751773834,0.0340429693460464,5.877542018890381,0.0329199992120265,5.916387557983398,50000 -107.68769526481628,0.0563008785247802,876.8367249965668,1774,0,876.8367249965668,0.0483000017702579,5.741621971130371,10000,984.6733191013336,0.0689257830381393,5.4872636795043945,0.062899999320507,5.551773071289063,50000 -145.71309757232666,0.0808191299438476,1296.802814245224,2694,0,1296.802814245224,0.0746000036597251,5.3333587646484375,10000,1442.7371714115145,0.1042382791638374,4.991183280944824,0.0974399968981742,5.053380012512207,50000 -189.81067943573,0.1053242683410644,1716.8630304336548,3612,0,1716.8630304336548,0.0956000015139579,5.0893425941467285,10000,1906.967833995819,0.1368554681539535,4.678333759307861,0.1284399926662445,4.747917175292969,50000 -236.6235709190369,0.1309516429901123,2136.9159696102142,4531,0,2136.9159696102142,0.1284000128507614,4.720849514007568,10000,2373.9058408737183,0.1862304657697677,4.207225322723389,0.170319989323616,4.318536758422852,50000 -274.4062368869781,0.1563773155212402,2557.1325867176056,5447,0,2557.1325867176056,0.1448000073432922,4.589404106140137,10000,2831.977271318436,0.201210930943489,4.060230731964111,0.1880399882793426,4.152525424957275,50000 -318.80083322525024,0.1846585273742675,2977.314273118973,6363,0,2977.314273118973,0.1531000137329101,4.517242431640625,10000,3296.628982782364,0.2166601568460464,3.984548330307007,0.2004199922084808,4.092609405517578,50000 -365.0135953426361,0.2192466259002685,3397.362140417099,7283,0,3397.362140417099,0.1666000038385391,4.42499589920044,10000,3762.9711685180655,0.2351367175579071,3.851510524749756,0.2165599912405014,3.9796018600463854,50000 -410.8689727783203,0.2473421096801757,3817.682046175003,8203,0,3817.682046175003,0.1744000017642974,4.296341419219971,10000,4229.221621990204,0.2627343833446502,3.629719734191896,0.2322999984025955,3.829036712646485,50000 -456.0230774879456,0.2729377746582031,4237.6647737026215,9121,0,4237.6647737026215,0.1869000047445297,4.236430168151856,10000,4694.430584430695,0.2595312297344208,3.638303518295288,0.2396799921989441,3.753018617630005,50000 -499.5223741531372,0.3000564575195312,4657.653786182404,10039,0,4657.653786182404,0.2008000165224075,4.119417667388916,10000,5157.994423866272,0.2796484231948852,3.492593288421631,0.2573399841785431,3.628937005996704,50000 -539.6884536743164,0.3250484466552734,5077.807441949844,10957,0,5077.807441949844,0.1965000033378601,4.154642581939697,10000,5618.391860961914,0.2962499856948852,3.400749683380127,0.2607399821281433,3.646686553955078,50000 -578.3044393062592,0.3541853427886963,5497.79346203804,11874,0,5497.79346203804,0.2082000076770782,4.0429534912109375,10000,6077.069966077805,0.2930664122104645,3.4094316959381104,0.2763800024986267,3.517740488052368,50000 -619.1004593372345,0.3812034130096435,5918.157521486282,12795,0,5918.157521486282,0.2123000174760818,4.034548282623291,10000,6538.304161548615,0.2978906035423279,3.414585590362549,0.2738800048828125,3.54900860786438,50000 -663.8211107254028,0.4091596603393554,6338.3181848526,13715,0,6338.3181848526,0.2126000076532364,4.019969940185547,10000,7003.260776519775,0.3108203113079071,3.344714641571045,0.2833400070667267,3.511706829071045,50000 -704.5228517055511,0.435704231262207,6758.649676322937,14635,0,6758.649676322937,0.2205000072717666,3.981424808502197,10000,7464.368344545364,0.3113867044448852,3.300966024398804,0.2888199985027313,3.432950496673584,50000 -749.0826478004456,0.462766170501709,7178.820302963257,15553,0,7178.820302963257,0.2326000183820724,3.875176668167114,10000,7929.172488689423,0.322558581829071,3.2005932331085205,0.3014000058174133,3.3318538665771484,50000 -789.4964265823364,0.4894022941589355,7599.0502026081085,16472,0,7599.0502026081085,0.2360000163316726,3.8591465950012207,10000,8389.891204357147,0.3410351574420929,3.152328491210937,0.310259997844696,3.32299542427063,50000 -828.5178320407867,0.5192873477935791,8019.264442682266,17392,0,8019.264442682266,0.2270000129938125,3.917538166046143,10000,8849.203982591629,0.3284960985183716,3.2336535453796387,0.2993399798870086,3.389450788497925,50000 -871.812756061554,0.5486938953399658,8439.656270503998,18310,0,8439.656270503998,0.2187000066041946,4.029071807861328,10000,9312.967787742617,0.3080468773841858,3.3751299381256104,0.2909199893474579,3.5047309398651123,50000 -917.766725063324,0.5775036811828613,8859.802419900894,19230,0,8859.802419900894,0.2408000081777572,3.847225189208984,10000,9779.14450263977,0.3353515565395355,3.115917921066284,0.3131999969482422,3.278340578079224,50000 -959.7255573272704,0.6077971458435059,9280.090950250626,20149,0,9280.090950250626,0.2436000108718872,3.796807527542114,10000,10241.47033715248,0.381640613079071,2.8914709091186523,0.3214599788188934,3.214576482772827,50000 -1004.2897145748138,0.6361017227172852,9700.320873737335,21070,0,9700.320873737335,0.2556000053882599,3.724861860275269,10000,10706.340615034103,0.3535351455211639,2.995561361312866,0.3300800025463104,3.140366554260254,50000 -1052.8985612392426,0.662848949432373,10120.517220973969,21990,0,10120.517220973969,0.2473000138998031,3.81091570854187,10000,11175.220227003098,0.3414843678474426,3.077301025390625,0.3179799914360046,3.2187304496765137,50000 -1094.0847754478457,0.6952741146087646,10540.53809094429,22908,0,10540.53809094429,0.2469000071287155,3.816385269165039,10000,11636.506449222565,0.3515820205211639,3.0414631366729736,0.3215000033378601,3.234511137008667,50000 -1141.8376290798187,0.7260003089904785,10960.580304384232,23825,0,10960.580304384232,0.2626000046730041,3.7027549743652335,10000,12104.379473924637,0.3559374809265136,3.025038719177246,0.3379800021648407,3.137194871902466,50000 -1183.7002770900726,0.754091739654541,11380.95326280594,24746,0,11380.95326280594,0.2560999989509582,3.7653188705444336,10000,12566.690264701843,0.3567968606948852,3.0097639560699463,0.3316799998283386,3.17022705078125,50000 -1228.4372079372406,0.7881994247436523,11801.243542194366,25665,0,11801.243542194366,0.2555000185966491,3.74962329864502,10000,13031.799020767212,0.3615429699420929,3.016035795211792,0.3340199887752533,3.1701815128326416,50000 -1275.3340392112732,0.8167154788970947,12221.656418085098,26585,0,12221.656418085098,0.247400015592575,3.812144756317138,10000,13499.184258461,0.346992164850235,3.0902798175811768,0.3211599886417389,3.2521159648895264,50000 -1314.660932779312,0.8498325347900391,12641.995535612106,27501,0,12641.995535612106,0.2553000152111053,3.7012126445770255,10000,13958.93049645424,0.3630273342132568,2.980881929397583,0.3391999900341034,3.110235452651977,50000 -1352.705206155777,0.8814163208007812,13062.072076559069,28418,0,13062.072076559069,0.2589000165462494,3.6658432483673096,10000,14417.129843950272,0.3716992139816284,2.924288034439087,0.3419399857521057,3.0865392684936523,50000 -1396.1704378128052,0.9180092811584472,13482.143792390823,29335,0,13482.143792390823,0.2434000074863433,3.895664930343628,10000,14880.750834703444,0.3422265648841858,3.173394203186035,0.316100001335144,3.3170182704925537,50000 -1443.007826089859,0.9461681842803956,13902.161823272703,30255,0,13902.161823272703,0.2698000073432922,3.616143226623535,10000,15347.681805610657,0.3800390660762787,2.861891746520996,0.3527399897575378,3.0179009437561035,50000 -1482.2746975421906,0.9801223278045654,14322.417489528656,31176,0,14322.417489528656,0.2560999989509582,3.68998384475708,10000,15807.286099910736,0.3661523461341858,2.9577431678771973,0.3359200060367584,3.120572328567505,50000 -1522.337045431137,1.0161523818969729,14742.738109588625,32094,0,14742.738109588625,0.2699000239372253,3.6354501247406006,10000,16267.7527885437,0.4045312404632568,2.735917806625366,0.3504199981689453,3.0416879653930664,50000 -1568.1595079898834,1.0463981628417969,15162.715245962145,33013,0,15162.715245962145,0.2725000083446502,3.637152194976807,10000,16733.630709409714,0.371406227350235,2.9292590618133545,0.3455399870872497,3.0680296421051025,50000 -1610.6403777599337,1.0798285007476809,15583.04086136818,33933,0,15583.04086136818,0.2675000131130218,3.668169736862183,10000,17196.518713235855,0.3720703125,2.9328975677490234,0.3447999954223633,3.0781383514404297,50000 -1653.1295936107635,1.1132240295410156,16003.265714883804,34852,0,16003.265714883804,0.270900011062622,3.5934412479400635,10000,17659.312923192978,0.3985156118869781,2.755237579345703,0.3614400029182434,2.958653926849365,50000 -1700.3952662944794,1.143989086151123,16423.598749876022,35772,0,16423.598749876022,0.2740000188350677,3.591972827911377,10000,18126.9892642498,0.3848632872104645,2.830144166946411,0.3606799840927124,2.9782936573028564,50000 -1744.3594100475311,1.1727678775787354,16843.694502830505,36693,0,16843.694502830505,0.2692000269889831,3.625471591949463,10000,18591.12539958954,0.3790820240974426,2.873992919921875,0.3569999933242798,3.0040409564971924,50000 -1791.871794462204,1.2091057300567627,17263.87330675125,37613,0,17263.87330675125,0.2735000252723694,3.5844669342041016,10000,19058.90007901192,0.3933398425579071,2.799267768859864,0.3593399822711944,2.989340543746948,50000 -1834.498204946518,1.2384934425354004,17683.964739084244,38530,0,17683.964739084244,0.2800000011920929,3.546329498291016,10000,19521.694784641262,0.3886132836341858,2.7967445850372314,0.3657599985599518,2.934712648391724,50000 -1879.0127630233765,1.272782802581787,18104.35695052147,39451,0,18104.35695052147,0.2826000154018402,3.571634531021118,10000,19986.68339204788,0.3868554532527923,2.8592474460601807,0.3622399866580963,2.997967004776001,50000 -1927.0022106170647,1.308027267456055,18524.528084754944,40369,0,18524.528084754944,0.2865000069141388,3.507483720779419,10000,20454.92676472664,0.4025976359844208,2.729934930801392,0.3704600036144256,2.898860216140747,50000 -1969.225081205368,1.340603590011597,18944.64669013024,41288,0,18944.64669013024,0.2812000215053558,3.5632994174957275,10000,20917.34801101685,0.4137109220027923,2.72369384765625,0.3606199920177459,2.9944801330566406,50000 -2007.7696409225464,1.3750572204589844,19364.88948059082,42206,0,19364.88948059082,0.2829000055789947,3.547908067703247,10000,21376.21737074852,0.388964831829071,2.8523201942443848,0.367279976606369,2.9817123413085938,50000 -2053.67297244072,1.4091379642486572,19785.169471025467,43125,0,19785.169471025467,0.2886000275611877,3.5031073093414307,10000,21842.482117176056,0.4010742008686065,2.7512407302856445,0.3735399842262268,2.8998680114746094,50000 -2099.034994125366,1.4414191246032717,20205.232219696045,44045,0,20205.232219696045,0.2928000092506408,3.449684858322144,10000,22307.98716187477,0.429003894329071,2.601717233657837,0.3832799792289734,2.852113723754883,50000 -2144.228601694107,1.4792120456695557,20625.31547307968,44962,0,20625.31547307968,0.2863000035285949,3.5115420818328857,10000,22773.348356485367,0.4001562297344208,2.801137208938598,0.3734200000762939,2.930062055587769,50000 -2191.300128698349,1.5172581672668457,21045.717796325684,45882,0,21045.717796325684,0.2107000052928924,4.131436347961426,10000,23240.907628774643,0.2884374856948852,3.4999566078186035,0.2685799896717071,3.610232830047608,50000 -2232.58002948761,1.55698823928833,21465.82685112953,46801,0,21465.82685112953,0.2831000089645386,3.567754030227661,10000,23702.38392972946,0.3951171934604645,2.819166421890259,0.3638199865818023,3.004495143890381,50000 -2275.397953271866,1.593160629272461,21886.131243228912,47723,0,21886.131243228912,0.2997000217437744,3.443194627761841,10000,24165.59008526802,0.4143163859844208,2.6614110469818115,0.3860200047492981,2.80824875831604,50000 -2322.138329029084,1.6305792331695557,22306.22730517388,48642,0,22306.22730517388,0.2904000282287597,3.4814419746398926,10000,24632.510838747025,0.4010742008686065,2.7443270683288574,0.3771199882030487,2.8815295696258545,50000 -2366.9990010261536,1.662933588027954,22726.256851911545,49559,0,22726.256851911545,0.2870000004768371,3.4870080947875977,10000,25097.48213648796,0.4129492044448852,2.7169084548950195,0.3819800019264221,2.885247707366944,50000 -2412.7322528362274,1.6987645626068115,23146.49708509445,50478,0,23146.49708509445,0.3005000054836273,3.3917243480682373,10000,25563.53849577904,0.4268554449081421,2.6102540493011475,0.39751997590065,2.780481100082397,50000 -2461.1722581386566,1.734628677368164,23566.498861551285,51397,0,23566.498861551285,0.3053000271320343,3.3814618587493896,10000,26032.06483864784,0.429492175579071,2.56564998626709,0.4009999930858612,2.729902505874634,50000 -2505.948692560196,1.7748017311096191,23986.7883477211,52315,0,23986.7883477211,0.3107000291347503,3.378159523010254,10000,26497.21853017807,0.4299023449420929,2.5807902812957764,0.398059993982315,2.752584934234619,50000 -2552.8182249069214,1.8105263710021973,24406.975200414658,53234,0,24406.975200414658,0.3115000128746032,3.353266954421997,10000,26964.357704639435,0.4723632633686065,2.374985694885254,0.4041999876499176,2.73072361946106,50000 -2603.2553062438965,1.846671342849732,24827.1175699234,54153,0,24827.1175699234,0.3023000061511993,3.398660659790039,10000,27435.02015376091,0.4276367127895355,2.59726619720459,0.3992999792098999,2.763692140579224,50000 -2646.8072805404663,1.887636423110962,25247.319982767105,55074,0,25247.319982767105,0.3185000121593475,3.2860732078552246,10000,27898.86329269409,0.4376757740974426,2.500518798828125,0.4105999767780304,2.6731464862823486,50000 -2690.1619765758514,1.9212877750396729,25667.46424293518,55995,0,25667.46424293518,0.302700012922287,3.414133310317993,10000,28362.4433825016,0.4432226419448852,2.532923936843872,0.402319997549057,2.7546207904815674,50000 -2733.7377874851227,1.959881067276001,26087.7837445736,56916,0,26087.7837445736,0.3193000257015228,3.296461582183838,10000,28826.425322294235,0.4407812356948852,2.5251173973083496,0.4111399948596954,2.6807515621185303,50000 -2778.325141429901,1.99947476387024,26507.91575574875,57836,0,26507.91575574875,0.3222000300884247,3.286940097808838,10000,29291.23188686371,0.4381054639816284,2.522041320800781,0.4074999988079071,2.682382345199585,50000 -2824.6300699710846,2.0336482524871826,26927.859172344208,58754,0,26927.859172344208,0.3222000300884247,3.29914927482605,10000,29757.56119155884,0.4493750035762787,2.502846002578736,0.4158999919891357,2.6864118576049805,50000 -2870.763954639435,2.066673040390014,27347.924685001373,59675,0,27347.924685001373,0.3230000138282776,3.316578388214112,10000,30223.841745615005,0.4404882788658142,2.5459964275360107,0.4094399809837341,2.710213661193848,50000 -2915.4320845603943,2.101196527481079,27768.17107129097,60594,0,27768.17107129097,0.3331000208854675,3.216318130493164,10000,30688.83791780472,0.453437477350235,2.441464424133301,0.4258799850940704,2.599683284759521,50000 -2963.3469684124,2.1426825523376465,28188.53337931633,61513,0,28188.53337931633,0.3228000104427337,3.2577946186065674,10000,31157.204171419144,0.454902321100235,2.44576096534729,0.4205799996852875,2.644345283508301,50000 -3008.009134531021,2.1812844276428223,28608.84869074821,62433,0,28608.84869074821,0.3226000070571899,3.3033504486083984,10000,31622.267642736435,0.4523828029632568,2.508619546890259,0.4140399992465973,2.6991686820983887,50000 -3056.772082090378,2.215562343597412,29028.77939391136,63353,0,29028.77939391136,0.3113000094890594,3.3364946842193604,10000,32091.04265642166,0.4460742175579071,2.4987034797668457,0.4134199917316437,2.6845717430114746,50000 -3104.6899168491364,2.250218152999878,29448.9363090992,64271,0,29448.9363090992,0.3242000043392181,3.2666985988616943,10000,32559.199233531952,0.451464831829071,2.464915037155152,0.4184199869632721,2.648669719696045,50000 -3148.381046772003,2.2855420112609863,29869.14924716949,65189,0,29869.14924716949,0.33160001039505,3.247788906097412,10000,33023.18640756607,0.4797070324420929,2.2820422649383545,0.4228200018405914,2.5929930210113525,50000 -3187.614446878433,2.323702335357666,30289.283529758453,66108,0,30289.283529758453,0.3427000045776367,3.1849935054779053,10000,33482.63961672783,0.4628124833106994,2.4145824909210205,0.4338599741458893,2.566502094268799,50000 -3232.7051644325256,2.362233638763428,30709.503110408783,67028,0,30709.503110408783,0.3293000161647796,3.2058255672454834,10000,33948.036386966705,0.4659765660762787,2.3797836303710938,0.4309599995613098,2.57188081741333,50000 -3279.18199968338,2.3966870307922363,31129.867817878723,67948,0,31129.867817878723,0.3377000093460083,3.203951835632324,10000,34414.95987582207,0.4738866984844208,2.352745532989502,0.4340199828147888,2.5681827068328857,50000 -3324.0429894924164,2.434640884399414,31550.06188440323,68868,0,31550.06188440323,0.3359000086784363,3.207452774047852,10000,34880.10146713257,0.4649609327316284,2.447537422180176,0.4329399764537811,2.587732553482056,50000 -3370.650318622589,2.472820997238159,31970.368980884552,69789,0,31970.368980884552,0.3391000032424927,3.190361976623535,10000,35347.10148000717,0.4654882848262787,2.4043633937835693,0.4331599771976471,2.563265800476074,50000 -3412.957806110382,2.5098109245300293,32390.55333662033,70706,0,32390.55333662033,0.3436000049114227,3.145872116088867,10000,35809.67774987221,0.4787109196186065,2.317901611328125,0.439520001411438,2.5207505226135254,50000 -3460.0244784355164,2.551076889038086,32810.782964229584,71625,0,32810.782964229584,0.3367000222206116,3.232973098754883,10000,36277.062376499176,0.4616210758686065,2.4654088020324707,0.4322199821472168,2.6213440895080566,50000 -3505.9291064739227,2.587003231048584,33230.88841557503,72547,0,33230.88841557503,0.3489000201225281,3.1029651165008545,10000,36743.15631175041,0.4809960722923279,2.3176419734954834,0.4526999890804291,2.466217041015625,50000 -3551.381100177765,2.621609687805176,33650.96217060089,73466,0,33650.96217060089,0.3473000228404999,3.121434211730957,10000,37208.76399350166,0.4806054532527923,2.2896952629089355,0.4437399804592132,2.492424249649048,50000 -3598.5108897686005,2.663499355316162,34071.35045528412,74386,0,34071.35045528412,0.3522000312805176,3.1258039474487305,10000,37676.37067079544,0.5169921517372131,2.1353821754455566,0.4467199742794037,2.486428737640381,50000 -3638.625265598297,2.711342096328736,34491.62985539436,75304,0,34491.62985539436,0.3513000309467315,3.071357011795044,10000,38136.85993814469,0.4877148270606994,2.240132570266724,0.4559399783611297,2.4114251136779785,50000 -3681.6170043945312,2.75275993347168,34911.86199808121,76223,0,34911.86199808121,0.3555000126361847,3.057042121887207,10000,38600.17257928848,0.4906249940395355,2.2219138145446777,0.4575199782848358,2.4190688133239746,50000 -3727.388257026672,2.79075288772583,35332.00664615631,77142,0,35332.00664615631,0.3538000285625458,3.080111503601074,10000,39066.17379283905,0.5056250095367432,2.1703872680664062,0.4499799907207489,2.44785213470459,50000 -3771.449672698975,2.83611798286438,35752.21084976196,78062,0,35752.21084976196,0.3588000237941742,3.084175109863281,10000,39530.53333187103,0.4858007729053497,2.2651078701019287,0.4527799785137176,2.438197374343872,50000 -3816.614407777786,2.875884532928467,36172.48115777969,78982,0,36172.48115777969,0.34620001912117,3.136114835739136,10000,39996.05541801453,0.4868945181369781,2.2959415912628174,0.4545799791812897,2.4721052646636963,50000 -3865.289860725403,2.9212820529937744,36592.814427137375,79903,0,36592.814427137375,0.3619000315666199,3.0350892543792725,10000,40465.15758180618,0.4998827874660492,2.2071332931518555,0.4567599892616272,2.4251766204833984,50000 -3906.299390554428,2.9639267921447754,37013.1857984066,80823,0,37013.1857984066,0.367900013923645,2.989685773849488,10000,40926.628831624985,0.4993554651737213,2.188851356506348,0.471560001373291,2.3428831100463867,50000 -3950.904098033905,3.0028018951416016,37433.493015527725,81739,0,37433.493015527725,0.368800014257431,2.9926578998565674,10000,41391.62658786774,0.5030664205551147,2.173933744430542,0.4730799794197082,2.3356215953826904,50000 -3997.634565114975,3.0432190895080566,37853.55633306503,82657,0,37853.55633306503,0.3644000291824341,3.0233078002929688,10000,41858.507304906845,0.5142968893051147,2.146437168121338,0.472379982471466,2.358328104019165,50000 -4034.875780582428,3.087830781936645,38273.708990097046,83577,0,38273.708990097046,0.3646000027656555,2.994065284729004,10000,42315.99324274063,0.5132030844688416,2.140690803527832,0.4748199880123138,2.3312788009643555,50000 -4080.1282589435577,3.1285057067871094,38693.9015994072,84497,0,38693.9015994072,0.3582000136375427,3.0756309032440186,10000,42781.52650523186,0.5048632621765137,2.2373056411743164,0.4640399813652038,2.4101569652557373,50000 -4122.668626070023,3.1721973419189453,39114.07578897476,85417,0,39114.07578897476,0.3726000189781189,2.975393772125244,10000,43244.331899404526,0.5148242115974426,2.1396749019622803,0.4755999743938446,2.3394620418548584,50000 -4169.881982803345,3.214378833770752,39534.43977665901,86337,0,39534.43977665901,0.3779000043869018,2.9314382076263428,10000,43711.998577833176,0.5555468797683716,1.912348389625549,0.4817200005054474,2.272268295288086,50000 -4213.158932924271,3.2534549236297607,39954.65167856216,87257,0,39954.65167856216,0.3801000118255615,2.9268369674682617,10000,44175.57442951202,0.52099609375,2.096512079238892,0.4859399795532226,2.274020195007324,50000 -4260.838416099548,3.29453992843628,40374.94868397713,88175,0,40374.94868397713,0.3754000067710876,2.916119337081909,10000,44643.63994884491,0.5269726514816284,2.0405514240264893,0.4899399876594543,2.239298105239868,50000 -4305.934587717056,3.3338735103607178,40795.08370661736,89095,0,40795.08370661736,0.3803000152111053,2.940920829772949,10000,45108.95787191391,0.5325976610183716,2.0417697429656982,0.4813999831676483,2.3064215183258057,50000 -4351.538655757904,3.375528812408448,41215.38111019135,90014,0,41215.38111019135,0.3824000060558319,2.936632871627808,10000,45574.949031591415,0.5163866877555847,2.1249258518218994,0.4807799756526947,2.301429033279419,50000 -4396.368488788605,3.417140483856201,41635.74610328674,90933,0,41635.74610328674,0.3911000192165375,2.831852436065674,10000,46040.23263311386,0.5344530940055847,2.001258134841919,0.4990399777889251,2.194621324539185,50000 -4444.340796709061,3.457115888595581,42056.00476717949,91852,0,42056.00476717949,0.388700008392334,2.87404203414917,10000,46508.55019235611,0.5383593440055847,2.007728815078736,0.4928599894046783,2.2256059646606445,50000 -4493.034183979034,3.495344877243042,42476.11767077446,92769,0,42476.11767077446,0.39410001039505,2.844502687454224,10000,46977.44144821167,0.5292577743530273,2.0604593753814697,0.4968799948692322,2.218428134918213,50000 -4539.431235074997,3.540785312652588,42896.11478877068,93690,0,42896.11478877068,0.3966000080108642,2.847208023071289,10000,47443.92788076401,0.5306054353713989,2.020060777664185,0.497979998588562,2.209547281265259,50000 -4587.36496925354,3.5851848125457764,43316.36288237572,94609,0,43316.36288237572,0.3892000317573547,2.920865297317505,10000,47912.20147418976,0.5318945050239563,2.0958914756774902,0.4930999875068664,2.292987108230591,50000 -4636.007810115814,3.625463962554932,43736.64577579498,95526,0,43736.64577579498,0.3992000222206116,2.828856468200684,10000,48381.214473724365,0.5541210770606995,1.9405015707015991,0.504859983921051,2.179647922515869,50000 -4677.541831970215,3.672870397567749,44156.965695381165,96446,0,44156.965695381165,0.3907000124454498,2.8464391231536865,10000,48843.16408109665,0.5423046946525574,2.020184278488159,0.5077199935913086,2.1947858333587646,50000 -4725.14670586586,3.7132389545440674,44577.15295219421,97366,0,44577.15295219421,0.395300030708313,2.858224391937256,10000,49311.0440621376,0.5404687523841858,2.009705781936645,0.5008000135421753,2.208786725997925,50000 -4773.47830247879,3.75544548034668,44997.3252222538,98286,0,44997.3252222538,0.3963000178337097,2.810023069381714,10000,49779.63790535927,0.5715429782867432,1.849209904670716,0.5137400031089783,2.139185667037964,50000 -4815.0956864357,3.802260398864746,45417.43864130974,99206,0,45417.43864130974,0.401600033044815,2.7917513847351074,10000,50241.46690893173,0.5505663752555847,1.9494009017944336,0.5107200145721436,2.1389076709747314,50000 -4863.8861446380615,3.843090057373047,45837.39217829704,100126,0,45837.39217829704,0.4137000143527984,2.7351529598236084,10000,50710.29930901528,0.5661718845367432,1.871565222740173,0.5231800079345703,2.0831315517425537,50000 -4910.821338653564,3.887923240661621,46257.652686834335,101047,0,46257.652686834335,0.4107000231742859,2.7359278202056885,10000,51177.58735990524,0.5694140791893005,1.840009093284607,0.5202800035476685,2.0846471786499023,50000 -4956.939247131348,3.932757616043091,46677.740731954575,101966,0,46677.740731954575,0.4108000099658966,2.763245105743408,10000,51643.885112285614,0.5561913847923279,1.9271278381347656,0.5159400105476379,2.123799085617065,50000 -5004.784141540527,3.972791910171509,47097.76620817184,102885,0,47097.76620817184,0.4234000146389007,2.709082841873169,10000,52111.843133449554,0.5678125023841858,1.848743557929993,0.5289799571037292,2.049794435501098,50000 -5054.160320997238,4.023122072219849,47517.937237262726,103806,0,47517.937237262726,0.4193000197410583,2.696749448776245,10000,52581.48854184151,0.5733789205551147,1.824051141738892,0.5295799970626831,2.0383520126342773,50000 -5101.870931148529,4.066912651062012,47937.92220687866,104725,0,47937.92220687866,0.4147000312805176,2.714286088943481,10000,53049.274918079376,0.567187488079071,1.864588022232056,0.5290799736976624,2.053553581237793,50000 -5152.026482105255,4.110436916351318,48358.190257549286,105644,0,48358.190257549286,0.4223000109195709,2.6632344722747803,10000,53519.78946995735,0.5702148079872131,1.839844822883606,0.5308600068092346,2.028477191925049,50000 -5196.08935046196,4.1553053855896,48778.23556470871,106563,0,48778.23556470871,0.4170000255107879,2.710397720336914,10000,53983.98978614807,0.5782421827316284,1.855036258697509,0.5310800075531006,2.070237874984741,50000 -5243.95282626152,4.197072982788086,49198.48667836189,107482,0,49198.48667836189,0.4292000234127044,2.670915603637696,10000,54452.19374775887,0.6131640672683716,1.644121527671814,0.536899983882904,2.003920078277588,50000 -5287.324712753296,4.244433879852295,49619.11729288101,108401,0,49619.11729288101,0.4228000342845917,2.666642427444458,10000,54916.2911529541,0.5776953101158142,1.8307543992996216,0.5416600108146667,2.010225296020508,50000 -5334.740091085434,4.290728807449341,50039.2133743763,109321,0,50039.2133743763,0.429500013589859,2.6396257877349854,10000,55383.89688038826,0.5839257836341858,1.781617283821106,0.5431399941444397,1.9895758628845213,50000 -5380.403715848923,4.336857795715332,50459.24368357658,110242,0,50459.24368357658,0.4374000132083893,2.595966100692749,10000,55849.68517708778,0.6037304401397705,1.6702401638031006,0.5479999780654907,1.953895568847656,50000 -5422.70272564888,4.410296440124512,50879.184905052185,111161,0,50879.184905052185,0.4337000250816345,2.628989696502685,10000,56312.04675197601,0.5822851657867432,1.7880244255065918,0.5461999773979187,1.9647815227508545,50000 -5471.136475563049,4.45197343826294,51299.41596865654,112079,0,51299.41596865654,0.4411000311374664,2.573200464248657,10000,56780.80129766464,0.5941405892372131,1.7171932458877563,0.5562199950218201,1.896606683731079,50000 -5518.939175367355,4.49985933303833,51719.630373477936,112999,0,51719.630373477936,0.4442000091075897,2.5536797046661377,10000,57248.9131758213,0.6062109470367432,1.663967847824097,0.555899977684021,1.902474045753479,50000 -5568.422976732254,4.548093795776367,52139.66808462143,113919,0,52139.66808462143,0.4449000358581543,2.562439203262329,10000,57718.53077292442,0.594042956829071,1.7198957204818726,0.5557599663734436,1.9055322408676147,50000 -5614.12481713295,4.5912158489227295,52559.89040374756,114839,0,52559.89040374756,0.4402000308036804,2.587394952774048,10000,58184.54452776909,0.6001757383346558,1.7225245237350464,0.5551599860191345,1.9314780235290527,50000 -5660.633631229401,4.6373395919799805,52980.26725935936,115757,0,52980.26725935936,0.4474000334739685,2.5390818119049072,10000,58651.52395033837,0.6089257597923279,1.6470164060592651,0.5604599714279175,1.8817566633224487,50000 -5707.661694765091,4.6801183223724365,53400.54282641411,116677,0,53400.54282641411,0.4478000104427337,2.53030014038086,10000,59118.91864323616,0.613476574420929,1.6504943370819092,0.5635200142860413,1.876099228858948,50000 -5753.440539121628,4.725946426391602,53820.59474277496,117596,0,53820.59474277496,0.4523000121116638,2.5318267345428467,10000,59584.84307575226,0.6048437356948853,1.6839405298233032,0.5666399598121643,1.874297022819519,50000 -5801.996683120728,4.77754282951355,54240.80654287338,118516,0,54240.80654287338,0.4612000286579132,2.4730336666107178,10000,60053.70967531204,0.617871105670929,1.593552827835083,0.5720199942588806,1.8123306035995483,50000 -5845.860307693481,4.820557355880737,54660.86072707176,119434,0,54660.86072707176,0.4488000273704529,2.510444164276123,10000,60517.71708583832,0.6364452838897705,1.5332109928131104,0.5678399801254272,1.8640707731246948,50000 -5892.424654722214,4.865030288696289,55081.26514601708,120352,0,55081.26514601708,0.4611000120639801,2.518050193786621,10000,60984.77663850784,0.6158202886581421,1.6390470266342163,0.5762199759483337,1.8413300514221191,50000 -5937.952259302139,4.911031007766724,55501.56821870804,121273,0,55501.56821870804,0.4592000246047973,2.49786114692688,10000,61450.70113730431,0.6201562285423279,1.615761399269104,0.5776799917221069,1.8188271522521973,50000 -5988.703502893448,4.95950722694397,55921.78595089912,122194,0,55921.78595089912,0.4688000082969665,2.44468355178833,10000,61921.7659611702,0.6444921493530273,1.5002819299697876,0.583139955997467,1.7808510065078735,50000 -6033.838894367218,5.008333683013916,56341.82621026039,123114,0,56341.82621026039,0.463200032711029,2.4657235145568848,10000,62387.038128614426,0.6229882836341858,1.5936886072158811,0.5823000073432922,1.7864725589752195,50000 -6082.30672454834,5.055681467056274,56761.87673950195,124031,0,56761.87673950195,0.4722000360488891,2.412415742874145,10000,62855.65078186989,0.6308007836341858,1.54508376121521,0.5869199633598328,1.7575558423995972,50000 -6123.739371538162,5.101463317871094,57181.95409989357,124950,0,57181.95409989357,0.4794000089168548,2.372661590576172,10000,63317.253702163696,0.6486718654632568,1.460990309715271,0.5928199887275696,1.732962131500244,50000 -6170.100564241409,5.144175052642822,57601.99688029289,125869,0,57601.99688029289,0.4811000227928161,2.383357286453247,10000,63783.74777674675,0.636523425579071,1.5250691175460815,0.5936599969863892,1.7224498987197876,50000 -6214.136778354645,5.194725275039673,58022.25514602661,126790,0,58022.25514602661,0.4742000102996826,2.401140689849853,10000,64248.13990950584,0.6359765529632568,1.5459275245666504,0.588979959487915,1.755147933959961,50000 -6261.555802345276,5.246778249740601,58442.36844062805,127710,0,58442.36844062805,0.4843000173568725,2.344245195388794,10000,64715.771369457245,0.6518163681030273,1.4520666599273682,0.600059986114502,1.690841794013977,50000 -6306.545022726059,5.292807579040527,58862.76211476326,128630,0,58862.76211476326,0.4792000353336334,2.3771519660949707,10000,65181.24786877632,0.6659374833106995,1.415810465812683,0.5946199893951416,1.7279531955718994,50000 -6350.070232391357,5.337663650512695,59283.08157444,129550,0,59283.08157444,0.4895000159740448,2.340092182159424,10000,65645.18449640274,0.6496288776397705,1.4651296138763428,0.6078599691390991,1.6644974946975708,50000 -6395.686897754669,5.386809349060059,59703.20712137222,130468,0,59703.20712137222,0.4879000186920166,2.341444730758667,10000,66111.03629493713,0.6561523079872131,1.4382413625717163,0.6042400002479553,1.6798323392868042,50000 -6441.799973964691,5.43379020690918,60123.49499297142,131389,0,60123.49499297142,0.4835000336170196,2.3345508575439453,10000,66577.5312511921,0.6699999570846558,1.370593547821045,0.6080399751663208,1.6591289043426514,50000 -6486.980547428131,5.481791257858276,60543.795784950256,132307,0,60543.795784950256,0.4922000169754028,2.302058219909668,10000,67043.1083316803,0.6594530940055847,1.4391076564788818,0.6137199997901917,1.645622968673706,50000 -6534.7361924648285,5.529500961303711,60963.91841840744,133224,0,60963.91841840744,0.4946000277996063,2.308748245239258,10000,67511.08133149147,0.6636718511581421,1.4158997535705566,0.6115800142288208,1.6641314029693604,50000 -6583.84783744812,5.574039936065674,61384.02610969544,134142,0,61384.02610969544,0.4912000298500061,2.289562702178955,10000,67980.39304852486,0.6744531393051147,1.3474888801574707,0.6168599724769592,1.62572181224823,50000 -6631.8838946819305,5.623882532119751,61804.2365424633,135061,0,61804.2365424633,0.493800014257431,2.297254800796509,10000,68448.73585152626,0.6664648056030273,1.4257535934448242,0.6180999875068665,1.6305259466171265,50000 -6677.498216867447,5.672069072723389,62224.20766711235,135975,0,62224.20766711235,0.5,2.278455972671509,10000,68914.41670441628,0.6676562428474426,1.3916078805923462,0.6185399889945984,1.6243489980697632,50000 -6725.043194055557,5.722181797027588,62644.19136214256,136889,0,62644.19136214256,0.5021000504493713,2.23734450340271,10000,69382.04277920723,0.6796679496765137,1.3299959897994995,0.6280199885368347,1.5712664127349854,50000 -6772.678442955017,5.771515846252441,63064.48580121994,137808,0,63064.48580121994,0.5060000419616699,2.2221529483795166,10000,69850.06992220879,0.6735937595367432,1.3464230298995972,0.6296399831771851,1.558292031288147,50000 -6820.7501039505005,5.830377578735352,63484.58924078941,138727,0,63484.58924078941,0.509600043296814,2.221627950668335,10000,70318.35100626945,0.6751171946525574,1.3565669059753418,0.628600001335144,1.5680512189865112,50000 -6867.18776845932,5.888091802597046,63904.95752739906,139644,0,63904.95752739906,0.5097000002861023,2.198205947875977,10000,70785.26147270203,0.6910156011581421,1.2635481357574463,0.6337400078773499,1.5286023616790771,50000 -6913.686674833298,5.933101177215576,64325.088150024414,140560,0,64325.088150024414,0.5170000195503235,2.171229839324951,10000,71251.98325324059,0.7136914134025574,1.185178518295288,0.6374599933624268,1.5168884992599487,50000 -6960.035108566284,5.978979587554932,64745.282984018326,141479,0,64745.282984018326,0.5174000263214111,2.166109800338745,10000,71718.62048172951,0.6903125047683716,1.27427077293396,0.6414600014686584,1.5009756088256836,50000 -7004.11513710022,6.038940191268921,65165.53315329552,142395,0,65165.53315329552,0.5193000435829163,2.146132469177246,10000,72183.05815005302,0.6958788633346558,1.2445529699325562,0.6466599702835083,1.4790695905685425,50000 -7051.486275434494,6.0886406898498535,65585.63843154907,143312,0,65585.63843154907,0.5238000154495239,2.134950399398804,10000,72650.63191390038,0.7118749618530273,1.1865516901016235,0.6448000073432922,1.483980417251587,50000 -7096.882295846939,6.138426780700684,66005.99985575676,144228,0,66005.99985575676,0.5264000296592712,2.108189821243286,10000,73116.48621463776,0.6990429759025574,1.2320014238357544,0.6526199579238892,1.4438382387161257,50000 -7144.226463794708,6.186929225921631,66425.925065279,145147,0,66425.925065279,0.5321000218391418,2.10010838508606,10000,73583.85144233704,0.7080078125,1.2053229808807373,0.6554399728775024,1.4488352537155151,50000 -7192.638860940933,6.239060163497925,66846.09749627113,146066,0,66846.09749627113,0.5300000309944153,2.0999159812927246,10000,74052.53604912758,0.7203124761581421,1.154877781867981,0.6532799601554871,1.4429513216018677,50000 -7239.931909561157,6.285334825515747,67266.39178800583,146984,0,67266.39178800583,0.5374000072479248,2.074611186981201,10000,74520.21671199799,0.708789050579071,1.1932083368301392,0.6591599583625793,1.4202784299850464,50000 -7286.852725744247,6.344372987747192,67686.7302069664,147901,0,67686.7302069664,0.5418000221252441,2.045015811920166,10000,74987.58259272575,0.7182226181030273,1.1583360433578491,0.6647399663925171,1.4021943807601929,50000 -7334.890476465225,6.396288871765137,68107.07111549377,148818,0,68107.07111549377,0.5412000417709351,2.0354931354522705,10000,75456.06113243103,0.7275585532188416,1.1043344736099243,0.6648199558258057,1.3845757246017456,50000 -7378.8897659778595,6.444957733154297,68527.06919646263,149733,0,68527.06919646263,0.5511000156402588,2.020998001098633,10000,75920.15477275848,0.7219140529632568,1.1348050832748413,0.6680399775505066,1.3798613548278809,50000 -7428.695145845413,6.494152307510376,68947.41572165489,150651,0,68947.41572165489,0.5471000075340271,2.0204763412475586,10000,76390.40303850174,0.7286132574081421,1.118777871131897,0.6726199984550476,1.3689404726028442,50000 -7475.611652612686,6.545987844467163,69367.38053894043,151569,0,69367.38053894043,0.5521000027656555,1.9848430156707764,10000,76857.38411259651,0.7353515625,1.0825318098068235,0.6772199869155884,1.343247890472412,50000 -7522.364356279373,6.606665372848511,69787.35775566101,152486,0,69787.35775566101,0.5507000088691711,1.9955780506134035,10000,77324.2224123478,0.7512304782867432,1.0053696632385254,0.6732400059700012,1.3476532697677612,50000 -7569.859260797501,6.660884857177734,70207.6869328022,153405,0,70207.6869328022,0.54830002784729,1.9810631275177,10000,77792.1483130455,0.7347655892372131,1.0813597440719604,0.6788199543952942,1.333372950553894,50000 -7617.13103890419,6.713019609451294,70627.66973996162,154324,0,70627.66973996162,0.5578000545501709,1.96280300617218,10000,78259.50230240822,0.7437695264816284,1.0343624353408811,0.6834799647331238,1.3101791143417358,50000 -7664.961159944534,6.767343282699585,71047.77196288109,155243,0,71047.77196288109,0.554900050163269,1.97342312335968,10000,78727.53687787056,0.7542773485183716,1.0230635404586792,0.6821399927139282,1.3291398286819458,50000 -7709.447716712952,6.819774866104126,71467.86416912079,156162,0,71467.86416912079,0.5647000074386597,1.930980205535889,10000,79192.2151761055,0.7413867115974426,1.0416297912597656,0.6865999698638916,1.2964109182357788,50000 -7756.980162143707,6.879199504852295,71887.97418117523,157080,0,71887.97418117523,0.5675000548362732,1.918201804161072,10000,79659.96469688416,0.7487695217132568,1.0085744857788086,0.6894800066947937,1.2726929187774658,50000 -7801.326666116714,6.932545900344849,72307.93619275093,157995,0,72307.93619275093,0.5671000480651855,1.9162464141845703,10000,80124.37371206284,0.7559570074081421,0.9801447987556458,0.6914199590682983,1.2745487689971924,50000 -7847.6235818862915,6.992994785308838,72728.1712462902,158915,0,72728.1712462902,0.569100022315979,1.889735460281372,10000,80591.012966156,0.7525194883346558,1.00764000415802,0.6936399936676025,1.2615045309066772,50000 -7896.201631784439,7.052062034606934,73148.25042939186,159832,0,73148.25042939186,0.5731000304222107,1.876285195350647,10000,81059.77536559105,0.7587695121765137,0.973080575466156,0.6987999677658081,1.2378002405166626,50000 -7941.881090164185,7.1026670932769775,73568.35084533691,160750,0,73568.35084533691,0.5751000046730042,1.8870264291763303,10000,81525.65334272385,0.7638476490974426,0.9550341367721558,0.7001399993896484,1.2401357889175415,50000 -7986.558121442795,7.155567407608032,73988.73741221428,161669,0,73988.73741221428,0.5799000263214111,1.860588908195496,10000,81990.81713628769,0.7666015625,0.9483379125595092,0.7057600021362305,1.2242356538772583,50000 -8036.057965755463,7.209298849105835,74408.8857395649,162587,0,74408.8857395649,0.5809000134468079,1.8519184589385984,10000,82460.56594634056,0.7708789110183716,0.9288029074668884,0.7047399878501892,1.2123870849609375,50000 -8081.126464605331,7.264871120452881,74829.21450185776,163508,0,74829.21450185776,0.5873000025749207,1.826006293296814,10000,82926.06588101387,0.7753124833106995,0.9027910232543944,0.7118799686431885,1.1837961673736572,50000 -8128.030569314957,7.324113607406616,75249.25192141533,164424,0,75249.25192141533,0.5870000123977661,1.8180652856826784,10000,83393.11365270615,0.7840234041213989,0.861545979976654,0.7120199799537659,1.1774080991744995,50000 -8175.50325012207,7.37749457359314,75669.19865012169,165344,0,75669.19865012169,0.5918000340461731,1.806179404258728,10000,83860.63411188126,0.78076171875,0.8917228579521179,0.7150999903678894,1.1734659671783447,50000 -8220.088045358658,7.440832138061523,76089.25742650032,166265,0,76089.25742650032,0.5913000106811523,1.798190951347351,10000,84325.38896870613,0.7815039157867432,0.8719916939735413,0.7160800099372864,1.1593176126480105,50000 -8266.85658288002,7.495449066162109,76509.34638428688,167181,0,76509.34638428688,0.5931000113487244,1.7854821681976318,10000,84792.34748697281,0.789843738079071,0.8536946177482605,0.7184199690818787,1.1556397676467896,50000 -8314.475160121918,7.552535772323608,76929.57842612267,168099,0,76929.57842612267,0.5954000353813171,1.7669183015823364,10000,85260.30331659317,0.7865429520606995,0.8519309759140015,0.7182199954986572,1.1389554738998413,50000 -8359.18827176094,7.6130571365356445,77349.57205319405,169019,0,77349.57205319405,0.5969000458717346,1.7649180889129639,10000,85725.11896824837,0.7899804711341858,0.8419513702392578,0.7230599522590637,1.1312774419784546,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/measurements.csv deleted file mode 100644 index 2e34e1b31..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1881 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36472693,6.907756,,,,,,,,,,,,,, -1,,,0.0010546874254941,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,36.79507851600647,64.5379867553711,36.79507851600647,27.742810249328613,0.0,0.0 -100,0.5627554,6.8144197,,,,,,,,,,,,,, -200,0.705497,6.8115516,,,,,,,,,,,,,, -300,1.3117896,6.5412836,,,,,,,,,,,,,, -400,1.5205872,6.4915724,,,,,,,,,,,,,, -500,0.7366597,6.606179,,,,,,,,,,,,,, -600,1.0433236,6.368125,,,,,,,,,,,,,, -700,0.8857277,6.3438463,,,,,,,,,,,,,, -800,1.0854686,6.2481647,,,,,,,,,,,,,, -857,,,0.0340429693460464,5.877542018890381,0.0329199992120265,5.916387557983398,50000.0,0.02730000205338,6.052794933319092,10000.0,456.8279526233673,525.8108751773834,456.8279526233673,68.91855096817017,0.0192360877990722,0.0 -900,0.76999956,6.17892,,,,,,,,,,,,,, -1000,0.72942793,6.234711,,,,,,,,,,,,,, -1100,1.2184175,6.553703,,,,,,,,,,,,,, -1200,0.64025897,6.0244446,,,,,,,,,,,,,, -1300,0.5693675,6.544106,,,,,,,,,,,,,, -1400,0.5923701,6.0648427,,,,,,,,,,,,,, -1500,0.5805814,6.5796432,,,,,,,,,,,,,, -1600,0.56197804,5.778795,,,,,,,,,,,,,, -1700,0.73778707,5.955011,,,,,,,,,,,,,, -1774,,,0.0689257830381393,5.4872636795043945,0.062899999320507,5.551773071289063,50000.0,0.0483000017702579,5.741621971130371,10000.0,876.8367249965668,984.6733191013336,876.8367249965668,107.68769526481628,0.0563008785247802,0.0 -1800,0.662715,5.9027395,,,,,,,,,,,,,, -1900,0.64274263,5.831435,,,,,,,,,,,,,, -2000,0.64286804,6.327641,,,,,,,,,,,,,, -2100,0.4822713,5.688785,,,,,,,,,,,,,, -2200,0.66541064,5.7360897,,,,,,,,,,,,,, -2300,0.52302307,5.686928,,,,,,,,,,,,,, -2400,0.5734452,5.6023664,,,,,,,,,,,,,, -2500,0.48403943,6.4255595,,,,,,,,,,,,,, -2600,0.57498187,5.5601063,,,,,,,,,,,,,, -2694,,,0.1042382791638374,4.991183280944824,0.0974399968981742,5.053380012512207,50000.0,0.0746000036597251,5.3333587646484375,10000.0,1296.802814245224,1442.7371714115145,1296.802814245224,145.71309757232666,0.0808191299438476,0.0 -2700,0.36949357,5.9788685,,,,,,,,,,,,,, -2800,0.5815197,5.6011095,,,,,,,,,,,,,, -2900,0.57039315,5.493022,,,,,,,,,,,,,, -3000,0.4539903,6.1520705,,,,,,,,,,,,,, -3100,0.7026512,5.5276694,,,,,,,,,,,,,, -3200,0.5565749,6.4908304,,,,,,,,,,,,,, -3300,0.5400921,5.38653,,,,,,,,,,,,,, -3400,0.82941794,5.542297,,,,,,,,,,,,,, -3500,0.4914247,5.344983,,,,,,,,,,,,,, -3600,0.47830382,5.404106,,,,,,,,,,,,,, -3612,,,0.1368554681539535,4.678333759307861,0.1284399926662445,4.747917175292969,50000.0,0.0956000015139579,5.0893425941467285,10000.0,1716.8630304336548,1906.967833995819,1716.8630304336548,189.81067943573,0.1053242683410644,0.0 -3700,0.42733645,5.831993,,,,,,,,,,,,,, -3800,0.546798,5.5682487,,,,,,,,,,,,,, -3900,0.77705693,5.1966906,,,,,,,,,,,,,, -4000,0.5873163,6.146512,,,,,,,,,,,,,, -4100,0.7450739,5.517932,,,,,,,,,,,,,, -4200,0.633565,5.1245785,,,,,,,,,,,,,, -4300,0.55589515,5.1049266,,,,,,,,,,,,,, -4400,0.5585599,5.915453,,,,,,,,,,,,,, -4500,0.7168802,5.052782,,,,,,,,,,,,,, -4531,,,0.1862304657697677,4.207225322723389,0.170319989323616,4.318536758422852,50000.0,0.1284000128507614,4.720849514007568,10000.0,2136.9159696102142,2373.9058408737183,2136.9159696102142,236.6235709190369,0.1309516429901123,0.0 -4600,0.49633306,6.415806,,,,,,,,,,,,,, -4700,0.70446503,5.1383233,,,,,,,,,,,,,, -4800,0.8774619,5.42448,,,,,,,,,,,,,, -4900,0.6492806,6.348214,,,,,,,,,,,,,, -5000,0.59557146,6.05902,,,,,,,,,,,,,, -5100,0.7448314,4.869713,,,,,,,,,,,,,, -5200,0.5489806,5.2218757,,,,,,,,,,,,,, -5300,0.7165824,4.911006,,,,,,,,,,,,,, -5400,0.70397836,5.635292,,,,,,,,,,,,,, -5447,,,0.201210930943489,4.060230731964111,0.1880399882793426,4.152525424957275,50000.0,0.1448000073432922,4.589404106140137,10000.0,2557.1325867176056,2831.977271318436,2557.1325867176056,274.4062368869781,0.1563773155212402,0.0 -5500,0.8548816,5.119095,,,,,,,,,,,,,, -5600,0.6436126,5.599205,,,,,,,,,,,,,, -5700,0.65167177,5.1634817,,,,,,,,,,,,,, -5800,0.7134519,4.8135295,,,,,,,,,,,,,, -5900,0.5902551,5.386789,,,,,,,,,,,,,, -6000,0.80844194,4.6870356,,,,,,,,,,,,,, -6100,0.9097263,6.3646336,,,,,,,,,,,,,, -6200,0.5987021,5.7158036,,,,,,,,,,,,,, -6300,0.73025006,6.1074486,,,,,,,,,,,,,, -6363,,,0.2166601568460464,3.984548330307007,0.2004199922084808,4.092609405517578,50000.0,0.1531000137329101,4.517242431640625,10000.0,2977.314273118973,3296.628982782364,2977.314273118973,318.80083322525024,0.1846585273742675,0.0 -6400,0.787678,4.7421265,,,,,,,,,,,,,, -6500,0.8103517,6.2327776,,,,,,,,,,,,,, -6600,0.7225499,4.7150517,,,,,,,,,,,,,, -6700,0.73343754,4.6780224,,,,,,,,,,,,,, -6800,0.6481285,6.2236876,,,,,,,,,,,,,, -6900,0.8183423,4.8982625,,,,,,,,,,,,,, -7000,0.72345567,4.772881,,,,,,,,,,,,,, -7100,0.9767349,4.6963015,,,,,,,,,,,,,, -7200,0.8325658,5.9280834,,,,,,,,,,,,,, -7283,,,0.2351367175579071,3.851510524749756,0.2165599912405014,3.9796018600463854,50000.0,0.1666000038385391,4.42499589920044,10000.0,3397.362140417099,3762.9711685180655,3397.362140417099,365.0135953426361,0.2192466259002685,0.0 -7300,0.679044,4.5425663,,,,,,,,,,,,,, -7400,0.77596414,6.228465,,,,,,,,,,,,,, -7500,0.8407165,4.6219697,,,,,,,,,,,,,, -7600,0.8601063,4.5829387,,,,,,,,,,,,,, -7700,0.8569208,4.789126,,,,,,,,,,,,,, -7800,0.7948568,4.646235,,,,,,,,,,,,,, -7900,0.7865219,5.7042704,,,,,,,,,,,,,, -8000,1.0382078,4.714123,,,,,,,,,,,,,, -8100,0.8199321,4.7480135,,,,,,,,,,,,,, -8200,0.73187983,4.8780594,,,,,,,,,,,,,, -8203,,,0.2627343833446502,3.629719734191896,0.2322999984025955,3.829036712646485,50000.0,0.1744000017642974,4.296341419219971,10000.0,3817.682046175003,4229.221621990204,3817.682046175003,410.8689727783203,0.2473421096801757,0.0 -8300,0.88218766,4.644953,,,,,,,,,,,,,, -8400,0.7787541,5.9004364,,,,,,,,,,,,,, -8500,0.80316836,4.9018445,,,,,,,,,,,,,, -8600,1.0304723,4.576035,,,,,,,,,,,,,, -8700,0.7663795,5.7709675,,,,,,,,,,,,,, -8800,0.927043,4.4706674,,,,,,,,,,,,,, -8900,0.6445471,5.740564,,,,,,,,,,,,,, -9000,0.6758703,6.194179,,,,,,,,,,,,,, -9100,0.7953122,6.299217,,,,,,,,,,,,,, -9121,,,0.2595312297344208,3.638303518295288,0.2396799921989441,3.753018617630005,50000.0,0.1869000047445297,4.236430168151856,10000.0,4237.6647737026215,4694.430584430695,4237.6647737026215,456.0230774879456,0.2729377746582031,0.0 -9200,0.84868395,4.380267,,,,,,,,,,,,,, -9300,0.70903516,5.41928,,,,,,,,,,,,,, -9400,0.8567464,4.7835684,,,,,,,,,,,,,, -9500,0.85494834,4.9708114,,,,,,,,,,,,,, -9600,0.87677026,4.482435,,,,,,,,,,,,,, -9700,0.65695983,5.3565745,,,,,,,,,,,,,, -9800,0.8493451,4.3973985,,,,,,,,,,,,,, -9900,0.74658245,4.619902,,,,,,,,,,,,,, -10000,0.77566504,5.6927505,,,,,,,,,,,,,, -10039,,,0.2796484231948852,3.492593288421631,0.2573399841785431,3.628937005996704,50000.0,0.2008000165224075,4.119417667388916,10000.0,4657.653786182404,5157.994423866272,4657.653786182404,499.5223741531372,0.3000564575195312,0.0 -10100,0.7796448,4.3788614,,,,,,,,,,,,,, -10200,0.92356837,4.6826835,,,,,,,,,,,,,, -10300,0.85949427,4.4310436,,,,,,,,,,,,,, -10400,1.0098915,4.2507086,,,,,,,,,,,,,, -10500,0.7909661,4.3293033,,,,,,,,,,,,,, -10600,0.7702766,5.7445354,,,,,,,,,,,,,, -10700,0.8207916,5.9185143,,,,,,,,,,,,,, -10800,0.9778715,4.26232,,,,,,,,,,,,,, -10900,0.84245276,4.7452354,,,,,,,,,,,,,, -10957,,,0.2962499856948852,3.400749683380127,0.2607399821281433,3.646686553955078,50000.0,0.1965000033378601,4.154642581939697,10000.0,5077.807441949844,5618.391860961914,5077.807441949844,539.6884536743164,0.3250484466552734,0.0 -11000,0.8958429,4.582363,,,,,,,,,,,,,, -11100,0.79823446,4.5375166,,,,,,,,,,,,,, -11200,0.6836418,5.5032597,,,,,,,,,,,,,, -11300,0.756907,4.267768,,,,,,,,,,,,,, -11400,0.8934971,5.231907,,,,,,,,,,,,,, -11500,0.8620104,4.606867,,,,,,,,,,,,,, -11600,0.94005805,4.5018625,,,,,,,,,,,,,, -11700,0.7546855,4.4055185,,,,,,,,,,,,,, -11800,0.86195344,4.313469,,,,,,,,,,,,,, -11874,,,0.2930664122104645,3.4094316959381104,0.2763800024986267,3.517740488052368,50000.0,0.2082000076770782,4.0429534912109375,10000.0,5497.79346203804,6077.069966077805,5497.79346203804,578.3044393062592,0.3541853427886963,0.0 -11900,0.9088511,4.295473,,,,,,,,,,,,,, -12000,0.91444,4.7952304,,,,,,,,,,,,,, -12100,1.250755,4.4324226,,,,,,,,,,,,,, -12200,0.85107094,4.4757614,,,,,,,,,,,,,, -12300,0.75118095,6.180607,,,,,,,,,,,,,, -12400,0.8135358,5.3740187,,,,,,,,,,,,,, -12500,0.6727339,5.858624,,,,,,,,,,,,,, -12600,0.865686,4.212356,,,,,,,,,,,,,, -12700,0.8630879,5.4241886,,,,,,,,,,,,,, -12795,,,0.2978906035423279,3.414585590362549,0.2738800048828125,3.54900860786438,50000.0,0.2123000174760818,4.034548282623291,10000.0,5918.157521486282,6538.304161548615,5918.157521486282,619.1004593372345,0.3812034130096435,0.0 -12800,0.7862712,4.644681,,,,,,,,,,,,,, -12900,0.73163795,5.91851,,,,,,,,,,,,,, -13000,0.67413044,5.4033194,,,,,,,,,,,,,, -13100,0.96390885,4.31857,,,,,,,,,,,,,, -13200,0.61237955,6.1701255,,,,,,,,,,,,,, -13300,0.96922666,4.377283,,,,,,,,,,,,,, -13400,0.94993347,4.2960057,,,,,,,,,,,,,, -13500,0.9612965,4.2411656,,,,,,,,,,,,,, -13600,0.86776775,4.2405257,,,,,,,,,,,,,, -13700,0.95252466,4.213069,,,,,,,,,,,,,, -13715,,,0.3108203113079071,3.344714641571045,0.2833400070667267,3.511706829071045,50000.0,0.2126000076532364,4.019969940185547,10000.0,6338.3181848526,7003.260776519775,6338.3181848526,663.8211107254028,0.4091596603393554,0.0 -13800,0.86219186,5.8277273,,,,,,,,,,,,,, -13900,1.1395388,4.2360806,,,,,,,,,,,,,, -14000,0.61945325,6.1177187,,,,,,,,,,,,,, -14100,0.8369025,4.248911,,,,,,,,,,,,,, -14200,0.818315,4.2027016,,,,,,,,,,,,,, -14300,0.849439,4.250916,,,,,,,,,,,,,, -14400,0.7699525,4.2954736,,,,,,,,,,,,,, -14500,0.8099179,4.1212964,,,,,,,,,,,,,, -14600,0.9214022,4.4954443,,,,,,,,,,,,,, -14635,,,0.3113867044448852,3.300966024398804,0.2888199985027313,3.432950496673584,50000.0,0.2205000072717666,3.981424808502197,10000.0,6758.649676322937,7464.368344545364,6758.649676322937,704.5228517055511,0.435704231262207,0.0 -14700,0.7811176,5.801953,,,,,,,,,,,,,, -14800,0.87573236,5.714575,,,,,,,,,,,,,, -14900,0.89975816,5.0546823,,,,,,,,,,,,,, -15000,0.68249345,5.218642,,,,,,,,,,,,,, -15100,0.8895592,4.172224,,,,,,,,,,,,,, -15200,0.7304497,6.0834465,,,,,,,,,,,,,, -15300,0.97608215,4.225033,,,,,,,,,,,,,, -15400,1.0199349,4.508821,,,,,,,,,,,,,, -15500,0.8274332,4.286849,,,,,,,,,,,,,, -15553,,,0.322558581829071,3.2005932331085205,0.3014000058174133,3.3318538665771484,50000.0,0.2326000183820724,3.875176668167114,10000.0,7178.820302963257,7929.172488689423,7178.820302963257,749.0826478004456,0.462766170501709,0.0 -15600,0.7877606,5.271439,,,,,,,,,,,,,, -15700,0.89070064,4.137834,,,,,,,,,,,,,, -15800,0.89507425,4.117602,,,,,,,,,,,,,, -15900,0.9790036,4.2177515,,,,,,,,,,,,,, -16000,0.93099487,4.1321907,,,,,,,,,,,,,, -16100,0.7698395,5.796771,,,,,,,,,,,,,, -16200,0.7086699,5.994072,,,,,,,,,,,,,, -16300,0.82344234,4.605895,,,,,,,,,,,,,, -16400,0.95552963,5.519059,,,,,,,,,,,,,, -16472,,,0.3410351574420929,3.152328491210937,0.310259997844696,3.32299542427063,50000.0,0.2360000163316726,3.8591465950012207,10000.0,7599.0502026081085,8389.891204357147,7599.0502026081085,789.4964265823364,0.4894022941589355,0.0 -16500,1.1460671,4.0632787,,,,,,,,,,,,,, -16600,0.9739134,4.111531,,,,,,,,,,,,,, -16700,0.94823974,4.3290315,,,,,,,,,,,,,, -16800,0.66152316,5.720789,,,,,,,,,,,,,, -16900,0.9154351,4.065495,,,,,,,,,,,,,, -17000,0.9757151,4.183246,,,,,,,,,,,,,, -17100,0.83378565,4.260674,,,,,,,,,,,,,, -17200,0.9879334,4.3172293,,,,,,,,,,,,,, -17300,0.78278655,4.673955,,,,,,,,,,,,,, -17392,,,0.3284960985183716,3.2336535453796387,0.2993399798870086,3.389450788497925,50000.0,0.2270000129938125,3.917538166046143,10000.0,8019.264442682266,8849.203982591629,8019.264442682266,828.5178320407867,0.5192873477935791,0.0 -17400,0.7851051,5.268024,,,,,,,,,,,,,, -17500,0.7246439,5.965317,,,,,,,,,,,,,, -17600,1.0551589,4.256782,,,,,,,,,,,,,, -17700,0.85694575,3.9918509,,,,,,,,,,,,,, -17800,0.90839505,4.355256,,,,,,,,,,,,,, -17900,0.81230485,5.065569,,,,,,,,,,,,,, -18000,0.7306942,4.767064,,,,,,,,,,,,,, -18100,0.8513537,5.311862,,,,,,,,,,,,,, -18200,0.9177392,4.722719,,,,,,,,,,,,,, -18300,0.921483,4.281657,,,,,,,,,,,,,, -18310,,,0.3080468773841858,3.3751299381256104,0.2909199893474579,3.5047309398651123,50000.0,0.2187000066041946,4.029071807861328,10000.0,8439.656270503998,9312.967787742617,8439.656270503998,871.812756061554,0.5486938953399658,0.0 -18400,0.8118628,6.1759977,,,,,,,,,,,,,, -18500,0.98051125,4.334296,,,,,,,,,,,,,, -18600,0.7629209,4.2672815,,,,,,,,,,,,,, -18700,1.0438411,4.082855,,,,,,,,,,,,,, -18800,1.1935805,4.2395296,,,,,,,,,,,,,, -18900,0.89592767,4.2483454,,,,,,,,,,,,,, -19000,0.9781894,4.043281,,,,,,,,,,,,,, -19100,0.7953795,4.2250795,,,,,,,,,,,,,, -19200,0.92001635,5.9087315,,,,,,,,,,,,,, -19230,,,0.3353515565395355,3.115917921066284,0.3131999969482422,3.278340578079224,50000.0,0.2408000081777572,3.847225189208984,10000.0,8859.802419900894,9779.14450263977,8859.802419900894,917.766725063324,0.5775036811828613,0.0 -19300,1.0610132,4.1918635,,,,,,,,,,,,,, -19400,1.0087651,3.9180543,,,,,,,,,,,,,, -19500,0.85371804,4.0341606,,,,,,,,,,,,,, -19600,0.7543545,5.6857214,,,,,,,,,,,,,, -19700,0.9049474,4.04372,,,,,,,,,,,,,, -19800,0.8402589,4.1631236,,,,,,,,,,,,,, -19900,0.7185415,5.676026,,,,,,,,,,,,,, -20000,0.8887098,6.053211,,,,,,,,,,,,,, -20100,1.0030468,4.1809454,,,,,,,,,,,,,, -20149,,,0.381640613079071,2.8914709091186523,0.3214599788188934,3.214576482772827,50000.0,0.2436000108718872,3.796807527542114,10000.0,9280.090950250626,10241.47033715248,9280.090950250626,959.7255573272704,0.6077971458435059,0.0 -20200,0.83115363,5.083151,,,,,,,,,,,,,, -20300,0.80103564,4.560795,,,,,,,,,,,,,, -20400,0.8538795,4.141664,,,,,,,,,,,,,, -20500,0.90696883,4.5818777,,,,,,,,,,,,,, -20600,0.70138186,5.50656,,,,,,,,,,,,,, -20700,1.2072525,4.1095586,,,,,,,,,,,,,, -20800,1.0978657,4.1331425,,,,,,,,,,,,,, -20900,0.76406586,5.1567473,,,,,,,,,,,,,, -21000,0.9575952,4.01519,,,,,,,,,,,,,, -21070,,,0.3535351455211639,2.995561361312866,0.3300800025463104,3.140366554260254,50000.0,0.2556000053882599,3.724861860275269,10000.0,9700.320873737335,10706.340615034103,9700.320873737335,1004.2897145748138,0.6361017227172852,0.0 -21100,1.5796969,3.9828868,,,,,,,,,,,,,, -21200,0.97329515,4.029385,,,,,,,,,,,,,, -21300,1.0053967,4.0314507,,,,,,,,,,,,,, -21400,0.71508473,5.633813,,,,,,,,,,,,,, -21500,0.8033204,4.0165415,,,,,,,,,,,,,, -21600,0.71688014,5.5789504,,,,,,,,,,,,,, -21700,1.0834924,4.103032,,,,,,,,,,,,,, -21800,0.9616671,4.2354517,,,,,,,,,,,,,, -21900,1.0114502,4.3168488,,,,,,,,,,,,,, -21990,,,0.3414843678474426,3.077301025390625,0.3179799914360046,3.2187304496765137,50000.0,0.2473000138998031,3.81091570854187,10000.0,10120.517220973969,11175.220227003098,10120.517220973969,1052.8985612392426,0.662848949432373,0.0 -22000,1.0439265,4.3310833,,,,,,,,,,,,,, -22100,0.9351377,4.125847,,,,,,,,,,,,,, -22200,0.79857785,4.126152,,,,,,,,,,,,,, -22300,0.86098975,4.3360353,,,,,,,,,,,,,, -22400,0.9084173,4.1115212,,,,,,,,,,,,,, -22500,1.0496715,3.9208477,,,,,,,,,,,,,, -22600,0.77505696,4.988554,,,,,,,,,,,,,, -22700,0.89766204,3.9914734,,,,,,,,,,,,,, -22800,0.8682863,6.099746,,,,,,,,,,,,,, -22900,0.93106276,5.6886477,,,,,,,,,,,,,, -22908,,,0.3515820205211639,3.0414631366729736,0.3215000033378601,3.234511137008667,50000.0,0.2469000071287155,3.816385269165039,10000.0,10540.53809094429,11636.506449222565,10540.53809094429,1094.0847754478457,0.6952741146087646,0.0 -23000,0.80502814,5.187289,,,,,,,,,,,,,, -23100,1.0556211,4.3713865,,,,,,,,,,,,,, -23200,0.8477377,4.6118593,,,,,,,,,,,,,, -23300,0.9752779,4.032173,,,,,,,,,,,,,, -23400,0.6393005,5.81252,,,,,,,,,,,,,, -23500,0.77511674,4.17034,,,,,,,,,,,,,, -23600,1.1363388,3.847182,,,,,,,,,,,,,, -23700,0.88331115,4.024164,,,,,,,,,,,,,, -23800,0.86323285,4.1313715,,,,,,,,,,,,,, -23825,,,0.3559374809265136,3.025038719177246,0.3379800021648407,3.137194871902466,50000.0,0.2626000046730041,3.7027549743652335,10000.0,10960.580304384232,12104.379473924637,10960.580304384232,1141.8376290798187,0.7260003089904785,0.0 -23900,0.9334689,4.7965717,,,,,,,,,,,,,, -24000,1.313296,4.254182,,,,,,,,,,,,,, -24100,1.0792427,4.170093,,,,,,,,,,,,,, -24200,0.9359988,3.9798658,,,,,,,,,,,,,, -24300,0.8215899,4.598054,,,,,,,,,,,,,, -24400,0.8997999,5.0082483,,,,,,,,,,,,,, -24500,0.9013758,4.4096575,,,,,,,,,,,,,, -24600,0.9324969,4.4660378,,,,,,,,,,,,,, -24700,1.0710077,4.1135006,,,,,,,,,,,,,, -24746,,,0.3567968606948852,3.0097639560699463,0.3316799998283386,3.17022705078125,50000.0,0.2560999989509582,3.7653188705444336,10000.0,11380.95326280594,12566.690264701843,11380.95326280594,1183.7002770900726,0.754091739654541,0.0 -24800,1.1569601,3.904915,,,,,,,,,,,,,, -24900,0.7334498,5.3369837,,,,,,,,,,,,,, -25000,1.0138489,3.9318042,,,,,,,,,,,,,, -25100,0.625544,5.585371,,,,,,,,,,,,,, -25200,0.91386807,4.1740265,,,,,,,,,,,,,, -25300,0.79624987,4.3997893,,,,,,,,,,,,,, -25400,1.1084256,3.977674,,,,,,,,,,,,,, -25500,0.9024623,4.117313,,,,,,,,,,,,,, -25600,0.8724909,4.021941,,,,,,,,,,,,,, -25665,,,0.3615429699420929,3.016035795211792,0.3340199887752533,3.1701815128326416,50000.0,0.2555000185966491,3.74962329864502,10000.0,11801.243542194366,13031.799020767212,11801.243542194366,1228.4372079372406,0.7881994247436523,0.0 -25700,0.82780963,4.921285,,,,,,,,,,,,,, -25800,0.9467499,6.19148,,,,,,,,,,,,,, -25900,0.86208206,4.107219,,,,,,,,,,,,,, -26000,0.8792497,4.4405966,,,,,,,,,,,,,, -26100,0.93772304,4.3029423,,,,,,,,,,,,,, -26200,0.991531,4.0540276,,,,,,,,,,,,,, -26300,0.90344006,3.9204133,,,,,,,,,,,,,, -26400,0.95680153,5.709478,,,,,,,,,,,,,, -26500,0.783138,6.0003533,,,,,,,,,,,,,, -26585,,,0.346992164850235,3.0902798175811768,0.3211599886417389,3.2521159648895264,50000.0,0.247400015592575,3.812144756317138,10000.0,12221.656418085098,13499.184258461,12221.656418085098,1275.3340392112732,0.8167154788970947,0.0 -26600,1.0415733,4.1472993,,,,,,,,,,,,,, -26700,1.1180478,4.0117407,,,,,,,,,,,,,, -26800,0.7934114,6.1048117,,,,,,,,,,,,,, -26900,0.8910432,3.9408407,,,,,,,,,,,,,, -27000,1.348366,4.0520225,,,,,,,,,,,,,, -27100,1.0065176,3.8312325,,,,,,,,,,,,,, -27200,0.9042951,4.3947153,,,,,,,,,,,,,, -27300,0.68779373,4.764461,,,,,,,,,,,,,, -27400,0.90714514,3.8799424,,,,,,,,,,,,,, -27500,0.88814616,4.1267605,,,,,,,,,,,,,, -27501,,,0.3630273342132568,2.980881929397583,0.3391999900341034,3.110235452651977,50000.0,0.2553000152111053,3.7012126445770255,10000.0,12641.995535612106,13958.93049645424,12641.995535612106,1314.660932779312,0.8498325347900391,0.0 -27600,0.6764577,5.7704535,,,,,,,,,,,,,, -27700,1.099651,3.9646804,,,,,,,,,,,,,, -27800,0.82927406,3.92761,,,,,,,,,,,,,, -27900,0.98974615,4.141818,,,,,,,,,,,,,, -28000,1.0411522,4.1638007,,,,,,,,,,,,,, -28100,0.96125156,4.0718813,,,,,,,,,,,,,, -28200,0.84429187,4.1103706,,,,,,,,,,,,,, -28300,0.90604925,4.0479636,,,,,,,,,,,,,, -28400,0.6451227,5.4138813,,,,,,,,,,,,,, -28418,,,0.3716992139816284,2.924288034439087,0.3419399857521057,3.0865392684936523,50000.0,0.2589000165462494,3.6658432483673096,10000.0,13062.072076559069,14417.129843950272,13062.072076559069,1352.705206155777,0.8814163208007812,0.0 -28500,0.89668494,3.9259949,,,,,,,,,,,,,, -28600,0.98495096,4.738299,,,,,,,,,,,,,, -28700,0.9453109,3.9772825,,,,,,,,,,,,,, -28800,0.71047074,5.317533,,,,,,,,,,,,,, -28900,0.9101955,3.9381418,,,,,,,,,,,,,, -29000,0.87672377,5.397587,,,,,,,,,,,,,, -29100,0.946589,4.0835404,,,,,,,,,,,,,, -29200,0.9083304,6.0066724,,,,,,,,,,,,,, -29300,0.9589056,3.8910742,,,,,,,,,,,,,, -29335,,,0.3422265648841858,3.173394203186035,0.316100001335144,3.3170182704925537,50000.0,0.2434000074863433,3.895664930343628,10000.0,13482.143792390823,14880.750834703444,13482.143792390823,1396.1704378128052,0.9180092811584472,0.0 -29400,0.80808544,6.13897,,,,,,,,,,,,,, -29500,1.1181203,3.9009275,,,,,,,,,,,,,, -29600,1.016876,3.9742224,,,,,,,,,,,,,, -29700,1.0055202,4.012881,,,,,,,,,,,,,, -29800,0.9252058,3.8171175,,,,,,,,,,,,,, -29900,0.7439295,5.8208733,,,,,,,,,,,,,, -30000,0.9114922,3.85595,,,,,,,,,,,,,, -30100,0.91160727,3.8249896,,,,,,,,,,,,,, -30200,1.3062435,3.9381433,,,,,,,,,,,,,, -30255,,,0.3800390660762787,2.861891746520996,0.3527399897575378,3.0179009437561035,50000.0,0.2698000073432922,3.616143226623535,10000.0,13902.161823272703,15347.681805610657,13902.161823272703,1443.007826089859,0.9461681842803956,0.0 -30300,1.0361661,4.0307975,,,,,,,,,,,,,, -30400,0.9284324,3.768763,,,,,,,,,,,,,, -30500,0.882951,4.4799194,,,,,,,,,,,,,, -30600,1.0635197,3.8532844,,,,,,,,,,,,,, -30700,1.0214041,4.050052,,,,,,,,,,,,,, -30800,1.0401452,3.9893832,,,,,,,,,,,,,, -30900,1.0273609,4.0222864,,,,,,,,,,,,,, -31000,0.8370606,5.8040295,,,,,,,,,,,,,, -31100,0.73886865,5.832781,,,,,,,,,,,,,, -31176,,,0.3661523461341858,2.9577431678771973,0.3359200060367584,3.120572328567505,50000.0,0.2560999989509582,3.68998384475708,10000.0,14322.417489528656,15807.286099910736,14322.417489528656,1482.2746975421906,0.9801223278045654,0.0 -31200,0.82235324,5.413018,,,,,,,,,,,,,, -31300,0.9762372,3.961757,,,,,,,,,,,,,, -31400,0.95530015,4.967515,,,,,,,,,,,,,, -31500,0.92139184,3.885917,,,,,,,,,,,,,, -31600,1.0267968,3.9506025,,,,,,,,,,,,,, -31700,0.6781235,6.0042963,,,,,,,,,,,,,, -31800,0.9537951,3.8731472,,,,,,,,,,,,,, -31900,0.62373036,5.828975,,,,,,,,,,,,,, -32000,1.0691022,4.5562315,,,,,,,,,,,,,, -32094,,,0.4045312404632568,2.735917806625366,0.3504199981689453,3.0416879653930664,50000.0,0.2699000239372253,3.6354501247406006,10000.0,14742.738109588625,16267.7527885437,14742.738109588625,1522.337045431137,1.0161523818969729,0.0 -32100,0.89936924,4.0561295,,,,,,,,,,,,,, -32200,1.264031,4.0655127,,,,,,,,,,,,,, -32300,0.7735973,5.107916,,,,,,,,,,,,,, -32400,0.9058319,4.117985,,,,,,,,,,,,,, -32500,1.0081489,4.5169287,,,,,,,,,,,,,, -32600,1.2412056,4.148976,,,,,,,,,,,,,, -32700,0.78390485,4.4695387,,,,,,,,,,,,,, -32800,0.74794817,5.071909,,,,,,,,,,,,,, -32900,0.7501766,4.957224,,,,,,,,,,,,,, -33000,0.9719838,4.0034847,,,,,,,,,,,,,, -33013,,,0.371406227350235,2.9292590618133545,0.3455399870872497,3.0680296421051025,50000.0,0.2725000083446502,3.637152194976807,10000.0,15162.715245962145,16733.630709409714,15162.715245962145,1568.1595079898834,1.0463981628417969,0.0 -33100,0.81954163,4.42654,,,,,,,,,,,,,, -33200,1.1771314,3.8682444,,,,,,,,,,,,,, -33300,0.8542427,5.902512,,,,,,,,,,,,,, -33400,0.87081885,4.471984,,,,,,,,,,,,,, -33500,0.9837961,3.992925,,,,,,,,,,,,,, -33600,1.0971104,4.0053434,,,,,,,,,,,,,, -33700,0.9251642,3.8378053,,,,,,,,,,,,,, -33800,0.82427466,3.8850193,,,,,,,,,,,,,, -33900,0.98882014,4.347524,,,,,,,,,,,,,, -33933,,,0.3720703125,2.9328975677490234,0.3447999954223633,3.0781383514404297,50000.0,0.2675000131130218,3.668169736862183,10000.0,15583.04086136818,17196.518713235855,15583.04086136818,1610.6403777599337,1.0798285007476809,0.0 -34000,1.0040478,3.7583735,,,,,,,,,,,,,, -34100,0.7600476,4.619526,,,,,,,,,,,,,, -34200,0.825265,5.785304,,,,,,,,,,,,,, -34300,0.8673674,4.1122713,,,,,,,,,,,,,, -34400,0.8640519,4.4085155,,,,,,,,,,,,,, -34500,0.93103915,3.9608684,,,,,,,,,,,,,, -34600,0.78647685,4.9396834,,,,,,,,,,,,,, -34700,1.2790977,3.9846969,,,,,,,,,,,,,, -34800,0.85742146,5.5770082,,,,,,,,,,,,,, -34852,,,0.3985156118869781,2.755237579345703,0.3614400029182434,2.958653926849365,50000.0,0.270900011062622,3.5934412479400635,10000.0,16003.265714883804,17659.312923192978,16003.265714883804,1653.1295936107635,1.1132240295410156,0.0 -34900,0.8909288,3.9381676,,,,,,,,,,,,,, -35000,0.988185,4.271842,,,,,,,,,,,,,, -35100,0.9022391,3.8416708,,,,,,,,,,,,,, -35200,0.85235953,3.8956258,,,,,,,,,,,,,, -35300,1.0184422,3.8438396,,,,,,,,,,,,,, -35400,0.7104371,5.6210527,,,,,,,,,,,,,, -35500,1.219969,3.9924896,,,,,,,,,,,,,, -35600,1.023937,4.7654986,,,,,,,,,,,,,, -35700,0.968798,3.9648364,,,,,,,,,,,,,, -35772,,,0.3848632872104645,2.830144166946411,0.3606799840927124,2.9782936573028564,50000.0,0.2740000188350677,3.591972827911377,10000.0,16423.598749876022,18126.9892642498,16423.598749876022,1700.3952662944794,1.143989086151123,0.0 -35800,0.96045715,3.8347366,,,,,,,,,,,,,, -35900,1.0984205,3.8877306,,,,,,,,,,,,,, -36000,0.97437775,3.7956471,,,,,,,,,,,,,, -36100,0.87378305,4.033334,,,,,,,,,,,,,, -36200,0.64645433,5.665887,,,,,,,,,,,,,, -36300,1.0199587,3.9320362,,,,,,,,,,,,,, -36400,0.7723046,5.742717,,,,,,,,,,,,,, -36500,1.1817435,3.8753116,,,,,,,,,,,,,, -36600,1.0314116,3.8893528,,,,,,,,,,,,,, -36693,,,0.3790820240974426,2.873992919921875,0.3569999933242798,3.0040409564971924,50000.0,0.2692000269889831,3.625471591949463,10000.0,16843.694502830505,18591.12539958954,16843.694502830505,1744.3594100475311,1.1727678775787354,0.0 -36700,0.9858932,3.6429603,,,,,,,,,,,,,, -36800,0.97440696,3.761557,,,,,,,,,,,,,, -36900,0.9635996,4.131099,,,,,,,,,,,,,, -37000,0.92421925,3.7099206,,,,,,,,,,,,,, -37100,0.9698837,3.9943278,,,,,,,,,,,,,, -37200,0.9251245,3.9402926,,,,,,,,,,,,,, -37300,0.96317214,3.884214,,,,,,,,,,,,,, -37400,1.1766399,3.92319,,,,,,,,,,,,,, -37500,0.9533266,4.236629,,,,,,,,,,,,,, -37600,0.7261055,6.0631557,,,,,,,,,,,,,, -37613,,,0.3933398425579071,2.799267768859864,0.3593399822711944,2.989340543746948,50000.0,0.2735000252723694,3.5844669342041016,10000.0,17263.87330675125,19058.90007901192,17263.87330675125,1791.871794462204,1.2091057300567627,0.0 -37700,0.9190773,4.2462378,,,,,,,,,,,,,, -37800,1.1545993,3.8106797,,,,,,,,,,,,,, -37900,0.9666649,3.758467,,,,,,,,,,,,,, -38000,1.1116811,4.2515187,,,,,,,,,,,,,, -38100,1.0384053,3.8380024,,,,,,,,,,,,,, -38200,0.887908,4.39284,,,,,,,,,,,,,, -38300,1.0711073,3.7889135,,,,,,,,,,,,,, -38400,0.7799798,4.3617496,,,,,,,,,,,,,, -38500,0.76105106,4.9681106,,,,,,,,,,,,,, -38530,,,0.3886132836341858,2.7967445850372314,0.3657599985599518,2.934712648391724,50000.0,0.2800000011920929,3.546329498291016,10000.0,17683.964739084244,19521.694784641262,17683.964739084244,1834.498204946518,1.2384934425354004,0.0 -38600,1.2484945,3.903127,,,,,,,,,,,,,, -38700,1.0489326,4.7017007,,,,,,,,,,,,,, -38800,1.2488674,3.848908,,,,,,,,,,,,,, -38900,1.0352756,3.7580633,,,,,,,,,,,,,, -39000,0.95051247,4.79413,,,,,,,,,,,,,, -39100,1.2448902,3.7612076,,,,,,,,,,,,,, -39200,0.9705769,3.8201973,,,,,,,,,,,,,, -39300,0.94721115,3.7040782,,,,,,,,,,,,,, -39400,0.78449345,5.062535,,,,,,,,,,,,,, -39451,,,0.3868554532527923,2.8592474460601807,0.3622399866580963,2.997967004776001,50000.0,0.2826000154018402,3.571634531021118,10000.0,18104.35695052147,19986.68339204788,18104.35695052147,1879.0127630233765,1.272782802581787,0.0 -39500,0.8192078,5.7545905,,,,,,,,,,,,,, -39600,0.9853301,3.9906642,,,,,,,,,,,,,, -39700,0.95242935,3.9566216,,,,,,,,,,,,,, -39800,0.95862514,3.7537775,,,,,,,,,,,,,, -39900,0.8898779,5.765568,,,,,,,,,,,,,, -40000,1.202882,4.412466,,,,,,,,,,,,,, -40100,1.0024586,3.836806,,,,,,,,,,,,,, -40200,0.89213365,4.206555,,,,,,,,,,,,,, -40300,1.114091,3.7855997,,,,,,,,,,,,,, -40369,,,0.4025976359844208,2.729934930801392,0.3704600036144256,2.898860216140747,50000.0,0.2865000069141388,3.507483720779419,10000.0,18524.528084754944,20454.92676472664,18524.528084754944,1927.0022106170647,1.308027267456055,0.0 -40400,1.0916965,3.94461,,,,,,,,,,,,,, -40500,0.7492921,5.7247562,,,,,,,,,,,,,, -40600,0.8820851,3.7829626,,,,,,,,,,,,,, -40700,1.0023859,3.9323978,,,,,,,,,,,,,, -40800,1.0209397,4.0196395,,,,,,,,,,,,,, -40900,0.8961637,5.08835,,,,,,,,,,,,,, -41000,0.98522586,3.6496475,,,,,,,,,,,,,, -41100,0.9660931,3.6186378,,,,,,,,,,,,,, -41200,0.7217014,4.7190757,,,,,,,,,,,,,, -41288,,,0.4137109220027923,2.72369384765625,0.3606199920177459,2.9944801330566406,50000.0,0.2812000215053558,3.5632994174957275,10000.0,18944.64669013024,20917.34801101685,18944.64669013024,1969.225081205368,1.340603590011597,0.0 -41300,1.1515936,3.774344,,,,,,,,,,,,,, -41400,1.1580288,3.8994627,,,,,,,,,,,,,, -41500,1.0585678,3.7257001,,,,,,,,,,,,,, -41600,1.0356592,3.8176003,,,,,,,,,,,,,, -41700,0.8066666,5.673546,,,,,,,,,,,,,, -41800,0.9435803,3.7486434,,,,,,,,,,,,,, -41900,0.8890598,5.354223,,,,,,,,,,,,,, -42000,0.9313665,4.852465,,,,,,,,,,,,,, -42100,1.0841138,4.202253,,,,,,,,,,,,,, -42200,0.9792198,4.001619,,,,,,,,,,,,,, -42206,,,0.388964831829071,2.8523201942443848,0.367279976606369,2.9817123413085938,50000.0,0.2829000055789947,3.547908067703247,10000.0,19364.88948059082,21376.21737074852,19364.88948059082,2007.7696409225464,1.3750572204589844,0.0 -42300,1.0393802,4.0851,,,,,,,,,,,,,, -42400,1.0116808,4.139707,,,,,,,,,,,,,, -42500,0.8975458,4.559026,,,,,,,,,,,,,, -42600,1.018269,3.6719494,,,,,,,,,,,,,, -42700,1.291788,4.099411,,,,,,,,,,,,,, -42800,1.0364747,3.9000807,,,,,,,,,,,,,, -42900,1.1042765,3.7432547,,,,,,,,,,,,,, -43000,1.1771703,3.8772855,,,,,,,,,,,,,, -43100,0.98486125,3.7781692,,,,,,,,,,,,,, -43125,,,0.4010742008686065,2.7512407302856445,0.3735399842262268,2.8998680114746094,50000.0,0.2886000275611877,3.5031073093414307,10000.0,19785.169471025467,21842.482117176056,19785.169471025467,2053.67297244072,1.4091379642486572,0.0 -43200,0.99767286,3.732444,,,,,,,,,,,,,, -43300,1.0599562,3.7994885,,,,,,,,,,,,,, -43400,0.85744816,4.486918,,,,,,,,,,,,,, -43500,0.95651263,3.6667783,,,,,,,,,,,,,, -43600,0.7850527,5.7869673,,,,,,,,,,,,,, -43700,1.0301223,3.9139783,,,,,,,,,,,,,, -43800,1.047926,3.8823774,,,,,,,,,,,,,, -43900,0.9145678,5.9447494,,,,,,,,,,,,,, -44000,1.1170586,3.7948308,,,,,,,,,,,,,, -44045,,,0.429003894329071,2.601717233657837,0.3832799792289734,2.852113723754883,50000.0,0.2928000092506408,3.449684858322144,10000.0,20205.232219696045,22307.98716187477,20205.232219696045,2099.034994125366,1.4414191246032717,0.0 -44100,1.3569176,3.8064482,,,,,,,,,,,,,, -44200,1.0695548,3.6944754,,,,,,,,,,,,,, -44300,0.77626175,4.940068,,,,,,,,,,,,,, -44400,0.8367157,4.7845163,,,,,,,,,,,,,, -44500,1.0601954,3.676112,,,,,,,,,,,,,, -44600,1.1208131,3.8202038,,,,,,,,,,,,,, -44700,0.812089,4.8245177,,,,,,,,,,,,,, -44800,0.9594854,3.562063,,,,,,,,,,,,,, -44900,0.99670935,3.9259534,,,,,,,,,,,,,, -44962,,,0.4001562297344208,2.801137208938598,0.3734200000762939,2.930062055587769,50000.0,0.2863000035285949,3.5115420818328857,10000.0,20625.31547307968,22773.348356485367,20625.31547307968,2144.228601694107,1.4792120456695557,0.0 -45000,0.85805863,5.8993483,,,,,,,,,,,,,, -45100,0.99038774,3.6667583,,,,,,,,,,,,,, -45200,0.68417996,5.9050393,,,,,,,,,,,,,, -45300,1.0495931,3.9893217,,,,,,,,,,,,,, -45400,1.2782849,3.87121,,,,,,,,,,,,,, -45500,0.9749847,4.5739174,,,,,,,,,,,,,, -45600,0.78344697,4.405575,,,,,,,,,,,,,, -45700,1.0198561,3.952795,,,,,,,,,,,,,, -45800,1.1251559,3.8557956,,,,,,,,,,,,,, -45882,,,0.2884374856948852,3.4999566078186035,0.2685799896717071,3.610232830047608,50000.0,0.2107000052928924,4.131436347961426,10000.0,21045.717796325684,23240.907628774643,21045.717796325684,2191.300128698349,1.5172581672668457,0.0 -45900,1.0122895,4.250681,,,,,,,,,,,,,, -46000,1.0175384,3.9376712,,,,,,,,,,,,,, -46100,0.9021627,3.7534819,,,,,,,,,,,,,, -46200,0.88943505,4.1496305,,,,,,,,,,,,,, -46300,0.9791887,3.9015927,,,,,,,,,,,,,, -46400,0.8207132,5.853152,,,,,,,,,,,,,, -46500,0.8691825,3.733623,,,,,,,,,,,,,, -46600,1.2825663,3.8108077,,,,,,,,,,,,,, -46700,1.0654815,3.7632132,,,,,,,,,,,,,, -46800,0.83296704,5.487559,,,,,,,,,,,,,, -46801,,,0.3951171934604645,2.819166421890259,0.3638199865818023,3.004495143890381,50000.0,0.2831000089645386,3.567754030227661,10000.0,21465.82685112953,23702.38392972946,21465.82685112953,2232.58002948761,1.55698823928833,0.0 -46900,0.89141184,4.4695597,,,,,,,,,,,,,, -47000,0.8043015,5.4462595,,,,,,,,,,,,,, -47100,0.91634715,3.8587666,,,,,,,,,,,,,, -47200,0.952225,5.7889233,,,,,,,,,,,,,, -47300,0.7799118,5.3051863,,,,,,,,,,,,,, -47400,1.0128517,3.88359,,,,,,,,,,,,,, -47500,0.59723777,5.905406,,,,,,,,,,,,,, -47600,1.1254079,4.109893,,,,,,,,,,,,,, -47700,0.95641255,5.8472986,,,,,,,,,,,,,, -47723,,,0.4143163859844208,2.6614110469818115,0.3860200047492981,2.80824875831604,50000.0,0.2997000217437744,3.443194627761841,10000.0,21886.131243228912,24165.59008526802,21886.131243228912,2275.397953271866,1.593160629272461,0.0 -47800,1.0161462,3.8725123,,,,,,,,,,,,,, -47900,0.8683292,3.7575688,,,,,,,,,,,,,, -48000,0.665723,5.787859,,,,,,,,,,,,,, -48100,0.87202173,4.4011593,,,,,,,,,,,,,, -48200,1.0199342,4.040841,,,,,,,,,,,,,, -48300,0.8468163,5.310167,,,,,,,,,,,,,, -48400,1.1622299,3.864942,,,,,,,,,,,,,, -48500,1.0946821,3.7393367,,,,,,,,,,,,,, -48600,1.0287142,3.6567724,,,,,,,,,,,,,, -48642,,,0.4010742008686065,2.7443270683288574,0.3771199882030487,2.8815295696258545,50000.0,0.2904000282287597,3.4814419746398926,10000.0,22306.22730517388,24632.510838747025,22306.22730517388,2322.138329029084,1.6305792331695557,0.0 -48700,1.0358809,3.6430404,,,,,,,,,,,,,, -48800,0.8342535,5.560191,,,,,,,,,,,,,, -48900,0.9591721,3.801348,,,,,,,,,,,,,, -49000,0.9695489,3.6683652,,,,,,,,,,,,,, -49100,0.97895557,3.736112,,,,,,,,,,,,,, -49200,0.9491681,4.1527543,,,,,,,,,,,,,, -49300,0.84570324,5.5113916,,,,,,,,,,,,,, -49400,1.0209293,3.7010105,,,,,,,,,,,,,, -49500,1.0527844,3.7584486,,,,,,,,,,,,,, -49559,,,0.4129492044448852,2.7169084548950195,0.3819800019264221,2.885247707366944,50000.0,0.2870000004768371,3.4870080947875977,10000.0,22726.256851911545,25097.48213648796,22726.256851911545,2366.9990010261536,1.662933588027954,0.0 -49600,0.95854104,3.8923163,,,,,,,,,,,,,, -49700,1.0644542,3.7752228,,,,,,,,,,,,,, -49800,0.9407046,3.5560162,,,,,,,,,,,,,, -49900,1.0496286,3.5720954,,,,,,,,,,,,,, -50000,0.91683,3.6042252,,,,,,,,,,,,,, -50100,0.8261873,5.223291,,,,,,,,,,,,,, -50200,1.0282726,3.598797,,,,,,,,,,,,,, -50300,1.1052209,3.6678452,,,,,,,,,,,,,, -50400,1.0003645,3.7326326,,,,,,,,,,,,,, -50478,,,0.4268554449081421,2.6102540493011475,0.39751997590065,2.780481100082397,50000.0,0.3005000054836273,3.3917243480682373,10000.0,23146.49708509445,25563.53849577904,23146.49708509445,2412.7322528362274,1.6987645626068115,0.0 -50500,0.8685027,3.8386292,,,,,,,,,,,,,, -50600,0.9945338,3.8833628,,,,,,,,,,,,,, -50700,0.9922625,3.8199134,,,,,,,,,,,,,, -50800,1.0151056,3.6924486,,,,,,,,,,,,,, -50900,1.1091954,3.7132313,,,,,,,,,,,,,, -51000,0.93640316,5.914507,,,,,,,,,,,,,, -51100,1.1718574,3.7200933,,,,,,,,,,,,,, -51200,0.85988045,4.0607085,,,,,,,,,,,,,, -51300,1.0811366,3.665586,,,,,,,,,,,,,, -51397,,,0.429492175579071,2.56564998626709,0.4009999930858612,2.729902505874634,50000.0,0.3053000271320343,3.3814618587493896,10000.0,23566.498861551285,26032.06483864784,23566.498861551285,2461.1722581386566,1.734628677368164,0.0 -51400,1.3914592,3.735914,,,,,,,,,,,,,, -51500,0.9355929,3.6923718,,,,,,,,,,,,,, -51600,0.8623011,4.4223914,,,,,,,,,,,,,, -51700,1.4075857,3.8272529,,,,,,,,,,,,,, -51800,1.1107696,3.750696,,,,,,,,,,,,,, -51900,1.0976391,3.5898151,,,,,,,,,,,,,, -52000,1.3492203,3.7340288,,,,,,,,,,,,,, -52100,0.81481785,5.472727,,,,,,,,,,,,,, -52200,1.2731454,3.7017725,,,,,,,,,,,,,, -52300,0.9073122,4.2876625,,,,,,,,,,,,,, -52315,,,0.4299023449420929,2.5807902812957764,0.398059993982315,2.752584934234619,50000.0,0.3107000291347503,3.378159523010254,10000.0,23986.7883477211,26497.21853017807,23986.7883477211,2505.948692560196,1.7748017311096191,0.0 -52400,1.1336837,3.6793203,,,,,,,,,,,,,, -52500,0.9581642,3.4889584,,,,,,,,,,,,,, -52600,1.0098557,3.63701,,,,,,,,,,,,,, -52700,1.0475498,3.7050161,,,,,,,,,,,,,, -52800,1.1594704,4.38658,,,,,,,,,,,,,, -52900,1.0045538,3.5684874,,,,,,,,,,,,,, -53000,1.080599,3.6203356,,,,,,,,,,,,,, -53100,0.9352136,3.5523918,,,,,,,,,,,,,, -53200,1.0399268,3.6572552,,,,,,,,,,,,,, -53234,,,0.4723632633686065,2.374985694885254,0.4041999876499176,2.73072361946106,50000.0,0.3115000128746032,3.353266954421997,10000.0,24406.975200414658,26964.357704639435,24406.975200414658,2552.8182249069214,1.8105263710021973,0.0 -53300,0.9222643,5.944924,,,,,,,,,,,,,, -53400,1.0297984,3.6884198,,,,,,,,,,,,,, -53500,1.067313,4.1754913,,,,,,,,,,,,,, -53600,0.9073709,5.2433968,,,,,,,,,,,,,, -53700,0.9925072,3.5590415,,,,,,,,,,,,,, -53800,1.0324739,3.5970008,,,,,,,,,,,,,, -53900,1.0479133,3.7491207,,,,,,,,,,,,,, -54000,0.9005966,5.491331,,,,,,,,,,,,,, -54100,0.66448474,5.5369744,,,,,,,,,,,,,, -54153,,,0.4276367127895355,2.59726619720459,0.3992999792098999,2.763692140579224,50000.0,0.3023000061511993,3.398660659790039,10000.0,24827.1175699234,27435.02015376091,24827.1175699234,2603.2553062438965,1.846671342849732,0.0 -54200,0.9974653,3.7380118,,,,,,,,,,,,,, -54300,0.8162256,5.858655,,,,,,,,,,,,,, -54400,0.9094277,3.6360192,,,,,,,,,,,,,, -54500,1.034917,3.6478207,,,,,,,,,,,,,, -54600,1.1378617,3.5757694,,,,,,,,,,,,,, -54700,0.9428735,4.6007237,,,,,,,,,,,,,, -54800,0.70807695,5.9118137,,,,,,,,,,,,,, -54900,1.1365145,3.726019,,,,,,,,,,,,,, -55000,1.2863115,3.5716116,,,,,,,,,,,,,, -55074,,,0.4376757740974426,2.500518798828125,0.4105999767780304,2.6731464862823486,50000.0,0.3185000121593475,3.2860732078552246,10000.0,25247.319982767105,27898.86329269409,25247.319982767105,2646.8072805404663,1.887636423110962,0.0 -55100,0.86131704,4.48989,,,,,,,,,,,,,, -55200,0.8531045,4.159111,,,,,,,,,,,,,, -55300,1.1278919,5.92465,,,,,,,,,,,,,, -55400,1.3134215,3.761846,,,,,,,,,,,,,, -55500,1.1433805,4.1982474,,,,,,,,,,,,,, -55600,1.091958,3.6755269,,,,,,,,,,,,,, -55700,0.91583204,3.9219453,,,,,,,,,,,,,, -55800,0.75775373,5.6640887,,,,,,,,,,,,,, -55900,1.0447929,3.6157641,,,,,,,,,,,,,, -55995,,,0.4432226419448852,2.532923936843872,0.402319997549057,2.7546207904815674,50000.0,0.302700012922287,3.414133310317993,10000.0,25667.46424293518,28362.4433825016,25667.46424293518,2690.1619765758514,1.9212877750396729,0.0 -56000,0.87167376,5.289716,,,,,,,,,,,,,, -56100,1.0285975,3.7489042,,,,,,,,,,,,,, -56200,1.0693821,3.546772,,,,,,,,,,,,,, -56300,1.2924721,3.8352923,,,,,,,,,,,,,, -56400,1.1726985,3.6131964,,,,,,,,,,,,,, -56500,0.75465643,5.66246,,,,,,,,,,,,,, -56600,1.0109208,3.9641905,,,,,,,,,,,,,, -56700,1.0173519,3.6681273,,,,,,,,,,,,,, -56800,0.83900094,3.6788175,,,,,,,,,,,,,, -56900,0.774677,4.8695307,,,,,,,,,,,,,, -56916,,,0.4407812356948852,2.5251173973083496,0.4111399948596954,2.6807515621185303,50000.0,0.3193000257015228,3.296461582183838,10000.0,26087.7837445736,28826.425322294235,26087.7837445736,2733.7377874851227,1.959881067276001,0.0 -57000,0.9152143,4.908683,,,,,,,,,,,,,, -57100,0.86513925,4.0111737,,,,,,,,,,,,,, -57200,0.7966764,5.567906,,,,,,,,,,,,,, -57300,0.9100636,4.091425,,,,,,,,,,,,,, -57400,1.0994643,3.8128128,,,,,,,,,,,,,, -57500,1.1262643,3.833669,,,,,,,,,,,,,, -57600,0.8706335,4.3683615,,,,,,,,,,,,,, -57700,0.79729205,4.8203473,,,,,,,,,,,,,, -57800,1.2162137,3.6630428,,,,,,,,,,,,,, -57836,,,0.4381054639816284,2.522041320800781,0.4074999988079071,2.682382345199585,50000.0,0.3222000300884247,3.286940097808838,10000.0,26507.91575574875,29291.23188686371,26507.91575574875,2778.325141429901,1.99947476387024,0.0 -57900,1.1610471,3.7341816,,,,,,,,,,,,,, -58000,0.8153934,4.8200088,,,,,,,,,,,,,, -58100,1.2263727,3.6995099,,,,,,,,,,,,,, -58200,0.933616,5.334804,,,,,,,,,,,,,, -58300,0.8562522,4.286619,,,,,,,,,,,,,, -58400,1.0339826,3.6981466,,,,,,,,,,,,,, -58500,1.0364248,4.0576563,,,,,,,,,,,,,, -58600,0.85540813,5.040065,,,,,,,,,,,,,, -58700,1.2296182,3.6367161,,,,,,,,,,,,,, -58754,,,0.4493750035762787,2.502846002578736,0.4158999919891357,2.6864118576049805,50000.0,0.3222000300884247,3.29914927482605,10000.0,26927.859172344208,29757.56119155884,26927.859172344208,2824.6300699710846,2.0336482524871826,0.0 -58800,0.8951379,4.8352413,,,,,,,,,,,,,, -58900,1.0406629,3.5787513,,,,,,,,,,,,,, -59000,0.96402955,3.7840436,,,,,,,,,,,,,, -59100,1.0590636,3.7033498,,,,,,,,,,,,,, -59200,1.184323,3.654271,,,,,,,,,,,,,, -59300,1.1834787,3.7424283,,,,,,,,,,,,,, -59400,0.9128706,3.6279895,,,,,,,,,,,,,, -59500,1.06563,3.9717374,,,,,,,,,,,,,, -59600,0.90111846,4.6604123,,,,,,,,,,,,,, -59675,,,0.4404882788658142,2.5459964275360107,0.4094399809837341,2.710213661193848,50000.0,0.3230000138282776,3.316578388214112,10000.0,27347.924685001373,30223.841745615005,27347.924685001373,2870.763954639435,2.066673040390014,0.0 -59700,1.0746343,3.5445707,,,,,,,,,,,,,, -59800,0.7623334,5.7908196,,,,,,,,,,,,,, -59900,1.2144157,3.5174139,,,,,,,,,,,,,, -60000,1.0970055,3.607757,,,,,,,,,,,,,, -60100,1.1686957,3.6720886,,,,,,,,,,,,,, -60200,1.0410608,3.512803,,,,,,,,,,,,,, -60300,1.2919377,3.5303223,,,,,,,,,,,,,, -60400,1.0410556,3.4075148,,,,,,,,,,,,,, -60500,0.8544098,5.623709,,,,,,,,,,,,,, -60594,,,0.453437477350235,2.441464424133301,0.4258799850940704,2.599683284759521,50000.0,0.3331000208854675,3.216318130493164,10000.0,27768.17107129097,30688.83791780472,27768.17107129097,2915.4320845603943,2.101196527481079,0.0 -60600,0.97954637,3.9342992,,,,,,,,,,,,,, -60700,1.2182504,3.682172,,,,,,,,,,,,,, -60800,0.98925585,3.5829086,,,,,,,,,,,,,, -60900,1.0971706,3.8334422,,,,,,,,,,,,,, -61000,1.2666575,3.6107914,,,,,,,,,,,,,, -61100,0.85584456,5.748642,,,,,,,,,,,,,, -61200,0.9498704,3.5601132,,,,,,,,,,,,,, -61300,1.1362885,3.9064436,,,,,,,,,,,,,, -61400,1.1203681,3.8591778,,,,,,,,,,,,,, -61500,0.78765976,5.7494392,,,,,,,,,,,,,, -61513,,,0.454902321100235,2.44576096534729,0.4205799996852875,2.644345283508301,50000.0,0.3228000104427337,3.2577946186065674,10000.0,28188.53337931633,31157.204171419144,28188.53337931633,2963.3469684124,2.1426825523376465,0.0 -61600,1.2488898,3.579953,,,,,,,,,,,,,, -61700,0.9108796,3.881795,,,,,,,,,,,,,, -61800,1.1405287,3.6196783,,,,,,,,,,,,,, -61900,1.0427989,3.9399114,,,,,,,,,,,,,, -62000,1.097203,3.5143614,,,,,,,,,,,,,, -62100,1.2860881,5.2669263,,,,,,,,,,,,,, -62200,0.83755267,4.550128,,,,,,,,,,,,,, -62300,0.9320157,3.8551962,,,,,,,,,,,,,, -62400,0.83210367,5.794582,,,,,,,,,,,,,, -62433,,,0.4523828029632568,2.508619546890259,0.4140399992465973,2.6991686820983887,50000.0,0.3226000070571899,3.3033504486083984,10000.0,28608.84869074821,31622.267642736435,28608.84869074821,3008.009134531021,2.1812844276428223,0.0 -62500,1.0224521,3.4792519,,,,,,,,,,,,,, -62600,1.1446748,3.825833,,,,,,,,,,,,,, -62700,0.8656991,4.5677853,,,,,,,,,,,,,, -62800,1.0859303,3.789949,,,,,,,,,,,,,, -62900,1.0075997,3.3787684,,,,,,,,,,,,,, -63000,0.8325853,5.474205,,,,,,,,,,,,,, -63100,1.0265265,3.551834,,,,,,,,,,,,,, -63200,0.75794196,4.988345,,,,,,,,,,,,,, -63300,1.0291378,3.5968187,,,,,,,,,,,,,, -63353,,,0.4460742175579071,2.4987034797668457,0.4134199917316437,2.6845717430114746,50000.0,0.3113000094890594,3.3364946842193604,10000.0,29028.77939391136,32091.04265642166,29028.77939391136,3056.772082090378,2.215562343597412,0.0 -63400,1.054669,3.4936156,,,,,,,,,,,,,, -63500,0.88806313,3.508999,,,,,,,,,,,,,, -63600,1.1745187,3.444599,,,,,,,,,,,,,, -63700,1.1215923,3.604186,,,,,,,,,,,,,, -63800,1.067735,3.7526941,,,,,,,,,,,,,, -63900,1.1768419,3.5734558,,,,,,,,,,,,,, -64000,1.1414324,3.4910538,,,,,,,,,,,,,, -64100,1.0557445,4.397244,,,,,,,,,,,,,, -64200,1.0277956,4.7785983,,,,,,,,,,,,,, -64271,,,0.451464831829071,2.464915037155152,0.4184199869632721,2.648669719696045,50000.0,0.3242000043392181,3.2666985988616943,10000.0,29448.9363090992,32559.199233531952,29448.9363090992,3104.6899168491364,2.250218152999878,0.0 -64300,0.90836966,4.006022,,,,,,,,,,,,,, -64400,1.165591,3.5051365,,,,,,,,,,,,,, -64500,1.1280713,3.5468795,,,,,,,,,,,,,, -64600,0.9916807,5.254059,,,,,,,,,,,,,, -64700,1.1292858,4.5484977,,,,,,,,,,,,,, -64800,1.0051416,3.736025,,,,,,,,,,,,,, -64900,1.1486382,3.5840166,,,,,,,,,,,,,, -65000,1.1232178,3.52603,,,,,,,,,,,,,, -65100,0.84417015,4.109751,,,,,,,,,,,,,, -65189,,,0.4797070324420929,2.2820422649383545,0.4228200018405914,2.5929930210113525,50000.0,0.33160001039505,3.247788906097412,10000.0,29869.14924716949,33023.18640756607,29869.14924716949,3148.381046772003,2.2855420112609863,0.0 -65200,1.1880583,3.309476,,,,,,,,,,,,,, -65300,0.86705136,4.3051376,,,,,,,,,,,,,, -65400,1.0186968,3.5181222,,,,,,,,,,,,,, -65500,0.65359807,4.931487,,,,,,,,,,,,,, -65600,1.0473136,3.5485213,,,,,,,,,,,,,, -65700,1.0919294,3.666679,,,,,,,,,,,,,, -65800,1.190541,3.5535429,,,,,,,,,,,,,, -65900,1.0584229,3.7588246,,,,,,,,,,,,,, -66000,0.8907792,3.9802651,,,,,,,,,,,,,, -66100,0.8267093,5.8143744,,,,,,,,,,,,,, -66108,,,0.4628124833106994,2.4145824909210205,0.4338599741458893,2.566502094268799,50000.0,0.3427000045776367,3.1849935054779053,10000.0,30289.283529758453,33482.63961672783,30289.283529758453,3187.614446878433,2.323702335357666,0.0 -66200,0.8179613,5.297495,,,,,,,,,,,,,, -66300,0.95027614,3.7063496,,,,,,,,,,,,,, -66400,0.9479753,4.8501725,,,,,,,,,,,,,, -66500,0.80331135,5.1022406,,,,,,,,,,,,,, -66600,1.0574192,3.414699,,,,,,,,,,,,,, -66700,0.918594,3.6725523,,,,,,,,,,,,,, -66800,1.0784701,3.8496108,,,,,,,,,,,,,, -66900,0.7333324,5.2784677,,,,,,,,,,,,,, -67000,1.207854,3.6561887,,,,,,,,,,,,,, -67028,,,0.4659765660762787,2.3797836303710938,0.4309599995613098,2.57188081741333,50000.0,0.3293000161647796,3.2058255672454834,10000.0,30709.503110408783,33948.036386966705,30709.503110408783,3232.7051644325256,2.362233638763428,0.0 -67100,0.9180046,4.453869,,,,,,,,,,,,,, -67200,1.4991152,4.095689,,,,,,,,,,,,,, -67300,0.86609334,5.4210067,,,,,,,,,,,,,, -67400,1.0235273,5.0933924,,,,,,,,,,,,,, -67500,1.1491125,3.4645507,,,,,,,,,,,,,, -67600,0.9107135,4.0292873,,,,,,,,,,,,,, -67700,1.2475023,3.375395,,,,,,,,,,,,,, -67800,1.001359,3.4447966,,,,,,,,,,,,,, -67900,0.9766581,5.530672,,,,,,,,,,,,,, -67948,,,0.4738866984844208,2.352745532989502,0.4340199828147888,2.5681827068328857,50000.0,0.3377000093460083,3.203951835632324,10000.0,31129.867817878723,34414.95987582207,31129.867817878723,3279.18199968338,2.3966870307922363,0.0 -68000,1.1367775,3.4286811,,,,,,,,,,,,,, -68100,1.044687,3.5674155,,,,,,,,,,,,,, -68200,0.9609032,3.5687814,,,,,,,,,,,,,, -68300,1.2856362,3.543708,,,,,,,,,,,,,, -68400,1.1769725,3.5808716,,,,,,,,,,,,,, -68500,0.998694,3.5434797,,,,,,,,,,,,,, -68600,1.0752242,4.0579348,,,,,,,,,,,,,, -68700,1.0596724,3.3022518,,,,,,,,,,,,,, -68800,1.1041217,3.4970403,,,,,,,,,,,,,, -68868,,,0.4649609327316284,2.447537422180176,0.4329399764537811,2.587732553482056,50000.0,0.3359000086784363,3.207452774047852,10000.0,31550.06188440323,34880.10146713257,31550.06188440323,3324.0429894924164,2.434640884399414,0.0 -68900,1.0190532,3.3111448,,,,,,,,,,,,,, -69000,0.88638943,4.9220486,,,,,,,,,,,,,, -69100,0.96743,3.4576173,,,,,,,,,,,,,, -69200,1.0341166,3.605812,,,,,,,,,,,,,, -69300,0.9405927,4.7069902,,,,,,,,,,,,,, -69400,1.0622323,3.3679433,,,,,,,,,,,,,, -69500,0.98680747,3.415618,,,,,,,,,,,,,, -69600,0.80665123,4.843918,,,,,,,,,,,,,, -69700,1.276417,3.499012,,,,,,,,,,,,,, -69789,,,0.4654882848262787,2.4043633937835693,0.4331599771976471,2.563265800476074,50000.0,0.3391000032424927,3.190361976623535,10000.0,31970.368980884552,35347.10148000717,31970.368980884552,3370.650318622589,2.472820997238159,0.0 -69800,1.043185,3.5100608,,,,,,,,,,,,,, -69900,1.1572345,3.6038053,,,,,,,,,,,,,, -70000,1.4615037,3.8454804,,,,,,,,,,,,,, -70100,1.0259427,3.448138,,,,,,,,,,,,,, -70200,1.0939478,3.6683214,,,,,,,,,,,,,, -70300,1.0091254,3.7170959,,,,,,,,,,,,,, -70400,0.8601654,4.5643435,,,,,,,,,,,,,, -70500,1.0696503,3.4439285,,,,,,,,,,,,,, -70600,0.9595748,4.113659,,,,,,,,,,,,,, -70700,1.0353638,3.3551214,,,,,,,,,,,,,, -70706,,,0.4787109196186065,2.317901611328125,0.439520001411438,2.5207505226135254,50000.0,0.3436000049114227,3.145872116088867,10000.0,32390.55333662033,35809.67774987221,32390.55333662033,3412.957806110382,2.5098109245300293,0.0 -70800,1.0634545,3.476613,,,,,,,,,,,,,, -70900,1.0365037,3.3051188,,,,,,,,,,,,,, -71000,1.2313578,3.5235863,,,,,,,,,,,,,, -71100,1.543012,3.470455,,,,,,,,,,,,,, -71200,0.9625686,4.1723633,,,,,,,,,,,,,, -71300,1.018321,3.440322,,,,,,,,,,,,,, -71400,1.0196022,3.3147976,,,,,,,,,,,,,, -71500,0.9820211,3.9857142,,,,,,,,,,,,,, -71600,1.0805819,3.4848504,,,,,,,,,,,,,, -71625,,,0.4616210758686065,2.4654088020324707,0.4322199821472168,2.6213440895080566,50000.0,0.3367000222206116,3.232973098754883,10000.0,32810.782964229584,36277.062376499176,32810.782964229584,3460.0244784355164,2.551076889038086,0.0 -71700,1.0528648,3.3897889,,,,,,,,,,,,,, -71800,1.0460172,3.9524693,,,,,,,,,,,,,, -71900,1.2135078,3.6241837,,,,,,,,,,,,,, -72000,0.80579174,5.253083,,,,,,,,,,,,,, -72100,0.96849453,3.3650398,,,,,,,,,,,,,, -72200,1.1155992,3.3805184,,,,,,,,,,,,,, -72300,1.1005201,4.0909476,,,,,,,,,,,,,, -72400,1.0266811,4.101287,,,,,,,,,,,,,, -72500,0.9961385,4.9031467,,,,,,,,,,,,,, -72547,,,0.4809960722923279,2.3176419734954834,0.4526999890804291,2.466217041015625,50000.0,0.3489000201225281,3.1029651165008545,10000.0,33230.88841557503,36743.15631175041,33230.88841557503,3505.9291064739227,2.587003231048584,0.0 -72600,1.1516441,3.6523838,,,,,,,,,,,,,, -72700,0.90591145,3.8278618,,,,,,,,,,,,,, -72800,1.1855056,3.295755,,,,,,,,,,,,,, -72900,1.068544,3.3206244,,,,,,,,,,,,,, -73000,1.3059014,3.4616005,,,,,,,,,,,,,, -73100,1.1073387,3.7577052,,,,,,,,,,,,,, -73200,0.9453432,5.48098,,,,,,,,,,,,,, -73300,0.9289148,4.5488954,,,,,,,,,,,,,, -73400,0.8140294,4.6633883,,,,,,,,,,,,,, -73466,,,0.4806054532527923,2.2896952629089355,0.4437399804592132,2.492424249649048,50000.0,0.3473000228404999,3.121434211730957,10000.0,33650.96217060089,37208.76399350166,33650.96217060089,3551.381100177765,2.621609687805176,0.0 -73500,0.93290323,3.8583899,,,,,,,,,,,,,, -73600,1.0310447,5.086023,,,,,,,,,,,,,, -73700,1.1705087,3.491995,,,,,,,,,,,,,, -73800,1.1010946,3.5171542,,,,,,,,,,,,,, -73900,0.92916757,5.0574045,,,,,,,,,,,,,, -74000,1.1368566,3.4383774,,,,,,,,,,,,,, -74100,0.87323487,4.6593695,,,,,,,,,,,,,, -74200,1.0782262,3.6008472,,,,,,,,,,,,,, -74300,0.9355288,5.2406025,,,,,,,,,,,,,, -74386,,,0.5169921517372131,2.1353821754455566,0.4467199742794037,2.486428737640381,50000.0,0.3522000312805176,3.1258039474487305,10000.0,34071.35045528412,37676.37067079544,34071.35045528412,3598.5108897686005,2.663499355316162,0.0 -74400,0.82362705,4.9857373,,,,,,,,,,,,,, -74500,1.2829891,3.5073829,,,,,,,,,,,,,, -74600,1.0217872,3.4924667,,,,,,,,,,,,,, -74700,1.0742034,3.4506705,,,,,,,,,,,,,, -74800,0.8294114,5.2983184,,,,,,,,,,,,,, -74900,1.0128449,3.6321063,,,,,,,,,,,,,, -75000,0.82900673,5.7723074,,,,,,,,,,,,,, -75100,1.4910785,3.3447018,,,,,,,,,,,,,, -75200,1.0200545,3.3049934,,,,,,,,,,,,,, -75300,1.2266656,3.388831,,,,,,,,,,,,,, -75304,,,0.4877148270606994,2.240132570266724,0.4559399783611297,2.4114251136779785,50000.0,0.3513000309467315,3.071357011795044,10000.0,34491.62985539436,38136.85993814469,34491.62985539436,3638.625265598297,2.711342096328736,0.0 -75400,1.1308868,3.7190537,,,,,,,,,,,,,, -75500,1.3621812,3.6104383,,,,,,,,,,,,,, -75600,1.1794804,3.4059262,,,,,,,,,,,,,, -75700,1.0557749,4.075158,,,,,,,,,,,,,, -75800,1.0721813,3.3831325,,,,,,,,,,,,,, -75900,0.8495487,4.4750314,,,,,,,,,,,,,, -76000,0.9203602,5.245351,,,,,,,,,,,,,, -76100,0.9983795,4.4011526,,,,,,,,,,,,,, -76200,1.1470244,3.255114,,,,,,,,,,,,,, -76223,,,0.4906249940395355,2.2219138145446777,0.4575199782848358,2.4190688133239746,50000.0,0.3555000126361847,3.057042121887207,10000.0,34911.86199808121,38600.17257928848,34911.86199808121,3681.6170043945312,2.75275993347168,0.0 -76300,0.9143658,5.399892,,,,,,,,,,,,,, -76400,0.8721705,5.7835026,,,,,,,,,,,,,, -76500,0.9736967,3.6340945,,,,,,,,,,,,,, -76600,0.9407541,5.6771245,,,,,,,,,,,,,, -76700,0.8182796,4.387532,,,,,,,,,,,,,, -76800,1.2868301,3.3964427,,,,,,,,,,,,,, -76900,1.080514,3.6936388,,,,,,,,,,,,,, -77000,1.1714252,3.3186848,,,,,,,,,,,,,, -77100,1.1217713,3.211549,,,,,,,,,,,,,, -77142,,,0.5056250095367432,2.1703872680664062,0.4499799907207489,2.44785213470459,50000.0,0.3538000285625458,3.080111503601074,10000.0,35332.00664615631,39066.17379283905,35332.00664615631,3727.388257026672,2.79075288772583,0.0 -77200,1.616327,3.4269211,,,,,,,,,,,,,, -77300,1.1810662,3.43465,,,,,,,,,,,,,, -77400,0.83395165,5.2351522,,,,,,,,,,,,,, -77500,1.2025051,3.3512542,,,,,,,,,,,,,, -77600,1.0838702,3.351965,,,,,,,,,,,,,, -77700,0.8021151,5.5267277,,,,,,,,,,,,,, -77800,1.0728651,5.608796,,,,,,,,,,,,,, -77900,0.98111564,5.1902294,,,,,,,,,,,,,, -78000,0.8407089,5.79391,,,,,,,,,,,,,, -78062,,,0.4858007729053497,2.2651078701019287,0.4527799785137176,2.438197374343872,50000.0,0.3588000237941742,3.084175109863281,10000.0,35752.21084976196,39530.53333187103,35752.21084976196,3771.449672698975,2.83611798286438,0.0 -78100,1.1835254,3.215469,,,,,,,,,,,,,, -78200,1.0326904,3.2477448,,,,,,,,,,,,,, -78300,0.95589536,4.330399,,,,,,,,,,,,,, -78400,0.91927826,5.637807,,,,,,,,,,,,,, -78500,1.1998161,3.465053,,,,,,,,,,,,,, -78600,1.015529,3.5853481,,,,,,,,,,,,,, -78700,1.1281732,3.7484484,,,,,,,,,,,,,, -78800,1.193572,3.3465562,,,,,,,,,,,,,, -78900,1.2035409,3.5416281,,,,,,,,,,,,,, -78982,,,0.4868945181369781,2.2959415912628174,0.4545799791812897,2.4721052646636963,50000.0,0.34620001912117,3.136114835739136,10000.0,36172.48115777969,39996.05541801453,36172.48115777969,3816.614407777786,2.875884532928467,0.0 -79000,1.1299666,3.2405581,,,,,,,,,,,,,, -79100,1.0343648,3.4117312,,,,,,,,,,,,,, -79200,1.0455501,3.442082,,,,,,,,,,,,,, -79300,1.0925752,3.2709224,,,,,,,,,,,,,, -79400,1.0435148,4.491766,,,,,,,,,,,,,, -79500,1.0485253,4.4555917,,,,,,,,,,,,,, -79600,1.072295,3.2543068,,,,,,,,,,,,,, -79700,1.0818616,3.3992672,,,,,,,,,,,,,, -79800,0.78730977,5.610643,,,,,,,,,,,,,, -79900,1.0645175,3.4592323,,,,,,,,,,,,,, -79903,,,0.4998827874660492,2.2071332931518555,0.4567599892616272,2.4251766204833984,50000.0,0.3619000315666199,3.0350892543792725,10000.0,36592.814427137375,40465.15758180618,36592.814427137375,3865.289860725403,2.9212820529937744,0.0 -80000,0.9526942,5.6179,,,,,,,,,,,,,, -80100,1.0507934,3.1410682,,,,,,,,,,,,,, -80200,0.871799,5.680935,,,,,,,,,,,,,, -80300,1.2043111,3.420799,,,,,,,,,,,,,, -80400,1.1814138,3.281704,,,,,,,,,,,,,, -80500,1.362535,3.2505,,,,,,,,,,,,,, -80600,0.8907042,4.6758347,,,,,,,,,,,,,, -80700,1.0685111,3.4955473,,,,,,,,,,,,,, -80800,1.3762124,3.3842325,,,,,,,,,,,,,, -80823,,,0.4993554651737213,2.188851356506348,0.471560001373291,2.3428831100463867,50000.0,0.367900013923645,2.989685773849488,10000.0,37013.1857984066,40926.628831624985,37013.1857984066,3906.299390554428,2.9639267921447754,0.0 -80900,1.1321975,3.3156595,,,,,,,,,,,,,, -81000,1.0501612,3.4631371,,,,,,,,,,,,,, -81100,1.218082,3.3126566,,,,,,,,,,,,,, -81200,1.2318579,3.2479248,,,,,,,,,,,,,, -81300,1.4840311,3.2855716,,,,,,,,,,,,,, -81400,0.9792697,4.29068,,,,,,,,,,,,,, -81500,1.0985726,3.4685934,,,,,,,,,,,,,, -81600,1.0915831,3.1921003,,,,,,,,,,,,,, -81700,1.099563,3.5032742,,,,,,,,,,,,,, -81739,,,0.5030664205551147,2.173933744430542,0.4730799794197082,2.3356215953826904,50000.0,0.368800014257431,2.9926578998565674,10000.0,37433.493015527725,41391.62658786774,37433.493015527725,3950.904098033905,3.0028018951416016,0.0 -81800,0.8932515,5.226981,,,,,,,,,,,,,, -81900,1.1075208,3.1957798,,,,,,,,,,,,,, -82000,1.0654652,3.2760806,,,,,,,,,,,,,, -82100,1.3205053,3.3401296,,,,,,,,,,,,,, -82200,1.1704469,3.6618528,,,,,,,,,,,,,, -82300,1.3505594,3.2312212,,,,,,,,,,,,,, -82400,1.1832877,3.307853,,,,,,,,,,,,,, -82500,1.0129048,3.2538736,,,,,,,,,,,,,, -82600,1.1891342,3.1907227,,,,,,,,,,,,,, -82657,,,0.5142968893051147,2.146437168121338,0.472379982471466,2.358328104019165,50000.0,0.3644000291824341,3.0233078002929688,10000.0,37853.55633306503,41858.507304906845,37853.55633306503,3997.634565114975,3.0432190895080566,0.0 -82700,1.1267617,3.1675367,,,,,,,,,,,,,, -82800,1.0032946,4.614763,,,,,,,,,,,,,, -82900,1.0418006,3.3052466,,,,,,,,,,,,,, -83000,1.0824797,3.5860248,,,,,,,,,,,,,, -83100,1.140767,3.2524867,,,,,,,,,,,,,, -83200,1.0960025,3.193014,,,,,,,,,,,,,, -83300,0.9898691,3.1541083,,,,,,,,,,,,,, -83400,0.9629015,4.1898303,,,,,,,,,,,,,, -83500,1.0829977,3.1242476,,,,,,,,,,,,,, -83577,,,0.5132030844688416,2.140690803527832,0.4748199880123138,2.3312788009643555,50000.0,0.3646000027656555,2.994065284729004,10000.0,38273.708990097046,42315.99324274063,38273.708990097046,4034.875780582428,3.087830781936645,0.0 -83600,1.24918,3.5940287,,,,,,,,,,,,,, -83700,1.1516719,3.048533,,,,,,,,,,,,,, -83800,1.2566465,3.286467,,,,,,,,,,,,,, -83900,1.0506729,3.7092483,,,,,,,,,,,,,, -84000,1.4581372,3.3276284,,,,,,,,,,,,,, -84100,0.9877416,5.43492,,,,,,,,,,,,,, -84200,1.0127885,3.3318825,,,,,,,,,,,,,, -84300,0.8338623,5.630987,,,,,,,,,,,,,, -84400,1.1197718,3.7897987,,,,,,,,,,,,,, -84497,,,0.5048632621765137,2.2373056411743164,0.4640399813652038,2.4101569652557373,50000.0,0.3582000136375427,3.0756309032440186,10000.0,38693.9015994072,42781.52650523186,38693.9015994072,4080.1282589435577,3.1285057067871094,0.0 -84500,1.0985217,5.409036,,,,,,,,,,,,,, -84600,0.94001156,5.6186385,,,,,,,,,,,,,, -84700,0.9603625,4.172325,,,,,,,,,,,,,, -84800,1.1419984,3.7350934,,,,,,,,,,,,,, -84900,0.90542436,4.1088543,,,,,,,,,,,,,, -85000,1.2766169,3.17267,,,,,,,,,,,,,, -85100,1.0380815,5.7457533,,,,,,,,,,,,,, -85200,1.0211253,3.2922416,,,,,,,,,,,,,, -85300,1.1024075,3.7282875,,,,,,,,,,,,,, -85400,1.1024505,3.3066466,,,,,,,,,,,,,, -85417,,,0.5148242115974426,2.1396749019622803,0.4755999743938446,2.3394620418548584,50000.0,0.3726000189781189,2.975393772125244,10000.0,39114.07578897476,43244.331899404526,39114.07578897476,4122.668626070023,3.1721973419189453,0.0 -85500,1.326004,3.2334294,,,,,,,,,,,,,, -85600,0.94537604,4.3537064,,,,,,,,,,,,,, -85700,1.1025627,3.1676342,,,,,,,,,,,,,, -85800,1.1098778,3.7116196,,,,,,,,,,,,,, -85900,1.2147423,3.2186265,,,,,,,,,,,,,, -86000,1.2563676,3.197076,,,,,,,,,,,,,, -86100,0.96523887,3.3822813,,,,,,,,,,,,,, -86200,1.0723137,3.7043242,,,,,,,,,,,,,, -86300,1.0406194,3.3071263,,,,,,,,,,,,,, -86337,,,0.5555468797683716,1.912348389625549,0.4817200005054474,2.272268295288086,50000.0,0.3779000043869018,2.9314382076263428,10000.0,39534.43977665901,43711.998577833176,39534.43977665901,4169.881982803345,3.214378833770752,0.0 -86400,0.87329495,5.266802,,,,,,,,,,,,,, -86500,0.90261817,4.501302,,,,,,,,,,,,,, -86600,1.2975608,3.132582,,,,,,,,,,,,,, -86700,1.0670207,3.1751606,,,,,,,,,,,,,, -86800,1.0050851,3.32663,,,,,,,,,,,,,, -86900,1.0075697,3.590202,,,,,,,,,,,,,, -87000,0.95319283,5.3066134,,,,,,,,,,,,,, -87100,1.2589033,3.4164457,,,,,,,,,,,,,, -87200,0.9892474,3.4199636,,,,,,,,,,,,,, -87257,,,0.52099609375,2.096512079238892,0.4859399795532226,2.274020195007324,50000.0,0.3801000118255615,2.9268369674682617,10000.0,39954.65167856216,44175.57442951202,39954.65167856216,4213.158932924271,3.2534549236297607,0.0 -87300,0.8608003,5.377972,,,,,,,,,,,,,, -87400,1.085888,3.9350655,,,,,,,,,,,,,, -87500,0.8827306,4.5551605,,,,,,,,,,,,,, -87600,1.2758546,3.1615872,,,,,,,,,,,,,, -87700,1.2023267,3.1997476,,,,,,,,,,,,,, -87800,1.159259,3.1873684,,,,,,,,,,,,,, -87900,1.0937765,3.1604664,,,,,,,,,,,,,, -88000,1.2104245,3.22749,,,,,,,,,,,,,, -88100,1.0017314,3.2988815,,,,,,,,,,,,,, -88175,,,0.5269726514816284,2.0405514240264893,0.4899399876594543,2.239298105239868,50000.0,0.3754000067710876,2.916119337081909,10000.0,40374.94868397713,44643.63994884491,40374.94868397713,4260.838416099548,3.29453992843628,0.0 -88200,1.259591,3.1912029,,,,,,,,,,,,,, -88300,0.8565277,5.121231,,,,,,,,,,,,,, -88400,1.0501409,5.6594415,,,,,,,,,,,,,, -88500,1.270864,3.253829,,,,,,,,,,,,,, -88600,1.2051897,3.1755753,,,,,,,,,,,,,, -88700,0.970701,4.2233686,,,,,,,,,,,,,, -88800,1.2698686,3.086757,,,,,,,,,,,,,, -88900,1.0806426,3.092489,,,,,,,,,,,,,, -89000,1.0559705,3.559174,,,,,,,,,,,,,, -89095,,,0.5325976610183716,2.0417697429656982,0.4813999831676483,2.3064215183258057,50000.0,0.3803000152111053,2.940920829772949,10000.0,40795.08370661736,45108.95787191391,40795.08370661736,4305.934587717056,3.3338735103607178,0.0 -89100,1.2602576,3.13876,,,,,,,,,,,,,, -89200,1.2498573,3.2899346,,,,,,,,,,,,,, -89300,0.89777416,4.479808,,,,,,,,,,,,,, -89400,1.0388359,5.4085364,,,,,,,,,,,,,, -89500,0.8424819,4.737233,,,,,,,,,,,,,, -89600,0.95549965,3.768055,,,,,,,,,,,,,, -89700,0.91088724,5.0434895,,,,,,,,,,,,,, -89800,1.327677,3.3004484,,,,,,,,,,,,,, -89900,0.99932784,5.5515685,,,,,,,,,,,,,, -90000,1.0247306,3.9017653,,,,,,,,,,,,,, -90014,,,0.5163866877555847,2.1249258518218994,0.4807799756526947,2.301429033279419,50000.0,0.3824000060558319,2.936632871627808,10000.0,41215.38111019135,45574.949031591415,41215.38111019135,4351.538655757904,3.375528812408448,0.0 -90100,1.1163151,3.7828302,,,,,,,,,,,,,, -90200,1.0951828,4.025665,,,,,,,,,,,,,, -90300,0.9329358,4.297501,,,,,,,,,,,,,, -90400,0.90438986,4.602558,,,,,,,,,,,,,, -90500,1.3721727,3.1380434,,,,,,,,,,,,,, -90600,0.8337118,5.0856996,,,,,,,,,,,,,, -90700,1.1125858,3.0584662,,,,,,,,,,,,,, -90800,1.0231397,4.9209204,,,,,,,,,,,,,, -90900,0.9288336,5.112807,,,,,,,,,,,,,, -90933,,,0.5344530940055847,2.001258134841919,0.4990399777889251,2.194621324539185,50000.0,0.3911000192165375,2.831852436065674,10000.0,41635.74610328674,46040.23263311386,41635.74610328674,4396.368488788605,3.417140483856201,0.0 -91000,1.0726663,3.1407626,,,,,,,,,,,,,, -91100,1.1164875,3.765733,,,,,,,,,,,,,, -91200,1.2016686,3.0629447,,,,,,,,,,,,,, -91300,1.0008454,3.6421762,,,,,,,,,,,,,, -91400,1.2426494,3.0773664,,,,,,,,,,,,,, -91500,0.9778459,5.51728,,,,,,,,,,,,,, -91600,1.2641939,3.0968106,,,,,,,,,,,,,, -91700,1.0561215,3.1442914,,,,,,,,,,,,,, -91800,1.2247885,3.2769027,,,,,,,,,,,,,, -91852,,,0.5383593440055847,2.007728815078736,0.4928599894046783,2.2256059646606445,50000.0,0.388700008392334,2.87404203414917,10000.0,42056.00476717949,46508.55019235611,42056.00476717949,4444.340796709061,3.457115888595581,0.0 -91900,1.117958,2.9986343,,,,,,,,,,,,,, -92000,1.3612745,3.2340946,,,,,,,,,,,,,, -92100,1.1394056,5.3863697,,,,,,,,,,,,,, -92200,1.1088363,3.4297092,,,,,,,,,,,,,, -92300,0.87953603,5.4394565,,,,,,,,,,,,,, -92400,0.8753765,4.260652,,,,,,,,,,,,,, -92500,1.3784789,3.1452236,,,,,,,,,,,,,, -92600,1.0341308,3.5357208,,,,,,,,,,,,,, -92700,1.2691219,3.1533,,,,,,,,,,,,,, -92769,,,0.5292577743530273,2.0604593753814697,0.4968799948692322,2.218428134918213,50000.0,0.39410001039505,2.844502687454224,10000.0,42476.11767077446,46977.44144821167,42476.11767077446,4493.034183979034,3.495344877243042,0.0 -92800,1.2210066,3.135516,,,,,,,,,,,,,, -92900,1.0023693,5.642367,,,,,,,,,,,,,, -93000,1.117998,3.2922652,,,,,,,,,,,,,, -93100,1.0611064,4.542064,,,,,,,,,,,,,, -93200,1.155442,3.100285,,,,,,,,,,,,,, -93300,0.9879736,4.9434347,,,,,,,,,,,,,, -93400,1.146243,3.2019718,,,,,,,,,,,,,, -93500,1.1162505,3.1927283,,,,,,,,,,,,,, -93600,0.9978697,4.09413,,,,,,,,,,,,,, -93690,,,0.5306054353713989,2.020060777664185,0.497979998588562,2.209547281265259,50000.0,0.3966000080108642,2.847208023071289,10000.0,42896.11478877068,47443.92788076401,42896.11478877068,4539.431235074997,3.540785312652588,0.0 -93700,1.1622554,3.1881828,,,,,,,,,,,,,, -93800,0.93010944,5.6172385,,,,,,,,,,,,,, -93900,1.0751,4.329182,,,,,,,,,,,,,, -94000,0.92920107,3.4148436,,,,,,,,,,,,,, -94100,1.1773866,3.025308,,,,,,,,,,,,,, -94200,1.2476226,3.2765532,,,,,,,,,,,,,, -94300,0.96985596,5.091237,,,,,,,,,,,,,, -94400,1.2697839,2.9752707,,,,,,,,,,,,,, -94500,1.1397111,3.3322945,,,,,,,,,,,,,, -94600,1.2580066,3.0644712,,,,,,,,,,,,,, -94609,,,0.5318945050239563,2.0958914756774902,0.4930999875068664,2.292987108230591,50000.0,0.3892000317573547,2.920865297317505,10000.0,43316.36288237572,47912.20147418976,43316.36288237572,4587.36496925354,3.5851848125457764,0.0 -94700,1.1255505,3.4714603,,,,,,,,,,,,,, -94800,1.144768,5.2168117,,,,,,,,,,,,,, -94900,1.1783231,3.0971103,,,,,,,,,,,,,, -95000,1.2615122,2.997844,,,,,,,,,,,,,, -95100,1.2006943,2.9693148,,,,,,,,,,,,,, -95200,1.0448757,2.9527016,,,,,,,,,,,,,, -95300,1.0090364,3.9165301,,,,,,,,,,,,,, -95400,1.0982372,5.475792,,,,,,,,,,,,,, -95500,0.8963977,4.6148453,,,,,,,,,,,,,, -95526,,,0.5541210770606995,1.9405015707015991,0.504859983921051,2.179647922515869,50000.0,0.3992000222206116,2.828856468200684,10000.0,43736.64577579498,48381.214473724365,43736.64577579498,4636.007810115814,3.625463962554932,0.0 -95600,1.0638794,3.3099613,,,,,,,,,,,,,, -95700,1.2055172,3.040564,,,,,,,,,,,,,, -95800,1.2570409,3.565924,,,,,,,,,,,,,, -95900,1.1659963,2.995617,,,,,,,,,,,,,, -96000,0.99256265,3.9736843,,,,,,,,,,,,,, -96100,1.1667868,5.591067,,,,,,,,,,,,,, -96200,1.1880373,3.0111728,,,,,,,,,,,,,, -96300,1.2233741,2.911981,,,,,,,,,,,,,, -96400,1.3186375,3.1077561,,,,,,,,,,,,,, -96446,,,0.5423046946525574,2.020184278488159,0.5077199935913086,2.1947858333587646,50000.0,0.3907000124454498,2.8464391231536865,10000.0,44156.965695381165,48843.16408109665,44156.965695381165,4677.541831970215,3.672870397567749,0.0 -96500,1.1439192,3.8997726,,,,,,,,,,,,,, -96600,1.4472144,3.4312868,,,,,,,,,,,,,, -96700,1.1174972,3.0297666,,,,,,,,,,,,,, -96800,1.2672687,3.0981622,,,,,,,,,,,,,, -96900,0.9756921,5.5023775,,,,,,,,,,,,,, -97000,1.1424205,2.878803,,,,,,,,,,,,,, -97100,0.8680928,5.4418573,,,,,,,,,,,,,, -97200,0.8563356,5.339343,,,,,,,,,,,,,, -97300,1.2353904,3.0441613,,,,,,,,,,,,,, -97366,,,0.5404687523841858,2.009705781936645,0.5008000135421753,2.208786725997925,50000.0,0.395300030708313,2.858224391937256,10000.0,44577.15295219421,49311.0440621376,44577.15295219421,4725.14670586586,3.7132389545440674,0.0 -97400,0.96465987,4.9260902,,,,,,,,,,,,,, -97500,1.2619737,2.9723663,,,,,,,,,,,,,, -97600,1.19445,3.0730634,,,,,,,,,,,,,, -97700,0.9632222,4.3748183,,,,,,,,,,,,,, -97800,1.153148,3.1238925,,,,,,,,,,,,,, -97900,1.2974113,3.3052642,,,,,,,,,,,,,, -98000,0.8528551,4.5901055,,,,,,,,,,,,,, -98100,1.0190309,3.7105994,,,,,,,,,,,,,, -98200,1.2521012,3.0803084,,,,,,,,,,,,,, -98286,,,0.5715429782867432,1.849209904670716,0.5137400031089783,2.139185667037964,50000.0,0.3963000178337097,2.810023069381714,10000.0,44997.3252222538,49779.63790535927,44997.3252222538,4773.47830247879,3.75544548034668,0.0 -98300,0.98940706,4.669312,,,,,,,,,,,,,, -98400,1.0242186,4.949717,,,,,,,,,,,,,, -98500,1.3016807,3.1722817,,,,,,,,,,,,,, -98600,0.9626403,5.2344513,,,,,,,,,,,,,, -98700,1.0803286,2.9252546,,,,,,,,,,,,,, -98800,1.120796,3.0154212,,,,,,,,,,,,,, -98900,1.219106,2.9846256,,,,,,,,,,,,,, -99000,1.1679637,2.9237847,,,,,,,,,,,,,, -99100,0.8940603,4.334657,,,,,,,,,,,,,, -99200,1.2364287,3.0106544,,,,,,,,,,,,,, -99206,,,0.5505663752555847,1.9494009017944336,0.5107200145721436,2.1389076709747314,50000.0,0.401600033044815,2.7917513847351074,10000.0,45417.43864130974,50241.46690893173,45417.43864130974,4815.0956864357,3.802260398864746,0.0 -99300,1.4037904,3.0761256,,,,,,,,,,,,,, -99400,1.4640741,3.082096,,,,,,,,,,,,,, -99500,1.3312889,2.98204,,,,,,,,,,,,,, -99600,1.206885,2.959001,,,,,,,,,,,,,, -99700,1.1255603,3.0320797,,,,,,,,,,,,,, -99800,1.2315111,3.107036,,,,,,,,,,,,,, -99900,1.4499378,2.9824293,,,,,,,,,,,,,, -100000,0.9737943,4.3378205,,,,,,,,,,,,,, -100100,1.2975343,3.1595879,,,,,,,,,,,,,, -100126,,,0.5661718845367432,1.871565222740173,0.5231800079345703,2.0831315517425537,50000.0,0.4137000143527984,2.7351529598236084,10000.0,45837.39217829704,50710.29930901528,45837.39217829704,4863.8861446380615,3.843090057373047,0.0 -100200,1.1756843,3.0035994,,,,,,,,,,,,,, -100300,1.4339888,3.0172122,,,,,,,,,,,,,, -100400,0.9290052,5.3570833,,,,,,,,,,,,,, -100500,1.1283245,3.4704444,,,,,,,,,,,,,, -100600,1.1560977,3.1754687,,,,,,,,,,,,,, -100700,1.3263398,2.9334176,,,,,,,,,,,,,, -100800,1.3125037,2.92577,,,,,,,,,,,,,, -100900,1.0310105,4.1440015,,,,,,,,,,,,,, -101000,1.2094641,2.9143791,,,,,,,,,,,,,, -101047,,,0.5694140791893005,1.840009093284607,0.5202800035476685,2.0846471786499023,50000.0,0.4107000231742859,2.7359278202056885,10000.0,46257.652686834335,51177.58735990524,46257.652686834335,4910.821338653564,3.887923240661621,0.0 -101100,1.2432996,3.110825,,,,,,,,,,,,,, -101200,1.1675588,3.0412424,,,,,,,,,,,,,, -101300,1.4931298,3.0819197,,,,,,,,,,,,,, -101400,1.1409675,3.1064258,,,,,,,,,,,,,, -101500,1.2983186,2.9961858,,,,,,,,,,,,,, -101600,1.2876567,3.0762992,,,,,,,,,,,,,, -101700,1.1929194,3.1228974,,,,,,,,,,,,,, -101800,1.0739046,3.8793035,,,,,,,,,,,,,, -101900,1.1858051,3.3991385,,,,,,,,,,,,,, -101966,,,0.5561913847923279,1.9271278381347656,0.5159400105476379,2.123799085617065,50000.0,0.4108000099658966,2.763245105743408,10000.0,46677.740731954575,51643.885112285614,46677.740731954575,4956.939247131348,3.932757616043091,0.0 -102000,1.1125618,3.9971132,,,,,,,,,,,,,, -102100,1.1459008,3.0456297,,,,,,,,,,,,,, -102200,1.4057345,3.0381017,,,,,,,,,,,,,, -102300,1.1677233,3.2088523,,,,,,,,,,,,,, -102400,1.173774,2.9105778,,,,,,,,,,,,,, -102500,1.2212428,2.8344567,,,,,,,,,,,,,, -102600,0.9845554,4.9354696,,,,,,,,,,,,,, -102700,1.2835884,2.9023626,,,,,,,,,,,,,, -102800,1.2145785,2.9220214,,,,,,,,,,,,,, -102885,,,0.5678125023841858,1.848743557929993,0.5289799571037292,2.049794435501098,50000.0,0.4234000146389007,2.709082841873169,10000.0,47097.76620817184,52111.843133449554,47097.76620817184,5004.784141540527,3.972791910171509,0.0 -102900,1.3558584,3.0460489,,,,,,,,,,,,,, -103000,1.0626758,4.719903,,,,,,,,,,,,,, -103100,1.214933,2.8751612,,,,,,,,,,,,,, -103200,1.0534799,4.638648,,,,,,,,,,,,,, -103300,1.2032619,2.841954,,,,,,,,,,,,,, -103400,1.0545628,5.434312,,,,,,,,,,,,,, -103500,1.1530665,3.5275364,,,,,,,,,,,,,, -103600,1.2942189,2.9242837,,,,,,,,,,,,,, -103700,1.3627309,2.876524,,,,,,,,,,,,,, -103800,0.9484829,4.784119,,,,,,,,,,,,,, -103806,,,0.5733789205551147,1.824051141738892,0.5295799970626831,2.0383520126342773,50000.0,0.4193000197410583,2.696749448776245,10000.0,47517.937237262726,52581.48854184151,47517.937237262726,5054.160320997238,4.023122072219849,0.0 -103900,1.276192,2.9875953,,,,,,,,,,,,,, -104000,1.1194881,2.9075785,,,,,,,,,,,,,, -104100,1.267236,3.05534,,,,,,,,,,,,,, -104200,1.3095018,2.9502964,,,,,,,,,,,,,, -104300,1.2168298,2.8581996,,,,,,,,,,,,,, -104400,1.251519,2.9214444,,,,,,,,,,,,,, -104500,1.3393679,2.9101305,,,,,,,,,,,,,, -104600,1.3785461,2.8545723,,,,,,,,,,,,,, -104700,1.2151198,2.9307582,,,,,,,,,,,,,, -104725,,,0.567187488079071,1.864588022232056,0.5290799736976624,2.053553581237793,50000.0,0.4147000312805176,2.714286088943481,10000.0,47937.92220687866,53049.274918079376,47937.92220687866,5101.870931148529,4.066912651062012,0.0 -104800,0.94082856,5.2719345,,,,,,,,,,,,,, -104900,1.218662,2.9380257,,,,,,,,,,,,,, -105000,0.9937334,5.3721313,,,,,,,,,,,,,, -105100,1.1465302,2.8317232,,,,,,,,,,,,,, -105200,1.1978012,2.9127457,,,,,,,,,,,,,, -105300,1.3721825,3.2462068,,,,,,,,,,,,,, -105400,1.0441017,3.902047,,,,,,,,,,,,,, -105500,1.3336889,3.037129,,,,,,,,,,,,,, -105600,1.1482061,3.4041085,,,,,,,,,,,,,, -105644,,,0.5702148079872131,1.839844822883606,0.5308600068092346,2.028477191925049,50000.0,0.4223000109195709,2.6632344722747803,10000.0,48358.190257549286,53519.78946995735,48358.190257549286,5152.026482105255,4.110436916351318,0.0 -105700,1.1058528,2.840773,,,,,,,,,,,,,, -105800,1.1129099,4.329137,,,,,,,,,,,,,, -105900,1.073102,3.315499,,,,,,,,,,,,,, -106000,1.2661635,2.9602125,,,,,,,,,,,,,, -106100,1.2745365,2.891253,,,,,,,,,,,,,, -106200,0.88794297,5.3086166,,,,,,,,,,,,,, -106300,1.3065673,2.9060006,,,,,,,,,,,,,, -106400,1.1436771,3.0137608,,,,,,,,,,,,,, -106500,1.2939972,3.008688,,,,,,,,,,,,,, -106563,,,0.5782421827316284,1.855036258697509,0.5310800075531006,2.070237874984741,50000.0,0.4170000255107879,2.710397720336914,10000.0,48778.23556470871,53983.98978614807,48778.23556470871,5196.08935046196,4.1553053855896,0.0 -106600,0.9277867,4.7793174,,,,,,,,,,,,,, -106700,1.0848594,5.278864,,,,,,,,,,,,,, -106800,1.0807171,3.0769944,,,,,,,,,,,,,, -106900,0.9758182,4.2050586,,,,,,,,,,,,,, -107000,1.3078415,2.7814703,,,,,,,,,,,,,, -107100,1.2219467,2.971129,,,,,,,,,,,,,, -107200,1.174968,2.8473177,,,,,,,,,,,,,, -107300,0.9921309,4.625977,,,,,,,,,,,,,, -107400,1.3311051,2.9580374,,,,,,,,,,,,,, -107482,,,0.6131640672683716,1.644121527671814,0.536899983882904,2.003920078277588,50000.0,0.4292000234127044,2.670915603637696,10000.0,49198.48667836189,54452.19374775887,49198.48667836189,5243.95282626152,4.197072982788086,0.0 -107500,1.0747954,4.48985,,,,,,,,,,,,,, -107600,1.0841954,5.3199306,,,,,,,,,,,,,, -107700,1.0560192,3.1392481,,,,,,,,,,,,,, -107800,1.1841773,2.910772,,,,,,,,,,,,,, -107900,1.2403554,2.865233,,,,,,,,,,,,,, -108000,1.3539938,2.880088,,,,,,,,,,,,,, -108100,1.0883855,3.9434383,,,,,,,,,,,,,, -108200,1.228334,5.5168333,,,,,,,,,,,,,, -108300,1.278557,3.0506742,,,,,,,,,,,,,, -108400,1.0049305,5.023299,,,,,,,,,,,,,, -108401,,,0.5776953101158142,1.8307543992996216,0.5416600108146667,2.010225296020508,50000.0,0.4228000342845917,2.666642427444458,10000.0,49619.11729288101,54916.2911529541,49619.11729288101,5287.324712753296,4.244433879852295,0.0 -108500,1.091936,4.025547,,,,,,,,,,,,,, -108600,1.3224809,2.8183155,,,,,,,,,,,,,, -108700,0.9805906,5.3550096,,,,,,,,,,,,,, -108800,1.4332515,2.9195256,,,,,,,,,,,,,, -108900,1.2266834,3.1081758,,,,,,,,,,,,,, -109000,1.2137617,3.4110427,,,,,,,,,,,,,, -109100,1.184898,2.7151017,,,,,,,,,,,,,, -109200,1.1741121,2.996676,,,,,,,,,,,,,, -109300,1.2970991,2.8652039,,,,,,,,,,,,,, -109321,,,0.5839257836341858,1.781617283821106,0.5431399941444397,1.9895758628845213,50000.0,0.429500013589859,2.6396257877349854,10000.0,50039.2133743763,55383.89688038826,50039.2133743763,5334.740091085434,4.290728807449341,0.0 -109400,1.2396706,3.0645516,,,,,,,,,,,,,, -109500,1.3123192,3.1641805,,,,,,,,,,,,,, -109600,1.1019354,3.2387593,,,,,,,,,,,,,, -109700,1.0614182,4.2907014,,,,,,,,,,,,,, -109800,1.0891749,5.397331,,,,,,,,,,,,,, -109900,1.1936804,2.7698643,,,,,,,,,,,,,, -110000,1.2754769,2.8258488,,,,,,,,,,,,,, -110100,1.2494973,2.9198565,,,,,,,,,,,,,, -110200,1.0435714,5.325305,,,,,,,,,,,,,, -110242,,,0.6037304401397705,1.6702401638031006,0.5479999780654907,1.953895568847656,50000.0,0.4374000132083893,2.595966100692749,10000.0,50459.24368357658,55849.68517708778,50459.24368357658,5380.403715848923,4.336857795715332,0.0 -110300,1.2401739,2.9153278,,,,,,,,,,,,,, -110400,1.2313303,3.1086018,,,,,,,,,,,,,, -110500,1.2196571,3.642876,,,,,,,,,,,,,, -110600,0.9601508,4.8128977,,,,,,,,,,,,,, -110700,1.2283684,2.865802,,,,,,,,,,,,,, -110800,1.3135395,2.9442987,,,,,,,,,,,,,, -110900,1.2007364,3.3729632,,,,,,,,,,,,,, -111000,1.3251082,3.0590138,,,,,,,,,,,,,, -111100,1.1718686,2.7072732,,,,,,,,,,,,,, -111161,,,0.5822851657867432,1.7880244255065918,0.5461999773979187,1.9647815227508545,50000.0,0.4337000250816345,2.628989696502685,10000.0,50879.184905052185,56312.04675197601,50879.184905052185,5422.70272564888,4.410296440124512,0.0 -111200,1.3725277,3.0323339,,,,,,,,,,,,,, -111300,1.1798053,3.117937,,,,,,,,,,,,,, -111400,1.1919781,3.2793324,,,,,,,,,,,,,, -111500,1.580728,2.7055502,,,,,,,,,,,,,, -111600,1.386894,2.865489,,,,,,,,,,,,,, -111700,1.181982,2.7793527,,,,,,,,,,,,,, -111800,1.1272354,5.4024415,,,,,,,,,,,,,, -111900,1.034479,3.4647455,,,,,,,,,,,,,, -112000,1.1254728,5.11408,,,,,,,,,,,,,, -112079,,,0.5941405892372131,1.7171932458877563,0.5562199950218201,1.896606683731079,50000.0,0.4411000311374664,2.573200464248657,10000.0,51299.41596865654,56780.80129766464,51299.41596865654,5471.136475563049,4.45197343826294,0.0 -112100,1.2712383,2.8226194,,,,,,,,,,,,,, -112200,1.2739305,2.772532,,,,,,,,,,,,,, -112300,1.4663271,2.7656755,,,,,,,,,,,,,, -112400,1.2493415,2.8388143,,,,,,,,,,,,,, -112500,1.1681451,2.7633827,,,,,,,,,,,,,, -112600,1.272592,3.092722,,,,,,,,,,,,,, -112700,1.2486025,2.7724714,,,,,,,,,,,,,, -112800,1.3048922,2.677987,,,,,,,,,,,,,, -112900,1.3672202,2.8499215,,,,,,,,,,,,,, -112999,,,0.6062109470367432,1.663967847824097,0.555899977684021,1.902474045753479,50000.0,0.4442000091075897,2.5536797046661377,10000.0,51719.630373477936,57248.9131758213,51719.630373477936,5518.939175367355,4.49985933303833,0.0 -113000,1.4196867,2.7416236,,,,,,,,,,,,,, -113100,1.1890672,3.4459069,,,,,,,,,,,,,, -113200,1.540895,2.8932836,,,,,,,,,,,,,, -113300,1.0484539,4.530248,,,,,,,,,,,,,, -113400,1.028975,4.0665164,,,,,,,,,,,,,, -113500,1.2199204,3.4220247,,,,,,,,,,,,,, -113600,1.2168502,2.7055812,,,,,,,,,,,,,, -113700,1.3218035,2.7346067,,,,,,,,,,,,,, -113800,1.094075,3.2542906,,,,,,,,,,,,,, -113900,1.043102,4.2281265,,,,,,,,,,,,,, -113919,,,0.594042956829071,1.7198957204818726,0.5557599663734436,1.9055322408676147,50000.0,0.4449000358581543,2.562439203262329,10000.0,52139.66808462143,57718.53077292442,52139.66808462143,5568.422976732254,4.548093795776367,0.0 -114000,1.2572263,2.9208124,,,,,,,,,,,,,, -114100,1.2703083,2.7791471,,,,,,,,,,,,,, -114200,1.2389957,2.6714032,,,,,,,,,,,,,, -114300,1.26327,2.628156,,,,,,,,,,,,,, -114400,1.0921129,3.905415,,,,,,,,,,,,,, -114500,1.1138688,3.5997057,,,,,,,,,,,,,, -114600,1.4440339,2.7679691,,,,,,,,,,,,,, -114700,1.6224061,2.7712915,,,,,,,,,,,,,, -114800,1.3390006,2.7454183,,,,,,,,,,,,,, -114839,,,0.6001757383346558,1.7225245237350464,0.5551599860191345,1.9314780235290527,50000.0,0.4402000308036804,2.587394952774048,10000.0,52559.89040374756,58184.54452776909,52559.89040374756,5614.12481713295,4.5912158489227295,0.0 -114900,1.3190166,2.757351,,,,,,,,,,,,,, -115000,1.4231483,2.8836374,,,,,,,,,,,,,, -115100,1.3697513,2.7927458,,,,,,,,,,,,,, -115200,1.33763,2.7835422,,,,,,,,,,,,,, -115300,1.3103191,2.8198073,,,,,,,,,,,,,, -115400,1.3381696,2.751584,,,,,,,,,,,,,, -115500,1.087153,4.6099515,,,,,,,,,,,,,, -115600,1.215925,2.6650364,,,,,,,,,,,,,, -115700,1.1928623,3.152288,,,,,,,,,,,,,, -115757,,,0.6089257597923279,1.6470164060592651,0.5604599714279175,1.8817566633224487,50000.0,0.4474000334739685,2.5390818119049072,10000.0,52980.26725935936,58651.52395033837,52980.26725935936,5660.633631229401,4.6373395919799805,0.0 -115800,1.4694167,2.7334168,,,,,,,,,,,,,, -115900,1.3736935,3.020492,,,,,,,,,,,,,, -116000,1.0475987,4.6131988,,,,,,,,,,,,,, -116100,1.3163129,2.9696603,,,,,,,,,,,,,, -116200,1.184498,4.010943,,,,,,,,,,,,,, -116300,1.333096,3.2840662,,,,,,,,,,,,,, -116400,1.0935026,4.8486824,,,,,,,,,,,,,, -116500,1.307485,2.7704904,,,,,,,,,,,,,, -116600,1.2796029,2.8517776,,,,,,,,,,,,,, -116677,,,0.613476574420929,1.6504943370819092,0.5635200142860413,1.876099228858948,50000.0,0.4478000104427337,2.53030014038086,10000.0,53400.54282641411,59118.91864323616,53400.54282641411,5707.661694765091,4.6801183223724365,0.0 -116700,1.1002473,3.818069,,,,,,,,,,,,,, -116800,1.3244188,2.7234561,,,,,,,,,,,,,, -116900,1.2255371,4.744314,,,,,,,,,,,,,, -117000,1.3719552,2.8163743,,,,,,,,,,,,,, -117100,1.1764232,4.1776476,,,,,,,,,,,,,, -117200,1.3300872,2.9031932,,,,,,,,,,,,,, -117300,1.3826644,2.707725,,,,,,,,,,,,,, -117400,1.4876482,2.7219768,,,,,,,,,,,,,, -117500,1.5176105,2.565544,,,,,,,,,,,,,, -117596,,,0.6048437356948853,1.6839405298233032,0.5666399598121643,1.874297022819519,50000.0,0.4523000121116638,2.5318267345428467,10000.0,53820.59474277496,59584.84307575226,53820.59474277496,5753.440539121628,4.725946426391602,0.0 -117600,1.4392583,2.642537,,,,,,,,,,,,,, -117700,1.21755,3.5829623,,,,,,,,,,,,,, -117800,1.1815935,4.9443073,,,,,,,,,,,,,, -117900,1.1078496,4.1430235,,,,,,,,,,,,,, -118000,1.2196987,4.002098,,,,,,,,,,,,,, -118100,1.0561169,5.1583633,,,,,,,,,,,,,, -118200,1.2041284,5.155487,,,,,,,,,,,,,, -118300,1.3016286,2.737076,,,,,,,,,,,,,, -118400,1.1456734,4.6686664,,,,,,,,,,,,,, -118500,1.0782787,4.3498864,,,,,,,,,,,,,, -118516,,,0.617871105670929,1.593552827835083,0.5720199942588806,1.8123306035995483,50000.0,0.4612000286579132,2.4730336666107178,10000.0,54240.80654287338,60053.70967531204,54240.80654287338,5801.996683120728,4.77754282951355,0.0 -118600,1.2258419,3.565341,,,,,,,,,,,,,, -118700,1.2378451,3.754474,,,,,,,,,,,,,, -118800,1.2748593,2.6882954,,,,,,,,,,,,,, -118900,1.2212366,4.1212673,,,,,,,,,,,,,, -119000,1.1554494,4.921511,,,,,,,,,,,,,, -119100,1.6303998,2.7614248,,,,,,,,,,,,,, -119200,1.6480618,2.6861,,,,,,,,,,,,,, -119300,1.1412225,4.7472396,,,,,,,,,,,,,, -119400,1.3542762,2.702033,,,,,,,,,,,,,, -119434,,,0.6364452838897705,1.5332109928131104,0.5678399801254272,1.8640707731246948,50000.0,0.4488000273704529,2.510444164276123,10000.0,54660.86072707176,60517.71708583832,54660.86072707176,5845.860307693481,4.820557355880737,0.0 -119500,1.5207816,2.712691,,,,,,,,,,,,,, -119600,1.3269407,2.996962,,,,,,,,,,,,,, -119700,1.2779328,3.0070302,,,,,,,,,,,,,, -119800,1.1237336,4.892373,,,,,,,,,,,,,, -119900,1.0827558,5.1374903,,,,,,,,,,,,,, -120000,1.3326507,2.8438153,,,,,,,,,,,,,, -120100,1.0405712,5.039383,,,,,,,,,,,,,, -120200,1.2907091,2.7671018,,,,,,,,,,,,,, -120300,1.1399162,4.409767,,,,,,,,,,,,,, -120352,,,0.6158202886581421,1.6390470266342163,0.5762199759483337,1.8413300514221191,50000.0,0.4611000120639801,2.518050193786621,10000.0,55081.26514601708,60984.77663850784,55081.26514601708,5892.424654722214,4.865030288696289,0.0 -120400,1.3339342,2.7187712,,,,,,,,,,,,,, -120500,1.7329884,2.6556337,,,,,,,,,,,,,, -120600,1.0809743,4.543714,,,,,,,,,,,,,, -120700,1.146734,3.2933655,,,,,,,,,,,,,, -120800,1.1587316,4.820593,,,,,,,,,,,,,, -120900,1.3053113,2.7839258,,,,,,,,,,,,,, -121000,1.4444935,2.6371078,,,,,,,,,,,,,, -121100,1.4023452,3.1699786,,,,,,,,,,,,,, -121200,1.1149195,3.6569033,,,,,,,,,,,,,, -121273,,,0.6201562285423279,1.615761399269104,0.5776799917221069,1.8188271522521973,50000.0,0.4592000246047973,2.49786114692688,10000.0,55501.56821870804,61450.70113730431,55501.56821870804,5937.952259302139,4.911031007766724,0.0 -121300,1.3110543,2.8126888,,,,,,,,,,,,,, -121400,1.3780446,2.9725525,,,,,,,,,,,,,, -121500,1.2338105,2.9597874,,,,,,,,,,,,,, -121600,1.1780117,4.457621,,,,,,,,,,,,,, -121700,1.2131529,3.9142013,,,,,,,,,,,,,, -121800,1.6631687,2.7473953,,,,,,,,,,,,,, -121900,1.322618,2.8954673,,,,,,,,,,,,,, -122000,1.3454036,2.675986,,,,,,,,,,,,,, -122100,1.2159356,4.6225753,,,,,,,,,,,,,, -122194,,,0.6444921493530273,1.5002819299697876,0.583139955997467,1.7808510065078735,50000.0,0.4688000082969665,2.44468355178833,10000.0,55921.78595089912,61921.7659611702,55921.78595089912,5988.703502893448,4.95950722694397,0.0 -122200,1.5553432,2.6459842,,,,,,,,,,,,,, -122300,1.3992411,2.7600973,,,,,,,,,,,,,, -122400,1.4463419,2.557971,,,,,,,,,,,,,, -122500,1.1912675,5.000453,,,,,,,,,,,,,, -122600,1.3441137,3.4583602,,,,,,,,,,,,,, -122700,1.3834542,3.532053,,,,,,,,,,,,,, -122800,1.259037,5.1086054,,,,,,,,,,,,,, -122900,1.1107057,4.4643817,,,,,,,,,,,,,, -123000,1.1397204,5.0563636,,,,,,,,,,,,,, -123100,1.3682252,2.7131262,,,,,,,,,,,,,, -123114,,,0.6229882836341858,1.5936886072158811,0.5823000073432922,1.7864725589752195,50000.0,0.463200032711029,2.4657235145568848,10000.0,56341.82621026039,62387.038128614426,56341.82621026039,6033.838894367218,5.008333683013916,0.0 -123200,1.3240714,2.8603692,,,,,,,,,,,,,, -123300,1.324498,2.6572435,,,,,,,,,,,,,, -123400,1.2585825,4.037,,,,,,,,,,,,,, -123500,1.1705915,3.5042634,,,,,,,,,,,,,, -123600,1.5711854,2.5746717,,,,,,,,,,,,,, -123700,1.1545527,4.974187,,,,,,,,,,,,,, -123800,1.2625756,2.7042208,,,,,,,,,,,,,, -123900,1.236097,3.1713204,,,,,,,,,,,,,, -124000,1.6594412,2.6429358,,,,,,,,,,,,,, -124031,,,0.6308007836341858,1.54508376121521,0.5869199633598328,1.7575558423995972,50000.0,0.4722000360488891,2.412415742874145,10000.0,56761.87673950195,62855.65078186989,56761.87673950195,6082.30672454834,5.055681467056274,0.0 -124100,1.2211033,5.1635523,,,,,,,,,,,,,, -124200,1.3493304,2.4904752,,,,,,,,,,,,,, -124300,1.1060779,5.1268187,,,,,,,,,,,,,, -124400,1.3259131,2.8535562,,,,,,,,,,,,,, -124500,1.2484133,3.4608674,,,,,,,,,,,,,, -124600,1.5586177,2.6429722,,,,,,,,,,,,,, -124700,1.2560412,3.765802,,,,,,,,,,,,,, -124800,1.1718323,4.1226215,,,,,,,,,,,,,, -124900,1.3872913,5.2192826,,,,,,,,,,,,,, -124950,,,0.6486718654632568,1.460990309715271,0.5928199887275696,1.732962131500244,50000.0,0.4794000089168548,2.372661590576172,10000.0,57181.95409989357,63317.253702163696,57181.95409989357,6123.739371538162,5.101463317871094,0.0 -125000,1.2013645,3.8298085,,,,,,,,,,,,,, -125100,1.184594,4.593964,,,,,,,,,,,,,, -125200,1.3183243,3.37179,,,,,,,,,,,,,, -125300,1.3930438,2.539943,,,,,,,,,,,,,, -125400,1.2997805,2.95037,,,,,,,,,,,,,, -125500,1.2273208,3.5435305,,,,,,,,,,,,,, -125600,1.416235,2.5365026,,,,,,,,,,,,,, -125700,1.4640396,2.3626616,,,,,,,,,,,,,, -125800,1.1253415,3.6103034,,,,,,,,,,,,,, -125869,,,0.636523425579071,1.5250691175460815,0.5936599969863892,1.7224498987197876,50000.0,0.4811000227928161,2.383357286453247,10000.0,57601.99688029289,63783.74777674675,57601.99688029289,6170.100564241409,5.144175052642822,0.0 -125900,1.5091914,2.6824963,,,,,,,,,,,,,, -126000,1.3833796,2.5312955,,,,,,,,,,,,,, -126100,1.4143633,2.4471912,,,,,,,,,,,,,, -126200,1.410123,3.0273218,,,,,,,,,,,,,, -126300,1.4037005,2.5527048,,,,,,,,,,,,,, -126400,1.4527633,2.5876217,,,,,,,,,,,,,, -126500,1.5458015,2.52762,,,,,,,,,,,,,, -126600,1.3876535,2.538415,,,,,,,,,,,,,, -126700,1.2286365,3.085026,,,,,,,,,,,,,, -126790,,,0.6359765529632568,1.5459275245666504,0.588979959487915,1.755147933959961,50000.0,0.4742000102996826,2.401140689849853,10000.0,58022.25514602661,64248.13990950584,58022.25514602661,6214.136778354645,5.194725275039673,0.0 -126800,1.5173236,2.5017605,,,,,,,,,,,,,, -126900,1.4258957,2.447618,,,,,,,,,,,,,, -127000,1.3130579,2.7659109,,,,,,,,,,,,,, -127100,1.3554531,2.6752408,,,,,,,,,,,,,, -127200,1.2188092,4.112317,,,,,,,,,,,,,, -127300,1.23984,3.1656678,,,,,,,,,,,,,, -127400,1.620316,2.5394292,,,,,,,,,,,,,, -127500,1.564833,2.63959,,,,,,,,,,,,,, -127600,1.7713745,2.6272612,,,,,,,,,,,,,, -127700,1.4842833,2.4305367,,,,,,,,,,,,,, -127710,,,0.6518163681030273,1.4520666599273682,0.600059986114502,1.690841794013977,50000.0,0.4843000173568725,2.344245195388794,10000.0,58442.36844062805,64715.771369457245,58442.36844062805,6261.555802345276,5.246778249740601,0.0 -127800,1.5741068,2.7850194,,,,,,,,,,,,,, -127900,1.226082,4.692438,,,,,,,,,,,,,, -128000,1.539196,2.5253906,,,,,,,,,,,,,, -128100,1.4860888,2.5426161,,,,,,,,,,,,,, -128200,1.4184076,2.9103384,,,,,,,,,,,,,, -128300,1.2489958,4.0709076,,,,,,,,,,,,,, -128400,1.3820634,5.017092,,,,,,,,,,,,,, -128500,1.298603,3.5548983,,,,,,,,,,,,,, -128600,1.4358524,4.921857,,,,,,,,,,,,,, -128630,,,0.6659374833106995,1.415810465812683,0.5946199893951416,1.7279531955718994,50000.0,0.4792000353336334,2.3771519660949707,10000.0,58862.76211476326,65181.24786877632,58862.76211476326,6306.545022726059,5.292807579040527,0.0 -128700,1.5893232,2.6615982,,,,,,,,,,,,,, -128800,1.5947695,2.6221843,,,,,,,,,,,,,, -128900,1.5383377,2.5755663,,,,,,,,,,,,,, -129000,1.3416935,2.7726378,,,,,,,,,,,,,, -129100,1.4859048,2.9130106,,,,,,,,,,,,,, -129200,1.447399,2.4312196,,,,,,,,,,,,,, -129300,1.6697475,2.4747229,,,,,,,,,,,,,, -129400,1.3741362,5.0678263,,,,,,,,,,,,,, -129500,1.2593547,4.7088366,,,,,,,,,,,,,, -129550,,,0.6496288776397705,1.4651296138763428,0.6078599691390991,1.6644974946975708,50000.0,0.4895000159740448,2.340092182159424,10000.0,59283.08157444,65645.18449640274,59283.08157444,6350.070232391357,5.337663650512695,0.0 -129600,1.6056406,2.568981,,,,,,,,,,,,,, -129700,1.537449,2.502575,,,,,,,,,,,,,, -129800,1.3521894,3.496725,,,,,,,,,,,,,, -129900,1.699173,2.4788058,,,,,,,,,,,,,, -130000,1.429069,2.5796027,,,,,,,,,,,,,, -130100,1.3940579,2.4223428,,,,,,,,,,,,,, -130200,1.4984057,2.5407279,,,,,,,,,,,,,, -130300,1.3126287,3.5511618,,,,,,,,,,,,,, -130400,1.4255179,3.1708343,,,,,,,,,,,,,, -130468,,,0.6561523079872131,1.4382413625717163,0.6042400002479553,1.6798323392868042,50000.0,0.4879000186920166,2.341444730758667,10000.0,59703.20712137222,66111.03629493713,59703.20712137222,6395.686897754669,5.386809349060059,0.0 -130500,1.5798063,2.8873856,,,,,,,,,,,,,, -130600,1.3760365,2.5903907,,,,,,,,,,,,,, -130700,1.3084114,4.1722326,,,,,,,,,,,,,, -130800,1.3075275,3.064436,,,,,,,,,,,,,, -130900,1.5847989,2.6087184,,,,,,,,,,,,,, -131000,1.8603085,2.5639505,,,,,,,,,,,,,, -131100,1.5718037,2.595763,,,,,,,,,,,,,, -131200,1.3886595,5.069003,,,,,,,,,,,,,, -131300,1.5345852,3.2443073,,,,,,,,,,,,,, -131389,,,0.6699999570846558,1.370593547821045,0.6080399751663208,1.6591289043426514,50000.0,0.4835000336170196,2.3345508575439453,10000.0,60123.49499297142,66577.5312511921,60123.49499297142,6441.799973964691,5.43379020690918,0.0 -131400,1.340022,4.2587676,,,,,,,,,,,,,, -131500,1.3326566,4.7452946,,,,,,,,,,,,,, -131600,1.6958311,2.4936826,,,,,,,,,,,,,, -131700,1.5649692,2.4864678,,,,,,,,,,,,,, -131800,1.3414096,4.9454546,,,,,,,,,,,,,, -131900,1.3360751,4.9255075,,,,,,,,,,,,,, -132000,1.4503348,3.6063626,,,,,,,,,,,,,, -132100,1.553268,4.0682936,,,,,,,,,,,,,, -132200,1.5649179,2.6022484,,,,,,,,,,,,,, -132300,1.5696692,2.6447377,,,,,,,,,,,,,, -132307,,,0.6594530940055847,1.4391076564788818,0.6137199997901917,1.645622968673706,50000.0,0.4922000169754028,2.302058219909668,10000.0,60543.795784950256,67043.1083316803,60543.795784950256,6486.980547428131,5.481791257858276,0.0 -132400,1.2543508,3.2646556,,,,,,,,,,,,,, -132500,1.4490985,2.480155,,,,,,,,,,,,,, -132600,1.6032721,2.6720634,,,,,,,,,,,,,, -132700,1.4157499,3.2147121,,,,,,,,,,,,,, -132800,1.5181241,2.7060096,,,,,,,,,,,,,, -132900,1.6640614,2.3705888,,,,,,,,,,,,,, -133000,1.7301823,2.4984772,,,,,,,,,,,,,, -133100,1.3526828,3.5117207,,,,,,,,,,,,,, -133200,1.4454803,3.0552902,,,,,,,,,,,,,, -133224,,,0.6636718511581421,1.4158997535705566,0.6115800142288208,1.6641314029693604,50000.0,0.4946000277996063,2.308748245239258,10000.0,60963.91841840744,67511.08133149147,60963.91841840744,6534.7361924648285,5.529500961303711,0.0 -133300,1.4590048,3.1867304,,,,,,,,,,,,,, -133400,1.6322292,2.5019205,,,,,,,,,,,,,, -133500,1.6805495,2.4465258,,,,,,,,,,,,,, -133600,1.7519162,2.7203724,,,,,,,,,,,,,, -133700,1.67889,2.5408607,,,,,,,,,,,,,, -133800,1.6185201,2.405356,,,,,,,,,,,,,, -133900,1.6734447,2.4107206,,,,,,,,,,,,,, -134000,1.258497,3.3235512,,,,,,,,,,,,,, -134100,1.4825435,3.0446494,,,,,,,,,,,,,, -134142,,,0.6744531393051147,1.3474888801574707,0.6168599724769592,1.62572181224823,50000.0,0.4912000298500061,2.289562702178955,10000.0,61384.02610969544,67980.39304852486,61384.02610969544,6583.84783744812,5.574039936065674,0.0 -134200,1.6565846,2.357961,,,,,,,,,,,,,, -134300,1.385304,3.8304956,,,,,,,,,,,,,, -134400,1.4224343,4.371246,,,,,,,,,,,,,, -134500,1.2874125,4.2745266,,,,,,,,,,,,,, -134600,1.2567618,4.64998,,,,,,,,,,,,,, -134700,1.3661379,4.832693,,,,,,,,,,,,,, -134800,1.8011996,2.3596509,,,,,,,,,,,,,, -134900,1.4599816,2.737549,,,,,,,,,,,,,, -135000,1.5121076,2.2074034,,,,,,,,,,,,,, -135061,,,0.6664648056030273,1.4257535934448242,0.6180999875068665,1.6305259466171265,50000.0,0.493800014257431,2.297254800796509,10000.0,61804.2365424633,68448.73585152626,61804.2365424633,6631.8838946819305,5.623882532119751,0.0 -135100,1.6241305,2.4457862,,,,,,,,,,,,,, -135200,1.5398905,2.7719889,,,,,,,,,,,,,, -135300,1.6553428,2.311478,,,,,,,,,,,,,, -135400,1.5172702,3.93056,,,,,,,,,,,,,, -135500,1.4105631,2.6968246,,,,,,,,,,,,,, -135600,1.6861653,2.4773343,,,,,,,,,,,,,, -135700,1.4274626,3.2085202,,,,,,,,,,,,,, -135800,1.5778936,2.4405773,,,,,,,,,,,,,, -135900,1.5666193,2.2971961,,,,,,,,,,,,,, -135975,,,0.6676562428474426,1.3916078805923462,0.6185399889945984,1.6243489980697632,50000.0,0.5,2.278455972671509,10000.0,62224.20766711235,68914.41670441628,62224.20766711235,6677.498216867447,5.672069072723389,0.0 -136000,1.5124524,2.7350516,,,,,,,,,,,,,, -136100,1.4709655,2.9210873,,,,,,,,,,,,,, -136200,1.3181715,4.1727433,,,,,,,,,,,,,, -136300,1.7588792,2.318093,,,,,,,,,,,,,, -136400,1.5872165,4.497219,,,,,,,,,,,,,, -136500,1.5351373,2.653951,,,,,,,,,,,,,, -136600,1.5463474,2.8948915,,,,,,,,,,,,,, -136700,1.6342573,2.355651,,,,,,,,,,,,,, -136800,1.801099,2.395182,,,,,,,,,,,,,, -136889,,,0.6796679496765137,1.3299959897994995,0.6280199885368347,1.5712664127349854,50000.0,0.5021000504493713,2.23734450340271,10000.0,62644.19136214256,69382.04277920723,62644.19136214256,6725.043194055557,5.722181797027588,0.0 -136900,1.3831809,4.3020926,,,,,,,,,,,,,, -137000,1.6967068,2.3840613,,,,,,,,,,,,,, -137100,1.7734692,2.269657,,,,,,,,,,,,,, -137200,1.5967883,2.9694438,,,,,,,,,,,,,, -137300,1.4768025,3.873322,,,,,,,,,,,,,, -137400,1.6451081,2.6127706,,,,,,,,,,,,,, -137500,1.6616559,2.3657026,,,,,,,,,,,,,, -137600,1.6443169,2.9090836,,,,,,,,,,,,,, -137700,1.3786267,4.7310786,,,,,,,,,,,,,, -137800,1.6841229,2.4797714,,,,,,,,,,,,,, -137808,,,0.6735937595367432,1.3464230298995972,0.6296399831771851,1.558292031288147,50000.0,0.5060000419616699,2.2221529483795166,10000.0,63064.48580121994,69850.06992220879,63064.48580121994,6772.678442955017,5.771515846252441,0.0 -137900,1.442984,4.9196267,,,,,,,,,,,,,, -138000,1.4664812,3.4144843,,,,,,,,,,,,,, -138100,1.5844393,2.6123195,,,,,,,,,,,,,, -138200,1.6992869,4.7346168,,,,,,,,,,,,,, -138300,1.4046583,3.461886,,,,,,,,,,,,,, -138400,1.5319601,3.1134071,,,,,,,,,,,,,, -138500,1.6628078,2.5864038,,,,,,,,,,,,,, -138600,1.5492105,2.7068198,,,,,,,,,,,,,, -138700,1.4778334,3.622282,,,,,,,,,,,,,, -138727,,,0.6751171946525574,1.3565669059753418,0.628600001335144,1.5680512189865112,50000.0,0.509600043296814,2.221627950668335,10000.0,63484.58924078941,70318.35100626945,63484.58924078941,6820.7501039505005,5.830377578735352,0.0 -138800,1.6280413,2.524545,,,,,,,,,,,,,, -138900,1.4163572,4.1153727,,,,,,,,,,,,,, -139000,1.5870943,2.2910872,,,,,,,,,,,,,, -139100,1.5751328,2.943845,,,,,,,,,,,,,, -139200,1.6883856,2.4054313,,,,,,,,,,,,,, -139300,1.5489718,2.4572077,,,,,,,,,,,,,, -139400,1.6661748,2.3995347,,,,,,,,,,,,,, -139500,1.6387619,4.4823866,,,,,,,,,,,,,, -139600,1.6865706,2.411702,,,,,,,,,,,,,, -139644,,,0.6910156011581421,1.2635481357574463,0.6337400078773499,1.5286023616790771,50000.0,0.5097000002861023,2.198205947875977,10000.0,63904.95752739906,70785.26147270203,63904.95752739906,6867.18776845932,5.888091802597046,0.0 -139700,1.8063259,2.2460363,,,,,,,,,,,,,, -139800,1.5543337,3.0459387,,,,,,,,,,,,,, -139900,1.7493746,2.3919141,,,,,,,,,,,,,, -140000,1.6783785,2.2656488,,,,,,,,,,,,,, -140100,1.7029229,2.3083367,,,,,,,,,,,,,, -140200,1.7783276,2.4430237,,,,,,,,,,,,,, -140300,1.6554562,2.2589574,,,,,,,,,,,,,, -140400,1.6801375,2.3504393,,,,,,,,,,,,,, -140500,1.8507304,2.3507183,,,,,,,,,,,,,, -140560,,,0.7136914134025574,1.185178518295288,0.6374599933624268,1.5168884992599487,50000.0,0.5170000195503235,2.171229839324951,10000.0,64325.088150024414,71251.98325324059,64325.088150024414,6913.686674833298,5.933101177215576,0.0 -140600,1.626235,3.004953,,,,,,,,,,,,,, -140700,1.9097195,2.306416,,,,,,,,,,,,,, -140800,1.6688937,2.44643,,,,,,,,,,,,,, -140900,1.6273682,3.3424504,,,,,,,,,,,,,, -141000,1.9277675,2.2967255,,,,,,,,,,,,,, -141100,1.720683,2.1544526,,,,,,,,,,,,,, -141200,1.6842732,2.7415879,,,,,,,,,,,,,, -141300,1.7470752,2.2453752,,,,,,,,,,,,,, -141400,1.5210787,3.5646505,,,,,,,,,,,,,, -141479,,,0.6903125047683716,1.27427077293396,0.6414600014686584,1.5009756088256836,50000.0,0.5174000263214111,2.166109800338745,10000.0,64745.282984018326,71718.62048172951,64745.282984018326,6960.035108566284,5.978979587554932,0.0 -141500,1.5016888,4.0958266,,,,,,,,,,,,,, -141600,1.9147862,2.2859955,,,,,,,,,,,,,, -141700,1.741293,2.3056889,,,,,,,,,,,,,, -141800,1.7111952,3.0170128,,,,,,,,,,,,,, -141900,1.901157,2.3462732,,,,,,,,,,,,,, -142000,1.5426949,2.835342,,,,,,,,,,,,,, -142100,1.7488778,2.3437319,,,,,,,,,,,,,, -142200,1.6810083,2.5700626,,,,,,,,,,,,,, -142300,1.8216385,2.4454956,,,,,,,,,,,,,, -142395,,,0.6958788633346558,1.2445529699325562,0.6466599702835083,1.4790695905685425,50000.0,0.5193000435829163,2.146132469177246,10000.0,65165.53315329552,72183.05815005302,65165.53315329552,7004.11513710022,6.038940191268921,0.0 -142400,1.7676798,4.817153,,,,,,,,,,,,,, -142500,1.5165629,3.1535666,,,,,,,,,,,,,, -142600,1.5666754,3.275511,,,,,,,,,,,,,, -142700,1.7523348,2.231467,,,,,,,,,,,,,, -142800,2.07323,2.2945247,,,,,,,,,,,,,, -142900,1.8422215,2.1992567,,,,,,,,,,,,,, -143000,1.6121622,3.5031247,,,,,,,,,,,,,, -143100,1.7713835,4.839401,,,,,,,,,,,,,, -143200,1.498684,3.1135695,,,,,,,,,,,,,, -143300,1.5268204,4.21838,,,,,,,,,,,,,, -143312,,,0.7118749618530273,1.1865516901016235,0.6448000073432922,1.483980417251587,50000.0,0.5238000154495239,2.134950399398804,10000.0,65585.63843154907,72650.63191390038,65585.63843154907,7051.486275434494,6.0886406898498535,0.0 -143400,1.9999924,2.3125076,,,,,,,,,,,,,, -143500,1.608717,2.7867124,,,,,,,,,,,,,, -143600,1.681586,2.7745697,,,,,,,,,,,,,, -143700,1.5820831,2.6660085,,,,,,,,,,,,,, -143800,2.1124136,4.8519244,,,,,,,,,,,,,, -143900,1.5850276,4.0193844,,,,,,,,,,,,,, -144000,2.0094452,2.2351394,,,,,,,,,,,,,, -144100,1.6105847,4.690446,,,,,,,,,,,,,, -144200,1.6017418,3.2263155,,,,,,,,,,,,,, -144228,,,0.6990429759025574,1.2320014238357544,0.6526199579238892,1.4438382387161257,50000.0,0.5264000296592712,2.108189821243286,10000.0,66005.99985575676,73116.48621463776,66005.99985575676,7096.882295846939,6.138426780700684,0.0 -144300,1.8783817,2.2126746,,,,,,,,,,,,,, -144400,1.7934031,2.5208607,,,,,,,,,,,,,, -144500,1.8328445,2.292382,,,,,,,,,,,,,, -144600,1.8002771,2.1019948,,,,,,,,,,,,,, -144700,1.6608213,2.2443452,,,,,,,,,,,,,, -144800,1.7058289,4.3298073,,,,,,,,,,,,,, -144900,1.8806784,2.1821895,,,,,,,,,,,,,, -145000,1.8284018,4.281785,,,,,,,,,,,,,, -145100,1.6620171,3.9996388,,,,,,,,,,,,,, -145147,,,0.7080078125,1.2053229808807373,0.6554399728775024,1.4488352537155151,50000.0,0.5321000218391418,2.10010838508606,10000.0,66425.925065279,73583.85144233704,66425.925065279,7144.226463794708,6.186929225921631,0.0 -145200,1.7644302,3.2683024,,,,,,,,,,,,,, -145300,1.6706982,4.576996,,,,,,,,,,,,,, -145400,1.7751548,2.5056715,,,,,,,,,,,,,, -145500,1.542936,3.3433177,,,,,,,,,,,,,, -145600,2.057492,2.049333,,,,,,,,,,,,,, -145700,1.798095,3.4418046,,,,,,,,,,,,,, -145800,1.9040955,2.7260964,,,,,,,,,,,,,, -145900,1.9696656,2.1779037,,,,,,,,,,,,,, -146000,1.8706032,4.1752954,,,,,,,,,,,,,, -146066,,,0.7203124761581421,1.154877781867981,0.6532799601554871,1.4429513216018677,50000.0,0.5300000309944153,2.0999159812927246,10000.0,66846.09749627113,74052.53604912758,66846.09749627113,7192.638860940933,6.239060163497925,0.0 -146100,1.9273218,2.5020952,,,,,,,,,,,,,, -146200,1.9810832,2.2144547,,,,,,,,,,,,,, -146300,1.6563164,3.5402007,,,,,,,,,,,,,, -146400,1.6741617,2.5487683,,,,,,,,,,,,,, -146500,1.7552336,4.227907,,,,,,,,,,,,,, -146600,1.8898768,3.6972086,,,,,,,,,,,,,, -146700,1.8678312,2.415565,,,,,,,,,,,,,, -146800,1.6571813,4.5123096,,,,,,,,,,,,,, -146900,1.6434852,3.5989377,,,,,,,,,,,,,, -146984,,,0.708789050579071,1.1932083368301392,0.6591599583625793,1.4202784299850464,50000.0,0.5374000072479248,2.074611186981201,10000.0,67266.39178800583,74520.21671199799,67266.39178800583,7239.931909561157,6.285334825515747,0.0 -147000,1.941804,2.190121,,,,,,,,,,,,,, -147100,2.0150535,2.1592278,,,,,,,,,,,,,, -147200,2.175598,2.3664758,,,,,,,,,,,,,, -147300,1.9053718,2.283924,,,,,,,,,,,,,, -147400,1.9195148,2.4654565,,,,,,,,,,,,,, -147500,1.8864479,2.1927142,,,,,,,,,,,,,, -147600,1.7857724,2.3618333,,,,,,,,,,,,,, -147700,1.8970475,2.1825826,,,,,,,,,,,,,, -147800,1.7857045,2.9101284,,,,,,,,,,,,,, -147900,1.7698215,2.1214356,,,,,,,,,,,,,, -147901,,,0.7182226181030273,1.1583360433578491,0.6647399663925171,1.4021943807601929,50000.0,0.5418000221252441,2.045015811920166,10000.0,67686.7302069664,74987.58259272575,67686.7302069664,7286.852725744247,6.344372987747192,0.0 -148000,1.7705065,3.306693,,,,,,,,,,,,,, -148100,2.0378335,2.2996902,,,,,,,,,,,,,, -148200,1.780934,2.0619755,,,,,,,,,,,,,, -148300,1.8944539,2.163702,,,,,,,,,,,,,, -148400,1.9386601,2.2319565,,,,,,,,,,,,,, -148500,1.9088268,2.1739023,,,,,,,,,,,,,, -148600,1.79618,3.7943716,,,,,,,,,,,,,, -148700,2.2939565,2.0901217,,,,,,,,,,,,,, -148800,2.2288017,2.2030287,,,,,,,,,,,,,, -148818,,,0.7275585532188416,1.1043344736099243,0.6648199558258057,1.3845757246017456,50000.0,0.5412000417709351,2.0354931354522705,10000.0,68107.07111549377,75456.06113243103,68107.07111549377,7334.890476465225,6.396288871765137,0.0 -148900,2.0282109,2.6309562,,,,,,,,,,,,,, -149000,1.9144123,4.562434,,,,,,,,,,,,,, -149100,2.0412893,2.110257,,,,,,,,,,,,,, -149200,1.9774574,2.1376622,,,,,,,,,,,,,, -149300,1.9597162,2.0928533,,,,,,,,,,,,,, -149400,2.0846062,1.965497,,,,,,,,,,,,,, -149500,2.1996605,2.2077055,,,,,,,,,,,,,, -149600,2.0471811,2.1692705,,,,,,,,,,,,,, -149700,1.9529718,2.3306353,,,,,,,,,,,,,, -149733,,,0.7219140529632568,1.1348050832748413,0.6680399775505066,1.3798613548278809,50000.0,0.5511000156402588,2.020998001098633,10000.0,68527.06919646263,75920.15477275848,68527.06919646263,7378.8897659778595,6.444957733154297,0.0 -149800,1.8681935,4.6323357,,,,,,,,,,,,,, -149900,1.8698547,3.9988594,,,,,,,,,,,,,, -150000,1.7568725,3.1244001,,,,,,,,,,,,,, -150100,2.188566,2.230291,,,,,,,,,,,,,, -150200,1.8360567,3.8619945,,,,,,,,,,,,,, -150300,2.0142992,2.9240499,,,,,,,,,,,,,, -150400,2.189298,2.3451457,,,,,,,,,,,,,, -150500,2.1263561,2.2187314,,,,,,,,,,,,,, -150600,2.0996797,2.4883847,,,,,,,,,,,,,, -150651,,,0.7286132574081421,1.118777871131897,0.6726199984550476,1.3689404726028442,50000.0,0.5471000075340271,2.0204763412475586,10000.0,68947.41572165489,76390.40303850174,68947.41572165489,7428.695145845413,6.494152307510376,0.0 -150700,2.1083937,2.1087866,,,,,,,,,,,,,, -150800,2.1997738,2.199024,,,,,,,,,,,,,, -150900,2.0133708,2.2588031,,,,,,,,,,,,,, -151000,2.1596415,2.4836648,,,,,,,,,,,,,, -151100,1.8828672,2.3366687,,,,,,,,,,,,,, -151200,2.189814,4.5120106,,,,,,,,,,,,,, -151300,1.8837038,3.6079814,,,,,,,,,,,,,, -151400,2.0564194,4.4475183,,,,,,,,,,,,,, -151500,1.8002752,4.529128,,,,,,,,,,,,,, -151569,,,0.7353515625,1.0825318098068235,0.6772199869155884,1.343247890472412,50000.0,0.5521000027656555,1.9848430156707764,10000.0,69367.38053894043,76857.38411259651,69367.38053894043,7475.611652612686,6.545987844467163,0.0 -151600,1.9691529,2.6949565,,,,,,,,,,,,,, -151700,2.3630188,2.1778514,,,,,,,,,,,,,, -151800,2.1567512,2.0907183,,,,,,,,,,,,,, -151900,1.9006884,2.244368,,,,,,,,,,,,,, -152000,2.3813114,2.133634,,,,,,,,,,,,,, -152100,2.0761018,2.0088673,,,,,,,,,,,,,, -152200,2.0352867,2.3691883,,,,,,,,,,,,,, -152300,2.357433,1.9758899,,,,,,,,,,,,,, -152400,1.9270892,3.6291432,,,,,,,,,,,,,, -152486,,,0.7512304782867432,1.0053696632385254,0.6732400059700012,1.3476532697677612,50000.0,0.5507000088691711,1.9955780506134035,10000.0,69787.35775566101,77324.2224123478,69787.35775566101,7522.364356279373,6.606665372848511,0.0 -152500,2.2759383,2.255582,,,,,,,,,,,,,, -152600,2.1724293,2.0996542,,,,,,,,,,,,,, -152700,1.8783776,4.22618,,,,,,,,,,,,,, -152800,1.9570683,4.0407796,,,,,,,,,,,,,, -152900,2.1616013,2.0660307,,,,,,,,,,,,,, -153000,2.2209268,1.9552269,,,,,,,,,,,,,, -153100,2.123005,1.9367571,,,,,,,,,,,,,, -153200,2.182657,2.0275877,,,,,,,,,,,,,, -153300,1.9139057,4.0627685,,,,,,,,,,,,,, -153400,2.1956375,2.1956904,,,,,,,,,,,,,, -153405,,,0.7347655892372131,1.0813597440719604,0.6788199543952942,1.333372950553894,50000.0,0.54830002784729,1.9810631275177,10000.0,70207.6869328022,77792.1483130455,70207.6869328022,7569.859260797501,6.660884857177734,0.0 -153500,2.143422,2.6692302,,,,,,,,,,,,,, -153600,2.1085563,1.9906793,,,,,,,,,,,,,, -153700,2.1965828,2.081864,,,,,,,,,,,,,, -153800,2.109894,2.4980679,,,,,,,,,,,,,, -153900,2.2495487,2.052672,,,,,,,,,,,,,, -154000,2.2696736,2.0539408,,,,,,,,,,,,,, -154100,2.0125458,3.1383638,,,,,,,,,,,,,, -154200,2.138089,1.9298046,,,,,,,,,,,,,, -154300,2.2121859,2.0460641,,,,,,,,,,,,,, -154324,,,0.7437695264816284,1.0343624353408811,0.6834799647331238,1.3101791143417358,50000.0,0.5578000545501709,1.96280300617218,10000.0,70627.66973996162,78259.50230240822,70627.66973996162,7617.13103890419,6.713019609451294,0.0 -154400,1.96398,3.9109128,,,,,,,,,,,,,, -154500,2.0886607,3.7935014,,,,,,,,,,,,,, -154600,1.990506,2.6210246,,,,,,,,,,,,,, -154700,2.3314738,3.6267166,,,,,,,,,,,,,, -154800,1.9740433,2.550758,,,,,,,,,,,,,, -154900,2.1389194,4.5794616,,,,,,,,,,,,,, -155000,2.2998264,1.9165713,,,,,,,,,,,,,, -155100,2.207261,2.9998074,,,,,,,,,,,,,, -155200,2.30008,2.0759573,,,,,,,,,,,,,, -155243,,,0.7542773485183716,1.0230635404586792,0.6821399927139282,1.3291398286819458,50000.0,0.554900050163269,1.97342312335968,10000.0,71047.77196288109,78727.53687787056,71047.77196288109,7664.961159944534,6.767343282699585,0.0 -155300,2.1700177,4.5132036,,,,,,,,,,,,,, -155400,2.2445455,3.6271133,,,,,,,,,,,,,, -155500,2.1987681,1.9490862,,,,,,,,,,,,,, -155600,2.272458,2.0920286,,,,,,,,,,,,,, -155700,2.3376892,2.1437714,,,,,,,,,,,,,, -155800,2.0622392,3.7038076,,,,,,,,,,,,,, -155900,2.3066354,2.4674122,,,,,,,,,,,,,, -156000,2.112962,4.011956,,,,,,,,,,,,,, -156100,2.0948622,2.2121532,,,,,,,,,,,,,, -156162,,,0.7413867115974426,1.0416297912597656,0.6865999698638916,1.2964109182357788,50000.0,0.5647000074386597,1.930980205535889,10000.0,71467.86416912079,79192.2151761055,71467.86416912079,7709.447716712952,6.819774866104126,0.0 -156200,2.3221037,2.021276,,,,,,,,,,,,,, -156300,2.1915681,2.664349,,,,,,,,,,,,,, -156400,2.1614406,1.8259686,,,,,,,,,,,,,, -156500,2.06933,3.8167186,,,,,,,,,,,,,, -156600,2.0339856,3.5002594,,,,,,,,,,,,,, -156700,2.6321933,4.456592,,,,,,,,,,,,,, -156800,2.2638328,1.9315221,,,,,,,,,,,,,, -156900,2.3557968,3.1525254,,,,,,,,,,,,,, -157000,2.5319505,4.43988,,,,,,,,,,,,,, -157080,,,0.7487695217132568,1.0085744857788086,0.6894800066947937,1.2726929187774658,50000.0,0.5675000548362732,1.918201804161072,10000.0,71887.97418117523,79659.96469688416,71887.97418117523,7756.980162143707,6.879199504852295,0.0 -157100,2.433354,1.893115,,,,,,,,,,,,,, -157200,2.3344045,1.9452409,,,,,,,,,,,,,, -157300,2.0654242,3.508398,,,,,,,,,,,,,, -157400,2.3404484,2.8176348,,,,,,,,,,,,,, -157500,2.2163005,2.9444137,,,,,,,,,,,,,, -157600,2.2665462,2.3759673,,,,,,,,,,,,,, -157700,1.9970783,3.3859622,,,,,,,,,,,,,, -157800,2.3954413,1.8816702,,,,,,,,,,,,,, -157900,2.450881,4.406538,,,,,,,,,,,,,, -157995,,,0.7559570074081421,0.9801447987556458,0.6914199590682983,1.2745487689971924,50000.0,0.5671000480651855,1.9162464141845703,10000.0,72307.93619275093,80124.37371206284,72307.93619275093,7801.326666116714,6.932545900344849,0.0 -158000,1.9783255,3.0776513,,,,,,,,,,,,,, -158100,2.5208237,2.0647917,,,,,,,,,,,,,, -158200,2.4122796,4.39418,,,,,,,,,,,,,, -158300,2.28625,1.9466444,,,,,,,,,,,,,, -158400,2.0022602,3.4443238,,,,,,,,,,,,,, -158500,2.337246,4.3385563,,,,,,,,,,,,,, -158600,2.5921478,2.278992,,,,,,,,,,,,,, -158700,2.5660183,2.0459309,,,,,,,,,,,,,, -158800,2.334017,1.8853736,,,,,,,,,,,,,, -158900,2.394517,1.9118633,,,,,,,,,,,,,, -158915,,,0.7525194883346558,1.00764000415802,0.6936399936676025,1.2615045309066772,50000.0,0.569100022315979,1.889735460281372,10000.0,72728.1712462902,80591.012966156,72728.1712462902,7847.6235818862915,6.992994785308838,0.0 -159000,2.5748398,1.8221626,,,,,,,,,,,,,, -159100,2.3683343,3.6889691,,,,,,,,,,,,,, -159200,2.1504397,4.041788,,,,,,,,,,,,,, -159300,2.439561,1.8983798,,,,,,,,,,,,,, -159400,2.380872,2.0188894,,,,,,,,,,,,,, -159500,2.5495121,2.0636108,,,,,,,,,,,,,, -159600,2.594058,1.9318362,,,,,,,,,,,,,, -159700,2.2827523,3.0662498,,,,,,,,,,,,,, -159800,2.1742487,3.0701723,,,,,,,,,,,,,, -159832,,,0.7587695121765137,0.973080575466156,0.6987999677658081,1.2378002405166626,50000.0,0.5731000304222107,1.876285195350647,10000.0,73148.25042939186,81059.77536559105,73148.25042939186,7896.201631784439,7.052062034606934,0.0 -159900,2.5411277,1.9006233,,,,,,,,,,,,,, -160000,2.6755197,4.0591683,,,,,,,,,,,,,, -160100,2.5616107,1.8671596,,,,,,,,,,,,,, -160200,2.5229135,1.9332961,,,,,,,,,,,,,, -160300,2.4843333,1.9288826,,,,,,,,,,,,,, -160400,2.453848,1.9393232,,,,,,,,,,,,,, -160500,2.530782,1.8903407,,,,,,,,,,,,,, -160600,2.521404,1.8917229,,,,,,,,,,,,,, -160700,2.617598,1.8365766,,,,,,,,,,,,,, -160750,,,0.7638476490974426,0.9550341367721558,0.7001399993896484,1.2401357889175415,50000.0,0.5751000046730042,1.8870264291763303,10000.0,73568.35084533691,81525.65334272385,73568.35084533691,7941.881090164185,7.1026670932769775,0.0 -160800,2.4434807,3.2644424,,,,,,,,,,,,,, -160900,2.2451012,2.5117095,,,,,,,,,,,,,, -161000,2.224232,3.2871907,,,,,,,,,,,,,, -161100,2.3181474,2.1129053,,,,,,,,,,,,,, -161200,2.3726544,2.4020073,,,,,,,,,,,,,, -161300,2.541881,1.9293094,,,,,,,,,,,,,, -161400,2.5354328,3.079342,,,,,,,,,,,,,, -161500,2.5236986,3.8116293,,,,,,,,,,,,,, -161600,2.5434132,1.8684402,,,,,,,,,,,,,, -161669,,,0.7666015625,0.9483379125595092,0.7057600021362305,1.2242356538772583,50000.0,0.5799000263214111,1.860588908195496,10000.0,73988.73741221428,81990.81713628769,73988.73741221428,7986.558121442795,7.155567407608032,0.0 -161700,2.2872748,2.7115788,,,,,,,,,,,,,, -161800,2.553551,4.3341722,,,,,,,,,,,,,, -161900,2.309486,2.4195745,,,,,,,,,,,,,, -162000,2.4570215,3.6319122,,,,,,,,,,,,,, -162100,2.43285,2.3577712,,,,,,,,,,,,,, -162200,2.4132814,1.8851833,,,,,,,,,,,,,, -162300,2.324439,2.2855008,,,,,,,,,,,,,, -162400,2.540097,1.7601383,,,,,,,,,,,,,, -162500,2.6328034,4.2695894,,,,,,,,,,,,,, -162587,,,0.7708789110183716,0.9288029074668884,0.7047399878501892,1.2123870849609375,50000.0,0.5809000134468079,1.8519184589385984,10000.0,74408.8857395649,82460.56594634056,74408.8857395649,8036.057965755463,7.209298849105835,0.0 -162600,2.4954083,2.5320554,,,,,,,,,,,,,, -162700,2.5316124,3.7993505,,,,,,,,,,,,,, -162800,2.299992,2.5456145,,,,,,,,,,,,,, -162900,2.9631453,1.830924,,,,,,,,,,,,,, -163000,2.2352657,3.3608024,,,,,,,,,,,,,, -163100,3.158748,4.3952026,,,,,,,,,,,,,, -163200,2.8669364,1.8873413,,,,,,,,,,,,,, -163300,2.5521436,2.8804243,,,,,,,,,,,,,, -163400,2.815109,2.2174768,,,,,,,,,,,,,, -163500,2.7131088,1.8289559,,,,,,,,,,,,,, -163508,,,0.7753124833106995,0.9027910232543944,0.7118799686431885,1.1837961673736572,50000.0,0.5873000025749207,1.826006293296814,10000.0,74829.21450185776,82926.06588101387,74829.21450185776,8081.126464605331,7.264871120452881,0.0 -163600,2.7157261,1.7639434,,,,,,,,,,,,,, -163700,2.489759,2.062918,,,,,,,,,,,,,, -163800,2.7904308,4.226625,,,,,,,,,,,,,, -163900,2.6810565,4.0896845,,,,,,,,,,,,,, -164000,2.7040622,1.7437451,,,,,,,,,,,,,, -164100,2.7460241,1.7508692,,,,,,,,,,,,,, -164200,2.6802676,2.0724978,,,,,,,,,,,,,, -164300,2.790775,3.469438,,,,,,,,,,,,,, -164400,2.3922882,2.222697,,,,,,,,,,,,,, -164424,,,0.7840234041213989,0.861545979976654,0.7120199799537659,1.1774080991744995,50000.0,0.5870000123977661,1.8180652856826784,10000.0,75249.25192141533,83393.11365270615,75249.25192141533,8128.030569314957,7.324113607406616,0.0 -164500,2.5114136,3.3246353,,,,,,,,,,,,,, -164600,2.657602,1.6424385,,,,,,,,,,,,,, -164700,2.5713375,2.7475915,,,,,,,,,,,,,, -164800,2.6254694,1.7624818,,,,,,,,,,,,,, -164900,2.6374965,2.4560456,,,,,,,,,,,,,, -165000,2.6917884,1.9614449,,,,,,,,,,,,,, -165100,2.4279642,1.8402082,,,,,,,,,,,,,, -165200,3.037385,1.8846608,,,,,,,,,,,,,, -165300,2.6072264,1.8504714,,,,,,,,,,,,,, -165344,,,0.78076171875,0.8917228579521179,0.7150999903678894,1.1734659671783447,50000.0,0.5918000340461731,1.806179404258728,10000.0,75669.19865012169,83860.63411188126,75669.19865012169,8175.50325012207,7.37749457359314,0.0 -165400,2.591984,3.2920327,,,,,,,,,,,,,, -165500,2.8067384,1.8483125,,,,,,,,,,,,,, -165600,2.66021,3.1300645,,,,,,,,,,,,,, -165700,2.9719298,4.046852,,,,,,,,,,,,,, -165800,2.7900224,2.2663727,,,,,,,,,,,,,, -165900,2.6321757,1.7158074,,,,,,,,,,,,,, -166000,2.7886357,1.7688677,,,,,,,,,,,,,, -166100,2.994894,1.8478816,,,,,,,,,,,,,, -166200,2.5065722,2.7252243,,,,,,,,,,,,,, -166265,,,0.7815039157867432,0.8719916939735413,0.7160800099372864,1.1593176126480105,50000.0,0.5913000106811523,1.798190951347351,10000.0,76089.25742650032,84325.38896870613,76089.25742650032,8220.088045358658,7.440832138061523,0.0 -166300,2.664539,2.8610466,,,,,,,,,,,,,, -166400,2.6357086,2.1692593,,,,,,,,,,,,,, -166500,2.7540796,1.6760274,,,,,,,,,,,,,, -166600,2.817394,4.0131197,,,,,,,,,,,,,, -166700,2.8264203,1.7793398,,,,,,,,,,,,,, -166800,2.9966497,1.8886973,,,,,,,,,,,,,, -166900,3.7454433,1.7817049,,,,,,,,,,,,,, -167000,2.8337214,1.739691,,,,,,,,,,,,,, -167100,2.9293532,1.9979751,,,,,,,,,,,,,, -167181,,,0.789843738079071,0.8536946177482605,0.7184199690818787,1.1556397676467896,50000.0,0.5931000113487244,1.7854821681976318,10000.0,76509.34638428688,84792.34748697281,76509.34638428688,8266.85658288002,7.495449066162109,0.0 -167200,2.6006544,2.481557,,,,,,,,,,,,,, -167300,2.6480103,3.3086905,,,,,,,,,,,,,, -167400,2.8182535,1.7716025,,,,,,,,,,,,,, -167500,2.7774568,3.2393248,,,,,,,,,,,,,, -167600,2.786793,1.7716036,,,,,,,,,,,,,, -167700,2.6493196,2.5095885,,,,,,,,,,,,,, -167800,3.1846282,1.8010255,,,,,,,,,,,,,, -167900,3.0997102,1.6870514,,,,,,,,,,,,,, -168000,2.9244509,1.5845431,,,,,,,,,,,,,, -168099,,,0.7865429520606995,0.8519309759140015,0.7182199954986572,1.1389554738998413,50000.0,0.5954000353813171,1.7669183015823364,10000.0,76929.57842612267,85260.30331659317,76929.57842612267,8314.475160121918,7.552535772323608,0.0 -168100,2.7978053,1.8062783,,,,,,,,,,,,,, -168200,3.1851158,1.7297571,,,,,,,,,,,,,, -168300,2.7133543,2.4079053,,,,,,,,,,,,,, -168400,3.1439493,2.5848615,,,,,,,,,,,,,, -168500,2.8017082,2.606505,,,,,,,,,,,,,, -168600,3.077989,1.7217507,,,,,,,,,,,,,, -168700,3.623676,1.7528609,,,,,,,,,,,,,, -168800,3.199271,1.7466909,,,,,,,,,,,,,, -168900,3.0651748,1.6978714,,,,,,,,,,,,,, -169000,3.1127532,1.7374412,,,,,,,,,,,,,, -169019,,,0.7899804711341858,0.8419513702392578,0.7230599522590637,1.1312774419784546,50000.0,0.5969000458717346,1.764918088912964,10000.0,77349.57205319405,85725.11896824837,77349.57205319405,8359.18827176094,7.6130571365356445,0.0 -169100,3.0931633,3.7802699,,,,,,,,,,,,,, -169200,3.1068127,2.016216,,,,,,,,,,,,,, -169300,2.931226,2.7352657,,,,,,,,,,,,,, -169399,,,,,,,,,,,77520.27059221268,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/eval_measurements.csv deleted file mode 100644 index be4460f93..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.51941752433777,0.0,40.48073649406433,1,0,40.48073649406433,0.0010000000474974,6.907756805419922,10000,69.00026822090149,0.0007812499534338,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -69.7873125076294,0.0192267894744873,460.5695321559906,848,0,460.5695321559906,0.0215000007301569,6.148784637451172,10000,530.42049741745,0.0295507814735174,5.966935157775879,0.0253199990838766,6.030439376831055,50000 -117.55456948280334,0.0466711521148681,880.5223081111908,1761,0,880.5223081111908,0.054100003093481,5.57318639755249,10000,998.2155048847198,0.0711914077401161,5.351015567779541,0.0685599967837333,5.379694938659668,50000 -164.33563232421875,0.0717246532440185,1300.5238733291626,2678,0,1300.5238733291626,0.0898000076413154,5.094581604003906,10000,1465.0715363025663,0.1262304633855819,4.751564025878906,0.119499996304512,4.812889099121094,50000 -213.6520509719849,0.1013989448547363,1720.9155325889587,3597,0,1720.9155325889587,0.1284000128507614,4.73599910736084,10000,1934.8571481704712,0.1823828071355819,4.224600315093994,0.1660399883985519,4.346518516540527,50000 -259.7257878780365,0.1269831657409668,2141.0058159828186,4516,0,2141.0058159828186,0.1705000102519989,4.336967945098877,10000,2401.094133615494,0.2391015589237213,3.7565841674804688,0.222580000758171,3.8583083152771,50000 -306.6399974822998,0.152900218963623,2561.1740078926086,5431,0,2561.1740078926086,0.1940000057220459,4.203524112701416,10000,2868.2551929950714,0.2705273330211639,3.628831386566162,0.2485599964857101,3.7476320266723633,50000 -353.06326150894165,0.1847636699676513,2981.2767839431763,6348,0,2981.2767839431763,0.2184000164270401,3.982554435729981,10000,3334.8607263565063,0.3121093809604645,3.293789386749268,0.2845599949359894,3.4468889236450195,50000 -401.4614663124085,0.214972972869873,3401.2823746204376,7261,0,3401.2823746204376,0.2422000169754028,3.815594434738159,10000,3803.342059373856,0.3370117247104645,3.1292214393615723,0.3139199912548065,3.2608413696289062,50000 -447.0047791004181,0.2399046421051025,3821.274812936783,8174,0,3821.274812936783,0.2569000124931335,3.737917423248291,10000,4268.951048851013,0.35986328125,3.0289080142974854,0.3312399983406067,3.1762967109680176,50000 -495.13693618774414,0.269237756729126,4241.370996952057,9089,0,4241.370996952057,0.2750000059604645,3.597908735275269,10000,4737.255994081497,0.3876953125,2.8449742794036865,0.3534199893474579,3.0242254734039307,50000 -544.185914516449,0.3009088039398193,4661.446325063705,10003,0,4661.446325063705,0.2870000004768371,3.4828357696533203,10000,5206.459691762924,0.4051952958106994,2.718560218811035,0.3734799921512604,2.8857717514038086,50000 -591.8794357776642,0.3268032073974609,5081.747814178467,10919,0,5081.747814178467,0.3006000220775604,3.390204429626465,10000,5674.528142929077,0.4252538979053497,2.589329719543457,0.3923999965190887,2.761080980300904,50000 -640.3868417739868,0.3585836887359619,5501.858088970184,11832,0,5501.858088970184,0.3098000288009643,3.3450920581817627,10000,6143.224822998047,0.4370703101158142,2.549715757369995,0.398499995470047,2.74328088760376,50000 -688.9637801647186,0.3849976062774658,5921.982423782349,12746,0,5921.982423782349,0.3286000192165375,3.2327558994293213,10000,6611.999529123306,0.482226550579071,2.3162496089935303,0.4213799834251404,2.6229944229125977,50000 -737.8570251464844,0.4148344993591308,6342.175592184067,13662,0,6342.175592184067,0.3323000073432922,3.2250514030456543,10000,7081.163420915604,0.4601757824420929,2.428586721420288,0.4309999942779541,2.577209711074829,50000 -789.2101700305939,0.4447028636932373,6762.429019451141,14578,0,6762.429019451141,0.3360000252723694,3.2194483280181885,10000,7552.847557067871,0.4689843654632568,2.3901267051696777,0.4369599819183349,2.566824197769165,50000 -839.039304971695,0.4735012054443359,7182.634547472,15492,0,7182.634547472,0.3511000275611877,3.085043430328369,10000,8022.958084821701,0.509960949420929,2.1471524238586426,0.4541999995708465,2.4339747428894043,50000 -888.5021407604218,0.5111334323883057,7602.818835020065,16408,0,7602.818835020065,0.3532000184059143,3.1026744842529297,10000,8492.689853906631,0.4842382669448852,2.291804552078247,0.4571599960327148,2.444198608398437,50000 -937.5597274303436,0.5380949974060059,8022.798815488815,17323,0,8022.798815488815,0.3483000099658966,3.109490871429444,10000,8961.801989793777,0.4905664026737213,2.280002355575561,0.4546799957752228,2.458152532577514,50000 -986.0897233486176,0.5650601387023926,8443.062378168106,18239,0,8443.062378168106,0.3587000072002411,3.0502817630767822,10000,9430.669814825058,0.520800769329071,2.135014533996582,0.4687999784946441,2.398402452468872,50000 -1037.8643803596497,0.5954127311706543,8863.191219329834,19155,0,8863.191219329834,0.3728000223636627,2.9553446769714355,10000,9902.65026807785,0.5138086080551147,2.095219135284424,0.4810999929904938,2.275703430175781,50000 -1085.8274364471436,0.6221892833709717,9283.32648420334,20070,0,9283.32648420334,0.3712000250816345,2.999906301498413,10000,10370.82237124443,0.5203906297683716,2.1330671310424805,0.4761599898338318,2.3456084728240967,50000 -1135.6430156230929,0.6531627178192139,9703.52631521225,20987,0,9703.52631521225,0.3790000081062317,2.913668394088745,10000,10840.916332960129,0.5329296588897705,2.013391733169556,0.4882999956607818,2.249237060546875,50000 -1181.5994005203247,0.6843507289886475,10123.515281438828,21901,0,10123.515281438828,0.3831000328063965,2.9034087657928467,10000,11306.940642356873,0.5308398604393005,2.069345474243164,0.4955999851226806,2.256677865982056,50000 -1229.5627224445343,0.7155659198760986,10543.550062179564,22816,0,10543.550062179564,0.3904000222682953,2.871825218200684,10000,11775.017482995989,0.5394726395606995,2.00322699546814,0.5013399720191956,2.20600700378418,50000 -1278.4515182971954,0.7490296363830566,10963.586684465408,23732,0,10963.586684465408,0.398900032043457,2.8280670642852783,10000,12244.024310827255,0.551074206829071,1.951820373535156,0.5073400139808655,2.178410768508911,50000 -1327.3468503952026,0.781226396560669,11383.736985206604,24649,0,11383.736985206604,0.3990000188350677,2.818315029144287,10000,12713.149666309357,0.5817968845367432,1.7977644205093384,0.5110399723052979,2.1437582969665527,50000 -1375.8142375946045,0.8103833198547363,11803.839218378069,25565,0,11803.839218378069,0.4045000076293945,2.760893106460572,10000,13181.796215295792,0.5573828220367432,1.887807011604309,0.5200200080871582,2.082399606704712,50000 -1425.337683916092,0.8424429893493652,12223.86883687973,26481,0,12223.86883687973,0.4079000055789947,2.758775234222412,10000,13651.433849334717,0.5679296851158142,1.8611342906951904,0.5214599967002869,2.0894598960876465,50000 -1475.6100606918335,0.8734757900238037,12644.12185382843,27397,0,12644.12185382843,0.4154000282287597,2.729538917541504,10000,14122.03713798523,0.5859179496765137,1.7362204790115356,0.5262599587440491,2.0381267070770264,50000 -1523.542355298996,0.9049403667449952,13064.183827638626,28313,0,13064.183827638626,0.4115000069141388,2.7372734546661377,10000,14590.11095881462,0.5691796541213989,1.837002515792847,0.5263400077819824,2.0492756366729736,50000 -1572.8207762241364,0.939743995666504,13484.137912511826,29230,0,13484.137912511826,0.4165000319480896,2.705432653427124,10000,15059.425895929337,0.5762890577316284,1.8103874921798704,0.5307799577713013,2.0295510292053223,50000 -1621.401178598404,0.9774858951568604,13904.376311302183,30147,0,13904.376311302183,0.416700005531311,2.7173359394073486,10000,15528.329751729963,0.586718738079071,1.7852870225906372,0.5297799706459045,2.055450677871704,50000 -1669.0584816932678,1.0061118602752686,14324.65103316307,31064,0,14324.65103316307,0.4200000166893005,2.728641986846924,10000,15996.338223218918,0.5789452791213989,1.8326478004455569,0.5329399704933167,2.0538718700408936,50000 -1716.5272045135498,1.0351231098175049,14744.678804397585,31981,0,14744.678804397585,0.4280000329017639,2.649587631225586,10000,16463.911828041077,0.5824804306030273,1.7766609191894531,0.5413399934768677,1.9840885400772093,50000 -1763.92360329628,1.0761663913726809,15164.711621046066,32897,0,15164.711621046066,0.4245000183582306,2.6344099044799805,10000,16931.429044008255,0.5879296660423279,1.7226225137710571,0.5427199602127075,1.9646421670913696,50000 -1811.923983812332,1.1177661418914795,15584.832740068436,33811,0,15584.832740068436,0.4308000206947326,2.6165106296539307,10000,17399.639848709106,0.5936523079872131,1.7037583589553833,0.5494799613952637,1.9261177778244016,50000 -1861.807582378388,1.1519536972045898,16004.785893440248,34727,0,16004.785893440248,0.4324000179767608,2.64608097076416,10000,17869.55842280388,0.5853710770606995,1.762331485748291,0.5461999773979187,1.964341163635254,50000 -1910.7358112335205,1.1812076568603516,16425.0973944664,35643,0,16425.0973944664,0.434000015258789,2.613935708999634,10000,18338.87577271461,0.6008984446525574,1.6964315176010132,0.5519599914550781,1.9376044273376465,50000 -1960.112357378006,1.2128477096557615,16845.29321050644,36555,0,16845.29321050644,0.4384000301361084,2.5619187355041504,10000,18808.52798581124,0.6290624737739563,1.551759958267212,0.5550999641418457,1.8873655796051023,50000 -2010.38028049469,1.2492239475250244,17265.42009329796,37470,0,17265.42009329796,0.4430000185966491,2.575869560241699,10000,19279.00605130196,0.5981054306030273,1.7004003524780271,0.5553799867630005,1.9004456996917725,50000 -2059.4976251125336,1.2838959693908691,17685.49142575264,38384,0,17685.49142575264,0.4426000118255615,2.580732822418213,10000,19748.276788711548,0.5977538824081421,1.6697273254394531,0.5511000156402588,1.906893253326416,50000 -2108.164050579071,1.3266685009002686,18105.790986537933,39298,0,18105.790986537933,0.4438000321388244,2.558645725250244,10000,20217.33277368545,0.6231836080551147,1.5941240787506104,0.5628399848937988,1.888396143913269,50000 -2156.8390045166016,1.3593740463256836,18526.15670633316,40214,0,18526.15670633316,0.4414000213146209,2.5693047046661377,10000,20686.45342946053,0.5969336032867432,1.700569987297058,0.55485999584198,1.911470413208008,50000 -2205.9153735637665,1.389624834060669,18946.13728427887,41130,0,18946.13728427887,0.4456000328063965,2.548941135406494,10000,21155.588027715683,0.6115429401397705,1.6528892517089844,0.5624399781227112,1.895995855331421,50000 -2255.04528427124,1.4197359085083008,19366.324320554733,42044,0,19366.324320554733,0.4542000293731689,2.498311758041382,10000,21624.983068466187,0.6252148151397705,1.5360467433929443,0.5721399784088135,1.8095941543579104,50000 -2304.938737630844,1.450535774230957,19786.38955569268,42957,0,19786.38955569268,0.4485000073909759,2.561943292617798,10000,22095.019829034805,0.6100390553474426,1.6844910383224487,0.5631600022315979,1.9153586626052856,50000 -2352.82309794426,1.4831857681274414,20206.347382307053,43869,0,20206.347382307053,0.4502000212669372,2.551200151443481,10000,22562.94238114357,0.6125780940055847,1.6683449745178225,0.5681599974632263,1.8960041999816888,50000 -2402.4246587753296,1.521988868713379,20626.368554592133,44782,0,20626.368554592133,0.4561000168323517,2.472685098648072,10000,23032.650621652603,0.6306250095367432,1.5366846323013306,0.5754599571228027,1.7981550693511963,50000 -2451.248507976532,1.560218334197998,21046.379149913788,45695,0,21046.379149913788,0.4631000161170959,2.4594175815582275,10000,23501.57047510147,0.6145312190055847,1.6023259162902832,0.5731599926948547,1.807721376419068,50000 -2500.084460258484,1.5997076034545898,21466.459342956543,46611,0,21466.459342956543,0.4531000256538391,2.510585308074951,10000,23970.574808597565,0.6163476705551147,1.598486304283142,0.5700199604034424,1.8413797616958616,50000 -2549.6821115016937,1.631948709487915,21886.38616538048,47521,0,21886.38616538048,0.4595000147819519,2.499596118927002,10000,24440.17867231369,0.629199206829071,1.6050124168395996,0.575760006904602,1.858893871307373,50000 -2597.718656778336,1.6638882160186768,22306.334075450897,48431,0,22306.334075450897,0.4678000211715698,2.401552677154541,10000,24908.24243426323,0.6450585722923279,1.445672631263733,0.5813400149345398,1.7478997707366943,50000 -2646.743688106537,1.7032458782196045,22726.702834129333,49343,0,22726.702834129333,0.4589000344276428,2.4847159385681152,10000,25377.72305703163,0.623046875,1.623903512954712,0.5805599689483643,1.835086703300476,50000 -2696.7211923599243,1.7389581203460691,23146.96598172188,50258,0,23146.96598172188,0.4665000140666961,2.422419309616089,10000,25848.04677391052,0.640917956829071,1.4934781789779663,0.5890600085258484,1.74969220161438,50000 -2744.944550037384,1.774586200714111,23567.253214359283,51173,0,23567.253214359283,0.4625000357627868,2.461416482925415,10000,26316.64104104042,0.654101550579071,1.4827535152435305,0.582260012626648,1.81582260131836,50000 -2792.6781933307648,1.813803672790528,23987.205542325974,52086,0,23987.205542325974,0.4642000198364258,2.44661545753479,10000,26784.412437677383,0.6284374594688416,1.56432044506073,0.5817399621009827,1.7849321365356443,50000 -2841.4261870384216,1.8572685718536377,24407.16034078598,53002,0,24407.16034078598,0.4789000153541565,2.357414960861206,10000,27253.20618200302,0.6423242092132568,1.4505491256713867,0.5980799794197083,1.690264344215393,50000 -2892.447431087494,1.895360231399536,24827.3532075882,53916,0,24827.3532075882,0.4648000299930572,2.465859889984131,10000,27724.505935430527,0.6451367139816284,1.5169159173965454,0.583579957485199,1.813071370124817,50000 -2941.016970396042,1.9336466789245603,25247.272602796555,54832,0,25247.272602796555,0.4722000360488891,2.388838529586792,10000,28193.08148908615,0.6351171731948853,1.4957752227783203,0.5937199592590332,1.7042373418807983,50000 -2991.487948417664,1.971735000610352,25667.20871949196,55740,0,25667.20871949196,0.4680000245571136,2.404808759689331,10000,28663.57295012474,0.64501953125,1.4755982160568235,0.5958200097084045,1.7111175060272217,50000 -3040.205552339554,2.005126714706421,26087.488626241684,56651,0,26087.488626241684,0.4706000089645386,2.3935937881469727,10000,29132.650985717773,0.6513866782188416,1.4567426443099976,0.590999960899353,1.7368868589401243,50000 -3089.9091703891754,2.0419747829437256,26507.458403348923,57560,0,26507.458403348923,0.473000019788742,2.4218482971191406,10000,29602.409747600555,0.6385351419448853,1.554219126701355,0.593239963054657,1.763049840927124,50000 -3140.4162259101868,2.0792157649993896,26927.61052799225,58474,0,26927.61052799225,0.4843000173568725,2.323306083679199,10000,30073.15274357796,0.6503124833106995,1.4144947528839111,0.6011399626731873,1.652499437332153,50000 -3189.0866141319275,2.1189894676208496,27347.89120697975,59389,0,27347.89120697975,0.4761000275611877,2.381978988647461,10000,30542.19105768204,0.6558203101158142,1.472233533859253,0.5989399552345276,1.731550931930542,50000 -3238.5846648216248,2.1544106006622314,27767.98445320129,60301,0,27767.98445320129,0.4716000258922577,2.4132392406463623,10000,31011.86527228356,0.64208984375,1.5133836269378662,0.5927799940109253,1.7447527647018433,50000 -3289.2686598300934,2.187922954559326,28188.03951358795,61215,0,28188.03951358795,0.4771000146865845,2.3871309757232666,10000,31482.68537425995,0.6486132740974426,1.5043102502822876,0.6002399921417236,1.7300899028778076,50000 -3339.475342273712,2.2301175594329834,28607.97818994522,62126,0,28607.97818994522,0.4869000315666199,2.353926658630371,10000,31952.91978526116,0.6597460508346558,1.415103316307068,0.6001200079917908,1.6884820461273191,50000 -3387.2805788517,2.26474666595459,29028.25691366196,63040,0,29028.25691366196,0.4918000102043152,2.2991316318511963,10000,32421.085930347443,0.6842187643051147,1.2707115411758425,0.6073399782180786,1.6327868700027466,50000 -3436.948952436447,2.301134824752808,29448.54456448555,63953,0,29448.54456448555,0.4830000102519989,2.333249092102051,10000,32891.12489771843,0.6486914157867432,1.453892469406128,0.6010400056838989,1.6871942281723022,50000 -3488.2956540584564,2.341980695724488,29868.56832718849,64867,0,29868.56832718849,0.4889000356197357,2.33275842666626,10000,33362.58354306221,0.658203125,1.4299410581588743,0.6074599623680115,1.6725072860717771,50000 -3537.798797369004,2.3904261589050293,30288.65084552765,65780,0,30288.65084552765,0.4874000251293182,2.2988369464874268,10000,33832.26501727104,0.6772069931030273,1.291306734085083,0.6087999939918518,1.6291512250900269,50000 -3586.4557435512543,2.428715467453003,30708.97991228104,66695,0,30708.97991228104,0.4958000183105469,2.307793378829956,10000,34301.335938215256,0.6553320288658142,1.4468097686767578,0.6100599765777588,1.668556571006775,50000 -3635.289438724518,2.4718663692474365,31128.98797106743,67610,0,31128.98797106743,0.486700028181076,2.3031609058380127,10000,34770.267876148224,0.6596874594688416,1.3802225589752195,0.6093400120735168,1.6242049932479858,50000 -3683.48933506012,2.515212059020996,31548.9758489132,68524,0,31548.9758489132,0.4939000308513641,2.2985339164733887,10000,35238.550125837326,0.6707226634025574,1.330780267715454,0.6122999787330627,1.6297461986541748,50000 -3731.7338008880615,2.553818941116333,31968.895731449127,69435,0,31968.895731449127,0.4905000329017639,2.291445732116699,10000,35706.80040049553,0.6629882454872131,1.3963218927383425,0.6138799786567688,1.6306712627410889,50000 -3781.4290795326233,2.589449882507324,32389.219320058823,70350,0,32389.219320058823,0.4910000264644623,2.2989656925201416,10000,36176.90175771713,0.66259765625,1.3925042152404783,0.6094799637794495,1.6426547765731812,50000 -3830.998381853104,2.6295628547668457,32809.368601322174,71265,0,32809.368601322174,0.4864000082015991,2.3102755546569824,10000,36646.70735788345,0.6683593392372131,1.397711157798767,0.6119399666786194,1.6650700569152832,50000 -3880.787905454636,2.671271324157715,33229.415531635284,72179,0,33229.415531635284,0.4981000125408172,2.268369674682617,10000,37116.63294029236,0.6663671731948853,1.3786489963531494,0.6188600063323975,1.6109440326690674,50000 -3929.753345727921,2.7157506942749023,33649.62982439995,73093,0,33649.62982439995,0.4900000095367431,2.293707609176636,10000,37585.90481185913,0.6658984422683716,1.3752154111862185,0.6152799725532532,1.6217827796936035,50000 -3979.105614423752,2.763178586959839,34069.875893354416,74008,0,34069.875893354416,0.4978000223636627,2.266586780548096,10000,38055.59846138954,0.6753124594688416,1.3523850440979004,0.6179999709129333,1.608821153640747,50000 -4028.905200958252,2.799935817718506,34489.79175186157,74919,0,34489.79175186157,0.5057000517845154,2.234017372131348,10000,38525.39829945564,0.695605456829071,1.2618730068206787,0.6238799691200256,1.5817739963531494,50000 -4081.3146035671234,2.8388900756835938,34909.78695511818,75832,0,34909.78695511818,0.4936000108718872,2.2937138080596924,10000,38997.88926529884,0.6705663800239563,1.3746696710586548,0.6190599799156189,1.6183385848999023,50000 -4128.850201368332,2.878529787063598,35329.75003695488,76745,0,35329.75003695488,0.5089000463485718,2.204143762588501,10000,39465.47510480881,0.6803905963897705,1.297563910484314,0.6237199902534485,1.563851237297058,50000 -4179.281141281128,2.9159533977508545,35749.74824357033,77660,0,35749.74824357033,0.5022000074386597,2.245917320251465,10000,39935.9891512394,0.6931836009025574,1.2485311031341553,0.6234999895095825,1.5847680568695068,50000 -4225.807715415955,2.951959609985352,36169.98121476173,78574,0,36169.98121476173,0.508400022983551,2.222673177719116,10000,40402.8313946724,0.6760546565055847,1.3154542446136477,0.6225399971008301,1.5719152688980105,50000 -4273.000194072723,2.9928932189941406,36590.37432670593,79489,0,36590.37432670593,0.5056000351905823,2.23213791847229,10000,40870.50536131859,0.6782616972923279,1.3288729190826416,0.6286799907684326,1.572872281074524,50000 -4321.6484797000885,3.0400874614715576,37010.2838101387,80405,0,37010.2838101387,0.499500036239624,2.244521379470825,10000,41339.1589307785,0.6900194883346558,1.277381420135498,0.6231399774551392,1.5857053995132446,50000 -4371.208312034607,3.0785329341888428,37430.52463960648,81320,0,37430.52463960648,0.5049000382423401,2.2456095218658447,10000,41809.0449860096,0.6727538704872131,1.3480576276779177,0.6234999895095825,1.5841037034988403,50000 -4419.785451173782,3.1283504962921143,37850.4936645031,82232,0,37850.4936645031,0.5116000175476074,2.190577745437622,10000,42277.68804812431,0.6878319978713989,1.2847343683242798,0.6326599717140198,1.532365798950195,50000 -4466.1608464717865,3.169515371322632,38270.45635151863,83145,0,38270.45635151863,0.5128999948501587,2.209243059158325,10000,42744.11470103264,0.6927343606948853,1.2806254625320437,0.6339799761772156,1.558222413063049,50000 -4515.86549949646,3.206382989883423,38690.68273019791,84058,0,38690.68273019791,0.5073000192642212,2.186338186264038,10000,43214.1302447319,0.6813867092132568,1.28969407081604,0.6294599771499634,1.5353686809539795,50000 -4564.33405828476,3.2509865760803223,39111.102942466736,84974,0,39111.102942466736,0.5128000378608704,2.1811633110046387,10000,43683.11135816574,0.688281238079071,1.2811274528503418,0.6375600099563599,1.5175962448120115,50000 -4612.810303688049,3.293998003005981,39531.07405328751,85887,0,39531.07405328751,0.5178000330924988,2.1769161224365234,10000,44151.64874100685,0.6952148079872131,1.2292883396148682,0.6356799602508545,1.519616723060608,50000 -4661.783539533615,3.335458278656006,39951.19725751877,86798,0,39951.19725751877,0.5117000341415405,2.1646652221679688,10000,44620.83441233635,0.6976562142372131,1.220523476600647,0.6376199722290039,1.5006834268569946,50000 -4712.8799839019775,3.3804314136505127,40371.22511386871,87713,0,40371.22511386871,0.5028000473976135,2.264337539672852,10000,45092.05185699463,0.6788867115974426,1.3864643573760986,0.6287400126457214,1.6301475763320925,50000 -4760.05557847023,3.4225516319274902,40791.30937457085,88626,0,40791.30937457085,0.5161000490188599,2.1748321056365967,10000,45559.40066933632,0.696582019329071,1.239540934562683,0.6377399563789368,1.517307162284851,50000 -4809.654201030731,3.460923194885254,41211.301359415054,89538,0,41211.301359415054,0.5152000188827515,2.1656334400177,10000,46029.07649850845,0.71888667345047,1.155856728553772,0.6437000036239624,1.506482481956482,50000 -4858.942010641098,3.504815101623535,41631.52504038811,90449,0,41631.52504038811,0.518500030040741,2.1620571613311768,10000,46498.67860245705,0.6948632597923279,1.2670438289642334,0.6413399577140808,1.5187033414840698,50000 -4907.109016418457,3.544394254684448,42051.71841979027,91365,0,42051.71841979027,0.5216000080108643,2.1140828132629395,10000,46967.12596321106,0.7024804353713989,1.197124719619751,0.6451199650764465,1.4732191562652588,50000 -4955.174675226212,3.591316938400269,42471.94208693504,92280,0,42471.94208693504,0.5205000042915344,2.161663293838501,10000,47435.50911140442,0.7185351252555847,1.1746147871017456,0.6480799913406372,1.501075267791748,50000 -5005.583575248718,3.635996818542481,42892.12707614899,93194,0,42892.12707614899,0.527400016784668,2.1004345417022705,10000,47906.19521903992,0.7062109112739563,1.185657262802124,0.6523799896240234,1.4381099939346311,50000 -5055.447685480118,3.67959189414978,43312.31781864166,94110,0,43312.31781864166,0.5231000185012817,2.142191171646118,10000,48376.341938734055,0.7049023509025574,1.2190285921096802,0.6449199914932251,1.500189185142517,50000 -5105.180654525757,3.722055435180664,43732.4443500042,95025,0,43732.4443500042,0.5199000239372253,2.164458990097046,10000,48846.29187488556,0.7074413895606995,1.209861397743225,0.6448799967765808,1.504482984542847,50000 -5154.415741682053,3.763861179351807,44152.55052447319,95937,0,44152.55052447319,0.5348000526428223,2.0843238830566406,10000,49315.72231268883,0.7051367163658142,1.1937915086746216,0.6563999652862549,1.434855580329895,50000 -5202.185204267502,3.811398506164551,44572.62438511848,96852,0,44572.62438511848,0.5329000353813171,2.126685857772827,10000,49783.660767793655,0.7074413895606995,1.221797227859497,0.6502400040626526,1.4847956895828247,50000 -5253.482377767563,3.8508574962615967,44992.95506215096,97764,0,44992.95506215096,0.5281000137329102,2.0989506244659424,10000,50255.37498688698,0.7198437452316284,1.1401715278625488,0.6556599736213684,1.4324363470077517,50000 -5302.544720649719,3.8982207775115967,45413.15657520294,98676,0,45413.15657520294,0.5307000279426575,2.072648525238037,10000,50724.73337721825,0.7141015529632568,1.160022854804993,0.6577399969100952,1.420377254486084,50000 -5351.394645690918,3.9429714679718018,45833.22182202339,99592,0,45833.22182202339,0.5315000414848328,2.08896541595459,10000,51193.74008560181,0.7101367115974426,1.178592085838318,0.6572999954223633,1.4439600706100464,50000 -5399.48194694519,3.9898531436920166,46253.18224287033,100506,0,46253.18224287033,0.5288000106811523,2.1265082359313965,10000,51661.882147789,0.7159960865974426,1.1765002012252808,0.6542199850082397,1.4638800621032717,50000 -5449.350612878799,4.036839246749878,46673.48130655289,101420,0,46673.48130655289,0.5382000207901001,2.045794010162353,10000,52132.143812179565,0.7417382597923279,1.0204622745513916,0.6606000065803528,1.399212121963501,50000 -5498.46777844429,4.080301761627197,47093.46519160271,102335,0,47093.46519160271,0.5314000248908997,2.095097064971924,10000,52601.3354651928,0.7127734422683716,1.167461633682251,0.6566799879074097,1.4336830377578735,50000 -5547.918229103088,4.1213812828063965,47513.68147063255,103250,0,47513.68147063255,0.5400000214576721,2.0467312335968018,10000,53071.09167742729,0.7226171493530273,1.1213455200195312,0.6646199822425842,1.3939058780670166,50000 -5596.494349956512,4.169813394546509,47933.8608417511,104166,0,47933.8608417511,0.5416000485420227,2.037737846374512,10000,53539.94333600998,0.7356249690055847,1.0597867965698242,0.6630399823188782,1.3899242877960205,50000 -5646.397901058197,4.211780786514282,48353.88178706169,105081,0,48353.88178706169,0.5398000478744507,2.0553183555603027,10000,54009.95732069016,0.7218359112739563,1.1346908807754517,0.6655399799346924,1.3913841247558594,50000 -5698.203535318375,4.257709503173828,48774.124626636505,105995,0,48774.124626636505,0.5421000123023987,2.0250608921051025,10000,54482.09878492355,0.7276366949081421,1.0901107788085938,0.6678199768066406,1.3639365434646606,50000 -5746.058041095734,4.303881406784058,49194.25314736366,106909,0,49194.25314736366,0.5458000302314758,2.0267672538757324,10000,54950.18088555336,0.7373046875,1.0468322038650513,0.6692599654197693,1.3726061582565308,50000 -5794.743673563004,4.347439050674439,49614.54778671265,107821,0,49614.54778671265,0.5473000407218933,2.035310745239258,10000,55419.25144505501,0.7249218821525574,1.1311174631118774,0.6684199571609497,1.385303616523743,50000 -5844.7369022369385,4.391482353210449,50034.81283926964,108735,0,50034.81283926964,0.5497000217437744,2.023720026016236,10000,55889.60102963448,0.7259179353713989,1.0984207391738892,0.6665199995040894,1.3816251754760742,50000 -5893.524285316467,4.434794902801514,50455.170258522034,109650,0,50455.170258522034,0.5544000267982483,1.9914335012435915,10000,56358.83649373055,0.7411913871765137,1.045373558998108,0.6743599772453308,1.3515647649765017,50000 -5943.436125278473,4.477481842041016,50875.44026255608,110565,0,50875.44026255608,0.5525000095367432,2.017176866531372,10000,56829.10870409012,0.7310742139816284,1.099925875663757,0.6731199622154236,1.3631587028503418,50000 -5992.779366493225,4.526298999786377,51295.71018028259,111479,0,51295.71018028259,0.5570000410079956,1.9866653680801392,10000,57298.81796312332,0.7366992235183716,1.0674903392791748,0.6771799921989441,1.3465999364852903,50000 -6040.666308164597,4.571993350982666,51715.62279224396,112394,0,51715.62279224396,0.5523000359535217,1.9630584716796875,10000,57766.71099972725,0.7471093535423279,1.0021445751190186,0.6795799732208252,1.315180420875549,50000 -6088.146703958511,4.616669178009033,52135.81513166428,113310,0,52135.81513166428,0.5523000359535217,2.0030879974365234,10000,58234.475531339645,0.7491015195846558,1.0073623657226562,0.6755200028419495,1.3467522859573364,50000 -6138.448607206345,4.658712863922119,52556.26023578644,114226,0,52556.26023578644,0.5550000071525574,1.981272578239441,10000,58705.31205654144,0.7378710508346558,1.050584316253662,0.6758399605751038,1.3396703004837036,50000 -6189.425545454025,4.704372644424439,52976.29587888718,115140,0,52976.29587888718,0.554900050163269,1.9811760187149048,10000,59176.41819024086,0.7449804544448853,1.0316308736801147,0.6810399889945984,1.3292573690414429,50000 -6240.787206888199,4.749409198760986,53396.448315382,116054,0,53396.448315382,0.5635000467300415,1.9470332860946653,10000,59648.02452993393,0.7593359351158142,0.968052327632904,0.6818599700927734,1.3153448104858398,50000 -6291.655651569367,4.797953367233276,53816.71623015404,116970,0,53816.71623015404,0.5591000318527222,1.9881199598312376,10000,60119.25689291954,0.7439843416213989,1.0527466535568235,0.6826399564743042,1.3293102979660034,50000 -6340.977566003799,4.843794584274292,54236.68958616257,117884,0,54236.68958616257,0.5621000528335571,1.9683139324188232,10000,60588.64545035362,0.748828113079071,1.0334590673446655,0.6802399754524231,1.337364912033081,50000 -6390.563494682312,4.891946077346802,54656.7682967186,118799,0,54656.7682967186,0.5672000050544739,1.9073601961135864,10000,61058.40561413765,0.7620312571525574,0.9354296922683716,0.6877599954605103,1.279233694076538,50000 -6439.0599138736725,4.941753149032593,55076.75383043289,119714,0,55076.75383043289,0.5613000392913818,1.9275920391082764,10000,61526.98484683037,0.7458398342132568,1.0181658267974854,0.6830599904060364,1.300762176513672,50000 -6488.97558426857,4.9889538288116455,55496.69174456597,120629,0,55496.69174456597,0.5610000491142273,1.920379638671875,10000,61996.93345713616,0.7542773485183716,0.980697214603424,0.6851599812507629,1.289842963218689,50000 -6538.339520454407,5.038301229476929,55916.75475072861,121539,0,55916.75475072861,0.5708000063896179,1.9232209920883176,10000,62466.457062006,0.7638280987739563,0.9494880437850952,0.6881799697875977,1.290436625480652,50000 -6588.268809556961,5.083476781845093,56337.04533982277,122454,0,56337.04533982277,0.5659000277519226,1.932502269744873,10000,62936.76914644241,0.753613293170929,1.002234935760498,0.6880599856376648,1.3044852018356323,50000 -6636.103061914444,5.127235651016235,56757.24954080582,123369,0,56757.24954080582,0.5748000144958496,1.90095317363739,10000,63404.89881134033,0.7582421898841858,0.9696675539016724,0.692579984664917,1.262717843055725,50000 -6682.54062128067,5.174166440963745,57177.33669304848,124285,0,57177.33669304848,0.5712000131607056,1.8988127708435056,10000,63871.52213644981,0.7618163824081421,0.9538049101829528,0.6948999762535095,1.2704392671585083,50000 -6733.407159566879,5.218145370483398,57597.292941093445,125198,0,57597.292941093445,0.5733000040054321,1.946867108345032,10000,64342.43574547768,0.7659569978713989,0.96906316280365,0.6935399770736694,1.2955999374389648,50000 -6783.851147174835,5.274590492248535,58017.22869372368,126112,0,58017.22869372368,0.5680000185966492,1.9271174669265747,10000,64812.9195394516,0.7572070360183716,0.9875859618186952,0.6926199793815613,1.279220461845398,50000 -6833.316943883896,5.319833278656006,58437.14878249168,127025,0,58437.14878249168,0.5734000205993652,1.893884301185608,10000,65282.39766430855,0.7670117020606995,0.9423275589942932,0.6972000002861023,1.2566183805465698,50000 -6882.7317888736725,5.368030786514282,58857.462636232376,127939,0,58857.462636232376,0.5772000551223755,1.8826701641082764,10000,65752.2223212719,0.7809179425239563,0.883868932723999,0.6997999548912048,1.2388687133789062,50000 -6929.378044605255,5.412663459777832,59277.4666519165,128854,0,59277.4666519165,0.582800030708313,1.845189332962036,10000,66218.96444773674,0.7710937261581421,0.9130802154541016,0.7034199833869934,1.210724949836731,50000 -6978.882484436035,5.467691898345947,59697.61399292946,129767,0,59697.61399292946,0.5736000537872314,1.8769426345825195,10000,66688.71971225739,0.7718359231948853,0.9090878367424012,0.6992599964141846,1.23419189453125,50000 -7031.3348553180695,5.931267976760864,60117.49275612831,130683,0,60117.49275612831,0.5805000066757202,1.8567651510238647,10000,67161.56207251549,0.7808593511581421,0.8642115592956543,0.7039200067520142,1.218134522438049,50000 -7082.074404239655,5.981815576553345,60537.74565386772,131596,0,60537.74565386772,0.5819000005722046,1.8415201902389529,10000,67632.65162563324,0.7708203196525574,0.8994258642196655,0.7038799524307251,1.2020835876464844,50000 -7132.960758924484,6.033755540847778,60957.78933095932,132511,0,60957.78933095932,0.5835000276565552,1.8380862474441528,10000,68103.68083000183,0.7748632431030273,0.8906111717224121,0.7036799788475037,1.2126256227493286,50000 -7179.321115255356,6.082520008087158,61377.8827316761,133426,0,61377.8827316761,0.5866000056266785,1.816677451133728,10000,68570.23081469536,0.7884570360183716,0.8362293839454651,0.7095800042152405,1.18121600151062,50000 -7229.612575292587,6.1410746574401855,61798.15866804123,134343,0,61798.15866804123,0.5829000473022461,1.8295793533325195,10000,69040.90519189835,0.7745312452316284,0.8990936279296875,0.7057799696922302,1.2014926671981812,50000 -7278.565196990967,6.195879459381104,62218.16717839241,135256,0,62218.16717839241,0.5863000154495239,1.840918064117432,10000,69509.96885418892,0.7787109017372131,0.886769711971283,0.7080000042915344,1.2056972980499268,50000 -7326.843999385834,6.2438578605651855,62638.280061244965,136170,0,62638.280061244965,0.5907000303268433,1.8285341262817385,10000,69978.45606637001,0.7862108945846558,0.8529349565505981,0.7099399566650391,1.196141004562378,50000 -7378.012609243393,6.291548013687134,63058.56229448319,137084,0,63058.56229448319,0.5909000039100647,1.807421922683716,10000,70450.00231456757,0.7843554615974426,0.8620164394378662,0.7106199860572815,1.1773674488067627,50000 -7427.31960606575,6.34608793258667,63478.914189100266,137998,0,63478.914189100266,0.5921000242233276,1.8124998807907104,10000,70919.76289439201,0.7840625047683716,0.8761175274848938,0.7105000019073486,1.1932687759399414,50000 -7477.396207094192,6.399160146713257,63899.1681098938,138913,0,63899.1681098938,0.5958000421524048,1.7992960214614868,10000,71390.19374513626,0.7923241853713989,0.8205782771110535,0.714199960231781,1.164374589920044,50000 -7525.692511558533,6.453774929046631,64319.3515021801,139830,0,64319.3515021801,0.5888000130653381,1.8122637271881104,10000,71858.77561020851,0.8011132478713989,0.8038817644119263,0.7127000093460083,1.1908022165298462,50000 -7575.239279747009,6.506052732467651,64739.52413749695,140740,0,64739.52413749695,0.5986000299453735,1.7915613651275637,10000,72328.59392142296,0.7881835699081421,0.8385329246520996,0.7159799933433533,1.167323350906372,50000 -7622.6402060985565,6.555207014083862,65159.69063377381,141654,0,65159.69063377381,0.5964000225067139,1.7926957607269287,10000,72796.25745105743,0.7936913967132568,0.8237582445144653,0.7181000113487244,1.1573004722595217,50000 -7673.526467323303,6.612675666809082,65579.90997314453,142571,0,65579.90997314453,0.6039000153541565,1.7557858228683472,10000,73267.46773982048,0.8095507621765137,0.75356125831604,0.7208200097084045,1.128256916999817,50000 -7723.234818935394,6.664331912994385,66000.25936365128,143487,0,66000.25936365128,0.6000000238418579,1.7516762018203735,10000,73737.62437415123,0.7930663824081421,0.80852872133255,0.719539999961853,1.132944107055664,50000 -7773.509160041809,6.716897964477539,66420.45049548149,144400,0,66420.45049548149,0.6027000546455383,1.758381962776184,10000,74208.18918466568,0.7961718440055847,0.7973366379737854,0.7209199666976929,1.1386834383010864,50000 -7823.125457286835,6.775168180465698,66840.47810816765,145314,0,66840.47810816765,0.6045000553131104,1.7355223894119265,10000,74677.93926692009,0.80712890625,0.7481192350387573,0.7227999567985535,1.1151829957962036,50000 -7871.883242845535,6.829352855682373,67260.58725810051,146229,0,67260.58725810051,0.5997000336647034,1.7654008865356443,10000,75146.9090578556,0.8002148270606995,0.7895472049713135,0.7212600111961365,1.1343475580215454,50000 -7921.5880670547485,6.884565353393555,67680.65770792961,147140,0,67680.65770792961,0.6002000570297241,1.7653189897537231,10000,75616.78656816483,0.7982617020606995,0.7943770289421082,0.7218599915504456,1.133278250694275,50000 -7970.376499891281,6.934350490570068,68100.57943248749,148051,0,68100.57943248749,0.602400004863739,1.749596118927002,10000,76085.59352064133,0.8085156083106995,0.7679116725921631,0.7251200079917908,1.131921410560608,50000 -8020.678261995316,6.98644495010376,68520.65689635277,148964,0,68520.65689635277,0.6067000031471252,1.739829421043396,10000,76556.07243704796,0.8046875,0.7722216844558716,0.7258399724960327,1.1195939779281616,50000 -8069.858564853668,7.0408594608306885,68940.75487065315,149875,0,68940.75487065315,0.6060000061988831,1.7303481101989746,10000,77025.45209693909,0.8066015243530273,0.7632541656494141,0.7278800010681152,1.111369013786316,50000 -8118.56437587738,7.0947325229644775,69360.97912240028,150790,0,69360.97912240028,0.6104000210762024,1.7081046104431152,10000,77494.48408341408,0.8119726181030273,0.7188385725021362,0.7307599782943726,1.0796045064926147,50000 -8166.598405122757,7.145893335342407,69781.3741095066,151704,0,69781.3741095066,0.6109000444412231,1.7023441791534424,10000,77963.0119497776,0.8213671445846558,0.7049360275268555,0.7328000068664551,1.0796293020248413,50000 -8215.91041135788,7.195410490036011,70201.74805235863,152618,0,70201.74805235863,0.61080002784729,1.709767460823059,10000,78432.79470396042,0.8115234375,0.7339527010917664,0.7320799827575684,1.0806304216384888,50000 -8266.604093313217,7.249720573425293,70622.05784344673,153535,0,70622.05784344673,0.6096000075340271,1.706923484802246,10000,78903.90107440948,0.815722644329071,0.7261427640914917,0.7325199842453003,1.0846915245056152,50000 -8316.71166753769,7.300796031951904,71042.29072260857,154448,0,71042.29072260857,0.61080002784729,1.7289506196975708,10000,79374.33994674683,0.8185155987739563,0.7208164930343628,0.730239987373352,1.097428798675537,50000 -8365.147471904755,7.35153603553772,71462.36037421227,155361,0,71462.36037421227,0.6110000014305115,1.7229865789413452,10000,79842.9433221817,0.8110156059265137,0.7516688108444214,0.731440007686615,1.091639757156372,50000 -8416.536350011826,7.401159286499023,71882.66050243378,156274,0,71882.66050243378,0.6121000051498413,1.7028956413269043,10000,80314.7297205925,0.8199804425239563,0.7003152370452881,0.7363599538803101,1.0666792392730713,50000 -8466.652791976929,7.462125539779663,72303.01607465744,157186,0,72303.01607465744,0.6162000298500061,1.7018332481384275,10000,80785.31089067459,0.82289057970047,0.6994317173957825,0.7371799945831299,1.0725938081741333,50000 -8515.557217359543,7.524913787841797,72723.06442546844,158097,0,72723.06442546844,0.6152000427246094,1.7035168409347534,10000,81254.37323331833,0.821582019329071,0.718077540397644,0.738599956035614,1.0715638399124146,50000 -8565.915897130966,7.585332155227661,73143.2449285984,159010,0,73143.2449285984,0.6173000335693359,1.6965481042861938,10000,81725.02041864395,0.8240429759025574,0.699968159198761,0.7378199696540833,1.0641032457351685,50000 -8613.183167696,7.635556221008301,73563.15134334564,159924,0,73563.15134334564,0.6152999997138977,1.6972639560699463,10000,82192.29157710075,0.822558581829071,0.6985819935798645,0.7378399968147278,1.068437933921814,50000 -8660.566656589508,7.696813583374023,73983.04493808746,160838,0,73983.04493808746,0.6131000518798828,1.6841599941253662,10000,82659.67759394646,0.8229296803474426,0.6872890591621399,0.7397199869155884,1.054733395576477,50000 -8710.053017377853,7.754112958908081,74403.15802598,161752,0,74403.15802598,0.615600049495697,1.6832082271575928,10000,83129.38182520866,0.8261327743530273,0.6801848411560059,0.7404199838638306,1.0489463806152344,50000 -8758.655586957932,7.804761648178101,74823.39261484146,162668,0,74823.39261484146,0.6183000206947327,1.6932504177093506,10000,83598.31682682037,0.8266015648841858,0.6813116073608398,0.7404999732971191,1.0605944395065308,50000 -8809.826647281647,7.860961198806763,75243.69893074036,163582,0,75243.69893074036,0.6205000281333923,1.6695793867111206,10000,84069.89703917503,0.82958984375,0.6572279334068298,0.7425599694252014,1.0434695482254028,50000 -8859.568863630295,7.916836023330688,75663.79537057877,164496,0,75663.79537057877,0.6194000244140625,1.6700612306594849,10000,84539.8385951519,0.8274804353713989,0.6675172448158264,0.7411800026893616,1.0367389917373655,50000 -8909.697760820389,7.97200608253479,76083.82977080345,165411,0,76083.82977080345,0.6206000447273254,1.6760276556015017,10000,85010.10476636887,0.8327734470367432,0.6660227179527283,0.741599977016449,1.0521215200424194,50000 -8959.051559209824,8.026431798934937,76504.16225075722,166329,0,76504.16225075722,0.6205000281333923,1.660801887512207,10000,85479.89331364632,0.8364452719688416,0.6335233449935913,0.7430599927902222,1.0303434133529663,50000 -9010.117893695831,8.077922344207764,76924.07387590408,167241,0,76924.07387590408,0.619100034236908,1.6686664819717407,10000,85950.96992588043,0.8333203196525574,0.6550984978675842,0.7444999814033508,1.033787965774536,50000 -9057.402058124542,8.131011486053467,77344.41281080246,168155,0,77344.41281080246,0.6255000233650208,1.6653497219085693,10000,86418.69306540489,0.8345312476158142,0.6509010195732117,0.7447999715805054,1.033661127090454,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/measurements.csv deleted file mode 100644 index 5ee9888d4..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1873 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.34868586,6.907756,,,,,,,,,,,,,, -1,,,0.0007812499534338,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,40.48073649406433,69.00026822090149,40.48073649406433,28.51941752433777,0.0,0.0 -100,0.5142109,6.8781614,,,,,,,,,,,,,, -200,0.54448754,6.8505554,,,,,,,,,,,,,, -300,0.8027285,6.6719804,,,,,,,,,,,,,, -400,1.0152221,6.6013913,,,,,,,,,,,,,, -500,0.70194215,6.6385107,,,,,,,,,,,,,, -600,1.1332301,6.4263487,,,,,,,,,,,,,, -700,1.1584499,6.31006,,,,,,,,,,,,,, -800,1.5335047,6.25318,,,,,,,,,,,,,, -848,,,0.0295507814735174,5.966935157775879,0.0253199990838766,6.030439376831055,50000.0,0.0215000007301569,6.148784637451172,10000.0,460.5695321559906,530.42049741745,460.5695321559906,69.7873125076294,0.0192267894744873,0.0 -900,1.501002,6.188493,,,,,,,,,,,,,, -1000,1.0214701,6.2164764,,,,,,,,,,,,,, -1100,1.3658009,6.3635216,,,,,,,,,,,,,, -1200,1.1695151,5.981964,,,,,,,,,,,,,, -1300,1.0878919,6.4967093,,,,,,,,,,,,,, -1400,1.2808326,6.0027547,,,,,,,,,,,,,, -1500,0.8914128,6.569243,,,,,,,,,,,,,, -1600,1.1576419,5.666967,,,,,,,,,,,,,, -1700,1.1082815,5.721008,,,,,,,,,,,,,, -1761,,,0.0711914077401161,5.351015567779541,0.0685599967837333,5.379694938659668,50000.0,0.054100003093481,5.57318639755249,10000.0,880.5223081111908,998.2155048847198,880.5223081111908,117.55456948280334,0.0466711521148681,0.0 -1800,1.1635188,5.7660446,,,,,,,,,,,,,, -1900,1.1149937,5.637308,,,,,,,,,,,,,, -2000,1.0045043,6.22192,,,,,,,,,,,,,, -2100,0.9291444,5.573498,,,,,,,,,,,,,, -2200,1.1588283,5.582739,,,,,,,,,,,,,, -2300,0.97052777,5.5202026,,,,,,,,,,,,,, -2400,1.3050632,5.4489336,,,,,,,,,,,,,, -2500,0.9549691,6.3708906,,,,,,,,,,,,,, -2600,1.0871239,5.343405,,,,,,,,,,,,,, -2678,,,0.1262304633855819,4.751564025878906,0.119499996304512,4.812889099121094,50000.0,0.0898000076413154,5.094581604003906,10000.0,1300.5238733291626,1465.0715363025663,1300.5238733291626,164.33563232421875,0.0717246532440185,0.0 -2700,0.8798058,5.912314,,,,,,,,,,,,,, -2800,0.87261266,5.3693066,,,,,,,,,,,,,, -2900,1.4286264,5.3438354,,,,,,,,,,,,,, -3000,0.77798,6.0110054,,,,,,,,,,,,,, -3100,1.1469961,5.1375113,,,,,,,,,,,,,, -3200,1.2401259,6.4545,,,,,,,,,,,,,, -3300,1.2144438,5.1853724,,,,,,,,,,,,,, -3400,0.9785791,5.26513,,,,,,,,,,,,,, -3500,1.0412472,5.0993066,,,,,,,,,,,,,, -3597,,,0.1823828071355819,4.224600315093994,0.1660399883985519,4.346518516540527,50000.0,0.1284000128507614,4.73599910736084,10000.0,1720.9155325889587,1934.8571481704712,1720.9155325889587,213.6520509719849,0.1013989448547363,0.0 -3600,0.8642562,5.1696424,,,,,,,,,,,,,, -3700,0.85191345,5.7086124,,,,,,,,,,,,,, -3800,0.8471729,5.4374704,,,,,,,,,,,,,, -3900,1.2229453,4.9588757,,,,,,,,,,,,,, -4000,0.7435623,6.0169487,,,,,,,,,,,,,, -4100,0.91759366,5.253203,,,,,,,,,,,,,, -4200,0.98979735,4.773473,,,,,,,,,,,,,, -4300,0.746094,4.6976447,,,,,,,,,,,,,, -4400,0.6644101,5.753462,,,,,,,,,,,,,, -4500,0.8694307,4.699179,,,,,,,,,,,,,, -4516,,,0.2391015589237213,3.7565841674804688,0.222580000758171,3.8583083152771,50000.0,0.1705000102519989,4.336967945098877,10000.0,2141.0058159828186,2401.094133615494,2141.0058159828186,259.7257878780365,0.1269831657409668,0.0 -4600,0.5663596,6.3097258,,,,,,,,,,,,,, -4700,0.9047514,4.840025,,,,,,,,,,,,,, -4800,1.0465236,5.1530266,,,,,,,,,,,,,, -4900,0.7137991,6.272828,,,,,,,,,,,,,, -5000,0.66509295,5.864607,,,,,,,,,,,,,, -5100,0.90263444,4.480412,,,,,,,,,,,,,, -5200,0.70317507,4.9067883,,,,,,,,,,,,,, -5300,0.954007,4.4917083,,,,,,,,,,,,,, -5400,0.7588194,5.355624,,,,,,,,,,,,,, -5431,,,0.2705273330211639,3.628831386566162,0.2485599964857101,3.7476320266723633,50000.0,0.1940000057220459,4.203524112701416,10000.0,2561.1740078926086,2868.2551929950714,2561.1740078926086,306.6399974822998,0.152900218963623,0.0 -5500,0.8807531,4.725476,,,,,,,,,,,,,, -5600,0.66690904,5.336332,,,,,,,,,,,,,, -5700,0.770791,4.81079,,,,,,,,,,,,,, -5800,0.7656801,4.3355145,,,,,,,,,,,,,, -5900,0.6883767,5.050718,,,,,,,,,,,,,, -6000,0.8626163,4.200429,,,,,,,,,,,,,, -6100,0.68085575,6.136275,,,,,,,,,,,,,, -6200,0.6686095,5.4633355,,,,,,,,,,,,,, -6300,0.58195114,5.8546286,,,,,,,,,,,,,, -6348,,,0.3121093809604645,3.293789386749268,0.2845599949359894,3.4468889236450195,50000.0,0.2184000164270401,3.982554435729981,10000.0,2981.2767839431763,3334.8607263565063,2981.2767839431763,353.06326150894165,0.1847636699676513,0.0 -6400,0.7239191,4.1762896,,,,,,,,,,,,,, -6500,0.75705934,6.0270896,,,,,,,,,,,,,, -6600,0.8793685,4.212399,,,,,,,,,,,,,, -6700,0.757057,4.084973,,,,,,,,,,,,,, -6800,0.61943436,6.0688553,,,,,,,,,,,,,, -6900,0.8977214,4.351362,,,,,,,,,,,,,, -7000,0.8981833,4.31473,,,,,,,,,,,,,, -7100,0.9236326,4.110879,,,,,,,,,,,,,, -7200,0.88361865,5.660464,,,,,,,,,,,,,, -7261,,,0.3370117247104645,3.1292214393615723,0.3139199912548065,3.2608413696289062,50000.0,0.2422000169754028,3.815594434738159,10000.0,3401.2823746204376,3803.342059373856,3401.2823746204376,401.4614663124085,0.214972972869873,0.0 -7300,0.91012496,4.074015,,,,,,,,,,,,,, -7400,0.7084433,6.042463,,,,,,,,,,,,,, -7500,0.718976,4.033258,,,,,,,,,,,,,, -7600,0.8328487,3.9888005,,,,,,,,,,,,,, -7700,0.9095364,4.1598606,,,,,,,,,,,,,, -7800,0.78446525,4.032931,,,,,,,,,,,,,, -7900,0.5581994,5.385499,,,,,,,,,,,,,, -8000,1.4842159,4.133235,,,,,,,,,,,,,, -8100,0.84249294,4.2865686,,,,,,,,,,,,,, -8174,,,0.35986328125,3.0289080142974854,0.3312399983406067,3.1762967109680176,50000.0,0.2569000124931335,3.737917423248291,10000.0,3821.274812936783,4268.951048851013,3821.274812936783,447.0047791004181,0.2399046421051025,0.0 -8200,0.7950268,4.3350186,,,,,,,,,,,,,, -8300,0.7859442,3.9056582,,,,,,,,,,,,,, -8400,0.78875846,5.668187,,,,,,,,,,,,,, -8500,0.7591422,4.3614674,,,,,,,,,,,,,, -8600,0.82069236,3.8693235,,,,,,,,,,,,,, -8700,0.64267045,5.432068,,,,,,,,,,,,,, -8800,0.8347452,3.788368,,,,,,,,,,,,,, -8900,0.61669815,5.4040766,,,,,,,,,,,,,, -9000,0.7082702,6.018121,,,,,,,,,,,,,, -9089,,,0.3876953125,2.8449742794036865,0.3534199893474579,3.0242254734039307,50000.0,0.2750000059604645,3.597908735275269,10000.0,4241.370996952057,4737.255994081497,4241.370996952057,495.13693618774414,0.269237756729126,0.0 -9100,0.56586224,6.0078025,,,,,,,,,,,,,, -9200,0.8591019,3.775715,,,,,,,,,,,,,, -9300,0.62080085,5.011926,,,,,,,,,,,,,, -9400,0.81323844,4.1921096,,,,,,,,,,,,,, -9500,0.7811438,4.5191054,,,,,,,,,,,,,, -9600,0.8689997,3.8230963,,,,,,,,,,,,,, -9700,0.64448285,4.9124675,,,,,,,,,,,,,, -9800,0.7924632,3.7423072,,,,,,,,,,,,,, -9900,0.7797816,3.9956367,,,,,,,,,,,,,, -10000,0.7591843,5.3180313,,,,,,,,,,,,,, -10003,,,0.4051952958106994,2.718560218811035,0.3734799921512604,2.8857717514038086,50000.0,0.2870000004768371,3.4828357696533203,10000.0,4661.446325063705,5206.459691762924,4661.446325063705,544.185914516449,0.3009088039398193,0.0 -10100,0.9276742,3.7244036,,,,,,,,,,,,,, -10200,0.8314373,4.0021887,,,,,,,,,,,,,, -10300,0.98350674,3.7125664,,,,,,,,,,,,,, -10400,1.1427093,3.576374,,,,,,,,,,,,,, -10500,1.0313337,3.6230345,,,,,,,,,,,,,, -10600,0.58161455,5.3813467,,,,,,,,,,,,,, -10700,0.7936643,5.5521727,,,,,,,,,,,,,, -10800,0.88240904,3.6377454,,,,,,,,,,,,,, -10900,0.6973951,4.108605,,,,,,,,,,,,,, -10919,,,0.4252538979053497,2.589329719543457,0.3923999965190887,2.761080980300904,50000.0,0.3006000220775604,3.390204429626465,10000.0,5081.747814178467,5674.528142929077,5081.747814178467,591.8794357776642,0.3268032073974609,0.0 -11000,0.91773015,3.8910658,,,,,,,,,,,,,, -11100,0.848342,3.7974331,,,,,,,,,,,,,, -11200,0.79375994,5.1687665,,,,,,,,,,,,,, -11300,0.8302853,3.5231848,,,,,,,,,,,,,, -11400,0.7724977,4.6581874,,,,,,,,,,,,,, -11500,0.8166686,3.8188162,,,,,,,,,,,,,, -11600,0.8587903,3.8125205,,,,,,,,,,,,,, -11700,0.847305,3.5978856,,,,,,,,,,,,,, -11800,0.8217619,3.6085525,,,,,,,,,,,,,, -11832,,,0.4370703101158142,2.549715757369995,0.398499995470047,2.74328088760376,50000.0,0.3098000288009643,3.3450920581817627,10000.0,5501.858088970184,6143.224822998047,5501.858088970184,640.3868417739868,0.3585836887359619,0.0 -11900,0.96414596,3.5388713,,,,,,,,,,,,,, -12000,0.93664217,4.194087,,,,,,,,,,,,,, -12100,1.0068679,3.5514164,,,,,,,,,,,,,, -12200,0.98559594,3.8230605,,,,,,,,,,,,,, -12300,0.763022,5.814904,,,,,,,,,,,,,, -12400,0.6943849,4.8911147,,,,,,,,,,,,,, -12500,0.65994245,5.5060925,,,,,,,,,,,,,, -12600,0.94607556,3.4575071,,,,,,,,,,,,,, -12700,0.7964989,4.950388,,,,,,,,,,,,,, -12746,,,0.482226550579071,2.3162496089935303,0.4213799834251404,2.6229944229125977,50000.0,0.3286000192165375,3.2327558994293213,10000.0,5921.982423782349,6611.999529123306,5921.982423782349,688.9637801647186,0.3849976062774658,0.0 -12800,0.8485199,4.0327888,,,,,,,,,,,,,, -12900,0.7739713,5.556835,,,,,,,,,,,,,, -13000,0.7088441,4.891612,,,,,,,,,,,,,, -13100,1.0208197,3.5489025,,,,,,,,,,,,,, -13200,0.6662991,5.8397713,,,,,,,,,,,,,, -13300,1.0775669,3.5566292,,,,,,,,,,,,,, -13400,0.8878315,3.4826877,,,,,,,,,,,,,, -13500,1.1141443,3.4165716,,,,,,,,,,,,,, -13600,0.99383664,3.4561195,,,,,,,,,,,,,, -13662,,,0.4601757824420929,2.428586721420288,0.4309999942779541,2.577209711074829,50000.0,0.3323000073432922,3.2250514030456543,10000.0,6342.175592184067,7081.163420915604,6342.175592184067,737.8570251464844,0.4148344993591308,0.0 -13700,1.0895784,3.4137132,,,,,,,,,,,,,, -13800,0.69826996,5.343953,,,,,,,,,,,,,, -13900,1.0454093,3.4418793,,,,,,,,,,,,,, -14000,0.6251384,5.7514324,,,,,,,,,,,,,, -14100,1.006964,3.4311192,,,,,,,,,,,,,, -14200,1.0841686,3.4152458,,,,,,,,,,,,,, -14300,0.9977787,3.418192,,,,,,,,,,,,,, -14400,1.003894,3.520937,,,,,,,,,,,,,, -14500,0.9012773,3.3840845,,,,,,,,,,,,,, -14578,,,0.4689843654632568,2.3901267051696777,0.4369599819183349,2.566824197769165,50000.0,0.3360000252723694,3.2194483280181885,10000.0,6762.429019451141,7552.847557067871,6762.429019451141,789.2101700305939,0.4447028636932373,0.0 -14600,0.9345571,3.7912588,,,,,,,,,,,,,, -14700,0.7283259,5.3532,,,,,,,,,,,,,, -14800,0.80582345,5.272281,,,,,,,,,,,,,, -14900,0.79951805,4.470217,,,,,,,,,,,,,, -15000,0.72902673,4.6867733,,,,,,,,,,,,,, -15100,1.0950327,3.379594,,,,,,,,,,,,,, -15200,0.74302214,5.718819,,,,,,,,,,,,,, -15300,0.9584476,3.3944921,,,,,,,,,,,,,, -15400,0.90833455,3.7147183,,,,,,,,,,,,,, -15492,,,0.509960949420929,2.1471524238586426,0.4541999995708465,2.4339747428894043,50000.0,0.3511000275611877,3.085043430328369,10000.0,7182.634547472,8022.958084821701,7182.634547472,839.039304971695,0.4735012054443359,0.0 -15500,0.95023304,3.4724243,,,,,,,,,,,,,, -15600,0.73205006,4.6890006,,,,,,,,,,,,,, -15700,1.0464573,3.2146215,,,,,,,,,,,,,, -15800,0.94159484,3.2735202,,,,,,,,,,,,,, -15900,1.0920216,3.4198108,,,,,,,,,,,,,, -16000,1.1117669,3.3599863,,,,,,,,,,,,,, -16100,0.70878077,5.3213506,,,,,,,,,,,,,, -16200,0.6641734,5.580277,,,,,,,,,,,,,, -16300,0.9142445,3.9535706,,,,,,,,,,,,,, -16400,0.76613235,4.9280453,,,,,,,,,,,,,, -16408,,,0.4842382669448852,2.291804552078247,0.4571599960327148,2.444198608398437,50000.0,0.3532000184059143,3.1026744842529297,10000.0,7602.818835020065,8492.689853906631,7602.818835020065,888.5021407604218,0.5111334323883057,0.0 -16500,1.0767252,3.1149948,,,,,,,,,,,,,, -16600,1.1488959,3.2414355,,,,,,,,,,,,,, -16700,1.1031014,3.3631523,,,,,,,,,,,,,, -16800,0.71611565,5.2099266,,,,,,,,,,,,,, -16900,1.0323062,3.1636493,,,,,,,,,,,,,, -17000,1.1059797,3.2802763,,,,,,,,,,,,,, -17100,0.9639805,3.4089994,,,,,,,,,,,,,, -17200,1.2095326,3.4920006,,,,,,,,,,,,,, -17300,0.85740024,3.9396243,,,,,,,,,,,,,, -17323,,,0.4905664026737213,2.280002355575561,0.4546799957752228,2.458152532577514,50000.0,0.3483000099658966,3.109490871429444,10000.0,8022.798815488815,8961.801989793777,8022.798815488815,937.5597274303436,0.5380949974060059,0.0 -17400,0.80171025,4.721029,,,,,,,,,,,,,, -17500,0.8079945,5.585706,,,,,,,,,,,,,, -17600,1.0040367,3.3837135,,,,,,,,,,,,,, -17700,0.9903205,3.097848,,,,,,,,,,,,,, -17800,1.0189391,3.444808,,,,,,,,,,,,,, -17900,0.7798584,4.4717975,,,,,,,,,,,,,, -18000,0.7949244,4.061333,,,,,,,,,,,,,, -18100,0.834249,4.788016,,,,,,,,,,,,,, -18200,0.9034598,3.9619522,,,,,,,,,,,,,, -18239,,,0.520800769329071,2.135014533996582,0.4687999784946441,2.398402452468872,50000.0,0.3587000072002411,3.0502817630767822,10000.0,8443.062378168106,9430.669814825058,8443.062378168106,986.0897233486176,0.5650601387023926,0.0 -18300,1.0172844,3.3837698,,,,,,,,,,,,,, -18400,0.87034297,5.7530904,,,,,,,,,,,,,, -18500,0.947013,3.5498865,,,,,,,,,,,,,, -18600,0.9103862,3.4856083,,,,,,,,,,,,,, -18700,1.0105087,3.2023504,,,,,,,,,,,,,, -18800,1.0452958,3.3707263,,,,,,,,,,,,,, -18900,0.95861614,3.4144106,,,,,,,,,,,,,, -19000,1.0925255,3.100561,,,,,,,,,,,,,, -19100,0.9152599,3.4505858,,,,,,,,,,,,,, -19155,,,0.5138086080551147,2.095219135284424,0.4810999929904938,2.275703430175781,50000.0,0.3728000223636627,2.9553446769714355,10000.0,8863.191219329834,9902.65026807785,8863.191219329834,1037.8643803596497,0.5954127311706543,0.0 -19200,0.9053914,5.37709,,,,,,,,,,,,,, -19300,1.0999486,3.243633,,,,,,,,,,,,,, -19400,1.2662843,2.9595852,,,,,,,,,,,,,, -19500,1.2048612,3.251683,,,,,,,,,,,,,, -19600,0.7343118,5.209622,,,,,,,,,,,,,, -19700,1.1372265,3.0905118,,,,,,,,,,,,,, -19800,1.1082795,3.3335013,,,,,,,,,,,,,, -19900,0.95628965,5.216879,,,,,,,,,,,,,, -20000,0.9036643,5.526471,,,,,,,,,,,,,, -20070,,,0.5203906297683716,2.1330671310424805,0.4761599898338318,2.3456084728240967,50000.0,0.3712000250816345,2.999906301498413,10000.0,9283.32648420334,10370.82237124443,9283.32648420334,1085.8274364471436,0.6221892833709717,0.0 -20100,1.1282288,3.2735453,,,,,,,,,,,,,, -20200,0.8721919,4.447613,,,,,,,,,,,,,, -20300,0.89844507,3.747786,,,,,,,,,,,,,, -20400,1.028166,3.1845677,,,,,,,,,,,,,, -20500,0.93583024,3.8230305,,,,,,,,,,,,,, -20600,0.73336536,4.9025865,,,,,,,,,,,,,, -20700,1.1931152,3.1042356,,,,,,,,,,,,,, -20800,1.1024771,3.2590075,,,,,,,,,,,,,, -20900,0.7982279,4.5477257,,,,,,,,,,,,,, -20987,,,0.5329296588897705,2.013391733169556,0.4882999956607818,2.249237060546875,50000.0,0.3790000081062317,2.913668394088745,10000.0,9703.52631521225,10840.916332960129,9703.52631521225,1135.6430156230929,0.6531627178192139,0.0 -21000,0.9786389,3.112227,,,,,,,,,,,,,, -21100,1.145783,2.9288936,,,,,,,,,,,,,, -21200,1.1512016,2.9853055,,,,,,,,,,,,,, -21300,0.96619827,3.0371785,,,,,,,,,,,,,, -21400,0.8008113,5.1240883,,,,,,,,,,,,,, -21500,0.9639876,3.0882165,,,,,,,,,,,,,, -21600,0.7575155,5.0240903,,,,,,,,,,,,,, -21700,1.0537659,3.20544,,,,,,,,,,,,,, -21800,0.9283966,3.358356,,,,,,,,,,,,,, -21900,1.1961566,3.2659028,,,,,,,,,,,,,, -21901,,,0.5308398604393005,2.069345474243164,0.4955999851226806,2.256677865982056,50000.0,0.3831000328063965,2.9034087657928467,10000.0,10123.515281438828,11306.940642356873,10123.515281438828,1181.5994005203247,0.6843507289886475,0.0 -22000,1.0458728,3.4022155,,,,,,,,,,,,,, -22100,1.1502792,3.1526551,,,,,,,,,,,,,, -22200,1.0693685,3.1747704,,,,,,,,,,,,,, -22300,0.98850816,3.533267,,,,,,,,,,,,,, -22400,1.0101767,3.0312254,,,,,,,,,,,,,, -22500,1.021464,2.9343116,,,,,,,,,,,,,, -22600,0.8034259,4.3057075,,,,,,,,,,,,,, -22700,1.0184427,3.0482085,,,,,,,,,,,,,, -22800,0.8543461,5.585076,,,,,,,,,,,,,, -22816,,,0.5394726395606995,2.00322699546814,0.5013399720191956,2.20600700378418,50000.0,0.3904000222682953,2.871825218200684,10000.0,10543.550062179564,11775.017482995989,10543.550062179564,1229.5627224445343,0.7155659198760986,0.0 -22900,0.82785445,5.070528,,,,,,,,,,,,,, -23000,0.78404915,4.498676,,,,,,,,,,,,,, -23100,1.0565845,3.5401506,,,,,,,,,,,,,, -23200,0.92891407,3.8229241,,,,,,,,,,,,,, -23300,1.0999066,3.0198326,,,,,,,,,,,,,, -23400,0.7848013,5.3417435,,,,,,,,,,,,,, -23500,1.1354728,3.3494241,,,,,,,,,,,,,, -23600,1.0229503,2.8976388,,,,,,,,,,,,,, -23700,1.124108,3.1912427,,,,,,,,,,,,,, -23732,,,0.551074206829071,1.951820373535156,0.5073400139808655,2.178410768508911,50000.0,0.398900032043457,2.8280670642852783,10000.0,10963.586684465408,12244.024310827255,10963.586684465408,1278.4515182971954,0.7490296363830566,0.0 -23800,0.945096,3.322178,,,,,,,,,,,,,, -23900,0.8689277,4.011855,,,,,,,,,,,,,, -24000,1.1053694,3.3440218,,,,,,,,,,,,,, -24100,1.0928222,3.041306,,,,,,,,,,,,,, -24200,1.0213602,2.9735827,,,,,,,,,,,,,, -24300,0.9615598,3.8843727,,,,,,,,,,,,,, -24400,0.8296294,4.275031,,,,,,,,,,,,,, -24500,0.9493636,3.5786567,,,,,,,,,,,,,, -24600,1.0907396,3.7121503,,,,,,,,,,,,,, -24649,,,0.5817968845367432,1.7977644205093384,0.5110399723052979,2.1437582969665527,50000.0,0.3990000188350677,2.818315029144287,10000.0,11383.736985206604,12713.149666309357,11383.736985206604,1327.3468503952026,0.781226396560669,0.0 -24700,1.0296334,3.0438478,,,,,,,,,,,,,, -24800,1.0639814,2.8101387,,,,,,,,,,,,,, -24900,0.8619269,4.69606,,,,,,,,,,,,,, -25000,1.2181437,3.0847995,,,,,,,,,,,,,, -25100,0.79940915,5.0390635,,,,,,,,,,,,,, -25200,1.2514515,3.1446264,,,,,,,,,,,,,, -25300,0.94112945,3.6212637,,,,,,,,,,,,,, -25400,1.1579008,2.9169989,,,,,,,,,,,,,, -25500,1.1420277,3.1743245,,,,,,,,,,,,,, -25565,,,0.5573828220367432,1.887807011604309,0.5200200080871582,2.082399606704712,50000.0,0.4045000076293945,2.760893106460572,10000.0,11803.839218378069,13181.796215295792,11803.839218378069,1375.8142375946045,0.8103833198547363,0.0 -25600,1.068293,3.0470166,,,,,,,,,,,,,, -25700,0.926243,4.214747,,,,,,,,,,,,,, -25800,0.9383424,5.673296,,,,,,,,,,,,,, -25900,1.158129,3.0804071,,,,,,,,,,,,,, -26000,0.9112745,3.6068964,,,,,,,,,,,,,, -26100,0.9448797,3.3724127,,,,,,,,,,,,,, -26200,1.3105698,3.1154153,,,,,,,,,,,,,, -26300,1.0494506,2.930254,,,,,,,,,,,,,, -26400,0.8372124,5.036288,,,,,,,,,,,,,, -26481,,,0.5679296851158142,1.8611342906951904,0.5214599967002869,2.0894598960876465,50000.0,0.4079000055789947,2.758775234222412,10000.0,12223.86883687973,13651.433849334717,12223.86883687973,1425.337683916092,0.8424429893493652,0.0 -26500,0.85077107,5.4456053,,,,,,,,,,,,,, -26600,1.2152799,3.1332781,,,,,,,,,,,,,, -26700,1.1981373,2.9240067,,,,,,,,,,,,,, -26800,0.8393666,5.5755672,,,,,,,,,,,,,, -26900,1.1023546,2.9693935,,,,,,,,,,,,,, -27000,1.3086401,2.9196215,,,,,,,,,,,,,, -27100,1.0779897,2.8514214,,,,,,,,,,,,,, -27200,1.0574775,3.5817962,,,,,,,,,,,,,, -27300,0.8269641,4.103298,,,,,,,,,,,,,, -27397,,,0.5859179496765137,1.7362204790115356,0.5262599587440491,2.0381267070770264,50000.0,0.4154000282287597,2.729538917541504,10000.0,12644.12185382843,14122.03713798523,12644.12185382843,1475.6100606918335,0.8734757900238037,0.0 -27400,1.0079305,2.7931163,,,,,,,,,,,,,, -27500,1.0596775,3.162123,,,,,,,,,,,,,, -27600,0.8377082,5.145019,,,,,,,,,,,,,, -27700,1.1140846,2.897668,,,,,,,,,,,,,, -27800,1.0680188,2.9584608,,,,,,,,,,,,,, -27900,1.1513841,3.0514402,,,,,,,,,,,,,, -28000,1.0186763,3.2854629,,,,,,,,,,,,,, -28100,1.1669419,3.033244,,,,,,,,,,,,,, -28200,1.0140586,3.1709495,,,,,,,,,,,,,, -28300,1.0571018,3.0958679,,,,,,,,,,,,,, -28313,,,0.5691796541213989,1.837002515792847,0.5263400077819824,2.0492756366729736,50000.0,0.4115000069141388,2.7372734546661377,10000.0,13064.183827638626,14590.11095881462,13064.183827638626,1523.542355298996,0.9049403667449952,0.0 -28400,0.88417405,4.846215,,,,,,,,,,,,,, -28500,1.1329995,2.9110327,,,,,,,,,,,,,, -28600,0.9047492,3.904615,,,,,,,,,,,,,, -28700,1.2276646,2.8932898,,,,,,,,,,,,,, -28800,0.82765216,4.6627207,,,,,,,,,,,,,, -28900,1.1354941,2.9753976,,,,,,,,,,,,,, -29000,1.1062719,4.729865,,,,,,,,,,,,,, -29100,1.1554493,3.0720396,,,,,,,,,,,,,, -29200,0.96795595,5.378979,,,,,,,,,,,,,, -29230,,,0.5762890577316284,1.8103874921798704,0.5307799577713013,2.0295510292053223,50000.0,0.4165000319480896,2.705432653427124,10000.0,13484.137912511826,15059.425895929337,13484.137912511826,1572.8207762241364,0.939743995666504,0.0 -29300,1.0642409,2.9216383,,,,,,,,,,,,,, -29400,0.9334244,5.5847836,,,,,,,,,,,,,, -29500,1.0422078,2.851348,,,,,,,,,,,,,, -29600,1.2848966,2.8992672,,,,,,,,,,,,,, -29700,1.1550044,2.8960962,,,,,,,,,,,,,, -29800,1.0271136,2.8757248,,,,,,,,,,,,,, -29900,1.062123,5.281368,,,,,,,,,,,,,, -30000,1.0338914,2.8988419,,,,,,,,,,,,,, -30100,1.0589267,2.816493,,,,,,,,,,,,,, -30147,,,0.586718738079071,1.7852870225906372,0.5297799706459045,2.055450677871704,50000.0,0.416700005531311,2.7173359394073486,10000.0,13904.376311302183,15528.329751729963,13904.376311302183,1621.401178598404,0.9774858951568604,0.0 -30200,1.1980855,2.8016477,,,,,,,,,,,,,, -30300,1.0607736,2.9458828,,,,,,,,,,,,,, -30400,1.2065084,2.790854,,,,,,,,,,,,,, -30500,1.0252485,3.6696887,,,,,,,,,,,,,, -30600,1.0888867,2.8391411,,,,,,,,,,,,,, -30700,1.0140728,3.1315348,,,,,,,,,,,,,, -30800,1.1852252,2.925062,,,,,,,,,,,,,, -30900,1.1131748,2.9835124,,,,,,,,,,,,,, -31000,0.91424584,5.214085,,,,,,,,,,,,,, -31064,,,0.5789452791213989,1.8326478004455569,0.5329399704933167,2.0538718700408936,50000.0,0.4200000166893005,2.728641986846924,10000.0,14324.65103316307,15996.338223218918,14324.65103316307,1669.0584816932678,1.0061118602752686,0.0 -31100,0.97622895,5.2775044,,,,,,,,,,,,,, -31200,0.9775498,4.7996073,,,,,,,,,,,,,, -31300,1.144543,2.9301155,,,,,,,,,,,,,, -31400,0.9960738,4.252124,,,,,,,,,,,,,, -31500,1.2178956,2.877213,,,,,,,,,,,,,, -31600,1.211657,2.8309083,,,,,,,,,,,,,, -31700,1.0168245,5.4835505,,,,,,,,,,,,,, -31800,1.1683099,2.8753765,,,,,,,,,,,,,, -31900,0.79655945,5.221083,,,,,,,,,,,,,, -31981,,,0.5824804306030273,1.7766609191894531,0.5413399934768677,1.9840885400772093,50000.0,0.4280000329017639,2.649587631225586,10000.0,14744.678804397585,16463.911828041077,14744.678804397585,1716.5272045135498,1.0351231098175049,0.0 -32000,0.90859264,3.725788,,,,,,,,,,,,,, -32100,1.1209304,2.864372,,,,,,,,,,,,,, -32200,1.243446,3.0237098,,,,,,,,,,,,,, -32300,0.8249627,4.384552,,,,,,,,,,,,,, -32400,1.0358139,3.1523757,,,,,,,,,,,,,, -32500,1.0731543,3.6770005,,,,,,,,,,,,,, -32600,1.1425626,2.9645581,,,,,,,,,,,,,, -32700,1.139523,3.6549168,,,,,,,,,,,,,, -32800,1.078619,4.354964,,,,,,,,,,,,,, -32897,,,0.5879296660423279,1.7226225137710571,0.5427199602127075,1.9646421670913696,50000.0,0.4245000183582306,2.6344099044799805,10000.0,15164.711621046066,16931.429044008255,15164.711621046066,1763.92360329628,1.0761663913726809,0.0 -32900,1.0187136,4.302454,,,,,,,,,,,,,, -33000,1.129648,3.0672455,,,,,,,,,,,,,, -33100,1.0091863,3.663317,,,,,,,,,,,,,, -33200,1.21685,2.8205109,,,,,,,,,,,,,, -33300,1.1035118,5.298444,,,,,,,,,,,,,, -33400,1.1681259,3.708932,,,,,,,,,,,,,, -33500,1.1085818,2.808205,,,,,,,,,,,,,, -33600,1.1086553,2.932771,,,,,,,,,,,,,, -33700,1.1895655,2.8098407,,,,,,,,,,,,,, -33800,1.139204,2.9080126,,,,,,,,,,,,,, -33811,,,0.5936523079872131,1.7037583589553833,0.5494799613952637,1.9261177778244016,50000.0,0.4308000206947326,2.6165106296539307,10000.0,15584.832740068436,17399.639848709106,15584.832740068436,1811.923983812332,1.1177661418914795,0.0 -33900,0.96469516,3.4304526,,,,,,,,,,,,,, -34000,1.0653863,2.6926641,,,,,,,,,,,,,, -34100,0.9783526,3.8193972,,,,,,,,,,,,,, -34200,1.0176507,5.1323776,,,,,,,,,,,,,, -34300,1.0748192,3.2313662,,,,,,,,,,,,,, -34400,0.98669744,3.5415437,,,,,,,,,,,,,, -34500,1.2536894,2.8919358,,,,,,,,,,,,,, -34600,0.928854,4.203685,,,,,,,,,,,,,, -34700,1.146911,2.8342757,,,,,,,,,,,,,, -34727,,,0.5853710770606995,1.762331485748291,0.5461999773979187,1.964341163635254,50000.0,0.4324000179767608,2.64608097076416,10000.0,16004.785893440248,17869.55842280388,16004.785893440248,1861.807582378388,1.1519536972045898,0.0 -34800,0.97189385,4.813777,,,,,,,,,,,,,, -34900,1.0977744,2.869969,,,,,,,,,,,,,, -35000,1.039193,3.3442593,,,,,,,,,,,,,, -35100,1.1744921,2.8307147,,,,,,,,,,,,,, -35200,1.1101806,3.0210717,,,,,,,,,,,,,, -35300,1.1782898,2.7875552,,,,,,,,,,,,,, -35400,0.88254595,4.924054,,,,,,,,,,,,,, -35500,1.2140293,2.9154625,,,,,,,,,,,,,, -35600,0.9296542,3.96428,,,,,,,,,,,,,, -35643,,,0.6008984446525574,1.6964315176010132,0.5519599914550781,1.9376044273376465,50000.0,0.434000015258789,2.613935708999634,10000.0,16425.0973944664,18338.87577271461,16425.0973944664,1910.7358112335205,1.1812076568603516,0.0 -35700,1.2330983,2.79458,,,,,,,,,,,,,, -35800,1.1710054,2.851403,,,,,,,,,,,,,, -35900,1.2059977,2.7636285,,,,,,,,,,,,,, -36000,1.2127889,2.7033837,,,,,,,,,,,,,, -36100,1.0291911,3.1907892,,,,,,,,,,,,,, -36200,0.8795668,4.9840665,,,,,,,,,,,,,, -36300,1.1417346,2.9482656,,,,,,,,,,,,,, -36400,0.8835371,5.0140758,,,,,,,,,,,,,, -36500,1.0839033,2.808816,,,,,,,,,,,,,, -36555,,,0.6290624737739563,1.551759958267212,0.5550999641418457,1.8873655796051023,50000.0,0.4384000301361084,2.5619187355041504,10000.0,16845.29321050644,18808.52798581124,16845.29321050644,1960.112357378006,1.2128477096557615,0.0 -36600,1.1901983,2.769054,,,,,,,,,,,,,, -36700,1.2189027,2.7272818,,,,,,,,,,,,,, -36800,1.0655208,2.6592536,,,,,,,,,,,,,, -36900,1.1786358,3.2302341,,,,,,,,,,,,,, -37000,1.1011636,2.7416651,,,,,,,,,,,,,, -37100,1.0797219,3.0689278,,,,,,,,,,,,,, -37200,1.1168717,2.9862397,,,,,,,,,,,,,, -37300,1.196933,2.860883,,,,,,,,,,,,,, -37400,1.1841857,2.7494478,,,,,,,,,,,,,, -37470,,,0.5981054306030273,1.7004003524780271,0.5553799867630005,1.9004456996917725,50000.0,0.4430000185966491,2.575869560241699,10000.0,17265.42009329796,19279.00605130196,17265.42009329796,2010.38028049469,1.2492239475250244,0.0 -37500,1.0539039,3.3196378,,,,,,,,,,,,,, -37600,0.9568472,5.442728,,,,,,,,,,,,,, -37700,1.0472878,3.350304,,,,,,,,,,,,,, -37800,1.0686052,2.7262592,,,,,,,,,,,,,, -37900,1.1605152,2.8006012,,,,,,,,,,,,,, -38000,1.0995651,3.1430397,,,,,,,,,,,,,, -38100,1.1657506,2.7520897,,,,,,,,,,,,,, -38200,0.9512347,3.5203028,,,,,,,,,,,,,, -38300,1.1048648,2.7415767,,,,,,,,,,,,,, -38384,,,0.5977538824081421,1.6697273254394531,0.5511000156402588,1.906893253326416,50000.0,0.4426000118255615,2.580732822418213,10000.0,17685.49142575264,19748.276788711548,17685.49142575264,2059.4976251125336,1.2838959693908691,0.0 -38400,1.0298952,3.5355914,,,,,,,,,,,,,, -38500,0.84216815,4.210069,,,,,,,,,,,,,, -38600,1.1476732,2.8030798,,,,,,,,,,,,,, -38700,1.110467,3.8290374,,,,,,,,,,,,,, -38800,1.171957,2.7647402,,,,,,,,,,,,,, -38900,1.2159057,2.7101798,,,,,,,,,,,,,, -39000,0.99390167,3.9571307,,,,,,,,,,,,,, -39100,1.1983362,2.7467892,,,,,,,,,,,,,, -39200,1.074518,2.6972456,,,,,,,,,,,,,, -39298,,,0.6231836080551147,1.5941240787506104,0.5628399848937988,1.888396143913269,50000.0,0.4438000321388244,2.558645725250244,10000.0,18105.790986537933,20217.33277368545,18105.790986537933,2108.164050579071,1.3266685009002686,0.0 -39300,1.206632,2.7646484,,,,,,,,,,,,,, -39400,0.91131634,4.3364086,,,,,,,,,,,,,, -39500,1.0752317,5.1981254,,,,,,,,,,,,,, -39600,1.0747969,2.9280918,,,,,,,,,,,,,, -39700,1.0305784,3.0318606,,,,,,,,,,,,,, -39800,1.1204281,2.715188,,,,,,,,,,,,,, -39900,0.98674494,5.098997,,,,,,,,,,,,,, -40000,0.92436504,3.4416041,,,,,,,,,,,,,, -40100,1.1368799,2.7632937,,,,,,,,,,,,,, -40200,1.0678607,3.303247,,,,,,,,,,,,,, -40214,,,0.5969336032867432,1.700569987297058,0.55485999584198,1.911470413208008,50000.0,0.4414000213146209,2.5693047046661377,10000.0,18526.15670633316,20686.45342946053,18526.15670633316,2156.8390045166016,1.3593740463256836,0.0 -40300,1.2229341,2.720786,,,,,,,,,,,,,, -40400,1.0892019,2.9220338,,,,,,,,,,,,,, -40500,0.90226567,5.080506,,,,,,,,,,,,,, -40600,1.109378,2.8079538,,,,,,,,,,,,,, -40700,1.0903935,2.8393955,,,,,,,,,,,,,, -40800,1.1808096,2.9215198,,,,,,,,,,,,,, -40900,0.94379467,4.3177137,,,,,,,,,,,,,, -41000,1.1528084,2.6725152,,,,,,,,,,,,,, -41100,1.208693,2.6286168,,,,,,,,,,,,,, -41130,,,0.6115429401397705,1.6528892517089844,0.5624399781227112,1.895995855331421,50000.0,0.4456000328063965,2.548941135406494,10000.0,18946.13728427887,21155.588027715683,18946.13728427887,2205.9153735637665,1.389624834060669,0.0 -41200,0.9400756,4.022317,,,,,,,,,,,,,, -41300,1.4952651,2.7369115,,,,,,,,,,,,,, -41400,1.2938309,2.7831159,,,,,,,,,,,,,, -41500,1.0794078,2.6803708,,,,,,,,,,,,,, -41600,1.3359587,2.7750974,,,,,,,,,,,,,, -41700,0.9526474,5.021924,,,,,,,,,,,,,, -41800,1.2891167,2.7460856,,,,,,,,,,,,,, -41900,0.94176656,4.575263,,,,,,,,,,,,,, -42000,1.104172,4.064043,,,,,,,,,,,,,, -42044,,,0.6252148151397705,1.5360467433929443,0.5721399784088135,1.8095941543579104,50000.0,0.4542000293731689,2.498311758041382,10000.0,19366.324320554733,21624.983068466187,19366.324320554733,2255.04528427124,1.4197359085083008,0.0 -42100,1.0279255,3.1970284,,,,,,,,,,,,,, -42200,1.1645021,2.9008427,,,,,,,,,,,,,, -42300,1.1435666,3.1761525,,,,,,,,,,,,,, -42400,1.4142122,3.263884,,,,,,,,,,,,,, -42500,1.0096406,3.7073724,,,,,,,,,,,,,, -42600,1.1030685,2.629611,,,,,,,,,,,,,, -42700,1.2379855,3.1560369,,,,,,,,,,,,,, -42800,1.0638118,2.829317,,,,,,,,,,,,,, -42900,1.2563015,2.651074,,,,,,,,,,,,,, -42957,,,0.6100390553474426,1.6844910383224487,0.5631600022315979,1.9153586626052856,50000.0,0.4485000073909759,2.561943292617798,10000.0,19786.38955569268,22095.019829034805,19786.38955569268,2304.938737630844,1.450535774230957,0.0 -43000,1.2047054,2.7932749,,,,,,,,,,,,,, -43100,1.394722,2.7165434,,,,,,,,,,,,,, -43200,1.1428765,2.6666923,,,,,,,,,,,,,, -43300,1.1465695,2.8437698,,,,,,,,,,,,,, -43400,0.958703,3.6934583,,,,,,,,,,,,,, -43500,1.1997843,2.5665686,,,,,,,,,,,,,, -43600,0.9342169,5.1484365,,,,,,,,,,,,,, -43700,1.3050655,2.862124,,,,,,,,,,,,,, -43800,1.3091344,2.751299,,,,,,,,,,,,,, -43869,,,0.6125780940055847,1.6683449745178225,0.5681599974632263,1.8960041999816888,50000.0,0.4502000212669372,2.551200151443481,10000.0,20206.347382307053,22562.94238114357,20206.347382307053,2352.82309794426,1.4831857681274414,0.0 -43900,0.900848,5.298203,,,,,,,,,,,,,, -44000,1.3058627,2.672415,,,,,,,,,,,,,, -44100,1.2225393,2.715489,,,,,,,,,,,,,, -44200,1.2823229,2.6418607,,,,,,,,,,,,,, -44300,0.9250033,4.2086444,,,,,,,,,,,,,, -44400,0.98562044,4.0328016,,,,,,,,,,,,,, -44500,1.3064642,2.7256331,,,,,,,,,,,,,, -44600,1.1920111,2.694447,,,,,,,,,,,,,, -44700,0.90918696,4.0577126,,,,,,,,,,,,,, -44782,,,0.6306250095367432,1.5366846323013306,0.5754599571228027,1.7981550693511963,50000.0,0.4561000168323517,2.472685098648072,10000.0,20626.368554592133,23032.650621652603,20626.368554592133,2402.4246587753296,1.521988868713379,0.0 -44800,1.1291429,2.5697374,,,,,,,,,,,,,, -44900,1.1625777,2.9515705,,,,,,,,,,,,,, -45000,1.1873555,5.3041096,,,,,,,,,,,,,, -45100,1.1075108,2.6447716,,,,,,,,,,,,,, -45200,0.83989555,5.2719603,,,,,,,,,,,,,, -45300,1.0349923,3.0970473,,,,,,,,,,,,,, -45400,1.1743698,2.787539,,,,,,,,,,,,,, -45500,1.0339986,3.8611584,,,,,,,,,,,,,, -45600,1.0492519,3.571788,,,,,,,,,,,,,, -45695,,,0.6145312190055847,1.6023259162902832,0.5731599926948547,1.807721376419068,50000.0,0.4631000161170959,2.4594175815582275,10000.0,21046.379149913788,23501.57047510147,21046.379149913788,2451.248507976532,1.560218334197998,0.0 -45700,1.1421785,2.999002,,,,,,,,,,,,,, -45800,1.1664461,2.631354,,,,,,,,,,,,,, -45900,1.1597372,2.7083719,,,,,,,,,,,,,, -46000,1.1863633,2.7465305,,,,,,,,,,,,,, -46100,1.2353702,2.6297603,,,,,,,,,,,,,, -46200,1.0400199,3.064879,,,,,,,,,,,,,, -46300,1.3237498,2.725266,,,,,,,,,,,,,, -46400,0.94269663,5.182876,,,,,,,,,,,,,, -46500,1.2398077,2.7414982,,,,,,,,,,,,,, -46600,1.1457915,2.6221616,,,,,,,,,,,,,, -46611,,,0.6163476705551147,1.598486304283142,0.5700199604034424,1.8413797616958616,50000.0,0.4531000256538391,2.510585308074951,10000.0,21466.459342956543,23970.574808597565,21466.459342956543,2500.084460258484,1.5997076034545898,0.0 -46700,1.2166547,2.6634347,,,,,,,,,,,,,, -46800,0.97483337,4.6953974,,,,,,,,,,,,,, -46900,1.093974,3.6685293,,,,,,,,,,,,,, -47000,0.9658128,4.705395,,,,,,,,,,,,,, -47100,1.1507819,2.7654638,,,,,,,,,,,,,, -47200,1.1787595,5.1355677,,,,,,,,,,,,,, -47300,0.95230883,4.63645,,,,,,,,,,,,,, -47400,1.1738564,2.6945984,,,,,,,,,,,,,, -47500,0.85963994,5.2934604,,,,,,,,,,,,,, -47521,,,0.629199206829071,1.6050124168395996,0.575760006904602,1.858893871307373,50000.0,0.4595000147819519,2.499596118927002,10000.0,21886.38616538048,24440.17867231369,21886.38616538048,2549.6821115016937,1.631948709487915,0.0 -47600,1.1071886,3.0611775,,,,,,,,,,,,,, -47700,1.0852267,5.1415405,,,,,,,,,,,,,, -47800,1.1610845,2.7192209,,,,,,,,,,,,,, -47900,1.03788,2.7894077,,,,,,,,,,,,,, -48000,0.8352534,5.130422,,,,,,,,,,,,,, -48100,1.0960379,3.6448593,,,,,,,,,,,,,, -48200,1.1108481,3.0079756,,,,,,,,,,,,,, -48300,1.0243864,4.6106305,,,,,,,,,,,,,, -48400,1.0503854,2.8541746,,,,,,,,,,,,,, -48431,,,0.6450585722923279,1.445672631263733,0.5813400149345398,1.7478997707366943,50000.0,0.4678000211715698,2.401552677154541,10000.0,22306.334075450897,24908.24243426323,22306.334075450897,2597.718656778336,1.6638882160186768,0.0 -48500,1.2880017,2.6812785,,,,,,,,,,,,,, -48600,1.1135583,2.6780033,,,,,,,,,,,,,, -48700,1.1291331,2.6435592,,,,,,,,,,,,,, -48800,1.0211614,4.936693,,,,,,,,,,,,,, -48900,1.1169968,2.7096653,,,,,,,,,,,,,, -49000,1.2196039,2.6136675,,,,,,,,,,,,,, -49100,1.2265314,2.7357435,,,,,,,,,,,,,, -49200,1.0681814,3.347117,,,,,,,,,,,,,, -49300,0.9885063,4.8053,,,,,,,,,,,,,, -49343,,,0.623046875,1.623903512954712,0.5805599689483643,1.835086703300476,50000.0,0.4589000344276428,2.4847159385681152,10000.0,22726.702834129333,25377.72305703163,22726.702834129333,2646.743688106537,1.7032458782196045,0.0 -49400,1.2257671,2.7192526,,,,,,,,,,,,,, -49500,1.1668187,2.7189758,,,,,,,,,,,,,, -49600,1.3489728,2.7850966,,,,,,,,,,,,,, -49700,1.1071811,2.7360861,,,,,,,,,,,,,, -49800,1.3421518,2.512125,,,,,,,,,,,,,, -49900,1.2627491,2.6680408,,,,,,,,,,,,,, -50000,1.1647753,2.62017,,,,,,,,,,,,,, -50100,0.9012574,4.490702,,,,,,,,,,,,,, -50200,1.2362052,2.5849624,,,,,,,,,,,,,, -50258,,,0.640917956829071,1.4934781789779663,0.5890600085258484,1.74969220161438,50000.0,0.4665000140666961,2.422419309616089,10000.0,23146.96598172188,25848.04677391052,23146.96598172188,2696.7211923599243,1.7389581203460691,0.0 -50300,1.2322272,2.6582084,,,,,,,,,,,,,, -50400,1.2499337,2.6153362,,,,,,,,,,,,,, -50500,1.0983155,2.8388848,,,,,,,,,,,,,, -50600,1.1733023,2.937205,,,,,,,,,,,,,, -50700,1.1064996,2.9827707,,,,,,,,,,,,,, -50800,1.2009304,2.7093487,,,,,,,,,,,,,, -50900,1.3024808,2.654909,,,,,,,,,,,,,, -51000,1.0964562,5.27718,,,,,,,,,,,,,, -51100,1.2534876,2.585731,,,,,,,,,,,,,, -51173,,,0.654101550579071,1.4827535152435305,0.582260012626648,1.81582260131836,50000.0,0.4625000357627868,2.461416482925415,10000.0,23567.253214359283,26316.64104104042,23567.253214359283,2744.944550037384,1.774586200714111,0.0 -51200,1.0276964,3.3146124,,,,,,,,,,,,,, -51300,1.1785479,2.616636,,,,,,,,,,,,,, -51400,1.5893968,2.6052976,,,,,,,,,,,,,, -51500,1.1520743,2.6107154,,,,,,,,,,,,,, -51600,0.9712604,3.5791733,,,,,,,,,,,,,, -51700,1.3013468,2.7343733,,,,,,,,,,,,,, -51800,1.2137368,2.7508965,,,,,,,,,,,,,, -51900,1.2087551,2.4955916,,,,,,,,,,,,,, -52000,1.2412848,2.6113422,,,,,,,,,,,,,, -52086,,,0.6284374594688416,1.56432044506073,0.5817399621009827,1.7849321365356443,50000.0,0.4642000198364258,2.44661545753479,10000.0,23987.205542325974,26784.412437677383,23987.205542325974,2792.6781933307648,1.813803672790528,0.0 -52100,1.0961207,4.6756,,,,,,,,,,,,,, -52200,1.1662321,2.6244898,,,,,,,,,,,,,, -52300,0.9534484,3.3974686,,,,,,,,,,,,,, -52400,1.3969537,2.6231155,,,,,,,,,,,,,, -52500,1.1096568,2.4572852,,,,,,,,,,,,,, -52600,1.231211,2.5557334,,,,,,,,,,,,,, -52700,1.2232659,2.624188,,,,,,,,,,,,,, -52800,1.0378302,3.4003868,,,,,,,,,,,,,, -52900,1.1408098,2.61011,,,,,,,,,,,,,, -53000,1.24139,2.622037,,,,,,,,,,,,,, -53002,,,0.6423242092132568,1.4505491256713867,0.5980799794197083,1.690264344215393,50000.0,0.4789000153541565,2.357414960861206,10000.0,24407.16034078598,27253.20618200302,24407.16034078598,2841.4261870384216,1.8572685718536377,0.0 -53100,1.2060525,2.5143197,,,,,,,,,,,,,, -53200,1.2618444,2.7207723,,,,,,,,,,,,,, -53300,0.9796954,5.303234,,,,,,,,,,,,,, -53400,1.1327046,2.6915827,,,,,,,,,,,,,, -53500,1.0036395,3.249282,,,,,,,,,,,,,, -53600,0.9301647,4.533671,,,,,,,,,,,,,, -53700,1.2622764,2.4786396,,,,,,,,,,,,,, -53800,1.1597664,2.5602922,,,,,,,,,,,,,, -53900,1.1131725,2.7762642,,,,,,,,,,,,,, -53916,,,0.6451367139816284,1.5169159173965454,0.583579957485199,1.813071370124817,50000.0,0.4648000299930572,2.465859889984131,10000.0,24827.3532075882,27724.505935430527,24827.3532075882,2892.447431087494,1.895360231399536,0.0 -54000,1.0877495,4.771156,,,,,,,,,,,,,, -54100,0.94283396,4.813134,,,,,,,,,,,,,, -54200,1.1604527,2.661471,,,,,,,,,,,,,, -54300,0.91030073,5.198498,,,,,,,,,,,,,, -54400,1.2912599,2.647893,,,,,,,,,,,,,, -54500,1.2481449,2.6062996,,,,,,,,,,,,,, -54600,1.3562618,2.4408226,,,,,,,,,,,,,, -54700,1.0672388,3.7728648,,,,,,,,,,,,,, -54800,0.9659182,5.3247347,,,,,,,,,,,,,, -54832,,,0.6351171731948853,1.4957752227783203,0.5937199592590332,1.7042373418807983,50000.0,0.4722000360488891,2.388838529586792,10000.0,25247.272602796555,28193.08148908615,25247.272602796555,2941.016970396042,1.9336466789245603,0.0 -54900,1.1972702,2.676518,,,,,,,,,,,,,, -55000,1.21119,2.526622,,,,,,,,,,,,,, -55100,1.041736,3.567721,,,,,,,,,,,,,, -55200,1.2325408,3.410864,,,,,,,,,,,,,, -55300,1.0470687,5.1957436,,,,,,,,,,,,,, -55400,1.1952453,2.6339118,,,,,,,,,,,,,, -55500,1.174787,3.3645546,,,,,,,,,,,,,, -55600,1.2687882,2.5530903,,,,,,,,,,,,,, -55700,1.0821892,2.9218762,,,,,,,,,,,,,, -55740,,,0.64501953125,1.4755982160568235,0.5958200097084045,1.7111175060272217,50000.0,0.4680000245571136,2.404808759689331,10000.0,25667.20871949196,28663.57295012474,25667.20871949196,2991.487948417664,1.971735000610352,0.0 -55800,0.9684586,4.9194355,,,,,,,,,,,,,, -55900,1.2342999,2.5880835,,,,,,,,,,,,,, -56000,0.970249,4.548047,,,,,,,,,,,,,, -56100,1.2370565,2.6745677,,,,,,,,,,,,,, -56200,1.2224456,2.413204,,,,,,,,,,,,,, -56300,1.1981432,2.6281505,,,,,,,,,,,,,, -56400,1.2090658,2.5535028,,,,,,,,,,,,,, -56500,0.98078394,5.003235,,,,,,,,,,,,,, -56600,1.1874056,3.0682201,,,,,,,,,,,,,, -56651,,,0.6513866782188416,1.4567426443099976,0.590999960899353,1.7368868589401243,50000.0,0.4706000089645386,2.3935937881469727,10000.0,26087.488626241684,29132.650985717773,26087.488626241684,3040.205552339554,2.005126714706421,0.0 -56700,1.2111183,2.5641317,,,,,,,,,,,,,, -56800,1.0571201,2.9146402,,,,,,,,,,,,,, -56900,1.1106256,4.1895437,,,,,,,,,,,,,, -57000,1.1244134,4.167305,,,,,,,,,,,,,, -57100,1.103455,3.130229,,,,,,,,,,,,,, -57200,0.960094,4.8288145,,,,,,,,,,,,,, -57300,1.1548352,3.206692,,,,,,,,,,,,,, -57400,1.2061232,2.8359756,,,,,,,,,,,,,, -57500,1.1799577,2.9196887,,,,,,,,,,,,,, -57560,,,0.6385351419448853,1.554219126701355,0.593239963054657,1.763049840927124,50000.0,0.473000019788742,2.4218482971191406,10000.0,26507.458403348923,29602.409747600555,26507.458403348923,3089.9091703891754,2.0419747829437256,0.0 -57600,0.9986617,3.5973406,,,,,,,,,,,,,, -57700,0.896773,4.022361,,,,,,,,,,,,,, -57800,1.3264626,2.588271,,,,,,,,,,,,,, -57900,1.1376237,2.6455698,,,,,,,,,,,,,, -58000,0.99444145,4.1298633,,,,,,,,,,,,,, -58100,1.209672,2.5810761,,,,,,,,,,,,,, -58200,1.0146769,4.6274786,,,,,,,,,,,,,, -58300,0.9838463,3.5199249,,,,,,,,,,,,,, -58400,1.4368275,2.7386966,,,,,,,,,,,,,, -58474,,,0.6503124833106995,1.4144947528839111,0.6011399626731873,1.652499437332153,50000.0,0.4843000173568725,2.323306083679199,10000.0,26927.61052799225,30073.15274357796,26927.61052799225,3140.4162259101868,2.0792157649993896,0.0 -58500,1.2078896,2.972993,,,,,,,,,,,,,, -58600,0.958965,4.3182917,,,,,,,,,,,,,, -58700,1.2139895,2.5714695,,,,,,,,,,,,,, -58800,0.9682229,4.076796,,,,,,,,,,,,,, -58900,1.1662078,2.56401,,,,,,,,,,,,,, -59000,1.2723628,2.7893019,,,,,,,,,,,,,, -59100,1.1046473,2.7086003,,,,,,,,,,,,,, -59200,1.1983657,2.489253,,,,,,,,,,,,,, -59300,1.32235,2.719492,,,,,,,,,,,,,, -59389,,,0.6558203101158142,1.472233533859253,0.5989399552345276,1.731550931930542,50000.0,0.4761000275611877,2.381978988647461,10000.0,27347.89120697975,30542.19105768204,27347.89120697975,3189.0866141319275,2.1189894676208496,0.0 -59400,1.0917144,2.6776698,,,,,,,,,,,,,, -59500,1.1983181,3.103356,,,,,,,,,,,,,, -59600,1.0383499,3.8474958,,,,,,,,,,,,,, -59700,1.2506206,2.5096383,,,,,,,,,,,,,, -59800,0.95307785,5.1536446,,,,,,,,,,,,,, -59900,1.363967,2.4849484,,,,,,,,,,,,,, -60000,1.2610836,2.5723264,,,,,,,,,,,,,, -60100,1.2371469,2.597839,,,,,,,,,,,,,, -60200,1.3010753,2.5184956,,,,,,,,,,,,,, -60300,1.1327019,2.4356399,,,,,,,,,,,,,, -60301,,,0.64208984375,1.5133836269378662,0.5927799940109253,1.7447527647018433,50000.0,0.4716000258922577,2.4132392406463623,10000.0,27767.98445320129,31011.86527228356,27767.98445320129,3238.5846648216248,2.1544106006622314,0.0 -60400,1.2143272,2.404079,,,,,,,,,,,,,, -60500,0.99357355,4.939925,,,,,,,,,,,,,, -60600,1.0908642,3.0517364,,,,,,,,,,,,,, -60700,1.1599208,2.6011922,,,,,,,,,,,,,, -60800,1.1554474,2.5913377,,,,,,,,,,,,,, -60900,1.2809243,2.8783598,,,,,,,,,,,,,, -61000,1.2179693,2.5603533,,,,,,,,,,,,,, -61100,1.0629178,5.065754,,,,,,,,,,,,,, -61200,1.2186708,2.5887957,,,,,,,,,,,,,, -61215,,,0.6486132740974426,1.5043102502822876,0.6002399921417236,1.7300899028778076,50000.0,0.4771000146865845,2.3871309757232666,10000.0,28188.03951358795,31482.68537425995,28188.03951358795,3289.2686598300934,2.187922954559326,0.0 -61300,1.3199332,2.980071,,,,,,,,,,,,,, -61400,1.0529903,2.826958,,,,,,,,,,,,,, -61500,0.9722106,4.9957547,,,,,,,,,,,,,, -61600,1.3243885,2.543566,,,,,,,,,,,,,, -61700,1.1634606,3.0437825,,,,,,,,,,,,,, -61800,1.2372729,2.663045,,,,,,,,,,,,,, -61900,1.0786107,3.0073442,,,,,,,,,,,,,, -62000,1.2697011,2.637926,,,,,,,,,,,,,, -62100,1.0658927,4.45823,,,,,,,,,,,,,, -62126,,,0.6597460508346558,1.415103316307068,0.6001200079917908,1.6884820461273191,50000.0,0.4869000315666199,2.353926658630371,10000.0,28607.97818994522,31952.91978526116,28607.97818994522,3339.475342273712,2.2301175594329834,0.0 -62200,0.9739974,3.8226814,,,,,,,,,,,,,, -62300,1.1158886,3.0387168,,,,,,,,,,,,,, -62400,0.9949835,5.068683,,,,,,,,,,,,,, -62500,1.2174217,2.4618192,,,,,,,,,,,,,, -62600,1.2587854,2.8988836,,,,,,,,,,,,,, -62700,0.9686679,3.7562966,,,,,,,,,,,,,, -62800,1.1215578,2.7958615,,,,,,,,,,,,,, -62900,1.2667022,2.4689422,,,,,,,,,,,,,, -63000,1.0302192,4.751604,,,,,,,,,,,,,, -63040,,,0.6842187643051147,1.2707115411758425,0.6073399782180786,1.6327868700027466,50000.0,0.4918000102043152,2.2991316318511963,10000.0,29028.25691366196,32421.085930347443,29028.25691366196,3387.2805788517,2.26474666595459,0.0 -63100,1.2990423,2.5605047,,,,,,,,,,,,,, -63200,1.0310003,4.315956,,,,,,,,,,,,,, -63300,1.141582,2.6063359,,,,,,,,,,,,,, -63400,1.3032184,2.5858278,,,,,,,,,,,,,, -63500,1.2017671,2.4618154,,,,,,,,,,,,,, -63600,1.224852,2.404686,,,,,,,,,,,,,, -63700,1.1752369,2.5152445,,,,,,,,,,,,,, -63800,1.1312908,2.7965457,,,,,,,,,,,,,, -63900,1.2698174,2.5275917,,,,,,,,,,,,,, -63953,,,0.6486914157867432,1.453892469406128,0.6010400056838989,1.6871942281723022,50000.0,0.4830000102519989,2.333249092102051,10000.0,29448.54456448555,32891.12489771843,29448.54456448555,3436.948952436447,2.301134824752808,0.0 -64000,1.2934357,2.4785943,,,,,,,,,,,,,, -64100,1.0343443,3.5824547,,,,,,,,,,,,,, -64200,1.0449132,4.0469646,,,,,,,,,,,,,, -64300,1.1271157,3.188404,,,,,,,,,,,,,, -64400,1.1990682,2.5950072,,,,,,,,,,,,,, -64500,1.318481,2.4651394,,,,,,,,,,,,,, -64600,1.0730064,4.425914,,,,,,,,,,,,,, -64700,1.1116079,3.5291378,,,,,,,,,,,,,, -64800,1.1699816,2.7451549,,,,,,,,,,,,,, -64867,,,0.658203125,1.4299410581588743,0.6074599623680115,1.6725072860717771,50000.0,0.4889000356197357,2.33275842666626,10000.0,29868.56832718849,33362.58354306221,29868.56832718849,3488.2956540584564,2.341980695724488,0.0 -64900,1.2993702,2.5185177,,,,,,,,,,,,,, -65000,1.3468798,2.5116405,,,,,,,,,,,,,, -65100,1.1236179,3.311737,,,,,,,,,,,,,, -65200,1.5932659,2.3453054,,,,,,,,,,,,,, -65300,1.0318378,3.5013652,,,,,,,,,,,,,, -65400,1.1851878,2.604611,,,,,,,,,,,,,, -65500,1.0530488,4.2588844,,,,,,,,,,,,,, -65600,1.228812,2.6421466,,,,,,,,,,,,,, -65700,1.3741264,2.5967875,,,,,,,,,,,,,, -65780,,,0.6772069931030273,1.291306734085083,0.6087999939918518,1.6291512250900269,50000.0,0.4874000251293182,2.2988369464874268,10000.0,30288.65084552765,33832.26501727104,30288.65084552765,3537.798797369004,2.3904261589050293,0.0 -65800,1.3510696,2.5609415,,,,,,,,,,,,,, -65900,1.0989043,2.8773935,,,,,,,,,,,,,, -66000,1.1010294,3.2144225,,,,,,,,,,,,,, -66100,1.154264,5.0973463,,,,,,,,,,,,,, -66200,0.9688743,4.571147,,,,,,,,,,,,,, -66300,1.2842835,2.812508,,,,,,,,,,,,,, -66400,1.0935497,4.1801834,,,,,,,,,,,,,, -66500,1.09119,4.4199934,,,,,,,,,,,,,, -66600,1.2725023,2.5233405,,,,,,,,,,,,,, -66695,,,0.6553320288658142,1.4468097686767578,0.6100599765777588,1.668556571006775,50000.0,0.4958000183105469,2.307793378829956,10000.0,30708.97991228104,34301.335938215256,30708.97991228104,3586.4557435512543,2.428715467453003,0.0 -66700,1.1494228,2.9106627,,,,,,,,,,,,,, -66800,1.2275817,3.0263155,,,,,,,,,,,,,, -66900,0.9842338,4.557669,,,,,,,,,,,,,, -67000,1.2150601,2.6476672,,,,,,,,,,,,,, -67100,1.0234607,3.6822908,,,,,,,,,,,,,, -67200,1.2453184,3.0912645,,,,,,,,,,,,,, -67300,1.0762433,4.723422,,,,,,,,,,,,,, -67400,1.109809,4.338197,,,,,,,,,,,,,, -67500,1.3190123,2.4552567,,,,,,,,,,,,,, -67600,1.0846757,3.2023172,,,,,,,,,,,,,, -67610,,,0.6596874594688416,1.3802225589752195,0.6093400120735168,1.6242049932479858,50000.0,0.486700028181076,2.3031609058380127,10000.0,31128.98797106743,34770.267876148224,31128.98797106743,3635.289438724518,2.4718663692474365,0.0 -67700,1.2437025,2.4015372,,,,,,,,,,,,,, -67800,1.3004795,2.5242345,,,,,,,,,,,,,, -67900,1.2554078,4.8758736,,,,,,,,,,,,,, -68000,1.2572451,2.4577935,,,,,,,,,,,,,, -68100,1.2386383,2.5895283,,,,,,,,,,,,,, -68200,1.3714553,2.7948825,,,,,,,,,,,,,, -68300,1.3084148,2.5957513,,,,,,,,,,,,,, -68400,1.3513288,2.5417807,,,,,,,,,,,,,, -68500,1.3211927,2.7106786,,,,,,,,,,,,,, -68524,,,0.6707226634025574,1.330780267715454,0.6122999787330627,1.6297461986541748,50000.0,0.4939000308513641,2.2985339164733887,10000.0,31548.9758489132,35238.550125837326,31548.9758489132,3683.48933506012,2.515212059020996,0.0 -68600,1.1605849,3.2048104,,,,,,,,,,,,,, -68700,1.3029733,2.4597116,,,,,,,,,,,,,, -68800,1.3189452,2.5110788,,,,,,,,,,,,,, -68900,1.1240015,2.4520314,,,,,,,,,,,,,, -69000,1.1010729,4.1729946,,,,,,,,,,,,,, -69100,1.3550441,2.6099324,,,,,,,,,,,,,, -69200,1.2615858,2.855198,,,,,,,,,,,,,, -69300,0.9760683,3.9350872,,,,,,,,,,,,,, -69400,1.2741907,2.5095139,,,,,,,,,,,,,, -69435,,,0.6629882454872131,1.3963218927383425,0.6138799786567688,1.6306712627410889,50000.0,0.4905000329017639,2.291445732116699,10000.0,31968.895731449127,35706.80040049553,31968.895731449127,3731.7338008880615,2.553818941116333,0.0 -69500,1.378734,2.6441317,,,,,,,,,,,,,, -69600,1.1731454,4.131851,,,,,,,,,,,,,, -69700,1.2663134,2.5826352,,,,,,,,,,,,,, -69800,1.4341686,2.5647197,,,,,,,,,,,,,, -69900,1.0786886,2.7163036,,,,,,,,,,,,,, -70000,1.18544,2.9663332,,,,,,,,,,,,,, -70100,1.2760407,2.6177883,,,,,,,,,,,,,, -70200,1.3430722,2.7262478,,,,,,,,,,,,,, -70300,1.1419575,2.9297242,,,,,,,,,,,,,, -70350,,,0.66259765625,1.3925042152404783,0.6094799637794495,1.6426547765731812,50000.0,0.4910000264644623,2.2989656925201416,10000.0,32389.219320058823,36176.90175771713,32389.219320058823,3781.4290795326233,2.589449882507324,0.0 -70400,1.0550185,3.9174364,,,,,,,,,,,,,, -70500,1.2496465,2.5772471,,,,,,,,,,,,,, -70600,1.1671742,3.3826973,,,,,,,,,,,,,, -70700,1.2906404,2.454043,,,,,,,,,,,,,, -70800,1.2268329,2.493831,,,,,,,,,,,,,, -70900,1.4408514,2.4533584,,,,,,,,,,,,,, -71000,1.3843639,2.472087,,,,,,,,,,,,,, -71100,1.3899231,2.4100251,,,,,,,,,,,,,, -71200,1.290242,3.5156946,,,,,,,,,,,,,, -71265,,,0.6683593392372131,1.397711157798767,0.6119399666786194,1.6650700569152832,50000.0,0.4864000082015991,2.3102755546569824,10000.0,32809.368601322174,36646.70735788345,32809.368601322174,3830.998381853104,2.6295628547668457,0.0 -71300,1.4309376,2.5384293,,,,,,,,,,,,,, -71400,1.402256,2.5061724,,,,,,,,,,,,,, -71500,1.15814,3.1878815,,,,,,,,,,,,,, -71600,1.3588452,2.503151,,,,,,,,,,,,,, -71700,1.3207518,2.4853673,,,,,,,,,,,,,, -71800,1.0578567,3.1138241,,,,,,,,,,,,,, -71900,1.3993927,2.7329423,,,,,,,,,,,,,, -72000,1.0023522,4.5793977,,,,,,,,,,,,,, -72100,1.4099597,2.50857,,,,,,,,,,,,,, -72179,,,0.6663671731948853,1.3786489963531494,0.6188600063323975,1.6109440326690674,50000.0,0.4981000125408172,2.268369674682617,10000.0,33229.415531635284,37116.63294029236,33229.415531635284,3880.787905454636,2.671271324157715,0.0 -72200,1.2848246,2.488566,,,,,,,,,,,,,, -72300,1.0990661,3.2618241,,,,,,,,,,,,,, -72400,1.1743027,3.26548,,,,,,,,,,,,,, -72500,1.178857,4.1228385,,,,,,,,,,,,,, -72600,1.2968588,2.5401764,,,,,,,,,,,,,, -72700,1.2341301,3.0480814,,,,,,,,,,,,,, -72800,1.2765787,2.2743177,,,,,,,,,,,,,, -72900,1.2297733,2.3347552,,,,,,,,,,,,,, -73000,1.301132,2.373053,,,,,,,,,,,,,, -73093,,,0.6658984422683716,1.3752154111862185,0.6152799725532532,1.6217827796936035,50000.0,0.4900000095367431,2.293707609176636,10000.0,33649.62982439995,37585.90481185913,33649.62982439995,3929.753345727921,2.7157506942749023,0.0 -73100,1.263429,2.957304,,,,,,,,,,,,,, -73200,1.1628506,4.715209,,,,,,,,,,,,,, -73300,1.1009867,3.8055568,,,,,,,,,,,,,, -73400,1.0975205,3.9617994,,,,,,,,,,,,,, -73500,1.133082,3.0915818,,,,,,,,,,,,,, -73600,1.2119308,4.3172812,,,,,,,,,,,,,, -73700,1.2164608,2.5112207,,,,,,,,,,,,,, -73800,1.3362315,2.58195,,,,,,,,,,,,,, -73900,0.9903401,4.288389,,,,,,,,,,,,,, -74000,1.3834951,2.407666,,,,,,,,,,,,,, -74008,,,0.6753124594688416,1.3523850440979004,0.6179999709129333,1.608821153640747,50000.0,0.4978000223636627,2.266586780548096,10000.0,34069.875893354416,38055.59846138954,34069.875893354416,3979.105614423752,2.763178586959839,0.0 -74100,1.059471,3.9237509,,,,,,,,,,,,,, -74200,1.3019437,2.6707604,,,,,,,,,,,,,, -74300,1.1966794,4.5545654,,,,,,,,,,,,,, -74400,1.058355,4.266961,,,,,,,,,,,,,, -74500,1.3305535,2.512448,,,,,,,,,,,,,, -74600,1.3551898,2.5601954,,,,,,,,,,,,,, -74700,1.417034,2.609533,,,,,,,,,,,,,, -74800,1.0855538,4.602773,,,,,,,,,,,,,, -74900,1.1643732,2.8544393,,,,,,,,,,,,,, -74919,,,0.695605456829071,1.2618730068206787,0.6238799691200256,1.5817739963531494,50000.0,0.5057000517845154,2.234017372131348,10000.0,34489.79175186157,38525.39829945564,34489.79175186157,4028.905200958252,2.799935817718506,0.0 -75000,1.0174261,5.0120935,,,,,,,,,,,,,, -75100,1.4152293,2.3641372,,,,,,,,,,,,,, -75200,1.2980582,2.3654075,,,,,,,,,,,,,, -75300,1.4961045,2.4952407,,,,,,,,,,,,,, -75400,1.1998416,2.7481952,,,,,,,,,,,,,, -75500,1.3010962,2.7030196,,,,,,,,,,,,,, -75600,1.4308895,2.53881,,,,,,,,,,,,,, -75700,1.0599728,3.252659,,,,,,,,,,,,,, -75800,1.3511552,2.3904817,,,,,,,,,,,,,, -75832,,,0.6705663800239563,1.3746696710586548,0.6190599799156189,1.6183385848999023,50000.0,0.4936000108718872,2.2937138080596924,10000.0,34909.78695511818,38997.88926529884,34909.78695511818,4081.3146035671234,2.8388900756835938,0.0 -75900,1.0905684,3.776472,,,,,,,,,,,,,, -76000,1.1218697,4.542335,,,,,,,,,,,,,, -76100,1.167813,3.7018914,,,,,,,,,,,,,, -76200,1.2492881,2.352259,,,,,,,,,,,,,, -76300,1.3400047,4.672373,,,,,,,,,,,,,, -76400,1.0801016,5.06353,,,,,,,,,,,,,, -76500,1.2739465,2.7902942,,,,,,,,,,,,,, -76600,1.3227981,5.0484905,,,,,,,,,,,,,, -76700,1.1458064,3.6848361,,,,,,,,,,,,,, -76745,,,0.6803905963897705,1.297563910484314,0.6237199902534485,1.563851237297058,50000.0,0.5089000463485718,2.204143762588501,10000.0,35329.75003695488,39465.47510480881,35329.75003695488,4128.850201368332,2.878529787063598,0.0 -76800,1.4257799,2.4783738,,,,,,,,,,,,,, -76900,1.1291835,2.845584,,,,,,,,,,,,,, -77000,1.3949976,2.4342546,,,,,,,,,,,,,, -77100,1.2601252,2.3419313,,,,,,,,,,,,,, -77200,1.5161664,2.3427343,,,,,,,,,,,,,, -77300,1.4043177,2.5387688,,,,,,,,,,,,,, -77400,1.1945715,4.5088186,,,,,,,,,,,,,, -77500,1.4070712,2.60253,,,,,,,,,,,,,, -77600,1.3355,2.4960659,,,,,,,,,,,,,, -77660,,,0.6931836009025574,1.2485311031341553,0.6234999895095825,1.5847680568695068,50000.0,0.5022000074386597,2.245917320251465,10000.0,35749.74824357033,39935.9891512394,35749.74824357033,4179.281141281128,2.9159533977508545,0.0 -77700,1.1132907,4.822158,,,,,,,,,,,,,, -77800,1.180469,4.8616295,,,,,,,,,,,,,, -77900,1.2506384,4.401859,,,,,,,,,,,,,, -78000,1.1456302,5.143818,,,,,,,,,,,,,, -78100,1.3829621,2.360599,,,,,,,,,,,,,, -78200,1.3568034,2.3822706,,,,,,,,,,,,,, -78300,1.1127143,3.6295848,,,,,,,,,,,,,, -78400,1.0927033,4.8761578,,,,,,,,,,,,,, -78500,1.3209747,2.4846554,,,,,,,,,,,,,, -78574,,,0.6760546565055847,1.3154542446136477,0.6225399971008301,1.5719152688980105,50000.0,0.508400022983551,2.222673177719116,10000.0,36169.98121476173,40402.8313946724,36169.98121476173,4225.807715415955,2.951959609985352,0.0 -78600,1.2897321,2.830347,,,,,,,,,,,,,, -78700,1.2149405,2.9295301,,,,,,,,,,,,,, -78800,1.3275386,2.3336105,,,,,,,,,,,,,, -78900,1.7285892,2.5279427,,,,,,,,,,,,,, -79000,1.313459,2.400768,,,,,,,,,,,,,, -79100,1.277639,2.515315,,,,,,,,,,,,,, -79200,1.2196568,2.5354302,,,,,,,,,,,,,, -79300,1.2995464,2.3539476,,,,,,,,,,,,,, -79400,1.1269057,3.8041523,,,,,,,,,,,,,, -79489,,,0.6782616972923279,1.3288729190826416,0.6286799907684326,1.572872281074524,50000.0,0.5056000351905823,2.23213791847229,10000.0,36590.37432670593,40870.50536131859,36590.37432670593,4273.000194072723,2.9928932189941406,0.0 -79500,1.098196,3.6660779,,,,,,,,,,,,,, -79600,1.3395189,2.4060106,,,,,,,,,,,,,, -79700,1.3662053,2.4137528,,,,,,,,,,,,,, -79800,1.0740771,4.9985676,,,,,,,,,,,,,, -79900,1.3279121,2.571393,,,,,,,,,,,,,, -80000,1.151484,5.0014095,,,,,,,,,,,,,, -80100,1.3115709,2.3863423,,,,,,,,,,,,,, -80200,1.0634457,5.0040855,,,,,,,,,,,,,, -80300,1.4253153,2.481072,,,,,,,,,,,,,, -80400,1.5203192,2.3457735,,,,,,,,,,,,,, -80405,,,0.6900194883346558,1.277381420135498,0.6231399774551392,1.5857053995132446,50000.0,0.499500036239624,2.244521379470825,10000.0,37010.2838101387,41339.1589307785,37010.2838101387,4321.6484797000885,3.0400874614715576,0.0 -80500,1.4637399,2.2603586,,,,,,,,,,,,,, -80600,1.1941974,3.9830704,,,,,,,,,,,,,, -80700,1.3139255,2.66539,,,,,,,,,,,,,, -80800,1.4561188,2.5082362,,,,,,,,,,,,,, -80900,1.4285406,2.561599,,,,,,,,,,,,,, -81000,1.3225944,2.712605,,,,,,,,,,,,,, -81100,1.5311996,2.36965,,,,,,,,,,,,,, -81200,1.56385,2.3356035,,,,,,,,,,,,,, -81300,1.3934999,2.3362772,,,,,,,,,,,,,, -81320,,,0.6727538704872131,1.3480576276779177,0.6234999895095825,1.5841037034988403,50000.0,0.5049000382423401,2.2456095218658447,10000.0,37430.52463960648,41809.0449860096,37430.52463960648,4371.208312034607,3.0785329341888428,0.0 -81400,1.0918583,3.6036215,,,,,,,,,,,,,, -81500,1.2262636,2.622189,,,,,,,,,,,,,, -81600,1.3269125,2.3505168,,,,,,,,,,,,,, -81700,1.311069,2.7009745,,,,,,,,,,,,,, -81800,1.1708,4.5359573,,,,,,,,,,,,,, -81900,1.4126409,2.409363,,,,,,,,,,,,,, -82000,1.3556817,2.450622,,,,,,,,,,,,,, -82100,1.4279217,2.3875897,,,,,,,,,,,,,, -82200,1.3103524,2.8391995,,,,,,,,,,,,,, -82232,,,0.6878319978713989,1.2847343683242798,0.6326599717140198,1.532365798950195,50000.0,0.5116000175476074,2.190577745437622,10000.0,37850.4936645031,42277.68804812431,37850.4936645031,4419.785451173782,3.1283504962921143,0.0 -82300,1.3057102,2.356251,,,,,,,,,,,,,, -82400,1.3477007,2.470384,,,,,,,,,,,,,, -82500,1.2764972,2.4304,,,,,,,,,,,,,, -82600,1.4138877,2.3264916,,,,,,,,,,,,,, -82700,1.3914741,2.3702924,,,,,,,,,,,,,, -82800,1.1973417,3.874623,,,,,,,,,,,,,, -82900,1.2776501,2.5277405,,,,,,,,,,,,,, -83000,1.3386198,2.7256525,,,,,,,,,,,,,, -83100,1.3458247,2.311856,,,,,,,,,,,,,, -83145,,,0.6927343606948853,1.2806254625320437,0.6339799761772156,1.558222413063049,50000.0,0.5128999948501587,2.209243059158325,10000.0,38270.45635151863,42744.11470103264,38270.45635151863,4466.1608464717865,3.169515371322632,0.0 -83200,1.3015627,2.3687735,,,,,,,,,,,,,, -83300,1.327692,2.3384578,,,,,,,,,,,,,, -83400,1.1665013,3.4921603,,,,,,,,,,,,,, -83500,1.3572707,2.2703724,,,,,,,,,,,,,, -83600,1.2608683,2.6636462,,,,,,,,,,,,,, -83700,1.4231373,2.2830825,,,,,,,,,,,,,, -83800,1.4450414,2.2729797,,,,,,,,,,,,,, -83900,1.3491361,2.889167,,,,,,,,,,,,,, -84000,1.6100191,2.3558292,,,,,,,,,,,,,, -84058,,,0.6813867092132568,1.28969407081604,0.6294599771499634,1.5353686809539795,50000.0,0.5073000192642212,2.186338186264038,10000.0,38690.68273019791,43214.1302447319,38690.68273019791,4515.86549949646,3.206382989883423,0.0 -84100,1.274124,4.697777,,,,,,,,,,,,,, -84200,1.3649997,2.4735954,,,,,,,,,,,,,, -84300,1.1687415,4.9948397,,,,,,,,,,,,,, -84400,1.2356024,2.9588995,,,,,,,,,,,,,, -84500,1.1914649,4.745188,,,,,,,,,,,,,, -84600,1.1198663,4.871234,,,,,,,,,,,,,, -84700,1.2676712,3.506627,,,,,,,,,,,,,, -84800,1.2265376,2.9694996,,,,,,,,,,,,,, -84900,1.0918245,3.3905938,,,,,,,,,,,,,, -84974,,,0.688281238079071,1.2811274528503418,0.6375600099563599,1.5175962448120115,50000.0,0.5128000378608704,2.1811633110046387,10000.0,39111.102942466736,43683.11135816574,39111.102942466736,4564.33405828476,3.2509865760803223,0.0 -85000,1.3467156,2.260582,,,,,,,,,,,,,, -85100,1.1828812,5.046831,,,,,,,,,,,,,, -85200,1.5216544,2.447927,,,,,,,,,,,,,, -85300,1.2054393,2.9977567,,,,,,,,,,,,,, -85400,1.4934244,2.408748,,,,,,,,,,,,,, -85500,1.2744329,2.319223,,,,,,,,,,,,,, -85600,1.1441697,3.6780183,,,,,,,,,,,,,, -85700,1.3499981,2.2035522,,,,,,,,,,,,,, -85800,1.3125399,2.969056,,,,,,,,,,,,,, -85887,,,0.6952148079872131,1.2292883396148682,0.6356799602508545,1.519616723060608,50000.0,0.5178000330924988,2.1769161224365234,10000.0,39531.07405328751,44151.64874100685,39531.07405328751,4612.810303688049,3.293998003005981,0.0 -85900,1.5043385,2.3140397,,,,,,,,,,,,,, -86000,1.3426291,2.3117003,,,,,,,,,,,,,, -86100,1.4117773,2.4923813,,,,,,,,,,,,,, -86200,1.3202443,2.953781,,,,,,,,,,,,,, -86300,1.3181164,2.5189493,,,,,,,,,,,,,, -86400,1.1513892,4.580992,,,,,,,,,,,,,, -86500,1.1003209,3.8495753,,,,,,,,,,,,,, -86600,1.279523,2.2135377,,,,,,,,,,,,,, -86700,1.4726492,2.249181,,,,,,,,,,,,,, -86798,,,0.6976562142372131,1.220523476600647,0.6376199722290039,1.5006834268569946,50000.0,0.5117000341415405,2.1646652221679688,10000.0,39951.19725751877,44620.83441233635,39951.19725751877,4661.783539533615,3.335458278656006,0.0 -86800,1.3684937,2.5214052,,,,,,,,,,,,,, -86900,1.5252541,2.848502,,,,,,,,,,,,,, -87000,1.1603577,4.530692,,,,,,,,,,,,,, -87100,1.413517,2.4520779,,,,,,,,,,,,,, -87200,1.4447432,2.55983,,,,,,,,,,,,,, -87300,1.0874485,4.709712,,,,,,,,,,,,,, -87400,1.2344328,3.2240622,,,,,,,,,,,,,, -87500,1.1998829,3.863855,,,,,,,,,,,,,, -87600,1.3754902,2.2926211,,,,,,,,,,,,,, -87700,1.2936717,2.4004347,,,,,,,,,,,,,, -87713,,,0.6788867115974426,1.3864643573760986,0.6287400126457214,1.6301475763320925,50000.0,0.5028000473976135,2.264337539672852,10000.0,40371.22511386871,45092.05185699463,40371.22511386871,4712.8799839019775,3.3804314136505127,0.0 -87800,1.3912363,2.3437371,,,,,,,,,,,,,, -87900,1.4950274,2.2324266,,,,,,,,,,,,,, -88000,1.3823812,2.3166184,,,,,,,,,,,,,, -88100,1.3560508,2.566197,,,,,,,,,,,,,, -88200,1.3752272,2.284943,,,,,,,,,,,,,, -88300,1.225582,4.445175,,,,,,,,,,,,,, -88400,1.131361,4.8952584,,,,,,,,,,,,,, -88500,1.4233491,2.2966359,,,,,,,,,,,,,, -88600,1.4313053,2.215745,,,,,,,,,,,,,, -88626,,,0.696582019329071,1.239540934562683,0.6377399563789368,1.517307162284851,50000.0,0.5161000490188599,2.1748321056365967,10000.0,40791.30937457085,45559.40066933632,40791.30937457085,4760.05557847023,3.4225516319274902,0.0 -88700,1.2068143,3.552488,,,,,,,,,,,,,, -88800,1.484094,2.2855618,,,,,,,,,,,,,, -88900,1.4859837,2.2763004,,,,,,,,,,,,,, -89000,1.5360461,2.8621175,,,,,,,,,,,,,, -89100,1.7475137,2.2847817,,,,,,,,,,,,,, -89200,1.5684191,2.3722262,,,,,,,,,,,,,, -89300,1.1238848,3.8018532,,,,,,,,,,,,,, -89400,1.2658409,4.6675014,,,,,,,,,,,,,, -89500,1.1602267,4.033084,,,,,,,,,,,,,, -89538,,,0.71888667345047,1.155856728553772,0.6437000036239624,1.506482481956482,50000.0,0.5152000188827515,2.1656334400177,10000.0,41211.301359415054,46029.07649850845,41211.301359415054,4809.654201030731,3.460923194885254,0.0 -89600,1.4291244,3.1030216,,,,,,,,,,,,,, -89700,1.2039807,4.3078766,,,,,,,,,,,,,, -89800,1.417517,2.4056158,,,,,,,,,,,,,, -89900,1.2282223,4.8495083,,,,,,,,,,,,,, -90000,1.2127372,3.2600577,,,,,,,,,,,,,, -90100,1.2647096,3.0511723,,,,,,,,,,,,,, -90200,1.2903446,3.2768345,,,,,,,,,,,,,, -90300,1.1690269,3.5674043,,,,,,,,,,,,,, -90400,1.2423896,3.9784377,,,,,,,,,,,,,, -90449,,,0.6948632597923279,1.2670438289642334,0.6413399577140808,1.5187033414840698,50000.0,0.518500030040741,2.1620571613311768,10000.0,41631.52504038811,46498.67860245705,41631.52504038811,4858.942010641098,3.504815101623535,0.0 -90500,1.4323329,2.1423,,,,,,,,,,,,,, -90600,1.1914682,4.424806,,,,,,,,,,,,,, -90700,1.3597379,2.224171,,,,,,,,,,,,,, -90800,1.1287183,4.1950507,,,,,,,,,,,,,, -90900,1.3289186,4.4101696,,,,,,,,,,,,,, -91000,1.4302918,2.3634953,,,,,,,,,,,,,, -91100,1.2129512,3.0164075,,,,,,,,,,,,,, -91200,1.4726928,2.1870394,,,,,,,,,,,,,, -91300,1.314615,2.8629017,,,,,,,,,,,,,, -91365,,,0.7024804353713989,1.197124719619751,0.6451199650764465,1.4732191562652588,50000.0,0.5216000080108643,2.1140828132629395,10000.0,42051.71841979027,46967.12596321106,42051.71841979027,4907.109016418457,3.544394254684448,0.0 -91400,1.4105123,2.2516468,,,,,,,,,,,,,, -91500,1.2806216,4.872193,,,,,,,,,,,,,, -91600,1.3963764,2.192233,,,,,,,,,,,,,, -91700,1.3757962,2.3769689,,,,,,,,,,,,,, -91800,1.4140072,2.3410876,,,,,,,,,,,,,, -91900,1.3493514,2.1973202,,,,,,,,,,,,,, -92000,1.3522769,2.3475156,,,,,,,,,,,,,, -92100,1.2872134,4.685638,,,,,,,,,,,,,, -92200,1.6507198,2.6878946,,,,,,,,,,,,,, -92280,,,0.7185351252555847,1.1746147871017456,0.6480799913406372,1.501075267791748,50000.0,0.5205000042915344,2.161663293838501,10000.0,42471.94208693504,47435.50911140442,42471.94208693504,4955.174675226212,3.591316938400269,0.0 -92300,1.2350113,4.76097,,,,,,,,,,,,,, -92400,1.1504415,3.60898,,,,,,,,,,,,,, -92500,1.6202489,2.2832098,,,,,,,,,,,,,, -92600,1.3592252,2.9468,,,,,,,,,,,,,, -92700,1.535893,2.4137974,,,,,,,,,,,,,, -92800,1.3109604,2.1946645,,,,,,,,,,,,,, -92900,1.1622916,4.9584894,,,,,,,,,,,,,, -93000,1.2643249,2.5301487,,,,,,,,,,,,,, -93100,1.2268212,3.84625,,,,,,,,,,,,,, -93194,,,0.7062109112739563,1.185657262802124,0.6523799896240234,1.4381099939346311,50000.0,0.527400016784668,2.1004345417022705,10000.0,42892.12707614899,47906.19521903992,42892.12707614899,5005.583575248718,3.635996818542481,0.0 -93200,1.5668278,2.204387,,,,,,,,,,,,,, -93300,1.1625544,4.2326884,,,,,,,,,,,,,, -93400,1.4858619,2.3242083,,,,,,,,,,,,,, -93500,1.3481747,2.420929,,,,,,,,,,,,,, -93600,1.282883,3.474177,,,,,,,,,,,,,, -93700,1.3948925,2.2401161,,,,,,,,,,,,,, -93800,1.1687958,4.925414,,,,,,,,,,,,,, -93900,1.3072902,3.6955256,,,,,,,,,,,,,, -94000,1.3482584,2.7776983,,,,,,,,,,,,,, -94100,1.3977162,2.2457063,,,,,,,,,,,,,, -94110,,,0.7049023509025574,1.2190285921096802,0.6449199914932251,1.500189185142517,50000.0,0.5231000185012817,2.142191171646118,10000.0,43312.31781864166,48376.341938734055,43312.31781864166,5055.447685480118,3.67959189414978,0.0 -94200,1.4309256,2.3783607,,,,,,,,,,,,,, -94300,1.4280578,4.3365583,,,,,,,,,,,,,, -94400,1.49111,2.1484473,,,,,,,,,,,,,, -94500,1.4438087,2.6730273,,,,,,,,,,,,,, -94600,1.4803334,2.2193494,,,,,,,,,,,,,, -94700,1.3291628,2.7418606,,,,,,,,,,,,,, -94800,1.2104492,4.517414,,,,,,,,,,,,,, -94900,1.5880255,2.328904,,,,,,,,,,,,,, -95000,1.3885293,2.171425,,,,,,,,,,,,,, -95025,,,0.7074413895606995,1.209861397743225,0.6448799967765808,1.504482984542847,50000.0,0.5199000239372253,2.164458990097046,10000.0,43732.4443500042,48846.29187488556,43732.4443500042,5105.180654525757,3.722055435180664,0.0 -95100,1.5220343,2.1498995,,,,,,,,,,,,,, -95200,1.4161617,2.1829762,,,,,,,,,,,,,, -95300,1.2993405,3.3060274,,,,,,,,,,,,,, -95400,1.128011,4.789634,,,,,,,,,,,,,, -95500,1.263421,3.9634373,,,,,,,,,,,,,, -95600,1.4036913,2.5544422,,,,,,,,,,,,,, -95700,1.3912991,2.1926932,,,,,,,,,,,,,, -95800,1.3852884,2.743575,,,,,,,,,,,,,, -95900,1.5484096,2.2783873,,,,,,,,,,,,,, -95937,,,0.7051367163658142,1.1937915086746216,0.6563999652862549,1.434855580329895,50000.0,0.5348000526428223,2.0843238830566406,10000.0,44152.55052447319,49315.72231268883,44152.55052447319,5154.415741682053,3.763861179351807,0.0 -96000,1.331539,3.2870426,,,,,,,,,,,,,, -96100,1.2058085,4.8296733,,,,,,,,,,,,,, -96200,1.4883907,2.2227032,,,,,,,,,,,,,, -96300,1.4824424,2.1481225,,,,,,,,,,,,,, -96400,1.5671481,2.1737037,,,,,,,,,,,,,, -96500,1.2538116,3.236564,,,,,,,,,,,,,, -96600,1.5507392,2.6877494,,,,,,,,,,,,,, -96700,1.5290226,2.1813822,,,,,,,,,,,,,, -96800,1.4507892,2.2511811,,,,,,,,,,,,,, -96852,,,0.7074413895606995,1.221797227859497,0.6502400040626526,1.4847956895828247,50000.0,0.5329000353813171,2.126685857772827,10000.0,44572.62438511848,49783.660767793655,44572.62438511848,5202.185204267502,3.811398506164551,0.0 -96900,1.2340902,4.7430553,,,,,,,,,,,,,, -97000,1.3436435,2.153183,,,,,,,,,,,,,, -97100,1.1635853,4.742498,,,,,,,,,,,,,, -97200,1.1405388,4.700865,,,,,,,,,,,,,, -97300,1.5381014,2.1821136,,,,,,,,,,,,,, -97400,1.2051953,4.2330575,,,,,,,,,,,,,, -97500,1.4544389,2.1688275,,,,,,,,,,,,,, -97600,1.5507034,2.2633185,,,,,,,,,,,,,, -97700,1.321057,3.7603087,,,,,,,,,,,,,, -97764,,,0.7198437452316284,1.1401715278625488,0.6556599736213684,1.4324363470077517,50000.0,0.5281000137329102,2.0989506244659424,10000.0,44992.95506215096,50255.37498688698,44992.95506215096,5253.482377767563,3.8508574962615967,0.0 -97800,1.3412722,2.2652662,,,,,,,,,,,,,, -97900,1.5392399,2.5342224,,,,,,,,,,,,,, -98000,1.284868,3.9200258,,,,,,,,,,,,,, -98100,1.2554085,3.1145823,,,,,,,,,,,,,, -98200,1.5287526,2.2587733,,,,,,,,,,,,,, -98300,1.4080173,4.044506,,,,,,,,,,,,,, -98400,1.3523028,4.2958527,,,,,,,,,,,,,, -98500,1.6757879,2.2827878,,,,,,,,,,,,,, -98600,1.2069564,4.5484986,,,,,,,,,,,,,, -98676,,,0.7141015529632568,1.160022854804993,0.6577399969100952,1.420377254486084,50000.0,0.5307000279426575,2.072648525238037,10000.0,45413.15657520294,50724.73337721825,45413.15657520294,5302.544720649719,3.8982207775115967,0.0 -98700,1.5456777,2.1428418,,,,,,,,,,,,,, -98800,1.4602256,2.17724,,,,,,,,,,,,,, -98900,1.47321,2.1157217,,,,,,,,,,,,,, -99000,1.5712904,2.1698742,,,,,,,,,,,,,, -99100,1.2844634,3.7201374,,,,,,,,,,,,,, -99200,1.5249351,2.172752,,,,,,,,,,,,,, -99300,1.4761347,2.245836,,,,,,,,,,,,,, -99400,1.5141593,2.1651492,,,,,,,,,,,,,, -99500,1.414085,2.120862,,,,,,,,,,,,,, -99592,,,0.7101367115974426,1.178592085838318,0.6572999954223633,1.4439600706100464,50000.0,0.5315000414848328,2.08896541595459,10000.0,45833.22182202339,51193.74008560181,45833.22182202339,5351.394645690918,3.9429714679718018,0.0 -99600,1.4069366,2.165883,,,,,,,,,,,,,, -99700,1.5494647,2.309933,,,,,,,,,,,,,, -99800,1.447386,2.2606711,,,,,,,,,,,,,, -99900,1.51005,2.1385877,,,,,,,,,,,,,, -100000,1.2742884,3.6737404,,,,,,,,,,,,,, -100100,1.469668,2.3359103,,,,,,,,,,,,,, -100200,1.4222734,2.2344708,,,,,,,,,,,,,, -100300,1.4932779,2.1995327,,,,,,,,,,,,,, -100400,1.228648,4.641656,,,,,,,,,,,,,, -100500,1.3315017,2.7238743,,,,,,,,,,,,,, -100506,,,0.7159960865974426,1.1765002012252808,0.6542199850082397,1.4638800621032717,50000.0,0.5288000106811523,2.1265082359313965,10000.0,46253.18224287033,51661.882147789,46253.18224287033,5399.48194694519,3.9898531436920166,0.0 -100600,1.5403746,2.412811,,,,,,,,,,,,,, -100700,1.46238,2.1144547,,,,,,,,,,,,,, -100800,1.4424692,2.1963658,,,,,,,,,,,,,, -100900,1.2579157,3.532375,,,,,,,,,,,,,, -101000,1.5436215,2.172441,,,,,,,,,,,,,, -101100,1.494336,2.240227,,,,,,,,,,,,,, -101200,1.4506721,2.214196,,,,,,,,,,,,,, -101300,1.4248824,2.2086399,,,,,,,,,,,,,, -101400,1.6075846,2.3041906,,,,,,,,,,,,,, -101420,,,0.7417382597923279,1.0204622745513916,0.6606000065803528,1.399212121963501,50000.0,0.5382000207901001,2.045794010162353,10000.0,46673.48130655289,52132.143812179565,46673.48130655289,5449.350612878799,4.036839246749878,0.0 -101500,1.4680514,2.190924,,,,,,,,,,,,,, -101600,1.6321892,2.313974,,,,,,,,,,,,,, -101700,1.4859775,2.3942895,,,,,,,,,,,,,, -101800,1.4131263,3.2001486,,,,,,,,,,,,,, -101900,1.4695007,2.6971216,,,,,,,,,,,,,, -102000,1.3312138,3.3709254,,,,,,,,,,,,,, -102100,1.4790573,2.2547977,,,,,,,,,,,,,, -102200,1.404255,2.2427697,,,,,,,,,,,,,, -102300,1.4593999,2.5173326,,,,,,,,,,,,,, -102335,,,0.7127734422683716,1.167461633682251,0.6566799879074097,1.4336830377578735,50000.0,0.5314000248908997,2.095097064971924,10000.0,47093.46519160271,52601.3354651928,47093.46519160271,5498.46777844429,4.080301761627197,0.0 -102400,1.4645814,2.149426,,,,,,,,,,,,,, -102500,1.5533334,2.1075954,,,,,,,,,,,,,, -102600,1.363211,4.290387,,,,,,,,,,,,,, -102700,1.556055,2.0572555,,,,,,,,,,,,,, -102800,1.4109484,2.2013845,,,,,,,,,,,,,, -102900,1.5434774,2.2282076,,,,,,,,,,,,,, -103000,1.508932,4.0457563,,,,,,,,,,,,,, -103100,1.4656489,2.1443617,,,,,,,,,,,,,, -103200,1.3087107,3.954581,,,,,,,,,,,,,, -103250,,,0.7226171493530273,1.1213455200195312,0.6646199822425842,1.3939058780670166,50000.0,0.5400000214576721,2.0467312335968018,10000.0,47513.68147063255,53071.09167742729,47513.68147063255,5547.918229103088,4.1213812828063965,0.0 -103300,1.6034303,2.1052318,,,,,,,,,,,,,, -103400,1.4305212,4.8490257,,,,,,,,,,,,,, -103500,1.4641913,2.8894577,,,,,,,,,,,,,, -103600,1.5697924,2.1253502,,,,,,,,,,,,,, -103700,1.544596,2.1320562,,,,,,,,,,,,,, -103800,1.4112427,4.0346947,,,,,,,,,,,,,, -103900,1.4281937,2.2013347,,,,,,,,,,,,,, -104000,1.4654334,2.044262,,,,,,,,,,,,,, -104100,1.4405613,2.2167645,,,,,,,,,,,,,, -104166,,,0.7356249690055847,1.0597867965698242,0.6630399823188782,1.3899242877960205,50000.0,0.5416000485420227,2.037737846374512,10000.0,47933.8608417511,53539.94333600998,47933.8608417511,5596.494349956512,4.169813394546509,0.0 -104200,1.5219427,2.1924005,,,,,,,,,,,,,, -104300,1.6842035,2.1494853,,,,,,,,,,,,,, -104400,1.7034836,2.0596986,,,,,,,,,,,,,, -104500,1.5632095,2.0890312,,,,,,,,,,,,,, -104600,1.6521524,2.1355734,,,,,,,,,,,,,, -104700,1.6201081,2.1510038,,,,,,,,,,,,,, -104800,1.2602512,4.6357803,,,,,,,,,,,,,, -104900,1.7513903,2.2195382,,,,,,,,,,,,,, -105000,1.431153,4.722731,,,,,,,,,,,,,, -105081,,,0.7218359112739563,1.1346908807754517,0.6655399799346924,1.3913841247558594,50000.0,0.5398000478744507,2.0553183555603027,10000.0,48353.88178706169,54009.95732069016,48353.88178706169,5646.397901058197,4.211780786514282,0.0 -105100,1.5145367,2.1627586,,,,,,,,,,,,,, -105200,1.573355,2.1422617,,,,,,,,,,,,,, -105300,1.6757275,2.4815226,,,,,,,,,,,,,, -105400,1.5164503,3.2914572,,,,,,,,,,,,,, -105500,1.887702,2.1864402,,,,,,,,,,,,,, -105600,1.4591353,2.618979,,,,,,,,,,,,,, -105700,1.4964739,2.0986152,,,,,,,,,,,,,, -105800,1.2433708,3.6993976,,,,,,,,,,,,,, -105900,1.5104626,2.7082362,,,,,,,,,,,,,, -105995,,,0.7276366949081421,1.0901107788085938,0.6678199768066406,1.3639365434646606,50000.0,0.5421000123023987,2.0250608921051025,10000.0,48774.124626636505,54482.09878492355,48774.124626636505,5698.203535318375,4.257709503173828,0.0 -106000,1.55023,2.1863832,,,,,,,,,,,,,, -106100,1.4906985,2.122502,,,,,,,,,,,,,, -106200,1.2930554,4.6452427,,,,,,,,,,,,,, -106300,1.6986182,2.0868495,,,,,,,,,,,,,, -106400,1.5170615,2.2740755,,,,,,,,,,,,,, -106500,1.5113039,2.2769148,,,,,,,,,,,,,, -106600,1.3899332,4.1139402,,,,,,,,,,,,,, -106700,1.3827485,4.5289564,,,,,,,,,,,,,, -106800,1.4198987,2.3360791,,,,,,,,,,,,,, -106900,1.5672853,3.5977845,,,,,,,,,,,,,, -106909,,,0.7373046875,1.0468322038650513,0.6692599654197693,1.3726061582565308,50000.0,0.5458000302314758,2.0267672538757324,10000.0,49194.25314736366,54950.18088555336,49194.25314736366,5746.058041095734,4.303881406784058,0.0 -107000,1.5744004,1.969625,,,,,,,,,,,,,, -107100,1.6220344,2.1637418,,,,,,,,,,,,,, -107200,1.6361973,2.1767502,,,,,,,,,,,,,, -107300,1.3254162,3.9153318,,,,,,,,,,,,,, -107400,1.6002964,2.1817095,,,,,,,,,,,,,, -107500,1.3943071,3.9402132,,,,,,,,,,,,,, -107600,1.4378393,4.631785,,,,,,,,,,,,,, -107700,1.3296218,2.4889188,,,,,,,,,,,,,, -107800,1.5772576,2.1878278,,,,,,,,,,,,,, -107821,,,0.7249218821525574,1.1311174631118774,0.6684199571609497,1.385303616523743,50000.0,0.5473000407218933,2.035310745239258,10000.0,49614.54778671265,55419.25144505501,49614.54778671265,5794.743673563004,4.347439050674439,0.0 -107900,1.7798985,2.1694882,,,,,,,,,,,,,, -108000,1.7827256,2.1629972,,,,,,,,,,,,,, -108100,1.306517,3.3450246,,,,,,,,,,,,,, -108200,1.4409126,4.837244,,,,,,,,,,,,,, -108300,1.4174628,2.3725104,,,,,,,,,,,,,, -108400,1.3983457,4.359259,,,,,,,,,,,,,, -108500,1.4673347,3.3993578,,,,,,,,,,,,,, -108600,1.6910189,2.037229,,,,,,,,,,,,,, -108700,1.4964783,4.751686,,,,,,,,,,,,,, -108735,,,0.7259179353713989,1.0984207391738892,0.6665199995040894,1.3816251754760742,50000.0,0.5497000217437744,2.023720026016236,10000.0,50034.81283926964,55889.60102963448,50034.81283926964,5844.7369022369385,4.391482353210449,0.0 -108800,1.6504676,2.0690355,,,,,,,,,,,,,, -108900,1.5464878,2.3932028,,,,,,,,,,,,,, -109000,1.4056128,2.801787,,,,,,,,,,,,,, -109100,1.5534045,2.003031,,,,,,,,,,,,,, -109200,1.5098709,2.411771,,,,,,,,,,,,,, -109300,1.6065719,2.1030955,,,,,,,,,,,,,, -109400,1.627106,2.3526275,,,,,,,,,,,,,, -109500,1.4941405,2.4828143,,,,,,,,,,,,,, -109600,1.4948701,2.5724206,,,,,,,,,,,,,, -109650,,,0.7411913871765137,1.045373558998108,0.6743599772453308,1.3515647649765017,50000.0,0.5544000267982483,1.9914335012435915,10000.0,50455.170258522034,56358.83649373055,50455.170258522034,5893.524285316467,4.434794902801514,0.0 -109700,1.3121505,3.6350243,,,,,,,,,,,,,, -109800,1.6233032,4.6700535,,,,,,,,,,,,,, -109900,1.5980477,2.0949311,,,,,,,,,,,,,, -110000,1.581609,2.0477223,,,,,,,,,,,,,, -110100,1.4601367,2.18005,,,,,,,,,,,,,, -110200,1.4233108,4.6715765,,,,,,,,,,,,,, -110300,1.6599267,2.2320962,,,,,,,,,,,,,, -110400,1.5919998,2.4054835,,,,,,,,,,,,,, -110500,1.5537034,3.0460215,,,,,,,,,,,,,, -110565,,,0.7310742139816284,1.099925875663757,0.6731199622154236,1.3631587028503418,50000.0,0.5525000095367432,2.017176866531372,10000.0,50875.44026255608,56829.10870409012,50875.44026255608,5943.436125278473,4.477481842041016,0.0 -110600,1.3435242,4.119673,,,,,,,,,,,,,, -110700,1.6259943,2.1037157,,,,,,,,,,,,,, -110800,1.51125,2.2360218,,,,,,,,,,,,,, -110900,1.548762,2.711893,,,,,,,,,,,,,, -111000,1.5346391,2.3920794,,,,,,,,,,,,,, -111100,1.624298,1.9638995,,,,,,,,,,,,,, -111200,1.5937219,2.2360473,,,,,,,,,,,,,, -111300,1.5155653,2.4789107,,,,,,,,,,,,,, -111400,1.5549152,2.6524403,,,,,,,,,,,,,, -111479,,,0.7366992235183716,1.0674903392791748,0.6771799921989441,1.3465999364852903,50000.0,0.5570000410079956,1.9866653680801392,10000.0,51295.71018028259,57298.81796312332,51295.71018028259,5992.779366493225,4.526298999786377,0.0 -111500,1.6973561,2.0365896,,,,,,,,,,,,,, -111600,1.8570744,2.165752,,,,,,,,,,,,,, -111700,1.8836752,2.076406,,,,,,,,,,,,,, -111800,1.5108011,4.7260547,,,,,,,,,,,,,, -111900,1.4797378,2.8865714,,,,,,,,,,,,,, -112000,1.6299244,4.498554,,,,,,,,,,,,,, -112100,1.6129719,2.107775,,,,,,,,,,,,,, -112200,1.5608128,2.064836,,,,,,,,,,,,,, -112300,1.6099008,1.9753999,,,,,,,,,,,,,, -112394,,,0.7471093535423279,1.0021445751190186,0.6795799732208252,1.315180420875549,50000.0,0.5523000359535217,1.9630584716796875,10000.0,51715.62279224396,57766.71099972725,51715.62279224396,6040.666308164597,4.571993350982666,0.0 -112400,1.5973328,2.1063304,,,,,,,,,,,,,, -112500,1.5376065,2.0648704,,,,,,,,,,,,,, -112600,1.5519998,2.4938717,,,,,,,,,,,,,, -112700,1.5886357,2.034062,,,,,,,,,,,,,, -112800,1.7991191,1.9754155,,,,,,,,,,,,,, -112900,2.4149516,2.192284,,,,,,,,,,,,,, -113000,1.6535902,2.082686,,,,,,,,,,,,,, -113100,1.5603065,2.9222705,,,,,,,,,,,,,, -113200,1.9765259,2.0993834,,,,,,,,,,,,,, -113300,1.3168466,3.8607128,,,,,,,,,,,,,, -113310,,,0.7491015195846558,1.0073623657226562,0.6755200028419495,1.3467522859573364,50000.0,0.5523000359535217,2.0030879974365234,10000.0,52135.81513166428,58234.475531339645,52135.81513166428,6088.146703958511,4.616669178009033,0.0 -113400,1.4252435,3.4321742,,,,,,,,,,,,,, -113500,1.5290134,2.9009204,,,,,,,,,,,,,, -113600,1.6563976,1.9893872,,,,,,,,,,,,,, -113700,1.6214528,2.1073635,,,,,,,,,,,,,, -113800,1.7330661,2.6541424,,,,,,,,,,,,,, -113900,1.5325309,3.6097627,,,,,,,,,,,,,, -114000,1.7921191,2.284464,,,,,,,,,,,,,, -114100,1.549813,2.066279,,,,,,,,,,,,,, -114200,1.5779845,2.025866,,,,,,,,,,,,,, -114226,,,0.7378710508346558,1.050584316253662,0.6758399605751038,1.3396703004837036,50000.0,0.5550000071525574,1.981272578239441,10000.0,52556.26023578644,58705.31205654144,52556.26023578644,6138.448607206345,4.658712863922119,0.0 -114300,1.6995095,1.9273093,,,,,,,,,,,,,, -114400,1.5587257,3.3228855,,,,,,,,,,,,,, -114500,1.5037936,3.0479906,,,,,,,,,,,,,, -114600,1.8206959,2.0970802,,,,,,,,,,,,,, -114700,1.6820422,2.0588064,,,,,,,,,,,,,, -114800,1.807224,2.0428863,,,,,,,,,,,,,, -114900,1.6745256,1.9871757,,,,,,,,,,,,,, -115000,1.689314,2.1840932,,,,,,,,,,,,,, -115100,1.4891404,2.137841,,,,,,,,,,,,,, -115140,,,0.7449804544448853,1.0316308736801147,0.6810399889945984,1.3292573690414429,50000.0,0.554900050163269,1.9811760187149048,10000.0,52976.29587888718,59176.41819024086,52976.29587888718,6189.425545454025,4.704372644424439,0.0 -115200,1.6818181,1.9735954,,,,,,,,,,,,,, -115300,1.616469,2.1814919,,,,,,,,,,,,,, -115400,1.708265,2.100059,,,,,,,,,,,,,, -115500,1.6236671,4.0578585,,,,,,,,,,,,,, -115600,1.7631245,2.0334132,,,,,,,,,,,,,, -115700,1.5764312,2.5861936,,,,,,,,,,,,,, -115800,1.6808115,1.9853303,,,,,,,,,,,,,, -115900,1.5963804,2.3875499,,,,,,,,,,,,,, -116000,1.3897192,3.9764895,,,,,,,,,,,,,, -116054,,,0.7593359351158142,0.968052327632904,0.6818599700927734,1.3153448104858398,50000.0,0.5635000467300415,1.9470332860946653,10000.0,53396.448315382,59648.02452993393,53396.448315382,6240.787206888199,4.749409198760986,0.0 -116100,1.5923774,2.2569954,,,,,,,,,,,,,, -116200,1.4930876,3.4076114,,,,,,,,,,,,,, -116300,1.806474,2.6348052,,,,,,,,,,,,,, -116400,1.4897302,4.2531013,,,,,,,,,,,,,, -116500,1.841487,2.0423868,,,,,,,,,,,,,, -116600,1.7799686,2.2332804,,,,,,,,,,,,,, -116700,1.493906,3.333519,,,,,,,,,,,,,, -116800,1.7251998,2.028784,,,,,,,,,,,,,, -116900,1.4972675,4.069503,,,,,,,,,,,,,, -116970,,,0.7439843416213989,1.0527466535568235,0.6826399564743042,1.3293102979660034,50000.0,0.5591000318527222,1.9881199598312376,10000.0,53816.71623015404,60119.25689291954,53816.71623015404,6291.655651569367,4.797953367233276,0.0 -117000,1.725115,2.0804794,,,,,,,,,,,,,, -117100,1.4927028,3.5669212,,,,,,,,,,,,,, -117200,1.7433485,2.1910748,,,,,,,,,,,,,, -117300,2.0078576,2.0685027,,,,,,,,,,,,,, -117400,1.7635888,1.9918513,,,,,,,,,,,,,, -117500,1.6208767,1.873037,,,,,,,,,,,,,, -117600,1.7352506,1.9225421,,,,,,,,,,,,,, -117700,1.4946951,2.984484,,,,,,,,,,,,,, -117800,1.6380503,4.3125834,,,,,,,,,,,,,, -117884,,,0.748828113079071,1.0334590673446655,0.6802399754524231,1.337364912033081,50000.0,0.5621000528335571,1.9683139324188232,10000.0,54236.68958616257,60588.64545035362,54236.68958616257,6340.977566003799,4.843794584274292,0.0 -117900,1.4725739,3.6032982,,,,,,,,,,,,,, -118000,1.5611522,3.4245684,,,,,,,,,,,,,, -118100,1.5430346,4.5640855,,,,,,,,,,,,,, -118200,1.3991811,4.4654994,,,,,,,,,,,,,, -118300,1.7645491,2.0973125,,,,,,,,,,,,,, -118400,1.609449,4.049049,,,,,,,,,,,,,, -118500,1.5480763,3.8502643,,,,,,,,,,,,,, -118600,1.4293699,2.9878695,,,,,,,,,,,,,, -118700,1.5917174,3.1766908,,,,,,,,,,,,,, -118799,,,0.7620312571525574,0.9354296922683716,0.6877599954605103,1.279233694076538,50000.0,0.5672000050544739,1.9073601961135864,10000.0,54656.7682967186,61058.40561413765,54656.7682967186,6390.563494682312,4.891946077346802,0.0 -118800,1.9637659,2.027794,,,,,,,,,,,,,, -118900,1.6563418,3.5463276,,,,,,,,,,,,,, -119000,1.6238525,4.2871604,,,,,,,,,,,,,, -119100,1.9448044,1.9751046,,,,,,,,,,,,,, -119200,1.8131306,2.0147874,,,,,,,,,,,,,, -119300,1.6189483,4.1159325,,,,,,,,,,,,,, -119400,1.7371273,2.01889,,,,,,,,,,,,,, -119500,1.8344673,1.9697726,,,,,,,,,,,,,, -119600,2.0861723,2.365529,,,,,,,,,,,,,, -119700,1.6743916,2.43299,,,,,,,,,,,,,, -119714,,,0.7458398342132568,1.0181658267974854,0.6830599904060364,1.300762176513672,50000.0,0.5613000392913818,1.9275920391082764,10000.0,55076.75383043289,61526.98484683037,55076.75383043289,6439.0599138736725,4.941753149032593,0.0 -119800,1.6130813,4.2810564,,,,,,,,,,,,,, -119900,1.4797218,4.5124345,,,,,,,,,,,,,, -120000,1.6857107,2.153903,,,,,,,,,,,,,, -120100,1.6032207,4.3943667,,,,,,,,,,,,,, -120200,1.5693095,2.200224,,,,,,,,,,,,,, -120300,1.5256256,3.8142235,,,,,,,,,,,,,, -120400,1.7810994,2.0700357,,,,,,,,,,,,,, -120500,1.924585,2.0417573,,,,,,,,,,,,,, -120600,1.579121,3.9722924,,,,,,,,,,,,,, -120629,,,0.7542773485183716,0.980697214603424,0.6851599812507629,1.289842963218689,50000.0,0.5610000491142273,1.920379638671875,10000.0,55496.69174456597,61996.93345713616,55496.69174456597,6488.97558426857,4.9889538288116455,0.0 -120700,1.5790373,2.81595,,,,,,,,,,,,,, -120800,1.5308626,4.2691984,,,,,,,,,,,,,, -120900,1.7705619,2.0653906,,,,,,,,,,,,,, -121000,1.7312751,1.8842653,,,,,,,,,,,,,, -121100,1.6060716,2.5831661,,,,,,,,,,,,,, -121200,1.5031201,3.195912,,,,,,,,,,,,,, -121300,1.6321015,2.2395384,,,,,,,,,,,,,, -121400,1.6502792,2.3503385,,,,,,,,,,,,,, -121500,1.5727535,2.3880847,,,,,,,,,,,,,, -121539,,,0.7638280987739563,0.9494880437850952,0.6881799697875977,1.290436625480652,50000.0,0.5708000063896179,1.9232209920883176,10000.0,55916.75475072861,62466.457062006,55916.75475072861,6538.339520454407,5.038301229476929,0.0 -121600,1.4737244,3.8434138,,,,,,,,,,,,,, -121700,1.5266991,3.3905082,,,,,,,,,,,,,, -121800,1.7119371,2.0714478,,,,,,,,,,,,,, -121900,1.6609548,2.281532,,,,,,,,,,,,,, -122000,1.7039729,2.0535805,,,,,,,,,,,,,, -122100,1.6488185,4.070706,,,,,,,,,,,,,, -122200,1.7956448,1.9034994,,,,,,,,,,,,,, -122300,1.9478956,2.0452495,,,,,,,,,,,,,, -122400,1.7405345,1.9224362,,,,,,,,,,,,,, -122454,,,0.753613293170929,1.002234935760498,0.6880599856376648,1.3044852018356323,50000.0,0.5659000277519226,1.932502269744873,10000.0,56337.04533982277,62936.76914644241,56337.04533982277,6588.268809556961,5.083476781845093,0.0 -122500,1.6458225,4.4955707,,,,,,,,,,,,,, -122600,1.7146529,2.9047346,,,,,,,,,,,,,, -122700,1.8487158,2.9455109,,,,,,,,,,,,,, -122800,1.7396885,4.444957,,,,,,,,,,,,,, -122900,1.5292454,3.8759356,,,,,,,,,,,,,, -123000,1.5110973,4.3839536,,,,,,,,,,,,,, -123100,1.7990265,2.006256,,,,,,,,,,,,,, -123200,1.8499509,2.3291657,,,,,,,,,,,,,, -123300,1.9322438,1.9938806,,,,,,,,,,,,,, -123369,,,0.7582421898841858,0.9696675539016724,0.692579984664917,1.262717843055725,50000.0,0.5748000144958496,1.90095317363739,10000.0,56757.24954080582,63404.89881134033,56757.24954080582,6636.103061914444,5.127235651016235,0.0 -123400,1.7764965,3.5260837,,,,,,,,,,,,,, -123500,1.923405,3.0129604,,,,,,,,,,,,,, -123600,2.0752764,1.9409899,,,,,,,,,,,,,, -123700,1.6245875,4.378442,,,,,,,,,,,,,, -123800,1.920209,2.137454,,,,,,,,,,,,,, -123900,1.920755,2.694937,,,,,,,,,,,,,, -124000,1.806053,1.9940544,,,,,,,,,,,,,, -124100,2.034778,4.616212,,,,,,,,,,,,,, -124200,1.8428359,1.8358241,,,,,,,,,,,,,, -124285,,,0.7618163824081421,0.9538049101829528,0.6948999762535095,1.2704392671585083,50000.0,0.5712000131607056,1.8988127708435056,10000.0,57177.33669304848,63871.52213644981,57177.33669304848,6682.54062128067,5.174166440963745,0.0 -124300,1.5939579,4.538171,,,,,,,,,,,,,, -124400,1.7445394,2.1431372,,,,,,,,,,,,,, -124500,1.6716884,2.9811943,,,,,,,,,,,,,, -124600,1.8891604,2.0047164,,,,,,,,,,,,,, -124700,1.6834538,3.2635741,,,,,,,,,,,,,, -124800,1.7444475,3.5593674,,,,,,,,,,,,,, -124900,1.9083216,4.555889,,,,,,,,,,,,,, -125000,1.926033,3.3464515,,,,,,,,,,,,,, -125100,1.9860541,4.084976,,,,,,,,,,,,,, -125198,,,0.7659569978713989,0.96906316280365,0.6935399770736694,1.2955999374389648,50000.0,0.5733000040054321,1.946867108345032,10000.0,57597.292941093445,64342.43574547768,57597.292941093445,6733.407159566879,5.218145370483398,0.0 -125200,1.4723488,2.800051,,,,,,,,,,,,,, -125300,1.7486119,1.91339,,,,,,,,,,,,,, -125400,1.5613497,2.3849273,,,,,,,,,,,,,, -125500,1.5410267,3.0744236,,,,,,,,,,,,,, -125600,2.1789336,1.8790195,,,,,,,,,,,,,, -125700,1.9553602,1.8136861,,,,,,,,,,,,,, -125800,1.7350626,3.150313,,,,,,,,,,,,,, -125900,1.9413235,2.065018,,,,,,,,,,,,,, -126000,1.8213031,1.8905625,,,,,,,,,,,,,, -126100,1.6771662,1.7698874,,,,,,,,,,,,,, -126112,,,0.7572070360183716,0.9875859618186952,0.6926199793815613,1.279220461845398,50000.0,0.5680000185966492,1.9271174669265747,10000.0,58017.22869372368,64812.9195394516,58017.22869372368,6783.851147174835,5.274590492248535,0.0 -126200,1.7114052,2.3236523,,,,,,,,,,,,,, -126300,1.9736199,1.937851,,,,,,,,,,,,,, -126400,1.8054968,1.9175608,,,,,,,,,,,,,, -126500,2.1074095,1.8805066,,,,,,,,,,,,,, -126600,1.926058,1.9057065,,,,,,,,,,,,,, -126700,1.6109906,2.563337,,,,,,,,,,,,,, -126800,2.2008255,1.90664,,,,,,,,,,,,,, -126900,1.938049,1.8268677,,,,,,,,,,,,,, -127000,1.8175592,2.161122,,,,,,,,,,,,,, -127025,,,0.7670117020606995,0.9423275589942932,0.6972000002861023,1.2566183805465698,50000.0,0.5734000205993652,1.893884301185608,10000.0,58437.14878249168,65282.39766430855,58437.14878249168,6833.316943883896,5.319833278656006,0.0 -127100,1.8459821,2.1233997,,,,,,,,,,,,,, -127200,2.209191,3.5224984,,,,,,,,,,,,,, -127300,1.9398546,2.7025497,,,,,,,,,,,,,, -127400,1.8403529,1.8346409,,,,,,,,,,,,,, -127500,2.0471761,1.8716421,,,,,,,,,,,,,, -127600,2.0534575,2.037884,,,,,,,,,,,,,, -127700,1.9612082,1.8554229,,,,,,,,,,,,,, -127800,2.495421,2.172573,,,,,,,,,,,,,, -127900,1.7588639,4.090966,,,,,,,,,,,,,, -127939,,,0.7809179425239563,0.883868932723999,0.6997999548912048,1.2388687133789062,50000.0,0.5772000551223755,1.8826701641082764,10000.0,58857.462636232376,65752.2223212719,58857.462636232376,6882.7317888736725,5.368030786514282,0.0 -128000,1.827364,1.8348796,,,,,,,,,,,,,, -128100,2.187672,1.8424945,,,,,,,,,,,,,, -128200,1.801985,2.3322067,,,,,,,,,,,,,, -128300,1.667142,3.5453994,,,,,,,,,,,,,, -128400,1.7774063,4.433899,,,,,,,,,,,,,, -128500,1.6618605,3.0900369,,,,,,,,,,,,,, -128600,1.8896369,4.316093,,,,,,,,,,,,,, -128700,1.905649,2.0536451,,,,,,,,,,,,,, -128800,1.7618868,2.0133793,,,,,,,,,,,,,, -128854,,,0.7710937261581421,0.9130802154541016,0.7034199833869934,1.210724949836731,50000.0,0.582800030708313,1.845189332962036,10000.0,59277.4666519165,66218.96444773674,59277.4666519165,6929.378044605255,5.412663459777832,0.0 -128900,1.9609755,1.9790022,,,,,,,,,,,,,, -129000,1.7688501,2.2877502,,,,,,,,,,,,,, -129100,1.9187909,2.3514583,,,,,,,,,,,,,, -129200,2.0564117,1.8017352,,,,,,,,,,,,,, -129300,1.8210299,1.8410676,,,,,,,,,,,,,, -129400,1.8123273,4.436377,,,,,,,,,,,,,, -129500,1.5795462,4.152633,,,,,,,,,,,,,, -129600,1.8821492,1.9809663,,,,,,,,,,,,,, -129700,2.19508,1.9732361,,,,,,,,,,,,,, -129767,,,0.7718359231948853,0.9090878367424012,0.6992599964141846,1.23419189453125,50000.0,0.5736000537872314,1.8769426345825195,10000.0,59697.61399292946,66688.71971225739,59697.61399292946,6978.882484436035,5.467691898345947,0.0 -129800,1.8929453,2.951091,,,,,,,,,,,,,, -129900,2.0213873,1.9110148,,,,,,,,,,,,,, -130000,2.0579734,2.108717,,,,,,,,,,,,,, -130100,2.2154064,1.8722873,,,,,,,,,,,,,, -130200,1.8872877,1.9782556,,,,,,,,,,,,,, -130300,1.7784618,3.048883,,,,,,,,,,,,,, -130400,1.7426815,2.6968875,,,,,,,,,,,,,, -130500,1.7015359,2.3204353,,,,,,,,,,,,,, -130600,2.0232174,2.0942116,,,,,,,,,,,,,, -130683,,,0.7808593511581421,0.8642115592956543,0.7039200067520142,1.218134522438049,50000.0,0.5805000066757202,1.8567651510238647,10000.0,60117.49275612831,67161.56207251549,60117.49275612831,7031.3348553180695,5.931267976760864,0.0 -130700,1.8621216,3.6765292,,,,,,,,,,,,,, -130800,1.7315167,2.6469827,,,,,,,,,,,,,, -130900,2.0107667,1.9389427,,,,,,,,,,,,,, -131000,1.9133829,1.9290389,,,,,,,,,,,,,, -131100,2.4551594,1.9920754,,,,,,,,,,,,,, -131200,2.0345573,4.5504007,,,,,,,,,,,,,, -131300,1.6654347,2.6350775,,,,,,,,,,,,,, -131400,2.0481489,3.7189562,,,,,,,,,,,,,, -131500,1.6299038,4.2111793,,,,,,,,,,,,,, -131596,,,0.7708203196525574,0.8994258642196655,0.7038799524307251,1.2020835876464844,50000.0,0.5819000005722046,1.8415201902389529,10000.0,60537.74565386772,67632.65162563324,60537.74565386772,7082.074404239655,5.981815576553345,0.0 -131600,2.0099592,1.8681991,,,,,,,,,,,,,, -131700,2.0009158,1.9046266,,,,,,,,,,,,,, -131800,1.8160146,4.3605084,,,,,,,,,,,,,, -131900,1.8231742,4.333833,,,,,,,,,,,,,, -132000,1.8953105,3.1766636,,,,,,,,,,,,,, -132100,1.9179507,3.59906,,,,,,,,,,,,,, -132200,1.9166129,2.0141785,,,,,,,,,,,,,, -132300,1.9757843,2.1090598,,,,,,,,,,,,,, -132400,1.8522557,2.8173335,,,,,,,,,,,,,, -132500,2.2924433,1.8968726,,,,,,,,,,,,,, -132511,,,0.7748632431030273,0.8906111717224121,0.7036799788475037,1.2126256227493286,50000.0,0.5835000276565552,1.8380862474441528,10000.0,60957.78933095932,68103.68083000183,60957.78933095932,7132.960758924484,6.033755540847778,0.0 -132600,2.1330144,2.1063716,,,,,,,,,,,,,, -132700,1.9172007,2.781083,,,,,,,,,,,,,, -132800,2.100128,2.206347,,,,,,,,,,,,,, -132900,2.7025123,1.7370967,,,,,,,,,,,,,, -133000,1.9479896,2.0188267,,,,,,,,,,,,,, -133100,1.7273719,3.0742922,,,,,,,,,,,,,, -133200,1.715679,2.4982255,,,,,,,,,,,,,, -133300,1.7874753,2.7463245,,,,,,,,,,,,,, -133400,1.924373,1.8866259,,,,,,,,,,,,,, -133426,,,0.7884570360183716,0.8362293839454651,0.7095800042152405,1.18121600151062,50000.0,0.5866000056266785,1.816677451133728,10000.0,61377.8827316761,68570.23081469536,61377.8827316761,7179.321115255356,6.082520008087158,0.0 -133500,1.891737,1.8270434,,,,,,,,,,,,,, -133600,1.9773688,2.1831439,,,,,,,,,,,,,, -133700,2.0713265,1.998608,,,,,,,,,,,,,, -133800,1.7275769,1.7707587,,,,,,,,,,,,,, -133900,1.8857864,1.8684629,,,,,,,,,,,,,, -134000,1.826308,2.8295946,,,,,,,,,,,,,, -134100,1.8386179,2.5972168,,,,,,,,,,,,,, -134200,2.0361373,1.7830156,,,,,,,,,,,,,, -134300,1.652375,3.373997,,,,,,,,,,,,,, -134343,,,0.7745312452316284,0.8990936279296875,0.7057799696922302,1.2014926671981812,50000.0,0.5829000473022461,1.8295793533325195,10000.0,61798.15866804123,69040.90519189835,61798.15866804123,7229.612575292587,6.1410746574401855,0.0 -134400,1.9049503,3.8121674,,,,,,,,,,,,,, -134500,1.8115993,3.7551165,,,,,,,,,,,,,, -134600,1.8796403,4.1310153,,,,,,,,,,,,,, -134700,2.1328773,4.306842,,,,,,,,,,,,,, -134800,2.0677903,1.7768636,,,,,,,,,,,,,, -134900,1.9527836,2.3295488,,,,,,,,,,,,,, -135000,2.2287958,1.7389045,,,,,,,,,,,,,, -135100,2.0572205,1.8704152,,,,,,,,,,,,,, -135200,1.8374201,2.254871,,,,,,,,,,,,,, -135256,,,0.7787109017372131,0.886769711971283,0.7080000042915344,1.2056972980499268,50000.0,0.5863000154495239,1.840918064117432,10000.0,62218.16717839241,69509.96885418892,62218.16717839241,7278.565196990967,6.195879459381104,0.0 -135300,2.2523139,1.7714584,,,,,,,,,,,,,, -135400,2.5986173,3.4747453,,,,,,,,,,,,,, -135500,1.9645113,2.2146673,,,,,,,,,,,,,, -135600,2.110637,1.8591974,,,,,,,,,,,,,, -135700,1.7938565,2.7473483,,,,,,,,,,,,,, -135800,2.147632,1.9326024,,,,,,,,,,,,,, -135900,1.877706,1.725388,,,,,,,,,,,,,, -136000,2.0948544,2.2136447,,,,,,,,,,,,,, -136100,1.9870968,2.3635824,,,,,,,,,,,,,, -136170,,,0.7862108945846558,0.8529349565505981,0.7099399566650391,1.196141004562378,50000.0,0.5907000303268433,1.8285341262817385,10000.0,62638.280061244965,69978.45606637001,62638.280061244965,7326.843999385834,6.2438578605651855,0.0 -136200,1.8002777,3.76145,,,,,,,,,,,,,, -136300,2.9667523,1.6525791,,,,,,,,,,,,,, -136400,1.886107,3.886273,,,,,,,,,,,,,, -136500,1.799387,2.1770995,,,,,,,,,,,,,, -136600,1.8995876,2.403483,,,,,,,,,,,,,, -136700,2.126512,1.8329456,,,,,,,,,,,,,, -136800,2.1255174,1.8033643,,,,,,,,,,,,,, -136900,1.8801786,3.754036,,,,,,,,,,,,,, -137000,1.9161524,1.8069624,,,,,,,,,,,,,, -137084,,,0.7843554615974426,0.8620164394378662,0.7106199860572815,1.1773674488067627,50000.0,0.5909000039100647,1.807421922683716,10000.0,63058.56229448319,70450.00231456757,63058.56229448319,7378.012609243393,6.291548013687134,0.0 -137100,2.2201545,1.7073548,,,,,,,,,,,,,, -137200,1.9446672,2.5312667,,,,,,,,,,,,,, -137300,2.1919312,3.3955803,,,,,,,,,,,,,, -137400,2.0782037,2.0572865,,,,,,,,,,,,,, -137500,1.9645244,1.712432,,,,,,,,,,,,,, -137600,2.041048,2.377942,,,,,,,,,,,,,, -137700,1.7817068,4.2614822,,,,,,,,,,,,,, -137800,2.1132762,1.992394,,,,,,,,,,,,,, -137900,2.2014964,4.319412,,,,,,,,,,,,,, -137998,,,0.7840625047683716,0.8761175274848938,0.7105000019073486,1.1932687759399414,50000.0,0.5921000242233276,1.8124998807907104,10000.0,63478.914189100266,70919.76289439201,63478.914189100266,7427.31960606575,6.34608793258667,0.0 -138000,1.813429,2.9168935,,,,,,,,,,,,,, -138100,2.0333066,2.1643684,,,,,,,,,,,,,, -138200,1.9584801,4.1849337,,,,,,,,,,,,,, -138300,2.0113025,2.9977028,,,,,,,,,,,,,, -138400,1.8517493,2.662723,,,,,,,,,,,,,, -138500,1.9027929,2.0814555,,,,,,,,,,,,,, -138600,2.1583254,2.2159214,,,,,,,,,,,,,, -138700,1.8285758,3.100926,,,,,,,,,,,,,, -138800,2.3509917,2.0863628,,,,,,,,,,,,,, -138900,1.9140114,3.6724243,,,,,,,,,,,,,, -138913,,,0.7923241853713989,0.8205782771110535,0.714199960231781,1.164374589920044,50000.0,0.5958000421524048,1.7992960214614868,10000.0,63899.1681098938,71390.19374513626,63899.1681098938,7477.396207094192,6.399160146713257,0.0 -139000,2.2270699,1.7985736,,,,,,,,,,,,,, -139100,2.1891859,2.5282307,,,,,,,,,,,,,, -139200,2.089672,1.8876979,,,,,,,,,,,,,, -139300,2.0771966,2.0593681,,,,,,,,,,,,,, -139400,2.1482315,1.807342,,,,,,,,,,,,,, -139500,1.8739707,4.0042963,,,,,,,,,,,,,, -139600,1.9563326,1.9049776,,,,,,,,,,,,,, -139700,2.0515077,1.7487516,,,,,,,,,,,,,, -139800,1.9599202,2.6257095,,,,,,,,,,,,,, -139830,,,0.8011132478713989,0.8038817644119263,0.7127000093460083,1.1908022165298462,50000.0,0.5888000130653381,1.8122637271881104,10000.0,64319.3515021801,71858.77561020851,64319.3515021801,7525.692511558533,6.453774929046631,0.0 -139900,2.2416546,1.9270375,,,,,,,,,,,,,, -140000,2.8235457,1.720548,,,,,,,,,,,,,, -140100,1.9261818,1.7964346,,,,,,,,,,,,,, -140200,2.3247411,1.9757757,,,,,,,,,,,,,, -140300,2.053011,1.7566581,,,,,,,,,,,,,, -140400,2.0245545,1.9600468,,,,,,,,,,,,,, -140500,2.2591515,1.8345015,,,,,,,,,,,,,, -140600,2.1166546,2.5558236,,,,,,,,,,,,,, -140700,2.2213461,1.7147602,,,,,,,,,,,,,, -140740,,,0.7881835699081421,0.8385329246520996,0.7159799933433533,1.167323350906372,50000.0,0.5986000299453735,1.7915613651275637,10000.0,64739.52413749695,72328.59392142296,64739.52413749695,7575.239279747009,6.506052732467651,0.0 -140800,2.1895523,2.0447147,,,,,,,,,,,,,, -140900,2.161484,2.931872,,,,,,,,,,,,,, -141000,2.2145717,1.7903717,,,,,,,,,,,,,, -141100,2.1323347,1.629521,,,,,,,,,,,,,, -141200,2.0807939,2.3158298,,,,,,,,,,,,,, -141300,2.3653994,1.688257,,,,,,,,,,,,,, -141400,1.9279205,3.1301756,,,,,,,,,,,,,, -141500,2.2722552,3.6600282,,,,,,,,,,,,,, -141600,2.5653727,1.8002516,,,,,,,,,,,,,, -141654,,,0.7936913967132568,0.8237582445144653,0.7181000113487244,1.1573004722595217,50000.0,0.5964000225067139,1.7926957607269287,10000.0,65159.69063377381,72796.25745105743,65159.69063377381,7622.6402060985565,6.555207014083862,0.0 -141700,2.2634785,1.9270092,,,,,,,,,,,,,, -141800,2.1465254,2.6144562,,,,,,,,,,,,,, -141900,2.2230523,1.8343618,,,,,,,,,,,,,, -142000,2.2706952,2.3777874,,,,,,,,,,,,,, -142100,2.203069,1.9169325,,,,,,,,,,,,,, -142200,1.9137927,2.1247792,,,,,,,,,,,,,, -142300,2.0417545,1.9852861,,,,,,,,,,,,,, -142400,2.0951319,4.2386093,,,,,,,,,,,,,, -142500,2.1095343,2.7767742,,,,,,,,,,,,,, -142571,,,0.8095507621765137,0.75356125831604,0.7208200097084045,1.128256916999817,50000.0,0.6039000153541565,1.7557858228683472,10000.0,65579.90997314453,73267.46773982048,65579.90997314453,7673.526467323303,6.612675666809082,0.0 -142600,2.135288,2.8516269,,,,,,,,,,,,,, -142700,2.1943796,1.7811744,,,,,,,,,,,,,, -142800,2.1721165,1.8349663,,,,,,,,,,,,,, -142900,2.2819262,1.706272,,,,,,,,,,,,,, -143000,2.0559783,3.1358864,,,,,,,,,,,,,, -143100,2.28987,4.2404613,,,,,,,,,,,,,, -143200,3.1272218,2.6815557,,,,,,,,,,,,,, -143300,2.0359595,3.7824783,,,,,,,,,,,,,, -143400,2.106227,1.8059775,,,,,,,,,,,,,, -143487,,,0.7930663824081421,0.80852872133255,0.719539999961853,1.132944107055664,50000.0,0.6000000238418579,1.7516762018203735,10000.0,66000.25936365128,73737.62437415123,66000.25936365128,7723.234818935394,6.664331912994385,0.0 -143500,2.2900326,2.4279723,,,,,,,,,,,,,, -143600,2.2387102,2.3062344,,,,,,,,,,,,,, -143700,2.033116,2.1779568,,,,,,,,,,,,,, -143800,2.475299,4.272487,,,,,,,,,,,,,, -143900,2.0361724,3.57358,,,,,,,,,,,,,, -144000,2.1939232,1.7601786,,,,,,,,,,,,,, -144100,2.0761511,4.249564,,,,,,,,,,,,,, -144200,2.1030884,2.8926578,,,,,,,,,,,,,, -144300,2.3018968,1.7413372,,,,,,,,,,,,,, -144400,,,0.7961718440055847,0.7973366379737854,0.7209199666976929,1.1386834383010864,50000.0,0.6027000546455383,1.758381962776184,10000.0,66420.45049548149,74208.18918466568,66420.45049548149,7773.509160041809,6.716897964477539,0.0 -144400,2.2823806,2.123879,,,,,,,,,,,,,, -144500,2.1485085,1.7979217,,,,,,,,,,,,,, -144600,2.2123704,1.6842507,,,,,,,,,,,,,, -144700,2.0211804,1.8523349,,,,,,,,,,,,,, -144800,2.0148492,3.8635662,,,,,,,,,,,,,, -144900,2.0073047,1.667954,,,,,,,,,,,,,, -145000,2.0426562,3.8125489,,,,,,,,,,,,,, -145100,2.1413937,3.5266478,,,,,,,,,,,,,, -145200,2.136817,2.9135072,,,,,,,,,,,,,, -145300,2.0910413,4.121146,,,,,,,,,,,,,, -145314,,,0.80712890625,0.7481192350387573,0.7227999567985535,1.1151829957962036,50000.0,0.6045000553131104,1.7355223894119265,10000.0,66840.47810816765,74677.93926692009,66840.47810816765,7823.125457286835,6.775168180465698,0.0 -145400,2.0627017,2.0516841,,,,,,,,,,,,,, -145500,1.8941202,2.9789212,,,,,,,,,,,,,, -145600,2.3392155,1.6304964,,,,,,,,,,,,,, -145700,2.1687791,2.9779453,,,,,,,,,,,,,, -145800,1.9510299,2.2688851,,,,,,,,,,,,,, -145900,2.1536577,1.7231209,,,,,,,,,,,,,, -146000,2.2673798,3.6923683,,,,,,,,,,,,,, -146100,2.5398486,2.0482764,,,,,,,,,,,,,, -146200,2.291686,1.7360023,,,,,,,,,,,,,, -146229,,,0.8002148270606995,0.7895472049713135,0.7212600111961365,1.1343475580215454,50000.0,0.5997000336647034,1.7654008865356443,10000.0,67260.58725810051,75146.9090578556,67260.58725810051,7871.883242845535,6.829352855682373,0.0 -146300,1.9065272,3.133309,,,,,,,,,,,,,, -146400,3.7316391,2.24301,,,,,,,,,,,,,, -146500,2.2944322,3.8300667,,,,,,,,,,,,,, -146600,2.9184654,3.2942743,,,,,,,,,,,,,, -146700,2.4839568,1.9490397,,,,,,,,,,,,,, -146800,2.2391133,4.043631,,,,,,,,,,,,,, -146900,1.9689687,3.2599869,,,,,,,,,,,,,, -147000,2.2822154,1.7308601,,,,,,,,,,,,,, -147100,2.19569,1.7513814,,,,,,,,,,,,,, -147140,,,0.7982617020606995,0.7943770289421082,0.7218599915504456,1.133278250694275,50000.0,0.6002000570297241,1.7653189897537231,10000.0,67680.65770792961,75616.78656816483,67680.65770792961,7921.5880670547485,6.884565353393555,0.0 -147200,2.0925841,1.8861455,,,,,,,,,,,,,, -147300,2.400855,1.7884284,,,,,,,,,,,,,, -147400,2.3452866,2.0015254,,,,,,,,,,,,,, -147500,2.2446163,1.7151569,,,,,,,,,,,,,, -147600,2.2691228,1.9534245,,,,,,,,,,,,,, -147700,2.22007,1.7768924,,,,,,,,,,,,,, -147800,2.3204222,2.5298882,,,,,,,,,,,,,, -147900,2.3988001,1.7479229,,,,,,,,,,,,,, -148000,1.9772176,2.9504304,,,,,,,,,,,,,, -148051,,,0.8085156083106995,0.7679116725921631,0.7251200079917908,1.131921410560608,50000.0,0.602400004863739,1.749596118927002,10000.0,68100.57943248749,76085.59352064133,68100.57943248749,7970.376499891281,6.934350490570068,0.0 -148100,2.1907332,1.8508986,,,,,,,,,,,,,, -148200,2.1997094,1.5966313,,,,,,,,,,,,,, -148300,2.4176457,1.7914519,,,,,,,,,,,,,, -148400,2.2187972,1.7782421,,,,,,,,,,,,,, -148500,2.2976837,1.74876,,,,,,,,,,,,,, -148600,2.2450473,3.403708,,,,,,,,,,,,,, -148700,2.3701904,1.645428,,,,,,,,,,,,,, -148800,2.2614393,1.7494814,,,,,,,,,,,,,, -148900,2.154786,2.2519574,,,,,,,,,,,,,, -148964,,,0.8046875,0.7722216844558716,0.7258399724960327,1.1195939779281616,50000.0,0.6067000031471252,1.739829421043396,10000.0,68520.65689635277,76556.07243704796,68520.65689635277,8020.678261995316,6.98644495010376,0.0 -149000,2.366788,4.111918,,,,,,,,,,,,,, -149100,2.5020704,1.6773458,,,,,,,,,,,,,, -149200,2.293873,1.6751124,,,,,,,,,,,,,, -149300,2.3765557,1.6190654,,,,,,,,,,,,,, -149400,2.208099,1.5741993,,,,,,,,,,,,,, -149500,2.2856314,1.7167947,,,,,,,,,,,,,, -149600,3.418219,1.754642,,,,,,,,,,,,,, -149700,2.3210976,1.8390639,,,,,,,,,,,,,, -149800,2.1711268,4.2468266,,,,,,,,,,,,,, -149875,,,0.8066015243530273,0.7632541656494141,0.7278800010681152,1.111369013786316,50000.0,0.6060000061988831,1.7303481101989746,10000.0,68940.75487065315,77025.45209693909,68940.75487065315,8069.858564853668,7.0408594608306885,0.0 -149900,1.9947698,3.560317,,,,,,,,,,,,,, -150000,2.1448722,2.7947562,,,,,,,,,,,,,, -150100,2.2842875,1.7633529,,,,,,,,,,,,,, -150200,2.2131944,3.473332,,,,,,,,,,,,,, -150300,2.584227,2.5803804,,,,,,,,,,,,,, -150400,2.5305538,1.9656901,,,,,,,,,,,,,, -150500,2.5673983,1.7111399,,,,,,,,,,,,,, -150600,3.1118565,2.0555096,,,,,,,,,,,,,, -150700,2.2086995,1.6914124,,,,,,,,,,,,,, -150790,,,0.8119726181030273,0.7188385725021362,0.7307599782943726,1.0796045064926147,50000.0,0.6104000210762024,1.7081046104431152,10000.0,69360.97912240028,77494.48408341408,69360.97912240028,8118.56437587738,7.0947325229644775,0.0 -150800,2.473481,1.7895745,,,,,,,,,,,,,, -150900,2.3451495,1.8129436,,,,,,,,,,,,,, -151000,2.1866114,2.082495,,,,,,,,,,,,,, -151100,2.1424906,2.0327876,,,,,,,,,,,,,, -151200,2.2742515,4.0639577,,,,,,,,,,,,,, -151300,2.1633668,3.1963289,,,,,,,,,,,,,, -151400,2.350355,4.053628,,,,,,,,,,,,,, -151500,2.183278,4.086027,,,,,,,,,,,,,, -151600,2.2682607,2.3389416,,,,,,,,,,,,,, -151700,2.5127091,1.7997373,,,,,,,,,,,,,, -151704,,,0.8213671445846558,0.7049360275268555,0.7328000068664551,1.0796293020248413,50000.0,0.6109000444412231,1.7023441791534424,10000.0,69781.3741095066,77963.0119497776,69781.3741095066,8166.598405122757,7.145893335342407,0.0 -151800,2.74438,1.6838472,,,,,,,,,,,,,, -151900,2.629194,1.9204133,,,,,,,,,,,,,, -152000,2.6559117,1.7029979,,,,,,,,,,,,,, -152100,2.3735535,1.5708526,,,,,,,,,,,,,, -152200,2.5487282,2.0203173,,,,,,,,,,,,,, -152300,2.2951624,1.5931295,,,,,,,,,,,,,, -152400,2.4571407,3.3045118,,,,,,,,,,,,,, -152500,3.0489442,1.8924495,,,,,,,,,,,,,, -152600,2.4079885,1.612385,,,,,,,,,,,,,, -152618,,,0.8115234375,0.7339527010917664,0.7320799827575684,1.0806304216384888,50000.0,0.61080002784729,1.709767460823059,10000.0,70201.74805235863,78432.79470396042,70201.74805235863,8215.91041135788,7.195410490036011,0.0 -152700,2.5624855,3.8572893,,,,,,,,,,,,,, -152800,2.2060075,3.6384377,,,,,,,,,,,,,, -152900,7.0289235,1.6177033,,,,,,,,,,,,,, -153000,2.2268903,1.5476254,,,,,,,,,,,,,, -153100,2.5961556,1.5445383,,,,,,,,,,,,,, -153200,2.6117113,1.6762658,,,,,,,,,,,,,, -153300,2.3857627,3.8256683,,,,,,,,,,,,,, -153400,2.2719853,1.7813148,,,,,,,,,,,,,, -153500,2.3926563,2.2894542,,,,,,,,,,,,,, -153535,,,0.815722644329071,0.7261427640914917,0.7325199842453003,1.0846915245056152,50000.0,0.6096000075340271,1.706923484802246,10000.0,70622.05784344673,78903.90107440948,70622.05784344673,8266.604093313217,7.249720573425293,0.0 -153600,2.8565428,1.612169,,,,,,,,,,,,,, -153700,2.4625793,1.6213436,,,,,,,,,,,,,, -153800,2.3487997,2.185784,,,,,,,,,,,,,, -153900,2.273941,1.5937836,,,,,,,,,,,,,, -154000,3.5504797,1.6241679,,,,,,,,,,,,,, -154100,2.290918,2.8961985,,,,,,,,,,,,,, -154200,2.422048,1.528523,,,,,,,,,,,,,, -154300,2.8296428,1.7366372,,,,,,,,,,,,,, -154400,2.449479,3.6296186,,,,,,,,,,,,,, -154448,,,0.8185155987739563,0.7208164930343628,0.730239987373352,1.097428798675537,50000.0,0.61080002784729,1.7289506196975708,10000.0,71042.29072260857,79374.33994674683,71042.29072260857,8316.71166753769,7.300796031951904,0.0 -154500,2.4714618,3.4995925,,,,,,,,,,,,,, -154600,2.3583312,2.3262553,,,,,,,,,,,,,, -154700,2.3905952,3.2870612,,,,,,,,,,,,,, -154800,2.2387452,2.230258,,,,,,,,,,,,,, -154900,2.285903,4.179448,,,,,,,,,,,,,, -155000,2.3839016,1.5293624,,,,,,,,,,,,,, -155100,2.246227,2.7131405,,,,,,,,,,,,,, -155200,2.482985,1.7339611,,,,,,,,,,,,,, -155300,2.4949765,4.093935,,,,,,,,,,,,,, -155361,,,0.8110156059265137,0.7516688108444214,0.731440007686615,1.091639757156372,50000.0,0.6110000014305115,1.7229865789413452,10000.0,71462.36037421227,79842.9433221817,71462.36037421227,8365.147471904755,7.35153603553772,0.0 -155400,2.4632626,3.32886,,,,,,,,,,,,,, -155500,2.6866505,1.6408198,,,,,,,,,,,,,, -155600,2.5584154,1.7074771,,,,,,,,,,,,,, -155700,2.4235542,1.7763436,,,,,,,,,,,,,, -155800,2.3617618,3.364598,,,,,,,,,,,,,, -155900,2.3511102,2.1393273,,,,,,,,,,,,,, -156000,2.3606653,3.6830735,,,,,,,,,,,,,, -156100,2.7809656,1.9067172,,,,,,,,,,,,,, -156200,2.2588615,1.6247132,,,,,,,,,,,,,, -156274,,,0.8199804425239563,0.7003152370452881,0.7363599538803101,1.0666792392730713,50000.0,0.6121000051498413,1.7028956413269043,10000.0,71882.66050243378,80314.7297205925,71882.66050243378,8416.536350011826,7.401159286499023,0.0 -156300,2.434148,2.3417656,,,,,,,,,,,,,, -156400,2.5679178,1.476946,,,,,,,,,,,,,, -156500,2.222241,3.4844756,,,,,,,,,,,,,, -156600,2.7956824,3.160487,,,,,,,,,,,,,, -156700,2.6487877,4.066863,,,,,,,,,,,,,, -156800,2.8367913,1.5328101,,,,,,,,,,,,,, -156900,2.5996213,2.883998,,,,,,,,,,,,,, -157000,2.5827699,4.03124,,,,,,,,,,,,,, -157100,2.789932,1.5264076,,,,,,,,,,,,,, -157186,,,0.82289057970047,0.6994317173957825,0.7371799945831299,1.0725938081741333,50000.0,0.6162000298500061,1.7018332481384275,10000.0,72303.01607465744,80785.31089067459,72303.01607465744,8466.652791976929,7.462125539779663,0.0 -157200,2.7026992,1.5403035,,,,,,,,,,,,,, -157300,2.5918753,3.2716985,,,,,,,,,,,,,, -157400,2.6852062,2.507794,,,,,,,,,,,,,, -157500,2.442107,2.6742697,,,,,,,,,,,,,, -157600,2.5484538,2.0771973,,,,,,,,,,,,,, -157700,2.8348312,3.1707926,,,,,,,,,,,,,, -157800,2.4332974,1.5067797,,,,,,,,,,,,,, -157900,2.9383583,4.109537,,,,,,,,,,,,,, -158000,2.4997735,2.8009486,,,,,,,,,,,,,, -158097,,,0.821582019329071,0.718077540397644,0.738599956035614,1.0715638399124146,50000.0,0.6152000427246094,1.7035168409347534,10000.0,72723.06442546844,81254.37323331833,72723.06442546844,8515.557217359543,7.524913787841797,0.0 -158100,2.9422514,1.6890025,,,,,,,,,,,,,, -158200,2.627483,4.0393505,,,,,,,,,,,,,, -158300,2.9427667,1.5783377,,,,,,,,,,,,,, -158400,2.6911857,3.1670303,,,,,,,,,,,,,, -158500,2.380884,3.9846573,,,,,,,,,,,,,, -158600,2.531214,1.9562408,,,,,,,,,,,,,, -158700,2.8785162,1.7194799,,,,,,,,,,,,,, -158800,2.512054,1.5158069,,,,,,,,,,,,,, -158900,2.5582898,1.523528,,,,,,,,,,,,,, -159000,2.6544251,1.4869255,,,,,,,,,,,,,, -159010,,,0.8240429759025574,0.699968159198761,0.7378199696540833,1.0641032457351685,50000.0,0.6173000335693359,1.6965481042861938,10000.0,73143.2449285984,81725.02041864395,73143.2449285984,8565.915897130966,7.585332155227661,0.0 -159100,2.3926291,3.41692,,,,,,,,,,,,,, -159200,2.3299503,3.7427444,,,,,,,,,,,,,, -159300,2.5469902,1.6426957,,,,,,,,,,,,,, -159400,2.3212466,1.731125,,,,,,,,,,,,,, -159500,2.728911,1.7999718,,,,,,,,,,,,,, -159600,2.6538641,1.5955727,,,,,,,,,,,,,, -159700,2.6921346,2.832108,,,,,,,,,,,,,, -159800,2.2406147,2.8062518,,,,,,,,,,,,,, -159900,2.7193096,1.5313983,,,,,,,,,,,,,, -159924,,,0.822558581829071,0.6985819935798645,0.7378399968147278,1.068437933921814,50000.0,0.6152999997138977,1.6972639560699463,10000.0,73563.15134334564,82192.29157710075,73563.15134334564,8613.183167696,7.635556221008301,0.0 -160000,2.8202567,3.7340355,,,,,,,,,,,,,, -160100,3.7007902,1.5744444,,,,,,,,,,,,,, -160200,2.8558671,1.5662866,,,,,,,,,,,,,, -160300,2.9090211,1.604398,,,,,,,,,,,,,, -160400,2.5091515,1.5897094,,,,,,,,,,,,,, -160500,2.7292428,1.563863,,,,,,,,,,,,,, -160600,2.7283363,1.6291113,,,,,,,,,,,,,, -160700,2.8305404,1.5343274,,,,,,,,,,,,,, -160800,2.6276925,2.9348497,,,,,,,,,,,,,, -160838,,,0.8229296803474426,0.6872890591621399,0.7397199869155884,1.054733395576477,50000.0,0.6131000518798828,1.6841599941253662,10000.0,73983.04493808746,82659.67759394646,73983.04493808746,8660.566656589508,7.696813583374023,0.0 -160900,2.4082701,2.271676,,,,,,,,,,,,,, -161000,2.5927842,3.0758934,,,,,,,,,,,,,, -161100,2.724163,1.8174974,,,,,,,,,,,,,, -161200,2.4282484,2.1594002,,,,,,,,,,,,,, -161300,2.6688812,1.5477941,,,,,,,,,,,,,, -161400,2.506119,2.7589853,,,,,,,,,,,,,, -161500,2.512453,3.521865,,,,,,,,,,,,,, -161600,3.0673468,1.533242,,,,,,,,,,,,,, -161700,2.5788305,2.4692633,,,,,,,,,,,,,, -161752,,,0.8261327743530273,0.6801848411560059,0.7404199838638306,1.0489463806152344,50000.0,0.615600049495697,1.6832082271575928,10000.0,74403.15802598,83129.38182520866,74403.15802598,8710.053017377853,7.754112958908081,0.0 -161800,2.657496,4.0811954,,,,,,,,,,,,,, -161900,2.6822913,2.193638,,,,,,,,,,,,,, -162000,2.5598412,3.3697236,,,,,,,,,,,,,, -162100,2.3559735,2.1016932,,,,,,,,,,,,,, -162200,2.3437402,1.6191288,,,,,,,,,,,,,, -162300,2.53573,2.0331435,,,,,,,,,,,,,, -162400,2.6472104,1.4511846,,,,,,,,,,,,,, -162500,3.337155,3.9932694,,,,,,,,,,,,,, -162600,2.726613,2.2683694,,,,,,,,,,,,,, -162668,,,0.8266015648841858,0.6813116073608398,0.7404999732971191,1.0605944395065308,50000.0,0.6183000206947327,1.6932504177093506,10000.0,74823.39261484146,83598.31682682037,74823.39261484146,8758.655586957932,7.804761648178101,0.0 -162700,2.3543706,3.6069362,,,,,,,,,,,,,, -162800,2.605392,2.308408,,,,,,,,,,,,,, -162900,2.8256328,1.5561824,,,,,,,,,,,,,, -163000,2.5143187,3.11068,,,,,,,,,,,,,, -163100,2.8773603,4.127394,,,,,,,,,,,,,, -163200,2.7569735,1.560152,,,,,,,,,,,,,, -163300,2.5452409,2.695476,,,,,,,,,,,,,, -163400,2.6398847,2.0001247,,,,,,,,,,,,,, -163500,2.971863,1.5270561,,,,,,,,,,,,,, -163582,,,0.82958984375,0.6572279334068298,0.7425599694252014,1.0434695482254028,50000.0,0.6205000281333923,1.6695793867111206,10000.0,75243.69893074036,84069.89703917503,75243.69893074036,8809.826647281647,7.860961198806763,0.0 -163600,2.6714253,1.4364185,,,,,,,,,,,,,, -163700,2.5347695,1.7836621,,,,,,,,,,,,,, -163800,2.575445,3.992654,,,,,,,,,,,,,, -163900,2.8718755,3.7883215,,,,,,,,,,,,,, -164000,3.129883,1.4518547,,,,,,,,,,,,,, -164100,3.2149134,1.4347596,,,,,,,,,,,,,, -164200,2.8358767,1.8451322,,,,,,,,,,,,,, -164300,3.0518346,3.2571468,,,,,,,,,,,,,, -164400,2.5573382,2.0053165,,,,,,,,,,,,,, -164496,,,0.8274804353713989,0.6675172448158264,0.7411800026893616,1.0367389917373655,50000.0,0.6194000244140625,1.6700612306594849,10000.0,75663.79537057877,84539.8385951519,75663.79537057877,8859.568863630295,7.916836023330688,0.0 -164500,2.335134,3.0634375,,,,,,,,,,,,,, -164600,2.5841947,1.35489,,,,,,,,,,,,,, -164700,2.5953155,2.5281992,,,,,,,,,,,,,, -164800,3.3497937,1.5426083,,,,,,,,,,,,,, -164900,2.7731988,2.2849164,,,,,,,,,,,,,, -165000,2.6155627,1.6178334,,,,,,,,,,,,,, -165100,7.127208,1.6435124,,,,,,,,,,,,,, -165200,3.1321175,1.6713903,,,,,,,,,,,,,, -165300,2.8449829,1.5768875,,,,,,,,,,,,,, -165400,2.426423,3.0922687,,,,,,,,,,,,,, -165411,,,0.8327734470367432,0.6660227179527283,0.741599977016449,1.0521215200424194,50000.0,0.6206000447273254,1.6760276556015017,10000.0,76083.82977080345,85010.10476636887,76083.82977080345,8909.697760820389,7.97200608253479,0.0 -165500,3.2577076,1.5920396,,,,,,,,,,,,,, -165600,2.7839735,2.9793465,,,,,,,,,,,,,, -165700,2.7189837,3.8718028,,,,,,,,,,,,,, -165800,2.6671255,2.060661,,,,,,,,,,,,,, -165900,3.0237708,1.4951881,,,,,,,,,,,,,, -166000,2.789225,1.498455,,,,,,,,,,,,,, -166100,2.8854907,1.6151427,,,,,,,,,,,,,, -166200,3.2040913,2.552286,,,,,,,,,,,,,, -166300,2.4392655,2.659749,,,,,,,,,,,,,, -166329,,,0.8364452719688416,0.6335233449935913,0.7430599927902222,1.0303434133529663,50000.0,0.6205000281333923,1.660801887512207,10000.0,76504.16225075722,85479.89331364632,76504.16225075722,8959.051559209824,8.026431798934937,0.0 -166400,2.791218,1.9468317,,,,,,,,,,,,,, -166500,2.8684924,1.3728315,,,,,,,,,,,,,, -166600,2.728005,3.8665364,,,,,,,,,,,,,, -166700,2.9370937,1.5148156,,,,,,,,,,,,,, -166800,2.8158896,1.6613938,,,,,,,,,,,,,, -166900,2.7197752,1.4955826,,,,,,,,,,,,,, -167000,2.606429,1.4810734,,,,,,,,,,,,,, -167100,3.1908484,1.6889868,,,,,,,,,,,,,, -167200,2.7182076,2.275215,,,,,,,,,,,,,, -167241,,,0.8333203196525574,0.6550984978675842,0.7444999814033508,1.033787965774536,50000.0,0.619100034236908,1.6686664819717407,10000.0,76924.07387590408,85950.96992588043,76924.07387590408,9010.117893695831,8.077922344207764,0.0 -167300,2.6259153,3.1146984,,,,,,,,,,,,,, -167400,4.981511,1.5102228,,,,,,,,,,,,,, -167500,2.8535132,3.0333505,,,,,,,,,,,,,, -167600,2.8809636,1.4732412,,,,,,,,,,,,,, -167700,4.0291204,2.349555,,,,,,,,,,,,,, -167800,2.8755536,1.4956532,,,,,,,,,,,,,, -167900,2.909161,1.419075,,,,,,,,,,,,,, -168000,3.3080282,1.3950797,,,,,,,,,,,,,, -168100,2.921434,1.5212653,,,,,,,,,,,,,, -168155,,,0.8345312476158142,0.6509010195732117,0.7447999715805054,1.033661127090454,50000.0,0.6255000233650208,1.6653497219085691,10000.0,77344.41281080246,86418.69306540489,77344.41281080246,9057.402058124542,8.131011486053467,0.0 -168200,3.3383105,1.5059235,,,,,,,,,,,,,, -168300,3.358402,2.225858,,,,,,,,,,,,,, -168400,3.016772,2.4544811,,,,,,,,,,,,,, -168500,2.7144632,2.4550703,,,,,,,,,,,,,, -168542,,,,,,,,,,,77520.18580436707,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index b9dcbb68f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -183.698429107666,0.0,62.217230796813965,1,0,62.217230796813965,31.27949,2472,1.0976986980277457,245.91571712493896,32.02916,1.3788965703933531,31.163591,5348,1.043156299178389 -291.5967798233032,0.0429005622863769,1502.7666385173798,1738,0,1502.7666385173798,6.6806073,2472,0.899579550301627,1794.4777307510376,6.6277523,0.944635537887994,6.710369,5348,0.8966179750330672 -419.49665999412537,0.0937092304229736,2943.0448310375214,3514,0,2943.0448310375214,3.023245,2472,0.625312290536835,3362.7826220989227,2.9477355,0.6334869349812215,3.3468678,5348,0.6825936259980497 -553.7924075126648,0.1545546054840088,4383.173229217529,5244,0,4383.173229217529,0.86656547,2472,0.2749172303130014,4937.342960596085,0.7954038,0.26949412527476,1.1628219,5348,0.3323517769388957 -690.0164759159088,0.2065155506134033,5823.355607032776,6967,0,5823.355607032776,0.570568,2472,0.1934881075701257,6513.878301858902,0.52575487,0.1829163568578579,0.8364061,5348,0.2505092829489172 -824.9669606685638,0.2596597671508789,7263.37749004364,8695,0,7263.37749004364,0.49796993,2472,0.1694595088660045,8088.980140447617,0.4505587,0.160277457652917,0.7544828,5348,0.2271257132374948 -961.338276863098,0.3113949298858642,8703.979434251785,10387,0,8703.979434251785,0.4417418,2472,0.1507931671846119,9666.078685045242,0.4173494,0.1483587710921787,0.68799996,5348,0.2082894851173523 -1098.0837564468384,0.3614089488983154,10144.905450105667,12101,0,10144.905450105667,0.40709147,2472,0.1400483415595231,11243.875876426697,0.35887453,0.1294535553995703,0.641523,5348,0.1964142618535003 -1235.3221170902252,0.4123513698577881,11585.165503025057,13827,0,11585.165503025057,0.38234127,2472,0.130075355960433,12821.49998164177,0.30647576,0.1146051639408201,0.61191267,5348,0.1859196539772343 -1371.180969953537,0.4650459289550781,13025.10080051422,15527,0,13025.10080051422,0.35999078,2472,0.122214774643024,14397.421264886856,0.2841638,0.1044799940415706,0.58510756,5348,0.1773559767129768 -1507.310597896576,0.5201177597045898,14466.018681049349,17244,0,14466.018681049349,0.3494229,2472,0.1185993134686084,15974.59987092018,0.27435565,0.1028117575764076,0.569697,5348,0.172161773366674 -1643.9017758369446,0.5702481269836426,15906.14902305603,18961,0,15906.14902305603,0.34298202,2472,0.1168525176202953,17551.447067975998,0.28704607,0.1045823527856368,0.55894655,5348,0.170549446305647 -1783.0626783370972,0.6249892711639404,17346.25297522545,20642,0,17346.25297522545,0.32796475,2472,0.1121605427254077,19130.841157197952,0.2800363,0.1034584641359115,0.54050195,5348,0.1640229008370584 -1918.723219871521,0.6793062686920166,18786.501445770264,22354,0,18786.501445770264,0.3156809,2472,0.109093494201044,20706.88010787964,0.2611929,0.096928696367939,0.5234255,5348,0.1602672407967019 -2053.571811437607,0.733173131942749,20226.675240516663,24076,0,20226.675240516663,0.3084755,2472,0.104239026669104,22282.030920743942,0.24182494,0.0891273594865081,0.5162125,5348,0.1568881122256871 -2189.5265684127808,0.7874441146850586,21666.55222582817,25751,0,21666.55222582817,0.30117932,2472,0.102309426604107,23857.990137577057,0.22988309,0.0859229505203578,0.50804406,5348,0.1546192687565772 -2325.708679676056,0.8463056087493896,23107.052596330643,27470,0,23107.052596330643,0.29376552,2472,0.1009891739280563,25434.80774664879,0.22008015,0.081281770668194,0.49360743,5348,0.1496664317367755 -2464.611401796341,0.904517650604248,24547.576735258102,29189,0,24547.576735258102,0.2872465,2472,0.0984908496333759,27014.36773943901,0.2380161,0.0883168660079736,0.48860037,5348,0.1491933537368334 -2599.4616689682007,0.9569101333618164,25988.7187268734,30881,0,25988.7187268734,0.28259403,2472,0.0946316495033818,28590.486193180084,0.2175279,0.0792353729537328,0.47386247,5348,0.1434874537783484 -2738.427877187729,1.0178043842315674,27428.926631212234,32563,0,27428.926631212234,0.26877353,2472,0.091117746227124,30169.796103715897,0.22219987,0.0804956104978193,0.46052456,5348,0.1377525898606833 -2875.2527759075165,1.0767803192138672,28868.84626793861,34277,0,28868.84626793861,0.27387086,2472,0.091930209412386,31746.674035787582,0.21909791,0.0772190074277737,0.4654405,5348,0.1390752773299091 -3013.4304831027985,1.1376783847808838,30309.341188192368,35964,0,30309.341188192368,0.26163143,2472,0.0876647776897609,33325.48345398903,0.17266439,0.0654295173375152,0.45770597,5348,0.1368643617791594 -3149.3453755378723,1.1959314346313477,31749.26070761681,37649,0,31749.26070761681,0.26120028,2472,0.0886600450917067,34901.451577186584,0.19681297,0.0729393254630314,0.4440631,5348,0.134247950799888 -3283.162401914597,1.2573490142822266,33189.13189792633,39356,0,33189.13189792633,0.2517394,2472,0.0832165417504519,36475.276661872864,0.24031529,0.0871487228999311,0.4375541,5348,0.1294013149637467 -3418.5341413021088,1.3193349838256836,34629.38103199005,41039,0,34629.38103199005,0.24431132,2472,0.0838462007190299,38051.03439188004,0.24005286,0.0886769514314105,0.4300553,5348,0.1283875763924423 -3551.3137764930725,1.3813042640686035,36069.76509022713,42750,0,36069.76509022713,0.24818589,2472,0.0830540491133995,39624.3362300396,0.28034556,0.1032108615919181,0.42529032,5348,0.1256649642295104 -3685.243673801422,1.445460557937622,37509.9409840107,44453,0,37509.9409840107,0.23739682,2472,0.080738529035403,41198.581345796585,0.24660096,0.0885492407962366,0.41894877,5348,0.124660880311266 -3819.056458711624,1.5136487483978271,38950.01946210861,46142,0,38950.01946210861,0.23045741,2472,0.0773871183961976,42772.61559915543,0.22179091,0.0823600146972041,0.40570727,5348,0.1204804155362677 -3954.9089529514313,1.5759549140930176,40389.98576760292,47851,0,40389.98576760292,0.22821124,2472,0.0758231267645684,44348.57243299484,0.18740244,0.0704520984458776,0.4032107,5348,0.1185205209650791 -4090.652411699295,1.6439871788024902,41830.2703332901,49534,0,41830.2703332901,0.21855989,2472,0.0722889119086791,45924.74306154251,0.20381515,0.0767133554557452,0.3923336,5348,0.1146876237002423 -4225.891577243805,1.700354814529419,43270.42231464386,51230,0,43270.42231464386,0.21695636,2472,0.0702374423658928,47500.26622223854,0.17861539,0.0677503865315237,0.3843447,5348,0.1125925639862131 -4360.704236030579,1.857272148132324,44710.28479671478,52957,0,44710.28479671478,0.21086185,2472,0.0709686592326285,49075.17360305786,0.17695369,0.0670163486260428,0.372745,5348,0.1091941261090782 -4496.820524454117,1.9165270328521729,46150.37541222572,54652,0,46150.37541222572,0.19986272,2472,0.067393821217476,50651.51419734955,0.16497089,0.0618523119392684,0.36614746,5348,0.1077169641908918 -4631.75553393364,1.9792718887329104,47590.53766059876,56368,0,47590.53766059876,0.1941069,2472,0.0644080190116385,52226.748883485794,0.16453665,0.0620743335279237,0.3558999,5348,0.1043861088851772 -4767.43133020401,2.0423855781555176,49030.65796470642,58095,0,49030.65796470642,0.19088641,2472,0.0614425283854325,53802.68481111527,0.16163893,0.0600747122055348,0.34983188,5348,0.1020110642324068 -4900.822810411453,2.103106737136841,50470.96928143501,59789,0,50470.96928143501,0.1848671,2472,0.0623159263095891,55376.52167439461,0.13394181,0.0514773466047929,0.34509662,5348,0.1003601185591395 -5035.821505784988,2.164682149887085,51911.2644867897,61491,0,51911.2644867897,0.1836257,2472,0.061462839965064,56951.9517929554,0.13361306,0.0517890557254347,0.3403858,5348,0.0993753439470152 -5168.687787055969,2.230812788009644,53351.83570098877,63210,0,53351.83570098877,0.17569458,2472,0.057827067211017,58525.53119134903,0.11915162,0.04634981922064,0.33299387,5348,0.0958610502331598 -5302.091877937317,2.291220664978028,54792.682903051376,64881,0,54792.682903051376,0.17374022,2472,0.0569130461275973,60099.91721081734,0.115036234,0.0437265496450064,0.32092267,5348,0.0931867113355281 -5436.998205900192,2.361929655075073,56232.99437975884,66568,0,56232.99437975884,0.17099647,2472,0.0549631344829687,61675.28289914131,0.10550904,0.0410768989959974,0.3176728,5348,0.0916226575398013 -5572.620210170746,2.421889066696167,57673.43183207512,68277,0,57673.43183207512,0.16618471,2472,0.0533382081124449,63251.47676539421,0.09899141,0.038625017987005,0.31188878,5348,0.0894213966421116 -5706.620025396347,2.478905200958252,59113.41282272339,69968,0,59113.41282272339,0.16443966,2472,0.0522820059716044,64825.58831310272,0.079612635,0.0305982258456632,0.3068214,5348,0.0878380335402647 -5839.893024682999,2.547360420227051,60553.82369470596,71672,0,60553.82369470596,0.16102934,2472,0.05179452806044726,66399.41762804985,0.07433639,0.028280459691990507,0.30262092,5348,0.0857526284792956 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 623977daf..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,769 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,46.656998,32.599678,,,,,,,,,,,,,, -1,,,32.02916,1.3788965703933531,31.163591,1.043156299178389,5348.0,31.27949,1.0976986980277457,2472.0,62.217230796813965,245.91571712493896,62.217230796813965,183.698429107666,0.0,0.0 -100,23.548836,8.065066,,,,,,,,,,,,,, -200,2.1804693,6.2075763,,,,,,,,,,,,,, -300,1.2001466,5.863913,,,,,,,,,,,,,, -400,0.3545736,5.8146315,,,,,,,,,,,,,, -500,0.25007635,5.8123407,,,,,,,,,,,,,, -600,0.30275348,5.81114,,,,,,,,,,,,,, -700,0.8503318,5.817859,,,,,,,,,,,,,, -800,0.5862387,5.8034534,,,,,,,,,,,,,, -900,0.2547178,5.771727,,,,,,,,,,,,,, -1000,0.45023257,5.800831,,,,,,,,,,,,,, -1100,0.49985904,5.785528,,,,,,,,,,,,,, -1200,0.36314166,5.770524,,,,,,,,,,,,,, -1300,0.4102992,5.785315,,,,,,,,,,,,,, -1400,0.29685614,5.772067,,,,,,,,,,,,,, -1500,0.4834307,5.776932,,,,,,,,,,,,,, -1600,0.8626517,5.7182226,,,,,,,,,,,,,, -1700,0.7020875,5.5808964,,,,,,,,,,,,,, -1738,,,6.6277523,0.944635537887994,6.710369,0.8966179750330672,5348.0,6.6806073,0.899579550301627,2472.0,1502.7666385173798,1794.4777307510376,1502.7666385173798,291.5967798233032,0.0429005622863769,0.0 -1800,2.1735997,5.529718,,,,,,,,,,,,,, -1900,0.4659097,5.293019,,,,,,,,,,,,,, -2000,1.6944855,4.8540144,,,,,,,,,,,,,, -2100,0.7902208,4.2205257,,,,,,,,,,,,,, -2200,1.165855,3.8382928,,,,,,,,,,,,,, -2300,1.1654595,3.554293,,,,,,,,,,,,,, -2400,1.1425022,3.4257731,,,,,,,,,,,,,, -2500,1.0388069,3.2409708,,,,,,,,,,,,,, -2600,0.99440765,3.1783264,,,,,,,,,,,,,, -2700,1.0505192,3.055125,,,,,,,,,,,,,, -2800,0.93418926,2.9381843,,,,,,,,,,,,,, -2900,0.94370264,2.88715,,,,,,,,,,,,,, -3000,1.1566055,2.8263507,,,,,,,,,,,,,, -3100,0.96228415,2.6967132,,,,,,,,,,,,,, -3200,1.1070373,2.677327,,,,,,,,,,,,,, -3300,0.89924574,2.6485388,,,,,,,,,,,,,, -3400,0.9805911,2.5894585,,,,,,,,,,,,,, -3500,0.9831185,2.4926217,,,,,,,,,,,,,, -3514,,,2.9477355,0.6334869349812215,3.3468678,0.6825936259980497,5348.0,3.023245,0.625312290536835,2472.0,2943.0448310375214,3362.7826220989227,2943.0448310375214,419.49665999412537,0.0937092304229736,0.0 -3600,1.0634602,2.5180745,,,,,,,,,,,,,, -3700,1.0157248,2.4562056,,,,,,,,,,,,,, -3800,1.1646472,2.4087448,,,,,,,,,,,,,, -3900,0.97938764,2.3711936,,,,,,,,,,,,,, -4000,0.8527164,2.3354409,,,,,,,,,,,,,, -4100,1.0324324,2.352014,,,,,,,,,,,,,, -4200,0.82422775,2.2590895,,,,,,,,,,,,,, -4300,0.8164481,2.2601583,,,,,,,,,,,,,, -4400,0.86986834,2.1996124,,,,,,,,,,,,,, -4500,1.0098107,2.1804352,,,,,,,,,,,,,, -4600,1.0828389,2.1094346,,,,,,,,,,,,,, -4700,0.9277333,2.0849087,,,,,,,,,,,,,, -4800,0.88593674,2.0691435,,,,,,,,,,,,,, -4900,1.005391,2.0766442,,,,,,,,,,,,,, -5000,0.9979167,2.0405545,,,,,,,,,,,,,, -5100,0.87382483,1.9977913,,,,,,,,,,,,,, -5200,0.8808874,2.0188167,,,,,,,,,,,,,, -5244,,,0.7954038,0.26949412527476,1.1628219,0.3323517769388957,5348.0,0.86656547,0.2749172303130014,2472.0,4383.173229217529,4937.342960596085,4383.173229217529,553.7924075126648,0.1545546054840088,0.0 -5300,0.9456566,1.9309417,,,,,,,,,,,,,, -5400,0.7732535,1.9237714,,,,,,,,,,,,,, -5500,0.74882853,1.9416391,,,,,,,,,,,,,, -5600,0.8846495,1.9598322,,,,,,,,,,,,,, -5700,0.8924287,1.9251055,,,,,,,,,,,,,, -5800,0.7623401,1.8749583,,,,,,,,,,,,,, -5900,0.86402076,1.8576535,,,,,,,,,,,,,, -6000,1.0296663,1.8836772,,,,,,,,,,,,,, -6100,0.8704034,1.9025551,,,,,,,,,,,,,, -6200,0.7770647,1.7713923,,,,,,,,,,,,,, -6300,0.81875235,1.8574382,,,,,,,,,,,,,, -6400,0.7580315,1.8701979,,,,,,,,,,,,,, -6500,0.7444122,1.8487306,,,,,,,,,,,,,, -6600,0.76966053,1.7808175,,,,,,,,,,,,,, -6700,0.8684463,1.8210816,,,,,,,,,,,,,, -6800,0.7371543,1.7793189,,,,,,,,,,,,,, -6900,0.7922671,1.8491775,,,,,,,,,,,,,, -6967,,,0.52575487,0.1829163568578579,0.8364061,0.2505092829489172,5348.0,0.570568,0.1934881075701257,2472.0,5823.355607032776,6513.878301858902,5823.355607032776,690.0164759159088,0.2065155506134033,0.0 -7000,0.79934156,1.8281717,,,,,,,,,,,,,, -7100,0.6924585,1.8055068,,,,,,,,,,,,,, -7200,0.8903379,1.7367201,,,,,,,,,,,,,, -7300,0.73693603,1.7603211,,,,,,,,,,,,,, -7400,0.8583274,1.7326612,,,,,,,,,,,,,, -7500,0.745073,1.7058889,,,,,,,,,,,,,, -7600,0.7780295,1.7501632,,,,,,,,,,,,,, -7700,0.7993589,1.7549914,,,,,,,,,,,,,, -7800,1.0684489,1.6730757,,,,,,,,,,,,,, -7900,0.8251163,1.7692696,,,,,,,,,,,,,, -8000,0.8624465,1.7250788,,,,,,,,,,,,,, -8100,0.79784375,1.7336799,,,,,,,,,,,,,, -8200,0.8295858,1.7273443,,,,,,,,,,,,,, -8300,0.6864698,1.6806833,,,,,,,,,,,,,, -8400,0.70353276,1.6759274,,,,,,,,,,,,,, -8500,0.7282338,1.7153718,,,,,,,,,,,,,, -8600,0.6827845,1.7282449,,,,,,,,,,,,,, -8695,,,0.4505587,0.160277457652917,0.7544828,0.2271257132374948,5348.0,0.49796993,0.1694595088660045,2472.0,7263.37749004364,8088.980140447617,7263.37749004364,824.9669606685638,0.2596597671508789,0.0 -8700,0.64768744,1.6563181,,,,,,,,,,,,,, -8800,0.8042441,1.684015,,,,,,,,,,,,,, -8900,0.705964,1.6779561,,,,,,,,,,,,,, -9000,0.63630027,1.6943012,,,,,,,,,,,,,, -9100,0.67203516,1.6531202,,,,,,,,,,,,,, -9200,0.73454696,1.608157,,,,,,,,,,,,,, -9300,0.6521489,1.5724243,,,,,,,,,,,,,, -9400,0.75725037,1.6194966,,,,,,,,,,,,,, -9500,0.87692213,1.6387963,,,,,,,,,,,,,, -9600,0.6446065,1.6274964,,,,,,,,,,,,,, -9700,0.6997307,1.6081195,,,,,,,,,,,,,, -9800,0.6333958,1.6745365,,,,,,,,,,,,,, -9900,0.6059873,1.6402415,,,,,,,,,,,,,, -10000,0.711065,1.6650852,,,,,,,,,,,,,, -10100,0.78282076,1.592821,,,,,,,,,,,,,, -10200,0.65987265,1.651448,,,,,,,,,,,,,, -10300,0.72094196,1.6413532,,,,,,,,,,,,,, -10387,,,0.4173494,0.1483587710921787,0.68799996,0.2082894851173523,5348.0,0.4417418,0.1507931671846119,2472.0,8703.979434251785,9666.078685045242,8703.979434251785,961.338276863098,0.3113949298858642,0.0 -10400,0.69393575,1.5502589,,,,,,,,,,,,,, -10500,0.6216661,1.5944418,,,,,,,,,,,,,, -10600,0.84949315,1.5857592,,,,,,,,,,,,,, -10700,0.6385295,1.5828047,,,,,,,,,,,,,, -10800,0.7981868,1.5609479,,,,,,,,,,,,,, -10900,0.73938906,1.6343344,,,,,,,,,,,,,, -11000,0.72190374,1.6202759,,,,,,,,,,,,,, -11100,0.71846455,1.5634265,,,,,,,,,,,,,, -11200,0.64037234,1.5332825,,,,,,,,,,,,,, -11300,0.6698802,1.5615411,,,,,,,,,,,,,, -11400,0.63105685,1.5015067,,,,,,,,,,,,,, -11500,0.6301239,1.537061,,,,,,,,,,,,,, -11600,0.64112246,1.5521117,,,,,,,,,,,,,, -11700,0.6445789,1.5955187,,,,,,,,,,,,,, -11800,0.60725594,1.5114558,,,,,,,,,,,,,, -11900,0.7588467,1.5981224,,,,,,,,,,,,,, -12000,0.69015956,1.4694948,,,,,,,,,,,,,, -12100,0.65987605,1.5753893,,,,,,,,,,,,,, -12101,,,0.35887453,0.1294535553995703,0.641523,0.1964142618535003,5348.0,0.40709147,0.1400483415595231,2472.0,10144.905450105667,11243.875876426697,10144.905450105667,1098.0837564468384,0.3614089488983154,0.0 -12200,0.68328613,1.5214918,,,,,,,,,,,,,, -12300,0.6515946,1.5187136,,,,,,,,,,,,,, -12400,0.6560349,1.5320919,,,,,,,,,,,,,, -12500,0.62453103,1.4885573,,,,,,,,,,,,,, -12600,0.65196025,1.5200074,,,,,,,,,,,,,, -12700,0.6790752,1.5291212,,,,,,,,,,,,,, -12800,0.6397795,1.5274379,,,,,,,,,,,,,, -12900,0.7114218,1.5331764,,,,,,,,,,,,,, -13000,0.74198467,1.5542475,,,,,,,,,,,,,, -13100,0.74249756,1.4529657,,,,,,,,,,,,,, -13200,0.67975885,1.4721359,,,,,,,,,,,,,, -13300,0.7158635,1.552059,,,,,,,,,,,,,, -13400,0.650729,1.479129,,,,,,,,,,,,,, -13500,0.5620922,1.5343003,,,,,,,,,,,,,, -13600,0.9501356,1.4499832,,,,,,,,,,,,,, -13700,0.5635292,1.4920675,,,,,,,,,,,,,, -13800,0.5744808,1.4651064,,,,,,,,,,,,,, -13827,,,0.30647576,0.1146051639408201,0.61191267,0.1859196539772343,5348.0,0.38234127,0.130075355960433,2472.0,11585.165503025057,12821.49998164177,11585.165503025057,1235.3221170902252,0.4123513698577881,0.0 -13900,0.6559535,1.5090206,,,,,,,,,,,,,, -14000,0.7628954,1.4595101,,,,,,,,,,,,,, -14100,0.6413032,1.4881204,,,,,,,,,,,,,, -14200,0.6567458,1.5358716,,,,,,,,,,,,,, -14300,0.92729443,1.4335325,,,,,,,,,,,,,, -14400,0.6147966,1.4286143,,,,,,,,,,,,,, -14500,0.6773686,1.4477313,,,,,,,,,,,,,, -14600,0.67846245,1.4201149,,,,,,,,,,,,,, -14700,0.68167377,1.4687366,,,,,,,,,,,,,, -14800,0.8660353,1.4478533,,,,,,,,,,,,,, -14900,0.62025887,1.4340863,,,,,,,,,,,,,, -15000,0.66247,1.4967519,,,,,,,,,,,,,, -15100,0.60751045,1.4615797,,,,,,,,,,,,,, -15200,0.6630311,1.5192316,,,,,,,,,,,,,, -15300,0.675404,1.4935741,,,,,,,,,,,,,, -15400,0.68552065,1.4000131,,,,,,,,,,,,,, -15500,0.6636734,1.4496853,,,,,,,,,,,,,, -15527,,,0.2841638,0.1044799940415706,0.58510756,0.1773559767129768,5348.0,0.35999078,0.122214774643024,2472.0,13025.10080051422,14397.421264886856,13025.10080051422,1371.180969953537,0.4650459289550781,0.0 -15600,0.73554134,1.5115589,,,,,,,,,,,,,, -15700,0.6579476,1.444411,,,,,,,,,,,,,, -15800,0.7455279,1.4667352,,,,,,,,,,,,,, -15900,0.67034096,1.4634893,,,,,,,,,,,,,, -16000,0.63733023,1.4149854,,,,,,,,,,,,,, -16100,0.6187164,1.4924158,,,,,,,,,,,,,, -16200,0.72333634,1.4156334,,,,,,,,,,,,,, -16300,0.6054996,1.4215707,,,,,,,,,,,,,, -16400,0.7310855,1.4677651,,,,,,,,,,,,,, -16500,0.7030027,1.4439762,,,,,,,,,,,,,, -16600,0.555935,1.4189066,,,,,,,,,,,,,, -16700,0.67022514,1.3849306,,,,,,,,,,,,,, -16800,1.1302574,1.4298592,,,,,,,,,,,,,, -16900,0.7321451,1.44813,,,,,,,,,,,,,, -17000,0.96088654,1.4698617,,,,,,,,,,,,,, -17100,0.6000177,1.4370447,,,,,,,,,,,,,, -17200,0.7038398,1.4562172,,,,,,,,,,,,,, -17244,,,0.27435565,0.1028117575764076,0.569697,0.172161773366674,5348.0,0.3494229,0.1185993134686084,2472.0,14466.018681049349,15974.59987092018,14466.018681049349,1507.310597896576,0.5201177597045898,0.0 -17300,0.66315126,1.4300021,,,,,,,,,,,,,, -17400,0.65051156,1.4339664,,,,,,,,,,,,,, -17500,0.62476933,1.4459366,,,,,,,,,,,,,, -17600,0.6022956,1.4194677,,,,,,,,,,,,,, -17700,0.6147548,1.4697963,,,,,,,,,,,,,, -17800,0.70684236,1.5104058,,,,,,,,,,,,,, -17900,0.6380869,1.5357603,,,,,,,,,,,,,, -18000,0.78071785,1.3648059,,,,,,,,,,,,,, -18100,0.67835116,1.4194534,,,,,,,,,,,,,, -18200,0.7664039,1.4697574,,,,,,,,,,,,,, -18300,0.5983496,1.4086372,,,,,,,,,,,,,, -18400,0.83691347,1.439029,,,,,,,,,,,,,, -18500,0.7651355,1.4339241,,,,,,,,,,,,,, -18600,0.82657075,1.4487562,,,,,,,,,,,,,, -18700,0.60261524,1.3268108,,,,,,,,,,,,,, -18800,0.6365854,1.3955752,,,,,,,,,,,,,, -18900,0.62659746,1.3812962,,,,,,,,,,,,,, -18961,,,0.28704607,0.1045823527856368,0.55894655,0.170549446305647,5348.0,0.34298202,0.1168525176202953,2472.0,15906.14902305603,17551.447067975998,15906.14902305603,1643.9017758369446,0.5702481269836426,0.0 -19000,0.7587826,1.3816824,,,,,,,,,,,,,, -19100,0.6992227,1.4402902,,,,,,,,,,,,,, -19200,0.7441305,1.3830134,,,,,,,,,,,,,, -19300,0.6517629,1.4168357,,,,,,,,,,,,,, -19400,0.61884314,1.3863237,,,,,,,,,,,,,, -19500,0.6043355,1.397794,,,,,,,,,,,,,, -19600,0.79363936,1.3984053,,,,,,,,,,,,,, -19700,0.649858,1.3554873,,,,,,,,,,,,,, -19800,0.8777503,1.4205059,,,,,,,,,,,,,, -19900,0.69406503,1.4363245,,,,,,,,,,,,,, -20000,0.79215056,1.451856,,,,,,,,,,,,,, -20100,0.82793117,1.418374,,,,,,,,,,,,,, -20200,0.762567,1.4813638,,,,,,,,,,,,,, -20300,0.8406809,1.4753274,,,,,,,,,,,,,, -20400,0.71774834,1.4251635,,,,,,,,,,,,,, -20500,0.7354948,1.4218675,,,,,,,,,,,,,, -20600,0.9340161,1.3452909,,,,,,,,,,,,,, -20642,,,0.2800363,0.1034584641359115,0.54050195,0.1640229008370584,5348.0,0.32796475,0.1121605427254077,2472.0,17346.25297522545,19130.841157197952,17346.25297522545,1783.0626783370972,0.6249892711639404,0.0 -20700,0.62366325,1.3836172,,,,,,,,,,,,,, -20800,0.5796571,1.3731681,,,,,,,,,,,,,, -20900,0.67261344,1.433846,,,,,,,,,,,,,, -21000,0.5653207,1.4397328,,,,,,,,,,,,,, -21100,0.6546094,1.4287719,,,,,,,,,,,,,, -21200,0.6161981,1.4105135,,,,,,,,,,,,,, -21300,0.69788665,1.3836563,,,,,,,,,,,,,, -21400,0.6986268,1.375335,,,,,,,,,,,,,, -21500,0.80547386,1.3836566,,,,,,,,,,,,,, -21600,0.82866836,1.4096817,,,,,,,,,,,,,, -21700,0.74424744,1.3495446,,,,,,,,,,,,,, -21800,0.6536734,1.4206704,,,,,,,,,,,,,, -21900,0.798403,1.449847,,,,,,,,,,,,,, -22000,0.79706,1.3613766,,,,,,,,,,,,,, -22100,0.6982113,1.3914559,,,,,,,,,,,,,, -22200,0.6022641,1.3518814,,,,,,,,,,,,,, -22300,0.6710346,1.3952956,,,,,,,,,,,,,, -22354,,,0.2611929,0.096928696367939,0.5234255,0.1602672407967019,5348.0,0.3156809,0.109093494201044,2472.0,18786.501445770264,20706.88010787964,18786.501445770264,1918.723219871521,0.6793062686920166,0.0 -22400,0.7470589,1.3298886,,,,,,,,,,,,,, -22500,0.6726725,1.3465923,,,,,,,,,,,,,, -22600,0.8025694,1.3732011,,,,,,,,,,,,,, -22700,0.7435571,1.3579339,,,,,,,,,,,,,, -22800,0.60986143,1.3893054,,,,,,,,,,,,,, -22900,0.628198,1.3448911,,,,,,,,,,,,,, -23000,0.66550756,1.353439,,,,,,,,,,,,,, -23100,0.8400612,1.4058238,,,,,,,,,,,,,, -23200,0.6296953,1.3410085,,,,,,,,,,,,,, -23300,1.0128386,1.3705705,,,,,,,,,,,,,, -23400,0.9261251,1.327226,,,,,,,,,,,,,, -23500,0.74159247,1.3846698,,,,,,,,,,,,,, -23600,0.7061557,1.4040153,,,,,,,,,,,,,, -23700,0.60507053,1.3411444,,,,,,,,,,,,,, -23800,0.7285927,1.3625139,,,,,,,,,,,,,, -23900,0.63110894,1.3166225,,,,,,,,,,,,,, -24000,0.81612796,1.3853756,,,,,,,,,,,,,, -24076,,,0.24182494,0.0891273594865081,0.5162125,0.1568881122256871,5348.0,0.3084755,0.104239026669104,2472.0,20226.675240516663,22282.030920743942,20226.675240516663,2053.571811437607,0.733173131942749,0.0 -24100,0.6477688,1.3757505,,,,,,,,,,,,,, -24200,0.67578876,1.392052,,,,,,,,,,,,,, -24300,0.7500927,1.2759993,,,,,,,,,,,,,, -24400,0.64367014,1.4004053,,,,,,,,,,,,,, -24500,0.7478749,1.3657207,,,,,,,,,,,,,, -24600,0.6990137,1.3357265,,,,,,,,,,,,,, -24700,0.62802,1.2940209,,,,,,,,,,,,,, -24800,0.5671781,1.2699982,,,,,,,,,,,,,, -24900,0.7051537,1.3856363,,,,,,,,,,,,,, -25000,0.6558521,1.3908867,,,,,,,,,,,,,, -25100,0.5672531,1.3641068,,,,,,,,,,,,,, -25200,0.80245197,1.3216106,,,,,,,,,,,,,, -25300,0.5849475,1.2344513,,,,,,,,,,,,,, -25400,0.6122191,1.3664142,,,,,,,,,,,,,, -25500,0.6530482,1.3280644,,,,,,,,,,,,,, -25600,0.70393515,1.3203281,,,,,,,,,,,,,, -25700,0.72405946,1.3473594,,,,,,,,,,,,,, -25751,,,0.22988309,0.0859229505203578,0.50804406,0.1546192687565772,5348.0,0.30117932,0.102309426604107,2472.0,21666.55222582817,23857.990137577057,21666.55222582817,2189.5265684127808,0.7874441146850586,0.0 -25800,0.6754437,1.3192434,,,,,,,,,,,,,, -25900,0.7269005,1.3125058,,,,,,,,,,,,,, -26000,0.687118,1.3295193,,,,,,,,,,,,,, -26100,0.58105344,1.270664,,,,,,,,,,,,,, -26200,0.8217728,1.2905372,,,,,,,,,,,,,, -26300,0.6056215,1.3344231,,,,,,,,,,,,,, -26400,0.7458436,1.3354728,,,,,,,,,,,,,, -26500,0.77035517,1.3440167,,,,,,,,,,,,,, -26600,0.6526656,1.3226383,,,,,,,,,,,,,, -26700,0.52433705,1.2897996,,,,,,,,,,,,,, -26800,0.6312906,1.3208681,,,,,,,,,,,,,, -26900,0.79161686,1.3278357,,,,,,,,,,,,,, -27000,0.8646977,1.29617,,,,,,,,,,,,,, -27100,0.6077049,1.2937477,,,,,,,,,,,,,, -27200,0.609737,1.2693036,,,,,,,,,,,,,, -27300,0.6639093,1.3371339,,,,,,,,,,,,,, -27400,0.73334676,1.2987107,,,,,,,,,,,,,, -27470,,,0.22008015,0.081281770668194,0.49360743,0.1496664317367755,5348.0,0.29376552,0.1009891739280563,2472.0,23107.052596330643,25434.80774664879,23107.052596330643,2325.708679676056,0.8463056087493896,0.0 -27500,0.6887239,1.336362,,,,,,,,,,,,,, -27600,1.0435338,1.3268017,,,,,,,,,,,,,, -27700,0.66617286,1.3212525,,,,,,,,,,,,,, -27800,0.6990179,1.3003204,,,,,,,,,,,,,, -27900,0.75705355,1.2917191,,,,,,,,,,,,,, -28000,0.56207263,1.3040258,,,,,,,,,,,,,, -28100,0.71390504,1.3796937,,,,,,,,,,,,,, -28200,0.7236176,1.2707784,,,,,,,,,,,,,, -28300,0.74121886,1.3006315,,,,,,,,,,,,,, -28400,0.5628987,1.281238,,,,,,,,,,,,,, -28500,0.7187665,1.3288741,,,,,,,,,,,,,, -28600,0.75485724,1.3602,,,,,,,,,,,,,, -28700,0.91615456,1.3195825,,,,,,,,,,,,,, -28800,0.8545947,1.3176805,,,,,,,,,,,,,, -28900,0.7619603,1.2544037,,,,,,,,,,,,,, -29000,0.6421611,1.3079617,,,,,,,,,,,,,, -29100,0.87256235,1.2838128,,,,,,,,,,,,,, -29189,,,0.2380161,0.0883168660079736,0.48860037,0.1491933537368334,5348.0,0.2872465,0.0984908496333759,2472.0,24547.576735258102,27014.36773943901,24547.576735258102,2464.611401796341,0.904517650604248,0.0 -29200,0.57194614,1.2952436,,,,,,,,,,,,,, -29300,0.6526393,1.2976573,,,,,,,,,,,,,, -29400,0.77434,1.251581,,,,,,,,,,,,,, -29500,0.77251256,1.3187861,,,,,,,,,,,,,, -29600,0.9486395,1.3022366,,,,,,,,,,,,,, -29700,0.70372224,1.3171319,,,,,,,,,,,,,, -29800,0.64291364,1.2971644,,,,,,,,,,,,,, -29900,0.7346223,1.2846422,,,,,,,,,,,,,, -30000,0.69094527,1.2722075,,,,,,,,,,,,,, -30100,0.64590985,1.2580242,,,,,,,,,,,,,, -30200,0.56471187,1.295395,,,,,,,,,,,,,, -30300,0.7913751,1.2674779,,,,,,,,,,,,,, -30400,0.800326,1.2895658,,,,,,,,,,,,,, -30500,0.6348039,1.2857236,,,,,,,,,,,,,, -30600,0.7468616,1.2985436,,,,,,,,,,,,,, -30700,0.6073628,1.2792289,,,,,,,,,,,,,, -30800,0.6752466,1.2857479,,,,,,,,,,,,,, -30881,,,0.2175279,0.0792353729537328,0.47386247,0.1434874537783484,5348.0,0.28259403,0.0946316495033818,2472.0,25988.7187268734,28590.486193180084,25988.7187268734,2599.4616689682007,0.9569101333618164,0.0 -30900,0.78674155,1.2424377,,,,,,,,,,,,,, -31000,0.67707616,1.322627,,,,,,,,,,,,,, -31100,0.86740893,1.2759007,,,,,,,,,,,,,, -31200,0.6597069,1.3107692,,,,,,,,,,,,,, -31300,0.58037037,1.2623191,,,,,,,,,,,,,, -31400,0.64194286,1.2760346,,,,,,,,,,,,,, -31500,0.74933046,1.2144266,,,,,,,,,,,,,, -31600,0.81445086,1.2987128,,,,,,,,,,,,,, -31700,0.75007355,1.2610822,,,,,,,,,,,,,, -31800,0.63572276,1.2354591,,,,,,,,,,,,,, -31900,1.0991787,1.243888,,,,,,,,,,,,,, -32000,0.66666627,1.2729658,,,,,,,,,,,,,, -32100,0.6164185,1.2433016,,,,,,,,,,,,,, -32200,0.7146986,1.262639,,,,,,,,,,,,,, -32300,0.6347242,1.2639403,,,,,,,,,,,,,, -32400,0.7120727,1.2839364,,,,,,,,,,,,,, -32500,0.63275343,1.2574735,,,,,,,,,,,,,, -32563,,,0.22219987,0.0804956104978193,0.46052456,0.1377525898606833,5348.0,0.26877353,0.091117746227124,2472.0,27428.926631212234,30169.796103715897,27428.926631212234,2738.427877187729,1.0178043842315674,0.0 -32600,0.8546767,1.2419361,,,,,,,,,,,,,, -32700,0.644226,1.3235408,,,,,,,,,,,,,, -32800,0.5931894,1.256291,,,,,,,,,,,,,, -32900,0.6766503,1.3243471,,,,,,,,,,,,,, -33000,0.74134964,1.2324303,,,,,,,,,,,,,, -33100,0.6469912,1.2300451,,,,,,,,,,,,,, -33200,0.62175035,1.2335056,,,,,,,,,,,,,, -33300,0.87087727,1.2671919,,,,,,,,,,,,,, -33400,0.68927217,1.2614985,,,,,,,,,,,,,, -33500,0.63317513,1.2481512,,,,,,,,,,,,,, -33600,0.67411995,1.2597804,,,,,,,,,,,,,, -33700,0.6812779,1.2407352,,,,,,,,,,,,,, -33800,0.7125078,1.2446526,,,,,,,,,,,,,, -33900,0.6217163,1.2961664,,,,,,,,,,,,,, -34000,0.5972016,1.204169,,,,,,,,,,,,,, -34100,0.7553127,1.286994,,,,,,,,,,,,,, -34200,0.77021575,1.2986081,,,,,,,,,,,,,, -34277,,,0.21909791,0.0772190074277737,0.4654405,0.1390752773299091,5348.0,0.27387086,0.091930209412386,2472.0,28868.84626793861,31746.674035787582,28868.84626793861,2875.2527759075165,1.0767803192138672,0.0 -34300,0.65593916,1.2211198,,,,,,,,,,,,,, -34400,0.62876403,1.1932827,,,,,,,,,,,,,, -34500,0.814177,1.2851335,,,,,,,,,,,,,, -34600,0.68779445,1.240685,,,,,,,,,,,,,, -34700,0.6582378,1.2807341,,,,,,,,,,,,,, -34800,0.6376793,1.3182257,,,,,,,,,,,,,, -34900,0.6949795,1.2683675,,,,,,,,,,,,,, -35000,0.7541608,1.2274737,,,,,,,,,,,,,, -35100,0.69053984,1.1851968,,,,,,,,,,,,,, -35200,0.690876,1.2804512,,,,,,,,,,,,,, -35300,0.7273471,1.2522532,,,,,,,,,,,,,, -35400,0.6793444,1.2542722,,,,,,,,,,,,,, -35500,0.69264615,1.2733774,,,,,,,,,,,,,, -35600,0.7075763,1.2292138,,,,,,,,,,,,,, -35700,0.6147693,1.291663,,,,,,,,,,,,,, -35800,0.7790469,1.305792,,,,,,,,,,,,,, -35900,0.79504323,1.2494463,,,,,,,,,,,,,, -35964,,,0.17266439,0.0654295173375152,0.45770597,0.1368643617791594,5348.0,0.26163143,0.0876647776897609,2472.0,30309.341188192368,33325.48345398903,30309.341188192368,3013.4304831027985,1.1376783847808838,0.0 -36000,0.6539139,1.2059726,,,,,,,,,,,,,, -36100,0.8531903,1.2498834,,,,,,,,,,,,,, -36200,0.6149767,1.1993484,,,,,,,,,,,,,, -36300,0.79136395,1.2327197,,,,,,,,,,,,,, -36400,0.7518683,1.2160684,,,,,,,,,,,,,, -36500,0.72013533,1.234551,,,,,,,,,,,,,, -36600,0.6562839,1.2567906,,,,,,,,,,,,,, -36700,0.71790373,1.214073,,,,,,,,,,,,,, -36800,0.6551865,1.2501032,,,,,,,,,,,,,, -36900,0.72195894,1.2046266,,,,,,,,,,,,,, -37000,0.6644481,1.2352813,,,,,,,,,,,,,, -37100,0.7316082,1.2478317,,,,,,,,,,,,,, -37200,0.82848465,1.281985,,,,,,,,,,,,,, -37300,0.8177029,1.2707933,,,,,,,,,,,,,, -37400,0.68667597,1.1948159,,,,,,,,,,,,,, -37500,0.83390033,1.1863567,,,,,,,,,,,,,, -37600,0.75335133,1.2361773,,,,,,,,,,,,,, -37649,,,0.19681297,0.0729393254630314,0.4440631,0.134247950799888,5348.0,0.26120028,0.0886600450917067,2472.0,31749.26070761681,34901.451577186584,31749.26070761681,3149.3453755378723,1.1959314346313477,0.0 -37700,0.7031397,1.2500184,,,,,,,,,,,,,, -37800,0.61461055,1.1707859,,,,,,,,,,,,,, -37900,0.752389,1.213176,,,,,,,,,,,,,, -38000,1.0638314,1.3118026,,,,,,,,,,,,,, -38100,0.69297355,1.1799399,,,,,,,,,,,,,, -38200,0.7386371,1.2072239,,,,,,,,,,,,,, -38300,0.65518147,1.2178442,,,,,,,,,,,,,, -38400,0.6604994,1.2042015,,,,,,,,,,,,,, -38500,0.71914613,1.2383995,,,,,,,,,,,,,, -38600,0.7124717,1.2052653,,,,,,,,,,,,,, -38700,0.5953294,1.1714454,,,,,,,,,,,,,, -38800,0.8855479,1.2014538,,,,,,,,,,,,,, -38900,0.75971663,1.2798623,,,,,,,,,,,,,, -39000,0.7120806,1.1981605,,,,,,,,,,,,,, -39100,0.606336,1.2221307,,,,,,,,,,,,,, -39200,0.7251811,1.177333,,,,,,,,,,,,,, -39300,0.70306003,1.1715922,,,,,,,,,,,,,, -39356,,,0.24031529,0.0871487228999311,0.4375541,0.1294013149637467,5348.0,0.2517394,0.0832165417504519,2472.0,33189.13189792633,36475.276661872864,33189.13189792633,3283.162401914597,1.2573490142822266,0.0 -39400,0.90132993,1.222568,,,,,,,,,,,,,, -39500,0.7306863,1.2540857,,,,,,,,,,,,,, -39600,0.76023394,1.2390393,,,,,,,,,,,,,, -39700,0.69895214,1.1619873,,,,,,,,,,,,,, -39800,0.75819254,1.212798,,,,,,,,,,,,,, -39900,0.7057288,1.2165184,,,,,,,,,,,,,, -40000,0.9407239,1.2166992,,,,,,,,,,,,,, -40100,0.6770496,1.1644087,,,,,,,,,,,,,, -40200,0.6446714,1.1833783,,,,,,,,,,,,,, -40300,0.7077271,1.1698993,,,,,,,,,,,,,, -40400,0.71599317,1.1358299,,,,,,,,,,,,,, -40500,0.7109043,1.1553804,,,,,,,,,,,,,, -40600,0.70982605,1.1324085,,,,,,,,,,,,,, -40700,0.6643988,1.24641,,,,,,,,,,,,,, -40800,0.75362146,1.2347317,,,,,,,,,,,,,, -40900,0.64797366,1.2169065,,,,,,,,,,,,,, -41000,0.69567955,1.2109348,,,,,,,,,,,,,, -41039,,,0.24005286,0.0886769514314105,0.4300553,0.1283875763924423,5348.0,0.24431132,0.0838462007190299,2472.0,34629.38103199005,38051.03439188004,34629.38103199005,3418.5341413021088,1.3193349838256836,0.0 -41100,0.67342997,1.1922091,,,,,,,,,,,,,, -41200,0.77845204,1.1413348,,,,,,,,,,,,,, -41300,0.7158208,1.1437448,,,,,,,,,,,,,, -41400,0.7463733,1.1467724,,,,,,,,,,,,,, -41500,0.8305434,1.1885482,,,,,,,,,,,,,, -41600,0.6928373,1.1978648,,,,,,,,,,,,,, -41700,0.72422945,1.1840078,,,,,,,,,,,,,, -41800,0.7349388,1.1668426,,,,,,,,,,,,,, -41900,0.99358577,1.2266519,,,,,,,,,,,,,, -42000,0.76450527,1.1686654,,,,,,,,,,,,,, -42100,0.79226995,1.1899974,,,,,,,,,,,,,, -42200,0.74695796,1.2443663,,,,,,,,,,,,,, -42300,0.7552991,1.1768104,,,,,,,,,,,,,, -42400,0.67067933,1.1758245,,,,,,,,,,,,,, -42500,0.73988295,1.1759131,,,,,,,,,,,,,, -42600,1.0176932,1.1320069,,,,,,,,,,,,,, -42700,0.71265334,1.1305882,,,,,,,,,,,,,, -42750,,,0.28034556,0.1032108615919181,0.42529032,0.1256649642295104,5348.0,0.24818589,0.0830540491133995,2472.0,36069.76509022713,39624.3362300396,36069.76509022713,3551.3137764930725,1.3813042640686035,0.0 -42800,0.86331546,1.1226909,,,,,,,,,,,,,, -42900,0.8365447,1.159668,,,,,,,,,,,,,, -43000,0.62446785,1.2205324,,,,,,,,,,,,,, -43100,0.8447979,1.1470011,,,,,,,,,,,,,, -43200,0.659398,1.1921204,,,,,,,,,,,,,, -43300,0.72082293,1.166043,,,,,,,,,,,,,, -43400,0.713903,1.1363018,,,,,,,,,,,,,, -43500,0.6952341,1.2025033,,,,,,,,,,,,,, -43600,0.9184263,1.1060138,,,,,,,,,,,,,, -43700,0.6692974,1.1626333,,,,,,,,,,,,,, -43800,0.7137855,1.1162854,,,,,,,,,,,,,, -43900,0.80888104,1.1180872,,,,,,,,,,,,,, -44000,0.7358304,1.1812036,,,,,,,,,,,,,, -44100,0.7448329,1.1426178,,,,,,,,,,,,,, -44200,0.7924396,1.1689894,,,,,,,,,,,,,, -44300,0.66086704,1.1228094,,,,,,,,,,,,,, -44400,0.87662,1.1994653,,,,,,,,,,,,,, -44453,,,0.24660096,0.0885492407962366,0.41894877,0.124660880311266,5348.0,0.23739682,0.080738529035403,2472.0,37509.9409840107,41198.581345796585,37509.9409840107,3685.243673801422,1.445460557937622,0.0 -44500,0.79168856,1.1790582,,,,,,,,,,,,,, -44600,0.8037691,1.2148416,,,,,,,,,,,,,, -44700,0.8533854,1.1194285,,,,,,,,,,,,,, -44800,0.80185,1.1437991,,,,,,,,,,,,,, -44900,0.87198424,1.1213038,,,,,,,,,,,,,, -45000,0.64082307,1.1291485,,,,,,,,,,,,,, -45100,0.80487096,1.2252519,,,,,,,,,,,,,, -45200,0.95359266,1.1321455,,,,,,,,,,,,,, -45300,0.69775677,1.1371171,,,,,,,,,,,,,, -45400,0.7689245,1.1642203,,,,,,,,,,,,,, -45500,0.79798204,1.1350112,,,,,,,,,,,,,, -45600,0.91324186,1.1343853,,,,,,,,,,,,,, -45700,0.69538623,1.0936298,,,,,,,,,,,,,, -45800,0.8390462,1.1644479,,,,,,,,,,,,,, -45900,0.6528085,1.1568842,,,,,,,,,,,,,, -46000,0.6813791,1.1690294,,,,,,,,,,,,,, -46100,0.7225102,1.1660936,,,,,,,,,,,,,, -46142,,,0.22179091,0.0823600146972041,0.40570727,0.1204804155362677,5348.0,0.23045741,0.0773871183961976,2472.0,38950.01946210861,42772.61559915543,38950.01946210861,3819.056458711624,1.5136487483978271,0.0 -46200,0.718464,1.1457474,,,,,,,,,,,,,, -46300,0.8455841,1.1579596,,,,,,,,,,,,,, -46400,0.7248313,1.1021103,,,,,,,,,,,,,, -46500,0.82824767,1.1701171,,,,,,,,,,,,,, -46600,0.7207802,1.1365491,,,,,,,,,,,,,, -46700,0.70034486,1.1184516,,,,,,,,,,,,,, -46800,0.7232025,1.1353682,,,,,,,,,,,,,, -46900,0.7582893,1.1122981,,,,,,,,,,,,,, -47000,0.7912437,1.1483929,,,,,,,,,,,,,, -47100,0.7833453,1.1389925,,,,,,,,,,,,,, -47200,0.7807218,1.1765183,,,,,,,,,,,,,, -47300,0.7975825,1.1109152,,,,,,,,,,,,,, -47400,1.1488451,1.0664793,,,,,,,,,,,,,, -47500,0.7248832,1.1606182,,,,,,,,,,,,,, -47600,0.8297002,1.1492778,,,,,,,,,,,,,, -47700,0.9769869,1.1051509,,,,,,,,,,,,,, -47800,0.67585087,1.1369678,,,,,,,,,,,,,, -47851,,,0.18740244,0.0704520984458776,0.4032107,0.1185205209650791,5348.0,0.22821124,0.0758231267645684,2472.0,40389.98576760292,44348.57243299484,40389.98576760292,3954.9089529514313,1.5759549140930176,0.0 -47900,0.6428344,1.0458593,,,,,,,,,,,,,, -48000,1.0803133,1.1224073,,,,,,,,,,,,,, -48100,0.8673199,1.1860496,,,,,,,,,,,,,, -48200,0.77653646,1.1594312,,,,,,,,,,,,,, -48300,0.78725255,1.1819597,,,,,,,,,,,,,, -48400,0.9807081,1.1625035,,,,,,,,,,,,,, -48500,0.882973,1.1428615,,,,,,,,,,,,,, -48600,0.68160915,1.1152817,,,,,,,,,,,,,, -48700,0.74651253,1.1091698,,,,,,,,,,,,,, -48800,1.0123614,1.1303676,,,,,,,,,,,,,, -48900,0.7717175,1.1414831,,,,,,,,,,,,,, -49000,0.77786916,1.1147933,,,,,,,,,,,,,, -49100,0.75066787,1.1344794,,,,,,,,,,,,,, -49200,0.7057997,1.141174,,,,,,,,,,,,,, -49300,0.9142586,1.104987,,,,,,,,,,,,,, -49400,0.9884544,1.1721637,,,,,,,,,,,,,, -49500,1.0689516,1.0815136,,,,,,,,,,,,,, -49534,,,0.20381515,0.0767133554557452,0.3923336,0.1146876237002423,5348.0,0.21855989,0.0722889119086791,2472.0,41830.2703332901,45924.74306154251,41830.2703332901,4090.652411699295,1.6439871788024902,0.0 -49600,0.76888245,1.0952221,,,,,,,,,,,,,, -49700,0.7706219,1.1417465,,,,,,,,,,,,,, -49800,1.1337559,1.1161625,,,,,,,,,,,,,, -49900,0.7472557,1.1615933,,,,,,,,,,,,,, -50000,0.91187,1.0616984,,,,,,,,,,,,,, -50100,0.9531522,1.0911999,,,,,,,,,,,,,, -50200,0.7224176,1.0777048,,,,,,,,,,,,,, -50300,0.7431973,1.1289257,,,,,,,,,,,,,, -50400,0.7989531,1.1074482,,,,,,,,,,,,,, -50500,0.76252127,1.0747013,,,,,,,,,,,,,, -50600,1.0430225,1.127576,,,,,,,,,,,,,, -50700,0.8283536,1.076345,,,,,,,,,,,,,, -50800,0.7285509,1.0915629,,,,,,,,,,,,,, -50900,1.0754353,1.0923473,,,,,,,,,,,,,, -51000,0.77981204,1.1027385,,,,,,,,,,,,,, -51100,0.74310005,1.0847068,,,,,,,,,,,,,, -51200,0.8881556,1.1260073,,,,,,,,,,,,,, -51230,,,0.17861539,0.0677503865315237,0.3843447,0.1125925639862131,5348.0,0.21695636,0.0702374423658928,2472.0,43270.42231464386,47500.26622223854,43270.42231464386,4225.891577243805,1.700354814529419,0.0 -51300,0.7191764,1.1492397,,,,,,,,,,,,,, -51400,0.76079607,1.1579746,,,,,,,,,,,,,, -51500,0.909247,1.0558759,,,,,,,,,,,,,, -51600,1.1588367,1.0714793,,,,,,,,,,,,,, -51700,1.188841,1.0516596,,,,,,,,,,,,,, -51800,1.2503406,1.0976282,,,,,,,,,,,,,, -51900,0.9223369,1.0472381,,,,,,,,,,,,,, -52000,0.845656,1.1466397,,,,,,,,,,,,,, -52100,0.8293502,1.0250248,,,,,,,,,,,,,, -52200,0.8261549,1.0796436,,,,,,,,,,,,,, -52300,1.1047745,1.112994,,,,,,,,,,,,,, -52400,1.0822947,1.1000372,,,,,,,,,,,,,, -52500,0.7724051,1.0585049,,,,,,,,,,,,,, -52600,0.89980054,1.0619038,,,,,,,,,,,,,, -52700,0.7909698,1.1062909,,,,,,,,,,,,,, -52800,0.8317158,1.0703918,,,,,,,,,,,,,, -52900,0.97989994,1.0540675,,,,,,,,,,,,,, -52957,,,0.17695369,0.0670163486260428,0.372745,0.1091941261090782,5348.0,0.21086185,0.0709686592326285,2472.0,44710.28479671478,49075.17360305786,44710.28479671478,4360.704236030579,1.857272148132324,0.0 -53000,0.81596965,1.076597,,,,,,,,,,,,,, -53100,0.9540121,1.0756612,,,,,,,,,,,,,, -53200,1.4432548,1.1215119,,,,,,,,,,,,,, -53300,0.8002079,1.067008,,,,,,,,,,,,,, -53400,0.97397166,1.0846887,,,,,,,,,,,,,, -53500,0.78449464,1.0311509,,,,,,,,,,,,,, -53600,0.8903183,1.0391089,,,,,,,,,,,,,, -53700,1.1101714,1.0362831,,,,,,,,,,,,,, -53800,0.9102076,1.0966475,,,,,,,,,,,,,, -53900,0.8885847,1.0811079,,,,,,,,,,,,,, -54000,0.8369734,1.058962,,,,,,,,,,,,,, -54100,0.8206723,1.0427029,,,,,,,,,,,,,, -54200,0.83552057,1.0619683,,,,,,,,,,,,,, -54300,0.9241311,1.0553476,,,,,,,,,,,,,, -54400,0.8205738,1.0816746,,,,,,,,,,,,,, -54500,0.7947591,1.0435207,,,,,,,,,,,,,, -54600,0.79920983,1.00032,,,,,,,,,,,,,, -54652,,,0.16497089,0.0618523119392684,0.36614746,0.1077169641908918,5348.0,0.19986272,0.067393821217476,2472.0,46150.37541222572,50651.51419734955,46150.37541222572,4496.820524454117,1.9165270328521729,0.0 -54700,1.2829907,1.0517344,,,,,,,,,,,,,, -54800,0.9208271,1.0624799,,,,,,,,,,,,,, -54900,0.9145169,1.0684385,,,,,,,,,,,,,, -55000,0.86736465,1.0155278,,,,,,,,,,,,,, -55100,0.95742685,1.038059,,,,,,,,,,,,,, -55200,0.8972717,1.0655937,,,,,,,,,,,,,, -55300,0.9157743,1.0617697,,,,,,,,,,,,,, -55400,0.9646702,1.0254468,,,,,,,,,,,,,, -55500,0.7882925,1.0545193,,,,,,,,,,,,,, -55600,0.8137432,1.0090344,,,,,,,,,,,,,, -55700,0.8581885,1.0016174,,,,,,,,,,,,,, -55800,0.8917947,1.01624,,,,,,,,,,,,,, -55900,0.8282839,1.0338117,,,,,,,,,,,,,, -56000,0.90252256,1.0824081,,,,,,,,,,,,,, -56100,1.3267953,1.0421996,,,,,,,,,,,,,, -56200,1.0615904,1.0663282,,,,,,,,,,,,,, -56300,0.8557458,1.0673566,,,,,,,,,,,,,, -56368,,,0.16453665,0.0620743335279237,0.3558999,0.1043861088851772,5348.0,0.1941069,0.0644080190116385,2472.0,47590.53766059876,52226.748883485794,47590.53766059876,4631.75553393364,1.9792718887329104,0.0 -56400,0.79828805,0.9718461,,,,,,,,,,,,,, -56500,0.89282995,1.0233878,,,,,,,,,,,,,, -56600,0.92987883,1.0121347,,,,,,,,,,,,,, -56700,0.8123839,0.9740809,,,,,,,,,,,,,, -56800,0.8237442,1.0300121,,,,,,,,,,,,,, -56900,0.9523564,1.0063674,,,,,,,,,,,,,, -57000,1.0906367,1.0397441,,,,,,,,,,,,,, -57100,0.90550494,1.0056157,,,,,,,,,,,,,, -57200,0.88162744,1.0293674,,,,,,,,,,,,,, -57300,0.84844035,1.0553455,,,,,,,,,,,,,, -57400,0.9144842,1.0439744,,,,,,,,,,,,,, -57500,1.0140915,1.0006171,,,,,,,,,,,,,, -57600,0.8571702,0.98247266,,,,,,,,,,,,,, -57700,0.82739,0.9817743,,,,,,,,,,,,,, -57800,1.051692,0.9776865,,,,,,,,,,,,,, -57900,1.0379449,0.9985233,,,,,,,,,,,,,, -58000,0.76958025,0.9913785,,,,,,,,,,,,,, -58095,,,0.16163893,0.0600747122055348,0.34983188,0.1020110642324068,5348.0,0.19088641,0.0614425283854325,2472.0,49030.65796470642,53802.68481111527,49030.65796470642,4767.43133020401,2.0423855781555176,0.0 -58100,0.9799785,0.9929157,,,,,,,,,,,,,, -58200,0.91795766,1.0265533,,,,,,,,,,,,,, -58300,0.957235,1.0135585,,,,,,,,,,,,,, -58400,0.94324654,0.9751253,,,,,,,,,,,,,, -58500,0.8645518,1.0377479,,,,,,,,,,,,,, -58600,0.9343125,0.9940762,,,,,,,,,,,,,, -58700,0.88657916,0.99038005,,,,,,,,,,,,,, -58800,0.99335283,1.0852742,,,,,,,,,,,,,, -58900,0.8105494,0.9658885,,,,,,,,,,,,,, -59000,0.91414934,1.060564,,,,,,,,,,,,,, -59100,0.8996174,0.99451953,,,,,,,,,,,,,, -59200,0.8669581,0.9580242,,,,,,,,,,,,,, -59300,0.9170186,0.97657645,,,,,,,,,,,,,, -59400,1.0395843,1.0009612,,,,,,,,,,,,,, -59500,0.8300908,0.9726072,,,,,,,,,,,,,, -59600,1.007176,0.99728864,,,,,,,,,,,,,, -59700,1.0539073,1.012764,,,,,,,,,,,,,, -59789,,,0.13394181,0.0514773466047929,0.34509662,0.1003601185591395,5348.0,0.1848671,0.0623159263095891,2472.0,50470.96928143501,55376.52167439461,50470.96928143501,4900.822810411453,2.103106737136841,0.0 -59800,0.9130488,0.9585835,,,,,,,,,,,,,, -59900,1.0211264,1.000916,,,,,,,,,,,,,, -60000,0.84792805,0.96790624,,,,,,,,,,,,,, -60100,0.9907639,0.97567827,,,,,,,,,,,,,, -60200,1.1596066,0.9911229,,,,,,,,,,,,,, -60300,1.0632497,0.9798677,,,,,,,,,,,,,, -60400,0.9089629,0.97416395,,,,,,,,,,,,,, -60500,0.93844587,1.041446,,,,,,,,,,,,,, -60600,0.8840064,0.9843681,,,,,,,,,,,,,, -60700,1.3340778,0.9869427,,,,,,,,,,,,,, -60800,1.3106202,1.0007858,,,,,,,,,,,,,, -60900,1.0562254,0.97365594,,,,,,,,,,,,,, -61000,1.1000944,1.022925,,,,,,,,,,,,,, -61100,1.0102487,0.9521085,,,,,,,,,,,,,, -61200,1.0134711,0.9736146,,,,,,,,,,,,,, -61300,1.0573554,0.970828,,,,,,,,,,,,,, -61400,0.973637,1.017852,,,,,,,,,,,,,, -61491,,,0.13361306,0.0517890557254347,0.3403858,0.0993753439470152,5348.0,0.1836257,0.061462839965064,2472.0,51911.2644867897,56951.9517929554,51911.2644867897,5035.821505784988,2.164682149887085,0.0 -61500,0.92607456,0.9865331,,,,,,,,,,,,,, -61600,0.9215709,0.94702965,,,,,,,,,,,,,, -61700,0.9819923,0.9801139,,,,,,,,,,,,,, -61800,1.0506518,0.9902103,,,,,,,,,,,,,, -61900,1.0514848,0.9224027,,,,,,,,,,,,,, -62000,0.8736193,0.9705197,,,,,,,,,,,,,, -62100,1.1516807,0.9778088,,,,,,,,,,,,,, -62200,1.0226808,0.98378205,,,,,,,,,,,,,, -62300,1.0467607,0.97919357,,,,,,,,,,,,,, -62400,1.0013194,1.0108017,,,,,,,,,,,,,, -62500,1.0778912,1.0098083,,,,,,,,,,,,,, -62600,1.0925478,1.0162001,,,,,,,,,,,,,, -62700,1.0309672,0.9506355,,,,,,,,,,,,,, -62800,1.4060622,0.9868763,,,,,,,,,,,,,, -62900,1.3446006,0.9733947,,,,,,,,,,,,,, -63000,1.0111663,0.92959464,,,,,,,,,,,,,, -63100,1.0115895,0.93807507,,,,,,,,,,,,,, -63200,0.9590675,0.9324226,,,,,,,,,,,,,, -63210,,,0.11915162,0.04634981922064,0.33299387,0.0958610502331598,5348.0,0.17569458,0.057827067211017,2472.0,53351.83570098877,58525.53119134903,53351.83570098877,5168.687787055969,2.230812788009644,0.0 -63300,0.94302154,0.9282945,,,,,,,,,,,,,, -63400,1.172562,0.9896501,,,,,,,,,,,,,, -63500,1.0845125,0.91492045,,,,,,,,,,,,,, -63600,0.89437455,0.96103585,,,,,,,,,,,,,, -63700,1.0521973,0.9349739,,,,,,,,,,,,,, -63800,1.0387356,0.96096724,,,,,,,,,,,,,, -63900,0.9237136,0.92687553,,,,,,,,,,,,,, -64000,0.9394518,0.9081585,,,,,,,,,,,,,, -64100,0.9846114,0.92744434,,,,,,,,,,,,,, -64200,1.348291,0.9349809,,,,,,,,,,,,,, -64300,1.1258955,0.9124057,,,,,,,,,,,,,, -64400,1.0828226,0.910532,,,,,,,,,,,,,, -64500,1.3067534,0.93911994,,,,,,,,,,,,,, -64600,1.1019665,0.9044176,,,,,,,,,,,,,, -64700,1.4876845,0.9636584,,,,,,,,,,,,,, -64800,1.2876529,0.94352174,,,,,,,,,,,,,, -64881,,,0.115036234,0.0437265496450064,0.32092267,0.0931867113355281,5348.0,0.17374022,0.0569130461275973,2472.0,54792.682903051376,60099.91721081734,54792.682903051376,5302.091877937317,2.291220664978028,0.0 -64900,0.9691528,0.95013773,,,,,,,,,,,,,, -65000,1.0874132,0.9419263,,,,,,,,,,,,,, -65100,1.2190856,1.0076914,,,,,,,,,,,,,, -65200,1.1344081,0.9132685,,,,,,,,,,,,,, -65300,0.97258055,0.9069473,,,,,,,,,,,,,, -65400,0.9177594,0.93071,,,,,,,,,,,,,, -65500,1.0957309,0.92023224,,,,,,,,,,,,,, -65600,1.231853,0.9410938,,,,,,,,,,,,,, -65700,1.1366808,0.9317708,,,,,,,,,,,,,, -65800,1.3327595,0.9605321,,,,,,,,,,,,,, -65900,1.0809623,0.91815054,,,,,,,,,,,,,, -66000,1.8195245,0.9232751,,,,,,,,,,,,,, -66100,1.2648132,0.96609086,,,,,,,,,,,,,, -66200,1.0289893,0.87341,,,,,,,,,,,,,, -66300,1.3642462,0.94247913,,,,,,,,,,,,,, -66400,1.1345522,0.92186046,,,,,,,,,,,,,, -66500,1.2124418,1.000747,,,,,,,,,,,,,, -66568,,,0.10550904,0.0410768989959974,0.3176728,0.0916226575398013,5348.0,0.17099647,0.0549631344829687,2472.0,56232.99437975884,61675.28289914131,56232.99437975884,5436.998205900192,2.361929655075073,0.0 -66600,1.5585467,0.9548549,,,,,,,,,,,,,, -66700,1.12011,0.92407596,,,,,,,,,,,,,, -66800,1.0491824,0.9246103,,,,,,,,,,,,,, -66900,0.9531638,0.88094324,,,,,,,,,,,,,, -67000,1.0891932,0.91396725,,,,,,,,,,,,,, -67100,1.357376,0.92182,,,,,,,,,,,,,, -67200,1.0945143,0.9255275,,,,,,,,,,,,,, -67300,1.2806607,0.93849456,,,,,,,,,,,,,, -67400,1.2898055,0.92276055,,,,,,,,,,,,,, -67500,1.0845366,0.93908066,,,,,,,,,,,,,, -67600,1.267732,0.90774924,,,,,,,,,,,,,, -67700,2.0131052,0.91792464,,,,,,,,,,,,,, -67800,1.1165054,0.9156619,,,,,,,,,,,,,, -67900,1.2005893,0.92376417,,,,,,,,,,,,,, -68000,1.4642961,0.9057743,,,,,,,,,,,,,, -68100,1.1900144,0.9380937,,,,,,,,,,,,,, -68200,1.1778076,0.90004635,,,,,,,,,,,,,, -68277,,,0.09899141,0.038625017987005,0.31188878,0.0894213966421116,5348.0,0.16618471,0.0533382081124449,2472.0,57673.43183207512,63251.47676539421,57673.43183207512,5572.620210170746,2.421889066696167,0.0 -68300,0.9592378,0.9093014,,,,,,,,,,,,,, -68400,1.2433885,0.91152656,,,,,,,,,,,,,, -68500,1.1297382,0.88657385,,,,,,,,,,,,,, -68600,1.0998448,0.87060106,,,,,,,,,,,,,, -68700,1.2315177,0.92348486,,,,,,,,,,,,,, -68800,1.0556763,0.8894664,,,,,,,,,,,,,, -68900,1.2862984,0.9656636,,,,,,,,,,,,,, -69000,1.2087451,0.90117794,,,,,,,,,,,,,, -69100,1.0890483,0.90205216,,,,,,,,,,,,,, -69200,1.1607637,0.91151524,,,,,,,,,,,,,, -69300,1.4342648,0.8463961,,,,,,,,,,,,,, -69400,1.4154956,0.9451529,,,,,,,,,,,,,, -69500,1.1650268,0.91619396,,,,,,,,,,,,,, -69600,1.117258,0.8848072,,,,,,,,,,,,,, -69700,1.176465,0.83077186,,,,,,,,,,,,,, -69800,1.0964396,0.8501465,,,,,,,,,,,,,, -69900,1.4184628,0.8845569,,,,,,,,,,,,,, -69968,,,0.079612635,0.0305982258456632,0.3068214,0.0878380335402647,5348.0,0.16443966,0.0522820059716044,2472.0,59113.41282272339,64825.58831310272,59113.41282272339,5706.620025396347,2.478905200958252,0.0 -70000,1.123292,0.8834414,,,,,,,,,,,,,, -70100,1.3568658,0.8633344,,,,,,,,,,,,,, -70200,1.0504541,0.91506547,,,,,,,,,,,,,, -70300,2.1066327,0.8834624,,,,,,,,,,,,,, -70400,1.1955225,0.88845766,,,,,,,,,,,,,, -70500,1.2659152,0.8530795,,,,,,,,,,,,,, -70600,1.195212,0.89139414,,,,,,,,,,,,,, -70700,1.1537824,0.85734445,,,,,,,,,,,,,, -70800,1.2136887,0.9008683,,,,,,,,,,,,,, -70900,1.3765895,0.8820105,,,,,,,,,,,,,, -71000,1.130023,0.87229806,,,,,,,,,,,,,, -71100,1.3212253,0.8843961,,,,,,,,,,,,,, -71200,1.2026358,0.8905658,,,,,,,,,,,,,, -71300,1.2858545,0.8411402,,,,,,,,,,,,,, -71400,1.0564463,0.88115156,,,,,,,,,,,,,, -71500,1.1118637,0.8325755,,,,,,,,,,,,,, -71600,1.1100056,0.87034684,,,,,,,,,,,,,, -71672,,,0.07433639,0.0282804596919905,0.30262092,0.0857526284792956,5348.0,0.16102934,0.0517945280604472,2472.0,60553.82369470596,66399.41762804985,60553.82369470596,5839.893024682999,2.547360420227051,0.0 -71700,1.1809124,0.876231,,,,,,,,,,,,,, -71800,1.1892347,0.87000835,,,,,,,,,,,,,, -71900,1.1491839,0.87926936,,,,,,,,,,,,,, -72000,1.1203649,0.84962136,,,,,,,,,,,,,, -72100,1.2911217,0.872528,,,,,,,,,,,,,, -72200,1.4187984,0.839133,,,,,,,,,,,,,, -72300,1.5979658,0.8858582,,,,,,,,,,,,,, -72301,,,,,,,,,,,61068.183177948,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 5bedcbeed..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -138.98080229759216,0.0,35.57791781425476,1,0,35.57791781425476,31.279486,2472,1.0976986980277457,174.5587763786316,32.387848,1.3955485058197952,31.163591,5348,1.043146644525329 -251.99166750907887,0.0357248783111572,1475.9603667259216,1685,0,1475.9603667259216,6.330534,2472,0.8973046533828936,1728.061883211136,6.3702335,0.936353811149033,6.3203735,5348,0.895440107359742 -383.8297390937805,0.0865085124969482,2916.4242436885834,3376,0,2916.4242436885834,2.5039368,2472,0.5188999248471554,3300.489486694336,2.8703103,0.5950224159334732,2.8294756,5348,0.5699334794404163 -521.587308883667,0.1410958766937255,4356.522769451141,5050,0,4356.522769451141,0.66521096,2472,0.2190197631669815,4878.47324180603,0.83288306,0.2686575402958711,0.9431758,5348,0.2808055842513299 -654.1273121833801,0.1892614364624023,5797.023282766342,6740,0,5797.023282766342,0.48961702,2472,0.1671846119472711,6451.637514591217,0.6040004,0.2079973718639594,0.74008816,5348,0.2246444674010639 -786.9121985435486,0.2403411865234375,7236.972342252731,8450,0,7236.972342252731,0.4210605,2472,0.1426279121727296,8024.497707366943,0.49352032,0.1710689262302751,0.6651629,5348,0.2010581499753806 -947.81329536438,0.2935667037963867,8677.129787445068,10149,0,8677.129787445068,0.3722067,2472,0.1262974021489651,9625.68328166008,0.29691428,0.1110461656180921,0.6043157,5348,0.1861996389159755 -1086.3854405879974,0.3410470485687256,10117.744185447693,11875,0,10117.744185447693,0.35247084,2472,0.1216257388337091,11204.992297172546,0.27041242,0.099588036967578,0.57668215,5348,0.175357463529548 -1222.2144269943235,0.3931374549865722,11557.835375070572,13572,0,11557.835375070572,0.33199945,2472,0.1133995490829321,12781.041088342668,0.26092592,0.0966991728111208,0.54714733,5348,0.1656352278980855 -1360.179630279541,0.4491429328918457,12997.719169139862,15238,0,12997.719169139862,0.3189678,2472,0.1110434058456726,14359.018620729446,0.2348248,0.0892655897821187,0.5383248,5348,0.1643704683472199 -1497.6522102355957,0.5024521350860596,14438.15297293663,16950,0,14438.15297293663,0.30607492,2472,0.1061889383137326,15937.055542230606,0.22629037,0.0866338613111835,0.51311344,5348,0.1572646436950288 -1632.5123000144958,0.5541045665740967,15878.68928694725,18641,0,15878.68928694725,0.2961128,2472,0.1002782686409522,17512.57652759552,0.20851934,0.0807911355209591,0.4994007,5348,0.1518097647161049 -1768.872528076172,0.6105937957763672,17319.13134288788,20329,0,17319.13134288788,0.28475207,2472,0.0978002559259033,19089.508982419968,0.22626519,0.0824587947463301,0.4777417,5348,0.1465479787983818 -1903.2603197097776,0.6675965785980225,18759.68219566345,22034,0,18759.68219566345,0.27821675,2472,0.0950378810960128,20664.581153154373,0.20196249,0.078254877014419,0.47147408,5348,0.1439701864313506 -2037.694198846817,0.728062629699707,20199.590856790543,23710,0,20199.590856790543,0.26943246,2472,0.0920723904698068,22239.059130191803,0.16809237,0.0668152726589335,0.46198797,5348,0.1395483553298512 -2173.9809906482697,0.7831501960754395,21639.50814318657,25401,0,21639.50814318657,0.26178136,2472,0.0884163061361282,23815.39300727844,0.1630345,0.063379154304254,0.4471465,5348,0.1367291966363188 -2309.429753303528,0.8377275466918945,23079.609487771988,27108,0,23079.609487771988,0.25654113,2472,0.086669510287815,25391.072680950165,0.15664186,0.0624892287806979,0.43819365,5348,0.1324328760246 -2445.1812682151794,0.892916202545166,24519.481394052505,28772,0,24519.481394052505,0.24707821,2472,0.0838258891393983,26966.82464933396,0.15958466,0.0610066966850886,0.42708495,5348,0.1285034322291628 -2581.9604048728943,0.948739767074585,25959.4847984314,30461,0,25959.4847984314,0.24584797,2472,0.0836227733430828,28543.738730192184,0.15726769,0.0611525223303841,0.43288928,5348,0.1284261950046825 -2718.995190382004,1.0019807815551758,27399.63714241981,32179,0,27399.63714241981,0.23941892,2472,0.0809010216724554,30121.05331158638,0.14783192,0.0590605554319882,0.42411798,5348,0.1266883574538748 -2853.9519302845,1.0594327449798584,28839.877410888672,33870,0,28839.877410888672,0.23393744,2472,0.0776308573517762,31696.38317656517,0.14385405,0.0571283357173921,0.41215566,5348,0.1227589136584376 -2990.9510576725006,1.1162869930267334,30280.39509201049,35559,0,30280.39509201049,0.23113732,2472,0.0759449962423577,33274.03078866005,0.1265491,0.0511833648791855,0.4120602,5348,0.122150670515655 -3128.355121135712,1.1727495193481443,31721.07645821572,37254,0,31721.07645821572,0.22469771,2472,0.0752747141145166,34852.249284267426,0.12382909,0.0480419862737182,0.40429664,5348,0.1200266468424457 -3265.838086128235,1.232497215270996,33161.42921876907,38930,0,33161.42921876907,0.2211299,2472,0.0744013161903601,36430.21849322319,0.13134545,0.051060376541874,0.39354557,5348,0.1150062272512237 -3403.660226583481,1.2956340312957764,34601.42512226105,40629,0,34601.42512226105,0.21496634,2472,0.0726138971827839,38008.17625498772,0.11861537,0.0469573080715363,0.38333464,5348,0.1137414677003581 -3540.007214784622,1.3486223220825195,36041.619644880295,42306,0,36041.619644880295,0.21081606,2472,0.0694452907602624,39584.84460020065,0.10947188,0.045603775520673,0.3845979,5348,0.11188777431283 -3676.626125335693,1.4088342189788818,37481.56066060066,44006,0,37481.56066060066,0.21202654,2472,0.0708671013344707,41161.54116153717,0.13150723,0.0486614122977759,0.38188517,5348,0.1114629695781882 -3814.218501806259,1.4656569957733154,38921.45214056969,45720,0,38921.45214056969,0.20303267,2472,0.0677594296508439,42739.1587574482,0.08845052,0.0354513198104887,0.37528655,5348,0.1092327447213184 -3950.640738248825,1.5266199111938477,40362.0363907814,47396,0,40362.0363907814,0.20532721,2472,0.0674547559563707,44316.29981398583,0.09425022,0.0379273029531541,0.37133262,5348,0.1077169641908918 -4085.706921339035,1.5839388370513916,41801.95665073395,49093,0,41801.95665073395,0.19545701,2472,0.0655251558913736,45891.41930747032,0.11438162,0.0454195704267153,0.36344072,5348,0.1061142917829247 -4221.55396938324,1.6436846256256104,43242.03401255608,50786,0,43242.03401255608,0.19716543,2472,0.0640830337375337,47467.48012375832,0.117886096,0.0461131535607938,0.360941,5348,0.1038454483138148 -4355.304067373276,1.6991455554962158,44682.5635433197,52446,0,44682.5635433197,0.19476175,2472,0.0630065200170617,49041.885318517685,0.12866937,0.0515651104510859,0.3551715,5348,0.1020110642324068 -4489.92017698288,1.7626543045043943,46123.199763059616,54122,0,46123.199763059616,0.19035426,2472,0.0630065200170617,50617.27509188652,0.111043565,0.0426047658175842,0.3560681,5348,0.1020400281915869 -4623.43222784996,1.8321900367736816,47563.62722706795,55801,0,47563.62722706795,0.18869571,2472,0.0615034631243271,52191.360865831375,0.09821409,0.0392730909928691,0.3484534,5348,0.0991629415796943 -4759.919205904007,1.887025594711304,49004.61514925957,57473,0,49004.61514925957,0.18406248,2472,0.0598785367538033,53768.96395373344,0.081379846,0.0325830120631762,0.34315622,5348,0.0966237678249032 -4897.376502037048,1.9495971202850344,50444.962716817856,59177,0,50444.962716817856,0.18465194,2472,0.0590254504092783,55346.90739274025,0.09032433,0.0367084051395039,0.34330922,5348,0.0964886026820626 -5037.317294836044,2.0173354148864746,51885.22932291031,60853,0,51885.22932291031,0.18213382,2472,0.0579083135295431,56927.25577759743,0.076652534,0.0303122660678002,0.33885056,5348,0.0949245488863357 -5172.018817186356,2.085672378540039,53326.08283543587,62543,0,53326.08283543587,0.17903255,2472,0.0567708650701765,58502.95514130592,0.07401633,0.0311466567874868,0.3358002,5348,0.0939011556619712 -5307.635042190552,2.152307748794556,54766.469485759735,64246,0,54766.469485759735,0.17863517,2472,0.056791176649808,60079.09964752197,0.06872757,0.0277908723154675,0.33404937,5348,0.0938721917027911 -5443.378481864929,2.2151148319244385,56206.48468470573,65915,0,56206.48468470573,0.17421615,2472,0.055308431336705,61654.99394798279,0.07059194,0.0279766860949208,0.32888624,5348,0.092018498315263 -5579.349866390228,2.277138948440552,57647.05547738075,67587,0,57647.05547738075,0.17488703,2472,0.0552068734385473,63231.67115473747,0.07291302,0.0287140800271538,0.32705754,5348,0.0905123724378964 -5714.688037395477,2.3428428173065186,59087.30037307739,69280,0,59087.30037307739,0.1747135,2472,0.0554506123941258,64807.39521574974,0.060553554,0.0240735475310397,0.329274,5348,0.0914102551724803 -5849.344304800034,2.4029757976531982,60527.92716932297,70958,0,60527.92716932297,0.17322181,2472,0.054617837629232426,66382.81089735031,0.06268328,0.024154338909627173,0.3256981,5348,0.08989447464205373 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/measurements.csv deleted file mode 100644 index 1bd6b8e54..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/measurements.csv +++ /dev/null @@ -1,762 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,44.14146,31.829971,,,,,,,,,,,,,, -1,,,32.387848,1.3955485058197952,31.163591,1.043146644525329,5348.0,31.279486,1.0976986980277457,2472.0,35.57791781425476,174.5587763786316,35.57791781425476,138.98080229759216,0.0,0.0 -100,2.5681877,6.6499853,,,,,,,,,,,,,, -200,0.6175644,5.9323616,,,,,,,,,,,,,, -300,0.47367212,5.836099,,,,,,,,,,,,,, -400,0.57608837,5.8060217,,,,,,,,,,,,,, -500,0.570951,5.810319,,,,,,,,,,,,,, -600,0.58105797,5.793974,,,,,,,,,,,,,, -700,1.9093612,5.78246,,,,,,,,,,,,,, -800,1.6847188,5.78802,,,,,,,,,,,,,, -900,3.6546555,5.798966,,,,,,,,,,,,,, -1000,1.9925528,5.7740664,,,,,,,,,,,,,, -1100,4.241114,5.7303276,,,,,,,,,,,,,, -1200,1.5411026,5.534539,,,,,,,,,,,,,, -1300,2.2268846,5.491662,,,,,,,,,,,,,, -1400,8.061101,5.5868864,,,,,,,,,,,,,, -1500,1.1307198,4.7472553,,,,,,,,,,,,,, -1600,0.78331465,4.1657104,,,,,,,,,,,,,, -1685,,,6.3702335,0.936353811149033,6.3203735,0.895440107359742,5348.0,6.330534,0.8973046533828936,2472.0,1475.9603667259216,1728.061883211136,1475.9603667259216,251.99166750907887,0.0357248783111572,0.0 -1700,1.2086877,3.8545601,,,,,,,,,,,,,, -1800,0.9246079,3.6348648,,,,,,,,,,,,,, -1900,1.6033845,3.412395,,,,,,,,,,,,,, -2000,1.0228087,3.2941463,,,,,,,,,,,,,, -2100,1.0632932,3.140467,,,,,,,,,,,,,, -2200,1.2216462,3.0859106,,,,,,,,,,,,,, -2300,0.9238248,2.9741583,,,,,,,,,,,,,, -2400,1.3426936,2.9254966,,,,,,,,,,,,,, -2500,1.3516728,2.870893,,,,,,,,,,,,,, -2600,0.97571826,2.7271962,,,,,,,,,,,,,, -2700,1.8047405,2.6853654,,,,,,,,,,,,,, -2800,0.8620672,2.5728855,,,,,,,,,,,,,, -2900,1.1347735,2.6154032,,,,,,,,,,,,,, -3000,1.9072491,2.4839513,,,,,,,,,,,,,, -3100,0.8526655,2.4275742,,,,,,,,,,,,,, -3200,1.0511534,2.3777027,,,,,,,,,,,,,, -3300,0.7725561,2.3354986,,,,,,,,,,,,,, -3376,,,2.8703103,0.5950224159334732,2.8294756,0.5699334794404163,5348.0,2.5039368,0.5188999248471554,2472.0,2916.4242436885834,3300.489486694336,2916.4242436885834,383.8297390937805,0.0865085124969482,0.0 -3400,0.8101665,2.2195213,,,,,,,,,,,,,, -3500,1.0958052,2.243652,,,,,,,,,,,,,, -3600,0.7777608,2.1736102,,,,,,,,,,,,,, -3700,0.79648817,2.1491745,,,,,,,,,,,,,, -3800,0.8138519,2.086443,,,,,,,,,,,,,, -3900,0.74327844,2.1805263,,,,,,,,,,,,,, -4000,1.0745735,2.045133,,,,,,,,,,,,,, -4100,0.7393088,2.0551164,,,,,,,,,,,,,, -4200,0.64688265,1.9634243,,,,,,,,,,,,,, -4300,1.0275608,1.9907805,,,,,,,,,,,,,, -4400,0.7845382,1.9503257,,,,,,,,,,,,,, -4500,0.7749285,1.9262706,,,,,,,,,,,,,, -4600,0.6815733,1.9688025,,,,,,,,,,,,,, -4700,0.8013006,1.8370807,,,,,,,,,,,,,, -4800,0.7469476,1.8580246,,,,,,,,,,,,,, -4900,0.80096316,1.8886014,,,,,,,,,,,,,, -5000,0.8308872,1.8308898,,,,,,,,,,,,,, -5050,,,0.83288306,0.2686575402958711,0.9431758,0.2808055842513299,5348.0,0.66521096,0.2190197631669815,2472.0,4356.522769451141,4878.47324180603,4356.522769451141,521.587308883667,0.1410958766937255,0.0 -5100,0.79773253,1.8254658,,,,,,,,,,,,,, -5200,0.59613746,1.8154494,,,,,,,,,,,,,, -5300,0.7553264,1.7908776,,,,,,,,,,,,,, -5400,0.56864923,1.776155,,,,,,,,,,,,,, -5500,0.9451816,1.8026675,,,,,,,,,,,,,, -5600,0.95368695,1.8096114,,,,,,,,,,,,,, -5700,0.67771256,1.8604559,,,,,,,,,,,,,, -5800,0.62773067,1.7588185,,,,,,,,,,,,,, -5900,0.6930454,1.7801868,,,,,,,,,,,,,, -6000,0.7556577,1.6990595,,,,,,,,,,,,,, -6100,0.9305669,1.7312063,,,,,,,,,,,,,, -6200,0.94173324,1.6220701,,,,,,,,,,,,,, -6300,0.6209583,1.6724524,,,,,,,,,,,,,, -6400,0.689005,1.6993127,,,,,,,,,,,,,, -6500,0.59158593,1.649292,,,,,,,,,,,,,, -6600,1.5282402,1.614341,,,,,,,,,,,,,, -6700,0.6808628,1.70236,,,,,,,,,,,,,, -6740,,,0.6040004,0.2079973718639594,0.74008816,0.2246444674010639,5348.0,0.48961702,0.1671846119472711,2472.0,5797.023282766342,6451.637514591217,5797.023282766342,654.1273121833801,0.1892614364624023,0.0 -6800,0.61469597,1.6188918,,,,,,,,,,,,,, -6900,0.70728874,1.6057281,,,,,,,,,,,,,, -7000,0.8878753,1.6760151,,,,,,,,,,,,,, -7100,0.64738977,1.6414944,,,,,,,,,,,,,, -7200,0.71632475,1.6266865,,,,,,,,,,,,,, -7300,0.6955349,1.5956937,,,,,,,,,,,,,, -7400,0.79316586,1.5727377,,,,,,,,,,,,,, -7500,0.58098483,1.6387917,,,,,,,,,,,,,, -7600,0.64454454,1.5391358,,,,,,,,,,,,,, -7700,0.8437667,1.6082504,,,,,,,,,,,,,, -7800,0.55873585,1.5789899,,,,,,,,,,,,,, -7900,0.5941433,1.5358825,,,,,,,,,,,,,, -8000,0.6434638,1.5579911,,,,,,,,,,,,,, -8100,0.68713826,1.5638163,,,,,,,,,,,,,, -8200,0.6619916,1.6175086,,,,,,,,,,,,,, -8300,0.71588224,1.4916267,,,,,,,,,,,,,, -8400,0.8090252,1.5376118,,,,,,,,,,,,,, -8450,,,0.49352032,0.1710689262302751,0.6651629,0.2010581499753806,5348.0,0.4210605,0.1426279121727296,2472.0,7236.972342252731,8024.497707366943,7236.972342252731,786.9121985435486,0.2403411865234375,0.0 -8500,0.5791571,1.5627813,,,,,,,,,,,,,, -8600,0.6594584,1.5578263,,,,,,,,,,,,,, -8700,0.6456313,1.5360272,,,,,,,,,,,,,, -8800,0.74800754,1.5876312,,,,,,,,,,,,,, -8900,0.6192822,1.532116,,,,,,,,,,,,,, -9000,0.6413133,1.5383703,,,,,,,,,,,,,, -9100,0.5799464,1.5205314,,,,,,,,,,,,,, -9200,0.60447246,1.49934,,,,,,,,,,,,,, -9300,0.5766881,1.480134,,,,,,,,,,,,,, -9400,0.7715793,1.4726605,,,,,,,,,,,,,, -9500,0.7179313,1.5282732,,,,,,,,,,,,,, -9600,0.6554888,1.4908276,,,,,,,,,,,,,, -9700,0.56835425,1.5435114,,,,,,,,,,,,,, -9800,0.60837454,1.5138226,,,,,,,,,,,,,, -9900,0.79028714,1.43802,,,,,,,,,,,,,, -10000,0.6530944,1.4586806,,,,,,,,,,,,,, -10100,0.6256141,1.4871912,,,,,,,,,,,,,, -10149,,,0.29691428,0.1110461656180921,0.6043157,0.1861996389159755,5348.0,0.3722067,0.1262974021489651,2472.0,8677.129787445068,9625.68328166008,8677.129787445068,947.81329536438,0.2935667037963867,0.0 -10200,0.65324455,1.5412415,,,,,,,,,,,,,, -10300,0.6393015,1.4745162,,,,,,,,,,,,,, -10400,0.6540031,1.4609996,,,,,,,,,,,,,, -10500,0.7166896,1.405619,,,,,,,,,,,,,, -10600,0.61184335,1.4414574,,,,,,,,,,,,,, -10700,0.6226299,1.505586,,,,,,,,,,,,,, -10800,0.6062006,1.4393847,,,,,,,,,,,,,, -10900,0.68690467,1.4776208,,,,,,,,,,,,,, -11000,0.61568165,1.3880268,,,,,,,,,,,,,, -11100,0.6046046,1.4202391,,,,,,,,,,,,,, -11200,0.6326718,1.4322032,,,,,,,,,,,,,, -11300,0.73993385,1.4541326,,,,,,,,,,,,,, -11400,0.70225894,1.4272041,,,,,,,,,,,,,, -11500,0.6236221,1.3986022,,,,,,,,,,,,,, -11600,0.6560986,1.4002331,,,,,,,,,,,,,, -11700,0.6523136,1.4994048,,,,,,,,,,,,,, -11800,0.55710644,1.3949062,,,,,,,,,,,,,, -11875,,,0.27041242,0.099588036967578,0.57668215,0.175357463529548,5348.0,0.35247084,0.1216257388337091,2472.0,10117.744185447693,11204.992297172546,10117.744185447693,1086.3854405879974,0.3410470485687256,0.0 -11900,0.5649074,1.4282651,,,,,,,,,,,,,, -12000,0.7619122,1.4007657,,,,,,,,,,,,,, -12100,0.6506184,1.4266342,,,,,,,,,,,,,, -12200,0.6069252,1.4209019,,,,,,,,,,,,,, -12300,0.7160285,1.4044089,,,,,,,,,,,,,, -12400,0.5581915,1.3858738,,,,,,,,,,,,,, -12500,0.6287035,1.4071938,,,,,,,,,,,,,, -12600,0.7285196,1.4084707,,,,,,,,,,,,,, -12700,0.62460804,1.4360863,,,,,,,,,,,,,, -12800,0.68141973,1.4022437,,,,,,,,,,,,,, -12900,0.56521755,1.3622942,,,,,,,,,,,,,, -13000,0.5772731,1.4098709,,,,,,,,,,,,,, -13100,0.6490571,1.3943849,,,,,,,,,,,,,, -13200,0.59016687,1.4157343,,,,,,,,,,,,,, -13300,0.6099916,1.3607028,,,,,,,,,,,,,, -13400,0.5451419,1.3545047,,,,,,,,,,,,,, -13500,0.5584402,1.373686,,,,,,,,,,,,,, -13572,,,0.26092592,0.0966991728111208,0.54714733,0.1656352278980855,5348.0,0.33199945,0.1133995490829321,2472.0,11557.835375070572,12781.041088342668,11557.835375070572,1222.2144269943235,0.3931374549865722,0.0 -13600,0.6274126,1.399471,,,,,,,,,,,,,, -13700,0.6681174,1.3845901,,,,,,,,,,,,,, -13800,0.5459534,1.3785341,,,,,,,,,,,,,, -13900,0.5567205,1.4035288,,,,,,,,,,,,,, -14000,0.7021305,1.317392,,,,,,,,,,,,,, -14100,0.64789605,1.4116461,,,,,,,,,,,,,, -14200,0.6285641,1.3357643,,,,,,,,,,,,,, -14300,0.53397506,1.3950583,,,,,,,,,,,,,, -14400,0.65112424,1.4009285,,,,,,,,,,,,,, -14500,0.8016378,1.3936346,,,,,,,,,,,,,, -14600,0.68871593,1.3364872,,,,,,,,,,,,,, -14700,0.5887089,1.3671767,,,,,,,,,,,,,, -14800,0.7282829,1.3544836,,,,,,,,,,,,,, -14900,0.54729307,1.3281715,,,,,,,,,,,,,, -15000,0.59887356,1.3472784,,,,,,,,,,,,,, -15100,0.64360535,1.3250184,,,,,,,,,,,,,, -15200,0.62975943,1.3997645,,,,,,,,,,,,,, -15238,,,0.2348248,0.0892655897821187,0.5383248,0.1643704683472199,5348.0,0.3189678,0.1110434058456726,2472.0,12997.719169139862,14359.018620729446,12997.719169139862,1360.179630279541,0.4491429328918457,0.0 -15300,0.7665472,1.3755332,,,,,,,,,,,,,, -15400,0.713054,1.3771647,,,,,,,,,,,,,, -15500,0.5453542,1.3220246,,,,,,,,,,,,,, -15600,0.5486115,1.3868619,,,,,,,,,,,,,, -15700,0.5812245,1.353627,,,,,,,,,,,,,, -15800,0.8608643,1.3522239,,,,,,,,,,,,,, -15900,0.50312984,1.3149834,,,,,,,,,,,,,, -16000,0.60115314,1.3330753,,,,,,,,,,,,,, -16100,0.52349913,1.3476723,,,,,,,,,,,,,, -16200,0.68131053,1.328159,,,,,,,,,,,,,, -16300,0.49257648,1.3638486,,,,,,,,,,,,,, -16400,0.5056813,1.3375114,,,,,,,,,,,,,, -16500,0.63464296,1.3216808,,,,,,,,,,,,,, -16600,0.54639244,1.2987279,,,,,,,,,,,,,, -16700,0.5398379,1.328969,,,,,,,,,,,,,, -16800,0.7207307,1.3266453,,,,,,,,,,,,,, -16900,0.59237295,1.3376782,,,,,,,,,,,,,, -16950,,,0.22629037,0.0866338613111835,0.51311344,0.1572646436950288,5348.0,0.30607492,0.1061889383137326,2472.0,14438.15297293663,15937.055542230606,14438.15297293663,1497.6522102355957,0.5024521350860596,0.0 -17000,0.80760646,1.3715914,,,,,,,,,,,,,, -17100,0.60217685,1.3600492,,,,,,,,,,,,,, -17200,0.582601,1.3481451,,,,,,,,,,,,,, -17300,0.68015265,1.3393031,,,,,,,,,,,,,, -17400,0.65458006,1.3718408,,,,,,,,,,,,,, -17500,0.5853485,1.3642486,,,,,,,,,,,,,, -17600,0.5788618,1.2886461,,,,,,,,,,,,,, -17700,0.6192958,1.3574358,,,,,,,,,,,,,, -17800,0.5698546,1.2958022,,,,,,,,,,,,,, -17900,0.61929125,1.3773514,,,,,,,,,,,,,, -18000,0.56786686,1.270207,,,,,,,,,,,,,, -18100,0.5205222,1.2757801,,,,,,,,,,,,,, -18200,0.53421634,1.305323,,,,,,,,,,,,,, -18300,0.6188621,1.3006127,,,,,,,,,,,,,, -18400,0.61420846,1.355497,,,,,,,,,,,,,, -18500,0.7695962,1.3659345,,,,,,,,,,,,,, -18600,0.7727638,1.3018149,,,,,,,,,,,,,, -18641,,,0.20851934,0.0807911355209591,0.4994007,0.1518097647161049,5348.0,0.2961128,0.1002782686409522,2472.0,15878.68928694725,17512.57652759552,15878.68928694725,1632.5123000144958,0.5541045665740967,0.0 -18700,0.6704474,1.2555434,,,,,,,,,,,,,, -18800,0.6259559,1.3227677,,,,,,,,,,,,,, -18900,0.7238377,1.2749603,,,,,,,,,,,,,, -19000,0.61578923,1.2956902,,,,,,,,,,,,,, -19100,0.60667676,1.2998425,,,,,,,,,,,,,, -19200,0.7298597,1.2458866,,,,,,,,,,,,,, -19300,0.5688267,1.3450209,,,,,,,,,,,,,, -19400,0.5793768,1.3078039,,,,,,,,,,,,,, -19500,0.5384698,1.273939,,,,,,,,,,,,,, -19600,0.5538669,1.2699944,,,,,,,,,,,,,, -19700,0.62594134,1.2719083,,,,,,,,,,,,,, -19800,0.5669018,1.2547182,,,,,,,,,,,,,, -19900,0.5205289,1.2516763,,,,,,,,,,,,,, -20000,0.6761799,1.2924631,,,,,,,,,,,,,, -20100,0.56589717,1.2704283,,,,,,,,,,,,,, -20200,0.5586852,1.2961222,,,,,,,,,,,,,, -20300,0.6640006,1.2985489,,,,,,,,,,,,,, -20329,,,0.22626519,0.0824587947463301,0.4777417,0.1465479787983818,5348.0,0.28475207,0.0978002559259033,2472.0,17319.13134288788,19089.508982419968,17319.13134288788,1768.872528076172,0.6105937957763672,0.0 -20400,0.549619,1.2894028,,,,,,,,,,,,,, -20500,0.5486974,1.325713,,,,,,,,,,,,,, -20600,0.61517894,1.3026718,,,,,,,,,,,,,, -20700,0.6362325,1.2388824,,,,,,,,,,,,,, -20800,0.575729,1.2504686,,,,,,,,,,,,,, -20900,0.60904825,1.2642928,,,,,,,,,,,,,, -21000,0.59637535,1.3000385,,,,,,,,,,,,,, -21100,0.59017926,1.2320786,,,,,,,,,,,,,, -21200,0.62377405,1.3070861,,,,,,,,,,,,,, -21300,0.63610077,1.3238043,,,,,,,,,,,,,, -21400,0.55235714,1.257752,,,,,,,,,,,,,, -21500,0.73568386,1.3264353,,,,,,,,,,,,,, -21600,0.612646,1.197909,,,,,,,,,,,,,, -21700,0.4969876,1.2178062,,,,,,,,,,,,,, -21800,0.68148655,1.2463238,,,,,,,,,,,,,, -21900,0.66973436,1.2832487,,,,,,,,,,,,,, -22000,0.65265715,1.2249691,,,,,,,,,,,,,, -22034,,,0.20196249,0.078254877014419,0.47147408,0.1439701864313506,5348.0,0.27821675,0.0950378810960128,2472.0,18759.68219566345,20664.581153154373,18759.68219566345,1903.2603197097776,0.6675965785980225,0.0 -22100,0.5805971,1.2293646,,,,,,,,,,,,,, -22200,0.65928996,1.2613475,,,,,,,,,,,,,, -22300,0.6623837,1.2638634,,,,,,,,,,,,,, -22400,0.48553473,1.2097131,,,,,,,,,,,,,, -22500,0.54018843,1.2380114,,,,,,,,,,,,,, -22600,0.6186895,1.2342718,,,,,,,,,,,,,, -22700,0.547232,1.2122651,,,,,,,,,,,,,, -22800,0.6059437,1.2325778,,,,,,,,,,,,,, -22900,0.76061726,1.2521249,,,,,,,,,,,,,, -23000,0.64875543,1.2149338,,,,,,,,,,,,,, -23100,0.5627293,1.1984445,,,,,,,,,,,,,, -23200,0.6080253,1.1904103,,,,,,,,,,,,,, -23300,0.66455644,1.2379916,,,,,,,,,,,,,, -23400,0.5895132,1.2288809,,,,,,,,,,,,,, -23500,0.51529384,1.2475975,,,,,,,,,,,,,, -23600,0.58130133,1.2829336,,,,,,,,,,,,,, -23700,0.58568084,1.1982374,,,,,,,,,,,,,, -23710,,,0.16809237,0.0668152726589335,0.46198797,0.1395483553298512,5348.0,0.26943246,0.0920723904698068,2472.0,20199.590856790543,22239.059130191803,20199.590856790543,2037.694198846817,0.728062629699707,0.0 -23800,0.69697493,1.2081138,,,,,,,,,,,,,, -23900,0.5174456,1.1583941,,,,,,,,,,,,,, -24000,0.7043509,1.2344223,,,,,,,,,,,,,, -24100,0.78807455,1.2597001,,,,,,,,,,,,,, -24200,0.63635904,1.2532036,,,,,,,,,,,,,, -24300,0.49207833,1.201486,,,,,,,,,,,,,, -24400,0.640905,1.2076856,,,,,,,,,,,,,, -24500,0.5567183,1.2122111,,,,,,,,,,,,,, -24600,0.5148727,1.2285887,,,,,,,,,,,,,, -24700,0.6092707,1.1529406,,,,,,,,,,,,,, -24800,0.5184393,1.1736112,,,,,,,,,,,,,, -24900,0.7512158,1.2416828,,,,,,,,,,,,,, -25000,0.5687259,1.205705,,,,,,,,,,,,,, -25100,0.61179966,1.2200699,,,,,,,,,,,,,, -25200,0.50160456,1.2071117,,,,,,,,,,,,,, -25300,0.5530966,1.185429,,,,,,,,,,,,,, -25400,0.61391634,1.2098209,,,,,,,,,,,,,, -25401,,,0.1630345,0.063379154304254,0.4471465,0.1367291966363188,5348.0,0.26178136,0.0884163061361282,2472.0,21639.50814318657,23815.39300727844,21639.50814318657,2173.9809906482697,0.7831501960754395,0.0 -25500,0.9166493,1.223579,,,,,,,,,,,,,, -25600,0.6337613,1.219455,,,,,,,,,,,,,, -25700,0.8079303,1.2514449,,,,,,,,,,,,,, -25800,0.61793184,1.1342789,,,,,,,,,,,,,, -25900,0.6445991,1.2269031,,,,,,,,,,,,,, -26000,0.7441361,1.218715,,,,,,,,,,,,,, -26100,0.6263691,1.2148691,,,,,,,,,,,,,, -26200,0.53205884,1.1701549,,,,,,,,,,,,,, -26300,0.5953626,1.2562577,,,,,,,,,,,,,, -26400,0.5826111,1.2286121,,,,,,,,,,,,,, -26500,0.65052176,1.1926923,,,,,,,,,,,,,, -26600,0.5224595,1.1947323,,,,,,,,,,,,,, -26700,0.81910634,1.2239329,,,,,,,,,,,,,, -26800,0.5337583,1.178704,,,,,,,,,,,,,, -26900,0.5074399,1.2026995,,,,,,,,,,,,,, -27000,0.664441,1.1399754,,,,,,,,,,,,,, -27100,0.78914654,1.1951346,,,,,,,,,,,,,, -27108,,,0.15664186,0.0624892287806979,0.43819365,0.1324328760246,5348.0,0.25654113,0.086669510287815,2472.0,23079.609487771988,25391.072680950165,23079.609487771988,2309.429753303528,0.8377275466918945,0.0 -27200,0.56590813,1.177243,,,,,,,,,,,,,, -27300,0.673859,1.202368,,,,,,,,,,,,,, -27400,1.0111293,1.1681675,,,,,,,,,,,,,, -27500,0.770441,1.206738,,,,,,,,,,,,,, -27600,0.5122622,1.1551685,,,,,,,,,,,,,, -27700,0.5687718,1.1914198,,,,,,,,,,,,,, -27800,0.57449085,1.1566072,,,,,,,,,,,,,, -27900,0.58059293,1.2583866,,,,,,,,,,,,,, -28000,0.6431004,1.1948287,,,,,,,,,,,,,, -28100,0.59069616,1.1785674,,,,,,,,,,,,,, -28200,0.6205371,1.1583548,,,,,,,,,,,,,, -28300,0.58707553,1.1656395,,,,,,,,,,,,,, -28400,0.59553754,1.1342791,,,,,,,,,,,,,, -28500,0.5762793,1.1292074,,,,,,,,,,,,,, -28600,0.6509549,1.1859452,,,,,,,,,,,,,, -28700,0.88335216,1.1758181,,,,,,,,,,,,,, -28772,,,0.15958466,0.0610066966850886,0.42708495,0.1285034322291628,5348.0,0.24707821,0.0838258891393983,2472.0,24519.481394052505,26966.82464933396,24519.481394052505,2445.1812682151794,0.892916202545166,0.0 -28800,0.7151203,1.160348,,,,,,,,,,,,,, -28900,0.7451649,1.1311616,,,,,,,,,,,,,, -29000,0.62357825,1.193385,,,,,,,,,,,,,, -29100,0.6492575,1.1629801,,,,,,,,,,,,,, -29200,0.5272592,1.1639025,,,,,,,,,,,,,, -29300,0.6746461,1.1353309,,,,,,,,,,,,,, -29400,0.6322565,1.143857,,,,,,,,,,,,,, -29500,0.9013431,1.1513458,,,,,,,,,,,,,, -29600,0.5815827,1.216545,,,,,,,,,,,,,, -29700,0.5478888,1.1869509,,,,,,,,,,,,,, -29800,0.5589346,1.2185172,,,,,,,,,,,,,, -29900,0.6255773,1.1289464,,,,,,,,,,,,,, -30000,0.5170187,1.1530285,,,,,,,,,,,,,, -30100,0.6018233,1.167984,,,,,,,,,,,,,, -30200,0.5793226,1.1620778,,,,,,,,,,,,,, -30300,0.6026954,1.1234152,,,,,,,,,,,,,, -30400,0.61504316,1.1451405,,,,,,,,,,,,,, -30461,,,0.15726769,0.0611525223303841,0.43288928,0.1284261950046825,5348.0,0.24584797,0.0836227733430828,2472.0,25959.4847984314,28543.738730192184,25959.4847984314,2581.9604048728943,0.948739767074585,0.0 -30500,0.6186425,1.1493138,,,,,,,,,,,,,, -30600,0.588402,1.1788855,,,,,,,,,,,,,, -30700,0.61259675,1.165616,,,,,,,,,,,,,, -30800,0.7471274,1.1678578,,,,,,,,,,,,,, -30900,0.5579226,1.0987517,,,,,,,,,,,,,, -31000,0.66767216,1.1333532,,,,,,,,,,,,,, -31100,0.5725103,1.1330919,,,,,,,,,,,,,, -31200,0.57681084,1.1111461,,,,,,,,,,,,,, -31300,0.705403,1.156852,,,,,,,,,,,,,, -31400,0.62640905,1.1471418,,,,,,,,,,,,,, -31500,0.61063653,1.1895814,,,,,,,,,,,,,, -31600,0.7068593,1.1523336,,,,,,,,,,,,,, -31700,0.57678044,1.1745309,,,,,,,,,,,,,, -31800,0.663231,1.1571202,,,,,,,,,,,,,, -31900,0.59865665,1.119485,,,,,,,,,,,,,, -32000,0.6241699,1.1710443,,,,,,,,,,,,,, -32100,0.5689603,1.1676713,,,,,,,,,,,,,, -32179,,,0.14783192,0.0590605554319882,0.42411798,0.1266883574538748,5348.0,0.23941892,0.0809010216724554,2472.0,27399.63714241981,30121.05331158638,27399.63714241981,2718.995190382004,1.0019807815551758,0.0 -32200,0.5602403,1.1009679,,,,,,,,,,,,,, -32300,0.6286982,1.0760571,,,,,,,,,,,,,, -32400,0.71934485,1.1187109,,,,,,,,,,,,,, -32500,0.73187786,1.1998949,,,,,,,,,,,,,, -32600,0.60579085,1.1008986,,,,,,,,,,,,,, -32700,0.73534894,1.1950096,,,,,,,,,,,,,, -32800,0.5975509,1.1603153,,,,,,,,,,,,,, -32900,0.6834257,1.1973754,,,,,,,,,,,,,, -33000,0.56203556,1.1207937,,,,,,,,,,,,,, -33100,0.5895259,1.1142442,,,,,,,,,,,,,, -33200,0.572151,1.0917493,,,,,,,,,,,,,, -33300,0.5859267,1.1702996,,,,,,,,,,,,,, -33400,0.5946757,1.1524959,,,,,,,,,,,,,, -33500,0.54743135,1.1177715,,,,,,,,,,,,,, -33600,0.5956205,1.1733745,,,,,,,,,,,,,, -33700,0.7441468,1.09299,,,,,,,,,,,,,, -33800,0.6428478,1.1558913,,,,,,,,,,,,,, -33870,,,0.14385405,0.0571283357173921,0.41215566,0.1227589136584376,5348.0,0.23393744,0.0776308573517762,2472.0,28839.877410888672,31696.38317656517,28839.877410888672,2853.9519302845,1.0594327449798584,0.0 -33900,0.5874849,1.1540943,,,,,,,,,,,,,, -34000,0.5564159,1.1005697,,,,,,,,,,,,,, -34100,0.7158393,1.1348008,,,,,,,,,,,,,, -34200,0.6251712,1.1557233,,,,,,,,,,,,,, -34300,0.68193054,1.0730417,,,,,,,,,,,,,, -34400,0.70754296,1.1234136,,,,,,,,,,,,,, -34500,0.73848397,1.0718627,,,,,,,,,,,,,, -34600,0.69943315,1.1113293,,,,,,,,,,,,,, -34700,0.5924099,1.1258028,,,,,,,,,,,,,, -34800,0.6302768,1.1796027,,,,,,,,,,,,,, -34900,0.7396065,1.0649722,,,,,,,,,,,,,, -35000,0.67400146,1.0667043,,,,,,,,,,,,,, -35100,0.7006249,1.0962057,,,,,,,,,,,,,, -35200,0.6387809,1.082117,,,,,,,,,,,,,, -35300,0.67335165,1.1248047,,,,,,,,,,,,,, -35400,0.7638164,1.1701247,,,,,,,,,,,,,, -35500,0.72756064,1.1681685,,,,,,,,,,,,,, -35559,,,0.1265491,0.0511833648791855,0.4120602,0.122150670515655,5348.0,0.23113732,0.0759449962423577,2472.0,30280.39509201049,33274.03078866005,30280.39509201049,2990.9510576725006,1.1162869930267334,0.0 -35600,0.64402753,1.0847819,,,,,,,,,,,,,, -35700,0.5681765,1.1676856,,,,,,,,,,,,,, -35800,0.64222306,1.1518257,,,,,,,,,,,,,, -35900,0.6022184,1.0859784,,,,,,,,,,,,,, -36000,0.82853645,1.1510168,,,,,,,,,,,,,, -36100,0.67280996,1.1296067,,,,,,,,,,,,,, -36200,0.5403161,1.0922285,,,,,,,,,,,,,, -36300,0.6269601,1.0977759,,,,,,,,,,,,,, -36400,0.6099674,1.124104,,,,,,,,,,,,,, -36500,0.65723866,1.1299529,,,,,,,,,,,,,, -36600,0.58547133,1.1819791,,,,,,,,,,,,,, -36700,0.642073,1.0541224,,,,,,,,,,,,,, -36800,0.6387783,1.0942925,,,,,,,,,,,,,, -36900,0.6335305,1.1064243,,,,,,,,,,,,,, -37000,0.7026819,1.0942389,,,,,,,,,,,,,, -37100,0.71491003,1.0772761,,,,,,,,,,,,,, -37200,0.72803974,1.0703789,,,,,,,,,,,,,, -37254,,,0.12382909,0.0480419862737182,0.40429664,0.1200266468424457,5348.0,0.22469771,0.0752747141145166,2472.0,31721.07645821572,34852.249284267426,31721.07645821572,3128.355121135712,1.1727495193481443,0.0 -37300,0.6736642,1.1322842,,,,,,,,,,,,,, -37400,0.7124753,1.0797014,,,,,,,,,,,,,, -37500,0.6925102,1.0951654,,,,,,,,,,,,,, -37600,0.6790232,1.1229808,,,,,,,,,,,,,, -37700,0.6412305,1.0861202,,,,,,,,,,,,,, -37800,0.5913737,1.08477,,,,,,,,,,,,,, -37900,0.6070224,1.0638664,,,,,,,,,,,,,, -38000,0.61380494,1.1839646,,,,,,,,,,,,,, -38100,0.633556,1.1020898,,,,,,,,,,,,,, -38200,0.7276349,1.0824883,,,,,,,,,,,,,, -38300,0.6451633,1.0475212,,,,,,,,,,,,,, -38400,0.5655562,1.0851947,,,,,,,,,,,,,, -38500,0.78981483,1.1051381,,,,,,,,,,,,,, -38600,0.6390856,1.0840117,,,,,,,,,,,,,, -38700,0.61284035,1.0736579,,,,,,,,,,,,,, -38800,0.74759734,1.0666614,,,,,,,,,,,,,, -38900,0.590441,1.0969478,,,,,,,,,,,,,, -38930,,,0.13134545,0.051060376541874,0.39354557,0.1150062272512237,5348.0,0.2211299,0.0744013161903601,2472.0,33161.42921876907,36430.21849322319,33161.42921876907,3265.838086128235,1.232497215270996,0.0 -39000,0.7055197,1.0932578,,,,,,,,,,,,,, -39100,0.6466262,1.0961802,,,,,,,,,,,,,, -39200,0.68785137,1.0711496,,,,,,,,,,,,,, -39300,0.63307595,1.0602473,,,,,,,,,,,,,, -39400,0.5716041,1.0205352,,,,,,,,,,,,,, -39500,0.6410467,1.0963382,,,,,,,,,,,,,, -39600,0.8332738,1.125874,,,,,,,,,,,,,, -39700,0.6727653,1.0607053,,,,,,,,,,,,,, -39800,0.6547374,1.0702939,,,,,,,,,,,,,, -39900,0.79006654,1.0700684,,,,,,,,,,,,,, -40000,0.7917741,1.0845922,,,,,,,,,,,,,, -40100,0.8075675,1.0625209,,,,,,,,,,,,,, -40200,0.61705625,1.0504131,,,,,,,,,,,,,, -40300,0.5847125,1.0457606,,,,,,,,,,,,,, -40400,0.7247006,1.0419775,,,,,,,,,,,,,, -40500,0.64833665,1.0667311,,,,,,,,,,,,,, -40600,0.82807344,1.0667032,,,,,,,,,,,,,, -40629,,,0.11861537,0.0469573080715363,0.38333464,0.1137414677003581,5348.0,0.21496634,0.0726138971827839,2472.0,34601.42512226105,38008.17625498772,34601.42512226105,3403.660226583481,1.2956340312957764,0.0 -40700,0.5648111,1.0666698,,,,,,,,,,,,,, -40800,0.6018242,1.0974658,,,,,,,,,,,,,, -40900,0.61999035,1.0833138,,,,,,,,,,,,,, -41000,0.61532414,1.0737501,,,,,,,,,,,,,, -41100,0.6704471,1.0916777,,,,,,,,,,,,,, -41200,0.55105,1.0444491,,,,,,,,,,,,,, -41300,0.6118538,1.0676204,,,,,,,,,,,,,, -41400,0.59649664,1.0289135,,,,,,,,,,,,,, -41500,0.7711579,1.0646805,,,,,,,,,,,,,, -41600,0.69218546,1.1147534,,,,,,,,,,,,,, -41700,0.55663997,1.0594153,,,,,,,,,,,,,, -41800,0.6492468,1.023348,,,,,,,,,,,,,, -41900,0.6617255,1.0307981,,,,,,,,,,,,,, -42000,0.66012216,1.055512,,,,,,,,,,,,,, -42100,0.5705921,1.0841256,,,,,,,,,,,,,, -42200,0.6340205,1.0785524,,,,,,,,,,,,,, -42300,0.6191521,1.024026,,,,,,,,,,,,,, -42306,,,0.10947188,0.045603775520673,0.3845979,0.11188777431283,5348.0,0.21081606,0.0694452907602624,2472.0,36041.619644880295,39584.84460020065,36041.619644880295,3540.007214784622,1.3486223220825195,0.0 -42400,0.6550483,1.0852424,,,,,,,,,,,,,, -42500,0.66557753,1.0725108,,,,,,,,,,,,,, -42600,0.5935478,1.0419191,,,,,,,,,,,,,, -42700,0.65062785,1.0218751,,,,,,,,,,,,,, -42800,0.68277454,1.0443165,,,,,,,,,,,,,, -42900,0.58965504,1.0506982,,,,,,,,,,,,,, -43000,0.67602515,1.0692823,,,,,,,,,,,,,, -43100,0.634008,1.0346928,,,,,,,,,,,,,, -43200,0.9602574,1.1055684,,,,,,,,,,,,,, -43300,0.58803076,1.0334524,,,,,,,,,,,,,, -43400,0.5670265,1.0183972,,,,,,,,,,,,,, -43500,0.80221033,1.0963311,,,,,,,,,,,,,, -43600,0.66565764,0.9814571,,,,,,,,,,,,,, -43700,0.689149,1.0527271,,,,,,,,,,,,,, -43800,0.6176579,1.0720408,,,,,,,,,,,,,, -43900,0.67586094,1.0528585,,,,,,,,,,,,,, -44000,0.76389545,1.0412468,,,,,,,,,,,,,, -44006,,,0.13150723,0.0486614122977759,0.38188517,0.1114629695781882,5348.0,0.21202654,0.0708671013344707,2472.0,37481.56066060066,41161.54116153717,37481.56066060066,3676.626125335693,1.4088342189788818,0.0 -44100,0.6754455,1.0002624,,,,,,,,,,,,,, -44200,0.6104854,1.0359532,,,,,,,,,,,,,, -44300,0.64050096,1.0333246,,,,,,,,,,,,,, -44400,0.73057306,1.0471268,,,,,,,,,,,,,, -44500,0.8392003,1.0429835,,,,,,,,,,,,,, -44600,0.6984159,1.0614992,,,,,,,,,,,,,, -44700,0.6027176,1.049508,,,,,,,,,,,,,, -44800,0.6384572,1.0719911,,,,,,,,,,,,,, -44900,0.99413085,1.0071841,,,,,,,,,,,,,, -45000,0.56595355,1.011815,,,,,,,,,,,,,, -45100,0.6387747,1.0693179,,,,,,,,,,,,,, -45200,0.59842485,1.0649525,,,,,,,,,,,,,, -45300,0.5752027,0.98582375,,,,,,,,,,,,,, -45400,0.64889705,1.0252436,,,,,,,,,,,,,, -45500,0.674694,1.0336334,,,,,,,,,,,,,, -45600,0.8331981,1.0293968,,,,,,,,,,,,,, -45700,0.6871837,1.0258846,,,,,,,,,,,,,, -45720,,,0.08845052,0.0354513198104887,0.37528655,0.1092327447213184,5348.0,0.20303267,0.0677594296508439,2472.0,38921.45214056969,42739.1587574482,38921.45214056969,3814.218501806259,1.4656569957733154,0.0 -45800,0.68145144,1.0594959,,,,,,,,,,,,,, -45900,0.8667367,1.0453011,,,,,,,,,,,,,, -46000,0.7303604,1.0531304,,,,,,,,,,,,,, -46100,0.7329361,1.0169618,,,,,,,,,,,,,, -46200,0.60707545,1.0187594,,,,,,,,,,,,,, -46300,0.6459404,0.98139113,,,,,,,,,,,,,, -46400,0.70275646,1.0184497,,,,,,,,,,,,,, -46500,0.6385241,0.99288565,,,,,,,,,,,,,, -46600,0.6043491,0.9812966,,,,,,,,,,,,,, -46700,1.0369595,1.0259824,,,,,,,,,,,,,, -46800,0.62802577,1.0012672,,,,,,,,,,,,,, -46900,0.6034769,1.0586534,,,,,,,,,,,,,, -47000,1.2185422,1.0353082,,,,,,,,,,,,,, -47100,0.6150561,1.03392,,,,,,,,,,,,,, -47200,0.86107624,1.0527198,,,,,,,,,,,,,, -47300,0.6614577,1.0458881,,,,,,,,,,,,,, -47396,,,0.09425022,0.0379273029531541,0.37133262,0.1077169641908918,5348.0,0.20532721,0.0674547559563707,2472.0,40362.0363907814,44316.29981398583,40362.0363907814,3950.640738248825,1.5266199111938477,0.0 -47400,0.6439028,0.9953566,,,,,,,,,,,,,, -47500,0.72578377,1.0446826,,,,,,,,,,,,,, -47600,0.7720358,1.0005664,,,,,,,,,,,,,, -47700,0.7284093,1.0101016,,,,,,,,,,,,,, -47800,0.63701046,1.0397383,,,,,,,,,,,,,, -47900,0.7197431,0.9958561,,,,,,,,,,,,,, -48000,0.6257081,0.99510837,,,,,,,,,,,,,, -48100,0.8215084,1.0382249,,,,,,,,,,,,,, -48200,0.60410124,0.9889415,,,,,,,,,,,,,, -48300,0.6483383,0.98575014,,,,,,,,,,,,,, -48400,0.7168596,1.0951788,,,,,,,,,,,,,, -48500,0.7190543,1.0371379,,,,,,,,,,,,,, -48600,0.66834503,1.018832,,,,,,,,,,,,,, -48700,0.67608964,0.9961723,,,,,,,,,,,,,, -48800,0.638923,1.0429263,,,,,,,,,,,,,, -48900,0.75637543,1.0458336,,,,,,,,,,,,,, -49000,0.62746185,1.0054874,,,,,,,,,,,,,, -49093,,,0.11438162,0.0454195704267153,0.36344072,0.1061142917829247,5348.0,0.19545701,0.0655251558913736,2472.0,41801.95665073395,45891.41930747032,41801.95665073395,4085.706921339035,1.5839388370513916,0.0 -49100,0.87068236,0.98749,,,,,,,,,,,,,, -49200,0.651473,0.9725657,,,,,,,,,,,,,, -49300,0.5806245,1.0006701,,,,,,,,,,,,,, -49400,0.6655047,1.0165824,,,,,,,,,,,,,, -49500,0.8799348,0.9816451,,,,,,,,,,,,,, -49600,0.67821485,0.9810542,,,,,,,,,,,,,, -49700,0.6564465,0.98688596,,,,,,,,,,,,,, -49800,0.7544104,1.0144047,,,,,,,,,,,,,, -49900,0.7191868,1.0120193,,,,,,,,,,,,,, -50000,0.65364957,0.9619228,,,,,,,,,,,,,, -50100,0.84033793,1.0004622,,,,,,,,,,,,,, -50200,0.6994443,0.94043857,,,,,,,,,,,,,, -50300,0.8394764,1.0323614,,,,,,,,,,,,,, -50400,0.61553097,0.9430158,,,,,,,,,,,,,, -50500,0.6946717,1.024108,,,,,,,,,,,,,, -50600,0.60413074,0.96499187,,,,,,,,,,,,,, -50700,0.79200923,0.964557,,,,,,,,,,,,,, -50786,,,0.117886096,0.0461131535607938,0.360941,0.1038454483138148,5348.0,0.19716543,0.0640830337375337,2472.0,43242.03401255608,47467.48012375832,43242.03401255608,4221.55396938324,1.6436846256256104,0.0 -50800,0.74839085,1.008437,,,,,,,,,,,,,, -50900,0.7152194,0.96693146,,,,,,,,,,,,,, -51000,0.6541396,0.96383977,,,,,,,,,,,,,, -51100,0.67035455,0.90168065,,,,,,,,,,,,,, -51200,0.7173901,0.98834383,,,,,,,,,,,,,, -51300,0.7521891,1.0154581,,,,,,,,,,,,,, -51400,0.6621182,1.0051904,,,,,,,,,,,,,, -51500,0.70203257,0.9709142,,,,,,,,,,,,,, -51600,0.5675042,0.9787669,,,,,,,,,,,,,, -51700,0.701686,1.0182428,,,,,,,,,,,,,, -51800,0.5560744,0.97076505,,,,,,,,,,,,,, -51900,0.679988,0.95155674,,,,,,,,,,,,,, -52000,0.8154773,1.0341588,,,,,,,,,,,,,, -52100,0.6846823,0.9486356,,,,,,,,,,,,,, -52200,0.6491517,0.93860084,,,,,,,,,,,,,, -52300,0.59600455,0.959913,,,,,,,,,,,,,, -52400,0.7364766,0.98298097,,,,,,,,,,,,,, -52446,,,0.12866937,0.0515651104510859,0.3551715,0.1020110642324068,5348.0,0.19476175,0.0630065200170617,2472.0,44682.5635433197,49041.885318517685,44682.5635433197,4355.304067373276,1.6991455554962158,0.0 -52500,0.7141628,0.949273,,,,,,,,,,,,,, -52600,0.81467193,0.9781568,,,,,,,,,,,,,, -52700,0.79078704,0.9828595,,,,,,,,,,,,,, -52800,0.79014343,0.97249544,,,,,,,,,,,,,, -52900,0.9044107,0.93161,,,,,,,,,,,,,, -53000,0.694623,0.9945844,,,,,,,,,,,,,, -53100,0.7494659,0.99459344,,,,,,,,,,,,,, -53200,0.70718,0.9522135,,,,,,,,,,,,,, -53300,0.64431936,0.9797659,,,,,,,,,,,,,, -53400,0.859068,1.018329,,,,,,,,,,,,,, -53500,0.800948,0.9696847,,,,,,,,,,,,,, -53600,0.76109785,0.92568827,,,,,,,,,,,,,, -53700,0.8341487,0.91380686,,,,,,,,,,,,,, -53800,0.7293624,0.98757046,,,,,,,,,,,,,, -53900,0.72645634,0.98945695,,,,,,,,,,,,,, -54000,0.8120272,0.9808908,,,,,,,,,,,,,, -54100,1.0013293,0.93997043,,,,,,,,,,,,,, -54122,,,0.111043565,0.0426047658175842,0.3560681,0.1020400281915869,5348.0,0.19035426,0.0630065200170617,2472.0,46123.199763059616,50617.27509188652,46123.199763059616,4489.92017698288,1.7626543045043943,0.0 -54200,0.6491814,0.948506,,,,,,,,,,,,,, -54300,0.75824064,0.9293857,,,,,,,,,,,,,, -54400,0.5910126,0.93366224,,,,,,,,,,,,,, -54500,0.7435044,0.9880057,,,,,,,,,,,,,, -54600,0.72543055,0.89234865,,,,,,,,,,,,,, -54700,0.89003086,0.91775656,,,,,,,,,,,,,, -54800,0.71722656,0.9926499,,,,,,,,,,,,,, -54900,0.74215674,0.9312801,,,,,,,,,,,,,, -55000,0.61674786,0.93123883,,,,,,,,,,,,,, -55100,0.5923295,0.9369237,,,,,,,,,,,,,, -55200,0.62836033,0.9326364,,,,,,,,,,,,,, -55300,0.71774036,1.0329334,,,,,,,,,,,,,, -55400,0.8643996,0.9348357,,,,,,,,,,,,,, -55500,0.74267656,0.98278165,,,,,,,,,,,,,, -55600,0.7112723,0.9555249,,,,,,,,,,,,,, -55700,0.83547753,0.9731651,,,,,,,,,,,,,, -55800,0.7142565,0.9400157,,,,,,,,,,,,,, -55801,,,0.09821409,0.0392730909928691,0.3484534,0.0991629415796943,5348.0,0.18869571,0.0615034631243271,2472.0,47563.62722706795,52191.360865831375,47563.62722706795,4623.43222784996,1.8321900367736816,0.0 -55900,0.733892,0.9600018,,,,,,,,,,,,,, -56000,0.6327806,0.90188515,,,,,,,,,,,,,, -56100,0.8838162,1.0228825,,,,,,,,,,,,,, -56200,0.8099607,0.93894047,,,,,,,,,,,,,, -56300,0.7454163,0.9248286,,,,,,,,,,,,,, -56400,0.7733307,0.9148472,,,,,,,,,,,,,, -56500,0.619835,0.9341851,,,,,,,,,,,,,, -56600,0.74736947,0.92468953,,,,,,,,,,,,,, -56700,0.65282345,0.88775134,,,,,,,,,,,,,, -56800,0.71358204,0.924757,,,,,,,,,,,,,, -56900,0.6890889,0.9359341,,,,,,,,,,,,,, -57000,0.6799377,0.8857979,,,,,,,,,,,,,, -57100,0.666831,0.9381341,,,,,,,,,,,,,, -57200,0.8646014,0.9493428,,,,,,,,,,,,,, -57300,1.2323813,0.9614601,,,,,,,,,,,,,, -57400,0.62687844,0.9723669,,,,,,,,,,,,,, -57473,,,0.081379846,0.0325830120631762,0.34315622,0.0966237678249032,5348.0,0.18406248,0.0598785367538033,2472.0,49004.61514925957,53768.96395373344,49004.61514925957,4759.919205904007,1.887025594711304,0.0 -57500,0.68404806,0.928276,,,,,,,,,,,,,, -57600,0.7460336,0.93165773,,,,,,,,,,,,,, -57700,0.83782417,0.88081884,,,,,,,,,,,,,, -57800,0.77060574,0.8956276,,,,,,,,,,,,,, -57900,1.2625836,0.9206997,,,,,,,,,,,,,, -58000,0.925746,0.93055373,,,,,,,,,,,,,, -58100,0.89417297,0.92107743,,,,,,,,,,,,,, -58200,1.0961694,0.94582117,,,,,,,,,,,,,, -58300,0.7648321,1.0015417,,,,,,,,,,,,,, -58400,0.75309074,0.92634535,,,,,,,,,,,,,, -58500,0.7535281,0.90582085,,,,,,,,,,,,,, -58600,0.68900895,0.92730695,,,,,,,,,,,,,, -58700,0.8868058,0.9018335,,,,,,,,,,,,,, -58800,0.7575355,0.94667405,,,,,,,,,,,,,, -58900,0.7714675,0.88292915,,,,,,,,,,,,,, -59000,0.7340235,0.9430043,,,,,,,,,,,,,, -59100,0.7282586,0.88582957,,,,,,,,,,,,,, -59177,,,0.09032433,0.0367084051395039,0.34330922,0.0964886026820626,5348.0,0.18465194,0.0590254504092783,2472.0,50444.962716817856,55346.90739274025,50444.962716817856,4897.376502037048,1.9495971202850344,0.0 -59200,0.76743686,0.9196762,,,,,,,,,,,,,, -59300,0.7053401,0.984058,,,,,,,,,,,,,, -59400,0.667693,0.9778828,,,,,,,,,,,,,, -59500,0.6191672,0.9079376,,,,,,,,,,,,,, -59600,0.8216526,0.9409453,,,,,,,,,,,,,, -59700,0.77807647,0.8819812,,,,,,,,,,,,,, -59800,0.6944193,0.87582725,,,,,,,,,,,,,, -59900,0.7173201,0.9499778,,,,,,,,,,,,,, -60000,1.2869773,0.9153295,,,,,,,,,,,,,, -60100,0.90728045,0.899471,,,,,,,,,,,,,, -60200,0.79289585,0.9104863,,,,,,,,,,,,,, -60300,0.70488864,0.9088611,,,,,,,,,,,,,, -60400,0.7391437,0.9410956,,,,,,,,,,,,,, -60500,0.6588365,0.92743313,,,,,,,,,,,,,, -60600,0.7174004,0.91829914,,,,,,,,,,,,,, -60700,0.7438734,0.9374275,,,,,,,,,,,,,, -60800,0.9669827,0.9426464,,,,,,,,,,,,,, -60853,,,0.076652534,0.0303122660678002,0.33885056,0.0949245488863357,5348.0,0.18213382,0.0579083135295431,2472.0,51885.22932291031,56927.25577759743,51885.22932291031,5037.317294836044,2.0173354148864746,0.0 -60900,0.70021343,0.8876658,,,,,,,,,,,,,, -61000,0.65734285,0.93539953,,,,,,,,,,,,,, -61100,0.80570316,0.8888433,,,,,,,,,,,,,, -61200,0.68193656,0.8920339,,,,,,,,,,,,,, -61300,0.8622532,0.88675314,,,,,,,,,,,,,, -61400,0.72866356,0.85069126,,,,,,,,,,,,,, -61500,0.6772958,0.89481384,,,,,,,,,,,,,, -61600,0.76031786,0.9166474,,,,,,,,,,,,,, -61700,0.7055619,0.90620637,,,,,,,,,,,,,, -61800,0.6735383,0.91660047,,,,,,,,,,,,,, -61900,0.70183784,0.8592054,,,,,,,,,,,,,, -62000,0.7851539,0.8718181,,,,,,,,,,,,,, -62100,0.7328035,0.8725341,,,,,,,,,,,,,, -62200,0.7361255,0.9164847,,,,,,,,,,,,,, -62300,1.0747716,0.8836795,,,,,,,,,,,,,, -62400,0.7357288,0.88900214,,,,,,,,,,,,,, -62500,0.7209094,0.89327544,,,,,,,,,,,,,, -62543,,,0.07401633,0.0311466567874868,0.3358002,0.0939011556619712,5348.0,0.17903255,0.0567708650701765,2472.0,53326.08283543587,58502.95514130592,53326.08283543587,5172.018817186356,2.085672378540039,0.0 -62600,0.8787725,0.8914157,,,,,,,,,,,,,, -62700,0.8611964,0.91171044,,,,,,,,,,,,,, -62800,0.7521568,0.87498456,,,,,,,,,,,,,, -62900,0.7542636,0.9128681,,,,,,,,,,,,,, -63000,0.8482347,0.8812664,,,,,,,,,,,,,, -63100,0.7307759,0.8817807,,,,,,,,,,,,,, -63200,0.8783412,0.89979595,,,,,,,,,,,,,, -63300,0.70135343,0.8944932,,,,,,,,,,,,,, -63400,0.7796948,0.91566557,,,,,,,,,,,,,, -63500,0.71412593,0.87521684,,,,,,,,,,,,,, -63600,0.8355614,0.91671145,,,,,,,,,,,,,, -63700,0.80816406,0.85671836,,,,,,,,,,,,,, -63800,0.72039896,0.8387845,,,,,,,,,,,,,, -63900,0.6769848,0.9243061,,,,,,,,,,,,,, -64000,0.70489025,0.8724912,,,,,,,,,,,,,, -64100,0.67169607,0.8781216,,,,,,,,,,,,,, -64200,0.65774494,0.87684715,,,,,,,,,,,,,, -64246,,,0.06872757,0.0277908723154675,0.33404937,0.0938721917027911,5348.0,0.17863517,0.056791176649808,2472.0,54766.469485759735,60079.09964752197,54766.469485759735,5307.635042190552,2.152307748794556,0.0 -64300,0.6633666,0.89605254,,,,,,,,,,,,,, -64400,0.9713062,0.87811893,,,,,,,,,,,,,, -64500,0.71752405,0.8777287,,,,,,,,,,,,,, -64600,1.0949591,0.9105643,,,,,,,,,,,,,, -64700,0.94743234,0.91185486,,,,,,,,,,,,,, -64800,0.7860848,0.8697984,,,,,,,,,,,,,, -64900,0.924873,0.92012566,,,,,,,,,,,,,, -65000,1.0340639,0.9051635,,,,,,,,,,,,,, -65100,0.90245694,0.9465197,,,,,,,,,,,,,, -65200,0.8471844,0.89240444,,,,,,,,,,,,,, -65300,0.64472497,0.8866327,,,,,,,,,,,,,, -65400,0.8029539,0.8610492,,,,,,,,,,,,,, -65500,0.73311937,0.84292054,,,,,,,,,,,,,, -65600,1.118846,0.85999125,,,,,,,,,,,,,, -65700,0.851139,0.88673997,,,,,,,,,,,,,, -65800,1.3565806,0.9514309,,,,,,,,,,,,,, -65900,0.84802634,0.9149795,,,,,,,,,,,,,, -65915,,,0.07059194,0.0279766860949208,0.32888624,0.092018498315263,5348.0,0.17421615,0.055308431336705,2472.0,56206.48468470573,61654.99394798279,56206.48468470573,5443.378481864929,2.2151148319244385,0.0 -66000,0.88820624,0.86654496,,,,,,,,,,,,,, -66100,0.7209201,0.86847305,,,,,,,,,,,,,, -66200,0.97331464,0.8231773,,,,,,,,,,,,,, -66300,1.3257278,0.84258544,,,,,,,,,,,,,, -66400,0.6604135,0.862137,,,,,,,,,,,,,, -66500,0.9200163,0.8924304,,,,,,,,,,,,,, -66600,0.82586336,0.8791917,,,,,,,,,,,,,, -66700,0.86752695,0.8355906,,,,,,,,,,,,,, -66800,0.6979774,0.83644056,,,,,,,,,,,,,, -66900,0.85586447,0.858346,,,,,,,,,,,,,, -67000,0.693603,0.8302991,,,,,,,,,,,,,, -67100,0.88259965,0.83086896,,,,,,,,,,,,,, -67200,0.9183953,0.8572434,,,,,,,,,,,,,, -67300,0.87591827,0.8609035,,,,,,,,,,,,,, -67400,0.9403868,0.88531744,,,,,,,,,,,,,, -67500,0.8703572,0.8539512,,,,,,,,,,,,,, -67587,,,0.07291302,0.0287140800271538,0.32705754,0.0905123724378964,5348.0,0.17488703,0.0552068734385473,2472.0,57647.05547738075,63231.67115473747,57647.05547738075,5579.349866390228,2.277138948440552,0.0 -67600,0.96672404,0.8775531,,,,,,,,,,,,,, -67700,1.0544637,0.8913003,,,,,,,,,,,,,, -67800,0.827777,0.8760582,,,,,,,,,,,,,, -67900,1.1711329,0.8296404,,,,,,,,,,,,,, -68000,1.0182246,0.8348369,,,,,,,,,,,,,, -68100,0.7767883,0.8939506,,,,,,,,,,,,,, -68200,0.97493684,0.8635233,,,,,,,,,,,,,, -68300,0.9533008,0.87024003,,,,,,,,,,,,,, -68400,0.87926596,0.8529648,,,,,,,,,,,,,, -68500,0.76838374,0.8208502,,,,,,,,,,,,,, -68600,0.7636569,0.8620134,,,,,,,,,,,,,, -68700,1.0836555,0.81897056,,,,,,,,,,,,,, -68800,0.9803311,0.8633832,,,,,,,,,,,,,, -68900,0.7916733,0.86623305,,,,,,,,,,,,,, -69000,0.7488301,0.83503634,,,,,,,,,,,,,, -69100,0.67047715,0.85551304,,,,,,,,,,,,,, -69200,1.1284392,0.8836228,,,,,,,,,,,,,, -69280,,,0.060553554,0.0240735475310397,0.329274,0.0914102551724803,5348.0,0.1747135,0.0554506123941258,2472.0,59087.30037307739,64807.39521574974,59087.30037307739,5714.688037395477,2.3428428173065186,0.0 -69300,0.83027,0.8275457,,,,,,,,,,,,,, -69400,0.8021144,0.88415277,,,,,,,,,,,,,, -69500,0.99604875,0.86745226,,,,,,,,,,,,,, -69600,0.7070279,0.85872954,,,,,,,,,,,,,, -69700,0.95793515,0.8268392,,,,,,,,,,,,,, -69800,0.72330725,0.80327415,,,,,,,,,,,,,, -69900,0.8440853,0.83773416,,,,,,,,,,,,,, -70000,0.81281567,0.8129536,,,,,,,,,,,,,, -70100,1.8709161,0.85021067,,,,,,,,,,,,,, -70200,0.8235543,0.83609957,,,,,,,,,,,,,, -70300,1.0553048,0.8416355,,,,,,,,,,,,,, -70400,0.71938413,0.855932,,,,,,,,,,,,,, -70500,0.7650009,0.8862864,,,,,,,,,,,,,, -70600,0.66972667,0.86247456,,,,,,,,,,,,,, -70700,1.0020465,0.8625535,,,,,,,,,,,,,, -70800,0.67762357,0.8824595,,,,,,,,,,,,,, -70900,0.84407884,0.84573597,,,,,,,,,,,,,, -70958,,,0.06268328,0.0241543389096271,0.3256981,0.0898944746420537,5348.0,0.17322181,0.0546178376292324,2472.0,60527.92716932297,66382.81089735031,60527.92716932297,5849.344304800034,2.4029757976531982,0.0 -71000,1.1638241,0.886142,,,,,,,,,,,,,, -71100,0.75604934,0.8408685,,,,,,,,,,,,,, -71200,1.0618784,0.8796219,,,,,,,,,,,,,, -71300,0.8368022,0.85009956,,,,,,,,,,,,,, -71400,0.751754,0.86467546,,,,,,,,,,,,,, -71500,0.7519028,0.86581427,,,,,,,,,,,,,, -71600,1.4136654,0.81443995,,,,,,,,,,,,,, -71611,,,,,,,,,,,61068.90940284729,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 9b04a71ee..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -143.2282350063324,0.0,34.168030738830566,1,0,34.168030738830566,31.279491,2472,1.0976986980277457,177.3963179588318,32.681545,1.395902589872439,31.163591,5348,1.043156299178389 -253.6914649009705,0.0291051864624023,1474.7652645111084,1670,0,1474.7652645111084,6.4492598,2472,0.899579550301627,1728.55646276474,6.3159585,0.9391896477614642,6.4463353,5348,0.8966179750330672 -370.795667886734,0.0799777507781982,2914.80074095726,3370,0,2914.80074095726,3.4893599,2472,0.776105457721447,3285.823149681092,3.8715672,0.8640947888589398,3.8034496,5348,0.8164843546347162 -505.2781677246094,0.1326217651367187,4354.720586776733,5031,0,4354.720586776733,0.91928077,2472,0.2884244307679808,4860.35196018219,1.1874629,0.3598878635673402,1.2110256,5348,0.3440339071415468 -642.7062382698059,0.1861689090728759,5795.25229716301,6703,0,5795.25229716301,0.5802292,2472,0.1959661202851745,6438.438871383667,0.7139498,0.2382428648205244,0.84304017,5348,0.2546994023769756 -779.3713698387146,0.2369072437286377,7235.874894618988,8372,0,7235.874894618988,0.50354314,2472,0.1695204436048991,8015.850210905075,0.5976577,0.1992043259062796,0.7593585,5348,0.2255616594417679 -917.6057934761049,0.2905178070068359,8676.452405929565,10060,0,8676.452405929565,0.44782707,2472,0.1513415798346637,9594.790989875792,0.58726174,0.2015751037064896,0.6942416,5348,0.2111665717292449 -1052.4742050170898,0.3399777412414551,10116.661395788193,11776,0,10116.661395788193,0.40460023,2472,0.137265655150001,11169.995291233065,0.5025373,0.1714755275350288,0.6410851,5348,0.1946764243026927 -1190.592571258545,0.3921723365783691,11556.729298353195,13468,0,11556.729298353195,0.38324997,2472,0.1325330570958503,12748.30952501297,0.49499294,0.1661610499241574,0.6108735,5348,0.1865858250383772 -1325.306384563446,0.4475064277648926,12996.726741552353,15157,0,12996.726741552353,0.3680178,2472,0.1252412000081246,14323.149782896042,0.41451362,0.1453697549614672,0.5853169,5348,0.1794413817739459 -1461.0907878875732,0.5039346218109131,14436.753607988358,16871,0,14436.753607988358,0.34979352,2472,0.1208132756484471,15899.094260454178,0.41015854,0.1461925696453384,0.56820166,5348,0.1730886200604381 -1596.6722741127014,0.5620999336242676,15877.172675848007,18548,0,15877.172675848007,0.33841434,2472,0.1152682144090345,17475.227288007736,0.3982619,0.1411061803467432,0.55043125,5348,0.1676240864284542 -1743.6689743995669,0.6150286197662354,17317.922025680542,20273,0,17317.922025680542,0.32714456,2472,0.1112871448012512,19063.102712631226,0.26937568,0.0986536762934611,0.54442644,5348,0.1646794172451413 -1871.4549486637115,0.674034595489502,18758.11255025864,22022,0,18758.11255025864,0.31976706,2472,0.1091544289399386,20631.21089434624,0.25267893,0.0920306252131056,0.5279306,5348,0.1606244629599235 -1997.80451631546,0.7282936573028564,20198.68664932251,23874,0,20198.68664932251,0.31226346,2472,0.1043405845672618,22198.257422208782,0.24335396,0.0904516271512627,0.51649487,5348,0.1568688029195671 -2125.4988000392914,0.7810320854187012,21638.83458685875,25727,0,21638.83458685875,0.2979179,2472,0.1024516076615278,23766.222700834274,0.22757375,0.0857738831471038,0.5035488,5348,0.1526786834915087 -2253.6334249973297,0.829434871673584,23079.49457478524,27582,0,23079.49457478524,0.2894307,2472,0.0979221254036926,25335.13685107231,0.22518203,0.0836282808468902,0.49105188,5348,0.1481313419002288 -2383.031873941421,0.8874096870422363,24520.111387252808,29437,0,24520.111387252808,0.27920473,2472,0.0958503442812747,26905.28270840645,0.20949112,0.0790651574502185,0.47871414,5348,0.1447425586761539 -2510.4499428272247,0.9405355453491212,25960.74817752838,31289,0,25960.74817752838,0.28004083,2472,0.0949972579367497,28473.462604045868,0.24946217,0.0872451837069823,0.4818028,5348,0.144076387615011 -2640.090026140213,0.9931390285491944,27400.628918409348,33137,0,27400.628918409348,0.27050698,2472,0.0911989925456502,30043.10976457596,0.20576166,0.0775823398719714,0.46390292,5348,0.1377043165953831 -2770.963233947754,1.047239065170288,28841.17727303505,34975,0,28841.17727303505,0.26218984,2472,0.0894521966973371,31614.65806341172,0.18593837,0.0706608123632815,0.44671303,5348,0.1354354731262732 -2901.019725084305,1.099900484085083,30281.21803665161,36813,0,30281.21803665161,0.25743183,2472,0.0867304450267097,33184.88313293457,0.1798886,0.0659792170515853,0.4445087,5348,0.1318246328818174 -3030.02846121788,1.1584982872009275,31721.667903661728,38657,0,31721.667903661728,0.24864013,2472,0.0843946133690817,34754.47449827194,0.17082553,0.0660368217054263,0.43620944,5348,0.1306853838207324 -3159.066329717636,1.2179243564605713,33161.69714832306,40503,0,33161.69714832306,0.2497961,2472,0.0837852659801352,36323.67623138428,0.18386379,0.0680838521902516,0.43003973,5348,0.1276827867190592 -3288.757158517837,1.272374153137207,34602.18103837967,42336,0,34602.18103837967,0.23499578,2472,0.0797432616334572,37893.97904467583,0.17193142,0.0642241447399598,0.41626227,5348,0.1248346640663467 -3419.2407870292664,1.3307726383209229,36043.03625369072,44177,0,36043.03625369072,0.23374555,2472,0.080068246907562,39465.451370716095,0.15760088,0.0605369308460445,0.41124475,5348,0.1217162111279531 -3549.8614530563354,1.3871524333953855,37483.32582259178,46011,0,37483.32582259178,0.22595222,2472,0.075437206751569,41036.49385213852,0.14866316,0.0575584215110063,0.39693525,5348,0.118269499985518 -3679.6286759376526,1.4453561305999756,38923.7791454792,47854,0,38923.7791454792,0.219991,2472,0.072654520342047,42606.84801912308,0.1336802,0.0514337012198341,0.39036846,5348,0.1155951610878863 -3810.182657241821,1.5044760704040527,40363.92171168327,49697,0,40363.92171168327,0.21653153,2472,0.0710092823918916,44177.68133687973,0.13503304,0.0513602419115919,0.38563406,5348,0.1136931944350579 -3940.1016092300415,1.5616471767425537,41803.84714341164,51530,0,41803.84714341164,0.20964031,2472,0.0716389413604696,45747.65846896172,0.13680682,0.0523486669727992,0.37866125,5348,0.1118588103536499 -4069.81272816658,1.6211626529693604,43243.965631484985,53373,0,43243.965631484985,0.2012541,2472,0.0681859728231064,47317.6235909462,0.11914682,0.0452420368083603,0.36934325,5348,0.1089817237417573 -4201.730570077896,1.6765682697296145,44683.96534562111,55209,0,44683.96534562111,0.19781065,2472,0.0661548148599516,48889.67320728302,0.11970915,0.0470404353538028,0.35692453,5348,0.1047819496606389 -4332.178327083588,1.7337017059326172,46123.9932820797,57047,0,46123.9932820797,0.19295095,2472,0.0630268315966932,50460.28098845482,0.13444948,0.0475382545195381,0.3579548,5348,0.1044440368035374 -4464.339418172836,1.7940006256103516,47564.11428070068,58872,0,47564.11428070068,0.18680844,2472,0.0614222168058009,52032.69997930527,0.08747087,0.0343877048170506,0.34598675,5348,0.0991919055388744 -4596.614463329315,1.8496484756469729,49004.76940536499,60697,0,49004.76940536499,0.17877539,2472,0.0604675725631182,53605.760281562805,0.086636014,0.0333267446215305,0.32997072,5348,0.0964210201106423 -4726.184319496155,1.9108588695526123,50444.94024038315,62523,0,50444.94024038315,0.17875016,2472,0.0592691893648569,55175.63874554634,0.10608669,0.0416625411179686,0.32664424,5348,0.0953783175801577 -4856.205128669739,1.9761414527893064,51885.154266119,64357,0,51885.154266119,0.17223309,2472,0.0569942924461235,56746.01544976234,0.10365024,0.0396321834020283,0.31951347,5348,0.0919122971316025 -4984.438056707382,2.034407615661621,53325.56930327416,66182,0,53325.56930327416,0.16681975,2472,0.0547397071070217,58314.800549030304,0.11834669,0.045590728569106,0.3103893,5348,0.0891317570503104 -5113.788911581039,2.122683525085449,54765.563039541245,67997,0,54765.563039541245,0.16329446,2472,0.0533178965328133,59884.31038093567,0.09386895,0.0346327432593597,0.30321878,5348,0.0869111868465006 -5243.618288755417,2.180286169052124,56205.50395298004,69819,0,56205.50395298004,0.1599037,2472,0.0525257449271829,61454.21480035782,0.09016636,0.0354465941015229,0.30152643,5348,0.0854147156221941 -5376.695669412613,2.2511696815490723,57645.81762099266,71644,0,57645.81762099266,0.15629764,2472,0.0508195722381329,63027.75369310379,0.0650331,0.0247357237001371,0.29494342,5348,0.0837830792550469 -5506.537875413895,2.314791440963745,59086.080971241,73472,0,59086.080971241,0.1558389,2472,0.0498649279954502,64598.00075960159,0.07749713,0.0303396862581965,0.29196256,5348,0.0827693406837425 -5638.150643587112,2.375697612762451,60526.34147930145,75286,0,60526.34147930145,0.15373042,2472,0.04909308796945138,66170.01026082039,0.067730136,0.024790915163660655,0.289029,5348,0.08147561717369686 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/measurements.csv deleted file mode 100644 index 9fac7a28a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/measurements.csv +++ /dev/null @@ -1,805 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,46.78699,32.649097,,,,,,,,,,,,,, -1,,,32.681545,1.395902589872439,31.163591,1.043156299178389,5348.0,31.279491,1.0976986980277457,2472.0,34.168030738830566,177.3963179588318,34.168030738830566,143.2282350063324,0.0,0.0 -100,23.842932,8.101322,,,,,,,,,,,,,, -200,1.5977552,6.215599,,,,,,,,,,,,,, -300,0.9557819,5.8622637,,,,,,,,,,,,,, -400,0.76085234,5.853574,,,,,,,,,,,,,, -500,0.6266219,5.818777,,,,,,,,,,,,,, -600,0.429385,5.806005,,,,,,,,,,,,,, -700,0.35267285,5.825949,,,,,,,,,,,,,, -800,0.27050275,5.820455,,,,,,,,,,,,,, -900,0.39917737,5.8169327,,,,,,,,,,,,,, -1000,0.29937616,5.794168,,,,,,,,,,,,,, -1100,0.50965524,5.7841334,,,,,,,,,,,,,, -1200,0.6082611,5.786364,,,,,,,,,,,,,, -1300,0.515797,5.7696285,,,,,,,,,,,,,, -1400,0.29569462,5.772766,,,,,,,,,,,,,, -1500,0.8285872,5.786098,,,,,,,,,,,,,, -1600,0.33940473,5.735037,,,,,,,,,,,,,, -1670,,,6.3159585,0.9391896477614642,6.4463353,0.8966179750330672,5348.0,6.4492598,0.899579550301627,2472.0,1474.7652645111084,1728.55646276474,1474.7652645111084,253.6914649009705,0.0291051864624023,0.0 -1700,0.7465609,5.5942802,,,,,,,,,,,,,, -1800,0.694729,5.520156,,,,,,,,,,,,,, -1900,0.61070883,5.4426394,,,,,,,,,,,,,, -2000,1.3527675,5.212208,,,,,,,,,,,,,, -2100,0.9351999,4.5932293,,,,,,,,,,,,,, -2200,0.7975956,4.0874014,,,,,,,,,,,,,, -2300,1.6879997,3.722607,,,,,,,,,,,,,, -2400,1.22722,3.497296,,,,,,,,,,,,,, -2500,1.2280827,3.3185563,,,,,,,,,,,,,, -2600,0.89729726,3.221447,,,,,,,,,,,,,, -2700,1.2256725,3.0626097,,,,,,,,,,,,,, -2800,1.1212193,2.9777732,,,,,,,,,,,,,, -2900,1.0572062,2.930581,,,,,,,,,,,,,, -3000,1.6415724,2.8380296,,,,,,,,,,,,,, -3100,1.0922638,2.6764257,,,,,,,,,,,,,, -3200,0.92071813,2.6800663,,,,,,,,,,,,,, -3300,1.4875884,2.6679819,,,,,,,,,,,,,, -3370,,,3.8715672,0.8640947888589398,3.8034496,0.8164843546347162,5348.0,3.4893599,0.776105457721447,2472.0,2914.80074095726,3285.823149681092,2914.80074095726,370.795667886734,0.0799777507781982,0.0 -3400,1.0349046,2.593176,,,,,,,,,,,,,, -3500,0.96670955,2.5261133,,,,,,,,,,,,,, -3600,0.936141,2.4748847,,,,,,,,,,,,,, -3700,0.8985722,2.4248078,,,,,,,,,,,,,, -3800,0.9440784,2.3815799,,,,,,,,,,,,,, -3900,0.93365884,2.370891,,,,,,,,,,,,,, -4000,0.8798872,2.374389,,,,,,,,,,,,,, -4100,0.9812135,2.2498279,,,,,,,,,,,,,, -4200,1.1101271,2.311351,,,,,,,,,,,,,, -4300,0.8731418,2.1667957,,,,,,,,,,,,,, -4400,0.8138109,2.189381,,,,,,,,,,,,,, -4500,0.98397046,2.1582599,,,,,,,,,,,,,, -4600,0.8755569,2.0994568,,,,,,,,,,,,,, -4700,0.8403401,2.1124246,,,,,,,,,,,,,, -4800,1.0418512,2.0973272,,,,,,,,,,,,,, -4900,0.88431597,2.0227876,,,,,,,,,,,,,, -5000,0.83343667,2.0502648,,,,,,,,,,,,,, -5031,,,1.1874629,0.3598878635673402,1.2110256,0.3440339071415468,5348.0,0.91928077,0.2884244307679808,2472.0,4354.720586776733,4860.35196018219,4354.720586776733,505.2781677246094,0.1326217651367187,0.0 -5100,0.84168816,2.0518863,,,,,,,,,,,,,, -5200,0.7888951,1.9543777,,,,,,,,,,,,,, -5300,0.9085081,1.946389,,,,,,,,,,,,,, -5400,0.8095585,1.9769697,,,,,,,,,,,,,, -5500,0.78373545,1.884627,,,,,,,,,,,,,, -5600,0.7233463,1.8961855,,,,,,,,,,,,,, -5700,0.8453081,1.8854874,,,,,,,,,,,,,, -5800,0.8949396,1.8803686,,,,,,,,,,,,,, -5900,1.0032357,1.8645335,,,,,,,,,,,,,, -6000,0.80961734,1.8359623,,,,,,,,,,,,,, -6100,0.7631609,1.8318913,,,,,,,,,,,,,, -6200,0.71731585,1.777987,,,,,,,,,,,,,, -6300,0.8935853,1.7956154,,,,,,,,,,,,,, -6400,0.7955782,1.9107195,,,,,,,,,,,,,, -6500,0.78051305,1.7652191,,,,,,,,,,,,,, -6600,0.72692347,1.7873771,,,,,,,,,,,,,, -6700,0.9405532,1.7877635,,,,,,,,,,,,,, -6703,,,0.7139498,0.2382428648205244,0.84304017,0.2546994023769756,5348.0,0.5802292,0.1959661202851745,2472.0,5795.25229716301,6438.438871383667,5795.25229716301,642.7062382698059,0.1861689090728759,0.0 -6800,0.8000329,1.7531193,,,,,,,,,,,,,, -6900,0.8314924,1.8005029,,,,,,,,,,,,,, -7000,0.70989,1.8184711,,,,,,,,,,,,,, -7100,0.7508831,1.7787932,,,,,,,,,,,,,, -7200,0.75504166,1.762128,,,,,,,,,,,,,, -7300,0.78731704,1.762769,,,,,,,,,,,,,, -7400,0.7659558,1.6817874,,,,,,,,,,,,,, -7500,0.81334025,1.6556218,,,,,,,,,,,,,, -7600,0.70403093,1.7689753,,,,,,,,,,,,,, -7700,0.7822553,1.7638524,,,,,,,,,,,,,, -7800,0.79378116,1.685101,,,,,,,,,,,,,, -7900,0.83242923,1.6837705,,,,,,,,,,,,,, -8000,0.6582384,1.7488266,,,,,,,,,,,,,, -8100,0.81619567,1.6843319,,,,,,,,,,,,,, -8200,0.759721,1.7456098,,,,,,,,,,,,,, -8300,0.7330775,1.6698946,,,,,,,,,,,,,, -8372,,,0.5976577,0.1992043259062796,0.7593585,0.2255616594417679,5348.0,0.50354314,0.1695204436048991,2472.0,7235.874894618988,8015.850210905075,7235.874894618988,779.3713698387146,0.2369072437286377,0.0 -8400,0.7188286,1.7042708,,,,,,,,,,,,,, -8500,0.92532223,1.7343745,,,,,,,,,,,,,, -8600,0.6833287,1.6472933,,,,,,,,,,,,,, -8700,0.83596575,1.6629655,,,,,,,,,,,,,, -8800,0.67690164,1.6546854,,,,,,,,,,,,,, -8900,0.66505194,1.6082133,,,,,,,,,,,,,, -9000,0.681519,1.6903801,,,,,,,,,,,,,, -9100,0.7620971,1.6240457,,,,,,,,,,,,,, -9200,0.71857584,1.6159896,,,,,,,,,,,,,, -9300,0.79227483,1.5915264,,,,,,,,,,,,,, -9400,0.6653058,1.6481324,,,,,,,,,,,,,, -9500,0.7292107,1.6394638,,,,,,,,,,,,,, -9600,0.69322044,1.608804,,,,,,,,,,,,,, -9700,0.77563685,1.6543779,,,,,,,,,,,,,, -9800,0.64223444,1.6438564,,,,,,,,,,,,,, -9900,0.63687265,1.5980451,,,,,,,,,,,,,, -10000,0.8176311,1.596282,,,,,,,,,,,,,, -10060,,,0.58726174,0.2015751037064896,0.6942416,0.2111665717292449,5348.0,0.44782707,0.1513415798346637,2472.0,8676.452405929565,9594.790989875792,8676.452405929565,917.6057934761049,0.2905178070068359,0.0 -10100,0.8285124,1.5652617,,,,,,,,,,,,,, -10200,0.6516652,1.5958735,,,,,,,,,,,,,, -10300,0.96345407,1.5656227,,,,,,,,,,,,,, -10400,0.779121,1.5770128,,,,,,,,,,,,,, -10500,0.79173535,1.5734564,,,,,,,,,,,,,, -10600,0.85728395,1.5667489,,,,,,,,,,,,,, -10700,0.5848513,1.5607145,,,,,,,,,,,,,, -10800,0.68343437,1.5789208,,,,,,,,,,,,,, -10900,0.75242865,1.5841974,,,,,,,,,,,,,, -11000,0.66967505,1.5670398,,,,,,,,,,,,,, -11100,0.79760635,1.5305849,,,,,,,,,,,,,, -11200,0.6733386,1.595261,,,,,,,,,,,,,, -11300,0.65545243,1.5724161,,,,,,,,,,,,,, -11400,0.7284443,1.5026131,,,,,,,,,,,,,, -11500,0.72956777,1.5338464,,,,,,,,,,,,,, -11600,0.6954596,1.5205915,,,,,,,,,,,,,, -11700,0.6856737,1.5793496,,,,,,,,,,,,,, -11776,,,0.5025373,0.1714755275350288,0.6410851,0.1946764243026927,5348.0,0.40460023,0.137265655150001,2472.0,10116.661395788193,11169.995291233065,10116.661395788193,1052.4742050170898,0.3399777412414551,0.0 -11800,0.63298833,1.5284966,,,,,,,,,,,,,, -11900,0.7352866,1.5864064,,,,,,,,,,,,,, -12000,0.83011293,1.4933437,,,,,,,,,,,,,, -12100,0.7746292,1.5785412,,,,,,,,,,,,,, -12200,0.67437977,1.494165,,,,,,,,,,,,,, -12300,0.6367988,1.4982178,,,,,,,,,,,,,, -12400,0.6797558,1.5176415,,,,,,,,,,,,,, -12500,0.8760135,1.4811391,,,,,,,,,,,,,, -12600,0.64551425,1.5760759,,,,,,,,,,,,,, -12700,0.5811667,1.5022436,,,,,,,,,,,,,, -12800,0.7511257,1.50532,,,,,,,,,,,,,, -12900,0.7359003,1.4991057,,,,,,,,,,,,,, -13000,0.6963859,1.5150836,,,,,,,,,,,,,, -13100,0.7065182,1.4970831,,,,,,,,,,,,,, -13200,0.68850005,1.4640721,,,,,,,,,,,,,, -13300,0.6030451,1.4616411,,,,,,,,,,,,,, -13400,0.7995356,1.483195,,,,,,,,,,,,,, -13468,,,0.49499294,0.1661610499241574,0.6108735,0.1865858250383772,5348.0,0.38324997,0.1325330570958503,2472.0,11556.729298353195,12748.30952501297,11556.729298353195,1190.592571258545,0.3921723365783691,0.0 -13500,0.7125292,1.5687199,,,,,,,,,,,,,, -13600,0.6945112,1.5347296,,,,,,,,,,,,,, -13700,0.6978768,1.4581381,,,,,,,,,,,,,, -13800,0.6464116,1.5057585,,,,,,,,,,,,,, -13900,0.74918026,1.4690574,,,,,,,,,,,,,, -14000,0.66741437,1.4749064,,,,,,,,,,,,,, -14100,0.66127676,1.5788182,,,,,,,,,,,,,, -14200,0.6512405,1.4991312,,,,,,,,,,,,,, -14300,0.7162598,1.4366457,,,,,,,,,,,,,, -14400,0.79963195,1.4477974,,,,,,,,,,,,,, -14500,0.71824837,1.422137,,,,,,,,,,,,,, -14600,0.807218,1.3961098,,,,,,,,,,,,,, -14700,0.6054241,1.4793859,,,,,,,,,,,,,, -14800,0.75260043,1.4564685,,,,,,,,,,,,,, -14900,0.6503023,1.4446867,,,,,,,,,,,,,, -15000,0.6953244,1.4554074,,,,,,,,,,,,,, -15100,0.68724036,1.4358045,,,,,,,,,,,,,, -15157,,,0.41451362,0.1453697549614672,0.5853169,0.1794413817739459,5348.0,0.3680178,0.1252412000081246,2472.0,12996.726741552353,14323.149782896042,12996.726741552353,1325.306384563446,0.4475064277648926,0.0 -15200,0.65427107,1.4747128,,,,,,,,,,,,,, -15300,0.6338046,1.4670883,,,,,,,,,,,,,, -15400,0.67002636,1.4391979,,,,,,,,,,,,,, -15500,0.67226297,1.4279803,,,,,,,,,,,,,, -15600,0.73222685,1.4502889,,,,,,,,,,,,,, -15700,0.606436,1.4750122,,,,,,,,,,,,,, -15800,0.6697538,1.4610803,,,,,,,,,,,,,, -15900,0.5758814,1.4054322,,,,,,,,,,,,,, -16000,0.6214487,1.4041843,,,,,,,,,,,,,, -16100,1.0029365,1.4638948,,,,,,,,,,,,,, -16200,0.64150614,1.4233695,,,,,,,,,,,,,, -16300,0.74347216,1.412016,,,,,,,,,,,,,, -16400,0.70490015,1.455218,,,,,,,,,,,,,, -16500,0.73155326,1.3845423,,,,,,,,,,,,,, -16600,0.8774453,1.4062846,,,,,,,,,,,,,, -16700,0.66500694,1.4152128,,,,,,,,,,,,,, -16800,0.6832048,1.4285603,,,,,,,,,,,,,, -16871,,,0.41015854,0.1461925696453384,0.56820166,0.1730886200604381,5348.0,0.34979352,0.1208132756484471,2472.0,14436.753607988358,15899.094260454178,14436.753607988358,1461.0907878875732,0.5039346218109131,0.0 -16900,0.8452373,1.454927,,,,,,,,,,,,,, -17000,0.65764207,1.4540216,,,,,,,,,,,,,, -17100,0.6267493,1.4525099,,,,,,,,,,,,,, -17200,0.6524232,1.4094223,,,,,,,,,,,,,, -17300,0.5651972,1.4074703,,,,,,,,,,,,,, -17400,0.67975026,1.3913274,,,,,,,,,,,,,, -17500,0.6775586,1.4446547,,,,,,,,,,,,,, -17600,0.6289499,1.4163668,,,,,,,,,,,,,, -17700,0.69531333,1.3718412,,,,,,,,,,,,,, -17800,0.68224233,1.4517714,,,,,,,,,,,,,, -17900,0.68545485,1.4801303,,,,,,,,,,,,,, -18000,0.59258986,1.4238796,,,,,,,,,,,,,, -18100,0.7227901,1.3823402,,,,,,,,,,,,,, -18200,0.618559,1.4506538,,,,,,,,,,,,,, -18300,0.73461956,1.435401,,,,,,,,,,,,,, -18400,0.6631561,1.458918,,,,,,,,,,,,,, -18500,0.634319,1.4333266,,,,,,,,,,,,,, -18548,,,0.3982619,0.1411061803467432,0.55043125,0.1676240864284542,5348.0,0.33841434,0.1152682144090345,2472.0,15877.172675848007,17475.227288007736,15877.172675848007,1596.6722741127014,0.5620999336242676,0.0 -18600,0.70095456,1.4449507,,,,,,,,,,,,,, -18700,0.59315485,1.3965795,,,,,,,,,,,,,, -18800,0.6521071,1.4890078,,,,,,,,,,,,,, -18900,0.62807363,1.3750421,,,,,,,,,,,,,, -19000,0.67401725,1.3977048,,,,,,,,,,,,,, -19100,0.6363665,1.4226937,,,,,,,,,,,,,, -19200,0.60647213,1.3976104,,,,,,,,,,,,,, -19300,0.6058859,1.4009508,,,,,,,,,,,,,, -19400,0.71784383,1.4160903,,,,,,,,,,,,,, -19500,0.61046785,1.4077991,,,,,,,,,,,,,, -19600,0.67265457,1.3983588,,,,,,,,,,,,,, -19700,0.7014266,1.3908287,,,,,,,,,,,,,, -19800,0.68495995,1.3491027,,,,,,,,,,,,,, -19900,0.6143938,1.4061555,,,,,,,,,,,,,, -20000,0.79255354,1.4457759,,,,,,,,,,,,,, -20100,0.6354329,1.4320724,,,,,,,,,,,,,, -20200,0.6583009,1.4272327,,,,,,,,,,,,,, -20273,,,0.26937568,0.0986536762934611,0.54442644,0.1646794172451413,5348.0,0.32714456,0.1112871448012512,2472.0,17317.922025680542,19063.102712631226,17317.922025680542,1743.6689743995669,0.6150286197662354,0.0 -20300,0.62020767,1.4066949,,,,,,,,,,,,,, -20400,0.6873127,1.4434109,,,,,,,,,,,,,, -20500,0.9269429,1.389416,,,,,,,,,,,,,, -20600,0.76900625,1.3989742,,,,,,,,,,,,,, -20700,0.6260831,1.3804374,,,,,,,,,,,,,, -20800,0.6692454,1.38805,,,,,,,,,,,,,, -20900,0.56722337,1.3863441,,,,,,,,,,,,,, -21000,0.7020585,1.3918645,,,,,,,,,,,,,, -21100,0.65145844,1.4038161,,,,,,,,,,,,,, -21200,0.5992952,1.3747106,,,,,,,,,,,,,, -21300,0.7276587,1.3699927,,,,,,,,,,,,,, -21400,0.6685503,1.3911068,,,,,,,,,,,,,, -21500,0.7100993,1.405463,,,,,,,,,,,,,, -21600,0.69430715,1.41047,,,,,,,,,,,,,, -21700,0.6736904,1.3946466,,,,,,,,,,,,,, -21800,0.7810034,1.3771023,,,,,,,,,,,,,, -21900,0.7610593,1.3943192,,,,,,,,,,,,,, -22000,0.8320461,1.3590732,,,,,,,,,,,,,, -22022,,,0.25267893,0.0920306252131056,0.5279306,0.1606244629599235,5348.0,0.31976706,0.1091544289399386,2472.0,18758.11255025864,20631.21089434624,18758.11255025864,1871.4549486637115,0.674034595489502,0.0 -22100,0.6464915,1.3609563,,,,,,,,,,,,,, -22200,0.81866676,1.3694016,,,,,,,,,,,,,, -22300,0.7041111,1.4245169,,,,,,,,,,,,,, -22400,0.6394613,1.3162013,,,,,,,,,,,,,, -22500,0.7083041,1.3462247,,,,,,,,,,,,,, -22600,0.6578942,1.3217593,,,,,,,,,,,,,, -22700,0.71521705,1.3840493,,,,,,,,,,,,,, -22800,0.6145562,1.4024321,,,,,,,,,,,,,, -22900,0.61634284,1.3450944,,,,,,,,,,,,,, -23000,0.6609083,1.3486596,,,,,,,,,,,,,, -23100,0.5962279,1.3635004,,,,,,,,,,,,,, -23200,0.7097951,1.2987775,,,,,,,,,,,,,, -23300,0.69205236,1.3088638,,,,,,,,,,,,,, -23400,0.5888375,1.3460537,,,,,,,,,,,,,, -23500,0.6716623,1.3456583,,,,,,,,,,,,,, -23600,0.7589436,1.3821306,,,,,,,,,,,,,, -23700,0.6210843,1.3555613,,,,,,,,,,,,,, -23800,0.66281074,1.3382826,,,,,,,,,,,,,, -23874,,,0.24335396,0.0904516271512627,0.51649487,0.1568688029195671,5348.0,0.31226346,0.1043405845672618,2472.0,20198.68664932251,22198.257422208782,20198.68664932251,1997.80451631546,0.7282936573028564,0.0 -23900,0.6583777,1.3065897,,,,,,,,,,,,,, -24000,0.77928305,1.3449721,,,,,,,,,,,,,, -24100,0.76942074,1.3616699,,,,,,,,,,,,,, -24200,0.6680805,1.3267388,,,,,,,,,,,,,, -24300,0.5796101,1.3569657,,,,,,,,,,,,,, -24400,0.69883245,1.318221,,,,,,,,,,,,,, -24500,0.684209,1.2734131,,,,,,,,,,,,,, -24600,0.6600572,1.3457711,,,,,,,,,,,,,, -24700,0.6546674,1.3090888,,,,,,,,,,,,,, -24800,0.8658585,1.3388867,,,,,,,,,,,,,, -24900,0.8633527,1.314119,,,,,,,,,,,,,, -25000,0.6270954,1.3210489,,,,,,,,,,,,,, -25100,0.7247507,1.3640025,,,,,,,,,,,,,, -25200,0.63196677,1.3503383,,,,,,,,,,,,,, -25300,0.6563514,1.3165241,,,,,,,,,,,,,, -25400,0.7435509,1.3262074,,,,,,,,,,,,,, -25500,0.6897022,1.3213975,,,,,,,,,,,,,, -25600,0.7283837,1.3957113,,,,,,,,,,,,,, -25700,0.7545835,1.2842555,,,,,,,,,,,,,, -25727,,,0.22757375,0.0857738831471038,0.5035488,0.1526786834915087,5348.0,0.2979179,0.1024516076615278,2472.0,21638.83458685875,23766.222700834274,21638.83458685875,2125.4988000392914,0.7810320854187012,0.0 -25800,0.6176451,1.2967032,,,,,,,,,,,,,, -25900,0.6600366,1.3448149,,,,,,,,,,,,,, -26000,0.6639296,1.3225076,,,,,,,,,,,,,, -26100,0.78387105,1.2651638,,,,,,,,,,,,,, -26200,0.6354124,1.3282796,,,,,,,,,,,,,, -26300,0.68684274,1.3207111,,,,,,,,,,,,,, -26400,0.66955256,1.3092958,,,,,,,,,,,,,, -26500,0.6938047,1.3580953,,,,,,,,,,,,,, -26600,0.684389,1.2860157,,,,,,,,,,,,,, -26700,0.7112466,1.3628904,,,,,,,,,,,,,, -26800,0.6454711,1.2915952,,,,,,,,,,,,,, -26900,0.76975864,1.3693738,,,,,,,,,,,,,, -27000,0.6107642,1.2891363,,,,,,,,,,,,,, -27100,0.7588974,1.3122417,,,,,,,,,,,,,, -27200,0.7833589,1.3007953,,,,,,,,,,,,,, -27300,0.69737196,1.3111931,,,,,,,,,,,,,, -27400,0.6910973,1.3582294,,,,,,,,,,,,,, -27500,0.75083196,1.3425069,,,,,,,,,,,,,, -27582,,,0.22518203,0.0836282808468902,0.49105188,0.1481313419002288,5348.0,0.2894307,0.0979221254036926,2472.0,23079.49457478524,25335.13685107231,23079.49457478524,2253.6334249973297,0.829434871673584,0.0 -27600,0.6723331,1.2718804,,,,,,,,,,,,,, -27700,0.62544554,1.289551,,,,,,,,,,,,,, -27800,0.6587043,1.3494351,,,,,,,,,,,,,, -27900,0.82896715,1.2918607,,,,,,,,,,,,,, -28000,0.68114847,1.2803468,,,,,,,,,,,,,, -28100,0.7242687,1.2822348,,,,,,,,,,,,,, -28200,0.6001267,1.3189265,,,,,,,,,,,,,, -28300,0.72122675,1.2796783,,,,,,,,,,,,,, -28400,0.7191644,1.3212332,,,,,,,,,,,,,, -28500,0.7093723,1.3254453,,,,,,,,,,,,,, -28600,0.72222596,1.3216094,,,,,,,,,,,,,, -28700,0.69028074,1.3663621,,,,,,,,,,,,,, -28800,0.71319383,1.3042691,,,,,,,,,,,,,, -28900,0.76272255,1.2457571,,,,,,,,,,,,,, -29000,0.66660136,1.3381871,,,,,,,,,,,,,, -29100,0.61874473,1.288043,,,,,,,,,,,,,, -29200,0.69407696,1.2610023,,,,,,,,,,,,,, -29300,0.6735996,1.3005137,,,,,,,,,,,,,, -29400,0.71435815,1.304829,,,,,,,,,,,,,, -29437,,,0.20949112,0.0790651574502185,0.47871414,0.1447425586761539,5348.0,0.27920473,0.0958503442812747,2472.0,24520.111387252808,26905.28270840645,24520.111387252808,2383.031873941421,0.8874096870422363,0.0 -29500,0.7564393,1.2571557,,,,,,,,,,,,,, -29600,0.6790401,1.3175092,,,,,,,,,,,,,, -29700,0.6048758,1.3183496,,,,,,,,,,,,,, -29800,0.67402697,1.2759155,,,,,,,,,,,,,, -29900,0.58479,1.194477,,,,,,,,,,,,,, -30000,0.66089815,1.3041714,,,,,,,,,,,,,, -30100,0.68386,1.2285662,,,,,,,,,,,,,, -30200,0.7326361,1.3276125,,,,,,,,,,,,,, -30300,0.6652138,1.3111906,,,,,,,,,,,,,, -30400,0.7506573,1.2883364,,,,,,,,,,,,,, -30500,0.7040404,1.3135774,,,,,,,,,,,,,, -30600,0.6883068,1.3565625,,,,,,,,,,,,,, -30700,0.7604355,1.2929509,,,,,,,,,,,,,, -30800,0.80884296,1.2525204,,,,,,,,,,,,,, -30900,0.7060029,1.3037537,,,,,,,,,,,,,, -31000,0.7762553,1.2927006,,,,,,,,,,,,,, -31100,0.85698,1.3158882,,,,,,,,,,,,,, -31200,0.67150587,1.2652794,,,,,,,,,,,,,, -31289,,,0.24946217,0.0872451837069823,0.4818028,0.144076387615011,5348.0,0.28004083,0.0949972579367497,2472.0,25960.74817752838,28473.462604045868,25960.74817752838,2510.4499428272247,0.9405355453491212,0.0 -31300,0.6722199,1.3170335,,,,,,,,,,,,,, -31400,0.6913655,1.2947886,,,,,,,,,,,,,, -31500,0.60698,1.2267163,,,,,,,,,,,,,, -31600,0.6722037,1.2715806,,,,,,,,,,,,,, -31700,0.66676384,1.2984724,,,,,,,,,,,,,, -31800,0.6989762,1.255003,,,,,,,,,,,,,, -31900,0.75683165,1.273996,,,,,,,,,,,,,, -32000,0.6661316,1.2700517,,,,,,,,,,,,,, -32100,0.730558,1.2977496,,,,,,,,,,,,,, -32200,0.71252364,1.2861254,,,,,,,,,,,,,, -32300,0.66006345,1.2593757,,,,,,,,,,,,,, -32400,0.6118911,1.258549,,,,,,,,,,,,,, -32500,0.6158874,1.327231,,,,,,,,,,,,,, -32600,0.6900499,1.2270024,,,,,,,,,,,,,, -32700,0.6864884,1.292535,,,,,,,,,,,,,, -32800,0.7312478,1.2207787,,,,,,,,,,,,,, -32900,0.6687702,1.3062863,,,,,,,,,,,,,, -33000,0.6812469,1.2433288,,,,,,,,,,,,,, -33100,0.69927275,1.252845,,,,,,,,,,,,,, -33137,,,0.20576166,0.0775823398719714,0.46390292,0.1377043165953831,5348.0,0.27050698,0.0911989925456502,2472.0,27400.628918409348,30043.10976457596,27400.628918409348,2640.090026140213,0.9931390285491944,0.0 -33200,0.5939655,1.2209185,,,,,,,,,,,,,, -33300,0.62443244,1.2503694,,,,,,,,,,,,,, -33400,0.7929048,1.2424127,,,,,,,,,,,,,, -33500,0.753551,1.2680556,,,,,,,,,,,,,, -33600,0.70606667,1.2981951,,,,,,,,,,,,,, -33700,0.66620344,1.2652862,,,,,,,,,,,,,, -33800,0.71010315,1.2928708,,,,,,,,,,,,,, -33900,0.74494416,1.2844912,,,,,,,,,,,,,, -34000,0.79386264,1.1986834,,,,,,,,,,,,,, -34100,0.61328596,1.2322431,,,,,,,,,,,,,, -34200,0.83464766,1.2290411,,,,,,,,,,,,,, -34300,0.75628483,1.2085493,,,,,,,,,,,,,, -34400,0.65359235,1.2051957,,,,,,,,,,,,,, -34500,0.7871758,1.2578117,,,,,,,,,,,,,, -34600,0.7878171,1.212297,,,,,,,,,,,,,, -34700,0.7497804,1.294726,,,,,,,,,,,,,, -34800,0.6951818,1.2766587,,,,,,,,,,,,,, -34900,0.6386499,1.2488052,,,,,,,,,,,,,, -34975,,,0.18593837,0.0706608123632815,0.44671303,0.1354354731262732,5348.0,0.26218984,0.0894521966973371,2472.0,28841.17727303505,31614.65806341172,28841.17727303505,2770.963233947754,1.047239065170288,0.0 -35000,0.7082469,1.232174,,,,,,,,,,,,,, -35100,0.7567536,1.2413443,,,,,,,,,,,,,, -35200,0.7064572,1.2176421,,,,,,,,,,,,,, -35300,0.78435344,1.2324953,,,,,,,,,,,,,, -35400,0.6376121,1.2312756,,,,,,,,,,,,,, -35500,0.69749856,1.2974461,,,,,,,,,,,,,, -35600,0.65441805,1.2461667,,,,,,,,,,,,,, -35700,0.7237751,1.2980123,,,,,,,,,,,,,, -35800,0.7314149,1.2274092,,,,,,,,,,,,,, -35900,0.8160827,1.2260362,,,,,,,,,,,,,, -36000,0.7024014,1.302285,,,,,,,,,,,,,, -36100,0.62315595,1.2491271,,,,,,,,,,,,,, -36200,0.7807477,1.1676677,,,,,,,,,,,,,, -36300,0.76902777,1.2053473,,,,,,,,,,,,,, -36400,0.65376097,1.1850636,,,,,,,,,,,,,, -36500,0.6796656,1.2301295,,,,,,,,,,,,,, -36600,0.69353455,1.3263084,,,,,,,,,,,,,, -36700,0.8113266,1.2240224,,,,,,,,,,,,,, -36800,0.6423765,1.216593,,,,,,,,,,,,,, -36813,,,0.1798886,0.0659792170515853,0.4445087,0.1318246328818174,5348.0,0.25743183,0.0867304450267097,2472.0,30281.21803665161,33184.88313293457,30281.21803665161,2901.019725084305,1.099900484085083,0.0 -36900,0.6841296,1.2129781,,,,,,,,,,,,,, -37000,0.755939,1.2410102,,,,,,,,,,,,,, -37100,0.75865275,1.170454,,,,,,,,,,,,,, -37200,0.7385384,1.1832175,,,,,,,,,,,,,, -37300,0.7278438,1.2568718,,,,,,,,,,,,,, -37400,0.7973569,1.140966,,,,,,,,,,,,,, -37500,0.698564,1.2139472,,,,,,,,,,,,,, -37600,0.67538714,1.2501712,,,,,,,,,,,,,, -37700,0.7598009,1.2196797,,,,,,,,,,,,,, -37800,0.75112176,1.2005601,,,,,,,,,,,,,, -37900,0.62158126,1.2214024,,,,,,,,,,,,,, -38000,0.80501133,1.2557447,,,,,,,,,,,,,, -38100,0.77529836,1.2304631,,,,,,,,,,,,,, -38200,0.7992035,1.2020912,,,,,,,,,,,,,, -38300,0.72470033,1.1585499,,,,,,,,,,,,,, -38400,0.68652624,1.2190893,,,,,,,,,,,,,, -38500,0.76617545,1.2014045,,,,,,,,,,,,,, -38600,0.7577795,1.1827139,,,,,,,,,,,,,, -38657,,,0.17082553,0.0660368217054263,0.43620944,0.1306853838207324,5348.0,0.24864013,0.0843946133690817,2472.0,31721.667903661728,34754.47449827194,31721.667903661728,3030.02846121788,1.1584982872009275,0.0 -38700,0.69429696,1.2052529,,,,,,,,,,,,,, -38800,0.6684218,1.1865146,,,,,,,,,,,,,, -38900,0.85284936,1.2318555,,,,,,,,,,,,,, -39000,0.68296826,1.1933903,,,,,,,,,,,,,, -39100,0.7818791,1.1832249,,,,,,,,,,,,,, -39200,0.79375345,1.1399322,,,,,,,,,,,,,, -39300,0.654854,1.1673485,,,,,,,,,,,,,, -39400,0.6834239,1.1606903,,,,,,,,,,,,,, -39500,0.71944577,1.2042242,,,,,,,,,,,,,, -39600,0.71886384,1.2069026,,,,,,,,,,,,,, -39700,0.7070655,1.231136,,,,,,,,,,,,,, -39800,0.7681974,1.1976839,,,,,,,,,,,,,, -39900,0.80327016,1.2090828,,,,,,,,,,,,,, -40000,1.0053802,1.174918,,,,,,,,,,,,,, -40100,0.7335169,1.1780351,,,,,,,,,,,,,, -40200,0.72074896,1.1559935,,,,,,,,,,,,,, -40300,0.7561041,1.1468893,,,,,,,,,,,,,, -40400,0.7191099,1.1775675,,,,,,,,,,,,,, -40500,0.762191,1.1677294,,,,,,,,,,,,,, -40503,,,0.18386379,0.0680838521902516,0.43003973,0.1276827867190592,5348.0,0.2497961,0.0837852659801352,2472.0,33161.69714832306,36323.67623138428,33161.69714832306,3159.066329717636,1.2179243564605713,0.0 -40600,0.6909943,1.1841666,,,,,,,,,,,,,, -40700,0.768053,1.2112341,,,,,,,,,,,,,, -40800,0.8198392,1.188504,,,,,,,,,,,,,, -40900,0.7338718,1.2479012,,,,,,,,,,,,,, -41000,0.9702578,1.2146252,,,,,,,,,,,,,, -41100,0.7432695,1.1481642,,,,,,,,,,,,,, -41200,0.72141945,1.1443106,,,,,,,,,,,,,, -41300,0.85313123,1.1708524,,,,,,,,,,,,,, -41400,0.7107513,1.1570679,,,,,,,,,,,,,, -41500,0.90802884,1.1205335,,,,,,,,,,,,,, -41600,0.7139307,1.1921691,,,,,,,,,,,,,, -41700,0.80599624,1.191134,,,,,,,,,,,,,, -41800,0.7090581,1.098135,,,,,,,,,,,,,, -41900,0.7809036,1.2321191,,,,,,,,,,,,,, -42000,0.6958547,1.1745617,,,,,,,,,,,,,, -42100,0.8640108,1.2129904,,,,,,,,,,,,,, -42200,0.7395623,1.2396619,,,,,,,,,,,,,, -42300,0.75909454,1.1692914,,,,,,,,,,,,,, -42336,,,0.17193142,0.0642241447399598,0.41626227,0.1248346640663467,5348.0,0.23499578,0.0797432616334572,2472.0,34602.18103837967,37893.97904467583,34602.18103837967,3288.757158517837,1.272374153137207,0.0 -42400,0.77498794,1.1729295,,,,,,,,,,,,,, -42500,0.6428432,1.1893336,,,,,,,,,,,,,, -42600,0.807635,1.1434778,,,,,,,,,,,,,, -42700,0.7106237,1.161231,,,,,,,,,,,,,, -42800,0.80433315,1.1845366,,,,,,,,,,,,,, -42900,0.7409247,1.1809394,,,,,,,,,,,,,, -43000,0.8064681,1.1976693,,,,,,,,,,,,,, -43100,0.83552676,1.138585,,,,,,,,,,,,,, -43200,0.70731425,1.1591779,,,,,,,,,,,,,, -43300,0.7947714,1.1667376,,,,,,,,,,,,,, -43400,0.74401987,1.155947,,,,,,,,,,,,,, -43500,0.74623656,1.1271801,,,,,,,,,,,,,, -43600,0.6688265,1.1185563,,,,,,,,,,,,,, -43700,0.73691726,1.184613,,,,,,,,,,,,,, -43800,0.72393364,1.1477959,,,,,,,,,,,,,, -43900,0.7190935,1.1583475,,,,,,,,,,,,,, -44000,0.81626195,1.1812284,,,,,,,,,,,,,, -44100,0.72632027,1.1560351,,,,,,,,,,,,,, -44177,,,0.15760088,0.0605369308460445,0.41124475,0.1217162111279531,5348.0,0.23374555,0.080068246907562,2472.0,36043.03625369072,39465.451370716095,36043.03625369072,3419.2407870292664,1.3307726383209229,0.0 -44200,0.7277355,1.1870673,,,,,,,,,,,,,, -44300,0.7216832,1.1124573,,,,,,,,,,,,,, -44400,0.7542034,1.1193527,,,,,,,,,,,,,, -44500,0.8932663,1.1408511,,,,,,,,,,,,,, -44600,0.92393684,1.1299915,,,,,,,,,,,,,, -44700,0.92833453,1.1632644,,,,,,,,,,,,,, -44800,0.79682505,1.111949,,,,,,,,,,,,,, -44900,0.89918447,1.1381345,,,,,,,,,,,,,, -45000,0.6730248,1.1180656,,,,,,,,,,,,,, -45100,0.8506943,1.1965828,,,,,,,,,,,,,, -45200,0.7733685,1.1471986,,,,,,,,,,,,,, -45300,0.7876927,1.1705055,,,,,,,,,,,,,, -45400,0.6828875,1.1550802,,,,,,,,,,,,,, -45500,0.7880399,1.1552215,,,,,,,,,,,,,, -45600,0.8020418,1.1924525,,,,,,,,,,,,,, -45700,0.7414286,1.113296,,,,,,,,,,,,,, -45800,0.8992453,1.1885722,,,,,,,,,,,,,, -45900,0.7281701,1.1510123,,,,,,,,,,,,,, -46000,0.78190804,1.1797171,,,,,,,,,,,,,, -46011,,,0.14866316,0.0575584215110063,0.39693525,0.118269499985518,5348.0,0.22595222,0.075437206751569,2472.0,37483.32582259178,41036.49385213852,37483.32582259178,3549.8614530563354,1.3871524333953855,0.0 -46100,0.7678726,1.1663435,,,,,,,,,,,,,, -46200,0.7854304,1.1056471,,,,,,,,,,,,,, -46300,0.8035037,1.1432345,,,,,,,,,,,,,, -46400,0.7756986,1.0791405,,,,,,,,,,,,,, -46500,0.7452373,1.0989025,,,,,,,,,,,,,, -46600,0.882763,1.1267034,,,,,,,,,,,,,, -46700,0.7962565,1.1640855,,,,,,,,,,,,,, -46800,0.92516685,1.1263133,,,,,,,,,,,,,, -46900,0.8408591,1.111669,,,,,,,,,,,,,, -47000,0.7654665,1.1443952,,,,,,,,,,,,,, -47100,0.9309432,1.174084,,,,,,,,,,,,,, -47200,0.69275445,1.1322217,,,,,,,,,,,,,, -47300,0.8244505,1.17288,,,,,,,,,,,,,, -47400,0.7592051,1.0999409,,,,,,,,,,,,,, -47500,0.8338731,1.1340822,,,,,,,,,,,,,, -47600,0.8638031,1.0838493,,,,,,,,,,,,,, -47700,0.81597,1.0402256,,,,,,,,,,,,,, -47800,0.98197293,1.1004447,,,,,,,,,,,,,, -47854,,,0.1336802,0.0514337012198341,0.39036846,0.1155951610878863,5348.0,0.219991,0.072654520342047,2472.0,38923.7791454792,42606.84801912308,38923.7791454792,3679.6286759376526,1.4453561305999756,0.0 -47900,0.7147006,1.0940672,,,,,,,,,,,,,, -48000,0.97487444,1.1096517,,,,,,,,,,,,,, -48100,0.7882195,1.1325798,,,,,,,,,,,,,, -48200,0.79588324,1.1036304,,,,,,,,,,,,,, -48300,0.791705,1.1361775,,,,,,,,,,,,,, -48400,0.83651227,1.1706096,,,,,,,,,,,,,, -48500,0.88754064,1.1038473,,,,,,,,,,,,,, -48600,0.8812978,1.1265284,,,,,,,,,,,,,, -48700,0.84345174,1.1277249,,,,,,,,,,,,,, -48800,0.9220736,1.1022122,,,,,,,,,,,,,, -48900,0.92211264,1.0866816,,,,,,,,,,,,,, -49000,0.8791305,1.0911428,,,,,,,,,,,,,, -49100,0.990272,1.139125,,,,,,,,,,,,,, -49200,0.8789914,1.1261158,,,,,,,,,,,,,, -49300,0.798496,1.1503564,,,,,,,,,,,,,, -49400,0.7801353,1.0943599,,,,,,,,,,,,,, -49500,0.9729609,1.0767611,,,,,,,,,,,,,, -49600,0.906795,1.0702976,,,,,,,,,,,,,, -49697,,,0.13503304,0.0513602419115919,0.38563406,0.1136931944350579,5348.0,0.21653153,0.0710092823918916,2472.0,40363.92171168327,44177.68133687973,40363.92171168327,3810.182657241821,1.5044760704040527,0.0 -49700,0.86287683,1.0844268,,,,,,,,,,,,,, -49800,0.81294864,1.1065148,,,,,,,,,,,,,, -49900,0.78394413,1.0472448,,,,,,,,,,,,,, -50000,0.879425,1.1037871,,,,,,,,,,,,,, -50100,0.8744217,1.1533021,,,,,,,,,,,,,, -50200,0.7713535,1.0839257,,,,,,,,,,,,,, -50300,0.8995761,1.082461,,,,,,,,,,,,,, -50400,0.7609867,1.0997909,,,,,,,,,,,,,, -50500,0.86412203,1.0458968,,,,,,,,,,,,,, -50600,0.8690352,1.048248,,,,,,,,,,,,,, -50700,0.8432498,1.0691252,,,,,,,,,,,,,, -50800,0.81108737,1.104477,,,,,,,,,,,,,, -50900,0.805316,1.0854416,,,,,,,,,,,,,, -51000,0.88701093,1.1012383,,,,,,,,,,,,,, -51100,0.9579979,1.038431,,,,,,,,,,,,,, -51200,0.92976993,1.1116691,,,,,,,,,,,,,, -51300,0.94993323,1.1316224,,,,,,,,,,,,,, -51400,0.8370691,1.0856816,,,,,,,,,,,,,, -51500,0.82468617,1.061021,,,,,,,,,,,,,, -51530,,,0.13680682,0.0523486669727992,0.37866125,0.1118588103536499,5348.0,0.20964031,0.0716389413604696,2472.0,41803.84714341164,45747.65846896172,41803.84714341164,3940.1016092300415,1.5616471767425537,0.0 -51600,0.8367268,1.0807164,,,,,,,,,,,,,, -51700,0.8576602,1.0887326,,,,,,,,,,,,,, -51800,0.94896597,1.0854541,,,,,,,,,,,,,, -51900,0.9470546,1.0812744,,,,,,,,,,,,,, -52000,0.8684353,1.1074075,,,,,,,,,,,,,, -52100,0.9545454,1.0363611,,,,,,,,,,,,,, -52200,0.9441809,1.0364587,,,,,,,,,,,,,, -52300,0.86104244,1.125077,,,,,,,,,,,,,, -52400,0.9593634,1.1208861,,,,,,,,,,,,,, -52500,0.78411055,1.0609071,,,,,,,,,,,,,, -52600,0.75762403,1.0500497,,,,,,,,,,,,,, -52700,1.0610725,1.0070169,,,,,,,,,,,,,, -52800,0.8769445,1.028552,,,,,,,,,,,,,, -52900,0.78669435,1.114212,,,,,,,,,,,,,, -53000,0.8639577,1.0985469,,,,,,,,,,,,,, -53100,0.8579891,1.0831509,,,,,,,,,,,,,, -53200,0.970013,1.069097,,,,,,,,,,,,,, -53300,0.8059594,1.090124,,,,,,,,,,,,,, -53373,,,0.11914682,0.0452420368083603,0.36934325,0.1089817237417573,5348.0,0.2012541,0.0681859728231064,2472.0,43243.965631484985,47317.6235909462,43243.965631484985,4069.81272816658,1.6211626529693604,0.0 -53400,0.82998437,1.0890788,,,,,,,,,,,,,, -53500,0.99721515,1.0516003,,,,,,,,,,,,,, -53600,1.101983,1.029012,,,,,,,,,,,,,, -53700,0.7813413,1.0192984,,,,,,,,,,,,,, -53800,1.233186,1.0237995,,,,,,,,,,,,,, -53900,0.91747195,1.0563158,,,,,,,,,,,,,, -54000,0.8624023,1.0781628,,,,,,,,,,,,,, -54100,0.8870331,1.0482448,,,,,,,,,,,,,, -54200,0.9374495,1.0693887,,,,,,,,,,,,,, -54300,1.01831,1.0255855,,,,,,,,,,,,,, -54400,0.8172447,1.0708399,,,,,,,,,,,,,, -54500,0.8569599,1.0852789,,,,,,,,,,,,,, -54600,0.84041154,0.9785566,,,,,,,,,,,,,, -54700,0.851102,1.0062034,,,,,,,,,,,,,, -54800,0.8619587,1.0058156,,,,,,,,,,,,,, -54900,0.98195535,1.0365918,,,,,,,,,,,,,, -55000,1.1148043,1.0422384,,,,,,,,,,,,,, -55100,1.04249,1.0393738,,,,,,,,,,,,,, -55200,0.8442502,1.0590754,,,,,,,,,,,,,, -55209,,,0.11970915,0.0470404353538028,0.35692453,0.1047819496606389,5348.0,0.19781065,0.0661548148599516,2472.0,44683.96534562111,48889.67320728302,44683.96534562111,4201.730570077896,1.6765682697296145,0.0 -55300,1.0644037,1.0839888,,,,,,,,,,,,,, -55400,0.77046597,1.0280575,,,,,,,,,,,,,, -55500,1.0734282,1.0403767,,,,,,,,,,,,,, -55600,0.8450288,1.0385376,,,,,,,,,,,,,, -55700,0.95097065,1.02896,,,,,,,,,,,,,, -55800,1.0291301,1.0257939,,,,,,,,,,,,,, -55900,0.8931413,1.0306255,,,,,,,,,,,,,, -56000,0.86092556,1.0318208,,,,,,,,,,,,,, -56100,1.0744343,1.0443966,,,,,,,,,,,,,, -56200,0.93559766,1.0671909,,,,,,,,,,,,,, -56300,0.84015954,1.0388018,,,,,,,,,,,,,, -56400,0.893101,1.0280826,,,,,,,,,,,,,, -56500,1.0813209,1.057904,,,,,,,,,,,,,, -56600,1.3170705,0.9933636,,,,,,,,,,,,,, -56700,1.0381584,0.9676966,,,,,,,,,,,,,, -56800,1.1875188,1.0790229,,,,,,,,,,,,,, -56900,0.91217595,0.97594017,,,,,,,,,,,,,, -57000,1.1568544,1.0044336,,,,,,,,,,,,,, -57047,,,0.13444948,0.0475382545195381,0.3579548,0.1044440368035374,5348.0,0.19295095,0.0630268315966932,2472.0,46123.9932820797,50460.28098845482,46123.9932820797,4332.178327083588,1.7337017059326172,0.0 -57100,0.8897598,1.021929,,,,,,,,,,,,,, -57200,0.9778531,1.0034078,,,,,,,,,,,,,, -57300,1.1383406,1.0189741,,,,,,,,,,,,,, -57400,1.0126563,1.0771458,,,,,,,,,,,,,, -57500,0.8871692,1.0269023,,,,,,,,,,,,,, -57600,0.90222484,1.0126077,,,,,,,,,,,,,, -57700,0.8037542,0.9701109,,,,,,,,,,,,,, -57800,0.9147074,0.96418846,,,,,,,,,,,,,, -57900,0.9897016,0.9992074,,,,,,,,,,,,,, -58000,0.90131676,1.0455959,,,,,,,,,,,,,, -58100,1.0466553,0.9897639,,,,,,,,,,,,,, -58200,0.8799836,1.0174756,,,,,,,,,,,,,, -58300,1.0225195,1.0363446,,,,,,,,,,,,,, -58400,0.9877583,1.016093,,,,,,,,,,,,,, -58500,1.1373775,1.0172809,,,,,,,,,,,,,, -58600,0.9335748,0.99759,,,,,,,,,,,,,, -58700,0.930621,0.9752311,,,,,,,,,,,,,, -58800,0.9076798,1.0089473,,,,,,,,,,,,,, -58872,,,0.08747087,0.0343877048170506,0.34598675,0.0991919055388744,5348.0,0.18680844,0.0614222168058009,2472.0,47564.11428070068,52032.69997930527,47564.11428070068,4464.339418172836,1.7940006256103516,0.0 -58900,0.8647007,0.9563429,,,,,,,,,,,,,, -59000,1.1309457,1.0025907,,,,,,,,,,,,,, -59100,1.1283822,1.0196646,,,,,,,,,,,,,, -59200,1.0005617,1.0277717,,,,,,,,,,,,,, -59300,1.0571762,1.0125674,,,,,,,,,,,,,, -59400,0.9785509,1.0787998,,,,,,,,,,,,,, -59500,1.1328317,1.0256286,,,,,,,,,,,,,, -59600,1.0624359,1.0174272,,,,,,,,,,,,,, -59700,1.2624396,1.026428,,,,,,,,,,,,,, -59800,0.88757825,0.96012974,,,,,,,,,,,,,, -59900,1.1681815,1.0581646,,,,,,,,,,,,,, -60000,1.018462,0.9992336,,,,,,,,,,,,,, -60100,0.9245445,0.97663355,,,,,,,,,,,,,, -60200,1.2872846,0.98702395,,,,,,,,,,,,,, -60300,1.005198,0.98887676,,,,,,,,,,,,,, -60400,0.89956486,0.986134,,,,,,,,,,,,,, -60500,1.0465165,1.0233849,,,,,,,,,,,,,, -60600,1.253712,0.98551416,,,,,,,,,,,,,, -60697,,,0.086636014,0.0333267446215305,0.32997072,0.0964210201106423,5348.0,0.17877539,0.0604675725631182,2472.0,49004.76940536499,53605.760281562805,49004.76940536499,4596.614463329315,1.8496484756469729,0.0 -60700,1.0687605,0.9951134,,,,,,,,,,,,,, -60800,0.94774485,0.9610203,,,,,,,,,,,,,, -60900,1.0228926,0.9651506,,,,,,,,,,,,,, -61000,1.0830591,0.96805155,,,,,,,,,,,,,, -61100,1.1164893,1.0083576,,,,,,,,,,,,,, -61200,0.9748412,0.8865964,,,,,,,,,,,,,, -61300,0.90099186,0.9257217,,,,,,,,,,,,,, -61400,1.3794782,1.0020082,,,,,,,,,,,,,, -61500,1.1800513,1.0149115,,,,,,,,,,,,,, -61600,0.9462869,0.9207483,,,,,,,,,,,,,, -61700,1.1728424,0.9628441,,,,,,,,,,,,,, -61800,0.9516203,0.9591583,,,,,,,,,,,,,, -61900,0.9736809,0.92361647,,,,,,,,,,,,,, -62000,1.0009915,0.9534807,,,,,,,,,,,,,, -62100,1.0259594,0.9888347,,,,,,,,,,,,,, -62200,0.973366,0.9487317,,,,,,,,,,,,,, -62300,1.130359,0.9018502,,,,,,,,,,,,,, -62400,0.9854815,0.9859267,,,,,,,,,,,,,, -62500,1.3477696,0.9681356,,,,,,,,,,,,,, -62523,,,0.10608669,0.0416625411179686,0.32664424,0.0953783175801577,5348.0,0.17875016,0.0592691893648569,2472.0,50444.94024038315,55175.63874554634,50444.94024038315,4726.184319496155,1.9108588695526123,0.0 -62600,1.0429264,0.91595024,,,,,,,,,,,,,, -62700,1.0503373,0.95970124,,,,,,,,,,,,,, -62800,1.0046108,0.9567191,,,,,,,,,,,,,, -62900,1.0541339,0.94986224,,,,,,,,,,,,,, -63000,0.96213067,0.89789647,,,,,,,,,,,,,, -63100,1.194577,0.89658755,,,,,,,,,,,,,, -63200,0.9176839,0.9328134,,,,,,,,,,,,,, -63300,1.0287225,0.93185455,,,,,,,,,,,,,, -63400,1.2583578,0.9926198,,,,,,,,,,,,,, -63500,1.3174697,0.9323396,,,,,,,,,,,,,, -63600,1.4770632,0.9467492,,,,,,,,,,,,,, -63700,0.9841175,0.9070107,,,,,,,,,,,,,, -63800,1.033646,0.9521273,,,,,,,,,,,,,, -63900,1.056492,0.9157765,,,,,,,,,,,,,, -64000,1.156456,0.914018,,,,,,,,,,,,,, -64100,1.006205,0.9472829,,,,,,,,,,,,,, -64200,1.1457243,0.96574473,,,,,,,,,,,,,, -64300,1.0778586,0.9237114,,,,,,,,,,,,,, -64357,,,0.10365024,0.0396321834020283,0.31951347,0.0919122971316025,5348.0,0.17223309,0.0569942924461235,2472.0,51885.154266119,56746.01544976234,51885.154266119,4856.205128669739,1.9761414527893064,0.0 -64400,1.1330801,0.916003,,,,,,,,,,,,,, -64500,1.0605011,0.9210709,,,,,,,,,,,,,, -64600,1.1642717,0.9664252,,,,,,,,,,,,,, -64700,1.1474391,0.9682565,,,,,,,,,,,,,, -64800,1.0443759,0.88587177,,,,,,,,,,,,,, -64900,1.1172347,0.9661949,,,,,,,,,,,,,, -65000,1.0776354,0.9368979,,,,,,,,,,,,,, -65100,1.0470471,0.9451747,,,,,,,,,,,,,, -65200,1.1146116,0.9137643,,,,,,,,,,,,,, -65300,1.1211349,0.9605975,,,,,,,,,,,,,, -65400,1.0603327,0.92270875,,,,,,,,,,,,,, -65500,1.3726741,0.92571455,,,,,,,,,,,,,, -65600,1.1007917,0.92931074,,,,,,,,,,,,,, -65700,1.2811732,0.9427324,,,,,,,,,,,,,, -65800,1.1236587,0.9288577,,,,,,,,,,,,,, -65900,1.4718999,0.931562,,,,,,,,,,,,,, -66000,1.257013,0.9572891,,,,,,,,,,,,,, -66100,1.4048607,0.93458176,,,,,,,,,,,,,, -66182,,,0.11834669,0.045590728569106,0.3103893,0.0891317570503104,5348.0,0.16681975,0.0547397071070217,2472.0,53325.56930327416,58314.800549030304,53325.56930327416,4984.438056707382,2.034407615661621,0.0 -66200,1.2711328,0.8780363,,,,,,,,,,,,,, -66300,1.0557771,0.89259434,,,,,,,,,,,,,, -66400,1.1140301,0.8794669,,,,,,,,,,,,,, -66500,1.2090659,0.94781584,,,,,,,,,,,,,, -66600,1.1631019,0.992253,,,,,,,,,,,,,, -66700,1.1217433,0.9160013,,,,,,,,,,,,,, -66800,1.341419,0.8973509,,,,,,,,,,,,,, -66900,1.1700909,0.9258826,,,,,,,,,,,,,, -67000,1.128328,0.88653594,,,,,,,,,,,,,, -67100,1.4616448,0.90890545,,,,,,,,,,,,,, -67200,1.012503,0.9107167,,,,,,,,,,,,,, -67300,1.1741586,0.9407489,,,,,,,,,,,,,, -67400,1.1944213,0.93178403,,,,,,,,,,,,,, -67500,1.2386098,0.9116645,,,,,,,,,,,,,, -67600,1.2009177,0.8882662,,,,,,,,,,,,,, -67700,1.185192,0.88568383,,,,,,,,,,,,,, -67800,1.3943334,0.8912886,,,,,,,,,,,,,, -67900,1.0959164,0.88507533,,,,,,,,,,,,,, -67997,,,0.09386895,0.0346327432593597,0.30321878,0.0869111868465006,5348.0,0.16329446,0.0533178965328133,2472.0,54765.563039541245,59884.31038093567,54765.563039541245,5113.788911581039,2.122683525085449,0.0 -68000,1.0736114,0.8654731,,,,,,,,,,,,,, -68100,1.2361733,0.8931989,,,,,,,,,,,,,, -68200,1.1668074,0.8517749,,,,,,,,,,,,,, -68300,1.3694372,0.8656327,,,,,,,,,,,,,, -68400,1.2237678,0.9118544,,,,,,,,,,,,,, -68500,1.2001266,0.9231604,,,,,,,,,,,,,, -68600,1.4911815,0.8852969,,,,,,,,,,,,,, -68700,2.4236498,0.9072725,,,,,,,,,,,,,, -68800,1.15527,0.8942986,,,,,,,,,,,,,, -68900,1.2309524,0.96218306,,,,,,,,,,,,,, -69000,1.8767742,0.8786612,,,,,,,,,,,,,, -69100,1.0848227,0.8404083,,,,,,,,,,,,,, -69200,1.1801373,0.9103123,,,,,,,,,,,,,, -69300,1.2355194,0.8852953,,,,,,,,,,,,,, -69400,1.3093548,0.9453759,,,,,,,,,,,,,, -69500,1.1570401,0.9290912,,,,,,,,,,,,,, -69600,1.055487,0.866394,,,,,,,,,,,,,, -69700,1.377862,0.90627277,,,,,,,,,,,,,, -69800,1.1511515,0.8483765,,,,,,,,,,,,,, -69819,,,0.09016636,0.0354465941015229,0.30152643,0.0854147156221941,5348.0,0.1599037,0.0525257449271829,2472.0,56205.50395298004,61454.21480035782,56205.50395298004,5243.618288755417,2.180286169052124,0.0 -69900,1.3867174,0.8825404,,,,,,,,,,,,,, -70000,1.1578937,0.8924409,,,,,,,,,,,,,, -70100,1.18832,0.9049084,,,,,,,,,,,,,, -70200,1.1265547,0.9167453,,,,,,,,,,,,,, -70300,1.2229215,0.84420127,,,,,,,,,,,,,, -70400,1.2288364,0.87606317,,,,,,,,,,,,,, -70500,1.1011237,0.8187019,,,,,,,,,,,,,, -70600,1.2321917,0.90158606,,,,,,,,,,,,,, -70700,1.3610077,0.86296886,,,,,,,,,,,,,, -70800,1.2074302,0.86869603,,,,,,,,,,,,,, -70900,2.0605395,0.88921,,,,,,,,,,,,,, -71000,1.1094348,0.84830284,,,,,,,,,,,,,, -71100,1.3455353,0.8692485,,,,,,,,,,,,,, -71200,1.2650454,0.8494107,,,,,,,,,,,,,, -71300,1.6785719,0.8699133,,,,,,,,,,,,,, -71400,1.4639689,0.8818841,,,,,,,,,,,,,, -71500,1.1662141,0.8749281,,,,,,,,,,,,,, -71600,1.2108538,0.8793676,,,,,,,,,,,,,, -71644,,,0.0650331,0.0247357237001371,0.29494342,0.0837830792550469,5348.0,0.15629764,0.0508195722381329,2472.0,57645.81762099266,63027.75369310379,57645.81762099266,5376.695669412613,2.2511696815490723,0.0 -71700,1.1358216,0.86702776,,,,,,,,,,,,,, -71800,1.0885953,0.8288592,,,,,,,,,,,,,, -71900,1.0668849,0.8909739,,,,,,,,,,,,,, -72000,1.4609689,0.888729,,,,,,,,,,,,,, -72100,1.4075488,0.8440255,,,,,,,,,,,,,, -72200,1.5816329,0.83361715,,,,,,,,,,,,,, -72300,1.380073,0.83071595,,,,,,,,,,,,,, -72400,1.3153199,0.8492636,,,,,,,,,,,,,, -72500,1.215126,0.84982914,,,,,,,,,,,,,, -72600,1.2693998,0.85988903,,,,,,,,,,,,,, -72700,1.8417369,0.8983595,,,,,,,,,,,,,, -72800,1.365255,0.86837184,,,,,,,,,,,,,, -72900,1.3738971,0.8714018,,,,,,,,,,,,,, -73000,1.2763387,0.86428565,,,,,,,,,,,,,, -73100,1.3260229,0.9060934,,,,,,,,,,,,,, -73200,1.3631351,0.89739084,,,,,,,,,,,,,, -73300,1.3907951,0.8491386,,,,,,,,,,,,,, -73400,1.5189582,0.8755449,,,,,,,,,,,,,, -73472,,,0.07749713,0.0303396862581965,0.29196256,0.0827693406837425,5348.0,0.1558389,0.0498649279954502,2472.0,59086.080971241,64598.00075960159,59086.080971241,5506.537875413895,2.314791440963745,0.0 -73500,1.2838975,0.8419794,,,,,,,,,,,,,, -73600,1.1653554,0.86629194,,,,,,,,,,,,,, -73700,1.3678056,0.871193,,,,,,,,,,,,,, -73800,1.4189373,0.87630934,,,,,,,,,,,,,, -73900,1.2556995,0.8753573,,,,,,,,,,,,,, -74000,1.5882758,0.81681246,,,,,,,,,,,,,, -74100,1.1834083,0.86657894,,,,,,,,,,,,,, -74200,1.1879996,0.87578917,,,,,,,,,,,,,, -74300,1.9991736,0.8004097,,,,,,,,,,,,,, -74400,1.4596299,0.89567566,,,,,,,,,,,,,, -74500,1.4539636,0.8672324,,,,,,,,,,,,,, -74600,1.2301117,0.80495596,,,,,,,,,,,,,, -74700,1.8070151,0.85683906,,,,,,,,,,,,,, -74800,1.3737255,0.84321254,,,,,,,,,,,,,, -74900,2.1504602,0.84482825,,,,,,,,,,,,,, -75000,1.3228458,0.8710595,,,,,,,,,,,,,, -75100,1.5146816,0.8410687,,,,,,,,,,,,,, -75200,1.2686465,0.82627386,,,,,,,,,,,,,, -75286,,,0.067730136,0.0247909151636606,0.289029,0.0814756171736968,5348.0,0.15373042,0.0490930879694513,2472.0,60526.34147930145,66170.01026082039,60526.34147930145,5638.150643587112,2.375697612762451,0.0 -75300,1.2475256,0.83966076,,,,,,,,,,,,,, -75400,1.344948,0.8530619,,,,,,,,,,,,,, -75500,1.3094219,0.8545391,,,,,,,,,,,,,, -75600,1.4704874,0.822256,,,,,,,,,,,,,, -75700,1.6100787,0.84093577,,,,,,,,,,,,,, -75800,1.2674452,0.85816896,,,,,,,,,,,,,, -75900,1.3048021,0.8473644,,,,,,,,,,,,,, -75992,,,,,,,,,,,61068.64175653458,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/eval_measurements.csv deleted file mode 100644 index ece91cdbf..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -135.43686819076538,0.0,34.95112609863281,1,0,34.95112609863281,31.27949,2472,1.0976986980277457,170.38803958892822,31.906748,1.3798149413691736,31.163591,5348,1.043146644525329 -242.74953532218933,0.031334638595581,1475.3679308891296,1789,0,1475.3679308891296,6.313715,2472,0.899579550301627,1718.2228519916534,6.4432144,0.9413900245298448,6.3511953,5348,0.8966179750330672 -350.4788267612457,0.0864629745483398,2915.763386487961,3605,0,2915.763386487961,6.4707236,2472,0.898990514492312,3266.4796857833862,6.894968,0.9380587920410928,6.6507244,5348,0.8960097318902845 -470.847095489502,0.1448276042938232,4355.635868310928,5421,0,4355.635868310928,3.293352,2472,0.7108037292060203,4826.856513261795,3.7265964,0.7729606676212637,3.6029704,5348,0.7371327611342287 -601.6357429027557,0.1990370750427246,5795.708523511887,7224,0,5795.708523511887,1.4695274,2472,0.4360693031097028,6397.847925901413,1.7894272,0.50639147729467,1.8419197,5348,0.487907547042297 -731.5442821979523,0.2577028274536133,7236.371410369873,9046,0,7236.371410369873,1.1427629,2472,0.3549042309020372,7968.555783987045,1.5100107,0.4490485031329775,1.4825138,5348,0.4142908174594746 -863.1278550624847,0.3104140758514404,8676.504511356354,10864,0,8676.504511356354,0.98521,2472,0.3151138464038348,9540.403466939926,1.2075412,0.3723685439688758,1.3426595,5348,0.3834924741979397 -993.5067636966704,0.3614308834075928,10116.434507131577,12691,0,10116.434507131577,0.8880885,2472,0.2892775171125058,11110.8425116539,1.1640844,0.3571085155123165,1.218482,5348,0.352742404201705 -1124.2681069374084,0.4147167205810547,11556.422718286514,14500,0,11556.422718286514,0.8417588,2472,0.272825137610952,12681.722338914871,1.1236786,0.3462997832604231,1.16027,5348,0.3341958156733637 -1256.8274652957916,0.464468240737915,12996.567366361618,16323,0,12996.567366361618,0.7645127,2472,0.2534682022220868,14254.55534505844,1.0246935,0.3256618002726353,1.0873188,5348,0.3203896617975033 -1388.124047040939,0.5152781009674072,14436.753766775131,18144,0,14436.753766775131,0.7306505,2472,0.2422968334247354,15826.1676633358,0.91025984,0.294290872165181,1.035664,5348,0.3064097241665621 -1518.3700096607208,0.5735807418823242,15877.275433778765,19977,0,15877.275433778765,0.68959343,2472,0.2303130014421221,17397.07205915451,0.83924824,0.2758357747722438,0.9875356,5348,0.2946600113924906 -1648.6084377765656,0.6313190460205078,17317.22550392151,21798,0,17317.22550392151,0.6596471,2472,0.2209493632319785,18967.39634442329,0.86420834,0.2851232256183706,0.9465712,5348,0.2811821157206716 -1778.9243762493134,0.6872842311859131,18757.19909286499,23608,0,18757.19909286499,0.631796,2472,0.2112810513273617,20537.81848526001,0.8158463,0.2648283332017014,0.9147472,5348,0.2765285729457312 -1911.658009529113,0.73909592628479,20197.799550056458,25422,0,20197.799550056458,0.6027977,2472,0.2033392236914264,22111.28186249733,0.8221751,0.2644739657925494,0.89193165,5348,0.2666228989061278 -2043.7654702663424,0.7973823547363281,21638.32007479668,27255,0,21638.32007479668,0.59717554,2472,0.2019580362764812,23684.04691696167,0.7159185,0.2397716148297229,0.87400496,5348,0.2637458122942352 -2174.5478515625,0.8513691425323486,23079.08738541603,29078,0,23079.08738541603,0.5756875,2472,0.1954989539536489,25255.72882080078,0.7162538,0.2431265672208601,0.8402705,5348,0.2538691022138119 -2303.503675699234,0.9082772731781006,24519.23925757408,30893,0,24519.23925757408,0.57978195,2472,0.1974894887575407,26824.97067308426,0.7572224,0.2548112793816958,0.86291736,5348,0.2615638607026656 -2445.386803150177,0.9621689319610596,25959.860438346863,32729,0,25959.860438346863,0.5389743,2472,0.1829667093209838,28407.60689687729,0.508243,0.1794177239877347,0.7982695,5348,0.2409704857255954 -2578.3474531173706,1.02097749710083,27399.94140648842,34568,0,27399.94140648842,0.51708585,2472,0.1752076859017325,29980.78520011902,0.45416874,0.1624045821292815,0.7872792,5348,0.2371858617260588 -2709.968914270401,1.0840272903442385,28840.46229052544,36405,0,28840.46229052544,0.51699764,2472,0.1756342290739951,31553.070999383926,0.4568602,0.1632162807164774,0.7698689,5348,0.2332853818898018 -2842.078820705414,1.1394822597503662,30280.543273448944,38227,0,30280.543273448944,0.487146,2472,0.1664533950805354,33125.39553499222,0.41335335,0.1520887762228717,0.74211925,5348,0.2256195873601282 -2973.8543276786804,1.196692943572998,31721.368850708008,40049,0,31721.368850708008,0.4805117,2472,0.164401925537749,34698.13117194176,0.41621852,0.1532543069726285,0.7324096,5348,0.2226749181768153 -3107.816128730774,1.2540216445922852,33161.7200114727,41871,0,33161.7200114727,0.45834982,2472,0.1583084516482847,36272.580060482025,0.38685045,0.143294122585181,0.69860196,5348,0.2129140639331125 -3239.886257171631,1.3179125785827637,34601.8858397007,43705,0,34601.8858397007,0.43958446,2472,0.152539963032925,37844.95906472206,0.4322496,0.1538579260060887,0.68366253,5348,0.2103845448313815 -3371.7308938503265,1.3760406970977783,36041.96706032753,45527,0,36041.96706032753,0.43343684,2472,0.1485182702658785,39417.02039551735,0.3753733,0.1400401803932237,0.67537385,5348,0.2057889299748013 -3504.4280862808228,1.436626672744751,37481.884147167206,47343,0,37481.884147167206,0.41392162,2472,0.1407795584262588,40989.77420902252,0.34525725,0.1258080321158192,0.6423175,5348,0.1970900875677032 -3637.337597131729,1.5352272987365725,38922.35270643234,49160,0,38922.35270643234,0.40082654,2472,0.1379765604371051,42563.32991623879,0.31973082,0.120110860171262,0.6218428,5348,0.192697220425384 -3769.957051992416,1.5926856994628906,40362.86497974396,50992,0,40362.86497974396,0.38197964,2472,0.1339751792496902,44136.597618341446,0.30350104,0.1163183008531079,0.6097256,5348,0.1875319810382613 -3900.5444436073294,1.651547908782959,41803.16771149635,52819,0,41803.16771149635,0.3687143,2472,0.1252412000081246,45707.625512599945,0.31450352,0.1154854943438069,0.587471,5348,0.1810054355696728 -4032.609624862671,1.7056090831756592,43243.308198452,54636,0,43243.308198452,0.35127002,2472,0.1210367030243942,47279.961285591125,0.29329264,0.1079460585277596,0.56704783,5348,0.1752126437336474 -4165.615864276886,1.7686245441436768,44683.90547633171,56468,0,44683.90547633171,0.33652857,2472,0.1180102776592935,48853.7058467865,0.2854557,0.104970891640451,0.5525805,5348,0.1698446566322639 -4298.656836748123,1.9026412963867188,46124.3394947052,58296,0,46124.3394947052,0.32527593,2472,0.1107387321511993,50427.39268708229,0.25223732,0.0949610668576031,0.5288609,5348,0.1617154387557083 -4430.442895412445,1.963092565536499,47564.77532219887,60135,0,47564.77532219887,0.30733588,2472,0.104645258261735,51999.75342416763,0.2326758,0.0873209403468769,0.5112117,5348,0.1576701391235506 -4562.946071147919,2.0217151641845703,49004.667399168015,61955,0,49004.667399168015,0.29639518,2472,0.1002173339020575,53572.28482842445,0.21656099,0.0802586862794826,0.49202898,5348,0.1502650202264981 -4695.051886081696,2.079317331314087,50444.830134153366,63784,0,50444.830134153366,0.28154856,2472,0.0963378221924319,55144.689125299454,0.21301003,0.0795733142829309,0.4686298,5348,0.1450225436148952 -4828.119354486465,2.1474575996398926,51885.259873628616,65615,0,51885.259873628616,0.27360567,2472,0.0922551946864907,56718.33281850815,0.18987854,0.0696951776786194,0.45930022,5348,0.1404655473705552 -4961.224667787552,2.211995840072632,53325.22155690193,67452,0,53325.22155690193,0.2592684,2472,0.0881928787601811,58291.54190802574,0.18605858,0.0703033976144076,0.43531832,5348,0.1353389265956728 -5093.022603034973,2.2826132774353027,54765.25285768509,69288,0,54765.25285768509,0.24975313,2472,0.0854711270895537,59863.52037596703,0.1981465,0.0684959758019311,0.42125234,5348,0.1299419755351091 -5226.270207643509,2.345404863357544,56205.272963523865,71102,0,56205.272963523865,0.23893896,2472,0.0804947900798245,61436.927609205246,0.13899502,0.0517284898186257,0.41104484,5348,0.1253173967193488 -5358.731348514557,2.414510488510132,57645.54037809372,72941,0,57645.54037809372,0.2305465,2472,0.0770621331220929,63009.80436420441,0.13666794,0.0507454263274232,0.39726907,5348,0.1221699798217751 -5490.084238290787,2.474437475204468,59085.41255426407,74775,0,59085.41255426407,0.22601831,2472,0.076066865720147,64581.16864061356,0.16922173,0.0638789305164062,0.38635293,5348,0.1179122778222964 -5621.268357515335,2.5391104221343994,60525.80380296707,76616,0,60525.80380296707,0.22135468,2472,0.07438100461072858,66152.88785719872,0.1668892,0.06057212775371304,0.3823363,5348,0.1158654913735675 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/measurements.csv deleted file mode 100644 index c66336a0f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/measurements.csv +++ /dev/null @@ -1,819 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,47.990047,31.799711,,,,,,,,,,,,,, -1,,,31.906748,1.3798149413691736,31.163591,1.043146644525329,5348.0,31.27949,1.0976986980277457,2472.0,34.95112609863281,170.38803958892822,34.95112609863281,135.43686819076538,0.0,0.0 -100,1.7971958,5.8280582,,,,,,,,,,,,,, -200,3.5575018,5.7912164,,,,,,,,,,,,,, -300,2.0501664,5.6134963,,,,,,,,,,,,,, -400,1.8556293,5.602569,,,,,,,,,,,,,, -500,0.48225632,5.5300817,,,,,,,,,,,,,, -600,2.8365173,5.5330954,,,,,,,,,,,,,, -700,3.3659468,5.548333,,,,,,,,,,,,,, -800,2.4768913,5.530944,,,,,,,,,,,,,, -900,0.40034062,5.51743,,,,,,,,,,,,,, -1000,0.9653763,5.4883966,,,,,,,,,,,,,, -1100,0.93691915,5.510887,,,,,,,,,,,,,, -1200,2.2332525,5.4955764,,,,,,,,,,,,,, -1300,2.6158817,5.524514,,,,,,,,,,,,,, -1400,1.1830167,5.516255,,,,,,,,,,,,,, -1500,1.1497766,5.5057216,,,,,,,,,,,,,, -1600,0.41021943,5.4914536,,,,,,,,,,,,,, -1700,0.32719746,5.4840555,,,,,,,,,,,,,, -1789,,,6.4432144,0.9413900245298448,6.3511953,0.8966179750330672,5348.0,6.313715,0.899579550301627,2472.0,1475.3679308891296,1718.2228519916534,1475.3679308891296,242.74953532218933,0.031334638595581,0.0 -1800,2.404208,5.479972,,,,,,,,,,,,,, -1900,1.777378,5.4747744,,,,,,,,,,,,,, -2000,1.4483275,5.4608817,,,,,,,,,,,,,, -2100,1.0409415,5.4293056,,,,,,,,,,,,,, -2200,2.1780448,5.45465,,,,,,,,,,,,,, -2300,1.1742948,5.425668,,,,,,,,,,,,,, -2400,0.2781424,5.39923,,,,,,,,,,,,,, -2500,0.6527989,5.3665705,,,,,,,,,,,,,, -2600,0.97825676,5.0698395,,,,,,,,,,,,,, -2700,1.9467951,4.701734,,,,,,,,,,,,,, -2800,1.7856717,4.405939,,,,,,,,,,,,,, -2900,1.0435396,4.2552695,,,,,,,,,,,,,, -3000,0.76723033,4.175039,,,,,,,,,,,,,, -3100,1.6227279,4.0786667,,,,,,,,,,,,,, -3200,1.2683685,3.9148095,,,,,,,,,,,,,, -3300,1.1427569,3.7961166,,,,,,,,,,,,,, -3400,0.72525436,3.6290324,,,,,,,,,,,,,, -3500,0.8351208,3.5915132,,,,,,,,,,,,,, -3600,1.3710539,3.5333095,,,,,,,,,,,,,, -3605,,,6.894968,0.9380587920410928,6.6507244,0.8960097318902845,5348.0,6.4707236,0.898990514492312,2472.0,2915.763386487961,3266.4796857833862,2915.763386487961,350.4788267612457,0.0864629745483398,0.0 -3700,1.0892756,3.4843316,,,,,,,,,,,,,, -3800,1.0130615,3.4728634,,,,,,,,,,,,,, -3900,1.4717268,3.4231539,,,,,,,,,,,,,, -4000,0.78024465,3.4120114,,,,,,,,,,,,,, -4100,1.7581377,3.386093,,,,,,,,,,,,,, -4200,1.9783083,3.2632587,,,,,,,,,,,,,, -4300,1.1166884,3.1929924,,,,,,,,,,,,,, -4400,1.1559303,3.160637,,,,,,,,,,,,,, -4500,1.8635404,3.0970392,,,,,,,,,,,,,, -4600,1.3862479,3.1140893,,,,,,,,,,,,,, -4700,0.5786533,2.972457,,,,,,,,,,,,,, -4800,1.1006799,3.0100913,,,,,,,,,,,,,, -4900,1.75949,3.0199156,,,,,,,,,,,,,, -5000,4.7302256,2.9967678,,,,,,,,,,,,,, -5100,1.037551,2.9440427,,,,,,,,,,,,,, -5200,1.9334303,2.938408,,,,,,,,,,,,,, -5300,0.6467679,2.9256902,,,,,,,,,,,,,, -5400,1.3192192,2.9303436,,,,,,,,,,,,,, -5421,,,3.7265964,0.7729606676212637,3.6029704,0.7371327611342287,5348.0,3.293352,0.7108037292060203,2472.0,4355.635868310928,4826.856513261795,4355.635868310928,470.847095489502,0.1448276042938232,0.0 -5500,1.2428309,2.867488,,,,,,,,,,,,,, -5600,1.2650974,2.8395705,,,,,,,,,,,,,, -5700,1.371552,2.8384738,,,,,,,,,,,,,, -5800,0.5263241,2.8035383,,,,,,,,,,,,,, -5900,0.96486974,2.7916381,,,,,,,,,,,,,, -6000,1.0487255,2.7889628,,,,,,,,,,,,,, -6100,0.746635,2.747842,,,,,,,,,,,,,, -6200,0.99365234,2.6937664,,,,,,,,,,,,,, -6300,0.991747,2.7593405,,,,,,,,,,,,,, -6400,1.2460473,2.6997414,,,,,,,,,,,,,, -6500,0.79486734,2.575936,,,,,,,,,,,,,, -6600,0.85526705,2.6257885,,,,,,,,,,,,,, -6700,0.848155,2.670514,,,,,,,,,,,,,, -6800,0.7498917,2.6642342,,,,,,,,,,,,,, -6900,0.7069639,2.6557772,,,,,,,,,,,,,, -7000,0.6761559,2.5646126,,,,,,,,,,,,,, -7100,0.68375397,2.6615493,,,,,,,,,,,,,, -7200,1.505433,2.600264,,,,,,,,,,,,,, -7224,,,1.7894272,0.50639147729467,1.8419197,0.487907547042297,5348.0,1.4695274,0.4360693031097028,2472.0,5795.708523511887,6397.847925901413,5795.708523511887,601.6357429027557,0.1990370750427246,0.0 -7300,1.5594072,2.576913,,,,,,,,,,,,,, -7400,0.5272041,2.5492167,,,,,,,,,,,,,, -7500,1.2343407,2.5028172,,,,,,,,,,,,,, -7600,0.9603368,2.5034614,,,,,,,,,,,,,, -7700,0.8676791,2.510707,,,,,,,,,,,,,, -7800,1.2041007,2.508563,,,,,,,,,,,,,, -7900,1.214167,2.478066,,,,,,,,,,,,,, -8000,0.7987601,2.556109,,,,,,,,,,,,,, -8100,0.7761017,2.5257354,,,,,,,,,,,,,, -8200,1.7079666,2.5579402,,,,,,,,,,,,,, -8300,0.6559898,2.471656,,,,,,,,,,,,,, -8400,1.332538,2.6472855,,,,,,,,,,,,,, -8500,0.85487825,2.4829872,,,,,,,,,,,,,, -8600,0.67237747,2.4683268,,,,,,,,,,,,,, -8700,1.436278,2.5589793,,,,,,,,,,,,,, -8800,1.256965,2.502486,,,,,,,,,,,,,, -8900,0.66290987,2.4114451,,,,,,,,,,,,,, -9000,1.6814718,2.4380178,,,,,,,,,,,,,, -9046,,,1.5100107,0.4490485031329775,1.4825138,0.4142908174594746,5348.0,1.1427629,0.3549042309020372,2472.0,7236.371410369873,7968.555783987045,7236.371410369873,731.5442821979523,0.2577028274536133,0.0 -9100,0.60618144,2.3846276,,,,,,,,,,,,,, -9200,1.2492962,2.3778238,,,,,,,,,,,,,, -9300,1.1363478,2.4227693,,,,,,,,,,,,,, -9400,0.7609217,2.3212805,,,,,,,,,,,,,, -9500,0.56022555,2.3717296,,,,,,,,,,,,,, -9600,0.8516286,2.3541822,,,,,,,,,,,,,, -9700,0.6461584,2.4203045,,,,,,,,,,,,,, -9800,2.0426652,2.4617658,,,,,,,,,,,,,, -9900,0.8433135,2.291699,,,,,,,,,,,,,, -10000,1.2334013,2.309126,,,,,,,,,,,,,, -10100,1.1733875,2.3042538,,,,,,,,,,,,,, -10200,0.62071884,2.3286986,,,,,,,,,,,,,, -10300,0.61465013,2.2926068,,,,,,,,,,,,,, -10400,0.76982623,2.2506146,,,,,,,,,,,,,, -10500,0.8118756,2.3212082,,,,,,,,,,,,,, -10600,0.593155,2.2402613,,,,,,,,,,,,,, -10700,0.734303,2.2720306,,,,,,,,,,,,,, -10800,0.61034167,2.2249122,,,,,,,,,,,,,, -10864,,,1.2075412,0.3723685439688758,1.3426595,0.3834924741979397,5348.0,0.98521,0.3151138464038348,2472.0,8676.504511356354,9540.403466939926,8676.504511356354,863.1278550624847,0.3104140758514404,0.0 -10900,0.65645075,2.3003237,,,,,,,,,,,,,, -11000,0.8413944,2.2568803,,,,,,,,,,,,,, -11100,1.6674992,2.2216144,,,,,,,,,,,,,, -11200,0.7911816,2.19277,,,,,,,,,,,,,, -11300,0.902983,2.1707013,,,,,,,,,,,,,, -11400,0.83634305,2.1449573,,,,,,,,,,,,,, -11500,1.1297797,2.1409028,,,,,,,,,,,,,, -11600,1.0969169,2.184453,,,,,,,,,,,,,, -11700,0.8248647,2.2361517,,,,,,,,,,,,,, -11800,1.2916111,2.1706905,,,,,,,,,,,,,, -11900,0.8454277,2.2172222,,,,,,,,,,,,,, -12000,0.7169333,2.2258985,,,,,,,,,,,,,, -12100,1.0116253,2.1818275,,,,,,,,,,,,,, -12200,0.7673554,2.1331406,,,,,,,,,,,,,, -12300,1.0718803,2.1683848,,,,,,,,,,,,,, -12400,1.0826213,2.154261,,,,,,,,,,,,,, -12500,0.827291,2.1897147,,,,,,,,,,,,,, -12600,0.71398354,2.2111588,,,,,,,,,,,,,, -12691,,,1.1640844,0.3571085155123165,1.218482,0.352742404201705,5348.0,0.8880885,0.2892775171125058,2472.0,10116.434507131577,11110.8425116539,10116.434507131577,993.5067636966704,0.3614308834075928,0.0 -12700,0.79090935,2.1975508,,,,,,,,,,,,,, -12800,1.0602902,2.158266,,,,,,,,,,,,,, -12900,0.7473101,2.1202714,,,,,,,,,,,,,, -13000,1.1155919,2.148579,,,,,,,,,,,,,, -13100,1.0072607,2.073832,,,,,,,,,,,,,, -13200,0.8765396,2.1123662,,,,,,,,,,,,,, -13300,0.9533262,2.1330638,,,,,,,,,,,,,, -13400,0.54830205,2.1158717,,,,,,,,,,,,,, -13500,1.0713584,2.1132965,,,,,,,,,,,,,, -13600,0.89799863,2.108791,,,,,,,,,,,,,, -13700,1.3512584,2.109364,,,,,,,,,,,,,, -13800,0.6431105,2.07489,,,,,,,,,,,,,, -13900,0.84185004,2.1035647,,,,,,,,,,,,,, -14000,0.71134734,2.0337987,,,,,,,,,,,,,, -14100,0.69558185,2.1067624,,,,,,,,,,,,,, -14200,0.86137664,2.118346,,,,,,,,,,,,,, -14300,1.1205093,2.0986757,,,,,,,,,,,,,, -14400,0.7717962,2.0290499,,,,,,,,,,,,,, -14500,,,1.1236786,0.3462997832604231,1.16027,0.3341958156733637,5348.0,0.8417588,0.272825137610952,2472.0,11556.422718286514,12681.722338914871,11556.422718286514,1124.2681069374084,0.4147167205810547,0.0 -14500,0.9914752,2.0371199,,,,,,,,,,,,,, -14600,0.6203264,2.0296211,,,,,,,,,,,,,, -14700,0.98772115,2.0536575,,,,,,,,,,,,,, -14800,1.4602634,2.0516167,,,,,,,,,,,,,, -14900,1.2248775,2.0582414,,,,,,,,,,,,,, -15000,0.6351204,2.0846334,,,,,,,,,,,,,, -15100,0.6058076,2.0949917,,,,,,,,,,,,,, -15200,0.8544022,2.0758612,,,,,,,,,,,,,, -15300,0.88968873,2.0393598,,,,,,,,,,,,,, -15400,0.6770584,2.0350604,,,,,,,,,,,,,, -15500,1.1043285,2.0521092,,,,,,,,,,,,,, -15600,0.7367694,2.0202243,,,,,,,,,,,,,, -15700,0.6718746,2.0836277,,,,,,,,,,,,,, -15800,0.92260224,2.0405326,,,,,,,,,,,,,, -15900,0.908542,2.0488467,,,,,,,,,,,,,, -16000,0.6743316,2.054577,,,,,,,,,,,,,, -16100,0.7976262,2.0593908,,,,,,,,,,,,,, -16200,0.75254184,2.0585623,,,,,,,,,,,,,, -16300,0.88978344,2.025442,,,,,,,,,,,,,, -16323,,,1.0246935,0.3256618002726353,1.0873188,0.3203896617975033,5348.0,0.7645127,0.2534682022220868,2472.0,12996.567366361618,14254.55534505844,12996.567366361618,1256.8274652957916,0.464468240737915,0.0 -16400,0.9404994,2.0815947,,,,,,,,,,,,,, -16500,0.764538,1.9779053,,,,,,,,,,,,,, -16600,0.50189,2.0012763,,,,,,,,,,,,,, -16700,0.5505143,2.0334835,,,,,,,,,,,,,, -16800,0.6894729,2.013699,,,,,,,,,,,,,, -16900,0.63393044,2.040346,,,,,,,,,,,,,, -17000,0.6733618,2.0103374,,,,,,,,,,,,,, -17100,1.1454461,2.0395842,,,,,,,,,,,,,, -17200,0.9090186,2.08344,,,,,,,,,,,,,, -17300,0.6206533,1.9406395,,,,,,,,,,,,,, -17400,0.71968734,1.9753789,,,,,,,,,,,,,, -17500,0.7842304,1.9854662,,,,,,,,,,,,,, -17600,0.6388236,1.9689813,,,,,,,,,,,,,, -17700,0.9758303,2.0813959,,,,,,,,,,,,,, -17800,0.98349714,1.9924289,,,,,,,,,,,,,, -17900,0.6179917,2.0501666,,,,,,,,,,,,,, -18000,1.1710331,2.0069163,,,,,,,,,,,,,, -18100,0.83153147,1.9956169,,,,,,,,,,,,,, -18144,,,0.91025984,0.294290872165181,1.035664,0.3064097241665621,5348.0,0.7306505,0.2422968334247354,2472.0,14436.753766775131,15826.1676633358,14436.753766775131,1388.124047040939,0.5152781009674072,0.0 -18200,0.6347141,2.0183768,,,,,,,,,,,,,, -18300,0.75209904,2.0117831,,,,,,,,,,,,,, -18400,0.8158492,1.9954237,,,,,,,,,,,,,, -18500,0.57372504,1.9589293,,,,,,,,,,,,,, -18600,0.5924823,1.980684,,,,,,,,,,,,,, -18700,0.74990094,1.9328274,,,,,,,,,,,,,, -18800,0.9660127,1.9712166,,,,,,,,,,,,,, -18900,0.75725794,1.926169,,,,,,,,,,,,,, -19000,0.62850034,1.8944418,,,,,,,,,,,,,, -19100,0.5241324,1.9630724,,,,,,,,,,,,,, -19200,0.6744916,1.911042,,,,,,,,,,,,,, -19300,0.5854475,1.9280365,,,,,,,,,,,,,, -19400,0.59456116,1.9529346,,,,,,,,,,,,,, -19500,1.0290675,2.0572271,,,,,,,,,,,,,, -19600,0.8254641,1.88705,,,,,,,,,,,,,, -19700,1.0919058,1.8978093,,,,,,,,,,,,,, -19800,0.65582514,1.9040465,,,,,,,,,,,,,, -19900,1.0159259,1.9798217,,,,,,,,,,,,,, -19977,,,0.83924824,0.2758357747722438,0.9875356,0.2946600113924906,5348.0,0.68959343,0.2303130014421221,2472.0,15877.275433778765,17397.07205915451,15877.275433778765,1518.3700096607208,0.5735807418823242,0.0 -20000,0.6225357,2.0096896,,,,,,,,,,,,,, -20100,0.47715315,1.9324362,,,,,,,,,,,,,, -20200,0.6505087,1.9354769,,,,,,,,,,,,,, -20300,0.5566874,1.9123877,,,,,,,,,,,,,, -20400,0.7519626,1.9406718,,,,,,,,,,,,,, -20500,0.8428557,1.8969728,,,,,,,,,,,,,, -20600,0.86553055,1.9048657,,,,,,,,,,,,,, -20700,0.5630475,1.8961098,,,,,,,,,,,,,, -20800,0.8028578,1.917089,,,,,,,,,,,,,, -20900,0.63389575,1.8990494,,,,,,,,,,,,,, -21000,0.6719016,1.9572603,,,,,,,,,,,,,, -21100,0.7041086,1.8779083,,,,,,,,,,,,,, -21200,0.88771546,1.8221765,,,,,,,,,,,,,, -21300,0.72032464,1.9159638,,,,,,,,,,,,,, -21400,0.5406688,1.8546717,,,,,,,,,,,,,, -21500,0.66313183,1.97195,,,,,,,,,,,,,, -21600,0.6189072,1.8837447,,,,,,,,,,,,,, -21700,0.7813801,1.814202,,,,,,,,,,,,,, -21798,,,0.86420834,0.2851232256183706,0.9465712,0.2811821157206716,5348.0,0.6596471,0.2209493632319785,2472.0,17317.22550392151,18967.39634442329,17317.22550392151,1648.6084377765656,0.6313190460205078,0.0 -21800,0.49342307,1.875158,,,,,,,,,,,,,, -21900,0.5014433,1.9057382,,,,,,,,,,,,,, -22000,0.53087395,1.849112,,,,,,,,,,,,,, -22100,0.78510225,1.8654045,,,,,,,,,,,,,, -22200,0.47018886,1.8770846,,,,,,,,,,,,,, -22300,0.63574195,1.8267262,,,,,,,,,,,,,, -22400,0.6268789,1.8145409,,,,,,,,,,,,,, -22500,0.5167618,1.8607268,,,,,,,,,,,,,, -22600,0.91816336,1.765292,,,,,,,,,,,,,, -22700,1.1205854,1.9076985,,,,,,,,,,,,,, -22800,0.80171853,1.9599966,,,,,,,,,,,,,, -22900,0.69888943,1.9070117,,,,,,,,,,,,,, -23000,0.6063162,1.8699965,,,,,,,,,,,,,, -23100,0.5809266,1.8813982,,,,,,,,,,,,,, -23200,0.52487326,1.8008113,,,,,,,,,,,,,, -23300,0.55975795,1.8772074,,,,,,,,,,,,,, -23400,0.6070333,1.8537712,,,,,,,,,,,,,, -23500,0.6700071,1.8404663,,,,,,,,,,,,,, -23600,0.57914615,1.8700631,,,,,,,,,,,,,, -23608,,,0.8158463,0.2648283332017014,0.9147472,0.2765285729457312,5348.0,0.631796,0.2112810513273617,2472.0,18757.19909286499,20537.81848526001,18757.19909286499,1778.9243762493134,0.6872842311859131,0.0 -23700,0.5733056,1.8317208,,,,,,,,,,,,,, -23800,0.84161216,1.7771121,,,,,,,,,,,,,, -23900,0.47022814,1.8155438,,,,,,,,,,,,,, -24000,0.79322225,1.851797,,,,,,,,,,,,,, -24100,0.53174657,1.8203869,,,,,,,,,,,,,, -24200,0.60511,1.9051962,,,,,,,,,,,,,, -24300,0.9222738,1.827603,,,,,,,,,,,,,, -24400,0.7877879,1.8625766,,,,,,,,,,,,,, -24500,0.6756392,1.7950023,,,,,,,,,,,,,, -24600,0.4703413,1.787295,,,,,,,,,,,,,, -24700,0.52680975,1.8133003,,,,,,,,,,,,,, -24800,0.72907317,1.7716638,,,,,,,,,,,,,, -24900,0.8008146,1.8293257,,,,,,,,,,,,,, -25000,0.71825725,1.8428062,,,,,,,,,,,,,, -25100,0.72478193,1.8598187,,,,,,,,,,,,,, -25200,0.73395026,1.8580052,,,,,,,,,,,,,, -25300,1.0018898,1.7440085,,,,,,,,,,,,,, -25400,0.76872426,1.775436,,,,,,,,,,,,,, -25422,,,0.8221751,0.2644739657925494,0.89193165,0.2666228989061278,5348.0,0.6027977,0.2033392236914264,2472.0,20197.799550056458,22111.28186249733,20197.799550056458,1911.658009529113,0.73909592628479,0.0 -25500,0.6685064,1.7644684,,,,,,,,,,,,,, -25600,0.60969925,1.8824679,,,,,,,,,,,,,, -25700,0.5232219,1.8164086,,,,,,,,,,,,,, -25800,1.1834657,1.8047929,,,,,,,,,,,,,, -25900,0.54229546,1.8241634,,,,,,,,,,,,,, -26000,0.8038446,1.8334788,,,,,,,,,,,,,, -26100,0.60536766,1.7678547,,,,,,,,,,,,,, -26200,0.8320865,1.7968357,,,,,,,,,,,,,, -26300,0.7678136,1.823536,,,,,,,,,,,,,, -26400,0.64419335,1.7797889,,,,,,,,,,,,,, -26500,0.59609336,1.9196565,,,,,,,,,,,,,, -26600,0.5783734,1.8151217,,,,,,,,,,,,,, -26700,1.0730942,1.811252,,,,,,,,,,,,,, -26800,0.869179,1.7337059,,,,,,,,,,,,,, -26900,1.1654363,1.8021108,,,,,,,,,,,,,, -27000,0.5367714,1.7355571,,,,,,,,,,,,,, -27100,0.5292297,1.8250873,,,,,,,,,,,,,, -27200,0.9352917,1.7390125,,,,,,,,,,,,,, -27255,,,0.7159185,0.2397716148297229,0.87400496,0.2637458122942352,5348.0,0.59717554,0.2019580362764812,2472.0,21638.32007479668,23684.04691696167,21638.32007479668,2043.7654702663424,0.7973823547363281,0.0 -27300,0.7293449,1.7821144,,,,,,,,,,,,,, -27400,0.48681352,1.7948309,,,,,,,,,,,,,, -27500,0.7172273,1.8320844,,,,,,,,,,,,,, -27600,0.8756737,1.7369043,,,,,,,,,,,,,, -27700,0.6995591,1.8174304,,,,,,,,,,,,,, -27800,0.8057827,1.781666,,,,,,,,,,,,,, -27900,1.0739018,1.8100538,,,,,,,,,,,,,, -28000,0.69974583,1.7458892,,,,,,,,,,,,,, -28100,0.6615025,1.8085558,,,,,,,,,,,,,, -28200,0.53154486,1.7848954,,,,,,,,,,,,,, -28300,0.82328176,1.742548,,,,,,,,,,,,,, -28400,0.56422263,1.7196832,,,,,,,,,,,,,, -28500,0.693245,1.7206606,,,,,,,,,,,,,, -28600,0.69416064,1.8042976,,,,,,,,,,,,,, -28700,0.86576617,1.8319371,,,,,,,,,,,,,, -28800,0.82274204,1.785522,,,,,,,,,,,,,, -28900,0.50392014,1.7389063,,,,,,,,,,,,,, -29000,0.643067,1.8065504,,,,,,,,,,,,,, -29078,,,0.7162538,0.2431265672208601,0.8402705,0.2538691022138119,5348.0,0.5756875,0.1954989539536489,2472.0,23079.08738541603,25255.72882080078,23079.08738541603,2174.5478515625,0.8513691425323486,0.0 -29100,0.593939,1.7591753,,,,,,,,,,,,,, -29200,0.533705,1.7696711,,,,,,,,,,,,,, -29300,0.6843605,1.7424139,,,,,,,,,,,,,, -29400,0.9724768,1.7483835,,,,,,,,,,,,,, -29500,0.97852606,1.740059,,,,,,,,,,,,,, -29600,0.77503246,1.7530954,,,,,,,,,,,,,, -29700,0.6786697,1.7241622,,,,,,,,,,,,,, -29800,0.5377708,1.7197571,,,,,,,,,,,,,, -29900,0.56591374,1.6810981,,,,,,,,,,,,,, -30000,0.94479674,1.7652836,,,,,,,,,,,,,, -30100,0.77594364,1.7478908,,,,,,,,,,,,,, -30200,0.6010521,1.7085387,,,,,,,,,,,,,, -30300,0.7115822,1.7578999,,,,,,,,,,,,,, -30400,0.5728752,1.7711971,,,,,,,,,,,,,, -30500,0.82120913,1.7539989,,,,,,,,,,,,,, -30600,0.7365617,1.803205,,,,,,,,,,,,,, -30700,0.76961976,1.7023059,,,,,,,,,,,,,, -30800,0.57438016,1.7447569,,,,,,,,,,,,,, -30893,,,0.7572224,0.2548112793816958,0.86291736,0.2615638607026656,5348.0,0.57978195,0.1974894887575407,2472.0,24519.23925757408,26824.97067308426,24519.23925757408,2303.503675699234,0.9082772731781006,0.0 -30900,0.68130094,1.7653122,,,,,,,,,,,,,, -31000,0.6502472,1.7563269,,,,,,,,,,,,,, -31100,0.5402239,1.7237159,,,,,,,,,,,,,, -31200,0.5815834,1.745728,,,,,,,,,,,,,, -31300,0.57373285,1.7203715,,,,,,,,,,,,,, -31400,0.60741407,1.7109808,,,,,,,,,,,,,, -31500,0.6026396,1.7305773,,,,,,,,,,,,,, -31600,0.6598179,1.7462517,,,,,,,,,,,,,, -31700,0.5698369,1.6997969,,,,,,,,,,,,,, -31800,0.9240608,1.7748232,,,,,,,,,,,,,, -31900,0.6593016,1.6996349,,,,,,,,,,,,,, -32000,0.59645134,1.7346146,,,,,,,,,,,,,, -32100,0.4849144,1.7361671,,,,,,,,,,,,,, -32200,0.6671906,1.7703818,,,,,,,,,,,,,, -32300,0.6624642,1.7070988,,,,,,,,,,,,,, -32400,0.6211881,1.6950403,,,,,,,,,,,,,, -32500,0.7728652,1.7581031,,,,,,,,,,,,,, -32600,0.8524649,1.711789,,,,,,,,,,,,,, -32700,0.55629724,1.7022179,,,,,,,,,,,,,, -32729,,,0.508243,0.1794177239877347,0.7982695,0.2409704857255954,5348.0,0.5389743,0.1829667093209838,2472.0,25959.860438346863,28407.60689687729,25959.860438346863,2445.386803150177,0.9621689319610596,0.0 -32800,0.59907573,1.7191569,,,,,,,,,,,,,, -32900,0.7755164,1.8157711,,,,,,,,,,,,,, -33000,0.87099963,1.718943,,,,,,,,,,,,,, -33100,0.7115641,1.6715089,,,,,,,,,,,,,, -33200,0.7902925,1.6958288,,,,,,,,,,,,,, -33300,0.5384897,1.714121,,,,,,,,,,,,,, -33400,1.0748212,1.7654767,,,,,,,,,,,,,, -33500,0.73176336,1.7401031,,,,,,,,,,,,,, -33600,0.78583,1.7154844,,,,,,,,,,,,,, -33700,0.71543944,1.688263,,,,,,,,,,,,,, -33800,0.6735952,1.7287326,,,,,,,,,,,,,, -33900,0.77580845,1.7680243,,,,,,,,,,,,,, -34000,0.6288278,1.6996348,,,,,,,,,,,,,, -34100,0.6914698,1.6432853,,,,,,,,,,,,,, -34200,0.61128587,1.7020643,,,,,,,,,,,,,, -34300,0.8674174,1.6849624,,,,,,,,,,,,,, -34400,0.7357327,1.670626,,,,,,,,,,,,,, -34500,0.54330987,1.7103748,,,,,,,,,,,,,, -34568,,,0.45416874,0.1624045821292815,0.7872792,0.2371858617260588,5348.0,0.51708585,0.1752076859017325,2472.0,27399.94140648842,29980.78520011902,27399.94140648842,2578.3474531173706,1.02097749710083,0.0 -34600,0.5825763,1.6753042,,,,,,,,,,,,,, -34700,0.8044265,1.7632419,,,,,,,,,,,,,, -34800,0.7565005,1.7022979,,,,,,,,,,,,,, -34900,0.60921884,1.7238437,,,,,,,,,,,,,, -35000,0.5862924,1.6860056,,,,,,,,,,,,,, -35100,0.66895455,1.6591786,,,,,,,,,,,,,, -35200,0.7179227,1.6808888,,,,,,,,,,,,,, -35300,0.6005929,1.6849546,,,,,,,,,,,,,, -35400,0.66400456,1.6938056,,,,,,,,,,,,,, -35500,0.58007777,1.6478448,,,,,,,,,,,,,, -35600,0.5866277,1.705818,,,,,,,,,,,,,, -35700,0.94949895,1.7841297,,,,,,,,,,,,,, -35800,0.81955886,1.6679486,,,,,,,,,,,,,, -35900,0.5592827,1.6457565,,,,,,,,,,,,,, -36000,0.90122336,1.6778035,,,,,,,,,,,,,, -36100,0.5193734,1.7248777,,,,,,,,,,,,,, -36200,0.58022696,1.6988795,,,,,,,,,,,,,, -36300,0.6137917,1.6227957,,,,,,,,,,,,,, -36400,0.6827549,1.5992962,,,,,,,,,,,,,, -36405,,,0.4568602,0.1632162807164774,0.7698689,0.2332853818898018,5348.0,0.51699764,0.1756342290739951,2472.0,28840.46229052544,31553.070999383926,28840.46229052544,2709.968914270401,1.0840272903442385,0.0 -36500,0.50233454,1.7017206,,,,,,,,,,,,,, -36600,0.69014204,1.6708535,,,,,,,,,,,,,, -36700,0.7538832,1.6966351,,,,,,,,,,,,,, -36800,0.6446851,1.6290957,,,,,,,,,,,,,, -36900,0.83549607,1.7148002,,,,,,,,,,,,,, -37000,0.6550786,1.617703,,,,,,,,,,,,,, -37100,0.5834557,1.6672254,,,,,,,,,,,,,, -37200,0.6879334,1.6442672,,,,,,,,,,,,,, -37300,0.59102166,1.6766188,,,,,,,,,,,,,, -37400,0.5686524,1.6334902,,,,,,,,,,,,,, -37500,0.6110673,1.647438,,,,,,,,,,,,,, -37600,0.70079625,1.6654558,,,,,,,,,,,,,, -37700,0.65981525,1.6783323,,,,,,,,,,,,,, -37800,0.6696852,1.6801305,,,,,,,,,,,,,, -37900,0.5769228,1.651128,,,,,,,,,,,,,, -38000,0.5760729,1.6665962,,,,,,,,,,,,,, -38100,0.7364818,1.6283393,,,,,,,,,,,,,, -38200,0.6465566,1.6631087,,,,,,,,,,,,,, -38227,,,0.41335335,0.1520887762228717,0.74211925,0.2256195873601282,5348.0,0.487146,0.1664533950805354,2472.0,30280.543273448944,33125.39553499222,30280.543273448944,2842.078820705414,1.1394822597503662,0.0 -38300,0.8853566,1.6291736,,,,,,,,,,,,,, -38400,0.81396693,1.5936341,,,,,,,,,,,,,, -38500,0.6480854,1.6719263,,,,,,,,,,,,,, -38600,0.5479812,1.6901876,,,,,,,,,,,,,, -38700,0.6174305,1.6235976,,,,,,,,,,,,,, -38800,0.6291805,1.6075418,,,,,,,,,,,,,, -38900,0.66661674,1.6958237,,,,,,,,,,,,,, -39000,0.6120158,1.612678,,,,,,,,,,,,,, -39100,0.5610724,1.6055373,,,,,,,,,,,,,, -39200,0.63898957,1.5945615,,,,,,,,,,,,,, -39300,0.65674305,1.6276779,,,,,,,,,,,,,, -39400,0.8425633,1.6133932,,,,,,,,,,,,,, -39500,0.66186464,1.6745994,,,,,,,,,,,,,, -39600,0.7017631,1.6403942,,,,,,,,,,,,,, -39700,0.71856654,1.6112128,,,,,,,,,,,,,, -39800,0.5255747,1.6256714,,,,,,,,,,,,,, -39900,0.57911384,1.5799401,,,,,,,,,,,,,, -40000,0.6451242,1.6214695,,,,,,,,,,,,,, -40049,,,0.41621852,0.1532543069726285,0.7324096,0.2226749181768153,5348.0,0.4805117,0.164401925537749,2472.0,31721.368850708008,34698.13117194176,31721.368850708008,2973.8543276786804,1.196692943572998,0.0 -40100,0.8225161,1.566257,,,,,,,,,,,,,, -40200,0.7054793,1.5919905,,,,,,,,,,,,,, -40300,0.97697395,1.6587527,,,,,,,,,,,,,, -40400,0.6387102,1.5845803,,,,,,,,,,,,,, -40500,0.69391245,1.5920749,,,,,,,,,,,,,, -40600,0.53002656,1.6184796,,,,,,,,,,,,,, -40700,0.71105367,1.639114,,,,,,,,,,,,,, -40800,0.71695185,1.5441247,,,,,,,,,,,,,, -40900,0.5548173,1.5609636,,,,,,,,,,,,,, -41000,0.61695683,1.6354957,,,,,,,,,,,,,, -41100,0.89574116,1.5995544,,,,,,,,,,,,,, -41200,0.60494244,1.5695279,,,,,,,,,,,,,, -41300,0.6806119,1.5831649,,,,,,,,,,,,,, -41400,0.64113265,1.6120028,,,,,,,,,,,,,, -41500,0.7225812,1.6620293,,,,,,,,,,,,,, -41600,0.75634545,1.5766029,,,,,,,,,,,,,, -41700,0.658287,1.591398,,,,,,,,,,,,,, -41800,0.6773258,1.5400877,,,,,,,,,,,,,, -41871,,,0.38685045,0.143294122585181,0.69860196,0.2129140639331125,5348.0,0.45834982,0.1583084516482847,2472.0,33161.7200114727,36272.580060482025,33161.7200114727,3107.816128730774,1.2540216445922852,0.0 -41900,0.63293993,1.64652,,,,,,,,,,,,,, -42000,0.6752739,1.6009246,,,,,,,,,,,,,, -42100,0.6103651,1.659274,,,,,,,,,,,,,, -42200,0.86801016,1.656424,,,,,,,,,,,,,, -42300,0.55787104,1.5161105,,,,,,,,,,,,,, -42400,0.5606358,1.611396,,,,,,,,,,,,,, -42500,0.5720862,1.5621283,,,,,,,,,,,,,, -42600,0.7013219,1.58672,,,,,,,,,,,,,, -42700,0.76413244,1.5727352,,,,,,,,,,,,,, -42800,0.80455524,1.5467516,,,,,,,,,,,,,, -42900,0.61602473,1.5891907,,,,,,,,,,,,,, -43000,0.6731836,1.5592307,,,,,,,,,,,,,, -43100,0.6841854,1.5336698,,,,,,,,,,,,,, -43200,0.7181046,1.6234996,,,,,,,,,,,,,, -43300,0.73811716,1.5734265,,,,,,,,,,,,,, -43400,0.59639996,1.545397,,,,,,,,,,,,,, -43500,0.7607953,1.6092288,,,,,,,,,,,,,, -43600,0.5473247,1.5508999,,,,,,,,,,,,,, -43700,0.76143456,1.6715244,,,,,,,,,,,,,, -43705,,,0.4322496,0.1538579260060887,0.68366253,0.2103845448313815,5348.0,0.43958446,0.152539963032925,2472.0,34601.8858397007,37844.95906472206,34601.8858397007,3239.886257171631,1.3179125785827637,0.0 -43800,0.66804594,1.6041749,,,,,,,,,,,,,, -43900,0.77746236,1.5735333,,,,,,,,,,,,,, -44000,0.671585,1.5454912,,,,,,,,,,,,,, -44100,0.6649878,1.5462111,,,,,,,,,,,,,, -44200,0.8755143,1.6026721,,,,,,,,,,,,,, -44300,0.67478883,1.4901216,,,,,,,,,,,,,, -44400,0.6477472,1.5597341,,,,,,,,,,,,,, -44500,0.6443104,1.5607189,,,,,,,,,,,,,, -44600,0.67061263,1.6181599,,,,,,,,,,,,,, -44700,0.6718259,1.5379142,,,,,,,,,,,,,, -44800,0.63312125,1.559132,,,,,,,,,,,,,, -44900,0.7221083,1.5346404,,,,,,,,,,,,,, -45000,0.605685,1.5512315,,,,,,,,,,,,,, -45100,0.7542351,1.5547299,,,,,,,,,,,,,, -45200,0.5988997,1.5452695,,,,,,,,,,,,,, -45300,0.7215485,1.5620365,,,,,,,,,,,,,, -45400,0.7538939,1.5894244,,,,,,,,,,,,,, -45500,0.6274266,1.5337876,,,,,,,,,,,,,, -45527,,,0.3753733,0.1400401803932237,0.67537385,0.2057889299748013,5348.0,0.43343684,0.1485182702658785,2472.0,36041.96706032753,39417.02039551735,36041.96706032753,3371.7308938503265,1.3760406970977783,0.0 -45600,0.6978997,1.5599179,,,,,,,,,,,,,, -45700,0.6815261,1.5337324,,,,,,,,,,,,,, -45800,0.6842943,1.5899708,,,,,,,,,,,,,, -45900,0.6931018,1.5129412,,,,,,,,,,,,,, -46000,0.7875954,1.5485144,,,,,,,,,,,,,, -46100,0.63954365,1.5413086,,,,,,,,,,,,,, -46200,0.75335616,1.5596968,,,,,,,,,,,,,, -46300,0.7696961,1.4718865,,,,,,,,,,,,,, -46400,0.68042713,1.5320835,,,,,,,,,,,,,, -46500,0.62139714,1.5036125,,,,,,,,,,,,,, -46600,0.73233795,1.527774,,,,,,,,,,,,,, -46700,0.62468636,1.5477738,,,,,,,,,,,,,, -46800,0.6891491,1.5018693,,,,,,,,,,,,,, -46900,0.60353065,1.5606762,,,,,,,,,,,,,, -47000,0.5514547,1.5290184,,,,,,,,,,,,,, -47100,0.65550935,1.5486356,,,,,,,,,,,,,, -47200,0.6541527,1.5017186,,,,,,,,,,,,,, -47300,0.7337026,1.5610557,,,,,,,,,,,,,, -47343,,,0.34525725,0.1258080321158192,0.6423175,0.1970900875677032,5348.0,0.41392162,0.1407795584262588,2472.0,37481.884147167206,40989.77420902252,37481.884147167206,3504.4280862808228,1.436626672744751,0.0 -47400,0.6313338,1.4801416,,,,,,,,,,,,,, -47500,0.63032305,1.5007262,,,,,,,,,,,,,, -47600,0.69243634,1.4888349,,,,,,,,,,,,,, -47700,0.58403677,1.5184139,,,,,,,,,,,,,, -47800,0.7267033,1.5219396,,,,,,,,,,,,,, -47900,0.646938,1.4720396,,,,,,,,,,,,,, -48000,0.77493984,1.5419494,,,,,,,,,,,,,, -48100,0.5907136,1.5425595,,,,,,,,,,,,,, -48200,0.7330371,1.5184649,,,,,,,,,,,,,, -48300,0.7197944,1.4915216,,,,,,,,,,,,,, -48400,0.6766783,1.5563955,,,,,,,,,,,,,, -48500,0.5814251,1.4913288,,,,,,,,,,,,,, -48600,0.6915576,1.5284853,,,,,,,,,,,,,, -48700,0.67563474,1.492789,,,,,,,,,,,,,, -48800,0.71668625,1.4722307,,,,,,,,,,,,,, -48900,0.6782323,1.4622843,,,,,,,,,,,,,, -49000,0.6653295,1.5001835,,,,,,,,,,,,,, -49100,0.7079968,1.4984335,,,,,,,,,,,,,, -49160,,,0.31973082,0.120110860171262,0.6218428,0.192697220425384,5348.0,0.40082654,0.1379765604371051,2472.0,38922.35270643234,42563.32991623879,38922.35270643234,3637.337597131729,1.5352272987365725,0.0 -49200,0.61670375,1.426037,,,,,,,,,,,,,, -49300,0.7016167,1.506561,,,,,,,,,,,,,, -49400,0.68669873,1.4920032,,,,,,,,,,,,,, -49500,0.6432609,1.4527804,,,,,,,,,,,,,, -49600,0.6904145,1.469609,,,,,,,,,,,,,, -49700,0.6028577,1.5126352,,,,,,,,,,,,,, -49800,0.7533127,1.4695098,,,,,,,,,,,,,, -49900,0.7300757,1.5264553,,,,,,,,,,,,,, -50000,0.5846775,1.447487,,,,,,,,,,,,,, -50100,0.6706398,1.5476352,,,,,,,,,,,,,, -50200,0.97997296,1.5161746,,,,,,,,,,,,,, -50300,0.66941446,1.5155104,,,,,,,,,,,,,, -50400,0.67126447,1.4805677,,,,,,,,,,,,,, -50500,0.71173584,1.4675653,,,,,,,,,,,,,, -50600,0.62244225,1.4166373,,,,,,,,,,,,,, -50700,0.8196722,1.4526021,,,,,,,,,,,,,, -50800,0.70517915,1.4540199,,,,,,,,,,,,,, -50900,0.64524966,1.4132742,,,,,,,,,,,,,, -50992,,,0.30350104,0.1163183008531079,0.6097256,0.1875319810382613,5348.0,0.38197964,0.1339751792496902,2472.0,40362.86497974396,44136.597618341446,40362.86497974396,3769.957051992416,1.5926856994628906,0.0 -51000,1.0038102,1.5382289,,,,,,,,,,,,,, -51100,0.5696869,1.455752,,,,,,,,,,,,,, -51200,0.5898462,1.4775623,,,,,,,,,,,,,, -51300,0.6902664,1.475994,,,,,,,,,,,,,, -51400,0.6908597,1.4648242,,,,,,,,,,,,,, -51500,0.7616576,1.4449104,,,,,,,,,,,,,, -51600,0.7082787,1.4122854,,,,,,,,,,,,,, -51700,0.6909919,1.4761053,,,,,,,,,,,,,, -51800,0.6346449,1.4923935,,,,,,,,,,,,,, -51900,0.6287149,1.4741179,,,,,,,,,,,,,, -52000,0.6937593,1.4877746,,,,,,,,,,,,,, -52100,0.67809135,1.4329785,,,,,,,,,,,,,, -52200,0.6744104,1.5127203,,,,,,,,,,,,,, -52300,0.87034076,1.4163369,,,,,,,,,,,,,, -52400,0.6465512,1.4740264,,,,,,,,,,,,,, -52500,0.5962335,1.38627,,,,,,,,,,,,,, -52600,0.607738,1.4209331,,,,,,,,,,,,,, -52700,0.61743224,1.3770486,,,,,,,,,,,,,, -52800,0.711647,1.4458872,,,,,,,,,,,,,, -52819,,,0.31450352,0.1154854943438069,0.587471,0.1810054355696728,5348.0,0.3687143,0.1252412000081246,2472.0,41803.16771149635,45707.625512599945,41803.16771149635,3900.5444436073294,1.651547908782959,0.0 -52900,0.6610875,1.4035047,,,,,,,,,,,,,, -53000,0.6801522,1.4000971,,,,,,,,,,,,,, -53100,0.6743716,1.4218717,,,,,,,,,,,,,, -53200,0.6113,1.41102,,,,,,,,,,,,,, -53300,0.63945794,1.3594806,,,,,,,,,,,,,, -53400,0.5882161,1.4365168,,,,,,,,,,,,,, -53500,0.7746914,1.4480169,,,,,,,,,,,,,, -53600,0.81256664,1.4033735,,,,,,,,,,,,,, -53700,0.74469763,1.3976856,,,,,,,,,,,,,, -53800,0.5994197,1.3920316,,,,,,,,,,,,,, -53900,0.75493944,1.4570185,,,,,,,,,,,,,, -54000,0.7125042,1.4325917,,,,,,,,,,,,,, -54100,0.7436382,1.436643,,,,,,,,,,,,,, -54200,0.661311,1.4144965,,,,,,,,,,,,,, -54300,0.6989791,1.3739057,,,,,,,,,,,,,, -54400,0.6896943,1.4124532,,,,,,,,,,,,,, -54500,0.60605,1.446781,,,,,,,,,,,,,, -54600,0.5883606,1.3505918,,,,,,,,,,,,,, -54636,,,0.29329264,0.1079460585277596,0.56704783,0.1752126437336474,5348.0,0.35127002,0.1210367030243942,2472.0,43243.308198452,47279.961285591125,43243.308198452,4032.609624862671,1.7056090831756592,0.0 -54700,0.90115446,1.399949,,,,,,,,,,,,,, -54800,0.72625715,1.3999096,,,,,,,,,,,,,, -54900,0.6211032,1.3775835,,,,,,,,,,,,,, -55000,0.60457766,1.3869263,,,,,,,,,,,,,, -55100,0.8100683,1.3361,,,,,,,,,,,,,, -55200,0.7865427,1.4156581,,,,,,,,,,,,,, -55300,0.69890815,1.3802054,,,,,,,,,,,,,, -55400,0.78765553,1.3919517,,,,,,,,,,,,,, -55500,0.60654444,1.3893075,,,,,,,,,,,,,, -55600,0.6473718,1.3180696,,,,,,,,,,,,,, -55700,0.72452474,1.4129856,,,,,,,,,,,,,, -55800,0.66309065,1.3506831,,,,,,,,,,,,,, -55900,0.6587862,1.4191804,,,,,,,,,,,,,, -56000,0.5957654,1.3743137,,,,,,,,,,,,,, -56100,0.8675461,1.3340713,,,,,,,,,,,,,, -56200,0.6961351,1.4163857,,,,,,,,,,,,,, -56300,0.6343739,1.3868493,,,,,,,,,,,,,, -56400,0.68804175,1.347014,,,,,,,,,,,,,, -56468,,,0.2854557,0.104970891640451,0.5525805,0.1698446566322639,5348.0,0.33652857,0.1180102776592935,2472.0,44683.90547633171,48853.7058467865,44683.90547633171,4165.615864276886,1.7686245441436768,0.0 -56500,0.7122465,1.3617393,,,,,,,,,,,,,, -56600,0.7218478,1.3818696,,,,,,,,,,,,,, -56700,0.60301757,1.312976,,,,,,,,,,,,,, -56800,0.7091848,1.361957,,,,,,,,,,,,,, -56900,0.75668454,1.3604338,,,,,,,,,,,,,, -57000,0.70232534,1.3251172,,,,,,,,,,,,,, -57100,0.7735984,1.3628486,,,,,,,,,,,,,, -57200,0.71422863,1.368932,,,,,,,,,,,,,, -57300,0.6855916,1.3332752,,,,,,,,,,,,,, -57400,0.7511267,1.4162842,,,,,,,,,,,,,, -57500,0.7310927,1.3838999,,,,,,,,,,,,,, -57600,0.6950825,1.3960708,,,,,,,,,,,,,, -57700,0.6315636,1.3064014,,,,,,,,,,,,,, -57800,0.69205505,1.288863,,,,,,,,,,,,,, -57900,0.8746942,1.3205484,,,,,,,,,,,,,, -58000,0.68799716,1.3060325,,,,,,,,,,,,,, -58100,0.8640823,1.3292459,,,,,,,,,,,,,, -58200,0.6348675,1.3692214,,,,,,,,,,,,,, -58296,,,0.25223732,0.0949610668576031,0.5288609,0.1617154387557083,5348.0,0.32527593,0.1107387321511993,2472.0,46124.3394947052,50427.39268708229,46124.3394947052,4298.656836748123,1.9026412963867188,0.0 -58300,0.77000874,1.3882353,,,,,,,,,,,,,, -58400,0.66956174,1.2821906,,,,,,,,,,,,,, -58500,0.6768844,1.3158988,,,,,,,,,,,,,, -58600,0.8134163,1.32024,,,,,,,,,,,,,, -58700,0.71586096,1.3402594,,,,,,,,,,,,,, -58800,0.73082346,1.350599,,,,,,,,,,,,,, -58900,0.71410644,1.3265266,,,,,,,,,,,,,, -59000,0.7387362,1.3295012,,,,,,,,,,,,,, -59100,0.699241,1.3545676,,,,,,,,,,,,,, -59200,0.8910446,1.3506467,,,,,,,,,,,,,, -59300,0.86189926,1.3628997,,,,,,,,,,,,,, -59400,0.6864007,1.371931,,,,,,,,,,,,,, -59500,0.73584306,1.3319869,,,,,,,,,,,,,, -59600,0.82643414,1.381556,,,,,,,,,,,,,, -59700,0.71025205,1.2951105,,,,,,,,,,,,,, -59800,0.6878034,1.262828,,,,,,,,,,,,,, -59900,0.6957874,1.3038542,,,,,,,,,,,,,, -60000,0.7871558,1.2779839,,,,,,,,,,,,,, -60100,0.860116,1.2879398,,,,,,,,,,,,,, -60135,,,0.2326758,0.0873209403468769,0.5112117,0.1576701391235506,5348.0,0.30733588,0.104645258261735,2472.0,47564.77532219887,51999.75342416763,47564.77532219887,4430.442895412445,1.963092565536499,0.0 -60200,0.7414892,1.3113872,,,,,,,,,,,,,, -60300,0.6869248,1.3363204,,,,,,,,,,,,,, -60400,0.8337653,1.3203756,,,,,,,,,,,,,, -60500,0.7177907,1.3048322,,,,,,,,,,,,,, -60600,0.8090582,1.2947254,,,,,,,,,,,,,, -60700,0.68860024,1.2774476,,,,,,,,,,,,,, -60800,0.7008119,1.3141083,,,,,,,,,,,,,, -60900,0.78728616,1.2493674,,,,,,,,,,,,,, -61000,0.8386161,1.3061615,,,,,,,,,,,,,, -61100,0.7897481,1.2935615,,,,,,,,,,,,,, -61200,0.8604633,1.2759347,,,,,,,,,,,,,, -61300,0.77856797,1.3072561,,,,,,,,,,,,,, -61400,0.73691404,1.3240473,,,,,,,,,,,,,, -61500,0.82772225,1.3175384,,,,,,,,,,,,,, -61600,0.78015774,1.2355665,,,,,,,,,,,,,, -61700,0.70552593,1.2905884,,,,,,,,,,,,,, -61800,0.70396674,1.3264378,,,,,,,,,,,,,, -61900,0.66327775,1.2492697,,,,,,,,,,,,,, -61955,,,0.21656099,0.0802586862794826,0.49202898,0.1502650202264981,5348.0,0.29639518,0.1002173339020575,2472.0,49004.667399168015,53572.28482842445,49004.667399168015,4562.946071147919,2.0217151641845703,0.0 -62000,0.75058347,1.2862899,,,,,,,,,,,,,, -62100,0.86506,1.2342225,,,,,,,,,,,,,, -62200,1.0361538,1.2645564,,,,,,,,,,,,,, -62300,1.0494337,1.2299757,,,,,,,,,,,,,, -62400,0.78003424,1.2947929,,,,,,,,,,,,,, -62500,0.77356046,1.2938298,,,,,,,,,,,,,, -62600,0.7744524,1.2877434,,,,,,,,,,,,,, -62700,0.74823934,1.2977115,,,,,,,,,,,,,, -62800,0.8896309,1.267953,,,,,,,,,,,,,, -62900,0.8863283,1.2555559,,,,,,,,,,,,,, -63000,0.7221459,1.2382569,,,,,,,,,,,,,, -63100,0.7959858,1.2351012,,,,,,,,,,,,,, -63200,0.82125366,1.2943794,,,,,,,,,,,,,, -63300,0.82002836,1.2457949,,,,,,,,,,,,,, -63400,0.7666597,1.2728084,,,,,,,,,,,,,, -63500,0.8279949,1.2393421,,,,,,,,,,,,,, -63600,0.81633455,1.2732863,,,,,,,,,,,,,, -63700,0.71325445,1.2434368,,,,,,,,,,,,,, -63784,,,0.21301003,0.0795733142829309,0.4686298,0.1450225436148952,5348.0,0.28154856,0.0963378221924319,2472.0,50444.830134153366,55144.689125299454,50444.830134153366,4695.051886081696,2.079317331314087,0.0 -63800,0.7684535,1.2633102,,,,,,,,,,,,,, -63900,0.8467965,1.2268765,,,,,,,,,,,,,, -64000,0.7420195,1.2433848,,,,,,,,,,,,,, -64100,0.8295,1.2598693,,,,,,,,,,,,,, -64200,0.7987566,1.2598315,,,,,,,,,,,,,, -64300,0.8320935,1.2601141,,,,,,,,,,,,,, -64400,0.90721446,1.2613287,,,,,,,,,,,,,, -64500,0.73812,1.20023,,,,,,,,,,,,,, -64600,0.9170532,1.2130939,,,,,,,,,,,,,, -64700,0.82095915,1.3176451,,,,,,,,,,,,,, -64800,0.8204346,1.211004,,,,,,,,,,,,,, -64900,0.8732252,1.2535555,,,,,,,,,,,,,, -65000,0.8060598,1.2265311,,,,,,,,,,,,,, -65100,0.7053633,1.2537627,,,,,,,,,,,,,, -65200,0.8819041,1.1999384,,,,,,,,,,,,,, -65300,0.9286125,1.2311976,,,,,,,,,,,,,, -65400,0.6875889,1.1939583,,,,,,,,,,,,,, -65500,0.8143145,1.1824385,,,,,,,,,,,,,, -65600,0.82435936,1.2064219,,,,,,,,,,,,,, -65615,,,0.18987854,0.0696951776786194,0.45930022,0.1404655473705552,5348.0,0.27360567,0.0922551946864907,2472.0,51885.259873628616,56718.33281850815,51885.259873628616,4828.119354486465,2.1474575996398926,0.0 -65700,0.7942514,1.2542651,,,,,,,,,,,,,, -65800,0.9941205,1.2186149,,,,,,,,,,,,,, -65900,0.86457473,1.2644155,,,,,,,,,,,,,, -66000,0.83469,1.2182889,,,,,,,,,,,,,, -66100,0.77374697,1.2160386,,,,,,,,,,,,,, -66200,0.79736227,1.2030035,,,,,,,,,,,,,, -66300,0.696149,1.1756853,,,,,,,,,,,,,, -66400,0.7493356,1.2272245,,,,,,,,,,,,,, -66500,0.84340376,1.2003931,,,,,,,,,,,,,, -66600,1.0452592,1.2433478,,,,,,,,,,,,,, -66700,0.82872534,1.1647713,,,,,,,,,,,,,, -66800,0.9874321,1.1816311,,,,,,,,,,,,,, -66900,0.8564304,1.241717,,,,,,,,,,,,,, -67000,0.8430661,1.1905861,,,,,,,,,,,,,, -67100,0.85929704,1.1787484,,,,,,,,,,,,,, -67200,0.81805354,1.191258,,,,,,,,,,,,,, -67300,0.8596142,1.1974113,,,,,,,,,,,,,, -67400,0.96457326,1.1514915,,,,,,,,,,,,,, -67452,,,0.18605858,0.0703033976144076,0.43531832,0.1353389265956728,5348.0,0.2592684,0.0881928787601811,2472.0,53325.22155690193,58291.54190802574,53325.22155690193,4961.224667787552,2.211995840072632,0.0 -67500,0.74140584,1.2016522,,,,,,,,,,,,,, -67600,0.84676516,1.2302494,,,,,,,,,,,,,, -67700,0.9384754,1.1704842,,,,,,,,,,,,,, -67800,0.9001181,1.1709098,,,,,,,,,,,,,, -67900,0.7979869,1.1812414,,,,,,,,,,,,,, -68000,0.8295623,1.1170102,,,,,,,,,,,,,, -68100,0.83350736,1.2113826,,,,,,,,,,,,,, -68200,0.8618705,1.1592165,,,,,,,,,,,,,, -68300,0.8023299,1.1233071,,,,,,,,,,,,,, -68400,0.7724861,1.1291355,,,,,,,,,,,,,, -68500,0.88078064,1.1592156,,,,,,,,,,,,,, -68600,0.7586757,1.1538565,,,,,,,,,,,,,, -68700,0.79165,1.1275192,,,,,,,,,,,,,, -68800,0.7948573,1.1514136,,,,,,,,,,,,,, -68900,0.9744678,1.192395,,,,,,,,,,,,,, -69000,0.74030185,1.142476,,,,,,,,,,,,,, -69100,0.8255774,1.151144,,,,,,,,,,,,,, -69200,0.7284367,1.1377615,,,,,,,,,,,,,, -69288,,,0.1981465,0.0684959758019311,0.42125234,0.1299419755351091,5348.0,0.24975313,0.0854711270895537,2472.0,54765.25285768509,59863.52037596703,54765.25285768509,5093.022603034973,2.2826132774353027,0.0 -69300,0.77308136,1.1199888,,,,,,,,,,,,,, -69400,0.97921985,1.2080991,,,,,,,,,,,,,, -69500,0.8492529,1.1635231,,,,,,,,,,,,,, -69600,0.9727568,1.1660384,,,,,,,,,,,,,, -69700,0.8199181,1.1671427,,,,,,,,,,,,,, -69800,0.83871007,1.1748337,,,,,,,,,,,,,, -69900,0.86891055,1.1488715,,,,,,,,,,,,,, -70000,0.82812715,1.1162452,,,,,,,,,,,,,, -70100,0.8074804,1.1657026,,,,,,,,,,,,,, -70200,0.98428154,1.1232909,,,,,,,,,,,,,, -70300,0.8919305,1.1480291,,,,,,,,,,,,,, -70400,0.84322006,1.1690514,,,,,,,,,,,,,, -70500,0.9410386,1.1116462,,,,,,,,,,,,,, -70600,0.8596024,1.1626686,,,,,,,,,,,,,, -70700,1.0358357,1.1326407,,,,,,,,,,,,,, -70800,0.8719262,1.1103605,,,,,,,,,,,,,, -70900,0.9321149,1.1318438,,,,,,,,,,,,,, -71000,0.9858232,1.150202,,,,,,,,,,,,,, -71100,0.87543637,1.118443,,,,,,,,,,,,,, -71102,,,0.13899502,0.0517284898186257,0.41104484,0.1253173967193488,5348.0,0.23893896,0.0804947900798245,2472.0,56205.272963523865,61436.927609205246,56205.272963523865,5226.270207643509,2.345404863357544,0.0 -71200,0.92318034,1.1390055,,,,,,,,,,,,,, -71300,0.9286639,1.1100347,,,,,,,,,,,,,, -71400,0.82403886,1.1366022,,,,,,,,,,,,,, -71500,0.8942087,1.1398598,,,,,,,,,,,,,, -71600,0.8549559,1.1609577,,,,,,,,,,,,,, -71700,0.8039099,1.1212486,,,,,,,,,,,,,, -71800,0.86799467,1.1101267,,,,,,,,,,,,,, -71900,0.993125,1.1582572,,,,,,,,,,,,,, -72000,0.8440722,1.1026016,,,,,,,,,,,,,, -72100,0.7855169,1.0879778,,,,,,,,,,,,,, -72200,0.9407645,1.0673908,,,,,,,,,,,,,, -72300,0.92161685,1.1007478,,,,,,,,,,,,,, -72400,0.84110427,1.1458231,,,,,,,,,,,,,, -72500,0.82359856,1.098171,,,,,,,,,,,,,, -72600,1.0102868,1.1400443,,,,,,,,,,,,,, -72700,0.76727694,1.1013039,,,,,,,,,,,,,, -72800,0.9694186,1.1098102,,,,,,,,,,,,,, -72900,1.3998389,1.1038568,,,,,,,,,,,,,, -72941,,,0.13666794,0.0507454263274232,0.39726907,0.1221699798217751,5348.0,0.2305465,0.0770621331220929,2472.0,57645.54037809372,63009.80436420441,57645.54037809372,5358.731348514557,2.414510488510132,0.0 -73000,0.81815577,1.0694532,,,,,,,,,,,,,, -73100,1.0484476,1.136665,,,,,,,,,,,,,, -73200,0.954895,1.1431805,,,,,,,,,,,,,, -73300,0.9647169,1.1332189,,,,,,,,,,,,,, -73400,0.97980887,1.1399792,,,,,,,,,,,,,, -73500,1.0270953,1.1007959,,,,,,,,,,,,,, -73600,0.95477176,1.0840187,,,,,,,,,,,,,, -73700,0.90896654,1.0880811,,,,,,,,,,,,,, -73800,0.8477736,1.1071875,,,,,,,,,,,,,, -73900,1.0739344,1.0887495,,,,,,,,,,,,,, -74000,1.176678,1.107484,,,,,,,,,,,,,, -74100,0.9789104,1.101093,,,,,,,,,,,,,, -74200,0.8835876,1.0477486,,,,,,,,,,,,,, -74300,0.9450731,1.0392878,,,,,,,,,,,,,, -74400,0.8525702,1.093463,,,,,,,,,,,,,, -74500,0.9754417,1.108377,,,,,,,,,,,,,, -74600,0.8964645,1.0311047,,,,,,,,,,,,,, -74700,0.8679556,1.0850697,,,,,,,,,,,,,, -74775,,,0.16922173,0.0638789305164062,0.38635293,0.1179122778222964,5348.0,0.22601831,0.076066865720147,2472.0,59085.41255426407,64581.16864061356,59085.41255426407,5490.084238290787,2.474437475204468,0.0 -74800,1.2283564,1.0566734,,,,,,,,,,,,,, -74900,0.8070144,1.0373735,,,,,,,,,,,,,, -75000,0.9939884,1.1203914,,,,,,,,,,,,,, -75100,1.0539005,1.0872577,,,,,,,,,,,,,, -75200,0.9859313,1.076123,,,,,,,,,,,,,, -75300,0.91065884,1.0968997,,,,,,,,,,,,,, -75400,0.94731426,1.1126922,,,,,,,,,,,,,, -75500,1.0513108,1.1112012,,,,,,,,,,,,,, -75600,0.8551008,1.0861177,,,,,,,,,,,,,, -75700,0.9354266,1.0808603,,,,,,,,,,,,,, -75800,0.9093265,1.0755711,,,,,,,,,,,,,, -75900,1.0236769,1.0079939,,,,,,,,,,,,,, -76000,0.97213143,1.0807087,,,,,,,,,,,,,, -76100,1.0783789,1.1080782,,,,,,,,,,,,,, -76200,1.0914472,1.0911489,,,,,,,,,,,,,, -76300,0.9273085,1.0793705,,,,,,,,,,,,,, -76400,0.96384466,1.0497948,,,,,,,,,,,,,, -76500,0.9304237,1.0409421,,,,,,,,,,,,,, -76600,0.8864931,1.0281558,,,,,,,,,,,,,, -76616,,,0.1668892,0.060572127753713,0.3823363,0.1158654913735675,5348.0,0.22135468,0.0743810046107285,2472.0,60525.80380296707,66152.88785719872,60525.80380296707,5621.268357515335,2.5391104221343994,0.0 -76700,0.9568187,1.0774677,,,,,,,,,,,,,, -76800,1.0368989,1.0675988,,,,,,,,,,,,,, -76900,0.8687115,1.0583116,,,,,,,,,,,,,, -77000,0.86684304,1.0579708,,,,,,,,,,,,,, -77100,1.0356873,1.0296886,,,,,,,,,,,,,, -77200,0.9916616,1.0959378,,,,,,,,,,,,,, -77300,0.91827786,1.0889926,,,,,,,,,,,,,, -77313,,,,,,,,,,,61068.18580174446,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 3136fc9e3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -135.82221722602844,0.0,36.68635892868042,1,0,36.68635892868042,31.27949,2472,1.0976986980277457,172.50863242149353,32.477173,1.3851183971898189,31.16359,5348,1.043156299178389 -262.77818775177,0.0273528099060058,1477.175707578659,1795,0,1477.175707578659,2.9762988,2472,0.5722178213799687,1740.054486036301,3.4333935,0.631694270799768,3.3112328,5348,0.6119022562924201 -393.5232543945313,0.0794527530670166,2917.344397068024,3620,0,2917.344397068024,0.64018565,2472,0.2091686470456807,3311.098723173141,0.8712847,0.2754521127493194,0.91919994,5348,0.2688917423752377 -524.8741371631622,0.1293325424194336,4357.893091678619,5441,0,4357.893091678619,0.49422714,2472,0.1648894034489062,4883.128641843796,0.5577522,0.189332902774414,0.7517841,5348,0.224654122054124 -657.5992331504822,0.1791727542877197,5798.3205742836,7243,0,5798.3205742836,0.4213605,2472,0.1424044847967826,6456.4077224731445,0.548139,0.1843557642721557,0.6662565,5348,0.2017243210365235 -791.0809001922607,0.2378764152526855,7238.686999559402,9061,0,7238.686999559402,0.383486,2472,0.1302784717567485,8030.393237113953,0.44564787,0.1570601279453759,0.6208973,5348,0.1878312752831227 -924.5182108879088,0.2955894470214844,8678.63923573494,10883,0,8678.63923573494,0.36785704,2472,0.1252615115877561,9603.92083120346,0.4268927,0.152191394170314,0.58681077,5348,0.1756567577744093 -1056.8081135749817,0.352435827255249,10118.762867212296,12713,0,10118.762867212296,0.3447162,2472,0.1157760038998232,11176.468502998352,0.40124416,0.1386010928961748,0.5638952,5348,0.169188140224181 -1186.2805242538452,0.4058682918548584,11559.307371377943,14529,0,11559.307371377943,0.33356386,2472,0.1123839701013547,12746.616312265396,0.3907849,0.1376007402372643,0.5494791,5348,0.1650462940614229 -1318.4176275730133,0.4609172344207763,12999.90366244316,16345,0,12999.90366244316,0.31609333,2472,0.1079357341620457,14319.482456922531,0.3892411,0.1341954468019204,0.5266665,5348,0.1597748534906398 -1449.492838382721,0.5100352764129639,14440.013365745544,18162,0,14440.013365745544,0.30815408,2472,0.1050718014339975,15890.79479265213,0.37243292,0.1307569937531525,0.5188964,5348,0.1562702144298444 -1578.4773724079132,0.5645277500152588,15880.443354845049,19980,0,15880.443354845049,0.30493766,2472,0.1033656287449475,17460.340955257416,0.37521964,0.1368282036148591,0.50475585,5348,0.1534027824710119 -1708.8097307682035,0.6217460632324219,17320.393052101135,21786,0,17320.393052101135,0.29324007,2472,0.0993642475575325,19030.75717544556,0.31174734,0.1106088718335526,0.49118516,5348,0.1480830686349286 -1839.3516829013824,0.6704885959625244,18760.32162618637,23592,0,18760.32162618637,0.2847809,2472,0.0961753295553795,20601.35322713852,0.32217106,0.1146681135829263,0.4742398,5348,0.1427440454927252 -1970.1682348251345,0.7266736030578613,20200.92087650299,25405,0,20200.92087650299,0.2735938,2472,0.0919505209920175,22172.902831554413,0.31540608,0.1119117622066246,0.46801168,5348,0.1400021240236732 -2101.914893388748,0.7841992378234863,21641.44178009033,27227,0,21641.44178009033,0.26831707,2472,0.0913411736030711,23745.30660867691,0.3136039,0.1112789305447235,0.45962992,5348,0.1382160132075653 -2231.9260606765747,0.8400917053222656,23081.622501134872,29041,0,23081.622501134872,0.2715291,2472,0.0898990514492312,25315.63108062744,0.27701193,0.1009942350132865,0.46328136,5348,0.1380132654933045 -2362.35418009758,0.9039266109466552,24521.81897425652,30847,0,24521.81897425652,0.25223556,2472,0.0851664533950805,26886.398028612137,0.24538893,0.0897736678368581,0.44567695,5348,0.133572125085685 -2492.959435224533,0.9592113494873048,25961.83100414276,32654,0,25961.83100414276,0.25065094,2472,0.0824853248837162,28457.146093845367,0.27409798,0.0995876673510596,0.43247974,5348,0.1298164650453286 -2623.59248495102,1.017035722732544,27402.23986577988,34475,0,27402.23986577988,0.24289979,2472,0.0819166006540328,30028.323790550232,0.25031212,0.0910184107825835,0.42782575,5348,0.1265145736987941 -2754.237830877304,1.070730209350586,28842.823963165283,36284,0,28842.823963165283,0.24147299,2472,0.0803119858631405,31599.68351483345,0.26957318,0.0937386351556249,0.42037314,5348,0.126253898066173 -2885.6785418987274,1.1268370151519775,30283.236042499542,38084,0,30283.236042499542,0.2352042,2472,0.0776308573517762,33171.66923260689,0.21194768,0.0783649849035137,0.41112727,5348,0.1200556108016258 -3016.282743215561,1.1808912754058838,31723.70522546768,39891,0,31723.70522546768,0.22744554,2472,0.0754981414904637,34742.872770786285,0.20022973,0.0736595003163257,0.40568483,5348,0.1178350405978161 -3146.2064123153687,1.2387712001800537,33163.71019721031,41721,0,33163.71019721031,0.22631016,2472,0.0739950845977291,36312.93751120567,0.22058325,0.0810318751248466,0.39636406,5348,0.1166282089653108 -3285.1928341388702,1.3039309978485107,34603.71418786049,43554,0,34603.71418786049,0.21916676,2472,0.0725529624438892,37892.071621418,0.14703715,0.0556202457259144,0.39308006,5348,0.1142435096594803 -3419.3971271514893,1.3684093952178955,36044.21834445,45371,0,36044.21834445,0.21107632,2472,0.0702171307862612,39466.92186617851,0.12569274,0.048289319630545,0.38317782,5348,0.1127566930882338 -3552.0800442695618,1.4292211532592771,37484.79524970055,47193,0,37484.79524970055,0.20164764,2472,0.068023480186054,41040.31870007515,0.12504135,0.0475356763616586,0.3714938,5348,0.1090396516601175 -3684.472150802612,1.4894332885742188,38925.28243923187,49019,0,38925.28243923187,0.199988,2472,0.0672922633193183,42613.33491516113,0.12188379,0.0461781700359377,0.36238962,5348,0.105496393987082 -3815.268126010895,1.547067642211914,40365.47981357575,50843,0,40365.47981357575,0.19808494,2472,0.066093880121057,44184.46557068825,0.110726975,0.0443575473913589,0.36124983,5348,0.1057474149666431 -3946.959716320038,1.6031808853149414,41805.494245529175,52655,0,41805.494245529175,0.19376087,2472,0.065057989559848,45756.3062517643,0.102106266,0.0398258617932727,0.35422248,5348,0.1018758990895662 -4078.56554722786,1.6603331565856934,43245.7050819397,54477,0,43245.7050819397,0.18844098,2472,0.0630065200170617,47328.256929636,0.115759045,0.0439879638476152,0.34668103,5348,0.1012869652529036 -4211.784162521362,1.719299077987671,44686.1141409874,56295,0,44686.1141409874,0.18400313,2472,0.0606909999390652,48902.020424842834,0.10086589,0.0365662264259159,0.34202397,5348,0.097714743620688 -4345.063044786453,1.7812213897705078,46126.11980581284,58115,0,46126.11980581284,0.17934273,2472,0.0583754798610687,50475.445321798325,0.09474793,0.0364319198806733,0.3363182,5348,0.0949535128455159 -4476.3204646110535,1.841925859451294,47566.14169406891,59930,0,47566.14169406891,0.17598493,2472,0.0583754798610687,52046.8627216816,0.07673226,0.0305275637225844,0.33402327,5348,0.0943259603966131 -4608.684624910355,1.90018367767334,49006.05555701256,61747,0,49006.05555701256,0.17221192,2472,0.0561005829423354,53619.27761530876,0.07129591,0.0283815178793507,0.3258931,5348,0.092694324029466 -4741.300411224365,1.9632997512817385,50446.629410505295,63558,0,50446.629410505295,0.16798735,2472,0.0561818292608616,55192.60791397095,0.07846448,0.0303980282360063,0.3151095,5348,0.089740000193093 -4874.458735466003,2.022402763366699,51886.58189582825,65383,0,51886.58189582825,0.1647859,2472,0.0528507302012877,56765.85677433014,0.07131476,0.0278298072157149,0.31282473,5348,0.0885331685605877 -5006.690465927124,2.087104558944702,53327.39374256134,67201,0,53327.39374256134,0.16259553,2472,0.0525866796660776,58339.04134559631,0.06927087,0.0262163524758444,0.30948767,5348,0.0870560066424013 -5139.909769058228,2.1487863063812256,54769.413890361786,69011,0,54769.413890361786,0.15818636,2472,0.0513476733085532,59914.420617342,0.059926614,0.0233735345498937,0.30565014,5348,0.0857429738262355 -5272.942055225372,2.210370779037476,56209.6598546505,70839,0,56209.6598546505,0.15762733,2472,0.0506570796010805,61487.83983325958,0.054493688,0.0206153878214479,0.30036068,5348,0.0833293105612249 -5406.034177541733,2.274899482727051,57649.84510660172,72665,0,57649.84510660172,0.15622462,2472,0.0500477322121341,63061.260159015656,0.059538107,0.0218196704533824,0.29563853,5348,0.0819583498266989 -5537.869856595993,2.33532977104187,59089.76412558556,74496,0,59089.76412558556,0.15535073,2472,0.0497430585176609,64633.15510225296,0.057106175,0.0210327421923502,0.29521686,5348,0.0823541906021607 -5668.746239900589,2.404423713684082,60530.329689741135,76309,0,60530.329689741135,0.15429394,2472,0.04970243535839782,66204.74381828308,0.05786657,0.021383041405538594,0.2951434,5348,0.08191007656139877 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/measurements.csv deleted file mode 100644 index 79696073f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/measurements.csv +++ /dev/null @@ -1,816 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,40.92152,32.769222,,,,,,,,,,,,,, -1,,,32.477173,1.3851183971898189,31.16359,1.043156299178389,5348.0,31.27949,1.0976986980277457,2472.0,36.68635892868042,172.50863242149353,36.68635892868042,135.82221722602844,0.0,0.0 -100,0.6174293,5.9566407,,,,,,,,,,,,,, -200,0.73783255,5.840009,,,,,,,,,,,,,, -300,0.34706378,5.799478,,,,,,,,,,,,,, -400,0.7253602,5.807892,,,,,,,,,,,,,, -500,0.59083945,5.8152137,,,,,,,,,,,,,, -600,3.2302952,5.775944,,,,,,,,,,,,,, -700,0.39171255,5.6107545,,,,,,,,,,,,,, -800,0.6314805,5.494297,,,,,,,,,,,,,, -900,1.5350782,5.39001,,,,,,,,,,,,,, -1000,0.7643105,4.552539,,,,,,,,,,,,,, -1100,1.8247265,3.7731028,,,,,,,,,,,,,, -1200,2.0822587,3.304964,,,,,,,,,,,,,, -1300,1.2084554,3.1169362,,,,,,,,,,,,,, -1400,0.9344533,2.9075525,,,,,,,,,,,,,, -1500,1.003889,2.7078888,,,,,,,,,,,,,, -1600,0.7705406,2.6326618,,,,,,,,,,,,,, -1700,0.6356576,2.5154119,,,,,,,,,,,,,, -1795,,,3.4333935,0.631694270799768,3.3112328,0.6119022562924201,5348.0,2.9762988,0.5722178213799687,2472.0,1477.175707578659,1740.054486036301,1477.175707578659,262.77818775177,0.0273528099060058,0.0 -1800,0.683306,2.4284275,,,,,,,,,,,,,, -1900,0.6339968,2.3260148,,,,,,,,,,,,,, -2000,0.70040315,2.3285475,,,,,,,,,,,,,, -2100,1.3911455,2.1922247,,,,,,,,,,,,,, -2200,0.8706018,2.1899962,,,,,,,,,,,,,, -2300,0.56606466,2.168782,,,,,,,,,,,,,, -2400,0.67041004,2.151704,,,,,,,,,,,,,, -2500,0.5577227,2.0765994,,,,,,,,,,,,,, -2600,0.5907902,2.033521,,,,,,,,,,,,,, -2700,0.6167941,2.0100815,,,,,,,,,,,,,, -2800,0.554313,1.9567475,,,,,,,,,,,,,, -2900,0.65735865,2.0281556,,,,,,,,,,,,,, -3000,0.57896495,1.9621922,,,,,,,,,,,,,, -3100,0.6007559,1.9147712,,,,,,,,,,,,,, -3200,0.6032015,1.934362,,,,,,,,,,,,,, -3300,0.5486841,1.8875183,,,,,,,,,,,,,, -3400,0.5550409,1.8463827,,,,,,,,,,,,,, -3500,0.5639113,1.8914295,,,,,,,,,,,,,, -3600,0.53475404,1.7900176,,,,,,,,,,,,,, -3620,,,0.8712847,0.2754521127493194,0.91919994,0.2688917423752377,5348.0,0.64018565,0.2091686470456807,2472.0,2917.344397068024,3311.098723173141,2917.344397068024,393.5232543945313,0.0794527530670166,0.0 -3700,0.49705625,1.8584031,,,,,,,,,,,,,, -3800,0.49819893,1.8075178,,,,,,,,,,,,,, -3900,0.5873247,1.8030192,,,,,,,,,,,,,, -4000,0.5185176,1.7675631,,,,,,,,,,,,,, -4100,0.53567463,1.7883613,,,,,,,,,,,,,, -4200,0.6285896,1.7556887,,,,,,,,,,,,,, -4300,0.7425842,1.7804984,,,,,,,,,,,,,, -4400,0.4398586,1.6911284,,,,,,,,,,,,,, -4500,0.47801015,1.7544842,,,,,,,,,,,,,, -4600,0.48602605,1.7524198,,,,,,,,,,,,,, -4700,0.51303166,1.7378218,,,,,,,,,,,,,, -4800,0.74251467,1.6789486,,,,,,,,,,,,,, -4900,0.52016366,1.7080884,,,,,,,,,,,,,, -5000,0.904438,1.7248415,,,,,,,,,,,,,, -5100,0.66769224,1.7543066,,,,,,,,,,,,,, -5200,0.5875313,1.6907406,,,,,,,,,,,,,, -5300,0.4629129,1.6584718,,,,,,,,,,,,,, -5400,0.5212445,1.6633371,,,,,,,,,,,,,, -5441,,,0.5577522,0.189332902774414,0.7517841,0.224654122054124,5348.0,0.49422714,0.1648894034489062,2472.0,4357.893091678619,4883.128641843796,4357.893091678619,524.8741371631622,0.1293325424194336,0.0 -5500,0.49743593,1.6885083,,,,,,,,,,,,,, -5600,0.5012114,1.7004052,,,,,,,,,,,,,, -5700,0.5829448,1.7032764,,,,,,,,,,,,,, -5800,0.478348,1.6247966,,,,,,,,,,,,,, -5900,0.5469033,1.6558039,,,,,,,,,,,,,, -6000,0.4820966,1.63343,,,,,,,,,,,,,, -6100,0.6607889,1.7734121,,,,,,,,,,,,,, -6200,0.48116964,1.6161147,,,,,,,,,,,,,, -6300,0.5281223,1.633608,,,,,,,,,,,,,, -6400,0.6071144,1.6635168,,,,,,,,,,,,,, -6500,0.6032866,1.6296276,,,,,,,,,,,,,, -6600,0.62791,1.5790006,,,,,,,,,,,,,, -6700,0.56268793,1.6631346,,,,,,,,,,,,,, -6800,0.5173378,1.6504712,,,,,,,,,,,,,, -6900,0.46133292,1.5775113,,,,,,,,,,,,,, -7000,0.47029492,1.6033634,,,,,,,,,,,,,, -7100,0.5308754,1.6189103,,,,,,,,,,,,,, -7200,0.43289936,1.58313,,,,,,,,,,,,,, -7243,,,0.548139,0.1843557642721557,0.6662565,0.2017243210365235,5348.0,0.4213605,0.1424044847967826,2472.0,5798.3205742836,6456.4077224731445,5798.3205742836,657.5992331504822,0.1791727542877197,0.0 -7300,0.55981195,1.6129233,,,,,,,,,,,,,, -7400,0.5354738,1.5708711,,,,,,,,,,,,,, -7500,0.54705006,1.5838736,,,,,,,,,,,,,, -7600,0.46879378,1.5805967,,,,,,,,,,,,,, -7700,0.56403136,1.6125407,,,,,,,,,,,,,, -7800,0.55742764,1.5590851,,,,,,,,,,,,,, -7900,0.52219343,1.5988364,,,,,,,,,,,,,, -8000,0.49457854,1.6401505,,,,,,,,,,,,,, -8100,0.42577165,1.5464896,,,,,,,,,,,,,, -8200,0.65651566,1.6098796,,,,,,,,,,,,,, -8300,0.5749343,1.581837,,,,,,,,,,,,,, -8400,0.47428244,1.5697436,,,,,,,,,,,,,, -8500,0.5295934,1.5829824,,,,,,,,,,,,,, -8600,0.4658066,1.4649603,,,,,,,,,,,,,, -8700,0.6120036,1.5483733,,,,,,,,,,,,,, -8800,0.45377785,1.5764815,,,,,,,,,,,,,, -8900,0.5559701,1.5275916,,,,,,,,,,,,,, -9000,0.6263997,1.512287,,,,,,,,,,,,,, -9061,,,0.44564787,0.1570601279453759,0.6208973,0.1878312752831227,5348.0,0.383486,0.1302784717567485,2472.0,7238.686999559402,8030.393237113953,7238.686999559402,791.0809001922607,0.2378764152526855,0.0 -9100,0.54601425,1.4891926,,,,,,,,,,,,,, -9200,0.4634711,1.4630525,,,,,,,,,,,,,, -9300,0.67337036,1.5178607,,,,,,,,,,,,,, -9400,0.42013314,1.4713068,,,,,,,,,,,,,, -9500,0.5681378,1.5470191,,,,,,,,,,,,,, -9600,0.5358822,1.4960897,,,,,,,,,,,,,, -9700,0.509467,1.5283409,,,,,,,,,,,,,, -9800,0.46600354,1.5527657,,,,,,,,,,,,,, -9900,0.53309745,1.4605899,,,,,,,,,,,,,, -10000,0.54199165,1.4665813,,,,,,,,,,,,,, -10100,0.43756777,1.5039686,,,,,,,,,,,,,, -10200,0.4230182,1.527541,,,,,,,,,,,,,, -10300,0.646564,1.4638579,,,,,,,,,,,,,, -10400,0.5597641,1.4509565,,,,,,,,,,,,,, -10500,0.47865403,1.481998,,,,,,,,,,,,,, -10600,0.42496678,1.4925671,,,,,,,,,,,,,, -10700,0.49751619,1.5053222,,,,,,,,,,,,,, -10800,0.51868683,1.5425354,,,,,,,,,,,,,, -10883,,,0.4268927,0.152191394170314,0.58681077,0.1756567577744093,5348.0,0.36785704,0.1252615115877561,2472.0,8678.63923573494,9603.92083120346,8678.63923573494,924.5182108879088,0.2955894470214844,0.0 -10900,0.53495675,1.5211622,,,,,,,,,,,,,, -11000,0.62778616,1.5057361,,,,,,,,,,,,,, -11100,0.51586473,1.5386113,,,,,,,,,,,,,, -11200,0.45892185,1.4147282,,,,,,,,,,,,,, -11300,0.49480626,1.5095111,,,,,,,,,,,,,, -11400,0.60576004,1.4161794,,,,,,,,,,,,,, -11500,0.45603484,1.4952673,,,,,,,,,,,,,, -11600,0.47676823,1.4646405,,,,,,,,,,,,,, -11700,0.731271,1.5606256,,,,,,,,,,,,,, -11800,0.46887082,1.4578909,,,,,,,,,,,,,, -11900,0.57896155,1.4655604,,,,,,,,,,,,,, -12000,0.6012147,1.4422088,,,,,,,,,,,,,, -12100,0.5988485,1.4808846,,,,,,,,,,,,,, -12200,0.5370061,1.4515159,,,,,,,,,,,,,, -12300,0.4821637,1.4951707,,,,,,,,,,,,,, -12400,0.5928346,1.4673041,,,,,,,,,,,,,, -12500,0.5183935,1.4617908,,,,,,,,,,,,,, -12600,0.46373335,1.449267,,,,,,,,,,,,,, -12700,0.4956458,1.4860327,,,,,,,,,,,,,, -12713,,,0.40124416,0.1386010928961748,0.5638952,0.169188140224181,5348.0,0.3447162,0.1157760038998232,2472.0,10118.762867212296,11176.468502998352,10118.762867212296,1056.8081135749817,0.352435827255249,0.0 -12800,0.46205708,1.4430473,,,,,,,,,,,,,, -12900,0.4321751,1.4071245,,,,,,,,,,,,,, -13000,0.52341896,1.460521,,,,,,,,,,,,,, -13100,0.49584723,1.4397227,,,,,,,,,,,,,, -13200,0.43676922,1.3662515,,,,,,,,,,,,,, -13300,0.65772545,1.5014868,,,,,,,,,,,,,, -13400,0.49804318,1.4535509,,,,,,,,,,,,,, -13500,0.52439743,1.392217,,,,,,,,,,,,,, -13600,0.5607277,1.4063466,,,,,,,,,,,,,, -13700,0.4464239,1.4009193,,,,,,,,,,,,,, -13800,0.49436444,1.4531778,,,,,,,,,,,,,, -13900,0.4618263,1.4350896,,,,,,,,,,,,,, -14000,0.5994762,1.4312351,,,,,,,,,,,,,, -14100,0.57935166,1.4559318,,,,,,,,,,,,,, -14200,0.51865715,1.4221984,,,,,,,,,,,,,, -14300,0.45420438,1.3778286,,,,,,,,,,,,,, -14400,0.40737212,1.3565911,,,,,,,,,,,,,, -14500,0.45369262,1.419612,,,,,,,,,,,,,, -14529,,,0.3907849,0.1376007402372643,0.5494791,0.1650462940614229,5348.0,0.33356386,0.1123839701013547,2472.0,11559.307371377943,12746.616312265396,11559.307371377943,1186.2805242538452,0.4058682918548584,0.0 -14600,0.44008636,1.3966455,,,,,,,,,,,,,, -14700,0.5631323,1.3835624,,,,,,,,,,,,,, -14800,0.5602694,1.4360944,,,,,,,,,,,,,, -14900,0.47846696,1.3797191,,,,,,,,,,,,,, -15000,0.49135187,1.4108077,,,,,,,,,,,,,, -15100,0.5219578,1.3934609,,,,,,,,,,,,,, -15200,0.6133163,1.4160395,,,,,,,,,,,,,, -15300,0.5742174,1.4398402,,,,,,,,,,,,,, -15400,0.41326892,1.3898041,,,,,,,,,,,,,, -15500,0.4850017,1.4032192,,,,,,,,,,,,,, -15600,0.5670033,1.4576913,,,,,,,,,,,,,, -15700,0.6175222,1.4348158,,,,,,,,,,,,,, -15800,0.55445194,1.3732879,,,,,,,,,,,,,, -15900,0.4570415,1.3881568,,,,,,,,,,,,,, -16000,0.44018954,1.3784212,,,,,,,,,,,,,, -16100,0.5451454,1.4298028,,,,,,,,,,,,,, -16200,0.5223969,1.3819221,,,,,,,,,,,,,, -16300,0.50843346,1.3662049,,,,,,,,,,,,,, -16345,,,0.3892411,0.1341954468019204,0.5266665,0.1597748534906398,5348.0,0.31609333,0.1079357341620457,2472.0,12999.90366244316,14319.482456922531,12999.90366244316,1318.4176275730133,0.4609172344207763,0.0 -16400,0.62009156,1.4275279,,,,,,,,,,,,,, -16500,0.5249334,1.4107404,,,,,,,,,,,,,, -16600,0.48246413,1.3838369,,,,,,,,,,,,,, -16700,0.48190722,1.3239908,,,,,,,,,,,,,, -16800,0.40008155,1.3886585,,,,,,,,,,,,,, -16900,0.56491816,1.3902053,,,,,,,,,,,,,, -17000,0.4866648,1.3727645,,,,,,,,,,,,,, -17100,0.5252251,1.3828753,,,,,,,,,,,,,, -17200,0.561665,1.3994087,,,,,,,,,,,,,, -17300,0.5130772,1.3830831,,,,,,,,,,,,,, -17400,0.5104988,1.4059808,,,,,,,,,,,,,, -17500,0.44276688,1.3509594,,,,,,,,,,,,,, -17600,0.53482556,1.4285886,,,,,,,,,,,,,, -17700,0.62176037,1.4332126,,,,,,,,,,,,,, -17800,0.5242112,1.4248922,,,,,,,,,,,,,, -17900,0.50542235,1.3752433,,,,,,,,,,,,,, -18000,0.4794531,1.3869797,,,,,,,,,,,,,, -18100,0.5336434,1.3371289,,,,,,,,,,,,,, -18162,,,0.37243292,0.1307569937531525,0.5188964,0.1562702144298444,5348.0,0.30815408,0.1050718014339975,2472.0,14440.013365745544,15890.79479265213,14440.013365745544,1449.492838382721,0.5100352764129639,0.0 -18200,0.46710393,1.3875585,,,,,,,,,,,,,, -18300,0.5055173,1.3625813,,,,,,,,,,,,,, -18400,0.54018927,1.4500165,,,,,,,,,,,,,, -18500,0.45025128,1.4292474,,,,,,,,,,,,,, -18600,0.5262981,1.3984675,,,,,,,,,,,,,, -18700,0.47021762,1.3681351,,,,,,,,,,,,,, -18800,0.44852445,1.3480333,,,,,,,,,,,,,, -18900,0.5010057,1.351666,,,,,,,,,,,,,, -19000,0.4525275,1.3524319,,,,,,,,,,,,,, -19100,0.56696844,1.3610889,,,,,,,,,,,,,, -19200,0.52703184,1.367206,,,,,,,,,,,,,, -19300,0.5420354,1.38638,,,,,,,,,,,,,, -19400,0.45467559,1.3753989,,,,,,,,,,,,,, -19500,0.51791096,1.415932,,,,,,,,,,,,,, -19600,0.51179546,1.2992703,,,,,,,,,,,,,, -19700,0.57809675,1.365678,,,,,,,,,,,,,, -19800,0.4708662,1.3304378,,,,,,,,,,,,,, -19900,0.48665416,1.3816588,,,,,,,,,,,,,, -19980,,,0.37521964,0.1368282036148591,0.50475585,0.1534027824710119,5348.0,0.30493766,0.1033656287449475,2472.0,15880.443354845049,17460.340955257416,15880.443354845049,1578.4773724079132,0.5645277500152588,0.0 -20000,0.6218963,1.3765234,,,,,,,,,,,,,, -20100,0.44165614,1.3290097,,,,,,,,,,,,,, -20200,0.49126616,1.3612405,,,,,,,,,,,,,, -20300,0.46741557,1.3802125,,,,,,,,,,,,,, -20400,0.55418324,1.3551117,,,,,,,,,,,,,, -20500,0.5147175,1.340933,,,,,,,,,,,,,, -20600,0.5050126,1.3920604,,,,,,,,,,,,,, -20700,0.7085444,1.2930295,,,,,,,,,,,,,, -20800,0.5039741,1.3256245,,,,,,,,,,,,,, -20900,0.6824082,1.3782322,,,,,,,,,,,,,, -21000,0.49563304,1.38439,,,,,,,,,,,,,, -21100,0.5640525,1.2852085,,,,,,,,,,,,,, -21200,0.56562054,1.3295215,,,,,,,,,,,,,, -21300,0.5076944,1.3347439,,,,,,,,,,,,,, -21400,0.49882585,1.3444163,,,,,,,,,,,,,, -21500,0.5717852,1.392827,,,,,,,,,,,,,, -21600,0.58880997,1.3672707,,,,,,,,,,,,,, -21700,0.4883575,1.3057755,,,,,,,,,,,,,, -21786,,,0.31174734,0.1106088718335526,0.49118516,0.1480830686349286,5348.0,0.29324007,0.0993642475575325,2472.0,17320.393052101135,19030.75717544556,17320.393052101135,1708.8097307682035,0.6217460632324219,0.0 -21800,0.52344453,1.3406723,,,,,,,,,,,,,, -21900,0.5626363,1.350466,,,,,,,,,,,,,, -22000,0.67936915,1.2979976,,,,,,,,,,,,,, -22100,0.48124325,1.3510585,,,,,,,,,,,,,, -22200,0.52786064,1.3513004,,,,,,,,,,,,,, -22300,0.6770263,1.3246586,,,,,,,,,,,,,, -22400,0.5559819,1.2649492,,,,,,,,,,,,,, -22500,0.4867065,1.3443137,,,,,,,,,,,,,, -22600,0.4907216,1.3002567,,,,,,,,,,,,,, -22700,0.52829933,1.347792,,,,,,,,,,,,,, -22800,0.48720545,1.3064239,,,,,,,,,,,,,, -22900,0.5538397,1.3781737,,,,,,,,,,,,,, -23000,0.5114515,1.3199828,,,,,,,,,,,,,, -23100,0.5104155,1.3755342,,,,,,,,,,,,,, -23200,0.45816654,1.3394073,,,,,,,,,,,,,, -23300,0.4831753,1.3139364,,,,,,,,,,,,,, -23400,0.6377648,1.3233433,,,,,,,,,,,,,, -23500,0.45897254,1.3434387,,,,,,,,,,,,,, -23592,,,0.32217106,0.1146681135829263,0.4742398,0.1427440454927252,5348.0,0.2847809,0.0961753295553795,2472.0,18760.32162618637,20601.35322713852,18760.32162618637,1839.3516829013824,0.6704885959625244,0.0 -23600,0.5361482,1.3468975,,,,,,,,,,,,,, -23700,0.46836612,1.2784778,,,,,,,,,,,,,, -23800,0.52686816,1.3489577,,,,,,,,,,,,,, -23900,0.56344026,1.298071,,,,,,,,,,,,,, -24000,0.55229455,1.3223898,,,,,,,,,,,,,, -24100,0.56812197,1.3373141,,,,,,,,,,,,,, -24200,0.5108779,1.2853178,,,,,,,,,,,,,, -24300,0.5313935,1.317687,,,,,,,,,,,,,, -24400,0.46203074,1.3385798,,,,,,,,,,,,,, -24500,0.5315978,1.3009887,,,,,,,,,,,,,, -24600,0.5398705,1.3295774,,,,,,,,,,,,,, -24700,0.51399815,1.2717938,,,,,,,,,,,,,, -24800,0.6272029,1.2911234,,,,,,,,,,,,,, -24900,0.5864884,1.2380877,,,,,,,,,,,,,, -25000,0.5325383,1.3110691,,,,,,,,,,,,,, -25100,0.542444,1.3003546,,,,,,,,,,,,,, -25200,0.47651404,1.3601873,,,,,,,,,,,,,, -25300,0.50290346,1.2830992,,,,,,,,,,,,,, -25400,0.47229066,1.2707301,,,,,,,,,,,,,, -25405,,,0.31540608,0.1119117622066246,0.46801168,0.1400021240236732,5348.0,0.2735938,0.0919505209920175,2472.0,20200.92087650299,22172.902831554413,20200.92087650299,1970.1682348251345,0.7266736030578613,0.0 -25500,0.5270792,1.3052253,,,,,,,,,,,,,, -25600,0.4398475,1.323412,,,,,,,,,,,,,, -25700,0.6303153,1.2765654,,,,,,,,,,,,,, -25800,0.6217856,1.276008,,,,,,,,,,,,,, -25900,0.55235434,1.3568842,,,,,,,,,,,,,, -26000,0.4435334,1.3196589,,,,,,,,,,,,,, -26100,0.5335564,1.2618464,,,,,,,,,,,,,, -26200,0.7459508,1.2648991,,,,,,,,,,,,,, -26300,0.49265805,1.3226827,,,,,,,,,,,,,, -26400,0.6314965,1.3478165,,,,,,,,,,,,,, -26500,0.5960709,1.3355302,,,,,,,,,,,,,, -26600,0.58676946,1.2954801,,,,,,,,,,,,,, -26700,0.43601894,1.316762,,,,,,,,,,,,,, -26800,0.5862036,1.2469718,,,,,,,,,,,,,, -26900,0.5823402,1.314313,,,,,,,,,,,,,, -27000,0.4630214,1.2215704,,,,,,,,,,,,,, -27100,0.5458061,1.2974995,,,,,,,,,,,,,, -27200,0.49463454,1.2884378,,,,,,,,,,,,,, -27227,,,0.3136039,0.1112789305447235,0.45962992,0.1382160132075653,5348.0,0.26831707,0.0913411736030711,2472.0,21641.44178009033,23745.30660867691,21641.44178009033,2101.914893388748,0.7841992378234863,0.0 -27300,0.51807976,1.2751013,,,,,,,,,,,,,, -27400,0.4207763,1.2886884,,,,,,,,,,,,,, -27500,0.5944457,1.3066274,,,,,,,,,,,,,, -27600,0.6453848,1.2792171,,,,,,,,,,,,,, -27700,0.470514,1.2954808,,,,,,,,,,,,,, -27800,0.49434763,1.3075905,,,,,,,,,,,,,, -27900,0.5033373,1.2274364,,,,,,,,,,,,,, -28000,0.6332232,1.2670273,,,,,,,,,,,,,, -28100,0.5277062,1.3152498,,,,,,,,,,,,,, -28200,0.76269114,1.2716843,,,,,,,,,,,,,, -28300,0.53111786,1.2744036,,,,,,,,,,,,,, -28400,0.51185304,1.2763399,,,,,,,,,,,,,, -28500,0.48487267,1.2430924,,,,,,,,,,,,,, -28600,0.51804924,1.2987348,,,,,,,,,,,,,, -28700,0.5892644,1.308796,,,,,,,,,,,,,, -28800,0.6367952,1.2576271,,,,,,,,,,,,,, -28900,0.5619216,1.245156,,,,,,,,,,,,,, -29000,0.545997,1.2895483,,,,,,,,,,,,,, -29041,,,0.27701193,0.1009942350132865,0.46328136,0.1380132654933045,5348.0,0.2715291,0.0898990514492312,2472.0,23081.622501134872,25315.63108062744,23081.622501134872,2231.9260606765747,0.8400917053222656,0.0 -29100,0.7447112,1.299897,,,,,,,,,,,,,, -29200,0.526655,1.3036275,,,,,,,,,,,,,, -29300,0.4882905,1.24493,,,,,,,,,,,,,, -29400,0.5639708,1.2811816,,,,,,,,,,,,,, -29500,0.47932675,1.265586,,,,,,,,,,,,,, -29600,0.61983526,1.2243747,,,,,,,,,,,,,, -29700,0.52488565,1.2777694,,,,,,,,,,,,,, -29800,0.51956743,1.2199129,,,,,,,,,,,,,, -29900,0.54171664,1.2169216,,,,,,,,,,,,,, -30000,0.6231757,1.2410122,,,,,,,,,,,,,, -30100,0.43131897,1.2184036,,,,,,,,,,,,,, -30200,0.511149,1.267137,,,,,,,,,,,,,, -30300,0.6331292,1.2495966,,,,,,,,,,,,,, -30400,0.59019834,1.2061577,,,,,,,,,,,,,, -30500,0.571141,1.2917385,,,,,,,,,,,,,, -30600,0.5116033,1.2795246,,,,,,,,,,,,,, -30700,0.48832318,1.2880515,,,,,,,,,,,,,, -30800,0.5377369,1.3068002,,,,,,,,,,,,,, -30847,,,0.24538893,0.0897736678368581,0.44567695,0.133572125085685,5348.0,0.25223556,0.0851664533950805,2472.0,24521.81897425652,26886.398028612137,24521.81897425652,2362.35418009758,0.9039266109466552,0.0 -30900,0.6859163,1.1895359,,,,,,,,,,,,,, -31000,0.52455884,1.2412058,,,,,,,,,,,,,, -31100,0.56131935,1.2610718,,,,,,,,,,,,,, -31200,0.49760437,1.2353636,,,,,,,,,,,,,, -31300,0.53275305,1.2433523,,,,,,,,,,,,,, -31400,0.59445757,1.242761,,,,,,,,,,,,,, -31500,0.47487894,1.2274692,,,,,,,,,,,,,, -31600,0.5447267,1.2181175,,,,,,,,,,,,,, -31700,0.48973262,1.20394,,,,,,,,,,,,,, -31800,0.5432185,1.2283419,,,,,,,,,,,,,, -31900,0.49668017,1.2092035,,,,,,,,,,,,,, -32000,0.6551946,1.2896651,,,,,,,,,,,,,, -32100,0.52575433,1.2713768,,,,,,,,,,,,,, -32200,0.5279649,1.25042,,,,,,,,,,,,,, -32300,0.53617704,1.2619177,,,,,,,,,,,,,, -32400,0.48410273,1.248981,,,,,,,,,,,,,, -32500,0.48413742,1.2405512,,,,,,,,,,,,,, -32600,0.4822603,1.1859258,,,,,,,,,,,,,, -32654,,,0.27409798,0.0995876673510596,0.43247974,0.1298164650453286,5348.0,0.25065094,0.0824853248837162,2472.0,25961.83100414276,28457.146093845367,25961.83100414276,2492.959435224533,0.9592113494873048,0.0 -32700,0.5925766,1.3121063,,,,,,,,,,,,,, -32800,0.5463106,1.2063453,,,,,,,,,,,,,, -32900,0.6949563,1.3152094,,,,,,,,,,,,,, -33000,0.4871544,1.1833928,,,,,,,,,,,,,, -33100,0.5858909,1.1851803,,,,,,,,,,,,,, -33200,0.48888457,1.1989161,,,,,,,,,,,,,, -33300,0.62059677,1.2757069,,,,,,,,,,,,,, -33400,0.52660626,1.2332457,,,,,,,,,,,,,, -33500,0.50295657,1.263705,,,,,,,,,,,,,, -33600,0.55158794,1.2507063,,,,,,,,,,,,,, -33700,0.48426306,1.2010813,,,,,,,,,,,,,, -33800,0.63056976,1.2550141,,,,,,,,,,,,,, -33900,0.6267137,1.2826685,,,,,,,,,,,,,, -34000,0.5033215,1.21104,,,,,,,,,,,,,, -34100,0.47006568,1.1747812,,,,,,,,,,,,,, -34200,0.5754959,1.2771206,,,,,,,,,,,,,, -34300,0.506596,1.1822717,,,,,,,,,,,,,, -34400,0.52174765,1.2049317,,,,,,,,,,,,,, -34475,,,0.25031212,0.0910184107825835,0.42782575,0.1265145736987941,5348.0,0.24289979,0.0819166006540328,2472.0,27402.23986577988,30028.323790550232,27402.23986577988,2623.59248495102,1.017035722732544,0.0 -34500,0.48073515,1.2287956,,,,,,,,,,,,,, -34600,0.4669714,1.1707658,,,,,,,,,,,,,, -34700,0.5377673,1.2860776,,,,,,,,,,,,,, -34800,0.5339684,1.2582127,,,,,,,,,,,,,, -34900,0.6155853,1.2449765,,,,,,,,,,,,,, -35000,0.65318686,1.2248129,,,,,,,,,,,,,, -35100,0.4859914,1.1968709,,,,,,,,,,,,,, -35200,0.52920556,1.2108388,,,,,,,,,,,,,, -35300,0.5021029,1.2123476,,,,,,,,,,,,,, -35400,0.5138755,1.233226,,,,,,,,,,,,,, -35500,0.51343197,1.1924114,,,,,,,,,,,,,, -35600,0.6248748,1.1558807,,,,,,,,,,,,,, -35700,0.5834315,1.244534,,,,,,,,,,,,,, -35800,0.56878203,1.2144241,,,,,,,,,,,,,, -35900,0.5069947,1.2480391,,,,,,,,,,,,,, -36000,0.565343,1.2667621,,,,,,,,,,,,,, -36100,0.5426172,1.1731617,,,,,,,,,,,,,, -36200,0.62093997,1.1765757,,,,,,,,,,,,,, -36284,,,0.26957318,0.0937386351556249,0.42037314,0.126253898066173,5348.0,0.24147299,0.0803119858631405,2472.0,28842.823963165283,31599.68351483345,28842.823963165283,2754.237830877304,1.070730209350586,0.0 -36300,0.5278951,1.2177963,,,,,,,,,,,,,, -36400,0.5246317,1.245544,,,,,,,,,,,,,, -36500,0.6758836,1.2240518,,,,,,,,,,,,,, -36600,0.61667985,1.2489349,,,,,,,,,,,,,, -36700,0.44260156,1.1864995,,,,,,,,,,,,,, -36800,0.53759235,1.1877126,,,,,,,,,,,,,, -36900,0.5282986,1.2470832,,,,,,,,,,,,,, -37000,0.5109038,1.2041023,,,,,,,,,,,,,, -37100,0.5674363,1.1817038,,,,,,,,,,,,,, -37200,0.5164867,1.190948,,,,,,,,,,,,,, -37300,0.5265241,1.1994898,,,,,,,,,,,,,, -37400,0.5343782,1.1477658,,,,,,,,,,,,,, -37500,0.5486375,1.1577147,,,,,,,,,,,,,, -37600,0.59003353,1.2484969,,,,,,,,,,,,,, -37700,0.50068325,1.2011541,,,,,,,,,,,,,, -37800,0.4561219,1.1764532,,,,,,,,,,,,,, -37900,0.5493229,1.2308545,,,,,,,,,,,,,, -38000,0.54456544,1.2394439,,,,,,,,,,,,,, -38084,,,0.21194768,0.0783649849035137,0.41112727,0.1200556108016258,5348.0,0.2352042,0.0776308573517762,2472.0,30283.236042499542,33171.66923260689,30283.236042499542,2885.6785418987274,1.1268370151519775,0.0 -38100,0.471457,1.181114,,,,,,,,,,,,,, -38200,0.46648923,1.1782367,,,,,,,,,,,,,, -38300,0.5162141,1.1569101,,,,,,,,,,,,,, -38400,0.49714467,1.187406,,,,,,,,,,,,,, -38500,0.5004444,1.2257533,,,,,,,,,,,,,, -38600,0.5835195,1.1389563,,,,,,,,,,,,,, -38700,0.6361365,1.1593859,,,,,,,,,,,,,, -38800,0.5774991,1.1841073,,,,,,,,,,,,,, -38900,0.5697405,1.2019631,,,,,,,,,,,,,, -39000,0.75607294,1.1678697,,,,,,,,,,,,,, -39100,0.5479461,1.134854,,,,,,,,,,,,,, -39200,0.49255773,1.0804793,,,,,,,,,,,,,, -39300,0.55556923,1.1662179,,,,,,,,,,,,,, -39400,0.45717075,1.1548063,,,,,,,,,,,,,, -39500,0.6043149,1.1995043,,,,,,,,,,,,,, -39600,0.60433394,1.2148788,,,,,,,,,,,,,, -39700,0.61038566,1.2199078,,,,,,,,,,,,,, -39800,0.58897984,1.2121071,,,,,,,,,,,,,, -39891,,,0.20022973,0.0736595003163257,0.40568483,0.1178350405978161,5348.0,0.22744554,0.0754981414904637,2472.0,31723.70522546768,34742.872770786285,31723.70522546768,3016.282743215561,1.1808912754058838,0.0 -39900,0.52416795,1.1364058,,,,,,,,,,,,,, -40000,0.5972309,1.2516322,,,,,,,,,,,,,, -40100,0.5592269,1.1466565,,,,,,,,,,,,,, -40200,0.6012255,1.1288887,,,,,,,,,,,,,, -40300,0.53129417,1.1778256,,,,,,,,,,,,,, -40400,0.5943805,1.1387452,,,,,,,,,,,,,, -40500,0.5376928,1.1521931,,,,,,,,,,,,,, -40600,0.6976845,1.1925628,,,,,,,,,,,,,, -40700,0.59727526,1.2014948,,,,,,,,,,,,,, -40800,0.5858783,1.2098314,,,,,,,,,,,,,, -40900,0.5482323,1.1606144,,,,,,,,,,,,,, -41000,0.53468376,1.1838956,,,,,,,,,,,,,, -41100,0.5232463,1.165712,,,,,,,,,,,,,, -41200,0.5299033,1.1378767,,,,,,,,,,,,,, -41300,0.51094425,1.1824033,,,,,,,,,,,,,, -41400,0.523658,1.1067828,,,,,,,,,,,,,, -41500,0.5533567,1.1247528,,,,,,,,,,,,,, -41600,0.5767412,1.155837,,,,,,,,,,,,,, -41700,0.5583735,1.1113471,,,,,,,,,,,,,, -41721,,,0.22058325,0.0810318751248466,0.39636406,0.1166282089653108,5348.0,0.22631016,0.0739950845977291,2472.0,33163.71019721031,36312.93751120567,33163.71019721031,3146.2064123153687,1.2387712001800537,0.0 -41800,0.67404413,1.090434,,,,,,,,,,,,,, -41900,0.76388603,1.1599116,,,,,,,,,,,,,, -42000,0.5228435,1.158399,,,,,,,,,,,,,, -42100,0.51051384,1.1416075,,,,,,,,,,,,,, -42200,0.69836974,1.229111,,,,,,,,,,,,,, -42300,0.5306421,1.1433866,,,,,,,,,,,,,, -42400,0.57731014,1.113557,,,,,,,,,,,,,, -42500,0.4708159,1.150352,,,,,,,,,,,,,, -42600,0.4915684,1.1469487,,,,,,,,,,,,,, -42700,0.54934627,1.0943336,,,,,,,,,,,,,, -42800,0.5478255,1.0947803,,,,,,,,,,,,,, -42900,0.57071096,1.1055193,,,,,,,,,,,,,, -43000,0.55298454,1.1858904,,,,,,,,,,,,,, -43100,0.60925186,1.1491064,,,,,,,,,,,,,, -43200,0.5426759,1.1631765,,,,,,,,,,,,,, -43300,0.51003844,1.1602899,,,,,,,,,,,,,, -43400,0.5510491,1.1486526,,,,,,,,,,,,,, -43500,0.4910976,1.1367955,,,,,,,,,,,,,, -43554,,,0.14703715,0.0556202457259144,0.39308006,0.1142435096594803,5348.0,0.21916676,0.0725529624438892,2472.0,34603.71418786049,37892.071621418,34603.71418786049,3285.1928341388702,1.3039309978485107,0.0 -43600,0.55001575,1.1277853,,,,,,,,,,,,,, -43700,0.6131702,1.1292533,,,,,,,,,,,,,, -43800,0.670693,1.0968513,,,,,,,,,,,,,, -43900,0.81938684,1.1397836,,,,,,,,,,,,,, -44000,0.54885507,1.1440513,,,,,,,,,,,,,, -44100,0.51223224,1.0985729,,,,,,,,,,,,,, -44200,0.57246745,1.1328218,,,,,,,,,,,,,, -44300,0.5456631,1.1264797,,,,,,,,,,,,,, -44400,0.65003985,1.1299549,,,,,,,,,,,,,, -44500,0.57591414,1.1450679,,,,,,,,,,,,,, -44600,0.5413728,1.1610775,,,,,,,,,,,,,, -44700,0.59755725,1.1065439,,,,,,,,,,,,,, -44800,0.52380335,1.1143421,,,,,,,,,,,,,, -44900,0.5362727,1.1372144,,,,,,,,,,,,,, -45000,0.5156104,1.1218461,,,,,,,,,,,,,, -45100,0.5409033,1.1236467,,,,,,,,,,,,,, -45200,0.5726491,1.175038,,,,,,,,,,,,,, -45300,0.5499293,1.0618833,,,,,,,,,,,,,, -45371,,,0.12569274,0.048289319630545,0.38317782,0.1127566930882338,5348.0,0.21107632,0.0702171307862612,2472.0,36044.21834445,39466.92186617851,36044.21834445,3419.3971271514893,1.3684093952178955,0.0 -45400,0.6473234,1.133239,,,,,,,,,,,,,, -45500,0.4741994,1.1426228,,,,,,,,,,,,,, -45600,0.70893294,1.1139979,,,,,,,,,,,,,, -45700,0.4632197,1.0465308,,,,,,,,,,,,,, -45800,0.74422175,1.1316649,,,,,,,,,,,,,, -45900,0.64320546,1.133322,,,,,,,,,,,,,, -46000,0.626705,1.1180775,,,,,,,,,,,,,, -46100,0.5023844,1.148505,,,,,,,,,,,,,, -46200,0.65843254,1.1048594,,,,,,,,,,,,,, -46300,0.58202016,1.0996214,,,,,,,,,,,,,, -46400,0.5398338,1.0729319,,,,,,,,,,,,,, -46500,0.79935235,1.0905929,,,,,,,,,,,,,, -46600,0.58797306,1.1035552,,,,,,,,,,,,,, -46700,0.5302015,1.0895897,,,,,,,,,,,,,, -46800,0.6090149,1.0752866,,,,,,,,,,,,,, -46900,0.559618,1.125025,,,,,,,,,,,,,, -47000,0.7555884,1.0917784,,,,,,,,,,,,,, -47100,0.5924971,1.1129239,,,,,,,,,,,,,, -47193,,,0.12504135,0.0475356763616586,0.3714938,0.1090396516601175,5348.0,0.20164764,0.068023480186054,2472.0,37484.79524970055,41040.31870007515,37484.79524970055,3552.0800442695618,1.4292211532592771,0.0 -47200,0.50760716,1.1234708,,,,,,,,,,,,,, -47300,0.54887563,1.1253694,,,,,,,,,,,,,, -47400,0.6523756,1.0643605,,,,,,,,,,,,,, -47500,0.5991002,1.1131878,,,,,,,,,,,,,, -47600,0.50700593,1.0899345,,,,,,,,,,,,,, -47700,0.58651215,1.1024086,,,,,,,,,,,,,, -47800,0.5586414,1.0845801,,,,,,,,,,,,,, -47900,0.4642795,1.0667843,,,,,,,,,,,,,, -48000,0.60939145,1.1523196,,,,,,,,,,,,,, -48100,0.5632271,1.1441942,,,,,,,,,,,,,, -48200,0.5634189,1.0820543,,,,,,,,,,,,,, -48300,0.5808573,1.0810355,,,,,,,,,,,,,, -48400,0.74226135,1.1921339,,,,,,,,,,,,,, -48500,0.6131914,1.114167,,,,,,,,,,,,,, -48600,0.5943454,1.0737972,,,,,,,,,,,,,, -48700,0.6822238,1.1141744,,,,,,,,,,,,,, -48800,0.7172287,1.0780095,,,,,,,,,,,,,, -48900,0.53277004,1.0915897,,,,,,,,,,,,,, -49000,0.5200871,1.0646224,,,,,,,,,,,,,, -49019,,,0.12188379,0.0461781700359377,0.36238962,0.105496393987082,5348.0,0.199988,0.0672922633193183,2472.0,38925.28243923187,42613.33491516113,38925.28243923187,3684.472150802612,1.4894332885742188,0.0 -49100,0.52511543,1.0604205,,,,,,,,,,,,,, -49200,0.52580905,1.1239599,,,,,,,,,,,,,, -49300,0.59379697,1.1285052,,,,,,,,,,,,,, -49400,0.58298534,1.0888039,,,,,,,,,,,,,, -49500,0.58238006,1.0803119,,,,,,,,,,,,,, -49600,0.53452325,1.0870538,,,,,,,,,,,,,, -49700,0.6188617,1.0428951,,,,,,,,,,,,,, -49800,0.6251307,1.0920111,,,,,,,,,,,,,, -49900,0.5919927,1.1037166,,,,,,,,,,,,,, -50000,0.56896365,1.0645882,,,,,,,,,,,,,, -50100,0.52920485,1.0976232,,,,,,,,,,,,,, -50200,0.5678139,1.0585527,,,,,,,,,,,,,, -50300,0.62560135,1.1146084,,,,,,,,,,,,,, -50400,0.48562658,1.0366927,,,,,,,,,,,,,, -50500,0.6167999,1.0785234,,,,,,,,,,,,,, -50600,0.6338757,1.096336,,,,,,,,,,,,,, -50700,0.5406407,1.0721307,,,,,,,,,,,,,, -50800,0.5095424,1.0298855,,,,,,,,,,,,,, -50843,,,0.110726975,0.0443575473913589,0.36124983,0.1057474149666431,5348.0,0.19808494,0.066093880121057,2472.0,40365.47981357575,44184.46557068825,40365.47981357575,3815.268126010895,1.547067642211914,0.0 -50900,0.5919565,1.0385203,,,,,,,,,,,,,, -51000,0.5959042,1.1156336,,,,,,,,,,,,,, -51100,0.87261605,1.0529449,,,,,,,,,,,,,, -51200,0.6630846,1.0997456,,,,,,,,,,,,,, -51300,0.643153,1.0863912,,,,,,,,,,,,,, -51400,0.5014421,1.0726452,,,,,,,,,,,,,, -51500,0.52491164,1.0667415,,,,,,,,,,,,,, -51600,0.5165098,1.068345,,,,,,,,,,,,,, -51700,0.64560765,1.1176952,,,,,,,,,,,,,, -51800,0.76079017,1.0764652,,,,,,,,,,,,,, -51900,0.9097003,1.0381895,,,,,,,,,,,,,, -52000,0.7210273,1.0923768,,,,,,,,,,,,,, -52100,0.5396714,1.035234,,,,,,,,,,,,,, -52200,0.58587474,1.0550119,,,,,,,,,,,,,, -52300,0.6076275,1.0826924,,,,,,,,,,,,,, -52400,0.57870615,1.110617,,,,,,,,,,,,,, -52500,0.5232681,1.0041988,,,,,,,,,,,,,, -52600,0.7353774,1.0533268,,,,,,,,,,,,,, -52655,,,0.102106266,0.0398258617932727,0.35422248,0.1018758990895662,5348.0,0.19376087,0.065057989559848,2472.0,41805.494245529175,45756.3062517643,41805.494245529175,3946.959716320038,1.6031808853149414,0.0 -52700,0.5810227,1.0284337,,,,,,,,,,,,,, -52800,0.55815285,1.023757,,,,,,,,,,,,,, -52900,0.557863,1.0470243,,,,,,,,,,,,,, -53000,0.57604486,1.0496652,,,,,,,,,,,,,, -53100,0.6333583,1.0602583,,,,,,,,,,,,,, -53200,0.65002465,1.0478605,,,,,,,,,,,,,, -53300,0.54842824,1.0591568,,,,,,,,,,,,,, -53400,0.75854707,1.0570545,,,,,,,,,,,,,, -53500,0.5465929,1.0669681,,,,,,,,,,,,,, -53600,0.6055529,1.0377597,,,,,,,,,,,,,, -53700,0.632229,1.074297,,,,,,,,,,,,,, -53800,0.5219007,0.99919367,,,,,,,,,,,,,, -53900,0.7053604,1.0470585,,,,,,,,,,,,,, -54000,0.57063067,1.0427778,,,,,,,,,,,,,, -54100,0.49877727,1.029859,,,,,,,,,,,,,, -54200,0.5618813,1.0232593,,,,,,,,,,,,,, -54300,0.7293506,1.0047395,,,,,,,,,,,,,, -54400,0.8670422,1.0435859,,,,,,,,,,,,,, -54477,,,0.115759045,0.0439879638476152,0.34668103,0.1012869652529036,5348.0,0.18844098,0.0630065200170617,2472.0,43245.7050819397,47328.256929636,43245.7050819397,4078.56554722786,1.6603331565856934,0.0 -54500,0.565145,0.99903274,,,,,,,,,,,,,, -54600,0.6000329,0.9887439,,,,,,,,,,,,,, -54700,0.5545946,0.9887874,,,,,,,,,,,,,, -54800,0.906122,0.98709357,,,,,,,,,,,,,, -54900,0.5984281,1.0064497,,,,,,,,,,,,,, -55000,0.6007352,1.0028483,,,,,,,,,,,,,, -55100,0.58873826,0.9693067,,,,,,,,,,,,,, -55200,0.62412673,1.0290076,,,,,,,,,,,,,, -55300,0.60575306,1.0312437,,,,,,,,,,,,,, -55400,0.56988484,1.0174834,,,,,,,,,,,,,, -55500,0.60342443,1.0583973,,,,,,,,,,,,,, -55600,0.6556079,1.0154306,,,,,,,,,,,,,, -55700,0.6576982,1.0203639,,,,,,,,,,,,,, -55800,0.56272405,1.0011798,,,,,,,,,,,,,, -55900,0.6355411,1.0062612,,,,,,,,,,,,,, -56000,0.5812252,1.0412577,,,,,,,,,,,,,, -56100,0.6972179,0.992302,,,,,,,,,,,,,, -56200,0.55359745,1.0267073,,,,,,,,,,,,,, -56295,,,0.10086589,0.0365662264259159,0.34202397,0.097714743620688,5348.0,0.18400313,0.0606909999390652,2472.0,44686.1141409874,48902.020424842834,44686.1141409874,4211.784162521362,1.719299077987671,0.0 -56300,0.6035465,1.0204645,,,,,,,,,,,,,, -56400,0.61393607,1.0033231,,,,,,,,,,,,,, -56500,0.574217,1.024257,,,,,,,,,,,,,, -56600,0.55613095,1.0172052,,,,,,,,,,,,,, -56700,0.5516707,1.0039953,,,,,,,,,,,,,, -56800,0.54530156,1.0426455,,,,,,,,,,,,,, -56900,0.71035403,0.982619,,,,,,,,,,,,,, -57000,0.5975592,1.0178822,,,,,,,,,,,,,, -57100,0.68030304,0.98926276,,,,,,,,,,,,,, -57200,0.5980068,1.0217627,,,,,,,,,,,,,, -57300,0.80231136,0.99032813,,,,,,,,,,,,,, -57400,0.62758154,0.98549163,,,,,,,,,,,,,, -57500,0.568056,1.0656145,,,,,,,,,,,,,, -57600,0.6243358,1.0215139,,,,,,,,,,,,,, -57700,0.7620832,0.97140795,,,,,,,,,,,,,, -57800,0.609959,0.98902446,,,,,,,,,,,,,, -57900,0.5595476,1.0159438,,,,,,,,,,,,,, -58000,0.7270679,0.985886,,,,,,,,,,,,,, -58100,0.65648335,0.99797213,,,,,,,,,,,,,, -58115,,,0.09474793,0.0364319198806733,0.3363182,0.0949535128455159,5348.0,0.17934273,0.0583754798610687,2472.0,46126.11980581284,50475.445321798325,46126.11980581284,4345.063044786453,1.7812213897705078,0.0 -58200,0.56219274,1.0337299,,,,,,,,,,,,,, -58300,0.6178281,1.0170863,,,,,,,,,,,,,, -58400,0.59573054,0.98776597,,,,,,,,,,,,,, -58500,0.5726196,0.96185356,,,,,,,,,,,,,, -58600,0.55781084,1.0246916,,,,,,,,,,,,,, -58700,0.7931809,1.0039927,,,,,,,,,,,,,, -58800,0.7857699,0.98380435,,,,,,,,,,,,,, -58900,0.7120431,0.9731472,,,,,,,,,,,,,, -59000,0.59284174,1.0052918,,,,,,,,,,,,,, -59100,0.67066276,1.0154303,,,,,,,,,,,,,, -59200,0.7016122,0.972557,,,,,,,,,,,,,, -59300,0.6478711,1.008499,,,,,,,,,,,,,, -59400,0.6270322,1.0352037,,,,,,,,,,,,,, -59500,0.55092925,1.0023245,,,,,,,,,,,,,, -59600,0.5960101,1.0013579,,,,,,,,,,,,,, -59700,0.626141,0.97190344,,,,,,,,,,,,,, -59800,0.7820491,0.9657586,,,,,,,,,,,,,, -59900,0.6094561,0.99375135,,,,,,,,,,,,,, -59930,,,0.07673226,0.0305275637225844,0.33402327,0.0943259603966131,5348.0,0.17598493,0.0583754798610687,2472.0,47566.14169406891,52046.8627216816,47566.14169406891,4476.3204646110535,1.841925859451294,0.0 -60000,1.0414574,1.0486757,,,,,,,,,,,,,, -60100,0.58558965,0.9498626,,,,,,,,,,,,,, -60200,0.58263963,0.952534,,,,,,,,,,,,,, -60300,0.6178115,0.9356248,,,,,,,,,,,,,, -60400,0.65292203,0.9961825,,,,,,,,,,,,,, -60500,0.6386232,1.0207696,,,,,,,,,,,,,, -60600,0.6508419,0.9970613,,,,,,,,,,,,,, -60700,0.78573143,0.96608865,,,,,,,,,,,,,, -60800,0.5959605,0.9596479,,,,,,,,,,,,,, -60900,0.78644484,0.9687022,,,,,,,,,,,,,, -61000,0.90702343,0.96174604,,,,,,,,,,,,,, -61100,0.8899333,0.97055554,,,,,,,,,,,,,, -61200,0.7172124,0.94730943,,,,,,,,,,,,,, -61300,0.63958776,0.9892249,,,,,,,,,,,,,, -61400,0.6955961,0.9608218,,,,,,,,,,,,,, -61500,0.5757788,0.9559752,,,,,,,,,,,,,, -61600,0.5759027,0.96901524,,,,,,,,,,,,,, -61700,0.6003989,1.0319743,,,,,,,,,,,,,, -61747,,,0.07129591,0.0283815178793507,0.3258931,0.092694324029466,5348.0,0.17221192,0.0561005829423354,2472.0,49006.05555701256,53619.27761530876,49006.05555701256,4608.684624910355,1.90018367767334,0.0 -61800,0.68632627,1.0140285,,,,,,,,,,,,,, -61900,0.6660908,0.95336926,,,,,,,,,,,,,, -62000,0.59290475,0.94549906,,,,,,,,,,,,,, -62100,0.603291,0.95435214,,,,,,,,,,,,,, -62200,0.58525485,0.9140638,,,,,,,,,,,,,, -62300,0.92677987,0.95126855,,,,,,,,,,,,,, -62400,0.64140934,0.9743601,,,,,,,,,,,,,, -62500,0.6282152,0.95146155,,,,,,,,,,,,,, -62600,0.63199943,0.96773314,,,,,,,,,,,,,, -62700,0.7383373,0.9254204,,,,,,,,,,,,,, -62800,0.9409357,0.9451229,,,,,,,,,,,,,, -62900,0.59812444,0.9314299,,,,,,,,,,,,,, -63000,0.526416,0.9252056,,,,,,,,,,,,,, -63100,0.62533927,0.93954766,,,,,,,,,,,,,, -63200,0.6659699,0.9401733,,,,,,,,,,,,,, -63300,0.6757141,1.0179226,,,,,,,,,,,,,, -63400,0.7183435,0.9870654,,,,,,,,,,,,,, -63500,0.62632513,0.9272548,,,,,,,,,,,,,, -63558,,,0.07846448,0.0303980282360063,0.3151095,0.089740000193093,5348.0,0.16798735,0.0561818292608616,2472.0,50446.629410505295,55192.60791397095,50446.629410505295,4741.300411224365,1.9632997512817385,0.0 -63600,0.77935857,0.98174196,,,,,,,,,,,,,, -63700,0.71589494,0.9662747,,,,,,,,,,,,,, -63800,0.5881313,0.91960686,,,,,,,,,,,,,, -63900,0.567528,0.9588586,,,,,,,,,,,,,, -64000,0.6053523,0.9246527,,,,,,,,,,,,,, -64100,0.61913276,0.928433,,,,,,,,,,,,,, -64200,0.64571565,0.9842542,,,,,,,,,,,,,, -64300,0.7401076,0.96965146,,,,,,,,,,,,,, -64400,0.8884618,0.9365218,,,,,,,,,,,,,, -64500,0.5488492,0.9054076,,,,,,,,,,,,,, -64600,0.6850849,0.905268,,,,,,,,,,,,,, -64700,1.0202606,0.94906723,,,,,,,,,,,,,, -64800,0.6746018,0.9102387,,,,,,,,,,,,,, -64900,0.8557003,0.97766,,,,,,,,,,,,,, -65000,0.8608033,0.92603606,,,,,,,,,,,,,, -65100,0.85946256,0.90990424,,,,,,,,,,,,,, -65200,0.58747053,0.9258066,,,,,,,,,,,,,, -65300,0.58288574,0.9639322,,,,,,,,,,,,,, -65383,,,0.07131476,0.0278298072157149,0.31282473,0.0885331685605877,5348.0,0.1647859,0.0528507302012877,2472.0,51886.58189582825,56765.85677433014,51886.58189582825,4874.458735466003,2.022402763366699,0.0 -65400,0.79366684,0.94308823,,,,,,,,,,,,,, -65500,0.622831,0.9328714,,,,,,,,,,,,,, -65600,0.7057582,0.98357993,,,,,,,,,,,,,, -65700,0.69998443,0.94037235,,,,,,,,,,,,,, -65800,0.698286,0.9654236,,,,,,,,,,,,,, -65900,0.7309422,0.8948347,,,,,,,,,,,,,, -66000,0.64238304,0.9248084,,,,,,,,,,,,,, -66100,0.66150784,0.9278707,,,,,,,,,,,,,, -66200,1.0228674,0.88764906,,,,,,,,,,,,,, -66300,0.6606256,0.9462549,,,,,,,,,,,,,, -66400,0.60654485,0.8917699,,,,,,,,,,,,,, -66500,0.8204692,0.9508625,,,,,,,,,,,,,, -66600,0.6777781,0.96750146,,,,,,,,,,,,,, -66700,0.6949521,0.92662036,,,,,,,,,,,,,, -66800,0.7139427,0.91817,,,,,,,,,,,,,, -66900,0.7024428,0.9219523,,,,,,,,,,,,,, -67000,0.5977448,0.9149562,,,,,,,,,,,,,, -67100,0.727579,0.9232688,,,,,,,,,,,,,, -67200,0.7452646,0.9211892,,,,,,,,,,,,,, -67201,,,0.06927087,0.0262163524758444,0.30948767,0.0870560066424013,5348.0,0.16259553,0.0525866796660776,2472.0,53327.39374256134,58339.04134559631,53327.39374256134,5006.690465927124,2.087104558944702,0.0 -67300,0.86585623,0.93484956,,,,,,,,,,,,,, -67400,0.6958084,0.91684383,,,,,,,,,,,,,, -67500,0.5679048,0.89824075,,,,,,,,,,,,,, -67600,0.8416567,0.95698005,,,,,,,,,,,,,, -67700,0.6370981,0.92853874,,,,,,,,,,,,,, -67800,0.665674,0.9481026,,,,,,,,,,,,,, -67900,0.5825165,0.9167534,,,,,,,,,,,,,, -68000,0.8070855,0.89166737,,,,,,,,,,,,,, -68100,0.6921426,0.9605809,,,,,,,,,,,,,, -68200,0.5697778,0.92901605,,,,,,,,,,,,,, -68300,0.62539876,0.9463138,,,,,,,,,,,,,, -68400,0.61514276,0.9091552,,,,,,,,,,,,,, -68500,0.7230424,0.9270726,,,,,,,,,,,,,, -68600,0.82041955,0.9105247,,,,,,,,,,,,,, -68700,0.78517133,0.90288115,,,,,,,,,,,,,, -68800,0.59237367,0.8648905,,,,,,,,,,,,,, -68900,0.67328435,0.9782809,,,,,,,,,,,,,, -69000,0.92106026,0.91953045,,,,,,,,,,,,,, -69011,,,0.059926614,0.0233735345498937,0.30565014,0.0857429738262355,5348.0,0.15818636,0.0513476733085532,2472.0,54769.413890361786,59914.420617342,54769.413890361786,5139.909769058228,2.1487863063812256,0.0 -69100,0.6360706,0.89638066,,,,,,,,,,,,,, -69200,0.67220277,0.9105001,,,,,,,,,,,,,, -69300,0.90668714,0.9100292,,,,,,,,,,,,,, -69400,0.699771,0.9446584,,,,,,,,,,,,,, -69500,0.66491354,0.94250894,,,,,,,,,,,,,, -69600,0.8012079,0.8789243,,,,,,,,,,,,,, -69700,0.64103997,0.88264245,,,,,,,,,,,,,, -69800,0.6987328,0.90302896,,,,,,,,,,,,,, -69900,0.65576905,0.8968713,,,,,,,,,,,,,, -70000,0.7304221,0.91506475,,,,,,,,,,,,,, -70100,0.64951116,0.89598167,,,,,,,,,,,,,, -70200,0.66906965,0.86968315,,,,,,,,,,,,,, -70300,0.61648315,0.90562046,,,,,,,,,,,,,, -70400,0.65662974,0.9045407,,,,,,,,,,,,,, -70500,0.7116999,0.8896465,,,,,,,,,,,,,, -70600,0.84430295,0.9540688,,,,,,,,,,,,,, -70700,0.5935307,0.89121026,,,,,,,,,,,,,, -70800,0.7557401,0.9598256,,,,,,,,,,,,,, -70839,,,0.054493688,0.0206153878214479,0.30036068,0.0833293105612249,5348.0,0.15762733,0.0506570796010805,2472.0,56209.6598546505,61487.83983325958,56209.6598546505,5272.942055225372,2.210370779037476,0.0 -70900,0.88859725,0.8780551,,,,,,,,,,,,,, -71000,0.7146206,0.866787,,,,,,,,,,,,,, -71100,0.61976093,0.88587356,,,,,,,,,,,,,, -71200,0.6587432,0.88880545,,,,,,,,,,,,,, -71300,0.67413616,0.90170854,,,,,,,,,,,,,, -71400,0.7767568,0.9463938,,,,,,,,,,,,,, -71500,0.8284069,0.89939916,,,,,,,,,,,,,, -71600,0.7321051,0.8903295,,,,,,,,,,,,,, -71700,0.64747846,0.90170074,,,,,,,,,,,,,, -71800,0.7281113,0.90693754,,,,,,,,,,,,,, -71900,0.62506497,0.85873675,,,,,,,,,,,,,, -72000,0.64791673,0.90106326,,,,,,,,,,,,,, -72100,0.66737896,0.89579284,,,,,,,,,,,,,, -72200,0.6264464,0.8395501,,,,,,,,,,,,,, -72300,0.5523641,0.890373,,,,,,,,,,,,,, -72400,0.6286796,0.8864412,,,,,,,,,,,,,, -72500,0.63838166,0.8995709,,,,,,,,,,,,,, -72600,0.6997316,0.8786493,,,,,,,,,,,,,, -72665,,,0.059538107,0.0218196704533824,0.29563853,0.0819583498266989,5348.0,0.15622462,0.0500477322121341,2472.0,57649.84510660172,63061.260159015656,57649.84510660172,5406.034177541733,2.274899482727051,0.0 -72700,0.8366987,0.9023525,,,,,,,,,,,,,, -72800,1.3092241,0.8887502,,,,,,,,,,,,,, -72900,0.6764562,0.9033046,,,,,,,,,,,,,, -73000,0.72988534,0.91248435,,,,,,,,,,,,,, -73100,1.0905774,0.9286291,,,,,,,,,,,,,, -73200,1.1023577,0.92539513,,,,,,,,,,,,,, -73300,0.84945786,0.91826296,,,,,,,,,,,,,, -73400,0.692722,0.87975425,,,,,,,,,,,,,, -73500,0.73473394,0.8316016,,,,,,,,,,,,,, -73600,0.78878534,0.92050636,,,,,,,,,,,,,, -73700,0.61034375,0.8953886,,,,,,,,,,,,,, -73800,0.6885865,0.8995337,,,,,,,,,,,,,, -73900,0.67830956,0.89051175,,,,,,,,,,,,,, -74000,0.9132252,0.8849714,,,,,,,,,,,,,, -74100,0.8388396,0.8866158,,,,,,,,,,,,,, -74200,0.85083133,0.9023468,,,,,,,,,,,,,, -74300,0.6465576,0.9113407,,,,,,,,,,,,,, -74400,1.0544336,0.8942191,,,,,,,,,,,,,, -74496,,,0.057106175,0.0210327421923502,0.29521686,0.0823541906021607,5348.0,0.15535073,0.0497430585176609,2472.0,59089.76412558556,64633.15510225296,59089.76412558556,5537.869856595993,2.33532977104187,0.0 -74500,0.7454313,0.85074055,,,,,,,,,,,,,, -74600,0.7271213,0.8800224,,,,,,,,,,,,,, -74700,1.1036541,0.9085041,,,,,,,,,,,,,, -74800,0.70798576,0.81270164,,,,,,,,,,,,,, -74900,0.64050096,0.8883967,,,,,,,,,,,,,, -75000,0.98642486,0.8870176,,,,,,,,,,,,,, -75100,0.83664054,0.8585004,,,,,,,,,,,,,, -75200,0.66622853,0.90122277,,,,,,,,,,,,,, -75300,0.6873727,0.8745431,,,,,,,,,,,,,, -75400,0.7245534,0.9183668,,,,,,,,,,,,,, -75500,0.9117221,0.89751494,,,,,,,,,,,,,, -75600,0.6942061,0.86938864,,,,,,,,,,,,,, -75700,0.6928495,0.86911845,,,,,,,,,,,,,, -75800,1.5305626,0.89411587,,,,,,,,,,,,,, -75900,0.8583374,0.89743716,,,,,,,,,,,,,, -76000,0.62156045,0.9000943,,,,,,,,,,,,,, -76100,0.8627825,0.90694666,,,,,,,,,,,,,, -76200,0.60701203,0.90430933,,,,,,,,,,,,,, -76300,0.75499505,0.9459226,,,,,,,,,,,,,, -76309,,,0.05786657,0.0213830414055385,0.2951434,0.0819100765613987,5348.0,0.15429394,0.0497024353583978,2472.0,60530.329689741135,66204.74381828308,60530.329689741135,5668.746239900589,2.404423713684082,0.0 -76400,0.7104124,0.8940467,,,,,,,,,,,,,, -76500,0.7023669,0.8832167,,,,,,,,,,,,,, -76600,0.81907135,0.9052643,,,,,,,,,,,,,, -76700,0.638176,0.8865453,,,,,,,,,,,,,, -76800,1.2304223,0.86308545,,,,,,,,,,,,,, -76900,0.77442694,0.8642871,,,,,,,,,,,,,, -77000,0.8275916,0.9077929,,,,,,,,,,,,,, -77006,,,,,,,,,,,61068.35397648811,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index c80175410..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,30 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -265.52112197875977,0.0,45.19680690765381,1,0,45.19680690765381,30.939672,2472,4.561290191538196,310.71803855896,31.839237,4.706369864439956,30.812881,5348,4.233381928420402 -371.200421333313,0.0480868816375732,1485.6898527145386,1756,0,1485.6898527145386,6.0450172,2472,0.8979343123514716,1857.0092232227323,6.2416306,0.941288541945583,6.160726,5348,0.8952083956863011 -497.2301824092865,0.0906984806060791,2926.220492839813,3509,0,2926.220492839813,3.663008,2472,0.7430382060812869,3423.682983160019,3.7979012,0.7728577897344379,4.057168,5348,0.784469525087616 -631.6112017631531,0.1279244422912597,4366.24077963829,5251,0,4366.24077963829,0.7044876,2472,0.2236101801637113,4998.193810939789,0.6250547,0.2147107969490435,1.0514781,5348,0.2947179393108509 -772.7243013381958,0.1675834655761718,5806.4535665512085,7003,0,5806.4535665512085,0.522345,2472,0.1705766457457396,6579.632403612137,0.48509657,0.1643487731571175,0.82567525,5348,0.2396477982563696 -908.856529712677,0.2088100910186767,7246.49645781517,8757,0,7246.49645781517,0.4660455,2472,0.1513212682550322,8155.920823574066,0.42493558,0.1448770358129117,0.76556903,5348,0.2198943780955231 -1044.7315225601196,0.2510659694671631,8686.787279844284,10501,0,8686.787279844284,0.43311206,2472,0.1437044258932017,9732.202117919922,0.41917217,0.1410128513584883,0.72048366,5348,0.208675671239754 -1183.076649427414,0.2884364128112793,10126.841429710388,12213,0,10126.841429710388,0.4196876,2472,0.1372859667296325,11310.710699796677,0.38169208,0.1291717184608512,0.70423025,5348,0.2039448912403332 -1315.8529777526855,0.3267304897308349,11567.57804465294,13956,0,11567.57804465294,0.38722602,2472,0.1270286190157008,12884.33522772789,0.3189625,0.1119997919862711,0.6466541,5348,0.1890187976095079 -1449.5785381793976,0.3680403232574463,13007.677670955658,15677,0,13007.677670955658,0.3840812,2472,0.1235756504783377,14458.273012399672,0.30618104,0.1042991131516367,0.6405959,5348,0.1845680025488284 -1593.3436903953552,0.4052996635437011,14447.619881868362,17405,0,14447.619881868362,0.36800343,2472,0.1195945808705543,16042.089720010756,0.306939,0.1079547587508781,0.6359483,5348,0.1824536335286791 -1727.5816078186035,0.4457442760467529,15888.127633094788,19134,0,15888.127633094788,0.35935163,2472,0.1146994901793512,17616.94973897934,0.3205848,0.1075173854411803,0.61774206,5348,0.1791710514882647 -1861.594043970108,0.4855766296386719,17328.907950878143,20886,0,17328.907950878143,0.34600043,2472,0.1121605427254077,19191.85638141632,0.31008214,0.1054897512332814,0.5979301,5348,0.1724417583054153 -1995.896348953247,0.5268707275390625,18769.11666440964,22583,0,18769.11666440964,0.33355513,2472,0.1080982267990981,20766.481579065323,0.28917006,0.0984051857186953,0.57588047,5348,0.1671896270407523 -2140.9842009544373,0.5696592330932617,20209.478974580765,24285,0,20209.478974580765,0.32747313,2472,0.1063920541100481,22352.04579281807,0.27114815,0.0929278746674549,0.5687677,5348,0.1633181111636753 -2279.357107400894,0.6119377613067627,21649.439897060394,26018,0,21649.439897060394,0.31358135,2472,0.1003392033798468,23930.495884418488,0.24425977,0.0852656563812306,0.5455414,5348,0.1575060100215298 -2420.41985297203,0.6570084095001221,23089.615846395493,27733,0,23089.615846395493,0.30254042,2472,0.0986330306907968,25511.851219892505,0.23917654,0.0822074030131634,0.533727,5348,0.1542813558994757 -2562.9808316230774,0.6976122856140137,24530.126829862595,29456,0,24530.126829862595,0.29362866,2472,0.0956878516442223,27095.03487610817,0.25039864,0.0864913763217195,0.51835567,5348,0.1496471224306554 -2702.8475642204285,0.7462007999420166,25970.27191734314,31194,0,25970.27191734314,0.28454825,2472,0.0915646009790181,28675.169151067734,0.22763942,0.0768750779982528,0.5053736,5348,0.1465672881045019 -2838.235775709152,0.7904810905456543,27410.788177251816,32913,0,27410.788177251816,0.27736688,2472,0.0895943777547579,30251.19058847428,0.2309214,0.0789493961544981,0.49034488,5348,0.1433522886355079 -2976.252999305725,0.8371679782867432,28851.9081697464,34613,0,28851.9081697464,0.2623799,2472,0.0847805333820811,31830.44634723664,0.21600437,0.0720131717765913,0.47284636,5348,0.1371153827587205 -3111.6914982795715,0.8780829906463623,30292.434127807617,36352,0,30292.434127807617,0.25495288,2472,0.0809822679909816,33406.52474331856,0.16080004,0.0557383434293738,0.46172354,5348,0.1340065844733869 -3253.490313768387,0.9202170372009276,31732.72907662392,38068,0,31732.72907662392,0.24388216,2472,0.0778542847277232,34988.733575344086,0.17805463,0.0605900614559354,0.4462943,5348,0.1298164650453286 -3387.4322237968445,0.9764108657836914,33173.24671959877,39784,0,33173.24671959877,0.2375296,2472,0.0756606341275161,36563.3217394352,0.22174585,0.0760185069672331,0.43836048,5348,0.1261187329233324 -3523.49757862091,1.033158540725708,34613.45397615433,41515,0,34613.45397615433,0.22996153,2472,0.0736294761643613,38139.72946357727,0.22483462,0.0773355834725792,0.4263075,5348,0.1229906253318787 -3657.8997716903687,1.0785846710205078,36053.44173717499,43212,0,36053.44173717499,0.22553718,2472,0.0716998760993642,39714.23803758621,0.25475532,0.0880018250843809,0.41781008,5348,0.1205383434546279 -3791.0279178619385,1.1358962059020996,37493.750953912735,44916,0,37493.750953912735,0.22203259,2472,0.0716592529401011,41287.81129693985,0.22422256,0.0749018245618166,0.41305763,5348,0.1193025478629425 -3924.312516689301,1.1946141719818115,38933.89561486244,46632,0,38933.89561486244,0.22045316,2472,0.0711108402900493,42861.37547492981,0.20445894,0.0709529800808346,0.410703,5348,0.1183563918630584 -4061.046733379364,1.2506194114685059,40087.15713596344,48000,0,40087.15713596344,0.22013809,2472,0.0707452318566815,44151.489364147186,0.18561524,0.0648674808861434,0.4110458,5348,0.11803778831207701 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index a05b1bef6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,511 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,18.907827,33.328632,,,,,,,,,,,,,, -1,,,31.839237,4.706369864439956,30.812881,4.233381928420402,5348.0,30.939672,4.561290191538196,2472.0,45.19680690765381,310.71803855896,45.19680690765381,265.52112197875977,0.0,0.0 -100,5.1315446,8.3750305,,,,,,,,,,,,,, -200,1.0536582,6.452946,,,,,,,,,,,,,, -300,0.59183025,6.039233,,,,,,,,,,,,,, -400,0.5341817,5.8675017,,,,,,,,,,,,,, -500,0.5897763,5.8342476,,,,,,,,,,,,,, -600,0.5370457,5.8096514,,,,,,,,,,,,,, -700,0.46108228,5.738929,,,,,,,,,,,,,, -800,0.40557557,5.5978355,,,,,,,,,,,,,, -900,0.70261335,5.3704247,,,,,,,,,,,,,, -1000,1.0269855,5.0263076,,,,,,,,,,,,,, -1100,1.1984075,4.5091863,,,,,,,,,,,,,, -1200,0.99162847,4.2462935,,,,,,,,,,,,,, -1300,1.2936592,3.9105842,,,,,,,,,,,,,, -1400,1.7837783,3.6679265,,,,,,,,,,,,,, -1500,1.468298,3.526956,,,,,,,,,,,,,, -1600,4.0934477,3.2863777,,,,,,,,,,,,,, -1700,2.6193726,3.1854153,,,,,,,,,,,,,, -1756,,,6.2416306,0.941288541945583,6.160726,0.8952083956863011,5348.0,6.0450172,0.8979343123514716,2472.0,1485.6898527145386,1857.0092232227323,1485.6898527145386,371.200421333313,0.0480868816375732,0.0 -1800,2.727338,3.0896928,,,,,,,,,,,,,, -1900,2.6139917,2.9858327,,,,,,,,,,,,,, -2000,4.323144,2.8629975,,,,,,,,,,,,,, -2100,1.8185457,2.7678907,,,,,,,,,,,,,, -2200,2.8615885,2.694107,,,,,,,,,,,,,, -2300,3.000121,2.6179686,,,,,,,,,,,,,, -2400,2.7142115,2.5975552,,,,,,,,,,,,,, -2500,2.9538615,2.5972545,,,,,,,,,,,,,, -2600,2.81537,2.4670613,,,,,,,,,,,,,, -2700,2.508174,2.3961797,,,,,,,,,,,,,, -2800,2.2130976,2.4116251,,,,,,,,,,,,,, -2900,2.992073,2.282197,,,,,,,,,,,,,, -3000,4.228219,2.2496953,,,,,,,,,,,,,, -3100,4.3625283,2.2424734,,,,,,,,,,,,,, -3200,3.3008382,2.2098398,,,,,,,,,,,,,, -3300,3.2990117,2.2010376,,,,,,,,,,,,,, -3400,5.6953444,2.1742084,,,,,,,,,,,,,, -3500,3.5988555,2.120544,,,,,,,,,,,,,, -3509,,,3.7979012,0.7728577897344379,4.057168,0.784469525087616,5348.0,3.663008,0.7430382060812869,2472.0,2926.220492839813,3423.682983160019,2926.220492839813,497.2301824092865,0.0906984806060791,0.0 -3600,3.2587664,2.184362,,,,,,,,,,,,,, -3700,4.221708,2.137657,,,,,,,,,,,,,, -3800,10.193707,2.0914571,,,,,,,,,,,,,, -3900,3.4663103,2.0874958,,,,,,,,,,,,,, -4000,4.058238,2.0314536,,,,,,,,,,,,,, -4100,2.8326163,2.0373356,,,,,,,,,,,,,, -4200,3.649523,1.9670014,,,,,,,,,,,,,, -4300,3.1984727,1.9724061,,,,,,,,,,,,,, -4400,3.3286564,1.9907665,,,,,,,,,,,,,, -4500,3.3229327,1.9986714,,,,,,,,,,,,,, -4600,5.302098,2.017605,,,,,,,,,,,,,, -4700,2.9938614,1.9151272,,,,,,,,,,,,,, -4800,4.3783855,2.006206,,,,,,,,,,,,,, -4900,4.627344,2.0002222,,,,,,,,,,,,,, -5000,5.017623,2.0095055,,,,,,,,,,,,,, -5100,3.3223758,1.908209,,,,,,,,,,,,,, -5200,3.0856578,1.8919965,,,,,,,,,,,,,, -5251,,,0.6250547,0.2147107969490435,1.0514781,0.2947179393108509,5348.0,0.7044876,0.2236101801637113,2472.0,4366.24077963829,4998.193810939789,4366.24077963829,631.6112017631531,0.1279244422912597,0.0 -5300,4.1761565,1.8567184,,,,,,,,,,,,,, -5400,4.2653747,1.836185,,,,,,,,,,,,,, -5500,3.1623464,1.8707683,,,,,,,,,,,,,, -5600,3.4739532,1.8536611,,,,,,,,,,,,,, -5700,2.2800686,1.7937882,,,,,,,,,,,,,, -5800,2.8026211,1.8865249,,,,,,,,,,,,,, -5900,2.9538004,1.9099127,,,,,,,,,,,,,, -6000,3.1795144,1.8006537,,,,,,,,,,,,,, -6100,3.703375,1.8428624,,,,,,,,,,,,,, -6200,2.9592617,1.8125619,,,,,,,,,,,,,, -6300,3.6089747,1.8270853,,,,,,,,,,,,,, -6400,2.974819,1.8101366,,,,,,,,,,,,,, -6500,2.6252067,1.805419,,,,,,,,,,,,,, -6600,3.3187494,1.7837516,,,,,,,,,,,,,, -6700,3.6629937,1.7583128,,,,,,,,,,,,,, -6800,3.0085666,1.7738262,,,,,,,,,,,,,, -6900,3.2352927,1.7476445,,,,,,,,,,,,,, -7000,2.938945,1.7727888,,,,,,,,,,,,,, -7003,,,0.48509657,0.1643487731571175,0.82567525,0.2396477982563696,5348.0,0.522345,0.1705766457457396,2472.0,5806.4535665512085,6579.632403612137,5806.4535665512085,772.7243013381958,0.1675834655761718,0.0 -7100,2.4936447,1.7290806,,,,,,,,,,,,,, -7200,3.391585,1.760207,,,,,,,,,,,,,, -7300,4.4906178,1.7122644,,,,,,,,,,,,,, -7400,2.50082,1.7018763,,,,,,,,,,,,,, -7500,3.7231956,1.7133982,,,,,,,,,,,,,, -7600,2.2357905,1.7441866,,,,,,,,,,,,,, -7700,3.422224,1.7513579,,,,,,,,,,,,,, -7800,3.5172417,1.7593836,,,,,,,,,,,,,, -7900,3.5350523,1.7134187,,,,,,,,,,,,,, -8000,2.9781656,1.672087,,,,,,,,,,,,,, -8100,2.5718186,1.6810832,,,,,,,,,,,,,, -8200,2.787795,1.752492,,,,,,,,,,,,,, -8300,2.8851635,1.707621,,,,,,,,,,,,,, -8400,2.9792304,1.6944336,,,,,,,,,,,,,, -8500,3.470474,1.7266784,,,,,,,,,,,,,, -8600,3.2805912,1.6883556,,,,,,,,,,,,,, -8700,1.7813458,1.6521308,,,,,,,,,,,,,, -8757,,,0.42493558,0.1448770358129117,0.76556903,0.2198943780955231,5348.0,0.4660455,0.1513212682550322,2472.0,7246.49645781517,8155.920823574066,7246.49645781517,908.856529712677,0.2088100910186767,0.0 -8800,2.2132912,1.620538,,,,,,,,,,,,,, -8900,4.1262126,1.647629,,,,,,,,,,,,,, -9000,2.9122796,1.7183744,,,,,,,,,,,,,, -9100,4.385452,1.6857576,,,,,,,,,,,,,, -9200,2.628357,1.6536741,,,,,,,,,,,,,, -9300,2.566063,1.6345806,,,,,,,,,,,,,, -9400,3.1177547,1.6756616,,,,,,,,,,,,,, -9500,2.7990215,1.6376615,,,,,,,,,,,,,, -9600,2.9441543,1.7014425,,,,,,,,,,,,,, -9700,3.6530633,1.6918896,,,,,,,,,,,,,, -9800,3.411896,1.6653918,,,,,,,,,,,,,, -9900,4.09438,1.5701013,,,,,,,,,,,,,, -10000,2.6829987,1.5890591,,,,,,,,,,,,,, -10100,2.88226,1.649359,,,,,,,,,,,,,, -10200,3.2300158,1.6398185,,,,,,,,,,,,,, -10300,1.6168786,1.5731937,,,,,,,,,,,,,, -10400,3.2129455,1.637475,,,,,,,,,,,,,, -10500,4.3689017,1.6257578,,,,,,,,,,,,,, -10501,,,0.41917217,0.1410128513584883,0.72048366,0.208675671239754,5348.0,0.43311206,0.1437044258932017,2472.0,8686.787279844284,9732.202117919922,8686.787279844284,1044.7315225601196,0.2510659694671631,0.0 -10600,2.0677462,1.5911031,,,,,,,,,,,,,, -10700,3.6014128,1.6485755,,,,,,,,,,,,,, -10800,3.5842483,1.6551782,,,,,,,,,,,,,, -10900,3.01428,1.6640056,,,,,,,,,,,,,, -11000,2.9522433,1.6016482,,,,,,,,,,,,,, -11100,2.6340952,1.5954702,,,,,,,,,,,,,, -11200,2.837244,1.5735683,,,,,,,,,,,,,, -11300,3.5822191,1.6808814,,,,,,,,,,,,,, -11400,2.3734643,1.6368088,,,,,,,,,,,,,, -11500,3.3255882,1.5847193,,,,,,,,,,,,,, -11600,3.1673281,1.5832883,,,,,,,,,,,,,, -11700,3.053937,1.6048673,,,,,,,,,,,,,, -11800,3.6313426,1.5977858,,,,,,,,,,,,,, -11900,2.6400585,1.6101028,,,,,,,,,,,,,, -12000,3.23551,1.5900259,,,,,,,,,,,,,, -12100,3.003272,1.6647555,,,,,,,,,,,,,, -12200,2.8705647,1.5723845,,,,,,,,,,,,,, -12213,,,0.38169208,0.1291717184608512,0.70423025,0.2039448912403332,5348.0,0.4196876,0.1372859667296325,2472.0,10126.841429710388,11310.710699796677,10126.841429710388,1183.076649427414,0.2884364128112793,0.0 -12300,3.3411677,1.6258415,,,,,,,,,,,,,, -12400,2.0404248,1.5426854,,,,,,,,,,,,,, -12500,3.790128,1.5844653,,,,,,,,,,,,,, -12600,3.7545962,1.6085409,,,,,,,,,,,,,, -12700,3.3547266,1.5382904,,,,,,,,,,,,,, -12800,2.4359012,1.6125852,,,,,,,,,,,,,, -12900,3.7249362,1.5550315,,,,,,,,,,,,,, -13000,3.7157967,1.6392086,,,,,,,,,,,,,, -13100,3.462203,1.5880555,,,,,,,,,,,,,, -13200,2.8928008,1.5878972,,,,,,,,,,,,,, -13300,3.1053505,1.5679606,,,,,,,,,,,,,, -13400,5.298422,1.5693449,,,,,,,,,,,,,, -13500,3.040257,1.593849,,,,,,,,,,,,,, -13600,3.501657,1.547313,,,,,,,,,,,,,, -13700,3.8598495,1.6149117,,,,,,,,,,,,,, -13800,2.6652114,1.615319,,,,,,,,,,,,,, -13900,2.7443347,1.5562568,,,,,,,,,,,,,, -13956,,,0.3189625,0.1119997919862711,0.6466541,0.1890187976095079,5348.0,0.38722602,0.1270286190157008,2472.0,11567.57804465294,12884.33522772789,11567.57804465294,1315.8529777526855,0.3267304897308349,0.0 -14000,4.477115,1.5668529,,,,,,,,,,,,,, -14100,3.951185,1.5716236,,,,,,,,,,,,,, -14200,4.2440825,1.5179622,,,,,,,,,,,,,, -14300,3.1285238,1.516911,,,,,,,,,,,,,, -14400,3.4852605,1.6125517,,,,,,,,,,,,,, -14500,2.2210312,1.4833521,,,,,,,,,,,,,, -14600,2.8984017,1.4979252,,,,,,,,,,,,,, -14700,2.4844298,1.5268412,,,,,,,,,,,,,, -14800,2.1895921,1.5701343,,,,,,,,,,,,,, -14900,2.6809256,1.5583903,,,,,,,,,,,,,, -15000,2.7555976,1.4916303,,,,,,,,,,,,,, -15100,4.39234,1.5303955,,,,,,,,,,,,,, -15200,4.0674877,1.5601119,,,,,,,,,,,,,, -15300,3.634778,1.5396681,,,,,,,,,,,,,, -15400,3.4891107,1.5543356,,,,,,,,,,,,,, -15500,2.971856,1.450003,,,,,,,,,,,,,, -15600,2.3881814,1.5271093,,,,,,,,,,,,,, -15677,,,0.30618104,0.1042991131516367,0.6405959,0.1845680025488284,5348.0,0.3840812,0.1235756504783377,2472.0,13007.677670955658,14458.273012399672,13007.677670955658,1449.5785381793976,0.3680403232574463,0.0 -15700,2.211635,1.5208443,,,,,,,,,,,,,, -15800,4.478663,1.5438426,,,,,,,,,,,,,, -15900,2.7238095,1.4590741,,,,,,,,,,,,,, -16000,3.0260134,1.5846244,,,,,,,,,,,,,, -16100,2.6272283,1.4935315,,,,,,,,,,,,,, -16200,2.8691761,1.5565641,,,,,,,,,,,,,, -16300,3.008323,1.5652473,,,,,,,,,,,,,, -16400,2.240835,1.4782585,,,,,,,,,,,,,, -16500,3.2711132,1.5486456,,,,,,,,,,,,,, -16600,2.5161774,1.5217085,,,,,,,,,,,,,, -16700,2.7338428,1.4333175,,,,,,,,,,,,,, -16800,4.4227715,1.5653245,,,,,,,,,,,,,, -16900,2.7742922,1.5245724,,,,,,,,,,,,,, -17000,2.8885953,1.5923175,,,,,,,,,,,,,, -17100,3.0488524,1.506002,,,,,,,,,,,,,, -17200,2.8362215,1.6000216,,,,,,,,,,,,,, -17300,2.425317,1.5736234,,,,,,,,,,,,,, -17400,3.3700774,1.5549908,,,,,,,,,,,,,, -17405,,,0.306939,0.1079547587508781,0.6359483,0.1824536335286791,5348.0,0.36800343,0.1195945808705543,2472.0,14447.619881868362,16042.089720010756,14447.619881868362,1593.3436903953552,0.4052996635437011,0.0 -17500,2.9971898,1.520827,,,,,,,,,,,,,, -17600,3.2735167,1.4665897,,,,,,,,,,,,,, -17700,2.2705772,1.559656,,,,,,,,,,,,,, -17800,3.2979722,1.5468948,,,,,,,,,,,,,, -17900,4.421806,1.4700826,,,,,,,,,,,,,, -18000,3.0426908,1.5637814,,,,,,,,,,,,,, -18100,2.691981,1.519422,,,,,,,,,,,,,, -18200,2.4144266,1.490258,,,,,,,,,,,,,, -18300,3.1350648,1.5421547,,,,,,,,,,,,,, -18400,2.8925104,1.5405154,,,,,,,,,,,,,, -18500,4.2657776,1.5437769,,,,,,,,,,,,,, -18600,1.8456998,1.5080167,,,,,,,,,,,,,, -18700,3.0571532,1.4992957,,,,,,,,,,,,,, -18800,2.6758096,1.4587471,,,,,,,,,,,,,, -18900,2.7258556,1.4508253,,,,,,,,,,,,,, -19000,2.57074,1.5254366,,,,,,,,,,,,,, -19100,2.8707166,1.5119846,,,,,,,,,,,,,, -19134,,,0.3205848,0.1075173854411803,0.61774206,0.1791710514882647,5348.0,0.35935163,0.1146994901793512,2472.0,15888.127633094788,17616.94973897934,15888.127633094788,1727.5816078186035,0.4457442760467529,0.0 -19200,4.015628,1.5027703,,,,,,,,,,,,,, -19300,3.0151083,1.5214378,,,,,,,,,,,,,, -19400,2.9676316,1.5094639,,,,,,,,,,,,,, -19500,3.0748346,1.4396131,,,,,,,,,,,,,, -19600,2.0954719,1.4905567,,,,,,,,,,,,,, -19700,2.6144435,1.53132,,,,,,,,,,,,,, -19800,2.1095152,1.4836408,,,,,,,,,,,,,, -19900,3.5639193,1.506018,,,,,,,,,,,,,, -20000,2.6912777,1.4832542,,,,,,,,,,,,,, -20100,1.8748618,1.4301496,,,,,,,,,,,,,, -20200,2.4433258,1.5356333,,,,,,,,,,,,,, -20300,3.260252,1.5197077,,,,,,,,,,,,,, -20400,2.9646294,1.481169,,,,,,,,,,,,,, -20500,3.139422,1.521658,,,,,,,,,,,,,, -20600,1.832907,1.4966311,,,,,,,,,,,,,, -20700,2.7410967,1.4489498,,,,,,,,,,,,,, -20800,2.3645167,1.4409633,,,,,,,,,,,,,, -20886,,,0.31008214,0.1054897512332814,0.5979301,0.1724417583054153,5348.0,0.34600043,0.1121605427254077,2472.0,17328.907950878143,19191.85638141632,17328.907950878143,1861.594043970108,0.4855766296386719,0.0 -20900,2.4464693,1.4682562,,,,,,,,,,,,,, -21000,3.0023515,1.4770188,,,,,,,,,,,,,, -21100,2.8982022,1.4546305,,,,,,,,,,,,,, -21200,3.3637989,1.4523389,,,,,,,,,,,,,, -21300,2.311914,1.3912908,,,,,,,,,,,,,, -21400,2.7726567,1.4624779,,,,,,,,,,,,,, -21500,3.6177588,1.4896871,,,,,,,,,,,,,, -21600,3.6109524,1.5139805,,,,,,,,,,,,,, -21700,2.5225446,1.4743087,,,,,,,,,,,,,, -21800,2.6034074,1.4512748,,,,,,,,,,,,,, -21900,2.2070975,1.4388185,,,,,,,,,,,,,, -22000,4.013652,1.4304032,,,,,,,,,,,,,, -22100,2.177917,1.4663991,,,,,,,,,,,,,, -22200,2.790142,1.4819882,,,,,,,,,,,,,, -22300,2.2403145,1.4948094,,,,,,,,,,,,,, -22400,3.957549,1.4315715,,,,,,,,,,,,,, -22500,2.6749468,1.5156472,,,,,,,,,,,,,, -22583,,,0.28917006,0.0984051857186953,0.57588047,0.1671896270407523,5348.0,0.33355513,0.1080982267990981,2472.0,18769.11666440964,20766.481579065323,18769.11666440964,1995.896348953247,0.5268707275390625,0.0 -22600,3.1683383,1.4853595,,,,,,,,,,,,,, -22700,2.0160556,1.3742471,,,,,,,,,,,,,, -22800,3.1387284,1.4141147,,,,,,,,,,,,,, -22900,2.2818508,1.4958313,,,,,,,,,,,,,, -23000,4.430839,1.4740896,,,,,,,,,,,,,, -23100,2.1302376,1.3956141,,,,,,,,,,,,,, -23200,3.0975778,1.4904014,,,,,,,,,,,,,, -23300,2.2168117,1.439373,,,,,,,,,,,,,, -23400,3.3392155,1.4835881,,,,,,,,,,,,,, -23500,2.57613,1.4540223,,,,,,,,,,,,,, -23600,2.609616,1.4010203,,,,,,,,,,,,,, -23700,2.6713722,1.3797332,,,,,,,,,,,,,, -23800,2.93525,1.4190952,,,,,,,,,,,,,, -23900,2.8093178,1.4804341,,,,,,,,,,,,,, -24000,2.2775536,1.4737083,,,,,,,,,,,,,, -24100,3.3914793,1.4131652,,,,,,,,,,,,,, -24200,3.032846,1.4329659,,,,,,,,,,,,,, -24285,,,0.27114815,0.0929278746674549,0.5687677,0.1633181111636753,5348.0,0.32747313,0.1063920541100481,2472.0,20209.478974580765,22352.04579281807,20209.478974580765,2140.9842009544373,0.5696592330932617,0.0 -24300,2.3022459,1.3702278,,,,,,,,,,,,,, -24400,4.7450714,1.4783908,,,,,,,,,,,,,, -24500,2.6518836,1.3578936,,,,,,,,,,,,,, -24600,3.052063,1.4134538,,,,,,,,,,,,,, -24700,2.453921,1.4863857,,,,,,,,,,,,,, -24800,3.3732803,1.4554776,,,,,,,,,,,,,, -24900,2.6789858,1.4373087,,,,,,,,,,,,,, -25000,2.8972292,1.4239788,,,,,,,,,,,,,, -25100,3.1225,1.412563,,,,,,,,,,,,,, -25200,3.2985404,1.4240866,,,,,,,,,,,,,, -25300,2.0857077,1.3704474,,,,,,,,,,,,,, -25400,2.4484577,1.4035363,,,,,,,,,,,,,, -25500,3.836758,1.4104639,,,,,,,,,,,,,, -25600,2.9997272,1.4298955,,,,,,,,,,,,,, -25700,3.1303637,1.4070504,,,,,,,,,,,,,, -25800,3.9401996,1.4162308,,,,,,,,,,,,,, -25900,2.619787,1.4049605,,,,,,,,,,,,,, -26000,2.1472268,1.4115468,,,,,,,,,,,,,, -26018,,,0.24425977,0.0852656563812306,0.5455414,0.1575060100215298,5348.0,0.31358135,0.1003392033798468,2472.0,21649.439897060394,23930.495884418488,21649.439897060394,2279.357107400894,0.6119377613067627,0.0 -26100,4.7233877,1.4209716,,,,,,,,,,,,,, -26200,3.0313041,1.3974208,,,,,,,,,,,,,, -26300,2.6678336,1.316621,,,,,,,,,,,,,, -26400,3.2163281,1.4016258,,,,,,,,,,,,,, -26500,3.1046774,1.3914183,,,,,,,,,,,,,, -26600,4.1619754,1.3917652,,,,,,,,,,,,,, -26700,3.02297,1.3478986,,,,,,,,,,,,,, -26800,3.228734,1.369947,,,,,,,,,,,,,, -26900,4.0206666,1.3905909,,,,,,,,,,,,,, -27000,5.0303698,1.3886693,,,,,,,,,,,,,, -27100,2.0822208,1.374711,,,,,,,,,,,,,, -27200,2.3302824,1.381333,,,,,,,,,,,,,, -27300,2.2965677,1.3862474,,,,,,,,,,,,,, -27400,3.3803873,1.3679242,,,,,,,,,,,,,, -27500,2.5197654,1.3715476,,,,,,,,,,,,,, -27600,2.1832683,1.2974514,,,,,,,,,,,,,, -27700,2.340488,1.3829154,,,,,,,,,,,,,, -27733,,,0.23917654,0.0822074030131634,0.533727,0.1542813558994757,5348.0,0.30254042,0.0986330306907968,2472.0,23089.615846395493,25511.851219892505,23089.615846395493,2420.41985297203,0.6570084095001221,0.0 -27800,2.5711892,1.3503864,,,,,,,,,,,,,, -27900,1.7657114,1.3236988,,,,,,,,,,,,,, -28000,2.2700872,1.3171172,,,,,,,,,,,,,, -28100,3.087555,1.3544122,,,,,,,,,,,,,, -28200,2.2832375,1.3338096,,,,,,,,,,,,,, -28300,2.7353175,1.2986498,,,,,,,,,,,,,, -28400,3.3545954,1.3149126,,,,,,,,,,,,,, -28500,4.3120294,1.3729339,,,,,,,,,,,,,, -28600,3.0428257,1.3458027,,,,,,,,,,,,,, -28700,2.031659,1.3901116,,,,,,,,,,,,,, -28800,3.584622,1.4075301,,,,,,,,,,,,,, -28900,2.2413256,1.3293536,,,,,,,,,,,,,, -29000,3.5273333,1.3511072,,,,,,,,,,,,,, -29100,2.175258,1.3307996,,,,,,,,,,,,,, -29200,2.2294686,1.3194772,,,,,,,,,,,,,, -29300,4.0112276,1.3572751,,,,,,,,,,,,,, -29400,2.2987137,1.3211967,,,,,,,,,,,,,, -29456,,,0.25039864,0.0864913763217195,0.51835567,0.1496471224306554,5348.0,0.29362866,0.0956878516442223,2472.0,24530.126829862595,27095.03487610817,24530.126829862595,2562.9808316230774,0.6976122856140137,0.0 -29500,2.8352501,1.41522,,,,,,,,,,,,,, -29600,2.474674,1.409807,,,,,,,,,,,,,, -29700,4.190944,1.3080767,,,,,,,,,,,,,, -29800,4.3837314,1.316413,,,,,,,,,,,,,, -29900,2.7107062,1.3084306,,,,,,,,,,,,,, -30000,3.400545,1.3597013,,,,,,,,,,,,,, -30100,1.6847614,1.2799854,,,,,,,,,,,,,, -30200,1.9680358,1.3564942,,,,,,,,,,,,,, -30300,3.0851223,1.2710679,,,,,,,,,,,,,, -30400,4.8233986,1.3539307,,,,,,,,,,,,,, -30500,3.255781,1.2751822,,,,,,,,,,,,,, -30600,1.8445373,1.3654974,,,,,,,,,,,,,, -30700,2.3109195,1.3158965,,,,,,,,,,,,,, -30800,2.8567424,1.3408082,,,,,,,,,,,,,, -30900,3.424853,1.305925,,,,,,,,,,,,,, -31000,3.359862,1.3261495,,,,,,,,,,,,,, -31100,2.164987,1.2830354,,,,,,,,,,,,,, -31194,,,0.22763942,0.0768750779982528,0.5053736,0.1465672881045019,5348.0,0.28454825,0.0915646009790181,2472.0,25970.27191734314,28675.169151067734,25970.27191734314,2702.8475642204285,0.7462007999420166,0.0 -31200,2.3461444,1.3536747,,,,,,,,,,,,,, -31300,2.8219774,1.271702,,,,,,,,,,,,,, -31400,2.2333274,1.2782009,,,,,,,,,,,,,, -31500,3.5145552,1.3358467,,,,,,,,,,,,,, -31600,2.241169,1.3500088,,,,,,,,,,,,,, -31700,3.0011666,1.3264014,,,,,,,,,,,,,, -31800,2.631812,1.191532,,,,,,,,,,,,,, -31900,3.7510734,1.335671,,,,,,,,,,,,,, -32000,2.4920647,1.2828286,,,,,,,,,,,,,, -32100,3.6276395,1.2772369,,,,,,,,,,,,,, -32200,2.2324758,1.2373592,,,,,,,,,,,,,, -32300,3.7090085,1.2577168,,,,,,,,,,,,,, -32400,2.5677915,1.2733971,,,,,,,,,,,,,, -32500,2.3763638,1.2443085,,,,,,,,,,,,,, -32600,2.5433302,1.3270189,,,,,,,,,,,,,, -32700,2.7709792,1.3631446,,,,,,,,,,,,,, -32800,3.5626633,1.3028651,,,,,,,,,,,,,, -32900,2.6052654,1.320257,,,,,,,,,,,,,, -32913,,,0.2309214,0.0789493961544981,0.49034488,0.1433522886355079,5348.0,0.27736688,0.0895943777547579,2472.0,27410.788177251816,30251.19058847428,27410.788177251816,2838.235775709152,0.7904810905456543,0.0 -33000,3.3115034,1.3034393,,,,,,,,,,,,,, -33100,3.327431,1.2677295,,,,,,,,,,,,,, -33200,2.8896701,1.2675538,,,,,,,,,,,,,, -33300,3.824125,1.2872663,,,,,,,,,,,,,, -33400,1.8282702,1.2448804,,,,,,,,,,,,,, -33500,3.7094307,1.2970597,,,,,,,,,,,,,, -33600,1.9770488,1.2734039,,,,,,,,,,,,,, -33700,2.811014,1.2653035,,,,,,,,,,,,,, -33800,2.5607364,1.2248391,,,,,,,,,,,,,, -33900,3.0440736,1.3258406,,,,,,,,,,,,,, -34000,2.3410547,1.2878053,,,,,,,,,,,,,, -34100,4.4467373,1.2920451,,,,,,,,,,,,,, -34200,2.203474,1.255845,,,,,,,,,,,,,, -34300,2.3926716,1.2589018,,,,,,,,,,,,,, -34400,2.2678792,1.2759197,,,,,,,,,,,,,, -34500,4.7497296,1.2303038,,,,,,,,,,,,,, -34600,4.309547,1.2760859,,,,,,,,,,,,,, -34613,,,0.21600437,0.0720131717765913,0.47284636,0.1371153827587205,5348.0,0.2623799,0.0847805333820811,2472.0,28851.9081697464,31830.44634723664,28851.9081697464,2976.252999305725,0.8371679782867432,0.0 -34700,2.3562899,1.2211496,,,,,,,,,,,,,, -34800,3.2783263,1.2821732,,,,,,,,,,,,,, -34900,2.768305,1.2453088,,,,,,,,,,,,,, -35000,2.6737945,1.2303065,,,,,,,,,,,,,, -35100,2.1563573,1.2123052,,,,,,,,,,,,,, -35200,2.852273,1.2504743,,,,,,,,,,,,,, -35300,4.1290016,1.3075001,,,,,,,,,,,,,, -35400,2.596908,1.2178917,,,,,,,,,,,,,, -35500,3.1638975,1.2644955,,,,,,,,,,,,,, -35600,2.6639752,1.2634212,,,,,,,,,,,,,, -35700,2.6153271,1.2664067,,,,,,,,,,,,,, -35800,2.569157,1.2390891,,,,,,,,,,,,,, -35900,3.1164606,1.220099,,,,,,,,,,,,,, -36000,2.5445843,1.2613189,,,,,,,,,,,,,, -36100,2.4699183,1.2513524,,,,,,,,,,,,,, -36200,3.1496665,1.190519,,,,,,,,,,,,,, -36300,2.4021118,1.2496667,,,,,,,,,,,,,, -36352,,,0.16080004,0.0557383434293738,0.46172354,0.1340065844733869,5348.0,0.25495288,0.0809822679909816,2472.0,30292.434127807617,33406.52474331856,30292.434127807617,3111.6914982795715,0.8780829906463623,0.0 -36400,3.349783,1.2353542,,,,,,,,,,,,,, -36500,3.5190332,1.2506901,,,,,,,,,,,,,, -36600,5.7452884,1.1917545,,,,,,,,,,,,,, -36700,2.6421885,1.2382829,,,,,,,,,,,,,, -36800,3.0414708,1.2140887,,,,,,,,,,,,,, -36900,2.5919719,1.2483693,,,,,,,,,,,,,, -37000,6.7865763,1.190247,,,,,,,,,,,,,, -37100,2.3416834,1.2204764,,,,,,,,,,,,,, -37200,3.4000242,1.2117269,,,,,,,,,,,,,, -37300,4.0126214,1.219156,,,,,,,,,,,,,, -37400,3.2480116,1.2434394,,,,,,,,,,,,,, -37500,3.068931,1.1894693,,,,,,,,,,,,,, -37600,4.4234095,1.194214,,,,,,,,,,,,,, -37700,2.8313618,1.2071265,,,,,,,,,,,,,, -37800,3.320006,1.2407615,,,,,,,,,,,,,, -37900,3.1058202,1.2366703,,,,,,,,,,,,,, -38000,3.222911,1.2097538,,,,,,,,,,,,,, -38068,,,0.17805463,0.0605900614559354,0.4462943,0.1298164650453286,5348.0,0.24388216,0.0778542847277232,2472.0,31732.72907662392,34988.733575344086,31732.72907662392,3253.490313768387,0.9202170372009276,0.0 -38100,2.5723538,1.2328732,,,,,,,,,,,,,, -38200,2.9405417,1.2518711,,,,,,,,,,,,,, -38300,2.571787,1.1424471,,,,,,,,,,,,,, -38400,3.256955,1.2650788,,,,,,,,,,,,,, -38500,2.8558712,1.225839,,,,,,,,,,,,,, -38600,2.5374355,1.1827363,,,,,,,,,,,,,, -38700,3.2304816,1.2393174,,,,,,,,,,,,,, -38800,2.243348,1.1803668,,,,,,,,,,,,,, -38900,3.5308459,1.2013518,,,,,,,,,,,,,, -39000,2.769773,1.1286522,,,,,,,,,,,,,, -39100,3.1620922,1.1878921,,,,,,,,,,,,,, -39200,3.4830844,1.1882932,,,,,,,,,,,,,, -39300,6.0242085,1.1934053,,,,,,,,,,,,,, -39400,2.8109298,1.2064133,,,,,,,,,,,,,, -39500,3.1216428,1.1734633,,,,,,,,,,,,,, -39600,3.2910473,1.1916721,,,,,,,,,,,,,, -39700,2.932282,1.2152354,,,,,,,,,,,,,, -39784,,,0.22174585,0.0760185069672331,0.43836048,0.1261187329233324,5348.0,0.2375296,0.0756606341275161,2472.0,33173.24671959877,36563.3217394352,33173.24671959877,3387.4322237968445,0.9764108657836914,0.0 -39800,3.6436267,1.1750829,,,,,,,,,,,,,, -39900,2.7656696,1.1566958,,,,,,,,,,,,,, -40000,2.9082909,1.2117007,,,,,,,,,,,,,, -40100,2.9461474,1.2436911,,,,,,,,,,,,,, -40200,3.860479,1.1398598,,,,,,,,,,,,,, -40300,2.4377377,1.1582018,,,,,,,,,,,,,, -40400,2.97514,1.1897111,,,,,,,,,,,,,, -40500,3.39387,1.1439308,,,,,,,,,,,,,, -40600,2.4582653,1.13937,,,,,,,,,,,,,, -40700,2.2887046,1.1963763,,,,,,,,,,,,,, -40800,2.3018456,1.1792321,,,,,,,,,,,,,, -40900,2.9046154,1.1660516,,,,,,,,,,,,,, -41000,4.9113026,1.1420571,,,,,,,,,,,,,, -41100,4.158077,1.1564845,,,,,,,,,,,,,, -41200,3.1247413,1.1552815,,,,,,,,,,,,,, -41300,3.460479,1.1492617,,,,,,,,,,,,,, -41400,2.2874026,1.1673183,,,,,,,,,,,,,, -41500,2.2701743,1.1897098,,,,,,,,,,,,,, -41515,,,0.22483462,0.0773355834725792,0.4263075,0.1229906253318787,5348.0,0.22996153,0.0736294761643613,2472.0,34613.45397615433,38139.72946357727,34613.45397615433,3523.49757862091,1.033158540725708,0.0 -41600,2.3409362,1.0994306,,,,,,,,,,,,,, -41700,2.789221,1.161774,,,,,,,,,,,,,, -41800,4.3040023,1.154319,,,,,,,,,,,,,, -41900,2.4979026,1.145888,,,,,,,,,,,,,, -42000,2.4386501,1.1617819,,,,,,,,,,,,,, -42100,3.038559,1.1595327,,,,,,,,,,,,,, -42200,3.6233728,1.1268998,,,,,,,,,,,,,, -42300,2.730659,1.1111372,,,,,,,,,,,,,, -42400,2.7337775,1.1746235,,,,,,,,,,,,,, -42500,2.4799857,1.1293317,,,,,,,,,,,,,, -42600,2.6294975,1.1550364,,,,,,,,,,,,,, -42700,3.2937326,1.175199,,,,,,,,,,,,,, -42800,2.8510091,1.1113203,,,,,,,,,,,,,, -42900,2.4492059,1.1326616,,,,,,,,,,,,,, -43000,3.280956,1.1914113,,,,,,,,,,,,,, -43100,4.5945234,1.1597464,,,,,,,,,,,,,, -43200,3.5804844,1.1468177,,,,,,,,,,,,,, -43212,,,0.25475532,0.0880018250843809,0.41781008,0.1205383434546279,5348.0,0.22553718,0.0716998760993642,2472.0,36053.44173717499,39714.23803758621,36053.44173717499,3657.8997716903687,1.0785846710205078,0.0 -43300,4.0418897,1.1212367,,,,,,,,,,,,,, -43400,4.4303727,1.1406381,,,,,,,,,,,,,, -43500,2.8885772,1.0807683,,,,,,,,,,,,,, -43600,4.7693143,1.1520118,,,,,,,,,,,,,, -43700,2.2236814,1.1130275,,,,,,,,,,,,,, -43800,2.4516041,1.1914024,,,,,,,,,,,,,, -43900,2.7893894,1.1454728,,,,,,,,,,,,,, -44000,3.679458,1.1506158,,,,,,,,,,,,,, -44100,3.1541755,1.1621931,,,,,,,,,,,,,, -44200,3.021267,1.0779983,,,,,,,,,,,,,, -44300,3.19914,1.0775688,,,,,,,,,,,,,, -44400,3.83751,1.1119106,,,,,,,,,,,,,, -44500,5.803289,1.1098001,,,,,,,,,,,,,, -44600,2.0748992,1.085534,,,,,,,,,,,,,, -44700,2.911786,1.1868345,,,,,,,,,,,,,, -44800,3.3405569,1.1719272,,,,,,,,,,,,,, -44900,3.265331,1.0894059,,,,,,,,,,,,,, -44916,,,0.22422256,0.0749018245618166,0.41305763,0.1193025478629425,5348.0,0.22203259,0.0716592529401011,2472.0,37493.750953912735,41287.81129693985,37493.750953912735,3791.0279178619385,1.1358962059020996,0.0 -45000,4.375649,1.123853,,,,,,,,,,,,,, -45100,2.3531168,1.0771639,,,,,,,,,,,,,, -45200,2.9818013,1.157924,,,,,,,,,,,,,, -45300,4.105555,1.0357034,,,,,,,,,,,,,, -45400,2.9687471,1.074368,,,,,,,,,,,,,, -45500,3.5178754,1.1791197,,,,,,,,,,,,,, -45600,2.6152747,1.0214702,,,,,,,,,,,,,, -45700,4.5882077,1.1266068,,,,,,,,,,,,,, -45800,3.0747268,1.1215293,,,,,,,,,,,,,, -45900,2.8503056,1.0654923,,,,,,,,,,,,,, -46000,2.6202762,1.1673408,,,,,,,,,,,,,, -46100,4.0564623,1.1297948,,,,,,,,,,,,,, -46200,3.450468,1.0640203,,,,,,,,,,,,,, -46300,5.051031,1.0661397,,,,,,,,,,,,,, -46400,2.5002983,1.1130809,,,,,,,,,,,,,, -46500,2.8870797,1.0830568,,,,,,,,,,,,,, -46600,3.4195838,1.1165409,,,,,,,,,,,,,, -46632,,,0.20445894,0.0709529800808346,0.410703,0.1183563918630584,5348.0,0.22045316,0.0711108402900493,2472.0,38933.89561486244,42861.37547492981,38933.89561486244,3924.312516689301,1.1946141719818115,0.0 -46700,2.9899547,1.1138841,,,,,,,,,,,,,, -46800,4.357436,1.1216927,,,,,,,,,,,,,, -46900,2.695979,1.114617,,,,,,,,,,,,,, -47000,2.961605,1.1166263,,,,,,,,,,,,,, -47100,4.384102,1.0871626,,,,,,,,,,,,,, -47200,7.7208304,1.1253306,,,,,,,,,,,,,, -47300,2.9543705,1.1047004,,,,,,,,,,,,,, -47400,2.739116,1.1375762,,,,,,,,,,,,,, -47500,3.031345,1.1137097,,,,,,,,,,,,,, -47600,4.0940905,1.1447504,,,,,,,,,,,,,, -47700,2.3704207,1.1772394,,,,,,,,,,,,,, -47800,2.957508,1.0938786,,,,,,,,,,,,,, -47900,2.409721,1.1159348,,,,,,,,,,,,,, -48000,,,0.18561524,0.0648674808861434,0.4110458,0.118037788312077,5348.0,0.22013809,0.0707452318566815,2472.0,40087.15713596344,44151.48936414719,40087.15713596344,4061.046733379364,1.250619411468506,0.0 -48000,,,,,,,,,,,40087.15713596344,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/eval_measurements.csv deleted file mode 100644 index a43b566d9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -218.9519054889679,0.0,15.68349552154541,1,0,15.68349552154541,30.939756,2472,4.561208945219669,234.63548374176025,31.934551,4.645983478794528,30.812979,5348,4.233507438910182 -331.06379795074463,0.0316545963287353,1456.0923523902893,1709,0,1456.0923523902893,6.318093,2472,0.895131314362318,1787.2602503299713,6.530197,0.936823569833362,6.3885655,5348,0.8922154532376879 -458.6938469409943,0.0821821689605712,2896.368205070496,3429,0,2896.368205070496,2.2226102,2472,0.5148985436597404,3355.2950081825256,2.8725617,0.6321504898078738,2.7419052,5348,0.5899089566216438 -590.7899439334869,0.1266040802001953,4336.491911649704,5129,0,4336.491911649704,0.5947173,2472,0.194686490768387,4927.635001182556,0.80511504,0.2522180814354727,0.916454,5348,0.2608590710292826 -725.4847946166992,0.1748373508453369,5777.701329469681,6825,0,5777.701329469681,0.49312124,2472,0.1629598033839091,6503.666594266892,0.6916581,0.2225141078030745,0.80308455,5348,0.2303117487473087 -858.9805746078491,0.2307698726654052,7218.413636207581,8557,0,7218.413636207581,0.42479584,2472,0.1396827331261552,8078.009069681168,0.58069324,0.1863339831625262,0.71140873,5348,0.2046207169545362 -995.5229451656342,0.2782824039459228,8659.064038991928,10255,0,8659.064038991928,0.40202525,2472,0.1314565433753783,9655.32750749588,0.5201113,0.1697827695204788,0.67553055,5348,0.1947729708332931 -1130.5369474887848,0.3322751522064209,10099.23685860634,11949,0,10099.23685860634,0.37967756,2472,0.1240834399691264,11230.648217201231,0.50205463,0.1670393866094655,0.64501023,5348,0.1888836324666673 -1267.259752511978,0.3846044540405273,11539.904070854189,13663,0,11539.904070854189,0.3549172,2472,0.1164462860276643,12808.168608665466,0.44674954,0.1488639914297192,0.61529595,5348,0.1775007965088774 -1401.9886286258698,0.4373552799224853,12979.79869055748,15356,0,12979.79869055748,0.3436971,2472,0.1125261511587756,14382.923528432846,0.44082975,0.1450491639005448,0.59986717,5348,0.1732624038155188 -1535.6642746925354,0.4859333038330078,14420.747940063477,17068,0,14420.747940063477,0.3291357,2472,0.1075701257286779,15957.678583145142,0.4190414,0.1400237025111245,0.58253497,5348,0.168261293530417 -1669.7652735710144,0.5466783046722412,15860.9145257473,18794,0,15860.9145257473,0.31767148,2472,0.1038937298153677,17532.088482141495,0.39970213,0.1371107888823705,0.5615143,5348,0.1631250181024744 -1804.435683012009,0.6045889854431152,17301.38058924675,20489,0,17301.38058924675,0.30769292,2472,0.1017407023744236,19107.35986685753,0.3484167,0.1170783556563183,0.5440495,5348,0.1592438475723374 -1939.11581158638,0.6574249267578125,18741.621086359024,22206,0,18741.621086359024,0.2998111,2472,0.0972518432758515,20682.414192676544,0.3176922,0.1081035156581876,0.52992153,5348,0.1538662058178939 -2074.142430305481,0.7140963077545166,20181.759190559387,23906,0,20181.759190559387,0.29122898,2472,0.0952003737330652,22257.7140955925,0.35585138,0.1228253384061325,0.5189788,5348,0.1510470471243616 -2209.3501620292664,0.7669525146484375,21621.865079164505,25609,0,21621.865079164505,0.28193888,2472,0.0931692157699104,23833.15885949135,0.325746,0.109971558589306,0.5139698,5348,0.1487299303899514 -2346.664496421814,0.8250465393066406,23062.529305696487,27335,0,23062.529305696487,0.2696854,2472,0.0883756829768651,25411.27657961845,0.30888176,0.1039969208584733,0.4966396,5348,0.1436902014926093 -2483.1363592147827,0.8873429298400879,24502.98046898842,29061,0,24502.98046898842,0.2685104,2472,0.0859789165803424,26988.34062337876,0.28284952,0.0974000883197173,0.49027404,5348,0.1409965532888575 -2617.004650115967,0.942192316055298,25943.658811092377,30754,0,25943.658811092377,0.2558537,2472,0.0830946722726626,28563.02157688141,0.26444355,0.0926417954318846,0.47725204,5348,0.1377429352076233 -2752.034639120102,0.9958689212799072,27383.986397266388,32487,0,27383.986397266388,0.2501571,2472,0.0817134848577173,30138.511751174927,0.25097135,0.087810355171256,0.46989283,5348,0.136758160595499 -2904.815445184708,1.0548536777496338,28824.25841617584,34220,0,28824.25841617584,0.24460703,2472,0.0788901752889322,31731.70393466949,0.15385501,0.0554912004761095,0.4573218,5348,0.1321046178205586 -3041.3079862594604,1.1017398834228516,30264.718192100525,35919,0,30264.718192100525,0.23849553,2472,0.076066865720147,33308.779076337814,0.14717703,0.050968087151221,0.44906735,5348,0.1288896183515645 -3178.5993349552155,1.1614861488342283,31705.60879182816,37639,0,31705.60879182816,0.23440033,2472,0.073954461438466,34887.10219669342,0.15311022,0.0543354739818705,0.4453374,5348,0.1284261950046825 -3314.5751991271973,1.2173850536346436,33146.07905125618,39347,0,33146.07905125618,0.23139994,2472,0.0734669835273089,36463.68469786644,0.13366854,0.0480371900826446,0.43505618,5348,0.1249601745561273 -3450.242693901062,1.2768633365631104,34586.41890668869,41036,0,34586.41890668869,0.22643007,2472,0.072877947717994,38039.82860040665,0.14206897,0.0508873973055536,0.43153396,5348,0.1245353698214854 -3589.543283224106,1.3324298858642578,36026.87464380264,42749,0,36026.87464380264,0.22455934,2472,0.0707452318566815,39619.72004675865,0.13182025,0.0478538685928601,0.42808613,5348,0.1236857603522017 -3726.88182258606,1.390561580657959,37467.42148900032,44453,0,37467.42148900032,0.22446273,2472,0.0706233623788922,41197.74384975433,0.15589492,0.0517909262597647,0.42629203,5348,0.1228071869237378 -3864.898614168167,1.456418752670288,38907.607313632965,46141,0,38907.607313632965,0.2233816,2472,0.0706030507992606,42776.08789777756,0.13782987,0.0495388040712468,0.4256302,5348,0.1224403101074562 -4002.885982275009,1.5108904838562012,40347.73718690872,47843,0,40347.73718690872,0.22332466,2472,0.070542116060366,44354.33743262291,0.118889384,0.0439292538030257,0.42518348,5348,0.1222665263523755 -4131.328007936478,1.5726885795593262,40467.03829741478,48000,0,40467.03829741478,0.22333445,2472,0.07048118132147137,44602.15502882004,0.12519462,0.04541185510896711,0.42519876,5348,0.12229549031155565 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/measurements.csv deleted file mode 100644 index 668b8eb96..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,18.08153,33.354748,,,,,,,,,,,,,, -1,,,31.934551,4.645983478794528,30.812979,4.233507438910182,5348.0,30.939756,4.561208945219669,2472.0,15.68349552154541,234.63548374176025,15.68349552154541,218.9519054889679,0.0,0.0 -100,2.5242908,7.075327,,,,,,,,,,,,,, -200,1.5059904,6.1131506,,,,,,,,,,,,,, -300,0.53116715,5.8837385,,,,,,,,,,,,,, -400,1.0505542,5.832569,,,,,,,,,,,,,, -500,1.2825003,5.7361107,,,,,,,,,,,,,, -600,1.6838932,5.571105,,,,,,,,,,,,,, -700,0.7035455,5.3010626,,,,,,,,,,,,,, -800,1.2238985,4.820079,,,,,,,,,,,,,, -900,1.4909698,4.4018955,,,,,,,,,,,,,, -1000,1.4764112,4.0144525,,,,,,,,,,,,,, -1100,3.5663123,3.7903988,,,,,,,,,,,,,, -1200,4.740507,3.6618896,,,,,,,,,,,,,, -1300,1.932847,3.3293161,,,,,,,,,,,,,, -1400,2.1489785,3.2291622,,,,,,,,,,,,,, -1500,2.239551,3.1137524,,,,,,,,,,,,,, -1600,3.2106335,3.036523,,,,,,,,,,,,,, -1700,2.050794,2.8983912,,,,,,,,,,,,,, -1709,,,6.530197,0.936823569833362,6.3885655,0.8922154532376879,5348.0,6.318093,0.895131314362318,2472.0,1456.0923523902893,1787.2602503299713,1456.0923523902893,331.06379795074463,0.0316545963287353,0.0 -1800,3.2786582,2.8885055,,,,,,,,,,,,,, -1900,2.9053266,2.725349,,,,,,,,,,,,,, -2000,2.709067,2.6559107,,,,,,,,,,,,,, -2100,2.4501677,2.6327593,,,,,,,,,,,,,, -2200,2.587887,2.6130495,,,,,,,,,,,,,, -2300,3.7161582,2.5117826,,,,,,,,,,,,,, -2400,2.5060987,2.463293,,,,,,,,,,,,,, -2500,3.23494,2.4713504,,,,,,,,,,,,,, -2600,3.2569635,2.3397017,,,,,,,,,,,,,, -2700,2.3527575,2.3475087,,,,,,,,,,,,,, -2800,3.3872855,2.3190105,,,,,,,,,,,,,, -2900,1.9022524,2.2650099,,,,,,,,,,,,,, -3000,3.078404,2.1975603,,,,,,,,,,,,,, -3100,3.6631665,2.1956234,,,,,,,,,,,,,, -3200,2.8721805,2.2130196,,,,,,,,,,,,,, -3300,2.719113,2.235422,,,,,,,,,,,,,, -3400,5.1708465,2.1683013,,,,,,,,,,,,,, -3429,,,2.8725617,0.6321504898078738,2.7419052,0.5899089566216438,5348.0,2.2226102,0.5148985436597404,2472.0,2896.368205070496,3355.2950081825256,2896.368205070496,458.6938469409943,0.0821821689605712,0.0 -3500,2.5189924,2.0398834,,,,,,,,,,,,,, -3600,2.7711563,2.1360784,,,,,,,,,,,,,, -3700,3.3290455,2.0785334,,,,,,,,,,,,,, -3800,3.627456,2.027702,,,,,,,,,,,,,, -3900,2.1644192,1.9024202,,,,,,,,,,,,,, -4000,2.404574,2.0200465,,,,,,,,,,,,,, -4100,1.8839995,2.0115132,,,,,,,,,,,,,, -4200,2.2844748,1.9597594,,,,,,,,,,,,,, -4300,4.921151,1.9122901,,,,,,,,,,,,,, -4400,3.001488,1.9007326,,,,,,,,,,,,,, -4500,2.2946076,1.9312187,,,,,,,,,,,,,, -4600,2.3780499,1.9582615,,,,,,,,,,,,,, -4700,2.1844645,1.880734,,,,,,,,,,,,,, -4800,2.980096,1.8762829,,,,,,,,,,,,,, -4900,2.2082188,1.8535277,,,,,,,,,,,,,, -5000,3.187576,1.8046123,,,,,,,,,,,,,, -5100,6.528563,1.8962034,,,,,,,,,,,,,, -5129,,,0.80511504,0.2522180814354727,0.916454,0.2608590710292826,5348.0,0.5947173,0.194686490768387,2472.0,4336.491911649704,4927.635001182556,4336.491911649704,590.7899439334869,0.1266040802001953,0.0 -5200,2.779154,1.818195,,,,,,,,,,,,,, -5300,3.0150044,1.7794671,,,,,,,,,,,,,, -5400,2.5082357,1.8373363,,,,,,,,,,,,,, -5500,2.3344293,1.8041586,,,,,,,,,,,,,, -5600,4.028638,1.7869513,,,,,,,,,,,,,, -5700,2.7023406,1.6940303,,,,,,,,,,,,,, -5800,3.040507,1.8194152,,,,,,,,,,,,,, -5900,3.0059752,1.7522418,,,,,,,,,,,,,, -6000,4.1908803,1.7349391,,,,,,,,,,,,,, -6100,2.1878593,1.748795,,,,,,,,,,,,,, -6200,2.3186939,1.7042369,,,,,,,,,,,,,, -6300,5.123796,1.7849734,,,,,,,,,,,,,, -6400,3.572588,1.7606924,,,,,,,,,,,,,, -6500,2.0982802,1.7190042,,,,,,,,,,,,,, -6600,2.4303002,1.7517588,,,,,,,,,,,,,, -6700,2.1273592,1.7376802,,,,,,,,,,,,,, -6800,3.194392,1.6875881,,,,,,,,,,,,,, -6825,,,0.6916581,0.2225141078030745,0.80308455,0.2303117487473087,5348.0,0.49312124,0.1629598033839091,2472.0,5777.701329469681,6503.666594266892,5777.701329469681,725.4847946166992,0.1748373508453369,0.0 -6900,1.6073542,1.728542,,,,,,,,,,,,,, -7000,3.353015,1.7181034,,,,,,,,,,,,,, -7100,2.76256,1.7454374,,,,,,,,,,,,,, -7200,2.4710882,1.6615039,,,,,,,,,,,,,, -7300,3.5097914,1.6643325,,,,,,,,,,,,,, -7400,2.0797498,1.670421,,,,,,,,,,,,,, -7500,3.1938672,1.6363436,,,,,,,,,,,,,, -7600,1.6859587,1.6483262,,,,,,,,,,,,,, -7700,2.8201904,1.6745968,,,,,,,,,,,,,, -7800,2.888728,1.6544343,,,,,,,,,,,,,, -7900,2.614851,1.7143726,,,,,,,,,,,,,, -8000,3.655138,1.6477116,,,,,,,,,,,,,, -8100,4.342437,1.6992766,,,,,,,,,,,,,, -8200,2.158033,1.7226958,,,,,,,,,,,,,, -8300,2.0826797,1.6163836,,,,,,,,,,,,,, -8400,1.8231446,1.6009246,,,,,,,,,,,,,, -8500,3.2906005,1.6284132,,,,,,,,,,,,,, -8557,,,0.58069324,0.1863339831625262,0.71140873,0.2046207169545362,5348.0,0.42479584,0.1396827331261552,2472.0,7218.413636207581,8078.009069681168,7218.413636207581,858.9805746078491,0.2307698726654052,0.0 -8600,2.1210067,1.5749758,,,,,,,,,,,,,, -8700,2.2809622,1.5712572,,,,,,,,,,,,,, -8800,1.6022015,1.5491949,,,,,,,,,,,,,, -8900,2.0467865,1.6191871,,,,,,,,,,,,,, -9000,2.1643739,1.6068393,,,,,,,,,,,,,, -9100,2.2167115,1.6157978,,,,,,,,,,,,,, -9200,2.4783497,1.5892469,,,,,,,,,,,,,, -9300,5.0344543,1.6081376,,,,,,,,,,,,,, -9400,1.7017583,1.5861777,,,,,,,,,,,,,, -9500,2.402936,1.5658126,,,,,,,,,,,,,, -9600,3.4348814,1.6578548,,,,,,,,,,,,,, -9700,3.2678144,1.6119858,,,,,,,,,,,,,, -9800,3.582847,1.5505428,,,,,,,,,,,,,, -9900,2.7731657,1.549734,,,,,,,,,,,,,, -10000,2.5671103,1.579551,,,,,,,,,,,,,, -10100,3.805639,1.5807402,,,,,,,,,,,,,, -10200,1.7125989,1.5667157,,,,,,,,,,,,,, -10255,,,0.5201113,0.1697827695204788,0.67553055,0.1947729708332931,5348.0,0.40202525,0.1314565433753783,2472.0,8659.064038991928,9655.32750749588,8659.064038991928,995.5229451656342,0.2782824039459228,0.0 -10300,2.891434,1.5811723,,,,,,,,,,,,,, -10400,2.903111,1.6473216,,,,,,,,,,,,,, -10500,2.134438,1.546379,,,,,,,,,,,,,, -10600,2.1965253,1.570367,,,,,,,,,,,,,, -10700,2.6330297,1.5629393,,,,,,,,,,,,,, -10800,2.9567542,1.5937505,,,,,,,,,,,,,, -10900,2.3300202,1.6034359,,,,,,,,,,,,,, -11000,2.2127008,1.4870808,,,,,,,,,,,,,, -11100,2.9715881,1.6443276,,,,,,,,,,,,,, -11200,3.0492074,1.5639311,,,,,,,,,,,,,, -11300,2.2731156,1.5452373,,,,,,,,,,,,,, -11400,3.7771082,1.4372541,,,,,,,,,,,,,, -11500,3.2893875,1.5681908,,,,,,,,,,,,,, -11600,6.7027493,1.496066,,,,,,,,,,,,,, -11700,2.8948426,1.5059317,,,,,,,,,,,,,, -11800,2.8404071,1.5790612,,,,,,,,,,,,,, -11900,2.8465607,1.5585648,,,,,,,,,,,,,, -11949,,,0.50205463,0.1670393866094655,0.64501023,0.1888836324666673,5348.0,0.37967756,0.1240834399691264,2472.0,10099.23685860634,11230.648217201231,10099.23685860634,1130.5369474887848,0.3322751522064209,0.0 -12000,3.4587224,1.5206506,,,,,,,,,,,,,, -12100,2.1536286,1.4717556,,,,,,,,,,,,,, -12200,2.9042368,1.5087061,,,,,,,,,,,,,, -12300,2.4993517,1.476894,,,,,,,,,,,,,, -12400,2.1811538,1.4529008,,,,,,,,,,,,,, -12500,2.6446304,1.5469897,,,,,,,,,,,,,, -12600,2.4897876,1.5546745,,,,,,,,,,,,,, -12700,2.2738595,1.454433,,,,,,,,,,,,,, -12800,5.0809884,1.5037392,,,,,,,,,,,,,, -12900,4.5593657,1.5267352,,,,,,,,,,,,,, -13000,2.40412,1.5474722,,,,,,,,,,,,,, -13100,1.9774964,1.5072576,,,,,,,,,,,,,, -13200,1.4992143,1.458069,,,,,,,,,,,,,, -13300,2.3170898,1.499816,,,,,,,,,,,,,, -13400,2.7351649,1.4686415,,,,,,,,,,,,,, -13500,4.173064,1.4932091,,,,,,,,,,,,,, -13600,2.7244067,1.5094182,,,,,,,,,,,,,, -13663,,,0.44674954,0.1488639914297192,0.61529595,0.1775007965088774,5348.0,0.3549172,0.1164462860276643,2472.0,11539.904070854189,12808.168608665466,11539.904070854189,1267.259752511978,0.3846044540405273,0.0 -13700,2.8963223,1.4793019,,,,,,,,,,,,,, -13800,3.9810095,1.5018754,,,,,,,,,,,,,, -13900,3.017486,1.460125,,,,,,,,,,,,,, -14000,3.160174,1.5033752,,,,,,,,,,,,,, -14100,2.9815497,1.4808488,,,,,,,,,,,,,, -14200,1.9157389,1.4558971,,,,,,,,,,,,,, -14300,2.182605,1.4389505,,,,,,,,,,,,,, -14400,3.0043726,1.5137628,,,,,,,,,,,,,, -14500,2.78966,1.4679451,,,,,,,,,,,,,, -14600,2.4482172,1.3884012,,,,,,,,,,,,,, -14700,1.9979047,1.47157,,,,,,,,,,,,,, -14800,6.2419086,1.4547352,,,,,,,,,,,,,, -14900,2.815758,1.520055,,,,,,,,,,,,,, -15000,1.6775815,1.4247937,,,,,,,,,,,,,, -15100,2.2200754,1.4762598,,,,,,,,,,,,,, -15200,1.7347342,1.4098818,,,,,,,,,,,,,, -15300,2.702388,1.5199057,,,,,,,,,,,,,, -15356,,,0.44082975,0.1450491639005448,0.59986717,0.1732624038155188,5348.0,0.3436971,0.1125261511587756,2472.0,12979.79869055748,14382.923528432846,12979.79869055748,1401.9886286258698,0.4373552799224853,0.0 -15400,2.2442238,1.5033555,,,,,,,,,,,,,, -15500,3.93083,1.4289166,,,,,,,,,,,,,, -15600,3.5820315,1.4472661,,,,,,,,,,,,,, -15700,4.6067085,1.5087247,,,,,,,,,,,,,, -15800,3.625422,1.4681929,,,,,,,,,,,,,, -15900,2.129671,1.4448711,,,,,,,,,,,,,, -16000,2.5836005,1.450089,,,,,,,,,,,,,, -16100,2.733354,1.4392906,,,,,,,,,,,,,, -16200,2.646993,1.446846,,,,,,,,,,,,,, -16300,2.4669454,1.4421641,,,,,,,,,,,,,, -16400,1.6075392,1.373277,,,,,,,,,,,,,, -16500,2.019363,1.4426268,,,,,,,,,,,,,, -16600,4.533573,1.4317462,,,,,,,,,,,,,, -16700,3.4300148,1.3661864,,,,,,,,,,,,,, -16800,3.149193,1.4786726,,,,,,,,,,,,,, -16900,2.6334965,1.4094017,,,,,,,,,,,,,, -17000,2.4810007,1.4089153,,,,,,,,,,,,,, -17068,,,0.4190414,0.1400237025111245,0.58253497,0.168261293530417,5348.0,0.3291357,0.1075701257286779,2472.0,14420.747940063477,15957.678583145142,14420.747940063477,1535.6642746925354,0.4859333038330078,0.0 -17100,2.491336,1.3966607,,,,,,,,,,,,,, -17200,2.0338442,1.4733211,,,,,,,,,,,,,, -17300,2.7169483,1.4728982,,,,,,,,,,,,,, -17400,2.150553,1.4524049,,,,,,,,,,,,,, -17500,2.3485775,1.3803134,,,,,,,,,,,,,, -17600,1.9458357,1.4395351,,,,,,,,,,,,,, -17700,1.8455502,1.42207,,,,,,,,,,,,,, -17800,3.5419033,1.4561565,,,,,,,,,,,,,, -17900,2.0137112,1.3355535,,,,,,,,,,,,,, -18000,2.273812,1.4482847,,,,,,,,,,,,,, -18100,2.3267236,1.3995506,,,,,,,,,,,,,, -18200,1.7075398,1.4062017,,,,,,,,,,,,,, -18300,2.4214694,1.458339,,,,,,,,,,,,,, -18400,3.3153443,1.409429,,,,,,,,,,,,,, -18500,2.2581391,1.362641,,,,,,,,,,,,,, -18600,3.050014,1.445365,,,,,,,,,,,,,, -18700,1.7131289,1.3737408,,,,,,,,,,,,,, -18794,,,0.39970213,0.1371107888823705,0.5615143,0.1631250181024744,5348.0,0.31767148,0.1038937298153677,2472.0,15860.9145257473,17532.088482141495,15860.9145257473,1669.7652735710144,0.5466783046722412,0.0 -18800,4.139268,1.3588384,,,,,,,,,,,,,, -18900,2.1125968,1.3962731,,,,,,,,,,,,,, -19000,1.8908715,1.403116,,,,,,,,,,,,,, -19100,2.4297245,1.4137818,,,,,,,,,,,,,, -19200,2.922274,1.3559872,,,,,,,,,,,,,, -19300,2.412937,1.3534176,,,,,,,,,,,,,, -19400,1.8725638,1.3878196,,,,,,,,,,,,,, -19500,2.252041,1.3730881,,,,,,,,,,,,,, -19600,2.1467955,1.4520246,,,,,,,,,,,,,, -19700,2.5320783,1.4155526,,,,,,,,,,,,,, -19800,2.2302172,1.4045814,,,,,,,,,,,,,, -19900,2.3897114,1.3606676,,,,,,,,,,,,,, -20000,2.0611832,1.3766301,,,,,,,,,,,,,, -20100,1.9194257,1.3910978,,,,,,,,,,,,,, -20200,1.7287472,1.4046999,,,,,,,,,,,,,, -20300,3.7412448,1.4354411,,,,,,,,,,,,,, -20400,2.5690348,1.4366208,,,,,,,,,,,,,, -20489,,,0.3484167,0.1170783556563183,0.5440495,0.1592438475723374,5348.0,0.30769292,0.1017407023744236,2472.0,17301.38058924675,19107.35986685753,17301.38058924675,1804.435683012009,0.6045889854431152,0.0 -20500,2.6457796,1.4440515,,,,,,,,,,,,,, -20600,2.4172468,1.3524994,,,,,,,,,,,,,, -20700,1.7619729,1.3372325,,,,,,,,,,,,,, -20800,2.0460162,1.3445019,,,,,,,,,,,,,, -20900,2.1274254,1.4081979,,,,,,,,,,,,,, -21000,2.059512,1.3932194,,,,,,,,,,,,,, -21100,2.2747908,1.369183,,,,,,,,,,,,,, -21200,1.549825,1.4215693,,,,,,,,,,,,,, -21300,2.3414564,1.3766494,,,,,,,,,,,,,, -21400,2.5353389,1.4026191,,,,,,,,,,,,,, -21500,2.8072581,1.3607113,,,,,,,,,,,,,, -21600,2.6782515,1.3431892,,,,,,,,,,,,,, -21700,1.8679906,1.4050095,,,,,,,,,,,,,, -21800,2.8129308,1.3292344,,,,,,,,,,,,,, -21900,2.1907372,1.3702948,,,,,,,,,,,,,, -22000,1.6814437,1.363986,,,,,,,,,,,,,, -22100,2.2445776,1.3792446,,,,,,,,,,,,,, -22200,2.3369572,1.3675585,,,,,,,,,,,,,, -22206,,,0.3176922,0.1081035156581876,0.52992153,0.1538662058178939,5348.0,0.2998111,0.0972518432758515,2472.0,18741.621086359024,20682.414192676544,18741.621086359024,1939.11581158638,0.6574249267578125,0.0 -22300,2.1557066,1.3444148,,,,,,,,,,,,,, -22400,2.7402782,1.3982766,,,,,,,,,,,,,, -22500,2.0973763,1.3727031,,,,,,,,,,,,,, -22600,4.339597,1.3683542,,,,,,,,,,,,,, -22700,2.1540189,1.3613247,,,,,,,,,,,,,, -22800,1.8784821,1.3946228,,,,,,,,,,,,,, -22900,2.306067,1.36855,,,,,,,,,,,,,, -23000,1.6974666,1.3147094,,,,,,,,,,,,,, -23100,1.9695355,1.2763137,,,,,,,,,,,,,, -23200,2.5905344,1.3795443,,,,,,,,,,,,,, -23300,2.2515528,1.3361366,,,,,,,,,,,,,, -23400,2.4455264,1.3613886,,,,,,,,,,,,,, -23500,2.4324207,1.3075235,,,,,,,,,,,,,, -23600,2.0820546,1.4072345,,,,,,,,,,,,,, -23700,2.007934,1.3022792,,,,,,,,,,,,,, -23800,2.4357035,1.2834793,,,,,,,,,,,,,, -23900,2.9739196,1.4371141,,,,,,,,,,,,,, -23906,,,0.35585138,0.1228253384061325,0.5189788,0.1510470471243616,5348.0,0.29122898,0.0952003737330652,2472.0,20181.759190559387,22257.7140955925,20181.759190559387,2074.142430305481,0.7140963077545166,0.0 -24000,2.0843425,1.3848248,,,,,,,,,,,,,, -24100,3.2224472,1.3050609,,,,,,,,,,,,,, -24200,2.0591185,1.3296684,,,,,,,,,,,,,, -24300,2.0538602,1.3579923,,,,,,,,,,,,,, -24400,2.089528,1.3513255,,,,,,,,,,,,,, -24500,2.6257458,1.3297296,,,,,,,,,,,,,, -24600,3.2087405,1.3199738,,,,,,,,,,,,,, -24700,2.31223,1.3419505,,,,,,,,,,,,,, -24800,2.0311115,1.3381474,,,,,,,,,,,,,, -24900,2.5601368,1.2804236,,,,,,,,,,,,,, -25000,2.569978,1.2938116,,,,,,,,,,,,,, -25100,2.6368306,1.3018134,,,,,,,,,,,,,, -25200,2.557167,1.3282617,,,,,,,,,,,,,, -25300,2.0813234,1.3140335,,,,,,,,,,,,,, -25400,2.2869534,1.3425078,,,,,,,,,,,,,, -25500,2.828691,1.3267688,,,,,,,,,,,,,, -25600,2.0271175,1.3320583,,,,,,,,,,,,,, -25609,,,0.325746,0.109971558589306,0.5139698,0.1487299303899514,5348.0,0.28193888,0.0931692157699104,2472.0,21621.865079164505,23833.15885949135,21621.865079164505,2209.3501620292664,0.7669525146484375,0.0 -25700,2.4264987,1.3390565,,,,,,,,,,,,,, -25800,2.0520313,1.244492,,,,,,,,,,,,,, -25900,2.4925053,1.2897302,,,,,,,,,,,,,, -26000,2.7605984,1.2689974,,,,,,,,,,,,,, -26100,1.769525,1.2696191,,,,,,,,,,,,,, -26200,3.091647,1.3077376,,,,,,,,,,,,,, -26300,1.5233084,1.2977843,,,,,,,,,,,,,, -26400,1.8477592,1.3186221,,,,,,,,,,,,,, -26500,2.7355418,1.3020401,,,,,,,,,,,,,, -26600,2.0397272,1.3068868,,,,,,,,,,,,,, -26700,1.8825712,1.3061662,,,,,,,,,,,,,, -26800,2.450686,1.267398,,,,,,,,,,,,,, -26900,2.3457596,1.2794499,,,,,,,,,,,,,, -27000,2.8293006,1.3131987,,,,,,,,,,,,,, -27100,2.6564155,1.2787964,,,,,,,,,,,,,, -27200,2.4217272,1.2728798,,,,,,,,,,,,,, -27300,2.587658,1.2905694,,,,,,,,,,,,,, -27335,,,0.30888176,0.1039969208584733,0.4966396,0.1436902014926093,5348.0,0.2696854,0.0883756829768651,2472.0,23062.529305696487,25411.27657961845,23062.529305696487,2346.664496421814,0.8250465393066406,0.0 -27400,1.8940805,1.2781454,,,,,,,,,,,,,, -27500,2.020688,1.3196425,,,,,,,,,,,,,, -27600,1.8374411,1.3204554,,,,,,,,,,,,,, -27700,3.7174802,1.3461235,,,,,,,,,,,,,, -27800,1.9906139,1.3025637,,,,,,,,,,,,,, -27900,2.26955,1.3121998,,,,,,,,,,,,,, -28000,1.6297374,1.2407959,,,,,,,,,,,,,, -28100,2.3986146,1.2558556,,,,,,,,,,,,,, -28200,2.3210042,1.2692529,,,,,,,,,,,,,, -28300,2.584466,1.2794952,,,,,,,,,,,,,, -28400,2.2067757,1.2508159,,,,,,,,,,,,,, -28500,3.232012,1.3012998,,,,,,,,,,,,,, -28600,3.4021604,1.3325565,,,,,,,,,,,,,, -28700,2.4882066,1.2944378,,,,,,,,,,,,,, -28800,4.025385,1.2436557,,,,,,,,,,,,,, -28900,1.8550732,1.2312591,,,,,,,,,,,,,, -29000,2.6426435,1.2579324,,,,,,,,,,,,,, -29061,,,0.28284952,0.0974000883197173,0.49027404,0.1409965532888575,5348.0,0.2685104,0.0859789165803424,2472.0,24502.98046898842,26988.34062337876,24502.98046898842,2483.1363592147827,0.8873429298400879,0.0 -29100,2.1720867,1.2205673,,,,,,,,,,,,,, -29200,2.5437865,1.2692851,,,,,,,,,,,,,, -29300,2.9211328,1.2766445,,,,,,,,,,,,,, -29400,3.1783397,1.2490779,,,,,,,,,,,,,, -29500,2.2530406,1.2821419,,,,,,,,,,,,,, -29600,2.257602,1.3135103,,,,,,,,,,,,,, -29700,5.491691,1.2860608,,,,,,,,,,,,,, -29800,4.907693,1.2436261,,,,,,,,,,,,,, -29900,4.0830355,1.2303348,,,,,,,,,,,,,, -30000,2.0838552,1.246022,,,,,,,,,,,,,, -30100,3.8392541,1.2587366,,,,,,,,,,,,,, -30200,8.197991,1.304534,,,,,,,,,,,,,, -30300,1.8541728,1.2385861,,,,,,,,,,,,,, -30400,2.2736273,1.2839273,,,,,,,,,,,,,, -30500,1.8504791,1.2225862,,,,,,,,,,,,,, -30600,2.0310996,1.3117212,,,,,,,,,,,,,, -30700,2.973184,1.2327952,,,,,,,,,,,,,, -30754,,,0.26444355,0.0926417954318846,0.47725204,0.1377429352076233,5348.0,0.2558537,0.0830946722726626,2472.0,25943.658811092377,28563.02157688141,25943.658811092377,2617.004650115967,0.942192316055298,0.0 -30800,1.6342196,1.2381333,,,,,,,,,,,,,, -30900,2.101632,1.207793,,,,,,,,,,,,,, -31000,1.642497,1.1961313,,,,,,,,,,,,,, -31100,2.8188143,1.2388102,,,,,,,,,,,,,, -31200,1.7794151,1.2563674,,,,,,,,,,,,,, -31300,3.084168,1.1869141,,,,,,,,,,,,,, -31400,3.1322196,1.2197834,,,,,,,,,,,,,, -31500,1.7085564,1.2736568,,,,,,,,,,,,,, -31600,2.4628813,1.2352839,,,,,,,,,,,,,, -31700,1.5759281,1.2859974,,,,,,,,,,,,,, -31800,1.7009287,1.2261975,,,,,,,,,,,,,, -31900,2.5405703,1.239449,,,,,,,,,,,,,, -32000,3.135121,1.2227902,,,,,,,,,,,,,, -32100,4.003054,1.2665395,,,,,,,,,,,,,, -32200,2.0484903,1.2095071,,,,,,,,,,,,,, -32300,2.2819757,1.2295873,,,,,,,,,,,,,, -32400,3.0879235,1.2109957,,,,,,,,,,,,,, -32487,,,0.25097135,0.087810355171256,0.46989283,0.136758160595499,5348.0,0.2501571,0.0817134848577173,2472.0,27383.986397266388,30138.511751174927,27383.986397266388,2752.034639120102,0.9958689212799072,0.0 -32500,2.323787,1.1995565,,,,,,,,,,,,,, -32600,2.3676429,1.2299385,,,,,,,,,,,,,, -32700,1.8027627,1.3023163,,,,,,,,,,,,,, -32800,3.2502205,1.257023,,,,,,,,,,,,,, -32900,2.7572246,1.2232863,,,,,,,,,,,,,, -33000,1.7882113,1.2676938,,,,,,,,,,,,,, -33100,2.3298535,1.226911,,,,,,,,,,,,,, -33200,1.4673072,1.1863394,,,,,,,,,,,,,, -33300,2.2109492,1.271083,,,,,,,,,,,,,, -33400,1.5670849,1.1422615,,,,,,,,,,,,,, -33500,2.2162256,1.2489538,,,,,,,,,,,,,, -33600,2.016902,1.2543186,,,,,,,,,,,,,, -33700,2.4209232,1.2263653,,,,,,,,,,,,,, -33800,3.9201744,1.1974137,,,,,,,,,,,,,, -33900,3.256638,1.2606895,,,,,,,,,,,,,, -34000,2.0775352,1.1877466,,,,,,,,,,,,,, -34100,2.7621212,1.1994838,,,,,,,,,,,,,, -34200,2.703211,1.1557881,,,,,,,,,,,,,, -34220,,,0.15385501,0.0554912004761095,0.4573218,0.1321046178205586,5348.0,0.24460703,0.0788901752889322,2472.0,28824.25841617584,31731.70393466949,28824.25841617584,2904.815445184708,1.0548536777496338,0.0 -34300,1.7590266,1.2375897,,,,,,,,,,,,,, -34400,1.8646528,1.176668,,,,,,,,,,,,,, -34500,2.976028,1.1984104,,,,,,,,,,,,,, -34600,1.5040016,1.2119393,,,,,,,,,,,,,, -34700,3.5649912,1.1947615,,,,,,,,,,,,,, -34800,2.102776,1.2325776,,,,,,,,,,,,,, -34900,2.3066263,1.1652688,,,,,,,,,,,,,, -35000,1.5984484,1.1842412,,,,,,,,,,,,,, -35100,3.358767,1.2299588,,,,,,,,,,,,,, -35200,1.9829904,1.2102575,,,,,,,,,,,,,, -35300,2.8929653,1.1682386,,,,,,,,,,,,,, -35400,2.2916043,1.1629593,,,,,,,,,,,,,, -35500,2.4904304,1.2324741,,,,,,,,,,,,,, -35600,3.0311725,1.2202716,,,,,,,,,,,,,, -35700,2.6302915,1.2034959,,,,,,,,,,,,,, -35800,4.2741337,1.2141018,,,,,,,,,,,,,, -35900,2.7137573,1.2085093,,,,,,,,,,,,,, -35919,,,0.14717703,0.050968087151221,0.44906735,0.1288896183515645,5348.0,0.23849553,0.076066865720147,2472.0,30264.718192100525,33308.779076337814,30264.718192100525,3041.3079862594604,1.1017398834228516,0.0 -36000,1.8807784,1.2065697,,,,,,,,,,,,,, -36100,2.2165904,1.170897,,,,,,,,,,,,,, -36200,2.2594593,1.1856064,,,,,,,,,,,,,, -36300,2.3846557,1.2070998,,,,,,,,,,,,,, -36400,2.7299817,1.160761,,,,,,,,,,,,,, -36500,1.7228136,1.1292278,,,,,,,,,,,,,, -36600,3.3684785,1.2002393,,,,,,,,,,,,,, -36700,2.4088352,1.177164,,,,,,,,,,,,,, -36800,1.7328006,1.1881756,,,,,,,,,,,,,, -36900,1.8729814,1.140432,,,,,,,,,,,,,, -37000,4.4884896,1.1731812,,,,,,,,,,,,,, -37100,2.418974,1.1608582,,,,,,,,,,,,,, -37200,2.3012269,1.201576,,,,,,,,,,,,,, -37300,1.574097,1.1762942,,,,,,,,,,,,,, -37400,2.1498547,1.152108,,,,,,,,,,,,,, -37500,1.8593528,1.1318963,,,,,,,,,,,,,, -37600,3.7029333,1.15457,,,,,,,,,,,,,, -37639,,,0.15311022,0.0543354739818705,0.4453374,0.1284261950046825,5348.0,0.23440033,0.073954461438466,2472.0,31705.60879182816,34887.10219669342,31705.60879182816,3178.5993349552155,1.1614861488342283,0.0 -37700,1.5183586,1.1570411,,,,,,,,,,,,,, -37800,2.5733654,1.1984288,,,,,,,,,,,,,, -37900,3.4551582,1.150074,,,,,,,,,,,,,, -38000,2.6549764,1.1184983,,,,,,,,,,,,,, -38100,2.248352,1.1768819,,,,,,,,,,,,,, -38200,2.3686733,1.2108866,,,,,,,,,,,,,, -38300,2.9377544,1.1020558,,,,,,,,,,,,,, -38400,3.8290727,1.0977209,,,,,,,,,,,,,, -38500,1.5168332,1.1501082,,,,,,,,,,,,,, -38600,2.2800264,1.1079074,,,,,,,,,,,,,, -38700,2.5824714,1.1451664,,,,,,,,,,,,,, -38800,1.9482448,1.0992925,,,,,,,,,,,,,, -38900,2.2853894,1.1710827,,,,,,,,,,,,,, -39000,2.1107836,1.1330494,,,,,,,,,,,,,, -39100,2.2036011,1.183316,,,,,,,,,,,,,, -39200,3.153125,1.1785539,,,,,,,,,,,,,, -39300,3.1597142,1.1566616,,,,,,,,,,,,,, -39347,,,0.13366854,0.0480371900826446,0.43505618,0.1249601745561273,5348.0,0.23139994,0.0734669835273089,2472.0,33146.07905125618,36463.68469786644,33146.07905125618,3314.5751991271973,1.2173850536346436,0.0 -39400,1.7585816,1.1838138,,,,,,,,,,,,,, -39500,2.7436843,1.1790011,,,,,,,,,,,,,, -39600,2.3605673,1.1425233,,,,,,,,,,,,,, -39700,3.6719494,1.17889,,,,,,,,,,,,,, -39800,2.0562286,1.158971,,,,,,,,,,,,,, -39900,2.36032,1.201237,,,,,,,,,,,,,, -40000,3.0095794,1.1234956,,,,,,,,,,,,,, -40100,3.3756518,1.2297959,,,,,,,,,,,,,, -40200,2.4385884,1.092323,,,,,,,,,,,,,, -40300,2.0852122,1.1276691,,,,,,,,,,,,,, -40400,2.55819,1.1070952,,,,,,,,,,,,,, -40500,4.820621,1.141167,,,,,,,,,,,,,, -40600,1.9977063,1.1533943,,,,,,,,,,,,,, -40700,4.159878,1.1673759,,,,,,,,,,,,,, -40800,1.470382,1.1269296,,,,,,,,,,,,,, -40900,2.4345007,1.1578269,,,,,,,,,,,,,, -41000,7.6172185,1.1688288,,,,,,,,,,,,,, -41036,,,0.14206897,0.0508873973055536,0.43153396,0.1245353698214854,5348.0,0.22643007,0.072877947717994,2472.0,34586.41890668869,38039.82860040665,34586.41890668869,3450.242693901062,1.2768633365631104,0.0 -41100,1.7579973,1.0828335,,,,,,,,,,,,,, -41200,2.9618766,1.1316769,,,,,,,,,,,,,, -41300,3.41267,1.0813513,,,,,,,,,,,,,, -41400,2.9627929,1.1396679,,,,,,,,,,,,,, -41500,3.5638137,1.1577878,,,,,,,,,,,,,, -41600,3.2048383,1.1417086,,,,,,,,,,,,,, -41700,2.4894037,1.1709192,,,,,,,,,,,,,, -41800,3.151172,1.1208361,,,,,,,,,,,,,, -41900,3.761238,1.1065028,,,,,,,,,,,,,, -42000,2.2233942,1.1628764,,,,,,,,,,,,,, -42100,2.34946,1.1219319,,,,,,,,,,,,,, -42200,3.0220385,1.1008887,,,,,,,,,,,,,, -42300,1.8655303,1.0657487,,,,,,,,,,,,,, -42400,2.9606516,1.1471251,,,,,,,,,,,,,, -42500,2.1729743,1.1572514,,,,,,,,,,,,,, -42600,2.3790028,1.0908483,,,,,,,,,,,,,, -42700,1.6197025,1.1201056,,,,,,,,,,,,,, -42749,,,0.13182025,0.0478538685928601,0.42808613,0.1236857603522017,5348.0,0.22455934,0.0707452318566815,2472.0,36026.87464380264,39619.72004675865,36026.87464380264,3589.543283224106,1.3324298858642578,0.0 -42800,1.676474,1.0834378,,,,,,,,,,,,,, -42900,4.4405565,1.091524,,,,,,,,,,,,,, -43000,2.393553,1.1340778,,,,,,,,,,,,,, -43100,2.668796,1.1009067,,,,,,,,,,,,,, -43200,2.7580163,1.1053083,,,,,,,,,,,,,, -43300,7.9011574,1.2323097,,,,,,,,,,,,,, -43400,2.2895875,1.0950193,,,,,,,,,,,,,, -43500,2.608325,1.0946888,,,,,,,,,,,,,, -43600,4.2801585,1.1478732,,,,,,,,,,,,,, -43700,2.4341743,1.1517481,,,,,,,,,,,,,, -43800,1.7287563,1.1615566,,,,,,,,,,,,,, -43900,2.3983674,1.1243377,,,,,,,,,,,,,, -44000,2.6505566,1.1046137,,,,,,,,,,,,,, -44100,1.9828455,1.1006912,,,,,,,,,,,,,, -44200,3.4867404,1.1285205,,,,,,,,,,,,,, -44300,2.7226894,1.1297491,,,,,,,,,,,,,, -44400,2.6295695,1.0833974,,,,,,,,,,,,,, -44453,,,0.15589492,0.0517909262597647,0.42629203,0.1228071869237378,5348.0,0.22446273,0.0706233623788922,2472.0,37467.42148900032,41197.74384975433,37467.42148900032,3726.88182258606,1.390561580657959,0.0 -44500,2.8289626,1.1260513,,,,,,,,,,,,,, -44600,5.743124,1.1030469,,,,,,,,,,,,,, -44700,2.2321548,1.1827441,,,,,,,,,,,,,, -44800,2.3793778,1.1869931,,,,,,,,,,,,,, -44900,2.1594296,1.1475419,,,,,,,,,,,,,, -45000,7.298235,1.0615889,,,,,,,,,,,,,, -45100,5.141013,1.0931325,,,,,,,,,,,,,, -45200,2.678318,1.1441205,,,,,,,,,,,,,, -45300,2.1336174,1.0519102,,,,,,,,,,,,,, -45400,2.2279289,1.066632,,,,,,,,,,,,,, -45500,3.7350857,1.1575084,,,,,,,,,,,,,, -45600,2.2412922,1.079031,,,,,,,,,,,,,, -45700,2.271274,1.0999568,,,,,,,,,,,,,, -45800,2.1242654,1.0714219,,,,,,,,,,,,,, -45900,2.8901942,1.1023676,,,,,,,,,,,,,, -46000,2.474901,1.1542121,,,,,,,,,,,,,, -46100,2.143016,1.1121259,,,,,,,,,,,,,, -46141,,,0.13782987,0.0495388040712468,0.4256302,0.1224403101074562,5348.0,0.2233816,0.0706030507992606,2472.0,38907.607313632965,42776.08789777756,38907.607313632965,3864.898614168167,1.456418752670288,0.0 -46200,2.1348257,1.085608,,,,,,,,,,,,,, -46300,2.4062223,1.1065221,,,,,,,,,,,,,, -46400,3.457545,1.1520432,,,,,,,,,,,,,, -46500,2.5186183,1.1194648,,,,,,,,,,,,,, -46600,2.4012728,1.1098145,,,,,,,,,,,,,, -46700,2.7356348,1.1012921,,,,,,,,,,,,,, -46800,2.3531466,1.0816203,,,,,,,,,,,,,, -46900,2.9303174,1.1457657,,,,,,,,,,,,,, -47000,1.9544731,1.0627167,,,,,,,,,,,,,, -47100,2.2601047,1.0703605,,,,,,,,,,,,,, -47200,3.116267,1.0947425,,,,,,,,,,,,,, -47300,3.784694,1.1273395,,,,,,,,,,,,,, -47400,2.1673598,1.1358154,,,,,,,,,,,,,, -47500,1.7593735,1.0819935,,,,,,,,,,,,,, -47600,4.433989,1.1153057,,,,,,,,,,,,,, -47700,2.8029807,1.1553557,,,,,,,,,,,,,, -47800,3.628183,1.1281104,,,,,,,,,,,,,, -47843,,,0.118889384,0.0439292538030257,0.42518348,0.1222665263523755,5348.0,0.22332466,0.070542116060366,2472.0,40347.73718690872,44354.33743262291,40347.73718690872,4002.885982275009,1.5108904838562012,0.0 -47900,6.0704656,1.1320348,,,,,,,,,,,,,, -48000,,,0.12519462,0.0454118551089671,0.42519876,0.1222954903115556,5348.0,0.22333445,0.0704811813214713,2472.0,40467.03829741478,44602.15502882004,40467.03829741478,4131.328007936478,1.5726885795593262,0.0 -48000,,,,,,,,,,,40467.03829741478,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/eval_measurements.csv deleted file mode 100644 index ad1ef6c91..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -219.01320672035217,0.0,15.311807870864868,1,0,15.311807870864868,30.939787,2472,4.560619909410355,234.32508182525635,31.498302,4.839384963377855,30.813007,5348,4.233526748216303 -331.1859018802643,0.0329687595367431,1455.6672337055206,1709,0,1455.6672337055206,6.0562162,2472,0.888387869924644,1786.9586765766144,6.1106343,0.922439641032606,6.18663,5348,0.886992285932205 -455.1211357116699,0.0901768207550048,2895.917924642563,3429,0,2895.917924642563,4.1909637,2472,0.8059228566205594,3351.280608177185,4.5104766,0.8586798150220665,4.5426497,5348,0.8317097425104029 -595.8084809780121,0.140510082244873,4336.514578104019,5110,0,4336.514578104019,0.694737,2472,0.2199134726707696,4932.689397573471,0.7030768,0.22973511201229,1.0333956,5348,0.2909333153113143 -731.6756029129028,0.1918461322784423,5777.360721826553,6817,0,5777.360721826553,0.5300292,2472,0.172668738447789,6509.532005548477,0.5106037,0.170897716019303,0.83905643,5348,0.2401208762563117 -869.3080587387085,0.243248701095581,7217.927873134613,8552,0,7217.927873134613,0.4711684,2472,0.1546320557349745,8087.86146068573,0.4068033,0.1415767302328769,0.7618631,5348,0.2202419456056846 -1006.1574850082396,0.299619197845459,8658.387535095215,10235,0,8658.387535095215,0.4391571,2472,0.1402514573558385,9665.303247451782,0.38560042,0.1304098976405717,0.7276246,5348,0.2111665717292449 -1141.2911870479584,0.3535733222961426,10098.617139339449,11932,0,10098.617139339449,0.42614245,2472,0.1368797351370016,11240.799752235413,0.4022625,0.1347543821683618,0.71573716,5348,0.2057599660156212 -1276.9315330982208,0.4093668460845947,11538.98053908348,13643,0,11538.98053908348,0.42028874,2472,0.133893932931164,12816.938221931458,0.37538254,0.1249162449406496,0.7008779,5348,0.2028732247506685 -1417.6731894016266,0.4588265419006347,12979.097982645037,15334,0,12979.097982645037,0.3942068,2472,0.129222269615908,14397.922987937927,0.3517334,0.1216117779829691,0.66360116,5348,0.1924558540988829 -1553.7552139759064,0.5167844295501709,14419.062114953997,17023,0,14419.062114953997,0.37999377,2472,0.1233725346820222,15974.10542154312,0.3794418,0.1212333939606667,0.6458081,5348,0.1875705996505015 -1689.598258972168,0.5695686340332031,15859.043215751648,18717,0,15859.043215751648,0.36024228,2472,0.1158978733776125,17550.060350894928,0.28394952,0.0967809750394815,0.6158726,5348,0.1778387093659789 -1824.739266872406,0.6228508949279785,17299.182085752487,20414,0,17299.182085752487,0.3576674,2472,0.1166900249832429,19125.46996498108,0.3042081,0.1037550905184356,0.6088648,5348,0.1766029137742935 -1961.207276582718,0.6727504730224609,18739.56921505928,22123,0,18739.56921505928,0.34208703,2472,0.1118152458716714,20702.45235347748,0.3945768,0.1310924000263811,0.5891843,5348,0.1709356324280487 -2098.446216583252,0.7275550365447998,20179.91077518463,23821,0,20179.91077518463,0.33332494,2472,0.1055389677655231,22280.166800022125,0.4185183,0.1368200999229229,0.5801254,5348,0.1673344468366529 -2229.463135004044,0.7792990207672119,21620.78708958625,25522,0,21620.78708958625,0.32449305,2472,0.1053358519692076,23852.19028639793,0.45643923,0.1495846482272136,0.560944,5348,0.1622464446740106 -2360.788998126984,0.8352978229522705,23060.805052042007,27235,0,23060.805052042007,0.3108912,2472,0.0994861170353218,25423.667890787125,0.39219317,0.1276499589153656,0.5473959,5348,0.1596686523069793 -2496.8168222904205,0.8881025314331055,24500.7278380394,28931,0,24500.7278380394,0.2973004,2472,0.0954034895293807,26999.74779629708,0.35523257,0.1191117055052006,0.526833,5348,0.1538758604709539 -2635.162611246109,0.941051721572876,25941.15482234955,30635,0,25941.15482234955,0.28395033,2472,0.0912396157049133,28578.65523004532,0.30507174,0.1035346025532191,0.50986946,5348,0.1472334591656448 -2769.01868224144,1.0033252239227295,27381.489032030106,32348,0,27381.489032030106,0.28116906,2472,0.08896471878618,30152.98643374443,0.3419567,0.1145099365197094,0.5040512,5348,0.1467893451248829 -2906.028322458267,1.0652475357055664,28822.18370938301,34027,0,28822.18370938301,0.2677237,2472,0.0856336197266061,31730.82936573029,0.28990674,0.098016254945995,0.47910196,5348,0.1395676646359713 -3043.079663515091,1.1228668689727783,30262.275601148605,35723,0,30262.275601148605,0.26008534,2472,0.0820994048707167,33308.1079826355,0.27798745,0.0966597839292621,0.4688031,5348,0.1361112988404761 -3178.134963274002,1.1837427616119385,31702.6688849926,37430,0,31702.6688849926,0.25269088,2472,0.0801494932260881,34883.69545674324,0.27347106,0.0936722321497089,0.45819047,5348,0.1324328760246 -3313.579571247101,1.2389507293701172,33143.4619538784,39101,0,33143.4619538784,0.24324615,2472,0.0767168362683565,36460.06577825546,0.26958352,0.0919348818513308,0.44324875,5348,0.1282813752087818 -3448.6984837055206,1.2968308925628662,34584.25316166878,40799,0,34584.25316166878,0.23475613,2472,0.0738732151199398,38036.112648010254,0.25799805,0.0862188839556998,0.43293205,5348,0.1257615107601108 -3582.3909134864807,1.352367639541626,36024.23282289505,42492,0,36024.23282289505,0.22964896,2472,0.0732029329920987,39609.91804718971,0.22737612,0.0797861368033138,0.4232022,5348,0.1218996495360939 -3715.3618457317352,1.4084465503692627,37464.273461818695,44160,0,37464.273461818695,0.22551502,2472,0.0717608108382588,41183.06321454048,0.22378671,0.0777282363242545,0.4190319,5348,0.1214458808422719 -3849.345896005632,1.4675178527832031,38904.74507904053,45856,0,38904.74507904053,0.22425716,2472,0.0708467897548392,42757.65633749962,0.21560626,0.0740736650284389,0.4162607,5348,0.1205190341485078 -3982.570233345032,1.5255703926086426,40345.26863455773,47543,0,40345.26863455773,0.22386399,2472,0.0707046086974184,44331.542063474655,0.21957077,0.0754813021740605,0.41540474,5348,0.1202293945567066 -4119.21386384964,1.5837607383728027,40712.370055913925,48000,0,40712.370055913925,0.22391605,2472,0.0706233623788922,44835.370416641235,0.22205196,0.07654986522911052,0.41554552,5348,0.12041283296484741 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/measurements.csv deleted file mode 100644 index 03731ba41..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,18.487438,32.70639,,,,,,,,,,,,,, -1,,,31.498302,4.839384963377855,30.813007,4.233526748216303,5348.0,30.939787,4.560619909410355,2472.0,15.311807870864868,234.32508182525635,15.311807870864868,219.01320672035217,0.0,0.0 -100,5.0005803,8.5002165,,,,,,,,,,,,,, -200,1.1592822,6.3841434,,,,,,,,,,,,,, -300,0.5609198,6.026816,,,,,,,,,,,,,, -400,0.46473405,5.8857813,,,,,,,,,,,,,, -500,0.3990076,5.8496075,,,,,,,,,,,,,, -600,0.7212219,5.7690043,,,,,,,,,,,,,, -700,0.40575272,5.721276,,,,,,,,,,,,,, -800,0.47203374,5.6250153,,,,,,,,,,,,,, -900,0.76717025,5.383094,,,,,,,,,,,,,, -1000,0.6977035,5.0591297,,,,,,,,,,,,,, -1100,1.0413898,4.574701,,,,,,,,,,,,,, -1200,1.0461639,4.2231836,,,,,,,,,,,,,, -1300,1.7766134,3.8490791,,,,,,,,,,,,,, -1400,2.1875634,3.691352,,,,,,,,,,,,,, -1500,1.880938,3.5012696,,,,,,,,,,,,,, -1600,1.731078,3.3212733,,,,,,,,,,,,,, -1700,4.375469,3.2063892,,,,,,,,,,,,,, -1709,,,6.1106343,0.922439641032606,6.18663,0.886992285932205,5348.0,6.0562162,0.888387869924644,2472.0,1455.6672337055206,1786.9586765766144,1455.6672337055206,331.1859018802643,0.0329687595367431,0.0 -1800,2.4600825,3.0627303,,,,,,,,,,,,,, -1900,2.4114664,2.9562402,,,,,,,,,,,,,, -2000,2.07264,2.8466856,,,,,,,,,,,,,, -2100,2.869634,2.8192513,,,,,,,,,,,,,, -2200,2.5009007,2.747104,,,,,,,,,,,,,, -2300,3.2233846,2.658577,,,,,,,,,,,,,, -2400,3.9214654,2.639945,,,,,,,,,,,,,, -2500,2.5860963,2.5237446,,,,,,,,,,,,,, -2600,2.845735,2.371879,,,,,,,,,,,,,, -2700,2.79521,2.4392533,,,,,,,,,,,,,, -2800,4.4424806,2.395113,,,,,,,,,,,,,, -2900,2.9374878,2.326822,,,,,,,,,,,,,, -3000,2.9303586,2.2808123,,,,,,,,,,,,,, -3100,2.1612127,2.2000937,,,,,,,,,,,,,, -3200,4.1651073,2.2788632,,,,,,,,,,,,,, -3300,4.223356,2.2208753,,,,,,,,,,,,,, -3400,4.2858524,2.2150092,,,,,,,,,,,,,, -3429,,,4.5104766,0.8586798150220665,4.5426497,0.8317097425104029,5348.0,4.1909637,0.8059228566205594,2472.0,2895.917924642563,3351.280608177185,2895.917924642563,455.1211357116699,0.0901768207550048,0.0 -3500,2.6978273,2.1838944,,,,,,,,,,,,,, -3600,3.9436426,2.1633146,,,,,,,,,,,,,, -3700,3.2370193,2.135883,,,,,,,,,,,,,, -3800,4.027213,2.1247427,,,,,,,,,,,,,, -3900,4.2778835,2.0889003,,,,,,,,,,,,,, -4000,4.222257,2.0626814,,,,,,,,,,,,,, -4100,3.4404206,2.0320053,,,,,,,,,,,,,, -4200,4.165499,1.983971,,,,,,,,,,,,,, -4300,3.8082714,2.025269,,,,,,,,,,,,,, -4400,3.273907,1.9898214,,,,,,,,,,,,,, -4500,4.191954,2.0400183,,,,,,,,,,,,,, -4600,4.689402,1.9641247,,,,,,,,,,,,,, -4700,4.443674,1.9422508,,,,,,,,,,,,,, -4800,3.264025,1.9958204,,,,,,,,,,,,,, -4900,3.0839238,1.9018968,,,,,,,,,,,,,, -5000,3.2659862,1.8868322,,,,,,,,,,,,,, -5100,2.8179693,1.9848771,,,,,,,,,,,,,, -5110,,,0.7030768,0.22973511201229,1.0333956,0.2909333153113143,5348.0,0.694737,0.2199134726707696,2472.0,4336.514578104019,4932.689397573471,4336.514578104019,595.8084809780121,0.140510082244873,0.0 -5200,4.2574086,1.8961076,,,,,,,,,,,,,, -5300,3.7634015,1.804796,,,,,,,,,,,,,, -5400,3.8143551,1.8255255,,,,,,,,,,,,,, -5500,2.6777573,1.8942312,,,,,,,,,,,,,, -5600,3.1330688,1.8907207,,,,,,,,,,,,,, -5700,2.8065534,1.8801965,,,,,,,,,,,,,, -5800,3.9215598,1.8550445,,,,,,,,,,,,,, -5900,3.0011735,1.8347552,,,,,,,,,,,,,, -6000,3.1905165,1.7252871,,,,,,,,,,,,,, -6100,2.7682893,1.8312013,,,,,,,,,,,,,, -6200,2.178821,1.7839552,,,,,,,,,,,,,, -6300,2.576068,1.8208559,,,,,,,,,,,,,, -6400,3.471674,1.7608877,,,,,,,,,,,,,, -6500,3.5277698,1.7978003,,,,,,,,,,,,,, -6600,3.2313411,1.8498328,,,,,,,,,,,,,, -6700,2.4181058,1.7903352,,,,,,,,,,,,,, -6800,2.3457916,1.7840517,,,,,,,,,,,,,, -6817,,,0.5106037,0.170897716019303,0.83905643,0.2401208762563117,5348.0,0.5300292,0.172668738447789,2472.0,5777.360721826553,6509.532005548477,5777.360721826553,731.6756029129028,0.1918461322784423,0.0 -6900,2.6102607,1.7516919,,,,,,,,,,,,,, -7000,2.8454537,1.7662851,,,,,,,,,,,,,, -7100,5.2619724,1.7433993,,,,,,,,,,,,,, -7200,2.7687232,1.7632252,,,,,,,,,,,,,, -7300,4.2967796,1.7860546,,,,,,,,,,,,,, -7400,2.8441699,1.7351594,,,,,,,,,,,,,, -7500,2.289072,1.71259,,,,,,,,,,,,,, -7600,2.2775517,1.7954637,,,,,,,,,,,,,, -7700,3.0926259,1.7017146,,,,,,,,,,,,,, -7800,2.398302,1.7341528,,,,,,,,,,,,,, -7900,4.367555,1.7271591,,,,,,,,,,,,,, -8000,3.2061813,1.7053672,,,,,,,,,,,,,, -8100,3.5699725,1.7101537,,,,,,,,,,,,,, -8200,2.7356334,1.7725753,,,,,,,,,,,,,, -8300,2.1066573,1.6333147,,,,,,,,,,,,,, -8400,2.4869983,1.6509897,,,,,,,,,,,,,, -8500,2.4115243,1.6561961,,,,,,,,,,,,,, -8552,,,0.4068033,0.1415767302328769,0.7618631,0.2202419456056846,5348.0,0.4711684,0.1546320557349745,2472.0,7217.927873134613,8087.86146068573,7217.927873134613,869.3080587387085,0.243248701095581,0.0 -8600,2.6859365,1.7026377,,,,,,,,,,,,,, -8700,2.2990596,1.6361328,,,,,,,,,,,,,, -8800,2.9208472,1.6957899,,,,,,,,,,,,,, -8900,4.8146453,1.6655021,,,,,,,,,,,,,, -9000,2.7251413,1.6568018,,,,,,,,,,,,,, -9100,2.8905778,1.7682742,,,,,,,,,,,,,, -9200,3.510537,1.6399149,,,,,,,,,,,,,, -9300,2.6902144,1.6875182,,,,,,,,,,,,,, -9400,3.5838354,1.6445671,,,,,,,,,,,,,, -9500,3.4641428,1.7067695,,,,,,,,,,,,,, -9600,4.986451,1.7197908,,,,,,,,,,,,,, -9700,2.0445397,1.6423266,,,,,,,,,,,,,, -9800,3.4016643,1.6519204,,,,,,,,,,,,,, -9900,3.460481,1.644082,,,,,,,,,,,,,, -10000,2.459788,1.6189436,,,,,,,,,,,,,, -10100,3.2419858,1.6257005,,,,,,,,,,,,,, -10200,3.648414,1.6230437,,,,,,,,,,,,,, -10235,,,0.38560042,0.1304098976405717,0.7276246,0.2111665717292449,5348.0,0.4391571,0.1402514573558385,2472.0,8658.387535095215,9665.303247451782,8658.387535095215,1006.1574850082396,0.299619197845459,0.0 -10300,3.2554288,1.5719264,,,,,,,,,,,,,, -10400,3.2874799,1.6792759,,,,,,,,,,,,,, -10500,2.6625624,1.5794266,,,,,,,,,,,,,, -10600,3.2901335,1.6414471,,,,,,,,,,,,,, -10700,4.3133965,1.6411016,,,,,,,,,,,,,, -10800,5.424246,1.6143384,,,,,,,,,,,,,, -10900,3.2924516,1.660441,,,,,,,,,,,,,, -11000,2.359229,1.637153,,,,,,,,,,,,,, -11100,2.2032406,1.6281614,,,,,,,,,,,,,, -11200,2.7250948,1.6268057,,,,,,,,,,,,,, -11300,2.64582,1.670687,,,,,,,,,,,,,, -11400,1.97053,1.5573187,,,,,,,,,,,,,, -11500,3.455917,1.5875299,,,,,,,,,,,,,, -11600,2.7871706,1.5553098,,,,,,,,,,,,,, -11700,3.555403,1.6263579,,,,,,,,,,,,,, -11800,4.1547403,1.6428264,,,,,,,,,,,,,, -11900,4.189207,1.5975283,,,,,,,,,,,,,, -11932,,,0.4022625,0.1347543821683618,0.71573716,0.2057599660156212,5348.0,0.42614245,0.1368797351370016,2472.0,10098.617139339449,11240.799752235413,10098.617139339449,1141.2911870479584,0.3535733222961426,0.0 -12000,3.607084,1.6034697,,,,,,,,,,,,,, -12100,3.110012,1.6218941,,,,,,,,,,,,,, -12200,2.0579073,1.6011441,,,,,,,,,,,,,, -12300,3.2813861,1.6099603,,,,,,,,,,,,,, -12400,2.642521,1.580072,,,,,,,,,,,,,, -12500,2.2376387,1.5816498,,,,,,,,,,,,,, -12600,2.4195755,1.6700882,,,,,,,,,,,,,, -12700,3.1790326,1.6159209,,,,,,,,,,,,,, -12800,5.22597,1.5781955,,,,,,,,,,,,,, -12900,2.7492523,1.5540476,,,,,,,,,,,,,, -13000,3.6386545,1.7102572,,,,,,,,,,,,,, -13100,3.508628,1.6230069,,,,,,,,,,,,,, -13200,3.454226,1.6317537,,,,,,,,,,,,,, -13300,3.5719917,1.5767769,,,,,,,,,,,,,, -13400,7.45549,1.5717288,,,,,,,,,,,,,, -13500,2.5879033,1.5876771,,,,,,,,,,,,,, -13600,3.9392304,1.6422026,,,,,,,,,,,,,, -13643,,,0.37538254,0.1249162449406496,0.7008779,0.2028732247506685,5348.0,0.42028874,0.133893932931164,2472.0,11538.98053908348,12816.938221931458,11538.98053908348,1276.9315330982208,0.4093668460845947,0.0 -13700,3.355822,1.5943767,,,,,,,,,,,,,, -13800,2.848196,1.56253,,,,,,,,,,,,,, -13900,2.3584628,1.5777327,,,,,,,,,,,,,, -14000,2.23274,1.623078,,,,,,,,,,,,,, -14100,3.289934,1.6324207,,,,,,,,,,,,,, -14200,2.4128757,1.573517,,,,,,,,,,,,,, -14300,3.4441013,1.501353,,,,,,,,,,,,,, -14400,2.9675326,1.5755095,,,,,,,,,,,,,, -14500,3.5834355,1.6269107,,,,,,,,,,,,,, -14600,3.0870922,1.5151203,,,,,,,,,,,,,, -14700,3.5483313,1.5522618,,,,,,,,,,,,,, -14800,2.6586258,1.587207,,,,,,,,,,,,,, -14900,2.5127957,1.649228,,,,,,,,,,,,,, -15000,2.404312,1.542026,,,,,,,,,,,,,, -15100,3.8180034,1.5847182,,,,,,,,,,,,,, -15200,2.8913596,1.5203867,,,,,,,,,,,,,, -15300,3.3806136,1.6098369,,,,,,,,,,,,,, -15334,,,0.3517334,0.1216117779829691,0.66360116,0.1924558540988829,5348.0,0.3942068,0.129222269615908,2472.0,12979.097982645037,14397.922987937927,12979.097982645037,1417.6731894016266,0.4588265419006347,0.0 -15400,2.5748148,1.5285916,,,,,,,,,,,,,, -15500,2.8860402,1.5736432,,,,,,,,,,,,,, -15600,2.4457643,1.6168652,,,,,,,,,,,,,, -15700,3.5776277,1.5978347,,,,,,,,,,,,,, -15800,3.587093,1.5903358,,,,,,,,,,,,,, -15900,2.1261044,1.5712276,,,,,,,,,,,,,, -16000,2.5121305,1.5512786,,,,,,,,,,,,,, -16100,2.7788482,1.6344686,,,,,,,,,,,,,, -16200,3.348154,1.6176744,,,,,,,,,,,,,, -16300,4.374826,1.5678289,,,,,,,,,,,,,, -16400,1.9624221,1.4983487,,,,,,,,,,,,,, -16500,2.7852697,1.5225552,,,,,,,,,,,,,, -16600,2.4863966,1.5617769,,,,,,,,,,,,,, -16700,2.922268,1.5615735,,,,,,,,,,,,,, -16800,2.8495376,1.5544881,,,,,,,,,,,,,, -16900,2.9120352,1.5056694,,,,,,,,,,,,,, -17000,4.9664927,1.6415496,,,,,,,,,,,,,, -17023,,,0.3794418,0.1212333939606667,0.6458081,0.1875705996505015,5348.0,0.37999377,0.1233725346820222,2472.0,14419.062114953997,15974.10542154312,14419.062114953997,1553.7552139759064,0.5167844295501709,0.0 -17100,3.2007282,1.5716394,,,,,,,,,,,,,, -17200,3.105161,1.5331599,,,,,,,,,,,,,, -17300,3.181311,1.509129,,,,,,,,,,,,,, -17400,2.658216,1.5711874,,,,,,,,,,,,,, -17500,2.735865,1.571382,,,,,,,,,,,,,, -17600,4.4960947,1.5025172,,,,,,,,,,,,,, -17700,2.8608537,1.5652313,,,,,,,,,,,,,, -17800,2.2889018,1.5153763,,,,,,,,,,,,,, -17900,3.8833463,1.5030799,,,,,,,,,,,,,, -18000,4.0349193,1.5254308,,,,,,,,,,,,,, -18100,1.9310429,1.510597,,,,,,,,,,,,,, -18200,3.7913167,1.5773422,,,,,,,,,,,,,, -18300,2.0327632,1.5613514,,,,,,,,,,,,,, -18400,4.519606,1.5280571,,,,,,,,,,,,,, -18500,3.3209429,1.553655,,,,,,,,,,,,,, -18600,3.3816786,1.5243679,,,,,,,,,,,,,, -18700,2.922788,1.4852651,,,,,,,,,,,,,, -18717,,,0.28394952,0.0967809750394815,0.6158726,0.1778387093659789,5348.0,0.36024228,0.1158978733776125,2472.0,15859.043215751648,17550.060350894928,15859.043215751648,1689.598258972168,0.5695686340332031,0.0 -18800,4.0136976,1.4579709,,,,,,,,,,,,,, -18900,2.0114691,1.5463157,,,,,,,,,,,,,, -19000,3.497759,1.4890525,,,,,,,,,,,,,, -19100,3.7268188,1.5415407,,,,,,,,,,,,,, -19200,2.4446237,1.4307837,,,,,,,,,,,,,, -19300,3.0890198,1.4674815,,,,,,,,,,,,,, -19400,2.7518384,1.5427393,,,,,,,,,,,,,, -19500,2.55262,1.4397843,,,,,,,,,,,,,, -19600,3.5718555,1.4643818,,,,,,,,,,,,,, -19700,2.654087,1.49873,,,,,,,,,,,,,, -19800,3.6067007,1.5335667,,,,,,,,,,,,,, -19900,2.5045786,1.4403753,,,,,,,,,,,,,, -20000,2.2189877,1.4346243,,,,,,,,,,,,,, -20100,3.2200346,1.5898509,,,,,,,,,,,,,, -20200,2.568045,1.5384406,,,,,,,,,,,,,, -20300,2.4836311,1.4963624,,,,,,,,,,,,,, -20400,2.1388144,1.5011652,,,,,,,,,,,,,, -20414,,,0.3042081,0.1037550905184356,0.6088648,0.1766029137742935,5348.0,0.3576674,0.1166900249832429,2472.0,17299.182085752487,19125.46996498108,17299.182085752487,1824.739266872406,0.6228508949279785,0.0 -20500,2.7832582,1.590348,,,,,,,,,,,,,, -20600,3.0755272,1.5167958,,,,,,,,,,,,,, -20700,2.3551052,1.4198464,,,,,,,,,,,,,, -20800,2.935667,1.4847965,,,,,,,,,,,,,, -20900,2.4793396,1.5133113,,,,,,,,,,,,,, -21000,2.3158457,1.5026436,,,,,,,,,,,,,, -21100,2.2854846,1.4098103,,,,,,,,,,,,,, -21200,3.015674,1.497635,,,,,,,,,,,,,, -21300,2.866358,1.4482514,,,,,,,,,,,,,, -21400,2.254071,1.4769094,,,,,,,,,,,,,, -21500,2.6010044,1.5076951,,,,,,,,,,,,,, -21600,2.8198316,1.4923179,,,,,,,,,,,,,, -21700,3.207907,1.5028375,,,,,,,,,,,,,, -21800,2.505027,1.5030705,,,,,,,,,,,,,, -21900,2.220134,1.5136063,,,,,,,,,,,,,, -22000,2.9713187,1.495287,,,,,,,,,,,,,, -22100,3.537603,1.4687622,,,,,,,,,,,,,, -22123,,,0.3945768,0.1310924000263811,0.5891843,0.1709356324280487,5348.0,0.34208703,0.1118152458716714,2472.0,18739.56921505928,20702.45235347748,18739.56921505928,1961.207276582718,0.6727504730224609,0.0 -22200,4.0309,1.5259932,,,,,,,,,,,,,, -22300,3.8591907,1.4537762,,,,,,,,,,,,,, -22400,3.8141072,1.43685,,,,,,,,,,,,,, -22500,3.0850775,1.459701,,,,,,,,,,,,,, -22600,2.9640365,1.5009165,,,,,,,,,,,,,, -22700,1.6366155,1.4247798,,,,,,,,,,,,,, -22800,2.5892384,1.497737,,,,,,,,,,,,,, -22900,2.7621846,1.4617239,,,,,,,,,,,,,, -23000,2.358132,1.4136168,,,,,,,,,,,,,, -23100,1.8444911,1.4099451,,,,,,,,,,,,,, -23200,2.052072,1.4964198,,,,,,,,,,,,,, -23300,2.5423777,1.4291921,,,,,,,,,,,,,, -23400,2.0757923,1.4873711,,,,,,,,,,,,,, -23500,2.57595,1.467762,,,,,,,,,,,,,, -23600,3.2449272,1.4602746,,,,,,,,,,,,,, -23700,2.89008,1.3838055,,,,,,,,,,,,,, -23800,2.9837296,1.4670401,,,,,,,,,,,,,, -23821,,,0.4185183,0.1368200999229229,0.5801254,0.1673344468366529,5348.0,0.33332494,0.1055389677655231,2472.0,20179.91077518463,22280.166800022125,20179.91077518463,2098.446216583252,0.7275550365447998,0.0 -23900,2.599671,1.5286673,,,,,,,,,,,,,, -24000,3.8309884,1.4596382,,,,,,,,,,,,,, -24100,2.4468381,1.442448,,,,,,,,,,,,,, -24200,2.170604,1.4375044,,,,,,,,,,,,,, -24300,3.2781494,1.3543411,,,,,,,,,,,,,, -24400,3.440276,1.3945862,,,,,,,,,,,,,, -24500,2.706368,1.3539214,,,,,,,,,,,,,, -24600,2.7123106,1.4243016,,,,,,,,,,,,,, -24700,2.4728239,1.3953922,,,,,,,,,,,,,, -24800,3.2705708,1.3793658,,,,,,,,,,,,,, -24900,2.5744684,1.3221766,,,,,,,,,,,,,, -25000,3.4272192,1.4415965,,,,,,,,,,,,,, -25100,2.4457538,1.4213165,,,,,,,,,,,,,, -25200,2.8126693,1.4109353,,,,,,,,,,,,,, -25300,3.4549272,1.4449838,,,,,,,,,,,,,, -25400,2.5021205,1.4507226,,,,,,,,,,,,,, -25500,3.3147988,1.5130919,,,,,,,,,,,,,, -25522,,,0.45643923,0.1495846482272136,0.560944,0.1622464446740106,5348.0,0.32449305,0.1053358519692076,2472.0,21620.78708958625,23852.19028639793,21620.78708958625,2229.463135004044,0.7792990207672119,0.0 -25600,2.2756872,1.3923591,,,,,,,,,,,,,, -25700,4.2016883,1.4057466,,,,,,,,,,,,,, -25800,2.8390427,1.385488,,,,,,,,,,,,,, -25900,3.8590262,1.4364924,,,,,,,,,,,,,, -26000,4.213318,1.3714836,,,,,,,,,,,,,, -26100,2.4916255,1.482004,,,,,,,,,,,,,, -26200,2.8592856,1.3505614,,,,,,,,,,,,,, -26300,3.37683,1.3899342,,,,,,,,,,,,,, -26400,2.9098,1.3485458,,,,,,,,,,,,,, -26500,3.1427853,1.4352033,,,,,,,,,,,,,, -26600,2.735608,1.431586,,,,,,,,,,,,,, -26700,3.0756168,1.4031209,,,,,,,,,,,,,, -26800,2.8738747,1.3954992,,,,,,,,,,,,,, -26900,2.7423577,1.4254457,,,,,,,,,,,,,, -27000,2.097663,1.3262491,,,,,,,,,,,,,, -27100,3.0472193,1.3531199,,,,,,,,,,,,,, -27200,3.0380795,1.3442868,,,,,,,,,,,,,, -27235,,,0.39219317,0.1276499589153656,0.5473959,0.1596686523069793,5348.0,0.3108912,0.0994861170353218,2472.0,23060.805052042007,25423.667890787125,23060.805052042007,2360.788998126984,0.8352978229522705,0.0 -27300,2.8390656,1.4160408,,,,,,,,,,,,,, -27400,2.6303656,1.4588265,,,,,,,,,,,,,, -27500,2.962482,1.3679293,,,,,,,,,,,,,, -27600,2.6709197,1.3799459,,,,,,,,,,,,,, -27700,2.2321322,1.416902,,,,,,,,,,,,,, -27800,1.9099199,1.3739487,,,,,,,,,,,,,, -27900,3.3632948,1.3342324,,,,,,,,,,,,,, -28000,3.1719162,1.3241346,,,,,,,,,,,,,, -28100,3.342795,1.3958722,,,,,,,,,,,,,, -28200,2.7901163,1.3436699,,,,,,,,,,,,,, -28300,2.1759667,1.3590882,,,,,,,,,,,,,, -28400,3.5788105,1.335175,,,,,,,,,,,,,, -28500,2.524202,1.3943812,,,,,,,,,,,,,, -28600,2.2159579,1.3988804,,,,,,,,,,,,,, -28700,3.4185388,1.3823673,,,,,,,,,,,,,, -28800,2.0409384,1.3644753,,,,,,,,,,,,,, -28900,2.6332445,1.3324072,,,,,,,,,,,,,, -28931,,,0.35523257,0.1191117055052006,0.526833,0.1538758604709539,5348.0,0.2973004,0.0954034895293807,2472.0,24500.7278380394,26999.74779629708,24500.7278380394,2496.8168222904205,0.8881025314331055,0.0 -29000,2.4205363,1.4027511,,,,,,,,,,,,,, -29100,2.5694895,1.3814191,,,,,,,,,,,,,, -29200,3.828744,1.3384105,,,,,,,,,,,,,, -29300,2.8034039,1.3853441,,,,,,,,,,,,,, -29400,3.058484,1.3772306,,,,,,,,,,,,,, -29500,3.9380052,1.3318727,,,,,,,,,,,,,, -29600,2.344829,1.3605095,,,,,,,,,,,,,, -29700,2.794355,1.3707497,,,,,,,,,,,,,, -29800,2.4604406,1.3774552,,,,,,,,,,,,,, -29900,3.2012026,1.3359587,,,,,,,,,,,,,, -30000,3.1339507,1.365905,,,,,,,,,,,,,, -30100,5.509609,1.3735464,,,,,,,,,,,,,, -30200,3.0233793,1.3599765,,,,,,,,,,,,,, -30300,2.5576463,1.324036,,,,,,,,,,,,,, -30400,2.5458655,1.3268455,,,,,,,,,,,,,, -30500,4.249021,1.3750813,,,,,,,,,,,,,, -30600,2.405459,1.3116945,,,,,,,,,,,,,, -30635,,,0.30507174,0.1035346025532191,0.50986946,0.1472334591656448,5348.0,0.28395033,0.0912396157049133,2472.0,25941.15482234955,28578.65523004532,25941.15482234955,2635.162611246109,0.941051721572876,0.0 -30700,3.7217705,1.2726517,,,,,,,,,,,,,, -30800,2.1488936,1.3196607,,,,,,,,,,,,,, -30900,2.055701,1.2976999,,,,,,,,,,,,,, -31000,2.7635682,1.359249,,,,,,,,,,,,,, -31100,5.077847,1.3217485,,,,,,,,,,,,,, -31200,2.421792,1.3127955,,,,,,,,,,,,,, -31300,2.5968869,1.2525907,,,,,,,,,,,,,, -31400,2.6837113,1.290119,,,,,,,,,,,,,, -31500,2.9502497,1.3064618,,,,,,,,,,,,,, -31600,2.6338449,1.3108734,,,,,,,,,,,,,, -31700,2.746344,1.3951198,,,,,,,,,,,,,, -31800,3.2612624,1.2789567,,,,,,,,,,,,,, -31900,3.4314687,1.3628474,,,,,,,,,,,,,, -32000,2.9722354,1.3029755,,,,,,,,,,,,,, -32100,2.9180107,1.403588,,,,,,,,,,,,,, -32200,2.967512,1.2924927,,,,,,,,,,,,,, -32300,2.7854013,1.326554,,,,,,,,,,,,,, -32348,,,0.3419567,0.1145099365197094,0.5040512,0.1467893451248829,5348.0,0.28116906,0.08896471878618,2472.0,27381.489032030106,30152.98643374443,27381.489032030106,2769.01868224144,1.0033252239227295,0.0 -32400,3.4479182,1.3199136,,,,,,,,,,,,,, -32500,2.299672,1.2888608,,,,,,,,,,,,,, -32600,2.4770563,1.2819172,,,,,,,,,,,,,, -32700,4.419465,1.3331531,,,,,,,,,,,,,, -32800,3.266632,1.3384728,,,,,,,,,,,,,, -32900,2.3477647,1.3204752,,,,,,,,,,,,,, -33000,2.4562328,1.3331351,,,,,,,,,,,,,, -33100,2.1439853,1.2512486,,,,,,,,,,,,,, -33200,2.5036633,1.3217922,,,,,,,,,,,,,, -33300,3.6525776,1.2755665,,,,,,,,,,,,,, -33400,2.409271,1.2589926,,,,,,,,,,,,,, -33500,2.2451558,1.2578837,,,,,,,,,,,,,, -33600,2.2897754,1.3462087,,,,,,,,,,,,,, -33700,2.6638215,1.2849869,,,,,,,,,,,,,, -33800,2.5071278,1.2718028,,,,,,,,,,,,,, -33900,2.4134865,1.3325698,,,,,,,,,,,,,, -34000,2.351575,1.2831938,,,,,,,,,,,,,, -34027,,,0.28990674,0.098016254945995,0.47910196,0.1395676646359713,5348.0,0.2677237,0.0856336197266061,2472.0,28822.18370938301,31730.82936573029,28822.18370938301,2906.028322458267,1.0652475357055664,0.0 -34100,2.6047382,1.2797507,,,,,,,,,,,,,, -34200,3.9010224,1.2629219,,,,,,,,,,,,,, -34300,2.9038458,1.2704344,,,,,,,,,,,,,, -34400,4.1698704,1.2542899,,,,,,,,,,,,,, -34500,2.5118897,1.226123,,,,,,,,,,,,,, -34600,3.1013303,1.2523575,,,,,,,,,,,,,, -34700,1.9495318,1.2848036,,,,,,,,,,,,,, -34800,3.2426848,1.2911421,,,,,,,,,,,,,, -34900,4.723919,1.2868265,,,,,,,,,,,,,, -35000,2.2623656,1.2961313,,,,,,,,,,,,,, -35100,2.9248073,1.276531,,,,,,,,,,,,,, -35200,2.9364848,1.2333531,,,,,,,,,,,,,, -35300,5.955417,1.2440693,,,,,,,,,,,,,, -35400,2.9348722,1.168016,,,,,,,,,,,,,, -35500,2.155422,1.2577099,,,,,,,,,,,,,, -35600,2.7446795,1.2793897,,,,,,,,,,,,,, -35700,2.6216059,1.2759509,,,,,,,,,,,,,, -35723,,,0.27798745,0.0966597839292621,0.4688031,0.1361112988404761,5348.0,0.26008534,0.0820994048707167,2472.0,30262.275601148605,33308.1079826355,30262.275601148605,3043.079663515091,1.1228668689727783,0.0 -35800,4.170672,1.2131736,,,,,,,,,,,,,, -35900,2.9036925,1.294493,,,,,,,,,,,,,, -36000,3.0758042,1.2781471,,,,,,,,,,,,,, -36100,3.0983953,1.2531347,,,,,,,,,,,,,, -36200,4.150625,1.244077,,,,,,,,,,,,,, -36300,2.7039032,1.2068919,,,,,,,,,,,,,, -36400,2.7877326,1.2263572,,,,,,,,,,,,,, -36500,2.7102063,1.2800863,,,,,,,,,,,,,, -36600,3.1596558,1.2780458,,,,,,,,,,,,,, -36700,3.2015798,1.1777111,,,,,,,,,,,,,, -36800,3.7027988,1.2292106,,,,,,,,,,,,,, -36900,3.054383,1.2532436,,,,,,,,,,,,,, -37000,3.6605914,1.2065319,,,,,,,,,,,,,, -37100,6.103061,1.1838602,,,,,,,,,,,,,, -37200,2.6261127,1.2368202,,,,,,,,,,,,,, -37300,2.7686036,1.2629799,,,,,,,,,,,,,, -37400,2.414546,1.20928,,,,,,,,,,,,,, -37430,,,0.27347106,0.0936722321497089,0.45819047,0.1324328760246,5348.0,0.25269088,0.0801494932260881,2472.0,31702.6688849926,34883.69545674324,31702.6688849926,3178.134963274002,1.1837427616119385,0.0 -37500,4.518216,1.1602256,,,,,,,,,,,,,, -37600,4.989013,1.2046759,,,,,,,,,,,,,, -37700,3.4113073,1.2538829,,,,,,,,,,,,,, -37800,5.605469,1.2105738,,,,,,,,,,,,,, -37900,2.5876267,1.2035316,,,,,,,,,,,,,, -38000,2.7990894,1.2353586,,,,,,,,,,,,,, -38100,2.1447928,1.1650801,,,,,,,,,,,,,, -38200,2.5349402,1.2391646,,,,,,,,,,,,,, -38300,2.9583137,1.2319013,,,,,,,,,,,,,, -38400,4.000343,1.1379552,,,,,,,,,,,,,, -38500,2.5242717,1.1948411,,,,,,,,,,,,,, -38600,2.827934,1.2105113,,,,,,,,,,,,,, -38700,2.7622805,1.2283949,,,,,,,,,,,,,, -38800,3.899833,1.1910056,,,,,,,,,,,,,, -38900,4.410442,1.1871102,,,,,,,,,,,,,, -39000,2.7374449,1.1903424,,,,,,,,,,,,,, -39100,2.2124972,1.2232233,,,,,,,,,,,,,, -39101,,,0.26958352,0.0919348818513308,0.44324875,0.1282813752087818,5348.0,0.24324615,0.0767168362683565,2472.0,33143.4619538784,36460.06577825546,33143.4619538784,3313.579571247101,1.2389507293701172,0.0 -39200,3.756494,1.2012316,,,,,,,,,,,,,, -39300,2.665072,1.2102631,,,,,,,,,,,,,, -39400,3.9771636,1.181906,,,,,,,,,,,,,, -39500,3.0343373,1.2088488,,,,,,,,,,,,,, -39600,2.748872,1.1687361,,,,,,,,,,,,,, -39700,3.3493056,1.1642054,,,,,,,,,,,,,, -39800,2.7156072,1.2458256,,,,,,,,,,,,,, -39900,2.9435744,1.1453539,,,,,,,,,,,,,, -40000,4.180884,1.2175748,,,,,,,,,,,,,, -40100,3.5590215,1.2092576,,,,,,,,,,,,,, -40200,3.3811193,1.1654443,,,,,,,,,,,,,, -40300,2.4950588,1.2019624,,,,,,,,,,,,,, -40400,3.5969949,1.1839072,,,,,,,,,,,,,, -40500,3.1225939,1.2044318,,,,,,,,,,,,,, -40600,3.708546,1.205639,,,,,,,,,,,,,, -40700,4.419181,1.2050838,,,,,,,,,,,,,, -40799,,,0.25799805,0.0862188839556998,0.43293205,0.1257615107601108,5348.0,0.23475613,0.0738732151199398,2472.0,34584.25316166878,38036.112648010254,34584.25316166878,3448.6984837055206,1.2968308925628662,0.0 -40800,3.3691683,1.1652374,,,,,,,,,,,,,, -40900,2.9494445,1.2008481,,,,,,,,,,,,,, -41000,2.1442993,1.2041564,,,,,,,,,,,,,, -41100,3.7366426,1.1318114,,,,,,,,,,,,,, -41200,3.126916,1.1095122,,,,,,,,,,,,,, -41300,2.6870892,1.1891313,,,,,,,,,,,,,, -41400,2.6496456,1.1811393,,,,,,,,,,,,,, -41500,2.567533,1.1289016,,,,,,,,,,,,,, -41600,3.4259157,1.1796781,,,,,,,,,,,,,, -41700,3.2338934,1.1484375,,,,,,,,,,,,,, -41800,2.6218534,1.1073229,,,,,,,,,,,,,, -41900,3.1939187,1.1318703,,,,,,,,,,,,,, -42000,2.8329325,1.1968052,,,,,,,,,,,,,, -42100,3.2779906,1.1168461,,,,,,,,,,,,,, -42200,2.7234228,1.1945373,,,,,,,,,,,,,, -42300,4.2805243,1.1188418,,,,,,,,,,,,,, -42400,2.3398347,1.1523464,,,,,,,,,,,,,, -42492,,,0.22737612,0.0797861368033138,0.4232022,0.1218996495360939,5348.0,0.22964896,0.0732029329920987,2472.0,36024.23282289505,39609.91804718971,36024.23282289505,3582.3909134864807,1.352367639541626,0.0 -42500,2.7025914,1.1125089,,,,,,,,,,,,,, -42600,2.8679914,1.154611,,,,,,,,,,,,,, -42700,2.0677185,1.1552846,,,,,,,,,,,,,, -42800,2.382438,1.1722468,,,,,,,,,,,,,, -42900,2.8899112,1.2141901,,,,,,,,,,,,,, -43000,4.518305,1.1690427,,,,,,,,,,,,,, -43100,3.1384676,1.0968324,,,,,,,,,,,,,, -43200,2.747848,1.1350387,,,,,,,,,,,,,, -43300,2.495235,1.1563282,,,,,,,,,,,,,, -43400,3.0952947,1.1275406,,,,,,,,,,,,,, -43500,2.538274,1.1509832,,,,,,,,,,,,,, -43600,4.2546377,1.1571183,,,,,,,,,,,,,, -43700,3.7933958,1.1055273,,,,,,,,,,,,,, -43800,3.580705,1.2004378,,,,,,,,,,,,,, -43900,2.8270342,1.118313,,,,,,,,,,,,,, -44000,2.6768565,1.1614572,,,,,,,,,,,,,, -44100,3.0803678,1.1533854,,,,,,,,,,,,,, -44160,,,0.22378671,0.0777282363242545,0.4190319,0.1214458808422719,5348.0,0.22551502,0.0717608108382588,2472.0,37464.273461818695,41183.06321454048,37464.273461818695,3715.3618457317352,1.4084465503692627,0.0 -44200,2.3912144,1.1441352,,,,,,,,,,,,,, -44300,3.4266348,1.0966839,,,,,,,,,,,,,, -44400,3.5226974,1.1398137,,,,,,,,,,,,,, -44500,4.882169,1.1326679,,,,,,,,,,,,,, -44600,5.0753894,1.1126939,,,,,,,,,,,,,, -44700,4.2510066,1.1383711,,,,,,,,,,,,,, -44800,2.638841,1.1801512,,,,,,,,,,,,,, -44900,2.9744952,1.1113063,,,,,,,,,,,,,, -45000,5.6809254,1.1064681,,,,,,,,,,,,,, -45100,2.427511,1.1376618,,,,,,,,,,,,,, -45200,3.1907442,1.1989436,,,,,,,,,,,,,, -45300,2.4438133,1.0753173,,,,,,,,,,,,,, -45400,3.4316204,1.0758936,,,,,,,,,,,,,, -45500,2.4066222,1.1621009,,,,,,,,,,,,,, -45600,3.712693,1.0928646,,,,,,,,,,,,,, -45700,3.949252,1.1597618,,,,,,,,,,,,,, -45800,2.4325848,1.1425753,,,,,,,,,,,,,, -45856,,,0.21560626,0.0740736650284389,0.4162607,0.1205190341485078,5348.0,0.22425716,0.0708467897548392,2472.0,38904.74507904053,42757.65633749962,38904.74507904053,3849.345896005632,1.4675178527832031,0.0 -45900,3.393946,1.128224,,,,,,,,,,,,,, -46000,2.7830384,1.1745546,,,,,,,,,,,,,, -46100,3.5796618,1.1629272,,,,,,,,,,,,,, -46200,4.5791426,1.1343299,,,,,,,,,,,,,, -46300,3.2551265,1.1345091,,,,,,,,,,,,,, -46400,3.5168698,1.1551623,,,,,,,,,,,,,, -46500,4.938607,1.1286877,,,,,,,,,,,,,, -46600,2.7925277,1.1440463,,,,,,,,,,,,,, -46700,3.7116046,1.1373253,,,,,,,,,,,,,, -46800,3.0657609,1.1751482,,,,,,,,,,,,,, -46900,2.1900835,1.1495901,,,,,,,,,,,,,, -47000,3.6719801,1.1254637,,,,,,,,,,,,,, -47100,2.7516313,1.110599,,,,,,,,,,,,,, -47200,4.2568974,1.1291796,,,,,,,,,,,,,, -47300,4.1916895,1.1342545,,,,,,,,,,,,,, -47400,4.280519,1.1221999,,,,,,,,,,,,,, -47500,3.5849538,1.1157289,,,,,,,,,,,,,, -47543,,,0.21957077,0.0754813021740605,0.41540474,0.1202293945567066,5348.0,0.22386399,0.0707046086974184,2472.0,40345.26863455773,44331.542063474655,40345.26863455773,3982.570233345032,1.5255703926086426,0.0 -47600,1.9162962,1.1208702,,,,,,,,,,,,,, -47700,3.772709,1.1967767,,,,,,,,,,,,,, -47800,4.3240952,1.1270055,,,,,,,,,,,,,, -47900,3.7606776,1.1278083,,,,,,,,,,,,,, -48000,,,0.22205196,0.0765498652291105,0.41554552,0.1204128329648474,5348.0,0.22391605,0.0706233623788922,2472.0,40712.370055913925,44835.37041664124,40712.370055913925,4119.21386384964,1.5837607383728027,0.0 -48000,,,,,,,,,,,40712.370055913925,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 996ff941d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -213.90761828422544,0.0,16.132816791534424,1,0,16.132816791534424,30.93976,2472,4.561148010480775,230.04052543640137,31.425632,4.572452038869513,30.812973,5348,4.233661913359144 -346.16198992729187,0.0319299697875976,1456.156590938568,1695,0,1456.156590938568,1.3081437,2472,0.3820608128694168,1802.423523426056,1.6061379,0.4365230673493055,1.737651,5348,0.4402811434971084 -480.506609916687,0.09938645362854,2896.9300112724304,3406,0,2896.9300112724304,0.67142403,2472,0.2084780533382081,3377.6903870105743,0.86962605,0.2598287471562653,1.0091406,5348,0.2794732421290441 -614.574684381485,0.1405558586120605,4338.122195243835,5080,0,4338.122195243835,0.6041575,2472,0.1889383137326589,4953.06637597084,0.8732126,0.2650956794108236,0.9360504,5348,0.2612452571516842 -749.1158313751221,0.1950762271881103,5778.4819252491,6770,0,5778.4819252491,0.5918397,2472,0.1863790546990839,6528.09968495369,0.7836887,0.2392782221424484,0.9076595,5348,0.2530484567037083 -885.0470359325409,0.2474684715270996,7219.33767747879,8473,0,7219.33767747879,0.5595584,2472,0.1730343468811569,8105.016730308533,0.7977896,0.238085031546088,0.8691856,5348,0.2439248095619683 -1019.487206697464,0.3779764175415039,8659.357873678207,10157,0,8659.357873678207,0.5371239,2472,0.1710641236568968,9679.683911561966,0.70206696,0.21623922703828,0.8392556,5348,0.236529345317976 -1152.574282169342,0.4317669868469238,10100.01055407524,11856,0,10100.01055407524,0.5218247,2472,0.1637925781488026,11253.55560541153,0.71111774,0.221663515637457,0.8253648,5348,0.232657829440899 -1286.6929433345797,0.4845023155212402,11540.073407173157,13549,0,11540.073407173157,0.49158484,2472,0.1551195336461316,12827.867657899857,0.66694856,0.2069393366942019,0.79508895,5348,0.223292815972658 -1432.8420944213867,0.5385165214538574,12981.39877486229,15263,0,12981.39877486229,0.47875628,2472,0.152722767249609,14415.474444389343,0.44968167,0.148620218944889,0.77568215,5348,0.2198847234424631 -1572.188811302185,0.5987732410430908,14422.54798579216,16982,0,14422.54798579216,0.46607754,2472,0.1485182702658785,15996.109202861786,0.42352477,0.1388208673010756,0.74860513,5348,0.2146615561369802 -1709.9824848175049,0.6509594917297363,15862.62443113327,18668,0,15862.62443113327,0.4562326,2472,0.1430950785042552,17574.10783100128,0.41050774,0.1349360339469657,0.7495671,5348,0.2105776378925823 -1848.256634473801,0.7045087814331055,17302.69507241249,20356,0,17302.69507241249,0.43282682,2472,0.1374281477870534,19152.5838842392,0.380554,0.1286975115825664,0.72028995,5348,0.2035007771995713 -1984.4048948287964,0.7557229995727539,18742.878882169724,22053,0,18742.878882169724,0.4144221,2472,0.1332845855422176,20729.05018377304,0.3735527,0.1240751796336965,0.6839187,5348,0.1949660638944939 -2124.35813331604,0.811715841293335,20183.42955708504,23726,0,20183.42955708504,0.39957988,2472,0.1274957853472264,22309.68611574173,0.34353802,0.1176566725889529,0.66541904,5348,0.1881884974463442 -2259.734280109405,0.8676271438598633,21623.96768260002,25440,0,21623.96768260002,0.38806707,2472,0.1247943452562305,23885.734208345413,0.38745674,0.1244650828951589,0.6517793,5348,0.1864506598955366 -2396.0211186409,0.9317009449005128,23064.55658864975,27158,0,23064.55658864975,0.37232503,2472,0.1169743870980846,25462.754029750824,0.3255008,0.1100644364580529,0.624412,5348,0.1799820423453083 -2534.6256392002106,0.987617254257202,24505.14197874069,28839,0,24505.14197874069,0.35233474,2472,0.112912071171775,27042.07692527771,0.2880375,0.0992398848024172,0.6031776,5348,0.1722486652442144 -2672.2584941387177,1.0459904670715332,25945.45933794976,30528,0,25945.45933794976,0.33330095,2472,0.1068389088619422,28620.165376901627,0.2754457,0.0915725409944986,0.5681978,5348,0.1625843575311121 -2810.763015270233,1.0972692966461182,27385.520001888275,32240,0,27385.520001888275,0.32057795,2472,0.1023500497633701,30198.85960030556,0.25709972,0.0892764857881137,0.5593565,5348,0.1598810546743002 -2949.174058675766,1.1517837047576904,28825.72840666771,33920,0,28825.72840666771,0.29892173,2472,0.0955050474275384,31777.60964488983,0.25858888,0.0858627059857579,0.5291889,5348,0.1517421821446846 -3086.517218351364,1.2039318084716797,30265.98497748375,35595,0,30265.98497748375,0.28692654,2472,0.0921333252087014,33355.33820748329,0.24284497,0.0817005681548919,0.510084,5348,0.1469824381860837 -3223.1055147647858,1.2620007991790771,31705.91008043289,37298,0,31705.91008043289,0.27283636,2472,0.0863851481729734,34931.98729777336,0.2170262,0.0735402387050613,0.48365703,5348,0.1398959228400127 -3363.3636882305145,1.316806077957153,33146.00746154785,38983,0,33146.00746154785,0.2571431,2472,0.082627505941137,36512.47426152229,0.20192882,0.0694591479663942,0.46267018,5348,0.1323266748409396 -3499.5277755260468,1.382036209106445,34587.25555706024,40699,0,34587.25555706024,0.24477644,2472,0.07754961103325,38090.030962228775,0.1680867,0.0581665522522046,0.44180185,5348,0.1268138679436554 -3638.6739313602448,1.4387574195861816,36027.65841794014,42391,0,36027.65841794014,0.23368566,2472,0.0739138382792029,39669.712934970856,0.16477492,0.0555998242871882,0.42381123,5348,0.1216293192504127 -3773.414123773575,1.5009901523590088,37468.0641913414,44091,0,37468.0641913414,0.2251991,2472,0.0706436739585237,41244.99890422821,0.16973253,0.0573566354231382,0.4103643,5348,0.1181150255365573 -3910.608511686325,1.557424783706665,38908.11180472374,45803,0,38908.11180472374,0.22022197,2472,0.0693640444417362,42822.37716197968,0.14928834,0.0501868458373551,0.40272284,5348,0.1158075634552072 -4047.680432319641,1.617690086364746,40348.411856889725,47502,0,40348.411856889725,0.21897092,2472,0.0687953202120529,44399.88564610481,0.16172755,0.0560497250479024,0.4016892,5348,0.1159041099858076 -4183.648021221161,1.6699771881103516,40751.27304506302,48000,0,40751.27304506302,0.21901312,2472,0.0687953202120529,44938.79483628273,0.1899466,0.059771190902950265,0.40191585,5348,0.11582687276132732 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/measurements.csv deleted file mode 100644 index 20866ccc9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,18.247444,32.95285,,,,,,,,,,,,,, -1,,,31.425632,4.572452038869513,30.812973,4.233661913359144,5348.0,30.93976,4.561148010480775,2472.0,16.132816791534424,230.04052543640137,16.132816791534424,213.90761828422544,0.0,0.0 -100,3.5315177,5.8408465,,,,,,,,,,,,,, -200,2.1146238,4.804647,,,,,,,,,,,,,, -300,3.9947774,3.6787498,,,,,,,,,,,,,, -400,3.1777792,3.2251887,,,,,,,,,,,,,, -500,2.1979895,2.996802,,,,,,,,,,,,,, -600,2.9020936,2.90779,,,,,,,,,,,,,, -700,1.9310025,2.7190657,,,,,,,,,,,,,, -800,3.5060968,2.608016,,,,,,,,,,,,,, -900,3.0353842,2.5369346,,,,,,,,,,,,,, -1000,3.1228735,2.5233629,,,,,,,,,,,,,, -1100,2.8050804,2.4299245,,,,,,,,,,,,,, -1200,3.2302904,2.53015,,,,,,,,,,,,,, -1300,2.641394,2.3672662,,,,,,,,,,,,,, -1400,2.9814446,2.3419142,,,,,,,,,,,,,, -1500,2.7707453,2.299703,,,,,,,,,,,,,, -1600,2.6587732,2.2976916,,,,,,,,,,,,,, -1695,,,1.6061379,0.4365230673493055,1.737651,0.4402811434971084,5348.0,1.3081437,0.3820608128694168,2472.0,1456.156590938568,1802.423523426056,1456.156590938568,346.16198992729187,0.0319299697875976,0.0 -1700,3.3962991,2.2577517,,,,,,,,,,,,,, -1800,3.3218822,2.3230345,,,,,,,,,,,,,, -1900,2.1197393,2.1731262,,,,,,,,,,,,,, -2000,3.8715842,2.2116108,,,,,,,,,,,,,, -2100,2.9306812,2.1886117,,,,,,,,,,,,,, -2200,3.157576,2.1491828,,,,,,,,,,,,,, -2300,2.504947,2.099112,,,,,,,,,,,,,, -2400,3.0561492,2.1571162,,,,,,,,,,,,,, -2500,3.0273893,2.0917144,,,,,,,,,,,,,, -2600,3.306438,2.1035209,,,,,,,,,,,,,, -2700,5.218917,2.0834007,,,,,,,,,,,,,, -2800,3.2598255,2.1113737,,,,,,,,,,,,,, -2900,3.217375,2.072176,,,,,,,,,,,,,, -3000,3.6211526,2.0995345,,,,,,,,,,,,,, -3100,2.7918441,2.0069907,,,,,,,,,,,,,, -3200,2.0760703,2.0424142,,,,,,,,,,,,,, -3300,3.4293208,2.0408747,,,,,,,,,,,,,, -3400,4.6633267,1.9757094,,,,,,,,,,,,,, -3406,,,0.86962605,0.2598287471562653,1.0091406,0.2794732421290441,5348.0,0.67142403,0.2084780533382081,2472.0,2896.9300112724304,3377.6903870105743,2896.9300112724304,480.506609916687,0.09938645362854,0.0 -3500,2.1038659,1.9123777,,,,,,,,,,,,,, -3600,3.3779411,2.055894,,,,,,,,,,,,,, -3700,2.5580087,1.9692959,,,,,,,,,,,,,, -3800,7.5544896,2.0244596,,,,,,,,,,,,,, -3900,2.882531,1.9754226,,,,,,,,,,,,,, -4000,2.1680543,1.9773935,,,,,,,,,,,,,, -4100,2.0195403,1.9956983,,,,,,,,,,,,,, -4200,2.6503947,2.02712,,,,,,,,,,,,,, -4300,2.1623654,2.001658,,,,,,,,,,,,,, -4400,2.4816663,1.8868053,,,,,,,,,,,,,, -4500,2.2183478,1.9367764,,,,,,,,,,,,,, -4600,3.514346,1.9932392,,,,,,,,,,,,,, -4700,2.695078,1.9110246,,,,,,,,,,,,,, -4800,2.6572447,2.0211096,,,,,,,,,,,,,, -4900,4.463111,2.0300217,,,,,,,,,,,,,, -5000,4.0931263,1.9314598,,,,,,,,,,,,,, -5080,,,0.8732126,0.2650956794108236,0.9360504,0.2612452571516842,5348.0,0.6041575,0.1889383137326589,2472.0,4338.122195243835,4953.06637597084,4338.122195243835,614.574684381485,0.1405558586120605,0.0 -5100,2.4369602,1.9691999,,,,,,,,,,,,,, -5200,1.9215866,2.0148575,,,,,,,,,,,,,, -5300,2.5409338,1.9567719,,,,,,,,,,,,,, -5400,4.2906523,1.9524361,,,,,,,,,,,,,, -5500,3.3008378,1.9937553,,,,,,,,,,,,,, -5600,2.5136847,1.8549353,,,,,,,,,,,,,, -5700,3.1703591,1.8779999,,,,,,,,,,,,,, -5800,1.7713581,1.992112,,,,,,,,,,,,,, -5900,1.7663784,1.9318365,,,,,,,,,,,,,, -6000,2.7719097,1.9124804,,,,,,,,,,,,,, -6100,4.2024164,1.9565815,,,,,,,,,,,,,, -6200,5.840642,1.9996557,,,,,,,,,,,,,, -6300,4.7857075,1.930149,,,,,,,,,,,,,, -6400,3.4370282,1.8669109,,,,,,,,,,,,,, -6500,4.2536535,1.8936902,,,,,,,,,,,,,, -6600,5.16156,1.9342905,,,,,,,,,,,,,, -6700,4.4532557,2.1252773,,,,,,,,,,,,,, -6770,,,0.7836887,0.2392782221424484,0.9076595,0.2530484567037083,5348.0,0.5918397,0.1863790546990839,2472.0,5778.4819252491,6528.09968495369,5778.4819252491,749.1158313751221,0.1950762271881103,0.0 -6800,3.1592593,1.8930132,,,,,,,,,,,,,, -6900,1.8556563,1.8126128,,,,,,,,,,,,,, -7000,3.3867962,1.9742916,,,,,,,,,,,,,, -7100,3.2327373,1.8873962,,,,,,,,,,,,,, -7200,2.3232267,1.8816757,,,,,,,,,,,,,, -7300,2.541116,1.888926,,,,,,,,,,,,,, -7400,4.6173286,1.8953842,,,,,,,,,,,,,, -7500,4.908305,1.9249101,,,,,,,,,,,,,, -7600,2.4442563,1.8113812,,,,,,,,,,,,,, -7700,2.9560888,1.8589172,,,,,,,,,,,,,, -7800,2.2525191,1.871445,,,,,,,,,,,,,, -7900,1.9419221,1.8690977,,,,,,,,,,,,,, -8000,3.029141,1.8917184,,,,,,,,,,,,,, -8100,2.1451964,1.9055595,,,,,,,,,,,,,, -8200,2.6840165,1.9270998,,,,,,,,,,,,,, -8300,2.528118,1.8847412,,,,,,,,,,,,,, -8400,1.5770141,1.8608013,,,,,,,,,,,,,, -8473,,,0.7977896,0.238085031546088,0.8691856,0.2439248095619683,5348.0,0.5595584,0.1730343468811569,2472.0,7219.33767747879,8105.016730308533,7219.33767747879,885.0470359325409,0.2474684715270996,0.0 -8500,2.0897896,1.7973853,,,,,,,,,,,,,, -8600,2.8695498,1.8721919,,,,,,,,,,,,,, -8700,3.768176,1.8186543,,,,,,,,,,,,,, -8800,2.7063646,1.8327714,,,,,,,,,,,,,, -8900,4.49928,1.8767148,,,,,,,,,,,,,, -9000,2.3304021,1.8711445,,,,,,,,,,,,,, -9100,2.2568514,1.8847033,,,,,,,,,,,,,, -9200,2.0019712,1.831101,,,,,,,,,,,,,, -9300,2.0857153,1.8201939,,,,,,,,,,,,,, -9400,2.5579731,1.8190526,,,,,,,,,,,,,, -9500,2.7388191,1.8563118,,,,,,,,,,,,,, -9600,3.020664,1.9356625,,,,,,,,,,,,,, -9700,2.5260196,1.821936,,,,,,,,,,,,,, -9800,2.9519708,1.8488518,,,,,,,,,,,,,, -9900,2.2779238,1.8450897,,,,,,,,,,,,,, -10000,3.1981838,1.845114,,,,,,,,,,,,,, -10100,4.3581376,1.8549018,,,,,,,,,,,,,, -10157,,,0.70206696,0.21623922703828,0.8392556,0.236529345317976,5348.0,0.5371239,0.1710641236568968,2472.0,8659.357873678207,9679.683911561966,8659.357873678207,1019.487206697464,0.3779764175415039,0.0 -10200,3.332055,1.8071493,,,,,,,,,,,,,, -10300,2.4123044,1.7631892,,,,,,,,,,,,,, -10400,4.2533517,1.8309994,,,,,,,,,,,,,, -10500,4.337803,1.7954221,,,,,,,,,,,,,, -10600,2.1242614,1.8167148,,,,,,,,,,,,,, -10700,2.3655248,1.7997211,,,,,,,,,,,,,, -10800,2.3284464,1.8221804,,,,,,,,,,,,,, -10900,2.4491973,1.9007764,,,,,,,,,,,,,, -11000,5.048068,1.7796803,,,,,,,,,,,,,, -11100,4.053665,1.8508265,,,,,,,,,,,,,, -11200,2.3374698,1.816672,,,,,,,,,,,,,, -11300,2.4591856,1.8151459,,,,,,,,,,,,,, -11400,4.54345,1.8405942,,,,,,,,,,,,,, -11500,2.1115837,1.828445,,,,,,,,,,,,,, -11600,4.1101093,1.8189671,,,,,,,,,,,,,, -11700,2.153683,1.8063139,,,,,,,,,,,,,, -11800,3.161433,1.8084974,,,,,,,,,,,,,, -11856,,,0.71111774,0.221663515637457,0.8253648,0.232657829440899,5348.0,0.5218247,0.1637925781488026,2472.0,10100.01055407524,11253.55560541153,10100.01055407524,1152.574282169342,0.4317669868469238,0.0 -11900,2.819387,1.806898,,,,,,,,,,,,,, -12000,2.9892337,1.7787341,,,,,,,,,,,,,, -12100,2.3531494,1.7891217,,,,,,,,,,,,,, -12200,2.773684,1.775401,,,,,,,,,,,,,, -12300,2.0352104,1.8008966,,,,,,,,,,,,,, -12400,2.5095272,1.7779921,,,,,,,,,,,,,, -12500,2.590231,1.8343832,,,,,,,,,,,,,, -12600,4.017258,1.8219818,,,,,,,,,,,,,, -12700,3.8225088,1.7424141,,,,,,,,,,,,,, -12800,2.1507747,1.767242,,,,,,,,,,,,,, -12900,2.2647789,1.7676082,,,,,,,,,,,,,, -13000,2.6403491,1.8280185,,,,,,,,,,,,,, -13100,2.231338,1.809472,,,,,,,,,,,,,, -13200,2.6028779,1.8161453,,,,,,,,,,,,,, -13300,2.5783525,1.7676654,,,,,,,,,,,,,, -13400,2.1461945,1.7526236,,,,,,,,,,,,,, -13500,3.523403,1.8111409,,,,,,,,,,,,,, -13549,,,0.66694856,0.2069393366942019,0.79508895,0.223292815972658,5348.0,0.49158484,0.1551195336461316,2472.0,11540.073407173157,12827.867657899857,11540.073407173157,1286.6929433345797,0.4845023155212402,0.0 -13600,2.6519039,1.7164022,,,,,,,,,,,,,, -13700,2.0720484,1.7612331,,,,,,,,,,,,,, -13800,3.4036222,1.7848504,,,,,,,,,,,,,, -13900,1.957311,1.7726269,,,,,,,,,,,,,, -14000,2.8411767,1.8231503,,,,,,,,,,,,,, -14100,2.2386072,1.7543826,,,,,,,,,,,,,, -14200,1.8241696,1.7291105,,,,,,,,,,,,,, -14300,3.112006,1.7344368,,,,,,,,,,,,,, -14400,3.8959887,1.7148309,,,,,,,,,,,,,, -14500,2.9633362,1.7860337,,,,,,,,,,,,,, -14600,3.231373,1.7391241,,,,,,,,,,,,,, -14700,3.0643516,1.7578833,,,,,,,,,,,,,, -14800,4.59127,1.7252855,,,,,,,,,,,,,, -14900,3.893186,1.7533585,,,,,,,,,,,,,, -15000,2.961725,1.8033378,,,,,,,,,,,,,, -15100,3.2802355,1.7967708,,,,,,,,,,,,,, -15200,3.2444603,1.6981051,,,,,,,,,,,,,, -15263,,,0.44968167,0.148620218944889,0.77568215,0.2198847234424631,5348.0,0.47875628,0.152722767249609,2472.0,12981.39877486229,14415.474444389343,12981.39877486229,1432.8420944213867,0.5385165214538574,0.0 -15300,3.4915588,1.768122,,,,,,,,,,,,,, -15400,3.0685937,1.8042971,,,,,,,,,,,,,, -15500,4.1440415,1.693449,,,,,,,,,,,,,, -15600,2.912498,1.664395,,,,,,,,,,,,,, -15700,2.6758945,1.7780653,,,,,,,,,,,,,, -15800,6.1510653,1.7838094,,,,,,,,,,,,,, -15900,2.5304737,1.695547,,,,,,,,,,,,,, -16000,1.7402881,1.7570461,,,,,,,,,,,,,, -16100,3.8882043,1.7291623,,,,,,,,,,,,,, -16200,2.678357,1.6735934,,,,,,,,,,,,,, -16300,5.708896,1.7320143,,,,,,,,,,,,,, -16400,2.9581747,1.728006,,,,,,,,,,,,,, -16500,1.757109,1.7320579,,,,,,,,,,,,,, -16600,2.466834,1.762132,,,,,,,,,,,,,, -16700,2.3116543,1.7247484,,,,,,,,,,,,,, -16800,2.7232885,1.8015761,,,,,,,,,,,,,, -16900,2.4042792,1.7334156,,,,,,,,,,,,,, -16982,,,0.42352477,0.1388208673010756,0.74860513,0.2146615561369802,5348.0,0.46607754,0.1485182702658785,2472.0,14422.54798579216,15996.109202861786,14422.54798579216,1572.188811302185,0.5987732410430908,0.0 -17000,2.5942352,1.6984206,,,,,,,,,,,,,, -17100,3.1081839,1.7719833,,,,,,,,,,,,,, -17200,2.4454525,1.7188979,,,,,,,,,,,,,, -17300,2.778558,1.7124916,,,,,,,,,,,,,, -17400,3.0894449,1.798723,,,,,,,,,,,,,, -17500,2.5188699,1.6867694,,,,,,,,,,,,,, -17600,2.5484903,1.698153,,,,,,,,,,,,,, -17700,1.8413215,1.7072871,,,,,,,,,,,,,, -17800,4.5340605,1.6872706,,,,,,,,,,,,,, -17900,2.7440262,1.6532001,,,,,,,,,,,,,, -18000,3.8840218,1.7635453,,,,,,,,,,,,,, -18100,1.9321303,1.7138861,,,,,,,,,,,,,, -18200,2.1348119,1.7644696,,,,,,,,,,,,,, -18300,3.5966487,1.7549093,,,,,,,,,,,,,, -18400,2.9235368,1.6889422,,,,,,,,,,,,,, -18500,2.3448312,1.652856,,,,,,,,,,,,,, -18600,3.4878843,1.7214494,,,,,,,,,,,,,, -18668,,,0.41050774,0.1349360339469657,0.7495671,0.2105776378925823,5348.0,0.4562326,0.1430950785042552,2472.0,15862.62443113327,17574.10783100128,15862.62443113327,1709.9824848175049,0.6509594917297363,0.0 -18700,2.3996112,1.6118033,,,,,,,,,,,,,, -18800,2.3251784,1.6845995,,,,,,,,,,,,,, -18900,3.2093701,1.7348999,,,,,,,,,,,,,, -19000,3.9433615,1.6923264,,,,,,,,,,,,,, -19100,2.8902678,1.6958157,,,,,,,,,,,,,, -19200,3.029494,1.6112983,,,,,,,,,,,,,, -19300,3.587058,1.6922972,,,,,,,,,,,,,, -19400,2.4188962,1.6569389,,,,,,,,,,,,,, -19500,2.3279886,1.6823515,,,,,,,,,,,,,, -19600,2.499632,1.6722757,,,,,,,,,,,,,, -19700,1.8594,1.681955,,,,,,,,,,,,,, -19800,2.832615,1.7135221,,,,,,,,,,,,,, -19900,3.94851,1.6714543,,,,,,,,,,,,,, -20000,1.9961481,1.6657193,,,,,,,,,,,,,, -20100,1.9247266,1.6606365,,,,,,,,,,,,,, -20200,4.383903,1.7239866,,,,,,,,,,,,,, -20300,1.8030374,1.661417,,,,,,,,,,,,,, -20356,,,0.380554,0.1286975115825664,0.72028995,0.2035007771995713,5348.0,0.43282682,0.1374281477870534,2472.0,17302.69507241249,19152.5838842392,17302.69507241249,1848.256634473801,0.7045087814331055,0.0 -20400,3.149635,1.755057,,,,,,,,,,,,,, -20500,4.0476646,1.7117938,,,,,,,,,,,,,, -20600,1.5811177,1.6346489,,,,,,,,,,,,,, -20700,3.128411,1.6711094,,,,,,,,,,,,,, -20800,5.6306663,1.6573794,,,,,,,,,,,,,, -20900,3.7282426,1.6390913,,,,,,,,,,,,,, -21000,1.8571314,1.739455,,,,,,,,,,,,,, -21100,2.1018279,1.6371826,,,,,,,,,,,,,, -21200,2.2121794,1.6977276,,,,,,,,,,,,,, -21300,3.1324792,1.6586969,,,,,,,,,,,,,, -21400,2.975423,1.6158677,,,,,,,,,,,,,, -21500,2.460034,1.6584812,,,,,,,,,,,,,, -21600,2.8959808,1.7086811,,,,,,,,,,,,,, -21700,3.7970428,1.629832,,,,,,,,,,,,,, -21800,2.0608723,1.5913968,,,,,,,,,,,,,, -21900,2.4967134,1.5916893,,,,,,,,,,,,,, -22000,3.8157415,1.6273068,,,,,,,,,,,,,, -22053,,,0.3735527,0.1240751796336965,0.6839187,0.1949660638944939,5348.0,0.4144221,0.1332845855422176,2472.0,18742.878882169724,20729.05018377304,18742.878882169724,1984.4048948287964,0.7557229995727539,0.0 -22100,2.4169261,1.643593,,,,,,,,,,,,,, -22200,2.1313012,1.6335803,,,,,,,,,,,,,, -22300,2.5641901,1.5707886,,,,,,,,,,,,,, -22400,3.8281517,1.6579094,,,,,,,,,,,,,, -22500,3.6398003,1.6041946,,,,,,,,,,,,,, -22600,2.5088892,1.6142662,,,,,,,,,,,,,, -22700,2.2097478,1.577857,,,,,,,,,,,,,, -22800,2.594825,1.634042,,,,,,,,,,,,,, -22900,2.6209078,1.6592327,,,,,,,,,,,,,, -23000,2.2503269,1.6064298,,,,,,,,,,,,,, -23100,2.4230983,1.6382723,,,,,,,,,,,,,, -23200,2.309001,1.6692125,,,,,,,,,,,,,, -23300,4.1650524,1.647205,,,,,,,,,,,,,, -23400,4.3542013,1.5884123,,,,,,,,,,,,,, -23500,1.7521614,1.6771526,,,,,,,,,,,,,, -23600,4.064581,1.6285565,,,,,,,,,,,,,, -23700,1.9997259,1.6804123,,,,,,,,,,,,,, -23726,,,0.34353802,0.1176566725889529,0.66541904,0.1881884974463442,5348.0,0.39957988,0.1274957853472264,2472.0,20183.42955708504,22309.68611574173,20183.42955708504,2124.35813331604,0.811715841293335,0.0 -23800,1.5127662,1.5457237,,,,,,,,,,,,,, -23900,3.2866862,1.6699529,,,,,,,,,,,,,, -24000,1.7319868,1.6341858,,,,,,,,,,,,,, -24100,2.3431149,1.6365867,,,,,,,,,,,,,, -24200,4.0118146,1.6707187,,,,,,,,,,,,,, -24300,1.6882001,1.5290383,,,,,,,,,,,,,, -24400,2.7663465,1.5895983,,,,,,,,,,,,,, -24500,2.6164398,1.5098542,,,,,,,,,,,,,, -24600,2.9352348,1.6143634,,,,,,,,,,,,,, -24700,2.578639,1.5797042,,,,,,,,,,,,,, -24800,2.394863,1.5267884,,,,,,,,,,,,,, -24900,1.9244382,1.5429162,,,,,,,,,,,,,, -25000,2.4053805,1.5544333,,,,,,,,,,,,,, -25100,2.2039998,1.5785351,,,,,,,,,,,,,, -25200,2.9595685,1.6914463,,,,,,,,,,,,,, -25300,2.2389185,1.607158,,,,,,,,,,,,,, -25400,2.63892,1.5720683,,,,,,,,,,,,,, -25440,,,0.38745674,0.1244650828951589,0.6517793,0.1864506598955366,5348.0,0.38806707,0.1247943452562305,2472.0,21623.96768260002,23885.734208345413,21623.96768260002,2259.734280109405,0.8676271438598633,0.0 -25500,2.673815,1.5790981,,,,,,,,,,,,,, -25600,2.4125192,1.5819491,,,,,,,,,,,,,, -25700,2.893203,1.5645862,,,,,,,,,,,,,, -25800,2.504677,1.4904021,,,,,,,,,,,,,, -25900,2.375558,1.5202312,,,,,,,,,,,,,, -26000,3.2623215,1.5412047,,,,,,,,,,,,,, -26100,2.6288612,1.549158,,,,,,,,,,,,,, -26200,2.4631727,1.5391152,,,,,,,,,,,,,, -26300,2.714874,1.5416135,,,,,,,,,,,,,, -26400,2.4065232,1.5765911,,,,,,,,,,,,,, -26500,3.1091897,1.6105367,,,,,,,,,,,,,, -26600,1.7368131,1.5259628,,,,,,,,,,,,,, -26700,2.417432,1.4604648,,,,,,,,,,,,,, -26800,4.724865,1.5415456,,,,,,,,,,,,,, -26900,1.8374115,1.4703588,,,,,,,,,,,,,, -27000,3.1162663,1.5744699,,,,,,,,,,,,,, -27100,1.3994838,1.5274059,,,,,,,,,,,,,, -27158,,,0.3255008,0.1100644364580529,0.624412,0.1799820423453083,5348.0,0.37232503,0.1169743870980846,2472.0,23064.55658864975,25462.754029750824,23064.55658864975,2396.0211186409,0.9317009449005128,0.0 -27200,2.8547103,1.621658,,,,,,,,,,,,,, -27300,2.9193263,1.4765273,,,,,,,,,,,,,, -27400,3.1795166,1.6272272,,,,,,,,,,,,,, -27500,2.2891338,1.5593195,,,,,,,,,,,,,, -27600,3.1132512,1.4945633,,,,,,,,,,,,,, -27700,2.3645496,1.618049,,,,,,,,,,,,,, -27800,2.6613872,1.5203477,,,,,,,,,,,,,, -27900,3.3515947,1.5443022,,,,,,,,,,,,,, -28000,1.7054017,1.4713618,,,,,,,,,,,,,, -28100,2.621059,1.5170667,,,,,,,,,,,,,, -28200,2.3681104,1.4997667,,,,,,,,,,,,,, -28300,2.0521472,1.5128437,,,,,,,,,,,,,, -28400,2.029702,1.4213213,,,,,,,,,,,,,, -28500,2.3189917,1.5320743,,,,,,,,,,,,,, -28600,2.6161273,1.4991708,,,,,,,,,,,,,, -28700,2.4924018,1.4706587,,,,,,,,,,,,,, -28800,2.308636,1.4803672,,,,,,,,,,,,,, -28839,,,0.2880375,0.0992398848024172,0.6031776,0.1722486652442144,5348.0,0.35233474,0.112912071171775,2472.0,24505.14197874069,27042.07692527771,24505.14197874069,2534.6256392002106,0.987617254257202,0.0 -28900,1.7425457,1.4734496,,,,,,,,,,,,,, -29000,2.7377062,1.4937557,,,,,,,,,,,,,, -29100,2.7147624,1.5112282,,,,,,,,,,,,,, -29200,3.6786482,1.5302932,,,,,,,,,,,,,, -29300,3.487076,1.4345666,,,,,,,,,,,,,, -29400,2.3469598,1.531415,,,,,,,,,,,,,, -29500,1.8800714,1.5286078,,,,,,,,,,,,,, -29600,2.2690506,1.4844563,,,,,,,,,,,,,, -29700,3.9965632,1.497186,,,,,,,,,,,,,, -29800,2.0652125,1.4233853,,,,,,,,,,,,,, -29900,3.0274189,1.4704689,,,,,,,,,,,,,, -30000,2.2271605,1.4891416,,,,,,,,,,,,,, -30100,2.294499,1.4955796,,,,,,,,,,,,,, -30200,4.444398,1.4818051,,,,,,,,,,,,,, -30300,2.228873,1.4988648,,,,,,,,,,,,,, -30400,2.034375,1.4832672,,,,,,,,,,,,,, -30500,2.7675722,1.4616921,,,,,,,,,,,,,, -30528,,,0.2754457,0.0915725409944986,0.5681978,0.1625843575311121,5348.0,0.33330095,0.1068389088619422,2472.0,25945.45933794976,28620.165376901627,25945.45933794976,2672.2584941387177,1.0459904670715332,0.0 -30600,1.9443817,1.4583268,,,,,,,,,,,,,, -30700,3.6476398,1.5440091,,,,,,,,,,,,,, -30800,2.5865035,1.4888704,,,,,,,,,,,,,, -30900,1.9077677,1.4271045,,,,,,,,,,,,,, -31000,2.7209997,1.4392506,,,,,,,,,,,,,, -31100,1.6903799,1.4437488,,,,,,,,,,,,,, -31200,1.9044259,1.3985109,,,,,,,,,,,,,, -31300,1.5266602,1.3719612,,,,,,,,,,,,,, -31400,2.9308898,1.4499118,,,,,,,,,,,,,, -31500,3.1407938,1.4632448,,,,,,,,,,,,,, -31600,2.9443817,1.4318008,,,,,,,,,,,,,, -31700,1.7432835,1.4276782,,,,,,,,,,,,,, -31800,1.9061098,1.4223574,,,,,,,,,,,,,, -31900,1.7063245,1.4004332,,,,,,,,,,,,,, -32000,1.8996546,1.4386758,,,,,,,,,,,,,, -32100,1.550279,1.4169257,,,,,,,,,,,,,, -32200,2.2428408,1.3669933,,,,,,,,,,,,,, -32240,,,0.25709972,0.0892764857881137,0.5593565,0.1598810546743002,5348.0,0.32057795,0.1023500497633701,2472.0,27385.520001888275,30198.85960030556,27385.520001888275,2810.763015270233,1.0972692966461182,0.0 -32300,2.3380485,1.4854771,,,,,,,,,,,,,, -32400,3.206281,1.4351276,,,,,,,,,,,,,, -32500,1.6962299,1.3882432,,,,,,,,,,,,,, -32600,3.7782793,1.4472717,,,,,,,,,,,,,, -32700,2.5019104,1.4735762,,,,,,,,,,,,,, -32800,1.9420483,1.4340723,,,,,,,,,,,,,, -32900,1.7913265,1.3756226,,,,,,,,,,,,,, -33000,1.9130849,1.4763954,,,,,,,,,,,,,, -33100,1.6083633,1.3819803,,,,,,,,,,,,,, -33200,1.5657998,1.3933238,,,,,,,,,,,,,, -33300,2.9190028,1.4126074,,,,,,,,,,,,,, -33400,1.8573273,1.3794198,,,,,,,,,,,,,, -33500,2.264383,1.4246117,,,,,,,,,,,,,, -33600,2.5386033,1.4375837,,,,,,,,,,,,,, -33700,2.4032452,1.4069515,,,,,,,,,,,,,, -33800,2.165796,1.4764074,,,,,,,,,,,,,, -33900,2.5079687,1.4155014,,,,,,,,,,,,,, -33920,,,0.25858888,0.0858627059857579,0.5291889,0.1517421821446846,5348.0,0.29892173,0.0955050474275384,2472.0,28825.72840666771,31777.60964488983,28825.72840666771,2949.174058675766,1.1517837047576904,0.0 -34000,1.803501,1.3365316,,,,,,,,,,,,,, -34100,2.524824,1.4233028,,,,,,,,,,,,,, -34200,3.6854336,1.3775244,,,,,,,,,,,,,, -34300,1.4589208,1.354626,,,,,,,,,,,,,, -34400,1.8405361,1.3391262,,,,,,,,,,,,,, -34500,2.128889,1.3403313,,,,,,,,,,,,,, -34600,2.0808694,1.4218831,,,,,,,,,,,,,, -34700,3.370585,1.3909945,,,,,,,,,,,,,, -34800,3.3271747,1.3887566,,,,,,,,,,,,,, -34900,2.0551105,1.3629031,,,,,,,,,,,,,, -35000,1.5544634,1.363976,,,,,,,,,,,,,, -35100,2.1629808,1.3812579,,,,,,,,,,,,,, -35200,1.6073563,1.3430405,,,,,,,,,,,,,, -35300,1.8318212,1.4071181,,,,,,,,,,,,,, -35400,2.0147467,1.3196517,,,,,,,,,,,,,, -35500,2.0267599,1.4209911,,,,,,,,,,,,,, -35595,,,0.24284497,0.0817005681548919,0.510084,0.1469824381860837,5348.0,0.28692654,0.0921333252087014,2472.0,30265.98497748375,33355.33820748329,30265.98497748375,3086.517218351364,1.2039318084716797,0.0 -35600,4.6479206,1.4290167,,,,,,,,,,,,,, -35700,2.0481243,1.3955529,,,,,,,,,,,,,, -35800,1.5550848,1.3142507,,,,,,,,,,,,,, -35900,3.77275,1.333153,,,,,,,,,,,,,, -36000,2.4659464,1.3652686,,,,,,,,,,,,,, -36100,4.3053136,1.3881959,,,,,,,,,,,,,, -36200,3.0675242,1.3647822,,,,,,,,,,,,,, -36300,3.8856647,1.294168,,,,,,,,,,,,,, -36400,1.8593186,1.3005764,,,,,,,,,,,,,, -36500,1.8042437,1.3448571,,,,,,,,,,,,,, -36600,2.1018965,1.3137536,,,,,,,,,,,,,, -36700,2.4251325,1.3256145,,,,,,,,,,,,,, -36800,2.395402,1.2923002,,,,,,,,,,,,,, -36900,2.2214644,1.3098922,,,,,,,,,,,,,, -37000,2.0132906,1.3433845,,,,,,,,,,,,,, -37100,1.8293493,1.32824,,,,,,,,,,,,,, -37200,2.2175138,1.3124794,,,,,,,,,,,,,, -37298,,,0.2170262,0.0735402387050613,0.48365703,0.1398959228400127,5348.0,0.27283636,0.0863851481729734,2472.0,31705.91008043289,34931.98729777336,31705.91008043289,3223.1055147647858,1.2620007991790771,0.0 -37300,1.6583403,1.3208362,,,,,,,,,,,,,, -37400,1.981871,1.3181989,,,,,,,,,,,,,, -37500,1.9345659,1.3195175,,,,,,,,,,,,,, -37600,1.921817,1.3035173,,,,,,,,,,,,,, -37700,3.2458148,1.3533695,,,,,,,,,,,,,, -37800,2.8535035,1.2851614,,,,,,,,,,,,,, -37900,1.9795971,1.3112708,,,,,,,,,,,,,, -38000,4.048402,1.3161064,,,,,,,,,,,,,, -38100,3.333061,1.3270485,,,,,,,,,,,,,, -38200,4.173828,1.2882112,,,,,,,,,,,,,, -38300,2.2007718,1.2910392,,,,,,,,,,,,,, -38400,2.7000186,1.2628471,,,,,,,,,,,,,, -38500,2.3924992,1.2687542,,,,,,,,,,,,,, -38600,2.1487193,1.2579798,,,,,,,,,,,,,, -38700,2.4361358,1.3376555,,,,,,,,,,,,,, -38800,2.0879662,1.2425103,,,,,,,,,,,,,, -38900,2.9983006,1.2226301,,,,,,,,,,,,,, -38983,,,0.20192882,0.0694591479663942,0.46267018,0.1323266748409396,5348.0,0.2571431,0.082627505941137,2472.0,33146.00746154785,36512.47426152229,33146.00746154785,3363.3636882305145,1.316806077957153,0.0 -39000,2.7052531,1.2776049,,,,,,,,,,,,,, -39100,1.9909062,1.2853742,,,,,,,,,,,,,, -39200,1.6402584,1.2529125,,,,,,,,,,,,,, -39300,1.9136307,1.2795306,,,,,,,,,,,,,, -39400,4.2216964,1.2584823,,,,,,,,,,,,,, -39500,2.3799772,1.2192754,,,,,,,,,,,,,, -39600,1.749375,1.2484398,,,,,,,,,,,,,, -39700,2.0156212,1.2572043,,,,,,,,,,,,,, -39800,2.5676723,1.2388021,,,,,,,,,,,,,, -39900,4.6412196,1.2814853,,,,,,,,,,,,,, -40000,3.0921073,1.25174,,,,,,,,,,,,,, -40100,2.5650713,1.2741573,,,,,,,,,,,,,, -40200,2.1089897,1.2367108,,,,,,,,,,,,,, -40300,2.0584667,1.2354567,,,,,,,,,,,,,, -40400,1.4796705,1.2124944,,,,,,,,,,,,,, -40500,2.706374,1.2351412,,,,,,,,,,,,,, -40600,2.7672317,1.2417854,,,,,,,,,,,,,, -40699,,,0.1680867,0.0581665522522046,0.44180185,0.1268138679436554,5348.0,0.24477644,0.07754961103325,2472.0,34587.25555706024,38090.030962228775,34587.25555706024,3499.5277755260468,1.382036209106445,0.0 -40700,2.4254544,1.2627875,,,,,,,,,,,,,, -40800,1.6230139,1.264276,,,,,,,,,,,,,, -40900,1.3434925,1.2611738,,,,,,,,,,,,,, -41000,2.061148,1.2421429,,,,,,,,,,,,,, -41100,1.8388159,1.1955191,,,,,,,,,,,,,, -41200,2.8151927,1.1915779,,,,,,,,,,,,,, -41300,1.6051546,1.1903775,,,,,,,,,,,,,, -41400,1.7258302,1.161105,,,,,,,,,,,,,, -41500,2.0572011,1.2248152,,,,,,,,,,,,,, -41600,2.88933,1.2063303,,,,,,,,,,,,,, -41700,2.6973126,1.2172699,,,,,,,,,,,,,, -41800,2.2869258,1.1566265,,,,,,,,,,,,,, -41900,1.9202508,1.1723942,,,,,,,,,,,,,, -42000,2.2814357,1.2260612,,,,,,,,,,,,,, -42100,3.070659,1.1699406,,,,,,,,,,,,,, -42200,2.4318035,1.1788834,,,,,,,,,,,,,, -42300,1.6646127,1.1684439,,,,,,,,,,,,,, -42391,,,0.16477492,0.0555998242871882,0.42381123,0.1216293192504127,5348.0,0.23368566,0.0739138382792029,2472.0,36027.65841794014,39669.712934970856,36027.65841794014,3638.6739313602448,1.4387574195861816,0.0 -42400,2.333245,1.2046101,,,,,,,,,,,,,, -42500,2.0002878,1.207506,,,,,,,,,,,,,, -42600,1.828194,1.203848,,,,,,,,,,,,,, -42700,1.7823392,1.2176956,,,,,,,,,,,,,, -42800,4.579115,1.1795963,,,,,,,,,,,,,, -42900,3.506455,1.1846193,,,,,,,,,,,,,, -43000,2.2270374,1.2534238,,,,,,,,,,,,,, -43100,3.774662,1.1397656,,,,,,,,,,,,,, -43200,2.0661201,1.1597446,,,,,,,,,,,,,, -43300,2.3288524,1.1459851,,,,,,,,,,,,,, -43400,3.0310323,1.1694762,,,,,,,,,,,,,, -43500,3.8189058,1.1123825,,,,,,,,,,,,,, -43600,3.3978446,1.1567501,,,,,,,,,,,,,, -43700,2.4953117,1.1918069,,,,,,,,,,,,,, -43800,2.0807514,1.1649138,,,,,,,,,,,,,, -43900,3.1200938,1.1809022,,,,,,,,,,,,,, -44000,2.148786,1.1496104,,,,,,,,,,,,,, -44091,,,0.16973253,0.0573566354231382,0.4103643,0.1181150255365573,5348.0,0.2251991,0.0706436739585237,2472.0,37468.0641913414,41244.99890422821,37468.0641913414,3773.414123773575,1.5009901523590088,0.0 -44100,1.5301374,1.1292702,,,,,,,,,,,,,, -44200,2.1899457,1.1517947,,,,,,,,,,,,,, -44300,1.587516,1.1270127,,,,,,,,,,,,,, -44400,1.7888733,1.1546263,,,,,,,,,,,,,, -44500,1.4359537,1.19162,,,,,,,,,,,,,, -44600,1.7980433,1.1296536,,,,,,,,,,,,,, -44700,2.9697208,1.1458184,,,,,,,,,,,,,, -44800,2.6208072,1.1980053,,,,,,,,,,,,,, -44900,1.7613006,1.1742103,,,,,,,,,,,,,, -45000,1.8883983,1.1380024,,,,,,,,,,,,,, -45100,2.3087592,1.1320076,,,,,,,,,,,,,, -45200,2.212586,1.1663024,,,,,,,,,,,,,, -45300,1.5621121,1.1090479,,,,,,,,,,,,,, -45400,2.968442,1.0455185,,,,,,,,,,,,,, -45500,1.5409539,1.2208551,,,,,,,,,,,,,, -45600,2.1638653,1.0734013,,,,,,,,,,,,,, -45700,2.6828632,1.1531478,,,,,,,,,,,,,, -45800,1.7893642,1.1065985,,,,,,,,,,,,,, -45803,,,0.14928834,0.0501868458373551,0.40272284,0.1158075634552072,5348.0,0.22022197,0.0693640444417362,2472.0,38908.11180472374,42822.37716197968,38908.11180472374,3910.608511686325,1.557424783706665,0.0 -45900,1.5219285,1.134433,,,,,,,,,,,,,, -46000,2.236941,1.1515313,,,,,,,,,,,,,, -46100,2.4547973,1.1242701,,,,,,,,,,,,,, -46200,2.9817553,1.1488643,,,,,,,,,,,,,, -46300,1.6173675,1.1553041,,,,,,,,,,,,,, -46400,3.3538098,1.1566691,,,,,,,,,,,,,, -46500,1.8603265,1.0926927,,,,,,,,,,,,,, -46600,3.5359473,1.1429493,,,,,,,,,,,,,, -46700,2.2508612,1.1663966,,,,,,,,,,,,,, -46800,1.6685958,1.1572747,,,,,,,,,,,,,, -46900,2.4390466,1.1687373,,,,,,,,,,,,,, -47000,1.776143,1.1422603,,,,,,,,,,,,,, -47100,1.7149518,1.1028941,,,,,,,,,,,,,, -47200,7.1024995,1.1286876,,,,,,,,,,,,,, -47300,2.0462484,1.1177047,,,,,,,,,,,,,, -47400,3.1724942,1.1860728,,,,,,,,,,,,,, -47500,1.7971936,1.0728277,,,,,,,,,,,,,, -47502,,,0.16172755,0.0560497250479024,0.4016892,0.1159041099858076,5348.0,0.21897092,0.0687953202120529,2472.0,40348.411856889725,44399.88564610481,40348.411856889725,4047.680432319641,1.617690086364746,0.0 -47600,2.357308,1.1400362,,,,,,,,,,,,,, -47700,1.9470241,1.1499189,,,,,,,,,,,,,, -47800,3.1911318,1.1396846,,,,,,,,,,,,,, -47900,2.583631,1.1391224,,,,,,,,,,,,,, -48000,,,0.1899466,0.0597711909029502,0.40191585,0.1158268727613273,5348.0,0.21901312,0.0687953202120529,2472.0,40751.27304506302,44938.79483628273,40751.27304506302,4183.648021221161,1.6699771881103516,0.0 -48000,,,,,,,,,,,40751.27304506302,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/eval_measurements.csv deleted file mode 100644 index b655511da..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -224.78997611999512,0.0,16.178600549697876,1,0,16.178600549697876,30.939766,2472,4.561066764162249,240.9686577320099,30.994576,4.707897006283109,30.812992,5348,4.233575021481603 -358.0138473510742,0.02854585647583,1456.1483154296875,1692,0,1456.1483154296875,2.0015655,2472,0.4741332033392237,1814.265554189682,2.0471356,0.4961872165120053,2.49202,5348,0.547013333075876 -493.33926486968994,0.0849425792694091,2896.421770811081,3409,0,2896.421770811081,0.5924073,2472,0.1844291430544553,3390.00124835968,0.7372342,0.2299690858884231,0.913854,5348,0.2588991764580939 -628.9697065353394,0.134153127670288,4336.899610280991,5098,0,4336.899610280991,0.48912373,2472,0.159324030629862,4966.236393213272,0.63931066,0.2019689524050605,0.7830972,5348,0.2235824555644593 -762.1963560581207,0.1938681602478027,5777.278906345367,6781,0,5777.278906345367,0.44423807,2472,0.1440903459062011,6539.980969905853,0.7021821,0.2195984392781661,0.74043995,5348,0.2126533883004914 -895.7563931941986,0.2485561370849609,7217.765897512436,8491,0,7217.765897512436,0.44064087,2472,0.141490463713363,8114.162263154983,0.5751879,0.1796062016228983,0.7267631,5348,0.206619230137965 -1029.8109738826752,0.3024752140045166,8658.173372507095,10176,0,8658.173372507095,0.39185888,2472,0.1262770905693336,9688.754879951475,0.53614444,0.1739231810375223,0.6622156,5348,0.1904380316093341 -1165.5521621704102,0.3547070026397705,10098.086151361464,11874,0,10098.086151361464,0.38414696,2472,0.1226210062356549,11264.540845394136,0.4290752,0.1424052505581408,0.65312153,5348,0.1863830773241163 -1304.4409563541412,0.4097585678100586,11538.334949493408,13566,0,11538.334949493408,0.37574464,2472,0.1222960209615501,12843.813130378723,0.5048986,0.1636252175721682,0.636037,5348,0.1803971924268901 -1441.1059172153473,0.4590444564819336,12978.819166183472,15256,0,12978.819166183472,0.36284128,2472,0.1173196839518209,14421.089455366136,0.45856845,0.1502044088176352,0.6351292,5348,0.1813336937737142 -1576.9730966091156,0.5129210948944092,14418.86126089096,16959,0,14418.86126089096,0.3535885,2472,0.1136026648792476,15997.133904457092,0.44045535,0.1457255361385011,0.6102114,5348,0.1746237098969848 -1711.5768485069275,0.5694046020507812,15859.378022432327,18655,0,15859.378022432327,0.33918792,2472,0.1079560457416773,17572.3883395195,0.40981236,0.1340964840556009,0.5838607,5348,0.1677399422651747 -1844.674907445908,0.621281623840332,17299.322156190872,20349,0,17299.322156190872,0.33319625,2472,0.1072654520342047,19145.56060743332,0.42425936,0.14089816744684,0.5756924,5348,0.1647373451635015 -1981.468006372452,0.6765701770782471,18739.717533826828,22051,0,18739.717533826828,0.32314736,2472,0.1023906729226332,20722.88533854485,0.4327474,0.1370816180418829,0.5562388,5348,0.1593017754906977 -2117.053550004959,0.740778923034668,20179.681025505062,23736,0,20179.681025505062,0.3138135,2472,0.100197022322426,22298.57732820511,0.36796176,0.1241607220686649,0.54934376,5348,0.1599776012049007 -2252.23565363884,0.7953715324401855,21619.953918218613,25430,0,21619.953918218613,0.30784765,2472,0.0975362053906932,23874.16535425186,0.39533722,0.1345729867718728,0.53948,5348,0.1550923467565193 -2388.6867899894714,0.8549208641052246,23060.258680820465,27147,0,23060.258680820465,0.2963604,2472,0.0940629252736985,25451.05943512917,0.3278709,0.1105115369186893,0.52832973,5348,0.1517132181855045 -2523.680060863495,0.9110238552093506,24500.23553586006,28814,0,24500.23553586006,0.2814619,2472,0.0896553124936526,27026.162631988525,0.33591753,0.1122769309086795,0.49735403,5348,0.1428309373702656 -2659.6065866947174,0.9619748592376708,25940.209047079086,30496,0,25940.209047079086,0.2728875,2472,0.0881725671805496,28602.19198703766,0.32313046,0.1091377633158311,0.49119592,5348,0.141112409125578 -2796.5278537273407,1.02350115776062,27380.473700761795,32213,0,27380.473700761795,0.26595923,2472,0.0854101923506591,30179.518819332123,0.32220346,0.1095048590651246,0.48185685,5348,0.1390752773299091 -2933.5398259162903,1.0806150436401367,28821.714812994003,33892,0,28821.714812994003,0.25858718,2472,0.081997846972559,31757.905738592148,0.25985754,0.0887791944037007,0.46532407,5348,0.1326742423511011 -3068.84513258934,1.1369514465332031,30261.89965081215,35575,0,30261.89965081215,0.25288898,2472,0.0821197164503483,33333.53277087212,0.23919861,0.0823648799729387,0.45785353,5348,0.1310232966778338 -3203.287886619568,1.198265790939331,31702.451380491257,37279,0,31702.451380491257,0.24081011,2472,0.0775699226128816,34908.66898369789,0.26334742,0.0912923553024017,0.44023106,5348,0.1263987178620736 -3339.672556877136,1.260697364807129,33142.61997103691,38963,0,33142.61997103691,0.23291391,2472,0.0753356488534113,36485.36349272728,0.23935167,0.0811849127280522,0.4321774,5348,0.1247188082296262 -3474.866818666458,1.3295516967773438,34583.59595036507,40672,0,34583.59595036507,0.22748674,2472,0.0735888530050982,38061.68483424187,0.24396911,0.0813538117108663,0.42279685,5348,0.1221410158625949 -3613.8211603164673,1.3897809982299805,36023.78817820549,42350,0,36023.78817820549,0.22420189,2472,0.0727154550809416,39640.9696969986,0.20017296,0.0698609234942493,0.41711703,5348,0.1192253106384622 -3748.6010749340057,1.4484024047851562,37464.233345746994,44022,0,37464.233345746994,0.22245006,2472,0.0720045497938374,41216.33184504509,0.19830093,0.0699378417169759,0.41334906,5348,0.118279154638578 -3884.234937906265,1.505363941192627,38905.48064374924,45723,0,38905.48064374924,0.22004013,2472,0.0710499055511547,42793.34859013557,0.2161998,0.0755508978966493,0.4112349,5348,0.1175164370468347 -4029.571016550064,1.5600640773773191,40345.91604781151,47391,0,40345.91604781151,0.2199478,2472,0.0712733329271017,44379.25354409218,0.14322816,0.0512173187565687,0.4108572,5348,0.1175357463529548 -4165.893843173981,1.620795726776123,40835.422513484955,48000,0,40835.422513484955,0.21999335,2472,0.07125302134747019,45005.17563033104,0.13066882,0.04626666252369263,0.41104025,5348,0.1175357463529548 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/measurements.csv deleted file mode 100644 index ce457c209..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,16.31417,32.8944,,,,,,,,,,,,,, -1,,,30.994576,4.707897006283109,30.812992,4.233575021481603,5348.0,30.939766,4.561066764162249,2472.0,16.178600549697876,240.9686577320099,16.178600549697876,224.78997611999512,0.0,0.0 -100,1.0471311,6.1214423,,,,,,,,,,,,,, -200,0.36795318,5.8258915,,,,,,,,,,,,,, -300,0.9897778,5.686956,,,,,,,,,,,,,, -400,1.230101,5.078728,,,,,,,,,,,,,, -500,0.8466425,4.2187767,,,,,,,,,,,,,, -600,1.4701293,3.733374,,,,,,,,,,,,,, -700,2.8314683,3.3242083,,,,,,,,,,,,,, -800,1.5890806,3.0683641,,,,,,,,,,,,,, -900,3.211487,2.941021,,,,,,,,,,,,,, -1000,2.5095832,2.8112361,,,,,,,,,,,,,, -1100,1.9391,2.62059,,,,,,,,,,,,,, -1200,2.1809657,2.606701,,,,,,,,,,,,,, -1300,1.890928,2.4774826,,,,,,,,,,,,,, -1400,2.1828072,2.3801565,,,,,,,,,,,,,, -1500,2.2975647,2.3753903,,,,,,,,,,,,,, -1600,3.3278208,2.274223,,,,,,,,,,,,,, -1692,,,2.0471356,0.4961872165120053,2.49202,0.547013333075876,5348.0,2.0015655,0.4741332033392237,2472.0,1456.1483154296875,1814.265554189682,1456.1483154296875,358.0138473510742,0.02854585647583,0.0 -1700,1.8867519,2.2328444,,,,,,,,,,,,,, -1800,2.0974677,2.1664374,,,,,,,,,,,,,, -1900,1.9251815,2.187092,,,,,,,,,,,,,, -2000,1.9742491,2.0246546,,,,,,,,,,,,,, -2100,2.5672374,2.112832,,,,,,,,,,,,,, -2200,4.3524795,2.1412299,,,,,,,,,,,,,, -2300,2.5824144,2.0180964,,,,,,,,,,,,,, -2400,3.7164724,2.0464091,,,,,,,,,,,,,, -2500,2.0986419,2.053879,,,,,,,,,,,,,, -2600,2.6859958,1.9233488,,,,,,,,,,,,,, -2700,2.1967301,1.9408286,,,,,,,,,,,,,, -2800,2.3262632,1.959787,,,,,,,,,,,,,, -2900,2.707237,1.9225848,,,,,,,,,,,,,, -3000,2.7526405,1.9146732,,,,,,,,,,,,,, -3100,2.0804315,1.8712814,,,,,,,,,,,,,, -3200,3.2089295,1.9642271,,,,,,,,,,,,,, -3300,2.2794995,1.9080445,,,,,,,,,,,,,, -3400,3.932331,1.9071105,,,,,,,,,,,,,, -3409,,,0.7372342,0.2299690858884231,0.913854,0.2588991764580939,5348.0,0.5924073,0.1844291430544553,2472.0,2896.421770811081,3390.00124835968,2896.421770811081,493.33926486968994,0.0849425792694091,0.0 -3500,2.181764,1.8364542,,,,,,,,,,,,,, -3600,3.3424618,1.882047,,,,,,,,,,,,,, -3700,3.8075042,1.9324452,,,,,,,,,,,,,, -3800,5.212905,1.871425,,,,,,,,,,,,,, -3900,4.240533,1.7575352,,,,,,,,,,,,,, -4000,1.8412821,1.7763221,,,,,,,,,,,,,, -4100,3.6934223,1.8039277,,,,,,,,,,,,,, -4200,3.4931326,1.7469224,,,,,,,,,,,,,, -4300,2.6748173,1.7569327,,,,,,,,,,,,,, -4400,2.4102726,1.8010849,,,,,,,,,,,,,, -4500,3.7816167,1.7759715,,,,,,,,,,,,,, -4600,1.9447241,1.77058,,,,,,,,,,,,,, -4700,2.432565,1.7958398,,,,,,,,,,,,,, -4800,2.7663627,1.7575338,,,,,,,,,,,,,, -4900,3.3780227,1.7424474,,,,,,,,,,,,,, -5000,2.147352,1.7955749,,,,,,,,,,,,,, -5098,,,0.63931066,0.2019689524050605,0.7830972,0.2235824555644593,5348.0,0.48912373,0.159324030629862,2472.0,4336.899610280991,4966.236393213272,4336.899610280991,628.9697065353394,0.134153127670288,0.0 -5100,2.1486053,1.7607906,,,,,,,,,,,,,, -5200,2.3958812,1.7219142,,,,,,,,,,,,,, -5300,2.4783738,1.7344675,,,,,,,,,,,,,, -5400,3.9019928,1.6969795,,,,,,,,,,,,,, -5500,3.353472,1.745616,,,,,,,,,,,,,, -5600,3.7315993,1.7008861,,,,,,,,,,,,,, -5700,3.8582895,1.6507957,,,,,,,,,,,,,, -5800,2.27188,1.7731847,,,,,,,,,,,,,, -5900,3.457038,1.6527536,,,,,,,,,,,,,, -6000,3.6861374,1.7144613,,,,,,,,,,,,,, -6100,2.5451448,1.7557881,,,,,,,,,,,,,, -6200,3.0659113,1.7228035,,,,,,,,,,,,,, -6300,3.6579928,1.7166498,,,,,,,,,,,,,, -6400,2.8223553,1.6283275,,,,,,,,,,,,,, -6500,3.1168296,1.6962228,,,,,,,,,,,,,, -6600,3.710089,1.6413794,,,,,,,,,,,,,, -6700,2.7120962,1.6969702,,,,,,,,,,,,,, -6781,,,0.7021821,0.2195984392781661,0.74043995,0.2126533883004914,5348.0,0.44423807,0.1440903459062011,2472.0,5777.278906345367,6539.980969905853,5777.278906345367,762.1963560581207,0.1938681602478027,0.0 -6800,6.4928226,1.6306158,,,,,,,,,,,,,, -6900,6.1590276,1.7142634,,,,,,,,,,,,,, -7000,3.9290283,1.7226694,,,,,,,,,,,,,, -7100,3.208901,1.629741,,,,,,,,,,,,,, -7200,2.8329396,1.6266048,,,,,,,,,,,,,, -7300,3.546929,1.7136551,,,,,,,,,,,,,, -7400,4.4568033,1.676186,,,,,,,,,,,,,, -7500,2.3135307,1.6043459,,,,,,,,,,,,,, -7600,2.3259676,1.6249456,,,,,,,,,,,,,, -7700,3.087594,1.6409726,,,,,,,,,,,,,, -7800,2.2972674,1.5722674,,,,,,,,,,,,,, -7900,2.2040584,1.6613096,,,,,,,,,,,,,, -8000,3.560403,1.6511621,,,,,,,,,,,,,, -8100,2.4406095,1.6233187,,,,,,,,,,,,,, -8200,2.3076205,1.6443527,,,,,,,,,,,,,, -8300,2.5100815,1.6295959,,,,,,,,,,,,,, -8400,3.025778,1.6495758,,,,,,,,,,,,,, -8491,,,0.5751879,0.1796062016228983,0.7267631,0.206619230137965,5348.0,0.44064087,0.141490463713363,2472.0,7217.765897512436,8114.162263154983,7217.765897512436,895.7563931941986,0.2485561370849609,0.0 -8500,3.2782745,1.6135854,,,,,,,,,,,,,, -8600,2.8799143,1.5933009,,,,,,,,,,,,,, -8700,2.636872,1.6108714,,,,,,,,,,,,,, -8800,2.9713051,1.6141504,,,,,,,,,,,,,, -8900,2.3944142,1.6359527,,,,,,,,,,,,,, -9000,2.925543,1.5965048,,,,,,,,,,,,,, -9100,3.4035351,1.6411772,,,,,,,,,,,,,, -9200,2.1435819,1.586261,,,,,,,,,,,,,, -9300,4.646834,1.6143069,,,,,,,,,,,,,, -9400,2.383327,1.5717212,,,,,,,,,,,,,, -9500,5.2732916,1.5927165,,,,,,,,,,,,,, -9600,2.9965289,1.6146624,,,,,,,,,,,,,, -9700,4.0990224,1.6424,,,,,,,,,,,,,, -9800,2.9266865,1.6200624,,,,,,,,,,,,,, -9900,3.2698185,1.5735012,,,,,,,,,,,,,, -10000,2.0154698,1.5530612,,,,,,,,,,,,,, -10100,4.587024,1.649112,,,,,,,,,,,,,, -10176,,,0.53614444,0.1739231810375223,0.6622156,0.1904380316093341,5348.0,0.39185888,0.1262770905693336,2472.0,8658.173372507095,9688.754879951475,8658.173372507095,1029.8109738826752,0.3024752140045166,0.0 -10200,2.5843692,1.5940518,,,,,,,,,,,,,, -10300,2.605593,1.5291685,,,,,,,,,,,,,, -10400,2.51567,1.5877634,,,,,,,,,,,,,, -10500,2.2896168,1.5227968,,,,,,,,,,,,,, -10600,3.3508468,1.530677,,,,,,,,,,,,,, -10700,2.4348874,1.5933418,,,,,,,,,,,,,, -10800,3.2034318,1.5746877,,,,,,,,,,,,,, -10900,3.1045744,1.6415176,,,,,,,,,,,,,, -11000,2.6640263,1.5975939,,,,,,,,,,,,,, -11100,2.050816,1.5644927,,,,,,,,,,,,,, -11200,2.4962392,1.5915476,,,,,,,,,,,,,, -11300,3.0548098,1.5939826,,,,,,,,,,,,,, -11400,2.6897447,1.4933306,,,,,,,,,,,,,, -11500,2.9961326,1.5309764,,,,,,,,,,,,,, -11600,2.3144238,1.5682018,,,,,,,,,,,,,, -11700,2.984468,1.5483301,,,,,,,,,,,,,, -11800,3.3239777,1.6474358,,,,,,,,,,,,,, -11874,,,0.4290752,0.1424052505581408,0.65312153,0.1863830773241163,5348.0,0.38414696,0.1226210062356549,2472.0,10098.086151361464,11264.540845394136,10098.086151361464,1165.5521621704102,0.3547070026397705,0.0 -11900,1.9119799,1.5805392,,,,,,,,,,,,,, -12000,4.324442,1.5701736,,,,,,,,,,,,,, -12100,5.6921215,1.5568783,,,,,,,,,,,,,, -12200,2.9120538,1.5495946,,,,,,,,,,,,,, -12300,2.0107791,1.5456257,,,,,,,,,,,,,, -12400,2.5848396,1.5245477,,,,,,,,,,,,,, -12500,3.2402675,1.5305629,,,,,,,,,,,,,, -12600,5.2615128,1.5977246,,,,,,,,,,,,,, -12700,3.0692277,1.5389478,,,,,,,,,,,,,, -12800,3.667601,1.5082698,,,,,,,,,,,,,, -12900,3.42237,1.4938257,,,,,,,,,,,,,, -13000,3.4102097,1.6071887,,,,,,,,,,,,,, -13100,3.4043722,1.5765013,,,,,,,,,,,,,, -13200,2.883369,1.5612983,,,,,,,,,,,,,, -13300,2.1491613,1.4960852,,,,,,,,,,,,,, -13400,2.6083686,1.5525436,,,,,,,,,,,,,, -13500,2.8148174,1.5646619,,,,,,,,,,,,,, -13566,,,0.5048986,0.1636252175721682,0.636037,0.1803971924268901,5348.0,0.37574464,0.1222960209615501,2472.0,11538.334949493408,12843.813130378723,11538.334949493408,1304.4409563541412,0.4097585678100586,0.0 -13600,2.7132633,1.5245844,,,,,,,,,,,,,, -13700,2.942852,1.5581691,,,,,,,,,,,,,, -13800,5.414521,1.5162436,,,,,,,,,,,,,, -13900,3.0926409,1.5014344,,,,,,,,,,,,,, -14000,2.584703,1.5873077,,,,,,,,,,,,,, -14100,4.114477,1.5175439,,,,,,,,,,,,,, -14200,3.034713,1.524252,,,,,,,,,,,,,, -14300,2.5367997,1.4847856,,,,,,,,,,,,,, -14400,2.5129764,1.5350944,,,,,,,,,,,,,, -14500,3.2910106,1.5707442,,,,,,,,,,,,,, -14600,3.1426203,1.5490245,,,,,,,,,,,,,, -14700,4.195681,1.5447769,,,,,,,,,,,,,, -14800,3.1169965,1.5821027,,,,,,,,,,,,,, -14900,5.322238,1.5746356,,,,,,,,,,,,,, -15000,3.7139664,1.5147921,,,,,,,,,,,,,, -15100,2.1448178,1.5060351,,,,,,,,,,,,,, -15200,3.171349,1.5064933,,,,,,,,,,,,,, -15256,,,0.45856845,0.1502044088176352,0.6351292,0.1813336937737142,5348.0,0.36284128,0.1173196839518209,2472.0,12978.819166183472,14421.089455366136,12978.819166183472,1441.1059172153473,0.4590444564819336,0.0 -15300,1.8844172,1.5798914,,,,,,,,,,,,,, -15400,4.129794,1.5646174,,,,,,,,,,,,,, -15500,3.1632016,1.462537,,,,,,,,,,,,,, -15600,2.8100796,1.5104218,,,,,,,,,,,,,, -15700,2.4433699,1.5237037,,,,,,,,,,,,,, -15800,2.2339191,1.5113941,,,,,,,,,,,,,, -15900,3.2072031,1.4897159,,,,,,,,,,,,,, -16000,2.2621465,1.5255685,,,,,,,,,,,,,, -16100,3.7659278,1.4834626,,,,,,,,,,,,,, -16200,2.0414805,1.5066193,,,,,,,,,,,,,, -16300,2.687833,1.4835428,,,,,,,,,,,,,, -16400,2.1226816,1.5480425,,,,,,,,,,,,,, -16500,5.474573,1.5027364,,,,,,,,,,,,,, -16600,4.767575,1.5060194,,,,,,,,,,,,,, -16700,2.231397,1.476659,,,,,,,,,,,,,, -16800,2.582296,1.4977409,,,,,,,,,,,,,, -16900,3.9897857,1.457857,,,,,,,,,,,,,, -16959,,,0.44045535,0.1457255361385011,0.6102114,0.1746237098969848,5348.0,0.3535885,0.1136026648792476,2472.0,14418.86126089096,15997.133904457092,14418.86126089096,1576.9730966091156,0.5129210948944092,0.0 -17000,2.5051858,1.5679283,,,,,,,,,,,,,, -17100,4.0122447,1.5015785,,,,,,,,,,,,,, -17200,3.037986,1.502999,,,,,,,,,,,,,, -17300,3.202788,1.5320932,,,,,,,,,,,,,, -17400,5.7885895,1.5984399,,,,,,,,,,,,,, -17500,4.688428,1.4988954,,,,,,,,,,,,,, -17600,3.5521727,1.5206467,,,,,,,,,,,,,, -17700,3.6938195,1.5166124,,,,,,,,,,,,,, -17800,3.968339,1.4370749,,,,,,,,,,,,,, -17900,1.841551,1.5215735,,,,,,,,,,,,,, -18000,2.7971752,1.5134267,,,,,,,,,,,,,, -18100,2.56143,1.4786426,,,,,,,,,,,,,, -18200,3.773123,1.5335681,,,,,,,,,,,,,, -18300,3.3270526,1.4931537,,,,,,,,,,,,,, -18400,5.3832564,1.5358739,,,,,,,,,,,,,, -18500,1.8749276,1.5256488,,,,,,,,,,,,,, -18600,1.6164161,1.4551138,,,,,,,,,,,,,, -18655,,,0.40981236,0.1340964840556009,0.5838607,0.1677399422651747,5348.0,0.33918792,0.1079560457416773,2472.0,15859.378022432327,17572.3883395195,15859.378022432327,1711.5768485069275,0.5694046020507812,0.0 -18700,3.335464,1.4337142,,,,,,,,,,,,,, -18800,2.280028,1.5148419,,,,,,,,,,,,,, -18900,2.9671352,1.5086256,,,,,,,,,,,,,, -19000,2.0132914,1.4089876,,,,,,,,,,,,,, -19100,2.3742297,1.4610612,,,,,,,,,,,,,, -19200,2.8081355,1.4309973,,,,,,,,,,,,,, -19300,3.2979503,1.4777732,,,,,,,,,,,,,, -19400,3.0615253,1.4430506,,,,,,,,,,,,,, -19500,1.7538567,1.5064061,,,,,,,,,,,,,, -19600,3.5319123,1.5488479,,,,,,,,,,,,,, -19700,3.48783,1.4523183,,,,,,,,,,,,,, -19800,5.051174,1.4865078,,,,,,,,,,,,,, -19900,1.9725232,1.4025552,,,,,,,,,,,,,, -20000,2.3537126,1.4584428,,,,,,,,,,,,,, -20100,2.0688899,1.4331331,,,,,,,,,,,,,, -20200,2.385063,1.4677083,,,,,,,,,,,,,, -20300,2.326835,1.4626377,,,,,,,,,,,,,, -20349,,,0.42425936,0.14089816744684,0.5756924,0.1647373451635015,5348.0,0.33319625,0.1072654520342047,2472.0,17299.322156190872,19145.56060743332,17299.322156190872,1844.674907445908,0.621281623840332,0.0 -20400,3.026539,1.5582445,,,,,,,,,,,,,, -20500,2.7701106,1.4925518,,,,,,,,,,,,,, -20600,2.2458365,1.414267,,,,,,,,,,,,,, -20700,2.7372627,1.448244,,,,,,,,,,,,,, -20800,2.6394463,1.4531586,,,,,,,,,,,,,, -20900,2.9268136,1.4512709,,,,,,,,,,,,,, -21000,5.7030144,1.4738674,,,,,,,,,,,,,, -21100,3.2041829,1.4402912,,,,,,,,,,,,,, -21200,2.8719985,1.4710464,,,,,,,,,,,,,, -21300,3.3719506,1.4871677,,,,,,,,,,,,,, -21400,3.2397223,1.4482725,,,,,,,,,,,,,, -21500,3.4316761,1.3691974,,,,,,,,,,,,,, -21600,4.401458,1.4612217,,,,,,,,,,,,,, -21700,3.032019,1.4519726,,,,,,,,,,,,,, -21800,2.1110022,1.4660918,,,,,,,,,,,,,, -21900,3.0600662,1.5236607,,,,,,,,,,,,,, -22000,2.7230487,1.3966644,,,,,,,,,,,,,, -22051,,,0.4327474,0.1370816180418829,0.5562388,0.1593017754906977,5348.0,0.32314736,0.1023906729226332,2472.0,18739.717533826828,20722.88533854485,18739.717533826828,1981.468006372452,0.6765701770782471,0.0 -22100,6.5821157,1.4117317,,,,,,,,,,,,,, -22200,4.011614,1.4373682,,,,,,,,,,,,,, -22300,2.4824164,1.3987665,,,,,,,,,,,,,, -22400,1.8596385,1.4857539,,,,,,,,,,,,,, -22500,2.1499753,1.4695419,,,,,,,,,,,,,, -22600,5.1283746,1.432804,,,,,,,,,,,,,, -22700,2.6956568,1.4051663,,,,,,,,,,,,,, -22800,2.7404263,1.4476523,,,,,,,,,,,,,, -22900,1.6432779,1.4775587,,,,,,,,,,,,,, -23000,2.5053031,1.4326388,,,,,,,,,,,,,, -23100,1.7893767,1.3536801,,,,,,,,,,,,,, -23200,2.526326,1.4252033,,,,,,,,,,,,,, -23300,3.293214,1.4672102,,,,,,,,,,,,,, -23400,2.9972982,1.4587119,,,,,,,,,,,,,, -23500,3.449197,1.4533032,,,,,,,,,,,,,, -23600,2.3313096,1.4858197,,,,,,,,,,,,,, -23700,3.6080499,1.3819658,,,,,,,,,,,,,, -23736,,,0.36796176,0.1241607220686649,0.54934376,0.1599776012049007,5348.0,0.3138135,0.100197022322426,2472.0,20179.681025505062,22298.57732820511,20179.681025505062,2117.053550004959,0.740778923034668,0.0 -23800,2.5577445,1.3731604,,,,,,,,,,,,,, -23900,2.7089436,1.4114127,,,,,,,,,,,,,, -24000,4.7773466,1.4108963,,,,,,,,,,,,,, -24100,2.9650342,1.4721762,,,,,,,,,,,,,, -24200,2.0673096,1.4072291,,,,,,,,,,,,,, -24300,2.8787544,1.3895788,,,,,,,,,,,,,, -24400,2.0352595,1.4222108,,,,,,,,,,,,,, -24500,5.0776124,1.3231326,,,,,,,,,,,,,, -24600,2.251597,1.3884673,,,,,,,,,,,,,, -24700,6.271471,1.4383525,,,,,,,,,,,,,, -24800,1.9555285,1.3524525,,,,,,,,,,,,,, -24900,2.2110128,1.3610836,,,,,,,,,,,,,, -25000,3.5060005,1.4310669,,,,,,,,,,,,,, -25100,1.805806,1.3650216,,,,,,,,,,,,,, -25200,2.0349967,1.435657,,,,,,,,,,,,,, -25300,2.7350116,1.4113219,,,,,,,,,,,,,, -25400,3.3555784,1.371042,,,,,,,,,,,,,, -25430,,,0.39533722,0.1345729867718728,0.53948,0.1550923467565193,5348.0,0.30784765,0.0975362053906932,2472.0,21619.953918218613,23874.16535425186,21619.953918218613,2252.23565363884,0.7953715324401855,0.0 -25500,2.6234589,1.354764,,,,,,,,,,,,,, -25600,2.2684922,1.3602357,,,,,,,,,,,,,, -25700,2.89282,1.4015249,,,,,,,,,,,,,, -25800,2.8268874,1.3336633,,,,,,,,,,,,,, -25900,3.572186,1.3488421,,,,,,,,,,,,,, -26000,2.3570037,1.3790064,,,,,,,,,,,,,, -26100,2.320481,1.3423669,,,,,,,,,,,,,, -26200,3.6095896,1.3775965,,,,,,,,,,,,,, -26300,2.0958016,1.3174427,,,,,,,,,,,,,, -26400,1.8733021,1.3281283,,,,,,,,,,,,,, -26500,1.7163837,1.3621268,,,,,,,,,,,,,, -26600,4.272918,1.3807151,,,,,,,,,,,,,, -26700,2.327392,1.372034,,,,,,,,,,,,,, -26800,4.0430555,1.4028387,,,,,,,,,,,,,, -26900,2.651438,1.3242048,,,,,,,,,,,,,, -27000,2.2253149,1.3295575,,,,,,,,,,,,,, -27100,2.7932878,1.3088775,,,,,,,,,,,,,, -27147,,,0.3278709,0.1105115369186893,0.52832973,0.1517132181855045,5348.0,0.2963604,0.0940629252736985,2472.0,23060.258680820465,25451.05943512917,23060.258680820465,2388.6867899894714,0.8549208641052246,0.0 -27200,2.2895198,1.3521959,,,,,,,,,,,,,, -27300,2.284488,1.3875238,,,,,,,,,,,,,, -27400,3.2016587,1.4314506,,,,,,,,,,,,,, -27500,1.5819545,1.3782578,,,,,,,,,,,,,, -27600,2.313157,1.3623315,,,,,,,,,,,,,, -27700,2.999521,1.4046047,,,,,,,,,,,,,, -27800,1.7102909,1.37134,,,,,,,,,,,,,, -27900,3.3607047,1.3590441,,,,,,,,,,,,,, -28000,3.1119432,1.3670938,,,,,,,,,,,,,, -28100,4.1037006,1.3504436,,,,,,,,,,,,,, -28200,2.5185661,1.3522025,,,,,,,,,,,,,, -28300,1.7517575,1.2758554,,,,,,,,,,,,,, -28400,4.204479,1.3799251,,,,,,,,,,,,,, -28500,2.941478,1.3006504,,,,,,,,,,,,,, -28600,2.2698085,1.3240812,,,,,,,,,,,,,, -28700,2.1100936,1.3387946,,,,,,,,,,,,,, -28800,2.786658,1.3431606,,,,,,,,,,,,,, -28814,,,0.33591753,0.1122769309086795,0.49735403,0.1428309373702656,5348.0,0.2814619,0.0896553124936526,2472.0,24500.23553586006,27026.162631988525,24500.23553586006,2523.680060863495,0.9110238552093506,0.0 -28900,2.8236086,1.2996632,,,,,,,,,,,,,, -29000,5.2123623,1.3265376,,,,,,,,,,,,,, -29100,2.865828,1.334219,,,,,,,,,,,,,, -29200,2.9821393,1.3572104,,,,,,,,,,,,,, -29300,2.445218,1.2963849,,,,,,,,,,,,,, -29400,1.9453067,1.384352,,,,,,,,,,,,,, -29500,3.245479,1.3622084,,,,,,,,,,,,,, -29600,4.558199,1.3813467,,,,,,,,,,,,,, -29700,1.4148899,1.4151666,,,,,,,,,,,,,, -29800,2.5556479,1.317352,,,,,,,,,,,,,, -29900,2.2938867,1.276308,,,,,,,,,,,,,, -30000,3.438849,1.3259026,,,,,,,,,,,,,, -30100,2.2678845,1.3178608,,,,,,,,,,,,,, -30200,2.3978047,1.3303832,,,,,,,,,,,,,, -30300,2.8298404,1.3611923,,,,,,,,,,,,,, -30400,2.6047478,1.3112801,,,,,,,,,,,,,, -30496,,,0.32313046,0.1091377633158311,0.49119592,0.141112409125578,5348.0,0.2728875,0.0881725671805496,2472.0,25940.209047079086,28602.19198703766,25940.209047079086,2659.6065866947174,0.9619748592376708,0.0 -30500,3.2045631,1.298886,,,,,,,,,,,,,, -30600,2.8743713,1.2656333,,,,,,,,,,,,,, -30700,2.193297,1.2871879,,,,,,,,,,,,,, -30800,2.3162155,1.324464,,,,,,,,,,,,,, -30900,4.028298,1.2225046,,,,,,,,,,,,,, -31000,2.6555252,1.2870789,,,,,,,,,,,,,, -31100,1.8136504,1.3288277,,,,,,,,,,,,,, -31200,3.9489632,1.2759985,,,,,,,,,,,,,, -31300,2.7347164,1.2250032,,,,,,,,,,,,,, -31400,2.7485888,1.2761132,,,,,,,,,,,,,, -31500,2.1184504,1.3112823,,,,,,,,,,,,,, -31600,3.9844415,1.2688378,,,,,,,,,,,,,, -31700,2.2013326,1.3408343,,,,,,,,,,,,,, -31800,5.279231,1.2872311,,,,,,,,,,,,,, -31900,2.4904265,1.3182626,,,,,,,,,,,,,, -32000,2.4268205,1.3287553,,,,,,,,,,,,,, -32100,2.4631023,1.3114591,,,,,,,,,,,,,, -32200,1.471607,1.247098,,,,,,,,,,,,,, -32213,,,0.32220346,0.1095048590651246,0.48185685,0.1390752773299091,5348.0,0.26595923,0.0854101923506591,2472.0,27380.473700761795,30179.518819332123,27380.473700761795,2796.5278537273407,1.02350115776062,0.0 -32300,1.9750205,1.2913822,,,,,,,,,,,,,, -32400,5.898061,1.2863369,,,,,,,,,,,,,, -32500,2.4341555,1.2824198,,,,,,,,,,,,,, -32600,1.8386961,1.2879142,,,,,,,,,,,,,, -32700,2.0122397,1.3162197,,,,,,,,,,,,,, -32800,2.8037674,1.3389485,,,,,,,,,,,,,, -32900,3.9222841,1.3022375,,,,,,,,,,,,,, -33000,1.926195,1.3321906,,,,,,,,,,,,,, -33100,3.0525987,1.2498337,,,,,,,,,,,,,, -33200,3.8155649,1.2737514,,,,,,,,,,,,,, -33300,1.8464293,1.270707,,,,,,,,,,,,,, -33400,2.0961251,1.2698185,,,,,,,,,,,,,, -33500,1.6093441,1.2425681,,,,,,,,,,,,,, -33600,2.2386546,1.3202928,,,,,,,,,,,,,, -33700,2.8319108,1.3441054,,,,,,,,,,,,,, -33800,1.8527616,1.3113494,,,,,,,,,,,,,, -33892,,,0.25985754,0.0887791944037007,0.46532407,0.1326742423511011,5348.0,0.25858718,0.081997846972559,2472.0,28821.714812994003,31757.905738592148,28821.714812994003,2933.5398259162903,1.0806150436401367,0.0 -33900,2.3604164,1.3084165,,,,,,,,,,,,,, -34000,1.9029547,1.2783952,,,,,,,,,,,,,, -34100,4.5272393,1.2840889,,,,,,,,,,,,,, -34200,2.050308,1.2077498,,,,,,,,,,,,,, -34300,2.6730037,1.276247,,,,,,,,,,,,,, -34400,3.4584072,1.2512107,,,,,,,,,,,,,, -34500,2.449211,1.2408178,,,,,,,,,,,,,, -34600,1.5101428,1.2740922,,,,,,,,,,,,,, -34700,2.0997155,1.1863873,,,,,,,,,,,,,, -34800,2.8097966,1.2275952,,,,,,,,,,,,,, -34900,2.9583788,1.2238714,,,,,,,,,,,,,, -35000,1.9562306,1.2809528,,,,,,,,,,,,,, -35100,2.1127877,1.2394265,,,,,,,,,,,,,, -35200,2.4193707,1.203413,,,,,,,,,,,,,, -35300,3.5946586,1.2893242,,,,,,,,,,,,,, -35400,2.3389554,1.1823097,,,,,,,,,,,,,, -35500,2.7856236,1.2544918,,,,,,,,,,,,,, -35575,,,0.23919861,0.0823648799729387,0.45785353,0.1310232966778338,5348.0,0.25288898,0.0821197164503483,2472.0,30261.89965081215,33333.53277087212,30261.89965081215,3068.84513258934,1.1369514465332031,0.0 -35600,3.8712804,1.2225404,,,,,,,,,,,,,, -35700,2.8359654,1.224797,,,,,,,,,,,,,, -35800,4.041188,1.2151483,,,,,,,,,,,,,, -35900,2.9985924,1.2791603,,,,,,,,,,,,,, -36000,4.280546,1.2975277,,,,,,,,,,,,,, -36100,1.6085563,1.2395676,,,,,,,,,,,,,, -36200,3.5640879,1.2179894,,,,,,,,,,,,,, -36300,3.719305,1.2073215,,,,,,,,,,,,,, -36400,2.4582522,1.2632649,,,,,,,,,,,,,, -36500,2.2738981,1.2110388,,,,,,,,,,,,,, -36600,6.6831956,1.2148345,,,,,,,,,,,,,, -36700,4.276652,1.1906871,,,,,,,,,,,,,, -36800,1.9433125,1.2305766,,,,,,,,,,,,,, -36900,15.032586,1.2244086,,,,,,,,,,,,,, -37000,3.2473454,1.1983417,,,,,,,,,,,,,, -37100,7.120197,1.2025098,,,,,,,,,,,,,, -37200,2.533535,1.1663972,,,,,,,,,,,,,, -37279,,,0.26334742,0.0912923553024017,0.44023106,0.1263987178620736,5348.0,0.24081011,0.0775699226128816,2472.0,31702.451380491257,34908.66898369789,31702.451380491257,3203.287886619568,1.198265790939331,0.0 -37300,3.8792653,1.2217501,,,,,,,,,,,,,, -37400,1.9833578,1.2116811,,,,,,,,,,,,,, -37500,1.7902292,1.1599354,,,,,,,,,,,,,, -37600,3.287342,1.1881144,,,,,,,,,,,,,, -37700,3.6429656,1.2673746,,,,,,,,,,,,,, -37800,2.1045842,1.2084181,,,,,,,,,,,,,, -37900,2.4972446,1.1846238,,,,,,,,,,,,,, -38000,2.0928972,1.1437984,,,,,,,,,,,,,, -38100,3.4329147,1.223471,,,,,,,,,,,,,, -38200,2.0163388,1.1964936,,,,,,,,,,,,,, -38300,2.7440603,1.2192662,,,,,,,,,,,,,, -38400,4.009854,1.2092674,,,,,,,,,,,,,, -38500,1.9966911,1.1696308,,,,,,,,,,,,,, -38600,2.432157,1.188281,,,,,,,,,,,,,, -38700,2.2335145,1.2264857,,,,,,,,,,,,,, -38800,2.2803402,1.1632731,,,,,,,,,,,,,, -38900,3.9040346,1.2092638,,,,,,,,,,,,,, -38963,,,0.23935167,0.0811849127280522,0.4321774,0.1247188082296262,5348.0,0.23291391,0.0753356488534113,2472.0,33142.61997103691,36485.36349272728,33142.61997103691,3339.672556877136,1.260697364807129,0.0 -39000,3.6018689,1.120906,,,,,,,,,,,,,, -39100,2.1694686,1.1946406,,,,,,,,,,,,,, -39200,1.8743378,1.2012221,,,,,,,,,,,,,, -39300,1.7850088,1.1761192,,,,,,,,,,,,,, -39400,2.2180424,1.1698966,,,,,,,,,,,,,, -39500,1.998405,1.1774765,,,,,,,,,,,,,, -39600,2.7247174,1.160809,,,,,,,,,,,,,, -39700,3.6965957,1.2054768,,,,,,,,,,,,,, -39800,4.337481,1.1584111,,,,,,,,,,,,,, -39900,2.1907623,1.1821064,,,,,,,,,,,,,, -40000,2.69612,1.2182275,,,,,,,,,,,,,, -40100,3.0656083,1.2355014,,,,,,,,,,,,,, -40200,1.943254,1.1701294,,,,,,,,,,,,,, -40300,2.0613375,1.1587045,,,,,,,,,,,,,, -40400,1.6745754,1.1391695,,,,,,,,,,,,,, -40500,3.2540224,1.1978779,,,,,,,,,,,,,, -40600,3.1575098,1.1878449,,,,,,,,,,,,,, -40672,,,0.24396911,0.0813538117108663,0.42279685,0.1221410158625949,5348.0,0.22748674,0.0735888530050982,2472.0,34583.59595036507,38061.68483424187,34583.59595036507,3474.866818666458,1.3295516967773438,0.0 -40700,2.2319882,1.2143589,,,,,,,,,,,,,, -40800,1.6395305,1.1701009,,,,,,,,,,,,,, -40900,2.8683994,1.2032912,,,,,,,,,,,,,, -41000,2.4718592,1.1724521,,,,,,,,,,,,,, -41100,2.206738,1.1684606,,,,,,,,,,,,,, -41200,2.5145004,1.0719022,,,,,,,,,,,,,, -41300,2.7864444,1.1234417,,,,,,,,,,,,,, -41400,4.7181067,1.1926279,,,,,,,,,,,,,, -41500,1.9920835,1.1291891,,,,,,,,,,,,,, -41600,3.8563187,1.1316565,,,,,,,,,,,,,, -41700,3.2261508,1.1488297,,,,,,,,,,,,,, -41800,2.4866889,1.1330239,,,,,,,,,,,,,, -41900,2.8219256,1.106387,,,,,,,,,,,,,, -42000,3.2047813,1.149586,,,,,,,,,,,,,, -42100,1.8976126,1.186561,,,,,,,,,,,,,, -42200,3.373713,1.1975024,,,,,,,,,,,,,, -42300,2.7245073,1.1275779,,,,,,,,,,,,,, -42350,,,0.20017296,0.0698609234942493,0.41711703,0.1192253106384622,5348.0,0.22420189,0.0727154550809416,2472.0,36023.78817820549,39640.9696969986,36023.78817820549,3613.8211603164673,1.3897809982299805,0.0 -42400,2.2215354,1.1583161,,,,,,,,,,,,,, -42500,2.8590946,1.1409763,,,,,,,,,,,,,, -42600,3.0496686,1.1928494,,,,,,,,,,,,,, -42700,2.8190067,1.1920379,,,,,,,,,,,,,, -42800,2.356069,1.1443118,,,,,,,,,,,,,, -42900,3.081242,1.1742765,,,,,,,,,,,,,, -43000,3.6628065,1.1981089,,,,,,,,,,,,,, -43100,2.6733344,1.1354501,,,,,,,,,,,,,, -43200,2.9795566,1.1287651,,,,,,,,,,,,,, -43300,2.6675105,1.1704954,,,,,,,,,,,,,, -43400,2.9704776,1.1279752,,,,,,,,,,,,,, -43500,1.8346503,1.1635411,,,,,,,,,,,,,, -43600,2.7091577,1.1567137,,,,,,,,,,,,,, -43700,1.8431648,1.1233213,,,,,,,,,,,,,, -43800,2.1847458,1.1629155,,,,,,,,,,,,,, -43900,4.701212,1.1511235,,,,,,,,,,,,,, -44000,3.1540587,1.1410754,,,,,,,,,,,,,, -44022,,,0.19830093,0.0699378417169759,0.41334906,0.118279154638578,5348.0,0.22245006,0.0720045497938374,2472.0,37464.233345746994,41216.33184504509,37464.233345746994,3748.6010749340057,1.4484024047851562,0.0 -44100,2.0489175,1.1853703,,,,,,,,,,,,,, -44200,2.0115285,1.1574389,,,,,,,,,,,,,, -44300,4.472919,1.1287215,,,,,,,,,,,,,, -44400,1.6431689,1.1492295,,,,,,,,,,,,,, -44500,3.3213248,1.1773614,,,,,,,,,,,,,, -44600,1.8513664,1.1584442,,,,,,,,,,,,,, -44700,2.0601656,1.1872883,,,,,,,,,,,,,, -44800,4.198932,1.1715709,,,,,,,,,,,,,, -44900,4.4837775,1.1620334,,,,,,,,,,,,,, -45000,2.5396721,1.1354443,,,,,,,,,,,,,, -45100,1.8242971,1.1280817,,,,,,,,,,,,,, -45200,3.6696262,1.1271317,,,,,,,,,,,,,, -45300,1.8408172,1.1024954,,,,,,,,,,,,,, -45400,3.3096633,1.1184931,,,,,,,,,,,,,, -45500,2.3510728,1.1283323,,,,,,,,,,,,,, -45600,2.4853728,1.1049725,,,,,,,,,,,,,, -45700,3.3996012,1.1489276,,,,,,,,,,,,,, -45723,,,0.2161998,0.0755508978966493,0.4112349,0.1175164370468347,5348.0,0.22004013,0.0710499055511547,2472.0,38905.48064374924,42793.34859013557,38905.48064374924,3884.234937906265,1.505363941192627,0.0 -45800,1.7896458,1.0992558,,,,,,,,,,,,,, -45900,14.43599,1.1186322,,,,,,,,,,,,,, -46000,3.419922,1.1635634,,,,,,,,,,,,,, -46100,3.0494308,1.0802515,,,,,,,,,,,,,, -46200,5.3531623,1.1203138,,,,,,,,,,,,,, -46300,1.6528075,1.1096548,,,,,,,,,,,,,, -46400,2.3864362,1.1376343,,,,,,,,,,,,,, -46500,2.7256582,1.1071931,,,,,,,,,,,,,, -46600,8.095403,1.0980917,,,,,,,,,,,,,, -46700,1.9303554,1.2330787,,,,,,,,,,,,,, -46800,2.7016284,1.1515055,,,,,,,,,,,,,, -46900,1.9656774,1.1220188,,,,,,,,,,,,,, -47000,2.4076009,1.0765036,,,,,,,,,,,,,, -47100,3.2775588,1.1147358,,,,,,,,,,,,,, -47200,3.5249147,1.1289059,,,,,,,,,,,,,, -47300,1.6774524,1.1743628,,,,,,,,,,,,,, -47391,,,0.14322816,0.0512173187565687,0.4108572,0.1175357463529548,5348.0,0.2199478,0.0712733329271017,2472.0,40345.91604781151,44379.25354409218,40345.91604781151,4029.571016550064,1.5600640773773191,0.0 -47400,1.9288552,1.1654329,,,,,,,,,,,,,, -47500,1.7453226,1.1230024,,,,,,,,,,,,,, -47600,3.5616639,1.1649204,,,,,,,,,,,,,, -47700,2.3139377,1.1471181,,,,,,,,,,,,,, -47800,2.5616827,1.1233966,,,,,,,,,,,,,, -47900,3.2798262,1.14565,,,,,,,,,,,,,, -48000,,,0.13066882,0.0462666625236926,0.41104025,0.1175357463529548,5348.0,0.21999335,0.0712530213474701,2472.0,40835.42251348496,45005.17563033104,40835.42251348496,4165.893843173981,1.620795726776123,0.0 -48000,,,,,,,,,,,40835.422513484955,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index bd9e8834e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -307.5564877986908,0.0,19.20242476463318,1,0,19.20242476463318,0.5214918851852417,0.7347381114959717,0.0268125009161917,43793,326.7589704990387,0.532486081123352,0.7277461886405945,0.0224462379105656,0.5230699777603149,0.7331878542900085,0.0255389908122996,43793 -426.1334116458893,0.0302038192749023,259.26010942459106,753,0,259.26010942459106,0.983142077922821,0.0804035291075706,0.0345979947513477,43793,685.4427721500397,0.986735999584198,0.0693851262331008,0.0327052996792796,0.9841179251670836,0.0775840729475021,0.033466476626962,43793 -547.5880465507507,0.0656647682189941,499.3460896015167,1504,0,499.3460896015167,0.9832296967506408,0.0641597285866737,0.0706124312681735,43793,1047.0379321575165,0.9868547320365906,0.051052525639534,0.0706327779101072,0.984196662902832,0.0607562810182571,0.0690115714268976,43793 -671.7086169719696,0.0926315784454345,739.3077754974365,2258,0,739.3077754974365,0.9839473962783812,0.0569424144923687,0.1268228510230024,43793,1411.1662957668304,0.9876076579093932,0.0449433661997318,0.1270505432485927,0.9849172234535216,0.0539407506585121,0.122546315511896,43793 -800.0604002475739,0.1198203563690185,979.2893161773682,2998,0,979.2893161773682,0.9841668605804444,0.0545504949986934,0.1515404133887407,43793,1779.545521259308,0.9880425930023192,0.0425593219697475,0.1539329837726181,0.9850654006004332,0.0518801398575305,0.1472381553313431,43793 -924.3276314735411,0.1471123695373535,1219.3918023109436,3740,0,1219.3918023109436,0.9844903349876404,0.0534924790263175,0.166235172026189,43793,2143.96270942688,0.9883670210838318,0.0407157242298126,0.1844254620145033,0.985426664352417,0.0505857840180397,0.1685717042001645,43793 -1055.3013689517977,0.1740467548370361,1459.5320043563845,4482,0,1459.5320043563845,0.9846583604812622,0.0519588366150856,0.181732471899558,43793,2515.1227877140045,0.9883249402046204,0.0400893315672874,0.2131278146848692,0.9855647087097168,0.0492448098957538,0.1808645836231969,43793 -1181.2931609153748,0.2058660984039306,1699.6076345443726,5209,0,1699.6076345443726,0.9847750663757324,0.0513728559017181,0.1873674543732953,43793,2881.247640132904,0.9886398315429688,0.0388185493648052,0.2206747630203164,0.9856414198875428,0.0487050414085388,0.1843567487809262,43793 -1306.7508709430697,0.2331297397613525,1939.6774501800537,5974,0,1939.6774501800537,0.984929621219635,0.0506270751357078,0.2034318188046622,43793,3246.82239818573,0.988920032978058,0.038038682192564,0.2460739881269172,0.9857940673828124,0.0480422861874103,0.1994409548683768,43793 -1432.3741779327393,0.2616612911224365,2179.744330406189,6743,0,2179.744330406189,0.984944760799408,0.0509179197251796,0.2028065330187146,43793,3612.562463760376,0.9888412952423096,0.0381081067025661,0.2423005532237125,0.9858776926994324,0.0480373539030551,0.204004536859809,43793 -1558.2514843940735,0.2884867191314697,2419.9255475997925,7515,0,2419.9255475997925,0.9852067828178406,0.0496997125446796,0.217869939416924,43793,3978.6678981781006,0.988990843296051,0.0374416075646877,0.2543271757587377,0.98616224527359,0.0469825714826583,0.215558121122398,43793 -1685.931292772293,0.3178873062133789,2660.0192317962646,8267,0,2660.0192317962646,0.9852614998817444,0.0492220148444175,0.214766806953296,43793,4346.496450185776,0.9891239404678344,0.0368091352283954,0.2656056616732034,0.9862223267555236,0.0464504174888134,0.2172402529453152,43793 -1814.795142650604,0.3454384803771972,2900.125126838684,9025,0,2900.125126838684,0.9852830171585084,0.0490747392177581,0.233778299310469,43793,4715.513872861862,0.989315629005432,0.0362415947020053,0.2720422042524107,0.9861720204353333,0.0463324151933193,0.2272592642219753,43793 -1943.7513513565063,0.3736414909362793,3140.291603088379,9783,0,3140.291603088379,0.985456109046936,0.0484712347388267,0.2270379080828656,43793,5084.685031175613,0.9895707368850708,0.0354536771774292,0.2919737580040339,0.9863116145133972,0.0457900017499923,0.2333110153993411,43793 -2066.56652712822,0.4013993740081787,3380.29745554924,10547,0,3380.29745554924,0.9855917096138,0.047913584858179,0.2333297489808317,43793,5447.554109573364,0.9897831678390504,0.034679215401411,0.3002367546046378,0.9863839149475098,0.0453971065580844,0.2397068494129548,43793 -2196.096314430237,0.4317901134490967,3620.2503366470337,11313,0,3620.2503366470337,0.9855576157569884,0.0481827370822429,0.2386666619664352,43793,5817.087311029434,0.9898716807365416,0.034160740673542,0.3269157247033578,0.9863802194595336,0.0452970378100872,0.2409442471012714,43793 -2325.3949341773987,0.4614150524139404,3860.471751451492,12067,0,3860.471751451492,0.9856814742088318,0.047646339982748,0.241180494655057,43793,6186.656934499741,0.9901003241539,0.0331801548600196,0.3477610696164991,0.986553966999054,0.0449106805026531,0.2418638094390883,43793 -2454.6309485435486,0.4894673824310303,4100.424854040146,12818,0,4100.424854040146,0.985745906829834,0.0477891266345977,0.2443888640165271,43793,6555.893530607224,0.9901209473609924,0.0326627865433692,0.3559880776898775,0.98660147190094,0.0450294651091098,0.2545477236240031,43793 -2584.913625717163,0.519852876663208,4340.415853261948,13570,0,4340.415853261948,0.9858596324920654,0.0472646281123161,0.2519613253121893,43793,6926.2187123298645,0.990494430065155,0.0316592827439308,0.3861857685105628,0.9867119193077089,0.0446790158748626,0.2552969826284998,43793 -2714.768377542496,0.5493123531341553,4580.423997163773,14326,0,4580.423997163773,0.985908031463623,0.0472202524542808,0.252440568707397,43793,7296.1311457157135,0.9905635118484496,0.0313184484839439,0.3779697651568524,0.9866960644721984,0.0443975627422332,0.2605003860303854,43793 -2843.13925409317,0.5789437294006348,4820.594982147217,15088,0,4820.594982147217,0.9859038591384888,0.0472093522548675,0.249374297031773,43793,7664.723286628723,0.9904943704605104,0.0314498618245124,0.3921774111864924,0.986743986606598,0.0443637333810329,0.2609996468968531,43793 -2970.6443254947662,0.6084580421447754,5060.655467748642,15848,0,5060.655467748642,0.9858655333518982,0.0474898181855678,0.2491323187341557,43793,8032.338857412338,0.990463137626648,0.0316627770662307,0.3854086649106583,0.9866436719894408,0.044806282967329,0.2567880837321566,43793 -3098.395025253296,0.6386611461639404,5300.637059926987,16610,0,5300.637059926987,0.9859662055969238,0.0469194948673248,0.2570484000552654,43793,8400.122673511505,0.9904116988182068,0.0317547656595706,0.3720946199853792,0.9867720007896424,0.0442337729036808,0.2636909103653799,43793 -3227.900470495224,0.668442964553833,5540.721517562866,17359,0,5540.721517562866,0.9858945608139038,0.0472031272947788,0.2509193630611767,43793,8769.763256072998,0.9904099106788636,0.031746108084917,0.377120961858111,0.9866794347763062,0.0444870181381702,0.2567422386746603,43793 -3359.1195571422577,0.6974725723266602,5780.809211492538,18114,0,5780.809211492538,0.9858099222183228,0.047458317130804,0.2474147599552491,43793,9141.121300935743,0.9906277656555176,0.0312679186463356,0.3893020073154302,0.9866157174110411,0.0446477569639682,0.2603163167861302,43793 -3489.12383890152,0.7305166721343994,6020.937862634659,18861,0,6020.937862634659,0.9858949780464172,0.047753270715475,0.2487265899364104,43793,9511.308304786682,0.990624189376831,0.0309741385281085,0.3934611392030673,0.9867175817489624,0.0448365472257137,0.2644804500767574,43793 -3618.504912137985,0.7599573135375977,6260.9293966293335,19617,0,6260.9293966293335,0.9857138991355896,0.0475363880395889,0.2473729470146687,43793,9880.730882644652,0.9908477663993835,0.0303032416850328,0.4078293290951283,0.9866339564323424,0.0445660911500453,0.2658152452066989,43793 -3745.8714208602905,0.7897212505340576,6501.003589630127,20378,0,6501.003589630127,0.985908031463623,0.0471081733703613,0.2536225909294436,43793,10248.222026586533,0.9909874200820924,0.0295681897550821,0.4334551141651773,0.986757755279541,0.0442159995436668,0.2665895056733096,43793 -3874.662131547928,0.8251934051513672,6741.249910831451,21131,0,6741.249910831451,0.9859185814857484,0.0475073382258415,0.2502607926465779,43793,10617.316989421844,0.9910106658935548,0.0294511504471302,0.4263613828291351,0.9867987632751464,0.0446124449372291,0.2653050845436331,43793 -4004.901304244995,0.8566954135894775,6981.331892490387,21888,0,6981.331892490387,0.985975444316864,0.0472784079611301,0.2519023897830796,43793,10987.689888238909,0.9911861419677734,0.0288388393819332,0.4450832660519178,0.9867784976959229,0.0443214364349842,0.2732250382415339,43793 -4132.4447610378265,0.8884739875793457,7221.436220884323,22651,0,7221.436220884323,0.985951006412506,0.0468320958316326,0.2558588523560185,43793,11355.389991521835,0.9912999868392944,0.0289224758744239,0.4529309025535197,0.9866639971733092,0.0441776476800441,0.2652814454767527,43793 -4260.824286222458,0.9203827381134032,7461.514452457428,23409,0,7461.514452457428,0.9860158562660216,0.0470472685992717,0.2603437330110667,43793,11723.900267839432,0.9911875128746032,0.0289385933429002,0.4468222720106902,0.98687344789505,0.0442219488322734,0.2659019283781217,43793 -4389.98766207695,1.248992681503296,7701.180008888245,24166,0,7701.180008888245,0.9858659505844116,0.0472520887851715,0.254193478821606,43793,12093.078660488129,0.9908936619758606,0.0299933366477489,0.4147140270965586,0.9866806268692015,0.0444285199046134,0.2666526009058164,43793 -4516.645851135254,1.2801299095153809,7941.191206932068,24915,0,7941.191206932068,0.9859838485717772,0.0469495318830013,0.2616324466083824,43793,12459.79927778244,0.9908816814422609,0.0299308989197015,0.4141630080073448,0.9867399334907532,0.0442919544875621,0.2720565043987341,43793 -4646.567767858505,1.3114888668060305,8181.43877363205,25671,0,8181.43877363205,0.985874354839325,0.0472862273454666,0.2556335293922175,43793,12830.020049333572,0.991015374660492,0.0294816028326749,0.4261409697864819,0.9866116046905518,0.0445607379078865,0.2697206290765936,43793 -4776.618052721024,1.3456156253814695,8421.660665512085,26422,0,8421.660665512085,0.9859535694122314,0.0478274375200271,0.2524827381782631,43793,13200.347779750824,0.9910532832145692,0.0291581321507692,0.4336587863156425,0.9867833256721495,0.0448125712573528,0.2679109911396861,43793 -4902.952599287033,1.3763277530670166,8661.61139369011,27176,0,8661.61139369011,0.9859000444412231,0.0473854392766952,0.2587810042911728,43793,13566.684311151505,0.9911747574806212,0.0289251338690519,0.4452254427445876,0.986806094646454,0.0443060398101806,0.2719736999539154,43793 -5034.454385757446,1.4076869487762451,8901.743763685226,27925,0,8901.743763685226,0.9859707951545716,0.0475916676223278,0.2610465782854482,43793,13938.37245965004,0.9912655353546144,0.0282652303576469,0.4763750571741873,0.9868454337120056,0.0445428192615509,0.2758354097559073,43793 -5161.768011569977,1.440685749053955,9141.946783781052,28676,0,9141.946783781052,0.9859851598739624,0.0473437346518039,0.2601059381326383,43793,14305.943809747696,0.99152410030365,0.0274818856269121,0.4806527878827299,0.986725687980652,0.0445623472332954,0.2700157613198305,43793 -5291.131680011749,1.4723448753356934,9382.060484170914,29433,0,9382.060484170914,0.9858773350715636,0.0478531457483768,0.2550559410887236,43793,14675.473503351212,0.9914757609367372,0.0276817418634891,0.4703134444485871,0.9867504835128784,0.0449969619512558,0.2625452748064528,43793 -5414.945363521576,1.5057015419006348,9622.095024824142,30189,0,9622.095024824142,0.9859526753425598,0.046930506825447,0.2677139955535795,43793,15039.375715255735,0.9917797446250916,0.0269236396998167,0.4778185300817284,0.986819863319397,0.0443765558302402,0.2750118181782855,43793 -5541.432502031326,1.537562131881714,9862.135322093964,30947,0,9862.135322093964,0.9860154390335084,0.0473567172884941,0.2568862944563916,43793,15405.955041885376,0.9916726350784302,0.0273495484143495,0.4843876213773574,0.986823558807373,0.0445965491235256,0.2758643690984573,43793 -5674.134570837021,1.574239730834961,10102.235368013382,31701,0,10102.235368013382,0.9858874082565308,0.0477321073412895,0.254013440712067,43793,15778.814150571823,0.9913953542709352,0.0280896257609128,0.4647237618074311,0.9867054224014282,0.0447933971881866,0.2632075575298583,43793 -5802.739213705063,1.6064386367797852,10342.45687031746,32453,0,10342.45687031746,0.9859194159507751,0.0471179857850074,0.2619408014981749,43793,16147.693119049072,0.9914029240608216,0.0280087292194366,0.4645824797357492,0.9867374897003174,0.0444966927170753,0.2662610658185099,43793 -5932.344577074051,1.6389267444610596,10582.4124584198,33209,0,10582.4124584198,0.985925316810608,0.0475473366677761,0.2548237538397603,43793,16517.30669927597,0.9913453459739684,0.028263833373785,0.4542848687332359,0.9869092106819152,0.0443997792899608,0.2762699325917176,43793 -6057.570507287979,1.6709973812103271,10822.565240621569,33967,0,10822.565240621569,0.9860167503356934,0.0483439229428768,0.2645479106855683,43793,16882.73788666725,0.991368293762207,0.0279243234544992,0.4556313674502258,0.986847460269928,0.0452116727828979,0.2753139830957439,43793 -6191.6519594192505,1.7039594650268557,11062.638055562971,34715,0,11062.638055562971,0.9859063625335692,0.0478287898004055,0.254685292772391,43793,17256.9448492527,0.9916467070579528,0.0272052455693483,0.477621432016776,0.9869270324707032,0.0444139018654823,0.2780114628194121,43793 -6315.94552397728,1.7400023937225342,11302.731583595276,35464,0,11302.731583595276,0.9860761165618896,0.0475175641477108,0.2631292091684028,43793,17621.38951563835,0.9917681217193604,0.0267904587090015,0.4988786142774347,0.986970067024231,0.0444306842982769,0.2751348822187652,43793 -6446.186127901077,1.7738418579101562,11542.948122262957,36217,0,11542.948122262957,0.9859569072723388,0.0481483377516269,0.2617889807945268,43793,17991.900566101074,0.9920303225517272,0.0256762281060218,0.5184381943115008,0.986899435520172,0.0450549945235252,0.2743834320358754,43793 -6571.886204242706,1.807714939117432,11782.979050397871,36976,0,11782.979050397871,0.9859548211097716,0.0478603355586528,0.2623279002542574,43793,18357.6857047081,0.9921303391456604,0.0254173502326011,0.5227613756415113,0.9869270324707032,0.0449628196656703,0.2739362064526262,43793 -6700.37837600708,1.8413872718811035,12023.23389339447,37725,0,12023.23389339447,0.9860011339187622,0.04807910323143,0.2635697863954901,43793,18726.489814043045,0.992214024066925,0.0252790115773677,0.532232625381326,0.9869457483291626,0.0449865348637104,0.2753142807744727,43793 -6825.7874138355255,1.8750951290130613,12263.267944574356,38488,0,12263.267944574356,0.9859851598739624,0.0478601045906543,0.2579441179379427,43793,19091.986697912216,0.9921120405197144,0.0254794359207153,0.5159274884360091,0.9868957996368408,0.0449107587337493,0.2744865870059713,43793 -6951.847330093384,1.91015625,12503.376068115234,39243,0,12503.376068115234,0.9859825968742372,0.0481785349547863,0.2605677246204606,43793,19458.20954990387,0.9920024275779724,0.0260459966957569,0.5064120084947242,0.986961543560028,0.0450559519231319,0.2774751291319751,43793 -7076.241222858429,1.944312334060669,12743.403692007065,40004,0,12743.403692007065,0.9859552383422852,0.0483670085668563,0.2582643340499317,43793,19822.685611486435,0.9919900298118592,0.0259577762335538,0.5071904458124843,0.9868775010108948,0.0453401543200016,0.2724499825709285,43793 -7204.345838546753,1.97922158241272,12983.375659227371,40763,0,12983.375659227371,0.9860268235206604,0.0480281524360179,0.2584114993433583,43793,20190.81739640236,0.9918793439865112,0.0262201521545648,0.4992228026531985,0.9869359731674194,0.0450229682028293,0.2775892452836066,43793 -7328.345608234405,2.01588773727417,13223.54290318489,41521,0,13223.54290318489,0.9859375357627868,0.0488403737545013,0.2549508720399013,43793,20555.041393518448,0.9919157028198242,0.0261086784303188,0.5037610390684488,0.986825168132782,0.0457451306283474,0.2686163429857319,43793 -7454.671268939972,2.0549404621124268,13463.506523132324,42273,0,13463.506523132324,0.985990583896637,0.048225313425064,0.2598064917046744,43793,20921.390644073486,0.9921371936798096,0.0254248585551977,0.517208637472588,0.986750066280365,0.0452563427388668,0.2774649701945903,43793 -7583.163735151291,2.093395233154297,13703.503736972809,43031,0,13703.503736972809,0.9858831763267516,0.0480500534176826,0.2615175880406216,43793,21289.939838171005,0.9923945665359496,0.0246466323733329,0.5287445837948694,0.9867159724235536,0.0451500564813613,0.2757726625150504,43793 -7711.316171169281,2.129408359527588,13943.55898284912,43777,0,13943.55898284912,0.986004114151001,0.0484849773347377,0.2589780692558592,43793,21658.204888105392,0.9924582839012146,0.0242256261408329,0.5548626450600167,0.9868783354759216,0.0454231277108192,0.2750619330472643,43793 -7834.429327011108,2.1637284755706787,14183.621313095093,44533,0,14183.621313095093,0.9860752820968628,0.0484214052557945,0.2575476179842053,43793,22021.435887813568,0.9928352236747742,0.0230323988944292,0.582281300015573,0.9868206977844238,0.0455101355910301,0.2740454856114717,43793 -7965.456509590149,2.204008817672729,14423.575603961945,45269,0,14423.575603961945,0.98585706949234,0.0489467233419418,0.2580098287432119,43793,22392.48312997818,0.9929222464561462,0.0229491274803876,0.5656440620216915,0.986701726913452,0.0458743907511234,0.2703104717691669,43793 -8088.594908952713,2.2400035858154297,14663.548845529556,46022,0,14663.548845529556,0.9859346151351928,0.0487636588513851,0.2573021483270661,43793,22755.651258468628,0.9929423928260804,0.022848380729556,0.5816965987814018,0.9868153929710388,0.0457025654613971,0.2774928305678912,43793 -8213.572768211365,2.277043104171753,14903.751368761064,46775,0,14903.751368761064,0.9860348105430604,0.0492957942187786,0.2616057469839685,43793,23120.890134334564,0.9926509261131288,0.023398483172059,0.5500882324221318,0.9869351387023926,0.0461748838424682,0.2759709745988475,43793 -8337.250361442566,2.31246018409729,15143.801016807556,47526,0,15143.801016807556,0.985968291759491,0.0490897744894027,0.2601720723494071,43793,23484.67352104187,0.9926035404205322,0.0236497167497873,0.5587826948908996,0.9867784976959229,0.0460036359727382,0.2770768478543392,43793 -8461.005041599274,2.348806619644165,15383.883342027664,48271,0,15383.883342027664,0.986004114151001,0.0490064099431037,0.2604601451995874,43793,23848.569514989853,0.9926998019218444,0.0234197005629539,0.5763855108885503,0.9868454337120056,0.0459731742739677,0.2784930722589843,43793 -8586.831431388855,2.384529590606689,15624.072097301483,49033,0,15624.072097301483,0.9859733581542968,0.0493910759687423,0.2595510327017606,43793,24214.64115262032,0.9927592277526855,0.0232525151222944,0.5598052472954665,0.9868369102478028,0.0462203919887542,0.276565655192867,43793 -8711.369632005692,2.420367956161499,15864.111089468002,49787,0,15864.111089468002,0.9858781695365906,0.0498117581009864,0.2538513804026239,43793,24579.27449965477,0.992779016494751,0.023049347102642,0.557587615426931,0.9867760539054872,0.0467174611985683,0.2682286383885479,43793 -8832.70336985588,2.456578016281128,16104.149505615234,50541,0,16104.149505615234,0.9860200881958008,0.0497410297393798,0.2600190963165037,43793,24940.70306921005,0.9930933713912964,0.0218581091612577,0.5951135774073243,0.986832857131958,0.0467489995062351,0.2769980029500767,43793 -8958.184759140015,2.49343204498291,16344.100379228592,51297,0,16344.100379228592,0.985889494419098,0.0497760213911533,0.2609768772472949,43793,25306.192680835724,0.9931976199150084,0.0217012390494346,0.591113654143453,0.9867837429046632,0.0467196889221668,0.276880985821581,43793 -9081.7381067276,2.5313422679901123,16584.350385904312,52052,0,16584.350385904312,0.985987663269043,0.0499919690191745,0.2584287179698735,43793,25670.05592918396,0.9936760067939758,0.0202678125351667,0.6317658547986209,0.9868324398994446,0.0467768274247646,0.2784929432021214,43793 -9207.574704885485,2.5672249794006348,16824.390946626663,52811,0,16824.390946626663,0.9858625531196594,0.0507209710776805,0.2562704321462084,43793,26035.98972201348,0.9940836429595948,0.0191185437142848,0.6474843811104636,0.9867330193519592,0.0475767217576503,0.2701701682771992,43793 -9331.2084608078,2.6037094593048096,17064.626448631287,53562,0,17064.626448631287,0.985888659954071,0.0507962331175804,0.2548926207726321,43793,26399.916737556458,0.9939260482788086,0.0195813234895467,0.6418360793553928,0.9867784976959229,0.0476591289043426,0.2710364539945412,43793 -9452.654334545135,2.644875288009644,17304.80330300331,54304,0,17304.80330300331,0.9858903884887696,0.0509893335402011,0.2528467102063643,43793,26761.601624012,0.9939098954200744,0.0195740088820457,0.6298671260120975,0.986743986606598,0.0479759089648723,0.2708166189977605,43793 -9579.312143564224,2.6809160709381104,17544.856785297394,55063,0,17544.856785297394,0.9858301281929016,0.050938531756401,0.2527652640350336,43793,27128.370005607605,0.9937072992324828,0.0202469173818826,0.6290030533487323,0.9866514205932616,0.0477222166955471,0.2730129936795538,43793 -9701.83114695549,2.7188405990600586,17784.952792406082,55815,0,17784.952792406082,0.9858187437057496,0.0514654070138931,0.2534089531791236,43793,27491.043491363525,0.9936484694480896,0.0201473720371723,0.6260407024076896,0.9866623878479004,0.0481687262654304,0.2720521240646281,43793 -9828.160341501236,2.756070375442505,18025.06869673729,56567,0,18025.06869673729,0.985710084438324,0.0522744171321392,0.2540701946363075,43793,27857.546080112457,0.993515133857727,0.0203472040593624,0.6213367212781136,0.9865515232086182,0.0489356853067874,0.2671761914545285,43793 -9950.428166866302,2.793794870376587,18265.143191337585,57320,0,18265.143191337585,0.985775351524353,0.05271177738904953,0.2552477887895668,43793,28219.94671702385,0.9935488104820251,0.020334692671895027,0.611723780806605,0.9865474700927734,0.049405910074710846,0.26607081788312853,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index e78b2a447..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,659 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.529245,0.72697896,,,,,,,,,,,,,,,,, -1,,,0.532486081123352,0.7277461886405945,0.0224462379105656,0.5230699777603149,0.7331878542900085,0.0255389908122996,43793.0,0.5214918851852417,0.7347381114959717,0.0268125009161917,43793.0,19.20242476463318,326.7589704990387,19.20242476463318,307.5564877986908,0.0,0.0 -100,0.6413539,0.44322443,,,,,,,,,,,,,,,,, -200,0.36645594,0.33015168,,,,,,,,,,,,,,,,, -300,0.26892924,0.2364992,,,,,,,,,,,,,,,,, -400,0.17710057,0.16276665,,,,,,,,,,,,,,,,, -500,0.11002149,0.11727608,,,,,,,,,,,,,,,,, -600,0.07019315,0.08634268,,,,,,,,,,,,,,,,, -700,0.045417,0.073882565,,,,,,,,,,,,,,,,, -753,,,0.986735999584198,0.0693851262331008,0.0327052996792796,0.9841179251670836,0.0775840729475021,0.033466476626962,43793.0,0.983142077922821,0.0804035291075706,0.0345979947513477,43793.0,259.26010942459106,685.4427721500397,259.26010942459106,426.1334116458893,0.0302038192749023,0.0 -800,0.035962455,0.06578028,,,,,,,,,,,,,,,,, -900,0.15797465,0.06229969,,,,,,,,,,,,,,,,, -1000,0.1458265,0.05729956,,,,,,,,,,,,,,,,, -1100,0.2224564,0.054426048,,,,,,,,,,,,,,,,, -1200,0.07493934,0.05002374,,,,,,,,,,,,,,,,, -1300,0.06657913,0.055273395,,,,,,,,,,,,,,,,, -1400,0.12239163,0.047961496,,,,,,,,,,,,,,,,, -1500,0.18562612,0.0562476,,,,,,,,,,,,,,,,, -1504,,,0.9868547320365906,0.051052525639534,0.0706327779101072,0.984196662902832,0.0607562810182571,0.0690115714268976,43793.0,0.9832296967506408,0.0641597285866737,0.0706124312681735,43793.0,499.3460896015167,1047.0379321575165,499.3460896015167,547.5880465507507,0.0656647682189941,0.0 -1600,0.077732734,0.049110703,,,,,,,,,,,,,,,,, -1700,0.21938302,0.054803528,,,,,,,,,,,,,,,,, -1800,0.05747898,0.045861464,,,,,,,,,,,,,,,,, -1900,0.22622846,0.046887353,,,,,,,,,,,,,,,,, -2000,0.13769892,0.048361514,,,,,,,,,,,,,,,,, -2100,0.07813229,0.048902303,,,,,,,,,,,,,,,,, -2200,0.08744592,0.050181538,,,,,,,,,,,,,,,,, -2258,,,0.9876076579093932,0.0449433661997318,0.1270505432485927,0.9849172234535216,0.0539407506585121,0.122546315511896,43793.0,0.9839473962783812,0.0569424144923687,0.1268228510230024,43793.0,739.3077754974365,1411.1662957668304,739.3077754974365,671.7086169719696,0.0926315784454345,0.0 -2300,0.13269427,0.05107642,,,,,,,,,,,,,,,,, -2400,0.071914114,0.04317651,,,,,,,,,,,,,,,,, -2500,0.2505563,0.046003744,,,,,,,,,,,,,,,,, -2600,0.12305068,0.042262394,,,,,,,,,,,,,,,,, -2700,0.093977936,0.042800374,,,,,,,,,,,,,,,,, -2800,0.10121738,0.044663318,,,,,,,,,,,,,,,,, -2900,0.11531147,0.049549814,,,,,,,,,,,,,,,,, -2998,,,0.9880425930023192,0.0425593219697475,0.1539329837726181,0.9850654006004332,0.0518801398575305,0.1472381553313431,43793.0,0.9841668605804444,0.0545504949986934,0.1515404133887407,43793.0,979.2893161773682,1779.545521259308,979.2893161773682,800.0604002475739,0.1198203563690185,0.0 -3000,0.07032323,0.04484592,,,,,,,,,,,,,,,,, -3100,0.11637143,0.04280544,,,,,,,,,,,,,,,,, -3200,0.118514426,0.044021916,,,,,,,,,,,,,,,,, -3300,0.07831112,0.042780574,,,,,,,,,,,,,,,,, -3400,0.06582521,0.039884847,,,,,,,,,,,,,,,,, -3500,0.10437279,0.042292736,,,,,,,,,,,,,,,,, -3600,0.119418226,0.0415783,,,,,,,,,,,,,,,,, -3700,0.05422346,0.043002926,,,,,,,,,,,,,,,,, -3740,,,0.9883670210838318,0.0407157242298126,0.1844254620145033,0.985426664352417,0.0505857840180397,0.1685717042001645,43793.0,0.9844903349876404,0.0534924790263175,0.166235172026189,43793.0,1219.3918023109436,2143.96270942688,1219.3918023109436,924.3276314735411,0.1471123695373535,0.0 -3800,0.050042197,0.039587718,,,,,,,,,,,,,,,,, -3900,0.14556271,0.042054985,,,,,,,,,,,,,,,,, -4000,0.074051924,0.045187734,,,,,,,,,,,,,,,,, -4100,0.07570779,0.04218684,,,,,,,,,,,,,,,,, -4200,0.047673296,0.038831197,,,,,,,,,,,,,,,,, -4300,0.101561144,0.043088425,,,,,,,,,,,,,,,,, -4400,0.043982275,0.040604845,,,,,,,,,,,,,,,,, -4482,,,0.9883249402046204,0.0400893315672874,0.2131278146848692,0.9855647087097168,0.0492448098957538,0.1808645836231969,43793.0,0.9846583604812622,0.0519588366150856,0.181732471899558,43793.0,1459.5320043563845,2515.1227877140045,1459.5320043563845,1055.3013689517977,0.1740467548370361,0.0 -4500,0.05279169,0.042923342,,,,,,,,,,,,,,,,, -4600,0.079356626,0.045454007,,,,,,,,,,,,,,,,, -4700,0.06072202,0.038605813,,,,,,,,,,,,,,,,, -4800,0.069105685,0.04193022,,,,,,,,,,,,,,,,, -4900,0.06130122,0.03956609,,,,,,,,,,,,,,,,, -5000,0.06665625,0.042551808,,,,,,,,,,,,,,,,, -5100,0.042627085,0.039126735,,,,,,,,,,,,,,,,, -5200,0.04094102,0.042777013,,,,,,,,,,,,,,,,, -5209,,,0.9886398315429688,0.0388185493648052,0.2206747630203164,0.9856414198875428,0.0487050414085388,0.1843567487809262,43793.0,0.9847750663757324,0.0513728559017181,0.1873674543732953,43793.0,1699.6076345443726,2881.247640132904,1699.6076345443726,1181.2931609153748,0.2058660984039306,0.0 -5300,0.04572012,0.040005542,,,,,,,,,,,,,,,,, -5400,0.06355058,0.04738893,,,,,,,,,,,,,,,,, -5500,0.035224628,0.042610094,,,,,,,,,,,,,,,,, -5600,0.05705489,0.042691935,,,,,,,,,,,,,,,,, -5700,0.054789793,0.036223613,,,,,,,,,,,,,,,,, -5800,0.11291303,0.04131167,,,,,,,,,,,,,,,,, -5900,0.032009535,0.038927656,,,,,,,,,,,,,,,,, -5974,,,0.988920032978058,0.038038682192564,0.2460739881269172,0.9857940673828124,0.0480422861874103,0.1994409548683768,43793.0,0.984929621219635,0.0506270751357078,0.2034318188046622,43793.0,1939.6774501800537,3246.82239818573,1939.6774501800537,1306.7508709430697,0.2331297397613525,0.0 -6000,0.05902141,0.037372336,,,,,,,,,,,,,,,,, -6100,0.050062764,0.036348455,,,,,,,,,,,,,,,,, -6200,0.033746522,0.03938482,,,,,,,,,,,,,,,,, -6300,0.028767187,0.039302476,,,,,,,,,,,,,,,,, -6400,0.040156353,0.040618747,,,,,,,,,,,,,,,,, -6500,0.041659415,0.040134493,,,,,,,,,,,,,,,,, -6600,0.030784529,0.038007893,,,,,,,,,,,,,,,,, -6700,0.033776756,0.0418758,,,,,,,,,,,,,,,,, -6743,,,0.9888412952423096,0.0381081067025661,0.2423005532237125,0.9858776926994324,0.0480373539030551,0.204004536859809,43793.0,0.984944760799408,0.0509179197251796,0.2028065330187146,43793.0,2179.744330406189,3612.562463760376,2179.744330406189,1432.3741779327393,0.2616612911224365,0.0 -6800,0.022529962,0.0384616,,,,,,,,,,,,,,,,, -6900,0.03518985,0.039586574,,,,,,,,,,,,,,,,, -7000,0.032192886,0.039798107,,,,,,,,,,,,,,,,, -7100,0.030501794,0.04418536,,,,,,,,,,,,,,,,, -7200,0.029285293,0.038511217,,,,,,,,,,,,,,,,, -7300,0.038993374,0.039420858,,,,,,,,,,,,,,,,, -7400,0.053721465,0.0429776,,,,,,,,,,,,,,,,, -7500,0.031862978,0.0418695,,,,,,,,,,,,,,,,, -7515,,,0.988990843296051,0.0374416075646877,0.2543271757587377,0.98616224527359,0.0469825714826583,0.215558121122398,43793.0,0.9852067828178406,0.0496997125446796,0.217869939416924,43793.0,2419.9255475997925,3978.6678981781006,2419.9255475997925,1558.2514843940735,0.2884867191314697,0.0 -7600,0.039960675,0.04109507,,,,,,,,,,,,,,,,, -7700,0.02997838,0.037924305,,,,,,,,,,,,,,,,, -7800,0.05045916,0.040044792,,,,,,,,,,,,,,,,, -7900,0.035064373,0.038285654,,,,,,,,,,,,,,,,, -8000,0.027464071,0.04163217,,,,,,,,,,,,,,,,, -8100,0.024216667,0.039952926,,,,,,,,,,,,,,,,, -8200,0.03501683,0.036356863,,,,,,,,,,,,,,,,, -8267,,,0.9891239404678344,0.0368091352283954,0.2656056616732034,0.9862223267555236,0.0464504174888134,0.2172402529453152,43793.0,0.9852614998817444,0.0492220148444175,0.214766806953296,43793.0,2660.0192317962646,4346.496450185776,2660.0192317962646,1685.931292772293,0.3178873062133789,0.0 -8300,0.020814734,0.039086368,,,,,,,,,,,,,,,,, -8400,0.022604128,0.03989181,,,,,,,,,,,,,,,,, -8500,0.023051469,0.037680786,,,,,,,,,,,,,,,,, -8600,0.028100122,0.03752898,,,,,,,,,,,,,,,,, -8700,0.02519499,0.03776733,,,,,,,,,,,,,,,,, -8800,0.02491354,0.041052643,,,,,,,,,,,,,,,,, -8900,0.025491694,0.040431444,,,,,,,,,,,,,,,,, -9000,0.0214753,0.041938964,,,,,,,,,,,,,,,,, -9025,,,0.989315629005432,0.0362415947020053,0.2720422042524107,0.9861720204353333,0.0463324151933193,0.2272592642219753,43793.0,0.9852830171585084,0.0490747392177581,0.233778299310469,43793.0,2900.125126838684,4715.513872861862,2900.125126838684,1814.795142650604,0.3454384803771972,0.0 -9100,0.022482045,0.03911855,,,,,,,,,,,,,,,,, -9200,0.020574598,0.03924833,,,,,,,,,,,,,,,,, -9300,0.050307553,0.04026924,,,,,,,,,,,,,,,,, -9400,0.024572063,0.03944726,,,,,,,,,,,,,,,,, -9500,0.024822053,0.037296772,,,,,,,,,,,,,,,,, -9600,0.03328547,0.032945782,,,,,,,,,,,,,,,,, -9700,0.03385912,0.03889129,,,,,,,,,,,,,,,,, -9783,,,0.9895707368850708,0.0354536771774292,0.2919737580040339,0.9863116145133972,0.0457900017499923,0.2333110153993411,43793.0,0.985456109046936,0.0484712347388267,0.2270379080828656,43793.0,3140.291603088379,5084.685031175613,3140.291603088379,1943.7513513565063,0.3736414909362793,0.0 -9800,0.039308958,0.04085786,,,,,,,,,,,,,,,,, -9900,0.042702854,0.041072704,,,,,,,,,,,,,,,,, -10000,0.030441582,0.035154127,,,,,,,,,,,,,,,,, -10100,0.033542167,0.038010966,,,,,,,,,,,,,,,,, -10200,0.033077125,0.03787434,,,,,,,,,,,,,,,,, -10300,0.04011439,0.041261885,,,,,,,,,,,,,,,,, -10400,0.043045692,0.039054926,,,,,,,,,,,,,,,,, -10500,0.0496105,0.04264045,,,,,,,,,,,,,,,,, -10547,,,0.9897831678390504,0.034679215401411,0.3002367546046378,0.9863839149475098,0.0453971065580844,0.2397068494129548,43793.0,0.9855917096138,0.047913584858179,0.2333297489808317,43793.0,3380.29745554924,5447.554109573364,3380.29745554924,2066.56652712822,0.4013993740081787,0.0 -10600,0.019308899,0.035304245,,,,,,,,,,,,,,,,, -10700,0.040731974,0.041333575,,,,,,,,,,,,,,,,, -10800,0.030875038,0.03542026,,,,,,,,,,,,,,,,, -10900,0.03233214,0.039953373,,,,,,,,,,,,,,,,, -11000,0.030338224,0.04133819,,,,,,,,,,,,,,,,, -11100,0.03339168,0.037359595,,,,,,,,,,,,,,,,, -11200,0.024089761,0.037405428,,,,,,,,,,,,,,,,, -11300,0.030799568,0.037171435,,,,,,,,,,,,,,,,, -11313,,,0.9898716807365416,0.034160740673542,0.3269157247033578,0.9863802194595336,0.0452970378100872,0.2409442471012714,43793.0,0.9855576157569884,0.0481827370822429,0.2386666619664352,43793.0,3620.2503366470337,5817.087311029434,3620.2503366470337,2196.096314430237,0.4317901134490967,0.0 -11400,0.024554228,0.034808796,,,,,,,,,,,,,,,,, -11500,0.025374478,0.037547067,,,,,,,,,,,,,,,,, -11600,0.027101971,0.037234902,,,,,,,,,,,,,,,,, -11700,0.0609472,0.037411764,,,,,,,,,,,,,,,,, -11800,0.039455123,0.034859434,,,,,,,,,,,,,,,,, -11900,0.02967416,0.03900902,,,,,,,,,,,,,,,,, -12000,0.03048971,0.039455224,,,,,,,,,,,,,,,,, -12067,,,0.9901003241539,0.0331801548600196,0.3477610696164991,0.986553966999054,0.0449106805026531,0.2418638094390883,43793.0,0.9856814742088318,0.047646339982748,0.241180494655057,43793.0,3860.471751451492,6186.656934499741,3860.471751451492,2325.3949341773987,0.4614150524139404,0.0 -12100,0.032861203,0.03867651,,,,,,,,,,,,,,,,, -12200,0.056718424,0.03699907,,,,,,,,,,,,,,,,, -12300,0.051211763,0.034429267,,,,,,,,,,,,,,,,, -12400,0.045121364,0.033312712,,,,,,,,,,,,,,,,, -12500,0.03326269,0.035583735,,,,,,,,,,,,,,,,, -12600,0.051872093,0.040356774,,,,,,,,,,,,,,,,, -12700,0.04356654,0.036720403,,,,,,,,,,,,,,,,, -12800,0.044012416,0.03784927,,,,,,,,,,,,,,,,, -12818,,,0.9901209473609924,0.0326627865433692,0.3559880776898775,0.98660147190094,0.0450294651091098,0.2545477236240031,43793.0,0.985745906829834,0.0477891266345977,0.2443888640165271,43793.0,4100.424854040146,6555.893530607224,4100.424854040146,2454.6309485435486,0.4894673824310303,0.0 -12900,0.02871989,0.033619512,,,,,,,,,,,,,,,,, -13000,0.043419898,0.03432855,,,,,,,,,,,,,,,,, -13100,0.052617256,0.040544752,,,,,,,,,,,,,,,,, -13200,0.050872397,0.03480863,,,,,,,,,,,,,,,,, -13300,0.03500018,0.03470618,,,,,,,,,,,,,,,,, -13400,0.08149761,0.041387744,,,,,,,,,,,,,,,,, -13500,0.044211455,0.031670775,,,,,,,,,,,,,,,,, -13570,,,0.990494430065155,0.0316592827439308,0.3861857685105628,0.9867119193077089,0.0446790158748626,0.2552969826284998,43793.0,0.9858596324920654,0.0472646281123161,0.2519613253121893,43793.0,4340.415853261948,6926.2187123298645,4340.415853261948,2584.913625717163,0.519852876663208,0.0 -13600,0.032666758,0.034786172,,,,,,,,,,,,,,,,, -13700,0.04634076,0.038550045,,,,,,,,,,,,,,,,, -13800,0.03454731,0.038178615,,,,,,,,,,,,,,,,, -13900,0.04921812,0.03592826,,,,,,,,,,,,,,,,, -14000,0.050231125,0.03551374,,,,,,,,,,,,,,,,, -14100,0.036499545,0.035197545,,,,,,,,,,,,,,,,, -14200,0.063081495,0.033420905,,,,,,,,,,,,,,,,, -14300,0.044942085,0.034538604,,,,,,,,,,,,,,,,, -14326,,,0.9905635118484496,0.0313184484839439,0.3779697651568524,0.9866960644721984,0.0443975627422332,0.2605003860303854,43793.0,0.985908031463623,0.0472202524542808,0.252440568707397,43793.0,4580.423997163773,7296.1311457157135,4580.423997163773,2714.768377542496,0.5493123531341553,0.0 -14400,0.046437126,0.03733667,,,,,,,,,,,,,,,,, -14500,0.0739381,0.032971557,,,,,,,,,,,,,,,,, -14600,0.052473567,0.03534728,,,,,,,,,,,,,,,,, -14700,0.041699775,0.033889152,,,,,,,,,,,,,,,,, -14800,0.059531137,0.03748376,,,,,,,,,,,,,,,,, -14900,0.044973902,0.03888281,,,,,,,,,,,,,,,,, -15000,0.046161436,0.037286397,,,,,,,,,,,,,,,,, -15088,,,0.9904943704605104,0.0314498618245124,0.3921774111864924,0.986743986606598,0.0443637333810329,0.2609996468968531,43793.0,0.9859038591384888,0.0472093522548675,0.249374297031773,43793.0,4820.594982147217,7664.723286628723,4820.594982147217,2843.13925409317,0.5789437294006348,0.0 -15100,0.0437633,0.035863586,,,,,,,,,,,,,,,,, -15200,0.055940483,0.035215806,,,,,,,,,,,,,,,,, -15300,0.045801446,0.036887627,,,,,,,,,,,,,,,,, -15400,0.05926455,0.038570233,,,,,,,,,,,,,,,,, -15500,0.04604901,0.035560604,,,,,,,,,,,,,,,,, -15600,0.060394943,0.038514134,,,,,,,,,,,,,,,,, -15700,0.04591761,0.034515698,,,,,,,,,,,,,,,,, -15800,0.046775356,0.033594187,,,,,,,,,,,,,,,,, -15848,,,0.990463137626648,0.0316627770662307,0.3854086649106583,0.9866436719894408,0.044806282967329,0.2567880837321566,43793.0,0.9858655333518982,0.0474898181855678,0.2491323187341557,43793.0,5060.655467748642,8032.338857412338,5060.655467748642,2970.6443254947662,0.6084580421447754,0.0 -15900,0.054409903,0.035529897,,,,,,,,,,,,,,,,, -16000,0.052090902,0.036142033,,,,,,,,,,,,,,,,, -16100,0.03935905,0.034970485,,,,,,,,,,,,,,,,, -16200,0.047558587,0.034787975,,,,,,,,,,,,,,,,, -16300,0.057446882,0.032181412,,,,,,,,,,,,,,,,, -16400,0.047706004,0.03489316,,,,,,,,,,,,,,,,, -16500,0.0819977,0.033419944,,,,,,,,,,,,,,,,, -16600,0.056789443,0.035375427,,,,,,,,,,,,,,,,, -16610,,,0.9904116988182068,0.0317547656595706,0.3720946199853792,0.9867720007896424,0.0442337729036808,0.2636909103653799,43793.0,0.9859662055969238,0.0469194948673248,0.2570484000552654,43793.0,5300.637059926987,8400.122673511505,5300.637059926987,3098.395025253296,0.6386611461639404,0.0 -16700,0.061720114,0.03697591,,,,,,,,,,,,,,,,, -16800,0.076414205,0.036425754,,,,,,,,,,,,,,,,, -16900,0.0938237,0.03727424,,,,,,,,,,,,,,,,, -17000,0.052310124,0.037111554,,,,,,,,,,,,,,,,, -17100,0.08285806,0.03578136,,,,,,,,,,,,,,,,, -17200,0.0630014,0.036468327,,,,,,,,,,,,,,,,, -17300,0.07349486,0.03222817,,,,,,,,,,,,,,,,, -17359,,,0.9904099106788636,0.031746108084917,0.377120961858111,0.9866794347763062,0.0444870181381702,0.2567422386746603,43793.0,0.9858945608139038,0.0472031272947788,0.2509193630611767,43793.0,5540.721517562866,8769.763256072998,5540.721517562866,3227.900470495224,0.668442964553833,0.0 -17400,0.081163675,0.036310315,,,,,,,,,,,,,,,,, -17500,0.08435263,0.038387287,,,,,,,,,,,,,,,,, -17600,0.056553062,0.03460254,,,,,,,,,,,,,,,,, -17700,0.05296621,0.034661345,,,,,,,,,,,,,,,,, -17800,0.07938159,0.0372841,,,,,,,,,,,,,,,,, -17900,0.09173659,0.03719428,,,,,,,,,,,,,,,,, -18000,0.09713211,0.037304543,,,,,,,,,,,,,,,,, -18100,0.084731,0.03632994,,,,,,,,,,,,,,,,, -18114,,,0.9906277656555176,0.0312679186463356,0.3893020073154302,0.9866157174110411,0.0446477569639682,0.2603163167861302,43793.0,0.9858099222183228,0.047458317130804,0.2474147599552491,43793.0,5780.809211492538,9141.121300935743,5780.809211492538,3359.1195571422577,0.6974725723266602,0.0 -18200,0.06159994,0.035017475,,,,,,,,,,,,,,,,, -18300,0.110146634,0.03569388,,,,,,,,,,,,,,,,, -18400,0.08431666,0.036400266,,,,,,,,,,,,,,,,, -18500,0.055442344,0.033091612,,,,,,,,,,,,,,,,, -18600,0.05431134,0.035762176,,,,,,,,,,,,,,,,, -18700,0.067169905,0.030938286,,,,,,,,,,,,,,,,, -18800,0.051578943,0.03295848,,,,,,,,,,,,,,,,, -18861,,,0.990624189376831,0.0309741385281085,0.3934611392030673,0.9867175817489624,0.0448365472257137,0.2644804500767574,43793.0,0.9858949780464172,0.047753270715475,0.2487265899364104,43793.0,6020.937862634659,9511.308304786682,6020.937862634659,3489.12383890152,0.7305166721343994,0.0 -18900,0.06268822,0.03484129,,,,,,,,,,,,,,,,, -19000,0.07419956,0.037316673,,,,,,,,,,,,,,,,, -19100,0.06812355,0.034271356,,,,,,,,,,,,,,,,, -19200,0.08419484,0.031815533,,,,,,,,,,,,,,,,, -19300,0.07028956,0.03448947,,,,,,,,,,,,,,,,, -19400,0.05871855,0.033195484,,,,,,,,,,,,,,,,, -19500,0.060457338,0.031983763,,,,,,,,,,,,,,,,, -19600,0.13445778,0.03655058,,,,,,,,,,,,,,,,, -19617,,,0.9908477663993835,0.0303032416850328,0.4078293290951283,0.9866339564323424,0.0445660911500453,0.2658152452066989,43793.0,0.9857138991355896,0.0475363880395889,0.2473729470146687,43793.0,6260.9293966293335,9880.730882644652,6260.9293966293335,3618.504912137985,0.7599573135375977,0.0 -19700,0.0722915,0.03479411,,,,,,,,,,,,,,,,, -19800,0.07900839,0.03739162,,,,,,,,,,,,,,,,, -19900,0.0661829,0.032461267,,,,,,,,,,,,,,,,, -20000,0.08453525,0.03521987,,,,,,,,,,,,,,,,, -20100,0.07101946,0.032479204,,,,,,,,,,,,,,,,, -20200,0.051112648,0.03243433,,,,,,,,,,,,,,,,, -20300,0.07296529,0.03284597,,,,,,,,,,,,,,,,, -20378,,,0.9909874200820924,0.0295681897550821,0.4334551141651773,0.986757755279541,0.0442159995436668,0.2665895056733096,43793.0,0.985908031463623,0.0471081733703613,0.2536225909294436,43793.0,6501.003589630127,10248.222026586533,6501.003589630127,3745.8714208602905,0.7897212505340576,0.0 -20400,0.06352527,0.03292507,,,,,,,,,,,,,,,,, -20500,0.08524453,0.036042634,,,,,,,,,,,,,,,,, -20600,0.094614394,0.029927244,,,,,,,,,,,,,,,,, -20700,0.105118774,0.0357746,,,,,,,,,,,,,,,,, -20800,0.08484981,0.03474385,,,,,,,,,,,,,,,,, -20900,0.06605669,0.03303686,,,,,,,,,,,,,,,,, -21000,0.05987664,0.035502568,,,,,,,,,,,,,,,,, -21100,0.072621904,0.032562025,,,,,,,,,,,,,,,,, -21131,,,0.9910106658935548,0.0294511504471302,0.4263613828291351,0.9867987632751464,0.0446124449372291,0.2653050845436331,43793.0,0.9859185814857484,0.0475073382258415,0.2502607926465779,43793.0,6741.249910831451,10617.316989421844,6741.249910831451,3874.662131547928,0.8251934051513672,0.0 -21200,0.06590269,0.032906864,,,,,,,,,,,,,,,,, -21300,0.1063307,0.03418622,,,,,,,,,,,,,,,,, -21400,0.057231627,0.034914076,,,,,,,,,,,,,,,,, -21500,0.13959408,0.032532312,,,,,,,,,,,,,,,,, -21600,0.060612697,0.036379293,,,,,,,,,,,,,,,,, -21700,0.088344835,0.03331294,,,,,,,,,,,,,,,,, -21800,0.061168324,0.03381445,,,,,,,,,,,,,,,,, -21888,,,0.9911861419677734,0.0288388393819332,0.4450832660519178,0.9867784976959229,0.0443214364349842,0.2732250382415339,43793.0,0.985975444316864,0.0472784079611301,0.2519023897830796,43793.0,6981.331892490387,10987.689888238909,6981.331892490387,4004.901304244995,0.8566954135894775,0.0 -21900,0.06429324,0.034287285,,,,,,,,,,,,,,,,, -22000,0.10354757,0.033145577,,,,,,,,,,,,,,,,, -22100,0.06962994,0.031515412,,,,,,,,,,,,,,,,, -22200,0.08362928,0.03654939,,,,,,,,,,,,,,,,, -22300,0.07424796,0.034412228,,,,,,,,,,,,,,,,, -22400,0.14770536,0.035494562,,,,,,,,,,,,,,,,, -22500,0.08081961,0.031544134,,,,,,,,,,,,,,,,, -22600,0.075193696,0.037122104,,,,,,,,,,,,,,,,, -22651,,,0.9912999868392944,0.0289224758744239,0.4529309025535197,0.9866639971733092,0.0441776476800441,0.2652814454767527,43793.0,0.985951006412506,0.0468320958316326,0.2558588523560185,43793.0,7221.436220884323,11355.389991521835,7221.436220884323,4132.4447610378265,0.8884739875793457,0.0 -22700,0.07054353,0.033506136,,,,,,,,,,,,,,,,, -22800,0.1021216,0.037589543,,,,,,,,,,,,,,,,, -22900,0.0742875,0.032581564,,,,,,,,,,,,,,,,, -23000,0.09218064,0.03584832,,,,,,,,,,,,,,,,, -23100,0.08092289,0.034707703,,,,,,,,,,,,,,,,, -23200,0.074509054,0.031277824,,,,,,,,,,,,,,,,, -23300,0.11245873,0.033957854,,,,,,,,,,,,,,,,, -23400,0.08744201,0.036561865,,,,,,,,,,,,,,,,, -23409,,,0.9911875128746032,0.0289385933429002,0.4468222720106902,0.98687344789505,0.0442219488322734,0.2659019283781217,43793.0,0.9860158562660216,0.0470472685992717,0.2603437330110667,43793.0,7461.514452457428,11723.900267839432,7461.514452457428,4260.824286222458,0.9203827381134032,0.0 -23500,0.069683015,0.032725263,,,,,,,,,,,,,,,,, -23600,0.075532615,0.03315312,,,,,,,,,,,,,,,,, -23700,0.08995807,0.035509963,,,,,,,,,,,,,,,,, -23800,0.10498764,0.03313389,,,,,,,,,,,,,,,,, -23900,0.13027628,0.034363203,,,,,,,,,,,,,,,,, -24000,0.108690016,0.035332385,,,,,,,,,,,,,,,,, -24100,0.08052382,0.031996485,,,,,,,,,,,,,,,,, -24166,,,0.9908936619758606,0.0299933366477489,0.4147140270965586,0.9866806268692015,0.0444285199046134,0.2666526009058164,43793.0,0.9858659505844116,0.0472520887851715,0.254193478821606,43793.0,7701.180008888245,12093.078660488129,7701.180008888245,4389.98766207695,1.248992681503296,0.0 -24200,0.07984105,0.03648249,,,,,,,,,,,,,,,,, -24300,0.09060758,0.03469499,,,,,,,,,,,,,,,,, -24400,0.09365717,0.0317133,,,,,,,,,,,,,,,,, -24500,0.066627204,0.033719257,,,,,,,,,,,,,,,,, -24600,0.06654938,0.03232811,,,,,,,,,,,,,,,,, -24700,0.079314664,0.036320888,,,,,,,,,,,,,,,,, -24800,0.077884346,0.03156104,,,,,,,,,,,,,,,,, -24900,0.069533646,0.030931825,,,,,,,,,,,,,,,,, -24915,,,0.9908816814422609,0.0299308989197015,0.4141630080073448,0.9867399334907532,0.0442919544875621,0.2720565043987341,43793.0,0.9859838485717772,0.0469495318830013,0.2616324466083824,43793.0,7941.191206932068,12459.79927778244,7941.191206932068,4516.645851135254,1.2801299095153809,0.0 -25000,0.059299286,0.030785272,,,,,,,,,,,,,,,,, -25100,0.10702231,0.034406,,,,,,,,,,,,,,,,, -25200,0.0631754,0.031773463,,,,,,,,,,,,,,,,, -25300,0.07283325,0.034401994,,,,,,,,,,,,,,,,, -25400,0.07816198,0.03117809,,,,,,,,,,,,,,,,, -25500,0.07057051,0.034005973,,,,,,,,,,,,,,,,, -25600,0.06680458,0.033138987,,,,,,,,,,,,,,,,, -25671,,,0.991015374660492,0.0294816028326749,0.4261409697864819,0.9866116046905518,0.0445607379078865,0.2697206290765936,43793.0,0.985874354839325,0.0472862273454666,0.2556335293922175,43793.0,8181.43877363205,12830.020049333572,8181.43877363205,4646.567767858505,1.3114888668060305,0.0 -25700,0.06577121,0.032956522,,,,,,,,,,,,,,,,, -25800,0.0994528,0.03443325,,,,,,,,,,,,,,,,, -25900,0.07204887,0.034482438,,,,,,,,,,,,,,,,, -26000,0.06579513,0.03184459,,,,,,,,,,,,,,,,, -26100,0.086190805,0.031364575,,,,,,,,,,,,,,,,, -26200,0.083571285,0.034048133,,,,,,,,,,,,,,,,, -26300,0.085482255,0.0346675,,,,,,,,,,,,,,,,, -26400,0.0945443,0.033029642,,,,,,,,,,,,,,,,, -26422,,,0.9910532832145692,0.0291581321507692,0.4336587863156425,0.9867833256721495,0.0448125712573528,0.2679109911396861,43793.0,0.9859535694122314,0.0478274375200271,0.2524827381782631,43793.0,8421.660665512085,13200.347779750824,8421.660665512085,4776.618052721024,1.3456156253814695,0.0 -26500,0.07270736,0.032668423,,,,,,,,,,,,,,,,, -26600,0.08683022,0.03411175,,,,,,,,,,,,,,,,, -26700,0.0802179,0.031715766,,,,,,,,,,,,,,,,, -26800,0.089993455,0.03519962,,,,,,,,,,,,,,,,, -26900,0.07990779,0.030156925,,,,,,,,,,,,,,,,, -27000,0.08825553,0.036091663,,,,,,,,,,,,,,,,, -27100,0.093814895,0.033236377,,,,,,,,,,,,,,,,, -27176,,,0.9911747574806212,0.0289251338690519,0.4452254427445876,0.986806094646454,0.0443060398101806,0.2719736999539154,43793.0,0.9859000444412231,0.0473854392766952,0.2587810042911728,43793.0,8661.61139369011,13566.684311151505,8661.61139369011,4902.952599287033,1.3763277530670166,0.0 -27200,0.0671893,0.031013297,,,,,,,,,,,,,,,,, -27300,0.08083262,0.034637388,,,,,,,,,,,,,,,,, -27400,0.07208888,0.035137385,,,,,,,,,,,,,,,,, -27500,0.09575877,0.031952478,,,,,,,,,,,,,,,,, -27600,0.09548804,0.03259571,,,,,,,,,,,,,,,,, -27700,0.088720374,0.03536516,,,,,,,,,,,,,,,,, -27800,0.074801736,0.03454967,,,,,,,,,,,,,,,,, -27900,0.07948497,0.031244315,,,,,,,,,,,,,,,,, -27925,,,0.9912655353546144,0.0282652303576469,0.4763750571741873,0.9868454337120056,0.0445428192615509,0.2758354097559073,43793.0,0.9859707951545716,0.0475916676223278,0.2610465782854482,43793.0,8901.743763685226,13938.37245965004,8901.743763685226,5034.454385757446,1.4076869487762451,0.0 -28000,0.07840765,0.033035368,,,,,,,,,,,,,,,,, -28100,0.07286295,0.03517778,,,,,,,,,,,,,,,,, -28200,0.08440772,0.032745935,,,,,,,,,,,,,,,,, -28300,0.09212208,0.029674353,,,,,,,,,,,,,,,,, -28400,0.08442844,0.03281379,,,,,,,,,,,,,,,,, -28500,0.07728451,0.03368174,,,,,,,,,,,,,,,,, -28600,0.08989219,0.037866384,,,,,,,,,,,,,,,,, -28676,,,0.99152410030365,0.0274818856269121,0.4806527878827299,0.986725687980652,0.0445623472332954,0.2700157613198305,43793.0,0.9859851598739624,0.0473437346518039,0.2601059381326383,43793.0,9141.946783781052,14305.943809747696,9141.946783781052,5161.768011569977,1.440685749053955,0.0 -28700,0.07084672,0.03207794,,,,,,,,,,,,,,,,, -28800,0.07846637,0.031208329,,,,,,,,,,,,,,,,, -28900,0.08264615,0.03693067,,,,,,,,,,,,,,,,, -29000,0.09465641,0.033893052,,,,,,,,,,,,,,,,, -29100,0.076522514,0.032443337,,,,,,,,,,,,,,,,, -29200,0.096495226,0.034534678,,,,,,,,,,,,,,,,, -29300,0.06834153,0.03379372,,,,,,,,,,,,,,,,, -29400,0.08514313,0.03462614,,,,,,,,,,,,,,,,, -29433,,,0.9914757609367372,0.0276817418634891,0.4703134444485871,0.9867504835128784,0.0449969619512558,0.2625452748064528,43793.0,0.9858773350715636,0.0478531457483768,0.2550559410887236,43793.0,9382.060484170914,14675.473503351212,9382.060484170914,5291.131680011749,1.4723448753356934,0.0 -29500,0.08978012,0.033324752,,,,,,,,,,,,,,,,, -29600,0.06722959,0.033639167,,,,,,,,,,,,,,,,, -29700,0.08038263,0.03191803,,,,,,,,,,,,,,,,, -29800,0.08953843,0.03328158,,,,,,,,,,,,,,,,, -29900,0.1227105,0.03142455,,,,,,,,,,,,,,,,, -30000,0.10968055,0.03527171,,,,,,,,,,,,,,,,, -30100,0.08604972,0.033663306,,,,,,,,,,,,,,,,, -30189,,,0.9917797446250916,0.0269236396998167,0.4778185300817284,0.986819863319397,0.0443765558302402,0.2750118181782855,43793.0,0.9859526753425598,0.046930506825447,0.2677139955535795,43793.0,9622.095024824142,15039.375715255735,9622.095024824142,5414.945363521576,1.5057015419006348,0.0 -30200,0.13255766,0.029962411,,,,,,,,,,,,,,,,, -30300,0.08240501,0.034051538,,,,,,,,,,,,,,,,, -30400,0.08512388,0.03126098,,,,,,,,,,,,,,,,, -30500,0.09501906,0.030409757,,,,,,,,,,,,,,,,, -30600,0.094023556,0.03508641,,,,,,,,,,,,,,,,, -30700,0.08850258,0.030764319,,,,,,,,,,,,,,,,, -30800,0.08555838,0.036056962,,,,,,,,,,,,,,,,, -30900,0.071270734,0.03263394,,,,,,,,,,,,,,,,, -30947,,,0.9916726350784302,0.0273495484143495,0.4843876213773574,0.986823558807373,0.0445965491235256,0.2758643690984573,43793.0,0.9860154390335084,0.0473567172884941,0.2568862944563916,43793.0,9862.135322093964,15405.955041885376,9862.135322093964,5541.432502031326,1.537562131881714,0.0 -31000,0.0882637,0.031217342,,,,,,,,,,,,,,,,, -31100,0.06851745,0.033427622,,,,,,,,,,,,,,,,, -31200,0.08202798,0.0322076,,,,,,,,,,,,,,,,, -31300,0.08412928,0.032375637,,,,,,,,,,,,,,,,, -31400,0.081174746,0.029708156,,,,,,,,,,,,,,,,, -31500,0.07564731,0.029015174,,,,,,,,,,,,,,,,, -31600,0.098171614,0.0337268,,,,,,,,,,,,,,,,, -31700,0.0788928,0.032979287,,,,,,,,,,,,,,,,, -31701,,,0.9913953542709352,0.0280896257609128,0.4647237618074311,0.9867054224014282,0.0447933971881866,0.2632075575298583,43793.0,0.9858874082565308,0.0477321073412895,0.254013440712067,43793.0,10102.235368013382,15778.814150571823,10102.235368013382,5674.134570837021,1.574239730834961,0.0 -31800,0.068132296,0.031940073,,,,,,,,,,,,,,,,, -31900,0.086261146,0.031308692,,,,,,,,,,,,,,,,, -32000,0.08451455,0.034251902,,,,,,,,,,,,,,,,, -32100,0.080370314,0.03076767,,,,,,,,,,,,,,,,, -32200,0.081463106,0.033741403,,,,,,,,,,,,,,,,, -32300,0.08481838,0.031589072,,,,,,,,,,,,,,,,, -32400,0.07441008,0.033848003,,,,,,,,,,,,,,,,, -32453,,,0.9914029240608216,0.0280087292194366,0.4645824797357492,0.9867374897003174,0.0444966927170753,0.2662610658185099,43793.0,0.9859194159507751,0.0471179857850074,0.2619408014981749,43793.0,10342.45687031746,16147.693119049072,10342.45687031746,5802.739213705063,1.6064386367797852,0.0 -32500,0.08458371,0.03208153,,,,,,,,,,,,,,,,, -32600,0.11479871,0.0350049,,,,,,,,,,,,,,,,, -32700,0.11274065,0.030232219,,,,,,,,,,,,,,,,, -32800,0.10416372,0.03244408,,,,,,,,,,,,,,,,, -32900,0.11675432,0.031192439,,,,,,,,,,,,,,,,, -33000,0.07735212,0.030816736,,,,,,,,,,,,,,,,, -33100,0.09984563,0.033501677,,,,,,,,,,,,,,,,, -33200,0.08051953,0.032489996,,,,,,,,,,,,,,,,, -33209,,,0.9913453459739684,0.028263833373785,0.4542848687332359,0.9869092106819152,0.0443997792899608,0.2762699325917176,43793.0,0.985925316810608,0.0475473366677761,0.2548237538397603,43793.0,10582.4124584198,16517.30669927597,10582.4124584198,5932.344577074051,1.6389267444610596,0.0 -33300,0.09894814,0.030610638,,,,,,,,,,,,,,,,, -33400,0.114964806,0.033346515,,,,,,,,,,,,,,,,, -33500,0.096379906,0.0305388,,,,,,,,,,,,,,,,, -33600,0.09787765,0.030086098,,,,,,,,,,,,,,,,, -33700,0.07734812,0.028532857,,,,,,,,,,,,,,,,, -33800,0.107730605,0.030474555,,,,,,,,,,,,,,,,, -33900,0.10026305,0.03214052,,,,,,,,,,,,,,,,, -33967,,,0.991368293762207,0.0279243234544992,0.4556313674502258,0.986847460269928,0.0452116727828979,0.2753139830957439,43793.0,0.9860167503356934,0.0483439229428768,0.2645479106855683,43793.0,10822.565240621569,16882.73788666725,10822.565240621569,6057.570507287979,1.6709973812103271,0.0 -34000,0.091951616,0.03293915,,,,,,,,,,,,,,,,, -34100,0.0882872,0.031238005,,,,,,,,,,,,,,,,, -34200,0.07487245,0.031664003,,,,,,,,,,,,,,,,, -34300,0.08124158,0.02987325,,,,,,,,,,,,,,,,, -34400,0.07857703,0.032422952,,,,,,,,,,,,,,,,, -34500,0.08038334,0.030284889,,,,,,,,,,,,,,,,, -34600,0.06869705,0.027912222,,,,,,,,,,,,,,,,, -34700,0.077042766,0.031116145,,,,,,,,,,,,,,,,, -34715,,,0.9916467070579528,0.0272052455693483,0.477621432016776,0.9869270324707032,0.0444139018654823,0.2780114628194121,43793.0,0.9859063625335692,0.0478287898004055,0.254685292772391,43793.0,11062.638055562971,17256.9448492527,11062.638055562971,6191.6519594192505,1.7039594650268557,0.0 -34800,0.109353475,0.031990193,,,,,,,,,,,,,,,,, -34900,0.099126264,0.03327276,,,,,,,,,,,,,,,,, -35000,0.08466589,0.03264318,,,,,,,,,,,,,,,,, -35100,0.08938074,0.03206832,,,,,,,,,,,,,,,,, -35200,0.10620276,0.033949897,,,,,,,,,,,,,,,,, -35300,0.08096087,0.03082991,,,,,,,,,,,,,,,,, -35400,0.081432015,0.031717874,,,,,,,,,,,,,,,,, -35464,,,0.9917681217193604,0.0267904587090015,0.4988786142774347,0.986970067024231,0.0444306842982769,0.2751348822187652,43793.0,0.9860761165618896,0.0475175641477108,0.2631292091684028,43793.0,11302.731583595276,17621.38951563835,11302.731583595276,6315.94552397728,1.7400023937225342,0.0 -35500,0.10897436,0.034625784,,,,,,,,,,,,,,,,, -35600,0.08412857,0.031669695,,,,,,,,,,,,,,,,, -35700,0.09310472,0.031180084,,,,,,,,,,,,,,,,, -35800,0.118614964,0.034232665,,,,,,,,,,,,,,,,, -35900,0.10235983,0.032625545,,,,,,,,,,,,,,,,, -36000,0.080255836,0.028836424,,,,,,,,,,,,,,,,, -36100,0.10629026,0.03102316,,,,,,,,,,,,,,,,, -36200,0.07953617,0.030684628,,,,,,,,,,,,,,,,, -36217,,,0.9920303225517272,0.0256762281060218,0.5184381943115008,0.986899435520172,0.0450549945235252,0.2743834320358754,43793.0,0.9859569072723388,0.0481483377516269,0.2617889807945268,43793.0,11542.948122262957,17991.900566101074,11542.948122262957,6446.186127901077,1.7738418579101562,0.0 -36300,0.14062807,0.034989778,,,,,,,,,,,,,,,,, -36400,0.09007475,0.030053621,,,,,,,,,,,,,,,,, -36500,0.10100163,0.0327748,,,,,,,,,,,,,,,,, -36600,0.09516553,0.03334493,,,,,,,,,,,,,,,,, -36700,0.10293968,0.03460936,,,,,,,,,,,,,,,,, -36800,0.099744804,0.031612992,,,,,,,,,,,,,,,,, -36900,0.08439587,0.026817068,,,,,,,,,,,,,,,,, -36976,,,0.9921303391456604,0.0254173502326011,0.5227613756415113,0.9869270324707032,0.0449628196656703,0.2739362064526262,43793.0,0.9859548211097716,0.0478603355586528,0.2623279002542574,43793.0,11782.979050397871,18357.6857047081,11782.979050397871,6571.886204242706,1.807714939117432,0.0 -37000,0.091186315,0.031236017,,,,,,,,,,,,,,,,, -37100,0.12389581,0.03475968,,,,,,,,,,,,,,,,, -37200,0.09846654,0.03314457,,,,,,,,,,,,,,,,, -37300,0.09227459,0.031306274,,,,,,,,,,,,,,,,, -37400,0.09396467,0.030550282,,,,,,,,,,,,,,,,, -37500,0.09674199,0.031315107,,,,,,,,,,,,,,,,, -37600,0.08272846,0.032745678,,,,,,,,,,,,,,,,, -37700,0.14504147,0.033511057,,,,,,,,,,,,,,,,, -37725,,,0.992214024066925,0.0252790115773677,0.532232625381326,0.9869457483291626,0.0449865348637104,0.2753142807744727,43793.0,0.9860011339187622,0.04807910323143,0.2635697863954901,43793.0,12023.23389339447,18726.489814043045,12023.23389339447,6700.37837600708,1.8413872718811035,0.0 -37800,0.121932484,0.029534632,,,,,,,,,,,,,,,,, -37900,0.083422855,0.03168167,,,,,,,,,,,,,,,,, -38000,0.09538265,0.03324194,,,,,,,,,,,,,,,,, -38100,0.10184775,0.031164275,,,,,,,,,,,,,,,,, -38200,0.09747408,0.033283178,,,,,,,,,,,,,,,,, -38300,0.12593599,0.030527093,,,,,,,,,,,,,,,,, -38400,0.09222951,0.032367665,,,,,,,,,,,,,,,,, -38488,,,0.9921120405197144,0.0254794359207153,0.5159274884360091,0.9868957996368408,0.0449107587337493,0.2744865870059713,43793.0,0.9859851598739624,0.0478601045906543,0.2579441179379427,43793.0,12263.267944574356,19091.986697912216,12263.267944574356,6825.7874138355255,1.8750951290130613,0.0 -38500,0.106211476,0.03006028,,,,,,,,,,,,,,,,, -38600,0.09058192,0.030045941,,,,,,,,,,,,,,,,, -38700,0.10254194,0.03306282,,,,,,,,,,,,,,,,, -38800,0.09861264,0.031833835,,,,,,,,,,,,,,,,, -38900,0.099272564,0.029372659,,,,,,,,,,,,,,,,, -39000,0.12414502,0.032870203,,,,,,,,,,,,,,,,, -39100,0.1185542,0.031799506,,,,,,,,,,,,,,,,, -39200,0.10824396,0.031918004,,,,,,,,,,,,,,,,, -39243,,,0.9920024275779724,0.0260459966957569,0.5064120084947242,0.986961543560028,0.0450559519231319,0.2774751291319751,43793.0,0.9859825968742372,0.0481785349547863,0.2605677246204606,43793.0,12503.376068115234,19458.20954990387,12503.376068115234,6951.847330093384,1.91015625,0.0 -39300,0.09352929,0.031070929,,,,,,,,,,,,,,,,, -39400,0.10940751,0.033123437,,,,,,,,,,,,,,,,, -39500,0.13633715,0.031703804,,,,,,,,,,,,,,,,, -39600,0.10170427,0.030035757,,,,,,,,,,,,,,,,, -39700,0.0895776,0.030348504,,,,,,,,,,,,,,,,, -39800,0.100724585,0.0314553,,,,,,,,,,,,,,,,, -39900,0.09557908,0.030001633,,,,,,,,,,,,,,,,, -40000,0.0957468,0.029237878,,,,,,,,,,,,,,,,, -40004,,,0.9919900298118592,0.0259577762335538,0.5071904458124843,0.9868775010108948,0.0453401543200016,0.2724499825709285,43793.0,0.9859552383422852,0.0483670085668563,0.2582643340499317,43793.0,12743.403692007065,19822.685611486435,12743.403692007065,7076.241222858429,1.944312334060669,0.0 -40100,0.13562123,0.030343166,,,,,,,,,,,,,,,,, -40200,0.09250287,0.029198142,,,,,,,,,,,,,,,,, -40300,0.09699492,0.0289239,,,,,,,,,,,,,,,,, -40400,0.10743992,0.029783588,,,,,,,,,,,,,,,,, -40500,0.121107474,0.033573616,,,,,,,,,,,,,,,,, -40600,0.08873023,0.02684137,,,,,,,,,,,,,,,,, -40700,0.09423127,0.028930703,,,,,,,,,,,,,,,,, -40763,,,0.9918793439865112,0.0262201521545648,0.4992228026531985,0.9869359731674194,0.0450229682028293,0.2775892452836066,43793.0,0.9860268235206604,0.0480281524360179,0.2584114993433583,43793.0,12983.375659227371,20190.81739640236,12983.375659227371,7204.345838546753,1.97922158241272,0.0 -40800,0.09841113,0.031373966,,,,,,,,,,,,,,,,, -40900,0.110715464,0.030736072,,,,,,,,,,,,,,,,, -41000,0.10205859,0.029456679,,,,,,,,,,,,,,,,, -41100,0.10332146,0.031959563,,,,,,,,,,,,,,,,, -41200,0.09106722,0.032019798,,,,,,,,,,,,,,,,, -41300,0.15009701,0.029765999,,,,,,,,,,,,,,,,, -41400,0.13209136,0.03254291,,,,,,,,,,,,,,,,, -41500,0.08877922,0.028139174,,,,,,,,,,,,,,,,, -41521,,,0.9919157028198242,0.0261086784303188,0.5037610390684488,0.986825168132782,0.0457451306283474,0.2686163429857319,43793.0,0.9859375357627868,0.0488403737545013,0.2549508720399013,43793.0,13223.54290318489,20555.041393518448,13223.54290318489,7328.345608234405,2.01588773727417,0.0 -41600,0.11307294,0.03000658,,,,,,,,,,,,,,,,, -41700,0.113302834,0.032025095,,,,,,,,,,,,,,,,, -41800,0.13920851,0.033953786,,,,,,,,,,,,,,,,, -41900,0.09623266,0.030322753,,,,,,,,,,,,,,,,, -42000,0.09427535,0.027333831,,,,,,,,,,,,,,,,, -42100,0.104224704,0.030321512,,,,,,,,,,,,,,,,, -42200,0.11782375,0.027918039,,,,,,,,,,,,,,,,, -42273,,,0.9921371936798096,0.0254248585551977,0.517208637472588,0.986750066280365,0.0452563427388668,0.2774649701945903,43793.0,0.985990583896637,0.048225313425064,0.2598064917046744,43793.0,13463.506523132324,20921.390644073486,13463.506523132324,7454.671268939972,2.0549404621124268,0.0 -42300,0.12433585,0.02924527,,,,,,,,,,,,,,,,, -42400,0.15008195,0.030590512,,,,,,,,,,,,,,,,, -42500,0.12038892,0.02796708,,,,,,,,,,,,,,,,, -42600,0.1109683,0.030447148,,,,,,,,,,,,,,,,, -42700,0.09897163,0.031219188,,,,,,,,,,,,,,,,, -42800,0.11249303,0.028502712,,,,,,,,,,,,,,,,, -42900,0.096477404,0.029708637,,,,,,,,,,,,,,,,, -43000,0.09998196,0.029676393,,,,,,,,,,,,,,,,, -43031,,,0.9923945665359496,0.0246466323733329,0.5287445837948694,0.9867159724235536,0.0451500564813613,0.2757726625150504,43793.0,0.9858831763267516,0.0480500534176826,0.2615175880406216,43793.0,13703.503736972809,21289.939838171005,13703.503736972809,7583.163735151291,2.093395233154297,0.0 -43100,0.09892063,0.027489295,,,,,,,,,,,,,,,,, -43200,0.12518844,0.028253669,,,,,,,,,,,,,,,,, -43300,0.16524942,0.031128136,,,,,,,,,,,,,,,,, -43400,0.1434906,0.029365383,,,,,,,,,,,,,,,,, -43500,0.115711346,0.029750464,,,,,,,,,,,,,,,,, -43600,0.09246093,0.026794204,,,,,,,,,,,,,,,,, -43700,0.13117354,0.031343583,,,,,,,,,,,,,,,,, -43777,,,0.9924582839012146,0.0242256261408329,0.5548626450600167,0.9868783354759216,0.0454231277108192,0.2750619330472643,43793.0,0.986004114151001,0.0484849773347377,0.2589780692558592,43793.0,13943.55898284912,21658.204888105392,13943.55898284912,7711.316171169281,2.129408359527588,0.0 -43800,0.12294719,0.031833846,,,,,,,,,,,,,,,,, -43900,0.08714084,0.027232967,,,,,,,,,,,,,,,,, -44000,0.10102877,0.031199811,,,,,,,,,,,,,,,,, -44100,0.10154096,0.028683456,,,,,,,,,,,,,,,,, -44200,0.12375943,0.03383207,,,,,,,,,,,,,,,,, -44300,0.10373781,0.030171463,,,,,,,,,,,,,,,,, -44400,0.10123131,0.029087782,,,,,,,,,,,,,,,,, -44500,0.09552009,0.02733619,,,,,,,,,,,,,,,,, -44533,,,0.9928352236747742,0.0230323988944292,0.582281300015573,0.9868206977844238,0.0455101355910301,0.2740454856114717,43793.0,0.9860752820968628,0.0484214052557945,0.2575476179842053,43793.0,14183.621313095093,22021.435887813568,14183.621313095093,7834.429327011108,2.1637284755706787,0.0 -44600,0.09582913,0.029543161,,,,,,,,,,,,,,,,, -44700,0.09886279,0.027020922,,,,,,,,,,,,,,,,, -44800,0.103487946,0.025758745,,,,,,,,,,,,,,,,, -44900,0.114212096,0.028678995,,,,,,,,,,,,,,,,, -45000,0.11456177,0.029633846,,,,,,,,,,,,,,,,, -45100,0.11440125,0.029234638,,,,,,,,,,,,,,,,, -45200,0.10377419,0.030585969,,,,,,,,,,,,,,,,, -45269,,,0.9929222464561462,0.0229491274803876,0.5656440620216915,0.986701726913452,0.0458743907511234,0.2703104717691669,43793.0,0.98585706949234,0.0489467233419418,0.2580098287432119,43793.0,14423.575603961945,22392.48312997818,14423.575603961945,7965.456509590149,2.204008817672729,0.0 -45300,0.10775481,0.027252069,,,,,,,,,,,,,,,,, -45400,0.12186837,0.028925948,,,,,,,,,,,,,,,,, -45500,0.11800821,0.028833354,,,,,,,,,,,,,,,,, -45600,0.117427595,0.03136138,,,,,,,,,,,,,,,,, -45700,0.10080525,0.027891573,,,,,,,,,,,,,,,,, -45800,0.12105036,0.031443145,,,,,,,,,,,,,,,,, -45900,0.10899895,0.032478563,,,,,,,,,,,,,,,,, -46000,0.101492725,0.02832943,,,,,,,,,,,,,,,,, -46022,,,0.9929423928260804,0.022848380729556,0.5816965987814018,0.9868153929710388,0.0457025654613971,0.2774928305678912,43793.0,0.9859346151351928,0.0487636588513851,0.2573021483270661,43793.0,14663.548845529556,22755.651258468628,14663.548845529556,8088.594908952713,2.2400035858154297,0.0 -46100,0.130554,0.027408436,,,,,,,,,,,,,,,,, -46200,0.10709958,0.028898519,,,,,,,,,,,,,,,,, -46300,0.1257,0.030336931,,,,,,,,,,,,,,,,, -46400,0.113729656,0.026666468,,,,,,,,,,,,,,,,, -46500,0.11098184,0.026618367,,,,,,,,,,,,,,,,, -46600,0.14571606,0.030622939,,,,,,,,,,,,,,,,, -46700,0.11734478,0.030237578,,,,,,,,,,,,,,,,, -46775,,,0.9926509261131288,0.023398483172059,0.5500882324221318,0.9869351387023926,0.0461748838424682,0.2759709745988475,43793.0,0.9860348105430604,0.0492957942187786,0.2616057469839685,43793.0,14903.751368761064,23120.890134334564,14903.751368761064,8213.572768211365,2.277043104171753,0.0 -46800,0.124859855,0.030127214,,,,,,,,,,,,,,,,, -46900,0.122373864,0.026967514,,,,,,,,,,,,,,,,, -47000,0.12051159,0.028671851,,,,,,,,,,,,,,,,, -47100,0.11676651,0.030077448,,,,,,,,,,,,,,,,, -47200,0.11343229,0.027019165,,,,,,,,,,,,,,,,, -47300,0.121927954,0.028906278,,,,,,,,,,,,,,,,, -47400,0.12408196,0.029110089,,,,,,,,,,,,,,,,, -47500,0.11052943,0.028086854,,,,,,,,,,,,,,,,, -47526,,,0.9926035404205322,0.0236497167497873,0.5587826948908996,0.9867784976959229,0.0460036359727382,0.2770768478543392,43793.0,0.985968291759491,0.0490897744894027,0.2601720723494071,43793.0,15143.801016807556,23484.67352104187,15143.801016807556,8337.250361442566,2.31246018409729,0.0 -47600,0.09946179,0.027361682,,,,,,,,,,,,,,,,, -47700,0.12348362,0.030130874,,,,,,,,,,,,,,,,, -47800,0.11535099,0.027730295,,,,,,,,,,,,,,,,, -47900,0.12172953,0.027949,,,,,,,,,,,,,,,,, -48000,0.1320656,0.02825259,,,,,,,,,,,,,,,,, -48100,0.10896316,0.024415867,,,,,,,,,,,,,,,,, -48200,0.10180655,0.025729176,,,,,,,,,,,,,,,,, -48271,,,0.9926998019218444,0.0234197005629539,0.5763855108885503,0.9868454337120056,0.0459731742739677,0.2784930722589843,43793.0,0.986004114151001,0.0490064099431037,0.2604601451995874,43793.0,15383.883342027664,23848.569514989853,15383.883342027664,8461.005041599274,2.348806619644165,0.0 -48300,0.12518056,0.02901248,,,,,,,,,,,,,,,,, -48400,0.1359394,0.031483773,,,,,,,,,,,,,,,,, -48500,0.10620523,0.025751626,,,,,,,,,,,,,,,,, -48600,0.11590498,0.027040405,,,,,,,,,,,,,,,,, -48700,0.1424541,0.0302452,,,,,,,,,,,,,,,,, -48800,0.113492414,0.026683358,,,,,,,,,,,,,,,,, -48900,0.1265842,0.025514904,,,,,,,,,,,,,,,,, -49000,0.13793011,0.029711071,,,,,,,,,,,,,,,,, -49033,,,0.9927592277526855,0.0232525151222944,0.5598052472954665,0.9868369102478028,0.0462203919887542,0.276565655192867,43793.0,0.9859733581542968,0.0493910759687423,0.2595510327017606,43793.0,15624.072097301483,24214.64115262032,15624.072097301483,8586.831431388855,2.384529590606689,0.0 -49100,0.11864756,0.028322699,,,,,,,,,,,,,,,,, -49200,0.1541635,0.026826745,,,,,,,,,,,,,,,,, -49300,0.11001639,0.026824929,,,,,,,,,,,,,,,,, -49400,0.12973626,0.028616961,,,,,,,,,,,,,,,,, -49500,0.13466968,0.026625412,,,,,,,,,,,,,,,,, -49600,0.1254961,0.027681828,,,,,,,,,,,,,,,,, -49700,0.119121306,0.026491698,,,,,,,,,,,,,,,,, -49787,,,0.992779016494751,0.023049347102642,0.557587615426931,0.9867760539054872,0.0467174611985683,0.2682286383885479,43793.0,0.9858781695365906,0.0498117581009864,0.2538513804026239,43793.0,15864.111089468002,24579.27449965477,15864.111089468002,8711.369632005692,2.420367956161499,0.0 -49800,0.12456214,0.027984433,,,,,,,,,,,,,,,,, -49900,0.130024,0.027399179,,,,,,,,,,,,,,,,, -50000,0.15258028,0.027393177,,,,,,,,,,,,,,,,, -50100,0.1511531,0.025836663,,,,,,,,,,,,,,,,, -50200,0.12086985,0.028385095,,,,,,,,,,,,,,,,, -50300,0.12353416,0.026910692,,,,,,,,,,,,,,,,, -50400,0.13613488,0.028084803,,,,,,,,,,,,,,,,, -50500,0.1319543,0.027797587,,,,,,,,,,,,,,,,, -50541,,,0.9930933713912964,0.0218581091612577,0.5951135774073243,0.986832857131958,0.0467489995062351,0.2769980029500767,43793.0,0.9860200881958008,0.0497410297393798,0.2600190963165037,43793.0,16104.149505615234,24940.70306921005,16104.149505615234,8832.70336985588,2.456578016281128,0.0 -50600,0.11609252,0.024000598,,,,,,,,,,,,,,,,, -50700,0.16638827,0.02771805,,,,,,,,,,,,,,,,, -50800,0.15328124,0.028238421,,,,,,,,,,,,,,,,, -50900,0.120833494,0.025213731,,,,,,,,,,,,,,,,, -51000,0.12434853,0.025762105,,,,,,,,,,,,,,,,, -51100,0.1319637,0.02651516,,,,,,,,,,,,,,,,, -51200,0.13015963,0.02692394,,,,,,,,,,,,,,,,, -51297,,,0.9931976199150084,0.0217012390494346,0.591113654143453,0.9867837429046632,0.0467196889221668,0.276880985821581,43793.0,0.985889494419098,0.0497760213911533,0.2609768772472949,43793.0,16344.100379228592,25306.192680835724,16344.100379228592,8958.184759140015,2.49343204498291,0.0 -51300,0.15341173,0.029856915,,,,,,,,,,,,,,,,, -51400,0.14176936,0.027535282,,,,,,,,,,,,,,,,, -51500,0.13638012,0.027948288,,,,,,,,,,,,,,,,, -51600,0.13449612,0.024856789,,,,,,,,,,,,,,,,, -51700,0.16698726,0.027385246,,,,,,,,,,,,,,,,, -51800,0.1424794,0.02924509,,,,,,,,,,,,,,,,, -51900,0.16290225,0.028526936,,,,,,,,,,,,,,,,, -52000,0.14781007,0.028384244,,,,,,,,,,,,,,,,, -52052,,,0.9936760067939758,0.0202678125351667,0.6317658547986209,0.9868324398994446,0.0467768274247646,0.2784929432021214,43793.0,0.985987663269043,0.0499919690191745,0.2584287179698735,43793.0,16584.350385904312,25670.05592918396,16584.350385904312,9081.7381067276,2.5313422679901123,0.0 -52100,0.15456262,0.029011697,,,,,,,,,,,,,,,,, -52200,0.12394321,0.025684338,,,,,,,,,,,,,,,,, -52300,0.12435216,0.02450041,,,,,,,,,,,,,,,,, -52400,0.13301598,0.02587181,,,,,,,,,,,,,,,,, -52500,0.12341772,0.025383336,,,,,,,,,,,,,,,,, -52600,0.1656124,0.03038611,,,,,,,,,,,,,,,,, -52700,0.15922111,0.028020924,,,,,,,,,,,,,,,,, -52800,0.14937054,0.026789417,,,,,,,,,,,,,,,,, -52811,,,0.9940836429595948,0.0191185437142848,0.6474843811104636,0.9867330193519592,0.0475767217576503,0.2701701682771992,43793.0,0.9858625531196594,0.0507209710776805,0.2562704321462084,43793.0,16824.390946626663,26035.98972201348,16824.390946626663,9207.574704885485,2.5672249794006348,0.0 -52900,0.13179341,0.026616288,,,,,,,,,,,,,,,,, -53000,0.13539752,0.024731087,,,,,,,,,,,,,,,,, -53100,0.15404303,0.025308283,,,,,,,,,,,,,,,,, -53200,0.13386932,0.027073663,,,,,,,,,,,,,,,,, -53300,0.15326811,0.025978746,,,,,,,,,,,,,,,,, -53400,0.14926325,0.022861596,,,,,,,,,,,,,,,,, -53500,0.15290296,0.027937189,,,,,,,,,,,,,,,,, -53562,,,0.9939260482788086,0.0195813234895467,0.6418360793553928,0.9867784976959229,0.0476591289043426,0.2710364539945412,43793.0,0.985888659954071,0.0507962331175804,0.2548926207726321,43793.0,17064.626448631287,26399.916737556458,17064.626448631287,9331.2084608078,2.6037094593048096,0.0 -53600,0.13466237,0.026116544,,,,,,,,,,,,,,,,, -53700,0.18735556,0.025881827,,,,,,,,,,,,,,,,, -53800,0.14311251,0.02695118,,,,,,,,,,,,,,,,, -53900,0.12747669,0.025293797,,,,,,,,,,,,,,,,, -54000,0.16674262,0.026026422,,,,,,,,,,,,,,,,, -54100,0.14684644,0.024485724,,,,,,,,,,,,,,,,, -54200,0.13558754,0.02475232,,,,,,,,,,,,,,,,, -54300,0.1480371,0.025487144,,,,,,,,,,,,,,,,, -54304,,,0.9939098954200744,0.0195740088820457,0.6298671260120975,0.986743986606598,0.0479759089648723,0.2708166189977605,43793.0,0.9858903884887696,0.0509893335402011,0.2528467102063643,43793.0,17304.80330300331,26761.601624012,17304.80330300331,9452.654334545135,2.644875288009644,0.0 -54400,0.13776268,0.025996754,,,,,,,,,,,,,,,,, -54500,0.14239967,0.025296116,,,,,,,,,,,,,,,,, -54600,0.16117825,0.026911352,,,,,,,,,,,,,,,,, -54700,0.15467662,0.026944341,,,,,,,,,,,,,,,,, -54800,0.15245491,0.0269644,,,,,,,,,,,,,,,,, -54900,0.15031877,0.02485877,,,,,,,,,,,,,,,,, -55000,0.15686464,0.026007785,,,,,,,,,,,,,,,,, -55063,,,0.9937072992324828,0.0202469173818826,0.6290030533487323,0.9866514205932616,0.0477222166955471,0.2730129936795538,43793.0,0.9858301281929016,0.050938531756401,0.2527652640350336,43793.0,17544.856785297394,27128.370005607605,17544.856785297394,9579.312143564224,2.6809160709381104,0.0 -55100,0.1513874,0.025796248,,,,,,,,,,,,,,,,, -55200,0.18002772,0.026292741,,,,,,,,,,,,,,,,, -55300,0.1475636,0.024528192,,,,,,,,,,,,,,,,, -55400,0.15074766,0.023573218,,,,,,,,,,,,,,,,, -55500,0.17977475,0.025924623,,,,,,,,,,,,,,,,, -55600,0.13710354,0.023485487,,,,,,,,,,,,,,,,, -55700,0.16045165,0.02501205,,,,,,,,,,,,,,,,, -55800,0.17845008,0.024679169,,,,,,,,,,,,,,,,, -55815,,,0.9936484694480896,0.0201473720371723,0.6260407024076896,0.9866623878479004,0.0481687262654304,0.2720521240646281,43793.0,0.9858187437057496,0.0514654070138931,0.2534089531791236,43793.0,17784.952792406082,27491.043491363525,17784.952792406082,9701.83114695549,2.7188405990600586,0.0 -55900,0.17474361,0.023888102,,,,,,,,,,,,,,,,, -56000,0.1442137,0.024251249,,,,,,,,,,,,,,,,, -56100,0.1484584,0.02319897,,,,,,,,,,,,,,,,, -56200,0.15939315,0.024730077,,,,,,,,,,,,,,,,, -56300,0.16182573,0.024091322,,,,,,,,,,,,,,,,, -56400,0.15230407,0.021183636,,,,,,,,,,,,,,,,, -56500,0.12736554,0.023668965,,,,,,,,,,,,,,,,, -56567,,,0.993515133857727,0.0203472040593624,0.6213367212781136,0.9865515232086182,0.0489356853067874,0.2671761914545285,43793.0,0.985710084438324,0.0522744171321392,0.2540701946363075,43793.0,18025.06869673729,27857.546080112457,18025.06869673729,9828.160341501236,2.756070375442505,0.0 -56600,0.15997815,0.022717755,,,,,,,,,,,,,,,,, -56700,0.17110027,0.024181323,,,,,,,,,,,,,,,,, -56800,0.15129054,0.025817912,,,,,,,,,,,,,,,,, -56900,0.17846483,0.024824468,,,,,,,,,,,,,,,,, -57000,0.17249662,0.024660528,,,,,,,,,,,,,,,,, -57100,0.15417945,0.024463063,,,,,,,,,,,,,,,,, -57200,0.19491753,0.025298966,,,,,,,,,,,,,,,,, -57300,0.17099395,0.025182111,,,,,,,,,,,,,,,,, -57320,,,0.9935488104820251,0.020334692671895,0.611723780806605,0.9865474700927734,0.0494059100747108,0.2660708178831285,43793.0,0.985775351524353,0.0527117773890495,0.2552477887895668,43793.0,18265.143191337585,28219.94671702385,18265.143191337585,9950.428166866302,2.793794870376587,0.0 -57400,0.1800905,0.023438016,,,,,,,,,,,,,,,,, -57500,0.16773298,0.023367492,,,,,,,,,,,,,,,,, -57600,0.1924936,0.026453184,,,,,,,,,,,,,,,,, -57700,0.18258938,0.023229413,,,,,,,,,,,,,,,,, -57800,0.17085025,0.024327246,,,,,,,,,,,,,,,,, -57900,0.15340988,0.022048034,,,,,,,,,,,,,,,,, -57989,,,,,,,,,,,,,,18477.2107527256,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 3cf68f2d7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -124.94838285446168,0.0,12.211849689483644,1,0,12.211849689483644,0.5214916467666626,0.7347381114959717,0.0268309887077987,43793,137.16028547286987,0.5325733423233032,0.7276434302330017,0.0231831062687627,0.5230699777603149,0.7331878542900085,0.0255029197338972,43793 -243.82266402244568,0.0216853618621826,252.3856241703033,750,0,252.3856241703033,0.983142077922821,0.0773934349417686,0.0377247415074901,43793,496.25007700920105,0.986755669116974,0.0659792274236679,0.0340223151502751,0.9841179251670836,0.0744884759187698,0.0362152987927236,43793 -374.7523839473725,0.0477268695831298,492.6344926357269,1498,0,492.6344926357269,0.9833973050117492,0.0620699375867843,0.0814007044953301,43793,867.4755127429962,0.9871229529380798,0.04890026897192,0.0824739438043924,0.984372854232788,0.0586623102426528,0.0836942279585238,43793 -497.15764117240906,0.0744333267211914,732.8614349365234,2244,0,732.8614349365234,0.983756184577942,0.0580357871949672,0.1187178026246305,43793,1230.154664993286,0.9873826503753662,0.0455460250377655,0.1263882455282694,0.984746754169464,0.054785881191492,0.1225127042343126,43793 -619.0059733390808,0.1008801460266113,972.8983051776886,2992,0,972.8983051776886,0.9839427471160888,0.056046262383461,0.141987513190904,43793,1592.0877315998075,0.9876368045806884,0.0436954535543918,0.1569229708133287,0.9849249720573424,0.053019069135189,0.1439419953105915,43793 -740.2879574298859,0.1300337314605713,1213.0681955814362,3746,0,1213.0681955814362,0.9841310381889344,0.0546961985528469,0.1534158208389616,43793,1953.5891268253329,0.9878502488136292,0.0424059890210628,0.1737429563879592,0.985138475894928,0.0518514774739742,0.1585169315699912,43793 -867.6484885215759,0.1572957038879394,1453.243047952652,4493,0,1453.243047952652,0.9840977787971495,0.0551101975142955,0.1655582362910332,43793,2321.1722240448,0.9877445101737976,0.0424637980759143,0.2003246065057177,0.985008955001831,0.0522321462631225,0.1644403373332049,43793 -989.522251367569,0.1868824958801269,1693.3792762756348,5241,0,1693.3792762756348,0.9846347570419312,0.0519169941544532,0.1835544135379636,43793,2683.23321557045,0.9884564280509948,0.0398259945213794,0.2073086561436093,0.9855472445487976,0.0492484420537948,0.1844483577769064,43793 -1109.290159702301,0.2166662216186523,1933.3688995838163,5976,0,1933.3688995838163,0.9846912026405334,0.0514706633985042,0.1953586825972949,43793,3043.0417470932007,0.988779842853546,0.0390032492578029,0.241654433364394,0.9856743216514589,0.0488737933337688,0.1980541880902195,43793 -1231.229267120361,0.2448453903198242,2173.444726228714,6723,0,2173.444726228714,0.9851002097129822,0.0501127280294895,0.210851843830605,43793,3405.105672597885,0.9888468384742736,0.0380408428609371,0.2605599738078626,0.9860299229621888,0.047458317130804,0.2136061366712834,43793 -1352.0426914691925,0.2722771167755127,2413.602122068405,7472,0,2413.602122068405,0.9851503372192384,0.0495170578360557,0.2198112775470612,43793,3766.124802827835,0.9890567660331726,0.0371920876204967,0.2785736894766114,0.9860782027244568,0.0469764843583107,0.2222152541430098,43793 -1477.9091057777405,0.3021590709686279,2653.557772874832,8212,0,2653.557772874832,0.9851882457733154,0.0498595125973224,0.2298814156917866,43793,4131.996986627579,0.9891743659973145,0.0361531190574169,0.3071437430999844,0.9860546588897704,0.0470578074455261,0.2279129954488742,43793 -1601.2729868888855,0.3315138816833496,2893.710347890854,8961,0,2893.710347890854,0.9853874444961548,0.0490616783499717,0.2329803435222461,43793,4495.563076972961,0.9895905256271362,0.0351231172680854,0.3273462661718691,0.9861902594566344,0.0463905036449432,0.2325769338745914,43793 -1721.642301082611,0.360389232635498,3133.842006921768,9718,0,3133.842006921768,0.985545814037323,0.0485690161585807,0.237077607090954,43793,4856.113333940506,0.9898381233215332,0.0341963469982147,0.3502433378063998,0.9864122867584229,0.0458182692527771,0.2419818426860593,43793 -1845.3364737033844,0.388399600982666,3373.964447259903,10465,0,3373.964447259903,0.9853895902633668,0.0494105294346809,0.2395083205783805,43793,5219.977976083756,0.989568829536438,0.0340709760785102,0.3687249605372311,0.986235737800598,0.0466529317200183,0.237131357755388,43793 -1969.332597494125,0.4206776618957519,3614.095441818237,11209,0,3614.095441818237,0.985528528690338,0.0485292673110961,0.2416833541478927,43793,5584.160804271698,0.9899488091468812,0.0335929244756698,0.3595911836357722,0.9864537119865416,0.0456045381724834,0.2528149574337986,43793 -2089.5103764534,0.4484295845031738,3854.335582494736,11956,0,3854.335582494736,0.985745906829834,0.0478426851332187,0.2463089958443751,43793,5944.627243518829,0.9904040098190308,0.0325356610119342,0.3633715206238281,0.986591339111328,0.0451359152793884,0.2521590708519562,43793 -2209.084566354752,0.4778716564178467,4094.5407037734985,12707,0,4094.5407037734985,0.9858217239379884,0.0480994060635566,0.2571150736686474,43793,6304.456825494766,0.9905290603637696,0.0317256972193717,0.3934554074221962,0.9866530299186708,0.0453526303172111,0.2577573768745234,43793 -2329.5390541553497,0.5078516006469727,4334.822235822678,13455,0,4334.822235822678,0.9856991171836852,0.0480859875679016,0.2459188570807226,43793,6665.2435131073,0.9906842708587646,0.0313801318407058,0.3929484641070109,0.9865840077400208,0.045293316245079,0.253105285712875,43793 -2449.6124868392944,0.5385310649871826,4574.961861371994,14202,0,4574.961861371994,0.9858486652374268,0.0480134636163711,0.2553494129192967,43793,7025.507632255554,0.9907562732696532,0.0304472688585519,0.428713869830881,0.9867236614227296,0.0452100075781345,0.2582183376390017,43793 -2572.057140827179,0.5675864219665527,4815.122054815292,14948,0,4815.122054815292,0.9857577085494996,0.0482803620398044,0.2439436774518225,43793,7388.161834478378,0.99077308177948,0.0302594359964132,0.4277688119308474,0.9866668581962584,0.0454121641814708,0.2544344404267181,43793 -2693.980701684952,0.5970251560211182,5055.139110803604,15697,0,5055.139110803604,0.9857909679412842,0.0481221489608287,0.2465367554993298,43793,7750.152556419373,0.9912256598472596,0.0288708545267581,0.4513888298738659,0.9866424798965454,0.0453258045017719,0.2573368540704658,43793 -2813.2641813755035,0.6277570724487305,5295.112805843353,16447,0,5295.112805843353,0.9858705401420592,0.048255406320095,0.2530989906320853,43793,8109.461350440979,0.9916675090789796,0.0275912955403327,0.4919298315701566,0.9866794347763062,0.0455493815243244,0.2621947058241379,43793 -2934.739273548126,0.6577551364898682,5535.315126657486,17191,0,5535.315126657486,0.9858924746513368,0.048205729573965,0.2521990759737141,43793,8471.188846826553,0.9916927814483644,0.0272750835865736,0.5031697196652744,0.986777663230896,0.0452458262443542,0.265259789435557,43793 -3056.4563794136047,0.6870527267456055,5775.409202814102,17944,0,5775.409202814102,0.9857863187789916,0.0490991212427616,0.2485788179160932,43793,8833.050078868866,0.9915154576301576,0.0276951789855957,0.5012935847392078,0.9867122769355774,0.0459314920008182,0.2675033684853384,43793 -3174.150098800659,0.7177164554595947,6015.5455322265625,18691,0,6015.5455322265625,0.985912263393402,0.0482953153550624,0.2565524213491001,43793,9190.930800199509,0.9915038347244264,0.0279141776263713,0.4830483010954101,0.9867801070213318,0.0454021245241165,0.2609649209664075,43793 -3294.866469860077,0.7476849555969238,6255.649400234222,19435,0,6255.649400234222,0.9858874082565308,0.0491726137697696,0.2569502106290384,43793,9551.801201581957,0.9913306832313538,0.0281452555209398,0.4829667100200081,0.9866834878921508,0.0461911484599113,0.260969365659252,43793 -3411.558787584305,0.7772469520568848,6495.83327627182,20189,0,6495.83327627182,0.9858680367469788,0.04932626709342,0.2496602999371342,43793,9908.72722864151,0.9917463064193726,0.0269054062664508,0.4939934566066215,0.9866786003112792,0.0461668036878109,0.258319712197705,43793 -3529.147953033448,0.8081989288330078,6736.029312372208,20938,0,6736.029312372208,0.9857537150382996,0.05003098025918,0.2476189983255035,43793,10266.563576698303,0.9917852282524108,0.0268196295946836,0.505956641813925,0.9866956472396852,0.0466153770685195,0.257373007364516,43793 -3646.929877519608,0.841012716293335,6976.128599882126,21689,0,6976.128599882126,0.9857938885688782,0.0493037439882755,0.2549565183568808,43793,10624.498348236084,0.9920173287391664,0.0260045584291219,0.523237135827197,0.9867163896560668,0.0463523082435131,0.2581223071410054,43793 -3765.5550322532654,0.8723323345184326,7216.098851442337,22436,0,7216.098851442337,0.9857235550880432,0.0491550788283348,0.2442960371593634,43793,10983.145557641985,0.9925153851509094,0.0246449559926986,0.5653510078298822,0.9865344762802124,0.0464587174355983,0.2524689786088972,43793 -3886.729905128479,0.9026894569396972,7456.205724477768,23182,0,7456.205724477768,0.9857202172279358,0.0501690804958343,0.2515987042061027,43793,11344.477650642397,0.9926448464393616,0.0239301789551973,0.5750513515882668,0.9865629076957704,0.0472385957837104,0.2531709487648458,43793 -4006.09627699852,0.9327542781829834,7696.241092205048,23933,0,7696.241092205048,0.9857092499732972,0.0499692484736442,0.2528821126225342,43793,11703.929669380188,0.9929741024971008,0.0230709183961153,0.5870177889735277,0.9866250157356262,0.0469329915940761,0.2611496862478765,43793 -4128.651603221893,0.9643106460571288,7936.516567230225,24676,0,7936.516567230225,0.9857812523841858,0.0499368458986282,0.2494686159353731,43793,12066.811733722689,0.9929828643798828,0.0229916647076606,0.5987426324371437,0.9867159724235536,0.04694814234972,0.2613548949456581,43793 -4253.220617294312,0.9951791763305664,8176.752241849899,25421,0,8176.752241849899,0.98576021194458,0.0513744801282882,0.2480600840810696,43793,12431.66696190834,0.992374300956726,0.0245363339781761,0.5646499916322825,0.9865953922271729,0.0482317917048931,0.2583915106640733,43793 -4369.113248348236,1.0298182964324951,8416.71799492836,26160,0,8416.71799492836,0.9857054352760316,0.0503102540969848,0.2487609686306002,43793,12787.58046579361,0.9926239848136902,0.0240995120257139,0.5744237415586881,0.986622154712677,0.0471982508897781,0.2568303117982489,43793 -4491.541553258896,1.0609753131866455,8656.93652176857,26901,0,8656.93652176857,0.9857370257377625,0.0514483153820037,0.2439620641067449,43793,13150.27823996544,0.9923655986785888,0.0245582349598407,0.5520001620106612,0.9866526126861572,0.0479371212422847,0.2553553219675096,43793 -4613.804739713669,1.095686435699463,8897.207918643951,27633,0,8897.207918643951,0.98567134141922,0.0520597845315933,0.2451742423721277,43793,13512.87291264534,0.9923880100250244,0.0241682678461074,0.5654214180134779,0.9865978360176086,0.0485782362520694,0.2502639805699408,43793 -4733.491109848023,1.128706693649292,9137.189074516296,28375,0,9137.189074516296,0.9856898784637452,0.051655750721693,0.2430242722986725,43793,13872.593567609789,0.9930915832519532,0.0223762076348066,0.5991473293839304,0.9865767359733582,0.0481585152447223,0.2556357184836858,43793 -4854.35968875885,1.1624870300292969,9377.43240237236,29123,0,9377.43240237236,0.9857964515686036,0.0519480742514133,0.2473759793119771,43793,14233.759857654572,0.9930724501609802,0.0221859905868768,0.617831682258646,0.9865877032279968,0.0487930439412593,0.2522404282257972,43793 -4972.047943592072,1.1940855979919434,9617.391327857971,29861,0,9617.391327857971,0.9856991171836852,0.0519113168120384,0.2484655358142954,43793,14591.459522247314,0.993553876876831,0.020813263952732,0.6401829967227319,0.9865182638168336,0.0486945249140262,0.2477831701067681,43793 -5097.874583005905,1.2269041538238523,9857.6364672184,30589,0,9857.6364672184,0.9856089949607848,0.0526963770389556,0.243400122037645,43793,14957.588129997252,0.9940480589866638,0.0193919260054826,0.6610642591112672,0.9865243434906006,0.0492027401924133,0.2488174179315952,43793 -5218.830732822418,1.2641723155975342,10097.875643253326,31334,0,10097.875643253326,0.9855576157569884,0.0529208593070507,0.2361791444678669,43793,15318.84247136116,0.9936291575431824,0.0207611881196498,0.6274333508360139,0.9864720106124878,0.0494552142918109,0.2492178935171289,43793 -5339.095104217529,1.2972569465637207,10337.898092508316,32075,0,10337.898092508316,0.9856974482536316,0.0535745881497859,0.2412373732637037,43793,15679.182544469832,0.9931486248970032,0.0215517878532409,0.6206713202358298,0.986558437347412,0.0501068904995918,0.2495110815580755,43793 -5457.711098432541,1.3318133354187012,10577.976407766342,32816,0,10577.976407766342,0.9856300950050354,0.0531991049647331,0.2408217543360258,43793,16037.931844472883,0.9932396411895752,0.0215398985892534,0.6191252736954957,0.9865308403968812,0.0497139617800712,0.2507458835114568,43793 -5574.642826318741,1.364924430847168,10818.098178625109,33566,0,10818.098178625109,0.985470414161682,0.05360097438097,0.2351262883357148,43793,16395.03837776184,0.9932653307914734,0.0214857589453458,0.6170808507068615,0.9865471124649048,0.0499107241630554,0.24726883668381,43793 -5692.171427726746,1.3974757194519043,11058.361001968384,34319,0,11058.361001968384,0.9855020046234132,0.0530079714953899,0.2383980651235781,43793,16752.882014513016,0.99366956949234,0.0203478969633579,0.6380343914037621,0.9864545464515686,0.0497109591960907,0.2510879231059746,43793 -5813.030182600021,1.4321606159210205,11298.55425786972,35066,0,11298.55425786972,0.9855298399925232,0.0547693707048893,0.2370133223069865,43793,17113.989943742752,0.9935685992240906,0.0202502589672803,0.6473796514189897,0.9865594506263732,0.0509484857320785,0.2562781973327978,43793 -5930.724012136459,1.4665467739105225,11538.779606342316,35807,0,11538.779606342316,0.9854856133461,0.0541211180388927,0.2356342042180603,43793,17471.963469982147,0.9941927194595336,0.0187986772507429,0.6769564832266253,0.9864675402641296,0.0506184510886669,0.2556838216760408,43793 -6048.290234088898,1.875577449798584,11778.403351545334,36556,0,11778.403351545334,0.9853832721710204,0.0543567053973674,0.2394302539002576,43793,17829.582918167114,0.9948610663414,0.0171932596713304,0.712666872419726,0.9863424897193908,0.0508019998669624,0.2482839235230926,43793 -6165.5471975803375,1.9106330871582031,12018.428232431412,37297,0,12018.428232431412,0.9853453636169434,0.0549082830548286,0.2335472181764784,43793,18186.92015695572,0.9950770735740662,0.0165905263274908,0.7307588502918585,0.9863532185554504,0.0512042194604873,0.2467354335083033,43793 -6284.926635742188,1.9459753036499023,12258.432970285416,38046,0,12258.432970285416,0.9854186177253724,0.0556958429515361,0.234332385061237,43793,18546.359894514084,0.9940646886825562,0.0189915187656879,0.6748852044013662,0.986367642879486,0.0518397986888885,0.2426539176487821,43793 -6400.845049619675,1.979938507080078,12498.420560121536,38788,0,12498.420560121536,0.9852977395057678,0.0552296452224254,0.2311972221425861,43793,18902.31952357292,0.9945722818374634,0.0179280396550893,0.6901741289062857,0.9861927032470704,0.0516926646232605,0.23949731854974,43793 -6518.578117609024,2.013803243637085,12738.50132727623,39532,0,12738.50132727623,0.9853373169898988,0.0559426136314868,0.2273990587653981,43793,19260.18678236008,0.9941584467887878,0.0185711476951837,0.6757015592453715,0.9862909317016602,0.0520914942026138,0.2429719702376356,43793 -6632.365000963211,2.0504045486450195,12978.7029337883,40275,0,12978.7029337883,0.9853878617286682,0.0571408942341804,0.2266936024700547,43793,19614.23141884804,0.9939488172531128,0.0189921930432319,0.6637098607290097,0.98630028963089,0.0531427562236785,0.2391469379130368,43793 -6753.425145626068,2.084524631500244,13218.936644792557,41017,0,13218.936644792557,0.9853529334068298,0.0570392683148384,0.2271843160553856,43793,19975.578807592392,0.993874728679657,0.0191944427788257,0.6714206706921495,0.9862105846405028,0.0532609857618808,0.2385562048796848,43793 -6871.238674879074,2.11984920501709,13458.923202037811,41761,0,13458.923202037811,0.9853084683418274,0.0573048181831836,0.2291360045914827,43793,20333.43417596817,0.994343101978302,0.0179709792137146,0.6866316394471659,0.9861910939216614,0.053494531661272,0.2414926688414058,43793 -6984.798540115356,2.154636144638061,13699.15291404724,42496,0,13699.15291404724,0.985352098941803,0.0576530136168003,0.2288912827762337,43793,20687.27829146385,0.9952325224876404,0.0154723525047302,0.7526896612136231,0.9862304329872132,0.0538027547299861,0.2387060657894706,43793 -7104.342164516449,2.190378427505493,13939.323972702026,43233,0,13939.323972702026,0.9852619171142578,0.0583874657750129,0.2256373697097565,43793,21047.05036020279,0.9956064820289612,0.0144054051488637,0.7794564332097744,0.986202836036682,0.0545261949300766,0.2326384161525987,43793 -7223.574427843094,2.22595739364624,14179.28227376938,43974,0,14179.28227376938,0.9852930903434752,0.058384072035551,0.22733958924486,43793,21406.30087685585,0.9957359433174132,0.014393120072782,0.762248438689672,0.986229658126831,0.0544007755815982,0.2338933867718675,43793 -7342.6540105342865,2.2634549140930176,14419.376125335692,44713,0,14419.376125335692,0.9853048920631408,0.0585244484245777,0.2281588202558792,43793,21765.531745910645,0.9954484105110168,0.0150855518877506,0.7581501313243499,0.9861987829208374,0.054727304726839,0.2367285787323215,43793 -7457.90939617157,2.2987446784973145,14659.376512050629,45457,0,14659.376512050629,0.9852097034454346,0.0588457509875297,0.2212094006776859,43793,22120.84386181832,0.9950907230377196,0.0157091245055198,0.7374597770421244,0.986100137233734,0.0549469888210296,0.2338895422966686,43793 -7573.364289522171,2.335345506668091,14899.567381620407,46200,0,14899.567381620407,0.9853684902191162,0.0596916303038597,0.2266012596243111,43793,22476.54651904106,0.9943767786026,0.0173879079520702,0.7149168519041117,0.9862478971481324,0.0558124035596847,0.2359758331316309,43793 -7689.344358444214,2.3755228519439697,15139.756244659424,46945,0,15139.756244659424,0.9852564930915833,0.0597845762968063,0.2232917765851285,43793,22832.776628017426,0.9943929314613342,0.0173799358308315,0.7088835674631668,0.9861464500427246,0.0558986105024814,0.2357999831802581,43793 -7808.494928121567,2.4163498878479004,15379.908016204834,47690,0,15379.908016204834,0.9851852655410768,0.0603056997060775,0.2267443674197929,43793,23192.14047503472,0.9948909282684326,0.0161126609891653,0.7247512557869674,0.9861387014389038,0.0563165694475173,0.2352581202099787,43793 -7924.379849433899,2.4527883529663086,15619.869005203249,48437,0,15619.869005203249,0.9852004647254944,0.0607633218169212,0.2229074234759547,43793,23548.04278063774,0.99583101272583,0.0137463333085179,0.7851234003757641,0.9861135482788086,0.0566686391830444,0.2343727002930017,43793 -8040.179135560989,2.490166187286377,15860.097087621689,49184,0,15860.097087621689,0.9851077795028688,0.0605574212968349,0.2294661277003323,43793,23904.12767291069,0.9950827360153198,0.0155350798740983,0.746528199797595,0.9861204624176024,0.0562176518142223,0.2370487070055823,43793 -8153.885211467743,2.5272974967956543,16100.1772043705,49935,0,16100.1772043705,0.9851545095443726,0.0619620457291603,0.2204592669232147,43793,24257.970780849457,0.996401846408844,0.0126142036169767,0.8146022634550463,0.9861245155334472,0.0576920099556446,0.2316445354722096,43793 -8270.046475887299,2.562939405441284,16340.363671779633,50680,0,16340.363671779633,0.9852147698402404,0.0626945868134498,0.2215540398824933,43793,24614.37414741516,0.9969805479049684,0.0111792050302028,0.8314752558336872,0.9861451983451844,0.0584052912890911,0.2308138767184729,43793 -8383.532366275787,2.60349440574646,16580.46027135849,51424,0,16580.46027135849,0.985121250152588,0.0625859647989273,0.2194823881609811,43793,24968.01723909378,0.9964102506637572,0.0121396416798233,0.8212676428096144,0.9860566854476928,0.0583216175436973,0.226055681241575,43793 -8497.374136447906,2.6400325298309326,16820.501957178116,52170,0,16820.501957178116,0.98509681224823,0.0628611296415329,0.2198022544309893,43793,25321.957409858704,0.9960031509399414,0.0130105959251523,0.7998794085137176,0.9860563278198242,0.0587018877267837,0.2295651743536794,43793 -8615.72803235054,2.677996158599853,17060.735783100128,52920,0,17060.735783100128,0.9851848483085632,0.0634528473019599,0.2199008821173299,43793,25680.60373020172,0.995352268218994,0.0142620839178562,0.7881900013083352,0.986114740371704,0.0593496747314929,0.225412239521873,43793 -8726.805708408356,2.715599775314331,17300.85765480995,53669,0,17300.85765480995,0.9851098656654358,0.0631409138441085,0.218001164903323,43793,26031.861436128616,0.9954304099082948,0.0140579780563712,0.785213941239531,0.9861135482788086,0.0589690953493118,0.2228026236983691,43793 -8839.761295318604,2.7525038719177246,17541.0244076252,54412,0,17541.0244076252,0.9851149320602416,0.0638305619359016,0.216414058452331,43793,26385.043430566788,0.995689332485199,0.0135448304936289,0.7969386333438457,0.9860498309135436,0.0595028214156627,0.2283987255755745,43793 -8953.982422113419,2.7906315326690674,17781.22932624817,55167,0,17781.22932624817,0.9850315451622008,0.0636202916502952,0.2174672035747131,43793,26739.52824115753,0.9960854053497314,0.0127092683687806,0.8059383088169321,0.9860132932662964,0.0594084672629833,0.2242151660658549,43793 -9071.203595876694,2.8327012062072754,18021.468203783035,55905,0,18021.468203783035,0.9850290417671204,0.0648826137185096,0.2126118880927797,43793,27097.05589222908,0.9968048930168152,0.0110593615099787,0.8559029305644175,0.9860246181488036,0.0603373870253562,0.2279856110308997,43793 -9182.202111959457,2.873664617538452,18261.61960697174,56650,0,18261.61960697174,0.9850433468818665,0.06487562507390976,0.21501642913040128,43793,27448.26845598221,0.9976639151573181,0.009552651084959507,0.8787869832304159,0.985992968082428,0.060625217854976654,0.22720330974813813,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/measurements.csv deleted file mode 100644 index 03ea3ece1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/measurements.csv +++ /dev/null @@ -1,653 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.5307114,0.72719467,,,,,,,,,,,,,,,,, -1,,,0.5325733423233032,0.7276434302330017,0.0231831062687627,0.5230699777603149,0.7331878542900085,0.0255029197338972,43793.0,0.5214916467666626,0.7347381114959717,0.0268309887077987,43793.0,12.211849689483644,137.16028547286987,12.211849689483644,124.94838285446168,0.0,0.0 -100,0.4575896,0.39823782,,,,,,,,,,,,,,,,, -200,0.34824225,0.30278274,,,,,,,,,,,,,,,,, -300,0.2488573,0.20728433,,,,,,,,,,,,,,,,, -400,0.15522498,0.13984518,,,,,,,,,,,,,,,,, -500,0.10214838,0.10229654,,,,,,,,,,,,,,,,, -600,0.07659836,0.07729929,,,,,,,,,,,,,,,,, -700,0.112654574,0.06860004,,,,,,,,,,,,,,,,, -750,,,0.986755669116974,0.0659792274236679,0.0340223151502751,0.9841179251670836,0.0744884759187698,0.0362152987927236,43793.0,0.983142077922821,0.0773934349417686,0.0377247415074901,43793.0,252.3856241703033,496.25007700920105,252.3856241703033,243.82266402244568,0.0216853618621826,0.0 -800,0.080179006,0.06171869,,,,,,,,,,,,,,,,, -900,0.034526605,0.060375728,,,,,,,,,,,,,,,,, -1000,0.06259471,0.056105983,,,,,,,,,,,,,,,,, -1100,0.053084724,0.053710885,,,,,,,,,,,,,,,,, -1200,0.023455517,0.049201973,,,,,,,,,,,,,,,,, -1300,0.043594126,0.054323155,,,,,,,,,,,,,,,,, -1400,0.028014261,0.04838261,,,,,,,,,,,,,,,,, -1498,,,0.9871229529380798,0.04890026897192,0.0824739438043924,0.984372854232788,0.0586623102426528,0.0836942279585238,43793.0,0.9833973050117492,0.0620699375867843,0.0814007044953301,43793.0,492.6344926357269,867.4755127429962,492.6344926357269,374.7523839473725,0.0477268695831298,0.0 -1500,0.09543994,0.057407882,,,,,,,,,,,,,,,,, -1600,0.027494902,0.05001165,,,,,,,,,,,,,,,,, -1700,0.07362223,0.055772226,,,,,,,,,,,,,,,,, -1800,0.032815866,0.046702,,,,,,,,,,,,,,,,, -1900,0.0625103,0.04788953,,,,,,,,,,,,,,,,, -2000,0.06566455,0.049555734,,,,,,,,,,,,,,,,, -2100,0.021909337,0.050403204,,,,,,,,,,,,,,,,, -2200,0.0296982,0.051342785,,,,,,,,,,,,,,,,, -2244,,,0.9873826503753662,0.0455460250377655,0.1263882455282694,0.984746754169464,0.054785881191492,0.1225127042343126,43793.0,0.983756184577942,0.0580357871949672,0.1187178026246305,43793.0,732.8614349365234,1230.154664993286,732.8614349365234,497.15764117240906,0.0744333267211914,0.0 -2300,0.049279805,0.05211553,,,,,,,,,,,,,,,,, -2400,0.03747727,0.045552947,,,,,,,,,,,,,,,,, -2500,0.058719814,0.048502814,,,,,,,,,,,,,,,,, -2600,0.020068184,0.044243548,,,,,,,,,,,,,,,,, -2700,0.05027727,0.045236964,,,,,,,,,,,,,,,,, -2800,0.056023896,0.047664948,,,,,,,,,,,,,,,,, -2900,0.045083914,0.051978864,,,,,,,,,,,,,,,,, -2992,,,0.9876368045806884,0.0436954535543918,0.1569229708133287,0.9849249720573424,0.053019069135189,0.1439419953105915,43793.0,0.9839427471160888,0.056046262383461,0.141987513190904,43793.0,972.8983051776886,1592.0877315998075,972.8983051776886,619.0059733390808,0.1008801460266113,0.0 -3000,0.02333577,0.046816498,,,,,,,,,,,,,,,,, -3100,0.023121856,0.044863053,,,,,,,,,,,,,,,,, -3200,0.04462889,0.04653849,,,,,,,,,,,,,,,,, -3300,0.022815071,0.044486865,,,,,,,,,,,,,,,,, -3400,0.024832305,0.043090284,,,,,,,,,,,,,,,,, -3500,0.033929855,0.044646896,,,,,,,,,,,,,,,,, -3600,0.05837636,0.04453454,,,,,,,,,,,,,,,,, -3700,0.015419278,0.04488977,,,,,,,,,,,,,,,,, -3746,,,0.9878502488136292,0.0424059890210628,0.1737429563879592,0.985138475894928,0.0518514774739742,0.1585169315699912,43793.0,0.9841310381889344,0.0546961985528469,0.1534158208389616,43793.0,1213.0681955814362,1953.5891268253329,1213.0681955814362,740.2879574298859,0.1300337314605713,0.0 -3800,0.018495012,0.042195227,,,,,,,,,,,,,,,,, -3900,0.06671204,0.046040166,,,,,,,,,,,,,,,,, -4000,0.021644497,0.04701069,,,,,,,,,,,,,,,,, -4100,0.019929154,0.04460039,,,,,,,,,,,,,,,,, -4200,0.020918982,0.0410958,,,,,,,,,,,,,,,,, -4300,0.017950984,0.044295475,,,,,,,,,,,,,,,,, -4400,0.016774554,0.042742696,,,,,,,,,,,,,,,,, -4493,,,0.9877445101737976,0.0424637980759143,0.2003246065057177,0.985008955001831,0.0522321462631225,0.1644403373332049,43793.0,0.9840977787971495,0.0551101975142955,0.1655582362910332,43793.0,1453.243047952652,2321.1722240448,1453.243047952652,867.6484885215759,0.1572957038879394,0.0 -4500,0.02045346,0.044842858,,,,,,,,,,,,,,,,, -4600,0.032280497,0.047840357,,,,,,,,,,,,,,,,, -4700,0.021949828,0.04055936,,,,,,,,,,,,,,,,, -4800,0.026293324,0.043701764,,,,,,,,,,,,,,,,, -4900,0.012517816,0.041401137,,,,,,,,,,,,,,,,, -5000,0.022682946,0.044639118,,,,,,,,,,,,,,,,, -5100,0.012252075,0.040674757,,,,,,,,,,,,,,,,, -5200,0.01889694,0.04412396,,,,,,,,,,,,,,,,, -5241,,,0.9884564280509948,0.0398259945213794,0.2073086561436093,0.9855472445487976,0.0492484420537948,0.1844483577769064,43793.0,0.9846347570419312,0.0519169941544532,0.1835544135379636,43793.0,1693.3792762756348,2683.23321557045,1693.3792762756348,989.522251367569,0.1868824958801269,0.0 -5300,0.02030062,0.04201492,,,,,,,,,,,,,,,,, -5400,0.022790497,0.048982352,,,,,,,,,,,,,,,,, -5500,0.014338486,0.044368856,,,,,,,,,,,,,,,,, -5600,0.020271681,0.04490464,,,,,,,,,,,,,,,,, -5700,0.017915707,0.037725464,,,,,,,,,,,,,,,,, -5800,0.029182699,0.043104168,,,,,,,,,,,,,,,,, -5900,0.013659059,0.041171186,,,,,,,,,,,,,,,,, -5976,,,0.988779842853546,0.0390032492578029,0.241654433364394,0.9856743216514589,0.0488737933337688,0.1980541880902195,43793.0,0.9846912026405334,0.0514706633985042,0.1953586825972949,43793.0,1933.3688995838163,3043.0417470932007,1933.3688995838163,1109.290159702301,0.2166662216186523,0.0 -6000,0.022507874,0.03990506,,,,,,,,,,,,,,,,, -6100,0.014033731,0.038198013,,,,,,,,,,,,,,,,, -6200,0.013670044,0.041050818,,,,,,,,,,,,,,,,, -6300,0.016873443,0.041152216,,,,,,,,,,,,,,,,, -6400,0.016814394,0.04294454,,,,,,,,,,,,,,,,, -6500,0.014326274,0.04175656,,,,,,,,,,,,,,,,, -6600,0.018171709,0.04005458,,,,,,,,,,,,,,,,, -6700,0.020429261,0.044185214,,,,,,,,,,,,,,,,, -6723,,,0.9888468384742736,0.0380408428609371,0.2605599738078626,0.9860299229621888,0.047458317130804,0.2136061366712834,43793.0,0.9851002097129822,0.0501127280294895,0.210851843830605,43793.0,2173.444726228714,3405.105672597885,2173.444726228714,1231.229267120361,0.2448453903198242,0.0 -6800,0.01359007,0.04053172,,,,,,,,,,,,,,,,, -6900,0.016228793,0.04148031,,,,,,,,,,,,,,,,, -7000,0.021098958,0.04192704,,,,,,,,,,,,,,,,, -7100,0.013487282,0.045937426,,,,,,,,,,,,,,,,, -7200,0.019881934,0.04100923,,,,,,,,,,,,,,,,, -7300,0.016583173,0.041413464,,,,,,,,,,,,,,,,, -7400,0.016211178,0.043200694,,,,,,,,,,,,,,,,, -7472,,,0.9890567660331726,0.0371920876204967,0.2785736894766114,0.9860782027244568,0.0469764843583107,0.2222152541430098,43793.0,0.9851503372192384,0.0495170578360557,0.2198112775470612,43793.0,2413.602122068405,3766.124802827835,2413.602122068405,1352.0426914691925,0.2722771167755127,0.0 -7500,0.01351139,0.04286298,,,,,,,,,,,,,,,,, -7600,0.023283377,0.042902883,,,,,,,,,,,,,,,,, -7700,0.017856704,0.03956027,,,,,,,,,,,,,,,,, -7800,0.02043879,0.041695528,,,,,,,,,,,,,,,,, -7900,0.02380284,0.04036782,,,,,,,,,,,,,,,,, -8000,0.020585267,0.042755272,,,,,,,,,,,,,,,,, -8100,0.020426944,0.040925533,,,,,,,,,,,,,,,,, -8200,0.020352779,0.037595212,,,,,,,,,,,,,,,,, -8212,,,0.9891743659973145,0.0361531190574169,0.3071437430999844,0.9860546588897704,0.0470578074455261,0.2279129954488742,43793.0,0.9851882457733154,0.0498595125973224,0.2298814156917866,43793.0,2653.557772874832,4131.996986627579,2653.557772874832,1477.9091057777405,0.3021590709686279,0.0 -8300,0.012609688,0.040621053,,,,,,,,,,,,,,,,, -8400,0.014167798,0.041110847,,,,,,,,,,,,,,,,, -8500,0.015668454,0.039577913,,,,,,,,,,,,,,,,, -8600,0.018637907,0.039343122,,,,,,,,,,,,,,,,, -8700,0.015727207,0.03963267,,,,,,,,,,,,,,,,, -8800,0.015710765,0.04290251,,,,,,,,,,,,,,,,, -8900,0.015257523,0.04158176,,,,,,,,,,,,,,,,, -8961,,,0.9895905256271362,0.0351231172680854,0.3273462661718691,0.9861902594566344,0.0463905036449432,0.2325769338745914,43793.0,0.9853874444961548,0.0490616783499717,0.2329803435222461,43793.0,2893.710347890854,4495.563076972961,2893.710347890854,1601.2729868888855,0.3315138816833496,0.0 -9000,0.016458008,0.043357458,,,,,,,,,,,,,,,,, -9100,0.016379768,0.041480288,,,,,,,,,,,,,,,,, -9200,0.013542288,0.040342793,,,,,,,,,,,,,,,,, -9300,0.018393839,0.04147755,,,,,,,,,,,,,,,,, -9400,0.016194876,0.040211614,,,,,,,,,,,,,,,,, -9500,0.012470877,0.038774967,,,,,,,,,,,,,,,,, -9600,0.017296275,0.035216194,,,,,,,,,,,,,,,,, -9700,0.017144734,0.040592488,,,,,,,,,,,,,,,,, -9718,,,0.9898381233215332,0.0341963469982147,0.3502433378063998,0.9864122867584229,0.0458182692527771,0.2419818426860593,43793.0,0.985545814037323,0.0485690161585807,0.237077607090954,43793.0,3133.842006921768,4856.113333940506,3133.842006921768,1721.642301082611,0.360389232635498,0.0 -9800,0.018801143,0.04198088,,,,,,,,,,,,,,,,, -9900,0.018857373,0.042757533,,,,,,,,,,,,,,,,, -10000,0.016186196,0.036521487,,,,,,,,,,,,,,,,, -10100,0.01921311,0.039273802,,,,,,,,,,,,,,,,, -10200,0.024957757,0.038829606,,,,,,,,,,,,,,,,, -10300,0.019052083,0.04260381,,,,,,,,,,,,,,,,, -10400,0.02268876,0.0402886,,,,,,,,,,,,,,,,, -10465,,,0.989568829536438,0.0340709760785102,0.3687249605372311,0.986235737800598,0.0466529317200183,0.237131357755388,43793.0,0.9853895902633668,0.0494105294346809,0.2395083205783805,43793.0,3373.964447259903,5219.977976083756,3373.964447259903,1845.3364737033844,0.388399600982666,0.0 -10500,0.021133646,0.04280686,,,,,,,,,,,,,,,,, -10600,0.015976485,0.037413314,,,,,,,,,,,,,,,,, -10700,0.024421146,0.04422278,,,,,,,,,,,,,,,,, -10800,0.016302148,0.037329108,,,,,,,,,,,,,,,,, -10900,0.016947009,0.04074273,,,,,,,,,,,,,,,,, -11000,0.0151487235,0.041951187,,,,,,,,,,,,,,,,, -11100,0.016470963,0.040031537,,,,,,,,,,,,,,,,, -11200,0.013566753,0.03936735,,,,,,,,,,,,,,,,, -11209,,,0.9899488091468812,0.0335929244756698,0.3595911836357722,0.9864537119865416,0.0456045381724834,0.2528149574337986,43793.0,0.985528528690338,0.0485292673110961,0.2416833541478927,43793.0,3614.095441818237,5584.160804271698,3614.095441818237,1969.332597494125,0.4206776618957519,0.0 -11300,0.0179295,0.038407974,,,,,,,,,,,,,,,,, -11400,0.015212051,0.03643754,,,,,,,,,,,,,,,,, -11500,0.018379144,0.038650215,,,,,,,,,,,,,,,,, -11600,0.014567685,0.03908965,,,,,,,,,,,,,,,,, -11700,0.025370674,0.039310977,,,,,,,,,,,,,,,,, -11800,0.01340015,0.036739647,,,,,,,,,,,,,,,,, -11900,0.017354827,0.040682737,,,,,,,,,,,,,,,,, -11956,,,0.9904040098190308,0.0325356610119342,0.3633715206238281,0.986591339111328,0.0451359152793884,0.2521590708519562,43793.0,0.985745906829834,0.0478426851332187,0.2463089958443751,43793.0,3854.335582494736,5944.627243518829,3854.335582494736,2089.5103764534,0.4484295845031738,0.0 -12000,0.0172082,0.039771426,,,,,,,,,,,,,,,,, -12100,0.015199969,0.04027798,,,,,,,,,,,,,,,,, -12200,0.022197917,0.038270157,,,,,,,,,,,,,,,,, -12300,0.015178836,0.03561281,,,,,,,,,,,,,,,,, -12400,0.017830381,0.035985943,,,,,,,,,,,,,,,,, -12500,0.015524568,0.03672164,,,,,,,,,,,,,,,,, -12600,0.017727626,0.040859792,,,,,,,,,,,,,,,,, -12700,0.018033953,0.037584644,,,,,,,,,,,,,,,,, -12707,,,0.9905290603637696,0.0317256972193717,0.3934554074221962,0.9866530299186708,0.0453526303172111,0.2577573768745234,43793.0,0.9858217239379884,0.0480994060635566,0.2571150736686474,43793.0,4094.5407037734985,6304.456825494766,4094.5407037734985,2209.084566354752,0.4778716564178467,0.0 -12800,0.017784126,0.038603827,,,,,,,,,,,,,,,,, -12900,0.01797836,0.03557872,,,,,,,,,,,,,,,,, -13000,0.01688702,0.03650944,,,,,,,,,,,,,,,,, -13100,0.02365772,0.040618878,,,,,,,,,,,,,,,,, -13200,0.022297693,0.037127152,,,,,,,,,,,,,,,,, -13300,0.015810046,0.036530986,,,,,,,,,,,,,,,,, -13400,0.023188733,0.04089543,,,,,,,,,,,,,,,,, -13455,,,0.9906842708587646,0.0313801318407058,0.3929484641070109,0.9865840077400208,0.045293316245079,0.253105285712875,43793.0,0.9856991171836852,0.0480859875679016,0.2459188570807226,43793.0,4334.822235822678,6665.2435131073,4334.822235822678,2329.5390541553497,0.5078516006469727,0.0 -13500,0.017959362,0.033813965,,,,,,,,,,,,,,,,, -13600,0.019081108,0.03650151,,,,,,,,,,,,,,,,, -13700,0.0211855,0.039562862,,,,,,,,,,,,,,,,, -13800,0.017586792,0.039498016,,,,,,,,,,,,,,,,, -13900,0.019152563,0.03665051,,,,,,,,,,,,,,,,, -14000,0.01876549,0.036768716,,,,,,,,,,,,,,,,, -14100,0.020829566,0.036592837,,,,,,,,,,,,,,,,, -14200,0.021209016,0.035200413,,,,,,,,,,,,,,,,, -14202,,,0.9907562732696532,0.0304472688585519,0.428713869830881,0.9867236614227296,0.0452100075781345,0.2582183376390017,43793.0,0.9858486652374268,0.0480134636163711,0.2553494129192967,43793.0,4574.961861371994,7025.507632255554,4574.961861371994,2449.6124868392944,0.5385310649871826,0.0 -14300,0.017949827,0.036711432,,,,,,,,,,,,,,,,, -14400,0.02253605,0.03789681,,,,,,,,,,,,,,,,, -14500,0.019664155,0.034261573,,,,,,,,,,,,,,,,, -14600,0.022710487,0.036515977,,,,,,,,,,,,,,,,, -14700,0.018364906,0.035958957,,,,,,,,,,,,,,,,, -14800,0.031274345,0.03915197,,,,,,,,,,,,,,,,, -14900,0.023107177,0.04024801,,,,,,,,,,,,,,,,, -14948,,,0.99077308177948,0.0302594359964132,0.4277688119308474,0.9866668581962584,0.0454121641814708,0.2544344404267181,43793.0,0.9857577085494996,0.0482803620398044,0.2439436774518225,43793.0,4815.122054815292,7388.161834478378,4815.122054815292,2572.057140827179,0.5675864219665527,0.0 -15000,0.018873038,0.037292976,,,,,,,,,,,,,,,,, -15100,0.018095901,0.036610924,,,,,,,,,,,,,,,,, -15200,0.017819962,0.036614645,,,,,,,,,,,,,,,,, -15300,0.018941054,0.037947945,,,,,,,,,,,,,,,,, -15400,0.020427847,0.037739586,,,,,,,,,,,,,,,,, -15500,0.020637585,0.036871552,,,,,,,,,,,,,,,,, -15600,0.021823987,0.038709503,,,,,,,,,,,,,,,,, -15697,,,0.9912256598472596,0.0288708545267581,0.4513888298738659,0.9866424798965454,0.0453258045017719,0.2573368540704658,43793.0,0.9857909679412842,0.0481221489608287,0.2465367554993298,43793.0,5055.139110803604,7750.152556419373,5055.139110803604,2693.980701684952,0.5970251560211182,0.0 -15700,0.02213493,0.035296943,,,,,,,,,,,,,,,,, -15800,0.01732784,0.03446562,,,,,,,,,,,,,,,,, -15900,0.020008028,0.03656383,,,,,,,,,,,,,,,,, -16000,0.025535755,0.037125845,,,,,,,,,,,,,,,,, -16100,0.020650037,0.036147792,,,,,,,,,,,,,,,,, -16200,0.0211103,0.036051854,,,,,,,,,,,,,,,,, -16300,0.020504968,0.033377558,,,,,,,,,,,,,,,,, -16400,0.019707663,0.035685737,,,,,,,,,,,,,,,,, -16447,,,0.9916675090789796,0.0275912955403327,0.4919298315701566,0.9866794347763062,0.0455493815243244,0.2621947058241379,43793.0,0.9858705401420592,0.048255406320095,0.2530989906320853,43793.0,5295.112805843353,8109.461350440979,5295.112805843353,2813.2641813755035,0.6277570724487305,0.0 -16500,0.021225763,0.03405424,,,,,,,,,,,,,,,,, -16600,0.020495437,0.035730004,,,,,,,,,,,,,,,,, -16700,0.021696294,0.037164513,,,,,,,,,,,,,,,,, -16800,0.021815687,0.036828097,,,,,,,,,,,,,,,,, -16900,0.02113127,0.03657432,,,,,,,,,,,,,,,,, -17000,0.019610927,0.037162304,,,,,,,,,,,,,,,,, -17100,0.02438424,0.036129482,,,,,,,,,,,,,,,,, -17191,,,0.9916927814483644,0.0272750835865736,0.5031697196652744,0.986777663230896,0.0452458262443542,0.265259789435557,43793.0,0.9858924746513368,0.048205729573965,0.2521990759737141,43793.0,5535.315126657486,8471.188846826553,5535.315126657486,2934.739273548126,0.6577551364898682,0.0 -17200,0.021680932,0.03593656,,,,,,,,,,,,,,,,, -17300,0.021483535,0.033582915,,,,,,,,,,,,,,,,, -17400,0.021980796,0.03660341,,,,,,,,,,,,,,,,, -17500,0.020737668,0.037886932,,,,,,,,,,,,,,,,, -17600,0.022244047,0.035054766,,,,,,,,,,,,,,,,, -17700,0.019977314,0.035864808,,,,,,,,,,,,,,,,, -17800,0.027030464,0.037724786,,,,,,,,,,,,,,,,, -17900,0.02625294,0.037979,,,,,,,,,,,,,,,,, -17944,,,0.9915154576301576,0.0276951789855957,0.5012935847392078,0.9867122769355774,0.0459314920008182,0.2675033684853384,43793.0,0.9857863187789916,0.0490991212427616,0.2485788179160932,43793.0,5775.409202814102,8833.050078868866,5775.409202814102,3056.4563794136047,0.6870527267456055,0.0 -18000,0.033569742,0.036414385,,,,,,,,,,,,,,,,, -18100,0.024429232,0.036436364,,,,,,,,,,,,,,,,, -18200,0.023744823,0.035925988,,,,,,,,,,,,,,,,, -18300,0.035400193,0.034974568,,,,,,,,,,,,,,,,, -18400,0.024910431,0.036101285,,,,,,,,,,,,,,,,, -18500,0.026043098,0.034752,,,,,,,,,,,,,,,,, -18600,0.02241981,0.03625252,,,,,,,,,,,,,,,,, -18691,,,0.9915038347244264,0.0279141776263713,0.4830483010954101,0.9867801070213318,0.0454021245241165,0.2609649209664075,43793.0,0.985912263393402,0.0482953153550624,0.2565524213491001,43793.0,6015.5455322265625,9190.930800199509,6015.5455322265625,3174.150098800659,0.7177164554595947,0.0 -18700,0.021974519,0.03175946,,,,,,,,,,,,,,,,, -18800,0.026578479,0.034032382,,,,,,,,,,,,,,,,, -18900,0.025234062,0.03547452,,,,,,,,,,,,,,,,, -19000,0.024748111,0.036081545,,,,,,,,,,,,,,,,, -19100,0.029775426,0.034614164,,,,,,,,,,,,,,,,, -19200,0.025228338,0.032244533,,,,,,,,,,,,,,,,, -19300,0.030099506,0.033957884,,,,,,,,,,,,,,,,, -19400,0.021308804,0.033313304,,,,,,,,,,,,,,,,, -19435,,,0.9913306832313538,0.0281452555209398,0.4829667100200081,0.9866834878921508,0.0461911484599113,0.260969365659252,43793.0,0.9858874082565308,0.0491726137697696,0.2569502106290384,43793.0,6255.649400234222,9551.801201581957,6255.649400234222,3294.866469860077,0.7476849555969238,0.0 -19500,0.023814991,0.032818586,,,,,,,,,,,,,,,,, -19600,0.027679661,0.036487274,,,,,,,,,,,,,,,,, -19700,0.027886517,0.034840316,,,,,,,,,,,,,,,,, -19800,0.02461941,0.0371136,,,,,,,,,,,,,,,,, -19900,0.024778074,0.0325737,,,,,,,,,,,,,,,,, -20000,0.024403289,0.033903785,,,,,,,,,,,,,,,,, -20100,0.03270416,0.03284584,,,,,,,,,,,,,,,,, -20189,,,0.9917463064193726,0.0269054062664508,0.4939934566066215,0.9866786003112792,0.0461668036878109,0.258319712197705,43793.0,0.9858680367469788,0.04932626709342,0.2496602999371342,43793.0,6495.83327627182,9908.72722864151,6495.83327627182,3411.558787584305,0.7772469520568848,0.0 -20200,0.024185676,0.032929763,,,,,,,,,,,,,,,,, -20300,0.022505125,0.032818355,,,,,,,,,,,,,,,,, -20400,0.030053776,0.03271638,,,,,,,,,,,,,,,,, -20500,0.030668806,0.035496976,,,,,,,,,,,,,,,,, -20600,0.027248735,0.030197779,,,,,,,,,,,,,,,,, -20700,0.031504907,0.035352618,,,,,,,,,,,,,,,,, -20800,0.027199201,0.034249343,,,,,,,,,,,,,,,,, -20900,0.030869944,0.032920614,,,,,,,,,,,,,,,,, -20938,,,0.9917852282524108,0.0268196295946836,0.505956641813925,0.9866956472396852,0.0466153770685195,0.257373007364516,43793.0,0.9857537150382996,0.05003098025918,0.2476189983255035,43793.0,6736.029312372208,10266.563576698303,6736.029312372208,3529.147953033448,0.8081989288330078,0.0 -21000,0.028925113,0.03443213,,,,,,,,,,,,,,,,, -21100,0.024941303,0.032499406,,,,,,,,,,,,,,,,, -21200,0.023362083,0.033196032,,,,,,,,,,,,,,,,, -21300,0.02936327,0.03465717,,,,,,,,,,,,,,,,, -21400,0.027894676,0.035199773,,,,,,,,,,,,,,,,, -21500,0.04110706,0.03197911,,,,,,,,,,,,,,,,, -21600,0.033024523,0.035772845,,,,,,,,,,,,,,,,, -21689,,,0.9920173287391664,0.0260045584291219,0.523237135827197,0.9867163896560668,0.0463523082435131,0.2581223071410054,43793.0,0.9857938885688782,0.0493037439882755,0.2549565183568808,43793.0,6976.128599882126,10624.498348236084,6976.128599882126,3646.929877519608,0.841012716293335,0.0 -21700,0.028939622,0.03239491,,,,,,,,,,,,,,,,, -21800,0.02793039,0.033494595,,,,,,,,,,,,,,,,, -21900,0.029349469,0.03320926,,,,,,,,,,,,,,,,, -22000,0.030144438,0.03286212,,,,,,,,,,,,,,,,, -22100,0.026493035,0.031319268,,,,,,,,,,,,,,,,, -22200,0.032170184,0.03449501,,,,,,,,,,,,,,,,, -22300,0.031247744,0.03436332,,,,,,,,,,,,,,,,, -22400,0.034821533,0.034177482,,,,,,,,,,,,,,,,, -22436,,,0.9925153851509094,0.0246449559926986,0.5653510078298822,0.9865344762802124,0.0464587174355983,0.2524689786088972,43793.0,0.9857235550880432,0.0491550788283348,0.2442960371593634,43793.0,7216.098851442337,10983.145557641985,7216.098851442337,3765.5550322532654,0.8723323345184326,0.0 -22500,0.030505324,0.030592289,,,,,,,,,,,,,,,,, -22600,0.038263626,0.03585213,,,,,,,,,,,,,,,,, -22700,0.036037434,0.032974664,,,,,,,,,,,,,,,,, -22800,0.033563502,0.03615954,,,,,,,,,,,,,,,,, -22900,0.034355015,0.031798985,,,,,,,,,,,,,,,,, -23000,0.033335913,0.034134813,,,,,,,,,,,,,,,,, -23100,0.029935801,0.0327351,,,,,,,,,,,,,,,,, -23182,,,0.9926448464393616,0.0239301789551973,0.5750513515882668,0.9865629076957704,0.0472385957837104,0.2531709487648458,43793.0,0.9857202172279358,0.0501690804958343,0.2515987042061027,43793.0,7456.205724477768,11344.477650642397,7456.205724477768,3886.729905128479,0.9026894569396972,0.0 -23200,0.028621104,0.030944815,,,,,,,,,,,,,,,,, -23300,0.032558683,0.0327543,,,,,,,,,,,,,,,,, -23400,0.03118568,0.034567375,,,,,,,,,,,,,,,,, -23500,0.037754685,0.03225824,,,,,,,,,,,,,,,,, -23600,0.027658155,0.032662448,,,,,,,,,,,,,,,,, -23700,0.03228615,0.03387656,,,,,,,,,,,,,,,,, -23800,0.034870688,0.032657072,,,,,,,,,,,,,,,,, -23900,0.034041338,0.03373994,,,,,,,,,,,,,,,,, -23933,,,0.9929741024971008,0.0230709183961153,0.5870177889735277,0.9866250157356262,0.0469329915940761,0.2611496862478765,43793.0,0.9857092499732972,0.0499692484736442,0.2528821126225342,43793.0,7696.241092205048,11703.929669380188,7696.241092205048,4006.09627699852,0.9327542781829834,0.0 -24000,0.036383603,0.034766182,,,,,,,,,,,,,,,,, -24100,0.037283033,0.03096128,,,,,,,,,,,,,,,,, -24200,0.03869085,0.03456052,,,,,,,,,,,,,,,,, -24300,0.03844648,0.03385853,,,,,,,,,,,,,,,,, -24400,0.03776251,0.031770118,,,,,,,,,,,,,,,,, -24500,0.035767388,0.033331472,,,,,,,,,,,,,,,,, -24600,0.036856472,0.032254264,,,,,,,,,,,,,,,,, -24676,,,0.9929828643798828,0.0229916647076606,0.5987426324371437,0.9867159724235536,0.04694814234972,0.2613548949456581,43793.0,0.9857812523841858,0.0499368458986282,0.2494686159353731,43793.0,7936.516567230225,12066.811733722689,7936.516567230225,4128.651603221893,0.9643106460571288,0.0 -24700,0.043577626,0.035313435,,,,,,,,,,,,,,,,, -24800,0.032208107,0.03054995,,,,,,,,,,,,,,,,, -24900,0.031087097,0.030350314,,,,,,,,,,,,,,,,, -25000,0.030518433,0.030008813,,,,,,,,,,,,,,,,, -25100,0.035685793,0.033159316,,,,,,,,,,,,,,,,, -25200,0.030565092,0.03112573,,,,,,,,,,,,,,,,, -25300,0.0345201,0.034006607,,,,,,,,,,,,,,,,, -25400,0.04451886,0.031044753,,,,,,,,,,,,,,,,, -25421,,,0.992374300956726,0.0245363339781761,0.5646499916322825,0.9865953922271729,0.0482317917048931,0.2583915106640733,43793.0,0.98576021194458,0.0513744801282882,0.2480600840810696,43793.0,8176.752241849899,12431.66696190834,8176.752241849899,4253.220617294312,0.9951791763305664,0.0 -25500,0.050080743,0.03289541,,,,,,,,,,,,,,,,, -25600,0.043721694,0.032262664,,,,,,,,,,,,,,,,, -25700,0.0408738,0.032589674,,,,,,,,,,,,,,,,, -25800,0.041087456,0.03226086,,,,,,,,,,,,,,,,, -25900,0.040518854,0.032990273,,,,,,,,,,,,,,,,, -26000,0.034151834,0.031029489,,,,,,,,,,,,,,,,, -26100,0.036539398,0.030614702,,,,,,,,,,,,,,,,, -26160,,,0.9926239848136902,0.0240995120257139,0.5744237415586881,0.986622154712677,0.0471982508897781,0.2568303117982489,43793.0,0.9857054352760316,0.0503102540969848,0.2487609686306002,43793.0,8416.71799492836,12787.58046579361,8416.71799492836,4369.113248348236,1.0298182964324951,0.0 -26200,0.035258166,0.03281574,,,,,,,,,,,,,,,,, -26300,0.04199033,0.033118006,,,,,,,,,,,,,,,,, -26400,0.036179993,0.03144126,,,,,,,,,,,,,,,,, -26500,0.038258668,0.032181006,,,,,,,,,,,,,,,,, -26600,0.0368844,0.032223146,,,,,,,,,,,,,,,,, -26700,0.035066634,0.031124154,,,,,,,,,,,,,,,,, -26800,0.04018282,0.032254115,,,,,,,,,,,,,,,,, -26900,0.033643793,0.029247534,,,,,,,,,,,,,,,,, -26901,,,0.9923655986785888,0.0245582349598407,0.5520001620106612,0.9866526126861572,0.0479371212422847,0.2553553219675096,43793.0,0.9857370257377625,0.0514483153820037,0.2439620641067449,43793.0,8656.93652176857,13150.27823996544,8656.93652176857,4491.541553258896,1.0609753131866455,0.0 -27000,0.045865357,0.034487713,,,,,,,,,,,,,,,,, -27100,0.044700924,0.031564113,,,,,,,,,,,,,,,,, -27200,0.042989373,0.032198325,,,,,,,,,,,,,,,,, -27300,0.04286168,0.03302141,,,,,,,,,,,,,,,,, -27400,0.03555391,0.03260705,,,,,,,,,,,,,,,,, -27500,0.045457907,0.030705567,,,,,,,,,,,,,,,,, -27600,0.042715203,0.03151455,,,,,,,,,,,,,,,,, -27633,,,0.9923880100250244,0.0241682678461074,0.5654214180134779,0.9865978360176086,0.0485782362520694,0.2502639805699408,43793.0,0.98567134141922,0.0520597845315933,0.2451742423721277,43793.0,8897.207918643951,13512.87291264534,8897.207918643951,4613.804739713669,1.095686435699463,0.0 -27700,0.043871336,0.03339775,,,,,,,,,,,,,,,,, -27800,0.04082134,0.032605335,,,,,,,,,,,,,,,,, -27900,0.033166256,0.029382396,,,,,,,,,,,,,,,,, -28000,0.039667353,0.031547237,,,,,,,,,,,,,,,,, -28100,0.040137857,0.03373124,,,,,,,,,,,,,,,,, -28200,0.049515914,0.03142295,,,,,,,,,,,,,,,,, -28300,0.040105093,0.029479917,,,,,,,,,,,,,,,,, -28375,,,0.9930915832519532,0.0223762076348066,0.5991473293839304,0.9865767359733582,0.0481585152447223,0.2556357184836858,43793.0,0.9856898784637452,0.051655750721693,0.2430242722986725,43793.0,9137.189074516296,13872.593567609789,9137.189074516296,4733.491109848023,1.128706693649292,0.0 -28400,0.04430613,0.030921023,,,,,,,,,,,,,,,,, -28500,0.043476496,0.030541215,,,,,,,,,,,,,,,,, -28600,0.048162624,0.033926856,,,,,,,,,,,,,,,,, -28700,0.040994145,0.03064178,,,,,,,,,,,,,,,,, -28800,0.037843615,0.029555999,,,,,,,,,,,,,,,,, -28900,0.04814222,0.033907313,,,,,,,,,,,,,,,,, -29000,0.0437126,0.031439133,,,,,,,,,,,,,,,,, -29100,0.041954413,0.03125001,,,,,,,,,,,,,,,,, -29123,,,0.9930724501609802,0.0221859905868768,0.617831682258646,0.9865877032279968,0.0487930439412593,0.2522404282257972,43793.0,0.9857964515686036,0.0519480742514133,0.2473759793119771,43793.0,9377.43240237236,14233.759857654572,9377.43240237236,4854.35968875885,1.1624870300292969,0.0 -29200,0.048280086,0.03164142,,,,,,,,,,,,,,,,, -29300,0.04523831,0.032463033,,,,,,,,,,,,,,,,, -29400,0.043977626,0.032239918,,,,,,,,,,,,,,,,, -29500,0.045149747,0.03122119,,,,,,,,,,,,,,,,, -29600,0.046281353,0.032030124,,,,,,,,,,,,,,,,, -29700,0.044754274,0.030021293,,,,,,,,,,,,,,,,, -29800,0.04715537,0.030689633,,,,,,,,,,,,,,,,, -29861,,,0.993553876876831,0.020813263952732,0.6401829967227319,0.9865182638168336,0.0486945249140262,0.2477831701067681,43793.0,0.9856991171836852,0.0519113168120384,0.2484655358142954,43793.0,9617.391327857971,14591.459522247314,9617.391327857971,4972.047943592072,1.1940855979919434,0.0 -29900,0.044958077,0.030320281,,,,,,,,,,,,,,,,, -30000,0.045772735,0.032990903,,,,,,,,,,,,,,,,, -30100,0.050987203,0.03192321,,,,,,,,,,,,,,,,, -30200,0.04358213,0.028647669,,,,,,,,,,,,,,,,, -30300,0.047223825,0.03173154,,,,,,,,,,,,,,,,, -30400,0.053093147,0.030231172,,,,,,,,,,,,,,,,, -30500,0.04143445,0.028456507,,,,,,,,,,,,,,,,, -30589,,,0.9940480589866638,0.0193919260054826,0.6610642591112672,0.9865243434906006,0.0492027401924133,0.2488174179315952,43793.0,0.9856089949607848,0.0526963770389556,0.243400122037645,43793.0,9857.6364672184,14957.588129997252,9857.6364672184,5097.874583005905,1.2269041538238523,0.0 -30600,0.043436196,0.0317331,,,,,,,,,,,,,,,,, -30700,0.04506129,0.029506326,,,,,,,,,,,,,,,,, -30800,0.05534754,0.033614364,,,,,,,,,,,,,,,,, -30900,0.0530824,0.031340092,,,,,,,,,,,,,,,,, -31000,0.049780138,0.029640377,,,,,,,,,,,,,,,,, -31100,0.04525036,0.03072539,,,,,,,,,,,,,,,,, -31200,0.043695055,0.03036028,,,,,,,,,,,,,,,,, -31300,0.051319793,0.029705018,,,,,,,,,,,,,,,,, -31334,,,0.9936291575431824,0.0207611881196498,0.6274333508360139,0.9864720106124878,0.0494552142918109,0.2492178935171289,43793.0,0.9855576157569884,0.0529208593070507,0.2361791444678669,43793.0,10097.875643253326,15318.84247136116,10097.875643253326,5218.830732822418,1.2641723155975342,0.0 -31400,0.050723188,0.02929955,,,,,,,,,,,,,,,,, -31500,0.048382863,0.028623588,,,,,,,,,,,,,,,,, -31600,0.048792534,0.031918265,,,,,,,,,,,,,,,,, -31700,0.055645566,0.03111296,,,,,,,,,,,,,,,,, -31800,0.046303365,0.030002091,,,,,,,,,,,,,,,,, -31900,0.052671522,0.029967086,,,,,,,,,,,,,,,,, -32000,0.055413634,0.031398367,,,,,,,,,,,,,,,,, -32075,,,0.9931486248970032,0.0215517878532409,0.6206713202358298,0.986558437347412,0.0501068904995918,0.2495110815580755,43793.0,0.9856974482536316,0.0535745881497859,0.2412373732637037,43793.0,10337.898092508316,15679.182544469832,10337.898092508316,5339.095104217529,1.2972569465637207,0.0 -32100,0.05602487,0.029106958,,,,,,,,,,,,,,,,, -32200,0.05472095,0.03149991,,,,,,,,,,,,,,,,, -32300,0.057173096,0.029994197,,,,,,,,,,,,,,,,, -32400,0.04865051,0.030882442,,,,,,,,,,,,,,,,, -32500,0.05410971,0.030119041,,,,,,,,,,,,,,,,, -32600,0.06867183,0.031466264,,,,,,,,,,,,,,,,, -32700,0.05679295,0.028507927,,,,,,,,,,,,,,,,, -32800,0.05077736,0.030698773,,,,,,,,,,,,,,,,, -32816,,,0.9932396411895752,0.0215398985892534,0.6191252736954957,0.9865308403968812,0.0497139617800712,0.2507458835114568,43793.0,0.9856300950050354,0.0531991049647331,0.2408217543360258,43793.0,10577.976407766342,16037.931844472883,10577.976407766342,5457.711098432541,1.3318133354187012,0.0 -32900,0.060257535,0.029050464,,,,,,,,,,,,,,,,, -33000,0.05246751,0.0284632,,,,,,,,,,,,,,,,, -33100,0.07067796,0.030083448,,,,,,,,,,,,,,,,, -33200,0.05309481,0.029865881,,,,,,,,,,,,,,,,, -33300,0.06438128,0.029283674,,,,,,,,,,,,,,,,, -33400,0.056413893,0.030924315,,,,,,,,,,,,,,,,, -33500,0.059394572,0.029287687,,,,,,,,,,,,,,,,, -33566,,,0.9932653307914734,0.0214857589453458,0.6170808507068615,0.9865471124649048,0.0499107241630554,0.24726883668381,43793.0,0.985470414161682,0.05360097438097,0.2351262883357148,43793.0,10818.098178625109,16395.03837776184,10818.098178625109,5574.642826318741,1.364924430847168,0.0 -33600,0.05536055,0.028172163,,,,,,,,,,,,,,,,, -33700,0.058178175,0.027568864,,,,,,,,,,,,,,,,, -33800,0.055886425,0.028150631,,,,,,,,,,,,,,,,, -33900,0.051238712,0.02953071,,,,,,,,,,,,,,,,, -34000,0.0505256,0.029451786,,,,,,,,,,,,,,,,, -34100,0.064264975,0.029106563,,,,,,,,,,,,,,,,, -34200,0.05573601,0.030538557,,,,,,,,,,,,,,,,, -34300,0.06300688,0.028284283,,,,,,,,,,,,,,,,, -34319,,,0.99366956949234,0.0203478969633579,0.6380343914037621,0.9864545464515686,0.0497109591960907,0.2510879231059746,43793.0,0.9855020046234132,0.0530079714953899,0.2383980651235781,43793.0,11058.361001968384,16752.882014513016,11058.361001968384,5692.171427726746,1.3974757194519043,0.0 -34400,0.06291304,0.029956097,,,,,,,,,,,,,,,,, -34500,0.053692296,0.02830443,,,,,,,,,,,,,,,,, -34600,0.055925,0.02670628,,,,,,,,,,,,,,,,, -34700,0.05794457,0.028756378,,,,,,,,,,,,,,,,, -34800,0.057002425,0.028311713,,,,,,,,,,,,,,,,, -34900,0.054363374,0.029886667,,,,,,,,,,,,,,,,, -35000,0.06054634,0.028951656,,,,,,,,,,,,,,,,, -35066,,,0.9935685992240906,0.0202502589672803,0.6473796514189897,0.9865594506263732,0.0509484857320785,0.2562781973327978,43793.0,0.9855298399925232,0.0547693707048893,0.2370133223069865,43793.0,11298.55425786972,17113.989943742752,11298.55425786972,5813.030182600021,1.4321606159210205,0.0 -35100,0.08765052,0.030767784,,,,,,,,,,,,,,,,, -35200,0.06007466,0.029275443,,,,,,,,,,,,,,,,, -35300,0.062963426,0.028638203,,,,,,,,,,,,,,,,, -35400,0.056084935,0.028848108,,,,,,,,,,,,,,,,, -35500,0.068567716,0.030891055,,,,,,,,,,,,,,,,, -35600,0.058005564,0.028359132,,,,,,,,,,,,,,,,, -35700,0.060060397,0.028677808,,,,,,,,,,,,,,,,, -35800,0.06183492,0.030603895,,,,,,,,,,,,,,,,, -35807,,,0.9941927194595336,0.0187986772507429,0.6769564832266253,0.9864675402641296,0.0506184510886669,0.2556838216760408,43793.0,0.9854856133461,0.0541211180388927,0.2356342042180603,43793.0,11538.779606342316,17471.963469982147,11538.779606342316,5930.724012136459,1.4665467739105225,0.0 -35900,0.06283609,0.030241571,,,,,,,,,,,,,,,,, -36000,0.06587359,0.028379418,,,,,,,,,,,,,,,,, -36100,0.061035696,0.028539347,,,,,,,,,,,,,,,,, -36200,0.07497358,0.029561337,,,,,,,,,,,,,,,,, -36300,0.07157966,0.031130968,,,,,,,,,,,,,,,,, -36400,0.060803287,0.02821397,,,,,,,,,,,,,,,,, -36500,0.05859475,0.029442059,,,,,,,,,,,,,,,,, -36556,,,0.9948610663414,0.0171932596713304,0.712666872419726,0.9863424897193908,0.0508019998669624,0.2482839235230926,43793.0,0.9853832721710204,0.0543567053973674,0.2394302539002576,43793.0,11778.403351545334,17829.582918167114,11778.403351545334,6048.290234088898,1.875577449798584,0.0 -36600,0.070645735,0.030194117,,,,,,,,,,,,,,,,, -36700,0.06451471,0.029931176,,,,,,,,,,,,,,,,, -36800,0.06239197,0.027610304,,,,,,,,,,,,,,,,, -36900,0.056777574,0.024682283,,,,,,,,,,,,,,,,, -37000,0.07404023,0.028629813,,,,,,,,,,,,,,,,, -37100,0.07575821,0.030533386,,,,,,,,,,,,,,,,, -37200,0.06984692,0.029424556,,,,,,,,,,,,,,,,, -37297,,,0.9950770735740662,0.0165905263274908,0.7307588502918585,0.9863532185554504,0.0512042194604873,0.2467354335083033,43793.0,0.9853453636169434,0.0549082830548286,0.2335472181764784,43793.0,12018.428232431412,18186.92015695572,12018.428232431412,6165.5471975803375,1.9106330871582031,0.0 -37300,0.06511912,0.028977064,,,,,,,,,,,,,,,,, -37400,0.057088178,0.02836713,,,,,,,,,,,,,,,,, -37500,0.06480012,0.028985703,,,,,,,,,,,,,,,,, -37600,0.07662914,0.029078452,,,,,,,,,,,,,,,,, -37700,0.0798481,0.030736562,,,,,,,,,,,,,,,,, -37800,0.07020856,0.026583575,,,,,,,,,,,,,,,,, -37900,0.06615881,0.029158257,,,,,,,,,,,,,,,,, -38000,0.073324375,0.02989239,,,,,,,,,,,,,,,,, -38046,,,0.9940646886825562,0.0189915187656879,0.6748852044013662,0.986367642879486,0.0518397986888885,0.2426539176487821,43793.0,0.9854186177253724,0.0556958429515361,0.234332385061237,43793.0,12258.432970285416,18546.359894514084,12258.432970285416,6284.926635742188,1.9459753036499023,0.0 -38100,0.06853443,0.028259613,,,,,,,,,,,,,,,,, -38200,0.071170926,0.029879635,,,,,,,,,,,,,,,,, -38300,0.07352384,0.027854731,,,,,,,,,,,,,,,,, -38400,0.06995002,0.029207155,,,,,,,,,,,,,,,,, -38500,0.09009637,0.028259097,,,,,,,,,,,,,,,,, -38600,0.064888015,0.02666168,,,,,,,,,,,,,,,,, -38700,0.073128805,0.028573751,,,,,,,,,,,,,,,,, -38788,,,0.9945722818374634,0.0179280396550893,0.6901741289062857,0.9861927032470704,0.0516926646232605,0.23949731854974,43793.0,0.9852977395057678,0.0552296452224254,0.2311972221425861,43793.0,12498.420560121536,18902.31952357292,12498.420560121536,6400.845049619675,1.979938507080078,0.0 -38800,0.06882678,0.027930532,,,,,,,,,,,,,,,,, -38900,0.06207743,0.02674915,,,,,,,,,,,,,,,,, -39000,0.07427007,0.029139165,,,,,,,,,,,,,,,,, -39100,0.07236572,0.027669996,,,,,,,,,,,,,,,,, -39200,0.07169745,0.029181803,,,,,,,,,,,,,,,,, -39300,0.07289152,0.02789359,,,,,,,,,,,,,,,,, -39400,0.075813904,0.028747655,,,,,,,,,,,,,,,,, -39500,0.0714636,0.02828302,,,,,,,,,,,,,,,,, -39532,,,0.9941584467887878,0.0185711476951837,0.6757015592453715,0.9862909317016602,0.0520914942026138,0.2429719702376356,43793.0,0.9853373169898988,0.0559426136314868,0.2273990587653981,43793.0,12738.50132727623,19260.18678236008,12738.50132727623,6518.578117609024,2.013803243637085,0.0 -39600,0.07157894,0.02733974,,,,,,,,,,,,,,,,, -39700,0.07147516,0.028304616,,,,,,,,,,,,,,,,, -39800,0.07634661,0.027836436,,,,,,,,,,,,,,,,, -39900,0.08254826,0.028383695,,,,,,,,,,,,,,,,, -40000,0.06671565,0.026382877,,,,,,,,,,,,,,,,, -40100,0.077786766,0.028080396,,,,,,,,,,,,,,,,, -40200,0.07238153,0.026557108,,,,,,,,,,,,,,,,, -40275,,,0.9939488172531128,0.0189921930432319,0.6637098607290097,0.98630028963089,0.0531427562236785,0.2391469379130368,43793.0,0.9853878617286682,0.0571408942341804,0.2266936024700547,43793.0,12978.7029337883,19614.23141884804,12978.7029337883,6632.365000963211,2.0504045486450195,0.0 -40300,0.083439164,0.02714298,,,,,,,,,,,,,,,,, -40400,0.07060069,0.02693459,,,,,,,,,,,,,,,,, -40500,0.101047166,0.029687176,,,,,,,,,,,,,,,,, -40600,0.075265005,0.025452135,,,,,,,,,,,,,,,,, -40700,0.06845642,0.026425058,,,,,,,,,,,,,,,,, -40800,0.0765221,0.027794283,,,,,,,,,,,,,,,,, -40900,0.072586514,0.026807137,,,,,,,,,,,,,,,,, -41000,0.075419724,0.026418975,,,,,,,,,,,,,,,,, -41017,,,0.993874728679657,0.0191944427788257,0.6714206706921495,0.9862105846405028,0.0532609857618808,0.2385562048796848,43793.0,0.9853529334068298,0.0570392683148384,0.2271843160553856,43793.0,13218.936644792557,19975.578807592392,13218.936644792557,6753.425145626068,2.084524631500244,0.0 -41100,0.07681496,0.028685309,,,,,,,,,,,,,,,,, -41200,0.07619943,0.027776979,,,,,,,,,,,,,,,,, -41300,0.07267502,0.027771585,,,,,,,,,,,,,,,,, -41400,0.08003831,0.027512737,,,,,,,,,,,,,,,,, -41500,0.074131556,0.025373833,,,,,,,,,,,,,,,,, -41600,0.07625931,0.027344285,,,,,,,,,,,,,,,,, -41700,0.08954001,0.028180797,,,,,,,,,,,,,,,,, -41761,,,0.994343101978302,0.0179709792137146,0.6866316394471659,0.9861910939216614,0.053494531661272,0.2414926688414058,43793.0,0.9853084683418274,0.0573048181831836,0.2291360045914827,43793.0,13458.923202037811,20333.43417596817,13458.923202037811,6871.238674879074,2.11984920501709,0.0 -41800,0.08566185,0.02877126,,,,,,,,,,,,,,,,, -41900,0.07321954,0.027301785,,,,,,,,,,,,,,,,, -42000,0.064053915,0.025269814,,,,,,,,,,,,,,,,, -42100,0.07103622,0.026822511,,,,,,,,,,,,,,,,, -42200,0.0771819,0.02532609,,,,,,,,,,,,,,,,, -42300,0.070234686,0.02603671,,,,,,,,,,,,,,,,, -42400,0.089152925,0.027294075,,,,,,,,,,,,,,,,, -42496,,,0.9952325224876404,0.0154723525047302,0.7526896612136231,0.9862304329872132,0.0538027547299861,0.2387060657894706,43793.0,0.985352098941803,0.0576530136168003,0.2288912827762337,43793.0,13699.15291404724,20687.27829146385,13699.15291404724,6984.798540115356,2.154636144638061,0.0 -42500,0.08203366,0.025488395,,,,,,,,,,,,,,,,, -42600,0.08998211,0.027067909,,,,,,,,,,,,,,,,, -42700,0.089517996,0.028001685,,,,,,,,,,,,,,,,, -42800,0.072212145,0.025870705,,,,,,,,,,,,,,,,, -42900,0.11570709,0.026720056,,,,,,,,,,,,,,,,, -43000,0.08124842,0.026944904,,,,,,,,,,,,,,,,, -43100,0.079381615,0.0250918,,,,,,,,,,,,,,,,, -43200,0.08838703,0.025552714,,,,,,,,,,,,,,,,, -43233,,,0.9956064820289612,0.0144054051488637,0.7794564332097744,0.986202836036682,0.0545261949300766,0.2326384161525987,43793.0,0.9852619171142578,0.0583874657750129,0.2256373697097565,43793.0,13939.323972702026,21047.05036020279,13939.323972702026,7104.342164516449,2.190378427505493,0.0 -43300,0.07567977,0.02696321,,,,,,,,,,,,,,,,, -43400,0.075561054,0.025771085,,,,,,,,,,,,,,,,, -43500,0.07419008,0.025899405,,,,,,,,,,,,,,,,, -43600,0.08538326,0.026147772,,,,,,,,,,,,,,,,, -43700,0.08300483,0.027186563,,,,,,,,,,,,,,,,, -43800,0.08255343,0.02810359,,,,,,,,,,,,,,,,, -43900,0.08061524,0.02565177,,,,,,,,,,,,,,,,, -43974,,,0.9957359433174132,0.014393120072782,0.762248438689672,0.986229658126831,0.0544007755815982,0.2338933867718675,43793.0,0.9852930903434752,0.058384072035551,0.22733958924486,43793.0,14179.28227376938,21406.30087685585,14179.28227376938,7223.574427843094,2.22595739364624,0.0 -44000,0.083891064,0.027282104,,,,,,,,,,,,,,,,, -44100,0.076369345,0.026264124,,,,,,,,,,,,,,,,, -44200,0.10291586,0.028691761,,,,,,,,,,,,,,,,, -44300,0.098255,0.028065473,,,,,,,,,,,,,,,,, -44400,0.08422261,0.025735114,,,,,,,,,,,,,,,,, -44500,0.07851273,0.02512544,,,,,,,,,,,,,,,,, -44600,0.08176148,0.026401533,,,,,,,,,,,,,,,,, -44700,0.084113166,0.025485106,,,,,,,,,,,,,,,,, -44713,,,0.9954484105110168,0.0150855518877506,0.7581501313243499,0.9861987829208374,0.054727304726839,0.2367285787323215,43793.0,0.9853048920631408,0.0585244484245777,0.2281588202558792,43793.0,14419.376125335692,21765.531745910645,14419.376125335692,7342.6540105342865,2.2634549140930176,0.0 -44800,0.0838433,0.024216896,,,,,,,,,,,,,,,,, -44900,0.088155,0.025841653,,,,,,,,,,,,,,,,, -45000,0.09044463,0.026170041,,,,,,,,,,,,,,,,, -45100,0.07808115,0.02628996,,,,,,,,,,,,,,,,, -45200,0.08405479,0.026555052,,,,,,,,,,,,,,,,, -45300,0.07756632,0.02443357,,,,,,,,,,,,,,,,, -45400,0.09188515,0.025429612,,,,,,,,,,,,,,,,, -45457,,,0.9950907230377196,0.0157091245055198,0.7374597770421244,0.986100137233734,0.0549469888210296,0.2338895422966686,43793.0,0.9852097034454346,0.0588457509875297,0.2212094006776859,43793.0,14659.376512050629,22120.84386181832,14659.376512050629,7457.90939617157,2.2987446784973145,0.0 -45500,0.079058185,0.025674483,,,,,,,,,,,,,,,,, -45600,0.092738815,0.026641238,,,,,,,,,,,,,,,,, -45700,0.085472114,0.025617482,,,,,,,,,,,,,,,,, -45800,0.09496306,0.027088301,,,,,,,,,,,,,,,,, -45900,0.08836322,0.028171506,,,,,,,,,,,,,,,,, -46000,0.08486378,0.025985101,,,,,,,,,,,,,,,,, -46100,0.09307892,0.024884457,,,,,,,,,,,,,,,,, -46200,,,0.9943767786026,0.0173879079520702,0.7149168519041117,0.9862478971481324,0.0558124035596847,0.2359758331316309,43793.0,0.9853684902191162,0.0596916303038597,0.2266012596243111,43793.0,14899.567381620407,22476.54651904106,14899.567381620407,7573.364289522171,2.335345506668091,0.0 -46200,0.079784624,0.025195235,,,,,,,,,,,,,,,,, -46300,0.0881039,0.02637098,,,,,,,,,,,,,,,,, -46400,0.08057984,0.024688683,,,,,,,,,,,,,,,,, -46500,0.08634275,0.023642546,,,,,,,,,,,,,,,,, -46600,0.11921095,0.026634708,,,,,,,,,,,,,,,,, -46700,0.0867686,0.025708579,,,,,,,,,,,,,,,,, -46800,0.08245244,0.025232071,,,,,,,,,,,,,,,,, -46900,0.08922614,0.02440839,,,,,,,,,,,,,,,,, -46945,,,0.9943929314613342,0.0173799358308315,0.7088835674631668,0.9861464500427246,0.0558986105024814,0.2357999831802581,43793.0,0.9852564930915833,0.0597845762968063,0.2232917765851285,43793.0,15139.756244659424,22832.776628017426,15139.756244659424,7689.344358444214,2.3755228519439697,0.0 -47000,0.077195644,0.025251953,,,,,,,,,,,,,,,,, -47100,0.10037526,0.0260785,,,,,,,,,,,,,,,,, -47200,0.07899322,0.02388762,,,,,,,,,,,,,,,,, -47300,0.09382376,0.025533538,,,,,,,,,,,,,,,,, -47400,0.08382133,0.025721306,,,,,,,,,,,,,,,,, -47500,0.09380503,0.024616059,,,,,,,,,,,,,,,,, -47600,0.09124819,0.024638297,,,,,,,,,,,,,,,,, -47690,,,0.9948909282684326,0.0161126609891653,0.7247512557869674,0.9861387014389038,0.0563165694475173,0.2352581202099787,43793.0,0.9851852655410768,0.0603056997060775,0.2267443674197929,43793.0,15379.908016204834,23192.14047503472,15379.908016204834,7808.494928121567,2.4163498878479004,0.0 -47700,0.11151822,0.025376977,,,,,,,,,,,,,,,,, -47800,0.08113681,0.024658348,,,,,,,,,,,,,,,,, -47900,0.114259206,0.025872929,,,,,,,,,,,,,,,,, -48000,0.08290905,0.024896506,,,,,,,,,,,,,,,,, -48100,0.08658843,0.022565603,,,,,,,,,,,,,,,,, -48200,0.084772386,0.02436738,,,,,,,,,,,,,,,,, -48300,0.094630465,0.024512451,,,,,,,,,,,,,,,,, -48400,0.0942276,0.026529416,,,,,,,,,,,,,,,,, -48437,,,0.99583101272583,0.0137463333085179,0.7851234003757641,0.9861135482788086,0.0566686391830444,0.2343727002930017,43793.0,0.9852004647254944,0.0607633218169212,0.2229074234759547,43793.0,15619.869005203249,23548.04278063774,15619.869005203249,7924.379849433899,2.4527883529663086,0.0 -48500,0.07813427,0.023097659,,,,,,,,,,,,,,,,, -48600,0.084973626,0.02436147,,,,,,,,,,,,,,,,, -48700,0.098976485,0.026259925,,,,,,,,,,,,,,,,, -48800,0.08384845,0.02395494,,,,,,,,,,,,,,,,, -48900,0.08425534,0.023421656,,,,,,,,,,,,,,,,, -49000,0.08937081,0.025785107,,,,,,,,,,,,,,,,, -49100,0.08495659,0.025788322,,,,,,,,,,,,,,,,, -49184,,,0.9950827360153198,0.0155350798740983,0.746528199797595,0.9861204624176024,0.0562176518142223,0.2370487070055823,43793.0,0.9851077795028688,0.0605574212968349,0.2294661277003323,43793.0,15860.097087621689,23904.12767291069,15860.097087621689,8040.179135560989,2.490166187286377,0.0 -49200,0.08371137,0.024102712,,,,,,,,,,,,,,,,, -49300,0.08227834,0.024267262,,,,,,,,,,,,,,,,, -49400,0.09250342,0.02503256,,,,,,,,,,,,,,,,, -49500,0.09701804,0.024211183,,,,,,,,,,,,,,,,, -49600,0.094213,0.024943694,,,,,,,,,,,,,,,,, -49700,0.090348065,0.023737881,,,,,,,,,,,,,,,,, -49800,0.09926595,0.02515847,,,,,,,,,,,,,,,,, -49900,0.07349897,0.023599703,,,,,,,,,,,,,,,,, -49935,,,0.996401846408844,0.0126142036169767,0.8146022634550463,0.9861245155334472,0.0576920099556446,0.2316445354722096,43793.0,0.9851545095443726,0.0619620457291603,0.2204592669232147,43793.0,16100.1772043705,24257.970780849457,16100.1772043705,8153.885211467743,2.5272974967956543,0.0 -50000,0.083486296,0.024005687,,,,,,,,,,,,,,,,, -50100,0.08891364,0.023750663,,,,,,,,,,,,,,,,, -50200,0.09291044,0.024665745,,,,,,,,,,,,,,,,, -50300,0.09907441,0.024692135,,,,,,,,,,,,,,,,, -50400,0.09136658,0.024732249,,,,,,,,,,,,,,,,, -50500,0.08858306,0.02484575,,,,,,,,,,,,,,,,, -50600,0.07894316,0.021968877,,,,,,,,,,,,,,,,, -50680,,,0.9969805479049684,0.0111792050302028,0.8314752558336872,0.9861451983451844,0.0584052912890911,0.2308138767184729,43793.0,0.9852147698402404,0.0626945868134498,0.2215540398824933,43793.0,16340.363671779633,24614.37414741516,16340.363671779633,8270.046475887299,2.562939405441284,0.0 -50700,0.096689895,0.024407705,,,,,,,,,,,,,,,,, -50800,0.09115014,0.025675688,,,,,,,,,,,,,,,,, -50900,0.08599929,0.023033727,,,,,,,,,,,,,,,,, -51000,0.09306364,0.023577344,,,,,,,,,,,,,,,,, -51100,0.08800509,0.023676993,,,,,,,,,,,,,,,,, -51200,0.08948987,0.023791317,,,,,,,,,,,,,,,,, -51300,0.08894912,0.025829386,,,,,,,,,,,,,,,,, -51400,0.09430829,0.023774289,,,,,,,,,,,,,,,,, -51424,,,0.9964102506637572,0.0121396416798233,0.8212676428096144,0.9860566854476928,0.0583216175436973,0.226055681241575,43793.0,0.985121250152588,0.0625859647989273,0.2194823881609811,43793.0,16580.46027135849,24968.01723909378,16580.46027135849,8383.532366275787,2.60349440574646,0.0 -51500,0.08832916,0.025316026,,,,,,,,,,,,,,,,, -51600,0.08906701,0.022934223,,,,,,,,,,,,,,,,, -51700,0.09685608,0.023651933,,,,,,,,,,,,,,,,, -51800,0.11044508,0.025423016,,,,,,,,,,,,,,,,, -51900,0.090576425,0.024660138,,,,,,,,,,,,,,,,, -52000,0.09527387,0.024259957,,,,,,,,,,,,,,,,, -52100,0.08797639,0.025573079,,,,,,,,,,,,,,,,, -52170,,,0.9960031509399414,0.0130105959251523,0.7998794085137176,0.9860563278198242,0.0587018877267837,0.2295651743536794,43793.0,0.98509681224823,0.0628611296415329,0.2198022544309893,43793.0,16820.501957178116,25321.957409858704,16820.501957178116,8497.374136447906,2.6400325298309326,0.0 -52200,0.09418979,0.023554621,,,,,,,,,,,,,,,,, -52300,0.09433476,0.023001246,,,,,,,,,,,,,,,,, -52400,0.08189071,0.023192322,,,,,,,,,,,,,,,,, -52500,0.08728755,0.022963483,,,,,,,,,,,,,,,,, -52600,0.120411746,0.026853587,,,,,,,,,,,,,,,,, -52700,0.088062204,0.02460801,,,,,,,,,,,,,,,,, -52800,0.09294389,0.023972012,,,,,,,,,,,,,,,,, -52900,0.094686456,0.023797384,,,,,,,,,,,,,,,,, -52920,,,0.995352268218994,0.0142620839178562,0.7881900013083352,0.986114740371704,0.0593496747314929,0.225412239521873,43793.0,0.9851848483085632,0.0634528473019599,0.2199008821173299,43793.0,17060.735783100128,25680.60373020172,17060.735783100128,8615.72803235054,2.677996158599853,0.0 -53000,0.08691299,0.022242665,,,,,,,,,,,,,,,,, -53100,0.09627154,0.023082692,,,,,,,,,,,,,,,,, -53200,0.090081856,0.024038685,,,,,,,,,,,,,,,,, -53300,0.103374995,0.023956051,,,,,,,,,,,,,,,,, -53400,0.09261231,0.02141115,,,,,,,,,,,,,,,,, -53500,0.08506359,0.023359854,,,,,,,,,,,,,,,,, -53600,0.08451999,0.023783784,,,,,,,,,,,,,,,,, -53669,,,0.9954304099082948,0.0140579780563712,0.785213941239531,0.9861135482788086,0.0589690953493118,0.2228026236983691,43793.0,0.9851098656654358,0.0631409138441085,0.218001164903323,43793.0,17300.85765480995,26031.861436128616,17300.85765480995,8726.805708408356,2.715599775314331,0.0 -53700,0.12221372,0.02331904,,,,,,,,,,,,,,,,, -53800,0.10064304,0.023890035,,,,,,,,,,,,,,,,, -53900,0.10318845,0.022173936,,,,,,,,,,,,,,,,, -54000,0.09562722,0.02349549,,,,,,,,,,,,,,,,, -54100,0.086707294,0.022071723,,,,,,,,,,,,,,,,, -54200,0.08941801,0.023033468,,,,,,,,,,,,,,,,, -54300,0.08862764,0.023512427,,,,,,,,,,,,,,,,, -54400,0.08098434,0.023923049,,,,,,,,,,,,,,,,, -54412,,,0.995689332485199,0.0135448304936289,0.7969386333438457,0.9860498309135436,0.0595028214156627,0.2283987255755745,43793.0,0.9851149320602416,0.0638305619359016,0.216414058452331,43793.0,17541.0244076252,26385.043430566788,17541.0244076252,8839.761295318604,2.7525038719177246,0.0 -54500,0.08060274,0.02271188,,,,,,,,,,,,,,,,, -54600,0.08928407,0.02332263,,,,,,,,,,,,,,,,, -54700,0.097691305,0.024795247,,,,,,,,,,,,,,,,, -54800,0.085639365,0.023277568,,,,,,,,,,,,,,,,, -54900,0.1014416,0.0224224,,,,,,,,,,,,,,,,, -55000,0.09303515,0.023240175,,,,,,,,,,,,,,,,, -55100,0.10480986,0.02221938,,,,,,,,,,,,,,,,, -55167,,,0.9960854053497314,0.0127092683687806,0.8059383088169321,0.9860132932662964,0.0594084672629833,0.2242151660658549,43793.0,0.9850315451622008,0.0636202916502952,0.2174672035747131,43793.0,17781.22932624817,26739.52824115753,17781.22932624817,8953.982422113419,2.7906315326690674,0.0 -55200,0.097341865,0.023409365,,,,,,,,,,,,,,,,, -55300,0.08732684,0.02155532,,,,,,,,,,,,,,,,, -55400,0.084817596,0.02234047,,,,,,,,,,,,,,,,, -55500,0.078003235,0.022996118,,,,,,,,,,,,,,,,, -55600,0.07629663,0.021077998,,,,,,,,,,,,,,,,, -55700,0.09362211,0.023069357,,,,,,,,,,,,,,,,, -55800,0.09578638,0.022567995,,,,,,,,,,,,,,,,, -55900,0.10009828,0.022649692,,,,,,,,,,,,,,,,, -55905,,,0.9968048930168152,0.0110593615099787,0.8559029305644175,0.9860246181488036,0.0603373870253562,0.2279856110308997,43793.0,0.9850290417671204,0.0648826137185096,0.2126118880927797,43793.0,18021.468203783035,27097.05589222908,18021.468203783035,9071.203595876694,2.8327012062072754,0.0 -56000,0.09277876,0.022915889,,,,,,,,,,,,,,,,, -56100,0.08186527,0.022022754,,,,,,,,,,,,,,,,, -56200,0.10626511,0.022901462,,,,,,,,,,,,,,,,, -56300,0.083175555,0.021303589,,,,,,,,,,,,,,,,, -56400,0.08022757,0.020668296,,,,,,,,,,,,,,,,, -56500,0.09144338,0.023074266,,,,,,,,,,,,,,,,, -56600,0.09950427,0.02130458,,,,,,,,,,,,,,,,, -56650,,,0.997663915157318,0.0095526510849595,0.8787869832304159,0.985992968082428,0.0606252178549766,0.2272033097481381,43793.0,0.9850433468818665,0.0648756250739097,0.2150164291304012,43793.0,18261.61960697174,27448.26845598221,18261.61960697174,9182.202111959456,2.873664617538452,0.0 -56700,0.08090251,0.021171637,,,,,,,,,,,,,,,,, -56800,0.08062764,0.022780126,,,,,,,,,,,,,,,,, -56900,0.08307634,0.022269377,,,,,,,,,,,,,,,,, -57000,0.09859039,0.022519186,,,,,,,,,,,,,,,,, -57100,0.08276982,0.022100877,,,,,,,,,,,,,,,,, -57200,0.087669194,0.022873797,,,,,,,,,,,,,,,,, -57300,0.09160783,0.023367109,,,,,,,,,,,,,,,,, -57322,,,,,,,,,,,,,,18477.167249679565,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 9e6edc4ef..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -114.24537301063538,0.0,11.491451501846312,1,0,11.491451501846312,0.5214916467666626,0.7347381114959717,0.0268228675787328,43793,125.73687386512756,0.5324387550354004,0.7276949882507324,0.0216668436768478,0.5230699777603149,0.7331878542900085,0.0255681898490336,43793 -231.0123994350433,0.0217573642730712,251.67628693580627,745,0,251.67628693580627,0.983142077922821,0.0807076543569564,0.0348263753677855,43793,482.7305812835693,0.9867120385169984,0.0699601098895073,0.0313649239899634,0.9841179251670836,0.0779209211468696,0.0335706418063978,43793 -347.83271861076355,0.0503137111663818,491.8714425563812,1483,0,491.8714425563812,0.9834638833999634,0.062710590660572,0.0691000026942505,43793,839.7973530292511,0.9868988394737244,0.0504347085952758,0.0713689412274672,0.9844674468040466,0.0596305727958679,0.0638470882588525,43793 -465.06700444221497,0.079200267791748,732.0550570487976,2227,0,732.0550570487976,0.9839284420013428,0.0571091137826442,0.1205723913361927,43793,1197.2648441791534,0.9876678586006165,0.044889573007822,0.1232982563297867,0.9848230481147766,0.0541297234594821,0.1191175490663034,43793 -579.8030261993408,0.1062445640563964,972.2738552093506,2976,0,972.2738552093506,0.984220325946808,0.054486583918333,0.1466533702010206,43793,1552.2670772075653,0.9881260991096495,0.0423507206141948,0.1503894734798555,0.9851656556129456,0.0517598614096641,0.1425854816835852,43793 -701.6240711212158,0.1329081058502197,1212.473397731781,3718,0,1212.473397731781,0.9843053817749025,0.0534604266285896,0.1665679376148152,43793,1914.3342943191528,0.9882087111473083,0.0411868691444397,0.1789513129569063,0.9852659106254578,0.0506561025977134,0.1721535565328543,43793 -816.5106444358826,0.1609640121459961,1452.6450970172882,4461,0,1452.6450970172882,0.9846844673156738,0.051958292722702,0.1845553027411022,43793,2269.441216945648,0.988564133644104,0.0396111719310283,0.2052638865801224,0.9856268167495728,0.0491602905094623,0.1869933395085386,43793 -930.4642643928528,0.1912925243377685,1692.7988908290863,5203,0,1692.7988908290863,0.984832763671875,0.0517894700169563,0.1853352581939542,43793,2623.5991291999817,0.9885818362236024,0.0390691086649894,0.21580086321456,0.9856730699539183,0.0491345338523387,0.1862065133678861,43793 -1039.5315363407135,0.2181088924407959,1932.7561764717104,5946,0,1932.7561764717104,0.9849889874458312,0.0500839799642562,0.1999434708760638,43793,2972.670731782913,0.988966464996338,0.0375383198261261,0.2455382972249783,0.9859256148338318,0.0473123341798782,0.2103226599902252,43793 -1151.2113268375397,0.250399112701416,2172.862287044525,6683,0,2172.862287044525,0.9849334359169006,0.0502394512295722,0.2088082447016146,43793,3324.5116848945618,0.989159107208252,0.0369097590446472,0.2577345522809846,0.9858171939849854,0.0476182177662849,0.2067422098356613,43793 -1267.4153888225555,0.2777352333068847,2413.044668197632,7430,0,2413.044668197632,0.985179364681244,0.0497350096702575,0.2149340764446602,43793,3680.946016788482,0.9891023635864258,0.0369815491139888,0.2569534259312815,0.9860132932662964,0.0470480434596538,0.2162440311839511,43793 -1377.8806405067444,0.3055968284606933,2653.0890650749207,8166,0,2653.0890650749207,0.9851751923561096,0.0500355958938598,0.2150636442012934,43793,4031.506335258484,0.9891697764396667,0.0367159433662891,0.2575354688274187,0.986077845096588,0.0471108704805374,0.2182111850928434,43793 -1489.8143291473389,0.3341653347015381,2893.1614439487457,8914,0,2893.1614439487457,0.9852905869483948,0.0490034371614456,0.2206016502531977,43793,4383.561316490173,0.9893730878829956,0.035877127200365,0.2799335236621727,0.9860900044441224,0.0463456958532333,0.2231633377875584,43793 -1602.1785380840302,0.3626327514648437,3133.301428318024,9658,0,3133.301428318024,0.9855382442474364,0.0484528020024299,0.2348079878439649,43793,4736.115564584732,0.9894848465919496,0.0353065766394138,0.2950740295810377,0.9863181114196776,0.0457585975527763,0.2377954955532402,43793 -1711.423900127411,0.3928844928741455,3373.3994760513306,10406,0,3373.3994760513306,0.9855803847312928,0.0481531098484993,0.2356283022953401,43793,5085.509356498718,0.9897177815437316,0.0344165675342083,0.3170789589270832,0.9863526225090028,0.0455629266798496,0.2372990797672682,43793 -1817.8432455062864,0.4235210418701172,3613.445493936538,11155,0,3613.445493936538,0.985577404499054,0.047991894185543,0.2370999874809867,43793,5432.026793479919,0.989963173866272,0.0334694199264049,0.3281541876391228,0.9864525198936462,0.0453196726739406,0.2426067549501819,43793 -1929.5064034461973,0.4515836238861084,3853.49520945549,11898,0,3853.49520945549,0.985605239868164,0.0476678498089313,0.2458516870009759,43793,5783.7891409397125,0.9901131391525269,0.0330640785396099,0.3472507179399345,0.9863274693489076,0.0450441613793373,0.2529347650672703,43793 -2041.57985496521,0.8727149963378906,4093.3299930095673,12641,0,4093.3299930095673,0.9857774972915648,0.0477763079106807,0.2443757778600505,43793,6136.139136552811,0.9902748465538024,0.0321195796132087,0.3678234983873403,0.9866047501564026,0.0449953004717826,0.251586093170994,43793 -2149.701101541519,0.902353286743164,4333.489721059799,13388,0,4333.489721059799,0.9857479929924012,0.0478766188025474,0.2479855948983023,43793,6484.470532894135,0.990310549736023,0.0318638868629932,0.3633886240567374,0.9865308403968812,0.045009970664978,0.2489073772637887,43793 -2259.2570674419403,0.931546688079834,4573.749706029892,14120,0,4573.749706029892,0.9856890439987184,0.0474424511194229,0.2457914692060729,43793,6834.34108376503,0.9904751777648926,0.0316066108644008,0.3673989488352646,0.9864829182624816,0.0447915568947792,0.2517094654821753,43793 -2367.700381278992,0.9632065296173096,4813.801182031632,14870,0,4813.801182031632,0.9857884645462036,0.0473058149218559,0.2542209605860204,43793,7182.887937784195,0.9906328320503236,0.0312632247805595,0.3707889618155518,0.9866303205490112,0.0445795319974422,0.2557782403995275,43793 -2474.837132692337,0.9940104484558104,5053.904675960541,15627,0,5053.904675960541,0.9857210516929626,0.0475239269435405,0.2492742742666057,43793,7530.179664611816,0.9904218912124634,0.0316820815205574,0.3776189810012174,0.9864882230758668,0.0450279936194419,0.2553386486819468,43793 -2582.3004174232483,1.0230059623718262,5293.972366333008,16374,0,5293.972366333008,0.9857126474380492,0.0474039055407047,0.2427294222040624,43793,7877.760135889053,0.9905163049697876,0.0314247533679008,0.3839333720689543,0.9865962266921996,0.0447120182216167,0.2527464871775334,43793 -2690.791320323944,1.059596300125122,5534.025603532791,17134,0,5534.025603532791,0.9857867360115052,0.0475266650319099,0.2523974319454819,43793,8226.361269712448,0.9904187321662904,0.0311676487326622,0.3788797100037003,0.986632764339447,0.0448093190789222,0.2590131502594755,43793 -2803.158401966095,1.0902180671691897,5774.065830469132,17887,0,5774.065830469132,0.9859261512756348,0.047035839408636,0.2652222081535117,43793,8578.81997179985,0.9908297061920166,0.0302246939390897,0.3976157937726178,0.9866778254508972,0.0445819795131683,0.2644920682135977,43793 -2912.605295658112,1.1236302852630615,6014.162916898727,18634,0,6014.162916898727,0.9858751893043518,0.0473973602056503,0.2532736599645985,43793,8928.417778730392,0.9909319281578064,0.0297775752842426,0.4131658217801549,0.9866737127304076,0.0447421967983245,0.2665251325435218,43793 -3025.486868619919,1.1551201343536377,6254.241195201874,19381,0,6254.241195201874,0.9859354496002196,0.0479034930467605,0.2580514846627701,43793,9281.429696083069,0.9910821914672852,0.0290072355419397,0.4330688829518003,0.9867277145385742,0.0453269369900226,0.2618424861378399,43793 -3141.0740916728973,1.1857497692108154,6494.479324102402,20119,0,6494.479324102402,0.985922396183014,0.0472350865602493,0.2566684458171706,43793,9637.310660123823,0.9911299347877502,0.029043648391962,0.4524596980950766,0.9866428971290588,0.0447989925742149,0.2666393752428824,43793 -3253.587816953659,1.219196319580078,6734.529579162598,20860,0,6734.529579162598,0.9855993390083312,0.0475965365767478,0.2488904178834033,43793,9989.929476737976,0.9911418557167052,0.0289900843054056,0.4361918090920916,0.9865061044692992,0.0448005199432373,0.2578541146981997,43793 -3363.8338465690613,1.250971794128418,6974.786199092865,21614,0,6974.786199092865,0.9857787489891052,0.0473995953798294,0.2512402387290738,43793,10340.48482823372,0.9909935593605042,0.0295106805860996,0.418704369794744,0.9864736199378968,0.0446580611169338,0.2688391855446387,43793 -3469.385509490967,1.2810804843902588,7215.030090808868,22366,0,7215.030090808868,0.9858646988868712,0.0471776910126209,0.2555786770835429,43793,10686.331381559372,0.990916907787323,0.0297023355960845,0.4092995417653871,0.9867342114448548,0.0445825308561325,0.2692697508641262,43793 -3579.345846414566,1.3130922317504885,7455.064225435257,23115,0,7455.064225435257,0.985886573791504,0.0472914464771747,0.2569587154031185,43793,11036.378544092178,0.990923285484314,0.0296106562018394,0.4129811539648565,0.9866039156913756,0.0448347702622413,0.2584141344443701,43793 -3686.243659973145,1.3458125591278076,7695.025829792023,23867,0,7695.025829792023,0.9857938885688782,0.0474799573421478,0.260062976833778,43793,11383.292026758194,0.991037666797638,0.0293105076998472,0.4443964159175245,0.9865726828575134,0.0448810681700706,0.2680395780729493,43793 -3798.9610619544974,1.377070426940918,7935.27251625061,24615,0,7935.27251625061,0.9858722686767578,0.0475184321403503,0.257685135742762,43793,11736.308470726011,0.9911251068115234,0.0287194959819316,0.4349780862460942,0.9867350459098816,0.0446908399462699,0.2686270225783731,43793 -3911.076207399368,1.4118154048919678,8175.348889112472,25352,0,8175.348889112472,0.9858680367469788,0.047340765595436,0.2599173126520697,43793,12088.560628414154,0.9912492036819458,0.0284839458763599,0.4410713011697503,0.9866904020309448,0.0446516759693622,0.2698735193364242,43793 -4024.9607586860657,1.4460327625274658,8415.341947555542,26095,0,8415.341947555542,0.9858587980270386,0.0475784353911876,0.2584336771859823,43793,12442.495477199554,0.9914388656616212,0.0277533438056707,0.4717187337805464,0.9866493940353394,0.0448740981519222,0.2677851521561414,43793 -4134.661108732224,1.4795305728912354,8655.540459394455,26836,0,8655.540459394455,0.9857859015464784,0.0476189441978931,0.2573965177777483,43793,12792.448338985443,0.9916952848434448,0.0269474107772111,0.4850118415830761,0.9866157174110411,0.0448866635560989,0.2657575785424613,43793 -4241.58500623703,1.5119729042053225,8895.518734931946,27589,0,8895.518734931946,0.9858987927436828,0.0480052009224891,0.2532023452419686,43793,13139.403820991516,0.9915966391563416,0.0271198097616434,0.4833182505356849,0.9867532849311828,0.0451139137148857,0.2680907991811648,43793 -4353.890632867813,1.5436317920684814,9135.784817695618,28338,0,9135.784817695618,0.9856258630752563,0.0483823120594024,0.2584529614231575,43793,13492.027368068697,0.9913761615753174,0.0279940068721771,0.4537976301349314,0.986455738544464,0.0457374155521392,0.265332734975108,43793 -4466.701440811157,1.5765349864959717,9375.92338514328,29081,0,9375.92338514328,0.9858810901641846,0.0476830154657363,0.2606048610788429,43793,13845.032121896744,0.9911956191062928,0.0283650364726781,0.4424061318322148,0.986704170703888,0.044825755059719,0.2695830069042759,43793 -4575.459840536118,1.6094539165496826,9616.0367436409,29837,0,9616.0367436409,0.9859964847564696,0.0474441237747669,0.2622114526014112,43793,14193.95768547058,0.9912551641464232,0.0284039881080389,0.4587308396449663,0.9867545366287231,0.0448542796075344,0.2684878121075197,43793 -4683.776865005493,1.6447956562042236,9856.1376850605,30584,0,9856.1376850605,0.985809087753296,0.0475344434380531,0.2657570707766837,43793,14542.43186235428,0.9915154576301576,0.0274634025990962,0.4633989594018279,0.9865645170211792,0.0447933971881866,0.2774304830450718,43793 -4794.186100482941,1.6768467426300049,10096.208181619644,31337,0,10096.208181619644,0.9858158230781556,0.0479998849332332,0.2566377316769796,43793,14892.963721752169,0.9915711283683776,0.0272007342427968,0.478147452430944,0.9866875410079956,0.0451299957931041,0.2771974970735388,43793 -4903.480740070343,1.7101600170135498,10336.211215496063,32094,0,10336.211215496063,0.9858823418617249,0.0483216866850853,0.2524618240322459,43793,15242.315237522123,0.991542398929596,0.0271677859127521,0.4717304585287615,0.9865958094596864,0.0454927459359169,0.2630080554557668,43793 -5009.653836011887,1.7483479976654053,10576.290118932724,32840,0,10576.290118932724,0.9857816696166992,0.048105664551258,0.2627687226551145,43793,15588.626574993134,0.9917683601379396,0.026312205940485,0.5015883602624208,0.9866209626197816,0.0450939089059829,0.2753164946013637,43793 -5118.999151468277,1.7833147048950195,10816.54667019844,33583,0,10816.54667019844,0.9858322143554688,0.0479882806539535,0.2613832404439288,43793,15938.288664340973,0.9920535683631896,0.0255418438464403,0.5182077647609746,0.9865832328796388,0.0453186854720115,0.2710398928255368,43793 -5229.406427145004,1.818897008895874,11056.78190946579,34333,0,11056.78190946579,0.9859931468963624,0.0483602136373519,0.2592067698810614,43793,16288.988464832306,0.992278516292572,0.0248276535421609,0.5308865696339622,0.9868900775909424,0.0454954542219638,0.2792153231160525,43793 -5336.452176809311,1.8567280769348145,11296.9742269516,35076,0,11296.9742269516,0.985697865486145,0.0481703765690326,0.2579584826472572,43793,16636.286905050278,0.9921972751617432,0.02528334595263,0.5149612251079351,0.9865661859512328,0.0451306663453578,0.2722435905529067,43793 -5446.469096899033,1.8912177085876465,11536.981744527817,35816,0,11536.981744527817,0.9858474135398864,0.0484303832054138,0.2547830847013111,43793,16986.371559381485,0.99190354347229,0.0260348320007324,0.5028728350713861,0.9866254329681396,0.0456982627511024,0.2760192571586574,43793 -5558.977123498917,1.9253017902374268,11777.182457208632,36574,0,11777.182457208632,0.9858174920082092,0.048509806394577,0.2587279235831513,43793,17339.134885072708,0.9917653799057008,0.0263422261923551,0.4941073351555926,0.9865604639053344,0.0457744225859642,0.2681013307066064,43793 -5664.32294178009,1.9741060733795168,12017.42915058136,37332,0,12017.42915058136,0.9857711791992188,0.0488524958491325,0.2530526844425635,43793,17684.79681611061,0.9916236400604248,0.0268276743590831,0.4900111688410533,0.98675936460495,0.0456772036850452,0.2650382750814447,43793 -5772.689391851425,2.009575843811035,12257.636492729189,38089,0,12257.636492729189,0.9857606291770936,0.0485462956130504,0.2595443308006202,43793,18033.42644929886,0.9918012619018556,0.0262242015451192,0.5010326044071443,0.9865389466285706,0.0456525199115276,0.2704411948702859,43793 -5879.004307746887,2.043231964111328,12497.743663549423,38846,0,12497.743663549423,0.9858554005622864,0.0490255393087863,0.2538144049564341,43793,18379.9026362896,0.9921699166297911,0.0252218600362539,0.5169255787785384,0.9866672158241272,0.0460537821054458,0.2735889551897641,43793 -5986.441025257111,2.076756715774536,12737.713785409927,39599,0,12737.713785409927,0.9858794212341307,0.0489414632320404,0.2525515987235432,43793,18727.36371779442,0.9920904040336608,0.0250936243683099,0.5144372891395481,0.9867464303970336,0.0459163263440132,0.2768331889155624,43793 -6097.138377904892,2.1142361164093018,12977.72698712349,40351,0,12977.72698712349,0.9858419299125672,0.0490048862993717,0.2534988346590776,43793,19078.132613658905,0.9924927949905396,0.0239283088594675,0.5607641462687354,0.9866968989372252,0.0457773953676223,0.2796187737769758,43793 -6207.048756837845,2.148378849029541,13217.773389101028,41103,0,13217.773389101028,0.9858731031417848,0.048900943249464,0.2547507713891866,43793,19428.14403939247,0.9927111864089966,0.0232475940138101,0.5671089549068034,0.986598253250122,0.0459187179803848,0.2702917025777467,43793 -6312.429904937744,2.183336973190308,13457.892441272736,41852,0,13457.892441272736,0.9858137369155884,0.0493646673858165,0.2584219040667957,43793,19773.699976205826,0.9927569031715392,0.0229802466928958,0.5699415537277851,0.9866871237754822,0.0462121404707431,0.275476008104804,43793 -6420.406673908234,2.2188665866851807,13698.05339050293,42605,0,13698.05339050293,0.985913097858429,0.0496393516659736,0.2544896576267384,43793,20121.89403295517,0.9927787184715272,0.0229641068726778,0.5776440654013806,0.98681378364563,0.0465125553309917,0.2742111553186739,43793 -6527.937375545502,2.253929853439331,13938.207823753355,43359,0,13938.207823753355,0.9858166575431824,0.049183864146471,0.2583898694445723,43793,20469.63530278206,0.9925683736801147,0.0236154980957508,0.553858979056238,0.9865787625312804,0.0461424030363559,0.2788279899894264,43793 -6634.2423985004425,2.289602518081665,14178.196440935137,44113,0,14178.196440935137,0.9856974482536316,0.0496170744299888,0.2565613538093232,43793,20815.98492026329,0.9924434423446656,0.0240654963999986,0.5444905132486455,0.9865491390228271,0.0465872064232826,0.2754028911325936,43793 -6742.28401350975,2.325899839401245,14418.435846090317,44869,0,14418.435846090317,0.9857791662216188,0.0501223616302013,0.2560017213560228,43793,21164.32275795937,0.9923385977745056,0.0242275260388851,0.5417031986400285,0.9865332841873168,0.0470042116940021,0.2736799668428959,43793 -6850.10337138176,2.360506772994995,14658.508082866669,45619,0,14658.508082866669,0.9858579039573668,0.0503933764994144,0.2503948832383014,43793,21512.26938462257,0.9924070835113524,0.0239708330482244,0.5425693027797756,0.9866347908973694,0.0474613681435585,0.2696670163389971,43793 -6953.782834768295,2.395135402679444,14898.713564157486,46379,0,14898.713564157486,0.985714316368103,0.0505896285176277,0.2525996286947263,43793,21856.210003376007,0.9926638007164,0.0232639666646718,0.5599083695297632,0.9865369200706482,0.0473793819546699,0.2749799347087452,43793 -7061.289536476135,2.430541515350342,15138.81972503662,47131,0,15138.81972503662,0.985743761062622,0.0504318475723266,0.2610750357508299,43793,22203.87849545479,0.992831289768219,0.0226766634732484,0.5819867362002796,0.9864833354949952,0.0474307835102081,0.2747054695439795,43793 -7165.529061079025,2.469257354736328,15378.858348846436,47888,0,15378.858348846436,0.9857400059700012,0.0510310530662536,0.2514952056072648,43793,22548.216205596924,0.9928385615348816,0.0222714468836784,0.5891975629574058,0.9865233302116394,0.0481309220194816,0.2687571695286418,43793 -7274.522526025772,2.506998538970948,15619.018918275831,48635,0,15619.018918275831,0.9856961965560912,0.0508407466113567,0.2539394671432382,43793,22897.42957997322,0.993541657924652,0.0203063413500785,0.6341293814817233,0.9865320920944214,0.0479825772345066,0.2730396903356236,43793 -7380.601233720779,2.544231653213501,15859.109230279922,49384,0,15859.109230279922,0.985526442527771,0.0514129288494586,0.2537168429971005,43793,23243.65624427796,0.993671178817749,0.0201415475457906,0.6405901932539897,0.9863559007644652,0.0482384450733661,0.2759097759683734,43793 -7484.907016038895,2.5809249877929688,16099.19183897972,50136,0,16099.19183897972,0.9857547283172609,0.0516500025987625,0.2574407049190096,43793,23588.10204672813,0.9936359524726868,0.0199294872581958,0.634786984643436,0.986668050289154,0.0484478995203971,0.2804146011801025,43793 -7590.119548559189,2.616666078567505,16339.40024280548,50894,0,16339.40024280548,0.9856056571006776,0.0517693571746349,0.2514201600194406,43793,23933.579761743546,0.9934646487236024,0.0205312222242355,0.6163652200340647,0.9863924384117126,0.0486702099442482,0.2739500108171812,43793 -7694.790858745575,2.667487859725952,16579.39709186554,51638,0,16579.39709186554,0.9855926036834716,0.0521041415631771,0.2548566982480128,43793,24278.319911956787,0.9933746457099916,0.0207659751176834,0.6203458093296251,0.986450493335724,0.0490680560469627,0.272940814365671,43793 -7799.301882743835,2.7027688026428223,16819.611674308777,52394,0,16819.611674308777,0.9857804179191588,0.0528693981468677,0.2487877520350712,43793,24623.10201358795,0.993254005908966,0.0207973401993513,0.6166767403336669,0.9865624904632568,0.0497842095792293,0.2697312671483332,43793 -7899.80190038681,2.738955497741699,17059.75382900238,53149,0,17059.75382900238,0.9855176210403442,0.0532625913619995,0.2513808514617977,43793,24963.80125451088,0.993232250213623,0.0211403518915176,0.600300756020649,0.9863465428352356,0.0502323098480701,0.2693128444946038,43793 -8012.406537294388,2.774937868118286,17300.010417699814,53905,0,17300.010417699814,0.98549485206604,0.0535668618977069,0.2530998305906409,43793,25316.71926093101,0.9932581186294556,0.0207524746656417,0.6190629736359106,0.9863465428352356,0.0504545755684375,0.2680434163900203,43793 -8116.179745674133,2.8153786659240723,17540.260680675507,54648,0,17540.260680675507,0.9853617548942566,0.0540123507380485,0.2467456727644569,43793,25660.806088209152,0.9935899972915648,0.0197139848023653,0.6550292985915764,0.9860900044441224,0.051197599619627,0.2668912946641175,43793 -8221.137031793594,2.852015495300293,17780.475964784622,55397,0,17780.475964784622,0.9854013323783876,0.0537953115999698,0.2553403779232295,43793,26006.03764986992,0.9940750002861024,0.0183645077049732,0.6688367949780897,0.9861955642700196,0.0506768971681594,0.272017019263525,43793 -8330.204983472824,2.890192985534668,18020.540464401245,56149,0,18020.540464401245,0.9854514598846436,0.0548352636396884,0.2518330540823449,43793,26355.229289531708,0.994035005569458,0.0182362999767065,0.6705837973689833,0.986273467540741,0.0518716983497142,0.2665896640185701,43793 -8438.41262793541,2.932584762573242,18260.47955775261,56909,0,18260.47955775261,0.9854662418365479,0.0550340972840786,0.2515098676230641,43793,26703.43917274475,0.9949232339859009,0.015860287472605705,0.7254194245765323,0.9862621426582336,0.05209961161017418,0.2709062672047048,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/measurements.csv deleted file mode 100644 index 527527f93..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/measurements.csv +++ /dev/null @@ -1,655 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.5278695,0.7267631,,,,,,,,,,,,,,,,, -1,,,0.5324387550354004,0.7276949882507324,0.0216668436768478,0.5230699777603149,0.7331878542900085,0.0255681898490336,43793.0,0.5214916467666626,0.7347381114959717,0.0268228675787328,43793.0,11.491451501846312,125.73687386512756,11.491451501846312,114.24537301063538,0.0,0.0 -100,0.6399088,0.44354528,,,,,,,,,,,,,,,,, -200,0.3658697,0.33111528,,,,,,,,,,,,,,,,, -300,0.2687598,0.23709032,,,,,,,,,,,,,,,,, -400,0.17727198,0.16360013,,,,,,,,,,,,,,,,, -500,0.10999312,0.118310496,,,,,,,,,,,,,,,,, -600,0.07090292,0.08696679,,,,,,,,,,,,,,,,, -700,0.046328593,0.07461166,,,,,,,,,,,,,,,,, -745,,,0.9867120385169984,0.0699601098895073,0.0313649239899634,0.9841179251670836,0.0779209211468696,0.0335706418063978,43793.0,0.983142077922821,0.0807076543569564,0.0348263753677855,43793.0,251.67628693580627,482.7305812835693,251.67628693580627,231.0123994350433,0.0217573642730712,0.0 -800,0.05602733,0.06633724,,,,,,,,,,,,,,,,, -900,0.17319648,0.06263813,,,,,,,,,,,,,,,,, -1000,0.06948149,0.056953806,,,,,,,,,,,,,,,,, -1100,0.2699116,0.054191608,,,,,,,,,,,,,,,,, -1200,0.09249926,0.049331307,,,,,,,,,,,,,,,,, -1300,0.09585201,0.05534915,,,,,,,,,,,,,,,,, -1400,0.16517383,0.046978876,,,,,,,,,,,,,,,,, -1483,,,0.9868988394737244,0.0504347085952758,0.0713689412274672,0.9844674468040466,0.0596305727958679,0.0638470882588525,43793.0,0.9834638833999634,0.062710590660572,0.0691000026942505,43793.0,491.8714425563812,839.7973530292511,491.8714425563812,347.83271861076355,0.0503137111663818,0.0 -1500,0.19452558,0.055934455,,,,,,,,,,,,,,,,, -1600,0.073522665,0.04800873,,,,,,,,,,,,,,,,, -1700,0.31626037,0.054258745,,,,,,,,,,,,,,,,, -1800,0.06685143,0.0440496,,,,,,,,,,,,,,,,, -1900,0.23380147,0.045515023,,,,,,,,,,,,,,,,, -2000,0.15649495,0.04676682,,,,,,,,,,,,,,,,, -2100,0.10633885,0.047795925,,,,,,,,,,,,,,,,, -2200,0.091105774,0.049360305,,,,,,,,,,,,,,,,, -2227,,,0.9876678586006165,0.044889573007822,0.1232982563297867,0.9848230481147766,0.0541297234594821,0.1191175490663034,43793.0,0.9839284420013428,0.0571091137826442,0.1205723913361927,43793.0,732.0550570487976,1197.2648441791534,732.0550570487976,465.06700444221497,0.079200267791748,0.0 -2300,0.1747647,0.04954589,,,,,,,,,,,,,,,,, -2400,0.12483621,0.041101456,,,,,,,,,,,,,,,,, -2500,0.30736268,0.044404987,,,,,,,,,,,,,,,,, -2600,0.15832478,0.04050728,,,,,,,,,,,,,,,,, -2700,0.1512795,0.041122016,,,,,,,,,,,,,,,,, -2800,0.12644161,0.042601988,,,,,,,,,,,,,,,,, -2900,0.15499957,0.04771293,,,,,,,,,,,,,,,,, -2976,,,0.9881260991096495,0.0423507206141948,0.1503894734798555,0.9851656556129456,0.0517598614096641,0.1425854816835852,43793.0,0.984220325946808,0.054486583918333,0.1466533702010206,43793.0,972.2738552093506,1552.2670772075653,972.2738552093506,579.8030261993408,0.1062445640563964,0.0 -3000,0.09294543,0.043429445,,,,,,,,,,,,,,,,, -3100,0.105646454,0.04024735,,,,,,,,,,,,,,,,, -3200,0.08180879,0.041840203,,,,,,,,,,,,,,,,, -3300,0.10171594,0.04120282,,,,,,,,,,,,,,,,, -3400,0.06250942,0.037140302,,,,,,,,,,,,,,,,, -3500,0.11052368,0.039773036,,,,,,,,,,,,,,,,, -3600,0.16559643,0.039553348,,,,,,,,,,,,,,,,, -3700,0.06188676,0.0416743,,,,,,,,,,,,,,,,, -3718,,,0.9882087111473083,0.0411868691444397,0.1789513129569063,0.9852659106254578,0.0506561025977134,0.1721535565328543,43793.0,0.9843053817749025,0.0534604266285896,0.1665679376148152,43793.0,1212.473397731781,1914.3342943191528,1212.473397731781,701.6240711212158,0.1329081058502197,0.0 -3800,0.052214805,0.036885485,,,,,,,,,,,,,,,,, -3900,0.16767652,0.039863832,,,,,,,,,,,,,,,,, -4000,0.059239306,0.042770647,,,,,,,,,,,,,,,,, -4100,0.07499896,0.039664596,,,,,,,,,,,,,,,,, -4200,0.06009964,0.035850644,,,,,,,,,,,,,,,,, -4300,0.089537255,0.03932177,,,,,,,,,,,,,,,,, -4400,0.06007594,0.0381387,,,,,,,,,,,,,,,,, -4461,,,0.988564133644104,0.0396111719310283,0.2052638865801224,0.9856268167495728,0.0491602905094623,0.1869933395085386,43793.0,0.9846844673156738,0.051958292722702,0.1845553027411022,43793.0,1452.6450970172882,2269.441216945648,1452.6450970172882,816.5106444358826,0.1609640121459961,0.0 -4500,0.04333918,0.03937303,,,,,,,,,,,,,,,,, -4600,0.07674811,0.043301784,,,,,,,,,,,,,,,,, -4700,0.05897932,0.03510336,,,,,,,,,,,,,,,,, -4800,0.07284914,0.039304856,,,,,,,,,,,,,,,,, -4900,0.050106026,0.03650977,,,,,,,,,,,,,,,,, -5000,0.073188044,0.04056985,,,,,,,,,,,,,,,,, -5100,0.05315291,0.03643718,,,,,,,,,,,,,,,,, -5200,0.038279023,0.040338468,,,,,,,,,,,,,,,,, -5203,,,0.9885818362236024,0.0390691086649894,0.21580086321456,0.9856730699539183,0.0491345338523387,0.1862065133678861,43793.0,0.984832763671875,0.0517894700169563,0.1853352581939542,43793.0,1692.7988908290863,2623.5991291999817,1692.7988908290863,930.4642643928528,0.1912925243377685,0.0 -5300,0.056051448,0.037176654,,,,,,,,,,,,,,,,, -5400,0.055502392,0.045257773,,,,,,,,,,,,,,,,, -5500,0.061633915,0.040674888,,,,,,,,,,,,,,,,, -5600,0.06093695,0.04024211,,,,,,,,,,,,,,,,, -5700,0.08226581,0.032580655,,,,,,,,,,,,,,,,, -5800,0.12138537,0.03931053,,,,,,,,,,,,,,,,, -5900,0.045040164,0.03596614,,,,,,,,,,,,,,,,, -5946,,,0.988966464996338,0.0375383198261261,0.2455382972249783,0.9859256148338318,0.0473123341798782,0.2103226599902252,43793.0,0.9849889874458312,0.0500839799642562,0.1999434708760638,43793.0,1932.7561764717104,2972.670731782913,1932.7561764717104,1039.5315363407135,0.2181088924407959,0.0 -6000,0.06568344,0.034443393,,,,,,,,,,,,,,,,, -6100,0.053760994,0.033514354,,,,,,,,,,,,,,,,, -6200,0.04083388,0.037028536,,,,,,,,,,,,,,,,, -6300,0.039345667,0.037387546,,,,,,,,,,,,,,,,, -6400,0.052630357,0.0383397,,,,,,,,,,,,,,,,, -6500,0.053117268,0.037832137,,,,,,,,,,,,,,,,, -6600,0.029980034,0.034543842,,,,,,,,,,,,,,,,, -6683,,,0.989159107208252,0.0369097590446472,0.2577345522809846,0.9858171939849854,0.0476182177662849,0.2067422098356613,43793.0,0.9849334359169006,0.0502394512295722,0.2088082447016146,43793.0,2172.862287044525,3324.5116848945618,2172.862287044525,1151.2113268375397,0.250399112701416,0.0 -6700,0.05269133,0.039785504,,,,,,,,,,,,,,,,, -6800,0.03045607,0.03597079,,,,,,,,,,,,,,,,, -6900,0.055903774,0.036570784,,,,,,,,,,,,,,,,, -7000,0.04313231,0.037141066,,,,,,,,,,,,,,,,, -7100,0.030547978,0.041343562,,,,,,,,,,,,,,,,, -7200,0.03406152,0.035053775,,,,,,,,,,,,,,,,, -7300,0.040500034,0.035942193,,,,,,,,,,,,,,,,, -7400,0.046132054,0.039704468,,,,,,,,,,,,,,,,, -7430,,,0.9891023635864258,0.0369815491139888,0.2569534259312815,0.9860132932662964,0.0470480434596538,0.2162440311839511,43793.0,0.985179364681244,0.0497350096702575,0.2149340764446602,43793.0,2413.044668197632,3680.946016788482,2413.044668197632,1267.4153888225555,0.2777352333068847,0.0 -7500,0.02955998,0.038506143,,,,,,,,,,,,,,,,, -7600,0.04041846,0.038779896,,,,,,,,,,,,,,,,, -7700,0.036382113,0.035421077,,,,,,,,,,,,,,,,, -7800,0.07237396,0.03659832,,,,,,,,,,,,,,,,, -7900,0.03715773,0.035184857,,,,,,,,,,,,,,,,, -8000,0.057031997,0.04003937,,,,,,,,,,,,,,,,, -8100,0.035709698,0.036893085,,,,,,,,,,,,,,,,, -8166,,,0.9891697764396667,0.0367159433662891,0.2575354688274187,0.986077845096588,0.0471108704805374,0.2182111850928434,43793.0,0.9851751923561096,0.0500355958938598,0.2150636442012934,43793.0,2653.0890650749207,4031.506335258484,2653.0890650749207,1377.8806405067444,0.3055968284606933,0.0 -8200,0.047610722,0.0336415,,,,,,,,,,,,,,,,, -8300,0.028400676,0.036006443,,,,,,,,,,,,,,,,, -8400,0.029131362,0.03754954,,,,,,,,,,,,,,,,, -8500,0.029594345,0.034286473,,,,,,,,,,,,,,,,, -8600,0.036767375,0.034853194,,,,,,,,,,,,,,,,, -8700,0.025804846,0.03414328,,,,,,,,,,,,,,,,, -8800,0.038551573,0.03891009,,,,,,,,,,,,,,,,, -8900,0.029206306,0.037343252,,,,,,,,,,,,,,,,, -8914,,,0.9893730878829956,0.035877127200365,0.2799335236621727,0.9860900044441224,0.0463456958532333,0.2231633377875584,43793.0,0.9852905869483948,0.0490034371614456,0.2206016502531977,43793.0,2893.1614439487457,4383.561316490173,2893.1614439487457,1489.8143291473389,0.3341653347015381,0.0 -9000,0.028556868,0.039145615,,,,,,,,,,,,,,,,, -9100,0.024077747,0.035959594,,,,,,,,,,,,,,,,, -9200,0.031839397,0.03647553,,,,,,,,,,,,,,,,, -9300,0.04437393,0.036680646,,,,,,,,,,,,,,,,, -9400,0.031755626,0.03703674,,,,,,,,,,,,,,,,, -9500,0.03094228,0.03388959,,,,,,,,,,,,,,,,, -9600,0.03720296,0.029390953,,,,,,,,,,,,,,,,, -9658,,,0.9894848465919496,0.0353065766394138,0.2950740295810377,0.9863181114196776,0.0457585975527763,0.2377954955532402,43793.0,0.9855382442474364,0.0484528020024299,0.2348079878439649,43793.0,3133.301428318024,4736.115564584732,3133.301428318024,1602.1785380840302,0.3626327514648437,0.0 -9700,0.03608036,0.035206348,,,,,,,,,,,,,,,,, -9800,0.052917026,0.039051082,,,,,,,,,,,,,,,,, -9900,0.053004272,0.038709216,,,,,,,,,,,,,,,,, -10000,0.035431404,0.032122824,,,,,,,,,,,,,,,,, -10100,0.053774316,0.034790862,,,,,,,,,,,,,,,,, -10200,0.040798526,0.034158897,,,,,,,,,,,,,,,,, -10300,0.048331536,0.038345538,,,,,,,,,,,,,,,,, -10400,0.047192615,0.03494905,,,,,,,,,,,,,,,,, -10406,,,0.9897177815437316,0.0344165675342083,0.3170789589270832,0.9863526225090028,0.0455629266798496,0.2372990797672682,43793.0,0.9855803847312928,0.0481531098484993,0.2356283022953401,43793.0,3373.3994760513306,5085.509356498718,3373.3994760513306,1711.423900127411,0.3928844928741455,0.0 -10500,0.05043773,0.039366137,,,,,,,,,,,,,,,,, -10600,0.024215158,0.03274842,,,,,,,,,,,,,,,,, -10700,0.076712,0.03931356,,,,,,,,,,,,,,,,, -10800,0.027461836,0.031447984,,,,,,,,,,,,,,,,, -10900,0.04090743,0.03682715,,,,,,,,,,,,,,,,, -11000,0.029165022,0.036930706,,,,,,,,,,,,,,,,, -11100,0.04432743,0.034480266,,,,,,,,,,,,,,,,, -11155,,,0.989963173866272,0.0334694199264049,0.3281541876391228,0.9864525198936462,0.0453196726739406,0.2426067549501819,43793.0,0.985577404499054,0.047991894185543,0.2370999874809867,43793.0,3613.445493936538,5432.026793479919,3613.445493936538,1817.8432455062864,0.4235210418701172,0.0 -11200,0.046448212,0.034392815,,,,,,,,,,,,,,,,, -11300,0.03227579,0.0331659,,,,,,,,,,,,,,,,, -11400,0.028064452,0.03094309,,,,,,,,,,,,,,,,, -11500,0.030534018,0.033737425,,,,,,,,,,,,,,,,, -11600,0.02993308,0.03397612,,,,,,,,,,,,,,,,, -11700,0.038618226,0.03481916,,,,,,,,,,,,,,,,, -11800,0.036364563,0.03093818,,,,,,,,,,,,,,,,, -11898,,,0.9901131391525269,0.0330640785396099,0.3472507179399345,0.9863274693489076,0.0450441613793373,0.2529347650672703,43793.0,0.985605239868164,0.0476678498089313,0.2458516870009759,43793.0,3853.49520945549,5783.7891409397125,3853.49520945549,1929.5064034461973,0.4515836238861084,0.0 -11900,0.03804837,0.0356071,,,,,,,,,,,,,,,,, -12000,0.034645107,0.03544477,,,,,,,,,,,,,,,,, -12100,0.044712674,0.036491264,,,,,,,,,,,,,,,,, -12200,0.05518841,0.03318243,,,,,,,,,,,,,,,,, -12300,0.03728307,0.029962674,,,,,,,,,,,,,,,,, -12400,0.03431614,0.029139247,,,,,,,,,,,,,,,,, -12500,0.036101345,0.031846993,,,,,,,,,,,,,,,,, -12600,0.04922856,0.036769513,,,,,,,,,,,,,,,,, -12641,,,0.9902748465538024,0.0321195796132087,0.3678234983873403,0.9866047501564026,0.0449953004717826,0.251586093170994,43793.0,0.9857774972915648,0.0477763079106807,0.2443757778600505,43793.0,4093.3299930095673,6136.139136552811,4093.3299930095673,2041.57985496521,0.8727149963378906,0.0 -12700,0.054999303,0.03325511,,,,,,,,,,,,,,,,, -12800,0.053641852,0.03455575,,,,,,,,,,,,,,,,, -12900,0.036547985,0.031177282,,,,,,,,,,,,,,,,, -13000,0.050668094,0.031563386,,,,,,,,,,,,,,,,, -13100,0.054475233,0.03730816,,,,,,,,,,,,,,,,, -13200,0.06261941,0.031271596,,,,,,,,,,,,,,,,, -13300,0.04697992,0.030261796,,,,,,,,,,,,,,,,, -13388,,,0.990310549736023,0.0318638868629932,0.3633886240567374,0.9865308403968812,0.045009970664978,0.2489073772637887,43793.0,0.9857479929924012,0.0478766188025474,0.2479855948983023,43793.0,4333.489721059799,6484.470532894135,4333.489721059799,2149.701101541519,0.902353286743164,0.0 -13400,0.06489416,0.037711434,,,,,,,,,,,,,,,,, -13500,0.05587024,0.027568366,,,,,,,,,,,,,,,,, -13600,0.05413153,0.030699199,,,,,,,,,,,,,,,,, -13700,0.05996189,0.03416471,,,,,,,,,,,,,,,,, -13800,0.055256803,0.035706908,,,,,,,,,,,,,,,,, -13900,0.0725704,0.031617045,,,,,,,,,,,,,,,,, -14000,0.04676682,0.03173811,,,,,,,,,,,,,,,,, -14100,0.054221522,0.032525364,,,,,,,,,,,,,,,,, -14120,,,0.9904751777648926,0.0316066108644008,0.3673989488352646,0.9864829182624816,0.0447915568947792,0.2517094654821753,43793.0,0.9856890439987184,0.0474424511194229,0.2457914692060729,43793.0,4573.749706029892,6834.34108376503,4573.749706029892,2259.2570674419403,0.931546688079834,0.0 -14200,0.101678185,0.03033142,,,,,,,,,,,,,,,,, -14300,0.05453836,0.030468224,,,,,,,,,,,,,,,,, -14400,0.05785368,0.03329445,,,,,,,,,,,,,,,,, -14500,0.06251432,0.028572543,,,,,,,,,,,,,,,,, -14600,0.05077401,0.03128728,,,,,,,,,,,,,,,,, -14700,0.07028906,0.031125473,,,,,,,,,,,,,,,,, -14800,0.08673399,0.033374507,,,,,,,,,,,,,,,,, -14870,,,0.9906328320503236,0.0312632247805595,0.3707889618155518,0.9866303205490112,0.0445795319974422,0.2557782403995275,43793.0,0.9857884645462036,0.0473058149218559,0.2542209605860204,43793.0,4813.801182031632,7182.887937784195,4813.801182031632,2367.700381278992,0.9632065296173096,0.0 -14900,0.06517581,0.035177756,,,,,,,,,,,,,,,,, -15000,0.04491948,0.032390874,,,,,,,,,,,,,,,,, -15100,0.077540584,0.032501306,,,,,,,,,,,,,,,,, -15200,0.070442036,0.03169334,,,,,,,,,,,,,,,,, -15300,0.061391793,0.034128558,,,,,,,,,,,,,,,,, -15400,0.051547524,0.034939036,,,,,,,,,,,,,,,,, -15500,0.05567931,0.032247417,,,,,,,,,,,,,,,,, -15600,0.10691823,0.036680415,,,,,,,,,,,,,,,,, -15627,,,0.9904218912124634,0.0316820815205574,0.3776189810012174,0.9864882230758668,0.0450279936194419,0.2553386486819468,43793.0,0.9857210516929626,0.0475239269435405,0.2492742742666057,43793.0,5053.904675960541,7530.179664611816,5053.904675960541,2474.837132692337,0.9940104484558104,0.0 -15700,0.06392847,0.031737227,,,,,,,,,,,,,,,,, -15800,0.056728322,0.02952526,,,,,,,,,,,,,,,,, -15900,0.08219942,0.031636123,,,,,,,,,,,,,,,,, -16000,0.050743684,0.03188365,,,,,,,,,,,,,,,,, -16100,0.058631815,0.03077364,,,,,,,,,,,,,,,,, -16200,0.05134626,0.031351328,,,,,,,,,,,,,,,,, -16300,0.05066355,0.027838362,,,,,,,,,,,,,,,,, -16374,,,0.9905163049697876,0.0314247533679008,0.3839333720689543,0.9865962266921996,0.0447120182216167,0.2527464871775334,43793.0,0.9857126474380492,0.0474039055407047,0.2427294222040624,43793.0,5293.972366333008,7877.760135889053,5293.972366333008,2582.3004174232483,1.0230059623718262,0.0 -16400,0.051527154,0.031507228,,,,,,,,,,,,,,,,, -16500,0.07883921,0.029613445,,,,,,,,,,,,,,,,, -16600,0.05337854,0.030957669,,,,,,,,,,,,,,,,, -16700,0.060176272,0.031946324,,,,,,,,,,,,,,,,, -16800,0.07503223,0.03203539,,,,,,,,,,,,,,,,, -16900,0.07326956,0.033040233,,,,,,,,,,,,,,,,, -17000,0.074842125,0.032688316,,,,,,,,,,,,,,,,, -17100,0.08606316,0.032627292,,,,,,,,,,,,,,,,, -17134,,,0.9904187321662904,0.0311676487326622,0.3788797100037003,0.986632764339447,0.0448093190789222,0.2590131502594755,43793.0,0.9857867360115052,0.0475266650319099,0.2523974319454819,43793.0,5534.025603532791,8226.361269712448,5534.025603532791,2690.791320323944,1.059596300125122,0.0 -17200,0.073757164,0.032596122,,,,,,,,,,,,,,,,, -17300,0.08155859,0.028240249,,,,,,,,,,,,,,,,, -17400,0.086697444,0.032590453,,,,,,,,,,,,,,,,, -17500,0.06061636,0.033941187,,,,,,,,,,,,,,,,, -17600,0.062171683,0.029967943,,,,,,,,,,,,,,,,, -17700,0.058221117,0.030337498,,,,,,,,,,,,,,,,, -17800,0.09771662,0.034787618,,,,,,,,,,,,,,,,, -17887,,,0.9908297061920166,0.0302246939390897,0.3976157937726178,0.9866778254508972,0.0445819795131683,0.2644920682135977,43793.0,0.9859261512756348,0.047035839408636,0.2652222081535117,43793.0,5774.065830469132,8578.81997179985,5774.065830469132,2803.158401966095,1.0902180671691897,0.0 -17900,0.07643674,0.032565847,,,,,,,,,,,,,,,,, -18000,0.1053109,0.032545015,,,,,,,,,,,,,,,,, -18100,0.080313206,0.03333903,,,,,,,,,,,,,,,,, -18200,0.09026878,0.03153564,,,,,,,,,,,,,,,,, -18300,0.14520061,0.031793192,,,,,,,,,,,,,,,,, -18400,0.10723079,0.032273572,,,,,,,,,,,,,,,,, -18500,0.07301578,0.029604001,,,,,,,,,,,,,,,,, -18600,0.07087953,0.031484768,,,,,,,,,,,,,,,,, -18634,,,0.9909319281578064,0.0297775752842426,0.4131658217801549,0.9866737127304076,0.0447421967983245,0.2665251325435218,43793.0,0.9858751893043518,0.0473973602056503,0.2532736599645985,43793.0,6014.162916898727,8928.417778730392,6014.162916898727,2912.605295658112,1.1236302852630615,0.0 -18700,0.08820364,0.027121488,,,,,,,,,,,,,,,,, -18800,0.06509674,0.029041959,,,,,,,,,,,,,,,,, -18900,0.07225748,0.031575684,,,,,,,,,,,,,,,,, -19000,0.08056239,0.032244813,,,,,,,,,,,,,,,,, -19100,0.08895111,0.03102375,,,,,,,,,,,,,,,,, -19200,0.09006452,0.02754474,,,,,,,,,,,,,,,,, -19300,0.08471254,0.030305602,,,,,,,,,,,,,,,,, -19381,,,0.9910821914672852,0.0290072355419397,0.4330688829518003,0.9867277145385742,0.0453269369900226,0.2618424861378399,43793.0,0.9859354496002196,0.0479034930467605,0.2580514846627701,43793.0,6254.241195201874,9281.429696083069,6254.241195201874,3025.486868619919,1.1551201343536377,0.0 -19400,0.07004339,0.028988484,,,,,,,,,,,,,,,,, -19500,0.08785358,0.028025454,,,,,,,,,,,,,,,,, -19600,0.09024173,0.032552846,,,,,,,,,,,,,,,,, -19700,0.13254689,0.0322698,,,,,,,,,,,,,,,,, -19800,0.102728315,0.03476309,,,,,,,,,,,,,,,,, -19900,0.075020775,0.027908945,,,,,,,,,,,,,,,,, -20000,0.08910119,0.029286038,,,,,,,,,,,,,,,,, -20100,0.10642895,0.027782382,,,,,,,,,,,,,,,,, -20119,,,0.9911299347877502,0.029043648391962,0.4524596980950766,0.9866428971290588,0.0447989925742149,0.2666393752428824,43793.0,0.985922396183014,0.0472350865602493,0.2566684458171706,43793.0,6494.479324102402,9637.310660123823,6494.479324102402,3141.0740916728973,1.1857497692108154,0.0 -20200,0.074126296,0.028852118,,,,,,,,,,,,,,,,, -20300,0.08756326,0.02911224,,,,,,,,,,,,,,,,, -20400,0.0678409,0.029441524,,,,,,,,,,,,,,,,, -20500,0.08882496,0.032805502,,,,,,,,,,,,,,,,, -20600,0.08099152,0.025380787,,,,,,,,,,,,,,,,, -20700,0.14069238,0.031916868,,,,,,,,,,,,,,,,, -20800,0.07484192,0.030323917,,,,,,,,,,,,,,,,, -20860,,,0.9911418557167052,0.0289900843054056,0.4361918090920916,0.9865061044692992,0.0448005199432373,0.2578541146981997,43793.0,0.9855993390083312,0.0475965365767478,0.2488904178834033,43793.0,6734.529579162598,9989.929476737976,6734.529579162598,3253.587816953659,1.219196319580078,0.0 -20900,0.068573296,0.028239736,,,,,,,,,,,,,,,,, -21000,0.097618096,0.031830702,,,,,,,,,,,,,,,,, -21100,0.08226869,0.027805554,,,,,,,,,,,,,,,,, -21200,0.08561142,0.02876054,,,,,,,,,,,,,,,,, -21300,0.12812993,0.030846206,,,,,,,,,,,,,,,,, -21400,0.09973603,0.031479612,,,,,,,,,,,,,,,,, -21500,0.1242519,0.028916508,,,,,,,,,,,,,,,,, -21600,0.083246194,0.03318468,,,,,,,,,,,,,,,,, -21614,,,0.9909935593605042,0.0295106805860996,0.418704369794744,0.9864736199378968,0.0446580611169338,0.2688391855446387,43793.0,0.9857787489891052,0.0473995953798294,0.2512402387290738,43793.0,6974.786199092865,10340.48482823372,6974.786199092865,3363.8338465690613,1.250971794128418,0.0 -21700,0.072472334,0.028163994,,,,,,,,,,,,,,,,, -21800,0.09243593,0.030512536,,,,,,,,,,,,,,,,, -21900,0.088393725,0.03055178,,,,,,,,,,,,,,,,, -22000,0.08971962,0.028433003,,,,,,,,,,,,,,,,, -22100,0.07293444,0.025919404,,,,,,,,,,,,,,,,, -22200,0.09935735,0.032914996,,,,,,,,,,,,,,,,, -22300,0.087803304,0.031749323,,,,,,,,,,,,,,,,, -22366,,,0.990916907787323,0.0297023355960845,0.4092995417653871,0.9867342114448548,0.0445825308561325,0.2692697508641262,43793.0,0.9858646988868712,0.0471776910126209,0.2555786770835429,43793.0,7215.030090808868,10686.331381559372,7215.030090808868,3469.385509490967,1.2810804843902588,0.0 -22400,0.15426175,0.03076668,,,,,,,,,,,,,,,,, -22500,0.111679226,0.028967278,,,,,,,,,,,,,,,,, -22600,0.10458249,0.033081796,,,,,,,,,,,,,,,,, -22700,0.08423284,0.030019138,,,,,,,,,,,,,,,,, -22800,0.09428255,0.03293792,,,,,,,,,,,,,,,,, -22900,0.10166023,0.028007852,,,,,,,,,,,,,,,,, -23000,0.10670856,0.032101333,,,,,,,,,,,,,,,,, -23100,0.13074473,0.030073432,,,,,,,,,,,,,,,,, -23115,,,0.990923285484314,0.0296106562018394,0.4129811539648565,0.9866039156913756,0.0448347702622413,0.2584141344443701,43793.0,0.985886573791504,0.0472914464771747,0.2569587154031185,43793.0,7455.064225435257,11036.378544092178,7455.064225435257,3579.345846414566,1.3130922317504885,0.0 -23200,0.07613528,0.027602645,,,,,,,,,,,,,,,,, -23300,0.086889185,0.029675383,,,,,,,,,,,,,,,,, -23400,0.08363404,0.031591445,,,,,,,,,,,,,,,,, -23500,0.087628186,0.027964424,,,,,,,,,,,,,,,,, -23600,0.08492426,0.02799781,,,,,,,,,,,,,,,,, -23700,0.124805674,0.031155461,,,,,,,,,,,,,,,,, -23800,0.09390278,0.0285746,,,,,,,,,,,,,,,,, -23867,,,0.991037666797638,0.0293105076998472,0.4443964159175245,0.9865726828575134,0.0448810681700706,0.2680395780729493,43793.0,0.9857938885688782,0.0474799573421478,0.260062976833778,43793.0,7695.025829792023,11383.292026758194,7695.025829792023,3686.243659973145,1.3458125591278076,0.0 -23900,0.15154442,0.029531417,,,,,,,,,,,,,,,,, -24000,0.090649426,0.031187188,,,,,,,,,,,,,,,,, -24100,0.080130406,0.026736652,,,,,,,,,,,,,,,,, -24200,0.08564589,0.031934105,,,,,,,,,,,,,,,,, -24300,0.08474239,0.030513829,,,,,,,,,,,,,,,,, -24400,0.10995846,0.027479568,,,,,,,,,,,,,,,,, -24500,0.094644904,0.029903403,,,,,,,,,,,,,,,,, -24600,0.082709804,0.027932648,,,,,,,,,,,,,,,,, -24615,,,0.9911251068115234,0.0287194959819316,0.4349780862460942,0.9867350459098816,0.0446908399462699,0.2686270225783731,43793.0,0.9858722686767578,0.0475184321403503,0.257685135742762,43793.0,7935.27251625061,11736.308470726011,7935.27251625061,3798.9610619544974,1.377070426940918,0.0 -24700,0.1130121,0.032626264,,,,,,,,,,,,,,,,, -24800,0.077703066,0.026972018,,,,,,,,,,,,,,,,, -24900,0.07416494,0.025782498,,,,,,,,,,,,,,,,, -25000,0.09142168,0.02707573,,,,,,,,,,,,,,,,, -25100,0.11401647,0.031345826,,,,,,,,,,,,,,,,, -25200,0.08677806,0.028098376,,,,,,,,,,,,,,,,, -25300,0.078576736,0.031000974,,,,,,,,,,,,,,,,, -25352,,,0.9912492036819458,0.0284839458763599,0.4410713011697503,0.9866904020309448,0.0446516759693622,0.2698735193364242,43793.0,0.9858680367469788,0.047340765595436,0.2599173126520697,43793.0,8175.348889112472,12088.560628414154,8175.348889112472,3911.076207399368,1.4118154048919678,0.0 -25400,0.097464226,0.026663704,,,,,,,,,,,,,,,,, -25500,0.082130104,0.030014278,,,,,,,,,,,,,,,,, -25600,0.0917234,0.030532831,,,,,,,,,,,,,,,,, -25700,0.0752827,0.028889664,,,,,,,,,,,,,,,,, -25800,0.14778237,0.030148981,,,,,,,,,,,,,,,,, -25900,0.09823194,0.03007707,,,,,,,,,,,,,,,,, -26000,0.09295443,0.027866045,,,,,,,,,,,,,,,,, -26095,,,0.9914388656616212,0.0277533438056707,0.4717187337805464,0.9866493940353394,0.0448740981519222,0.2677851521561414,43793.0,0.9858587980270386,0.0475784353911876,0.2584336771859823,43793.0,8415.341947555542,12442.495477199554,8415.341947555542,4024.9607586860657,1.4460327625274658,0.0 -26100,0.103684805,0.027231842,,,,,,,,,,,,,,,,, -26200,0.103071846,0.0303926,,,,,,,,,,,,,,,,, -26300,0.10011675,0.030235823,,,,,,,,,,,,,,,,, -26400,0.08513246,0.028692504,,,,,,,,,,,,,,,,, -26500,0.09052822,0.029024469,,,,,,,,,,,,,,,,, -26600,0.08954744,0.030057296,,,,,,,,,,,,,,,,, -26700,0.095063776,0.027579159,,,,,,,,,,,,,,,,, -26800,0.09802647,0.03113683,,,,,,,,,,,,,,,,, -26836,,,0.9916952848434448,0.0269474107772111,0.4850118415830761,0.9866157174110411,0.0448866635560989,0.2657575785424613,43793.0,0.9857859015464784,0.0476189441978931,0.2573965177777483,43793.0,8655.540459394455,12792.448338985443,8655.540459394455,4134.661108732224,1.4795305728912354,0.0 -26900,0.076522805,0.025825579,,,,,,,,,,,,,,,,, -27000,0.09878413,0.032582942,,,,,,,,,,,,,,,,, -27100,0.1413025,0.029228285,,,,,,,,,,,,,,,,, -27200,0.087975204,0.027596286,,,,,,,,,,,,,,,,, -27300,0.08887286,0.031384416,,,,,,,,,,,,,,,,, -27400,0.08937887,0.030002482,,,,,,,,,,,,,,,,, -27500,0.10722391,0.026657294,,,,,,,,,,,,,,,,, -27589,,,0.9915966391563416,0.0271198097616434,0.4833182505356849,0.9867532849311828,0.0451139137148857,0.2680907991811648,43793.0,0.9858987927436828,0.0480052009224891,0.2532023452419686,43793.0,8895.518734931946,13139.403820991516,8895.518734931946,4241.58500623703,1.5119729042053225,0.0 -27600,0.09310283,0.028164096,,,,,,,,,,,,,,,,, -27700,0.122561485,0.031411156,,,,,,,,,,,,,,,,, -27800,0.09798926,0.03062731,,,,,,,,,,,,,,,,, -27900,0.09101001,0.025919275,,,,,,,,,,,,,,,,, -28000,0.09721499,0.02883046,,,,,,,,,,,,,,,,, -28100,0.0951882,0.031911924,,,,,,,,,,,,,,,,, -28200,0.12164105,0.028761428,,,,,,,,,,,,,,,,, -28300,0.12841251,0.025030315,,,,,,,,,,,,,,,,, -28338,,,0.9913761615753174,0.0279940068721771,0.4537976301349314,0.986455738544464,0.0457374155521392,0.265332734975108,43793.0,0.9856258630752563,0.0483823120594024,0.2584529614231575,43793.0,9135.784817695618,13492.027368068697,9135.784817695618,4353.890632867813,1.5436317920684814,0.0 -28400,0.11962761,0.029253582,,,,,,,,,,,,,,,,, -28500,0.11315147,0.029791074,,,,,,,,,,,,,,,,, -28600,0.1142007,0.033903558,,,,,,,,,,,,,,,,, -28700,0.094934285,0.027868003,,,,,,,,,,,,,,,,, -28800,0.09065869,0.026794773,,,,,,,,,,,,,,,,, -28900,0.09576758,0.03212448,,,,,,,,,,,,,,,,, -29000,0.10638427,0.029961867,,,,,,,,,,,,,,,,, -29081,,,0.9911956191062928,0.0283650364726781,0.4424061318322148,0.986704170703888,0.044825755059719,0.2695830069042759,43793.0,0.9858810901641846,0.0476830154657363,0.2606048610788429,43793.0,9375.92338514328,13845.032121896744,9375.92338514328,4466.701440811157,1.5765349864959717,0.0 -29100,0.10691515,0.027775306,,,,,,,,,,,,,,,,, -29200,0.11133298,0.02953814,,,,,,,,,,,,,,,,, -29300,0.104310416,0.030706933,,,,,,,,,,,,,,,,, -29400,0.12346168,0.03122339,,,,,,,,,,,,,,,,, -29500,0.10837711,0.029951138,,,,,,,,,,,,,,,,, -29600,0.080910906,0.027723784,,,,,,,,,,,,,,,,, -29700,0.08535699,0.027025929,,,,,,,,,,,,,,,,, -29800,0.099511996,0.028552013,,,,,,,,,,,,,,,,, -29837,,,0.9912551641464232,0.0284039881080389,0.4587308396449663,0.9867545366287231,0.0448542796075344,0.2684878121075197,43793.0,0.9859964847564696,0.0474441237747669,0.2622114526014112,43793.0,9616.0367436409,14193.95768547058,9616.0367436409,4575.459840536118,1.6094539165496826,0.0 -29900,0.1155769,0.027014527,,,,,,,,,,,,,,,,, -30000,0.108845934,0.03188464,,,,,,,,,,,,,,,,, -30100,0.11978897,0.02963757,,,,,,,,,,,,,,,,, -30200,0.11020036,0.024670837,,,,,,,,,,,,,,,,, -30300,0.112575404,0.029968085,,,,,,,,,,,,,,,,, -30400,0.11534999,0.026721776,,,,,,,,,,,,,,,,, -30500,0.100708745,0.025189849,,,,,,,,,,,,,,,,, -30584,,,0.9915154576301576,0.0274634025990962,0.4633989594018279,0.9865645170211792,0.0447933971881866,0.2774304830450718,43793.0,0.985809087753296,0.0475344434380531,0.2657570707766837,43793.0,9856.1376850605,14542.43186235428,9856.1376850605,4683.776865005493,1.6447956562042236,0.0 -30600,0.108744174,0.031072676,,,,,,,,,,,,,,,,, -30700,0.08932245,0.026024224,,,,,,,,,,,,,,,,, -30800,0.09346867,0.030988785,,,,,,,,,,,,,,,,, -30900,0.09306573,0.02772866,,,,,,,,,,,,,,,,, -31000,0.11234341,0.027033206,,,,,,,,,,,,,,,,, -31100,0.10940453,0.029999882,,,,,,,,,,,,,,,,, -31200,0.11895496,0.028969133,,,,,,,,,,,,,,,,, -31300,0.091095455,0.027664825,,,,,,,,,,,,,,,,, -31337,,,0.9915711283683776,0.0272007342427968,0.478147452430944,0.9866875410079956,0.0451299957931041,0.2771974970735388,43793.0,0.9858158230781556,0.0479998849332332,0.2566377316769796,43793.0,10096.208181619644,14892.963721752169,10096.208181619644,4794.186100482941,1.6768467426300049,0.0 -31400,0.09897314,0.025677696,,,,,,,,,,,,,,,,, -31500,0.10083547,0.024667948,,,,,,,,,,,,,,,,, -31600,0.12658718,0.028930005,,,,,,,,,,,,,,,,, -31700,0.091895245,0.028302606,,,,,,,,,,,,,,,,, -31800,0.100480564,0.028127454,,,,,,,,,,,,,,,,, -31900,0.091749065,0.026011461,,,,,,,,,,,,,,,,, -32000,0.10511092,0.02999389,,,,,,,,,,,,,,,,, -32094,,,0.991542398929596,0.0271677859127521,0.4717304585287615,0.9865958094596864,0.0454927459359169,0.2630080554557668,43793.0,0.9858823418617249,0.0483216866850853,0.2524618240322459,43793.0,10336.211215496063,15242.315237522123,10336.211215496063,4903.480740070343,1.7101600170135498,0.0 -32100,0.11684662,0.02608954,,,,,,,,,,,,,,,,, -32200,0.096760705,0.028618526,,,,,,,,,,,,,,,,, -32300,0.10718178,0.027127521,,,,,,,,,,,,,,,,, -32400,0.13295983,0.029949162,,,,,,,,,,,,,,,,, -32500,0.08399997,0.027026704,,,,,,,,,,,,,,,,, -32600,0.12385791,0.030232968,,,,,,,,,,,,,,,,, -32700,0.101527676,0.02563331,,,,,,,,,,,,,,,,, -32800,0.11060258,0.02785989,,,,,,,,,,,,,,,,, -32840,,,0.9917683601379396,0.026312205940485,0.5015883602624208,0.9866209626197816,0.0450939089059829,0.2753164946013637,43793.0,0.9857816696166992,0.048105664551258,0.2627687226551145,43793.0,10576.290118932724,15588.626574993134,10576.290118932724,5009.653836011887,1.7483479976654053,0.0 -32900,0.09669823,0.026128247,,,,,,,,,,,,,,,,, -33000,0.118294574,0.025150433,,,,,,,,,,,,,,,,, -33100,0.102316424,0.029310167,,,,,,,,,,,,,,,,, -33200,0.09726081,0.028147534,,,,,,,,,,,,,,,,, -33300,0.11140795,0.026965518,,,,,,,,,,,,,,,,, -33400,0.10742369,0.028867545,,,,,,,,,,,,,,,,, -33500,0.11235143,0.026547693,,,,,,,,,,,,,,,,, -33583,,,0.9920535683631896,0.0255418438464403,0.5182077647609746,0.9865832328796388,0.0453186854720115,0.2710398928255368,43793.0,0.9858322143554688,0.0479882806539535,0.2613832404439288,43793.0,10816.54667019844,15938.288664340973,10816.54667019844,5118.999151468277,1.7833147048950195,0.0 -33600,0.122652605,0.025396276,,,,,,,,,,,,,,,,, -33700,0.11831696,0.02466039,,,,,,,,,,,,,,,,, -33800,0.099135466,0.025895525,,,,,,,,,,,,,,,,, -33900,0.09857582,0.02748654,,,,,,,,,,,,,,,,, -34000,0.11508345,0.028602913,,,,,,,,,,,,,,,,, -34100,0.10798472,0.026181463,,,,,,,,,,,,,,,,, -34200,0.11332141,0.026803372,,,,,,,,,,,,,,,,, -34300,0.109095074,0.024402006,,,,,,,,,,,,,,,,, -34333,,,0.992278516292572,0.0248276535421609,0.5308865696339622,0.9868900775909424,0.0454954542219638,0.2792153231160525,43793.0,0.9859931468963624,0.0483602136373519,0.2592067698810614,43793.0,11056.78190946579,16288.988464832306,11056.78190946579,5229.406427145004,1.818897008895874,0.0 -34400,0.14151396,0.027793635,,,,,,,,,,,,,,,,, -34500,0.098734,0.025376562,,,,,,,,,,,,,,,,, -34600,0.13729721,0.023474066,,,,,,,,,,,,,,,,, -34700,0.096676424,0.02594109,,,,,,,,,,,,,,,,, -34800,0.13122791,0.026940402,,,,,,,,,,,,,,,,, -34900,0.12661077,0.028260516,,,,,,,,,,,,,,,,, -35000,0.09012426,0.027794601,,,,,,,,,,,,,,,,, -35076,,,0.9921972751617432,0.02528334595263,0.5149612251079351,0.9865661859512328,0.0451306663453578,0.2722435905529067,43793.0,0.985697865486145,0.0481703765690326,0.2579584826472572,43793.0,11296.9742269516,16636.286905050278,11296.9742269516,5336.452176809311,1.8567280769348145,0.0 -35100,0.10908851,0.026084838,,,,,,,,,,,,,,,,, -35200,0.11939243,0.029096922,,,,,,,,,,,,,,,,, -35300,0.09952349,0.026209617,,,,,,,,,,,,,,,,, -35400,0.106449366,0.02687203,,,,,,,,,,,,,,,,, -35500,0.13436708,0.030932473,,,,,,,,,,,,,,,,, -35600,0.11600132,0.0263018,,,,,,,,,,,,,,,,, -35700,0.11753988,0.026519774,,,,,,,,,,,,,,,,, -35800,0.11136955,0.027685981,,,,,,,,,,,,,,,,, -35816,,,0.99190354347229,0.0260348320007324,0.5028728350713861,0.9866254329681396,0.0456982627511024,0.2760192571586574,43793.0,0.9858474135398864,0.0484303832054138,0.2547830847013111,43793.0,11536.981744527817,16986.371559381485,11536.981744527817,5446.469096899033,1.8912177085876465,0.0 -35900,0.13617377,0.029145682,,,,,,,,,,,,,,,,, -36000,0.10656775,0.024431884,,,,,,,,,,,,,,,,, -36100,0.1628167,0.026081042,,,,,,,,,,,,,,,,, -36200,0.1461885,0.026566664,,,,,,,,,,,,,,,,, -36300,0.14721368,0.030405415,,,,,,,,,,,,,,,,, -36400,0.12830831,0.026008196,,,,,,,,,,,,,,,,, -36500,0.10604855,0.028562915,,,,,,,,,,,,,,,,, -36574,,,0.9917653799057008,0.0263422261923551,0.4941073351555926,0.9865604639053344,0.0457744225859642,0.2681013307066064,43793.0,0.9858174920082092,0.048509806394577,0.2587279235831513,43793.0,11777.182457208632,17339.134885072708,11777.182457208632,5558.977123498917,1.9253017902374268,0.0 -36600,0.14844581,0.02930546,,,,,,,,,,,,,,,,, -36700,0.14514975,0.029727975,,,,,,,,,,,,,,,,, -36800,0.1147825,0.026719168,,,,,,,,,,,,,,,,, -36900,0.11427975,0.024096286,,,,,,,,,,,,,,,,, -37000,0.11579166,0.026312225,,,,,,,,,,,,,,,,, -37100,0.10882301,0.028357688,,,,,,,,,,,,,,,,, -37200,0.17356999,0.028275179,,,,,,,,,,,,,,,,, -37300,0.109464206,0.02528868,,,,,,,,,,,,,,,,, -37332,,,0.9916236400604248,0.0268276743590831,0.4900111688410533,0.98675936460495,0.0456772036850452,0.2650382750814447,43793.0,0.9857711791992188,0.0488524958491325,0.2530526844425635,43793.0,12017.42915058136,17684.79681611061,12017.42915058136,5664.32294178009,1.9741060733795168,0.0 -37400,0.09372962,0.024570556,,,,,,,,,,,,,,,,, -37500,0.14830853,0.027360732,,,,,,,,,,,,,,,,, -37600,0.119776934,0.02828529,,,,,,,,,,,,,,,,, -37700,0.13460547,0.027382137,,,,,,,,,,,,,,,,, -37800,0.12742278,0.02416513,,,,,,,,,,,,,,,,, -37900,0.13098761,0.0274853,,,,,,,,,,,,,,,,, -38000,0.12737152,0.029160293,,,,,,,,,,,,,,,,, -38089,,,0.9918012619018556,0.0262242015451192,0.5010326044071443,0.9865389466285706,0.0456525199115276,0.2704411948702859,43793.0,0.9857606291770936,0.0485462956130504,0.2595443308006202,43793.0,12257.636492729189,18033.42644929886,12257.636492729189,5772.689391851425,2.009575843811035,0.0 -38100,0.10870887,0.026225945,,,,,,,,,,,,,,,,, -38200,0.12759809,0.028635873,,,,,,,,,,,,,,,,, -38300,0.1288476,0.025745146,,,,,,,,,,,,,,,,, -38400,0.11690457,0.028213797,,,,,,,,,,,,,,,,, -38500,0.11804106,0.024624657,,,,,,,,,,,,,,,,, -38600,0.120487936,0.025527522,,,,,,,,,,,,,,,,, -38700,0.12080381,0.028155055,,,,,,,,,,,,,,,,, -38800,0.12746714,0.026387716,,,,,,,,,,,,,,,,, -38846,,,0.9921699166297911,0.0252218600362539,0.5169255787785384,0.9866672158241272,0.0460537821054458,0.2735889551897641,43793.0,0.9858554005622864,0.0490255393087863,0.2538144049564341,43793.0,12497.743663549423,18379.9026362896,12497.743663549423,5879.004307746887,2.043231964111328,0.0 -38900,0.12228766,0.024395593,,,,,,,,,,,,,,,,, -39000,0.1416662,0.027929954,,,,,,,,,,,,,,,,, -39100,0.16725023,0.028211072,,,,,,,,,,,,,,,,, -39200,0.14560209,0.027261175,,,,,,,,,,,,,,,,, -39300,0.10408326,0.026023392,,,,,,,,,,,,,,,,, -39400,0.1177492,0.026809199,,,,,,,,,,,,,,,,, -39500,0.17111543,0.025803596,,,,,,,,,,,,,,,,, -39599,,,0.9920904040336608,0.0250936243683099,0.5144372891395481,0.9867464303970336,0.0459163263440132,0.2768331889155624,43793.0,0.9858794212341307,0.0489414632320404,0.2525515987235432,43793.0,12737.713785409927,18727.36371779442,12737.713785409927,5986.441025257111,2.076756715774536,0.0 -39600,0.12719755,0.025198922,,,,,,,,,,,,,,,,, -39700,0.12774833,0.026148697,,,,,,,,,,,,,,,,, -39800,0.1489282,0.027179686,,,,,,,,,,,,,,,,, -39900,0.12523843,0.024587436,,,,,,,,,,,,,,,,, -40000,0.12187035,0.02287691,,,,,,,,,,,,,,,,, -40100,0.13262713,0.025373764,,,,,,,,,,,,,,,,, -40200,0.14333132,0.024779936,,,,,,,,,,,,,,,,, -40300,0.13369446,0.023997879,,,,,,,,,,,,,,,,, -40351,,,0.9924927949905396,0.0239283088594675,0.5607641462687354,0.9866968989372252,0.0457773953676223,0.2796187737769758,43793.0,0.9858419299125672,0.0490048862993717,0.2534988346590776,43793.0,12977.72698712349,19078.132613658905,12977.72698712349,6097.138377904892,2.1142361164093018,0.0 -40400,0.11148503,0.024597617,,,,,,,,,,,,,,,,, -40500,0.15597461,0.028940825,,,,,,,,,,,,,,,,, -40600,0.11034186,0.021863695,,,,,,,,,,,,,,,,, -40700,0.13147114,0.02391888,,,,,,,,,,,,,,,,, -40800,0.1577605,0.026743943,,,,,,,,,,,,,,,,, -40900,0.12615962,0.024456808,,,,,,,,,,,,,,,,, -41000,0.11915637,0.024388894,,,,,,,,,,,,,,,,, -41100,0.13675708,0.026040893,,,,,,,,,,,,,,,,, -41103,,,0.9927111864089966,0.0232475940138101,0.5671089549068034,0.986598253250122,0.0459187179803848,0.2702917025777467,43793.0,0.9858731031417848,0.048900943249464,0.2547507713891866,43793.0,13217.773389101028,19428.14403939247,13217.773389101028,6207.048756837845,2.148378849029541,0.0 -41200,0.116572924,0.02737507,,,,,,,,,,,,,,,,, -41300,0.1520012,0.02563812,,,,,,,,,,,,,,,,, -41400,0.16017886,0.027352056,,,,,,,,,,,,,,,,, -41500,0.11387326,0.022500377,,,,,,,,,,,,,,,,, -41600,0.11632449,0.02520771,,,,,,,,,,,,,,,,, -41700,0.13459891,0.026625557,,,,,,,,,,,,,,,,, -41800,0.1661445,0.028847765,,,,,,,,,,,,,,,,, -41852,,,0.9927569031715392,0.0229802466928958,0.5699415537277851,0.9866871237754822,0.0462121404707431,0.275476008104804,43793.0,0.9858137369155884,0.0493646673858165,0.2584219040667957,43793.0,13457.892441272736,19773.699976205826,13457.892441272736,6312.429904937744,2.183336973190308,0.0 -41900,0.12428466,0.025332851,,,,,,,,,,,,,,,,, -42000,0.12526107,0.022375664,,,,,,,,,,,,,,,,, -42100,0.15626796,0.02551014,,,,,,,,,,,,,,,,, -42200,0.1399504,0.022210246,,,,,,,,,,,,,,,,, -42300,0.12067365,0.023362665,,,,,,,,,,,,,,,,, -42400,0.17704578,0.025834875,,,,,,,,,,,,,,,,, -42500,0.15089542,0.023080073,,,,,,,,,,,,,,,,, -42600,0.14693734,0.024102658,,,,,,,,,,,,,,,,, -42605,,,0.9927787184715272,0.0229641068726778,0.5776440654013806,0.98681378364563,0.0465125553309917,0.2742111553186739,43793.0,0.985913097858429,0.0496393516659736,0.2544896576267384,43793.0,13698.05339050293,20121.89403295517,13698.05339050293,6420.406673908234,2.2188665866851807,0.0 -42700,0.11174598,0.025941117,,,,,,,,,,,,,,,,, -42800,0.15540461,0.02315041,,,,,,,,,,,,,,,,, -42900,0.13522051,0.02571662,,,,,,,,,,,,,,,,, -43000,0.11814134,0.023438172,,,,,,,,,,,,,,,,, -43100,0.14220089,0.021658827,,,,,,,,,,,,,,,,, -43200,0.14359759,0.024152918,,,,,,,,,,,,,,,,, -43300,0.159594,0.025080439,,,,,,,,,,,,,,,,, -43359,,,0.9925683736801147,0.0236154980957508,0.553858979056238,0.9865787625312804,0.0461424030363559,0.2788279899894264,43793.0,0.9858166575431824,0.049183864146471,0.2583898694445723,43793.0,13938.207823753355,20469.63530278206,13938.207823753355,6527.937375545502,2.253929853439331,0.0 -43400,0.13743545,0.023393895,,,,,,,,,,,,,,,,, -43500,0.13933372,0.024871917,,,,,,,,,,,,,,,,, -43600,0.12832801,0.021756459,,,,,,,,,,,,,,,,, -43700,0.14357777,0.024668938,,,,,,,,,,,,,,,,, -43800,0.13503031,0.026544195,,,,,,,,,,,,,,,,, -43900,0.13932543,0.021806339,,,,,,,,,,,,,,,,, -44000,0.150949,0.026321406,,,,,,,,,,,,,,,,, -44100,0.11856575,0.02240612,,,,,,,,,,,,,,,,, -44113,,,0.9924434423446656,0.0240654963999986,0.5444905132486455,0.9865491390228271,0.0465872064232826,0.2754028911325936,43793.0,0.9856974482536316,0.0496170744299888,0.2565613538093232,43793.0,14178.196440935137,20815.98492026329,14178.196440935137,6634.2423985004425,2.289602518081665,0.0 -44200,0.17302257,0.02764228,,,,,,,,,,,,,,,,, -44300,0.14335239,0.024307445,,,,,,,,,,,,,,,,, -44400,0.14615212,0.024265133,,,,,,,,,,,,,,,,, -44500,0.12366708,0.021921562,,,,,,,,,,,,,,,,, -44600,0.135936,0.024501985,,,,,,,,,,,,,,,,, -44700,0.12581992,0.020501647,,,,,,,,,,,,,,,,, -44800,0.14771843,0.02021278,,,,,,,,,,,,,,,,, -44869,,,0.9923385977745056,0.0242275260388851,0.5417031986400285,0.9865332841873168,0.0470042116940021,0.2736799668428959,43793.0,0.9857791662216188,0.0501223616302013,0.2560017213560228,43793.0,14418.435846090317,21164.32275795937,14418.435846090317,6742.28401350975,2.325899839401245,0.0 -44900,0.13449603,0.023473037,,,,,,,,,,,,,,,,, -45000,0.16847976,0.024243817,,,,,,,,,,,,,,,,, -45100,0.16776927,0.02346721,,,,,,,,,,,,,,,,, -45200,0.15650196,0.025465868,,,,,,,,,,,,,,,,, -45300,0.12098093,0.02088395,,,,,,,,,,,,,,,,, -45400,0.14702417,0.022906723,,,,,,,,,,,,,,,,, -45500,0.12470158,0.022827681,,,,,,,,,,,,,,,,, -45600,0.14594534,0.025093758,,,,,,,,,,,,,,,,, -45619,,,0.9924070835113524,0.0239708330482244,0.5425693027797756,0.9866347908973694,0.0474613681435585,0.2696670163389971,43793.0,0.9858579039573668,0.0503933764994144,0.2503948832383014,43793.0,14658.508082866669,21512.26938462257,14658.508082866669,6850.10337138176,2.360506772994995,0.0 -45700,0.1298972,0.022082033,,,,,,,,,,,,,,,,, -45800,0.15055317,0.026493883,,,,,,,,,,,,,,,,, -45900,0.18084385,0.02770582,,,,,,,,,,,,,,,,, -46000,0.13764359,0.023704443,,,,,,,,,,,,,,,,, -46100,0.15354252,0.021308424,,,,,,,,,,,,,,,,, -46200,0.14029256,0.022489345,,,,,,,,,,,,,,,,, -46300,0.18033357,0.02485049,,,,,,,,,,,,,,,,, -46379,,,0.9926638007164,0.0232639666646718,0.5599083695297632,0.9865369200706482,0.0473793819546699,0.2749799347087452,43793.0,0.985714316368103,0.0505896285176277,0.2525996286947263,43793.0,14898.713564157486,21856.210003376007,14898.713564157486,6953.782834768295,2.395135402679444,0.0 -46400,0.15478548,0.02054046,,,,,,,,,,,,,,,,, -46500,0.13132755,0.020966036,,,,,,,,,,,,,,,,, -46600,0.15812267,0.025771867,,,,,,,,,,,,,,,,, -46700,0.11458703,0.022806738,,,,,,,,,,,,,,,,, -46800,0.18858711,0.025265554,,,,,,,,,,,,,,,,, -46900,0.12309564,0.02143612,,,,,,,,,,,,,,,,, -47000,0.14032358,0.021774605,,,,,,,,,,,,,,,,, -47100,0.17248185,0.024512887,,,,,,,,,,,,,,,,, -47131,,,0.992831289768219,0.0226766634732484,0.5819867362002796,0.9864833354949952,0.0474307835102081,0.2747054695439795,43793.0,0.985743761062622,0.0504318475723266,0.2610750357508299,43793.0,15138.81972503662,22203.87849545479,15138.81972503662,7061.289536476135,2.430541515350342,0.0 -47200,0.13988903,0.021247387,,,,,,,,,,,,,,,,, -47300,0.17011477,0.023387117,,,,,,,,,,,,,,,,, -47400,0.16714504,0.022277268,,,,,,,,,,,,,,,,, -47500,0.16111708,0.021899348,,,,,,,,,,,,,,,,, -47600,0.1385741,0.022100033,,,,,,,,,,,,,,,,, -47700,0.14648739,0.025420768,,,,,,,,,,,,,,,,, -47800,0.13486822,0.022130223,,,,,,,,,,,,,,,,, -47888,,,0.9928385615348816,0.0222714468836784,0.5891975629574058,0.9865233302116394,0.0481309220194816,0.2687571695286418,43793.0,0.9857400059700012,0.0510310530662536,0.2514952056072648,43793.0,15378.858348846436,22548.216205596924,15378.858348846436,7165.529061079025,2.469257354736328,0.0 -47900,0.1721268,0.023095133,,,,,,,,,,,,,,,,, -48000,0.14246,0.023211548,,,,,,,,,,,,,,,,, -48100,0.12074382,0.01814287,,,,,,,,,,,,,,,,, -48200,0.13611345,0.020685004,,,,,,,,,,,,,,,,, -48300,0.13768288,0.022645393,,,,,,,,,,,,,,,,, -48400,0.14217769,0.024549564,,,,,,,,,,,,,,,,, -48500,0.15047175,0.020357605,,,,,,,,,,,,,,,,, -48600,0.14597362,0.021151561,,,,,,,,,,,,,,,,, -48635,,,0.993541657924652,0.0203063413500785,0.6341293814817233,0.9865320920944214,0.0479825772345066,0.2730396903356236,43793.0,0.9856961965560912,0.0508407466113567,0.2539394671432382,43793.0,15619.018918275831,22897.42957997322,15619.018918275831,7274.522526025772,2.506998538970948,0.0 -48700,0.1542436,0.024132693,,,,,,,,,,,,,,,,, -48800,0.13534638,0.020567888,,,,,,,,,,,,,,,,, -48900,0.13079892,0.019843848,,,,,,,,,,,,,,,,, -49000,0.16857706,0.024243658,,,,,,,,,,,,,,,,, -49100,0.17013003,0.022829704,,,,,,,,,,,,,,,,, -49200,0.18136893,0.02143636,,,,,,,,,,,,,,,,, -49300,0.14652038,0.022106346,,,,,,,,,,,,,,,,, -49384,,,0.993671178817749,0.0201415475457906,0.6405901932539897,0.9863559007644652,0.0482384450733661,0.2759097759683734,43793.0,0.985526442527771,0.0514129288494586,0.2537168429971005,43793.0,15859.109230279922,23243.65624427796,15859.109230279922,7380.601233720779,2.544231653213501,0.0 -49400,0.16600056,0.022365794,,,,,,,,,,,,,,,,, -49500,0.17290542,0.021308059,,,,,,,,,,,,,,,,, -49600,0.15978557,0.020477774,,,,,,,,,,,,,,,,, -49700,0.14458306,0.020432008,,,,,,,,,,,,,,,,, -49800,0.17958295,0.02254305,,,,,,,,,,,,,,,,, -49900,0.1484522,0.021275278,,,,,,,,,,,,,,,,, -50000,0.16335934,0.02035034,,,,,,,,,,,,,,,,, -50100,0.14918587,0.019983927,,,,,,,,,,,,,,,,, -50136,,,0.9936359524726868,0.0199294872581958,0.634786984643436,0.986668050289154,0.0484478995203971,0.2804146011801025,43793.0,0.9857547283172609,0.0516500025987625,0.2574407049190096,43793.0,16099.19183897972,23588.10204672813,16099.19183897972,7484.907016038895,2.5809249877929688,0.0 -50200,0.16545649,0.023145285,,,,,,,,,,,,,,,,, -50300,0.16876605,0.021054573,,,,,,,,,,,,,,,,, -50400,0.18476647,0.022636037,,,,,,,,,,,,,,,,, -50500,0.17035972,0.021242261,,,,,,,,,,,,,,,,, -50600,0.15501279,0.018074434,,,,,,,,,,,,,,,,, -50700,0.1861081,0.021442454,,,,,,,,,,,,,,,,, -50800,0.18194975,0.024085216,,,,,,,,,,,,,,,,, -50894,,,0.9934646487236024,0.0205312222242355,0.6163652200340647,0.9863924384117126,0.0486702099442482,0.2739500108171812,43793.0,0.9856056571006776,0.0517693571746349,0.2514201600194406,43793.0,16339.40024280548,23933.579761743546,16339.40024280548,7590.119548559189,2.616666078567505,0.0 -50900,0.17679319,0.02095212,,,,,,,,,,,,,,,,, -51000,0.15860261,0.020174922,,,,,,,,,,,,,,,,, -51100,0.15619102,0.021300968,,,,,,,,,,,,,,,,, -51200,0.16661356,0.021327855,,,,,,,,,,,,,,,,, -51300,0.18867245,0.022819359,,,,,,,,,,,,,,,,, -51400,0.19064252,0.022057872,,,,,,,,,,,,,,,,, -51500,0.19082035,0.021741105,,,,,,,,,,,,,,,,, -51600,0.18942994,0.018194837,,,,,,,,,,,,,,,,, -51638,,,0.9933746457099916,0.0207659751176834,0.6203458093296251,0.986450493335724,0.0490680560469627,0.272940814365671,43793.0,0.9855926036834716,0.0521041415631771,0.2548566982480128,43793.0,16579.39709186554,24278.319911956787,16579.39709186554,7694.790858745575,2.667487859725952,0.0 -51700,0.22140048,0.021869069,,,,,,,,,,,,,,,,, -51800,0.21681133,0.02350684,,,,,,,,,,,,,,,,, -51900,0.1694115,0.022207564,,,,,,,,,,,,,,,,, -52000,0.19532208,0.021225316,,,,,,,,,,,,,,,,, -52100,0.17430648,0.022288062,,,,,,,,,,,,,,,,, -52200,0.15672572,0.019891413,,,,,,,,,,,,,,,,, -52300,0.1645724,0.019054126,,,,,,,,,,,,,,,,, -52394,,,0.993254005908966,0.0207973401993513,0.6166767403336669,0.9865624904632568,0.0497842095792293,0.2697312671483332,43793.0,0.9857804179191588,0.0528693981468677,0.2487877520350712,43793.0,16819.611674308777,24623.10201358795,16819.611674308777,7799.301882743835,2.7027688026428223,0.0 -52400,0.16623455,0.019759316,,,,,,,,,,,,,,,,, -52500,0.19373655,0.018672086,,,,,,,,,,,,,,,,, -52600,0.2571776,0.023058588,,,,,,,,,,,,,,,,, -52700,0.22054194,0.02073824,,,,,,,,,,,,,,,,, -52800,0.18507162,0.020141017,,,,,,,,,,,,,,,,, -52900,0.18173838,0.020807946,,,,,,,,,,,,,,,,, -53000,0.19173715,0.020275898,,,,,,,,,,,,,,,,, -53100,0.1953088,0.018590469,,,,,,,,,,,,,,,,, -53149,,,0.993232250213623,0.0211403518915176,0.600300756020649,0.9863465428352356,0.0502323098480701,0.2693128444946038,43793.0,0.9855176210403442,0.0532625913619995,0.2513808514617977,43793.0,17059.75382900238,24963.80125451088,17059.75382900238,7899.80190038681,2.738955497741699,0.0 -53200,0.19296326,0.02104001,,,,,,,,,,,,,,,,, -53300,0.22506838,0.01936905,,,,,,,,,,,,,,,,, -53400,0.18285742,0.016280241,,,,,,,,,,,,,,,,, -53500,0.20320581,0.02230469,,,,,,,,,,,,,,,,, -53600,0.17496318,0.020008901,,,,,,,,,,,,,,,,, -53700,0.19630267,0.01985578,,,,,,,,,,,,,,,,, -53800,0.17488542,0.019771531,,,,,,,,,,,,,,,,, -53900,0.16675271,0.019187842,,,,,,,,,,,,,,,,, -53905,,,0.9932581186294556,0.0207524746656417,0.6190629736359106,0.9863465428352356,0.0504545755684375,0.2680434163900203,43793.0,0.98549485206604,0.0535668618977069,0.2530998305906409,43793.0,17300.010417699814,25316.71926093101,17300.010417699814,8012.406537294388,2.774937868118286,0.0 -54000,0.22540805,0.019885715,,,,,,,,,,,,,,,,, -54100,0.19508569,0.018407058,,,,,,,,,,,,,,,,, -54200,0.19857216,0.01919651,,,,,,,,,,,,,,,,, -54300,0.19342946,0.01901343,,,,,,,,,,,,,,,,, -54400,0.17909738,0.020213649,,,,,,,,,,,,,,,,, -54500,0.17419384,0.019251117,,,,,,,,,,,,,,,,, -54600,0.23078012,0.020661198,,,,,,,,,,,,,,,,, -54648,,,0.9935899972915648,0.0197139848023653,0.6550292985915764,0.9860900044441224,0.051197599619627,0.2668912946641175,43793.0,0.9853617548942566,0.0540123507380485,0.2467456727644569,43793.0,17540.260680675507,25660.806088209152,17540.260680675507,8116.179745674133,2.8153786659240723,0.0 -54700,0.19065398,0.020983508,,,,,,,,,,,,,,,,, -54800,0.21004838,0.020767998,,,,,,,,,,,,,,,,, -54900,0.18267713,0.01812841,,,,,,,,,,,,,,,,, -55000,0.17700379,0.01930481,,,,,,,,,,,,,,,,, -55100,0.18568449,0.019127425,,,,,,,,,,,,,,,,, -55200,0.23342726,0.020237189,,,,,,,,,,,,,,,,, -55300,0.19557098,0.01850275,,,,,,,,,,,,,,,,, -55397,,,0.9940750002861024,0.0183645077049732,0.6688367949780897,0.9861955642700196,0.0506768971681594,0.272017019263525,43793.0,0.9854013323783876,0.0537953115999698,0.2553403779232295,43793.0,17780.475964784622,26006.03764986992,17780.475964784622,8221.137031793594,2.852015495300293,0.0 -55400,0.21522336,0.017152835,,,,,,,,,,,,,,,,, -55500,0.24402772,0.019445743,,,,,,,,,,,,,,,,, -55600,0.23480392,0.018161818,,,,,,,,,,,,,,,,, -55700,0.1972607,0.019603424,,,,,,,,,,,,,,,,, -55800,0.24594066,0.019122822,,,,,,,,,,,,,,,,, -55900,0.21290113,0.016862364,,,,,,,,,,,,,,,,, -56000,0.19700062,0.01772223,,,,,,,,,,,,,,,,, -56100,0.19038221,0.015598496,,,,,,,,,,,,,,,,, -56149,,,0.994035005569458,0.0182362999767065,0.6705837973689833,0.986273467540741,0.0518716983497142,0.2665896640185701,43793.0,0.9854514598846436,0.0548352636396884,0.2518330540823449,43793.0,18020.540464401245,26355.229289531708,18020.540464401245,8330.204983472824,2.890192985534668,0.0 -56200,0.20463839,0.018916758,,,,,,,,,,,,,,,,, -56300,0.20757742,0.017367002,,,,,,,,,,,,,,,,, -56400,0.1972264,0.014858108,,,,,,,,,,,,,,,,, -56500,0.19401887,0.017119009,,,,,,,,,,,,,,,,, -56600,0.21020587,0.01606319,,,,,,,,,,,,,,,,, -56700,0.20447709,0.017452763,,,,,,,,,,,,,,,,, -56800,0.20590627,0.019860944,,,,,,,,,,,,,,,,, -56900,0.20263188,0.018518293,,,,,,,,,,,,,,,,, -56909,,,0.9949232339859008,0.0158602874726057,0.7254194245765323,0.9862621426582336,0.0520996116101741,0.2709062672047048,43793.0,0.985466241836548,0.0550340972840786,0.2515098676230641,43793.0,18260.47955775261,26703.43917274475,18260.47955775261,8438.41262793541,2.932584762573242,0.0 -57000,0.2133148,0.017821746,,,,,,,,,,,,,,,,, -57100,0.22909953,0.01923088,,,,,,,,,,,,,,,,, -57200,0.24950252,0.018368557,,,,,,,,,,,,,,,,, -57300,0.23549588,0.018308105,,,,,,,,,,,,,,,,, -57400,0.19914427,0.016752033,,,,,,,,,,,,,,,,, -57500,0.2360647,0.016881837,,,,,,,,,,,,,,,,, -57596,,,,,,,,,,,,,,18477.267731904984,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 286b4ff04..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -106.30569434165956,0.0,11.92359495162964,1,0,11.92359495162964,0.521492063999176,0.7347381114959717,0.0268472039754593,43793,118.22934317588806,0.5323500037193298,0.7277898788452148,0.0227137389590253,0.5230699777603149,0.7331878542900085,0.0255114594813027,43793 -215.32289910316467,0.021022081375122,251.88759779930115,756,0,251.88759779930115,0.9831761717796326,0.0671860426664352,0.042759928672415,43793,467.25233030319214,0.9868640899658204,0.0536021664738655,0.0415510551682086,0.9841504096984864,0.0639973059296608,0.0408722513917534,43793 -323.3419051170349,0.0500235557556152,492.0900263786316,1500,0,492.0900263786316,0.98328697681427,0.0643152892589569,0.0656007687911264,43793,815.5264096260071,0.9869927763938904,0.0505222678184509,0.0666700498028669,0.9842925071716307,0.060984082520008,0.0673987467194187,43793 -431.4344410896301,0.0763416290283203,732.36257147789,2254,0,732.36257147789,0.983359396457672,0.0613919794559478,0.0773755137544642,43793,1163.9381144046783,0.9870300889015198,0.04844581335783,0.0834768995327652,0.9843335151672364,0.0580961555242538,0.0776466173512693,43793 -532.6530933380127,0.1035521030426025,972.623060464859,3012,0,972.623060464859,0.9836963415145874,0.0595721825957298,0.0947401907228563,43793,1505.4653851985931,0.9873122572898864,0.046275608241558,0.1062962399690621,0.984628200531006,0.0561421811580657,0.0958496041953559,43793 -642.2608568668365,0.1302659511566162,1212.7634572982788,3767,0,1212.7634572982788,0.9836066365242004,0.0575594492256641,0.1131281724685992,43793,1855.2609441280365,0.987439751625061,0.0452545993030071,0.1275264101577466,0.9846375584602356,0.0547099821269512,0.1125843895014612,43793 -743.5346171855927,0.1563253402709961,1452.7605681419373,4524,0,1452.7605681419373,0.9839351773262024,0.0555778555572032,0.1285225618718639,43793,2196.578159570694,0.9877470135688782,0.0434085913002491,0.1467361120975908,0.984818994998932,0.0527540072798728,0.1231737757961029,43793 -848.6604132652283,0.1823468208312988,1692.9004187583923,5270,0,1692.9004187583923,0.9842110872268676,0.0552487820386886,0.1399958925446105,43793,2541.892335653305,0.9880778789520264,0.0419932827353477,0.1662197912800197,0.9851506352424622,0.0519580356776714,0.1390556061527263,43793 -951.2205955982208,0.2111051082611084,1932.855243206024,6029,0,1932.855243206024,0.9840817451477052,0.0550039038062095,0.1352366892757305,43793,2884.456712245941,0.9879913330078124,0.0421731173992157,0.1533392112329769,0.9850727319717408,0.0521065816283226,0.1399027724744271,43793 -1055.2664711475372,0.2395803928375244,2173.0677292346954,6784,0,2173.0677292346954,0.9840371012687684,0.0542875565588474,0.1423427844244604,43793,3228.763679265976,0.9880300164222716,0.0422561429440975,0.160798852638912,0.9850122332572936,0.051524419337511,0.147078046851123,43793 -1160.4659247398376,0.2674005031585693,2413.096028804779,7543,0,2413.096028804779,0.984203040599823,0.053709540516138,0.1514900506984701,43793,3574.0399181842804,0.9881213903427124,0.0411835014820098,0.1685122143369883,0.9851640462875366,0.0508228130638599,0.1580675882901569,43793 -1264.378232717514,0.2971460819244385,2653.28998541832,8297,0,2653.28998541832,0.9844747185707092,0.0528067424893379,0.1600640826682151,43793,3918.196222543717,0.9882885813713074,0.0404829382896423,0.1833573007638227,0.9854526519775392,0.0499641820788383,0.1665353489619279,43793 -1366.8684968948364,0.3246288299560547,2893.515405654907,9059,0,2893.515405654907,0.9844157695770264,0.0527447089552879,0.1588039709738697,43793,4260.960469484329,0.9884369373321532,0.040347833186388,0.1880444803569518,0.9853227734565736,0.0499247312545776,0.1570235449190813,43793 -1471.6854536533356,0.354102611541748,3133.483088493347,9818,0,3133.483088493347,0.984444797039032,0.0532606206834316,0.1585128542982423,43793,4605.794593811035,0.988334596157074,0.0405681803822517,0.1774276717068803,0.9853284358978271,0.050316285341978,0.1624657584398002,43793 -1575.0813403129578,0.3822371959686279,3373.640467405319,10579,0,3373.640467405319,0.9845644235610962,0.0525931566953659,0.1652371916166563,43793,4949.39643740654,0.9884284138679504,0.0399681627750396,0.187111713207441,0.985492467880249,0.049749307334423,0.1647399948632794,43793 -1676.585940361023,0.4333629608154297,3613.69028544426,11336,0,3613.69028544426,0.9844300746917723,0.0530269481241703,0.1644090315504303,43793,5291.022964477539,0.988300621509552,0.0401514880359172,0.1967531360042304,0.9854080080986024,0.0499443002045154,0.171657268922337,43793 -1783.1696891784668,0.4635937213897705,3853.765954732895,12099,0,3853.765954732895,0.9845282435417176,0.0532206334173679,0.1626573099817982,43793,5637.733014345169,0.9884596467018129,0.0397456176578998,0.1934612264169873,0.9854969382286072,0.0499612241983413,0.1620202862419266,43793 -1888.3162214756007,0.492189884185791,4093.766298055649,12858,0,4093.766298055649,0.9844393730163574,0.0540200620889663,0.1680094410696119,43793,5982.928592443466,0.9882970452308656,0.0401655845344066,0.2059234311374559,0.9854335784912108,0.050815675407648,0.1689011888219761,43793 -1997.0318686962128,0.5204670429229736,4334.008366107941,13619,0,4334.008366107941,0.9844822883605956,0.0527850054204463,0.1698031474244022,43793,6331.93460059166,0.9884986281394958,0.0397482253611087,0.2073664034633914,0.9854303598403932,0.0498954951763153,0.1671076033289334,43793 -2099.35003232956,0.5519809722900391,4574.039252996445,14376,0,4574.039252996445,0.9846579432487488,0.0521375574171543,0.1663883366384637,43793,6674.336926460266,0.9886337518692015,0.0393826253712177,0.1911487464016221,0.9855237007141112,0.0494581498205661,0.1669518591312967,43793 -2205.947719812393,0.58294677734375,4814.030332088471,15132,0,4814.030332088471,0.9846330881118774,0.0528563000261783,0.170379749437865,43793,7020.977716207504,0.988545536994934,0.0397002212703228,0.2028603164464443,0.9855809211730956,0.0498757548630237,0.1683711746065087,43793 -2306.9480743408203,0.6138248443603516,5054.056452035904,15892,0,5054.056452035904,0.9847198724746704,0.0521600879728794,0.177989276069501,43793,7362.055555820465,0.9885801672935486,0.039336010813713,0.1984520111062934,0.9856215119361876,0.0493347458541393,0.171709221241322,43793 -2413.8230736255646,0.6432387828826904,5294.24645280838,16644,0,5294.24645280838,0.9847078323364258,0.0525591969490051,0.1744295765461452,43793,7709.170320272446,0.9884446263313292,0.0397267453372478,0.1952258893047422,0.9856138229370116,0.0496997870504856,0.1718890388904479,43793 -2520.733085632324,0.6736774444580078,5534.219587802887,17393,0,5534.219587802887,0.9848057627677916,0.052097849547863,0.1789103944466997,43793,8056.105494737625,0.9886024594306946,0.0392973124980926,0.2104613935710548,0.9857274889945984,0.0492811389267444,0.1825998260774477,43793 -2627.6177830696106,0.7077550888061523,5774.297125339508,18143,0,5774.297125339508,0.9847438931465148,0.0521351657807827,0.1714601168378829,43793,8403.12332034111,0.9886455535888672,0.039150483906269,0.1976879380121534,0.9855988025665284,0.0494032129645347,0.1718450416887608,43793 -2732.3874497413635,0.7390317916870117,6014.361703634262,18896,0,6014.361703634262,0.9847944378852844,0.0522737577557563,0.1771656161784761,43793,8748.009192943573,0.988558292388916,0.0391103066504001,0.2097071248810855,0.9856860637664796,0.0494744516909122,0.171085756222413,43793 -2838.0074141025543,0.7716796398162842,6254.322240352631,19650,0,6254.322240352631,0.984729528427124,0.0522604845464229,0.1714548320037481,43793,9093.64295911789,0.9885598421096802,0.0392596535384655,0.203455128542813,0.9856438636779784,0.0494760014116764,0.1667181022324168,43793 -2942.955656528473,0.8033504486083984,6494.273699045181,20396,0,6494.273699045181,0.9847375750541688,0.0518952123820781,0.1791972795879412,43793,9438.594544649124,0.9887438416481018,0.038576565682888,0.212028509293876,0.9856438636779784,0.0489908568561077,0.1850536108788391,43793 -3046.653083801269,0.834456205368042,6734.436192750931,21159,0,6734.436192750931,0.9846756458282472,0.051584169268608,0.1779796657330824,43793,9782.50632095337,0.9888488054275512,0.0383082553744316,0.2137575777576014,0.9856317043304444,0.048804972320795,0.1773359528687752,43793 -3150.928318500519,0.8655087947845459,6974.460539340973,21912,0,6974.460539340973,0.984855055809021,0.0511500872671604,0.1821158694057571,43793,10126.857418060305,0.9886640906333924,0.038734383881092,0.2241103178423511,0.9857713580131532,0.0484201423823833,0.1833587164211968,43793 -3252.439876317978,0.89583420753479,7214.716456651688,22666,0,7214.716456651688,0.984789788722992,0.0514403767883777,0.1796530308851342,43793,10468.675518989565,0.9887723326683044,0.0386015810072422,0.2117687081291772,0.98566335439682,0.0487646721303463,0.1744092576389312,43793 -3357.5840377807617,0.9260289669036864,7454.776722192764,23426,0,7454.776722192764,0.9846773147583008,0.0538868382573127,0.1822086485298938,43793,10813.931846141815,0.9883869886398317,0.0401140339672565,0.2083748797725441,0.985582947731018,0.0504753105342388,0.1814963297205611,43793 -3461.6296710968018,0.956789255142212,7694.763477563858,24182,0,7694.763477563858,0.9847131371498108,0.0524004660546779,0.1733469909451849,43793,11158.016758441923,0.9885193705558776,0.0394197851419448,0.2086739845483023,0.985595166683197,0.0494046621024608,0.1779097963337989,43793 -3566.3573133945465,0.987828016281128,7935.051184654236,24932,0,7935.051184654236,0.9847312569618224,0.0529902204871177,0.173159771894493,43793,11503.084161281586,0.9886791110038756,0.0391275435686111,0.2061097173095994,0.9856958389282228,0.0496935658156871,0.1742417341431321,43793 -3669.8920063972473,1.0200772285461426,8175.208832502365,25691,0,8175.208832502365,0.9847291111946106,0.0530240833759307,0.1790258355233416,43793,11846.830078840256,0.988594651222229,0.0393481999635696,0.2041383473869536,0.9856219291687012,0.0499823614954948,0.1795545772703834,43793 -3775.2796547412872,1.051387071609497,8415.195302963257,26445,0,8415.195302963257,0.9847640991210938,0.0514288060367107,0.1743120148980194,43793,12192.255960464478,0.9887518882751464,0.0389961153268814,0.2137159547440802,0.9856349229812622,0.0487174801528453,0.177908458512726,43793 -3885.255741834641,1.0838682651519775,8655.201495409012,27195,0,8655.201495409012,0.9845766425132751,0.0518827773630619,0.1705559020023737,43793,12542.291038513184,0.9886566400527954,0.0389818735420703,0.2184096719692399,0.985464870929718,0.0490479469299316,0.1820737219027982,43793 -3990.845261335373,1.1152944564819336,8895.445118188858,27950,0,8895.445118188858,0.9847881197929382,0.0518213957548141,0.1771507754449422,43793,12888.176303863524,0.9887489676475524,0.038225021213293,0.2238591435715243,0.9858261346817015,0.0486630946397781,0.1828395677585086,43793 -4094.420918226242,1.1478347778320312,9135.70709347725,28714,0,9135.70709347725,0.9847543835639954,0.0520633272826671,0.1759829370541633,43793,13232.066885709764,0.988622546195984,0.0386340096592903,0.2167209401037269,0.9856897592544556,0.0489758588373661,0.1747642867575387,43793 -4199.661685228348,1.1791231632232666,9375.913627147676,29464,0,9375.913627147676,0.9848394989967346,0.0512321665883064,0.1788149519418192,43793,13577.565557718275,0.989011526107788,0.0376570969820022,0.2281551208005285,0.9858322143554688,0.0482140444219112,0.18028048993088,43793 -4303.645714521408,1.2145650386810305,9615.896932125092,30214,0,9615.896932125092,0.984809160232544,0.0517909489572048,0.1794148969781187,43793,13921.591437339785,0.988719403743744,0.0386365242302417,0.213871037985582,0.9857904314994812,0.0485792160034179,0.1853121992215218,43793 -4406.588081121445,1.246260166168213,9855.93932056427,30969,0,9855.93932056427,0.9845911860466005,0.0530778802931308,0.1786106783703574,43793,14264.628145694733,0.988382875919342,0.0396341979503631,0.2193603036571355,0.9854400753974916,0.0501162000000476,0.1812246976637882,43793 -4515.162441253662,1.278688669204712,10096.182999134064,31722,0,10096.182999134064,0.984890878200531,0.050958689302206,0.1822194739907897,43793,14613.498676538467,0.9886970520019532,0.0385732203722,0.2138601955226267,0.9858525395393372,0.0481051616370677,0.185948414920223,43793 -4618.139933824539,1.3119385242462158,10336.253832101822,32467,0,10336.253832101822,0.9847211241722108,0.0515131317079067,0.1742848635781939,43793,14956.602952480316,0.988766312599182,0.0382795669138431,0.2245607374487249,0.9856763482093812,0.0484305880963802,0.1850852054221676,43793 -4721.810349941254,1.3457691669464111,10576.26902961731,33217,0,10576.26902961731,0.9849864840507508,0.0509848780930042,0.1873119050115221,43793,15300.3434009552,0.989039421081543,0.0377273596823215,0.2320768232391687,0.985966980457306,0.047953937202692,0.1985244284184877,43793 -4826.063565731049,1.380265235900879,10816.320188999176,33971,0,10816.320188999176,0.9849759340286256,0.0510173439979553,0.1896981083265152,43793,15644.702628612518,0.9889031052589417,0.0379895083606243,0.2294375844036454,0.9858683347702026,0.048191137611866,0.1919149423802982,43793 -4932.482318401337,1.4145681858062744,11056.37366938591,34728,0,11056.37366938591,0.9849582314491272,0.0511214435100555,0.1860170698764786,43793,15991.2295897007,0.9889699220657348,0.0376724116504192,0.2294512134532897,0.9858846068382264,0.0480569601058959,0.1883638989656708,43793 -5038.0507435798645,1.4474399089813232,11296.35082745552,35491,0,11296.35082745552,0.9849721789360046,0.0509317182004451,0.1839866549682523,43793,16336.828307151794,0.9889968633651732,0.0373593904078006,0.2480038810268567,0.9859061241149902,0.0479318015277385,0.1854400662467406,43793 -5143.885582208633,1.4814918041229248,11536.564465284348,36246,0,11536.564465284348,0.9849587082862854,0.0507874898612499,0.1820291417245847,43793,16682.931549072266,0.9889400601387024,0.0376297198235988,0.2289945929014396,0.9858976006507874,0.0478270538151264,0.1902274000931431,43793 -5246.514811038971,1.5184621810913086,11776.666022777556,36989,0,11776.666022777556,0.984892964363098,0.0507130473852157,0.1814267991041178,43793,17025.721967220306,0.9891141057014464,0.0372612811625003,0.2363792503926179,0.9858058094978333,0.0480333343148231,0.185534120927033,43793 -5350.075026988983,1.552298069000244,12016.776044130323,37742,0,12016.776044130323,0.9848238825798036,0.0511325113475322,0.1817323313886335,43793,17369.446642637253,0.9889169335365297,0.0378450527787208,0.236725344213282,0.9857680797576904,0.0481749102473259,0.1899463025558951,43793 -5450.9605281353,1.5866341590881348,12257.044536828997,38492,0,12257.044536828997,0.98500794172287,0.0509182810783386,0.1852429104902491,43793,17710.656693458557,0.9891027808189392,0.0373391807079315,0.2300288003892227,0.9859219193458556,0.0479280315339565,0.1933169156122331,43793 -5551.708795070648,1.6208436489105225,12497.11010313034,39249,0,12497.11010313034,0.9849451780319214,0.0517506897449493,0.1873046864618587,43793,18051.52519583702,0.9887359738349916,0.0383051224052906,0.2391878173293923,0.9858935475349426,0.0485585033893585,0.1977071193646147,43793 -5652.283429861069,1.6545679569244385,12737.06660580635,40010,0,12737.06660580635,0.985103964805603,0.0504989437758922,0.1960561700278622,43793,18392.110761642456,0.9889850616455078,0.0373271107673645,0.2340222648566641,0.9859641790390016,0.0478209815919399,0.1918717627352083,43793 -5755.875368833542,1.688731670379639,12977.28224658966,40763,0,12977.28224658966,0.9850677847862244,0.0508184693753719,0.1908341660525139,43793,18735.97519636154,0.9889816045761108,0.0374413505196571,0.229548285313721,0.9859832525253296,0.0477732196450233,0.1986464619495477,43793 -5861.133890390396,1.7246840000152588,13217.424660682678,41526,0,13217.424660682678,0.9847068190574646,0.0516904331743717,0.1825132729511454,43793,19081.43302559853,0.988987386226654,0.0376761332154274,0.2321438418027664,0.9857603907585144,0.0486207045614719,0.1892344774317401,43793 -5968.8874089717865,1.759670972824097,13457.660992622375,42287,0,13457.660992622375,0.9851654767990112,0.0507156662642955,0.1907572109509011,43793,19429.478434562683,0.9890965819358826,0.0369728617370128,0.2431629217687319,0.9859690070152284,0.0477113537490367,0.1947566100337512,43793 -6071.891428232193,1.7949151992797852,13697.648321390152,43042,0,13697.648321390152,0.9849687814712524,0.0506608895957469,0.1921175012712416,43793,19772.52578687668,0.98912513256073,0.0369981303811073,0.2451984672837403,0.9858740568161012,0.047635443508625,0.1916597128129191,43793 -6176.329439401627,1.829290151596069,13937.767841815948,43802,0,13937.767841815948,0.9850454330444336,0.0504810027778148,0.1850773565321149,43793,20117.138426303864,0.989127278327942,0.0368660315871238,0.2467007561681576,0.9859893321990968,0.0476075746119022,0.1888054470829359,43793 -6276.924092531204,1.8644392490386963,14177.993677854538,44565,0,14177.993677854538,0.9850686192512512,0.0505295880138874,0.1911015084064031,43793,20458.01491880417,0.9892011284828186,0.0366797894239425,0.2519868990290798,0.9859905242919922,0.0475577712059021,0.1938699084327157,43793 -6382.364597320557,1.9001915454864504,14418.119437456133,45324,0,14418.119437456133,0.9850033521652222,0.0514786131680011,0.1892441479288406,43793,20803.637764453888,0.9893069863319396,0.0367147289216518,0.2476210653071464,0.9859580397605896,0.0481480509042739,0.1963034593315598,43793 -6487.826305150986,1.9359188079833984,14658.144562959673,46075,0,14658.144562959673,0.9850778579711914,0.0500235259532928,0.1901680219495883,43793,21149.18072462082,0.989287257194519,0.0366210155189037,0.2516408284043377,0.986051857471466,0.0470076762139797,0.1985409708187414,43793 -6591.78231549263,1.970759391784668,14898.28629231453,46830,0,14898.28629231453,0.9850918054580688,0.0501007810235023,0.2014320869136368,43793,21493.333992242813,0.9892902970314026,0.0364799313247203,0.2555223045112521,0.9860911965370178,0.0471407510340213,0.2065983850306616,43793 -6702.919489622116,2.0053157806396484,15138.5211622715,47588,0,15138.5211622715,0.9851625561714172,0.0500448420643806,0.1926117407578925,43793,21844.761209726334,0.98917818069458,0.0369547009468078,0.2552370542538443,0.9860566854476928,0.0472060516476631,0.2016886836342544,43793 -6804.498358488083,2.044494390487671,15378.47325873375,48349,0,15378.47325873375,0.9851372838020324,0.049987506121397,0.1951242118193244,43793,22186.35283780098,0.9893152117729188,0.0363342836499214,0.2480764058432878,0.9860628247261048,0.0470212027430534,0.2059009357194049,43793 -6908.433473587036,2.0814638137817383,15618.42730998993,49112,0,15618.42730998993,0.985241711139679,0.0499663092195987,0.1993356351891673,43793,22530.29949116707,0.9893460273742676,0.0361896976828575,0.2436113602084199,0.9861687421798706,0.0469433367252349,0.2054351799038135,43793 -7009.377001285553,2.12254285812378,15858.633188962936,49869,0,15858.633188962936,0.9851536750793456,0.0498836562037467,0.2006272679957678,43793,22871.511667251587,0.9894251823425292,0.0361647494137287,0.2601694948034501,0.9860376119613647,0.0469044111669063,0.2030483587220174,43793 -7113.305378198624,2.1576476097106934,16098.7771692276,50631,0,16098.7771692276,0.9852059483528136,0.0497694425284862,0.1960011431400151,43793,23215.639623880383,0.9894543290138244,0.0357305705547332,0.2622053498185402,0.9860794544219972,0.0468679517507553,0.2063089989078978,43793 -7217.148921728134,2.195411443710327,16338.885677576063,51381,0,16338.885677576063,0.9851823449134828,0.050275906920433,0.2007918248456357,43793,23559.649918794632,0.9894623160362244,0.0354856625199317,0.2692625219377699,0.986139714717865,0.0471389181911945,0.206980863824063,43793 -7323.358339309692,2.233850240707397,16579.111583709717,52137,0,16579.111583709717,0.9852589964866638,0.0498905703425407,0.1996205651212903,43793,23906.143963336945,0.9895649552345276,0.0351977571845054,0.274967294293767,0.9861720204353333,0.0469048433005809,0.2073715671016123,43793 -7428.175041437149,2.2763447761535645,16819.154947042465,52887,0,16819.154947042465,0.9852981567382812,0.0499219372868537,0.1996159199399781,43793,24251.066781044006,0.9894105792045592,0.0356871746480464,0.2732275237291334,0.9862263798713684,0.0467871576547622,0.2090118597592896,43793 -7530.755472898483,2.313117027282715,17059.26643562317,53642,0,17059.26643562317,0.9852105379104614,0.0497577525675296,0.1973662601774869,43793,24593.81644487381,0.9896687865257264,0.0350858755409717,0.2733144440364853,0.9860932230949402,0.0468956939876079,0.2068264701613324,43793 -7634.566126823425,2.349884271621704,17299.366802215576,54398,0,17299.366802215576,0.9852269887924194,0.0497096553444862,0.1957582689787017,43793,24937.78490638733,0.9894760251045228,0.0357528924942016,0.264047843704155,0.9861346483230592,0.0467527098953723,0.2083068753505612,43793 -7733.860890388489,2.3890788555145264,17539.38094496727,55156,0,17539.38094496727,0.9852808713912964,0.0498327724635601,0.202013044800903,43793,25277.153856515884,0.9895650148391724,0.0352049171924591,0.2680131655883736,0.986196756362915,0.04677639529109,0.2084649784585847,43793 -7836.111313343048,2.427854061126709,17779.42640852928,55922,0,17779.42640852928,0.9852501749992372,0.0495138205587863,0.2028329675264197,43793,25619.509373903275,0.9894495010375975,0.0355528406798839,0.2755270016128907,0.9862328767776488,0.0465220957994461,0.2131809277834291,43793 -7939.627541303635,2.465727567672729,18019.61223173141,56673,0,18019.61223173141,0.985371470451355,0.0493917502462863,0.2039670057834056,43793,25963.269825220108,0.9895038604736328,0.0352360047399997,0.2716795602529939,0.9862877130508424,0.0464650578796863,0.2110234043764053,43793 -8041.5682673454285,2.5035622119903564,18259.605164289474,57429,0,18259.605164289474,0.9852796196937561,0.049422599375247955,0.2015535494588887,43793,26305.261911392212,0.9898420572280884,0.03476691246032715,0.28116842626885963,0.9861756563186646,0.04659731686115265,0.21251331949642285,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/measurements.csv deleted file mode 100644 index 92312e895..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/measurements.csv +++ /dev/null @@ -1,661 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.5278704,0.7267631,,,,,,,,,,,,,,,,, -1,,,0.5323500037193298,0.7277898788452148,0.0227137389590253,0.5230699777603149,0.7331878542900085,0.0255114594813027,43793.0,0.521492063999176,0.7347381114959717,0.0268472039754593,43793.0,11.92359495162964,118.22934317588806,11.92359495162964,106.30569434165956,0.0,0.0 -100,0.10748224,0.11501641,,,,,,,,,,,,,,,,, -200,0.010978516,0.06538002,,,,,,,,,,,,,,,,, -300,0.012269616,0.053239267,,,,,,,,,,,,,,,,, -400,0.010108517,0.054288503,,,,,,,,,,,,,,,,, -500,0.014944052,0.058671553,,,,,,,,,,,,,,,,, -600,0.021478249,0.052838027,,,,,,,,,,,,,,,,, -700,0.013751589,0.05485368,,,,,,,,,,,,,,,,, -756,,,0.9868640899658204,0.0536021664738655,0.0415510551682086,0.9841504096984864,0.0639973059296608,0.0408722513917534,43793.0,0.9831761717796326,0.0671860426664352,0.042759928672415,43793.0,251.88759779930115,467.25233030319214,251.88759779930115,215.32289910316467,0.021022081375122,0.0 -800,0.011231126,0.05441603,,,,,,,,,,,,,,,,, -900,0.006347251,0.056638155,,,,,,,,,,,,,,,,, -1000,0.014655706,0.05349218,,,,,,,,,,,,,,,,, -1100,0.023302324,0.051346272,,,,,,,,,,,,,,,,, -1200,0.010931454,0.046686165,,,,,,,,,,,,,,,,, -1300,0.00633068,0.054627445,,,,,,,,,,,,,,,,, -1400,0.006388682,0.046706874,,,,,,,,,,,,,,,,, -1500,,,0.9869927763938904,0.0505222678184509,0.0666700498028669,0.9842925071716307,0.060984082520008,0.0673987467194187,43793.0,0.98328697681427,0.0643152892589569,0.0656007687911264,43793.0,492.0900263786316,815.5264096260071,492.0900263786316,323.3419051170349,0.0500235557556152,0.0 -1500,0.011483766,0.05689821,,,,,,,,,,,,,,,,, -1600,0.005824811,0.050035574,,,,,,,,,,,,,,,,, -1700,0.021564394,0.057703603,,,,,,,,,,,,,,,,, -1800,0.030162556,0.04750776,,,,,,,,,,,,,,,,, -1900,0.012138746,0.04787388,,,,,,,,,,,,,,,,, -2000,0.019660817,0.05158008,,,,,,,,,,,,,,,,, -2100,0.013670466,0.050253026,,,,,,,,,,,,,,,,, -2200,0.01329168,0.053313337,,,,,,,,,,,,,,,,, -2254,,,0.9870300889015198,0.04844581335783,0.0834768995327652,0.9843335151672364,0.0580961555242538,0.0776466173512693,43793.0,0.983359396457672,0.0613919794559478,0.0773755137544642,43793.0,732.36257147789,1163.9381144046783,732.36257147789,431.4344410896301,0.0763416290283203,0.0 -2300,0.014212702,0.05502523,,,,,,,,,,,,,,,,, -2400,0.0063739945,0.045601748,,,,,,,,,,,,,,,,, -2500,0.026870726,0.04948367,,,,,,,,,,,,,,,,, -2600,0.008485311,0.04296081,,,,,,,,,,,,,,,,, -2700,0.021850826,0.044909056,,,,,,,,,,,,,,,,, -2800,0.019387802,0.046778273,,,,,,,,,,,,,,,,, -2900,0.019445397,0.051450904,,,,,,,,,,,,,,,,, -3000,0.013287821,0.047806464,,,,,,,,,,,,,,,,, -3012,,,0.9873122572898864,0.046275608241558,0.1062962399690621,0.984628200531006,0.0561421811580657,0.0958496041953559,43793.0,0.9836963415145874,0.0595721825957298,0.0947401907228563,43793.0,972.623060464859,1505.4653851985931,972.623060464859,532.6530933380127,0.1035521030426025,0.0 -3100,0.03615168,0.04621534,,,,,,,,,,,,,,,,, -3200,0.021108199,0.045283794,,,,,,,,,,,,,,,,, -3300,0.016068809,0.043802872,,,,,,,,,,,,,,,,, -3400,0.014342737,0.04183886,,,,,,,,,,,,,,,,, -3500,0.03526293,0.043768894,,,,,,,,,,,,,,,,, -3600,0.03311484,0.043253325,,,,,,,,,,,,,,,,, -3700,0.013933019,0.04469903,,,,,,,,,,,,,,,,, -3767,,,0.987439751625061,0.0452545993030071,0.1275264101577466,0.9846375584602356,0.0547099821269512,0.1125843895014612,43793.0,0.9836066365242004,0.0575594492256641,0.1131281724685992,43793.0,1212.7634572982788,1855.2609441280365,1212.7634572982788,642.2608568668365,0.1302659511566162,0.0 -3800,0.02483564,0.03968668,,,,,,,,,,,,,,,,, -3900,0.044114463,0.042342648,,,,,,,,,,,,,,,,, -4000,0.01831421,0.046030134,,,,,,,,,,,,,,,,, -4100,0.037531674,0.04381893,,,,,,,,,,,,,,,,, -4200,0.026491499,0.03949414,,,,,,,,,,,,,,,,, -4300,0.028455863,0.04409347,,,,,,,,,,,,,,,,, -4400,0.017688854,0.040528562,,,,,,,,,,,,,,,,, -4500,0.033263527,0.04286382,,,,,,,,,,,,,,,,, -4524,,,0.9877470135688782,0.0434085913002491,0.1467361120975908,0.984818994998932,0.0527540072798728,0.1231737757961029,43793.0,0.9839351773262024,0.0555778555572032,0.1285225618718639,43793.0,1452.7605681419373,2196.578159570694,1452.7605681419373,743.5346171855927,0.1563253402709961,0.0 -4600,0.033949573,0.047642402,,,,,,,,,,,,,,,,, -4700,0.029257609,0.03865428,,,,,,,,,,,,,,,,, -4800,0.03397558,0.04293989,,,,,,,,,,,,,,,,, -4900,0.052845567,0.04024688,,,,,,,,,,,,,,,,, -5000,0.03175799,0.04374604,,,,,,,,,,,,,,,,, -5100,0.021604225,0.03981341,,,,,,,,,,,,,,,,, -5200,0.048219543,0.043398447,,,,,,,,,,,,,,,,, -5270,,,0.9880778789520264,0.0419932827353477,0.1662197912800197,0.9851506352424622,0.0519580356776714,0.1390556061527263,43793.0,0.9842110872268676,0.0552487820386886,0.1399958925446105,43793.0,1692.9004187583923,2541.892335653305,1692.9004187583923,848.6604132652283,0.1823468208312988,0.0 -5300,0.034604702,0.041246288,,,,,,,,,,,,,,,,, -5400,0.06636001,0.04836739,,,,,,,,,,,,,,,,, -5500,0.08612109,0.046231013,,,,,,,,,,,,,,,,, -5600,0.032196596,0.04389057,,,,,,,,,,,,,,,,, -5700,0.030483572,0.0356252,,,,,,,,,,,,,,,,, -5800,0.055502202,0.041967627,,,,,,,,,,,,,,,,, -5900,0.06680645,0.039019894,,,,,,,,,,,,,,,,, -6000,0.054478478,0.037649747,,,,,,,,,,,,,,,,, -6029,,,0.9879913330078124,0.0421731173992157,0.1533392112329769,0.9850727319717408,0.0521065816283226,0.1399027724744271,43793.0,0.9840817451477052,0.0550039038062095,0.1352366892757305,43793.0,1932.855243206024,2884.456712245941,1932.855243206024,951.2205955982208,0.2111051082611084,0.0 -6100,0.04494407,0.03507084,,,,,,,,,,,,,,,,, -6200,0.021888072,0.04015039,,,,,,,,,,,,,,,,, -6300,0.035241652,0.039748263,,,,,,,,,,,,,,,,, -6400,0.050332434,0.04230018,,,,,,,,,,,,,,,,, -6500,0.104535736,0.043382972,,,,,,,,,,,,,,,,, -6600,0.03525582,0.037942667,,,,,,,,,,,,,,,,, -6700,0.068657435,0.04408276,,,,,,,,,,,,,,,,, -6784,,,0.9880300164222716,0.0422561429440975,0.160798852638912,0.9850122332572936,0.051524419337511,0.147078046851123,43793.0,0.9840371012687684,0.0542875565588474,0.1423427844244604,43793.0,2173.0677292346954,3228.763679265976,2173.0677292346954,1055.2664711475372,0.2395803928375244,0.0 -6800,0.04075422,0.03926489,,,,,,,,,,,,,,,,, -6900,0.06685917,0.04112907,,,,,,,,,,,,,,,,, -7000,0.03833779,0.04052297,,,,,,,,,,,,,,,,, -7100,0.026945278,0.045190986,,,,,,,,,,,,,,,,, -7200,0.062098823,0.0404361,,,,,,,,,,,,,,,,, -7300,0.025642116,0.039253715,,,,,,,,,,,,,,,,, -7400,0.059465498,0.04480345,,,,,,,,,,,,,,,,, -7500,0.030551102,0.044167895,,,,,,,,,,,,,,,,, -7543,,,0.9881213903427124,0.0411835014820098,0.1685122143369883,0.9851640462875366,0.0508228130638599,0.1580675882901569,43793.0,0.984203040599823,0.053709540516138,0.1514900506984701,43793.0,2413.096028804779,3574.0399181842804,2413.096028804779,1160.4659247398376,0.2674005031585693,0.0 -7600,0.042070687,0.04182592,,,,,,,,,,,,,,,,, -7700,0.066714674,0.03872273,,,,,,,,,,,,,,,,, -7800,0.05750816,0.040170424,,,,,,,,,,,,,,,,, -7900,0.047981896,0.037464533,,,,,,,,,,,,,,,,, -8000,0.06513553,0.045017418,,,,,,,,,,,,,,,,, -8100,0.07201509,0.04063319,,,,,,,,,,,,,,,,, -8200,0.02801885,0.0357693,,,,,,,,,,,,,,,,, -8297,,,0.9882885813713074,0.0404829382896423,0.1833573007638227,0.9854526519775392,0.0499641820788383,0.1665353489619279,43793.0,0.9844747185707092,0.0528067424893379,0.1600640826682151,43793.0,2653.28998541832,3918.196222543717,2653.28998541832,1264.378232717514,0.2971460819244385,0.0 -8300,0.039786767,0.040006936,,,,,,,,,,,,,,,,, -8400,0.038241915,0.041242927,,,,,,,,,,,,,,,,, -8500,0.074700765,0.039110877,,,,,,,,,,,,,,,,, -8600,0.079202645,0.04006721,,,,,,,,,,,,,,,,, -8700,0.04395015,0.039451536,,,,,,,,,,,,,,,,, -8800,0.051660527,0.043473784,,,,,,,,,,,,,,,,, -8900,0.042481896,0.04138763,,,,,,,,,,,,,,,,, -9000,0.04339389,0.04530335,,,,,,,,,,,,,,,,, -9059,,,0.9884369373321532,0.040347833186388,0.1880444803569518,0.9853227734565736,0.0499247312545776,0.1570235449190813,43793.0,0.9844157695770264,0.0527447089552879,0.1588039709738697,43793.0,2893.515405654907,4260.960469484329,2893.515405654907,1366.8684968948364,0.3246288299560547,0.0 -9100,0.061546735,0.040987883,,,,,,,,,,,,,,,,, -9200,0.07280167,0.04056276,,,,,,,,,,,,,,,,, -9300,0.04015834,0.04230159,,,,,,,,,,,,,,,,, -9400,0.060314573,0.04286801,,,,,,,,,,,,,,,,, -9500,0.029076243,0.037558183,,,,,,,,,,,,,,,,, -9600,0.059445214,0.033998523,,,,,,,,,,,,,,,,, -9700,0.028657347,0.041155197,,,,,,,,,,,,,,,,, -9800,0.05603834,0.043246638,,,,,,,,,,,,,,,,, -9818,,,0.988334596157074,0.0405681803822517,0.1774276717068803,0.9853284358978271,0.050316285341978,0.1624657584398002,43793.0,0.984444797039032,0.0532606206834316,0.1585128542982423,43793.0,3133.483088493347,4605.794593811035,3133.483088493347,1471.6854536533356,0.354102611541748,0.0 -9900,0.061916478,0.045171402,,,,,,,,,,,,,,,,, -10000,0.043965332,0.037042063,,,,,,,,,,,,,,,,, -10100,0.09252955,0.041496202,,,,,,,,,,,,,,,,, -10200,0.07069134,0.039180722,,,,,,,,,,,,,,,,, -10300,0.055385098,0.044752736,,,,,,,,,,,,,,,,, -10400,0.061315343,0.039163087,,,,,,,,,,,,,,,,, -10500,0.031395845,0.044443257,,,,,,,,,,,,,,,,, -10579,,,0.9884284138679504,0.0399681627750396,0.187111713207441,0.985492467880249,0.049749307334423,0.1647399948632794,43793.0,0.9845644235610962,0.0525931566953659,0.1652371916166563,43793.0,3373.640467405319,4949.39643740654,3373.640467405319,1575.0813403129578,0.3822371959686279,0.0 -10600,0.037198085,0.039527692,,,,,,,,,,,,,,,,, -10700,0.0454305,0.044965714,,,,,,,,,,,,,,,,, -10800,0.058020186,0.039417684,,,,,,,,,,,,,,,,, -10900,0.054852616,0.042049896,,,,,,,,,,,,,,,,, -11000,0.03390998,0.04350208,,,,,,,,,,,,,,,,, -11100,0.075436704,0.040492132,,,,,,,,,,,,,,,,, -11200,0.037874755,0.03908492,,,,,,,,,,,,,,,,, -11300,0.04957818,0.0411904,,,,,,,,,,,,,,,,, -11336,,,0.988300621509552,0.0401514880359172,0.1967531360042304,0.9854080080986024,0.0499443002045154,0.171657268922337,43793.0,0.9844300746917723,0.0530269481241703,0.1644090315504303,43793.0,3613.69028544426,5291.022964477539,3613.69028544426,1676.585940361023,0.4333629608154297,0.0 -11400,0.023367079,0.036797497,,,,,,,,,,,,,,,,, -11500,0.029367624,0.039626945,,,,,,,,,,,,,,,,, -11600,0.033715826,0.038938716,,,,,,,,,,,,,,,,, -11700,0.059406698,0.040764246,,,,,,,,,,,,,,,,, -11800,0.028220251,0.03664301,,,,,,,,,,,,,,,,, -11900,0.049879074,0.04292362,,,,,,,,,,,,,,,,, -12000,0.07874212,0.041821484,,,,,,,,,,,,,,,,, -12099,,,0.9884596467018129,0.0397456176578998,0.1934612264169873,0.9854969382286072,0.0499612241983413,0.1620202862419266,43793.0,0.9845282435417176,0.0532206334173679,0.1626573099817982,43793.0,3853.765954732895,5637.733014345169,3853.765954732895,1783.1696891784668,0.4635937213897705,0.0 -12100,0.06721981,0.042585306,,,,,,,,,,,,,,,,, -12200,0.09700309,0.04184174,,,,,,,,,,,,,,,,, -12300,0.057222664,0.036555823,,,,,,,,,,,,,,,,, -12400,0.044364255,0.037400234,,,,,,,,,,,,,,,,, -12500,0.05995298,0.039088342,,,,,,,,,,,,,,,,, -12600,0.03714331,0.042442128,,,,,,,,,,,,,,,,, -12700,0.052566428,0.040318977,,,,,,,,,,,,,,,,, -12800,0.14137381,0.04180371,,,,,,,,,,,,,,,,, -12858,,,0.9882970452308656,0.0401655845344066,0.2059234311374559,0.9854335784912108,0.050815675407648,0.1689011888219761,43793.0,0.9844393730163574,0.0540200620889663,0.1680094410696119,43793.0,4093.766298055649,5982.928592443466,4093.766298055649,1888.3162214756007,0.492189884185791,0.0 -12900,0.11029748,0.03724459,,,,,,,,,,,,,,,,, -13000,0.04868952,0.037885405,,,,,,,,,,,,,,,,, -13100,0.07568922,0.043970615,,,,,,,,,,,,,,,,, -13200,0.12521684,0.038813096,,,,,,,,,,,,,,,,, -13300,0.05193574,0.036687303,,,,,,,,,,,,,,,,, -13400,0.055491,0.047461137,,,,,,,,,,,,,,,,, -13500,0.06406365,0.037154358,,,,,,,,,,,,,,,,, -13600,0.04550876,0.038945723,,,,,,,,,,,,,,,,, -13619,,,0.9884986281394958,0.0397482253611087,0.2073664034633914,0.9854303598403932,0.0498954951763153,0.1671076033289334,43793.0,0.9844822883605956,0.0527850054204463,0.1698031474244022,43793.0,4334.008366107941,6331.93460059166,4334.008366107941,1997.0318686962128,0.5204670429229736,0.0 -13700,0.037856508,0.043413743,,,,,,,,,,,,,,,,, -13800,0.044980463,0.043029223,,,,,,,,,,,,,,,,, -13900,0.02904219,0.040243722,,,,,,,,,,,,,,,,, -14000,0.03678311,0.03930272,,,,,,,,,,,,,,,,, -14100,0.06490809,0.039902486,,,,,,,,,,,,,,,,, -14200,0.13917631,0.0405405,,,,,,,,,,,,,,,,, -14300,0.037502617,0.037595253,,,,,,,,,,,,,,,,, -14376,,,0.9886337518692015,0.0393826253712177,0.1911487464016221,0.9855237007141112,0.0494581498205661,0.1669518591312967,43793.0,0.9846579432487488,0.0521375574171543,0.1663883366384637,43793.0,4574.039252996445,6674.336926460266,4574.039252996445,2099.35003232956,0.5519809722900391,0.0 -14400,0.057279423,0.043567088,,,,,,,,,,,,,,,,, -14500,0.072033,0.036949493,,,,,,,,,,,,,,,,, -14600,0.045124765,0.038151857,,,,,,,,,,,,,,,,, -14700,0.038193334,0.038401052,,,,,,,,,,,,,,,,, -14800,0.08494832,0.04270456,,,,,,,,,,,,,,,,, -14900,0.067780316,0.04512384,,,,,,,,,,,,,,,,, -15000,0.04518985,0.04163771,,,,,,,,,,,,,,,,, -15100,0.046136286,0.039722368,,,,,,,,,,,,,,,,, -15132,,,0.988545536994934,0.0397002212703228,0.2028603164464443,0.9855809211730956,0.0498757548630237,0.1683711746065087,43793.0,0.9846330881118774,0.0528563000261783,0.170379749437865,43793.0,4814.030332088471,7020.977716207504,4814.030332088471,2205.947719812393,0.58294677734375,0.0 -15200,0.09579507,0.038849946,,,,,,,,,,,,,,,,, -15300,0.06881531,0.044043504,,,,,,,,,,,,,,,,, -15400,0.07645311,0.043828353,,,,,,,,,,,,,,,,, -15500,0.029948864,0.037364528,,,,,,,,,,,,,,,,, -15600,0.10425653,0.043457333,,,,,,,,,,,,,,,,, -15700,0.039510652,0.038270164,,,,,,,,,,,,,,,,, -15800,0.04318703,0.036072023,,,,,,,,,,,,,,,,, -15892,,,0.9885801672935486,0.039336010813713,0.1984520111062934,0.9856215119361876,0.0493347458541393,0.171709221241322,43793.0,0.9847198724746704,0.0521600879728794,0.177989276069501,43793.0,5054.056452035904,7362.055555820465,5054.056452035904,2306.9480743408203,0.6138248443603516,0.0 -15900,0.06296673,0.039001197,,,,,,,,,,,,,,,,, -16000,0.05429381,0.04198288,,,,,,,,,,,,,,,,, -16100,0.06471943,0.039014634,,,,,,,,,,,,,,,,, -16200,0.03617266,0.038029738,,,,,,,,,,,,,,,,, -16300,0.05064534,0.03386402,,,,,,,,,,,,,,,,, -16400,0.05379619,0.039105333,,,,,,,,,,,,,,,,, -16500,0.07057439,0.033934217,,,,,,,,,,,,,,,,, -16600,0.04238674,0.039429516,,,,,,,,,,,,,,,,, -16644,,,0.9884446263313292,0.0397267453372478,0.1952258893047422,0.9856138229370116,0.0496997870504856,0.1718890388904479,43793.0,0.9847078323364258,0.0525591969490051,0.1744295765461452,43793.0,5294.24645280838,7709.170320272446,5294.24645280838,2413.8230736255646,0.6432387828826904,0.0 -16700,0.050007418,0.041954763,,,,,,,,,,,,,,,,, -16800,0.027646821,0.04055236,,,,,,,,,,,,,,,,, -16900,0.055896804,0.040838353,,,,,,,,,,,,,,,,, -17000,0.065196924,0.042822324,,,,,,,,,,,,,,,,, -17100,0.0544246,0.040482897,,,,,,,,,,,,,,,,, -17200,0.035243016,0.041579504,,,,,,,,,,,,,,,,, -17300,0.050443117,0.03644773,,,,,,,,,,,,,,,,, -17393,,,0.9886024594306946,0.0392973124980926,0.2104613935710548,0.9857274889945984,0.0492811389267444,0.1825998260774477,43793.0,0.9848057627677916,0.052097849547863,0.1789103944466997,43793.0,5534.219587802887,8056.105494737625,5534.219587802887,2520.733085632324,0.6736774444580078,0.0 -17400,0.033660263,0.040705573,,,,,,,,,,,,,,,,, -17500,0.04932944,0.043814667,,,,,,,,,,,,,,,,, -17600,0.02701204,0.03941426,,,,,,,,,,,,,,,,, -17700,0.031788126,0.039354652,,,,,,,,,,,,,,,,, -17800,0.05739844,0.04273956,,,,,,,,,,,,,,,,, -17900,0.09517173,0.043535117,,,,,,,,,,,,,,,,, -18000,0.12182443,0.041206434,,,,,,,,,,,,,,,,, -18100,0.045854148,0.042248566,,,,,,,,,,,,,,,,, -18143,,,0.9886455535888672,0.039150483906269,0.1976879380121534,0.9855988025665284,0.0494032129645347,0.1718450416887608,43793.0,0.9847438931465148,0.0521351657807827,0.1714601168378829,43793.0,5774.297125339508,8403.12332034111,5774.297125339508,2627.6177830696106,0.7077550888061523,0.0 -18200,0.043864526,0.03859648,,,,,,,,,,,,,,,,, -18300,0.09173401,0.040374756,,,,,,,,,,,,,,,,, -18400,0.0776128,0.044423807,,,,,,,,,,,,,,,,, -18500,0.07502354,0.038655505,,,,,,,,,,,,,,,,, -18600,0.071931824,0.042274944,,,,,,,,,,,,,,,,, -18700,0.05572965,0.035499185,,,,,,,,,,,,,,,,, -18800,0.062938824,0.037513703,,,,,,,,,,,,,,,,, -18896,,,0.988558292388916,0.0391103066504001,0.2097071248810855,0.9856860637664796,0.0494744516909122,0.171085756222413,43793.0,0.9847944378852844,0.0522737577557563,0.1771656161784761,43793.0,6014.361703634262,8748.009192943573,6014.361703634262,2732.3874497413635,0.7390317916870117,0.0 -18900,0.051435802,0.04048275,,,,,,,,,,,,,,,,, -19000,0.042795658,0.041461468,,,,,,,,,,,,,,,,, -19100,0.061183214,0.04033907,,,,,,,,,,,,,,,,, -19200,0.12157017,0.03664417,,,,,,,,,,,,,,,,, -19300,0.056592982,0.040021412,,,,,,,,,,,,,,,,, -19400,0.037650075,0.037036702,,,,,,,,,,,,,,,,, -19500,0.045678712,0.03501889,,,,,,,,,,,,,,,,, -19600,0.085264705,0.04332772,,,,,,,,,,,,,,,,, -19650,,,0.9885598421096802,0.0392596535384655,0.203455128542813,0.9856438636779784,0.0494760014116764,0.1667181022324168,43793.0,0.984729528427124,0.0522604845464229,0.1714548320037481,43793.0,6254.322240352631,9093.64295911789,6254.322240352631,2838.0074141025543,0.7716796398162842,0.0 -19700,0.065432824,0.040845674,,,,,,,,,,,,,,,,, -19800,0.05616981,0.043363184,,,,,,,,,,,,,,,,, -19900,0.028175972,0.038475595,,,,,,,,,,,,,,,,, -20000,0.050495848,0.039158825,,,,,,,,,,,,,,,,, -20100,0.07632966,0.036108516,,,,,,,,,,,,,,,,, -20200,0.03485448,0.0364506,,,,,,,,,,,,,,,,, -20300,0.031378224,0.03615754,,,,,,,,,,,,,,,,, -20396,,,0.9887438416481018,0.038576565682888,0.212028509293876,0.9856438636779784,0.0489908568561077,0.1850536108788391,43793.0,0.9847375750541688,0.0518952123820781,0.1791972795879412,43793.0,6494.273699045181,9438.594544649124,6494.273699045181,2942.955656528473,0.8033504486083984,0.0 -20400,0.030509558,0.037032433,,,,,,,,,,,,,,,,, -20500,0.0496734,0.042366683,,,,,,,,,,,,,,,,, -20600,0.039475136,0.033495508,,,,,,,,,,,,,,,,, -20700,0.083471745,0.041059233,,,,,,,,,,,,,,,,, -20800,0.06508389,0.03939382,,,,,,,,,,,,,,,,, -20900,0.058577955,0.038241975,,,,,,,,,,,,,,,,, -21000,0.07979378,0.040945068,,,,,,,,,,,,,,,,, -21100,0.069010206,0.037672263,,,,,,,,,,,,,,,,, -21159,,,0.9888488054275512,0.0383082553744316,0.2137575777576014,0.9856317043304444,0.048804972320795,0.1773359528687752,43793.0,0.9846756458282472,0.051584169268608,0.1779796657330824,43793.0,6734.436192750931,9782.50632095337,6734.436192750931,3046.653083801269,0.834456205368042,0.0 -21200,0.034733273,0.038199887,,,,,,,,,,,,,,,,, -21300,0.06337778,0.040304817,,,,,,,,,,,,,,,,, -21400,0.034097433,0.042621266,,,,,,,,,,,,,,,,, -21500,0.19430147,0.03977056,,,,,,,,,,,,,,,,, -21600,0.02874845,0.041192465,,,,,,,,,,,,,,,,, -21700,0.02974785,0.036119528,,,,,,,,,,,,,,,,, -21800,0.02817418,0.04066369,,,,,,,,,,,,,,,,, -21900,0.04223357,0.03939387,,,,,,,,,,,,,,,,, -21912,,,0.9886640906333924,0.038734383881092,0.2241103178423511,0.9857713580131532,0.0484201423823833,0.1833587164211968,43793.0,0.984855055809021,0.0511500872671604,0.1821158694057571,43793.0,6974.460539340973,10126.857418060305,6974.460539340973,3150.928318500519,0.8655087947845459,0.0 -22000,0.08581121,0.03694602,,,,,,,,,,,,,,,,, -22100,0.04657907,0.034792114,,,,,,,,,,,,,,,,, -22200,0.07712487,0.04214164,,,,,,,,,,,,,,,,, -22300,0.13470772,0.043278474,,,,,,,,,,,,,,,,, -22400,0.12281482,0.041647468,,,,,,,,,,,,,,,,, -22500,0.045004282,0.03632132,,,,,,,,,,,,,,,,, -22600,0.11629588,0.04348416,,,,,,,,,,,,,,,,, -22666,,,0.9887723326683044,0.0386015810072422,0.2117687081291772,0.98566335439682,0.0487646721303463,0.1744092576389312,43793.0,0.984789788722992,0.0514403767883777,0.1796530308851342,43793.0,7214.716456651688,10468.675518989565,7214.716456651688,3252.439876317978,0.89583420753479,0.0 -22700,0.065353855,0.043448295,,,,,,,,,,,,,,,,, -22800,0.044026796,0.04291797,,,,,,,,,,,,,,,,, -22900,0.047142725,0.037973646,,,,,,,,,,,,,,,,, -23000,0.084646806,0.041555293,,,,,,,,,,,,,,,,, -23100,0.082013056,0.039236158,,,,,,,,,,,,,,,,, -23200,0.025717914,0.034733597,,,,,,,,,,,,,,,,, -23300,0.077169284,0.041464698,,,,,,,,,,,,,,,,, -23400,0.043555655,0.04151667,,,,,,,,,,,,,,,,, -23426,,,0.9883869886398317,0.0401140339672565,0.2083748797725441,0.985582947731018,0.0504753105342388,0.1814963297205611,43793.0,0.9846773147583008,0.0538868382573127,0.1822086485298938,43793.0,7454.776722192764,10813.931846141815,7454.776722192764,3357.5840377807617,0.9260289669036864,0.0 -23500,0.09534104,0.03842884,,,,,,,,,,,,,,,,, -23600,0.049636256,0.03950212,,,,,,,,,,,,,,,,, -23700,0.051653907,0.039999362,,,,,,,,,,,,,,,,, -23800,0.053553917,0.037152294,,,,,,,,,,,,,,,,, -23900,0.030515064,0.040697824,,,,,,,,,,,,,,,,, -24000,0.06537342,0.04311029,,,,,,,,,,,,,,,,, -24100,0.043076638,0.035886995,,,,,,,,,,,,,,,,, -24182,,,0.9885193705558776,0.0394197851419448,0.2086739845483023,0.985595166683197,0.0494046621024608,0.1779097963337989,43793.0,0.9847131371498108,0.0524004660546779,0.1733469909451849,43793.0,7694.763477563858,11158.016758441923,7694.763477563858,3461.6296710968018,0.956789255142212,0.0 -24200,0.058859207,0.042039193,,,,,,,,,,,,,,,,, -24300,0.04940217,0.03868363,,,,,,,,,,,,,,,,, -24400,0.061998177,0.037186928,,,,,,,,,,,,,,,,, -24500,0.059323598,0.039274525,,,,,,,,,,,,,,,,, -24600,0.039731283,0.03841286,,,,,,,,,,,,,,,,, -24700,0.06447333,0.044743463,,,,,,,,,,,,,,,,, -24800,0.12162549,0.03564579,,,,,,,,,,,,,,,,, -24900,0.041090257,0.03465191,,,,,,,,,,,,,,,,, -24932,,,0.9886791110038756,0.0391275435686111,0.2061097173095994,0.9856958389282228,0.0496935658156871,0.1742417341431321,43793.0,0.9847312569618224,0.0529902204871177,0.173159771894493,43793.0,7935.051184654236,11503.084161281586,7935.051184654236,3566.3573133945465,0.987828016281128,0.0 -25000,0.062448706,0.035505123,,,,,,,,,,,,,,,,, -25100,0.057789464,0.0404611,,,,,,,,,,,,,,,,, -25200,0.03607652,0.035755802,,,,,,,,,,,,,,,,, -25300,0.087900974,0.041645397,,,,,,,,,,,,,,,,, -25400,0.07339945,0.035158645,,,,,,,,,,,,,,,,, -25500,0.047227476,0.039726097,,,,,,,,,,,,,,,,, -25600,0.045716543,0.04061769,,,,,,,,,,,,,,,,, -25691,,,0.988594651222229,0.0393481999635696,0.2041383473869536,0.9856219291687012,0.0499823614954948,0.1795545772703834,43793.0,0.9847291111946106,0.0530240833759307,0.1790258355233416,43793.0,8175.208832502365,11846.830078840256,8175.208832502365,3669.8920063972473,1.0200772285461426,0.0 -25700,0.06346593,0.038102295,,,,,,,,,,,,,,,,, -25800,0.04072734,0.039262637,,,,,,,,,,,,,,,,, -25900,0.046476766,0.04150639,,,,,,,,,,,,,,,,, -26000,0.03950496,0.03618564,,,,,,,,,,,,,,,,, -26100,0.04410787,0.036288604,,,,,,,,,,,,,,,,, -26200,0.04122809,0.039019126,,,,,,,,,,,,,,,,, -26300,0.060766086,0.041626863,,,,,,,,,,,,,,,,, -26400,0.039277833,0.036875065,,,,,,,,,,,,,,,,, -26445,,,0.9887518882751464,0.0389961153268814,0.2137159547440802,0.9856349229812622,0.0487174801528453,0.177908458512726,43793.0,0.9847640991210938,0.0514288060367107,0.1743120148980194,43793.0,8415.195302963257,12192.255960464478,8415.195302963257,3775.2796547412872,1.051387071609497,0.0 -26500,0.05537757,0.037397567,,,,,,,,,,,,,,,,, -26600,0.054034326,0.040282346,,,,,,,,,,,,,,,,, -26700,0.030429902,0.038757235,,,,,,,,,,,,,,,,, -26800,0.06380728,0.040842324,,,,,,,,,,,,,,,,, -26900,0.08870641,0.03626752,,,,,,,,,,,,,,,,, -27000,0.07829399,0.042598397,,,,,,,,,,,,,,,,, -27100,0.049577236,0.03970138,,,,,,,,,,,,,,,,, -27195,,,0.9886566400527954,0.0389818735420703,0.2184096719692399,0.985464870929718,0.0490479469299316,0.1820737219027982,43793.0,0.9845766425132751,0.0518827773630619,0.1705559020023737,43793.0,8655.201495409012,12542.291038513184,8655.201495409012,3885.255741834641,1.0838682651519775,0.0 -27200,0.061807483,0.037239213,,,,,,,,,,,,,,,,, -27300,0.02644829,0.039737567,,,,,,,,,,,,,,,,, -27400,0.038816724,0.040429253,,,,,,,,,,,,,,,,, -27500,0.034227863,0.034722522,,,,,,,,,,,,,,,,, -27600,0.054958697,0.038095467,,,,,,,,,,,,,,,,, -27700,0.068046756,0.041747253,,,,,,,,,,,,,,,,, -27800,0.035938058,0.04098355,,,,,,,,,,,,,,,,, -27900,0.04709104,0.03484279,,,,,,,,,,,,,,,,, -27950,,,0.9887489676475524,0.038225021213293,0.2238591435715243,0.9858261346817015,0.0486630946397781,0.1828395677585086,43793.0,0.9847881197929382,0.0518213957548141,0.1771507754449422,43793.0,8895.445118188858,12888.176303863524,8895.445118188858,3990.845261335373,1.1152944564819336,0.0 -28000,0.05258035,0.03922547,,,,,,,,,,,,,,,,, -28100,0.052569333,0.041823503,,,,,,,,,,,,,,,,, -28200,0.038921617,0.037880566,,,,,,,,,,,,,,,,, -28300,0.04545753,0.03287159,,,,,,,,,,,,,,,,, -28400,0.054874506,0.04015042,,,,,,,,,,,,,,,,, -28500,0.09592654,0.040909715,,,,,,,,,,,,,,,,, -28600,0.0641797,0.046070118,,,,,,,,,,,,,,,,, -28700,0.0288929,0.03546646,,,,,,,,,,,,,,,,, -28714,,,0.988622546195984,0.0386340096592903,0.2167209401037269,0.9856897592544556,0.0489758588373661,0.1747642867575387,43793.0,0.9847543835639954,0.0520633272826671,0.1759829370541633,43793.0,9135.70709347725,13232.066885709764,9135.70709347725,4094.420918226242,1.1478347778320312,0.0 -28800,0.03167557,0.036401033,,,,,,,,,,,,,,,,, -28900,0.09444103,0.042107265,,,,,,,,,,,,,,,,, -29000,0.03070179,0.038448703,,,,,,,,,,,,,,,,, -29100,0.040433116,0.03722565,,,,,,,,,,,,,,,,, -29200,0.088432975,0.04101323,,,,,,,,,,,,,,,,, -29300,0.051704735,0.04162486,,,,,,,,,,,,,,,,, -29400,0.042071432,0.041696105,,,,,,,,,,,,,,,,, -29464,,,0.989011526107788,0.0376570969820022,0.2281551208005285,0.9858322143554688,0.0482140444219112,0.18028048993088,43793.0,0.9848394989967346,0.0512321665883064,0.1788149519418192,43793.0,9375.913627147676,13577.565557718275,9375.913627147676,4199.661685228348,1.1791231632232666,0.0 -29500,0.082451634,0.040144484,,,,,,,,,,,,,,,,, -29600,0.0785209,0.041679975,,,,,,,,,,,,,,,,, -29700,0.035799667,0.036483943,,,,,,,,,,,,,,,,, -29800,0.05754511,0.039988406,,,,,,,,,,,,,,,,, -29900,0.06472755,0.038077135,,,,,,,,,,,,,,,,, -30000,0.091336526,0.043269936,,,,,,,,,,,,,,,,, -30100,0.058260523,0.039895646,,,,,,,,,,,,,,,,, -30200,0.06907852,0.037294507,,,,,,,,,,,,,,,,, -30214,,,0.988719403743744,0.0386365242302417,0.213871037985582,0.9857904314994812,0.0485792160034179,0.1853121992215218,43793.0,0.984809160232544,0.0517909489572048,0.1794148969781187,43793.0,9615.896932125092,13921.591437339785,9615.896932125092,4303.645714521408,1.2145650386810305,0.0 -30300,0.031860836,0.04028156,,,,,,,,,,,,,,,,, -30400,0.09158616,0.037178677,,,,,,,,,,,,,,,,, -30500,0.15278013,0.038055282,,,,,,,,,,,,,,,,, -30600,0.051373646,0.043337945,,,,,,,,,,,,,,,,, -30700,0.030980118,0.03604636,,,,,,,,,,,,,,,,, -30800,0.04376871,0.04104145,,,,,,,,,,,,,,,,, -30900,0.052625734,0.03945468,,,,,,,,,,,,,,,,, -30969,,,0.988382875919342,0.0396341979503631,0.2193603036571355,0.9854400753974916,0.0501162000000476,0.1812246976637882,43793.0,0.9845911860466005,0.0530778802931308,0.1786106783703574,43793.0,9855.93932056427,14264.628145694733,9855.93932056427,4406.588081121445,1.246260166168213,0.0 -31000,0.03214782,0.03714781,,,,,,,,,,,,,,,,, -31100,0.0727307,0.03933813,,,,,,,,,,,,,,,,, -31200,0.049463294,0.04061721,,,,,,,,,,,,,,,,, -31300,0.044039298,0.036580052,,,,,,,,,,,,,,,,, -31400,0.042529177,0.035118613,,,,,,,,,,,,,,,,, -31500,0.030824693,0.03374614,,,,,,,,,,,,,,,,, -31600,0.05330365,0.042291842,,,,,,,,,,,,,,,,, -31700,0.09587931,0.041775033,,,,,,,,,,,,,,,,, -31722,,,0.9886970520019532,0.0385732203722,0.2138601955226267,0.9858525395393372,0.0481051616370677,0.185948414920223,43793.0,0.984890878200531,0.050958689302206,0.1822194739907897,43793.0,10096.182999134064,14613.498676538467,10096.182999134064,4515.162441253662,1.278688669204712,0.0 -31800,0.04077042,0.039945785,,,,,,,,,,,,,,,,, -31900,0.040489204,0.037003785,,,,,,,,,,,,,,,,, -32000,0.043638986,0.041260466,,,,,,,,,,,,,,,,, -32100,0.078396186,0.03684258,,,,,,,,,,,,,,,,, -32200,0.058580756,0.039575737,,,,,,,,,,,,,,,,, -32300,0.043986727,0.03632179,,,,,,,,,,,,,,,,, -32400,0.08980411,0.041741826,,,,,,,,,,,,,,,,, -32467,,,0.988766312599182,0.0382795669138431,0.2245607374487249,0.9856763482093812,0.0484305880963802,0.1850852054221676,43793.0,0.9847211241722108,0.0515131317079067,0.1742848635781939,43793.0,10336.253832101822,14956.602952480316,10336.253832101822,4618.139933824539,1.3119385242462158,0.0 -32500,0.054376688,0.03672312,,,,,,,,,,,,,,,,, -32600,0.049831472,0.041447017,,,,,,,,,,,,,,,,, -32700,0.061235007,0.036511574,,,,,,,,,,,,,,,,, -32800,0.058693573,0.038516767,,,,,,,,,,,,,,,,, -32900,0.030152926,0.036831703,,,,,,,,,,,,,,,,, -33000,0.045972753,0.034971368,,,,,,,,,,,,,,,,, -33100,0.038145743,0.03864258,,,,,,,,,,,,,,,,, -33200,0.08089564,0.038093563,,,,,,,,,,,,,,,,, -33217,,,0.989039421081543,0.0377273596823215,0.2320768232391687,0.985966980457306,0.047953937202692,0.1985244284184877,43793.0,0.9849864840507508,0.0509848780930042,0.1873119050115221,43793.0,10576.26902961731,15300.3434009552,10576.26902961731,4721.810349941254,1.3457691669464111,0.0 -33300,0.050950114,0.037789054,,,,,,,,,,,,,,,,, -33400,0.07454111,0.0402078,,,,,,,,,,,,,,,,, -33500,0.05095385,0.035549425,,,,,,,,,,,,,,,,, -33600,0.044144142,0.034221705,,,,,,,,,,,,,,,,, -33700,0.04579909,0.034017242,,,,,,,,,,,,,,,,, -33800,0.04130906,0.036355752,,,,,,,,,,,,,,,,, -33900,0.07293234,0.04023449,,,,,,,,,,,,,,,,, -33971,,,0.9889031052589417,0.0379895083606243,0.2294375844036454,0.9858683347702026,0.048191137611866,0.1919149423802982,43793.0,0.9849759340286256,0.0510173439979553,0.1896981083265152,43793.0,10816.320188999176,15644.702628612518,10816.320188999176,4826.063565731049,1.380265235900879,0.0 -34000,0.06305722,0.04152864,,,,,,,,,,,,,,,,, -34100,0.0919395,0.03826996,,,,,,,,,,,,,,,,, -34200,0.04962274,0.036437307,,,,,,,,,,,,,,,,, -34300,0.08814435,0.03436003,,,,,,,,,,,,,,,,, -34400,0.120345235,0.038343307,,,,,,,,,,,,,,,,, -34500,0.041091952,0.03637251,,,,,,,,,,,,,,,,, -34600,0.079551384,0.033456262,,,,,,,,,,,,,,,,, -34700,0.068873055,0.037798353,,,,,,,,,,,,,,,,, -34728,,,0.9889699220657348,0.0376724116504192,0.2294512134532897,0.9858846068382264,0.0480569601058959,0.1883638989656708,43793.0,0.9849582314491272,0.0511214435100555,0.1860170698764786,43793.0,11056.37366938591,15991.2295897007,11056.37366938591,4932.482318401337,1.4145681858062744,0.0 -34800,0.06780557,0.038120303,,,,,,,,,,,,,,,,, -34900,0.07094199,0.039630484,,,,,,,,,,,,,,,,, -35000,0.044699825,0.039234634,,,,,,,,,,,,,,,,, -35100,0.051714588,0.03932482,,,,,,,,,,,,,,,,, -35200,0.07241295,0.03982641,,,,,,,,,,,,,,,,, -35300,0.06419403,0.03678305,,,,,,,,,,,,,,,,, -35400,0.055557165,0.039235566,,,,,,,,,,,,,,,,, -35491,,,0.9889968633651732,0.0373593904078006,0.2480038810268567,0.9859061241149902,0.0479318015277385,0.1854400662467406,43793.0,0.9849721789360046,0.0509317182004451,0.1839866549682523,43793.0,11296.35082745552,16336.828307151794,11296.35082745552,5038.0507435798645,1.4474399089813232,0.0 -35500,0.1308591,0.04389665,,,,,,,,,,,,,,,,, -35600,0.037092187,0.036389228,,,,,,,,,,,,,,,,, -35700,0.09065026,0.03753523,,,,,,,,,,,,,,,,, -35800,0.046916302,0.04209832,,,,,,,,,,,,,,,,, -35900,0.0575873,0.04053765,,,,,,,,,,,,,,,,, -36000,0.06773646,0.036212716,,,,,,,,,,,,,,,,, -36100,0.106317796,0.03687336,,,,,,,,,,,,,,,,, -36200,0.043118387,0.038106527,,,,,,,,,,,,,,,,, -36246,,,0.9889400601387024,0.0376297198235988,0.2289945929014396,0.9858976006507874,0.0478270538151264,0.1902274000931431,43793.0,0.9849587082862854,0.0507874898612499,0.1820291417245847,43793.0,11536.564465284348,16682.931549072266,11536.564465284348,5143.885582208633,1.4814918041229248,0.0 -36300,0.06876323,0.041075215,,,,,,,,,,,,,,,,, -36400,0.044999093,0.036030855,,,,,,,,,,,,,,,,, -36500,0.06701254,0.04268085,,,,,,,,,,,,,,,,, -36600,0.095343456,0.041165918,,,,,,,,,,,,,,,,, -36700,0.12141049,0.04359458,,,,,,,,,,,,,,,,, -36800,0.042047374,0.036369424,,,,,,,,,,,,,,,,, -36900,0.06344489,0.03147257,,,,,,,,,,,,,,,,, -36989,,,0.9891141057014464,0.0372612811625003,0.2363792503926179,0.9858058094978333,0.0480333343148231,0.185534120927033,43793.0,0.984892964363098,0.0507130473852157,0.1814267991041178,43793.0,11776.666022777556,17025.721967220306,11776.666022777556,5246.514811038971,1.5184621810913086,0.0 -37000,0.07778537,0.03810507,,,,,,,,,,,,,,,,, -37100,0.07488044,0.045400012,,,,,,,,,,,,,,,,, -37200,0.10408187,0.040029727,,,,,,,,,,,,,,,,, -37300,0.053424332,0.037254322,,,,,,,,,,,,,,,,, -37400,0.047452807,0.03626104,,,,,,,,,,,,,,,,, -37500,0.04204356,0.03981709,,,,,,,,,,,,,,,,, -37600,0.08718571,0.04099439,,,,,,,,,,,,,,,,, -37700,0.060687877,0.04071203,,,,,,,,,,,,,,,,, -37742,,,0.9889169335365297,0.0378450527787208,0.236725344213282,0.9857680797576904,0.0481749102473259,0.1899463025558951,43793.0,0.9848238825798036,0.0511325113475322,0.1817323313886335,43793.0,12016.776044130323,17369.446642637253,12016.776044130323,5350.075026988983,1.552298069000244,0.0 -37800,0.053970642,0.03639519,,,,,,,,,,,,,,,,, -37900,0.050309666,0.0382793,,,,,,,,,,,,,,,,, -38000,0.042183828,0.041448347,,,,,,,,,,,,,,,,, -38100,0.0913216,0.0389948,,,,,,,,,,,,,,,,, -38200,0.13839328,0.041387238,,,,,,,,,,,,,,,,, -38300,0.042183958,0.036949743,,,,,,,,,,,,,,,,, -38400,0.05874053,0.040313683,,,,,,,,,,,,,,,,, -38492,,,0.9891027808189392,0.0373391807079315,0.2300288003892227,0.9859219193458556,0.0479280315339565,0.1933169156122331,43793.0,0.98500794172287,0.0509182810783386,0.1852429104902491,43793.0,12257.044536828997,17710.656693458557,12257.044536828997,5450.9605281353,1.5866341590881348,0.0 -38500,0.10040391,0.03503734,,,,,,,,,,,,,,,,, -38600,0.06265267,0.03660272,,,,,,,,,,,,,,,,, -38700,0.0478909,0.04054362,,,,,,,,,,,,,,,,, -38800,0.122555405,0.038171086,,,,,,,,,,,,,,,,, -38900,0.09849038,0.034390647,,,,,,,,,,,,,,,,, -39000,0.05565476,0.04055474,,,,,,,,,,,,,,,,, -39100,0.053922493,0.039180323,,,,,,,,,,,,,,,,, -39200,0.056616172,0.040519923,,,,,,,,,,,,,,,,, -39249,,,0.9887359738349916,0.0383051224052906,0.2391878173293923,0.9858935475349426,0.0485585033893585,0.1977071193646147,43793.0,0.9849451780319214,0.0517506897449493,0.1873046864618587,43793.0,12497.11010313034,18051.52519583702,12497.11010313034,5551.708795070648,1.6208436489105225,0.0 -39300,0.050751597,0.037637364,,,,,,,,,,,,,,,,, -39400,0.082607955,0.041187234,,,,,,,,,,,,,,,,, -39500,0.054054,0.04077241,,,,,,,,,,,,,,,,, -39600,0.06189016,0.035864096,,,,,,,,,,,,,,,,, -39700,0.10148612,0.039704476,,,,,,,,,,,,,,,,, -39800,0.07244802,0.04019714,,,,,,,,,,,,,,,,, -39900,0.0979971,0.03940461,,,,,,,,,,,,,,,,, -40000,0.07708019,0.035473894,,,,,,,,,,,,,,,,, -40010,,,0.9889850616455078,0.0373271107673645,0.2340222648566641,0.9859641790390016,0.0478209815919399,0.1918717627352083,43793.0,0.985103964805603,0.0504989437758922,0.1960561700278622,43793.0,12737.06660580635,18392.110761642456,12737.06660580635,5652.283429861069,1.6545679569244385,0.0 -40100,0.078570135,0.038869444,,,,,,,,,,,,,,,,, -40200,0.04695643,0.035828453,,,,,,,,,,,,,,,,, -40300,0.06647508,0.034422643,,,,,,,,,,,,,,,,, -40400,0.053717878,0.037723698,,,,,,,,,,,,,,,,, -40500,0.10316402,0.04254475,,,,,,,,,,,,,,,,, -40600,0.055655617,0.033359192,,,,,,,,,,,,,,,,, -40700,0.05530867,0.034497965,,,,,,,,,,,,,,,,, -40763,,,0.9889816045761108,0.0374413505196571,0.229548285313721,0.9859832525253296,0.0477732196450233,0.1986464619495477,43793.0,0.9850677847862244,0.0508184693753719,0.1908341660525139,43793.0,12977.28224658966,18735.97519636154,12977.28224658966,5755.875368833542,1.688731670379639,0.0 -40800,0.04530939,0.039738268,,,,,,,,,,,,,,,,, -40900,0.082223415,0.039517798,,,,,,,,,,,,,,,,, -41000,0.0629037,0.036612306,,,,,,,,,,,,,,,,, -41100,0.07077773,0.038441923,,,,,,,,,,,,,,,,, -41200,0.06632818,0.039392717,,,,,,,,,,,,,,,,, -41300,0.15557653,0.040084448,,,,,,,,,,,,,,,,, -41400,0.085731804,0.040143088,,,,,,,,,,,,,,,,, -41500,0.04384323,0.032669075,,,,,,,,,,,,,,,,, -41526,,,0.988987386226654,0.0376761332154274,0.2321438418027664,0.9857603907585144,0.0486207045614719,0.1892344774317401,43793.0,0.9847068190574646,0.0516904331743717,0.1825132729511454,43793.0,13217.424660682678,19081.43302559853,13217.424660682678,5861.133890390396,1.7246840000152588,0.0 -41600,0.08487473,0.037816055,,,,,,,,,,,,,,,,, -41700,0.04201666,0.039681938,,,,,,,,,,,,,,,,, -41800,0.09491496,0.04396543,,,,,,,,,,,,,,,,, -41900,0.050546017,0.038041666,,,,,,,,,,,,,,,,, -42000,0.059507918,0.03456336,,,,,,,,,,,,,,,,, -42100,0.097886235,0.038913317,,,,,,,,,,,,,,,,, -42200,0.088700764,0.034416683,,,,,,,,,,,,,,,,, -42287,,,0.9890965819358826,0.0369728617370128,0.2431629217687319,0.9859690070152284,0.0477113537490367,0.1947566100337512,43793.0,0.9851654767990112,0.0507156662642955,0.1907572109509011,43793.0,13457.660992622375,19429.478434562683,13457.660992622375,5968.8874089717865,1.759670972824097,0.0 -42300,0.07510166,0.03880483,,,,,,,,,,,,,,,,, -42400,0.10762885,0.037693843,,,,,,,,,,,,,,,,, -42500,0.07390707,0.03541603,,,,,,,,,,,,,,,,, -42600,0.04018928,0.036639024,,,,,,,,,,,,,,,,, -42700,0.09999475,0.037929524,,,,,,,,,,,,,,,,, -42800,0.046492673,0.036262628,,,,,,,,,,,,,,,,, -42900,0.08427526,0.035687085,,,,,,,,,,,,,,,,, -43000,0.101300456,0.036812022,,,,,,,,,,,,,,,,, -43042,,,0.98912513256073,0.0369981303811073,0.2451984672837403,0.9858740568161012,0.047635443508625,0.1916597128129191,43793.0,0.9849687814712524,0.0506608895957469,0.1921175012712416,43793.0,13697.648321390152,19772.52578687668,13697.648321390152,6071.891428232193,1.7949151992797852,0.0 -43100,0.06884285,0.034399927,,,,,,,,,,,,,,,,, -43200,0.07247327,0.03425435,,,,,,,,,,,,,,,,, -43300,0.05787938,0.03840704,,,,,,,,,,,,,,,,, -43400,0.043351576,0.036972158,,,,,,,,,,,,,,,,, -43500,0.06741275,0.039421685,,,,,,,,,,,,,,,,, -43600,0.043884445,0.032210562,,,,,,,,,,,,,,,,, -43700,0.08110839,0.03923841,,,,,,,,,,,,,,,,, -43800,0.082360394,0.039080765,,,,,,,,,,,,,,,,, -43802,,,0.989127278327942,0.0368660315871238,0.2467007561681576,0.9859893321990968,0.0476075746119022,0.1888054470829359,43793.0,0.9850454330444336,0.0504810027778148,0.1850773565321149,43793.0,13937.767841815948,20117.138426303864,13937.767841815948,6176.329439401627,1.829290151596069,0.0 -43900,0.0678288,0.03336181,,,,,,,,,,,,,,,,, -44000,0.04710519,0.041745245,,,,,,,,,,,,,,,,, -44100,0.05556551,0.03779666,,,,,,,,,,,,,,,,, -44200,0.078040786,0.043862145,,,,,,,,,,,,,,,,, -44300,0.10300016,0.041999314,,,,,,,,,,,,,,,,, -44400,0.08883015,0.038779907,,,,,,,,,,,,,,,,, -44500,0.0993169,0.033973772,,,,,,,,,,,,,,,,, -44565,,,0.9892011284828186,0.0366797894239425,0.2519868990290798,0.9859905242919922,0.0475577712059021,0.1938699084327157,43793.0,0.9850686192512512,0.0505295880138874,0.1911015084064031,43793.0,14177.993677854538,20458.01491880417,14177.993677854538,6276.924092531204,1.8644392490386963,0.0 -44600,0.048203815,0.037738126,,,,,,,,,,,,,,,,, -44700,0.07133037,0.033553492,,,,,,,,,,,,,,,,, -44800,0.17229514,0.033041548,,,,,,,,,,,,,,,,, -44900,0.050137315,0.037303027,,,,,,,,,,,,,,,,, -45000,0.059316117,0.037708912,,,,,,,,,,,,,,,,, -45100,0.0558201,0.038206924,,,,,,,,,,,,,,,,, -45200,0.06388513,0.03676624,,,,,,,,,,,,,,,,, -45300,0.06829893,0.034047656,,,,,,,,,,,,,,,,, -45324,,,0.9893069863319396,0.0367147289216518,0.2476210653071464,0.9859580397605896,0.0481480509042739,0.1963034593315598,43793.0,0.9850033521652222,0.0514786131680011,0.1892441479288406,43793.0,14418.119437456133,20803.637764453888,14418.119437456133,6382.364597320557,1.9001915454864504,0.0 -45400,0.053402625,0.036463153,,,,,,,,,,,,,,,,, -45500,0.10177172,0.035246193,,,,,,,,,,,,,,,,, -45600,0.17536588,0.042442743,,,,,,,,,,,,,,,,, -45700,0.10329857,0.03532694,,,,,,,,,,,,,,,,, -45800,0.08061862,0.04109561,,,,,,,,,,,,,,,,, -45900,0.08478044,0.040758986,,,,,,,,,,,,,,,,, -46000,0.06761689,0.035973847,,,,,,,,,,,,,,,,, -46075,,,0.989287257194519,0.0366210155189037,0.2516408284043377,0.986051857471466,0.0470076762139797,0.1985409708187414,43793.0,0.9850778579711914,0.0500235259532928,0.1901680219495883,43793.0,14658.144562959673,21149.18072462082,14658.144562959673,6487.826305150986,1.9359188079833984,0.0 -46100,0.094894186,0.034003858,,,,,,,,,,,,,,,,, -46200,0.07498278,0.03631167,,,,,,,,,,,,,,,,, -46300,0.07140657,0.040256828,,,,,,,,,,,,,,,,, -46400,0.08660296,0.03399878,,,,,,,,,,,,,,,,, -46500,0.0774605,0.032429617,,,,,,,,,,,,,,,,, -46600,0.1643994,0.03928159,,,,,,,,,,,,,,,,, -46700,0.05085037,0.03720213,,,,,,,,,,,,,,,,, -46800,0.0873009,0.037652787,,,,,,,,,,,,,,,,, -46830,,,0.9892902970314026,0.0364799313247203,0.2555223045112521,0.9860911965370178,0.0471407510340213,0.2065983850306616,43793.0,0.9850918054580688,0.0501007810235023,0.2014320869136368,43793.0,14898.28629231453,21493.333992242813,14898.28629231453,6591.78231549263,1.970759391784668,0.0 -46900,0.058793306,0.03418039,,,,,,,,,,,,,,,,, -47000,0.09029037,0.037302237,,,,,,,,,,,,,,,,, -47100,0.04776088,0.03970958,,,,,,,,,,,,,,,,, -47200,0.11948309,0.035068657,,,,,,,,,,,,,,,,, -47300,0.07544487,0.03784193,,,,,,,,,,,,,,,,, -47400,0.10128856,0.03723047,,,,,,,,,,,,,,,,, -47500,0.052741088,0.03727803,,,,,,,,,,,,,,,,, -47588,,,0.98917818069458,0.0369547009468078,0.2552370542538443,0.9860566854476928,0.0472060516476631,0.2016886836342544,43793.0,0.9851625561714172,0.0500448420643806,0.1926117407578925,43793.0,15138.5211622715,21844.761209726334,15138.5211622715,6702.919489622116,2.0053157806396484,0.0 -47600,0.057408493,0.03740253,,,,,,,,,,,,,,,,, -47700,0.057559077,0.04017863,,,,,,,,,,,,,,,,, -47800,0.06369436,0.036539566,,,,,,,,,,,,,,,,, -47900,0.10239058,0.039268106,,,,,,,,,,,,,,,,, -48000,0.07813861,0.035106424,,,,,,,,,,,,,,,,, -48100,0.077288784,0.029494898,,,,,,,,,,,,,,,,, -48200,0.05459712,0.034139674,,,,,,,,,,,,,,,,, -48300,0.045883477,0.03445307,,,,,,,,,,,,,,,,, -48349,,,0.9893152117729188,0.0363342836499214,0.2480764058432878,0.9860628247261048,0.0470212027430534,0.2059009357194049,43793.0,0.9851372838020324,0.049987506121397,0.1951242118193244,43793.0,15378.47325873375,22186.35283780098,15378.47325873375,6804.498358488083,2.044494390487671,0.0 -48400,0.06879067,0.04099751,,,,,,,,,,,,,,,,, -48500,0.066603415,0.033741787,,,,,,,,,,,,,,,,, -48600,0.0790299,0.034419384,,,,,,,,,,,,,,,,, -48700,0.10840682,0.040254693,,,,,,,,,,,,,,,,, -48800,0.08156468,0.032770704,,,,,,,,,,,,,,,,, -48900,0.04937214,0.034173634,,,,,,,,,,,,,,,,, -49000,0.09210476,0.03887587,,,,,,,,,,,,,,,,, -49100,0.05629255,0.03938893,,,,,,,,,,,,,,,,, -49112,,,0.9893460273742676,0.0361896976828575,0.2436113602084199,0.9861687421798706,0.0469433367252349,0.2054351799038135,43793.0,0.985241711139679,0.0499663092195987,0.1993356351891673,43793.0,15618.42730998993,22530.29949116707,15618.42730998993,6908.433473587036,2.0814638137817383,0.0 -49200,0.07400133,0.036275364,,,,,,,,,,,,,,,,, -49300,0.06209398,0.034847718,,,,,,,,,,,,,,,,, -49400,0.109659806,0.03796814,,,,,,,,,,,,,,,,, -49500,0.09324205,0.034455385,,,,,,,,,,,,,,,,, -49600,0.06851207,0.037310764,,,,,,,,,,,,,,,,, -49700,0.08054742,0.03427755,,,,,,,,,,,,,,,,, -49800,0.095844075,0.039518896,,,,,,,,,,,,,,,,, -49869,,,0.9894251823425292,0.0361647494137287,0.2601694948034501,0.9860376119613647,0.0469044111669063,0.2030483587220174,43793.0,0.9851536750793456,0.0498836562037467,0.2006272679957678,43793.0,15858.633188962936,22871.511667251587,15858.633188962936,7009.377001285553,2.12254285812378,0.0 -49900,0.060117837,0.033988137,,,,,,,,,,,,,,,,, -50000,0.09823198,0.03388679,,,,,,,,,,,,,,,,, -50100,0.13536146,0.03527366,,,,,,,,,,,,,,,,, -50200,0.1007689,0.03809582,,,,,,,,,,,,,,,,, -50300,0.057467233,0.037310947,,,,,,,,,,,,,,,,, -50400,0.09631065,0.038537547,,,,,,,,,,,,,,,,, -50500,0.101723745,0.03945625,,,,,,,,,,,,,,,,, -50600,0.067183256,0.03150046,,,,,,,,,,,,,,,,, -50631,,,0.9894543290138244,0.0357305705547332,0.2622053498185402,0.9860794544219972,0.0468679517507553,0.2063089989078978,43793.0,0.9852059483528136,0.0497694425284862,0.1960011431400151,43793.0,16098.7771692276,23215.639623880383,16098.7771692276,7113.305378198624,2.1576476097106934,0.0 -50700,0.09884878,0.039067075,,,,,,,,,,,,,,,,, -50800,0.07701715,0.03811775,,,,,,,,,,,,,,,,, -50900,0.086415365,0.03297973,,,,,,,,,,,,,,,,, -51000,0.105869696,0.0341913,,,,,,,,,,,,,,,,, -51100,0.08891066,0.036240123,,,,,,,,,,,,,,,,, -51200,0.06385046,0.03594457,,,,,,,,,,,,,,,,, -51300,0.08077623,0.04501958,,,,,,,,,,,,,,,,, -51381,,,0.9894623160362244,0.0354856625199317,0.2692625219377699,0.986139714717865,0.0471389181911945,0.206980863824063,43793.0,0.9851823449134828,0.050275906920433,0.2007918248456357,43793.0,16338.885677576063,23559.649918794632,16338.885677576063,7217.148921728134,2.195411443710327,0.0 -51400,0.059581153,0.036025528,,,,,,,,,,,,,,,,, -51500,0.05990199,0.037294652,,,,,,,,,,,,,,,,, -51600,0.048645556,0.03360431,,,,,,,,,,,,,,,,, -51700,0.09893583,0.035950694,,,,,,,,,,,,,,,,, -51800,0.0762755,0.040480644,,,,,,,,,,,,,,,,, -51900,0.08055358,0.03972532,,,,,,,,,,,,,,,,, -52000,0.11096079,0.0393028,,,,,,,,,,,,,,,,, -52100,0.121126644,0.041377183,,,,,,,,,,,,,,,,, -52137,,,0.9895649552345276,0.0351977571845054,0.274967294293767,0.9861720204353333,0.0469048433005809,0.2073715671016123,43793.0,0.9852589964866638,0.0498905703425407,0.1996205651212903,43793.0,16579.111583709717,23906.143963336945,16579.111583709717,7323.358339309692,2.233850240707397,0.0 -52200,0.08026306,0.033174228,,,,,,,,,,,,,,,,, -52300,0.072433345,0.033206742,,,,,,,,,,,,,,,,, -52400,0.098073885,0.03340547,,,,,,,,,,,,,,,,, -52500,0.07485335,0.037429918,,,,,,,,,,,,,,,,, -52600,0.12202112,0.041307714,,,,,,,,,,,,,,,,, -52700,0.12889896,0.04094633,,,,,,,,,,,,,,,,, -52800,0.11824646,0.038257394,,,,,,,,,,,,,,,,, -52887,,,0.9894105792045592,0.0356871746480464,0.2732275237291334,0.9862263798713684,0.0467871576547622,0.2090118597592896,43793.0,0.9852981567382812,0.0499219372868537,0.1996159199399781,43793.0,16819.154947042465,24251.066781044006,16819.154947042465,7428.175041437149,2.2763447761535645,0.0 -52900,0.06655844,0.03566573,,,,,,,,,,,,,,,,, -53000,0.07229091,0.03224707,,,,,,,,,,,,,,,,, -53100,0.09009435,0.03534586,,,,,,,,,,,,,,,,, -53200,0.11255095,0.03728711,,,,,,,,,,,,,,,,, -53300,0.11821115,0.03657775,,,,,,,,,,,,,,,,, -53400,0.07097037,0.032183085,,,,,,,,,,,,,,,,, -53500,0.10107166,0.037288204,,,,,,,,,,,,,,,,, -53600,0.09751576,0.03771775,,,,,,,,,,,,,,,,, -53642,,,0.9896687865257264,0.0350858755409717,0.2733144440364853,0.9860932230949402,0.0468956939876079,0.2068264701613324,43793.0,0.9852105379104614,0.0497577525675296,0.1973662601774869,43793.0,17059.26643562317,24593.81644487381,17059.26643562317,7530.755472898483,2.313117027282715,0.0 -53700,0.07723693,0.034719437,,,,,,,,,,,,,,,,, -53800,0.05380317,0.03585416,,,,,,,,,,,,,,,,, -53900,0.11872893,0.033531055,,,,,,,,,,,,,,,,, -54000,0.13720378,0.03876983,,,,,,,,,,,,,,,,, -54100,0.054057717,0.033403974,,,,,,,,,,,,,,,,, -54200,0.15549184,0.035688777,,,,,,,,,,,,,,,,, -54300,0.10683789,0.037440643,,,,,,,,,,,,,,,,, -54398,,,0.9894760251045228,0.0357528924942016,0.264047843704155,0.9861346483230592,0.0467527098953723,0.2083068753505612,43793.0,0.9852269887924194,0.0497096553444862,0.1957582689787017,43793.0,17299.366802215576,24937.78490638733,17299.366802215576,7634.566126823425,2.349884271621704,0.0 -54400,0.13030598,0.035305623,,,,,,,,,,,,,,,,, -54500,0.06531579,0.03475837,,,,,,,,,,,,,,,,, -54600,0.10265939,0.037503097,,,,,,,,,,,,,,,,, -54700,0.058189932,0.03641167,,,,,,,,,,,,,,,,, -54800,0.1104924,0.038544968,,,,,,,,,,,,,,,,, -54900,0.07895589,0.033869956,,,,,,,,,,,,,,,,, -55000,0.12523344,0.037837923,,,,,,,,,,,,,,,,, -55100,0.08052392,0.033418,,,,,,,,,,,,,,,,, -55156,,,0.9895650148391724,0.0352049171924591,0.2680131655883736,0.986196756362915,0.04677639529109,0.2084649784585847,43793.0,0.9852808713912964,0.0498327724635601,0.202013044800903,43793.0,17539.38094496727,25277.153856515884,17539.38094496727,7733.860890388489,2.3890788555145264,0.0 -55200,0.13143247,0.039759096,,,,,,,,,,,,,,,,, -55300,0.1050117,0.03330101,,,,,,,,,,,,,,,,, -55400,0.08645153,0.033650767,,,,,,,,,,,,,,,,, -55500,0.05541183,0.036886927,,,,,,,,,,,,,,,,, -55600,0.090461165,0.03218902,,,,,,,,,,,,,,,,, -55700,0.07347834,0.03574276,,,,,,,,,,,,,,,,, -55800,0.09851773,0.03426602,,,,,,,,,,,,,,,,, -55900,0.09785926,0.03367571,,,,,,,,,,,,,,,,, -55922,,,0.9894495010375975,0.0355528406798839,0.2755270016128907,0.9862328767776488,0.0465220957994461,0.2131809277834291,43793.0,0.9852501749992372,0.0495138205587863,0.2028329675264197,43793.0,17779.42640852928,25619.509373903275,17779.42640852928,7836.111313343048,2.427854061126709,0.0 -56000,0.06189941,0.034887902,,,,,,,,,,,,,,,,, -56100,0.070304856,0.031241052,,,,,,,,,,,,,,,,, -56200,0.14134301,0.036548223,,,,,,,,,,,,,,,,, -56300,0.067834504,0.0309086,,,,,,,,,,,,,,,,, -56400,0.07301896,0.029598305,,,,,,,,,,,,,,,,, -56500,0.11669783,0.033463154,,,,,,,,,,,,,,,,, -56600,0.18027094,0.03315049,,,,,,,,,,,,,,,,, -56673,,,0.9895038604736328,0.0352360047399997,0.2716795602529939,0.9862877130508424,0.0464650578796863,0.2110234043764053,43793.0,0.985371470451355,0.0493917502462863,0.2039670057834056,43793.0,18019.61223173141,25963.269825220108,18019.61223173141,7939.627541303635,2.465727567672729,0.0 -56700,0.12136771,0.03266934,,,,,,,,,,,,,,,,, -56800,0.08463855,0.039579898,,,,,,,,,,,,,,,,, -56900,0.08139758,0.035563,,,,,,,,,,,,,,,,, -57000,0.09039937,0.034884304,,,,,,,,,,,,,,,,, -57100,0.093897656,0.03468455,,,,,,,,,,,,,,,,, -57200,0.12028515,0.038138643,,,,,,,,,,,,,,,,, -57300,0.105492026,0.037002984,,,,,,,,,,,,,,,,, -57400,0.0953722,0.03387757,,,,,,,,,,,,,,,,, -57429,,,0.9898420572280884,0.0347669124603271,0.2811684262688596,0.9861756563186646,0.0465973168611526,0.2125133194964228,43793.0,0.985279619693756,0.0494225993752479,0.2015535494588887,43793.0,18259.605164289474,26305.26191139221,18259.605164289474,8041.568267345428,2.5035622119903564,0.0 -57500,0.06678367,0.033528805,,,,,,,,,,,,,,,,, -57600,0.11836406,0.041545805,,,,,,,,,,,,,,,,, -57700,0.08160272,0.03173999,,,,,,,,,,,,,,,,, -57800,0.082254365,0.03591725,,,,,,,,,,,,,,,,, -57900,0.07603487,0.03156565,,,,,,,,,,,,,,,,, -58000,0.06316454,0.035390377,,,,,,,,,,,,,,,,, -58100,0.14292206,0.03475576,,,,,,,,,,,,,,,,, -58120,,,,,,,,,,,,,,18477.279756069183,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 7c9ffaa0b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -107.60885667800903,0.0,13.114962816238403,1,0,13.114962816238403,0.5214918851852417,0.7347381114959717,0.0268206381238046,43793,120.72387218475342,0.5324588418006897,0.7277070879936218,0.0217586781684841,0.5230699777603149,0.7331878542900085,0.0254813053980928,43793 -213.7212650775909,0.2872607707977295,253.0232331752777,760,0,253.0232331752777,0.983195960521698,0.0661912709474563,0.0479848150834761,43793,467.0528705120087,0.9868366122245787,0.0525698475539684,0.0459173332919591,0.9841719269752502,0.0628870725631713,0.0457983544387885,43793 -319.4787435531616,0.3140983581542969,493.0518238544464,1505,0,493.0518238544464,0.9834386110305786,0.0608416348695755,0.0933789730938457,43793,812.8888652324677,0.986958384513855,0.0480176247656345,0.0938309508173274,0.9844183325767516,0.057403702288866,0.0944905234983526,43793 -420.9634261131287,0.3397960662841797,733.1513078212738,2258,0,733.1513078212738,0.98393052816391,0.05640734359622,0.1286533700685239,43793,1154.5208294391632,0.987659990787506,0.0437299534678459,0.1344494258374315,0.9849497079849244,0.053313635289669,0.1305046429838764,43793 -520.0432939529419,0.3667905330657959,973.3656969070436,3011,0,973.3656969070436,0.9842051863670348,0.0541716180741786,0.1458128469841821,43793,1493.8625662326813,0.988203763961792,0.0414475202560424,0.1611080446191192,0.9852355122566224,0.0509338229894638,0.15108787353642,43793 -628.8051879405975,0.3951599597930908,1213.6298730373385,3759,0,1213.6298730373385,0.9845345616340636,0.0524315349757671,0.1698998980243238,43793,1842.9375042915344,0.9883896112442015,0.0404955521225929,0.1966472162561015,0.9854575395584106,0.0496173612773418,0.1722477732427589,43793 -730.8336308002472,0.4234771728515625,1453.6540446281433,4512,0,1453.6540446281433,0.984804093837738,0.0515812933444976,0.1859877344371198,43793,2185.0393300056458,0.9884734749794006,0.0396814867854118,0.2104161173825414,0.985728681087494,0.0488488934934139,0.1853613796800635,43793 -836.2591185569763,0.4546363353729248,1693.849843502045,5260,0,1693.849843502045,0.9851145148277284,0.0497923754155635,0.2028064962651752,43793,2530.7144179344177,0.9889073967933656,0.0377699360251426,0.240450900537608,0.9860051274299622,0.0472903698682785,0.2011653371217038,43793 -941.2021589279176,0.4842491149902344,1933.864909410477,6017,0,1933.864909410477,0.9850445985794068,0.0499980822205543,0.2039791274407398,43793,2875.7232887744904,0.9888168573379515,0.037816647440195,0.2434065181146786,0.9859337210655212,0.0472665354609489,0.2041797816516377,43793 -1044.834265708923,0.5117506980895996,2173.9619414806366,6766,0,2173.9619414806366,0.9852948188781738,0.0491566210985183,0.2295385995796318,43793,3219.5011546611786,0.9892390966415404,0.0362072438001632,0.271036118425679,0.9861992001533508,0.046460848301649,0.2200321466350437,43793 -1148.9634318351746,0.5409078598022461,2413.9666571617126,7516,0,2413.9666571617126,0.9854813814163208,0.0492261908948421,0.2348783154287064,43793,3563.6852464675903,0.9895783066749572,0.0349873229861259,0.2932618939418534,0.9863753914833068,0.0464226454496383,0.2287308422014867,43793 -1246.890321969986,0.5691602230072021,2653.955867290497,8266,0,2653.955867290497,0.9855201244354248,0.0480259098112583,0.2359621448547312,43793,3901.649888753891,0.9897823333740234,0.0342083573341369,0.3195737883451471,0.9863818883895874,0.0454126447439193,0.2325608492496531,43793 -1353.1766684055328,0.5972564220428467,2894.023741006851,9015,0,2894.023741006851,0.9856511354446412,0.0477934069931507,0.2427388808261531,43793,4248.052522659302,0.990003228187561,0.0334638282656669,0.3346555605107079,0.9865231513977052,0.0452743992209434,0.2372158837428065,43793 -1460.478672027588,0.6274769306182861,3134.196201562881,9761,0,3134.196201562881,0.9857004284858704,0.0474113151431083,0.2414298978085044,43793,4595.578535079956,0.9901729226112366,0.0329452231526374,0.3521926542739725,0.9865406155586244,0.0448672249913215,0.2515968842157506,43793 -1567.1986813545227,0.6583380699157715,3374.321034193039,10505,0,3374.321034193039,0.985818326473236,0.0471859909594059,0.2474191243238686,43793,4942.475167989731,0.9900105595588684,0.0333098098635673,0.3414721915273255,0.9866453409194946,0.044458620250225,0.2510299167733832,43793 -1673.4594593048096,0.6867432594299316,3614.390597343445,11258,0,3614.390597343445,0.9857791662216188,0.0471867583692073,0.2567930607884815,43793,5288.855046987534,0.990221619606018,0.0325211398303508,0.3451225530368046,0.986581563949585,0.0445379801094532,0.2583005286632394,43793 -1778.9284162521362,0.7189178466796875,3854.4456446170807,12007,0,3854.4456446170807,0.9858141541481018,0.0472727715969085,0.2503482735494056,43793,5634.432926416397,0.9903555512428284,0.0322158448398113,0.3548928727455357,0.986707866191864,0.0444315969944,0.2554358455220814,43793 -1877.9654395580287,0.7478671073913574,4094.401614904404,12768,0,4094.401614904404,0.985888659954071,0.0471827387809753,0.2574348691223886,43793,5973.4755423069,0.9902983903884888,0.0321623496711254,0.3728056786919886,0.9867314100265504,0.0442637987434864,0.2605304608129135,43793 -1978.278470039368,0.7787857055664062,4334.481739997864,13526,0,4334.481739997864,0.9858831763267516,0.0473159775137901,0.2539916084768826,43793,6313.919853925705,0.9905072450637816,0.0311836674809455,0.3852428672374816,0.986772358417511,0.0443389266729354,0.2648099308155049,43793 -2082.57606124878,0.807380199432373,4574.704688310623,14281,0,4574.704688310623,0.9858419299125672,0.0471290312707424,0.2575491563365366,43793,6658.48993730545,0.9905017018318176,0.0312279406934976,0.3851849660303862,0.9866887331008912,0.0443001464009285,0.2595797861296848,43793 -2182.466774225235,0.8405594825744629,4814.931003808975,15031,0,4814.931003808975,0.9859139323234558,0.0474883429706096,0.2540615756671235,43793,6998.6618309021,0.9906373023986816,0.0307508446276187,0.4132004282045805,0.9867650866508484,0.0446204729378223,0.2628655068956396,43793 -2286.075028419494,0.8707151412963867,5054.909048080444,15795,0,5054.909048080444,0.9858680367469788,0.0467845201492309,0.2557980212280484,43793,7342.298298835754,0.990784227848053,0.0301042962819337,0.4046015524426429,0.9867382645606996,0.044171217828989,0.2618256099036918,43793 -2390.076048135757,0.8997907638549805,5294.891679048538,16557,0,5294.891679048538,0.9858490824699402,0.0468463972210884,0.2617767080606321,43793,7686.331763267517,0.9911503791809082,0.029118113219738,0.4419770802823985,0.986614465713501,0.0442926213145256,0.2690583731034517,43793 -2490.6851439476013,0.932380437850952,5534.965999603272,17305,0,5534.965999603272,0.9858777523040771,0.0467349551618099,0.2607631513044034,43793,8027.069473028183,0.9909698367118835,0.0296294633299112,0.425239852981252,0.9866148829460144,0.0441667363047599,0.2680382857825597,43793 -2595.902805566788,0.9617249965667723,5775.012037992477,18067,0,5775.012037992477,0.9858953952789308,0.0467287376523017,0.2647181635576729,43793,8372.382522583008,0.99085795879364,0.0300395190715789,0.4119247732666293,0.9867743849754332,0.0440140813589096,0.2692154002615283,43793 -2695.526474237442,0.9920079708099364,6014.981848716736,18826,0,6014.981848716736,0.985905945301056,0.0468245893716812,0.261026378969448,43793,8712.026437044144,0.9908071160316468,0.0303335282951593,0.4082732006218971,0.986763834953308,0.0440635085105896,0.2648081561789626,43793 -2796.600333929062,1.0257935523986816,6255.140200138092,19588,0,6255.140200138092,0.9860677123069764,0.0467986539006233,0.269728862963543,43793,9053.31315279007,0.990985095500946,0.0293735228478908,0.42571662301725,0.986893355846405,0.0441452153027057,0.2726999962014126,43793 -2900.5887792110443,1.0550963878631592,6495.229241847992,20343,0,6495.229241847992,0.9860925674438475,0.0466288626194,0.2667238021460736,43793,9397.442924976349,0.9908535480499268,0.0297939609736204,0.4299498885823914,0.9868730306625366,0.0438564978539943,0.27432950631949,43793 -3003.7433593273163,1.089904546737671,6735.494882106781,21098,0,6735.494882106781,0.9860213398933412,0.0465683303773403,0.2609999197800985,43793,9740.919793367386,0.9910489320755004,0.0291868932545185,0.4427940285589564,0.9868564009666444,0.0438627004623413,0.2687042846949644,43793 -3106.980573654175,1.1237335205078125,6975.677932262421,21856,0,6975.677932262421,0.9861317276954652,0.0474246628582477,0.2671671664715521,43793,10084.39429950714,0.9909344911575316,0.0291950646787881,0.4393534144632743,0.9868978261947632,0.0446341596543788,0.2735190565901901,43793 -3208.3332257270813,1.1560900211334229,7215.721256017685,22612,0,7215.721256017685,0.9859792590141296,0.0466558374464511,0.2688925711781614,43793,10425.84306025505,0.9911161065101624,0.0289258304983377,0.4461390763695887,0.9867297410964966,0.0439446568489074,0.2784422440583035,43793 -3310.157596349716,1.1869611740112305,7455.825122117996,23366,0,7455.825122117996,0.9861093759536744,0.0466336794197559,0.2682321435300443,43793,10767.822446346285,0.9914578199386596,0.0278907660394907,0.4627701374352354,0.9868608713150024,0.0440641902387142,0.276925187531418,43793 -3409.015555858612,1.2186834812164309,7695.909645080566,24123,0,7695.909645080566,0.9861299991607666,0.0465262271463871,0.2760676809335625,43793,11106.817147493362,0.9916008114814758,0.0272237285971641,0.4873757299629083,0.9869104027748108,0.0439563207328319,0.2830847361580403,43793 -3511.65873837471,1.2492575645446775,7936.176169872284,24882,0,7936.176169872284,0.9860736131668092,0.0468006432056427,0.2785355949087478,43793,11449.778022766111,0.9916938543319702,0.026998370885849,0.4925730935016073,0.9868324398994446,0.0440956726670265,0.2810012658860785,43793 -3609.806525945664,1.2795326709747314,8176.215074539185,25646,0,8176.215074539185,0.986088752746582,0.0467182286083698,0.275252788624737,43793,11788.014889717102,0.9915238618850708,0.0275689717382192,0.4713439378127574,0.9869733452796936,0.0439281985163688,0.2816985678081601,43793 -3711.346279144287,1.3122563362121582,8416.223066806793,26404,0,8416.223066806793,0.9861670732498168,0.0467212721705436,0.272193655521218,43793,12129.615074634552,0.9913026094436646,0.0280513260513544,0.4540246569703544,0.9868937730789183,0.0440740473568439,0.2784015307768114,43793 -3815.66153049469,1.343522071838379,8656.387609004974,27167,0,8656.387609004974,0.9861894249916076,0.0466781705617904,0.2716197370065153,43793,12474.14634013176,0.9913366436958312,0.0279964916408061,0.4512330298512755,0.9870614409446716,0.0438373424112796,0.2824528369850656,43793 -3915.033325195313,1.374058961868286,8896.527443885803,27924,0,8896.527443885803,0.986270308494568,0.0467647463083267,0.2791487319363641,43793,12813.708843708038,0.9912831783294678,0.0282191280275583,0.459710992591313,0.9870139360427856,0.0440753027796745,0.2801922874370537,43793 -4021.181110858917,1.4105889797210691,9136.534759044647,28676,0,9136.534759044647,0.9860512614250184,0.0466743782162666,0.2696900439508503,43793,13159.920971632004,0.9914658069610596,0.0276060681790113,0.4804013392235434,0.9868357181549072,0.043917428702116,0.2829397828021821,43793 -4123.922771692276,1.4421751499176023,9376.566728830338,29429,0,9376.566728830338,0.9862012267112732,0.0470006801187992,0.27495089302911,43793,13502.746740341188,0.9914921522140504,0.0274810884147882,0.4786408746212388,0.98698753118515,0.0441337078809738,0.280014814493845,43793 -4224.7899651527405,1.474400281906128,9616.63538479805,30193,0,9616.63538479805,0.9862092137336732,0.0464126951992511,0.2754781499849553,43793,13843.735455036163,0.9915912747383118,0.0271249692887067,0.4820567404347408,0.9869375824928284,0.0437896288931369,0.2850369521914939,43793 -4328.463456869125,1.5062589645385742,9856.632263422012,30960,0,9856.632263422012,0.9862736463546752,0.0467136912047863,0.2736675093336144,43793,14187.458263158798,0.9916066527366638,0.0269829258322715,0.4897253311971881,0.987052083015442,0.0439271628856658,0.2813805981787919,43793 -4425.642622709274,1.5399868488311768,10096.578315734863,31724,0,10096.578315734863,0.9861574172973632,0.0463861934840679,0.2759647672414221,43793,14524.637803077698,0.991977870464325,0.025896318256855,0.4988183288868678,0.9869737029075624,0.0437305495142936,0.287502596920722,43793 -4528.410791397095,1.5721287727355957,10336.690904140472,32485,0,10336.690904140472,0.9862816333770752,0.0465779192745685,0.2800528015176483,43793,14867.57130074501,0.9921050071716307,0.0254569556564092,0.5302573368344676,0.986983060836792,0.0440465398132801,0.2828144092560208,43793 -4626.351637125015,1.6052203178405762,10576.917943954468,33248,0,10576.917943954468,0.986216366291046,0.0468509458005428,0.2793535608483301,43793,15205.792612075806,0.99212908744812,0.0253685247153043,0.5184021366579336,0.9869611263275146,0.0442101545631885,0.2848071117613658,43793 -4731.118203163147,1.6438887119293213,10816.88122177124,34005,0,10816.88122177124,0.9860959053039552,0.0467146001756191,0.2774428806202462,43793,15550.581326246262,0.9920830726623536,0.02560705691576,0.5206663689703317,0.9869534373283386,0.0438498817384243,0.2907716835350484,43793 -4829.9352684021,1.679349660873413,11056.937723875046,34766,0,11056.937723875046,0.9861780405044556,0.0469150133430957,0.2740014648998188,43793,15889.510733604431,0.9918881058692932,0.0260873194783926,0.5010801180054504,0.9869778156280518,0.0440643094480037,0.2841008787928426,43793 -4929.022217512131,1.7137835025787354,11296.953873872755,35525,0,11296.953873872755,0.986100971698761,0.0467952117323875,0.2735042606028195,43793,16228.668662548063,0.991990566253662,0.0258852709084749,0.5093072745034118,0.9869050979614258,0.0439064614474773,0.287430191108912,43793 -5034.475823879242,1.7479205131530762,11537.133439779282,36284,0,11537.133439779282,0.9861708879470824,0.047328058630228,0.2739957375683177,43793,16574.356207609177,0.9918835759162904,0.0260820463299751,0.50537218004944,0.987015962600708,0.0443647243082523,0.2867905841474967,43793 -5131.294198036194,1.7850227355957031,11777.099576950071,37041,0,11777.099576950071,0.9860870838165284,0.0468637980520725,0.2749342780922189,43793,16911.1993830204,0.991921603679657,0.0260071270167827,0.5088854546962834,0.9869713187217712,0.0440560989081859,0.2855104168089387,43793 -5236.3157749176025,1.818418979644776,12017.13541841507,37792,0,12017.13541841507,0.9861211776733398,0.0471453480422496,0.2790645129473144,43793,17256.310554504395,0.9920625686645508,0.0254651494324207,0.5246759598596786,0.9870808720588684,0.0442389287054538,0.2901478151348224,43793 -5339.024819612503,1.8535473346710205,12257.310943841934,38539,0,12257.310943841934,0.9862340688705444,0.0469984784722328,0.2798594395019039,43793,17599.25080871582,0.992132604122162,0.025216331705451,0.5271417505934661,0.9870041608810424,0.0443718582391738,0.284068813649276,43793 -5438.942674875259,1.888882160186768,12497.441692590714,39297,0,12497.441692590714,0.9863966703414916,0.0471530593931674,0.2812594940719487,43793,17939.355797052383,0.9923343062400818,0.0245491247624158,0.5338019086037505,0.987090229988098,0.0443548746407032,0.2878398362951603,43793 -5545.897831678391,1.9227867126464844,12737.689799547195,40049,0,12737.689799547195,0.9862483739852904,0.0473817698657512,0.2766928416127924,43793,18286.61421489716,0.9923198223114014,0.0243196655064821,0.5546808189751687,0.9871109127998352,0.0443590618669986,0.2893713804221078,43793 -5646.659322977066,1.9570858478546145,12977.942219495771,40802,0,12977.942219495771,0.9862837791442872,0.0476432628929615,0.2762729613561593,43793,18627.68426156044,0.9927132725715636,0.0232907254248857,0.5798397030318193,0.987059772014618,0.0447507388889789,0.2860022023207019,43793 -5757.701861858368,1.9907326698303225,13217.97829079628,41565,0,13217.97829079628,0.9863145351409912,0.0475908629596233,0.2851083116733161,43793,18978.81697440148,0.9927103519439696,0.023197915405035,0.5746574374680143,0.9871835708618164,0.0446155592799186,0.28433454948306,43793 -5869.774827003479,2.050032138824463,13458.056258440018,42321,0,13458.056258440018,0.9861207604408264,0.0483300276100635,0.2736255863559124,43793,19331.04762125016,0.9925374984741212,0.0237526260316371,0.5571466336294107,0.9870622158050536,0.0449502542614936,0.2945649533936957,43793 -5970.566066265106,2.092049360275269,13698.089906215668,43067,0,13698.089906215668,0.9861182570457458,0.0477258823812007,0.2800975663738683,43793,19671.93588352204,0.9924070835113524,0.0241846218705177,0.5549188476915506,0.9869388341903688,0.0447618216276168,0.291296804280545,43793 -6073.116245508194,2.1265196800231934,13938.330757379532,43822,0,13938.330757379532,0.9862075448036194,0.0476227030158042,0.2781757933964989,43793,20014.78182053566,0.9925096035003662,0.023991784080863,0.5581068188704152,0.987064242362976,0.0446946024894714,0.2883423968673922,43793 -6176.107170343399,2.1627137660980225,14178.45221710205,44573,0,14178.45221710205,0.9861910939216614,0.0475999042391777,0.2807915067602074,43793,20357.951191186905,0.9925292134284972,0.0238912533968687,0.5533522608731953,0.9870536923408508,0.0445699468255043,0.2946664919696532,43793 -6273.095078229904,2.1975855827331543,14418.65455508232,45328,0,14418.65455508232,0.9862349033355712,0.0475789941847324,0.2797156681465258,43793,20695.19634771347,0.9924713969230652,0.0238328706473112,0.5691807750617297,0.9870431423187256,0.0448387153446674,0.2838232387063066,43793 -6379.238595485687,2.2321431636810303,14658.61558651924,46094,0,14658.61558651924,0.9862656593322754,0.0479391925036907,0.2842984568305961,43793,21041.355713129044,0.9926730990409852,0.0232164915651083,0.5673038654986291,0.9871312379837036,0.0448323711752891,0.2927768764638855,43793 -6480.91828584671,2.2671091556549072,14898.701581716536,46841,0,14898.701581716536,0.9862319827079772,0.0485468283295631,0.2778092930415843,43793,21383.17679190636,0.9926999807357788,0.0229254197329282,0.5850914346518163,0.9871032238006592,0.0455067902803421,0.2930319312146007,43793 -6583.505940437317,2.3025686740875244,15138.667674064636,47593,0,15138.667674064636,0.9862273335456848,0.04789400100708,0.2869319142049717,43793,21725.787615060806,0.992985725402832,0.0222699958831071,0.5896478412213515,0.9871113300323486,0.0449162758886814,0.2926934946306269,43793 -6683.557134151459,2.3388681411743164,15378.75329375267,48351,0,15378.75329375267,0.986178457736969,0.0477825812995433,0.2834298438452727,43793,22065.98137497902,0.9932321310043336,0.0214752256870269,0.6224910459932752,0.987035036087036,0.0448025874793529,0.2923386664862276,43793 -6783.601722478867,2.3929810523986816,15618.89215707779,49109,0,15618.89215707779,0.9862424731254578,0.0486608073115348,0.2857276248478931,43793,22406.23919153213,0.9933977723121644,0.0210042484104633,0.6172502556198101,0.9870760440826416,0.0456252992153167,0.2926285593021795,43793 -6884.114263057709,2.4300451278686523,15859.115253686905,49864,0,15859.115253686905,0.9862942695617676,0.0483603700995445,0.2828741073363153,43793,22747.032354831696,0.9931698441505432,0.0215233471244573,0.6213504048826404,0.9871621131896972,0.0453357174992561,0.2962699552477201,43793 -6987.389058351517,2.4660558700561523,16099.191632032394,50621,0,16099.191632032394,0.9862441420555116,0.0485006384551525,0.2809669364114627,43793,23090.439814329147,0.9931451082229614,0.0217219870537519,0.5996855145951254,0.9870354533195496,0.0454207174479961,0.2876952709587337,43793 -7093.5865795612335,2.502041339874268,16339.3504114151,51379,0,16339.3504114151,0.9861915111541748,0.0485584177076816,0.2790803063476967,43793,23436.85264754296,0.9931361675262452,0.0217958111315965,0.6021560760301627,0.9870212078094482,0.0455244444310665,0.2931287192061453,43793 -7197.719510793686,2.5390548706054688,16579.44235610962,52138,0,16579.44235610962,0.98629891872406,0.0488433130085468,0.2831817590489205,43793,23781.13451218605,0.9931796193122864,0.0215019863098859,0.6067070557150565,0.9871442317962646,0.0456354469060897,0.293471528182527,43793 -7301.227471590042,2.5754523277282715,16819.4209959507,52896,0,16819.4209959507,0.9863018989562988,0.0489050596952438,0.2820189862990282,43793,24124.67778921128,0.9931758046150208,0.0214325115084648,0.6115915335475446,0.9871913194656372,0.0455203764140605,0.3000649690933642,43793 -7401.090955257416,2.6128950119018555,17059.447257995605,53651,0,17059.447257995605,0.9862803816795348,0.0489753000438213,0.2802482233155609,43793,24464.625368595123,0.9933634400367736,0.0208130627870559,0.6317767925373252,0.987166166305542,0.045723769813776,0.2988973768460285,43793 -7502.637751102447,2.65012526512146,17299.483036756516,54403,0,17299.483036756516,0.986240804195404,0.0494866259396076,0.2819279003808288,43793,24806.26559662819,0.9933775663375854,0.0207133274525403,0.6225736436664356,0.987188458442688,0.0462249219417572,0.296826822781387,43793 -7602.579082250595,2.6879782676696777,17539.461524248123,55156,0,17539.461524248123,0.9862689971923828,0.0496369414031505,0.2751525052977334,43793,25146.24363541603,0.993545413017273,0.0201831441372632,0.645784824725203,0.9871426224708556,0.0462013557553291,0.2940088533353942,43793 -7701.495981454849,2.7243547439575195,17779.411629915237,55921,0,17779.411629915237,0.9861751198768616,0.0495301969349384,0.278072814459413,43793,25485.167813539505,0.9938235282897948,0.0194733310490846,0.6602483439614082,0.9870220422744752,0.0461726337671279,0.2962127262744353,43793 -7807.280650138855,2.760650157928467,18019.554752588272,56685,0,18019.554752588272,0.9863424897193908,0.0498776361346244,0.2806236785910661,43793,25831.152674913406,0.993970274925232,0.0188592132180929,0.676547356716646,0.9871426224708556,0.0465945266187191,0.2926747635568142,43793 -7909.328285217285,2.7992615699768066,18259.80503797531,57443,0,18259.80503797531,0.9862766265869141,0.04957572743296623,0.2852527176577683,43793,26173.50976252556,0.9942606687545776,0.01822679676115513,0.6956112092775892,0.9871117472648621,0.04643256217241287,0.2988261302469323,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/measurements.csv deleted file mode 100644 index 64438963d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/measurements.csv +++ /dev/null @@ -1,661 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.8519781,0.7274107,,,,,,,,,,,,,,,,, -1,,,0.5324588418006897,0.7277070879936218,0.0217586781684841,0.5230699777603149,0.7331878542900085,0.0254813053980928,43793.0,0.5214918851852417,0.7347381114959717,0.0268206381238046,43793.0,13.114962816238403,120.72387218475342,13.114962816238403,107.60885667800903,0.0,0.0 -100,0.29227042,0.26782244,,,,,,,,,,,,,,,,, -200,0.093204245,0.11452889,,,,,,,,,,,,,,,,, -300,0.03005778,0.06601029,,,,,,,,,,,,,,,,, -400,0.01467846,0.05889775,,,,,,,,,,,,,,,,, -500,0.015004844,0.059379533,,,,,,,,,,,,,,,,, -600,0.016165385,0.051949106,,,,,,,,,,,,,,,,, -700,0.035877295,0.0546671,,,,,,,,,,,,,,,,, -760,,,0.9868366122245787,0.0525698475539684,0.0459173332919591,0.9841719269752502,0.0628870725631713,0.0457983544387885,43793.0,0.983195960521698,0.0661912709474563,0.0479848150834761,43793.0,253.0232331752777,467.0528705120087,253.0232331752777,213.7212650775909,0.2872607707977295,0.0 -800,0.014015678,0.05189471,,,,,,,,,,,,,,,,, -900,0.021262174,0.053366903,,,,,,,,,,,,,,,,, -1000,0.012379905,0.0507674,,,,,,,,,,,,,,,,, -1100,0.026322335,0.04949008,,,,,,,,,,,,,,,,, -1200,0.014613479,0.045164075,,,,,,,,,,,,,,,,, -1300,0.014748436,0.051697534,,,,,,,,,,,,,,,,, -1400,0.014796482,0.044635177,,,,,,,,,,,,,,,,, -1500,0.042326048,0.055388793,,,,,,,,,,,,,,,,, -1505,,,0.986958384513855,0.0480176247656345,0.0938309508173274,0.9844183325767516,0.057403702288866,0.0944905234983526,43793.0,0.9834386110305786,0.0608416348695755,0.0933789730938457,43793.0,493.0518238544464,812.8888652324677,493.0518238544464,319.4787435531616,0.3140983581542969,0.0 -1600,0.01450167,0.04754587,,,,,,,,,,,,,,,,, -1700,0.027748236,0.053749092,,,,,,,,,,,,,,,,, -1800,0.012324996,0.043660685,,,,,,,,,,,,,,,,, -1900,0.029833125,0.04537595,,,,,,,,,,,,,,,,, -2000,0.012626385,0.045738395,,,,,,,,,,,,,,,,, -2100,0.0119075775,0.047471464,,,,,,,,,,,,,,,,, -2200,0.010303105,0.0486721,,,,,,,,,,,,,,,,, -2258,,,0.987659990787506,0.0437299534678459,0.1344494258374315,0.9849497079849244,0.053313635289669,0.1305046429838764,43793.0,0.98393052816391,0.05640734359622,0.1286533700685239,43793.0,733.1513078212738,1154.5208294391632,733.1513078212738,420.9634261131287,0.3397960662841797,0.0 -2300,0.018464036,0.050404068,,,,,,,,,,,,,,,,, -2400,0.011053654,0.042145837,,,,,,,,,,,,,,,,, -2500,0.016733546,0.04281305,,,,,,,,,,,,,,,,, -2600,0.013121832,0.040640596,,,,,,,,,,,,,,,,, -2700,0.014669278,0.041561298,,,,,,,,,,,,,,,,, -2800,0.022569358,0.04405164,,,,,,,,,,,,,,,,, -2900,0.01306465,0.048115797,,,,,,,,,,,,,,,,, -3000,0.013924377,0.0442898,,,,,,,,,,,,,,,,, -3011,,,0.988203763961792,0.0414475202560424,0.1611080446191192,0.9852355122566224,0.0509338229894638,0.15108787353642,43793.0,0.9842051863670348,0.0541716180741786,0.1458128469841821,43793.0,973.3656969070436,1493.8625662326813,973.3656969070436,520.0432939529419,0.3667905330657959,0.0 -3100,0.01024891,0.040746015,,,,,,,,,,,,,,,,, -3200,0.0120299505,0.042728815,,,,,,,,,,,,,,,,, -3300,0.012076139,0.041084778,,,,,,,,,,,,,,,,, -3400,0.01190991,0.036925465,,,,,,,,,,,,,,,,, -3500,0.017982092,0.040676158,,,,,,,,,,,,,,,,, -3600,0.013415704,0.039477114,,,,,,,,,,,,,,,,, -3700,0.011137501,0.042230755,,,,,,,,,,,,,,,,, -3759,,,0.9883896112442015,0.0404955521225929,0.1966472162561015,0.9854575395584106,0.0496173612773418,0.1722477732427589,43793.0,0.9845345616340636,0.0524315349757671,0.1698998980243238,43793.0,1213.6298730373385,1842.9375042915344,1213.6298730373385,628.8051879405975,0.3951599597930908,0.0 -3800,0.010132716,0.037579447,,,,,,,,,,,,,,,,, -3900,0.014891327,0.03982604,,,,,,,,,,,,,,,,, -4000,0.013704952,0.04254143,,,,,,,,,,,,,,,,, -4100,0.010880386,0.039053015,,,,,,,,,,,,,,,,, -4200,0.010497885,0.036812924,,,,,,,,,,,,,,,,, -4300,0.015240235,0.03922535,,,,,,,,,,,,,,,,, -4400,0.009764204,0.037992682,,,,,,,,,,,,,,,,, -4500,0.010904689,0.040161192,,,,,,,,,,,,,,,,, -4512,,,0.9884734749794006,0.0396814867854118,0.2104161173825414,0.985728681087494,0.0488488934934139,0.1853613796800635,43793.0,0.984804093837738,0.0515812933444976,0.1859877344371198,43793.0,1453.6540446281433,2185.0393300056458,1453.6540446281433,730.8336308002472,0.4234771728515625,0.0 -4600,0.013823724,0.04502948,,,,,,,,,,,,,,,,, -4700,0.014028822,0.035692424,,,,,,,,,,,,,,,,, -4800,0.013766847,0.039080746,,,,,,,,,,,,,,,,, -4900,0.012124531,0.037226737,,,,,,,,,,,,,,,,, -5000,0.012025442,0.039452594,,,,,,,,,,,,,,,,, -5100,0.010714772,0.035619304,,,,,,,,,,,,,,,,, -5200,0.011461194,0.040534556,,,,,,,,,,,,,,,,, -5260,,,0.9889073967933656,0.0377699360251426,0.240450900537608,0.9860051274299622,0.0472903698682785,0.2011653371217038,43793.0,0.9851145148277284,0.0497923754155635,0.2028064962651752,43793.0,1693.849843502045,2530.7144179344177,1693.849843502045,836.2591185569763,0.4546363353729248,0.0 -5300,0.012019575,0.03662594,,,,,,,,,,,,,,,,, -5400,0.011617688,0.04429739,,,,,,,,,,,,,,,,, -5500,0.014676414,0.0411229,,,,,,,,,,,,,,,,, -5600,0.01652481,0.04031831,,,,,,,,,,,,,,,,, -5700,0.012024545,0.032152057,,,,,,,,,,,,,,,,, -5800,0.020529775,0.03826178,,,,,,,,,,,,,,,,, -5900,0.011715578,0.036650356,,,,,,,,,,,,,,,,, -6000,0.010306346,0.033979136,,,,,,,,,,,,,,,,, -6017,,,0.9888168573379515,0.037816647440195,0.2434065181146786,0.9859337210655212,0.0472665354609489,0.2041797816516377,43793.0,0.9850445985794068,0.0499980822205543,0.2039791274407398,43793.0,1933.864909410477,2875.7232887744904,1933.864909410477,941.2021589279176,0.4842491149902344,0.0 -6100,0.019048017,0.033176012,,,,,,,,,,,,,,,,, -6200,0.012994559,0.0366007,,,,,,,,,,,,,,,,, -6300,0.011765211,0.03595147,,,,,,,,,,,,,,,,, -6400,0.013851394,0.038343836,,,,,,,,,,,,,,,,, -6500,0.016825499,0.03746875,,,,,,,,,,,,,,,,, -6600,0.012658671,0.03464992,,,,,,,,,,,,,,,,, -6700,0.01204267,0.038636085,,,,,,,,,,,,,,,,, -6766,,,0.9892390966415404,0.0362072438001632,0.271036118425679,0.9861992001533508,0.046460848301649,0.2200321466350437,43793.0,0.9852948188781738,0.0491566210985183,0.2295385995796318,43793.0,2173.9619414806366,3219.5011546611786,2173.9619414806366,1044.834265708923,0.5117506980895996,0.0 -6800,0.013629362,0.035458688,,,,,,,,,,,,,,,,, -6900,0.015531112,0.0366109,,,,,,,,,,,,,,,,, -7000,0.015209029,0.035564315,,,,,,,,,,,,,,,,, -7100,0.012516628,0.04172401,,,,,,,,,,,,,,,,, -7200,0.01438081,0.034728326,,,,,,,,,,,,,,,,, -7300,0.013800424,0.03533316,,,,,,,,,,,,,,,,, -7400,0.026956575,0.039327227,,,,,,,,,,,,,,,,, -7500,0.013858918,0.03750212,,,,,,,,,,,,,,,,, -7516,,,0.9895783066749572,0.0349873229861259,0.2932618939418534,0.9863753914833068,0.0464226454496383,0.2287308422014867,43793.0,0.9854813814163208,0.0492261908948421,0.2348783154287064,43793.0,2413.9666571617126,3563.6852464675903,2413.9666571617126,1148.9634318351746,0.5409078598022461,0.0 -7600,0.01567872,0.037887912,,,,,,,,,,,,,,,,, -7700,0.017944366,0.034435965,,,,,,,,,,,,,,,,, -7800,0.030195236,0.036627144,,,,,,,,,,,,,,,,, -7900,0.023300083,0.034793854,,,,,,,,,,,,,,,,, -8000,0.017952561,0.038521204,,,,,,,,,,,,,,,,, -8100,0.022967981,0.03642484,,,,,,,,,,,,,,,,, -8200,0.01495837,0.032358732,,,,,,,,,,,,,,,,, -8266,,,0.9897823333740234,0.0342083573341369,0.3195737883451471,0.9863818883895874,0.0454126447439193,0.2325608492496531,43793.0,0.9855201244354248,0.0480259098112583,0.2359621448547312,43793.0,2653.955867290497,3901.649888753891,2653.955867290497,1246.890321969986,0.5691602230072021,0.0 -8300,0.01683211,0.035192087,,,,,,,,,,,,,,,,, -8400,0.023183677,0.03668776,,,,,,,,,,,,,,,,, -8500,0.021644056,0.032968357,,,,,,,,,,,,,,,,, -8600,0.03127518,0.035380222,,,,,,,,,,,,,,,,, -8700,0.015572782,0.0341106,,,,,,,,,,,,,,,,, -8800,0.016459558,0.036953434,,,,,,,,,,,,,,,,, -8900,0.016988177,0.036755968,,,,,,,,,,,,,,,,, -9000,0.018180804,0.039167512,,,,,,,,,,,,,,,,, -9015,,,0.990003228187561,0.0334638282656669,0.3346555605107079,0.9865231513977052,0.0452743992209434,0.2372158837428065,43793.0,0.9856511354446412,0.0477934069931507,0.2427388808261531,43793.0,2894.023741006851,4248.052522659302,2894.023741006851,1353.1766684055328,0.5972564220428467,0.0 -9100,0.018433152,0.03613725,,,,,,,,,,,,,,,,, -9200,0.01613841,0.03522839,,,,,,,,,,,,,,,,, -9300,0.025739998,0.03636865,,,,,,,,,,,,,,,,, -9400,0.021786517,0.035637684,,,,,,,,,,,,,,,,, -9500,0.02225653,0.032713644,,,,,,,,,,,,,,,,, -9600,0.024506722,0.028915405,,,,,,,,,,,,,,,,, -9700,0.024950135,0.03586829,,,,,,,,,,,,,,,,, -9761,,,0.9901729226112366,0.0329452231526374,0.3521926542739725,0.9865406155586244,0.0448672249913215,0.2515968842157506,43793.0,0.9857004284858704,0.0474113151431083,0.2414298978085044,43793.0,3134.196201562881,4595.578535079956,3134.196201562881,1460.478672027588,0.6274769306182861,0.0 -9800,0.017162496,0.036801938,,,,,,,,,,,,,,,,, -9900,0.027067304,0.039455574,,,,,,,,,,,,,,,,, -10000,0.040727038,0.03182499,,,,,,,,,,,,,,,,, -10100,0.029470727,0.034716703,,,,,,,,,,,,,,,,, -10200,0.037903313,0.034348436,,,,,,,,,,,,,,,,, -10300,0.023978295,0.037948135,,,,,,,,,,,,,,,,, -10400,0.033666432,0.036077525,,,,,,,,,,,,,,,,, -10500,0.021951381,0.037423465,,,,,,,,,,,,,,,,, -10505,,,0.9900105595588684,0.0333098098635673,0.3414721915273255,0.9866453409194946,0.044458620250225,0.2510299167733832,43793.0,0.985818326473236,0.0471859909594059,0.2474191243238686,43793.0,3374.321034193039,4942.475167989731,3374.321034193039,1567.1986813545227,0.6583380699157715,0.0 -10600,0.020150803,0.03238388,,,,,,,,,,,,,,,,, -10700,0.029678188,0.038109947,,,,,,,,,,,,,,,,, -10800,0.021347115,0.03173639,,,,,,,,,,,,,,,,, -10900,0.022324365,0.035330463,,,,,,,,,,,,,,,,, -11000,0.023543088,0.037410203,,,,,,,,,,,,,,,,, -11100,0.01961183,0.033987887,,,,,,,,,,,,,,,,, -11200,0.024565518,0.034606334,,,,,,,,,,,,,,,,, -11258,,,0.990221619606018,0.0325211398303508,0.3451225530368046,0.986581563949585,0.0445379801094532,0.2583005286632394,43793.0,0.9857791662216188,0.0471867583692073,0.2567930607884815,43793.0,3614.390597343445,5288.855046987534,3614.390597343445,1673.4594593048096,0.6867432594299316,0.0 -11300,0.026919514,0.033476457,,,,,,,,,,,,,,,,, -11400,0.020846127,0.030798823,,,,,,,,,,,,,,,,, -11500,0.021915548,0.033669185,,,,,,,,,,,,,,,,, -11600,0.025951304,0.034071513,,,,,,,,,,,,,,,,, -11700,0.033356134,0.034349784,,,,,,,,,,,,,,,,, -11800,0.023773266,0.031280152,,,,,,,,,,,,,,,,, -11900,0.023528881,0.035572257,,,,,,,,,,,,,,,,, -12000,0.030521415,0.035054326,,,,,,,,,,,,,,,,, -12007,,,0.9903555512428284,0.0322158448398113,0.3548928727455357,0.986707866191864,0.0444315969944,0.2554358455220814,43793.0,0.9858141541481018,0.0472727715969085,0.2503482735494056,43793.0,3854.4456446170807,5634.432926416397,3854.4456446170807,1778.9284162521362,0.7189178466796875,0.0 -12100,0.025720123,0.035832077,,,,,,,,,,,,,,,,, -12200,0.03361814,0.034126997,,,,,,,,,,,,,,,,, -12300,0.023120286,0.030561807,,,,,,,,,,,,,,,,, -12400,0.02409161,0.028873967,,,,,,,,,,,,,,,,, -12500,0.024807893,0.031983603,,,,,,,,,,,,,,,,, -12600,0.03048768,0.036163125,,,,,,,,,,,,,,,,, -12700,0.029302878,0.032673184,,,,,,,,,,,,,,,,, -12768,,,0.9902983903884888,0.0321623496711254,0.3728056786919886,0.9867314100265504,0.0442637987434864,0.2605304608129135,43793.0,0.985888659954071,0.0471827387809753,0.2574348691223886,43793.0,4094.401614904404,5973.4755423069,4094.401614904404,1877.9654395580287,0.7478671073913574,0.0 -12800,0.03599893,0.035020396,,,,,,,,,,,,,,,,, -12900,0.025444094,0.030986771,,,,,,,,,,,,,,,,, -13000,0.02828396,0.031911265,,,,,,,,,,,,,,,,, -13100,0.031481653,0.03697779,,,,,,,,,,,,,,,,, -13200,0.022647936,0.031708356,,,,,,,,,,,,,,,,, -13300,0.029990496,0.030247776,,,,,,,,,,,,,,,,, -13400,0.039732635,0.03853805,,,,,,,,,,,,,,,,, -13500,0.025611687,0.027078351,,,,,,,,,,,,,,,,, -13526,,,0.9905072450637816,0.0311836674809455,0.3852428672374816,0.986772358417511,0.0443389266729354,0.2648099308155049,43793.0,0.9858831763267516,0.0473159775137901,0.2539916084768826,43793.0,4334.481739997864,6313.919853925705,4334.481739997864,1978.278470039368,0.7787857055664062,0.0 -13600,0.04065854,0.032329474,,,,,,,,,,,,,,,,, -13700,0.040263973,0.035553332,,,,,,,,,,,,,,,,, -13800,0.03435867,0.035547383,,,,,,,,,,,,,,,,, -13900,0.03545283,0.032920957,,,,,,,,,,,,,,,,, -14000,0.03553705,0.032612786,,,,,,,,,,,,,,,,, -14100,0.032292023,0.03433002,,,,,,,,,,,,,,,,, -14200,0.050673563,0.031800434,,,,,,,,,,,,,,,,, -14281,,,0.9905017018318176,0.0312279406934976,0.3851849660303862,0.9866887331008912,0.0443001464009285,0.2595797861296848,43793.0,0.9858419299125672,0.0471290312707424,0.2575491563365366,43793.0,4574.704688310623,6658.48993730545,4574.704688310623,2082.57606124878,0.807380199432373,0.0 -14300,0.02963247,0.03223173,,,,,,,,,,,,,,,,, -14400,0.03598155,0.034365196,,,,,,,,,,,,,,,,, -14500,0.040094707,0.029248187,,,,,,,,,,,,,,,,, -14600,0.03148175,0.032387085,,,,,,,,,,,,,,,,, -14700,0.037509758,0.030941822,,,,,,,,,,,,,,,,, -14800,0.041765623,0.033660542,,,,,,,,,,,,,,,,, -14900,0.03192655,0.036666945,,,,,,,,,,,,,,,,, -15000,0.034613855,0.03234746,,,,,,,,,,,,,,,,, -15031,,,0.9906373023986816,0.0307508446276187,0.4132004282045805,0.9867650866508484,0.0446204729378223,0.2628655068956396,43793.0,0.9859139323234558,0.0474883429706096,0.2540615756671235,43793.0,4814.931003808975,6998.6618309021,4814.931003808975,2182.466774225235,0.8405594825744629,0.0 -15100,0.033626646,0.033648502,,,,,,,,,,,,,,,,, -15200,0.029966697,0.03233802,,,,,,,,,,,,,,,,, -15300,0.04172743,0.034365937,,,,,,,,,,,,,,,,, -15400,0.04325428,0.03535105,,,,,,,,,,,,,,,,, -15500,0.033558924,0.03098931,,,,,,,,,,,,,,,,, -15600,0.045372438,0.0364952,,,,,,,,,,,,,,,,, -15700,0.038498983,0.03208601,,,,,,,,,,,,,,,,, -15795,,,0.990784227848053,0.0301042962819337,0.4046015524426429,0.9867382645606996,0.044171217828989,0.2618256099036918,43793.0,0.9858680367469788,0.0467845201492309,0.2557980212280484,43793.0,5054.909048080444,7342.298298835754,5054.909048080444,2286.075028419494,0.8707151412963867,0.0 -15800,0.028416846,0.030227663,,,,,,,,,,,,,,,,, -15900,0.048907213,0.0330029,,,,,,,,,,,,,,,,, -16000,0.03662836,0.032546043,,,,,,,,,,,,,,,,, -16100,0.034461077,0.031972762,,,,,,,,,,,,,,,,, -16200,0.033100955,0.031361386,,,,,,,,,,,,,,,,, -16300,0.030747801,0.026942395,,,,,,,,,,,,,,,,, -16400,0.03484154,0.031915892,,,,,,,,,,,,,,,,, -16500,0.031446904,0.02920504,,,,,,,,,,,,,,,,, -16557,,,0.9911503791809082,0.029118113219738,0.4419770802823985,0.986614465713501,0.0442926213145256,0.2690583731034517,43793.0,0.9858490824699402,0.0468463972210884,0.2617767080606321,43793.0,5294.891679048538,7686.331763267517,5294.891679048538,2390.076048135757,0.8997907638549805,0.0 -16600,0.048939317,0.032847457,,,,,,,,,,,,,,,,, -16700,0.044274893,0.034475457,,,,,,,,,,,,,,,,, -16800,0.04508133,0.03401222,,,,,,,,,,,,,,,,, -16900,0.039463043,0.033897202,,,,,,,,,,,,,,,,, -17000,0.040696066,0.033507977,,,,,,,,,,,,,,,,, -17100,0.04026772,0.032540407,,,,,,,,,,,,,,,,, -17200,0.033752047,0.03370817,,,,,,,,,,,,,,,,, -17300,0.03849174,0.028844416,,,,,,,,,,,,,,,,, -17305,,,0.9909698367118835,0.0296294633299112,0.425239852981252,0.9866148829460144,0.0441667363047599,0.2680382857825597,43793.0,0.9858777523040771,0.0467349551618099,0.2607631513044034,43793.0,5534.965999603272,8027.069473028183,5534.965999603272,2490.6851439476013,0.932380437850952,0.0 -17400,0.03896544,0.03425124,,,,,,,,,,,,,,,,, -17500,0.040883496,0.035519846,,,,,,,,,,,,,,,,, -17600,0.042841576,0.031568293,,,,,,,,,,,,,,,,, -17700,0.040042114,0.032627627,,,,,,,,,,,,,,,,, -17800,0.041958623,0.03481358,,,,,,,,,,,,,,,,, -17900,0.042791475,0.03430356,,,,,,,,,,,,,,,,, -18000,0.055153415,0.03371705,,,,,,,,,,,,,,,,, -18067,,,0.99085795879364,0.0300395190715789,0.4119247732666293,0.9867743849754332,0.0440140813589096,0.2692154002615283,43793.0,0.9858953952789308,0.0467287376523017,0.2647181635576729,43793.0,5775.012037992477,8372.382522583008,5775.012037992477,2595.902805566788,0.9617249965667723,0.0 -18100,0.055627592,0.033893626,,,,,,,,,,,,,,,,, -18200,0.040104963,0.0320944,,,,,,,,,,,,,,,,, -18300,0.07381666,0.033131994,,,,,,,,,,,,,,,,, -18400,0.05362834,0.032989103,,,,,,,,,,,,,,,,, -18500,0.037924837,0.030105885,,,,,,,,,,,,,,,,, -18600,0.043675173,0.033157367,,,,,,,,,,,,,,,,, -18700,0.04571906,0.029672356,,,,,,,,,,,,,,,,, -18800,0.042660177,0.030260628,,,,,,,,,,,,,,,,, -18826,,,0.9908071160316468,0.0303335282951593,0.4082732006218971,0.986763834953308,0.0440635085105896,0.2648081561789626,43793.0,0.985905945301056,0.0468245893716812,0.261026378969448,43793.0,6014.981848716736,8712.026437044144,6014.981848716736,2695.526474237442,0.9920079708099364,0.0 -18900,0.044568595,0.032041706,,,,,,,,,,,,,,,,, -19000,0.039627507,0.0330381,,,,,,,,,,,,,,,,, -19100,0.039862353,0.031023415,,,,,,,,,,,,,,,,, -19200,0.049964968,0.02779777,,,,,,,,,,,,,,,,, -19300,0.046686362,0.03062361,,,,,,,,,,,,,,,,, -19400,0.042368107,0.030664556,,,,,,,,,,,,,,,,, -19500,0.036865562,0.027837697,,,,,,,,,,,,,,,,, -19588,,,0.990985095500946,0.0293735228478908,0.42571662301725,0.986893355846405,0.0441452153027057,0.2726999962014126,43793.0,0.9860677123069764,0.0467986539006233,0.269728862963543,43793.0,6255.140200138092,9053.31315279007,6255.140200138092,2796.600333929062,1.0257935523986816,0.0 -19600,0.048222158,0.03410129,,,,,,,,,,,,,,,,, -19700,0.07591638,0.033323854,,,,,,,,,,,,,,,,, -19800,0.057155553,0.035654172,,,,,,,,,,,,,,,,, -19900,0.037815686,0.028447995,,,,,,,,,,,,,,,,, -20000,0.04600245,0.029907247,,,,,,,,,,,,,,,,, -20100,0.047024637,0.029145401,,,,,,,,,,,,,,,,, -20200,0.03754542,0.029546987,,,,,,,,,,,,,,,,, -20300,0.041230332,0.029799266,,,,,,,,,,,,,,,,, -20343,,,0.9908535480499268,0.0297939609736204,0.4299498885823914,0.9868730306625366,0.0438564978539943,0.27432950631949,43793.0,0.9860925674438475,0.0466288626194,0.2667238021460736,43793.0,6495.229241847992,9397.442924976349,6495.229241847992,2900.5887792110443,1.0550963878631592,0.0 -20400,0.039200068,0.029964535,,,,,,,,,,,,,,,,, -20500,0.04458188,0.033079613,,,,,,,,,,,,,,,,, -20600,0.049571607,0.025950726,,,,,,,,,,,,,,,,, -20700,0.054589484,0.03244104,,,,,,,,,,,,,,,,, -20800,0.049604483,0.031800993,,,,,,,,,,,,,,,,, -20900,0.040115666,0.030120334,,,,,,,,,,,,,,,,, -21000,0.06750355,0.032426808,,,,,,,,,,,,,,,,, -21098,,,0.9910489320755004,0.0291868932545185,0.4427940285589564,0.9868564009666444,0.0438627004623413,0.2687042846949644,43793.0,0.9860213398933412,0.0465683303773403,0.2609999197800985,43793.0,6735.494882106781,9740.919793367386,6735.494882106781,3003.7433593273163,1.089904546737671,0.0 -21100,0.044165425,0.029104315,,,,,,,,,,,,,,,,, -21200,0.048487253,0.030657785,,,,,,,,,,,,,,,,, -21300,0.052999217,0.033338364,,,,,,,,,,,,,,,,, -21400,0.046163164,0.031263817,,,,,,,,,,,,,,,,, -21500,0.08054214,0.029409772,,,,,,,,,,,,,,,,, -21600,0.046234153,0.03383922,,,,,,,,,,,,,,,,, -21700,0.03896831,0.02979286,,,,,,,,,,,,,,,,, -21800,0.054267034,0.031136625,,,,,,,,,,,,,,,,, -21856,,,0.9909344911575316,0.0291950646787881,0.4393534144632743,0.9868978261947632,0.0446341596543788,0.2735190565901901,43793.0,0.9861317276954652,0.0474246628582477,0.2671671664715521,43793.0,6975.677932262421,10084.39429950714,6975.677932262421,3106.980573654175,1.1237335205078125,0.0 -21900,0.040683128,0.031465407,,,,,,,,,,,,,,,,, -22000,0.04013062,0.029553233,,,,,,,,,,,,,,,,, -22100,0.03852319,0.02702579,,,,,,,,,,,,,,,,, -22200,0.050040293,0.032927383,,,,,,,,,,,,,,,,, -22300,0.051816687,0.03275423,,,,,,,,,,,,,,,,, -22400,0.062775694,0.032043893,,,,,,,,,,,,,,,,, -22500,0.050235905,0.028876033,,,,,,,,,,,,,,,,, -22600,0.060930196,0.034340575,,,,,,,,,,,,,,,,, -22612,,,0.9911161065101624,0.0289258304983377,0.4461390763695887,0.9867297410964966,0.0439446568489074,0.2784422440583035,43793.0,0.9859792590141296,0.0466558374464511,0.2688925711781614,43793.0,7215.721256017685,10425.84306025505,7215.721256017685,3208.3332257270813,1.1560900211334229,0.0 -22700,0.044661243,0.030494975,,,,,,,,,,,,,,,,, -22800,0.051949024,0.03611276,,,,,,,,,,,,,,,,, -22900,0.058846958,0.030175198,,,,,,,,,,,,,,,,, -23000,0.051604494,0.03235457,,,,,,,,,,,,,,,,, -23100,0.056440026,0.031354178,,,,,,,,,,,,,,,,, -23200,0.0477246,0.028013606,,,,,,,,,,,,,,,,, -23300,0.048576053,0.030934548,,,,,,,,,,,,,,,,, -23366,,,0.9914578199386596,0.0278907660394907,0.4627701374352354,0.9868608713150024,0.0440641902387142,0.276925187531418,43793.0,0.9861093759536744,0.0466336794197559,0.2682321435300443,43793.0,7455.825122117996,10767.822446346285,7455.825122117996,3310.157596349716,1.1869611740112305,0.0 -23400,0.04745346,0.032676447,,,,,,,,,,,,,,,,, -23500,0.042796865,0.029041095,,,,,,,,,,,,,,,,, -23600,0.042266645,0.029346911,,,,,,,,,,,,,,,,, -23700,0.04858021,0.03267894,,,,,,,,,,,,,,,,, -23800,0.048568476,0.029214427,,,,,,,,,,,,,,,,, -23900,0.072398715,0.031504422,,,,,,,,,,,,,,,,, -24000,0.04264006,0.03188419,,,,,,,,,,,,,,,,, -24100,0.057005584,0.028883003,,,,,,,,,,,,,,,,, -24123,,,0.9916008114814758,0.0272237285971641,0.4873757299629083,0.9869104027748108,0.0439563207328319,0.2830847361580403,43793.0,0.9861299991607666,0.0465262271463871,0.2760676809335625,43793.0,7695.909645080566,11106.817147493362,7695.909645080566,3409.015555858612,1.2186834812164309,0.0 -24200,0.05272388,0.032619346,,,,,,,,,,,,,,,,, -24300,0.039821666,0.030722855,,,,,,,,,,,,,,,,, -24400,0.053230938,0.02945503,,,,,,,,,,,,,,,,, -24500,0.054203674,0.031192007,,,,,,,,,,,,,,,,, -24600,0.051685933,0.028513018,,,,,,,,,,,,,,,,, -24700,0.04709097,0.033145357,,,,,,,,,,,,,,,,, -24800,0.047009744,0.027945206,,,,,,,,,,,,,,,,, -24882,,,0.9916938543319702,0.026998370885849,0.4925730935016073,0.9868324398994446,0.0440956726670265,0.2810012658860785,43793.0,0.9860736131668092,0.0468006432056427,0.2785355949087478,43793.0,7936.176169872284,11449.778022766111,7936.176169872284,3511.65873837471,1.2492575645446775,0.0 -24900,0.043528456,0.027438786,,,,,,,,,,,,,,,,, -25000,0.046136264,0.027483756,,,,,,,,,,,,,,,,, -25100,0.060749765,0.03193876,,,,,,,,,,,,,,,,, -25200,0.050579924,0.029719157,,,,,,,,,,,,,,,,, -25300,0.047199447,0.032067742,,,,,,,,,,,,,,,,, -25400,0.059515294,0.027026432,,,,,,,,,,,,,,,,, -25500,0.047459256,0.030939864,,,,,,,,,,,,,,,,, -25600,0.053004995,0.03157911,,,,,,,,,,,,,,,,, -25646,,,0.9915238618850708,0.0275689717382192,0.4713439378127574,0.9869733452796936,0.0439281985163688,0.2816985678081601,43793.0,0.986088752746582,0.0467182286083698,0.275252788624737,43793.0,8176.215074539185,11788.014889717102,8176.215074539185,3609.806525945664,1.2795326709747314,0.0 -25700,0.043929,0.03051071,,,,,,,,,,,,,,,,, -25800,0.06486014,0.031181967,,,,,,,,,,,,,,,,, -25900,0.056474976,0.03154034,,,,,,,,,,,,,,,,, -26000,0.04373993,0.028705887,,,,,,,,,,,,,,,,, -26100,0.048401024,0.027227411,,,,,,,,,,,,,,,,, -26200,0.05072252,0.031299617,,,,,,,,,,,,,,,,, -26300,0.060823392,0.03183875,,,,,,,,,,,,,,,,, -26400,0.06014757,0.030153738,,,,,,,,,,,,,,,,, -26404,,,0.9913026094436646,0.0280513260513544,0.4540246569703544,0.9868937730789183,0.0440740473568439,0.2784015307768114,43793.0,0.9861670732498168,0.0467212721705436,0.272193655521218,43793.0,8416.223066806793,12129.615074634552,8416.223066806793,3711.346279144287,1.3122563362121582,0.0 -26500,0.052432515,0.029322304,,,,,,,,,,,,,,,,, -26600,0.051556885,0.03072498,,,,,,,,,,,,,,,,, -26700,0.05050896,0.0289872,,,,,,,,,,,,,,,,, -26800,0.050560273,0.03240213,,,,,,,,,,,,,,,,, -26900,0.054193012,0.027117888,,,,,,,,,,,,,,,,, -27000,0.06433236,0.033572514,,,,,,,,,,,,,,,,, -27100,0.055061094,0.030312618,,,,,,,,,,,,,,,,, -27167,,,0.9913366436958312,0.0279964916408061,0.4512330298512755,0.9870614409446716,0.0438373424112796,0.2824528369850656,43793.0,0.9861894249916076,0.0466781705617904,0.2716197370065153,43793.0,8656.387609004974,12474.14634013176,8656.387609004974,3815.66153049469,1.343522071838379,0.0 -27200,0.04663638,0.028563261,,,,,,,,,,,,,,,,, -27300,0.050960887,0.031227332,,,,,,,,,,,,,,,,, -27400,0.05276881,0.03086088,,,,,,,,,,,,,,,,, -27500,0.05059768,0.028083064,,,,,,,,,,,,,,,,, -27600,0.08629754,0.029227309,,,,,,,,,,,,,,,,, -27700,0.06276894,0.03160479,,,,,,,,,,,,,,,,, -27800,0.049501166,0.031959243,,,,,,,,,,,,,,,,, -27900,0.047186807,0.026410265,,,,,,,,,,,,,,,,, -27924,,,0.9912831783294678,0.0282191280275583,0.459710992591313,0.9870139360427856,0.0440753027796745,0.2801922874370537,43793.0,0.986270308494568,0.0467647463083267,0.2791487319363641,43793.0,8896.527443885803,12813.708843708038,8896.527443885803,3915.033325195313,1.374058961868286,0.0 -28000,0.055111438,0.030719137,,,,,,,,,,,,,,,,, -28100,0.053146668,0.033369433,,,,,,,,,,,,,,,,, -28200,0.053362854,0.02871244,,,,,,,,,,,,,,,,, -28300,0.054239947,0.025143031,,,,,,,,,,,,,,,,, -28400,0.06799235,0.030344725,,,,,,,,,,,,,,,,, -28500,0.05487769,0.030465709,,,,,,,,,,,,,,,,, -28600,0.07965666,0.035483,,,,,,,,,,,,,,,,, -28676,,,0.9914658069610596,0.0276060681790113,0.4804013392235434,0.9868357181549072,0.043917428702116,0.2829397828021821,43793.0,0.9860512614250184,0.0466743782162666,0.2696900439508503,43793.0,9136.534759044647,13159.920971632004,9136.534759044647,4021.181110858917,1.4105889797210691,0.0 -28700,0.049073126,0.028606413,,,,,,,,,,,,,,,,, -28800,0.046880744,0.027411921,,,,,,,,,,,,,,,,, -28900,0.054311566,0.0334329,,,,,,,,,,,,,,,,, -29000,0.055010073,0.03158235,,,,,,,,,,,,,,,,, -29100,0.05456161,0.029325655,,,,,,,,,,,,,,,,, -29200,0.053387027,0.030015577,,,,,,,,,,,,,,,,, -29300,0.056796666,0.03165876,,,,,,,,,,,,,,,,, -29400,0.06203795,0.032352805,,,,,,,,,,,,,,,,, -29429,,,0.9914921522140504,0.0274810884147882,0.4786408746212388,0.98698753118515,0.0441337078809738,0.280014814493845,43793.0,0.9862012267112732,0.0470006801187992,0.27495089302911,43793.0,9376.566728830338,13502.746740341188,9376.566728830338,4123.922771692276,1.4421751499176023,0.0 -29500,0.06175399,0.0309114,,,,,,,,,,,,,,,,, -29600,0.05023576,0.030346619,,,,,,,,,,,,,,,,, -29700,0.04703344,0.028814392,,,,,,,,,,,,,,,,, -29800,0.055732466,0.030166822,,,,,,,,,,,,,,,,, -29900,0.06934131,0.02928124,,,,,,,,,,,,,,,,, -30000,0.07280761,0.033540837,,,,,,,,,,,,,,,,, -30100,0.06547518,0.029994246,,,,,,,,,,,,,,,,, -30193,,,0.9915912747383118,0.0271249692887067,0.4820567404347408,0.9869375824928284,0.0437896288931369,0.2850369521914939,43793.0,0.9862092137336732,0.0464126951992511,0.2754781499849553,43793.0,9616.63538479805,13843.735455036163,9616.63538479805,4224.7899651527405,1.474400281906128,0.0 -30200,0.091807246,0.0266629,,,,,,,,,,,,,,,,, -30300,0.06301159,0.03002876,,,,,,,,,,,,,,,,, -30400,0.07728087,0.028245581,,,,,,,,,,,,,,,,, -30500,0.06161957,0.026492566,,,,,,,,,,,,,,,,, -30600,0.06475015,0.0318001,,,,,,,,,,,,,,,,, -30700,0.049901295,0.026699401,,,,,,,,,,,,,,,,, -30800,0.056401618,0.03233684,,,,,,,,,,,,,,,,, -30900,0.062386125,0.028742021,,,,,,,,,,,,,,,,, -30960,,,0.9916066527366638,0.0269829258322715,0.4897253311971881,0.987052083015442,0.0439271628856658,0.2813805981787919,43793.0,0.9862736463546752,0.0467136912047863,0.2736675093336144,43793.0,9856.632263422012,14187.458263158798,9856.632263422012,4328.463456869125,1.5062589645385742,0.0 -31000,0.063775346,0.029166613,,,,,,,,,,,,,,,,, -31100,0.07039522,0.031228177,,,,,,,,,,,,,,,,, -31200,0.06365692,0.029237537,,,,,,,,,,,,,,,,, -31300,0.057748288,0.028233917,,,,,,,,,,,,,,,,, -31400,0.06679183,0.028153587,,,,,,,,,,,,,,,,, -31500,0.060273338,0.026799494,,,,,,,,,,,,,,,,, -31600,0.064926475,0.030060047,,,,,,,,,,,,,,,,, -31700,0.056130044,0.029774537,,,,,,,,,,,,,,,,, -31724,,,0.991977870464325,0.025896318256855,0.4988183288868678,0.9869737029075624,0.0437305495142936,0.287502596920722,43793.0,0.9861574172973632,0.0463861934840679,0.2759647672414221,43793.0,10096.578315734863,14524.637803077698,10096.578315734863,4425.642622709274,1.5399868488311768,0.0 -31800,0.054418564,0.028243205,,,,,,,,,,,,,,,,, -31900,0.053867072,0.027491221,,,,,,,,,,,,,,,,, -32000,0.057209525,0.032556772,,,,,,,,,,,,,,,,, -32100,0.05379695,0.026088146,,,,,,,,,,,,,,,,, -32200,0.0523243,0.029510472,,,,,,,,,,,,,,,,, -32300,0.055360902,0.029350521,,,,,,,,,,,,,,,,, -32400,0.06158691,0.031285245,,,,,,,,,,,,,,,,, -32485,,,0.9921050071716307,0.0254569556564092,0.5302573368344676,0.986983060836792,0.0440465398132801,0.2828144092560208,43793.0,0.9862816333770752,0.0465779192745685,0.2800528015176483,43793.0,10336.690904140472,14867.57130074501,10336.690904140472,4528.410791397095,1.5721287727355957,0.0 -32500,0.05209392,0.02801238,,,,,,,,,,,,,,,,, -32600,0.067181274,0.031695552,,,,,,,,,,,,,,,,, -32700,0.07007427,0.027250879,,,,,,,,,,,,,,,,, -32800,0.059845217,0.029147316,,,,,,,,,,,,,,,,, -32900,0.06043786,0.027927341,,,,,,,,,,,,,,,,, -33000,0.053935967,0.027014326,,,,,,,,,,,,,,,,, -33100,0.05598169,0.030184785,,,,,,,,,,,,,,,,, -33200,0.06978069,0.029544987,,,,,,,,,,,,,,,,, -33248,,,0.99212908744812,0.0253685247153043,0.5184021366579336,0.9869611263275146,0.0442101545631885,0.2848071117613658,43793.0,0.986216366291046,0.0468509458005428,0.2793535608483301,43793.0,10576.917943954468,15205.792612075806,10576.917943954468,4626.351637125015,1.6052203178405762,0.0 -33300,0.0567337,0.028952068,,,,,,,,,,,,,,,,, -33400,0.06945728,0.030690059,,,,,,,,,,,,,,,,, -33500,0.057357974,0.027793022,,,,,,,,,,,,,,,,, -33600,0.06438409,0.026540691,,,,,,,,,,,,,,,,, -33700,0.052585326,0.025664344,,,,,,,,,,,,,,,,, -33800,0.06316133,0.026525648,,,,,,,,,,,,,,,,, -33900,0.06374389,0.030058613,,,,,,,,,,,,,,,,, -34000,0.08261108,0.030385956,,,,,,,,,,,,,,,,, -34005,,,0.9920830726623536,0.02560705691576,0.5206663689703317,0.9869534373283386,0.0438498817384243,0.2907716835350484,43793.0,0.9860959053039552,0.0467146001756191,0.2774428806202462,43793.0,10816.88122177124,15550.581326246262,10816.88122177124,4731.118203163147,1.6438887119293213,0.0 -34100,0.055312965,0.027157351,,,,,,,,,,,,,,,,, -34200,0.057043158,0.028081486,,,,,,,,,,,,,,,,, -34300,0.058511727,0.025832731,,,,,,,,,,,,,,,,, -34400,0.063097194,0.029552039,,,,,,,,,,,,,,,,, -34500,0.057919215,0.027189758,,,,,,,,,,,,,,,,, -34600,0.0488597,0.025121696,,,,,,,,,,,,,,,,, -34700,0.05825049,0.028511602,,,,,,,,,,,,,,,,, -34766,,,0.9918881058692932,0.0260873194783926,0.5010801180054504,0.9869778156280518,0.0440643094480037,0.2841008787928426,43793.0,0.9861780405044556,0.0469150133430957,0.2740014648998188,43793.0,11056.937723875046,15889.510733604431,11056.937723875046,4829.9352684021,1.679349660873413,0.0 -34800,0.07325739,0.029804146,,,,,,,,,,,,,,,,, -34900,0.066657774,0.029921388,,,,,,,,,,,,,,,,, -35000,0.054884247,0.028916886,,,,,,,,,,,,,,,,, -35100,0.063221656,0.028361727,,,,,,,,,,,,,,,,, -35200,0.062569454,0.030301802,,,,,,,,,,,,,,,,, -35300,0.06324074,0.028113218,,,,,,,,,,,,,,,,, -35400,0.07551968,0.030150376,,,,,,,,,,,,,,,,, -35500,0.08640131,0.034033075,,,,,,,,,,,,,,,,, -35525,,,0.991990566253662,0.0258852709084749,0.5093072745034118,0.9869050979614258,0.0439064614474773,0.287430191108912,43793.0,0.986100971698761,0.0467952117323875,0.2735042606028195,43793.0,11296.953873872755,16228.668662548063,11296.953873872755,4929.022217512131,1.7137835025787354,0.0 -35600,0.061959207,0.027035272,,,,,,,,,,,,,,,,, -35700,0.058769803,0.027930679,,,,,,,,,,,,,,,,, -35800,0.06265761,0.030466883,,,,,,,,,,,,,,,,, -35900,0.08310347,0.030223684,,,,,,,,,,,,,,,,, -36000,0.06291688,0.025857372,,,,,,,,,,,,,,,,, -36100,0.072065555,0.027961984,,,,,,,,,,,,,,,,, -36200,0.054615058,0.026954692,,,,,,,,,,,,,,,,, -36284,,,0.9918835759162904,0.0260820463299751,0.50537218004944,0.987015962600708,0.0443647243082523,0.2867905841474967,43793.0,0.9861708879470824,0.047328058630228,0.2739957375683177,43793.0,11537.133439779282,16574.356207609177,11537.133439779282,5034.475823879242,1.7479205131530762,0.0 -36300,0.06773881,0.03186379,,,,,,,,,,,,,,,,, -36400,0.057685602,0.026185816,,,,,,,,,,,,,,,,, -36500,0.08394299,0.031197244,,,,,,,,,,,,,,,,, -36600,0.06171966,0.030720487,,,,,,,,,,,,,,,,, -36700,0.057540428,0.03132042,,,,,,,,,,,,,,,,, -36800,0.068377115,0.027638385,,,,,,,,,,,,,,,,, -36900,0.059135545,0.023525095,,,,,,,,,,,,,,,,, -37000,0.056867503,0.028061073,,,,,,,,,,,,,,,,, -37041,,,0.991921603679657,0.0260071270167827,0.5088854546962834,0.9869713187217712,0.0440560989081859,0.2855104168089387,43793.0,0.9860870838165284,0.0468637980520725,0.2749342780922189,43793.0,11777.099576950071,16911.1993830204,11777.099576950071,5131.294198036194,1.7850227355957031,0.0 -37100,0.081625074,0.031504344,,,,,,,,,,,,,,,,, -37200,0.06654474,0.030461296,,,,,,,,,,,,,,,,, -37300,0.06719216,0.02765267,,,,,,,,,,,,,,,,, -37400,0.06263457,0.025866842,,,,,,,,,,,,,,,,, -37500,0.076494,0.028427573,,,,,,,,,,,,,,,,, -37600,0.061461758,0.0294307,,,,,,,,,,,,,,,,, -37700,0.08360696,0.029044965,,,,,,,,,,,,,,,,, -37792,,,0.9920625686645508,0.0254651494324207,0.5246759598596786,0.9870808720588684,0.0442389287054538,0.2901478151348224,43793.0,0.9861211776733398,0.0471453480422496,0.2790645129473144,43793.0,12017.13541841507,17256.310554504395,12017.13541841507,5236.3157749176025,1.818418979644776,0.0 -37800,0.062329415,0.024831166,,,,,,,,,,,,,,,,, -37900,0.06256646,0.028469326,,,,,,,,,,,,,,,,, -38000,0.072561085,0.030156326,,,,,,,,,,,,,,,,, -38100,0.071722515,0.027897868,,,,,,,,,,,,,,,,, -38200,0.07000978,0.03059865,,,,,,,,,,,,,,,,, -38300,0.06497229,0.026768602,,,,,,,,,,,,,,,,, -38400,0.057891738,0.030079814,,,,,,,,,,,,,,,,, -38500,0.08521105,0.026907315,,,,,,,,,,,,,,,,, -38539,,,0.992132604122162,0.025216331705451,0.5271417505934661,0.9870041608810424,0.0443718582391738,0.284068813649276,43793.0,0.9862340688705444,0.0469984784722328,0.2798594395019039,43793.0,12257.310943841934,17599.25080871582,12257.310943841934,5339.024819612503,1.8535473346710205,0.0 -38600,0.0654478,0.026376812,,,,,,,,,,,,,,,,, -38700,0.070408545,0.02946896,,,,,,,,,,,,,,,,, -38800,0.06414118,0.027700378,,,,,,,,,,,,,,,,, -38900,0.06677009,0.025461765,,,,,,,,,,,,,,,,, -39000,0.07795746,0.030359594,,,,,,,,,,,,,,,,, -39100,0.090928115,0.028922439,,,,,,,,,,,,,,,,, -39200,0.06659246,0.0285262,,,,,,,,,,,,,,,,, -39297,,,0.9923343062400818,0.0245491247624158,0.5338019086037505,0.987090229988098,0.0443548746407032,0.2878398362951603,43793.0,0.9863966703414916,0.0471530593931674,0.2812594940719487,43793.0,12497.441692590714,17939.355797052383,12497.441692590714,5438.942674875259,1.888882160186768,0.0 -39300,0.059641723,0.027654627,,,,,,,,,,,,,,,,, -39400,0.078811936,0.030174855,,,,,,,,,,,,,,,,, -39500,0.10355286,0.028010761,,,,,,,,,,,,,,,,, -39600,0.06621094,0.027052935,,,,,,,,,,,,,,,,, -39700,0.0725314,0.027251862,,,,,,,,,,,,,,,,, -39800,0.073500045,0.028922712,,,,,,,,,,,,,,,,, -39900,0.06465893,0.026393002,,,,,,,,,,,,,,,,, -40000,0.06627331,0.02554778,,,,,,,,,,,,,,,,, -40049,,,0.9923198223114014,0.0243196655064821,0.5546808189751687,0.9871109127998352,0.0443590618669986,0.2893713804221078,43793.0,0.9862483739852904,0.0473817698657512,0.2766928416127924,43793.0,12737.689799547195,18286.61421489716,12737.689799547195,5545.897831678391,1.9227867126464844,0.0 -40100,0.08370597,0.028492467,,,,,,,,,,,,,,,,, -40200,0.072739504,0.026484465,,,,,,,,,,,,,,,,, -40300,0.076494955,0.026024481,,,,,,,,,,,,,,,,, -40400,0.06322236,0.02690465,,,,,,,,,,,,,,,,, -40500,0.0758935,0.03039119,,,,,,,,,,,,,,,,, -40600,0.059944287,0.02356538,,,,,,,,,,,,,,,,, -40700,0.057689182,0.024486689,,,,,,,,,,,,,,,,, -40800,0.07111725,0.028085988,,,,,,,,,,,,,,,,, -40802,,,0.9927132725715636,0.0232907254248857,0.5798397030318193,0.987059772014618,0.0447507388889789,0.2860022023207019,43793.0,0.9862837791442872,0.0476432628929615,0.2762729613561593,43793.0,12977.942219495771,18627.68426156044,12977.942219495771,5646.659322977066,1.9570858478546145,0.0 -40900,0.08653608,0.027990159,,,,,,,,,,,,,,,,, -41000,0.0695601,0.025054356,,,,,,,,,,,,,,,,, -41100,0.07643107,0.029647773,,,,,,,,,,,,,,,,, -41200,0.060486462,0.028078604,,,,,,,,,,,,,,,,, -41300,0.088883385,0.027098257,,,,,,,,,,,,,,,,, -41400,0.06633561,0.029338717,,,,,,,,,,,,,,,,, -41500,0.06276861,0.024279656,,,,,,,,,,,,,,,,, -41565,,,0.9927103519439696,0.023197915405035,0.5746574374680143,0.9871835708618164,0.0446155592799186,0.28433454948306,43793.0,0.9863145351409912,0.0475908629596233,0.2851083116733161,43793.0,13217.97829079628,18978.81697440148,13217.97829079628,5757.701861858368,1.9907326698303225,0.0 -41600,0.08252718,0.026941137,,,,,,,,,,,,,,,,, -41700,0.0700563,0.028393658,,,,,,,,,,,,,,,,, -41800,0.08671772,0.030325452,,,,,,,,,,,,,,,,, -41900,0.06695008,0.027689096,,,,,,,,,,,,,,,,, -42000,0.068257704,0.024686625,,,,,,,,,,,,,,,,, -42100,0.070775025,0.027267365,,,,,,,,,,,,,,,,, -42200,0.07354007,0.02489987,,,,,,,,,,,,,,,,, -42300,0.07989872,0.026464067,,,,,,,,,,,,,,,,, -42321,,,0.9925374984741212,0.0237526260316371,0.5571466336294107,0.9870622158050536,0.0449502542614936,0.2945649533936957,43793.0,0.9861207604408264,0.0483300276100635,0.2736255863559124,43793.0,13458.056258440018,19331.04762125016,13458.056258440018,5869.774827003479,2.050032138824463,0.0 -42400,0.08916412,0.027177574,,,,,,,,,,,,,,,,, -42500,0.07413914,0.02422952,,,,,,,,,,,,,,,,, -42600,0.064481884,0.026293164,,,,,,,,,,,,,,,,, -42700,0.06820167,0.02902102,,,,,,,,,,,,,,,,, -42800,0.0734001,0.02597504,,,,,,,,,,,,,,,,, -42900,0.072716616,0.027326535,,,,,,,,,,,,,,,,, -43000,0.06808147,0.026160695,,,,,,,,,,,,,,,,, -43067,,,0.9924070835113524,0.0241846218705177,0.5549188476915506,0.9869388341903688,0.0447618216276168,0.291296804280545,43793.0,0.9861182570457458,0.0477258823812007,0.2800975663738683,43793.0,13698.089906215668,19671.93588352204,13698.089906215668,5970.566066265106,2.092049360275269,0.0 -43100,0.07319195,0.023968488,,,,,,,,,,,,,,,,, -43200,0.06874672,0.025319949,,,,,,,,,,,,,,,,, -43300,0.07940838,0.027706396,,,,,,,,,,,,,,,,, -43400,0.09794385,0.025914904,,,,,,,,,,,,,,,,, -43500,0.08099219,0.02693615,,,,,,,,,,,,,,,,, -43600,0.06365985,0.022458715,,,,,,,,,,,,,,,,, -43700,0.07680281,0.026436968,,,,,,,,,,,,,,,,, -43800,0.08038449,0.028185422,,,,,,,,,,,,,,,,, -43822,,,0.9925096035003662,0.023991784080863,0.5581068188704152,0.987064242362976,0.0446946024894714,0.2883423968673922,43793.0,0.9862075448036194,0.0476227030158042,0.2781757933964989,43793.0,13938.330757379532,20014.78182053566,13938.330757379532,6073.116245508194,2.1265196800231934,0.0 -43900,0.07236278,0.024324605,,,,,,,,,,,,,,,,, -44000,0.07696525,0.028375426,,,,,,,,,,,,,,,,, -44100,0.0674788,0.025891358,,,,,,,,,,,,,,,,, -44200,0.08164332,0.030718438,,,,,,,,,,,,,,,,, -44300,0.10512776,0.028828006,,,,,,,,,,,,,,,,, -44400,0.08425749,0.025955949,,,,,,,,,,,,,,,,, -44500,0.07281524,0.02501946,,,,,,,,,,,,,,,,, -44573,,,0.9925292134284972,0.0238912533968687,0.5533522608731953,0.9870536923408508,0.0445699468255043,0.2946664919696532,43793.0,0.9861910939216614,0.0475999042391777,0.2807915067602074,43793.0,14178.45221710205,20357.951191186905,14178.45221710205,6176.107170343399,2.1627137660980225,0.0 -44600,0.068729654,0.026126932,,,,,,,,,,,,,,,,, -44700,0.07570584,0.023895673,,,,,,,,,,,,,,,,, -44800,0.0692735,0.022533206,,,,,,,,,,,,,,,,, -44900,0.07900559,0.026709497,,,,,,,,,,,,,,,,, -45000,0.090938255,0.026590932,,,,,,,,,,,,,,,,, -45100,0.08472679,0.027125223,,,,,,,,,,,,,,,,, -45200,0.06579418,0.027465086,,,,,,,,,,,,,,,,, -45300,0.07220897,0.023997933,,,,,,,,,,,,,,,,, -45328,,,0.9924713969230652,0.0238328706473112,0.5691807750617297,0.9870431423187256,0.0448387153446674,0.2838232387063066,43793.0,0.9862349033355712,0.0475789941847324,0.2797156681465258,43793.0,14418.65455508232,20695.19634771347,14418.65455508232,6273.095078229904,2.1975855827331543,0.0 -45400,0.081673495,0.025229715,,,,,,,,,,,,,,,,, -45500,0.0697843,0.025634985,,,,,,,,,,,,,,,,, -45600,0.09300834,0.028550487,,,,,,,,,,,,,,,,, -45700,0.07416688,0.024003876,,,,,,,,,,,,,,,,, -45800,0.09177756,0.02926047,,,,,,,,,,,,,,,,, -45900,0.08417577,0.030413477,,,,,,,,,,,,,,,,, -46000,0.07487623,0.026346574,,,,,,,,,,,,,,,,, -46094,,,0.9926730990409852,0.0232164915651083,0.5673038654986291,0.9871312379837036,0.0448323711752891,0.2927768764638855,43793.0,0.9862656593322754,0.0479391925036907,0.2842984568305961,43793.0,14658.61558651924,21041.355713129044,14658.61558651924,6379.238595485687,2.2321431636810303,0.0 -46100,0.09123207,0.024221739,,,,,,,,,,,,,,,,, -46200,0.07930575,0.025388632,,,,,,,,,,,,,,,,, -46300,0.08670111,0.027544247,,,,,,,,,,,,,,,,, -46400,0.08664011,0.023257518,,,,,,,,,,,,,,,,, -46500,0.08093656,0.023030091,,,,,,,,,,,,,,,,, -46600,0.09070646,0.027731592,,,,,,,,,,,,,,,,, -46700,0.06881768,0.026062751,,,,,,,,,,,,,,,,, -46800,0.115313515,0.027364843,,,,,,,,,,,,,,,,, -46841,,,0.9926999807357788,0.0229254197329282,0.5850914346518163,0.9871032238006592,0.0455067902803421,0.2930319312146007,43793.0,0.9862319827079772,0.0485468283295631,0.2778092930415843,43793.0,14898.701581716536,21383.17679190636,14898.701581716536,6480.91828584671,2.2671091556549072,0.0 -46900,0.071413815,0.024504215,,,,,,,,,,,,,,,,, -47000,0.07915052,0.025566738,,,,,,,,,,,,,,,,, -47100,0.079288974,0.027144527,,,,,,,,,,,,,,,,, -47200,0.073618345,0.02407895,,,,,,,,,,,,,,,,, -47300,0.08320059,0.026586736,,,,,,,,,,,,,,,,, -47400,0.09341573,0.026212027,,,,,,,,,,,,,,,,, -47500,0.090505585,0.02543989,,,,,,,,,,,,,,,,, -47593,,,0.992985725402832,0.0222699958831071,0.5896478412213515,0.9871113300323486,0.0449162758886814,0.2926934946306269,43793.0,0.9862273335456848,0.04789400100708,0.2869319142049717,43793.0,15138.667674064636,21725.787615060806,15138.667674064636,6583.505940437317,2.3025686740875244,0.0 -47600,0.08460386,0.025183583,,,,,,,,,,,,,,,,, -47700,0.08929211,0.028993908,,,,,,,,,,,,,,,,, -47800,0.07661527,0.025663795,,,,,,,,,,,,,,,,, -47900,0.083696544,0.025621632,,,,,,,,,,,,,,,,, -48000,0.07975831,0.025389304,,,,,,,,,,,,,,,,, -48100,0.06495157,0.020963479,,,,,,,,,,,,,,,,, -48200,0.0739425,0.023127738,,,,,,,,,,,,,,,,, -48300,0.06847539,0.025238412,,,,,,,,,,,,,,,,, -48351,,,0.9932321310043336,0.0214752256870269,0.6224910459932752,0.987035036087036,0.0448025874793529,0.2923386664862276,43793.0,0.986178457736969,0.0477825812995433,0.2834298438452727,43793.0,15378.75329375267,22065.98137497902,15378.75329375267,6683.557134151459,2.3388681411743164,0.0 -48400,0.08255899,0.029060328,,,,,,,,,,,,,,,,, -48500,0.07784767,0.023945248,,,,,,,,,,,,,,,,, -48600,0.08843362,0.024695132,,,,,,,,,,,,,,,,, -48700,0.101069376,0.026536487,,,,,,,,,,,,,,,,, -48800,0.06921427,0.023445012,,,,,,,,,,,,,,,,, -48900,0.090409584,0.02292265,,,,,,,,,,,,,,,,, -49000,0.10174254,0.027932303,,,,,,,,,,,,,,,,, -49100,0.10775705,0.026568828,,,,,,,,,,,,,,,,, -49109,,,0.9933977723121644,0.0210042484104633,0.6172502556198101,0.9870760440826416,0.0456252992153167,0.2926285593021795,43793.0,0.9862424731254578,0.0486608073115348,0.2857276248478931,43793.0,15618.89215707779,22406.23919153213,15618.89215707779,6783.601722478867,2.3929810523986816,0.0 -49200,0.079310514,0.024223728,,,,,,,,,,,,,,,,, -49300,0.089584015,0.024708724,,,,,,,,,,,,,,,,, -49400,0.08479188,0.02534107,,,,,,,,,,,,,,,,, -49500,0.096405655,0.023497734,,,,,,,,,,,,,,,,, -49600,0.08865409,0.025854357,,,,,,,,,,,,,,,,, -49700,0.070915,0.022806754,,,,,,,,,,,,,,,,, -49800,0.086990185,0.025256732,,,,,,,,,,,,,,,,, -49864,,,0.9931698441505432,0.0215233471244573,0.6213504048826404,0.9871621131896972,0.0453357174992561,0.2962699552477201,43793.0,0.9862942695617676,0.0483603700995445,0.2828741073363153,43793.0,15859.115253686905,22747.032354831696,15859.115253686905,6884.114263057709,2.4300451278686523,0.0 -49900,0.081954375,0.02364034,,,,,,,,,,,,,,,,, -50000,0.09395557,0.023003351,,,,,,,,,,,,,,,,, -50100,0.0988334,0.023475496,,,,,,,,,,,,,,,,, -50200,0.086227275,0.026795015,,,,,,,,,,,,,,,,, -50300,0.08578126,0.024055444,,,,,,,,,,,,,,,,, -50400,0.09129675,0.024809826,,,,,,,,,,,,,,,,, -50500,0.08721392,0.025873598,,,,,,,,,,,,,,,,, -50600,0.08550501,0.021300793,,,,,,,,,,,,,,,,, -50621,,,0.9931451082229614,0.0217219870537519,0.5996855145951254,0.9870354533195496,0.0454207174479961,0.2876952709587337,43793.0,0.9862441420555116,0.0485006384551525,0.2809669364114627,43793.0,16099.191632032394,23090.439814329147,16099.191632032394,6987.389058351517,2.4660558700561523,0.0 -50700,0.09064903,0.02485538,,,,,,,,,,,,,,,,, -50800,0.09463633,0.027086657,,,,,,,,,,,,,,,,, -50900,0.082903504,0.023467848,,,,,,,,,,,,,,,,, -51000,0.08246872,0.023780564,,,,,,,,,,,,,,,,, -51100,0.08745383,0.025300208,,,,,,,,,,,,,,,,, -51200,0.08141431,0.02330799,,,,,,,,,,,,,,,,, -51300,0.100755975,0.027055666,,,,,,,,,,,,,,,,, -51379,,,0.9931361675262452,0.0217958111315965,0.6021560760301627,0.9870212078094482,0.0455244444310665,0.2931287192061453,43793.0,0.9861915111541748,0.0485584177076816,0.2790803063476967,43793.0,16339.3504114151,23436.85264754296,16339.3504114151,7093.5865795612335,2.502041339874268,0.0 -51400,0.096465304,0.02565124,,,,,,,,,,,,,,,,, -51500,0.08637915,0.026250629,,,,,,,,,,,,,,,,, -51600,0.0928488,0.021891093,,,,,,,,,,,,,,,,, -51700,0.09115295,0.024277372,,,,,,,,,,,,,,,,, -51800,0.10868805,0.027768796,,,,,,,,,,,,,,,,, -51900,0.116655834,0.027571648,,,,,,,,,,,,,,,,, -52000,0.10489863,0.026764655,,,,,,,,,,,,,,,,, -52100,0.09013112,0.027140176,,,,,,,,,,,,,,,,, -52138,,,0.9931796193122864,0.0215019863098859,0.6067070557150565,0.9871442317962646,0.0456354469060897,0.293471528182527,43793.0,0.98629891872406,0.0488433130085468,0.2831817590489205,43793.0,16579.44235610962,23781.13451218605,16579.44235610962,7197.719510793686,2.5390548706054688,0.0 -52200,0.07726332,0.022702748,,,,,,,,,,,,,,,,, -52300,0.085929364,0.021889783,,,,,,,,,,,,,,,,, -52400,0.08405907,0.023279496,,,,,,,,,,,,,,,,, -52500,0.09405532,0.023276368,,,,,,,,,,,,,,,,, -52600,0.10152245,0.027193421,,,,,,,,,,,,,,,,, -52700,0.10373026,0.025759578,,,,,,,,,,,,,,,,, -52800,0.09524183,0.023811733,,,,,,,,,,,,,,,,, -52896,,,0.9931758046150208,0.0214325115084648,0.6115915335475446,0.9871913194656372,0.0455203764140605,0.3000649690933642,43793.0,0.9863018989562988,0.0489050596952438,0.2820189862990282,43793.0,16819.4209959507,24124.67778921128,16819.4209959507,7301.227471590042,2.5754523277282715,0.0 -52900,0.08853965,0.02494494,,,,,,,,,,,,,,,,, -53000,0.090275384,0.022784926,,,,,,,,,,,,,,,,, -53100,0.09232981,0.021585623,,,,,,,,,,,,,,,,, -53200,0.09341097,0.024624474,,,,,,,,,,,,,,,,, -53300,0.11928396,0.023561826,,,,,,,,,,,,,,,,, -53400,0.0944908,0.019779244,,,,,,,,,,,,,,,,, -53500,0.095577024,0.02626709,,,,,,,,,,,,,,,,, -53600,0.094480895,0.024952328,,,,,,,,,,,,,,,,, -53651,,,0.9933634400367736,0.0208130627870559,0.6317767925373252,0.987166166305542,0.045723769813776,0.2988973768460285,43793.0,0.9862803816795348,0.0489753000438213,0.2802482233155609,43793.0,17059.447257995605,24464.625368595123,17059.447257995605,7401.090955257416,2.6128950119018555,0.0 -53700,0.12395328,0.023063796,,,,,,,,,,,,,,,,, -53800,0.10687474,0.025171671,,,,,,,,,,,,,,,,, -53900,0.0856903,0.022762483,,,,,,,,,,,,,,,,, -54000,0.13391495,0.0245591,,,,,,,,,,,,,,,,, -54100,0.101595275,0.022820491,,,,,,,,,,,,,,,,, -54200,0.09371125,0.022980131,,,,,,,,,,,,,,,,, -54300,0.089291506,0.023369715,,,,,,,,,,,,,,,,, -54400,0.08972549,0.02344101,,,,,,,,,,,,,,,,, -54403,,,0.9933775663375854,0.0207133274525403,0.6225736436664356,0.987188458442688,0.0462249219417572,0.296826822781387,43793.0,0.986240804195404,0.0494866259396076,0.2819279003808288,43793.0,17299.483036756516,24806.26559662819,17299.483036756516,7502.637751102447,2.65012526512146,0.0 -54500,0.1058977,0.023521516,,,,,,,,,,,,,,,,, -54600,0.097428165,0.024493275,,,,,,,,,,,,,,,,, -54700,0.09385392,0.024857476,,,,,,,,,,,,,,,,, -54800,0.10142763,0.024285099,,,,,,,,,,,,,,,,, -54900,0.095716275,0.022053761,,,,,,,,,,,,,,,,, -55000,0.10791269,0.022855991,,,,,,,,,,,,,,,,, -55100,0.096261054,0.022911914,,,,,,,,,,,,,,,,, -55156,,,0.993545413017273,0.0201831441372632,0.645784824725203,0.9871426224708556,0.0462013557553291,0.2940088533353942,43793.0,0.9862689971923828,0.0496369414031505,0.2751525052977334,43793.0,17539.461524248123,25146.24363541603,17539.461524248123,7602.579082250595,2.6879782676696777,0.0 -55200,0.10437034,0.024215894,,,,,,,,,,,,,,,,, -55300,0.08331918,0.02140507,,,,,,,,,,,,,,,,, -55400,0.11721713,0.023161491,,,,,,,,,,,,,,,,, -55500,0.102600165,0.024146266,,,,,,,,,,,,,,,,, -55600,0.09237011,0.021059819,,,,,,,,,,,,,,,,, -55700,0.09756423,0.023709608,,,,,,,,,,,,,,,,, -55800,0.08670118,0.02160031,,,,,,,,,,,,,,,,, -55900,0.12431345,0.02174129,,,,,,,,,,,,,,,,, -55921,,,0.9938235282897948,0.0194733310490846,0.6602483439614082,0.9870220422744752,0.0461726337671279,0.2962127262744353,43793.0,0.9861751198768616,0.0495301969349384,0.278072814459413,43793.0,17779.411629915237,25485.167813539505,17779.411629915237,7701.495981454849,2.7243547439575195,0.0 -56000,0.106827684,0.023299681,,,,,,,,,,,,,,,,, -56100,0.09909318,0.020482754,,,,,,,,,,,,,,,,, -56200,0.09812186,0.02357989,,,,,,,,,,,,,,,,, -56300,0.103073955,0.022186702,,,,,,,,,,,,,,,,, -56400,0.09603632,0.018850641,,,,,,,,,,,,,,,,, -56500,0.086646564,0.021573402,,,,,,,,,,,,,,,,, -56600,0.102402955,0.021717718,,,,,,,,,,,,,,,,, -56685,,,0.993970274925232,0.0188592132180929,0.676547356716646,0.9871426224708556,0.0465945266187191,0.2926747635568142,43793.0,0.9863424897193908,0.0498776361346244,0.2806236785910661,43793.0,18019.554752588272,25831.152674913406,18019.554752588272,7807.280650138855,2.760650157928467,0.0 -56700,0.08339809,0.021822877,,,,,,,,,,,,,,,,, -56800,0.12035506,0.025276601,,,,,,,,,,,,,,,,, -56900,0.116487116,0.022199681,,,,,,,,,,,,,,,,, -57000,0.10901129,0.023462735,,,,,,,,,,,,,,,,, -57100,0.10310051,0.02296927,,,,,,,,,,,,,,,,, -57200,0.12430808,0.02435384,,,,,,,,,,,,,,,,, -57300,0.110004224,0.02326603,,,,,,,,,,,,,,,,, -57400,0.11252246,0.02230067,,,,,,,,,,,,,,,,, -57443,,,0.9942606687545776,0.0182267967611551,0.6956112092775892,0.987111747264862,0.0464325621724128,0.2988261302469323,43793.0,0.986276626586914,0.0495757274329662,0.2852527176577683,43793.0,18259.80503797531,26173.50976252556,18259.80503797531,7909.328285217285,2.799261569976806,0.0 -57500,0.10732792,0.02207112,,,,,,,,,,,,,,,,, -57600,0.12207912,0.025582353,,,,,,,,,,,,,,,,, -57700,0.10698294,0.021058252,,,,,,,,,,,,,,,,, -57800,0.10955512,0.023332793,,,,,,,,,,,,,,,,, -57900,0.10741173,0.021012478,,,,,,,,,,,,,,,,, -58000,0.100100346,0.022266753,,,,,,,,,,,,,,,,, -58100,0.11555591,0.021856258,,,,,,,,,,,,,,,,, -58134,,,,,,,,,,,,,,18477.266256332397,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 3e1e3e3e6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -894.9750609397888,0.0,37.0385491847992,1,0,37.0385491847992,0.0007088489946909,0.0,11.19086742401123,3003,932.0136518478394,0.00065292842919,0.0,11.176665306091309,0.0004835649742744,0.0,11.208685874938965,3000 -1663.1117770671844,0.0305879116058349,877.1155626773834,2382,0,877.1155626773834,0.378374308347702,7.77435242429996,4.347209453582764,3003,2540.3347775936127,0.4112793505191803,14.314785399828414,4.001132965087891,0.3958165347576141,9.577254281286269,4.139353752136231,3000 -2150.2625029087067,0.0555925369262695,1717.0735466480255,4762,0,1717.0735466480255,0.5421649217605591,18.93392274615353,2.766301155090332,3003,3867.541526556015,0.5386283993721008,24.47574218709137,2.7876904010772705,0.5418035387992859,20.31710448845741,2.7314565181732178,3000 -2623.3933651447296,0.0818729400634765,2557.290193080902,7143,0,2557.290193080902,0.588565468788147,21.92302229850635,2.322375535964966,3003,5180.989482164383,0.5830803513526917,27.624996169924223,2.34594202041626,0.5854235887527466,23.29588649527409,2.320286512374878,3000 -3093.3391761779785,0.1232635974884033,3397.41008734703,9526,0,3397.41008734703,0.6118645071983337,23.48735964868124,2.1073811054229736,3003,6491.173132181168,0.5969531536102295,28.03945898046316,2.222651481628418,0.6077544093132019,24.906629328826103,2.134187936782837,3000 -3538.7417256832123,0.15330171585083,4237.438687324524,11909,0,4237.438687324524,0.6260995864868164,24.403812250497207,1.976151943206787,3003,7776.712178230286,0.6016323566436768,28.773179468343937,2.1511049270629883,0.621133029460907,25.84807362168108,2.0083134174346924,3000 -4019.282205581665,0.1813933849334716,5077.634784460068,14294,0,5077.634784460068,0.6379757523536682,25.565374635606645,1.875356078147888,3003,9097.55193066597,0.6140919923782349,29.75643167533389,2.048624277114868,0.6318830251693726,26.592895269365595,1.923273563385009,3000 -4478.662546873093,0.2098925113677978,5917.868166685104,16679,0,5917.868166685104,0.6479925513267517,25.933641625620712,1.8052959442138672,3003,10397.26683807373,0.6197834014892578,29.712456080288373,2.003021001815796,0.6387149691581726,26.800316120447345,1.858880519866944,3000 -5263.736715555191,0.236696720123291,6757.979300022125,19064,0,6757.979300022125,0.652710497379303,26.45541676211269,1.761230707168579,3003,12022.555188179016,0.6391572952270508,31.17832988950659,1.85425329208374,0.6441581845283508,27.1275590231522,1.8158868551254272,3000 -5778.866499423981,0.2639615535736084,7598.178004980087,21450,0,7598.178004980087,0.6565220355987549,26.60739545980313,1.7380309104919434,3003,13377.98326063156,0.6265982985496521,30.379573470875417,1.9437370300292969,0.6476795077323914,27.4281157734518,1.7847602367401123,3000 -6283.715788841248,0.291827917098999,8438.228722810745,23836,0,8438.228722810745,0.6588460803031921,26.95805362050392,1.7223812341690063,3003,14722.985248565674,0.6263420581817627,30.605509374050456,1.9479836225509644,0.6493533849716187,27.501190326452026,1.7742805480957031,3000 -6752.686856031418,0.3216967582702636,9278.243922948835,26221,0,9278.243922948835,0.665864884853363,27.71821619195752,1.6981230974197388,3003,16032.076689481735,0.6345329880714417,31.028171095176663,1.8784587383270264,0.65301114320755,27.96881610127737,1.7509132623672483,3000 -7214.063798904419,0.3504965305328369,10118.190644741058,28607,0,10118.190644741058,0.6636918187141418,26.97858312778024,1.6925480365753174,3003,17333.501986265182,0.6294631361961365,30.843768750457446,1.9184149503707888,0.6532467007637024,27.84541869955221,1.749211311340332,3000 -7759.080738306045,0.3789188861846924,10958.382172107697,30994,0,10958.382172107697,0.6665040254592896,27.64029504672969,1.6737943887710571,3003,18718.81236767769,0.6348527669906616,30.79260266740877,1.892822265625,0.6570532321929932,28.399133611509708,1.729580998420715,3000 -8410.495576143265,0.4079523086547851,11798.445326805117,33381,0,11798.445326805117,0.667026937007904,27.45520512735884,1.6664468050003052,3003,20210.39287185669,0.6381608843803406,30.97019323236983,1.8649847507476809,0.6565820574760437,28.203732298008507,1.7204769849777222,3000 -8903.872762203217,0.4422931671142578,12638.43499302864,35767,0,12638.43499302864,0.6706873774528503,27.893507040265977,1.6476588249206543,3003,21543.868954896927,0.6377241015434265,31.08551284303532,1.875410556793213,0.6565572619438171,28.019113346418383,1.71941077709198,3000 -9381.595716238022,0.4726207256317138,13478.334066152573,38152,0,13478.334066152573,0.6716054081916809,27.802578131818603,1.6476448774337769,3003,22861.5990588665,0.6488719582557678,31.61389258230864,1.7934083938598633,0.6610581278800964,28.280349091822146,1.70951247215271,3000 -9902.135681152344,0.5022103786468506,14318.342139482498,40538,0,14318.342139482498,0.6716402173042297,27.64028559420224,1.6384211778640747,3003,24222.2512857914,0.6424659490585327,31.08091692572428,1.8437772989273071,0.6593222618103027,28.3158290660357,1.7040035724639893,3000 -10383.810242652891,0.5315215587615967,15158.377641677856,42925,0,15158.377641677856,0.6723374724388123,28.000836489908377,1.6351051330566406,3003,25544.065200567245,0.6389986276626587,31.192170592865967,1.856657862663269,0.6593842506408691,28.26113223855552,1.7011802196502686,3000 -10852.916440725328,0.5626571178436279,15998.422271966934,45312,0,15998.422271966934,0.6726396083831787,27.586332945086863,1.62099027633667,3003,26853.32121229172,0.6493979692459106,31.863872415061568,1.7883561849594116,0.6628807783126831,28.467148185414068,1.6845715045928955,3000 -11316.14895439148,0.5931167602539062,16838.640387535095,47700,0,16838.640387535095,0.6753239631652832,28.22375180587874,1.6231544017791748,3003,28156.87495303154,0.641980767250061,31.306511234966308,1.836068153381348,0.6625956296920776,28.62275325157356,1.6862725019454956,3000 -11796.412520170212,0.6238071918487549,17678.58997631073,50087,0,17678.58997631073,0.6756957769393921,28.095035425482383,1.6147881746292114,3003,29477.193959236145,0.691525399684906,35.324233093710475,1.5449894666671753,0.6646042466163635,28.461571849251246,1.6788150072097778,3000 -12272.941549777985,0.6620500087738037,18518.520733118057,52474,0,18518.520733118057,0.6764511466026306,28.30420204851027,1.6009951829910278,3003,30793.76582384109,0.648737370967865,31.28400718143166,1.7925587892532349,0.6651870608329773,28.830817072279498,1.667005181312561,3000 -12735.091276407242,0.6979632377624512,19358.474761724472,54860,0,19358.474761724472,0.6795421838760376,28.502942846220005,1.5950560569763184,3003,32095.981645822525,0.6473584175109863,31.921184654644467,1.8072962760925293,0.6671088933944702,29.0922268005727,1.663117289543152,3000 -13202.08388352394,0.7302701473236084,20198.645827054977,57247,0,20198.645827054977,0.6824705004692078,28.750735001154936,1.5814528465270996,3003,33403.252178907394,0.6565492749214172,32.6825347018119,1.7392457723617554,0.6668609380722046,29.00122112780653,1.654740810394287,3000 -13678.836746692656,0.7625217437744141,21038.74125480652,59634,0,21038.74125480652,0.6804021000862122,28.77173907460039,1.5773099660873413,3003,34720.2066423893,0.6492637395858765,31.7395194006414,1.7962538003921509,0.6694399118423462,28.93402112722786,1.6482961177825928,3000 -14165.701899528503,0.7960004806518555,21878.683776140213,62021,0,21878.683776140213,0.6833536624908447,28.95119878578508,1.5698957443237305,3003,36047.12191772461,0.6487945318222046,32.09413766752779,1.79489004611969,0.6691547632217407,28.96760031948729,1.6473053693771362,3000 -14644.784934043884,0.8291144371032715,22718.78689146042,64408,0,22718.78689146042,0.6835861206054688,29.056131315412504,1.5647251605987549,3003,37366.41644477844,0.6550009250640869,32.452760515426576,1.7537881135940552,0.6708534359931946,28.91921608404348,1.635170578956604,3000 -15124.934435129166,0.8623538017272949,23558.88369011879,66795,0,23558.88369011879,0.6852013468742371,29.178349712756702,1.5587241649627686,3003,38686.76952624321,0.6545759439468384,32.203787090257514,1.7578319311141968,0.6715105772018433,29.41297111527696,1.635469675064087,3000 -15660.699274778366,0.9025025367736816,24398.984651088715,69181,0,24398.984651088715,0.6837836503982544,28.954724806981417,1.5548111200332642,3003,40062.75036764145,0.6702680587768555,33.066980321510584,1.6583515405654907,0.6727132797241211,29.20222921314499,1.62405264377594,3000 -16190.815400600432,0.9364674091339112,25239.0883140564,71568,0,25239.0883140564,0.6880018711090088,28.961093312146044,1.5388880968093872,3003,41433.078765153885,0.655456006526947,32.35903194112188,1.7491358518600464,0.6749451160430908,29.35684438996192,1.6112587451934814,3000 -16826.95787167549,0.9769396781921388,26079.173448324203,73955,0,26079.173448324203,0.6869211792945862,29.29112343334928,1.5449918508529663,3003,42909.42013859749,0.6537445783615112,32.23235903214701,1.7614141702651978,0.6735440492630005,28.445815162064275,1.6155986785888672,3000 -17308.0654463768,1.0143301486968994,26919.26244521141,76340,0,26919.26244521141,0.6900935769081116,29.379224504547373,1.5276169776916504,3003,44230.73479604721,0.6645991206169128,32.924616682714294,1.6866929531097412,0.67549067735672,29.642798192788497,1.6071454286575315,3000 -17798.384374141693,1.0526585578918457,27759.4205942154,78727,0,27759.4205942154,0.6900470852851868,29.493238645573108,1.5211902856826782,3003,45561.32496523857,0.6578866243362427,33.02379867229866,1.7336337566375732,0.6777225136756897,29.78239122588757,1.597773551940918,3000 -18273.101594686508,1.0890939235687256,28599.610867261887,81114,0,28599.610867261887,0.6944047808647156,29.8250633065196,1.50739848613739,3003,46876.34319114685,0.6640418171882629,32.88214843273441,1.6978520154953003,0.6789872050285339,29.676987911310327,1.593400001525879,3000 -18761.78837966919,1.1252844333648682,29440.040162086487,83501,0,29440.040162086487,0.6937307715415955,29.50003298374016,1.5014517307281494,3003,48205.56838059425,0.6630203723907471,32.74125649796642,1.7007750272750854,0.6791360378265381,29.65465038029376,1.5790903568267822,3000 -19246.59332036972,1.1620018482208252,30280.26576089859,85888,0,30280.26576089859,0.6941258907318115,29.87476830317458,1.5003278255462646,3003,49530.71009230614,0.6633317470550537,32.76356242688841,1.6965419054031372,0.6802891492843628,30.043155582785047,1.5811126232147217,3000 -19744.70170378685,1.199810266494751,31120.169873714447,88274,0,31120.169873714447,0.6963802576065063,29.658199063669382,1.4937410354614258,3003,50868.83356237412,0.682263195514679,34.32433497901006,1.5941760540008545,0.6819010376930237,29.823944589013955,1.5699979066848757,3000 -20236.372787475582,1.2370805740356443,31960.32098174095,90660,0,31960.32098174095,0.6983673572540283,30.366281457229302,1.478404879570007,3003,52200.76883912087,0.6701759696006775,32.95552252211354,1.6549545526504517,0.6838228702545166,30.1105725801552,1.5629757642745972,3000 -20746.14312171936,1.2735848426818848,32800.464728832245,93047,0,32800.464728832245,0.7005403637886047,30.228897107393603,1.4703547954559326,3003,53550.79271054268,0.6740824580192566,32.95555072276922,1.6410588026046753,0.6841452717781067,30.397595709889032,1.5564584732055664,3000 -21223.906438589096,1.3174645900726318,33640.66740679741,95434,0,33640.66740679741,0.7010632753372192,30.193043914956924,1.465505599975586,3003,54868.87886095047,0.6752694249153137,34.1304166462366,1.6265360116958618,0.6851744055747986,30.227560754570124,1.5516746044158936,3000 -21694.93043994904,1.3555314540863037,34480.82843494415,97821,0,34480.82843494415,0.7020161747932434,30.33091817118647,1.464707612991333,3003,56180.17577815056,0.6793409585952759,33.39148472321215,1.6154285669326782,0.6855835318565369,30.3606306216234,1.547755002975464,3000 -22157.154767990112,1.3990330696105957,35320.90512704849,100207,0,35320.90512704849,0.7030271291732788,30.08765700556365,1.4542731046676636,3003,57482.59611940384,0.7089230418205261,36.04611483927222,1.443196415901184,0.6880261898040771,30.359806178761417,1.534866452217102,3000 -22635.92946076393,1.444124460220337,36160.99453735352,102593,0,36160.99453735352,0.7057347297668457,30.57403197903927,1.445566177368164,3003,58801.58122611046,0.6842008233070374,34.35210443372337,1.579800724983215,0.6877781748771667,30.39731603776197,1.531459927558899,3000 -23147.89739346504,1.483485221862793,37001.09319806099,104979,0,37001.09319806099,0.7062576413154602,30.793022631949142,1.4383666515350342,3003,60153.76279973984,0.683157205581665,34.28504809514936,1.582433581352234,0.6898860335350037,30.47028869134486,1.5270402431488037,3000 -23635.432076931,1.522993564605713,37841.05610227585,107365,0,37841.05610227585,0.7057463526725769,30.804335144952017,1.435275673866272,3003,61481.371970653534,0.6957497596740723,35.563739134784505,1.5065596103668213,0.6896504759788513,30.50287564019885,1.5250399112701416,3000 -24178.042583703995,1.5672264099121094,38681.15629196167,109751,0,38681.15629196167,0.7073732018470764,30.37974883692844,1.4307950735092163,3003,62864.20189833641,0.6920896172523499,35.06577869851787,1.5295428037643433,0.6902828216552734,30.492036028114327,1.5207116603851318,3000 -24640.57275032997,1.606520652770996,39521.200323581696,112137,0,39521.200323581696,0.7082563638687134,30.58984000784102,1.4285781383514404,3003,64166.88848829269,0.6917627453804016,34.74488382371364,1.534698724746704,0.6911011338233948,30.914413893236805,1.5178720951080322,3000 -25145.35341453552,1.647085428237915,40361.25352883339,114523,0,40361.25352883339,0.708628237247467,30.869338740033143,1.4227757453918457,3003,65511.83819484711,0.698196291923523,35.77462367787801,1.495474338531494,0.6925270557403564,30.57723185326211,1.5115963220596311,3000 -25627.239145994183,1.689962387084961,41201.23071146011,116908,0,41201.23071146011,0.7094300389289856,30.701844657815705,1.418934345245361,3003,66833.81958413124,0.6974208354949951,35.78149866685925,1.5098053216934204,0.693990170955658,30.96133738228591,1.5112836360931396,3000 -26121.627092838287,1.7309155464172363,42041.36020541191,119295,0,42041.36020541191,0.7100808024406433,30.658603829809408,1.416991114616394,3003,68168.45203638077,0.710716187953949,36.5941331650276,1.4390244483947754,0.6939281225204468,30.915858804496704,1.5062028169631958,3000 -26617.650767326355,1.773043870925903,42881.50228333473,121681,0,42881.50228333473,0.7108593583106995,30.843223404698183,1.4113267660140991,3003,69504.73358535767,0.7038121819496155,36.19937840630523,1.4720726013183594,0.6947588920593262,30.82173925718412,1.5022433996200562,3000 -27111.19986152649,1.8156933784484863,43721.44666552544,124066,0,43721.44666552544,0.7117890119552612,30.892537332352088,1.4099397659301758,3003,70838.34599137306,0.7046647667884827,36.53354128224039,1.4676135778427124,0.6953912377357483,30.746032259348308,1.503679275512695,3000 -27582.21673822403,1.8660566806793213,44561.50462055206,126451,0,44561.50462055206,0.7123816609382629,31.215212632401546,1.4068301916122437,3003,72149.5507850647,0.7076228857040405,36.56570523268232,1.4430257081985474,0.6953416466712952,30.9751714367643,1.5003718137741089,3000 -28073.51711320877,1.908497333526612,45401.51370239258,128837,0,45401.51370239258,0.7122421860694885,30.932261071712,1.4065791368484497,3003,73480.97584033012,0.7105525732040405,36.75074571548151,1.4332529306411743,0.6956764459609985,30.823767338628773,1.4999107122421265,3000 -28565.360773801804,1.958054780960083,46241.69057679176,131223,0,46241.69057679176,0.7126489281654358,31.09791754474128,1.4054698944091797,3003,74813.12171435356,0.7070929408073425,36.070049531934,1.4555907249450684,0.695874810218811,30.95694516258592,1.499701976776123,3000 -29056.399307250977,2.001035213470459,46984.7851524353,133333,0,46984.7851524353,0.7125210762023926,31.041513578042657,1.4059910774230957,3003,76047.3649597168,0.7128967642784119,36.50041035697179,1.4201592206954956,0.6956764459609985,30.901218707876914,1.4999949932098389,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 786210cb7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.615253,11.164706,,,,,,,,,,,,,,,,, -1,,,0.00065292842919,11.176665306091309,0.0,0.0004835649742744,11.208685874938965,0.0,3000.0,0.0007088489946909,11.19086742401123,0.0,3003.0,37.0385491847992,932.0136518478394,37.0385491847992,894.9750609397888,0.0,0.0 -100,0.43962425,8.903551,,,,,,,,,,,,,,,,, -200,0.15697378,8.547919,,,,,,,,,,,,,,,,, -300,0.16404852,8.325197,,,,,,,,,,,,,,,,, -400,0.2881223,7.9730244,,,,,,,,,,,,,,,,, -500,0.35936484,7.6332126,,,,,,,,,,,,,,,,, -600,0.7561394,7.3986144,,,,,,,,,,,,,,,,, -700,0.69248503,7.2032824,,,,,,,,,,,,,,,,, -800,0.55900246,6.9315114,,,,,,,,,,,,,,,,, -900,0.51868254,6.815754,,,,,,,,,,,,,,,,, -1000,0.46499583,6.5605726,,,,,,,,,,,,,,,,, -1100,0.5652977,6.3587027,,,,,,,,,,,,,,,,, -1200,0.72699535,6.225107,,,,,,,,,,,,,,,,, -1300,0.5980017,6.0884295,,,,,,,,,,,,,,,,, -1400,0.57806206,5.932805,,,,,,,,,,,,,,,,, -1500,0.56722724,5.7909245,,,,,,,,,,,,,,,,, -1600,0.59383786,5.6594973,,,,,,,,,,,,,,,,, -1700,0.6037839,5.509642,,,,,,,,,,,,,,,,, -1800,0.8971289,5.4694314,,,,,,,,,,,,,,,,, -1900,0.8999788,5.352315,,,,,,,,,,,,,,,,, -2000,0.8161043,5.3346624,,,,,,,,,,,,,,,,, -2100,0.874891,5.1644545,,,,,,,,,,,,,,,,, -2200,0.6426465,5.0552006,,,,,,,,,,,,,,,,, -2300,0.80829674,4.9774356,,,,,,,,,,,,,,,,, -2382,,,0.4112793505191803,4.001132965087891,14.314785399828414,0.3958165347576141,4.139353752136231,9.577254281286269,3000.0,0.378374308347702,4.347209453582764,7.77435242429996,3003.0,877.1155626773834,2540.3347775936127,877.1155626773834,1663.1117770671844,0.0305879116058349,0.0 -2400,0.7200735,4.913783,,,,,,,,,,,,,,,,, -2500,0.6640373,4.827715,,,,,,,,,,,,,,,,, -2600,0.73737526,4.647209,,,,,,,,,,,,,,,,, -2700,0.8753207,4.636377,,,,,,,,,,,,,,,,, -2800,0.79440117,4.5540447,,,,,,,,,,,,,,,,, -2900,0.90815026,4.48922,,,,,,,,,,,,,,,,, -3000,1.0712976,4.4907627,,,,,,,,,,,,,,,,, -3100,0.62878,4.3538313,,,,,,,,,,,,,,,,, -3200,0.9690578,4.298713,,,,,,,,,,,,,,,,, -3300,0.90871495,4.282699,,,,,,,,,,,,,,,,, -3400,0.633427,4.2459836,,,,,,,,,,,,,,,,, -3500,0.6817268,4.1706905,,,,,,,,,,,,,,,,, -3600,0.6867973,4.155293,,,,,,,,,,,,,,,,, -3700,0.6364837,4.0546236,,,,,,,,,,,,,,,,, -3800,0.7742857,4.1065564,,,,,,,,,,,,,,,,, -3900,0.83822614,4.0557413,,,,,,,,,,,,,,,,, -4000,0.6543673,4.042284,,,,,,,,,,,,,,,,, -4100,0.61211014,3.9271352,,,,,,,,,,,,,,,,, -4200,0.60679436,3.942342,,,,,,,,,,,,,,,,, -4300,0.5395824,3.8672023,,,,,,,,,,,,,,,,, -4400,0.63684595,3.897582,,,,,,,,,,,,,,,,, -4500,0.5915032,3.8126645,,,,,,,,,,,,,,,,, -4600,0.56518567,3.852035,,,,,,,,,,,,,,,,, -4700,0.5189155,3.6826427,,,,,,,,,,,,,,,,, -4762,,,0.5386283993721008,2.7876904010772705,24.47574218709137,0.5418035387992859,2.7314565181732178,20.31710448845741,3000.0,0.5421649217605591,2.766301155090332,18.93392274615353,3003.0,1717.0735466480255,3867.541526556015,1717.0735466480255,2150.2625029087067,0.0555925369262695,0.0 -4800,0.61265427,3.7677507,,,,,,,,,,,,,,,,, -4900,0.5005627,3.7654216,,,,,,,,,,,,,,,,, -5000,0.6044249,3.7479687,,,,,,,,,,,,,,,,, -5100,0.7157528,3.7529547,,,,,,,,,,,,,,,,, -5200,0.56334007,3.7140796,,,,,,,,,,,,,,,,, -5300,0.4861076,3.6535606,,,,,,,,,,,,,,,,, -5400,0.5786,3.6599874,,,,,,,,,,,,,,,,, -5500,0.47269484,3.5786154,,,,,,,,,,,,,,,,, -5600,0.5622447,3.6600776,,,,,,,,,,,,,,,,, -5700,0.48843855,3.6111605,,,,,,,,,,,,,,,,, -5800,0.4943361,3.6684604,,,,,,,,,,,,,,,,, -5900,0.4754077,3.6294422,,,,,,,,,,,,,,,,, -6000,0.48069218,3.5548325,,,,,,,,,,,,,,,,, -6100,0.4794753,3.5668426,,,,,,,,,,,,,,,,, -6200,0.5369074,3.6094642,,,,,,,,,,,,,,,,, -6300,0.47647533,3.5750647,,,,,,,,,,,,,,,,, -6400,0.456153,3.5971382,,,,,,,,,,,,,,,,, -6500,0.48070893,3.5512733,,,,,,,,,,,,,,,,, -6600,0.42656776,3.455397,,,,,,,,,,,,,,,,, -6700,0.44946498,3.4567137,,,,,,,,,,,,,,,,, -6800,0.49270532,3.5364625,,,,,,,,,,,,,,,,, -6900,0.5001598,3.4601555,,,,,,,,,,,,,,,,, -7000,0.41544628,3.5212479,,,,,,,,,,,,,,,,, -7100,0.40347525,3.5532322,,,,,,,,,,,,,,,,, -7143,,,0.5830803513526917,2.34594202041626,27.624996169924223,0.5854235887527466,2.320286512374878,23.29588649527409,3000.0,0.588565468788147,2.322375535964966,21.92302229850635,3003.0,2557.290193080902,5180.989482164383,2557.290193080902,2623.3933651447296,0.0818729400634765,0.0 -7200,0.3904182,3.4755373,,,,,,,,,,,,,,,,, -7300,0.3966488,3.419415,,,,,,,,,,,,,,,,, -7400,0.46487588,3.464302,,,,,,,,,,,,,,,,, -7500,0.40328985,3.4444714,,,,,,,,,,,,,,,,, -7600,0.40169725,3.508006,,,,,,,,,,,,,,,,, -7700,0.48781827,3.3937793,,,,,,,,,,,,,,,,, -7800,0.4109493,3.3183866,,,,,,,,,,,,,,,,, -7900,0.37303248,3.4546714,,,,,,,,,,,,,,,,, -8000,0.3924186,3.393397,,,,,,,,,,,,,,,,, -8100,0.32644367,3.3318377,,,,,,,,,,,,,,,,, -8200,0.43118438,3.3201196,,,,,,,,,,,,,,,,, -8300,0.41096842,3.4575315,,,,,,,,,,,,,,,,, -8400,0.36654472,3.3938432,,,,,,,,,,,,,,,,, -8500,0.35806206,3.352966,,,,,,,,,,,,,,,,, -8600,0.5137874,3.4066696,,,,,,,,,,,,,,,,, -8700,0.34798825,3.3530335,,,,,,,,,,,,,,,,, -8800,0.3629391,3.32268,,,,,,,,,,,,,,,,, -8900,0.3649705,3.3465512,,,,,,,,,,,,,,,,, -9000,0.36301294,3.399304,,,,,,,,,,,,,,,,, -9100,0.3442541,3.356505,,,,,,,,,,,,,,,,, -9200,0.36234927,3.349675,,,,,,,,,,,,,,,,, -9300,0.31409055,3.2829974,,,,,,,,,,,,,,,,, -9400,0.28937644,3.3253777,,,,,,,,,,,,,,,,, -9500,0.31165513,3.296176,,,,,,,,,,,,,,,,, -9526,,,0.5969531536102295,2.222651481628418,28.03945898046316,0.6077544093132019,2.134187936782837,24.906629328826103,3000.0,0.6118645071983337,2.1073811054229736,23.48735964868124,3003.0,3397.41008734703,6491.173132181168,3397.41008734703,3093.3391761779785,0.1232635974884033,0.0 -9600,0.3201829,3.253135,,,,,,,,,,,,,,,,, -9700,0.3157981,3.2381082,,,,,,,,,,,,,,,,, -9800,0.3072799,3.269213,,,,,,,,,,,,,,,,, -9900,0.28482157,3.3776731,,,,,,,,,,,,,,,,, -10000,0.3116389,3.3014276,,,,,,,,,,,,,,,,, -10100,0.27995047,3.260913,,,,,,,,,,,,,,,,, -10200,0.27917877,3.2536545,,,,,,,,,,,,,,,,, -10300,0.2764756,3.2700036,,,,,,,,,,,,,,,,, -10400,0.3242167,3.2555516,,,,,,,,,,,,,,,,, -10500,0.28382924,3.2132027,,,,,,,,,,,,,,,,, -10600,0.2864701,3.3121278,,,,,,,,,,,,,,,,, -10700,0.29090872,3.2527547,,,,,,,,,,,,,,,,, -10800,0.26914948,3.2133436,,,,,,,,,,,,,,,,, -10900,0.36416933,3.342626,,,,,,,,,,,,,,,,, -11000,0.27612025,3.304437,,,,,,,,,,,,,,,,, -11100,0.29654506,3.257656,,,,,,,,,,,,,,,,, -11200,0.2950094,3.1680481,,,,,,,,,,,,,,,,, -11300,0.27697533,3.3143106,,,,,,,,,,,,,,,,, -11400,0.26173195,3.1697977,,,,,,,,,,,,,,,,, -11500,0.26324925,3.3251247,,,,,,,,,,,,,,,,, -11600,0.24885252,3.2233896,,,,,,,,,,,,,,,,, -11700,0.2436894,3.1846416,,,,,,,,,,,,,,,,, -11800,0.24030662,3.237229,,,,,,,,,,,,,,,,, -11900,0.2626157,3.2295928,,,,,,,,,,,,,,,,, -11909,,,0.6016323566436768,2.1511049270629883,28.773179468343937,0.621133029460907,2.0083134174346924,25.84807362168108,3000.0,0.6260995864868164,1.976151943206787,24.403812250497207,3003.0,4237.438687324524,7776.712178230286,4237.438687324524,3538.7417256832123,0.15330171585083,0.0 -12000,0.29474568,3.3479168,,,,,,,,,,,,,,,,, -12100,0.2815204,3.1662917,,,,,,,,,,,,,,,,, -12200,0.25260773,3.1315618,,,,,,,,,,,,,,,,, -12300,0.22466983,3.2245185,,,,,,,,,,,,,,,,, -12400,0.23433098,3.217785,,,,,,,,,,,,,,,,, -12500,0.26459372,3.156466,,,,,,,,,,,,,,,,, -12600,0.25869575,3.1878583,,,,,,,,,,,,,,,,, -12700,0.23580371,3.230319,,,,,,,,,,,,,,,,, -12800,0.23004726,3.2306762,,,,,,,,,,,,,,,,, -12900,0.21852973,3.1930852,,,,,,,,,,,,,,,,, -13000,0.2305431,3.1925874,,,,,,,,,,,,,,,,, -13100,0.26227945,3.1928217,,,,,,,,,,,,,,,,, -13200,0.27685848,3.20682,,,,,,,,,,,,,,,,, -13300,0.25440973,3.210832,,,,,,,,,,,,,,,,, -13400,0.285276,3.1897464,,,,,,,,,,,,,,,,, -13500,0.23693825,3.1860325,,,,,,,,,,,,,,,,, -13600,0.27093554,3.1309426,,,,,,,,,,,,,,,,, -13700,0.29563868,3.2369058,,,,,,,,,,,,,,,,, -13800,0.25903818,3.2215328,,,,,,,,,,,,,,,,, -13900,0.25759226,3.1282291,,,,,,,,,,,,,,,,, -14000,0.3129357,3.1826305,,,,,,,,,,,,,,,,, -14100,0.2368284,3.1476502,,,,,,,,,,,,,,,,, -14200,0.283243,3.11847,,,,,,,,,,,,,,,,, -14294,,,0.6140919923782349,2.048624277114868,29.75643167533389,0.6318830251693726,1.923273563385009,26.592895269365595,3000.0,0.6379757523536682,1.875356078147888,25.565374635606645,3003.0,5077.634784460068,9097.55193066597,5077.634784460068,4019.282205581665,0.1813933849334716,0.0 -14300,0.2322092,3.1899965,,,,,,,,,,,,,,,,, -14400,0.2708752,3.107969,,,,,,,,,,,,,,,,, -14500,0.25997254,3.0894756,,,,,,,,,,,,,,,,, -14600,0.2817272,3.1676354,,,,,,,,,,,,,,,,, -14700,0.30205902,3.1597857,,,,,,,,,,,,,,,,, -14800,0.28035665,3.1181855,,,,,,,,,,,,,,,,, -14900,0.26448432,3.22083,,,,,,,,,,,,,,,,, -15000,0.26727042,3.1389258,,,,,,,,,,,,,,,,, -15100,0.25909194,3.1525724,,,,,,,,,,,,,,,,, -15200,0.23611844,3.0794008,,,,,,,,,,,,,,,,, -15300,0.25427678,3.1745334,,,,,,,,,,,,,,,,, -15400,0.26842612,3.1804118,,,,,,,,,,,,,,,,, -15500,0.26136425,3.1345036,,,,,,,,,,,,,,,,, -15600,0.24567528,3.114893,,,,,,,,,,,,,,,,, -15700,0.27880207,3.1450202,,,,,,,,,,,,,,,,, -15800,0.33007467,3.0748205,,,,,,,,,,,,,,,,, -15900,0.25828272,3.0722036,,,,,,,,,,,,,,,,, -16000,0.25774112,3.1071796,,,,,,,,,,,,,,,,, -16100,0.30202705,3.079112,,,,,,,,,,,,,,,,, -16200,0.28144193,3.110781,,,,,,,,,,,,,,,,, -16300,0.26567027,3.0731578,,,,,,,,,,,,,,,,, -16400,0.3325257,3.0902486,,,,,,,,,,,,,,,,, -16500,0.28261894,3.1422205,,,,,,,,,,,,,,,,, -16600,0.2744674,3.0981193,,,,,,,,,,,,,,,,, -16679,,,0.6197834014892578,2.003021001815796,29.712456080288373,0.6387149691581726,1.858880519866944,26.800316120447345,3000.0,0.6479925513267517,1.8052959442138672,25.933641625620712,3003.0,5917.868166685104,10397.26683807373,5917.868166685104,4478.662546873093,0.2098925113677978,0.0 -16700,0.28889316,3.050753,,,,,,,,,,,,,,,,, -16800,0.2749777,3.1069746,,,,,,,,,,,,,,,,, -16900,0.3153204,3.0861275,,,,,,,,,,,,,,,,, -17000,0.2621589,3.084648,,,,,,,,,,,,,,,,, -17100,0.2608761,3.0877836,,,,,,,,,,,,,,,,, -17200,0.38158664,3.1546082,,,,,,,,,,,,,,,,, -17300,0.28574878,3.0708752,,,,,,,,,,,,,,,,, -17400,0.29340425,3.0619013,,,,,,,,,,,,,,,,, -17500,0.36288214,3.173222,,,,,,,,,,,,,,,,, -17600,0.33567116,3.0850058,,,,,,,,,,,,,,,,, -17700,0.26836744,3.0868397,,,,,,,,,,,,,,,,, -17800,0.30853072,3.0854247,,,,,,,,,,,,,,,,, -17900,0.3009001,3.0577507,,,,,,,,,,,,,,,,, -18000,0.27557468,2.9637723,,,,,,,,,,,,,,,,, -18100,0.3064645,3.0889373,,,,,,,,,,,,,,,,, -18200,0.28049728,3.1207469,,,,,,,,,,,,,,,,, -18300,0.3560372,3.0360363,,,,,,,,,,,,,,,,, -18400,0.34708738,3.0888062,,,,,,,,,,,,,,,,, -18500,0.28910118,3.0579457,,,,,,,,,,,,,,,,, -18600,0.2798195,3.0487347,,,,,,,,,,,,,,,,, -18700,0.39178607,3.0499387,,,,,,,,,,,,,,,,, -18800,0.31411412,3.1025412,,,,,,,,,,,,,,,,, -18900,0.29817638,3.0545883,,,,,,,,,,,,,,,,, -19000,0.31192157,3.0469592,,,,,,,,,,,,,,,,, -19064,,,0.6391572952270508,1.85425329208374,31.17832988950659,0.6441581845283508,1.8158868551254272,27.1275590231522,3000.0,0.652710497379303,1.761230707168579,26.45541676211269,3003.0,6757.979300022125,12022.555188179016,6757.979300022125,5263.736715555191,0.236696720123291,0.0 -19100,0.35315338,3.1066456,,,,,,,,,,,,,,,,, -19200,0.30574763,3.0781057,,,,,,,,,,,,,,,,, -19300,0.31463975,3.0373054,,,,,,,,,,,,,,,,, -19400,0.31303403,3.1110065,,,,,,,,,,,,,,,,, -19500,0.30843717,3.022775,,,,,,,,,,,,,,,,, -19600,0.33083594,3.0976758,,,,,,,,,,,,,,,,, -19700,0.34409124,3.0934954,,,,,,,,,,,,,,,,, -19800,0.33332074,3.088022,,,,,,,,,,,,,,,,, -19900,0.3366925,3.0165758,,,,,,,,,,,,,,,,, -20000,0.31483755,3.0588014,,,,,,,,,,,,,,,,, -20100,0.30243963,3.113312,,,,,,,,,,,,,,,,, -20200,0.3101504,2.9769595,,,,,,,,,,,,,,,,, -20300,0.4300287,2.982455,,,,,,,,,,,,,,,,, -20400,0.33212852,3.0876107,,,,,,,,,,,,,,,,, -20500,0.38016644,3.026979,,,,,,,,,,,,,,,,, -20600,0.33135745,2.9383004,,,,,,,,,,,,,,,,, -20700,0.3387811,3.074282,,,,,,,,,,,,,,,,, -20800,0.3495401,3.0362408,,,,,,,,,,,,,,,,, -20900,0.37653315,3.0598972,,,,,,,,,,,,,,,,, -21000,0.4376484,2.9976199,,,,,,,,,,,,,,,,, -21100,0.45227623,2.9881907,,,,,,,,,,,,,,,,, -21200,0.38100266,3.0606413,,,,,,,,,,,,,,,,, -21300,0.3555727,3.0639467,,,,,,,,,,,,,,,,, -21400,0.38556218,3.0641658,,,,,,,,,,,,,,,,, -21450,,,0.6265982985496521,1.9437370300292969,30.379573470875417,0.6476795077323914,1.7847602367401123,27.4281157734518,3000.0,0.6565220355987549,1.7380309104919434,26.60739545980313,3003.0,7598.178004980087,13377.98326063156,7598.178004980087,5778.866499423981,0.2639615535736084,0.0 -21500,0.37179446,3.0648453,,,,,,,,,,,,,,,,, -21600,0.42557594,3.077219,,,,,,,,,,,,,,,,, -21700,0.35815006,3.0294576,,,,,,,,,,,,,,,,, -21800,0.27671337,2.9945886,,,,,,,,,,,,,,,,, -21900,0.3419505,2.9848926,,,,,,,,,,,,,,,,, -22000,0.3927303,3.0775785,,,,,,,,,,,,,,,,, -22100,0.4403804,2.9675953,,,,,,,,,,,,,,,,, -22200,0.33082128,3.0091999,,,,,,,,,,,,,,,,, -22300,0.36169538,3.0189283,,,,,,,,,,,,,,,,, -22400,0.37774897,3.0195038,,,,,,,,,,,,,,,,, -22500,0.31225362,2.9824886,,,,,,,,,,,,,,,,, -22600,0.41020045,3.0473158,,,,,,,,,,,,,,,,, -22700,0.4099948,2.9924672,,,,,,,,,,,,,,,,, -22800,0.3852939,2.9736443,,,,,,,,,,,,,,,,, -22900,0.3585546,3.0297875,,,,,,,,,,,,,,,,, -23000,0.33851054,2.9983072,,,,,,,,,,,,,,,,, -23100,0.38123137,2.9847565,,,,,,,,,,,,,,,,, -23200,0.42833373,3.055193,,,,,,,,,,,,,,,,, -23300,0.37809458,3.0847747,,,,,,,,,,,,,,,,, -23400,0.3272053,3.0699096,,,,,,,,,,,,,,,,, -23500,0.33536297,2.9503965,,,,,,,,,,,,,,,,, -23600,0.39982492,3.0065312,,,,,,,,,,,,,,,,, -23700,0.33580977,2.9440506,,,,,,,,,,,,,,,,, -23800,0.40023142,2.9925895,,,,,,,,,,,,,,,,, -23836,,,0.6263420581817627,1.9479836225509644,30.605509374050456,0.6493533849716187,1.7742805480957031,27.501190326452026,3000.0,0.6588460803031921,1.7223812341690063,26.95805362050392,3003.0,8438.228722810745,14722.985248565674,8438.228722810745,6283.715788841248,0.291827917098999,0.0 -23900,0.32402858,3.0378432,,,,,,,,,,,,,,,,, -24000,0.32331485,2.936467,,,,,,,,,,,,,,,,, -24100,0.32722458,2.9843688,,,,,,,,,,,,,,,,, -24200,0.33250597,3.046513,,,,,,,,,,,,,,,,, -24300,0.40513727,3.0645864,,,,,,,,,,,,,,,,, -24400,0.35592666,3.0220263,,,,,,,,,,,,,,,,, -24500,0.40833023,3.0220797,,,,,,,,,,,,,,,,, -24600,0.44008943,3.0316498,,,,,,,,,,,,,,,,, -24700,0.37046483,3.055099,,,,,,,,,,,,,,,,, -24800,0.38278034,3.0460074,,,,,,,,,,,,,,,,, -24900,0.33696032,3.0206738,,,,,,,,,,,,,,,,, -25000,0.35426128,2.96515,,,,,,,,,,,,,,,,, -25100,0.33412245,3.0678515,,,,,,,,,,,,,,,,, -25200,0.38567814,3.0052323,,,,,,,,,,,,,,,,, -25300,0.3805121,2.9775846,,,,,,,,,,,,,,,,, -25400,0.41656178,3.0139601,,,,,,,,,,,,,,,,, -25500,0.3403456,3.0283897,,,,,,,,,,,,,,,,, -25600,0.33894676,3.0011387,,,,,,,,,,,,,,,,, -25700,0.3565841,3.0788321,,,,,,,,,,,,,,,,, -25800,0.3291951,3.0318964,,,,,,,,,,,,,,,,, -25900,0.34727737,3.0008821,,,,,,,,,,,,,,,,, -26000,0.38669297,3.0140312,,,,,,,,,,,,,,,,, -26100,0.41834134,3.0316145,,,,,,,,,,,,,,,,, -26200,0.4092515,3.0484536,,,,,,,,,,,,,,,,, -26221,,,0.6345329880714417,1.8784587383270264,31.028171095176663,0.65301114320755,1.7509132623672483,27.96881610127737,3000.0,0.665864884853363,1.6981230974197388,27.71821619195752,3003.0,9278.243922948835,16032.076689481735,9278.243922948835,6752.686856031418,0.3216967582702636,0.0 -26300,0.3822815,3.0169494,,,,,,,,,,,,,,,,, -26400,0.36212423,2.9491415,,,,,,,,,,,,,,,,, -26500,0.3427478,2.9629016,,,,,,,,,,,,,,,,, -26600,0.34356976,2.9458897,,,,,,,,,,,,,,,,, -26700,0.39888018,2.9791236,,,,,,,,,,,,,,,,, -26800,0.35193226,2.9467378,,,,,,,,,,,,,,,,, -26900,0.36080706,2.9697845,,,,,,,,,,,,,,,,, -27000,0.38443905,3.0084853,,,,,,,,,,,,,,,,, -27100,0.39885327,2.990078,,,,,,,,,,,,,,,,, -27200,0.33243066,2.9502103,,,,,,,,,,,,,,,,, -27300,0.31935203,2.9876728,,,,,,,,,,,,,,,,, -27400,0.37831676,2.977271,,,,,,,,,,,,,,,,, -27500,0.36107168,3.0225072,,,,,,,,,,,,,,,,, -27600,0.41260058,2.9598944,,,,,,,,,,,,,,,,, -27700,0.38131928,3.000123,,,,,,,,,,,,,,,,, -27800,0.35422245,2.964261,,,,,,,,,,,,,,,,, -27900,0.42133033,3.0193284,,,,,,,,,,,,,,,,, -28000,0.45057315,3.0338757,,,,,,,,,,,,,,,,, -28100,0.38384834,3.0126781,,,,,,,,,,,,,,,,, -28200,0.4110447,3.0363681,,,,,,,,,,,,,,,,, -28300,0.3393851,2.991179,,,,,,,,,,,,,,,,, -28400,0.40268895,2.998411,,,,,,,,,,,,,,,,, -28500,0.38455763,2.9753428,,,,,,,,,,,,,,,,, -28600,0.408224,2.9839406,,,,,,,,,,,,,,,,, -28607,,,0.6294631361961365,1.9184149503707888,30.843768750457446,0.6532467007637024,1.749211311340332,27.84541869955221,3000.0,0.6636918187141418,1.6925480365753174,26.97858312778024,3003.0,10118.190644741058,17333.501986265182,10118.190644741058,7214.063798904419,0.3504965305328369,0.0 -28700,0.46262506,2.937596,,,,,,,,,,,,,,,,, -28800,0.34005514,3.0252178,,,,,,,,,,,,,,,,, -28900,0.44896618,2.939941,,,,,,,,,,,,,,,,, -29000,0.3325673,3.0773435,,,,,,,,,,,,,,,,, -29100,0.37524834,2.921876,,,,,,,,,,,,,,,,, -29200,0.31646788,2.9791791,,,,,,,,,,,,,,,,, -29300,0.3463689,2.9923377,,,,,,,,,,,,,,,,, -29400,0.35077637,3.082157,,,,,,,,,,,,,,,,, -29500,0.3991097,3.0604043,,,,,,,,,,,,,,,,, -29600,0.41871545,2.994125,,,,,,,,,,,,,,,,, -29700,0.36145192,2.9802816,,,,,,,,,,,,,,,,, -29800,0.34006837,2.9624653,,,,,,,,,,,,,,,,, -29900,0.514126,2.974191,,,,,,,,,,,,,,,,, -30000,0.36380512,2.9399202,,,,,,,,,,,,,,,,, -30100,0.38577318,2.9739268,,,,,,,,,,,,,,,,, -30200,0.3643966,3.0167956,,,,,,,,,,,,,,,,, -30300,0.3882266,2.997693,,,,,,,,,,,,,,,,, -30400,0.346626,2.9637802,,,,,,,,,,,,,,,,, -30500,0.39936665,2.9735067,,,,,,,,,,,,,,,,, -30600,0.42611018,3.015388,,,,,,,,,,,,,,,,, -30700,1.185307,3.0401123,,,,,,,,,,,,,,,,, -30800,0.38290656,3.0089524,,,,,,,,,,,,,,,,, -30900,0.37297642,2.8934774,,,,,,,,,,,,,,,,, -30994,,,0.6348527669906616,1.892822265625,30.79260266740877,0.6570532321929932,1.729580998420715,28.399133611509708,3000.0,0.6665040254592896,1.6737943887710571,27.64029504672969,3003.0,10958.382172107697,18718.81236767769,10958.382172107697,7759.080738306045,0.3789188861846924,0.0 -31000,0.3738856,3.021176,,,,,,,,,,,,,,,,, -31100,0.36041838,2.9497747,,,,,,,,,,,,,,,,, -31200,0.35292122,2.9703224,,,,,,,,,,,,,,,,, -31300,0.3630878,3.0392945,,,,,,,,,,,,,,,,, -31400,0.34678054,3.011904,,,,,,,,,,,,,,,,, -31500,0.33887267,2.9808726,,,,,,,,,,,,,,,,, -31600,0.33169302,2.9424338,,,,,,,,,,,,,,,,, -31700,0.38054264,2.9939647,,,,,,,,,,,,,,,,, -31800,0.3681949,2.8566458,,,,,,,,,,,,,,,,, -31900,0.5040316,2.9834695,,,,,,,,,,,,,,,,, -32000,0.45824835,3.0003579,,,,,,,,,,,,,,,,, -32100,0.32252035,2.9534562,,,,,,,,,,,,,,,,, -32200,0.38331473,2.9898953,,,,,,,,,,,,,,,,, -32300,0.3433518,3.0152526,,,,,,,,,,,,,,,,, -32400,0.3349635,2.947413,,,,,,,,,,,,,,,,, -32500,0.35575035,2.947322,,,,,,,,,,,,,,,,, -32600,0.34497824,3.0099134,,,,,,,,,,,,,,,,, -32700,0.36078882,3.001474,,,,,,,,,,,,,,,,, -32800,0.35153285,3.0428672,,,,,,,,,,,,,,,,, -32900,0.3649122,2.956977,,,,,,,,,,,,,,,,, -33000,0.4901911,2.9539433,,,,,,,,,,,,,,,,, -33100,0.33308834,3.004056,,,,,,,,,,,,,,,,, -33200,0.38790017,2.9683,,,,,,,,,,,,,,,,, -33300,0.4032389,2.9436867,,,,,,,,,,,,,,,,, -33381,,,0.6381608843803406,1.8649847507476809,30.97019323236983,0.6565820574760437,1.7204769849777222,28.203732298008507,3000.0,0.667026937007904,1.6664468050003052,27.45520512735884,3003.0,11798.445326805117,20210.39287185669,11798.445326805117,8410.495576143265,0.4079523086547851,0.0 -33400,0.36156663,2.98767,,,,,,,,,,,,,,,,, -33500,0.34693888,2.9795961,,,,,,,,,,,,,,,,, -33600,0.4140553,2.938471,,,,,,,,,,,,,,,,, -33700,0.35215944,2.8618016,,,,,,,,,,,,,,,,, -33800,0.36043555,2.8833861,,,,,,,,,,,,,,,,, -33900,0.43334463,2.9603896,,,,,,,,,,,,,,,,, -34000,0.36483842,2.9652407,,,,,,,,,,,,,,,,, -34100,0.36333466,2.9336357,,,,,,,,,,,,,,,,, -34200,0.4186663,2.9030735,,,,,,,,,,,,,,,,, -34300,0.33634952,2.9861856,,,,,,,,,,,,,,,,, -34400,0.37209752,2.9720645,,,,,,,,,,,,,,,,, -34500,0.3671007,2.967333,,,,,,,,,,,,,,,,, -34600,0.3488706,2.9064887,,,,,,,,,,,,,,,,, -34700,0.4092797,3.0240672,,,,,,,,,,,,,,,,, -34800,0.33739936,2.9538667,,,,,,,,,,,,,,,,, -34900,0.3503781,2.9190922,,,,,,,,,,,,,,,,, -35000,0.39711276,3.0255437,,,,,,,,,,,,,,,,, -35100,0.38633165,2.9944386,,,,,,,,,,,,,,,,, -35200,0.38035578,2.9670937,,,,,,,,,,,,,,,,, -35300,0.34747723,2.9260285,,,,,,,,,,,,,,,,, -35400,0.41484538,2.980851,,,,,,,,,,,,,,,,, -35500,0.35927087,3.0025425,,,,,,,,,,,,,,,,, -35600,0.33817697,2.946354,,,,,,,,,,,,,,,,, -35700,0.389935,2.9490147,,,,,,,,,,,,,,,,, -35767,,,0.6377241015434265,1.875410556793213,31.08551284303532,0.6565572619438171,1.71941077709198,28.019113346418383,3000.0,0.6706873774528503,1.6476588249206543,27.893507040265977,3003.0,12638.43499302864,21543.868954896927,12638.43499302864,8903.872762203217,0.4422931671142578,0.0 -35800,0.3960464,3.0463998,,,,,,,,,,,,,,,,, -35900,0.40838367,2.975433,,,,,,,,,,,,,,,,, -36000,0.35036108,2.9313605,,,,,,,,,,,,,,,,, -36100,0.36041588,3.002123,,,,,,,,,,,,,,,,, -36200,0.49109837,2.9828265,,,,,,,,,,,,,,,,, -36300,0.3710385,3.0326827,,,,,,,,,,,,,,,,, -36400,0.4216648,2.977287,,,,,,,,,,,,,,,,, -36500,0.3632988,2.901403,,,,,,,,,,,,,,,,, -36600,0.4329915,2.9241383,,,,,,,,,,,,,,,,, -36700,0.34604624,2.919402,,,,,,,,,,,,,,,,, -36800,0.3414171,2.989499,,,,,,,,,,,,,,,,, -36900,0.37297845,3.0251906,,,,,,,,,,,,,,,,, -37000,0.37824327,2.954863,,,,,,,,,,,,,,,,, -37100,0.41370094,3.0027373,,,,,,,,,,,,,,,,, -37200,0.5514241,3.0248523,,,,,,,,,,,,,,,,, -37300,0.3785581,2.9665918,,,,,,,,,,,,,,,,, -37400,0.33936974,2.913688,,,,,,,,,,,,,,,,, -37500,0.4527417,2.9708354,,,,,,,,,,,,,,,,, -37600,0.38547212,2.9427857,,,,,,,,,,,,,,,,, -37700,0.37768194,2.8607392,,,,,,,,,,,,,,,,, -37800,0.46277255,2.945459,,,,,,,,,,,,,,,,, -37900,0.35556597,2.9365118,,,,,,,,,,,,,,,,, -38000,0.39243448,2.9494026,,,,,,,,,,,,,,,,, -38100,0.3740633,2.9782863,,,,,,,,,,,,,,,,, -38152,,,0.6488719582557678,1.7934083938598633,31.61389258230864,0.6610581278800964,1.70951247215271,28.280349091822146,3000.0,0.6716054081916809,1.6476448774337769,27.802578131818603,3003.0,13478.334066152573,22861.5990588665,13478.334066152573,9381.595716238022,0.4726207256317138,0.0 -38200,0.38702008,2.9557176,,,,,,,,,,,,,,,,, -38300,0.37222493,2.8703325,,,,,,,,,,,,,,,,, -38400,0.41328123,2.9780905,,,,,,,,,,,,,,,,, -38500,0.37749076,2.9069734,,,,,,,,,,,,,,,,, -38600,0.3537677,2.9040139,,,,,,,,,,,,,,,,, -38700,0.3454462,2.8857746,,,,,,,,,,,,,,,,, -38800,0.36429384,2.973431,,,,,,,,,,,,,,,,, -38900,0.38877806,2.904619,,,,,,,,,,,,,,,,, -39000,0.36349955,3.039256,,,,,,,,,,,,,,,,, -39100,0.31450424,2.959874,,,,,,,,,,,,,,,,, -39200,0.38398087,2.952504,,,,,,,,,,,,,,,,, -39300,0.36759347,2.966885,,,,,,,,,,,,,,,,, -39400,0.34539798,3.0221193,,,,,,,,,,,,,,,,, -39500,0.3806151,2.9634843,,,,,,,,,,,,,,,,, -39600,0.3552844,2.934116,,,,,,,,,,,,,,,,, -39700,0.40010235,2.9178588,,,,,,,,,,,,,,,,, -39800,0.3487037,2.9561799,,,,,,,,,,,,,,,,, -39900,0.35678872,2.9462874,,,,,,,,,,,,,,,,, -40000,0.3869596,3.01593,,,,,,,,,,,,,,,,, -40100,0.36889192,2.949618,,,,,,,,,,,,,,,,, -40200,0.39100242,2.9268205,,,,,,,,,,,,,,,,, -40300,0.37881938,2.9374175,,,,,,,,,,,,,,,,, -40400,0.40007037,2.9477408,,,,,,,,,,,,,,,,, -40500,0.36111805,2.96156,,,,,,,,,,,,,,,,, -40538,,,0.6424659490585327,1.8437772989273071,31.08091692572428,0.6593222618103027,1.7040035724639893,28.3158290660357,3000.0,0.6716402173042297,1.6384211778640747,27.64028559420224,3003.0,14318.342139482498,24222.2512857914,14318.342139482498,9902.135681152344,0.5022103786468506,0.0 -40600,0.35334852,2.9649417,,,,,,,,,,,,,,,,, -40700,0.3881036,2.9354823,,,,,,,,,,,,,,,,, -40800,0.3449164,2.903508,,,,,,,,,,,,,,,,, -40900,0.31441784,2.878823,,,,,,,,,,,,,,,,, -41000,0.35541943,2.9338417,,,,,,,,,,,,,,,,, -41100,0.4186297,2.9975324,,,,,,,,,,,,,,,,, -41200,0.40339828,2.966781,,,,,,,,,,,,,,,,, -41300,0.4522244,2.9511826,,,,,,,,,,,,,,,,, -41400,0.3504821,2.8971229,,,,,,,,,,,,,,,,, -41500,0.35194325,2.9103026,,,,,,,,,,,,,,,,, -41600,0.54061097,2.9984283,,,,,,,,,,,,,,,,, -41700,0.38850653,2.9897416,,,,,,,,,,,,,,,,, -41800,0.38326016,2.9977295,,,,,,,,,,,,,,,,, -41900,0.34581372,2.9720128,,,,,,,,,,,,,,,,, -42000,0.3196632,2.9427576,,,,,,,,,,,,,,,,, -42100,0.36054996,2.935103,,,,,,,,,,,,,,,,, -42200,0.3642828,2.9469776,,,,,,,,,,,,,,,,, -42300,0.32550967,2.9164994,,,,,,,,,,,,,,,,, -42400,0.36223885,2.9273734,,,,,,,,,,,,,,,,, -42500,0.34079614,2.9096835,,,,,,,,,,,,,,,,, -42600,0.39180893,2.9088256,,,,,,,,,,,,,,,,, -42700,0.33439472,2.922894,,,,,,,,,,,,,,,,, -42800,0.37712657,2.9336808,,,,,,,,,,,,,,,,, -42900,0.3695404,2.8959963,,,,,,,,,,,,,,,,, -42925,,,0.6389986276626587,1.856657862663269,31.192170592865967,0.6593842506408691,1.7011802196502686,28.26113223855552,3000.0,0.6723374724388123,1.6351051330566406,28.000836489908377,3003.0,15158.377641677856,25544.065200567245,15158.377641677856,10383.810242652891,0.5315215587615967,0.0 -43000,0.33866554,2.94325,,,,,,,,,,,,,,,,, -43100,0.36591658,2.9763894,,,,,,,,,,,,,,,,, -43200,0.3540355,2.9503856,,,,,,,,,,,,,,,,, -43300,0.4792659,2.908763,,,,,,,,,,,,,,,,, -43400,0.35301673,2.975446,,,,,,,,,,,,,,,,, -43500,0.39362657,2.9365795,,,,,,,,,,,,,,,,, -43600,0.40472746,2.91529,,,,,,,,,,,,,,,,, -43700,0.35333812,2.96226,,,,,,,,,,,,,,,,, -43800,0.34576875,2.940677,,,,,,,,,,,,,,,,, -43900,0.35160038,2.9184308,,,,,,,,,,,,,,,,, -44000,0.35545173,2.917368,,,,,,,,,,,,,,,,, -44100,0.37071893,2.9076028,,,,,,,,,,,,,,,,, -44200,0.3360731,2.999846,,,,,,,,,,,,,,,,, -44300,0.37913677,2.9245458,,,,,,,,,,,,,,,,, -44400,0.34480965,2.9553916,,,,,,,,,,,,,,,,, -44500,0.4324697,2.9663517,,,,,,,,,,,,,,,,, -44600,0.3438361,2.9522042,,,,,,,,,,,,,,,,, -44700,0.3240286,2.854837,,,,,,,,,,,,,,,,, -44800,0.35647303,2.894144,,,,,,,,,,,,,,,,, -44900,0.37706697,3.0826225,,,,,,,,,,,,,,,,, -45000,0.33058825,2.8844044,,,,,,,,,,,,,,,,, -45100,0.39938393,2.9146113,,,,,,,,,,,,,,,,, -45200,0.37279472,2.977888,,,,,,,,,,,,,,,,, -45300,0.3643253,2.9255054,,,,,,,,,,,,,,,,, -45312,,,0.6493979692459106,1.7883561849594116,31.863872415061568,0.6628807783126831,1.6845715045928955,28.467148185414068,3000.0,0.6726396083831787,1.62099027633667,27.586332945086863,3003.0,15998.422271966934,26853.32121229172,15998.422271966934,10852.916440725328,0.5626571178436279,0.0 -45400,0.41142523,2.9663966,,,,,,,,,,,,,,,,, -45500,0.39554673,3.0207844,,,,,,,,,,,,,,,,, -45600,0.4016975,2.9584832,,,,,,,,,,,,,,,,, -45700,0.33252776,2.9416304,,,,,,,,,,,,,,,,, -45800,0.31227222,2.928032,,,,,,,,,,,,,,,,, -45900,0.40654436,3.0086887,,,,,,,,,,,,,,,,, -46000,0.36336485,2.9020352,,,,,,,,,,,,,,,,, -46100,0.3914188,2.9960287,,,,,,,,,,,,,,,,, -46200,0.36984587,2.939759,,,,,,,,,,,,,,,,, -46300,0.401285,2.9301565,,,,,,,,,,,,,,,,, -46400,0.3281194,2.9130633,,,,,,,,,,,,,,,,, -46500,0.3903827,2.878838,,,,,,,,,,,,,,,,, -46600,0.35028005,2.8997006,,,,,,,,,,,,,,,,, -46700,0.37420896,2.9300084,,,,,,,,,,,,,,,,, -46800,0.35655007,2.938744,,,,,,,,,,,,,,,,, -46900,0.36794648,2.9293783,,,,,,,,,,,,,,,,, -47000,0.34002376,3.0197446,,,,,,,,,,,,,,,,, -47100,0.41161686,3.012664,,,,,,,,,,,,,,,,, -47200,0.37631747,3.0045073,,,,,,,,,,,,,,,,, -47300,0.36722508,2.8700414,,,,,,,,,,,,,,,,, -47400,0.34151584,2.9032302,,,,,,,,,,,,,,,,, -47500,0.38131815,2.953891,,,,,,,,,,,,,,,,, -47600,0.34498352,2.8796735,,,,,,,,,,,,,,,,, -47700,,,0.641980767250061,1.836068153381348,31.306511234966308,0.6625956296920776,1.6862725019454956,28.62275325157356,3000.0,0.6753239631652832,1.6231544017791748,28.22375180587874,3003.0,16838.640387535095,28156.87495303154,16838.640387535095,11316.14895439148,0.5931167602539062,0.0 -47700,0.37001142,2.982498,,,,,,,,,,,,,,,,, -47800,0.43230045,2.9274657,,,,,,,,,,,,,,,,, -47900,0.36745542,2.9292789,,,,,,,,,,,,,,,,, -48000,0.37478805,2.9073148,,,,,,,,,,,,,,,,, -48100,0.4145952,2.8503654,,,,,,,,,,,,,,,,, -48200,0.36020598,2.935817,,,,,,,,,,,,,,,,, -48300,0.3504834,2.9561317,,,,,,,,,,,,,,,,, -48400,0.4165104,2.9557538,,,,,,,,,,,,,,,,, -48500,0.35834092,2.9222116,,,,,,,,,,,,,,,,, -48600,0.33175853,2.9372885,,,,,,,,,,,,,,,,, -48700,0.38459858,2.963652,,,,,,,,,,,,,,,,, -48800,0.31699204,2.88534,,,,,,,,,,,,,,,,, -48900,0.35947958,2.9884837,,,,,,,,,,,,,,,,, -49000,0.36414298,2.9326038,,,,,,,,,,,,,,,,, -49100,0.36056444,2.9080145,,,,,,,,,,,,,,,,, -49200,0.3205224,2.9013987,,,,,,,,,,,,,,,,, -49300,0.37751082,2.9314446,,,,,,,,,,,,,,,,, -49400,0.36502913,2.9655166,,,,,,,,,,,,,,,,, -49500,0.36514813,2.8779364,,,,,,,,,,,,,,,,, -49600,0.35325438,2.9272583,,,,,,,,,,,,,,,,, -49700,0.3762549,2.9769096,,,,,,,,,,,,,,,,, -49800,0.38018328,3.0068343,,,,,,,,,,,,,,,,, -49900,0.39436486,3.007029,,,,,,,,,,,,,,,,, -50000,0.33275992,2.9261324,,,,,,,,,,,,,,,,, -50087,,,0.691525399684906,1.5449894666671753,35.324233093710475,0.6646042466163635,1.6788150072097778,28.461571849251246,3000.0,0.6756957769393921,1.6147881746292114,28.095035425482383,3003.0,17678.58997631073,29477.193959236145,17678.58997631073,11796.412520170212,0.6238071918487549,0.0 -50100,0.37836626,2.9295225,,,,,,,,,,,,,,,,, -50200,0.32229897,2.925793,,,,,,,,,,,,,,,,, -50300,0.40662333,2.9286406,,,,,,,,,,,,,,,,, -50400,0.34016904,2.934885,,,,,,,,,,,,,,,,, -50500,0.39101192,2.8978276,,,,,,,,,,,,,,,,, -50600,0.3792705,2.9312632,,,,,,,,,,,,,,,,, -50700,0.48447058,2.913826,,,,,,,,,,,,,,,,, -50800,0.3734644,2.9779122,,,,,,,,,,,,,,,,, -50900,0.40323174,2.9567246,,,,,,,,,,,,,,,,, -51000,0.36630553,2.905029,,,,,,,,,,,,,,,,, -51100,0.3333621,2.8725355,,,,,,,,,,,,,,,,, -51200,0.34661028,2.8984745,,,,,,,,,,,,,,,,, -51300,0.37534475,2.951036,,,,,,,,,,,,,,,,, -51400,0.39054406,2.856498,,,,,,,,,,,,,,,,, -51500,0.32129002,2.8824883,,,,,,,,,,,,,,,,, -51600,0.34584194,2.9746175,,,,,,,,,,,,,,,,, -51700,0.34131026,2.9299133,,,,,,,,,,,,,,,,, -51800,0.36120817,2.9876451,,,,,,,,,,,,,,,,, -51900,0.3942555,2.8941746,,,,,,,,,,,,,,,,, -52000,0.38886508,2.9025888,,,,,,,,,,,,,,,,, -52100,0.34598824,2.9375849,,,,,,,,,,,,,,,,, -52200,0.3560809,2.9700735,,,,,,,,,,,,,,,,, -52300,0.34301835,2.9109876,,,,,,,,,,,,,,,,, -52400,0.40279946,2.9828033,,,,,,,,,,,,,,,,, -52474,,,0.648737370967865,1.7925587892532349,31.28400718143166,0.6651870608329773,1.667005181312561,28.830817072279498,3000.0,0.6764511466026306,1.6009951829910278,28.30420204851027,3003.0,18518.520733118057,30793.76582384109,18518.520733118057,12272.941549777985,0.6620500087738037,0.0 -52500,0.36368343,2.9296129,,,,,,,,,,,,,,,,, -52600,0.4062297,2.9414315,,,,,,,,,,,,,,,,, -52700,0.3615379,2.873775,,,,,,,,,,,,,,,,, -52800,0.34619993,2.9222345,,,,,,,,,,,,,,,,, -52900,0.3721849,2.904576,,,,,,,,,,,,,,,,, -53000,0.34498844,2.9050138,,,,,,,,,,,,,,,,, -53100,0.35439926,2.8563719,,,,,,,,,,,,,,,,, -53200,0.3547515,2.8775673,,,,,,,,,,,,,,,,, -53300,0.3494804,2.8378217,,,,,,,,,,,,,,,,, -53400,0.36221886,2.924244,,,,,,,,,,,,,,,,, -53500,0.3315573,2.9112852,,,,,,,,,,,,,,,,, -53600,0.4418834,2.8853688,,,,,,,,,,,,,,,,, -53700,0.39348426,2.9679172,,,,,,,,,,,,,,,,, -53800,0.34190068,2.9055028,,,,,,,,,,,,,,,,, -53900,0.36282822,2.898566,,,,,,,,,,,,,,,,, -54000,0.34659767,2.8889687,,,,,,,,,,,,,,,,, -54100,0.36337864,3.0223353,,,,,,,,,,,,,,,,, -54200,0.36693236,2.946029,,,,,,,,,,,,,,,,, -54300,0.3733176,2.924258,,,,,,,,,,,,,,,,, -54400,0.33282936,2.9109998,,,,,,,,,,,,,,,,, -54500,0.38161933,2.9841948,,,,,,,,,,,,,,,,, -54600,0.32558835,2.92766,,,,,,,,,,,,,,,,, -54700,0.3730932,2.9515772,,,,,,,,,,,,,,,,, -54800,0.38267586,2.871406,,,,,,,,,,,,,,,,, -54860,,,0.6473584175109863,1.8072962760925293,31.921184654644467,0.6671088933944702,1.663117289543152,29.0922268005727,3000.0,0.6795421838760376,1.5950560569763184,28.502942846220005,3003.0,19358.474761724472,32095.981645822525,19358.474761724472,12735.091276407242,0.6979632377624512,0.0 -54900,0.40760285,2.9369247,,,,,,,,,,,,,,,,, -55000,0.4143403,2.9566636,,,,,,,,,,,,,,,,, -55100,0.3624591,2.8984082,,,,,,,,,,,,,,,,, -55200,0.34697458,2.8763087,,,,,,,,,,,,,,,,, -55300,0.3494669,2.8683474,,,,,,,,,,,,,,,,, -55400,0.34790775,2.9127178,,,,,,,,,,,,,,,,, -55500,0.41852728,2.9715793,,,,,,,,,,,,,,,,, -55600,0.38322031,2.981687,,,,,,,,,,,,,,,,, -55700,0.35518917,2.8124323,,,,,,,,,,,,,,,,, -55800,0.36336282,2.9110992,,,,,,,,,,,,,,,,, -55900,0.35128406,2.8934245,,,,,,,,,,,,,,,,, -56000,0.3793318,2.928317,,,,,,,,,,,,,,,,, -56100,0.38049734,2.913874,,,,,,,,,,,,,,,,, -56200,0.3608461,2.9681723,,,,,,,,,,,,,,,,, -56300,0.4043609,2.9435022,,,,,,,,,,,,,,,,, -56400,0.3678525,2.8982859,,,,,,,,,,,,,,,,, -56500,0.32429627,2.9471858,,,,,,,,,,,,,,,,, -56600,0.3476688,2.926886,,,,,,,,,,,,,,,,, -56700,0.34880468,2.8312473,,,,,,,,,,,,,,,,, -56800,0.31401074,2.8821409,,,,,,,,,,,,,,,,, -56900,0.3180441,2.8367255,,,,,,,,,,,,,,,,, -57000,0.3462409,2.7893848,,,,,,,,,,,,,,,,, -57100,0.37645698,2.9728274,,,,,,,,,,,,,,,,, -57200,0.34219852,2.9441562,,,,,,,,,,,,,,,,, -57247,,,0.6565492749214172,1.7392457723617554,32.6825347018119,0.6668609380722046,1.654740810394287,29.00122112780653,3000.0,0.6824705004692078,1.5814528465270996,28.750735001154936,3003.0,20198.645827054977,33403.252178907394,20198.645827054977,13202.08388352394,0.7302701473236084,0.0 -57300,0.3783384,2.9247437,,,,,,,,,,,,,,,,, -57400,0.37363,2.8812723,,,,,,,,,,,,,,,,, -57500,0.36109743,2.8715665,,,,,,,,,,,,,,,,, -57600,0.39488727,2.8479836,,,,,,,,,,,,,,,,, -57700,0.36626655,2.8818321,,,,,,,,,,,,,,,,, -57800,0.36264002,2.8483078,,,,,,,,,,,,,,,,, -57900,0.36109975,2.9434805,,,,,,,,,,,,,,,,, -58000,0.35448283,2.967452,,,,,,,,,,,,,,,,, -58100,0.36729604,2.909741,,,,,,,,,,,,,,,,, -58200,0.3180381,2.8859155,,,,,,,,,,,,,,,,, -58300,0.36314952,2.8497438,,,,,,,,,,,,,,,,, -58400,0.3422065,2.9131434,,,,,,,,,,,,,,,,, -58500,0.36074802,2.9068751,,,,,,,,,,,,,,,,, -58600,0.3700015,2.8709173,,,,,,,,,,,,,,,,, -58700,0.3467976,2.8863294,,,,,,,,,,,,,,,,, -58800,0.3580101,2.8792152,,,,,,,,,,,,,,,,, -58900,0.35469013,2.9931438,,,,,,,,,,,,,,,,, -59000,0.45061785,2.8983834,,,,,,,,,,,,,,,,, -59100,0.40385264,2.8632848,,,,,,,,,,,,,,,,, -59200,0.47240606,2.9759617,,,,,,,,,,,,,,,,, -59300,0.43436128,2.9050765,,,,,,,,,,,,,,,,, -59400,0.3891749,2.9063296,,,,,,,,,,,,,,,,, -59500,0.36430147,2.8704371,,,,,,,,,,,,,,,,, -59600,0.38844207,2.9477854,,,,,,,,,,,,,,,,, -59634,,,0.6492637395858765,1.7962538003921509,31.7395194006414,0.6694399118423462,1.6482961177825928,28.93402112722786,3000.0,0.6804021000862122,1.5773099660873413,28.77173907460039,3003.0,21038.74125480652,34720.2066423893,21038.74125480652,13678.836746692656,0.7625217437744141,0.0 -59700,0.34333184,2.8567703,,,,,,,,,,,,,,,,, -59800,0.33789268,2.845831,,,,,,,,,,,,,,,,, -59900,0.33995646,2.851083,,,,,,,,,,,,,,,,, -60000,0.35741204,2.9439611,,,,,,,,,,,,,,,,, -60100,0.34618303,2.8528178,,,,,,,,,,,,,,,,, -60200,0.40841383,2.9232857,,,,,,,,,,,,,,,,, -60300,0.3759478,2.8951738,,,,,,,,,,,,,,,,, -60400,0.33384183,2.877125,,,,,,,,,,,,,,,,, -60500,0.36360568,2.8950326,,,,,,,,,,,,,,,,, -60600,0.3558048,2.8806403,,,,,,,,,,,,,,,,, -60700,0.3923873,2.862977,,,,,,,,,,,,,,,,, -60800,0.3347488,2.909275,,,,,,,,,,,,,,,,, -60900,0.4567451,2.8526437,,,,,,,,,,,,,,,,, -61000,0.32831195,2.8690476,,,,,,,,,,,,,,,,, -61100,0.37187183,2.9208758,,,,,,,,,,,,,,,,, -61200,0.37455449,2.9334776,,,,,,,,,,,,,,,,, -61300,0.36141685,2.8683345,,,,,,,,,,,,,,,,, -61400,0.338929,2.8713255,,,,,,,,,,,,,,,,, -61500,0.36014655,2.8559327,,,,,,,,,,,,,,,,, -61600,0.40800464,2.8233676,,,,,,,,,,,,,,,,, -61700,0.34171373,2.892386,,,,,,,,,,,,,,,,, -61800,0.34051818,2.8832288,,,,,,,,,,,,,,,,, -61900,0.35152045,2.9322114,,,,,,,,,,,,,,,,, -62000,0.36628354,2.9473855,,,,,,,,,,,,,,,,, -62021,,,0.6487945318222046,1.79489004611969,32.09413766752779,0.6691547632217407,1.6473053693771362,28.96760031948729,3000.0,0.6833536624908447,1.5698957443237305,28.95119878578508,3003.0,21878.683776140213,36047.12191772461,21878.683776140213,14165.701899528503,0.7960004806518555,0.0 -62100,0.39284486,2.8731651,,,,,,,,,,,,,,,,, -62200,0.34071907,2.9105477,,,,,,,,,,,,,,,,, -62300,0.36239594,2.830334,,,,,,,,,,,,,,,,, -62400,0.36662135,2.8989708,,,,,,,,,,,,,,,,, -62500,0.3873128,2.9344802,,,,,,,,,,,,,,,,, -62600,0.3933289,2.9430847,,,,,,,,,,,,,,,,, -62700,0.3961607,2.9439027,,,,,,,,,,,,,,,,, -62800,0.32780203,2.8731527,,,,,,,,,,,,,,,,, -62900,0.34730303,2.9413717,,,,,,,,,,,,,,,,, -63000,0.36732018,2.9145741,,,,,,,,,,,,,,,,, -63100,0.35974863,2.9008775,,,,,,,,,,,,,,,,, -63200,0.37593895,2.9118888,,,,,,,,,,,,,,,,, -63300,0.40278488,2.9147477,,,,,,,,,,,,,,,,, -63400,0.36201143,2.8382995,,,,,,,,,,,,,,,,, -63500,0.36260274,2.8272426,,,,,,,,,,,,,,,,, -63600,0.3577426,2.850017,,,,,,,,,,,,,,,,, -63700,0.36376208,2.8895166,,,,,,,,,,,,,,,,, -63800,0.39544985,2.925074,,,,,,,,,,,,,,,,, -63900,0.40840065,2.9348767,,,,,,,,,,,,,,,,, -64000,0.32681724,2.8198047,,,,,,,,,,,,,,,,, -64100,0.3995207,2.8807502,,,,,,,,,,,,,,,,, -64200,0.34614947,2.9529414,,,,,,,,,,,,,,,,, -64300,0.35172096,2.876583,,,,,,,,,,,,,,,,, -64400,0.3576386,2.8660588,,,,,,,,,,,,,,,,, -64408,,,0.6550009250640869,1.7537881135940552,32.452760515426576,0.6708534359931946,1.635170578956604,28.91921608404348,3000.0,0.6835861206054688,1.5647251605987549,29.056131315412504,3003.0,22718.78689146042,37366.41644477844,22718.78689146042,14644.784934043884,0.8291144371032715,0.0 -64500,0.3918827,2.9497554,,,,,,,,,,,,,,,,, -64600,0.34876928,2.8585186,,,,,,,,,,,,,,,,, -64700,0.3341275,2.833041,,,,,,,,,,,,,,,,, -64800,0.33263993,2.8874893,,,,,,,,,,,,,,,,, -64900,0.35859153,2.9570765,,,,,,,,,,,,,,,,, -65000,0.36282656,2.844728,,,,,,,,,,,,,,,,, -65100,0.37669626,2.8547769,,,,,,,,,,,,,,,,, -65200,0.34139812,2.8851233,,,,,,,,,,,,,,,,, -65300,0.362689,2.9653265,,,,,,,,,,,,,,,,, -65400,0.3423501,2.843228,,,,,,,,,,,,,,,,, -65500,0.39455196,2.861307,,,,,,,,,,,,,,,,, -65600,0.39947245,2.8246095,,,,,,,,,,,,,,,,, -65700,0.34293535,2.8987043,,,,,,,,,,,,,,,,, -65800,0.3914863,2.8454711,,,,,,,,,,,,,,,,, -65900,0.3887566,2.9268768,,,,,,,,,,,,,,,,, -66000,0.38151225,2.904504,,,,,,,,,,,,,,,,, -66100,0.346279,2.8831403,,,,,,,,,,,,,,,,, -66200,0.43666705,2.811443,,,,,,,,,,,,,,,,, -66300,0.35278624,2.8119602,,,,,,,,,,,,,,,,, -66400,0.3998039,2.8299577,,,,,,,,,,,,,,,,, -66500,0.34950525,2.9330146,,,,,,,,,,,,,,,,, -66600,0.3601841,2.9380045,,,,,,,,,,,,,,,,, -66700,0.36332244,2.8515985,,,,,,,,,,,,,,,,, -66795,,,0.6545759439468384,1.7578319311141968,32.203787090257514,0.6715105772018433,1.635469675064087,29.41297111527696,3000.0,0.6852013468742371,1.5587241649627686,29.178349712756702,3003.0,23558.88369011879,38686.76952624321,23558.88369011879,15124.934435129166,0.8623538017272949,0.0 -66800,0.3412201,2.8950799,,,,,,,,,,,,,,,,, -66900,0.3621137,2.8382218,,,,,,,,,,,,,,,,, -67000,0.36978832,2.8648627,,,,,,,,,,,,,,,,, -67100,0.3759505,2.8910651,,,,,,,,,,,,,,,,, -67200,0.38241968,2.9647858,,,,,,,,,,,,,,,,, -67300,0.33158636,2.858234,,,,,,,,,,,,,,,,, -67400,0.34988448,2.8348615,,,,,,,,,,,,,,,,, -67500,0.3481727,2.8672106,,,,,,,,,,,,,,,,, -67600,0.35495538,2.8853655,,,,,,,,,,,,,,,,, -67700,0.33545193,2.9028552,,,,,,,,,,,,,,,,, -67800,0.35163745,2.881038,,,,,,,,,,,,,,,,, -67900,0.3607489,2.8039937,,,,,,,,,,,,,,,,, -68000,0.38677293,2.9485939,,,,,,,,,,,,,,,,, -68100,0.36582294,2.8530567,,,,,,,,,,,,,,,,, -68200,0.3849465,2.8718076,,,,,,,,,,,,,,,,, -68300,0.36501604,2.9136834,,,,,,,,,,,,,,,,, -68400,0.37544215,2.975252,,,,,,,,,,,,,,,,, -68500,0.36069864,2.869818,,,,,,,,,,,,,,,,, -68600,0.38180968,2.805455,,,,,,,,,,,,,,,,, -68700,0.35728738,2.8267236,,,,,,,,,,,,,,,,, -68800,0.39392734,2.9366667,,,,,,,,,,,,,,,,, -68900,0.3718972,2.9073324,,,,,,,,,,,,,,,,, -69000,0.37462348,2.8731859,,,,,,,,,,,,,,,,, -69100,0.38175455,2.9081037,,,,,,,,,,,,,,,,, -69181,,,0.6702680587768555,1.6583515405654907,33.066980321510584,0.6727132797241211,1.62405264377594,29.20222921314499,3000.0,0.6837836503982544,1.5548111200332642,28.954724806981417,3003.0,24398.984651088715,40062.75036764145,24398.984651088715,15660.699274778366,0.9025025367736816,0.0 -69200,0.35511136,2.891606,,,,,,,,,,,,,,,,, -69300,0.36773437,2.940162,,,,,,,,,,,,,,,,, -69400,0.3427697,2.8244627,,,,,,,,,,,,,,,,, -69500,0.3759562,2.8509839,,,,,,,,,,,,,,,,, -69600,0.34358126,2.8609645,,,,,,,,,,,,,,,,, -69700,0.40329596,2.866577,,,,,,,,,,,,,,,,, -69800,0.38206464,2.8451025,,,,,,,,,,,,,,,,, -69900,0.37654114,2.84502,,,,,,,,,,,,,,,,, -70000,0.32836014,2.8650734,,,,,,,,,,,,,,,,, -70100,0.3463364,2.8682773,,,,,,,,,,,,,,,,, -70200,0.37519905,2.8213224,,,,,,,,,,,,,,,,, -70300,0.3270203,2.8274243,,,,,,,,,,,,,,,,, -70400,0.399723,2.866746,,,,,,,,,,,,,,,,, -70500,0.3963324,2.8359418,,,,,,,,,,,,,,,,, -70600,0.31676704,2.8781056,,,,,,,,,,,,,,,,, -70700,0.35747716,2.854538,,,,,,,,,,,,,,,,, -70800,0.36281982,2.831595,,,,,,,,,,,,,,,,, -70900,0.37849212,2.8408606,,,,,,,,,,,,,,,,, -71000,0.3673647,2.819143,,,,,,,,,,,,,,,,, -71100,0.379703,2.8449156,,,,,,,,,,,,,,,,, -71200,0.35032433,2.8540897,,,,,,,,,,,,,,,,, -71300,0.34793374,2.8614073,,,,,,,,,,,,,,,,, -71400,0.36845824,2.8806822,,,,,,,,,,,,,,,,, -71500,0.38186964,2.8688934,,,,,,,,,,,,,,,,, -71568,,,0.655456006526947,1.7491358518600464,32.35903194112188,0.6749451160430908,1.6112587451934814,29.35684438996192,3000.0,0.6880018711090088,1.5388880968093872,28.961093312146044,3003.0,25239.0883140564,41433.078765153885,25239.0883140564,16190.815400600432,0.9364674091339112,0.0 -71600,0.4542327,2.8106933,,,,,,,,,,,,,,,,, -71700,0.3706144,2.9095383,,,,,,,,,,,,,,,,, -71800,0.33586502,2.8255165,,,,,,,,,,,,,,,,, -71900,0.38919523,2.8708906,,,,,,,,,,,,,,,,, -72000,0.37372908,2.7945704,,,,,,,,,,,,,,,,, -72100,0.36664554,2.826209,,,,,,,,,,,,,,,,, -72200,0.36236212,2.8365011,,,,,,,,,,,,,,,,, -72300,0.3734102,2.8298013,,,,,,,,,,,,,,,,, -72400,0.35355046,2.896274,,,,,,,,,,,,,,,,, -72500,0.36640513,2.8961465,,,,,,,,,,,,,,,,, -72600,0.36858943,2.835675,,,,,,,,,,,,,,,,, -72700,0.34219873,2.7921586,,,,,,,,,,,,,,,,, -72800,0.36187446,2.8981795,,,,,,,,,,,,,,,,, -72900,0.35969839,2.931909,,,,,,,,,,,,,,,,, -73000,0.3477342,2.8735359,,,,,,,,,,,,,,,,, -73100,0.3531297,2.856785,,,,,,,,,,,,,,,,, -73200,0.36559778,2.857525,,,,,,,,,,,,,,,,, -73300,0.3604804,2.8481157,,,,,,,,,,,,,,,,, -73400,0.3512067,2.8021936,,,,,,,,,,,,,,,,, -73500,0.3623971,2.8883486,,,,,,,,,,,,,,,,, -73600,0.40312785,2.7536304,,,,,,,,,,,,,,,,, -73700,0.37610126,2.9308467,,,,,,,,,,,,,,,,, -73800,0.39371234,2.8459883,,,,,,,,,,,,,,,,, -73900,0.41414875,2.8677576,,,,,,,,,,,,,,,,, -73955,,,0.6537445783615112,1.7614141702651978,32.23235903214701,0.6735440492630005,1.6155986785888672,28.445815162064275,3000.0,0.6869211792945862,1.5449918508529663,29.29112343334928,3003.0,26079.173448324203,42909.42013859749,26079.173448324203,16826.95787167549,0.9769396781921388,0.0 -74000,0.38516596,2.9217384,,,,,,,,,,,,,,,,, -74100,0.37560052,2.8822594,,,,,,,,,,,,,,,,, -74200,0.38414443,2.809568,,,,,,,,,,,,,,,,, -74300,0.36877862,2.866972,,,,,,,,,,,,,,,,, -74400,0.3534524,2.830523,,,,,,,,,,,,,,,,, -74500,0.35842887,2.8286104,,,,,,,,,,,,,,,,, -74600,0.3980264,2.834079,,,,,,,,,,,,,,,,, -74700,0.37703848,2.873896,,,,,,,,,,,,,,,,, -74800,0.38360924,2.882093,,,,,,,,,,,,,,,,, -74900,0.3496051,2.8403757,,,,,,,,,,,,,,,,, -75000,0.36172286,2.8475816,,,,,,,,,,,,,,,,, -75100,0.3771396,2.8575463,,,,,,,,,,,,,,,,, -75200,0.37536117,2.8422866,,,,,,,,,,,,,,,,, -75300,0.37155917,2.8120124,,,,,,,,,,,,,,,,, -75400,0.36885068,2.9275405,,,,,,,,,,,,,,,,, -75500,0.3490545,2.860911,,,,,,,,,,,,,,,,, -75600,0.36784995,2.7685773,,,,,,,,,,,,,,,,, -75700,0.3322781,2.868755,,,,,,,,,,,,,,,,, -75800,0.3642257,2.8160944,,,,,,,,,,,,,,,,, -75900,0.36247057,2.9464092,,,,,,,,,,,,,,,,, -76000,0.37663096,2.8046718,,,,,,,,,,,,,,,,, -76100,0.3664805,2.8176458,,,,,,,,,,,,,,,,, -76200,0.388671,2.8556647,,,,,,,,,,,,,,,,, -76300,0.3533168,2.870551,,,,,,,,,,,,,,,,, -76340,,,0.6645991206169128,1.6866929531097412,32.924616682714294,0.67549067735672,1.6071454286575315,29.642798192788497,3000.0,0.6900935769081116,1.5276169776916504,29.379224504547373,3003.0,26919.26244521141,44230.73479604721,26919.26244521141,17308.0654463768,1.0143301486968994,0.0 -76400,0.37419045,2.8016021,,,,,,,,,,,,,,,,, -76500,0.36290628,2.79455,,,,,,,,,,,,,,,,, -76600,0.3942861,2.831524,,,,,,,,,,,,,,,,, -76700,0.3691488,2.8816671,,,,,,,,,,,,,,,,, -76800,0.3747574,2.7707152,,,,,,,,,,,,,,,,, -76900,0.35922578,2.9230742,,,,,,,,,,,,,,,,, -77000,0.36856386,2.840554,,,,,,,,,,,,,,,,, -77100,0.332503,2.82677,,,,,,,,,,,,,,,,, -77200,0.42087832,2.8402443,,,,,,,,,,,,,,,,, -77300,0.38405168,2.8564148,,,,,,,,,,,,,,,,, -77400,0.39665553,2.8303387,,,,,,,,,,,,,,,,, -77500,0.37876335,2.8221936,,,,,,,,,,,,,,,,, -77600,0.38323107,2.8503785,,,,,,,,,,,,,,,,, -77700,0.36303547,2.8667514,,,,,,,,,,,,,,,,, -77800,0.37703907,2.8183231,,,,,,,,,,,,,,,,, -77900,0.40086392,2.8993633,,,,,,,,,,,,,,,,, -78000,0.37578344,2.8265793,,,,,,,,,,,,,,,,, -78100,0.33830294,2.8236315,,,,,,,,,,,,,,,,, -78200,0.42033675,2.893373,,,,,,,,,,,,,,,,, -78300,0.37510315,2.8718035,,,,,,,,,,,,,,,,, -78400,0.47391394,2.8669713,,,,,,,,,,,,,,,,, -78500,0.3687596,2.803927,,,,,,,,,,,,,,,,, -78600,0.36850137,2.7421424,,,,,,,,,,,,,,,,, -78700,0.38774264,2.831677,,,,,,,,,,,,,,,,, -78727,,,0.6578866243362427,1.7336337566375732,33.02379867229866,0.6777225136756897,1.597773551940918,29.78239122588757,3000.0,0.6900470852851868,1.5211902856826782,29.493238645573108,3003.0,27759.4205942154,45561.32496523857,27759.4205942154,17798.384374141693,1.0526585578918457,0.0 -78800,0.3726894,2.859133,,,,,,,,,,,,,,,,, -78900,0.3728995,2.8353715,,,,,,,,,,,,,,,,, -79000,0.3904902,2.8245642,,,,,,,,,,,,,,,,, -79100,0.39014557,2.7687204,,,,,,,,,,,,,,,,, -79200,0.42854998,2.877026,,,,,,,,,,,,,,,,, -79300,0.36777085,2.8149848,,,,,,,,,,,,,,,,, -79400,0.34273654,2.817055,,,,,,,,,,,,,,,,, -79500,0.37531936,2.7892642,,,,,,,,,,,,,,,,, -79600,0.3874235,2.8729646,,,,,,,,,,,,,,,,, -79700,0.39252058,2.8321338,,,,,,,,,,,,,,,,, -79800,0.36991847,2.8202412,,,,,,,,,,,,,,,,, -79900,0.40769377,2.84779,,,,,,,,,,,,,,,,, -80000,0.38193202,2.865148,,,,,,,,,,,,,,,,, -80100,0.3762118,2.801639,,,,,,,,,,,,,,,,, -80200,0.38432404,2.857743,,,,,,,,,,,,,,,,, -80300,0.36887994,2.8588843,,,,,,,,,,,,,,,,, -80400,2.882318,2.7913768,,,,,,,,,,,,,,,,, -80500,0.38750106,2.9328666,,,,,,,,,,,,,,,,, -80600,0.38852555,2.807044,,,,,,,,,,,,,,,,, -80700,0.37386972,2.8395736,,,,,,,,,,,,,,,,, -80800,0.3789359,2.8126454,,,,,,,,,,,,,,,,, -80900,0.35662884,2.7491198,,,,,,,,,,,,,,,,, -81000,0.36635605,2.8896601,,,,,,,,,,,,,,,,, -81100,0.35554582,2.8499353,,,,,,,,,,,,,,,,, -81114,,,0.6640418171882629,1.6978520154953003,32.88214843273441,0.6789872050285339,1.593400001525879,29.676987911310327,3000.0,0.6944047808647156,1.50739848613739,29.8250633065196,3003.0,28599.610867261887,46876.34319114685,28599.610867261887,18273.101594686508,1.0890939235687256,0.0 -81200,0.38638628,2.848693,,,,,,,,,,,,,,,,, -81300,0.3675015,2.8284311,,,,,,,,,,,,,,,,, -81400,0.3810979,2.8188906,,,,,,,,,,,,,,,,, -81500,0.37958825,2.8453794,,,,,,,,,,,,,,,,, -81600,0.37675768,2.802278,,,,,,,,,,,,,,,,, -81700,0.38356748,2.8294325,,,,,,,,,,,,,,,,, -81800,0.3619566,2.8050041,,,,,,,,,,,,,,,,, -81900,0.4012992,2.7911384,,,,,,,,,,,,,,,,, -82000,0.39943606,2.8740985,,,,,,,,,,,,,,,,, -82100,0.36769816,2.743172,,,,,,,,,,,,,,,,, -82200,0.4057025,2.8190222,,,,,,,,,,,,,,,,, -82300,0.36475116,2.795537,,,,,,,,,,,,,,,,, -82400,0.39079273,2.8410535,,,,,,,,,,,,,,,,, -82500,0.37690887,2.826653,,,,,,,,,,,,,,,,, -82600,0.38719723,2.8839464,,,,,,,,,,,,,,,,, -82700,0.38228253,2.837562,,,,,,,,,,,,,,,,, -82800,0.37704942,2.8702462,,,,,,,,,,,,,,,,, -82900,0.39351028,2.8137817,,,,,,,,,,,,,,,,, -83000,0.42063183,2.8289587,,,,,,,,,,,,,,,,, -83100,0.3704788,2.809563,,,,,,,,,,,,,,,,, -83200,0.3773887,2.8092048,,,,,,,,,,,,,,,,, -83300,0.37011334,2.8003285,,,,,,,,,,,,,,,,, -83400,0.35898992,2.8177724,,,,,,,,,,,,,,,,, -83500,0.379119,2.7519834,,,,,,,,,,,,,,,,, -83501,,,0.6630203723907471,1.7007750272750854,32.74125649796642,0.6791360378265381,1.5790903568267822,29.65465038029376,3000.0,0.6937307715415955,1.5014517307281494,29.50003298374016,3003.0,29440.040162086487,48205.56838059425,29440.040162086487,18761.78837966919,1.1252844333648682,0.0 -83600,0.38400233,2.7946463,,,,,,,,,,,,,,,,, -83700,0.39104357,2.8912444,,,,,,,,,,,,,,,,, -83800,0.37010065,2.81718,,,,,,,,,,,,,,,,, -83900,0.38741022,2.81709,,,,,,,,,,,,,,,,, -84000,0.4091809,2.8428867,,,,,,,,,,,,,,,,, -84100,0.3512344,2.8260114,,,,,,,,,,,,,,,,, -84200,0.43803793,2.8489907,,,,,,,,,,,,,,,,, -84300,0.4070689,2.7609203,,,,,,,,,,,,,,,,, -84400,0.36696228,2.8385053,,,,,,,,,,,,,,,,, -84500,0.38167292,2.7820067,,,,,,,,,,,,,,,,, -84600,0.37129965,2.7835774,,,,,,,,,,,,,,,,, -84700,0.40345016,2.8236272,,,,,,,,,,,,,,,,, -84800,0.46693844,2.814951,,,,,,,,,,,,,,,,, -84900,0.37112114,2.8380098,,,,,,,,,,,,,,,,, -85000,0.389069,2.7506416,,,,,,,,,,,,,,,,, -85100,0.40166584,2.7961826,,,,,,,,,,,,,,,,, -85200,0.4128211,2.8117244,,,,,,,,,,,,,,,,, -85300,0.4199362,2.7558515,,,,,,,,,,,,,,,,, -85400,0.38738635,2.8579087,,,,,,,,,,,,,,,,, -85500,0.4263467,2.7629414,,,,,,,,,,,,,,,,, -85600,0.38537362,2.7897677,,,,,,,,,,,,,,,,, -85700,0.41298252,2.860823,,,,,,,,,,,,,,,,, -85800,0.3776457,2.7887049,,,,,,,,,,,,,,,,, -85888,,,0.6633317470550537,1.6965419054031372,32.76356242688841,0.6802891492843628,1.5811126232147217,30.043155582785047,3000.0,0.6941258907318115,1.5003278255462646,29.87476830317458,3003.0,30280.26576089859,49530.71009230614,30280.26576089859,19246.59332036972,1.1620018482208252,0.0 -85900,0.3848093,2.7550175,,,,,,,,,,,,,,,,, -86000,0.52185404,2.8074775,,,,,,,,,,,,,,,,, -86100,0.35337672,2.754965,,,,,,,,,,,,,,,,, -86200,0.37731892,2.7793663,,,,,,,,,,,,,,,,, -86300,0.35972998,2.7927024,,,,,,,,,,,,,,,,, -86400,0.37401998,2.819304,,,,,,,,,,,,,,,,, -86500,0.38404125,2.8431556,,,,,,,,,,,,,,,,, -86600,0.4041275,2.8306355,,,,,,,,,,,,,,,,, -86700,0.39052552,2.7901437,,,,,,,,,,,,,,,,, -86800,0.38343665,2.7586641,,,,,,,,,,,,,,,,, -86900,0.37603047,2.8096116,,,,,,,,,,,,,,,,, -87000,0.38460204,2.7179832,,,,,,,,,,,,,,,,, -87100,0.38401768,2.6929734,,,,,,,,,,,,,,,,, -87200,0.3943275,2.7662857,,,,,,,,,,,,,,,,, -87300,0.4009412,2.7528508,,,,,,,,,,,,,,,,, -87400,0.37949917,2.8375282,,,,,,,,,,,,,,,,, -87500,0.40872785,2.7341557,,,,,,,,,,,,,,,,, -87600,0.40203497,2.754588,,,,,,,,,,,,,,,,, -87700,0.38976595,2.7521539,,,,,,,,,,,,,,,,, -87800,0.40144575,2.7993348,,,,,,,,,,,,,,,,, -87900,0.38561213,2.7833838,,,,,,,,,,,,,,,,, -88000,0.38682467,2.7600048,,,,,,,,,,,,,,,,, -88100,0.38675424,2.7657998,,,,,,,,,,,,,,,,, -88200,0.4091496,2.8510005,,,,,,,,,,,,,,,,, -88274,,,0.682263195514679,1.5941760540008545,34.32433497901006,0.6819010376930237,1.5699979066848757,29.823944589013955,3000.0,0.6963802576065063,1.4937410354614258,29.658199063669382,3003.0,31120.169873714447,50868.83356237412,31120.169873714447,19744.70170378685,1.199810266494751,0.0 -88300,0.38501814,2.8596659,,,,,,,,,,,,,,,,, -88400,0.40859938,2.8037448,,,,,,,,,,,,,,,,, -88500,0.4036761,2.755266,,,,,,,,,,,,,,,,, -88600,0.39942196,2.8019223,,,,,,,,,,,,,,,,, -88700,0.4297631,2.8093033,,,,,,,,,,,,,,,,, -88800,0.40166795,2.8301237,,,,,,,,,,,,,,,,, -88900,0.40872326,2.8046787,,,,,,,,,,,,,,,,, -89000,0.41915414,2.7957823,,,,,,,,,,,,,,,,, -89100,0.38122177,2.7840407,,,,,,,,,,,,,,,,, -89200,0.4074019,2.797661,,,,,,,,,,,,,,,,, -89300,0.37354204,2.735684,,,,,,,,,,,,,,,,, -89400,0.39732394,2.839849,,,,,,,,,,,,,,,,, -89500,0.3829883,2.7555182,,,,,,,,,,,,,,,,, -89600,0.40332562,2.7278001,,,,,,,,,,,,,,,,, -89700,0.4059126,2.7418144,,,,,,,,,,,,,,,,, -89800,0.3720674,2.7193844,,,,,,,,,,,,,,,,, -89900,0.3857379,2.7327662,,,,,,,,,,,,,,,,, -90000,0.40455446,2.7329538,,,,,,,,,,,,,,,,, -90100,0.40739825,2.8086436,,,,,,,,,,,,,,,,, -90200,0.4419395,2.8208032,,,,,,,,,,,,,,,,, -90300,0.4236061,2.7654586,,,,,,,,,,,,,,,,, -90400,0.42109036,2.7383754,,,,,,,,,,,,,,,,, -90500,0.4061935,2.7818303,,,,,,,,,,,,,,,,, -90600,0.40210873,2.7971292,,,,,,,,,,,,,,,,, -90660,,,0.6701759696006775,1.6549545526504517,32.95552252211354,0.6838228702545166,1.5629757642745972,30.1105725801552,3000.0,0.6983673572540283,1.478404879570007,30.366281457229302,3003.0,31960.32098174095,52200.76883912087,31960.32098174095,20236.372787475582,1.2370805740356443,0.0 -90700,0.4034303,2.8219573,,,,,,,,,,,,,,,,, -90800,0.40975344,2.8495882,,,,,,,,,,,,,,,,, -90900,0.4017236,2.7757535,,,,,,,,,,,,,,,,, -91000,0.38517034,2.7484398,,,,,,,,,,,,,,,,, -91100,0.41414785,2.7823248,,,,,,,,,,,,,,,,, -91200,0.40722892,2.720677,,,,,,,,,,,,,,,,, -91300,0.3972741,2.7128642,,,,,,,,,,,,,,,,, -91400,0.4116998,2.7963955,,,,,,,,,,,,,,,,, -91500,0.42935815,2.7728512,,,,,,,,,,,,,,,,, -91600,0.41332063,2.7698898,,,,,,,,,,,,,,,,, -91700,0.42548263,2.7340317,,,,,,,,,,,,,,,,, -91800,0.40503755,2.7372568,,,,,,,,,,,,,,,,, -91900,0.4332086,2.7679334,,,,,,,,,,,,,,,,, -92000,0.42759365,2.8114564,,,,,,,,,,,,,,,,, -92100,0.41436288,2.8332086,,,,,,,,,,,,,,,,, -92200,0.38085857,2.7258542,,,,,,,,,,,,,,,,, -92300,0.43493938,2.7536318,,,,,,,,,,,,,,,,, -92400,0.39782116,2.7290986,,,,,,,,,,,,,,,,, -92500,0.41191474,2.813788,,,,,,,,,,,,,,,,, -92600,0.42174163,2.8621562,,,,,,,,,,,,,,,,, -92700,0.40963885,2.7301395,,,,,,,,,,,,,,,,, -92800,0.45261276,2.7852182,,,,,,,,,,,,,,,,, -92900,0.38834977,2.74948,,,,,,,,,,,,,,,,, -93000,0.4242866,2.721284,,,,,,,,,,,,,,,,, -93047,,,0.6740824580192566,1.6410588026046753,32.95555072276922,0.6841452717781067,1.5564584732055664,30.397595709889032,3000.0,0.7005403637886047,1.4703547954559326,30.228897107393603,3003.0,32800.464728832245,53550.79271054268,32800.464728832245,20746.14312171936,1.2735848426818848,0.0 -93100,0.3936081,2.6954749,,,,,,,,,,,,,,,,, -93200,0.40761244,2.7818325,,,,,,,,,,,,,,,,, -93300,0.44204617,2.7974942,,,,,,,,,,,,,,,,, -93400,0.41987872,2.7973187,,,,,,,,,,,,,,,,, -93500,0.39965928,2.7829242,,,,,,,,,,,,,,,,, -93600,0.40515196,2.7794147,,,,,,,,,,,,,,,,, -93700,0.43519238,2.8426435,,,,,,,,,,,,,,,,, -93800,0.44704345,2.8589919,,,,,,,,,,,,,,,,, -93900,0.4246553,2.7223864,,,,,,,,,,,,,,,,, -94000,0.46583784,2.7794185,,,,,,,,,,,,,,,,, -94100,0.44784495,2.8011088,,,,,,,,,,,,,,,,, -94200,0.41522753,2.731444,,,,,,,,,,,,,,,,, -94300,0.42861608,2.7332118,,,,,,,,,,,,,,,,, -94400,0.43144885,2.8127367,,,,,,,,,,,,,,,,, -94500,0.4444779,2.7502704,,,,,,,,,,,,,,,,, -94600,0.41517666,2.7506785,,,,,,,,,,,,,,,,, -94700,0.3985623,2.8139968,,,,,,,,,,,,,,,,, -94800,0.42751032,2.8315926,,,,,,,,,,,,,,,,, -94900,0.39220348,2.8041272,,,,,,,,,,,,,,,,, -95000,0.4277982,2.7506788,,,,,,,,,,,,,,,,, -95100,0.42688987,2.668056,,,,,,,,,,,,,,,,, -95200,0.4245098,2.8170054,,,,,,,,,,,,,,,,, -95300,0.43596944,2.76951,,,,,,,,,,,,,,,,, -95400,0.4550327,2.7698872,,,,,,,,,,,,,,,,, -95434,,,0.6752694249153137,1.6265360116958618,34.1304166462366,0.6851744055747986,1.5516746044158936,30.227560754570124,3000.0,0.7010632753372192,1.465505599975586,30.193043914956924,3003.0,33640.66740679741,54868.87886095047,33640.66740679741,21223.906438589096,1.3174645900726318,0.0 -95500,0.44779176,2.767782,,,,,,,,,,,,,,,,, -95600,0.43548727,2.7010047,,,,,,,,,,,,,,,,, -95700,0.4129841,2.7580342,,,,,,,,,,,,,,,,, -95800,0.42503467,2.755218,,,,,,,,,,,,,,,,, -95900,0.41915935,2.7437837,,,,,,,,,,,,,,,,, -96000,0.4370612,2.808517,,,,,,,,,,,,,,,,, -96100,0.45772457,2.7369745,,,,,,,,,,,,,,,,, -96200,0.456451,2.7748196,,,,,,,,,,,,,,,,, -96300,0.41922897,2.7260125,,,,,,,,,,,,,,,,, -96400,0.44400543,2.8177927,,,,,,,,,,,,,,,,, -96500,0.4407884,2.7014458,,,,,,,,,,,,,,,,, -96600,0.45319912,2.7667873,,,,,,,,,,,,,,,,, -96700,0.4414747,2.817814,,,,,,,,,,,,,,,,, -96800,0.43323353,2.7710133,,,,,,,,,,,,,,,,, -96900,0.4400522,2.7390394,,,,,,,,,,,,,,,,, -97000,0.43753687,2.8052835,,,,,,,,,,,,,,,,, -97100,0.40227687,2.7148418,,,,,,,,,,,,,,,,, -97200,0.4394483,2.7950804,,,,,,,,,,,,,,,,, -97300,0.45348036,2.7751956,,,,,,,,,,,,,,,,, -97400,0.45951977,2.7145941,,,,,,,,,,,,,,,,, -97500,0.4486833,2.6932235,,,,,,,,,,,,,,,,, -97600,0.4162593,2.7217975,,,,,,,,,,,,,,,,, -97700,0.46465224,2.814621,,,,,,,,,,,,,,,,, -97800,0.42486575,2.7102234,,,,,,,,,,,,,,,,, -97821,,,0.6793409585952759,1.6154285669326782,33.39148472321215,0.6855835318565369,1.547755002975464,30.3606306216234,3000.0,0.7020161747932434,1.464707612991333,30.33091817118647,3003.0,34480.82843494415,56180.17577815056,34480.82843494415,21694.93043994904,1.3555314540863037,0.0 -97900,0.4582218,2.6909122,,,,,,,,,,,,,,,,, -98000,0.4575694,2.7794693,,,,,,,,,,,,,,,,, -98100,0.46658346,2.7717133,,,,,,,,,,,,,,,,, -98200,0.43189433,2.687551,,,,,,,,,,,,,,,,, -98300,0.47664815,2.7960956,,,,,,,,,,,,,,,,, -98400,0.49776173,2.7979376,,,,,,,,,,,,,,,,, -98500,0.43893778,2.7405343,,,,,,,,,,,,,,,,, -98600,0.4209287,2.6607122,,,,,,,,,,,,,,,,, -98700,0.42869976,2.710132,,,,,,,,,,,,,,,,, -98800,0.45109886,2.7791064,,,,,,,,,,,,,,,,, -98900,0.43055326,2.7597127,,,,,,,,,,,,,,,,, -99000,0.45274088,2.7421148,,,,,,,,,,,,,,,,, -99100,0.45167893,2.7297,,,,,,,,,,,,,,,,, -99200,0.45408744,2.6973546,,,,,,,,,,,,,,,,, -99300,0.4521922,2.6984522,,,,,,,,,,,,,,,,, -99400,0.43061882,2.7512596,,,,,,,,,,,,,,,,, -99500,0.47527155,2.7200458,,,,,,,,,,,,,,,,, -99600,0.47364932,2.7194498,,,,,,,,,,,,,,,,, -99700,0.46535832,2.7485993,,,,,,,,,,,,,,,,, -99800,0.45514488,2.6991293,,,,,,,,,,,,,,,,, -99900,0.46140757,2.7092724,,,,,,,,,,,,,,,,, -100000,0.4640581,2.7340124,,,,,,,,,,,,,,,,, -100100,0.44374126,2.7589307,,,,,,,,,,,,,,,,, -100200,0.44659632,2.7222211,,,,,,,,,,,,,,,,, -100207,,,0.7089230418205261,1.443196415901184,36.04611483927222,0.6880261898040771,1.534866452217102,30.359806178761417,3000.0,0.7030271291732788,1.4542731046676636,30.08765700556365,3003.0,35320.90512704849,57482.59611940384,35320.90512704849,22157.154767990112,1.3990330696105957,0.0 -100300,0.43696645,2.6498787,,,,,,,,,,,,,,,,, -100400,0.42661992,2.7828734,,,,,,,,,,,,,,,,, -100500,0.44560534,2.7470715,,,,,,,,,,,,,,,,, -100600,0.4452286,2.73502,,,,,,,,,,,,,,,,, -100700,0.4661126,2.675688,,,,,,,,,,,,,,,,, -100800,0.46351278,2.6696107,,,,,,,,,,,,,,,,, -100900,0.46897206,2.7982254,,,,,,,,,,,,,,,,, -101000,0.5099988,2.729761,,,,,,,,,,,,,,,,, -101100,0.4342309,2.6791947,,,,,,,,,,,,,,,,, -101200,0.48830277,2.720391,,,,,,,,,,,,,,,,, -101300,0.4651472,2.7032702,,,,,,,,,,,,,,,,, -101400,0.46526968,2.71776,,,,,,,,,,,,,,,,, -101500,0.47793016,2.764624,,,,,,,,,,,,,,,,, -101600,0.46366826,2.7967818,,,,,,,,,,,,,,,,, -101700,0.47380888,2.7716882,,,,,,,,,,,,,,,,, -101800,0.4911435,2.7802775,,,,,,,,,,,,,,,,, -101900,0.47832948,2.7706516,,,,,,,,,,,,,,,,, -102000,0.48529556,2.7490735,,,,,,,,,,,,,,,,, -102100,0.46336108,2.6476195,,,,,,,,,,,,,,,,, -102200,0.4983702,2.7355485,,,,,,,,,,,,,,,,, -102300,0.45567584,2.7154765,,,,,,,,,,,,,,,,, -102400,0.47027043,2.7036448,,,,,,,,,,,,,,,,, -102500,0.47634634,2.7089176,,,,,,,,,,,,,,,,, -102593,,,0.6842008233070374,1.579800724983215,34.35210443372337,0.6877781748771667,1.531459927558899,30.39731603776197,3000.0,0.7057347297668457,1.445566177368164,30.57403197903927,3003.0,36160.99453735352,58801.58122611046,36160.99453735352,22635.92946076393,1.444124460220337,0.0 -102600,0.4972383,2.7429006,,,,,,,,,,,,,,,,, -102700,0.48257217,2.8071477,,,,,,,,,,,,,,,,, -102800,0.4809215,2.7009602,,,,,,,,,,,,,,,,, -102900,0.47480854,2.6861684,,,,,,,,,,,,,,,,, -103000,0.49837932,2.6965227,,,,,,,,,,,,,,,,, -103100,0.48380655,2.7203443,,,,,,,,,,,,,,,,, -103200,0.5105604,2.7056906,,,,,,,,,,,,,,,,, -103300,0.5027172,2.7147322,,,,,,,,,,,,,,,,, -103400,0.49685803,2.7247424,,,,,,,,,,,,,,,,, -103500,0.49275452,2.6596112,,,,,,,,,,,,,,,,, -103600,0.47397292,2.6710904,,,,,,,,,,,,,,,,, -103700,0.48689532,2.735639,,,,,,,,,,,,,,,,, -103800,0.48837242,2.7025437,,,,,,,,,,,,,,,,, -103900,0.4966295,2.6712494,,,,,,,,,,,,,,,,, -104000,0.4968099,2.6848378,,,,,,,,,,,,,,,,, -104100,0.49565634,2.7013233,,,,,,,,,,,,,,,,, -104200,0.49878824,2.6918044,,,,,,,,,,,,,,,,, -104300,0.4832136,2.6873739,,,,,,,,,,,,,,,,, -104400,0.54993266,2.7165747,,,,,,,,,,,,,,,,, -104500,0.4994125,2.7420948,,,,,,,,,,,,,,,,, -104600,0.5047266,2.7211359,,,,,,,,,,,,,,,,, -104700,0.50360215,2.7024028,,,,,,,,,,,,,,,,, -104800,0.48796874,2.6589665,,,,,,,,,,,,,,,,, -104900,0.5231071,2.7193708,,,,,,,,,,,,,,,,, -104979,,,0.683157205581665,1.582433581352234,34.28504809514936,0.6898860335350037,1.5270402431488037,30.47028869134486,3000.0,0.7062576413154602,1.4383666515350342,30.793022631949142,3003.0,37001.09319806099,60153.76279973984,37001.09319806099,23147.89739346504,1.483485221862793,0.0 -105000,0.4708321,2.6834142,,,,,,,,,,,,,,,,, -105100,0.48155814,2.6460583,,,,,,,,,,,,,,,,, -105200,0.49197993,2.7477274,,,,,,,,,,,,,,,,, -105300,0.5177643,2.7031975,,,,,,,,,,,,,,,,, -105400,0.50464803,2.691027,,,,,,,,,,,,,,,,, -105500,0.4920119,2.6461112,,,,,,,,,,,,,,,,, -105600,0.51775736,2.6733546,,,,,,,,,,,,,,,,, -105700,0.5220239,2.6934116,,,,,,,,,,,,,,,,, -105800,0.5031309,2.7317293,,,,,,,,,,,,,,,,, -105900,0.56482065,2.7346916,,,,,,,,,,,,,,,,, -106000,0.49214393,2.7019346,,,,,,,,,,,,,,,,, -106100,0.5039709,2.702329,,,,,,,,,,,,,,,,, -106200,0.49550074,2.6027384,,,,,,,,,,,,,,,,, -106300,0.51594067,2.6862504,,,,,,,,,,,,,,,,, -106400,0.50881386,2.652474,,,,,,,,,,,,,,,,, -106500,0.5487058,2.7244992,,,,,,,,,,,,,,,,, -106600,0.5274123,2.753338,,,,,,,,,,,,,,,,, -106700,0.5152894,2.652728,,,,,,,,,,,,,,,,, -106800,0.52337193,2.7001917,,,,,,,,,,,,,,,,, -106900,0.49649906,2.6509027,,,,,,,,,,,,,,,,, -107000,0.5308513,2.7050138,,,,,,,,,,,,,,,,, -107100,0.5218177,2.6946542,,,,,,,,,,,,,,,,, -107200,0.5256703,2.6851966,,,,,,,,,,,,,,,,, -107300,0.5205349,2.704291,,,,,,,,,,,,,,,,, -107365,,,0.6957497596740723,1.5065596103668213,35.563739134784505,0.6896504759788513,1.5250399112701416,30.50287564019885,3000.0,0.7057463526725769,1.435275673866272,30.804335144952017,3003.0,37841.05610227585,61481.371970653534,37841.05610227585,23635.432076931,1.522993564605713,0.0 -107400,0.52147806,2.7555816,,,,,,,,,,,,,,,,, -107500,0.5630663,2.6553967,,,,,,,,,,,,,,,,, -107600,0.514329,2.6250331,,,,,,,,,,,,,,,,, -107700,0.52892816,2.7004113,,,,,,,,,,,,,,,,, -107800,0.5596179,2.757376,,,,,,,,,,,,,,,,, -107900,0.5173555,2.6798868,,,,,,,,,,,,,,,,, -108000,0.547236,2.6942937,,,,,,,,,,,,,,,,, -108100,0.5155372,2.6672235,,,,,,,,,,,,,,,,, -108200,0.5311218,2.7036102,,,,,,,,,,,,,,,,, -108300,0.5106531,2.7250493,,,,,,,,,,,,,,,,, -108400,0.5165537,2.7139537,,,,,,,,,,,,,,,,, -108500,0.5306741,2.6754556,,,,,,,,,,,,,,,,, -108600,0.5461031,2.7300382,,,,,,,,,,,,,,,,, -108700,0.52838767,2.6715443,,,,,,,,,,,,,,,,, -108800,0.56628615,2.7007916,,,,,,,,,,,,,,,,, -108900,0.5496601,2.6338975,,,,,,,,,,,,,,,,, -109000,0.54975396,2.6731353,,,,,,,,,,,,,,,,, -109100,0.5214586,2.6862783,,,,,,,,,,,,,,,,, -109200,0.5502922,2.66593,,,,,,,,,,,,,,,,, -109300,0.5515851,2.6923757,,,,,,,,,,,,,,,,, -109400,0.5582844,2.724206,,,,,,,,,,,,,,,,, -109500,0.5675018,2.6642656,,,,,,,,,,,,,,,,, -109600,0.53069156,2.667533,,,,,,,,,,,,,,,,, -109700,0.5790786,2.711871,,,,,,,,,,,,,,,,, -109751,,,0.6920896172523499,1.5295428037643433,35.06577869851787,0.6902828216552734,1.5207116603851318,30.492036028114327,3000.0,0.7073732018470764,1.4307950735092163,30.37974883692844,3003.0,38681.15629196167,62864.20189833641,38681.15629196167,24178.042583703995,1.5672264099121094,0.0 -109800,0.5385588,2.6732247,,,,,,,,,,,,,,,,, -109900,0.5465653,2.7100434,,,,,,,,,,,,,,,,, -110000,0.54029596,2.678369,,,,,,,,,,,,,,,,, -110100,0.5587586,2.6845076,,,,,,,,,,,,,,,,, -110200,0.56343204,2.6964848,,,,,,,,,,,,,,,,, -110300,0.5774981,2.7144914,,,,,,,,,,,,,,,,, -110400,0.5559031,2.6516647,,,,,,,,,,,,,,,,, -110500,0.5746293,2.6699631,,,,,,,,,,,,,,,,, -110600,0.5389323,2.6723301,,,,,,,,,,,,,,,,, -110700,0.5549944,2.6593666,,,,,,,,,,,,,,,,, -110800,0.58996004,2.666639,,,,,,,,,,,,,,,,, -110900,0.5594439,2.738938,,,,,,,,,,,,,,,,, -111000,0.5405295,2.670995,,,,,,,,,,,,,,,,, -111100,0.53526,2.6514268,,,,,,,,,,,,,,,,, -111200,0.54496396,2.623273,,,,,,,,,,,,,,,,, -111300,0.5825534,2.6707816,,,,,,,,,,,,,,,,, -111400,0.5656853,2.6618156,,,,,,,,,,,,,,,,, -111500,0.57240695,2.683769,,,,,,,,,,,,,,,,, -111600,0.56584555,2.677716,,,,,,,,,,,,,,,,, -111700,0.5683456,2.6381729,,,,,,,,,,,,,,,,, -111800,0.5651491,2.5941782,,,,,,,,,,,,,,,,, -111900,0.56089574,2.659377,,,,,,,,,,,,,,,,, -112000,0.5862379,2.634398,,,,,,,,,,,,,,,,, -112100,0.5636808,2.6606534,,,,,,,,,,,,,,,,, -112137,,,0.6917627453804016,1.534698724746704,34.74488382371364,0.6911011338233948,1.5178720951080322,30.914413893236805,3000.0,0.7082563638687134,1.4285781383514404,30.58984000784102,3003.0,39521.200323581696,64166.88848829269,39521.200323581696,24640.57275032997,1.606520652770996,0.0 -112200,0.5861618,2.622304,,,,,,,,,,,,,,,,, -112300,0.5574143,2.6606574,,,,,,,,,,,,,,,,, -112400,0.57680726,2.6538546,,,,,,,,,,,,,,,,, -112500,0.5742762,2.641152,,,,,,,,,,,,,,,,, -112600,0.5980757,2.6515472,,,,,,,,,,,,,,,,, -112700,0.5916985,2.689498,,,,,,,,,,,,,,,,, -112800,0.551144,2.6983297,,,,,,,,,,,,,,,,, -112900,0.5823511,2.6052113,,,,,,,,,,,,,,,,, -113000,0.5980466,2.6761866,,,,,,,,,,,,,,,,, -113100,0.61739653,2.690747,,,,,,,,,,,,,,,,, -113200,0.6060722,2.6999989,,,,,,,,,,,,,,,,, -113300,1.4534415,2.7039497,,,,,,,,,,,,,,,,, -113400,0.5861207,2.6660674,,,,,,,,,,,,,,,,, -113500,0.57969564,2.596975,,,,,,,,,,,,,,,,, -113600,0.5941401,2.6408632,,,,,,,,,,,,,,,,, -113700,0.604896,2.6210313,,,,,,,,,,,,,,,,, -113800,0.56442434,2.6261592,,,,,,,,,,,,,,,,, -113900,0.60066026,2.6607873,,,,,,,,,,,,,,,,, -114000,0.57390153,2.590979,,,,,,,,,,,,,,,,, -114100,0.5686833,2.649226,,,,,,,,,,,,,,,,, -114200,0.5980059,2.590106,,,,,,,,,,,,,,,,, -114300,0.59525913,2.576217,,,,,,,,,,,,,,,,, -114400,0.5995312,2.6279862,,,,,,,,,,,,,,,,, -114500,0.57264507,2.6484249,,,,,,,,,,,,,,,,, -114523,,,0.698196291923523,1.495474338531494,35.77462367787801,0.6925270557403564,1.5115963220596311,30.57723185326211,3000.0,0.708628237247467,1.4227757453918457,30.869338740033143,3003.0,40361.25352883339,65511.83819484711,40361.25352883339,25145.35341453552,1.647085428237915,0.0 -114600,0.61898637,2.693518,,,,,,,,,,,,,,,,, -114700,0.61995775,2.6425776,,,,,,,,,,,,,,,,, -114800,0.6034339,2.6202383,,,,,,,,,,,,,,,,, -114900,0.5892857,2.712823,,,,,,,,,,,,,,,,, -115000,0.63656044,2.636129,,,,,,,,,,,,,,,,, -115100,0.60499793,2.6238766,,,,,,,,,,,,,,,,, -115200,0.6608823,2.6790462,,,,,,,,,,,,,,,,, -115300,0.6037489,2.6334167,,,,,,,,,,,,,,,,, -115400,0.6034464,2.661429,,,,,,,,,,,,,,,,, -115500,0.6005269,2.6234992,,,,,,,,,,,,,,,,, -115600,0.6199247,2.6324687,,,,,,,,,,,,,,,,, -115700,0.5976554,2.5817895,,,,,,,,,,,,,,,,, -115800,0.62047344,2.667185,,,,,,,,,,,,,,,,, -115900,0.5915803,2.6290104,,,,,,,,,,,,,,,,, -116000,0.59842634,2.6311436,,,,,,,,,,,,,,,,, -116100,0.6212242,2.6377645,,,,,,,,,,,,,,,,, -116200,0.6218866,2.6921744,,,,,,,,,,,,,,,,, -116300,0.61308664,2.547949,,,,,,,,,,,,,,,,, -116400,0.6147698,2.6324756,,,,,,,,,,,,,,,,, -116500,0.61272573,2.6015956,,,,,,,,,,,,,,,,, -116600,0.6103781,2.580407,,,,,,,,,,,,,,,,, -116700,0.6183231,2.6572664,,,,,,,,,,,,,,,,, -116800,0.6140527,2.6625404,,,,,,,,,,,,,,,,, -116900,0.6320236,2.6586044,,,,,,,,,,,,,,,,, -116908,,,0.6974208354949951,1.5098053216934204,35.78149866685925,0.693990170955658,1.5112836360931396,30.96133738228591,3000.0,0.7094300389289856,1.418934345245361,30.701844657815705,3003.0,41201.23071146011,66833.81958413124,41201.23071146011,25627.239145994183,1.689962387084961,0.0 -117000,0.6045906,2.5715775,,,,,,,,,,,,,,,,, -117100,0.6154137,2.6387315,,,,,,,,,,,,,,,,, -117200,0.6355865,2.6356535,,,,,,,,,,,,,,,,, -117300,0.65500027,2.6939633,,,,,,,,,,,,,,,,, -117400,0.649651,2.643894,,,,,,,,,,,,,,,,, -117500,0.6307248,2.6131155,,,,,,,,,,,,,,,,, -117600,0.62787604,2.6166081,,,,,,,,,,,,,,,,, -117700,0.6076589,2.605536,,,,,,,,,,,,,,,,, -117800,0.6275453,2.5998342,,,,,,,,,,,,,,,,, -117900,0.63582844,2.5935466,,,,,,,,,,,,,,,,, -118000,0.6071549,2.5349412,,,,,,,,,,,,,,,,, -118100,0.6301588,2.5634258,,,,,,,,,,,,,,,,, -118200,0.6426219,2.6483254,,,,,,,,,,,,,,,,, -118300,0.63769144,2.6254032,,,,,,,,,,,,,,,,, -118400,0.6344725,2.6065502,,,,,,,,,,,,,,,,, -118500,0.6553537,2.7104414,,,,,,,,,,,,,,,,, -118600,0.65929025,2.756737,,,,,,,,,,,,,,,,, -118700,0.6475625,2.6752086,,,,,,,,,,,,,,,,, -118800,0.64357096,2.551854,,,,,,,,,,,,,,,,, -118900,0.6393743,2.6363633,,,,,,,,,,,,,,,,, -119000,0.65941584,2.6035576,,,,,,,,,,,,,,,,, -119100,0.6300196,2.5675383,,,,,,,,,,,,,,,,, -119200,0.6414104,2.642427,,,,,,,,,,,,,,,,, -119295,,,0.710716187953949,1.4390244483947754,36.5941331650276,0.6939281225204468,1.5062028169631958,30.915858804496704,3000.0,0.7100808024406433,1.416991114616394,30.658603829809408,3003.0,42041.36020541191,68168.45203638077,42041.36020541191,26121.627092838287,1.7309155464172363,0.0 -119300,0.64910585,2.601091,,,,,,,,,,,,,,,,, -119400,0.67501587,2.6287136,,,,,,,,,,,,,,,,, -119500,0.6363393,2.607064,,,,,,,,,,,,,,,,, -119600,0.63244385,2.6133244,,,,,,,,,,,,,,,,, -119700,0.65152705,2.546908,,,,,,,,,,,,,,,,, -119800,0.6318665,2.6228516,,,,,,,,,,,,,,,,, -119900,0.66871667,2.6149695,,,,,,,,,,,,,,,,, -120000,0.6644925,2.6060495,,,,,,,,,,,,,,,,, -120100,0.65557927,2.5862837,,,,,,,,,,,,,,,,, -120200,0.64129174,2.6204154,,,,,,,,,,,,,,,,, -120300,0.63619673,2.6203098,,,,,,,,,,,,,,,,, -120400,0.66448426,2.6475492,,,,,,,,,,,,,,,,, -120500,0.6484217,2.5452302,,,,,,,,,,,,,,,,, -120600,0.64496017,2.6391308,,,,,,,,,,,,,,,,, -120700,0.68501633,2.6062057,,,,,,,,,,,,,,,,, -120800,0.6429657,2.6382813,,,,,,,,,,,,,,,,, -120900,0.64953315,2.5654514,,,,,,,,,,,,,,,,, -121000,0.65515447,2.5749457,,,,,,,,,,,,,,,,, -121100,0.6452827,2.6352506,,,,,,,,,,,,,,,,, -121200,0.64232403,2.5606568,,,,,,,,,,,,,,,,, -121300,0.65630955,2.550269,,,,,,,,,,,,,,,,, -121400,0.66790426,2.5947573,,,,,,,,,,,,,,,,, -121500,0.6326739,2.6921654,,,,,,,,,,,,,,,,, -121600,0.71129143,2.564948,,,,,,,,,,,,,,,,, -121681,,,0.7038121819496155,1.4720726013183594,36.19937840630523,0.6947588920593262,1.5022433996200562,30.82173925718412,3000.0,0.7108593583106995,1.4113267660140991,30.843223404698183,3003.0,42881.50228333473,69504.73358535767,42881.50228333473,26617.650767326355,1.773043870925903,0.0 -121700,0.64106655,2.5771453,,,,,,,,,,,,,,,,, -121800,0.65784234,2.6425643,,,,,,,,,,,,,,,,, -121900,0.64971775,2.5944476,,,,,,,,,,,,,,,,, -122000,0.6374318,2.5655034,,,,,,,,,,,,,,,,, -122100,0.6752068,2.634853,,,,,,,,,,,,,,,,, -122200,0.66397303,2.6000304,,,,,,,,,,,,,,,,, -122300,0.6512524,2.5769932,,,,,,,,,,,,,,,,, -122400,0.68286484,2.686718,,,,,,,,,,,,,,,,, -122500,0.7217521,2.5789776,,,,,,,,,,,,,,,,, -122600,0.6381603,2.5793521,,,,,,,,,,,,,,,,, -122700,0.68464524,2.6626685,,,,,,,,,,,,,,,,, -122800,0.67255986,2.5451887,,,,,,,,,,,,,,,,, -122900,0.6783908,2.603724,,,,,,,,,,,,,,,,, -123000,0.658355,2.602995,,,,,,,,,,,,,,,,, -123100,0.6606525,2.583669,,,,,,,,,,,,,,,,, -123200,0.6512394,2.6220117,,,,,,,,,,,,,,,,, -123300,0.66137296,2.5805585,,,,,,,,,,,,,,,,, -123400,0.6707082,2.5489638,,,,,,,,,,,,,,,,, -123500,0.6655172,2.6220245,,,,,,,,,,,,,,,,, -123600,0.69550395,2.5863717,,,,,,,,,,,,,,,,, -123700,0.6887687,2.6109548,,,,,,,,,,,,,,,,, -123800,0.6576668,2.5995255,,,,,,,,,,,,,,,,, -123900,0.6524142,2.629098,,,,,,,,,,,,,,,,, -124000,0.6628754,2.5774128,,,,,,,,,,,,,,,,, -124066,,,0.7046647667884827,1.4676135778427124,36.53354128224039,0.6953912377357483,1.503679275512695,30.746032259348308,3000.0,0.7117890119552612,1.4099397659301758,30.892537332352088,3003.0,43721.44666552544,70838.34599137306,43721.44666552544,27111.19986152649,1.8156933784484863,0.0 -124100,0.67109454,2.5821629,,,,,,,,,,,,,,,,, -124200,0.68210137,2.6179056,,,,,,,,,,,,,,,,, -124300,0.699173,2.6205816,,,,,,,,,,,,,,,,, -124400,0.69184625,2.6284335,,,,,,,,,,,,,,,,, -124500,0.67379117,2.649669,,,,,,,,,,,,,,,,, -124600,0.6703661,2.617268,,,,,,,,,,,,,,,,, -124700,0.6692795,2.6284137,,,,,,,,,,,,,,,,, -124800,0.6710323,2.5308788,,,,,,,,,,,,,,,,, -124900,0.6912603,2.5857308,,,,,,,,,,,,,,,,, -125000,0.6704618,2.5107129,,,,,,,,,,,,,,,,, -125100,0.6798971,2.6422868,,,,,,,,,,,,,,,,, -125200,0.66666806,2.608575,,,,,,,,,,,,,,,,, -125300,0.68405575,2.6027625,,,,,,,,,,,,,,,,, -125400,0.6669812,2.6196196,,,,,,,,,,,,,,,,, -125500,0.65303314,2.6031716,,,,,,,,,,,,,,,,, -125600,0.6727695,2.5949414,,,,,,,,,,,,,,,,, -125700,0.6609603,2.5775044,,,,,,,,,,,,,,,,, -125800,0.7109084,2.6257405,,,,,,,,,,,,,,,,, -125900,0.63792306,2.538986,,,,,,,,,,,,,,,,, -126000,0.7058674,2.6097438,,,,,,,,,,,,,,,,, -126100,0.68929636,2.5991993,,,,,,,,,,,,,,,,, -126200,0.68796927,2.6258178,,,,,,,,,,,,,,,,, -126300,0.6879781,2.6125824,,,,,,,,,,,,,,,,, -126400,0.67061615,2.639556,,,,,,,,,,,,,,,,, -126451,,,0.7076228857040405,1.4430257081985474,36.56570523268232,0.6953416466712952,1.5003718137741089,30.9751714367643,3000.0,0.7123816609382629,1.4068301916122437,31.215212632401546,3003.0,44561.50462055206,72149.5507850647,44561.50462055206,27582.21673822403,1.8660566806793213,0.0 -126500,0.6996571,2.5470748,,,,,,,,,,,,,,,,, -126600,0.6831348,2.6295812,,,,,,,,,,,,,,,,, -126700,0.68483716,2.5945003,,,,,,,,,,,,,,,,, -126800,0.66576046,2.570959,,,,,,,,,,,,,,,,, -126900,0.705004,2.652044,,,,,,,,,,,,,,,,, -127000,0.6635172,2.5404925,,,,,,,,,,,,,,,,, -127100,0.68676263,2.5691786,,,,,,,,,,,,,,,,, -127200,0.6940981,2.6112204,,,,,,,,,,,,,,,,, -127300,0.667187,2.6151748,,,,,,,,,,,,,,,,, -127400,0.7084215,2.7298365,,,,,,,,,,,,,,,,, -127500,0.68582803,2.5656657,,,,,,,,,,,,,,,,, -127600,0.7043501,2.6383202,,,,,,,,,,,,,,,,, -127700,0.67808044,2.600478,,,,,,,,,,,,,,,,, -127800,0.6506777,2.613163,,,,,,,,,,,,,,,,, -127900,0.6879685,2.557089,,,,,,,,,,,,,,,,, -128000,0.68091,2.609022,,,,,,,,,,,,,,,,, -128100,0.6876742,2.5601754,,,,,,,,,,,,,,,,, -128200,0.678126,2.5331495,,,,,,,,,,,,,,,,, -128300,0.6881393,2.5385418,,,,,,,,,,,,,,,,, -128400,0.6662889,2.5782917,,,,,,,,,,,,,,,,, -128500,0.7061918,2.5267346,,,,,,,,,,,,,,,,, -128600,0.6864908,2.5956438,,,,,,,,,,,,,,,,, -128700,0.6821494,2.6137807,,,,,,,,,,,,,,,,, -128800,0.6612832,2.5412724,,,,,,,,,,,,,,,,, -128837,,,0.7105525732040405,1.4332529306411743,36.75074571548151,0.6956764459609985,1.4999107122421265,30.823767338628773,3000.0,0.7122421860694885,1.4065791368484497,30.932261071712,3003.0,45401.51370239258,73480.97584033012,45401.51370239258,28073.51711320877,1.908497333526612,0.0 -128900,0.6867172,2.6147375,,,,,,,,,,,,,,,,, -129000,0.6763449,2.6353438,,,,,,,,,,,,,,,,, -129100,0.66646993,2.572113,,,,,,,,,,,,,,,,, -129200,0.677192,2.591903,,,,,,,,,,,,,,,,, -129300,0.6934004,2.6153138,,,,,,,,,,,,,,,,, -129400,0.718432,2.5884554,,,,,,,,,,,,,,,,, -129500,0.676452,2.5452673,,,,,,,,,,,,,,,,, -129600,0.68060184,2.5705895,,,,,,,,,,,,,,,,, -129700,0.66222876,2.5603285,,,,,,,,,,,,,,,,, -129800,0.67642957,2.5485861,,,,,,,,,,,,,,,,, -129900,0.70816606,2.6185827,,,,,,,,,,,,,,,,, -130000,0.66396487,2.5286455,,,,,,,,,,,,,,,,, -130100,0.6758043,2.5753872,,,,,,,,,,,,,,,,, -130200,0.69960403,2.6111825,,,,,,,,,,,,,,,,, -130300,0.68809384,2.5872018,,,,,,,,,,,,,,,,, -130400,0.6684685,2.615661,,,,,,,,,,,,,,,,, -130500,0.6709823,2.5592923,,,,,,,,,,,,,,,,, -130600,0.68056744,2.5943158,,,,,,,,,,,,,,,,, -130700,0.6787098,2.556052,,,,,,,,,,,,,,,,, -130800,0.6731489,2.6047318,,,,,,,,,,,,,,,,, -130900,0.6936778,2.615951,,,,,,,,,,,,,,,,, -131000,0.6703476,2.598213,,,,,,,,,,,,,,,,, -131100,0.64358974,2.6222024,,,,,,,,,,,,,,,,, -131200,0.68662184,2.6045551,,,,,,,,,,,,,,,,, -131223,,,0.7070929408073425,1.4555907249450684,36.070049531934,0.695874810218811,1.499701976776123,30.95694516258592,3000.0,0.7126489281654358,1.4054698944091797,31.09791754474128,3003.0,46241.69057679176,74813.12171435356,46241.69057679176,28565.360773801804,1.958054780960083,0.0 -131300,0.67320514,2.5941632,,,,,,,,,,,,,,,,, -131400,0.6644044,2.5777712,,,,,,,,,,,,,,,,, -131500,0.67940724,2.534072,,,,,,,,,,,,,,,,, -131600,0.6843348,2.6405294,,,,,,,,,,,,,,,,, -131700,0.6828975,2.6339617,,,,,,,,,,,,,,,,, -131800,0.6653447,2.5911322,,,,,,,,,,,,,,,,, -131900,0.6538019,2.5077138,,,,,,,,,,,,,,,,, -132000,0.6853585,2.6527586,,,,,,,,,,,,,,,,, -132100,0.68566805,2.5663548,,,,,,,,,,,,,,,,, -132200,0.67157245,2.5866554,,,,,,,,,,,,,,,,, -132300,0.68685776,2.6220627,,,,,,,,,,,,,,,,, -132400,0.6573507,2.5458617,,,,,,,,,,,,,,,,, -132500,0.6823226,2.5873015,,,,,,,,,,,,,,,,, -132600,0.68399894,2.6152387,,,,,,,,,,,,,,,,, -132700,0.67692906,2.5784602,,,,,,,,,,,,,,,,, -132800,0.6882717,2.6270368,,,,,,,,,,,,,,,,, -132900,0.67881775,2.5906222,,,,,,,,,,,,,,,,, -133000,0.6667266,2.536411,,,,,,,,,,,,,,,,, -133100,0.6934144,2.6330624,,,,,,,,,,,,,,,,, -133200,0.66183656,2.5986702,,,,,,,,,,,,,,,,, -133300,0.6428279,2.53317,,,,,,,,,,,,,,,,, -133333,,,0.7128967642784119,1.4201592206954956,36.50041035697179,0.6956764459609985,1.4999949932098389,30.901218707876914,3000.0,0.7125210762023926,1.4059910774230957,31.041513578042657,3003.0,46984.7851524353,76047.3649597168,46984.7851524353,29056.399307250977,2.001035213470459,0.0 -133333,,,,,,,,,,,,,,46984.7851524353,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/eval_measurements.csv deleted file mode 100644 index d137f2ada..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -875.0153396129608,0.0,26.80888557434082,1,0,26.80888557434082,0.0007088489946909,0.0,11.19086742401123,3003,901.8242671489716,0.000589943723753,0.0,11.17578411102295,0.0004835649742744,0.0,11.208685874938965,3000 -1597.1277873516085,0.0198423862457275,866.9106457233429,2380,0,866.9106457233429,0.4006739854812622,9.119701587852656,4.220223426818848,3003,2464.131246328354,0.4236276149749756,15.317106956985512,3.945071458816528,0.4104350805282593,10.871939591746957,4.064540863037109,3000 -2090.888371706009,0.0458595752716064,1707.1207020282743,4760,0,1707.1207020282743,0.5474405884742737,19.069473249226416,2.8263823986053467,3003,3798.200165033341,0.5412089824676514,24.40067253420964,2.8526296615600586,0.5490322709083557,20.54890937518002,2.781830072402954,3000 -2580.725028514862,0.0715317726135253,2547.193795681,7142,0,2547.193795681,0.5926558971405029,22.39646133848995,2.3840017318725586,3003,5128.208259344101,0.5862109065055847,26.95864129778303,2.43117094039917,0.5920323133468628,23.53881679992252,2.384349584579468,3000 -3060.3883907794952,0.0963890552520752,3387.143469810486,9524,0,3387.143469810486,0.6223113536834717,24.093482603258604,2.157191276550293,3003,6447.921786308289,0.5987629294395447,28.450225323180657,2.312525987625122,0.6168057322502136,25.08437942199972,2.1870875358581543,3000 -3535.278825521469,0.1225256919860839,4227.378788471222,11908,0,4227.378788471222,0.6385219097137451,25.16989331342719,2.006153345108032,3003,7763.146024465561,0.6138762831687927,29.359989662943025,2.191060304641724,0.632627010345459,26.211492834655424,2.045173168182373,3000 -4005.4279165267935,0.1491467952728271,5067.369838953018,14290,0,5067.369838953018,0.6499332189559937,26.31243160118559,1.9334949254989624,3003,9073.389715909958,0.6261821985244751,30.646065698482577,2.09044885635376,0.6428066492080688,27.19086719876297,1.982466459274292,3000 -4500.4322681427,0.177018404006958,5907.608487606049,16673,0,5907.608487606049,0.6567311882972717,26.530187275714983,1.875162959098816,3003,10408.735023498535,0.6261575818061829,30.48252226177977,2.085031509399414,0.648510217666626,27.83188217261842,1.9287440776824951,3000 -4989.699973106384,0.2039761543273925,6747.7297422885895,19056,0,6747.7297422885895,0.6621579527854919,27.09557119536779,1.8095208406448364,3003,11738.225252628326,0.6523534059524536,32.224112132764745,1.89583683013916,0.6549577713012695,28.0233294530836,1.867025375366211,3000 -5463.887475013733,0.2319626808166504,7587.909989356995,21439,0,7587.909989356995,0.667910099029541,27.43810439700653,1.771620273590088,3003,13052.69588947296,0.6392760276794434,31.3656819463444,1.9621353149414065,0.6572020053863525,28.291012927380795,1.832504153251648,3000 -5934.689563751221,0.2605419158935547,8427.87146282196,23822,0,8427.87146282196,0.6710127592086792,27.898585149515966,1.7506834268569946,3003,14363.563488960266,0.643540620803833,31.37433547834663,1.939845085144043,0.6603513956069946,28.12367394678881,1.811151027679444,3000 -6465.22674870491,0.2896280288696289,9268.00053191185,26206,0,9268.00053191185,0.6742897033691406,28.15439954133418,1.722639799118042,3003,15734.331931114197,0.653814435005188,32.41593082315556,1.8643923997879028,0.6627939939498901,28.8887559201092,1.7915291786193848,3000 -6962.611318826675,0.3182723522186279,10108.170245409012,28591,0,10108.170245409012,0.6772529482841492,28.447757417888468,1.7303847074508667,3003,17071.990015506744,0.6467625498771667,31.52587145991113,1.921960711479187,0.6663525700569153,28.85844748497176,1.796502947807312,3000 -7471.288013458252,0.3465893268585205,10948.20868062973,30974,0,10948.20868062973,0.6762768030166626,27.87737284091013,1.725675344467163,3003,18420.80942606926,0.6443696022033691,31.986524329560304,1.931141018867493,0.6660549640655518,28.60632817214931,1.7905811071395874,3000 -8128.417171955109,0.3758647441864013,11788.108996391296,33357,0,11788.108996391296,0.680727481842041,28.342088839250422,1.6848061084747314,3003,19917.94206786156,0.6579074263572693,32.069965372459464,1.84610378742218,0.6691919565200806,29.00561299490477,1.755408525466919,3000 -8605.174662351608,0.4066076278686523,12628.127183675766,35740,0,12628.127183675766,0.6793214082717896,28.543354366147703,1.6866551637649536,3003,21234.825357437134,0.6511590480804443,32.00241740281816,1.890103459358216,0.6663649678230286,29.13012192945443,1.756929874420166,3000 -9172.487488031387,0.4418942928314209,13468.265769004822,38123,0,13468.265769004822,0.6831677556037903,28.57284390796016,1.668337106704712,3003,22642.39179039001,0.6650872230529785,33.173089200434895,1.7879432439804075,0.6712377667427063,29.3184200998274,1.7390223741531372,3000 -9689.880198001862,0.4748435020446777,14308.177471637726,40506,0,14308.177471637726,0.6837139129638672,28.80755072318091,1.6359859704971311,3003,23999.80347084999,0.6558361649513245,32.338058333076155,1.8276304006576536,0.6708658337593079,29.019836053244,1.7183750867843628,3000 -10389.728868246078,0.51078200340271,15148.184902191162,42889,0,15148.184902191162,0.3018534779548645,0.033897742842188,4.571069240570068,3003,25539.77074050904,0.3256044685840606,0.1383137865844718,4.233160018920898,0.3041127920150757,0.0647455762933908,4.484804153442383,3000 -10857.352401733398,0.5476312637329102,15988.277658700945,45272,0,15988.277658700945,0.6868398189544678,28.890927730804112,1.6206845045089722,3003,26847.600742578503,0.6611015796661377,32.56452484565516,1.7925008535385132,0.6752551198005676,29.62041815683345,1.697988986968994,3000 -11367.964815616608,0.5800197124481201,16828.508974790573,47656,0,16828.508974790573,0.6875022053718567,29.12670142648821,1.6302539110183716,3003,28198.55082011223,0.6571767330169678,32.09281506764354,1.834312081336975,0.6745359301567078,29.54546456908649,1.7079156637191772,3000 -11943.41303062439,0.6114578247070312,17668.53111076355,50039,0,17668.53111076355,0.6908721327781677,29.203659503289817,1.6230562925338743,3003,29614.12636780739,0.70138019323349,35.59088268760344,1.5984097719192505,0.6770529747009277,29.58461302392992,1.7046748399734497,3000 -12451.49473619461,0.6497724056243896,18508.57137870789,52421,0,18508.57137870789,0.6908256411552429,29.400656364664417,1.6125904321670532,3003,30962.36435317993,0.6642529368400574,32.94824790704417,1.7819786071777344,0.6771645545959473,29.778122229123387,1.690804123878479,3000 -12936.39828658104,0.6831979751586914,19348.67028999329,54804,0,19348.67028999329,0.6921155452728271,29.701937568525977,1.6047054529190063,3003,32287.47656297684,0.6630371809005737,32.97631954929356,1.7826213836669922,0.6774993538856506,29.82505230190574,1.6878796815872192,3000 -13417.802710533142,0.7159018516540527,20188.675621509552,57187,0,20188.675621509552,0.6951368451118469,29.96182087882542,1.5948320627212524,3003,33608.99349451065,0.677360475063324,33.62485481934153,1.7102564573287964,0.6791856288909912,29.688366128750683,1.6819928884506226,3000 -13998.15175628662,0.7499048709869385,21028.83904337883,59571,0,21028.83904337883,0.6934635043144226,29.45679735869189,1.5827269554138184,3003,35029.61338472366,0.665337324142456,32.771778877290295,1.7709975242614746,0.6795576214790344,29.518673328869195,1.667377233505249,3000 -14507.747369527817,0.7831311225891113,21869.08454298973,61955,0,21869.08454298973,0.6972982287406921,30.065476502372743,1.5744260549545288,3003,36379.56156635285,0.6696307063102722,33.5078406006622,1.765308141708374,0.6800783276557922,29.881170707042948,1.670266032218933,3000 -14981.440528154371,0.8170902729034424,22709.00170826912,64338,0,22709.00170826912,0.6957643628120422,29.864767151801697,1.577563762664795,3003,37693.27958345413,0.6764718890190125,33.57440543575773,1.7071046829223633,0.680251955986023,29.700324861556748,1.6660431623458862,3000 -15489.14245057106,0.854212760925293,23549.12446165085,66721,0,23549.12446165085,0.6970425844192505,30.122066124861863,1.569135665893555,3003,39041.21701860428,0.6704260110855103,33.43165042153991,1.7423919439315796,0.6816902160644531,29.84663557839197,1.6634081602096558,3000 -15992.877562999724,0.8907725811004639,24389.31098389625,69104,0,24389.31098389625,0.6975655555725098,29.880227745785337,1.562856674194336,3003,40385.25058054924,0.6957202553749084,35.166386173569414,1.5890588760375977,0.682223379611969,29.95974849835134,1.655207276344299,3000 -16507.55722308159,0.925260066986084,25229.35171794892,71487,0,25229.35171794892,0.6996223330497742,30.269282466335337,1.5564168691635132,3003,41740.079825639725,0.676636815071106,34.26073413751061,1.7049415111541748,0.6833145022392273,30.023471272179897,1.646494746208191,3000 -17084.658698558807,0.9607217311859132,26069.48642897606,73871,0,26069.48642897606,0.6999593377113342,30.42530986159958,1.545217990875244,3003,43157.42514300346,0.6750386357307434,34.45133599614152,1.720869064331055,0.6839592456817627,30.166010042652832,1.6430405378341677,3000 -17658.08572268486,1.0019049644470217,26909.4674077034,76254,0,26909.4674077034,0.701655924320221,30.24334076025487,1.541452407836914,3003,44570.950585365295,0.6888977885246277,35.07580034747715,1.6275510787963867,0.6844180226325989,29.92807660712913,1.6386370658874512,3000 -18163.086530447006,1.0375986099243164,27749.411451101303,78637,0,27749.411451101303,0.7021439671516418,30.256775334023622,1.543318271636963,3003,45916.00671863556,0.6790190935134888,34.606128151781625,1.6930485963821411,0.6853355765342712,30.084906994619733,1.6412664651870728,3000 -18639.112541913983,1.0760555267333984,28589.64216661453,81021,0,28589.64216661453,0.7027831077575684,30.454600655162768,1.532280445098877,3003,47232.37526059151,0.68350750207901,33.94700163998571,1.6653733253479004,0.6868730783462524,30.37649273654434,1.628083348274231,3000 -19135.941576480865,1.1130588054656982,29429.75515246392,83404,0,29429.75515246392,0.7037476301193237,30.497443854719528,1.527652382850647,3003,48569.43083524704,0.689186692237854,34.95898421051862,1.628559947013855,0.6865258812904358,29.923109223085948,1.6255191564559937,3000 -19781.944012403488,1.1491875648498535,30269.69832634926,85787,0,30269.69832634926,0.7035732865333557,30.16393657409564,1.526257038116455,3003,50055.488450050354,0.6846340894699097,34.29854509964452,1.6518151760101318,0.6845916509628296,29.913907501321336,1.627881407737732,3000 -20433.31798386573,1.1875977516174316,31109.80075287819,88170,0,31109.80075287819,0.7039335370063782,30.32476310275814,1.526638388633728,3003,51547.07810497284,0.7073625326156616,36.36983182542021,1.5367120504379272,0.6875426173210144,30.037688410889466,1.6262383460998535,3000 -21065.105221271515,1.225377082824707,31949.7300992012,90553,0,31949.7300992012,0.7055139541625977,30.369156672036247,1.5144983530044556,3003,53018.90501999855,0.6886942386627197,34.94074502185821,1.6172972917556765,0.6868730783462524,29.63990232421176,1.6169028282165527,3000 -21553.11935710907,1.263375759124756,32789.64198088646,92936,0,32789.64198088646,0.7075939774513245,30.907346055474644,1.5099539756774902,3003,54346.94282460213,0.6906367540359497,35.15019654432984,1.6220505237579346,0.6880013942718506,30.125960524935937,1.6124529838562012,3000 -22071.868659973145,1.3095154762268066,33629.78120470047,95319,0,33629.78120470047,0.7068386673927307,30.55684966808048,1.5075197219848633,3003,55705.952450037,0.7033336758613586,35.7991515308961,1.5496033430099487,0.6894024610519409,30.590930186553223,1.6060791015625,3000 -22554.345523118973,1.3491921424865725,34469.73361515999,97702,0,34469.73361515999,0.7058508992195129,30.394577530924465,1.5085893869400024,3003,57028.49427843094,0.6969262957572937,35.152333339015506,1.5804659128189087,0.6876417994499207,30.332051794275426,1.6106321811676023,3000 -23047.190549373627,1.3953723907470703,35309.84344100952,100084,0,35309.84344100952,0.7073267102241516,30.53523360311065,1.5063436031341553,3003,58361.5737016201,0.7314824461936951,38.33886509323845,1.4231904745101929,0.6890305280685425,30.71176452861133,1.6068017482757568,3000 -23557.168686389923,1.442002773284912,36149.92675304413,102466,0,36149.92675304413,0.7066527605056763,30.829963453250123,1.5068280696868896,3003,59711.75960898399,0.7023757696151733,36.134819561514,1.553908348083496,0.6896008849143982,30.46585574558103,1.6050572395324707,3000 -24042.96623015404,1.4812438488006592,36989.896939754486,104849,0,36989.896939754486,0.7081517577171326,30.92962167513607,1.4993467330932615,3003,61037.63996696472,0.6993989944458008,35.754191956860694,1.559681415557861,0.6894644498825073,30.594905174890503,1.6042819023132324,3000 -24527.8506731987,1.5237393379211426,37830.11774921417,107233,0,37830.11774921417,0.7072221636772156,30.69685450178556,1.5013386011123655,3003,62362.86051940918,0.7179781794548035,37.12911174176097,1.4720852375030518,0.6896132826805115,30.57522490142288,1.6041953563690186,3000 -25037.91539287567,1.563863754272461,38670.0972571373,109615,0,38670.0972571373,0.7097089290618896,30.87070303183882,1.493978500366211,3003,63713.02125096321,0.710374653339386,36.08658871891254,1.5148351192474363,0.690220832824707,30.252796426395733,1.601181983947754,3000 -25521.773129224777,1.6069166660308838,39510.23357391357,111997,0,39510.23357391357,0.7084655165672302,30.64387931009616,1.4979416131973269,3003,65037.137236356735,0.7104139924049377,36.71417614009322,1.5093110799789429,0.6904935836791992,30.74636830222096,1.6021604537963867,3000 -26007.337995052338,1.6503000259399414,40350.37732386589,114380,0,40350.37732386589,0.7098019123077393,30.778936015094025,1.4975818395614624,3003,66362.9657073021,0.7173652052879333,37.131145115956905,1.4794282913208008,0.6896628737449646,30.39730108265441,1.6034998893737793,3000 -26494.117411851883,1.6911771297454834,41190.49334859848,116764,0,41190.49334859848,0.7097205519676208,30.84845509744524,1.494398832321167,3003,67689.97535419464,0.7147499918937683,36.79285558146597,1.4900606870651243,0.6898736357688904,30.654046685474967,1.601643681526184,3000 -26965.506422758102,1.7330989837646484,42030.68079471588,119147,0,42030.68079471588,0.709674060344696,30.740337151689708,1.493552803993225,3003,69001.66783833504,0.7253099083900452,38.21547448718535,1.4400979280471802,0.6901092529296875,30.736349330402373,1.599493384361267,3000 -27487.700913906097,1.7821390628814695,42870.59420728684,121529,0,42870.59420728684,0.7105339765548706,30.733517509111195,1.4927688837051392,3003,70363.89932894707,0.7211779356002808,37.837886173614265,1.4619131088256836,0.6910391449928284,30.736317737646832,1.5997225046157837,3000 -27984.25898528099,1.8253540992736816,43710.78532648087,123912,0,43710.78532648087,0.7107663750648499,30.85889317102481,1.4905011653900146,3003,71700.7684583664,0.7188615202903748,37.37068803202791,1.4707727432250977,0.6907415986061096,30.730796295851167,1.5994369983673096,3000 -28494.218544483185,1.867745876312256,44550.9940226078,126295,0,44550.9940226078,0.7114520072937012,30.760997036011428,1.4903055429458618,3003,73051.0552790165,0.7226437330245972,38.29304533888482,1.4515371322631836,0.6909151673316956,30.73582258184852,1.599329710006714,3000 -29010.47008228302,1.919017314910889,45391.09645104408,128677,0,45391.09645104408,0.7110103964805603,30.711279312995767,1.489971160888672,3003,74407.53640341759,0.7253398895263672,38.15674927738666,1.437965750694275,0.6909027695655823,30.539796950932832,1.5986698865890503,3000 -29529.583674669266,1.9653499126434328,46231.21893525124,131059,0,46231.21893525124,0.7107431292533875,30.77321518935283,1.490159273147583,3003,75766.89620161057,0.724202036857605,38.04270301127323,1.443503499031067,0.6908655762672424,30.667733220414544,1.599054217338562,3000 -30052.825205802917,2.0094943046569824,47033.39378809929,133333,0,47033.39378809929,0.7107315063476562,30.821354274030078,1.4902734756469727,3003,77092.4309284687,0.7267444133758545,37.58713677051531,1.427182912826538,0.6907787919044495,30.66343634971797,1.5991333723068237,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/measurements.csv deleted file mode 100644 index 018cb4569..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.094406,11.156319,,,,,,,,,,,,,,,,, -1,,,0.000589943723753,11.17578411102295,0.0,0.0004835649742744,11.208685874938965,0.0,3000.0,0.0007088489946909,11.19086742401123,0.0,3003.0,26.80888557434082,901.8242671489716,26.80888557434082,875.0153396129608,0.0,0.0 -100,0.2789818,9.017484,,,,,,,,,,,,,,,,, -200,0.22253759,8.678911,,,,,,,,,,,,,,,,, -300,0.8212967,8.304846,,,,,,,,,,,,,,,,, -400,1.0460213,8.050576,,,,,,,,,,,,,,,,, -500,0.4112684,7.799064,,,,,,,,,,,,,,,,, -600,0.81785643,7.6367874,,,,,,,,,,,,,,,,, -700,0.84371156,7.4954195,,,,,,,,,,,,,,,,, -800,0.5832978,7.2425513,,,,,,,,,,,,,,,,, -900,0.6592276,7.1811166,,,,,,,,,,,,,,,,, -1000,0.881578,6.993343,,,,,,,,,,,,,,,,, -1100,0.61998314,6.80459,,,,,,,,,,,,,,,,, -1200,0.68671936,6.73033,,,,,,,,,,,,,,,,, -1300,0.7108012,6.6308146,,,,,,,,,,,,,,,,, -1400,0.46918193,6.4935446,,,,,,,,,,,,,,,,, -1500,0.49586278,6.383498,,,,,,,,,,,,,,,,, -1600,0.6141541,6.289968,,,,,,,,,,,,,,,,, -1700,0.7745579,6.1829786,,,,,,,,,,,,,,,,, -1800,0.6876537,6.12112,,,,,,,,,,,,,,,,, -1900,0.7547601,6.0416555,,,,,,,,,,,,,,,,, -2000,0.7007388,6.011113,,,,,,,,,,,,,,,,, -2100,0.78674036,5.8610315,,,,,,,,,,,,,,,,, -2200,0.71506965,5.7423778,,,,,,,,,,,,,,,,, -2300,0.5899731,5.6556196,,,,,,,,,,,,,,,,, -2380,,,0.4236276149749756,3.945071458816528,15.317106956985512,0.4104350805282593,4.064540863037109,10.871939591746957,3000.0,0.4006739854812622,4.220223426818848,9.119701587852656,3003.0,866.9106457233429,2464.131246328354,866.9106457233429,1597.1277873516085,0.0198423862457275,0.0 -2400,0.7224027,5.615198,,,,,,,,,,,,,,,,, -2500,0.7253809,5.550821,,,,,,,,,,,,,,,,, -2600,0.51509917,5.381719,,,,,,,,,,,,,,,,, -2700,0.5492712,5.3829627,,,,,,,,,,,,,,,,, -2800,0.52376425,5.3086543,,,,,,,,,,,,,,,,, -2900,0.5224447,5.2594876,,,,,,,,,,,,,,,,, -3000,0.7560966,5.3161163,,,,,,,,,,,,,,,,, -3100,0.4734177,5.1533465,,,,,,,,,,,,,,,,, -3200,0.49483413,5.1124434,,,,,,,,,,,,,,,,, -3300,0.54435045,5.111471,,,,,,,,,,,,,,,,, -3400,0.49037665,5.0845895,,,,,,,,,,,,,,,,, -3500,0.5616364,5.0355096,,,,,,,,,,,,,,,,, -3600,0.57816046,5.029184,,,,,,,,,,,,,,,,, -3700,0.44219783,4.917858,,,,,,,,,,,,,,,,, -3800,0.4730027,4.9765944,,,,,,,,,,,,,,,,, -3900,0.4372673,4.939034,,,,,,,,,,,,,,,,, -4000,0.39592385,4.91766,,,,,,,,,,,,,,,,, -4100,0.45288685,4.834805,,,,,,,,,,,,,,,,, -4200,0.4530061,4.84723,,,,,,,,,,,,,,,,, -4300,0.42027006,4.7958827,,,,,,,,,,,,,,,,, -4400,0.44481272,4.8204103,,,,,,,,,,,,,,,,, -4500,0.37232125,4.732013,,,,,,,,,,,,,,,,, -4600,0.43444496,4.787936,,,,,,,,,,,,,,,,, -4700,0.3546022,4.6297474,,,,,,,,,,,,,,,,, -4760,,,0.5412089824676514,2.8526296615600586,24.40067253420964,0.5490322709083557,2.781830072402954,20.54890937518002,3000.0,0.5474405884742737,2.8263823986053467,19.069473249226416,3003.0,1707.1207020282743,3798.200165033341,1707.1207020282743,2090.888371706009,0.0458595752716064,0.0 -4800,0.39607435,4.7160063,,,,,,,,,,,,,,,,, -4900,0.38115346,4.7089963,,,,,,,,,,,,,,,,, -5000,0.33485857,4.690218,,,,,,,,,,,,,,,,, -5100,0.3630901,4.691867,,,,,,,,,,,,,,,,, -5200,0.48848152,4.701407,,,,,,,,,,,,,,,,, -5300,0.31470096,4.6037188,,,,,,,,,,,,,,,,, -5400,0.31621677,4.6118326,,,,,,,,,,,,,,,,, -5500,0.34162876,4.55241,,,,,,,,,,,,,,,,, -5600,0.3212125,4.6108775,,,,,,,,,,,,,,,,, -5700,0.30856526,4.57234,,,,,,,,,,,,,,,,, -5800,0.30408135,4.629739,,,,,,,,,,,,,,,,, -5900,0.2998805,4.5947375,,,,,,,,,,,,,,,,, -6000,0.30565158,4.527969,,,,,,,,,,,,,,,,, -6100,0.28224596,4.530093,,,,,,,,,,,,,,,,, -6200,0.27902365,4.5672812,,,,,,,,,,,,,,,,, -6300,0.2835936,4.539339,,,,,,,,,,,,,,,,, -6400,0.30903983,4.553556,,,,,,,,,,,,,,,,, -6500,0.34630534,4.5312996,,,,,,,,,,,,,,,,, -6600,0.29495782,4.447339,,,,,,,,,,,,,,,,, -6700,0.26455808,4.4397473,,,,,,,,,,,,,,,,, -6800,0.27039847,4.512269,,,,,,,,,,,,,,,,, -6900,0.24541844,4.4290457,,,,,,,,,,,,,,,,, -7000,0.25725734,4.495728,,,,,,,,,,,,,,,,, -7100,0.22706848,4.512598,,,,,,,,,,,,,,,,, -7142,,,0.5862109065055847,2.43117094039917,26.95864129778303,0.5920323133468628,2.384349584579468,23.53881679992252,3000.0,0.5926558971405029,2.3840017318725586,22.39646133848995,3003.0,2547.193795681,5128.208259344101,2547.193795681,2580.725028514862,0.0715317726135253,0.0 -7200,0.23086773,4.4409156,,,,,,,,,,,,,,,,, -7300,0.24870493,4.3924227,,,,,,,,,,,,,,,,, -7400,0.27274302,4.433617,,,,,,,,,,,,,,,,, -7500,0.23900798,4.413307,,,,,,,,,,,,,,,,, -7600,0.2801812,4.4757175,,,,,,,,,,,,,,,,, -7700,0.2511085,4.358884,,,,,,,,,,,,,,,,, -7800,0.22099614,4.2976804,,,,,,,,,,,,,,,,, -7900,0.20380348,4.414115,,,,,,,,,,,,,,,,, -8000,0.23798475,4.3645015,,,,,,,,,,,,,,,,, -8100,0.20899445,4.3063836,,,,,,,,,,,,,,,,, -8200,0.23387769,4.2911596,,,,,,,,,,,,,,,,, -8300,0.224473,4.413812,,,,,,,,,,,,,,,,, -8400,0.20932908,4.3530793,,,,,,,,,,,,,,,,, -8500,0.20192622,4.3203077,,,,,,,,,,,,,,,,, -8600,0.21450403,4.360709,,,,,,,,,,,,,,,,, -8700,0.21427718,4.320144,,,,,,,,,,,,,,,,, -8800,0.20958449,4.2978983,,,,,,,,,,,,,,,,, -8900,0.2061272,4.3109493,,,,,,,,,,,,,,,,, -9000,0.22617093,4.3605213,,,,,,,,,,,,,,,,, -9100,0.23448858,4.325312,,,,,,,,,,,,,,,,, -9200,0.20598592,4.3188057,,,,,,,,,,,,,,,,, -9300,0.19901465,4.263724,,,,,,,,,,,,,,,,, -9400,0.1982049,4.2922955,,,,,,,,,,,,,,,,, -9500,0.19187844,4.2591724,,,,,,,,,,,,,,,,, -9524,,,0.5987629294395447,2.312525987625122,28.450225323180657,0.6168057322502136,2.1870875358581543,25.08437942199972,3000.0,0.6223113536834717,2.157191276550293,24.093482603258604,3003.0,3387.143469810486,6447.921786308289,3387.143469810486,3060.3883907794952,0.0963890552520752,0.0 -9600,0.17839167,4.2217913,,,,,,,,,,,,,,,,, -9700,0.19078843,4.204837,,,,,,,,,,,,,,,,, -9800,0.178237,4.243494,,,,,,,,,,,,,,,,, -9900,0.17688218,4.33753,,,,,,,,,,,,,,,,, -10000,0.19539766,4.268058,,,,,,,,,,,,,,,,, -10100,0.17200913,4.2296224,,,,,,,,,,,,,,,,, -10200,0.17575593,4.223508,,,,,,,,,,,,,,,,, -10300,0.17639503,4.235,,,,,,,,,,,,,,,,, -10400,0.18295623,4.2275543,,,,,,,,,,,,,,,,, -10500,0.17506763,4.188692,,,,,,,,,,,,,,,,, -10600,0.17849536,4.266315,,,,,,,,,,,,,,,,, -10700,0.17576809,4.218778,,,,,,,,,,,,,,,,, -10800,0.18099746,4.1893783,,,,,,,,,,,,,,,,, -10900,0.262085,4.30389,,,,,,,,,,,,,,,,, -11000,0.18974529,4.259544,,,,,,,,,,,,,,,,, -11100,0.19824794,4.2244353,,,,,,,,,,,,,,,,, -11200,0.19399123,4.149767,,,,,,,,,,,,,,,,, -11300,0.2123506,4.276064,,,,,,,,,,,,,,,,, -11400,0.17377408,4.1524434,,,,,,,,,,,,,,,,, -11500,0.1690053,4.2846985,,,,,,,,,,,,,,,,, -11600,0.1604487,4.1935806,,,,,,,,,,,,,,,,, -11700,0.17786005,4.158576,,,,,,,,,,,,,,,,, -11800,0.17749794,4.200365,,,,,,,,,,,,,,,,, -11900,0.16521887,4.1924567,,,,,,,,,,,,,,,,, -11908,,,0.6138762831687927,2.191060304641724,29.359989662943025,0.632627010345459,2.045173168182373,26.211492834655424,3000.0,0.6385219097137451,2.006153345108032,25.16989331342719,3003.0,4227.378788471222,7763.146024465561,4227.378788471222,3535.278825521469,0.1225256919860839,0.0 -12000,0.1655349,4.2896223,,,,,,,,,,,,,,,,, -12100,0.21316756,4.1461,,,,,,,,,,,,,,,,, -12200,0.18620515,4.118473,,,,,,,,,,,,,,,,, -12300,0.16241574,4.1902647,,,,,,,,,,,,,,,,, -12400,0.15681325,4.1867948,,,,,,,,,,,,,,,,, -12500,0.17113878,4.1281385,,,,,,,,,,,,,,,,, -12600,0.20011216,4.154877,,,,,,,,,,,,,,,,, -12700,0.16213845,4.197838,,,,,,,,,,,,,,,,, -12800,0.15124407,4.1953835,,,,,,,,,,,,,,,,, -12900,0.153224,4.1645265,,,,,,,,,,,,,,,,, -13000,0.15507148,4.157407,,,,,,,,,,,,,,,,, -13100,0.161122,4.155795,,,,,,,,,,,,,,,,, -13200,0.1766099,4.177253,,,,,,,,,,,,,,,,, -13300,0.15741321,4.174194,,,,,,,,,,,,,,,,, -13400,0.17785503,4.158334,,,,,,,,,,,,,,,,, -13500,0.16939747,4.1543174,,,,,,,,,,,,,,,,, -13600,0.23623617,4.105915,,,,,,,,,,,,,,,,, -13700,0.1909684,4.202484,,,,,,,,,,,,,,,,, -13800,0.15236953,4.187981,,,,,,,,,,,,,,,,, -13900,0.16209249,4.102571,,,,,,,,,,,,,,,,, -14000,0.1793646,4.1571794,,,,,,,,,,,,,,,,, -14100,0.1594822,4.1278806,,,,,,,,,,,,,,,,, -14200,0.15612294,4.097624,,,,,,,,,,,,,,,,, -14290,,,0.6261821985244751,2.09044885635376,30.646065698482577,0.6428066492080688,1.982466459274292,27.19086719876297,3000.0,0.6499332189559937,1.9334949254989624,26.31243160118559,3003.0,5067.369838953018,9073.389715909958,5067.369838953018,4005.4279165267935,0.1491467952728271,0.0 -14300,0.17471117,4.166968,,,,,,,,,,,,,,,,, -14400,0.20423736,4.0945525,,,,,,,,,,,,,,,,, -14500,0.16263023,4.079585,,,,,,,,,,,,,,,,, -14600,0.17513284,4.1404924,,,,,,,,,,,,,,,,, -14700,0.16728042,4.1392903,,,,,,,,,,,,,,,,, -14800,0.17374216,4.1022167,,,,,,,,,,,,,,,,, -14900,0.16185082,4.188585,,,,,,,,,,,,,,,,, -15000,0.16394666,4.1170154,,,,,,,,,,,,,,,,, -15100,0.15483287,4.1304464,,,,,,,,,,,,,,,,, -15200,0.16136031,4.0668774,,,,,,,,,,,,,,,,, -15300,0.15709646,4.1495466,,,,,,,,,,,,,,,,, -15400,0.1566591,4.1570625,,,,,,,,,,,,,,,,, -15500,0.1943031,4.1171827,,,,,,,,,,,,,,,,, -15600,0.15090995,4.094519,,,,,,,,,,,,,,,,, -15700,0.15726803,4.123547,,,,,,,,,,,,,,,,, -15800,0.17503412,4.0682406,,,,,,,,,,,,,,,,, -15900,0.15052605,4.0643687,,,,,,,,,,,,,,,,, -16000,0.1996905,4.103729,,,,,,,,,,,,,,,,, -16100,0.16194795,4.069751,,,,,,,,,,,,,,,,, -16200,0.17590025,4.098689,,,,,,,,,,,,,,,,, -16300,0.1539413,4.0596824,,,,,,,,,,,,,,,,, -16400,0.17847802,4.079201,,,,,,,,,,,,,,,,, -16500,0.16653077,4.1263905,,,,,,,,,,,,,,,,, -16600,0.16578501,4.086057,,,,,,,,,,,,,,,,, -16673,,,0.6261575818061829,2.085031509399414,30.48252226177977,0.648510217666626,1.9287440776824951,27.83188217261842,3000.0,0.6567311882972717,1.875162959098816,26.530187275714983,3003.0,5907.608487606049,10408.735023498535,5907.608487606049,4500.4322681427,0.177018404006958,0.0 -16700,0.16575916,4.042547,,,,,,,,,,,,,,,,, -16800,0.16460589,4.0937004,,,,,,,,,,,,,,,,, -16900,0.15777507,4.0800323,,,,,,,,,,,,,,,,, -17000,0.20644438,4.080104,,,,,,,,,,,,,,,,, -17100,0.16186617,4.078488,,,,,,,,,,,,,,,,, -17200,0.2024731,4.1392736,,,,,,,,,,,,,,,,, -17300,0.15170723,4.0606585,,,,,,,,,,,,,,,,, -17400,0.1566784,4.056077,,,,,,,,,,,,,,,,, -17500,0.1639392,4.1434135,,,,,,,,,,,,,,,,, -17600,0.16044645,4.0646157,,,,,,,,,,,,,,,,, -17700,0.15603544,4.072052,,,,,,,,,,,,,,,,, -17800,0.19705044,4.0827656,,,,,,,,,,,,,,,,, -17900,0.15653576,4.048994,,,,,,,,,,,,,,,,, -18000,0.15347666,3.9728987,,,,,,,,,,,,,,,,, -18100,0.16919157,4.0769563,,,,,,,,,,,,,,,,, -18200,0.17077103,4.1097302,,,,,,,,,,,,,,,,, -18300,0.1688665,4.028417,,,,,,,,,,,,,,,,, -18400,0.18112859,4.07668,,,,,,,,,,,,,,,,, -18500,0.1589518,4.045352,,,,,,,,,,,,,,,,, -18600,0.14589885,4.0394325,,,,,,,,,,,,,,,,, -18700,0.1740636,4.040088,,,,,,,,,,,,,,,,, -18800,0.15336177,4.081617,,,,,,,,,,,,,,,,, -18900,0.16922039,4.044367,,,,,,,,,,,,,,,,, -19000,0.16367063,4.033416,,,,,,,,,,,,,,,,, -19056,,,0.6523534059524536,1.89583683013916,32.224112132764745,0.6549577713012695,1.867025375366211,28.0233294530836,3000.0,0.6621579527854919,1.8095208406448364,27.09557119536779,3003.0,6747.7297422885895,11738.225252628326,6747.7297422885895,4989.699973106384,0.2039761543273925,0.0 -19100,0.16988929,4.0901175,,,,,,,,,,,,,,,,, -19200,0.19807252,4.0649667,,,,,,,,,,,,,,,,, -19300,0.14692695,4.0218844,,,,,,,,,,,,,,,,, -19400,0.1955231,4.097844,,,,,,,,,,,,,,,,, -19500,0.15423733,4.0067177,,,,,,,,,,,,,,,,, -19600,0.1611779,4.079405,,,,,,,,,,,,,,,,, -19700,0.21111393,4.083848,,,,,,,,,,,,,,,,, -19800,0.1787385,4.0724926,,,,,,,,,,,,,,,,, -19900,0.16092998,4.005908,,,,,,,,,,,,,,,,, -20000,0.17142178,4.048115,,,,,,,,,,,,,,,,, -20100,0.1663738,4.088156,,,,,,,,,,,,,,,,, -20200,0.1497671,3.9730246,,,,,,,,,,,,,,,,, -20300,0.19526027,3.9791079,,,,,,,,,,,,,,,,, -20400,0.16904594,4.07155,,,,,,,,,,,,,,,,, -20500,0.20331725,4.0083337,,,,,,,,,,,,,,,,, -20600,0.18005873,3.9359894,,,,,,,,,,,,,,,,, -20700,0.15710925,4.0533977,,,,,,,,,,,,,,,,, -20800,0.16226292,4.0209394,,,,,,,,,,,,,,,,, -20900,0.15678549,4.0394597,,,,,,,,,,,,,,,,, -21000,0.29663354,3.991042,,,,,,,,,,,,,,,,, -21100,0.16245195,3.9782605,,,,,,,,,,,,,,,,, -21200,0.1958615,4.0354595,,,,,,,,,,,,,,,,, -21300,0.17139417,4.0502973,,,,,,,,,,,,,,,,, -21400,0.18279703,4.041817,,,,,,,,,,,,,,,,, -21439,,,0.6392760276794434,1.9621353149414065,31.3656819463444,0.6572020053863525,1.832504153251648,28.291012927380795,3000.0,0.667910099029541,1.771620273590088,27.43810439700653,3003.0,7587.909989356995,13052.69588947296,7587.909989356995,5463.887475013733,0.2319626808166504,0.0 -21500,0.17182131,4.0447755,,,,,,,,,,,,,,,,, -21600,0.1648108,4.0598054,,,,,,,,,,,,,,,,, -21700,0.15949084,4.0170183,,,,,,,,,,,,,,,,, -21800,0.14682196,3.9832919,,,,,,,,,,,,,,,,, -21900,0.17295192,3.9776316,,,,,,,,,,,,,,,,, -22000,0.19020002,4.0487976,,,,,,,,,,,,,,,,, -22100,0.17573763,3.9554026,,,,,,,,,,,,,,,,, -22200,0.18894124,3.9933264,,,,,,,,,,,,,,,,, -22300,0.18926388,4.006589,,,,,,,,,,,,,,,,, -22400,0.1810479,4.0000563,,,,,,,,,,,,,,,,, -22500,0.16939518,3.9683828,,,,,,,,,,,,,,,,, -22600,0.16549268,4.021272,,,,,,,,,,,,,,,,, -22700,0.27216798,3.9816766,,,,,,,,,,,,,,,,, -22800,0.1763144,3.959508,,,,,,,,,,,,,,,,, -22900,0.1533595,4.005563,,,,,,,,,,,,,,,,, -23000,0.19970615,3.9811807,,,,,,,,,,,,,,,,, -23100,0.17875199,3.968547,,,,,,,,,,,,,,,,, -23200,0.19244495,4.0236325,,,,,,,,,,,,,,,,, -23300,0.1720317,4.060032,,,,,,,,,,,,,,,,, -23400,0.16277571,4.041503,,,,,,,,,,,,,,,,, -23500,0.17384353,3.9357908,,,,,,,,,,,,,,,,, -23600,0.1675274,3.9879892,,,,,,,,,,,,,,,,, -23700,0.17186733,3.9366078,,,,,,,,,,,,,,,,, -23800,0.19381198,3.9692736,,,,,,,,,,,,,,,,, -23822,,,0.643540620803833,1.939845085144043,31.37433547834663,0.6603513956069946,1.811151027679444,28.12367394678881,3000.0,0.6710127592086792,1.7506834268569946,27.898585149515966,3003.0,8427.87146282196,14363.563488960266,8427.87146282196,5934.689563751221,0.2605419158935547,0.0 -23900,0.18764132,4.0178294,,,,,,,,,,,,,,,,, -24000,0.16287534,3.9203427,,,,,,,,,,,,,,,,, -24100,0.16819888,3.961662,,,,,,,,,,,,,,,,, -24200,0.17461671,4.0266705,,,,,,,,,,,,,,,,, -24300,0.19209363,4.0308485,,,,,,,,,,,,,,,,, -24400,0.19812103,3.9938126,,,,,,,,,,,,,,,,, -24500,0.17645578,3.9976914,,,,,,,,,,,,,,,,, -24600,0.1982108,3.988019,,,,,,,,,,,,,,,,, -24700,0.21004881,4.018054,,,,,,,,,,,,,,,,, -24800,0.2040026,4.014822,,,,,,,,,,,,,,,,, -24900,0.17328563,3.9848156,,,,,,,,,,,,,,,,, -25000,0.17060867,3.947987,,,,,,,,,,,,,,,,, -25100,0.26018983,4.037556,,,,,,,,,,,,,,,,, -25200,0.17849708,3.9769418,,,,,,,,,,,,,,,,, -25300,0.17268676,3.9558585,,,,,,,,,,,,,,,,, -25400,0.18810086,3.9795709,,,,,,,,,,,,,,,,, -25500,0.18105385,3.9960291,,,,,,,,,,,,,,,,, -25600,0.17638904,3.9796927,,,,,,,,,,,,,,,,, -25700,0.18436545,4.041126,,,,,,,,,,,,,,,,, -25800,0.18802775,3.9957786,,,,,,,,,,,,,,,,, -25900,0.22400104,3.9663389,,,,,,,,,,,,,,,,, -26000,0.2961288,3.9898713,,,,,,,,,,,,,,,,, -26100,0.18097192,4.003506,,,,,,,,,,,,,,,,, -26200,0.1835383,4.0112,,,,,,,,,,,,,,,,, -26206,,,0.653814435005188,1.8643923997879028,32.41593082315556,0.6627939939498901,1.7915291786193848,28.8887559201092,3000.0,0.6742897033691406,1.722639799118042,28.15439954133418,3003.0,9268.00053191185,15734.331931114197,9268.00053191185,6465.22674870491,0.2896280288696289,0.0 -26300,0.22097875,3.9861362,,,,,,,,,,,,,,,,, -26400,0.18087839,3.9241228,,,,,,,,,,,,,,,,, -26500,0.18169494,3.9413247,,,,,,,,,,,,,,,,, -26600,0.19219789,3.9209728,,,,,,,,,,,,,,,,, -26700,0.18591471,3.9556813,,,,,,,,,,,,,,,,, -26800,0.22602874,6.177408,,,,,,,,,,,,,,,,, -26900,0.22640017,5.6750317,,,,,,,,,,,,,,,,, -27000,0.2196189,5.572152,,,,,,,,,,,,,,,,, -27100,0.2253264,5.5760417,,,,,,,,,,,,,,,,, -27200,0.23266347,5.4952836,,,,,,,,,,,,,,,,, -27300,0.57147056,5.507762,,,,,,,,,,,,,,,,, -27400,0.64075035,5.4911494,,,,,,,,,,,,,,,,, -27500,0.97834444,5.4387593,,,,,,,,,,,,,,,,, -27600,0.20043911,4.0110946,,,,,,,,,,,,,,,,, -27700,0.20015444,3.9939048,,,,,,,,,,,,,,,,, -27800,0.19473644,3.9585445,,,,,,,,,,,,,,,,, -27900,0.18816435,3.992001,,,,,,,,,,,,,,,,, -28000,0.18675977,4.00122,,,,,,,,,,,,,,,,, -28100,0.21085864,3.9844441,,,,,,,,,,,,,,,,, -28200,0.18389247,4.000641,,,,,,,,,,,,,,,,, -28300,0.19469611,3.9624977,,,,,,,,,,,,,,,,, -28400,0.23674111,3.9713938,,,,,,,,,,,,,,,,, -28500,0.17424215,3.9420855,,,,,,,,,,,,,,,,, -28591,,,0.6467625498771667,1.921960711479187,31.52587145991113,0.6663525700569153,1.796502947807312,28.85844748497176,3000.0,0.6772529482841492,1.7303847074508667,28.447757417888468,3003.0,10108.170245409012,17071.990015506744,10108.170245409012,6962.611318826675,0.3182723522186279,0.0 -28600,0.2488492,3.9573739,,,,,,,,,,,,,,,,, -28700,0.2618618,3.913902,,,,,,,,,,,,,,,,, -28800,0.20717897,3.9914753,,,,,,,,,,,,,,,,, -28900,0.24586329,3.9180155,,,,,,,,,,,,,,,,, -29000,0.19511014,4.0354958,,,,,,,,,,,,,,,,, -29100,0.24011184,3.902607,,,,,,,,,,,,,,,,, -29200,0.2708419,3.9585962,,,,,,,,,,,,,,,,, -29300,0.18438222,3.9599247,,,,,,,,,,,,,,,,, -29400,0.1849112,4.043294,,,,,,,,,,,,,,,,, -29500,0.18263319,4.0193214,,,,,,,,,,,,,,,,, -29600,0.21997207,3.96662,,,,,,,,,,,,,,,,, -29700,0.27282116,3.9547386,,,,,,,,,,,,,,,,, -29800,0.18760906,3.9381576,,,,,,,,,,,,,,,,, -29900,0.28410777,3.93753,,,,,,,,,,,,,,,,, -30000,0.18659557,3.9112594,,,,,,,,,,,,,,,,, -30100,0.19810058,3.9375405,,,,,,,,,,,,,,,,, -30200,0.19588462,3.9812891,,,,,,,,,,,,,,,,, -30300,0.19203743,3.959199,,,,,,,,,,,,,,,,, -30400,0.24577112,3.9398031,,,,,,,,,,,,,,,,, -30500,0.20590435,3.9422233,,,,,,,,,,,,,,,,, -30600,0.24264303,3.979879,,,,,,,,,,,,,,,,, -30700,0.2578972,3.9964888,,,,,,,,,,,,,,,,, -30800,0.185791,3.966255,,,,,,,,,,,,,,,,, -30900,0.19645122,3.864366,,,,,,,,,,,,,,,,, -30974,,,0.6443696022033691,1.931141018867493,31.986524329560304,0.6660549640655518,1.7905811071395874,28.60632817214931,3000.0,0.6762768030166626,1.725675344467163,27.87737284091013,3003.0,10948.20868062973,18420.80942606926,10948.20868062973,7471.288013458252,0.3465893268585205,0.0 -31000,0.25260937,3.9917006,,,,,,,,,,,,,,,,, -31100,0.19003326,3.9223905,,,,,,,,,,,,,,,,, -31200,0.23101445,3.9437308,,,,,,,,,,,,,,,,, -31300,0.1994633,4.0000443,,,,,,,,,,,,,,,,, -31400,0.20217656,3.968123,,,,,,,,,,,,,,,,, -31500,0.204907,3.9466913,,,,,,,,,,,,,,,,, -31600,0.19438209,3.912966,,,,,,,,,,,,,,,,, -31700,0.21778837,3.9545515,,,,,,,,,,,,,,,,, -31800,0.2538154,3.8450758,,,,,,,,,,,,,,,,, -31900,0.30877972,3.950447,,,,,,,,,,,,,,,,, -32000,0.20622467,3.9592752,,,,,,,,,,,,,,,,, -32100,0.1889045,3.9198027,,,,,,,,,,,,,,,,, -32200,0.22104119,3.9483058,,,,,,,,,,,,,,,,, -32300,0.19967094,3.9710612,,,,,,,,,,,,,,,,, -32400,0.21789603,3.915516,,,,,,,,,,,,,,,,, -32500,0.21401145,3.9162045,,,,,,,,,,,,,,,,, -32600,0.19826646,3.9734037,,,,,,,,,,,,,,,,, -32700,0.21035887,3.9593883,,,,,,,,,,,,,,,,, -32800,0.20398645,3.98774,,,,,,,,,,,,,,,,, -32900,0.2183258,3.9223986,,,,,,,,,,,,,,,,, -33000,0.21960804,3.9116297,,,,,,,,,,,,,,,,, -33100,0.19125108,3.979899,,,,,,,,,,,,,,,,, -33200,0.26888838,3.9446888,,,,,,,,,,,,,,,,, -33300,0.19961359,3.9248517,,,,,,,,,,,,,,,,, -33357,,,0.6579074263572693,1.84610378742218,32.069965372459464,0.6691919565200806,1.755408525466919,29.00561299490477,3000.0,0.680727481842041,1.6848061084747314,28.342088839250422,3003.0,11788.108996391296,19917.94206786156,11788.108996391296,8128.417171955109,0.3758647441864013,0.0 -33400,0.20930767,3.9580622,,,,,,,,,,,,,,,,, -33500,0.19211587,3.9461725,,,,,,,,,,,,,,,,, -33600,0.23739879,3.923314,,,,,,,,,,,,,,,,, -33700,0.18414433,3.8513408,,,,,,,,,,,,,,,,, -33800,0.2087007,3.8659832,,,,,,,,,,,,,,,,, -33900,0.32447332,3.925576,,,,,,,,,,,,,,,,, -34000,0.23777133,3.9276848,,,,,,,,,,,,,,,,, -34100,0.22731648,3.9028866,,,,,,,,,,,,,,,,, -34200,0.20238458,3.876138,,,,,,,,,,,,,,,,, -34300,0.2233596,3.946669,,,,,,,,,,,,,,,,, -34400,0.20091847,3.9282513,,,,,,,,,,,,,,,,, -34500,0.19553125,3.9261231,,,,,,,,,,,,,,,,, -34600,0.29613525,3.8769522,,,,,,,,,,,,,,,,, -34700,0.20642403,3.975989,,,,,,,,,,,,,,,,, -34800,0.19064246,3.9145634,,,,,,,,,,,,,,,,, -34900,0.19224839,3.8818731,,,,,,,,,,,,,,,,, -35000,0.24262822,3.9682727,,,,,,,,,,,,,,,,, -35100,0.22242153,3.9419897,,,,,,,,,,,,,,,,, -35200,0.23250492,3.9231734,,,,,,,,,,,,,,,,, -35300,0.19168468,3.8911095,,,,,,,,,,,,,,,,, -35400,2.5040803,5.064644,,,,,,,,,,,,,,,,, -35500,0.22802775,4.0531306,,,,,,,,,,,,,,,,, -35600,0.22008368,3.949859,,,,,,,,,,,,,,,,, -35700,0.20354393,3.9340267,,,,,,,,,,,,,,,,, -35740,,,0.6511590480804443,1.890103459358216,32.00241740281816,0.6663649678230286,1.756929874420166,29.13012192945443,3000.0,0.6793214082717896,1.6866551637649536,28.543354366147703,3003.0,12628.127183675766,21234.825357437134,12628.127183675766,8605.174662351608,0.4066076278686523,0.0 -35800,0.21498586,4.00287,,,,,,,,,,,,,,,,, -35900,0.20987833,3.9382522,,,,,,,,,,,,,,,,, -36000,0.21228288,3.89305,,,,,,,,,,,,,,,,, -36100,0.21780407,3.9523187,,,,,,,,,,,,,,,,, -36200,0.25843918,3.9285686,,,,,,,,,,,,,,,,, -36300,0.19779234,3.976664,,,,,,,,,,,,,,,,, -36400,0.20446967,3.9225514,,,,,,,,,,,,,,,,, -36500,0.19834377,3.8563173,,,,,,,,,,,,,,,,, -36600,0.18898052,3.8800578,,,,,,,,,,,,,,,,, -36700,0.18764393,3.8772297,,,,,,,,,,,,,,,,, -36800,0.23189388,3.935535,,,,,,,,,,,,,,,,, -36900,0.20349856,3.9666378,,,,,,,,,,,,,,,,, -37000,0.1981038,3.9040954,,,,,,,,,,,,,,,,, -37100,0.26944214,3.947822,,,,,,,,,,,,,,,,, -37200,0.2909083,3.957145,,,,,,,,,,,,,,,,, -37300,0.26556134,3.9141374,,,,,,,,,,,,,,,,, -37400,0.21031234,3.8656683,,,,,,,,,,,,,,,,, -37500,0.3042712,3.9170315,,,,,,,,,,,,,,,,, -37600,0.22094205,3.8953416,,,,,,,,,,,,,,,,, -37700,0.25615165,3.8217146,,,,,,,,,,,,,,,,, -37800,0.3063685,3.8961859,,,,,,,,,,,,,,,,, -37900,0.26780677,3.8930337,,,,,,,,,,,,,,,,, -38000,0.20671752,3.900279,,,,,,,,,,,,,,,,, -38100,0.21931091,3.9298997,,,,,,,,,,,,,,,,, -38123,,,0.6650872230529785,1.7879432439804075,33.173089200434895,0.6712377667427063,1.7390223741531372,29.3184200998274,3000.0,0.6831677556037903,1.668337106704712,28.57284390796016,3003.0,13468.265769004822,22642.39179039001,13468.265769004822,9172.487488031387,0.4418942928314209,0.0 -38200,0.21850128,3.9074826,,,,,,,,,,,,,,,,, -38300,0.21944147,3.829821,,,,,,,,,,,,,,,,, -38400,0.25912708,3.9258554,,,,,,,,,,,,,,,,, -38500,0.24778752,3.8680391,,,,,,,,,,,,,,,,, -38600,0.19824685,3.8689375,,,,,,,,,,,,,,,,, -38700,0.23179324,3.8474658,,,,,,,,,,,,,,,,, -38800,0.22082825,3.9200225,,,,,,,,,,,,,,,,, -38900,0.24400207,3.8689678,,,,,,,,,,,,,,,,, -39000,0.22380231,3.981078,,,,,,,,,,,,,,,,, -39100,0.22554158,3.917763,,,,,,,,,,,,,,,,, -39200,0.22527766,3.908036,,,,,,,,,,,,,,,,, -39300,0.22859834,3.9263604,,,,,,,,,,,,,,,,, -39400,0.22680445,3.971602,,,,,,,,,,,,,,,,, -39500,0.21654688,3.9174054,,,,,,,,,,,,,,,,, -39600,0.255574,3.895102,,,,,,,,,,,,,,,,, -39700,0.28496656,3.8766053,,,,,,,,,,,,,,,,, -39800,0.255404,3.915324,,,,,,,,,,,,,,,,, -39900,0.23389682,3.9041147,,,,,,,,,,,,,,,,, -40000,0.22041048,3.963947,,,,,,,,,,,,,,,,, -40100,0.23842098,3.9063022,,,,,,,,,,,,,,,,, -40200,0.2219386,3.8851333,,,,,,,,,,,,,,,,, -40300,0.24358942,3.8929777,,,,,,,,,,,,,,,,, -40400,0.24692303,3.9068282,,,,,,,,,,,,,,,,, -40500,0.2258803,3.9137616,,,,,,,,,,,,,,,,, -40506,,,0.6558361649513245,1.8276304006576536,32.338058333076155,0.6708658337593079,1.7183750867843628,29.019836053244,3000.0,0.6837139129638672,1.6359859704971311,28.80755072318091,3003.0,14308.177471637726,23999.80347084999,14308.177471637726,9689.880198001862,0.4748435020446777,0.0 -40600,0.23445748,3.9135122,,,,,,,,,,,,,,,,, -40700,0.23307285,3.8882892,,,,,,,,,,,,,,,,, -40800,0.31480482,3.8762283,,,,,,,,,,,,,,,,, -40900,0.22872664,3.8416214,,,,,,,,,,,,,,,,, -41000,0.21160297,3.8925152,,,,,,,,,,,,,,,,, -41100,0.2408963,3.9443789,,,,,,,,,,,,,,,,, -41200,0.2565087,3.9186,,,,,,,,,,,,,,,,, -41300,0.23272921,3.8953996,,,,,,,,,,,,,,,,, -41400,0.227606,3.8557117,,,,,,,,,,,,,,,,, -41500,0.21453391,3.871677,,,,,,,,,,,,,,,,, -41600,0.2628017,3.908059,,,,,,,,,,,,,,,,, -41700,0.378015,3.9478407,,,,,,,,,,,,,,,,, -41800,0.23084891,3.9453425,,,,,,,,,,,,,,,,, -41900,0.22319438,3.92468,,,,,,,,,,,,,,,,, -42000,0.24681287,3.9042425,,,,,,,,,,,,,,,,, -42100,0.22803022,3.8977726,,,,,,,,,,,,,,,,, -42200,0.22667335,3.906285,,,,,,,,,,,,,,,,, -42300,0.21659422,3.8826833,,,,,,,,,,,,,,,,, -42400,0.23627533,3.8842545,,,,,,,,,,,,,,,,, -42500,0.21909684,3.8696473,,,,,,,,,,,,,,,,, -42600,0.24628134,3.8612232,,,,,,,,,,,,,,,,, -42700,0.22814125,3.8894863,,,,,,,,,,,,,,,,, -42800,0.24994712,3.8883152,,,,,,,,,,,,,,,,, -42889,,,0.3256044685840606,4.233160018920898,0.1383137865844718,0.3041127920150757,4.484804153442383,0.0647455762933908,3000.0,0.3018534779548645,4.571069240570068,0.033897742842188,3003.0,15148.184902191162,25539.77074050904,15148.184902191162,10389.728868246078,0.51078200340271,0.0 -42900,9.599934,9.04253,,,,,,,,,,,,,,,,, -43000,0.3391931,5.5749846,,,,,,,,,,,,,,,,, -43100,0.6890672,5.480607,,,,,,,,,,,,,,,,, -43200,0.2602715,5.4626756,,,,,,,,,,,,,,,,, -43300,0.3091417,5.471932,,,,,,,,,,,,,,,,, -43400,1.891618,5.385476,,,,,,,,,,,,,,,,, -43500,0.90639156,4.0625134,,,,,,,,,,,,,,,,, -43600,0.23844583,3.8934457,,,,,,,,,,,,,,,,, -43700,0.27480876,3.9244936,,,,,,,,,,,,,,,,, -43800,0.24581201,3.905408,,,,,,,,,,,,,,,,, -43900,0.20962031,3.8771312,,,,,,,,,,,,,,,,, -44000,0.21973896,3.869232,,,,,,,,,,,,,,,,, -44100,0.22922051,3.8627858,,,,,,,,,,,,,,,,, -44200,0.22853959,3.9403353,,,,,,,,,,,,,,,,, -44300,0.2174688,3.8695312,,,,,,,,,,,,,,,,, -44400,0.20208958,3.895259,,,,,,,,,,,,,,,,, -44500,0.21107158,3.9051614,,,,,,,,,,,,,,,,, -44600,0.22033586,3.8965542,,,,,,,,,,,,,,,,, -44700,0.21355745,3.812528,,,,,,,,,,,,,,,,, -44800,0.22011626,3.8519447,,,,,,,,,,,,,,,,, -44900,0.23369849,4.0016017,,,,,,,,,,,,,,,,, -45000,0.29617846,3.8425968,,,,,,,,,,,,,,,,, -45100,0.2482856,3.8691885,,,,,,,,,,,,,,,,, -45200,0.23180665,3.9153092,,,,,,,,,,,,,,,,, -45272,,,0.6611015796661377,1.7925008535385132,32.56452484565516,0.6752551198005676,1.697988986968994,29.62041815683345,3000.0,0.6868398189544678,1.6206845045089722,28.890927730804112,3003.0,15988.277658700945,26847.600742578503,15988.277658700945,10857.352401733398,0.5476312637329102,0.0 -45300,0.2838479,3.8735642,,,,,,,,,,,,,,,,, -45400,0.24916121,3.9085972,,,,,,,,,,,,,,,,, -45500,0.27054986,3.956018,,,,,,,,,,,,,,,,, -45600,0.2602482,3.908778,,,,,,,,,,,,,,,,, -45700,0.20943809,3.9006636,,,,,,,,,,,,,,,,, -45800,0.22103211,3.8880277,,,,,,,,,,,,,,,,, -45900,0.23842171,3.9469352,,,,,,,,,,,,,,,,, -46000,0.26611385,3.8574522,,,,,,,,,,,,,,,,, -46100,0.22548406,3.9437976,,,,,,,,,,,,,,,,, -46200,0.22511752,3.889745,,,,,,,,,,,,,,,,, -46300,0.22609825,3.8872364,,,,,,,,,,,,,,,,, -46400,0.21734704,3.8659449,,,,,,,,,,,,,,,,, -46500,0.21654865,3.839165,,,,,,,,,,,,,,,,, -46600,0.38347507,3.8700607,,,,,,,,,,,,,,,,, -46700,0.3207214,3.8827028,,,,,,,,,,,,,,,,, -46800,0.24754015,3.888252,,,,,,,,,,,,,,,,, -46900,0.2334464,3.8766687,,,,,,,,,,,,,,,,, -47000,0.24414992,3.955511,,,,,,,,,,,,,,,,, -47100,0.24140874,3.9509625,,,,,,,,,,,,,,,,, -47200,0.23129956,3.9341714,,,,,,,,,,,,,,,,, -47300,0.2344891,3.823359,,,,,,,,,,,,,,,,, -47400,0.23595789,3.8559124,,,,,,,,,,,,,,,,, -47500,0.2664866,3.9059713,,,,,,,,,,,,,,,,, -47600,0.23217689,3.8337715,,,,,,,,,,,,,,,,, -47656,,,0.6571767330169678,1.834312081336975,32.09281506764354,0.6745359301567078,1.7079156637191772,29.54546456908649,3000.0,0.6875022053718567,1.6302539110183716,29.12670142648821,3003.0,16828.508974790573,28198.55082011223,16828.508974790573,11367.964815616608,0.5800197124481201,0.0 -47700,0.29183868,3.9302511,,,,,,,,,,,,,,,,, -47800,0.27568668,3.8733828,,,,,,,,,,,,,,,,, -47900,0.27069694,3.8964317,,,,,,,,,,,,,,,,, -48000,0.2695859,3.8635736,,,,,,,,,,,,,,,,, -48100,0.25570902,3.819749,,,,,,,,,,,,,,,,, -48200,0.29135656,3.8897743,,,,,,,,,,,,,,,,, -48300,0.2354265,3.905189,,,,,,,,,,,,,,,,, -48400,0.2513518,3.9008956,,,,,,,,,,,,,,,,, -48500,0.27644363,3.8761427,,,,,,,,,,,,,,,,, -48600,0.2361801,3.8852906,,,,,,,,,,,,,,,,, -48700,0.294161,3.9113362,,,,,,,,,,,,,,,,, -48800,0.23020613,3.838882,,,,,,,,,,,,,,,,, -48900,0.24692433,3.9261756,,,,,,,,,,,,,,,,, -49000,0.22905752,3.8770864,,,,,,,,,,,,,,,,, -49100,0.2232262,3.8698335,,,,,,,,,,,,,,,,, -49200,0.22583413,3.8694756,,,,,,,,,,,,,,,,, -49300,0.24745315,3.9026153,,,,,,,,,,,,,,,,, -49400,0.25235263,3.9237995,,,,,,,,,,,,,,,,, -49500,0.24908279,3.849708,,,,,,,,,,,,,,,,, -49600,0.23864628,3.8929124,,,,,,,,,,,,,,,,, -49700,0.24687463,3.927888,,,,,,,,,,,,,,,,, -49800,0.25636184,3.9444046,,,,,,,,,,,,,,,,, -49900,0.25326833,3.9505866,,,,,,,,,,,,,,,,, -50000,0.22121494,3.883209,,,,,,,,,,,,,,,,, -50039,,,0.70138019323349,1.5984097719192505,35.59088268760344,0.6770529747009277,1.7046748399734497,29.58461302392992,3000.0,0.6908721327781677,1.6230562925338743,29.203659503289817,3003.0,17668.53111076355,29614.12636780739,17668.53111076355,11943.41303062439,0.6114578247070312,0.0 -50100,0.28519192,3.8834713,,,,,,,,,,,,,,,,, -50200,0.24605705,3.881535,,,,,,,,,,,,,,,,, -50300,0.25942245,3.8799238,,,,,,,,,,,,,,,,, -50400,0.27313045,3.881454,,,,,,,,,,,,,,,,, -50500,0.23635289,3.84606,,,,,,,,,,,,,,,,, -50600,0.23941204,3.8711128,,,,,,,,,,,,,,,,, -50700,0.22206974,3.869877,,,,,,,,,,,,,,,,, -50800,0.24387783,3.9165268,,,,,,,,,,,,,,,,, -50900,0.24986154,3.8979332,,,,,,,,,,,,,,,,, -51000,0.2919052,3.8479853,,,,,,,,,,,,,,,,, -51100,0.24284913,3.8276944,,,,,,,,,,,,,,,,, -51200,0.23975743,3.8454378,,,,,,,,,,,,,,,,, -51300,0.28185865,3.8967154,,,,,,,,,,,,,,,,, -51400,0.25060514,3.8114462,,,,,,,,,,,,,,,,, -51500,0.25513428,3.838882,,,,,,,,,,,,,,,,, -51600,0.24708836,3.9132025,,,,,,,,,,,,,,,,, -51700,0.23968995,3.878535,,,,,,,,,,,,,,,,, -51800,0.25583422,3.9323077,,,,,,,,,,,,,,,,, -51900,0.24859983,3.8463945,,,,,,,,,,,,,,,,, -52000,0.32063404,3.8498938,,,,,,,,,,,,,,,,, -52100,0.23146623,3.88068,,,,,,,,,,,,,,,,, -52200,0.25347334,3.9050772,,,,,,,,,,,,,,,,, -52300,0.25735393,3.8510642,,,,,,,,,,,,,,,,, -52400,0.25066203,3.9199438,,,,,,,,,,,,,,,,, -52421,,,0.6642529368400574,1.7819786071777344,32.94824790704417,0.6771645545959473,1.690804123878479,29.778122229123387,3000.0,0.6908256411552429,1.6125904321670532,29.400656364664417,3003.0,18508.57137870789,30962.36435317993,18508.57137870789,12451.49473619461,0.6497724056243896,0.0 -52500,0.24475229,3.8752494,,,,,,,,,,,,,,,,, -52600,0.26187035,3.8867583,,,,,,,,,,,,,,,,, -52700,0.246642,3.8262582,,,,,,,,,,,,,,,,, -52800,0.24566475,3.8606553,,,,,,,,,,,,,,,,, -52900,0.25702998,3.8564744,,,,,,,,,,,,,,,,, -53000,0.25905606,3.856453,,,,,,,,,,,,,,,,, -53100,0.25122246,3.8030055,,,,,,,,,,,,,,,,, -53200,0.24215448,3.8255107,,,,,,,,,,,,,,,,, -53300,0.2592376,3.7949824,,,,,,,,,,,,,,,,, -53400,0.24250871,3.865639,,,,,,,,,,,,,,,,, -53500,0.27091584,3.8577,,,,,,,,,,,,,,,,, -53600,0.27994293,3.83026,,,,,,,,,,,,,,,,, -53700,0.2802844,3.8974252,,,,,,,,,,,,,,,,, -53800,0.24800263,3.8390336,,,,,,,,,,,,,,,,, -53900,0.25370452,3.8515768,,,,,,,,,,,,,,,,, -54000,0.41286922,3.8436406,,,,,,,,,,,,,,,,, -54100,0.26519412,3.942608,,,,,,,,,,,,,,,,, -54200,0.26803792,3.884551,,,,,,,,,,,,,,,,, -54300,0.2813899,3.8665814,,,,,,,,,,,,,,,,, -54400,0.30018324,3.8536417,,,,,,,,,,,,,,,,, -54500,0.2688539,3.9198787,,,,,,,,,,,,,,,,, -54600,0.23584367,3.8681188,,,,,,,,,,,,,,,,, -54700,0.24359737,3.8884814,,,,,,,,,,,,,,,,, -54800,0.31115586,3.8136466,,,,,,,,,,,,,,,,, -54804,,,0.6630371809005737,1.7826213836669922,32.97631954929356,0.6774993538856506,1.6878796815872192,29.82505230190574,3000.0,0.6921155452728271,1.6047054529190063,29.701937568525977,3003.0,19348.67028999329,32287.47656297684,19348.67028999329,12936.39828658104,0.6831979751586914,0.0 -54900,0.26619974,3.8829184,,,,,,,,,,,,,,,,, -55000,0.31062406,3.8887718,,,,,,,,,,,,,,,,, -55100,0.26006502,3.8492603,,,,,,,,,,,,,,,,, -55200,0.38310087,3.831146,,,,,,,,,,,,,,,,, -55300,0.25251693,3.822272,,,,,,,,,,,,,,,,, -55400,0.23165691,3.8632035,,,,,,,,,,,,,,,,, -55500,0.3217497,3.9053068,,,,,,,,,,,,,,,,, -55600,0.28480947,3.9156587,,,,,,,,,,,,,,,,, -55700,0.26279148,3.780928,,,,,,,,,,,,,,,,, -55800,0.2501911,3.8627708,,,,,,,,,,,,,,,,, -55900,0.23949976,3.8450131,,,,,,,,,,,,,,,,, -56000,0.24574375,3.8719068,,,,,,,,,,,,,,,,, -56100,0.2532229,3.8583786,,,,,,,,,,,,,,,,, -56200,0.28623748,3.9015546,,,,,,,,,,,,,,,,, -56300,0.26101562,3.8930109,,,,,,,,,,,,,,,,, -56400,0.23923963,3.8440561,,,,,,,,,,,,,,,,, -56500,0.24090962,3.883886,,,,,,,,,,,,,,,,, -56600,0.25997078,3.864087,,,,,,,,,,,,,,,,, -56700,0.29385206,3.7863162,,,,,,,,,,,,,,,,, -56800,0.24899358,3.838602,,,,,,,,,,,,,,,,, -56900,0.23584507,3.7889216,,,,,,,,,,,,,,,,, -57000,0.24018969,3.752681,,,,,,,,,,,,,,,,, -57100,0.2517158,3.9046671,,,,,,,,,,,,,,,,, -57187,,,0.677360475063324,1.7102564573287964,33.62485481934153,0.6791856288909912,1.6819928884506226,29.688366128750683,3000.0,0.6951368451118469,1.5948320627212524,29.96182087882542,3003.0,20188.675621509552,33608.99349451065,20188.675621509552,13417.802710533142,0.7159018516540527,0.0 -57200,0.27393985,3.881677,,,,,,,,,,,,,,,,, -57300,0.26838067,3.8638592,,,,,,,,,,,,,,,,, -57400,0.25216582,3.829955,,,,,,,,,,,,,,,,, -57500,0.35179418,3.8266037,,,,,,,,,,,,,,,,, -57600,0.25778714,3.7996016,,,,,,,,,,,,,,,,, -57700,0.25513563,3.8316026,,,,,,,,,,,,,,,,, -57800,0.24663053,3.7960658,,,,,,,,,,,,,,,,, -57900,0.26050127,3.880553,,,,,,,,,,,,,,,,, -58000,0.26290345,3.8909693,,,,,,,,,,,,,,,,, -58100,0.23115367,3.8530893,,,,,,,,,,,,,,,,, -58200,0.24620025,3.828422,,,,,,,,,,,,,,,,, -58300,0.26380184,3.7969472,,,,,,,,,,,,,,,,, -58400,0.25300893,3.8619297,,,,,,,,,,,,,,,,, -58500,0.25091264,3.8454216,,,,,,,,,,,,,,,,, -58600,0.26268294,3.8223925,,,,,,,,,,,,,,,,, -58700,0.26208064,3.8333323,,,,,,,,,,,,,,,,, -58800,0.27618235,3.8255386,,,,,,,,,,,,,,,,, -58900,0.34240162,3.9153833,,,,,,,,,,,,,,,,, -59000,0.32112253,3.8380277,,,,,,,,,,,,,,,,, -59100,0.27160594,3.7956,,,,,,,,,,,,,,,,, -59200,0.3996905,3.8995185,,,,,,,,,,,,,,,,, -59300,0.2707474,3.8374362,,,,,,,,,,,,,,,,, -59400,0.24763225,3.8429255,,,,,,,,,,,,,,,,, -59500,0.26350647,3.814895,,,,,,,,,,,,,,,,, -59571,,,0.665337324142456,1.7709975242614746,32.771778877290295,0.6795576214790344,1.667377233505249,29.518673328869195,3000.0,0.6934635043144226,1.5827269554138184,29.45679735869189,3003.0,21028.83904337883,35029.61338472366,21028.83904337883,13998.15175628662,0.7499048709869385,0.0 -59600,0.29546675,3.880469,,,,,,,,,,,,,,,,, -59700,0.23331936,3.800466,,,,,,,,,,,,,,,,, -59800,0.30238572,3.7925153,,,,,,,,,,,,,,,,, -59900,0.2540615,3.8012679,,,,,,,,,,,,,,,,, -60000,0.2650016,3.8709998,,,,,,,,,,,,,,,,, -60100,0.31742737,3.7910311,,,,,,,,,,,,,,,,, -60200,0.27460733,3.8547957,,,,,,,,,,,,,,,,, -60300,0.2632245,3.835099,,,,,,,,,,,,,,,,, -60400,0.25400716,3.8288336,,,,,,,,,,,,,,,,, -60500,0.26466915,3.8336644,,,,,,,,,,,,,,,,, -60600,0.26596043,3.8288467,,,,,,,,,,,,,,,,, -60700,0.256969,3.801208,,,,,,,,,,,,,,,,, -60800,0.25988582,3.8454804,,,,,,,,,,,,,,,,, -60900,0.26160535,3.800325,,,,,,,,,,,,,,,,, -61000,0.27616808,3.8166804,,,,,,,,,,,,,,,,, -61100,0.27576217,3.8531468,,,,,,,,,,,,,,,,, -61200,0.25662532,3.8755682,,,,,,,,,,,,,,,,, -61300,0.25142613,3.8148155,,,,,,,,,,,,,,,,, -61400,0.27441344,3.8098233,,,,,,,,,,,,,,,,, -61500,0.25243813,3.7977448,,,,,,,,,,,,,,,,, -61600,0.27043903,3.774268,,,,,,,,,,,,,,,,, -61700,0.2912185,3.8361144,,,,,,,,,,,,,,,,, -61800,0.2443746,3.8328445,,,,,,,,,,,,,,,,, -61900,0.2699383,3.8792167,,,,,,,,,,,,,,,,, -61955,,,0.6696307063102722,1.765308141708374,33.5078406006622,0.6800783276557922,1.670266032218933,29.881170707042948,3000.0,0.6972982287406921,1.5744260549545288,30.065476502372743,3003.0,21869.08454298973,36379.56156635285,21869.08454298973,14507.747369527817,0.7831311225891113,0.0 -62000,0.27459627,3.871702,,,,,,,,,,,,,,,,, -62100,0.30189502,3.8233125,,,,,,,,,,,,,,,,, -62200,0.25884598,3.8497734,,,,,,,,,,,,,,,,, -62300,0.26234937,3.7836018,,,,,,,,,,,,,,,,, -62400,0.2652778,3.8324187,,,,,,,,,,,,,,,,, -62500,0.26590428,3.864733,,,,,,,,,,,,,,,,, -62600,0.29907927,3.8768296,,,,,,,,,,,,,,,,, -62700,0.27339083,3.8764265,,,,,,,,,,,,,,,,, -62800,0.23416695,3.8169122,,,,,,,,,,,,,,,,, -62900,0.27495638,3.8703825,,,,,,,,,,,,,,,,, -63000,0.25853977,3.8511095,,,,,,,,,,,,,,,,, -63100,0.28224567,3.840384,,,,,,,,,,,,,,,,, -63200,0.2701351,3.846072,,,,,,,,,,,,,,,,, -63300,0.27987567,3.853782,,,,,,,,,,,,,,,,, -63400,0.28186673,3.7900631,,,,,,,,,,,,,,,,, -63500,0.27552637,3.771329,,,,,,,,,,,,,,,,, -63600,0.2741779,3.8030047,,,,,,,,,,,,,,,,, -63700,0.27040577,3.823247,,,,,,,,,,,,,,,,, -63800,0.32703078,3.8525279,,,,,,,,,,,,,,,,, -63900,0.25898135,3.8688881,,,,,,,,,,,,,,,,, -64000,0.2686843,3.7675006,,,,,,,,,,,,,,,,, -64100,0.280993,3.8173964,,,,,,,,,,,,,,,,, -64200,0.25354445,3.868873,,,,,,,,,,,,,,,,, -64300,0.2572745,3.826697,,,,,,,,,,,,,,,,, -64338,,,0.6764718890190125,1.7071046829223633,33.57440543575773,0.680251955986023,1.6660431623458862,29.700324861556748,3000.0,0.6957643628120422,1.577563762664795,29.864767151801697,3003.0,22709.00170826912,37693.27958345413,22709.00170826912,14981.440528154371,0.8170902729034424,0.0 -64400,0.26389125,3.8184285,,,,,,,,,,,,,,,,, -64500,0.27757174,3.879032,,,,,,,,,,,,,,,,, -64600,0.2569582,3.8026307,,,,,,,,,,,,,,,,, -64700,0.34665954,3.7648947,,,,,,,,,,,,,,,,, -64800,0.2511974,3.825832,,,,,,,,,,,,,,,,, -64900,0.27763122,3.8865235,,,,,,,,,,,,,,,,, -65000,0.2748749,3.7879097,,,,,,,,,,,,,,,,, -65100,0.27286866,3.793277,,,,,,,,,,,,,,,,, -65200,0.25834072,3.821338,,,,,,,,,,,,,,,,, -65300,0.3053768,3.8900046,,,,,,,,,,,,,,,,, -65400,0.27585483,3.790151,,,,,,,,,,,,,,,,, -65500,0.2561965,3.8027601,,,,,,,,,,,,,,,,, -65600,0.3067303,3.7623684,,,,,,,,,,,,,,,,, -65700,0.27960378,3.8211477,,,,,,,,,,,,,,,,, -65800,0.28227028,3.79157,,,,,,,,,,,,,,,,, -65900,0.28047806,3.8606656,,,,,,,,,,,,,,,,, -66000,0.32435718,3.8340316,,,,,,,,,,,,,,,,, -66100,0.25368077,3.821072,,,,,,,,,,,,,,,,, -66200,0.29830295,3.758087,,,,,,,,,,,,,,,,, -66300,0.25133127,3.761554,,,,,,,,,,,,,,,,, -66400,0.25278178,3.7813911,,,,,,,,,,,,,,,,, -66500,0.26885512,3.8565166,,,,,,,,,,,,,,,,, -66600,0.270058,3.8757675,,,,,,,,,,,,,,,,, -66700,0.2692777,3.803295,,,,,,,,,,,,,,,,, -66721,,,0.6704260110855103,1.7423919439315796,33.43165042153991,0.6816902160644531,1.6634081602096558,29.84663557839197,3000.0,0.6970425844192505,1.569135665893555,30.122066124861863,3003.0,23549.12446165085,39041.21701860428,23549.12446165085,15489.14245057106,0.854212760925293,0.0 -66800,0.26952177,3.8351285,,,,,,,,,,,,,,,,, -66900,0.27918768,3.7835312,,,,,,,,,,,,,,,,, -67000,0.2734913,3.803743,,,,,,,,,,,,,,,,, -67100,0.2706645,3.8289468,,,,,,,,,,,,,,,,, -67200,0.26790833,3.8825593,,,,,,,,,,,,,,,,, -67300,0.2798802,3.808795,,,,,,,,,,,,,,,,, -67400,0.25382185,3.784809,,,,,,,,,,,,,,,,, -67500,0.2975129,3.8037958,,,,,,,,,,,,,,,,, -67600,0.2727601,3.8263545,,,,,,,,,,,,,,,,, -67700,0.2827465,3.82637,,,,,,,,,,,,,,,,, -67800,0.25719148,3.8241234,,,,,,,,,,,,,,,,, -67900,0.29522476,3.7561562,,,,,,,,,,,,,,,,, -68000,0.28749055,3.8732615,,,,,,,,,,,,,,,,, -68100,0.28217942,3.7974946,,,,,,,,,,,,,,,,, -68200,0.28191075,3.8143318,,,,,,,,,,,,,,,,, -68300,0.30472204,3.847659,,,,,,,,,,,,,,,,, -68400,0.26672867,3.8925788,,,,,,,,,,,,,,,,, -68500,0.2569397,3.812866,,,,,,,,,,,,,,,,, -68600,0.27561235,3.7523656,,,,,,,,,,,,,,,,, -68700,0.27203938,3.7817385,,,,,,,,,,,,,,,,, -68800,0.2792932,3.8684366,,,,,,,,,,,,,,,,, -68900,0.3001932,3.8361435,,,,,,,,,,,,,,,,, -69000,0.26457906,3.8106813,,,,,,,,,,,,,,,,, -69100,0.27883935,3.8508003,,,,,,,,,,,,,,,,, -69104,,,0.6957202553749084,1.5890588760375977,35.166386173569414,0.682223379611969,1.655207276344299,29.95974849835134,3000.0,0.6975655555725098,1.562856674194336,29.880227745785337,3003.0,24389.31098389625,40385.25058054924,24389.31098389625,15992.877562999724,0.8907725811004639,0.0 -69200,0.26532376,3.83413,,,,,,,,,,,,,,,,, -69300,0.28878748,3.870197,,,,,,,,,,,,,,,,, -69400,0.2552825,3.7684085,,,,,,,,,,,,,,,,, -69500,0.2542869,3.800049,,,,,,,,,,,,,,,,, -69600,0.25859696,3.8115525,,,,,,,,,,,,,,,,, -69700,0.25767633,3.808712,,,,,,,,,,,,,,,,, -69800,0.27068427,3.7918918,,,,,,,,,,,,,,,,, -69900,0.27321777,3.782109,,,,,,,,,,,,,,,,, -70000,0.26747656,3.8040385,,,,,,,,,,,,,,,,, -70100,0.30093724,3.8032503,,,,,,,,,,,,,,,,, -70200,0.28392822,3.766859,,,,,,,,,,,,,,,,, -70300,0.30153176,3.771017,,,,,,,,,,,,,,,,, -70400,0.28407574,3.8067513,,,,,,,,,,,,,,,,, -70500,0.30215982,3.7830715,,,,,,,,,,,,,,,,, -70600,0.27399194,3.82132,,,,,,,,,,,,,,,,, -70700,0.26474714,3.7993293,,,,,,,,,,,,,,,,, -70800,0.2815783,3.7801068,,,,,,,,,,,,,,,,, -70900,0.24501082,3.7815642,,,,,,,,,,,,,,,,, -71000,0.25057602,3.7639031,,,,,,,,,,,,,,,,, -71100,0.31345695,3.7869692,,,,,,,,,,,,,,,,, -71200,0.26450175,3.7947795,,,,,,,,,,,,,,,,, -71300,0.27880907,3.802553,,,,,,,,,,,,,,,,, -71400,0.26601404,3.8122041,,,,,,,,,,,,,,,,, -71487,,,0.676636815071106,1.7049415111541748,34.26073413751061,0.6833145022392273,1.646494746208191,30.023471272179897,3000.0,0.6996223330497742,1.5564168691635132,30.269282466335337,3003.0,25229.35171794892,41740.079825639725,25229.35171794892,16507.55722308159,0.925260066986084,0.0 -71500,0.25898063,3.8098598,,,,,,,,,,,,,,,,, -71600,0.40363535,3.7530966,,,,,,,,,,,,,,,,, -71700,0.32215106,3.8441715,,,,,,,,,,,,,,,,, -71800,0.27147493,3.773622,,,,,,,,,,,,,,,,, -71900,0.268343,3.793069,,,,,,,,,,,,,,,,, -72000,0.2832016,3.741521,,,,,,,,,,,,,,,,, -72100,0.27436432,3.772285,,,,,,,,,,,,,,,,, -72200,0.2685448,3.778838,,,,,,,,,,,,,,,,, -72300,0.30032092,3.766434,,,,,,,,,,,,,,,,, -72400,0.28532007,3.8227177,,,,,,,,,,,,,,,,, -72500,0.2873668,3.83504,,,,,,,,,,,,,,,,, -72600,0.27710563,3.770482,,,,,,,,,,,,,,,,, -72700,0.25301456,3.7439802,,,,,,,,,,,,,,,,, -72800,0.28118435,3.8226104,,,,,,,,,,,,,,,,, -72900,0.35500354,3.8626573,,,,,,,,,,,,,,,,, -73000,0.2752903,3.7996652,,,,,,,,,,,,,,,,, -73100,0.28159615,3.7876482,,,,,,,,,,,,,,,,, -73200,0.26013714,3.8044102,,,,,,,,,,,,,,,,, -73300,0.27052948,3.7835464,,,,,,,,,,,,,,,,, -73400,0.26694074,3.7496052,,,,,,,,,,,,,,,,, -73500,0.28514192,3.8248842,,,,,,,,,,,,,,,,, -73600,0.27770182,3.7017744,,,,,,,,,,,,,,,,, -73700,0.29508647,3.8559158,,,,,,,,,,,,,,,,, -73800,0.28614503,3.7880938,,,,,,,,,,,,,,,,, -73871,,,0.6750386357307434,1.720869064331055,34.45133599614152,0.6839592456817627,1.6430405378341677,30.166010042652832,3000.0,0.6999593377113342,1.545217990875244,30.42530986159958,3003.0,26069.48642897606,43157.42514300346,26069.48642897606,17084.658698558807,0.9607217311859132,0.0 -73900,0.27015847,3.803383,,,,,,,,,,,,,,,,, -74000,0.33868948,3.8514242,,,,,,,,,,,,,,,,, -74100,0.27723983,3.8233557,,,,,,,,,,,,,,,,, -74200,0.26234022,3.757622,,,,,,,,,,,,,,,,, -74300,0.27048808,3.8058376,,,,,,,,,,,,,,,,, -74400,0.2830567,3.774762,,,,,,,,,,,,,,,,, -74500,0.26321357,3.7806792,,,,,,,,,,,,,,,,, -74600,0.27974582,3.7766807,,,,,,,,,,,,,,,,, -74700,0.26314002,3.805206,,,,,,,,,,,,,,,,, -74800,0.2971271,3.8092878,,,,,,,,,,,,,,,,, -74900,0.25877067,3.7778049,,,,,,,,,,,,,,,,, -75000,0.2942011,3.7963843,,,,,,,,,,,,,,,,, -75100,0.27423617,3.790055,,,,,,,,,,,,,,,,, -75200,0.2820093,3.7830305,,,,,,,,,,,,,,,,, -75300,0.29309097,3.7639155,,,,,,,,,,,,,,,,, -75400,0.27379873,3.8431842,,,,,,,,,,,,,,,,, -75500,0.2789754,3.7971334,,,,,,,,,,,,,,,,, -75600,0.26583806,3.7238586,,,,,,,,,,,,,,,,, -75700,0.2822349,3.8113513,,,,,,,,,,,,,,,,, -75800,0.29099324,3.7684128,,,,,,,,,,,,,,,,, -75900,0.28504264,3.8668525,,,,,,,,,,,,,,,,, -76000,0.26714766,3.7527287,,,,,,,,,,,,,,,,, -76100,0.2846203,3.7603812,,,,,,,,,,,,,,,,, -76200,0.2754805,3.7931802,,,,,,,,,,,,,,,,, -76254,,,0.6888977885246277,1.6275510787963867,35.07580034747715,0.6844180226325989,1.6386370658874512,29.92807660712913,3000.0,0.701655924320221,1.541452407836914,30.24334076025487,3003.0,26909.4674077034,44570.950585365295,26909.4674077034,17658.08572268486,1.0019049644470217,0.0 -76300,0.28336203,3.808411,,,,,,,,,,,,,,,,, -76400,0.2874291,3.7442877,,,,,,,,,,,,,,,,, -76500,0.2726147,3.7514756,,,,,,,,,,,,,,,,, -76600,0.27582076,3.7709973,,,,,,,,,,,,,,,,, -76700,0.2931475,3.8082194,,,,,,,,,,,,,,,,, -76800,0.27775574,3.7220218,,,,,,,,,,,,,,,,, -76900,0.26186007,3.8438623,,,,,,,,,,,,,,,,, -77000,0.2913215,3.7769594,,,,,,,,,,,,,,,,, -77100,0.25773796,3.7728488,,,,,,,,,,,,,,,,, -77200,0.29485616,3.7836955,,,,,,,,,,,,,,,,, -77300,0.28409192,3.7918963,,,,,,,,,,,,,,,,, -77400,0.2874028,3.7798803,,,,,,,,,,,,,,,,, -77500,0.2840695,3.756073,,,,,,,,,,,,,,,,, -77600,0.27712566,3.7881887,,,,,,,,,,,,,,,,, -77700,0.27005488,3.8046963,,,,,,,,,,,,,,,,, -77800,0.27038485,3.7641952,,,,,,,,,,,,,,,,, -77900,0.299047,3.8173919,,,,,,,,,,,,,,,,, -78000,0.29746142,3.7621965,,,,,,,,,,,,,,,,, -78100,0.2658538,3.7706656,,,,,,,,,,,,,,,,, -78200,0.3061676,3.8147175,,,,,,,,,,,,,,,,, -78300,0.2801973,3.8070555,,,,,,,,,,,,,,,,, -78400,0.28570002,3.794311,,,,,,,,,,,,,,,,, -78500,0.27235818,3.7550588,,,,,,,,,,,,,,,,, -78600,0.26666632,3.7066865,,,,,,,,,,,,,,,,, -78637,,,0.6790190935134888,1.6930485963821411,34.606128151781625,0.6853355765342712,1.6412664651870728,30.084906994619733,3000.0,0.7021439671516418,1.543318271636963,30.256775334023622,3003.0,27749.411451101303,45916.00671863556,27749.411451101303,18163.086530447006,1.0375986099243164,0.0 -78700,0.2851483,3.7824895,,,,,,,,,,,,,,,,, -78800,0.30135754,3.804271,,,,,,,,,,,,,,,,, -78900,0.28771737,3.7710834,,,,,,,,,,,,,,,,, -79000,0.2716332,3.7557359,,,,,,,,,,,,,,,,, -79100,0.29313338,3.7211115,,,,,,,,,,,,,,,,, -79200,0.29131466,3.803565,,,,,,,,,,,,,,,,, -79300,0.28685382,3.7626014,,,,,,,,,,,,,,,,, -79400,0.32311064,3.767413,,,,,,,,,,,,,,,,, -79500,0.2760459,3.7411416,,,,,,,,,,,,,,,,, -79600,0.28500068,3.80831,,,,,,,,,,,,,,,,, -79700,0.31096727,3.774746,,,,,,,,,,,,,,,,, -79800,0.29264832,3.7681715,,,,,,,,,,,,,,,,, -79900,0.32099682,3.7921002,,,,,,,,,,,,,,,,, -80000,0.28835264,3.8019023,,,,,,,,,,,,,,,,, -80100,0.30818838,3.7480676,,,,,,,,,,,,,,,,, -80200,0.27858382,3.797856,,,,,,,,,,,,,,,,, -80300,0.27441546,3.8003886,,,,,,,,,,,,,,,,, -80400,0.28352118,3.74162,,,,,,,,,,,,,,,,, -80500,0.30211535,3.8630872,,,,,,,,,,,,,,,,, -80600,0.29379037,3.7494977,,,,,,,,,,,,,,,,, -80700,0.2780783,3.7849197,,,,,,,,,,,,,,,,, -80800,0.28585154,3.7629733,,,,,,,,,,,,,,,,, -80900,0.2816458,3.7097602,,,,,,,,,,,,,,,,, -81000,0.27857047,3.8237865,,,,,,,,,,,,,,,,, -81021,,,0.68350750207901,1.6653733253479004,33.94700163998571,0.6868730783462524,1.628083348274231,30.37649273654434,3000.0,0.7027831077575684,1.532280445098877,30.454600655162768,3003.0,28589.64216661453,47232.37526059151,28589.64216661453,18639.112541913983,1.0760555267333984,0.0 -81100,0.28421205,3.7858958,,,,,,,,,,,,,,,,, -81200,0.31378704,3.7815847,,,,,,,,,,,,,,,,, -81300,0.29023263,3.7745337,,,,,,,,,,,,,,,,, -81400,0.29354173,3.7713099,,,,,,,,,,,,,,,,, -81500,0.2909095,3.786536,,,,,,,,,,,,,,,,, -81600,0.29652804,3.7511919,,,,,,,,,,,,,,,,, -81700,0.2843526,3.7605338,,,,,,,,,,,,,,,,, -81800,0.27344334,3.750123,,,,,,,,,,,,,,,,, -81900,0.29523283,3.737761,,,,,,,,,,,,,,,,, -82000,0.30184773,3.8065417,,,,,,,,,,,,,,,,, -82100,0.28478247,3.6911132,,,,,,,,,,,,,,,,, -82200,0.27561444,3.7611363,,,,,,,,,,,,,,,,, -82300,0.32097378,3.7367356,,,,,,,,,,,,,,,,, -82400,0.29527143,3.7750776,,,,,,,,,,,,,,,,, -82500,0.30396464,3.7716522,,,,,,,,,,,,,,,,, -82600,0.32868347,3.8161814,,,,,,,,,,,,,,,,, -82700,0.30247042,3.7754004,,,,,,,,,,,,,,,,, -82800,0.31521568,3.8084478,,,,,,,,,,,,,,,,, -82900,0.30748245,3.7537413,,,,,,,,,,,,,,,,, -83000,0.2907643,3.770643,,,,,,,,,,,,,,,,, -83100,0.2954344,3.7531822,,,,,,,,,,,,,,,,, -83200,0.30087084,3.757258,,,,,,,,,,,,,,,,, -83300,0.29470834,3.7471695,,,,,,,,,,,,,,,,, -83400,0.2930379,3.748115,,,,,,,,,,,,,,,,, -83404,,,0.689186692237854,1.628559947013855,34.95898421051862,0.6865258812904358,1.6255191564559937,29.923109223085948,3000.0,0.7037476301193237,1.527652382850647,30.497443854719528,3003.0,29429.75515246392,48569.43083524704,29429.75515246392,19135.941576480865,1.1130588054656982,0.0 -83500,0.28149998,3.710929,,,,,,,,,,,,,,,,, -83600,0.2984291,3.7413712,,,,,,,,,,,,,,,,, -83700,0.30728695,3.8204045,,,,,,,,,,,,,,,,, -83800,0.28324938,3.7622414,,,,,,,,,,,,,,,,, -83900,0.3139744,3.76291,,,,,,,,,,,,,,,,, -84000,0.28532234,3.7833037,,,,,,,,,,,,,,,,, -84100,0.287826,3.772436,,,,,,,,,,,,,,,,, -84200,0.3093048,3.7817693,,,,,,,,,,,,,,,,, -84300,0.27909866,3.7040038,,,,,,,,,,,,,,,,, -84400,0.28050232,3.7752926,,,,,,,,,,,,,,,,, -84500,0.29555246,3.729051,,,,,,,,,,,,,,,,, -84600,0.2769681,3.7343085,,,,,,,,,,,,,,,,, -84700,0.29930422,3.7638388,,,,,,,,,,,,,,,,, -84800,0.2904414,3.752825,,,,,,,,,,,,,,,,, -84900,0.28944874,3.7792833,,,,,,,,,,,,,,,,, -85000,0.27624574,3.705952,,,,,,,,,,,,,,,,, -85100,0.32059208,3.7355902,,,,,,,,,,,,,,,,, -85200,0.31322664,3.7480624,,,,,,,,,,,,,,,,, -85300,0.28141662,3.7063925,,,,,,,,,,,,,,,,, -85400,0.2903796,3.7970862,,,,,,,,,,,,,,,,, -85500,0.3042835,3.71561,,,,,,,,,,,,,,,,, -85600,0.28445145,3.730991,,,,,,,,,,,,,,,,, -85700,0.30128703,3.79675,,,,,,,,,,,,,,,,, -85787,,,0.6846340894699097,1.6518151760101318,34.29854509964452,0.6845916509628296,1.627881407737732,29.913907501321336,3000.0,0.7035732865333557,1.526257038116455,30.16393657409564,3003.0,30269.69832634926,50055.488450050354,30269.69832634926,19781.944012403488,1.1491875648498535,0.0 -85800,0.33693802,3.7372713,,,,,,,,,,,,,,,,, -85900,0.3024061,3.7108154,,,,,,,,,,,,,,,,, -86000,0.2984802,3.752509,,,,,,,,,,,,,,,,, -86100,0.28786096,3.7029812,,,,,,,,,,,,,,,,, -86200,0.28788245,3.7330086,,,,,,,,,,,,,,,,, -86300,0.2787852,3.7421556,,,,,,,,,,,,,,,,, -86400,0.30490503,3.7624953,,,,,,,,,,,,,,,,, -86500,0.31661382,3.785962,,,,,,,,,,,,,,,,, -86600,0.3166994,3.777596,,,,,,,,,,,,,,,,, -86700,0.29446346,3.741675,,,,,,,,,,,,,,,,, -86800,0.30038986,3.7144275,,,,,,,,,,,,,,,,, -86900,0.29010472,3.75354,,,,,,,,,,,,,,,,, -87000,0.2822807,3.6789196,,,,,,,,,,,,,,,,, -87100,0.2948632,3.655376,,,,,,,,,,,,,,,,, -87200,0.28632864,3.7171378,,,,,,,,,,,,,,,,, -87300,0.30683777,3.7077496,,,,,,,,,,,,,,,,, -87400,0.28630155,3.783458,,,,,,,,,,,,,,,,, -87500,0.31802395,3.6966617,,,,,,,,,,,,,,,,, -87600,0.29263496,3.7108324,,,,,,,,,,,,,,,,, -87700,0.29504645,3.7112832,,,,,,,,,,,,,,,,, -87800,0.3294126,3.742176,,,,,,,,,,,,,,,,, -87900,0.30421573,3.7262993,,,,,,,,,,,,,,,,, -88000,0.28787455,3.7150018,,,,,,,,,,,,,,,,, -88100,0.31966612,3.7120242,,,,,,,,,,,,,,,,, -88170,,,0.7073625326156616,1.5367120504379272,36.36983182542021,0.6875426173210144,1.6262383460998535,30.037688410889466,3000.0,0.7039335370063782,1.526638388633728,30.32476310275814,3003.0,31109.80075287819,51547.07810497284,31109.80075287819,20433.31798386573,1.1875977516174316,0.0 -88200,0.31408185,3.7855246,,,,,,,,,,,,,,,,, -88300,0.3045656,3.7964084,,,,,,,,,,,,,,,,, -88400,0.31501576,3.7464993,,,,,,,,,,,,,,,,, -88500,0.31709194,3.7049446,,,,,,,,,,,,,,,,, -88600,0.31081203,3.7485309,,,,,,,,,,,,,,,,, -88700,0.30808893,3.7570226,,,,,,,,,,,,,,,,, -88800,0.30010813,3.7716246,,,,,,,,,,,,,,,,, -88900,0.31334236,3.747237,,,,,,,,,,,,,,,,, -89000,0.32038343,3.7439766,,,,,,,,,,,,,,,,, -89100,0.30221933,3.7305515,,,,,,,,,,,,,,,,, -89200,0.29358634,3.7428904,,,,,,,,,,,,,,,,, -89300,0.2891397,3.688928,,,,,,,,,,,,,,,,, -89400,0.3194002,3.7776525,,,,,,,,,,,,,,,,, -89500,0.3130347,3.7060146,,,,,,,,,,,,,,,,, -89600,0.3143503,3.6786296,,,,,,,,,,,,,,,,, -89700,0.28866267,3.6957543,,,,,,,,,,,,,,,,, -89800,0.305944,3.6804733,,,,,,,,,,,,,,,,, -89900,0.28859797,3.690331,,,,,,,,,,,,,,,,, -90000,0.29389033,3.686382,,,,,,,,,,,,,,,,, -90100,0.30765164,3.7480438,,,,,,,,,,,,,,,,, -90200,0.3490161,3.754518,,,,,,,,,,,,,,,,, -90300,0.3064229,3.7214274,,,,,,,,,,,,,,,,, -90400,0.3184439,3.6895385,,,,,,,,,,,,,,,,, -90500,0.3083993,3.729078,,,,,,,,,,,,,,,,, -90553,,,0.6886942386627197,1.6172972917556765,34.94074502185821,0.6868730783462524,1.6169028282165527,29.63990232421176,3000.0,0.7055139541625977,1.5144983530044556,30.369156672036247,3003.0,31949.7300992012,53018.90501999855,31949.7300992012,21065.105221271515,1.225377082824707,0.0 -90600,0.3088035,3.740828,,,,,,,,,,,,,,,,, -90700,0.30445582,3.7613745,,,,,,,,,,,,,,,,, -90800,0.33086535,3.7858477,,,,,,,,,,,,,,,,, -90900,0.3098204,3.7210548,,,,,,,,,,,,,,,,, -91000,0.30400145,3.7028866,,,,,,,,,,,,,,,,, -91100,0.33833686,3.7301085,,,,,,,,,,,,,,,,, -91200,0.30963507,3.6820042,,,,,,,,,,,,,,,,, -91300,0.28198346,3.6688657,,,,,,,,,,,,,,,,, -91400,0.32662213,3.739017,,,,,,,,,,,,,,,,, -91500,0.30434054,3.7187142,,,,,,,,,,,,,,,,, -91600,0.29378211,3.7208288,,,,,,,,,,,,,,,,, -91700,0.29977527,3.686884,,,,,,,,,,,,,,,,, -91800,0.2946018,3.6994662,,,,,,,,,,,,,,,,, -91900,0.32871845,3.7177672,,,,,,,,,,,,,,,,, -92000,0.35576543,3.7497206,,,,,,,,,,,,,,,,, -92100,0.30788872,3.775427,,,,,,,,,,,,,,,,, -92200,0.29057482,3.6845238,,,,,,,,,,,,,,,,, -92300,0.31423703,3.70254,,,,,,,,,,,,,,,,, -92400,0.3041899,3.6885982,,,,,,,,,,,,,,,,, -92500,0.3239282,3.7614553,,,,,,,,,,,,,,,,, -92600,0.31829697,3.792219,,,,,,,,,,,,,,,,, -92700,0.29598698,3.685208,,,,,,,,,,,,,,,,, -92800,0.31518337,3.7265844,,,,,,,,,,,,,,,,, -92900,0.31610766,3.7033346,,,,,,,,,,,,,,,,, -92936,,,0.6906367540359497,1.6220505237579346,35.15019654432984,0.6880013942718506,1.6124529838562012,30.125960524935937,3000.0,0.7075939774513245,1.5099539756774902,30.907346055474644,3003.0,32789.64198088646,54346.94282460213,32789.64198088646,21553.11935710907,1.263375759124756,0.0 -93000,0.3094518,3.6846473,,,,,,,,,,,,,,,,, -93100,0.3061622,3.6647007,,,,,,,,,,,,,,,,, -93200,0.32843027,3.734151,,,,,,,,,,,,,,,,, -93300,0.31996095,3.7465205,,,,,,,,,,,,,,,,, -93400,0.31770402,3.7453146,,,,,,,,,,,,,,,,, -93500,0.3089782,3.7251818,,,,,,,,,,,,,,,,, -93600,0.3110201,3.72829,,,,,,,,,,,,,,,,, -93700,0.35533774,3.7762368,,,,,,,,,,,,,,,,, -93800,0.348656,3.7972724,,,,,,,,,,,,,,,,, -93900,0.31699225,3.685242,,,,,,,,,,,,,,,,, -94000,0.31225386,3.7306664,,,,,,,,,,,,,,,,, -94100,0.31302172,3.7455807,,,,,,,,,,,,,,,,, -94200,0.29868713,3.6910267,,,,,,,,,,,,,,,,, -94300,0.3130381,3.6911442,,,,,,,,,,,,,,,,, -94400,0.31767002,3.758811,,,,,,,,,,,,,,,,, -94500,0.3024596,3.7043319,,,,,,,,,,,,,,,,, -94600,0.3138598,3.7071147,,,,,,,,,,,,,,,,, -94700,0.3302102,3.759371,,,,,,,,,,,,,,,,, -94800,0.32646406,3.7735913,,,,,,,,,,,,,,,,, -94900,0.29908642,3.7490442,,,,,,,,,,,,,,,,, -95000,0.33044308,3.7054195,,,,,,,,,,,,,,,,, -95100,0.30384663,3.6326528,,,,,,,,,,,,,,,,, -95200,0.31955853,3.757239,,,,,,,,,,,,,,,,, -95300,0.34337983,3.7209008,,,,,,,,,,,,,,,,, -95319,,,0.7033336758613586,1.5496033430099487,35.7991515308961,0.6894024610519409,1.6060791015625,30.590930186553223,3000.0,0.7068386673927307,1.5075197219848633,30.55684966808048,3003.0,33629.78120470047,55705.952450037,33629.78120470047,22071.868659973145,1.3095154762268066,0.0 -95400,0.3190637,3.7166226,,,,,,,,,,,,,,,,, -95500,0.31375426,3.7157946,,,,,,,,,,,,,,,,, -95600,0.31900358,3.6709938,,,,,,,,,,,,,,,,, -95700,0.3288189,3.709352,,,,,,,,,,,,,,,,, -95800,0.3095267,3.708266,,,,,,,,,,,,,,,,, -95900,0.3241069,3.6972435,,,,,,,,,,,,,,,,, -96000,0.3188852,3.757909,,,,,,,,,,,,,,,,, -96100,0.3220809,3.6936111,,,,,,,,,,,,,,,,, -96200,0.32991788,3.7197032,,,,,,,,,,,,,,,,, -96300,0.31960124,3.683851,,,,,,,,,,,,,,,,, -96400,0.36495793,3.758399,,,,,,,,,,,,,,,,, -96500,0.3141196,3.6639836,,,,,,,,,,,,,,,,, -96600,0.3563379,3.7222543,,,,,,,,,,,,,,,,, -96700,0.31811664,3.75842,,,,,,,,,,,,,,,,, -96800,0.32610223,3.7156863,,,,,,,,,,,,,,,,, -96900,0.33634648,3.6903613,,,,,,,,,,,,,,,,, -97000,0.3245576,3.7396417,,,,,,,,,,,,,,,,, -97100,0.3141251,3.6726792,,,,,,,,,,,,,,,,, -97200,0.3358172,3.7456412,,,,,,,,,,,,,,,,, -97300,0.32607037,3.7239652,,,,,,,,,,,,,,,,, -97400,0.32335863,3.6660864,,,,,,,,,,,,,,,,, -97500,0.31996834,3.657435,,,,,,,,,,,,,,,,, -97600,0.32495663,3.673525,,,,,,,,,,,,,,,,, -97700,0.33100194,3.7496312,,,,,,,,,,,,,,,,, -97702,,,0.6969262957572937,1.5804659128189087,35.152333339015506,0.6876417994499207,1.6106321811676023,30.332051794275426,3000.0,0.7058508992195129,1.5085893869400024,30.394577530924465,3003.0,34469.73361515999,57028.49427843094,34469.73361515999,22554.345523118973,1.3491921424865725,0.0 -97800,0.33649662,3.6709988,,,,,,,,,,,,,,,,, -97900,0.31941274,3.6530786,,,,,,,,,,,,,,,,, -98000,0.33404493,3.7278707,,,,,,,,,,,,,,,,, -98100,0.33453125,3.72151,,,,,,,,,,,,,,,,, -98200,0.3055714,3.651727,,,,,,,,,,,,,,,,, -98300,0.33908314,3.7403593,,,,,,,,,,,,,,,,, -98400,0.34238788,3.7411265,,,,,,,,,,,,,,,,, -98500,0.32592845,3.695485,,,,,,,,,,,,,,,,, -98600,0.33577153,3.6278446,,,,,,,,,,,,,,,,, -98700,0.33553892,3.6695924,,,,,,,,,,,,,,,,, -98800,0.335651,3.7262878,,,,,,,,,,,,,,,,, -98900,0.31989405,3.7168589,,,,,,,,,,,,,,,,, -99000,0.32149553,3.6851542,,,,,,,,,,,,,,,,, -99100,0.3225605,3.6874352,,,,,,,,,,,,,,,,, -99200,0.31900764,3.6632326,,,,,,,,,,,,,,,,, -99300,0.31996912,3.6651952,,,,,,,,,,,,,,,,, -99400,0.31964818,3.7122939,,,,,,,,,,,,,,,,, -99500,0.32567778,3.6803157,,,,,,,,,,,,,,,,, -99600,0.32901794,3.678584,,,,,,,,,,,,,,,,, -99700,0.33454213,3.7024095,,,,,,,,,,,,,,,,, -99800,0.33168915,3.664664,,,,,,,,,,,,,,,,, -99900,0.34756568,3.673591,,,,,,,,,,,,,,,,, -100000,0.32662138,3.687026,,,,,,,,,,,,,,,,, -100084,,,0.7314824461936951,1.4231904745101929,38.33886509323845,0.6890305280685425,1.6068017482757568,30.71176452861133,3000.0,0.7073267102241516,1.5063436031341553,30.53523360311065,3003.0,35309.84344100952,58361.5737016201,35309.84344100952,23047.190549373627,1.3953723907470703,0.0 -100100,0.33956602,3.7057323,,,,,,,,,,,,,,,,, -100200,0.32714292,3.681033,,,,,,,,,,,,,,,,, -100300,0.32626712,3.6271608,,,,,,,,,,,,,,,,, -100400,0.3293426,3.7394626,,,,,,,,,,,,,,,,, -100500,0.32637307,3.7015631,,,,,,,,,,,,,,,,, -100600,0.33105278,3.6906939,,,,,,,,,,,,,,,,, -100700,0.33126226,3.6417465,,,,,,,,,,,,,,,,, -100800,0.3175649,3.6330068,,,,,,,,,,,,,,,,, -100900,0.33369395,3.7481024,,,,,,,,,,,,,,,,, -101000,0.3581497,3.6834981,,,,,,,,,,,,,,,,, -101100,0.32368013,3.6539783,,,,,,,,,,,,,,,,, -101200,0.33694896,3.684292,,,,,,,,,,,,,,,,, -101300,0.33012918,3.6636798,,,,,,,,,,,,,,,,, -101400,0.34437454,3.6779306,,,,,,,,,,,,,,,,, -101500,0.3483148,3.7176924,,,,,,,,,,,,,,,,, -101600,0.34765142,3.743191,,,,,,,,,,,,,,,,, -101700,0.34440306,3.7242808,,,,,,,,,,,,,,,,, -101800,0.34886676,3.733612,,,,,,,,,,,,,,,,, -101900,0.3461616,3.730155,,,,,,,,,,,,,,,,, -102000,0.34017903,3.7050855,,,,,,,,,,,,,,,,, -102100,0.33304545,3.6225407,,,,,,,,,,,,,,,,, -102200,0.34637487,3.6970606,,,,,,,,,,,,,,,,, -102300,0.32757846,3.6818984,,,,,,,,,,,,,,,,, -102400,0.32790816,3.6728132,,,,,,,,,,,,,,,,, -102466,,,0.7023757696151733,1.553908348083496,36.134819561514,0.6896008849143982,1.6050572395324707,30.46585574558103,3000.0,0.7066527605056763,1.5068280696868896,30.829963453250123,3003.0,36149.92675304413,59711.75960898399,36149.92675304413,23557.168686389923,1.442002773284912,0.0 -102500,0.33154804,3.6768162,,,,,,,,,,,,,,,,, -102600,0.36578172,3.702713,,,,,,,,,,,,,,,,, -102700,0.34613425,3.7487059,,,,,,,,,,,,,,,,, -102800,0.34537408,3.6568244,,,,,,,,,,,,,,,,, -102900,0.3302511,3.6566176,,,,,,,,,,,,,,,,, -103000,0.3455934,3.661617,,,,,,,,,,,,,,,,, -103100,0.33789036,3.679997,,,,,,,,,,,,,,,,, -103200,0.35349327,3.6652956,,,,,,,,,,,,,,,,, -103300,0.3350512,3.6792471,,,,,,,,,,,,,,,,, -103400,0.35195732,3.6815805,,,,,,,,,,,,,,,,, -103500,0.3395593,3.6359956,,,,,,,,,,,,,,,,, -103600,0.3553757,3.6389902,,,,,,,,,,,,,,,,, -103700,0.33866274,3.6936758,,,,,,,,,,,,,,,,, -103800,0.33498767,3.6683831,,,,,,,,,,,,,,,,, -103900,0.3465086,3.6435337,,,,,,,,,,,,,,,,, -104000,0.33468515,3.6537626,,,,,,,,,,,,,,,,, -104100,0.33842552,3.656694,,,,,,,,,,,,,,,,, -104200,0.3549105,3.661543,,,,,,,,,,,,,,,,, -104300,0.33674547,3.6557608,,,,,,,,,,,,,,,,, -104400,0.3477839,3.6796277,,,,,,,,,,,,,,,,, -104500,0.36280307,3.6953819,,,,,,,,,,,,,,,,, -104600,0.33781555,3.681042,,,,,,,,,,,,,,,,, -104700,0.35308826,3.6624825,,,,,,,,,,,,,,,,, -104800,0.3368321,3.634031,,,,,,,,,,,,,,,,, -104849,,,0.6993989944458008,1.559681415557861,35.754191956860694,0.6894644498825073,1.6042819023132324,30.594905174890503,3000.0,0.7081517577171326,1.4993467330932615,30.92962167513607,3003.0,36989.896939754486,61037.63996696472,36989.896939754486,24042.96623015404,1.4812438488006592,0.0 -104900,0.35719964,3.6839468,,,,,,,,,,,,,,,,, -105000,0.34107712,3.651855,,,,,,,,,,,,,,,,, -105100,0.33796096,3.6254628,,,,,,,,,,,,,,,,, -105200,0.34389263,3.6980488,,,,,,,,,,,,,,,,, -105300,0.33799762,3.6633425,,,,,,,,,,,,,,,,, -105400,0.3450342,3.6633668,,,,,,,,,,,,,,,,, -105500,0.34649473,3.6296687,,,,,,,,,,,,,,,,, -105600,0.38019422,3.6452281,,,,,,,,,,,,,,,,, -105700,0.3364301,3.6593823,,,,,,,,,,,,,,,,, -105800,0.35976112,3.6991158,,,,,,,,,,,,,,,,, -105900,0.34525645,3.700347,,,,,,,,,,,,,,,,, -106000,0.34543392,3.6726685,,,,,,,,,,,,,,,,, -106100,0.38793018,3.6702583,,,,,,,,,,,,,,,,, -106200,0.34930107,3.5852714,,,,,,,,,,,,,,,,, -106300,0.35591137,3.652982,,,,,,,,,,,,,,,,, -106400,0.35674006,3.6323628,,,,,,,,,,,,,,,,, -106500,0.36788228,3.6843781,,,,,,,,,,,,,,,,, -106600,0.35386017,3.7198129,,,,,,,,,,,,,,,,, -106700,0.37226406,3.6275277,,,,,,,,,,,,,,,,, -106800,0.3513761,3.6674762,,,,,,,,,,,,,,,,, -106900,0.341779,3.621026,,,,,,,,,,,,,,,,, -107000,0.34866023,3.670657,,,,,,,,,,,,,,,,, -107100,0.3549184,3.6663873,,,,,,,,,,,,,,,,, -107200,0.36189148,3.6521702,,,,,,,,,,,,,,,,, -107233,,,0.7179781794548035,1.4720852375030518,37.12911174176097,0.6896132826805115,1.6041953563690186,30.57522490142288,3000.0,0.7072221636772156,1.5013386011123655,30.69685450178556,3003.0,37830.11774921417,62362.86051940918,37830.11774921417,24527.8506731987,1.5237393379211426,0.0 -107300,0.33614266,3.6700022,,,,,,,,,,,,,,,,, -107400,0.3570749,3.7163217,,,,,,,,,,,,,,,,, -107500,0.35634327,3.632464,,,,,,,,,,,,,,,,, -107600,0.35999933,3.6044028,,,,,,,,,,,,,,,,, -107700,0.3556035,3.6703327,,,,,,,,,,,,,,,,, -107800,0.37808022,3.715761,,,,,,,,,,,,,,,,, -107900,0.3587266,3.6581078,,,,,,,,,,,,,,,,, -108000,0.3693275,3.6647308,,,,,,,,,,,,,,,,, -108100,0.34662184,3.6417983,,,,,,,,,,,,,,,,, -108200,0.36418027,3.6717002,,,,,,,,,,,,,,,,, -108300,0.35177737,3.6932766,,,,,,,,,,,,,,,,, -108400,0.34754443,3.6824236,,,,,,,,,,,,,,,,, -108500,0.35671493,3.6552439,,,,,,,,,,,,,,,,, -108600,0.3619693,3.693587,,,,,,,,,,,,,,,,, -108700,0.36228415,3.6426787,,,,,,,,,,,,,,,,, -108800,0.36206475,3.6633818,,,,,,,,,,,,,,,,, -108900,0.3448359,3.6152174,,,,,,,,,,,,,,,,, -109000,0.36172363,3.6466422,,,,,,,,,,,,,,,,, -109100,0.33627164,3.655938,,,,,,,,,,,,,,,,, -109200,0.372427,3.6392834,,,,,,,,,,,,,,,,, -109300,0.36655897,3.6617386,,,,,,,,,,,,,,,,, -109400,0.35841966,3.6939552,,,,,,,,,,,,,,,,, -109500,0.36430493,3.6364393,,,,,,,,,,,,,,,,, -109600,0.3506239,3.640911,,,,,,,,,,,,,,,,, -109615,,,0.710374653339386,1.5148351192474363,36.08658871891254,0.690220832824707,1.601181983947754,30.252796426395733,3000.0,0.7097089290618896,1.493978500366211,30.87070303183882,3003.0,38670.0972571373,63713.02125096321,38670.0972571373,25037.91539287567,1.563863754272461,0.0 -109700,0.36643568,3.6805158,,,,,,,,,,,,,,,,, -109800,0.3562972,3.6480768,,,,,,,,,,,,,,,,, -109900,0.35322574,3.6778367,,,,,,,,,,,,,,,,, -110000,0.36016104,3.6556413,,,,,,,,,,,,,,,,, -110100,0.35301918,3.6540093,,,,,,,,,,,,,,,,, -110200,0.37353304,3.6757119,,,,,,,,,,,,,,,,, -110300,0.3606571,3.6855345,,,,,,,,,,,,,,,,, -110400,0.3685629,3.6259847,,,,,,,,,,,,,,,,, -110500,0.3635669,3.6478953,,,,,,,,,,,,,,,,, -110600,0.35097682,3.6489978,,,,,,,,,,,,,,,,, -110700,0.3515648,3.641,,,,,,,,,,,,,,,,, -110800,0.37008741,3.643967,,,,,,,,,,,,,,,,, -110900,0.36141565,3.7067204,,,,,,,,,,,,,,,,, -111000,0.34366176,3.6481333,,,,,,,,,,,,,,,,, -111100,0.3581412,3.6406124,,,,,,,,,,,,,,,,, -111200,0.37118334,3.6117659,,,,,,,,,,,,,,,,, -111300,0.3813133,3.6397882,,,,,,,,,,,,,,,,, -111400,0.37010816,3.6408522,,,,,,,,,,,,,,,,, -111500,0.3901853,3.655686,,,,,,,,,,,,,,,,, -111600,0.36163017,3.6530595,,,,,,,,,,,,,,,,, -111700,0.3726026,3.6255755,,,,,,,,,,,,,,,,, -111800,0.3802357,3.5911126,,,,,,,,,,,,,,,,, -111900,0.37815994,3.6402578,,,,,,,,,,,,,,,,, -111997,,,0.7104139924049377,1.5093110799789429,36.71417614009322,0.6904935836791992,1.6021604537963867,30.74636830222096,3000.0,0.7084655165672302,1.4979416131973269,30.64387931009616,3003.0,39510.23357391357,65037.137236356735,39510.23357391357,25521.773129224777,1.6069166660308838,0.0 -112000,0.36710906,3.6258972,,,,,,,,,,,,,,,,, -112100,0.35847893,3.651915,,,,,,,,,,,,,,,,, -112200,0.3586922,3.605431,,,,,,,,,,,,,,,,, -112300,0.3614443,3.6502666,,,,,,,,,,,,,,,,, -112400,0.3907239,3.6411338,,,,,,,,,,,,,,,,, -112500,0.3754698,3.6286108,,,,,,,,,,,,,,,,, -112600,0.38260937,3.6363602,,,,,,,,,,,,,,,,, -112700,0.36188948,3.6759207,,,,,,,,,,,,,,,,, -112800,0.37028143,3.6741428,,,,,,,,,,,,,,,,, -112900,0.35560468,3.5969338,,,,,,,,,,,,,,,,, -113000,0.37350753,3.6624084,,,,,,,,,,,,,,,,, -113100,0.38283584,3.6678069,,,,,,,,,,,,,,,,, -113200,0.38421923,3.6753132,,,,,,,,,,,,,,,,, -113300,0.37111956,3.677785,,,,,,,,,,,,,,,,, -113400,0.37160236,3.6470478,,,,,,,,,,,,,,,,, -113500,0.37632096,3.5840583,,,,,,,,,,,,,,,,, -113600,0.3727267,3.6254134,,,,,,,,,,,,,,,,, -113700,0.36658335,3.6071417,,,,,,,,,,,,,,,,, -113800,0.36880472,3.6120884,,,,,,,,,,,,,,,,, -113900,0.37878716,3.6454294,,,,,,,,,,,,,,,,, -114000,0.3830986,3.5860517,,,,,,,,,,,,,,,,, -114100,0.3625866,3.629841,,,,,,,,,,,,,,,,, -114200,0.35729182,3.5880578,,,,,,,,,,,,,,,,, -114300,0.38464546,3.567522,,,,,,,,,,,,,,,,, -114380,,,0.7173652052879333,1.4794282913208008,37.131145115956905,0.6896628737449646,1.6034998893737793,30.39730108265441,3000.0,0.7098019123077393,1.4975818395614624,30.778936015094025,3003.0,40350.37732386589,66362.9657073021,40350.37732386589,26007.337995052338,1.6503000259399414,0.0 -114400,0.36716923,3.6182063,,,,,,,,,,,,,,,,, -114500,0.37081194,3.633432,,,,,,,,,,,,,,,,, -114600,0.38909516,3.676659,,,,,,,,,,,,,,,,, -114700,0.36935115,3.6313813,,,,,,,,,,,,,,,,, -114800,0.38012403,3.6125066,,,,,,,,,,,,,,,,, -114900,0.37232703,3.6841226,,,,,,,,,,,,,,,,, -115000,0.3688534,3.619511,,,,,,,,,,,,,,,,, -115100,0.37043074,3.6111434,,,,,,,,,,,,,,,,, -115200,0.3855329,3.6634321,,,,,,,,,,,,,,,,, -115300,0.37375468,3.6255748,,,,,,,,,,,,,,,,, -115400,0.3723489,3.639783,,,,,,,,,,,,,,,,, -115500,0.35952243,3.6100857,,,,,,,,,,,,,,,,, -115600,0.36937436,3.6176395,,,,,,,,,,,,,,,,, -115700,0.3503178,3.5789056,,,,,,,,,,,,,,,,, -115800,0.37728187,3.6519783,,,,,,,,,,,,,,,,, -115900,0.37143537,3.6172428,,,,,,,,,,,,,,,,, -116000,0.38527346,3.6171834,,,,,,,,,,,,,,,,, -116100,0.37285104,3.6232264,,,,,,,,,,,,,,,,, -116200,0.37927386,3.6761723,,,,,,,,,,,,,,,,, -116300,0.3718545,3.5458167,,,,,,,,,,,,,,,,, -116400,0.37194982,3.6259382,,,,,,,,,,,,,,,,, -116500,0.3595461,3.5921555,,,,,,,,,,,,,,,,, -116600,0.3770078,3.580675,,,,,,,,,,,,,,,,, -116700,0.38287753,3.6518428,,,,,,,,,,,,,,,,, -116764,,,0.7147499918937683,1.4900606870651243,36.79285558146597,0.6898736357688904,1.601643681526184,30.654046685474967,3000.0,0.7097205519676208,1.494398832321167,30.84845509744524,3003.0,41190.49334859848,67689.97535419464,41190.49334859848,26494.117411851883,1.6911771297454834,0.0 -116800,0.38754928,3.647185,,,,,,,,,,,,,,,,, -116900,0.37206346,3.644124,,,,,,,,,,,,,,,,, -117000,0.3579332,3.5744767,,,,,,,,,,,,,,,,, -117100,0.37622148,3.6295342,,,,,,,,,,,,,,,,, -117200,0.38912767,3.6294625,,,,,,,,,,,,,,,,, -117300,0.41372928,3.671488,,,,,,,,,,,,,,,,, -117400,0.38503954,3.6281173,,,,,,,,,,,,,,,,, -117500,0.37167782,3.6062765,,,,,,,,,,,,,,,,, -117600,0.38730776,3.6092434,,,,,,,,,,,,,,,,, -117700,0.3706799,3.5984192,,,,,,,,,,,,,,,,, -117800,0.38054398,3.6009948,,,,,,,,,,,,,,,,, -117900,0.36906496,3.5927806,,,,,,,,,,,,,,,,, -118000,0.3878047,3.5472214,,,,,,,,,,,,,,,,, -118100,0.36712933,3.573346,,,,,,,,,,,,,,,,, -118200,0.38974202,3.6453466,,,,,,,,,,,,,,,,, -118300,0.375045,3.6239972,,,,,,,,,,,,,,,,, -118400,0.37369165,3.6102405,,,,,,,,,,,,,,,,, -118500,0.40706405,3.6951375,,,,,,,,,,,,,,,,, -118600,0.40577215,3.726021,,,,,,,,,,,,,,,,, -118700,0.39477557,3.662655,,,,,,,,,,,,,,,,, -118800,0.36346516,3.5614753,,,,,,,,,,,,,,,,, -118900,0.37482882,3.6341183,,,,,,,,,,,,,,,,, -119000,0.37677416,3.6044416,,,,,,,,,,,,,,,,, -119100,0.38760853,3.5782588,,,,,,,,,,,,,,,,, -119147,,,0.7253099083900452,1.4400979280471802,38.21547448718535,0.6901092529296875,1.599493384361267,30.736349330402373,3000.0,0.709674060344696,1.493552803993225,30.740337151689708,3003.0,42030.68079471588,69001.66783833504,42030.68079471588,26965.506422758102,1.7330989837646484,0.0 -119200,0.38598895,3.6366873,,,,,,,,,,,,,,,,, -119300,0.381721,3.601826,,,,,,,,,,,,,,,,, -119400,0.38037652,3.6229818,,,,,,,,,,,,,,,,, -119500,0.40040818,3.6056192,,,,,,,,,,,,,,,,, -119600,0.38694888,3.6163788,,,,,,,,,,,,,,,,, -119700,0.36635888,3.553118,,,,,,,,,,,,,,,,, -119800,0.37411416,3.6233866,,,,,,,,,,,,,,,,, -119900,0.38669991,3.6133149,,,,,,,,,,,,,,,,, -120000,0.38642362,3.6066887,,,,,,,,,,,,,,,,, -120100,0.36985758,3.5906794,,,,,,,,,,,,,,,,, -120200,0.37541783,3.6208284,,,,,,,,,,,,,,,,, -120300,0.39937386,3.6210897,,,,,,,,,,,,,,,,, -120400,0.38691956,3.647379,,,,,,,,,,,,,,,,, -120500,0.37453038,3.5606198,,,,,,,,,,,,,,,,, -120600,0.37907118,3.6339986,,,,,,,,,,,,,,,,, -120700,0.41394797,3.601796,,,,,,,,,,,,,,,,, -120800,0.3757306,3.6355002,,,,,,,,,,,,,,,,, -120900,0.3715643,3.5728607,,,,,,,,,,,,,,,,, -121000,0.37264198,3.5847619,,,,,,,,,,,,,,,,, -121100,0.37417465,3.6361396,,,,,,,,,,,,,,,,, -121200,0.3720298,3.5690532,,,,,,,,,,,,,,,,, -121300,0.372418,3.561767,,,,,,,,,,,,,,,,, -121400,0.39049068,3.600969,,,,,,,,,,,,,,,,, -121500,0.38698927,3.6826632,,,,,,,,,,,,,,,,, -121529,,,0.7211779356002808,1.4619131088256836,37.837886173614265,0.6910391449928284,1.5997225046157837,30.736317737646832,3000.0,0.7105339765548706,1.4927688837051392,30.733517509111195,3003.0,42870.59420728684,70363.89932894707,42870.59420728684,27487.700913906097,1.7821390628814695,0.0 -121600,0.36634797,3.5681877,,,,,,,,,,,,,,,,, -121700,0.38647044,3.5807037,,,,,,,,,,,,,,,,, -121800,0.38190746,3.6401055,,,,,,,,,,,,,,,,, -121900,0.37758055,3.5997016,,,,,,,,,,,,,,,,, -122000,0.36155877,3.5746891,,,,,,,,,,,,,,,,, -122100,0.40115812,3.634511,,,,,,,,,,,,,,,,, -122200,0.38560146,3.601572,,,,,,,,,,,,,,,,, -122300,0.37571082,3.5867906,,,,,,,,,,,,,,,,, -122400,0.41644713,3.6753054,,,,,,,,,,,,,,,,, -122500,0.36711633,3.5923567,,,,,,,,,,,,,,,,, -122600,0.36685777,3.5881064,,,,,,,,,,,,,,,,, -122700,0.40811914,3.6583376,,,,,,,,,,,,,,,,, -122800,0.37153506,3.5550108,,,,,,,,,,,,,,,,, -122900,0.38717583,3.610049,,,,,,,,,,,,,,,,, -123000,0.3942377,3.6070182,,,,,,,,,,,,,,,,, -123100,0.3721373,3.5955904,,,,,,,,,,,,,,,,, -123200,0.38622552,3.6316583,,,,,,,,,,,,,,,,, -123300,0.38559145,3.5875838,,,,,,,,,,,,,,,,, -123400,0.3845051,3.565373,,,,,,,,,,,,,,,,, -123500,0.3931137,3.6298816,,,,,,,,,,,,,,,,, -123600,0.3844695,3.5977163,,,,,,,,,,,,,,,,, -123700,0.3863331,3.6207552,,,,,,,,,,,,,,,,, -123800,0.36848852,3.6108928,,,,,,,,,,,,,,,,, -123900,0.38798475,3.6266608,,,,,,,,,,,,,,,,, -123912,,,0.7188615202903748,1.4707727432250977,37.37068803202791,0.6907415986061096,1.5994369983673096,30.730796295851167,3000.0,0.7107663750648499,1.4905011653900146,30.85889317102481,3003.0,43710.78532648087,71700.7684583664,43710.78532648087,27984.25898528099,1.8253540992736816,0.0 -124000,0.37962514,3.5901406,,,,,,,,,,,,,,,,, -124100,0.39451885,3.5998895,,,,,,,,,,,,,,,,, -124200,0.4005135,3.626257,,,,,,,,,,,,,,,,, -124300,0.3921626,3.629556,,,,,,,,,,,,,,,,, -124400,0.38339883,3.6316476,,,,,,,,,,,,,,,,, -124500,0.40879327,3.6495876,,,,,,,,,,,,,,,,, -124600,0.37885916,3.624472,,,,,,,,,,,,,,,,, -124700,0.39522052,3.6328244,,,,,,,,,,,,,,,,, -124800,0.37867624,3.5540943,,,,,,,,,,,,,,,,, -124900,0.4040949,3.6026058,,,,,,,,,,,,,,,,, -125000,0.3837568,3.536398,,,,,,,,,,,,,,,,, -125100,0.4042482,3.6487877,,,,,,,,,,,,,,,,, -125200,0.3928908,3.6241865,,,,,,,,,,,,,,,,, -125300,0.38331118,3.615827,,,,,,,,,,,,,,,,, -125400,0.37837967,3.6236384,,,,,,,,,,,,,,,,, -125500,0.3773389,3.610164,,,,,,,,,,,,,,,,, -125600,0.37635034,3.609951,,,,,,,,,,,,,,,,, -125700,0.37935206,3.5973346,,,,,,,,,,,,,,,,, -125800,0.3967336,3.6326222,,,,,,,,,,,,,,,,, -125900,0.3817283,3.5600476,,,,,,,,,,,,,,,,, -126000,0.39855847,3.6116378,,,,,,,,,,,,,,,,, -126100,0.3892066,3.6100829,,,,,,,,,,,,,,,,, -126200,0.39520586,3.6289604,,,,,,,,,,,,,,,,, -126295,,,0.7226437330245972,1.4515371322631836,38.29304533888482,0.6909151673316956,1.599329710006714,30.73582258184852,3000.0,0.7114520072937012,1.4903055429458618,30.760997036011428,3003.0,44550.9940226078,73051.0552790165,44550.9940226078,28494.218544483185,1.867745876312256,0.0 -126300,0.38097957,3.6290138,,,,,,,,,,,,,,,,, -126400,0.38001528,3.640281,,,,,,,,,,,,,,,,, -126500,0.3724874,3.5681252,,,,,,,,,,,,,,,,, -126600,0.3814512,3.6367757,,,,,,,,,,,,,,,,, -126700,0.3721823,3.6071332,,,,,,,,,,,,,,,,, -126800,0.3739372,3.58134,,,,,,,,,,,,,,,,, -126900,0.40060765,3.6512074,,,,,,,,,,,,,,,,, -127000,0.38229224,3.5659635,,,,,,,,,,,,,,,,, -127100,0.3936613,3.5937335,,,,,,,,,,,,,,,,, -127200,0.3841741,3.6244352,,,,,,,,,,,,,,,,, -127300,0.3779478,3.6284702,,,,,,,,,,,,,,,,, -127400,0.4108704,3.7192202,,,,,,,,,,,,,,,,, -127500,0.3987561,3.587624,,,,,,,,,,,,,,,,, -127600,0.40243027,3.6474996,,,,,,,,,,,,,,,,, -127700,0.39704397,3.6137457,,,,,,,,,,,,,,,,, -127800,0.38548088,3.6241817,,,,,,,,,,,,,,,,, -127900,0.38270816,3.5765176,,,,,,,,,,,,,,,,, -128000,0.41218495,3.6145966,,,,,,,,,,,,,,,,, -128100,0.380929,3.5814009,,,,,,,,,,,,,,,,, -128200,0.38657755,3.5565438,,,,,,,,,,,,,,,,, -128300,0.3827588,3.5685728,,,,,,,,,,,,,,,,, -128400,0.381014,3.5921512,,,,,,,,,,,,,,,,, -128500,0.37834874,3.551559,,,,,,,,,,,,,,,,, -128600,0.3773635,3.612779,,,,,,,,,,,,,,,,, -128677,,,0.7253398895263672,1.437965750694275,38.15674927738666,0.6909027695655823,1.5986698865890503,30.539796950932832,3000.0,0.7110103964805603,1.489971160888672,30.711279312995767,3003.0,45391.09645104408,74407.53640341759,45391.09645104408,29010.47008228302,1.919017314910889,0.0 -128700,0.3907149,3.6168742,,,,,,,,,,,,,,,,, -128800,0.37178954,3.5683897,,,,,,,,,,,,,,,,, -128900,0.4003298,3.627666,,,,,,,,,,,,,,,,, -129000,0.38914979,3.6465828,,,,,,,,,,,,,,,,, -129100,0.3774921,3.5964274,,,,,,,,,,,,,,,,, -129200,0.39958218,3.6087646,,,,,,,,,,,,,,,,, -129300,0.3982173,3.6267152,,,,,,,,,,,,,,,,, -129400,0.38502932,3.6047597,,,,,,,,,,,,,,,,, -129500,0.37726882,3.569039,,,,,,,,,,,,,,,,, -129600,0.39687535,3.5945413,,,,,,,,,,,,,,,,, -129700,0.38755453,3.5726159,,,,,,,,,,,,,,,,, -129800,0.38304093,3.572586,,,,,,,,,,,,,,,,, -129900,0.41276506,3.6281776,,,,,,,,,,,,,,,,, -130000,0.36806455,3.5577834,,,,,,,,,,,,,,,,, -130100,0.38367194,3.5935006,,,,,,,,,,,,,,,,, -130200,0.39316258,3.6242142,,,,,,,,,,,,,,,,, -130300,0.38553393,3.604751,,,,,,,,,,,,,,,,, -130400,0.3866892,3.6249652,,,,,,,,,,,,,,,,, -130500,0.3809268,3.583261,,,,,,,,,,,,,,,,, -130600,0.38114983,3.6139102,,,,,,,,,,,,,,,,, -130700,0.36943656,3.575307,,,,,,,,,,,,,,,,, -130800,0.38099575,3.6172707,,,,,,,,,,,,,,,,, -130900,0.3939498,3.6334147,,,,,,,,,,,,,,,,, -131000,0.3832871,3.616247,,,,,,,,,,,,,,,,, -131059,,,0.724202036857605,1.443503499031067,38.04270301127323,0.6908655762672424,1.599054217338562,30.667733220414544,3000.0,0.7107431292533875,1.490159273147583,30.77321518935283,3003.0,46231.21893525124,75766.89620161057,46231.21893525124,29529.583674669266,1.9653499126434328,0.0 -131100,0.3920844,3.6367886,,,,,,,,,,,,,,,,, -131200,0.3888253,3.6246905,,,,,,,,,,,,,,,,, -131300,0.39512318,3.6153839,,,,,,,,,,,,,,,,, -131400,0.37972826,3.5915604,,,,,,,,,,,,,,,,, -131500,0.36819184,3.5638719,,,,,,,,,,,,,,,,, -131600,0.3931399,3.6467323,,,,,,,,,,,,,,,,, -131700,0.38934273,3.648162,,,,,,,,,,,,,,,,, -131800,0.37591052,3.6099474,,,,,,,,,,,,,,,,, -131900,0.36560667,3.5386164,,,,,,,,,,,,,,,,, -132000,0.3965348,3.6553662,,,,,,,,,,,,,,,,, -132100,0.39161786,3.5918312,,,,,,,,,,,,,,,,, -132200,0.37798938,3.6075318,,,,,,,,,,,,,,,,, -132300,0.39134595,3.6302502,,,,,,,,,,,,,,,,, -132400,0.3780763,3.5712361,,,,,,,,,,,,,,,,, -132500,0.3910327,3.6042457,,,,,,,,,,,,,,,,, -132600,0.39897823,3.627102,,,,,,,,,,,,,,,,, -132700,0.38583797,3.5985117,,,,,,,,,,,,,,,,, -132800,0.39989707,3.640325,,,,,,,,,,,,,,,,, -132900,0.3945519,3.615336,,,,,,,,,,,,,,,,, -133000,0.37851515,3.562757,,,,,,,,,,,,,,,,, -133100,0.39287302,3.6395683,,,,,,,,,,,,,,,,, -133200,0.37146458,3.615749,,,,,,,,,,,,,,,,, -133300,0.37000135,3.561571,,,,,,,,,,,,,,,,, -133333,,,0.7267444133758545,1.427182912826538,37.58713677051531,0.6907787919044495,1.5991333723068235,30.66343634971797,3000.0,0.7107315063476562,1.4902734756469729,30.82135427403008,3003.0,47033.39378809929,77092.4309284687,47033.39378809929,30052.82520580292,2.0094943046569824,0.0 -133333,,,,,,,,,,,,,,47033.39378809929,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/eval_measurements.csv deleted file mode 100644 index f06a0617f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -878.3704223632812,0.0,26.88200736045837,1,0,26.88200736045837,0.0007088489946909,0.0,11.19086742401123,3003,905.2524693012238,0.0005962961004115,0.0,11.175597190856934,0.0004835649742744,0.0,11.208685874938965,3000 -1604.2824800014496,0.0188047885894775,866.837769985199,2379,0,866.837769985199,0.375143826007843,7.492265560875259,4.313692092895508,3003,2471.213627576828,0.4058598577976227,14.225651382899391,3.990895748138428,0.3933243155479431,9.435176775149335,4.097593307495117,3000 -2088.90660238266,0.0470664501190185,1707.074345588684,4757,0,1707.074345588684,0.5409215092658997,18.88999064105362,2.6869993209838867,3003,3796.1771688461295,0.5393306016921997,24.873358742540944,2.7030396461486816,0.542708694934845,20.43187495336819,2.6487877368927,3000 -2560.031564474106,0.0735063552856445,2547.0738015174866,7136,0,2547.0738015174866,0.586462140083313,22.075670983383866,2.236505508422852,3003,5107.401966333389,0.5833343267440796,27.53477110939534,2.2597906589508057,0.5829437971115112,23.01984843027956,2.237121105194092,3000 -3023.601635694504,0.1044983863830566,3387.072328567505,9516,0,3387.072328567505,0.6093777418136597,23.438955386735778,2.014265537261963,3003,6411.077891111374,0.5903454422950745,28.423563496088704,2.161125421524048,0.60672527551651,24.557212452962293,2.036038398742676,3000 -3474.739589214325,0.130730390548706,4227.306416749954,11897,0,4227.306416749954,0.6275289058685303,24.511840943821287,1.8683427572250368,3003,7702.555572986603,0.6027729511260986,28.85630338895498,2.055947542190552,0.6205874681472778,26.28002901639568,1.9075512886047363,3000 -3980.2613196372986,0.1580393314361572,5067.252546310425,14279,0,5067.252546310425,0.6389285922050476,25.56158413081232,1.783985257148743,3003,9048.127070188522,0.6161245703697205,29.923540706085337,1.941306471824646,0.6305191516876221,26.47629077209134,1.827352523803711,3000 -4433.9902555942535,0.1892220973968505,5907.344487428665,16662,0,5907.344487428665,0.6466794610023499,26.0778676097258,1.7082695960998535,3003,10342.05437874794,0.6206269860267639,29.945450713323183,1.902586579322815,0.6397068500518799,27.15322636447869,1.7634724378585815,3000 -5199.53736281395,0.2170188426971435,6747.582292556763,19046,0,6747.582292556763,0.6498286128044128,26.317142488955124,1.6713584661483765,3003,11947.941531896591,0.641328752040863,31.378610385408912,1.730721354484558,0.6421619057655334,26.5024963346979,1.719527244567871,3000 -5681.113517045975,0.2463798522949218,7587.546016216278,21429,0,7587.546016216278,0.6571030020713806,26.55280495333868,1.642043113708496,3003,13269.586208820345,0.6279956698417664,30.87111135681036,1.8332117795944207,0.6472703218460083,27.717725941527146,1.690517783164978,3000 -6163.827691793442,0.2743749618530273,8427.509428739548,23812,0,8427.509428739548,0.6596711277961731,27.04909577649694,1.6173677444458008,3003,14592.367801189424,0.6292793154716492,31.07930974577942,1.8349435329437256,0.6511016488075256,27.76755747673576,1.673044204711914,3000 -6667.039098739624,0.3026127815246582,9267.73613333702,26196,0,9267.73613333702,0.662216067314148,27.36525046340449,1.600104570388794,3003,15935.911216020584,0.636094868183136,31.10254632386268,1.7777622938156128,0.6532343029975891,27.915533557356824,1.6610873937606812,3000 -7190.688607931137,0.3311488628387451,10107.839243650436,28581,0,10107.839243650436,0.6586834192276001,26.97462392898852,1.6259721517562866,3003,17299.766478300095,0.6279000639915466,30.217147413299987,1.8329882621765137,0.650444507598877,27.74585837118948,1.677872896194458,3000 -7703.6917362213135,0.3620364665985107,10947.769191265106,30964,0,10947.769191265106,0.666748046875,27.358672182603275,1.587527871131897,3003,18652.81289315224,0.6326751112937927,30.968724846774887,1.8045499324798584,0.6552057266235352,28.025201830906543,1.6448639631271362,3000 -8247.851219177246,0.391165018081665,11787.76186323166,33348,0,11787.76186323166,0.6678636074066162,27.742504843960525,1.5693237781524658,3003,20037.068594694138,0.6399797201156616,31.4189488530457,1.7495006322860718,0.6567928194999695,28.44525274290272,1.632373929023743,3000 -8822.155555486679,0.4218254089355469,12627.94023323059,35733,0,12627.94023323059,0.6670036911964417,27.611120742716544,1.5616695880889893,3003,21451.65731573105,0.636881947517395,31.0226532366646,1.780168533325195,0.6587890982627869,28.25602613523417,1.621066689491272,3000 -9305.749348640442,0.4531617164611816,13468.088101387024,38118,0,13468.088101387024,0.6677822470664978,27.36167683701787,1.551979899406433,3003,22775.50716114044,0.6518439054489136,32.117936466755715,1.66854727268219,0.6590867042541504,28.395664950065505,1.615612506866455,3000 -9882.452818155289,0.4833953380584717,14308.046524524689,40502,0,14308.046524524689,0.6714543104171753,28.16023044547092,1.5414283275604248,3003,24192.2749774456,0.6441013216972351,31.448439318045764,1.7263169288635254,0.6599918007850647,28.42181715596443,1.606161117553711,3000 -10409.5962266922,0.5133147239685059,15148.275726556778,42887,0,15148.275726556778,0.6733600497245789,28.174196940148576,1.5363410711288452,3003,25559.754118919373,0.6383638978004456,31.42605783809,1.7570757865905762,0.6582187414169312,28.302365767108448,1.6062474250793457,3000 -11000.106136083605,0.5449538230895996,15988.28655910492,45271,0,15988.28655910492,0.6741386651992798,27.84871434765751,1.5258550643920898,3003,26990.38134407997,0.6441451907157898,31.51585807077358,1.7160241603851318,0.6614177227020264,28.40174838792601,1.5945888757705688,3000 -11570.585859537125,0.5759387016296387,16828.508977890015,47656,0,16828.508977890015,0.674312949180603,28.13550040341368,1.5223636627197266,3003,28401.18878436089,0.6404057145118713,31.17807012669652,1.7511805295944214,0.6617772579193115,28.050460138448788,1.587974190711975,3000 -12135.090401649475,0.6083238124847412,17668.541675567627,50040,0,17668.541675567627,0.6757422685623169,27.995736464992472,1.516922116279602,3003,29805.83444905281,0.6673862934112549,33.58396302504942,1.58917498588562,0.6645546555519104,28.75264872649214,1.5788758993148804,3000 -12665.82209134102,0.6415340900421143,18508.65367841721,52425,0,18508.65367841721,0.6774504780769348,27.875327607675807,1.505465745925903,3003,31176.78390431404,0.6453410387039185,31.580043661207323,1.7113089561462402,0.6636619567871094,28.384619001282065,1.575170874595642,3000 -13179.79739499092,0.6752762794494629,19348.648066282272,54810,0,19348.648066282272,0.6790425181388855,28.44183774132029,1.4980599880218506,3003,32530.86065387726,0.6456965804100037,31.87085102078416,1.71318256855011,0.6642695069313049,28.47791587474528,1.5717233419418335,3000 -13681.056218147278,0.7081527709960938,20188.61677551269,57195,0,20188.61677551269,0.6807158589363098,28.4510571567512,1.4926679134368896,3003,33872.19545149803,0.6577078104019165,32.78105983753737,1.6320240497589111,0.667592465877533,29.03892667850046,1.5603337287902832,3000 -14264.738429784777,0.74080491065979,21028.747178077698,59580,0,21028.747178077698,0.6815292835235596,28.475886821190706,1.4816378355026243,3003,35296.116545677185,0.6465297937393188,32.05305332294858,1.694580316543579,0.6670469045639038,28.62439054702784,1.5555315017700195,3000 -14754.18920135498,0.7733290195465088,21868.823503017426,61965,0,21868.823503017426,0.6838417649269104,28.84332825064846,1.4731414318084717,3003,36625.751838207245,0.648226261138916,32.419402937612425,1.7013450860977173,0.6691423654556274,29.033741216883577,1.5469614267349243,3000 -15264.377411842346,0.8070766925811768,22708.90673828125,64349,0,22708.90673828125,0.6819010972976685,28.5848930702908,1.4711955785751345,3003,37976.13408732414,0.6545460820198059,32.617605237832414,1.6462438106536863,0.6699978709220886,29.007868109997457,1.5386850833892822,3000 -15738.510041713716,0.8472380638122559,23548.84869503975,66733,0,23548.84869503975,0.6844227910041809,28.940553718862866,1.462852954864502,3003,39290.32270479202,0.6512013673782349,32.131119778833856,1.6751712560653689,0.6714361906051636,29.08472915961638,1.5344096422195437,3000 -16280.106495141985,0.8839507102966309,24388.915120363235,69117,0,24388.915120363235,0.6858404874801636,29.00716191550685,1.4534401893615725,3003,40672.09962892532,0.6723154783248901,33.717381365500145,1.534518480300903,0.6729241013526917,29.173103041104703,1.5261874198913574,3000 -16826.37229347229,0.9204680919647216,25228.996363401413,71501,0,25228.996363401413,0.685910165309906,28.593936374173712,1.4481548070907593,3003,42058.55702161789,0.65537029504776,32.375148555377464,1.6473228931427002,0.6737300157546997,29.28750192928385,1.5176098346710205,3000 -17396.327433347702,0.9564275741577148,26069.110144138336,73886,0,26069.110144138336,0.6876416206359863,28.943381743475623,1.4385989904403689,3003,43468.737097263336,0.6524078845977783,32.701466533580295,1.6667128801345823,0.6737548112869263,29.32046658710141,1.5123512744903564,3000 -17861.11874818802,0.9940569400787354,26909.085456848145,76269,0,26909.085456848145,0.6899773478507996,29.44708155748432,1.4313315153121948,3003,44773.62126874924,0.6652563810348511,33.20367223177718,1.5872588157653809,0.6766065955162048,29.703277161865785,1.50298011302948,3000 -18400.16328859329,1.0356485843658447,27749.27845311165,78654,0,27749.27845311165,0.6922898292541504,29.830781468889445,1.42089581489563,3003,46152.97341346741,0.6580626368522644,32.67718189408843,1.6298712491989136,0.67764812707901,29.64695478586008,1.4988480806350708,3000 -19072.02882409096,1.0736279487609863,28589.20459485054,81039,0,28589.20459485054,0.6907210946083069,29.21138067716528,1.4177829027175903,3003,47664.87612128258,0.6602153778076172,32.57357972662718,1.6250991821289062,0.6790988445281982,29.84311518307996,1.4938795566558838,3000 -19616.61013817787,1.1113903522491455,29429.326312065125,83423,0,29429.326312065125,0.6939747929573059,29.7245642760348,1.405526041984558,3003,49049.69407105446,0.6681274175643921,32.9723800626729,1.57500159740448,0.6786648631095886,29.676601764832707,1.4860962629318235,3000 -20176.76056933403,1.147826910018921,30269.53056025505,85808,0,30269.53056025505,0.694195568561554,29.68178603516924,1.39683997631073,3003,50450.16050791741,0.6655429601669312,33.31576160365833,1.5882421731948853,0.6799171566963196,29.79216033426816,1.4798542261123655,3000 -20677.904341459274,1.1917221546173096,31109.656126499176,88193,0,31109.656126499176,0.6972517967224121,30.091496010180137,1.393563151359558,3003,51791.54896807671,0.6824063062667847,34.05213368530559,1.4828976392745972,0.6825209856033325,30.0387532654638,1.4713144302368164,3000 -21295.32447552681,1.2300746440887451,31949.85806655884,90578,0,31949.85806655884,0.6974958181381226,29.89253457886926,1.3812869787216189,3003,53249.28406596184,0.6718730330467224,33.41231035188514,1.5490468740463257,0.6825333833694458,30.102713832189707,1.4652026891708374,3000 -21805.096822977062,1.2744011878967283,32790.02381038666,92963,0,32790.02381038666,0.6990064382553101,30.2127779651062,1.3774524927139282,3003,54599.34149551392,0.6706978678703308,33.93765018553232,1.5593184232711792,0.6834261417388916,29.94001242273573,1.4623475074768066,3000 -22337.2406938076,1.3123936653137207,33630.12921476364,95347,0,33630.12921476364,0.6996804475784302,29.98158297653488,1.3698073625564575,3003,55971.70452427864,0.6808968186378479,34.68351336887508,1.4978055953979492,0.6850131750106812,30.293797765390387,1.4518790245056152,3000 -22831.28604865074,1.359565496444702,34470.28120446205,97731,0,34470.28120446205,0.7033641338348389,30.140830859153333,1.3601025342941284,3003,57306.02733922005,0.6694570183753967,33.762738030108665,1.563796043395996,0.6860547065734863,30.13158593552824,1.4455714225769043,3000 -23346.078429937363,1.3988215923309326,35310.500405311584,100116,0,35310.500405311584,0.7017953991889954,30.204320024066764,1.3589884042739868,3003,58661.151460170746,0.709169864654541,36.85112230721641,1.3537744283676147,0.6860795021057129,30.376215081158747,1.4437220096588137,3000 -23869.03919649124,1.439997673034668,36150.45610380173,102500,0,36150.45610380173,0.7032363414764404,30.55794406358037,1.3476167917251587,3003,60024.18383717537,0.6830866932868958,34.33374070683871,1.4718409776687622,0.6886585354804993,30.327454897877665,1.4326103925704956,3000 -24395.162294387817,1.479119062423706,36990.53392624855,104884,0,36990.53392624855,0.7036778926849365,30.484548709173747,1.347154140472412,3003,61390.49967384338,0.6831173896789551,34.37978318584616,1.4802350997924805,0.6883485317230225,30.31289615139876,1.4367409944534302,3000 -24910.0284614563,1.5198464393615725,37830.54970383644,107269,0,37830.54970383644,0.7058392763137817,30.684268005675964,1.339435338973999,3003,62745.49480700493,0.6964491009712219,35.48056756531287,1.404707908630371,0.6896876692771912,30.40983792345602,1.4275447130203247,3000 -25458.17726159096,1.561103343963623,38670.55089187622,109653,0,38670.55089187622,0.7071059346199036,30.78908048297665,1.3327090740203855,3003,64133.76025557518,0.6924313902854919,35.2917509143731,1.4321004152297974,0.6913739442825317,30.570859095014143,1.4228721857070925,3000 -25961.32877779007,1.60233473777771,39510.55522322655,112037,0,39510.55522322655,0.7082796096801758,30.976652739571417,1.3301644325256348,3003,65477.0303850174,0.6871734857559204,35.0512106849693,1.4625294208526611,0.6905679702758789,30.782762669814545,1.4227373600006104,3000 -26473.64466929436,1.644965410232544,40350.648983716965,114421,0,40350.648983716965,0.7078379988670349,30.80652754405374,1.3299546241760254,3003,66829.55691623688,0.699521541595459,36.07696290784867,1.3912333250045776,0.6913739442825317,30.715100454032235,1.4179686307907104,3000 -27004.1587600708,1.687828779220581,41190.77096199989,116804,0,41190.77096199989,0.7087211608886719,30.802705175820623,1.3230842351913452,3003,68200.31424498558,0.6963258385658264,35.96184061155457,1.4070066213607788,0.6927750110626221,30.68112671416496,1.415552258491516,3000 -27537.889677286148,1.728813409805298,42030.960902929306,119188,0,42030.960902929306,0.7092092633247375,30.994883238838508,1.32043719291687,3003,69574.35052037239,0.7110795378684998,37.20115594615644,1.327668070793152,0.6929734349250793,30.86193975348537,1.4138745069503784,3000 -28090.25924921036,1.7727289199829102,42870.91543865204,121572,0,42870.91543865204,0.7098135352134705,30.807151339414787,1.3198758363723757,3003,70966.79245257378,0.7089225053787231,36.699779532755656,1.3409868478775024,0.6929982304573059,30.743967260596456,1.4128586053848269,3000 -28615.730256080627,1.814807415008545,43710.88386774063,123956,0,43710.88386774063,0.7102085947990417,31.06983647732182,1.3139033317565918,3003,72332.34792470932,0.7044199109077454,36.63037870395407,1.3619587421417236,0.6941761374473572,30.92822877713484,1.409651756286621,3000 -29137.58055591584,1.8581435680389404,44550.9148273468,126339,0,44550.9148273468,0.7101853489875793,30.995070613341326,1.3157200813293457,3003,73694.35182857513,0.7150582075119019,36.93187865217039,1.3104259967803955,0.6944612860679626,30.66054748234449,1.411051869392395,3000 -29686.40299320221,1.902180433273316,45391.13405776024,128723,0,45391.13405776024,0.7109407186508179,31.110231469925782,1.3132773637771606,3003,75083.51441955566,0.7099118232727051,36.93021401210776,1.337271809577942,0.6949200630187988,30.931063555305773,1.4093292951583862,3000 -30232.55574464798,1.945484161376953,46231.019748449326,131106,0,46231.019748449326,0.710777997970581,31.17334827670532,1.3123667240142822,3003,76469.67206168175,0.7094533443450928,36.79860717293863,1.3412723541259766,0.6949820518493652,30.92847868651178,1.4096693992614746,3000 -30766.507289409637,1.9903368949890137,47015.76848649979,133333,0,47015.76848649979,0.7105455994606018,31.11336078741953,1.312601923942566,3003,77788.48857069016,0.7125383019447327,36.68598863876687,1.3173706531524658,0.6947216987609863,30.934123880652553,1.4095044136047363,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/measurements.csv deleted file mode 100644 index 5661e0948..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.1584854,11.173096,,,,,,,,,,,,,,,,, -1,,,0.0005962961004115,11.175597190856934,0.0,0.0004835649742744,11.208685874938965,0.0,3000.0,0.0007088489946909,11.19086742401123,0.0,3003.0,26.88200736045837,905.2524693012238,26.88200736045837,878.3704223632812,0.0,0.0 -100,0.48677242,8.652959,,,,,,,,,,,,,,,,, -200,0.17784809,8.261594,,,,,,,,,,,,,,,,, -300,0.19116548,8.004621,,,,,,,,,,,,,,,,, -400,0.32695454,7.603524,,,,,,,,,,,,,,,,, -500,0.42935234,7.2227,,,,,,,,,,,,,,,,, -600,0.86744225,6.959226,,,,,,,,,,,,,,,,, -700,0.7864968,6.7341805,,,,,,,,,,,,,,,,, -800,0.57157576,6.4225173,,,,,,,,,,,,,,,,, -900,0.587758,6.2873297,,,,,,,,,,,,,,,,, -1000,0.5786174,5.991468,,,,,,,,,,,,,,,,, -1100,0.64687693,5.7553263,,,,,,,,,,,,,,,,, -1200,0.82211345,5.595517,,,,,,,,,,,,,,,,, -1300,0.6952407,5.4331155,,,,,,,,,,,,,,,,, -1400,0.7274879,5.2476826,,,,,,,,,,,,,,,,, -1500,0.78069323,5.0791206,,,,,,,,,,,,,,,,, -1600,0.69493365,4.917864,,,,,,,,,,,,,,,,, -1700,0.7181932,4.7420974,,,,,,,,,,,,,,,,, -1800,0.88924205,4.6800594,,,,,,,,,,,,,,,,, -1900,1.1575704,4.5513387,,,,,,,,,,,,,,,,, -2000,0.8849507,4.5247746,,,,,,,,,,,,,,,,, -2100,0.87966794,4.3206935,,,,,,,,,,,,,,,,, -2200,0.65073377,4.187289,,,,,,,,,,,,,,,,, -2300,1.7135797,4.1238055,,,,,,,,,,,,,,,,, -2379,,,0.4058598577976227,3.990895748138428,14.225651382899391,0.3933243155479431,4.097593307495117,9.435176775149335,3000.0,0.375143826007843,4.313692092895508,7.492265560875259,3003.0,866.837769985199,2471.213627576828,866.837769985199,1604.2824800014496,0.0188047885894775,0.0 -2400,1.178652,4.0291867,,,,,,,,,,,,,,,,, -2500,1.1551702,3.9276261,,,,,,,,,,,,,,,,, -2600,0.8621888,3.7073448,,,,,,,,,,,,,,,,, -2700,0.7915192,3.6848323,,,,,,,,,,,,,,,,, -2800,0.8885874,3.5938187,,,,,,,,,,,,,,,,, -2900,1.0712729,3.521239,,,,,,,,,,,,,,,,, -3000,0.98927367,3.5092945,,,,,,,,,,,,,,,,, -3100,0.88668084,3.3548667,,,,,,,,,,,,,,,,, -3200,0.88895166,3.2800548,,,,,,,,,,,,,,,,, -3300,0.8547276,3.2627742,,,,,,,,,,,,,,,,, -3400,0.8212223,3.2239745,,,,,,,,,,,,,,,,, -3500,0.84374195,3.1363707,,,,,,,,,,,,,,,,, -3600,0.66167766,3.114939,,,,,,,,,,,,,,,,, -3700,0.7704966,2.9949017,,,,,,,,,,,,,,,,, -3800,0.81578505,3.0500724,,,,,,,,,,,,,,,,, -3900,0.73811865,2.9959931,,,,,,,,,,,,,,,,, -4000,0.69582516,2.9841707,,,,,,,,,,,,,,,,, -4100,0.72958845,2.8491647,,,,,,,,,,,,,,,,, -4200,0.69009423,2.8629415,,,,,,,,,,,,,,,,, -4300,0.607979,2.7843516,,,,,,,,,,,,,,,,, -4400,0.700057,2.8010337,,,,,,,,,,,,,,,,, -4500,0.79860324,2.7178938,,,,,,,,,,,,,,,,, -4600,0.8457116,2.7669427,,,,,,,,,,,,,,,,, -4700,0.66635007,2.5736806,,,,,,,,,,,,,,,,, -4757,,,0.5393306016921997,2.7030396461486816,24.873358742540944,0.542708694934845,2.6487877368927,20.43187495336819,3000.0,0.5409215092658997,2.6869993209838867,18.88999064105362,3003.0,1707.074345588684,3796.1771688461295,1707.074345588684,2088.90660238266,0.0470664501190185,0.0 -4800,0.6437942,2.6650753,,,,,,,,,,,,,,,,, -4900,0.59739953,2.6593156,,,,,,,,,,,,,,,,, -5000,0.58783984,2.6397483,,,,,,,,,,,,,,,,, -5100,0.62085027,2.633085,,,,,,,,,,,,,,,,, -5200,0.821644,2.6048148,,,,,,,,,,,,,,,,, -5300,0.60358936,2.5327475,,,,,,,,,,,,,,,,, -5400,0.55349624,2.5411584,,,,,,,,,,,,,,,,, -5500,0.60505307,2.435639,,,,,,,,,,,,,,,,, -5600,0.6798263,2.5379643,,,,,,,,,,,,,,,,, -5700,0.7974442,2.484939,,,,,,,,,,,,,,,,, -5800,0.5671144,2.54453,,,,,,,,,,,,,,,,, -5900,0.67384976,2.502227,,,,,,,,,,,,,,,,, -6000,0.6123608,2.414762,,,,,,,,,,,,,,,,, -6100,0.5678585,2.4216514,,,,,,,,,,,,,,,,, -6200,0.5692446,2.4776607,,,,,,,,,,,,,,,,, -6300,0.5912674,2.435577,,,,,,,,,,,,,,,,, -6400,0.50264347,2.4538546,,,,,,,,,,,,,,,,, -6500,0.52253145,2.4028306,,,,,,,,,,,,,,,,, -6600,0.52777517,2.300141,,,,,,,,,,,,,,,,, -6700,0.53493536,2.3111868,,,,,,,,,,,,,,,,, -6800,0.49996144,2.3951948,,,,,,,,,,,,,,,,, -6900,0.46280524,2.3099604,,,,,,,,,,,,,,,,, -7000,0.46996647,2.3834746,,,,,,,,,,,,,,,,, -7100,0.56858754,2.4194736,,,,,,,,,,,,,,,,, -7136,,,0.5833343267440796,2.2597906589508057,27.53477110939534,0.5829437971115112,2.237121105194092,23.01984843027956,3000.0,0.586462140083313,2.236505508422852,22.075670983383866,3003.0,2547.0738015174866,5107.401966333389,2547.0738015174866,2560.031564474106,0.0735063552856445,0.0 -7200,0.5286665,2.3378992,,,,,,,,,,,,,,,,, -7300,0.45813715,2.26374,,,,,,,,,,,,,,,,, -7400,0.42627537,2.308846,,,,,,,,,,,,,,,,, -7500,0.49456242,2.2886302,,,,,,,,,,,,,,,,, -7600,0.52072686,2.3696659,,,,,,,,,,,,,,,,, -7700,0.6029446,2.2284436,,,,,,,,,,,,,,,,, -7800,0.42946082,2.142279,,,,,,,,,,,,,,,,, -7900,0.46510872,2.3102717,,,,,,,,,,,,,,,,, -8000,0.53597665,2.2424388,,,,,,,,,,,,,,,,, -8100,0.45774812,2.1647835,,,,,,,,,,,,,,,,, -8200,0.48682728,2.1537976,,,,,,,,,,,,,,,,, -8300,0.50031835,2.307095,,,,,,,,,,,,,,,,, -8400,0.51817685,2.233416,,,,,,,,,,,,,,,,, -8500,0.38009372,2.182737,,,,,,,,,,,,,,,,, -8600,0.5281151,2.2427564,,,,,,,,,,,,,,,,, -8700,0.3876056,2.187864,,,,,,,,,,,,,,,,, -8800,0.39042938,2.1595268,,,,,,,,,,,,,,,,, -8900,0.429746,2.1796453,,,,,,,,,,,,,,,,, -9000,0.40556905,2.2387652,,,,,,,,,,,,,,,,, -9100,0.3741479,2.194949,,,,,,,,,,,,,,,,, -9200,0.3641271,2.1944966,,,,,,,,,,,,,,,,, -9300,0.37198928,2.1178815,,,,,,,,,,,,,,,,, -9400,0.3616983,2.1621623,,,,,,,,,,,,,,,,, -9500,0.35213298,2.1253395,,,,,,,,,,,,,,,,, -9516,,,0.5903454422950745,2.161125421524048,28.423563496088704,0.60672527551651,2.036038398742676,24.557212452962293,3000.0,0.6093777418136597,2.014265537261963,23.438955386735778,3003.0,3387.072328567505,6411.077891111374,3387.072328567505,3023.601635694504,0.1044983863830566,0.0 -9600,0.36596465,2.0795758,,,,,,,,,,,,,,,,, -9700,0.3572914,2.055694,,,,,,,,,,,,,,,,, -9800,0.37259224,2.0939147,,,,,,,,,,,,,,,,, -9900,0.33680245,2.222411,,,,,,,,,,,,,,,,, -10000,0.3243034,2.1339774,,,,,,,,,,,,,,,,, -10100,0.3109046,2.0891757,,,,,,,,,,,,,,,,, -10200,0.32097128,2.0778694,,,,,,,,,,,,,,,,, -10300,0.31126896,2.1017802,,,,,,,,,,,,,,,,, -10400,0.4171325,2.0913284,,,,,,,,,,,,,,,,, -10500,0.36179388,2.0325093,,,,,,,,,,,,,,,,, -10600,0.30427,2.148152,,,,,,,,,,,,,,,,, -10700,0.33198154,2.07988,,,,,,,,,,,,,,,,, -10800,0.3128207,2.0358872,,,,,,,,,,,,,,,,, -10900,0.4124844,2.183565,,,,,,,,,,,,,,,,, -11000,0.30900526,2.142721,,,,,,,,,,,,,,,,, -11100,0.34466556,2.092085,,,,,,,,,,,,,,,,, -11200,0.31564286,1.9912994,,,,,,,,,,,,,,,,, -11300,0.38112983,2.1647484,,,,,,,,,,,,,,,,, -11400,0.26760128,1.9877716,,,,,,,,,,,,,,,,, -11500,0.32356074,2.1704001,,,,,,,,,,,,,,,,, -11600,0.28567556,2.048138,,,,,,,,,,,,,,,,, -11700,0.30695868,2.0120766,,,,,,,,,,,,,,,,, -11800,0.3101881,2.0599608,,,,,,,,,,,,,,,,, -11897,,,0.6027729511260986,2.055947542190552,28.85630338895498,0.6205874681472778,1.9075512886047363,26.28002901639568,3000.0,0.6275289058685303,1.8683427572250368,24.511840943821287,3003.0,4227.306416749954,7702.555572986603,4227.306416749954,3474.739589214325,0.130730390548706,0.0 -11900,0.30103627,2.0525713,,,,,,,,,,,,,,,,, -12000,0.32699358,2.192074,,,,,,,,,,,,,,,,, -12100,0.29394442,1.9849443,,,,,,,,,,,,,,,,, -12200,0.2892164,1.9589218,,,,,,,,,,,,,,,,, -12300,0.27389914,2.057032,,,,,,,,,,,,,,,,, -12400,0.2580455,2.0508757,,,,,,,,,,,,,,,,, -12500,0.31083286,1.9835061,,,,,,,,,,,,,,,,, -12600,0.39585108,2.029704,,,,,,,,,,,,,,,,, -12700,0.27393788,2.069589,,,,,,,,,,,,,,,,, -12800,0.27489063,2.073074,,,,,,,,,,,,,,,,, -12900,0.25647777,2.0305386,,,,,,,,,,,,,,,,, -13000,0.25755346,2.0185943,,,,,,,,,,,,,,,,, -13100,0.30379105,2.013145,,,,,,,,,,,,,,,,, -13200,0.3306728,2.0360756,,,,,,,,,,,,,,,,, -13300,0.27538705,2.0422082,,,,,,,,,,,,,,,,, -13400,0.3404995,2.0171251,,,,,,,,,,,,,,,,, -13500,0.2533542,2.0105503,,,,,,,,,,,,,,,,, -13600,0.2918167,1.9498813,,,,,,,,,,,,,,,,, -13700,0.30751964,2.0671453,,,,,,,,,,,,,,,,, -13800,0.28493392,2.0514565,,,,,,,,,,,,,,,,, -13900,0.25518605,1.955368,,,,,,,,,,,,,,,,, -14000,0.27823117,2.019453,,,,,,,,,,,,,,,,, -14100,0.28816804,1.9725865,,,,,,,,,,,,,,,,, -14200,0.2901635,1.9401327,,,,,,,,,,,,,,,,, -14279,,,0.6161245703697205,1.941306471824646,29.923540706085337,0.6305191516876221,1.827352523803711,26.47629077209134,3000.0,0.6389285922050476,1.783985257148743,25.56158413081232,3003.0,5067.252546310425,9048.127070188522,5067.252546310425,3980.2613196372986,0.1580393314361572,0.0 -14300,0.29392064,2.0257537,,,,,,,,,,,,,,,,, -14400,0.327857,1.9340107,,,,,,,,,,,,,,,,, -14500,0.34180474,1.9091586,,,,,,,,,,,,,,,,, -14600,0.3417486,1.9942814,,,,,,,,,,,,,,,,, -14700,0.35114866,2.0022721,,,,,,,,,,,,,,,,, -14800,0.2962665,1.9410243,,,,,,,,,,,,,,,,, -14900,0.3382628,2.0609121,,,,,,,,,,,,,,,,, -15000,0.33939695,1.9647717,,,,,,,,,,,,,,,,, -15100,0.3207996,1.9900873,,,,,,,,,,,,,,,,, -15200,0.30829185,1.9004948,,,,,,,,,,,,,,,,, -15300,0.26689294,2.0018857,,,,,,,,,,,,,,,,, -15400,0.31002188,2.0154724,,,,,,,,,,,,,,,,, -15500,0.30581015,1.9682726,,,,,,,,,,,,,,,,, -15600,0.27996463,1.935804,,,,,,,,,,,,,,,,, -15700,0.30307376,1.981131,,,,,,,,,,,,,,,,, -15800,0.40306452,1.8985318,,,,,,,,,,,,,,,,, -15900,0.28207776,1.8889315,,,,,,,,,,,,,,,,, -16000,0.28477257,1.9404881,,,,,,,,,,,,,,,,, -16100,0.3118978,1.9062464,,,,,,,,,,,,,,,,, -16200,0.32547957,1.9338263,,,,,,,,,,,,,,,,, -16300,0.29609317,1.8851569,,,,,,,,,,,,,,,,, -16400,0.33479393,1.9176363,,,,,,,,,,,,,,,,, -16500,0.360775,1.9736307,,,,,,,,,,,,,,,,, -16600,0.30631882,1.9184465,,,,,,,,,,,,,,,,, -16662,,,0.6206269860267639,1.902586579322815,29.945450713323183,0.6397068500518799,1.7634724378585815,27.15322636447869,3000.0,0.6466794610023499,1.7082695960998535,26.0778676097258,3003.0,5907.344487428665,10342.05437874794,5907.344487428665,4433.9902555942535,0.1892220973968505,0.0 -16700,0.38678473,1.8685464,,,,,,,,,,,,,,,,, -16800,0.33045828,1.9315546,,,,,,,,,,,,,,,,, -16900,0.37358886,1.9144057,,,,,,,,,,,,,,,,, -17000,0.28952035,1.9108689,,,,,,,,,,,,,,,,, -17100,0.30021882,1.9176188,,,,,,,,,,,,,,,,, -17200,0.61890125,2.0065465,,,,,,,,,,,,,,,,, -17300,0.37876764,1.8920652,,,,,,,,,,,,,,,,, -17400,0.29982084,1.882604,,,,,,,,,,,,,,,,, -17500,0.30967,2.0012727,,,,,,,,,,,,,,,,, -17600,0.32484174,1.9107256,,,,,,,,,,,,,,,,, -17700,0.35663205,1.907659,,,,,,,,,,,,,,,,, -17800,0.3431147,1.9162616,,,,,,,,,,,,,,,,, -17900,0.32596868,1.8779393,,,,,,,,,,,,,,,,, -18000,0.3039833,1.7842643,,,,,,,,,,,,,,,,, -18100,0.47184512,1.9122518,,,,,,,,,,,,,,,,, -18200,0.33981058,1.9582772,,,,,,,,,,,,,,,,, -18300,0.3536119,1.8558997,,,,,,,,,,,,,,,,, -18400,0.30618817,1.9153202,,,,,,,,,,,,,,,,, -18500,0.34099418,1.8774165,,,,,,,,,,,,,,,,, -18600,0.313475,1.8643749,,,,,,,,,,,,,,,,, -18700,0.36423025,1.8627324,,,,,,,,,,,,,,,,, -18800,0.43557858,1.9291594,,,,,,,,,,,,,,,,, -18900,0.3590957,1.8852915,,,,,,,,,,,,,,,,, -19000,0.37784928,1.8734146,,,,,,,,,,,,,,,,, -19046,,,0.641328752040863,1.730721354484558,31.378610385408912,0.6421619057655334,1.719527244567871,26.5024963346979,3000.0,0.6498286128044128,1.6713584661483765,26.317142488955124,3003.0,6747.582292556763,11947.941531896591,6747.582292556763,5199.53736281395,0.2170188426971435,0.0 -19100,0.3634989,1.9372413,,,,,,,,,,,,,,,,, -19200,0.42038423,1.901368,,,,,,,,,,,,,,,,, -19300,0.35205483,1.8569356,,,,,,,,,,,,,,,,, -19400,0.39685678,1.9366904,,,,,,,,,,,,,,,,, -19500,0.3236896,1.8378751,,,,,,,,,,,,,,,,, -19600,0.35868266,1.9244015,,,,,,,,,,,,,,,,, -19700,0.60863036,1.9334859,,,,,,,,,,,,,,,,, -19800,0.40446064,1.9115481,,,,,,,,,,,,,,,,, -19900,0.35869527,1.8251581,,,,,,,,,,,,,,,,, -20000,0.359728,1.8819352,,,,,,,,,,,,,,,,, -20100,0.38535982,1.9462775,,,,,,,,,,,,,,,,, -20200,0.37090164,1.7887896,,,,,,,,,,,,,,,,, -20300,0.41803333,1.7948157,,,,,,,,,,,,,,,,, -20400,0.41025773,1.9216176,,,,,,,,,,,,,,,,, -20500,0.32341126,1.8286377,,,,,,,,,,,,,,,,, -20600,0.3271865,1.7453756,,,,,,,,,,,,,,,,, -20700,0.35157207,1.893852,,,,,,,,,,,,,,,,, -20800,0.37226364,1.8531392,,,,,,,,,,,,,,,,, -20900,0.41518855,1.888725,,,,,,,,,,,,,,,,, -21000,0.41523904,1.8107882,,,,,,,,,,,,,,,,, -21100,0.37490034,1.8030095,,,,,,,,,,,,,,,,, -21200,0.41520408,1.8792993,,,,,,,,,,,,,,,,, -21300,0.34695423,1.8832308,,,,,,,,,,,,,,,,, -21400,0.38152573,1.8789945,,,,,,,,,,,,,,,,, -21429,,,0.6279956698417664,1.8332117795944207,30.87111135681036,0.6472703218460083,1.690517783164978,27.717725941527146,3000.0,0.6571030020713806,1.642043113708496,26.55280495333868,3003.0,7587.546016216278,13269.586208820345,7587.546016216278,5681.113517045975,0.2463798522949218,0.0 -21500,0.38289934,1.8902884,,,,,,,,,,,,,,,,, -21600,0.392081,1.8999078,,,,,,,,,,,,,,,,, -21700,0.3748583,1.8520892,,,,,,,,,,,,,,,,, -21800,0.38915604,1.8140892,,,,,,,,,,,,,,,,, -21900,0.40741056,1.7983891,,,,,,,,,,,,,,,,, -22000,0.39608115,1.899838,,,,,,,,,,,,,,,,, -22100,0.4078421,1.7785093,,,,,,,,,,,,,,,,, -22200,0.36088574,1.8231934,,,,,,,,,,,,,,,,, -22300,0.3671069,1.8334912,,,,,,,,,,,,,,,,, -22400,0.39732224,1.8382792,,,,,,,,,,,,,,,,, -22500,0.37452698,1.7941583,,,,,,,,,,,,,,,,, -22600,0.3801025,1.8715994,,,,,,,,,,,,,,,,, -22700,0.4991032,1.8171527,,,,,,,,,,,,,,,,, -22800,0.35194173,1.7859159,,,,,,,,,,,,,,,,, -22900,0.3462113,1.8437666,,,,,,,,,,,,,,,,, -23000,0.35833508,1.8174279,,,,,,,,,,,,,,,,, -23100,0.4204698,1.7941525,,,,,,,,,,,,,,,,, -23200,0.47542745,1.8700124,,,,,,,,,,,,,,,,, -23300,0.38790122,1.9228466,,,,,,,,,,,,,,,,, -23400,0.40826443,1.8977333,,,,,,,,,,,,,,,,, -23500,0.41362056,1.7577422,,,,,,,,,,,,,,,,, -23600,0.44213405,1.82952,,,,,,,,,,,,,,,,, -23700,0.39119294,1.7486624,,,,,,,,,,,,,,,,, -23800,0.36655405,1.8100522,,,,,,,,,,,,,,,,, -23812,,,0.6292793154716492,1.8349435329437256,31.07930974577942,0.6511016488075256,1.673044204711914,27.76755747673576,3000.0,0.6596711277961731,1.6173677444458008,27.04909577649694,3003.0,8427.509428739548,14592.367801189424,8427.509428739548,6163.827691793442,0.2743749618530273,0.0 -23900,0.41771308,1.8680489,,,,,,,,,,,,,,,,, -24000,0.39692461,1.752272,,,,,,,,,,,,,,,,, -24100,0.37743738,1.7997782,,,,,,,,,,,,,,,,, -24200,0.3798668,1.8712548,,,,,,,,,,,,,,,,, -24300,0.38307652,1.8912233,,,,,,,,,,,,,,,,, -24400,0.43138602,1.8446182,,,,,,,,,,,,,,,,, -24500,0.5037142,1.8491312,,,,,,,,,,,,,,,,, -24600,0.43086535,1.8441589,,,,,,,,,,,,,,,,, -24700,0.4327395,1.8708988,,,,,,,,,,,,,,,,, -24800,0.40055618,1.8713402,,,,,,,,,,,,,,,,, -24900,0.38162968,1.8271147,,,,,,,,,,,,,,,,, -25000,0.4030435,1.7815907,,,,,,,,,,,,,,,,, -25100,0.3737698,1.8863047,,,,,,,,,,,,,,,,, -25200,0.38883817,1.8145229,,,,,,,,,,,,,,,,, -25300,0.42710638,1.7910172,,,,,,,,,,,,,,,,, -25400,0.58063674,1.8350562,,,,,,,,,,,,,,,,, -25500,0.42889616,1.8482065,,,,,,,,,,,,,,,,, -25600,0.40092447,1.824826,,,,,,,,,,,,,,,,, -25700,0.37000367,1.9044336,,,,,,,,,,,,,,,,, -25800,0.3633836,1.8537904,,,,,,,,,,,,,,,,, -25900,0.404605,1.8092744,,,,,,,,,,,,,,,,, -26000,0.61530304,1.8343548,,,,,,,,,,,,,,,,, -26100,0.40362698,1.8618325,,,,,,,,,,,,,,,,, -26196,,,0.636094868183136,1.7777622938156128,31.10254632386268,0.6532343029975891,1.6610873937606812,27.915533557356824,3000.0,0.662216067314148,1.600104570388794,27.36525046340449,3003.0,9267.73613333702,15935.911216020584,9267.73613333702,6667.039098739624,0.3026127815246582,0.0 -26200,0.5501435,1.8651053,,,,,,,,,,,,,,,,, -26300,0.40648836,1.8346719,,,,,,,,,,,,,,,,, -26400,0.3826982,1.7578374,,,,,,,,,,,,,,,,, -26500,0.403186,1.78077,,,,,,,,,,,,,,,,, -26600,0.39663747,1.7556419,,,,,,,,,,,,,,,,, -26700,0.40160698,1.7998453,,,,,,,,,,,,,,,,, -26800,0.3583477,1.7632376,,,,,,,,,,,,,,,,, -26900,0.41966486,1.7849087,,,,,,,,,,,,,,,,, -27000,0.41339862,1.8223412,,,,,,,,,,,,,,,,, -27100,0.42731348,1.8051724,,,,,,,,,,,,,,,,, -27200,0.3579623,1.7645471,,,,,,,,,,,,,,,,, -27300,0.37514296,1.8048576,,,,,,,,,,,,,,,,, -27400,0.48583898,1.7926708,,,,,,,,,,,,,,,,, -27500,0.52191615,1.8502008,,,,,,,,,,,,,,,,, -27600,0.43421212,1.7785599,,,,,,,,,,,,,,,,, -27700,0.3892622,1.8168309,,,,,,,,,,,,,,,,, -27800,0.40529212,1.7786474,,,,,,,,,,,,,,,,, -27900,0.43069056,1.8338426,,,,,,,,,,,,,,,,, -28000,0.42371768,1.8546191,,,,,,,,,,,,,,,,, -28100,0.4352614,1.8374469,,,,,,,,,,,,,,,,, -28200,0.40271235,1.851758,,,,,,,,,,,,,,,,, -28300,0.5655062,1.8421636,,,,,,,,,,,,,,,,, -28400,0.41160882,1.8205363,,,,,,,,,,,,,,,,, -28500,0.66902506,1.8743638,,,,,,,,,,,,,,,,, -28581,,,0.6279000639915466,1.8329882621765137,30.217147413299987,0.650444507598877,1.677872896194458,27.74585837118948,3000.0,0.6586834192276001,1.6259721517562866,26.97462392898852,3003.0,10107.839243650436,17299.766478300095,10107.839243650436,7190.688607931137,0.3311488628387451,0.0 -28600,0.457318,1.8222944,,,,,,,,,,,,,,,,, -28700,0.40039057,1.747164,,,,,,,,,,,,,,,,, -28800,0.4201367,1.839744,,,,,,,,,,,,,,,,, -28900,0.39259747,1.7415845,,,,,,,,,,,,,,,,, -29000,0.3571498,1.8933114,,,,,,,,,,,,,,,,, -29100,0.41857934,1.7249877,,,,,,,,,,,,,,,,, -29200,0.3663897,1.7858815,,,,,,,,,,,,,,,,, -29300,0.3757193,1.7920502,,,,,,,,,,,,,,,,, -29400,0.39011994,1.9059961,,,,,,,,,,,,,,,,, -29500,0.39258903,1.8753399,,,,,,,,,,,,,,,,, -29600,0.3919989,1.8121531,,,,,,,,,,,,,,,,, -29700,0.4013027,1.7950308,,,,,,,,,,,,,,,,, -29800,0.3378596,1.7697642,,,,,,,,,,,,,,,,, -29900,0.39008945,1.7745363,,,,,,,,,,,,,,,,, -30000,0.36566493,1.744128,,,,,,,,,,,,,,,,, -30100,0.39109915,1.7890841,,,,,,,,,,,,,,,,, -30200,0.44662115,1.8367636,,,,,,,,,,,,,,,,, -30300,0.40647686,1.810891,,,,,,,,,,,,,,,,, -30400,0.41132724,1.7839454,,,,,,,,,,,,,,,,, -30500,0.42037845,1.7854751,,,,,,,,,,,,,,,,, -30600,0.4067414,1.835292,,,,,,,,,,,,,,,,, -30700,0.41654477,1.8538008,,,,,,,,,,,,,,,,, -30800,0.39724103,1.8277513,,,,,,,,,,,,,,,,, -30900,0.37085405,1.6953804,,,,,,,,,,,,,,,,, -30964,,,0.6326751112937927,1.8045499324798584,30.968724846774887,0.6552057266235352,1.6448639631271362,28.025201830906543,3000.0,0.666748046875,1.587527871131897,27.358672182603275,3003.0,10947.769191265106,18652.81289315224,10947.769191265106,7703.6917362213135,0.3620364665985107,0.0 -31000,0.43320444,1.8477826,,,,,,,,,,,,,,,,, -31100,0.41192,1.7705427,,,,,,,,,,,,,,,,, -31200,0.41489807,1.7874032,,,,,,,,,,,,,,,,, -31300,0.37101838,1.8705412,,,,,,,,,,,,,,,,, -31400,0.38933378,1.835221,,,,,,,,,,,,,,,,, -31500,0.39278913,1.7965847,,,,,,,,,,,,,,,,, -31600,0.42457286,1.7546272,,,,,,,,,,,,,,,,, -31700,0.41318867,1.8212212,,,,,,,,,,,,,,,,, -31800,0.44631237,1.6648537,,,,,,,,,,,,,,,,, -31900,0.41859835,1.798849,,,,,,,,,,,,,,,,, -32000,0.4036398,1.8180171,,,,,,,,,,,,,,,,, -32100,0.39442658,1.7724552,,,,,,,,,,,,,,,,, -32200,0.48045346,1.8136489,,,,,,,,,,,,,,,,, -32300,0.36711672,1.8396841,,,,,,,,,,,,,,,,, -32400,0.37635735,1.7575737,,,,,,,,,,,,,,,,, -32500,0.3937444,1.7584442,,,,,,,,,,,,,,,,, -32600,0.38397387,1.8400596,,,,,,,,,,,,,,,,, -32700,0.40512627,1.8187157,,,,,,,,,,,,,,,,, -32800,0.43719837,1.8663191,,,,,,,,,,,,,,,,, -32900,0.41199282,1.7827305,,,,,,,,,,,,,,,,, -33000,0.4206243,1.7571175,,,,,,,,,,,,,,,,, -33100,0.440662,1.8249199,,,,,,,,,,,,,,,,, -33200,0.4135907,1.7889037,,,,,,,,,,,,,,,,, -33300,0.42421204,1.7534817,,,,,,,,,,,,,,,,, -33348,,,0.6399797201156616,1.7495006322860718,31.4189488530457,0.6567928194999695,1.632373929023743,28.44525274290272,3000.0,0.6678636074066162,1.5693237781524658,27.742504843960525,3003.0,11787.76186323166,20037.068594694138,11787.76186323166,8247.851219177246,0.391165018081665,0.0 -33400,0.4827853,1.8100107,,,,,,,,,,,,,,,,, -33500,0.40017575,1.7903081,,,,,,,,,,,,,,,,, -33600,0.42632148,1.7478173,,,,,,,,,,,,,,,,, -33700,0.3916099,1.6643604,,,,,,,,,,,,,,,,, -33800,0.4882535,1.6937276,,,,,,,,,,,,,,,,, -33900,0.39735374,1.7748226,,,,,,,,,,,,,,,,, -34000,0.41984344,1.7816257,,,,,,,,,,,,,,,,, -34100,0.387802,1.744023,,,,,,,,,,,,,,,,, -34200,0.42622513,1.7092342,,,,,,,,,,,,,,,,, -34300,0.4402734,1.8043349,,,,,,,,,,,,,,,,, -34400,0.36771384,1.7857684,,,,,,,,,,,,,,,,, -34500,0.39129364,1.7799767,,,,,,,,,,,,,,,,, -34600,0.3979456,1.7177184,,,,,,,,,,,,,,,,, -34700,0.45749196,1.8555287,,,,,,,,,,,,,,,,, -34800,0.43088618,1.7701387,,,,,,,,,,,,,,,,, -34900,0.3731722,1.7280912,,,,,,,,,,,,,,,,, -35000,0.39038232,1.8389049,,,,,,,,,,,,,,,,, -35100,0.42838848,1.8113021,,,,,,,,,,,,,,,,, -35200,0.40297064,1.7781972,,,,,,,,,,,,,,,,, -35300,0.61722875,1.7385055,,,,,,,,,,,,,,,,, -35400,0.4263313,1.7963539,,,,,,,,,,,,,,,,, -35500,0.39860076,1.8276298,,,,,,,,,,,,,,,,, -35600,0.38548008,1.7537391,,,,,,,,,,,,,,,,, -35700,0.3793282,1.7652605,,,,,,,,,,,,,,,,, -35733,,,0.636881947517395,1.780168533325195,31.0226532366646,0.6587890982627869,1.621066689491272,28.25602613523417,3000.0,0.6670036911964417,1.5616695880889893,27.611120742716544,3003.0,12627.94023323059,21451.65731573105,12627.94023323059,8822.155555486679,0.4218254089355469,0.0 -35800,0.37713167,1.868692,,,,,,,,,,,,,,,,, -35900,0.44525898,1.7999594,,,,,,,,,,,,,,,,, -36000,0.39459762,1.7450505,,,,,,,,,,,,,,,,, -36100,0.43443057,1.8238038,,,,,,,,,,,,,,,,, -36200,0.4277164,1.8021055,,,,,,,,,,,,,,,,, -36300,0.38061798,1.8623872,,,,,,,,,,,,,,,,, -36400,0.42941487,1.7849905,,,,,,,,,,,,,,,,, -36500,0.39475036,1.7014197,,,,,,,,,,,,,,,,, -36600,0.38001195,1.7334388,,,,,,,,,,,,,,,,, -36700,0.3749497,1.7267702,,,,,,,,,,,,,,,,, -36800,0.46326628,1.8058032,,,,,,,,,,,,,,,,, -36900,0.42053995,1.851469,,,,,,,,,,,,,,,,, -37000,0.4516826,1.7783216,,,,,,,,,,,,,,,,, -37100,0.4265278,1.8329558,,,,,,,,,,,,,,,,, -37200,0.461619,1.8477137,,,,,,,,,,,,,,,,, -37300,0.44292822,1.7800952,,,,,,,,,,,,,,,,, -37400,0.44711247,1.7138324,,,,,,,,,,,,,,,,, -37500,0.4098627,1.7813072,,,,,,,,,,,,,,,,, -37600,0.44354895,1.7483754,,,,,,,,,,,,,,,,, -37700,0.39416584,1.651689,,,,,,,,,,,,,,,,, -37800,0.4393594,1.7557132,,,,,,,,,,,,,,,,, -37900,0.43022898,1.7416117,,,,,,,,,,,,,,,,, -38000,0.4407445,1.7675902,,,,,,,,,,,,,,,,, -38100,0.4274095,1.7921482,,,,,,,,,,,,,,,,, -38118,,,0.6518439054489136,1.66854727268219,32.117936466755715,0.6590867042541504,1.615612506866455,28.395664950065505,3000.0,0.6677822470664978,1.551979899406433,27.36167683701787,3003.0,13468.088101387024,22775.50716114044,13468.088101387024,9305.749348640442,0.4531617164611816,0.0 -38200,0.44681215,1.761925,,,,,,,,,,,,,,,,, -38300,0.423811,1.6664015,,,,,,,,,,,,,,,,, -38400,0.4164234,1.7861099,,,,,,,,,,,,,,,,, -38500,0.40708548,1.7225417,,,,,,,,,,,,,,,,, -38600,0.37147012,1.7082762,,,,,,,,,,,,,,,,, -38700,0.36940658,1.6901343,,,,,,,,,,,,,,,,, -38800,0.39725175,1.7928702,,,,,,,,,,,,,,,,, -38900,0.39818546,1.7091628,,,,,,,,,,,,,,,,, -39000,0.4395902,1.864736,,,,,,,,,,,,,,,,, -39100,0.36265123,1.7788006,,,,,,,,,,,,,,,,, -39200,0.50461227,1.7646639,,,,,,,,,,,,,,,,, -39300,0.40480846,1.7814898,,,,,,,,,,,,,,,,, -39400,0.4180691,1.8500347,,,,,,,,,,,,,,,,, -39500,0.47947958,1.7787417,,,,,,,,,,,,,,,,, -39600,0.39182144,1.7447944,,,,,,,,,,,,,,,,, -39700,0.43595654,1.7252959,,,,,,,,,,,,,,,,, -39800,0.43153602,1.7696668,,,,,,,,,,,,,,,,, -39900,0.3851747,1.7610475,,,,,,,,,,,,,,,,, -40000,0.4364949,1.8308696,,,,,,,,,,,,,,,,, -40100,0.3858137,1.7568004,,,,,,,,,,,,,,,,, -40200,0.40628132,1.7346411,,,,,,,,,,,,,,,,, -40300,0.672211,1.7552929,,,,,,,,,,,,,,,,, -40400,0.49647596,1.7641764,,,,,,,,,,,,,,,,, -40500,0.4285573,1.7688192,,,,,,,,,,,,,,,,, -40502,,,0.6441013216972351,1.7263169288635254,31.448439318045764,0.6599918007850647,1.606161117553711,28.42181715596443,3000.0,0.6714543104171753,1.5414283275604248,28.16023044547092,3003.0,14308.046524524689,24192.2749774456,14308.046524524689,9882.452818155289,0.4833953380584717,0.0 -40600,0.37228522,1.7783601,,,,,,,,,,,,,,,,, -40700,0.40369272,1.7440768,,,,,,,,,,,,,,,,, -40800,0.39951354,1.7116125,,,,,,,,,,,,,,,,, -40900,0.38260296,1.6853102,,,,,,,,,,,,,,,,, -41000,0.40044084,1.7430625,,,,,,,,,,,,,,,,, -41100,0.4493143,1.8200383,,,,,,,,,,,,,,,,, -41200,0.42641607,1.7866359,,,,,,,,,,,,,,,,, -41300,0.48101574,1.7529837,,,,,,,,,,,,,,,,, -41400,0.42709544,1.7038109,,,,,,,,,,,,,,,,, -41500,0.41221836,1.7211308,,,,,,,,,,,,,,,,, -41600,0.4507934,1.7859205,,,,,,,,,,,,,,,,, -41700,0.4606913,1.8079388,,,,,,,,,,,,,,,,, -41800,0.41839445,1.8191732,,,,,,,,,,,,,,,,, -41900,0.36642164,1.7806296,,,,,,,,,,,,,,,,, -42000,0.35835806,1.7562541,,,,,,,,,,,,,,,,, -42100,0.37663963,1.7498213,,,,,,,,,,,,,,,,, -42200,0.4454656,1.7552643,,,,,,,,,,,,,,,,, -42300,0.42454788,1.7290993,,,,,,,,,,,,,,,,, -42400,0.41957253,1.7329384,,,,,,,,,,,,,,,,, -42500,0.3750802,1.7142857,,,,,,,,,,,,,,,,, -42600,0.42583632,1.711397,,,,,,,,,,,,,,,,, -42700,0.42417735,1.7335026,,,,,,,,,,,,,,,,, -42800,0.38813585,1.7411817,,,,,,,,,,,,,,,,, -42887,,,0.6383638978004456,1.7570757865905762,31.42605783809,0.6582187414169312,1.6062474250793457,28.302365767108448,3000.0,0.6733600497245789,1.5363410711288452,28.174196940148576,3003.0,15148.275726556778,25559.754118919373,15148.275726556778,10409.5962266922,0.5133147239685059,0.0 -42900,0.39340648,1.696264,,,,,,,,,,,,,,,,, -43000,0.41908893,1.750681,,,,,,,,,,,,,,,,, -43100,0.40044922,1.7885184,,,,,,,,,,,,,,,,, -43200,0.37045193,1.757416,,,,,,,,,,,,,,,,, -43300,0.3765215,1.6970011,,,,,,,,,,,,,,,,, -43400,0.39259693,1.7846867,,,,,,,,,,,,,,,,, -43500,0.36743465,1.7433081,,,,,,,,,,,,,,,,, -43600,0.45179376,1.7146244,,,,,,,,,,,,,,,,, -43700,0.44176242,1.7722827,,,,,,,,,,,,,,,,, -43800,0.394659,1.7504419,,,,,,,,,,,,,,,,, -43900,0.40229002,1.7258424,,,,,,,,,,,,,,,,, -44000,0.38978225,1.7225691,,,,,,,,,,,,,,,,, -44100,0.3653315,1.712751,,,,,,,,,,,,,,,,, -44200,0.47193426,1.8213261,,,,,,,,,,,,,,,,, -44300,0.4685324,1.7314205,,,,,,,,,,,,,,,,, -44400,0.39485762,1.7631216,,,,,,,,,,,,,,,,, -44500,0.45040968,1.7769306,,,,,,,,,,,,,,,,, -44600,0.37394837,1.7629093,,,,,,,,,,,,,,,,, -44700,0.36921203,1.6556097,,,,,,,,,,,,,,,,, -44800,0.4122119,1.7005948,,,,,,,,,,,,,,,,, -44900,0.46263945,1.9084008,,,,,,,,,,,,,,,,, -45000,0.438212,1.6952828,,,,,,,,,,,,,,,,, -45100,0.3914101,1.7242533,,,,,,,,,,,,,,,,, -45200,0.40097752,1.790145,,,,,,,,,,,,,,,,, -45271,,,0.6441451907157898,1.7160241603851318,31.51585807077358,0.6614177227020264,1.5945888757705688,28.40174838792601,3000.0,0.6741386651992798,1.5258550643920898,27.84871434765751,3003.0,15988.28655910492,26990.38134407997,15988.28655910492,11000.106136083605,0.5449538230895996,0.0 -45300,0.41776627,1.7379758,,,,,,,,,,,,,,,,, -45400,0.41671985,1.7887436,,,,,,,,,,,,,,,,, -45500,0.44094384,1.839021,,,,,,,,,,,,,,,,, -45600,0.38986734,1.7680917,,,,,,,,,,,,,,,,, -45700,0.3956402,1.7593005,,,,,,,,,,,,,,,,, -45800,0.36327207,1.7406989,,,,,,,,,,,,,,,,, -45900,0.4425843,1.8260283,,,,,,,,,,,,,,,,, -46000,0.4192664,1.7072474,,,,,,,,,,,,,,,,, -46100,0.38415396,1.8177619,,,,,,,,,,,,,,,,, -46200,0.42257696,1.7460848,,,,,,,,,,,,,,,,, -46300,0.38782004,1.7372488,,,,,,,,,,,,,,,,, -46400,0.3798155,1.7163869,,,,,,,,,,,,,,,,, -46500,0.3949398,1.6711653,,,,,,,,,,,,,,,,, -46600,0.4262255,1.7026787,,,,,,,,,,,,,,,,, -46700,0.4262584,1.7426825,,,,,,,,,,,,,,,,, -46800,0.38432664,1.7518492,,,,,,,,,,,,,,,,, -46900,0.4300752,1.7427444,,,,,,,,,,,,,,,,, -47000,0.40893993,1.8441266,,,,,,,,,,,,,,,,, -47100,0.39830756,1.8369648,,,,,,,,,,,,,,,,, -47200,0.39922932,1.8198768,,,,,,,,,,,,,,,,, -47300,0.4353289,1.6716015,,,,,,,,,,,,,,,,, -47400,0.4286153,1.7116281,,,,,,,,,,,,,,,,, -47500,0.37453157,1.7690014,,,,,,,,,,,,,,,,, -47600,0.42897683,1.6824337,,,,,,,,,,,,,,,,, -47656,,,0.6404057145118713,1.7511805295944214,31.17807012669652,0.6617772579193115,1.587974190711975,28.050460138448788,3000.0,0.674312949180603,1.5223636627197266,28.13550040341368,3003.0,16828.508977890015,28401.18878436089,16828.508977890015,11570.585859537125,0.5759387016296387,0.0 -47700,0.3724965,1.8084376,,,,,,,,,,,,,,,,, -47800,0.58096653,1.7330805,,,,,,,,,,,,,,,,, -47900,0.40161574,1.744397,,,,,,,,,,,,,,,,, -48000,0.4273681,1.709191,,,,,,,,,,,,,,,,, -48100,0.5279874,1.6516944,,,,,,,,,,,,,,,,, -48200,0.4429927,1.7442596,,,,,,,,,,,,,,,,, -48300,0.39547887,1.7747788,,,,,,,,,,,,,,,,, -48400,0.49257392,1.768335,,,,,,,,,,,,,,,,, -48500,0.41311067,1.7283797,,,,,,,,,,,,,,,,, -48600,0.36604652,1.7453381,,,,,,,,,,,,,,,,, -48700,0.4409349,1.7735393,,,,,,,,,,,,,,,,, -48800,0.3739456,1.6899073,,,,,,,,,,,,,,,,, -48900,0.42483622,1.8034654,,,,,,,,,,,,,,,,, -49000,0.44365227,1.736878,,,,,,,,,,,,,,,,, -49100,0.42095977,1.7238541,,,,,,,,,,,,,,,,, -49200,0.3723727,1.7003894,,,,,,,,,,,,,,,,, -49300,1.026758,1.7536438,,,,,,,,,,,,,,,,, -49400,0.4304264,1.7782967,,,,,,,,,,,,,,,,, -49500,0.4190399,1.6824181,,,,,,,,,,,,,,,,, -49600,0.5051933,1.7328762,,,,,,,,,,,,,,,,, -49700,0.42024845,1.7878823,,,,,,,,,,,,,,,,, -49800,0.454995,1.8233647,,,,,,,,,,,,,,,,, -49900,0.41575623,1.8221581,,,,,,,,,,,,,,,,, -50000,0.41048,1.7381811,,,,,,,,,,,,,,,,, -50040,,,0.6673862934112549,1.58917498588562,33.58396302504942,0.6645546555519104,1.5788758993148804,28.75264872649214,3000.0,0.6757422685623169,1.516922116279602,27.995736464992472,3003.0,17668.541675567627,29805.83444905281,17668.541675567627,12135.090401649475,0.6083238124847412,0.0 -50100,0.4317098,1.7431549,,,,,,,,,,,,,,,,, -50200,0.428172,1.7379228,,,,,,,,,,,,,,,,, -50300,0.4803364,1.7361671,,,,,,,,,,,,,,,,, -50400,0.43012267,1.742029,,,,,,,,,,,,,,,,, -50500,0.46276805,1.7019523,,,,,,,,,,,,,,,,, -50600,0.4417053,1.7242284,,,,,,,,,,,,,,,,, -50700,0.41427708,1.7290876,,,,,,,,,,,,,,,,, -50800,0.38167465,1.7972922,,,,,,,,,,,,,,,,, -50900,0.38846397,1.7719499,,,,,,,,,,,,,,,,, -51000,0.43479964,1.7051747,,,,,,,,,,,,,,,,, -51100,0.3899558,1.6766465,,,,,,,,,,,,,,,,, -51200,0.38502023,1.701538,,,,,,,,,,,,,,,,, -51300,0.44639835,1.762364,,,,,,,,,,,,,,,,, -51400,0.41232604,1.6596364,,,,,,,,,,,,,,,,, -51500,0.36104,1.6919886,,,,,,,,,,,,,,,,, -51600,0.41801098,1.7837611,,,,,,,,,,,,,,,,, -51700,0.42451584,1.7402437,,,,,,,,,,,,,,,,, -51800,0.41799936,1.8089926,,,,,,,,,,,,,,,,, -51900,0.39306334,1.6987576,,,,,,,,,,,,,,,,, -52000,0.41181144,1.6988366,,,,,,,,,,,,,,,,, -52100,0.40695596,1.7522774,,,,,,,,,,,,,,,,, -52200,0.39940172,1.7945255,,,,,,,,,,,,,,,,, -52300,0.39124537,1.7144123,,,,,,,,,,,,,,,,, -52400,0.41386452,1.7928289,,,,,,,,,,,,,,,,, -52425,,,0.6453410387039185,1.7113089561462402,31.580043661207323,0.6636619567871094,1.575170874595642,28.384619001282065,3000.0,0.6774504780769348,1.505465745925903,27.875327607675807,3003.0,18508.65367841721,31176.78390431404,18508.65367841721,12665.82209134102,0.6415340900421143,0.0 -52500,0.37430367,1.7405266,,,,,,,,,,,,,,,,, -52600,0.40276757,1.7485418,,,,,,,,,,,,,,,,, -52700,0.45146704,1.6792264,,,,,,,,,,,,,,,,, -52800,0.42167342,1.728744,,,,,,,,,,,,,,,,, -52900,0.40443298,1.7122158,,,,,,,,,,,,,,,,, -53000,0.3996206,1.71045,,,,,,,,,,,,,,,,, -53100,0.3675148,1.6555616,,,,,,,,,,,,,,,,, -53200,0.43037513,1.6738114,,,,,,,,,,,,,,,,, -53300,0.42196494,1.6323273,,,,,,,,,,,,,,,,, -53400,0.43526837,1.7330277,,,,,,,,,,,,,,,,, -53500,0.37900704,1.7198253,,,,,,,,,,,,,,,,, -53600,0.50389755,1.6846591,,,,,,,,,,,,,,,,, -53700,0.46559164,1.7866989,,,,,,,,,,,,,,,,, -53800,0.46027267,1.7044681,,,,,,,,,,,,,,,,, -53900,0.4100626,1.7139524,,,,,,,,,,,,,,,,, -54000,0.422541,1.6969582,,,,,,,,,,,,,,,,, -54100,0.39731547,1.845972,,,,,,,,,,,,,,,,, -54200,0.48555535,1.7653214,,,,,,,,,,,,,,,,, -54300,0.4280834,1.7336674,,,,,,,,,,,,,,,,, -54400,0.4386439,1.7269278,,,,,,,,,,,,,,,,, -54500,0.44356975,1.8066437,,,,,,,,,,,,,,,,, -54600,0.38100654,1.7378455,,,,,,,,,,,,,,,,, -54700,0.38012058,1.7512616,,,,,,,,,,,,,,,,, -54800,0.38739404,1.6659119,,,,,,,,,,,,,,,,, -54810,,,0.6456965804100037,1.71318256855011,31.87085102078416,0.6642695069313049,1.5717233419418335,28.47791587474528,3000.0,0.6790425181388855,1.4980599880218506,28.44183774132029,3003.0,19348.648066282272,32530.86065387726,19348.648066282272,13179.79739499092,0.6752762794494629,0.0 -54900,0.3925523,1.7455006,,,,,,,,,,,,,,,,, -55000,0.5389719,1.7637941,,,,,,,,,,,,,,,,, -55100,0.40844497,1.7102197,,,,,,,,,,,,,,,,, -55200,0.38823417,1.6701636,,,,,,,,,,,,,,,,, -55300,0.42056662,1.673244,,,,,,,,,,,,,,,,, -55400,0.42361432,1.7103606,,,,,,,,,,,,,,,,, -55500,0.39189398,1.7753129,,,,,,,,,,,,,,,,, -55600,0.4192489,1.7942241,,,,,,,,,,,,,,,,, -55700,0.39214814,1.612374,,,,,,,,,,,,,,,,, -55800,0.37232977,1.7091764,,,,,,,,,,,,,,,,, -55900,0.37623695,1.6932523,,,,,,,,,,,,,,,,, -56000,0.39508963,1.7305311,,,,,,,,,,,,,,,,, -56100,0.41077158,1.7191195,,,,,,,,,,,,,,,,, -56200,0.3910582,1.7795647,,,,,,,,,,,,,,,,, -56300,0.4436746,1.7491711,,,,,,,,,,,,,,,,, -56400,0.37593275,1.7027028,,,,,,,,,,,,,,,,, -56500,0.416561,1.7571481,,,,,,,,,,,,,,,,, -56600,0.39991847,1.7411913,,,,,,,,,,,,,,,,, -56700,0.45912552,1.6386775,,,,,,,,,,,,,,,,, -56800,0.40330198,1.689229,,,,,,,,,,,,,,,,, -56900,0.36016774,1.6381799,,,,,,,,,,,,,,,,, -57000,0.39236245,1.5849999,,,,,,,,,,,,,,,,, -57100,0.41555646,1.7906823,,,,,,,,,,,,,,,,, -57195,,,0.6577078104019165,1.6320240497589111,32.78105983753737,0.667592465877533,1.5603337287902832,29.03892667850046,3000.0,0.6807158589363098,1.4926679134368896,28.4510571567512,3003.0,20188.61677551269,33872.19545149803,20188.61677551269,13681.056218147278,0.7081527709960938,0.0 -57200,0.42101523,1.7473681,,,,,,,,,,,,,,,,, -57300,0.38171673,1.733603,,,,,,,,,,,,,,,,, -57400,0.37776166,1.6819657,,,,,,,,,,,,,,,,, -57500,0.40065554,1.6857661,,,,,,,,,,,,,,,,, -57600,0.40377578,1.6521219,,,,,,,,,,,,,,,,, -57700,0.4008058,1.6839626,,,,,,,,,,,,,,,,, -57800,0.4783065,1.6473689,,,,,,,,,,,,,,,,, -57900,0.37634328,1.754425,,,,,,,,,,,,,,,,, -58000,0.39037544,1.7813886,,,,,,,,,,,,,,,,, -58100,0.3943809,1.7170709,,,,,,,,,,,,,,,,, -58200,0.39612362,1.6857196,,,,,,,,,,,,,,,,, -58300,0.4354896,1.6468791,,,,,,,,,,,,,,,,, -58400,0.40675282,1.7244374,,,,,,,,,,,,,,,,, -58500,0.40088573,1.7109184,,,,,,,,,,,,,,,,, -58600,0.43391907,1.6728903,,,,,,,,,,,,,,,,, -58700,0.39194486,1.6942528,,,,,,,,,,,,,,,,, -58800,0.3934068,1.6813723,,,,,,,,,,,,,,,,, -58900,0.40849292,1.8083817,,,,,,,,,,,,,,,,, -59000,0.3979589,1.6890237,,,,,,,,,,,,,,,,, -59100,0.36941585,1.642354,,,,,,,,,,,,,,,,, -59200,0.430814,1.7820251,,,,,,,,,,,,,,,,, -59300,0.41040012,1.704797,,,,,,,,,,,,,,,,, -59400,0.3952569,1.7120138,,,,,,,,,,,,,,,,, -59500,0.3863021,1.6675555,,,,,,,,,,,,,,,,, -59580,,,0.6465297937393188,1.694580316543579,32.05305332294858,0.6670469045639038,1.5555315017700195,28.62439054702784,3000.0,0.6815292835235596,1.4816378355026243,28.475886821190706,3003.0,21028.747178077698,35296.116545677185,21028.747178077698,14264.738429784777,0.74080491065979,0.0 -59600,0.46722445,1.7587065,,,,,,,,,,,,,,,,, -59700,0.38178295,1.6651424,,,,,,,,,,,,,,,,, -59800,0.40047583,1.6430638,,,,,,,,,,,,,,,,, -59900,0.4090994,1.6534805,,,,,,,,,,,,,,,,, -60000,0.4177482,1.7486166,,,,,,,,,,,,,,,,, -60100,0.40828493,1.648522,,,,,,,,,,,,,,,,, -60200,0.41645962,1.735779,,,,,,,,,,,,,,,,, -60300,0.9109086,1.7031162,,,,,,,,,,,,,,,,, -60400,0.37085795,1.6851727,,,,,,,,,,,,,,,,, -60500,0.41794607,1.6980969,,,,,,,,,,,,,,,,, -60600,0.41697052,1.69116,,,,,,,,,,,,,,,,, -60700,0.398363,1.6579949,,,,,,,,,,,,,,,,, -60800,0.42972785,1.7172992,,,,,,,,,,,,,,,,, -60900,0.4484093,1.6419854,,,,,,,,,,,,,,,,, -61000,0.42606783,1.6665711,,,,,,,,,,,,,,,,, -61100,0.41140524,1.7234063,,,,,,,,,,,,,,,,, -61200,0.4962393,1.754761,,,,,,,,,,,,,,,,, -61300,0.42072648,1.6700827,,,,,,,,,,,,,,,,, -61400,0.40158138,1.6704762,,,,,,,,,,,,,,,,, -61500,0.41234058,1.6563091,,,,,,,,,,,,,,,,, -61600,0.44095987,1.6234658,,,,,,,,,,,,,,,,, -61700,0.38851985,1.6922745,,,,,,,,,,,,,,,,, -61800,0.40818736,1.6906811,,,,,,,,,,,,,,,,, -61900,0.40358615,1.744319,,,,,,,,,,,,,,,,, -61965,,,0.648226261138916,1.7013450860977173,32.419402937612425,0.6691423654556274,1.5469614267349243,29.033741216883577,3000.0,0.6838417649269104,1.4731414318084717,28.84332825064846,3003.0,21868.823503017426,36625.751838207245,21868.823503017426,14754.18920135498,0.7733290195465088,0.0 -62000,0.42382494,1.7532995,,,,,,,,,,,,,,,,, -62100,0.43296018,1.6724858,,,,,,,,,,,,,,,,, -62200,0.39774343,1.7151725,,,,,,,,,,,,,,,,, -62300,0.37586504,1.6234542,,,,,,,,,,,,,,,,, -62400,0.37363628,1.6916856,,,,,,,,,,,,,,,,, -62500,0.40258923,1.7443511,,,,,,,,,,,,,,,,, -62600,0.40606946,1.7611979,,,,,,,,,,,,,,,,, -62700,0.45024702,1.7562237,,,,,,,,,,,,,,,,, -62800,0.3977984,1.6767404,,,,,,,,,,,,,,,,, -62900,0.425033,1.7477008,,,,,,,,,,,,,,,,, -63000,0.39061046,1.7260064,,,,,,,,,,,,,,,,, -63100,0.37725252,1.7080315,,,,,,,,,,,,,,,,, -63200,0.43060732,1.7186749,,,,,,,,,,,,,,,,, -63300,0.5071885,1.7313339,,,,,,,,,,,,,,,,, -63400,0.38790527,1.6360599,,,,,,,,,,,,,,,,, -63500,0.40721858,1.6202444,,,,,,,,,,,,,,,,, -63600,0.38406077,1.6498126,,,,,,,,,,,,,,,,, -63700,0.40774018,1.6859953,,,,,,,,,,,,,,,,, -63800,0.40857714,1.7256296,,,,,,,,,,,,,,,,, -63900,0.39389253,1.7456393,,,,,,,,,,,,,,,,, -64000,0.36036927,1.6098678,,,,,,,,,,,,,,,,, -64100,0.4044848,1.7013743,,,,,,,,,,,,,,,,, -64200,0.4384326,1.759786,,,,,,,,,,,,,,,,, -64300,0.40323249,1.682199,,,,,,,,,,,,,,,,, -64349,,,0.6545460820198059,1.6462438106536863,32.617605237832414,0.6699978709220886,1.5386850833892822,29.007868109997457,3000.0,0.6819010972976685,1.4711955785751345,28.5848930702908,3003.0,22708.90673828125,37976.13408732414,22708.90673828125,15264.377411842346,0.8070766925811768,0.0 -64400,0.3987776,1.6695962,,,,,,,,,,,,,,,,, -64500,0.4478307,1.7644292,,,,,,,,,,,,,,,,, -64600,0.37377962,1.6545163,,,,,,,,,,,,,,,,, -64700,0.4188652,1.5980006,,,,,,,,,,,,,,,,, -64800,0.39417773,1.6925207,,,,,,,,,,,,,,,,, -64900,0.3883163,1.7697568,,,,,,,,,,,,,,,,, -65000,0.44365162,1.6397749,,,,,,,,,,,,,,,,, -65100,0.40598974,1.654275,,,,,,,,,,,,,,,,, -65200,0.40324122,1.6887832,,,,,,,,,,,,,,,,, -65300,0.42257562,1.7803544,,,,,,,,,,,,,,,,, -65400,0.43487695,1.6385343,,,,,,,,,,,,,,,,, -65500,0.4928438,1.6663016,,,,,,,,,,,,,,,,, -65600,0.41058818,1.6071581,,,,,,,,,,,,,,,,, -65700,0.4000558,1.6981187,,,,,,,,,,,,,,,,, -65800,0.44301575,1.64231,,,,,,,,,,,,,,,,, -65900,0.43084577,1.7308153,,,,,,,,,,,,,,,,, -66000,0.44434586,1.7002606,,,,,,,,,,,,,,,,, -66100,0.413866,1.6891222,,,,,,,,,,,,,,,,, -66200,0.4290026,1.5969383,,,,,,,,,,,,,,,,, -66300,0.42687535,1.6166077,,,,,,,,,,,,,,,,, -66400,0.39420182,1.6371936,,,,,,,,,,,,,,,,, -66500,0.42916,1.7366449,,,,,,,,,,,,,,,,, -66600,0.38764143,1.7461084,,,,,,,,,,,,,,,,, -66700,0.3782773,1.6576275,,,,,,,,,,,,,,,,, -66733,,,0.6512013673782349,1.6751712560653689,32.131119778833856,0.6714361906051636,1.5344096422195437,29.08472915961638,3000.0,0.6844227910041809,1.462852954864502,28.940553718862866,3003.0,23548.84869503975,39290.32270479202,23548.84869503975,15738.510041713716,0.8472380638122559,0.0 -66800,0.39769307,1.6995963,,,,,,,,,,,,,,,,, -66900,0.3905306,1.6358914,,,,,,,,,,,,,,,,, -67000,0.3756325,1.663804,,,,,,,,,,,,,,,,, -67100,0.3862692,1.6955061,,,,,,,,,,,,,,,,, -67200,0.42820945,1.7706045,,,,,,,,,,,,,,,,, -67300,0.38916156,1.6630886,,,,,,,,,,,,,,,,, -67400,0.38377994,1.6321766,,,,,,,,,,,,,,,,, -67500,0.39581037,1.6725094,,,,,,,,,,,,,,,,, -67600,0.42161122,1.6892692,,,,,,,,,,,,,,,,, -67700,0.40322286,1.6884744,,,,,,,,,,,,,,,,, -67800,0.39206436,1.6920916,,,,,,,,,,,,,,,,, -67900,0.4000872,1.6005975,,,,,,,,,,,,,,,,, -68000,0.40612012,1.755309,,,,,,,,,,,,,,,,, -68100,0.40887755,1.6514896,,,,,,,,,,,,,,,,, -68200,0.4212512,1.6741914,,,,,,,,,,,,,,,,, -68300,0.39367795,1.7149564,,,,,,,,,,,,,,,,, -68400,0.42383754,1.7753246,,,,,,,,,,,,,,,,, -68500,0.38653627,1.672739,,,,,,,,,,,,,,,,, -68600,0.4067442,1.5995252,,,,,,,,,,,,,,,,, -68700,0.41079408,1.6256644,,,,,,,,,,,,,,,,, -68800,0.41592845,1.7434523,,,,,,,,,,,,,,,,, -68900,0.42945442,1.7063336,,,,,,,,,,,,,,,,, -69000,0.41766143,1.6720542,,,,,,,,,,,,,,,,, -69100,0.44280505,1.7253451,,,,,,,,,,,,,,,,, -69117,,,0.6723154783248901,1.534518480300903,33.717381365500145,0.6729241013526917,1.5261874198913574,29.173103041104703,3000.0,0.6858404874801636,1.4534401893615725,29.00716191550685,3003.0,24388.915120363235,40672.09962892532,24388.915120363235,16280.106495141985,0.8839507102966309,0.0 -69200,0.38564345,1.7016748,,,,,,,,,,,,,,,,, -69300,0.39898127,1.7545305,,,,,,,,,,,,,,,,, -69400,0.39774123,1.6158828,,,,,,,,,,,,,,,,, -69500,0.41314724,1.655859,,,,,,,,,,,,,,,,, -69600,0.42903715,1.6677893,,,,,,,,,,,,,,,,, -69700,0.41141382,1.6663032,,,,,,,,,,,,,,,,, -69800,0.41499126,1.6523345,,,,,,,,,,,,,,,,, -69900,0.38832182,1.6380696,,,,,,,,,,,,,,,,, -70000,0.3905949,1.6655945,,,,,,,,,,,,,,,,, -70100,0.3994601,1.6672688,,,,,,,,,,,,,,,,, -70200,0.45743015,1.6194923,,,,,,,,,,,,,,,,, -70300,0.4110142,1.6257703,,,,,,,,,,,,,,,,, -70400,0.40982762,1.6720672,,,,,,,,,,,,,,,,, -70500,0.41276857,1.6285148,,,,,,,,,,,,,,,,, -70600,0.3626666,1.6746083,,,,,,,,,,,,,,,,, -70700,0.40317413,1.6558539,,,,,,,,,,,,,,,,, -70800,0.41226983,1.6342361,,,,,,,,,,,,,,,,, -70900,0.4082426,1.6387112,,,,,,,,,,,,,,,,, -71000,0.411119,1.611752,,,,,,,,,,,,,,,,, -71100,0.4411955,1.6441277,,,,,,,,,,,,,,,,, -71200,0.43548447,1.6472654,,,,,,,,,,,,,,,,, -71300,0.4045268,1.6664922,,,,,,,,,,,,,,,,, -71400,0.42024302,1.6844349,,,,,,,,,,,,,,,,, -71500,0.3872433,1.6725107,,,,,,,,,,,,,,,,, -71501,,,0.65537029504776,1.6473228931427002,32.375148555377464,0.6737300157546997,1.5176098346710205,29.28750192928385,3000.0,0.685910165309906,1.4481548070907593,28.593936374173712,3003.0,25228.996363401413,42058.55702161789,25228.996363401413,16826.37229347229,0.9204680919647216,0.0 -71600,0.4133409,1.6007106,,,,,,,,,,,,,,,,, -71700,0.41361666,1.713381,,,,,,,,,,,,,,,,, -71800,0.3804048,1.6195858,,,,,,,,,,,,,,,,, -71900,0.4192469,1.6506279,,,,,,,,,,,,,,,,, -72000,0.40210032,1.5911855,,,,,,,,,,,,,,,,, -72100,0.4043919,1.61979,,,,,,,,,,,,,,,,, -72200,0.40579963,1.638669,,,,,,,,,,,,,,,,, -72300,0.41789097,1.6214007,,,,,,,,,,,,,,,,, -72400,0.4253897,1.6943225,,,,,,,,,,,,,,,,, -72500,0.39481315,1.7110022,,,,,,,,,,,,,,,,, -72600,0.4305816,1.6290903,,,,,,,,,,,,,,,,, -72700,0.41970852,1.5855896,,,,,,,,,,,,,,,,, -72800,0.44031334,1.7054712,,,,,,,,,,,,,,,,, -72900,0.44393775,1.7498902,,,,,,,,,,,,,,,,, -73000,0.4224051,1.6696049,,,,,,,,,,,,,,,,, -73100,0.40918133,1.6414952,,,,,,,,,,,,,,,,, -73200,0.41611055,1.6694813,,,,,,,,,,,,,,,,, -73300,0.4170882,1.6466012,,,,,,,,,,,,,,,,, -73400,0.42269582,1.5999423,,,,,,,,,,,,,,,,, -73500,0.38899222,1.6956426,,,,,,,,,,,,,,,,, -73600,0.3997476,1.5285345,,,,,,,,,,,,,,,,, -73700,0.40322453,1.7354419,,,,,,,,,,,,,,,,, -73800,0.39792278,1.6505327,,,,,,,,,,,,,,,,, -73886,,,0.6524078845977783,1.6667128801345823,32.701466533580295,0.6737548112869263,1.5123512744903564,29.32046658710141,3000.0,0.6876416206359863,1.4385989904403689,28.943381743475623,3003.0,26069.110144138336,43468.737097263336,26069.110144138336,17396.327433347702,0.9564275741577148,0.0 -73900,0.45469698,1.6718203,,,,,,,,,,,,,,,,, -74000,0.42549428,1.733208,,,,,,,,,,,,,,,,, -74100,0.4156398,1.6940804,,,,,,,,,,,,,,,,, -74200,0.42821848,1.6018199,,,,,,,,,,,,,,,,, -74300,0.41700476,1.6751758,,,,,,,,,,,,,,,,, -74400,0.40943736,1.617662,,,,,,,,,,,,,,,,, -74500,0.44118607,1.6316954,,,,,,,,,,,,,,,,, -74600,0.42122373,1.6274626,,,,,,,,,,,,,,,,, -74700,0.3890705,1.6708432,,,,,,,,,,,,,,,,, -74800,0.4294229,1.6809566,,,,,,,,,,,,,,,,, -74900,0.41814446,1.631114,,,,,,,,,,,,,,,,, -75000,0.43544638,1.6499164,,,,,,,,,,,,,,,,, -75100,0.40194046,1.6477789,,,,,,,,,,,,,,,,, -75200,0.41490284,1.6423829,,,,,,,,,,,,,,,,, -75300,0.3724993,1.6093398,,,,,,,,,,,,,,,,, -75400,0.44054747,1.7117406,,,,,,,,,,,,,,,,, -75500,0.39612988,1.6588955,,,,,,,,,,,,,,,,, -75600,0.39392734,1.5673765,,,,,,,,,,,,,,,,, -75700,0.3920588,1.6730448,,,,,,,,,,,,,,,,, -75800,0.38826856,1.6091791,,,,,,,,,,,,,,,,, -75900,0.41680866,1.7449114,,,,,,,,,,,,,,,,, -76000,0.3978739,1.5970969,,,,,,,,,,,,,,,,, -76100,0.40495783,1.6164632,,,,,,,,,,,,,,,,, -76200,0.44955686,1.6637431,,,,,,,,,,,,,,,,, -76269,,,0.6652563810348511,1.5872588157653809,33.20367223177718,0.6766065955162048,1.50298011302948,29.703277161865785,3000.0,0.6899773478507996,1.4313315153121948,29.44708155748432,3003.0,26909.085456848145,44773.62126874924,26909.085456848145,17861.11874818802,0.9940569400787354,0.0 -76300,0.41019988,1.6754577,,,,,,,,,,,,,,,,, -76400,0.40798098,1.5923164,,,,,,,,,,,,,,,,, -76500,0.4108922,1.5999578,,,,,,,,,,,,,,,,, -76600,0.4296974,1.6223075,,,,,,,,,,,,,,,,, -76700,0.40874508,1.6794325,,,,,,,,,,,,,,,,, -76800,0.41736138,1.5632466,,,,,,,,,,,,,,,,, -76900,0.4119093,1.732046,,,,,,,,,,,,,,,,, -77000,0.42954624,1.6296556,,,,,,,,,,,,,,,,, -77100,0.41269222,1.6264607,,,,,,,,,,,,,,,,, -77200,0.4591298,1.6433098,,,,,,,,,,,,,,,,, -77300,0.44450146,1.6507764,,,,,,,,,,,,,,,,, -77400,0.42939728,1.6344471,,,,,,,,,,,,,,,,, -77500,0.42528602,1.6195399,,,,,,,,,,,,,,,,, -77600,0.42933488,1.6478256,,,,,,,,,,,,,,,,, -77700,0.45259625,1.6742383,,,,,,,,,,,,,,,,, -77800,0.3943322,1.6113863,,,,,,,,,,,,,,,,, -77900,0.4487643,1.6973256,,,,,,,,,,,,,,,,, -78000,0.43177494,1.6189873,,,,,,,,,,,,,,,,, -78100,0.3868838,1.6291949,,,,,,,,,,,,,,,,, -78200,0.43883228,1.6927392,,,,,,,,,,,,,,,,, -78300,0.43124872,1.6846567,,,,,,,,,,,,,,,,, -78400,0.40548423,1.6626494,,,,,,,,,,,,,,,,, -78500,0.38967854,1.6044811,,,,,,,,,,,,,,,,, -78600,0.42092004,1.5329399,,,,,,,,,,,,,,,,, -78654,,,0.6580626368522644,1.6298712491989136,32.67718189408843,0.67764812707901,1.4988480806350708,29.64695478586008,3000.0,0.6922898292541504,1.42089581489563,29.830781468889445,3003.0,27749.27845311165,46152.97341346741,27749.27845311165,18400.16328859329,1.0356485843658447,0.0 -78700,0.43111554,1.6273583,,,,,,,,,,,,,,,,, -78800,0.42626065,1.6665995,,,,,,,,,,,,,,,,, -78900,0.40265402,1.6284941,,,,,,,,,,,,,,,,, -79000,0.48366153,1.6211354,,,,,,,,,,,,,,,,, -79100,0.4335621,1.5564594,,,,,,,,,,,,,,,,, -79200,0.45285025,1.672298,,,,,,,,,,,,,,,,, -79300,0.42707345,1.6127937,,,,,,,,,,,,,,,,, -79400,0.45665824,1.6173854,,,,,,,,,,,,,,,,, -79500,0.4321177,1.5916103,,,,,,,,,,,,,,,,, -79600,0.41207403,1.6759185,,,,,,,,,,,,,,,,, -79700,0.4292186,1.6354035,,,,,,,,,,,,,,,,, -79800,0.4327654,1.6231644,,,,,,,,,,,,,,,,, -79900,0.42866975,1.6574093,,,,,,,,,,,,,,,,, -80000,0.45037773,1.6685627,,,,,,,,,,,,,,,,, -80100,0.42885616,1.594273,,,,,,,,,,,,,,,,, -80200,0.43311474,1.6590419,,,,,,,,,,,,,,,,, -80300,0.41619027,1.6654739,,,,,,,,,,,,,,,,, -80400,0.4209502,1.5898244,,,,,,,,,,,,,,,,, -80500,0.43752435,1.7393001,,,,,,,,,,,,,,,,, -80600,0.44935557,1.5997168,,,,,,,,,,,,,,,,, -80700,0.4206463,1.6352396,,,,,,,,,,,,,,,,, -80800,0.46165466,1.6065644,,,,,,,,,,,,,,,,, -80900,0.4264446,1.5368065,,,,,,,,,,,,,,,,, -81000,0.4263401,1.7004017,,,,,,,,,,,,,,,,, -81039,,,0.6602153778076172,1.6250991821289062,32.57357972662718,0.6790988445281982,1.4938795566558838,29.84311518307996,3000.0,0.6907210946083069,1.4177829027175903,29.21138067716528,3003.0,28589.20459485054,47664.87612128258,28589.20459485054,19072.02882409096,1.0736279487609863,0.0 -81100,0.42735153,1.644766,,,,,,,,,,,,,,,,, -81200,0.43354255,1.6536316,,,,,,,,,,,,,,,,, -81300,0.42612353,1.6276675,,,,,,,,,,,,,,,,, -81400,0.43756187,1.6145498,,,,,,,,,,,,,,,,, -81500,0.457939,1.6476598,,,,,,,,,,,,,,,,, -81600,0.42852205,1.5959373,,,,,,,,,,,,,,,,, -81700,0.4369222,1.6127783,,,,,,,,,,,,,,,,, -81800,0.43026453,1.5985662,,,,,,,,,,,,,,,,, -81900,0.41362378,1.5836303,,,,,,,,,,,,,,,,, -82000,0.4264717,1.6789944,,,,,,,,,,,,,,,,, -82100,0.42380247,1.5245894,,,,,,,,,,,,,,,,, -82200,0.42030248,1.6218309,,,,,,,,,,,,,,,,, -82300,0.43159783,1.5831937,,,,,,,,,,,,,,,,, -82400,0.43986073,1.6396468,,,,,,,,,,,,,,,,, -82500,0.4052289,1.6276333,,,,,,,,,,,,,,,,, -82600,0.4712847,1.6960655,,,,,,,,,,,,,,,,, -82700,0.41736558,1.6335127,,,,,,,,,,,,,,,,, -82800,0.46475846,1.6795241,,,,,,,,,,,,,,,,, -82900,0.45192263,1.6057447,,,,,,,,,,,,,,,,, -83000,0.47083044,1.6367084,,,,,,,,,,,,,,,,, -83100,0.41319618,1.6034257,,,,,,,,,,,,,,,,, -83200,0.4425406,1.614699,,,,,,,,,,,,,,,,, -83300,0.42228976,1.5922229,,,,,,,,,,,,,,,,, -83400,0.4402605,1.5988985,,,,,,,,,,,,,,,,, -83423,,,0.6681274175643921,1.57500159740448,32.9723800626729,0.6786648631095886,1.4860962629318235,29.676601764832707,3000.0,0.6939747929573059,1.405526041984558,29.7245642760348,3003.0,29429.326312065125,49049.69407105446,29429.326312065125,19616.61013817787,1.1113903522491455,0.0 -83500,0.44242686,1.5342594,,,,,,,,,,,,,,,,, -83600,0.45315343,1.5830553,,,,,,,,,,,,,,,,, -83700,0.42660385,1.7004641,,,,,,,,,,,,,,,,, -83800,0.4274038,1.6230233,,,,,,,,,,,,,,,,, -83900,0.5106296,1.6191604,,,,,,,,,,,,,,,,, -84000,0.45822182,1.64641,,,,,,,,,,,,,,,,, -84100,0.40979984,1.6278819,,,,,,,,,,,,,,,,, -84200,0.48917106,1.637597,,,,,,,,,,,,,,,,, -84300,0.41287512,1.5402588,,,,,,,,,,,,,,,,, -84400,0.4364007,1.634553,,,,,,,,,,,,,,,,, -84500,0.4462982,1.5689267,,,,,,,,,,,,,,,,, -84600,0.430843,1.5857544,,,,,,,,,,,,,,,,, -84700,0.43738696,1.6158987,,,,,,,,,,,,,,,,, -84800,0.4485637,1.6078696,,,,,,,,,,,,,,,,, -84900,0.4285966,1.6346259,,,,,,,,,,,,,,,,, -85000,0.47959167,1.5494205,,,,,,,,,,,,,,,,, -85100,0.43846446,1.5905138,,,,,,,,,,,,,,,,, -85200,0.4814103,1.60938,,,,,,,,,,,,,,,,, -85300,0.42320198,1.5471143,,,,,,,,,,,,,,,,, -85400,0.41120398,1.6634189,,,,,,,,,,,,,,,,, -85500,0.4882292,1.5549746,,,,,,,,,,,,,,,,, -85600,0.44694072,1.5792077,,,,,,,,,,,,,,,,, -85700,0.4644587,1.6625701,,,,,,,,,,,,,,,,, -85800,0.43459514,1.5779556,,,,,,,,,,,,,,,,, -85808,,,0.6655429601669312,1.5882421731948853,33.31576160365833,0.6799171566963196,1.4798542261123655,29.79216033426816,3000.0,0.694195568561554,1.39683997631073,29.68178603516924,3003.0,30269.53056025505,50450.16050791741,30269.53056025505,20176.76056933403,1.147826910018921,0.0 -85900,0.44416493,1.5444734,,,,,,,,,,,,,,,,, -86000,0.4450491,1.5985769,,,,,,,,,,,,,,,,, -86100,0.44622692,1.541527,,,,,,,,,,,,,,,,, -86200,0.45143864,1.5723946,,,,,,,,,,,,,,,,, -86300,0.42513415,1.5900259,,,,,,,,,,,,,,,,, -86400,0.43887138,1.6114545,,,,,,,,,,,,,,,,, -86500,0.44203007,1.6416144,,,,,,,,,,,,,,,,, -86600,0.43401298,1.6298397,,,,,,,,,,,,,,,,, -86700,0.44590873,1.578681,,,,,,,,,,,,,,,,, -86800,0.43574786,1.5513465,,,,,,,,,,,,,,,,, -86900,0.43255362,1.6025187,,,,,,,,,,,,,,,,, -87000,0.44774324,1.5017849,,,,,,,,,,,,,,,,, -87100,0.45964006,1.4784347,,,,,,,,,,,,,,,,, -87200,0.42769894,1.5602738,,,,,,,,,,,,,,,,, -87300,0.47017306,1.5497073,,,,,,,,,,,,,,,,, -87400,0.44099763,1.643539,,,,,,,,,,,,,,,,, -87500,0.4814083,1.5222929,,,,,,,,,,,,,,,,, -87600,0.48988765,1.5490388,,,,,,,,,,,,,,,,, -87700,0.4137143,1.5413529,,,,,,,,,,,,,,,,, -87800,0.4491038,1.5976721,,,,,,,,,,,,,,,,, -87900,0.46287644,1.5726717,,,,,,,,,,,,,,,,, -88000,0.43872228,1.5532398,,,,,,,,,,,,,,,,, -88100,0.4411029,1.5572367,,,,,,,,,,,,,,,,, -88193,,,0.6824063062667847,1.4828976392745972,34.05213368530559,0.6825209856033325,1.4713144302368164,30.0387532654638,3000.0,0.6972517967224121,1.393563151359558,30.091496010180137,3003.0,31109.656126499176,51791.54896807671,31109.656126499176,20677.904341459274,1.1917221546173096,0.0 -88200,3.1013393,1.6561157,,,,,,,,,,,,,,,,, -88300,0.44250575,1.6667454,,,,,,,,,,,,,,,,, -88400,0.44931445,1.599445,,,,,,,,,,,,,,,,, -88500,0.45976707,1.5417751,,,,,,,,,,,,,,,,, -88600,0.45981458,1.5983515,,,,,,,,,,,,,,,,, -88700,0.4842911,1.6103022,,,,,,,,,,,,,,,,, -88800,0.4407438,1.6295996,,,,,,,,,,,,,,,,, -88900,0.46093556,1.5968211,,,,,,,,,,,,,,,,, -89000,0.46990475,1.5914636,,,,,,,,,,,,,,,,, -89100,0.45069343,1.575162,,,,,,,,,,,,,,,,, -89200,0.4546612,1.5939382,,,,,,,,,,,,,,,,, -89300,0.43157488,1.5213507,,,,,,,,,,,,,,,,, -89400,0.4820731,1.6427255,,,,,,,,,,,,,,,,, -89500,0.47620764,1.5437871,,,,,,,,,,,,,,,,, -89600,0.4524229,1.5081805,,,,,,,,,,,,,,,,, -89700,0.47423378,1.5285748,,,,,,,,,,,,,,,,, -89800,0.44453043,1.5062782,,,,,,,,,,,,,,,,, -89900,0.44377333,1.5213318,,,,,,,,,,,,,,,,, -90000,0.44654346,1.5185218,,,,,,,,,,,,,,,,, -90100,0.45918855,1.5997632,,,,,,,,,,,,,,,,, -90200,0.4763061,1.6136857,,,,,,,,,,,,,,,,, -90300,0.42995986,1.5563477,,,,,,,,,,,,,,,,, -90400,0.47044542,1.523159,,,,,,,,,,,,,,,,, -90500,0.46913728,1.5715051,,,,,,,,,,,,,,,,, -90578,,,0.6718730330467224,1.5490468740463257,33.41231035188514,0.6825333833694458,1.4652026891708374,30.102713832189707,3000.0,0.6974958181381226,1.3812869787216189,29.89253457886926,3003.0,31949.85806655884,53249.28406596184,31949.85806655884,21295.32447552681,1.2300746440887451,0.0 -90600,0.44741312,1.5890474,,,,,,,,,,,,,,,,, -90700,0.46869102,1.6178677,,,,,,,,,,,,,,,,, -90800,0.48013908,1.6556033,,,,,,,,,,,,,,,,, -90900,0.4768074,1.5663794,,,,,,,,,,,,,,,,, -91000,0.47222367,1.5358684,,,,,,,,,,,,,,,,, -91100,0.4666518,1.5747744,,,,,,,,,,,,,,,,, -91200,0.48189268,1.5100607,,,,,,,,,,,,,,,,, -91300,0.43330166,1.4996089,,,,,,,,,,,,,,,,, -91400,0.45821163,1.5897796,,,,,,,,,,,,,,,,, -91500,0.46310684,1.5639256,,,,,,,,,,,,,,,,, -91600,0.4570727,1.5665092,,,,,,,,,,,,,,,,, -91700,0.4635091,1.5263128,,,,,,,,,,,,,,,,, -91800,0.47551548,1.5262191,,,,,,,,,,,,,,,,, -91900,0.4791724,1.5556545,,,,,,,,,,,,,,,,, -92000,0.48249787,1.611546,,,,,,,,,,,,,,,,, -92100,0.47103536,1.632214,,,,,,,,,,,,,,,,, -92200,0.47127074,1.519246,,,,,,,,,,,,,,,,, -92300,0.46040738,1.535132,,,,,,,,,,,,,,,,, -92400,0.4688132,1.5183212,,,,,,,,,,,,,,,,, -92500,0.49028456,1.6128894,,,,,,,,,,,,,,,,, -92600,0.46960813,1.668349,,,,,,,,,,,,,,,,, -92700,0.47100115,1.5148156,,,,,,,,,,,,,,,,, -92800,0.4955612,1.5780714,,,,,,,,,,,,,,,,, -92900,0.46737194,1.5414145,,,,,,,,,,,,,,,,, -92963,,,0.6706978678703308,1.5593184232711792,33.93765018553232,0.6834261417388916,1.4623475074768066,29.94001242273573,3000.0,0.6990064382553101,1.3774524927139282,30.2127779651062,3003.0,32790.02381038666,54599.34149551392,32790.02381038666,21805.096822977062,1.2744011878967283,0.0 -93000,0.46518433,1.5046492,,,,,,,,,,,,,,,,, -93100,0.45271137,1.4807292,,,,,,,,,,,,,,,,, -93200,0.49263963,1.5747966,,,,,,,,,,,,,,,,, -93300,0.5390742,1.5909102,,,,,,,,,,,,,,,,, -93400,0.45403302,1.5873818,,,,,,,,,,,,,,,,, -93500,0.4629457,1.5782486,,,,,,,,,,,,,,,,, -93600,0.48275715,1.5660539,,,,,,,,,,,,,,,,, -93700,0.49400952,1.6506273,,,,,,,,,,,,,,,,, -93800,0.53811467,1.6664286,,,,,,,,,,,,,,,,, -93900,0.4653017,1.5128748,,,,,,,,,,,,,,,,, -94000,0.47450882,1.573179,,,,,,,,,,,,,,,,, -94100,0.50127983,1.5995928,,,,,,,,,,,,,,,,, -94200,0.493847,1.5232918,,,,,,,,,,,,,,,,, -94300,0.48922518,1.5166603,,,,,,,,,,,,,,,,, -94400,0.5166504,1.6228805,,,,,,,,,,,,,,,,, -94500,0.49781153,1.53832,,,,,,,,,,,,,,,,, -94600,0.48907673,1.5425206,,,,,,,,,,,,,,,,, -94700,0.46256828,1.6106246,,,,,,,,,,,,,,,,, -94800,0.49217302,1.6377556,,,,,,,,,,,,,,,,, -94900,0.50395316,1.6009665,,,,,,,,,,,,,,,,, -95000,0.48934102,1.5418869,,,,,,,,,,,,,,,,, -95100,0.49591345,1.4418553,,,,,,,,,,,,,,,,, -95200,0.51119953,1.6169317,,,,,,,,,,,,,,,,, -95300,0.51974756,1.5702883,,,,,,,,,,,,,,,,, -95347,,,0.6808968186378479,1.4978055953979492,34.68351336887508,0.6850131750106812,1.4518790245056152,30.293797765390387,3000.0,0.6996804475784302,1.3698073625564575,29.98158297653488,3003.0,33630.12921476364,55971.70452427864,33630.12921476364,22337.2406938076,1.3123936653137207,0.0 -95400,0.51786,1.561944,,,,,,,,,,,,,,,,, -95500,0.4851647,1.5538573,,,,,,,,,,,,,,,,, -95600,0.47101274,1.4823921,,,,,,,,,,,,,,,,, -95700,0.53766114,1.5553368,,,,,,,,,,,,,,,,, -95800,0.49366885,1.5506389,,,,,,,,,,,,,,,,, -95900,0.51558053,1.5334399,,,,,,,,,,,,,,,,, -96000,0.52672446,1.6069185,,,,,,,,,,,,,,,,, -96100,0.5281377,1.5226052,,,,,,,,,,,,,,,,, -96200,0.52449113,1.5660634,,,,,,,,,,,,,,,,, -96300,0.5170149,1.5058984,,,,,,,,,,,,,,,,, -96400,0.5103264,1.6139073,,,,,,,,,,,,,,,,, -96500,0.52683014,1.4867394,,,,,,,,,,,,,,,,, -96600,0.5249603,1.5622005,,,,,,,,,,,,,,,,, -96700,0.5322072,1.6176084,,,,,,,,,,,,,,,,, -96800,0.51931036,1.5594031,,,,,,,,,,,,,,,,, -96900,0.55906487,1.5316226,,,,,,,,,,,,,,,,, -97000,0.5157563,1.6027143,,,,,,,,,,,,,,,,, -97100,0.4884932,1.5004307,,,,,,,,,,,,,,,,, -97200,0.49080762,1.5913727,,,,,,,,,,,,,,,,, -97300,0.51522064,1.5655794,,,,,,,,,,,,,,,,, -97400,0.5287914,1.4951862,,,,,,,,,,,,,,,,, -97500,0.5188262,1.4732317,,,,,,,,,,,,,,,,, -97600,0.5216129,1.5048658,,,,,,,,,,,,,,,,, -97700,0.53328264,1.6046544,,,,,,,,,,,,,,,,, -97731,,,0.6694570183753967,1.563796043395996,33.762738030108665,0.6860547065734863,1.4455714225769043,30.13158593552824,3000.0,0.7033641338348389,1.3601025342941284,30.140830859153333,3003.0,34470.28120446205,57306.02733922005,34470.28120446205,22831.28604865074,1.359565496444702,0.0 -97800,0.49550083,1.493175,,,,,,,,,,,,,,,,, -97900,0.49969995,1.4755878,,,,,,,,,,,,,,,,, -98000,0.52876234,1.5700299,,,,,,,,,,,,,,,,, -98100,0.5557528,1.5654107,,,,,,,,,,,,,,,,, -98200,0.54488015,1.4701095,,,,,,,,,,,,,,,,, -98300,0.52785736,1.5837703,,,,,,,,,,,,,,,,, -98400,0.5285608,1.5896008,,,,,,,,,,,,,,,,, -98500,0.53276026,1.5289419,,,,,,,,,,,,,,,,, -98600,0.50538355,1.436304,,,,,,,,,,,,,,,,, -98700,0.52494895,1.4956976,,,,,,,,,,,,,,,,, -98800,0.514267,1.5709305,,,,,,,,,,,,,,,,, -98900,0.55021536,1.5545849,,,,,,,,,,,,,,,,, -99000,0.51027256,1.5246089,,,,,,,,,,,,,,,,, -99100,0.5208521,1.5113974,,,,,,,,,,,,,,,,, -99200,0.5308717,1.4801652,,,,,,,,,,,,,,,,, -99300,0.5354752,1.4889891,,,,,,,,,,,,,,,,, -99400,0.53190196,1.5506084,,,,,,,,,,,,,,,,, -99500,0.5412725,1.5084783,,,,,,,,,,,,,,,,, -99600,0.5413439,1.5000731,,,,,,,,,,,,,,,,, -99700,0.5353095,1.5332397,,,,,,,,,,,,,,,,, -99800,0.526591,1.4846888,,,,,,,,,,,,,,,,, -99900,0.5433148,1.4850022,,,,,,,,,,,,,,,,, -100000,0.5352668,1.5193812,,,,,,,,,,,,,,,,, -100100,0.5426322,1.5475416,,,,,,,,,,,,,,,,, -100116,,,0.709169864654541,1.3537744283676147,36.85112230721641,0.6860795021057129,1.4437220096588137,30.376215081158747,3000.0,0.7017953991889954,1.3589884042739868,30.204320024066764,3003.0,35310.500405311584,58661.151460170746,35310.500405311584,23346.078429937363,1.3988215923309326,0.0 -100200,0.5284074,1.5064094,,,,,,,,,,,,,,,,, -100300,0.54185003,1.4349128,,,,,,,,,,,,,,,,, -100400,0.5173299,1.5786493,,,,,,,,,,,,,,,,, -100500,0.53231496,1.540682,,,,,,,,,,,,,,,,, -100600,0.5620786,1.5255312,,,,,,,,,,,,,,,,, -100700,0.5397808,1.4610405,,,,,,,,,,,,,,,,, -100800,0.5565125,1.4455247,,,,,,,,,,,,,,,,, -100900,0.5629106,1.5914564,,,,,,,,,,,,,,,,, -101000,0.5854602,1.5148621,,,,,,,,,,,,,,,,, -101100,0.5448044,1.4681717,,,,,,,,,,,,,,,,, -101200,0.5433939,1.5107584,,,,,,,,,,,,,,,,, -101300,0.59375894,1.4875935,,,,,,,,,,,,,,,,, -101400,0.5461406,1.5049146,,,,,,,,,,,,,,,,, -101500,0.5843495,1.5585244,,,,,,,,,,,,,,,,, -101600,0.5533163,1.5962186,,,,,,,,,,,,,,,,, -101700,0.56308347,1.5684865,,,,,,,,,,,,,,,,, -101800,0.5490636,1.5769454,,,,,,,,,,,,,,,,, -101900,0.55713767,1.564138,,,,,,,,,,,,,,,,, -102000,0.56594807,1.5369965,,,,,,,,,,,,,,,,, -102100,0.5583497,1.4306947,,,,,,,,,,,,,,,,, -102200,0.55266374,1.5240169,,,,,,,,,,,,,,,,, -102300,0.55420196,1.5061207,,,,,,,,,,,,,,,,, -102400,0.5485406,1.4978414,,,,,,,,,,,,,,,,, -102500,,,0.6830866932868958,1.4718409776687622,34.33374070683871,0.6886585354804993,1.4326103925704956,30.327454897877665,3000.0,0.7032363414764404,1.3476167917251587,30.55794406358037,3003.0,36150.45610380173,60024.18383717537,36150.45610380173,23869.03919649124,1.439997673034668,0.0 -102500,0.56874657,1.5025712,,,,,,,,,,,,,,,,, -102600,0.5779676,1.538971,,,,,,,,,,,,,,,,, -102700,0.53765625,1.5980163,,,,,,,,,,,,,,,,, -102800,0.555407,1.4829414,,,,,,,,,,,,,,,,, -102900,0.55887455,1.4670807,,,,,,,,,,,,,,,,, -103000,0.59427226,1.4810569,,,,,,,,,,,,,,,,, -103100,0.6051866,1.5033349,,,,,,,,,,,,,,,,, -103200,0.58566725,1.4897732,,,,,,,,,,,,,,,,, -103300,0.5751529,1.5090892,,,,,,,,,,,,,,,,, -103400,0.56867826,1.5159875,,,,,,,,,,,,,,,,, -103500,0.5786239,1.4417418,,,,,,,,,,,,,,,,, -103600,0.59147525,1.455115,,,,,,,,,,,,,,,,, -103700,0.59087896,1.5215064,,,,,,,,,,,,,,,,, -103800,0.5596412,1.4829491,,,,,,,,,,,,,,,,, -103900,0.58272624,1.4551189,,,,,,,,,,,,,,,,, -104000,0.5708946,1.4720802,,,,,,,,,,,,,,,,, -104100,0.5666601,1.4815216,,,,,,,,,,,,,,,,, -104200,0.61297727,1.470529,,,,,,,,,,,,,,,,, -104300,0.5744552,1.4649452,,,,,,,,,,,,,,,,, -104400,0.5913099,1.5008675,,,,,,,,,,,,,,,,, -104500,0.58533573,1.5316207,,,,,,,,,,,,,,,,, -104600,0.5717028,1.5076114,,,,,,,,,,,,,,,,, -104700,0.6064511,1.4881243,,,,,,,,,,,,,,,,, -104800,0.56622845,1.4366293,,,,,,,,,,,,,,,,, -104884,,,0.6831173896789551,1.4802350997924805,34.37978318584616,0.6883485317230225,1.4367409944534302,30.31289615139876,3000.0,0.7036778926849365,1.347154140472412,30.484548709173747,3003.0,36990.53392624855,61390.49967384338,36990.53392624855,24395.162294387817,1.479119062423706,0.0 -104900,0.57597744,1.5018734,,,,,,,,,,,,,,,,, -105000,0.58365035,1.4625907,,,,,,,,,,,,,,,,, -105100,0.5645471,1.4260792,,,,,,,,,,,,,,,,, -105200,0.5842636,1.5374014,,,,,,,,,,,,,,,,, -105300,0.6001341,1.4846653,,,,,,,,,,,,,,,,, -105400,0.6084955,1.4786685,,,,,,,,,,,,,,,,, -105500,0.6134509,1.428146,,,,,,,,,,,,,,,,, -105600,0.600207,1.462116,,,,,,,,,,,,,,,,, -105700,0.6186848,1.4683048,,,,,,,,,,,,,,,,, -105800,0.5763856,1.5224893,,,,,,,,,,,,,,,,, -105900,0.62953997,1.5253221,,,,,,,,,,,,,,,,, -106000,0.57535094,1.4860448,,,,,,,,,,,,,,,,, -106100,0.6250343,1.48886,,,,,,,,,,,,,,,,, -106200,0.6072406,1.3698078,,,,,,,,,,,,,,,,, -106300,0.59713286,1.4687951,,,,,,,,,,,,,,,,, -106400,0.59633064,1.4347899,,,,,,,,,,,,,,,,, -106500,0.6175041,1.5130006,,,,,,,,,,,,,,,,, -106600,0.6170439,1.5476874,,,,,,,,,,,,,,,,, -106700,0.60821533,1.4274569,,,,,,,,,,,,,,,,, -106800,0.63061124,1.4879099,,,,,,,,,,,,,,,,, -106900,0.59425485,1.4288551,,,,,,,,,,,,,,,,, -107000,0.6116223,1.4877254,,,,,,,,,,,,,,,,, -107100,0.58416104,1.4822345,,,,,,,,,,,,,,,,, -107200,0.6092724,1.4685494,,,,,,,,,,,,,,,,, -107269,,,0.6964491009712219,1.404707908630371,35.48056756531287,0.6896876692771912,1.4275447130203247,30.40983792345602,3000.0,0.7058392763137817,1.339435338973999,30.684268005675964,3003.0,37830.54970383644,62745.49480700493,37830.54970383644,24910.0284614563,1.5198464393615725,0.0 -107300,0.62016886,1.4895874,,,,,,,,,,,,,,,,, -107400,0.6047633,1.547679,,,,,,,,,,,,,,,,, -107500,0.6314848,1.4387176,,,,,,,,,,,,,,,,, -107600,0.6344593,1.3999735,,,,,,,,,,,,,,,,, -107700,0.6317504,1.4817262,,,,,,,,,,,,,,,,, -107800,0.6370122,1.5548404,,,,,,,,,,,,,,,,, -107900,0.6237031,1.4646002,,,,,,,,,,,,,,,,, -108000,0.64339554,1.4805657,,,,,,,,,,,,,,,,, -108100,0.6205877,1.4441481,,,,,,,,,,,,,,,,, -108200,0.62562406,1.4929335,,,,,,,,,,,,,,,,, -108300,0.6064852,1.5123578,,,,,,,,,,,,,,,,, -108400,0.6308994,1.5044382,,,,,,,,,,,,,,,,, -108500,0.60799706,1.4618812,,,,,,,,,,,,,,,,, -108600,0.6360176,1.5139704,,,,,,,,,,,,,,,,, -108700,0.6364672,1.4534683,,,,,,,,,,,,,,,,, -108800,0.6476688,1.4818411,,,,,,,,,,,,,,,,, -108900,0.6406988,1.4105145,,,,,,,,,,,,,,,,, -109000,0.65776706,1.4557722,,,,,,,,,,,,,,,,, -109100,0.6137456,1.4654322,,,,,,,,,,,,,,,,, -109200,0.64756674,1.4433125,,,,,,,,,,,,,,,,, -109300,0.6586117,1.47838,,,,,,,,,,,,,,,,, -109400,0.66869104,1.5114479,,,,,,,,,,,,,,,,, -109500,0.68694913,1.4436411,,,,,,,,,,,,,,,,, -109600,0.63848513,1.4473035,,,,,,,,,,,,,,,,, -109653,,,0.6924313902854919,1.4321004152297974,35.2917509143731,0.6913739442825317,1.4228721857070925,30.570859095014143,3000.0,0.7071059346199036,1.3327090740203855,30.78908048297665,3003.0,38670.55089187622,64133.76025557518,38670.55089187622,25458.17726159096,1.561103343963623,0.0 -109700,0.64918983,1.4949588,,,,,,,,,,,,,,,,, -109800,0.6437075,1.4476246,,,,,,,,,,,,,,,,, -109900,0.65246207,1.4928411,,,,,,,,,,,,,,,,, -110000,0.64604133,1.4579068,,,,,,,,,,,,,,,,, -110100,0.6626095,1.4639782,,,,,,,,,,,,,,,,, -110200,0.6373838,1.4865565,,,,,,,,,,,,,,,,, -110300,0.66869646,1.5023897,,,,,,,,,,,,,,,,, -110400,0.6513093,1.4289464,,,,,,,,,,,,,,,,, -110500,0.65586543,1.4460256,,,,,,,,,,,,,,,,, -110600,0.67223686,1.45227,,,,,,,,,,,,,,,,, -110700,0.66050684,1.4369671,,,,,,,,,,,,,,,,, -110800,0.67262113,1.455122,,,,,,,,,,,,,,,,, -110900,0.6513951,1.5378213,,,,,,,,,,,,,,,,, -111000,0.66096723,1.4562204,,,,,,,,,,,,,,,,, -111100,0.65183276,1.4355782,,,,,,,,,,,,,,,,, -111200,0.6512275,1.3992707,,,,,,,,,,,,,,,,, -111300,0.7032759,1.4415858,,,,,,,,,,,,,,,,, -111400,0.70000476,1.4435129,,,,,,,,,,,,,,,,, -111500,0.67216384,1.4658595,,,,,,,,,,,,,,,,, -111600,0.6693129,1.4593375,,,,,,,,,,,,,,,,, -111700,0.65511304,1.4144374,,,,,,,,,,,,,,,,, -111800,0.67955136,1.3673817,,,,,,,,,,,,,,,,, -111900,0.66309214,1.4371263,,,,,,,,,,,,,,,,, -112000,0.7141027,1.417278,,,,,,,,,,,,,,,,, -112037,,,0.6871734857559204,1.4625294208526611,35.0512106849693,0.6905679702758789,1.4227373600006104,30.782762669814545,3000.0,0.7082796096801758,1.3301644325256348,30.976652739571417,3003.0,39510.55522322655,65477.0303850174,39510.55522322655,25961.32877779007,1.60233473777771,0.0 -112100,0.6661554,1.4414558,,,,,,,,,,,,,,,,, -112200,0.6603346,1.3898985,,,,,,,,,,,,,,,,, -112300,0.6758601,1.4403006,,,,,,,,,,,,,,,,, -112400,0.6813685,1.4417835,,,,,,,,,,,,,,,,, -112500,0.68680155,1.4178605,,,,,,,,,,,,,,,,, -112600,0.7044011,1.425852,,,,,,,,,,,,,,,,, -112700,0.7367507,1.4742798,,,,,,,,,,,,,,,,, -112800,0.6687355,1.4852513,,,,,,,,,,,,,,,,, -112900,0.6752382,1.3846861,,,,,,,,,,,,,,,,, -113000,0.6776175,1.4610116,,,,,,,,,,,,,,,,, -113100,0.728234,1.4754616,,,,,,,,,,,,,,,,, -113200,0.76558334,1.4863911,,,,,,,,,,,,,,,,, -113300,0.71713746,1.4837952,,,,,,,,,,,,,,,,, -113400,0.6896097,1.4489683,,,,,,,,,,,,,,,,, -113500,0.6858049,1.3701056,,,,,,,,,,,,,,,,, -113600,0.6699166,1.4164988,,,,,,,,,,,,,,,,, -113700,0.7317565,1.3971002,,,,,,,,,,,,,,,,, -113800,0.68175095,1.406836,,,,,,,,,,,,,,,,, -113900,0.69198644,1.4380755,,,,,,,,,,,,,,,,, -114000,0.6880346,1.3610044,,,,,,,,,,,,,,,,, -114100,0.6930318,1.428309,,,,,,,,,,,,,,,,, -114200,0.6675452,1.356225,,,,,,,,,,,,,,,,, -114300,0.70445514,1.3452586,,,,,,,,,,,,,,,,, -114400,0.69635034,1.4028714,,,,,,,,,,,,,,,,, -114421,,,0.699521541595459,1.3912333250045776,36.07696290784867,0.6913739442825317,1.4179686307907104,30.715100454032235,3000.0,0.7078379988670349,1.3299546241760254,30.80652754405374,3003.0,40350.648983716965,66829.55691623688,40350.648983716965,26473.64466929436,1.644965410232544,0.0 -114500,0.70056874,1.4236059,,,,,,,,,,,,,,,,, -114600,0.74147123,1.4848036,,,,,,,,,,,,,,,,, -114700,0.71586955,1.4194868,,,,,,,,,,,,,,,,, -114800,0.72833043,1.3923837,,,,,,,,,,,,,,,,, -114900,0.72001785,1.5003088,,,,,,,,,,,,,,,,, -115000,0.7171684,1.4066991,,,,,,,,,,,,,,,,, -115100,0.7243978,1.395957,,,,,,,,,,,,,,,,, -115200,0.7527031,1.4690685,,,,,,,,,,,,,,,,, -115300,0.71650255,1.411343,,,,,,,,,,,,,,,,, -115400,0.74075735,1.4377271,,,,,,,,,,,,,,,,, -115500,0.71500653,1.3968214,,,,,,,,,,,,,,,,, -115600,0.7363202,1.4081163,,,,,,,,,,,,,,,,, -115700,0.71982014,1.352361,,,,,,,,,,,,,,,,, -115800,0.7267415,1.4506755,,,,,,,,,,,,,,,,, -115900,0.6827257,1.405344,,,,,,,,,,,,,,,,, -116000,0.7094354,1.4092739,,,,,,,,,,,,,,,,, -116100,0.7257175,1.4098431,,,,,,,,,,,,,,,,, -116200,0.72209746,1.4749599,,,,,,,,,,,,,,,,, -116300,0.70970637,1.308893,,,,,,,,,,,,,,,,, -116400,0.71391314,1.4111944,,,,,,,,,,,,,,,,, -116500,0.7124197,1.3744739,,,,,,,,,,,,,,,,, -116600,0.76726615,1.3547987,,,,,,,,,,,,,,,,, -116700,0.7379277,1.4405835,,,,,,,,,,,,,,,,, -116800,0.751013,1.4403622,,,,,,,,,,,,,,,,, -116804,,,0.6963258385658264,1.4070066213607788,35.96184061155457,0.6927750110626221,1.415552258491516,30.68112671416496,3000.0,0.7087211608886719,1.3230842351913452,30.802705175820623,3003.0,41190.77096199989,68200.31424498558,41190.77096199989,27004.1587600708,1.687828779220581,0.0 -116900,0.76860034,1.4321923,,,,,,,,,,,,,,,,, -117000,0.76339793,1.3457865,,,,,,,,,,,,,,,,, -117100,0.72190326,1.4143662,,,,,,,,,,,,,,,,, -117200,0.7790015,1.4208909,,,,,,,,,,,,,,,,, -117300,0.7899764,1.4737356,,,,,,,,,,,,,,,,, -117400,0.752622,1.425484,,,,,,,,,,,,,,,,, -117500,0.7617242,1.3862139,,,,,,,,,,,,,,,,, -117600,0.7513908,1.3880426,,,,,,,,,,,,,,,,, -117700,0.73038673,1.3771025,,,,,,,,,,,,,,,,, -117800,0.74438125,1.3764027,,,,,,,,,,,,,,,,, -117900,0.76564634,1.3696755,,,,,,,,,,,,,,,,, -118000,0.73483276,1.29301,,,,,,,,,,,,,,,,, -118100,0.7547256,1.3383651,,,,,,,,,,,,,,,,, -118200,0.75322837,1.4293282,,,,,,,,,,,,,,,,, -118300,0.74312264,1.3938322,,,,,,,,,,,,,,,,, -118400,0.7663346,1.3812659,,,,,,,,,,,,,,,,, -118500,0.7515804,1.5013498,,,,,,,,,,,,,,,,, -118600,0.79738307,1.5502737,,,,,,,,,,,,,,,,, -118700,0.7540745,1.4533048,,,,,,,,,,,,,,,,, -118800,0.77634937,1.3186755,,,,,,,,,,,,,,,,, -118900,0.7817979,1.4106767,,,,,,,,,,,,,,,,, -119000,0.77485853,1.3756298,,,,,,,,,,,,,,,,, -119100,0.7646254,1.3403288,,,,,,,,,,,,,,,,, -119188,,,0.7110795378684998,1.327668070793152,37.20115594615644,0.6929734349250793,1.4138745069503784,30.86193975348537,3000.0,0.7092092633247375,1.32043719291687,30.994883238838508,3003.0,42030.960902929306,69574.35052037239,42030.960902929306,27537.889677286148,1.728813409805298,0.0 -119200,0.74322146,1.4118329,,,,,,,,,,,,,,,,, -119300,0.78708583,1.3714747,,,,,,,,,,,,,,,,, -119400,0.77612686,1.4029686,,,,,,,,,,,,,,,,, -119500,0.76962084,1.3797014,,,,,,,,,,,,,,,,, -119600,0.7627725,1.3904195,,,,,,,,,,,,,,,,, -119700,0.72915727,1.3142127,,,,,,,,,,,,,,,,, -119800,0.7685909,1.397898,,,,,,,,,,,,,,,,, -119900,0.75969976,1.3852862,,,,,,,,,,,,,,,,, -120000,0.77262664,1.378047,,,,,,,,,,,,,,,,, -120100,0.7545254,1.3596346,,,,,,,,,,,,,,,,, -120200,0.77414125,1.3990302,,,,,,,,,,,,,,,,, -120300,0.76939714,1.3960916,,,,,,,,,,,,,,,,, -120400,0.8031572,1.431254,,,,,,,,,,,,,,,,, -120500,0.7683781,1.3172243,,,,,,,,,,,,,,,,, -120600,0.81377935,1.4140756,,,,,,,,,,,,,,,,, -120700,0.8160536,1.3759995,,,,,,,,,,,,,,,,, -120800,0.76672965,1.4089649,,,,,,,,,,,,,,,,, -120900,0.79516023,1.3335363,,,,,,,,,,,,,,,,, -121000,0.7602905,1.3440316,,,,,,,,,,,,,,,,, -121100,0.7576947,1.4180554,,,,,,,,,,,,,,,,, -121200,0.7674746,1.3282766,,,,,,,,,,,,,,,,, -121300,0.7474697,1.3148907,,,,,,,,,,,,,,,,, -121400,0.7941837,1.3631045,,,,,,,,,,,,,,,,, -121500,0.7555129,1.4717491,,,,,,,,,,,,,,,,, -121572,,,0.7089225053787231,1.3409868478775024,36.699779532755656,0.6929982304573059,1.4128586053848269,30.743967260596456,3000.0,0.7098135352134705,1.3198758363723757,30.807151339414787,3003.0,42870.91543865204,70966.79245257378,42870.91543865204,28090.25924921036,1.7727289199829102,0.0 -121600,0.7725278,1.326179,,,,,,,,,,,,,,,,, -121700,0.80102795,1.3474457,,,,,,,,,,,,,,,,, -121800,0.79736394,1.417722,,,,,,,,,,,,,,,,, -121900,0.7973575,1.3561598,,,,,,,,,,,,,,,,, -122000,0.7826694,1.331216,,,,,,,,,,,,,,,,, -122100,0.78416026,1.4132233,,,,,,,,,,,,,,,,, -122200,0.821879,1.3686925,,,,,,,,,,,,,,,,, -122300,0.75673914,1.3447287,,,,,,,,,,,,,,,,, -122400,0.80352294,1.4685891,,,,,,,,,,,,,,,,, -122500,0.7932455,1.3538157,,,,,,,,,,,,,,,,, -122600,0.77523386,1.3472412,,,,,,,,,,,,,,,,, -122700,0.79484415,1.4425032,,,,,,,,,,,,,,,,, -122800,0.79225135,1.308186,,,,,,,,,,,,,,,,, -122900,0.7997876,1.378834,,,,,,,,,,,,,,,,, -123000,0.7842967,1.3745253,,,,,,,,,,,,,,,,, -123100,0.79913884,1.352323,,,,,,,,,,,,,,,,, -123200,0.7803296,1.4017764,,,,,,,,,,,,,,,,, -123300,0.8243646,1.3490776,,,,,,,,,,,,,,,,, -123400,0.78916174,1.3126371,,,,,,,,,,,,,,,,, -123500,0.7903181,1.400762,,,,,,,,,,,,,,,,, -123600,0.82511413,1.3564718,,,,,,,,,,,,,,,,, -123700,0.80387324,1.3834777,,,,,,,,,,,,,,,,, -123800,0.8004527,1.3734848,,,,,,,,,,,,,,,,, -123900,0.79991466,1.4002634,,,,,,,,,,,,,,,,, -123956,,,0.7044199109077454,1.3619587421417236,36.63037870395407,0.6941761374473572,1.409651756286621,30.92822877713484,3000.0,0.7102085947990417,1.3139033317565918,31.06983647732182,3003.0,43710.88386774063,72332.34792470932,43710.88386774063,28615.730256080627,1.814807415008545,0.0 -124000,0.78422236,1.3451599,,,,,,,,,,,,,,,,, -124100,0.7990505,1.3578868,,,,,,,,,,,,,,,,, -124200,0.8092689,1.3884917,,,,,,,,,,,,,,,,, -124300,0.813294,1.4006749,,,,,,,,,,,,,,,,, -124400,0.8229496,1.4098747,,,,,,,,,,,,,,,,, -124500,0.80145586,1.4313961,,,,,,,,,,,,,,,,, -124600,0.8155123,1.3869996,,,,,,,,,,,,,,,,, -124700,0.76691234,1.4037373,,,,,,,,,,,,,,,,, -124800,0.7863633,1.2931212,,,,,,,,,,,,,,,,, -124900,0.83742654,1.3603667,,,,,,,,,,,,,,,,, -125000,0.8228392,1.2722147,,,,,,,,,,,,,,,,, -125100,0.8280682,1.4159757,,,,,,,,,,,,,,,,, -125200,0.77027285,1.3826659,,,,,,,,,,,,,,,,, -125300,0.8206331,1.3827889,,,,,,,,,,,,,,,,, -125400,0.78019464,1.3915472,,,,,,,,,,,,,,,,, -125500,0.8006535,1.379196,,,,,,,,,,,,,,,,, -125600,0.84394914,1.3651265,,,,,,,,,,,,,,,,, -125700,0.8106063,1.3515623,,,,,,,,,,,,,,,,, -125800,0.8308254,1.3999484,,,,,,,,,,,,,,,,, -125900,0.7929588,1.3121294,,,,,,,,,,,,,,,,, -126000,0.8409401,1.3785094,,,,,,,,,,,,,,,,, -126100,0.8167555,1.3730894,,,,,,,,,,,,,,,,, -126200,0.8291469,1.4033091,,,,,,,,,,,,,,,,, -126300,0.80316615,1.383001,,,,,,,,,,,,,,,,, -126339,,,0.7150582075119019,1.3104259967803955,36.93187865217039,0.6944612860679626,1.411051869392395,30.66054748234449,3000.0,0.7101853489875793,1.3157200813293457,30.995070613341326,3003.0,44550.9148273468,73694.35182857513,44550.9148273468,29137.58055591584,1.8581435680389404,0.0 -126400,0.8401384,1.4107411,,,,,,,,,,,,,,,,, -126500,0.815524,1.31537,,,,,,,,,,,,,,,,, -126600,0.83814216,1.4140805,,,,,,,,,,,,,,,,, -126700,0.8153953,1.3667537,,,,,,,,,,,,,,,,, -126800,0.7899538,1.3369157,,,,,,,,,,,,,,,,, -126900,0.8092709,1.4287695,,,,,,,,,,,,,,,,, -127000,0.79094374,1.3024696,,,,,,,,,,,,,,,,, -127100,0.83928984,1.3429364,,,,,,,,,,,,,,,,, -127200,0.82047695,1.3781365,,,,,,,,,,,,,,,,, -127300,0.7995368,1.3919756,,,,,,,,,,,,,,,,, -127400,0.819096,1.5191166,,,,,,,,,,,,,,,,, -127500,0.80591786,1.3320545,,,,,,,,,,,,,,,,, -127600,0.84217846,1.4159763,,,,,,,,,,,,,,,,, -127700,0.8379217,1.3669587,,,,,,,,,,,,,,,,, -127800,0.82552576,1.3916258,,,,,,,,,,,,,,,,, -127900,0.816499,1.321926,,,,,,,,,,,,,,,,, -128000,0.83401275,1.3790954,,,,,,,,,,,,,,,,, -128100,0.8340176,1.3321831,,,,,,,,,,,,,,,,, -128200,0.81722784,1.2958708,,,,,,,,,,,,,,,,, -128300,0.78066224,1.301349,,,,,,,,,,,,,,,,, -128400,0.8117229,1.3441654,,,,,,,,,,,,,,,,, -128500,0.82707304,1.2946587,,,,,,,,,,,,,,,,, -128600,0.8283779,1.3731298,,,,,,,,,,,,,,,,, -128700,0.8186079,1.3912231,,,,,,,,,,,,,,,,, -128723,,,0.7099118232727051,1.337271809577942,36.93021401210776,0.6949200630187988,1.4093292951583862,30.931063555305773,3000.0,0.7109407186508179,1.3132773637771606,31.110231469925782,3003.0,45391.13405776024,75083.51441955566,45391.13405776024,29686.40299320221,1.902180433273316,0.0 -128800,0.81333625,1.3048455,,,,,,,,,,,,,,,,, -128900,0.8102985,1.3817567,,,,,,,,,,,,,,,,, -129000,0.83143413,1.416617,,,,,,,,,,,,,,,,, -129100,0.78303295,1.3479909,,,,,,,,,,,,,,,,, -129200,0.8199688,1.3640516,,,,,,,,,,,,,,,,, -129300,0.83095753,1.388446,,,,,,,,,,,,,,,,, -129400,0.8660396,1.3570281,,,,,,,,,,,,,,,,, -129500,0.8005993,1.3166788,,,,,,,,,,,,,,,,, -129600,0.8072997,1.3443426,,,,,,,,,,,,,,,,, -129700,0.84247184,1.3228375,,,,,,,,,,,,,,,,, -129800,0.81696624,1.3125428,,,,,,,,,,,,,,,,, -129900,0.82938194,1.3891996,,,,,,,,,,,,,,,,, -130000,0.80418634,1.2979692,,,,,,,,,,,,,,,,, -130100,0.797359,1.338191,,,,,,,,,,,,,,,,, -130200,0.82876587,1.3839748,,,,,,,,,,,,,,,,, -130300,0.837217,1.354922,,,,,,,,,,,,,,,,, -130400,0.7838867,1.3822294,,,,,,,,,,,,,,,,, -130500,0.8362021,1.3269595,,,,,,,,,,,,,,,,, -130600,0.7991824,1.3670564,,,,,,,,,,,,,,,,, -130700,0.7912655,1.3174233,,,,,,,,,,,,,,,,, -130800,0.8158257,1.3711782,,,,,,,,,,,,,,,,, -130900,0.80399674,1.3919042,,,,,,,,,,,,,,,,, -131000,0.78388816,1.3731339,,,,,,,,,,,,,,,,, -131100,0.8105445,1.3922513,,,,,,,,,,,,,,,,, -131106,,,0.7094533443450928,1.3412723541259766,36.79860717293863,0.6949820518493652,1.4096693992614746,30.92847868651178,3000.0,0.710777997970581,1.3123667240142822,31.17334827670532,3003.0,46231.019748449326,76469.67206168175,46231.019748449326,30232.55574464798,1.945484161376953,0.0 -131200,0.80743784,1.3777373,,,,,,,,,,,,,,,,, -131300,0.7945691,1.3627187,,,,,,,,,,,,,,,,, -131400,0.8205682,1.3449605,,,,,,,,,,,,,,,,, -131500,0.8055134,1.2996026,,,,,,,,,,,,,,,,, -131600,0.82722646,1.4148571,,,,,,,,,,,,,,,,, -131700,0.83168507,1.4104061,,,,,,,,,,,,,,,,, -131800,0.77843,1.3605585,,,,,,,,,,,,,,,,, -131900,0.7931131,1.2712852,,,,,,,,,,,,,,,,, -132000,0.81140834,1.4321029,,,,,,,,,,,,,,,,, -132100,0.8185827,1.3371919,,,,,,,,,,,,,,,,, -132200,0.81202096,1.3572118,,,,,,,,,,,,,,,,, -132300,0.82786816,1.3933624,,,,,,,,,,,,,,,,, -132400,0.79213935,1.3128248,,,,,,,,,,,,,,,,, -132500,0.8087013,1.3574301,,,,,,,,,,,,,,,,, -132600,0.81340945,1.397083,,,,,,,,,,,,,,,,, -132700,0.79511875,1.3443449,,,,,,,,,,,,,,,,, -132800,0.80660886,1.3940783,,,,,,,,,,,,,,,,, -132900,0.803988,1.3693217,,,,,,,,,,,,,,,,, -133000,0.806468,1.3027664,,,,,,,,,,,,,,,,, -133100,0.828774,1.4064863,,,,,,,,,,,,,,,,, -133200,0.81862485,1.3754576,,,,,,,,,,,,,,,,, -133300,0.7776787,1.3021781,,,,,,,,,,,,,,,,, -133333,,,0.7125383019447327,1.3173706531524658,36.68598863876687,0.6947216987609863,1.4095044136047363,30.934123880652557,3000.0,0.7105455994606018,1.312601923942566,31.11336078741953,3003.0,47015.76848649979,77788.48857069016,47015.76848649979,30766.50728940964,1.9903368949890137,0.0 -133333,,,,,,,,,,,,,,47015.76848649979,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 3dfa4aebd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -877.5347895622253,0.0,30.8999125957489,1,0,30.8999125957489,0.0007088489946909,0.0,11.19086742401123,3003,908.4347548484802,0.0006454141112044,0.0,11.175609588623049,0.0004835649742744,0.0,11.208685874938965,3000 -1403.2590272426603,0.02254319190979,870.8914546966553,2382,0,870.8914546966553,0.5347859263420105,18.522566174985347,2.591620922088623,3003,2274.2477111816406,0.5327548980712891,24.206353735215085,2.6247782707214355,0.537860631942749,20.04948386044301,2.565381050109864,3000 -1915.1850049495697,0.0476388931274414,1711.0473158359528,4766,0,1711.0473158359528,0.5945848822593689,21.990632388952925,2.082348823547364,3003,3626.431683540344,0.5796718597412109,26.36870746881244,2.213855266571045,0.5924167037010193,23.469450249820703,2.0987660884857178,3000 -2407.65248298645,0.0771441459655761,2551.1354496479034,7150,0,2551.1354496479034,0.6053221821784973,22.893728461005857,2.0033843517303467,3003,4959.091591835022,0.5893054604530334,27.176336003425604,2.118350028991699,0.6037743091583252,24.357176752244964,2.024528980255127,3000 -2982.4601430892944,0.1054227352142334,3391.3415517807007,9535,0,3391.3415517807007,0.6162105798721313,23.417765370115543,1.9303609132766724,3003,6374.207869529724,0.5954850912094116,27.71182659066221,2.119326114654541,0.6125404238700867,24.82005795946175,1.9642142057418823,3000 -3488.283415555954,0.1319782733917236,4231.374727487564,11917,0,4231.374727487564,0.6186973452568054,22.844513408768293,1.9178035259246824,3003,7720.172181844711,0.5949356555938721,27.32630178120262,2.093808650970459,0.6108169555664062,23.938898942260515,1.9556013345718384,3000 -3977.3424422740936,0.1598775386810302,5071.90565609932,14301,0,5071.90565609932,0.621997594833374,23.81654558217464,1.8840601444244385,3003,9049.865262746813,0.597704291343689,28.308626342993755,2.0802969932556152,0.6167685389518738,25.06334814984212,1.930830478668213,3000 -4463.997585058212,0.1864314079284668,5911.835096359253,16684,0,5911.835096359253,0.6254836916923523,23.955626051084025,1.8665118217468264,3003,10376.551353693008,0.5992457866668701,28.287729139742407,2.077332258224488,0.6202650666236877,25.16760638686699,1.911595344543457,3000 -4972.11691904068,0.2197494506835937,6751.987172842026,19068,0,6751.987172842026,0.6246935129165649,24.06795003200436,1.8686076402664185,3003,11724.933151721954,0.6137471199035645,29.14411802678976,1.926009178161621,0.6178720593452454,25.27364721195347,1.9107155799865725,3000 -5525.0277309417725,0.2484176158905029,7591.907975435257,21452,0,7591.907975435257,0.6274592280387878,24.41130121879812,1.8478097915649407,3003,13117.867561101912,0.6039170026779175,28.499650206528244,2.04196834564209,0.6211702227592468,25.326143244982028,1.893813848495484,3000 -6023.364975690842,0.2760303020477295,8432.142497062683,23837,0,8432.142497062683,0.6271222233772278,24.527137407140724,1.846617221832276,3003,14456.54249548912,0.6014580130577087,28.452707209093468,2.0570669174194336,0.6210214495658875,25.541055572678147,1.8938277959823608,3000 -6616.417924404144,0.303934097290039,9272.08500289917,26220,0,9272.08500289917,0.6297716498374939,24.47234568073476,1.828667402267456,3003,15889.645047426224,0.6070371866226196,28.626233950948446,2.0028188228607178,0.6237616539001465,25.473166487207465,1.874969720840454,3000 -7134.273945808411,0.3334333896636963,10112.03343486786,28604,0,10112.03343486786,0.6322584748268127,24.058762358703923,1.824683427810669,3003,17247.554879665375,0.6041083335876465,28.460018032548263,2.042267322540283,0.6240839958190918,25.687265643742247,1.880605697631836,3000 -7647.579026222229,0.3638780117034912,10952.269091129305,30989,0,10952.269091129305,0.6323397755622864,24.526306536985867,1.8168457746505733,3003,18601.20199108124,0.6057415008544922,28.31511913979044,2.0303850173950195,0.6273201704025269,25.75657484227801,1.8572407960891724,3000 -8246.12324142456,0.3925106525421142,11792.17632174492,33374,0,11792.17632174492,0.6319795846939087,24.44685371781997,1.811674952507019,3003,20039.75516843796,0.60844486951828,28.007482098976688,2.00085997581482,0.6259438991546631,25.505005927829227,1.856594443321228,3000 -8765.0341360569,0.4228677749633789,12632.082436800005,35758,0,12632.082436800005,0.6337923407554626,24.817481823924734,1.8009341955184937,3003,21398.67829966545,0.6062001585960388,28.496158366728405,2.012900114059448,0.6244683861732483,25.335470426237546,1.8566988706588743,3000 -9230.926263809204,0.4587104320526123,13472.220349311829,38143,0,13472.220349311829,0.6382081508636475,25.00556614291527,1.7820231914520264,3003,22704.81827187538,0.6146321892738342,29.38206898813265,1.95678985118866,0.628361701965332,25.92643316412883,1.8436554670333865,3000 -9721.91780424118,0.4886271953582763,14312.36628293991,40527,0,14312.36628293991,0.6370925903320312,24.96963784239136,1.7789210081100464,3003,24036.06175875664,0.611473023891449,28.94678678261578,1.9678258895874023,0.6278533339500427,25.68769018292687,1.835180640220642,3000 -10253.97549557686,0.5196661949157715,15152.403260946274,42911,0,15152.403260946274,0.6399396061897278,25.392114429594827,1.767079472541809,3003,25408.2639605999,0.6131593585014343,29.0899991882196,1.972790241241455,0.631461501121521,26.09814806128841,1.8121095895767207,3000 -10814.103539466858,0.5507168769836426,15992.33811044693,45296,0,15992.33811044693,0.6396374702453613,24.73479737642498,1.7606006860733032,3003,26808.432076215744,0.6127845048904419,28.965295301006385,1.955664873123169,0.6328253746032715,25.96706766904877,1.8045439720153809,3000 -11321.025541305542,0.5871026515960693,16832.49195432663,47681,0,16832.49195432663,0.6436465382575989,25.42608692174256,1.74813711643219,3003,28155.61889219284,0.6143538355827332,29.107341526376448,1.9528555870056152,0.6314491033554077,25.94393430285941,1.8058665990829468,3000 -11801.74766111374,0.6183607578277588,17672.648672819138,50064,0,17672.648672819138,0.6434489488601685,25.14636965746108,1.7467548847198486,3003,29476.605541706085,0.6843953728675842,34.336430518788745,1.47307026386261,0.635949969291687,26.441084450311465,1.7930233478546145,3000 -12377.643058538437,0.6505851745605469,18512.559674978256,52447,0,18512.559674978256,0.6453315019607544,25.591061740837908,1.7311948537826538,3003,30892.52356314659,0.6166776418685913,28.009326838866563,1.935380220413208,0.6369170546531677,26.137517861350776,1.7796038389205933,3000 -12930.402215003967,0.6834166049957275,19352.45453119278,54831,0,19352.45453119278,0.6477136611938477,25.65318269635807,1.711084008216858,3003,32285.287123441696,0.6168438792228699,29.60037342386599,1.9355334043502808,0.639086902141571,26.58346724934059,1.7692811489105225,3000 -13528.37666606903,0.7170102596282959,20192.690421819687,57216,0,20192.690421819687,0.6459241509437561,25.75472864828333,1.7121566534042358,3003,33723.60852479935,0.6217406392097473,30.00592278859685,1.8886256217956543,0.6395828723907471,26.48839796432684,1.764250636100769,3000 -14096.32988357544,0.7592217922210693,21033.113726854324,59601,0,21033.113726854324,0.6510720252990723,26.27622806513249,1.687710523605347,3003,35132.104048252106,0.6205396056175232,29.974718274766825,1.9234768152236936,0.6435505747795105,26.64118576387656,1.7510372400283811,3000 -14926.03948044777,0.7936761379241943,21873.043880462646,61986,0,21873.043880462646,0.6511649489402771,25.027113486174706,1.6913459300994873,3003,36801.8515625,0.620170533657074,29.704826058174017,1.924217939376831,0.6417403221130371,24.938980730325103,1.7392182350158691,3000 -15509.657069206238,0.826836109161377,22713.07751083374,64370,0,22713.07751083374,0.6523153781890869,26.33216731822289,1.6770296096801758,3003,38225.61360192299,0.6249828934669495,30.12080442859549,1.8825191259384155,0.6434885859489441,27.00111898618707,1.73857581615448,3000 -16021.18220758438,0.8614444732666016,23553.293513298035,66755,0,23553.293513298035,0.6545000672340393,26.066548963881,1.6544857025146484,3003,39577.4663040638,0.6223474144935608,29.60616800850276,1.897873878479004,0.6448525190353394,26.909955268306703,1.712667465209961,3000 -16597.801033496857,0.8966538906097412,24393.23588967324,69140,0,24393.23588967324,0.6583929061889648,26.70894910112787,1.6351416110992432,3003,40994.1373796463,0.6475690007209778,31.34708892707513,1.708027482032776,0.6484978199005127,27.0442542027674,1.7034063339233398,3000 -17131.623816490173,0.9319117069244384,25233.399693965912,71525,0,25233.399693965912,0.6592528223991394,26.910900978321823,1.626090168952942,3003,42368.23453712464,0.6270906329154968,30.49007289439116,1.85956346988678,0.6483118534088135,27.250290188229883,1.6941295862197876,3000 -17622.726563453674,0.9732100963592528,26073.622700452805,73910,0,26073.622700452805,0.658288300037384,26.540084120816232,1.6216771602630615,3003,43699.67828798294,0.6320022940635681,30.474995803724568,1.837198734283448,0.6498989462852478,27.39593869430475,1.6905559301376345,3000 -18340.518897771835,1.016244649887085,26913.55896472931,76295,0,26913.55896472931,0.6646563410758972,27.065361558941973,1.5929890871047974,3003,45257.52446436882,0.6367040872573853,31.028093772921565,1.7800288200378418,0.6544989943504333,27.57831994512013,1.6648274660110474,3000 -18900.15974450112,1.0552592277526855,27753.546751499176,78680,0,27753.546751499176,0.6647725701332092,27.32055158379405,1.5881857872009275,3003,46657.269523620605,0.6351809501647949,30.51627421938524,1.8063995838165283,0.6567060351371765,27.850444978291744,1.6517740488052368,3000 -19485.708025693893,1.0978131294250488,28593.6186144352,81065,0,28593.6186144352,0.6671895980834961,27.59701393455958,1.572178840637207,3003,48083.00713443756,0.6338257789611816,30.37254047659617,1.8154295682907104,0.657177209854126,27.84527507937404,1.6409265995025637,3000 -20160.911412000656,1.1356170177459717,29433.67586684227,83449,0,29433.67586684227,0.6720237135887146,28.04304378534904,1.551444411277771,3003,49598.38488411904,0.6420011520385742,31.104798039546385,1.7547723054885864,0.6573383808135986,28.031213845797343,1.6285314559936523,3000 -20746.017300128937,1.178145170211792,30273.84391236305,85834,0,30273.84391236305,0.673267126083374,28.295872994291173,1.5379685163497925,3003,51023.77744102478,0.6379316449165344,31.381088323768665,1.7899552583694458,0.6602894067764282,27.88452353796739,1.61657977104187,3000 -21216.517703533173,1.216465711593628,31114.019277334213,88220,0,31114.019277334213,0.6726628541946411,28.065620186470788,1.5305873155593872,3003,52334.56629276276,0.6552841663360596,32.29373394632669,1.6534067392349243,0.6635503768920898,28.25207895536669,1.599921703338623,3000 -21782.82957220077,1.2534148693084717,31953.927941083908,90603,0,31953.927941083908,0.6796583533287048,28.84191883449241,1.5105689764022827,3003,53740.90544724464,0.6459221243858337,31.6283829980775,1.725430607795715,0.6639595031738281,28.32560009538912,1.5904563665390017,3000 -22296.06116771698,1.2910587787628174,32794.00373888016,92989,0,32794.00373888016,0.6814130544662476,28.43917903285333,1.4934310913085938,3003,55094.32453846932,0.6447259783744812,31.694070491390008,1.7305859327316284,0.6674808859825134,28.977459287447072,1.5697031021118164,3000 -22882.23731899261,1.3375027179718018,33634.21266889572,95375,0,33634.21266889572,0.6810063719749451,28.512990721724428,1.4797627925872805,3003,56520.830107450485,0.6559216380119324,32.57261780221355,1.6479884386062622,0.6692167520523071,28.733379976232605,1.5522282123565674,3000 -23430.37916445732,1.3787147998809814,34474.18307328224,97760,0,34474.18307328224,0.6845040917396545,28.75097303659052,1.4635298252105713,3003,57909.06037116051,0.6507320404052734,32.16376730049248,1.6990152597427368,0.6716469526290894,28.933022568504928,1.5430229902267456,3000 -24041.794478416443,1.417504072189331,35314.18997120857,100145,0,35314.18997120857,0.6881529688835144,29.32560119300237,1.4465465545654297,3003,59360.59610676765,0.6949189305305481,35.23068643277161,1.418954849243164,0.673246443271637,29.01353235327636,1.528011441230774,3000 -24610.273792028427,1.4581010341644287,36154.3091506958,102530,0,36154.3091506958,0.6910231709480286,29.189676258834265,1.4268368482589722,3003,60769.31097340584,0.6625471115112305,32.88659371950133,1.6055678129196167,0.6756394505500793,29.382846968354546,1.5099588632583618,3000 -25181.71244192124,1.5033392906188965,36994.25947546959,104916,0,36994.25947546959,0.6924060583114624,29.2301595025155,1.420128583908081,3003,62180.81967806816,0.6604390144348145,32.50623094459274,1.628412842750549,0.6783796548843384,29.517135048861,1.500727891921997,3000 -25798.53511714936,1.545297145843506,37834.20816755295,107300,0,37834.20816755295,0.693591296672821,29.43719279595945,1.4022724628448486,3003,63637.71153593063,0.6785178780555725,33.844285592149795,1.51469624042511,0.6805495023727417,29.774191434883186,1.4880130290985107,3000 -26334.07120013237,1.5848398208618164,38674.33937954903,109684,0,38674.33937954903,0.6969845294952393,29.99551478418762,1.3863232135772705,3003,65013.49522304535,0.6720339059829712,32.979392765729266,1.5570589303970337,0.6831905245780945,30.14964707735585,1.4739441871643066,3000 -26838.75919866562,1.627079963684082,39514.344621658325,112068,0,39514.344621658325,0.6990064382553101,29.9190960808925,1.3751919269561768,3003,66358.30993127823,0.6718950271606445,33.424706492563324,1.5602552890777588,0.6845420598983765,30.09190590265441,1.466409683227539,3000 -27475.601682901382,1.6685771942138672,40354.538182258606,114454,0,40354.538182258606,0.7002847194671631,30.209978551978672,1.361093282699585,3003,67835.46319293976,0.68068528175354,34.31321273170774,1.5048444271087646,0.6864762902259827,30.02917097359945,1.4497908353805542,3000 -28011.064224004745,1.7157025337219238,41194.56631875038,116839,0,41194.56631875038,0.7036197781562805,30.64315992147018,1.350566267967224,3003,69211.07856273651,0.6848151087760925,33.9288495000838,1.493071436882019,0.6874682307243347,30.30958769276966,1.4417047500610352,3000 -28580.59091734886,1.7584803104400637,42034.72462916374,119225,0,42034.72462916374,0.7048050761222839,30.31912476620588,1.3403512239456177,3003,70620.8813958168,0.6958141326904297,35.5492473489212,1.417291522026062,0.6902580261230469,30.490945218469427,1.4339897632598877,3000 -29180.690301418304,1.8024253845214844,42874.72810292244,121610,0,42874.72810292244,0.7055488228797913,30.548803892158418,1.332283616065979,3003,72061.10469961166,0.6900114417076111,34.7617892280219,1.4532064199447632,0.6903696060180664,30.629143421025496,1.4234853982925415,3000 -29717.70730495453,1.845533847808838,43714.89873600006,123995,0,43714.89873600006,0.7058973908424377,30.80911951363941,1.32646906375885,3003,73438.41144561768,0.6892982721328735,34.6818584249282,1.4594749212265017,0.6919319033622742,30.84775682494113,1.4199069738388062,3000 -30282.248594284058,1.8876619338989256,44555.02311301232,126380,0,44555.02311301232,0.7090000510215759,30.86978933852004,1.3189260959625244,3003,74843.19479346275,0.697314977645874,36.055260217867456,1.4237242937088013,0.6940025687217712,30.82466800846016,1.413859486579895,3000 -30855.69163513184,1.931976318359375,45395.007142305374,128765,0,45395.007142305374,0.7091976404190063,30.909184782342955,1.3158727884292605,3003,76256.74099755287,0.7001673579216003,36.05400553646607,1.3921157121658323,0.6933701634407043,30.57451985032652,1.4130243062973022,3000 -31395.65654921532,1.9757180213928225,46235.168941020966,131150,0,46235.168941020966,0.7093021869659424,30.79200209401173,1.3146555423736572,3003,77636.98960375786,0.6955586671829224,36.22844690063096,1.4250540733337402,0.6936181783676147,30.819000657109928,1.4108177423477173,3000 -31937.92197227478,2.0215542316436768,47003.91788029671,133333,0,47003.91788029671,0.70927894115448,30.84290524056852,1.3146251440048218,3003,78948.11796474457,0.6961641311645508,36.12654870431947,1.423098087310791,0.6939653754234314,30.919892623504534,1.4107656478881836,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/measurements.csv deleted file mode 100644 index ab5b558e0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.1584854,11.173096,,,,,,,,,,,,,,,,, -1,,,0.0006454141112044,11.175609588623049,0.0,0.0004835649742744,11.208685874938965,0.0,3000.0,0.0007088489946909,11.19086742401123,0.0,3003.0,30.8999125957489,908.4347548484802,30.8999125957489,877.5347895622253,0.0,0.0 -100,0.9120364,7.576749,,,,,,,,,,,,,,,,, -200,0.6719931,6.5999084,,,,,,,,,,,,,,,,, -300,0.5261837,5.804872,,,,,,,,,,,,,,,,, -400,0.4605179,5.349124,,,,,,,,,,,,,,,,, -500,0.6017723,5.0637417,,,,,,,,,,,,,,,,, -600,0.534091,4.7921677,,,,,,,,,,,,,,,,, -700,0.50313216,4.5757413,,,,,,,,,,,,,,,,, -800,0.46413988,4.175586,,,,,,,,,,,,,,,,, -900,0.442608,4.074409,,,,,,,,,,,,,,,,, -1000,0.42510727,3.7905827,,,,,,,,,,,,,,,,, -1100,0.36903572,3.592195,,,,,,,,,,,,,,,,, -1200,0.57173824,3.616431,,,,,,,,,,,,,,,,, -1300,0.3741342,3.4885936,,,,,,,,,,,,,,,,, -1400,0.3712481,3.3654962,,,,,,,,,,,,,,,,, -1500,0.27076098,3.2276917,,,,,,,,,,,,,,,,, -1600,0.26548573,3.0955732,,,,,,,,,,,,,,,,, -1700,0.2645238,2.998419,,,,,,,,,,,,,,,,, -1800,0.23360653,2.9169614,,,,,,,,,,,,,,,,, -1900,0.28621972,2.9298196,,,,,,,,,,,,,,,,, -2000,0.23085307,2.9039712,,,,,,,,,,,,,,,,, -2100,0.2871246,2.7874963,,,,,,,,,,,,,,,,, -2200,0.231484,2.7511895,,,,,,,,,,,,,,,,, -2300,0.2463208,2.715552,,,,,,,,,,,,,,,,, -2382,,,0.5327548980712891,2.6247782707214355,24.206353735215085,0.537860631942749,2.565381050109864,20.04948386044301,3000.0,0.5347859263420105,2.591620922088623,18.522566174985347,3003.0,870.8914546966553,2274.2477111816406,870.8914546966553,1403.2590272426603,0.02254319190979,0.0 -2400,0.22320266,2.6450233,,,,,,,,,,,,,,,,, -2500,0.2702448,2.6548362,,,,,,,,,,,,,,,,, -2600,0.20373525,2.477144,,,,,,,,,,,,,,,,, -2700,0.30140772,2.5108051,,,,,,,,,,,,,,,,, -2800,0.20846424,2.4710326,,,,,,,,,,,,,,,,, -2900,0.21569899,2.4159162,,,,,,,,,,,,,,,,, -3000,0.19579476,2.4818466,,,,,,,,,,,,,,,,, -3100,0.23698457,2.3927662,,,,,,,,,,,,,,,,, -3200,0.53971606,2.3762112,,,,,,,,,,,,,,,,, -3300,0.36811438,2.3664236,,,,,,,,,,,,,,,,, -3400,0.36867535,2.393944,,,,,,,,,,,,,,,,, -3500,0.37933743,2.323135,,,,,,,,,,,,,,,,, -3600,0.21555902,2.3386583,,,,,,,,,,,,,,,,, -3700,0.38449967,2.258017,,,,,,,,,,,,,,,,, -3800,0.40978748,2.3632214,,,,,,,,,,,,,,,,, -3900,0.24523932,2.2997625,,,,,,,,,,,,,,,,, -4000,0.3034655,2.307016,,,,,,,,,,,,,,,,, -4100,0.43867955,2.2307115,,,,,,,,,,,,,,,,, -4200,0.9533927,2.250341,,,,,,,,,,,,,,,,, -4300,0.40942362,2.2291045,,,,,,,,,,,,,,,,, -4400,0.29540914,2.2568324,,,,,,,,,,,,,,,,, -4500,0.7420628,2.189945,,,,,,,,,,,,,,,,, -4600,0.22798173,2.2295685,,,,,,,,,,,,,,,,, -4700,0.26891887,2.0984516,,,,,,,,,,,,,,,,, -4766,,,0.5796718597412109,2.213855266571045,26.36870746881244,0.5924167037010193,2.0987660884857178,23.469450249820703,3000.0,0.5945848822593689,2.082348823547364,21.990632388952925,3003.0,1711.0473158359528,3626.431683540344,1711.0473158359528,1915.1850049495697,0.0476388931274414,0.0 -4800,0.3454454,2.1811283,,,,,,,,,,,,,,,,, -4900,0.4182499,2.2185144,,,,,,,,,,,,,,,,, -5000,0.55948806,2.2389207,,,,,,,,,,,,,,,,, -5100,0.5500541,2.2270498,,,,,,,,,,,,,,,,, -5200,0.64561296,2.2233684,,,,,,,,,,,,,,,,, -5300,0.90033644,2.1742742,,,,,,,,,,,,,,,,, -5400,0.45718333,2.1769514,,,,,,,,,,,,,,,,, -5500,0.5643282,2.121436,,,,,,,,,,,,,,,,, -5600,0.51557875,2.194338,,,,,,,,,,,,,,,,, -5700,0.5684475,2.1490862,,,,,,,,,,,,,,,,, -5800,0.31239566,2.241068,,,,,,,,,,,,,,,,, -5900,0.46157867,2.2009869,,,,,,,,,,,,,,,,, -6000,0.31899387,2.1391034,,,,,,,,,,,,,,,,, -6100,0.26976103,2.147893,,,,,,,,,,,,,,,,, -6200,0.6110589,2.218746,,,,,,,,,,,,,,,,, -6300,0.5518816,2.192727,,,,,,,,,,,,,,,,, -6400,0.4172338,2.2077541,,,,,,,,,,,,,,,,, -6500,0.32801536,2.190568,,,,,,,,,,,,,,,,, -6600,0.65502757,2.0760186,,,,,,,,,,,,,,,,, -6700,0.2748082,2.093727,,,,,,,,,,,,,,,,, -6800,0.3251537,2.1953316,,,,,,,,,,,,,,,,, -6900,0.32082304,2.1058311,,,,,,,,,,,,,,,,, -7000,0.38863158,2.190027,,,,,,,,,,,,,,,,, -7100,0.43442765,2.2363665,,,,,,,,,,,,,,,,, -7150,,,0.5893054604530334,2.118350028991699,27.176336003425604,0.6037743091583252,2.024528980255127,24.357176752244964,3000.0,0.6053221821784973,2.0033843517303467,22.893728461005857,3003.0,2551.1354496479034,4959.091591835022,2551.1354496479034,2407.65248298645,0.0771441459655761,0.0 -7200,0.5497874,2.159987,,,,,,,,,,,,,,,,, -7300,0.5342136,2.0906894,,,,,,,,,,,,,,,,, -7400,0.3760766,2.1590526,,,,,,,,,,,,,,,,, -7500,0.24422522,2.141638,,,,,,,,,,,,,,,,, -7600,0.45036423,2.223539,,,,,,,,,,,,,,,,, -7700,0.7464891,2.10053,,,,,,,,,,,,,,,,, -7800,0.39127403,2.0225136,,,,,,,,,,,,,,,,, -7900,0.39247277,2.1780248,,,,,,,,,,,,,,,,, -8000,0.5346445,2.1167903,,,,,,,,,,,,,,,,, -8100,0.26856703,2.0515742,,,,,,,,,,,,,,,,, -8200,0.36635548,2.05923,,,,,,,,,,,,,,,,, -8300,0.44014797,2.2029715,,,,,,,,,,,,,,,,, -8400,0.5292083,2.132677,,,,,,,,,,,,,,,,, -8500,0.5194341,2.1112688,,,,,,,,,,,,,,,,, -8600,0.285794,2.1524477,,,,,,,,,,,,,,,,, -8700,0.40245894,2.1140537,,,,,,,,,,,,,,,,, -8800,0.50363386,2.0976799,,,,,,,,,,,,,,,,, -8900,0.8825774,2.1168947,,,,,,,,,,,,,,,,, -9000,0.26811436,2.1848886,,,,,,,,,,,,,,,,, -9100,0.49316192,2.1343663,,,,,,,,,,,,,,,,, -9200,0.80911696,2.13224,,,,,,,,,,,,,,,,, -9300,1.008302,2.0991828,,,,,,,,,,,,,,,,, -9400,0.8418443,2.1200922,,,,,,,,,,,,,,,,, -9500,0.33083734,2.0869515,,,,,,,,,,,,,,,,, -9535,,,0.5954850912094116,2.119326114654541,27.71182659066221,0.6125404238700867,1.9642142057418823,24.82005795946175,3000.0,0.6162105798721313,1.9303609132766724,23.417765370115543,3003.0,3391.3415517807007,6374.207869529724,3391.3415517807007,2982.4601430892944,0.1054227352142334,0.0 -9600,0.40931022,2.0470743,,,,,,,,,,,,,,,,, -9700,0.51336706,2.0302572,,,,,,,,,,,,,,,,, -9800,0.47940615,2.0796983,,,,,,,,,,,,,,,,, -9900,0.975681,2.2078643,,,,,,,,,,,,,,,,, -10000,0.3559521,2.1155019,,,,,,,,,,,,,,,,, -10100,0.25735864,2.0813537,,,,,,,,,,,,,,,,, -10200,0.41529667,2.080102,,,,,,,,,,,,,,,,, -10300,0.6205228,2.0903056,,,,,,,,,,,,,,,,, -10400,0.33171362,2.0813308,,,,,,,,,,,,,,,,, -10500,0.29981732,2.0381784,,,,,,,,,,,,,,,,, -10600,0.29022318,2.1531544,,,,,,,,,,,,,,,,, -10700,0.45627,2.0942278,,,,,,,,,,,,,,,,, -10800,0.38336384,2.0515413,,,,,,,,,,,,,,,,, -10900,0.97189337,2.1931381,,,,,,,,,,,,,,,,, -11000,0.63979286,2.1729999,,,,,,,,,,,,,,,,, -11100,0.32099718,2.112202,,,,,,,,,,,,,,,,, -11200,0.4162245,2.010409,,,,,,,,,,,,,,,,, -11300,0.43276405,2.177522,,,,,,,,,,,,,,,,, -11400,0.49525687,2.0274942,,,,,,,,,,,,,,,,, -11500,0.71644074,2.2081258,,,,,,,,,,,,,,,,, -11600,0.30165312,2.0858066,,,,,,,,,,,,,,,,, -11700,0.50302213,2.0472593,,,,,,,,,,,,,,,,, -11800,0.6182633,2.1079218,,,,,,,,,,,,,,,,, -11900,0.2917768,2.0966337,,,,,,,,,,,,,,,,, -11917,,,0.5949356555938721,2.093808650970459,27.32630178120262,0.6108169555664062,1.9556013345718384,23.938898942260515,3000.0,0.6186973452568054,1.9178035259246824,22.844513408768293,3003.0,4231.374727487564,7720.172181844711,4231.374727487564,3488.283415555954,0.1319782733917236,0.0 -12000,0.39049387,2.2282195,,,,,,,,,,,,,,,,, -12100,0.5503755,2.0346546,,,,,,,,,,,,,,,,, -12200,0.81119233,2.0185685,,,,,,,,,,,,,,,,, -12300,0.30769196,2.1073658,,,,,,,,,,,,,,,,, -12400,0.4791057,2.097389,,,,,,,,,,,,,,,,, -12500,0.3432019,2.0378067,,,,,,,,,,,,,,,,, -12600,0.5757584,2.0765731,,,,,,,,,,,,,,,,, -12700,0.26034898,2.1198163,,,,,,,,,,,,,,,,, -12800,0.3342583,2.134662,,,,,,,,,,,,,,,,, -12900,0.6730784,2.094427,,,,,,,,,,,,,,,,, -13000,0.66465205,2.0958877,,,,,,,,,,,,,,,,, -13100,0.44567832,2.0902681,,,,,,,,,,,,,,,,, -13200,0.2975904,2.128179,,,,,,,,,,,,,,,,, -13300,0.5065231,2.124542,,,,,,,,,,,,,,,,, -13400,0.43998316,2.091806,,,,,,,,,,,,,,,,, -13500,0.39988554,2.0994334,,,,,,,,,,,,,,,,, -13600,0.5344711,2.0355108,,,,,,,,,,,,,,,,, -13700,0.64139044,2.1600375,,,,,,,,,,,,,,,,, -13800,0.5757676,2.1532574,,,,,,,,,,,,,,,,, -13900,0.6554133,2.0430703,,,,,,,,,,,,,,,,, -14000,0.24247247,2.1212585,,,,,,,,,,,,,,,,, -14100,0.2615011,2.0813384,,,,,,,,,,,,,,,,, -14200,0.28706503,2.0218267,,,,,,,,,,,,,,,,, -14300,0.8033952,2.1338224,,,,,,,,,,,,,,,,, -14301,,,0.597704291343689,2.0802969932556152,28.308626342993755,0.6167685389518738,1.930830478668213,25.06334814984212,3000.0,0.621997594833374,1.8840601444244385,23.81654558217464,3003.0,5071.90565609932,9049.865262746813,5071.90565609932,3977.3424422740936,0.1598775386810302,0.0 -14400,0.25685516,2.0323784,,,,,,,,,,,,,,,,, -14500,0.69598275,2.023879,,,,,,,,,,,,,,,,, -14600,0.34198684,2.108326,,,,,,,,,,,,,,,,, -14700,0.48550546,2.1154783,,,,,,,,,,,,,,,,, -14800,0.77879965,2.0727515,,,,,,,,,,,,,,,,, -14900,0.74102193,2.1838684,,,,,,,,,,,,,,,,, -15000,0.51319075,2.0890474,,,,,,,,,,,,,,,,, -15100,0.24721065,2.0830266,,,,,,,,,,,,,,,,, -15200,0.5536363,2.0304327,,,,,,,,,,,,,,,,, -15300,0.35676152,2.139812,,,,,,,,,,,,,,,,, -15400,0.6958998,2.1604483,,,,,,,,,,,,,,,,, -15500,0.61792976,2.086767,,,,,,,,,,,,,,,,, -15600,0.58125705,2.0635347,,,,,,,,,,,,,,,,, -15700,0.47537512,2.1095297,,,,,,,,,,,,,,,,, -15800,0.32013285,2.035679,,,,,,,,,,,,,,,,, -15900,0.47834983,2.0273724,,,,,,,,,,,,,,,,, -16000,0.60546696,2.0780718,,,,,,,,,,,,,,,,, -16100,0.65225196,2.0565925,,,,,,,,,,,,,,,,, -16200,0.4713755,2.1164603,,,,,,,,,,,,,,,,, -16300,1.0738262,2.052368,,,,,,,,,,,,,,,,, -16400,0.27698442,2.074411,,,,,,,,,,,,,,,,, -16500,0.6366992,2.126417,,,,,,,,,,,,,,,,, -16600,0.5279014,2.0923998,,,,,,,,,,,,,,,,, -16684,,,0.5992457866668701,2.077332258224488,28.287729139742407,0.6202650666236877,1.911595344543457,25.16760638686699,3000.0,0.6254836916923523,1.8665118217468264,23.955626051084025,3003.0,5911.835096359253,10376.551353693008,5911.835096359253,4463.997585058212,0.1864314079284668,0.0 -16700,0.4986914,2.0202987,,,,,,,,,,,,,,,,, -16800,0.47080883,2.091581,,,,,,,,,,,,,,,,, -16900,0.7126114,2.0578315,,,,,,,,,,,,,,,,, -17000,0.512526,2.0756543,,,,,,,,,,,,,,,,, -17100,0.39639926,2.0658233,,,,,,,,,,,,,,,,, -17200,0.34228107,2.1618319,,,,,,,,,,,,,,,,, -17300,0.37878183,2.0670009,,,,,,,,,,,,,,,,, -17400,0.39176154,2.0438552,,,,,,,,,,,,,,,,, -17500,0.30768874,2.1849484,,,,,,,,,,,,,,,,, -17600,0.34650496,2.0820007,,,,,,,,,,,,,,,,, -17700,0.84366167,2.0774107,,,,,,,,,,,,,,,,, -17800,0.3503298,2.085209,,,,,,,,,,,,,,,,, -17900,0.5743376,2.0642743,,,,,,,,,,,,,,,,, -18000,0.5210725,1.94504,,,,,,,,,,,,,,,,, -18100,0.74398816,2.0893764,,,,,,,,,,,,,,,,, -18200,0.71692324,2.1321616,,,,,,,,,,,,,,,,, -18300,0.62855154,2.0386877,,,,,,,,,,,,,,,,, -18400,0.30080554,2.0924525,,,,,,,,,,,,,,,,, -18500,0.5946322,2.034745,,,,,,,,,,,,,,,,, -18600,0.3059503,2.0395954,,,,,,,,,,,,,,,,, -18700,0.62252873,2.052302,,,,,,,,,,,,,,,,, -18800,0.3601119,2.1033664,,,,,,,,,,,,,,,,, -18900,0.4274263,2.0569286,,,,,,,,,,,,,,,,, -19000,0.6820303,2.039543,,,,,,,,,,,,,,,,, -19068,,,0.6137471199035645,1.926009178161621,29.14411802678976,0.6178720593452454,1.9107155799865725,25.27364721195347,3000.0,0.6246935129165649,1.8686076402664185,24.06795003200436,3003.0,6751.987172842026,11724.933151721954,6751.987172842026,4972.11691904068,0.2197494506835937,0.0 -19100,0.7520465,2.110833,,,,,,,,,,,,,,,,, -19200,0.40657088,2.079062,,,,,,,,,,,,,,,,, -19300,0.72061026,2.0445433,,,,,,,,,,,,,,,,, -19400,0.3414867,2.1201422,,,,,,,,,,,,,,,,, -19500,0.41171777,2.018846,,,,,,,,,,,,,,,,, -19600,0.29221943,2.1200044,,,,,,,,,,,,,,,,, -19700,0.73227197,2.1154811,,,,,,,,,,,,,,,,, -19800,0.5148663,2.092703,,,,,,,,,,,,,,,,, -19900,0.5122927,2.0311677,,,,,,,,,,,,,,,,, -20000,0.5819007,2.081273,,,,,,,,,,,,,,,,, -20100,0.64290816,2.1381195,,,,,,,,,,,,,,,,, -20200,0.4987082,1.9768935,,,,,,,,,,,,,,,,, -20300,0.30639577,1.9776391,,,,,,,,,,,,,,,,, -20400,0.73966503,2.103664,,,,,,,,,,,,,,,,, -20500,0.5561065,2.0250442,,,,,,,,,,,,,,,,, -20600,0.5357244,1.9417436,,,,,,,,,,,,,,,,, -20700,0.24414529,2.0911255,,,,,,,,,,,,,,,,, -20800,0.4698121,2.06425,,,,,,,,,,,,,,,,, -20900,0.30878156,2.086443,,,,,,,,,,,,,,,,, -21000,0.44854462,2.0119972,,,,,,,,,,,,,,,,, -21100,0.74153334,2.011284,,,,,,,,,,,,,,,,, -21200,0.4943399,2.0872023,,,,,,,,,,,,,,,,, -21300,0.2588037,2.0933852,,,,,,,,,,,,,,,,, -21400,0.34836912,2.0847492,,,,,,,,,,,,,,,,, -21452,,,0.6039170026779175,2.04196834564209,28.499650206528244,0.6211702227592468,1.893813848495484,25.326143244982028,3000.0,0.6274592280387878,1.8478097915649407,24.41130121879812,3003.0,7591.907975435257,13117.867561101912,7591.907975435257,5525.0277309417725,0.2484176158905029,0.0 -21500,0.89396316,2.1003938,,,,,,,,,,,,,,,,, -21600,0.43218398,2.1111846,,,,,,,,,,,,,,,,, -21700,0.5196032,2.0377874,,,,,,,,,,,,,,,,, -21800,0.53162307,2.0154269,,,,,,,,,,,,,,,,, -21900,0.37058803,1.9904662,,,,,,,,,,,,,,,,, -22000,0.41969347,2.1125028,,,,,,,,,,,,,,,,, -22100,0.3728482,1.9654055,,,,,,,,,,,,,,,,, -22200,0.56617105,2.0364432,,,,,,,,,,,,,,,,, -22300,0.51310426,2.042431,,,,,,,,,,,,,,,,, -22400,0.5775495,2.033118,,,,,,,,,,,,,,,,, -22500,0.50598705,2.002819,,,,,,,,,,,,,,,,, -22600,0.710224,2.0748935,,,,,,,,,,,,,,,,, -22700,0.70027095,2.020578,,,,,,,,,,,,,,,,, -22800,0.28601182,1.9828366,,,,,,,,,,,,,,,,, -22900,0.5448546,2.0638714,,,,,,,,,,,,,,,,, -23000,0.5684876,2.013894,,,,,,,,,,,,,,,,, -23100,0.25617304,1.9850959,,,,,,,,,,,,,,,,, -23200,0.8121614,2.1092467,,,,,,,,,,,,,,,,, -23300,0.5711292,2.1261122,,,,,,,,,,,,,,,,, -23400,0.591505,2.111605,,,,,,,,,,,,,,,,, -23500,0.60637826,1.9613458,,,,,,,,,,,,,,,,, -23600,0.38722622,2.0305252,,,,,,,,,,,,,,,,, -23700,0.41160947,1.9651375,,,,,,,,,,,,,,,,, -23800,0.4004746,2.0199704,,,,,,,,,,,,,,,,, -23837,,,0.6014580130577087,2.0570669174194336,28.452707209093468,0.6210214495658875,1.8938277959823608,25.541055572678147,3000.0,0.6271222233772278,1.846617221832276,24.527137407140724,3003.0,8432.142497062683,14456.54249548912,8432.142497062683,6023.364975690842,0.2760303020477295,0.0 -23900,0.44310263,2.0717175,,,,,,,,,,,,,,,,, -24000,0.3236027,1.9552057,,,,,,,,,,,,,,,,, -24100,0.28530586,2.006041,,,,,,,,,,,,,,,,, -24200,0.67217064,2.093265,,,,,,,,,,,,,,,,, -24300,0.34719193,2.0977242,,,,,,,,,,,,,,,,, -24400,0.55852056,2.050623,,,,,,,,,,,,,,,,, -24500,0.2683384,2.0420346,,,,,,,,,,,,,,,,, -24600,1.101818,2.072652,,,,,,,,,,,,,,,,, -24700,0.29764995,2.0759668,,,,,,,,,,,,,,,,, -24800,0.40291274,2.0815847,,,,,,,,,,,,,,,,, -24900,0.3865073,2.0517144,,,,,,,,,,,,,,,,, -25000,0.3827033,1.9875734,,,,,,,,,,,,,,,,, -25100,0.34801596,2.1196427,,,,,,,,,,,,,,,,, -25200,0.3452461,2.017267,,,,,,,,,,,,,,,,, -25300,0.2742406,1.994555,,,,,,,,,,,,,,,,, -25400,0.37013847,2.0497043,,,,,,,,,,,,,,,,, -25500,0.26328623,2.0598118,,,,,,,,,,,,,,,,, -25600,0.55447733,2.0341566,,,,,,,,,,,,,,,,, -25700,0.5187611,2.1362493,,,,,,,,,,,,,,,,, -25800,0.79366964,2.0705578,,,,,,,,,,,,,,,,, -25900,0.45817629,2.0443754,,,,,,,,,,,,,,,,, -26000,0.8059812,2.040383,,,,,,,,,,,,,,,,, -26100,0.37380132,2.0742068,,,,,,,,,,,,,,,,, -26200,0.602813,2.0830538,,,,,,,,,,,,,,,,, -26220,,,0.6070371866226196,2.0028188228607178,28.626233950948446,0.6237616539001465,1.874969720840454,25.473166487207465,3000.0,0.6297716498374939,1.828667402267456,24.47234568073476,3003.0,9272.08500289917,15889.645047426224,9272.08500289917,6616.417924404144,0.303934097290039,0.0 -26300,0.48833573,2.054113,,,,,,,,,,,,,,,,, -26400,0.29296273,1.9770898,,,,,,,,,,,,,,,,, -26500,0.63874173,1.9819659,,,,,,,,,,,,,,,,, -26600,0.4414703,1.955903,,,,,,,,,,,,,,,,, -26700,0.37055418,2.003642,,,,,,,,,,,,,,,,, -26800,0.3119369,1.9625763,,,,,,,,,,,,,,,,, -26900,0.729306,2.0135849,,,,,,,,,,,,,,,,, -27000,0.27458847,2.045089,,,,,,,,,,,,,,,,, -27100,0.34454638,2.0144112,,,,,,,,,,,,,,,,, -27200,0.31776404,1.9834253,,,,,,,,,,,,,,,,, -27300,0.38566867,2.0109982,,,,,,,,,,,,,,,,, -27400,0.55039763,2.0093424,,,,,,,,,,,,,,,,, -27500,0.49551153,2.063279,,,,,,,,,,,,,,,,, -27600,0.49301633,1.9805402,,,,,,,,,,,,,,,,, -27700,0.42972746,2.0264268,,,,,,,,,,,,,,,,, -27800,0.29793483,1.9845917,,,,,,,,,,,,,,,,, -27900,0.33380416,2.049722,,,,,,,,,,,,,,,,, -28000,0.49872002,2.0602436,,,,,,,,,,,,,,,,, -28100,0.46164683,2.0591872,,,,,,,,,,,,,,,,, -28200,0.3301714,2.0853708,,,,,,,,,,,,,,,,, -28300,0.52543813,2.0277016,,,,,,,,,,,,,,,,, -28400,0.4166311,2.0391688,,,,,,,,,,,,,,,,, -28500,0.5262695,2.0046728,,,,,,,,,,,,,,,,, -28600,0.27706817,2.015594,,,,,,,,,,,,,,,,, -28604,,,0.6041083335876465,2.042267322540283,28.460018032548263,0.6240839958190918,1.880605697631836,25.687265643742247,3000.0,0.6322584748268127,1.824683427810669,24.058762358703923,3003.0,10112.03343486786,17247.554879665375,10112.03343486786,7134.273945808411,0.3334333896636963,0.0 -28700,0.8042214,1.9662097,,,,,,,,,,,,,,,,, -28800,0.7947277,2.0843308,,,,,,,,,,,,,,,,, -28900,0.7587766,1.9697928,,,,,,,,,,,,,,,,, -29000,0.33442235,2.1233118,,,,,,,,,,,,,,,,, -29100,0.5281761,1.9484047,,,,,,,,,,,,,,,,, -29200,0.5095912,2.0241375,,,,,,,,,,,,,,,,, -29300,0.45314494,2.030292,,,,,,,,,,,,,,,,, -29400,0.4175993,2.1476793,,,,,,,,,,,,,,,,, -29500,0.3416028,2.112893,,,,,,,,,,,,,,,,, -29600,0.3025031,2.02594,,,,,,,,,,,,,,,,, -29700,0.4246565,2.0247834,,,,,,,,,,,,,,,,, -29800,0.56630653,2.0063727,,,,,,,,,,,,,,,,, -29900,0.57347625,2.0139828,,,,,,,,,,,,,,,,, -30000,0.610801,1.9676375,,,,,,,,,,,,,,,,, -30100,0.2775765,2.002184,,,,,,,,,,,,,,,,, -30200,0.44685316,2.0742545,,,,,,,,,,,,,,,,, -30300,0.31457686,2.033009,,,,,,,,,,,,,,,,, -30400,0.56223816,2.0072296,,,,,,,,,,,,,,,,, -30500,0.8330741,2.0062065,,,,,,,,,,,,,,,,, -30600,0.43997592,2.0921888,,,,,,,,,,,,,,,,, -30700,0.48163146,2.103687,,,,,,,,,,,,,,,,, -30800,0.44849342,2.0503004,,,,,,,,,,,,,,,,, -30900,0.54150593,1.9098384,,,,,,,,,,,,,,,,, -30989,,,0.6057415008544922,2.0303850173950195,28.31511913979044,0.6273201704025269,1.8572407960891724,25.75657484227801,3000.0,0.6323397755622864,1.8168457746505733,24.526306536985867,3003.0,10952.269091129305,18601.20199108124,10952.269091129305,7647.579026222229,0.3638780117034912,0.0 -31000,0.6501099,2.0804706,,,,,,,,,,,,,,,,, -31100,0.4075855,1.9816874,,,,,,,,,,,,,,,,, -31200,0.31475896,2.004291,,,,,,,,,,,,,,,,, -31300,0.3658772,2.094629,,,,,,,,,,,,,,,,, -31400,0.4069051,2.058246,,,,,,,,,,,,,,,,, -31500,0.71043456,2.0123668,,,,,,,,,,,,,,,,, -31600,0.592562,1.966457,,,,,,,,,,,,,,,,, -31700,0.33494365,2.0543184,,,,,,,,,,,,,,,,, -31800,0.45625046,1.8829668,,,,,,,,,,,,,,,,, -31900,0.37404865,2.0068433,,,,,,,,,,,,,,,,, -32000,0.49899626,2.0421422,,,,,,,,,,,,,,,,, -32100,0.41334412,1.9983277,,,,,,,,,,,,,,,,, -32200,0.40197843,2.034851,,,,,,,,,,,,,,,,, -32300,0.31551176,2.0699816,,,,,,,,,,,,,,,,, -32400,0.75888044,1.9961416,,,,,,,,,,,,,,,,, -32500,0.32827133,1.9788581,,,,,,,,,,,,,,,,, -32600,0.32982215,2.058436,,,,,,,,,,,,,,,,, -32700,0.33565393,2.0568538,,,,,,,,,,,,,,,,, -32800,0.495217,2.0968912,,,,,,,,,,,,,,,,, -32900,0.24907638,1.9941854,,,,,,,,,,,,,,,,, -33000,0.52271265,1.9885254,,,,,,,,,,,,,,,,, -33100,0.28824142,2.060371,,,,,,,,,,,,,,,,, -33200,0.41017494,1.9964824,,,,,,,,,,,,,,,,, -33300,0.5256883,1.9679561,,,,,,,,,,,,,,,,, -33374,,,0.60844486951828,2.00085997581482,28.007482098976688,0.6259438991546631,1.856594443321228,25.505005927829227,3000.0,0.6319795846939087,1.811674952507019,24.44685371781997,3003.0,11792.17632174492,20039.75516843796,11792.17632174492,8246.12324142456,0.3925106525421142,0.0 -33400,0.39901257,2.0256286,,,,,,,,,,,,,,,,, -33500,0.32115272,2.0114126,,,,,,,,,,,,,,,,, -33600,0.36859998,1.9837235,,,,,,,,,,,,,,,,, -33700,0.34570011,1.8929564,,,,,,,,,,,,,,,,, -33800,0.57783705,1.9152153,,,,,,,,,,,,,,,,, -33900,0.5832654,1.9949117,,,,,,,,,,,,,,,,, -34000,0.48938343,2.0042143,,,,,,,,,,,,,,,,, -34100,0.52142715,1.9688376,,,,,,,,,,,,,,,,, -34200,0.35491666,1.9254075,,,,,,,,,,,,,,,,, -34300,0.49667835,2.024783,,,,,,,,,,,,,,,,, -34400,0.29917502,2.005683,,,,,,,,,,,,,,,,, -34500,0.32673267,1.9960997,,,,,,,,,,,,,,,,, -34600,0.29335755,1.9394264,,,,,,,,,,,,,,,,, -34700,0.42478648,2.0754673,,,,,,,,,,,,,,,,, -34800,0.6282155,1.9876872,,,,,,,,,,,,,,,,, -34900,0.35901138,1.9589108,,,,,,,,,,,,,,,,, -35000,0.3397401,2.0622969,,,,,,,,,,,,,,,,, -35100,0.28198183,2.0361838,,,,,,,,,,,,,,,,, -35200,0.4350345,2.0054219,,,,,,,,,,,,,,,,, -35300,0.28938845,1.9539353,,,,,,,,,,,,,,,,, -35400,0.51868945,2.0211768,,,,,,,,,,,,,,,,, -35500,0.64634883,2.0468292,,,,,,,,,,,,,,,,, -35600,0.51673913,1.9816543,,,,,,,,,,,,,,,,, -35700,0.30799925,1.9944878,,,,,,,,,,,,,,,,, -35758,,,0.6062001585960388,2.012900114059448,28.496158366728405,0.6244683861732483,1.8566988706588743,25.335470426237546,3000.0,0.6337923407554626,1.8009341955184937,24.817481823924734,3003.0,12632.082436800005,21398.67829966545,12632.082436800005,8765.0341360569,0.4228677749633789,0.0 -35800,0.3064977,2.0965075,,,,,,,,,,,,,,,,, -35900,0.42016813,2.0108552,,,,,,,,,,,,,,,,, -36000,0.53239554,1.9771969,,,,,,,,,,,,,,,,, -36100,0.37868538,2.057973,,,,,,,,,,,,,,,,, -36200,0.31036785,2.0288215,,,,,,,,,,,,,,,,, -36300,0.83764637,2.0836744,,,,,,,,,,,,,,,,, -36400,0.49458423,2.0184996,,,,,,,,,,,,,,,,, -36500,0.42458737,1.9250543,,,,,,,,,,,,,,,,, -36600,0.49749383,1.9567943,,,,,,,,,,,,,,,,, -36700,0.57551605,1.9499607,,,,,,,,,,,,,,,,, -36800,0.60899895,2.0410075,,,,,,,,,,,,,,,,, -36900,0.50142807,2.0841928,,,,,,,,,,,,,,,,, -37000,0.34891745,2.0020416,,,,,,,,,,,,,,,,, -37100,0.37790555,2.0544722,,,,,,,,,,,,,,,,, -37200,0.36586872,2.087675,,,,,,,,,,,,,,,,, -37300,0.35879862,2.0131285,,,,,,,,,,,,,,,,, -37400,0.4836905,1.9362261,,,,,,,,,,,,,,,,, -37500,0.51062095,1.998936,,,,,,,,,,,,,,,,, -37600,0.46977088,1.9658489,,,,,,,,,,,,,,,,, -37700,0.26608,1.8712692,,,,,,,,,,,,,,,,, -37800,0.42187798,1.9828411,,,,,,,,,,,,,,,,, -37900,0.41052818,1.9819828,,,,,,,,,,,,,,,,, -38000,0.7622226,1.9888794,,,,,,,,,,,,,,,,, -38100,0.48726273,2.0182776,,,,,,,,,,,,,,,,, -38143,,,0.6146321892738342,1.95678985118866,29.38206898813265,0.628361701965332,1.8436554670333865,25.92643316412883,3000.0,0.6382081508636475,1.7820231914520264,25.00556614291527,3003.0,13472.220349311829,22704.81827187538,13472.220349311829,9230.926263809204,0.4587104320526123,0.0 -38200,0.3656439,1.9930873,,,,,,,,,,,,,,,,, -38300,1.0586543,1.8971745,,,,,,,,,,,,,,,,, -38400,0.35704413,2.006668,,,,,,,,,,,,,,,,, -38500,0.328064,1.9327348,,,,,,,,,,,,,,,,, -38600,0.33995065,1.9433395,,,,,,,,,,,,,,,,, -38700,0.639554,1.9160506,,,,,,,,,,,,,,,,, -38800,0.2911974,2.010898,,,,,,,,,,,,,,,,, -38900,0.30329704,1.930161,,,,,,,,,,,,,,,,, -39000,0.45733386,2.0852127,,,,,,,,,,,,,,,,, -39100,0.43909398,2.009696,,,,,,,,,,,,,,,,, -39200,0.60485095,1.9902394,,,,,,,,,,,,,,,,, -39300,0.27053547,1.999566,,,,,,,,,,,,,,,,, -39400,0.3149782,2.0772412,,,,,,,,,,,,,,,,, -39500,0.3114941,2.011865,,,,,,,,,,,,,,,,, -39600,0.44312918,1.950698,,,,,,,,,,,,,,,,, -39700,0.23395564,1.9434742,,,,,,,,,,,,,,,,, -39800,0.9036982,2.0067208,,,,,,,,,,,,,,,,, -39900,0.30824405,1.9834082,,,,,,,,,,,,,,,,, -40000,0.87017655,2.0705075,,,,,,,,,,,,,,,,, -40100,0.45470956,1.9895866,,,,,,,,,,,,,,,,, -40200,0.5767929,1.9575156,,,,,,,,,,,,,,,,, -40300,0.291917,1.9708333,,,,,,,,,,,,,,,,, -40400,0.65740496,1.9965277,,,,,,,,,,,,,,,,, -40500,0.694172,2.0043113,,,,,,,,,,,,,,,,, -40527,,,0.611473023891449,1.9678258895874023,28.94678678261578,0.6278533339500427,1.835180640220642,25.68769018292687,3000.0,0.6370925903320312,1.7789210081100464,24.96963784239136,3003.0,14312.36628293991,24036.06175875664,14312.36628293991,9721.91780424118,0.4886271953582763,0.0 -40600,0.50118685,2.0010462,,,,,,,,,,,,,,,,, -40700,0.6966173,1.9670748,,,,,,,,,,,,,,,,, -40800,0.31447402,1.9439673,,,,,,,,,,,,,,,,, -40900,0.3803868,1.9149135,,,,,,,,,,,,,,,,, -41000,0.30451146,1.9734386,,,,,,,,,,,,,,,,, -41100,0.36902273,2.044089,,,,,,,,,,,,,,,,, -41200,0.30731228,2.0204227,,,,,,,,,,,,,,,,, -41300,0.24781203,1.9664835,,,,,,,,,,,,,,,,, -41400,0.29797852,1.9141114,,,,,,,,,,,,,,,,, -41500,0.31363863,1.9398706,,,,,,,,,,,,,,,,, -41600,0.5615715,1.9999866,,,,,,,,,,,,,,,,, -41700,0.5053487,2.0327094,,,,,,,,,,,,,,,,, -41800,0.30126098,2.0426404,,,,,,,,,,,,,,,,, -41900,0.33499125,2.004493,,,,,,,,,,,,,,,,, -42000,0.33980995,1.962534,,,,,,,,,,,,,,,,, -42100,0.38500434,1.9729825,,,,,,,,,,,,,,,,, -42200,0.34501088,1.9873645,,,,,,,,,,,,,,,,, -42300,0.26321423,1.9576824,,,,,,,,,,,,,,,,, -42400,0.36671826,1.9623685,,,,,,,,,,,,,,,,, -42500,0.35489598,1.9387237,,,,,,,,,,,,,,,,, -42600,0.29613233,1.9442407,,,,,,,,,,,,,,,,, -42700,0.41124496,1.9570312,,,,,,,,,,,,,,,,, -42800,0.80102503,1.9805163,,,,,,,,,,,,,,,,, -42900,0.44284913,1.928315,,,,,,,,,,,,,,,,, -42911,,,0.6131593585014343,1.972790241241455,29.0899991882196,0.631461501121521,1.8121095895767207,26.09814806128841,3000.0,0.6399396061897278,1.767079472541809,25.392114429594827,3003.0,15152.403260946274,25408.2639605999,15152.403260946274,10253.97549557686,0.5196661949157715,0.0 -43000,0.32715848,1.9757844,,,,,,,,,,,,,,,,, -43100,0.6732866,2.0150058,,,,,,,,,,,,,,,,, -43200,0.3321061,1.9791659,,,,,,,,,,,,,,,,, -43300,0.39490327,1.9097748,,,,,,,,,,,,,,,,, -43400,0.5628517,2.017397,,,,,,,,,,,,,,,,, -43500,0.52770543,1.9697663,,,,,,,,,,,,,,,,, -43600,0.65654534,1.9466555,,,,,,,,,,,,,,,,, -43700,0.6826919,2.0004492,,,,,,,,,,,,,,,,, -43800,0.32663855,1.9798315,,,,,,,,,,,,,,,,, -43900,0.38836163,1.9375834,,,,,,,,,,,,,,,,, -44000,0.4129268,1.9390024,,,,,,,,,,,,,,,,, -44100,0.45610625,1.9473671,,,,,,,,,,,,,,,,, -44200,0.27824906,2.0391855,,,,,,,,,,,,,,,,, -44300,0.28157842,1.9512746,,,,,,,,,,,,,,,,, -44400,0.28559515,1.9924893,,,,,,,,,,,,,,,,, -44500,0.31976157,1.9974424,,,,,,,,,,,,,,,,, -44600,0.31617048,2.0002372,,,,,,,,,,,,,,,,, -44700,0.43261617,1.8744588,,,,,,,,,,,,,,,,, -44800,0.32709196,1.9302552,,,,,,,,,,,,,,,,, -44900,0.4550475,2.1380396,,,,,,,,,,,,,,,,, -45000,0.38353723,1.9036314,,,,,,,,,,,,,,,,, -45100,0.33255917,1.945942,,,,,,,,,,,,,,,,, -45200,0.41203725,1.9995207,,,,,,,,,,,,,,,,, -45296,,,0.6127845048904419,1.955664873123169,28.965295301006385,0.6328253746032715,1.8045439720153809,25.96706766904877,3000.0,0.6396374702453613,1.7606006860733032,24.73479737642498,3003.0,15992.33811044693,26808.432076215744,15992.33811044693,10814.103539466858,0.5507168769836426,0.0 -45300,0.8606577,1.9570986,,,,,,,,,,,,,,,,, -45400,0.47257188,2.0068376,,,,,,,,,,,,,,,,, -45500,0.80305713,2.0642805,,,,,,,,,,,,,,,,, -45600,0.48224607,1.9912168,,,,,,,,,,,,,,,,, -45700,0.5294795,1.9806813,,,,,,,,,,,,,,,,, -45800,0.34820074,1.9576484,,,,,,,,,,,,,,,,, -45900,0.3801084,2.0558577,,,,,,,,,,,,,,,,, -46000,0.41578412,1.924957,,,,,,,,,,,,,,,,, -46100,0.30368614,2.033445,,,,,,,,,,,,,,,,, -46200,0.32684204,1.9626623,,,,,,,,,,,,,,,,, -46300,0.4987399,1.9696032,,,,,,,,,,,,,,,,, -46400,0.30568007,1.9309645,,,,,,,,,,,,,,,,, -46500,0.40162998,1.8849965,,,,,,,,,,,,,,,,, -46600,0.31525096,1.9234374,,,,,,,,,,,,,,,,, -46700,0.4533151,1.960294,,,,,,,,,,,,,,,,, -46800,0.6432995,1.9747186,,,,,,,,,,,,,,,,, -46900,0.8542832,1.9768702,,,,,,,,,,,,,,,,, -47000,0.35105243,2.072251,,,,,,,,,,,,,,,,, -47100,0.4344335,2.0679245,,,,,,,,,,,,,,,,, -47200,0.24784787,2.0421917,,,,,,,,,,,,,,,,, -47300,0.39884433,1.8859625,,,,,,,,,,,,,,,,, -47400,0.3872806,1.9382463,,,,,,,,,,,,,,,,, -47500,0.53571516,2.002181,,,,,,,,,,,,,,,,, -47600,0.7695101,1.9060742,,,,,,,,,,,,,,,,, -47681,,,0.6143538355827332,1.9528555870056152,29.107341526376448,0.6314491033554077,1.8058665990829468,25.94393430285941,3000.0,0.6436465382575989,1.74813711643219,25.42608692174256,3003.0,16832.49195432663,28155.61889219284,16832.49195432663,11321.025541305542,0.5871026515960693,0.0 -47700,0.67768574,2.0356143,,,,,,,,,,,,,,,,, -47800,0.36972943,1.9545453,,,,,,,,,,,,,,,,, -47900,0.46283305,1.9674891,,,,,,,,,,,,,,,,, -48000,0.30706733,1.9242362,,,,,,,,,,,,,,,,, -48100,0.54838145,1.8722814,,,,,,,,,,,,,,,,, -48200,0.5931497,1.9607632,,,,,,,,,,,,,,,,, -48300,0.3241058,2.0020118,,,,,,,,,,,,,,,,, -48400,0.5029869,1.9948132,,,,,,,,,,,,,,,,, -48500,0.33334413,1.9454398,,,,,,,,,,,,,,,,, -48600,0.3097873,1.9562341,,,,,,,,,,,,,,,,, -48700,0.3814128,2.012861,,,,,,,,,,,,,,,,, -48800,0.36114746,1.913479,,,,,,,,,,,,,,,,, -48900,0.4515172,2.0301926,,,,,,,,,,,,,,,,, -49000,0.31377718,1.9576175,,,,,,,,,,,,,,,,, -49100,0.5553693,1.9383788,,,,,,,,,,,,,,,,, -49200,0.6998724,1.931328,,,,,,,,,,,,,,,,, -49300,0.28842878,1.9655163,,,,,,,,,,,,,,,,, -49400,0.6282877,2.0141919,,,,,,,,,,,,,,,,, -49500,0.5353689,1.9056301,,,,,,,,,,,,,,,,, -49600,0.94641584,1.9686826,,,,,,,,,,,,,,,,, -49700,0.31158188,2.0075684,,,,,,,,,,,,,,,,, -49800,0.50137717,2.055878,,,,,,,,,,,,,,,,, -49900,0.552258,2.0471795,,,,,,,,,,,,,,,,, -50000,0.36377487,1.9580047,,,,,,,,,,,,,,,,, -50064,,,0.6843953728675842,1.47307026386261,34.336430518788745,0.635949969291687,1.7930233478546145,26.441084450311465,3000.0,0.6434489488601685,1.7467548847198486,25.14636965746108,3003.0,17672.648672819138,29476.605541706085,17672.648672819138,11801.74766111374,0.6183607578277588,0.0 -50100,0.33141217,1.9574188,,,,,,,,,,,,,,,,, -50200,0.4549382,1.9511453,,,,,,,,,,,,,,,,, -50300,0.32334936,1.9555396,,,,,,,,,,,,,,,,, -50400,0.27706876,1.9599587,,,,,,,,,,,,,,,,, -50500,0.35359603,1.9189895,,,,,,,,,,,,,,,,, -50600,0.37070927,1.9372092,,,,,,,,,,,,,,,,, -50700,0.27552053,1.9412395,,,,,,,,,,,,,,,,, -50800,0.8805007,2.0299103,,,,,,,,,,,,,,,,, -50900,0.41174555,1.9957036,,,,,,,,,,,,,,,,, -51000,0.6811971,1.9458061,,,,,,,,,,,,,,,,, -51100,0.2791209,1.8938862,,,,,,,,,,,,,,,,, -51200,0.41398394,1.9157871,,,,,,,,,,,,,,,,, -51300,0.428377,1.9908848,,,,,,,,,,,,,,,,, -51400,0.6198921,1.8746966,,,,,,,,,,,,,,,,, -51500,0.29912347,1.9004672,,,,,,,,,,,,,,,,, -51600,0.34278622,2.0324085,,,,,,,,,,,,,,,,, -51700,0.29887122,1.9590114,,,,,,,,,,,,,,,,, -51800,0.28358218,2.042571,,,,,,,,,,,,,,,,, -51900,0.5867409,1.9069924,,,,,,,,,,,,,,,,, -52000,0.475961,1.9158903,,,,,,,,,,,,,,,,, -52100,0.5224607,1.9698064,,,,,,,,,,,,,,,,, -52200,0.3535178,2.0003574,,,,,,,,,,,,,,,,, -52300,0.56275904,1.9225655,,,,,,,,,,,,,,,,, -52400,0.35258475,2.0219483,,,,,,,,,,,,,,,,, -52447,,,0.6166776418685913,1.935380220413208,28.009326838866563,0.6369170546531677,1.7796038389205933,26.137517861350776,3000.0,0.6453315019607544,1.7311948537826538,25.591061740837908,3003.0,18512.559674978256,30892.52356314659,18512.559674978256,12377.643058538437,0.6505851745605469,0.0 -52500,0.43309432,1.9598655,,,,,,,,,,,,,,,,, -52600,0.42222986,1.9615607,,,,,,,,,,,,,,,,, -52700,0.3795238,1.8885474,,,,,,,,,,,,,,,,, -52800,0.5210483,1.9554048,,,,,,,,,,,,,,,,, -52900,0.35798496,1.9264044,,,,,,,,,,,,,,,,, -53000,0.27302417,1.9238386,,,,,,,,,,,,,,,,, -53100,0.460881,1.874376,,,,,,,,,,,,,,,,, -53200,0.26153174,1.89014,,,,,,,,,,,,,,,,, -53300,0.32724264,1.8564318,,,,,,,,,,,,,,,,, -53400,0.34278145,1.9526643,,,,,,,,,,,,,,,,, -53500,0.34091347,1.9215075,,,,,,,,,,,,,,,,, -53600,0.59564847,1.896924,,,,,,,,,,,,,,,,, -53700,0.51759976,1.991805,,,,,,,,,,,,,,,,, -53800,0.3551204,1.9118282,,,,,,,,,,,,,,,,, -53900,0.328916,1.9279743,,,,,,,,,,,,,,,,, -54000,0.7647821,1.9121578,,,,,,,,,,,,,,,,, -54100,0.31509623,2.0765975,,,,,,,,,,,,,,,,, -54200,0.2718476,1.970214,,,,,,,,,,,,,,,,, -54300,0.27766103,1.945123,,,,,,,,,,,,,,,,, -54400,0.30416405,1.9325948,,,,,,,,,,,,,,,,, -54500,0.480119,2.013258,,,,,,,,,,,,,,,,, -54600,0.30112284,1.9569073,,,,,,,,,,,,,,,,, -54700,0.35548764,1.9746934,,,,,,,,,,,,,,,,, -54800,0.3271167,1.8979388,,,,,,,,,,,,,,,,, -54831,,,0.6168438792228699,1.9355334043502808,29.60037342386599,0.639086902141571,1.7692811489105225,26.58346724934059,3000.0,0.6477136611938477,1.711084008216858,25.65318269635807,3003.0,19352.45453119278,32285.287123441696,19352.45453119278,12930.402215003967,0.6834166049957275,0.0 -54900,0.56087273,1.9745268,,,,,,,,,,,,,,,,, -55000,0.29247263,1.9753397,,,,,,,,,,,,,,,,, -55100,0.50836676,1.9303683,,,,,,,,,,,,,,,,, -55200,0.43964642,1.8842045,,,,,,,,,,,,,,,,, -55300,0.5090469,1.8807651,,,,,,,,,,,,,,,,, -55400,0.35527855,1.9300503,,,,,,,,,,,,,,,,, -55500,0.586227,1.9946119,,,,,,,,,,,,,,,,, -55600,0.3315799,2.0079968,,,,,,,,,,,,,,,,, -55700,0.7172245,1.81939,,,,,,,,,,,,,,,,, -55800,0.31202117,1.9261336,,,,,,,,,,,,,,,,, -55900,0.3082214,1.9069933,,,,,,,,,,,,,,,,, -56000,0.31261173,1.9507202,,,,,,,,,,,,,,,,, -56100,0.31664547,1.9353336,,,,,,,,,,,,,,,,, -56200,0.683464,2.0014925,,,,,,,,,,,,,,,,, -56300,0.3533947,1.993602,,,,,,,,,,,,,,,,, -56400,0.49914253,1.908008,,,,,,,,,,,,,,,,, -56500,0.5057847,1.9802223,,,,,,,,,,,,,,,,, -56600,0.66320884,1.9547087,,,,,,,,,,,,,,,,, -56700,0.5688057,1.8412275,,,,,,,,,,,,,,,,, -56800,0.31935057,1.9056398,,,,,,,,,,,,,,,,, -56900,0.37057844,1.8455148,,,,,,,,,,,,,,,,, -57000,0.3239553,1.7955275,,,,,,,,,,,,,,,,, -57100,0.3777784,2.0059078,,,,,,,,,,,,,,,,, -57200,0.2807908,1.9701792,,,,,,,,,,,,,,,,, -57216,,,0.6217406392097473,1.8886256217956543,30.00592278859685,0.6395828723907471,1.764250636100769,26.48839796432684,3000.0,0.6459241509437561,1.7121566534042358,25.75472864828333,3003.0,20192.690421819687,33723.60852479935,20192.690421819687,13528.37666606903,0.7170102596282959,0.0 -57300,0.36197847,1.9416838,,,,,,,,,,,,,,,,, -57400,0.38998497,1.8949573,,,,,,,,,,,,,,,,, -57500,0.5026672,1.8821391,,,,,,,,,,,,,,,,, -57600,0.6089393,1.8568556,,,,,,,,,,,,,,,,, -57700,0.25500238,1.891756,,,,,,,,,,,,,,,,, -57800,0.3783727,1.8570056,,,,,,,,,,,,,,,,, -57900,0.45599928,1.9687717,,,,,,,,,,,,,,,,, -58000,0.30140564,1.996852,,,,,,,,,,,,,,,,, -58100,0.27078143,1.9283813,,,,,,,,,,,,,,,,, -58200,0.51469743,1.900943,,,,,,,,,,,,,,,,, -58300,0.4648916,1.8510262,,,,,,,,,,,,,,,,, -58400,0.4781644,1.945796,,,,,,,,,,,,,,,,, -58500,0.29100826,1.9188665,,,,,,,,,,,,,,,,, -58600,0.41484645,1.889907,,,,,,,,,,,,,,,,, -58700,0.35944736,1.9040058,,,,,,,,,,,,,,,,, -58800,0.3960272,1.8915102,,,,,,,,,,,,,,,,, -58900,0.35082653,2.0252721,,,,,,,,,,,,,,,,, -59000,0.3357453,1.9112507,,,,,,,,,,,,,,,,, -59100,0.61484873,1.8759995,,,,,,,,,,,,,,,,, -59200,0.2889408,1.9965683,,,,,,,,,,,,,,,,, -59300,0.37838867,1.9143237,,,,,,,,,,,,,,,,, -59400,0.4449607,1.9187045,,,,,,,,,,,,,,,,, -59500,0.5933381,1.8837119,,,,,,,,,,,,,,,,, -59600,0.57690245,1.9773612,,,,,,,,,,,,,,,,, -59601,,,0.6205396056175232,1.9234768152236936,29.974718274766825,0.6435505747795105,1.7510372400283811,26.64118576387656,3000.0,0.6510720252990723,1.687710523605347,26.27622806513249,3003.0,21033.113726854324,35132.104048252106,21033.113726854324,14096.32988357544,0.7592217922210693,0.0 -59700,0.36233187,1.8657781,,,,,,,,,,,,,,,,, -59800,0.74929535,1.8552603,,,,,,,,,,,,,,,,, -59900,0.27637246,1.8568561,,,,,,,,,,,,,,,,, -60000,0.4419946,1.9621922,,,,,,,,,,,,,,,,, -60100,0.35962328,1.861439,,,,,,,,,,,,,,,,, -60200,0.33382988,1.9416891,,,,,,,,,,,,,,,,, -60300,0.46955484,1.9116923,,,,,,,,,,,,,,,,, -60400,0.58319855,1.8960891,,,,,,,,,,,,,,,,, -60500,0.30050516,1.910355,,,,,,,,,,,,,,,,, -60600,0.28640068,1.8920078,,,,,,,,,,,,,,,,, -60700,0.31343868,1.8627771,,,,,,,,,,,,,,,,, -60800,0.3052985,1.9224731,,,,,,,,,,,,,,,,, -60900,0.3185183,1.8534732,,,,,,,,,,,,,,,,, -61000,0.33196583,1.8761744,,,,,,,,,,,,,,,,, -61100,0.4912053,1.9433774,,,,,,,,,,,,,,,,, -61200,0.48846334,1.9608611,,,,,,,,,,,,,,,,, -61300,0.31454313,1.8708416,,,,,,,,,,,,,,,,, -61400,0.342489,1.8822128,,,,,,,,,,,,,,,,, -61500,0.46317422,1.8690723,,,,,,,,,,,,,,,,, -61600,0.2967418,1.822775,,,,,,,,,,,,,,,,, -61700,0.32148433,1.8920316,,,,,,,,,,,,,,,,, -61800,0.41715857,1.9010649,,,,,,,,,,,,,,,,, -61900,0.5512334,1.9472976,,,,,,,,,,,,,,,,, -61986,,,0.620170533657074,1.924217939376831,29.704826058174017,0.6417403221130371,1.7392182350158691,24.938980730325103,3000.0,0.6511649489402771,1.6913459300994873,25.027113486174706,3003.0,21873.043880462646,36801.8515625,21873.043880462646,14926.03948044777,0.7936761379241943,0.0 -62000,0.54737186,1.9640055,,,,,,,,,,,,,,,,, -62100,0.42861044,1.8770105,,,,,,,,,,,,,,,,, -62200,0.5779524,1.9241612,,,,,,,,,,,,,,,,, -62300,0.3161524,1.8238481,,,,,,,,,,,,,,,,, -62400,0.30387524,1.9119042,,,,,,,,,,,,,,,,, -62500,0.46093544,1.9495999,,,,,,,,,,,,,,,,, -62600,0.66469216,1.9700662,,,,,,,,,,,,,,,,, -62700,0.38742927,1.9516243,,,,,,,,,,,,,,,,, -62800,0.2724183,1.8677901,,,,,,,,,,,,,,,,, -62900,0.4239887,1.9655977,,,,,,,,,,,,,,,,, -63000,0.45035636,1.9311178,,,,,,,,,,,,,,,,, -63100,0.29312924,1.8979158,,,,,,,,,,,,,,,,, -63200,0.43424317,1.9188786,,,,,,,,,,,,,,,,, -63300,0.29303664,1.9328138,,,,,,,,,,,,,,,,, -63400,0.39214557,1.8440802,,,,,,,,,,,,,,,,, -63500,0.3901246,1.8262296,,,,,,,,,,,,,,,,, -63600,0.3543121,1.8557198,,,,,,,,,,,,,,,,, -63700,0.2925161,1.8846135,,,,,,,,,,,,,,,,, -63800,0.3331366,1.9496964,,,,,,,,,,,,,,,,, -63900,0.34932655,1.9436835,,,,,,,,,,,,,,,,, -64000,0.31612438,1.8128833,,,,,,,,,,,,,,,,, -64100,0.48282918,1.8856238,,,,,,,,,,,,,,,,, -64200,0.32268628,1.9641676,,,,,,,,,,,,,,,,, -64300,0.6640392,1.8828769,,,,,,,,,,,,,,,,, -64370,,,0.6249828934669495,1.8825191259384155,30.12080442859549,0.6434885859489441,1.73857581615448,27.00111898618707,3000.0,0.6523153781890869,1.6770296096801758,26.33216731822289,3003.0,22713.07751083374,38225.61360192299,22713.07751083374,15509.657069206238,0.826836109161377,0.0 -64400,0.41865963,1.8649257,,,,,,,,,,,,,,,,, -64500,0.5050294,1.9667645,,,,,,,,,,,,,,,,, -64600,0.29885152,1.8446193,,,,,,,,,,,,,,,,, -64700,0.30179068,1.7923435,,,,,,,,,,,,,,,,, -64800,0.5630227,1.9035217,,,,,,,,,,,,,,,,, -64900,0.4821045,1.9803729,,,,,,,,,,,,,,,,, -65000,0.45892692,1.8465987,,,,,,,,,,,,,,,,, -65100,0.2723769,1.8461734,,,,,,,,,,,,,,,,, -65200,0.5490656,1.8810852,,,,,,,,,,,,,,,,, -65300,0.6277667,1.9946076,,,,,,,,,,,,,,,,, -65400,0.30905256,1.8401009,,,,,,,,,,,,,,,,, -65500,0.30997303,1.86035,,,,,,,,,,,,,,,,, -65600,0.38047236,1.7995498,,,,,,,,,,,,,,,,, -65700,0.4581164,1.8978493,,,,,,,,,,,,,,,,, -65800,0.4938876,1.8417356,,,,,,,,,,,,,,,,, -65900,0.34595793,1.9311267,,,,,,,,,,,,,,,,, -66000,0.42141616,1.9016274,,,,,,,,,,,,,,,,, -66100,0.2579603,1.8920033,,,,,,,,,,,,,,,,, -66200,0.37314722,1.7974823,,,,,,,,,,,,,,,,, -66300,0.2927164,1.800574,,,,,,,,,,,,,,,,, -66400,0.32594258,1.8332243,,,,,,,,,,,,,,,,, -66500,0.30745134,1.9485111,,,,,,,,,,,,,,,,, -66600,0.35693294,1.9456593,,,,,,,,,,,,,,,,, -66700,0.32264218,1.8572628,,,,,,,,,,,,,,,,, -66755,,,0.6223474144935608,1.897873878479004,29.60616800850276,0.6448525190353394,1.712667465209961,26.909955268306703,3000.0,0.6545000672340393,1.6544857025146484,26.066548963881,3003.0,23553.293513298035,39577.4663040638,23553.293513298035,16021.18220758438,0.8614444732666016,0.0 -66800,0.28993347,1.8935649,,,,,,,,,,,,,,,,, -66900,0.35174668,1.8394946,,,,,,,,,,,,,,,,, -67000,0.5152329,1.8611794,,,,,,,,,,,,,,,,, -67100,0.40270013,1.909003,,,,,,,,,,,,,,,,, -67200,0.27027953,1.9787785,,,,,,,,,,,,,,,,, -67300,0.49112424,1.8535525,,,,,,,,,,,,,,,,, -67400,0.39993605,1.8232512,,,,,,,,,,,,,,,,, -67500,0.40077367,1.8682102,,,,,,,,,,,,,,,,, -67600,0.359696,1.8884618,,,,,,,,,,,,,,,,, -67700,0.5468928,1.885476,,,,,,,,,,,,,,,,, -67800,0.34773454,1.8923126,,,,,,,,,,,,,,,,, -67900,0.5384384,1.7873511,,,,,,,,,,,,,,,,, -68000,0.36986995,1.9589348,,,,,,,,,,,,,,,,, -68100,0.360704,1.8507963,,,,,,,,,,,,,,,,, -68200,0.27332473,1.8715297,,,,,,,,,,,,,,,,, -68300,0.2908354,1.9082615,,,,,,,,,,,,,,,,, -68400,0.33881378,1.9840792,,,,,,,,,,,,,,,,, -68500,0.30541396,1.867789,,,,,,,,,,,,,,,,, -68600,0.34706473,1.7986362,,,,,,,,,,,,,,,,, -68700,0.4140309,1.8254149,,,,,,,,,,,,,,,,, -68800,0.33442804,1.9714911,,,,,,,,,,,,,,,,, -68900,0.27871525,1.9016124,,,,,,,,,,,,,,,,, -69000,0.31996125,1.8779842,,,,,,,,,,,,,,,,, -69100,0.5179702,1.9205315,,,,,,,,,,,,,,,,, -69140,,,0.6475690007209778,1.708027482032776,31.34708892707513,0.6484978199005127,1.7034063339233398,27.0442542027674,3000.0,0.6583929061889648,1.6351416110992432,26.70894910112787,3003.0,24393.23588967324,40994.1373796463,24393.23588967324,16597.801033496857,0.8966538906097412,0.0 -69200,0.36966762,1.8928137,,,,,,,,,,,,,,,,, -69300,0.29804075,1.9608635,,,,,,,,,,,,,,,,, -69400,0.3670861,1.8016983,,,,,,,,,,,,,,,,, -69500,0.29576987,1.863444,,,,,,,,,,,,,,,,, -69600,0.7019678,1.8656902,,,,,,,,,,,,,,,,, -69700,0.3335821,1.869197,,,,,,,,,,,,,,,,, -69800,0.30807048,1.8411369,,,,,,,,,,,,,,,,, -69900,0.49886346,1.825167,,,,,,,,,,,,,,,,, -70000,0.2954946,1.868343,,,,,,,,,,,,,,,,, -70100,0.73664975,1.8745133,,,,,,,,,,,,,,,,, -70200,0.28701866,1.8003255,,,,,,,,,,,,,,,,, -70300,0.38417324,1.8175879,,,,,,,,,,,,,,,,, -70400,0.33567423,1.8713809,,,,,,,,,,,,,,,,, -70500,0.54715794,1.829527,,,,,,,,,,,,,,,,, -70600,0.32736447,1.8803183,,,,,,,,,,,,,,,,, -70700,0.29365727,1.8359752,,,,,,,,,,,,,,,,, -70800,0.28942838,1.8271662,,,,,,,,,,,,,,,,, -70900,0.30271375,1.8398349,,,,,,,,,,,,,,,,, -71000,0.3052751,1.8178984,,,,,,,,,,,,,,,,, -71100,0.29791114,1.8462288,,,,,,,,,,,,,,,,, -71200,0.29790124,1.8490661,,,,,,,,,,,,,,,,, -71300,0.2665948,1.8655149,,,,,,,,,,,,,,,,, -71400,0.3831137,1.8860633,,,,,,,,,,,,,,,,, -71500,0.27657282,1.8644172,,,,,,,,,,,,,,,,, -71525,,,0.6270906329154968,1.85956346988678,30.49007289439116,0.6483118534088135,1.6941295862197876,27.250290188229883,3000.0,0.6592528223991394,1.626090168952942,26.910900978321823,3003.0,25233.399693965912,42368.23453712464,25233.399693965912,17131.623816490173,0.9319117069244384,0.0 -71600,0.56314975,1.7929962,,,,,,,,,,,,,,,,, -71700,0.3204599,1.9119412,,,,,,,,,,,,,,,,, -71800,0.5848027,1.8187397,,,,,,,,,,,,,,,,, -71900,0.38867077,1.8509325,,,,,,,,,,,,,,,,, -72000,0.50645065,1.7744576,,,,,,,,,,,,,,,,, -72100,0.53109044,1.8156065,,,,,,,,,,,,,,,,, -72200,0.5356476,1.8351704,,,,,,,,,,,,,,,,, -72300,0.35721505,1.8078824,,,,,,,,,,,,,,,,, -72400,0.39487413,1.8888092,,,,,,,,,,,,,,,,, -72500,0.40398306,1.9104798,,,,,,,,,,,,,,,,, -72600,0.37283525,1.8177195,,,,,,,,,,,,,,,,, -72700,0.25985807,1.7659386,,,,,,,,,,,,,,,,, -72800,0.32722023,1.8883454,,,,,,,,,,,,,,,,, -72900,0.44611084,1.9518832,,,,,,,,,,,,,,,,, -73000,0.28965542,1.8557214,,,,,,,,,,,,,,,,, -73100,0.25342023,1.8286632,,,,,,,,,,,,,,,,, -73200,0.3906055,1.86046,,,,,,,,,,,,,,,,, -73300,0.3614595,1.8291916,,,,,,,,,,,,,,,,, -73400,0.30232954,1.7827736,,,,,,,,,,,,,,,,, -73500,0.28396612,1.8885983,,,,,,,,,,,,,,,,, -73600,0.30099592,1.7200192,,,,,,,,,,,,,,,,, -73700,0.6041139,1.9299643,,,,,,,,,,,,,,,,, -73800,0.33450854,1.8407416,,,,,,,,,,,,,,,,, -73900,0.48112535,1.8645769,,,,,,,,,,,,,,,,, -73910,,,0.6320022940635681,1.837198734283448,30.474995803724568,0.6498989462852478,1.6905559301376345,27.39593869430475,3000.0,0.658288300037384,1.6216771602630615,26.540084120816232,3003.0,26073.622700452805,43699.67828798294,26073.622700452805,17622.726563453674,0.9732100963592528,0.0 -74000,0.6166513,1.9130052,,,,,,,,,,,,,,,,, -74100,0.30517697,1.8847132,,,,,,,,,,,,,,,,, -74200,0.4975176,1.7843512,,,,,,,,,,,,,,,,, -74300,0.31084982,1.8687917,,,,,,,,,,,,,,,,, -74400,0.50913286,1.823427,,,,,,,,,,,,,,,,, -74500,0.31159273,1.8152692,,,,,,,,,,,,,,,,, -74600,0.27513716,1.8196917,,,,,,,,,,,,,,,,, -74700,0.37657648,1.8457389,,,,,,,,,,,,,,,,, -74800,0.31464255,1.8712522,,,,,,,,,,,,,,,,, -74900,0.7645925,1.8312843,,,,,,,,,,,,,,,,, -75000,0.66481006,1.8448993,,,,,,,,,,,,,,,,, -75100,0.29910594,1.838998,,,,,,,,,,,,,,,,, -75200,0.4974525,1.8367864,,,,,,,,,,,,,,,,, -75300,0.41878754,1.7961813,,,,,,,,,,,,,,,,, -75400,0.35460225,1.8949033,,,,,,,,,,,,,,,,, -75500,0.36240888,1.845402,,,,,,,,,,,,,,,,, -75600,0.49434704,1.7497411,,,,,,,,,,,,,,,,, -75700,0.4598066,1.8485479,,,,,,,,,,,,,,,,, -75800,0.45946184,1.8025929,,,,,,,,,,,,,,,,, -75900,0.4293413,1.9360904,,,,,,,,,,,,,,,,, -76000,0.4490737,1.7850916,,,,,,,,,,,,,,,,, -76100,0.42993182,1.8142596,,,,,,,,,,,,,,,,, -76200,0.29220042,1.8444308,,,,,,,,,,,,,,,,, -76295,,,0.6367040872573853,1.7800288200378418,31.028093772921565,0.6544989943504333,1.6648274660110474,27.57831994512013,3000.0,0.6646563410758972,1.5929890871047974,27.065361558941973,3003.0,26913.55896472931,45257.52446436882,26913.55896472931,18340.518897771835,1.016244649887085,0.0 -76300,0.6567799,1.8671255,,,,,,,,,,,,,,,,, -76400,0.4431786,1.784214,,,,,,,,,,,,,,,,, -76500,0.46432546,1.7763995,,,,,,,,,,,,,,,,, -76600,0.41313192,1.8132135,,,,,,,,,,,,,,,,, -76700,0.39086324,1.869113,,,,,,,,,,,,,,,,, -76800,0.35474545,1.7452718,,,,,,,,,,,,,,,,, -76900,0.55180866,1.9281253,,,,,,,,,,,,,,,,, -77000,0.30354843,1.8142647,,,,,,,,,,,,,,,,, -77100,0.25971022,1.8121197,,,,,,,,,,,,,,,,, -77200,0.5928953,1.8311502,,,,,,,,,,,,,,,,, -77300,0.47912577,1.8370521,,,,,,,,,,,,,,,,, -77400,0.3036336,1.8101635,,,,,,,,,,,,,,,,, -77500,0.43786514,1.803303,,,,,,,,,,,,,,,,, -77600,0.29622164,1.841124,,,,,,,,,,,,,,,,, -77700,0.3710168,1.8533213,,,,,,,,,,,,,,,,, -77800,0.31177884,1.7945129,,,,,,,,,,,,,,,,, -77900,0.34957483,1.8771529,,,,,,,,,,,,,,,,, -78000,0.30207103,1.8082179,,,,,,,,,,,,,,,,, -78100,0.33174887,1.8066,,,,,,,,,,,,,,,,, -78200,0.29897287,1.8794435,,,,,,,,,,,,,,,,, -78300,0.49190986,1.8649796,,,,,,,,,,,,,,,,, -78400,0.2893806,1.8603103,,,,,,,,,,,,,,,,, -78500,0.31395304,1.7854942,,,,,,,,,,,,,,,,, -78600,0.2747938,1.704001,,,,,,,,,,,,,,,,, -78680,,,0.6351809501647949,1.8063995838165283,30.51627421938524,0.6567060351371765,1.6517740488052368,27.850444978291744,3000.0,0.6647725701332092,1.5881857872009275,27.32055158379405,3003.0,27753.546751499176,46657.269523620605,27753.546751499176,18900.15974450112,1.0552592277526855,0.0 -78700,0.36197823,1.8134842,,,,,,,,,,,,,,,,, -78800,0.5095082,1.8556689,,,,,,,,,,,,,,,,, -78900,0.33421668,1.8221588,,,,,,,,,,,,,,,,, -79000,0.3199886,1.810578,,,,,,,,,,,,,,,,, -79100,0.26928705,1.7386261,,,,,,,,,,,,,,,,, -79200,0.2994283,1.8435082,,,,,,,,,,,,,,,,, -79300,0.2741091,1.7923567,,,,,,,,,,,,,,,,, -79400,0.5493619,1.8030248,,,,,,,,,,,,,,,,, -79500,0.30784696,1.7648922,,,,,,,,,,,,,,,,, -79600,0.46682677,1.8688289,,,,,,,,,,,,,,,,, -79700,0.30305845,1.8020978,,,,,,,,,,,,,,,,, -79800,0.2985605,1.809737,,,,,,,,,,,,,,,,, -79900,0.51813585,1.8345062,,,,,,,,,,,,,,,,, -80000,0.42459393,1.8545885,,,,,,,,,,,,,,,,, -80100,0.3326616,1.772328,,,,,,,,,,,,,,,,, -80200,0.32009116,1.8322132,,,,,,,,,,,,,,,,, -80300,0.3011109,1.8534801,,,,,,,,,,,,,,,,, -80400,0.29584664,1.7580948,,,,,,,,,,,,,,,,, -80500,0.32938966,1.92839,,,,,,,,,,,,,,,,, -80600,0.30998403,1.7815093,,,,,,,,,,,,,,,,, -80700,0.31615818,1.8166466,,,,,,,,,,,,,,,,, -80800,0.31171542,1.7905371,,,,,,,,,,,,,,,,, -80900,0.35047534,1.7160757,,,,,,,,,,,,,,,,, -81000,0.42589727,1.8712679,,,,,,,,,,,,,,,,, -81065,,,0.6338257789611816,1.8154295682907104,30.37254047659617,0.657177209854126,1.6409265995025637,27.84527507937404,3000.0,0.6671895980834961,1.572178840637207,27.59701393455958,3003.0,28593.6186144352,48083.00713443756,28593.6186144352,19485.708025693893,1.0978131294250488,0.0 -81100,0.3266131,1.8296556,,,,,,,,,,,,,,,,, -81200,0.60883176,1.8354789,,,,,,,,,,,,,,,,, -81300,0.34758162,1.8127476,,,,,,,,,,,,,,,,, -81400,0.30507505,1.8003287,,,,,,,,,,,,,,,,, -81500,0.35763496,1.8231956,,,,,,,,,,,,,,,,, -81600,0.30801195,1.7764273,,,,,,,,,,,,,,,,, -81700,0.3403429,1.8055644,,,,,,,,,,,,,,,,, -81800,0.5638218,1.7677101,,,,,,,,,,,,,,,,, -81900,0.40823647,1.7605276,,,,,,,,,,,,,,,,, -82000,0.41523668,1.850051,,,,,,,,,,,,,,,,, -82100,0.36154756,1.6966629,,,,,,,,,,,,,,,,, -82200,0.29598513,1.7989235,,,,,,,,,,,,,,,,, -82300,0.29458416,1.7681031,,,,,,,,,,,,,,,,, -82400,0.35280538,1.8278513,,,,,,,,,,,,,,,,, -82500,0.4144086,1.8068483,,,,,,,,,,,,,,,,, -82600,0.3683485,1.8728638,,,,,,,,,,,,,,,,, -82700,0.295482,1.8099234,,,,,,,,,,,,,,,,, -82800,0.5379762,1.8697839,,,,,,,,,,,,,,,,, -82900,0.33412728,1.7758293,,,,,,,,,,,,,,,,, -83000,0.45786238,1.8025184,,,,,,,,,,,,,,,,, -83100,0.3664122,1.7761683,,,,,,,,,,,,,,,,, -83200,0.2873042,1.7924896,,,,,,,,,,,,,,,,, -83300,0.33534795,1.7711037,,,,,,,,,,,,,,,,, -83400,0.29047552,1.7627784,,,,,,,,,,,,,,,,, -83449,,,0.6420011520385742,1.7547723054885864,31.104798039546385,0.6573383808135986,1.6285314559936523,28.031213845797343,3000.0,0.6720237135887146,1.551444411277771,28.04304378534904,3003.0,29433.67586684227,49598.38488411904,29433.67586684227,20160.911412000656,1.1356170177459717,0.0 -83500,0.6426148,1.7107589,,,,,,,,,,,,,,,,, -83600,0.3513908,1.7714769,,,,,,,,,,,,,,,,, -83700,0.28578508,1.8853971,,,,,,,,,,,,,,,,, -83800,0.27391252,1.7976081,,,,,,,,,,,,,,,,, -83900,0.4103234,1.7886739,,,,,,,,,,,,,,,,, -84000,0.4415932,1.8226877,,,,,,,,,,,,,,,,, -84100,0.42278472,1.7976288,,,,,,,,,,,,,,,,, -84200,0.37144545,1.8112297,,,,,,,,,,,,,,,,, -84300,0.29261363,1.7092967,,,,,,,,,,,,,,,,, -84400,0.33487254,1.8093386,,,,,,,,,,,,,,,,, -84500,0.34127495,1.7386057,,,,,,,,,,,,,,,,, -84600,0.51860255,1.7488519,,,,,,,,,,,,,,,,, -84700,0.31412572,1.7929792,,,,,,,,,,,,,,,,, -84800,0.3202342,1.7773752,,,,,,,,,,,,,,,,, -84900,0.46911234,1.8216631,,,,,,,,,,,,,,,,, -85000,0.30873272,1.7191035,,,,,,,,,,,,,,,,, -85100,0.3679011,1.7695056,,,,,,,,,,,,,,,,, -85200,0.3195132,1.7729995,,,,,,,,,,,,,,,,, -85300,0.29252517,1.7085966,,,,,,,,,,,,,,,,, -85400,0.38744518,1.8450416,,,,,,,,,,,,,,,,, -85500,0.31097326,1.7276595,,,,,,,,,,,,,,,,, -85600,0.30751666,1.7477293,,,,,,,,,,,,,,,,, -85700,0.3935705,1.8398632,,,,,,,,,,,,,,,,, -85800,0.28315246,1.7522424,,,,,,,,,,,,,,,,, -85834,,,0.6379316449165344,1.7899552583694458,31.381088323768665,0.6602894067764282,1.61657977104187,27.88452353796739,3000.0,0.673267126083374,1.5379685163497925,28.295872994291173,3003.0,30273.84391236305,51023.77744102478,30273.84391236305,20746.017300128937,1.178145170211792,0.0 -85900,0.32855472,1.7140317,,,,,,,,,,,,,,,,, -86000,0.44930047,1.7774129,,,,,,,,,,,,,,,,, -86100,0.37324312,1.7018858,,,,,,,,,,,,,,,,, -86200,0.37465692,1.7382095,,,,,,,,,,,,,,,,, -86300,0.40831244,1.7673707,,,,,,,,,,,,,,,,, -86400,0.43414748,1.7819693,,,,,,,,,,,,,,,,, -86500,0.384867,1.8124825,,,,,,,,,,,,,,,,, -86600,0.30648047,1.8001496,,,,,,,,,,,,,,,,, -86700,0.31888986,1.7536604,,,,,,,,,,,,,,,,, -86800,0.3065125,1.7207754,,,,,,,,,,,,,,,,, -86900,0.30932373,1.7686131,,,,,,,,,,,,,,,,, -87000,0.33763197,1.674347,,,,,,,,,,,,,,,,, -87100,0.5443436,1.6347791,,,,,,,,,,,,,,,,, -87200,0.38787052,1.73184,,,,,,,,,,,,,,,,, -87300,0.35650605,1.7053614,,,,,,,,,,,,,,,,, -87400,0.50928104,1.8233018,,,,,,,,,,,,,,,,, -87500,0.36725768,1.6935136,,,,,,,,,,,,,,,,, -87600,0.27041805,1.7143408,,,,,,,,,,,,,,,,, -87700,0.44251305,1.709971,,,,,,,,,,,,,,,,, -87800,0.34390622,1.7764261,,,,,,,,,,,,,,,,, -87900,0.2951613,1.7402669,,,,,,,,,,,,,,,,, -88000,0.36260146,1.7145661,,,,,,,,,,,,,,,,, -88100,0.37403682,1.743901,,,,,,,,,,,,,,,,, -88200,0.33900818,1.8340367,,,,,,,,,,,,,,,,, -88220,,,0.6552841663360596,1.6534067392349243,32.29373394632669,0.6635503768920898,1.599921703338623,28.25207895536669,3000.0,0.6726628541946411,1.5305873155593872,28.065620186470788,3003.0,31114.019277334213,52334.56629276276,31114.019277334213,21216.517703533173,1.216465711593628,0.0 -88300,0.3447927,1.8341565,,,,,,,,,,,,,,,,, -88400,0.31215024,1.7576941,,,,,,,,,,,,,,,,, -88500,0.30397055,1.7036138,,,,,,,,,,,,,,,,, -88600,0.35492477,1.7689495,,,,,,,,,,,,,,,,, -88700,0.3101234,1.7719601,,,,,,,,,,,,,,,,, -88800,0.35773453,1.7943401,,,,,,,,,,,,,,,,, -88900,0.3144371,1.7597526,,,,,,,,,,,,,,,,, -89000,0.40818268,1.7655766,,,,,,,,,,,,,,,,, -89100,0.3129725,1.7445203,,,,,,,,,,,,,,,,, -89200,0.40495536,1.7578517,,,,,,,,,,,,,,,,, -89300,0.2945262,1.6856469,,,,,,,,,,,,,,,,, -89400,0.3237181,1.8305442,,,,,,,,,,,,,,,,, -89500,0.3080404,1.7096387,,,,,,,,,,,,,,,,, -89600,0.31728068,1.6686804,,,,,,,,,,,,,,,,, -89700,0.34959847,1.6908658,,,,,,,,,,,,,,,,, -89800,0.39265084,1.6619792,,,,,,,,,,,,,,,,, -89900,0.42521963,1.6803514,,,,,,,,,,,,,,,,, -90000,0.34576094,1.6820066,,,,,,,,,,,,,,,,, -90100,0.3316023,1.7576019,,,,,,,,,,,,,,,,, -90200,0.3136152,1.7840854,,,,,,,,,,,,,,,,, -90300,0.36735582,1.719633,,,,,,,,,,,,,,,,, -90400,0.33698997,1.6981987,,,,,,,,,,,,,,,,, -90500,0.35673815,1.7378991,,,,,,,,,,,,,,,,, -90600,0.32102656,1.7525066,,,,,,,,,,,,,,,,, -90603,,,0.6459221243858337,1.725430607795715,31.6283829980775,0.6639595031738281,1.5904563665390017,28.32560009538912,3000.0,0.6796583533287048,1.5105689764022827,28.84191883449241,3003.0,31953.927941083908,53740.90544724464,31953.927941083908,21782.82957220077,1.2534148693084717,0.0 -90700,0.30086133,1.7774562,,,,,,,,,,,,,,,,, -90800,0.3564593,1.821412,,,,,,,,,,,,,,,,, -90900,0.30670232,1.7301449,,,,,,,,,,,,,,,,, -91000,0.42051643,1.6950468,,,,,,,,,,,,,,,,, -91100,0.29182848,1.7419981,,,,,,,,,,,,,,,,, -91200,0.3190314,1.6730453,,,,,,,,,,,,,,,,, -91300,0.27951956,1.6502545,,,,,,,,,,,,,,,,, -91400,0.29558617,1.7587562,,,,,,,,,,,,,,,,, -91500,0.74355197,1.7305324,,,,,,,,,,,,,,,,, -91600,0.30238417,1.7227223,,,,,,,,,,,,,,,,, -91700,0.50914615,1.6882533,,,,,,,,,,,,,,,,, -91800,0.6432868,1.6797271,,,,,,,,,,,,,,,,, -91900,0.30648077,1.7187599,,,,,,,,,,,,,,,,, -92000,0.32199323,1.7848802,,,,,,,,,,,,,,,,, -92100,0.32211754,1.7992665,,,,,,,,,,,,,,,,, -92200,0.30776465,1.6702137,,,,,,,,,,,,,,,,, -92300,0.42454404,1.710186,,,,,,,,,,,,,,,,, -92400,0.31829935,1.6755496,,,,,,,,,,,,,,,,, -92500,0.41655964,1.7702973,,,,,,,,,,,,,,,,, -92600,0.28502914,1.832623,,,,,,,,,,,,,,,,, -92700,0.46269864,1.6774089,,,,,,,,,,,,,,,,, -92800,0.34340805,1.7445005,,,,,,,,,,,,,,,,, -92900,0.36080617,1.6991384,,,,,,,,,,,,,,,,, -92989,,,0.6447259783744812,1.7305859327316284,31.694070491390008,0.6674808859825134,1.5697031021118164,28.977459287447072,3000.0,0.6814130544662476,1.4934310913085938,28.43917903285333,3003.0,32794.00373888016,55094.32453846932,32794.00373888016,22296.06116771698,1.2910587787628174,0.0 -93000,0.32886547,1.6646008,,,,,,,,,,,,,,,,, -93100,0.50314605,1.6390383,,,,,,,,,,,,,,,,, -93200,0.3139894,1.7395519,,,,,,,,,,,,,,,,, -93300,0.34048718,1.741576,,,,,,,,,,,,,,,,, -93400,0.346727,1.7564169,,,,,,,,,,,,,,,,, -93500,0.30804884,1.7310247,,,,,,,,,,,,,,,,, -93600,0.32814008,1.7325191,,,,,,,,,,,,,,,,, -93700,0.33968812,1.7998691,,,,,,,,,,,,,,,,, -93800,0.3291237,1.8378758,,,,,,,,,,,,,,,,, -93900,0.3076104,1.66454,,,,,,,,,,,,,,,,, -94000,0.34500226,1.7344103,,,,,,,,,,,,,,,,, -94100,0.29435274,1.7633728,,,,,,,,,,,,,,,,, -94200,0.3045405,1.6820396,,,,,,,,,,,,,,,,, -94300,0.2979248,1.6649675,,,,,,,,,,,,,,,,, -94400,0.43170887,1.7823644,,,,,,,,,,,,,,,,, -94500,0.37664503,1.6957575,,,,,,,,,,,,,,,,, -94600,0.3410908,1.6926082,,,,,,,,,,,,,,,,, -94700,0.3526688,1.7705001,,,,,,,,,,,,,,,,, -94800,0.31996003,1.7902893,,,,,,,,,,,,,,,,, -94900,0.343175,1.7553899,,,,,,,,,,,,,,,,, -95000,0.31879047,1.7017164,,,,,,,,,,,,,,,,, -95100,0.33731252,1.5919253,,,,,,,,,,,,,,,,, -95200,0.31878468,1.7774792,,,,,,,,,,,,,,,,, -95300,0.33598915,1.7254896,,,,,,,,,,,,,,,,, -95375,,,0.6559216380119324,1.6479884386062622,32.57261780221355,0.6692167520523071,1.5522282123565674,28.733379976232605,3000.0,0.6810063719749451,1.4797627925872805,28.512990721724428,3003.0,33634.21266889572,56520.830107450485,33634.21266889572,22882.23731899261,1.3375027179718018,0.0 -95400,0.32707492,1.7079387,,,,,,,,,,,,,,,,, -95500,0.34632072,1.7198817,,,,,,,,,,,,,,,,, -95600,0.2947439,1.6366565,,,,,,,,,,,,,,,,, -95700,0.3606768,1.7101607,,,,,,,,,,,,,,,,, -95800,0.36091474,1.6963625,,,,,,,,,,,,,,,,, -95900,0.34021655,1.6815404,,,,,,,,,,,,,,,,, -96000,0.30524203,1.7555934,,,,,,,,,,,,,,,,, -96100,0.32152686,1.6723133,,,,,,,,,,,,,,,,, -96200,0.31554627,1.7179172,,,,,,,,,,,,,,,,, -96300,0.36708733,1.6676538,,,,,,,,,,,,,,,,, -96400,0.65445226,1.7703395,,,,,,,,,,,,,,,,, -96500,0.30409443,1.6325291,,,,,,,,,,,,,,,,, -96600,0.3564927,1.7143285,,,,,,,,,,,,,,,,, -96700,0.32353866,1.7715988,,,,,,,,,,,,,,,,, -96800,0.30670467,1.7126985,,,,,,,,,,,,,,,,, -96900,0.32619083,1.6891627,,,,,,,,,,,,,,,,, -97000,0.34622347,1.7554132,,,,,,,,,,,,,,,,, -97100,0.28841347,1.6422838,,,,,,,,,,,,,,,,, -97200,0.37041083,1.7438644,,,,,,,,,,,,,,,,, -97300,0.29523647,1.720651,,,,,,,,,,,,,,,,, -97400,0.3576137,1.6499163,,,,,,,,,,,,,,,,, -97500,0.3278434,1.6208807,,,,,,,,,,,,,,,,, -97600,0.36288458,1.6484958,,,,,,,,,,,,,,,,, -97700,0.44528624,1.7603077,,,,,,,,,,,,,,,,, -97760,,,0.6507320404052734,1.6990152597427368,32.16376730049248,0.6716469526290894,1.5430229902267456,28.933022568504928,3000.0,0.6845040917396545,1.4635298252105713,28.75097303659052,3003.0,34474.18307328224,57909.06037116051,34474.18307328224,23430.37916445732,1.3787147998809814,0.0 -97800,0.34043667,1.6428899,,,,,,,,,,,,,,,,, -97900,0.5111045,1.6191902,,,,,,,,,,,,,,,,, -98000,0.33451107,1.7329692,,,,,,,,,,,,,,,,, -98100,0.31005698,1.7190225,,,,,,,,,,,,,,,,, -98200,0.32227,1.6038053,,,,,,,,,,,,,,,,, -98300,0.3238155,1.7372519,,,,,,,,,,,,,,,,, -98400,0.3278571,1.7405549,,,,,,,,,,,,,,,,, -98500,0.39204723,1.6749021,,,,,,,,,,,,,,,,, -98600,0.3246299,1.575461,,,,,,,,,,,,,,,,, -98700,0.35338038,1.6483294,,,,,,,,,,,,,,,,, -98800,0.34758556,1.7238979,,,,,,,,,,,,,,,,, -98900,0.33464435,1.7049949,,,,,,,,,,,,,,,,, -99000,0.36273813,1.6727642,,,,,,,,,,,,,,,,, -99100,0.3182178,1.6590086,,,,,,,,,,,,,,,,, -99200,0.3273616,1.6255984,,,,,,,,,,,,,,,,, -99300,0.37837562,1.6285594,,,,,,,,,,,,,,,,, -99400,0.3604319,1.6994808,,,,,,,,,,,,,,,,, -99500,0.33510926,1.6509689,,,,,,,,,,,,,,,,, -99600,0.34273446,1.6429478,,,,,,,,,,,,,,,,, -99700,0.35236725,1.6824915,,,,,,,,,,,,,,,,, -99800,0.416063,1.6247802,,,,,,,,,,,,,,,,, -99900,0.35162923,1.6306508,,,,,,,,,,,,,,,,, -100000,0.32218286,1.6716532,,,,,,,,,,,,,,,,, -100100,0.38846967,1.708484,,,,,,,,,,,,,,,,, -100145,,,0.6949189305305481,1.418954849243164,35.23068643277161,0.673246443271637,1.528011441230774,29.01353235327636,3000.0,0.6881529688835144,1.4465465545654297,29.32560119300237,3003.0,35314.18997120857,59360.59610676765,35314.18997120857,24041.794478416443,1.417504072189331,0.0 -100200,0.32128885,1.6542927,,,,,,,,,,,,,,,,, -100300,0.31757063,1.5692353,,,,,,,,,,,,,,,,, -100400,0.31235144,1.7160621,,,,,,,,,,,,,,,,, -100500,0.32552862,1.6906613,,,,,,,,,,,,,,,,, -100600,0.34178826,1.6698573,,,,,,,,,,,,,,,,, -100700,0.35122117,1.5895772,,,,,,,,,,,,,,,,, -100800,0.3077885,1.5917103,,,,,,,,,,,,,,,,, -100900,0.35497922,1.7474817,,,,,,,,,,,,,,,,, -101000,0.31795448,1.6604577,,,,,,,,,,,,,,,,, -101100,0.32115746,1.6019144,,,,,,,,,,,,,,,,, -101200,0.42199335,1.6575185,,,,,,,,,,,,,,,,, -101300,0.30882633,1.6210463,,,,,,,,,,,,,,,,, -101400,0.33100984,1.6467886,,,,,,,,,,,,,,,,, -101500,0.30550307,1.7046409,,,,,,,,,,,,,,,,, -101600,0.34559873,1.7377558,,,,,,,,,,,,,,,,, -101700,0.32193157,1.7130663,,,,,,,,,,,,,,,,, -101800,0.4012931,1.7232735,,,,,,,,,,,,,,,,, -101900,0.43535626,1.6992946,,,,,,,,,,,,,,,,, -102000,0.321918,1.6871703,,,,,,,,,,,,,,,,, -102100,0.35033765,1.5698909,,,,,,,,,,,,,,,,, -102200,0.3134712,1.6657474,,,,,,,,,,,,,,,,, -102300,0.37074184,1.6505967,,,,,,,,,,,,,,,,, -102400,0.36549357,1.6331637,,,,,,,,,,,,,,,,, -102500,0.36214462,1.6440103,,,,,,,,,,,,,,,,, -102530,,,0.6625471115112305,1.6055678129196167,32.88659371950133,0.6756394505500793,1.5099588632583618,29.382846968354546,3000.0,0.6910231709480286,1.4268368482589722,29.189676258834265,3003.0,36154.3091506958,60769.31097340584,36154.3091506958,24610.273792028427,1.4581010341644287,0.0 -102600,0.38744986,1.68177,,,,,,,,,,,,,,,,, -102700,0.33944622,1.7348893,,,,,,,,,,,,,,,,, -102800,0.35356852,1.6165767,,,,,,,,,,,,,,,,, -102900,0.5381776,1.6181777,,,,,,,,,,,,,,,,, -103000,0.36720285,1.6162053,,,,,,,,,,,,,,,,, -103100,0.33989024,1.6372101,,,,,,,,,,,,,,,,, -103200,0.30892172,1.6285735,,,,,,,,,,,,,,,,, -103300,0.3237698,1.645628,,,,,,,,,,,,,,,,, -103400,0.32093558,1.6472238,,,,,,,,,,,,,,,,, -103500,0.37482402,1.5742885,,,,,,,,,,,,,,,,, -103600,0.3798993,1.5953542,,,,,,,,,,,,,,,,, -103700,0.33482265,1.6599514,,,,,,,,,,,,,,,,, -103800,0.34055942,1.6166208,,,,,,,,,,,,,,,,, -103900,0.35390297,1.5883383,,,,,,,,,,,,,,,,, -104000,0.32459152,1.6106187,,,,,,,,,,,,,,,,, -104100,0.31331438,1.6115084,,,,,,,,,,,,,,,,, -104200,0.34462604,1.60655,,,,,,,,,,,,,,,,, -104300,0.33796144,1.611033,,,,,,,,,,,,,,,,, -104400,0.34463847,1.639546,,,,,,,,,,,,,,,,, -104500,0.3362143,1.6688107,,,,,,,,,,,,,,,,, -104600,0.3478915,1.641643,,,,,,,,,,,,,,,,, -104700,0.39830327,1.6121365,,,,,,,,,,,,,,,,, -104800,0.3260918,1.573283,,,,,,,,,,,,,,,,, -104900,0.32802543,1.6390864,,,,,,,,,,,,,,,,, -104916,,,0.6604390144348145,1.628412842750549,32.50623094459274,0.6783796548843384,1.500727891921997,29.517135048861,3000.0,0.6924060583114624,1.420128583908081,29.2301595025155,3003.0,36994.25947546959,62180.81967806816,36994.25947546959,25181.71244192124,1.5033392906188965,0.0 -105000,0.35960215,1.6024323,,,,,,,,,,,,,,,,, -105100,0.37794536,1.5454693,,,,,,,,,,,,,,,,, -105200,0.3727764,1.6709561,,,,,,,,,,,,,,,,, -105300,0.32051274,1.6213253,,,,,,,,,,,,,,,,, -105400,0.33870998,1.6112052,,,,,,,,,,,,,,,,, -105500,0.3630618,1.5543965,,,,,,,,,,,,,,,,, -105600,0.3647078,1.5809822,,,,,,,,,,,,,,,,, -105700,0.39213887,1.6073663,,,,,,,,,,,,,,,,, -105800,0.38962457,1.6512785,,,,,,,,,,,,,,,,, -105900,0.41968387,1.6597446,,,,,,,,,,,,,,,,, -106000,0.39265376,1.6202956,,,,,,,,,,,,,,,,, -106100,0.44996837,1.6173524,,,,,,,,,,,,,,,,, -106200,0.34666014,1.4974822,,,,,,,,,,,,,,,,, -106300,0.34685153,1.6108682,,,,,,,,,,,,,,,,, -106400,0.34146777,1.5594552,,,,,,,,,,,,,,,,, -106500,0.40361533,1.6441886,,,,,,,,,,,,,,,,, -106600,0.39864933,1.6817236,,,,,,,,,,,,,,,,, -106700,0.38117936,1.5737265,,,,,,,,,,,,,,,,, -106800,0.33891997,1.6205239,,,,,,,,,,,,,,,,, -106900,0.32937375,1.558229,,,,,,,,,,,,,,,,, -107000,0.3625436,1.6266901,,,,,,,,,,,,,,,,, -107100,0.37703466,1.6116652,,,,,,,,,,,,,,,,, -107200,0.36051935,1.5944827,,,,,,,,,,,,,,,,, -107300,,,0.6785178780555725,1.51469624042511,33.844285592149795,0.6805495023727417,1.4880130290985107,29.774191434883186,3000.0,0.693591296672821,1.4022724628448486,29.43719279595945,3003.0,37834.20816755295,63637.71153593063,37834.20816755295,25798.53511714936,1.545297145843506,0.0 -107300,0.38518292,1.6133595,,,,,,,,,,,,,,,,, -107400,0.3430658,1.6787164,,,,,,,,,,,,,,,,, -107500,0.31206673,1.5486679,,,,,,,,,,,,,,,,, -107600,0.36333442,1.5239109,,,,,,,,,,,,,,,,, -107700,0.35437182,1.614286,,,,,,,,,,,,,,,,, -107800,0.4167279,1.6759499,,,,,,,,,,,,,,,,, -107900,0.36005905,1.5988644,,,,,,,,,,,,,,,,, -108000,0.33203402,1.6050963,,,,,,,,,,,,,,,,, -108100,0.3282697,1.5762763,,,,,,,,,,,,,,,,, -108200,0.36938426,1.6208025,,,,,,,,,,,,,,,,, -108300,0.3677914,1.6461883,,,,,,,,,,,,,,,,, -108400,0.387215,1.6292224,,,,,,,,,,,,,,,,, -108500,0.35293055,1.5894148,,,,,,,,,,,,,,,,, -108600,0.37918752,1.6510985,,,,,,,,,,,,,,,,, -108700,0.3461452,1.5755898,,,,,,,,,,,,,,,,, -108800,0.33670315,1.6125182,,,,,,,,,,,,,,,,, -108900,0.35942897,1.5308309,,,,,,,,,,,,,,,,, -109000,0.38645387,1.5723451,,,,,,,,,,,,,,,,, -109100,0.367377,1.5902523,,,,,,,,,,,,,,,,, -109200,0.35101166,1.5682273,,,,,,,,,,,,,,,,, -109300,0.34988242,1.6087908,,,,,,,,,,,,,,,,, -109400,0.34724662,1.6389968,,,,,,,,,,,,,,,,, -109500,0.36581755,1.5629165,,,,,,,,,,,,,,,,, -109600,0.33637756,1.5695487,,,,,,,,,,,,,,,,, -109684,,,0.6720339059829712,1.5570589303970337,32.979392765729266,0.6831905245780945,1.4739441871643066,30.14964707735585,3000.0,0.6969845294952393,1.3863232135772705,29.99551478418762,3003.0,38674.33937954903,65013.49522304535,38674.33937954903,26334.07120013237,1.5848398208618164,0.0 -109700,0.36705807,1.6193238,,,,,,,,,,,,,,,,, -109800,0.3634005,1.5728836,,,,,,,,,,,,,,,,, -109900,0.3742885,1.6160895,,,,,,,,,,,,,,,,, -110000,0.3591244,1.5850502,,,,,,,,,,,,,,,,, -110100,0.3518682,1.5817614,,,,,,,,,,,,,,,,, -110200,0.36327374,1.6058589,,,,,,,,,,,,,,,,, -110300,0.36073866,1.6334182,,,,,,,,,,,,,,,,, -110400,0.34176108,1.551437,,,,,,,,,,,,,,,,, -110500,0.3662826,1.5741403,,,,,,,,,,,,,,,,, -110600,0.35055038,1.5700567,,,,,,,,,,,,,,,,, -110700,0.42009723,1.5603367,,,,,,,,,,,,,,,,, -110800,0.3770843,1.5701729,,,,,,,,,,,,,,,,, -110900,0.3172313,1.6577563,,,,,,,,,,,,,,,,, -111000,0.33001065,1.571312,,,,,,,,,,,,,,,,, -111100,0.340479,1.5543764,,,,,,,,,,,,,,,,, -111200,0.34067997,1.518222,,,,,,,,,,,,,,,,, -111300,0.34588438,1.5710001,,,,,,,,,,,,,,,,, -111400,0.39012104,1.5619285,,,,,,,,,,,,,,,,, -111500,0.36514527,1.5923443,,,,,,,,,,,,,,,,, -111600,0.35852176,1.5745395,,,,,,,,,,,,,,,,, -111700,0.3824893,1.5298231,,,,,,,,,,,,,,,,, -111800,0.38763928,1.4815196,,,,,,,,,,,,,,,,, -111900,0.34844106,1.5562668,,,,,,,,,,,,,,,,, -112000,0.33225608,1.5143358,,,,,,,,,,,,,,,,, -112068,,,0.6718950271606445,1.5602552890777588,33.424706492563324,0.6845420598983765,1.466409683227539,30.09190590265441,3000.0,0.6990064382553101,1.3751919269561768,29.9190960808925,3003.0,39514.344621658325,66358.30993127823,39514.344621658325,26838.75919866562,1.627079963684082,0.0 -112100,0.33450463,1.5507413,,,,,,,,,,,,,,,,, -112200,0.3471371,1.5036021,,,,,,,,,,,,,,,,, -112300,0.38729116,1.556885,,,,,,,,,,,,,,,,, -112400,0.34817848,1.55305,,,,,,,,,,,,,,,,, -112500,0.3526376,1.5320885,,,,,,,,,,,,,,,,, -112600,0.3844298,1.5507755,,,,,,,,,,,,,,,,, -112700,0.36628255,1.5908539,,,,,,,,,,,,,,,,, -112800,0.39453718,1.5993726,,,,,,,,,,,,,,,,, -112900,0.37738612,1.4923893,,,,,,,,,,,,,,,,, -113000,0.36060655,1.5730531,,,,,,,,,,,,,,,,, -113100,0.38095823,1.6016994,,,,,,,,,,,,,,,,, -113200,0.40952784,1.6031084,,,,,,,,,,,,,,,,, -113300,0.39075398,1.6002358,,,,,,,,,,,,,,,,, -113400,0.36026174,1.5606031,,,,,,,,,,,,,,,,, -113500,0.35217485,1.4685549,,,,,,,,,,,,,,,,, -113600,0.35877362,1.5241476,,,,,,,,,,,,,,,,, -113700,0.35489407,1.5068438,,,,,,,,,,,,,,,,, -113800,0.34639707,1.5113745,,,,,,,,,,,,,,,,, -113900,0.39535707,1.553251,,,,,,,,,,,,,,,,, -114000,0.3559461,1.4711984,,,,,,,,,,,,,,,,, -114100,0.36526912,1.5424023,,,,,,,,,,,,,,,,, -114200,0.327588,1.4645786,,,,,,,,,,,,,,,,, -114300,0.35719156,1.4488647,,,,,,,,,,,,,,,,, -114400,0.38367933,1.5074046,,,,,,,,,,,,,,,,, -114454,,,0.68068528175354,1.5048444271087646,34.31321273170774,0.6864762902259827,1.4497908353805542,30.02917097359945,3000.0,0.7002847194671631,1.361093282699585,30.209978551978672,3003.0,40354.538182258606,67835.46319293976,40354.538182258606,27475.601682901382,1.6685771942138672,0.0 -114500,0.35704294,1.5326498,,,,,,,,,,,,,,,,, -114600,0.3644631,1.5979041,,,,,,,,,,,,,,,,, -114700,0.39299494,1.5320698,,,,,,,,,,,,,,,,, -114800,0.3770495,1.5119214,,,,,,,,,,,,,,,,, -114900,0.38998646,1.6122335,,,,,,,,,,,,,,,,, -115000,0.37531745,1.5123857,,,,,,,,,,,,,,,,, -115100,0.3842918,1.5094978,,,,,,,,,,,,,,,,, -115200,0.39382765,1.5743052,,,,,,,,,,,,,,,,, -115300,0.34965235,1.5083601,,,,,,,,,,,,,,,,, -115400,0.35034838,1.5424345,,,,,,,,,,,,,,,,, -115500,0.36730975,1.5033151,,,,,,,,,,,,,,,,, -115600,0.36031613,1.513699,,,,,,,,,,,,,,,,, -115700,0.35416633,1.4498549,,,,,,,,,,,,,,,,, -115800,0.36223048,1.5645379,,,,,,,,,,,,,,,,, -115900,0.36924294,1.510747,,,,,,,,,,,,,,,,, -116000,0.37559333,1.5150905,,,,,,,,,,,,,,,,, -116100,0.35671473,1.5095072,,,,,,,,,,,,,,,,, -116200,0.38250184,1.5803238,,,,,,,,,,,,,,,,, -116300,0.36471656,1.414807,,,,,,,,,,,,,,,,, -116400,0.35309818,1.5101937,,,,,,,,,,,,,,,,, -116500,0.35765684,1.4669877,,,,,,,,,,,,,,,,, -116600,0.41611198,1.4506167,,,,,,,,,,,,,,,,, -116700,0.3626908,1.5520468,,,,,,,,,,,,,,,,, -116800,0.38435677,1.5511783,,,,,,,,,,,,,,,,, -116839,,,0.6848151087760925,1.493071436882019,33.9288495000838,0.6874682307243347,1.4417047500610352,30.30958769276966,3000.0,0.7036197781562805,1.350566267967224,30.64315992147018,3003.0,41194.56631875038,69211.07856273651,41194.56631875038,28011.064224004745,1.7157025337219238,0.0 -116900,0.41244248,1.5413482,,,,,,,,,,,,,,,,, -117000,0.35304615,1.4364353,,,,,,,,,,,,,,,,, -117100,0.36534747,1.5213648,,,,,,,,,,,,,,,,, -117200,0.39595944,1.514991,,,,,,,,,,,,,,,,, -117300,0.38030243,1.5716308,,,,,,,,,,,,,,,,, -117400,0.38972104,1.5237905,,,,,,,,,,,,,,,,, -117500,0.3831879,1.4847984,,,,,,,,,,,,,,,,, -117600,0.37605745,1.4848783,,,,,,,,,,,,,,,,, -117700,0.35591805,1.4680369,,,,,,,,,,,,,,,,, -117800,0.4116698,1.4788411,,,,,,,,,,,,,,,,, -117900,0.37225708,1.4599578,,,,,,,,,,,,,,,,, -118000,0.37944546,1.3837265,,,,,,,,,,,,,,,,, -118100,0.3646163,1.4318203,,,,,,,,,,,,,,,,, -118200,0.37963632,1.5231023,,,,,,,,,,,,,,,,, -118300,0.38364598,1.4996101,,,,,,,,,,,,,,,,, -118400,0.34626657,1.4753182,,,,,,,,,,,,,,,,, -118500,0.3897905,1.6057813,,,,,,,,,,,,,,,,, -118600,0.39290866,1.6413624,,,,,,,,,,,,,,,,, -118700,0.3980822,1.551899,,,,,,,,,,,,,,,,, -118800,0.3766207,1.4058914,,,,,,,,,,,,,,,,, -118900,0.36918736,1.5122758,,,,,,,,,,,,,,,,, -119000,0.3771348,1.4694484,,,,,,,,,,,,,,,,, -119100,0.37479624,1.432432,,,,,,,,,,,,,,,,, -119200,0.4078867,1.5149113,,,,,,,,,,,,,,,,, -119225,,,0.6958141326904297,1.417291522026062,35.5492473489212,0.6902580261230469,1.4339897632598877,30.490945218469427,3000.0,0.7048050761222839,1.3403512239456177,30.31912476620588,3003.0,42034.72462916374,70620.8813958168,42034.72462916374,28580.59091734886,1.7584803104400637,0.0 -119300,0.38263524,1.4653058,,,,,,,,,,,,,,,,, -119400,0.3988951,1.5019095,,,,,,,,,,,,,,,,, -119500,0.3714422,1.4808377,,,,,,,,,,,,,,,,, -119600,0.37718934,1.4812546,,,,,,,,,,,,,,,,, -119700,0.37528494,1.4030288,,,,,,,,,,,,,,,,, -119800,0.38749892,1.4906301,,,,,,,,,,,,,,,,, -119900,0.38917843,1.4921902,,,,,,,,,,,,,,,,, -120000,0.39550078,1.4695652,,,,,,,,,,,,,,,,, -120100,0.44112027,1.4508278,,,,,,,,,,,,,,,,, -120200,0.3770341,1.4868882,,,,,,,,,,,,,,,,, -120300,0.40447462,1.4880841,,,,,,,,,,,,,,,,, -120400,0.37835595,1.5326567,,,,,,,,,,,,,,,,, -120500,0.39250633,1.4044151,,,,,,,,,,,,,,,,, -120600,0.39944658,1.5087082,,,,,,,,,,,,,,,,, -120700,0.4074144,1.4669703,,,,,,,,,,,,,,,,, -120800,0.36363545,1.5050141,,,,,,,,,,,,,,,,, -120900,0.38090014,1.4215994,,,,,,,,,,,,,,,,, -121000,0.37991527,1.4394096,,,,,,,,,,,,,,,,, -121100,0.38879147,1.5044206,,,,,,,,,,,,,,,,, -121200,0.38206673,1.407098,,,,,,,,,,,,,,,,, -121300,0.39592016,1.3990159,,,,,,,,,,,,,,,,, -121400,0.3826805,1.4530634,,,,,,,,,,,,,,,,, -121500,0.40085027,1.5658602,,,,,,,,,,,,,,,,, -121600,0.4033983,1.4053308,,,,,,,,,,,,,,,,, -121610,,,0.6900114417076111,1.4532064199447632,34.7617892280219,0.6903696060180664,1.4234853982925415,30.629143421025496,3000.0,0.7055488228797913,1.332283616065979,30.548803892158418,3003.0,42874.72810292244,72061.10469961166,42874.72810292244,29180.690301418304,1.8024253845214844,0.0 -121700,0.4120974,1.4300536,,,,,,,,,,,,,,,,, -121800,0.39486057,1.512807,,,,,,,,,,,,,,,,, -121900,0.3785688,1.444862,,,,,,,,,,,,,,,,, -122000,0.39495343,1.4146956,,,,,,,,,,,,,,,,, -122100,0.40407065,1.5028158,,,,,,,,,,,,,,,,, -122200,0.3946396,1.4572666,,,,,,,,,,,,,,,,, -122300,0.41670352,1.4244099,,,,,,,,,,,,,,,,, -122400,0.4010874,1.5556049,,,,,,,,,,,,,,,,, -122500,0.39056614,1.4353063,,,,,,,,,,,,,,,,, -122600,0.3903367,1.4313626,,,,,,,,,,,,,,,,, -122700,0.40353504,1.530962,,,,,,,,,,,,,,,,, -122800,0.3991758,1.3881221,,,,,,,,,,,,,,,,, -122900,0.41837537,1.4587922,,,,,,,,,,,,,,,,, -123000,0.37757328,1.4595945,,,,,,,,,,,,,,,,, -123100,0.41112345,1.436873,,,,,,,,,,,,,,,,, -123200,0.39707354,1.4838219,,,,,,,,,,,,,,,,, -123300,0.38641703,1.4259478,,,,,,,,,,,,,,,,, -123400,0.40698856,1.3945293,,,,,,,,,,,,,,,,, -123500,0.39262524,1.4812062,,,,,,,,,,,,,,,,, -123600,0.4180727,1.4413071,,,,,,,,,,,,,,,,, -123700,0.3982821,1.467915,,,,,,,,,,,,,,,,, -123800,0.38983405,1.4504906,,,,,,,,,,,,,,,,, -123900,0.38582328,1.481779,,,,,,,,,,,,,,,,, -123995,,,0.6892982721328735,1.4594749212265017,34.6818584249282,0.6919319033622742,1.4199069738388062,30.84775682494113,3000.0,0.7058973908424377,1.32646906375885,30.80911951363941,3003.0,43714.89873600006,73438.41144561768,43714.89873600006,29717.70730495453,1.845533847808838,0.0 -124000,0.39324144,1.4276334,,,,,,,,,,,,,,,,, -124100,0.4305968,1.4328196,,,,,,,,,,,,,,,,, -124200,0.38534188,1.4721267,,,,,,,,,,,,,,,,, -124300,0.39396608,1.477908,,,,,,,,,,,,,,,,, -124400,0.42898375,1.4845842,,,,,,,,,,,,,,,,, -124500,0.41323435,1.5138855,,,,,,,,,,,,,,,,, -124600,0.4071481,1.4627234,,,,,,,,,,,,,,,,, -124700,0.40155372,1.4923421,,,,,,,,,,,,,,,,, -124800,0.40185466,1.3719159,,,,,,,,,,,,,,,,, -124900,0.39759073,1.4440538,,,,,,,,,,,,,,,,, -125000,0.38730338,1.3480301,,,,,,,,,,,,,,,,, -125100,0.38019523,1.5045885,,,,,,,,,,,,,,,,, -125200,0.41196945,1.4598117,,,,,,,,,,,,,,,,, -125300,0.40158987,1.4630603,,,,,,,,,,,,,,,,, -125400,0.37878594,1.4735756,,,,,,,,,,,,,,,,, -125500,0.40166825,1.4522939,,,,,,,,,,,,,,,,, -125600,0.40867493,1.444552,,,,,,,,,,,,,,,,, -125700,0.3763224,1.4246714,,,,,,,,,,,,,,,,, -125800,0.41742456,1.4812001,,,,,,,,,,,,,,,,, -125900,0.38515744,1.3775058,,,,,,,,,,,,,,,,, -126000,0.42166668,1.4539872,,,,,,,,,,,,,,,,, -126100,0.392099,1.4469901,,,,,,,,,,,,,,,,, -126200,0.42422777,1.4778832,,,,,,,,,,,,,,,,, -126300,0.40289265,1.4625881,,,,,,,,,,,,,,,,, -126380,,,0.697314977645874,1.4237242937088013,36.055260217867456,0.6940025687217712,1.413859486579895,30.82466800846016,3000.0,0.7090000510215759,1.3189260959625244,30.86978933852004,3003.0,44555.02311301232,74843.19479346275,44555.02311301232,30282.248594284058,1.8876619338989256,0.0 -126400,0.39027712,1.4883158,,,,,,,,,,,,,,,,, -126500,0.39126822,1.3868697,,,,,,,,,,,,,,,,, -126600,0.40564266,1.4838222,,,,,,,,,,,,,,,,, -126700,0.4129226,1.4349203,,,,,,,,,,,,,,,,, -126800,0.36812133,1.4065009,,,,,,,,,,,,,,,,, -126900,0.39926776,1.4994079,,,,,,,,,,,,,,,,, -127000,0.3812725,1.3825297,,,,,,,,,,,,,,,,, -127100,0.3946764,1.4093752,,,,,,,,,,,,,,,,, -127200,0.41847947,1.4567734,,,,,,,,,,,,,,,,, -127300,0.3947022,1.4646784,,,,,,,,,,,,,,,,, -127400,0.4157921,1.5974476,,,,,,,,,,,,,,,,, -127500,0.4013287,1.4129403,,,,,,,,,,,,,,,,, -127600,0.39568627,1.4893936,,,,,,,,,,,,,,,,, -127700,0.4051364,1.4427055,,,,,,,,,,,,,,,,, -127800,0.41863114,1.465752,,,,,,,,,,,,,,,,, -127900,0.41757956,1.3975365,,,,,,,,,,,,,,,,, -128000,0.4290406,1.4503008,,,,,,,,,,,,,,,,, -128100,0.39031407,1.401864,,,,,,,,,,,,,,,,, -128200,0.3911999,1.3567168,,,,,,,,,,,,,,,,, -128300,0.39041853,1.3754328,,,,,,,,,,,,,,,,, -128400,0.4099912,1.4141905,,,,,,,,,,,,,,,,, -128500,0.40152872,1.35458,,,,,,,,,,,,,,,,, -128600,0.40689418,1.44241,,,,,,,,,,,,,,,,, -128700,0.38145283,1.4510353,,,,,,,,,,,,,,,,, -128765,,,0.7001673579216003,1.3921157121658323,36.05400553646607,0.6933701634407043,1.4130243062973022,30.57451985032652,3000.0,0.7091976404190063,1.3158727884292605,30.909184782342955,3003.0,45395.007142305374,76256.74099755287,45395.007142305374,30855.69163513184,1.931976318359375,0.0 -128800,0.38767362,1.3766472,,,,,,,,,,,,,,,,, -128900,0.41050896,1.4603736,,,,,,,,,,,,,,,,, -129000,0.42366332,1.4877398,,,,,,,,,,,,,,,,, -129100,0.3991243,1.4075617,,,,,,,,,,,,,,,,, -129200,0.40540558,1.4379939,,,,,,,,,,,,,,,,, -129300,0.41379258,1.4599651,,,,,,,,,,,,,,,,, -129400,0.3883696,1.4165256,,,,,,,,,,,,,,,,, -129500,0.40818444,1.3745406,,,,,,,,,,,,,,,,, -129600,0.39710665,1.4052836,,,,,,,,,,,,,,,,, -129700,0.40714288,1.3861102,,,,,,,,,,,,,,,,, -129800,0.409723,1.3759576,,,,,,,,,,,,,,,,, -129900,0.40701488,1.454733,,,,,,,,,,,,,,,,, -130000,0.3942949,1.3604928,,,,,,,,,,,,,,,,, -130100,0.39097133,1.4082139,,,,,,,,,,,,,,,,, -130200,0.401052,1.4480094,,,,,,,,,,,,,,,,, -130300,0.41192517,1.4253855,,,,,,,,,,,,,,,,, -130400,0.39977777,1.4448062,,,,,,,,,,,,,,,,, -130500,0.39515257,1.391379,,,,,,,,,,,,,,,,, -130600,0.4038377,1.4397758,,,,,,,,,,,,,,,,, -130700,0.38514286,1.3796515,,,,,,,,,,,,,,,,, -130800,0.3833484,1.4407983,,,,,,,,,,,,,,,,, -130900,0.37915802,1.4583211,,,,,,,,,,,,,,,,, -131000,0.41198868,1.4313954,,,,,,,,,,,,,,,,, -131100,0.40547314,1.4609982,,,,,,,,,,,,,,,,, -131150,,,0.6955586671829224,1.4250540733337402,36.22844690063096,0.6936181783676147,1.4108177423477173,30.819000657109928,3000.0,0.7093021869659424,1.3146555423736572,30.79200209401173,3003.0,46235.168941020966,77636.98960375786,46235.168941020966,31395.65654921532,1.9757180213928225,0.0 -131200,0.39994764,1.4472193,,,,,,,,,,,,,,,,, -131300,0.39157686,1.4407232,,,,,,,,,,,,,,,,, -131400,0.41079935,1.4116278,,,,,,,,,,,,,,,,, -131500,0.39510208,1.3642479,,,,,,,,,,,,,,,,, -131600,0.39996853,1.4812453,,,,,,,,,,,,,,,,, -131700,0.415381,1.4818678,,,,,,,,,,,,,,,,, -131800,0.40244368,1.4319555,,,,,,,,,,,,,,,,, -131900,0.38354447,1.330935,,,,,,,,,,,,,,,,, -132000,0.40285042,1.4952744,,,,,,,,,,,,,,,,, -132100,0.3982565,1.3966889,,,,,,,,,,,,,,,,, -132200,0.3987432,1.4218304,,,,,,,,,,,,,,,,, -132300,0.4042187,1.4654621,,,,,,,,,,,,,,,,, -132400,0.39073992,1.3767637,,,,,,,,,,,,,,,,, -132500,0.389808,1.4242538,,,,,,,,,,,,,,,,, -132600,0.3987617,1.4580251,,,,,,,,,,,,,,,,, -132700,0.39177063,1.4036653,,,,,,,,,,,,,,,,, -132800,0.39517882,1.4624963,,,,,,,,,,,,,,,,, -132900,0.40240076,1.4262182,,,,,,,,,,,,,,,,, -133000,0.37875322,1.3552773,,,,,,,,,,,,,,,,, -133100,0.4030785,1.4723406,,,,,,,,,,,,,,,,, -133200,0.39314404,1.432515,,,,,,,,,,,,,,,,, -133300,0.3814773,1.3565934,,,,,,,,,,,,,,,,, -133333,,,0.6961641311645508,1.423098087310791,36.12654870431947,0.6939653754234314,1.4107656478881836,30.91989262350453,3000.0,0.70927894115448,1.3146251440048218,30.84290524056852,3003.0,47003.91788029671,78948.11796474457,47003.91788029671,31937.92197227478,2.0215542316436768,0.0 -133333,,,,,,,,,,,,,,47003.91788029671,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 92dbf3e3d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,59 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -876.4952499866486,0.0,31.33621072769165,1,0,31.33621072769165,0.0007088489946909,0.0,11.19086742401123,3003,907.8314986228944,0.0005905735306441,0.0,11.173606872558594,0.0004835649742744,0.0,11.208685874938965,3000 -1380.3259477615356,0.0205812454223632,871.4278464317322,2326,0,871.4278464317322,0.5088954567909241,16.326466481940816,2.922269582748413,3003,2251.8469796180725,0.5102726817131042,22.3400601817346,2.8770992755889893,0.5084871649742126,18.049709526813007,2.875525951385498,3000 -1861.728487491608,0.0453600883483886,1711.6315701007843,4653,0,1711.6315701007843,0.5914473533630371,22.11066399058141,2.134783983230591,3003,3573.551340818405,0.5776917338371277,27.71633074707988,2.2594664096832275,0.5888209342956543,23.508677243219108,2.168548107147217,3000 -2313.41427898407,0.0724508762359619,2551.706431388855,6981,0,2551.706431388855,0.621997594833374,24.20917412371607,1.889957070350647,3003,4865.410325527191,0.6117618680000305,29.191234059510283,1.9756730794906616,0.6168429255485535,25.35601131488413,1.9347010850906368,3000 -2782.745606899261,0.0980725288391113,3391.958536148072,9308,0,3391.958536148072,0.6395677328109741,25.220842753067227,1.7598669528961182,3003,6175.092213869095,0.6131592392921448,29.21159701320609,1.9403873682022093,0.6278781294822693,26.29559810187768,1.825577735900879,3000 -3258.116788625717,0.1238188743591308,4232.144702672958,11635,0,4232.144702672958,0.6465632319450378,25.978056856540427,1.6927350759506226,3003,7490.7502863407135,0.6183011531829834,30.174969727349268,1.9032855033874512,0.637698233127594,26.727296923815302,1.7577568292617798,3000 -3727.089662313461,0.1512885093688964,5072.34353518486,13963,0,5072.34353518486,0.6531288027763367,26.288234040380576,1.6488524675369265,3003,8800.02228808403,0.6274576783180237,30.42799003102148,1.82904314994812,0.6435381770133972,27.2194505792734,1.7107415199279783,3000 -4267.290963888168,0.1784923076629638,5912.543488740921,16291,0,5912.543488740921,0.6560688018798828,26.360974311764373,1.6182109117507937,3003,10180.52417087555,0.6281635761260986,30.12821649759861,1.831966042518616,0.6473942995071411,27.31210332692129,1.6852028369903564,3000 -4784.1533625125885,0.2073345184326172,6752.588777065277,18617,0,6752.588777065277,0.6614025831222534,27.088659970344388,1.5896934270858765,3003,11537.538368225098,0.631950318813324,30.512143847900504,1.8058927059173584,0.6519696116447449,27.90245230878121,1.6625659465789795,3000 -5275.937527179718,0.2346513271331787,7592.666503190994,20945,0,7592.666503190994,0.6645168662071228,27.15118496076002,1.5687122344970703,3003,12869.502164840698,0.6327740550041199,30.94379672227982,1.7765021324157717,0.6536186933517456,27.590168338973967,1.6433725357055664,3000 -5772.678349494934,0.2628781795501709,8432.74540233612,23273,0,8432.74540233612,0.6667364239692688,27.47751462370135,1.5609209537506104,3003,14206.423070907593,0.6339693069458008,30.655361633286923,1.7776130437850952,0.6550569534301758,27.783212559536747,1.6391043663024902,3000 -6227.92778301239,0.2916662693023681,9272.903683662416,25600,0,9272.903683662416,0.6655627489089966,27.11589523195827,1.5492188930511477,3003,15501.935836553574,0.6479560136795044,31.84595571310796,1.6769108772277832,0.6565200686454773,27.658885874545486,1.623263597488403,3000 -6993.833231925964,0.3207814693450928,10112.873045921326,27927,0,10112.873045921326,0.6682470440864563,27.582243730159053,1.5407686233520508,3003,17107.914113998413,0.6404833793640137,31.049956913695816,1.7259860038757324,0.6574624180793762,27.82174760624576,1.6162623167037964,3000 -7534.302425146103,0.3493866920471191,10952.89643716812,30255,0,10952.89643716812,0.6700947284698486,27.54839713124831,1.525154709815979,3003,18488.50726699829,0.636620819568634,31.611790041850664,1.76451575756073,0.6559621095657349,27.67252673218362,1.605042576789856,3000 -8067.295173883438,0.3822178840637207,11793.089814901352,32583,0,11793.089814901352,0.6725466251373291,27.34830812687249,1.5128222703933716,3003,19861.80053019524,0.6440646648406982,31.360567916844506,1.7038763761520386,0.6583923101425171,27.90004436424476,1.594013810157776,3000 -8553.564245939255,0.4125397205352783,12633.118577480316,34910,0,12633.118577480316,0.6729998588562012,27.251809248166783,1.5030328035354614,3003,21188.203240394592,0.6403083801269531,31.28614280491014,1.734409213066101,0.6616036891937256,28.19766421877236,1.5844405889511108,3000 -9075.40495300293,0.4424059391021728,13473.14170718193,37238,0,13473.14170718193,0.6725582480430603,27.28591892054456,1.502700686454773,3003,22550.16905093193,0.6411272883415222,31.164002746131025,1.7296642065048218,0.6619756817817688,28.1188025528331,1.5773353576660156,3000 -9618.16955590248,0.4737052917480469,14313.19044804573,39565,0,14313.19044804573,0.6753936409950256,27.86312619079604,1.4936915636062622,3003,23933.09076309204,0.6465423107147217,31.69269439740908,1.6878236532211304,0.6620872616767883,28.22098356404612,1.5737826824188232,3000 -10165.055995941162,0.5107996463775635,15153.36160159111,41893,0,15153.36160159111,0.6745802164077759,28.026926555958124,1.4858415126800537,3003,25320.262265205383,0.6447368264198303,31.56412929199626,1.7097009420394895,0.6631907820701599,28.26689393593714,1.5640273094177246,3000 -10733.730982542038,0.5428597927093506,15993.378037929537,44221,0,15993.378037929537,0.6778572201728821,28.2497816979482,1.474234104156494,3003,26729.05960392952,0.6637552380561829,32.93034087770076,1.5781522989273071,0.6653606295585632,28.47223571166506,1.5589519739151,3000 -11356.05702495575,0.5749258995056152,16833.506383895874,46549,0,16833.506383895874,0.6787520051002502,28.08245586929687,1.4658998250961304,3003,28191.62198972702,0.6493027806282043,31.197649864945813,1.6772871017456057,0.6661293506622314,28.40430065663764,1.5487326383590698,3000 -11906.429987430573,0.6067063808441162,17673.46990466118,48877,0,17673.46990466118,0.6791703104972839,28.13524029971456,1.4580789804458618,3003,29582.065421819687,0.6498969793319702,31.67379922402138,1.671565294265747,0.6682496070861816,28.52788848704636,1.5388306379318235,3000 -12463.738171815872,0.638897180557251,18513.54372811317,51205,0,18513.54372811317,0.6821568012237549,28.69368597124545,1.45171320438385,3003,30979.550669908524,0.6544747352600098,32.19056354138812,1.6260355710983276,0.6689439415931702,29.092913139909744,1.538604497909546,3000 -13027.610567808151,0.6803698539733887,19353.507332086563,53533,0,19353.507332086563,0.6848992109298706,28.1405950398642,1.4424127340316772,3003,32383.50212812424,0.6519391536712646,32.23124119536476,1.6503171920776367,0.6698243021965027,28.78384577586669,1.5270503759384155,3000 -13604.082396030426,0.7132043838500977,20193.692486524586,55860,0,20193.692486524586,0.68580561876297,29.12164847468511,1.4303960800170898,3003,33800.267602682114,0.6507437229156494,32.41903297717362,1.672051191329956,0.6705806255340576,29.09612197316751,1.521011233329773,3000 -14238.36943602562,0.7463061809539795,21033.878321647644,58188,0,21033.878321647644,0.6871768236160278,28.67539179023136,1.433237910270691,3003,35274.847340106964,0.6537730693817139,32.227818247947184,1.6372880935668943,0.6702582836151123,28.863251676095945,1.5269571542739868,3000 -14961.058718919754,0.7788941860198975,21873.81586742401,60515,0,21873.81586742401,0.6871535778045654,29.073916460981607,1.4229494333267212,3003,36837.5823905468,0.6578980684280396,32.07658106977582,1.6208689212799072,0.6717089414596558,28.92682264501012,1.5151821374893188,3000 -15496.90365743637,0.8135547637939453,22713.780904769897,62841,0,22713.780904769897,0.6859566569328308,28.6389104597473,1.4162017107009888,3003,38213.50196003914,0.6713626980781555,33.69910168169331,1.508202075958252,0.674411952495575,29.20011186816427,1.5051673650741575,3000 -16053.180361270905,0.8553059101104736,23553.788153648376,65168,0,23553.788153648376,0.6899308562278748,28.924297355169426,1.4054077863693235,3003,39609.90386343002,0.65871661901474,32.66574289741518,1.6069515943527222,0.6746847629547119,29.051769116288494,1.4960321187973022,3000 -16575.54416203499,0.8911569118499756,24393.98091578484,67496,0,24393.98091578484,0.689059317111969,29.383744057926585,1.4040769338607788,3003,40972.572177410126,0.6580450534820557,32.3845854245405,1.6125664710998535,0.6748707294464111,29.354178552156,1.4918677806854248,3000 -17175.61923766136,0.9281957149505616,25234.14297890663,69823,0,25234.14297890663,0.6909418702125549,29.22117196182941,1.3975272178649902,3003,42412.92000794411,0.6646310091018677,33.108476509492625,1.5558394193649292,0.6774001717567444,29.732924813863225,1.4817792177200315,3000 -17716.850410461426,0.9643950462341307,26074.134548187256,72151,0,26074.134548187256,0.6935332417488098,29.335350984322613,1.3811521530151367,3003,43794.25180768967,0.6619483828544617,32.684329463700614,1.584561824798584,0.676978588104248,29.4450521112024,1.4765617847442627,3000 -18214.69597125053,1.0013093948364258,26914.059263944622,74478,0,26914.059263944622,0.69459068775177,29.20670844457939,1.3774783611297607,3003,45132.13513803482,0.6594235301017761,32.67356970090248,1.5999988317489624,0.6777969002723694,29.571455647456062,1.4699158668518066,3000 -18751.840750455856,1.0371296405792236,27754.242109537125,76806,0,27754.242109537125,0.6945674419403076,29.462568269990985,1.3712241649627686,3003,46509.57319569588,0.6694381833076477,33.0531892644505,1.5369466543197632,0.6802147626876831,29.47381621833204,1.4623948335647583,3000 -19320.94375896454,1.0745155811309814,28594.37763762474,79133,0,28594.37763762474,0.6958921551704407,29.431835227142336,1.3631484508514404,3003,47918.92470264435,0.6636430025100708,32.54973127642245,1.570115566253662,0.6804007291793823,29.552745828566785,1.45924973487854,3000 -19830.34655547142,1.111635684967041,29434.364958763123,81461,0,29434.364958763123,0.6979489922523499,29.82934247972239,1.3512593507766724,3003,49268.42466902733,0.6872650384902954,34.87201890341066,1.421802282333374,0.6823846101760864,30.02667336732894,1.4490602016448977,3000 -20590.173165798187,1.1480538845062256,30274.487243413925,83788,0,30274.487243413925,0.6982162594795227,29.646421542412785,1.348582744598389,3003,50868.4852745533,0.6727224588394165,33.53385947656073,1.5115309953689575,0.6825209856033325,29.710900837467445,1.441343903541565,3000 -21151.93498682976,1.1864585876464844,31114.39369344712,86116,0,31114.39369344712,0.7002963423728943,30.026394171826187,1.3392986059188845,3003,52270.26392364502,0.6685099005699158,33.509164043830715,1.5362203121185305,0.6834757328033447,30.021007104354624,1.4383317232131958,3000 -21738.805801153183,1.2259869575500488,31954.47871041298,88444,0,31954.47871041298,0.7007378935813904,30.07429817068781,1.3346697092056274,3003,53697.33389949799,0.6788003444671631,33.85755785036112,1.4714220762252808,0.6860795021057129,30.326511463742506,1.428539156913757,3000 -22414.37882566452,1.265702247619629,32794.429097890854,90772,0,32794.429097890854,0.7035732865333557,30.3331563210614,1.3192323446273804,3003,55212.970027923584,0.6721503138542175,33.66060615621202,1.5121186971664429,0.6863027215003967,30.17695861596039,1.4202866554260254,3000 -23005.836725473404,1.3031470775604248,33634.516040802,93100,0,33634.516040802,0.7037127614021301,30.392667213925773,1.315152645111084,3003,56644.627078294754,0.6721413135528564,33.86056555520624,1.5171512365341189,0.6868358850479126,30.403902679657968,1.417452335357666,3000 -23493.724118709564,1.34435772895813,34474.600940704346,95428,0,34474.600940704346,0.7046656608581543,30.250803153996863,1.3117072582244873,3003,57972.71556472778,0.6813426613807678,34.027032781976644,1.451947808265686,0.6877037882804871,30.14882976469225,1.4141026735305786,3000 -24081.37470555305,1.387169599533081,35314.54950070381,97755,0,35314.54950070381,0.7072570323944092,30.597164472048394,1.2985440492630005,3003,59400.434348106384,0.6785141825675964,34.25009675294788,1.4695278406143188,0.6903820037841797,30.519843494061867,1.4033619165420532,3000 -24619.3114862442,1.4285976886749268,36154.70352482796,100083,0,36154.70352482796,0.706187903881073,30.34443331307036,1.299963116645813,3003,60778.641570568085,0.6936557292938232,35.65112004922339,1.3875911235809326,0.6906548142433167,30.85509335453766,1.400636911392212,3000 -25245.94492340088,1.4772801399230957,36994.59205460549,102411,0,36994.59205460549,0.7067921757698059,30.49037124141108,1.296403884887695,3003,62245.28328108788,0.6837599873542786,34.42078439556147,1.4399648904800415,0.691696286201477,30.65930224150146,1.3963810205459597,3000 -25783.427124261856,1.5160844326019287,37834.66395068169,104739,0,37834.66395068169,0.7083609700202942,30.74846146018644,1.2872958183288574,3003,63622.94854474068,0.6803861260414124,34.19161231852919,1.4522088766098022,0.6913491487503052,30.690073446109658,1.3914282321929932,3000 -26338.833918571472,1.5553884506225586,38674.80152177811,107068,0,38674.80152177811,0.708570122718811,30.67077610265788,1.2839752435684204,3003,65018.60505485535,0.691776692867279,34.81961982311779,1.3914330005645752,0.6922914981842041,30.58321900927457,1.3859933614730835,3000 -26914.005674123764,1.5969746112823486,39515.00077295303,109397,0,39515.00077295303,0.7111498713493347,31.09661239196818,1.274350881576538,3003,66434.0912425518,0.6863054037094116,34.77338731179625,1.422512769699097,0.692328691482544,30.9349353146716,1.3812284469604492,3000 -27458.749623537064,1.6440293788909912,40354.97988796234,111725,0,40354.97988796234,0.7117308974266052,31.12296226341149,1.268806219100952,3003,67818.93628644943,0.6838766932487488,34.7839582194981,1.4372414350509644,0.6921178698539734,30.739287863649857,1.376046061515808,3000 -28038.789351701736,1.6860826015472412,41195.22360944748,114053,0,41195.22360944748,0.7118470668792725,31.035798009097697,1.2679171562194824,3003,69239.3367304802,0.6951395273208618,35.48070667530032,1.3722827434539795,0.6938289403915405,30.685051686956182,1.3752260208129885,3000 -28615.922476530075,1.726020097732544,42035.293724536896,116381,0,42035.293724536896,0.713706374168396,31.34912036522872,1.2614657878875732,3003,70656.65222883224,0.6891108155250549,35.24992178839545,1.4094853401184082,0.6954160332679749,30.92770434969378,1.370386242866516,3000 -29169.59684228897,1.7702577114105225,42875.2884888649,118709,0,42875.2884888649,0.7125559449195862,31.390352867613373,1.260074496269226,3003,72050.43827652931,0.6923836469650269,34.80738047912399,1.3908350467681885,0.6954408288002014,31.010793828916093,1.3662816286087036,3000 -29696.142201185223,1.8214778900146484,43715.2911632061,121036,0,43715.2911632061,0.7132415771484375,31.33621769982612,1.257387399673462,3003,73417.11321353912,0.6926527619361877,35.392830009300184,1.3857016563415527,0.6950812935829163,30.888696962120584,1.3645596504211426,3000 -30242.33323359489,1.863266944885254,44555.2459564209,123364,0,44555.2459564209,0.7135552763938904,31.546370535616788,1.2557114362716677,3003,74803.37408471107,0.6927144527435303,35.22963795099508,1.3827232122421265,0.6956640481948853,30.82173282251918,1.36439049243927,3000 -30818.95568537712,1.906174659729004,45395.13310265541,125691,0,45395.13310265541,0.7137877345085144,31.45919176016124,1.2536863088607788,3003,76220.00295948982,0.6953275799751282,35.52793475734423,1.366080403327942,0.6961103677749634,30.92001661383089,1.3615988492965698,3000 -31372.873242139816,1.95019006729126,46235.07247853279,128018,0,46235.07247853279,0.7140550017356873,31.35299748207434,1.2544089555740356,3003,77613.97647738457,0.6966461539268494,35.648902859429185,1.3577289581298828,0.6961227655410767,30.90892111794316,1.362878441810608,3000 -31922.08831977844,1.9932560920715328,47074.99239993096,130345,0,47074.99239993096,0.7141711711883545,31.45476105787029,1.253180742263794,3003,79003.22927308083,0.6950264573097229,35.27114728847377,1.3692277669906616,0.6961600184440613,30.8710946996457,1.3616727590560913,3000 -32485.429721593857,2.036381959915161,47915.134707927704,132673,0,47915.134707927704,0.7141944169998169,31.444790126895835,1.252968668937683,3003,80406.8292388916,0.6964210271835327,35.60231110750975,1.3611470460891724,0.6960855722427368,30.815935514007432,1.3613717555999756,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/measurements.csv deleted file mode 100644 index b123e54e8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1394 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.0709844,11.151057,,,,,,,,,,,,,,,,, -1,,,0.0005905735306441,11.173606872558594,0.0,0.0004835649742744,11.208685874938965,0.0,3000.0,0.0007088489946909,11.19086742401123,0.0,3003.0,31.33621072769165,907.8314986228944,31.33621072769165,876.4952499866486,0.0,0.0 -100,0.17051536,8.175846,,,,,,,,,,,,,,,,, -200,0.30922076,7.413442,,,,,,,,,,,,,,,,, -300,0.43103236,6.7901664,,,,,,,,,,,,,,,,, -400,0.49881947,6.3037496,,,,,,,,,,,,,,,,, -500,0.42326468,5.8367963,,,,,,,,,,,,,,,,, -600,0.4175631,5.5292673,,,,,,,,,,,,,,,,, -700,0.6128253,5.3610353,,,,,,,,,,,,,,,,, -800,0.513506,4.9849524,,,,,,,,,,,,,,,,, -900,0.40322492,4.8841753,,,,,,,,,,,,,,,,, -1000,0.5988842,4.5885496,,,,,,,,,,,,,,,,, -1100,0.4769583,4.260439,,,,,,,,,,,,,,,,, -1200,0.4557067,4.0947948,,,,,,,,,,,,,,,,, -1300,0.56488097,3.959901,,,,,,,,,,,,,,,,, -1400,0.71995044,3.7761683,,,,,,,,,,,,,,,,, -1500,0.56151974,3.6272004,,,,,,,,,,,,,,,,, -1600,0.44908032,3.4695973,,,,,,,,,,,,,,,,, -1700,0.5819913,3.4005222,,,,,,,,,,,,,,,,, -1800,0.43784136,3.2916994,,,,,,,,,,,,,,,,, -1900,0.36867625,3.2660198,,,,,,,,,,,,,,,,, -2000,0.4670055,3.2658477,,,,,,,,,,,,,,,,, -2100,0.43061733,3.1344328,,,,,,,,,,,,,,,,, -2200,0.36687475,3.0881896,,,,,,,,,,,,,,,,, -2300,0.5213032,3.0543957,,,,,,,,,,,,,,,,, -2326,,,0.5102726817131042,2.8770992755889893,22.3400601817346,0.5084871649742126,2.875525951385498,18.049709526813007,3000.0,0.5088954567909241,2.922269582748413,16.326466481940816,3003.0,871.4278464317322,2251.8469796180725,871.4278464317322,1380.3259477615356,0.0205812454223632,0.0 -2400,0.36342654,2.9743538,,,,,,,,,,,,,,,,, -2500,0.3259593,2.9665034,,,,,,,,,,,,,,,,, -2600,0.34682226,2.8014069,,,,,,,,,,,,,,,,, -2700,0.35967237,2.810858,,,,,,,,,,,,,,,,, -2800,0.29393986,2.760002,,,,,,,,,,,,,,,,, -2900,0.25503564,2.6971724,,,,,,,,,,,,,,,,, -3000,0.27004868,2.756567,,,,,,,,,,,,,,,,, -3100,0.28247172,2.6333408,,,,,,,,,,,,,,,,, -3200,0.32218352,2.6205246,,,,,,,,,,,,,,,,, -3300,0.2616875,2.5980358,,,,,,,,,,,,,,,,, -3400,0.22213274,2.6171083,,,,,,,,,,,,,,,,, -3500,0.29637975,2.5511563,,,,,,,,,,,,,,,,, -3600,0.20662355,2.5642543,,,,,,,,,,,,,,,,, -3700,0.21513434,2.440582,,,,,,,,,,,,,,,,, -3800,0.192832,2.5334651,,,,,,,,,,,,,,,,, -3900,0.21484452,2.5001311,,,,,,,,,,,,,,,,, -4000,0.20695381,2.4875731,,,,,,,,,,,,,,,,, -4100,0.18996176,2.383494,,,,,,,,,,,,,,,,, -4200,0.21068406,2.4106205,,,,,,,,,,,,,,,,, -4300,0.19723204,2.374454,,,,,,,,,,,,,,,,, -4400,0.19219087,2.3879707,,,,,,,,,,,,,,,,, -4500,0.1730839,2.3189752,,,,,,,,,,,,,,,,, -4600,0.20468603,2.3631153,,,,,,,,,,,,,,,,, -4653,,,0.5776917338371277,2.2594664096832275,27.71633074707988,0.5888209342956543,2.168548107147217,23.508677243219108,3000.0,0.5914473533630371,2.134783983230591,22.11066399058141,3003.0,1711.6315701007843,3573.551340818405,1711.6315701007843,1861.728487491608,0.0453600883483886,0.0 -4700,0.16910776,2.2099595,,,,,,,,,,,,,,,,, -4800,0.16594855,2.2872283,,,,,,,,,,,,,,,,, -4900,0.16478367,2.3121145,,,,,,,,,,,,,,,,, -5000,0.16926092,2.3293014,,,,,,,,,,,,,,,,, -5100,0.17939836,2.3069608,,,,,,,,,,,,,,,,, -5200,0.16439685,2.2808003,,,,,,,,,,,,,,,,, -5300,0.17661713,2.2352204,,,,,,,,,,,,,,,,, -5400,0.16243315,2.2494562,,,,,,,,,,,,,,,,, -5500,0.18624766,2.1724463,,,,,,,,,,,,,,,,, -5600,0.1685546,2.257855,,,,,,,,,,,,,,,,, -5700,0.16008414,2.2052104,,,,,,,,,,,,,,,,, -5800,0.15324405,2.2850163,,,,,,,,,,,,,,,,, -5900,0.15623178,2.2348733,,,,,,,,,,,,,,,,, -6000,0.17860211,2.1671853,,,,,,,,,,,,,,,,, -6100,0.17483544,2.1822672,,,,,,,,,,,,,,,,, -6200,0.16094643,2.244263,,,,,,,,,,,,,,,,, -6300,0.15026212,2.2089448,,,,,,,,,,,,,,,,, -6400,0.15895215,2.2219872,,,,,,,,,,,,,,,,, -6500,0.14836642,2.1837509,,,,,,,,,,,,,,,,, -6600,0.15252608,2.0936253,,,,,,,,,,,,,,,,, -6700,0.17327131,2.10805,,,,,,,,,,,,,,,,, -6800,0.18554749,2.184935,,,,,,,,,,,,,,,,, -6900,0.18405734,2.1079516,,,,,,,,,,,,,,,,, -6981,,,0.6117618680000305,1.9756730794906616,29.191234059510283,0.6168429255485535,1.9347010850906368,25.35601131488413,3000.0,0.621997594833374,1.889957070350647,24.20917412371607,3003.0,2551.706431388855,4865.410325527191,2551.706431388855,2313.41427898407,0.0724508762359619,0.0 -7000,0.17481825,2.1866615,,,,,,,,,,,,,,,,, -7100,0.17612176,2.2159398,,,,,,,,,,,,,,,,, -7200,0.1467012,2.134803,,,,,,,,,,,,,,,,, -7300,0.17321372,2.0827978,,,,,,,,,,,,,,,,, -7400,0.15234576,2.1269658,,,,,,,,,,,,,,,,, -7500,0.15725566,2.1070344,,,,,,,,,,,,,,,,, -7600,0.16088253,2.1898124,,,,,,,,,,,,,,,,, -7700,0.20593289,2.050563,,,,,,,,,,,,,,,,, -7800,0.22697029,1.9785173,,,,,,,,,,,,,,,,, -7900,0.16223103,2.1358523,,,,,,,,,,,,,,,,, -8000,0.23068042,2.0698783,,,,,,,,,,,,,,,,, -8100,0.1605311,2.0045884,,,,,,,,,,,,,,,,, -8200,0.30591372,1.9854543,,,,,,,,,,,,,,,,, -8300,0.17978479,2.1443958,,,,,,,,,,,,,,,,, -8400,0.16597904,2.0792637,,,,,,,,,,,,,,,,, -8500,0.17732736,2.04002,,,,,,,,,,,,,,,,, -8600,0.16704032,2.0897238,,,,,,,,,,,,,,,,, -8700,0.18118964,2.0459723,,,,,,,,,,,,,,,,, -8800,0.16240081,2.016642,,,,,,,,,,,,,,,,, -8900,0.16506347,2.0427547,,,,,,,,,,,,,,,,, -9000,0.17272677,2.1017985,,,,,,,,,,,,,,,,, -9100,0.2088827,2.051776,,,,,,,,,,,,,,,,, -9200,0.27239344,2.0599601,,,,,,,,,,,,,,,,, -9300,0.23550856,1.9933805,,,,,,,,,,,,,,,,, -9308,,,0.6131592392921448,1.9403873682022093,29.21159701320609,0.6278781294822693,1.825577735900879,26.29559810187768,3000.0,0.6395677328109741,1.7598669528961182,25.220842753067227,3003.0,3391.958536148072,6175.092213869095,3391.958536148072,2782.745606899261,0.0980725288391113,0.0 -9400,0.177053,2.0302815,,,,,,,,,,,,,,,,, -9500,0.17073143,1.9994396,,,,,,,,,,,,,,,,, -9600,0.18084349,1.9566296,,,,,,,,,,,,,,,,, -9700,0.18472674,1.93855,,,,,,,,,,,,,,,,, -9800,0.1484644,1.981031,,,,,,,,,,,,,,,,, -9900,0.22402674,2.1063828,,,,,,,,,,,,,,,,, -10000,0.16792661,2.0257754,,,,,,,,,,,,,,,,, -10100,0.17984357,1.9891852,,,,,,,,,,,,,,,,, -10200,0.18940444,1.9643674,,,,,,,,,,,,,,,,, -10300,0.26952612,1.9811033,,,,,,,,,,,,,,,,, -10400,0.28498542,1.9844111,,,,,,,,,,,,,,,,, -10500,0.2052019,1.9328448,,,,,,,,,,,,,,,,, -10600,0.20295464,2.0393772,,,,,,,,,,,,,,,,, -10700,0.20843336,1.9808723,,,,,,,,,,,,,,,,, -10800,0.22004175,1.9183238,,,,,,,,,,,,,,,,, -10900,0.26619816,2.0944355,,,,,,,,,,,,,,,,, -11000,0.17940079,2.04154,,,,,,,,,,,,,,,,, -11100,0.19660138,1.9974616,,,,,,,,,,,,,,,,, -11200,0.20768562,1.880593,,,,,,,,,,,,,,,,, -11300,0.22907853,2.045151,,,,,,,,,,,,,,,,, -11400,0.19749375,1.9057356,,,,,,,,,,,,,,,,, -11500,0.208663,2.0824926,,,,,,,,,,,,,,,,, -11600,0.19359413,1.9541172,,,,,,,,,,,,,,,,, -11635,,,0.6183011531829834,1.9032855033874512,30.174969727349268,0.637698233127594,1.7577568292617798,26.727296923815302,3000.0,0.6465632319450378,1.6927350759506226,25.978056856540427,3003.0,4232.144702672958,7490.7502863407135,4232.144702672958,3258.116788625717,0.1238188743591308,0.0 -11700,0.26692227,1.926694,,,,,,,,,,,,,,,,, -11800,0.1956186,1.9615716,,,,,,,,,,,,,,,,, -11900,0.19155324,1.9664323,,,,,,,,,,,,,,,,, -12000,0.18898846,2.1034732,,,,,,,,,,,,,,,,, -12100,0.18270794,1.90496,,,,,,,,,,,,,,,,, -12200,0.25805414,1.8799806,,,,,,,,,,,,,,,,, -12300,0.21251172,1.9697341,,,,,,,,,,,,,,,,, -12400,0.19539861,1.9674288,,,,,,,,,,,,,,,,, -12500,0.2115806,1.9074435,,,,,,,,,,,,,,,,, -12600,0.20386337,1.9388584,,,,,,,,,,,,,,,,, -12700,0.17938276,1.9823812,,,,,,,,,,,,,,,,, -12800,0.17615329,1.9904441,,,,,,,,,,,,,,,,, -12900,0.16777249,1.9450704,,,,,,,,,,,,,,,,, -13000,0.2460557,1.9528139,,,,,,,,,,,,,,,,, -13100,0.2238266,1.9462624,,,,,,,,,,,,,,,,, -13200,0.27431476,1.9718045,,,,,,,,,,,,,,,,, -13300,0.20098753,1.9772427,,,,,,,,,,,,,,,,, -13400,0.21460198,1.9441187,,,,,,,,,,,,,,,,, -13500,0.17810524,1.9425279,,,,,,,,,,,,,,,,, -13600,0.26013333,1.8919836,,,,,,,,,,,,,,,,, -13700,0.18034628,2.0019882,,,,,,,,,,,,,,,,, -13800,0.19148242,1.9933902,,,,,,,,,,,,,,,,, -13900,0.25407928,1.9025779,,,,,,,,,,,,,,,,, -13963,,,0.6274576783180237,1.82904314994812,30.42799003102148,0.6435381770133972,1.7107415199279783,27.2194505792734,3000.0,0.6531288027763367,1.6488524675369265,26.288234040380576,3003.0,5072.34353518486,8800.02228808403,5072.34353518486,3727.089662313461,0.1512885093688964,0.0 -14000,0.22834699,1.9577769,,,,,,,,,,,,,,,,, -14100,0.20646201,1.9239925,,,,,,,,,,,,,,,,, -14200,0.17228538,1.8787425,,,,,,,,,,,,,,,,, -14300,0.17694487,1.9631793,,,,,,,,,,,,,,,,, -14400,0.17833194,1.8710138,,,,,,,,,,,,,,,,, -14500,0.18585081,1.8587021,,,,,,,,,,,,,,,,, -14600,0.25461116,1.9609222,,,,,,,,,,,,,,,,, -14700,0.19770557,1.9408104,,,,,,,,,,,,,,,,, -14800,0.2042112,1.8959146,,,,,,,,,,,,,,,,, -14900,0.21709348,2.0126917,,,,,,,,,,,,,,,,, -15000,0.1892891,1.9183035,,,,,,,,,,,,,,,,, -15100,0.18935047,1.9420377,,,,,,,,,,,,,,,,, -15200,0.24797226,1.8581206,,,,,,,,,,,,,,,,, -15300,0.19564357,1.9653459,,,,,,,,,,,,,,,,, -15400,0.18852592,1.9706203,,,,,,,,,,,,,,,,, -15500,0.19063528,1.9205936,,,,,,,,,,,,,,,,, -15600,0.17300434,1.8984789,,,,,,,,,,,,,,,,, -15700,0.18569459,1.9159188,,,,,,,,,,,,,,,,, -15800,0.20816721,1.8574986,,,,,,,,,,,,,,,,, -15900,0.18543781,1.8625932,,,,,,,,,,,,,,,,, -16000,0.20911168,1.9137378,,,,,,,,,,,,,,,,, -16100,0.19298534,1.8728937,,,,,,,,,,,,,,,,, -16200,0.18604183,1.9159691,,,,,,,,,,,,,,,,, -16291,,,0.6281635761260986,1.831966042518616,30.12821649759861,0.6473942995071411,1.6852028369903564,27.31210332692129,3000.0,0.6560688018798828,1.6182109117507937,26.360974311764373,3003.0,5912.543488740921,10180.52417087555,5912.543488740921,4267.290963888168,0.1784923076629638,0.0 -16300,0.19864553,1.8655977,,,,,,,,,,,,,,,,, -16400,0.20833585,1.8945135,,,,,,,,,,,,,,,,, -16500,0.17366445,1.9524864,,,,,,,,,,,,,,,,, -16600,0.1950228,1.902562,,,,,,,,,,,,,,,,, -16700,0.20138754,1.8490326,,,,,,,,,,,,,,,,, -16800,0.18687372,1.9153849,,,,,,,,,,,,,,,,, -16900,0.29619712,1.8909795,,,,,,,,,,,,,,,,, -17000,0.20897706,1.8897111,,,,,,,,,,,,,,,,, -17100,0.18332943,1.908059,,,,,,,,,,,,,,,,, -17200,0.23273134,1.9770168,,,,,,,,,,,,,,,,, -17300,0.19770548,1.8683939,,,,,,,,,,,,,,,,, -17400,0.19533159,1.86344,,,,,,,,,,,,,,,,, -17500,0.19088799,1.9946907,,,,,,,,,,,,,,,,, -17600,0.24080452,1.8855822,,,,,,,,,,,,,,,,, -17700,0.21130602,1.8974946,,,,,,,,,,,,,,,,, -17800,0.23150516,1.8953977,,,,,,,,,,,,,,,,, -17900,0.20580858,1.8671145,,,,,,,,,,,,,,,,, -18000,0.18976156,1.7681766,,,,,,,,,,,,,,,,, -18100,0.23763663,1.9007155,,,,,,,,,,,,,,,,, -18200,0.19213356,1.9407948,,,,,,,,,,,,,,,,, -18300,0.21890098,1.8382019,,,,,,,,,,,,,,,,, -18400,0.20091371,1.901545,,,,,,,,,,,,,,,,, -18500,0.1819535,1.866188,,,,,,,,,,,,,,,,, -18600,0.18485235,1.8568709,,,,,,,,,,,,,,,,, -18617,,,0.631950318813324,1.8058927059173584,30.512143847900504,0.6519696116447449,1.6625659465789795,27.90245230878121,3000.0,0.6614025831222534,1.5896934270858765,27.088659970344388,3003.0,6752.588777065277,11537.538368225098,6752.588777065277,4784.1533625125885,0.2073345184326172,0.0 -18700,0.19188379,1.8561329,,,,,,,,,,,,,,,,, -18800,0.21873793,1.9160986,,,,,,,,,,,,,,,,, -18900,0.21694265,1.869878,,,,,,,,,,,,,,,,, -19000,0.21612976,1.8538196,,,,,,,,,,,,,,,,, -19100,0.21319713,1.9269333,,,,,,,,,,,,,,,,, -19200,0.35122538,1.8971882,,,,,,,,,,,,,,,,, -19300,0.20308238,1.8481027,,,,,,,,,,,,,,,,, -19400,0.23887147,1.9313617,,,,,,,,,,,,,,,,, -19500,0.18912551,1.8362746,,,,,,,,,,,,,,,,, -19600,0.23497486,1.9223368,,,,,,,,,,,,,,,,, -19700,0.2741515,1.9202425,,,,,,,,,,,,,,,,, -19800,0.19809473,1.9063168,,,,,,,,,,,,,,,,, -19900,0.20204636,1.8317763,,,,,,,,,,,,,,,,, -20000,0.20601614,1.884875,,,,,,,,,,,,,,,,, -20100,0.18636042,1.9429408,,,,,,,,,,,,,,,,, -20200,0.18053173,1.783901,,,,,,,,,,,,,,,,, -20300,0.21694359,1.7959388,,,,,,,,,,,,,,,,, -20400,0.21631907,1.9091359,,,,,,,,,,,,,,,,, -20500,0.2678109,1.8310187,,,,,,,,,,,,,,,,, -20600,0.20588307,1.7389715,,,,,,,,,,,,,,,,, -20700,0.26806638,1.9108412,,,,,,,,,,,,,,,,, -20800,0.30203635,1.856831,,,,,,,,,,,,,,,,, -20900,0.21716933,1.8819762,,,,,,,,,,,,,,,,, -20945,,,0.6327740550041199,1.7765021324157717,30.94379672227982,0.6536186933517456,1.6433725357055664,27.590168338973967,3000.0,0.6645168662071228,1.5687122344970703,27.15118496076002,3003.0,7592.666503190994,12869.502164840698,7592.666503190994,5275.937527179718,0.2346513271331787,0.0 -21000,0.33774188,1.8115168,,,,,,,,,,,,,,,,, -21100,0.20403782,1.7956442,,,,,,,,,,,,,,,,, -21200,0.21873984,1.8819089,,,,,,,,,,,,,,,,, -21300,0.19070743,1.8903601,,,,,,,,,,,,,,,,, -21400,0.20643619,1.8935813,,,,,,,,,,,,,,,,, -21500,0.22279651,1.8912234,,,,,,,,,,,,,,,,, -21600,0.19599798,1.9055015,,,,,,,,,,,,,,,,, -21700,0.20272115,1.8511565,,,,,,,,,,,,,,,,, -21800,0.20948659,1.8244194,,,,,,,,,,,,,,,,, -21900,0.19252393,1.8038857,,,,,,,,,,,,,,,,, -22000,0.20933071,1.9000064,,,,,,,,,,,,,,,,, -22100,0.23205267,1.7755038,,,,,,,,,,,,,,,,, -22200,0.18597183,1.8325768,,,,,,,,,,,,,,,,, -22300,0.19511725,1.8432541,,,,,,,,,,,,,,,,, -22400,0.18450761,1.8471242,,,,,,,,,,,,,,,,, -22500,0.18500537,1.7984579,,,,,,,,,,,,,,,,, -22600,0.19988415,1.8723414,,,,,,,,,,,,,,,,, -22700,0.31267396,1.8189567,,,,,,,,,,,,,,,,, -22800,0.20362891,1.7927401,,,,,,,,,,,,,,,,, -22900,0.18393971,1.8486627,,,,,,,,,,,,,,,,, -23000,0.20513768,1.8170818,,,,,,,,,,,,,,,,, -23100,0.23257475,1.8004193,,,,,,,,,,,,,,,,, -23200,0.25847486,1.8751526,,,,,,,,,,,,,,,,, -23273,,,0.6339693069458008,1.7776130437850952,30.655361633286923,0.6550569534301758,1.6391043663024902,27.783212559536747,3000.0,0.6667364239692688,1.5609209537506104,27.47751462370135,3003.0,8432.74540233612,14206.423070907593,8432.74540233612,5772.678349494934,0.2628781795501709,0.0 -23300,0.21056814,1.92067,,,,,,,,,,,,,,,,, -23400,0.20052864,1.8986512,,,,,,,,,,,,,,,,, -23500,0.24915242,1.7669632,,,,,,,,,,,,,,,,, -23600,0.3220394,1.8313652,,,,,,,,,,,,,,,,, -23700,0.21760994,1.7571279,,,,,,,,,,,,,,,,, -23800,0.30785367,1.8155131,,,,,,,,,,,,,,,,, -23900,0.20309518,1.8696667,,,,,,,,,,,,,,,,, -24000,0.20799606,1.7541211,,,,,,,,,,,,,,,,, -24100,0.19995345,1.8026633,,,,,,,,,,,,,,,,, -24200,0.20357008,1.8751637,,,,,,,,,,,,,,,,, -24300,0.21864091,1.8963562,,,,,,,,,,,,,,,,, -24400,0.21913786,1.8431522,,,,,,,,,,,,,,,,, -24500,0.23506515,1.8539202,,,,,,,,,,,,,,,,, -24600,0.22770949,1.841679,,,,,,,,,,,,,,,,, -24700,0.21246764,1.8753618,,,,,,,,,,,,,,,,, -24800,0.24479207,1.8748533,,,,,,,,,,,,,,,,, -24900,0.19930317,1.83525,,,,,,,,,,,,,,,,, -25000,0.28852287,1.7787915,,,,,,,,,,,,,,,,, -25100,0.22409108,1.8993735,,,,,,,,,,,,,,,,, -25200,0.20574632,1.8233887,,,,,,,,,,,,,,,,, -25300,0.21667662,1.8061837,,,,,,,,,,,,,,,,, -25400,0.23964775,1.8313229,,,,,,,,,,,,,,,,, -25500,0.19404604,1.8449478,,,,,,,,,,,,,,,,, -25600,,,0.6479560136795044,1.6769108772277832,31.84595571310796,0.6565200686454773,1.623263597488403,27.658885874545486,3000.0,0.6655627489089966,1.5492188930511477,27.11589523195827,3003.0,9272.903683662416,15501.935836553574,9272.903683662416,6227.92778301239,0.2916662693023681,0.0 -25600,0.2010024,1.8253686,,,,,,,,,,,,,,,,, -25700,0.19072114,1.9195595,,,,,,,,,,,,,,,,, -25800,0.20421647,1.8540149,,,,,,,,,,,,,,,,, -25900,0.2535445,1.8237219,,,,,,,,,,,,,,,,, -26000,0.2654417,1.836747,,,,,,,,,,,,,,,,, -26100,0.19068348,1.8574125,,,,,,,,,,,,,,,,, -26200,3.6450677,2.3459764,,,,,,,,,,,,,,,,, -26300,0.26452032,1.8575739,,,,,,,,,,,,,,,,, -26400,0.2194023,1.7903596,,,,,,,,,,,,,,,,, -26500,0.18275584,1.780358,,,,,,,,,,,,,,,,, -26600,0.18265016,1.7564269,,,,,,,,,,,,,,,,, -26700,0.19314843,1.7974846,,,,,,,,,,,,,,,,, -26800,0.17558005,1.755657,,,,,,,,,,,,,,,,, -26900,0.19268915,1.789379,,,,,,,,,,,,,,,,, -27000,0.21279466,1.8341414,,,,,,,,,,,,,,,,, -27100,0.21307863,1.8031027,,,,,,,,,,,,,,,,, -27200,0.20292535,1.7727201,,,,,,,,,,,,,,,,, -27300,0.1886034,1.8043748,,,,,,,,,,,,,,,,, -27400,0.19407804,1.7959232,,,,,,,,,,,,,,,,, -27500,0.19119495,1.8417548,,,,,,,,,,,,,,,,, -27600,0.21604827,1.7810452,,,,,,,,,,,,,,,,, -27700,0.1904251,1.8266202,,,,,,,,,,,,,,,,, -27800,0.23249052,1.7747461,,,,,,,,,,,,,,,,, -27900,0.23272759,1.8479041,,,,,,,,,,,,,,,,, -27927,,,0.6404833793640137,1.7259860038757324,31.049956913695816,0.6574624180793762,1.6162623167037964,27.82174760624576,3000.0,0.6682470440864563,1.5407686233520508,27.582243730159053,3003.0,10112.873045921326,17107.914113998413,10112.873045921326,6993.833231925964,0.3207814693450928,0.0 -28000,0.23405448,1.8658392,,,,,,,,,,,,,,,,, -28100,0.22081938,1.8461441,,,,,,,,,,,,,,,,, -28200,0.24097835,1.8720798,,,,,,,,,,,,,,,,, -28300,0.21227519,1.8106215,,,,,,,,,,,,,,,,, -28400,0.2093174,1.8150438,,,,,,,,,,,,,,,,, -28500,0.20493057,1.7937442,,,,,,,,,,,,,,,,, -28600,0.27211612,1.8186632,,,,,,,,,,,,,,,,, -28700,0.27619702,1.762826,,,,,,,,,,,,,,,,, -28800,0.30663165,1.8737326,,,,,,,,,,,,,,,,, -28900,0.24123177,1.7589456,,,,,,,,,,,,,,,,, -29000,0.19367632,1.9160675,,,,,,,,,,,,,,,,, -29100,0.21120203,1.7396873,,,,,,,,,,,,,,,,, -29200,0.23362964,1.8095455,,,,,,,,,,,,,,,,, -29300,0.20086205,1.8098894,,,,,,,,,,,,,,,,, -29400,0.19726801,1.9220244,,,,,,,,,,,,,,,,, -29500,0.22976097,1.892355,,,,,,,,,,,,,,,,, -29600,0.24373886,1.8293498,,,,,,,,,,,,,,,,, -29700,0.2102876,1.811088,,,,,,,,,,,,,,,,, -29800,0.20108797,1.7916623,,,,,,,,,,,,,,,,, -29900,0.23656473,1.7937168,,,,,,,,,,,,,,,,, -30000,0.20413032,1.7408955,,,,,,,,,,,,,,,,, -30100,0.2087807,1.7877947,,,,,,,,,,,,,,,,, -30200,0.2304427,1.8555301,,,,,,,,,,,,,,,,, -30255,,,0.636620819568634,1.76451575756073,31.611790041850664,0.6559621095657349,1.605042576789856,27.67252673218362,3000.0,0.6700947284698486,1.525154709815979,27.54839713124831,3003.0,10952.89643716812,18488.50726699829,10952.89643716812,7534.302425146103,0.3493866920471191,0.0 -30300,0.25110847,1.8168391,,,,,,,,,,,,,,,,, -30400,0.24605486,1.7957172,,,,,,,,,,,,,,,,, -30500,0.27699223,1.7994086,,,,,,,,,,,,,,,,, -30600,0.2505512,1.8537196,,,,,,,,,,,,,,,,, -30700,0.28216752,1.8677459,,,,,,,,,,,,,,,,, -30800,0.23866278,1.8393879,,,,,,,,,,,,,,,,, -30900,0.17609055,1.6934261,,,,,,,,,,,,,,,,, -31000,0.2208721,1.8655384,,,,,,,,,,,,,,,,, -31100,0.1976841,1.7676015,,,,,,,,,,,,,,,,, -31200,0.23483628,1.7959503,,,,,,,,,,,,,,,,, -31300,0.20557004,1.8749884,,,,,,,,,,,,,,,,, -31400,0.37966475,1.8839655,,,,,,,,,,,,,,,,, -31500,0.19934766,1.8136889,,,,,,,,,,,,,,,,, -31600,0.18575461,1.75838,,,,,,,,,,,,,,,,, -31700,0.20738922,1.8181448,,,,,,,,,,,,,,,,, -31800,0.23963094,1.6941471,,,,,,,,,,,,,,,,, -31900,0.20220864,1.804381,,,,,,,,,,,,,,,,, -32000,0.20841481,1.8383648,,,,,,,,,,,,,,,,, -32100,2.075461,1.7875264,,,,,,,,,,,,,,,,, -32200,0.23191527,1.8252432,,,,,,,,,,,,,,,,, -32300,0.20302442,1.8397076,,,,,,,,,,,,,,,,, -32400,0.20645179,1.7598432,,,,,,,,,,,,,,,,, -32500,0.19836086,1.7693136,,,,,,,,,,,,,,,,, -32583,,,0.6440646648406982,1.7038763761520386,31.360567916844506,0.6583923101425171,1.594013810157776,27.90004436424476,3000.0,0.6725466251373291,1.5128222703933716,27.34830812687249,3003.0,11793.089814901352,19861.80053019524,11793.089814901352,8067.295173883438,0.3822178840637207,0.0 -32600,0.21493618,1.8425559,,,,,,,,,,,,,,,,, -32700,0.19748801,1.8309935,,,,,,,,,,,,,,,,, -32800,0.25268453,1.8722273,,,,,,,,,,,,,,,,, -32900,0.19876389,1.7854326,,,,,,,,,,,,,,,,, -33000,0.20980373,1.768528,,,,,,,,,,,,,,,,, -33100,0.19383313,1.8293549,,,,,,,,,,,,,,,,, -33200,0.19355631,1.7919059,,,,,,,,,,,,,,,,, -33300,0.22673447,1.764387,,,,,,,,,,,,,,,,, -33400,0.28130203,1.8137411,,,,,,,,,,,,,,,,, -33500,0.20892769,1.8044729,,,,,,,,,,,,,,,,, -33600,0.23557271,1.7593288,,,,,,,,,,,,,,,,, -33700,0.24761221,1.6913471,,,,,,,,,,,,,,,,, -33800,0.24004747,1.696582,,,,,,,,,,,,,,,,, -33900,0.23466676,1.7937863,,,,,,,,,,,,,,,,, -34000,0.20370114,1.7813406,,,,,,,,,,,,,,,,, -34100,0.19330809,1.7551814,,,,,,,,,,,,,,,,, -34200,0.20395313,1.7232151,,,,,,,,,,,,,,,,, -34300,0.22311768,1.8115746,,,,,,,,,,,,,,,,, -34400,0.19046487,1.787739,,,,,,,,,,,,,,,,, -34500,0.21377148,1.7865652,,,,,,,,,,,,,,,,, -34600,0.2175044,1.7269264,,,,,,,,,,,,,,,,, -34700,0.20095882,1.855666,,,,,,,,,,,,,,,,, -34800,0.30980027,1.7985911,,,,,,,,,,,,,,,,, -34900,0.20527929,1.7405151,,,,,,,,,,,,,,,,, -34910,,,0.6403083801269531,1.734409213066101,31.28614280491014,0.6616036891937256,1.5844405889511108,28.19766421877236,3000.0,0.6729998588562012,1.5030328035354614,27.251809248166783,3003.0,12633.118577480316,21188.203240394592,12633.118577480316,8553.564245939255,0.4125397205352783,0.0 -35000,0.2952125,1.8489769,,,,,,,,,,,,,,,,, -35100,0.22104244,1.8067043,,,,,,,,,,,,,,,,, -35200,0.19757564,1.7930707,,,,,,,,,,,,,,,,, -35300,0.19715908,1.746088,,,,,,,,,,,,,,,,, -35400,0.20613748,1.8103529,,,,,,,,,,,,,,,,, -35500,0.20886609,1.8297547,,,,,,,,,,,,,,,,, -35600,0.1802629,1.7597805,,,,,,,,,,,,,,,,, -35700,0.20797375,1.7755176,,,,,,,,,,,,,,,,, -35800,0.20815478,1.8785579,,,,,,,,,,,,,,,,, -35900,0.24346334,1.7999603,,,,,,,,,,,,,,,,, -36000,0.19434436,1.7428269,,,,,,,,,,,,,,,,, -36100,0.2065891,1.8348888,,,,,,,,,,,,,,,,, -36200,0.24034992,1.8260779,,,,,,,,,,,,,,,,, -36300,0.26535746,1.86298,,,,,,,,,,,,,,,,, -36400,0.21404059,1.7996781,,,,,,,,,,,,,,,,, -36500,0.2135314,1.7147495,,,,,,,,,,,,,,,,, -36600,0.18618645,1.7434127,,,,,,,,,,,,,,,,, -36700,0.21407147,1.7402977,,,,,,,,,,,,,,,,, -36800,0.18167204,1.8181955,,,,,,,,,,,,,,,,, -36900,0.2119898,1.8545513,,,,,,,,,,,,,,,,, -37000,0.19630386,1.7819468,,,,,,,,,,,,,,,,, -37100,0.24837838,1.8318144,,,,,,,,,,,,,,,,, -37200,0.23324403,1.8466209,,,,,,,,,,,,,,,,, -37238,,,0.6411272883415222,1.7296642065048218,31.164002746131025,0.6619756817817688,1.5773353576660156,28.1188025528331,3000.0,0.6725582480430603,1.502700686454773,27.28591892054456,3003.0,13473.14170718193,22550.16905093193,13473.14170718193,9075.40495300293,0.4424059391021728,0.0 -37300,0.22836722,1.784407,,,,,,,,,,,,,,,,, -37400,0.18711941,1.7206974,,,,,,,,,,,,,,,,, -37500,0.2493476,1.7864081,,,,,,,,,,,,,,,,, -37600,0.23558402,1.7629317,,,,,,,,,,,,,,,,, -37700,0.22595769,1.6577871,,,,,,,,,,,,,,,,, -37800,0.29475516,1.7675365,,,,,,,,,,,,,,,,, -37900,0.20793685,1.7460754,,,,,,,,,,,,,,,,, -38000,0.20708133,1.7719499,,,,,,,,,,,,,,,,, -38100,0.19310431,1.8058823,,,,,,,,,,,,,,,,, -38200,0.21205218,1.7766676,,,,,,,,,,,,,,,,, -38300,0.20347293,1.6818622,,,,,,,,,,,,,,,,, -38400,0.24202497,1.7943817,,,,,,,,,,,,,,,,, -38500,0.21682431,1.7247698,,,,,,,,,,,,,,,,, -38600,0.22676979,1.7259812,,,,,,,,,,,,,,,,, -38700,0.19294913,1.6992571,,,,,,,,,,,,,,,,, -38800,0.22137748,1.7907825,,,,,,,,,,,,,,,,, -38900,0.23969123,1.721625,,,,,,,,,,,,,,,,, -39000,0.21032752,1.8743775,,,,,,,,,,,,,,,,, -39100,0.19765517,1.7859404,,,,,,,,,,,,,,,,, -39200,0.21925835,1.7702309,,,,,,,,,,,,,,,,, -39300,0.22626723,1.792019,,,,,,,,,,,,,,,,, -39400,0.21950334,1.8591503,,,,,,,,,,,,,,,,, -39500,0.27523813,1.7915226,,,,,,,,,,,,,,,,, -39565,,,0.6465423107147217,1.6878236532211304,31.69269439740908,0.6620872616767883,1.5737826824188232,28.22098356404612,3000.0,0.6753936409950256,1.4936915636062622,27.86312619079604,3003.0,14313.19044804573,23933.09076309204,14313.19044804573,9618.16955590248,0.4737052917480469,0.0 -39600,0.23068139,1.7518942,,,,,,,,,,,,,,,,, -39700,0.30494618,1.7330306,,,,,,,,,,,,,,,,, -39800,0.22036034,1.7776698,,,,,,,,,,,,,,,,, -39900,0.21596786,1.7532319,,,,,,,,,,,,,,,,, -40000,0.19569083,1.8370175,,,,,,,,,,,,,,,,, -40100,0.20646423,1.766671,,,,,,,,,,,,,,,,, -40200,0.23770103,1.7467656,,,,,,,,,,,,,,,,, -40300,0.21222554,1.7532705,,,,,,,,,,,,,,,,, -40400,0.22483224,1.7647994,,,,,,,,,,,,,,,,, -40500,0.20577246,1.7839773,,,,,,,,,,,,,,,,, -40600,0.20785019,1.8075569,,,,,,,,,,,,,,,,, -40700,0.21054853,1.7481523,,,,,,,,,,,,,,,,, -40800,0.19081548,1.7113289,,,,,,,,,,,,,,,,, -40900,0.18192472,1.6853175,,,,,,,,,,,,,,,,, -41000,0.18946178,1.7499869,,,,,,,,,,,,,,,,, -41100,0.25530955,1.823973,,,,,,,,,,,,,,,,, -41200,0.21473603,1.7881037,,,,,,,,,,,,,,,,, -41300,0.24446772,1.7690145,,,,,,,,,,,,,,,,, -41400,0.22126171,1.7091033,,,,,,,,,,,,,,,,, -41500,0.20182654,1.7276577,,,,,,,,,,,,,,,,, -41600,0.2872176,1.7790238,,,,,,,,,,,,,,,,, -41700,0.19507653,1.8119398,,,,,,,,,,,,,,,,, -41800,0.20935887,1.818727,,,,,,,,,,,,,,,,, -41893,,,0.6447368264198303,1.7097009420394895,31.56412929199626,0.6631907820701599,1.5640273094177246,28.26689393593714,3000.0,0.6745802164077759,1.4858415126800537,28.026926555958124,3003.0,15153.36160159111,25320.262265205383,15153.36160159111,10165.055995941162,0.5107996463775635,0.0 -41900,0.21260293,1.7795867,,,,,,,,,,,,,,,,, -42000,0.18477291,1.7542164,,,,,,,,,,,,,,,,, -42100,0.19013135,1.7546366,,,,,,,,,,,,,,,,, -42200,0.213901,1.7519999,,,,,,,,,,,,,,,,, -42300,0.19047733,1.7301284,,,,,,,,,,,,,,,,, -42400,0.24539174,1.7374392,,,,,,,,,,,,,,,,, -42500,0.18769489,1.7328187,,,,,,,,,,,,,,,,, -42600,0.24796402,1.7310258,,,,,,,,,,,,,,,,, -42700,0.18916243,1.740466,,,,,,,,,,,,,,,,, -42800,0.2259639,1.7544681,,,,,,,,,,,,,,,,, -42900,0.21992251,1.7019318,,,,,,,,,,,,,,,,, -43000,0.19717145,1.7511406,,,,,,,,,,,,,,,,, -43100,0.24193022,1.7953135,,,,,,,,,,,,,,,,, -43200,0.21697868,1.7729299,,,,,,,,,,,,,,,,, -43300,0.2104671,1.7120874,,,,,,,,,,,,,,,,, -43400,0.18732992,1.7841271,,,,,,,,,,,,,,,,, -43500,0.20291093,1.7461824,,,,,,,,,,,,,,,,, -43600,0.20982437,1.7255309,,,,,,,,,,,,,,,,, -43700,0.24600117,1.7804682,,,,,,,,,,,,,,,,, -43800,0.2456023,1.7583824,,,,,,,,,,,,,,,,, -43900,0.22395083,1.7313222,,,,,,,,,,,,,,,,, -44000,0.21346517,1.7347133,,,,,,,,,,,,,,,,, -44100,0.22837996,1.7239085,,,,,,,,,,,,,,,,, -44200,0.2154347,1.8230205,,,,,,,,,,,,,,,,, -44221,,,0.6637552380561829,1.5781522989273071,32.93034087770076,0.6653606295585632,1.5589519739151,28.47223571166506,3000.0,0.6778572201728821,1.474234104156494,28.2497816979482,3003.0,15993.378037929537,26729.05960392952,15993.378037929537,10733.730982542038,0.5428597927093506,0.0 -44300,0.19756477,1.7420547,,,,,,,,,,,,,,,,, -44400,0.21004808,1.7760552,,,,,,,,,,,,,,,,, -44500,0.21272698,1.7789766,,,,,,,,,,,,,,,,, -44600,0.20795226,1.7721243,,,,,,,,,,,,,,,,, -44700,0.17978776,1.6576805,,,,,,,,,,,,,,,,, -44800,0.18954487,1.7067866,,,,,,,,,,,,,,,,, -44900,0.20865567,1.9103377,,,,,,,,,,,,,,,,, -45000,0.21565145,1.700529,,,,,,,,,,,,,,,,, -45100,0.21133696,1.7338837,,,,,,,,,,,,,,,,, -45200,0.19749354,1.7940099,,,,,,,,,,,,,,,,, -45300,0.23307349,1.7435743,,,,,,,,,,,,,,,,, -45400,0.2023259,1.7803527,,,,,,,,,,,,,,,,, -45500,0.22248784,1.8409283,,,,,,,,,,,,,,,,, -45600,0.20397192,1.7806716,,,,,,,,,,,,,,,,, -45700,0.19270843,1.7669913,,,,,,,,,,,,,,,,, -45800,0.18678676,1.7476611,,,,,,,,,,,,,,,,, -45900,0.21911432,1.8279519,,,,,,,,,,,,,,,,, -46000,0.2507118,1.7220912,,,,,,,,,,,,,,,,, -46100,0.2192665,1.8226597,,,,,,,,,,,,,,,,, -46200,0.19690232,1.75181,,,,,,,,,,,,,,,,, -46300,0.19524097,1.7405279,,,,,,,,,,,,,,,,, -46400,0.19118127,1.7264187,,,,,,,,,,,,,,,,, -46500,0.1898954,1.6868842,,,,,,,,,,,,,,,,, -46549,,,0.6493027806282043,1.6772871017456057,31.197649864945813,0.6661293506622314,1.5487326383590698,28.40430065663764,3000.0,0.6787520051002502,1.4658998250961304,28.08245586929687,3003.0,16833.506383895874,28191.62198972702,16833.506383895874,11356.05702495575,0.5749258995056152,0.0 -46600,0.3120428,1.753388,,,,,,,,,,,,,,,,, -46700,0.28961864,1.7479478,,,,,,,,,,,,,,,,, -46800,0.2068722,1.7559866,,,,,,,,,,,,,,,,, -46900,0.21482101,1.7499697,,,,,,,,,,,,,,,,, -47000,0.2023796,1.842102,,,,,,,,,,,,,,,,, -47100,0.19728762,1.8444912,,,,,,,,,,,,,,,,, -47200,0.26352584,1.8247414,,,,,,,,,,,,,,,,, -47300,0.20642008,1.672726,,,,,,,,,,,,,,,,, -47400,0.22654244,1.7225896,,,,,,,,,,,,,,,,, -47500,0.19975619,1.7775255,,,,,,,,,,,,,,,,, -47600,0.21140973,1.7062722,,,,,,,,,,,,,,,,, -47700,0.18505819,1.8178428,,,,,,,,,,,,,,,,, -47800,0.2121547,1.7356384,,,,,,,,,,,,,,,,, -47900,0.200638,1.759378,,,,,,,,,,,,,,,,, -48000,0.20345373,1.7113856,,,,,,,,,,,,,,,,, -48100,0.20290205,1.6574371,,,,,,,,,,,,,,,,, -48200,0.20546831,1.7551923,,,,,,,,,,,,,,,,, -48300,0.20564651,1.7728019,,,,,,,,,,,,,,,,, -48400,0.22093388,1.7738205,,,,,,,,,,,,,,,,, -48500,0.21370383,1.7421429,,,,,,,,,,,,,,,,, -48600,0.24572349,1.7512693,,,,,,,,,,,,,,,,, -48700,0.23415558,1.7941713,,,,,,,,,,,,,,,,, -48800,0.18773307,1.6966642,,,,,,,,,,,,,,,,, -48877,,,0.6498969793319702,1.671565294265747,31.67379922402138,0.6682496070861816,1.5388306379318235,28.52788848704636,3000.0,0.6791703104972839,1.4580789804458618,28.13524029971456,3003.0,17673.46990466118,29582.065421819687,17673.46990466118,11906.429987430573,0.6067063808441162,0.0 -48900,0.21160571,1.8115308,,,,,,,,,,,,,,,,, -49000,0.20848958,1.7450941,,,,,,,,,,,,,,,,, -49100,0.21392669,1.7267553,,,,,,,,,,,,,,,,, -49200,0.22148986,1.7113842,,,,,,,,,,,,,,,,, -49300,0.21907085,1.7565427,,,,,,,,,,,,,,,,, -49400,0.20231055,1.7875848,,,,,,,,,,,,,,,,, -49500,0.23583129,1.6897832,,,,,,,,,,,,,,,,, -49600,0.2055211,1.7425739,,,,,,,,,,,,,,,,, -49700,0.20979273,1.8060045,,,,,,,,,,,,,,,,, -49800,0.22421493,1.8254982,,,,,,,,,,,,,,,,, -49900,0.1981269,1.8391424,,,,,,,,,,,,,,,,, -50000,0.42087048,1.8956972,,,,,,,,,,,,,,,,, -50100,0.20781143,1.7541741,,,,,,,,,,,,,,,,, -50200,0.19810303,1.7425069,,,,,,,,,,,,,,,,, -50300,0.64446753,1.7686442,,,,,,,,,,,,,,,,, -50400,0.1969565,1.7614514,,,,,,,,,,,,,,,,, -50500,0.21119444,1.7037362,,,,,,,,,,,,,,,,, -50600,0.24279541,1.7370704,,,,,,,,,,,,,,,,, -50700,0.23216137,1.7355514,,,,,,,,,,,,,,,,, -50800,0.21522443,1.8021194,,,,,,,,,,,,,,,,, -50900,0.23743024,1.7725654,,,,,,,,,,,,,,,,, -51000,0.22121876,1.71103,,,,,,,,,,,,,,,,, -51100,0.20079567,1.6777118,,,,,,,,,,,,,,,,, -51200,0.20252098,1.7148302,,,,,,,,,,,,,,,,, -51205,,,0.6544747352600098,1.6260355710983276,32.19056354138812,0.6689439415931702,1.538604497909546,29.092913139909744,3000.0,0.6821568012237549,1.45171320438385,28.69368597124545,3003.0,18513.54372811317,30979.550669908524,18513.54372811317,12463.738171815872,0.638897180557251,0.0 -51300,0.20864087,1.7782602,,,,,,,,,,,,,,,,, -51400,0.25026894,1.6658393,,,,,,,,,,,,,,,,, -51500,0.21062146,1.6923659,,,,,,,,,,,,,,,,, -51600,0.21696699,1.8084874,,,,,,,,,,,,,,,,, -51700,0.21441394,1.7457088,,,,,,,,,,,,,,,,, -51800,0.21351089,1.817359,,,,,,,,,,,,,,,,, -51900,0.22454631,1.706009,,,,,,,,,,,,,,,,, -52000,0.2179172,1.7162052,,,,,,,,,,,,,,,,, -52100,0.18832865,1.7511696,,,,,,,,,,,,,,,,, -52200,0.25809297,1.8038505,,,,,,,,,,,,,,,,, -52300,0.20048185,1.7189326,,,,,,,,,,,,,,,,, -52400,0.19593358,1.7979848,,,,,,,,,,,,,,,,, -52500,0.20021239,1.7386079,,,,,,,,,,,,,,,,, -52600,0.2042267,1.7573206,,,,,,,,,,,,,,,,, -52700,0.21109748,1.677455,,,,,,,,,,,,,,,,, -52800,0.20950927,1.7333753,,,,,,,,,,,,,,,,, -52900,0.19363877,1.717959,,,,,,,,,,,,,,,,, -53000,0.21898893,1.7105631,,,,,,,,,,,,,,,,, -53100,0.21295878,1.6615039,,,,,,,,,,,,,,,,, -53200,0.19153753,1.6860532,,,,,,,,,,,,,,,,, -53300,0.23341666,1.6421341,,,,,,,,,,,,,,,,, -53400,0.19335134,1.7426631,,,,,,,,,,,,,,,,, -53500,0.19897294,1.7260989,,,,,,,,,,,,,,,,, -53533,,,0.6519391536712646,1.6503171920776367,32.23124119536476,0.6698243021965027,1.5270503759384155,28.78384577586669,3000.0,0.6848992109298706,1.4424127340316772,28.1405950398642,3003.0,19353.507332086563,32383.50212812424,19353.507332086563,13027.610567808151,0.6803698539733887,0.0 -53600,0.18343662,1.6854755,,,,,,,,,,,,,,,,, -53700,0.23112746,1.7793236,,,,,,,,,,,,,,,,, -53800,0.19659011,1.706921,,,,,,,,,,,,,,,,, -53900,0.22652407,1.7192398,,,,,,,,,,,,,,,,, -54000,0.20855215,1.7043067,,,,,,,,,,,,,,,,, -54100,0.21763249,1.85798,,,,,,,,,,,,,,,,, -54200,0.18945868,1.7627,,,,,,,,,,,,,,,,, -54300,0.2133473,1.7362206,,,,,,,,,,,,,,,,, -54400,0.2625072,1.7178725,,,,,,,,,,,,,,,,, -54500,0.21541916,1.8130045,,,,,,,,,,,,,,,,, -54600,0.20307745,1.7490939,,,,,,,,,,,,,,,,, -54700,0.21281879,1.7632673,,,,,,,,,,,,,,,,, -54800,0.20302697,1.675553,,,,,,,,,,,,,,,,, -54900,0.2129876,1.7604896,,,,,,,,,,,,,,,,, -55000,0.2564043,1.7749774,,,,,,,,,,,,,,,,, -55100,0.21181078,1.7119813,,,,,,,,,,,,,,,,, -55200,0.2165962,1.6757421,,,,,,,,,,,,,,,,, -55300,0.1814194,1.6747667,,,,,,,,,,,,,,,,, -55400,0.20624816,1.7211356,,,,,,,,,,,,,,,,, -55500,0.20234235,1.7821789,,,,,,,,,,,,,,,,, -55600,0.25522795,1.8026583,,,,,,,,,,,,,,,,, -55700,0.225532,1.6089114,,,,,,,,,,,,,,,,, -55800,0.5491846,1.7241173,,,,,,,,,,,,,,,,, -55860,,,0.6507437229156494,1.672051191329956,32.41903297717362,0.6705806255340576,1.521011233329773,29.09612197316751,3000.0,0.68580561876297,1.4303960800170898,29.12164847468511,3003.0,20193.692486524586,33800.267602682114,20193.692486524586,13604.082396030426,0.7132043838500977,0.0 -55900,0.20000006,1.697029,,,,,,,,,,,,,,,,, -56000,0.18629651,1.7373415,,,,,,,,,,,,,,,,, -56100,0.20613632,1.7172213,,,,,,,,,,,,,,,,, -56200,0.23845764,1.7987279,,,,,,,,,,,,,,,,, -56300,0.20114973,1.7581148,,,,,,,,,,,,,,,,, -56400,0.19063704,1.6999265,,,,,,,,,,,,,,,,, -56500,0.20253134,1.7654833,,,,,,,,,,,,,,,,, -56600,0.21868092,1.7478559,,,,,,,,,,,,,,,,, -56700,0.1948107,1.6319638,,,,,,,,,,,,,,,,, -56800,0.18534443,1.6986942,,,,,,,,,,,,,,,,, -56900,0.20551097,1.6377658,,,,,,,,,,,,,,,,, -57000,0.20323707,1.589928,,,,,,,,,,,,,,,,, -57100,0.19763228,1.7934064,,,,,,,,,,,,,,,,, -57200,0.22898197,1.756281,,,,,,,,,,,,,,,,, -57300,0.8156439,1.7368932,,,,,,,,,,,,,,,,, -57400,0.24551867,1.6987199,,,,,,,,,,,,,,,,, -57500,0.18649016,1.6765066,,,,,,,,,,,,,,,,, -57600,0.21749163,1.6545304,,,,,,,,,,,,,,,,, -57700,0.19056047,1.6909037,,,,,,,,,,,,,,,,, -57800,0.19958064,1.6481631,,,,,,,,,,,,,,,,, -57900,0.20345093,1.7580544,,,,,,,,,,,,,,,,, -58000,0.2189055,1.7782662,,,,,,,,,,,,,,,,, -58100,0.6585905,1.7887396,,,,,,,,,,,,,,,,, -58188,,,0.6537730693817139,1.6372880935668943,32.227818247947184,0.6702582836151123,1.5269571542739868,28.863251676095945,3000.0,0.6871768236160278,1.433237910270691,28.67539179023136,3003.0,21033.878321647644,35274.847340106964,21033.878321647644,14238.36943602562,0.7463061809539795,0.0 -58200,0.19522609,1.6894777,,,,,,,,,,,,,,,,, -58300,0.21305265,1.6513498,,,,,,,,,,,,,,,,, -58400,0.20678471,1.734565,,,,,,,,,,,,,,,,, -58500,0.19921148,1.7228422,,,,,,,,,,,,,,,,, -58600,0.20689508,1.6788857,,,,,,,,,,,,,,,,, -58700,0.35053536,1.7048912,,,,,,,,,,,,,,,,, -58800,0.22412698,1.6966529,,,,,,,,,,,,,,,,, -58900,0.20242652,1.8122512,,,,,,,,,,,,,,,,, -59000,0.25764847,1.7061186,,,,,,,,,,,,,,,,, -59100,0.19837724,1.6569731,,,,,,,,,,,,,,,,, -59200,0.33186528,1.7904155,,,,,,,,,,,,,,,,, -59300,0.23149689,1.7060784,,,,,,,,,,,,,,,,, -59400,0.20531641,1.7243464,,,,,,,,,,,,,,,,, -59500,0.21259275,1.6744413,,,,,,,,,,,,,,,,, -59600,0.23887812,1.7571284,,,,,,,,,,,,,,,,, -59700,0.21122181,1.6677378,,,,,,,,,,,,,,,,, -59800,0.2010869,1.642187,,,,,,,,,,,,,,,,, -59900,0.1916936,1.6553831,,,,,,,,,,,,,,,,, -60000,0.20543967,1.7479739,,,,,,,,,,,,,,,,, -60100,0.2037205,1.6506615,,,,,,,,,,,,,,,,, -60200,0.411916,1.7372487,,,,,,,,,,,,,,,,, -60300,0.20644858,1.7165103,,,,,,,,,,,,,,,,, -60400,0.20370209,1.6889849,,,,,,,,,,,,,,,,, -60500,0.19709112,1.7118521,,,,,,,,,,,,,,,,, -60515,,,0.6578980684280396,1.6208689212799072,32.07658106977582,0.6717089414596558,1.5151821374893188,28.92682264501012,3000.0,0.6871535778045654,1.4229494333267212,29.073916460981607,3003.0,21873.81586742401,36837.5823905468,21873.81586742401,14961.058718919754,0.7788941860198975,0.0 -60600,0.18903485,1.6881952,,,,,,,,,,,,,,,,, -60700,0.42346403,1.6697896,,,,,,,,,,,,,,,,, -60800,0.2126557,1.7175442,,,,,,,,,,,,,,,,, -60900,0.21319929,1.6512194,,,,,,,,,,,,,,,,, -61000,0.20114902,1.6757251,,,,,,,,,,,,,,,,, -61100,0.30578583,1.7294902,,,,,,,,,,,,,,,,, -61200,0.22205348,1.7556955,,,,,,,,,,,,,,,,, -61300,0.21661168,1.6775159,,,,,,,,,,,,,,,,, -61400,0.19303419,1.6779522,,,,,,,,,,,,,,,,, -61500,0.44595522,1.6530837,,,,,,,,,,,,,,,,, -61600,0.19987632,1.6216178,,,,,,,,,,,,,,,,, -61700,0.1898924,1.6968032,,,,,,,,,,,,,,,,, -61800,0.19508655,1.6913972,,,,,,,,,,,,,,,,, -61900,0.2077607,1.7520278,,,,,,,,,,,,,,,,, -62000,0.21266796,1.7487838,,,,,,,,,,,,,,,,, -62100,0.21435887,1.6730037,,,,,,,,,,,,,,,,, -62200,0.19346659,1.7222047,,,,,,,,,,,,,,,,, -62300,0.1997423,1.6226978,,,,,,,,,,,,,,,,, -62400,0.20762794,1.7023005,,,,,,,,,,,,,,,,, -62500,0.2040907,1.744774,,,,,,,,,,,,,,,,, -62600,0.2194926,1.7703977,,,,,,,,,,,,,,,,, -62700,0.194606,1.7576008,,,,,,,,,,,,,,,,, -62800,0.19249752,1.6800007,,,,,,,,,,,,,,,,, -62841,,,0.6713626980781555,1.508202075958252,33.69910168169331,0.674411952495575,1.5051673650741575,29.20011186816427,3000.0,0.6859566569328308,1.4162017107009888,28.6389104597473,3003.0,22713.780904769897,38213.50196003914,22713.780904769897,15496.90365743637,0.8135547637939453,0.0 -62900,0.21317819,1.7681825,,,,,,,,,,,,,,,,, -63000,0.19206646,1.7254598,,,,,,,,,,,,,,,,, -63100,0.19946393,1.7031945,,,,,,,,,,,,,,,,, -63200,0.22502646,1.7216243,,,,,,,,,,,,,,,,, -63300,0.2598138,1.7307472,,,,,,,,,,,,,,,,, -63400,0.21256284,1.6440142,,,,,,,,,,,,,,,,, -63500,0.22530536,1.6234525,,,,,,,,,,,,,,,,, -63600,0.19676995,1.656512,,,,,,,,,,,,,,,,, -63700,0.20154738,1.6895636,,,,,,,,,,,,,,,,, -63800,0.2370202,1.735593,,,,,,,,,,,,,,,,, -63900,0.21464191,1.766689,,,,,,,,,,,,,,,,, -64000,0.20482227,1.6177343,,,,,,,,,,,,,,,,, -64100,0.21361585,1.6824504,,,,,,,,,,,,,,,,, -64200,0.20114572,1.7640966,,,,,,,,,,,,,,,,, -64300,0.21647608,1.6883053,,,,,,,,,,,,,,,,, -64400,0.19543698,1.6853262,,,,,,,,,,,,,,,,, -64500,0.20113042,1.7645477,,,,,,,,,,,,,,,,, -64600,0.19185877,1.657832,,,,,,,,,,,,,,,,, -64700,0.1957207,1.6040521,,,,,,,,,,,,,,,,, -64800,0.19633539,1.7004968,,,,,,,,,,,,,,,,, -64900,0.22062032,1.7797707,,,,,,,,,,,,,,,,, -65000,0.19932827,1.6416168,,,,,,,,,,,,,,,,, -65100,0.20604654,1.6532794,,,,,,,,,,,,,,,,, -65168,,,0.65871661901474,1.6069515943527222,32.66574289741518,0.6746847629547119,1.4960321187973022,29.051769116288494,3000.0,0.6899308562278748,1.4054077863693235,28.924297355169426,3003.0,23553.788153648376,39609.90386343002,23553.788153648376,16053.180361270905,0.8553059101104736,0.0 -65200,0.22522108,1.6901271,,,,,,,,,,,,,,,,, -65300,0.20660377,1.7879354,,,,,,,,,,,,,,,,, -65400,0.18452075,1.6428308,,,,,,,,,,,,,,,,, -65500,0.23298267,1.6704707,,,,,,,,,,,,,,,,, -65600,0.38206965,1.6166134,,,,,,,,,,,,,,,,, -65700,0.20444778,1.6985368,,,,,,,,,,,,,,,,, -65800,0.21522145,1.6481093,,,,,,,,,,,,,,,,, -65900,0.20352782,1.7358059,,,,,,,,,,,,,,,,, -66000,0.26284498,1.7101455,,,,,,,,,,,,,,,,, -66100,0.20460673,1.692835,,,,,,,,,,,,,,,,, -66200,0.22979015,1.6016905,,,,,,,,,,,,,,,,, -66300,0.19004284,1.6118852,,,,,,,,,,,,,,,,, -66400,0.19867639,1.6348988,,,,,,,,,,,,,,,,, -66500,0.21599302,1.7605176,,,,,,,,,,,,,,,,, -66600,0.2372602,1.7612516,,,,,,,,,,,,,,,,, -66700,0.22302414,1.666915,,,,,,,,,,,,,,,,, -66800,0.2116086,1.6967151,,,,,,,,,,,,,,,,, -66900,0.19544609,1.6346065,,,,,,,,,,,,,,,,, -67000,0.20787363,1.6678897,,,,,,,,,,,,,,,,, -67100,0.1980715,1.6991938,,,,,,,,,,,,,,,,, -67200,0.20178585,1.7678845,,,,,,,,,,,,,,,,, -67300,0.18943527,1.6625926,,,,,,,,,,,,,,,,, -67400,0.31599697,1.6423812,,,,,,,,,,,,,,,,, -67496,,,0.6580450534820557,1.6125664710998535,32.3845854245405,0.6748707294464111,1.4918677806854248,29.354178552156,3000.0,0.689059317111969,1.4040769338607788,29.383744057926585,3003.0,24393.98091578484,40972.572177410126,24393.98091578484,16575.54416203499,0.8911569118499756,0.0 -67500,0.27167392,1.6666784,,,,,,,,,,,,,,,,, -67600,0.21443462,1.707311,,,,,,,,,,,,,,,,, -67700,0.20050691,1.692273,,,,,,,,,,,,,,,,, -67800,0.20517634,1.699089,,,,,,,,,,,,,,,,, -67900,0.19644031,1.5927941,,,,,,,,,,,,,,,,, -68000,0.20912847,1.7699351,,,,,,,,,,,,,,,,, -68100,0.20788774,1.6560491,,,,,,,,,,,,,,,,, -68200,0.22000095,1.6866009,,,,,,,,,,,,,,,,, -68300,0.21393047,1.7217621,,,,,,,,,,,,,,,,, -68400,0.2107259,1.7808126,,,,,,,,,,,,,,,,, -68500,0.20807247,1.6850038,,,,,,,,,,,,,,,,, -68600,0.19394135,1.5996791,,,,,,,,,,,,,,,,, -68700,0.20379199,1.6384737,,,,,,,,,,,,,,,,, -68800,0.21042557,1.7445294,,,,,,,,,,,,,,,,, -68900,0.21962598,1.7233744,,,,,,,,,,,,,,,,, -69000,0.2123954,1.6758488,,,,,,,,,,,,,,,,, -69100,0.22159258,1.7216218,,,,,,,,,,,,,,,,, -69200,0.20706904,1.6991361,,,,,,,,,,,,,,,,, -69300,0.22274028,1.7657691,,,,,,,,,,,,,,,,, -69400,0.21759097,1.6183215,,,,,,,,,,,,,,,,, -69500,0.18130478,1.6602197,,,,,,,,,,,,,,,,, -69600,0.19577432,1.6721398,,,,,,,,,,,,,,,,, -69700,0.19294567,1.6884862,,,,,,,,,,,,,,,,, -69800,0.20639893,1.6559922,,,,,,,,,,,,,,,,, -69823,,,0.6646310091018677,1.5558394193649292,33.108476509492625,0.6774001717567444,1.4817792177200315,29.732924813863225,3000.0,0.6909418702125549,1.3975272178649902,29.22117196182941,3003.0,25234.14297890663,42412.92000794411,25234.14297890663,17175.61923766136,0.9281957149505616,0.0 -69900,0.18861665,1.6420432,,,,,,,,,,,,,,,,, -70000,0.18853134,1.6733129,,,,,,,,,,,,,,,,, -70100,0.20572087,1.6804128,,,,,,,,,,,,,,,,, -70200,0.20372662,1.6172719,,,,,,,,,,,,,,,,, -70300,0.20629592,1.6318477,,,,,,,,,,,,,,,,, -70400,0.20647767,1.6741111,,,,,,,,,,,,,,,,, -70500,0.26944873,1.6449955,,,,,,,,,,,,,,,,, -70600,0.18602128,1.6845312,,,,,,,,,,,,,,,,, -70700,0.21395938,1.6546724,,,,,,,,,,,,,,,,, -70800,0.25303715,1.6572992,,,,,,,,,,,,,,,,, -70900,0.18978335,1.6489687,,,,,,,,,,,,,,,,, -71000,0.21219996,1.6220928,,,,,,,,,,,,,,,,, -71100,0.21390824,1.6553346,,,,,,,,,,,,,,,,, -71200,0.19000772,1.6579834,,,,,,,,,,,,,,,,, -71300,0.21030958,1.6777409,,,,,,,,,,,,,,,,, -71400,0.19803962,1.6908456,,,,,,,,,,,,,,,,, -71500,0.19670184,1.6734848,,,,,,,,,,,,,,,,, -71600,0.26474282,1.6091354,,,,,,,,,,,,,,,,, -71700,0.23982373,1.7247171,,,,,,,,,,,,,,,,, -71800,0.19652845,1.6236533,,,,,,,,,,,,,,,,, -71900,0.21485487,1.6568239,,,,,,,,,,,,,,,,, -72000,0.20640129,1.600156,,,,,,,,,,,,,,,,, -72100,0.20312834,1.6285444,,,,,,,,,,,,,,,,, -72151,,,0.6619483828544617,1.584561824798584,32.684329463700614,0.676978588104248,1.4765617847442627,29.4450521112024,3000.0,0.6935332417488098,1.3811521530151367,29.335350984322613,3003.0,26074.134548187256,43794.25180768967,26074.134548187256,17716.850410461426,0.9643950462341307,0.0 -72200,0.18988097,1.6440222,,,,,,,,,,,,,,,,, -72300,0.21144418,1.624483,,,,,,,,,,,,,,,,, -72400,0.19376798,1.7047056,,,,,,,,,,,,,,,,, -72500,0.20084807,1.7139379,,,,,,,,,,,,,,,,, -72600,0.19318911,1.6302835,,,,,,,,,,,,,,,,, -72700,0.19042443,1.5841904,,,,,,,,,,,,,,,,, -72800,0.20746735,1.6993307,,,,,,,,,,,,,,,,, -72900,0.2527875,1.758402,,,,,,,,,,,,,,,,, -73000,0.20451279,1.6727315,,,,,,,,,,,,,,,,, -73100,0.20008755,1.6501433,,,,,,,,,,,,,,,,, -73200,0.19688849,1.6726687,,,,,,,,,,,,,,,,, -73300,0.19933161,1.6491473,,,,,,,,,,,,,,,,, -73400,0.2112923,1.6032153,,,,,,,,,,,,,,,,, -73500,0.22601089,1.703228,,,,,,,,,,,,,,,,, -73600,0.20124735,1.5405895,,,,,,,,,,,,,,,,, -73700,0.19818585,1.738798,,,,,,,,,,,,,,,,, -73800,0.21523482,1.6656264,,,,,,,,,,,,,,,,, -73900,0.1885761,1.6773826,,,,,,,,,,,,,,,,, -74000,0.21219333,1.7362487,,,,,,,,,,,,,,,,, -74100,0.20485778,1.699761,,,,,,,,,,,,,,,,, -74200,0.20931004,1.6116328,,,,,,,,,,,,,,,,, -74300,0.18453197,1.6794926,,,,,,,,,,,,,,,,, -74400,0.19860876,1.6255763,,,,,,,,,,,,,,,,, -74478,,,0.6594235301017761,1.5999988317489624,32.67356970090248,0.6777969002723694,1.4699158668518066,29.571455647456062,3000.0,0.69459068775177,1.3774783611297607,29.20670844457939,3003.0,26914.059263944622,45132.13513803482,26914.059263944622,18214.69597125053,1.0013093948364258,0.0 -74500,0.19477563,1.62764,,,,,,,,,,,,,,,,, -74600,0.19581768,1.6343693,,,,,,,,,,,,,,,,, -74700,0.2043488,1.6777313,,,,,,,,,,,,,,,,, -74800,0.22370163,1.68193,,,,,,,,,,,,,,,,, -74900,0.21095924,1.6473004,,,,,,,,,,,,,,,,, -75000,0.20689,1.6577753,,,,,,,,,,,,,,,,, -75100,0.19487518,1.6591582,,,,,,,,,,,,,,,,, -75200,0.25349095,1.6563396,,,,,,,,,,,,,,,,, -75300,0.19217545,1.6144869,,,,,,,,,,,,,,,,, -75400,0.19231094,1.7212814,,,,,,,,,,,,,,,,, -75500,0.21902798,1.6613594,,,,,,,,,,,,,,,,, -75600,0.19668935,1.5715392,,,,,,,,,,,,,,,,, -75700,0.20094688,1.673276,,,,,,,,,,,,,,,,, -75800,0.21800731,1.6236874,,,,,,,,,,,,,,,,, -75900,0.21309267,1.7573491,,,,,,,,,,,,,,,,, -76000,0.2008532,1.6017598,,,,,,,,,,,,,,,,, -76100,0.22191148,1.6321553,,,,,,,,,,,,,,,,, -76200,0.21621965,1.6648206,,,,,,,,,,,,,,,,, -76300,0.22910072,1.6827884,,,,,,,,,,,,,,,,, -76400,0.19308111,1.601577,,,,,,,,,,,,,,,,, -76500,0.19972332,1.6121516,,,,,,,,,,,,,,,,, -76600,0.20754655,1.6447229,,,,,,,,,,,,,,,,, -76700,0.20781896,1.6880939,,,,,,,,,,,,,,,,, -76800,0.2086599,1.5725553,,,,,,,,,,,,,,,,, -76806,,,0.6694381833076477,1.5369466543197632,33.0531892644505,0.6802147626876831,1.4623948335647583,29.47381621833204,3000.0,0.6945674419403076,1.3712241649627686,29.462568269990985,3003.0,27754.242109537125,46509.57319569588,27754.242109537125,18751.840750455856,1.0371296405792236,0.0 -76900,0.20370029,1.7356551,,,,,,,,,,,,,,,,, -77000,0.20404972,1.6338953,,,,,,,,,,,,,,,,, -77100,0.18531194,1.6373388,,,,,,,,,,,,,,,,, -77200,0.20311704,1.6459072,,,,,,,,,,,,,,,,, -77300,0.20960246,1.6643361,,,,,,,,,,,,,,,,, -77400,0.20507024,1.6409111,,,,,,,,,,,,,,,,, -77500,0.21144581,1.6251143,,,,,,,,,,,,,,,,, -77600,0.21145238,1.6502962,,,,,,,,,,,,,,,,, -77700,0.21035266,1.6816179,,,,,,,,,,,,,,,,, -77800,0.21020727,1.6235789,,,,,,,,,,,,,,,,, -77900,0.20535058,1.695101,,,,,,,,,,,,,,,,, -78000,0.22626297,1.6310048,,,,,,,,,,,,,,,,, -78100,0.1986451,1.6346983,,,,,,,,,,,,,,,,, -78200,0.22885102,1.7012932,,,,,,,,,,,,,,,,, -78300,0.19721915,1.6798834,,,,,,,,,,,,,,,,, -78400,0.20295002,1.6735976,,,,,,,,,,,,,,,,, -78500,0.21801512,1.6112268,,,,,,,,,,,,,,,,, -78600,0.19312403,1.5357343,,,,,,,,,,,,,,,,, -78700,0.20715578,1.6438696,,,,,,,,,,,,,,,,, -78800,0.21587372,1.6749992,,,,,,,,,,,,,,,,, -78900,0.20573102,1.6380321,,,,,,,,,,,,,,,,, -79000,0.28538427,1.63473,,,,,,,,,,,,,,,,, -79100,0.20345612,1.5713527,,,,,,,,,,,,,,,,, -79133,,,0.6636430025100708,1.570115566253662,32.54973127642245,0.6804007291793823,1.45924973487854,29.552745828566785,3000.0,0.6958921551704407,1.3631484508514404,29.431835227142336,3003.0,28594.37763762474,47918.92470264435,28594.37763762474,19320.94375896454,1.0745155811309814,0.0 -79200,0.22834,1.6724588,,,,,,,,,,,,,,,,, -79300,0.19803636,1.6234484,,,,,,,,,,,,,,,,, -79400,0.20783006,1.6275822,,,,,,,,,,,,,,,,, -79500,0.21334535,1.5952644,,,,,,,,,,,,,,,,, -79600,0.21360612,1.684355,,,,,,,,,,,,,,,,, -79700,0.2080651,1.6373097,,,,,,,,,,,,,,,,, -79800,0.22074871,1.6351777,,,,,,,,,,,,,,,,, -79900,0.21005893,1.6526923,,,,,,,,,,,,,,,,, -80000,0.20146936,1.6768698,,,,,,,,,,,,,,,,, -80100,0.21135981,1.6065484,,,,,,,,,,,,,,,,, -80200,0.21971188,1.6702063,,,,,,,,,,,,,,,,, -80300,0.19947013,1.6797185,,,,,,,,,,,,,,,,, -80400,2.5033486,1.5966997,,,,,,,,,,,,,,,,, -80500,0.2037861,1.7569474,,,,,,,,,,,,,,,,, -80600,0.21913183,1.6083344,,,,,,,,,,,,,,,,, -80700,0.20267119,1.6441935,,,,,,,,,,,,,,,,, -80800,0.21272363,1.6165551,,,,,,,,,,,,,,,,, -80900,0.207656,1.5451435,,,,,,,,,,,,,,,,, -81000,0.19394171,1.6988504,,,,,,,,,,,,,,,,, -81100,0.20509815,1.6566696,,,,,,,,,,,,,,,,, -81200,0.19885291,1.6670003,,,,,,,,,,,,,,,,, -81300,0.24113385,1.640716,,,,,,,,,,,,,,,,, -81400,0.19951864,1.624534,,,,,,,,,,,,,,,,, -81461,,,0.6872650384902954,1.421802282333374,34.87201890341066,0.6823846101760864,1.4490602016448977,30.02667336732894,3000.0,0.6979489922523499,1.3512593507766724,29.82934247972239,3003.0,29434.364958763123,49268.42466902733,29434.364958763123,19830.34655547142,1.111635684967041,0.0 -81500,0.21828046,1.6560805,,,,,,,,,,,,,,,,, -81600,0.21842799,1.6059343,,,,,,,,,,,,,,,,, -81700,0.20981097,1.6172111,,,,,,,,,,,,,,,,, -81800,0.20852962,1.6110513,,,,,,,,,,,,,,,,, -81900,0.19055231,1.5920057,,,,,,,,,,,,,,,,, -82000,0.21143137,1.6840869,,,,,,,,,,,,,,,,, -82100,0.21898817,1.5409048,,,,,,,,,,,,,,,,, -82200,0.18949072,1.6274723,,,,,,,,,,,,,,,,, -82300,0.22785652,1.5964786,,,,,,,,,,,,,,,,, -82400,0.21495357,1.6533041,,,,,,,,,,,,,,,,, -82500,0.2048758,1.6367147,,,,,,,,,,,,,,,,, -82600,0.22735545,1.7119205,,,,,,,,,,,,,,,,, -82700,0.19973093,1.6428024,,,,,,,,,,,,,,,,, -82800,0.21187557,1.6935382,,,,,,,,,,,,,,,,, -82900,0.20777023,1.6130856,,,,,,,,,,,,,,,,, -83000,0.24068561,1.6435502,,,,,,,,,,,,,,,,, -83100,0.20897311,1.6084968,,,,,,,,,,,,,,,,, -83200,0.20963505,1.6268014,,,,,,,,,,,,,,,,, -83300,0.20859206,1.5951204,,,,,,,,,,,,,,,,, -83400,0.20272468,1.6144947,,,,,,,,,,,,,,,,, -83500,0.19817336,1.5485241,,,,,,,,,,,,,,,,, -83600,0.2199746,1.5951709,,,,,,,,,,,,,,,,, -83700,0.20439374,1.7076248,,,,,,,,,,,,,,,,, -83788,,,0.6727224588394165,1.5115309953689575,33.53385947656073,0.6825209856033325,1.441343903541565,29.710900837467445,3000.0,0.6982162594795227,1.348582744598389,29.646421542412785,3003.0,30274.487243413925,50868.4852745533,30274.487243413925,20590.173165798187,1.1480538845062256,0.0 -83800,0.20192882,1.621514,,,,,,,,,,,,,,,,, -83900,0.21265905,1.6291354,,,,,,,,,,,,,,,,, -84000,0.21400271,1.6584563,,,,,,,,,,,,,,,,, -84100,0.18465123,1.6345009,,,,,,,,,,,,,,,,, -84200,0.23395084,1.6536028,,,,,,,,,,,,,,,,, -84300,0.20358656,1.5502318,,,,,,,,,,,,,,,,, -84400,0.20073925,1.6329547,,,,,,,,,,,,,,,,, -84500,0.2024757,1.5823741,,,,,,,,,,,,,,,,, -84600,0.23100226,1.5911304,,,,,,,,,,,,,,,,, -84700,0.19824228,1.6266338,,,,,,,,,,,,,,,,, -84800,0.2114836,1.6182706,,,,,,,,,,,,,,,,, -84900,0.19579305,1.651436,,,,,,,,,,,,,,,,, -85000,0.2416704,1.5573089,,,,,,,,,,,,,,,,, -85100,0.21561646,1.594116,,,,,,,,,,,,,,,,, -85200,0.20891103,1.6121411,,,,,,,,,,,,,,,,, -85300,0.20117271,1.5512031,,,,,,,,,,,,,,,,, -85400,0.20596214,1.6756988,,,,,,,,,,,,,,,,, -85500,0.2264175,1.5677094,,,,,,,,,,,,,,,,, -85600,0.2092522,1.5858885,,,,,,,,,,,,,,,,, -85700,0.22923654,1.670185,,,,,,,,,,,,,,,,, -85800,0.20951973,1.5877292,,,,,,,,,,,,,,,,, -85900,0.20557274,1.5614008,,,,,,,,,,,,,,,,, -86000,0.203047,1.6126151,,,,,,,,,,,,,,,,, -86100,0.20383969,1.5439,,,,,,,,,,,,,,,,, -86116,,,0.6685099005699158,1.5362203121185305,33.509164043830715,0.6834757328033447,1.4383317232131958,30.021007104354624,3000.0,0.7002963423728943,1.3392986059188845,30.026394171826187,3003.0,31114.39369344712,52270.26392364502,31114.39369344712,21151.93498682976,1.1864585876464844,0.0 -86200,0.20122996,1.5836027,,,,,,,,,,,,,,,,, -86300,0.20472848,1.5943334,,,,,,,,,,,,,,,,, -86400,0.19889463,1.6207937,,,,,,,,,,,,,,,,, -86500,0.20966667,1.6572567,,,,,,,,,,,,,,,,, -86600,0.21157038,1.6427693,,,,,,,,,,,,,,,,, -86700,0.20478654,1.5935442,,,,,,,,,,,,,,,,, -86800,0.20749721,1.5685991,,,,,,,,,,,,,,,,, -86900,0.20260215,1.6167706,,,,,,,,,,,,,,,,, -87000,0.19604601,1.5129822,,,,,,,,,,,,,,,,, -87100,0.21649586,1.4874387,,,,,,,,,,,,,,,,, -87200,0.20390442,1.5748405,,,,,,,,,,,,,,,,, -87300,0.20985086,1.5518967,,,,,,,,,,,,,,,,, -87400,0.2118555,1.6617578,,,,,,,,,,,,,,,,, -87500,0.20482437,1.5353589,,,,,,,,,,,,,,,,, -87600,0.20947397,1.5625194,,,,,,,,,,,,,,,,, -87700,0.20463893,1.5565556,,,,,,,,,,,,,,,,, -87800,0.21877702,1.6169415,,,,,,,,,,,,,,,,, -87900,0.2068597,1.5834959,,,,,,,,,,,,,,,,, -88000,0.21230818,1.563703,,,,,,,,,,,,,,,,, -88100,0.23691869,1.574864,,,,,,,,,,,,,,,,, -88200,0.21765545,1.6649086,,,,,,,,,,,,,,,,, -88300,0.20860814,1.6787508,,,,,,,,,,,,,,,,, -88400,0.2037219,1.6018072,,,,,,,,,,,,,,,,, -88444,,,0.6788003444671631,1.4714220762252808,33.85755785036112,0.6860795021057129,1.428539156913757,30.326511463742506,3000.0,0.7007378935813904,1.3346697092056274,30.07429817068781,3003.0,31954.47871041298,53697.33389949799,31954.47871041298,21738.805801153183,1.2259869575500488,0.0 -88500,0.21401043,1.5550245,,,,,,,,,,,,,,,,, -88600,0.22458297,1.6094302,,,,,,,,,,,,,,,,, -88700,0.22973497,1.6257488,,,,,,,,,,,,,,,,, -88800,0.2351017,1.6351303,,,,,,,,,,,,,,,,, -88900,0.20962545,1.604171,,,,,,,,,,,,,,,,, -89000,0.24964032,1.610918,,,,,,,,,,,,,,,,, -89100,0.19893923,1.5929921,,,,,,,,,,,,,,,,, -89200,0.20777914,1.6093781,,,,,,,,,,,,,,,,, -89300,0.20505509,1.5433148,,,,,,,,,,,,,,,,, -89400,0.20924398,1.6675394,,,,,,,,,,,,,,,,, -89500,0.20790096,1.5527285,,,,,,,,,,,,,,,,, -89600,0.20684907,1.5195874,,,,,,,,,,,,,,,,, -89700,0.20534666,1.54792,,,,,,,,,,,,,,,,, -89800,0.20835005,1.5265385,,,,,,,,,,,,,,,,, -89900,0.20657787,1.530672,,,,,,,,,,,,,,,,, -90000,0.20590466,1.5373518,,,,,,,,,,,,,,,,, -90100,0.2118289,1.6132293,,,,,,,,,,,,,,,,, -90200,0.21797663,1.6250936,,,,,,,,,,,,,,,,, -90300,0.2206593,1.5738317,,,,,,,,,,,,,,,,, -90400,0.22131574,1.5386294,,,,,,,,,,,,,,,,, -90500,0.24730797,1.5926661,,,,,,,,,,,,,,,,, -90600,0.22433256,1.6037589,,,,,,,,,,,,,,,,, -90700,0.2065223,1.6362425,,,,,,,,,,,,,,,,, -90772,,,0.6721503138542175,1.5121186971664429,33.66060615621202,0.6863027215003967,1.4202866554260254,30.17695861596039,3000.0,0.7035732865333557,1.3192323446273804,30.3331563210614,3003.0,32794.429097890854,55212.970027923584,32794.429097890854,22414.37882566452,1.265702247619629,0.0 -90800,0.21689892,1.6680171,,,,,,,,,,,,,,,,, -90900,0.21080242,1.5893413,,,,,,,,,,,,,,,,, -91000,0.21713471,1.5463077,,,,,,,,,,,,,,,,, -91100,0.2286908,1.5822341,,,,,,,,,,,,,,,,, -91200,0.21402599,1.5225753,,,,,,,,,,,,,,,,, -91300,0.21017282,1.5106071,,,,,,,,,,,,,,,,, -91400,0.21419759,1.6114124,,,,,,,,,,,,,,,,, -91500,0.19818656,1.5805951,,,,,,,,,,,,,,,,, -91600,0.21245624,1.5812624,,,,,,,,,,,,,,,,, -91700,0.22480422,1.5367377,,,,,,,,,,,,,,,,, -91800,0.20610206,1.5357333,,,,,,,,,,,,,,,,, -91900,0.19824931,1.5808511,,,,,,,,,,,,,,,,, -92000,0.22629656,1.626631,,,,,,,,,,,,,,,,, -92100,0.21310702,1.6529104,,,,,,,,,,,,,,,,, -92200,0.20602866,1.5308377,,,,,,,,,,,,,,,,, -92300,0.21053801,1.5589331,,,,,,,,,,,,,,,,, -92400,0.2092063,1.5313987,,,,,,,,,,,,,,,,, -92500,0.2205467,1.632283,,,,,,,,,,,,,,,,, -92600,0.20757735,1.6841807,,,,,,,,,,,,,,,,, -92700,0.20971099,1.5321603,,,,,,,,,,,,,,,,, -92800,0.23373662,1.5913013,,,,,,,,,,,,,,,,, -92900,0.21373288,1.5600455,,,,,,,,,,,,,,,,, -93000,0.22117166,1.5332737,,,,,,,,,,,,,,,,, -93100,,,0.6721413135528564,1.5171512365341189,33.86056555520624,0.6868358850479126,1.417452335357666,30.403902679657968,3000.0,0.7037127614021301,1.315152645111084,30.392667213925773,3003.0,33634.516040802,56644.627078294754,33634.516040802,23005.836725473404,1.3031470775604248,0.0 -93100,0.20675698,1.5039774,,,,,,,,,,,,,,,,, -93200,0.21484713,1.5975541,,,,,,,,,,,,,,,,, -93300,0.20935528,1.6039801,,,,,,,,,,,,,,,,, -93400,0.20670456,1.6111473,,,,,,,,,,,,,,,,, -93500,0.1954499,1.5902948,,,,,,,,,,,,,,,,, -93600,0.20573096,1.5797076,,,,,,,,,,,,,,,,, -93700,0.22609873,1.6621488,,,,,,,,,,,,,,,,, -93800,0.238235,1.6901366,,,,,,,,,,,,,,,,, -93900,0.20874013,1.5374055,,,,,,,,,,,,,,,,, -94000,0.20276079,1.5869508,,,,,,,,,,,,,,,,, -94100,0.22656614,1.62397,,,,,,,,,,,,,,,,, -94200,0.20576277,1.5410671,,,,,,,,,,,,,,,,, -94300,0.21411741,1.5308714,,,,,,,,,,,,,,,,, -94400,0.22015415,1.6385065,,,,,,,,,,,,,,,,, -94500,0.24930845,1.5685247,,,,,,,,,,,,,,,,, -94600,0.20735319,1.56004,,,,,,,,,,,,,,,,, -94700,0.20877033,1.628413,,,,,,,,,,,,,,,,, -94800,0.23426855,1.6588291,,,,,,,,,,,,,,,,, -94900,0.21269706,1.6143143,,,,,,,,,,,,,,,,, -95000,0.19713429,1.5607826,,,,,,,,,,,,,,,,, -95100,0.20835225,1.461334,,,,,,,,,,,,,,,,, -95200,0.2274768,1.6416454,,,,,,,,,,,,,,,,, -95300,0.21882546,1.5903873,,,,,,,,,,,,,,,,, -95400,0.21255201,1.5816505,,,,,,,,,,,,,,,,, -95428,,,0.6813426613807678,1.451947808265686,34.027032781976644,0.6877037882804871,1.4141026735305786,30.14882976469225,3000.0,0.7046656608581543,1.3117072582244873,30.250803153996863,3003.0,34474.600940704346,57972.71556472778,34474.600940704346,23493.724118709564,1.34435772895813,0.0 -95500,0.20295854,1.5819849,,,,,,,,,,,,,,,,, -95600,0.20262007,1.506334,,,,,,,,,,,,,,,,, -95700,0.21617605,1.566701,,,,,,,,,,,,,,,,, -95800,0.21581152,1.5656544,,,,,,,,,,,,,,,,, -95900,0.21253641,1.5569649,,,,,,,,,,,,,,,,, -96000,0.22048478,1.624944,,,,,,,,,,,,,,,,, -96100,0.22624378,1.5409817,,,,,,,,,,,,,,,,, -96200,0.22240138,1.581563,,,,,,,,,,,,,,,,, -96300,0.20744343,1.5316991,,,,,,,,,,,,,,,,, -96400,0.23026128,1.6365255,,,,,,,,,,,,,,,,, -96500,0.21765544,1.5005968,,,,,,,,,,,,,,,,, -96600,0.23034135,1.5811688,,,,,,,,,,,,,,,,, -96700,0.2168646,1.6380386,,,,,,,,,,,,,,,,, -96800,0.22079487,1.579305,,,,,,,,,,,,,,,,, -96900,0.22202253,1.5605942,,,,,,,,,,,,,,,,, -97000,0.21193543,1.6222036,,,,,,,,,,,,,,,,, -97100,0.20391491,1.5092986,,,,,,,,,,,,,,,,, -97200,0.21824601,1.6165375,,,,,,,,,,,,,,,,, -97300,0.23401247,1.5994314,,,,,,,,,,,,,,,,, -97400,0.23677328,1.5091432,,,,,,,,,,,,,,,,, -97500,0.20717525,1.4982345,,,,,,,,,,,,,,,,, -97600,0.20789759,1.5287241,,,,,,,,,,,,,,,,, -97700,0.22738816,1.6329244,,,,,,,,,,,,,,,,, -97755,,,0.6785141825675964,1.4695278406143188,34.25009675294788,0.6903820037841797,1.4033619165420532,30.519843494061867,3000.0,0.7072570323944092,1.2985440492630005,30.597164472048394,3003.0,35314.54950070381,59400.434348106384,35314.54950070381,24081.37470555305,1.387169599533081,0.0 -97800,0.22144319,1.5192069,,,,,,,,,,,,,,,,, -97900,0.2147031,1.4920558,,,,,,,,,,,,,,,,, -98000,0.21753713,1.6014513,,,,,,,,,,,,,,,,, -98100,0.241145,1.5932266,,,,,,,,,,,,,,,,, -98200,0.2132888,1.4919853,,,,,,,,,,,,,,,,, -98300,0.2334675,1.6096023,,,,,,,,,,,,,,,,, -98400,0.22414105,1.612018,,,,,,,,,,,,,,,,, -98500,0.22720326,1.55475,,,,,,,,,,,,,,,,, -98600,0.21580918,1.4560797,,,,,,,,,,,,,,,,, -98700,0.23058134,1.5193006,,,,,,,,,,,,,,,,, -98800,0.21401703,1.6066221,,,,,,,,,,,,,,,,, -98900,0.22103624,1.5788541,,,,,,,,,,,,,,,,, -99000,0.21266657,1.5486462,,,,,,,,,,,,,,,,, -99100,0.21086648,1.5388997,,,,,,,,,,,,,,,,, -99200,0.2087203,1.5091966,,,,,,,,,,,,,,,,, -99300,0.21036047,1.5111952,,,,,,,,,,,,,,,,, -99400,0.20657918,1.5721579,,,,,,,,,,,,,,,,, -99500,0.22281805,1.5232279,,,,,,,,,,,,,,,,, -99600,0.21897168,1.5305597,,,,,,,,,,,,,,,,, -99700,0.21390803,1.5638056,,,,,,,,,,,,,,,,, -99800,0.22465348,1.5085661,,,,,,,,,,,,,,,,, -99900,0.22953711,1.5162964,,,,,,,,,,,,,,,,, -100000,0.22675993,1.5489722,,,,,,,,,,,,,,,,, -100083,,,0.6936557292938232,1.3875911235809326,35.65112004922339,0.6906548142433167,1.400636911392212,30.85509335453766,3000.0,0.706187903881073,1.299963116645813,30.34443331307036,3003.0,36154.70352482796,60778.641570568085,36154.70352482796,24619.3114862442,1.4285976886749268,0.0 -100100,0.23511206,1.5735186,,,,,,,,,,,,,,,,, -100200,0.20491025,1.5391271,,,,,,,,,,,,,,,,, -100300,0.2105485,1.4618075,,,,,,,,,,,,,,,,, -100400,0.20985672,1.6040167,,,,,,,,,,,,,,,,, -100500,0.23032297,1.5736713,,,,,,,,,,,,,,,,, -100600,0.21478452,1.5439227,,,,,,,,,,,,,,,,, -100700,0.2117618,1.4815619,,,,,,,,,,,,,,,,, -100800,0.22505726,1.4718537,,,,,,,,,,,,,,,,, -100900,0.28814363,1.624716,,,,,,,,,,,,,,,,, -101000,0.22520295,1.5442848,,,,,,,,,,,,,,,,, -101100,0.20533888,1.4960345,,,,,,,,,,,,,,,,, -101200,0.23015139,1.5402024,,,,,,,,,,,,,,,,, -101300,0.22432536,1.5112197,,,,,,,,,,,,,,,,, -101400,0.22974484,1.5312572,,,,,,,,,,,,,,,,, -101500,0.22739533,1.5851191,,,,,,,,,,,,,,,,, -101600,0.22409861,1.6190284,,,,,,,,,,,,,,,,, -101700,0.25467858,1.5982518,,,,,,,,,,,,,,,,, -101800,0.23355182,1.6036509,,,,,,,,,,,,,,,,, -101900,0.22702532,1.5952014,,,,,,,,,,,,,,,,, -102000,0.22403842,1.5727714,,,,,,,,,,,,,,,,, -102100,0.23135385,1.4538051,,,,,,,,,,,,,,,,, -102200,0.22629125,1.5456363,,,,,,,,,,,,,,,,, -102300,0.21706142,1.5368832,,,,,,,,,,,,,,,,, -102400,0.2260945,1.5300971,,,,,,,,,,,,,,,,, -102411,,,0.6837599873542786,1.4399648904800415,34.42078439556147,0.691696286201477,1.3963810205459597,30.65930224150146,3000.0,0.7067921757698059,1.296403884887695,30.49037124141108,3003.0,36994.59205460549,62245.28328108788,36994.59205460549,25245.94492340088,1.4772801399230957,0.0 -102500,0.21818879,1.5249822,,,,,,,,,,,,,,,,, -102600,0.22916754,1.5686015,,,,,,,,,,,,,,,,, -102700,0.23092227,1.6336987,,,,,,,,,,,,,,,,, -102800,0.21273583,1.5096345,,,,,,,,,,,,,,,,, -102900,0.22165212,1.5074546,,,,,,,,,,,,,,,,, -103000,0.22516131,1.509155,,,,,,,,,,,,,,,,, -103100,0.22816321,1.5334783,,,,,,,,,,,,,,,,, -103200,0.2255388,1.5222389,,,,,,,,,,,,,,,,, -103300,0.22066902,1.5333214,,,,,,,,,,,,,,,,, -103400,0.22554316,1.5353557,,,,,,,,,,,,,,,,, -103500,0.23170106,1.4695297,,,,,,,,,,,,,,,,, -103600,0.23504516,1.4872969,,,,,,,,,,,,,,,,, -103700,0.22327663,1.5525863,,,,,,,,,,,,,,,,, -103800,0.22358637,1.5118896,,,,,,,,,,,,,,,,, -103900,0.22235192,1.4815823,,,,,,,,,,,,,,,,, -104000,0.22168866,1.5060576,,,,,,,,,,,,,,,,, -104100,0.22970033,1.5128659,,,,,,,,,,,,,,,,, -104200,0.23123656,1.5090746,,,,,,,,,,,,,,,,, -104300,0.22850214,1.5099442,,,,,,,,,,,,,,,,, -104400,0.23884109,1.530866,,,,,,,,,,,,,,,,, -104500,0.2301809,1.5621861,,,,,,,,,,,,,,,,, -104600,0.22100306,1.5438617,,,,,,,,,,,,,,,,, -104700,0.2289161,1.5181286,,,,,,,,,,,,,,,,, -104739,,,0.6803861260414124,1.4522088766098022,34.19161231852919,0.6913491487503052,1.3914282321929932,30.690073446109658,3000.0,0.7083609700202942,1.2872958183288574,30.74846146018644,3003.0,37834.66395068169,63622.94854474068,37834.66395068169,25783.427124261856,1.5160844326019287,0.0 -104800,0.22195122,1.4764445,,,,,,,,,,,,,,,,, -104900,0.23394321,1.5401177,,,,,,,,,,,,,,,,, -105000,0.22083423,1.5051488,,,,,,,,,,,,,,,,, -105100,0.21193615,1.4653968,,,,,,,,,,,,,,,,, -105200,0.22116934,1.5710367,,,,,,,,,,,,,,,,, -105300,0.22088563,1.5244837,,,,,,,,,,,,,,,,, -105400,0.22846182,1.5156202,,,,,,,,,,,,,,,,, -105500,0.23284037,1.4619156,,,,,,,,,,,,,,,,, -105600,0.24035902,1.4894072,,,,,,,,,,,,,,,,, -105700,0.22401458,1.5056105,,,,,,,,,,,,,,,,, -105800,0.22631645,1.5554396,,,,,,,,,,,,,,,,, -105900,0.23454528,1.5665267,,,,,,,,,,,,,,,,, -106000,0.21246155,1.5212101,,,,,,,,,,,,,,,,, -106100,0.22630288,1.5235016,,,,,,,,,,,,,,,,, -106200,0.24370326,1.4069067,,,,,,,,,,,,,,,,, -106300,0.2290271,1.501154,,,,,,,,,,,,,,,,, -106400,0.23109163,1.4751428,,,,,,,,,,,,,,,,, -106500,0.22929391,1.554266,,,,,,,,,,,,,,,,, -106600,0.24444383,1.5872295,,,,,,,,,,,,,,,,, -106700,0.21998708,1.4826124,,,,,,,,,,,,,,,,, -106800,0.227388,1.5268109,,,,,,,,,,,,,,,,, -106900,0.21177338,1.4655201,,,,,,,,,,,,,,,,, -107000,0.22519533,1.52802,,,,,,,,,,,,,,,,, -107068,,,0.691776692867279,1.3914330005645752,34.81961982311779,0.6922914981842041,1.3859933614730835,30.58321900927457,3000.0,0.708570122718811,1.2839752435684204,30.67077610265788,3003.0,38674.80152177811,65018.60505485535,38674.80152177811,26338.833918571472,1.5553884506225586,0.0 -107100,0.2329591,1.527521,,,,,,,,,,,,,,,,, -107200,0.23157847,1.5073178,,,,,,,,,,,,,,,,, -107300,0.23061553,1.5230706,,,,,,,,,,,,,,,,, -107400,0.23217003,1.5930386,,,,,,,,,,,,,,,,, -107500,0.2402208,1.4731336,,,,,,,,,,,,,,,,, -107600,0.23630875,1.4485757,,,,,,,,,,,,,,,,, -107700,0.23352814,1.5317659,,,,,,,,,,,,,,,,, -107800,0.23628221,1.5840205,,,,,,,,,,,,,,,,, -107900,0.23121382,1.5138296,,,,,,,,,,,,,,,,, -108000,0.22358859,1.5237288,,,,,,,,,,,,,,,,, -108100,0.22040881,1.4874848,,,,,,,,,,,,,,,,, -108200,0.22785102,1.5323924,,,,,,,,,,,,,,,,, -108300,0.23385273,1.565137,,,,,,,,,,,,,,,,, -108400,0.22941037,1.5512387,,,,,,,,,,,,,,,,, -108500,0.2382393,1.5017264,,,,,,,,,,,,,,,,, -108600,0.23600744,1.5587717,,,,,,,,,,,,,,,,, -108700,0.21755147,1.4934126,,,,,,,,,,,,,,,,, -108800,0.22773586,1.520055,,,,,,,,,,,,,,,,, -108900,0.2247769,1.4576355,,,,,,,,,,,,,,,,, -109000,0.2498683,1.4972147,,,,,,,,,,,,,,,,, -109100,0.22271265,1.5087618,,,,,,,,,,,,,,,,, -109200,0.22275166,1.4877219,,,,,,,,,,,,,,,,, -109300,0.24478264,1.52362,,,,,,,,,,,,,,,,, -109397,,,0.6863054037094116,1.422512769699097,34.77338731179625,0.692328691482544,1.3812284469604492,30.9349353146716,3000.0,0.7111498713493347,1.274350881576538,31.09661239196818,3003.0,39515.00077295303,66434.0912425518,39515.00077295303,26914.005674123764,1.5969746112823486,0.0 -109400,0.2396723,1.5610975,,,,,,,,,,,,,,,,, -109500,0.23638588,1.4845011,,,,,,,,,,,,,,,,, -109600,0.21623984,1.4942319,,,,,,,,,,,,,,,,, -109700,0.23492463,1.540625,,,,,,,,,,,,,,,,, -109800,0.22515665,1.4971613,,,,,,,,,,,,,,,,, -109900,0.21611112,1.5463425,,,,,,,,,,,,,,,,, -110000,0.22091228,1.5072881,,,,,,,,,,,,,,,,, -110100,0.22667211,1.509083,,,,,,,,,,,,,,,,, -110200,0.22643536,1.5241591,,,,,,,,,,,,,,,,, -110300,0.22182567,1.5510261,,,,,,,,,,,,,,,,, -110400,0.23500583,1.4842898,,,,,,,,,,,,,,,,, -110500,0.23149842,1.497777,,,,,,,,,,,,,,,,, -110600,0.23939645,1.4986526,,,,,,,,,,,,,,,,, -110700,0.24290974,1.4902289,,,,,,,,,,,,,,,,, -110800,0.23205587,1.4979192,,,,,,,,,,,,,,,,, -110900,0.22782402,1.5879444,,,,,,,,,,,,,,,,, -111000,0.22070687,1.494947,,,,,,,,,,,,,,,,, -111100,0.2199675,1.483411,,,,,,,,,,,,,,,,, -111200,0.23128822,1.446215,,,,,,,,,,,,,,,,, -111300,0.2335161,1.4947999,,,,,,,,,,,,,,,,, -111400,0.23531932,1.4981303,,,,,,,,,,,,,,,,, -111500,0.2306406,1.5214243,,,,,,,,,,,,,,,,, -111600,0.23376293,1.5084815,,,,,,,,,,,,,,,,, -111700,0.22996064,1.4668639,,,,,,,,,,,,,,,,, -111725,,,0.6838766932487488,1.4372414350509644,34.7839582194981,0.6921178698539734,1.376046061515808,30.739287863649857,3000.0,0.7117308974266052,1.268806219100952,31.12296226341149,3003.0,40354.97988796234,67818.93628644943,40354.97988796234,27458.749623537064,1.6440293788909912,0.0 -111800,0.22065566,1.4210789,,,,,,,,,,,,,,,,, -111900,0.23179899,1.4804991,,,,,,,,,,,,,,,,, -112000,0.24751005,1.4624653,,,,,,,,,,,,,,,,, -112100,0.21980059,1.496451,,,,,,,,,,,,,,,,, -112200,0.22910239,1.4397132,,,,,,,,,,,,,,,,, -112300,0.23688982,1.4990505,,,,,,,,,,,,,,,,, -112400,0.22907047,1.4935616,,,,,,,,,,,,,,,,, -112500,0.23425676,1.4668413,,,,,,,,,,,,,,,,, -112600,0.2385299,1.4922438,,,,,,,,,,,,,,,,, -112700,0.23876661,1.5287994,,,,,,,,,,,,,,,,, -112800,0.23697484,1.5433071,,,,,,,,,,,,,,,,, -112900,0.2357098,1.44262,,,,,,,,,,,,,,,,, -113000,0.23434998,1.522132,,,,,,,,,,,,,,,,, -113100,0.25667855,1.5462303,,,,,,,,,,,,,,,,, -113200,0.2551429,1.5358938,,,,,,,,,,,,,,,,, -113300,0.24139176,1.5423301,,,,,,,,,,,,,,,,, -113400,0.23347278,1.501748,,,,,,,,,,,,,,,,, -113500,0.21942958,1.4095964,,,,,,,,,,,,,,,,, -113600,0.24472515,1.4641837,,,,,,,,,,,,,,,,, -113700,0.2362114,1.4494182,,,,,,,,,,,,,,,,, -113800,0.2290194,1.4620624,,,,,,,,,,,,,,,,, -113900,0.23424329,1.5018045,,,,,,,,,,,,,,,,, -114000,0.23007025,1.4157977,,,,,,,,,,,,,,,,, -114053,,,0.6951395273208618,1.3722827434539795,35.48070667530032,0.6938289403915405,1.3752260208129885,30.685051686956182,3000.0,0.7118470668792725,1.2679171562194824,31.035798009097697,3003.0,41195.22360944748,69239.3367304802,41195.22360944748,28038.789351701736,1.6860826015472412,0.0 -114100,0.23116492,1.4838678,,,,,,,,,,,,,,,,, -114200,0.22463234,1.4136561,,,,,,,,,,,,,,,,, -114300,0.23311803,1.4004226,,,,,,,,,,,,,,,,, -114400,0.22239311,1.4624785,,,,,,,,,,,,,,,,, -114500,0.22642897,1.4855353,,,,,,,,,,,,,,,,, -114600,0.2499393,1.542905,,,,,,,,,,,,,,,,, -114700,0.22700304,1.4824775,,,,,,,,,,,,,,,,, -114800,0.23683704,1.4578093,,,,,,,,,,,,,,,,, -114900,0.22489502,1.5623134,,,,,,,,,,,,,,,,, -115000,0.23688734,1.4592733,,,,,,,,,,,,,,,,, -115100,0.24121608,1.4628073,,,,,,,,,,,,,,,,, -115200,0.24161954,1.5268289,,,,,,,,,,,,,,,,, -115300,0.23435654,1.4720341,,,,,,,,,,,,,,,,, -115400,0.23053493,1.5044179,,,,,,,,,,,,,,,,, -115500,0.23614682,1.4538354,,,,,,,,,,,,,,,,, -115600,0.23162721,1.4747461,,,,,,,,,,,,,,,,, -115700,0.2298565,1.4116187,,,,,,,,,,,,,,,,, -115800,0.22967021,1.512448,,,,,,,,,,,,,,,,, -115900,0.2226428,1.4732782,,,,,,,,,,,,,,,,, -116000,0.23334266,1.4753782,,,,,,,,,,,,,,,,, -116100,0.23565169,1.4661208,,,,,,,,,,,,,,,,, -116200,0.23826478,1.5392747,,,,,,,,,,,,,,,,, -116300,0.23782592,1.3753127,,,,,,,,,,,,,,,,, -116381,,,0.6891108155250549,1.4094853401184082,35.24992178839545,0.6954160332679749,1.370386242866516,30.92770434969378,3000.0,0.713706374168396,1.2614657878875732,31.34912036522872,3003.0,42035.293724536896,70656.65222883224,42035.293724536896,28615.922476530075,1.726020097732544,0.0 -116400,0.22645487,1.4736246,,,,,,,,,,,,,,,,, -116500,0.23188677,1.4328521,,,,,,,,,,,,,,,,, -116600,0.2205405,1.4182433,,,,,,,,,,,,,,,,, -116700,0.25060263,1.511299,,,,,,,,,,,,,,,,, -116800,0.2530103,1.5185714,,,,,,,,,,,,,,,,, -116900,0.23552631,1.5062375,,,,,,,,,,,,,,,,, -117000,0.22757235,1.406748,,,,,,,,,,,,,,,,, -117100,0.23322369,1.4832473,,,,,,,,,,,,,,,,, -117200,0.2451887,1.4788029,,,,,,,,,,,,,,,,, -117300,0.24630485,1.5368541,,,,,,,,,,,,,,,,, -117400,0.23689917,1.4907477,,,,,,,,,,,,,,,,, -117500,0.22672059,1.446768,,,,,,,,,,,,,,,,, -117600,0.23380598,1.4529394,,,,,,,,,,,,,,,,, -117700,0.23478605,1.4459448,,,,,,,,,,,,,,,,, -117800,0.2404794,1.4484521,,,,,,,,,,,,,,,,, -117900,0.24570446,1.4318827,,,,,,,,,,,,,,,,, -118000,0.21964154,1.3568898,,,,,,,,,,,,,,,,, -118100,0.2295275,1.4030025,,,,,,,,,,,,,,,,, -118200,0.23422599,1.4960049,,,,,,,,,,,,,,,,, -118300,0.23565663,1.4716077,,,,,,,,,,,,,,,,, -118400,0.23194437,1.447225,,,,,,,,,,,,,,,,, -118500,0.24464437,1.577408,,,,,,,,,,,,,,,,, -118600,0.24093075,1.615765,,,,,,,,,,,,,,,,, -118700,0.23073097,1.5291001,,,,,,,,,,,,,,,,, -118709,,,0.6923836469650269,1.3908350467681885,34.80738047912399,0.6954408288002014,1.3662816286087036,31.010793828916093,3000.0,0.7125559449195862,1.260074496269226,31.390352867613373,3003.0,42875.2884888649,72050.43827652931,42875.2884888649,29169.59684228897,1.7702577114105225,0.0 -118800,0.23000456,1.3871716,,,,,,,,,,,,,,,,, -118900,0.23593876,1.4897028,,,,,,,,,,,,,,,,, -119000,0.23196076,1.4465445,,,,,,,,,,,,,,,,, -119100,0.22857212,1.4107283,,,,,,,,,,,,,,,,, -119200,0.23596053,1.4884001,,,,,,,,,,,,,,,,, -119300,0.23440307,1.449095,,,,,,,,,,,,,,,,, -119400,0.24414456,1.4749998,,,,,,,,,,,,,,,,, -119500,0.24296865,1.4655894,,,,,,,,,,,,,,,,, -119600,0.23894681,1.4672958,,,,,,,,,,,,,,,,, -119700,0.23013939,1.388907,,,,,,,,,,,,,,,,, -119800,0.22058438,1.4775069,,,,,,,,,,,,,,,,, -119900,0.2388342,1.4693669,,,,,,,,,,,,,,,,, -120000,0.23974624,1.4601173,,,,,,,,,,,,,,,,, -120100,0.23037569,1.4289049,,,,,,,,,,,,,,,,, -120200,0.23832127,1.4761498,,,,,,,,,,,,,,,,, -120300,0.23843977,1.478283,,,,,,,,,,,,,,,,, -120400,0.23911932,1.5071403,,,,,,,,,,,,,,,,, -120500,0.24055526,1.3950156,,,,,,,,,,,,,,,,, -120600,0.24562265,1.4937887,,,,,,,,,,,,,,,,, -120700,0.24663033,1.4521446,,,,,,,,,,,,,,,,, -120800,0.23308599,1.4922439,,,,,,,,,,,,,,,,, -120900,0.25366876,1.4061356,,,,,,,,,,,,,,,,, -121000,0.22744533,1.4208721,,,,,,,,,,,,,,,,, -121036,,,0.6926527619361877,1.3857016563415527,35.392830009300184,0.6950812935829163,1.3645596504211426,30.888696962120584,3000.0,0.7132415771484375,1.257387399673462,31.33621769982612,3003.0,43715.2911632061,73417.11321353912,43715.2911632061,29696.142201185223,1.8214778900146484,0.0 -121100,0.24123742,1.495886,,,,,,,,,,,,,,,,, -121200,0.22595952,1.3915873,,,,,,,,,,,,,,,,, -121300,0.24098723,1.3926951,,,,,,,,,,,,,,,,, -121400,0.23627403,1.439214,,,,,,,,,,,,,,,,, -121500,0.23761545,1.5574411,,,,,,,,,,,,,,,,, -121600,0.23000146,1.406599,,,,,,,,,,,,,,,,, -121700,0.23126838,1.4261252,,,,,,,,,,,,,,,,, -121800,0.23368289,1.4997239,,,,,,,,,,,,,,,,, -121900,0.24337643,1.4404634,,,,,,,,,,,,,,,,, -122000,0.23557355,1.40912,,,,,,,,,,,,,,,,, -122100,0.23791644,1.5016814,,,,,,,,,,,,,,,,, -122200,0.23338164,1.4487016,,,,,,,,,,,,,,,,, -122300,0.24166483,1.4197128,,,,,,,,,,,,,,,,, -122400,0.25083363,1.553164,,,,,,,,,,,,,,,,, -122500,0.23217939,1.4233739,,,,,,,,,,,,,,,,, -122600,0.22930029,1.4285692,,,,,,,,,,,,,,,,, -122700,0.240638,1.5358722,,,,,,,,,,,,,,,,, -122800,0.23571439,1.3839881,,,,,,,,,,,,,,,,, -122900,0.24360533,1.4597821,,,,,,,,,,,,,,,,, -123000,0.23246634,1.4561027,,,,,,,,,,,,,,,,, -123100,0.22656688,1.4382228,,,,,,,,,,,,,,,,, -123200,0.23450676,1.4779799,,,,,,,,,,,,,,,,, -123300,0.25327358,1.4277561,,,,,,,,,,,,,,,,, -123364,,,0.6927144527435303,1.3827232122421265,35.22963795099508,0.6956640481948853,1.36439049243927,30.82173282251918,3000.0,0.7135552763938904,1.2557114362716677,31.546370535616788,3003.0,44555.2459564209,74803.37408471107,44555.2459564209,30242.33323359489,1.863266944885254,0.0 -123400,0.24200204,1.3917838,,,,,,,,,,,,,,,,, -123500,0.24151362,1.484011,,,,,,,,,,,,,,,,, -123600,0.23698442,1.4419229,,,,,,,,,,,,,,,,, -123700,0.24257806,1.470251,,,,,,,,,,,,,,,,, -123800,0.22995,1.4567211,,,,,,,,,,,,,,,,, -123900,0.23819509,1.4875354,,,,,,,,,,,,,,,,, -124000,0.22930552,1.4270049,,,,,,,,,,,,,,,,, -124100,0.23779164,1.435672,,,,,,,,,,,,,,,,, -124200,0.23907124,1.4736397,,,,,,,,,,,,,,,,, -124300,0.2569718,1.476616,,,,,,,,,,,,,,,,, -124400,0.24160199,1.4984266,,,,,,,,,,,,,,,,, -124500,0.23966227,1.5229062,,,,,,,,,,,,,,,,, -124600,0.24309456,1.474734,,,,,,,,,,,,,,,,, -124700,0.23911263,1.4974061,,,,,,,,,,,,,,,,, -124800,0.23000404,1.373553,,,,,,,,,,,,,,,,, -124900,0.22826897,1.4509562,,,,,,,,,,,,,,,,, -125000,0.23759584,1.3506174,,,,,,,,,,,,,,,,, -125100,0.24315763,1.5059906,,,,,,,,,,,,,,,,, -125200,0.22576378,1.4688536,,,,,,,,,,,,,,,,, -125300,0.23495029,1.47447,,,,,,,,,,,,,,,,, -125400,0.23781663,1.487507,,,,,,,,,,,,,,,,, -125500,0.23476465,1.4638455,,,,,,,,,,,,,,,,, -125600,0.23294298,1.4604181,,,,,,,,,,,,,,,,, -125691,,,0.6953275799751282,1.366080403327942,35.52793475734423,0.6961103677749634,1.3615988492965698,30.92001661383089,3000.0,0.7137877345085144,1.2536863088607788,31.45919176016124,3003.0,45395.13310265541,76220.00295948982,45395.13310265541,30818.95568537712,1.906174659729004,0.0 -125700,0.22807756,1.4344485,,,,,,,,,,,,,,,,, -125800,0.24802491,1.4957949,,,,,,,,,,,,,,,,, -125900,0.2232459,1.3926866,,,,,,,,,,,,,,,,, -126000,0.24217078,1.461569,,,,,,,,,,,,,,,,, -126100,0.24917294,1.4590478,,,,,,,,,,,,,,,,, -126200,0.24578202,1.4810411,,,,,,,,,,,,,,,,, -126300,0.23378184,1.4803971,,,,,,,,,,,,,,,,, -126400,0.23864539,1.5069218,,,,,,,,,,,,,,,,, -126500,0.234408,1.4013531,,,,,,,,,,,,,,,,, -126600,0.23597772,1.5067866,,,,,,,,,,,,,,,,, -126700,0.22912656,1.4474171,,,,,,,,,,,,,,,,, -126800,0.23119488,1.4176614,,,,,,,,,,,,,,,,, -126900,0.2399095,1.5190467,,,,,,,,,,,,,,,,, -127000,0.23456496,1.4068925,,,,,,,,,,,,,,,,, -127100,0.2295193,1.428124,,,,,,,,,,,,,,,,, -127200,0.2307006,1.4759278,,,,,,,,,,,,,,,,, -127300,0.22701256,1.4892963,,,,,,,,,,,,,,,,, -127400,0.23839246,1.6167696,,,,,,,,,,,,,,,,, -127500,0.22474296,1.4330568,,,,,,,,,,,,,,,,, -127600,0.24112284,1.5176969,,,,,,,,,,,,,,,,, -127700,0.2350332,1.4630109,,,,,,,,,,,,,,,,, -127800,0.25810462,1.4842881,,,,,,,,,,,,,,,,, -127900,0.23575819,1.4209318,,,,,,,,,,,,,,,,, -128000,0.24797858,1.4759233,,,,,,,,,,,,,,,,, -128018,,,0.6966461539268494,1.3577289581298828,35.648902859429185,0.6961227655410767,1.362878441810608,30.90892111794316,3000.0,0.7140550017356873,1.2544089555740356,31.35299748207434,3003.0,46235.07247853279,77613.97647738457,46235.07247853279,31372.873242139816,1.95019006729126,0.0 -128100,0.22732987,1.4254198,,,,,,,,,,,,,,,,, -128200,0.23842926,1.3778394,,,,,,,,,,,,,,,,, -128300,0.23102032,1.4017489,,,,,,,,,,,,,,,,, -128400,0.23986873,1.4384992,,,,,,,,,,,,,,,,, -128500,0.2400516,1.373998,,,,,,,,,,,,,,,,, -128600,0.2364584,1.4789279,,,,,,,,,,,,,,,,, -128700,0.22801688,1.4749217,,,,,,,,,,,,,,,,, -128800,0.23401281,1.391589,,,,,,,,,,,,,,,,, -128900,0.24244912,1.4871603,,,,,,,,,,,,,,,,, -129000,0.23409836,1.516095,,,,,,,,,,,,,,,,, -129100,0.23082943,1.4396944,,,,,,,,,,,,,,,,, -129200,0.24134178,1.4602656,,,,,,,,,,,,,,,,, -129300,0.24428552,1.4874054,,,,,,,,,,,,,,,,, -129400,0.24400331,1.4430467,,,,,,,,,,,,,,,,, -129500,0.24280334,1.4073169,,,,,,,,,,,,,,,,, -129600,0.22832637,1.4349765,,,,,,,,,,,,,,,,, -129700,0.23100257,1.4079837,,,,,,,,,,,,,,,,, -129800,0.2282926,1.4069525,,,,,,,,,,,,,,,,, -129900,0.25617805,1.4868355,,,,,,,,,,,,,,,,, -130000,0.22807886,1.3862092,,,,,,,,,,,,,,,,, -130100,0.23167348,1.4316429,,,,,,,,,,,,,,,,, -130200,0.23477201,1.4775844,,,,,,,,,,,,,,,,, -130300,0.24381797,1.4569445,,,,,,,,,,,,,,,,, -130345,,,0.6950264573097229,1.3692277669906616,35.27114728847377,0.6961600184440613,1.3616727590560913,30.8710946996457,3000.0,0.7141711711883545,1.253180742263794,31.45476105787029,3003.0,47074.99239993096,79003.22927308083,47074.99239993096,31922.08831977844,1.9932560920715328,0.0 -130400,0.24322501,1.4789895,,,,,,,,,,,,,,,,, -130500,0.24173112,1.4266448,,,,,,,,,,,,,,,,, -130600,0.23603117,1.4720521,,,,,,,,,,,,,,,,, -130700,0.22958231,1.414699,,,,,,,,,,,,,,,,, -130800,0.23693909,1.4714452,,,,,,,,,,,,,,,,, -130900,0.23467189,1.4934541,,,,,,,,,,,,,,,,, -131000,0.23771363,1.4697613,,,,,,,,,,,,,,,,, -131100,0.2421715,1.489175,,,,,,,,,,,,,,,,, -131200,0.23957044,1.4806824,,,,,,,,,,,,,,,,, -131300,0.23527014,1.4662333,,,,,,,,,,,,,,,,, -131400,0.24680541,1.4474046,,,,,,,,,,,,,,,,, -131500,0.23637614,1.4013585,,,,,,,,,,,,,,,,, -131600,0.23986322,1.5209278,,,,,,,,,,,,,,,,, -131700,0.23983213,1.5144105,,,,,,,,,,,,,,,,, -131800,0.23702796,1.4578412,,,,,,,,,,,,,,,,, -131900,0.22596337,1.3668668,,,,,,,,,,,,,,,,, -132000,0.23845011,1.5370953,,,,,,,,,,,,,,,,, -132100,0.23549986,1.4376538,,,,,,,,,,,,,,,,, -132200,0.23581907,1.4657154,,,,,,,,,,,,,,,,, -132300,0.24014977,1.4988025,,,,,,,,,,,,,,,,, -132400,0.22657824,1.402033,,,,,,,,,,,,,,,,, -132500,0.23306312,1.4541357,,,,,,,,,,,,,,,,, -132600,0.23005065,1.4885048,,,,,,,,,,,,,,,,, -132673,,,0.6964210271835327,1.3611470460891724,35.60231110750975,0.6960855722427368,1.3613717555999756,30.815935514007432,3000.0,0.7141944169998169,1.252968668937683,31.444790126895835,3003.0,47915.1347079277,80406.8292388916,47915.1347079277,32485.429721593857,2.036381959915161,0.0 -132700,0.24284899,1.4458365,,,,,,,,,,,,,,,,, -132800,0.22713149,1.5015695,,,,,,,,,,,,,,,,, -132900,0.23947193,1.4632012,,,,,,,,,,,,,,,,, -133000,0.23949172,1.4011075,,,,,,,,,,,,,,,,, -133100,0.24515203,1.5047631,,,,,,,,,,,,,,,,, -133200,0.22666809,1.462961,,,,,,,,,,,,,,,,, -133300,0.23105665,1.3917176,,,,,,,,,,,,,,,,, -133329,,,,,,,,,,,,,,48151.32396864891,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 5e3115b38..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -800.0324959754944,0.0,20.40406966209412,1,0,20.40406966209412,0.4188618532072368,95000000,820.4366059303284,0.417869397116907,0.4188120656037804,83274637 -1453.5206379890442,0.0263786315917968,1220.853224515915,1534,0,1220.853224515915,0.1284171869860197,95000000,2674.450860977173,0.1249114120343946,0.1259749513847325,83274637 -2008.7453598976133,0.051041841506958,2420.9457845687866,3077,0,2420.9457845687866,0.1273812083778782,95000000,4429.84237241745,0.1229683316353732,0.124958537045319,83274637 -2557.459654331208,0.0768871307373046,3621.1655600070953,4620,0,3621.1655600070953,0.1268847081208881,95000000,6178.8516590595245,0.1236519870236984,0.1245262420713854,83274637 -3071.446784257889,0.1051466464996337,4821.371724128723,6166,0,4821.371724128723,0.1269528404091283,95000000,7893.121944189072,0.1236871280853853,0.124615452719941,83274637 -3525.393592596054,0.1279835700988769,6021.972544908524,7706,0,6021.972544908524,0.1261517924547697,95000000,9547.742681503296,0.1238551849745354,0.1238614922883872,83274637 -3847.1143724918365,0.15393447875976562,7222.520458698273,9250,0,7222.520458698273,0.12587958279194078,95000000,11070.087594270706,0.12005350195199439,0.1235807650252051,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 9e5bc2c5e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.1300435,0.41734096,,,,,,,,,,, -1,,,0.417869397116907,0.4188120656037804,83274637.0,0.4188618532072368,95000000.0,20.40406966209412,820.4366059303284,20.40406966209412,800.0324959754944,0.0,0.0 -100,0.046868667,0.13723466,,,,,,,,,,, -200,0.00933943,0.12970889,,,,,,,,,,, -300,0.0078046834,0.13196057,,,,,,,,,,, -400,0.019773718,0.122103415,,,,,,,,,,, -500,0.009630637,0.12984219,,,,,,,,,,, -600,0.015679639,0.13681656,,,,,,,,,,, -700,0.020832954,0.13475241,,,,,,,,,,, -800,0.025686113,0.12150377,,,,,,,,,,, -900,0.014975373,0.1197229,,,,,,,,,,, -1000,0.007296873,0.12594578,,,,,,,,,,, -1100,0.019458938,0.13005218,,,,,,,,,,, -1200,0.01688626,0.12659907,,,,,,,,,,, -1300,0.02245634,0.125616,,,,,,,,,,, -1400,0.008070337,0.12458321,,,,,,,,,,, -1500,0.0148451375,0.11961827,,,,,,,,,,, -1534,,,0.1249114120343946,0.1259749513847325,83274637.0,0.1284171869860197,95000000.0,1220.853224515915,2674.450860977173,1220.853224515915,1453.5206379890442,0.0263786315917968,0.0 -1600,0.027443236,0.122365534,,,,,,,,,,, -1700,0.008066008,0.11906021,,,,,,,,,,, -1800,0.018191345,0.1239223,,,,,,,,,,, -1900,0.01838196,0.12935147,,,,,,,,,,, -2000,0.008834813,0.11899212,,,,,,,,,,, -2100,0.007981515,0.12164228,,,,,,,,,,, -2200,0.013716713,0.11776258,,,,,,,,,,, -2300,0.019258425,0.12486133,,,,,,,,,,, -2400,0.0056750188,0.12965702,,,,,,,,,,, -2500,0.006895105,0.11819805,,,,,,,,,,, -2600,0.005851446,0.12710726,,,,,,,,,,, -2700,0.011987198,0.1171715,,,,,,,,,,, -2800,0.01194877,0.12362203,,,,,,,,,,, -2900,0.017925823,0.124452606,,,,,,,,,,, -3000,0.005310939,0.11916889,,,,,,,,,,, -3077,,,0.1229683316353732,0.124958537045319,83274637.0,0.1273812083778782,95000000.0,2420.9457845687866,4429.84237241745,2420.9457845687866,2008.7453598976133,0.051041841506958,0.0 -3100,0.0052965656,0.12280719,,,,,,,,,,, -3200,0.010803661,0.12523381,,,,,,,,,,, -3300,0.034388144,0.13069996,,,,,,,,,,, -3400,0.012946341,0.12339009,,,,,,,,,,, -3500,0.017838217,0.12335347,,,,,,,,,,, -3600,0.008686968,0.118296936,,,,,,,,,,, -3700,0.0056894026,0.12161028,,,,,,,,,,, -3800,0.0068512335,0.12730075,,,,,,,,,,, -3900,0.016903926,0.12366908,,,,,,,,,,, -4000,0.021797359,0.119503975,,,,,,,,,,, -4100,0.011699685,0.11718469,,,,,,,,,,, -4200,0.0054046446,0.1265591,,,,,,,,,,, -4300,0.0055448404,0.119739145,,,,,,,,,,, -4400,0.013835544,0.12461038,,,,,,,,,,, -4500,0.0061823837,0.120991796,,,,,,,,,,, -4600,0.02197685,0.1276631,,,,,,,,,,, -4620,,,0.1236519870236984,0.1245262420713854,83274637.0,0.1268847081208881,95000000.0,3621.1655600070953,6178.8516590595245,3621.1655600070953,2557.459654331208,0.0768871307373046,0.0 -4700,0.009877806,0.12869865,,,,,,,,,,, -4800,0.006774147,0.12480929,,,,,,,,,,, -4900,0.009370758,0.12494961,,,,,,,,,,, -5000,0.0061270827,0.13229352,,,,,,,,,,, -5100,0.013759749,0.12731385,,,,,,,,,,, -5200,0.0143626565,0.11978536,,,,,,,,,,, -5300,0.011644762,0.12557386,,,,,,,,,,, -5400,0.005563262,0.12294042,,,,,,,,,,, -5500,0.019307962,0.12492308,,,,,,,,,,, -5600,0.008722358,0.12237504,,,,,,,,,,, -5700,0.008613521,0.13123745,,,,,,,,,,, -5800,0.013793572,0.13025028,,,,,,,,,,, -5900,0.0063070315,0.12056251,,,,,,,,,,, -6000,0.007940414,0.11955399,,,,,,,,,,, -6100,0.01682088,0.12239388,,,,,,,,,,, -6166,,,0.1236871280853853,0.124615452719941,83274637.0,0.1269528404091283,95000000.0,4821.371724128723,7893.121944189072,4821.371724128723,3071.446784257889,0.1051466464996337,0.0 -6200,0.008068258,0.120786205,,,,,,,,,,, -6300,0.0069976146,0.13649724,,,,,,,,,,, -6400,0.011782006,0.122622125,,,,,,,,,,, -6500,0.007490282,0.12084503,,,,,,,,,,, -6600,0.006467042,0.13002129,,,,,,,,,,, -6700,0.007862785,0.12042201,,,,,,,,,,, -6800,0.0066438457,0.122565106,,,,,,,,,,, -6900,0.006504104,0.12286995,,,,,,,,,,, -7000,0.025729584,0.12714376,,,,,,,,,,, -7100,0.008562439,0.12102871,,,,,,,,,,, -7200,0.009115764,0.1372438,,,,,,,,,,, -7300,0.0068767117,0.11908598,,,,,,,,,,, -7400,0.013232676,0.122126155,,,,,,,,,,, -7500,0.008364772,0.13251981,,,,,,,,,,, -7600,0.008175692,0.12223479,,,,,,,,,,, -7700,0.006488293,0.12369485,,,,,,,,,,, -7706,,,0.1238551849745354,0.1238614922883872,83274637.0,0.1261517924547697,95000000.0,6021.972544908524,9547.742681503296,6021.972544908524,3525.393592596054,0.1279835700988769,0.0 -7800,0.010508659,0.11615224,,,,,,,,,,, -7900,0.006231223,0.11944622,,,,,,,,,,, -8000,0.006374241,0.11803985,,,,,,,,,,, -8100,0.009378811,0.12608626,,,,,,,,,,, -8200,0.00840828,0.119811945,,,,,,,,,,, -8300,0.010660024,0.123544745,,,,,,,,,,, -8400,0.007910148,0.12099914,,,,,,,,,,, -8500,0.012354042,0.11940734,,,,,,,,,,, -8600,0.007888773,0.11789787,,,,,,,,,,, -8700,0.007528118,0.11215532,,,,,,,,,,, -8800,0.0075850994,0.12045368,,,,,,,,,,, -8900,0.013249356,0.120327525,,,,,,,,,,, -9000,0.0062596346,0.12612458,,,,,,,,,,, -9100,0.009397706,0.11863978,,,,,,,,,,, -9200,0.008931331,0.12295435,,,,,,,,,,, -9250,,,0.1200535019519943,0.1235807650252051,83274637.0,0.1258795827919407,95000000.0,7222.520458698273,11070.087594270706,7222.520458698273,3847.114372491837,0.1539344787597656,0.0 -9300,0.006773992,0.12660727,,,,,,,,,,, -9400,0.008320434,0.115578555,,,,,,,,,,, -9500,0.008499878,0.11739156,,,,,,,,,,, -9600,0.008192979,0.11916334,,,,,,,,,,, -9700,0.006963786,0.121016994,,,,,,,,,,, -9800,0.013732812,0.12257892,,,,,,,,,,, -9882,,,,,,,,7703.072590589523,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/eval_measurements.csv deleted file mode 100644 index b6cf71ebe..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -165.15218567848206,0.0,6.695144176483154,1,0,6.695144176483154,0.4188618532072368,95000000,171.84738278388977,0.4174958868596539,0.4188120656037804,83274637 -189.17342853546145,0.0179743766784667,1207.2788932323456,1524,0,1207.2788932323456,0.128335424712171,95000000,1396.5214047431946,0.1258377176035875,0.1260088345630524,83274637 -212.356232881546,0.0437788963317871,2407.3932876586914,3042,0,2407.3932876586914,0.1280807847861842,95000000,2619.8952460289,0.1236309052544569,0.1256461182378825,83274637 -236.05473399162287,0.0727894306182861,3607.7674717903137,4562,0,3607.7674717903137,0.1268517716899671,95000000,3844.047542095184,0.1213925681867689,0.1244713451700486,83274637 -259.6351001262665,0.097424030303955,4808.567242622376,6063,0,4808.567242622376,0.1263481180612664,95000000,5068.501587867737,0.1228239984257416,0.1238887687976352,83274637 -282.9353101253509,0.1199071407318115,6008.508641242981,7559,0,6008.508641242981,0.126163258696546,95000000,6291.814614772797,0.1220807054507657,0.1238988479750893,83274637 -306.9958863258362,0.14849400520324707,7208.577575683594,9077,0,7208.577575683594,0.12595633828125,95000000,7516.021926641464,0.12144491376367005,0.12366305692549881,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/measurements.csv deleted file mode 100644 index cd5354e73..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/measurements.csv +++ /dev/null @@ -1,106 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.12649,0.41882348,,,,,,,,,,, -1,,,0.4174958868596539,0.4188120656037804,83274637.0,0.4188618532072368,95000000.0,6.695144176483154,171.84738278388977,6.695144176483154,165.15218567848206,0.0,0.0 -100,0.069315426,0.13665459,,,,,,,,,,, -200,0.31735063,0.1376458,,,,,,,,,,, -300,0.11813917,0.12592348,,,,,,,,,,, -400,0.09627492,0.1339238,,,,,,,,,,, -500,0.03169876,0.12493286,,,,,,,,,,, -600,0.07278622,0.1276412,,,,,,,,,,, -700,0.047849283,0.118594155,,,,,,,,,,, -800,0.007550995,0.13479342,,,,,,,,,,, -900,0.01499886,0.1235461,,,,,,,,,,, -1000,0.019969417,0.11478802,,,,,,,,,,, -1100,0.008187329,0.120884225,,,,,,,,,,, -1200,0.010354724,0.12190902,,,,,,,,,,, -1300,0.012756186,0.12631902,,,,,,,,,,, -1400,0.015630182,0.12289633,,,,,,,,,,, -1500,0.013146747,0.118372396,,,,,,,,,,, -1524,,,0.1258377176035875,0.1260088345630524,83274637.0,0.128335424712171,95000000.0,1207.2788932323456,1396.5214047431946,1207.2788932323456,189.17342853546145,0.0179743766784667,0.0 -1600,0.018267341,0.12734999,,,,,,,,,,, -1700,0.009156034,0.12445205,,,,,,,,,,, -1800,0.014288886,0.11735453,,,,,,,,,,, -1900,0.035325374,0.12908924,,,,,,,,,,, -2000,0.030155225,0.12494795,,,,,,,,,,, -2100,0.030358274,0.12900294,,,,,,,,,,, -2200,0.011003703,0.12930971,,,,,,,,,,, -2300,0.007254431,0.11974519,,,,,,,,,,, -2400,0.02863278,0.12244289,,,,,,,,,,, -2500,0.020260822,0.13025822,,,,,,,,,,, -2600,0.033862334,0.13425186,,,,,,,,,,, -2700,0.0074698753,0.12905177,,,,,,,,,,, -2800,0.009033797,0.12775256,,,,,,,,,,, -2900,0.025136033,0.13571447,,,,,,,,,,, -3000,0.009019874,0.119424246,,,,,,,,,,, -3042,,,0.1236309052544569,0.1256461182378825,83274637.0,0.1280807847861842,95000000.0,2407.3932876586914,2619.8952460289,2407.3932876586914,212.356232881546,0.0437788963317871,0.0 -3100,0.022960963,0.11940293,,,,,,,,,,, -3200,0.008205292,0.12316069,,,,,,,,,,, -3300,0.015494601,0.12239374,,,,,,,,,,, -3400,0.007463012,0.11932185,,,,,,,,,,, -3500,0.012773791,0.11751199,,,,,,,,,,, -3600,0.0064658145,0.11752272,,,,,,,,,,, -3700,0.019629663,0.13014847,,,,,,,,,,, -3800,0.008595461,0.11705929,,,,,,,,,,, -3900,0.026130147,0.1250312,,,,,,,,,,, -4000,0.005796016,0.11751756,,,,,,,,,,, -4100,0.007957535,0.120769165,,,,,,,,,,, -4200,0.029605651,0.13258761,,,,,,,,,,, -4300,0.009479223,0.12296583,,,,,,,,,,, -4400,0.0105207935,0.1370041,,,,,,,,,,, -4500,0.013645223,0.12301948,,,,,,,,,,, -4562,,,0.1213925681867689,0.1244713451700486,83274637.0,0.1268517716899671,95000000.0,3607.7674717903137,3844.047542095184,3607.7674717903137,236.05473399162287,0.0727894306182861,0.0 -4600,0.00965062,0.120092675,,,,,,,,,,, -4700,0.00647982,0.11576885,,,,,,,,,,, -4800,0.009407394,0.11458351,,,,,,,,,,, -4900,0.016665705,0.123678565,,,,,,,,,,, -5000,0.013624759,0.11944592,,,,,,,,,,, -5100,0.009721186,0.123252064,,,,,,,,,,, -5200,0.018058373,0.11866301,,,,,,,,,,, -5300,0.019420343,0.12441787,,,,,,,,,,, -5400,0.012197129,0.11969119,,,,,,,,,,, -5500,0.005893697,0.12461862,,,,,,,,,,, -5600,0.017452495,0.1130707,,,,,,,,,,, -5700,0.007834065,0.12470255,,,,,,,,,,, -5800,0.0056524766,0.12911478,,,,,,,,,,, -5900,0.008476515,0.12746109,,,,,,,,,,, -6000,0.00527629,0.11751158,,,,,,,,,,, -6063,,,0.1228239984257416,0.1238887687976352,83274637.0,0.1263481180612664,95000000.0,4808.567242622376,5068.501587867737,4808.567242622376,259.6351001262665,0.097424030303955,0.0 -6100,0.0067138127,0.11783271,,,,,,,,,,, -6200,0.0075419005,0.120706685,,,,,,,,,,, -6300,0.01006799,0.121356264,,,,,,,,,,, -6400,0.009934723,0.1223303,,,,,,,,,,, -6500,0.00884864,0.12892377,,,,,,,,,,, -6600,0.006506252,0.11757383,,,,,,,,,,, -6700,0.007113942,0.12622014,,,,,,,,,,, -6800,0.010840247,0.1226648,,,,,,,,,,, -6900,0.016277947,0.12829153,,,,,,,,,,, -7000,0.017363425,0.13626488,,,,,,,,,,, -7100,0.0067263525,0.13120747,,,,,,,,,,, -7200,0.006754859,0.12422407,,,,,,,,,,, -7300,0.0067965453,0.11938777,,,,,,,,,,, -7400,0.02541698,0.12889947,,,,,,,,,,, -7500,0.0061654653,0.12117975,,,,,,,,,,, -7559,,,0.1220807054507657,0.1238988479750893,83274637.0,0.126163258696546,95000000.0,6008.508641242981,6291.814614772797,6008.508641242981,282.9353101253509,0.1199071407318115,0.0 -7600,0.007757574,0.12445613,,,,,,,,,,, -7700,0.0083438605,0.117243804,,,,,,,,,,, -7800,0.0072974004,0.11944964,,,,,,,,,,, -7900,0.0077933343,0.12327983,,,,,,,,,,, -8000,0.008782008,0.13296865,,,,,,,,,,, -8100,0.0066077807,0.11485075,,,,,,,,,,, -8200,0.0075467573,0.12934498,,,,,,,,,,, -8300,0.0065341443,0.12518601,,,,,,,,,,, -8400,0.012613281,0.118837506,,,,,,,,,,, -8500,0.006630254,0.121419534,,,,,,,,,,, -8600,0.0077483663,0.12078263,,,,,,,,,,, -8700,0.006407058,0.1188699,,,,,,,,,,, -8800,0.0106661245,0.12722553,,,,,,,,,,, -8900,0.007656328,0.12166223,,,,,,,,,,, -9000,0.007471064,0.11935034,,,,,,,,,,, -9077,,,0.12144491376367,0.1236630569254988,83274637.0,0.12595633828125,95000000.0,7208.577575683594,7516.021926641464,7208.577575683594,306.9958863258362,0.148494005203247,0.0 -9100,0.011132785,0.13049057,,,,,,,,,,, -9200,0.008335578,0.12612331,,,,,,,,,,, -9300,0.006213434,0.11296882,,,,,,,,,,, -9400,0.0068339678,0.13193733,,,,,,,,,,, -9500,0.007294891,0.122376435,,,,,,,,,,, -9600,0.007469563,0.11802992,,,,,,,,,,, -9699,,,,,,,,7703.099191665649,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/eval_measurements.csv deleted file mode 100644 index bcb03c5eb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -23.736982822418213,0.0,6.709789276123047,1,0,6.709789276123047,0.4188618532072368,95000000,30.446840286254883,0.4181562261011615,0.4188120656037804,83274637 -50.14362454414368,0.0191893577575683,1207.0080687999723,1524,0,1207.0080687999723,0.1283900325041118,95000000,1257.2221710681915,0.1257212884114973,0.1256773305193152,83274637 -74.33149528503418,0.0458357334136962,2407.0765080451965,3026,0,2407.0765080451965,0.127627279399671,95000000,2481.554374933243,0.1224746711599002,0.1250115902730533,83274637 -98.81931662559508,0.0688531398773193,3607.626647233963,4533,0,3607.626647233963,0.1269304060855263,95000000,3706.6648955345154,0.1217223097037219,0.1245626172862362,83274637 -122.87645244598389,0.0949859619140625,4807.566356658936,6050,0,4807.566356658936,0.1267423980263158,95000000,4930.737677097321,0.1214077326486695,0.1243880322693571,83274637 -150.06325364112854,0.1183829307556152,6007.942964553833,7549,0,6007.942964553833,0.1263650322985197,95000000,6158.37321305275,0.1233120016613096,0.1240022305744935,83274637 -174.17526245117188,0.145538330078125,7208.53520655632,9067,0,7208.53520655632,0.12593963112664475,95000000,7383.154753684998,0.12435142550640886,0.12369175456072537,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/measurements.csv deleted file mode 100644 index 85fa5ccf7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/measurements.csv +++ /dev/null @@ -1,106 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.1088495,0.41638312,,,,,,,,,,, -1,,,0.4181562261011615,0.4188120656037804,83274637.0,0.4188618532072368,95000000.0,6.709789276123047,30.446840286254883,6.709789276123047,23.736982822418213,0.0,0.0 -100,0.04437572,0.14898217,,,,,,,,,,, -200,0.010766367,0.13230987,,,,,,,,,,, -300,0.0074101556,0.12187229,,,,,,,,,,, -400,0.01163024,0.12169005,,,,,,,,,,, -500,0.007714396,0.11940138,,,,,,,,,,, -600,0.028407658,0.12162355,,,,,,,,,,, -700,0.039921127,0.13252282,,,,,,,,,,, -800,0.0238902,0.1285742,,,,,,,,,,, -900,0.020065175,0.12500061,,,,,,,,,,, -1000,0.007308983,0.115374364,,,,,,,,,,, -1100,0.006967698,0.124523595,,,,,,,,,,, -1200,0.03132821,0.12436326,,,,,,,,,,, -1300,0.009244671,0.13013275,,,,,,,,,,, -1400,0.017732562,0.12076672,,,,,,,,,,, -1500,0.0075502754,0.11527002,,,,,,,,,,, -1524,,,0.1257212884114973,0.1256773305193152,83274637.0,0.1283900325041118,95000000.0,1207.0080687999723,1257.2221710681915,1207.0080687999723,50.14362454414368,0.0191893577575683,0.0 -1600,0.02563187,0.12059872,,,,,,,,,,, -1700,0.007234006,0.13061339,,,,,,,,,,, -1800,0.008587296,0.11967661,,,,,,,,,,, -1900,0.021932222,0.13167377,,,,,,,,,,, -2000,0.009018105,0.12224355,,,,,,,,,,, -2100,0.017180055,0.1305574,,,,,,,,,,, -2200,0.03685636,0.13600618,,,,,,,,,,, -2300,0.013954563,0.13044247,,,,,,,,,,, -2400,0.011245997,0.1280711,,,,,,,,,,, -2500,0.0065778242,0.12415639,,,,,,,,,,, -2600,0.013592801,0.11403544,,,,,,,,,,, -2700,0.0065790243,0.12693904,,,,,,,,,,, -2800,0.01598632,0.12933509,,,,,,,,,,, -2900,0.020649645,0.12851684,,,,,,,,,,, -3000,0.018532228,0.12171498,,,,,,,,,,, -3026,,,0.1224746711599002,0.1250115902730533,83274637.0,0.127627279399671,95000000.0,2407.0765080451965,2481.554374933243,2407.0765080451965,74.33149528503418,0.0458357334136962,0.0 -3100,0.009929377,0.11794225,,,,,,,,,,, -3200,0.0110088205,0.11684926,,,,,,,,,,, -3300,0.020279113,0.1256082,,,,,,,,,,, -3400,0.013682723,0.12540686,,,,,,,,,,, -3500,0.01912126,0.12730525,,,,,,,,,,, -3600,0.006400622,0.12706594,,,,,,,,,,, -3700,0.011024957,0.120550476,,,,,,,,,,, -3800,0.009521538,0.12396223,,,,,,,,,,, -3900,0.022643222,0.118114956,,,,,,,,,,, -4000,0.015224331,0.1190912,,,,,,,,,,, -4100,0.0052726255,0.12806728,,,,,,,,,,, -4200,0.0070098303,0.122205205,,,,,,,,,,, -4300,0.019361312,0.12202984,,,,,,,,,,, -4400,0.0064342064,0.12464918,,,,,,,,,,, -4500,0.010819853,0.112823576,,,,,,,,,,, -4533,,,0.1217223097037219,0.1245626172862362,83274637.0,0.1269304060855263,95000000.0,3607.626647233963,3706.6648955345154,3607.626647233963,98.81931662559508,0.0688531398773193,0.0 -4600,0.008554922,0.11856547,,,,,,,,,,, -4700,0.01392386,0.117156714,,,,,,,,,,, -4800,0.01153347,0.12765652,,,,,,,,,,, -4900,0.01637117,0.1262282,,,,,,,,,,, -5000,0.012333384,0.1221838,,,,,,,,,,, -5100,0.014878256,0.12702799,,,,,,,,,,, -5200,0.005794587,0.12384351,,,,,,,,,,, -5300,0.0065942802,0.11765575,,,,,,,,,,, -5400,0.022818055,0.12674996,,,,,,,,,,, -5500,0.018939368,0.12301654,,,,,,,,,,, -5600,0.006329,0.12629321,,,,,,,,,,, -5700,0.014856412,0.12376113,,,,,,,,,,, -5800,0.00820919,0.12829936,,,,,,,,,,, -5900,0.0069807316,0.12045059,,,,,,,,,,, -6000,0.016810095,0.1271125,,,,,,,,,,, -6050,,,0.1214077326486695,0.1243880322693571,83274637.0,0.1267423980263158,95000000.0,4807.566356658936,4930.737677097321,4807.566356658936,122.87645244598389,0.0949859619140625,0.0 -6100,0.009055465,0.11826011,,,,,,,,,,, -6200,0.011048434,0.11810014,,,,,,,,,,, -6300,0.008532663,0.123198114,,,,,,,,,,, -6400,0.0074778185,0.11825468,,,,,,,,,,, -6500,0.005586061,0.12837552,,,,,,,,,,, -6600,0.007269635,0.113859355,,,,,,,,,,, -6700,0.009587508,0.1209035,,,,,,,,,,, -6800,0.0061477832,0.11963532,,,,,,,,,,, -6900,0.0068080286,0.12983881,,,,,,,,,,, -7000,0.025004266,0.12462178,,,,,,,,,,, -7100,0.007230068,0.12260697,,,,,,,,,,, -7200,0.013574017,0.1225424,,,,,,,,,,, -7300,0.005745982,0.12769917,,,,,,,,,,, -7400,0.010676174,0.122537576,,,,,,,,,,, -7500,0.007895908,0.116997145,,,,,,,,,,, -7549,,,0.1233120016613096,0.1240022305744935,83274637.0,0.1263650322985197,95000000.0,6007.942964553833,6158.37321305275,6007.942964553833,150.06325364112854,0.1183829307556152,0.0 -7600,0.0071503483,0.11407839,,,,,,,,,,, -7700,0.0077345087,0.12220016,,,,,,,,,,, -7800,0.008336514,0.12894331,,,,,,,,,,, -7900,0.010515497,0.118974335,,,,,,,,,,, -8000,0.010740182,0.12904966,,,,,,,,,,, -8100,0.008736752,0.11478947,,,,,,,,,,, -8200,0.007969342,0.12291095,,,,,,,,,,, -8300,0.011715288,0.13405538,,,,,,,,,,, -8400,0.009298875,0.120420605,,,,,,,,,,, -8500,0.011781535,0.12795568,,,,,,,,,,, -8600,0.008417941,0.120130636,,,,,,,,,,, -8700,0.012148678,0.12667337,,,,,,,,,,, -8800,0.0068928828,0.121463686,,,,,,,,,,, -8900,0.012119613,0.113768056,,,,,,,,,,, -9000,0.010007698,0.118103005,,,,,,,,,,, -9067,,,0.1243514255064088,0.1236917545607253,83274637.0,0.1259396311266447,95000000.0,7208.53520655632,7383.154753684998,7208.53520655632,174.17526245117188,0.145538330078125,0.0 -9100,0.006609633,0.12079452,,,,,,,,,,, -9200,0.0070929644,0.116610795,,,,,,,,,,, -9300,0.0071244286,0.124309696,,,,,,,,,,, -9400,0.007886714,0.12233779,,,,,,,,,,, -9500,0.009798959,0.12797444,,,,,,,,,,, -9600,0.010554994,0.11936444,,,,,,,,,,, -9698,,,,,,,,7703.667432069778,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 218fdc7bc..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -23.535862922668457,0.0,10.69684362411499,1,0,10.69684362411499,0.4188618532072368,95000000,34.23276495933533,0.4176470236208454,0.4188120656037804,83274637 -46.04471206665039,0.0211014747619628,1210.9046003818512,1385,0,1210.9046003818512,0.1286561589432565,95000000,1257.0154948234558,0.1257467939043944,0.1262581310471518,83274637 -68.33290123939514,0.0490431785583496,2411.299080133438,2767,0,2411.299080133438,0.1284978528166118,95000000,2479.7721927165985,0.1240582037834251,0.126174739059505,83274637 -90.88132357597352,0.0729458332061767,3611.7111065387726,4148,0,3611.7111065387726,0.1287472412006579,95000000,3702.8027780056,0.1246463721090892,0.1263915748046881,83274637 -113.57464456558228,0.0985362529754638,4812.121766090393,5532,0,4812.121766090393,0.1281059535773026,95000000,4925.977107524872,0.1261061483865264,0.1257231651156882,83274637 -136.3490800857544,0.1220607757568359,6012.393161773682,6894,0,6012.393161773682,0.1278534981599506,95000000,6149.090483188629,0.1229537926942297,0.1254273176556296,83274637 -159.09422087669373,0.14670157432556152,7213.036543607712,8258,0,7213.036543607712,0.12774747687088817,95000000,7372.547692298889,0.12317954655150948,0.12541083008938844,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/measurements.csv deleted file mode 100644 index 89b97ab5d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/measurements.csv +++ /dev/null @@ -1,97 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.0929966,0.41928858,,,,,,,,,,, -1,,,0.4176470236208454,0.4188120656037804,83274637.0,0.4188618532072368,95000000.0,10.69684362411499,34.23276495933533,10.69684362411499,23.535862922668457,0.0,0.0 -100,0.0807778,0.13230287,,,,,,,,,,, -200,0.07150222,0.13679527,,,,,,,,,,, -300,0.009048812,0.13359185,,,,,,,,,,, -400,0.0076093897,0.1305742,,,,,,,,,,, -500,0.046218306,0.1273943,,,,,,,,,,, -600,0.028172657,0.12524833,,,,,,,,,,, -700,0.008790972,0.12942088,,,,,,,,,,, -800,0.030185018,0.13116977,,,,,,,,,,, -900,0.0095966505,0.119867384,,,,,,,,,,, -1000,0.022729877,0.12304518,,,,,,,,,,, -1100,0.047835294,0.119999014,,,,,,,,,,, -1200,0.025185551,0.120653644,,,,,,,,,,, -1300,0.009370137,0.12690547,,,,,,,,,,, -1385,,,0.1257467939043944,0.1262581310471518,83274637.0,0.1286561589432565,95000000.0,1210.9046003818512,1257.0154948234558,1210.9046003818512,46.04471206665039,0.0211014747619628,0.0 -1400,0.013787351,0.12520333,,,,,,,,,,, -1500,0.027656363,0.12876114,,,,,,,,,,, -1600,0.0247154,0.12528421,,,,,,,,,,, -1700,0.02508968,0.11813856,,,,,,,,,,, -1800,0.01758578,0.1315279,,,,,,,,,,, -1900,0.004244441,0.12233712,,,,,,,,,,, -2000,0.010228067,0.12056616,,,,,,,,,,, -2100,0.019845061,0.12410459,,,,,,,,,,, -2200,0.0072398805,0.12284984,,,,,,,,,,, -2300,0.053669192,0.1258275,,,,,,,,,,, -2400,0.0051607867,0.12627646,,,,,,,,,,, -2500,0.023172457,0.1181493,,,,,,,,,,, -2600,0.034353115,0.12920511,,,,,,,,,,, -2700,0.014268195,0.13764735,,,,,,,,,,, -2767,,,0.1240582037834251,0.126174739059505,83274637.0,0.1284978528166118,95000000.0,2411.299080133438,2479.7721927165985,2411.299080133438,68.33290123939514,0.0490431785583496,0.0 -2800,0.00828969,0.11671193,,,,,,,,,,, -2900,0.03197531,0.12502185,,,,,,,,,,, -3000,0.053116534,0.122031495,,,,,,,,,,, -3100,0.016484238,0.121777356,,,,,,,,,,, -3200,0.040841583,0.13707566,,,,,,,,,,, -3300,0.02515905,0.120638445,,,,,,,,,,, -3400,0.03506166,0.12952495,,,,,,,,,,, -3500,0.03742706,0.12559336,,,,,,,,,,, -3600,0.036931526,0.121770315,,,,,,,,,,, -3700,0.032821618,0.1310759,,,,,,,,,,, -3800,0.03386659,0.13132471,,,,,,,,,,, -3900,0.004557242,0.123039216,,,,,,,,,,, -4000,0.0040009664,0.120571196,,,,,,,,,,, -4100,0.0149166165,0.12779929,,,,,,,,,,, -4148,,,0.1246463721090892,0.1263915748046881,83274637.0,0.1287472412006579,95000000.0,3611.7111065387726,3702.8027780056,3611.7111065387726,90.88132357597352,0.0729458332061767,0.0 -4200,0.05197221,0.1263977,,,,,,,,,,, -4300,0.042917445,0.121288165,,,,,,,,,,, -4400,0.049677473,0.119286284,,,,,,,,,,, -4500,0.010803903,0.132038,,,,,,,,,,, -4600,0.014010695,0.12271087,,,,,,,,,,, -4700,0.0046882667,0.13150755,,,,,,,,,,, -4800,0.028115366,0.124584384,,,,,,,,,,, -4900,0.00983935,0.12296531,,,,,,,,,,, -5000,0.014730053,0.13183834,,,,,,,,,,, -5100,0.042346906,0.12668377,,,,,,,,,,, -5200,0.026349023,0.12831886,,,,,,,,,,, -5300,0.041305088,0.1322378,,,,,,,,,,, -5400,0.007893044,0.1266757,,,,,,,,,,, -5500,0.046122573,0.11977975,,,,,,,,,,, -5532,,,0.1261061483865264,0.1257231651156882,83274637.0,0.1281059535773026,95000000.0,4812.121766090393,4925.977107524872,4812.121766090393,113.57464456558228,0.0985362529754638,0.0 -5600,0.046406493,0.12362985,,,,,,,,,,, -5700,0.022495668,0.120445594,,,,,,,,,,, -5800,0.027208257,0.12976614,,,,,,,,,,, -5900,0.023332918,0.13907541,,,,,,,,,,, -6000,0.016934982,0.121976815,,,,,,,,,,, -6100,0.012603023,0.12777913,,,,,,,,,,, -6200,0.0081364205,0.13310869,,,,,,,,,,, -6300,0.006364949,0.12317997,,,,,,,,,,, -6400,0.011180502,0.12529789,,,,,,,,,,, -6500,0.0050768643,0.117635064,,,,,,,,,,, -6600,0.014012624,0.11976478,,,,,,,,,,, -6700,0.019544637,0.12646894,,,,,,,,,,, -6800,0.020542895,0.13083816,,,,,,,,,,, -6894,,,0.1229537926942297,0.1254273176556296,83274637.0,0.1278534981599506,95000000.0,6012.393161773682,6149.090483188629,6012.393161773682,136.3490800857544,0.1220607757568359,0.0 -6900,0.0129055055,0.1171488,,,,,,,,,,, -7000,0.006543059,0.125728,,,,,,,,,,, -7100,0.011063947,0.120712966,,,,,,,,,,, -7200,0.017471874,0.13324034,,,,,,,,,,, -7300,0.012629436,0.13364968,,,,,,,,,,, -7400,0.012036447,0.12341838,,,,,,,,,,, -7500,0.021022327,0.12151978,,,,,,,,,,, -7600,0.00696875,0.11521623,,,,,,,,,,, -7700,0.012665197,0.124069735,,,,,,,,,,, -7800,0.008804278,0.12816274,,,,,,,,,,, -7900,0.015571233,0.120649695,,,,,,,,,,, -8000,0.009424241,0.12470996,,,,,,,,,,, -8100,0.007274986,0.12190605,,,,,,,,,,, -8200,0.006868649,0.12314265,,,,,,,,,,, -8258,,,0.1231795465515094,0.1254108300893884,83274637.0,0.1277474768708881,95000000.0,7213.036543607712,7372.547692298889,7213.036543607712,159.09422087669373,0.1467015743255615,0.0 -8300,0.01837232,0.12830779,,,,,,,,,,, -8400,0.0067478404,0.117859244,,,,,,,,,,, -8500,0.0063683,0.118687406,,,,,,,,,,, -8600,0.013484476,0.12236837,,,,,,,,,,, -8700,0.006113844,0.13219272,,,,,,,,,,, -8746,,,,,,,,7703.177574872971,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/eval_measurements.csv deleted file mode 100644 index c5ac241e4..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.568405866622925,0.0,5.746478080749512,1,0,5.746478080749512,0.4188618532072368,95000000,28.314929246902462,0.417861022289444,0.4188120656037804,83274637 -45.48547530174256,0.0170862674713134,1205.8514742851255,1410,0,1205.8514742851255,0.1280632156044407,95000000,1251.4007487297058,0.1253754194802458,0.1256152709010007,83274637 -68.39558744430542,0.0397999286651611,2405.9904704093933,2812,0,2405.9904704093933,0.1279820143194901,95000000,2474.51877117157,0.1242157695522098,0.1256440325187938,83274637 -91.25166869163512,0.0668396949768066,3605.944534778595,4206,0,3605.944534778595,0.1269549632709704,95000000,3697.401435613632,0.1228075881236754,0.1245392818028735,83274637 -114.03592133522034,0.0925056934356689,4806.842922210693,5604,0,4806.842922210693,0.1264720616570723,95000000,4921.155634880066,0.1248605644927834,0.1240328853538713,83274637 -136.51566314697266,0.1157255172729492,6007.645996809006,7000,0,6007.645996809006,0.1262262012335526,95000000,6144.508240938187,0.1227599759146852,0.123875571689673,83274637 -159.04064083099365,0.13881611824035645,7207.703269481659,8390,0,7207.703269481659,0.12607612123766448,95000000,7367.158077716827,0.12214348680755627,0.12379352441480772,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/measurements.csv deleted file mode 100644 index 71a9512a7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/measurements.csv +++ /dev/null @@ -1,98 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.0669575,0.41919458,,,,,,,,,,, -1,,,0.417861022289444,0.4188120656037804,83274637.0,0.4188618532072368,95000000.0,5.746478080749512,28.314929246902462,5.746478080749512,22.568405866622925,0.0,0.0 -100,0.07664004,0.12763922,,,,,,,,,,, -200,0.047432136,0.120679334,,,,,,,,,,, -300,0.03430597,0.13065101,,,,,,,,,,, -400,0.012565528,0.12752873,,,,,,,,,,, -500,0.016936917,0.12090288,,,,,,,,,,, -600,0.0062515927,0.118614085,,,,,,,,,,, -700,0.01067067,0.12601523,,,,,,,,,,, -800,0.007073245,0.13762203,,,,,,,,,,, -900,0.013397502,0.12711838,,,,,,,,,,, -1000,0.014276954,0.11984275,,,,,,,,,,, -1100,0.015175631,0.12122571,,,,,,,,,,, -1200,0.017992506,0.12779507,,,,,,,,,,, -1300,0.00922378,0.120300315,,,,,,,,,,, -1400,0.025307605,0.13231573,,,,,,,,,,, -1410,,,0.1253754194802458,0.1256152709010007,83274637.0,0.1280632156044407,95000000.0,1205.8514742851255,1251.4007487297058,1205.8514742851255,45.48547530174256,0.0170862674713134,0.0 -1500,0.042079087,0.13036571,,,,,,,,,,, -1600,0.048744872,0.13570946,,,,,,,,,,, -1700,0.013249716,0.12244758,,,,,,,,,,, -1800,0.04245892,0.1256924,,,,,,,,,,, -1900,0.005226544,0.12144653,,,,,,,,,,, -2000,0.028074984,0.12698042,,,,,,,,,,, -2100,0.024626922,0.12856084,,,,,,,,,,, -2200,0.026869824,0.12158877,,,,,,,,,,, -2300,0.013248722,0.12861027,,,,,,,,,,, -2400,0.029112425,0.11430491,,,,,,,,,,, -2500,0.00493395,0.12115323,,,,,,,,,,, -2600,0.031259794,0.13439491,,,,,,,,,,, -2700,0.0059108245,0.12824579,,,,,,,,,,, -2800,0.011190518,0.12623367,,,,,,,,,,, -2812,,,0.1242157695522098,0.1256440325187938,83274637.0,0.1279820143194901,95000000.0,2405.9904704093933,2474.51877117157,2405.9904704093933,68.39558744430542,0.0397999286651611,0.0 -2900,0.033815786,0.121047564,,,,,,,,,,, -3000,0.015802022,0.12685007,,,,,,,,,,, -3100,0.036752034,0.12368798,,,,,,,,,,, -3200,0.021941306,0.13120499,,,,,,,,,,, -3300,0.026730595,0.12963589,,,,,,,,,,, -3400,0.014540445,0.120987654,,,,,,,,,,, -3500,0.021978192,0.11902094,,,,,,,,,,, -3600,0.016324347,0.12627591,,,,,,,,,,, -3700,0.016429713,0.12479744,,,,,,,,,,, -3800,0.006770631,0.11865483,,,,,,,,,,, -3900,0.012751977,0.12723103,,,,,,,,,,, -4000,0.013068935,0.1231549,,,,,,,,,,, -4100,0.017040921,0.11631675,,,,,,,,,,, -4200,0.0050648097,0.12448146,,,,,,,,,,, -4206,,,0.1228075881236754,0.1245392818028735,83274637.0,0.1269549632709704,95000000.0,3605.944534778595,3697.401435613632,3605.944534778595,91.25166869163512,0.0668396949768066,0.0 -4300,0.008664669,0.123208836,,,,,,,,,,, -4400,0.010758286,0.119159244,,,,,,,,,,, -4500,0.006242982,0.123159505,,,,,,,,,,, -4600,0.0071120164,0.12277117,,,,,,,,,,, -4700,0.016329171,0.11561136,,,,,,,,,,, -4800,0.006993725,0.11974328,,,,,,,,,,, -4900,0.0044575604,0.11963337,,,,,,,,,,, -5000,0.014737965,0.11955218,,,,,,,,,,, -5100,0.005813671,0.12645037,,,,,,,,,,, -5200,0.012102163,0.11881255,,,,,,,,,,, -5300,0.015433091,0.1275289,,,,,,,,,,, -5400,0.009806295,0.11798766,,,,,,,,,,, -5500,0.012738183,0.11751615,,,,,,,,,,, -5600,0.0054993206,0.12925577,,,,,,,,,,, -5604,,,0.1248605644927834,0.1240328853538713,83274637.0,0.1264720616570723,95000000.0,4806.842922210693,4921.155634880066,4806.842922210693,114.03592133522034,0.0925056934356689,0.0 -5700,0.006458926,0.123172075,,,,,,,,,,, -5800,0.011890605,0.12125274,,,,,,,,,,, -5900,0.005743629,0.11664085,,,,,,,,,,, -6000,0.006080791,0.12789805,,,,,,,,,,, -6100,0.005373599,0.115991,,,,,,,,,,, -6200,0.0052005295,0.12472001,,,,,,,,,,, -6300,0.008708165,0.12155621,,,,,,,,,,, -6400,0.008167361,0.13436106,,,,,,,,,,, -6500,0.009830714,0.1256759,,,,,,,,,,, -6600,0.0059293658,0.12090233,,,,,,,,,,, -6700,0.0060230144,0.12675011,,,,,,,,,,, -6800,0.013305511,0.121538,,,,,,,,,,, -6900,0.0065931445,0.1354225,,,,,,,,,,, -7000,,,0.1227599759146852,0.123875571689673,83274637.0,0.1262262012335526,95000000.0,6007.645996809006,6144.508240938187,6007.645996809006,136.51566314697266,0.1157255172729492,0.0 -7000,0.0059054815,0.11997863,,,,,,,,,,, -7100,0.0054848054,0.11991491,,,,,,,,,,, -7200,0.0051424815,0.12487902,,,,,,,,,,, -7300,0.0058279173,0.11572364,,,,,,,,,,, -7400,0.0066281916,0.12409259,,,,,,,,,,, -7500,0.0057948115,0.123314135,,,,,,,,,,, -7600,0.0053046537,0.12373963,,,,,,,,,,, -7700,0.005484649,0.12561639,,,,,,,,,,, -7800,0.0051645855,0.11625183,,,,,,,,,,, -7900,0.008029766,0.11489172,,,,,,,,,,, -8000,0.0070482865,0.11984636,,,,,,,,,,, -8100,0.008052788,0.123668574,,,,,,,,,,, -8200,0.0080422005,0.12877256,,,,,,,,,,, -8300,0.005819268,0.121866815,,,,,,,,,,, -8390,,,0.1221434868075562,0.1237935244148077,83274637.0,0.1260761212376644,95000000.0,7207.703269481659,7367.158077716827,7207.703269481659,159.04064083099365,0.1388161182403564,0.0 -8400,0.0062660663,0.11729492,,,,,,,,,,, -8500,0.006297216,0.1171326,,,,,,,,,,, -8600,0.009818638,0.12312515,,,,,,,,,,, -8700,0.0074722576,0.12388565,,,,,,,,,,, -8800,0.007586937,0.13063224,,,,,,,,,,, -8900,,,,,,,,7703.252153635025,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index ba272da3c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -210.51775097846985,0.0,57.71510553359985,1,0,57.71510553359985,0.8250461419645351,3581,0.2972505852284627,268.233216047287,0.8154546192714146,0.2794570582253592,0.8275999671365011,3554,0.2738082104341059 -214.952023267746,0.0298702716827392,137.94220185279846,335,0,137.94220185279846,0.3449955194298031,3581,0.6816368071069534,352.9363169670105,0.3210630076272147,0.6868676458086286,0.3440120355695167,3554,0.6621346546145188 -218.9958727359772,0.0759840011596679,217.9908742904663,571,0,217.9908742904663,0.3222129244406241,3581,0.7039003049951131,437.0828356742859,0.2987121854509626,0.7094067164829799,0.3204514004906443,3554,0.6854377192731781 -223.04977464675903,0.1100139617919921,298.2785224914551,810,0,298.2785224914551,0.3111053441230627,3581,0.7163270014486177,521.4665069580078,0.2876839297158377,0.7223307064601353,0.3089031504994372,3554,0.6987798601707583 -227.0959722995758,0.1418704986572265,378.4060695171356,1076,0,378.4060695171356,0.3054533625274888,3581,0.7218287216821418,605.6814365386963,0.2821281467165266,0.7281755038670131,0.3031973079408764,3554,0.7045336515633793 -231.14115858078003,0.1719951629638672,458.46973156929016,1422,0,458.46973156929016,0.3009750762487783,3581,0.7276111252312553,689.8326303958893,0.2774519920349121,0.7344260896955218,0.2988459855972144,3554,0.7104694149637732 -235.19049859046936,0.1966481208801269,538.5077736377716,1769,0,538.5077736377716,0.2979088308956995,3581,0.7299332905002094,773.9570178985596,0.2749485969543457,0.7361435890197754,0.2960623084794422,3554,0.7127721268069429 -239.24358010292053,0.2220118045806884,618.6638140678406,2112,0,618.6638140678406,0.297212610827981,3581,0.7305576523666574,858.2030894756317,0.2739390305110386,0.7374092510768345,0.2953523153928672,3554,0.7133519779649691 -243.29773139953613,0.2474856376647949,698.8899874687195,2459,0,698.8899874687195,0.2946925627988865,3581,0.7334981799558433,942.5207052230836,0.2711841038295201,0.740647724696568,0.2929722877721405,3554,0.7163059146604178 -247.3470792770385,0.2752392292022705,778.935720205307,2807,0,778.935720205307,0.2933455624083705,3581,0.734273553127618,1026.655560255051,0.2702146598270961,0.7409558977399554,0.2917060744715462,3554,0.7170057065981992 -251.3970603942871,0.304495096206665,859.1168282032013,3150,0,859.1168282032013,0.292795649456332,3581,0.7359791968462022,1110.9284377098083,0.2697741644723074,0.7425024850027901,0.2911750995247432,3554,0.7186918843644485 -255.4488706588745,0.3315091133117676,939.1118621826172,3493,0,939.1118621826172,0.2918500050723436,3581,0.7366152169217747,1195.0142102241516,0.2683522530964443,0.743992533002581,0.2901963731446961,3554,0.7194040413310706 -259.5003592967987,0.3605566024780273,1019.31481051445,3841,0,1019.31481051445,0.2915548001278448,3581,0.7365999453495881,1279.309923171997,0.2686336721692766,0.7430756432669503,0.2900099359876196,3554,0.719275445031127 -263.5515856742859,0.3859472274780273,1099.3168652057648,4188,0,1099.3168652057648,0.2911928843200223,3581,0.7377528126963487,1363.4006741046906,0.2682199818747384,0.7445659637451172,0.2897664136140616,3554,0.7204679833682118 -267.60273838043213,0.4113495349884033,1179.4377205371857,4534,0,1179.4377205371857,0.2905081179314437,3581,0.7376765230120776,1447.6101825237274,0.2670318399156843,0.7451190267290387,0.2889700370621131,3554,0.7203840385613042 -271.65371799468994,0.4404506683349609,1259.612999200821,4880,0,1259.612999200821,0.2904610078583147,3581,0.7378385107599135,1531.8774890899658,0.2672845636095319,0.7446834700448173,0.2889450665733152,3554,0.7206724872063168 -275.7093436717987,0.4694664478302002,1339.6272230148315,5225,0,1339.6272230148315,0.2902309798022549,3581,0.7388448664610094,1615.988268136978,0.2669649294444493,0.745838097163609,0.2887557785901449,3554,0.7215957426930923 -279.7575418949127,0.4949519634246826,1419.6141102313995,5566,0,1419.6141102313995,0.2897936946950921,3581,0.7388160959098367,1700.0605845451355,0.2661159038543701,0.7464524677821568,0.2883142440186234,3554,0.7216524844365504 -283.80796813964844,0.5214133262634277,1499.6323685646057,5911,0,1499.6323685646057,0.2897185640140673,3581,0.7389985366561366,1784.1678936481476,0.2663944959640503,0.7461219515119281,0.2883554264341235,3554,0.7216612086513435 -287.8547184467316,0.5506565570831299,1579.788170337677,6253,0,1579.788170337677,0.2895057505650482,3581,0.7397630697387252,1868.4111258983608,0.2661893367767334,0.7469372068132673,0.2880133788238165,3554,0.7225887232036086 -291.9043619632721,0.5786528587341309,1659.8735864162445,6597,0,1659.8735864162445,0.2904327486321034,3581,0.7394097101019268,1952.585978269577,0.2666650669915335,0.7465224947248187,0.2888153024650921,3554,0.722412315458814 -295.9577133655548,0.6068413257598877,1740.0795395374298,6941,0,1740.0795395374298,0.2935818968056234,3581,0.7301608641964535,2036.8851308822632,0.2708651849201747,0.7345943450927734,0.2920314808138717,3554,0.7138045380205402 -300.0156321525574,0.6333053112030029,1820.227460861206,7285,0,1820.227460861206,0.2889406342174846,3581,0.7397586382557246,2121.1292040348053,0.265736630984715,0.7467425210135323,0.2875278797053056,3554,0.7224672024479459 -304.0689377784729,0.6600399017333984,1900.2204134464264,7626,0,1900.2204134464264,0.2898893124629119,3581,0.7389341097109746,2205.213802576065,0.2662156649998256,0.7467371395656041,0.2884126833871166,3554,0.7217820424609594 -308.1188626289368,0.6859474182128906,1980.26410484314,7972,0,1980.26410484314,0.2898612577666853,3581,0.7401055211096761,2289.345111608505,0.2661494357245309,0.7473570278712681,0.2882626200232133,3554,0.7229981430474466 -312.1712157726288,0.7163846492767334,2060.3492851257324,8317,0,2060.3492851257324,0.2888488002543807,3581,0.7399076724378665,2373.52497792244,0.2657549721854074,0.7467418398175921,0.2874924332892955,3554,0.7227271428320202 -316.221538066864,0.7425732612609863,2140.48207116127,8661,0,2140.48207116127,0.2888023719478672,3581,0.7406642288510542,2457.746108531952,0.2654073068073818,0.7479797771998814,0.2872930471992385,3554,0.7234581908149268 -320.27386260032654,0.7702662944793701,2220.520433425904,9006,0,2220.520433425904,0.2888830249384774,3581,0.7397593881990017,2541.876602172852,0.2653544800622122,0.7469651358468192,0.2874644745851945,3554,0.7226491057611142 -324.3219790458679,0.7968940734863281,2300.557582139969,9351,0,2300.557582139969,0.2882040535661302,3581,0.7412121646799078,2626.0004420280457,0.2648039885929653,0.7483945574079242,0.2868160146907533,3554,0.7240381793621623 -328.377671957016,0.8233621120452881,2380.7025940418243,9695,0,2380.7025940418243,0.2885059398234606,3581,0.7412599565196524,2710.239602804184,0.2649398190634591,0.7488522529602051,0.2871628022287827,3554,0.723960004902047 -332.429728269577,0.8554947376251221,2460.9026103019714,10041,0,2460.9026103019714,0.2881796463212964,3581,0.7420153539295937,2794.535893678665,0.2646245104925973,0.7495920317513602,0.2867716551497872,3554,0.7249149286015757 -336.4816882610321,0.8848049640655518,2540.861899137497,10385,0,2540.861899137497,0.2881357405512601,3581,0.7409942720696034,2878.588441133499,0.2647161313465663,0.7483953067234584,0.2867452592479425,3554,0.7237816737083216 -340.54004883766174,0.9131312370300292,2621.0233914852142,10727,0,2621.0233914852142,0.2881166169976962,3581,0.7407105208042446,2962.848497867584,0.2646113463810512,0.7482667650495257,0.2867335296441949,3554,0.72345544303074 -344.5905284881592,0.939763069152832,2701.2316172122955,11073,0,2701.2316172122955,0.2880745519975216,3581,0.7407743341594527,3047.145696401596,0.2641687563487461,0.748828615461077,0.2866743664159222,3554,0.7235340296584833 -348.6447901725769,0.966834306716919,2781.212673664093,11420,0,2781.212673664093,0.2879572199629991,3581,0.7405599867355487,3131.220014810562,0.2644648722239903,0.7480402673993792,0.286574295550568,3554,0.7233010862540448 -352.6982641220093,0.994499444961548,2861.333312511444,11765,0,2861.333312511444,0.2878505916643396,3581,0.7415539342842432,3215.433751344681,0.2642995119094848,0.7490818841116769,0.2864849238698913,3554,0.72429846321926 -356.7520639896393,1.0221521854400637,2941.3339407444,12108,0,2941.3339407444,0.2886686775145734,3581,0.7395374049890044,3299.527735233307,0.2648725509643554,0.7470649310520717,0.2872223261037299,3554,0.7225018932233047 -360.8086497783661,1.0507407188415527,3021.476846933365,12453,0,3021.476846933365,0.2878965086458915,3581,0.7415426169584264,3383.7677912712097,0.2643582820892334,0.7491067477634975,0.2864743964217255,3554,0.7243488163644837 -364.86098647117615,1.0791373252868652,3101.4897875785828,12798,0,3101.4897875785828,0.2876597651886693,3581,0.7412870226542865,3467.873590707779,0.2641389199665614,0.7487072263445173,0.2862849538756946,3554,0.7240591312165869 -368.9163925647736,1.107081413269043,3181.5577688217163,13140,0,3181.5577688217163,0.2878337520289374,3581,0.7415685240898143,3552.0366492271423,0.2637655564716884,0.7497420992170062,0.2864398945565296,3554,0.7244437523081387 -372.96993708610535,1.135343074798584,3261.6042971611023,13487,0,3261.6042971611023,0.2875157079028204,3581,0.7415469120881039,3636.176833629608,0.2638783114297049,0.749126638684954,0.286114694298018,3554,0.7242455683736635 -377.0191743373871,1.1649174690246582,3341.729484319687,13833,0,3341.729484319687,0.2879423233624511,3581,0.7416033623638648,3720.393041372299,0.2641657420567104,0.7492619241986956,0.2865796365560811,3554,0.7243551362681134 -381.06914806365967,1.1924426555633545,3421.8819210529327,14177,0,3421.8819210529327,0.2877021028954901,3581,0.7419544039941707,3804.6349902153015,0.2635764905384609,0.7501425061907087,0.2862090291638822,3554,0.7248468522483469 -385.1245248317719,1.2210686206817627,3501.987695455551,14523,0,3501.987695455551,0.2876874108249267,3581,0.7405440333967467,3888.836674690247,0.2640088626316615,0.7482326371329171,0.2863345685539181,3554,0.7233014984216728 -389.177859544754,1.2528254985809326,3581.99942946434,14867,0,3581.99942946434,0.2877675524905753,3581,0.741584068368647,3972.9454820156097,0.2641704423086984,0.7490936688014439,0.2864648993926297,3554,0.724210396736072 -393.2313735485077,1.2852871417999268,3662.1490709781647,15209,0,3662.1490709781647,0.2880258397728463,3581,0.7425437230565833,4057.1929366588593,0.2643688576562064,0.7499757494245257,0.2867057426766056,3554,0.7254071941386466 -397.2786636352539,1.3142294883728027,3742.141808271408,15555,0,3742.141808271408,0.2875579433446488,3581,0.7413624942186191,4141.27409529686,0.2637040104184832,0.7491991860525948,0.2861622309644502,3554,0.7240965697761326 -401.3295774459839,1.3458054065704346,3822.2436985969534,15900,0,3822.2436985969534,0.2874640981700293,3581,0.7411458969648841,4225.470196008682,0.2637172596795218,0.7490659441266742,0.2860872336298009,3554,0.7240223109084833 -405.3828492164612,1.3787052631378174,3902.2939944267273,16242,0,3902.2939944267273,0.2879943080668807,3581,0.7415718647462302,4309.618505954742,0.2642509256090437,0.7492741176060268,0.2865908166029913,3554,0.7244184726936199 -409.4357068538666,1.4074318408966064,3982.34534406662,16586,0,3982.34534406662,0.2875285932918528,3581,0.7408519873769548,4393.763057470322,0.2635989189147949,0.7487504822867257,0.2861476848819112,3554,0.7237115365169527 -413.48853278160095,1.4365577697753906,4062.446593284607,16931,0,4062.446593284607,0.2871454404539584,3581,0.7420229897156869,4477.958353757858,0.2632872888020107,0.7499645096915108,0.2857487753125879,3554,0.7248403262609032 -417.5439202785492,1.4648287296295166,4142.558726072311,17274,0,4142.558726072311,0.2872500234527715,3581,0.7423415110784348,4562.1660623550415,0.2633872883660452,0.750265257699149,0.2858453255794527,3554,0.725125202786473 -421.599889755249,1.4937596321105957,4222.760106086731,17619,0,4222.760106086731,0.2873178592310109,3581,0.7431117710005934,4646.464495420456,0.2629937955311366,0.7514595985412598,0.2859021875384689,3554,0.7259761228545301 -425.6540808677673,1.5217459201812744,4302.788812160492,17965,0,4302.788812160492,0.2870001559882016,3581,0.7422263606970818,4730.587537527084,0.2631190163748605,0.7501124654497419,0.2856347937897879,3554,0.7249581375079136 -429.7097146511078,1.550818681716919,4382.871379613876,18309,0,4382.871379613876,0.2876765366474099,3581,0.7415111193407917,4814.766731977463,0.2637862648282732,0.7495180538722447,0.2862694117213878,3554,0.7243303375158272 -433.7644882202149,1.580070734024048,4463.077668428421,18655,0,4463.077668428421,0.2869414217942788,3581,0.7428010217772619,4899.068983078003,0.2626204831259591,0.7512383460998535,0.2855628018440929,3554,0.725613896204101 -437.8175001144409,1.6196582317352295,4543.044121026993,18999,0,4543.044121026993,0.2871799719330319,3581,0.7423001278448758,4983.139835596085,0.2631433010101318,0.7502283368791852,0.2858119915225362,3554,0.7251444372757808 -441.870076417923,1.6532104015350342,4623.194087982178,19346,0,4623.194087982178,0.2870098370741413,3581,0.7424429579508168,5067.387616157532,0.2632224900381906,0.7503267696925572,0.2856542858838632,3554,0.7251956147562606 -445.9244599342346,1.6831650733947754,4703.263406991959,19688,0,4703.263406991959,0.2870990462357756,3581,0.7424449350740017,5151.553251743317,0.2627083233424595,0.750896794455392,0.2856937337605955,3554,0.7252511199968346 -449.9741785526276,1.7135276794433594,4783.346960544586,20033,0,4783.346960544586,0.2871399863210346,3581,0.7425795839805571,5235.728672742844,0.2630380732672555,0.7507200922284808,0.2857979091285787,3554,0.7253897457090602 -454.02555441856384,1.7433743476867676,4863.308755397797,20379,0,4863.308755397797,0.2867996143382609,3581,0.7427057108044192,5319.78363609314,0.2628463676997593,0.7507028579711914,0.2854411952201744,3554,0.725503091806767 -458.0803413391113,1.773087501525879,4943.37055516243,20722,0,4943.37055516243,0.2879258246103567,3581,0.742558381038816,5403.941605091095,0.2634522574288504,0.7512781279427665,0.286552278929771,3554,0.7255425225098481 -462.1315383911133,1.8046362400054927,5023.388708114624,21068,0,5023.388708114624,0.2867335170648911,3581,0.7424042336070581,5488.054424762726,0.2626910550253732,0.7504569462367466,0.2853945172363006,3554,0.7252177344189645 -466.18581223487854,1.8341941833496087,5103.514678239822,21413,0,5103.514678239822,0.2866953722227555,3581,0.7430837503926976,5572.2762241363525,0.2626854521887643,0.7511151858738491,0.2853362985588421,3554,0.7259615595983399 -470.2389633655548,1.8642263412475584,5183.712336778641,21757,0,5183.712336778641,0.2870615490719247,3581,0.7417080817160011,5656.568719625473,0.2627014432634626,0.7500740459987095,0.2856341583646947,3554,0.7243928496060776 -474.2914884090424,1.8956067562103271,5263.737190961838,22102,0,5263.737190961838,0.2870317558708287,3581,0.743197332710835,5740.689452886581,0.2628119502748762,0.7515127999441964,0.2855904857697752,3554,0.7261375551755065 -478.3473429679871,1.929384469985962,5343.85110616684,22449,0,5343.85110616684,0.2867978758333915,3581,0.7433537981490854,5824.905149459839,0.2626760687146868,0.7513835770743233,0.2854385676515458,3554,0.7262300181133934 -482.4012656211853,1.959122896194458,5423.849596738815,22791,0,5423.849596738815,0.2869999173698862,3581,0.7425605626919854,5908.999294519424,0.2628911222730364,0.7507671628679548,0.285624558293692,3554,0.7254199713351154 -486.4510595798493,1.9885139465332031,5504.067927360535,23137,0,5504.067927360535,0.2869907816972389,3581,0.7428933329769967,5993.308559894562,0.2626312630517142,0.7513228143964495,0.2855588003833708,3554,0.7257843962128939 -490.5085999965668,2.0182549953460693,5584.158223390579,23484,0,5584.158223390579,0.286600777104859,3581,0.7433634110583636,6077.498130559921,0.2624009847640991,0.7516647747584752,0.2852267650116946,3554,0.7261988994574775 -494.56139755249023,2.053964614868164,5664.2836174964905,23829,0,5664.2836174964905,0.2867686621339186,3581,0.7430474804087546,6161.723926067352,0.2625121729714529,0.7513493810381208,0.2853360753013769,3554,0.7258792634619443 -498.6148250102997,2.0857253074646,5744.359104633331,24175,0,5744.359104633331,0.286753390561732,3581,0.7425226564681653,6245.896435976028,0.2620893716812134,0.7513031278337751,0.2852982074005522,3554,0.7253518262872819 -502.6727757453919,2.1174325942993164,5824.3427176475525,24520,0,5824.3427176475525,0.2865376455162664,3581,0.7431070668109466,6329.981694221497,0.2621761390141078,0.7515650476728167,0.2851459458092993,3554,0.7259875948535102 -506.7254252433777,2.1474146842956543,5904.422456741333,24863,0,5904.422456741333,0.2865232261523492,3581,0.7429354661538328,6414.155948877335,0.2622285911015102,0.7513407298496791,0.2851708647771437,3554,0.7257067026150112 -510.77738857269287,2.179636716842652,5984.5337653160095,25206,0,5984.5337653160095,0.2865558145965687,3581,0.7434358828495881,6498.363244771957,0.2616928815841675,0.7524091175624302,0.2851238261465953,3554,0.7262610680747046 -514.8336462974548,2.210814952850342,6064.5873601436615,25552,0,6064.5873601436615,0.2865705066671321,3581,0.7429550328556968,6582.516287326813,0.2620738404137747,0.7515433175223214,0.2851855654225432,3554,0.7256952993106359 -518.8966798782349,2.243125438690185,6144.693779706955,25900,0,6144.693779706955,0.2865279644303267,3581,0.7424625246526808,6666.730274915695,0.26214143208095,0.7509380068097796,0.2851501533538355,3554,0.7252735831325618 -522.9531097412109,2.274085521697998,6224.853043317795,26246,0,6224.853043317795,0.2866125716673066,3581,0.7433814778736736,6750.988937854767,0.2616041558129446,0.7524256025041852,0.285166794621817,3554,0.7262390858012099 -527.0004780292511,2.305572271347046,6304.8623919487,26590,0,6304.8623919487,0.2864727413344736,3581,0.7433889773064437,6835.089107990265,0.2618626015526907,0.752068178994315,0.2850863875870498,3554,0.7262528934167487 -531.0505583286285,2.336199998855591,6384.856774330139,26937,0,6384.856774330139,0.2864647646650726,3581,0.7432566464063809,6919.176200866699,0.2620003904615129,0.751748970576695,0.2851182103626635,3554,0.7261211371649902 -535.1086058616638,2.369391679763794,6465.06779050827,27281,0,6465.06779050827,0.2865492355487294,3581,0.7433453442430537,7003.490379571915,0.2614716802324567,0.7525568008422852,0.2850966059094946,3554,0.7262717157384285 -539.1609346866608,2.4011919498443604,6545.176929950714,27625,0,6545.176929950714,0.2864156774687587,3581,0.7431984235374197,7087.695506572723,0.2617475816181728,0.751924855368478,0.2850241674488692,3554,0.726097574915588 -543.2160265445709,2.432178735733032,6625.335506677628,27971,0,6625.335506677628,0.28652326024068,3581,0.7433388674602066,7171.952100038528,0.2618813855307443,0.7521243095397949,0.2851246848291537,3554,0.7263136881418824 -547.2668516635895,2.463989496231079,6705.314663171768,28311,0,6705.314663171768,0.2869765668633063,3581,0.7433338223872522,7256.025635004044,0.2617536953517368,0.7526357514517648,0.2855354957387362,3554,0.7263437076841235 -551.3233077526093,2.4959301948547363,6785.449561357498,28660,0,6785.449561357498,0.2864036442879957,3581,0.7434398370959578,7340.261151790619,0.2616182225091116,0.7522928374154227,0.2850380265853615,3554,0.7263367008344471 -555.37380027771,2.5343716144561768,6865.401733875275,29003,0,6865.401733875275,0.2864864789317753,3581,0.7432160812927604,7424.313965082169,0.2617198399135044,0.7521061216081891,0.2851091598484981,3554,0.7261561027187676 -559.4301190376282,2.566793918609619,6945.3704397678375,29346,0,6945.3704397678375,0.286642160338418,3581,0.7429957343226403,7508.383171081543,0.2616387775966099,0.7521747861589704,0.2852646500861705,3554,0.7258670358223129 -563.4838757514954,2.602396249771118,7025.511574745178,29694,0,7025.511574745178,0.2864821838020979,3581,0.743683500484327,7592.625883340836,0.2614959818976266,0.7528272356305804,0.2850999032505187,3554,0.7266219208330402 -567.5378227233887,2.6340932846069336,7105.662876844406,30036,0,7105.662876844406,0.2864304036276529,3581,0.7435130588304594,7676.874650716782,0.2615717138562883,0.7525065967014858,0.285084979347654,3554,0.7264429713878728 -571.5938925743103,2.66501784324646,7185.761651039124,30379,0,7185.761651039124,0.2867632420893256,3581,0.7427278000427604,7761.072353124618,0.2618185792650495,0.7517011506216866,0.2853469118752638,3554,0.7255941808525604 -575.6503157615662,2.6965012550354004,7265.859354972839,30724,0,7265.859354972839,0.2864934329512531,3581,0.7432732133351369,7845.270067453384,0.2612915209361485,0.7526508740016392,0.285073198222953,3554,0.7261841301174733 -579.7019543647766,2.7283411026000977,7345.888344764709,31071,0,7345.888344764709,0.2864299604793528,3581,0.7436164146493647,7929.394406795502,0.2614444323948451,0.752741881779262,0.2850723395403946,3554,0.7265454637380416 -583.755449295044,2.7600905895233154,7426.11357998848,31415,0,7426.11357998848,0.2864075985343654,3581,0.7435497378743717,8013.716660499573,0.2614004441670009,0.7526612281799316,0.2850400359025481,3554,0.7264517642972707 -587.8157875537872,2.792599678039551,7506.210152387619,31761,0,7506.210152387619,0.2863236730640009,3581,0.7436201643657497,8097.917795181274,0.2610013655253819,0.753065654209682,0.2849258998168876,3554,0.7265352282419457 -591.8755309581757,2.8241755962371826,7586.232357978821,32107,0,7586.232357978821,0.2863925996688251,3581,0.7436131421696105,8182.043341875076,0.2612640857696533,0.7528303691319057,0.2850118195936796,3554,0.7265279466138506 -595.9300870895386,2.857153415679932,7666.246356487274,32449,0,7666.246356487274,0.2863342745348715,3581,0.7437789478104929,8266.15687918663,0.2612173386982509,0.7530434472220284,0.2849734708306222,3554,0.726705865639948 -599.9857411384583,2.8888802528381348,7746.339337348938,32793,0,7746.339337348938,0.2862743131610409,3581,0.7436942042201898,8350.34925699234,0.2608901262283325,0.7532211031232562,0.2848969965619724,3554,0.7266475439205824 -604.0404622554779,2.9211058616638184,7826.380793333054,33140,0,7826.380793333054,0.2862382136187517,3581,0.7436322998115051,8434.489802360535,0.2610472440719604,0.7529222624642509,0.2849144965125123,3554,0.7264977896824001 -608.09494972229,2.9546470642089844,7906.548151254654,33484,0,7906.548151254654,0.2863115376182456,3581,0.7435476925745252,8518.75707435608,0.2611274208341326,0.7528365680149623,0.2849736940880873,3554,0.7264392618792206 -612.147744178772,2.9883530139923096,7986.679588794708,33829,0,7986.679588794708,0.2862855282218654,3581,0.7438938254851997,8602.986902713776,0.2608636447361537,0.7534403119768415,0.2848896977602261,3554,0.7268620084763646 -616.2017261981964,3.026043176651001,8066.71623635292,34175,0,8066.71623635292,0.2861680598340198,3581,0.7438721453068277,8687.127203702927,0.2609024899346487,0.7532878603254046,0.2848174653834148,3554,0.7268059536789533 -620.2588531970978,3.0591280460357666,8146.856377363205,34521,0,8146.856377363205,0.2862098180392174,3581,0.7437584266353672,8771.369478464127,0.2609563895634242,0.7531764847891671,0.2848718199893605,3554,0.7267027056881331 -624.3104538917542,3.092402696609497,8226.93552994728,34863,0,8226.93552994728,0.2861729344653204,3581,0.7436710241552639,8855.545527935028,0.2608116865158081,0.7531660624912807,0.2848065772885745,3554,0.7266088001635481 -628.3667876720428,3.125974416732788,8306.98347043991,35207,0,8306.98347043991,0.2860854297202247,3581,0.7435840989117914,8939.69522190094,0.260757531438555,0.7530843189784459,0.2847236285534345,3554,0.7264942175629572 -632.4175064563751,3.1593573093414307,8386.943142175674,35554,0,8386.943142175674,0.2861068031036198,3581,0.7436912726237433,9023.751316070557,0.260784523827689,0.7531565257481166,0.2847584567180026,3554,0.7266150513725732 -636.4669382572174,3.19321870803833,8467.084789514542,35897,0,8467.084789514542,0.2861113709399434,3581,0.7436200280124267,9107.988319396973,0.2607801982334682,0.7530922208513532,0.2847584910653049,3554,0.7265309691764561 -640.5183072090149,3.226508617401123,8534.14320230484,36189,0,8534.14320230484,0.2861253471555606,3581,0.7438771222031206,9179.14162349701,0.2607856137411935,0.7533482824053083,0.28476645963944675,3554,0.7268125483610017 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 1f85576c8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.989067,0.812358,,,,,,,,,,,,,, -1,,,0.2794570582253592,0.8154546192714146,0.2738082104341059,0.8275999671365011,3554.0,0.2972505852284627,0.8250461419645351,3581.0,57.71510553359985,268.233216047287,57.71510553359985,210.51775097846985,0.0,0.0 -100,0.58920413,0.34720257,,,,,,,,,,,,,, -200,0.1586968,0.39448,,,,,,,,,,,,,, -300,0.095219925,0.35154366,,,,,,,,,,,,,, -335,,,0.6868676458086286,0.3210630076272147,0.6621346546145188,0.3440120355695167,3554.0,0.6816368071069534,0.3449955194298031,3581.0,137.94220185279846,352.9363169670105,137.94220185279846,214.952023267746,0.0298702716827392,0.0 -400,0.12559126,0.33677927,,,,,,,,,,,,,, -500,0.08939382,0.23391655,,,,,,,,,,,,,, -571,,,0.7094067164829799,0.2987121854509626,0.6854377192731781,0.3204514004906443,3554.0,0.7039003049951131,0.3222129244406241,3581.0,217.9908742904663,437.0828356742859,217.9908742904663,218.9958727359772,0.0759840011596679,0.0 -600,0.09621504,0.41580975,,,,,,,,,,,,,, -700,0.100117736,0.31280568,,,,,,,,,,,,,, -800,0.15714906,0.21720444,,,,,,,,,,,,,, -810,,,0.7223307064601353,0.2876839297158377,0.6987798601707583,0.3089031504994372,3554.0,0.7163270014486177,0.3111053441230627,3581.0,298.2785224914551,521.4665069580078,298.2785224914551,223.04977464675903,0.1100139617919921,0.0 -900,0.17126283,0.32293475,,,,,,,,,,,,,, -1000,0.11933245,0.28629637,,,,,,,,,,,,,, -1076,,,0.7281755038670131,0.2821281467165266,0.7045336515633793,0.3031973079408764,3554.0,0.7218287216821418,0.3054533625274888,3581.0,378.4060695171356,605.6814365386963,378.4060695171356,227.0959722995758,0.1418704986572265,0.0 -1100,0.2663192,0.31351754,,,,,,,,,,,,,, -1200,0.09060602,0.2242957,,,,,,,,,,,,,, -1300,0.12533529,0.4012684,,,,,,,,,,,,,, -1400,0.0901713,0.22473328,,,,,,,,,,,,,, -1422,,,0.7344260896955218,0.2774519920349121,0.7104694149637732,0.2988459855972144,3554.0,0.7276111252312553,0.3009750762487783,3581.0,458.46973156929016,689.8326303958893,458.46973156929016,231.14115858078003,0.1719951629638672,0.0 -1500,0.057339136,0.28885803,,,,,,,,,,,,,, -1600,0.20481823,0.22553092,,,,,,,,,,,,,, -1700,0.14316607,0.40114397,,,,,,,,,,,,,, -1769,,,0.7361435890197754,0.2749485969543457,0.7127721268069429,0.2960623084794422,3554.0,0.7299332905002094,0.2979088308956995,3581.0,538.5077736377716,773.9570178985596,538.5077736377716,235.19049859046936,0.1966481208801269,0.0 -1800,0.3161283,0.25857943,,,,,,,,,,,,,, -1900,0.086517304,0.2271758,,,,,,,,,,,,,, -2000,0.074893855,0.2096474,,,,,,,,,,,,,, -2100,0.083265916,0.227808,,,,,,,,,,,,,, -2112,,,0.7374092510768345,0.2739390305110386,0.7133519779649691,0.2953523153928672,3554.0,0.7305576523666574,0.297212610827981,3581.0,618.6638140678406,858.2030894756317,618.6638140678406,239.24358010292053,0.2220118045806884,0.0 -2200,0.0932411,0.27263016,,,,,,,,,,,,,, -2300,0.077058524,0.28949225,,,,,,,,,,,,,, -2400,0.05096764,0.26130182,,,,,,,,,,,,,, -2459,,,0.740647724696568,0.2711841038295201,0.7163059146604178,0.2929722877721405,3554.0,0.7334981799558433,0.2946925627988865,3581.0,698.8899874687195,942.5207052230836,698.8899874687195,243.29773139953613,0.2474856376647949,0.0 -2500,0.09897017,0.2095439,,,,,,,,,,,,,, -2600,0.18400626,0.29955155,,,,,,,,,,,,,, -2700,0.056811146,0.28681016,,,,,,,,,,,,,, -2800,0.46485692,0.33454633,,,,,,,,,,,,,, -2807,,,0.7409558977399554,0.2702146598270961,0.7170057065981992,0.2917060744715462,3554.0,0.734273553127618,0.2933455624083705,3581.0,778.935720205307,1026.655560255051,778.935720205307,247.3470792770385,0.2752392292022705,0.0 -2900,0.24122646,0.23588428,,,,,,,,,,,,,, -3000,0.05523843,0.33331266,,,,,,,,,,,,,, -3100,0.1248921,0.19625005,,,,,,,,,,,,,, -3150,,,0.7425024850027901,0.2697741644723074,0.7186918843644485,0.2911750995247432,3554.0,0.7359791968462022,0.292795649456332,3581.0,859.1168282032013,1110.9284377098083,859.1168282032013,251.3970603942871,0.304495096206665,0.0 -3200,0.086714275,0.39221695,,,,,,,,,,,,,, -3300,0.22106212,0.21160299,,,,,,,,,,,,,, -3400,0.04001222,0.3179382,,,,,,,,,,,,,, -3493,,,0.743992533002581,0.2683522530964443,0.7194040413310706,0.2901963731446961,3554.0,0.7366152169217747,0.2918500050723436,3581.0,939.1118621826172,1195.0142102241516,939.1118621826172,255.4488706588745,0.3315091133117676,0.0 -3500,0.21327014,0.31735155,,,,,,,,,,,,,, -3600,0.059436668,0.30547073,,,,,,,,,,,,,, -3700,0.28325766,0.33037904,,,,,,,,,,,,,, -3800,0.06258707,0.23522341,,,,,,,,,,,,,, -3841,,,0.7430756432669503,0.2686336721692766,0.719275445031127,0.2900099359876196,3554.0,0.7365999453495881,0.2915548001278448,3581.0,1019.31481051445,1279.309923171997,1019.31481051445,259.5003592967987,0.3605566024780273,0.0 -3900,0.12143822,0.28612682,,,,,,,,,,,,,, -4000,0.2313267,0.20248728,,,,,,,,,,,,,, -4100,0.07055551,0.279187,,,,,,,,,,,,,, -4188,,,0.7445659637451172,0.2682199818747384,0.7204679833682118,0.2897664136140616,3554.0,0.7377528126963487,0.2911928843200223,3581.0,1099.3168652057648,1363.4006741046906,1099.3168652057648,263.5515856742859,0.3859472274780273,0.0 -4200,0.16869576,0.2488719,,,,,,,,,,,,,, -4300,0.10269535,0.35294735,,,,,,,,,,,,,, -4400,0.07789467,0.2321518,,,,,,,,,,,,,, -4500,0.14352657,0.22011209,,,,,,,,,,,,,, -4534,,,0.7451190267290387,0.2670318399156843,0.7203840385613042,0.2889700370621131,3554.0,0.7376765230120776,0.2905081179314437,3581.0,1179.4377205371857,1447.6101825237274,1179.4377205371857,267.60273838043213,0.4113495349884033,0.0 -4600,0.085384384,0.24574223,,,,,,,,,,,,,, -4700,0.19704281,0.30372617,,,,,,,,,,,,,, -4800,0.17835931,0.21233891,,,,,,,,,,,,,, -4880,,,0.7446834700448173,0.2672845636095319,0.7206724872063168,0.2889450665733152,3554.0,0.7378385107599135,0.2904610078583147,3581.0,1259.612999200821,1531.8774890899658,1259.612999200821,271.65371799468994,0.4404506683349609,0.0 -4900,0.037787154,0.41831166,,,,,,,,,,,,,, -5000,0.06464655,0.20902136,,,,,,,,,,,,,, -5100,0.40624842,0.27726865,,,,,,,,,,,,,, -5200,0.13835779,0.25398886,,,,,,,,,,,,,, -5225,,,0.745838097163609,0.2669649294444493,0.7215957426930923,0.2887557785901449,3554.0,0.7388448664610094,0.2902309798022549,3581.0,1339.6272230148315,1615.988268136978,1339.6272230148315,275.7093436717987,0.4694664478302002,0.0 -5300,0.084842116,0.26597908,,,,,,,,,,,,,, -5400,0.1265526,0.29695088,,,,,,,,,,,,,, -5500,0.12803128,0.26585758,,,,,,,,,,,,,, -5566,,,0.7464524677821568,0.2661159038543701,0.7216524844365504,0.2883142440186234,3554.0,0.7388160959098367,0.2897936946950921,3581.0,1419.6141102313995,1700.0605845451355,1419.6141102313995,279.7575418949127,0.4949519634246826,0.0 -5600,0.08293215,0.342425,,,,,,,,,,,,,, -5700,0.2723892,0.2664102,,,,,,,,,,,,,, -5800,0.18034992,0.34898716,,,,,,,,,,,,,, -5900,0.07540537,0.32688797,,,,,,,,,,,,,, -5911,,,0.7461219515119281,0.2663944959640503,0.7216612086513435,0.2883554264341235,3554.0,0.7389985366561366,0.2897185640140673,3581.0,1499.6323685646057,1784.1678936481476,1499.6323685646057,283.80796813964844,0.5214133262634277,0.0 -6000,0.59220105,0.20862347,,,,,,,,,,,,,, -6100,0.12171459,0.27298725,,,,,,,,,,,,,, -6200,0.06105061,0.34158245,,,,,,,,,,,,,, -6253,,,0.7469372068132673,0.2661893367767334,0.7225887232036086,0.2880133788238165,3554.0,0.7397630697387252,0.2895057505650482,3581.0,1579.788170337677,1868.4111258983608,1579.788170337677,287.8547184467316,0.5506565570831299,0.0 -6300,0.33670798,0.2577365,,,,,,,,,,,,,, -6400,0.10773279,0.31896555,,,,,,,,,,,,,, -6500,0.14427286,0.26117432,,,,,,,,,,,,,, -6597,,,0.7465224947248187,0.2666650669915335,0.722412315458814,0.2888153024650921,3554.0,0.7394097101019268,0.2904327486321034,3581.0,1659.8735864162445,1952.585978269577,1659.8735864162445,291.9043619632721,0.5786528587341309,0.0 -6600,0.13164917,0.2288013,,,,,,,,,,,,,, -6700,0.073746935,0.24054289,,,,,,,,,,,,,, -6800,0.17488907,0.27153745,,,,,,,,,,,,,, -6900,0.12407624,0.35990235,,,,,,,,,,,,,, -6941,,,0.7345943450927734,0.2708651849201747,0.7138045380205402,0.2920314808138717,3554.0,0.7301608641964535,0.2935818968056234,3581.0,1740.0795395374298,2036.8851308822632,1740.0795395374298,295.9577133655548,0.6068413257598877,0.0 -7000,0.21240251,0.22431934,,,,,,,,,,,,,, -7100,0.18898351,0.28847912,,,,,,,,,,,,,, -7200,0.14619945,0.23371351,,,,,,,,,,,,,, -7285,,,0.7467425210135323,0.265736630984715,0.7224672024479459,0.2875278797053056,3554.0,0.7397586382557246,0.2889406342174846,3581.0,1820.227460861206,2121.1292040348053,1820.227460861206,300.0156321525574,0.6333053112030029,0.0 -7300,0.07060672,0.25100654,,,,,,,,,,,,,, -7400,0.17994098,0.31781164,,,,,,,,,,,,,, -7500,0.09707641,0.23923059,,,,,,,,,,,,,, -7600,0.06717066,0.26941,,,,,,,,,,,,,, -7626,,,0.7467371395656041,0.2662156649998256,0.7217820424609594,0.2884126833871166,3554.0,0.7389341097109746,0.2898893124629119,3581.0,1900.2204134464264,2205.213802576065,1900.2204134464264,304.0689377784729,0.6600399017333984,0.0 -7700,0.18109553,0.21837448,,,,,,,,,,,,,, -7800,0.1756123,0.30225965,,,,,,,,,,,,,, -7900,0.08886168,0.2952508,,,,,,,,,,,,,, -7972,,,0.7473570278712681,0.2661494357245309,0.7229981430474466,0.2882626200232133,3554.0,0.7401055211096761,0.2898612577666853,3581.0,1980.26410484314,2289.345111608505,1980.26410484314,308.1188626289368,0.6859474182128906,0.0 -8000,0.08527156,0.2712155,,,,,,,,,,,,,, -8100,0.16224475,0.19412734,,,,,,,,,,,,,, -8200,0.07272914,0.31846613,,,,,,,,,,,,,, -8300,0.3205471,0.21408427,,,,,,,,,,,,,, -8317,,,0.7467418398175921,0.2657549721854074,0.7227271428320202,0.2874924332892955,3554.0,0.7399076724378665,0.2888488002543807,3581.0,2060.3492851257324,2373.52497792244,2060.3492851257324,312.1712157726288,0.7163846492767334,0.0 -8400,0.09641868,0.19269164,,,,,,,,,,,,,, -8500,0.12169332,0.29381907,,,,,,,,,,,,,, -8600,0.21031648,0.2871158,,,,,,,,,,,,,, -8661,,,0.7479797771998814,0.2654073068073818,0.7234581908149268,0.2872930471992385,3554.0,0.7406642288510542,0.2888023719478672,3581.0,2140.48207116127,2457.746108531952,2140.48207116127,316.221538066864,0.7425732612609863,0.0 -8700,0.06190369,0.26081055,,,,,,,,,,,,,, -8800,0.1652367,0.20444804,,,,,,,,,,,,,, -8900,0.11687401,0.25417742,,,,,,,,,,,,,, -9000,0.17579348,0.23442975,,,,,,,,,,,,,, -9006,,,0.7469651358468192,0.2653544800622122,0.7226491057611142,0.2874644745851945,3554.0,0.7397593881990017,0.2888830249384774,3581.0,2220.520433425904,2541.876602172852,2220.520433425904,320.27386260032654,0.7702662944793701,0.0 -9100,0.124134034,0.36945623,,,,,,,,,,,,,, -9200,0.16463447,0.28100577,,,,,,,,,,,,,, -9300,0.08917528,0.33195847,,,,,,,,,,,,,, -9351,,,0.7483945574079242,0.2648039885929653,0.7240381793621623,0.2868160146907533,3554.0,0.7412121646799078,0.2882040535661302,3581.0,2300.557582139969,2626.0004420280457,2300.557582139969,324.3219790458679,0.7968940734863281,0.0 -9400,0.28294492,0.20003825,,,,,,,,,,,,,, -9500,0.06996092,0.3279146,,,,,,,,,,,,,, -9600,0.047938798,0.29640973,,,,,,,,,,,,,, -9695,,,0.7488522529602051,0.2649398190634591,0.723960004902047,0.2871628022287827,3554.0,0.7412599565196524,0.2885059398234606,3581.0,2380.7025940418243,2710.239602804184,2380.7025940418243,328.377671957016,0.8233621120452881,0.0 -9700,0.06720126,0.2697099,,,,,,,,,,,,,, -9800,0.11755798,0.24100548,,,,,,,,,,,,,, -9900,0.12592177,0.23162276,,,,,,,,,,,,,, -10000,0.061691362,0.2617587,,,,,,,,,,,,,, -10041,,,0.7495920317513602,0.2646245104925973,0.7249149286015757,0.2867716551497872,3554.0,0.7420153539295937,0.2881796463212964,3581.0,2460.9026103019714,2794.535893678665,2460.9026103019714,332.429728269577,0.8554947376251221,0.0 -10100,0.16841641,0.3477613,,,,,,,,,,,,,, -10200,0.20101187,0.27768648,,,,,,,,,,,,,, -10300,0.110392794,0.2743569,,,,,,,,,,,,,, -10385,,,0.7483953067234584,0.2647161313465663,0.7237816737083216,0.2867452592479425,3554.0,0.7409942720696034,0.2881357405512601,3581.0,2540.861899137497,2878.588441133499,2540.861899137497,336.4816882610321,0.8848049640655518,0.0 -10400,0.11240452,0.28717226,,,,,,,,,,,,,, -10500,0.15085796,0.286116,,,,,,,,,,,,,, -10600,0.08880901,0.2786437,,,,,,,,,,,,,, -10700,0.13785887,0.20626098,,,,,,,,,,,,,, -10727,,,0.7482667650495257,0.2646113463810512,0.72345544303074,0.2867335296441949,3554.0,0.7407105208042446,0.2881166169976962,3581.0,2621.0233914852142,2962.848497867584,2621.0233914852142,340.54004883766174,0.9131312370300292,0.0 -10800,0.08706867,0.3059653,,,,,,,,,,,,,, -10900,0.12786016,0.32686207,,,,,,,,,,,,,, -11000,0.07969847,0.26578236,,,,,,,,,,,,,, -11073,,,0.748828615461077,0.2641687563487461,0.7235340296584833,0.2866743664159222,3554.0,0.7407743341594527,0.2880745519975216,3581.0,2701.2316172122955,3047.145696401596,2701.2316172122955,344.5905284881592,0.939763069152832,0.0 -11100,0.07399699,0.35212865,,,,,,,,,,,,,, -11200,0.13142559,0.2667174,,,,,,,,,,,,,, -11300,0.1864301,0.3103883,,,,,,,,,,,,,, -11400,0.04098665,0.27494636,,,,,,,,,,,,,, -11420,,,0.7480402673993792,0.2644648722239903,0.7233010862540448,0.286574295550568,3554.0,0.7405599867355487,0.2879572199629991,3581.0,2781.212673664093,3131.220014810562,2781.212673664093,348.6447901725769,0.966834306716919,0.0 -11500,0.096649095,0.33586907,,,,,,,,,,,,,, -11600,0.15346731,0.20781556,,,,,,,,,,,,,, -11700,0.1720338,0.23725942,,,,,,,,,,,,,, -11765,,,0.7490818841116769,0.2642995119094848,0.72429846321926,0.2864849238698913,3554.0,0.7415539342842432,0.2878505916643396,3581.0,2861.333312511444,3215.433751344681,2861.333312511444,352.6982641220093,0.994499444961548,0.0 -11800,0.25611463,0.21975008,,,,,,,,,,,,,, -11900,0.3573865,0.41236928,,,,,,,,,,,,,, -12000,0.13256684,0.26386997,,,,,,,,,,,,,, -12100,0.12577857,0.21374175,,,,,,,,,,,,,, -12108,,,0.7470649310520717,0.2648725509643554,0.7225018932233047,0.2872223261037299,3554.0,0.7395374049890044,0.2886686775145734,3581.0,2941.3339407444,3299.527735233307,2941.3339407444,356.7520639896393,1.0221521854400637,0.0 -12200,0.12496407,0.20620735,,,,,,,,,,,,,, -12300,0.07872862,0.27498585,,,,,,,,,,,,,, -12400,0.102752894,0.24254166,,,,,,,,,,,,,, -12453,,,0.7491067477634975,0.2643582820892334,0.7243488163644837,0.2864743964217255,3554.0,0.7415426169584264,0.2878965086458915,3581.0,3021.476846933365,3383.7677912712097,3021.476846933365,360.8086497783661,1.0507407188415527,0.0 -12500,0.103772044,0.34392664,,,,,,,,,,,,,, -12600,0.2693736,0.34087765,,,,,,,,,,,,,, -12700,0.16076909,0.353208,,,,,,,,,,,,,, -12798,,,0.7487072263445173,0.2641389199665614,0.7240591312165869,0.2862849538756946,3554.0,0.7412870226542865,0.2876597651886693,3581.0,3101.4897875785828,3467.873590707779,3101.4897875785828,364.86098647117615,1.0791373252868652,0.0 -12800,0.12892015,0.24129808,,,,,,,,,,,,,, -12900,0.20140912,0.24444845,,,,,,,,,,,,,, -13000,0.38186717,0.26347306,,,,,,,,,,,,,, -13100,0.05461713,0.27503076,,,,,,,,,,,,,, -13140,,,0.7497420992170062,0.2637655564716884,0.7244437523081387,0.2864398945565296,3554.0,0.7415685240898143,0.2878337520289374,3581.0,3181.5577688217163,3552.0366492271423,3181.5577688217163,368.9163925647736,1.107081413269043,0.0 -13200,0.08155038,0.19626229,,,,,,,,,,,,,, -13300,0.12181919,0.34141338,,,,,,,,,,,,,, -13400,0.0980256,0.31581292,,,,,,,,,,,,,, -13487,,,0.749126638684954,0.2638783114297049,0.7242455683736635,0.286114694298018,3554.0,0.7415469120881039,0.2875157079028204,3581.0,3261.6042971611023,3636.176833629608,3261.6042971611023,372.96993708610535,1.135343074798584,0.0 -13500,0.14522982,0.33411604,,,,,,,,,,,,,, -13600,0.12326343,0.24056105,,,,,,,,,,,,,, -13700,0.11474382,0.2864201,,,,,,,,,,,,,, -13800,0.17638664,0.24855867,,,,,,,,,,,,,, -13833,,,0.7492619241986956,0.2641657420567104,0.7243551362681134,0.2865796365560811,3554.0,0.7416033623638648,0.2879423233624511,3581.0,3341.729484319687,3720.393041372299,3341.729484319687,377.0191743373871,1.1649174690246582,0.0 -13900,0.1809352,0.2516886,,,,,,,,,,,,,, -14000,0.0999721,0.279803,,,,,,,,,,,,,, -14100,0.09046851,0.2902615,,,,,,,,,,,,,, -14177,,,0.7501425061907087,0.2635764905384609,0.7248468522483469,0.2862090291638822,3554.0,0.7419544039941707,0.2877021028954901,3581.0,3421.8819210529327,3804.6349902153015,3421.8819210529327,381.06914806365967,1.1924426555633545,0.0 -14200,0.36958042,0.22935858,,,,,,,,,,,,,, -14300,0.10344644,0.30487916,,,,,,,,,,,,,, -14400,0.17193809,0.23033658,,,,,,,,,,,,,, -14500,0.15947418,0.3577425,,,,,,,,,,,,,, -14523,,,0.7482326371329171,0.2640088626316615,0.7233014984216728,0.2863345685539181,3554.0,0.7405440333967467,0.2876874108249267,3581.0,3501.987695455551,3888.836674690247,3501.987695455551,385.1245248317719,1.2210686206817627,0.0 -14600,0.24294835,0.3467489,,,,,,,,,,,,,, -14700,0.12495415,0.2402792,,,,,,,,,,,,,, -14800,0.19805677,0.25700185,,,,,,,,,,,,,, -14867,,,0.7490936688014439,0.2641704423086984,0.724210396736072,0.2864648993926297,3554.0,0.741584068368647,0.2877675524905753,3581.0,3581.99942946434,3972.9454820156097,3581.99942946434,389.177859544754,1.2528254985809326,0.0 -14900,0.26318112,0.23445213,,,,,,,,,,,,,, -15000,0.18313834,0.24526757,,,,,,,,,,,,,, -15100,0.18589294,0.2170494,,,,,,,,,,,,,, -15200,0.107467845,0.21793377,,,,,,,,,,,,,, -15209,,,0.7499757494245257,0.2643688576562064,0.7254071941386466,0.2867057426766056,3554.0,0.7425437230565833,0.2880258397728463,3581.0,3662.1490709781647,4057.1929366588593,3662.1490709781647,393.2313735485077,1.2852871417999268,0.0 -15300,0.18185823,0.19950986,,,,,,,,,,,,,, -15400,0.067857236,0.35286498,,,,,,,,,,,,,, -15500,0.17511693,0.26886117,,,,,,,,,,,,,, -15555,,,0.7491991860525948,0.2637040104184832,0.7240965697761326,0.2861622309644502,3554.0,0.7413624942186191,0.2875579433446488,3581.0,3742.141808271408,4141.27409529686,3742.141808271408,397.2786636352539,1.3142294883728027,0.0 -15600,0.26917875,0.26778537,,,,,,,,,,,,,, -15700,0.18993503,0.22194421,,,,,,,,,,,,,, -15800,0.12722833,0.3023996,,,,,,,,,,,,,, -15900,,,0.7490659441266742,0.2637172596795218,0.7240223109084833,0.2860872336298009,3554.0,0.7411458969648841,0.2874640981700293,3581.0,3822.2436985969534,4225.470196008682,3822.2436985969534,401.3295774459839,1.3458054065704346,0.0 -15900,0.12415524,0.30260694,,,,,,,,,,,,,, -16000,0.09978506,0.19717148,,,,,,,,,,,,,, -16100,0.17796591,0.29420784,,,,,,,,,,,,,, -16200,0.12916784,0.28353375,,,,,,,,,,,,,, -16242,,,0.7492741176060268,0.2642509256090437,0.7244184726936199,0.2865908166029913,3554.0,0.7415718647462302,0.2879943080668807,3581.0,3902.2939944267273,4309.618505954742,3902.2939944267273,405.3828492164612,1.3787052631378174,0.0 -16300,0.16945514,0.2897928,,,,,,,,,,,,,, -16400,0.13318369,0.33430994,,,,,,,,,,,,,, -16500,0.36001822,0.34180012,,,,,,,,,,,,,, -16586,,,0.7487504822867257,0.2635989189147949,0.7237115365169527,0.2861476848819112,3554.0,0.7408519873769548,0.2875285932918528,3581.0,3982.34534406662,4393.763057470322,3982.34534406662,409.4357068538666,1.4074318408966064,0.0 -16600,0.18854283,0.23314634,,,,,,,,,,,,,, -16700,0.32787937,0.31045735,,,,,,,,,,,,,, -16800,0.23095486,0.23490241,,,,,,,,,,,,,, -16900,0.07027738,0.3467363,,,,,,,,,,,,,, -16931,,,0.7499645096915108,0.2632872888020107,0.7248403262609032,0.2857487753125879,3554.0,0.7420229897156869,0.2871454404539584,3581.0,4062.446593284607,4477.958353757858,4062.446593284607,413.48853278160095,1.4365577697753906,0.0 -17000,0.099068366,0.22148147,,,,,,,,,,,,,, -17100,0.08537359,0.2546469,,,,,,,,,,,,,, -17200,0.08744361,0.2604441,,,,,,,,,,,,,, -17274,,,0.750265257699149,0.2633872883660452,0.725125202786473,0.2858453255794527,3554.0,0.7423415110784348,0.2872500234527715,3581.0,4142.558726072311,4562.1660623550415,4142.558726072311,417.5439202785492,1.4648287296295166,0.0 -17300,0.12595837,0.27577215,,,,,,,,,,,,,, -17400,0.4572419,0.26326036,,,,,,,,,,,,,, -17500,0.09461056,0.26556325,,,,,,,,,,,,,, -17600,0.14209086,0.31594422,,,,,,,,,,,,,, -17619,,,0.7514595985412598,0.2629937955311366,0.7259761228545301,0.2859021875384689,3554.0,0.7431117710005934,0.2873178592310109,3581.0,4222.760106086731,4646.464495420456,4222.760106086731,421.599889755249,1.4937596321105957,0.0 -17700,0.28375915,0.2564631,,,,,,,,,,,,,, -17800,0.35600656,0.33648273,,,,,,,,,,,,,, -17900,0.15105905,0.29459164,,,,,,,,,,,,,, -17965,,,0.7501124654497419,0.2631190163748605,0.7249581375079136,0.2856347937897879,3554.0,0.7422263606970818,0.2870001559882016,3581.0,4302.788812160492,4730.587537527084,4302.788812160492,425.6540808677673,1.5217459201812744,0.0 -18000,0.07260054,0.2863964,,,,,,,,,,,,,, -18100,0.18794067,0.25332993,,,,,,,,,,,,,, -18200,0.07785706,0.30057448,,,,,,,,,,,,,, -18300,0.18884067,0.21744487,,,,,,,,,,,,,, -18309,,,0.7495180538722447,0.2637862648282732,0.7243303375158272,0.2862694117213878,3554.0,0.7415111193407917,0.2876765366474099,3581.0,4382.871379613876,4814.766731977463,4382.871379613876,429.7097146511078,1.550818681716919,0.0 -18400,0.121681936,0.3011419,,,,,,,,,,,,,, -18500,0.07347176,0.2671079,,,,,,,,,,,,,, -18600,0.17551562,0.29850632,,,,,,,,,,,,,, -18655,,,0.7512383460998535,0.2626204831259591,0.725613896204101,0.2855628018440929,3554.0,0.7428010217772619,0.2869414217942788,3581.0,4463.077668428421,4899.068983078003,4463.077668428421,433.7644882202149,1.580070734024048,0.0 -18700,0.123425856,0.21002355,,,,,,,,,,,,,, -18800,0.11207002,0.38490915,,,,,,,,,,,,,, -18900,0.2805911,0.21942486,,,,,,,,,,,,,, -18999,,,0.7502283368791852,0.2631433010101318,0.7251444372757808,0.2858119915225362,3554.0,0.7423001278448758,0.2871799719330319,3581.0,4543.044121026993,4983.139835596085,4543.044121026993,437.8175001144409,1.6196582317352295,0.0 -19000,0.13706131,0.24731036,,,,,,,,,,,,,, -19100,0.19140705,0.2442188,,,,,,,,,,,,,, -19200,0.79195005,0.19536665,,,,,,,,,,,,,, -19300,0.116575584,0.24283203,,,,,,,,,,,,,, -19346,,,0.7503267696925572,0.2632224900381906,0.7251956147562606,0.2856542858838632,3554.0,0.7424429579508168,0.2870098370741413,3581.0,4623.194087982178,5067.387616157532,4623.194087982178,441.870076417923,1.6532104015350342,0.0 -19400,0.31517252,0.2694945,,,,,,,,,,,,,, -19500,0.14703397,0.25747213,,,,,,,,,,,,,, -19600,0.14233184,0.27562648,,,,,,,,,,,,,, -19688,,,0.750896794455392,0.2627083233424595,0.7252511199968346,0.2856937337605955,3554.0,0.7424449350740017,0.2870990462357756,3581.0,4703.263406991959,5151.553251743317,4703.263406991959,445.9244599342346,1.6831650733947754,0.0 -19700,0.13545008,0.22199263,,,,,,,,,,,,,, -19800,0.15235545,0.3453434,,,,,,,,,,,,,, -19900,0.081933446,0.2427128,,,,,,,,,,,,,, -20000,0.2874324,0.28787774,,,,,,,,,,,,,, -20033,,,0.7507200922284808,0.2630380732672555,0.7253897457090602,0.2857979091285787,3554.0,0.7425795839805571,0.2871399863210346,3581.0,4783.346960544586,5235.728672742844,4783.346960544586,449.9741785526276,1.7135276794433594,0.0 -20100,0.20886959,0.35006163,,,,,,,,,,,,,, -20200,0.13458066,0.35473663,,,,,,,,,,,,,, -20300,0.23024812,0.26824266,,,,,,,,,,,,,, -20379,,,0.7507028579711914,0.2628463676997593,0.725503091806767,0.2854411952201744,3554.0,0.7427057108044192,0.2867996143382609,3581.0,4863.308755397797,5319.78363609314,4863.308755397797,454.02555441856384,1.7433743476867676,0.0 -20400,0.11879435,0.26516208,,,,,,,,,,,,,, -20500,0.1773664,0.21705379,,,,,,,,,,,,,, -20600,0.13636766,0.28906918,,,,,,,,,,,,,, -20700,0.13566773,0.30522868,,,,,,,,,,,,,, -20722,,,0.7512781279427665,0.2634522574288504,0.7255425225098481,0.286552278929771,3554.0,0.742558381038816,0.2879258246103567,3581.0,4943.37055516243,5403.941605091095,4943.37055516243,458.0803413391113,1.773087501525879,0.0 -20800,0.09761411,0.26051563,,,,,,,,,,,,,, -20900,0.15454675,0.3035571,,,,,,,,,,,,,, -21000,0.28101057,0.32251754,,,,,,,,,,,,,, -21068,,,0.7504569462367466,0.2626910550253732,0.7252177344189645,0.2853945172363006,3554.0,0.7424042336070581,0.2867335170648911,3581.0,5023.388708114624,5488.054424762726,5023.388708114624,462.1315383911133,1.8046362400054927,0.0 -21100,0.10318578,0.26253957,,,,,,,,,,,,,, -21200,0.1412843,0.21987385,,,,,,,,,,,,,, -21300,0.12582119,0.2403382,,,,,,,,,,,,,, -21400,0.14613633,0.21465856,,,,,,,,,,,,,, -21413,,,0.7511151858738491,0.2626854521887643,0.7259615595983399,0.2853362985588421,3554.0,0.7430837503926976,0.2866953722227555,3581.0,5103.514678239822,5572.2762241363525,5103.514678239822,466.18581223487854,1.8341941833496087,0.0 -21500,0.40682253,0.20553924,,,,,,,,,,,,,, -21600,0.15075201,0.2313644,,,,,,,,,,,,,, -21700,0.09947518,0.24704014,,,,,,,,,,,,,, -21757,,,0.7500740459987095,0.2627014432634626,0.7243928496060776,0.2856341583646947,3554.0,0.7417080817160011,0.2870615490719247,3581.0,5183.712336778641,5656.568719625473,5183.712336778641,470.2389633655548,1.8642263412475584,0.0 -21800,0.14836358,0.2327213,,,,,,,,,,,,,, -21900,0.10847842,0.20361158,,,,,,,,,,,,,, -22000,0.12905687,0.27712387,,,,,,,,,,,,,, -22100,0.13785343,0.373096,,,,,,,,,,,,,, -22102,,,0.7515127999441964,0.2628119502748762,0.7261375551755065,0.2855904857697752,3554.0,0.743197332710835,0.2870317558708287,3581.0,5263.737190961838,5740.689452886581,5263.737190961838,474.2914884090424,1.8956067562103271,0.0 -22200,0.23831551,0.2279997,,,,,,,,,,,,,, -22300,0.092288144,0.21964926,,,,,,,,,,,,,, -22400,0.09289285,0.22428122,,,,,,,,,,,,,, -22449,,,0.7513835770743233,0.2626760687146868,0.7262300181133934,0.2854385676515458,3554.0,0.7433537981490854,0.2867978758333915,3581.0,5343.85110616684,5824.905149459839,5343.85110616684,478.3473429679871,1.929384469985962,0.0 -22500,0.22056927,0.24840079,,,,,,,,,,,,,, -22600,0.16017118,0.21107893,,,,,,,,,,,,,, -22700,0.096982405,0.15840206,,,,,,,,,,,,,, -22791,,,0.7507671628679548,0.2628911222730364,0.7254199713351154,0.285624558293692,3554.0,0.7425605626919854,0.2869999173698862,3581.0,5423.849596738815,5908.999294519424,5423.849596738815,482.4012656211853,1.959122896194458,0.0 -22800,0.1099287,0.3040151,,,,,,,,,,,,,, -22900,0.33173764,0.2334768,,,,,,,,,,,,,, -23000,0.15226847,0.32152408,,,,,,,,,,,,,, -23100,0.16389997,0.2763794,,,,,,,,,,,,,, -23137,,,0.7513228143964495,0.2626312630517142,0.7257843962128939,0.2855588003833708,3554.0,0.7428933329769967,0.2869907816972389,3581.0,5504.067927360535,5993.308559894562,5504.067927360535,486.4510595798493,1.9885139465332031,0.0 -23200,0.23889017,0.30449212,,,,,,,,,,,,,, -23300,0.14409007,0.24631485,,,,,,,,,,,,,, -23400,0.1834604,0.35526192,,,,,,,,,,,,,, -23484,,,0.7516647747584752,0.2624009847640991,0.7261988994574775,0.2852267650116946,3554.0,0.7433634110583636,0.286600777104859,3581.0,5584.158223390579,6077.498130559921,5584.158223390579,490.5085999965668,2.0182549953460693,0.0 -23500,0.16871688,0.27382717,,,,,,,,,,,,,, -23600,0.1636202,0.2810173,,,,,,,,,,,,,, -23700,0.094172634,0.2507594,,,,,,,,,,,,,, -23800,0.21382071,0.23270203,,,,,,,,,,,,,, -23829,,,0.7513493810381208,0.2625121729714529,0.7258792634619443,0.2853360753013769,3554.0,0.7430474804087546,0.2867686621339186,3581.0,5664.2836174964905,6161.723926067352,5664.2836174964905,494.56139755249023,2.053964614868164,0.0 -23900,0.1108655,0.2325446,,,,,,,,,,,,,, -24000,0.28450516,0.22749788,,,,,,,,,,,,,, -24100,0.18426076,0.27122355,,,,,,,,,,,,,, -24175,,,0.7513031278337751,0.2620893716812134,0.7253518262872819,0.2852982074005522,3554.0,0.7425226564681653,0.286753390561732,3581.0,5744.359104633331,6245.896435976028,5744.359104633331,498.6148250102997,2.0857253074646,0.0 -24200,0.13620757,0.27348652,,,,,,,,,,,,,, -24300,0.18140787,0.2941166,,,,,,,,,,,,,, -24400,0.102421,0.21505684,,,,,,,,,,,,,, -24500,0.1957413,0.35988742,,,,,,,,,,,,,, -24520,,,0.7515650476728167,0.2621761390141078,0.7259875948535102,0.2851459458092993,3554.0,0.7431070668109466,0.2865376455162664,3581.0,5824.3427176475525,6329.981694221497,5824.3427176475525,502.6727757453919,2.1174325942993164,0.0 -24600,0.1469726,0.19165774,,,,,,,,,,,,,, -24700,0.16570926,0.2797115,,,,,,,,,,,,,, -24800,0.09654388,0.26722512,,,,,,,,,,,,,, -24863,,,0.7513407298496791,0.2622285911015102,0.7257067026150112,0.2851708647771437,3554.0,0.7429354661538328,0.2865232261523492,3581.0,5904.422456741333,6414.155948877335,5904.422456741333,506.7254252433777,2.1474146842956543,0.0 -24900,0.12278477,0.30162352,,,,,,,,,,,,,, -25000,0.07702211,0.24922526,,,,,,,,,,,,,, -25100,0.122807704,0.25593278,,,,,,,,,,,,,, -25200,0.17824616,0.326679,,,,,,,,,,,,,, -25206,,,0.7524091175624302,0.2616928815841675,0.7262610680747046,0.2851238261465953,3554.0,0.7434358828495881,0.2865558145965687,3581.0,5984.5337653160095,6498.363244771957,5984.5337653160095,510.77738857269287,2.179636716842652,0.0 -25300,0.20618357,0.25181141,,,,,,,,,,,,,, -25400,0.1149571,0.2780113,,,,,,,,,,,,,, -25500,0.070915334,0.20715857,,,,,,,,,,,,,, -25552,,,0.7515433175223214,0.2620738404137747,0.7256952993106359,0.2851855654225432,3554.0,0.7429550328556968,0.2865705066671321,3581.0,6064.5873601436615,6582.516287326813,6064.5873601436615,514.8336462974548,2.210814952850342,0.0 -25600,0.064861596,0.24876359,,,,,,,,,,,,,, -25700,0.12863643,0.26129285,,,,,,,,,,,,,, -25800,0.06351844,0.2707811,,,,,,,,,,,,,, -25900,,,0.7509380068097796,0.26214143208095,0.7252735831325618,0.2851501533538355,3554.0,0.7424625246526808,0.2865279644303267,3581.0,6144.693779706955,6666.730274915695,6144.693779706955,518.8966798782349,2.243125438690185,0.0 -25900,0.17082195,0.19354519,,,,,,,,,,,,,, -26000,0.14773367,0.2576508,,,,,,,,,,,,,, -26100,0.06838004,0.21681795,,,,,,,,,,,,,, -26200,0.1489371,0.25156984,,,,,,,,,,,,,, -26246,,,0.7524256025041852,0.2616041558129446,0.7262390858012099,0.285166794621817,3554.0,0.7433814778736736,0.2866125716673066,3581.0,6224.853043317795,6750.988937854767,6224.853043317795,522.9531097412109,2.274085521697998,0.0 -26300,0.11552157,0.28046018,,,,,,,,,,,,,, -26400,0.13490772,0.2596028,,,,,,,,,,,,,, -26500,0.18084122,0.2462406,,,,,,,,,,,,,, -26590,,,0.752068178994315,0.2618626015526907,0.7262528934167487,0.2850863875870498,3554.0,0.7433889773064437,0.2864727413344736,3581.0,6304.8623919487,6835.089107990265,6304.8623919487,527.0004780292511,2.305572271347046,0.0 -26600,0.094920844,0.3152577,,,,,,,,,,,,,, -26700,0.12407105,0.33256382,,,,,,,,,,,,,, -26800,0.21488328,0.28887314,,,,,,,,,,,,,, -26900,0.095952086,0.20038354,,,,,,,,,,,,,, -26937,,,0.751748970576695,0.2620003904615129,0.7261211371649902,0.2851182103626635,3554.0,0.7432566464063809,0.2864647646650726,3581.0,6384.856774330139,6919.176200866699,6384.856774330139,531.0505583286285,2.336199998855591,0.0 -27000,0.11297567,0.2421701,,,,,,,,,,,,,, -27100,0.1915111,0.38150403,,,,,,,,,,,,,, -27200,0.15464707,0.20848559,,,,,,,,,,,,,, -27281,,,0.7525568008422852,0.2614716802324567,0.7262717157384285,0.2850966059094946,3554.0,0.7433453442430537,0.2865492355487294,3581.0,6465.06779050827,7003.490379571915,6465.06779050827,535.1086058616638,2.369391679763794,0.0 -27300,0.20986684,0.234573,,,,,,,,,,,,,, -27400,0.093071006,0.21581441,,,,,,,,,,,,,, -27500,0.07026327,0.25721014,,,,,,,,,,,,,, -27600,0.11064211,0.3166965,,,,,,,,,,,,,, -27625,,,0.751924855368478,0.2617475816181728,0.726097574915588,0.2850241674488692,3554.0,0.7431984235374197,0.2864156774687587,3581.0,6545.176929950714,7087.695506572723,6545.176929950714,539.1609346866608,2.4011919498443604,0.0 -27700,0.08311089,0.4022999,,,,,,,,,,,,,, -27800,0.11869152,0.33977833,,,,,,,,,,,,,, -27900,0.13953657,0.20245267,,,,,,,,,,,,,, -27971,,,0.7521243095397949,0.2618813855307443,0.7263136881418824,0.2851246848291537,3554.0,0.7433388674602066,0.28652326024068,3581.0,6625.335506677628,7171.952100038528,6625.335506677628,543.2160265445709,2.432178735733032,0.0 -28000,0.10367123,0.23698452,,,,,,,,,,,,,, -28100,0.41468748,0.23595369,,,,,,,,,,,,,, -28200,0.18462849,0.28293747,,,,,,,,,,,,,, -28300,0.074060805,0.2612844,,,,,,,,,,,,,, -28311,,,0.7526357514517648,0.2617536953517368,0.7263437076841235,0.2855354957387362,3554.0,0.7433338223872522,0.2869765668633063,3581.0,6705.314663171768,7256.025635004044,6705.314663171768,547.2668516635895,2.463989496231079,0.0 -28400,0.11956976,0.2777286,,,,,,,,,,,,,, -28500,0.123635344,0.25670332,,,,,,,,,,,,,, -28600,0.077143125,0.33786646,,,,,,,,,,,,,, -28660,,,0.7522928374154227,0.2616182225091116,0.7263367008344471,0.2850380265853615,3554.0,0.7434398370959578,0.2864036442879957,3581.0,6785.449561357498,7340.261151790619,6785.449561357498,551.3233077526093,2.4959301948547363,0.0 -28700,0.104439706,0.25155967,,,,,,,,,,,,,, -28800,0.11921684,0.23428303,,,,,,,,,,,,,, -28900,0.18806109,0.309704,,,,,,,,,,,,,, -29000,0.19102138,0.24128531,,,,,,,,,,,,,, -29003,,,0.7521061216081891,0.2617198399135044,0.7261561027187676,0.2851091598484981,3554.0,0.7432160812927604,0.2864864789317753,3581.0,6865.401733875275,7424.313965082169,6865.401733875275,555.37380027771,2.5343716144561768,0.0 -29100,0.17484847,0.20205906,,,,,,,,,,,,,, -29200,0.15024501,0.27459815,,,,,,,,,,,,,, -29300,0.11493448,0.26650777,,,,,,,,,,,,,, -29346,,,0.7521747861589704,0.2616387775966099,0.7258670358223129,0.2852646500861705,3554.0,0.7429957343226403,0.286642160338418,3581.0,6945.3704397678375,7508.383171081543,6945.3704397678375,559.4301190376282,2.566793918609619,0.0 -29400,0.15156914,0.22089729,,,,,,,,,,,,,, -29500,0.10096923,0.19412684,,,,,,,,,,,,,, -29600,0.08237953,0.28265837,,,,,,,,,,,,,, -29694,,,0.7528272356305804,0.2614959818976266,0.7266219208330402,0.2850999032505187,3554.0,0.743683500484327,0.2864821838020979,3581.0,7025.511574745178,7592.625883340836,7025.511574745178,563.4838757514954,2.602396249771118,0.0 -29700,0.11969846,0.3804017,,,,,,,,,,,,,, -29800,0.11015223,0.24761193,,,,,,,,,,,,,, -29900,0.087968126,0.27047557,,,,,,,,,,,,,, -30000,0.08838491,0.24497242,,,,,,,,,,,,,, -30036,,,0.7525065967014858,0.2615717138562883,0.7264429713878728,0.285084979347654,3554.0,0.7435130588304594,0.2864304036276529,3581.0,7105.662876844406,7676.874650716782,7105.662876844406,567.5378227233887,2.6340932846069336,0.0 -30100,0.07399071,0.265703,,,,,,,,,,,,,, -30200,0.12035121,0.24334833,,,,,,,,,,,,,, -30300,0.07640061,0.24015212,,,,,,,,,,,,,, -30379,,,0.7517011506216866,0.2618185792650495,0.7255941808525604,0.2853469118752638,3554.0,0.7427278000427604,0.2867632420893256,3581.0,7185.761651039124,7761.072353124618,7185.761651039124,571.5938925743103,2.66501784324646,0.0 -30400,0.34087816,0.22304474,,,,,,,,,,,,,, -30500,0.12950629,0.21131678,,,,,,,,,,,,,, -30600,0.0866017,0.24720727,,,,,,,,,,,,,, -30700,0.18771349,0.31225148,,,,,,,,,,,,,, -30724,,,0.7526508740016392,0.2612915209361485,0.7261841301174733,0.285073198222953,3554.0,0.7432732133351369,0.2864934329512531,3581.0,7265.859354972839,7845.270067453384,7265.859354972839,575.6503157615662,2.6965012550354004,0.0 -30800,0.0931526,0.31141722,,,,,,,,,,,,,, -30900,0.0920744,0.19367182,,,,,,,,,,,,,, -31000,0.075568125,0.30892313,,,,,,,,,,,,,, -31071,,,0.752741881779262,0.2614444323948451,0.7265454637380416,0.2850723395403946,3554.0,0.7436164146493647,0.2864299604793528,3581.0,7345.888344764709,7929.394406795502,7345.888344764709,579.7019543647766,2.7283411026000977,0.0 -31100,0.063455015,0.2581575,,,,,,,,,,,,,, -31200,0.090879604,0.25317395,,,,,,,,,,,,,, -31300,0.102673754,0.29722843,,,,,,,,,,,,,, -31400,0.05814025,0.2604366,,,,,,,,,,,,,, -31415,,,0.7526612281799316,0.2614004441670009,0.7264517642972707,0.2850400359025481,3554.0,0.7435497378743717,0.2864075985343654,3581.0,7426.11357998848,8013.716660499573,7426.11357998848,583.755449295044,2.7600905895233154,0.0 -31500,0.14036298,0.19957529,,,,,,,,,,,,,, -31600,0.04750959,0.21799088,,,,,,,,,,,,,, -31700,0.06757388,0.24015802,,,,,,,,,,,,,, -31761,,,0.753065654209682,0.2610013655253819,0.7265352282419457,0.2849258998168876,3554.0,0.7436201643657497,0.2863236730640009,3581.0,7506.210152387619,8097.917795181274,7506.210152387619,587.8157875537872,2.792599678039551,0.0 -31800,0.11997647,0.21547082,,,,,,,,,,,,,, -31900,0.08443608,0.32219052,,,,,,,,,,,,,, -32000,0.08600074,0.18692586,,,,,,,,,,,,,, -32100,0.061689865,0.38814557,,,,,,,,,,,,,, -32107,,,0.7528303691319057,0.2612640857696533,0.7265279466138506,0.2850118195936796,3554.0,0.7436131421696105,0.2863925996688251,3581.0,7586.232357978821,8182.043341875076,7586.232357978821,591.8755309581757,2.8241755962371826,0.0 -32200,0.086184196,0.22431351,,,,,,,,,,,,,, -32300,0.08953046,0.31026015,,,,,,,,,,,,,, -32400,0.08342853,0.31772488,,,,,,,,,,,,,, -32449,,,0.7530434472220284,0.2612173386982509,0.726705865639948,0.2849734708306222,3554.0,0.7437789478104929,0.2863342745348715,3581.0,7666.246356487274,8266.15687918663,7666.246356487274,595.9300870895386,2.857153415679932,0.0 -32500,0.057267107,0.28252405,,,,,,,,,,,,,, -32600,0.08048428,0.24871594,,,,,,,,,,,,,, -32700,0.077407,0.3322822,,,,,,,,,,,,,, -32793,,,0.7532211031232562,0.2608901262283325,0.7266475439205824,0.2848969965619724,3554.0,0.7436942042201898,0.2862743131610409,3581.0,7746.339337348938,8350.34925699234,7746.339337348938,599.9857411384583,2.8888802528381348,0.0 -32800,0.08959947,0.25990528,,,,,,,,,,,,,, -32900,0.07582265,0.22557098,,,,,,,,,,,,,, -33000,0.08004727,0.2687794,,,,,,,,,,,,,, -33100,0.07629226,0.26365343,,,,,,,,,,,,,, -33140,,,0.7529222624642509,0.2610472440719604,0.7264977896824001,0.2849144965125123,3554.0,0.7436322998115051,0.2862382136187517,3581.0,7826.380793333054,8434.489802360535,7826.380793333054,604.0404622554779,2.9211058616638184,0.0 -33200,0.112433545,0.34739596,,,,,,,,,,,,,, -33300,0.10411219,0.25107712,,,,,,,,,,,,,, -33400,0.055127308,0.25058255,,,,,,,,,,,,,, -33484,,,0.7528365680149623,0.2611274208341326,0.7264392618792206,0.2849736940880873,3554.0,0.7435476925745252,0.2863115376182456,3581.0,7906.548151254654,8518.75707435608,7906.548151254654,608.09494972229,2.9546470642089844,0.0 -33500,0.06370863,0.25128815,,,,,,,,,,,,,, -33600,0.11425105,0.2266666,,,,,,,,,,,,,, -33700,0.1002542,0.28560504,,,,,,,,,,,,,, -33800,0.075409316,0.29403064,,,,,,,,,,,,,, -33829,,,0.7534403119768415,0.2608636447361537,0.7268620084763646,0.2848896977602261,3554.0,0.7438938254851997,0.2862855282218654,3581.0,7986.679588794708,8602.986902713776,7986.679588794708,612.147744178772,2.9883530139923096,0.0 -33900,0.05234618,0.36836755,,,,,,,,,,,,,, -34000,0.0641533,0.25401556,,,,,,,,,,,,,, -34100,0.05656467,0.30038744,,,,,,,,,,,,,, -34175,,,0.7532878603254046,0.2609024899346487,0.7268059536789533,0.2848174653834148,3554.0,0.7438721453068277,0.2861680598340198,3581.0,8066.71623635292,8687.127203702927,8066.71623635292,616.2017261981964,3.026043176651001,0.0 -34200,0.07269703,0.3339774,,,,,,,,,,,,,, -34300,0.12467185,0.24989642,,,,,,,,,,,,,, -34400,0.05965389,0.26854223,,,,,,,,,,,,,, -34500,0.08209733,0.262409,,,,,,,,,,,,,, -34521,,,0.7531764847891671,0.2609563895634242,0.7267027056881331,0.2848718199893605,3554.0,0.7437584266353672,0.2862098180392174,3581.0,8146.856377363205,8771.369478464127,8146.856377363205,620.2588531970978,3.0591280460357666,0.0 -34600,0.056309994,0.2587156,,,,,,,,,,,,,, -34700,0.25198686,0.35557154,,,,,,,,,,,,,, -34800,0.05161955,0.3641845,,,,,,,,,,,,,, -34863,,,0.7531660624912807,0.2608116865158081,0.7266088001635481,0.2848065772885745,3554.0,0.7436710241552639,0.2861729344653204,3581.0,8226.93552994728,8855.545527935028,8226.93552994728,624.3104538917542,3.092402696609497,0.0 -34900,0.086863406,0.22588508,,,,,,,,,,,,,, -35000,0.04297794,0.2967342,,,,,,,,,,,,,, -35100,0.052220825,0.27063832,,,,,,,,,,,,,, -35200,0.06562105,0.22075357,,,,,,,,,,,,,, -35207,,,0.7530843189784459,0.260757531438555,0.7264942175629572,0.2847236285534345,3554.0,0.7435840989117914,0.2860854297202247,3581.0,8306.98347043991,8939.69522190094,8306.98347043991,628.3667876720428,3.125974416732788,0.0 -35300,0.07797539,0.2522673,,,,,,,,,,,,,, -35400,0.07374002,0.29265416,,,,,,,,,,,,,, -35500,0.06736435,0.192171,,,,,,,,,,,,,, -35554,,,0.7531565257481166,0.260784523827689,0.7266150513725732,0.2847584567180026,3554.0,0.7436912726237433,0.2861068031036198,3581.0,8386.943142175674,9023.751316070557,8386.943142175674,632.4175064563751,3.1593573093414307,0.0 -35600,0.111483015,0.23259121,,,,,,,,,,,,,, -35700,0.0638197,0.29613692,,,,,,,,,,,,,, -35800,0.15303646,0.26629108,,,,,,,,,,,,,, -35897,,,0.7530922208513532,0.2607801982334682,0.7265309691764561,0.2847584910653049,3554.0,0.7436200280124267,0.2861113709399434,3581.0,8467.084789514542,9107.988319396973,8467.084789514542,636.4669382572174,3.19321870803833,0.0 -35900,0.06623433,0.2646526,,,,,,,,,,,,,, -36000,0.055850897,0.22781476,,,,,,,,,,,,,, -36100,0.08472898,0.25734323,,,,,,,,,,,,,, -36189,,,0.7533482824053083,0.2607856137411935,0.7268125483610017,0.2847664596394467,3554.0,0.7438771222031206,0.2861253471555606,3581.0,8534.14320230484,9179.14162349701,8534.14320230484,640.5183072090149,3.226508617401123,0.0 -36189,,,,,,,,,,,8534.14320230484,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 1f41ec132..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.9995675086975098,0.0,29.85322141647339,1,0,29.85322141647339,0.8250461419645351,3581,0.2972505852284627,33.85292458534241,0.8154546192714146,0.2794570582253592,0.8275999671365011,3554,0.2738082104341059 -8.045104265213013,0.0208237171173095,109.82941102981567,340,0,109.82941102981567,0.3325069185676138,3581,0.6910679574054035,117.90757751464844,0.3081912994384765,0.6971728461129325,0.3307215873619513,3554,0.672379355787493 -12.096370697021484,0.0461506843566894,189.9423379898072,681,0,189.9423379898072,0.312450026507086,3581,0.7154252969229964,202.10964155197144,0.2886877400534494,0.7214180401393345,0.3102809581853193,3554,0.6979783315234594 -16.14963984489441,0.0723109245300293,270.07433581352234,1018,0,270.07433581352234,0.3053255653754189,3581,0.7213962089412873,286.3332691192627,0.2825293030057634,0.7268994195120675,0.3032789171312254,3554,0.7041054780924663 -20.19764614105225,0.0971243381500244,350.1551387310028,1362,0,350.1551387310028,0.3010481616299567,3581,0.7271120038920692,370.4994876384735,0.2778271777289254,0.7332515716552734,0.2989388263554269,3554,0.7102090624120709 -24.24586677551269,0.1208417415618896,430.2985694408417,1708,0,430.2985694408417,0.3023199290744555,3581,0.726705943695895,454.72812390327454,0.278896621295384,0.7331899915422712,0.3001783518021068,3554,0.7097868653717642 -28.29649257659912,0.1473636627197265,510.4877436161041,2051,0,510.4877436161041,0.2962755226150168,3581,0.7314937179296984,539.0076286792755,0.2731998818261282,0.7381429672241211,0.2945284609990504,3554,0.7141687568145048 -32.3504912853241,0.1729648113250732,590.4974267482758,2396,0,590.4974267482758,0.2954339499048799,3581,0.7335627432543284,623.1101274490356,0.2719523566109793,0.7405281748090472,0.2936053429014842,3554,0.7164755903339547 -36.397353172302246,0.1988835334777832,670.5267918109894,2742,0,670.5267918109894,0.2942995925217292,3581,0.7319510469753561,707.2261955738068,0.271491391318185,0.7381634031023298,0.2926420384601857,3554,0.7148019836803602 -40.44942855834961,0.2239551544189453,750.6180312633514,3087,0,750.6180312633514,0.2932155154264695,3581,0.7362375182168039,791.4083216190338,0.2699452127729143,0.7429150853838239,0.2915490386052687,3554,0.7192059260912 -44.49841260910034,0.2509458065032959,830.8124248981476,3429,0,830.8124248981476,0.2926150153915631,3581,0.7352532517540491,875.6934387683868,0.2693698746817453,0.7420486041477748,0.2910673863846194,3554,0.7180215624120709 -48.5478253364563,0.2797403335571289,910.9857234954834,3776,0,910.9857234954834,0.2917920208216978,3581,0.7350357682037141,959.9583053588868,0.2686145646231515,0.7421434947422573,0.2901438217721229,3554,0.7178022892339617 -52.598410844802856,0.3055357933044433,991.1750037670135,4122,0,991.1750037670135,0.2922372485099309,3581,0.7351968014782881,1044.237408399582,0.2698146786008562,0.7409387997218541,0.2908186432211065,3554,0.717972514464336 -56.64663577079773,0.3317186832427978,1071.141833782196,4465,0,1071.141833782196,0.2911474445751012,3581,0.7380251784592292,1128.2922093868256,0.2674111298152378,0.7455438886369977,0.2895034850146841,3554,0.7208618782313942 -60.70039987564087,0.3574385643005371,1151.3040285110474,4811,0,1151.3040285110474,0.291878093856901,3581,0.7367109369545867,1212.547758102417,0.2683026790618896,0.7439666475568499,0.2903159361041256,3554,0.7196454341718838 -64.746572971344,0.3820641040802002,1231.4249053001404,5157,0,1231.4249053001404,0.2899794420094771,3581,0.7389520401729615,1296.7527680397034,0.2668769189289638,0.7457617350987026,0.2885046311154685,3554,0.721741031781971 -68.8025426864624,0.4079747200012207,1311.5613470077517,5500,0,1311.5613470077517,0.2923064478214011,3581,0.7373460025568975,1380.9851520061493,0.2686341660363333,0.7450404848371234,0.2907542763765299,3554,0.7202586022131753 -72.85080456733704,0.4369354248046875,1391.7456901073456,5846,0,1391.7456901073456,0.2911164582824281,3581,0.7389854467371195,1465.259878873825,0.2673809358051845,0.746405805860247,0.2896439654812359,3554,0.7218080090215251 -76.90266561508179,0.462430477142334,1471.8602805137634,6191,0,1471.8602805137634,0.2897505047800021,3581,0.7380078615871963,1549.4647357463837,0.2664087159293039,0.7452504975455148,0.2882869379132667,3554,0.7206758532419457 -80.94577598571777,0.4891648292541504,1552.0047433376312,6532,0,1552.0047433376312,0.2895002282554628,3581,0.7375512824848157,1633.692532300949,0.2663153750555856,0.7444896016802106,0.2880332315645663,3554,0.7202943921022088 -84.99467325210571,0.5149352550506592,1632.0608353614807,6877,0,1632.0608353614807,0.29036252667071,3581,0.7384655315161617,1717.8377559185028,0.2664291858673095,0.746187550680978,0.2888295565955613,3554,0.721219777121729 -89.0437970161438,0.541454553604126,1712.099158525467,7222,0,1712.099158525467,0.2894576860186575,3581,0.7394924083923834,1801.9656219482424,0.266190699168614,0.7466784885951451,0.2880130353507931,3554,0.7223175169043683 -93.09405064582825,0.5671203136444092,1792.2901685237885,7566,0,1792.2901685237885,0.2897122917612049,3581,0.7396306024853393,1886.2462952137,0.2662081207547869,0.7469291005815778,0.2881514034582512,3554,0.722539606561269 -97.14030265808104,0.5930294990539551,1872.4543488025663,7912,0,1872.4543488025663,0.28995356896642,3581,0.7389450179768221,1970.496957540512,0.2667276688984462,0.7459709303719657,0.2883923497841341,3554,0.7218876947629431 -101.19419121742249,0.6205320358276367,1952.4229590892792,8257,0,1952.4229590892792,0.2885494706218584,3581,0.7408515101403239,2054.5611073970795,0.2652080740247454,0.7481414931161063,0.2870385852098867,3554,0.7236107615319006 -105.19413876533508,0.648857593536377,2032.4846720695496,8601,0,2032.4846720695496,0.2882610151668528,3581,0.7406954537620427,2138.6657240390778,0.2648976189749581,0.7478885650634766,0.286846188795855,3554,0.7233566601892234 -109.24890971183775,0.6749727725982666,2112.54460477829,8947,0,2112.54460477829,0.2884022090329168,3581,0.7395091116744624,2222.8213012218475,0.2646905183792114,0.7472258976527623,0.2870105062902275,3554,0.7222594699634215 -113.30114507675172,0.7012271881103516,2192.608017683029,9291,0,2192.608017683029,0.2882008492630376,3581,0.739634556731709,2306.978061437607,0.2647973639624459,0.7469348226274762,0.2868361937308754,3554,0.7223313245199071 -117.35312056541444,0.7275919914245605,2272.6760840415955,9634,0,2272.6760840415955,0.2884851459416888,3581,0.7401850832737015,2391.1396033763885,0.265298673084804,0.747213499886649,0.2870020053328995,3554,0.7230594186348129 -121.39879179000854,0.7551271915435791,2352.786567211151,9977,0,2352.786567211151,0.2881633180108559,3581,0.7410903329857232,2475.3380110263824,0.2640678882598877,0.7491606984819684,0.2867365006858469,3554,0.7238872573157006 -125.44814491271973,0.7824423313140869,2432.979870557785,10323,0,2432.979870557785,0.2881919181203748,3581,0.7409134145490086,2559.622112035752,0.2644612789154053,0.7485763686043876,0.286777992227068,3554,0.7236826160883864 -129.49608874320984,0.8082091808319092,2512.958654642105,10668,0,2512.958654642105,0.2879405848575816,3581,0.740052547843654,2643.6903455257416,0.2646090303148542,0.7473513739449638,0.2865494796246307,3554,0.7227766716419879 -133.54543042182922,0.8410682678222656,2593.1541748046875,11011,0,2593.1541748046875,0.2879999326414584,3581,0.740695862822012,2727.9833042621613,0.2639562232153756,0.7484585217067173,0.2865472813972812,3554,0.7235746968644485 -137.59311985969543,0.8715968132019043,2673.2036843299866,11356,0,2673.2036843299866,0.2886352027737538,3581,0.7422605853811785,2812.125711202621,0.264902012688773,0.7498021806989398,0.2872951767319833,3554,0.7250792460959482 -141.6435444355011,0.89859938621521,2753.203345775604,11700,0,2753.203345775604,0.2885399940659033,3581,0.7410890376291539,2896.2168004512787,0.2647356305803571,0.7487577029636928,0.2870766248472232,3554,0.7239579440639069 -145.69473385810852,0.925682783126831,2833.208889245987,12041,0,2833.208889245987,0.2889885624105522,3581,0.7389065663397095,2980.31423330307,0.2650294133595058,0.746711186000279,0.2875592387923378,3554,0.721740413530529 -149.74429845809937,0.9543476104736328,2913.327432394028,12386,0,2913.327432394028,0.2881583070262322,3581,0.7407446091350182,3064.5256729125977,0.2644084010805402,0.7485940115792411,0.2867051415988147,3554,0.7234813408967009 -153.79073357582092,0.9809386730194092,2993.439972639084,12732,0,2993.439972639084,0.2877276350552394,3581,0.7404874467676626,3148.726415157318,0.2642101390021188,0.7478584562029157,0.2863673874012996,3554,0.7233078870199071 -157.837087392807,1.010098218917847,3073.533169031143,13073,0,3073.533169031143,0.287606382862678,3581,0.741620951942544,3232.910295248032,0.2635124070303781,0.7499058587210519,0.2861709208319411,3554,0.7244298073033906 -161.88950777053833,1.0378003120422363,3153.508836746216,13418,0,3153.508836746216,0.2875869525141371,3581,0.7418854773893465,3316.9802169799805,0.263601439339774,0.7499217305864606,0.2861146084297622,3554,0.7247510232748312 -165.9377157688141,1.0656042098999023,3233.6690809726715,13764,0,3233.6690809726715,0.287572669503543,3581,0.7428763569882715,3401.2306904792786,0.2637285164424351,0.7505524499075753,0.2861339459609771,3554,0.7258133166414603 -169.98314476013184,1.094580888748169,3313.6374888420105,14107,0,3313.6374888420105,0.2878133331188041,3581,0.7414161492512567,3485.28772354126,0.2639488322394235,0.74928161076137,0.286385351040421,3554,0.7241858040675999 -174.02948689460754,1.1231722831726074,3393.6883749961853,14452,0,3393.6883749961853,0.2880063412476438,3581,0.7408737357319882,3569.427688598633,0.2637116909027099,0.7491108349391392,0.286574501634382,3554,0.7236119980347847 -178.07557320594788,1.1510798931121826,3473.860435962677,14797,0,3473.860435962677,0.2877132838679838,3581,0.7409820684471865,3653.68798494339,0.2637397732053484,0.7490803854806083,0.2863007192874666,3554,0.7236975915122046 -182.1272554397583,1.1790404319763184,3554.036841392517,15139,0,3554.036841392517,0.2873795591097109,3581,0.7423500331611281,3737.95885014534,0.2634006227765764,0.7502944810049874,0.2859343709607572,3554,0.7251582448913196 -186.17304944992063,1.210641384124756,3634.21212887764,15484,0,3634.21212887764,0.2882788433638474,3581,0.7405888936400447,3822.227237462997,0.2640561206000192,0.7487557274954659,0.2868814978226558,3554,0.7233342657481008 -190.22082328796387,1.2392685413360596,3714.2076604366302,15828,0,3714.2076604366302,0.2873993303415596,3581,0.7421144146188215,3906.3140411376953,0.2632385492324829,0.7502681868416923,0.2859557349828098,3554,0.724986439685038 -194.2671341896057,1.2684621810913086,3794.377538204193,16173,0,3794.377538204193,0.287493789106133,3581,0.741831072413432,3990.574172496796,0.2634538071496146,0.7497570855276925,0.2861162914475766,3554,0.724560327052265 -198.3175041675568,1.296675205230713,3874.452456474304,16521,0,3874.452456474304,0.2873837178860653,3581,0.742643260982442,4074.742446660996,0.2628585611070905,0.7512686593191964,0.285953914575786,3554,0.7255132586082583 -202.363365650177,1.3254055976867676,3954.5216364860535,16869,0,3954.5216364860535,0.287424862501309,3581,0.7417947342528274,4158.901402235031,0.2634029899324689,0.7496285438537598,0.2859665715566967,3554,0.7246958615072805 -206.41446828842163,1.353689670562744,4034.547282218933,17212,0,4034.547282218933,0.2874083637492146,3581,0.7424672970189891,4243.02034330368,0.2632824352809361,0.7505200249808175,0.2859747462146525,3554,0.7253164485658765 -210.46268796920776,1.3817291259765625,4114.6579785346985,17557,0,4114.6579785346985,0.2870543564341315,3581,0.7418575249581123,4327.222203493118,0.2625634159360613,0.750438894544329,0.2856668741701691,3554,0.7246600029236424 -214.46094393730164,1.415083646774292,4194.84974193573,17901,0,4194.84974193573,0.2870739231359955,3581,0.7430210960407359,4411.460675477982,0.2628797122410365,0.7512284687587193,0.2856506278961645,3554,0.7259621778497819 -218.50892448425293,1.4444615840911863,4274.968697547913,18247,0,4274.968697547913,0.287193061852049,3581,0.742275856953365,4495.671122074127,0.2630462987082345,0.7504652568272182,0.2858168344921655,3554,0.7250845355805079 -222.5591013431549,4.762778997421265,4351.688469409943,18574,0,4351.688469409943,0.287297372144216,3581,0.7420176037594247,4579.772929430008,0.2625848565782819,0.7509215899876186,0.2858960221976998,3554,0.7248626520074212 -226.6106996536255,4.797757148742676,4431.760053873062,18917,0,4431.760053873062,0.2873650374808014,3581,0.7426776220198618,4663.943806886673,0.263042824608939,0.7508281299046108,0.2859329798950126,3554,0.7255474685213844 -230.6644184589386,4.827903747558594,4511.818215370178,19261,0,4511.818215370178,0.2874655980565833,3581,0.7427368675387461,4748.09862780571,0.2631254366465977,0.7510519027709961,0.2860428740888348,3554,0.725652502571926 -234.71754837036133,4.8576250076293945,4591.8331344127655,19602,0,4591.8331344127655,0.2871664047773841,3581,0.7420789627548171,4832.209055185318,0.2624681336539132,0.7507716587611607,0.285777455310038,3554,0.7249418568866066 -238.76827216148376,4.890393733978272,4671.805415391922,19946,0,4671.805415391922,0.2871128179214081,3581,0.7427063925710347,4916.277668714523,0.2625590392521449,0.7512575558253697,0.2856975463111547,3554,0.7255555744847355 -242.82312202453613,4.920037508010864,4751.956250667572,20291,0,4751.956250667572,0.2879711961786163,3581,0.7410374960730243,5000.525846481323,0.2636475052152361,0.7494308607918876,0.286644346873681,3554,0.7237927335396737 -246.8783664703369,4.950624942779541,4832.079285860062,20632,0,4832.079285860062,0.2880584282170658,3581,0.7424082560300893,5084.747328519821,0.2634527172361101,0.7509378705705915,0.2865787263525693,3554,0.7253593139991911 -250.92867803573608,4.980751514434815,4912.069168806076,20977,0,4912.069168806076,0.2870752525808957,3581,0.7423472379180047,5168.830347061157,0.2624333586011614,0.7510521071297782,0.2856248502457618,3554,0.7251787158835116 -254.97441744804385,5.011352777481079,4992.18848824501,21320,0,4992.18848824501,0.2872758624074979,3581,0.7418286862302779,5253.038994789124,0.2626627002443586,0.7505565370832171,0.2858772513969734,3554,0.7246191983284679 -259.02623534202576,5.041996479034424,5072.292250871658,21662,0,5072.292250871658,0.2869897249589849,3581,0.7420900073739877,5337.238107681274,0.2624525683266775,0.7506332397460938,0.2855532017730902,3554,0.724930934444464 -263.08213925361633,5.072854518890381,5152.445718765259,22007,0,5152.445718765259,0.2871232830389556,3581,0.7414707587571558,5421.491142511368,0.2623602492468698,0.7502215249197823,0.2856884099287334,3554,0.7243254601988957 -267.1392109394073,5.108336448669434,5232.639133930206,22354,0,5232.639133930206,0.2872255480312762,3581,0.7433244140079587,5505.790293693543,0.2626238891056606,0.7519216537475586,0.2857978232603229,3554,0.7262842868510833 -271.1893949508667,5.138692140579224,5312.724345445633,22696,0,5312.724345445633,0.2871369865479266,3581,0.7426496014119659,5589.968581676483,0.2624912091663905,0.7513668196541923,0.2857631324849641,3554,0.7254721792346651 -275.2423415184021,5.170776605606079,5392.689038276672,23038,0,5392.689038276672,0.2870495499794924,3581,0.743064865457449,5674.031042814255,0.2620306015014648,0.7522259439740863,0.2856072816006172,3554,0.7259720698728546 -279.2945177555084,5.201882123947144,5472.827522277832,23383,0,5472.827522277832,0.2870913763613515,3581,0.7427012793214186,5758.265685558319,0.2623745373317173,0.7514401844569615,0.2856796341929867,3554,0.7255796862909749 -283.3469295501709,5.232615470886231,5552.985707521439,23728,0,5552.985707521439,0.2872230254947989,3581,0.7428281560885577,5842.519861698151,0.2625692401613508,0.7514640944344657,0.2858105489358381,3554,0.725731570061902 -287.3980450630188,5.264841794967651,5633.126035690308,24072,0,5633.126035690308,0.2869076402584822,3581,0.7427041427412036,5926.756481170654,0.2617558411189488,0.7519587108067104,0.285483957611582,3554,0.7255711681599958 -291.4509608745575,5.295077800750732,5713.109417676926,24417,0,5713.109417676926,0.2869837935894303,3581,0.741841162559341,6010.835768699646,0.262204783303397,0.7507854189191546,0.285589644260868,3554,0.7247238202113815 -295.5001335144043,5.32674765586853,5793.201881885529,24762,0,5793.201881885529,0.2872373425937238,3581,0.7423777810623778,6095.022074460983,0.2625353336334228,0.7509933880397252,0.2858242535094699,3554,0.7253233180263435 -299.5552673339844,5.357968091964722,5873.248897790909,25102,0,5873.248897790909,0.2869809642579761,3581,0.7422795384930885,6179.167838096619,0.2616557223456247,0.751718384878976,0.2855815211238657,3554,0.7251388043181978 -303.6081876754761,5.389884471893311,5953.253943920136,25447,0,5953.253943920136,0.2870431754616378,3581,0.7427448442081471,6263.270909070969,0.262043969971793,0.7519270351954869,0.2856289547483909,3554,0.7256108736414955 -307.6641731262207,5.420477151870728,6033.245709180832,25793,0,6033.245709180832,0.2870635602834404,3581,0.7423027185580146,6347.361993312836,0.2622495889663696,0.7512048993791852,0.2856707038943796,3554,0.7252663015044668 -311.72106552124023,5.456787824630737,6113.267695903778,26136,0,6113.267695903778,0.2872991106490855,3581,0.7421742737276599,6431.490020036697,0.2618169954844883,0.7519177028111049,0.2858858553962085,3554,0.7250756052819006 -315.7747297286988,5.48795223236084,6193.231685638428,26479,0,6193.231685638428,0.2869440465957484,3581,0.7427881363882295,6515.551483869553,0.2618359157017299,0.7520029204232352,0.2855081381124261,3554,0.7256874681257034 -319.8229854106903,5.520106554031372,6273.216883420944,26824,0,6273.216883420944,0.2869271046953539,3581,0.7426322163632715,6599.63019657135,0.2619940042495727,0.7516803741455078,0.2855152308303584,3554,0.7255216393500281 -323.87314200401306,5.551224231719971,6353.233630180359,27166,0,6353.233630180359,0.2870705824795797,3581,0.7424758191016825,6683.7409517765045,0.2619033200400216,0.7516687938145229,0.2856453899325584,3554,0.7254074689170653 -327.92193126678467,5.58203649520874,6433.285531997681,27512,0,6433.285531997681,0.2869271046953539,3581,0.742613467781346,6767.885113239288,0.2616982460021972,0.7520516259329659,0.2854895047009091,3554,0.7255234254097496 -331.9698152542114,5.613497018814087,6513.400120258331,27856,0,6513.400120258331,0.2869478304004642,3581,0.7426463289322117,6852.091889381409,0.2618895769119262,0.7518525123596191,0.2855585599522545,3554,0.7255723359682752 -336.02450037002563,5.647374629974365,6593.375950098038,28197,0,6593.375950098038,0.2871377705795343,3581,0.7426058319952528,6936.169029951096,0.2619935785021101,0.7518358911786761,0.285708434405995,3554,0.725502679639139 -340.0794105529785,5.67836856842041,6673.345023870468,28539,0,6673.345023870468,0.2870893310615052,3581,0.7427282091027296,7020.236711502075,0.2616695165634155,0.7522450855800084,0.2856231500542962,3554,0.725679843024585 -344.1304316520691,5.709791898727417,6753.366040945053,28885,0,6753.366040945053,0.2870189386584578,3581,0.7426622822710137,7104.3532173633575,0.2618741648537772,0.7519869804382324,0.2856465405671866,3554,0.7255535823412 -348.1815276145935,5.743354797363281,6833.333642244339,29226,0,6833.333642244339,0.2869982470416783,3581,0.7426421019791958,7188.418164491653,0.261786869594029,0.7519833700997489,0.2855797007168419,3554,0.7255307070378447 -352.23563385009766,5.774363517761231,6913.417159795761,29570,0,6913.417159795761,0.2870593674187552,3581,0.7426630322142908,7272.599680185318,0.2614668778010777,0.7523882048470634,0.2855992099845684,3554,0.7256052406839125 -356.2854199409485,5.811537981033325,6993.5478773117065,29916,0,6993.5478773117065,0.28703622144216,3581,0.7427758645891511,7356.830544710159,0.2617281675338745,0.7522600037711007,0.2856097374327342,3554,0.7256939941131472 -360.3349132537842,5.845282316207886,7073.549101829529,30258,0,7073.549101829529,0.286973158030229,3581,0.7426906437622173,7440.927565097809,0.2616961683545794,0.7521281242370605,0.2855764377231201,3554,0.7255890287572102 -364.3896791934967,5.879350185394287,7153.538716077805,30602,0,7153.538716077805,0.2869701241687901,3581,0.7428522224500838,7525.018535852432,0.2613735369273594,0.7526561873299735,0.2855524633060899,3554,0.7257843962128939 -368.44449734687805,5.916398048400879,7233.716419696808,30947,0,7233.716419696808,0.2870090189542027,3581,0.742683007976124,7609.30117225647,0.2616010393415178,0.7522655895778111,0.2855880814586117,3554,0.7255678708189716 -372.5036156177521,5.948800563812256,7313.770256280899,31291,0,7313.770256280899,0.287009871162472,3581,0.7427899771580914,7693.459502458572,0.2616640159061977,0.7522608212062291,0.2856347594424855,3554,0.7256755152644907 -376.55426692962646,5.986757755279541,7393.891540288925,31633,0,7393.891540288925,0.2870043829412175,3581,0.7427310725225147,7777.6818697452545,0.2613063539777483,0.7526048932756696,0.2855985230385217,3554,0.7256450148600169 -380.60776138305664,6.020724773406982,7473.958404064178,31978,0,7473.958404064178,0.2869842367377304,3581,0.7428728799785326,7861.848997831345,0.2615161623273577,0.752514089856829,0.2855669235203731,3554,0.725779793674381 -384.6572859287262,6.052553415298462,7553.929656028748,32323,0,7553.929656028748,0.2870202340150272,3581,0.742837428114528,7945.914715051651,0.261630654335022,0.7524120467049735,0.2856301053830191,3554,0.7257753972196821 -388.7059483528137,6.086850643157959,7633.9339427948,32664,0,7633.9339427948,0.2869662721874127,3581,0.7427972720608769,8030.01468038559,0.2613300766263689,0.7525908606392997,0.2855583195211381,3554,0.7257525906109313 -392.7595648765564,6.119038105010986,7714.073157310486,33009,0,7714.073157310486,0.2869116626815135,3581,0.742825701728742,8114.253061294556,0.2613956247057233,0.7525243759155273,0.2854895218745603,3554,0.7257540331976294 -396.8049416542053,6.152233600616455,7794.181841135025,33353,0,7794.181841135025,0.2869396832894094,3581,0.7428152025228637,8198.453286409378,0.2614794969558716,0.752453054700579,0.2855558808626723,3554,0.7257225710686902 -400.8564488887787,6.191194772720337,7874.324627399445,33695,0,7874.324627399445,0.2868952321060807,3581,0.7426832125061086,8282.699506282806,0.2613553319658552,0.7523847988673619,0.2854971984966323,3554,0.7255684890704136 -404.9108922481537,6.224079132080078,7954.292649030685,34037,0,7954.292649030685,0.28690225430222,3581,0.7429591234553895,8366.768176078796,0.2613096748079572,0.7527414730616978,0.2854857780186058,3554,0.725891834574599 -408.96024656295776,6.263574361801148,8034.458181381226,34380,0,8034.458181381226,0.2868850056068486,3581,0.7428880833740575,8451.035452127457,0.261373656136649,0.7526027134486607,0.2854935405089336,3554,0.7257960056010833 -413.01319456100464,6.295818567276001,8114.476765394211,34723,0,8114.476765394211,0.2869422058258866,3581,0.7429870758866238,8535.151907682419,0.2613887275968279,0.7526653153555733,0.2855344824933173,3554,0.7258881250659468 -417.0644700527191,6.32902193069458,8194.555145740509,35068,0,8194.555145740509,0.2868667683498848,3581,0.7428655168990854,8619.32661819458,0.2612641368593488,0.7526427677699498,0.2854579567037141,3554,0.7257840527398706 -421.1167623996735,6.363311767578125,8274.71585559845,35411,0,8274.71585559845,0.2868662911132539,3581,0.7429296029609397,8703.585729598999,0.2612950801849365,0.7526873179844448,0.2854678315531355,3554,0.7258443666027715 -425.1749198436737,6.397544860839844,8354.723905563354,35754,0,8354.723905563354,0.2868410998368123,3581,0.743001665692195,8787.698282241821,0.2612860543387277,0.7527467863900321,0.2854532682969453,3554,0.7259219915060495 -429.2230050563812,6.4317402839660645,8434.793033599854,36099,0,8434.793033599854,0.2868526557809445,3581,0.7429603506352974,8871.863282203674,0.261279889515468,0.7527127265930176,0.2854581971348304,3554,0.7258752104802687 -433.2783992290497,6.466166973114014,8453.836063861847,36189,0,8453.836063861847,0.2868524853392907,3581,0.7429589871020664,8895.000623464584,0.2612797021865845,0.7527111598423549,0.28545797387736527,3554,0.7258738365881753 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/measurements.csv deleted file mode 100644 index e359653a2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.989067,0.812358,,,,,,,,,,,,,, -1,,,0.2794570582253592,0.8154546192714146,0.2738082104341059,0.8275999671365011,3554.0,0.2972505852284627,0.8250461419645351,3581.0,29.85322141647339,33.85292458534241,29.85322141647339,3.9995675086975098,0.0,0.0 -100,0.4719467,0.3213598,,,,,,,,,,,,,, -200,0.2393438,0.38450125,,,,,,,,,,,,,, -300,0.12403842,0.33800998,,,,,,,,,,,,,, -340,,,0.6971728461129325,0.3081912994384765,0.672379355787493,0.3307215873619513,3554.0,0.6910679574054035,0.3325069185676138,3581.0,109.82941102981567,117.90757751464844,109.82941102981567,8.045104265213013,0.0208237171173095,0.0 -400,0.1686696,0.3214888,,,,,,,,,,,,,, -500,0.23209704,0.22575936,,,,,,,,,,,,,, -600,0.1154517,0.40880114,,,,,,,,,,,,,, -681,,,0.7214180401393345,0.2886877400534494,0.6979783315234594,0.3102809581853193,3554.0,0.7154252969229964,0.312450026507086,3581.0,189.9423379898072,202.10964155197144,189.9423379898072,12.096370697021484,0.0461506843566894,0.0 -700,0.41298154,0.31055477,,,,,,,,,,,,,, -800,0.12411045,0.21450521,,,,,,,,,,,,,, -900,0.11841275,0.3212934,,,,,,,,,,,,,, -1000,0.24094768,0.28615832,,,,,,,,,,,,,, -1018,,,0.7268994195120675,0.2825293030057634,0.7041054780924663,0.3032789171312254,3554.0,0.7213962089412873,0.3053255653754189,3581.0,270.07433581352234,286.3332691192627,270.07433581352234,16.14963984489441,0.0723109245300293,0.0 -1100,0.11686276,0.31195983,,,,,,,,,,,,,, -1200,0.21953504,0.22400066,,,,,,,,,,,,,, -1300,0.093312524,0.3991742,,,,,,,,,,,,,, -1362,,,0.7332515716552734,0.2778271777289254,0.7102090624120709,0.2989388263554269,3554.0,0.7271120038920692,0.3010481616299567,3581.0,350.1551387310028,370.4994876384735,350.1551387310028,20.19764614105225,0.0971243381500244,0.0 -1400,0.22210364,0.2259044,,,,,,,,,,,,,, -1500,0.23130155,0.28937796,,,,,,,,,,,,,, -1600,0.3371049,0.22654346,,,,,,,,,,,,,, -1700,0.15217644,0.40093854,,,,,,,,,,,,,, -1708,,,0.7331899915422712,0.278896621295384,0.7097868653717642,0.3001783518021068,3554.0,0.726705943695895,0.3023199290744555,3581.0,430.2985694408417,454.72812390327454,430.2985694408417,24.24586677551269,0.1208417415618896,0.0 -1800,0.2148821,0.25857124,,,,,,,,,,,,,, -1900,0.05965836,0.2260209,,,,,,,,,,,,,, -2000,0.1282169,0.20950496,,,,,,,,,,,,,, -2051,,,0.7381429672241211,0.2731998818261282,0.7141687568145048,0.2945284609990504,3554.0,0.7314937179296984,0.2962755226150168,3581.0,510.4877436161041,539.0076286792755,510.4877436161041,28.29649257659912,0.1473636627197265,0.0 -2100,0.11555115,0.22700912,,,,,,,,,,,,,, -2200,0.1133446,0.27220672,,,,,,,,,,,,,, -2300,0.3295565,0.30653772,,,,,,,,,,,,,, -2396,,,0.7405281748090472,0.2719523566109793,0.7164755903339547,0.2936053429014842,3554.0,0.7335627432543284,0.2954339499048799,3581.0,590.4974267482758,623.1101274490356,590.4974267482758,32.3504912853241,0.1729648113250732,0.0 -2400,0.18383986,0.2631085,,,,,,,,,,,,,, -2500,0.27582592,0.21268928,,,,,,,,,,,,,, -2600,0.18272904,0.30024198,,,,,,,,,,,,,, -2700,0.04469243,0.2866637,,,,,,,,,,,,,, -2742,,,0.7381634031023298,0.271491391318185,0.7148019836803602,0.2926420384601857,3554.0,0.7319510469753561,0.2942995925217292,3581.0,670.5267918109894,707.2261955738068,670.5267918109894,36.397353172302246,0.1988835334777832,0.0 -2800,0.20861852,0.33624044,,,,,,,,,,,,,, -2900,0.15674224,0.23579445,,,,,,,,,,,,,, -3000,0.09953926,0.33320078,,,,,,,,,,,,,, -3087,,,0.7429150853838239,0.2699452127729143,0.7192059260912,0.2915490386052687,3554.0,0.7362375182168039,0.2932155154264695,3581.0,750.6180312633514,791.4083216190338,750.6180312633514,40.44942855834961,0.2239551544189453,0.0 -3100,0.29081804,0.19719566,,,,,,,,,,,,,, -3200,0.13262442,0.393619,,,,,,,,,,,,,, -3300,0.09952576,0.21097714,,,,,,,,,,,,,, -3400,0.03736325,0.318415,,,,,,,,,,,,,, -3429,,,0.7420486041477748,0.2693698746817453,0.7180215624120709,0.2910673863846194,3554.0,0.7352532517540491,0.2926150153915631,3581.0,830.8124248981476,875.6934387683868,830.8124248981476,44.49841260910034,0.2509458065032959,0.0 -3500,0.06555342,0.31682974,,,,,,,,,,,,,, -3600,0.06420254,0.30515343,,,,,,,,,,,,,, -3700,0.17844121,0.3300063,,,,,,,,,,,,,, -3776,,,0.7421434947422573,0.2686145646231515,0.7178022892339617,0.2901438217721229,3554.0,0.7350357682037141,0.2917920208216978,3581.0,910.9857234954834,959.9583053588868,910.9857234954834,48.5478253364563,0.2797403335571289,0.0 -3800,0.048436712,0.23526809,,,,,,,,,,,,,, -3900,0.110705644,0.28619754,,,,,,,,,,,,,, -4000,0.2575554,0.20394439,,,,,,,,,,,,,, -4100,0.114690036,0.2811556,,,,,,,,,,,,,, -4122,,,0.7409387997218541,0.2698146786008562,0.717972514464336,0.2908186432211065,3554.0,0.7351968014782881,0.2922372485099309,3581.0,991.1750037670135,1044.237408399582,991.1750037670135,52.598410844802856,0.3055357933044433,0.0 -4200,0.5871381,0.25358564,,,,,,,,,,,,,, -4300,0.15089622,0.3561511,,,,,,,,,,,,,, -4400,0.058269218,0.23204824,,,,,,,,,,,,,, -4465,,,0.7455438886369977,0.2674111298152378,0.7208618782313942,0.2895034850146841,3554.0,0.7380251784592292,0.2911474445751012,3581.0,1071.141833782196,1128.2922093868256,1071.141833782196,56.64663577079773,0.3317186832427978,0.0 -4500,0.11252828,0.22023392,,,,,,,,,,,,,, -4600,0.0857599,0.24607073,,,,,,,,,,,,,, -4700,0.23515432,0.30393407,,,,,,,,,,,,,, -4800,0.0740354,0.21254627,,,,,,,,,,,,,, -4811,,,0.7439666475568499,0.2683026790618896,0.7196454341718838,0.2903159361041256,3554.0,0.7367109369545867,0.291878093856901,3581.0,1151.3040285110474,1212.547758102417,1151.3040285110474,60.70039987564087,0.3574385643005371,0.0 -4900,0.053765193,0.4176647,,,,,,,,,,,,,, -5000,0.05537064,0.20934871,,,,,,,,,,,,,, -5100,0.09229196,0.27480707,,,,,,,,,,,,,, -5157,,,0.7457617350987026,0.2668769189289638,0.721741031781971,0.2885046311154685,3554.0,0.7389520401729615,0.2899794420094771,3581.0,1231.4249053001404,1296.7527680397034,1231.4249053001404,64.746572971344,0.3820641040802002,0.0 -5200,0.10233337,0.25414717,,,,,,,,,,,,,, -5300,0.08181685,0.26614624,,,,,,,,,,,,,, -5400,0.13191526,0.29978934,,,,,,,,,,,,,, -5500,,,0.7450404848371234,0.2686341660363333,0.7202586022131753,0.2907542763765299,3554.0,0.7373460025568975,0.2923064478214011,3581.0,1311.5613470077517,1380.9851520061493,1311.5613470077517,68.8025426864624,0.4079747200012207,0.0 -5500,0.26346946,0.2679282,,,,,,,,,,,,,, -5600,0.1169663,0.34175435,,,,,,,,,,,,,, -5700,0.29677436,0.26894745,,,,,,,,,,,,,, -5800,0.06612008,0.34891832,,,,,,,,,,,,,, -5846,,,0.746405805860247,0.2673809358051845,0.7218080090215251,0.2896439654812359,3554.0,0.7389854467371195,0.2911164582824281,3581.0,1391.7456901073456,1465.259878873825,1391.7456901073456,72.85080456733704,0.4369354248046875,0.0 -5900,0.058763977,0.32625937,,,,,,,,,,,,,, -6000,0.49441454,0.20993915,,,,,,,,,,,,,, -6100,0.1014076,0.27312285,,,,,,,,,,,,,, -6191,,,0.7452504975455148,0.2664087159293039,0.7206758532419457,0.2882869379132667,3554.0,0.7380078615871963,0.2897505047800021,3581.0,1471.8602805137634,1549.4647357463837,1471.8602805137634,76.90266561508179,0.462430477142334,0.0 -6200,0.038245868,0.34087774,,,,,,,,,,,,,, -6300,0.27951956,0.25480443,,,,,,,,,,,,,, -6400,0.073334605,0.3176179,,,,,,,,,,,,,, -6500,0.1686828,0.2608994,,,,,,,,,,,,,, -6532,,,0.7444896016802106,0.2663153750555856,0.7202943921022088,0.2880332315645663,3554.0,0.7375512824848157,0.2895002282554628,3581.0,1552.0047433376312,1633.692532300949,1552.0047433376312,80.94577598571777,0.4891648292541504,0.0 -6600,0.05822493,0.22843364,,,,,,,,,,,,,, -6700,0.09224364,0.2404438,,,,,,,,,,,,,, -6800,0.075596176,0.27095702,,,,,,,,,,,,,, -6877,,,0.746187550680978,0.2664291858673095,0.721219777121729,0.2888295565955613,3554.0,0.7384655315161617,0.29036252667071,3581.0,1632.0608353614807,1717.8377559185028,1632.0608353614807,84.99467325210571,0.5149352550506592,0.0 -6900,0.10683096,0.3582033,,,,,,,,,,,,,, -7000,0.1570653,0.22401248,,,,,,,,,,,,,, -7100,0.06514005,0.28798816,,,,,,,,,,,,,, -7200,0.13871975,0.23364015,,,,,,,,,,,,,, -7222,,,0.7466784885951451,0.266190699168614,0.7223175169043683,0.2880130353507931,3554.0,0.7394924083923834,0.2894576860186575,3581.0,1712.099158525467,1801.9656219482424,1712.099158525467,89.0437970161438,0.541454553604126,0.0 -7300,0.16066481,0.25120917,,,,,,,,,,,,,, -7400,0.08605038,0.317536,,,,,,,,,,,,,, -7500,0.07874159,0.23902163,,,,,,,,,,,,,, -7566,,,0.7469291005815778,0.2662081207547869,0.722539606561269,0.2881514034582512,3554.0,0.7396306024853393,0.2897122917612049,3581.0,1792.2901685237885,1886.2462952137,1792.2901685237885,93.09405064582825,0.5671203136444092,0.0 -7600,0.095227756,0.2697892,,,,,,,,,,,,,, -7700,0.22223942,0.21896905,,,,,,,,,,,,,, -7800,0.067674674,0.30192485,,,,,,,,,,,,,, -7900,0.055946417,0.2953677,,,,,,,,,,,,,, -7912,,,0.7459709303719657,0.2667276688984462,0.7218876947629431,0.2883923497841341,3554.0,0.7389450179768221,0.28995356896642,3581.0,1872.4543488025663,1970.496957540512,1872.4543488025663,97.14030265808104,0.5930294990539551,0.0 -8000,0.07197276,0.27062815,,,,,,,,,,,,,, -8100,0.121554,0.19582911,,,,,,,,,,,,,, -8200,0.040035106,0.3184112,,,,,,,,,,,,,, -8257,,,0.7481414931161063,0.2652080740247454,0.7236107615319006,0.2870385852098867,3554.0,0.7408515101403239,0.2885494706218584,3581.0,1952.4229590892792,2054.5611073970795,1952.4229590892792,101.19419121742249,0.6205320358276367,0.0 -8300,0.12480665,0.21276666,,,,,,,,,,,,,, -8400,0.12000295,0.19272673,,,,,,,,,,,,,, -8500,0.073355004,0.29337317,,,,,,,,,,,,,, -8600,0.17544955,0.28580183,,,,,,,,,,,,,, -8601,,,0.7478885650634766,0.2648976189749581,0.7233566601892234,0.286846188795855,3554.0,0.7406954537620427,0.2882610151668528,3581.0,2032.4846720695496,2138.6657240390778,2032.4846720695496,105.19413876533508,0.648857593536377,0.0 -8700,0.07549623,0.26054832,,,,,,,,,,,,,, -8800,0.16819893,0.20553625,,,,,,,,,,,,,, -8900,0.04383266,0.25377065,,,,,,,,,,,,,, -8947,,,0.7472258976527623,0.2646905183792114,0.7222594699634215,0.2870105062902275,3554.0,0.7395091116744624,0.2884022090329168,3581.0,2112.54460477829,2222.8213012218475,2112.54460477829,109.24890971183775,0.6749727725982666,0.0 -9000,0.19936219,0.23494253,,,,,,,,,,,,,, -9100,0.22758651,0.3698348,,,,,,,,,,,,,, -9200,0.16010094,0.28083375,,,,,,,,,,,,,, -9291,,,0.7469348226274762,0.2647973639624459,0.7223313245199071,0.2868361937308754,3554.0,0.739634556731709,0.2882008492630376,3581.0,2192.608017683029,2306.978061437607,2192.608017683029,113.30114507675172,0.7012271881103516,0.0 -9300,0.10185886,0.3317519,,,,,,,,,,,,,, -9400,0.107108735,0.19987023,,,,,,,,,,,,,, -9500,0.06723956,0.32763425,,,,,,,,,,,,,, -9600,0.030918373,0.29602444,,,,,,,,,,,,,, -9634,,,0.747213499886649,0.265298673084804,0.7230594186348129,0.2870020053328995,3554.0,0.7401850832737015,0.2884851459416888,3581.0,2272.6760840415955,2391.1396033763885,2272.6760840415955,117.35312056541444,0.7275919914245605,0.0 -9700,0.06024976,0.26974472,,,,,,,,,,,,,, -9800,0.06730823,0.24090938,,,,,,,,,,,,,, -9900,0.054795936,0.23104127,,,,,,,,,,,,,, -9977,,,0.7491606984819684,0.2640678882598877,0.7238872573157006,0.2867365006858469,3554.0,0.7410903329857232,0.2881633180108559,3581.0,2352.786567211151,2475.3380110263824,2352.786567211151,121.39879179000854,0.7551271915435791,0.0 -10000,0.08063174,0.26168796,,,,,,,,,,,,,, -10100,0.07595343,0.3476817,,,,,,,,,,,,,, -10200,0.076155685,0.27686292,,,,,,,,,,,,,, -10300,0.032257028,0.2741322,,,,,,,,,,,,,, -10323,,,0.7485763686043876,0.2644612789154053,0.7236826160883864,0.286777992227068,3554.0,0.7409134145490086,0.2881919181203748,3581.0,2432.979870557785,2559.622112035752,2432.979870557785,125.44814491271973,0.7824423313140869,0.0 -10400,0.05048112,0.28672627,,,,,,,,,,,,,, -10500,0.06852129,0.28524953,,,,,,,,,,,,,, -10600,0.06596575,0.27875143,,,,,,,,,,,,,, -10668,,,0.7473513739449638,0.2646090303148542,0.7227766716419879,0.2865494796246307,3554.0,0.740052547843654,0.2879405848575816,3581.0,2512.958654642105,2643.6903455257416,2512.958654642105,129.49608874320984,0.8082091808319092,0.0 -10700,0.06424837,0.20601201,,,,,,,,,,,,,, -10800,0.068441,0.30511484,,,,,,,,,,,,,, -10900,0.103666924,0.32640532,,,,,,,,,,,,,, -11000,0.14359882,0.265825,,,,,,,,,,,,,, -11011,,,0.7484585217067173,0.2639562232153756,0.7235746968644485,0.2865472813972812,3554.0,0.740695862822012,0.2879999326414584,3581.0,2593.1541748046875,2727.9833042621613,2593.1541748046875,133.54543042182922,0.8410682678222656,0.0 -11100,0.035093598,0.351827,,,,,,,,,,,,,, -11200,0.057513267,0.26635182,,,,,,,,,,,,,, -11300,0.15125534,0.31079456,,,,,,,,,,,,,, -11356,,,0.7498021806989398,0.264902012688773,0.7250792460959482,0.2872951767319833,3554.0,0.7422605853811785,0.2886352027737538,3581.0,2673.2036843299866,2812.125711202621,2673.2036843299866,137.59311985969543,0.8715968132019043,0.0 -11400,0.039822653,0.27471283,,,,,,,,,,,,,, -11500,0.10757869,0.33548564,,,,,,,,,,,,,, -11600,0.11954445,0.2078806,,,,,,,,,,,,,, -11700,,,0.7487577029636928,0.2647356305803571,0.7239579440639069,0.2870766248472232,3554.0,0.7410890376291539,0.2885399940659033,3581.0,2753.203345775604,2896.2168004512787,2753.203345775604,141.6435444355011,0.89859938621521,0.0 -11700,0.16412489,0.23720214,,,,,,,,,,,,,, -11800,0.22673523,0.21982561,,,,,,,,,,,,,, -11900,0.2838886,0.41261125,,,,,,,,,,,,,, -12000,0.059168085,0.2632023,,,,,,,,,,,,,, -12041,,,0.746711186000279,0.2650294133595058,0.721740413530529,0.2875592387923378,3554.0,0.7389065663397095,0.2889885624105522,3581.0,2833.208889245987,2980.31423330307,2833.208889245987,145.69473385810852,0.925682783126831,0.0 -12100,0.042567167,0.21280277,,,,,,,,,,,,,, -12200,0.07384232,0.20573357,,,,,,,,,,,,,, -12300,0.074657135,0.27393597,,,,,,,,,,,,,, -12386,,,0.7485940115792411,0.2644084010805402,0.7234813408967009,0.2867051415988147,3554.0,0.7407446091350182,0.2881583070262322,3581.0,2913.327432394028,3064.5256729125977,2913.327432394028,149.74429845809937,0.9543476104736328,0.0 -12400,0.05858995,0.24211933,,,,,,,,,,,,,, -12500,0.05579192,0.34303576,,,,,,,,,,,,,, -12600,0.17066398,0.34016776,,,,,,,,,,,,,, -12700,0.09616747,0.3526005,,,,,,,,,,,,,, -12732,,,0.7478584562029157,0.2642101390021188,0.7233078870199071,0.2863673874012996,3554.0,0.7404874467676626,0.2877276350552394,3581.0,2993.439972639084,3148.726415157318,2993.439972639084,153.79073357582092,0.9809386730194092,0.0 -12800,0.039047264,0.24020398,,,,,,,,,,,,,, -12900,0.10695607,0.24367997,,,,,,,,,,,,,, -13000,0.12862581,0.26275566,,,,,,,,,,,,,, -13073,,,0.7499058587210519,0.2635124070303781,0.7244298073033906,0.2861709208319411,3554.0,0.741620951942544,0.287606382862678,3581.0,3073.533169031143,3232.910295248032,3073.533169031143,157.837087392807,1.010098218917847,0.0 -13100,0.064317755,0.2743685,,,,,,,,,,,,,, -13200,0.102469556,0.19582154,,,,,,,,,,,,,, -13300,0.10370774,0.3411152,,,,,,,,,,,,,, -13400,0.044201937,0.31545693,,,,,,,,,,,,,, -13418,,,0.7499217305864606,0.263601439339774,0.7247510232748312,0.2861146084297622,3554.0,0.7418854773893465,0.2875869525141371,3581.0,3153.508836746216,3316.9802169799805,3153.508836746216,161.88950777053833,1.0378003120422363,0.0 -13500,0.11752258,0.33382848,,,,,,,,,,,,,, -13600,0.0939069,0.24003138,,,,,,,,,,,,,, -13700,0.06873524,0.28574157,,,,,,,,,,,,,, -13764,,,0.7505524499075753,0.2637285164424351,0.7258133166414603,0.2861339459609771,3554.0,0.7428763569882715,0.287572669503543,3581.0,3233.6690809726715,3401.2306904792786,3233.6690809726715,165.9377157688141,1.0656042098999023,0.0 -13800,0.07237784,0.24775602,,,,,,,,,,,,,, -13900,0.17756964,0.25120884,,,,,,,,,,,,,, -14000,0.045005977,0.27925012,,,,,,,,,,,,,, -14100,0.029808844,0.28955126,,,,,,,,,,,,,, -14107,,,0.74928161076137,0.2639488322394235,0.7241858040675999,0.286385351040421,3554.0,0.7414161492512567,0.2878133331188041,3581.0,3313.6374888420105,3485.28772354126,3313.6374888420105,169.98314476013184,1.094580888748169,0.0 -14200,0.21395752,0.22883059,,,,,,,,,,,,,, -14300,0.024783175,0.30391127,,,,,,,,,,,,,, -14400,0.059581827,0.23006421,,,,,,,,,,,,,, -14452,,,0.7491108349391392,0.2637116909027099,0.7236119980347847,0.286574501634382,3554.0,0.7408737357319882,0.2880063412476438,3581.0,3393.6883749961853,3569.427688598633,3393.6883749961853,174.02948689460754,1.1231722831726074,0.0 -14500,0.062364202,0.35662034,,,,,,,,,,,,,, -14600,0.18150733,0.346129,,,,,,,,,,,,,, -14700,0.067496546,0.2398696,,,,,,,,,,,,,, -14797,,,0.7490803854806083,0.2637397732053484,0.7236975915122046,0.2863007192874666,3554.0,0.7409820684471865,0.2877132838679838,3581.0,3473.860435962677,3653.68798494339,3473.860435962677,178.07557320594788,1.1510798931121826,0.0 -14800,0.109174475,0.25656655,,,,,,,,,,,,,, -14900,0.086373344,0.23380342,,,,,,,,,,,,,, -15000,0.04278829,0.24407175,,,,,,,,,,,,,, -15100,0.14740284,0.21622886,,,,,,,,,,,,,, -15139,,,0.7502944810049874,0.2634006227765764,0.7251582448913196,0.2859343709607572,3554.0,0.7423500331611281,0.2873795591097109,3581.0,3554.036841392517,3737.95885014534,3554.036841392517,182.1272554397583,1.1790404319763184,0.0 -15200,0.048991274,0.21732588,,,,,,,,,,,,,, -15300,0.12428231,0.19886927,,,,,,,,,,,,,, -15400,0.04690125,0.35216457,,,,,,,,,,,,,, -15484,,,0.7487557274954659,0.2640561206000192,0.7233342657481008,0.2868814978226558,3554.0,0.7405888936400447,0.2882788433638474,3581.0,3634.21212887764,3822.227237462997,3634.21212887764,186.17304944992063,1.210641384124756,0.0 -15500,0.095151566,0.26792118,,,,,,,,,,,,,, -15600,0.15535401,0.26664564,,,,,,,,,,,,,, -15700,0.102625206,0.22164482,,,,,,,,,,,,,, -15800,0.065389834,0.3015237,,,,,,,,,,,,,, -15828,,,0.7502681868416923,0.2632385492324829,0.724986439685038,0.2859557349828098,3554.0,0.7421144146188215,0.2873993303415596,3581.0,3714.2076604366302,3906.3140411376953,3714.2076604366302,190.22082328796387,1.2392685413360596,0.0 -15900,0.06870774,0.30233237,,,,,,,,,,,,,, -16000,0.04064338,0.19650392,,,,,,,,,,,,,, -16100,0.054350633,0.29287523,,,,,,,,,,,,,, -16173,,,0.7497570855276925,0.2634538071496146,0.724560327052265,0.2861162914475766,3554.0,0.741831072413432,0.287493789106133,3581.0,3794.377538204193,3990.574172496796,3794.377538204193,194.2671341896057,1.2684621810913086,0.0 -16200,0.0250038,0.28229037,,,,,,,,,,,,,, -16300,0.046306044,0.28897333,,,,,,,,,,,,,, -16400,0.03252152,0.3329029,,,,,,,,,,,,,, -16500,0.13806811,0.33994013,,,,,,,,,,,,,, -16521,,,0.7512686593191964,0.2628585611070905,0.7255132586082583,0.285953914575786,3554.0,0.742643260982442,0.2873837178860653,3581.0,3874.452456474304,4074.742446660996,3874.452456474304,198.3175041675568,1.296675205230713,0.0 -16600,0.07928823,0.23132984,,,,,,,,,,,,,, -16700,0.08943371,0.30913427,,,,,,,,,,,,,, -16800,0.12993222,0.2341345,,,,,,,,,,,,,, -16869,,,0.7496285438537598,0.2634029899324689,0.7246958615072805,0.2859665715566967,3554.0,0.7417947342528274,0.287424862501309,3581.0,3954.5216364860535,4158.901402235031,3954.5216364860535,202.363365650177,1.3254055976867676,0.0 -16900,0.032229915,0.34572744,,,,,,,,,,,,,, -17000,0.036765978,0.22094561,,,,,,,,,,,,,, -17100,0.061337136,0.25460213,,,,,,,,,,,,,, -17200,0.09698074,0.2598984,,,,,,,,,,,,,, -17212,,,0.7505200249808175,0.2632824352809361,0.7253164485658765,0.2859747462146525,3554.0,0.7424672970189891,0.2874083637492146,3581.0,4034.547282218933,4243.02034330368,4034.547282218933,206.41446828842163,1.353689670562744,0.0 -17300,0.08762988,0.27469915,,,,,,,,,,,,,, -17400,0.059884027,0.26117578,,,,,,,,,,,,,, -17500,0.022011358,0.26459506,,,,,,,,,,,,,, -17557,,,0.750438894544329,0.2625634159360613,0.7246600029236424,0.2856668741701691,3554.0,0.7418575249581123,0.2870543564341315,3581.0,4114.6579785346985,4327.222203493118,4114.6579785346985,210.46268796920776,1.3817291259765625,0.0 -17600,0.062244534,0.31507352,,,,,,,,,,,,,, -17700,0.15828022,0.2556056,,,,,,,,,,,,,, -17800,0.1095783,0.33495107,,,,,,,,,,,,,, -17900,0.043293055,0.2940228,,,,,,,,,,,,,, -17901,,,0.7512284687587193,0.2628797122410365,0.7259621778497819,0.2856506278961645,3554.0,0.7430210960407359,0.2870739231359955,3581.0,4194.84974193573,4411.460675477982,4194.84974193573,214.46094393730164,1.415083646774292,0.0 -18000,0.05962533,0.2853999,,,,,,,,,,,,,, -18100,0.03784248,0.2523638,,,,,,,,,,,,,, -18200,0.028767738,0.29941532,,,,,,,,,,,,,, -18247,,,0.7504652568272182,0.2630462987082345,0.7250845355805079,0.2858168344921655,3554.0,0.742275856953365,0.287193061852049,3581.0,4274.968697547913,4495.671122074127,4274.968697547913,218.50892448425293,1.4444615840911863,0.0 -18300,0.04952387,0.21673115,,,,,,,,,,,,,, -18400,0.07592197,0.2998343,,,,,,,,,,,,,, -18500,0.024950132,0.26574674,,,,,,,,,,,,,, -18574,,,0.7509215899876186,0.2625848565782819,0.7248626520074212,0.2858960221976998,3554.0,0.7420176037594247,0.287297372144216,3581.0,4351.688469409943,4579.772929430008,4351.688469409943,222.5591013431549,4.762778997421265,0.0 -18600,0.04947434,0.2974972,,,,,,,,,,,,,, -18700,0.05464649,0.2090697,,,,,,,,,,,,,, -18800,0.15415198,0.38461956,,,,,,,,,,,,,, -18900,0.10772902,0.2185217,,,,,,,,,,,,,, -18917,,,0.7508281299046108,0.263042824608939,0.7255474685213844,0.2859329798950126,3554.0,0.7426776220198618,0.2873650374808014,3581.0,4431.760053873062,4663.943806886673,4431.760053873062,226.6106996536255,4.797757148742676,0.0 -19000,0.031979155,0.24649788,,,,,,,,,,,,,, -19100,0.06713143,0.2434088,,,,,,,,,,,,,, -19200,0.22087568,0.19384263,,,,,,,,,,,,,, -19261,,,0.7510519027709961,0.2631254366465977,0.725652502571926,0.2860428740888348,3554.0,0.7427368675387461,0.2874655980565833,3581.0,4511.818215370178,4748.09862780571,4511.818215370178,230.6644184589386,4.827903747558594,0.0 -19300,0.072633244,0.24154927,,,,,,,,,,,,,, -19400,0.1540341,0.2690306,,,,,,,,,,,,,, -19500,0.09728107,0.2571189,,,,,,,,,,,,,, -19600,0.045731135,0.27406234,,,,,,,,,,,,,, -19602,,,0.7507716587611607,0.2624681336539132,0.7249418568866066,0.285777455310038,3554.0,0.7420789627548171,0.2871664047773841,3581.0,4591.8331344127655,4832.209055185318,4591.8331344127655,234.71754837036133,4.8576250076293945,0.0 -19700,0.08725933,0.22090869,,,,,,,,,,,,,, -19800,0.07521394,0.34433174,,,,,,,,,,,,,, -19900,0.026409363,0.24207732,,,,,,,,,,,,,, -19946,,,0.7512575558253697,0.2625590392521449,0.7255555744847355,0.2856975463111547,3554.0,0.7427063925710347,0.2871128179214081,3581.0,4671.805415391922,4916.277668714523,4671.805415391922,238.76827216148376,4.890393733978272,0.0 -20000,0.14090127,0.28730553,,,,,,,,,,,,,, -20100,0.100727834,0.3490257,,,,,,,,,,,,,, -20200,0.040857293,0.35366318,,,,,,,,,,,,,, -20291,,,0.7494308607918876,0.2636475052152361,0.7237927335396737,0.286644346873681,3554.0,0.7410374960730243,0.2879711961786163,3581.0,4751.956250667572,5000.525846481323,4751.956250667572,242.82312202453613,4.920037508010864,0.0 -20300,0.059295624,0.2675045,,,,,,,,,,,,,, -20400,0.03457206,0.26409304,,,,,,,,,,,,,, -20500,0.085356325,0.21640098,,,,,,,,,,,,,, -20600,0.051056106,0.28891796,,,,,,,,,,,,,, -20632,,,0.7509378705705915,0.2634527172361101,0.7253593139991911,0.2865787263525693,3554.0,0.7424082560300893,0.2880584282170658,3581.0,4832.079285860062,5084.747328519821,4832.079285860062,246.8783664703369,4.950624942779541,0.0 -20700,0.028595604,0.30399975,,,,,,,,,,,,,, -20800,0.037913535,0.259364,,,,,,,,,,,,,, -20900,0.02966879,0.3029432,,,,,,,,,,,,,, -20977,,,0.7510521071297782,0.2624333586011614,0.7251787158835116,0.2856248502457618,3554.0,0.7423472379180047,0.2870752525808957,3581.0,4912.069168806076,5168.830347061157,4912.069168806076,250.92867803573608,4.980751514434815,0.0 -21000,0.11175878,0.32190657,,,,,,,,,,,,,, -21100,0.028543513,0.26142156,,,,,,,,,,,,,, -21200,0.029747484,0.21883698,,,,,,,,,,,,,, -21300,0.023746777,0.23906212,,,,,,,,,,,,,, -21320,,,0.7505565370832171,0.2626627002443586,0.7246191983284679,0.2858772513969734,3554.0,0.7418286862302779,0.2872758624074979,3581.0,4992.18848824501,5253.038994789124,4992.18848824501,254.97441744804385,5.011352777481079,0.0 -21400,0.08372564,0.21419704,,,,,,,,,,,,,, -21500,0.1187755,0.20407641,,,,,,,,,,,,,, -21600,0.059449546,0.23064591,,,,,,,,,,,,,, -21662,,,0.7506332397460938,0.2624525683266775,0.724930934444464,0.2855532017730902,3554.0,0.7420900073739877,0.2869897249589849,3581.0,5072.292250871658,5337.238107681274,5072.292250871658,259.02623534202576,5.041996479034424,0.0 -21700,0.051049024,0.24732667,,,,,,,,,,,,,, -21800,0.08278699,0.23176631,,,,,,,,,,,,,, -21900,0.032487154,0.20234276,,,,,,,,,,,,,, -22000,0.117674455,0.27576184,,,,,,,,,,,,,, -22007,,,0.7502215249197823,0.2623602492468698,0.7243254601988957,0.2856884099287334,3554.0,0.7414707587571558,0.2871232830389556,3581.0,5152.445718765259,5421.491142511368,5152.445718765259,263.08213925361633,5.072854518890381,0.0 -22100,0.08645818,0.37157273,,,,,,,,,,,,,, -22200,0.18319188,0.22722301,,,,,,,,,,,,,, -22300,0.02472099,0.21899486,,,,,,,,,,,,,, -22354,,,0.7519216537475586,0.2626238891056606,0.7262842868510833,0.2857978232603229,3554.0,0.7433244140079587,0.2872255480312762,3581.0,5232.639133930206,5505.790293693543,5232.639133930206,267.1392109394073,5.108336448669434,0.0 -22400,0.024728362,0.22343455,,,,,,,,,,,,,, -22500,0.053730942,0.24747007,,,,,,,,,,,,,, -22600,0.052776746,0.2104357,,,,,,,,,,,,,, -22696,,,0.7513668196541923,0.2624912091663905,0.7254721792346651,0.2857631324849641,3554.0,0.7426496014119659,0.2871369865479266,3581.0,5312.724345445633,5589.968581676483,5312.724345445633,271.1893949508667,5.138692140579224,0.0 -22700,0.049513888,0.15753141,,,,,,,,,,,,,, -22800,0.065718226,0.30262494,,,,,,,,,,,,,, -22900,0.091320306,0.23278026,,,,,,,,,,,,,, -23000,0.07984613,0.3206231,,,,,,,,,,,,,, -23038,,,0.7522259439740863,0.2620306015014648,0.7259720698728546,0.2856072816006172,3554.0,0.743064865457449,0.2870495499794924,3581.0,5392.689038276672,5674.031042814255,5392.689038276672,275.2423415184021,5.170776605606079,0.0 -23100,0.042486783,0.27578974,,,,,,,,,,,,,, -23200,0.06155732,0.30319977,,,,,,,,,,,,,, -23300,0.039714728,0.24512824,,,,,,,,,,,,,, -23383,,,0.7514401844569615,0.2623745373317173,0.7255796862909749,0.2856796341929867,3554.0,0.7427012793214186,0.2870913763613515,3581.0,5472.827522277832,5758.265685558319,5472.827522277832,279.2945177555084,5.201882123947144,0.0 -23400,0.076726764,0.35420084,,,,,,,,,,,,,, -23500,0.12907052,0.2731906,,,,,,,,,,,,,, -23600,0.08729953,0.28009298,,,,,,,,,,,,,, -23700,0.042740084,0.24993855,,,,,,,,,,,,,, -23728,,,0.7514640944344657,0.2625692401613508,0.725731570061902,0.2858105489358381,3554.0,0.7428281560885577,0.2872230254947989,3581.0,5552.985707521439,5842.519861698151,5552.985707521439,283.3469295501709,5.232615470886231,0.0 -23800,0.074788466,0.23205577,,,,,,,,,,,,,, -23900,0.03931431,0.23148586,,,,,,,,,,,,,, -24000,0.051139742,0.22667602,,,,,,,,,,,,,, -24072,,,0.7519587108067104,0.2617558411189488,0.7255711681599958,0.285483957611582,3554.0,0.7427041427412036,0.2869076402584822,3581.0,5633.126035690308,5926.756481170654,5633.126035690308,287.3980450630188,5.264841794967651,0.0 -24100,0.065445445,0.2706221,,,,,,,,,,,,,, -24200,0.07502761,0.27261302,,,,,,,,,,,,,, -24300,0.043741588,0.29302675,,,,,,,,,,,,,, -24400,0.05096108,0.21420158,,,,,,,,,,,,,, -24417,,,0.7507854189191546,0.262204783303397,0.7247238202113815,0.285589644260868,3554.0,0.741841162559341,0.2869837935894303,3581.0,5713.109417676926,6010.835768699646,5713.109417676926,291.4509608745575,5.295077800750732,0.0 -24500,0.06540305,0.3591047,,,,,,,,,,,,,, -24600,0.036757953,0.19095692,,,,,,,,,,,,,, -24700,0.076544605,0.27867696,,,,,,,,,,,,,, -24762,,,0.7509933880397252,0.2625353336334228,0.7253233180263435,0.2858242535094699,3554.0,0.7423777810623778,0.2872373425937238,3581.0,5793.201881885529,6095.022074460983,5793.201881885529,295.5001335144043,5.32674765586853,0.0 -24800,0.040269252,0.26673937,,,,,,,,,,,,,, -24900,0.059100766,0.30092716,,,,,,,,,,,,,, -25000,0.03839438,0.2486074,,,,,,,,,,,,,, -25100,0.062432252,0.25511873,,,,,,,,,,,,,, -25102,,,0.751718384878976,0.2616557223456247,0.7251388043181978,0.2855815211238657,3554.0,0.7422795384930885,0.2869809642579761,3581.0,5873.248897790909,6179.167838096619,5873.248897790909,299.5552673339844,5.357968091964722,0.0 -25200,0.051989194,0.32549706,,,,,,,,,,,,,, -25300,0.04809449,0.251086,,,,,,,,,,,,,, -25400,0.026643917,0.2771975,,,,,,,,,,,,,, -25447,,,0.7519270351954869,0.262043969971793,0.7256108736414955,0.2856289547483909,3554.0,0.7427448442081471,0.2870431754616378,3581.0,5953.253943920136,6263.270909070969,5953.253943920136,303.6081876754761,5.389884471893311,0.0 -25500,0.020848917,0.20641074,,,,,,,,,,,,,, -25600,0.024957115,0.24793018,,,,,,,,,,,,,, -25700,0.020281471,0.26044136,,,,,,,,,,,,,, -25793,,,0.7512048993791852,0.2622495889663696,0.7252663015044668,0.2856707038943796,3554.0,0.7423027185580146,0.2870635602834404,3581.0,6033.245709180832,6347.361993312836,6033.245709180832,307.6641731262207,5.420477151870728,0.0 -25800,0.03410252,0.2701539,,,,,,,,,,,,,, -25900,0.03130162,0.19282866,,,,,,,,,,,,,, -26000,0.03390351,0.2569982,,,,,,,,,,,,,, -26100,0.049963966,0.2159216,,,,,,,,,,,,,, -26136,,,0.7519177028111049,0.2618169954844883,0.7250756052819006,0.2858858553962085,3554.0,0.7421742737276599,0.2872991106490855,3581.0,6113.267695903778,6431.490020036697,6113.267695903778,311.72106552124023,5.456787824630737,0.0 -26200,0.02514406,0.25059935,,,,,,,,,,,,,, -26300,0.04482649,0.27947298,,,,,,,,,,,,,, -26400,0.018528132,0.25892252,,,,,,,,,,,,,, -26479,,,0.7520029204232352,0.2618359157017299,0.7256874681257034,0.2855081381124261,3554.0,0.7427881363882295,0.2869440465957484,3581.0,6193.231685638428,6515.551483869553,6193.231685638428,315.7747297286988,5.48795223236084,0.0 -26500,0.046178024,0.24553692,,,,,,,,,,,,,, -26600,0.027785351,0.3141317,,,,,,,,,,,,,, -26700,0.049005765,0.33161378,,,,,,,,,,,,,, -26800,0.032343645,0.28831816,,,,,,,,,,,,,, -26824,,,0.7516803741455078,0.2619940042495727,0.7255216393500281,0.2855152308303584,3554.0,0.7426322163632715,0.2869271046953539,3581.0,6273.216883420944,6599.63019657135,6273.216883420944,319.8229854106903,5.520106554031372,0.0 -26900,0.029597161,0.19929236,,,,,,,,,,,,,, -27000,0.037024498,0.24126443,,,,,,,,,,,,,, -27100,0.1008876,0.38091534,,,,,,,,,,,,,, -27166,,,0.7516687938145229,0.2619033200400216,0.7254074689170653,0.2856453899325584,3554.0,0.7424758191016825,0.2870705824795797,3581.0,6353.233630180359,6683.7409517765045,6353.233630180359,323.87314200401306,5.551224231719971,0.0 -27200,0.04303857,0.20765238,,,,,,,,,,,,,, -27300,0.059855342,0.23382084,,,,,,,,,,,,,, -27400,0.03961594,0.21514341,,,,,,,,,,,,,, -27500,0.056981325,0.25649557,,,,,,,,,,,,,, -27512,,,0.7520516259329659,0.2616982460021972,0.7255234254097496,0.2854895047009091,3554.0,0.742613467781346,0.2869271046953539,3581.0,6433.285531997681,6767.885113239288,6433.285531997681,327.92193126678467,5.58203649520874,0.0 -27600,0.024679141,0.31587598,,,,,,,,,,,,,, -27700,0.019565886,0.40159383,,,,,,,,,,,,,, -27800,0.035132937,0.3389671,,,,,,,,,,,,,, -27856,,,0.7518525123596191,0.2618895769119262,0.7255723359682752,0.2855585599522545,3554.0,0.7426463289322117,0.2869478304004642,3581.0,6513.400120258331,6852.091889381409,6513.400120258331,331.9698152542114,5.613497018814087,0.0 -27900,0.045933187,0.20173076,,,,,,,,,,,,,, -28000,0.02289857,0.23635823,,,,,,,,,,,,,, -28100,0.1354056,0.23529157,,,,,,,,,,,,,, -28197,,,0.7518358911786761,0.2619935785021101,0.725502679639139,0.285708434405995,3554.0,0.7426058319952528,0.2871377705795343,3581.0,6593.375950098038,6936.169029951096,6593.375950098038,336.02450037002563,5.647374629974365,0.0 -28200,0.04786581,0.2822851,,,,,,,,,,,,,, -28300,0.04741522,0.2606429,,,,,,,,,,,,,, -28400,0.03137417,0.27715358,,,,,,,,,,,,,, -28500,0.022440312,0.2560407,,,,,,,,,,,,,, -28539,,,0.7522450855800084,0.2616695165634155,0.725679843024585,0.2856231500542962,3554.0,0.7427282091027296,0.2870893310615052,3581.0,6673.345023870468,7020.236711502075,6673.345023870468,340.0794105529785,5.67836856842041,0.0 -28600,0.03177177,0.33711535,,,,,,,,,,,,,, -28700,0.039988805,0.2509694,,,,,,,,,,,,,, -28800,0.027770888,0.23376228,,,,,,,,,,,,,, -28885,,,0.7519869804382324,0.2618741648537772,0.7255535823412,0.2856465405671866,3554.0,0.7426622822710137,0.2870189386584578,3581.0,6753.366040945053,7104.3532173633575,6753.366040945053,344.1304316520691,5.709791898727417,0.0 -28900,0.05221286,0.30901366,,,,,,,,,,,,,, -29000,0.05536591,0.24060552,,,,,,,,,,,,,, -29100,0.065881595,0.20149657,,,,,,,,,,,,,, -29200,0.039693426,0.27353513,,,,,,,,,,,,,, -29226,,,0.7519833700997489,0.261786869594029,0.7255307070378447,0.2855797007168419,3554.0,0.7426421019791958,0.2869982470416783,3581.0,6833.333642244339,7188.418164491653,6833.333642244339,348.1815276145935,5.743354797363281,0.0 -29300,0.030699082,0.2659401,,,,,,,,,,,,,, -29400,0.074417956,0.22029455,,,,,,,,,,,,,, -29500,0.023760634,0.19326247,,,,,,,,,,,,,, -29570,,,0.7523882048470634,0.2614668778010777,0.7256052406839125,0.2855992099845684,3554.0,0.7426630322142908,0.2870593674187552,3581.0,6913.417159795761,7272.599680185318,6913.417159795761,352.23563385009766,5.774363517761231,0.0 -29600,0.038084473,0.28221327,,,,,,,,,,,,,, -29700,0.023421524,0.3796457,,,,,,,,,,,,,, -29800,0.024789857,0.24731284,,,,,,,,,,,,,, -29900,0.02917818,0.26997185,,,,,,,,,,,,,, -29916,,,0.7522600037711007,0.2617281675338745,0.7256939941131472,0.2856097374327342,3554.0,0.7427758645891511,0.28703622144216,3581.0,6993.5478773117065,7356.830544710159,6993.5478773117065,356.2854199409485,5.811537981033325,0.0 -30000,0.02991261,0.24453184,,,,,,,,,,,,,, -30100,0.028538529,0.26537296,,,,,,,,,,,,,, -30200,0.040333867,0.24305688,,,,,,,,,,,,,, -30258,,,0.7521281242370605,0.2616961683545794,0.7255890287572102,0.2855764377231201,3554.0,0.7426906437622173,0.286973158030229,3581.0,7073.549101829529,7440.927565097809,7073.549101829529,360.3349132537842,5.845282316207886,0.0 -30300,0.04641494,0.23978621,,,,,,,,,,,,,, -30400,0.06572593,0.222418,,,,,,,,,,,,,, -30500,0.026479047,0.21086808,,,,,,,,,,,,,, -30600,0.039791666,0.2466041,,,,,,,,,,,,,, -30602,,,0.7526561873299735,0.2613735369273594,0.7257843962128939,0.2855524633060899,3554.0,0.7428522224500838,0.2869701241687901,3581.0,7153.538716077805,7525.018535852432,7153.538716077805,364.3896791934967,5.879350185394287,0.0 -30700,0.04452491,0.31171274,,,,,,,,,,,,,, -30800,0.038566116,0.31086528,,,,,,,,,,,,,, -30900,0.022643974,0.1933577,,,,,,,,,,,,,, -30947,,,0.7522655895778111,0.2616010393415178,0.7255678708189716,0.2855880814586117,3554.0,0.742683007976124,0.2870090189542027,3581.0,7233.716419696808,7609.30117225647,7233.716419696808,368.44449734687805,5.916398048400879,0.0 -31000,0.02925068,0.30868408,,,,,,,,,,,,,, -31100,0.018741835,0.25777724,,,,,,,,,,,,,, -31200,0.035895415,0.253021,,,,,,,,,,,,,, -31291,,,0.7522608212062291,0.2616640159061977,0.7256755152644907,0.2856347594424855,3554.0,0.7427899771580914,0.287009871162472,3581.0,7313.770256280899,7693.459502458572,7313.770256280899,372.5036156177521,5.948800563812256,0.0 -31300,0.023405127,0.29702824,,,,,,,,,,,,,, -31400,0.020195154,0.26024252,,,,,,,,,,,,,, -31500,0.020403598,0.19913736,,,,,,,,,,,,,, -31600,0.015043253,0.21767978,,,,,,,,,,,,,, -31633,,,0.7526048932756696,0.2613063539777483,0.7256450148600169,0.2855985230385217,3554.0,0.7427310725225147,0.2870043829412175,3581.0,7393.891540288925,7777.6818697452545,7393.891540288925,376.55426692962646,5.986757755279541,0.0 -31700,0.025730826,0.23970555,,,,,,,,,,,,,, -31800,0.05204666,0.21505408,,,,,,,,,,,,,, -31900,0.02531168,0.3218243,,,,,,,,,,,,,, -31978,,,0.752514089856829,0.2615161623273577,0.725779793674381,0.2855669235203731,3554.0,0.7428728799785326,0.2869842367377304,3581.0,7473.958404064178,7861.848997831345,7473.958404064178,380.60776138305664,6.020724773406982,0.0 -32000,0.02411378,0.18660891,,,,,,,,,,,,,, -32100,0.018083032,0.3876562,,,,,,,,,,,,,, -32200,0.033307236,0.22433661,,,,,,,,,,,,,, -32300,0.023357933,0.31017706,,,,,,,,,,,,,, -32323,,,0.7524120467049735,0.261630654335022,0.7257753972196821,0.2856301053830191,3554.0,0.742837428114528,0.2870202340150272,3581.0,7553.929656028748,7945.914715051651,7553.929656028748,384.6572859287262,6.052553415298462,0.0 -32400,0.014162642,0.3174763,,,,,,,,,,,,,, -32500,0.020277781,0.2822586,,,,,,,,,,,,,, -32600,0.015270295,0.24857916,,,,,,,,,,,,,, -32664,,,0.7525908606392997,0.2613300766263689,0.7257525906109313,0.2855583195211381,3554.0,0.7427972720608769,0.2869662721874127,3581.0,7633.9339427948,8030.01468038559,7633.9339427948,388.7059483528137,6.086850643157959,0.0 -32700,0.021566486,0.33221644,,,,,,,,,,,,,, -32800,0.016844943,0.25988975,,,,,,,,,,,,,, -32900,0.0381314,0.2254978,,,,,,,,,,,,,, -33000,0.022438321,0.26855963,,,,,,,,,,,,,, -33009,,,0.7525243759155273,0.2613956247057233,0.7257540331976294,0.2854895218745603,3554.0,0.742825701728742,0.2869116626815135,3581.0,7714.073157310486,8114.253061294556,7714.073157310486,392.7595648765564,6.119038105010986,0.0 -33100,0.020939887,0.2636215,,,,,,,,,,,,,, -33200,0.021403236,0.34729263,,,,,,,,,,,,,, -33300,0.04835814,0.25111187,,,,,,,,,,,,,, -33353,,,0.752453054700579,0.2614794969558716,0.7257225710686902,0.2855558808626723,3554.0,0.7428152025228637,0.2869396832894094,3581.0,7794.181841135025,8198.453286409378,7794.181841135025,396.8049416542053,6.152233600616455,0.0 -33400,0.01868613,0.25056672,,,,,,,,,,,,,, -33500,0.015473155,0.2513371,,,,,,,,,,,,,, -33600,0.018308625,0.22659643,,,,,,,,,,,,,, -33695,,,0.7523847988673619,0.2613553319658552,0.7255684890704136,0.2854971984966323,3554.0,0.7426832125061086,0.2868952321060807,3581.0,7874.324627399445,8282.699506282806,7874.324627399445,400.8564488887787,6.191194772720337,0.0 -33700,0.01899721,0.2850942,,,,,,,,,,,,,, -33800,0.01745108,0.29404607,,,,,,,,,,,,,, -33900,0.017115744,0.36824444,,,,,,,,,,,,,, -34000,0.01884046,0.25391468,,,,,,,,,,,,,, -34037,,,0.7527414730616978,0.2613096748079572,0.725891834574599,0.2854857780186058,3554.0,0.7429591234553895,0.28690225430222,3581.0,7954.292649030685,8366.768176078796,7954.292649030685,404.9108922481537,6.224079132080078,0.0 -34100,0.019018171,0.30027544,,,,,,,,,,,,,, -34200,0.018380485,0.33383256,,,,,,,,,,,,,, -34300,0.04278497,0.24990167,,,,,,,,,,,,,, -34380,,,0.7526027134486607,0.261373656136649,0.7257960056010833,0.2854935405089336,3554.0,0.7428880833740575,0.2868850056068486,3581.0,8034.458181381226,8451.035452127457,8034.458181381226,408.96024656295776,6.263574361801148,0.0 -34400,0.022307595,0.2687466,,,,,,,,,,,,,, -34500,0.020178154,0.26256716,,,,,,,,,,,,,, -34600,0.017825829,0.25873187,,,,,,,,,,,,,, -34700,0.02942196,0.35550895,,,,,,,,,,,,,, -34723,,,0.7526653153555733,0.2613887275968279,0.7258881250659468,0.2855344824933173,3554.0,0.7429870758866238,0.2869422058258866,3581.0,8114.476765394211,8535.151907682419,8114.476765394211,413.01319456100464,6.295818567276001,0.0 -34800,0.0156081095,0.3642138,,,,,,,,,,,,,, -34900,0.017363725,0.2258203,,,,,,,,,,,,,, -35000,0.014635308,0.29662913,,,,,,,,,,,,,, -35068,,,0.7526427677699498,0.2612641368593488,0.7257840527398706,0.2854579567037141,3554.0,0.7428655168990854,0.2868667683498848,3581.0,8194.555145740509,8619.32661819458,8194.555145740509,417.0644700527191,6.32902193069458,0.0 -35100,0.02383036,0.27059457,,,,,,,,,,,,,, -35200,0.02584832,0.22074896,,,,,,,,,,,,,, -35300,0.018186629,0.25227472,,,,,,,,,,,,,, -35400,0.022709887,0.29272977,,,,,,,,,,,,,, -35411,,,0.7526873179844448,0.2612950801849365,0.7258443666027715,0.2854678315531355,3554.0,0.7429296029609397,0.2868662911132539,3581.0,8274.71585559845,8703.585729598999,8274.71585559845,421.1167623996735,6.363311767578125,0.0 -35500,0.017899897,0.19225629,,,,,,,,,,,,,, -35600,0.021012546,0.23267533,,,,,,,,,,,,,, -35700,0.016587125,0.29624,,,,,,,,,,,,,, -35754,,,0.7527467863900321,0.2612860543387277,0.7259219915060495,0.2854532682969453,3554.0,0.743001665692195,0.2868410998368123,3581.0,8354.723905563354,8787.698282241821,8354.723905563354,425.1749198436737,6.397544860839844,0.0 -35800,0.06431034,0.26615533,,,,,,,,,,,,,, -35900,0.022891479,0.26474628,,,,,,,,,,,,,, -36000,0.016589995,0.22781391,,,,,,,,,,,,,, -36099,,,0.7527127265930176,0.261279889515468,0.7258752104802687,0.2854581971348304,3554.0,0.7429603506352974,0.2868526557809445,3581.0,8434.793033599854,8871.863282203674,8434.793033599854,429.2230050563812,6.4317402839660645,0.0 -36100,0.023470083,0.25741407,,,,,,,,,,,,,, -36189,,,0.7527111598423549,0.2612797021865845,0.7258738365881753,0.2854579738773652,3554.0,0.7429589871020664,0.2868524853392907,3581.0,8453.836063861847,8895.000623464584,8453.836063861847,433.2783992290497,6.466166973114014,0.0 -36189,,,,,,,,,,,8453.836063861847,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/eval_measurements.csv deleted file mode 100644 index e91a5a15e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.998514652252197,0.0,46.38900184631348,1,0,46.38900184631348,0.8250461419645351,3581,0.2972505852284627,50.38763356208801,0.8154546192714146,0.2794570582253592,0.8275999671365011,3554,0.2738082104341059 -8.051229238510132,0.0280277729034423,126.72829818725586,261,0,126.72829818725586,0.3541430889045483,3581,0.6680879909810458,134.8169128894806,0.3304575851985386,0.6731563976832798,0.3530414948385622,3554,0.648352799552441 -12.102293968200684,0.0650091171264648,206.8462879657745,594,0,206.8462879657745,0.3206396797987643,3581,0.705759073495532,219.03469967842105,0.2972275870186942,0.7111654281616211,0.3188105268161754,3554,0.6874991069701393 -16.150205850601196,0.0895030498504638,286.9167993068695,932,0,286.9167993068695,0.30844171604475,3581,0.7187924058834823,303.1899762153625,0.2850749833243234,0.7249079431806292,0.3062318577549064,3554,0.7014162904034187 -20.195985317230225,0.114079475402832,367.0610167980194,1282,0,367.0610167980194,0.304571906470347,3581,0.7228088975452038,387.4175651073456,0.2807219198771885,0.730029787336077,0.3024191354591657,3554,0.7057390357915377 -24.246984004974365,0.1383695602416992,447.1965122222901,1630,0,447.1965122222901,0.2986177659107966,3581,0.7293881499144792,471.6418776512146,0.2754407099315098,0.735931396484375,0.2966684353237549,3554,0.7122415296804657 -28.29820895195008,0.1634666919708252,527.3426928520203,1978,0,527.3426928520203,0.296682025959491,3581,0.7315722574438006,555.8780257701874,0.2737154960632324,0.7379615647452218,0.2948487152460256,3554,0.714338569877251 -32.347129821777344,0.1886870861053466,607.5200967788696,2323,0,607.5200967788696,0.2949648262967746,3581,0.7332982178075258,640.1423740386963,0.2714020695005144,0.7403804234095982,0.2932167031755592,3554,0.7161126080428742 -36.39129662513733,0.2129256725311279,687.5188465118408,2670,0,687.5188465118408,0.2938280145538083,3581,0.7340108002740157,724.2226903438568,0.2707600423267909,0.7409505844116211,0.2921713773762838,3554,0.7167952263294879 -40.43844723701477,0.2373144626617431,767.621922492981,3018,0,767.621922492981,0.2929147881723855,3581,0.7352949076942544,808.4106514453888,0.2698723758969988,0.7421620913914272,0.2912445497700654,3554,0.7181647219682048 -44.491214990615845,0.2660079002380371,847.7596549987793,3363,0,847.7596549987793,0.2936922748206681,3581,0.7324179207536303,892.6428055763245,0.2705576249531337,0.7382982117789132,0.291953581132175,3554,0.7156640322743739 -48.54154658317566,0.2912108898162842,927.7310743331908,3699,0,927.7310743331908,0.2921372674357721,3581,0.7347732198800964,976.7025051116944,0.2692724636622837,0.7407500403267997,0.2905863180681098,3554,0.7178452920564856 -52.58805227279663,0.3167808055877685,1007.7034420967102,4045,0,1007.7034420967102,0.2922567811234641,3581,0.736945396493647,1060.7598226070404,0.2689505474908011,0.7437830652509417,0.2908049386474747,3554,0.7196033930738253 -56.636749029159546,0.3462316989898681,1087.85382604599,4389,0,1087.85382604599,0.2913125343610374,3581,0.7368078841673066,1145.0013349056244,0.2679150785718645,0.7440434864589146,0.2897296276532604,3554,0.7194905278383511 -60.68485951423645,0.3707010746002197,1167.8536245822906,4735,0,1167.8536245822906,0.2911158787808049,3581,0.7372943928241064,1229.0866062641144,0.268069931438991,0.7440556798662458,0.2896225327645786,3554,0.7199958453503095 -64.73780488967896,0.3951351642608642,1248.007008075714,5085,0,1248.007008075714,0.291414049410081,3581,0.7373068009765079,1313.3300881385803,0.2685292107718331,0.7439850398472377,0.2899754856433771,3554,0.7199961888233328 -68.79254746437073,0.4260334968566894,1328.0173542499542,5426,0,1328.0173542499542,0.2915667310436156,3581,0.7372558048336708,1397.438951253891,0.2681744609560285,0.7444303376334054,0.2899265407375492,3554,0.7201586515633793 -72.83991551399231,0.4525535106658935,1408.1050882339478,5773,0,1408.1050882339478,0.2898276807608733,3581,0.7386466769058923,1481.6133811473846,0.2664764608655657,0.7458662986755371,0.2882950782239202,3554,0.7213260476751547 -76.8948724269867,0.4782447814941406,1488.1629683971405,6121,0,1488.1629683971405,0.2905847484990226,3581,0.7392027939341316,1565.7648272514343,0.2676864010947091,0.7458045823233468,0.289156302482678,3554,0.7220802457398354 -80.9424307346344,0.5033724308013916,1568.169510126114,6468,0,1568.169510126114,0.2906994216437447,3581,0.7396826894547612,1649.8570857048037,0.2672385488237653,0.7469580514090401,0.2890733880948403,3554,0.7227373096335116 -84.99346947669983,0.5285801887512207,1648.3162939548492,6814,0,1648.3162939548492,0.2890812485819254,3581,0.7403746825694638,1734.0930144786837,0.265754290989467,0.7474136352539062,0.2876301659716604,3554,0.7232966897993458 -89.04334568977356,0.5537066459655762,1728.470938205719,7160,0,1728.470938205719,0.2891557997613271,3581,0.740983363803756,1818.335706949234,0.26585396698543,0.7481129510062081,0.2876633111184141,3554,0.7238215852736354 -93.09039211273192,0.5837688446044922,1808.6441133022308,7506,0,1808.6441133022308,0.289021696268064,3581,0.7396961884337475,1902.599097251892,0.2659519399915422,0.7462959289550781,0.287555254505267,3554,0.7225214024910313 -97.14750504493712,0.6135847568511963,1888.6966087818143,7851,0,1888.6966087818143,0.2891605721276354,3581,0.7380947186540072,1986.7513935565948,0.2657542398997715,0.7452442305428642,0.2876985857979126,3554,0.7210275009232555 -101.14794516563416,0.6397542953491211,1968.838838338852,8201,0,1968.838838338852,0.2886514629075328,3581,0.7414120586515638,2070.93323135376,0.2652214254651751,0.7487215995788574,0.2872155940324722,3554,0.7242683749824141 -105.2003893852234,0.6659073829650879,2048.9702048301697,8547,0,2048.9702048301697,0.2884978949773981,3581,0.7406799776598716,2155.156111717224,0.2652098281042916,0.7477490561349052,0.2870953269533448,3554,0.7233854432285804 -109.25191283226012,0.6932220458984375,2129.0309269428253,8893,0,2129.0309269428253,0.2885711167118996,3581,0.7415842047219702,2239.3087198734283,0.2646728924342564,0.7493375369480678,0.2870908102830877,3554,0.7244903959447102 -113.3003273010254,0.718714714050293,2209.164315223694,9243,0,2209.164315223694,0.2882860700899713,3581,0.7409059832929,2323.52946972847,0.2648956435067313,0.7482168333871024,0.2868886935824951,3554,0.7236484061752603 -117.35204768180849,0.7489063739776611,2289.1893439292908,9587,0,2289.1893439292908,0.2883126248996439,3581,0.7408063090137182,2407.648950576782,0.2649180889129638,0.7480886323111398,0.2869371919733927,3554,0.723539319143043 -121.4030601978302,0.7770266532897949,2369.2818806171417,9931,0,2369.2818806171417,0.2880700523378595,3581,0.7419528359309551,2491.8336610794067,0.2642585379736764,0.7497763633728027,0.2866745038051315,3554,0.7247282853606851 -125.45101857185364,0.8046424388885498,2449.29647898674,10278,0,2449.29647898674,0.2882836498184865,3581,0.7403298905028274,2575.937170982361,0.2647982495171683,0.7476978983197894,0.2868940517616594,3554,0.7231131378156654 -129.50124597549438,0.8316726684570312,2529.2785215377808,10622,0,2529.2785215377808,0.2890171284317404,3581,0.7399198760602834,2660.009659767151,0.2655226673398699,0.7473840713500977,0.2875874379275552,3554,0.7227836784916644 -133.54821157455444,0.8579127788543701,2609.270174264908,10964,0,2609.270174264908,0.2882112802922542,3581,0.7413440183433399,2744.0874071121216,0.2642334188733782,0.7493657384599958,0.2867996653748417,3554,0.7241476098674029 -137.60043573379517,0.8868420124053955,2689.234837770462,11311,0,2689.234837770462,0.2881827824477276,3581,0.740676296120148,2828.1469349861145,0.2644272191183908,0.7484683990478516,0.2867133506040729,3554,0.7234407423853405 -141.65505647659302,0.9141678810119628,2769.3530809879303,11658,0,2769.3530809879303,0.2879696281154007,3581,0.7413526767793563,2912.360531568527,0.2645096097673688,0.748514039175851,0.2865941998122714,3554,0.7242095037062113 -145.70746207237244,0.9428093433380128,2849.3807203769684,12000,0,2849.3807203769684,0.2880192948133377,3581,0.7408110813800265,2996.482173204422,0.2642149073737008,0.7486372675214495,0.2866262114980479,3554,0.7235448834060214 -149.75933504104614,0.9710886478424072,2929.544560432434,12348,0,2929.544560432434,0.2882019400896223,3581,0.741044313739179,3080.739272594452,0.2645761966705322,0.7484827722821917,0.2867579849234577,3554,0.7237893675040448 -153.8161265850067,1.0013494491577148,3009.52698135376,12693,0,3009.52698135376,0.2877790743463767,3581,0.7417327616674811,3164.821781158448,0.2641494955335344,0.749316828591483,0.2864611898839775,3554,0.7244157936040377 -157.86707878112793,1.0301060676574707,3089.589159965515,13037,0,3089.589159965515,0.288313068047944,3581,0.7397638196820022,3248.976585388184,0.2649313041142055,0.7469122750418526,0.286979628065428,3554,0.7224884977753939 -161.91984677314758,1.057826042175293,3169.674957036972,13384,0,3169.674957036972,0.2876959669959508,3581,0.7425961509093131,3333.155841112137,0.2638043846402849,0.7504325594220843,0.286266870021015,3554,0.7254452509496342 -165.96881675720215,1.0872135162353516,3249.79615855217,13730,0,3249.79615855217,0.2878152761536582,3581,0.7423545328207903,3417.368681192398,0.2643736771174839,0.7498242514474052,0.286442298867693,3554,0.7252430140334833 -170.0182363986969,1.114762544631958,3329.8468322753906,14073,0,3329.8468322753906,0.2881390471193452,3581,0.739502566714954,3501.50940656662,0.2646465301513672,0.7470395905630929,0.2867674647789023,3554,0.7220492644731289 -174.07473397254944,1.145578145980835,3409.9828832149506,14419,0,3409.9828832149506,0.2874166131252618,3581,0.7422518587685004,3585.745890855789,0.263518077986581,0.7502109663827079,0.2860264045573649,3554,0.7250840547182752 -178.1279594898224,1.1743721961975098,3490.001576662064,14767,0,3490.001576662064,0.2871996749882191,3581,0.7418965901851787,3669.860261678696,0.2635535172053745,0.749675818852016,0.2858485198785699,3554,0.7246531334631753 -182.1823709011078,1.2033040523529053,3570.1850593090057,15112,0,3570.1850593090057,0.2872077880109431,3581,0.7416073166102346,3754.140084505081,0.2635910511016845,0.749295847756522,0.2858831247856728,3554,0.7243161177326604 -186.23437309265137,1.2322709560394287,3650.1696586608887,15458,0,3650.1696586608887,0.287198481896642,3581,0.7424763645149749,3838.218438386917,0.2631913423538208,0.7505433900015694,0.2858433162622661,3554,0.7253112964705262 -190.2851979732513,1.260035514831543,3730.133190393448,15806,0,3730.133190393448,0.2874296007792865,3581,0.7423772356490854,3922.273967027664,0.2635800157274519,0.7500426428658622,0.2860392847957407,3554,0.7251619543999719 -194.33792114257807,1.291689157485962,3810.2701032161713,16151,0,3810.2701032161713,0.2871168744327702,3581,0.7428711755619939,4006.5082714557648,0.2634295565741403,0.7505623272487095,0.2857569499705437,3554,0.7256722866180711 -198.39164185523987,1.3205831050872805,3890.3815047740936,16496,0,3890.3815047740936,0.2872027770263194,3581,0.7426029685754677,4090.71523809433,0.2629409176962716,0.7508247920445034,0.2858233433059581,3554,0.7254057515519485 -202.4444003105164,1.3496923446655271,3970.4745304584494,16843,0,3970.4745304584494,0.2873442095106988,3581,0.7414484649888299,4174.903291225433,0.2634856700897217,0.7493993214198521,0.2860307838384127,3554,0.7241385421795864 -206.4969160556793,1.3776051998138428,4050.521205425264,17188,0,4050.521205425264,0.2872095946924741,3581,0.7419645623167411,4259.043537378311,0.2636538743972778,0.7493541581290108,0.2859189490220086,3554,0.7246931137230938 -210.5419554710388,1.405665159225464,4130.596678495407,17533,0,4130.596678495407,0.2872696583312971,3581,0.7426409429759494,4343.205314636231,0.2629106725965227,0.7509161404200962,0.2858201490068409,3554,0.7254793234735509 -214.59174036979675,1.435006618499756,4210.794536828995,17881,0,4210.794536828995,0.2874863237616936,3581,0.7417941206628735,4427.495761632919,0.2636230162211826,0.7497342654636928,0.286127179542417,3554,0.7245976282226013 -218.64333319664,1.4626963138580322,4290.841006994247,18229,0,4290.841006994247,0.2877632573608978,3581,0.7428972190467048,4511.634707212448,0.2640169177736555,0.7507637568882534,0.2863463496786191,3554,0.7257824727639631 -222.6942172050476,1.4911997318267822,4371.013791799545,18573,0,4371.013791799545,0.2871066479335381,3581,0.7424235276022759,4595.899594783783,0.2628540311540876,0.750854355948312,0.2857621192395452,3554,0.725241227973762 -226.7476847171784,1.5213429927825928,4451.124196529388,18920,0,4451.124196529388,0.2870736504293493,3581,0.7426334435431793,4680.10669708252,0.2630536556243896,0.7506722041538784,0.285700448658202,3554,0.7254103540904615 -230.79849290847773,1.5496416091918943,4531.102238416672,19265,0,4531.102238416672,0.2871111135048694,3581,0.7431179750767942,4764.176918029785,0.2633555105754307,0.7508053779602051,0.2857915205303443,3554,0.7259667116936902 -234.8485188484192,1.5785152912139893,4611.263741254807,19612,0,4611.263741254807,0.2869758510083601,3581,0.7418924314088243,4848.430703639984,0.262720022882734,0.750293459211077,0.2855622351136044,3554,0.7247455277064575 -238.89465618133545,1.6072266101837158,4691.397416830063,19958,0,4691.397416830063,0.2871143859846237,3581,0.7426680772872452,4932.652416944504,0.2630909170423235,0.7505944115774972,0.2857109245854143,3554,0.7255061830639772 -242.94447922706604,1.635915994644165,4771.36471247673,20305,0,4771.36471247673,0.2868059888561156,3581,0.7428100892732477,5016.711462259293,0.2629550184522356,0.7507155282156808,0.2854636583559018,3554,0.7256489991470878 -246.99908590316767,1.6671850681304932,4851.4915273189545,20652,0,4851.4915273189545,0.2872346155272619,3581,0.7420031503071768,5100.93705368042,0.2631739548274449,0.7502062661307198,0.2858611596858293,3554,0.7248037807312183 -251.05397963523865,1.69809889793396,4931.621717214584,20997,0,4931.621717214584,0.2871397136143884,3581,0.743488583408964,5185.16579914093,0.2628318411963327,0.7517170224870954,0.2857012214725046,3554,0.7264534816623874 -255.10294556617737,1.7279188632965088,5011.650506019592,21345,0,5011.650506019592,0.2868220444599099,3581,0.7424934086803616,5269.286895036697,0.2627440861293247,0.7506068774632045,0.285448013159688,3554,0.7253145251169457 -259.1506371498108,1.763467788696289,5091.84877705574,21690,0,5091.84877705574,0.2881662155189716,3581,0.7442467078853672,5353.5817584991455,0.2637862818581717,0.7520913396562848,0.286662052908035,3554,0.7273531062051561 -263.2007920742035,1.793210744857788,5171.93511557579,22038,0,5171.93511557579,0.2868673819398387,3581,0.7424092786800126,5437.761346340179,0.262592111315046,0.7507982935224261,0.2854774316241383,3554,0.7252544860324635 -267.25358629226685,1.8238272666931152,5252.042159557343,22384,0,5252.042159557343,0.2866638746051207,3581,0.7426881894024016,5521.96520614624,0.2625446489879063,0.7508326257978167,0.2852868727907815,3554,0.7254954667056486 -271.30557203292847,1.8549041748046875,5332.193813800812,22731,0,5332.193813800812,0.2867556744798938,3581,0.743461789980976,5606.21314406395,0.2625406299318586,0.7517132077898298,0.2853445419114026,3554,0.7263388303671919 -275.3547251224518,1.886202812194824,5412.402036905289,23078,0,5412.402036905289,0.2867709460520804,3581,0.7425826519303267,5690.515196084976,0.2621175731931414,0.7513458388192313,0.2853178712311392,3554,0.7253771059018008 -279.40776801109314,1.916372299194336,5492.464476823807,23423,0,5492.464476823807,0.2868492128595364,3581,0.742896741810074,5774.673628091812,0.2625161579677036,0.7511231558663505,0.2854214798686339,3554,0.7256902846044949 -283.4580383300781,1.949268341064453,5572.554956436157,23767,0,5572.554956436157,0.2868443382282358,3581,0.7434283152401564,5858.86038517952,0.2626339537756784,0.7518348693847656,0.285478427695906,3554,0.7262959649338773 -287.5047097206116,1.9845812320709229,5652.6396470069885,24112,0,5652.6396470069885,0.2865892211607267,3581,0.7427327087623918,5943.040595293045,0.2618482112884521,0.7516160011291504,0.2851764462137732,3554,0.7255825027697664 -291.55988907814026,2.014088153839112,5732.652978897095,24460,0,5732.652978897095,0.2867930012020909,3581,0.7436621611892628,6027.151556968689,0.2623759337833949,0.7521277836390904,0.2854077066003974,3554,0.726524717967431 -295.6112983226776,2.044174194335937,5812.645381689072,24805,0,5812.645381689072,0.2865136473314018,3581,0.7429490333094806,6111.23862695694,0.2622409377779279,0.7513463837759835,0.2851812204887978,3554,0.7257656425858188 -299.6589345932007,2.075486183166504,5892.711507558823,25151,0,5892.711507558823,0.2866022429030822,3581,0.7433074380192335,6195.396842479706,0.2617241314479283,0.7522298949105399,0.2851499300963703,3554,0.7262362006278137 -303.71176838874817,2.105675220489502,5972.851857185364,25500,0,5972.851857185364,0.2866298885393396,3581,0.7438676456471656,6279.633534908295,0.2621656315667288,0.7523572785513741,0.2852240859221124,3554,0.7267561500905669 -307.7598524093628,2.14125919342041,6052.838146686554,25846,0,6052.838146686554,0.2866360244388788,3581,0.743266872905613,6363.716654062271,0.2622879913875034,0.7516216550554548,0.2852599445057505,3554,0.726183992728264 -311.8156065940857,2.172098398208618,6132.949994087219,26190,0,6132.949994087219,0.286731574030037,3581,0.743290802913816,6447.928359508514,0.2617473602294922,0.7522223336356026,0.2853095763576252,3554,0.7261280066254572 -315.86508679389954,2.203031063079834,6212.970285415649,26538,0,6212.970285415649,0.2864591400904949,3581,0.7435386250785395,6532.04209446907,0.2619017532893589,0.7521584374564034,0.2850819396113974,3554,0.7264184474140053 -319.9152567386627,2.234532356262207,6293.131700754166,26884,0,6293.131700754166,0.2866152646454377,3581,0.7431427232049358,6616.29847073555,0.2620894398008074,0.7518197468348912,0.2851997336847566,3554,0.7260853472759566 -323.96526765823364,2.2656874656677246,6373.266193151474,27228,0,6373.266193151474,0.2866919633896782,3581,0.743790674196279,6700.527494668961,0.2616877555847168,0.7528512137276786,0.2852557713085168,3554,0.726850673866594 -328.01572155952454,2.301090955734253,6453.276751041412,27577,0,6453.276751041412,0.2864751275176277,3581,0.7433965449158755,6784.637478113174,0.2617858478001186,0.7521657262529645,0.2850842065333515,3554,0.7262759061093135 -332.0697326660156,2.3325891494750977,6533.27613568306,27924,0,6533.27613568306,0.2866200029234153,3581,0.7430076652384111,6868.735721111298,0.2619636058807373,0.7517365046909877,0.2852299421371606,3554,0.7259130612074424 -336.1240510940552,2.365093231201172,6613.255655527115,28267,0,6613.255655527115,0.2866514323643884,3581,0.7435234898596761,6952.815067529678,0.2617320503507341,0.7525230816432408,0.2852017601755944,3554,0.7264022354873031 -340.17482447624207,2.3963160514831543,6693.265082120895,28614,0,6693.265082120895,0.2865337594465582,3581,0.743265509372382,7036.919741868973,0.2616327660424368,0.7522595269339425,0.2851194468655476,3554,0.7261526679885341 -344.2283432483673,2.427100419998169,6773.419958114624,28962,0,6773.419958114624,0.2864951714561226,3581,0.7436247322020735,7121.172005176544,0.2617249659129551,0.7524974686758858,0.2851314169004115,3554,0.7265269848893852 -348.2759618759155,2.458501100540161,6853.474825620651,29307,0,6853.474825620651,0.2866700786813215,3581,0.7429303529042167,7205.318714380264,0.2619178976331438,0.7518026488167899,0.2852925229220157,3554,0.7258280859814645 -352.32594203948975,2.4895973205566406,6933.58352804184,29654,0,6933.58352804184,0.2866022088147514,3581,0.743440109802604,7289.521868467331,0.2615166221346174,0.7526144981384277,0.2851482127312535,3554,0.7264150126837718 -356.37841534614563,2.5211710929870605,7013.698459863663,30003,0,7013.698459863663,0.2864873652283755,3581,0.7438008325188494,7373.734119653702,0.2616254431860788,0.7527388163975307,0.2851014488791238,3554,0.7267130785734384 -360.4285726547241,2.558183431625366,7093.814469337463,30346,0,7093.814469337463,0.2865387022545204,3581,0.7434115437814158,7457.95020198822,0.2616192102432251,0.7524821417672294,0.2851408452349026,3554,0.7263430894326814 -364.4844374656677,2.589385509490967,7173.847417593002,30692,0,7173.847417593002,0.2864683439398038,3581,0.7432238534321768,7542.083223819733,0.2612156527382986,0.7525557109287807,0.2850283062988006,3554,0.7261212745541995 -368.5385098457336,2.620823383331299,7254.003686904907,31038,0,7254.003686904907,0.2863576250414514,3581,0.7436732058084334,7626.33788728714,0.2613873652049473,0.7527132715497699,0.2849861278115327,3554,0.726601724619267 -372.5858800411224,2.65349555015564,7334.081294298172,31381,0,7334.081294298172,0.2863485575454657,3581,0.7433383220469143,7710.508943557739,0.2613832950592041,0.7524211066109794,0.2850200285989378,3554,0.7261790467167276 -376.6345460414887,2.6875240802764893,7414.202308893204,31728,0,7414.202308893204,0.286343989709142,3581,0.7435512377609257,7794.726012229919,0.2609822239194597,0.7530506678989956,0.2849472294916379,3554,0.7264788986661156 -380.68638134002686,2.719357490539551,7494.393157243729,32077,0,7494.393157243729,0.2863007316173904,3581,0.7436087106866098,7879.013952970505,0.261205928666251,0.7527856826782227,0.2849471092760797,3554,0.72648425684528 -384.7363765239716,2.7521777153015137,7574.377175331116,32421,0,7574.377175331116,0.2863673402157218,3581,0.7437317013840408,7963.093930482864,0.2613171679633004,0.7528745106288365,0.2850152199766108,3554,0.7266306450478335 -388.787223815918,2.7848803997039795,7654.5212988853455,32768,0,7654.5212988853455,0.2863153214229614,3581,0.7434478137653588,8047.334677219391,0.2608934129987444,0.7530239650181362,0.2849223276974448,3554,0.7263984572840462 -392.8369917869568,2.824607610702514,7734.511190891266,33114,0,7734.511190891266,0.2862740745427255,3581,0.7433103014390184,8131.42751955986,0.261080128805978,0.7526590483529227,0.2849193051348392,3554,0.7261710781425859 -396.8901824951172,2.85813307762146,7814.515537023544,33461,0,7814.515537023544,0.2863309338784557,3581,0.7440449049671879,8215.531928062439,0.2611795323235648,0.7532807758876255,0.2850068048875387,3554,0.726994863841798 -400.94162368774414,2.8913586139678955,7894.61809015274,33807,0,7894.61809015274,0.2862424064834369,3581,0.7437748572108001,8299.732083559036,0.2608542442321777,0.7532802990504673,0.2848649505288935,3554,0.7267084760349254 -404.9890365600586,2.923970937728882,7974.578327178955,34153,0,7974.578327178955,0.2861671053607582,3581,0.7437286334342712,8383.78547167778,0.2609092337744577,0.7531078883579799,0.2848093422464124,3554,0.7266382014543472 -409.04110741615295,2.958050012588501,8054.679794073105,34499,0,8054.679794073105,0.2862170106770106,3581,0.7436772964081262,8467.986461162567,0.260980623109,0.7530832971845355,0.2848652940019168,3554,0.7266129218398284 -413.0929937362671,2.9907097816467285,8134.677657842636,34842,0,8134.677657842636,0.286145459270717,3581,0.7437078395524993,8552.080843925476,0.2608306578227451,0.753159659249442,0.284789936020593,3554,0.726635728448579 -417.1439554691314,3.023453950881958,8214.799205303192,35191,0,8214.799205303192,0.2861053373053965,3581,0.743826194236945,8636.29803609848,0.2607841832297189,0.7532715116228376,0.2847346712111353,3554,0.726746120678285 -421.1991891860962,3.063876152038574,8294.869604349136,35537,0,8294.869604349136,0.2861471636872557,3581,0.7437623127050754,8720.475849628448,0.2608253444944109,0.7532127244131905,0.2847871195418014,3554,0.7266732357027293 -425.2508318424225,3.098477363586426,8374.977184772491,35881,0,8374.977184772491,0.28610857569682,3581,0.7437063396659452,8804.681557416916,0.2607960360390799,0.7531414713178363,0.2847555543709553,3554,0.7266046097926632 -429.2997944355011,3.1342813968658447,8446.071097612381,36189,0,8446.071097612381,0.2861180522527751,3581,0.7439003704447081,8879.871040344238,0.26079542296273367,0.7533424922398159,0.28475758086179304,3554,0.7268173569833286 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/measurements.csv deleted file mode 100644 index 1ed5698a2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.989067,0.812358,,,,,,,,,,,,,, -1,,,0.2794570582253592,0.8154546192714146,0.2738082104341059,0.8275999671365011,3554.0,0.2972505852284627,0.8250461419645351,3581.0,46.38900184631348,50.38763356208801,46.38900184631348,3.998514652252197,0.0,0.0 -100,0.5893506,0.34720498,,,,,,,,,,,,,, -200,0.15868145,0.39448005,,,,,,,,,,,,,, -261,,,0.6731563976832798,0.3304575851985386,0.648352799552441,0.3530414948385622,3554.0,0.6680879909810458,0.3541430889045483,3581.0,126.72829818725586,134.8169128894806,126.72829818725586,8.051229238510132,0.0280277729034423,0.0 -300,0.09503922,0.35154456,,,,,,,,,,,,,, -400,0.126197,0.3367831,,,,,,,,,,,,,, -500,0.088687204,0.23391059,,,,,,,,,,,,,, -594,,,0.7111654281616211,0.2972275870186942,0.6874991069701393,0.3188105268161754,3554.0,0.705759073495532,0.3206396797987643,3581.0,206.8462879657745,219.03469967842105,206.8462879657745,12.102293968200684,0.0650091171264648,0.0 -600,0.07491033,0.41589707,,,,,,,,,,,,,, -700,0.106648676,0.31620044,,,,,,,,,,,,,, -800,0.1636534,0.21787257,,,,,,,,,,,,,, -900,0.14091212,0.32293734,,,,,,,,,,,,,, -932,,,0.7249079431806292,0.2850749833243234,0.7014162904034187,0.3062318577549064,3554.0,0.7187924058834823,0.30844171604475,3581.0,286.9167993068695,303.1899762153625,286.9167993068695,16.150205850601196,0.0895030498504638,0.0 -1000,0.09233715,0.28666976,,,,,,,,,,,,,, -1100,0.26836273,0.31359395,,,,,,,,,,,,,, -1200,0.11761506,0.22462979,,,,,,,,,,,,,, -1282,,,0.730029787336077,0.2807219198771885,0.7057390357915377,0.3024191354591657,3554.0,0.7228088975452038,0.304571906470347,3581.0,367.0610167980194,387.4175651073456,367.0610167980194,20.195985317230225,0.114079475402832,0.0 -1300,0.12464189,0.4016387,,,,,,,,,,,,,, -1400,0.08202205,0.2248701,,,,,,,,,,,,,, -1500,0.06640223,0.28899947,,,,,,,,,,,,,, -1600,0.2773001,0.22582254,,,,,,,,,,,,,, -1630,,,0.735931396484375,0.2754407099315098,0.7122415296804657,0.2966684353237549,3554.0,0.7293881499144792,0.2986177659107966,3581.0,447.1965122222901,471.6418776512146,447.1965122222901,24.246984004974365,0.1383695602416992,0.0 -1700,0.12894683,0.40098298,,,,,,,,,,,,,, -1800,0.23564339,0.25822923,,,,,,,,,,,,,, -1900,0.08762662,0.22711562,,,,,,,,,,,,,, -1978,,,0.7379615647452218,0.2737154960632324,0.714338569877251,0.2948487152460256,3554.0,0.7315722574438006,0.296682025959491,3581.0,527.3426928520203,555.8780257701874,527.3426928520203,28.29820895195008,0.1634666919708252,0.0 -2000,0.08394114,0.20968847,,,,,,,,,,,,,, -2100,0.09942038,0.22717318,,,,,,,,,,,,,, -2200,0.10348651,0.27238894,,,,,,,,,,,,,, -2300,0.11369975,0.28952557,,,,,,,,,,,,,, -2323,,,0.7403804234095982,0.2714020695005144,0.7161126080428742,0.2932167031755592,3554.0,0.7332982178075258,0.2949648262967746,3581.0,607.5200967788696,640.1423740386963,607.5200967788696,32.347129821777344,0.1886870861053466,0.0 -2400,0.045883365,0.2612398,,,,,,,,,,,,,, -2500,0.2859389,0.21046656,,,,,,,,,,,,,, -2600,0.17703137,0.2996417,,,,,,,,,,,,,, -2670,,,0.7409505844116211,0.2707600423267909,0.7167952263294879,0.2921713773762838,3554.0,0.7340108002740157,0.2938280145538083,3581.0,687.5188465118408,724.2226903438568,687.5188465118408,36.39129662513733,0.2129256725311279,0.0 -2700,0.06535689,0.28701413,,,,,,,,,,,,,, -2800,0.27182034,0.3325807,,,,,,,,,,,,,, -2900,0.15363625,0.23562989,,,,,,,,,,,,,, -3000,0.043136545,0.33222207,,,,,,,,,,,,,, -3018,,,0.7421620913914272,0.2698723758969988,0.7181647219682048,0.2912445497700654,3554.0,0.7352949076942544,0.2929147881723855,3581.0,767.621922492981,808.4106514453888,767.621922492981,40.43844723701477,0.2373144626617431,0.0 -3100,0.09435315,0.19602343,,,,,,,,,,,,,, -3200,0.13134179,0.39226845,,,,,,,,,,,,,, -3300,0.19237064,0.21142265,,,,,,,,,,,,,, -3363,,,0.7382982117789132,0.2705576249531337,0.7156640322743739,0.291953581132175,3554.0,0.7324179207536303,0.2936922748206681,3581.0,847.7596549987793,892.6428055763245,847.7596549987793,44.491214990615845,0.2660079002380371,0.0 -3400,0.046112202,0.31814733,,,,,,,,,,,,,, -3500,0.093509845,0.31659955,,,,,,,,,,,,,, -3600,0.06268135,0.30525288,,,,,,,,,,,,,, -3699,,,0.7407500403267997,0.2692724636622837,0.7178452920564856,0.2905863180681098,3554.0,0.7347732198800964,0.2921372674357721,3581.0,927.7310743331908,976.7025051116944,927.7310743331908,48.54154658317566,0.2912108898162842,0.0 -3700,0.27010447,0.32986102,,,,,,,,,,,,,, -3800,0.050227065,0.23516035,,,,,,,,,,,,,, -3900,0.047317673,0.28591356,,,,,,,,,,,,,, -4000,0.25004357,0.20261101,,,,,,,,,,,,,, -4045,,,0.7437830652509417,0.2689505474908011,0.7196033930738253,0.2908049386474747,3554.0,0.736945396493647,0.2922567811234641,3581.0,1007.7034420967102,1060.7598226070404,1007.7034420967102,52.58805227279663,0.3167808055877685,0.0 -4100,0.070527405,0.27896088,,,,,,,,,,,,,, -4200,0.31384623,0.24924946,,,,,,,,,,,,,, -4300,0.14993781,0.3534894,,,,,,,,,,,,,, -4389,,,0.7440434864589146,0.2679150785718645,0.7194905278383511,0.2897296276532604,3554.0,0.7368078841673066,0.2913125343610374,3581.0,1087.85382604599,1145.0013349056244,1087.85382604599,56.636749029159546,0.3462316989898681,0.0 -4400,0.088978134,0.23214947,,,,,,,,,,,,,, -4500,0.14998129,0.22006467,,,,,,,,,,,,,, -4600,0.10513298,0.24565518,,,,,,,,,,,,,, -4700,0.2900158,0.30432078,,,,,,,,,,,,,, -4735,,,0.7440556798662458,0.268069931438991,0.7199958453503095,0.2896225327645786,3554.0,0.7372943928241064,0.2911158787808049,3581.0,1167.8536245822906,1229.0866062641144,1167.8536245822906,60.68485951423645,0.3707010746002197,0.0 -4800,0.19509344,0.21265331,,,,,,,,,,,,,, -4900,0.043914046,0.41860336,,,,,,,,,,,,,, -5000,0.1040872,0.20932919,,,,,,,,,,,,,, -5085,,,0.7439850398472377,0.2685292107718331,0.7199961888233328,0.2899754856433771,3554.0,0.7373068009765079,0.291414049410081,3581.0,1248.007008075714,1313.3300881385803,1248.007008075714,64.73780488967896,0.3951351642608642,0.0 -5100,0.25162217,0.27637145,,,,,,,,,,,,,, -5200,0.1390212,0.2542575,,,,,,,,,,,,,, -5300,0.10694723,0.26627588,,,,,,,,,,,,,, -5400,0.10016241,0.29679713,,,,,,,,,,,,,, -5426,,,0.7444303376334054,0.2681744609560285,0.7201586515633793,0.2899265407375492,3554.0,0.7372558048336708,0.2915667310436156,3581.0,1328.0173542499542,1397.438951253891,1328.0173542499542,68.79254746437073,0.4260334968566894,0.0 -5500,0.10586326,0.26603758,,,,,,,,,,,,,, -5600,0.11002235,0.34286684,,,,,,,,,,,,,, -5700,0.35183644,0.2675675,,,,,,,,,,,,,, -5773,,,0.7458662986755371,0.2664764608655657,0.7213260476751547,0.2882950782239202,3554.0,0.7386466769058923,0.2898276807608733,3581.0,1408.1050882339478,1481.6133811473846,1408.1050882339478,72.83991551399231,0.4525535106658935,0.0 -5800,0.10453923,0.348362,,,,,,,,,,,,,, -5900,0.057785008,0.32693908,,,,,,,,,,,,,, -6000,0.54198164,0.20685148,,,,,,,,,,,,,, -6100,0.18246077,0.2731545,,,,,,,,,,,,,, -6121,,,0.7458045823233468,0.2676864010947091,0.7220802457398354,0.289156302482678,3554.0,0.7392027939341316,0.2905847484990226,3581.0,1488.1629683971405,1565.7648272514343,1488.1629683971405,76.8948724269867,0.4782447814941406,0.0 -6200,0.07088211,0.34174508,,,,,,,,,,,,,, -6300,0.34643054,0.25400192,,,,,,,,,,,,,, -6400,0.1073685,0.31865183,,,,,,,,,,,,,, -6468,,,0.7469580514090401,0.2672385488237653,0.7227373096335116,0.2890733880948403,3554.0,0.7396826894547612,0.2906994216437447,3581.0,1568.169510126114,1649.8570857048037,1568.169510126114,80.9424307346344,0.5033724308013916,0.0 -6500,0.15625256,0.2610709,,,,,,,,,,,,,, -6600,0.0678724,0.22838688,,,,,,,,,,,,,, -6700,0.06511096,0.24034756,,,,,,,,,,,,,, -6800,0.16110642,0.2712975,,,,,,,,,,,,,, -6814,,,0.7474136352539062,0.265754290989467,0.7232966897993458,0.2876301659716604,3554.0,0.7403746825694638,0.2890812485819254,3581.0,1648.3162939548492,1734.0930144786837,1648.3162939548492,84.99346947669983,0.5285801887512207,0.0 -6900,0.114531815,0.35892627,,,,,,,,,,,,,, -7000,0.20194107,0.22404005,,,,,,,,,,,,,, -7100,0.19032057,0.288208,,,,,,,,,,,,,, -7160,,,0.7481129510062081,0.26585396698543,0.7238215852736354,0.2876633111184141,3554.0,0.740983363803756,0.2891557997613271,3581.0,1728.470938205719,1818.335706949234,1728.470938205719,89.04334568977356,0.5537066459655762,0.0 -7200,0.17390476,0.23361921,,,,,,,,,,,,,, -7300,0.27174932,0.25243646,,,,,,,,,,,,,, -7400,0.12877254,0.3177258,,,,,,,,,,,,,, -7500,0.10052202,0.23943587,,,,,,,,,,,,,, -7506,,,0.7462959289550781,0.2659519399915422,0.7225214024910313,0.287555254505267,3554.0,0.7396961884337475,0.289021696268064,3581.0,1808.6441133022308,1902.599097251892,1808.6441133022308,93.09039211273192,0.5837688446044922,0.0 -7600,0.4833922,0.2722954,,,,,,,,,,,,,, -7700,0.120482415,0.21823747,,,,,,,,,,,,,, -7800,0.10579469,0.30217782,,,,,,,,,,,,,, -7851,,,0.7452442305428642,0.2657542398997715,0.7210275009232555,0.2876985857979126,3554.0,0.7380947186540072,0.2891605721276354,3581.0,1888.6966087818143,1986.7513935565948,1888.6966087818143,97.14750504493712,0.6135847568511963,0.0 -7900,0.13822469,0.2951423,,,,,,,,,,,,,, -8000,0.06986907,0.27075037,,,,,,,,,,,,,, -8100,0.20908076,0.19430488,,,,,,,,,,,,,, -8200,0.06669846,0.31853002,,,,,,,,,,,,,, -8201,,,0.7487215995788574,0.2652214254651751,0.7242683749824141,0.2872155940324722,3554.0,0.7414120586515638,0.2886514629075328,3581.0,1968.838838338852,2070.93323135376,1968.838838338852,101.14794516563416,0.6397542953491211,0.0 -8300,0.10130418,0.21259275,,,,,,,,,,,,,, -8400,0.08313502,0.19260092,,,,,,,,,,,,,, -8500,0.05392728,0.2935426,,,,,,,,,,,,,, -8547,,,0.7477490561349052,0.2652098281042916,0.7233854432285804,0.2870953269533448,3554.0,0.7406799776598716,0.2884978949773981,3581.0,2048.9702048301697,2155.156111717224,2048.9702048301697,105.2003893852234,0.6659073829650879,0.0 -8600,0.38682,0.2887273,,,,,,,,,,,,,, -8700,0.06808456,0.2607398,,,,,,,,,,,,,, -8800,0.2066699,0.20468226,,,,,,,,,,,,,, -8893,,,0.7493375369480678,0.2646728924342564,0.7244903959447102,0.2870908102830877,3554.0,0.7415842047219702,0.2885711167118996,3581.0,2129.0309269428253,2239.3087198734283,2129.0309269428253,109.25191283226012,0.6932220458984375,0.0 -8900,0.105930775,0.25413388,,,,,,,,,,,,,, -9000,0.21472166,0.23431684,,,,,,,,,,,,,, -9100,0.117131226,0.3699497,,,,,,,,,,,,,, -9200,0.12753695,0.28063652,,,,,,,,,,,,,, -9243,,,0.7482168333871024,0.2648956435067313,0.7236484061752603,0.2868886935824951,3554.0,0.7409059832929,0.2882860700899713,3581.0,2209.164315223694,2323.52946972847,2209.164315223694,113.3003273010254,0.718714714050293,0.0 -9300,0.110851154,0.33195847,,,,,,,,,,,,,, -9400,0.19786657,0.19958144,,,,,,,,,,,,,, -9500,0.09726611,0.32807183,,,,,,,,,,,,,, -9587,,,0.7480886323111398,0.2649180889129638,0.723539319143043,0.2869371919733927,3554.0,0.7408063090137182,0.2883126248996439,3581.0,2289.1893439292908,2407.648950576782,2289.1893439292908,117.35204768180849,0.7489063739776611,0.0 -9600,0.05090764,0.2962217,,,,,,,,,,,,,, -9700,0.1039973,0.26978353,,,,,,,,,,,,,, -9800,0.15923135,0.2410921,,,,,,,,,,,,,, -9900,0.15241933,0.23169896,,,,,,,,,,,,,, -9931,,,0.7497763633728027,0.2642585379736764,0.7247282853606851,0.2866745038051315,3554.0,0.7419528359309551,0.2880700523378595,3581.0,2369.2818806171417,2491.8336610794067,2369.2818806171417,121.4030601978302,0.7770266532897949,0.0 -10000,0.1160787,0.26189953,,,,,,,,,,,,,, -10100,0.12901203,0.3480165,,,,,,,,,,,,,, -10200,0.12632564,0.27725852,,,,,,,,,,,,,, -10278,,,0.7476978983197894,0.2647982495171683,0.7231131378156654,0.2868940517616594,3554.0,0.7403298905028274,0.2882836498184865,3581.0,2449.29647898674,2575.937170982361,2449.29647898674,125.45101857185364,0.8046424388885498,0.0 -10300,0.054116536,0.27441168,,,,,,,,,,,,,, -10400,0.12718788,0.286949,,,,,,,,,,,,,, -10500,0.14014548,0.28652853,,,,,,,,,,,,,, -10600,0.109975964,0.27898413,,,,,,,,,,,,,, -10622,,,0.7473840713500977,0.2655226673398699,0.7227836784916644,0.2875874379275552,3554.0,0.7399198760602834,0.2890171284317404,3581.0,2529.2785215377808,2660.009659767151,2529.2785215377808,129.50124597549438,0.8316726684570312,0.0 -10700,0.067598194,0.20619914,,,,,,,,,,,,,, -10800,0.08929014,0.30578208,,,,,,,,,,,,,, -10900,0.11470813,0.32670334,,,,,,,,,,,,,, -10964,,,0.7493657384599958,0.2642334188733782,0.7241476098674029,0.2867996653748417,3554.0,0.7413440183433399,0.2882112802922542,3581.0,2609.270174264908,2744.0874071121216,2609.270174264908,133.54821157455444,0.8579127788543701,0.0 -11000,0.21040179,0.2661454,,,,,,,,,,,,,, -11100,0.095487356,0.35226268,,,,,,,,,,,,,, -11200,0.13149835,0.26682654,,,,,,,,,,,,,, -11300,0.24002752,0.31065997,,,,,,,,,,,,,, -11311,,,0.7484683990478516,0.2644272191183908,0.7234407423853405,0.2867133506040729,3554.0,0.740676296120148,0.2881827824477276,3581.0,2689.234837770462,2828.1469349861145,2689.234837770462,137.60043573379517,0.8868420124053955,0.0 -11400,0.04144552,0.27496293,,,,,,,,,,,,,, -11500,0.15031534,0.33622032,,,,,,,,,,,,,, -11600,0.30682424,0.20829128,,,,,,,,,,,,,, -11658,,,0.748514039175851,0.2645096097673688,0.7242095037062113,0.2865941998122714,3554.0,0.7413526767793563,0.2879696281154007,3581.0,2769.3530809879303,2912.360531568527,2769.3530809879303,141.65505647659302,0.9141678810119628,0.0 -11700,0.21618187,0.2374794,,,,,,,,,,,,,, -11800,0.4631412,0.22076577,,,,,,,,,,,,,, -11900,0.26518247,0.41114664,,,,,,,,,,,,,, -12000,,,0.7486372675214495,0.2642149073737008,0.7235448834060214,0.2866262114980479,3554.0,0.7408110813800265,0.2880192948133377,3581.0,2849.3807203769684,2996.482173204422,2849.3807203769684,145.70746207237244,0.9428093433380128,0.0 -12000,0.11321314,0.26380524,,,,,,,,,,,,,, -12100,0.14844516,0.21368666,,,,,,,,,,,,,, -12200,0.1053879,0.20605461,,,,,,,,,,,,,, -12300,0.10519658,0.27440807,,,,,,,,,,,,,, -12348,,,0.7484827722821917,0.2645761966705322,0.7237893675040448,0.2867579849234577,3554.0,0.741044313739179,0.2882019400896223,3581.0,2929.544560432434,3080.739272594452,2929.544560432434,149.75933504104614,0.9710886478424072,0.0 -12400,0.15437391,0.24270386,,,,,,,,,,,,,, -12500,0.101959206,0.34413043,,,,,,,,,,,,,, -12600,0.19533816,0.34009033,,,,,,,,,,,,,, -12693,,,0.749316828591483,0.2641494955335344,0.7244157936040377,0.2864611898839775,3554.0,0.7417327616674811,0.2877790743463767,3581.0,3009.52698135376,3164.821781158448,3009.52698135376,153.8161265850067,1.0013494491577148,0.0 -12700,0.22841659,0.35353613,,,,,,,,,,,,,, -12800,0.13477992,0.2413864,,,,,,,,,,,,,, -12900,0.13910359,0.24434628,,,,,,,,,,,,,, -13000,0.31134814,0.2629872,,,,,,,,,,,,,, -13037,,,0.7469122750418526,0.2649313041142055,0.7224884977753939,0.286979628065428,3554.0,0.7397638196820022,0.288313068047944,3581.0,3089.589159965515,3248.976585388184,3089.589159965515,157.86707878112793,1.0301060676574707,0.0 -13100,0.09331838,0.27496427,,,,,,,,,,,,,, -13200,0.12280373,0.19631377,,,,,,,,,,,,,, -13300,0.11800425,0.34162855,,,,,,,,,,,,,, -13384,,,0.7504325594220843,0.2638043846402849,0.7254452509496342,0.286266870021015,3554.0,0.7425961509093131,0.2876959669959508,3581.0,3169.674957036972,3333.155841112137,3169.674957036972,161.91984677314758,1.057826042175293,0.0 -13400,0.9554759,0.3260545,,,,,,,,,,,,,, -13500,0.12609273,0.33433223,,,,,,,,,,,,,, -13600,0.08253068,0.24080431,,,,,,,,,,,,,, -13700,0.08520811,0.2865092,,,,,,,,,,,,,, -13730,,,0.7498242514474052,0.2643736771174839,0.7252430140334833,0.286442298867693,3554.0,0.7423545328207903,0.2878152761536582,3581.0,3249.79615855217,3417.368681192398,3249.79615855217,165.96881675720215,1.0872135162353516,0.0 -13800,0.11289948,0.24845172,,,,,,,,,,,,,, -13900,0.20785932,0.25183117,,,,,,,,,,,,,, -14000,0.12044363,0.2800213,,,,,,,,,,,,,, -14073,,,0.7470395905630929,0.2646465301513672,0.7220492644731289,0.2867674647789023,3554.0,0.739502566714954,0.2881390471193452,3581.0,3329.8468322753906,3501.50940656662,3329.8468322753906,170.0182363986969,1.114762544631958,0.0 -14100,0.064056,0.29019752,,,,,,,,,,,,,, -14200,0.29915145,0.228928,,,,,,,,,,,,,, -14300,0.12177319,0.30494314,,,,,,,,,,,,,, -14400,0.21240422,0.23041472,,,,,,,,,,,,,, -14419,,,0.7502109663827079,0.263518077986581,0.7250840547182752,0.2860264045573649,3554.0,0.7422518587685004,0.2874166131252618,3581.0,3409.9828832149506,3585.745890855789,3409.9828832149506,174.07473397254944,1.145578145980835,0.0 -14500,0.1395059,0.35774636,,,,,,,,,,,,,, -14600,0.262619,0.3467614,,,,,,,,,,,,,, -14700,0.13622275,0.24024722,,,,,,,,,,,,,, -14767,,,0.749675818852016,0.2635535172053745,0.7246531334631753,0.2858485198785699,3554.0,0.7418965901851787,0.2871996749882191,3581.0,3490.001576662064,3669.860261678696,3490.001576662064,178.1279594898224,1.1743721961975098,0.0 -14800,0.2963807,0.25732955,,,,,,,,,,,,,, -14900,0.3118613,0.23448965,,,,,,,,,,,,,, -15000,0.20601428,0.24522883,,,,,,,,,,,,,, -15100,0.16416247,0.21696508,,,,,,,,,,,,,, -15112,,,0.749295847756522,0.2635910511016845,0.7243161177326604,0.2858831247856728,3554.0,0.7416073166102346,0.2872077880109431,3581.0,3570.1850593090057,3754.140084505081,3570.1850593090057,182.1823709011078,1.2033040523529053,0.0 -15200,0.07465778,0.21779785,,,,,,,,,,,,,, -15300,0.4089003,0.1999866,,,,,,,,,,,,,, -15400,0.0639366,0.35304576,,,,,,,,,,,,,, -15458,,,0.7505433900015694,0.2631913423538208,0.7253112964705262,0.2858433162622661,3554.0,0.7424763645149749,0.287198481896642,3581.0,3650.1696586608887,3838.218438386917,3650.1696586608887,186.23437309265137,1.2322709560394287,0.0 -15500,0.14633816,0.26867527,,,,,,,,,,,,,, -15600,0.29791728,0.26741993,,,,,,,,,,,,,, -15700,0.22199538,0.22192213,,,,,,,,,,,,,, -15800,0.114920445,0.30250147,,,,,,,,,,,,,, -15806,,,0.7500426428658622,0.2635800157274519,0.7251619543999719,0.2860392847957407,3554.0,0.7423772356490854,0.2874296007792865,3581.0,3730.133190393448,3922.273967027664,3730.133190393448,190.2851979732513,1.260035514831543,0.0 -15900,0.12007408,0.3025258,,,,,,,,,,,,,, -16000,0.16804521,0.19731025,,,,,,,,,,,,,, -16100,0.13927034,0.2939843,,,,,,,,,,,,,, -16151,,,0.7505623272487095,0.2634295565741403,0.7256722866180711,0.2857569499705437,3554.0,0.7428711755619939,0.2871168744327702,3581.0,3810.2701032161713,4006.5082714557648,3810.2701032161713,194.33792114257807,1.291689157485962,0.0 -16200,0.12939535,0.2834383,,,,,,,,,,,,,, -16300,0.17989117,0.28985965,,,,,,,,,,,,,, -16400,0.101154014,0.334219,,,,,,,,,,,,,, -16496,,,0.7508247920445034,0.2629409176962716,0.7254057515519485,0.2858233433059581,3554.0,0.7426029685754677,0.2872027770263194,3581.0,3890.3815047740936,4090.71523809433,3890.3815047740936,198.39164185523987,1.3205831050872805,0.0 -16500,0.26823035,0.34055343,,,,,,,,,,,,,, -16600,0.23311497,0.23335785,,,,,,,,,,,,,, -16700,0.1603976,0.3097485,,,,,,,,,,,,,, -16800,0.27543506,0.23502687,,,,,,,,,,,,,, -16843,,,0.7493993214198521,0.2634856700897217,0.7241385421795864,0.2860307838384127,3554.0,0.7414484649888299,0.2873442095106988,3581.0,3970.4745304584494,4174.903291225433,3970.4745304584494,202.4444003105164,1.3496923446655271,0.0 -16900,0.086695075,0.3466404,,,,,,,,,,,,,, -17000,0.11066164,0.22149165,,,,,,,,,,,,,, -17100,0.079521134,0.25440136,,,,,,,,,,,,,, -17188,,,0.7493541581290108,0.2636538743972778,0.7246931137230938,0.2859189490220086,3554.0,0.7419645623167411,0.2872095946924741,3581.0,4050.521205425264,4259.043537378311,4050.521205425264,206.4969160556793,1.3776051998138428,0.0 -17200,0.12202091,0.26040274,,,,,,,,,,,,,, -17300,0.15139948,0.2758847,,,,,,,,,,,,,, -17400,0.1199659,0.26195922,,,,,,,,,,,,,, -17500,0.083966695,0.26547948,,,,,,,,,,,,,, -17533,,,0.7509161404200962,0.2629106725965227,0.7254793234735509,0.2858201490068409,3554.0,0.7426409429759494,0.2872696583312971,3581.0,4130.596678495407,4343.205314636231,4130.596678495407,210.5419554710388,1.405665159225464,0.0 -17600,0.12672603,0.31587437,,,,,,,,,,,,,, -17700,0.29368907,0.25607732,,,,,,,,,,,,,, -17800,0.30014312,0.33619833,,,,,,,,,,,,,, -17881,,,0.7497342654636928,0.2636230162211826,0.7245976282226013,0.286127179542417,3554.0,0.7417941206628735,0.2874863237616936,3581.0,4210.794536828995,4427.495761632919,4210.794536828995,214.59174036979675,1.435006618499756,0.0 -17900,0.18235965,0.29451004,,,,,,,,,,,,,, -18000,0.12919623,0.2864251,,,,,,,,,,,,,, -18100,0.17204298,0.25320664,,,,,,,,,,,,,, -18200,0.08547848,0.3005429,,,,,,,,,,,,,, -18229,,,0.7507637568882534,0.2640169177736555,0.7257824727639631,0.2863463496786191,3554.0,0.7428972190467048,0.2877632573608978,3581.0,4290.841006994247,4511.634707212448,4290.841006994247,218.64333319664,1.4626963138580322,0.0 -18300,0.13317466,0.21732274,,,,,,,,,,,,,, -18400,0.20641002,0.30097723,,,,,,,,,,,,,, -18500,0.12504646,0.267182,,,,,,,,,,,,,, -18573,,,0.750854355948312,0.2628540311540876,0.725241227973762,0.2857621192395452,3554.0,0.7424235276022759,0.2871066479335381,3581.0,4371.013791799545,4595.899594783783,4371.013791799545,222.6942172050476,1.4911997318267822,0.0 -18600,0.19021812,0.29840863,,,,,,,,,,,,,, -18700,0.13601026,0.2099289,,,,,,,,,,,,,, -18800,0.13361986,0.38456357,,,,,,,,,,,,,, -18900,0.22244892,0.21926099,,,,,,,,,,,,,, -18920,,,0.7506722041538784,0.2630536556243896,0.7254103540904615,0.285700448658202,3554.0,0.7426334435431793,0.2870736504293493,3581.0,4451.124196529388,4680.10669708252,4451.124196529388,226.7476847171784,1.5213429927825928,0.0 -19000,0.13334443,0.2472365,,,,,,,,,,,,,, -19100,0.18196677,0.24419256,,,,,,,,,,,,,, -19200,0.61834794,0.19475497,,,,,,,,,,,,,, -19265,,,0.7508053779602051,0.2633555105754307,0.7259667116936902,0.2857915205303443,3554.0,0.7431179750767942,0.2871111135048694,3581.0,4531.102238416672,4764.176918029785,4531.102238416672,230.79849290847773,1.5496416091918943,0.0 -19300,0.15468629,0.24280758,,,,,,,,,,,,,, -19400,0.28058815,0.26971495,,,,,,,,,,,,,, -19500,0.14001161,0.25750697,,,,,,,,,,,,,, -19600,0.18900709,0.27560967,,,,,,,,,,,,,, -19612,,,0.750293459211077,0.262720022882734,0.7247455277064575,0.2855622351136044,3554.0,0.7418924314088243,0.2869758510083601,3581.0,4611.263741254807,4848.430703639984,4611.263741254807,234.8485188484192,1.5785152912139893,0.0 -19700,0.2036388,0.22215205,,,,,,,,,,,,,, -19800,0.12139107,0.34540862,,,,,,,,,,,,,, -19900,0.12737225,0.24267685,,,,,,,,,,,,,, -19958,,,0.7505944115774972,0.2630909170423235,0.7255061830639772,0.2857109245854143,3554.0,0.7426680772872452,0.2871143859846237,3581.0,4691.397416830063,4932.652416944504,4691.397416830063,238.89465618133545,1.6072266101837158,0.0 -20000,0.2027625,0.2874959,,,,,,,,,,,,,, -20100,0.20635875,0.3500464,,,,,,,,,,,,,, -20200,0.24590388,0.35485366,,,,,,,,,,,,,, -20300,0.16979092,0.2681031,,,,,,,,,,,,,, -20305,,,0.7507155282156808,0.2629550184522356,0.7256489991470878,0.2854636583559018,3554.0,0.7428100892732477,0.2868059888561156,3581.0,4771.36471247673,5016.711462259293,4771.36471247673,242.94447922706604,1.635915994644165,0.0 -20400,0.0972085,0.26513696,,,,,,,,,,,,,, -20500,0.1288506,0.21695858,,,,,,,,,,,,,, -20600,0.12899753,0.28900018,,,,,,,,,,,,,, -20652,,,0.7502062661307198,0.2631739548274449,0.7248037807312183,0.2858611596858293,3554.0,0.7420031503071768,0.2872346155272619,3581.0,4851.4915273189545,5100.93705368042,4851.4915273189545,246.99908590316767,1.6671850681304932,0.0 -20700,0.1619878,0.30520365,,,,,,,,,,,,,, -20800,0.110063106,0.26063162,,,,,,,,,,,,,, -20900,0.18952191,0.30338535,,,,,,,,,,,,,, -20997,,,0.7517170224870954,0.2628318411963327,0.7264534816623874,0.2857012214725046,3554.0,0.743488583408964,0.2871397136143884,3581.0,4931.621717214584,5185.16579914093,4931.621717214584,251.05397963523865,1.69809889793396,0.0 -21000,0.40499315,0.32255024,,,,,,,,,,,,,, -21100,0.09834829,0.2626217,,,,,,,,,,,,,, -21200,0.14082141,0.21993905,,,,,,,,,,,,,, -21300,0.12127643,0.24044102,,,,,,,,,,,,,, -21345,,,0.7506068774632045,0.2627440861293247,0.7253145251169457,0.285448013159688,3554.0,0.7424934086803616,0.2868220444599099,3581.0,5011.650506019592,5269.286895036697,5011.650506019592,255.10294556617737,1.7279188632965088,0.0 -21400,0.1406166,0.21455671,,,,,,,,,,,,,, -21500,0.3579969,0.2053563,,,,,,,,,,,,,, -21600,0.14822333,0.23132607,,,,,,,,,,,,,, -21690,,,0.7520913396562848,0.2637862818581717,0.7273531062051561,0.286662052908035,3554.0,0.7442467078853672,0.2881662155189716,3581.0,5091.84877705574,5353.5817584991455,5091.84877705574,259.1506371498108,1.763467788696289,0.0 -21700,0.13909926,0.24768652,,,,,,,,,,,,,, -21800,0.1485469,0.23258913,,,,,,,,,,,,,, -21900,0.1117953,0.20352279,,,,,,,,,,,,,, -22000,0.121393755,0.27700078,,,,,,,,,,,,,, -22038,,,0.7507982935224261,0.262592111315046,0.7252544860324635,0.2854774316241383,3554.0,0.7424092786800126,0.2868673819398387,3581.0,5171.93511557579,5437.761346340179,5171.93511557579,263.2007920742035,1.793210744857788,0.0 -22100,0.19420686,0.37277573,,,,,,,,,,,,,, -22200,0.37204826,0.22812365,,,,,,,,,,,,,, -22300,0.08490671,0.21958938,,,,,,,,,,,,,, -22384,,,0.7508326257978167,0.2625446489879063,0.7254954667056486,0.2852868727907815,3554.0,0.7426881894024016,0.2866638746051207,3581.0,5252.042159557343,5521.96520614624,5252.042159557343,267.25358629226685,1.8238272666931152,0.0 -22400,0.07807068,0.22432385,,,,,,,,,,,,,, -22500,0.24124771,0.2485819,,,,,,,,,,,,,, -22600,0.13581431,0.21114732,,,,,,,,,,,,,, -22700,0.13100965,0.15837598,,,,,,,,,,,,,, -22731,,,0.7517132077898298,0.2625406299318586,0.7263388303671919,0.2853445419114026,3554.0,0.743461789980976,0.2867556744798938,3581.0,5332.193813800812,5606.21314406395,5332.193813800812,271.30557203292847,1.8549041748046875,0.0 -22800,0.14129913,0.30407804,,,,,,,,,,,,,, -22900,0.25593874,0.23329495,,,,,,,,,,,,,, -23000,0.19894136,0.32173085,,,,,,,,,,,,,, -23078,,,0.7513458388192313,0.2621175731931414,0.7253771059018008,0.2853178712311392,3554.0,0.7425826519303267,0.2867709460520804,3581.0,5412.402036905289,5690.515196084976,5412.402036905289,275.3547251224518,1.886202812194824,0.0 -23100,0.14441751,0.27630904,,,,,,,,,,,,,, -23200,0.22949277,0.3046853,,,,,,,,,,,,,, -23300,0.15938778,0.24638125,,,,,,,,,,,,,, -23400,0.1591165,0.35519683,,,,,,,,,,,,,, -23423,,,0.7511231558663505,0.2625161579677036,0.7256902846044949,0.2854214798686339,3554.0,0.742896741810074,0.2868492128595364,3581.0,5492.464476823807,5774.673628091812,5492.464476823807,279.40776801109314,1.916372299194336,0.0 -23500,0.19372332,0.27363366,,,,,,,,,,,,,, -23600,0.091184095,0.280835,,,,,,,,,,,,,, -23700,0.107292674,0.250901,,,,,,,,,,,,,, -23767,,,0.7518348693847656,0.2626339537756784,0.7262959649338773,0.285478427695906,3554.0,0.7434283152401564,0.2868443382282358,3581.0,5572.554956436157,5858.86038517952,5572.554956436157,283.4580383300781,1.949268341064453,0.0 -23800,0.22344205,0.23278862,,,,,,,,,,,,,, -23900,0.12386371,0.2326814,,,,,,,,,,,,,, -24000,0.20045482,0.22757544,,,,,,,,,,,,,, -24100,0.17310104,0.27123448,,,,,,,,,,,,,, -24112,,,0.7516160011291504,0.2618482112884521,0.7255825027697664,0.2851764462137732,3554.0,0.7427327087623918,0.2865892211607267,3581.0,5652.6396470069885,5943.040595293045,5652.6396470069885,287.5047097206116,1.9845812320709229,0.0 -24200,0.2169618,0.2735992,,,,,,,,,,,,,, -24300,0.3042982,0.2947892,,,,,,,,,,,,,, -24400,0.13582914,0.2151804,,,,,,,,,,,,,, -24460,,,0.7521277836390904,0.2623759337833949,0.726524717967431,0.2854077066003974,3554.0,0.7436621611892628,0.2867930012020909,3581.0,5732.652978897095,6027.151556968689,5732.652978897095,291.55988907814026,2.014088153839112,0.0 -24500,0.14911778,0.35990208,,,,,,,,,,,,,, -24600,0.17020667,0.19170389,,,,,,,,,,,,,, -24700,0.21463004,0.27974138,,,,,,,,,,,,,, -24800,0.11015693,0.26728508,,,,,,,,,,,,,, -24805,,,0.7513463837759835,0.2622409377779279,0.7257656425858188,0.2851812204887978,3554.0,0.7429490333094806,0.2865136473314018,3581.0,5812.645381689072,6111.23862695694,5812.645381689072,295.6112983226776,2.044174194335937,0.0 -24900,0.122257,0.30174953,,,,,,,,,,,,,, -25000,0.07685213,0.24920137,,,,,,,,,,,,,, -25100,0.16668338,0.25593555,,,,,,,,,,,,,, -25151,,,0.7522298949105399,0.2617241314479283,0.7262362006278137,0.2851499300963703,3554.0,0.7433074380192335,0.2866022429030822,3581.0,5892.711507558823,6195.396842479706,5892.711507558823,299.6589345932007,2.075486183166504,0.0 -25200,0.21767516,0.32640144,,,,,,,,,,,,,, -25300,0.21684471,0.25166312,,,,,,,,,,,,,, -25400,0.13262275,0.2782961,,,,,,,,,,,,,, -25500,,,0.7523572785513741,0.2621656315667288,0.7267561500905669,0.2852240859221124,3554.0,0.7438676456471656,0.2866298885393396,3581.0,5972.851857185364,6279.633534908295,5972.851857185364,303.71176838874817,2.105675220489502,0.0 -25500,0.15185642,0.20727353,,,,,,,,,,,,,, -25600,0.10019862,0.24886307,,,,,,,,,,,,,, -25700,0.14962222,0.26128116,,,,,,,,,,,,,, -25800,0.08761603,0.2707921,,,,,,,,,,,,,, -25846,,,0.7516216550554548,0.2622879913875034,0.726183992728264,0.2852599445057505,3554.0,0.743266872905613,0.2866360244388788,3581.0,6052.838146686554,6363.716654062271,6052.838146686554,307.7598524093628,2.14125919342041,0.0 -25900,0.20034118,0.19361813,,,,,,,,,,,,,, -26000,0.13296364,0.2576952,,,,,,,,,,,,,, -26100,0.102198,0.21735309,,,,,,,,,,,,,, -26190,,,0.7522223336356026,0.2617473602294922,0.7261280066254572,0.2853095763576252,3554.0,0.743290802913816,0.286731574030037,3581.0,6132.949994087219,6447.928359508514,6132.949994087219,311.8156065940857,2.172098398208618,0.0 -26200,0.13156323,0.2516036,,,,,,,,,,,,,, -26300,0.12106691,0.28041336,,,,,,,,,,,,,, -26400,0.09382277,0.25959992,,,,,,,,,,,,,, -26500,0.13815486,0.24627215,,,,,,,,,,,,,, -26538,,,0.7521584374564034,0.2619017532893589,0.7264184474140053,0.2850819396113974,3554.0,0.7435386250785395,0.2864591400904949,3581.0,6212.970285415649,6532.04209446907,6212.970285415649,315.86508679389954,2.203031063079834,0.0 -26600,0.11023352,0.31530958,,,,,,,,,,,,,, -26700,0.12737457,0.33264834,,,,,,,,,,,,,, -26800,0.16549976,0.28869307,,,,,,,,,,,,,, -26884,,,0.7518197468348912,0.2620894398008074,0.7260853472759566,0.2851997336847566,3554.0,0.7431427232049358,0.2866152646454377,3581.0,6293.131700754166,6616.29847073555,6293.131700754166,319.9152567386627,2.234532356262207,0.0 -26900,0.15053464,0.20042989,,,,,,,,,,,,,, -27000,0.15561208,0.24212192,,,,,,,,,,,,,, -27100,0.20614968,0.3817905,,,,,,,,,,,,,, -27200,0.1866552,0.20838156,,,,,,,,,,,,,, -27228,,,0.7528512137276786,0.2616877555847168,0.726850673866594,0.2852557713085168,3554.0,0.743790674196279,0.2866919633896782,3581.0,6373.266193151474,6700.527494668961,6373.266193151474,323.96526765823364,2.2656874656677246,0.0 -27300,0.24872293,0.23462558,,,,,,,,,,,,,, -27400,0.12118628,0.21592721,,,,,,,,,,,,,, -27500,0.11763859,0.25711644,,,,,,,,,,,,,, -27577,,,0.7521657262529645,0.2617858478001186,0.7262759061093135,0.2850842065333515,3554.0,0.7433965449158755,0.2864751275176277,3581.0,6453.276751041412,6784.637478113174,6453.276751041412,328.01572155952454,2.301090955734253,0.0 -27600,0.11802829,0.31669688,,,,,,,,,,,,,, -27700,0.07633119,0.4022944,,,,,,,,,,,,,, -27800,0.13930203,0.33998486,,,,,,,,,,,,,, -27900,0.21908286,0.20257878,,,,,,,,,,,,,, -27924,,,0.7517365046909877,0.2619636058807373,0.7259130612074424,0.2852299421371606,3554.0,0.7430076652384111,0.2866200029234153,3581.0,6533.27613568306,6868.735721111298,6533.27613568306,332.0697326660156,2.3325891494750977,0.0 -28000,0.1077953,0.23701316,,,,,,,,,,,,,, -28100,0.47012877,0.23614547,,,,,,,,,,,,,, -28200,0.36824414,0.28326,,,,,,,,,,,,,, -28267,,,0.7525230816432408,0.2617320503507341,0.7264022354873031,0.2852017601755944,3554.0,0.7435234898596761,0.2866514323643884,3581.0,6613.255655527115,6952.815067529678,6613.255655527115,336.1240510940552,2.365093231201172,0.0 -28300,0.15329011,0.2614051,,,,,,,,,,,,,, -28400,0.14184923,0.27787203,,,,,,,,,,,,,, -28500,0.11984189,0.25674096,,,,,,,,,,,,,, -28600,0.07817113,0.33787692,,,,,,,,,,,,,, -28614,,,0.7522595269339425,0.2616327660424368,0.7261526679885341,0.2851194468655476,3554.0,0.743265509372382,0.2865337594465582,3581.0,6693.265082120895,7036.919741868973,6693.265082120895,340.17482447624207,2.3963160514831543,0.0 -28700,0.11899655,0.25159845,,,,,,,,,,,,,, -28800,0.11409744,0.23419686,,,,,,,,,,,,,, -28900,0.19770823,0.30969644,,,,,,,,,,,,,, -28962,,,0.7524974686758858,0.2617249659129551,0.7265269848893852,0.2851314169004115,3554.0,0.7436247322020735,0.2864951714561226,3581.0,6773.419958114624,7121.172005176544,6773.419958114624,344.2283432483673,2.427100419998169,0.0 -29000,0.19865178,0.24124472,,,,,,,,,,,,,, -29100,0.18319704,0.20189571,,,,,,,,,,,,,, -29200,0.17647552,0.27453727,,,,,,,,,,,,,, -29300,0.11190341,0.26661733,,,,,,,,,,,,,, -29307,,,0.7518026488167899,0.2619178976331438,0.7258280859814645,0.2852925229220157,3554.0,0.7429303529042167,0.2866700786813215,3581.0,6853.474825620651,7205.318714380264,6853.474825620651,348.2759618759155,2.458501100540161,0.0 -29400,0.15300505,0.22095726,,,,,,,,,,,,,, -29500,0.12003506,0.19421852,,,,,,,,,,,,,, -29600,0.081166446,0.28250876,,,,,,,,,,,,,, -29654,,,0.7526144981384277,0.2615166221346174,0.7264150126837718,0.2851482127312535,3554.0,0.743440109802604,0.2866022088147514,3581.0,6933.58352804184,7289.521868467331,6933.58352804184,352.32594203948975,2.4895973205566406,0.0 -29700,0.10705215,0.38037705,,,,,,,,,,,,,, -29800,0.054951627,0.24768445,,,,,,,,,,,,,, -29900,0.09736342,0.27063316,,,,,,,,,,,,,, -30000,0.08044993,0.24498004,,,,,,,,,,,,,, -30003,,,0.7527388163975307,0.2616254431860788,0.7267130785734384,0.2851014488791238,3554.0,0.7438008325188494,0.2864873652283755,3581.0,7013.698459863663,7373.734119653702,7013.698459863663,356.37841534614563,2.5211710929870605,0.0 -30100,0.08983642,0.2656029,,,,,,,,,,,,,, -30200,0.18798275,0.24336497,,,,,,,,,,,,,, -30300,0.09956381,0.2402651,,,,,,,,,,,,,, -30346,,,0.7524821417672294,0.2616192102432251,0.7263430894326814,0.2851408452349026,3554.0,0.7434115437814158,0.2865387022545204,3581.0,7093.814469337463,7457.95020198822,7093.814469337463,360.4285726547241,2.558183431625366,0.0 -30400,0.24709876,0.22302218,,,,,,,,,,,,,, -30500,0.2275936,0.21145573,,,,,,,,,,,,,, -30600,0.083112314,0.24721777,,,,,,,,,,,,,, -30692,,,0.7525557109287807,0.2612156527382986,0.7261212745541995,0.2850283062988006,3554.0,0.7432238534321768,0.2864683439398038,3581.0,7173.847417593002,7542.083223819733,7173.847417593002,364.4844374656677,2.589385509490967,0.0 -30700,0.2969575,0.31215444,,,,,,,,,,,,,, -30800,0.087951615,0.31149262,,,,,,,,,,,,,, -30900,0.08257181,0.19374637,,,,,,,,,,,,,, -31000,0.09757625,0.30899033,,,,,,,,,,,,,, -31038,,,0.7527132715497699,0.2613873652049473,0.726601724619267,0.2849861278115327,3554.0,0.7436732058084334,0.2863576250414514,3581.0,7254.003686904907,7626.33788728714,7254.003686904907,368.5385098457336,2.620823383331299,0.0 -31100,0.111775145,0.25817928,,,,,,,,,,,,,, -31200,0.094069935,0.25333488,,,,,,,,,,,,,, -31300,0.105048746,0.29727566,,,,,,,,,,,,,, -31381,,,0.7524211066109794,0.2613832950592041,0.7261790467167276,0.2850200285989378,3554.0,0.7433383220469143,0.2863485575454657,3581.0,7334.081294298172,7710.508943557739,7334.081294298172,372.5858800411224,2.65349555015564,0.0 -31400,0.06066174,0.26039287,,,,,,,,,,,,,, -31500,0.123486295,0.1994394,,,,,,,,,,,,,, -31600,0.081446454,0.21815437,,,,,,,,,,,,,, -31700,0.06511427,0.24000551,,,,,,,,,,,,,, -31728,,,0.7530506678989956,0.2609822239194597,0.7264788986661156,0.2849472294916379,3554.0,0.7435512377609257,0.286343989709142,3581.0,7414.202308893204,7794.726012229919,7414.202308893204,376.6345460414887,2.6875240802764893,0.0 -31800,0.13048516,0.21549505,,,,,,,,,,,,,, -31900,0.07317866,0.32218423,,,,,,,,,,,,,, -32000,0.09709072,0.18699129,,,,,,,,,,,,,, -32077,,,0.7527856826782227,0.261205928666251,0.72648425684528,0.2849471092760797,3554.0,0.7436087106866098,0.2863007316173904,3581.0,7494.393157243729,7879.013952970505,7494.393157243729,380.68638134002686,2.719357490539551,0.0 -32100,0.076134056,0.3881074,,,,,,,,,,,,,, -32200,0.106457435,0.22444977,,,,,,,,,,,,,, -32300,0.057962924,0.3103045,,,,,,,,,,,,,, -32400,0.10392969,0.31767845,,,,,,,,,,,,,, -32421,,,0.7528745106288365,0.2613171679633004,0.7266306450478335,0.2850152199766108,3554.0,0.7437317013840408,0.2863673402157218,3581.0,7574.377175331116,7963.093930482864,7574.377175331116,384.7363765239716,2.7521777153015137,0.0 -32500,0.071249686,0.28253075,,,,,,,,,,,,,, -32600,0.08425302,0.24872503,,,,,,,,,,,,,, -32700,0.09769997,0.33241558,,,,,,,,,,,,,, -32768,,,0.7530239650181362,0.2608934129987444,0.7263984572840462,0.2849223276974448,3554.0,0.7434478137653588,0.2863153214229614,3581.0,7654.5212988853455,8047.334677219391,7654.5212988853455,388.787223815918,2.7848803997039795,0.0 -32800,0.08430729,0.25998458,,,,,,,,,,,,,, -32900,0.09586887,0.22574018,,,,,,,,,,,,,, -33000,0.0654055,0.26888326,,,,,,,,,,,,,, -33100,0.059330568,0.26370773,,,,,,,,,,,,,, -33114,,,0.7526590483529227,0.261080128805978,0.7261710781425859,0.2849193051348392,3554.0,0.7433103014390184,0.2862740745427255,3581.0,7734.511190891266,8131.42751955986,7734.511190891266,392.8369917869568,2.824607610702514,0.0 -33200,0.10026174,0.3472522,,,,,,,,,,,,,, -33300,0.090240166,0.25117278,,,,,,,,,,,,,, -33400,0.079422645,0.2507039,,,,,,,,,,,,,, -33461,,,0.7532807758876255,0.2611795323235648,0.726994863841798,0.2850068048875387,3554.0,0.7440449049671879,0.2863309338784557,3581.0,7814.515537023544,8215.531928062439,7814.515537023544,396.8901824951172,2.85813307762146,0.0 -33500,0.07586191,0.25141376,,,,,,,,,,,,,, -33600,0.09471509,0.22661698,,,,,,,,,,,,,, -33700,0.07859772,0.28564343,,,,,,,,,,,,,, -33800,0.11528587,0.29432255,,,,,,,,,,,,,, -33807,,,0.7532802990504673,0.2608542442321777,0.7267084760349254,0.2848649505288935,3554.0,0.7437748572108001,0.2862424064834369,3581.0,7894.61809015274,8299.732083559036,7894.61809015274,400.94162368774414,2.8913586139678955,0.0 -33900,0.080007024,0.36831266,,,,,,,,,,,,,, -34000,0.06864514,0.25389588,,,,,,,,,,,,,, -34100,0.07668186,0.30035374,,,,,,,,,,,,,, -34153,,,0.7531078883579799,0.2609092337744577,0.7266382014543472,0.2848093422464124,3554.0,0.7437286334342712,0.2861671053607582,3581.0,7974.578327178955,8383.78547167778,7974.578327178955,404.9890365600586,2.923970937728882,0.0 -34200,0.06957292,0.33384857,,,,,,,,,,,,,, -34300,0.12770575,0.24990702,,,,,,,,,,,,,, -34400,0.064535044,0.26869228,,,,,,,,,,,,,, -34499,,,0.7530832971845355,0.260980623109,0.7266129218398284,0.2848652940019168,3554.0,0.7436772964081262,0.2862170106770106,3581.0,8054.679794073105,8467.986461162567,8054.679794073105,409.04110741615295,2.958050012588501,0.0 -34500,0.08319271,0.26262006,,,,,,,,,,,,,, -34600,0.0736434,0.25872636,,,,,,,,,,,,,, -34700,0.17248301,0.3554754,,,,,,,,,,,,,, -34800,0.0739981,0.3642606,,,,,,,,,,,,,, -34842,,,0.753159659249442,0.2608306578227451,0.726635728448579,0.284789936020593,3554.0,0.7437078395524993,0.286145459270717,3581.0,8134.677657842636,8552.080843925476,8134.677657842636,413.0929937362671,2.9907097816467285,0.0 -34900,0.07918812,0.22587974,,,,,,,,,,,,,, -35000,0.054252643,0.29670697,,,,,,,,,,,,,, -35100,0.09014333,0.270541,,,,,,,,,,,,,, -35191,,,0.7532715116228376,0.2607841832297189,0.726746120678285,0.2847346712111353,3554.0,0.743826194236945,0.2861053373053965,3581.0,8214.799205303192,8636.29803609848,8214.799205303192,417.1439554691314,3.023453950881958,0.0 -35200,0.07477863,0.22071287,,,,,,,,,,,,,, -35300,0.08930795,0.2523591,,,,,,,,,,,,,, -35400,0.10781976,0.2926974,,,,,,,,,,,,,, -35500,0.072432026,0.19216615,,,,,,,,,,,,,, -35537,,,0.7532127244131905,0.2608253444944109,0.7266732357027293,0.2847871195418014,3554.0,0.7437623127050754,0.2861471636872557,3581.0,8294.869604349136,8720.475849628448,8294.869604349136,421.1991891860962,3.063876152038574,0.0 -35600,0.10644449,0.23268858,,,,,,,,,,,,,, -35700,0.07046689,0.29634595,,,,,,,,,,,,,, -35800,0.11501973,0.26608717,,,,,,,,,,,,,, -35881,,,0.7531414713178363,0.2607960360390799,0.7266046097926632,0.2847555543709553,3554.0,0.7437063396659452,0.28610857569682,3581.0,8374.977184772491,8804.681557416916,8374.977184772491,425.2508318424225,3.098477363586426,0.0 -35900,0.088387765,0.26478463,,,,,,,,,,,,,, -36000,0.05718526,0.22798607,,,,,,,,,,,,,, -36100,0.08686351,0.2573295,,,,,,,,,,,,,, -36189,,,0.7533424922398159,0.2607954229627336,0.7268173569833286,0.284757580861793,3554.0,0.7439003704447081,0.2861180522527751,3581.0,8446.071097612381,8879.871040344238,8446.071097612381,429.2997944355011,3.1342813968658447,0.0 -36189,,,,,,,,,,,8446.071097612381,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 13203c4a2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.987623453140259,0.0,30.755135536193848,1,0,30.755135536193848,0.8250461419645351,3581,0.2972505852284627,34.74285674095154,0.8154546192714146,0.2794570582253592,0.8275999671365011,3554,0.2738082104341059 -8.039196014404297,0.0185468196868896,110.89293766021729,343,0,110.89293766021729,0.3104794482817299,3581,0.7181294560265987,118.96266174316406,0.2883876051221575,0.7228562491280692,0.3085966008260938,3554,0.7009436715232836 -12.090436458587646,0.0432403087615966,190.94924879074097,682,0,190.94924879074097,0.3040596611238131,3581,0.7172930647427395,203.1064257621765,0.2817029271806989,0.7227890832083566,0.3021568251112303,3554,0.7006566654649691 -16.14446997642517,0.0671541690826416,270.9735543727875,1021,0,270.9735543727875,0.2995580243385227,3581,0.7312118074342013,287.2202172279358,0.2768511431557791,0.737203666142055,0.2975125889732519,3554,0.7143450958646947 -20.1980848312378,0.0915133953094482,351.1611316204071,1367,0,351.1611316204071,0.296217538364371,3581,0.7299091559620218,371.4979031085968,0.274091226713998,0.7352179118565151,0.2945705707917135,3554,0.7125450224439013 -24.2577121257782,0.1165361404418945,431.1475255489349,1712,0,431.1475255489349,0.2956740340185179,3581,0.7371561987311506,455.5808773040772,0.2730787651879446,0.7424981934683663,0.2939981386509918,3554,0.7203474930316193 -28.313236951828003,0.1428146362304687,511.2024140357971,2057,0,511.2024140357971,0.2945841619070267,3581,0.7336673262531416,539.7296917438507,0.2719074657985142,0.7400174140930176,0.2929652122278594,3554,0.7164212529016601 -32.36935877799988,0.1679451465606689,591.3643488883972,2403,0,591.3643488883972,0.2938471381073722,3581,0.7362529261423136,623.9852411746979,0.2707936423165457,0.7423323222569057,0.2921564362997678,3554,0.7194922452034679 -36.42832922935486,0.1926379203796386,671.3672761917114,2748,0,671.3672761917114,0.2929049366447919,3581,0.7373717051583007,708.083958864212,0.2697193452290126,0.7440058163234166,0.2913248537629255,3554,0.7201583767849606 -40.48659348487854,0.2232506275177002,751.5329916477203,3093,0,751.5329916477203,0.2933784917358978,3581,0.7317550390734082,792.350554227829,0.2706356218882969,0.7383698054722377,0.2919343809901695,3554,0.7140315736889772 -44.54193639755249,0.2484974861145019,831.5558547973633,3437,0,831.5558547973633,0.2942423582143605,3581,0.7320612886370776,876.4660019874573,0.2714465005057199,0.7383463723318917,0.2927429508344471,3554,0.7147255952799663 -48.595699310302734,0.2732932567596435,911.606143951416,3783,0,911.606143951416,0.2923623526838697,3581,0.7369284886815833,960.606726884842,0.26948082447052,0.7434975760323661,0.2907994087317986,3554,0.7195923332424733 -52.64918851852417,0.2977328300476074,991.7642273902892,4127,0,991.7642273902892,0.292056307650185,3581,0.7355574560178721,1044.8545212745669,0.2693336691175188,0.7420037133353097,0.2905618971361494,3554,0.7181948102050506 -56.71219515800476,0.3224613666534424,1071.9001495838163,4473,0,1071.9001495838163,0.293317644065467,3581,0.7364044146842712,1129.090569972992,0.2699408190590994,0.7434308188302177,0.2917687582983082,3554,0.7193578098621272 -60.76723957061768,0.3475849628448486,1151.9339997768402,4819,0,1151.9339997768402,0.2934710756422787,3581,0.7332005888281905,1213.2164514064789,0.2705390623637608,0.7400569915771484,0.2919701021845983,3554,0.7159953463527012 -64.8251564502716,0.3724362850189209,1231.9920871257782,5165,0,1231.9920871257782,0.291159818639172,3581,0.7376907037576794,1297.369089126587,0.2682174614497593,0.7443958691188267,0.2895858841929867,3554,0.7204964916291503 -68.88069200515747,0.3992469310760498,1311.9626994132996,5505,0,1311.9626994132996,0.2934232497142034,3581,0.7343092776982686,1381.4337227344513,0.2700718641281128,0.7416069167000907,0.2917918396854776,3554,0.7170744698974747 -72.9327085018158,0.4251205921173095,1391.9404091835022,5851,0,1391.9404091835022,0.2913406231455948,3581,0.7387483283082589,1465.5014822483065,0.2684932265962873,0.744828428540911,0.2897629445365257,3554,0.7217045549468908 -76.98855376243591,0.451430082321167,1472.0305316448212,6197,0,1472.0305316448212,0.2911732494414968,3581,0.7351471006920204,1549.6855239868164,0.2683504138674055,0.741647584097726,0.2897854420195554,3554,0.7175782761281303 -81.04493761062622,0.4772851467132568,1552.044453382492,6539,0,1552.044453382492,0.2946943013037559,3581,0.7376347307185492,1633.7936868667605,0.2707888058253697,0.7450931412833077,0.2930879007918014,3554,0.7206491310407288 -85.1014678478241,0.5031547546386719,1632.2036936283112,6885,0,1632.2036936283112,0.2927858661054,3581,0.7375047860016406,1718.0472900867462,0.2690360375813075,0.7446616036551339,0.2912670816003974,3554,0.7202530379501969 -89.15229082107544,0.5326182842254639,1712.2826426029203,7230,0,1712.2826426029203,0.2910377824150028,3581,0.7374284281407079,1802.2185401916504,0.2680553708757673,0.7442918504987445,0.2895192504264561,3554,0.7200282692037141 -93.20718359947205,0.5594367980957031,1792.2599248886108,7573,0,1792.2599248886108,0.2932264236923171,3581,0.7413721071278973,1886.2895441055296,0.2697854382651193,0.7479324340820312,0.291658640847021,3554,0.7244976088782006 -97.26000475883484,0.5858092308044434,1872.4262573719025,7920,0,1872.4262573719025,0.2909352106277053,3581,0.7383865829420903,1970.5471255779264,0.2679447957447597,0.7448616709027972,0.2893895550128376,3554,0.7212492471071328 -101.31509613990784,0.6120538711547852,1952.595893383026,8266,0,1952.595893383026,0.2911959863581227,3581,0.7410095436417901,2054.8102276325226,0.2681253126689366,0.7474032129560199,0.2896141520228088,3554,0.7241728894819218 -105.36998534202576,0.6376612186431885,2032.6832168102264,8609,0,2032.6832168102264,0.2912986944987433,3581,0.7344287913859606,2138.9900114536285,0.2686831951141357,0.7405197960989816,0.2899430617899726,3554,0.7167463844655669 -109.42215585708618,0.6642770767211914,2112.7515869140625,8955,0,2112.7515869140625,0.2902178557949071,3581,0.739195499031346,2223.14924287796,0.266757777759007,0.7465249470302037,0.2887311859216727,3554,0.7218790392427546 -113.47452092170715,0.6899559497833252,2192.7990798950195,9299,0,2192.7990798950195,0.2902843621282463,3581,0.7367883856421041,2307.28660941124,0.2673584733690534,0.7434581347874233,0.2887551259914005,3554,0.7193650227956176 -117.53413605690002,0.7210087776184082,2273.013704776764,9642,0,2273.013704776764,0.2910233630510856,3581,0.7355956349483385,2391.6038851737976,0.2680794170924595,0.7420975821358817,0.2894657373294175,3554,0.7183655849922622 -121.5886685848236,0.748103141784668,2353.11772108078,9987,0,2353.11772108078,0.2913739274447605,3581,0.7351992558381039,2475.801580429077,0.2682451180049351,0.7422417231968471,0.289962261931978,3554,0.7175602094471019 -125.6482367515564,0.7754595279693604,2433.091839790344,10332,0,2433.091839790344,0.2899997245662873,3581,0.739086075489563,2559.874430179596,0.2669530766350882,0.7457905496869769,0.2885337232805465,3554,0.721928430663513 -129.70497846603394,0.8024251461029053,2513.0755269527435,10677,0,2513.0755269527435,0.2908333206070231,3581,0.7353981953364982,2643.954001903534,0.2680451188768659,0.741919994354248,0.2894272340034995,3554,0.7178960573693374 -133.76000213623047,0.829798698425293,2593.118797302246,11020,0,2593.118797302246,0.2911618639390184,3581,0.7369526913964326,2728.0914771556854,0.2675068889345441,0.7446566990443638,0.28960034440727,3554,0.7196864448508723 -137.81110620498657,0.8560574054718018,2673.3326404094696,11366,0,2673.3326404094696,0.2903087011964186,3581,0.7356547441138998,2812.394788742065,0.2673390763146536,0.7424424035208566,0.2889510430039216,3554,0.7179938784863886 -141.86714124679563,0.8846144676208496,2753.3085412979126,11711,0,2753.3085412979126,0.2908966908139311,3581,0.736668599247766,2896.467357158661,0.2679608038493565,0.7434718949454171,0.2893848494324177,3554,0.7194423042258723 -145.92257690429688,0.9114329814910888,2833.4092671871185,12053,0,2833.4092671871185,0.2898577466686156,3581,0.7418477075188494,2980.6624734401703,0.2662361689976283,0.7494621276855469,0.2883218347724395,3554,0.7248269995075971 -149.9753041267395,0.9421532154083252,2913.51412153244,12397,0,2913.51412153244,0.290946391600199,3581,0.7357152849893536,3064.8627235889435,0.2680514369692121,0.7417304856436593,0.2894438237505276,3554,0.7187457409345104 -154.03021216392517,0.969123363494873,2993.7409834861755,12743,0,2993.7409834861755,0.2910311011021712,3581,0.736301331572012,3149.1837322711945,0.2683757202965872,0.7417310987200055,0.2895662031887486,3554,0.7193276529306767 -158.08467936515808,0.9979074001312256,3073.824416399002,13087,0,3073.824416399002,0.2895660528221865,3581,0.7402932796355767,3233.3625707626343,0.2658177273614066,0.7479927880423409,0.288119220035963,3554,0.7230232852727561 -162.13646292686462,1.0256588459014893,3153.797593355179,13433,0,3153.797593355179,0.2898521561823687,3581,0.7369149215259355,3317.427288770676,0.2668966054916382,0.7433491434369769,0.288490754805325,3554,0.7194958173229108 -166.1916003227234,1.0534124374389648,3233.8913176059723,13780,0,3233.8913176059723,0.2901215221721411,3581,0.7372898249877827,3401.6160349845886,0.267004234450204,0.744424547467913,0.288745646135956,3554,0.7197686722926632 -170.2469208240509,1.0821444988250732,3313.9427967071533,14119,0,3313.9427967071533,0.2922536449970329,3581,0.7371195196872382,3485.763567209244,0.2686776093074253,0.744767325265067,0.2906745562878095,3554,0.7200198197673396 -174.30322456359863,1.1104457378387451,3393.917350292206,14465,0,3393.917350292206,0.2901606214875384,3581,0.7368043389809061,3569.834620714188,0.2669003861291067,0.743727479662214,0.2886858818298923,3554,0.7195809986327026 -178.3620719909668,1.138521432876587,3474.073793888092,14813,0,3474.073793888092,0.2903403692557072,3581,0.7348131714037629,3654.090247869492,0.2672189133507864,0.7420859336853027,0.2889552333748065,3554,0.7172307501231008 -182.4195261001587,1.1659214496612549,3554.166335821152,15157,0,3554.166335821152,0.289497194394024,3581,0.7399728493263055,3738.2798635959625,0.2663203307560512,0.7469425882611956,0.2880200250268184,3554,0.7227949444068303 -186.43106985092163,1.1933035850524902,3638.211625099182,15501,0,3638.211625099182,0.2897083375148352,3581,0.7369443056670623,3826.375822782517,0.2666252681187221,0.74376494543893,0.2883164078986705,3554,0.7195468574141812 -190.48362708091736,1.2215111255645752,3718.3154895305634,15845,0,3718.3154895305634,0.2901179769857407,3581,0.7381134672359327,3910.5722908973694,0.2667724745614188,0.7452110563005719,0.2887018876727806,3554,0.7208480706158554 -194.5382580757141,1.2500929832458496,3798.436872243881,16188,0,3798.436872243881,0.289755788471272,3581,0.7381221256719491,3994.7888975143433,0.2665325743811471,0.7452451160975865,0.2883681349359876,3554,0.7206068838588562 -198.59056663513184,1.278146743774414,3878.410897016525,16531,0,3878.410897016525,0.2888766504206227,3581,0.7409147099055781,4078.8552227020255,0.2653563703809465,0.7482778685433524,0.2874118201707143,3554,0.7236172188247397 -202.6435163021088,1.31022047996521,3958.4317643642426,16876,0,3958.4317643642426,0.2892800858153274,3581,0.737625595045902,4162.973185300827,0.2660742487226213,0.7445151465279716,0.2879061465459254,3554,0.7202597700214547 -206.69846272468567,1.3392434120178225,4038.617102861405,17219,0,4038.617102861405,0.2890235370379259,3581,0.7409924994764032,4247.254467487335,0.2657501527241298,0.7482468741280692,0.2875324135492139,3554,0.7237485629088702 -210.75095224380493,1.368750810623169,4118.709421873093,17563,0,4118.709421873093,0.2890411607049358,3581,0.7399236257766685,4331.440586566925,0.2655086347034999,0.7473959241594587,0.2876453646529439,3554,0.7226291156311551 -214.8023371696472,1.397280216217041,4198.857976913452,17910,0,4198.857976913452,0.2900685148177883,3581,0.737226557045867,4415.681092500687,0.2669028724942888,0.7436424664088658,0.2886433255222988,3554,0.7204089747027996 -218.8605060577393,1.4254579544067385,4278.897131919861,18255,0,4278.897131919861,0.2886897441029915,3581,0.7405509874162245,4499.818293809891,0.2655533381870815,0.7477162224905831,0.287289921594726,3554,0.7232185153392304 -222.91399669647217,1.4582040309906006,4358.909790039063,18597,0,4358.909790039063,0.2893079700699001,3581,0.7419169750069813,4583.929002761841,0.2652980940682547,0.749964850289481,0.2878781363208708,3554,0.7248652624023987 -226.9695131778717,1.4869599342346191,4439.060678720474,18941,0,4439.060678720474,0.2892563262487783,3581,0.7415229139032393,4668.17610168457,0.265641553061349,0.7486487797328404,0.2878657712920301,3554,0.724370455164955 -231.0234580039978,1.5253701210021973,4519.121737480164,19286,0,4519.121737480164,0.2885983191998569,3581,0.7422357690763753,4752.34153175354,0.2651561668940952,0.7493959154401507,0.2871962393276062,3554,0.7251388730128024 -235.07659649848927,1.554959058761597,4599.299153327942,19628,0,4599.299153327942,0.2888892971913397,3581,0.7369073539165037,4836.61348938942,0.2651771136692592,0.7444512503487724,0.287422742612857,3554,0.7195708318312113 -239.13303208351127,1.5843939781188965,4679.32307100296,19972,0,4679.32307100296,0.2881554436064472,3581,0.7399095132077282,4920.735057592392,0.2648437534059797,0.7471006257193429,0.2867585516539462,3554,0.7225424917346651 -243.1871690750122,1.61844801902771,4759.342108011246,20317,0,4759.342108011246,0.2887418310724134,3581,0.7426582598479824,5004.854192733765,0.2653532709394182,0.7499920981270927,0.2872765776677687,3554,0.7255031605013716 -247.2418978214264,1.6471836566925049,4839.4510152339935,20660,0,4839.4510152339935,0.2885014060754677,3581,0.7406960673519967,5089.058373451233,0.2645001922334943,0.7488868577139718,0.2871020246773002,3554,0.7233265032577729 -251.29483437538147,1.676776647567749,4919.556501865387,21006,0,4919.556501865387,0.2881249686387357,3581,0.7413716980679279,5173.258333683014,0.2645917109080723,0.7488184656415667,0.2867401930208479,3554,0.7240962949977139 -255.34806871414185,1.7086007595062256,4999.725428342819,21349,0,4999.725428342819,0.2880124089705215,3581,0.7390779624668389,5257.524056196213,0.2646516731807163,0.746589115687779,0.2866960223900446,3554,0.721595399220069 -259.3984282016754,1.74009370803833,5079.9467849731445,21692,0,5079.9467849731445,0.2891149619410604,3581,0.742009149853393,5341.839040994644,0.2651270457676479,0.7493458475385394,0.2876467900659908,3554,0.7248130545028489 -263.4543421268463,1.769385814666748,5160.113996267319,22037,0,5160.113996267319,0.2879780479331018,3581,0.7397672285150796,5426.103360176086,0.2643566301890782,0.7473893165588379,0.286699439946627,3554,0.7223388122318163 -267.50645327568054,1.8041374683380127,5240.150569677353,22382,0,5240.150569677353,0.2877156700511379,3581,0.7427755918825049,5510.238574266434,0.2640033108847482,0.7503771781921387,0.2863131358372608,3554,0.7255582535743177 -271.55556893348694,1.8343710899353027,5320.23034787178,22723,0,5320.23034787178,0.2882191887849937,3581,0.7379846133456087,5594.409419775009,0.2649799925940377,0.7449972970145089,0.2868631220159063,3554,0.720803899985052 -275.6089406013489,1.864561557769776,5400.266304969788,23069,0,5400.266304969788,0.2876683554480242,3581,0.7411706450930257,5678.5408182144165,0.2638182810374668,0.7490667615618024,0.2863307560033589,3554,0.723865893293648 -279.65731549263,1.895724058151245,5480.29541182518,23414,0,5480.29541182518,0.2879020650438076,3581,0.7431733345259705,5762.661432981491,0.2641075338636126,0.75071838923863,0.2864420584365767,3554,0.7261588505029544 -283.71100759506226,1.9275357723236084,5560.32515335083,23757,0,5560.32515335083,0.2876682872713627,3581,0.7428464956105139,5846.788516521454,0.2638962609427316,0.7506724766322544,0.2863045318380258,3554,0.725684651646912 -287.70102548599243,1.963218688964844,5640.499708890915,24101,0,5640.499708890915,0.2873740368001257,3581,0.7408622138761868,5931.000749349594,0.2632827929088047,0.7491707120622907,0.2859516648274831,3554,0.7234838139024691 -291.7552845478058,1.997541904449463,5720.588813781738,24446,0,5720.588813781738,0.2876302787825502,3581,0.7441044231927185,6015.1900470256805,0.2637123380388532,0.751819406236921,0.2862745638167381,3554,0.7269841474834693 -295.8091485500336,2.027472972869873,5800.679148674011,24788,0,5800.679148674011,0.2870329489624058,3581,0.7423019686147375,6099.376039981842,0.2634058339255197,0.7498753411429269,0.2857233583088597,3554,0.7250181078977912 -299.86060905456543,2.058157205581665,5880.765575647354,25132,0,5880.765575647354,0.2870723209844492,3581,0.7424756145716979,6183.556484937668,0.262839674949646,0.7507461139133998,0.2856844428153137,3554,0.7252613554929305 -303.9175374507904,2.090134382247925,5960.800450563431,25477,0,5960.800450563431,0.2872153897087057,3581,0.7417635775185004,6267.692300796509,0.2632796083177839,0.7495565414428711,0.2858350900833568,3554,0.7245314066236986 -307.97506952285767,2.125797748565674,6040.952090978622,25822,0,6040.952090978622,0.2868910392413956,3581,0.7427894999214605,6351.949378728867,0.2631086451666696,0.7506115777151925,0.2855094433099149,3554,0.7256509225960186 -312.0299699306488,2.155958414077759,6121.030804872513,26162,0,6121.030804872513,0.2870183250685039,3581,0.742301900438076,6436.124837398529,0.2625585283551897,0.7509149823869977,0.2856405126156268,3554,0.7250569203494303 -316.0838363170624,2.1931822299957275,6201.037936925888,26506,0,6201.037936925888,0.2870047238245253,3581,0.7421060288894513,6520.235012292862,0.2629836286817278,0.7500993183680943,0.2856195607612021,3554,0.724892946328081 -320.1396276950836,2.2305245399475098,6281.198880434036,26852,0,6281.198880434036,0.2867839677944359,3581,0.7423404884285116,6604.5012130737305,0.2629303080695016,0.750182969229562,0.2854501426924328,3554,0.725069285378271 -324.193115234375,2.2626664638519287,6361.350186109543,27197,0,6361.350186109543,0.2868568827339605,3581,0.7428371554078819,6688.750422000885,0.2624048846108572,0.7514676366533551,0.2854863447490943,3554,0.7256813543058878 -328.2473568916321,2.2936313152313232,6441.491717815399,27544,0,6441.491717815399,0.2870722868961184,3581,0.7433970903291678,6772.989144325256,0.2626441035951887,0.7516813278198242,0.2857002940953415,3554,0.7261767797947735 -332.3001654148102,2.324267625808716,6521.546169281006,27890,0,6521.546169281006,0.2868547351691217,3581,0.7436194144224728,6857.138879299164,0.2627055134092058,0.7515984943934849,0.2854648776851347,3554,0.72654058642111 -336.35513162612915,2.356107234954834,6601.5510313510895,28230,0,6601.5510313510895,0.2867318126483524,3581,0.7424111876265359,6941.242228984833,0.2624414137431553,0.7507645743233817,0.2853406606662387,3554,0.7252138188264983 -340.4086203575134,2.393435478210449,6681.67812538147,28575,0,6681.67812538147,0.2865787219548485,3581,0.7428912195004886,7025.472114086151,0.2621686969484602,0.7513136182512555,0.2852038553610369,3554,0.7256709814205824 -344.46311378479004,2.425102949142456,6761.666354894638,28920,0,6761.666354894638,0.2863938268487329,3581,0.7425641760550474,7109.558537721634,0.2622612885066441,0.7508598055158343,0.2850795353002339,3554,0.7252665075882808 -348.5168535709381,2.458962202072144,6841.841877698898,29264,0,6841.841877698898,0.2878441830581541,3581,0.7443117484204831,7193.83359837532,0.2630555118833269,0.7527449471609933,0.2864035035897053,3554,0.7273026156707231 -352.57623052597046,2.48972225189209,6921.887184858322,29609,0,6921.887184858322,0.2864239950214675,3581,0.7430781599064508,7277.980953931808,0.261823228427342,0.7518707684108189,0.2850722021511853,3554,0.7259042682980444 -356.6236889362335,2.5209431648254395,7002.001019716263,29954,0,7002.001019716263,0.2863624655844212,3581,0.7435643958566043,7362.185358047485,0.2619952474321638,0.7520835059029716,0.284983792194974,3554,0.7263647282331528 -360.6707994937897,2.5538601875305176,7081.958555698395,30295,0,7081.958555698395,0.2862053524678861,3581,0.7430245048738132,7446.234877347946,0.2618656499045236,0.7514946801321847,0.284842779345236,3554,0.7258097445220174 -364.7251925468445,2.58677077293396,7162.066142082214,30639,0,7162.066142082214,0.2862288734161198,3581,0.7440924241002862,7530.441558122635,0.2613992350442068,0.7530231475830078,0.2848137730484137,3554,0.7269642947427195 -368.7781274318695,2.6195969581604004,7242.0318784713745,30982,0,7242.0318784713745,0.2863513187002583,3581,0.7440678805021292,7614.505212068558,0.2616708278656006,0.7528040749686105,0.2849631838135727,3554,0.7269228032014983 -372.8330626487732,2.653053998947144,7322.19585609436,31328,0,7322.19585609436,0.2861925352555152,3581,0.7435416248516475,7698.769409894943,0.2616599457604544,0.7522945404052734,0.2848349481603035,3554,0.7264111657859103 -376.886527299881,2.6854922771453857,7402.319948911667,31673,0,7402.319948911667,0.2861175068394826,3581,0.7437561086288746,7782.991308689117,0.2611620766775949,0.7529420171465192,0.2847431721684633,3554,0.7266534516565841 -380.9351851940155,2.71712589263916,7482.313843011856,32017,0,7482.313843011856,0.2862083522409941,3581,0.7436601840660779,7867.077488183975,0.2614663669041225,0.7525466510227748,0.2848308780049768,3554,0.7265271222785945 -384.9920229911804,2.7492971420288086,7562.373685836792,32363,0,7562.373685836792,0.286116859161198,3581,0.7436575933529391,7951.238343477249,0.261459265436445,0.7524669510977608,0.2847815037578696,3554,0.7265055521727279 -389.0479950904846,2.7822179794311523,7642.534946680069,32704,0,7642.534946680069,0.2861338010615924,3581,0.7436016884904706,8035.50048160553,0.2610211712973458,0.7528903824942452,0.2847771588241242,3554,0.7264628241286227 -393.1053960323334,2.8138113021850586,7722.624789953232,33048,0,7722.624789953232,0.2861019284723192,3581,0.7432874622574002,8119.691402673721,0.2612138135092599,0.7523625237601144,0.2847444773659521,3554,0.7261041695976365 -397.1605281829834,2.84609603881836,7802.774275302887,33393,0,7802.774275302887,0.2860446941649504,3581,0.7436916816837127,8203.940031290054,0.2612191098076956,0.7526929037911552,0.2846999976094277,3554,0.7265927943206598 -401.21487951278687,2.8802406787872314,7882.909619808197,33734,0,7882.909619808197,0.2862703930030019,3581,0.7437669487180606,8288.175689220428,0.2610527958188738,0.7531563213893345,0.2848787581444323,3554,0.7266605958954699 -405.2720079421997,2.913221836090088,7963.108328819275,34081,0,7963.108328819275,0.286056795522375,3581,0.7439093697640324,8372.476554393768,0.2610229934964861,0.7531188556126186,0.284685382832284,3554,0.7267718811550365 -409.32266187667847,2.9466326236724854,8043.085624933243,34426,0,8043.085624933243,0.2859959137636135,3581,0.743857350971272,8456.55000448227,0.2610433782849993,0.7529817308698382,0.2846396322255733,3554,0.7267184367526027 -413.373916387558,2.9849205017089844,8123.120993852615,34767,0,8123.120993852615,0.286009378654269,3581,0.7439245731595574,8540.69055390358,0.2610036305018833,0.7531920160566058,0.2846758686295371,3554,0.7267810175374578 -417.4235715866089,3.0231704711914062,8203.180258989334,35114,0,8203.180258989334,0.2859910050439821,3581,0.7440337921713558,8624.85003042221,0.2608975853238787,0.7533129964556012,0.2846280571846862,3554,0.7268919593240011 -421.4790394306183,3.05560564994812,8283.291579961777,35460,0,8283.291579961777,0.2859707224871718,3581,0.7440485865069115,8709.06115269661,0.2609228576932634,0.7532985551016671,0.2846089600845878,3554,0.7269186128306134 -425.5373492240906,3.088459014892578,8363.375246286392,35806,0,8363.375246286392,0.2859652001775866,3581,0.7440701303319603,8793.248180627823,0.2609026942934309,0.7533233506338937,0.284607654887099,3554,0.7269327639191756 -429.58997654914856,3.1220977306365967,8443.398798942566,36151,0,8443.398798942566,0.2859424973492914,3581,0.7439553890105767,8877.369792938232,0.2608860220227922,0.7532243728637695,0.2845854665297904,3554,0.7268157770074212 -433.6436011791229,3.156363010406494,8450.03044462204,36189,0,8450.03044462204,0.2859423439518029,3581,0.7439548435972843,8888.094502449036,0.2608859198434012,0.7532238279070173,0.2845853634878834,3554,0.7268150213667698 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/measurements.csv deleted file mode 100644 index 0ea61bfeb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.989067,0.812358,,,,,,,,,,,,,, -1,,,0.2794570582253592,0.8154546192714146,0.2738082104341059,0.8275999671365011,3554.0,0.2972505852284627,0.8250461419645351,3581.0,30.755135536193848,34.74285674095154,30.755135536193848,3.987623453140259,0.0,0.0 -100,0.514676,0.2509368,,,,,,,,,,,,,, -200,0.74592197,0.35523313,,,,,,,,,,,,,, -300,0.20089613,0.31525013,,,,,,,,,,,,,, -343,,,0.7228562491280692,0.2883876051221575,0.7009436715232836,0.3085966008260938,3554.0,0.7181294560265987,0.3104794482817299,3581.0,110.89293766021729,118.96266174316406,110.89293766021729,8.039196014404297,0.0185468196868896,0.0 -400,0.32358837,0.30901593,,,,,,,,,,,,,, -500,0.55632246,0.22281909,,,,,,,,,,,,,, -600,0.31315553,0.3897825,,,,,,,,,,,,,, -682,,,0.7227890832083566,0.2817029271806989,0.7006566654649691,0.3021568251112303,3554.0,0.7172930647427395,0.3040596611238131,3581.0,190.94924879074097,203.1064257621765,190.94924879074097,12.090436458587646,0.0432403087615966,0.0 -700,0.13118924,0.3050019,,,,,,,,,,,,,, -800,0.404001,0.2138642,,,,,,,,,,,,,, -900,0.2670999,0.31160304,,,,,,,,,,,,,, -1000,0.34227666,0.27773565,,,,,,,,,,,,,, -1021,,,0.737203666142055,0.2768511431557791,0.7143450958646947,0.2975125889732519,3554.0,0.7312118074342013,0.2995580243385227,3581.0,270.9735543727875,287.2202172279358,270.9735543727875,16.14446997642517,0.0671541690826416,0.0 -1100,0.5278814,0.31214052,,,,,,,,,,,,,, -1200,0.28010157,0.22142988,,,,,,,,,,,,,, -1300,0.2743484,0.38447583,,,,,,,,,,,,,, -1367,,,0.7352179118565151,0.274091226713998,0.7125450224439013,0.2945705707917135,3554.0,0.7299091559620218,0.296217538364371,3581.0,351.1611316204071,371.4979031085968,351.1611316204071,20.1980848312378,0.0915133953094482,0.0 -1400,0.11037203,0.22225507,,,,,,,,,,,,,, -1500,0.15995173,0.28352693,,,,,,,,,,,,,, -1600,0.35087708,0.22670645,,,,,,,,,,,,,, -1700,0.15465201,0.39817047,,,,,,,,,,,,,, -1712,,,0.7424981934683663,0.2730787651879446,0.7203474930316193,0.2939981386509918,3554.0,0.7371561987311506,0.2956740340185179,3581.0,431.1475255489349,455.5808773040772,431.1475255489349,24.2577121257782,0.1165361404418945,0.0 -1800,0.17508106,0.2576336,,,,,,,,,,,,,, -1900,0.10675119,0.22470652,,,,,,,,,,,,,, -2000,0.2546092,0.20954761,,,,,,,,,,,,,, -2057,,,0.7400174140930176,0.2719074657985142,0.7164212529016601,0.2929652122278594,3554.0,0.7336673262531416,0.2945841619070267,3581.0,511.2024140357971,539.7296917438507,511.2024140357971,28.313236951828003,0.1428146362304687,0.0 -2100,0.3130182,0.22670688,,,,,,,,,,,,,, -2200,0.2827903,0.27284914,,,,,,,,,,,,,, -2300,0.09563558,0.28709152,,,,,,,,,,,,,, -2400,0.38435832,0.2617041,,,,,,,,,,,,,, -2403,,,0.7423323222569057,0.2707936423165457,0.7194922452034679,0.2921564362997678,3554.0,0.7362529261423136,0.2938471381073722,3581.0,591.3643488883972,623.9852411746979,591.3643488883972,32.36935877799988,0.1679451465606689,0.0 -2500,0.2504872,0.21108036,,,,,,,,,,,,,, -2600,0.17178608,0.29907653,,,,,,,,,,,,,, -2700,0.20389397,0.28791758,,,,,,,,,,,,,, -2748,,,0.7440058163234166,0.2697193452290126,0.7201583767849606,0.2913248537629255,3554.0,0.7373717051583007,0.2929049366447919,3581.0,671.3672761917114,708.083958864212,671.3672761917114,36.42832922935486,0.1926379203796386,0.0 -2800,0.10866256,0.33096695,,,,,,,,,,,,,, -2900,0.26139447,0.23523022,,,,,,,,,,,,,, -3000,0.20904793,0.3332338,,,,,,,,,,,,,, -3093,,,0.7383698054722377,0.2706356218882969,0.7140315736889772,0.2919343809901695,3554.0,0.7317550390734082,0.2933784917358978,3581.0,751.5329916477203,792.350554227829,751.5329916477203,40.48659348487854,0.2232506275177002,0.0 -3100,0.1900137,0.19745392,,,,,,,,,,,,,, -3200,0.1062326,0.39215878,,,,,,,,,,,,,, -3300,0.11437465,0.21123484,,,,,,,,,,,,,, -3400,0.12074619,0.31886375,,,,,,,,,,,,,, -3437,,,0.7383463723318917,0.2714465005057199,0.7147255952799663,0.2927429508344471,3554.0,0.7320612886370776,0.2942423582143605,3581.0,831.5558547973633,876.4660019874573,831.5558547973633,44.54193639755249,0.2484974861145019,0.0 -3500,0.08694079,0.31767392,,,,,,,,,,,,,, -3600,0.24881768,0.30604604,,,,,,,,,,,,,, -3700,0.14864752,0.3295785,,,,,,,,,,,,,, -3783,,,0.7434975760323661,0.26948082447052,0.7195923332424733,0.2907994087317986,3554.0,0.7369284886815833,0.2923623526838697,3581.0,911.606143951416,960.606726884842,911.606143951416,48.595699310302734,0.2732932567596435,0.0 -3800,0.18822265,0.23577422,,,,,,,,,,,,,, -3900,0.12056084,0.2870297,,,,,,,,,,,,,, -4000,0.44204432,0.20916206,,,,,,,,,,,,,, -4100,0.10470914,0.28153688,,,,,,,,,,,,,, -4127,,,0.7420037133353097,0.2693336691175188,0.7181948102050506,0.2905618971361494,3554.0,0.7355574560178721,0.292056307650185,3581.0,991.7642273902892,1044.8545212745669,991.7642273902892,52.64918851852417,0.2977328300476074,0.0 -4200,0.07573862,0.25045484,,,,,,,,,,,,,, -4300,0.063927285,0.35381496,,,,,,,,,,,,,, -4400,0.24778357,0.23390724,,,,,,,,,,,,,, -4473,,,0.7434308188302177,0.2699408190590994,0.7193578098621272,0.2917687582983082,3554.0,0.7364044146842712,0.293317644065467,3581.0,1071.9001495838163,1129.090569972992,1071.9001495838163,56.71219515800476,0.3224613666534424,0.0 -4500,0.32150847,0.22299522,,,,,,,,,,,,,, -4600,0.16803752,0.2476287,,,,,,,,,,,,,, -4700,0.2795761,0.30690014,,,,,,,,,,,,,, -4800,0.41775057,0.215795,,,,,,,,,,,,,, -4819,,,0.7400569915771484,0.2705390623637608,0.7159953463527012,0.2919701021845983,3554.0,0.7332005888281905,0.2934710756422787,3581.0,1151.9339997768402,1213.2164514064789,1151.9339997768402,60.76723957061768,0.3475849628448486,0.0 -4900,0.13810727,0.42012006,,,,,,,,,,,,,, -5000,0.26112232,0.21137008,,,,,,,,,,,,,, -5100,0.14198077,0.2761838,,,,,,,,,,,,,, -5165,,,0.7443958691188267,0.2682174614497593,0.7204964916291503,0.2895858841929867,3554.0,0.7376907037576794,0.291159818639172,3581.0,1231.9920871257782,1297.369089126587,1231.9920871257782,64.8251564502716,0.3724362850189209,0.0 -5200,0.43073103,0.25943547,,,,,,,,,,,,,, -5300,0.14202756,0.26744646,,,,,,,,,,,,,, -5400,0.26653266,0.3002927,,,,,,,,,,,,,, -5500,0.22690786,0.2680599,,,,,,,,,,,,,, -5505,,,0.7416069167000907,0.2700718641281128,0.7170744698974747,0.2917918396854776,3554.0,0.7343092776982686,0.2934232497142034,3581.0,1311.9626994132996,1381.4337227344513,1311.9626994132996,68.88069200515747,0.3992469310760498,0.0 -5600,0.09813904,0.3467555,,,,,,,,,,,,,, -5700,0.10053253,0.26780486,,,,,,,,,,,,,, -5800,0.2762919,0.3508786,,,,,,,,,,,,,, -5851,,,0.744828428540911,0.2684932265962873,0.7217045549468908,0.2897629445365257,3554.0,0.7387483283082589,0.2913406231455948,3581.0,1391.9404091835022,1465.5014822483065,1391.9404091835022,72.9327085018158,0.4251205921173095,0.0 -5900,0.2992611,0.33139184,,,,,,,,,,,,,, -6000,0.42689225,0.20819159,,,,,,,,,,,,,, -6100,0.09028967,0.27390125,,,,,,,,,,,,,, -6197,,,0.741647584097726,0.2683504138674055,0.7175782761281303,0.2897854420195554,3554.0,0.7351471006920204,0.2911732494414968,3581.0,1472.0305316448212,1549.6855239868164,1472.0305316448212,76.98855376243591,0.451430082321167,0.0 -6200,0.2638806,0.3464596,,,,,,,,,,,,,, -6300,0.30793378,0.2575234,,,,,,,,,,,,,, -6400,0.20937416,0.3230451,,,,,,,,,,,,,, -6500,0.2223091,0.2621946,,,,,,,,,,,,,, -6539,,,0.7450931412833077,0.2707888058253697,0.7206491310407288,0.2930879007918014,3554.0,0.7376347307185492,0.2946943013037559,3581.0,1552.044453382492,1633.7936868667605,1552.044453382492,81.04493761062622,0.4772851467132568,0.0 -6600,0.23889941,0.23072544,,,,,,,,,,,,,, -6700,0.09405093,0.24191621,,,,,,,,,,,,,, -6800,0.106289595,0.27428147,,,,,,,,,,,,,, -6885,,,0.7446616036551339,0.2690360375813075,0.7202530379501969,0.2912670816003974,3554.0,0.7375047860016406,0.2927858661054,3581.0,1632.2036936283112,1718.0472900867462,1632.2036936283112,85.1014678478241,0.5031547546386719,0.0 -6900,0.08971363,0.3607558,,,,,,,,,,,,,, -7000,0.2386419,0.22716732,,,,,,,,,,,,,, -7100,0.11530652,0.29042658,,,,,,,,,,,,,, -7200,0.098007806,0.2358234,,,,,,,,,,,,,, -7230,,,0.7442918504987445,0.2680553708757673,0.7200282692037141,0.2895192504264561,3554.0,0.7374284281407079,0.2910377824150028,3581.0,1712.2826426029203,1802.2185401916504,1712.2826426029203,89.15229082107544,0.5326182842254639,0.0 -7300,0.31944966,0.25651857,,,,,,,,,,,,,, -7400,0.14550675,0.32144135,,,,,,,,,,,,,, -7500,0.24796124,0.24379918,,,,,,,,,,,,,, -7573,,,0.7479324340820312,0.2697854382651193,0.7244976088782006,0.291658640847021,3554.0,0.7413721071278973,0.2932264236923171,3581.0,1792.2599248886108,1886.2895441055296,1792.2599248886108,93.20718359947205,0.5594367980957031,0.0 -7600,0.2599789,0.27332515,,,,,,,,,,,,,, -7700,0.1896949,0.22161116,,,,,,,,,,,,,, -7800,0.19671659,0.3049659,,,,,,,,,,,,,, -7900,0.09172844,0.2975106,,,,,,,,,,,,,, -7920,,,0.7448616709027972,0.2679447957447597,0.7212492471071328,0.2893895550128376,3554.0,0.7383865829420903,0.2909352106277053,3581.0,1872.4262573719025,1970.5471255779264,1872.4262573719025,97.26000475883484,0.5858092308044434,0.0 -8000,0.16261995,0.27310452,,,,,,,,,,,,,, -8100,0.3436862,0.1974329,,,,,,,,,,,,,, -8200,0.104963444,0.32178316,,,,,,,,,,,,,, -8266,,,0.7474032129560199,0.2681253126689366,0.7241728894819218,0.2896141520228088,3554.0,0.7410095436417901,0.2911959863581227,3581.0,1952.595893383026,2054.8102276325226,1952.595893383026,101.31509613990784,0.6120538711547852,0.0 -8300,0.121834755,0.21402803,,,,,,,,,,,,,, -8400,0.16450877,0.19566688,,,,,,,,,,,,,, -8500,0.08797789,0.29567242,,,,,,,,,,,,,, -8600,0.20372438,0.29092833,,,,,,,,,,,,,, -8609,,,0.7405197960989816,0.2686831951141357,0.7167463844655669,0.2899430617899726,3554.0,0.7344287913859606,0.2912986944987433,3581.0,2032.6832168102264,2138.9900114536285,2032.6832168102264,105.36998534202576,0.6376612186431885,0.0 -8700,0.05353328,0.2627376,,,,,,,,,,,,,, -8800,0.2872122,0.20800655,,,,,,,,,,,,,, -8900,0.13470824,0.2572341,,,,,,,,,,,,,, -8955,,,0.7465249470302037,0.266757777759007,0.7218790392427546,0.2887311859216727,3554.0,0.739195499031346,0.2902178557949071,3581.0,2112.7515869140625,2223.14924287796,2112.7515869140625,109.42215585708618,0.6642770767211914,0.0 -9000,0.28194484,0.23726821,,,,,,,,,,,,,, -9100,0.21044333,0.3725266,,,,,,,,,,,,,, -9200,0.15742663,0.28452528,,,,,,,,,,,,,, -9299,,,0.7434581347874233,0.2673584733690534,0.7193650227956176,0.2887551259914005,3554.0,0.7367883856421041,0.2902843621282463,3581.0,2192.7990798950195,2307.28660941124,2192.7990798950195,113.47452092170715,0.6899559497833252,0.0 -9300,0.27122864,0.33627236,,,,,,,,,,,,,, -9400,0.16386648,0.2032339,,,,,,,,,,,,,, -9500,0.15139735,0.33251214,,,,,,,,,,,,,, -9600,0.07909122,0.29882824,,,,,,,,,,,,,, -9642,,,0.7420975821358817,0.2680794170924595,0.7183655849922622,0.2894657373294175,3554.0,0.7355956349483385,0.2910233630510856,3581.0,2273.013704776764,2391.6038851737976,2273.013704776764,117.53413605690002,0.7210087776184082,0.0 -9700,0.12133128,0.2719108,,,,,,,,,,,,,, -9800,0.3357869,0.24428676,,,,,,,,,,,,,, -9900,0.135515,0.23410042,,,,,,,,,,,,,, -9987,,,0.7422417231968471,0.2682451180049351,0.7175602094471019,0.289962261931978,3554.0,0.7351992558381039,0.2913739274447605,3581.0,2353.11772108078,2475.801580429077,2353.11772108078,121.5886685848236,0.748103141784668,0.0 -10000,0.18738471,0.26483506,,,,,,,,,,,,,, -10100,0.12816192,0.35158187,,,,,,,,,,,,,, -10200,0.3293252,0.2824388,,,,,,,,,,,,,, -10300,0.14563046,0.27775106,,,,,,,,,,,,,, -10332,,,0.7457905496869769,0.2669530766350882,0.721928430663513,0.2885337232805465,3554.0,0.739086075489563,0.2899997245662873,3581.0,2433.091839790344,2559.874430179596,2433.091839790344,125.6482367515564,0.7754595279693604,0.0 -10400,0.079140574,0.28894043,,,,,,,,,,,,,, -10500,0.18425132,0.28804886,,,,,,,,,,,,,, -10600,0.10468171,0.2808659,,,,,,,,,,,,,, -10677,,,0.741919994354248,0.2680451188768659,0.7178960573693374,0.2894272340034995,3554.0,0.7353981953364982,0.2908333206070231,3581.0,2513.0755269527435,2643.954001903534,2513.0755269527435,129.70497846603394,0.8024251461029053,0.0 -10700,0.29663217,0.2096886,,,,,,,,,,,,,, -10800,0.14272085,0.30912474,,,,,,,,,,,,,, -10900,0.11686711,0.32930702,,,,,,,,,,,,,, -11000,0.13906164,0.26817486,,,,,,,,,,,,,, -11020,,,0.7446566990443638,0.2675068889345441,0.7196864448508723,0.28960034440727,3554.0,0.7369526913964326,0.2911618639390184,3581.0,2593.118797302246,2728.0914771556854,2593.118797302246,133.76000213623047,0.829798698425293,0.0 -11100,0.25276765,0.35527387,,,,,,,,,,,,,, -11200,0.3154794,0.2705821,,,,,,,,,,,,,, -11300,0.21866657,0.31418625,,,,,,,,,,,,,, -11366,,,0.7424424035208566,0.2673390763146536,0.7179938784863886,0.2889510430039216,3554.0,0.7356547441138998,0.2903087011964186,3581.0,2673.3326404094696,2812.394788742065,2673.3326404094696,137.81110620498657,0.8560574054718018,0.0 -11400,0.1255832,0.27666882,,,,,,,,,,,,,, -11500,0.22659446,0.33982208,,,,,,,,,,,,,, -11600,0.28423053,0.21126027,,,,,,,,,,,,,, -11700,0.32924986,0.24139687,,,,,,,,,,,,,, -11711,,,0.7434718949454171,0.2679608038493565,0.7194423042258723,0.2893848494324177,3554.0,0.736668599247766,0.2908966908139311,3581.0,2753.3085412979126,2896.467357158661,2753.3085412979126,141.86714124679563,0.8846144676208496,0.0 -11800,0.09807556,0.22127244,,,,,,,,,,,,,, -11900,0.11787422,0.41212368,,,,,,,,,,,,,, -12000,0.07197814,0.266298,,,,,,,,,,,,,, -12053,,,0.7494621276855469,0.2662361689976283,0.7248269995075971,0.2883218347724395,3554.0,0.7418477075188494,0.2898577466686156,3581.0,2833.4092671871185,2980.6624734401703,2833.4092671871185,145.92257690429688,0.9114329814910888,0.0 -12100,0.21317461,0.21725714,,,,,,,,,,,,,, -12200,0.24135096,0.20989881,,,,,,,,,,,,,, -12300,0.23838226,0.2795399,,,,,,,,,,,,,, -12397,,,0.7417304856436593,0.2680514369692121,0.7187457409345104,0.2894438237505276,3554.0,0.7357152849893536,0.290946391600199,3581.0,2913.51412153244,3064.8627235889435,2913.51412153244,149.9753041267395,0.9421532154083252,0.0 -12400,0.07722856,0.24382533,,,,,,,,,,,,,, -12500,0.14148112,0.3471246,,,,,,,,,,,,,, -12600,0.1393955,0.34491238,,,,,,,,,,,,,, -12700,0.1380783,0.35647652,,,,,,,,,,,,,, -12743,,,0.7417310987200055,0.2683757202965872,0.7193276529306767,0.2895662031887486,3554.0,0.736301331572012,0.2910311011021712,3581.0,2993.7409834861755,3149.1837322711945,2993.7409834861755,154.03021216392517,0.969123363494873,0.0 -12800,0.05242135,0.24463972,,,,,,,,,,,,,, -12900,0.18568723,0.24800414,,,,,,,,,,,,,, -13000,0.24482128,0.26497748,,,,,,,,,,,,,, -13087,,,0.7479927880423409,0.2658177273614066,0.7230232852727561,0.288119220035963,3554.0,0.7402932796355767,0.2895660528221865,3581.0,3073.824416399002,3233.3625707626343,3073.824416399002,158.08467936515808,0.9979074001312256,0.0 -13100,0.10697616,0.27782813,,,,,,,,,,,,,, -13200,0.2231414,0.19944824,,,,,,,,,,,,,, -13300,0.14787218,0.34369826,,,,,,,,,,,,,, -13400,0.09848216,0.31804475,,,,,,,,,,,,,, -13433,,,0.7433491434369769,0.2668966054916382,0.7194958173229108,0.288490754805325,3554.0,0.7369149215259355,0.2898521561823687,3581.0,3153.797593355179,3317.427288770676,3153.797593355179,162.13646292686462,1.0256588459014893,0.0 -13500,0.106905736,0.33637393,,,,,,,,,,,,,, -13600,0.25314602,0.24400514,,,,,,,,,,,,,, -13700,0.19768813,0.29016814,,,,,,,,,,,,,, -13780,,,0.744424547467913,0.267004234450204,0.7197686722926632,0.288745646135956,3554.0,0.7372898249877827,0.2901215221721411,3581.0,3233.8913176059723,3401.6160349845886,3233.8913176059723,166.1916003227234,1.0534124374389648,0.0 -13800,0.18975039,0.2513768,,,,,,,,,,,,,, -13900,0.093507275,0.25420293,,,,,,,,,,,,,, -14000,0.09226702,0.2814428,,,,,,,,,,,,,, -14100,0.24826783,0.29354292,,,,,,,,,,,,,, -14119,,,0.744767325265067,0.2686776093074253,0.7200198197673396,0.2906745562878095,3554.0,0.7371195196872382,0.2922536449970329,3581.0,3313.9427967071533,3485.763567209244,3313.9427967071533,170.2469208240509,1.0821444988250732,0.0 -14200,0.34603336,0.23278958,,,,,,,,,,,,,, -14300,0.20872459,0.30931044,,,,,,,,,,,,,, -14400,0.31541064,0.23458791,,,,,,,,,,,,,, -14465,,,0.743727479662214,0.2669003861291067,0.7195809986327026,0.2886858818298923,3554.0,0.7368043389809061,0.2901606214875384,3581.0,3393.917350292206,3569.834620714188,3393.917350292206,174.30322456359863,1.1104457378387451,0.0 -14500,0.09480324,0.36006987,,,,,,,,,,,,,, -14600,0.16598892,0.3504266,,,,,,,,,,,,,, -14700,0.08109122,0.2423788,,,,,,,,,,,,,, -14800,0.19487716,0.26026312,,,,,,,,,,,,,, -14813,,,0.7420859336853027,0.2672189133507864,0.7172307501231008,0.2889552333748065,3554.0,0.7348131714037629,0.2903403692557072,3581.0,3474.073793888092,3654.090247869492,3474.073793888092,178.3620719909668,1.138521432876587,0.0 -14900,0.28362393,0.23685268,,,,,,,,,,,,,, -15000,0.2414888,0.24800931,,,,,,,,,,,,,, -15100,0.12181049,0.21940416,,,,,,,,,,,,,, -15157,,,0.7469425882611956,0.2663203307560512,0.7227949444068303,0.2880200250268184,3554.0,0.7399728493263055,0.289497194394024,3581.0,3554.166335821152,3738.2798635959625,3554.166335821152,182.4195261001587,1.1659214496612549,0.0 -15200,0.10509395,0.22032762,,,,,,,,,,,,,, -15300,0.15382329,0.20225304,,,,,,,,,,,,,, -15400,0.029774902,0.3553143,,,,,,,,,,,,,, -15500,0.31758833,0.27322507,,,,,,,,,,,,,, -15501,,,0.74376494543893,0.2666252681187221,0.7195468574141812,0.2883164078986705,3554.0,0.7369443056670623,0.2897083375148352,3581.0,3638.211625099182,3826.375822782517,3638.211625099182,186.43106985092163,1.1933035850524902,0.0 -15600,0.12559356,0.26880246,,,,,,,,,,,,,, -15700,0.08364001,0.22433919,,,,,,,,,,,,,, -15800,0.13656823,0.3074873,,,,,,,,,,,,,, -15845,,,0.7452110563005719,0.2667724745614188,0.7208480706158554,0.2887018876727806,3554.0,0.7381134672359327,0.2901179769857407,3581.0,3718.3154895305634,3910.5722908973694,3718.3154895305634,190.48362708091736,1.2215111255645752,0.0 -15900,0.1871746,0.30673435,,,,,,,,,,,,,, -16000,0.17021054,0.20037325,,,,,,,,,,,,,, -16100,0.15358445,0.29680598,,,,,,,,,,,,,, -16188,,,0.7452451160975865,0.2665325743811471,0.7206068838588562,0.2883681349359876,3554.0,0.7381221256719491,0.289755788471272,3581.0,3798.436872243881,3994.7888975143433,3798.436872243881,194.5382580757141,1.2500929832458496,0.0 -16200,0.24687317,0.2872494,,,,,,,,,,,,,, -16300,0.072479315,0.29246762,,,,,,,,,,,,,, -16400,0.16694291,0.3378675,,,,,,,,,,,,,, -16500,0.13032357,0.34423536,,,,,,,,,,,,,, -16531,,,0.7482778685433524,0.2653563703809465,0.7236172188247397,0.2874118201707143,3554.0,0.7409147099055781,0.2888766504206227,3581.0,3878.410897016525,4078.8552227020255,3878.410897016525,198.59056663513184,1.278146743774414,0.0 -16600,0.16344051,0.23613483,,,,,,,,,,,,,, -16700,0.10558458,0.311645,,,,,,,,,,,,,, -16800,0.15176462,0.23822285,,,,,,,,,,,,,, -16876,,,0.7445151465279716,0.2660742487226213,0.7202597700214547,0.2879061465459254,3554.0,0.737625595045902,0.2892800858153274,3581.0,3958.4317643642426,4162.973185300827,3958.4317643642426,202.6435163021088,1.31022047996521,0.0 -16900,0.13273218,0.34980845,,,,,,,,,,,,,, -17000,0.19122799,0.22434533,,,,,,,,,,,,,, -17100,0.16070178,0.25722268,,,,,,,,,,,,,, -17200,0.088679954,0.26295024,,,,,,,,,,,,,, -17219,,,0.7482468741280692,0.2657501527241298,0.7237485629088702,0.2875324135492139,3554.0,0.7409924994764032,0.2890235370379259,3581.0,4038.617102861405,4247.254467487335,4038.617102861405,206.69846272468567,1.3392434120178225,0.0 -17300,0.17809358,0.27987522,,,,,,,,,,,,,, -17400,0.17601138,0.2643921,,,,,,,,,,,,,, -17500,0.10624358,0.2675184,,,,,,,,,,,,,, -17563,,,0.7473959241594587,0.2655086347034999,0.7226291156311551,0.2876453646529439,3554.0,0.7399236257766685,0.2890411607049358,3581.0,4118.709421873093,4331.440586566925,4118.709421873093,210.75095224380493,1.368750810623169,0.0 -17600,0.10278285,0.31901947,,,,,,,,,,,,,, -17700,0.13832818,0.2586004,,,,,,,,,,,,,, -17800,0.22780831,0.3392429,,,,,,,,,,,,,, -17900,0.12497261,0.29671767,,,,,,,,,,,,,, -17910,,,0.7436424664088658,0.2669028724942888,0.7204089747027996,0.2886433255222988,3554.0,0.737226557045867,0.2900685148177883,3581.0,4198.857976913452,4415.681092500687,4198.857976913452,214.8023371696472,1.397280216217041,0.0 -18000,0.12643358,0.28849357,,,,,,,,,,,,,, -18100,0.08195696,0.25573835,,,,,,,,,,,,,, -18200,0.10215136,0.30251676,,,,,,,,,,,,,, -18255,,,0.7477162224905831,0.2655533381870815,0.7232185153392304,0.287289921594726,3554.0,0.7405509874162245,0.2886897441029915,3581.0,4278.897131919861,4499.818293809891,4278.897131919861,218.8605060577393,1.4254579544067385,0.0 -18300,0.093971185,0.21915844,,,,,,,,,,,,,, -18400,0.12820348,0.3055405,,,,,,,,,,,,,, -18500,0.29499027,0.27059403,,,,,,,,,,,,,, -18597,,,0.749964850289481,0.2652980940682547,0.7248652624023987,0.2878781363208708,3554.0,0.7419169750069813,0.2893079700699001,3581.0,4358.909790039063,4583.929002761841,4358.909790039063,222.91399669647217,1.4582040309906006,0.0 -18600,0.091796696,0.29995123,,,,,,,,,,,,,, -18700,0.106588885,0.21217947,,,,,,,,,,,,,, -18800,0.15250744,0.3882336,,,,,,,,,,,,,, -18900,0.2265283,0.22216298,,,,,,,,,,,,,, -18941,,,0.7486487797328404,0.265641553061349,0.724370455164955,0.2878657712920301,3554.0,0.7415229139032393,0.2892563262487783,3581.0,4439.060678720474,4668.17610168457,4439.060678720474,226.9695131778717,1.4869599342346191,0.0 -19000,0.067497484,0.24953412,,,,,,,,,,,,,, -19100,0.10797943,0.24637598,,,,,,,,,,,,,, -19200,0.21246827,0.19517462,,,,,,,,,,,,,, -19286,,,0.7493959154401507,0.2651561668940952,0.7251388730128024,0.2871962393276062,3554.0,0.7422357690763753,0.2885983191998569,3581.0,4519.121737480164,4752.34153175354,4519.121737480164,231.0234580039978,1.5253701210021973,0.0 -19300,0.1398362,0.24597548,,,,,,,,,,,,,, -19400,0.10081337,0.2721154,,,,,,,,,,,,,, -19500,0.10577917,0.25951472,,,,,,,,,,,,,, -19600,0.103030406,0.27812183,,,,,,,,,,,,,, -19628,,,0.7444512503487724,0.2651771136692592,0.7195708318312113,0.287422742612857,3554.0,0.7369073539165037,0.2888892971913397,3581.0,4599.299153327942,4836.61348938942,4599.299153327942,235.07659649848927,1.554959058761597,0.0 -19700,0.26149982,0.22509518,,,,,,,,,,,,,, -19800,0.15752932,0.34915197,,,,,,,,,,,,,, -19900,0.24924389,0.24490036,,,,,,,,,,,,,, -19972,,,0.7471006257193429,0.2648437534059797,0.7225424917346651,0.2867585516539462,3554.0,0.7399095132077282,0.2881554436064472,3581.0,4679.32307100296,4920.735057592392,4679.32307100296,239.13303208351127,1.5843939781188965,0.0 -20000,0.05656625,0.2891072,,,,,,,,,,,,,, -20100,0.2190646,0.35274395,,,,,,,,,,,,,, -20200,0.10002956,0.35751122,,,,,,,,,,,,,, -20300,0.08217113,0.2696283,,,,,,,,,,,,,, -20317,,,0.7499920981270927,0.2653532709394182,0.7255031605013716,0.2872765776677687,3554.0,0.7426582598479824,0.2887418310724134,3581.0,4759.342108011246,5004.854192733765,4759.342108011246,243.1871690750122,1.61844801902771,0.0 -20400,0.056808162,0.26753813,,,,,,,,,,,,,, -20500,0.13376778,0.2188629,,,,,,,,,,,,,, -20600,0.18500704,0.29131126,,,,,,,,,,,,,, -20660,,,0.7488868577139718,0.2645001922334943,0.7233265032577729,0.2871020246773002,3554.0,0.7406960673519967,0.2885014060754677,3581.0,4839.4510152339935,5089.058373451233,4839.4510152339935,247.2418978214264,1.6471836566925049,0.0 -20700,0.10786109,0.30792838,,,,,,,,,,,,,, -20800,0.14072637,0.26324424,,,,,,,,,,,,,, -20900,0.11726256,0.3048966,,,,,,,,,,,,,, -21000,0.18917556,0.32442456,,,,,,,,,,,,,, -21006,,,0.7488184656415667,0.2645917109080723,0.7240962949977139,0.2867401930208479,3554.0,0.7413716980679279,0.2881249686387357,3581.0,4919.556501865387,5173.258333683014,4919.556501865387,251.29483437538147,1.676776647567749,0.0 -21100,0.10066877,0.26474863,,,,,,,,,,,,,, -21200,0.17445572,0.22203039,,,,,,,,,,,,,, -21300,0.16375585,0.24330416,,,,,,,,,,,,,, -21349,,,0.746589115687779,0.2646516731807163,0.721595399220069,0.2866960223900446,3554.0,0.7390779624668389,0.2880124089705215,3581.0,4999.725428342819,5257.524056196213,4999.725428342819,255.34806871414185,1.7086007595062256,0.0 -21400,0.12356733,0.21647422,,,,,,,,,,,,,, -21500,0.17712967,0.20721924,,,,,,,,,,,,,, -21600,0.16386175,0.2332235,,,,,,,,,,,,,, -21692,,,0.7493458475385394,0.2651270457676479,0.7248130545028489,0.2876467900659908,3554.0,0.742009149853393,0.2891149619410604,3581.0,5079.9467849731445,5341.839040994644,5079.9467849731445,259.3984282016754,1.74009370803833,0.0 -21700,0.096778914,0.25072083,,,,,,,,,,,,,, -21800,0.053104598,0.23397589,,,,,,,,,,,,,, -21900,0.179288,0.20599885,,,,,,,,,,,,,, -22000,0.13637601,0.2807248,,,,,,,,,,,,,, -22037,,,0.7473893165588379,0.2643566301890782,0.7223388122318163,0.286699439946627,3554.0,0.7397672285150796,0.2879780479331018,3581.0,5160.113996267319,5426.103360176086,5160.113996267319,263.4543421268463,1.769385814666748,0.0 -22100,0.091377765,0.37554935,,,,,,,,,,,,,, -22200,0.11968252,0.22994193,,,,,,,,,,,,,, -22300,0.06869254,0.2209152,,,,,,,,,,,,,, -22382,,,0.7503771781921387,0.2640033108847482,0.7255582535743177,0.2863131358372608,3554.0,0.7427755918825049,0.2877156700511379,3581.0,5240.150569677353,5510.238574266434,5240.150569677353,267.50645327568054,1.8041374683380127,0.0 -22400,0.13096295,0.22629805,,,,,,,,,,,,,, -22500,0.13040783,0.2511494,,,,,,,,,,,,,, -22600,0.07892277,0.21353965,,,,,,,,,,,,,, -22700,0.27451897,0.16121133,,,,,,,,,,,,,, -22723,,,0.7449972970145089,0.2649799925940377,0.720803899985052,0.2868631220159063,3554.0,0.7379846133456087,0.2882191887849937,3581.0,5320.23034787178,5594.409419775009,5320.23034787178,271.55556893348694,1.8343710899353027,0.0 -22800,0.15508752,0.30736652,,,,,,,,,,,,,, -22900,0.178705,0.23568058,,,,,,,,,,,,,, -23000,0.06525999,0.32364705,,,,,,,,,,,,,, -23069,,,0.7490667615618024,0.2638182810374668,0.723865893293648,0.2863307560033589,3554.0,0.7411706450930257,0.2876683554480242,3581.0,5400.266304969788,5678.5408182144165,5400.266304969788,275.6089406013489,1.864561557769776,0.0 -23100,0.08859985,0.27801648,,,,,,,,,,,,,, -23200,0.07836716,0.30655986,,,,,,,,,,,,,, -23300,0.037660435,0.24863525,,,,,,,,,,,,,, -23400,0.09998226,0.35725492,,,,,,,,,,,,,, -23414,,,0.75071838923863,0.2641075338636126,0.7261588505029544,0.2864420584365767,3554.0,0.7431733345259705,0.2879020650438076,3581.0,5480.29541182518,5762.661432981491,5480.29541182518,279.65731549263,1.895724058151245,0.0 -23500,0.12693065,0.27536136,,,,,,,,,,,,,, -23600,0.13106807,0.28364596,,,,,,,,,,,,,, -23700,0.07756099,0.2529026,,,,,,,,,,,,,, -23757,,,0.7506724766322544,0.2638962609427316,0.725684651646912,0.2863045318380258,3554.0,0.7428464956105139,0.2876682872713627,3581.0,5560.32515335083,5846.788516521454,5560.32515335083,283.71100759506226,1.9275357723236084,0.0 -23800,0.0858587,0.23442121,,,,,,,,,,,,,, -23900,0.16294673,0.23468879,,,,,,,,,,,,,, -24000,0.06915961,0.2289792,,,,,,,,,,,,,, -24100,0.12200589,0.27307233,,,,,,,,,,,,,, -24101,,,0.7491707120622907,0.2632827929088047,0.7234838139024691,0.2859516648274831,3554.0,0.7408622138761868,0.2873740368001257,3581.0,5640.499708890915,5931.000749349594,5640.499708890915,287.70102548599243,1.963218688964844,0.0 -24200,0.13150118,0.27524683,,,,,,,,,,,,,, -24300,0.11221302,0.29660702,,,,,,,,,,,,,, -24400,0.22947983,0.2176639,,,,,,,,,,,,,, -24446,,,0.751819406236921,0.2637123380388532,0.7269841474834693,0.2862745638167381,3554.0,0.7441044231927185,0.2876302787825502,3581.0,5720.588813781738,6015.1900470256805,5720.588813781738,291.7552845478058,1.997541904449463,0.0 -24500,0.12437662,0.36199248,,,,,,,,,,,,,, -24600,0.08168582,0.19356044,,,,,,,,,,,,,, -24700,0.1362882,0.28187686,,,,,,,,,,,,,, -24788,,,0.7498753411429269,0.2634058339255197,0.7250181078977912,0.2857233583088597,3554.0,0.7423019686147375,0.2870329489624058,3581.0,5800.679148674011,6099.376039981842,5800.679148674011,295.8091485500336,2.027472972869873,0.0 -24800,0.09562112,0.26872578,,,,,,,,,,,,,, -24900,0.17476647,0.30422464,,,,,,,,,,,,,, -25000,0.09108461,0.2504093,,,,,,,,,,,,,, -25100,0.11556434,0.25750598,,,,,,,,,,,,,, -25132,,,0.7507461139133998,0.262839674949646,0.7252613554929305,0.2856844428153137,3554.0,0.7424756145716979,0.2870723209844492,3581.0,5880.765575647354,6183.556484937668,5880.765575647354,299.86060905456543,2.058157205581665,0.0 -25200,0.07152415,0.3292465,,,,,,,,,,,,,, -25300,0.19428423,0.2537584,,,,,,,,,,,,,, -25400,0.1096887,0.2796635,,,,,,,,,,,,,, -25477,,,0.7495565414428711,0.2632796083177839,0.7245314066236986,0.2858350900833568,3554.0,0.7417635775185004,0.2872153897087057,3581.0,5960.800450563431,6267.692300796509,5960.800450563431,303.9175374507904,2.090134382247925,0.0 -25500,0.118537664,0.20872179,,,,,,,,,,,,,, -25600,0.1468344,0.2506011,,,,,,,,,,,,,, -25700,0.04999792,0.26269072,,,,,,,,,,,,,, -25800,0.079051636,0.27225077,,,,,,,,,,,,,, -25822,,,0.7506115777151925,0.2631086451666696,0.7256509225960186,0.2855094433099149,3554.0,0.7427894999214605,0.2868910392413956,3581.0,6040.952090978622,6351.949378728867,6040.952090978622,307.97506952285767,2.125797748565674,0.0 -25900,0.096207894,0.19512883,,,,,,,,,,,,,, -26000,0.15196009,0.26011062,,,,,,,,,,,,,, -26100,0.07463877,0.21875237,,,,,,,,,,,,,, -26162,,,0.7509149823869977,0.2625585283551897,0.7250569203494303,0.2856405126156268,3554.0,0.742301900438076,0.2870183250685039,3581.0,6121.030804872513,6436.124837398529,6121.030804872513,312.0299699306488,2.155958414077759,0.0 -26200,0.115750134,0.253106,,,,,,,,,,,,,, -26300,0.13630134,0.28216293,,,,,,,,,,,,,, -26400,0.11253358,0.26133212,,,,,,,,,,,,,, -26500,0.14888792,0.24749915,,,,,,,,,,,,,, -26506,,,0.7500993183680943,0.2629836286817278,0.724892946328081,0.2856195607612021,3554.0,0.7421060288894513,0.2870047238245253,3581.0,6201.037936925888,6520.235012292862,6201.037936925888,316.0838363170624,2.1931822299957275,0.0 -26600,0.047904525,0.31736812,,,,,,,,,,,,,, -26700,0.080437824,0.33510113,,,,,,,,,,,,,, -26800,0.0613183,0.2901887,,,,,,,,,,,,,, -26852,,,0.750182969229562,0.2629303080695016,0.725069285378271,0.2854501426924328,3554.0,0.7423404884285116,0.2867839677944359,3581.0,6281.198880434036,6604.5012130737305,6281.198880434036,320.1396276950836,2.2305245399475098,0.0 -26900,0.091056265,0.20274058,,,,,,,,,,,,,, -27000,0.083470315,0.2435665,,,,,,,,,,,,,, -27100,0.06059736,0.38364726,,,,,,,,,,,,,, -27197,,,0.7514676366533551,0.2624048846108572,0.7256813543058878,0.2854863447490943,3554.0,0.7428371554078819,0.2868568827339605,3581.0,6361.350186109543,6688.750422000885,6361.350186109543,324.193115234375,2.2626664638519287,0.0 -27200,0.06697046,0.20942396,,,,,,,,,,,,,, -27300,0.076049276,0.236196,,,,,,,,,,,,,, -27400,0.10806658,0.217197,,,,,,,,,,,,,, -27500,0.08973122,0.2585779,,,,,,,,,,,,,, -27544,,,0.7516813278198242,0.2626441035951887,0.7261767797947735,0.2857002940953415,3554.0,0.7433970903291678,0.2870722868961184,3581.0,6441.491717815399,6772.989144325256,6441.491717815399,328.2473568916321,2.2936313152313232,0.0 -27600,0.07072892,0.31849372,,,,,,,,,,,,,, -27700,0.049875457,0.4038784,,,,,,,,,,,,,, -27800,0.070085675,0.34231782,,,,,,,,,,,,,, -27890,,,0.7515984943934849,0.2627055134092058,0.72654058642111,0.2854648776851347,3554.0,0.7436194144224728,0.2868547351691217,3581.0,6521.546169281006,6857.138879299164,6521.546169281006,332.3001654148102,2.324267625808716,0.0 -27900,0.109678544,0.20433427,,,,,,,,,,,,,, -28000,0.039229065,0.23859356,,,,,,,,,,,,,, -28100,0.13537608,0.23785977,,,,,,,,,,,,,, -28200,0.07282678,0.28410396,,,,,,,,,,,,,, -28230,,,0.7507645743233817,0.2624414137431553,0.7252138188264983,0.2853406606662387,3554.0,0.7424111876265359,0.2867318126483524,3581.0,6601.5510313510895,6941.242228984833,6601.5510313510895,336.35513162612915,2.356107234954834,0.0 -28300,0.046392594,0.26261133,,,,,,,,,,,,,, -28400,0.07823344,0.27926934,,,,,,,,,,,,,, -28500,0.033111062,0.25803244,,,,,,,,,,,,,, -28575,,,0.7513136182512555,0.2621686969484602,0.7256709814205824,0.2852038553610369,3554.0,0.7428912195004886,0.2865787219548485,3581.0,6681.67812538147,7025.472114086151,6681.67812538147,340.4086203575134,2.393435478210449,0.0 -28600,0.073562846,0.33931798,,,,,,,,,,,,,, -28700,0.06785339,0.25267082,,,,,,,,,,,,,, -28800,0.05730267,0.23508272,,,,,,,,,,,,,, -28900,0.076093495,0.31103158,,,,,,,,,,,,,, -28920,,,0.7508598055158343,0.2622612885066441,0.7252665075882808,0.2850795353002339,3554.0,0.7425641760550474,0.2863938268487329,3581.0,6761.666354894638,7109.558537721634,6761.666354894638,344.46311378479004,2.425102949142456,0.0 -29000,0.050313447,0.24261814,,,,,,,,,,,,,, -29100,0.07517458,0.20343211,,,,,,,,,,,,,, -29200,0.12395911,0.27655193,,,,,,,,,,,,,, -29264,,,0.7527449471609933,0.2630555118833269,0.7273026156707231,0.2864035035897053,3554.0,0.7443117484204831,0.2878441830581541,3581.0,6841.841877698898,7193.83359837532,6841.841877698898,348.5168535709381,2.458962202072144,0.0 -29300,0.03706483,0.2676813,,,,,,,,,,,,,, -29400,0.079543225,0.222135,,,,,,,,,,,,,, -29500,0.051321227,0.1953676,,,,,,,,,,,,,, -29600,0.11039052,0.2837155,,,,,,,,,,,,,, -29609,,,0.7518707684108189,0.261823228427342,0.7259042682980444,0.2850722021511853,3554.0,0.7430781599064508,0.2864239950214675,3581.0,6921.887184858322,7277.980953931808,6921.887184858322,352.57623052597046,2.48972225189209,0.0 -29700,0.04491567,0.3814631,,,,,,,,,,,,,, -29800,0.053442206,0.24849595,,,,,,,,,,,,,, -29900,0.07158513,0.27162737,,,,,,,,,,,,,, -29954,,,0.7520835059029716,0.2619952474321638,0.7263647282331528,0.284983792194974,3554.0,0.7435643958566043,0.2863624655844212,3581.0,7002.001019716263,7362.185358047485,7002.001019716263,356.6236889362335,2.5209431648254395,0.0 -30000,0.05197224,0.24638113,,,,,,,,,,,,,, -30100,0.03218053,0.2664817,,,,,,,,,,,,,, -30200,0.0635251,0.24449816,,,,,,,,,,,,,, -30295,,,0.7514946801321847,0.2618656499045236,0.7258097445220174,0.284842779345236,3554.0,0.7430245048738132,0.2862053524678861,3581.0,7081.958555698395,7446.234877347946,7081.958555698395,360.6707994937897,2.5538601875305176,0.0 -30300,0.039061736,0.24177149,,,,,,,,,,,,,, -30400,0.045133803,0.2243509,,,,,,,,,,,,,, -30500,0.054950763,0.21203865,,,,,,,,,,,,,, -30600,0.07646884,0.24820763,,,,,,,,,,,,,, -30639,,,0.7530231475830078,0.2613992350442068,0.7269642947427195,0.2848137730484137,3554.0,0.7440924241002862,0.2862288734161198,3581.0,7162.066142082214,7530.441558122635,7162.066142082214,364.7251925468445,2.58677077293396,0.0 -30700,0.10729136,0.31329757,,,,,,,,,,,,,, -30800,0.053569928,0.3125068,,,,,,,,,,,,,, -30900,0.022170756,0.19442667,,,,,,,,,,,,,, -30982,,,0.7528040749686105,0.2616708278656006,0.7269228032014983,0.2849631838135727,3554.0,0.7440678805021292,0.2863513187002583,3581.0,7242.0318784713745,7614.505212068558,7242.0318784713745,368.7781274318695,2.6195969581604004,0.0 -31000,0.02413684,0.30950338,,,,,,,,,,,,,, -31100,0.05885296,0.25907266,,,,,,,,,,,,,, -31200,0.036812264,0.25416246,,,,,,,,,,,,,, -31300,0.027739018,0.29782948,,,,,,,,,,,,,, -31328,,,0.7522945404052734,0.2616599457604544,0.7264111657859103,0.2848349481603035,3554.0,0.7435416248516475,0.2861925352555152,3581.0,7322.19585609436,7698.769409894943,7322.19585609436,372.8330626487732,2.653053998947144,0.0 -31400,0.05353403,0.26132542,,,,,,,,,,,,,, -31500,0.07973087,0.20010327,,,,,,,,,,,,,, -31600,0.036664575,0.21879221,,,,,,,,,,,,,, -31673,,,0.7529420171465192,0.2611620766775949,0.7266534516565841,0.2847431721684633,3554.0,0.7437561086288746,0.2861175068394826,3581.0,7402.319948911667,7782.991308689117,7402.319948911667,376.886527299881,2.6854922771453857,0.0 -31700,0.041152693,0.24066679,,,,,,,,,,,,,, -31800,0.062991895,0.21614817,,,,,,,,,,,,,, -31900,0.019447839,0.32254913,,,,,,,,,,,,,, -32000,0.0615038,0.18745087,,,,,,,,,,,,,, -32017,,,0.7525466510227748,0.2614663669041225,0.7265271222785945,0.2848308780049768,3554.0,0.7436601840660779,0.2862083522409941,3581.0,7482.313843011856,7867.077488183975,7482.313843011856,380.9351851940155,2.71712589263916,0.0 -32100,0.016812967,0.3887905,,,,,,,,,,,,,, -32200,0.075901024,0.2250903,,,,,,,,,,,,,, -32300,0.01831712,0.3106378,,,,,,,,,,,,,, -32363,,,0.7524669510977608,0.261459265436445,0.7265055521727279,0.2847815037578696,3554.0,0.7436575933529391,0.286116859161198,3581.0,7562.373685836792,7951.238343477249,7562.373685836792,384.9920229911804,2.7492971420288086,0.0 -32400,0.036860995,0.31834206,,,,,,,,,,,,,, -32500,0.054657754,0.2834528,,,,,,,,,,,,,, -32600,0.016641855,0.24895936,,,,,,,,,,,,,, -32700,0.018781347,0.3327886,,,,,,,,,,,,,, -32704,,,0.7528903824942452,0.2610211712973458,0.7264628241286227,0.2847771588241242,3554.0,0.7436016884904706,0.2861338010615924,3581.0,7642.534946680069,8035.50048160553,7642.534946680069,389.0479950904846,2.7822179794311523,0.0 -32800,0.023665618,0.26034892,,,,,,,,,,,,,, -32900,0.036571655,0.22596054,,,,,,,,,,,,,, -33000,0.027588734,0.26931605,,,,,,,,,,,,,, -33048,,,0.7523625237601144,0.2612138135092599,0.7261041695976365,0.2847444773659521,3554.0,0.7432874622574002,0.2861019284723192,3581.0,7722.624789953232,8119.691402673721,7722.624789953232,393.1053960323334,2.8138113021850586,0.0 -33100,0.023609594,0.2640354,,,,,,,,,,,,,, -33200,0.033540763,0.34788716,,,,,,,,,,,,,, -33300,0.053743374,0.2515204,,,,,,,,,,,,,, -33393,,,0.7526929037911552,0.2612191098076956,0.7265927943206598,0.2846999976094277,3554.0,0.7436916816837127,0.2860446941649504,3581.0,7802.774275302887,8203.940031290054,7802.774275302887,397.1605281829834,2.84609603881836,0.0 -33400,0.02917264,0.2511644,,,,,,,,,,,,,, -33500,0.025394706,0.25194356,,,,,,,,,,,,,, -33600,0.045058824,0.22722647,,,,,,,,,,,,,, -33700,0.033793684,0.2861045,,,,,,,,,,,,,, -33734,,,0.7531563213893345,0.2610527958188738,0.7266605958954699,0.2848787581444323,3554.0,0.7437669487180606,0.2862703930030019,3581.0,7882.909619808197,8288.175689220428,7882.909619808197,401.21487951278687,2.8802406787872314,0.0 -33800,0.01882552,0.2948755,,,,,,,,,,,,,, -33900,0.022676283,0.36903974,,,,,,,,,,,,,, -34000,0.020229498,0.25428143,,,,,,,,,,,,,, -34081,,,0.7531188556126186,0.2610229934964861,0.7267718811550365,0.284685382832284,3554.0,0.7439093697640324,0.286056795522375,3581.0,7963.108328819275,8372.476554393768,7963.108328819275,405.2720079421997,2.913221836090088,0.0 -34100,0.032060787,0.30068025,,,,,,,,,,,,,, -34200,0.018761065,0.33419478,,,,,,,,,,,,,, -34300,0.046508145,0.25028726,,,,,,,,,,,,,, -34400,0.01682359,0.26908153,,,,,,,,,,,,,, -34426,,,0.7529817308698382,0.2610433782849993,0.7267184367526027,0.2846396322255733,3554.0,0.743857350971272,0.2859959137636135,3581.0,8043.085624933243,8456.55000448227,8043.085624933243,409.32266187667847,2.9466326236724854,0.0 -34500,0.02006406,0.26337606,,,,,,,,,,,,,, -34600,0.019923203,0.2590447,,,,,,,,,,,,,, -34700,0.0848638,0.35597408,,,,,,,,,,,,,, -34767,,,0.7531920160566058,0.2610036305018833,0.7267810175374578,0.2846758686295371,3554.0,0.7439245731595574,0.286009378654269,3581.0,8123.120993852615,8540.69055390358,8123.120993852615,413.373916387558,2.9849205017089844,0.0 -34800,0.025178961,0.36476988,,,,,,,,,,,,,, -34900,0.021063598,0.2264029,,,,,,,,,,,,,, -35000,0.020939814,0.2970116,,,,,,,,,,,,,, -35100,0.015770707,0.27077484,,,,,,,,,,,,,, -35114,,,0.7533129964556012,0.2608975853238787,0.7268919593240011,0.2846280571846862,3554.0,0.7440337921713558,0.2859910050439821,3581.0,8203.180258989334,8624.85003042221,8203.180258989334,417.4235715866089,3.0231704711914062,0.0 -35200,0.01767767,0.22105145,,,,,,,,,,,,,, -35300,0.03034605,0.25275043,,,,,,,,,,,,,, -35400,0.030363046,0.2929788,,,,,,,,,,,,,, -35460,,,0.7532985551016671,0.2609228576932634,0.7269186128306134,0.2846089600845878,3554.0,0.7440485865069115,0.2859707224871718,3581.0,8283.291579961777,8709.06115269661,8283.291579961777,421.4790394306183,3.05560564994812,0.0 -35500,0.021810064,0.1925283,,,,,,,,,,,,,, -35600,0.025369085,0.23321062,,,,,,,,,,,,,, -35700,0.015793202,0.2969398,,,,,,,,,,,,,, -35800,0.07547373,0.27197835,,,,,,,,,,,,,, -35806,,,0.7533233506338937,0.2609026942934309,0.7269327639191756,0.284607654887099,3554.0,0.7440701303319603,0.2859652001775866,3581.0,8363.375246286392,8793.248180627823,8363.375246286392,425.5373492240906,3.088459014892578,0.0 -35900,0.028373383,0.2650392,,,,,,,,,,,,,, -36000,0.016285896,0.22829711,,,,,,,,,,,,,, -36100,0.035409413,0.2578897,,,,,,,,,,,,,, -36151,,,0.7532243728637695,0.2608860220227922,0.7268157770074212,0.2845854665297904,3554.0,0.7439553890105767,0.2859424973492914,3581.0,8443.398798942566,8877.369792938232,8443.398798942566,429.58997654914856,3.1220977306365967,0.0 -36189,,,0.7532238279070173,0.2608859198434012,0.7268150213667698,0.2845853634878834,3554.0,0.7439548435972843,0.2859423439518029,3581.0,8450.03044462204,8888.094502449036,8450.03044462204,433.6436011791229,3.156363010406494,0.0 -36189,,,,,,,,,,,8450.03044462204,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 71f4f659a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,83 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -201.3988561630249,0.0,55.26950407028198,1,0,55.26950407028198,0.9828197539967886,3581,0.2181308195598296,256.6686406135559,0.983511243547712,0.199856059891837,0.9851154591437464,3554,0.1934721543672182 -205.707136631012,0.0273573398590087,135.39524340629578,335,0,135.39524340629578,0.3296139101115261,3581,0.6998325444839081,341.14130687713623,0.3055055141448974,0.7025839260646275,0.3277747949053883,3554,0.6819424690929234 -209.72573137283325,0.0587446689605712,215.4636936187744,574,0,215.4636936187744,0.3168418647789549,3581,0.7133247058040701,425.2674949169159,0.2934543064662388,0.7165115901402065,0.3148490463540201,3554,0.6956642163759145 -213.747370004654,0.0868842601776123,295.4432170391083,816,0,295.4432170391083,0.3096011283510193,3581,0.7193678850836009,509.3046169281006,0.286161150251116,0.7226732798985073,0.3074523891433948,3554,0.701933079914357 -217.76836442947388,0.1191539764404296,375.56945061683655,1097,0,375.56945061683655,0.3083697896668179,3581,0.7202405463514032,593.4932446479797,0.2846875020435878,0.7242628506251744,0.3061716812812148,3554,0.7031247252215813 -221.79093170166016,0.1434721946716308,455.658816576004,1446,0,455.658816576004,0.3043452872473645,3581,0.7224767408501466,677.641458272934,0.2813701118741716,0.725240979875837,0.3022148033575724,3554,0.7055216173677546 -225.81270360946647,0.1660394668579101,535.7831084728241,1793,0,535.7831084728241,0.2997427149146537,3581,0.7277350022252862,761.8218240737915,0.2767983845302036,0.7315128190176827,0.2979158949972742,3554,0.7101681204276871 -229.82861018180847,0.191880464553833,615.940948009491,2139,0,615.940948009491,0.298683317770874,3581,0.7281596064734013,846.0330247879028,0.2759296894073486,0.731431416102818,0.2970016385037106,3554,0.7106693162633653 -233.85149574279785,0.2148540019989013,695.9960536956787,2488,0,695.9960536956787,0.2965060279077073,3581,0.7310327073443172,930.1457843780518,0.2735965251922607,0.7350499289376395,0.2947435094589722,3554,0.7137692976883441 -237.86601328849792,0.2370862960815429,776.1169848442078,2836,0,776.1169848442078,0.2955520318826794,3581,0.7336718259128037,1014.315054178238,0.2729426962988717,0.7375479425702777,0.293957299708515,3554,0.7163687015290869 -241.8804790973664,0.2626070976257324,856.2830686569214,3184,0,856.2830686569214,0.2947986115959229,3581,0.7322620007286722,1098.5327405929563,0.2717396872384207,0.7368855476379395,0.2929867823337261,3554,0.7150457808323368 -245.89754986763,0.2908327579498291,936.331609249115,3532,0,936.331609249115,0.296571102531154,3581,0.7314617430754329,1182.6382551193235,0.2730930362428937,0.7362298965454102,0.2945232745563977,3554,0.7147958698605444 -249.9260466098785,0.3137433528900146,1016.4203460216522,3881,0,1016.4203460216522,0.2941487175697256,3581,0.7325347755515219,1266.7900638580322,0.2710892983845302,0.7370754650660923,0.2924375346220808,3554,0.7154304706184933 -253.9429223537445,0.3369905948638916,1096.4697914123535,4225,0,1096.4697914123535,0.2935189356586847,3581,0.7330105122957973,1350.8909401893616,0.2704846348081316,0.7373881340026855,0.2919068001063942,3554,0.7155515105119232 -257.966778755188,0.3606770038604736,1176.554782152176,4571,0,1176.554782152176,0.2930620497613271,3581,0.7342269884677813,1435.0348629951477,0.2703211988721575,0.7382924216134208,0.2915198777455859,3554,0.7167820369653911 -261.98525977134705,0.3870325088500976,1256.7266416549685,4919,0,1256.7266416549685,0.2930749692386903,3581,0.7337398662210276,1519.2632131576538,0.2701424019677298,0.7382422174726214,0.2915389061510797,3554,0.7163420480224747 -266.0015046596527,0.4115355014801025,1336.8073723316193,5265,0,1336.8073723316193,0.2942258935505969,3581,0.7334462293397445,1603.3959367275238,0.2712574005126953,0.7380596569606236,0.2926849725881049,3554,0.7163017929841375 -270.02198791503906,0.4362020492553711,1416.8781580924988,5615,0,1416.8781580924988,0.2927201438036687,3581,0.7347357227162454,1687.523372173309,0.2696823562894548,0.7392725263323102,0.2913162325900394,3554,0.7172675704312043 -274.0418329238892,0.4638032913208008,1497.0349748134613,5962,0,1497.0349748134613,0.2924396309297333,3581,0.7354078764224379,1771.7392427921295,0.2692754779543195,0.7401208196367536,0.29083251953125,3554,0.7180653895698509 -278.0578043460846,0.4883151054382324,1577.034504652023,6308,0,1577.034504652023,0.292567462170134,3581,0.7354130578487155,1855.790392160416,0.2695643220629011,0.7398928233555385,0.2911026610641179,3554,0.7180783728501337 -282.0802209377289,0.5139882564544678,1656.9976651668549,6652,0,1656.9976651668549,0.292407008397183,3581,0.7372841663248744,1939.8130419254303,0.2689920834132603,0.7422077996390206,0.29080590037194,3554,0.7201190834710889 -286.04563570022583,0.54081130027771,1737.2716624736786,7001,0,1737.2716624736786,0.2913742683280683,3581,0.7360620314899818,2024.090726852417,0.2684498514447893,0.7404723848615374,0.2899037341287985,3554,0.7186478511228546 -290.0675663948059,0.5697238445281982,1817.2381384372711,7348,0,1817.2381384372711,0.292688339391057,3581,0.7348790982354789,2108.1195402145386,0.2693685122898647,0.7398408481052944,0.2910204679696292,3554,0.7177683540992543 -294.08826303482056,0.5942308902740479,1897.236734867096,7693,0,1897.236734867096,0.2940985736351578,3581,0.7350832191601508,2192.174660682678,0.2696377038955688,0.7407291276114327,0.2923774611902962,3554,0.7177994727551702 -298.10858392715454,0.6188774108886719,1977.2866151332853,8041,0,1977.2866151332853,0.2935847943137392,3581,0.7356848781983035,2276.280938625336,0.2701274497168405,0.740661757332938,0.2919904014402785,3554,0.7186807558384919 -302.1303942203522,0.6429617404937744,2057.261200904846,8388,0,2057.261200904846,0.2911733176181583,3581,0.7354499414226124,2360.312845468521,0.2680303198950631,0.7401862825666156,0.2897438474364273,3554,0.7179852229662 -306.15238642692566,0.6702282428741455,2137.390530109405,8734,0,2137.390530109405,0.2910374415316951,3581,0.7359388362625663,2444.502839803696,0.2678526129041399,0.7407048089163644,0.2895385536103686,3554,0.7185499613111986 -310.17205476760864,0.694101095199585,2217.4832949638367,9083,0,2217.4832949638367,0.2911403542023003,3581,0.7380135884267662,2528.650801420212,0.2678829772131784,0.7429026194981166,0.2896388133858856,3554,0.7207812994601154 -314.19159865379333,0.7229745388031006,2297.473820209503,9429,0,2297.473820209503,0.291500292886938,3581,0.7354532139023666,2612.7011408805847,0.2684066636221749,0.7403473854064941,0.2899045584640546,3554,0.7182699621025604 -318.208441734314,0.7502224445343018,2377.535681962967,9775,0,2377.535681962967,0.2908467854976787,3581,0.7358350713836918,2696.8188314437866,0.2679011821746826,0.7402353286743164,0.2893958062218627,3554,0.7183779500211029 -322.22816157341003,0.7746529579162598,2457.5765509605408,10124,0,2457.5765509605408,0.2905376043375628,3581,0.7375286478331821,2780.9151532649994,0.2672719274248396,0.7422716276986259,0.288979654306767,3554,0.7202255601083286 -326.24861216545105,0.7995173931121826,2537.5425441265106,10472,0,2537.5425441265106,0.2906910700027053,3581,0.7372738716489807,2864.9380617141724,0.2674891608101981,0.7421283040727887,0.2892322787154439,3554,0.7200168658993388 -330.2671887874603,0.8240396976470947,2617.5111916065216,10817,0,2617.5111916065216,0.2917194808538118,3581,0.7364756592955878,2948.9611835479736,0.2683074474334717,0.7415690422058105,0.2901125313796954,3554,0.7192934430175506 -334.2903254032135,0.8486440181732178,2697.6871478557587,11166,0,2697.6871478557587,0.2924594021615819,3581,0.7356054523876012,3033.1966512203217,0.2690391199929373,0.7408914566040039,0.2906372551174733,3554,0.7186734055157921 -338.3120629787445,0.8769309520721436,2777.7934505939484,11513,0,2777.7934505939484,0.2905638523522584,3581,0.7360119216437447,3117.364461660385,0.2675546067101614,0.7404755183628627,0.2891333928320203,3554,0.718561982867016 -342.33744502067566,0.901536226272583,2857.8316600322723,11861,0,2857.8316600322723,0.2904494860025132,3581,0.7377858102005376,3201.4641349315643,0.2674191849572317,0.7427958760942731,0.288963579769274,3554,0.7204063643078221 -346.3597838878632,0.927215337753296,2937.796384334564,12208,0,2937.796384334564,0.2900276769975216,3581,0.7372341928319603,3285.4885654449463,0.266594444002424,0.7424193790980748,0.2885294985623593,3554,0.7199619102156022 -350.3804326057434,0.953420639038086,3017.7934601306915,12556,0,3017.7934601306915,0.2915937290015882,3581,0.7345436008840058,3369.5441093444824,0.2682104791913713,0.7396091052464077,0.2900364864523248,3554,0.7172720355805079 -354.3416240215301,0.9813022613525392,3097.760371685028,12901,0,3097.760371685028,0.2904608033283301,3581,0.7364720459325258,3453.51145529747,0.2670316185270037,0.7418176106044224,0.288919134360052,3554,0.7191108527583356 -358.3672223091125,1.009075403213501,3177.790649652481,13248,0,3177.790649652481,0.2904348961969422,3581,0.7377360412376082,3537.606767892837,0.2669127668653215,0.7430195127214704,0.2889232216890299,3554,0.7204350099579699 -362.3905634880066,1.035358190536499,3257.7921764850616,13594,0,3257.7921764850616,0.2906754916355417,3581,0.7375735762531416,3621.669471979141,0.2671223027365548,0.7428833416530064,0.2891223330006682,3554,0.7201877780757597 -366.4184386730194,1.0607011318206787,3337.8404335975647,13941,0,3337.8404335975647,0.290944721271991,3581,0.737586461642174,3705.7823617458334,0.2675494125911167,0.742699214390346,0.2894603448029509,3554,0.7202629299732696 -370.4419300556183,1.0870215892791748,3417.8745822906494,14285,0,3417.8745822906494,0.2906121555169645,3581,0.737531511252967,3789.877601146698,0.266897337777274,0.7431567737034389,0.2891662288530529,3554,0.7201880528541784 -374.460782289505,1.1130621433258057,3497.9089074134827,14634,0,3497.9089074134827,0.2904956075140498,3581,0.7367213679838034,3873.968332767487,0.2673311744417463,0.7417561667306083,0.2890202184708251,3554,0.7192847188027575 -378.48295879364014,1.1384129524230957,3578.090413331985,14982,0,3578.090413331985,0.2899907593352939,3581,0.7384148080799707,3958.208960294724,0.2665043217795236,0.7435366085597447,0.28846259001741,3554,0.7212925247080754 -382.5058283805847,1.1644511222839355,3658.228110074997,15328,0,3658.228110074997,0.2898182382932491,3581,0.7373226861386484,4042.40718126297,0.2664320639201573,0.7421719006129673,0.2883614028647299,3554,0.7199183578362408 -386.5298075675965,1.1899957656860352,3738.270714044571,15676,0,3738.270714044571,0.2904098753621544,3581,0.7366578955119031,4126.510515928268,0.2670072657721383,0.7417474474225726,0.2889765287022545,3554,0.7191569468380697 -390.55300784111023,1.218813180923462,3818.447353839874,16023,0,3818.447353839874,0.2894058035792202,3581,0.7378498962623918,4210.751033306122,0.2661642347063337,0.7427665165492466,0.2879930108735316,3554,0.7204003191826112 -394.5731847286224,1.2473206520080566,3898.562789916992,16368,0,3898.562789916992,0.2905316729680082,3581,0.7390052179689681,4294.926712274551,0.2670840535845075,0.7440041133335659,0.2891163909173642,3554,0.721756488068022 -398.5912523269653,1.2744569778442385,3978.636076450348,16717,0,3978.636076450348,0.2903541750296705,3581,0.7394075284487573,4379.056773900986,0.2667323521205357,0.7443607875279018,0.2889188595816334,3554,0.7220803144344401 -402.61247849464417,1.3024535179138184,4058.681727647781,17064,0,4058.681727647781,0.2895092275747871,3581,0.7393874163336009,4463.163305521011,0.2659415687833513,0.7447352409362793,0.2880070245728844,3554,0.7221488716499015 -406.63641810417175,1.332447528839111,4138.6701855659485,17408,0,4138.6701855659485,0.2894624583849658,3581,0.7389856512671041,4547.21695804596,0.2660709619522095,0.7440425327845982,0.2880502506528735,3554,0.7215390009496342 -410.6612322330475,1.359151840209961,4218.735862731934,17753,0,4218.735862731934,0.2899840098458007,3581,0.7387822121090477,4631.345724105835,0.2663700580596924,0.7439319065638951,0.2884311622357731,3554,0.7215269106992122 -414.6260409355164,1.3860557079315186,4298.713285446167,18101,0,4298.713285446167,0.2894187230565833,3581,0.7404509040770735,4715.32683968544,0.2659097569329398,0.745412962777274,0.2879792032579927,3554,0.7232113711003447 -418.6499736309052,1.4169938564300537,4378.784018278122,18446,0,4378.784018278122,0.2898725410041713,3581,0.7406133008848785,4799.463942289352,0.2661910397665841,0.7459585326058524,0.2883196021977877,3554,0.7234733036279544 -422.6686074733734,1.4473583698272705,4458.971214294434,18794,0,4458.971214294434,0.2896872368380864,3581,0.7384542141903448,4883.711813926697,0.2661190032958984,0.7438090188162667,0.2880788276084166,3554,0.7211505329602209 -426.6920883655548,1.4761953353881836,4539.066566228867,19144,0,4539.066566228867,0.2891817750693766,3581,0.738919588082065,4967.871089935303,0.2657319477626255,0.7441033635820661,0.2877311813878288,3554,0.7216333873364519 -430.71262383461,1.503026008605957,4619.122064113617,19492,0,4619.122064113617,0.2901595306609536,3581,0.7386280646772898,5051.985657215118,0.2665566887174334,0.7439759799412319,0.2886482715338351,3554,0.7214370581563028 -434.7377324104309,1.530184984207153,4699.332726478577,19839,0,4699.332726478577,0.2890906569612189,3581,0.7386809015899888,5136.260441064835,0.2655234677450998,0.7438088144574847,0.2876793684822559,3554,0.7212626425550436 -438.7622768878937,1.5600879192352295,4779.307057142258,20187,0,4779.307057142258,0.2889807220944743,3581,0.739436162646607,5220.3008596897125,0.2655076810291835,0.744532653263637,0.2875561475351276,3554,0.7220739258362057 -442.7827990055084,1.588029146194458,4859.422160625458,20533,0,4859.422160625458,0.2897922970735304,3581,0.7386787199368193,5304.475937128067,0.2663663114820208,0.7438911029270717,0.2883558729490539,3554,0.7213166365143149 -446.8024094104767,1.6191449165344238,4939.427295207977,20880,0,4939.427295207977,0.2902587958801662,3581,0.7385392304872941,5388.543436765671,0.2665532827377319,0.7441916465759277,0.2887949688621096,3554,0.7212896395346793 -450.8265299797058,1.6496500968933103,5019.592027187347,21230,0,5019.592027187347,0.2887227075188495,3581,0.7394924765690449,5472.774369478226,0.2651681729725429,0.744783878326416,0.2872741905302564,3554,0.722162335792417 -454.8499348163605,1.676215887069702,5099.608054637909,21576,0,5099.608054637909,0.2892016826545483,3581,0.7398384731263963,5556.852010011673,0.2656759704862322,0.7450062888009208,0.2878276457864378,3554,0.7225232572453574 -458.8681221008301,1.7044212818145752,5179.616717100143,21923,0,5179.616717100143,0.2890725901459089,3581,0.7393251710416084,5640.918618440628,0.2651782206126621,0.7450105122157505,0.2876252543074264,3554,0.7219767916652012 -462.88522577285767,1.7367875576019287,5259.582234382629,22270,0,5259.582234382629,0.2891758096114912,3581,0.7391410940554315,5724.94517993927,0.2653919628688267,0.7445428030831474,0.2877102295334042,3554,0.7217431613147158 -466.90236949920654,1.7640461921691897,5339.7165422439575,22619,0,5339.7165422439575,0.2887698175919785,3581,0.740124951458217,5809.135733127594,0.265277692249843,0.7453750882829938,0.2873608316003974,3554,0.7228236587515827 -470.92625880241394,1.79207444190979,5419.942188262939,22966,0,5419.942188262939,0.2904986754638194,3581,0.7396289662454621,5893.424808740616,0.266116806438991,0.7456490652901786,0.2888991098827905,3554,0.7225788998751407 -474.9463920593262,1.8207497596740725,5499.9059715271,23314,0,5499.9059715271,0.2886775404805745,3581,0.7400252090023737,5977.449015617371,0.2647743225097656,0.745732034955706,0.2871776230897404,3554,0.7227975548018079 -478.9675254821777,1.84963059425354,5579.946591615677,23664,0,5579.946591615677,0.2885994782031031,3581,0.7392118614301173,6061.551117658615,0.264919672693525,0.7447335379464286,0.2872225665348463,3554,0.7217938579329629 -482.9878783226013,1.8781225681304927,5660.062964439392,24009,0,5660.062964439392,0.289814488576864,3581,0.7381948701698199,6145.727813243866,0.2658603702272687,0.7441058840070452,0.2883601663618458,3554,0.7207657744794598 -487.0089511871338,1.9069738388061523,5740.089933633804,24358,0,5740.089933633804,0.2896226394512706,3581,0.7391160732206437,6229.816264629364,0.2656828846250261,0.7448805400303432,0.2882363443369267,3554,0.7218370668393008 -491.0269269943237,1.9357292652130127,5820.203050613403,24707,0,5820.203050613403,0.2886322711773073,3581,0.7399242393666224,6313.987454175949,0.2650761604309082,0.7452448436192104,0.2873384028119724,3554,0.7225511472548537 -495.0495693683624,1.9641578197479248,5900.204864025116,25050,0,5900.204864025116,0.2885487206785814,3581,0.7402670997975426,6398.051856517792,0.2647052322115217,0.7459461348397392,0.2871748581319024,3554,0.7229634522720878 -499.0717160701752,1.9927153587341309,5980.308407306671,25399,0,5980.308407306671,0.2885024287253909,3581,0.7401602669688984,6482.217605352402,0.2647386278424944,0.7456408909388951,0.2871405795241717,3554,0.7228097137468346 -503.0915832519531,2.0212597846984863,6060.476011991501,25746,0,6060.476011991501,0.2882216090564786,3581,0.7408707359588802,6566.445360660553,0.2646439586366926,0.7461174556187221,0.2868086300207512,3554,0.7234955606798678 -507.1141304969788,2.0498805046081543,6140.551958799362,26090,0,6140.551958799362,0.2886733135275586,3581,0.7393630090887671,6650.584136009216,0.2648283072880336,0.7450517245701381,0.2872904024569587,3554,0.721940314830121 -511.1384792327881,2.083644151687622,6220.706364631653,26437,0,6220.706364631653,0.2884955087942439,3581,0.7400255498856814,6734.808057308197,0.2645855120250157,0.745732239314488,0.2870743922725714,3554,0.7227190368686691 -515.1594040393829,2.1129167079925537,6300.665889263153,26785,0,6300.665889263153,0.2885674692605068,3581,0.7394161187081123,6818.829693555832,0.2647560664585658,0.7451646668570382,0.287235549815129,3554,0.7219787151141319 -519.1855278015137,2.141589641571045,6380.640214920044,27130,0,6380.640214920044,0.2884770670072955,3581,0.7396342840250628,6902.870454549789,0.264656594821385,0.7452775410243443,0.2870864996966446,3554,0.7223087926895752 -523.2129678726196,2.17048716545105,6460.715427160263,27477,0,6460.715427160263,0.2880691660412594,3581,0.7399376701689472,6987.013831138611,0.2641756364277431,0.745558534349714,0.2866485715918683,3554,0.7225227763831247 -527.2320716381073,2.2036807537078857,6540.796165466309,27825,0,6540.796165466309,0.28817057882531066,3581,0.741383288100391,7071.158291816711,0.26425789083753315,0.7469542367117745,0.2867645280845526,3554,0.7240524678399338 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/measurements.csv deleted file mode 100644 index def5fc036..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/measurements.csv +++ /dev/null @@ -1,363 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.9456766,0.987718,,,,,,,,,,,,,, -1,,,0.199856059891837,0.983511243547712,0.1934721543672182,0.9851154591437464,3554.0,0.2181308195598296,0.9828197539967886,3581.0,55.26950407028198,256.6686406135559,55.26950407028198,201.3988561630249,0.0,0.0 -100,0.22717272,0.3913332,,,,,,,,,,,,,, -200,0.22004968,0.34265584,,,,,,,,,,,,,, -300,0.17814352,0.32464212,,,,,,,,,,,,,, -335,,,0.7025839260646275,0.3055055141448974,0.6819424690929234,0.3277747949053883,3554.0,0.6998325444839081,0.3296139101115261,3581.0,135.39524340629578,341.14130687713623,135.39524340629578,205.707136631012,0.0273573398590087,0.0 -400,0.40439382,0.37097666,,,,,,,,,,,,,, -500,0.30768043,0.27869257,,,,,,,,,,,,,, -574,,,0.7165115901402065,0.2934543064662388,0.6956642163759145,0.3148490463540201,3554.0,0.7133247058040701,0.3168418647789549,3581.0,215.4636936187744,425.2674949169159,215.4636936187744,209.72573137283325,0.0587446689605712,0.0 -600,0.16263413,0.31414968,,,,,,,,,,,,,, -700,0.14998862,0.28069142,,,,,,,,,,,,,, -800,0.31837636,0.23769799,,,,,,,,,,,,,, -816,,,0.7226732798985073,0.286161150251116,0.701933079914357,0.3074523891433948,3554.0,0.7193678850836009,0.3096011283510193,3581.0,295.4432170391083,509.3046169281006,295.4432170391083,213.747370004654,0.0868842601776123,0.0 -900,0.26337853,0.3409694,,,,,,,,,,,,,, -1000,0.08239091,0.30549258,,,,,,,,,,,,,, -1097,,,0.7242628506251744,0.2846875020435878,0.7031247252215813,0.3061716812812148,3554.0,0.7202405463514032,0.3083697896668179,3581.0,375.56945061683655,593.4932446479797,375.56945061683655,217.76836442947388,0.1191539764404296,0.0 -1100,0.29562002,0.2994892,,,,,,,,,,,,,, -1200,0.14239812,0.25276095,,,,,,,,,,,,,, -1300,0.16303328,0.27413622,,,,,,,,,,,,,, -1400,0.45291013,0.26263204,,,,,,,,,,,,,, -1446,,,0.725240979875837,0.2813701118741716,0.7055216173677546,0.3022148033575724,3554.0,0.7224767408501466,0.3043452872473645,3581.0,455.658816576004,677.641458272934,455.658816576004,221.79093170166016,0.1434721946716308,0.0 -1500,0.33241606,0.28269655,,,,,,,,,,,,,, -1600,0.4652627,0.23563159,,,,,,,,,,,,,, -1700,0.0711194,0.38822716,,,,,,,,,,,,,, -1793,,,0.7315128190176827,0.2767983845302036,0.7101681204276871,0.2979158949972742,3554.0,0.7277350022252862,0.2997427149146537,3581.0,535.7831084728241,761.8218240737915,535.7831084728241,225.81270360946647,0.1660394668579101,0.0 -1800,0.175603,0.22186816,,,,,,,,,,,,,, -1900,0.10075879,0.29518133,,,,,,,,,,,,,, -2000,0.08038706,0.26200938,,,,,,,,,,,,,, -2100,0.07138075,0.44388863,,,,,,,,,,,,,, -2139,,,0.731431416102818,0.2759296894073486,0.7106693162633653,0.2970016385037106,3554.0,0.7281596064734013,0.298683317770874,3581.0,615.940948009491,846.0330247879028,615.940948009491,229.82861018180847,0.191880464553833,0.0 -2200,0.21755832,0.22179966,,,,,,,,,,,,,, -2300,0.12296039,0.2547567,,,,,,,,,,,,,, -2400,0.2694771,0.29670337,,,,,,,,,,,,,, -2488,,,0.7350499289376395,0.2735965251922607,0.7137692976883441,0.2947435094589722,3554.0,0.7310327073443172,0.2965060279077073,3581.0,695.9960536956787,930.1457843780518,695.9960536956787,233.85149574279785,0.2148540019989013,0.0 -2500,0.24883783,0.2874599,,,,,,,,,,,,,, -2600,0.26023638,0.25299618,,,,,,,,,,,,,, -2700,0.1581653,0.2575066,,,,,,,,,,,,,, -2800,0.20109771,0.24880788,,,,,,,,,,,,,, -2836,,,0.7375479425702777,0.2729426962988717,0.7163687015290869,0.293957299708515,3554.0,0.7336718259128037,0.2955520318826794,3581.0,776.1169848442078,1014.315054178238,776.1169848442078,237.86601328849792,0.2370862960815429,0.0 -2900,0.20647803,0.24985749,,,,,,,,,,,,,, -3000,0.12997892,0.29382002,,,,,,,,,,,,,, -3100,0.17771323,0.29267472,,,,,,,,,,,,,, -3184,,,0.7368855476379395,0.2717396872384207,0.7150457808323368,0.2929867823337261,3554.0,0.7322620007286722,0.2947986115959229,3581.0,856.2830686569214,1098.5327405929563,856.2830686569214,241.8804790973664,0.2626070976257324,0.0 -3200,0.14460754,0.2807281,,,,,,,,,,,,,, -3300,0.099664874,0.30272913,,,,,,,,,,,,,, -3400,0.2747766,0.2912874,,,,,,,,,,,,,, -3500,0.13803427,0.28261352,,,,,,,,,,,,,, -3532,,,0.7362298965454102,0.2730930362428937,0.7147958698605444,0.2945232745563977,3554.0,0.7314617430754329,0.296571102531154,3581.0,936.331609249115,1182.6382551193235,936.331609249115,245.89754986763,0.2908327579498291,0.0 -3600,0.19823708,0.27775133,,,,,,,,,,,,,, -3700,0.05645802,0.2877745,,,,,,,,,,,,,, -3800,0.09199055,0.27645415,,,,,,,,,,,,,, -3881,,,0.7370754650660923,0.2710892983845302,0.7154304706184933,0.2924375346220808,3554.0,0.7325347755515219,0.2941487175697256,3581.0,1016.4203460216522,1266.7900638580322,1016.4203460216522,249.9260466098785,0.3137433528900146,0.0 -3900,0.16212294,0.3851105,,,,,,,,,,,,,, -4000,0.07803555,0.24746317,,,,,,,,,,,,,, -4100,0.09047414,0.32447413,,,,,,,,,,,,,, -4200,0.21484597,0.2532031,,,,,,,,,,,,,, -4225,,,0.7373881340026855,0.2704846348081316,0.7155515105119232,0.2919068001063942,3554.0,0.7330105122957973,0.2935189356586847,3581.0,1096.4697914123535,1350.8909401893616,1096.4697914123535,253.9429223537445,0.3369905948638916,0.0 -4300,0.122298636,0.35544086,,,,,,,,,,,,,, -4400,0.16852379,0.34310877,,,,,,,,,,,,,, -4500,0.24999145,0.2762935,,,,,,,,,,,,,, -4571,,,0.7382924216134208,0.2703211988721575,0.7167820369653911,0.2915198777455859,3554.0,0.7342269884677813,0.2930620497613271,3581.0,1176.554782152176,1435.0348629951477,1176.554782152176,257.966778755188,0.3606770038604736,0.0 -4600,0.12836322,0.26730186,,,,,,,,,,,,,, -4700,0.35815495,0.3192168,,,,,,,,,,,,,, -4800,0.056599155,0.26083338,,,,,,,,,,,,,, -4900,0.22468929,0.3322918,,,,,,,,,,,,,, -4919,,,0.7382422174726214,0.2701424019677298,0.7163420480224747,0.2915389061510797,3554.0,0.7337398662210276,0.2930749692386903,3581.0,1256.7266416549685,1519.2632131576538,1256.7266416549685,261.98525977134705,0.3870325088500976,0.0 -5000,0.22780189,0.32681322,,,,,,,,,,,,,, -5100,0.24458438,0.2817754,,,,,,,,,,,,,, -5200,0.11271637,0.27005088,,,,,,,,,,,,,, -5265,,,0.7380596569606236,0.2712574005126953,0.7163017929841375,0.2926849725881049,3554.0,0.7334462293397445,0.2942258935505969,3581.0,1336.8073723316193,1603.3959367275238,1336.8073723316193,266.0015046596527,0.4115355014801025,0.0 -5300,0.2383943,0.22887726,,,,,,,,,,,,,, -5400,0.11487937,0.31154275,,,,,,,,,,,,,, -5500,0.20822228,0.26969534,,,,,,,,,,,,,, -5600,0.12772839,0.22048046,,,,,,,,,,,,,, -5615,,,0.7392725263323102,0.2696823562894548,0.7172675704312043,0.2913162325900394,3554.0,0.7347357227162454,0.2927201438036687,3581.0,1416.8781580924988,1687.523372173309,1416.8781580924988,270.02198791503906,0.4362020492553711,0.0 -5700,0.12544194,0.28740355,,,,,,,,,,,,,, -5800,0.11721,0.3477234,,,,,,,,,,,,,, -5900,0.19702584,0.24426055,,,,,,,,,,,,,, -5962,,,0.7401208196367536,0.2692754779543195,0.7180653895698509,0.29083251953125,3554.0,0.7354078764224379,0.2924396309297333,3581.0,1497.0349748134613,1771.7392427921295,1497.0349748134613,274.0418329238892,0.4638032913208008,0.0 -6000,0.121520884,0.19333887,,,,,,,,,,,,,, -6100,0.16344213,0.24549404,,,,,,,,,,,,,, -6200,0.12333001,0.23434854,,,,,,,,,,,,,, -6300,0.12924062,0.29378176,,,,,,,,,,,,,, -6308,,,0.7398928233555385,0.2695643220629011,0.7180783728501337,0.2911026610641179,3554.0,0.7354130578487155,0.292567462170134,3581.0,1577.034504652023,1855.790392160416,1577.034504652023,278.0578043460846,0.4883151054382324,0.0 -6400,0.1797835,0.25749624,,,,,,,,,,,,,, -6500,0.17348006,0.27118766,,,,,,,,,,,,,, -6600,0.30009407,0.28104147,,,,,,,,,,,,,, -6652,,,0.7422077996390206,0.2689920834132603,0.7201190834710889,0.29080590037194,3554.0,0.7372841663248744,0.292407008397183,3581.0,1656.9976651668549,1939.8130419254303,1656.9976651668549,282.0802209377289,0.5139882564544678,0.0 -6700,0.16396028,0.23406449,,,,,,,,,,,,,, -6800,0.23039828,0.2433434,,,,,,,,,,,,,, -6900,0.065271035,0.33741817,,,,,,,,,,,,,, -7000,0.19259286,0.22580093,,,,,,,,,,,,,, -7001,,,0.7404723848615374,0.2684498514447893,0.7186478511228546,0.2899037341287985,3554.0,0.7360620314899818,0.2913742683280683,3581.0,1737.2716624736786,2024.090726852417,1737.2716624736786,286.04563570022583,0.54081130027771,0.0 -7100,0.2537214,0.34783012,,,,,,,,,,,,,, -7200,0.1783617,0.27520284,,,,,,,,,,,,,, -7300,0.31335303,0.30853498,,,,,,,,,,,,,, -7348,,,0.7398408481052944,0.2693685122898647,0.7177683540992543,0.2910204679696292,3554.0,0.7348790982354789,0.292688339391057,3581.0,1817.2381384372711,2108.1195402145386,1817.2381384372711,290.0675663948059,0.5697238445281982,0.0 -7400,0.08668356,0.29277912,,,,,,,,,,,,,, -7500,0.055765584,0.25366613,,,,,,,,,,,,,, -7600,0.09678299,0.33582282,,,,,,,,,,,,,, -7693,,,0.7407291276114327,0.2696377038955688,0.7177994727551702,0.2923774611902962,3554.0,0.7350832191601508,0.2940985736351578,3581.0,1897.236734867096,2192.174660682678,1897.236734867096,294.08826303482056,0.5942308902740479,0.0 -7700,0.2842033,0.21745345,,,,,,,,,,,,,, -7800,0.07450785,0.26515236,,,,,,,,,,,,,, -7900,0.18451226,0.33516935,,,,,,,,,,,,,, -8000,0.08798091,0.2810524,,,,,,,,,,,,,, -8041,,,0.740661757332938,0.2701274497168405,0.7186807558384919,0.2919904014402785,3554.0,0.7356848781983035,0.2935847943137392,3581.0,1977.2866151332853,2276.280938625336,1977.2866151332853,298.10858392715454,0.6188774108886719,0.0 -8100,0.088909246,0.25425246,,,,,,,,,,,,,, -8200,0.16100088,0.26719275,,,,,,,,,,,,,, -8300,0.08167838,0.2327693,,,,,,,,,,,,,, -8388,,,0.7401862825666156,0.2680303198950631,0.7179852229662,0.2897438474364273,3554.0,0.7354499414226124,0.2911733176181583,3581.0,2057.261200904846,2360.312845468521,2057.261200904846,302.1303942203522,0.6429617404937744,0.0 -8400,0.24071093,0.26703125,,,,,,,,,,,,,, -8500,0.070702545,0.35522163,,,,,,,,,,,,,, -8600,0.15281993,0.22444797,,,,,,,,,,,,,, -8700,0.18870986,0.26776025,,,,,,,,,,,,,, -8734,,,0.7407048089163644,0.2678526129041399,0.7185499613111986,0.2895385536103686,3554.0,0.7359388362625663,0.2910374415316951,3581.0,2137.390530109405,2444.502839803696,2137.390530109405,306.15238642692566,0.6702282428741455,0.0 -8800,0.2596324,0.27193254,,,,,,,,,,,,,, -8900,0.21851797,0.31903327,,,,,,,,,,,,,, -9000,0.11732849,0.23967497,,,,,,,,,,,,,, -9083,,,0.7429026194981166,0.2678829772131784,0.7207812994601154,0.2896388133858856,3554.0,0.7380135884267662,0.2911403542023003,3581.0,2217.4832949638367,2528.650801420212,2217.4832949638367,310.17205476760864,0.694101095199585,0.0 -9100,0.0654296,0.25766817,,,,,,,,,,,,,, -9200,0.23696786,0.24994656,,,,,,,,,,,,,, -9300,0.10664628,0.35826644,,,,,,,,,,,,,, -9400,0.1268815,0.2530511,,,,,,,,,,,,,, -9429,,,0.7403473854064941,0.2684066636221749,0.7182699621025604,0.2899045584640546,3554.0,0.7354532139023666,0.291500292886938,3581.0,2297.473820209503,2612.7011408805847,2297.473820209503,314.19159865379333,0.7229745388031006,0.0 -9500,0.3157319,0.27753073,,,,,,,,,,,,,, -9600,0.2432146,0.22609343,,,,,,,,,,,,,, -9700,0.28684106,0.28321657,,,,,,,,,,,,,, -9775,,,0.7402353286743164,0.2679011821746826,0.7183779500211029,0.2893958062218627,3554.0,0.7358350713836918,0.2908467854976787,3581.0,2377.535681962967,2696.8188314437866,2377.535681962967,318.208441734314,0.7502224445343018,0.0 -9800,0.50913817,0.33386484,,,,,,,,,,,,,, -9900,0.16138785,0.2570335,,,,,,,,,,,,,, -10000,0.076434806,0.30963546,,,,,,,,,,,,,, -10100,0.18814322,0.24172848,,,,,,,,,,,,,, -10124,,,0.7422716276986259,0.2672719274248396,0.7202255601083286,0.288979654306767,3554.0,0.7375286478331821,0.2905376043375628,3581.0,2457.5765509605408,2780.9151532649994,2457.5765509605408,322.22816157341003,0.7746529579162598,0.0 -10200,0.122994736,0.41297138,,,,,,,,,,,,,, -10300,0.07813913,0.33758116,,,,,,,,,,,,,, -10400,0.22658697,0.27988055,,,,,,,,,,,,,, -10472,,,0.7421283040727887,0.2674891608101981,0.7200168658993388,0.2892322787154439,3554.0,0.7372738716489807,0.2906910700027053,3581.0,2537.5425441265106,2864.9380617141724,2537.5425441265106,326.24861216545105,0.7995173931121826,0.0 -10500,0.14204368,0.25121394,,,,,,,,,,,,,, -10600,0.20088787,0.23967412,,,,,,,,,,,,,, -10700,0.3849271,0.26075435,,,,,,,,,,,,,, -10800,0.25304115,0.32015017,,,,,,,,,,,,,, -10817,,,0.7415690422058105,0.2683074474334717,0.7192934430175506,0.2901125313796954,3554.0,0.7364756592955878,0.2917194808538118,3581.0,2617.5111916065216,2948.9611835479736,2617.5111916065216,330.2671887874603,0.8240396976470947,0.0 -10900,0.2764779,0.29331928,,,,,,,,,,,,,, -11000,0.102561206,0.30766085,,,,,,,,,,,,,, -11100,0.30675712,0.22977845,,,,,,,,,,,,,, -11166,,,0.7408914566040039,0.2690391199929373,0.7186734055157921,0.2906372551174733,3554.0,0.7356054523876012,0.2924594021615819,3581.0,2697.6871478557587,3033.1966512203217,2697.6871478557587,334.2903254032135,0.8486440181732178,0.0 -11200,0.084069036,0.28566837,,,,,,,,,,,,,, -11300,0.19006644,0.266127,,,,,,,,,,,,,, -11400,0.052535094,0.360085,,,,,,,,,,,,,, -11500,0.067192085,0.340675,,,,,,,,,,,,,, -11513,,,0.7404755183628627,0.2675546067101614,0.718561982867016,0.2891333928320203,3554.0,0.7360119216437447,0.2905638523522584,3581.0,2777.7934505939484,3117.364461660385,2777.7934505939484,338.3120629787445,0.8769309520721436,0.0 -11600,0.42517635,0.28765976,,,,,,,,,,,,,, -11700,0.17436065,0.27411625,,,,,,,,,,,,,, -11800,0.17166664,0.2764938,,,,,,,,,,,,,, -11861,,,0.7427958760942731,0.2674191849572317,0.7204063643078221,0.288963579769274,3554.0,0.7377858102005376,0.2904494860025132,3581.0,2857.8316600322723,3201.4641349315643,2857.8316600322723,342.33744502067566,0.901536226272583,0.0 -11900,0.26856178,0.30012468,,,,,,,,,,,,,, -12000,0.11214446,0.34603518,,,,,,,,,,,,,, -12100,0.054142855,0.23718993,,,,,,,,,,,,,, -12200,0.38642085,0.3290135,,,,,,,,,,,,,, -12208,,,0.7424193790980748,0.266594444002424,0.7199619102156022,0.2885294985623593,3554.0,0.7372341928319603,0.2900276769975216,3581.0,2937.796384334564,3285.4885654449463,2937.796384334564,346.3597838878632,0.927215337753296,0.0 -12300,0.24842149,0.20429724,,,,,,,,,,,,,, -12400,0.107591316,0.311322,,,,,,,,,,,,,, -12500,0.10554833,0.3278578,,,,,,,,,,,,,, -12556,,,0.7396091052464077,0.2682104791913713,0.7172720355805079,0.2900364864523248,3554.0,0.7345436008840058,0.2915937290015882,3581.0,3017.7934601306915,3369.5441093444824,3017.7934601306915,350.3804326057434,0.953420639038086,0.0 -12600,0.09202795,0.31160572,,,,,,,,,,,,,, -12700,0.077801235,0.36222106,,,,,,,,,,,,,, -12800,0.3408218,0.28155345,,,,,,,,,,,,,, -12900,0.111285806,0.23031431,,,,,,,,,,,,,, -12901,,,0.7418176106044224,0.2670316185270037,0.7191108527583356,0.288919134360052,3554.0,0.7364720459325258,0.2904608033283301,3581.0,3097.760371685028,3453.51145529747,3097.760371685028,354.3416240215301,0.9813022613525392,0.0 -13000,0.07834292,0.3139531,,,,,,,,,,,,,, -13100,0.10252889,0.2585244,,,,,,,,,,,,,, -13200,0.1837113,0.26283154,,,,,,,,,,,,,, -13248,,,0.7430195127214704,0.2669127668653215,0.7204350099579699,0.2889232216890299,3554.0,0.7377360412376082,0.2904348961969422,3581.0,3177.790649652481,3537.606767892837,3177.790649652481,358.3672223091125,1.009075403213501,0.0 -13300,0.11987678,0.3839294,,,,,,,,,,,,,, -13400,0.19064407,0.38782942,,,,,,,,,,,,,, -13500,0.10818145,0.2823751,,,,,,,,,,,,,, -13594,,,0.7428833416530064,0.2671223027365548,0.7201877780757597,0.2891223330006682,3554.0,0.7375735762531416,0.2906754916355417,3581.0,3257.7921764850616,3621.669471979141,3257.7921764850616,362.3905634880066,1.035358190536499,0.0 -13600,0.11217571,0.2528827,,,,,,,,,,,,,, -13700,0.23650345,0.27018806,,,,,,,,,,,,,, -13800,0.3125655,0.33955917,,,,,,,,,,,,,, -13900,0.2043795,0.23295544,,,,,,,,,,,,,, -13941,,,0.742699214390346,0.2675494125911167,0.7202629299732696,0.2894603448029509,3554.0,0.737586461642174,0.290944721271991,3581.0,3337.8404335975647,3705.7823617458334,3337.8404335975647,366.4184386730194,1.0607011318206787,0.0 -14000,0.20149632,0.23568462,,,,,,,,,,,,,, -14100,0.14922825,0.26617014,,,,,,,,,,,,,, -14200,0.35601303,0.21368322,,,,,,,,,,,,,, -14285,,,0.7431567737034389,0.266897337777274,0.7201880528541784,0.2891662288530529,3554.0,0.737531511252967,0.2906121555169645,3581.0,3417.8745822906494,3789.877601146698,3417.8745822906494,370.4419300556183,1.0870215892791748,0.0 -14300,0.101954035,0.30347764,,,,,,,,,,,,,, -14400,0.44932315,0.31957823,,,,,,,,,,,,,, -14500,0.15893835,0.29337707,,,,,,,,,,,,,, -14600,0.18386903,0.24480619,,,,,,,,,,,,,, -14634,,,0.7417561667306083,0.2673311744417463,0.7192847188027575,0.2890202184708251,3554.0,0.7367213679838034,0.2904956075140498,3581.0,3497.9089074134827,3873.968332767487,3497.9089074134827,374.460782289505,1.1130621433258057,0.0 -14700,0.08416967,0.2879344,,,,,,,,,,,,,, -14800,0.19383837,0.32006595,,,,,,,,,,,,,, -14900,0.2387308,0.28416583,,,,,,,,,,,,,, -14982,,,0.7435366085597447,0.2665043217795236,0.7212925247080754,0.28846259001741,3554.0,0.7384148080799707,0.2899907593352939,3581.0,3578.090413331985,3958.208960294724,3578.090413331985,378.48295879364014,1.1384129524230957,0.0 -15000,0.22381839,0.2669391,,,,,,,,,,,,,, -15100,0.18146214,0.2925478,,,,,,,,,,,,,, -15200,0.23387313,0.31459394,,,,,,,,,,,,,, -15300,0.24176402,0.21888907,,,,,,,,,,,,,, -15328,,,0.7421719006129673,0.2664320639201573,0.7199183578362408,0.2883614028647299,3554.0,0.7373226861386484,0.2898182382932491,3581.0,3658.228110074997,4042.40718126297,3658.228110074997,382.5058283805847,1.1644511222839355,0.0 -15400,0.13295294,0.32826483,,,,,,,,,,,,,, -15500,0.15011878,0.38688403,,,,,,,,,,,,,, -15600,0.0895191,0.31146976,,,,,,,,,,,,,, -15676,,,0.7417474474225726,0.2670072657721383,0.7191569468380697,0.2889765287022545,3554.0,0.7366578955119031,0.2904098753621544,3581.0,3738.270714044571,4126.510515928268,3738.270714044571,386.5298075675965,1.1899957656860352,0.0 -15700,0.13422062,0.19870314,,,,,,,,,,,,,, -15800,0.15155242,0.31984612,,,,,,,,,,,,,, -15900,0.10047485,0.23455682,,,,,,,,,,,,,, -16000,0.08712958,0.23512574,,,,,,,,,,,,,, -16023,,,0.7427665165492466,0.2661642347063337,0.7204003191826112,0.2879930108735316,3554.0,0.7378498962623918,0.2894058035792202,3581.0,3818.447353839874,4210.751033306122,3818.447353839874,390.55300784111023,1.218813180923462,0.0 -16100,0.0871218,0.38539442,,,,,,,,,,,,,, -16200,0.111281835,0.24460444,,,,,,,,,,,,,, -16300,0.15691946,0.33861324,,,,,,,,,,,,,, -16368,,,0.7440041133335659,0.2670840535845075,0.721756488068022,0.2891163909173642,3554.0,0.7390052179689681,0.2905316729680082,3581.0,3898.562789916992,4294.926712274551,3898.562789916992,394.5731847286224,1.2473206520080566,0.0 -16400,0.17491658,0.23382911,,,,,,,,,,,,,, -16500,0.22467405,0.23883075,,,,,,,,,,,,,, -16600,0.22884409,0.23593417,,,,,,,,,,,,,, -16700,0.31235445,0.30036741,,,,,,,,,,,,,, -16717,,,0.7443607875279018,0.2667323521205357,0.7220803144344401,0.2889188595816334,3554.0,0.7394075284487573,0.2903541750296705,3581.0,3978.636076450348,4379.056773900986,3978.636076450348,398.5912523269653,1.2744569778442385,0.0 -16800,0.2393918,0.25947785,,,,,,,,,,,,,, -16900,0.13694294,0.35464704,,,,,,,,,,,,,, -17000,0.1286641,0.27392676,,,,,,,,,,,,,, -17064,,,0.7447352409362793,0.2659415687833513,0.7221488716499015,0.2880070245728844,3554.0,0.7393874163336009,0.2895092275747871,3581.0,4058.681727647781,4463.163305521011,4058.681727647781,402.61247849464417,1.3024535179138184,0.0 -17100,0.23900926,0.26141575,,,,,,,,,,,,,, -17200,0.09239128,0.36117512,,,,,,,,,,,,,, -17300,0.30856708,0.25281364,,,,,,,,,,,,,, -17400,0.21124484,0.27294084,,,,,,,,,,,,,, -17408,,,0.7440425327845982,0.2660709619522095,0.7215390009496342,0.2880502506528735,3554.0,0.7389856512671041,0.2894624583849658,3581.0,4138.6701855659485,4547.21695804596,4138.6701855659485,406.63641810417175,1.332447528839111,0.0 -17500,0.23150295,0.28560814,,,,,,,,,,,,,, -17600,0.21594039,0.27907953,,,,,,,,,,,,,, -17700,0.14447561,0.2292556,,,,,,,,,,,,,, -17753,,,0.7439319065638951,0.2663700580596924,0.7215269106992122,0.2884311622357731,3554.0,0.7387822121090477,0.2899840098458007,3581.0,4218.735862731934,4631.345724105835,4218.735862731934,410.6612322330475,1.359151840209961,0.0 -17800,0.21747825,0.246762,,,,,,,,,,,,,, -17900,0.09507173,0.33435068,,,,,,,,,,,,,, -18000,0.12267545,0.2622404,,,,,,,,,,,,,, -18100,0.14239307,0.2475523,,,,,,,,,,,,,, -18101,,,0.745412962777274,0.2659097569329398,0.7232113711003447,0.2879792032579927,3554.0,0.7404509040770735,0.2894187230565833,3581.0,4298.713285446167,4715.32683968544,4298.713285446167,414.6260409355164,1.3860557079315186,0.0 -18200,0.34561864,0.2948073,,,,,,,,,,,,,, -18300,0.11271246,0.32788578,,,,,,,,,,,,,, -18400,0.10766158,0.25726038,,,,,,,,,,,,,, -18446,,,0.7459585326058524,0.2661910397665841,0.7234733036279544,0.2883196021977877,3554.0,0.7406133008848785,0.2898725410041713,3581.0,4378.784018278122,4799.463942289352,4378.784018278122,418.6499736309052,1.4169938564300537,0.0 -18500,0.12959583,0.2618258,,,,,,,,,,,,,, -18600,0.2518487,0.3448649,,,,,,,,,,,,,, -18700,0.23599544,0.35494643,,,,,,,,,,,,,, -18794,,,0.7438090188162667,0.2661190032958984,0.7211505329602209,0.2880788276084166,3554.0,0.7384542141903448,0.2896872368380864,3581.0,4458.971214294434,4883.711813926697,4458.971214294434,422.6686074733734,1.4473583698272705,0.0 -18800,0.11490498,0.32547295,,,,,,,,,,,,,, -18900,0.1671014,0.29327032,,,,,,,,,,,,,, -19000,0.10779816,0.244589,,,,,,,,,,,,,, -19100,0.21707396,0.33332923,,,,,,,,,,,,,, -19144,,,0.7441033635820661,0.2657319477626255,0.7216333873364519,0.2877311813878288,3554.0,0.738919588082065,0.2891817750693766,3581.0,4539.066566228867,4967.871089935303,4539.066566228867,426.6920883655548,1.4761953353881836,0.0 -19200,0.12421605,0.24314928,,,,,,,,,,,,,, -19300,0.17793235,0.29092902,,,,,,,,,,,,,, -19400,0.18565571,0.19668952,,,,,,,,,,,,,, -19492,,,0.7439759799412319,0.2665566887174334,0.7214370581563028,0.2886482715338351,3554.0,0.7386280646772898,0.2901595306609536,3581.0,4619.122064113617,5051.985657215118,4619.122064113617,430.71262383461,1.503026008605957,0.0 -19500,0.12118213,0.3782464,,,,,,,,,,,,,, -19600,0.1754063,0.32313332,,,,,,,,,,,,,, -19700,0.15138833,0.28438875,,,,,,,,,,,,,, -19800,0.11901009,0.28912196,,,,,,,,,,,,,, -19839,,,0.7438088144574847,0.2655234677450998,0.7212626425550436,0.2876793684822559,3554.0,0.7386809015899888,0.2890906569612189,3581.0,4699.332726478577,5136.260441064835,4699.332726478577,434.7377324104309,1.530184984207153,0.0 -19900,0.09374597,0.3406664,,,,,,,,,,,,,, -20000,0.13295634,0.23918217,,,,,,,,,,,,,, -20100,0.2930784,0.34435916,,,,,,,,,,,,,, -20187,,,0.744532653263637,0.2655076810291835,0.7220739258362057,0.2875561475351276,3554.0,0.739436162646607,0.2889807220944743,3581.0,4779.307057142258,5220.3008596897125,4779.307057142258,438.7622768878937,1.5600879192352295,0.0 -20200,0.0789623,0.25682384,,,,,,,,,,,,,, -20300,0.23829755,0.24081534,,,,,,,,,,,,,, -20400,0.17679606,0.2489949,,,,,,,,,,,,,, -20500,0.12808588,0.26769707,,,,,,,,,,,,,, -20533,,,0.7438911029270717,0.2663663114820208,0.7213166365143149,0.2883558729490539,3554.0,0.7386787199368193,0.2897922970735304,3581.0,4859.422160625458,5304.475937128067,4859.422160625458,442.7827990055084,1.588029146194458,0.0 -20600,0.1966726,0.41157138,,,,,,,,,,,,,, -20700,0.10794267,0.319903,,,,,,,,,,,,,, -20800,0.07379225,0.2260136,,,,,,,,,,,,,, -20880,,,0.7441916465759277,0.2665532827377319,0.7212896395346793,0.2887949688621096,3554.0,0.7385392304872941,0.2902587958801662,3581.0,4939.427295207977,5388.543436765671,4939.427295207977,446.8024094104767,1.6191449165344238,0.0 -20900,0.15541421,0.32052875,,,,,,,,,,,,,, -21000,0.08950439,0.32614338,,,,,,,,,,,,,, -21100,0.25579602,0.34243405,,,,,,,,,,,,,, -21200,0.121854834,0.29180104,,,,,,,,,,,,,, -21230,,,0.744783878326416,0.2651681729725429,0.722162335792417,0.2872741905302564,3554.0,0.7394924765690449,0.2887227075188495,3581.0,5019.592027187347,5472.774369478226,5019.592027187347,450.8265299797058,1.6496500968933103,0.0 -21300,0.11880778,0.3288235,,,,,,,,,,,,,, -21400,0.11188248,0.28359568,,,,,,,,,,,,,, -21500,0.10246691,0.22156587,,,,,,,,,,,,,, -21576,,,0.7450062888009208,0.2656759704862322,0.7225232572453574,0.2878276457864378,3554.0,0.7398384731263963,0.2892016826545483,3581.0,5099.608054637909,5556.852010011673,5099.608054637909,454.8499348163605,1.676215887069702,0.0 -21600,0.17942835,0.36016232,,,,,,,,,,,,,, -21700,0.16271791,0.36531612,,,,,,,,,,,,,, -21800,0.16176347,0.22466084,,,,,,,,,,,,,, -21900,0.24034484,0.26252782,,,,,,,,,,,,,, -21923,,,0.7450105122157505,0.2651782206126621,0.7219767916652012,0.2876252543074264,3554.0,0.7393251710416084,0.2890725901459089,3581.0,5179.616717100143,5640.918618440628,5179.616717100143,458.8681221008301,1.7044212818145752,0.0 -22000,0.12173256,0.3182319,,,,,,,,,,,,,, -22100,0.13213992,0.3760463,,,,,,,,,,,,,, -22200,0.14198148,0.30071786,,,,,,,,,,,,,, -22270,,,0.7445428030831474,0.2653919628688267,0.7217431613147158,0.2877102295334042,3554.0,0.7391410940554315,0.2891758096114912,3581.0,5259.582234382629,5724.94517993927,5259.582234382629,462.88522577285767,1.7367875576019287,0.0 -22300,0.15381674,0.23555303,,,,,,,,,,,,,, -22400,0.21205205,0.27418664,,,,,,,,,,,,,, -22500,0.24693458,0.2894867,,,,,,,,,,,,,, -22600,0.15783216,0.3170452,,,,,,,,,,,,,, -22619,,,0.7453750882829938,0.265277692249843,0.7228236587515827,0.2873608316003974,3554.0,0.740124951458217,0.2887698175919785,3581.0,5339.7165422439575,5809.135733127594,5339.7165422439575,466.90236949920654,1.7640461921691897,0.0 -22700,0.09926717,0.23190758,,,,,,,,,,,,,, -22800,0.0993647,0.23811987,,,,,,,,,,,,,, -22900,0.17794058,0.22922441,,,,,,,,,,,,,, -22966,,,0.7456490652901786,0.266116806438991,0.7225788998751407,0.2888991098827905,3554.0,0.7396289662454621,0.2904986754638194,3581.0,5419.942188262939,5893.424808740616,5419.942188262939,470.92625880241394,1.79207444190979,0.0 -23000,0.1186298,0.32979104,,,,,,,,,,,,,, -23100,0.15823491,0.2723261,,,,,,,,,,,,,, -23200,0.09434637,0.3378655,,,,,,,,,,,,,, -23300,0.2963837,0.3603753,,,,,,,,,,,,,, -23314,,,0.745732034955706,0.2647743225097656,0.7227975548018079,0.2871776230897404,3554.0,0.7400252090023737,0.2886775404805745,3581.0,5499.9059715271,5977.449015617371,5499.9059715271,474.9463920593262,1.8207497596740725,0.0 -23400,0.15665358,0.43210524,,,,,,,,,,,,,, -23500,0.051910467,0.2463406,,,,,,,,,,,,,, -23600,0.36965626,0.2548004,,,,,,,,,,,,,, -23664,,,0.7447335379464286,0.264919672693525,0.7217938579329629,0.2872225665348463,3554.0,0.7392118614301173,0.2885994782031031,3581.0,5579.946591615677,6061.551117658615,5579.946591615677,478.9675254821777,1.84963059425354,0.0 -23700,0.173181,0.37709323,,,,,,,,,,,,,, -23800,0.36449113,0.210882,,,,,,,,,,,,,, -23900,0.12249413,0.2715695,,,,,,,,,,,,,, -24000,0.14883478,0.26778117,,,,,,,,,,,,,, -24009,,,0.7441058840070452,0.2658603702272687,0.7207657744794598,0.2883601663618458,3554.0,0.7381948701698199,0.289814488576864,3581.0,5660.062964439392,6145.727813243866,5660.062964439392,482.9878783226013,1.8781225681304927,0.0 -24100,0.10971041,0.2954665,,,,,,,,,,,,,, -24200,0.196463,0.25935498,,,,,,,,,,,,,, -24300,0.22314025,0.31419635,,,,,,,,,,,,,, -24358,,,0.7448805400303432,0.2656828846250261,0.7218370668393008,0.2882363443369267,3554.0,0.7391160732206437,0.2896226394512706,3581.0,5740.089933633804,6229.816264629364,5740.089933633804,487.0089511871338,1.9069738388061523,0.0 -24400,0.22299571,0.2155633,,,,,,,,,,,,,, -24500,0.07860922,0.33866572,,,,,,,,,,,,,, -24600,0.14269641,0.29145634,,,,,,,,,,,,,, -24700,0.17502168,0.29496032,,,,,,,,,,,,,, -24707,,,0.7452448436192104,0.2650761604309082,0.7225511472548537,0.2873384028119724,3554.0,0.7399242393666224,0.2886322711773073,3581.0,5820.203050613403,6313.987454175949,5820.203050613403,491.0269269943237,1.9357292652130127,0.0 -24800,0.056090347,0.3884374,,,,,,,,,,,,,, -24900,0.13421342,0.30080214,,,,,,,,,,,,,, -25000,0.22379045,0.26959762,,,,,,,,,,,,,, -25050,,,0.7459461348397392,0.2647052322115217,0.7229634522720878,0.2871748581319024,3554.0,0.7402670997975426,0.2885487206785814,3581.0,5900.204864025116,6398.051856517792,5900.204864025116,495.0495693683624,1.9641578197479248,0.0 -25100,0.07782109,0.2789336,,,,,,,,,,,,,, -25200,0.14111054,0.28406116,,,,,,,,,,,,,, -25300,0.3367781,0.24479702,,,,,,,,,,,,,, -25399,,,0.7456408909388951,0.2647386278424944,0.7228097137468346,0.2871405795241717,3554.0,0.7401602669688984,0.2885024287253909,3581.0,5980.308407306671,6482.217605352402,5980.308407306671,499.0717160701752,1.9927153587341309,0.0 -25400,0.11242574,0.3431042,,,,,,,,,,,,,, -25500,0.19137463,0.22623548,,,,,,,,,,,,,, -25600,0.092830755,0.27520743,,,,,,,,,,,,,, -25700,0.25699013,0.2682395,,,,,,,,,,,,,, -25746,,,0.7461174556187221,0.2646439586366926,0.7234955606798678,0.2868086300207512,3554.0,0.7408707359588802,0.2882216090564786,3581.0,6060.476011991501,6566.445360660553,6060.476011991501,503.0915832519531,2.0212597846984863,0.0 -25800,0.20650925,0.2571497,,,,,,,,,,,,,, -25900,0.53024495,0.20686117,,,,,,,,,,,,,, -26000,0.1343016,0.24400012,,,,,,,,,,,,,, -26090,,,0.7450517245701381,0.2648283072880336,0.721940314830121,0.2872904024569587,3554.0,0.7393630090887671,0.2886733135275586,3581.0,6140.551958799362,6650.584136009216,6140.551958799362,507.1141304969788,2.0498805046081543,0.0 -26100,0.21355194,0.2888685,,,,,,,,,,,,,, -26200,0.10582866,0.3897729,,,,,,,,,,,,,, -26300,0.21708119,0.32380834,,,,,,,,,,,,,, -26400,0.13004205,0.27747604,,,,,,,,,,,,,, -26437,,,0.745732239314488,0.2645855120250157,0.7227190368686691,0.2870743922725714,3554.0,0.7400255498856814,0.2884955087942439,3581.0,6220.706364631653,6734.808057308197,6220.706364631653,511.1384792327881,2.083644151687622,0.0 -26500,0.1443717,0.26070103,,,,,,,,,,,,,, -26600,0.05523196,0.25828835,,,,,,,,,,,,,, -26700,0.09168793,0.32622814,,,,,,,,,,,,,, -26785,,,0.7451646668570382,0.2647560664585658,0.7219787151141319,0.287235549815129,3554.0,0.7394161187081123,0.2885674692605068,3581.0,6300.665889263153,6818.829693555832,6300.665889263153,515.1594040393829,2.1129167079925537,0.0 -26800,0.44383788,0.24749795,,,,,,,,,,,,,, -26900,0.1930567,0.23786065,,,,,,,,,,,,,, -27000,0.2762449,0.26454198,,,,,,,,,,,,,, -27100,0.09282485,0.3166148,,,,,,,,,,,,,, -27130,,,0.7452775410243443,0.264656594821385,0.7223087926895752,0.2870864996966446,3554.0,0.7396342840250628,0.2884770670072955,3581.0,6380.640214920044,6902.870454549789,6380.640214920044,519.1855278015137,2.141589641571045,0.0 -27200,0.13973983,0.35896635,,,,,,,,,,,,,, -27300,0.12950924,0.3001641,,,,,,,,,,,,,, -27400,0.13759413,0.23101613,,,,,,,,,,,,,, -27477,,,0.745558534349714,0.2641756364277431,0.7225227763831247,0.2866485715918683,3554.0,0.7399376701689472,0.2880691660412594,3581.0,6460.715427160263,6987.013831138611,6460.715427160263,523.2129678726196,2.17048716545105,0.0 -27500,0.14907743,0.3103478,,,,,,,,,,,,,, -27600,0.17216444,0.2889022,,,,,,,,,,,,,, -27700,0.12965968,0.3430941,,,,,,,,,,,,,, -27800,0.19382465,0.3606869,,,,,,,,,,,,,, -27825,,,0.7469542367117745,0.2642578908375331,0.7240524678399338,0.2867645280845526,3554.0,0.741383288100391,0.2881705788253106,3581.0,6540.796165466309,7071.158291816711,6540.796165466309,527.2320716381073,2.2036807537078857,0.0 -27825,,,,,,,,,,,6540.796165466309,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 488c994ee..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -36.88790965080261,0.0,51.92995691299439,1,0,51.92995691299439,0.0006000000284984,6.9125494956970215,10000,88.81795763969421,0.0010363520123064,6.91261625289917,0.0007599999662488,6.913174629211426,50000 -55.00457406044006,0.0275373458862304,562.0372688770294,1498,0,562.0372688770294,0.0465000011026859,5.657770156860352,10000,617.1234951019287,0.0717275142669677,5.377313137054443,0.0648999959230423,5.448593616485596,50000 -73.41657638549805,0.0556209087371826,1072.2711553573608,2994,0,1072.2711553573608,0.1105000078678131,4.828097820281982,10000,1145.850881814957,0.1758011728525161,4.277118682861328,0.1584599912166595,4.395031929016113,50000 -91.58906817436218,0.0812327861785888,1582.2347013950348,4489,0,1582.2347013950348,0.1837000101804733,4.219701766967773,10000,1674.067317724228,0.2791972160339355,3.528715372085572,0.2509799897670746,3.671311378479004,50000 -109.70492148399352,0.1139328479766845,2092.3124701976776,5985,0,2092.3124701976776,0.2619000077247619,3.715914726257324,10000,2202.3470935821533,0.3691804707050323,2.9538686275482178,0.3462799787521362,3.0891966819763184,50000 -127.85242891311646,0.1410117149353027,2602.396719932556,7481,0,2602.396719932556,0.2989000082015991,3.4827613830566406,10000,2730.6603696346283,0.4307437837123871,2.613607168197632,0.3971000015735626,2.790703296661377,50000 -146.1283574104309,0.1749634742736816,3112.5727632045746,8978,0,3112.5727632045746,0.357200026512146,3.100750684738159,10000,3259.198537349701,0.492287129163742,2.275744676589966,0.4618600010871887,2.433438301086426,50000 -164.5659377574921,0.2022130489349365,3622.6284663677216,10475,0,3622.6284663677216,0.3951000273227691,2.8704304695129395,10000,3787.772131204605,0.5693957209587097,1.8849722146987915,0.5029999613761902,2.2190101146698,50000 -182.8322350978852,0.2424328327178955,4132.844066381455,11974,0,4132.844066381455,0.4093000292778015,2.793846368789673,10000,4316.348064184189,0.5763313174247742,1.9071600437164309,0.5275999903678894,2.128957271575928,50000 -201.3808000087738,0.2749552726745605,4643.087154865265,13473,0,4643.087154865265,0.4231000244617462,2.7273426055908203,10000,4845.225155115128,0.5816326141357422,1.8359532356262207,0.535539984703064,2.061918258666992,50000 -220.55117321014404,0.304619550704956,5153.3032858371735,14973,0,5153.3032858371735,0.431300014257431,2.6510255336761475,10000,5374.694349765778,0.5989516973495483,1.7546666860580444,0.5537399649620056,1.964607000350952,50000 -240.77084040641785,0.3464560508728027,5663.362612962723,16473,0,5663.362612962723,0.4360000193119049,2.6539673805236816,10000,5905.069729804993,0.59375,1.8033223152160645,0.5557399988174438,1.9902366399765008,50000 -262.16706109046936,0.3816823959350586,6173.529769182205,17974,0,6173.529769182205,0.4388000071048736,2.6314282417297363,10000,6436.721809148788,0.6002271771430969,1.7657809257507324,0.5660600066184998,1.9302324056625368,50000 -287.0052680969238,0.423555850982666,6683.501819372177,19474,0,6683.501819372177,0.4509000182151794,2.5466179847717285,10000,6971.627715110779,0.6611328125,1.485448122024536,0.5753799676895142,1.8646032810211184,50000 -309.8524570465088,0.4640743732452392,7193.579409122467,20975,0,7193.579409122467,0.4554000198841095,2.530134439468384,10000,7504.646897554398,0.6342673897743225,1.583891749382019,0.5761199593544006,1.8759549856185915,50000 -331.14792466163635,0.4928333759307861,7703.516690969467,22476,0,7703.516690969467,0.4526000320911407,2.587834596633911,10000,8035.961810588837,0.6265146732330322,1.6444129943847656,0.5791199803352356,1.8640754222869875,50000 -354.0397229194641,0.5262980461120605,8213.743519306183,23978,0,8213.743519306183,0.4561000168323517,2.549467325210572,10000,8569.167505025864,0.6319355964660645,1.6008868217468262,0.582260012626648,1.830167293548584,50000 -376.9086351394653,0.5557169914245605,8723.919181346893,25480,0,8723.919181346893,0.4687000215053558,2.476848602294922,10000,9102.295284986496,0.6276904940605164,1.620834469795227,0.586080014705658,1.8189243078231807,50000 -398.1472208499909,0.58970046043396,9234.135885238647,26982,0,9234.135885238647,0.469400018453598,2.4492757320404053,10000,9633.838790893557,0.6330516338348389,1.5673516988754272,0.5966199636459351,1.7551690340042114,50000 -420.8384771347046,1.513392210006714,9743.258620500565,28481,0,9743.258620500565,0.4660000205039978,2.4653191566467285,10000,10166.62963104248,0.6598572731018066,1.473987102508545,0.5931400060653687,1.779598593711853,50000 -442.5229060649872,1.5470540523529053,10253.200542211533,29983,0,10253.200542211533,0.4546000361442566,2.542966604232788,10000,10698.343393564224,0.6408442258834839,1.5831868648529053,0.5834800004959106,1.8493436574935915,50000 -465.9139356613159,1.5822596549987793,10763.20768213272,31484,0,10763.20768213272,0.4751000106334686,2.450975179672241,10000,11231.830409526823,0.650390625,1.5332188606262207,0.6009599566459656,1.777511715888977,50000 -487.7795009613037,1.6160292625427246,11273.447633743286,32987,0,11273.447633743286,0.4719000160694122,2.464462995529175,10000,11764.022417783735,0.6478196382522583,1.545523166656494,0.5983999967575073,1.7723729610443115,50000 -508.4510877132416,1.65091872215271,11783.682542085648,34490,0,11783.682542085648,0.4844000339508056,2.4039146900177,10000,12295.01685166359,0.6456871628761292,1.5215743780136108,0.6019399762153625,1.7319201231002808,50000 -530.7041218280792,1.6843512058258057,12293.693517684937,35993,0,12293.693517684937,0.4715000092983246,2.450953722000122,10000,12827.367676734924,0.6405652165412903,1.5603916645050049,0.5936599969863892,1.7699402570724487,50000 -552.2334206104279,1.7202045917510986,12803.7000541687,37495,0,12803.7000541687,0.4916000366210937,2.3600785732269287,10000,13358.99266719818,0.6504703164100647,1.510545015335083,0.608460009098053,1.7156388759613037,50000 -572.3573670387268,1.755732774734497,13313.690217733383,38998,0,13313.690217733383,0.4841000139713287,2.391342639923096,10000,13889.194969892502,0.6715760231018066,1.4409266710281372,0.6063599586486816,1.7327011823654177,50000 -590.4436695575714,1.7907118797302246,13823.646218061447,40500,0,13823.646218061447,0.488500028848648,2.359315633773804,10000,14417.326255083084,0.6689851880073547,1.414209485054016,0.6127600073814392,1.6750468015670776,50000 -610.9255712032318,1.8257253170013428,14333.642135620115,42003,0,14333.642135620115,0.4949000179767608,2.352324247360229,10000,14947.89179635048,0.661531388759613,1.4491329193115234,0.6156399846076965,1.6673866510391235,50000 -629.7090165615082,1.8625688552856443,14843.627077817917,43506,0,14843.627077817917,0.4839000105857849,2.3706626892089844,10000,15476.748891830444,0.6548947691917419,1.4952327013015747,0.6090599894523621,1.7058863639831543,50000 -648.3707220554352,1.9011075496673584,15353.6026597023,45009,0,15353.6026597023,0.4889000356197357,2.364584922790528,10000,16005.477684020996,0.6541972160339355,1.492714524269104,0.6135199666023254,1.686639666557312,50000 -666.4376637935638,1.9357657432556152,15863.77936911583,46513,0,15863.77936911583,0.4930000305175781,2.3379759788513184,10000,16533.809260606766,0.6575055718421936,1.4960453510284424,0.6189199686050415,1.6851189136505127,50000 -685.0665671825409,1.9751687049865725,16373.80011177063,48017,0,16373.80011177063,0.498600035905838,2.3249881267547607,10000,17062.553616285324,0.6881377696990967,1.3390893936157229,0.6141799688339233,1.6819976568222046,50000 -703.5050427913666,2.014011144638061,16883.71029162407,49520,0,16883.71029162407,0.4930000305175781,2.3653693199157715,10000,17590.99581003189,0.6678889989852905,1.4150813817977903,0.6149399876594543,1.6707707643508911,50000 -721.513111114502,2.0556466579437256,17393.85679912567,51024,0,17393.85679912567,0.5006999969482422,2.3009283542633057,10000,18119.246876478195,0.6670718789100647,1.4345439672470093,0.6176199913024902,1.679994821548462,50000 -740.029061794281,2.0966885089874268,17904.098637342453,52528,0,17904.098637342453,0.4999000132083893,2.28476881980896,10000,18648.099063634872,0.6692044138908386,1.4314494132995603,0.6190999746322632,1.6480151414871216,50000 -757.3529677391052,2.1330432891845703,18414.03647136688,54032,0,18414.03647136688,0.4923000335693359,2.3091156482696533,10000,19175.45029401779,0.6640425324440002,1.432912826538086,0.620199978351593,1.6460167169570925,50000 -774.8156905174255,2.171757459640503,18923.95730996132,55535,0,18923.95730996132,0.5022000074386597,2.2728426456451416,10000,19702.92580795288,0.6664939522743225,1.427367925643921,0.6240000128746033,1.622045636177063,50000 -792.2390511035919,2.209519863128662,19434.118763685223,57040,0,19434.118763685223,0.5054000020027161,2.2856898307800293,10000,20230.6025326252,0.7063934803009033,1.2701016664505005,0.6281599998474121,1.6143368482589722,50000 -809.6045508384705,2.2512757778167725,19944.30988574028,58544,0,19944.30988574028,0.4928000271320343,2.3111507892608643,10000,20758.25435495377,0.6819196343421936,1.3621110916137695,0.6212799549102783,1.638480544090271,50000 -826.8658409118652,2.2905819416046143,20454.509740829468,60048,0,20454.509740829468,0.5053000450134277,2.2946648597717285,10000,21285.80756020546,0.6807836294174194,1.359082579612732,0.6322000026702881,1.601298213005066,50000 -844.2586009502411,2.3358445167541504,20964.46904230117,61552,0,20964.46904230117,0.5041000247001648,2.258517980575561,10000,21813.258487939835,0.6801857352256775,1.3753252029418943,0.6317799687385559,1.5977727174758911,50000 -861.6014168262482,2.37718152999878,21474.617376327515,63057,0,21474.617376327515,0.5072000026702881,2.271667242050171,10000,22340.84404706955,0.6864038705825806,1.3631954193115234,0.6362400054931641,1.5933818817138672,50000 -879.075288772583,2.416655778884888,21984.827194929123,64562,0,21984.827194929123,0.4985000193119049,2.332926034927368,10000,22868.620908498764,0.6725525856018066,1.4098135232925415,0.6275399923324585,1.6258821487426758,50000 -896.5920617580414,2.4555139541625977,22494.96067333221,66067,0,22494.96067333221,0.5062000155448914,2.2767369747161865,10000,23396.3625805378,0.7122927308082581,1.2576779127120972,0.624459981918335,1.6419285535812378,50000 -914.1040511131288,2.495955228805542,23005.19539809227,67572,0,23005.19539809227,0.5186000466346741,2.217724323272705,10000,23924.20377588272,0.6955117583274841,1.2980284690856934,0.6328999996185303,1.5780168771743774,50000 -933.078031539917,2.535719633102417,23515.14908361435,69077,0,23515.14908361435,0.5082000494003296,2.25681209564209,10000,24453.22411513329,0.6938177347183228,1.3356457948684692,0.634719967842102,1.593294978141785,50000 -951.5091438293456,2.5747079849243164,24025.31227278709,70582,0,24025.31227278709,0.5163000226020813,2.2319271564483643,10000,24981.911822795868,0.6916653513908386,1.3394980430603027,0.6408199667930603,1.581676959991455,50000 -968.733303785324,2.617192506790161,24535.51449584961,72087,0,24535.51449584961,0.518500030040741,2.226827621459961,10000,25509.434837818146,0.6868821382522583,1.351656198501587,0.6393600106239319,1.5750479698181152,50000 -986.4949653148652,2.6583826541900635,25045.693425178528,73592,0,25045.693425178528,0.513700008392334,2.240742921829224,10000,26037.469779729843,0.6874202489852905,1.3545960187911987,0.6374799609184265,1.5858798027038574,50000 -1004.9125220775604,2.7030186653137207,25555.87436771393,75098,0,25555.87436771393,0.5095000267028809,2.269789695739746,10000,26566.167140960693,0.7245495915412903,1.1946289539337158,0.6363799571990967,1.5921239852905271,50000 -1022.8778517246246,2.7443735599517822,26066.06677961349,76603,0,26066.06677961349,0.5212000012397766,2.208439588546753,10000,27094.41897177696,0.7009127736091614,1.2936850786209106,0.6375199556350708,1.5906641483306885,50000 -1039.984982252121,2.788038969039917,26576.22903299332,78107,0,26576.22903299332,0.5258000493049622,2.185645580291748,10000,27621.786662578583,0.7107182741165161,1.2404513359069824,0.6543399691581726,1.5085474252700806,50000 -1057.1927177906036,2.82971453666687,27086.374658584595,79612,0,27086.374658584595,0.5260000228881836,2.1862270832061768,10000,28149.237203598022,0.7019491195678711,1.2742984294891355,0.6469199657440186,1.5245815515518188,50000 -1074.5033564567566,2.8747684955596924,27596.460722208023,81117,0,27596.460722208023,0.5253000259399414,2.175584554672241,10000,28676.73327088356,0.6994778513908386,1.2843109369277954,0.6515199542045593,1.5109398365020752,50000 -1091.787621974945,2.9192492961883545,28106.558556318283,82622,0,28106.558556318283,0.52510005235672,2.1800341606140137,10000,29204.21426296234,0.6990393400192261,1.291176676750183,0.6479799747467041,1.5192651748657229,50000 -1109.4084930419922,2.964075803756714,28616.616693496704,84127,0,28616.616693496704,0.5231000185012817,2.172151565551758,10000,29731.992182970047,0.7421077489852905,1.1094766855239868,0.6539199948310852,1.498048186302185,50000 -1126.7367713451383,2.999528408050537,29126.722969293594,85632,0,29126.722969293594,0.5339000225067139,2.140961170196533,10000,30259.5151386261,0.72562575340271,1.1676472425460815,0.6571399569511414,1.469819188117981,50000 -1143.8874711990356,3.0535805225372314,29636.656057357788,87137,0,29636.656057357788,0.5339000225067139,2.1381936073303223,10000,30786.70604276657,0.7182118892669678,1.213655948638916,0.6566799879074097,1.4903264045715332,50000 -1161.2099361419678,3.103337049484253,30146.884006261826,88642,0,30146.884006261826,0.5324000120162964,2.1352744102478027,10000,31314.36168217659,0.7137675285339355,1.2036211490631104,0.6636999845504761,1.4569283723831177,50000 -1178.7714076042175,3.1447083950042725,30656.89363193512,90147,0,30656.89363193512,0.5302000045776367,2.188394784927368,10000,31842.028629779816,0.7099409699440002,1.2612212896347046,0.6620799899101257,1.5005483627319336,50000 -1196.192389011383,3.1903398036956787,31167.033656597137,91652,0,31167.033656597137,0.532800018787384,2.14026951789856,10000,32369.68866419792,0.7141063213348389,1.2344672679901123,0.6605799794197083,1.4747819900512695,50000 -1213.5849130153656,3.2358620166778564,31677.014585733414,93157,0,31677.014585733414,0.5348000526428223,2.143916606903076,10000,32897.162084817886,0.7379822731018066,1.1373311281204224,0.6595999598503113,1.4904597997665403,50000 -1230.840161561966,3.279825687408448,32187.24665856361,94662,0,32187.24665856361,0.534500002861023,2.139446258544922,10000,33424.74766254425,0.7356903553009033,1.137479305267334,0.6634199619293213,1.4608635902404783,50000 -1248.0705358982086,3.32486629486084,32697.393936157227,96167,0,32697.393936157227,0.5420000553131104,2.080425262451172,10000,33952.223504543304,0.7338966727256775,1.1363016366958618,0.6714999675750732,1.4211866855621338,50000 -1265.554355621338,3.372291326522827,33207.3950073719,97672,0,33207.3950073719,0.5360000133514404,2.12292218208313,10000,34479.80842757225,0.7274991869926453,1.16644549369812,0.6618399620056152,1.453052043914795,50000 -1282.9458377361298,3.418320894241333,33717.37340140343,99177,0,33717.37340140343,0.538100004196167,2.1251816749572754,10000,35007.27703619003,0.7198660373687744,1.182017803192139,0.6629999876022339,1.4383418560028076,50000 -1300.0815467834473,3.4637675285339355,34227.44268536568,100682,0,34227.44268536568,0.5496000051498413,2.0684101581573486,10000,35534.58135795593,0.7286949753761292,1.1676808595657349,0.6708599925041199,1.4283438920974731,50000 -1317.113971233368,3.5098297595977783,34737.54347777367,102187,0,34737.54347777367,0.5542000532150269,2.0193703174591064,10000,36061.81385445595,0.748046875,1.0852841138839722,0.6756599545478821,1.4007805585861206,50000 -1334.212421655655,3.555453062057495,35247.70157814026,103692,0,35247.70157814026,0.5421000123023987,2.0770750045776367,10000,36589.16969251633,0.7497408986091614,1.068345069885254,0.6728799939155579,1.414284348487854,50000 -1351.752160072327,3.603590250015259,35757.91668009758,105198,0,35757.91668009758,0.5439000129699707,2.1047375202178955,10000,37117.02583503723,0.7359095811843872,1.1194480657577517,0.6708999872207642,1.4225313663482666,50000 -1369.1742932796478,3.651136636734009,36268.02729392052,106703,0,36268.02729392052,0.5574000477790833,2.037686824798584,10000,37644.65886926651,0.7438616156578064,1.0795286893844604,0.6801799535751343,1.3656352758407593,50000 -1386.2812361717224,3.69752836227417,36778.0198366642,108205,0,36778.0198366642,0.5433000326156616,2.072263479232788,10000,38171.86295318604,0.7301897406578064,1.1509002447128296,0.6698799729347229,1.42633056640625,50000 -1403.6111969947815,3.74524450302124,37288.15925168991,109710,0,37288.15925168991,0.5539000034332275,2.023701429367065,10000,38699.43433403969,0.7444196343421936,1.0855257511138916,0.6843999624252319,1.3646774291992188,50000 -1421.063180923462,3.7959115505218506,37798.17662835121,111215,0,37798.17662835121,0.5539000034332275,2.0336456298828125,10000,39227.00764942169,0.7463129758834839,1.0860766172409058,0.6835599541664124,1.3672128915786743,50000 -1438.587425947189,3.8423855304718018,38308.20759654045,112720,0,38308.20759654045,0.5567000508308411,2.027822256088257,10000,39754.66309762001,0.7732580900192261,0.9794655442237854,0.6888999938964844,1.3610857725143433,50000 -1455.857293367386,3.897749662399292,38818.21781897545,114225,0,38818.21781897545,0.557200014591217,2.0630359649658203,10000,40282.05206513405,0.7552614808082581,1.0510189533233645,0.6833999752998352,1.376708984375,50000 -1472.9484844207764,3.9506664276123047,39328.42887806893,115730,0,39328.42887806893,0.5611000061035156,2.0068650245666504,10000,40809.46044826508,0.7583107352256775,1.042015790939331,0.6886799931526184,1.344862461090088,50000 -1490.2100987434387,4.000177621841431,39838.38658428192,117235,0,39838.38658428192,0.5618000030517578,2.0201923847198486,10000,41336.782682180405,0.7524114847183228,1.064433455467224,0.6860799789428711,1.3524702787399292,50000 -1507.3889908790588,4.046831607818604,40348.35553407669,118739,0,40348.35553407669,0.565500020980835,2.018036365509033,10000,41864.03025341034,0.7526506781578064,1.0996969938278198,0.6875199675559998,1.3770368099212646,50000 -1524.8025405406952,4.097309589385986,40858.48894238472,120244,0,40858.48894238472,0.5662000179290771,1.9515352249145508,10000,42391.6834628582,0.7641701102256775,1.0034244060516355,0.699999988079071,1.2891603708267212,50000 -1542.2404384613037,4.155265808105469,41368.59728837013,121749,0,41368.59728837013,0.5667999982833862,1.9649931192398071,10000,42919.34297060967,0.7843789458274841,0.9068344831466676,0.6967200040817261,1.3004614114761353,50000 -1559.600687265396,4.203623294830322,41878.64205765724,123254,0,41878.64205765724,0.5749000310897827,1.9503682851791384,10000,43446.8511133194,0.78324294090271,0.9465728998184204,0.7019199728965759,1.2940689325332642,50000 -1577.003856897354,4.252768278121948,42388.77073264122,124759,0,42388.77073264122,0.578000009059906,1.9135336875915527,10000,43974.48748540878,0.7840202450752258,0.9154542088508606,0.7075799703598022,1.2590969800949097,50000 -1594.5007934570312,4.299294471740723,42898.89861416817,126264,0,42898.89861416817,0.5743000507354736,1.9358104467391968,10000,44502.21145987511,0.7784199714660645,0.9533615708351136,0.7020599842071533,1.289789795875549,50000 -1612.0683364868164,4.361209392547607,43408.88532400131,127768,0,43408.88532400131,0.5785000324249268,1.9376139640808103,10000,45029.88151431084,0.7771444320678711,0.9476613402366638,0.7056399583816528,1.2707682847976685,50000 -1629.2153718471527,4.409427642822266,43919.058978796005,129273,0,43919.058978796005,0.5729000568389893,1.9563974142074585,10000,45557.30363321304,0.7789580225944519,0.9515725374221802,0.7039799690246582,1.283861756324768,50000 -1646.6670017242432,4.463230848312378,44429.077756643295,130777,0,44429.077756643295,0.5872000455856323,1.876109480857849,10000,46084.88196539879,0.8083147406578064,0.815697968006134,0.7135199904441833,1.2330472469329834,50000 -1664.0718927383425,4.518594264984131,44939.26374220848,132282,0,44939.26374220848,0.5893000364303589,1.8553305864334104,10000,46612.58217954636,0.8037906289100647,0.8384227156639099,0.7140399813652039,1.224155306816101,50000 -1681.5357718467712,4.574284315109253,45449.338752985,133787,0,45449.338752985,0.5929000377655029,1.859167098999024,10000,47140.23160409928,0.8019172549247742,0.8607712388038635,0.715499997138977,1.2324215173721311,50000 -1698.6805226802826,4.625326871871948,45959.35618376732,135292,0,45959.35618376732,0.5944000482559204,1.8514662981033323,10000,47667.49844503403,0.8028938174247742,0.844454824924469,0.7201399803161621,1.20967698097229,50000 -1715.7776033878326,4.675408601760864,46469.46252202988,136797,0,46469.46252202988,0.5927000045776367,1.855873107910156,10000,48194.80530762672,0.8018972873687744,0.8656412959098816,0.7197399735450745,1.2183958292007446,50000 -1733.1091315746307,4.73111629486084,46979.67847561836,138303,0,46979.67847561836,0.6022000312805176,1.815556883811951,10000,48722.46280956268,0.8061822056770325,0.8341156244277954,0.7231799960136414,1.196844220161438,50000 -1750.324533700943,4.7820470333099365,47489.7852473259,139808,0,47489.7852473259,0.601900041103363,1.808569073677063,10000,49249.88934969902,0.8362762928009033,0.7129075527191162,0.7239599823951721,1.182400465011597,50000 -1767.8215091228485,4.833433866500855,47999.80777215958,141313,0,47999.80777215958,0.6055999994277954,1.8135945796966555,10000,49777.51341342926,0.8279455900192261,0.7402110695838928,0.727679967880249,1.1678305864334106,50000 -1784.957479953766,4.888729572296143,48510.02785205841,142818,0,48510.02785205841,0.6040000319480896,1.7834153175354004,10000,50304.98026776314,0.8234215378761292,0.757717490196228,0.729479968547821,1.166062593460083,50000 -1802.3611364364624,4.945554733276367,49020.1178920269,144323,0,49020.1178920269,0.6178000569343567,1.7489664554595947,10000,50832.58455181122,0.832429826259613,0.7265508770942688,0.7337599992752075,1.1390953063964844,50000 -1819.614022731781,4.998095750808716,49530.3389377594,145828,0,49530.3389377594,0.6071000099182129,1.7726635932922363,10000,51360.16606426239,0.8336455225944519,0.722388744354248,0.7360000014305115,1.13362717628479,50000 -1836.791166067124,5.049542188644409,50040.367628097534,147333,0,50040.367628097534,0.6159000396728516,1.7439031600952148,10000,51887.47683286667,0.8375119566917419,0.7116343975067139,0.7391600012779236,1.125153422355652,50000 -1854.272777080536,5.106694459915161,50550.32852196693,148837,0,50550.32852196693,0.615600049495697,1.7641743421554563,10000,52415.02962350845,0.8622449040412903,0.6133614182472229,0.7362200021743774,1.1291574239730835,50000 -1871.5532939434047,5.160580635070801,51060.45466089249,150342,0,51060.45466089249,0.6177000403404236,1.7354857921600342,10000,52942.5445561409,0.8563655614852905,0.637763261795044,0.7451800107955933,1.1021682024002075,50000 -1889.6330585479736,5.211879253387451,51570.480170726776,151846,0,51570.480170726776,0.6177000403404236,1.743034839630127,10000,53470.75521993637,0.8573222160339355,0.6447144746780396,0.7456799745559692,1.0978809595108032,50000 -1907.001557826996,5.265218734741211,52080.545378923416,153351,0,52080.545378923416,0.6243000030517578,1.7114746570587158,10000,53998.2950387001,0.8552295565605164,0.6356821060180664,0.7467799782752991,1.086501955986023,50000 -1924.1550514698029,5.317140579223633,52590.51016449928,154855,0,52590.51016449928,0.6264000535011292,1.6894590854644775,10000,54525.517634391785,0.8594945669174194,0.6123570203781128,0.7492199540138245,1.0768321752548218,50000 -1942.180507183075,5.372523307800293,53100.41500091553,156359,0,53100.41500091553,0.628600001335144,1.6836930513381958,10000,55053.556411504745,0.861348032951355,0.6097769141197205,0.7504799962043762,1.0702511072158811,50000 -1959.568165779113,5.417553663253784,53610.552248716354,157863,0,53610.552248716354,0.6276000142097473,1.6905221939086914,10000,55581.18113279343,0.8867984414100647,0.5364611744880676,0.7528600096702576,1.0770390033721924,50000 -1977.037132024765,5.476731061935425,54120.69530892372,159368,0,54120.69530892372,0.6304000020027161,1.6812326908111572,10000,56108.90738940239,0.8807397484779358,0.5454962849617004,0.7541199922561646,1.0671308040618896,50000 -1994.1933534145355,5.536860942840576,54630.79351758957,160872,0,54630.79351758957,0.6320000290870667,1.6694732904434204,10000,56636.27542257309,0.882254421710968,0.5351816415786743,0.7583799958229065,1.0473202466964722,50000 -2011.8303670883176,5.594317197799683,55140.68953895569,162376,0,55140.68953895569,0.6378000378608704,1.6772630214691162,10000,57163.9195151329,0.8803212642669678,0.5474730134010315,0.7545599937438965,1.0637322664260864,50000 -2029.086805820465,5.652773380279541,55650.60866093636,163880,0,55650.60866093636,0.6351000070571899,1.6635215282440186,10000,57691.20793008804,0.8883529901504517,0.5204653143882751,0.7597799897193909,1.0434074401855469,50000 -2046.2316403388977,5.71399450302124,56160.56981110573,165384,0,56160.56981110573,0.6367000341415405,1.6468669176101685,10000,58218.42898082733,0.8879145383834839,0.5141134858131409,0.7617999911308289,1.0365240573883057,50000 -2063.513118505478,5.7704408168792725,56670.46351933479,166888,0,56670.46351933479,0.6419000029563904,1.6486574411392212,10000,58745.71489930153,0.9036790132522584,0.4611811637878418,0.7634199857711792,1.030709147453308,50000 -2080.8864829540253,5.8460328578948975,57180.640429496765,168393,0,57180.640429496765,0.6416000127792358,1.6349159479141235,10000,59273.3946313858,0.9052734375,0.4601774215698242,0.7659199833869934,1.0219361782073977,50000 -2098.3980734348297,5.909879684448242,57690.83558368683,169898,0,57690.83558368683,0.6459000110626221,1.622611403465271,10000,59801.218878507614,0.9046555757522584,0.4588156342506408,0.7649399638175964,1.0206623077392578,50000 -2116.0283353328705,5.96552038192749,58201.03377509117,171403,0,58201.03377509117,0.6456000208854675,1.6306302547454834,10000,60329.157285928726,0.9075254797935486,0.4488637149333954,0.7670800089836121,1.009906530380249,50000 -2133.4003245830536,6.02742600440979,58711.01109552384,172907,0,58711.01109552384,0.6478000283241272,1.615654230117798,10000,60856.62333703041,0.910375475883484,0.4398549795150757,0.7702599763870239,1.003976345062256,50000 -2150.674258470536,6.087927341461182,59221.15896511078,174412,0,59221.15896511078,0.6509000062942505,1.6078853607177734,10000,61384.15951418877,0.9120894074440002,0.4342247247695923,0.7703799605369568,1.0012692213058472,50000 -2167.850204706192,6.143632411956787,59731.146606206894,175916,0,59731.146606206894,0.6498000025749207,1.614419937133789,10000,61911.43213367462,0.915796399116516,0.4226926863193512,0.7699999809265137,1.0016485452651978,50000 -2185.389719486237,6.202255487442017,60241.354083538055,177420,0,60241.354083538055,0.6514000296592712,1.6060065031051636,10000,62439.29220581055,0.9207788109779358,0.4045875966548919,0.7719599604606628,0.9974290132522584,50000 -2202.798728942871,6.266868829727173,60751.29431128502,178924,0,60751.29431128502,0.6512000560760498,1.60286545753479,10000,62966.76009917259,0.9169324040412904,0.405758649110794,0.7729399800300598,0.9932107329368592,50000 -2220.0592410564423,6.343884229660034,61261.49818825722,180429,0,61261.49818825722,0.651900053024292,1.6055456399917605,10000,63494.355257987976,0.9206991195678712,0.4018170833587646,0.7720800042152405,0.9955241084098816,50000 -2237.2278864383698,6.401177883148193,61771.65120244026,181934,0,61771.65120244026,0.653700053691864,1.6013190746307373,10000,64021.787281513214,0.91898512840271,0.4023058712482452,0.7721999883651733,0.9920402765274048,50000 -2254.500765562057,6.462839603424072,62281.68023562431,183438,0,62281.68023562431,0.6526000499725342,1.6008130311965942,10000,64549.20379114151,0.9219945669174194,0.3971076905727386,0.772819995880127,0.9916077852249146,50000 -2272.0287766456604,6.52376127243042,62791.72662734985,184942,0,62791.72662734985,0.653700053691864,1.5995631217956543,10000,65076.893508434296,0.9215561151504517,0.3981364667415619,0.7727400064468384,0.9910697937011719,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 210e3ca8b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1982 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5919652,6.933517,,,,,,,,,,,,,, -1,,,0.0010363520123064,6.91261625289917,0.0007599999662488,6.913174629211426,50000.0,0.0006000000284984,6.9125494956970215,10000.0,51.92995691299439,88.81795763969421,51.92995691299439,36.88790965080261,0.0,0.0 -100,0.58846265,6.8997173,,,,,,,,,,,,,, -200,0.5819929,6.867396,,,,,,,,,,,,,, -300,0.62214595,6.8104377,,,,,,,,,,,,,, -400,0.6753133,6.7137947,,,,,,,,,,,,,, -500,0.70964605,6.602459,,,,,,,,,,,,,, -600,0.75311965,6.518093,,,,,,,,,,,,,, -700,0.7649732,6.465991,,,,,,,,,,,,,, -800,0.88811225,6.3917313,,,,,,,,,,,,,, -900,1.9567851,6.2281647,,,,,,,,,,,,,, -1000,1.7885996,6.170498,,,,,,,,,,,,,, -1100,1.4164201,6.0284443,,,,,,,,,,,,,, -1200,1.6268216,6.0001955,,,,,,,,,,,,,, -1300,1.740271,5.934003,,,,,,,,,,,,,, -1400,1.7124648,5.8331723,,,,,,,,,,,,,, -1498,,,0.0717275142669677,5.377313137054443,0.0648999959230423,5.448593616485596,50000.0,0.0465000011026859,5.657770156860352,10000.0,562.0372688770294,617.1234951019287,562.0372688770294,55.00457406044006,0.0275373458862304,0.0 -1500,2.8617582,5.851534,,,,,,,,,,,,,, -1600,3.0732517,5.6887183,,,,,,,,,,,,,, -1700,2.3801003,5.639028,,,,,,,,,,,,,, -1800,2.451807,5.5824795,,,,,,,,,,,,,, -1900,2.5285578,5.58612,,,,,,,,,,,,,, -2000,2.3546085,5.5459986,,,,,,,,,,,,,, -2100,2.9188645,5.4895654,,,,,,,,,,,,,, -2200,3.7264977,5.410718,,,,,,,,,,,,,, -2300,4.1654186,5.3296857,,,,,,,,,,,,,, -2400,3.2242742,5.4102883,,,,,,,,,,,,,, -2500,5.3624635,5.292855,,,,,,,,,,,,,, -2600,2.7326562,5.201928,,,,,,,,,,,,,, -2700,6.9369793,5.240113,,,,,,,,,,,,,, -2800,5.3758755,5.130969,,,,,,,,,,,,,, -2900,4.3831916,5.0823016,,,,,,,,,,,,,, -2994,,,0.1758011728525161,4.277118682861328,0.1584599912166595,4.395031929016113,50000.0,0.1105000078678131,4.828097820281982,10000.0,1072.2711553573608,1145.850881814957,1072.2711553573608,73.41657638549805,0.0556209087371826,0.0 -3000,4.6778913,5.0031075,,,,,,,,,,,,,, -3100,6.679556,5.019977,,,,,,,,,,,,,, -3200,6.2255654,5.0138283,,,,,,,,,,,,,, -3300,5.7555237,4.9188294,,,,,,,,,,,,,, -3400,5.021463,4.90181,,,,,,,,,,,,,, -3500,4.1789675,4.856604,,,,,,,,,,,,,, -3600,5.741874,4.824374,,,,,,,,,,,,,, -3700,4.229619,4.7428923,,,,,,,,,,,,,, -3800,5.0221395,4.6450095,,,,,,,,,,,,,, -3900,4.8980637,4.6339326,,,,,,,,,,,,,, -4000,6.840312,4.669504,,,,,,,,,,,,,, -4100,5.3198895,4.6470633,,,,,,,,,,,,,, -4200,7.880004,4.608964,,,,,,,,,,,,,, -4300,3.9618962,4.4742117,,,,,,,,,,,,,, -4400,7.255221,4.460317,,,,,,,,,,,,,, -4489,,,0.2791972160339355,3.528715372085572,0.2509799897670746,3.671311378479004,50000.0,0.1837000101804733,4.219701766967773,10000.0,1582.2347013950348,1674.067317724228,1582.2347013950348,91.58906817436218,0.0812327861785888,0.0 -4500,5.7255607,4.44133,,,,,,,,,,,,,, -4600,5.435438,4.38627,,,,,,,,,,,,,, -4700,4.833036,4.3947687,,,,,,,,,,,,,, -4800,4.362995,4.349086,,,,,,,,,,,,,, -4900,5.903903,4.3029346,,,,,,,,,,,,,, -5000,4.7214813,4.168232,,,,,,,,,,,,,, -5100,7.8385115,4.2158337,,,,,,,,,,,,,, -5200,6.5320196,4.1757574,,,,,,,,,,,,,, -5300,6.41319,4.3246193,,,,,,,,,,,,,, -5400,5.210958,4.0922666,,,,,,,,,,,,,, -5500,5.6356683,4.15179,,,,,,,,,,,,,, -5600,6.9639173,4.1141176,,,,,,,,,,,,,, -5700,3.644831,4.011348,,,,,,,,,,,,,, -5800,7.577235,4.109173,,,,,,,,,,,,,, -5900,6.230823,4.142295,,,,,,,,,,,,,, -5985,,,0.3691804707050323,2.9538686275482178,0.3462799787521362,3.0891966819763184,50000.0,0.2619000077247619,3.715914726257324,10000.0,2092.3124701976776,2202.3470935821533,2092.3124701976776,109.70492148399352,0.1139328479766845,0.0 -6000,5.417344,4.015612,,,,,,,,,,,,,, -6100,5.07187,4.058652,,,,,,,,,,,,,, -6200,5.2558107,3.905042,,,,,,,,,,,,,, -6300,8.691951,3.9412653,,,,,,,,,,,,,, -6400,6.652378,3.9015734,,,,,,,,,,,,,, -6500,6.2290263,3.8110354,,,,,,,,,,,,,, -6600,8.657057,3.9495301,,,,,,,,,,,,,, -6700,4.9344335,3.7452872,,,,,,,,,,,,,, -6800,4.3928566,3.833798,,,,,,,,,,,,,, -6900,5.986685,3.6988919,,,,,,,,,,,,,, -7000,5.542161,3.6813993,,,,,,,,,,,,,, -7100,5.99728,3.738625,,,,,,,,,,,,,, -7200,6.1309447,3.804022,,,,,,,,,,,,,, -7300,3.6314871,3.6685443,,,,,,,,,,,,,, -7400,4.405891,3.6511035,,,,,,,,,,,,,, -7481,,,0.4307437837123871,2.613607168197632,0.3971000015735626,2.790703296661377,50000.0,0.2989000082015991,3.4827613830566406,10000.0,2602.396719932556,2730.6603696346283,2602.396719932556,127.85242891311646,0.1410117149353027,0.0 -7500,6.075757,3.6120074,,,,,,,,,,,,,, -7600,4.7045274,3.6343048,,,,,,,,,,,,,, -7700,6.462488,3.7985146,,,,,,,,,,,,,, -7800,5.88028,3.5949943,,,,,,,,,,,,,, -7900,7.0052447,3.7167616,,,,,,,,,,,,,, -8000,4.6509085,3.567554,,,,,,,,,,,,,, -8100,7.0404654,3.562647,,,,,,,,,,,,,, -8200,5.4887247,3.4963279,,,,,,,,,,,,,, -8300,5.1811843,3.600115,,,,,,,,,,,,,, -8400,6.4578996,3.4788055,,,,,,,,,,,,,, -8500,4.335598,3.4663649,,,,,,,,,,,,,, -8600,7.7358937,3.5744839,,,,,,,,,,,,,, -8700,6.2867975,3.521585,,,,,,,,,,,,,, -8800,4.4662127,3.475379,,,,,,,,,,,,,, -8900,5.3914833,3.465237,,,,,,,,,,,,,, -8978,,,0.492287129163742,2.275744676589966,0.4618600010871887,2.433438301086426,50000.0,0.357200026512146,3.100750684738159,10000.0,3112.5727632045746,3259.198537349701,3112.5727632045746,146.1283574104309,0.1749634742736816,0.0 -9000,6.0473585,3.4290743,,,,,,,,,,,,,, -9100,3.0692112,3.4291654,,,,,,,,,,,,,, -9200,6.119689,3.5042782,,,,,,,,,,,,,, -9300,5.946896,3.3898158,,,,,,,,,,,,,, -9400,8.53267,3.393443,,,,,,,,,,,,,, -9500,4.2618403,3.4348106,,,,,,,,,,,,,, -9600,4.802609,3.2982812,,,,,,,,,,,,,, -9700,6.187919,3.4674504,,,,,,,,,,,,,, -9800,4.2790594,3.3631303,,,,,,,,,,,,,, -9900,4.996441,3.3673425,,,,,,,,,,,,,, -10000,6.4027762,3.272825,,,,,,,,,,,,,, -10100,4.788958,3.5001767,,,,,,,,,,,,,, -10200,6.288,3.463393,,,,,,,,,,,,,, -10300,4.688681,3.261662,,,,,,,,,,,,,, -10400,3.9352496,3.323901,,,,,,,,,,,,,, -10475,,,0.5693957209587097,1.8849722146987915,0.5029999613761902,2.2190101146698,50000.0,0.3951000273227691,2.8704304695129395,10000.0,3622.6284663677216,3787.772131204605,3622.6284663677216,164.5659377574921,0.2022130489349365,0.0 -10500,4.3381205,3.2532887,,,,,,,,,,,,,, -10600,6.510359,3.1702914,,,,,,,,,,,,,, -10700,6.3861217,3.3298464,,,,,,,,,,,,,, -10800,4.219602,3.2786794,,,,,,,,,,,,,, -10900,3.9676368,3.3029706,,,,,,,,,,,,,, -11000,5.3176913,3.3002517,,,,,,,,,,,,,, -11100,6.9873405,3.1763747,,,,,,,,,,,,,, -11200,6.789555,3.199216,,,,,,,,,,,,,, -11300,6.940995,3.230531,,,,,,,,,,,,,, -11400,4.4795003,3.134665,,,,,,,,,,,,,, -11500,3.8769712,3.1810741,,,,,,,,,,,,,, -11600,4.529537,3.1175365,,,,,,,,,,,,,, -11700,5.6142497,3.197114,,,,,,,,,,,,,, -11800,6.106865,3.182045,,,,,,,,,,,,,, -11900,4.5536733,3.1733115,,,,,,,,,,,,,, -11974,,,0.5763313174247742,1.9071600437164309,0.5275999903678894,2.128957271575928,50000.0,0.4093000292778015,2.793846368789673,10000.0,4132.844066381455,4316.348064184189,4132.844066381455,182.8322350978852,0.2424328327178955,0.0 -12000,5.312479,3.2431056,,,,,,,,,,,,,, -12100,5.9104304,3.1286469,,,,,,,,,,,,,, -12200,4.610932,3.2323995,,,,,,,,,,,,,, -12300,4.5401454,3.1670477,,,,,,,,,,,,,, -12400,7.489271,3.0768745,,,,,,,,,,,,,, -12500,5.3529944,3.1421216,,,,,,,,,,,,,, -12600,4.7708526,3.1707525,,,,,,,,,,,,,, -12700,5.044092,3.0631993,,,,,,,,,,,,,, -12800,5.7982574,3.182397,,,,,,,,,,,,,, -12900,7.1598663,3.1738503,,,,,,,,,,,,,, -13000,4.546345,3.0553913,,,,,,,,,,,,,, -13100,6.084323,3.1337872,,,,,,,,,,,,,, -13200,6.332659,3.0568318,,,,,,,,,,,,,, -13300,5.885892,3.1348176,,,,,,,,,,,,,, -13400,5.264166,3.1330495,,,,,,,,,,,,,, -13473,,,0.5816326141357422,1.8359532356262207,0.535539984703064,2.061918258666992,50000.0,0.4231000244617462,2.7273426055908203,10000.0,4643.087154865265,4845.225155115128,4643.087154865265,201.3808000087738,0.2749552726745605,0.0 -13500,6.4775367,3.076951,,,,,,,,,,,,,, -13600,8.468356,3.0745857,,,,,,,,,,,,,, -13700,4.6893015,3.0862741,,,,,,,,,,,,,, -13800,4.309243,3.0836587,,,,,,,,,,,,,, -13900,5.265733,3.0579374,,,,,,,,,,,,,, -14000,5.137474,3.0738869,,,,,,,,,,,,,, -14100,5.2922506,3.050061,,,,,,,,,,,,,, -14200,5.5577345,3.1158915,,,,,,,,,,,,,, -14300,5.47991,3.0223157,,,,,,,,,,,,,, -14400,7.789869,3.1392238,,,,,,,,,,,,,, -14500,4.776466,3.1531177,,,,,,,,,,,,,, -14600,5.3034153,3.0571332,,,,,,,,,,,,,, -14700,5.566613,2.982052,,,,,,,,,,,,,, -14800,4.224511,2.9955559,,,,,,,,,,,,,, -14900,4.6879067,3.012995,,,,,,,,,,,,,, -14973,,,0.5989516973495483,1.7546666860580444,0.5537399649620056,1.964607000350952,50000.0,0.431300014257431,2.6510255336761475,10000.0,5153.3032858371735,5374.694349765778,5153.3032858371735,220.55117321014404,0.304619550704956,0.0 -15000,5.0417757,3.1142657,,,,,,,,,,,,,, -15100,4.69557,3.02176,,,,,,,,,,,,,, -15200,4.4328427,3.0076063,,,,,,,,,,,,,, -15300,5.6561303,3.0606456,,,,,,,,,,,,,, -15400,5.794324,2.9635186,,,,,,,,,,,,,, -15500,4.302029,3.0559604,,,,,,,,,,,,,, -15600,5.1702924,3.0175223,,,,,,,,,,,,,, -15700,6.0849223,3.0207996,,,,,,,,,,,,,, -15800,3.5355363,3.0397253,,,,,,,,,,,,,, -15900,4.3230133,2.9990268,,,,,,,,,,,,,, -16000,6.8497915,3.0577147,,,,,,,,,,,,,, -16100,4.701148,3.034987,,,,,,,,,,,,,, -16200,5.710892,3.0651011,,,,,,,,,,,,,, -16300,6.111589,3.0616674,,,,,,,,,,,,,, -16400,5.79092,3.0826478,,,,,,,,,,,,,, -16473,,,0.59375,1.8033223152160645,0.5557399988174438,1.9902366399765008,50000.0,0.4360000193119049,2.6539673805236816,10000.0,5663.362612962723,5905.069729804993,5663.362612962723,240.77084040641785,0.3464560508728027,0.0 -16500,4.752538,2.9943976,,,,,,,,,,,,,, -16600,5.8301234,3.025496,,,,,,,,,,,,,, -16700,4.454416,2.9928432,,,,,,,,,,,,,, -16800,4.6859345,3.0117228,,,,,,,,,,,,,, -16900,4.5001917,2.9829059,,,,,,,,,,,,,, -17000,5.1401467,3.06726,,,,,,,,,,,,,, -17100,4.402298,3.0495782,,,,,,,,,,,,,, -17200,5.19754,2.9541073,,,,,,,,,,,,,, -17300,3.1203263,2.9368742,,,,,,,,,,,,,, -17400,4.204293,3.1200984,,,,,,,,,,,,,, -17500,5.7832394,2.9085088,,,,,,,,,,,,,, -17600,6.3360796,2.9443612,,,,,,,,,,,,,, -17700,4.589482,2.9321225,,,,,,,,,,,,,, -17800,4.504002,2.9891605,,,,,,,,,,,,,, -17900,5.5380507,2.9387746,,,,,,,,,,,,,, -17974,,,0.6002271771430969,1.7657809257507324,0.5660600066184998,1.9302324056625368,50000.0,0.4388000071048736,2.6314282417297363,10000.0,6173.529769182205,6436.721809148788,6173.529769182205,262.16706109046936,0.3816823959350586,0.0 -18000,7.0638905,2.9806225,,,,,,,,,,,,,, -18100,5.424322,3.100714,,,,,,,,,,,,,, -18200,5.1539006,3.0119023,,,,,,,,,,,,,, -18300,3.7761793,3.0183935,,,,,,,,,,,,,, -18400,3.3085425,2.9478526,,,,,,,,,,,,,, -18500,3.4427264,2.979712,,,,,,,,,,,,,, -18600,3.8728876,2.9049668,,,,,,,,,,,,,, -18700,3.753921,3.1367602,,,,,,,,,,,,,, -18800,4.992184,2.9422312,,,,,,,,,,,,,, -18900,3.3111968,2.9560883,,,,,,,,,,,,,, -19000,3.958573,2.9567692,,,,,,,,,,,,,, -19100,4.8238006,2.980788,,,,,,,,,,,,,, -19200,3.5978382,2.9721797,,,,,,,,,,,,,, -19300,4.3549156,3.0332525,,,,,,,,,,,,,, -19400,4.122372,3.0263317,,,,,,,,,,,,,, -19474,,,0.6611328125,1.485448122024536,0.5753799676895142,1.8646032810211184,50000.0,0.4509000182151794,2.5466179847717285,10000.0,6683.501819372177,6971.627715110779,6683.501819372177,287.0052680969238,0.423555850982666,0.0 -19500,4.7598343,3.083894,,,,,,,,,,,,,, -19600,3.5009403,3.0684376,,,,,,,,,,,,,, -19700,4.4091606,2.9272919,,,,,,,,,,,,,, -19800,4.29672,2.877136,,,,,,,,,,,,,, -19900,3.3212237,2.97336,,,,,,,,,,,,,, -20000,3.5463514,2.8683786,,,,,,,,,,,,,, -20100,3.7615023,2.9198956,,,,,,,,,,,,,, -20200,5.781681,3.0437608,,,,,,,,,,,,,, -20300,3.9075217,2.9061992,,,,,,,,,,,,,, -20400,3.4821157,2.9209602,,,,,,,,,,,,,, -20500,4.1165566,2.9976969,,,,,,,,,,,,,, -20600,3.8041973,2.8928719,,,,,,,,,,,,,, -20700,3.5172195,2.8678784,,,,,,,,,,,,,, -20800,3.034744,2.9054248,,,,,,,,,,,,,, -20900,3.59803,2.834648,,,,,,,,,,,,,, -20975,,,0.6342673897743225,1.583891749382019,0.5761199593544006,1.8759549856185915,50000.0,0.4554000198841095,2.530134439468384,10000.0,7193.579409122467,7504.646897554398,7193.579409122467,309.8524570465088,0.4640743732452392,0.0 -21000,2.7481716,3.0300546,,,,,,,,,,,,,, -21100,3.7322862,2.9158094,,,,,,,,,,,,,, -21200,4.794798,2.961384,,,,,,,,,,,,,, -21300,2.955123,2.9578786,,,,,,,,,,,,,, -21400,3.855058,2.9765677,,,,,,,,,,,,,, -21500,3.8367438,2.9930346,,,,,,,,,,,,,, -21600,2.9353077,2.9546971,,,,,,,,,,,,,, -21700,4.3118405,2.896258,,,,,,,,,,,,,, -21800,3.6963475,2.89742,,,,,,,,,,,,,, -21900,3.2234588,2.9078484,,,,,,,,,,,,,, -22000,3.2045503,2.8599458,,,,,,,,,,,,,, -22100,4.2618027,2.9638796,,,,,,,,,,,,,, -22200,2.8329754,2.9392817,,,,,,,,,,,,,, -22300,3.910834,2.8975043,,,,,,,,,,,,,, -22400,3.53623,2.9626408,,,,,,,,,,,,,, -22476,,,0.6265146732330322,1.6444129943847656,0.5791199803352356,1.8640754222869875,50000.0,0.4526000320911407,2.587834596633911,10000.0,7703.516690969467,8035.961810588837,7703.516690969467,331.14792466163635,0.4928333759307861,0.0 -22500,3.1034274,2.7991543,,,,,,,,,,,,,, -22600,4.384188,2.9413178,,,,,,,,,,,,,, -22700,3.0463123,2.9986668,,,,,,,,,,,,,, -22800,3.0338848,2.925931,,,,,,,,,,,,,, -22900,3.8024418,2.819707,,,,,,,,,,,,,, -23000,2.727463,2.8906834,,,,,,,,,,,,,, -23100,2.6499574,2.8596425,,,,,,,,,,,,,, -23200,2.9500978,2.9087617,,,,,,,,,,,,,, -23300,3.6910872,2.855183,,,,,,,,,,,,,, -23400,3.1852124,2.9217572,,,,,,,,,,,,,, -23500,2.624263,2.8786047,,,,,,,,,,,,,, -23600,3.4833834,2.9779081,,,,,,,,,,,,,, -23700,3.292454,2.9514217,,,,,,,,,,,,,, -23800,3.6460183,2.9324787,,,,,,,,,,,,,, -23900,4.014514,2.9070742,,,,,,,,,,,,,, -23978,,,0.6319355964660645,1.6008868217468262,0.582260012626648,1.830167293548584,50000.0,0.4561000168323517,2.549467325210572,10000.0,8213.743519306183,8569.167505025864,8213.743519306183,354.0397229194641,0.5262980461120605,0.0 -24000,3.3361943,2.94651,,,,,,,,,,,,,, -24100,3.0140839,2.90743,,,,,,,,,,,,,, -24200,3.4321797,2.8671346,,,,,,,,,,,,,, -24300,3.2554111,2.9603977,,,,,,,,,,,,,, -24400,3.258323,2.936456,,,,,,,,,,,,,, -24500,3.3607037,2.9693058,,,,,,,,,,,,,, -24600,3.095902,2.9045966,,,,,,,,,,,,,, -24700,2.7245946,2.8113098,,,,,,,,,,,,,, -24800,3.3361347,2.8608108,,,,,,,,,,,,,, -24900,2.8661358,2.871297,,,,,,,,,,,,,, -25000,3.3591344,2.902032,,,,,,,,,,,,,, -25100,2.7100132,2.9319894,,,,,,,,,,,,,, -25200,3.4249537,2.9757617,,,,,,,,,,,,,, -25300,3.3187494,2.9344134,,,,,,,,,,,,,, -25400,4.4046006,2.930528,,,,,,,,,,,,,, -25480,,,0.6276904940605164,1.620834469795227,0.586080014705658,1.8189243078231807,50000.0,0.4687000215053558,2.476848602294922,10000.0,8723.919181346893,9102.295284986496,8723.919181346893,376.9086351394653,0.5557169914245605,0.0 -25500,2.5866463,2.929401,,,,,,,,,,,,,, -25600,3.2085354,2.8617752,,,,,,,,,,,,,, -25700,2.8288586,2.891939,,,,,,,,,,,,,, -25800,4.9063025,2.849119,,,,,,,,,,,,,, -25900,3.657254,2.8549166,,,,,,,,,,,,,, -26000,3.5907557,2.9899032,,,,,,,,,,,,,, -26100,3.2685921,2.8105712,,,,,,,,,,,,,, -26200,3.318164,2.8739913,,,,,,,,,,,,,, -26300,3.8666515,2.9284496,,,,,,,,,,,,,, -26400,4.252199,2.832787,,,,,,,,,,,,,, -26500,3.4164596,2.9144835,,,,,,,,,,,,,, -26600,3.1875777,2.8368456,,,,,,,,,,,,,, -26700,2.9147415,2.8907743,,,,,,,,,,,,,, -26800,3.8679512,2.8486192,,,,,,,,,,,,,, -26900,3.2322521,2.7034433,,,,,,,,,,,,,, -26982,,,0.6330516338348389,1.5673516988754272,0.5966199636459351,1.7551690340042114,50000.0,0.469400018453598,2.4492757320404053,10000.0,9234.135885238647,9633.838790893557,9234.135885238647,398.1472208499909,0.58970046043396,0.0 -27000,3.0001423,2.9029503,,,,,,,,,,,,,, -27100,3.1757286,2.8102798,,,,,,,,,,,,,, -27200,3.471916,2.904969,,,,,,,,,,,,,, -27300,2.981191,2.8899312,,,,,,,,,,,,,, -27400,3.025294,2.8347464,,,,,,,,,,,,,, -27500,2.9213262,2.8096662,,,,,,,,,,,,,, -27600,2.5656962,2.8470125,,,,,,,,,,,,,, -27700,3.6790807,2.8054743,,,,,,,,,,,,,, -27800,3.353541,2.9180768,,,,,,,,,,,,,, -27900,3.3662229,2.894955,,,,,,,,,,,,,, -28000,3.574006,2.890126,,,,,,,,,,,,,, -28100,4.4368033,2.85317,,,,,,,,,,,,,, -28200,3.14515,2.8850594,,,,,,,,,,,,,, -28300,3.1399386,2.8513143,,,,,,,,,,,,,, -28400,2.984831,2.803988,,,,,,,,,,,,,, -28481,,,0.6598572731018066,1.473987102508545,0.5931400060653687,1.779598593711853,50000.0,0.4660000205039978,2.4653191566467285,10000.0,9743.258620500565,10166.62963104248,9743.258620500565,420.8384771347046,1.513392210006714,0.0 -28500,3.040568,2.7238312,,,,,,,,,,,,,, -28600,3.2620103,2.913906,,,,,,,,,,,,,, -28700,2.7362723,2.9004247,,,,,,,,,,,,,, -28800,3.3342702,2.7840059,,,,,,,,,,,,,, -28900,2.86144,2.8551598,,,,,,,,,,,,,, -29000,2.9259932,2.7576172,,,,,,,,,,,,,, -29100,2.824231,2.859121,,,,,,,,,,,,,, -29200,3.076077,2.8271236,,,,,,,,,,,,,, -29300,3.107486,2.9115424,,,,,,,,,,,,,, -29400,3.0471942,2.9092598,,,,,,,,,,,,,, -29500,3.5170815,2.8762245,,,,,,,,,,,,,, -29600,3.0546212,2.8918426,,,,,,,,,,,,,, -29700,3.0002217,2.772924,,,,,,,,,,,,,, -29800,2.9165487,2.828837,,,,,,,,,,,,,, -29900,3.0496707,2.9017553,,,,,,,,,,,,,, -29983,,,0.6408442258834839,1.5831868648529053,0.5834800004959106,1.8493436574935915,50000.0,0.4546000361442566,2.542966604232788,10000.0,10253.200542211533,10698.343393564224,10253.200542211533,442.5229060649872,1.5470540523529053,0.0 -30000,3.0958092,2.913754,,,,,,,,,,,,,, -30100,2.8589,2.838427,,,,,,,,,,,,,, -30200,3.2966666,2.9053884,,,,,,,,,,,,,, -30300,2.5949426,2.8583767,,,,,,,,,,,,,, -30400,2.593708,2.7820733,,,,,,,,,,,,,, -30500,3.3114033,2.8387194,,,,,,,,,,,,,, -30600,2.8326333,2.8183975,,,,,,,,,,,,,, -30700,2.9349985,2.8042178,,,,,,,,,,,,,, -30800,3.9718914,2.8300178,,,,,,,,,,,,,, -30900,3.1885881,2.8784041,,,,,,,,,,,,,, -31000,3.031079,2.8605442,,,,,,,,,,,,,, -31100,3.2098203,2.8172722,,,,,,,,,,,,,, -31200,2.8656228,2.8085678,,,,,,,,,,,,,, -31300,3.3122296,2.803444,,,,,,,,,,,,,, -31400,2.7885218,2.8547716,,,,,,,,,,,,,, -31484,,,0.650390625,1.5332188606262207,0.6009599566459656,1.777511715888977,50000.0,0.4751000106334686,2.450975179672241,10000.0,10763.20768213272,11231.830409526823,10763.20768213272,465.9139356613159,1.5822596549987793,0.0 -31500,2.8339329,2.888496,,,,,,,,,,,,,, -31600,2.6139412,2.701789,,,,,,,,,,,,,, -31700,2.868307,2.7814949,,,,,,,,,,,,,, -31800,2.6098168,2.839581,,,,,,,,,,,,,, -31900,3.1383653,2.8108664,,,,,,,,,,,,,, -32000,2.602061,2.7450368,,,,,,,,,,,,,, -32100,2.6150057,2.847413,,,,,,,,,,,,,, -32200,3.1289704,2.8485758,,,,,,,,,,,,,, -32300,3.4877708,2.7046287,,,,,,,,,,,,,, -32400,2.8535118,2.8483586,,,,,,,,,,,,,, -32500,3.0199847,2.7523584,,,,,,,,,,,,,, -32600,2.8395774,2.8320122,,,,,,,,,,,,,, -32700,2.847025,2.8828483,,,,,,,,,,,,,, -32800,3.0507982,2.8363888,,,,,,,,,,,,,, -32900,3.1638627,2.747665,,,,,,,,,,,,,, -32987,,,0.6478196382522583,1.545523166656494,0.5983999967575073,1.7723729610443115,50000.0,0.4719000160694122,2.464462995529175,10000.0,11273.447633743286,11764.022417783735,11273.447633743286,487.7795009613037,1.6160292625427246,0.0 -33000,3.6491218,2.7894375,,,,,,,,,,,,,, -33100,3.2146282,2.8630345,,,,,,,,,,,,,, -33200,2.8144717,2.820672,,,,,,,,,,,,,, -33300,2.525388,2.7905095,,,,,,,,,,,,,, -33400,3.1484942,2.8769505,,,,,,,,,,,,,, -33500,3.1201773,2.8510957,,,,,,,,,,,,,, -33600,3.5985577,2.9781756,,,,,,,,,,,,,, -33700,3.1794949,2.7038746,,,,,,,,,,,,,, -33800,4.31697,2.8433156,,,,,,,,,,,,,, -33900,3.1504774,2.683724,,,,,,,,,,,,,, -34000,2.9391387,2.8589058,,,,,,,,,,,,,, -34100,3.3855011,2.8047624,,,,,,,,,,,,,, -34200,2.7565606,2.7049487,,,,,,,,,,,,,, -34300,2.7895126,2.859203,,,,,,,,,,,,,, -34400,4.430565,2.754624,,,,,,,,,,,,,, -34490,,,0.6456871628761292,1.5215743780136108,0.6019399762153625,1.7319201231002808,50000.0,0.4844000339508056,2.4039146900177,10000.0,11783.682542085648,12295.01685166359,11783.682542085648,508.4510877132416,1.65091872215271,0.0 -34500,2.9461844,2.7351413,,,,,,,,,,,,,, -34600,2.7681992,2.7300003,,,,,,,,,,,,,, -34700,3.549235,2.7950094,,,,,,,,,,,,,, -34800,3.581368,2.9430676,,,,,,,,,,,,,, -34900,3.116576,2.8090253,,,,,,,,,,,,,, -35000,3.2440329,2.8109827,,,,,,,,,,,,,, -35100,2.8866792,2.8140779,,,,,,,,,,,,,, -35200,3.3744664,2.8013632,,,,,,,,,,,,,, -35300,3.4379938,2.791251,,,,,,,,,,,,,, -35400,3.2877638,2.7733116,,,,,,,,,,,,,, -35500,2.806984,2.7899873,,,,,,,,,,,,,, -35600,3.0731215,2.8645651,,,,,,,,,,,,,, -35700,3.6704624,2.733434,,,,,,,,,,,,,, -35800,3.3420217,2.794395,,,,,,,,,,,,,, -35900,3.593412,2.7875333,,,,,,,,,,,,,, -35993,,,0.6405652165412903,1.5603916645050049,0.5936599969863892,1.7699402570724487,50000.0,0.4715000092983246,2.450953722000122,10000.0,12293.693517684937,12827.367676734924,12293.693517684937,530.7041218280792,1.6843512058258057,0.0 -36000,3.2017612,2.7518132,,,,,,,,,,,,,, -36100,3.0059788,2.8682647,,,,,,,,,,,,,, -36200,3.4584446,2.7876887,,,,,,,,,,,,,, -36300,3.057537,2.871449,,,,,,,,,,,,,, -36400,3.4325805,2.8488166,,,,,,,,,,,,,, -36500,3.23723,2.8198588,,,,,,,,,,,,,, -36600,3.1337152,2.8502164,,,,,,,,,,,,,, -36700,3.9065833,2.7315273,,,,,,,,,,,,,, -36800,3.1922672,2.7538552,,,,,,,,,,,,,, -36900,3.1554825,2.7827692,,,,,,,,,,,,,, -37000,3.1075373,2.8886182,,,,,,,,,,,,,, -37100,3.1688488,2.8442235,,,,,,,,,,,,,, -37200,2.8717077,2.7857666,,,,,,,,,,,,,, -37300,3.439596,2.7808437,,,,,,,,,,,,,, -37400,3.2894554,2.7862816,,,,,,,,,,,,,, -37495,,,0.6504703164100647,1.510545015335083,0.608460009098053,1.7156388759613037,50000.0,0.4916000366210937,2.3600785732269287,10000.0,12803.7000541687,13358.99266719818,12803.7000541687,552.2334206104279,1.7202045917510986,0.0 -37500,3.069196,2.892313,,,,,,,,,,,,,, -37600,3.1282346,2.7432783,,,,,,,,,,,,,, -37700,2.936448,2.8215773,,,,,,,,,,,,,, -37800,2.8408868,2.799479,,,,,,,,,,,,,, -37900,2.7796597,2.8756793,,,,,,,,,,,,,, -38000,3.19667,2.8617296,,,,,,,,,,,,,, -38100,2.8235793,2.8180285,,,,,,,,,,,,,, -38200,2.8778095,2.799939,,,,,,,,,,,,,, -38300,3.5233133,2.8878555,,,,,,,,,,,,,, -38400,3.1333833,2.7741818,,,,,,,,,,,,,, -38500,3.234831,2.847465,,,,,,,,,,,,,, -38600,2.8379128,2.7828932,,,,,,,,,,,,,, -38700,2.9216292,2.741209,,,,,,,,,,,,,, -38800,2.901669,2.7559597,,,,,,,,,,,,,, -38900,2.782698,2.77602,,,,,,,,,,,,,, -38998,,,0.6715760231018066,1.4409266710281372,0.6063599586486816,1.7327011823654177,50000.0,0.4841000139713287,2.391342639923096,10000.0,13313.690217733383,13889.194969892502,13313.690217733383,572.3573670387268,1.755732774734497,0.0 -39000,2.8516412,2.6812646,,,,,,,,,,,,,, -39100,3.8599832,2.7598119,,,,,,,,,,,,,, -39200,3.2336366,2.7756116,,,,,,,,,,,,,, -39300,3.4453056,2.7773807,,,,,,,,,,,,,, -39400,3.003864,2.8544545,,,,,,,,,,,,,, -39500,2.8659432,2.7794845,,,,,,,,,,,,,, -39600,3.1560187,2.7058706,,,,,,,,,,,,,, -39700,3.549739,2.8016522,,,,,,,,,,,,,, -39800,4.207063,2.8258576,,,,,,,,,,,,,, -39900,3.15072,2.7913008,,,,,,,,,,,,,, -40000,3.6490169,2.8772368,,,,,,,,,,,,,, -40100,3.4243824,2.748292,,,,,,,,,,,,,, -40200,2.6626217,2.7162735,,,,,,,,,,,,,, -40300,3.2014725,2.81857,,,,,,,,,,,,,, -40400,3.4188354,2.8148613,,,,,,,,,,,,,, -40500,,,0.6689851880073547,1.414209485054016,0.6127600073814392,1.6750468015670776,50000.0,0.488500028848648,2.359315633773804,10000.0,13823.646218061447,14417.326255083084,13823.646218061447,590.4436695575714,1.7907118797302246,0.0 -40500,2.7633526,2.7700567,,,,,,,,,,,,,, -40600,3.081237,2.6913476,,,,,,,,,,,,,, -40700,2.8817174,2.7743318,,,,,,,,,,,,,, -40800,2.7537916,2.8267205,,,,,,,,,,,,,, -40900,3.0998356,2.788089,,,,,,,,,,,,,, -41000,2.9616096,2.703278,,,,,,,,,,,,,, -41100,3.4837465,2.8434389,,,,,,,,,,,,,, -41200,2.7177277,2.6936731,,,,,,,,,,,,,, -41300,3.4580736,2.7510595,,,,,,,,,,,,,, -41400,3.2685275,2.7637389,,,,,,,,,,,,,, -41500,3.4942873,2.808578,,,,,,,,,,,,,, -41600,2.9059145,2.7026896,,,,,,,,,,,,,, -41700,3.0253494,2.7742298,,,,,,,,,,,,,, -41800,3.4275892,2.7532735,,,,,,,,,,,,,, -41900,3.2765498,2.795001,,,,,,,,,,,,,, -42000,2.9647202,2.6765146,,,,,,,,,,,,,, -42003,,,0.661531388759613,1.4491329193115234,0.6156399846076965,1.6673866510391235,50000.0,0.4949000179767608,2.352324247360229,10000.0,14333.642135620115,14947.89179635048,14333.642135620115,610.9255712032318,1.8257253170013428,0.0 -42100,2.999505,2.8161511,,,,,,,,,,,,,, -42200,3.109534,2.7971523,,,,,,,,,,,,,, -42300,2.8637557,2.775804,,,,,,,,,,,,,, -42400,4.2743263,2.7637393,,,,,,,,,,,,,, -42500,3.0893178,2.7878075,,,,,,,,,,,,,, -42600,3.4377682,2.7183044,,,,,,,,,,,,,, -42700,3.9703395,2.8188708,,,,,,,,,,,,,, -42800,3.225838,2.697426,,,,,,,,,,,,,, -42900,3.1353955,2.6897478,,,,,,,,,,,,,, -43000,3.4197454,2.7029474,,,,,,,,,,,,,, -43100,3.0190203,2.7114978,,,,,,,,,,,,,, -43200,2.8725853,2.660944,,,,,,,,,,,,,, -43300,2.8750186,2.7413135,,,,,,,,,,,,,, -43400,3.8060741,2.8362122,,,,,,,,,,,,,, -43500,3.1869369,2.8561869,,,,,,,,,,,,,, -43506,,,0.6548947691917419,1.4952327013015747,0.6090599894523621,1.7058863639831543,50000.0,0.4839000105857849,2.3706626892089844,10000.0,14843.627077817917,15476.748891830444,14843.627077817917,629.7090165615082,1.8625688552856443,0.0 -43600,3.1817014,2.7716815,,,,,,,,,,,,,, -43700,3.1634936,2.761542,,,,,,,,,,,,,, -43800,2.8244913,2.706292,,,,,,,,,,,,,, -43900,3.61629,2.819595,,,,,,,,,,,,,, -44000,3.4333453,2.841167,,,,,,,,,,,,,, -44100,2.970283,2.7755117,,,,,,,,,,,,,, -44200,3.1836247,2.706784,,,,,,,,,,,,,, -44300,2.826328,2.6795416,,,,,,,,,,,,,, -44400,3.1193233,2.771802,,,,,,,,,,,,,, -44500,2.649385,2.7625484,,,,,,,,,,,,,, -44600,3.0273356,2.7245793,,,,,,,,,,,,,, -44700,2.890752,2.7271066,,,,,,,,,,,,,, -44800,3.0657456,2.6826222,,,,,,,,,,,,,, -44900,3.144546,2.7548316,,,,,,,,,,,,,, -45000,3.346162,2.7741609,,,,,,,,,,,,,, -45009,,,0.6541972160339355,1.492714524269104,0.6135199666023254,1.686639666557312,50000.0,0.4889000356197357,2.364584922790528,10000.0,15353.6026597023,16005.477684020996,15353.6026597023,648.3707220554352,1.9011075496673584,0.0 -45100,2.9908278,2.8576617,,,,,,,,,,,,,, -45200,3.8594298,2.807421,,,,,,,,,,,,,, -45300,2.9004629,2.7481077,,,,,,,,,,,,,, -45400,3.5129335,2.715786,,,,,,,,,,,,,, -45500,3.0366693,2.750863,,,,,,,,,,,,,, -45600,3.2729573,2.7604444,,,,,,,,,,,,,, -45700,2.7268581,2.7285156,,,,,,,,,,,,,, -45800,3.0897682,2.8295865,,,,,,,,,,,,,, -45900,3.350059,2.8302937,,,,,,,,,,,,,, -46000,3.2982159,2.7639613,,,,,,,,,,,,,, -46100,3.6659129,2.7273426,,,,,,,,,,,,,, -46200,2.954255,2.7861853,,,,,,,,,,,,,, -46300,3.5217562,2.6932461,,,,,,,,,,,,,, -46400,3.1485124,2.7119074,,,,,,,,,,,,,, -46500,3.237253,2.7329082,,,,,,,,,,,,,, -46513,,,0.6575055718421936,1.4960453510284424,0.6189199686050415,1.6851189136505127,50000.0,0.4930000305175781,2.3379759788513184,10000.0,15863.77936911583,16533.809260606766,15863.77936911583,666.4376637935638,1.9357657432556152,0.0 -46600,2.9840355,2.766056,,,,,,,,,,,,,, -46700,2.8465297,2.7579532,,,,,,,,,,,,,, -46800,3.3376455,2.7598364,,,,,,,,,,,,,, -46900,3.4926748,2.7522397,,,,,,,,,,,,,, -47000,3.2016644,2.770838,,,,,,,,,,,,,, -47100,4.3568573,2.7071285,,,,,,,,,,,,,, -47200,3.527173,2.7292676,,,,,,,,,,,,,, -47300,3.9949183,2.6421564,,,,,,,,,,,,,, -47400,3.2889874,2.6368985,,,,,,,,,,,,,, -47500,3.4778502,2.7587132,,,,,,,,,,,,,, -47600,3.521319,2.764423,,,,,,,,,,,,,, -47700,2.7948115,2.7655084,,,,,,,,,,,,,, -47800,2.910789,2.6743429,,,,,,,,,,,,,, -47900,2.8036747,2.7199712,,,,,,,,,,,,,, -48000,3.079715,2.7591586,,,,,,,,,,,,,, -48017,,,0.6881377696990967,1.3390893936157229,0.6141799688339233,1.6819976568222046,50000.0,0.498600035905838,2.3249881267547607,10000.0,16373.80011177063,17062.553616285324,16373.80011177063,685.0665671825409,1.9751687049865725,0.0 -48100,3.2062283,2.7364068,,,,,,,,,,,,,, -48200,2.790268,2.6900475,,,,,,,,,,,,,, -48300,3.5872443,2.8599195,,,,,,,,,,,,,, -48400,2.831749,2.7055643,,,,,,,,,,,,,, -48500,3.4108837,2.71851,,,,,,,,,,,,,, -48600,2.9014044,2.7518811,,,,,,,,,,,,,, -48700,3.1351855,2.7115784,,,,,,,,,,,,,, -48800,2.9175606,2.7332604,,,,,,,,,,,,,, -48900,3.0387557,2.7406635,,,,,,,,,,,,,, -49000,2.8768098,2.7005148,,,,,,,,,,,,,, -49100,3.3667436,2.7690547,,,,,,,,,,,,,, -49200,3.026835,2.7164595,,,,,,,,,,,,,, -49300,3.507746,2.7408223,,,,,,,,,,,,,, -49400,3.8517306,2.6661928,,,,,,,,,,,,,, -49500,3.115149,2.6275933,,,,,,,,,,,,,, -49520,,,0.6678889989852905,1.4150813817977903,0.6149399876594543,1.6707707643508911,50000.0,0.4930000305175781,2.3653693199157715,10000.0,16883.71029162407,17590.99581003189,16883.71029162407,703.5050427913666,2.014011144638061,0.0 -49600,3.308122,2.7167482,,,,,,,,,,,,,, -49700,3.5558772,2.7102745,,,,,,,,,,,,,, -49800,2.7382643,2.6949966,,,,,,,,,,,,,, -49900,3.076148,2.8184705,,,,,,,,,,,,,, -50000,3.4216232,2.7779334,,,,,,,,,,,,,, -50100,3.2239792,2.7709498,,,,,,,,,,,,,, -50200,3.165511,2.7628865,,,,,,,,,,,,,, -50300,3.9116926,2.6236541,,,,,,,,,,,,,, -50400,2.9309707,2.6926198,,,,,,,,,,,,,, -50500,2.950565,2.7305596,,,,,,,,,,,,,, -50600,3.0188487,2.7297084,,,,,,,,,,,,,, -50700,2.8647678,2.681109,,,,,,,,,,,,,, -50800,3.3731148,2.8152752,,,,,,,,,,,,,, -50900,3.0155752,2.7530065,,,,,,,,,,,,,, -51000,3.2945027,2.7442133,,,,,,,,,,,,,, -51024,,,0.6670718789100647,1.4345439672470093,0.6176199913024902,1.679994821548462,50000.0,0.5006999969482422,2.3009283542633057,10000.0,17393.85679912567,18119.246876478195,17393.85679912567,721.513111114502,2.0556466579437256,0.0 -51100,2.955706,2.7519774,,,,,,,,,,,,,, -51200,2.9579222,2.6751132,,,,,,,,,,,,,, -51300,3.16447,2.7524831,,,,,,,,,,,,,, -51400,4.2876425,2.7056239,,,,,,,,,,,,,, -51500,3.3731146,2.6884615,,,,,,,,,,,,,, -51600,2.9816034,2.6594093,,,,,,,,,,,,,, -51700,2.9718664,2.7536123,,,,,,,,,,,,,, -51800,3.0565212,2.7677045,,,,,,,,,,,,,, -51900,3.3422086,2.7750611,,,,,,,,,,,,,, -52000,2.9435916,2.661878,,,,,,,,,,,,,, -52100,3.5019395,2.695129,,,,,,,,,,,,,, -52200,3.3092997,2.7502131,,,,,,,,,,,,,, -52300,2.7333875,2.751184,,,,,,,,,,,,,, -52400,2.8515828,2.6533258,,,,,,,,,,,,,, -52500,3.25562,2.66167,,,,,,,,,,,,,, -52528,,,0.6692044138908386,1.4314494132995603,0.6190999746322632,1.6480151414871216,50000.0,0.4999000132083893,2.28476881980896,10000.0,17904.098637342453,18648.099063634872,17904.098637342453,740.029061794281,2.0966885089874268,0.0 -52600,3.4586844,2.7777088,,,,,,,,,,,,,, -52700,3.205151,2.7506835,,,,,,,,,,,,,, -52800,2.6842227,2.7145736,,,,,,,,,,,,,, -52900,3.4519508,2.832682,,,,,,,,,,,,,, -53000,3.2035596,2.8047705,,,,,,,,,,,,,, -53100,3.4527352,2.6851234,,,,,,,,,,,,,, -53200,3.2514718,2.8008704,,,,,,,,,,,,,, -53300,3.2149937,2.699142,,,,,,,,,,,,,, -53400,3.2041903,2.7495482,,,,,,,,,,,,,, -53500,3.0964177,2.7319024,,,,,,,,,,,,,, -53600,2.9677098,2.7525897,,,,,,,,,,,,,, -53700,3.2666464,2.7711797,,,,,,,,,,,,,, -53800,3.0557704,2.7121155,,,,,,,,,,,,,, -53900,2.857573,2.7426307,,,,,,,,,,,,,, -54000,3.756385,2.7081673,,,,,,,,,,,,,, -54032,,,0.6640425324440002,1.432912826538086,0.620199978351593,1.6460167169570925,50000.0,0.4923000335693359,2.3091156482696533,10000.0,18414.03647136688,19175.45029401779,18414.03647136688,757.3529677391052,2.1330432891845703,0.0 -54100,3.5948687,2.7819417,,,,,,,,,,,,,, -54200,3.2789724,2.7176757,,,,,,,,,,,,,, -54300,3.0523136,2.7143252,,,,,,,,,,,,,, -54400,2.9051714,2.783866,,,,,,,,,,,,,, -54500,2.9671278,2.69815,,,,,,,,,,,,,, -54600,3.0932128,2.721677,,,,,,,,,,,,,, -54700,3.4643643,2.7541335,,,,,,,,,,,,,, -54800,3.3194056,2.7265563,,,,,,,,,,,,,, -54900,3.476033,2.7522416,,,,,,,,,,,,,, -55000,3.0617163,2.6571598,,,,,,,,,,,,,, -55100,3.1945074,2.764235,,,,,,,,,,,,,, -55200,2.9640334,2.6391222,,,,,,,,,,,,,, -55300,3.1387668,2.6591854,,,,,,,,,,,,,, -55400,3.3097346,2.6605694,,,,,,,,,,,,,, -55500,3.0764005,2.6934047,,,,,,,,,,,,,, -55535,,,0.6664939522743225,1.427367925643921,0.6240000128746033,1.622045636177063,50000.0,0.5022000074386597,2.2728426456451416,10000.0,18923.95730996132,19702.92580795288,18923.95730996132,774.8156905174255,2.171757459640503,0.0 -55600,3.014796,2.7113166,,,,,,,,,,,,,, -55700,3.7021399,2.726395,,,,,,,,,,,,,, -55800,3.5798485,2.6787436,,,,,,,,,,,,,, -55900,3.3792045,2.7334108,,,,,,,,,,,,,, -56000,2.961431,2.7031388,,,,,,,,,,,,,, -56100,3.3290353,2.7649508,,,,,,,,,,,,,, -56200,3.6821716,2.75731,,,,,,,,,,,,,, -56300,3.2066402,2.6621532,,,,,,,,,,,,,, -56400,2.9474401,2.639636,,,,,,,,,,,,,, -56500,2.930461,2.722999,,,,,,,,,,,,,, -56600,3.01132,2.7023916,,,,,,,,,,,,,, -56700,3.090929,2.6841028,,,,,,,,,,,,,, -56800,3.219278,2.7485983,,,,,,,,,,,,,, -56900,2.9237983,2.7965524,,,,,,,,,,,,,, -57000,3.2439227,2.778859,,,,,,,,,,,,,, -57040,,,0.7063934803009033,1.2701016664505005,0.6281599998474121,1.6143368482589722,50000.0,0.5054000020027161,2.2856898307800293,10000.0,19434.118763685223,20230.6025326252,19434.118763685223,792.2390511035919,2.209519863128662,0.0 -57100,3.0826037,2.6493134,,,,,,,,,,,,,, -57200,2.8929973,2.615087,,,,,,,,,,,,,, -57300,3.8769047,2.696083,,,,,,,,,,,,,, -57400,3.688017,2.7678208,,,,,,,,,,,,,, -57500,3.153292,2.7495835,,,,,,,,,,,,,, -57600,3.0490277,2.7008543,,,,,,,,,,,,,, -57700,3.0736125,2.7347775,,,,,,,,,,,,,, -57800,3.265085,2.6883085,,,,,,,,,,,,,, -57900,3.534561,2.7442226,,,,,,,,,,,,,, -58000,3.027301,2.731069,,,,,,,,,,,,,, -58100,3.1523666,2.7945964,,,,,,,,,,,,,, -58200,2.9180226,2.853773,,,,,,,,,,,,,, -58300,3.8544815,2.7248173,,,,,,,,,,,,,, -58400,4.2199397,2.6590881,,,,,,,,,,,,,, -58500,3.002284,2.727596,,,,,,,,,,,,,, -58544,,,0.6819196343421936,1.3621110916137695,0.6212799549102783,1.638480544090271,50000.0,0.4928000271320343,2.3111507892608643,10000.0,19944.30988574028,20758.25435495377,19944.30988574028,809.6045508384705,2.2512757778167725,0.0 -58600,2.9981914,2.5930192,,,,,,,,,,,,,, -58700,2.8551784,2.6647418,,,,,,,,,,,,,, -58800,3.0630171,2.6146326,,,,,,,,,,,,,, -58900,3.324618,2.7107592,,,,,,,,,,,,,, -59000,3.1165843,2.6685076,,,,,,,,,,,,,, -59100,3.8126855,2.6476536,,,,,,,,,,,,,, -59200,3.1176736,2.845731,,,,,,,,,,,,,, -59300,2.9910538,2.6080785,,,,,,,,,,,,,, -59400,3.1278272,2.7686684,,,,,,,,,,,,,, -59500,3.5513039,2.7387824,,,,,,,,,,,,,, -59600,3.076677,2.7121446,,,,,,,,,,,,,, -59700,2.947901,2.6041086,,,,,,,,,,,,,, -59800,2.8138623,2.6518316,,,,,,,,,,,,,, -59900,3.9029694,2.7199988,,,,,,,,,,,,,, -60000,3.2786078,2.655262,,,,,,,,,,,,,, -60048,,,0.6807836294174194,1.359082579612732,0.6322000026702881,1.601298213005066,50000.0,0.5053000450134277,2.2946648597717285,10000.0,20454.509740829468,21285.80756020546,20454.509740829468,826.8658409118652,2.2905819416046143,0.0 -60100,3.6456025,2.7575946,,,,,,,,,,,,,, -60200,2.7524674,2.6594548,,,,,,,,,,,,,, -60300,3.6696475,2.7102942,,,,,,,,,,,,,, -60400,2.7817407,2.6741176,,,,,,,,,,,,,, -60500,3.0231905,2.6070743,,,,,,,,,,,,,, -60600,3.6947012,2.7088397,,,,,,,,,,,,,, -60700,3.0431855,2.624305,,,,,,,,,,,,,, -60800,3.3973093,2.7532523,,,,,,,,,,,,,, -60900,3.2015865,2.719667,,,,,,,,,,,,,, -61000,3.0483952,2.6775737,,,,,,,,,,,,,, -61100,3.2760322,2.6275883,,,,,,,,,,,,,, -61200,3.5806217,2.7320404,,,,,,,,,,,,,, -61300,3.5345838,2.674148,,,,,,,,,,,,,, -61400,3.260603,2.6975255,,,,,,,,,,,,,, -61500,3.2622125,2.7830637,,,,,,,,,,,,,, -61552,,,0.6801857352256775,1.3753252029418943,0.6317799687385559,1.5977727174758911,50000.0,0.5041000247001648,2.258517980575561,10000.0,20964.46904230117,21813.258487939835,20964.46904230117,844.2586009502411,2.3358445167541504,0.0 -61600,3.248789,2.7450042,,,,,,,,,,,,,, -61700,3.0049264,2.6124148,,,,,,,,,,,,,, -61800,3.42676,2.675169,,,,,,,,,,,,,, -61900,3.4410648,2.6523638,,,,,,,,,,,,,, -62000,2.9884183,2.6658602,,,,,,,,,,,,,, -62100,3.0617828,2.7061768,,,,,,,,,,,,,, -62200,3.0617988,2.7062743,,,,,,,,,,,,,, -62300,3.123027,2.6543512,,,,,,,,,,,,,, -62400,3.4825861,2.7051437,,,,,,,,,,,,,, -62500,3.4376032,2.6978085,,,,,,,,,,,,,, -62600,3.1461806,2.6267462,,,,,,,,,,,,,, -62700,3.5375721,2.875062,,,,,,,,,,,,,, -62800,2.9013276,2.6276448,,,,,,,,,,,,,, -62900,3.2743628,2.6325433,,,,,,,,,,,,,, -63000,3.5250945,2.78941,,,,,,,,,,,,,, -63057,,,0.6864038705825806,1.3631954193115234,0.6362400054931641,1.5933818817138672,50000.0,0.5072000026702881,2.271667242050171,10000.0,21474.617376327515,22340.84404706955,21474.617376327515,861.6014168262482,2.37718152999878,0.0 -63100,3.2446427,2.6693807,,,,,,,,,,,,,, -63200,3.184499,2.6679232,,,,,,,,,,,,,, -63300,3.1930778,2.6213164,,,,,,,,,,,,,, -63400,3.3925433,2.7240033,,,,,,,,,,,,,, -63500,3.115375,2.5798461,,,,,,,,,,,,,, -63600,2.9328244,2.5427222,,,,,,,,,,,,,, -63700,2.7904072,2.6804936,,,,,,,,,,,,,, -63800,3.5883732,2.700181,,,,,,,,,,,,,, -63900,3.2608054,2.559125,,,,,,,,,,,,,, -64000,2.9046588,2.7015762,,,,,,,,,,,,,, -64100,3.6646159,2.7861924,,,,,,,,,,,,,, -64200,3.3013716,2.6635695,,,,,,,,,,,,,, -64300,3.1256633,2.6276212,,,,,,,,,,,,,, -64400,3.2637153,2.6353803,,,,,,,,,,,,,, -64500,3.7015393,2.7956786,,,,,,,,,,,,,, -64562,,,0.6725525856018066,1.4098135232925415,0.6275399923324585,1.6258821487426758,50000.0,0.4985000193119049,2.332926034927368,10000.0,21984.827194929123,22868.620908498764,21984.827194929123,879.075288772583,2.416655778884888,0.0 -64600,3.5445487,2.6144233,,,,,,,,,,,,,, -64700,3.684392,2.6968253,,,,,,,,,,,,,, -64800,3.0096412,2.5946488,,,,,,,,,,,,,, -64900,2.8196611,2.7087986,,,,,,,,,,,,,, -65000,3.3615773,2.6283162,,,,,,,,,,,,,, -65100,3.8456888,2.7226267,,,,,,,,,,,,,, -65200,3.8335445,2.5700445,,,,,,,,,,,,,, -65300,3.2304945,2.7763627,,,,,,,,,,,,,, -65400,3.2450368,2.643385,,,,,,,,,,,,,, -65500,3.172986,2.5698104,,,,,,,,,,,,,, -65600,3.7269514,2.709453,,,,,,,,,,,,,, -65700,3.2731147,2.7626438,,,,,,,,,,,,,, -65800,3.4736798,2.6234386,,,,,,,,,,,,,, -65900,3.6628227,2.6463575,,,,,,,,,,,,,, -66000,3.03666,2.6356583,,,,,,,,,,,,,, -66067,,,0.7122927308082581,1.2576779127120972,0.624459981918335,1.6419285535812378,50000.0,0.5062000155448914,2.2767369747161865,10000.0,22494.96067333221,23396.3625805378,22494.96067333221,896.5920617580414,2.4555139541625977,0.0 -66100,3.2469234,2.67528,,,,,,,,,,,,,, -66200,3.6469789,2.6390114,,,,,,,,,,,,,, -66300,3.049616,2.6826067,,,,,,,,,,,,,, -66400,3.447437,2.7742443,,,,,,,,,,,,,, -66500,3.568798,2.6403408,,,,,,,,,,,,,, -66600,3.114887,2.6661031,,,,,,,,,,,,,, -66700,3.3941915,2.6380608,,,,,,,,,,,,,, -66800,3.0939922,2.5780027,,,,,,,,,,,,,, -66900,3.3722022,2.711598,,,,,,,,,,,,,, -67000,3.2082531,2.593729,,,,,,,,,,,,,, -67100,4.0441766,2.7027922,,,,,,,,,,,,,, -67200,3.2045684,2.6775937,,,,,,,,,,,,,, -67300,3.327312,2.6770926,,,,,,,,,,,,,, -67400,3.076762,2.5691357,,,,,,,,,,,,,, -67500,3.3025994,2.6573703,,,,,,,,,,,,,, -67572,,,0.6955117583274841,1.2980284690856934,0.6328999996185303,1.5780168771743774,50000.0,0.5186000466346741,2.217724323272705,10000.0,23005.19539809227,23924.20377588272,23005.19539809227,914.1040511131288,2.495955228805542,0.0 -67600,3.1223483,2.6815977,,,,,,,,,,,,,, -67700,4.075495,2.6391857,,,,,,,,,,,,,, -67800,3.3652291,2.7478518,,,,,,,,,,,,,, -67900,3.2552893,2.7137632,,,,,,,,,,,,,, -68000,3.3386638,2.6419733,,,,,,,,,,,,,, -68100,3.9620962,2.6873555,,,,,,,,,,,,,, -68200,3.9059122,2.5368056,,,,,,,,,,,,,, -68300,3.0711553,2.6607015,,,,,,,,,,,,,, -68400,3.0665455,2.657897,,,,,,,,,,,,,, -68500,3.5419664,2.5780149,,,,,,,,,,,,,, -68600,3.2197602,2.6315997,,,,,,,,,,,,,, -68700,3.1825655,2.6669223,,,,,,,,,,,,,, -68800,3.1133409,2.6159341,,,,,,,,,,,,,, -68900,3.1655052,2.664574,,,,,,,,,,,,,, -69000,3.2384713,2.6596308,,,,,,,,,,,,,, -69077,,,0.6938177347183228,1.3356457948684692,0.634719967842102,1.593294978141785,50000.0,0.5082000494003296,2.25681209564209,10000.0,23515.14908361435,24453.22411513329,23515.14908361435,933.078031539917,2.535719633102417,0.0 -69100,3.5910664,2.6090486,,,,,,,,,,,,,, -69200,3.3435404,2.5449219,,,,,,,,,,,,,, -69300,3.1493313,2.670137,,,,,,,,,,,,,, -69400,3.5290194,2.7039742,,,,,,,,,,,,,, -69500,3.7637115,2.6991544,,,,,,,,,,,,,, -69600,3.3837492,2.6815362,,,,,,,,,,,,,, -69700,3.7880404,2.6469285,,,,,,,,,,,,,, -69800,3.2032373,2.689397,,,,,,,,,,,,,, -69900,3.5385609,2.5860615,,,,,,,,,,,,,, -70000,3.0084982,2.6400406,,,,,,,,,,,,,, -70100,2.803121,2.6671367,,,,,,,,,,,,,, -70200,3.9511366,2.5667677,,,,,,,,,,,,,, -70300,3.5862298,2.6076763,,,,,,,,,,,,,, -70400,4.3538923,2.6966145,,,,,,,,,,,,,, -70500,3.6186004,2.5996761,,,,,,,,,,,,,, -70582,,,0.6916653513908386,1.3394980430603027,0.6408199667930603,1.581676959991455,50000.0,0.5163000226020813,2.2319271564483643,10000.0,24025.31227278709,24981.911822795868,24025.31227278709,951.5091438293456,2.5747079849243164,0.0 -70600,3.2671814,2.6539173,,,,,,,,,,,,,, -70700,3.5437148,2.7027311,,,,,,,,,,,,,, -70800,3.435298,2.6899614,,,,,,,,,,,,,, -70900,4.086979,2.7232363,,,,,,,,,,,,,, -71000,3.3283632,2.6093652,,,,,,,,,,,,,, -71100,3.1302843,2.51526,,,,,,,,,,,,,, -71200,3.6658406,2.6436863,,,,,,,,,,,,,, -71300,3.2906444,2.708667,,,,,,,,,,,,,, -71400,3.259542,2.6080303,,,,,,,,,,,,,, -71500,3.8491309,2.712987,,,,,,,,,,,,,, -71600,3.990106,2.711335,,,,,,,,,,,,,, -71700,3.238091,2.6675458,,,,,,,,,,,,,, -71800,3.667558,2.5920959,,,,,,,,,,,,,, -71900,3.2478042,2.6640909,,,,,,,,,,,,,, -72000,3.5065067,2.5706542,,,,,,,,,,,,,, -72087,,,0.6868821382522583,1.351656198501587,0.6393600106239319,1.5750479698181152,50000.0,0.518500030040741,2.226827621459961,10000.0,24535.51449584961,25509.434837818146,24535.51449584961,968.733303785324,2.617192506790161,0.0 -72100,3.0574057,2.6481383,,,,,,,,,,,,,, -72200,3.3138611,2.6386673,,,,,,,,,,,,,, -72300,3.7046266,2.6731918,,,,,,,,,,,,,, -72400,3.9596899,2.701198,,,,,,,,,,,,,, -72500,3.1946049,2.6223311,,,,,,,,,,,,,, -72600,3.4481034,2.5528731,,,,,,,,,,,,,, -72700,3.78149,2.561407,,,,,,,,,,,,,, -72800,3.3175435,2.6292315,,,,,,,,,,,,,, -72900,3.104965,2.5624442,,,,,,,,,,,,,, -73000,3.4898825,2.6942835,,,,,,,,,,,,,, -73100,3.5837934,2.6611862,,,,,,,,,,,,,, -73200,3.2902772,2.6587183,,,,,,,,,,,,,, -73300,3.0646298,2.4959767,,,,,,,,,,,,,, -73400,3.040267,2.5921624,,,,,,,,,,,,,, -73500,3.106275,2.6092606,,,,,,,,,,,,,, -73592,,,0.6874202489852905,1.3545960187911987,0.6374799609184265,1.5858798027038574,50000.0,0.513700008392334,2.240742921829224,10000.0,25045.693425178528,26037.469779729843,25045.693425178528,986.4949653148652,2.6583826541900635,0.0 -73600,3.3448086,2.5232136,,,,,,,,,,,,,, -73700,3.1059513,2.5552287,,,,,,,,,,,,,, -73800,3.268865,2.5518293,,,,,,,,,,,,,, -73900,3.4654176,2.69302,,,,,,,,,,,,,, -74000,3.1215773,2.5391054,,,,,,,,,,,,,, -74100,4.064324,2.4908133,,,,,,,,,,,,,, -74200,3.7782962,2.7156935,,,,,,,,,,,,,, -74300,3.0113034,2.7054496,,,,,,,,,,,,,, -74400,3.303907,2.572212,,,,,,,,,,,,,, -74500,3.4099505,2.6171741,,,,,,,,,,,,,, -74600,3.305474,2.7819092,,,,,,,,,,,,,, -74700,3.849975,2.677137,,,,,,,,,,,,,, -74800,3.5187514,2.7122848,,,,,,,,,,,,,, -74900,3.2771194,2.6196225,,,,,,,,,,,,,, -75000,3.2991116,2.6467204,,,,,,,,,,,,,, -75098,,,0.7245495915412903,1.1946289539337158,0.6363799571990967,1.5921239852905271,50000.0,0.5095000267028809,2.269789695739746,10000.0,25555.87436771393,26566.167140960693,25555.87436771393,1004.9125220775604,2.7030186653137207,0.0 -75100,3.562554,2.6137574,,,,,,,,,,,,,, -75200,3.3810935,2.6779006,,,,,,,,,,,,,, -75300,3.137885,2.6682794,,,,,,,,,,,,,, -75400,4.104538,2.6238308,,,,,,,,,,,,,, -75500,3.6680794,2.6578999,,,,,,,,,,,,,, -75600,3.393812,2.6502485,,,,,,,,,,,,,, -75700,3.2448869,2.6065083,,,,,,,,,,,,,, -75800,3.9702947,2.6692762,,,,,,,,,,,,,, -75900,3.220001,2.7353532,,,,,,,,,,,,,, -76000,3.1720629,2.6402023,,,,,,,,,,,,,, -76100,3.8495095,2.671981,,,,,,,,,,,,,, -76200,3.464563,2.607471,,,,,,,,,,,,,, -76300,3.3487718,2.5900588,,,,,,,,,,,,,, -76400,3.5401127,2.6493518,,,,,,,,,,,,,, -76500,3.6133907,2.5862622,,,,,,,,,,,,,, -76600,3.4884589,2.633348,,,,,,,,,,,,,, -76603,,,0.7009127736091614,1.2936850786209106,0.6375199556350708,1.5906641483306885,50000.0,0.5212000012397766,2.208439588546753,10000.0,26066.06677961349,27094.41897177696,26066.06677961349,1022.8778517246246,2.7443735599517822,0.0 -76700,3.2312331,2.6150696,,,,,,,,,,,,,, -76800,3.4681656,2.5939212,,,,,,,,,,,,,, -76900,3.6052136,2.5683258,,,,,,,,,,,,,, -77000,3.4288418,2.6480064,,,,,,,,,,,,,, -77100,3.3515747,2.5921524,,,,,,,,,,,,,, -77200,3.6344278,2.5683217,,,,,,,,,,,,,, -77300,3.490059,2.6393363,,,,,,,,,,,,,, -77400,4.0151916,2.6313927,,,,,,,,,,,,,, -77500,3.4087908,2.5464585,,,,,,,,,,,,,, -77600,3.85394,2.603021,,,,,,,,,,,,,, -77700,3.366233,2.5317008,,,,,,,,,,,,,, -77800,3.4927838,2.6629896,,,,,,,,,,,,,, -77900,4.0911355,2.7208152,,,,,,,,,,,,,, -78000,3.3552825,2.5098693,,,,,,,,,,,,,, -78100,3.4301128,2.523695,,,,,,,,,,,,,, -78107,,,0.7107182741165161,1.2404513359069824,0.6543399691581726,1.5085474252700806,50000.0,0.5258000493049622,2.185645580291748,10000.0,26576.22903299332,27621.786662578583,26576.22903299332,1039.984982252121,2.788038969039917,0.0 -78200,3.354368,2.6507928,,,,,,,,,,,,,, -78300,3.1301327,2.5695753,,,,,,,,,,,,,, -78400,3.9323778,2.6571646,,,,,,,,,,,,,, -78500,3.7310221,2.6956844,,,,,,,,,,,,,, -78600,3.525899,2.6060686,,,,,,,,,,,,,, -78700,3.6162405,2.5395198,,,,,,,,,,,,,, -78800,3.7005467,2.580789,,,,,,,,,,,,,, -78900,3.5856183,2.58855,,,,,,,,,,,,,, -79000,3.795913,2.676248,,,,,,,,,,,,,, -79100,3.3561983,2.598122,,,,,,,,,,,,,, -79200,3.3476684,2.564008,,,,,,,,,,,,,, -79300,3.4550574,2.5658748,,,,,,,,,,,,,, -79400,3.3211715,2.5503685,,,,,,,,,,,,,, -79500,3.447696,2.6357036,,,,,,,,,,,,,, -79600,3.3642046,2.599392,,,,,,,,,,,,,, -79612,,,0.7019491195678711,1.2742984294891355,0.6469199657440186,1.5245815515518188,50000.0,0.5260000228881836,2.1862270832061768,10000.0,27086.374658584595,28149.237203598022,27086.374658584595,1057.1927177906036,2.82971453666687,0.0 -79700,4.5059443,2.6753337,,,,,,,,,,,,,, -79800,3.3294992,2.557775,,,,,,,,,,,,,, -79900,3.411395,2.534563,,,,,,,,,,,,,, -80000,3.2058494,2.4989212,,,,,,,,,,,,,, -80100,3.3852382,2.590839,,,,,,,,,,,,,, -80200,3.5533288,2.6012626,,,,,,,,,,,,,, -80300,3.1682494,2.6180954,,,,,,,,,,,,,, -80400,3.3250163,2.577752,,,,,,,,,,,,,, -80500,3.4102833,2.6068077,,,,,,,,,,,,,, -80600,3.546195,2.6419253,,,,,,,,,,,,,, -80700,3.6832168,2.6350188,,,,,,,,,,,,,, -80800,3.4450128,2.5922558,,,,,,,,,,,,,, -80900,3.2910275,2.5222726,,,,,,,,,,,,,, -81000,3.4969356,2.5252478,,,,,,,,,,,,,, -81100,3.4435678,2.5148394,,,,,,,,,,,,,, -81117,,,0.6994778513908386,1.2843109369277954,0.6515199542045593,1.5109398365020752,50000.0,0.5253000259399414,2.175584554672241,10000.0,27596.460722208023,28676.73327088356,27596.460722208023,1074.5033564567566,2.8747684955596924,0.0 -81200,3.7173178,2.4840794,,,,,,,,,,,,,, -81300,3.4435852,2.5209606,,,,,,,,,,,,,, -81400,3.3884788,2.6137571,,,,,,,,,,,,,, -81500,3.5024972,2.5702693,,,,,,,,,,,,,, -81600,3.7605715,2.53016,,,,,,,,,,,,,, -81700,3.6618025,2.66248,,,,,,,,,,,,,, -81800,3.3443544,2.5856352,,,,,,,,,,,,,, -81900,3.333302,2.620035,,,,,,,,,,,,,, -82000,3.5554967,2.6040885,,,,,,,,,,,,,, -82100,3.4086277,2.5408838,,,,,,,,,,,,,, -82200,3.6377058,2.5651581,,,,,,,,,,,,,, -82300,3.4567916,2.6285186,,,,,,,,,,,,,, -82400,4.1410136,2.5893373,,,,,,,,,,,,,, -82500,3.3991416,2.5908337,,,,,,,,,,,,,, -82600,3.681516,2.5420325,,,,,,,,,,,,,, -82622,,,0.6990393400192261,1.291176676750183,0.6479799747467041,1.5192651748657229,50000.0,0.52510005235672,2.1800341606140137,10000.0,28106.558556318283,29204.21426296234,28106.558556318283,1091.787621974945,2.9192492961883545,0.0 -82700,3.5117733,2.6428034,,,,,,,,,,,,,, -82800,3.2299037,2.5435412,,,,,,,,,,,,,, -82900,3.3012831,2.6277487,,,,,,,,,,,,,, -83000,4.480883,2.5901291,,,,,,,,,,,,,, -83100,3.405983,2.5567791,,,,,,,,,,,,,, -83200,3.7771647,2.5623767,,,,,,,,,,,,,, -83300,3.406113,2.5821917,,,,,,,,,,,,,, -83400,3.9240835,2.653432,,,,,,,,,,,,,, -83500,3.7469037,2.607208,,,,,,,,,,,,,, -83600,3.3464947,2.6196349,,,,,,,,,,,,,, -83700,3.702064,2.609382,,,,,,,,,,,,,, -83800,4.010138,2.5188837,,,,,,,,,,,,,, -83900,3.4693482,2.5556252,,,,,,,,,,,,,, -84000,4.062319,2.57976,,,,,,,,,,,,,, -84100,3.2594578,2.6296184,,,,,,,,,,,,,, -84127,,,0.7421077489852905,1.1094766855239868,0.6539199948310852,1.498048186302185,50000.0,0.5231000185012817,2.172151565551758,10000.0,28616.616693496704,29731.992182970047,28616.616693496704,1109.4084930419922,2.964075803756714,0.0 -84200,3.5281737,2.4998596,,,,,,,,,,,,,, -84300,3.7088013,2.5925133,,,,,,,,,,,,,, -84400,3.4406686,2.564549,,,,,,,,,,,,,, -84500,4.386578,2.5825455,,,,,,,,,,,,,, -84600,3.5981784,2.626177,,,,,,,,,,,,,, -84700,3.5762854,2.5545068,,,,,,,,,,,,,, -84800,4.298593,2.5367408,,,,,,,,,,,,,, -84900,3.9533212,2.5232573,,,,,,,,,,,,,, -85000,3.9706366,2.5770664,,,,,,,,,,,,,, -85100,3.842358,2.5630562,,,,,,,,,,,,,, -85200,3.30245,2.501516,,,,,,,,,,,,,, -85300,3.374791,2.6007533,,,,,,,,,,,,,, -85400,3.9488144,2.5067065,,,,,,,,,,,,,, -85500,4.1143866,2.5767105,,,,,,,,,,,,,, -85600,3.5497525,2.5162196,,,,,,,,,,,,,, -85632,,,0.72562575340271,1.1676472425460815,0.6571399569511414,1.469819188117981,50000.0,0.5339000225067139,2.140961170196533,10000.0,29126.722969293594,30259.5151386261,29126.722969293594,1126.7367713451383,2.999528408050537,0.0 -85700,3.2271638,2.5935774,,,,,,,,,,,,,, -85800,3.605181,2.5766134,,,,,,,,,,,,,, -85900,3.6191185,2.5610006,,,,,,,,,,,,,, -86000,3.8439176,2.5561337,,,,,,,,,,,,,, -86100,3.6625257,2.5349214,,,,,,,,,,,,,, -86200,3.8240595,2.6443563,,,,,,,,,,,,,, -86300,3.3465965,2.572959,,,,,,,,,,,,,, -86400,3.8380694,2.6390457,,,,,,,,,,,,,, -86500,3.6605299,2.6443572,,,,,,,,,,,,,, -86600,3.4921184,2.62611,,,,,,,,,,,,,, -86700,3.4110484,2.6597314,,,,,,,,,,,,,, -86800,3.9998066,2.6125264,,,,,,,,,,,,,, -86900,3.7138493,2.508768,,,,,,,,,,,,,, -87000,4.016875,2.5413575,,,,,,,,,,,,,, -87100,3.14971,2.5898137,,,,,,,,,,,,,, -87137,,,0.7182118892669678,1.213655948638916,0.6566799879074097,1.4903264045715332,50000.0,0.5339000225067139,2.1381936073303223,10000.0,29636.656057357788,30786.70604276657,29636.656057357788,1143.8874711990356,3.0535805225372314,0.0 -87200,3.8011158,2.5487149,,,,,,,,,,,,,, -87300,3.8679175,2.5410178,,,,,,,,,,,,,, -87400,3.5323484,2.5842977,,,,,,,,,,,,,, -87500,3.4965715,2.5961585,,,,,,,,,,,,,, -87600,3.594009,2.5518806,,,,,,,,,,,,,, -87700,3.8498454,2.572019,,,,,,,,,,,,,, -87800,3.4810784,2.5523663,,,,,,,,,,,,,, -87900,3.7538052,2.597374,,,,,,,,,,,,,, -88000,3.7676494,2.618662,,,,,,,,,,,,,, -88100,3.5047877,2.591002,,,,,,,,,,,,,, -88200,3.8264375,2.6163347,,,,,,,,,,,,,, -88300,3.5062134,2.5068853,,,,,,,,,,,,,, -88400,3.553185,2.609613,,,,,,,,,,,,,, -88500,3.825901,2.6591244,,,,,,,,,,,,,, -88600,3.458163,2.487541,,,,,,,,,,,,,, -88642,,,0.7137675285339355,1.2036211490631104,0.6636999845504761,1.4569283723831177,50000.0,0.5324000120162964,2.1352744102478027,10000.0,30146.884006261826,31314.36168217659,30146.884006261826,1161.2099361419678,3.103337049484253,0.0 -88700,3.9292507,2.4581676,,,,,,,,,,,,,, -88800,3.474811,2.6331072,,,,,,,,,,,,,, -88900,3.603686,2.5290713,,,,,,,,,,,,,, -89000,3.6289976,2.573082,,,,,,,,,,,,,, -89100,4.435964,2.5123067,,,,,,,,,,,,,, -89200,3.4904962,2.5312738,,,,,,,,,,,,,, -89300,3.4742136,2.5745614,,,,,,,,,,,,,, -89400,3.5802982,2.6188862,,,,,,,,,,,,,, -89500,3.7241285,2.5887709,,,,,,,,,,,,,, -89600,4.287895,2.6663418,,,,,,,,,,,,,, -89700,3.5204885,2.4996026,,,,,,,,,,,,,, -89800,3.8312573,2.5925622,,,,,,,,,,,,,, -89900,3.802878,2.6021721,,,,,,,,,,,,,, -90000,3.776965,2.6042278,,,,,,,,,,,,,, -90100,3.5416496,2.567632,,,,,,,,,,,,,, -90147,,,0.7099409699440002,1.2612212896347046,0.6620799899101257,1.5005483627319336,50000.0,0.5302000045776367,2.188394784927368,10000.0,30656.89363193512,31842.028629779816,30656.89363193512,1178.7714076042175,3.1447083950042725,0.0 -90200,4.7094975,2.5823815,,,,,,,,,,,,,, -90300,3.5159295,2.6170638,,,,,,,,,,,,,, -90400,3.364855,2.4885669,,,,,,,,,,,,,, -90500,4.0671883,2.5437322,,,,,,,,,,,,,, -90600,3.7018907,2.5272474,,,,,,,,,,,,,, -90700,3.6675692,2.604786,,,,,,,,,,,,,, -90800,4.4554763,2.4865994,,,,,,,,,,,,,, -90900,3.529672,2.4804924,,,,,,,,,,,,,, -91000,4.474466,2.4839213,,,,,,,,,,,,,, -91100,3.8912113,2.5888705,,,,,,,,,,,,,, -91200,3.5677738,2.5438938,,,,,,,,,,,,,, -91300,3.5910816,2.537339,,,,,,,,,,,,,, -91400,3.5719578,2.5118418,,,,,,,,,,,,,, -91500,4.160109,2.4917054,,,,,,,,,,,,,, -91600,3.708237,2.5907114,,,,,,,,,,,,,, -91652,,,0.7141063213348389,1.2344672679901123,0.6605799794197083,1.4747819900512695,50000.0,0.532800018787384,2.14026951789856,10000.0,31167.033656597137,32369.68866419792,31167.033656597137,1196.192389011383,3.1903398036956787,0.0 -91700,4.2497244,2.569336,,,,,,,,,,,,,, -91800,3.5101864,2.6257625,,,,,,,,,,,,,, -91900,4.096653,2.6050558,,,,,,,,,,,,,, -92000,3.6296182,2.5879483,,,,,,,,,,,,,, -92100,4.0127897,2.5335925,,,,,,,,,,,,,, -92200,4.0551467,2.4792783,,,,,,,,,,,,,, -92300,3.4917257,2.50975,,,,,,,,,,,,,, -92400,3.763248,2.5806518,,,,,,,,,,,,,, -92500,3.6146436,2.5331635,,,,,,,,,,,,,, -92600,3.9864862,2.4446223,,,,,,,,,,,,,, -92700,3.4893565,2.5126371,,,,,,,,,,,,,, -92800,4.1443286,2.4612665,,,,,,,,,,,,,, -92900,3.7966764,2.6058843,,,,,,,,,,,,,, -93000,4.049747,2.6247642,,,,,,,,,,,,,, -93100,4.054458,2.4785974,,,,,,,,,,,,,, -93157,,,0.7379822731018066,1.1373311281204224,0.6595999598503113,1.4904597997665403,50000.0,0.5348000526428223,2.143916606903076,10000.0,31677.014585733414,32897.162084817886,31677.014585733414,1213.5849130153656,3.2358620166778564,0.0 -93200,4.4371476,2.5860388,,,,,,,,,,,,,, -93300,3.6613336,2.4570985,,,,,,,,,,,,,, -93400,3.6046276,2.6377523,,,,,,,,,,,,,, -93500,3.8957133,2.4871583,,,,,,,,,,,,,, -93600,3.6844933,2.4835272,,,,,,,,,,,,,, -93700,3.964091,2.5470417,,,,,,,,,,,,,, -93800,4.1371875,2.5890322,,,,,,,,,,,,,, -93900,3.9537933,2.4419982,,,,,,,,,,,,,, -94000,3.8449495,2.576779,,,,,,,,,,,,,, -94100,3.5874074,2.5630643,,,,,,,,,,,,,, -94200,3.941196,2.5852156,,,,,,,,,,,,,, -94300,3.96011,2.480853,,,,,,,,,,,,,, -94400,3.8789659,2.44347,,,,,,,,,,,,,, -94500,3.3877947,2.5256052,,,,,,,,,,,,,, -94600,3.8511827,2.4959946,,,,,,,,,,,,,, -94662,,,0.7356903553009033,1.137479305267334,0.6634199619293213,1.4608635902404783,50000.0,0.534500002861023,2.139446258544922,10000.0,32187.24665856361,33424.74766254425,32187.24665856361,1230.840161561966,3.279825687408448,0.0 -94700,3.8866653,2.464415,,,,,,,,,,,,,, -94800,4.5355635,2.6201892,,,,,,,,,,,,,, -94900,4.0765586,2.4910035,,,,,,,,,,,,,, -95000,3.6578584,2.5587728,,,,,,,,,,,,,, -95100,3.3590276,2.517982,,,,,,,,,,,,,, -95200,3.8230574,2.4056396,,,,,,,,,,,,,, -95300,3.9712396,2.554119,,,,,,,,,,,,,, -95400,4.0999703,2.4900663,,,,,,,,,,,,,, -95500,3.92106,2.4410257,,,,,,,,,,,,,, -95600,3.6683464,2.5008297,,,,,,,,,,,,,, -95700,3.7594774,2.4959917,,,,,,,,,,,,,, -95800,4.0992136,2.569718,,,,,,,,,,,,,, -95900,3.5908608,2.4646826,,,,,,,,,,,,,, -96000,3.777127,2.6637504,,,,,,,,,,,,,, -96100,4.1707807,2.6656423,,,,,,,,,,,,,, -96167,,,0.7338966727256775,1.1363016366958618,0.6714999675750732,1.4211866855621338,50000.0,0.5420000553131104,2.080425262451172,10000.0,32697.393936157227,33952.223504543304,32697.393936157227,1248.0705358982086,3.32486629486084,0.0 -96200,3.7274246,2.5546174,,,,,,,,,,,,,, -96300,4.1946926,2.5571647,,,,,,,,,,,,,, -96400,3.3616724,2.4468887,,,,,,,,,,,,,, -96500,4.2874904,2.5344942,,,,,,,,,,,,,, -96600,3.9238305,2.5460534,,,,,,,,,,,,,, -96700,4.282932,2.5839481,,,,,,,,,,,,,, -96800,4.0080047,2.481394,,,,,,,,,,,,,, -96900,3.330403,2.4881866,,,,,,,,,,,,,, -97000,4.1237392,2.4892936,,,,,,,,,,,,,, -97100,3.9388397,2.5120249,,,,,,,,,,,,,, -97200,3.7414174,2.464644,,,,,,,,,,,,,, -97300,3.8046649,2.5092525,,,,,,,,,,,,,, -97400,4.05203,2.472979,,,,,,,,,,,,,, -97500,3.8124485,2.414277,,,,,,,,,,,,,, -97600,3.5298946,2.5815086,,,,,,,,,,,,,, -97672,,,0.7274991869926453,1.16644549369812,0.6618399620056152,1.453052043914795,50000.0,0.5360000133514404,2.12292218208313,10000.0,33207.3950073719,34479.80842757225,33207.3950073719,1265.554355621338,3.372291326522827,0.0 -97700,4.158326,2.469777,,,,,,,,,,,,,, -97800,3.884182,2.545891,,,,,,,,,,,,,, -97900,4.7427235,2.4288375,,,,,,,,,,,,,, -98000,3.7196434,2.5761235,,,,,,,,,,,,,, -98100,4.5712776,2.502759,,,,,,,,,,,,,, -98200,3.4841359,2.5360541,,,,,,,,,,,,,, -98300,4.2548137,2.461391,,,,,,,,,,,,,, -98400,3.8236134,2.4234645,,,,,,,,,,,,,, -98500,3.7163727,2.5391536,,,,,,,,,,,,,, -98600,4.128727,2.5593104,,,,,,,,,,,,,, -98700,4.1769657,2.4822092,,,,,,,,,,,,,, -98800,3.8612466,2.5481317,,,,,,,,,,,,,, -98900,4.0372295,2.5704675,,,,,,,,,,,,,, -99000,3.615736,2.4009726,,,,,,,,,,,,,, -99100,3.9388592,2.424845,,,,,,,,,,,,,, -99177,,,0.7198660373687744,1.182017803192139,0.6629999876022339,1.4383418560028076,50000.0,0.538100004196167,2.1251816749572754,10000.0,33717.37340140343,35007.27703619003,33717.37340140343,1282.9458377361298,3.418320894241333,0.0 -99200,4.52858,2.4679983,,,,,,,,,,,,,, -99300,4.2492223,2.5924542,,,,,,,,,,,,,, -99400,3.9855564,2.5017564,,,,,,,,,,,,,, -99500,3.567299,2.5180423,,,,,,,,,,,,,, -99600,3.9285362,2.4626698,,,,,,,,,,,,,, -99700,4.0925703,2.5824,,,,,,,,,,,,,, -99800,3.5369053,2.4614053,,,,,,,,,,,,,, -99900,3.732581,2.4787407,,,,,,,,,,,,,, -100000,4.055588,2.5401907,,,,,,,,,,,,,, -100100,4.386351,2.451534,,,,,,,,,,,,,, -100200,3.4161453,2.4007444,,,,,,,,,,,,,, -100300,4.3266377,2.5167027,,,,,,,,,,,,,, -100400,3.8622184,2.440689,,,,,,,,,,,,,, -100500,5.246373,2.5477061,,,,,,,,,,,,,, -100600,3.5862992,2.4399524,,,,,,,,,,,,,, -100682,,,0.7286949753761292,1.1676808595657349,0.6708599925041199,1.4283438920974731,50000.0,0.5496000051498413,2.0684101581573486,10000.0,34227.44268536568,35534.58135795593,34227.44268536568,1300.0815467834473,3.4637675285339355,0.0 -100700,3.7817216,2.4820952,,,,,,,,,,,,,, -100800,4.208529,2.5202343,,,,,,,,,,,,,, -100900,3.5157156,2.5162268,,,,,,,,,,,,,, -101000,3.9031527,2.4699783,,,,,,,,,,,,,, -101100,3.666284,2.5614474,,,,,,,,,,,,,, -101200,3.9238338,2.4338188,,,,,,,,,,,,,, -101300,4.049907,2.4804034,,,,,,,,,,,,,, -101400,3.519705,2.4955213,,,,,,,,,,,,,, -101500,3.78157,2.4251807,,,,,,,,,,,,,, -101600,4.0358653,2.4628625,,,,,,,,,,,,,, -101700,4.3417735,2.5382013,,,,,,,,,,,,,, -101800,4.2058015,2.476564,,,,,,,,,,,,,, -101900,4.058688,2.6141086,,,,,,,,,,,,,, -102000,4.486819,2.525995,,,,,,,,,,,,,, -102100,3.5696378,2.3412273,,,,,,,,,,,,,, -102187,,,0.748046875,1.0852841138839722,0.6756599545478821,1.4007805585861206,50000.0,0.5542000532150269,2.0193703174591064,10000.0,34737.54347777367,36061.81385445595,34737.54347777367,1317.113971233368,3.5098297595977783,0.0 -102200,3.7449293,2.51132,,,,,,,,,,,,,, -102300,4.0388536,2.493324,,,,,,,,,,,,,, -102400,4.083013,2.3842213,,,,,,,,,,,,,, -102500,4.4916964,2.4652326,,,,,,,,,,,,,, -102600,3.6941137,2.3743777,,,,,,,,,,,,,, -102700,3.5575633,2.3683655,,,,,,,,,,,,,, -102800,4.258756,2.5401807,,,,,,,,,,,,,, -102900,4.74531,2.556356,,,,,,,,,,,,,, -103000,4.057252,2.3900933,,,,,,,,,,,,,, -103100,4.3047333,2.4628313,,,,,,,,,,,,,, -103200,4.1023455,2.4457443,,,,,,,,,,,,,, -103300,3.7991376,2.3748572,,,,,,,,,,,,,, -103400,3.7527616,2.4053092,,,,,,,,,,,,,, -103500,4.2468843,2.541068,,,,,,,,,,,,,, -103600,3.524085,2.4784472,,,,,,,,,,,,,, -103692,,,0.7497408986091614,1.068345069885254,0.6728799939155579,1.414284348487854,50000.0,0.5421000123023987,2.0770750045776367,10000.0,35247.70157814026,36589.16969251633,35247.70157814026,1334.212421655655,3.555453062057495,0.0 -103700,4.5082903,2.527287,,,,,,,,,,,,,, -103800,3.8596244,2.389669,,,,,,,,,,,,,, -103900,3.9447467,2.4885833,,,,,,,,,,,,,, -104000,4.528373,2.4994316,,,,,,,,,,,,,, -104100,4.1401534,2.4934773,,,,,,,,,,,,,, -104200,3.6625106,2.4028134,,,,,,,,,,,,,, -104300,3.886672,2.4511194,,,,,,,,,,,,,, -104400,3.9609044,2.5329185,,,,,,,,,,,,,, -104500,4.3689737,2.496521,,,,,,,,,,,,,, -104600,4.4428196,2.430795,,,,,,,,,,,,,, -104700,4.390935,2.5008416,,,,,,,,,,,,,, -104800,3.992829,2.457272,,,,,,,,,,,,,, -104900,3.6304605,2.4229631,,,,,,,,,,,,,, -105000,4.965823,2.427457,,,,,,,,,,,,,, -105100,3.7177207,2.4663565,,,,,,,,,,,,,, -105198,,,0.7359095811843872,1.1194480657577517,0.6708999872207642,1.4225313663482666,50000.0,0.5439000129699707,2.1047375202178955,10000.0,35757.91668009758,37117.02583503723,35757.91668009758,1351.752160072327,3.603590250015259,0.0 -105200,3.9764104,2.4253042,,,,,,,,,,,,,, -105300,4.407947,2.5684159,,,,,,,,,,,,,, -105400,4.123981,2.4464152,,,,,,,,,,,,,, -105500,3.6232207,2.3315988,,,,,,,,,,,,,, -105600,4.120918,2.4623797,,,,,,,,,,,,,, -105700,4.0753856,2.4577622,,,,,,,,,,,,,, -105800,4.0804214,2.4593773,,,,,,,,,,,,,, -105900,4.1197557,2.476403,,,,,,,,,,,,,, -106000,4.0431075,2.4202223,,,,,,,,,,,,,, -106100,4.849368,2.4559977,,,,,,,,,,,,,, -106200,3.715636,2.4462652,,,,,,,,,,,,,, -106300,4.4748154,2.551409,,,,,,,,,,,,,, -106400,4.123206,2.452122,,,,,,,,,,,,,, -106500,4.0152373,2.4310586,,,,,,,,,,,,,, -106600,4.0459123,2.4966564,,,,,,,,,,,,,, -106700,4.224289,2.4442215,,,,,,,,,,,,,, -106703,,,0.7438616156578064,1.0795286893844604,0.6801799535751343,1.3656352758407593,50000.0,0.5574000477790833,2.037686824798584,10000.0,36268.02729392052,37644.65886926651,36268.02729392052,1369.1742932796478,3.651136636734009,0.0 -106800,4.1594496,2.4995482,,,,,,,,,,,,,, -106900,3.5899763,2.407876,,,,,,,,,,,,,, -107000,3.9609718,2.5036159,,,,,,,,,,,,,, -107100,4.9813356,2.448769,,,,,,,,,,,,,, -107200,4.7012043,2.5008366,,,,,,,,,,,,,, -107300,4.004021,2.4064605,,,,,,,,,,,,,, -107400,3.9450777,2.3812072,,,,,,,,,,,,,, -107500,4.26982,2.3734193,,,,,,,,,,,,,, -107600,4.186048,2.4670138,,,,,,,,,,,,,, -107700,4.3734865,2.43565,,,,,,,,,,,,,, -107800,4.2751245,2.5302947,,,,,,,,,,,,,, -107900,4.194717,2.3363779,,,,,,,,,,,,,, -108000,3.9619808,2.5460887,,,,,,,,,,,,,, -108100,4.606566,2.399211,,,,,,,,,,,,,, -108200,4.187807,2.4537108,,,,,,,,,,,,,, -108205,,,0.7301897406578064,1.1509002447128296,0.6698799729347229,1.42633056640625,50000.0,0.5433000326156616,2.072263479232788,10000.0,36778.0198366642,38171.86295318604,36778.0198366642,1386.2812361717224,3.69752836227417,0.0 -108300,4.429505,2.5189362,,,,,,,,,,,,,, -108400,4.209603,2.4667063,,,,,,,,,,,,,, -108500,4.948783,2.4086685,,,,,,,,,,,,,, -108600,3.8250964,2.3949656,,,,,,,,,,,,,, -108700,4.0568123,2.4521923,,,,,,,,,,,,,, -108800,4.615379,2.483245,,,,,,,,,,,,,, -108900,4.2400947,2.3987994,,,,,,,,,,,,,, -109000,3.9823022,2.4137871,,,,,,,,,,,,,, -109100,4.604797,2.4570065,,,,,,,,,,,,,, -109200,4.0688777,2.3927724,,,,,,,,,,,,,, -109300,4.1585426,2.4121828,,,,,,,,,,,,,, -109400,4.334082,2.4192429,,,,,,,,,,,,,, -109500,4.656282,2.478746,,,,,,,,,,,,,, -109600,3.6550078,2.375256,,,,,,,,,,,,,, -109700,4.2302566,2.4756613,,,,,,,,,,,,,, -109710,,,0.7444196343421936,1.0855257511138916,0.6843999624252319,1.3646774291992188,50000.0,0.5539000034332275,2.023701429367065,10000.0,37288.15925168991,38699.43433403969,37288.15925168991,1403.6111969947815,3.74524450302124,0.0 -109800,4.0446277,2.4793026,,,,,,,,,,,,,, -109900,4.0051146,2.433888,,,,,,,,,,,,,, -110000,3.948592,2.4744234,,,,,,,,,,,,,, -110100,4.0913105,2.3764064,,,,,,,,,,,,,, -110200,3.9887471,2.3795018,,,,,,,,,,,,,, -110300,4.081305,2.3988788,,,,,,,,,,,,,, -110400,4.91312,2.4669955,,,,,,,,,,,,,, -110500,3.9996414,2.4198544,,,,,,,,,,,,,, -110600,4.3452387,2.5240517,,,,,,,,,,,,,, -110700,3.8848944,2.449475,,,,,,,,,,,,,, -110800,4.1750464,2.4582238,,,,,,,,,,,,,, -110900,4.339505,2.4578218,,,,,,,,,,,,,, -111000,3.9673865,2.3882618,,,,,,,,,,,,,, -111100,4.786189,2.4230273,,,,,,,,,,,,,, -111200,4.529336,2.5388384,,,,,,,,,,,,,, -111215,,,0.7463129758834839,1.0860766172409058,0.6835599541664124,1.3672128915786743,50000.0,0.5539000034332275,2.0336456298828125,10000.0,37798.17662835121,39227.00764942169,37798.17662835121,1421.063180923462,3.7959115505218506,0.0 -111300,3.9202876,2.3993487,,,,,,,,,,,,,, -111400,4.37099,2.3159442,,,,,,,,,,,,,, -111500,4.1669197,2.3434443,,,,,,,,,,,,,, -111600,4.309573,2.4598792,,,,,,,,,,,,,, -111700,4.0536675,2.3914616,,,,,,,,,,,,,, -111800,4.1695952,2.3691018,,,,,,,,,,,,,, -111900,4.373642,2.4282906,,,,,,,,,,,,,, -112000,4.8038535,2.4768982,,,,,,,,,,,,,, -112100,4.2008853,2.4559212,,,,,,,,,,,,,, -112200,4.020308,2.25111,,,,,,,,,,,,,, -112300,4.9129057,2.3999174,,,,,,,,,,,,,, -112400,4.1668878,2.4799833,,,,,,,,,,,,,, -112500,4.513907,2.4534323,,,,,,,,,,,,,, -112600,4.1173763,2.386625,,,,,,,,,,,,,, -112700,4.3170457,2.3834095,,,,,,,,,,,,,, -112720,,,0.7732580900192261,0.9794655442237854,0.6888999938964844,1.3610857725143433,50000.0,0.5567000508308411,2.027822256088257,10000.0,38308.20759654045,39754.66309762001,38308.20759654045,1438.587425947189,3.8423855304718018,0.0 -112800,4.4914494,2.4116912,,,,,,,,,,,,,, -112900,4.457342,2.468013,,,,,,,,,,,,,, -113000,4.2405324,2.4232838,,,,,,,,,,,,,, -113100,4.007741,2.4583368,,,,,,,,,,,,,, -113200,4.9186435,2.3384979,,,,,,,,,,,,,, -113300,4.3823977,2.3846226,,,,,,,,,,,,,, -113400,4.591521,2.4442024,,,,,,,,,,,,,, -113500,4.694924,2.5078118,,,,,,,,,,,,,, -113600,4.350431,2.2865794,,,,,,,,,,,,,, -113700,4.5775185,2.443256,,,,,,,,,,,,,, -113800,4.192428,2.3451521,,,,,,,,,,,,,, -113900,4.639268,2.4098692,,,,,,,,,,,,,, -114000,4.2639213,2.3631868,,,,,,,,,,,,,, -114100,4.359783,2.3731105,,,,,,,,,,,,,, -114200,4.9283915,2.5424063,,,,,,,,,,,,,, -114225,,,0.7552614808082581,1.0510189533233645,0.6833999752998352,1.376708984375,50000.0,0.557200014591217,2.0630359649658203,10000.0,38818.21781897545,40282.05206513405,38818.21781897545,1455.857293367386,3.897749662399292,0.0 -114300,4.008652,2.3458233,,,,,,,,,,,,,, -114400,4.2749023,2.416981,,,,,,,,,,,,,, -114500,4.3045945,2.4020543,,,,,,,,,,,,,, -114600,4.7750645,2.3779342,,,,,,,,,,,,,, -114700,4.6676292,2.4133885,,,,,,,,,,,,,, -114800,4.697245,2.4220521,,,,,,,,,,,,,, -114900,4.088782,2.332686,,,,,,,,,,,,,, -115000,5.29797,2.364975,,,,,,,,,,,,,, -115100,4.7466273,2.313993,,,,,,,,,,,,,, -115200,3.9832315,2.3544781,,,,,,,,,,,,,, -115300,4.331354,2.316708,,,,,,,,,,,,,, -115400,4.0579743,2.3310494,,,,,,,,,,,,,, -115500,4.206115,2.3902342,,,,,,,,,,,,,, -115600,4.7456584,2.3309867,,,,,,,,,,,,,, -115700,4.503324,2.3768506,,,,,,,,,,,,,, -115730,,,0.7583107352256775,1.042015790939331,0.6886799931526184,1.344862461090088,50000.0,0.5611000061035156,2.0068650245666504,10000.0,39328.42887806893,40809.46044826508,39328.42887806893,1472.9484844207764,3.9506664276123047,0.0 -115800,4.5516,2.3353286,,,,,,,,,,,,,, -115900,4.3607755,2.311841,,,,,,,,,,,,,, -116000,4.308072,2.3194997,,,,,,,,,,,,,, -116100,4.5352287,2.4382873,,,,,,,,,,,,,, -116200,4.396334,2.3757687,,,,,,,,,,,,,, -116300,4.745739,2.446779,,,,,,,,,,,,,, -116400,4.218837,2.3369427,,,,,,,,,,,,,, -116500,4.7668486,2.51052,,,,,,,,,,,,,, -116600,4.6662555,2.4507523,,,,,,,,,,,,,, -116700,4.4001846,2.376288,,,,,,,,,,,,,, -116800,4.796417,2.34741,,,,,,,,,,,,,, -116900,5.2394295,2.3995993,,,,,,,,,,,,,, -117000,4.2717023,2.3655086,,,,,,,,,,,,,, -117100,5.018251,2.3892844,,,,,,,,,,,,,, -117200,4.58093,2.3955393,,,,,,,,,,,,,, -117235,,,0.7524114847183228,1.064433455467224,0.6860799789428711,1.3524702787399292,50000.0,0.5618000030517578,2.0201923847198486,10000.0,39838.38658428192,41336.782682180405,39838.38658428192,1490.2100987434387,4.000177621841431,0.0 -117300,4.208919,2.3753715,,,,,,,,,,,,,, -117400,4.4134145,2.3533258,,,,,,,,,,,,,, -117500,4.245613,2.3901978,,,,,,,,,,,,,, -117600,4.588048,2.3756182,,,,,,,,,,,,,, -117700,4.7308674,2.3104203,,,,,,,,,,,,,, -117800,4.6971436,2.396932,,,,,,,,,,,,,, -117900,4.419909,2.3685055,,,,,,,,,,,,,, -118000,5.490351,2.3379564,,,,,,,,,,,,,, -118100,4.646319,2.356968,,,,,,,,,,,,,, -118200,4.8881044,2.4083698,,,,,,,,,,,,,, -118300,4.2276697,2.2676303,,,,,,,,,,,,,, -118400,4.595452,2.3855362,,,,,,,,,,,,,, -118500,4.550757,2.44137,,,,,,,,,,,,,, -118600,4.391236,2.3738222,,,,,,,,,,,,,, -118700,4.383045,2.309486,,,,,,,,,,,,,, -118739,,,0.7526506781578064,1.0996969938278198,0.6875199675559998,1.3770368099212646,50000.0,0.565500020980835,2.018036365509033,10000.0,40348.35553407669,41864.03025341034,40348.35553407669,1507.3889908790588,4.046831607818604,0.0 -118800,4.666704,2.4416544,,,,,,,,,,,,,, -118900,4.420248,2.3164127,,,,,,,,,,,,,, -119000,4.190904,2.373224,,,,,,,,,,,,,, -119100,4.724602,2.3887343,,,,,,,,,,,,,, -119200,5.0189443,2.3793192,,,,,,,,,,,,,, -119300,5.263013,2.314363,,,,,,,,,,,,,, -119400,4.87897,2.3798594,,,,,,,,,,,,,, -119500,4.157686,2.3000956,,,,,,,,,,,,,, -119600,4.7238927,2.4127693,,,,,,,,,,,,,, -119700,4.533367,2.3942735,,,,,,,,,,,,,, -119800,4.302761,2.3021128,,,,,,,,,,,,,, -119900,4.5024767,2.3801,,,,,,,,,,,,,, -120000,4.5732813,2.3351765,,,,,,,,,,,,,, -120100,4.7347417,2.2994313,,,,,,,,,,,,,, -120200,4.8370137,2.3727207,,,,,,,,,,,,,, -120244,,,0.7641701102256775,1.0034244060516355,0.699999988079071,1.2891603708267212,50000.0,0.5662000179290771,1.9515352249145508,10000.0,40858.48894238472,42391.6834628582,40858.48894238472,1524.8025405406952,4.097309589385986,0.0 -120300,4.420705,2.32165,,,,,,,,,,,,,, -120400,4.613088,2.443376,,,,,,,,,,,,,, -120500,4.650066,2.32318,,,,,,,,,,,,,, -120600,4.7537374,2.2520192,,,,,,,,,,,,,, -120700,4.4548907,2.3958418,,,,,,,,,,,,,, -120800,4.471038,2.322901,,,,,,,,,,,,,, -120900,5.0525956,2.3324323,,,,,,,,,,,,,, -121000,4.4347143,2.3195665,,,,,,,,,,,,,, -121100,5.092274,2.273209,,,,,,,,,,,,,, -121200,4.795506,2.389999,,,,,,,,,,,,,, -121300,4.955557,2.3887253,,,,,,,,,,,,,, -121400,4.2594504,2.3097582,,,,,,,,,,,,,, -121500,4.8166223,2.2453113,,,,,,,,,,,,,, -121600,4.6073766,2.3670652,,,,,,,,,,,,,, -121700,4.6441803,2.3365943,,,,,,,,,,,,,, -121749,,,0.7843789458274841,0.9068344831466676,0.6967200040817261,1.3004614114761353,50000.0,0.5667999982833862,1.9649931192398071,10000.0,41368.59728837013,42919.34297060967,41368.59728837013,1542.2404384613037,4.155265808105469,0.0 -121800,4.398267,2.2156725,,,,,,,,,,,,,, -121900,4.453065,2.3686624,,,,,,,,,,,,,, -122000,5.187709,2.3699925,,,,,,,,,,,,,, -122100,4.496101,2.2821565,,,,,,,,,,,,,, -122200,4.966672,2.3860695,,,,,,,,,,,,,, -122300,4.547307,2.3242626,,,,,,,,,,,,,, -122400,4.6262193,2.3581681,,,,,,,,,,,,,, -122500,4.6912327,2.2578764,,,,,,,,,,,,,, -122600,4.5577216,2.3684392,,,,,,,,,,,,,, -122700,5.0645995,2.2827008,,,,,,,,,,,,,, -122800,4.6117916,2.3029068,,,,,,,,,,,,,, -122900,4.964198,2.3382568,,,,,,,,,,,,,, -123000,4.019436,2.230982,,,,,,,,,,,,,, -123100,4.8701835,2.3462505,,,,,,,,,,,,,, -123200,4.3812943,2.2899518,,,,,,,,,,,,,, -123254,,,0.78324294090271,0.9465728998184204,0.7019199728965759,1.2940689325332642,50000.0,0.5749000310897827,1.9503682851791384,10000.0,41878.64205765724,43446.8511133194,41878.64205765724,1559.600687265396,4.203623294830322,0.0 -123300,4.2878976,2.3675447,,,,,,,,,,,,,, -123400,4.472567,2.330868,,,,,,,,,,,,,, -123500,4.658865,2.3141286,,,,,,,,,,,,,, -123600,5.3841786,2.3762665,,,,,,,,,,,,,, -123700,4.8647,2.3943236,,,,,,,,,,,,,, -123800,4.87672,2.2811244,,,,,,,,,,,,,, -123900,4.796325,2.3376882,,,,,,,,,,,,,, -124000,4.629236,2.3070335,,,,,,,,,,,,,, -124100,4.453521,2.3466268,,,,,,,,,,,,,, -124200,4.571991,2.3614333,,,,,,,,,,,,,, -124300,4.659864,2.336448,,,,,,,,,,,,,, -124400,5.025426,2.3058844,,,,,,,,,,,,,, -124500,4.5671487,2.2815804,,,,,,,,,,,,,, -124600,4.790898,2.269117,,,,,,,,,,,,,, -124700,4.60785,2.2865057,,,,,,,,,,,,,, -124759,,,0.7840202450752258,0.9154542088508606,0.7075799703598022,1.2590969800949097,50000.0,0.578000009059906,1.9135336875915527,10000.0,42388.77073264122,43974.48748540878,42388.77073264122,1577.003856897354,4.252768278121948,0.0 -124800,4.8276367,2.2290776,,,,,,,,,,,,,, -124900,4.8502593,2.3421338,,,,,,,,,,,,,, -125000,5.010198,2.287698,,,,,,,,,,,,,, -125100,5.474386,2.2547166,,,,,,,,,,,,,, -125200,4.4719505,2.2375724,,,,,,,,,,,,,, -125300,4.70017,2.303416,,,,,,,,,,,,,, -125400,4.958733,2.341085,,,,,,,,,,,,,, -125500,5.20532,2.3172965,,,,,,,,,,,,,, -125600,4.6729827,2.2893274,,,,,,,,,,,,,, -125700,5.0511427,2.3151052,,,,,,,,,,,,,, -125800,4.9423,2.313988,,,,,,,,,,,,,, -125900,4.664368,2.2867227,,,,,,,,,,,,,, -126000,4.986156,2.3206315,,,,,,,,,,,,,, -126100,4.5388985,2.2719646,,,,,,,,,,,,,, -126200,4.862328,2.3068676,,,,,,,,,,,,,, -126264,,,0.7784199714660645,0.9533615708351136,0.7020599842071533,1.289789795875549,50000.0,0.5743000507354736,1.9358104467391968,10000.0,42898.89861416817,44502.21145987511,42898.89861416817,1594.5007934570312,4.299294471740723,0.0 -126300,4.9979496,2.2765105,,,,,,,,,,,,,, -126400,4.825022,2.2621782,,,,,,,,,,,,,, -126500,4.627394,2.186074,,,,,,,,,,,,,, -126600,5.1159716,2.3519177,,,,,,,,,,,,,, -126700,4.7346916,2.3153162,,,,,,,,,,,,,, -126800,5.0697136,2.341482,,,,,,,,,,,,,, -126900,5.108167,2.294845,,,,,,,,,,,,,, -127000,5.1327424,2.2606351,,,,,,,,,,,,,, -127100,4.7990537,2.336836,,,,,,,,,,,,,, -127200,4.630134,2.199642,,,,,,,,,,,,,, -127300,4.758153,2.2845669,,,,,,,,,,,,,, -127400,4.8880677,2.2594225,,,,,,,,,,,,,, -127500,5.3167367,2.2602413,,,,,,,,,,,,,, -127600,5.0055513,2.250153,,,,,,,,,,,,,, -127700,4.498688,2.2050347,,,,,,,,,,,,,, -127768,,,0.7771444320678711,0.9476613402366638,0.7056399583816528,1.2707682847976685,50000.0,0.5785000324249268,1.9376139640808103,10000.0,43408.88532400131,45029.88151431084,43408.88532400131,1612.0683364868164,4.361209392547607,0.0 -127800,5.106462,2.3619015,,,,,,,,,,,,,, -127900,4.542871,2.209622,,,,,,,,,,,,,, -128000,5.28678,2.3215687,,,,,,,,,,,,,, -128100,5.362543,2.2313735,,,,,,,,,,,,,, -128200,5.1101065,2.2683396,,,,,,,,,,,,,, -128300,5.279693,2.2864249,,,,,,,,,,,,,, -128400,5.202924,2.2188501,,,,,,,,,,,,,, -128500,4.8550053,2.290084,,,,,,,,,,,,,, -128600,4.722277,2.225182,,,,,,,,,,,,,, -128700,5.533536,2.224698,,,,,,,,,,,,,, -128800,5.3688774,2.394371,,,,,,,,,,,,,, -128900,4.94059,2.292315,,,,,,,,,,,,,, -129000,5.439989,2.2910395,,,,,,,,,,,,,, -129100,4.6988325,2.2560341,,,,,,,,,,,,,, -129200,5.1186943,2.2324257,,,,,,,,,,,,,, -129273,,,0.7789580225944519,0.9515725374221802,0.7039799690246582,1.283861756324768,50000.0,0.5729000568389893,1.9563974142074585,10000.0,43919.058978796005,45557.30363321304,43919.058978796005,1629.2153718471527,4.409427642822266,0.0 -129300,5.613831,2.3808641,,,,,,,,,,,,,, -129400,5.41602,2.2857647,,,,,,,,,,,,,, -129500,5.548021,2.3537383,,,,,,,,,,,,,, -129600,4.61096,2.2617047,,,,,,,,,,,,,, -129700,4.869393,2.1989193,,,,,,,,,,,,,, -129800,5.8016086,2.30093,,,,,,,,,,,,,, -129900,4.9968486,2.2190425,,,,,,,,,,,,,, -130000,5.4415665,2.2680268,,,,,,,,,,,,,, -130100,5.0609956,2.2005374,,,,,,,,,,,,,, -130200,5.42716,2.213676,,,,,,,,,,,,,, -130300,5.286772,2.3117661,,,,,,,,,,,,,, -130400,5.1191287,2.2723546,,,,,,,,,,,,,, -130500,5.3183885,2.370761,,,,,,,,,,,,,, -130600,4.931729,2.3449214,,,,,,,,,,,,,, -130700,4.7188864,2.1922154,,,,,,,,,,,,,, -130777,,,0.8083147406578064,0.815697968006134,0.7135199904441833,1.2330472469329834,50000.0,0.5872000455856323,1.876109480857849,10000.0,44429.077756643295,46084.88196539879,44429.077756643295,1646.6670017242432,4.463230848312378,0.0 -130800,5.2321844,2.2274272,,,,,,,,,,,,,, -130900,5.384903,2.2945938,,,,,,,,,,,,,, -131000,5.463231,2.24894,,,,,,,,,,,,,, -131100,5.2304397,2.2314744,,,,,,,,,,,,,, -131200,4.874256,2.2116396,,,,,,,,,,,,,, -131300,5.2769265,2.2460852,,,,,,,,,,,,,, -131400,5.604053,2.2447674,,,,,,,,,,,,,, -131500,4.980529,2.1411715,,,,,,,,,,,,,, -131600,5.5555553,2.2522693,,,,,,,,,,,,,, -131700,5.057232,2.1527376,,,,,,,,,,,,,, -131800,5.057864,2.1284332,,,,,,,,,,,,,, -131900,4.8060884,2.2488148,,,,,,,,,,,,,, -132000,4.9733005,2.2347696,,,,,,,,,,,,,, -132100,5.2373834,2.300879,,,,,,,,,,,,,, -132200,4.9796634,2.302579,,,,,,,,,,,,,, -132282,,,0.8037906289100647,0.8384227156639099,0.7140399813652039,1.224155306816101,50000.0,0.5893000364303589,1.8553305864334104,10000.0,44939.26374220848,46612.58217954636,44939.26374220848,1664.0718927383425,4.518594264984131,0.0 -132300,5.2246814,2.286631,,,,,,,,,,,,,, -132400,4.965164,2.2392247,,,,,,,,,,,,,, -132500,5.563509,2.2652144,,,,,,,,,,,,,, -132600,5.639295,2.212163,,,,,,,,,,,,,, -132700,5.158884,2.2136922,,,,,,,,,,,,,, -132800,5.4193497,2.2372005,,,,,,,,,,,,,, -132900,5.17646,2.2593608,,,,,,,,,,,,,, -133000,5.3723927,2.3002548,,,,,,,,,,,,,, -133100,5.7173195,2.2017324,,,,,,,,,,,,,, -133200,5.1515274,2.2395895,,,,,,,,,,,,,, -133300,5.0673223,2.1803875,,,,,,,,,,,,,, -133400,5.359608,2.2198017,,,,,,,,,,,,,, -133500,5.5214653,2.187059,,,,,,,,,,,,,, -133600,5.408185,2.2123847,,,,,,,,,,,,,, -133700,4.916724,2.1351528,,,,,,,,,,,,,, -133787,,,0.8019172549247742,0.8607712388038635,0.715499997138977,1.2324215173721311,50000.0,0.5929000377655029,1.859167098999024,10000.0,45449.338752985,47140.23160409928,45449.338752985,1681.5357718467712,4.574284315109253,0.0 -133800,5.104365,2.2754848,,,,,,,,,,,,,, -133900,4.9280868,2.230206,,,,,,,,,,,,,, -134000,5.6506524,2.2931926,,,,,,,,,,,,,, -134100,5.89618,2.2985399,,,,,,,,,,,,,, -134200,5.727422,2.2475147,,,,,,,,,,,,,, -134300,5.5170007,2.2146316,,,,,,,,,,,,,, -134400,5.3479166,2.1691165,,,,,,,,,,,,,, -134500,5.128697,2.1558022,,,,,,,,,,,,,, -134600,5.1602473,2.1569457,,,,,,,,,,,,,, -134700,5.8180285,2.2372534,,,,,,,,,,,,,, -134800,5.306554,2.2398067,,,,,,,,,,,,,, -134900,5.0808735,2.2622247,,,,,,,,,,,,,, -135000,5.3574867,2.1587718,,,,,,,,,,,,,, -135100,6.395025,2.2447324,,,,,,,,,,,,,, -135200,5.6991906,2.2640243,,,,,,,,,,,,,, -135292,,,0.8028938174247742,0.844454824924469,0.7201399803161621,1.20967698097229,50000.0,0.5944000482559204,1.8514662981033323,10000.0,45959.35618376732,47667.49844503403,45959.35618376732,1698.6805226802826,4.625326871871948,0.0 -135300,5.4163146,2.252962,,,,,,,,,,,,,, -135400,5.318693,2.2546937,,,,,,,,,,,,,, -135500,5.0962152,2.1612544,,,,,,,,,,,,,, -135600,5.316424,2.2255752,,,,,,,,,,,,,, -135700,5.4429345,2.2138252,,,,,,,,,,,,,, -135800,5.40668,2.1977375,,,,,,,,,,,,,, -135900,4.797821,2.1369195,,,,,,,,,,,,,, -136000,4.9281883,2.1719944,,,,,,,,,,,,,, -136100,5.1266103,2.2105818,,,,,,,,,,,,,, -136200,5.2548957,2.1368675,,,,,,,,,,,,,, -136300,5.1782293,2.1255798,,,,,,,,,,,,,, -136400,5.0267825,2.1752,,,,,,,,,,,,,, -136500,6.136671,2.261764,,,,,,,,,,,,,, -136600,5.547479,2.2189517,,,,,,,,,,,,,, -136700,5.7056475,2.1533148,,,,,,,,,,,,,, -136797,,,0.8018972873687744,0.8656412959098816,0.7197399735450745,1.2183958292007446,50000.0,0.5927000045776367,1.855873107910156,10000.0,46469.46252202988,48194.80530762672,46469.46252202988,1715.7776033878326,4.675408601760864,0.0 -136800,5.2148523,2.1268716,,,,,,,,,,,,,, -136900,5.1548657,2.0741458,,,,,,,,,,,,,, -137000,5.1202674,2.2907715,,,,,,,,,,,,,, -137100,5.7240753,2.2142267,,,,,,,,,,,,,, -137200,6.27016,2.223142,,,,,,,,,,,,,, -137300,5.3461556,2.101502,,,,,,,,,,,,,, -137400,5.444846,2.1965177,,,,,,,,,,,,,, -137500,5.4356456,2.1337101,,,,,,,,,,,,,, -137600,5.567657,2.1927104,,,,,,,,,,,,,, -137700,5.776443,2.237841,,,,,,,,,,,,,, -137800,5.97406,2.1693926,,,,,,,,,,,,,, -137900,5.108037,2.2034235,,,,,,,,,,,,,, -138000,5.7643056,2.1489203,,,,,,,,,,,,,, -138100,5.443096,2.2416747,,,,,,,,,,,,,, -138200,5.4835258,2.172433,,,,,,,,,,,,,, -138300,5.752432,2.2187212,,,,,,,,,,,,,, -138303,,,0.8061822056770325,0.8341156244277954,0.7231799960136414,1.196844220161438,50000.0,0.6022000312805176,1.815556883811951,10000.0,46979.67847561836,48722.46280956268,46979.67847561836,1733.1091315746307,4.73111629486084,0.0 -138400,5.2577252,2.1632998,,,,,,,,,,,,,, -138500,5.623497,2.1431754,,,,,,,,,,,,,, -138600,5.4374237,2.1376922,,,,,,,,,,,,,, -138700,5.419795,2.1198344,,,,,,,,,,,,,, -138800,5.065165,2.1610324,,,,,,,,,,,,,, -138900,5.415762,2.1763048,,,,,,,,,,,,,, -139000,5.5281935,2.2237878,,,,,,,,,,,,,, -139100,5.4183316,2.17972,,,,,,,,,,,,,, -139200,5.9004292,2.160263,,,,,,,,,,,,,, -139300,5.701277,2.1923747,,,,,,,,,,,,,, -139400,5.3014145,2.122593,,,,,,,,,,,,,, -139500,5.644303,2.1328259,,,,,,,,,,,,,, -139600,5.41795,2.1806056,,,,,,,,,,,,,, -139700,5.457916,2.194757,,,,,,,,,,,,,, -139800,5.561614,2.163284,,,,,,,,,,,,,, -139808,,,0.8362762928009033,0.7129075527191162,0.7239599823951721,1.182400465011597,50000.0,0.601900041103363,1.808569073677063,10000.0,47489.7852473259,49249.88934969902,47489.7852473259,1750.324533700943,4.7820470333099365,0.0 -139900,5.7862067,2.1577559,,,,,,,,,,,,,, -140000,5.6264496,2.1794648,,,,,,,,,,,,,, -140100,5.7150207,2.1576529,,,,,,,,,,,,,, -140200,5.572227,2.147622,,,,,,,,,,,,,, -140300,6.4885798,2.1678212,,,,,,,,,,,,,, -140400,6.0933776,2.1869278,,,,,,,,,,,,,, -140500,6.5699477,2.1740685,,,,,,,,,,,,,, -140600,5.408058,2.108539,,,,,,,,,,,,,, -140700,6.30381,2.1773376,,,,,,,,,,,,,, -140800,6.0254364,2.1543987,,,,,,,,,,,,,, -140900,6.3826213,2.2155685,,,,,,,,,,,,,, -141000,5.8154497,2.1860619,,,,,,,,,,,,,, -141100,5.394862,2.1520498,,,,,,,,,,,,,, -141200,5.678049,2.1174471,,,,,,,,,,,,,, -141300,5.970764,2.1245964,,,,,,,,,,,,,, -141313,,,0.8279455900192261,0.7402110695838928,0.727679967880249,1.1678305864334106,50000.0,0.6055999994277954,1.8135945796966555,10000.0,47999.80777215958,49777.51341342926,47999.80777215958,1767.8215091228485,4.833433866500855,0.0 -141400,5.163781,2.0940416,,,,,,,,,,,,,, -141500,6.243452,2.1048117,,,,,,,,,,,,,, -141600,6.025624,2.1030846,,,,,,,,,,,,,, -141700,5.6682525,2.1833055,,,,,,,,,,,,,, -141800,5.9326906,2.1658545,,,,,,,,,,,,,, -141900,5.972759,2.0877402,,,,,,,,,,,,,, -142000,5.8745008,2.0502944,,,,,,,,,,,,,, -142100,5.3777165,2.1059628,,,,,,,,,,,,,, -142200,6.0749583,2.1217175,,,,,,,,,,,,,, -142300,5.8638744,2.2037299,,,,,,,,,,,,,, -142400,5.9008365,2.1005623,,,,,,,,,,,,,, -142500,5.6063104,2.1258473,,,,,,,,,,,,,, -142600,5.847613,2.0949013,,,,,,,,,,,,,, -142700,6.7536263,2.1324372,,,,,,,,,,,,,, -142800,5.4808774,2.1274736,,,,,,,,,,,,,, -142818,,,0.8234215378761292,0.757717490196228,0.729479968547821,1.166062593460083,50000.0,0.6040000319480896,1.7834153175354004,10000.0,48510.02785205841,50304.98026776314,48510.02785205841,1784.957479953766,4.888729572296143,0.0 -142900,6.0851474,2.2575433,,,,,,,,,,,,,, -143000,5.756905,2.1473808,,,,,,,,,,,,,, -143100,6.055606,2.0758567,,,,,,,,,,,,,, -143200,5.7113633,2.066846,,,,,,,,,,,,,, -143300,6.0685987,2.153519,,,,,,,,,,,,,, -143400,5.94016,2.1063855,,,,,,,,,,,,,, -143500,5.971694,2.0781527,,,,,,,,,,,,,, -143600,5.7421436,2.174855,,,,,,,,,,,,,, -143700,5.883481,2.1247258,,,,,,,,,,,,,, -143800,5.903039,2.0574822,,,,,,,,,,,,,, -143900,6.4291964,2.105908,,,,,,,,,,,,,, -144000,6.32089,2.1796126,,,,,,,,,,,,,, -144100,6.5597,2.1733801,,,,,,,,,,,,,, -144200,6.0736794,2.1185133,,,,,,,,,,,,,, -144300,5.8524528,2.1137102,,,,,,,,,,,,,, -144323,,,0.832429826259613,0.7265508770942688,0.7337599992752075,1.1390953063964844,50000.0,0.6178000569343567,1.7489664554595947,10000.0,49020.1178920269,50832.58455181122,49020.1178920269,1802.3611364364624,4.945554733276367,0.0 -144400,5.5037284,2.1248796,,,,,,,,,,,,,, -144500,5.3645644,2.0961473,,,,,,,,,,,,,, -144600,5.7413087,2.112318,,,,,,,,,,,,,, -144700,6.842453,2.1577806,,,,,,,,,,,,,, -144800,6.6438007,2.1442924,,,,,,,,,,,,,, -144900,6.284237,2.1653333,,,,,,,,,,,,,, -145000,6.446436,2.1242301,,,,,,,,,,,,,, -145100,5.654735,2.090189,,,,,,,,,,,,,, -145200,6.0818458,2.0917506,,,,,,,,,,,,,, -145300,6.2915764,2.1178772,,,,,,,,,,,,,, -145400,5.8237543,2.0366747,,,,,,,,,,,,,, -145500,5.6851535,2.053378,,,,,,,,,,,,,, -145600,6.299025,2.0440865,,,,,,,,,,,,,, -145700,6.659218,2.0855267,,,,,,,,,,,,,, -145800,6.2353525,2.175056,,,,,,,,,,,,,, -145828,,,0.8336455225944519,0.722388744354248,0.7360000014305115,1.13362717628479,50000.0,0.6071000099182129,1.7726635932922363,10000.0,49530.3389377594,51360.16606426239,49530.3389377594,1819.614022731781,4.998095750808716,0.0 -145900,5.4222913,2.0795345,,,,,,,,,,,,,, -146000,6.0110803,2.1418276,,,,,,,,,,,,,, -146100,6.306202,2.0629776,,,,,,,,,,,,,, -146200,5.9350896,2.1075516,,,,,,,,,,,,,, -146300,5.8980007,2.0607815,,,,,,,,,,,,,, -146400,6.490942,2.0791361,,,,,,,,,,,,,, -146500,6.192542,2.0834825,,,,,,,,,,,,,, -146600,5.934211,2.1013422,,,,,,,,,,,,,, -146700,6.0404177,2.060998,,,,,,,,,,,,,, -146800,6.428039,2.0893795,,,,,,,,,,,,,, -146900,5.717761,2.123192,,,,,,,,,,,,,, -147000,6.1645966,2.1391509,,,,,,,,,,,,,, -147100,5.8983736,2.0878897,,,,,,,,,,,,,, -147200,5.88666,2.0578737,,,,,,,,,,,,,, -147300,5.751424,2.0014353,,,,,,,,,,,,,, -147333,,,0.8375119566917419,0.7116343975067139,0.7391600012779236,1.125153422355652,50000.0,0.6159000396728516,1.7439031600952148,10000.0,50040.367628097534,51887.47683286667,50040.367628097534,1836.791166067124,5.049542188644409,0.0 -147400,6.0979757,2.0550625,,,,,,,,,,,,,, -147500,5.7545805,2.0609365,,,,,,,,,,,,,, -147600,6.326064,2.0807886,,,,,,,,,,,,,, -147700,6.535182,2.109924,,,,,,,,,,,,,, -147800,6.920758,2.1032608,,,,,,,,,,,,,, -147900,6.230408,2.1299796,,,,,,,,,,,,,, -148000,5.918161,2.0534117,,,,,,,,,,,,,, -148100,5.9308424,2.0307183,,,,,,,,,,,,,, -148200,6.4606357,2.0064988,,,,,,,,,,,,,, -148300,6.088585,1.966887,,,,,,,,,,,,,, -148400,6.7435055,2.0380852,,,,,,,,,,,,,, -148500,5.891897,2.0663047,,,,,,,,,,,,,, -148600,6.1743646,2.0512362,,,,,,,,,,,,,, -148700,5.954997,2.0923982,,,,,,,,,,,,,, -148800,6.416053,1.9867874,,,,,,,,,,,,,, -148837,,,0.8622449040412903,0.6133614182472229,0.7362200021743774,1.1291574239730835,50000.0,0.615600049495697,1.7641743421554563,10000.0,50550.32852196693,52415.02962350845,50550.32852196693,1854.272777080536,5.106694459915161,0.0 -148900,5.9761186,2.0164483,,,,,,,,,,,,,, -149000,6.3635063,2.1314094,,,,,,,,,,,,,, -149100,6.192904,2.0987217,,,,,,,,,,,,,, -149200,5.855015,2.0058079,,,,,,,,,,,,,, -149300,6.61208,2.076858,,,,,,,,,,,,,, -149400,5.57979,2.0426455,,,,,,,,,,,,,, -149500,6.6326003,2.0771682,,,,,,,,,,,,,, -149600,5.7814627,2.0859766,,,,,,,,,,,,,, -149700,5.890029,2.0642238,,,,,,,,,,,,,, -149800,7.3917885,2.093988,,,,,,,,,,,,,, -149900,6.1335583,2.0291102,,,,,,,,,,,,,, -150000,7.0618625,2.0505009,,,,,,,,,,,,,, -150100,6.3971643,2.0185094,,,,,,,,,,,,,, -150200,6.2738657,1.9714885,,,,,,,,,,,,,, -150300,6.352561,2.034326,,,,,,,,,,,,,, -150342,,,0.8563655614852905,0.637763261795044,0.7451800107955933,1.1021682024002075,50000.0,0.6177000403404236,1.7354857921600342,10000.0,51060.45466089249,52942.5445561409,51060.45466089249,1871.5532939434047,5.160580635070801,0.0 -150400,6.2074604,1.9671754,,,,,,,,,,,,,, -150500,6.1201267,2.0684104,,,,,,,,,,,,,, -150600,6.799393,2.119086,,,,,,,,,,,,,, -150700,6.4630065,1.9972832,,,,,,,,,,,,,, -150800,6.6896377,1.9780527,,,,,,,,,,,,,, -150900,7.0746126,2.1612976,,,,,,,,,,,,,, -151000,5.944238,1.9736392,,,,,,,,,,,,,, -151100,6.5519724,2.0652404,,,,,,,,,,,,,, -151200,6.3171186,2.014799,,,,,,,,,,,,,, -151300,5.8540597,2.026155,,,,,,,,,,,,,, -151400,6.0591183,2.061812,,,,,,,,,,,,,, -151500,6.8207464,2.048217,,,,,,,,,,,,,, -151600,6.750732,2.0882978,,,,,,,,,,,,,, -151700,6.0551333,2.006477,,,,,,,,,,,,,, -151800,6.51437,2.0285587,,,,,,,,,,,,,, -151846,,,0.8573222160339355,0.6447144746780396,0.7456799745559692,1.0978809595108032,50000.0,0.6177000403404236,1.743034839630127,10000.0,51570.480170726776,53470.75521993637,51570.480170726776,1889.6330585479736,5.211879253387451,0.0 -151900,7.163013,2.0625062,,,,,,,,,,,,,, -152000,6.6871758,2.0965295,,,,,,,,,,,,,, -152100,6.672937,2.09176,,,,,,,,,,,,,, -152200,7.005534,2.0548832,,,,,,,,,,,,,, -152300,6.2637444,2.0592413,,,,,,,,,,,,,, -152400,6.474575,1.9937106,,,,,,,,,,,,,, -152500,6.641602,1.9861444,,,,,,,,,,,,,, -152600,5.8861337,2.0665584,,,,,,,,,,,,,, -152700,6.305742,1.9851153,,,,,,,,,,,,,, -152800,6.7402472,2.0195298,,,,,,,,,,,,,, -152900,6.6136203,2.0515268,,,,,,,,,,,,,, -153000,6.9309115,2.0377002,,,,,,,,,,,,,, -153100,6.4018874,2.0426085,,,,,,,,,,,,,, -153200,6.443773,1.963356,,,,,,,,,,,,,, -153300,6.981008,2.0627193,,,,,,,,,,,,,, -153351,,,0.8552295565605164,0.6356821060180664,0.7467799782752991,1.086501955986023,50000.0,0.6243000030517578,1.7114746570587158,10000.0,52080.545378923416,53998.2950387001,52080.545378923416,1907.001557826996,5.265218734741211,0.0 -153400,6.483533,2.0847015,,,,,,,,,,,,,, -153500,6.2677526,2.0428674,,,,,,,,,,,,,, -153600,6.3904796,2.0114439,,,,,,,,,,,,,, -153700,6.4014983,1.9534287,,,,,,,,,,,,,, -153800,6.362013,1.9558102,,,,,,,,,,,,,, -153900,6.185624,1.9705725,,,,,,,,,,,,,, -154000,6.7213492,2.00354,,,,,,,,,,,,,, -154100,6.866757,2.0987394,,,,,,,,,,,,,, -154200,6.3227997,1.9348947,,,,,,,,,,,,,, -154300,6.9180837,2.0410311,,,,,,,,,,,,,, -154400,6.6733475,1.8990276,,,,,,,,,,,,,, -154500,7.313526,2.0484838,,,,,,,,,,,,,, -154600,5.859514,1.9189407,,,,,,,,,,,,,, -154700,6.4142566,1.9280264,,,,,,,,,,,,,, -154800,6.232276,2.05801,,,,,,,,,,,,,, -154855,,,0.8594945669174194,0.6123570203781128,0.7492199540138245,1.0768321752548218,50000.0,0.6264000535011292,1.6894590854644775,10000.0,52590.51016449928,54525.517634391785,52590.51016449928,1924.1550514698029,5.317140579223633,0.0 -154900,6.686344,1.8999658,,,,,,,,,,,,,, -155000,6.8440685,1.9895306,,,,,,,,,,,,,, -155100,6.8847003,1.972703,,,,,,,,,,,,,, -155200,6.7650175,2.0123756,,,,,,,,,,,,,, -155300,6.8801126,2.0378067,,,,,,,,,,,,,, -155400,6.841871,1.9933558,,,,,,,,,,,,,, -155500,6.5039644,1.9166608,,,,,,,,,,,,,, -155600,6.7218127,1.9878051,,,,,,,,,,,,,, -155700,6.5578933,2.0162733,,,,,,,,,,,,,, -155800,6.520502,2.022991,,,,,,,,,,,,,, -155900,7.170797,2.0392828,,,,,,,,,,,,,, -156000,6.5360036,1.9823838,,,,,,,,,,,,,, -156100,6.154618,1.9346173,,,,,,,,,,,,,, -156200,7.0851984,1.9829315,,,,,,,,,,,,,, -156300,7.1644044,1.999844,,,,,,,,,,,,,, -156359,,,0.861348032951355,0.6097769141197205,0.7504799962043762,1.0702511072158811,50000.0,0.628600001335144,1.6836930513381958,10000.0,53100.41500091553,55053.556411504745,53100.41500091553,1942.180507183075,5.372523307800293,0.0 -156400,6.653852,1.9742098,,,,,,,,,,,,,, -156500,6.330656,1.9555953,,,,,,,,,,,,,, -156600,6.5708876,2.001463,,,,,,,,,,,,,, -156700,6.8088927,1.9739949,,,,,,,,,,,,,, -156800,7.074761,1.9538369,,,,,,,,,,,,,, -156900,6.8856134,1.9231129,,,,,,,,,,,,,, -157000,7.2616405,1.988369,,,,,,,,,,,,,, -157100,7.5836673,2.0223205,,,,,,,,,,,,,, -157200,7.072617,1.9152589,,,,,,,,,,,,,, -157300,6.5945883,1.9832546,,,,,,,,,,,,,, -157400,6.9752884,1.9765414,,,,,,,,,,,,,, -157500,6.9430842,1.9080096,,,,,,,,,,,,,, -157600,6.63391,1.9785769,,,,,,,,,,,,,, -157700,7.04336,1.9935988,,,,,,,,,,,,,, -157800,6.8456826,1.9443922,,,,,,,,,,,,,, -157863,,,0.8867984414100647,0.5364611744880676,0.7528600096702576,1.0770390033721924,50000.0,0.6276000142097473,1.6905221939086914,10000.0,53610.552248716354,55581.18113279343,53610.552248716354,1959.568165779113,5.417553663253784,0.0 -157900,7.450242,2.0183833,,,,,,,,,,,,,, -158000,7.1071115,1.9063903,,,,,,,,,,,,,, -158100,7.177637,1.9524657,,,,,,,,,,,,,, -158200,7.082538,1.9625375,,,,,,,,,,,,,, -158300,6.706003,1.9919277,,,,,,,,,,,,,, -158400,7.098618,1.9974782,,,,,,,,,,,,,, -158500,7.3510575,2.0923338,,,,,,,,,,,,,, -158600,6.896576,1.9268268,,,,,,,,,,,,,, -158700,6.9860373,1.9679286,,,,,,,,,,,,,, -158800,7.225819,1.9850488,,,,,,,,,,,,,, -158900,6.610016,1.9929831,,,,,,,,,,,,,, -159000,7.823366,1.9883517,,,,,,,,,,,,,, -159100,7.561599,1.9595453,,,,,,,,,,,,,, -159200,6.808052,1.941611,,,,,,,,,,,,,, -159300,6.996288,1.9700131,,,,,,,,,,,,,, -159368,,,0.8807397484779358,0.5454962849617004,0.7541199922561646,1.0671308040618896,50000.0,0.6304000020027161,1.6812326908111572,10000.0,54120.69530892372,56108.90738940239,54120.69530892372,1977.037132024765,5.476731061935425,0.0 -159400,6.4202642,1.9440321,,,,,,,,,,,,,, -159500,6.556852,1.9625278,,,,,,,,,,,,,, -159600,7.174051,1.9374628,,,,,,,,,,,,,, -159700,7.3567095,1.9443402,,,,,,,,,,,,,, -159800,6.8144884,1.9633579,,,,,,,,,,,,,, -159900,6.9918823,1.9797934,,,,,,,,,,,,,, -160000,6.8617773,2.021132,,,,,,,,,,,,,, -160100,7.040738,1.9366925,,,,,,,,,,,,,, -160200,6.599379,1.8969954,,,,,,,,,,,,,, -160300,7.634989,1.9912617,,,,,,,,,,,,,, -160400,7.601546,1.9478108,,,,,,,,,,,,,, -160500,7.177368,1.9552629,,,,,,,,,,,,,, -160600,6.891651,1.9032317,,,,,,,,,,,,,, -160700,7.4275827,1.9188305,,,,,,,,,,,,,, -160800,6.8556232,1.9153059,,,,,,,,,,,,,, -160872,,,0.882254421710968,0.5351816415786743,0.7583799958229065,1.0473202466964722,50000.0,0.6320000290870667,1.6694732904434204,10000.0,54630.79351758957,56636.27542257309,54630.79351758957,1994.1933534145355,5.536860942840576,0.0 -160900,7.113852,1.8940902,,,,,,,,,,,,,, -161000,6.6418977,1.9251297,,,,,,,,,,,,,, -161100,7.127837,1.8535799,,,,,,,,,,,,,, -161200,7.241662,1.956312,,,,,,,,,,,,,, -161300,6.7384567,1.9518974,,,,,,,,,,,,,, -161400,7.925613,1.9199402,,,,,,,,,,,,,, -161500,8.122516,1.8921624,,,,,,,,,,,,,, -161600,7.6812224,1.8854126,,,,,,,,,,,,,, -161700,7.3936496,1.9103682,,,,,,,,,,,,,, -161800,7.3541713,1.8932989,,,,,,,,,,,,,, -161900,7.1498146,1.881024,,,,,,,,,,,,,, -162000,7.3003035,1.8974905,,,,,,,,,,,,,, -162100,7.144917,1.9753556,,,,,,,,,,,,,, -162200,6.966197,1.974358,,,,,,,,,,,,,, -162300,7.32732,1.8945984,,,,,,,,,,,,,, -162376,,,0.8803212642669678,0.5474730134010315,0.7545599937438965,1.0637322664260864,50000.0,0.6378000378608704,1.6772630214691162,10000.0,55140.68953895569,57163.9195151329,55140.68953895569,2011.8303670883176,5.594317197799683,0.0 -162400,7.2810287,1.9610544,,,,,,,,,,,,,, -162500,7.382462,1.8948367,,,,,,,,,,,,,, -162600,7.2908483,1.8939466,,,,,,,,,,,,,, -162700,6.6624513,1.8833312,,,,,,,,,,,,,, -162800,7.273501,1.8803755,,,,,,,,,,,,,, -162900,7.530103,1.8688796,,,,,,,,,,,,,, -163000,7.4111915,1.9039183,,,,,,,,,,,,,, -163100,7.7924476,1.927607,,,,,,,,,,,,,, -163200,7.3282804,1.8736383,,,,,,,,,,,,,, -163300,6.7747865,1.9725468,,,,,,,,,,,,,, -163400,7.167571,1.8913293,,,,,,,,,,,,,, -163500,7.4892397,1.898074,,,,,,,,,,,,,, -163600,7.070474,1.8999546,,,,,,,,,,,,,, -163700,7.3680544,1.9296176,,,,,,,,,,,,,, -163800,7.1766458,1.8707842,,,,,,,,,,,,,, -163880,,,0.8883529901504517,0.5204653143882751,0.7597799897193909,1.0434074401855469,50000.0,0.6351000070571899,1.6635215282440186,10000.0,55650.60866093636,57691.20793008804,55650.60866093636,2029.086805820465,5.652773380279541,0.0 -163900,7.9494576,1.917069,,,,,,,,,,,,,, -164000,7.055153,1.8497589,,,,,,,,,,,,,, -164100,6.760568,1.8829422,,,,,,,,,,,,,, -164200,6.925742,1.8986552,,,,,,,,,,,,,, -164300,7.4488134,1.938199,,,,,,,,,,,,,, -164400,7.4172845,1.8941581,,,,,,,,,,,,,, -164500,7.3521705,1.8745579,,,,,,,,,,,,,, -164600,7.3709416,1.8873342,,,,,,,,,,,,,, -164700,6.884269,1.8365579,,,,,,,,,,,,,, -164800,7.073046,1.9142996,,,,,,,,,,,,,, -164900,7.127713,1.9368167,,,,,,,,,,,,,, -165000,7.1639647,1.8772073,,,,,,,,,,,,,, -165100,7.437895,1.9182959,,,,,,,,,,,,,, -165200,6.5966144,1.7828282,,,,,,,,,,,,,, -165300,7.750193,1.9431843,,,,,,,,,,,,,, -165384,,,0.8879145383834839,0.5141134858131409,0.7617999911308289,1.0365240573883057,50000.0,0.6367000341415405,1.6468669176101685,10000.0,56160.56981110573,58218.42898082733,56160.56981110573,2046.2316403388977,5.71399450302124,0.0 -165400,7.378778,1.8386933,,,,,,,,,,,,,, -165500,7.7266874,1.8698852,,,,,,,,,,,,,, -165600,6.928042,1.8948274,,,,,,,,,,,,,, -165700,6.9019628,1.8475513,,,,,,,,,,,,,, -165800,7.8218136,1.8556293,,,,,,,,,,,,,, -165900,7.8471084,1.8798544,,,,,,,,,,,,,, -166000,7.1738224,1.8742144,,,,,,,,,,,,,, -166100,7.81832,1.7926304,,,,,,,,,,,,,, -166200,7.3273973,1.8738632,,,,,,,,,,,,,, -166300,8.338043,1.9406637,,,,,,,,,,,,,, -166400,7.393373,1.8838233,,,,,,,,,,,,,, -166500,7.00188,1.861383,,,,,,,,,,,,,, -166600,7.69799,1.9705043,,,,,,,,,,,,,, -166700,8.1463585,1.8733225,,,,,,,,,,,,,, -166800,7.3310294,1.8597053,,,,,,,,,,,,,, -166888,,,0.9036790132522584,0.4611811637878418,0.7634199857711792,1.030709147453308,50000.0,0.6419000029563904,1.6486574411392212,10000.0,56670.46351933479,58745.71489930153,56670.46351933479,2063.513118505478,5.7704408168792725,0.0 -166900,7.043892,1.9052016,,,,,,,,,,,,,, -167000,7.2637873,1.903839,,,,,,,,,,,,,, -167100,7.2350817,1.79561,,,,,,,,,,,,,, -167200,7.6484733,1.8222464,,,,,,,,,,,,,, -167300,8.857079,1.8787332,,,,,,,,,,,,,, -167400,8.213928,1.8473065,,,,,,,,,,,,,, -167500,6.9716444,1.9021833,,,,,,,,,,,,,, -167600,7.3471255,1.808795,,,,,,,,,,,,,, -167700,8.653309,1.9126385,,,,,,,,,,,,,, -167800,7.1365275,1.8618243,,,,,,,,,,,,,, -167900,7.9050217,1.9126692,,,,,,,,,,,,,, -168000,7.3981123,1.8156691,,,,,,,,,,,,,, -168100,7.618997,1.9633348,,,,,,,,,,,,,, -168200,8.368803,1.8745781,,,,,,,,,,,,,, -168300,7.74217,1.9003319,,,,,,,,,,,,,, -168393,,,0.9052734375,0.4601774215698242,0.7659199833869934,1.0219361782073977,50000.0,0.6416000127792358,1.6349159479141235,10000.0,57180.640429496765,59273.3946313858,57180.640429496765,2080.8864829540253,5.8460328578948975,0.0 -168400,7.8234334,1.8918478,,,,,,,,,,,,,, -168500,7.3138156,1.8286057,,,,,,,,,,,,,, -168600,8.21831,1.8155285,,,,,,,,,,,,,, -168700,6.9165955,1.8990989,,,,,,,,,,,,,, -168800,8.303925,1.8908849,,,,,,,,,,,,,, -168900,7.6499567,1.8467357,,,,,,,,,,,,,, -169000,7.62094,1.8693123,,,,,,,,,,,,,, -169100,7.4730983,1.884766,,,,,,,,,,,,,, -169200,7.007213,1.804611,,,,,,,,,,,,,, -169300,6.849884,1.8230623,,,,,,,,,,,,,, -169400,7.4560075,1.8656588,,,,,,,,,,,,,, -169500,7.2611947,1.8961235,,,,,,,,,,,,,, -169600,7.271722,1.8000515,,,,,,,,,,,,,, -169700,7.251143,1.8692753,,,,,,,,,,,,,, -169800,7.131125,1.879725,,,,,,,,,,,,,, -169898,,,0.9046555757522584,0.4588156342506408,0.7649399638175964,1.0206623077392578,50000.0,0.6459000110626221,1.622611403465271,10000.0,57690.83558368683,59801.218878507614,57690.83558368683,2098.3980734348297,5.909879684448242,0.0 -169900,7.929098,1.8366683,,,,,,,,,,,,,, -170000,6.7385406,1.8599818,,,,,,,,,,,,,, -170100,7.222889,1.8140292,,,,,,,,,,,,,, -170200,7.182392,1.7851375,,,,,,,,,,,,,, -170300,7.376771,1.834846,,,,,,,,,,,,,, -170400,7.4649124,1.7756894,,,,,,,,,,,,,, -170500,7.096928,1.8849074,,,,,,,,,,,,,, -170600,7.411853,1.7690272,,,,,,,,,,,,,, -170700,7.1119485,1.8393891,,,,,,,,,,,,,, -170800,7.1860595,1.769301,,,,,,,,,,,,,, -170900,7.684721,1.8038445,,,,,,,,,,,,,, -171000,8.025446,1.8063512,,,,,,,,,,,,,, -171100,7.3369145,1.8316069,,,,,,,,,,,,,, -171200,7.9511337,1.8210318,,,,,,,,,,,,,, -171300,8.270143,1.8572524,,,,,,,,,,,,,, -171400,8.30361,1.8331187,,,,,,,,,,,,,, -171403,,,0.9075254797935486,0.4488637149333954,0.7670800089836121,1.009906530380249,50000.0,0.6456000208854675,1.6306302547454834,10000.0,58201.03377509117,60329.157285928726,58201.03377509117,2116.0283353328705,5.96552038192749,0.0 -171500,8.041868,1.812221,,,,,,,,,,,,,, -171600,7.0621815,1.7975806,,,,,,,,,,,,,, -171700,7.3578463,1.8154024,,,,,,,,,,,,,, -171800,7.9016294,1.7290287,,,,,,,,,,,,,, -171900,7.6773934,1.8510879,,,,,,,,,,,,,, -172000,7.1990266,1.9092393,,,,,,,,,,,,,, -172100,8.013542,1.8085444,,,,,,,,,,,,,, -172200,6.8741345,1.8218981,,,,,,,,,,,,,, -172300,7.5885043,1.8120726,,,,,,,,,,,,,, -172400,7.242466,1.813339,,,,,,,,,,,,,, -172500,7.279773,1.8121947,,,,,,,,,,,,,, -172600,7.654772,1.7926208,,,,,,,,,,,,,, -172700,8.028924,1.7772884,,,,,,,,,,,,,, -172800,8.260859,1.8704828,,,,,,,,,,,,,, -172900,7.333858,1.8496534,,,,,,,,,,,,,, -172907,,,0.910375475883484,0.4398549795150757,0.7702599763870239,1.003976345062256,50000.0,0.6478000283241272,1.615654230117798,10000.0,58711.01109552384,60856.62333703041,58711.01109552384,2133.4003245830536,6.02742600440979,0.0 -173000,7.6336966,1.7898347,,,,,,,,,,,,,, -173100,7.633703,1.854233,,,,,,,,,,,,,, -173200,8.44422,1.7956414,,,,,,,,,,,,,, -173300,8.104816,1.8316066,,,,,,,,,,,,,, -173400,7.9775186,1.7997096,,,,,,,,,,,,,, -173500,8.098632,1.8501604,,,,,,,,,,,,,, -173600,7.9413333,1.8118719,,,,,,,,,,,,,, -173700,8.317412,1.8389252,,,,,,,,,,,,,, -173800,7.6564536,1.7908111,,,,,,,,,,,,,, -173900,8.882854,1.8707105,,,,,,,,,,,,,, -174000,8.06939,1.7724304,,,,,,,,,,,,,, -174100,7.505933,1.8508772,,,,,,,,,,,,,, -174200,7.640812,1.8409535,,,,,,,,,,,,,, -174300,9.072524,1.8439219,,,,,,,,,,,,,, -174400,7.826278,1.734746,,,,,,,,,,,,,, -174412,,,0.9120894074440002,0.4342247247695923,0.7703799605369568,1.0012692213058472,50000.0,0.6509000062942505,1.6078853607177734,10000.0,59221.15896511078,61384.15951418877,59221.15896511078,2150.674258470536,6.087927341461182,0.0 -174500,7.2279596,1.7434555,,,,,,,,,,,,,, -174600,7.320183,1.7590979,,,,,,,,,,,,,, -174700,6.856408,1.701622,,,,,,,,,,,,,, -174800,7.4599347,1.7279603,,,,,,,,,,,,,, -174900,8.579656,1.8318956,,,,,,,,,,,,,, -175000,7.896465,1.796337,,,,,,,,,,,,,, -175100,7.642576,1.8344759,,,,,,,,,,,,,, -175200,8.393542,1.7652893,,,,,,,,,,,,,, -175300,6.8249974,1.7221576,,,,,,,,,,,,,, -175400,7.781236,1.7809675,,,,,,,,,,,,,, -175500,7.4305773,1.7760395,,,,,,,,,,,,,, -175600,8.675871,1.811244,,,,,,,,,,,,,, -175700,7.788252,1.7803317,,,,,,,,,,,,,, -175800,8.427952,1.8059165,,,,,,,,,,,,,, -175900,8.256399,1.7940711,,,,,,,,,,,,,, -175916,,,0.915796399116516,0.4226926863193512,0.7699999809265137,1.0016485452651978,50000.0,0.6498000025749207,1.614419937133789,10000.0,59731.146606206894,61911.43213367462,59731.146606206894,2167.850204706192,6.143632411956787,0.0 -176000,7.3106384,1.7717775,,,,,,,,,,,,,, -176100,7.8515134,1.7645856,,,,,,,,,,,,,, -176200,7.2466536,1.7755588,,,,,,,,,,,,,, -176300,8.512975,1.7844863,,,,,,,,,,,,,, -176400,7.575219,1.7070265,,,,,,,,,,,,,, -176500,6.96633,1.7601148,,,,,,,,,,,,,, -176600,8.014874,1.7876393,,,,,,,,,,,,,, -176700,7.7644,1.7599456,,,,,,,,,,,,,, -176800,7.3890796,1.8298782,,,,,,,,,,,,,, -176900,8.321338,1.7829984,,,,,,,,,,,,,, -177000,7.616015,1.7915007,,,,,,,,,,,,,, -177100,8.103554,1.812465,,,,,,,,,,,,,, -177200,8.021571,1.8718787,,,,,,,,,,,,,, -177300,7.563213,1.7686726,,,,,,,,,,,,,, -177400,7.806105,1.7770473,,,,,,,,,,,,,, -177420,,,0.9207788109779358,0.4045875966548919,0.7719599604606628,0.9974290132522584,50000.0,0.6514000296592712,1.6060065031051636,10000.0,60241.354083538055,62439.29220581055,60241.354083538055,2185.389719486237,6.202255487442017,0.0 -177500,6.9036193,1.753578,,,,,,,,,,,,,, -177600,7.5625167,1.7528259,,,,,,,,,,,,,, -177700,8.208888,1.7485594,,,,,,,,,,,,,, -177800,7.776823,1.7463304,,,,,,,,,,,,,, -177900,7.8368316,1.7818431,,,,,,,,,,,,,, -178000,8.8163,1.8645301,,,,,,,,,,,,,, -178100,7.830197,1.7517029,,,,,,,,,,,,,, -178200,7.6668243,1.8324052,,,,,,,,,,,,,, -178300,8.407273,1.8475233,,,,,,,,,,,,,, -178400,7.5926447,1.7638173,,,,,,,,,,,,,, -178500,7.5392103,1.7736285,,,,,,,,,,,,,, -178600,8.641904,1.8325627,,,,,,,,,,,,,, -178700,7.8740363,1.7276189,,,,,,,,,,,,,, -178800,7.4818835,1.7661135,,,,,,,,,,,,,, -178900,8.137812,1.8498803,,,,,,,,,,,,,, -178924,,,0.9169324040412904,0.405758649110794,0.7729399800300598,0.9932107329368592,50000.0,0.6512000560760498,1.60286545753479,10000.0,60751.29431128502,62966.76009917259,60751.29431128502,2202.798728942871,6.266868829727173,0.0 -179000,7.6374826,1.858877,,,,,,,,,,,,,, -179100,7.7699986,1.7887349,,,,,,,,,,,,,, -179200,8.508171,1.8149769,,,,,,,,,,,,,, -179300,8.17946,1.7914307,,,,,,,,,,,,,, -179400,7.8234825,1.8012164,,,,,,,,,,,,,, -179500,8.888666,1.8343271,,,,,,,,,,,,,, -179600,7.6345663,1.7776244,,,,,,,,,,,,,, -179700,7.8548527,1.7704921,,,,,,,,,,,,,, -179800,8.515234,1.7846159,,,,,,,,,,,,,, -179900,7.124692,1.7613345,,,,,,,,,,,,,, -180000,8.354176,1.7854185,,,,,,,,,,,,,, -180100,7.903352,1.770898,,,,,,,,,,,,,, -180200,7.5311604,1.7671824,,,,,,,,,,,,,, -180300,8.230057,1.7647557,,,,,,,,,,,,,, -180400,8.403834,1.778406,,,,,,,,,,,,,, -180429,,,0.9206991195678712,0.4018170833587646,0.7720800042152405,0.9955241084098816,50000.0,0.651900053024292,1.6055456399917605,10000.0,61261.49818825722,63494.355257987976,61261.49818825722,2220.0592410564423,6.343884229660034,0.0 -180500,7.6049247,1.8059044,,,,,,,,,,,,,, -180600,7.656254,1.7812035,,,,,,,,,,,,,, -180700,7.813617,1.7695876,,,,,,,,,,,,,, -180800,7.8847747,1.7537724,,,,,,,,,,,,,, -180900,7.8791227,1.8050286,,,,,,,,,,,,,, -181000,7.536798,1.6893222,,,,,,,,,,,,,, -181100,8.481448,1.8081102,,,,,,,,,,,,,, -181200,7.309296,1.7688348,,,,,,,,,,,,,, -181300,7.9323406,1.7954623,,,,,,,,,,,,,, -181400,7.842725,1.7774198,,,,,,,,,,,,,, -181500,8.931661,1.7881553,,,,,,,,,,,,,, -181600,7.63468,1.8203129,,,,,,,,,,,,,, -181700,7.861558,1.7535722,,,,,,,,,,,,,, -181800,7.020387,1.723068,,,,,,,,,,,,,, -181900,8.106373,1.7544229,,,,,,,,,,,,,, -181934,,,0.91898512840271,0.4023058712482452,0.7721999883651733,0.9920402765274048,50000.0,0.653700053691864,1.6013190746307373,10000.0,61771.65120244026,64021.787281513214,61771.65120244026,2237.2278864383698,6.401177883148193,0.0 -182000,7.7536397,1.7638929,,,,,,,,,,,,,, -182100,7.5813513,1.7759647,,,,,,,,,,,,,, -182200,7.082845,1.7970452,,,,,,,,,,,,,, -182300,7.360333,1.7584708,,,,,,,,,,,,,, -182400,7.6692796,1.775353,,,,,,,,,,,,,, -182500,8.171822,1.727575,,,,,,,,,,,,,, -182600,8.431208,1.8187157,,,,,,,,,,,,,, -182700,7.3924165,1.7607471,,,,,,,,,,,,,, -182800,8.101474,1.8671697,,,,,,,,,,,,,, -182900,7.8269234,1.7807279,,,,,,,,,,,,,, -183000,7.765093,1.7922921,,,,,,,,,,,,,, -183100,7.489946,1.7437459,,,,,,,,,,,,,, -183200,8.188781,1.8179858,,,,,,,,,,,,,, -183300,8.114649,1.7386578,,,,,,,,,,,,,, -183400,7.7670336,1.70561,,,,,,,,,,,,,, -183438,,,0.9219945669174194,0.3971076905727386,0.772819995880127,0.9916077852249146,50000.0,0.6526000499725342,1.6008130311965942,10000.0,62281.68023562431,64549.20379114151,62281.68023562431,2254.500765562057,6.462839603424072,0.0 -183500,8.388041,1.8336322,,,,,,,,,,,,,, -183600,8.372842,1.8273058,,,,,,,,,,,,,, -183700,8.227224,1.815994,,,,,,,,,,,,,, -183800,7.8921614,1.7685137,,,,,,,,,,,,,, -183900,8.356347,1.870714,,,,,,,,,,,,,, -184000,7.3539157,1.7607114,,,,,,,,,,,,,, -184100,8.07355,1.6708062,,,,,,,,,,,,,, -184200,8.124724,1.830565,,,,,,,,,,,,,, -184300,8.550961,1.8128574,,,,,,,,,,,,,, -184400,7.874198,1.723808,,,,,,,,,,,,,, -184500,7.2472153,1.7590894,,,,,,,,,,,,,, -184600,7.257191,1.7675745,,,,,,,,,,,,,, -184700,8.018742,1.7847611,,,,,,,,,,,,,, -184800,7.3905654,1.7274071,,,,,,,,,,,,,, -184900,7.6920223,1.7813636,,,,,,,,,,,,,, -184942,,,0.9215561151504515,0.3981364667415619,0.7727400064468384,0.991069793701172,50000.0,0.653700053691864,1.5995631217956543,10000.0,62791.72662734985,65076.893508434296,62791.72662734985,2272.0287766456604,6.52376127243042,0.0 -185000,7.7338395,1.7691851,,,,,,,,,,,,,, -185100,8.010542,1.7581822,,,,,,,,,,,,,, -185200,8.194973,1.7270019,,,,,,,,,,,,,, -185300,8.122476,1.8141881,,,,,,,,,,,,,, -185400,7.6082854,1.7880368,,,,,,,,,,,,,, -185500,8.674299,1.7706611,,,,,,,,,,,,,, -185581,,,,,,,,,,,63008.197350502014,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/eval_measurements.csv deleted file mode 100644 index dc0b6ae88..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -18.382753133773804,0.0,31.89585375785828,1,0,31.89585375785828,0.0006000000284984,6.9125494956970215,10000,50.27870035171509,0.0007174744969233,6.912591457366943,0.0007599999662488,6.913174629211426,50000 -36.13430953025818,0.0164451599121093,541.9562501907349,1497,0,541.9562501907349,0.053700003772974,5.532280921936035,10000,578.1606819629669,0.0817721635103225,5.256331920623779,0.0762000009417533,5.30866289138794,50000 -53.70322799682617,0.0438079833984375,1051.935497045517,2992,0,1051.935497045517,0.1266000121831894,4.754936218261719,10000,1105.7899084091189,0.1966477930545807,4.2057881355285645,0.1745599955320358,4.332620143890381,50000 -72.37804913520813,0.0735633373260498,1562.1542949676514,4488,0,1562.1542949676514,0.2101000100374221,4.122847080230713,10000,1634.7661957740784,0.3034917116165161,3.4856112003326416,0.2791000008583069,3.62486720085144,50000 -90.1362087726593,0.1037073135375976,2072.106372833252,5984,0,2072.106372833252,0.2732000052928924,3.685528755187988,10000,2162.559784412384,0.3869180381298065,2.9858076572418213,0.359059989452362,3.1390581130981445,50000 -107.99761819839478,0.1325531005859375,2582.1680810451508,7480,0,2582.1680810451508,0.3192000091075897,3.449917316436768,10000,2690.564907312393,0.4513911008834839,2.6858718395233154,0.4205600023269653,2.84183406829834,50000 -125.6869819164276,0.1603870391845703,3092.3153219223022,8977,0,3092.3153219223022,0.3625000119209289,3.1810948848724365,10000,3218.483143806457,0.5364516973495483,2.2150745391845703,0.4718599915504455,2.5311429500579834,50000 -143.2770929336548,0.1908240318298339,3602.40500998497,10474,0,3602.40500998497,0.3958000242710113,3.038412094116211,10000,3746.247165679932,0.5558035373687744,2.1703338623046875,0.5116599798202515,2.3818368911743164,50000 -161.05826830863953,0.2227044105529785,4112.561676979065,11972,0,4112.561676979065,0.4250000119209289,2.862744092941284,10000,4274.270484447479,0.5846021771430969,1.9957231283187864,0.5405799746513367,2.2037601470947266,50000 -178.40104007720947,0.2536995410919189,4622.485732078552,13469,0,4622.485732078552,0.4506000280380249,2.675619602203369,10000,4801.622956514359,0.6219108700752258,1.801672101020813,0.5726400017738342,2.031932592391968,50000 -196.3314588069916,0.285891056060791,5132.471919298172,14968,0,5132.471919298172,0.4617000222206116,2.591858148574829,10000,5329.625428676605,0.6355029940605164,1.682795166969299,0.5859599709510803,1.926688551902771,50000 -214.10228490829468,0.3106076717376709,5642.635751485825,16467,0,5642.635751485825,0.4690000116825104,2.6094350814819336,10000,5857.638697385788,0.6412428021430969,1.7201507091522217,0.5932999849319458,1.9380810260772705,50000 -231.8304960727692,0.3401076793670654,6152.796566486359,17967,0,6152.796566486359,0.4798000156879425,2.5136966705322266,10000,6385.609614610672,0.7058354616165161,1.4462586641311646,0.6126999855041504,1.8431637287139893,50000 -249.6162793636322,0.3739089965820312,6662.721271753311,19466,0,6662.721271753311,0.4852000176906585,2.5008463859558105,10000,6913.4088344573975,0.6865234375,1.4992294311523438,0.6133399605751038,1.816870212554932,50000 -267.1690058708191,0.4088256359100342,7172.737153053284,20966,0,7172.737153053284,0.501300036907196,2.41690993309021,10000,7441.06632399559,0.6920639276504517,1.470927596092224,0.6231799721717834,1.7682565450668335,50000 -284.8941237926483,0.4400563240051269,7682.723104953766,22465,0,7682.723104953766,0.5006999969482422,2.429164409637451,10000,7968.861889123917,0.6925621628761292,1.4761664867401123,0.6293399930000305,1.753197431564331,50000 -302.46875619888306,0.4770851135253906,8192.924956321716,23966,0,8192.924956321716,0.5092000365257263,2.4119579792022705,10000,8496.729290246964,0.6956911683082581,1.4922162294387815,0.634719967842102,1.7566068172454834,50000 -320.05310821533203,0.508368968963623,8702.947353839874,25466,0,8702.947353839874,0.5162000060081482,2.3508965969085693,10000,9024.420098781586,0.7027263641357422,1.419603705406189,0.6396999955177307,1.6906665563583374,50000 -337.60744285583496,0.5415892601013184,9212.861717700958,26966,0,9212.861717700958,0.4951000213623047,2.4571080207824707,10000,9551.975267887115,0.6951530575752258,1.484920859336853,0.6298199892044067,1.7593168020248413,50000 -355.4415764808655,0.5747489929199219,9722.92922616005,28466,0,9722.92922616005,0.5190000534057617,2.301248073577881,10000,10079.964128255844,0.7315250039100647,1.2711821794509888,0.6417799592018127,1.6622222661972046,50000 -373.3464798927307,0.6060409545898438,10233.182272434236,29967,0,10233.182272434236,0.5196000337600708,2.333815813064575,10000,10608.206499814987,0.7155213356018066,1.3536566495895386,0.6436799764633179,1.67797589302063,50000 -391.1074600219727,0.6514892578125,10743.343856096268,31468,0,10743.343856096268,0.5074000358581543,2.405803203582764,10000,11136.228532791138,0.7057158946990967,1.420884132385254,0.6421200037002563,1.7156926393508911,50000 -408.85849595069885,0.6858878135681152,11253.426638364792,32968,0,11253.426638364792,0.5246000289916992,2.299521446228028,10000,11664.150138139725,0.7166374325752258,1.359977960586548,0.6487799882888794,1.6539162397384644,50000 -426.4737284183502,0.7184295654296875,11763.347280740738,34468,0,11763.347280740738,0.5151000022888184,2.3533051013946533,10000,12191.771873950958,0.70703125,1.396243453025818,0.6412400007247925,1.682578444480896,50000 -444.2426521778106,0.7576742172241211,12273.410103797913,35969,0,12273.410103797913,0.522599995136261,2.307307243347168,10000,12719.69810295105,0.7151227593421936,1.3773218393325806,0.6521199941635132,1.6565839052200315,50000 -462.0181083679199,0.7907888889312744,12783.423173427582,37469,0,12783.423173427582,0.5078000426292419,2.3833701610565186,10000,13247.574071884155,0.7365872263908386,1.2994897365570068,0.6460199952125549,1.7007862329483032,50000 -479.7775735855103,0.8264029026031494,13293.508579730988,38970,0,13293.508579730988,0.5218999981880188,2.2923173904418945,10000,13775.50724887848,0.7347337007522583,1.2530560493469238,0.6542199850082397,1.6197998523712158,50000 -497.2023296356201,0.8594932556152344,13803.586438894272,40471,0,13803.586438894272,0.5252000093460083,2.321791410446167,10000,14303.099082946776,0.7262037396430969,1.325469732284546,0.6515399813652039,1.6518381834030151,50000 -515.5967583656311,0.8947463035583496,14313.738857030869,41972,0,14313.738857030869,0.5282000303268433,2.2479236125946045,10000,14831.734774827955,0.7281568646430969,1.2736696004867554,0.6581000089645386,1.5901299715042114,50000 -533.4117612838745,0.9289801120758056,14823.836078882216,43473,0,14823.836078882216,0.5232000350952148,2.333250999450684,10000,15359.73616051674,0.7159597873687744,1.3737213611602783,0.6536200046539307,1.6578227281570437,50000 -551.2723331451416,0.964834690093994,15333.796644449234,44974,0,15333.796644449234,0.532800018787384,2.273008108139038,10000,15887.6466050148,0.7290138602256775,1.3344924449920654,0.6631999611854553,1.6209385395050049,50000 -568.7053790092468,1.002387285232544,15843.76311159134,46475,0,15843.76311159134,0.527999997138977,2.3292038440704346,10000,16415.138216257095,0.75882887840271,1.2321513891220093,0.6564799547195435,1.6605032682418823,50000 -586.3989787101746,1.043410301208496,16353.741973161696,47976,0,16353.741973161696,0.5361000299453735,2.2724545001983643,10000,16942.905286073685,0.7443000674247742,1.2563570737838743,0.6606599688529968,1.619337797164917,50000 -604.2181794643402,1.0814259052276611,16863.93771147728,49477,0,16863.93771147728,0.5272000432014465,2.2808008193969727,10000,17471.012765169144,0.73539137840271,1.2704975605010986,0.6642000079154968,1.5893797874450684,50000 -622.0479846000671,1.1158974170684814,17373.93242096901,50978,0,17373.93242096901,0.5261000394821167,2.2644474506378174,10000,17998.925570249557,0.7208425998687744,1.3237276077270508,0.6489999890327454,1.642844319343567,50000 -639.8506090641022,1.1524429321289062,17884.122307777405,52479,0,17884.122307777405,0.5471000075340271,2.1574642658233643,10000,18527.0086209774,0.7428650856018066,1.2034056186676023,0.6698399782180786,1.527472972869873,50000 -657.5397083759308,1.1924428939819336,18394.34732890129,53981,0,18394.34732890129,0.5386000275611877,2.2656149864196777,10000,19055.01717185974,0.7333585619926453,1.3001765012741089,0.6660999655723572,1.604732871055603,50000 -675.1203720569611,1.2325246334075928,18904.38467264176,55482,0,18904.38467264176,0.5451000332832336,2.194451332092285,10000,19582.730887413025,0.7835220098495483,1.07129967212677,0.6714000105857849,1.5408074855804443,50000 -692.9199199676514,1.2718331813812256,19414.40906858444,56983,0,19414.40906858444,0.5392000079154968,2.25341796875,10000,20110.64827489853,0.753926157951355,1.2177770137786863,0.6661799550056458,1.6002877950668335,50000 -710.7420144081116,1.3092410564422607,19924.53598690033,58484,0,19924.53598690033,0.5412000417709351,2.232109785079956,10000,20638.68911504745,0.7466916441917419,1.2169824838638306,0.6627799868583679,1.5723552703857422,50000 -728.5081448554993,1.348177433013916,20434.61958694458,59985,0,20434.61958694458,0.5350000262260437,2.256978988647461,10000,21166.6318423748,0.7434031963348389,1.2602286338806152,0.6678599715232849,1.5897306203842163,50000 -746.0784072875977,1.3884241580963137,20944.60235857964,61486,0,20944.60235857964,0.5541000366210938,2.180800676345825,10000,21694.27863621712,0.7493622303009033,1.2253031730651855,0.6746999621391296,1.5467970371246338,50000 -764.052640914917,1.428706407546997,21454.522994041443,62987,0,21454.522994041443,0.5515000224113464,2.1404056549072266,10000,22222.26675367356,0.7569953799247742,1.1516437530517578,0.6823599934577942,1.4813121557235718,50000 -782.0063388347626,1.467320203781128,21964.64013814926,64488,0,21964.64013814926,0.5467000007629395,2.184284448623657,10000,22750.43190002441,0.7630141973495483,1.167246699333191,0.6773399710655212,1.535491704940796,50000 -799.6415371894836,1.5086112022399902,22474.734596014023,65928,0,22474.734596014023,0.5439000129699707,2.232006549835205,10000,23278.254689455032,0.7694514989852905,1.1408820152282717,0.6707199811935425,1.5621381998062134,50000 -817.1650323867798,1.54809308052063,22984.724188804623,67429,0,22984.724188804623,0.551800012588501,2.163882255554199,10000,23805.86190366745,0.7701291441917419,1.1213020086288452,0.6814000010490417,1.508509635925293,50000 -834.686586856842,1.5904114246368408,23494.799648284912,68931,0,23494.799648284912,0.5585000514984131,2.107130527496338,10000,24333.555045366287,0.7631736397743225,1.1109957695007324,0.6834200024604797,1.4736833572387695,50000 -852.1690158843994,1.633352518081665,24004.760673046112,70432,0,24004.760673046112,0.554900050163269,2.130442380905152,10000,24861.097346305847,0.7593470811843872,1.125109314918518,0.6784999966621399,1.4743475914001465,50000 -869.810240983963,1.6746103763580322,24514.92568397522,71934,0,24514.92568397522,0.5496000051498413,2.1928555965423584,10000,25388.99902510643,0.7483657598495483,1.220489263534546,0.6733199954032898,1.5536441802978516,50000 -887.5698609352112,1.713430643081665,25025.14162898064,73436,0,25025.14162898064,0.5415000319480896,2.236577272415161,10000,25917.068249225616,0.7437220811843872,1.2418133020401,0.6713399887084961,1.5635813474655151,50000 -905.3917419910432,1.753354549407959,25535.273594856262,74938,0,25535.273594856262,0.5593000054359436,2.1356232166290283,10000,26445.11689734459,0.7940250039100647,1.0290786027908323,0.6819599866867065,1.4988394975662231,50000 -923.286732673645,1.793353796005249,26045.1993291378,76439,0,26045.1993291378,0.5587000250816345,2.140300989151001,10000,26973.03084754944,0.7794363498687744,1.0432428121566772,0.6851999759674072,1.4631946086883545,50000 -940.8660507202148,1.8331985473632808,26555.232617139816,77941,0,26555.232617139816,0.5613000392913818,2.144953489303589,10000,27500.738260746,0.7739556431770325,1.1201239824295044,0.686199963092804,1.504704236984253,50000 -958.206443309784,1.8757927417755127,27065.26093101501,79442,0,27065.26093101501,0.5567000508308411,2.17029881477356,10000,28028.20407128334,0.7662029266357422,1.158005952835083,0.6850000023841858,1.5323312282562256,50000 -976.9262478351592,1.91663146018982,27575.289494991302,80944,0,27575.289494991302,0.5541000366210938,2.1207895278930664,10000,28557.046733379364,0.7602837681770325,1.1421968936920166,0.6798799633979797,1.5045496225357056,50000 -994.8886668682098,1.9587581157684328,28085.512244701385,82445,0,28085.512244701385,0.5697000026702881,2.080371379852295,10000,29085.32898545265,0.7724011540412903,1.0933239459991455,0.6927399635314941,1.4421255588531494,50000 -1012.3704881668092,2.003067970275879,28595.73851132393,83948,0,28595.73851132393,0.5618000030517578,2.10579776763916,10000,29613.13614249229,0.8116828799247742,0.9399776458740234,0.6905199885368347,1.4603952169418335,50000 -1029.901986837387,2.045266628265381,29105.90775370598,85450,0,29105.90775370598,0.567300021648407,2.1164779663085938,10000,30140.93297529221,0.7909757494926453,1.06205952167511,0.6925599575042725,1.4879108667373655,50000 -1047.5002024173737,2.0875396728515625,29615.94760608673,86952,0,29615.94760608673,0.5677000284194946,2.063734769821167,10000,30668.66878247261,0.7831034660339355,1.0348970890045166,0.6964600086212158,1.4234379529953003,50000 -1065.1176307201383,2.137878894805908,30125.87562441826,88453,0,30125.87562441826,0.5667000412940979,2.1170437335968018,10000,31196.319259643555,0.7817681431770325,1.088742971420288,0.6913599967956543,1.4801713228225708,50000 -1082.7853560447693,2.1824607849121094,30635.928003787994,89955,0,30635.928003787994,0.5601000189781189,2.1493637561798096,10000,31724.13882732392,0.7678770422935486,1.122750759124756,0.6862999796867371,1.481406569480896,50000 -1100.424178123474,2.2250864505767822,31146.005562067032,91457,0,31146.005562067032,0.5726000070571899,2.05700421333313,10000,32251.952335357662,0.7881656289100647,1.0311802625656128,0.7014600038528442,1.415004014968872,50000 -1118.377030134201,2.269920825958252,31655.91071796417,92959,0,31655.91071796417,0.563800036907196,2.128091812133789,10000,32779.90780115128,0.7984693646430969,1.0095969438552856,0.6953799724578857,1.4463399648666382,50000 -1135.9845206737518,2.3124051094055176,32166.15032839775,94462,0,32166.15032839775,0.5697000026702881,2.11687970161438,10000,33307.851593732834,0.8036909699440002,1.0080145597457886,0.6985599994659424,1.4534142017364502,50000 -1153.5908043384552,2.356222152709961,32676.37836766243,95965,0,32676.37836766243,0.5717000365257263,2.1151468753814697,10000,33835.78617501259,0.7919324040412903,1.0704740285873413,0.6911799907684326,1.506101369857788,50000 -1171.4737193584442,2.403014659881592,33186.609236717224,97467,0,33186.609236717224,0.5795000195503235,2.038626194000244,10000,34364.00185251236,0.7996053695678711,0.9840648174285888,0.70551997423172,1.3981292247772217,50000 -1189.1691591739657,2.448249578475952,33696.592266082764,98969,0,33696.592266082764,0.5839000344276428,2.0269968509674072,10000,34891.780616760254,0.7989476919174194,0.9818856716156006,0.7056799530982971,1.3842735290527344,50000 -1206.7572317123413,2.495995283126831,34206.52460384369,100470,0,34206.52460384369,0.5842000246047974,2.023866653442383,10000,35419.403853178024,0.8025948405265808,0.9827648401260376,0.707539975643158,1.3925204277038574,50000 -1224.4450266361237,2.540762186050415,34716.45545458794,101972,0,34716.45545458794,0.5784000158309937,2.0674097537994385,10000,35947.121560812,0.8014987111091614,1.016085505485535,0.707040011882782,1.427125334739685,50000 -1242.2984237670898,2.5861055850982666,35226.452053546906,103474,0,35226.452053546906,0.5851000547409058,2.004522562026977,10000,36475.07280921936,0.8282246589660645,0.8729314804077148,0.7114599943161011,1.368599534034729,50000 -1259.911861896515,2.632046937942505,35736.37076330185,104976,0,35736.37076330185,0.5789000391960144,2.0361974239349365,10000,37002.70536971092,0.8075972199440002,0.9604279398918152,0.7060999870300293,1.4053521156311035,50000 -1277.4861352443695,2.677516460418701,36246.44602751732,106478,0,36246.44602751732,0.5820000171661377,2.037717342376709,10000,37530.45402097702,0.8165656924247742,0.916153609752655,0.7117599844932556,1.3760380744934082,50000 -1295.2741153240204,2.727273941040039,36756.588076114655,107981,0,36756.588076114655,0.5901000499725342,2.000663042068481,10000,38058.4887046814,0.8157086968421936,0.9369272589683532,0.7161999940872192,1.3731088638305664,50000 -1312.8663160800934,2.7767326831817627,37266.67953324318,109483,0,37266.67953324318,0.5835000276565552,2.0119309425354004,10000,38586.27570772171,0.8148317933082581,0.9315711855888368,0.7142199873924255,1.35802960395813,50000 -1330.7699587345123,2.829568386077881,37776.58733320236,110985,0,37776.58733320236,0.5811000466346741,1.99495792388916,10000,39114.19459462166,0.8172432780265808,0.890960693359375,0.7128599882125854,1.340518832206726,50000 -1348.5736198425293,2.8754875659942627,38286.6563835144,112487,0,38286.6563835144,0.5963000059127808,1.9578930139541624,10000,39642.166607141495,0.8461814522743225,0.7963310480117798,0.7188400030136108,1.3273828029632568,50000 -1366.0709924697876,2.9247477054595947,38796.7903342247,113990,0,38796.7903342247,0.5891000032424927,2.0083277225494385,10000,40169.9016866684,0.8255141973495483,0.8634473085403442,0.7139399647712708,1.3523683547973633,50000 -1383.8353281021118,2.9718353748321533,39306.94250631333,115492,0,39306.94250631333,0.5921000242233276,1.966264724731445,10000,40697.91963505745,0.8307756781578064,0.861309826374054,0.7196199893951416,1.3414368629455566,50000 -1401.4225118160248,3.0217180252075195,39817.17044043541,116995,0,39817.17044043541,0.5955000519752502,1.9473471641540527,10000,41225.83962345123,0.8356983065605164,0.8533991575241089,0.7249400019645691,1.3167498111724854,50000 -1419.805627822876,3.0692250728607178,40327.12989664078,118497,0,40327.12989664078,0.5940999984741211,1.9697046279907229,10000,41754.2838177681,0.8339644074440002,0.8613070249557495,0.7239999771118164,1.3303016424179075,50000 -1437.6517629623413,3.1198184490203857,40837.23765873909,119999,0,40837.23765873909,0.6055999994277954,1.9222824573516848,10000,42282.34254765511,0.8390266299247742,0.8340352177619934,0.7273600101470947,1.301812767982483,50000 -1455.0882284641266,3.1719682216644287,41347.40583300591,121501,0,41347.40583300591,0.6010000109672546,1.9819520711898804,10000,42810.052292108536,0.8613081574440002,0.7704369425773621,0.7250399589538574,1.3356198072433472,50000 -1472.3993520736694,3.2225170135498047,41857.52238154411,123004,0,41857.52238154411,0.6055000424385071,1.9438683986663816,10000,43337.58539104462,0.8552694320678711,0.7926141023635864,0.7265399694442749,1.3171532154083252,50000 -1489.850107908249,3.269981622695923,42367.69151568413,124506,0,42367.69151568413,0.6068000197410583,1.9477180242538448,10000,43865.30625462532,0.8481743931770325,0.8100681304931641,0.7289199829101562,1.3144792318344116,50000 -1507.7119023799896,3.3195204734802246,42877.82626962662,126009,0,42877.82626962662,0.6034000515937805,1.9218626022338867,10000,44393.40523290634,0.845723032951355,0.7846372127532959,0.7291799783706665,1.2905118465423584,50000 -1525.374079704285,3.3690085411071777,43387.9768986702,127511,0,43387.9768986702,0.615600049495697,1.869235634803772,10000,44921.3228840828,0.8580795526504517,0.7403988838195801,0.7360000014305115,1.253514289855957,50000 -1542.8680157661438,3.4177677631378174,43898.0478746891,129014,0,43898.0478746891,0.6055999994277954,1.904442310333252,10000,45448.98998808861,0.85550856590271,0.7635818123817444,0.7354599833488464,1.2662724256515503,50000 -1560.2620012760162,3.47461485862732,44407.95130634308,130515,0,44407.95130634308,0.6105000376701355,1.9217170476913448,10000,45976.39903450012,0.8826330900192261,0.6655307412147522,0.7340999841690063,1.2825183868408203,50000 -1578.0767569541931,3.5267868041992188,44918.10623574257,132018,0,44918.10623574257,0.6077000498771667,1.936708688735962,10000,46504.47780227661,0.8704758882522583,0.7248345613479614,0.731499969959259,1.2995328903198242,50000 -1595.8508143424988,4.649079084396362,45426.990837574005,133517,0,45426.990837574005,0.6086000204086304,1.885425329208374,10000,47032.312911748886,0.8717314600944519,0.7035813331604004,0.7369599938392639,1.2577601671218872,50000 -1613.7076733112335,4.703073501586914,45937.16322660446,135020,0,45937.16322660446,0.6134000420570374,1.8910014629364007,10000,47560.45282816887,0.868582546710968,0.7156649827957153,0.7356799840927124,1.2662253379821775,50000 -1631.1530323028564,4.756369113922119,46447.39303159714,136523,0,46447.39303159714,0.6157000064849854,1.8840872049331665,10000,48088.235530138016,0.8720503449440002,0.7015858888626099,0.7404599785804749,1.2451156377792358,50000 -1648.9669427871704,4.806999683380127,46957.294365644455,138025,0,46957.294365644455,0.6074000000953674,1.9359248876571653,10000,48616.05543756485,0.8668486475944519,0.7391188144683838,0.7368199825286865,1.2867945432662964,50000 -1666.5136742591858,4.857511758804321,47467.2943482399,139527,0,47467.2943482399,0.6183000206947327,1.87476646900177,10000,49143.7070350647,0.8956273794174194,0.627384603023529,0.7456600069999695,1.240861415863037,50000 -1684.1839241981506,4.914875984191895,47977.450786590576,141030,0,47977.450786590576,0.619700014591217,1.8548864126205444,10000,49671.64498496056,0.893973171710968,0.6151627898216248,0.7440399527549744,1.2287087440490725,50000 -1701.8740639686584,4.969464063644409,48487.50527572632,142532,0,48487.50527572632,0.6160000562667847,1.8964543342590328,10000,50199.49908399582,0.8904854655265808,0.6626537442207336,0.7447999715805054,1.2583589553833008,50000 -1719.357551574707,5.022829294204712,48997.59909081459,144035,0,48997.59909081459,0.6198000311851501,1.864946722984314,10000,50727.18270134926,0.892598032951355,0.6262313723564148,0.7450399994850159,1.2290778160095217,50000 -1737.1198983192444,5.078593730926514,49507.578125715256,145537,0,49507.578125715256,0.6224000453948975,1.8546254634857176,10000,51255.033161878586,0.8907046914100647,0.6271235346794128,0.7487599849700928,1.223987102508545,50000 -1754.819967508316,5.133638858795166,50017.51508355141,147039,0,50017.51508355141,0.6242000460624695,1.8519423007965088,10000,51782.78130912781,0.8966238498687744,0.6165596842765808,0.7476199865341187,1.228817582130432,50000 -1772.9918661117554,5.189709186553955,50527.50337028504,148541,0,50527.50337028504,0.6272000074386597,1.8344119787216189,10000,52311.05200695992,0.9024234414100648,0.5903157591819763,0.7495200037956238,1.215433120727539,50000 -1790.6060602664948,5.234796047210693,51037.676836013794,150044,0,51037.676836013794,0.6228000521659851,1.844643831253052,10000,52838.9393966198,0.910574734210968,0.5532589554786682,0.75,1.207600712776184,50000 -1808.1750228405,5.288285970687866,51547.57273697853,151545,0,51547.57273697853,0.6243000030517578,1.843218684196472,10000,53366.51215076447,0.909817397594452,0.566749632358551,0.7523599863052368,1.2168418169021606,50000 -1826.0408027172089,5.341819047927856,52057.77176237106,153048,0,52057.77176237106,0.6313000321388245,1.818287968635559,10000,53894.68529486656,0.915058970451355,0.5383015871047974,0.7546399831771851,1.1951937675476074,50000 -1843.5144710540767,5.40025782585144,52567.898394823074,154550,0,52567.898394823074,0.6278000473976135,1.8361276388168333,10000,54422.400752067566,0.9122289419174194,0.5614323616027832,0.75382000207901,1.2097517251968384,50000 -1861.369544029236,5.455489873886108,53077.98852467537,156053,0,53077.98852467537,0.6291000247001648,1.8328757286071773,10000,54950.4566886425,0.9175502061843872,0.5501120090484619,0.7555999755859375,1.2017172574996948,50000 -1879.8014032840729,5.5035552978515625,53588.06019878388,157555,0,53588.06019878388,0.6296000480651855,1.8347876071929927,10000,55479.06328034401,0.9153977632522584,0.5394268035888672,0.7566399574279785,1.1991684436798096,50000 -1897.5863370895383,5.557616472244263,54098.17010354996,159058,0,54098.17010354996,0.6288000345230103,1.838891744613648,10000,56007.06683254242,0.9265186190605164,0.5162858366966248,0.7549399733543396,1.2046244144439695,50000 -1914.979204893112,5.611308336257935,54608.08381175995,160560,0,54608.08381175995,0.6307000517845154,1.830616116523743,10000,56534.48181200028,0.924824595451355,0.5198838114738464,0.7564599514007568,1.200865626335144,50000 -1932.6419672966003,5.669963121414185,55118.07119345665,162062,0,55118.07119345665,0.6338000297546387,1.8089417219161987,10000,57062.24536585808,0.9266381859779358,0.4990904629230499,0.7584199905395508,1.1812278032302856,50000 -1950.3917593956,5.731624841690064,55628.10046887398,163564,0,55628.10046887398,0.6325000524520874,1.8226237297058103,10000,57590.141486644745,0.9269371628761292,0.5071867108345032,0.7584199905395508,1.188880205154419,50000 -1967.9118270874023,5.779476404190064,56138.13663291931,165066,0,56138.13663291931,0.6345000267028809,1.815388798713684,10000,58117.80002164841,0.9267378449440002,0.4970170259475708,0.7583799958229065,1.1788029670715332,50000 -1985.462195396424,5.834491014480591,56648.308772563934,166568,0,56648.308772563934,0.6345000267028809,1.816957950592041,10000,58645.63115429878,0.9300262928009032,0.5010972619056702,0.7594999670982361,1.1896491050720217,50000 -2003.1597566604607,5.8996758460998535,57158.31609511376,168070,0,57158.31609511376,0.6356000304222107,1.80722975730896,10000,59173.45578980446,0.9396922588348388,0.4563548862934112,0.7609599828720093,1.1790025234222412,50000 -2020.6457545757287,5.955905914306641,57668.43019104004,169572,0,57668.43019104004,0.6360000371932983,1.8080520629882808,10000,59701.16584134102,0.9342115521430968,0.4681693911552429,0.7597799897193909,1.1787148714065552,50000 -2038.581184387207,6.014849662780762,58178.49935603142,171074,0,58178.49935603142,0.6367000341415405,1.8073753118515008,10000,60229.28386855125,0.9379583597183228,0.4724557101726532,0.7612800002098083,1.180039882659912,50000 -2056.078993320465,6.078775405883789,58688.43921303749,172575,0,58688.43921303749,0.6373000144958496,1.7978694438934326,10000,60756.84143066406,0.9350087642669678,0.4679067730903625,0.761680006980896,1.1774041652679443,50000 -2073.7847397327423,6.137209415435791,59198.377883434296,174077,0,59198.377883434296,0.6359000205993652,1.802971720695496,10000,61284.59939098358,0.936344027519226,0.4684905409812927,0.7619799971580505,1.1775026321411133,50000 -2091.134479045868,6.194151878356934,59708.31727671623,175578,0,59708.31727671623,0.6399000287055969,1.8003582954406738,10000,61812.00027608872,0.9384765625,0.4636580049991607,0.7625199556350708,1.177310824394226,50000 -2109.08202624321,6.251053094863892,60218.3447508812,177080,0,60218.3447508812,0.6403000354766846,1.799221396446228,10000,62340.08955526352,0.94046950340271,0.4543436765670776,0.762779951095581,1.1745065450668335,50000 -2126.7394778728485,6.30780291557312,60728.52793264389,178582,0,60728.52793264389,0.6392000317573547,1.7995479106903076,10000,62868.04338693619,0.9402901530265808,0.4555262923240661,0.7623199820518494,1.173393964767456,50000 -2144.196943998337,6.366096258163452,61238.67451667786,180084,0,61238.67451667786,0.6391000151634216,1.7973567247390747,10000,63395.76078367233,0.9409279227256776,0.4550089836120605,0.7625799775123596,1.172922968864441,50000 -2161.671940803528,6.423700332641602,61748.84334039688,181586,0,61748.84334039688,0.6389000415802002,1.7996957302093506,10000,63923.51679396629,0.941824734210968,0.455964058637619,0.7629599571228027,1.1759765148162842,50000 -2179.367330789566,6.485302448272705,62258.9725549221,183087,0,62258.9725549221,0.6394000053405762,1.797763705253601,10000,64451.45683383942,0.9397321343421936,0.4528777301311493,0.7628600001335144,1.174703598022461,50000 -2197.1233875751495,6.553771495819092,62768.99567198753,184588,0,62768.99567198753,0.6394000053405762,1.797345757484436,10000,64979.35884261131,0.939473032951355,0.4559187591075897,0.7632799744606018,1.1745176315307617,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/measurements.csv deleted file mode 100644 index 8a89bb066..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1979 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5283198,6.9328847,,,,,,,,,,,,,, -1,,,0.0007174744969233,6.912591457366943,0.0007599999662488,6.913174629211426,50000.0,0.0006000000284984,6.9125494956970215,10000.0,31.89585375785828,50.27870035171509,31.89585375785828,18.382753133773804,0.0,0.0 -100,0.5241796,6.898794,,,,,,,,,,,,,, -200,0.52206147,6.8605213,,,,,,,,,,,,,, -300,0.57412195,6.788781,,,,,,,,,,,,,, -400,0.6246461,6.6790133,,,,,,,,,,,,,, -500,0.6514214,6.5994415,,,,,,,,,,,,,, -600,0.69758564,6.512314,,,,,,,,,,,,,, -700,0.7566064,6.461299,,,,,,,,,,,,,, -800,1.135739,6.3821087,,,,,,,,,,,,,, -900,2.1312175,6.2485323,,,,,,,,,,,,,, -1000,1.9387649,6.2148266,,,,,,,,,,,,,, -1100,2.9553583,6.104237,,,,,,,,,,,,,, -1200,2.3821092,6.08021,,,,,,,,,,,,,, -1300,1.8901248,6.0169168,,,,,,,,,,,,,, -1400,1.646095,5.9225245,,,,,,,,,,,,,, -1497,,,0.0817721635103225,5.256331920623779,0.0762000009417533,5.30866289138794,50000.0,0.053700003772974,5.532280921936035,10000.0,541.9562501907349,578.1606819629669,541.9562501907349,36.13430953025818,0.0164451599121093,0.0 -1500,1.8909599,5.9691405,,,,,,,,,,,,,, -1600,2.170697,5.821894,,,,,,,,,,,,,, -1700,2.8188143,5.800589,,,,,,,,,,,,,, -1800,2.7524683,5.750079,,,,,,,,,,,,,, -1900,2.6154566,5.741132,,,,,,,,,,,,,, -2000,3.7645876,5.7192364,,,,,,,,,,,,,, -2100,2.2502007,5.6725483,,,,,,,,,,,,,, -2200,3.384165,5.608825,,,,,,,,,,,,,, -2300,3.6330209,5.5409164,,,,,,,,,,,,,, -2400,2.7742827,5.576455,,,,,,,,,,,,,, -2500,3.6329434,5.463419,,,,,,,,,,,,,, -2600,4.5306497,5.411351,,,,,,,,,,,,,, -2700,5.019565,5.4204025,,,,,,,,,,,,,, -2800,3.3124259,5.357515,,,,,,,,,,,,,, -2900,3.2198222,5.294589,,,,,,,,,,,,,, -2992,,,0.1966477930545807,4.2057881355285645,0.1745599955320358,4.332620143890381,50000.0,0.1266000121831894,4.754936218261719,10000.0,1051.935497045517,1105.7899084091189,1051.935497045517,53.70322799682617,0.0438079833984375,0.0 -3000,3.8637617,5.247735,,,,,,,,,,,,,, -3100,8.558143,5.257143,,,,,,,,,,,,,, -3200,4.2212853,5.219947,,,,,,,,,,,,,, -3300,4.067183,5.1441917,,,,,,,,,,,,,, -3400,4.2746177,5.1429787,,,,,,,,,,,,,, -3500,4.661178,5.117332,,,,,,,,,,,,,, -3600,3.7771175,5.0686593,,,,,,,,,,,,,, -3700,3.3259697,5.0062203,,,,,,,,,,,,,, -3800,4.3260098,4.920019,,,,,,,,,,,,,, -3900,4.0552235,4.9167185,,,,,,,,,,,,,, -4000,3.6045206,4.9553804,,,,,,,,,,,,,, -4100,3.8966835,4.9085865,,,,,,,,,,,,,, -4200,6.3161254,4.884485,,,,,,,,,,,,,, -4300,4.157747,4.752102,,,,,,,,,,,,,, -4400,2.5730386,4.7495575,,,,,,,,,,,,,, -4488,,,0.3034917116165161,3.4856112003326416,0.2791000008583069,3.62486720085144,50000.0,0.2101000100374221,4.122847080230713,10000.0,1562.1542949676514,1634.7661957740784,1562.1542949676514,72.37804913520813,0.0735633373260498,0.0 -4500,3.087323,4.732796,,,,,,,,,,,,,, -4600,5.328942,4.718866,,,,,,,,,,,,,, -4700,2.6070623,4.6965547,,,,,,,,,,,,,, -4800,3.1849854,4.6467075,,,,,,,,,,,,,, -4900,3.3476872,4.650573,,,,,,,,,,,,,, -5000,2.7831793,4.548101,,,,,,,,,,,,,, -5100,3.2844114,4.5763793,,,,,,,,,,,,,, -5200,3.2538784,4.5684547,,,,,,,,,,,,,, -5300,2.930448,4.633493,,,,,,,,,,,,,, -5400,4.1911397,4.4748993,,,,,,,,,,,,,, -5500,3.8786807,4.4702425,,,,,,,,,,,,,, -5600,4.3059473,4.5008783,,,,,,,,,,,,,, -5700,3.5864675,4.3854866,,,,,,,,,,,,,, -5800,3.2332742,4.488038,,,,,,,,,,,,,, -5900,2.5284195,4.463169,,,,,,,,,,,,,, -5984,,,0.3869180381298065,2.9858076572418213,0.359059989452362,3.1390581130981445,50000.0,0.2732000052928924,3.685528755187988,10000.0,2072.106372833252,2162.559784412384,2072.106372833252,90.1362087726593,0.1037073135375976,0.0 -6000,3.2141957,4.391487,,,,,,,,,,,,,, -6100,2.930135,4.4424634,,,,,,,,,,,,,, -6200,2.9498386,4.307802,,,,,,,,,,,,,, -6300,4.6050067,4.337803,,,,,,,,,,,,,, -6400,2.693526,4.308502,,,,,,,,,,,,,, -6500,3.295055,4.239626,,,,,,,,,,,,,, -6600,3.5005908,4.3308744,,,,,,,,,,,,,, -6700,3.4913795,4.2045765,,,,,,,,,,,,,, -6800,2.1095755,4.2507257,,,,,,,,,,,,,, -6900,2.6484313,4.1688223,,,,,,,,,,,,,, -7000,3.0650318,4.1407127,,,,,,,,,,,,,, -7100,2.589636,4.181346,,,,,,,,,,,,,, -7200,2.9015603,4.2316346,,,,,,,,,,,,,, -7300,1.7121224,4.151163,,,,,,,,,,,,,, -7400,2.2345614,4.1051626,,,,,,,,,,,,,, -7480,,,0.4513911008834839,2.6858718395233154,0.4205600023269653,2.84183406829834,50000.0,0.3192000091075897,3.449917316436768,10000.0,2582.1680810451508,2690.564907312393,2582.1680810451508,107.99761819839478,0.1325531005859375,0.0 -7500,3.160996,4.1207747,,,,,,,,,,,,,, -7600,2.0984707,4.0795856,,,,,,,,,,,,,, -7700,4.5703597,4.2041707,,,,,,,,,,,,,, -7800,2.3721395,4.0732565,,,,,,,,,,,,,, -7900,2.3430889,4.175673,,,,,,,,,,,,,, -8000,2.5874655,4.0576134,,,,,,,,,,,,,, -8100,1.8192383,4.0208745,,,,,,,,,,,,,, -8200,2.5658236,4.0210924,,,,,,,,,,,,,, -8300,2.5638773,4.0718145,,,,,,,,,,,,,, -8400,2.428488,3.98462,,,,,,,,,,,,,, -8500,1.8562537,3.9804497,,,,,,,,,,,,,, -8600,2.5700817,4.0664487,,,,,,,,,,,,,, -8700,1.9473099,4.0004544,,,,,,,,,,,,,, -8800,1.8312047,3.9638174,,,,,,,,,,,,,, -8900,1.8122392,3.9537725,,,,,,,,,,,,,, -8977,,,0.5364516973495483,2.2150745391845703,0.4718599915504455,2.5311429500579834,50000.0,0.3625000119209289,3.1810948848724365,10000.0,3092.3153219223022,3218.483143806457,3092.3153219223022,125.6869819164276,0.1603870391845703,0.0 -9000,2.3466165,3.9414344,,,,,,,,,,,,,, -9100,2.2409563,3.9262,,,,,,,,,,,,,, -9200,2.8388999,4.0004582,,,,,,,,,,,,,, -9300,2.0417151,3.8756845,,,,,,,,,,,,,, -9400,2.5272675,3.9083562,,,,,,,,,,,,,, -9500,2.4844139,3.9661367,,,,,,,,,,,,,, -9600,2.1827288,3.847108,,,,,,,,,,,,,, -9700,2.0230193,3.9765348,,,,,,,,,,,,,, -9800,1.6997831,3.8402104,,,,,,,,,,,,,, -9900,2.557971,3.8878448,,,,,,,,,,,,,, -10000,2.7434857,3.7793248,,,,,,,,,,,,,, -10100,1.6489272,3.9680552,,,,,,,,,,,,,, -10200,1.7710121,3.9406335,,,,,,,,,,,,,, -10300,1.5906348,3.7651324,,,,,,,,,,,,,, -10400,1.9730321,3.840883,,,,,,,,,,,,,, -10474,,,0.5558035373687744,2.1703338623046875,0.5116599798202515,2.3818368911743164,50000.0,0.3958000242710113,3.038412094116211,10000.0,3602.40500998497,3746.247165679932,3602.40500998497,143.2770929336548,0.1908240318298339,0.0 -10500,1.4556956,3.77397,,,,,,,,,,,,,, -10600,2.978499,3.69481,,,,,,,,,,,,,, -10700,1.6310514,3.8080816,,,,,,,,,,,,,, -10800,2.0602067,3.7943535,,,,,,,,,,,,,, -10900,2.0177212,3.8169034,,,,,,,,,,,,,, -11000,2.8241372,3.8214655,,,,,,,,,,,,,, -11100,1.7370955,3.677622,,,,,,,,,,,,,, -11200,1.6415027,3.7186809,,,,,,,,,,,,,, -11300,2.3631063,3.7525325,,,,,,,,,,,,,, -11400,1.6746709,3.6565704,,,,,,,,,,,,,, -11500,2.3634188,3.689394,,,,,,,,,,,,,, -11600,1.4864968,3.6192355,,,,,,,,,,,,,, -11700,2.165313,3.695363,,,,,,,,,,,,,, -11800,1.8907957,3.67137,,,,,,,,,,,,,, -11900,1.664368,3.6835437,,,,,,,,,,,,,, -11972,,,0.5846021771430969,1.9957231283187864,0.5405799746513367,2.2037601470947266,50000.0,0.4250000119209289,2.862744092941284,10000.0,4112.561676979065,4274.270484447479,4112.561676979065,161.05826830863953,0.2227044105529785,0.0 -12000,1.8608259,3.7448626,,,,,,,,,,,,,, -12100,1.5861801,3.6528006,,,,,,,,,,,,,, -12200,1.7729118,3.6995716,,,,,,,,,,,,,, -12300,1.8914872,3.6784203,,,,,,,,,,,,,, -12400,2.4540408,3.5998716,,,,,,,,,,,,,, -12500,1.7541232,3.632567,,,,,,,,,,,,,, -12600,1.8973943,3.6192586,,,,,,,,,,,,,, -12700,2.2958777,3.559391,,,,,,,,,,,,,, -12800,1.9402441,3.6550705,,,,,,,,,,,,,, -12900,1.830741,3.6944153,,,,,,,,,,,,,, -13000,1.5203447,3.5880625,,,,,,,,,,,,,, -13100,1.5295982,3.6502838,,,,,,,,,,,,,, -13200,1.7963171,3.59414,,,,,,,,,,,,,, -13300,1.4951613,3.6180234,,,,,,,,,,,,,, -13400,1.621865,3.624858,,,,,,,,,,,,,, -13469,,,0.6219108700752258,1.801672101020813,0.5726400017738342,2.031932592391968,50000.0,0.4506000280380249,2.675619602203369,10000.0,4622.485732078552,4801.622956514359,4622.485732078552,178.40104007720947,0.2536995410919189,0.0 -13500,1.8349671,3.5878394,,,,,,,,,,,,,, -13600,1.8539764,3.5432737,,,,,,,,,,,,,, -13700,1.4802797,3.5894103,,,,,,,,,,,,,, -13800,1.6011323,3.536979,,,,,,,,,,,,,, -13900,2.551987,3.5588298,,,,,,,,,,,,,, -14000,2.010386,3.5767157,,,,,,,,,,,,,, -14100,1.5861363,3.5767965,,,,,,,,,,,,,, -14200,2.0754018,3.6057305,,,,,,,,,,,,,, -14300,1.6665641,3.5407972,,,,,,,,,,,,,, -14400,1.6555679,3.6090143,,,,,,,,,,,,,, -14500,1.6084344,3.5982616,,,,,,,,,,,,,, -14600,1.360292,3.5180008,,,,,,,,,,,,,, -14700,1.3419622,3.4775524,,,,,,,,,,,,,, -14800,1.5491649,3.4731913,,,,,,,,,,,,,, -14900,1.4666903,3.4801674,,,,,,,,,,,,,, -14968,,,0.6355029940605164,1.682795166969299,0.5859599709510803,1.926688551902771,50000.0,0.4617000222206116,2.591858148574829,10000.0,5132.471919298172,5329.625428676605,5132.471919298172,196.3314588069916,0.285891056060791,0.0 -15000,1.7891246,3.560409,,,,,,,,,,,,,, -15100,1.7020175,3.4943273,,,,,,,,,,,,,, -15200,1.3569092,3.4911392,,,,,,,,,,,,,, -15300,1.6996375,3.5259914,,,,,,,,,,,,,, -15400,1.4445978,3.4429116,,,,,,,,,,,,,, -15500,2.0429003,3.5206087,,,,,,,,,,,,,, -15600,1.5598063,3.5055628,,,,,,,,,,,,,, -15700,1.5545195,3.5115654,,,,,,,,,,,,,, -15800,1.4297017,3.4601378,,,,,,,,,,,,,, -15900,1.3845029,3.4971416,,,,,,,,,,,,,, -16000,1.1556736,3.4793644,,,,,,,,,,,,,, -16100,1.6618459,3.5136058,,,,,,,,,,,,,, -16200,1.408853,3.492556,,,,,,,,,,,,,, -16300,1.5321128,3.4875042,,,,,,,,,,,,,, -16400,1.3149525,3.5460143,,,,,,,,,,,,,, -16467,,,0.6412428021430969,1.7201507091522217,0.5932999849319458,1.9380810260772705,50000.0,0.4690000116825104,2.6094350814819336,10000.0,5642.635751485825,5857.638697385788,5642.635751485825,214.10228490829468,0.3106076717376709,0.0 -16500,1.834585,3.4825158,,,,,,,,,,,,,, -16600,1.4616971,3.4847293,,,,,,,,,,,,,, -16700,1.3869349,3.4220998,,,,,,,,,,,,,, -16800,1.4440433,3.4796462,,,,,,,,,,,,,, -16900,1.384665,3.4339502,,,,,,,,,,,,,, -17000,1.5575535,3.4864433,,,,,,,,,,,,,, -17100,1.9025927,3.4916987,,,,,,,,,,,,,, -17200,1.6030581,3.4133902,,,,,,,,,,,,,, -17300,1.3507742,3.3779452,,,,,,,,,,,,,, -17400,1.3850855,3.5239346,,,,,,,,,,,,,, -17500,1.3876996,3.4096107,,,,,,,,,,,,,, -17600,1.2540349,3.3971162,,,,,,,,,,,,,, -17700,1.4078841,3.3839803,,,,,,,,,,,,,, -17800,1.5934136,3.4294133,,,,,,,,,,,,,, -17900,1.3744987,3.3491113,,,,,,,,,,,,,, -17967,,,0.7058354616165161,1.4462586641311646,0.6126999855041504,1.8431637287139893,50000.0,0.4798000156879425,2.5136966705322266,10000.0,6152.796566486359,6385.609614610672,6152.796566486359,231.8304960727692,0.3401076793670654,0.0 -18000,1.096348,3.400273,,,,,,,,,,,,,, -18100,1.5520208,3.486947,,,,,,,,,,,,,, -18200,1.4916172,3.453117,,,,,,,,,,,,,, -18300,1.4596202,3.4282026,,,,,,,,,,,,,, -18400,1.2805926,3.3793192,,,,,,,,,,,,,, -18500,1.5901549,3.4124124,,,,,,,,,,,,,, -18600,1.6713231,3.3414583,,,,,,,,,,,,,, -18700,1.6947097,3.5116203,,,,,,,,,,,,,, -18800,1.7133088,3.367697,,,,,,,,,,,,,, -18900,1.3500212,3.3987637,,,,,,,,,,,,,, -19000,1.2275169,3.4076807,,,,,,,,,,,,,, -19100,1.4643979,3.4076545,,,,,,,,,,,,,, -19200,1.5271051,3.4198666,,,,,,,,,,,,,, -19300,1.3366375,3.4324934,,,,,,,,,,,,,, -19400,1.3030791,3.4301863,,,,,,,,,,,,,, -19466,,,0.6865234375,1.4992294311523438,0.6133399605751038,1.816870212554932,50000.0,0.4852000176906585,2.5008463859558105,10000.0,6662.721271753311,6913.4088344573975,6662.721271753311,249.6162793636322,0.3739089965820312,0.0 -19500,1.2679945,3.458918,,,,,,,,,,,,,, -19600,1.5753176,3.4929788,,,,,,,,,,,,,, -19700,1.4744117,3.376565,,,,,,,,,,,,,, -19800,1.307913,3.335631,,,,,,,,,,,,,, -19900,1.9410918,3.4199033,,,,,,,,,,,,,, -20000,1.5134418,3.3562994,,,,,,,,,,,,,, -20100,1.4645442,3.373961,,,,,,,,,,,,,, -20200,1.3493736,3.4343846,,,,,,,,,,,,,, -20300,1.2155823,3.382232,,,,,,,,,,,,,, -20400,1.4501508,3.347276,,,,,,,,,,,,,, -20500,1.2673472,3.397231,,,,,,,,,,,,,, -20600,1.4706128,3.3020573,,,,,,,,,,,,,, -20700,1.0927155,3.2833273,,,,,,,,,,,,,, -20800,1.467301,3.381575,,,,,,,,,,,,,, -20900,1.1877114,3.3153932,,,,,,,,,,,,,, -20966,,,0.6920639276504517,1.470927596092224,0.6231799721717834,1.7682565450668335,50000.0,0.501300036907196,2.41690993309021,10000.0,7172.737153053284,7441.06632399559,7172.737153053284,267.1690058708191,0.4088256359100342,0.0 -21000,1.2550592,3.4102125,,,,,,,,,,,,,, -21100,1.1401473,3.3204114,,,,,,,,,,,,,, -21200,1.7677896,3.4005747,,,,,,,,,,,,,, -21300,1.8391361,3.39985,,,,,,,,,,,,,, -21400,1.6672812,3.376821,,,,,,,,,,,,,, -21500,1.8906811,3.3909044,,,,,,,,,,,,,, -21600,1.2704034,3.3972607,,,,,,,,,,,,,, -21700,1.3164011,3.3385494,,,,,,,,,,,,,, -21800,1.4169477,3.3355703,,,,,,,,,,,,,, -21900,1.3852291,3.3585148,,,,,,,,,,,,,, -22000,1.4538008,3.3226366,,,,,,,,,,,,,, -22100,1.4749391,3.391639,,,,,,,,,,,,,, -22200,1.3574789,3.3454885,,,,,,,,,,,,,, -22300,1.3324914,3.2957072,,,,,,,,,,,,,, -22400,1.6014754,3.3756318,,,,,,,,,,,,,, -22465,,,0.6925621628761292,1.4761664867401123,0.6293399930000305,1.753197431564331,50000.0,0.5006999969482422,2.429164409637451,10000.0,7682.723104953766,7968.861889123917,7682.723104953766,284.8941237926483,0.4400563240051269,0.0 -22500,1.3101505,3.2909927,,,,,,,,,,,,,, -22600,1.3532025,3.3352542,,,,,,,,,,,,,, -22700,1.4462137,3.3952937,,,,,,,,,,,,,, -22800,1.0667354,3.355881,,,,,,,,,,,,,, -22900,1.463807,3.275091,,,,,,,,,,,,,, -23000,1.4020044,3.3602057,,,,,,,,,,,,,, -23100,1.4748548,3.3066914,,,,,,,,,,,,,, -23200,1.238661,3.3358126,,,,,,,,,,,,,, -23300,1.5508473,3.3186367,,,,,,,,,,,,,, -23400,1.3094006,3.3201854,,,,,,,,,,,,,, -23500,1.4455941,3.309695,,,,,,,,,,,,,, -23600,1.3100777,3.3830228,,,,,,,,,,,,,, -23700,1.2724595,3.3704581,,,,,,,,,,,,,, -23800,1.3073813,3.3859944,,,,,,,,,,,,,, -23900,1.574313,3.3356912,,,,,,,,,,,,,, -23966,,,0.6956911683082581,1.4922162294387815,0.634719967842102,1.7566068172454834,50000.0,0.5092000365257263,2.4119579792022705,10000.0,8192.924956321716,8496.729290246964,8192.924956321716,302.46875619888306,0.4770851135253906,0.0 -24000,1.5733496,3.3564355,,,,,,,,,,,,,, -24100,1.29834,3.3442357,,,,,,,,,,,,,, -24200,1.2330309,3.3126435,,,,,,,,,,,,,, -24300,1.6697016,3.3666916,,,,,,,,,,,,,, -24400,1.3457882,3.3624198,,,,,,,,,,,,,, -24500,1.3601652,3.3549166,,,,,,,,,,,,,, -24600,1.4499872,3.3666642,,,,,,,,,,,,,, -24700,1.2885109,3.2739851,,,,,,,,,,,,,, -24800,1.5096678,3.2816513,,,,,,,,,,,,,, -24900,1.4204267,3.33487,,,,,,,,,,,,,, -25000,1.3219984,3.3505573,,,,,,,,,,,,,, -25100,1.3165311,3.3698244,,,,,,,,,,,,,, -25200,1.3596787,3.3685787,,,,,,,,,,,,,, -25300,1.2689062,3.3767123,,,,,,,,,,,,,, -25400,1.268156,3.3565474,,,,,,,,,,,,,, -25466,,,0.7027263641357422,1.419603705406189,0.6396999955177307,1.6906665563583374,50000.0,0.5162000060081482,2.3508965969085693,10000.0,8702.947353839874,9024.420098781586,8702.947353839874,320.05310821533203,0.508368968963623,0.0 -25500,1.3156027,3.351417,,,,,,,,,,,,,, -25600,1.35685,3.278268,,,,,,,,,,,,,, -25700,1.333398,3.317309,,,,,,,,,,,,,, -25800,1.8105712,3.2889507,,,,,,,,,,,,,, -25900,1.3544956,3.2721772,,,,,,,,,,,,,, -26000,1.5119681,3.3780708,,,,,,,,,,,,,, -26100,1.770259,3.2327309,,,,,,,,,,,,,, -26200,1.3735023,3.297627,,,,,,,,,,,,,, -26300,1.237117,3.3119478,,,,,,,,,,,,,, -26400,1.675851,3.2888377,,,,,,,,,,,,,, -26500,1.4254296,3.3413274,,,,,,,,,,,,,, -26600,1.3300227,3.292577,,,,,,,,,,,,,, -26700,1.6941563,3.3227248,,,,,,,,,,,,,, -26800,1.2421155,3.295206,,,,,,,,,,,,,, -26900,1.3783182,3.1836617,,,,,,,,,,,,,, -26966,,,0.6951530575752258,1.484920859336853,0.6298199892044067,1.7593168020248413,50000.0,0.4951000213623047,2.4571080207824707,10000.0,9212.861717700958,9551.975267887115,9212.861717700958,337.60744285583496,0.5415892601013184,0.0 -27000,1.3287219,3.3246427,,,,,,,,,,,,,, -27100,1.2010545,3.2559714,,,,,,,,,,,,,, -27200,1.4231368,3.3425252,,,,,,,,,,,,,, -27300,1.2691578,3.3315475,,,,,,,,,,,,,, -27400,1.3431518,3.257994,,,,,,,,,,,,,, -27500,1.368314,3.2593875,,,,,,,,,,,,,, -27600,1.5389137,3.2935681,,,,,,,,,,,,,, -27700,1.2156265,3.2595382,,,,,,,,,,,,,, -27800,1.4973458,3.3342264,,,,,,,,,,,,,, -27900,1.2910719,3.3078127,,,,,,,,,,,,,, -28000,1.3665305,3.2878807,,,,,,,,,,,,,, -28100,1.5032713,3.2901776,,,,,,,,,,,,,, -28200,1.6126231,3.3071983,,,,,,,,,,,,,, -28300,1.2886844,3.2427335,,,,,,,,,,,,,, -28400,1.3746569,3.229061,,,,,,,,,,,,,, -28466,,,0.7315250039100647,1.2711821794509888,0.6417799592018127,1.6622222661972046,50000.0,0.5190000534057617,2.301248073577881,10000.0,9722.92922616005,10079.964128255844,9722.92922616005,355.4415764808655,0.5747489929199219,0.0 -28500,1.1700723,3.1905997,,,,,,,,,,,,,, -28600,1.3205484,3.343742,,,,,,,,,,,,,, -28700,1.3205229,3.3259413,,,,,,,,,,,,,, -28800,1.2776606,3.2407873,,,,,,,,,,,,,, -28900,1.2898448,3.3052883,,,,,,,,,,,,,, -29000,1.3570479,3.232236,,,,,,,,,,,,,, -29100,1.4693669,3.2483208,,,,,,,,,,,,,, -29200,1.3753164,3.2342248,,,,,,,,,,,,,, -29300,1.4855056,3.3348842,,,,,,,,,,,,,, -29400,1.7182355,3.3182821,,,,,,,,,,,,,, -29500,1.3515348,3.32091,,,,,,,,,,,,,, -29600,1.312782,3.2552364,,,,,,,,,,,,,, -29700,1.3738242,3.2203605,,,,,,,,,,,,,, -29800,1.3700948,3.2368677,,,,,,,,,,,,,, -29900,1.444437,3.2761197,,,,,,,,,,,,,, -29967,,,0.7155213356018066,1.3536566495895386,0.6436799764633179,1.67797589302063,50000.0,0.5196000337600708,2.333815813064575,10000.0,10233.182272434236,10608.206499814987,10233.182272434236,373.3464798927307,0.6060409545898438,0.0 -30000,1.5442457,3.3019438,,,,,,,,,,,,,, -30100,1.4453877,3.247599,,,,,,,,,,,,,, -30200,1.3216798,3.3049874,,,,,,,,,,,,,, -30300,1.3419735,3.2778897,,,,,,,,,,,,,, -30400,1.5077024,3.2171774,,,,,,,,,,,,,, -30500,1.4382607,3.2593033,,,,,,,,,,,,,, -30600,1.3406478,3.2751267,,,,,,,,,,,,,, -30700,1.3536104,3.2265167,,,,,,,,,,,,,, -30800,1.3627405,3.2793944,,,,,,,,,,,,,, -30900,1.4606522,3.3320599,,,,,,,,,,,,,, -31000,1.547414,3.2757103,,,,,,,,,,,,,, -31100,1.4651533,3.262547,,,,,,,,,,,,,, -31200,1.3662856,3.2923496,,,,,,,,,,,,,, -31300,1.4754635,3.247284,,,,,,,,,,,,,, -31400,1.3262419,3.3081932,,,,,,,,,,,,,, -31468,,,0.7057158946990967,1.420884132385254,0.6421200037002563,1.7156926393508911,50000.0,0.5074000358581543,2.405803203582764,10000.0,10743.343856096268,11136.228532791138,10743.343856096268,391.1074600219727,0.6514892578125,0.0 -31500,1.421159,3.2863812,,,,,,,,,,,,,, -31600,1.5831879,3.1420722,,,,,,,,,,,,,, -31700,1.1381956,3.1958616,,,,,,,,,,,,,, -31800,1.3522683,3.2672749,,,,,,,,,,,,,, -31900,1.489488,3.2253659,,,,,,,,,,,,,, -32000,1.3197564,3.2031994,,,,,,,,,,,,,, -32100,1.5152638,3.2962828,,,,,,,,,,,,,, -32200,1.4031519,3.25154,,,,,,,,,,,,,, -32300,1.2014909,3.138996,,,,,,,,,,,,,, -32400,1.3877122,3.3004773,,,,,,,,,,,,,, -32500,1.4006225,3.2171075,,,,,,,,,,,,,, -32600,1.4169319,3.260499,,,,,,,,,,,,,, -32700,1.5269971,3.2859483,,,,,,,,,,,,,, -32800,1.2984364,3.2391074,,,,,,,,,,,,,, -32900,1.3668567,3.1769557,,,,,,,,,,,,,, -32968,,,0.7166374325752258,1.359977960586548,0.6487799882888794,1.6539162397384644,50000.0,0.5246000289916992,2.299521446228028,10000.0,11253.426638364792,11664.150138139725,11253.426638364792,408.85849595069885,0.6858878135681152,0.0 -33000,1.4125587,3.2081344,,,,,,,,,,,,,, -33100,1.4317306,3.2566001,,,,,,,,,,,,,, -33200,1.3807633,3.2380323,,,,,,,,,,,,,, -33300,1.3792028,3.2165108,,,,,,,,,,,,,, -33400,1.7005692,3.2615604,,,,,,,,,,,,,, -33500,1.3802718,3.2585526,,,,,,,,,,,,,, -33600,1.4115453,3.3505812,,,,,,,,,,,,,, -33700,1.6548789,3.1270173,,,,,,,,,,,,,, -33800,1.4211758,3.267778,,,,,,,,,,,,,, -33900,1.3567433,3.16332,,,,,,,,,,,,,, -34000,1.267877,3.256898,,,,,,,,,,,,,, -34100,1.7512294,3.2528894,,,,,,,,,,,,,, -34200,1.3435296,3.134774,,,,,,,,,,,,,, -34300,1.3475505,3.2728608,,,,,,,,,,,,,, -34400,1.5285721,3.18921,,,,,,,,,,,,,, -34468,,,0.70703125,1.396243453025818,0.6412400007247925,1.682578444480896,50000.0,0.5151000022888184,2.3533051013946533,10000.0,11763.347280740738,12191.771873950958,11763.347280740738,426.4737284183502,0.7184295654296875,0.0 -34500,2.4018073,3.1825752,,,,,,,,,,,,,, -34600,1.3567549,3.2017236,,,,,,,,,,,,,, -34700,1.3059043,3.1927607,,,,,,,,,,,,,, -34800,1.5453038,3.281506,,,,,,,,,,,,,, -34900,1.4428718,3.240562,,,,,,,,,,,,,, -35000,1.4385328,3.2334032,,,,,,,,,,,,,, -35100,1.4978836,3.2069993,,,,,,,,,,,,,, -35200,1.4566385,3.204801,,,,,,,,,,,,,, -35300,1.5192435,3.2414644,,,,,,,,,,,,,, -35400,1.4238966,3.195009,,,,,,,,,,,,,, -35500,1.3367873,3.2037115,,,,,,,,,,,,,, -35600,1.6225901,3.2790258,,,,,,,,,,,,,, -35700,1.2654195,3.20513,,,,,,,,,,,,,, -35800,1.4959688,3.2077913,,,,,,,,,,,,,, -35900,1.701215,3.2019598,,,,,,,,,,,,,, -35969,,,0.7151227593421936,1.3773218393325806,0.6521199941635132,1.6565839052200315,50000.0,0.522599995136261,2.307307243347168,10000.0,12273.410103797913,12719.69810295105,12273.410103797913,444.2426521778106,0.7576742172241211,0.0 -36000,1.6025653,3.1903048,,,,,,,,,,,,,, -36100,1.5841511,3.282559,,,,,,,,,,,,,, -36200,1.4882604,3.1876574,,,,,,,,,,,,,, -36300,1.551175,3.2622006,,,,,,,,,,,,,, -36400,1.457344,3.272387,,,,,,,,,,,,,, -36500,1.4038508,3.2280626,,,,,,,,,,,,,, -36600,1.4671478,3.2246146,,,,,,,,,,,,,, -36700,1.4980707,3.179183,,,,,,,,,,,,,, -36800,1.3430902,3.199854,,,,,,,,,,,,,, -36900,1.4981223,3.2122421,,,,,,,,,,,,,, -37000,1.3830748,3.286107,,,,,,,,,,,,,, -37100,1.5758682,3.2356534,,,,,,,,,,,,,, -37200,1.4801412,3.190946,,,,,,,,,,,,,, -37300,1.6767535,3.1962838,,,,,,,,,,,,,, -37400,1.4530421,3.1399517,,,,,,,,,,,,,, -37469,,,0.7365872263908386,1.2994897365570068,0.6460199952125549,1.7007862329483032,50000.0,0.5078000426292419,2.3833701610565186,10000.0,12783.423173427582,13247.574071884155,12783.423173427582,462.0181083679199,0.7907888889312744,0.0 -37500,1.5310118,3.2788095,,,,,,,,,,,,,, -37600,1.4445587,3.1677766,,,,,,,,,,,,,, -37700,1.4160094,3.2088304,,,,,,,,,,,,,, -37800,1.3030545,3.1894476,,,,,,,,,,,,,, -37900,1.5891311,3.2992506,,,,,,,,,,,,,, -38000,1.5156657,3.2575445,,,,,,,,,,,,,, -38100,1.3953688,3.2758214,,,,,,,,,,,,,, -38200,1.5577465,3.1942067,,,,,,,,,,,,,, -38300,1.7100283,3.312676,,,,,,,,,,,,,, -38400,1.7151835,3.1905124,,,,,,,,,,,,,, -38500,1.7906806,3.201303,,,,,,,,,,,,,, -38600,1.4178487,3.2203712,,,,,,,,,,,,,, -38700,1.5861655,3.146536,,,,,,,,,,,,,, -38800,1.4957926,3.2289746,,,,,,,,,,,,,, -38900,1.4492806,3.1954644,,,,,,,,,,,,,, -38970,,,0.7347337007522583,1.2530560493469238,0.6542199850082397,1.6197998523712158,50000.0,0.5218999981880188,2.2923173904418945,10000.0,13293.508579730988,13775.50724887848,13293.508579730988,479.7775735855103,0.8264029026031494,0.0 -39000,1.4416811,3.1374798,,,,,,,,,,,,,, -39100,1.6043457,3.1483793,,,,,,,,,,,,,, -39200,1.6753587,3.1739173,,,,,,,,,,,,,, -39300,1.5565561,3.2094283,,,,,,,,,,,,,, -39400,1.5243418,3.2779822,,,,,,,,,,,,,, -39500,1.6442515,3.2077873,,,,,,,,,,,,,, -39600,1.4732958,3.1649659,,,,,,,,,,,,,, -39700,1.6657567,3.2214327,,,,,,,,,,,,,, -39800,1.6745706,3.2380927,,,,,,,,,,,,,, -39900,1.5187309,3.1889684,,,,,,,,,,,,,, -40000,1.935772,3.231267,,,,,,,,,,,,,, -40100,1.6122868,3.1989315,,,,,,,,,,,,,, -40200,1.537995,3.1707022,,,,,,,,,,,,,, -40300,1.7056445,3.2360315,,,,,,,,,,,,,, -40400,1.5736318,3.2293313,,,,,,,,,,,,,, -40471,,,0.7262037396430969,1.325469732284546,0.6515399813652039,1.6518381834030151,50000.0,0.5252000093460083,2.321791410446167,10000.0,13803.586438894272,14303.099082946776,13803.586438894272,497.2023296356201,0.8594932556152344,0.0 -40500,1.5502143,3.192442,,,,,,,,,,,,,, -40600,1.6543412,3.144175,,,,,,,,,,,,,, -40700,1.6540872,3.1875682,,,,,,,,,,,,,, -40800,1.574836,3.2562716,,,,,,,,,,,,,, -40900,1.4658161,3.1974702,,,,,,,,,,,,,, -41000,1.4802054,3.1527863,,,,,,,,,,,,,, -41100,1.4445461,3.2037427,,,,,,,,,,,,,, -41200,1.6051959,3.1292977,,,,,,,,,,,,,, -41300,2.2939656,3.214634,,,,,,,,,,,,,, -41400,1.5706314,3.1475115,,,,,,,,,,,,,, -41500,1.714084,3.2317472,,,,,,,,,,,,,, -41600,1.5608681,3.1673574,,,,,,,,,,,,,, -41700,1.4777788,3.128015,,,,,,,,,,,,,, -41800,1.5261964,3.1704626,,,,,,,,,,,,,, -41900,1.8470858,3.2079928,,,,,,,,,,,,,, -41972,,,0.7281568646430969,1.2736696004867554,0.6581000089645386,1.5901299715042114,50000.0,0.5282000303268433,2.2479236125946045,10000.0,14313.738857030869,14831.734774827955,14313.738857030869,515.5967583656311,0.8947463035583496,0.0 -42000,1.556733,3.1423419,,,,,,,,,,,,,, -42100,1.7072948,3.226346,,,,,,,,,,,,,, -42200,1.7581369,3.2200286,,,,,,,,,,,,,, -42300,1.6242781,3.1947455,,,,,,,,,,,,,, -42400,1.8747979,3.145209,,,,,,,,,,,,,, -42500,1.5478789,3.1939733,,,,,,,,,,,,,, -42600,1.7657969,3.1440372,,,,,,,,,,,,,, -42700,1.5421256,3.1995423,,,,,,,,,,,,,, -42800,1.5924094,3.1160204,,,,,,,,,,,,,, -42900,1.636934,3.143193,,,,,,,,,,,,,, -43000,1.5862763,3.1426187,,,,,,,,,,,,,, -43100,1.5482839,3.14273,,,,,,,,,,,,,, -43200,1.7785,3.1279223,,,,,,,,,,,,,, -43300,1.656356,3.161026,,,,,,,,,,,,,, -43400,1.7340786,3.259137,,,,,,,,,,,,,, -43473,,,0.7159597873687744,1.3737213611602783,0.6536200046539307,1.6578227281570437,50000.0,0.5232000350952148,2.333250999450684,10000.0,14823.836078882216,15359.73616051674,14823.836078882216,533.4117612838745,0.9289801120758056,0.0 -43500,1.585978,3.248937,,,,,,,,,,,,,, -43600,1.7585335,3.1713977,,,,,,,,,,,,,, -43700,1.5874991,3.153115,,,,,,,,,,,,,, -43800,1.7156444,3.1505697,,,,,,,,,,,,,, -43900,1.6450268,3.2356238,,,,,,,,,,,,,, -44000,1.6429545,3.1990602,,,,,,,,,,,,,, -44100,1.6023428,3.204117,,,,,,,,,,,,,, -44200,1.5068424,3.120942,,,,,,,,,,,,,, -44300,1.8875289,3.120522,,,,,,,,,,,,,, -44400,1.7036586,3.183532,,,,,,,,,,,,,, -44500,1.6871836,3.1618495,,,,,,,,,,,,,, -44600,1.6044072,3.1287553,,,,,,,,,,,,,, -44700,1.5665048,3.149992,,,,,,,,,,,,,, -44800,1.5310628,3.128574,,,,,,,,,,,,,, -44900,1.6873978,3.1973186,,,,,,,,,,,,,, -44974,,,0.7290138602256775,1.3344924449920654,0.6631999611854553,1.6209385395050049,50000.0,0.532800018787384,2.273008108139038,10000.0,15333.796644449234,15887.6466050148,15333.796644449234,551.2723331451416,0.964834690093994,0.0 -45000,1.8043516,3.1608279,,,,,,,,,,,,,, -45100,1.618753,3.264614,,,,,,,,,,,,,, -45200,1.5112519,3.2404015,,,,,,,,,,,,,, -45300,1.6579331,3.1898892,,,,,,,,,,,,,, -45400,1.7538261,3.1787534,,,,,,,,,,,,,, -45500,1.7020539,3.1609933,,,,,,,,,,,,,, -45600,1.6798159,3.1778102,,,,,,,,,,,,,, -45700,1.6056838,3.151739,,,,,,,,,,,,,, -45800,1.740066,3.2213268,,,,,,,,,,,,,, -45900,1.8397437,3.2349267,,,,,,,,,,,,,, -46000,1.7762958,3.1985934,,,,,,,,,,,,,, -46100,1.7609087,3.1450396,,,,,,,,,,,,,, -46200,1.6500325,3.20378,,,,,,,,,,,,,, -46300,1.6748072,3.1250095,,,,,,,,,,,,,, -46400,1.702653,3.1453192,,,,,,,,,,,,,, -46475,,,0.75882887840271,1.2321513891220093,0.6564799547195435,1.6605032682418823,50000.0,0.527999997138977,2.3292038440704346,10000.0,15843.76311159134,16415.138216257095,15843.76311159134,568.7053790092468,1.002387285232544,0.0 -46500,1.9285995,3.1460202,,,,,,,,,,,,,, -46600,1.5719335,3.1389923,,,,,,,,,,,,,, -46700,1.6132199,3.1940503,,,,,,,,,,,,,, -46800,1.6139542,3.1215057,,,,,,,,,,,,,, -46900,1.7982216,3.1408658,,,,,,,,,,,,,, -47000,1.7983563,3.1738012,,,,,,,,,,,,,, -47100,1.6326362,3.1440182,,,,,,,,,,,,,, -47200,1.7400621,3.1842933,,,,,,,,,,,,,, -47300,1.8370981,3.110148,,,,,,,,,,,,,, -47400,1.5742356,3.0784595,,,,,,,,,,,,,, -47500,1.9434958,3.167157,,,,,,,,,,,,,, -47600,2.0428321,3.1490166,,,,,,,,,,,,,, -47700,1.6416941,3.101628,,,,,,,,,,,,,, -47800,1.5830747,3.0969806,,,,,,,,,,,,,, -47900,1.805001,3.1750891,,,,,,,,,,,,,, -47976,,,0.7443000674247742,1.2563570737838743,0.6606599688529968,1.619337797164917,50000.0,0.5361000299453735,2.2724545001983643,10000.0,16353.741973161696,16942.905286073685,16353.741973161696,586.3989787101746,1.043410301208496,0.0 -48000,1.961592,3.1910458,,,,,,,,,,,,,, -48100,1.7565665,3.1467397,,,,,,,,,,,,,, -48200,1.7444999,3.115983,,,,,,,,,,,,,, -48300,2.0172663,3.2141006,,,,,,,,,,,,,, -48400,1.7154372,3.1604161,,,,,,,,,,,,,, -48500,1.7332146,3.122484,,,,,,,,,,,,,, -48600,1.7371622,3.2019796,,,,,,,,,,,,,, -48700,1.8193531,3.1539311,,,,,,,,,,,,,, -48800,1.5951085,3.1828628,,,,,,,,,,,,,, -48900,2.014455,3.1329076,,,,,,,,,,,,,, -49000,1.5968354,3.1141737,,,,,,,,,,,,,, -49100,1.8301549,3.1516204,,,,,,,,,,,,,, -49200,1.7179263,3.1247244,,,,,,,,,,,,,, -49300,1.7723918,3.1000488,,,,,,,,,,,,,, -49400,1.6105756,3.0889359,,,,,,,,,,,,,, -49477,,,0.73539137840271,1.2704975605010986,0.6642000079154968,1.5893797874450684,50000.0,0.5272000432014465,2.2808008193969727,10000.0,16863.93771147728,17471.012765169144,16863.93771147728,604.2181794643402,1.0814259052276611,0.0 -49500,1.7334933,3.0808284,,,,,,,,,,,,,, -49600,1.8205773,3.165821,,,,,,,,,,,,,, -49700,1.6921741,3.1288538,,,,,,,,,,,,,, -49800,1.7683772,3.1286712,,,,,,,,,,,,,, -49900,1.7257283,3.2016022,,,,,,,,,,,,,, -50000,1.7067395,3.2127378,,,,,,,,,,,,,, -50100,1.7719823,3.150943,,,,,,,,,,,,,, -50200,1.8344059,3.1638484,,,,,,,,,,,,,, -50300,1.7465215,3.088173,,,,,,,,,,,,,, -50400,1.8610983,3.0813823,,,,,,,,,,,,,, -50500,1.7313303,3.1744757,,,,,,,,,,,,,, -50600,1.7754036,3.171228,,,,,,,,,,,,,, -50700,1.8238181,3.1206899,,,,,,,,,,,,,, -50800,2.121404,3.2541108,,,,,,,,,,,,,, -50900,1.8728685,3.1720102,,,,,,,,,,,,,, -50978,,,0.7208425998687744,1.3237276077270508,0.6489999890327454,1.642844319343567,50000.0,0.5261000394821167,2.2644474506378174,10000.0,17373.93242096901,17998.925570249557,17373.93242096901,622.0479846000671,1.1158974170684814,0.0 -51000,1.7893653,3.1581237,,,,,,,,,,,,,, -51100,1.7119757,3.195416,,,,,,,,,,,,,, -51200,1.8089428,3.1280367,,,,,,,,,,,,,, -51300,1.9566407,3.2087808,,,,,,,,,,,,,, -51400,1.5969335,3.0966547,,,,,,,,,,,,,, -51500,1.8107724,3.1244729,,,,,,,,,,,,,, -51600,1.9835335,3.0992675,,,,,,,,,,,,,, -51700,1.8579524,3.158266,,,,,,,,,,,,,, -51800,1.7603538,3.1724753,,,,,,,,,,,,,, -51900,1.9901268,3.222651,,,,,,,,,,,,,, -52000,1.8039469,3.0928109,,,,,,,,,,,,,, -52100,1.706116,3.1364856,,,,,,,,,,,,,, -52200,1.7303275,3.1192527,,,,,,,,,,,,,, -52300,1.7403431,3.134058,,,,,,,,,,,,,, -52400,1.8568028,3.0882146,,,,,,,,,,,,,, -52479,,,0.7428650856018066,1.2034056186676023,0.6698399782180786,1.527472972869873,50000.0,0.5471000075340271,2.1574642658233643,10000.0,17884.122307777405,18527.0086209774,17884.122307777405,639.8506090641022,1.1524429321289062,0.0 -52500,1.8357184,3.1300259,,,,,,,,,,,,,, -52600,1.7885655,3.151059,,,,,,,,,,,,,, -52700,1.7626694,3.1438932,,,,,,,,,,,,,, -52800,1.8021567,3.1223898,,,,,,,,,,,,,, -52900,1.8763527,3.222508,,,,,,,,,,,,,, -53000,1.7558256,3.1805053,,,,,,,,,,,,,, -53100,1.6957736,3.148619,,,,,,,,,,,,,, -53200,1.7336899,3.1816595,,,,,,,,,,,,,, -53300,2.2596598,3.1423116,,,,,,,,,,,,,, -53400,1.776369,3.157034,,,,,,,,,,,,,, -53500,1.8317826,3.1179368,,,,,,,,,,,,,, -53600,1.8847303,3.1716921,,,,,,,,,,,,,, -53700,1.8085729,3.174392,,,,,,,,,,,,,, -53800,1.901089,3.1280274,,,,,,,,,,,,,, -53900,1.8759153,3.1289465,,,,,,,,,,,,,, -53981,,,0.7333585619926453,1.3001765012741089,0.6660999655723572,1.604732871055603,50000.0,0.5386000275611877,2.2656149864196777,10000.0,18394.34732890129,19055.01717185974,18394.34732890129,657.5397083759308,1.1924428939819336,0.0 -54000,1.8711098,3.100419,,,,,,,,,,,,,, -54100,1.8533983,3.2061443,,,,,,,,,,,,,, -54200,1.8635279,3.1070967,,,,,,,,,,,,,, -54300,1.8938196,3.1132932,,,,,,,,,,,,,, -54400,1.7796828,3.1828156,,,,,,,,,,,,,, -54500,1.8806068,3.0913675,,,,,,,,,,,,,, -54600,1.7647812,3.1191728,,,,,,,,,,,,,, -54700,1.7050436,3.1530757,,,,,,,,,,,,,, -54800,1.8589807,3.1233416,,,,,,,,,,,,,, -54900,1.8674619,3.1910195,,,,,,,,,,,,,, -55000,1.7238609,3.1099672,,,,,,,,,,,,,, -55100,1.9098321,3.1461906,,,,,,,,,,,,,, -55200,1.8545394,3.1150174,,,,,,,,,,,,,, -55300,1.9622928,3.0854292,,,,,,,,,,,,,, -55400,1.9093722,3.0819738,,,,,,,,,,,,,, -55482,,,0.7835220098495483,1.07129967212677,0.6714000105857849,1.5408074855804443,50000.0,0.5451000332832336,2.194451332092285,10000.0,18904.38467264176,19582.730887413025,18904.38467264176,675.1203720569611,1.2325246334075928,0.0 -55500,1.8245813,3.1477537,,,,,,,,,,,,,, -55600,1.9645563,3.1386333,,,,,,,,,,,,,, -55700,2.0501845,3.1708982,,,,,,,,,,,,,, -55800,1.8358238,3.111088,,,,,,,,,,,,,, -55900,1.9489642,3.160126,,,,,,,,,,,,,, -56000,1.9102681,3.1287575,,,,,,,,,,,,,, -56100,1.8487349,3.1831784,,,,,,,,,,,,,, -56200,1.9024001,3.1757615,,,,,,,,,,,,,, -56300,1.8546206,3.1124797,,,,,,,,,,,,,, -56400,1.8467325,3.1166575,,,,,,,,,,,,,, -56500,2.2512352,3.1686819,,,,,,,,,,,,,, -56600,2.0918136,3.0951955,,,,,,,,,,,,,, -56700,2.0049329,3.0879626,,,,,,,,,,,,,, -56800,1.973061,3.177382,,,,,,,,,,,,,, -56900,1.9175912,3.2090185,,,,,,,,,,,,,, -56983,,,0.753926157951355,1.2177770137786863,0.6661799550056458,1.6002877950668335,50000.0,0.5392000079154968,2.25341796875,10000.0,19414.40906858444,20110.64827489853,19414.40906858444,692.9199199676514,1.2718331813812256,0.0 -57000,1.959891,3.190118,,,,,,,,,,,,,, -57100,1.8629675,3.1172214,,,,,,,,,,,,,, -57200,1.7641686,3.0432944,,,,,,,,,,,,,, -57300,2.4272983,3.1529775,,,,,,,,,,,,,, -57400,2.0283794,3.197332,,,,,,,,,,,,,, -57500,1.9538269,3.1822686,,,,,,,,,,,,,, -57600,1.9621413,3.1144412,,,,,,,,,,,,,, -57700,2.1603804,3.1269069,,,,,,,,,,,,,, -57800,1.9378701,3.0895803,,,,,,,,,,,,,, -57900,1.8894614,3.1490073,,,,,,,,,,,,,, -58000,1.9142468,3.142649,,,,,,,,,,,,,, -58100,2.1052983,3.1934354,,,,,,,,,,,,,, -58200,1.8690649,3.2061982,,,,,,,,,,,,,, -58300,1.9238907,3.1263316,,,,,,,,,,,,,, -58400,1.9094708,3.0902972,,,,,,,,,,,,,, -58484,,,0.7466916441917419,1.2169824838638306,0.6627799868583679,1.5723552703857422,50000.0,0.5412000417709351,2.232109785079956,10000.0,19924.53598690033,20638.68911504745,19924.53598690033,710.7420144081116,1.3092410564422607,0.0 -58500,1.8610135,3.1018097,,,,,,,,,,,,,, -58600,1.7791474,3.0792577,,,,,,,,,,,,,, -58700,1.7305658,3.0885398,,,,,,,,,,,,,, -58800,1.9158826,3.0709784,,,,,,,,,,,,,, -58900,2.0229974,3.1338348,,,,,,,,,,,,,, -59000,1.9435352,3.1089578,,,,,,,,,,,,,, -59100,1.9345461,3.0736237,,,,,,,,,,,,,, -59200,1.9353551,3.199184,,,,,,,,,,,,,, -59300,1.773202,3.0876462,,,,,,,,,,,,,, -59400,2.0494256,3.1734526,,,,,,,,,,,,,, -59500,2.0308099,3.1646268,,,,,,,,,,,,,, -59600,1.871775,3.1218283,,,,,,,,,,,,,, -59700,2.1565151,3.070734,,,,,,,,,,,,,, -59800,2.1446579,3.0785565,,,,,,,,,,,,,, -59900,1.9681709,3.1236725,,,,,,,,,,,,,, -59985,,,0.7434031963348389,1.2602286338806152,0.6678599715232849,1.5897306203842163,50000.0,0.5350000262260437,2.256978988647461,10000.0,20434.61958694458,21166.6318423748,20434.61958694458,728.5081448554993,1.348177433013916,0.0 -60000,2.052761,3.0656297,,,,,,,,,,,,,, -60100,1.9730297,3.179383,,,,,,,,,,,,,, -60200,2.0795755,3.090517,,,,,,,,,,,,,, -60300,1.9434152,3.1199055,,,,,,,,,,,,,, -60400,1.9630325,3.1182733,,,,,,,,,,,,,, -60500,1.8488598,3.0228562,,,,,,,,,,,,,, -60600,1.9987906,3.1087477,,,,,,,,,,,,,, -60700,1.9804434,3.0565095,,,,,,,,,,,,,, -60800,1.855811,3.1312983,,,,,,,,,,,,,, -60900,1.9061575,3.1494672,,,,,,,,,,,,,, -61000,2.004561,3.0959797,,,,,,,,,,,,,, -61100,1.9180396,3.0445406,,,,,,,,,,,,,, -61200,2.281776,3.1494093,,,,,,,,,,,,,, -61300,1.8788061,3.097375,,,,,,,,,,,,,, -61400,2.2056727,3.1174183,,,,,,,,,,,,,, -61486,,,0.7493622303009033,1.2253031730651855,0.6746999621391296,1.5467970371246338,50000.0,0.5541000366210938,2.180800676345825,10000.0,20944.60235857964,21694.27863621712,20944.60235857964,746.0784072875977,1.3884241580963137,0.0 -61500,1.9713048,3.1606584,,,,,,,,,,,,,, -61600,2.0918808,3.1258793,,,,,,,,,,,,,, -61700,1.8932583,3.0783212,,,,,,,,,,,,,, -61800,1.9225957,3.093715,,,,,,,,,,,,,, -61900,1.94282,3.065205,,,,,,,,,,,,,, -62000,1.9252963,3.1292617,,,,,,,,,,,,,, -62100,2.064412,3.1293201,,,,,,,,,,,,,, -62200,2.1295507,3.1271527,,,,,,,,,,,,,, -62300,1.9254898,3.0880127,,,,,,,,,,,,,, -62400,2.1771703,3.1261215,,,,,,,,,,,,,, -62500,2.081229,3.088168,,,,,,,,,,,,,, -62600,2.0099766,3.0190349,,,,,,,,,,,,,, -62700,2.1696508,3.260363,,,,,,,,,,,,,, -62800,2.0067184,3.0227978,,,,,,,,,,,,,, -62900,1.9904816,3.0504332,,,,,,,,,,,,,, -62987,,,0.7569953799247742,1.1516437530517578,0.6823599934577942,1.4813121557235718,50000.0,0.5515000224113464,2.1404056549072266,10000.0,21454.522994041443,22222.26675367356,21454.522994041443,764.052640914917,1.428706407546997,0.0 -63000,1.9537125,3.1595035,,,,,,,,,,,,,, -63100,2.0107257,3.0985324,,,,,,,,,,,,,, -63200,1.9637545,3.09359,,,,,,,,,,,,,, -63300,1.9754568,3.0517733,,,,,,,,,,,,,, -63400,1.963176,3.12909,,,,,,,,,,,,,, -63500,2.0765498,3.0505245,,,,,,,,,,,,,, -63600,1.81314,2.990022,,,,,,,,,,,,,, -63700,2.0292473,3.0568116,,,,,,,,,,,,,, -63800,2.0213397,3.142583,,,,,,,,,,,,,, -63900,1.9876915,3.0457437,,,,,,,,,,,,,, -64000,1.934588,3.0741313,,,,,,,,,,,,,, -64100,2.2691011,3.153822,,,,,,,,,,,,,, -64200,2.0121481,3.1065624,,,,,,,,,,,,,, -64300,2.0735862,3.0548723,,,,,,,,,,,,,, -64400,1.9447329,3.0726013,,,,,,,,,,,,,, -64488,,,0.7630141973495483,1.167246699333191,0.6773399710655212,1.535491704940796,50000.0,0.5467000007629395,2.184284448623657,10000.0,21964.64013814926,22750.43190002441,21964.64013814926,782.0063388347626,1.467320203781128,0.0 -64500,2.1085064,3.202826,,,,,,,,,,,,,, -64600,2.0629833,3.010476,,,,,,,,,,,,,, -64700,2.0652835,3.108691,,,,,,,,,,,,,, -64800,2.2344768,3.0572627,,,,,,,,,,,,,, -64900,1.9607049,3.1118302,,,,,,,,,,,,,, -65000,1.955187,3.0730166,,,,,,,,,,,,,, -65100,2.1168616,3.1539748,,,,,,,,,,,,,, -65200,1.9581631,3.0554986,,,,,,,,,,,,,, -65300,2.0156665,3.1620798,,,,,,,,,,,,,, -65400,2.033752,3.0310802,,,,,,,,,,,,,, -65500,1.9298061,3.0145364,,,,,,,,,,,,,, -65600,2.0102742,3.148679,,,,,,,,,,,,,, -65700,2.0795732,3.147431,,,,,,,,,,,,,, -65800,1.9223257,3.0552204,,,,,,,,,,,,,, -65900,2.0778327,3.0663981,,,,,,,,,,,,,, -65928,,,0.7694514989852905,1.1408820152282717,0.6707199811935425,1.5621381998062134,50000.0,0.5439000129699707,2.232006549835205,10000.0,22474.734596014023,23278.254689455032,22474.734596014023,799.6415371894836,1.5086112022399902,0.0 -66000,2.0104463,3.0681527,,,,,,,,,,,,,, -66100,1.9638646,3.087806,,,,,,,,,,,,,, -66200,1.9859829,3.0933738,,,,,,,,,,,,,, -66300,2.031523,3.0963116,,,,,,,,,,,,,, -66400,2.080548,3.1639688,,,,,,,,,,,,,, -66500,2.0245361,3.0669942,,,,,,,,,,,,,, -66600,1.8856167,3.08341,,,,,,,,,,,,,, -66700,2.0193307,3.031529,,,,,,,,,,,,,, -66800,2.0605037,3.0324993,,,,,,,,,,,,,, -66900,2.1449366,3.1054773,,,,,,,,,,,,,, -67000,1.8957679,3.025875,,,,,,,,,,,,,, -67100,2.0660923,3.097138,,,,,,,,,,,,,, -67200,2.2250526,3.0794613,,,,,,,,,,,,,, -67300,2.0339465,3.1112208,,,,,,,,,,,,,, -67400,1.9210614,3.0236797,,,,,,,,,,,,,, -67429,,,0.7701291441917419,1.1213020086288452,0.6814000010490417,1.508509635925293,50000.0,0.551800012588501,2.163882255554199,10000.0,22984.724188804623,23805.86190366745,22984.724188804623,817.1650323867798,1.54809308052063,0.0 -67500,1.9500461,3.047475,,,,,,,,,,,,,, -67600,2.0474243,3.098531,,,,,,,,,,,,,, -67700,2.1221073,3.070971,,,,,,,,,,,,,, -67800,2.05075,3.1493664,,,,,,,,,,,,,, -67900,2.0760753,3.1249785,,,,,,,,,,,,,, -68000,2.0684335,3.0391984,,,,,,,,,,,,,, -68100,2.0973403,3.1186523,,,,,,,,,,,,,, -68200,2.1065507,2.9825428,,,,,,,,,,,,,, -68300,2.0690138,3.0749154,,,,,,,,,,,,,, -68400,2.019324,3.0952816,,,,,,,,,,,,,, -68500,2.0830307,3.0467443,,,,,,,,,,,,,, -68600,1.9699366,3.0403042,,,,,,,,,,,,,, -68700,2.196928,3.0755904,,,,,,,,,,,,,, -68800,2.0770838,3.0375452,,,,,,,,,,,,,, -68900,2.1208208,3.0870934,,,,,,,,,,,,,, -68931,,,0.7631736397743225,1.1109957695007324,0.6834200024604797,1.4736833572387695,50000.0,0.5585000514984131,2.107130527496338,10000.0,23494.799648284912,24333.555045366287,23494.799648284912,834.686586856842,1.5904114246368408,0.0 -69000,2.233861,3.1152465,,,,,,,,,,,,,, -69100,2.1049857,3.0310738,,,,,,,,,,,,,, -69200,2.183998,2.9798641,,,,,,,,,,,,,, -69300,2.0450943,3.0823514,,,,,,,,,,,,,, -69400,2.0773008,3.101995,,,,,,,,,,,,,, -69500,2.2029595,3.0966015,,,,,,,,,,,,,, -69600,2.2726262,3.0791905,,,,,,,,,,,,,, -69700,2.0435238,3.0333192,,,,,,,,,,,,,, -69800,2.2002628,3.0825303,,,,,,,,,,,,,, -69900,2.0568335,3.0413241,,,,,,,,,,,,,, -70000,1.9988021,3.0742347,,,,,,,,,,,,,, -70100,1.9327774,3.0745294,,,,,,,,,,,,,, -70200,2.0492344,2.971654,,,,,,,,,,,,,, -70300,2.1006758,3.0520854,,,,,,,,,,,,,, -70400,2.2194178,3.1553192,,,,,,,,,,,,,, -70432,,,0.7593470811843872,1.125109314918518,0.6784999966621399,1.4743475914001465,50000.0,0.554900050163269,2.130442380905152,10000.0,24004.760673046112,24861.097346305847,24004.760673046112,852.1690158843994,1.633352518081665,0.0 -70500,2.145851,3.0380607,,,,,,,,,,,,,, -70600,2.1071615,3.0353956,,,,,,,,,,,,,, -70700,2.1532168,3.1309147,,,,,,,,,,,,,, -70800,2.1549838,3.1069229,,,,,,,,,,,,,, -70900,2.0808706,3.15552,,,,,,,,,,,,,, -71000,2.1416428,3.036115,,,,,,,,,,,,,, -71100,1.9338206,2.9987578,,,,,,,,,,,,,, -71200,2.077679,3.0810926,,,,,,,,,,,,,, -71300,1.9077917,3.086592,,,,,,,,,,,,,, -71400,2.1868234,3.0471168,,,,,,,,,,,,,, -71500,2.110025,3.112339,,,,,,,,,,,,,, -71600,2.137073,3.128062,,,,,,,,,,,,,, -71700,1.9922364,3.0808911,,,,,,,,,,,,,, -71800,2.0198898,2.9978602,,,,,,,,,,,,,, -71900,2.530313,3.1244717,,,,,,,,,,,,,, -71934,,,0.7483657598495483,1.220489263534546,0.6733199954032898,1.5536441802978516,50000.0,0.5496000051498413,2.1928555965423584,10000.0,24514.92568397522,25388.99902510643,24514.92568397522,869.810240983963,1.6746103763580322,0.0 -72000,2.1752048,3.035628,,,,,,,,,,,,,, -72100,2.2709289,3.1116095,,,,,,,,,,,,,, -72200,2.1228204,3.0699084,,,,,,,,,,,,,, -72300,2.125176,3.0861557,,,,,,,,,,,,,, -72400,2.1172209,3.1245801,,,,,,,,,,,,,, -72500,2.1642654,3.0789757,,,,,,,,,,,,,, -72600,2.3025746,2.9833586,,,,,,,,,,,,,, -72700,2.1803162,2.9665937,,,,,,,,,,,,,, -72800,2.1256418,3.043464,,,,,,,,,,,,,, -72900,2.014201,3.0369627,,,,,,,,,,,,,, -73000,2.2249184,3.0827558,,,,,,,,,,,,,, -73100,2.1190367,3.0969305,,,,,,,,,,,,,, -73200,2.0336268,3.063838,,,,,,,,,,,,,, -73300,2.1862788,2.9789777,,,,,,,,,,,,,, -73400,2.0921304,3.0192416,,,,,,,,,,,,,, -73436,,,0.7437220811843872,1.2418133020401,0.6713399887084961,1.5635813474655151,50000.0,0.5415000319480896,2.236577272415161,10000.0,25025.14162898064,25917.068249225616,25025.14162898064,887.5698609352112,1.713430643081665,0.0 -73500,2.0879939,3.0541291,,,,,,,,,,,,,, -73600,2.1927118,2.9683952,,,,,,,,,,,,,, -73700,2.0999455,2.9946609,,,,,,,,,,,,,, -73800,2.2253754,3.025817,,,,,,,,,,,,,, -73900,2.1375153,3.0707848,,,,,,,,,,,,,, -74000,1.9826627,2.991334,,,,,,,,,,,,,, -74100,2.0361614,2.9952826,,,,,,,,,,,,,, -74200,2.2936974,3.1267717,,,,,,,,,,,,,, -74300,2.0302942,3.080624,,,,,,,,,,,,,, -74400,2.4101806,3.0236166,,,,,,,,,,,,,, -74500,2.0990252,3.0416493,,,,,,,,,,,,,, -74600,2.111329,3.1590831,,,,,,,,,,,,,, -74700,2.052691,3.0746253,,,,,,,,,,,,,, -74800,2.2326643,3.1275907,,,,,,,,,,,,,, -74900,2.1800323,3.0372849,,,,,,,,,,,,,, -74938,,,0.7940250039100647,1.0290786027908323,0.6819599866867065,1.4988394975662231,50000.0,0.5593000054359436,2.1356232166290283,10000.0,25535.273594856262,26445.11689734459,25535.273594856262,905.3917419910432,1.753354549407959,0.0 -75000,2.1777742,3.0307007,,,,,,,,,,,,,, -75100,2.3131878,3.0749903,,,,,,,,,,,,,, -75200,2.180227,3.0709558,,,,,,,,,,,,,, -75300,2.1073875,3.0649385,,,,,,,,,,,,,, -75400,2.3442438,3.0571415,,,,,,,,,,,,,, -75500,2.1663876,3.0598614,,,,,,,,,,,,,, -75600,2.1561992,3.0386498,,,,,,,,,,,,,, -75700,2.3522744,3.0620184,,,,,,,,,,,,,, -75800,2.2159433,3.0971825,,,,,,,,,,,,,, -75900,2.0708256,3.110198,,,,,,,,,,,,,, -76000,2.0641682,3.0467973,,,,,,,,,,,,,, -76100,2.2119544,3.0685294,,,,,,,,,,,,,, -76200,2.2132022,3.024975,,,,,,,,,,,,,, -76300,2.252534,3.0022585,,,,,,,,,,,,,, -76400,2.1302252,3.0652473,,,,,,,,,,,,,, -76439,,,0.7794363498687744,1.0432428121566772,0.6851999759674072,1.4631946086883545,50000.0,0.5587000250816345,2.140300989151001,10000.0,26045.1993291378,26973.03084754944,26045.1993291378,923.286732673645,1.793353796005249,0.0 -76500,2.1869445,3.0543275,,,,,,,,,,,,,, -76600,2.0595403,3.0632834,,,,,,,,,,,,,, -76700,2.2041137,3.0575423,,,,,,,,,,,,,, -76800,2.1558614,3.020774,,,,,,,,,,,,,, -76900,2.0750723,3.0093338,,,,,,,,,,,,,, -77000,2.108698,3.1054003,,,,,,,,,,,,,, -77100,2.0311913,3.0165927,,,,,,,,,,,,,, -77200,2.1979494,3.032883,,,,,,,,,,,,,, -77300,2.25185,3.0686712,,,,,,,,,,,,,, -77400,2.1402583,3.0795088,,,,,,,,,,,,,, -77500,2.2432616,3.0108294,,,,,,,,,,,,,, -77600,2.1966572,3.0452237,,,,,,,,,,,,,, -77700,2.2513318,2.9919791,,,,,,,,,,,,,, -77800,2.146611,3.0784447,,,,,,,,,,,,,, -77900,2.2634394,3.082751,,,,,,,,,,,,,, -77941,,,0.7739556431770325,1.1201239824295044,0.686199963092804,1.504704236984253,50000.0,0.5613000392913818,2.144953489303589,10000.0,26555.232617139816,27500.738260746,26555.232617139816,940.8660507202148,1.8331985473632808,0.0 -78000,2.1218116,3.0016632,,,,,,,,,,,,,, -78100,2.248009,2.9619265,,,,,,,,,,,,,, -78200,2.1760058,3.082573,,,,,,,,,,,,,, -78300,2.3327737,3.0517042,,,,,,,,,,,,,, -78400,2.2449064,3.0067945,,,,,,,,,,,,,, -78500,2.381479,3.1090555,,,,,,,,,,,,,, -78600,2.2070572,3.0236163,,,,,,,,,,,,,, -78700,2.4380176,2.9884837,,,,,,,,,,,,,, -78800,2.2488844,3.0391526,,,,,,,,,,,,,, -78900,2.1899638,3.0092037,,,,,,,,,,,,,, -79000,2.1825368,3.072902,,,,,,,,,,,,,, -79100,2.1147048,3.0565896,,,,,,,,,,,,,, -79200,2.2408414,3.0375037,,,,,,,,,,,,,, -79300,2.0519743,2.9851997,,,,,,,,,,,,,, -79400,2.1977973,3.024421,,,,,,,,,,,,,, -79442,,,0.7662029266357422,1.158005952835083,0.6850000023841858,1.5323312282562256,50000.0,0.5567000508308411,2.17029881477356,10000.0,27065.26093101501,28028.20407128334,27065.26093101501,958.206443309784,1.8757927417755127,0.0 -79500,2.2337263,3.051776,,,,,,,,,,,,,, -79600,2.4440947,3.0594964,,,,,,,,,,,,,, -79700,2.372074,3.0589113,,,,,,,,,,,,,, -79800,2.2653425,2.9845295,,,,,,,,,,,,,, -79900,2.3718443,2.9361987,,,,,,,,,,,,,, -80000,2.1614308,2.9372358,,,,,,,,,,,,,, -80100,2.3010468,3.049033,,,,,,,,,,,,,, -80200,2.2094302,3.0344894,,,,,,,,,,,,,, -80300,2.1741788,3.0252645,,,,,,,,,,,,,, -80400,2.0786433,2.9893935,,,,,,,,,,,,,, -80500,2.4183738,3.0622735,,,,,,,,,,,,,, -80600,2.100649,3.0783033,,,,,,,,,,,,,, -80700,2.200427,3.0408022,,,,,,,,,,,,,, -80800,2.1692188,3.037363,,,,,,,,,,,,,, -80900,2.2080104,2.952355,,,,,,,,,,,,,, -80944,,,0.7602837681770325,1.1421968936920166,0.6798799633979797,1.5045496225357056,50000.0,0.5541000366210938,2.1207895278930664,10000.0,27575.289494991302,28557.046733379364,27575.289494991302,976.9262478351592,1.91663146018982,0.0 -81000,2.3200676,3.0152383,,,,,,,,,,,,,, -81100,2.2988892,2.9576266,,,,,,,,,,,,,, -81200,2.0899348,2.9264235,,,,,,,,,,,,,, -81300,2.3425758,2.9833374,,,,,,,,,,,,,, -81400,2.1362653,2.9928932,,,,,,,,,,,,,, -81500,2.3135664,3.004956,,,,,,,,,,,,,, -81600,2.3160772,3.0004644,,,,,,,,,,,,,, -81700,2.2188172,3.0407434,,,,,,,,,,,,,, -81800,2.2153327,2.9782515,,,,,,,,,,,,,, -81900,2.1740377,3.0422378,,,,,,,,,,,,,, -82000,2.1993957,2.980946,,,,,,,,,,,,,, -82100,2.2054849,2.9794014,,,,,,,,,,,,,, -82200,2.2822516,3.0506039,,,,,,,,,,,,,, -82300,2.3667035,3.0298212,,,,,,,,,,,,,, -82400,2.3703518,3.019821,,,,,,,,,,,,,, -82445,,,0.7724011540412903,1.0933239459991455,0.6927399635314941,1.4421255588531494,50000.0,0.5697000026702881,2.080371379852295,10000.0,28085.512244701385,29085.32898545265,28085.512244701385,994.8886668682098,1.9587581157684328,0.0 -82500,2.1932805,3.0122654,,,,,,,,,,,,,, -82600,2.2178562,2.9770472,,,,,,,,,,,,,, -82700,2.3273733,3.0676498,,,,,,,,,,,,,, -82800,2.3529456,3.0112152,,,,,,,,,,,,,, -82900,2.1145992,3.0244603,,,,,,,,,,,,,, -83000,2.1569536,2.9844577,,,,,,,,,,,,,, -83100,2.158108,2.972303,,,,,,,,,,,,,, -83200,2.3727837,3.0093083,,,,,,,,,,,,,, -83300,2.2176628,2.9959247,,,,,,,,,,,,,, -83400,2.327072,3.0431418,,,,,,,,,,,,,, -83500,2.2396867,3.0468142,,,,,,,,,,,,,, -83600,2.2885253,3.0551345,,,,,,,,,,,,,, -83700,2.2726479,3.0017753,,,,,,,,,,,,,, -83800,2.369791,3.0057328,,,,,,,,,,,,,, -83900,2.2353582,2.9492626,,,,,,,,,,,,,, -83948,,,0.8116828799247742,0.9399776458740234,0.6905199885368347,1.4603952169418335,50000.0,0.5618000030517578,2.10579776763916,10000.0,28595.73851132393,29613.13614249229,28595.73851132393,1012.3704881668092,2.003067970275879,0.0 -84000,2.3192785,3.000928,,,,,,,,,,,,,, -84100,2.5275214,3.047193,,,,,,,,,,,,,, -84200,2.2293265,2.9285173,,,,,,,,,,,,,, -84300,2.4320722,3.0303288,,,,,,,,,,,,,, -84400,2.285661,2.9702845,,,,,,,,,,,,,, -84500,2.151889,2.9943528,,,,,,,,,,,,,, -84600,2.2838194,3.0580442,,,,,,,,,,,,,, -84700,2.3623645,3.0480416,,,,,,,,,,,,,, -84800,2.3680022,2.9882421,,,,,,,,,,,,,, -84900,2.3598707,2.947548,,,,,,,,,,,,,, -85000,2.280392,3.0296044,,,,,,,,,,,,,, -85100,2.3064845,2.996368,,,,,,,,,,,,,, -85200,2.2228203,2.9804933,,,,,,,,,,,,,, -85300,2.27038,3.0201242,,,,,,,,,,,,,, -85400,2.2458174,2.9698875,,,,,,,,,,,,,, -85450,,,0.7909757494926453,1.06205952167511,0.6925599575042725,1.4879108667373655,50000.0,0.567300021648407,2.1164779663085938,10000.0,29105.90775370598,30140.93297529221,29105.90775370598,1029.901986837387,2.045266628265381,0.0 -85500,2.3388553,3.0361245,,,,,,,,,,,,,, -85600,2.1687963,2.957604,,,,,,,,,,,,,, -85700,2.3685153,3.0383613,,,,,,,,,,,,,, -85800,2.3178513,3.003937,,,,,,,,,,,,,, -85900,2.3000445,3.009237,,,,,,,,,,,,,, -86000,2.237173,3.009811,,,,,,,,,,,,,, -86100,2.2326164,2.9528832,,,,,,,,,,,,,, -86200,2.26,3.0109954,,,,,,,,,,,,,, -86300,2.3949876,3.007852,,,,,,,,,,,,,, -86400,2.4254098,3.0437212,,,,,,,,,,,,,, -86500,2.2745445,3.0282347,,,,,,,,,,,,,, -86600,2.3701627,3.0381997,,,,,,,,,,,,,, -86700,2.2751791,3.0933595,,,,,,,,,,,,,, -86800,2.4576633,3.0501115,,,,,,,,,,,,,, -86900,2.2858896,2.9549248,,,,,,,,,,,,,, -86952,,,0.7831034660339355,1.0348970890045166,0.6964600086212158,1.4234379529953003,50000.0,0.5677000284194946,2.063734769821167,10000.0,29615.94760608673,30668.66878247261,29615.94760608673,1047.5002024173737,2.0875396728515625,0.0 -87000,2.2196195,2.9868007,,,,,,,,,,,,,, -87100,2.369766,3.0270412,,,,,,,,,,,,,, -87200,2.2116182,2.9973812,,,,,,,,,,,,,, -87300,2.4204278,2.9591532,,,,,,,,,,,,,, -87400,2.3570952,3.0302196,,,,,,,,,,,,,, -87500,2.3033772,3.0208588,,,,,,,,,,,,,, -87600,2.2976978,3.0191388,,,,,,,,,,,,,, -87700,2.3643963,2.993976,,,,,,,,,,,,,, -87800,2.0730505,2.9960182,,,,,,,,,,,,,, -87900,2.185199,3.017227,,,,,,,,,,,,,, -88000,2.2520611,3.0369635,,,,,,,,,,,,,, -88100,2.3314033,3.0059137,,,,,,,,,,,,,, -88200,2.3567488,3.0738988,,,,,,,,,,,,,, -88300,2.3865054,2.9418414,,,,,,,,,,,,,, -88400,2.517759,3.0454,,,,,,,,,,,,,, -88453,,,0.7817681431770325,1.088742971420288,0.6913599967956543,1.4801713228225708,50000.0,0.5667000412940979,2.1170437335968018,10000.0,30125.87562441826,31196.319259643555,30125.87562441826,1065.1176307201383,2.137878894805908,0.0 -88500,2.4759264,3.050342,,,,,,,,,,,,,, -88600,2.2219067,2.9641442,,,,,,,,,,,,,, -88700,2.4633543,2.9512093,,,,,,,,,,,,,, -88800,2.4006078,3.0433822,,,,,,,,,,,,,, -88900,2.4160533,2.987553,,,,,,,,,,,,,, -89000,2.2876756,2.9818072,,,,,,,,,,,,,, -89100,2.5074544,2.9525506,,,,,,,,,,,,,, -89200,2.2775204,3.0049539,,,,,,,,,,,,,, -89300,2.3673615,2.9687128,,,,,,,,,,,,,, -89400,2.3006217,3.0219538,,,,,,,,,,,,,, -89500,2.5770097,3.0064466,,,,,,,,,,,,,, -89600,2.4437802,3.0764823,,,,,,,,,,,,,, -89700,2.2305303,2.933984,,,,,,,,,,,,,, -89800,2.5404887,3.03209,,,,,,,,,,,,,, -89900,2.3634114,3.037964,,,,,,,,,,,,,, -89955,,,0.7678770422935486,1.122750759124756,0.6862999796867371,1.481406569480896,50000.0,0.5601000189781189,2.1493637561798096,10000.0,30635.928003787994,31724.13882732392,30635.928003787994,1082.7853560447693,2.1824607849121094,0.0 -90000,2.3512084,3.0363312,,,,,,,,,,,,,, -90100,2.3995032,3.0184422,,,,,,,,,,,,,, -90200,2.35665,2.9988115,,,,,,,,,,,,,, -90300,2.3804016,3.0184975,,,,,,,,,,,,,, -90400,2.1411674,2.9538822,,,,,,,,,,,,,, -90500,2.357974,2.990461,,,,,,,,,,,,,, -90600,2.3820438,3.0100744,,,,,,,,,,,,,, -90700,2.2481976,3.0211716,,,,,,,,,,,,,, -90800,2.341164,2.9241915,,,,,,,,,,,,,, -90900,2.359691,2.9605405,,,,,,,,,,,,,, -91000,2.4494433,2.9358537,,,,,,,,,,,,,, -91100,2.2070405,2.9768302,,,,,,,,,,,,,, -91200,2.439059,2.967353,,,,,,,,,,,,,, -91300,2.353785,2.99435,,,,,,,,,,,,,, -91400,2.382357,2.940793,,,,,,,,,,,,,, -91457,,,0.7881656289100647,1.0311802625656128,0.7014600038528442,1.415004014968872,50000.0,0.5726000070571899,2.05700421333313,10000.0,31146.005562067032,32251.952335357662,31146.005562067032,1100.424178123474,2.2250864505767822,0.0 -91500,2.4425635,2.9483688,,,,,,,,,,,,,, -91600,2.2430124,3.028259,,,,,,,,,,,,,, -91700,2.4566202,3.007529,,,,,,,,,,,,,, -91800,2.4683168,3.02287,,,,,,,,,,,,,, -91900,2.3384135,2.9949143,,,,,,,,,,,,,, -92000,2.5308962,2.988325,,,,,,,,,,,,,, -92100,2.5155964,2.993234,,,,,,,,,,,,,, -92200,2.2709324,2.930729,,,,,,,,,,,,,, -92300,2.3216138,2.9742317,,,,,,,,,,,,,, -92400,2.3993008,2.9824727,,,,,,,,,,,,,, -92500,2.4307563,2.9778602,,,,,,,,,,,,,, -92600,2.3227236,2.880611,,,,,,,,,,,,,, -92700,2.392787,2.9320946,,,,,,,,,,,,,, -92800,2.434592,2.9003465,,,,,,,,,,,,,, -92900,2.3881168,3.0200913,,,,,,,,,,,,,, -92959,,,0.7984693646430969,1.0095969438552856,0.6953799724578857,1.4463399648666382,50000.0,0.563800036907196,2.128091812133789,10000.0,31655.91071796417,32779.90780115128,31655.91071796417,1118.377030134201,2.269920825958252,0.0 -93000,2.576114,3.018157,,,,,,,,,,,,,, -93100,2.34096,2.903059,,,,,,,,,,,,,, -93200,2.3827543,2.9553883,,,,,,,,,,,,,, -93300,2.1810315,2.8882232,,,,,,,,,,,,,, -93400,2.559771,3.0642154,,,,,,,,,,,,,, -93500,2.3008823,2.9284027,,,,,,,,,,,,,, -93600,2.4257019,2.9775527,,,,,,,,,,,,,, -93700,2.3281467,2.9756165,,,,,,,,,,,,,, -93800,2.544361,2.9867,,,,,,,,,,,,,, -93900,2.276541,2.881102,,,,,,,,,,,,,, -94000,2.6148655,3.0304706,,,,,,,,,,,,,, -94100,2.5518262,2.9844408,,,,,,,,,,,,,, -94200,2.2804415,2.9703097,,,,,,,,,,,,,, -94300,2.494641,2.9049945,,,,,,,,,,,,,, -94400,2.4791017,2.8971488,,,,,,,,,,,,,, -94462,,,0.8036909699440002,1.0080145597457886,0.6985599994659424,1.4534142017364502,50000.0,0.5697000026702881,2.11687970161438,10000.0,32166.15032839775,33307.851593732834,32166.15032839775,1135.9845206737518,2.3124051094055176,0.0 -94500,2.2819605,2.9876087,,,,,,,,,,,,,, -94600,2.3672278,2.967139,,,,,,,,,,,,,, -94700,2.2932565,2.9204724,,,,,,,,,,,,,, -94800,2.6254256,3.0392365,,,,,,,,,,,,,, -94900,2.3403018,2.9242435,,,,,,,,,,,,,, -95000,2.3635457,2.955128,,,,,,,,,,,,,, -95100,2.5020554,3.0155118,,,,,,,,,,,,,, -95200,2.353408,2.8867404,,,,,,,,,,,,,, -95300,2.5653036,2.9574697,,,,,,,,,,,,,, -95400,2.5417757,2.901098,,,,,,,,,,,,,, -95500,2.5190675,2.9219825,,,,,,,,,,,,,, -95600,2.9229245,2.9430246,,,,,,,,,,,,,, -95700,2.4298656,2.9456925,,,,,,,,,,,,,, -95800,2.5878668,2.969101,,,,,,,,,,,,,, -95900,2.540552,2.928479,,,,,,,,,,,,,, -95965,,,0.7919324040412903,1.0704740285873413,0.6911799907684326,1.506101369857788,50000.0,0.5717000365257263,2.1151468753814697,10000.0,32676.37836766243,33835.78617501259,32676.37836766243,1153.5908043384552,2.356222152709961,0.0 -96000,2.3117979,3.0490801,,,,,,,,,,,,,, -96100,2.6170316,3.0721536,,,,,,,,,,,,,, -96200,2.5253458,2.9827416,,,,,,,,,,,,,, -96300,2.3832283,2.9908755,,,,,,,,,,,,,, -96400,2.4014149,2.9344747,,,,,,,,,,,,,, -96500,2.6605358,2.9801729,,,,,,,,,,,,,, -96600,2.4915128,3.000491,,,,,,,,,,,,,, -96700,2.4007068,3.010168,,,,,,,,,,,,,, -96800,2.3834631,2.9193518,,,,,,,,,,,,,, -96900,2.5456579,2.9557045,,,,,,,,,,,,,, -97000,2.4481158,2.948171,,,,,,,,,,,,,, -97100,2.3290603,2.910923,,,,,,,,,,,,,, -97200,2.4517915,2.9321957,,,,,,,,,,,,,, -97300,2.4459023,2.927949,,,,,,,,,,,,,, -97400,2.4440029,2.966473,,,,,,,,,,,,,, -97467,,,0.7996053695678711,0.9840648174285888,0.70551997423172,1.3981292247772217,50000.0,0.5795000195503235,2.038626194000244,10000.0,33186.609236717224,34364.00185251236,33186.609236717224,1171.4737193584442,2.403014659881592,0.0 -97500,2.343001,2.8841121,,,,,,,,,,,,,, -97600,2.5086844,3.0088167,,,,,,,,,,,,,, -97700,2.411116,2.9315944,,,,,,,,,,,,,, -97800,2.4770572,2.9605474,,,,,,,,,,,,,, -97900,2.44239,2.9319205,,,,,,,,,,,,,, -98000,2.5950973,2.9929683,,,,,,,,,,,,,, -98100,2.5936286,2.9438865,,,,,,,,,,,,,, -98200,2.5586543,2.9820242,,,,,,,,,,,,,, -98300,2.5384815,2.9249933,,,,,,,,,,,,,, -98400,2.4097645,2.869155,,,,,,,,,,,,,, -98500,2.6240888,2.972596,,,,,,,,,,,,,, -98600,2.5334144,3.0085878,,,,,,,,,,,,,, -98700,2.4733138,2.922644,,,,,,,,,,,,,, -98800,2.430476,2.9548624,,,,,,,,,,,,,, -98900,2.4962666,2.9657557,,,,,,,,,,,,,, -98969,,,0.7989476919174194,0.9818856716156006,0.7056799530982971,1.3842735290527344,50000.0,0.5839000344276428,2.0269968509674072,10000.0,33696.592266082764,34891.780616760254,33696.592266082764,1189.1691591739657,2.448249578475952,0.0 -99000,2.5287874,2.8845966,,,,,,,,,,,,,, -99100,2.2719924,2.8448563,,,,,,,,,,,,,, -99200,2.5880303,2.9431129,,,,,,,,,,,,,, -99300,2.7289288,3.015428,,,,,,,,,,,,,, -99400,2.7211568,2.9560237,,,,,,,,,,,,,, -99500,2.7013216,2.985514,,,,,,,,,,,,,, -99600,2.4494517,2.953634,,,,,,,,,,,,,, -99700,2.5791657,2.9726644,,,,,,,,,,,,,, -99800,2.4861236,2.915289,,,,,,,,,,,,,, -99900,2.5990932,2.8848646,,,,,,,,,,,,,, -100000,2.4300745,2.9423766,,,,,,,,,,,,,, -100100,2.4400907,2.8704624,,,,,,,,,,,,,, -100200,2.370252,2.886606,,,,,,,,,,,,,, -100300,2.5401137,2.9154623,,,,,,,,,,,,,, -100400,2.5080993,2.9110382,,,,,,,,,,,,,, -100470,,,0.8025948405265808,0.9827648401260376,0.707539975643158,1.3925204277038574,50000.0,0.5842000246047974,2.023866653442383,10000.0,34206.52460384369,35419.403853178024,34206.52460384369,1206.7572317123413,2.495995283126831,0.0 -100500,2.555194,2.9520328,,,,,,,,,,,,,, -100600,2.5688527,2.880209,,,,,,,,,,,,,, -100700,2.8428743,2.9241104,,,,,,,,,,,,,, -100800,2.617205,2.9234936,,,,,,,,,,,,,, -100900,2.531557,2.9675248,,,,,,,,,,,,,, -101000,2.5564752,2.9319816,,,,,,,,,,,,,, -101100,2.5586555,2.949426,,,,,,,,,,,,,, -101200,2.871461,2.8866565,,,,,,,,,,,,,, -101300,2.562248,2.9298034,,,,,,,,,,,,,, -101400,2.5871298,2.9330287,,,,,,,,,,,,,, -101500,2.6392202,2.8860521,,,,,,,,,,,,,, -101600,2.6739485,2.904231,,,,,,,,,,,,,, -101700,2.7652664,3.0101528,,,,,,,,,,,,,, -101800,2.5937726,2.9294975,,,,,,,,,,,,,, -101900,2.5300982,3.005974,,,,,,,,,,,,,, -101972,,,0.8014987111091614,1.016085505485535,0.707040011882782,1.427125334739685,50000.0,0.5784000158309937,2.0674097537994385,10000.0,34716.45545458794,35947.121560812,34716.45545458794,1224.4450266361237,2.540762186050415,0.0 -102000,2.4822116,2.9315176,,,,,,,,,,,,,, -102100,2.4947476,2.8481839,,,,,,,,,,,,,, -102200,2.6658697,2.98964,,,,,,,,,,,,,, -102300,2.5621772,2.9264803,,,,,,,,,,,,,, -102400,2.5091805,2.8136258,,,,,,,,,,,,,, -102500,2.732895,2.8979115,,,,,,,,,,,,,, -102600,2.6626472,2.8767385,,,,,,,,,,,,,, -102700,2.5159996,2.8664439,,,,,,,,,,,,,, -102800,2.6744962,2.9735758,,,,,,,,,,,,,, -102900,2.6069376,2.9635026,,,,,,,,,,,,,, -103000,2.7110696,2.8741026,,,,,,,,,,,,,, -103100,2.7534146,2.922583,,,,,,,,,,,,,, -103200,2.5847194,2.887261,,,,,,,,,,,,,, -103300,2.5888333,2.858472,,,,,,,,,,,,,, -103400,2.6201038,2.862263,,,,,,,,,,,,,, -103474,,,0.8282246589660645,0.8729314804077148,0.7114599943161011,1.368599534034729,50000.0,0.5851000547409058,2.004522562026977,10000.0,35226.452053546906,36475.07280921936,35226.452053546906,1242.2984237670898,2.5861055850982666,0.0 -103500,2.726917,2.968567,,,,,,,,,,,,,, -103600,2.647316,2.9445229,,,,,,,,,,,,,, -103700,2.679004,2.9596255,,,,,,,,,,,,,, -103800,2.4760396,2.8710976,,,,,,,,,,,,,, -103900,2.5756078,2.9213848,,,,,,,,,,,,,, -104000,2.667053,2.9322658,,,,,,,,,,,,,, -104100,2.7138379,2.916305,,,,,,,,,,,,,, -104200,2.604333,2.8634522,,,,,,,,,,,,,, -104300,2.6430783,2.8811588,,,,,,,,,,,,,, -104400,2.5879898,2.9529421,,,,,,,,,,,,,, -104500,2.6928804,2.9474092,,,,,,,,,,,,,, -104600,2.6263304,2.9038653,,,,,,,,,,,,,, -104700,2.5940473,2.9370809,,,,,,,,,,,,,, -104800,2.704926,2.9070802,,,,,,,,,,,,,, -104900,2.6396906,2.8888254,,,,,,,,,,,,,, -104976,,,0.8075972199440002,0.9604279398918152,0.7060999870300293,1.4053521156311035,50000.0,0.5789000391960144,2.0361974239349365,10000.0,35736.37076330185,37002.70536971092,35736.37076330185,1259.911861896515,2.632046937942505,0.0 -105000,2.519027,2.879271,,,,,,,,,,,,,, -105100,2.595862,2.9345355,,,,,,,,,,,,,, -105200,2.5280752,2.885212,,,,,,,,,,,,,, -105300,2.7814136,2.9812593,,,,,,,,,,,,,, -105400,2.7748754,2.9132338,,,,,,,,,,,,,, -105500,2.692958,2.8303452,,,,,,,,,,,,,, -105600,2.6895263,2.90952,,,,,,,,,,,,,, -105700,2.6134467,2.912188,,,,,,,,,,,,,, -105800,2.5992715,2.9139183,,,,,,,,,,,,,, -105900,2.6031435,2.8958426,,,,,,,,,,,,,, -106000,2.7492824,2.881915,,,,,,,,,,,,,, -106100,2.7902527,2.8873518,,,,,,,,,,,,,, -106200,2.6085114,2.9184623,,,,,,,,,,,,,, -106300,2.6802611,2.9863598,,,,,,,,,,,,,, -106400,2.6959424,2.892499,,,,,,,,,,,,,, -106478,,,0.8165656924247742,0.916153609752655,0.7117599844932556,1.3760380744934082,50000.0,0.5820000171661377,2.037717342376709,10000.0,36246.44602751732,37530.45402097702,36246.44602751732,1277.4861352443695,2.677516460418701,0.0 -106500,2.7602828,2.8692298,,,,,,,,,,,,,, -106600,2.6679814,2.9397292,,,,,,,,,,,,,, -106700,2.7074838,2.8577747,,,,,,,,,,,,,, -106800,2.6920986,2.9192,,,,,,,,,,,,,, -106900,2.716123,2.9045453,,,,,,,,,,,,,, -107000,2.4740825,2.9076478,,,,,,,,,,,,,, -107100,2.7742872,2.924469,,,,,,,,,,,,,, -107200,2.5536559,2.9250934,,,,,,,,,,,,,, -107300,2.6837583,2.8694913,,,,,,,,,,,,,, -107400,2.7311304,2.8323982,,,,,,,,,,,,,, -107500,2.5719397,2.8895383,,,,,,,,,,,,,, -107600,2.7145295,2.9127874,,,,,,,,,,,,,, -107700,2.763216,2.900689,,,,,,,,,,,,,, -107800,2.845598,2.9925876,,,,,,,,,,,,,, -107900,2.6225247,2.8083615,,,,,,,,,,,,,, -107981,,,0.8157086968421936,0.9369272589683532,0.7161999940872192,1.3731088638305664,50000.0,0.5901000499725342,2.000663042068481,10000.0,36756.588076114655,38058.4887046814,36756.588076114655,1295.2741153240204,2.727273941040039,0.0 -108000,2.795294,2.9609456,,,,,,,,,,,,,, -108100,2.6688068,2.872983,,,,,,,,,,,,,, -108200,2.6295724,2.905869,,,,,,,,,,,,,, -108300,2.6942797,2.9418106,,,,,,,,,,,,,, -108400,2.6567574,2.9143429,,,,,,,,,,,,,, -108500,2.705006,2.8624854,,,,,,,,,,,,,, -108600,2.628212,2.8281353,,,,,,,,,,,,,, -108700,2.6125991,2.878748,,,,,,,,,,,,,, -108800,2.7269557,2.9094198,,,,,,,,,,,,,, -108900,2.9471364,2.9225457,,,,,,,,,,,,,, -109000,2.518268,2.862458,,,,,,,,,,,,,, -109100,2.8486948,2.9253483,,,,,,,,,,,,,, -109200,2.7231042,2.9105,,,,,,,,,,,,,, -109300,2.5515425,2.8335583,,,,,,,,,,,,,, -109400,2.8396056,2.911,,,,,,,,,,,,,, -109483,,,0.8148317933082581,0.9315711855888368,0.7142199873924255,1.35802960395813,50000.0,0.5835000276565552,2.0119309425354004,10000.0,37266.67953324318,38586.27570772171,37266.67953324318,1312.8663160800934,2.7767326831817627,0.0 -109500,2.8156557,2.8932056,,,,,,,,,,,,,, -109600,2.5324519,2.845992,,,,,,,,,,,,,, -109700,2.6583033,2.9022355,,,,,,,,,,,,,, -109800,2.7871008,2.9217768,,,,,,,,,,,,,, -109900,2.6060047,2.852145,,,,,,,,,,,,,, -110000,2.6786935,2.915538,,,,,,,,,,,,,, -110100,2.8046546,2.8612406,,,,,,,,,,,,,, -110200,2.701628,2.8658702,,,,,,,,,,,,,, -110300,2.8515694,2.8706732,,,,,,,,,,,,,, -110400,2.7694302,2.908706,,,,,,,,,,,,,, -110500,2.736179,2.8904347,,,,,,,,,,,,,, -110600,2.957656,2.9219096,,,,,,,,,,,,,, -110700,2.8496165,2.9102561,,,,,,,,,,,,,, -110800,2.8516436,2.8860104,,,,,,,,,,,,,, -110900,2.7722557,2.8970954,,,,,,,,,,,,,, -110985,,,0.8172432780265808,0.890960693359375,0.7128599882125854,1.340518832206726,50000.0,0.5811000466346741,1.99495792388916,10000.0,37776.58733320236,39114.19459462166,37776.58733320236,1330.7699587345123,2.829568386077881,0.0 -111000,2.665255,2.8589022,,,,,,,,,,,,,, -111100,2.8465047,2.8711195,,,,,,,,,,,,,, -111200,2.9601884,2.941562,,,,,,,,,,,,,, -111300,2.763879,2.879105,,,,,,,,,,,,,, -111400,2.7347503,2.8013828,,,,,,,,,,,,,, -111500,2.7468617,2.8153949,,,,,,,,,,,,,, -111600,2.7334485,2.8847172,,,,,,,,,,,,,, -111700,2.7001476,2.8539212,,,,,,,,,,,,,, -111800,2.5930033,2.8296258,,,,,,,,,,,,,, -111900,2.8297782,2.8953063,,,,,,,,,,,,,, -112000,2.7865386,2.930922,,,,,,,,,,,,,, -112100,2.9174712,2.91149,,,,,,,,,,,,,, -112200,2.610961,2.7906916,,,,,,,,,,,,,, -112300,2.722511,2.831341,,,,,,,,,,,,,, -112400,2.7200475,2.916823,,,,,,,,,,,,,, -112487,,,0.8461814522743225,0.7963310480117798,0.7188400030136108,1.3273828029632568,50000.0,0.5963000059127808,1.9578930139541624,10000.0,38286.6563835144,39642.166607141495,38286.6563835144,1348.5736198425293,2.8754875659942627,0.0 -112500,2.772983,2.8967037,,,,,,,,,,,,,, -112600,2.6670282,2.8471785,,,,,,,,,,,,,, -112700,2.7063873,2.8202431,,,,,,,,,,,,,, -112800,2.7675374,2.887404,,,,,,,,,,,,,, -112900,2.8666818,2.9187493,,,,,,,,,,,,,, -113000,2.6659374,2.8724384,,,,,,,,,,,,,, -113100,2.8797684,2.9054043,,,,,,,,,,,,,, -113200,2.8481283,2.8133552,,,,,,,,,,,,,, -113300,2.812097,2.8329482,,,,,,,,,,,,,, -113400,2.9503655,2.8729258,,,,,,,,,,,,,, -113500,2.9062245,2.9245067,,,,,,,,,,,,,, -113600,2.8455083,2.801244,,,,,,,,,,,,,, -113700,2.8623865,2.8519142,,,,,,,,,,,,,, -113800,2.8464868,2.8127284,,,,,,,,,,,,,, -113900,2.849801,2.8282726,,,,,,,,,,,,,, -113990,,,0.8255141973495483,0.8634473085403442,0.7139399647712708,1.3523683547973633,50000.0,0.5891000032424927,2.0083277225494385,10000.0,38796.7903342247,40169.9016866684,38796.7903342247,1366.0709924697876,2.9247477054595947,0.0 -114000,3.0090938,2.8408322,,,,,,,,,,,,,, -114100,2.8835497,2.8335109,,,,,,,,,,,,,, -114200,2.821863,2.9546518,,,,,,,,,,,,,, -114300,3.035724,2.8542297,,,,,,,,,,,,,, -114400,2.9274473,2.8643277,,,,,,,,,,,,,, -114500,2.68584,2.8366547,,,,,,,,,,,,,, -114600,2.7880363,2.865427,,,,,,,,,,,,,, -114700,2.7896845,2.8792949,,,,,,,,,,,,,, -114800,3.1574929,2.8854423,,,,,,,,,,,,,, -114900,2.8777003,2.8169909,,,,,,,,,,,,,, -115000,3.1392503,2.8656437,,,,,,,,,,,,,, -115100,2.8553047,2.8015962,,,,,,,,,,,,,, -115200,2.9733975,2.8300457,,,,,,,,,,,,,, -115300,2.8216398,2.857024,,,,,,,,,,,,,, -115400,2.9468079,2.8450162,,,,,,,,,,,,,, -115492,,,0.8307756781578064,0.861309826374054,0.7196199893951416,1.3414368629455566,50000.0,0.5921000242233276,1.966264724731445,10000.0,39306.94250631333,40697.91963505745,39306.94250631333,1383.8353281021118,2.9718353748321533,0.0 -115500,2.6901455,2.8355393,,,,,,,,,,,,,, -115600,2.8488288,2.821124,,,,,,,,,,,,,, -115700,2.959236,2.841883,,,,,,,,,,,,,, -115800,2.7173712,2.8136725,,,,,,,,,,,,,, -115900,2.7290702,2.819672,,,,,,,,,,,,,, -116000,2.8435242,2.8455565,,,,,,,,,,,,,, -116100,2.8990796,2.9054298,,,,,,,,,,,,,, -116200,2.946813,2.8502738,,,,,,,,,,,,,, -116300,2.9556324,2.8970146,,,,,,,,,,,,,, -116400,2.995506,2.8038301,,,,,,,,,,,,,, -116500,2.9560907,2.948886,,,,,,,,,,,,,, -116600,2.7928135,2.9333732,,,,,,,,,,,,,, -116700,2.8184195,2.836785,,,,,,,,,,,,,, -116800,2.801669,2.824902,,,,,,,,,,,,,, -116900,3.0160723,2.8363693,,,,,,,,,,,,,, -116995,,,0.8356983065605164,0.8533991575241089,0.7249400019645691,1.3167498111724854,50000.0,0.5955000519752502,1.9473471641540527,10000.0,39817.17044043541,41225.83962345123,39817.17044043541,1401.4225118160248,3.0217180252075195,0.0 -117000,2.9051192,2.8671517,,,,,,,,,,,,,, -117100,3.1569717,2.8748193,,,,,,,,,,,,,, -117200,3.0029914,2.863435,,,,,,,,,,,,,, -117300,3.014217,2.8719628,,,,,,,,,,,,,, -117400,2.8759096,2.8615975,,,,,,,,,,,,,, -117500,2.8866727,2.860104,,,,,,,,,,,,,, -117600,2.91198,2.8728147,,,,,,,,,,,,,, -117700,3.009343,2.7839262,,,,,,,,,,,,,, -117800,3.0060792,2.8532712,,,,,,,,,,,,,, -117900,2.999638,2.8320777,,,,,,,,,,,,,, -118000,2.9277163,2.8120651,,,,,,,,,,,,,, -118100,2.9088192,2.8071332,,,,,,,,,,,,,, -118200,3.0557935,2.8656678,,,,,,,,,,,,,, -118300,2.7309651,2.776984,,,,,,,,,,,,,, -118400,2.7538166,2.8198986,,,,,,,,,,,,,, -118497,,,0.8339644074440002,0.8613070249557495,0.7239999771118164,1.3303016424179075,50000.0,0.5940999984741211,1.9697046279907229,10000.0,40327.12989664078,41754.2838177681,40327.12989664078,1419.805627822876,3.0692250728607178,0.0 -118500,2.8984683,2.9196486,,,,,,,,,,,,,, -118600,3.0662713,2.8042467,,,,,,,,,,,,,, -118700,2.7465823,2.7739818,,,,,,,,,,,,,, -118800,3.069991,2.8711236,,,,,,,,,,,,,, -118900,2.9066384,2.8022008,,,,,,,,,,,,,, -119000,3.107464,2.8152196,,,,,,,,,,,,,, -119100,3.0155663,2.8321822,,,,,,,,,,,,,, -119200,3.0216477,2.80504,,,,,,,,,,,,,, -119300,3.0270894,2.8232303,,,,,,,,,,,,,, -119400,3.0905764,2.8451314,,,,,,,,,,,,,, -119500,2.9741182,2.7835765,,,,,,,,,,,,,, -119600,3.099059,2.8896298,,,,,,,,,,,,,, -119700,2.9151182,2.8337026,,,,,,,,,,,,,, -119800,2.8962934,2.8118157,,,,,,,,,,,,,, -119900,3.0690594,2.8083932,,,,,,,,,,,,,, -119999,,,0.8390266299247742,0.8340352177619934,0.7273600101470947,1.301812767982483,50000.0,0.6055999994277954,1.9222824573516848,10000.0,40837.23765873909,42282.34254765511,40837.23765873909,1437.6517629623413,3.1198184490203857,0.0 -120000,3.0596128,2.8160653,,,,,,,,,,,,,, -120100,2.9980228,2.8341434,,,,,,,,,,,,,, -120200,3.1705785,2.824985,,,,,,,,,,,,,, -120300,2.9177082,2.7975106,,,,,,,,,,,,,, -120400,3.100425,2.8897147,,,,,,,,,,,,,, -120500,3.1418505,2.8257384,,,,,,,,,,,,,, -120600,2.9472814,2.7438169,,,,,,,,,,,,,, -120700,2.9671774,2.9151156,,,,,,,,,,,,,, -120800,3.0596561,2.7837312,,,,,,,,,,,,,, -120900,2.9757676,2.8201036,,,,,,,,,,,,,, -121000,2.9647827,2.7908454,,,,,,,,,,,,,, -121100,2.8703904,2.7806587,,,,,,,,,,,,,, -121200,3.2167304,2.830457,,,,,,,,,,,,,, -121300,3.2658138,2.810526,,,,,,,,,,,,,, -121400,3.0170944,2.7995648,,,,,,,,,,,,,, -121500,2.9377723,2.7343354,,,,,,,,,,,,,, -121501,,,0.8613081574440002,0.7704369425773621,0.7250399589538574,1.3356198072433472,50000.0,0.6010000109672546,1.9819520711898804,10000.0,41347.40583300591,42810.052292108536,41347.40583300591,1455.0882284641266,3.1719682216644287,0.0 -121600,3.0547273,2.8377328,,,,,,,,,,,,,, -121700,3.142031,2.834987,,,,,,,,,,,,,, -121800,2.9797604,2.7702663,,,,,,,,,,,,,, -121900,2.7709374,2.819982,,,,,,,,,,,,,, -122000,3.20161,2.8445585,,,,,,,,,,,,,, -122100,2.9614866,2.794572,,,,,,,,,,,,,, -122200,3.1299174,2.8864393,,,,,,,,,,,,,, -122300,3.105484,2.7986426,,,,,,,,,,,,,, -122400,2.8649871,2.8163345,,,,,,,,,,,,,, -122500,2.994236,2.789719,,,,,,,,,,,,,, -122600,3.0659068,2.8123229,,,,,,,,,,,,,, -122700,3.0302198,2.7973776,,,,,,,,,,,,,, -122800,2.8580873,2.774391,,,,,,,,,,,,,, -122900,3.1994646,2.8312628,,,,,,,,,,,,,, -123000,2.826774,2.7580109,,,,,,,,,,,,,, -123004,,,0.8552694320678711,0.7926141023635864,0.7265399694442749,1.3171532154083252,50000.0,0.6055000424385071,1.9438683986663816,10000.0,41857.52238154411,43337.58539104462,41857.52238154411,1472.3993520736694,3.2225170135498047,0.0 -123100,3.1109567,2.8243136,,,,,,,,,,,,,, -123200,3.1104684,2.8059592,,,,,,,,,,,,,, -123300,3.1356921,2.8609328,,,,,,,,,,,,,, -123400,3.0177283,2.8070467,,,,,,,,,,,,,, -123500,3.042725,2.8173006,,,,,,,,,,,,,, -123600,3.0147867,2.8376207,,,,,,,,,,,,,, -123700,3.2938697,2.8650284,,,,,,,,,,,,,, -123800,3.1887333,2.751554,,,,,,,,,,,,,, -123900,3.042251,2.8455148,,,,,,,,,,,,,, -124000,3.0113437,2.77924,,,,,,,,,,,,,, -124100,3.214754,2.8314252,,,,,,,,,,,,,, -124200,3.007388,2.8092525,,,,,,,,,,,,,, -124300,3.1390514,2.8111567,,,,,,,,,,,,,, -124400,2.8928418,2.7633832,,,,,,,,,,,,,, -124500,3.1622872,2.7526584,,,,,,,,,,,,,, -124506,,,0.8481743931770325,0.8100681304931641,0.7289199829101562,1.3144792318344116,50000.0,0.6068000197410583,1.9477180242538448,10000.0,42367.69151568413,43865.30625462532,42367.69151568413,1489.850107908249,3.269981622695923,0.0 -124600,3.141691,2.7829046,,,,,,,,,,,,,, -124700,3.34435,2.8051558,,,,,,,,,,,,,, -124800,3.067353,2.7483068,,,,,,,,,,,,,, -124900,3.2047222,2.8356836,,,,,,,,,,,,,, -125000,2.9675815,2.7617638,,,,,,,,,,,,,, -125100,3.115425,2.7631903,,,,,,,,,,,,,, -125200,3.0798051,2.7455022,,,,,,,,,,,,,, -125300,3.045227,2.7873888,,,,,,,,,,,,,, -125400,3.1321661,2.8079488,,,,,,,,,,,,,, -125500,3.2289386,2.7835708,,,,,,,,,,,,,, -125600,3.1332128,2.7877643,,,,,,,,,,,,,, -125700,3.0799763,2.7885547,,,,,,,,,,,,,, -125800,3.3033702,2.8066967,,,,,,,,,,,,,, -125900,3.0914245,2.7967162,,,,,,,,,,,,,, -126000,3.2003858,2.8167348,,,,,,,,,,,,,, -126009,,,0.845723032951355,0.7846372127532959,0.7291799783706665,1.2905118465423584,50000.0,0.6034000515937805,1.9218626022338867,10000.0,42877.82626962662,44393.40523290634,42877.82626962662,1507.7119023799896,3.3195204734802246,0.0 -126100,3.0250843,2.7561944,,,,,,,,,,,,,, -126200,2.963163,2.7561052,,,,,,,,,,,,,, -126300,3.0081453,2.7575324,,,,,,,,,,,,,, -126400,3.2495208,2.7694051,,,,,,,,,,,,,, -126500,2.963042,2.71071,,,,,,,,,,,,,, -126600,3.2559247,2.8048906,,,,,,,,,,,,,, -126700,3.1933005,2.8386254,,,,,,,,,,,,,, -126800,3.1683533,2.8003755,,,,,,,,,,,,,, -126900,3.1398838,2.7691548,,,,,,,,,,,,,, -127000,3.2010827,2.7271247,,,,,,,,,,,,,, -127100,3.1489837,2.765964,,,,,,,,,,,,,, -127200,3.3181183,2.7555037,,,,,,,,,,,,,, -127300,3.091573,2.8144643,,,,,,,,,,,,,, -127400,3.0517008,2.7556164,,,,,,,,,,,,,, -127500,3.0321243,2.7392082,,,,,,,,,,,,,, -127511,,,0.8580795526504517,0.7403988838195801,0.7360000014305115,1.253514289855957,50000.0,0.615600049495697,1.869235634803772,10000.0,43387.9768986702,44921.3228840828,43387.9768986702,1525.374079704285,3.3690085411071777,0.0 -127600,3.3877423,2.7644625,,,,,,,,,,,,,, -127700,3.2073197,2.730711,,,,,,,,,,,,,, -127800,3.238844,2.8346457,,,,,,,,,,,,,, -127900,2.9869382,2.7250705,,,,,,,,,,,,,, -128000,3.35956,2.823605,,,,,,,,,,,,,, -128100,3.0154371,2.7744126,,,,,,,,,,,,,, -128200,3.3751426,2.7655063,,,,,,,,,,,,,, -128300,3.4192564,2.763334,,,,,,,,,,,,,, -128400,2.9968243,2.7027938,,,,,,,,,,,,,, -128500,3.2728546,2.810104,,,,,,,,,,,,,, -128600,2.8856096,2.715608,,,,,,,,,,,,,, -128700,3.1619372,2.7298286,,,,,,,,,,,,,, -128800,3.2116992,2.8841307,,,,,,,,,,,,,, -128900,3.12186,2.7952826,,,,,,,,,,,,,, -129000,3.45211,2.8026228,,,,,,,,,,,,,, -129014,,,0.85550856590271,0.7635818123817444,0.7354599833488464,1.2662724256515503,50000.0,0.6055999994277954,1.904442310333252,10000.0,43898.0478746891,45448.98998808861,43898.0478746891,1542.8680157661438,3.4177677631378174,0.0 -129100,3.3427937,2.8001597,,,,,,,,,,,,,, -129200,3.2067947,2.7429247,,,,,,,,,,,,,, -129300,3.4414585,2.837203,,,,,,,,,,,,,, -129400,3.155892,2.7665262,,,,,,,,,,,,,, -129500,3.3149612,2.8391528,,,,,,,,,,,,,, -129600,3.158327,2.7747343,,,,,,,,,,,,,, -129700,3.0128934,2.7062287,,,,,,,,,,,,,, -129800,3.2247088,2.7851865,,,,,,,,,,,,,, -129900,3.5020442,2.727715,,,,,,,,,,,,,, -130000,3.5145245,2.7783692,,,,,,,,,,,,,, -130100,3.3277535,2.7308834,,,,,,,,,,,,,, -130200,3.3063,2.7246275,,,,,,,,,,,,,, -130300,3.0848656,2.7794785,,,,,,,,,,,,,, -130400,3.5712886,2.7714615,,,,,,,,,,,,,, -130500,3.2604666,2.8028045,,,,,,,,,,,,,, -130515,,,0.8826330900192261,0.6655307412147522,0.7340999841690063,1.2825183868408203,50000.0,0.6105000376701355,1.9217170476913448,10000.0,44407.95130634308,45976.39903450012,44407.95130634308,1560.2620012760162,3.47461485862732,0.0 -130600,3.3120883,2.799063,,,,,,,,,,,,,, -130700,3.3671942,2.713103,,,,,,,,,,,,,, -130800,3.2647352,2.723103,,,,,,,,,,,,,, -130900,3.2002518,2.7523904,,,,,,,,,,,,,, -131000,3.1949575,2.7255812,,,,,,,,,,,,,, -131100,3.0975602,2.7562618,,,,,,,,,,,,,, -131200,3.1082656,2.724534,,,,,,,,,,,,,, -131300,3.178059,2.7352877,,,,,,,,,,,,,, -131400,3.525513,2.76863,,,,,,,,,,,,,, -131500,3.1669905,2.708365,,,,,,,,,,,,,, -131600,3.3842018,2.7767198,,,,,,,,,,,,,, -131700,3.1097379,2.6904166,,,,,,,,,,,,,, -131800,3.279423,2.6644778,,,,,,,,,,,,,, -131900,3.326476,2.7463021,,,,,,,,,,,,,, -132000,3.2933023,2.7419448,,,,,,,,,,,,,, -132018,,,0.8704758882522583,0.7248345613479614,0.731499969959259,1.2995328903198242,50000.0,0.6077000498771667,1.936708688735962,10000.0,44918.10623574257,46504.47780227661,44918.10623574257,1578.0767569541931,3.5267868041992188,0.0 -132100,3.303252,2.813704,,,,,,,,,,,,,, -132200,3.1842453,2.758869,,,,,,,,,,,,,, -132300,3.3293293,2.7846184,,,,,,,,,,,,,, -132400,3.2225118,2.7325563,,,,,,,,,,,,,, -132500,3.4395587,2.7513027,,,,,,,,,,,,,, -132600,3.3239944,2.7081213,,,,,,,,,,,,,, -132700,3.2099237,2.750095,,,,,,,,,,,,,, -132800,3.4127545,2.7450154,,,,,,,,,,,,,, -132900,3.2676947,2.7540941,,,,,,,,,,,,,, -133000,3.1707382,2.779911,,,,,,,,,,,,,, -133100,3.3518221,2.7080436,,,,,,,,,,,,,, -133200,3.2852747,2.765482,,,,,,,,,,,,,, -133300,3.279176,2.698197,,,,,,,,,,,,,, -133400,3.223128,2.770392,,,,,,,,,,,,,, -133500,3.5258098,2.7216375,,,,,,,,,,,,,, -133517,,,0.8717314600944519,0.7035813331604004,0.7369599938392639,1.2577601671218872,50000.0,0.6086000204086304,1.885425329208374,10000.0,45426.990837574005,47032.312911748886,45426.990837574005,1595.8508143424988,4.649079084396362,0.0 -133600,3.2416985,2.6895893,,,,,,,,,,,,,, -133700,3.297573,2.6905532,,,,,,,,,,,,,, -133800,3.356009,2.804061,,,,,,,,,,,,,, -133900,3.3355262,2.733906,,,,,,,,,,,,,, -134000,3.488606,2.8155935,,,,,,,,,,,,,, -134100,3.530521,2.8192356,,,,,,,,,,,,,, -134200,3.4756804,2.7506797,,,,,,,,,,,,,, -134300,3.1937792,2.72189,,,,,,,,,,,,,, -134400,3.2194636,2.6909122,,,,,,,,,,,,,, -134500,3.2776785,2.7063942,,,,,,,,,,,,,, -134600,3.3002427,2.6969824,,,,,,,,,,,,,, -134700,3.3794804,2.737864,,,,,,,,,,,,,, -134800,3.2961757,2.741856,,,,,,,,,,,,,, -134900,3.3726504,2.7513344,,,,,,,,,,,,,, -135000,3.2389138,2.6670775,,,,,,,,,,,,,, -135020,,,0.868582546710968,0.7156649827957153,0.7356799840927124,1.2662253379821775,50000.0,0.6134000420570374,1.8910014629364007,10000.0,45937.16322660446,47560.45282816887,45937.16322660446,1613.7076733112335,4.703073501586914,0.0 -135100,3.3890524,2.7814426,,,,,,,,,,,,,, -135200,3.488947,2.7766879,,,,,,,,,,,,,, -135300,3.5391269,2.7521038,,,,,,,,,,,,,, -135400,3.2225535,2.7184944,,,,,,,,,,,,,, -135500,3.232158,2.6963131,,,,,,,,,,,,,, -135600,3.5977867,2.7615333,,,,,,,,,,,,,, -135700,3.2103193,2.7347698,,,,,,,,,,,,,, -135800,3.3998926,2.7347543,,,,,,,,,,,,,, -135900,3.2917717,2.6720946,,,,,,,,,,,,,, -136000,3.3074486,2.7177205,,,,,,,,,,,,,, -136100,3.3895736,2.7185922,,,,,,,,,,,,,, -136200,3.3395286,2.6548963,,,,,,,,,,,,,, -136300,3.430432,2.7116096,,,,,,,,,,,,,, -136400,3.3067462,2.71402,,,,,,,,,,,,,, -136500,3.5320837,2.7643402,,,,,,,,,,,,,, -136523,,,0.8720503449440002,0.7015858888626099,0.7404599785804749,1.2451156377792358,50000.0,0.6157000064849854,1.8840872049331665,10000.0,46447.39303159714,48088.235530138016,46447.39303159714,1631.1530323028564,4.756369113922119,0.0 -136600,3.5559473,2.756557,,,,,,,,,,,,,, -136700,3.3589547,2.7111287,,,,,,,,,,,,,, -136800,3.2214627,2.671615,,,,,,,,,,,,,, -136900,3.0638125,2.6255388,,,,,,,,,,,,,, -137000,3.3015118,2.774632,,,,,,,,,,,,,, -137100,3.6575682,2.7333612,,,,,,,,,,,,,, -137200,3.4908795,2.7466989,,,,,,,,,,,,,, -137300,3.5232658,2.6530263,,,,,,,,,,,,,, -137400,3.571386,2.7280192,,,,,,,,,,,,,, -137500,3.521789,2.654686,,,,,,,,,,,,,, -137600,3.2800727,2.7332454,,,,,,,,,,,,,, -137700,3.3750308,2.7229028,,,,,,,,,,,,,, -137800,3.4928734,2.7051413,,,,,,,,,,,,,, -137900,3.4474442,2.7350898,,,,,,,,,,,,,, -138000,3.506447,2.6719112,,,,,,,,,,,,,, -138025,,,0.8668486475944519,0.7391188144683838,0.7368199825286865,1.2867945432662964,50000.0,0.6074000000953674,1.9359248876571653,10000.0,46957.294365644455,48616.05543756485,46957.294365644455,1648.9669427871704,4.806999683380127,0.0 -138100,3.5209312,2.731948,,,,,,,,,,,,,, -138200,3.7190764,2.7201633,,,,,,,,,,,,,, -138300,3.3874989,2.747615,,,,,,,,,,,,,, -138400,3.6012676,2.682063,,,,,,,,,,,,,, -138500,3.2788894,2.6765707,,,,,,,,,,,,,, -138600,3.5182297,2.7127922,,,,,,,,,,,,,, -138700,3.6648226,2.7211757,,,,,,,,,,,,,, -138800,3.5862803,2.7157278,,,,,,,,,,,,,, -138900,3.5043306,2.7205951,,,,,,,,,,,,,, -139000,3.4781733,2.7387238,,,,,,,,,,,,,, -139100,3.6739361,2.6942668,,,,,,,,,,,,,, -139200,3.2278032,2.671236,,,,,,,,,,,,,, -139300,3.5165055,2.6948147,,,,,,,,,,,,,, -139400,3.490012,2.6921566,,,,,,,,,,,,,, -139500,3.5534225,2.6674204,,,,,,,,,,,,,, -139527,,,0.8956273794174194,0.627384603023529,0.7456600069999695,1.240861415863037,50000.0,0.6183000206947327,1.87476646900177,10000.0,47467.2943482399,49143.7070350647,47467.2943482399,1666.5136742591858,4.857511758804321,0.0 -139600,3.5512047,2.7403185,,,,,,,,,,,,,, -139700,3.5040855,2.7105012,,,,,,,,,,,,,, -139800,3.4819372,2.712097,,,,,,,,,,,,,, -139900,3.329461,2.6808395,,,,,,,,,,,,,, -140000,3.519411,2.7158246,,,,,,,,,,,,,, -140100,3.4656956,2.6773715,,,,,,,,,,,,,, -140200,3.5036752,2.699788,,,,,,,,,,,,,, -140300,3.5274749,2.6912,,,,,,,,,,,,,, -140400,3.3788085,2.706058,,,,,,,,,,,,,, -140500,3.4253633,2.6982436,,,,,,,,,,,,,, -140600,3.4360826,2.6587718,,,,,,,,,,,,,, -140700,3.543608,2.67093,,,,,,,,,,,,,, -140800,3.5772748,2.6666198,,,,,,,,,,,,,, -140900,3.7375927,2.7399478,,,,,,,,,,,,,, -141000,3.6690223,2.7170224,,,,,,,,,,,,,, -141030,,,0.893973171710968,0.6151627898216248,0.7440399527549744,1.2287087440490725,50000.0,0.619700014591217,1.8548864126205444,10000.0,47977.450786590576,49671.64498496056,47977.450786590576,1684.1839241981506,4.914875984191895,0.0 -141100,3.4359937,2.6792984,,,,,,,,,,,,,, -141200,3.4817944,2.683031,,,,,,,,,,,,,, -141300,3.5788324,2.6546712,,,,,,,,,,,,,, -141400,3.2060022,2.6826546,,,,,,,,,,,,,, -141500,3.5298712,2.6807454,,,,,,,,,,,,,, -141600,3.4351418,2.6551516,,,,,,,,,,,,,, -141700,3.6812475,2.744186,,,,,,,,,,,,,, -141800,3.4328697,2.6833892,,,,,,,,,,,,,, -141900,3.6117935,2.6292098,,,,,,,,,,,,,, -142000,3.6967044,2.636239,,,,,,,,,,,,,, -142100,3.5611253,2.6582453,,,,,,,,,,,,,, -142200,3.553194,2.6843154,,,,,,,,,,,,,, -142300,3.6182473,2.742576,,,,,,,,,,,,,, -142400,3.5449967,2.6522238,,,,,,,,,,,,,, -142500,3.3379803,2.669724,,,,,,,,,,,,,, -142532,,,0.8904854655265808,0.6626537442207336,0.7447999715805054,1.2583589553833008,50000.0,0.6160000562667847,1.8964543342590328,10000.0,48487.50527572632,50199.49908399582,48487.50527572632,1701.8740639686584,4.969464063644409,0.0 -142600,3.8509054,2.641052,,,,,,,,,,,,,, -142700,3.62368,2.6911907,,,,,,,,,,,,,, -142800,4.0191345,2.6815975,,,,,,,,,,,,,, -142900,3.6824698,2.7707572,,,,,,,,,,,,,, -143000,3.4525936,2.6720684,,,,,,,,,,,,,, -143100,3.5574508,2.6579905,,,,,,,,,,,,,, -143200,3.6014185,2.6426303,,,,,,,,,,,,,, -143300,3.4259846,2.7132485,,,,,,,,,,,,,, -143400,3.5041704,2.6713097,,,,,,,,,,,,,, -143500,3.6349742,2.6654313,,,,,,,,,,,,,, -143600,3.7629771,2.7157412,,,,,,,,,,,,,, -143700,3.613033,2.6922066,,,,,,,,,,,,,, -143800,3.535899,2.6623254,,,,,,,,,,,,,, -143900,3.4625745,2.6695442,,,,,,,,,,,,,, -144000,3.4847605,2.6971712,,,,,,,,,,,,,, -144035,,,0.892598032951355,0.6262313723564148,0.7450399994850159,1.2290778160095217,50000.0,0.6198000311851501,1.864946722984314,10000.0,48997.59909081459,50727.18270134926,48997.59909081459,1719.357551574707,5.022829294204712,0.0 -144100,3.5584545,2.7344518,,,,,,,,,,,,,, -144200,3.7171493,2.6831927,,,,,,,,,,,,,, -144300,3.3649166,2.673726,,,,,,,,,,,,,, -144400,3.8822792,2.6752992,,,,,,,,,,,,,, -144500,3.427364,2.6678905,,,,,,,,,,,,,, -144600,3.5122075,2.6661816,,,,,,,,,,,,,, -144700,3.8242762,2.7194188,,,,,,,,,,,,,, -144800,4.085008,2.712782,,,,,,,,,,,,,, -144900,3.876251,2.689693,,,,,,,,,,,,,, -145000,3.7472389,2.6947732,,,,,,,,,,,,,, -145100,3.4726372,2.6584044,,,,,,,,,,,,,, -145200,3.5660949,2.6731002,,,,,,,,,,,,,, -145300,3.587718,2.666783,,,,,,,,,,,,,, -145400,3.451207,2.5794091,,,,,,,,,,,,,, -145500,3.7834136,2.629137,,,,,,,,,,,,,, -145537,,,0.8907046914100647,0.6271235346794128,0.7487599849700928,1.223987102508545,50000.0,0.6224000453948975,1.8546254634857176,10000.0,49507.578125715256,51255.033161878586,49507.578125715256,1737.1198983192444,5.078593730926514,0.0 -145600,3.6373816,2.5881546,,,,,,,,,,,,,, -145700,3.924258,2.6610284,,,,,,,,,,,,,, -145800,3.637124,2.712144,,,,,,,,,,,,,, -145900,3.4006171,2.62565,,,,,,,,,,,,,, -146000,3.7152185,2.6692612,,,,,,,,,,,,,, -146100,3.603473,2.6270602,,,,,,,,,,,,,, -146200,3.627312,2.6626952,,,,,,,,,,,,,, -146300,3.6558735,2.6341596,,,,,,,,,,,,,, -146400,3.6294038,2.6770506,,,,,,,,,,,,,, -146500,3.6886787,2.648842,,,,,,,,,,,,,, -146600,3.5035152,2.635738,,,,,,,,,,,,,, -146700,3.615999,2.6324716,,,,,,,,,,,,,, -146800,3.7149086,2.6370692,,,,,,,,,,,,,, -146900,3.6323202,2.660079,,,,,,,,,,,,,, -147000,3.8211937,2.698399,,,,,,,,,,,,,, -147039,,,0.8966238498687744,0.6165596842765808,0.7476199865341187,1.228817582130432,50000.0,0.6242000460624695,1.8519423007965088,10000.0,50017.51508355141,51782.78130912781,50017.51508355141,1754.819967508316,5.133638858795166,0.0 -147100,3.6936553,2.6593423,,,,,,,,,,,,,, -147200,3.6260207,2.6444447,,,,,,,,,,,,,, -147300,3.8156402,2.6101027,,,,,,,,,,,,,, -147400,3.5394351,2.6079202,,,,,,,,,,,,,, -147500,3.4520242,2.6492064,,,,,,,,,,,,,, -147600,3.604811,2.6604936,,,,,,,,,,,,,, -147700,3.767073,2.6784012,,,,,,,,,,,,,, -147800,4.0999346,2.6637468,,,,,,,,,,,,,, -147900,3.468,2.6999753,,,,,,,,,,,,,, -148000,3.4544559,2.6373725,,,,,,,,,,,,,, -148100,3.6235821,2.6062655,,,,,,,,,,,,,, -148200,3.485375,2.5912526,,,,,,,,,,,,,, -148300,3.3569417,2.5757434,,,,,,,,,,,,,, -148400,3.5231283,2.618812,,,,,,,,,,,,,, -148500,3.5667937,2.652279,,,,,,,,,,,,,, -148541,,,0.9024234414100648,0.5903157591819763,0.7495200037956238,1.215433120727539,50000.0,0.6272000074386597,1.8344119787216189,10000.0,50527.50337028504,52311.05200695992,50527.50337028504,1772.9918661117554,5.189709186553955,0.0 -148600,3.664223,2.6319342,,,,,,,,,,,,,, -148700,3.8115034,2.6475236,,,,,,,,,,,,,, -148800,3.7881482,2.594775,,,,,,,,,,,,,, -148900,3.7670176,2.6214893,,,,,,,,,,,,,, -149000,3.7857246,2.6926095,,,,,,,,,,,,,, -149100,3.902332,2.6659832,,,,,,,,,,,,,, -149200,3.6997118,2.6230555,,,,,,,,,,,,,, -149300,3.5109541,2.6452975,,,,,,,,,,,,,, -149400,3.6382036,2.6297255,,,,,,,,,,,,,, -149500,3.7343466,2.629896,,,,,,,,,,,,,, -149600,3.8486307,2.6479135,,,,,,,,,,,,,, -149700,3.6836035,2.6405149,,,,,,,,,,,,,, -149800,3.7172596,2.637939,,,,,,,,,,,,,, -149900,3.8836765,2.6285393,,,,,,,,,,,,,, -150000,3.7673402,2.6122565,,,,,,,,,,,,,, -150044,,,0.910574734210968,0.5532589554786682,0.75,1.207600712776184,50000.0,0.6228000521659851,1.844643831253052,10000.0,51037.676836013794,52838.9393966198,51037.676836013794,1790.6060602664948,5.234796047210693,0.0 -150100,3.7533317,2.6243784,,,,,,,,,,,,,, -150200,3.7658038,2.5633287,,,,,,,,,,,,,, -150300,3.6309867,2.6464267,,,,,,,,,,,,,, -150400,3.9714284,2.577012,,,,,,,,,,,,,, -150500,3.6261837,2.6508255,,,,,,,,,,,,,, -150600,3.8840292,2.680525,,,,,,,,,,,,,, -150700,3.7762566,2.594697,,,,,,,,,,,,,, -150800,3.6086843,2.5841343,,,,,,,,,,,,,, -150900,3.9340837,2.7135131,,,,,,,,,,,,,, -151000,3.9184732,2.5729084,,,,,,,,,,,,,, -151100,3.5300124,2.6799924,,,,,,,,,,,,,, -151200,3.7327957,2.6073413,,,,,,,,,,,,,, -151300,3.7398996,2.6312087,,,,,,,,,,,,,, -151400,3.647682,2.6413102,,,,,,,,,,,,,, -151500,3.74111,2.6116924,,,,,,,,,,,,,, -151545,,,0.909817397594452,0.566749632358551,0.7523599863052368,1.2168418169021606,50000.0,0.6243000030517578,1.843218684196472,10000.0,51547.57273697853,53366.51215076447,51547.57273697853,1808.1750228405,5.288285970687866,0.0 -151600,3.9429598,2.653317,,,,,,,,,,,,,, -151700,3.5904183,2.6022563,,,,,,,,,,,,,, -151800,3.688404,2.6122417,,,,,,,,,,,,,, -151900,3.693281,2.647799,,,,,,,,,,,,,, -152000,3.7653818,2.6564386,,,,,,,,,,,,,, -152100,3.6809812,2.6585586,,,,,,,,,,,,,, -152200,3.9461956,2.61424,,,,,,,,,,,,,, -152300,3.854369,2.6581721,,,,,,,,,,,,,, -152400,3.9934132,2.596518,,,,,,,,,,,,,, -152500,3.7318754,2.574985,,,,,,,,,,,,,, -152600,3.570209,2.6694274,,,,,,,,,,,,,, -152700,3.6058292,2.5587091,,,,,,,,,,,,,, -152800,3.614356,2.6047864,,,,,,,,,,,,,, -152900,3.9626758,2.681234,,,,,,,,,,,,,, -153000,3.857631,2.6064298,,,,,,,,,,,,,, -153048,,,0.915058970451355,0.5383015871047974,0.7546399831771851,1.1951937675476074,50000.0,0.6313000321388245,1.818287968635559,10000.0,52057.77176237106,53894.68529486656,52057.77176237106,1826.0408027172089,5.341819047927856,0.0 -153100,3.9653888,2.627553,,,,,,,,,,,,,, -153200,3.6289775,2.580917,,,,,,,,,,,,,, -153300,4.1903358,2.656369,,,,,,,,,,,,,, -153400,3.9120731,2.6730258,,,,,,,,,,,,,, -153500,3.8709755,2.6728735,,,,,,,,,,,,,, -153600,3.6630235,2.6251369,,,,,,,,,,,,,, -153700,3.7491412,2.578239,,,,,,,,,,,,,, -153800,3.4677665,2.5714006,,,,,,,,,,,,,, -153900,3.6307507,2.569118,,,,,,,,,,,,,, -154000,3.7667158,2.6275318,,,,,,,,,,,,,, -154100,3.9449737,2.6868343,,,,,,,,,,,,,, -154200,3.762706,2.5416415,,,,,,,,,,,,,, -154300,4.214511,2.6428595,,,,,,,,,,,,,, -154400,3.5666254,2.5259657,,,,,,,,,,,,,, -154500,3.9577725,2.6263435,,,,,,,,,,,,,, -154550,,,0.9122289419174194,0.5614323616027832,0.75382000207901,1.2097517251968384,50000.0,0.6278000473976135,1.8361276388168333,10000.0,52567.898394823074,54422.400752067566,52567.898394823074,1843.5144710540767,5.40025782585144,0.0 -154600,3.5459943,2.5381577,,,,,,,,,,,,,, -154700,3.5262785,2.5711918,,,,,,,,,,,,,, -154800,3.7462213,2.6700532,,,,,,,,,,,,,, -154900,3.587391,2.5278122,,,,,,,,,,,,,, -155000,3.8899186,2.613766,,,,,,,,,,,,,, -155100,3.864156,2.5877557,,,,,,,,,,,,,, -155200,3.7540596,2.6390922,,,,,,,,,,,,,, -155300,4.0717316,2.6478133,,,,,,,,,,,,,, -155400,3.7716327,2.5980296,,,,,,,,,,,,,, -155500,3.6745298,2.572721,,,,,,,,,,,,,, -155600,3.8443246,2.5842748,,,,,,,,,,,,,, -155700,3.93102,2.6396253,,,,,,,,,,,,,, -155800,3.8537235,2.6176488,,,,,,,,,,,,,, -155900,3.9270918,2.6531694,,,,,,,,,,,,,, -156000,3.689058,2.602862,,,,,,,,,,,,,, -156053,,,0.9175502061843872,0.5501120090484619,0.7555999755859375,1.2017172574996948,50000.0,0.6291000247001648,1.8328757286071773,10000.0,53077.98852467537,54950.4566886425,53077.98852467537,1861.369544029236,5.455489873886108,0.0 -156100,3.6776319,2.560062,,,,,,,,,,,,,, -156200,3.8687255,2.6221836,,,,,,,,,,,,,, -156300,3.937055,2.598071,,,,,,,,,,,,,, -156400,3.739452,2.582587,,,,,,,,,,,,,, -156500,3.8249607,2.5922115,,,,,,,,,,,,,, -156600,3.649926,2.6258683,,,,,,,,,,,,,, -156700,4.236872,2.6068542,,,,,,,,,,,,,, -156800,3.728548,2.572894,,,,,,,,,,,,,, -156900,3.9500778,2.5512009,,,,,,,,,,,,,, -157000,3.7949715,2.5978472,,,,,,,,,,,,,, -157100,4.09181,2.6125605,,,,,,,,,,,,,, -157200,3.9551947,2.5554006,,,,,,,,,,,,,, -157300,3.8491812,2.585172,,,,,,,,,,,,,, -157400,4.104538,2.599752,,,,,,,,,,,,,, -157500,3.6881835,2.5235171,,,,,,,,,,,,,, -157555,,,0.9153977632522584,0.5394268035888672,0.7566399574279785,1.1991684436798096,50000.0,0.6296000480651855,1.8347876071929927,10000.0,53588.06019878388,55479.06328034401,53588.06019878388,1879.8014032840729,5.5035552978515625,0.0 -157600,3.6868742,2.5911055,,,,,,,,,,,,,, -157700,3.9318345,2.6050644,,,,,,,,,,,,,, -157800,3.8739698,2.5684466,,,,,,,,,,,,,, -157900,3.9178822,2.673845,,,,,,,,,,,,,, -158000,3.7718148,2.5289068,,,,,,,,,,,,,, -158100,3.7894778,2.5755575,,,,,,,,,,,,,, -158200,4.0874805,2.6315854,,,,,,,,,,,,,, -158300,3.7902007,2.6206756,,,,,,,,,,,,,, -158400,3.771783,2.6310241,,,,,,,,,,,,,, -158500,4.3809705,2.6709533,,,,,,,,,,,,,, -158600,3.965476,2.5917113,,,,,,,,,,,,,, -158700,4.053146,2.6101406,,,,,,,,,,,,,, -158800,3.9724014,2.647825,,,,,,,,,,,,,, -158900,3.9377189,2.6103158,,,,,,,,,,,,,, -159000,3.973293,2.589638,,,,,,,,,,,,,, -159058,,,0.9265186190605164,0.5162858366966248,0.7549399733543396,1.2046244144439695,50000.0,0.6288000345230103,1.838891744613648,10000.0,54098.17010354996,56007.06683254242,54098.17010354996,1897.5863370895383,5.557616472244263,0.0 -159100,4.26144,2.601939,,,,,,,,,,,,,, -159200,3.958495,2.5866685,,,,,,,,,,,,,, -159300,3.7039177,2.5923913,,,,,,,,,,,,,, -159400,3.7165995,2.595333,,,,,,,,,,,,,, -159500,3.7850149,2.636967,,,,,,,,,,,,,, -159600,3.7166326,2.5603173,,,,,,,,,,,,,, -159700,4.032635,2.5578876,,,,,,,,,,,,,, -159800,4.008383,2.5773304,,,,,,,,,,,,,, -159900,3.9130225,2.6091428,,,,,,,,,,,,,, -160000,3.922452,2.6528594,,,,,,,,,,,,,, -160100,3.949869,2.5829582,,,,,,,,,,,,,, -160200,3.847804,2.5716214,,,,,,,,,,,,,, -160300,3.951779,2.6392033,,,,,,,,,,,,,, -160400,3.9071233,2.5851018,,,,,,,,,,,,,, -160500,4.061559,2.6175783,,,,,,,,,,,,,, -160560,,,0.924824595451355,0.5198838114738464,0.7564599514007568,1.200865626335144,50000.0,0.6307000517845154,1.830616116523743,10000.0,54608.08381175995,56534.48181200028,54608.08381175995,1914.979204893112,5.611308336257935,0.0 -160600,3.7791266,2.5193696,,,,,,,,,,,,,, -160700,3.729581,2.5295222,,,,,,,,,,,,,, -160800,3.8800232,2.5526233,,,,,,,,,,,,,, -160900,4.0546083,2.5527086,,,,,,,,,,,,,, -161000,4.0070477,2.5855508,,,,,,,,,,,,,, -161100,3.608801,2.5213645,,,,,,,,,,,,,, -161200,4.3919168,2.60923,,,,,,,,,,,,,, -161300,3.893381,2.6121743,,,,,,,,,,,,,, -161400,4.464067,2.5928624,,,,,,,,,,,,,, -161500,4.107287,2.5528936,,,,,,,,,,,,,, -161600,3.9799206,2.5577273,,,,,,,,,,,,,, -161700,3.9826524,2.5690422,,,,,,,,,,,,,, -161800,3.747503,2.548804,,,,,,,,,,,,,, -161900,3.7950315,2.5517359,,,,,,,,,,,,,, -162000,3.950774,2.5565236,,,,,,,,,,,,,, -162062,,,0.9266381859779358,0.4990904629230499,0.7584199905395508,1.1812278032302856,50000.0,0.6338000297546387,1.8089417219161987,10000.0,55118.07119345665,57062.24536585808,55118.07119345665,1932.6419672966003,5.669963121414185,0.0 -162100,4.0139303,2.627452,,,,,,,,,,,,,, -162200,3.9553971,2.5903635,,,,,,,,,,,,,, -162300,3.936158,2.5451694,,,,,,,,,,,,,, -162400,3.785499,2.574199,,,,,,,,,,,,,, -162500,4.0491495,2.5551853,,,,,,,,,,,,,, -162600,4.038548,2.5458164,,,,,,,,,,,,,, -162700,3.8135085,2.569275,,,,,,,,,,,,,, -162800,3.985501,2.551557,,,,,,,,,,,,,, -162900,3.8661609,2.548439,,,,,,,,,,,,,, -163000,4.101587,2.5553052,,,,,,,,,,,,,, -163100,3.9605494,2.572979,,,,,,,,,,,,,, -163200,3.9987245,2.5353558,,,,,,,,,,,,,, -163300,4.0411654,2.6141949,,,,,,,,,,,,,, -163400,3.9081538,2.5666485,,,,,,,,,,,,,, -163500,4.2052264,2.5579169,,,,,,,,,,,,,, -163564,,,0.9269371628761292,0.5071867108345032,0.7584199905395508,1.188880205154419,50000.0,0.6325000524520874,1.8226237297058103,10000.0,55628.10046887398,57590.141486644745,55628.10046887398,1950.3917593956,5.731624841690064,0.0 -163600,4.2783184,2.573719,,,,,,,,,,,,,, -163700,3.84818,2.5496345,,,,,,,,,,,,,, -163800,3.8480244,2.552859,,,,,,,,,,,,,, -163900,4.211067,2.5695052,,,,,,,,,,,,,, -164000,3.8905277,2.5090594,,,,,,,,,,,,,, -164100,4.0719404,2.5459836,,,,,,,,,,,,,, -164200,3.699343,2.569457,,,,,,,,,,,,,, -164300,3.927957,2.5880089,,,,,,,,,,,,,, -164400,4.1628795,2.572523,,,,,,,,,,,,,, -164500,3.6943305,2.5608552,,,,,,,,,,,,,, -164600,3.7808228,2.5346153,,,,,,,,,,,,,, -164700,4.113251,2.5373607,,,,,,,,,,,,,, -164800,3.8742766,2.5489945,,,,,,,,,,,,,, -164900,3.9156084,2.6176603,,,,,,,,,,,,,, -165000,4.2761445,2.5552716,,,,,,,,,,,,,, -165066,,,0.9267378449440002,0.4970170259475708,0.7583799958229065,1.1788029670715332,50000.0,0.6345000267028809,1.815388798713684,10000.0,56138.13663291931,58117.80002164841,56138.13663291931,1967.9118270874023,5.779476404190064,0.0 -165100,4.0954723,2.578295,,,,,,,,,,,,,, -165200,3.6636376,2.4779596,,,,,,,,,,,,,, -165300,4.085382,2.6368322,,,,,,,,,,,,,, -165400,3.9224367,2.5350432,,,,,,,,,,,,,, -165500,4.109632,2.5387895,,,,,,,,,,,,,, -165600,3.8809762,2.5959418,,,,,,,,,,,,,, -165700,3.7826264,2.5370438,,,,,,,,,,,,,, -165800,3.6676311,2.5166328,,,,,,,,,,,,,, -165900,4.10157,2.5775402,,,,,,,,,,,,,, -166000,4.143861,2.565705,,,,,,,,,,,,,, -166100,3.7810514,2.483598,,,,,,,,,,,,,, -166200,4.3869777,2.5518832,,,,,,,,,,,,,, -166300,4.075483,2.6156604,,,,,,,,,,,,,, -166400,4.2269855,2.5522778,,,,,,,,,,,,,, -166500,4.0410314,2.5656753,,,,,,,,,,,,,, -166568,,,0.9300262928009032,0.5010972619056702,0.7594999670982361,1.1896491050720217,50000.0,0.6345000267028809,1.816957950592041,10000.0,56648.308772563934,58645.63115429878,56648.308772563934,1985.462195396424,5.834491014480591,0.0 -166600,4.246737,2.6042814,,,,,,,,,,,,,, -166700,4.502079,2.542125,,,,,,,,,,,,,, -166800,3.8700905,2.5671072,,,,,,,,,,,,,, -166900,4.040892,2.5648093,,,,,,,,,,,,,, -167000,3.8218226,2.562787,,,,,,,,,,,,,, -167100,3.8672564,2.5057933,,,,,,,,,,,,,, -167200,3.9682531,2.506561,,,,,,,,,,,,,, -167300,3.9242377,2.5335674,,,,,,,,,,,,,, -167400,4.0199714,2.5491993,,,,,,,,,,,,,, -167500,4.0938034,2.5812025,,,,,,,,,,,,,, -167600,3.9344597,2.5073106,,,,,,,,,,,,,, -167700,4.275846,2.5959399,,,,,,,,,,,,,, -167800,3.8747091,2.5371242,,,,,,,,,,,,,, -167900,4.3125863,2.5869877,,,,,,,,,,,,,, -168000,3.9778376,2.512589,,,,,,,,,,,,,, -168070,,,0.9396922588348388,0.4563548862934112,0.7609599828720093,1.1790025234222412,50000.0,0.6356000304222107,1.80722975730896,10000.0,57158.31609511376,59173.45578980446,57158.31609511376,2003.1597566604607,5.8996758460998535,0.0 -168100,3.9902868,2.6361058,,,,,,,,,,,,,, -168200,4.1628294,2.5466769,,,,,,,,,,,,,, -168300,4.5686665,2.6230316,,,,,,,,,,,,,, -168400,4.0622783,2.5548487,,,,,,,,,,,,,, -168500,3.9275322,2.5318682,,,,,,,,,,,,,, -168600,3.9608264,2.517726,,,,,,,,,,,,,, -168700,3.9672046,2.5769866,,,,,,,,,,,,,, -168800,4.006992,2.5643575,,,,,,,,,,,,,, -168900,4.1895576,2.548225,,,,,,,,,,,,,, -169000,4.11176,2.5729284,,,,,,,,,,,,,, -169100,4.0602508,2.5702524,,,,,,,,,,,,,, -169200,3.8768325,2.48959,,,,,,,,,,,,,, -169300,3.8907986,2.5406365,,,,,,,,,,,,,, -169400,4.092389,2.5465486,,,,,,,,,,,,,, -169500,4.0631266,2.559446,,,,,,,,,,,,,, -169572,,,0.9342115521430968,0.4681693911552429,0.7597799897193909,1.1787148714065552,50000.0,0.6360000371932983,1.8080520629882808,10000.0,57668.43019104004,59701.16584134102,57668.43019104004,2020.6457545757287,5.955905914306641,0.0 -169600,4.0051045,2.5134914,,,,,,,,,,,,,, -169700,4.1418667,2.5665603,,,,,,,,,,,,,, -169800,4.0733123,2.5510447,,,,,,,,,,,,,, -169900,4.0628204,2.5383458,,,,,,,,,,,,,, -170000,4.0195775,2.5399525,,,,,,,,,,,,,, -170100,3.775385,2.5012739,,,,,,,,,,,,,, -170200,4.2892594,2.5076814,,,,,,,,,,,,,, -170300,4.0826836,2.5597045,,,,,,,,,,,,,, -170400,3.8965414,2.498259,,,,,,,,,,,,,, -170500,3.829483,2.5725527,,,,,,,,,,,,,, -170600,4.4069896,2.505072,,,,,,,,,,,,,, -170700,4.2324643,2.5403156,,,,,,,,,,,,,, -170800,4.009616,2.4845052,,,,,,,,,,,,,, -170900,3.9845202,2.5368383,,,,,,,,,,,,,, -171000,4.2173576,2.5380008,,,,,,,,,,,,,, -171074,,,0.9379583597183228,0.4724557101726532,0.7612800002098083,1.180039882659912,50000.0,0.6367000341415405,1.8073753118515008,10000.0,58178.49935603142,60229.28386855125,58178.49935603142,2038.581184387207,6.014849662780762,0.0 -171100,4.1327343,2.5349488,,,,,,,,,,,,,, -171200,3.799416,2.5266023,,,,,,,,,,,,,, -171300,4.0744047,2.5816326,,,,,,,,,,,,,, -171400,4.0009103,2.5346143,,,,,,,,,,,,,, -171500,3.9990182,2.5055242,,,,,,,,,,,,,, -171600,3.9394484,2.5027149,,,,,,,,,,,,,, -171700,3.984606,2.521673,,,,,,,,,,,,,, -171800,3.798546,2.4364562,,,,,,,,,,,,,, -171900,4.2151985,2.5151553,,,,,,,,,,,,,, -172000,4.2663593,2.6105766,,,,,,,,,,,,,, -172100,4.1582484,2.5193937,,,,,,,,,,,,,, -172200,4.073495,2.523716,,,,,,,,,,,,,, -172300,3.8862243,2.5274272,,,,,,,,,,,,,, -172400,4.002135,2.5205743,,,,,,,,,,,,,, -172500,4.0266666,2.5123358,,,,,,,,,,,,,, -172575,,,0.9350087642669678,0.4679067730903625,0.761680006980896,1.1774041652679443,50000.0,0.6373000144958496,1.7978694438934326,10000.0,58688.43921303749,60756.84143066406,58688.43921303749,2056.078993320465,6.078775405883789,0.0 -172600,3.9574037,2.5256171,,,,,,,,,,,,,, -172700,4.1958528,2.5013607,,,,,,,,,,,,,, -172800,4.46245,2.5637357,,,,,,,,,,,,,, -172900,4.312463,2.5480344,,,,,,,,,,,,,, -173000,4.2838464,2.5121524,,,,,,,,,,,,,, -173100,4.1627035,2.5430684,,,,,,,,,,,,,, -173200,4.0890813,2.5262592,,,,,,,,,,,,,, -173300,4.189907,2.5341058,,,,,,,,,,,,,, -173400,4.026056,2.5102184,,,,,,,,,,,,,, -173500,4.3372974,2.5756886,,,,,,,,,,,,,, -173600,4.456143,2.5412695,,,,,,,,,,,,,, -173700,4.5788045,2.563806,,,,,,,,,,,,,, -173800,4.2964754,2.492326,,,,,,,,,,,,,, -173900,3.9989662,2.5519576,,,,,,,,,,,,,, -174000,4.103417,2.4909384,,,,,,,,,,,,,, -174077,,,0.936344027519226,0.4684905409812927,0.7619799971580505,1.1775026321411133,50000.0,0.6359000205993652,1.802971720695496,10000.0,59198.377883434296,61284.59939098358,59198.377883434296,2073.7847397327423,6.137209415435791,0.0 -174100,4.0600595,2.5698693,,,,,,,,,,,,,, -174200,4.095358,2.5533032,,,,,,,,,,,,,, -174300,4.092764,2.5461264,,,,,,,,,,,,,, -174400,3.6643233,2.4616685,,,,,,,,,,,,,, -174500,3.897497,2.481592,,,,,,,,,,,,,, -174600,4.0930862,2.4784083,,,,,,,,,,,,,, -174700,3.7993407,2.476505,,,,,,,,,,,,,, -174800,4.052763,2.4878123,,,,,,,,,,,,,, -174900,4.0842266,2.5304136,,,,,,,,,,,,,, -175000,4.0206456,2.5279663,,,,,,,,,,,,,, -175100,4.2577996,2.562805,,,,,,,,,,,,,, -175200,3.8958952,2.4831302,,,,,,,,,,,,,, -175300,4.032837,2.4786837,,,,,,,,,,,,,, -175400,4.22609,2.5209553,,,,,,,,,,,,,, -175500,4.0673366,2.5013475,,,,,,,,,,,,,, -175578,,,0.9384765625,0.4636580049991607,0.7625199556350708,1.177310824394226,50000.0,0.6399000287055969,1.8003582954406738,10000.0,59708.31727671623,61812.00027608872,59708.31727671623,2091.134479045868,6.194151878356934,0.0 -175600,3.8876674,2.545309,,,,,,,,,,,,,, -175700,3.8539639,2.5020554,,,,,,,,,,,,,, -175800,4.1126904,2.5140288,,,,,,,,,,,,,, -175900,4.00838,2.4978626,,,,,,,,,,,,,, -176000,3.9758105,2.5498412,,,,,,,,,,,,,, -176100,4.074686,2.4859462,,,,,,,,,,,,,, -176200,4.026954,2.531215,,,,,,,,,,,,,, -176300,3.9087234,2.4977756,,,,,,,,,,,,,, -176400,3.9691594,2.458797,,,,,,,,,,,,,, -176500,4.54362,2.5099144,,,,,,,,,,,,,, -176600,4.369246,2.519712,,,,,,,,,,,,,, -176700,3.9679828,2.4889596,,,,,,,,,,,,,, -176800,4.067309,2.5590541,,,,,,,,,,,,,, -176900,4.2445774,2.4769223,,,,,,,,,,,,,, -177000,4.2160664,2.501441,,,,,,,,,,,,,, -177080,,,0.94046950340271,0.4543436765670776,0.762779951095581,1.1745065450668335,50000.0,0.6403000354766846,1.799221396446228,10000.0,60218.3447508812,62340.08955526352,60218.3447508812,2109.08202624321,6.251053094863892,0.0 -177100,4.1892214,2.5011754,,,,,,,,,,,,,, -177200,4.139248,2.5656772,,,,,,,,,,,,,, -177300,4.0714874,2.5108843,,,,,,,,,,,,,, -177400,3.9617143,2.5416832,,,,,,,,,,,,,, -177500,3.9200072,2.5056558,,,,,,,,,,,,,, -177600,4.031348,2.5109434,,,,,,,,,,,,,, -177700,3.6898222,2.4848547,,,,,,,,,,,,,, -177800,4.0345263,2.4920912,,,,,,,,,,,,,, -177900,4.4763203,2.5584593,,,,,,,,,,,,,, -178000,4.2527256,2.5490127,,,,,,,,,,,,,, -178100,3.8824828,2.4966025,,,,,,,,,,,,,, -178200,4.137436,2.5505977,,,,,,,,,,,,,, -178300,4.4938283,2.5793705,,,,,,,,,,,,,, -178400,4.0057225,2.500049,,,,,,,,,,,,,, -178500,3.8906941,2.5317552,,,,,,,,,,,,,, -178582,,,0.9402901530265808,0.4555262923240661,0.7623199820518494,1.173393964767456,50000.0,0.6392000317573547,1.7995479106903076,10000.0,60728.52793264389,62868.04338693619,60728.52793264389,2126.7394778728485,6.30780291557312,0.0 -178600,4.2397156,2.5602448,,,,,,,,,,,,,, -178700,4.0592203,2.490098,,,,,,,,,,,,,, -178800,4.418627,2.5227828,,,,,,,,,,,,,, -178900,4.066219,2.5742466,,,,,,,,,,,,,, -179000,3.9838545,2.6033134,,,,,,,,,,,,,, -179100,4.1017833,2.5369184,,,,,,,,,,,,,, -179200,4.0285826,2.533519,,,,,,,,,,,,,, -179300,4.288788,2.5153084,,,,,,,,,,,,,, -179400,3.9374492,2.5261908,,,,,,,,,,,,,, -179500,4.319802,2.5668125,,,,,,,,,,,,,, -179600,3.8574936,2.516226,,,,,,,,,,,,,, -179700,3.8928864,2.513348,,,,,,,,,,,,,, -179800,4.179974,2.5241134,,,,,,,,,,,,,, -179900,3.9345691,2.5256498,,,,,,,,,,,,,, -180000,4.4686203,2.4995527,,,,,,,,,,,,,, -180084,,,0.9409279227256776,0.4550089836120605,0.7625799775123596,1.172922968864441,50000.0,0.6391000151634216,1.7973567247390747,10000.0,61238.67451667786,63395.76078367233,61238.67451667786,2144.196943998337,6.366096258163452,0.0 -180100,3.713095,2.501761,,,,,,,,,,,,,, -180200,4.13243,2.4994378,,,,,,,,,,,,,, -180300,4.437306,2.5035555,,,,,,,,,,,,,, -180400,4.0161743,2.515047,,,,,,,,,,,,,, -180500,3.931924,2.5521333,,,,,,,,,,,,,, -180600,4.0910172,2.5339422,,,,,,,,,,,,,, -180700,4.075099,2.5456202,,,,,,,,,,,,,, -180800,4.0359654,2.5011241,,,,,,,,,,,,,, -180900,4.0409007,2.5354466,,,,,,,,,,,,,, -181000,4.133274,2.4643016,,,,,,,,,,,,,, -181100,4.1236277,2.5716429,,,,,,,,,,,,,, -181200,4.155701,2.5370421,,,,,,,,,,,,,, -181300,3.9060636,2.5098896,,,,,,,,,,,,,, -181400,3.8575046,2.4993734,,,,,,,,,,,,,, -181500,4.218663,2.5365615,,,,,,,,,,,,,, -181586,,,0.941824734210968,0.455964058637619,0.7629599571228027,1.1759765148162842,50000.0,0.6389000415802002,1.7996957302093506,10000.0,61748.84334039688,63923.51679396629,61748.84334039688,2161.671940803528,6.423700332641602,0.0 -181600,4.2505784,2.5746388,,,,,,,,,,,,,, -181700,3.9505224,2.4753237,,,,,,,,,,,,,, -181800,4.252544,2.4718509,,,,,,,,,,,,,, -181900,4.062149,2.5005279,,,,,,,,,,,,,, -182000,3.7696707,2.5098658,,,,,,,,,,,,,, -182100,4.190776,2.5306268,,,,,,,,,,,,,, -182200,4.153747,2.5459874,,,,,,,,,,,,,, -182300,4.0702953,2.5216303,,,,,,,,,,,,,, -182400,4.3497844,2.5416667,,,,,,,,,,,,,, -182500,4.05983,2.4746916,,,,,,,,,,,,,, -182600,4.3583384,2.5567408,,,,,,,,,,,,,, -182700,4.091492,2.5187469,,,,,,,,,,,,,, -182800,4.317921,2.588028,,,,,,,,,,,,,, -182900,4.272286,2.5282962,,,,,,,,,,,,,, -183000,4.1488743,2.5290718,,,,,,,,,,,,,, -183087,,,0.9397321343421936,0.4528777301311493,0.7628600001335144,1.174703598022461,50000.0,0.6394000053405762,1.797763705253601,10000.0,62258.9725549221,64451.45683383942,62258.9725549221,2179.367330789566,6.485302448272705,0.0 -183100,4.171867,2.4969726,,,,,,,,,,,,,, -183200,4.116329,2.5369513,,,,,,,,,,,,,, -183300,4.350289,2.5121524,,,,,,,,,,,,,, -183400,3.7303586,2.4668093,,,,,,,,,,,,,, -183500,4.4172373,2.5656462,,,,,,,,,,,,,, -183600,4.3930554,2.5558453,,,,,,,,,,,,,, -183700,4.103432,2.562167,,,,,,,,,,,,,, -183800,3.7659118,2.505311,,,,,,,,,,,,,, -183900,4.153039,2.5928657,,,,,,,,,,,,,, -184000,4.0949364,2.5458033,,,,,,,,,,,,,, -184100,3.9638572,2.4121823,,,,,,,,,,,,,, -184200,4.250788,2.5304248,,,,,,,,,,,,,, -184300,3.9388504,2.5653949,,,,,,,,,,,,,, -184400,4.206244,2.490188,,,,,,,,,,,,,, -184500,4.3383613,2.5121133,,,,,,,,,,,,,, -184588,,,0.939473032951355,0.4559187591075897,0.7632799744606018,1.1745176315307615,50000.0,0.6394000053405762,1.797345757484436,10000.0,62768.99567198753,64979.35884261131,62768.99567198753,2197.123387575149,6.553771495819092,0.0 -184600,4.06933,2.4919362,,,,,,,,,,,,,, -184700,4.0010695,2.5146046,,,,,,,,,,,,,, -184800,3.991652,2.4978719,,,,,,,,,,,,,, -184900,4.1952586,2.5373778,,,,,,,,,,,,,, -185000,3.9919982,2.5210798,,,,,,,,,,,,,, -185100,4.245226,2.5227108,,,,,,,,,,,,,, -185200,4.175022,2.4668334,,,,,,,,,,,,,, -185293,,,,,,,,,,,63008.30585741997,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/eval_measurements.csv deleted file mode 100644 index bd94c886b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.85161852836609,0.0,30.50831913948059,1,0,30.50831913948059,0.0006000000284984,6.9125494956970215,10000,48.36005544662476,0.0009964923374354,6.91312313079834,0.0007599999662488,6.913174629211426,50000 -35.9192590713501,0.019089937210083,540.5099921226501,1497,0,540.5099921226501,0.0462000034749507,5.633206367492676,10000,576.5030663013458,0.0686981827020645,5.348140716552734,0.0642599985003471,5.396807193756104,50000 -53.85086178779602,0.0537657737731933,1050.4253158569336,2993,0,1050.4253158569336,0.1075000017881393,4.922767639160156,10000,1104.4395382404327,0.1697425097227096,4.304430961608887,0.1521599888801574,4.434348106384277,50000 -71.5679407119751,0.0806002616882324,1560.6546158790588,4489,0,1560.6546158790588,0.1746000051498413,4.269079208374023,10000,1632.4659051895142,0.2671595811843872,3.542244672775269,0.245619997382164,3.673764944076538,50000 -89.33231997489929,0.1103501319885253,2070.620194196701,5985,0,2070.620194196701,0.2565000057220459,3.7345449924468994,10000,2160.280205488205,0.3645368218421936,2.9136829376220703,0.340719997882843,3.063777446746826,50000 -107.3965380191803,0.142303466796875,2580.722687244416,7481,0,2580.722687244416,0.2895000278949737,3.5531110763549805,10000,2688.5322892665863,0.442402720451355,2.474280834197998,0.3840200006961822,2.817934513092041,50000 -126.0262610912323,0.1721491813659668,3090.6557648181915,8978,0,3090.6557648181915,0.345300018787384,3.1740477085113525,10000,3217.1792571544647,0.4957549273967743,2.157732963562012,0.4500799775123596,2.436403274536133,50000 -144.09838795661926,0.2016129493713379,3600.778788328171,10475,0,3600.778788328171,0.3767000138759613,2.959792375564575,10000,3745.4569323062897,0.521882951259613,2.0246825218200684,0.4817799925804138,2.25955581665039,50000 -161.8473880290985,0.2317366600036621,4110.76530623436,11973,0,4110.76530623436,0.397100031375885,2.806823492050171,10000,4273.276990890503,0.5642538070678711,1.8322232961654663,0.5165799856185913,2.0682852268218994,50000 -179.95350456237793,0.2778005599975586,4620.987542629242,13472,0,4620.987542629242,0.4182000160217285,2.717848777770996,10000,4801.707133054733,0.5740792155265808,1.7880722284317017,0.5296199917793274,2.0117011070251465,50000 -197.8390641212464,0.3077218532562256,5131.111432313919,14972,0,5131.111432313919,0.4134000241756439,2.7285807132720947,10000,5329.801098108292,0.5755141973495483,1.759779930114746,0.540399968624115,1.9560120105743408,50000 -215.526605129242,0.3375465869903564,5641.149353981018,16471,0,5641.149353981018,0.4357000291347503,2.652667284011841,10000,5857.610563755035,0.6073421239852905,1.608778953552246,0.5467000007629395,1.9259779453277588,50000 -233.45811223983765,0.3709759712219238,6151.183410406113,17972,0,6151.183410406113,0.432200014591217,2.6180338859558105,10000,6385.664488315582,0.6112284660339355,1.6073660850524902,0.5543400049209595,1.895092725753784,50000 -251.256432056427,0.4122631549835205,6661.199192047119,19473,0,6661.199192047119,0.4406000077724457,2.583203077316284,10000,6913.575413227081,0.6114277839660645,1.587865114212036,0.5617799758911133,1.8461424112319944,50000 -268.92280554771423,0.4528038501739502,7171.261469125748,20974,0,7171.261469125748,0.4406000077724457,2.576707363128662,10000,7441.400032997131,0.6094148755073547,1.583729267120361,0.5648800134658813,1.8347382545471191,50000 -286.828584432602,0.4840741157531738,7681.373394966125,22476,0,7681.373394966125,0.443200021982193,2.5712270736694336,10000,7969.503405809402,0.6075215339660645,1.6099942922592163,0.5667600035667419,1.828776240348816,50000 -304.44716572761536,0.5166065692901611,8191.535412549972,23978,0,8191.535412549972,0.4435000121593475,2.56406831741333,10000,8497.370985507965,0.6109893321990967,1.584595799446106,0.5736599564552307,1.797381043434143,50000 -322.6385922431946,0.5493104457855225,8701.451757669449,25479,0,8701.451757669449,0.4689000248908996,2.438860893249512,10000,9025.565544128418,0.6316764950752258,1.49774432182312,0.5838800072669983,1.7310395240783691,50000 -340.3928325176239,0.5819401741027832,9211.581802606584,26982,0,9211.581802606584,0.4593000113964081,2.5006914138793945,10000,9553.536763191223,0.6469228267669678,1.421025037765503,0.5805799961090088,1.7613977193832395,50000 -358.2810490131378,0.6139397621154785,9721.777070045471,28484,0,9721.777070045471,0.4655000269412994,2.420314311981201,10000,10081.70643401146,0.6459661722183228,1.4371527433395386,0.5914199948310852,1.6988749504089355,50000 -376.0208065509796,0.6470503807067871,10232.031190395355,29988,0,10232.031190395355,0.4718000292778015,2.4125912189483643,10000,10609.787842988968,0.6504305005073547,1.4086905717849731,0.5985400080680847,1.6694458723068235,50000 -393.6835870742798,0.680239200592041,10742.144251823423,31491,0,10742.144251823423,0.4706000089645386,2.442723035812378,10000,11137.651804208755,0.6327327489852905,1.4778612852096558,0.5869199633598328,1.7286511659622192,50000 -411.8246719837189,0.7132534980773926,11252.236163139343,32994,0,11252.236163139343,0.4778000116348266,2.3845551013946533,10000,11665.972929954529,0.6381337642669678,1.4649090766906738,0.5950799584388733,1.6732388734817505,50000 -429.4828112125397,0.7562506198883057,11762.316961288452,34497,0,11762.316961288452,0.4724000096321106,2.404728889465332,10000,12193.809639453888,0.647480845451355,1.4073861837387085,0.6032800078392029,1.6388838291168213,50000 -447.3253636360169,0.7911787033081055,12272.395174503326,36000,0,12272.395174503326,0.4748000204563141,2.377642869949341,10000,12721.819028377531,0.6731704473495483,1.2901920080184937,0.5949400067329407,1.6796997785568235,50000 -465.0970587730408,0.8253750801086426,12782.483393192291,37501,0,12782.483393192291,0.4794000089168548,2.349362134933472,10000,13249.76772403717,0.6614716053009033,1.3486958742141724,0.602679967880249,1.6470184326171875,50000 -483.1254951953888,0.8638536930084229,13292.577105998991,39005,0,13292.577105998991,0.4794000089168548,2.3809192180633545,10000,13777.983241558077,0.6461654901504517,1.4173647165298462,0.5987799763679504,1.6708097457885742,50000 -501.1253571510315,0.8978090286254883,13802.59973692894,40508,0,13802.59973692894,0.488500028848648,2.363870143890381,10000,14306.094644546509,0.6537986397743225,1.3841086626052856,0.6029999852180481,1.6390022039413452,50000 -518.831524848938,0.941624641418457,14312.696018218994,42012,0,14312.696018218994,0.4715000092983246,2.4117226600646973,10000,14833.994425058365,0.644949734210968,1.4297263622283936,0.5964999794960022,1.6860605478286743,50000 -536.7683305740356,0.976036548614502,14822.645218849182,43515,0,14822.645218849182,0.4844000339508056,2.345468759536743,10000,15361.96886754036,0.6421595811843872,1.4424054622650146,0.5974400043487549,1.672808289527893,50000 -554.4605889320374,1.0155045986175537,15332.848045110704,45019,0,15332.848045110704,0.4869000315666199,2.34109878540039,10000,15889.959641456604,0.6957908272743225,1.203609585762024,0.6162199974060059,1.5929490327835083,50000 -572.4644169807434,1.0506083965301514,15842.934381008148,46523,0,15842.934381008148,0.4812000095844269,2.37629771232605,10000,16418.139196634293,0.6674505472183228,1.308665156364441,0.6080999970436096,1.629995584487915,50000 -591.0792412757874,1.088939905166626,16352.884350776672,48027,0,16352.884350776672,0.4751000106334686,2.403264045715332,10000,16946.79745745659,0.6508689522743225,1.3953670263290403,0.5989199876785278,1.662032961845398,50000 -608.8275811672211,1.131298542022705,16863.04797935486,49531,0,16863.04797935486,0.4963000118732452,2.311565637588501,10000,17474.806631326675,0.6604352593421936,1.361952304840088,0.6140999794006348,1.5894731283187866,50000 -626.5100448131561,1.1692650318145752,17373.247692346573,51036,0,17373.247692346573,0.4839000105857849,2.3610148429870605,10000,18002.78094768524,0.6528220772743225,1.377036690711975,0.6045599579811096,1.6280007362365725,50000 -644.3940415382385,1.2060277462005615,17883.32636666298,52540,0,17883.32636666298,0.4944000244140625,2.280700206756592,10000,18530.83543086052,0.6638432741165161,1.3441935777664185,0.6182799935340881,1.5748586654663086,50000 -662.1686675548553,1.2445552349090576,18393.53794503212,54045,0,18393.53794503212,0.4820000231266022,2.356093168258667,10000,19058.91395521164,0.6905492544174194,1.2071837186813354,0.604699969291687,1.6392205953598022,50000 -679.9338908195496,1.2859973907470703,18903.47699022293,55549,0,18903.47699022293,0.48580002784729,2.349951982498169,10000,19586.715607881542,0.6698620915412903,1.3108255863189695,0.608959972858429,1.6123535633087158,50000 -697.7373259067535,1.3273625373840332,19413.60482478141,57054,0,19413.60482478141,0.497700035572052,2.252488136291504,10000,20114.74164962769,0.6702606678009033,1.301419734954834,0.6184399724006653,1.5728846788406372,50000 -715.5969526767731,1.3676202297210691,19923.63752818108,58558,0,19923.63752818108,0.4946000277996063,2.278593063354492,10000,20642.72845196724,0.6655771732330322,1.3216028213500977,0.6157599687576294,1.5845470428466797,50000 -733.2713446617126,1.408951997756958,20433.673320770264,60063,0,20433.673320770264,0.4969000220298767,2.297972679138184,10000,21170.53536248207,0.6641222834587097,1.3311790227890017,0.6173999905586243,1.5765674114227295,50000 -750.9849836826324,1.4479811191558838,20943.756311655045,61568,0,20943.756311655045,0.497700035572052,2.2889938354492188,10000,21698.42645382881,0.6678690910339355,1.3077174425125122,0.622439980506897,1.5476828813552856,50000 -769.0031280517578,1.4872050285339355,21454.002192497253,63074,0,21454.002192497253,0.5130000114440918,2.2014949321746826,10000,22226.784834861755,0.7181122303009033,1.0941146612167358,0.6288599967956543,1.528099536895752,50000 -786.8960626125336,1.5244412422180176,21964.22785615921,64579,0,21964.22785615921,0.5029000043869019,2.242102861404419,10000,22754.9960372448,0.6921635866165161,1.2201071977615356,0.6279599666595459,1.5359902381896973,50000 -804.7957236766815,1.5703482627868652,22474.41934657097,66085,0,22474.41934657097,0.4992000162601471,2.274237632751465,10000,23283.18795681,0.6760801672935486,1.270575761795044,0.623199999332428,1.5508557558059692,50000 -822.6233458518982,1.609081745147705,22984.531358480453,67590,0,22984.531358480453,0.5045000314712524,2.2707667350769043,10000,23811.22071957588,0.675203263759613,1.2853842973709106,0.6229400038719177,1.5472275018692017,50000 -840.2972972393036,1.6531808376312256,23494.5044093132,69095,0,23494.5044093132,0.5063000321388245,2.2389981746673584,10000,24338.96833062172,0.682039201259613,1.2529497146606443,0.6260799765586853,1.544729471206665,50000 -858.0913376808167,1.6934871673583984,24004.64289021492,70600,0,24004.64289021492,0.5125000476837158,2.193894147872925,10000,24866.99529266357,0.6823381781578064,1.261600375175476,0.6331599950790405,1.5052260160446167,50000 -875.8923208713531,1.7388732433319092,24514.733004808422,72105,0,24514.733004808422,0.5054000020027161,2.2285428047180176,10000,25394.98640203476,0.7115353941917419,1.123003363609314,0.6309399604797363,1.522621989250183,50000 -893.5084192752838,1.7819523811340332,25024.659342050552,73610,0,25024.659342050552,0.5118000507354736,2.229922294616699,10000,25922.62748122216,0.7038823366165161,1.1389704942703247,0.6362000107765198,1.4967260360717771,50000 -911.36985373497,1.823401689529419,25534.734695911407,75116,0,25534.734695911407,0.5091000199317932,2.216155767440796,10000,26450.662058591843,0.6950334906578064,1.1846526861190796,0.6358999609947205,1.490337610244751,50000 -929.1047916412354,1.8654460906982424,26044.821749687195,76621,0,26044.821749687195,0.4903000295162201,2.2797510623931885,10000,26978.580530405045,0.6759606003761292,1.2714606523513794,0.6222000122070312,1.5547786951065063,50000 -946.9249217510225,1.913517951965332,26555.042127132416,78127,0,26555.042127132416,0.5200999975204468,2.1450984477996826,10000,27506.72307229042,0.696687638759613,1.187483787536621,0.6402400135993958,1.469867467880249,50000 -964.8990716934204,1.956645011901856,27065.05168557167,79632,0,27065.05168557167,0.5033000111579895,2.213355302810669,10000,28034.80432486534,0.6890544891357422,1.2304526567459106,0.635919988155365,1.4974093437194824,50000 -982.7576727867126,2.000518560409546,27575.18598818779,81138,0,27575.18598818779,0.5028000473976135,2.267728090286255,10000,28562.895799398422,0.7074896097183228,1.138108491897583,0.6323999762535095,1.5167434215545654,50000 -1000.4262602329254,2.045931100845337,28085.256512880325,82643,0,28085.256512880325,0.5162000060081482,2.1949095726013184,10000,29090.73493361473,0.7142258882522583,1.0972977876663208,0.6411199569702148,1.474290132522583,50000 -1018.34730553627,2.0879697799682617,28595.185331583023,84148,0,28595.185331583023,0.5175999999046326,2.1688649654388428,10000,29618.681260347366,0.7117147445678711,1.1199214458465576,0.6462799906730652,1.4425675868988037,50000 -1036.582043170929,2.133475542068481,29105.374361276627,85654,0,29105.374361276627,0.5216000080108643,2.153033971786499,10000,30147.20574998856,0.7042809128761292,1.16113018989563,0.6452400088310242,1.4567824602127075,50000 -1054.7166481018066,2.183864116668701,29615.305659532547,87160,0,29615.305659532547,0.5284000039100647,2.1416590213775635,10000,30675.376630306244,0.7081273794174194,1.1381235122680664,0.6510800123214722,1.4280773401260376,50000 -1072.3675389289856,3.314368963241577,30124.24169325829,88662,0,30124.24169325829,0.520300030708313,2.128875970840454,10000,31203.1488199234,0.706074595451355,1.148975849151611,0.6509799957275391,1.4250391721725464,50000 -1090.077528476715,3.361715793609619,30634.32846808433,90168,0,30634.32846808433,0.5240000486373901,2.1700925827026367,10000,31731.047548294067,0.7113161683082581,1.1303361654281616,0.6467599868774414,1.444667100906372,50000 -1107.6928596496582,3.4071123600006104,31144.40614700317,91673,0,31144.40614700317,0.5311000347137451,2.098185062408448,10000,32258.84111380577,0.7281768321990967,1.0502692461013794,0.6510199904441833,1.4127343893051147,50000 -1125.5177392959597,3.4525933265686035,31654.537391901016,93179,0,31654.537391901016,0.5208000540733337,2.1194303035736084,10000,32786.896673202515,0.7147042155265808,1.1151316165924072,0.6510599851608276,1.43448007106781,50000 -1143.2149093151093,3.500083923339844,32164.47088861465,94684,0,32164.47088861465,0.5320000052452087,2.085729122161865,10000,33314.628957271576,0.7174146771430969,1.0910450220108032,0.6616799831390381,1.3794057369232178,50000 -1161.172973394394,3.5434374809265137,32674.50990009308,96190,0,32674.50990009308,0.530500054359436,2.102585792541504,10000,33842.723615169525,0.7141262888908386,1.1053730249404907,0.6563599705696106,1.3915393352508545,50000 -1178.9914045333862,3.5905425548553467,33184.6513376236,97696,0,33184.6513376236,0.5344000458717346,2.091149091720581,10000,34370.78517818451,0.7200254797935486,1.0727653503417969,0.6601399779319763,1.3852746486663818,50000 -1196.9177355766296,3.6347944736480713,33694.69447398186,99202,0,33694.69447398186,0.5306000113487244,2.0861809253692627,10000,34898.85359764099,0.7228953838348389,1.072148680686951,0.6584399938583374,1.380805730819702,50000 -1214.7148866653442,3.682264804840088,34204.65418744087,100707,0,34204.65418744087,0.5213000178337097,2.180105209350586,10000,35426.71285367012,0.7236925959587097,1.0682148933410645,0.6428999900817871,1.46091890335083,50000 -1232.6231932640076,3.7265255451202393,34714.8821105957,102213,0,34714.8821105957,0.5330000519752502,2.092923641204834,10000,35954.948600530624,0.7354512214660645,1.0134378671646118,0.6638399958610535,1.3594108819961548,50000 -1250.2376444339752,3.770775318145752,35225.05365109444,103718,0,35225.05365109444,0.5445000529289246,2.0400118827819824,10000,36482.834594249725,0.735371470451355,1.017613410949707,0.6706799864768982,1.3379477262496948,50000 -1267.9070928096771,3.817821502685547,35735.07813715935,105224,0,35735.07813715935,0.5378000140190125,2.079288482666016,10000,37010.63081288338,0.7281967401504517,1.0495648384094238,0.6665599942207336,1.3577643632888794,50000 -1285.6735136508942,3.866531372070313,36245.17560458183,106730,0,36245.17560458183,0.5428000092506409,2.031836748123169,10000,37538.59745979309,0.7329201102256775,1.0269001722335815,0.6684399843215942,1.340019464492798,50000 -1303.4949452877045,3.911306619644165,36755.24262714386,108236,0,36755.24262714386,0.5467000007629395,2.014052152633667,10000,38066.58639526367,0.7391581535339355,1.0007308721542358,0.6759799718856812,1.3103091716766355,50000 -1321.0843846797943,3.95632004737854,37265.14801168442,109741,0,37265.14801168442,0.5461000204086304,2.040199279785156,10000,38594.18107557297,0.7607222199440002,0.9054629802703856,0.6724399924278259,1.3212249279022217,50000 -1338.969208240509,4.006459951400757,37775.20650601387,111246,0,37775.20650601387,0.55840003490448,1.99822998046875,10000,39122.22860813141,0.7552016973495483,0.9197171330451964,0.6800199747085571,1.300967574119568,50000 -1356.734162569046,4.056137800216675,38285.42838454247,112752,0,38285.42838454247,0.5509999990463257,2.019953727722168,10000,39650.31968307495,0.752949595451355,0.93806129693985,0.6771599650382996,1.3039065599441528,50000 -1374.3892624378204,4.116119623184204,38795.5488409996,114258,0,38795.5488409996,0.5581000447273254,1.993594765663147,10000,40178.211097717285,0.7528300285339355,0.94208025932312,0.6818199753761292,1.2975364923477173,50000 -1392.7512967586515,4.165997743606567,39305.58757019043,115763,0,39305.58757019043,0.5466000437736511,2.012835025787353,10000,40706.71689796448,0.74418044090271,0.970153272151947,0.678059995174408,1.2999392747879028,50000 -1410.537811756134,4.214937686920166,39815.64774036408,117269,0,39815.64774036408,0.5684000253677368,1.921729564666748,10000,41234.668796777725,0.7614795565605164,0.9109672904014589,0.6902799606323242,1.2496405839920044,50000 -1428.5066511631012,4.262691259384155,40325.84529519081,118775,0,40325.84529519081,0.5646000504493713,1.9624043703079224,10000,41762.935829401016,0.7867307066917419,0.7906128764152527,0.6891199946403503,1.260579228401184,50000 -1446.361969947815,4.311857223510742,40836.02567625046,120281,0,40836.02567625046,0.5605000257492065,1.9643940925598145,10000,42291.07681274414,0.77543044090271,0.8411076664924622,0.6890400052070618,1.254690408706665,50000 -1463.9844024181366,4.359987735748291,41345.97124314308,121786,0,41345.97124314308,0.5678000450134277,1.9029667377471924,10000,42818.74826264381,0.7752710580825806,0.8376134634017944,0.6941999793052673,1.22850501537323,50000 -1481.8402979373932,4.414201498031616,41856.06367182732,123292,0,41856.06367182732,0.5701000094413757,1.924597978591919,10000,43346.80474662781,0.7697106003761292,0.8630506992340088,0.6924200057983398,1.2353243827819824,50000 -1500.4037234783173,4.462406873703003,42366.00140285492,124797,0,42366.00140285492,0.5696000456809998,1.931737184524536,10000,43875.40986657143,0.7671197056770325,0.862572968006134,0.6951199769973755,1.2315285205841064,50000 -1517.8734288215635,4.515023231506348,42876.07256484032,126303,0,42876.07256484032,0.579200029373169,1.8880500793457031,10000,44403.05740857124,0.776387095451355,0.831611156463623,0.7033199667930603,1.20175302028656,50000 -1535.7307217121124,4.566674709320068,43386.099005937576,127808,0,43386.099005937576,0.5752000212669373,1.917343020439148,10000,44931.04763793945,0.808035671710968,0.7111296057701111,0.7016400098800659,1.2024482488632202,50000 -1553.4010951519012,4.621357679367065,43896.101038217545,129314,0,43896.101038217545,0.5738000273704529,1.8980298042297363,10000,45458.82877635956,0.7916733026504517,0.7742475271224976,0.6988999843597412,1.205346941947937,50000 -1571.1410930156708,4.673174381256104,44406.101815223694,130819,0,44406.101815223694,0.5729000568389893,1.9065872430801392,10000,45986.675889253616,0.7887834906578064,0.7834305763244629,0.7028999924659729,1.1931791305541992,50000 -1589.0011916160583,4.726458549499512,44916.1741373539,132325,0,44916.1741373539,0.5933000445365906,1.8415168523788448,10000,46514.71562767029,0.7939453125,0.7579774260520935,0.7079199552536011,1.1673496961593628,50000 -1606.9133660793304,4.778220176696777,45426.37725400925,133831,0,45426.37725400925,0.5860000252723694,1.8372397422790527,10000,47042.93878364563,0.7958984375,0.757809579372406,0.708620011806488,1.1658791303634644,50000 -1624.4669604301453,4.833124399185181,45936.42216873169,135336,0,45936.42216873169,0.5838000178337097,1.8571194410324097,10000,47570.64797115326,0.7983697056770325,0.737368106842041,0.7152799963951111,1.1455098390579224,50000 -1641.9889187812803,4.884382009506226,46446.49920320511,136841,0,46446.49920320511,0.5967000126838684,1.7885128259658811,10000,48098.353201150894,0.8374122977256775,0.5970585942268372,0.7161399722099304,1.1301541328430176,50000 -1659.7750248908997,4.95368218421936,46956.61994481087,138346,0,46956.61994481087,0.5999000072479248,1.7960389852523804,10000,48626.38387274742,0.8237802982330322,0.637283980846405,0.7233399748802185,1.104615330696106,50000 -1677.6271243095398,5.008821725845337,47466.56193733215,139851,0,47466.56193733215,0.6011000275611877,1.7942116260528564,10000,49154.288517951965,0.8210897445678711,0.6470831036567688,0.7231400012969971,1.1141338348388672,50000 -1695.204603433609,5.06158185005188,47976.69469380379,141357,0,47976.69469380379,0.5990000367164612,1.7983546257019043,10000,49682.10616827011,0.8201530575752258,0.6532760858535767,0.7235199809074402,1.1173137426376345,50000 -1713.1225311756134,5.118523836135864,48486.89080500603,142862,0,48486.89080500603,0.5948000550270081,1.8159778118133545,10000,50210.33151769638,0.8167450428009033,0.6668460965156555,0.7207799553871155,1.1217389106750488,50000 -1730.7020156383514,5.169455766677856,48996.92267632485,144368,0,48996.92267632485,0.6010000109672546,1.7762413024902344,10000,50738.048347473145,0.8262914419174194,0.6314259767532349,0.7257599830627441,1.0909123420715332,50000 -1748.4203968048096,5.229321718215942,49507.11916804314,145874,0,49507.11916804314,0.6047000288963318,1.778953194618225,10000,51266.07844781876,0.8601721525192261,0.5021381378173828,0.732479989528656,1.0820106267929075,50000 -1766.0413398742676,5.282959222793579,50017.27359175682,147380,0,50017.27359175682,0.6089000105857849,1.7504609823226929,10000,51793.96232128143,0.8510442972183228,0.5359805822372437,0.7305399775505066,1.0758004188537598,50000 -1783.8505229949951,5.341572046279907,50527.27390384674,148885,0,50527.27390384674,0.6100000143051147,1.7439473867416382,10000,52321.88641309738,0.8487523794174194,0.5430769324302673,0.7354199886322021,1.058487057685852,50000 -1801.683106660843,5.4006664752960205,51037.273169994354,150390,0,51037.273169994354,0.6028000116348267,1.7548463344573977,10000,52849.83254790306,0.8517817258834839,0.5283650159835815,0.7335999608039856,1.0614495277404783,50000 -1819.265034675598,5.456265211105347,51547.48435902596,151895,0,51547.48435902596,0.6148000359535217,1.7296031713485718,10000,53377.73521447182,0.8530372977256775,0.5189365148544312,0.7375400066375732,1.058759093284607,50000 -1837.022742509842,5.511146545410156,52057.91526436806,153401,0,52057.91526436806,0.6172000169754028,1.733088731765747,10000,53906.03312087059,0.8583585619926453,0.5022905468940735,0.7417399883270264,1.0438984632492063,50000 -1854.6535007953644,5.569386959075928,52567.88165092468,154906,0,52567.88165092468,0.6126000285148621,1.7265610694885254,10000,54433.74296832085,0.8868981003761292,0.4032793939113617,0.741159975528717,1.0439447164535522,50000 -1872.464199066162,5.626893997192383,53077.91947197914,156412,0,53077.91947197914,0.6198000311851501,1.7015371322631836,10000,54961.70411801338,0.8819355964660645,0.4201908409595489,0.7441399693489075,1.0274673700332642,50000 -1890.2672145366669,5.684285640716553,53587.85387945175,157917,0,53587.85387945175,0.6189000010490417,1.7277756929397583,10000,55489.55252146721,0.8812978267669678,0.4185983538627624,0.7456600069999695,1.027588129043579,50000 -1908.046015739441,5.742774486541748,54097.80532884598,159422,0,54097.80532884598,0.6230000257492065,1.718498468399048,10000,56017.39664173126,0.8851044178009033,0.404142826795578,0.7461400032043457,1.024307131767273,50000 -1925.8649232387545,5.801524877548218,54608.0150744915,160928,0,54608.0150744915,0.6189000010490417,1.7309526205062866,10000,56545.54099678993,0.8844267725944519,0.4088110625743866,0.7474600076675415,1.0270127058029177,50000 -1944.0792744159696,5.855859756469727,55118.058817625046,162434,0,55118.058817625046,0.6270000338554382,1.6860774755477903,10000,57073.90779566765,0.8910036683082581,0.3811632096767425,0.7514399886131287,1.0032466650009155,50000 -1961.663684368133,5.913207292556763,55627.95647931099,163939,0,55627.95647931099,0.6303000450134277,1.6824558973312378,10000,57601.5028333664,0.90921950340271,0.3235359787940979,0.7534599900245667,0.9915898442268372,50000 -1979.5256507396696,5.970006227493286,56138.10664725304,165445,0,56138.10664725304,0.6299000382423401,1.6766432523727417,10000,58129.62777304649,0.9094586968421936,0.3204634189605713,0.7548999786376953,0.9951203465461732,50000 -1997.2634472846985,6.031470537185669,56648.064351558685,166950,0,56648.064351558685,0.6342000365257263,1.6934216022491455,10000,58657.44054579735,0.9108538031578064,0.3147288560867309,0.7562599778175354,0.9941707849502563,50000 -2014.961641788483,6.093728303909302,57158.12152385712,168456,0,57158.12152385712,0.6330000162124634,1.6975325345993042,10000,59185.312942266464,0.9100167155265808,0.3191568851470947,0.7554000020027161,0.9903133511543274,50000 -2032.660756349564,6.152337074279785,57668.05534219742,169961,0,57668.05534219742,0.6330000162124634,1.672535061836243,10000,59713.05906128883,0.913843274116516,0.3032202124595642,0.7573599815368652,0.9809739589691162,50000 -2050.4763667583466,6.209350109100342,58178.16114640236,171467,0,58178.16114640236,0.6367000341415405,1.6685278415679932,10000,60241.09210109711,0.9151387214660645,0.2979476153850555,0.7592200040817261,0.9786510467529296,50000 -2068.201717376709,6.271347761154175,58688.1078979969,172972,0,58688.1078979969,0.6366000175476074,1.6629676818847656,10000,60768.88009810448,0.92386794090271,0.2707740664482116,0.759619951248169,0.9723941683769226,50000 -2085.809932947159,6.333725452423096,59198.22804784775,174478,0,59198.22804784775,0.6362000107765198,1.6596471071243286,10000,61296.72523832321,0.9291493892669678,0.2514609694480896,0.7603799700737,0.972877323627472,50000 -2103.5739080905914,6.391595363616943,59708.26288509369,175983,0,59708.26288509369,0.636900007724762,1.6625109910964966,10000,61824.635633945465,0.9298469424247742,0.2528488039970398,0.7617799639701843,0.9658573865890504,50000 -2121.300199508667,6.451011419296265,60218.22137951851,177488,0,60218.22137951851,0.6383000016212463,1.6531953811645508,10000,62352.43422079086,0.928730845451355,0.2565664649009704,0.7629799842834473,0.9615415930747986,50000 -2139.160705327988,6.521094560623169,60728.37961912155,178993,0,60728.37961912155,0.6381000280380249,1.6532292366027832,10000,62880.57772922516,0.9295678734779358,0.2503544390201568,0.7623199820518494,0.9605156183242798,50000 -2156.840073823929,6.580701112747192,61238.27491044998,180498,0,61238.27491044998,0.638200044631958,1.647757053375244,10000,63408.26725912094,0.9304248690605164,0.244537353515625,0.7628200054168701,0.9591808319091796,50000 -2174.361873626709,6.64280366897583,61748.39874982834,182003,0,61748.39874982834,0.6396000385284424,1.6491045951843262,10000,63936.03000330925,0.9319595098495485,0.2480978667736053,0.7633199691772461,0.958383858203888,50000 -2192.3842589855194,6.70346999168396,62258.506739377975,183508,0,62258.506739377975,0.6368000507354736,1.650039553642273,10000,64464.27554774284,0.93359375,0.2431531846523285,0.7633199691772461,0.9569819569587708,50000 -2210.148278236389,6.768203496932983,62768.50212907791,185013,0,62768.50212907791,0.6378000378608704,1.648453950881958,10000,64992.15400886536,0.9331752061843872,0.24071836471557617,0.7636199593544006,0.957119882106781,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/measurements.csv deleted file mode 100644 index c466618bb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1984 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6558059,6.934149,,,,,,,,,,,,,, -1,,,0.0009964923374354,6.91312313079834,0.0007599999662488,6.913174629211426,50000.0,0.0006000000284984,6.9125494956970215,10000.0,30.50831913948059,48.36005544662476,30.50831913948059,17.85161852836609,0.0,0.0 -100,0.65337265,6.8976083,,,,,,,,,,,,,, -200,0.64755255,6.861892,,,,,,,,,,,,,, -300,0.6928496,6.7967167,,,,,,,,,,,,,, -400,0.75381017,6.6834807,,,,,,,,,,,,,, -500,0.7955704,6.5478535,,,,,,,,,,,,,, -600,0.84733903,6.4404464,,,,,,,,,,,,,, -700,0.866842,6.3692975,,,,,,,,,,,,,, -800,1.0322355,6.273805,,,,,,,,,,,,,, -900,2.3473341,6.067587,,,,,,,,,,,,,, -1000,2.1739361,5.989801,,,,,,,,,,,,,, -1100,1.840093,5.812634,,,,,,,,,,,,,, -1200,2.0297399,5.772386,,,,,,,,,,,,,, -1300,2.2532763,5.6973476,,,,,,,,,,,,,, -1400,2.1662786,5.566329,,,,,,,,,,,,,, -1497,,,0.0686981827020645,5.348140716552734,0.0642599985003471,5.396807193756104,50000.0,0.0462000034749507,5.633206367492676,10000.0,540.5099921226501,576.5030663013458,540.5099921226501,35.9192590713501,0.019089937210083,0.0 -1500,3.6213071,5.5810533,,,,,,,,,,,,,, -1600,3.77403,5.382949,,,,,,,,,,,,,, -1700,3.3148944,5.325825,,,,,,,,,,,,,, -1800,3.0707648,5.2483873,,,,,,,,,,,,,, -1900,3.3767056,5.260184,,,,,,,,,,,,,, -2000,3.1754265,5.2139053,,,,,,,,,,,,,, -2100,3.4612184,5.1313167,,,,,,,,,,,,,, -2200,5.1640463,5.0564923,,,,,,,,,,,,,, -2300,5.192972,4.94169,,,,,,,,,,,,,, -2400,4.0578594,5.0446377,,,,,,,,,,,,,, -2500,7.26339,4.912806,,,,,,,,,,,,,, -2600,3.7798471,4.800103,,,,,,,,,,,,,, -2700,8.33402,4.8424273,,,,,,,,,,,,,, -2800,6.545034,4.708309,,,,,,,,,,,,,, -2900,5.6803517,4.656521,,,,,,,,,,,,,, -2993,,,0.1697425097227096,4.304430961608887,0.1521599888801574,4.434348106384277,50000.0,0.1075000017881393,4.922767639160156,10000.0,1050.4253158569336,1104.4395382404327,1050.4253158569336,53.85086178779602,0.0537657737731933,0.0 -3000,4.6641116,4.56707,,,,,,,,,,,,,, -3100,7.4319143,4.579644,,,,,,,,,,,,,, -3200,9.659853,4.610473,,,,,,,,,,,,,, -3300,6.8049603,4.465188,,,,,,,,,,,,,, -3400,6.364278,4.4587045,,,,,,,,,,,,,, -3500,5.937528,4.385242,,,,,,,,,,,,,, -3600,7.0836873,4.3761573,,,,,,,,,,,,,, -3700,5.8924923,4.243226,,,,,,,,,,,,,, -3800,5.0786324,4.150761,,,,,,,,,,,,,, -3900,6.4977417,4.1231623,,,,,,,,,,,,,, -4000,9.134161,4.175894,,,,,,,,,,,,,, -4100,6.750136,4.153454,,,,,,,,,,,,,, -4200,7.36004,4.0922294,,,,,,,,,,,,,, -4300,4.2562094,3.9537828,,,,,,,,,,,,,, -4400,7.746972,3.924368,,,,,,,,,,,,,, -4489,,,0.2671595811843872,3.542244672775269,0.245619997382164,3.673764944076538,50000.0,0.1746000051498413,4.269079208374023,10000.0,1560.6546158790588,1632.4659051895142,1560.6546158790588,71.5679407119751,0.0806002616882324,0.0 -4500,8.334443,3.9139903,,,,,,,,,,,,,, -4600,8.465131,3.8219237,,,,,,,,,,,,,, -4700,6.504534,3.8517642,,,,,,,,,,,,,, -4800,6.096896,3.8029916,,,,,,,,,,,,,, -4900,6.467418,3.7121139,,,,,,,,,,,,,, -5000,7.743506,3.5807128,,,,,,,,,,,,,, -5100,7.8403726,3.6361241,,,,,,,,,,,,,, -5200,6.955584,3.58969,,,,,,,,,,,,,, -5300,10.001385,3.7592995,,,,,,,,,,,,,, -5400,7.597902,3.4627938,,,,,,,,,,,,,, -5500,6.7715883,3.5814369,,,,,,,,,,,,,, -5600,8.848642,3.5002813,,,,,,,,,,,,,, -5700,5.333061,3.3834782,,,,,,,,,,,,,, -5800,8.220879,3.5509386,,,,,,,,,,,,,, -5900,7.1631374,3.5751781,,,,,,,,,,,,,, -5985,,,0.3645368218421936,2.9136829376220703,0.340719997882843,3.063777446746826,50000.0,0.2565000057220459,3.7345449924468994,10000.0,2070.620194196701,2160.280205488205,2070.620194196701,89.33231997489929,0.1103501319885253,0.0 -6000,7.051062,3.3946342,,,,,,,,,,,,,, -6100,6.4189425,3.4730902,,,,,,,,,,,,,, -6200,7.6608844,3.2417586,,,,,,,,,,,,,, -6300,11.5959425,3.309542,,,,,,,,,,,,,, -6400,12.078389,3.272767,,,,,,,,,,,,,, -6500,5.586671,3.1611772,,,,,,,,,,,,,, -6600,12.769292,3.3508246,,,,,,,,,,,,,, -6700,7.069141,3.1067924,,,,,,,,,,,,,, -6800,4.777975,3.1911252,,,,,,,,,,,,,, -6900,5.7416396,3.0378907,,,,,,,,,,,,,, -7000,7.30078,3.0059118,,,,,,,,,,,,,, -7100,4.7983785,3.0770328,,,,,,,,,,,,,, -7200,7.4131536,3.1569252,,,,,,,,,,,,,, -7300,4.040307,3.020931,,,,,,,,,,,,,, -7400,5.3293123,2.9952898,,,,,,,,,,,,,, -7481,,,0.442402720451355,2.474280834197998,0.3840200006961822,2.817934513092041,50000.0,0.2895000278949737,3.5531110763549805,10000.0,2580.722687244416,2688.5322892665863,2580.722687244416,107.3965380191803,0.142303466796875,0.0 -7500,5.936029,2.9389699,,,,,,,,,,,,,, -7600,4.9621253,2.9735556,,,,,,,,,,,,,, -7700,7.310459,3.12214,,,,,,,,,,,,,, -7800,8.509534,2.9000623,,,,,,,,,,,,,, -7900,6.414226,3.0475197,,,,,,,,,,,,,, -8000,6.125955,2.8778296,,,,,,,,,,,,,, -8100,6.3183913,2.8578875,,,,,,,,,,,,,, -8200,6.156884,2.8085883,,,,,,,,,,,,,, -8300,4.2591867,2.8987873,,,,,,,,,,,,,, -8400,5.2010145,2.753026,,,,,,,,,,,,,, -8500,5.1719246,2.7210417,,,,,,,,,,,,,, -8600,6.954334,2.8823514,,,,,,,,,,,,,, -8700,7.4918876,2.8046489,,,,,,,,,,,,,, -8800,4.46936,2.747847,,,,,,,,,,,,,, -8900,6.9674788,2.7343757,,,,,,,,,,,,,, -8978,,,0.4957549273967743,2.157732963562012,0.4500799775123596,2.436403274536133,50000.0,0.345300018787384,3.1740477085113525,10000.0,3090.6557648181915,3217.1792571544647,3090.6557648181915,126.0262610912323,0.1721491813659668,0.0 -9000,5.175506,2.7260518,,,,,,,,,,,,,, -9100,3.5533576,2.7288055,,,,,,,,,,,,,, -9200,7.4081645,2.805272,,,,,,,,,,,,,, -9300,5.5010767,2.645924,,,,,,,,,,,,,, -9400,8.155899,2.7011056,,,,,,,,,,,,,, -9500,8.532637,2.7040367,,,,,,,,,,,,,, -9600,6.234306,2.5529542,,,,,,,,,,,,,, -9700,8.199096,2.7445579,,,,,,,,,,,,,, -9800,6.8043466,2.606707,,,,,,,,,,,,,, -9900,7.6880965,2.621864,,,,,,,,,,,,,, -10000,6.110202,2.5242395,,,,,,,,,,,,,, -10100,4.5680437,2.7936828,,,,,,,,,,,,,, -10200,5.8304977,2.7423134,,,,,,,,,,,,,, -10300,6.3125525,2.5341103,,,,,,,,,,,,,, -10400,4.3448577,2.566465,,,,,,,,,,,,,, -10475,,,0.521882951259613,2.0246825218200684,0.4817799925804138,2.25955581665039,50000.0,0.3767000138759613,2.959792375564575,10000.0,3600.778788328171,3745.4569323062897,3600.778788328171,144.09838795661926,0.2016129493713379,0.0 -10500,4.9656053,2.4681177,,,,,,,,,,,,,, -10600,6.096773,2.368266,,,,,,,,,,,,,, -10700,9.07395,2.6026316,,,,,,,,,,,,,, -10800,5.3754616,2.525475,,,,,,,,,,,,,, -10900,4.663888,2.5704408,,,,,,,,,,,,,, -11000,7.7257037,2.5553415,,,,,,,,,,,,,, -11100,6.7612543,2.3403506,,,,,,,,,,,,,, -11200,4.9766707,2.43076,,,,,,,,,,,,,, -11300,5.651117,2.4471664,,,,,,,,,,,,,, -11400,7.217275,2.3884509,,,,,,,,,,,,,, -11500,9.025783,2.4393725,,,,,,,,,,,,,, -11600,7.179778,2.3121433,,,,,,,,,,,,,, -11700,6.223219,2.4208405,,,,,,,,,,,,,, -11800,5.5907774,2.4030094,,,,,,,,,,,,,, -11900,7.0477805,2.4052355,,,,,,,,,,,,,, -11973,,,0.5642538070678711,1.8322232961654663,0.5165799856185913,2.0682852268218994,50000.0,0.397100031375885,2.806823492050171,10000.0,4110.76530623436,4273.276990890503,4110.76530623436,161.8473880290985,0.2317366600036621,0.0 -12000,6.424739,2.5139527,,,,,,,,,,,,,, -12100,5.993559,2.356701,,,,,,,,,,,,,, -12200,4.804718,2.4670577,,,,,,,,,,,,,, -12300,6.355086,2.386715,,,,,,,,,,,,,, -12400,4.6427183,2.290423,,,,,,,,,,,,,, -12500,5.7294354,2.36634,,,,,,,,,,,,,, -12600,6.0420094,2.3726497,,,,,,,,,,,,,, -12700,5.217107,2.250072,,,,,,,,,,,,,, -12800,8.593526,2.3800893,,,,,,,,,,,,,, -12900,5.374526,2.418047,,,,,,,,,,,,,, -13000,8.500883,2.291576,,,,,,,,,,,,,, -13100,7.8713994,2.342085,,,,,,,,,,,,,, -13200,8.220527,2.2667863,,,,,,,,,,,,,, -13300,8.197153,2.3241653,,,,,,,,,,,,,, -13400,8.651589,2.325655,,,,,,,,,,,,,, -13472,,,0.5740792155265808,1.7880722284317017,0.5296199917793274,2.0117011070251465,50000.0,0.4182000160217285,2.717848777770996,10000.0,4620.987542629242,4801.707133054733,4620.987542629242,179.95350456237793,0.2778005599975586,0.0 -13500,9.177803,2.331404,,,,,,,,,,,,,, -13600,6.514846,2.2308612,,,,,,,,,,,,,, -13700,6.3769207,2.3052225,,,,,,,,,,,,,, -13800,7.6573567,2.2528448,,,,,,,,,,,,,, -13900,8.924978,2.24757,,,,,,,,,,,,,, -14000,4.925011,2.285035,,,,,,,,,,,,,, -14100,5.882618,2.220512,,,,,,,,,,,,,, -14200,8.469834,2.2972696,,,,,,,,,,,,,, -14300,6.9152093,2.2822924,,,,,,,,,,,,,, -14400,6.5752807,2.3565884,,,,,,,,,,,,,, -14500,5.578109,2.3820755,,,,,,,,,,,,,, -14600,8.179267,2.281917,,,,,,,,,,,,,, -14700,7.697649,2.1697323,,,,,,,,,,,,,, -14800,5.418071,2.1751623,,,,,,,,,,,,,, -14900,5.9914536,2.1811247,,,,,,,,,,,,,, -14972,,,0.5755141973495483,1.759779930114746,0.540399968624115,1.9560120105743408,50000.0,0.4134000241756439,2.7285807132720947,10000.0,5131.111432313919,5329.801098108292,5131.111432313919,197.8390641212464,0.3077218532562256,0.0 -15000,6.3769016,2.3003204,,,,,,,,,,,,,, -15100,5.010117,2.20716,,,,,,,,,,,,,, -15200,7.093149,2.210526,,,,,,,,,,,,,, -15300,9.4791355,2.25905,,,,,,,,,,,,,, -15400,7.547457,2.1339746,,,,,,,,,,,,,, -15500,7.4712224,2.296635,,,,,,,,,,,,,, -15600,6.1853275,2.214767,,,,,,,,,,,,,, -15700,5.515449,2.2095346,,,,,,,,,,,,,, -15800,5.3595133,2.2291508,,,,,,,,,,,,,, -15900,7.4386396,2.2435596,,,,,,,,,,,,,, -16000,8.248577,2.2724833,,,,,,,,,,,,,, -16100,6.334516,2.2307363,,,,,,,,,,,,,, -16200,8.125646,2.278481,,,,,,,,,,,,,, -16300,5.498949,2.2313893,,,,,,,,,,,,,, -16400,4.5694323,2.3194146,,,,,,,,,,,,,, -16471,,,0.6073421239852905,1.608778953552246,0.5467000007629395,1.9259779453277588,50000.0,0.4357000291347503,2.652667284011841,10000.0,5641.149353981018,5857.610563755035,5641.149353981018,215.526605129242,0.3375465869903564,0.0 -16500,5.3323326,2.1251814,,,,,,,,,,,,,, -16600,6.4710507,2.2504964,,,,,,,,,,,,,, -16700,4.2448335,2.129351,,,,,,,,,,,,,, -16800,5.5176053,2.1941133,,,,,,,,,,,,,, -16900,8.509009,2.2088425,,,,,,,,,,,,,, -17000,6.2042427,2.2725005,,,,,,,,,,,,,, -17100,5.643759,2.2266424,,,,,,,,,,,,,, -17200,4.566801,2.1263099,,,,,,,,,,,,,, -17300,4.061115,2.097709,,,,,,,,,,,,,, -17400,5.189391,2.283129,,,,,,,,,,,,,, -17500,4.271319,2.0683768,,,,,,,,,,,,,, -17600,5.5937295,2.0637925,,,,,,,,,,,,,, -17700,4.9163313,2.1443915,,,,,,,,,,,,,, -17800,6.4491816,2.1574292,,,,,,,,,,,,,, -17900,5.8545737,2.0780404,,,,,,,,,,,,,, -17972,,,0.6112284660339355,1.6073660850524902,0.5543400049209595,1.895092725753784,50000.0,0.432200014591217,2.6180338859558105,10000.0,6151.183410406113,6385.664488315582,6151.183410406113,233.45811223983765,0.3709759712219238,0.0 -18000,8.330803,2.1643746,,,,,,,,,,,,,, -18100,3.9394023,2.2602062,,,,,,,,,,,,,, -18200,4.9629006,2.1751132,,,,,,,,,,,,,, -18300,5.1187673,2.1921325,,,,,,,,,,,,,, -18400,4.7964506,2.1183455,,,,,,,,,,,,,, -18500,4.80964,2.1934547,,,,,,,,,,,,,, -18600,7.4399815,2.1183472,,,,,,,,,,,,,, -18700,3.3445024,2.3250544,,,,,,,,,,,,,, -18800,4.312058,2.0841537,,,,,,,,,,,,,, -18900,5.789975,2.1638207,,,,,,,,,,,,,, -19000,4.280836,2.143437,,,,,,,,,,,,,, -19100,3.9556289,2.1435707,,,,,,,,,,,,,, -19200,2.8497932,2.181882,,,,,,,,,,,,,, -19300,3.8054075,2.2285955,,,,,,,,,,,,,, -19400,3.0574045,2.1608005,,,,,,,,,,,,,, -19473,,,0.6114277839660645,1.587865114212036,0.5617799758911133,1.8461424112319944,50000.0,0.4406000077724457,2.583203077316284,10000.0,6661.199192047119,6913.575413227081,6661.199192047119,251.256432056427,0.4122631549835205,0.0 -19500,6.1803803,2.265343,,,,,,,,,,,,,, -19600,4.488199,2.2451415,,,,,,,,,,,,,, -19700,4.4844265,2.077299,,,,,,,,,,,,,, -19800,3.9265647,2.0499685,,,,,,,,,,,,,, -19900,4.0936317,2.1685066,,,,,,,,,,,,,, -20000,4.839858,2.056186,,,,,,,,,,,,,, -20100,4.442946,2.0742536,,,,,,,,,,,,,, -20200,5.0043597,2.2609987,,,,,,,,,,,,,, -20300,4.527859,2.1658716,,,,,,,,,,,,,, -20400,4.047442,2.0527198,,,,,,,,,,,,,, -20500,4.1415453,2.1847494,,,,,,,,,,,,,, -20600,3.2201173,2.052269,,,,,,,,,,,,,, -20700,8.451511,2.024291,,,,,,,,,,,,,, -20800,6.373572,2.096163,,,,,,,,,,,,,, -20900,4.2795978,2.0018542,,,,,,,,,,,,,, -20974,,,0.6094148755073547,1.583729267120361,0.5648800134658813,1.8347382545471191,50000.0,0.4406000077724457,2.576707363128662,10000.0,7171.261469125748,7441.400032997131,7171.261469125748,268.92280554771423,0.4528038501739502,0.0 -21000,4.6063757,2.1970778,,,,,,,,,,,,,, -21100,4.266023,2.1063626,,,,,,,,,,,,,, -21200,4.1868005,2.1457858,,,,,,,,,,,,,, -21300,4.5194044,2.1049733,,,,,,,,,,,,,, -21400,3.0554066,2.0972347,,,,,,,,,,,,,, -21500,4.985654,2.1468866,,,,,,,,,,,,,, -21600,4.5068135,2.0869749,,,,,,,,,,,,,, -21700,3.2544718,2.090745,,,,,,,,,,,,,, -21800,4.03934,2.0689213,,,,,,,,,,,,,, -21900,3.4275315,2.1184745,,,,,,,,,,,,,, -22000,5.029726,2.0265372,,,,,,,,,,,,,, -22100,3.6849852,2.089213,,,,,,,,,,,,,, -22200,4.245305,2.099954,,,,,,,,,,,,,, -22300,3.5621414,2.0143347,,,,,,,,,,,,,, -22400,3.5653553,2.1071956,,,,,,,,,,,,,, -22476,,,0.6075215339660645,1.6099942922592163,0.5667600035667419,1.828776240348816,50000.0,0.443200021982193,2.5712270736694336,10000.0,7681.373394966125,7969.503405809402,7681.373394966125,286.828584432602,0.4840741157531738,0.0 -22500,4.0225673,1.9183253,,,,,,,,,,,,,, -22600,4.058149,2.0738935,,,,,,,,,,,,,, -22700,3.2535655,2.1640635,,,,,,,,,,,,,, -22800,3.7346528,2.125577,,,,,,,,,,,,,, -22900,3.4158933,1.9747254,,,,,,,,,,,,,, -23000,3.4540105,2.1173809,,,,,,,,,,,,,, -23100,3.0372136,2.0098548,,,,,,,,,,,,,, -23200,3.3944566,2.024068,,,,,,,,,,,,,, -23300,5.231933,1.9887521,,,,,,,,,,,,,, -23400,4.421061,2.0437338,,,,,,,,,,,,,, -23500,4.5322237,2.069774,,,,,,,,,,,,,, -23600,4.777667,2.1278243,,,,,,,,,,,,,, -23700,4.1029515,2.1057518,,,,,,,,,,,,,, -23800,4.3171105,2.1454167,,,,,,,,,,,,,, -23900,3.1183867,2.038726,,,,,,,,,,,,,, -23978,,,0.6109893321990967,1.584595799446106,0.5736599564552307,1.797381043434143,50000.0,0.4435000121593475,2.56406831741333,10000.0,8191.535412549972,8497.370985507965,8191.535412549972,304.44716572761536,0.5166065692901611,0.0 -24000,4.1525455,2.0810354,,,,,,,,,,,,,, -24100,4.3296905,2.0132565,,,,,,,,,,,,,, -24200,3.3741403,2.0407934,,,,,,,,,,,,,, -24300,4.0783463,2.1071725,,,,,,,,,,,,,, -24400,4.951961,2.1313586,,,,,,,,,,,,,, -24500,3.1732295,2.141245,,,,,,,,,,,,,, -24600,3.5533657,2.0405786,,,,,,,,,,,,,, -24700,3.1932433,1.966291,,,,,,,,,,,,,, -24800,3.4718554,2.0250373,,,,,,,,,,,,,, -24900,3.3981478,2.0396123,,,,,,,,,,,,,, -25000,4.5022044,2.0794775,,,,,,,,,,,,,, -25100,4.0164447,2.1291075,,,,,,,,,,,,,, -25200,3.8138266,2.187622,,,,,,,,,,,,,, -25300,4.303244,2.1586533,,,,,,,,,,,,,, -25400,4.2027683,2.0496616,,,,,,,,,,,,,, -25479,,,0.6316764950752258,1.49774432182312,0.5838800072669983,1.7310395240783691,50000.0,0.4689000248908996,2.438860893249512,10000.0,8701.451757669449,9025.565544128418,8701.451757669449,322.6385922431946,0.5493104457855225,0.0 -25500,3.8162754,2.0801237,,,,,,,,,,,,,, -25600,3.6592557,1.9956261,,,,,,,,,,,,,, -25700,3.8203824,2.064241,,,,,,,,,,,,,, -25800,3.196938,1.9530971,,,,,,,,,,,,,, -25900,3.6944492,1.988237,,,,,,,,,,,,,, -26000,4.391988,2.1996856,,,,,,,,,,,,,, -26100,3.3263302,1.9582362,,,,,,,,,,,,,, -26200,3.867757,2.0122519,,,,,,,,,,,,,, -26300,3.6919317,2.1383889,,,,,,,,,,,,,, -26400,4.8695326,1.9807444,,,,,,,,,,,,,, -26500,3.7735393,2.0909212,,,,,,,,,,,,,, -26600,4.031737,2.02292,,,,,,,,,,,,,, -26700,3.2460377,2.0552013,,,,,,,,,,,,,, -26800,4.4654493,1.9902726,,,,,,,,,,,,,, -26900,4.1504407,1.8229526,,,,,,,,,,,,,, -26982,,,0.6469228267669678,1.421025037765503,0.5805799961090088,1.7613977193832395,50000.0,0.4593000113964081,2.5006914138793945,10000.0,9211.581802606584,9553.536763191223,9211.581802606584,340.3928325176239,0.5819401741027832,0.0 -27000,3.575584,2.0711226,,,,,,,,,,,,,, -27100,3.8618267,2.0103736,,,,,,,,,,,,,, -27200,3.8148074,2.0707688,,,,,,,,,,,,,, -27300,3.8700976,2.030191,,,,,,,,,,,,,, -27400,3.415759,1.9861661,,,,,,,,,,,,,, -27500,3.830515,1.9696374,,,,,,,,,,,,,, -27600,3.6749003,2.0105283,,,,,,,,,,,,,, -27700,3.243936,1.9218583,,,,,,,,,,,,,, -27800,4.5300884,2.0565841,,,,,,,,,,,,,, -27900,3.4991071,2.1007156,,,,,,,,,,,,,, -28000,3.8472264,1.9717493,,,,,,,,,,,,,, -28100,4.070619,1.9723485,,,,,,,,,,,,,, -28200,3.4722679,2.0829263,,,,,,,,,,,,,, -28300,3.3774495,1.9330022,,,,,,,,,,,,,, -28400,3.6617756,1.9325001,,,,,,,,,,,,,, -28484,,,0.6459661722183228,1.4371527433395386,0.5914199948310852,1.6988749504089355,50000.0,0.4655000269412994,2.420314311981201,10000.0,9721.777070045471,10081.70643401146,9721.777070045471,358.2810490131378,0.6139397621154785,0.0 -28500,3.8321497,1.8385576,,,,,,,,,,,,,, -28600,3.7075827,2.0707905,,,,,,,,,,,,,, -28700,4.191793,2.0442991,,,,,,,,,,,,,, -28800,3.5826092,1.9284937,,,,,,,,,,,,,, -28900,3.324598,2.0168605,,,,,,,,,,,,,, -29000,3.3512912,1.9009395,,,,,,,,,,,,,, -29100,3.9763725,1.9510353,,,,,,,,,,,,,, -29200,3.0638084,1.9514363,,,,,,,,,,,,,, -29300,4.113017,2.0844274,,,,,,,,,,,,,, -29400,3.9044092,2.0151365,,,,,,,,,,,,,, -29500,4.209317,2.0119128,,,,,,,,,,,,,, -29600,3.4653683,2.0130212,,,,,,,,,,,,,, -29700,3.5196009,1.8804907,,,,,,,,,,,,,, -29800,3.7667186,1.9660542,,,,,,,,,,,,,, -29900,3.9237363,2.0794554,,,,,,,,,,,,,, -29988,,,0.6504305005073547,1.4086905717849731,0.5985400080680847,1.6694458723068235,50000.0,0.4718000292778015,2.4125912189483643,10000.0,10232.031190395355,10609.787842988968,10232.031190395355,376.0208065509796,0.6470503807067871,0.0 -30000,4.0954847,2.0941286,,,,,,,,,,,,,, -30100,3.2283378,1.958734,,,,,,,,,,,,,, -30200,4.3719606,2.0445306,,,,,,,,,,,,,, -30300,3.186183,2.0017388,,,,,,,,,,,,,, -30400,3.3631732,1.919333,,,,,,,,,,,,,, -30500,3.2013507,1.927017,,,,,,,,,,,,,, -30600,3.27903,1.9893568,,,,,,,,,,,,,, -30700,3.4623196,1.9521517,,,,,,,,,,,,,, -30800,3.831807,1.9753414,,,,,,,,,,,,,, -30900,3.7776961,2.061562,,,,,,,,,,,,,, -31000,3.8177042,1.9563715,,,,,,,,,,,,,, -31100,3.661306,1.9359192,,,,,,,,,,,,,, -31200,3.7240956,1.9594955,,,,,,,,,,,,,, -31300,3.6520991,1.9274985,,,,,,,,,,,,,, -31400,3.7090259,2.0202618,,,,,,,,,,,,,, -31491,,,0.6327327489852905,1.4778612852096558,0.5869199633598328,1.7286511659622192,50000.0,0.4706000089645386,2.442723035812378,10000.0,10742.144251823423,11137.651804208755,10742.144251823423,393.6835870742798,0.680239200592041,0.0 -31500,3.8525438,1.9941742,,,,,,,,,,,,,, -31600,3.6527622,1.8519094,,,,,,,,,,,,,, -31700,3.3345308,1.9254096,,,,,,,,,,,,,, -31800,3.651813,2.0022066,,,,,,,,,,,,,, -31900,3.311867,1.9548542,,,,,,,,,,,,,, -32000,3.9247613,1.8567395,,,,,,,,,,,,,, -32100,3.2227876,2.0132954,,,,,,,,,,,,,, -32200,3.849761,2.0422392,,,,,,,,,,,,,, -32300,3.3165278,1.8070481,,,,,,,,,,,,,, -32400,3.9825737,2.0428674,,,,,,,,,,,,,, -32500,3.9937282,1.9376206,,,,,,,,,,,,,, -32600,4.1319456,1.9896612,,,,,,,,,,,,,, -32700,3.2351654,1.9978476,,,,,,,,,,,,,, -32800,3.3225648,1.9452883,,,,,,,,,,,,,, -32900,3.9821777,1.9056435,,,,,,,,,,,,,, -32994,,,0.6381337642669678,1.4649090766906738,0.5950799584388733,1.6732388734817505,50000.0,0.4778000116348266,2.3845551013946533,10000.0,11252.236163139343,11665.972929954529,11252.236163139343,411.8246719837189,0.7132534980773926,0.0 -33000,3.244766,1.8939003,,,,,,,,,,,,,, -33100,3.647297,1.9855087,,,,,,,,,,,,,, -33200,3.2831895,2.0080867,,,,,,,,,,,,,, -33300,3.744005,1.9088765,,,,,,,,,,,,,, -33400,4.2663007,1.997848,,,,,,,,,,,,,, -33500,3.7812994,1.98895,,,,,,,,,,,,,, -33600,3.51415,2.1567898,,,,,,,,,,,,,, -33700,3.4006612,1.7898293,,,,,,,,,,,,,, -33800,3.1788514,2.014168,,,,,,,,,,,,,, -33900,3.0059204,1.8528075,,,,,,,,,,,,,, -34000,3.3896823,2.0213292,,,,,,,,,,,,,, -34100,3.2186518,1.9507897,,,,,,,,,,,,,, -34200,3.919927,1.7633984,,,,,,,,,,,,,, -34300,4.190972,2.0438695,,,,,,,,,,,,,, -34400,3.605304,1.9020445,,,,,,,,,,,,,, -34497,,,0.647480845451355,1.4073861837387085,0.6032800078392029,1.6388838291168213,50000.0,0.4724000096321106,2.404728889465332,10000.0,11762.316961288452,12193.809639453888,11762.316961288452,429.4828112125397,0.7562506198883057,0.0 -34500,3.6715558,1.8973042,,,,,,,,,,,,,, -34600,3.0074909,1.8705325,,,,,,,,,,,,,, -34700,5.576546,1.989402,,,,,,,,,,,,,, -34800,4.3639083,2.0899684,,,,,,,,,,,,,, -34900,3.9134598,1.9753608,,,,,,,,,,,,,, -35000,3.9840815,1.9281096,,,,,,,,,,,,,, -35100,3.4915195,1.9167383,,,,,,,,,,,,,, -35200,3.2263284,1.9747431,,,,,,,,,,,,,, -35300,4.19099,1.9615184,,,,,,,,,,,,,, -35400,4.875026,1.9279851,,,,,,,,,,,,,, -35500,3.7965977,1.9257565,,,,,,,,,,,,,, -35600,3.2499883,1.968776,,,,,,,,,,,,,, -35700,3.2210548,1.823985,,,,,,,,,,,,,, -35800,3.2981977,1.910497,,,,,,,,,,,,,, -35900,3.7600179,1.9206145,,,,,,,,,,,,,, -36000,,,0.6731704473495483,1.2901920080184937,0.5949400067329407,1.6796997785568235,50000.0,0.4748000204563141,2.377642869949341,10000.0,12272.395174503326,12721.819028377531,12272.395174503326,447.3253636360169,0.7911787033081055,0.0 -36000,3.5730472,1.8525037,,,,,,,,,,,,,, -36100,4.1430616,1.997744,,,,,,,,,,,,,, -36200,4.218591,1.9629135,,,,,,,,,,,,,, -36300,4.6340275,1.9938257,,,,,,,,,,,,,, -36400,3.4412777,1.9455239,,,,,,,,,,,,,, -36500,3.5443435,1.9334962,,,,,,,,,,,,,, -36600,3.2580812,2.00587,,,,,,,,,,,,,, -36700,3.671563,1.8526322,,,,,,,,,,,,,, -36800,3.4651313,1.8770089,,,,,,,,,,,,,, -36900,3.3231575,1.9377707,,,,,,,,,,,,,, -37000,3.5151365,2.0054998,,,,,,,,,,,,,, -37100,3.0111535,2.0454118,,,,,,,,,,,,,, -37200,3.6032574,1.9264393,,,,,,,,,,,,,, -37300,3.8288212,1.9573185,,,,,,,,,,,,,, -37400,3.3507795,1.8890212,,,,,,,,,,,,,, -37500,4.029606,2.0634239,,,,,,,,,,,,,, -37501,,,0.6614716053009033,1.3486958742141724,0.602679967880249,1.6470184326171875,50000.0,0.4794000089168548,2.349362134933472,10000.0,12782.483393192291,13249.76772403717,12782.483393192291,465.0970587730408,0.8253750801086426,0.0 -37600,3.7267535,1.8817325,,,,,,,,,,,,,, -37700,3.963963,1.931218,,,,,,,,,,,,,, -37800,4.307525,1.9567958,,,,,,,,,,,,,, -37900,3.9415634,2.0468013,,,,,,,,,,,,,, -38000,3.6577795,2.0305614,,,,,,,,,,,,,, -38100,4.1894813,2.021584,,,,,,,,,,,,,, -38200,4.4902945,1.9250584,,,,,,,,,,,,,, -38300,4.1308413,2.0722985,,,,,,,,,,,,,, -38400,3.5605779,1.9469672,,,,,,,,,,,,,, -38500,3.752462,1.9830136,,,,,,,,,,,,,, -38600,3.267445,1.8979816,,,,,,,,,,,,,, -38700,4.163271,1.8993675,,,,,,,,,,,,,, -38800,3.3122535,1.927022,,,,,,,,,,,,,, -38900,3.1114283,1.9079633,,,,,,,,,,,,,, -39000,3.8338072,1.8901287,,,,,,,,,,,,,, -39005,,,0.6461654901504517,1.4173647165298462,0.5987799763679504,1.6708097457885742,50000.0,0.4794000089168548,2.3809192180633545,10000.0,13292.577105998991,13777.983241558077,13292.577105998991,483.1254951953888,0.8638536930084229,0.0 -39100,3.5126526,1.9494519,,,,,,,,,,,,,, -39200,4.1212134,1.8877692,,,,,,,,,,,,,, -39300,3.7709765,1.9678103,,,,,,,,,,,,,, -39400,3.2364445,2.0040793,,,,,,,,,,,,,, -39500,4.0054293,1.8984523,,,,,,,,,,,,,, -39600,4.239578,1.8979948,,,,,,,,,,,,,, -39700,4.2662787,1.9520583,,,,,,,,,,,,,, -39800,3.7256222,1.9408268,,,,,,,,,,,,,, -39900,3.620235,1.9308734,,,,,,,,,,,,,, -40000,3.3339646,1.9660095,,,,,,,,,,,,,, -40100,3.3050659,1.8882813,,,,,,,,,,,,,, -40200,3.7981586,1.8576801,,,,,,,,,,,,,, -40300,3.8010936,1.9991037,,,,,,,,,,,,,, -40400,3.9121857,1.9777173,,,,,,,,,,,,,, -40500,3.4700916,1.9248224,,,,,,,,,,,,,, -40508,,,0.6537986397743225,1.3841086626052856,0.6029999852180481,1.6390022039413452,50000.0,0.488500028848648,2.363870143890381,10000.0,13802.59973692894,14306.094644546509,13802.59973692894,501.1253571510315,0.8978090286254883,0.0 -40600,3.6535413,1.8288141,,,,,,,,,,,,,, -40700,3.4232452,1.969665,,,,,,,,,,,,,, -40800,3.5903888,1.988723,,,,,,,,,,,,,, -40900,3.6192307,1.9248672,,,,,,,,,,,,,, -41000,3.5960705,1.9190861,,,,,,,,,,,,,, -41100,4.3833795,1.975471,,,,,,,,,,,,,, -41200,3.750308,1.7974007,,,,,,,,,,,,,, -41300,3.6449783,1.9254956,,,,,,,,,,,,,, -41400,3.8682554,1.9314393,,,,,,,,,,,,,, -41500,3.4604473,1.9421767,,,,,,,,,,,,,, -41600,4.0286064,1.8983222,,,,,,,,,,,,,, -41700,3.5459938,1.8607918,,,,,,,,,,,,,, -41800,4.376284,1.9506501,,,,,,,,,,,,,, -41900,3.291356,1.9182475,,,,,,,,,,,,,, -42000,3.9771838,1.761962,,,,,,,,,,,,,, -42012,,,0.644949734210968,1.4297263622283936,0.5964999794960022,1.6860605478286743,50000.0,0.4715000092983246,2.4117226600646973,10000.0,14312.696018218994,14833.994425058365,14312.696018218994,518.831524848938,0.941624641418457,0.0 -42100,4.557737,1.9427998,,,,,,,,,,,,,, -42200,3.4375007,1.9297872,,,,,,,,,,,,,, -42300,3.193629,1.9092281,,,,,,,,,,,,,, -42400,3.9577806,1.8531349,,,,,,,,,,,,,, -42500,3.5986998,1.9492464,,,,,,,,,,,,,, -42600,3.995644,1.86549,,,,,,,,,,,,,, -42700,5.5321803,1.974063,,,,,,,,,,,,,, -42800,3.8674254,1.8342056,,,,,,,,,,,,,, -42900,3.556351,1.7949116,,,,,,,,,,,,,, -43000,3.510775,1.8481737,,,,,,,,,,,,,, -43100,3.6651464,1.900169,,,,,,,,,,,,,, -43200,4.07443,1.8420857,,,,,,,,,,,,,, -43300,3.6616738,1.8503898,,,,,,,,,,,,,, -43400,4.238622,1.9723387,,,,,,,,,,,,,, -43500,3.5302167,1.9729583,,,,,,,,,,,,,, -43515,,,0.6421595811843872,1.4424054622650146,0.5974400043487549,1.672808289527893,50000.0,0.4844000339508056,2.345468759536743,10000.0,14822.645218849182,15361.96886754036,14822.645218849182,536.7683305740356,0.976036548614502,0.0 -43600,3.414589,1.9139543,,,,,,,,,,,,,, -43700,3.517018,1.8896291,,,,,,,,,,,,,, -43800,3.3497894,1.8842295,,,,,,,,,,,,,, -43900,4.0445695,1.9634969,,,,,,,,,,,,,, -44000,3.6996891,1.9554989,,,,,,,,,,,,,, -44100,3.1630993,1.9352765,,,,,,,,,,,,,, -44200,3.8018146,1.8237207,,,,,,,,,,,,,, -44300,3.8940628,1.8307339,,,,,,,,,,,,,, -44400,3.7455559,1.9603207,,,,,,,,,,,,,, -44500,3.741041,1.9007537,,,,,,,,,,,,,, -44600,3.6095297,1.9074577,,,,,,,,,,,,,, -44700,3.2332304,1.8222746,,,,,,,,,,,,,, -44800,3.3964252,1.815641,,,,,,,,,,,,,, -44900,3.6851723,1.9219662,,,,,,,,,,,,,, -45000,3.4985154,1.9332188,,,,,,,,,,,,,, -45019,,,0.6957908272743225,1.203609585762024,0.6162199974060059,1.5929490327835083,50000.0,0.4869000315666199,2.34109878540039,10000.0,15332.848045110704,15889.959641456604,15332.848045110704,554.4605889320374,1.0155045986175537,0.0 -45100,4.0756097,2.0526195,,,,,,,,,,,,,, -45200,4.099412,1.9327784,,,,,,,,,,,,,, -45300,3.834317,1.9341786,,,,,,,,,,,,,, -45400,3.4612966,1.88215,,,,,,,,,,,,,, -45500,4.4759417,1.9120948,,,,,,,,,,,,,, -45600,3.8404534,1.8603185,,,,,,,,,,,,,, -45700,3.6163018,1.9012098,,,,,,,,,,,,,, -45800,3.6661785,1.9680878,,,,,,,,,,,,,, -45900,4.0257277,1.9953878,,,,,,,,,,,,,, -46000,4.122666,1.8974223,,,,,,,,,,,,,, -46100,4.378319,1.8588108,,,,,,,,,,,,,, -46200,3.5702808,1.95939,,,,,,,,,,,,,, -46300,3.9760833,1.8214593,,,,,,,,,,,,,, -46400,3.7874703,1.8851236,,,,,,,,,,,,,, -46500,3.8615694,1.8693814,,,,,,,,,,,,,, -46523,,,0.6674505472183228,1.308665156364441,0.6080999970436096,1.629995584487915,50000.0,0.4812000095844269,2.37629771232605,10000.0,15842.934381008148,16418.139196634293,15842.934381008148,572.4644169807434,1.0506083965301514,0.0 -46600,3.4523444,1.8695586,,,,,,,,,,,,,, -46700,3.4404864,1.9001149,,,,,,,,,,,,,, -46800,4.7547355,1.9028208,,,,,,,,,,,,,, -46900,3.3429074,1.8877164,,,,,,,,,,,,,, -47000,3.767171,1.9208063,,,,,,,,,,,,,, -47100,3.8917868,1.8306665,,,,,,,,,,,,,, -47200,4.667358,1.951718,,,,,,,,,,,,,, -47300,3.8196404,1.8195794,,,,,,,,,,,,,, -47400,3.460979,1.7517928,,,,,,,,,,,,,, -47500,4.0649376,1.9159319,,,,,,,,,,,,,, -47600,4.1542144,1.8817745,,,,,,,,,,,,,, -47700,3.3979042,1.8684872,,,,,,,,,,,,,, -47800,3.1273415,1.8696277,,,,,,,,,,,,,, -47900,3.5353599,1.861179,,,,,,,,,,,,,, -48000,3.4911726,1.9127821,,,,,,,,,,,,,, -48027,,,0.6508689522743225,1.3953670263290403,0.5989199876785278,1.662032961845398,50000.0,0.4751000106334686,2.403264045715332,10000.0,16352.884350776672,16946.79745745659,16352.884350776672,591.0792412757874,1.088939905166626,0.0 -48100,3.6438172,1.9053801,,,,,,,,,,,,,, -48200,3.643496,1.7847056,,,,,,,,,,,,,, -48300,3.9090242,2.0235667,,,,,,,,,,,,,, -48400,3.5380745,1.8402435,,,,,,,,,,,,,, -48500,3.685601,1.893786,,,,,,,,,,,,,, -48600,4.0769324,1.9096925,,,,,,,,,,,,,, -48700,4.24719,1.884895,,,,,,,,,,,,,, -48800,4.5100284,1.8751065,,,,,,,,,,,,,, -48900,4.2824664,1.8762379,,,,,,,,,,,,,, -49000,3.342417,1.869484,,,,,,,,,,,,,, -49100,3.568794,1.9114872,,,,,,,,,,,,,, -49200,3.5778039,1.8611059,,,,,,,,,,,,,, -49300,3.9431021,1.8927066,,,,,,,,,,,,,, -49400,3.5436337,1.7814236,,,,,,,,,,,,,, -49500,3.446061,1.7304386,,,,,,,,,,,,,, -49531,,,0.6604352593421936,1.361952304840088,0.6140999794006348,1.5894731283187866,50000.0,0.4963000118732452,2.311565637588501,10000.0,16863.04797935486,17474.806631326675,16863.04797935486,608.8275811672211,1.131298542022705,0.0 -49600,3.473786,1.8932106,,,,,,,,,,,,,, -49700,3.342486,1.8447092,,,,,,,,,,,,,, -49800,4.0479155,1.8445954,,,,,,,,,,,,,, -49900,3.7797003,1.9614421,,,,,,,,,,,,,, -50000,3.754056,1.932813,,,,,,,,,,,,,, -50100,3.8183262,1.8927945,,,,,,,,,,,,,, -50200,3.5214374,1.874459,,,,,,,,,,,,,, -50300,3.5074978,1.7821462,,,,,,,,,,,,,, -50400,3.7678428,1.8046519,,,,,,,,,,,,,, -50500,3.3780518,1.8239406,,,,,,,,,,,,,, -50600,3.7659478,1.8610183,,,,,,,,,,,,,, -50700,3.7668734,1.79337,,,,,,,,,,,,,, -50800,3.86015,1.9992046,,,,,,,,,,,,,, -50900,3.2486672,1.8559031,,,,,,,,,,,,,, -51000,3.0635183,1.8765411,,,,,,,,,,,,,, -51036,,,0.6528220772743225,1.377036690711975,0.6045599579811096,1.6280007362365725,50000.0,0.4839000105857849,2.3610148429870605,10000.0,17373.247692346573,18002.78094768524,17373.247692346573,626.5100448131561,1.1692650318145752,0.0 -51100,4.28746,1.9629068,,,,,,,,,,,,,, -51200,4.57391,1.8234288,,,,,,,,,,,,,, -51300,3.666174,1.9644549,,,,,,,,,,,,,, -51400,4.3712616,1.8166355,,,,,,,,,,,,,, -51500,3.3477833,1.7825661,,,,,,,,,,,,,, -51600,3.933163,1.8102008,,,,,,,,,,,,,, -51700,3.0505676,1.9078848,,,,,,,,,,,,,, -51800,3.8755825,1.9191054,,,,,,,,,,,,,, -51900,3.3892171,1.9392258,,,,,,,,,,,,,, -52000,3.5653055,1.7623923,,,,,,,,,,,,,, -52100,3.6124716,1.8307588,,,,,,,,,,,,,, -52200,3.9917746,1.8810868,,,,,,,,,,,,,, -52300,3.2241914,1.89241,,,,,,,,,,,,,, -52400,4.6339636,1.842413,,,,,,,,,,,,,, -52500,3.5919874,1.8030006,,,,,,,,,,,,,, -52540,,,0.6638432741165161,1.3441935777664185,0.6182799935340881,1.5748586654663086,50000.0,0.4944000244140625,2.280700206756592,10000.0,17883.32636666298,18530.83543086052,17883.32636666298,644.3940415382385,1.2060277462005615,0.0 -52600,3.7662795,1.9394165,,,,,,,,,,,,,, -52700,3.5352726,1.913947,,,,,,,,,,,,,, -52800,3.3637261,1.8164498,,,,,,,,,,,,,, -52900,4.1206217,1.9505755,,,,,,,,,,,,,, -53000,3.6697624,1.9051708,,,,,,,,,,,,,, -53100,3.9460893,1.8600601,,,,,,,,,,,,,, -53200,3.3534064,1.9564501,,,,,,,,,,,,,, -53300,3.7255719,1.845951,,,,,,,,,,,,,, -53400,3.8961575,1.8984416,,,,,,,,,,,,,, -53500,3.5608644,1.8628281,,,,,,,,,,,,,, -53600,3.798764,1.9416745,,,,,,,,,,,,,, -53700,3.9492157,1.937415,,,,,,,,,,,,,, -53800,3.8479023,1.842599,,,,,,,,,,,,,, -53900,3.4195924,1.9047649,,,,,,,,,,,,,, -54000,3.888552,1.8233242,,,,,,,,,,,,,, -54045,,,0.6905492544174194,1.2071837186813354,0.604699969291687,1.6392205953598022,50000.0,0.4820000231266022,2.356093168258667,10000.0,18393.53794503212,19058.91395521164,18393.53794503212,662.1686675548553,1.2445552349090576,0.0 -54100,3.8551357,1.9089998,,,,,,,,,,,,,, -54200,3.9386158,1.8093508,,,,,,,,,,,,,, -54300,3.5369425,1.8301148,,,,,,,,,,,,,, -54400,3.6543274,1.9792671,,,,,,,,,,,,,, -54500,3.3611856,1.8425612,,,,,,,,,,,,,, -54600,3.7913868,1.8566356,,,,,,,,,,,,,, -54700,4.04342,1.9165858,,,,,,,,,,,,,, -54800,3.4444525,1.8402606,,,,,,,,,,,,,, -54900,4.4432654,1.9129372,,,,,,,,,,,,,, -55000,3.5438979,1.8437612,,,,,,,,,,,,,, -55100,3.527232,1.9029101,,,,,,,,,,,,,, -55200,4.0385375,1.7739147,,,,,,,,,,,,,, -55300,3.954696,1.7751899,,,,,,,,,,,,,, -55400,3.943297,1.8200686,,,,,,,,,,,,,, -55500,3.8412392,1.9255413,,,,,,,,,,,,,, -55549,,,0.6698620915412903,1.3108255863189695,0.608959972858429,1.6123535633087158,50000.0,0.48580002784729,2.349951982498169,10000.0,18903.47699022293,19586.715607881542,18903.47699022293,679.9338908195496,1.2859973907470703,0.0 -55600,3.7474637,1.826819,,,,,,,,,,,,,, -55700,3.9346077,1.9126298,,,,,,,,,,,,,, -55800,4.1192083,1.865868,,,,,,,,,,,,,, -55900,3.9985197,1.8717916,,,,,,,,,,,,,, -56000,4.0381846,1.8663585,,,,,,,,,,,,,, -56100,4.0214214,1.9347577,,,,,,,,,,,,,, -56200,4.4005737,1.8994684,,,,,,,,,,,,,, -56300,3.5964448,1.8010526,,,,,,,,,,,,,, -56400,3.7623277,1.8198755,,,,,,,,,,,,,, -56500,3.79679,1.85242,,,,,,,,,,,,,, -56600,4.65692,1.848888,,,,,,,,,,,,,, -56700,3.4616184,1.7785538,,,,,,,,,,,,,, -56800,3.9035165,1.840944,,,,,,,,,,,,,, -56900,3.2626843,1.9230164,,,,,,,,,,,,,, -57000,3.7241442,1.9336885,,,,,,,,,,,,,, -57054,,,0.6702606678009033,1.301419734954834,0.6184399724006653,1.5728846788406372,50000.0,0.497700035572052,2.252488136291504,10000.0,19413.60482478141,20114.74164962769,19413.60482478141,697.7373259067535,1.3273625373840332,0.0 -57100,3.5563238,1.7991538,,,,,,,,,,,,,, -57200,3.8666315,1.7631624,,,,,,,,,,,,,, -57300,4.684195,1.8898828,,,,,,,,,,,,,, -57400,3.949804,1.9221662,,,,,,,,,,,,,, -57500,3.830646,1.9238025,,,,,,,,,,,,,, -57600,4.4699783,1.8125256,,,,,,,,,,,,,, -57700,4.163243,1.9049816,,,,,,,,,,,,,, -57800,4.5230947,1.7371721,,,,,,,,,,,,,, -57900,3.8078797,1.9047416,,,,,,,,,,,,,, -58000,4.360791,1.8686919,,,,,,,,,,,,,, -58100,3.6378098,1.9283589,,,,,,,,,,,,,, -58200,3.2839863,1.9680874,,,,,,,,,,,,,, -58300,3.2557766,1.8501282,,,,,,,,,,,,,, -58400,3.4618752,1.8068014,,,,,,,,,,,,,, -58500,3.7363443,1.8555086,,,,,,,,,,,,,, -58558,,,0.6655771732330322,1.3216028213500977,0.6157599687576294,1.5845470428466797,50000.0,0.4946000277996063,2.278593063354492,10000.0,19923.63752818108,20642.72845196724,19923.63752818108,715.5969526767731,1.3676202297210691,0.0 -58600,3.354971,1.8165516,,,,,,,,,,,,,, -58700,3.771425,1.7896748,,,,,,,,,,,,,, -58800,3.847175,1.734155,,,,,,,,,,,,,, -58900,3.6223252,1.8686832,,,,,,,,,,,,,, -59000,3.7179031,1.7766533,,,,,,,,,,,,,, -59100,4.05716,1.7375354,,,,,,,,,,,,,, -59200,4.0068173,2.0111473,,,,,,,,,,,,,, -59300,3.399686,1.7656155,,,,,,,,,,,,,, -59400,3.644207,1.9153099,,,,,,,,,,,,,, -59500,4.227798,1.8499999,,,,,,,,,,,,,, -59600,3.938883,1.8929578,,,,,,,,,,,,,, -59700,4.1684413,1.7668171,,,,,,,,,,,,,, -59800,4.0463357,1.8071688,,,,,,,,,,,,,, -59900,3.971589,1.7980969,,,,,,,,,,,,,, -60000,5.8296375,1.8104502,,,,,,,,,,,,,, -60063,,,0.6641222834587097,1.3311790227890017,0.6173999905586243,1.5765674114227295,50000.0,0.4969000220298767,2.297972679138184,10000.0,20433.673320770264,21170.53536248207,20433.673320770264,733.2713446617126,1.408951997756958,0.0 -60100,4.139057,1.9813855,,,,,,,,,,,,,, -60200,4.0368567,1.8296981,,,,,,,,,,,,,, -60300,5.0108867,1.8396755,,,,,,,,,,,,,, -60400,3.865987,1.7869477,,,,,,,,,,,,,, -60500,3.6185522,1.7031202,,,,,,,,,,,,,, -60600,5.944306,1.8645463,,,,,,,,,,,,,, -60700,4.499605,1.784776,,,,,,,,,,,,,, -60800,3.4694333,1.9008865,,,,,,,,,,,,,, -60900,3.8002958,1.8414992,,,,,,,,,,,,,, -61000,4.6266623,1.8078642,,,,,,,,,,,,,, -61100,3.6482418,1.7226934,,,,,,,,,,,,,, -61200,3.8165617,1.8538598,,,,,,,,,,,,,, -61300,3.9647117,1.849446,,,,,,,,,,,,,, -61400,3.7312777,1.8272266,,,,,,,,,,,,,, -61500,4.0272436,1.9112786,,,,,,,,,,,,,, -61568,,,0.6678690910339355,1.3077174425125122,0.622439980506897,1.5476828813552856,50000.0,0.497700035572052,2.2889938354492188,10000.0,20943.756311655045,21698.42645382881,20943.756311655045,750.9849836826324,1.4479811191558838,0.0 -61600,4.049765,1.8214321,,,,,,,,,,,,,, -61700,3.8725963,1.7603009,,,,,,,,,,,,,, -61800,3.8344083,1.8439498,,,,,,,,,,,,,, -61900,3.5396523,1.7713137,,,,,,,,,,,,,, -62000,4.017493,1.8470702,,,,,,,,,,,,,, -62100,3.7068758,1.8506751,,,,,,,,,,,,,, -62200,3.6398804,1.8146118,,,,,,,,,,,,,, -62300,3.6213236,1.7816802,,,,,,,,,,,,,, -62400,4.5697484,1.88731,,,,,,,,,,,,,, -62500,3.76394,1.8732816,,,,,,,,,,,,,, -62600,4.052282,1.7111914,,,,,,,,,,,,,, -62700,3.7210703,2.0844615,,,,,,,,,,,,,, -62800,3.818941,1.7713223,,,,,,,,,,,,,, -62900,3.5049317,1.7512269,,,,,,,,,,,,,, -63000,3.6387098,1.8981278,,,,,,,,,,,,,, -63074,,,0.7181122303009033,1.0941146612167358,0.6288599967956543,1.528099536895752,50000.0,0.5130000114440918,2.2014949321746826,10000.0,21454.002192497253,22226.784834861755,21454.002192497253,769.0031280517578,1.4872050285339355,0.0 -63100,4.202753,1.8145066,,,,,,,,,,,,,, -63200,3.7506943,1.8025503,,,,,,,,,,,,,, -63300,3.3296223,1.7838787,,,,,,,,,,,,,, -63400,3.4886806,1.8739738,,,,,,,,,,,,,, -63500,3.5508466,1.7087686,,,,,,,,,,,,,, -63600,3.9200442,1.6866903,,,,,,,,,,,,,, -63700,3.7577605,1.7925875,,,,,,,,,,,,,, -63800,4.1035447,1.8773727,,,,,,,,,,,,,, -63900,3.438935,1.7044713,,,,,,,,,,,,,, -64000,3.9223604,1.8325851,,,,,,,,,,,,,, -64100,4.0386095,1.9249289,,,,,,,,,,,,,, -64200,3.4237094,1.7603214,,,,,,,,,,,,,, -64300,3.8155684,1.7542095,,,,,,,,,,,,,, -64400,3.7976754,1.7339053,,,,,,,,,,,,,, -64500,3.923082,1.9420987,,,,,,,,,,,,,, -64579,,,0.6921635866165161,1.2201071977615356,0.6279599666595459,1.5359902381896973,50000.0,0.5029000043869019,2.242102861404419,10000.0,21964.22785615921,22754.9960372448,21964.22785615921,786.8960626125336,1.5244412422180176,0.0 -64600,3.6079385,1.7427918,,,,,,,,,,,,,, -64700,3.725939,1.8310175,,,,,,,,,,,,,, -64800,3.924435,1.7717131,,,,,,,,,,,,,, -64900,3.5873547,1.8920511,,,,,,,,,,,,,, -65000,4.201822,1.7153279,,,,,,,,,,,,,, -65100,3.909051,1.9366591,,,,,,,,,,,,,, -65200,3.764332,1.7911086,,,,,,,,,,,,,, -65300,3.7244601,1.9252238,,,,,,,,,,,,,, -65400,3.8476522,1.7615404,,,,,,,,,,,,,, -65500,3.4882357,1.705147,,,,,,,,,,,,,, -65600,3.8922548,1.894081,,,,,,,,,,,,,, -65700,4.1786094,1.8529425,,,,,,,,,,,,,, -65800,3.8624678,1.7132919,,,,,,,,,,,,,, -65900,3.9040976,1.7884396,,,,,,,,,,,,,, -66000,3.9617305,1.7673628,,,,,,,,,,,,,, -66085,,,0.6760801672935486,1.270575761795044,0.623199999332428,1.5508557558059692,50000.0,0.4992000162601471,2.274237632751465,10000.0,22474.41934657097,23283.18795681,22474.41934657097,804.7957236766815,1.5703482627868652,0.0 -66100,3.781227,1.8036206,,,,,,,,,,,,,, -66200,4.26958,1.7958899,,,,,,,,,,,,,, -66300,3.943312,1.7613884,,,,,,,,,,,,,, -66400,4.3264885,1.888688,,,,,,,,,,,,,, -66500,4.83304,1.7246698,,,,,,,,,,,,,, -66600,3.7485414,1.8167162,,,,,,,,,,,,,, -66700,3.7803771,1.7747893,,,,,,,,,,,,,, -66800,3.5546184,1.7038628,,,,,,,,,,,,,, -66900,3.575219,1.8765221,,,,,,,,,,,,,, -67000,4.103999,1.7354192,,,,,,,,,,,,,, -67100,4.073138,1.8372494,,,,,,,,,,,,,, -67200,4.1045604,1.786528,,,,,,,,,,,,,, -67300,3.919502,1.7950321,,,,,,,,,,,,,, -67400,3.6290457,1.6850638,,,,,,,,,,,,,, -67500,3.6682959,1.728211,,,,,,,,,,,,,, -67590,,,0.675203263759613,1.2853842973709106,0.6229400038719177,1.5472275018692017,50000.0,0.5045000314712524,2.2707667350769043,10000.0,22984.531358480453,23811.22071957588,22984.531358480453,822.6233458518982,1.609081745147705,0.0 -67600,3.7807884,1.8120332,,,,,,,,,,,,,, -67700,3.9344049,1.774742,,,,,,,,,,,,,, -67800,4.5170918,1.9440536,,,,,,,,,,,,,, -67900,3.8652558,1.8931799,,,,,,,,,,,,,, -68000,3.6601918,1.7268946,,,,,,,,,,,,,, -68100,3.8152456,1.8327647,,,,,,,,,,,,,, -68200,3.6609712,1.6444244,,,,,,,,,,,,,, -68300,4.1595116,1.78422,,,,,,,,,,,,,, -68400,4.1385293,1.8351026,,,,,,,,,,,,,, -68500,4.319787,1.7094574,,,,,,,,,,,,,, -68600,4.2356424,1.745285,,,,,,,,,,,,,, -68700,4.248506,1.8244293,,,,,,,,,,,,,, -68800,4.368505,1.7616997,,,,,,,,,,,,,, -68900,4.819145,1.8016316,,,,,,,,,,,,,, -69000,3.8084965,1.8243376,,,,,,,,,,,,,, -69095,,,0.682039201259613,1.2529497146606443,0.6260799765586853,1.544729471206665,50000.0,0.5063000321388245,2.2389981746673584,10000.0,23494.5044093132,24338.96833062172,23494.5044093132,840.2972972393036,1.6531808376312256,0.0 -69100,3.6576905,1.6968066,,,,,,,,,,,,,, -69200,4.1702437,1.639331,,,,,,,,,,,,,, -69300,3.9865344,1.8157269,,,,,,,,,,,,,, -69400,3.7736812,1.8399813,,,,,,,,,,,,,, -69500,4.6335607,1.8243757,,,,,,,,,,,,,, -69600,3.89459,1.7728548,,,,,,,,,,,,,, -69700,3.7816896,1.7152213,,,,,,,,,,,,,, -69800,3.8044026,1.7939473,,,,,,,,,,,,,, -69900,3.6587667,1.7585919,,,,,,,,,,,,,, -70000,4.1697884,1.7616581,,,,,,,,,,,,,, -70100,4.120089,1.7864969,,,,,,,,,,,,,, -70200,5.17399,1.714258,,,,,,,,,,,,,, -70300,3.449455,1.7428267,,,,,,,,,,,,,, -70400,3.5961187,1.8359216,,,,,,,,,,,,,, -70500,4.103145,1.7250762,,,,,,,,,,,,,, -70600,,,0.6823381781578064,1.261600375175476,0.6331599950790405,1.5052260160446167,50000.0,0.5125000476837158,2.193894147872925,10000.0,24004.64289021492,24866.99529266357,24004.64289021492,858.0913376808167,1.6934871673583984,0.0 -70600,4.3042536,1.8036551,,,,,,,,,,,,,, -70700,3.5973768,1.8194585,,,,,,,,,,,,,, -70800,3.5210226,1.7820439,,,,,,,,,,,,,, -70900,3.749658,1.8739438,,,,,,,,,,,,,, -71000,3.8857992,1.7113549,,,,,,,,,,,,,, -71100,3.623407,1.6254098,,,,,,,,,,,,,, -71200,4.4693913,1.7980491,,,,,,,,,,,,,, -71300,3.5259004,1.8628179,,,,,,,,,,,,,, -71400,3.9115791,1.7914333,,,,,,,,,,,,,, -71500,4.5442796,1.8739178,,,,,,,,,,,,,, -71600,3.9841304,1.9032168,,,,,,,,,,,,,, -71700,4.3878922,1.8362992,,,,,,,,,,,,,, -71800,3.905769,1.6655754,,,,,,,,,,,,,, -71900,4.531508,1.8531619,,,,,,,,,,,,,, -72000,3.9244125,1.6752281,,,,,,,,,,,,,, -72100,4.526089,1.812746,,,,,,,,,,,,,, -72105,,,0.7115353941917419,1.123003363609314,0.6309399604797363,1.522621989250183,50000.0,0.5054000020027161,2.2285428047180176,10000.0,24514.733004808422,25394.98640203476,24514.733004808422,875.8923208713531,1.7388732433319092,0.0 -72200,4.5946445,1.7102861,,,,,,,,,,,,,, -72300,4.114622,1.8102903,,,,,,,,,,,,,, -72400,4.3850517,1.8385909,,,,,,,,,,,,,, -72500,3.6645844,1.7900459,,,,,,,,,,,,,, -72600,4.215743,1.6369424,,,,,,,,,,,,,, -72700,4.8988256,1.6664711,,,,,,,,,,,,,, -72800,3.9215803,1.7569818,,,,,,,,,,,,,, -72900,3.5088906,1.7113283,,,,,,,,,,,,,, -73000,3.5423107,1.8053043,,,,,,,,,,,,,, -73100,4.3344817,1.8797885,,,,,,,,,,,,,, -73200,3.5643888,1.7755319,,,,,,,,,,,,,, -73300,3.90922,1.6030815,,,,,,,,,,,,,, -73400,4.188562,1.7365773,,,,,,,,,,,,,, -73500,3.6346815,1.7668164,,,,,,,,,,,,,, -73600,4.340678,1.6567419,,,,,,,,,,,,,, -73610,,,0.7038823366165161,1.1389704942703247,0.6362000107765198,1.4967260360717771,50000.0,0.5118000507354736,2.229922294616699,10000.0,25024.659342050552,25922.62748122216,25024.659342050552,893.5084192752838,1.7819523811340332,0.0 -73700,4.0212026,1.6675097,,,,,,,,,,,,,, -73800,4.0457554,1.6707995,,,,,,,,,,,,,, -73900,4.5022144,1.8853446,,,,,,,,,,,,,, -74000,4.275156,1.6484029,,,,,,,,,,,,,, -74100,3.6788635,1.6025386,,,,,,,,,,,,,, -74200,3.8137603,1.7908151,,,,,,,,,,,,,, -74300,3.7723942,1.8302988,,,,,,,,,,,,,, -74400,3.821423,1.737375,,,,,,,,,,,,,, -74500,4.04442,1.7946341,,,,,,,,,,,,,, -74600,4.415854,1.9245533,,,,,,,,,,,,,, -74700,4.5069137,1.805539,,,,,,,,,,,,,, -74800,4.148097,1.8885691,,,,,,,,,,,,,, -74900,3.8634393,1.7580376,,,,,,,,,,,,,, -75000,4.2475133,1.7775972,,,,,,,,,,,,,, -75100,3.7620208,1.7272232,,,,,,,,,,,,,, -75116,,,0.6950334906578064,1.1846526861190796,0.6358999609947205,1.490337610244751,50000.0,0.5091000199317932,2.216155767440796,10000.0,25534.734695911407,26450.662058591843,25534.734695911407,911.36985373497,1.823401689529419,0.0 -75200,3.929922,1.7814784,,,,,,,,,,,,,, -75300,4.0065627,1.8002372,,,,,,,,,,,,,, -75400,4.229354,1.8080536,,,,,,,,,,,,,, -75500,3.8980625,1.7430066,,,,,,,,,,,,,, -75600,3.7809107,1.7190593,,,,,,,,,,,,,, -75700,4.1726108,1.7450519,,,,,,,,,,,,,, -75800,4.221935,1.7909632,,,,,,,,,,,,,, -75900,4.1234884,1.8532255,,,,,,,,,,,,,, -76000,4.1263604,1.7297461,,,,,,,,,,,,,, -76100,4.2340164,1.8255113,,,,,,,,,,,,,, -76200,3.9277267,1.6806331,,,,,,,,,,,,,, -76300,4.25189,1.6548524,,,,,,,,,,,,,, -76400,3.6963418,1.7647913,,,,,,,,,,,,,, -76500,3.5153103,1.6944968,,,,,,,,,,,,,, -76600,3.775362,1.7674235,,,,,,,,,,,,,, -76621,,,0.6759606003761292,1.2714606523513794,0.6222000122070312,1.5547786951065063,50000.0,0.4903000295162201,2.2797510623931885,10000.0,26044.821749687195,26978.580530405045,26044.821749687195,929.1047916412354,1.8654460906982424,0.0 -76700,3.8852246,1.7802811,,,,,,,,,,,,,, -76800,4.413227,1.7252158,,,,,,,,,,,,,, -76900,3.658199,1.682368,,,,,,,,,,,,,, -77000,3.6346543,1.7682439,,,,,,,,,,,,,, -77100,3.8796866,1.7632613,,,,,,,,,,,,,, -77200,4.641309,1.7234802,,,,,,,,,,,,,, -77300,4.169009,1.7924631,,,,,,,,,,,,,, -77400,3.8836722,1.7994115,,,,,,,,,,,,,, -77500,3.7329435,1.6467073,,,,,,,,,,,,,, -77600,4.420047,1.7483137,,,,,,,,,,,,,, -77700,3.5472198,1.6312677,,,,,,,,,,,,,, -77800,4.038581,1.7922292,,,,,,,,,,,,,, -77900,4.2659216,1.8295839,,,,,,,,,,,,,, -78000,4.0400953,1.6019967,,,,,,,,,,,,,, -78100,3.931924,1.5870391,,,,,,,,,,,,,, -78127,,,0.696687638759613,1.187483787536621,0.6402400135993958,1.469867467880249,50000.0,0.5200999975204468,2.1450984477996826,10000.0,26555.042127132416,27506.72307229042,26555.042127132416,946.9249217510225,1.913517951965332,0.0 -78200,3.8724442,1.8206787,,,,,,,,,,,,,, -78300,4.120112,1.6486746,,,,,,,,,,,,,, -78400,4.301408,1.7005281,,,,,,,,,,,,,, -78500,4.401671,1.7858287,,,,,,,,,,,,,, -78600,4.312145,1.7672782,,,,,,,,,,,,,, -78700,3.8951514,1.6233215,,,,,,,,,,,,,, -78800,4.579044,1.7022319,,,,,,,,,,,,,, -78900,3.5350187,1.6834699,,,,,,,,,,,,,, -79000,4.193444,1.8454744,,,,,,,,,,,,,, -79100,4.2679396,1.7279822,,,,,,,,,,,,,, -79200,4.101369,1.6969993,,,,,,,,,,,,,, -79300,4.371409,1.6503253,,,,,,,,,,,,,, -79400,3.7326028,1.6929992,,,,,,,,,,,,,, -79500,4.027888,1.7597307,,,,,,,,,,,,,, -79600,4.08063,1.7323887,,,,,,,,,,,,,, -79632,,,0.6890544891357422,1.2304526567459106,0.635919988155365,1.4974093437194824,50000.0,0.5033000111579895,2.213355302810669,10000.0,27065.05168557167,28034.80432486534,27065.05168557167,964.8990716934204,1.956645011901856,0.0 -79700,4.794127,1.8042475,,,,,,,,,,,,,, -79800,3.9548562,1.6735513,,,,,,,,,,,,,, -79900,4.2842846,1.639157,,,,,,,,,,,,,, -80000,4.3599696,1.5723767,,,,,,,,,,,,,, -80100,4.6109624,1.7750103,,,,,,,,,,,,,, -80200,4.239351,1.743694,,,,,,,,,,,,,, -80300,4.0072513,1.6655288,,,,,,,,,,,,,, -80400,4.018439,1.7028135,,,,,,,,,,,,,, -80500,3.7550063,1.7345737,,,,,,,,,,,,,, -80600,4.1027794,1.8192573,,,,,,,,,,,,,, -80700,4.0705175,1.7785746,,,,,,,,,,,,,, -80800,4.148474,1.6858592,,,,,,,,,,,,,, -80900,4.546386,1.6427474,,,,,,,,,,,,,, -81000,3.7916017,1.6743368,,,,,,,,,,,,,, -81100,4.8704057,1.6092002,,,,,,,,,,,,,, -81138,,,0.7074896097183228,1.138108491897583,0.6323999762535095,1.5167434215545654,50000.0,0.5028000473976135,2.267728090286255,10000.0,27575.18598818779,28562.895799398422,27575.18598818779,982.7576727867126,2.000518560409546,0.0 -81200,4.3790526,1.5534458,,,,,,,,,,,,,, -81300,3.618181,1.5968596,,,,,,,,,,,,,, -81400,3.9580812,1.7263942,,,,,,,,,,,,,, -81500,4.3272495,1.7332793,,,,,,,,,,,,,, -81600,4.1143413,1.702277,,,,,,,,,,,,,, -81700,4.423195,1.7533643,,,,,,,,,,,,,, -81800,3.8295937,1.7134116,,,,,,,,,,,,,, -81900,4.6217957,1.6921442,,,,,,,,,,,,,, -82000,4.288305,1.6861526,,,,,,,,,,,,,, -82100,4.9432063,1.6830401,,,,,,,,,,,,,, -82200,4.5436964,1.721505,,,,,,,,,,,,,, -82300,3.6691349,1.6991854,,,,,,,,,,,,,, -82400,4.3992715,1.6868455,,,,,,,,,,,,,, -82500,4.6599298,1.6778016,,,,,,,,,,,,,, -82600,4.8121324,1.6422105,,,,,,,,,,,,,, -82643,,,0.7142258882522583,1.0972977876663208,0.6411199569702148,1.474290132522583,50000.0,0.5162000060081482,2.1949095726013184,10000.0,28085.256512880325,29090.73493361473,28085.256512880325,1000.4262602329254,2.045931100845337,0.0 -82700,4.3008122,1.7884438,,,,,,,,,,,,,, -82800,3.7822182,1.6585724,,,,,,,,,,,,,, -82900,4.2278085,1.7486148,,,,,,,,,,,,,, -83000,5.03697,1.7053303,,,,,,,,,,,,,, -83100,4.3518085,1.6624092,,,,,,,,,,,,,, -83200,4.911066,1.699764,,,,,,,,,,,,,, -83300,3.700288,1.6606588,,,,,,,,,,,,,, -83400,4.8110623,1.7634109,,,,,,,,,,,,,, -83500,4.370795,1.7030988,,,,,,,,,,,,,, -83600,4.0184875,1.7664707,,,,,,,,,,,,,, -83700,4.2858357,1.7036173,,,,,,,,,,,,,, -83800,4.9730196,1.6929591,,,,,,,,,,,,,, -83900,4.0944304,1.7191125,,,,,,,,,,,,,, -84000,4.3400364,1.7297393,,,,,,,,,,,,,, -84100,3.5150342,1.7400213,,,,,,,,,,,,,, -84148,,,0.7117147445678711,1.1199214458465576,0.6462799906730652,1.4425675868988037,50000.0,0.5175999999046326,2.1688649654388428,10000.0,28595.185331583023,29618.681260347366,28595.185331583023,1018.34730553627,2.0879697799682617,0.0 -84200,3.854864,1.6310354,,,,,,,,,,,,,, -84300,4.45085,1.6826931,,,,,,,,,,,,,, -84400,5.0399976,1.6793945,,,,,,,,,,,,,, -84500,3.9682057,1.7005062,,,,,,,,,,,,,, -84600,3.8926091,1.7873849,,,,,,,,,,,,,, -84700,4.1763225,1.6873779,,,,,,,,,,,,,, -84800,4.62752,1.653133,,,,,,,,,,,,,, -84900,4.1221952,1.6368982,,,,,,,,,,,,,, -85000,4.197249,1.6744587,,,,,,,,,,,,,, -85100,4.311009,1.6677085,,,,,,,,,,,,,, -85200,3.955552,1.5796906,,,,,,,,,,,,,, -85300,3.5399292,1.7547586,,,,,,,,,,,,,, -85400,4.403648,1.6277244,,,,,,,,,,,,,, -85500,4.388246,1.6997868,,,,,,,,,,,,,, -85600,3.7527788,1.6691406,,,,,,,,,,,,,, -85654,,,0.7042809128761292,1.16113018989563,0.6452400088310242,1.4567824602127075,50000.0,0.5216000080108643,2.153033971786499,10000.0,29105.374361276627,30147.20574998856,29105.374361276627,1036.582043170929,2.133475542068481,0.0 -85700,4.4593253,1.6991366,,,,,,,,,,,,,, -85800,3.6807697,1.6930742,,,,,,,,,,,,,, -85900,4.133704,1.7300427,,,,,,,,,,,,,, -86000,3.7873807,1.667981,,,,,,,,,,,,,, -86100,4.007974,1.6733954,,,,,,,,,,,,,, -86200,5.2018514,1.8252312,,,,,,,,,,,,,, -86300,4.427922,1.6624444,,,,,,,,,,,,,, -86400,4.3889194,1.8192015,,,,,,,,,,,,,, -86500,4.777964,1.7265896,,,,,,,,,,,,,, -86600,4.256991,1.7426153,,,,,,,,,,,,,, -86700,4.190359,1.7885065,,,,,,,,,,,,,, -86800,4.2608204,1.7540224,,,,,,,,,,,,,, -86900,4.7663603,1.6066875,,,,,,,,,,,,,, -87000,4.739272,1.6602468,,,,,,,,,,,,,, -87100,5.076337,1.726305,,,,,,,,,,,,,, -87160,,,0.7081273794174194,1.1381235122680664,0.6510800123214722,1.4280773401260376,50000.0,0.5284000039100647,2.1416590213775635,10000.0,29615.305659532547,30675.376630306244,29615.305659532547,1054.7166481018066,2.183864116668701,0.0 -87200,4.0117536,1.6539782,,,,,,,,,,,,,, -87300,4.4229074,1.6281416,,,,,,,,,,,,,, -87400,4.061122,1.7024144,,,,,,,,,,,,,, -87500,4.273985,1.725284,,,,,,,,,,,,,, -87600,4.1844983,1.667146,,,,,,,,,,,,,, -87700,4.1589513,1.6855087,,,,,,,,,,,,,, -87800,4.5667505,1.698363,,,,,,,,,,,,,, -87900,4.6490474,1.7059546,,,,,,,,,,,,,, -88000,4.077484,1.7759417,,,,,,,,,,,,,, -88100,3.9314566,1.6918111,,,,,,,,,,,,,, -88200,3.9433997,1.7557194,,,,,,,,,,,,,, -88300,5.0640345,1.5801477,,,,,,,,,,,,,, -88400,4.405108,1.715067,,,,,,,,,,,,,, -88500,4.5518985,1.7866366,,,,,,,,,,,,,, -88600,4.0967293,1.5969009,,,,,,,,,,,,,, -88662,,,0.706074595451355,1.148975849151611,0.6509799957275391,1.4250391721725464,50000.0,0.520300030708313,2.128875970840454,10000.0,30124.24169325829,31203.1488199234,30124.24169325829,1072.3675389289856,3.314368963241577,0.0 -88700,3.8245573,1.5854933,,,,,,,,,,,,,, -88800,4.2651844,1.763152,,,,,,,,,,,,,, -88900,4.701159,1.6432292,,,,,,,,,,,,,, -89000,4.645038,1.6767108,,,,,,,,,,,,,, -89100,4.51324,1.607679,,,,,,,,,,,,,, -89200,4.323447,1.6925327,,,,,,,,,,,,,, -89300,4.2099447,1.6704404,,,,,,,,,,,,,, -89400,4.3790994,1.7801092,,,,,,,,,,,,,, -89500,4.073634,1.7249196,,,,,,,,,,,,,, -89600,4.265574,1.7711349,,,,,,,,,,,,,, -89700,4.0913587,1.5833267,,,,,,,,,,,,,, -89800,4.9540834,1.70277,,,,,,,,,,,,,, -89900,4.98392,1.7194989,,,,,,,,,,,,,, -90000,4.294571,1.7766532,,,,,,,,,,,,,, -90100,4.2705555,1.664431,,,,,,,,,,,,,, -90168,,,0.7113161683082581,1.1303361654281616,0.6467599868774414,1.444667100906372,50000.0,0.5240000486373901,2.1700925827026367,10000.0,30634.32846808433,31731.047548294067,30634.32846808433,1090.077528476715,3.361715793609619,0.0 -90200,4.0412116,1.661475,,,,,,,,,,,,,, -90300,4.2077193,1.7249547,,,,,,,,,,,,,, -90400,4.284868,1.6134021,,,,,,,,,,,,,, -90500,4.8742247,1.6360133,,,,,,,,,,,,,, -90600,4.263751,1.6349026,,,,,,,,,,,,,, -90700,4.372862,1.7267957,,,,,,,,,,,,,, -90800,4.3041377,1.5965624,,,,,,,,,,,,,, -90900,4.8605275,1.6202812,,,,,,,,,,,,,, -91000,4.414419,1.6243271,,,,,,,,,,,,,, -91100,5.4554186,1.6482003,,,,,,,,,,,,,, -91200,4.892167,1.6397197,,,,,,,,,,,,,, -91300,3.915668,1.6750972,,,,,,,,,,,,,, -91400,4.205492,1.6611296,,,,,,,,,,,,,, -91500,4.1153426,1.5831587,,,,,,,,,,,,,, -91600,4.0848055,1.695955,,,,,,,,,,,,,, -91673,,,0.7281768321990967,1.0502692461013794,0.6510199904441833,1.4127343893051147,50000.0,0.5311000347137451,2.098185062408448,10000.0,31144.40614700317,32258.84111380577,31144.40614700317,1107.6928596496582,3.4071123600006104,0.0 -91700,5.1332297,1.7504389,,,,,,,,,,,,,, -91800,6.0426235,1.7529807,,,,,,,,,,,,,, -91900,4.963945,1.6810472,,,,,,,,,,,,,, -92000,3.8795877,1.6852151,,,,,,,,,,,,,, -92100,4.2441845,1.6584382,,,,,,,,,,,,,, -92200,4.0098834,1.5786988,,,,,,,,,,,,,, -92300,4.723,1.6247693,,,,,,,,,,,,,, -92400,3.9904807,1.6470054,,,,,,,,,,,,,, -92500,4.705143,1.6338278,,,,,,,,,,,,,, -92600,4.421271,1.532951,,,,,,,,,,,,,, -92700,4.624884,1.6285323,,,,,,,,,,,,,, -92800,5.411213,1.5740364,,,,,,,,,,,,,, -92900,4.4627266,1.7347128,,,,,,,,,,,,,, -93000,4.3055325,1.7060661,,,,,,,,,,,,,, -93100,4.1534834,1.5638694,,,,,,,,,,,,,, -93179,,,0.7147042155265808,1.1151316165924072,0.6510599851608276,1.43448007106781,50000.0,0.5208000540733337,2.1194303035736084,10000.0,31654.537391901016,32786.896673202515,31654.537391901016,1125.5177392959597,3.4525933265686035,0.0 -93200,4.3146634,1.673984,,,,,,,,,,,,,, -93300,5.096919,1.5589073,,,,,,,,,,,,,, -93400,4.3375845,1.7813679,,,,,,,,,,,,,, -93500,4.1340404,1.5863982,,,,,,,,,,,,,, -93600,4.863106,1.6176351,,,,,,,,,,,,,, -93700,4.6778007,1.6100417,,,,,,,,,,,,,, -93800,4.558878,1.6383973,,,,,,,,,,,,,, -93900,5.060881,1.5384523,,,,,,,,,,,,,, -94000,4.297629,1.6891067,,,,,,,,,,,,,, -94100,3.8790944,1.6736919,,,,,,,,,,,,,, -94200,5.3404994,1.6413774,,,,,,,,,,,,,, -94300,4.0799446,1.5158452,,,,,,,,,,,,,, -94400,4.2107058,1.5410817,,,,,,,,,,,,,, -94500,4.664633,1.6162493,,,,,,,,,,,,,, -94600,4.423232,1.6705626,,,,,,,,,,,,,, -94684,,,0.7174146771430969,1.0910450220108032,0.6616799831390381,1.3794057369232178,50000.0,0.5320000052452087,2.085729122161865,10000.0,32164.47088861465,33314.628957271576,32164.47088861465,1143.2149093151093,3.500083923339844,0.0 -94700,4.194376,1.547493,,,,,,,,,,,,,, -94800,4.536416,1.7151709,,,,,,,,,,,,,, -94900,4.001586,1.5568231,,,,,,,,,,,,,, -95000,4.2955804,1.7067536,,,,,,,,,,,,,, -95100,4.060018,1.6525438,,,,,,,,,,,,,, -95200,4.759288,1.5490875,,,,,,,,,,,,,, -95300,4.3453455,1.5931545,,,,,,,,,,,,,, -95400,4.903692,1.6244795,,,,,,,,,,,,,, -95500,5.021834,1.5520186,,,,,,,,,,,,,, -95600,4.5276923,1.5958824,,,,,,,,,,,,,, -95700,4.38609,1.6240307,,,,,,,,,,,,,, -95800,4.571598,1.6842892,,,,,,,,,,,,,, -95900,3.8427794,1.5706043,,,,,,,,,,,,,, -96000,4.591556,1.7687395,,,,,,,,,,,,,, -96100,4.7345166,1.8372848,,,,,,,,,,,,,, -96190,,,0.7141262888908386,1.1053730249404907,0.6563599705696106,1.3915393352508545,50000.0,0.530500054359436,2.102585792541504,10000.0,32674.50990009308,33842.723615169525,32674.50990009308,1161.172973394394,3.5434374809265137,0.0 -96200,4.5152144,1.705254,,,,,,,,,,,,,, -96300,4.705754,1.6882815,,,,,,,,,,,,,, -96400,4.9667187,1.5862439,,,,,,,,,,,,,, -96500,4.6403327,1.6407517,,,,,,,,,,,,,, -96600,4.7628226,1.7529461,,,,,,,,,,,,,, -96700,4.961684,1.6646264,,,,,,,,,,,,,, -96800,4.6228046,1.5576286,,,,,,,,,,,,,, -96900,4.7568684,1.6201346,,,,,,,,,,,,,, -97000,3.9667835,1.6154959,,,,,,,,,,,,,, -97100,4.94637,1.6275724,,,,,,,,,,,,,, -97200,4.527281,1.5694898,,,,,,,,,,,,,, -97300,4.6452837,1.6054575,,,,,,,,,,,,,, -97400,4.7378116,1.5972805,,,,,,,,,,,,,, -97500,4.3658977,1.451036,,,,,,,,,,,,,, -97600,4.90925,1.7213478,,,,,,,,,,,,,, -97696,,,0.7200254797935486,1.0727653503417969,0.6601399779319763,1.3852746486663818,50000.0,0.5344000458717346,2.091149091720581,10000.0,33184.6513376236,34370.78517818451,33184.6513376236,1178.9914045333862,3.5905425548553467,0.0 -97700,5.4737997,1.6105485,,,,,,,,,,,,,, -97800,4.799472,1.6899812,,,,,,,,,,,,,, -97900,4.6975713,1.5894475,,,,,,,,,,,,,, -98000,4.6068945,1.6985304,,,,,,,,,,,,,, -98100,5.10935,1.6273913,,,,,,,,,,,,,, -98200,4.13677,1.701131,,,,,,,,,,,,,, -98300,4.7007136,1.5807037,,,,,,,,,,,,,, -98400,4.5441184,1.4583164,,,,,,,,,,,,,, -98500,4.357648,1.7119777,,,,,,,,,,,,,, -98600,5.268182,1.676958,,,,,,,,,,,,,, -98700,4.2924147,1.6033218,,,,,,,,,,,,,, -98800,5.0735426,1.6852119,,,,,,,,,,,,,, -98900,4.5862584,1.702525,,,,,,,,,,,,,, -99000,4.3060136,1.449323,,,,,,,,,,,,,, -99100,4.2816052,1.4955796,,,,,,,,,,,,,, -99200,4.3754597,1.5660802,,,,,,,,,,,,,, -99202,,,0.7228953838348389,1.072148680686951,0.6584399938583374,1.380805730819702,50000.0,0.5306000113487244,2.0861809253692627,10000.0,33694.69447398186,34898.85359764099,33694.69447398186,1196.9177355766296,3.6347944736480713,0.0 -99300,5.327344,1.7566061,,,,,,,,,,,,,, -99400,4.762598,1.6616046,,,,,,,,,,,,,, -99500,4.8201237,1.5935265,,,,,,,,,,,,,, -99600,4.3314214,1.6057869,,,,,,,,,,,,,, -99700,4.8197293,1.6632776,,,,,,,,,,,,,, -99800,4.034734,1.5679357,,,,,,,,,,,,,, -99900,4.3025727,1.5388443,,,,,,,,,,,,,, -100000,4.3704967,1.6202049,,,,,,,,,,,,,, -100100,4.812176,1.5349028,,,,,,,,,,,,,, -100200,4.5817094,1.5021737,,,,,,,,,,,,,, -100300,4.661375,1.6030014,,,,,,,,,,,,,, -100400,4.6562777,1.5596468,,,,,,,,,,,,,, -100500,4.3693657,1.6153224,,,,,,,,,,,,,, -100600,5.132881,1.5447377,,,,,,,,,,,,,, -100700,4.9149194,1.5579886,,,,,,,,,,,,,, -100707,,,0.7236925959587097,1.0682148933410645,0.6428999900817871,1.46091890335083,50000.0,0.5213000178337097,2.180105209350586,10000.0,34204.65418744087,35426.71285367012,34204.65418744087,1214.7148866653442,3.682264804840088,0.0 -100800,5.0146008,1.6121656,,,,,,,,,,,,,, -100900,4.2667475,1.6595721,,,,,,,,,,,,,, -101000,4.619228,1.6144264,,,,,,,,,,,,,, -101100,4.881141,1.6340302,,,,,,,,,,,,,, -101200,5.6695595,1.5164286,,,,,,,,,,,,,, -101300,4.885592,1.5385522,,,,,,,,,,,,,, -101400,5.0153823,1.62643,,,,,,,,,,,,,, -101500,4.3204184,1.5446692,,,,,,,,,,,,,, -101600,4.352453,1.5311389,,,,,,,,,,,,,, -101700,4.4022145,1.6721082,,,,,,,,,,,,,, -101800,4.8538084,1.5729209,,,,,,,,,,,,,, -101900,5.130869,1.701681,,,,,,,,,,,,,, -102000,4.727898,1.646484,,,,,,,,,,,,,, -102100,4.2771316,1.4481984,,,,,,,,,,,,,, -102200,4.722918,1.6171248,,,,,,,,,,,,,, -102213,,,0.7354512214660645,1.0134378671646118,0.6638399958610535,1.3594108819961548,50000.0,0.5330000519752502,2.092923641204834,10000.0,34714.8821105957,35954.948600530624,34714.8821105957,1232.6231932640076,3.7265255451202393,0.0 -102300,4.767277,1.6029785,,,,,,,,,,,,,, -102400,4.1878543,1.4456981,,,,,,,,,,,,,, -102500,5.0423884,1.5119838,,,,,,,,,,,,,, -102600,4.7617583,1.4517668,,,,,,,,,,,,,, -102700,4.281914,1.4484153,,,,,,,,,,,,,, -102800,4.419691,1.6514809,,,,,,,,,,,,,, -102900,4.779642,1.6457937,,,,,,,,,,,,,, -103000,4.274019,1.47072,,,,,,,,,,,,,, -103100,5.2274103,1.6028007,,,,,,,,,,,,,, -103200,4.3931403,1.540629,,,,,,,,,,,,,, -103300,4.7951236,1.5278658,,,,,,,,,,,,,, -103400,4.4577413,1.443533,,,,,,,,,,,,,, -103500,4.6703353,1.6720911,,,,,,,,,,,,,, -103600,4.259428,1.5985383,,,,,,,,,,,,,, -103700,4.7014356,1.6504829,,,,,,,,,,,,,, -103718,,,0.735371470451355,1.017613410949707,0.6706799864768982,1.3379477262496948,50000.0,0.5445000529289246,2.0400118827819824,10000.0,35225.05365109444,36482.834594249725,35225.05365109444,1250.2376444339752,3.770775318145752,0.0 -103800,4.330952,1.4937084,,,,,,,,,,,,,, -103900,4.9847965,1.5633731,,,,,,,,,,,,,, -104000,4.5234714,1.6182914,,,,,,,,,,,,,, -104100,4.4003015,1.6158566,,,,,,,,,,,,,, -104200,4.0415454,1.5170771,,,,,,,,,,,,,, -104300,4.617198,1.5213859,,,,,,,,,,,,,, -104400,4.987395,1.6568632,,,,,,,,,,,,,, -104500,5.088245,1.636898,,,,,,,,,,,,,, -104600,5.4956293,1.488887,,,,,,,,,,,,,, -104700,5.549114,1.6385231,,,,,,,,,,,,,, -104800,5.210948,1.5777805,,,,,,,,,,,,,, -104900,4.2893343,1.4984417,,,,,,,,,,,,,, -105000,5.1148643,1.5591171,,,,,,,,,,,,,, -105100,5.0674424,1.5994232,,,,,,,,,,,,,, -105200,5.1368523,1.5194565,,,,,,,,,,,,,, -105224,,,0.7281967401504517,1.0495648384094238,0.6665599942207336,1.3577643632888794,50000.0,0.5378000140190125,2.079288482666016,10000.0,35735.07813715935,37010.63081288338,35735.07813715935,1267.9070928096771,3.817821502685547,0.0 -105300,4.8030357,1.6781301,,,,,,,,,,,,,, -105400,4.9670286,1.5434506,,,,,,,,,,,,,, -105500,4.1735005,1.4358573,,,,,,,,,,,,,, -105600,4.685895,1.5469875,,,,,,,,,,,,,, -105700,4.92488,1.5368162,,,,,,,,,,,,,, -105800,4.285521,1.5623108,,,,,,,,,,,,,, -105900,5.0693808,1.6111443,,,,,,,,,,,,,, -106000,4.859341,1.5158775,,,,,,,,,,,,,, -106100,4.6431446,1.4759741,,,,,,,,,,,,,, -106200,5.09603,1.5523497,,,,,,,,,,,,,, -106300,4.735018,1.6450081,,,,,,,,,,,,,, -106400,5.043698,1.5491273,,,,,,,,,,,,,, -106500,5.309418,1.4897616,,,,,,,,,,,,,, -106600,4.583704,1.630744,,,,,,,,,,,,,, -106700,5.449844,1.5323572,,,,,,,,,,,,,, -106730,,,0.7329201102256775,1.0269001722335815,0.6684399843215942,1.340019464492798,50000.0,0.5428000092506409,2.031836748123169,10000.0,36245.17560458183,37538.59745979309,36245.17560458183,1285.6735136508942,3.866531372070313,0.0 -106800,4.5405984,1.568959,,,,,,,,,,,,,, -106900,4.8925004,1.5356624,,,,,,,,,,,,,, -107000,5.476916,1.6405132,,,,,,,,,,,,,, -107100,4.773144,1.4802463,,,,,,,,,,,,,, -107200,4.763234,1.593052,,,,,,,,,,,,,, -107300,4.8716316,1.468905,,,,,,,,,,,,,, -107400,4.276195,1.4882305,,,,,,,,,,,,,, -107500,4.402693,1.5425082,,,,,,,,,,,,,, -107600,4.912902,1.5860599,,,,,,,,,,,,,, -107700,4.7593446,1.5616388,,,,,,,,,,,,,, -107800,5.118365,1.6352849,,,,,,,,,,,,,, -107900,4.5339646,1.4195871,,,,,,,,,,,,,, -108000,5.0888247,1.6806338,,,,,,,,,,,,,, -108100,4.597148,1.5310043,,,,,,,,,,,,,, -108200,4.817286,1.5536783,,,,,,,,,,,,,, -108236,,,0.7391581535339355,1.0007308721542358,0.6759799718856812,1.3103091716766355,50000.0,0.5467000007629395,2.014052152633667,10000.0,36755.24262714386,38066.58639526367,36755.24262714386,1303.4949452877045,3.911306619644165,0.0 -108300,5.226675,1.6011934,,,,,,,,,,,,,, -108400,4.7423,1.56252,,,,,,,,,,,,,, -108500,5.0322785,1.5453787,,,,,,,,,,,,,, -108600,4.7390594,1.4268591,,,,,,,,,,,,,, -108700,4.8190885,1.5631386,,,,,,,,,,,,,, -108800,5.2359924,1.6187935,,,,,,,,,,,,,, -108900,5.2673087,1.4861076,,,,,,,,,,,,,, -109000,4.914627,1.4848309,,,,,,,,,,,,,, -109100,5.109213,1.5758405,,,,,,,,,,,,,, -109200,4.385089,1.5393066,,,,,,,,,,,,,, -109300,4.9732165,1.4757187,,,,,,,,,,,,,, -109400,6.1452255,1.5389255,,,,,,,,,,,,,, -109500,5.2364063,1.5299639,,,,,,,,,,,,,, -109600,4.7722707,1.4822476,,,,,,,,,,,,,, -109700,4.903363,1.5844958,,,,,,,,,,,,,, -109741,,,0.7607222199440002,0.9054629802703856,0.6724399924278259,1.3212249279022217,50000.0,0.5461000204086304,2.040199279785156,10000.0,37265.14801168442,38594.18107557297,37265.14801168442,1321.0843846797943,3.95632004737854,0.0 -109800,4.910785,1.5780531,,,,,,,,,,,,,, -109900,4.4749,1.4592402,,,,,,,,,,,,,, -110000,5.5872145,1.5144,,,,,,,,,,,,,, -110100,4.6745496,1.444447,,,,,,,,,,,,,, -110200,4.8274302,1.4369203,,,,,,,,,,,,,, -110300,5.373266,1.494434,,,,,,,,,,,,,, -110400,4.847868,1.5836327,,,,,,,,,,,,,, -110500,4.7445335,1.5184475,,,,,,,,,,,,,, -110600,4.8294506,1.6147108,,,,,,,,,,,,,, -110700,5.2173343,1.5305506,,,,,,,,,,,,,, -110800,5.03727,1.5405549,,,,,,,,,,,,,, -110900,4.744043,1.576366,,,,,,,,,,,,,, -111000,4.8123674,1.4530191,,,,,,,,,,,,,, -111100,5.0164013,1.5357608,,,,,,,,,,,,,, -111200,5.5352845,1.6660165,,,,,,,,,,,,,, -111246,,,0.7552016973495483,0.9197171330451964,0.6800199747085571,1.300967574119568,50000.0,0.55840003490448,1.99822998046875,10000.0,37775.20650601387,39122.22860813141,37775.20650601387,1338.969208240509,4.006459951400757,0.0 -111300,4.847781,1.5152793,,,,,,,,,,,,,, -111400,5.0311904,1.3796015,,,,,,,,,,,,,, -111500,4.5822916,1.4423777,,,,,,,,,,,,,, -111600,5.1348596,1.5128281,,,,,,,,,,,,,, -111700,5.0180254,1.4924918,,,,,,,,,,,,,, -111800,4.9811087,1.4631101,,,,,,,,,,,,,, -111900,5.011725,1.5053768,,,,,,,,,,,,,, -112000,4.807519,1.5920212,,,,,,,,,,,,,, -112100,5.187317,1.587752,,,,,,,,,,,,,, -112200,5.274477,1.3320085,,,,,,,,,,,,,, -112300,5.23182,1.4681455,,,,,,,,,,,,,, -112400,4.699969,1.5621635,,,,,,,,,,,,,, -112500,5.4585714,1.5455017,,,,,,,,,,,,,, -112600,5.7304826,1.4810188,,,,,,,,,,,,,, -112700,4.5859437,1.4330356,,,,,,,,,,,,,, -112752,,,0.752949595451355,0.93806129693985,0.6771599650382996,1.3039065599441528,50000.0,0.5509999990463257,2.019953727722168,10000.0,38285.42838454247,39650.31968307495,38285.42838454247,1356.734162569046,4.056137800216675,0.0 -112800,5.298063,1.5545217,,,,,,,,,,,,,, -112900,4.9716535,1.5719248,,,,,,,,,,,,,, -113000,5.1634164,1.5277872,,,,,,,,,,,,,, -113100,4.8492203,1.5613747,,,,,,,,,,,,,, -113200,5.126415,1.4156592,,,,,,,,,,,,,, -113300,4.845466,1.4081,,,,,,,,,,,,,, -113400,5.025466,1.512826,,,,,,,,,,,,,, -113500,4.970951,1.6221389,,,,,,,,,,,,,, -113600,4.744605,1.3945597,,,,,,,,,,,,,, -113700,4.9260373,1.515425,,,,,,,,,,,,,, -113800,5.1317325,1.4164085,,,,,,,,,,,,,, -113900,4.993083,1.492928,,,,,,,,,,,,,, -114000,5.0764723,1.4012969,,,,,,,,,,,,,, -114100,5.660285,1.4560153,,,,,,,,,,,,,, -114200,6.0405345,1.6244274,,,,,,,,,,,,,, -114258,,,0.7528300285339355,0.94208025932312,0.6818199753761292,1.2975364923477173,50000.0,0.5581000447273254,1.993594765663147,10000.0,38795.5488409996,40178.211097717285,38795.5488409996,1374.3892624378204,4.116119623184204,0.0 -114300,5.5667315,1.4316242,,,,,,,,,,,,,, -114400,4.7709017,1.4813652,,,,,,,,,,,,,, -114500,5.091893,1.4900155,,,,,,,,,,,,,, -114600,4.714279,1.5111849,,,,,,,,,,,,,, -114700,5.1710124,1.5689311,,,,,,,,,,,,,, -114800,5.4496055,1.5683892,,,,,,,,,,,,,, -114900,5.223009,1.4475256,,,,,,,,,,,,,, -115000,5.229155,1.4753635,,,,,,,,,,,,,, -115100,4.69613,1.400086,,,,,,,,,,,,,, -115200,4.826831,1.395956,,,,,,,,,,,,,, -115300,5.2393007,1.477,,,,,,,,,,,,,, -115400,5.1551685,1.4427507,,,,,,,,,,,,,, -115500,5.316657,1.4405135,,,,,,,,,,,,,, -115600,4.9954247,1.4089566,,,,,,,,,,,,,, -115700,5.9263105,1.5185597,,,,,,,,,,,,,, -115763,,,0.74418044090271,0.970153272151947,0.678059995174408,1.2999392747879028,50000.0,0.5466000437736511,2.012835025787353,10000.0,39305.58757019043,40706.71689796448,39305.58757019043,1392.7512967586515,4.165997743606567,0.0 -115800,4.717543,1.381165,,,,,,,,,,,,,, -115900,4.9208527,1.421915,,,,,,,,,,,,,, -116000,5.053949,1.4522407,,,,,,,,,,,,,, -116100,4.817529,1.5661623,,,,,,,,,,,,,, -116200,5.040992,1.4621537,,,,,,,,,,,,,, -116300,5.7342153,1.5702524,,,,,,,,,,,,,, -116400,5.3225865,1.4141222,,,,,,,,,,,,,, -116500,5.410504,1.6038978,,,,,,,,,,,,,, -116600,4.9903727,1.5808716,,,,,,,,,,,,,, -116700,5.09708,1.4961455,,,,,,,,,,,,,, -116800,5.736003,1.404329,,,,,,,,,,,,,, -116900,5.1388702,1.4686953,,,,,,,,,,,,,, -117000,5.6573553,1.4966125,,,,,,,,,,,,,, -117100,5.7100496,1.4910647,,,,,,,,,,,,,, -117200,4.6621904,1.4510062,,,,,,,,,,,,,, -117269,,,0.7614795565605164,0.9109672904014589,0.6902799606323242,1.2496405839920044,50000.0,0.5684000253677368,1.921729564666748,10000.0,39815.64774036408,41234.668796777725,39815.64774036408,1410.537811756134,4.214937686920166,0.0 -117300,5.190941,1.4402794,,,,,,,,,,,,,, -117400,4.898635,1.4348605,,,,,,,,,,,,,, -117500,5.666644,1.5268191,,,,,,,,,,,,,, -117600,5.5084224,1.5045376,,,,,,,,,,,,,, -117700,6.01758,1.3852618,,,,,,,,,,,,,, -117800,5.705491,1.438017,,,,,,,,,,,,,, -117900,5.4815235,1.4477903,,,,,,,,,,,,,, -118000,5.100871,1.4542153,,,,,,,,,,,,,, -118100,5.637913,1.4292741,,,,,,,,,,,,,, -118200,5.224607,1.4713806,,,,,,,,,,,,,, -118300,4.752706,1.3484368,,,,,,,,,,,,,, -118400,5.134251,1.446016,,,,,,,,,,,,,, -118500,5.008507,1.5637513,,,,,,,,,,,,,, -118600,5.5572596,1.445088,,,,,,,,,,,,,, -118700,5.3532763,1.3634611,,,,,,,,,,,,,, -118775,,,0.7867307066917419,0.7906128764152527,0.6891199946403503,1.260579228401184,50000.0,0.5646000504493713,1.9624043703079224,10000.0,40325.84529519081,41762.935829401016,40325.84529519081,1428.5066511631012,4.262691259384155,0.0 -118800,6.4761252,1.5322063,,,,,,,,,,,,,, -118900,6.091903,1.4367231,,,,,,,,,,,,,, -119000,5.615605,1.4574368,,,,,,,,,,,,,, -119100,5.731327,1.5115668,,,,,,,,,,,,,, -119200,5.068005,1.4310205,,,,,,,,,,,,,, -119300,5.227488,1.3833609,,,,,,,,,,,,,, -119400,4.946012,1.4098824,,,,,,,,,,,,,, -119500,4.9185734,1.3884469,,,,,,,,,,,,,, -119600,5.4256134,1.5436921,,,,,,,,,,,,,, -119700,5.3995924,1.481184,,,,,,,,,,,,,, -119800,5.449892,1.3935943,,,,,,,,,,,,,, -119900,5.5337067,1.4220762,,,,,,,,,,,,,, -120000,5.3074245,1.3789021,,,,,,,,,,,,,, -120100,4.7506166,1.3752768,,,,,,,,,,,,,, -120200,5.9043736,1.4653602,,,,,,,,,,,,,, -120281,,,0.77543044090271,0.8411076664924622,0.6890400052070618,1.254690408706665,50000.0,0.5605000257492065,1.9643940925598145,10000.0,40836.02567625046,42291.07681274414,40836.02567625046,1446.361969947815,4.311857223510742,0.0 -120300,5.163466,1.3754023,,,,,,,,,,,,,, -120400,5.101946,1.5469878,,,,,,,,,,,,,, -120500,5.4690714,1.3719404,,,,,,,,,,,,,, -120600,5.00973,1.2787986,,,,,,,,,,,,,, -120700,5.4544177,1.4484127,,,,,,,,,,,,,, -120800,5.379204,1.3756142,,,,,,,,,,,,,, -120900,5.1653256,1.4676399,,,,,,,,,,,,,, -121000,5.1099734,1.408514,,,,,,,,,,,,,, -121100,5.823263,1.3425903,,,,,,,,,,,,,, -121200,5.3040977,1.4775635,,,,,,,,,,,,,, -121300,5.4969163,1.4538912,,,,,,,,,,,,,, -121400,5.171266,1.359937,,,,,,,,,,,,,, -121500,5.576567,1.3546603,,,,,,,,,,,,,, -121600,5.6120067,1.4628102,,,,,,,,,,,,,, -121700,4.909593,1.3677064,,,,,,,,,,,,,, -121786,,,0.7752710580825806,0.8376134634017944,0.6941999793052673,1.22850501537323,50000.0,0.5678000450134277,1.9029667377471924,10000.0,41345.97124314308,42818.74826264381,41345.97124314308,1463.9844024181366,4.359987735748291,0.0 -121800,5.0057693,1.2965302,,,,,,,,,,,,,, -121900,5.928442,1.446632,,,,,,,,,,,,,, -122000,6.0914135,1.4814502,,,,,,,,,,,,,, -122100,5.7129307,1.3394417,,,,,,,,,,,,,, -122200,6.4170294,1.5058944,,,,,,,,,,,,,, -122300,5.3197274,1.397817,,,,,,,,,,,,,, -122400,4.884404,1.4309828,,,,,,,,,,,,,, -122500,5.3885756,1.3366032,,,,,,,,,,,,,, -122600,5.010085,1.4309576,,,,,,,,,,,,,, -122700,5.4511514,1.3341327,,,,,,,,,,,,,, -122800,5.7680674,1.3589245,,,,,,,,,,,,,, -122900,5.482477,1.4040519,,,,,,,,,,,,,, -123000,5.6503196,1.3309613,,,,,,,,,,,,,, -123100,5.066765,1.3749408,,,,,,,,,,,,,, -123200,5.439923,1.4127251,,,,,,,,,,,,,, -123292,,,0.7697106003761292,0.8630506992340088,0.6924200057983398,1.2353243827819824,50000.0,0.5701000094413757,1.924597978591919,10000.0,41856.06367182732,43346.80474662781,41856.06367182732,1481.8402979373932,4.414201498031616,0.0 -123300,6.4124026,1.4791963,,,,,,,,,,,,,, -123400,6.0443034,1.3563118,,,,,,,,,,,,,, -123500,5.0895233,1.3828182,,,,,,,,,,,,,, -123600,5.208608,1.4470983,,,,,,,,,,,,,, -123700,5.4972334,1.4581285,,,,,,,,,,,,,, -123800,5.534083,1.339332,,,,,,,,,,,,,, -123900,5.269308,1.4391081,,,,,,,,,,,,,, -124000,5.349242,1.3170849,,,,,,,,,,,,,, -124100,5.5827847,1.4360232,,,,,,,,,,,,,, -124200,5.440183,1.4251509,,,,,,,,,,,,,, -124300,5.6009045,1.3928798,,,,,,,,,,,,,, -124400,5.265215,1.3264982,,,,,,,,,,,,,, -124500,5.3991065,1.3321433,,,,,,,,,,,,,, -124600,5.808335,1.3365356,,,,,,,,,,,,,, -124700,5.4227448,1.3683482,,,,,,,,,,,,,, -124797,,,0.7671197056770325,0.862572968006134,0.6951199769973755,1.2315285205841064,50000.0,0.5696000456809998,1.931737184524536,10000.0,42366.00140285492,43875.40986657143,42366.00140285492,1500.4037234783173,4.462406873703003,0.0 -124800,5.58238,1.2795289,,,,,,,,,,,,,, -124900,5.306728,1.4305382,,,,,,,,,,,,,, -125000,5.920665,1.3374935,,,,,,,,,,,,,, -125100,5.7196,1.315338,,,,,,,,,,,,,, -125200,5.4467216,1.3023651,,,,,,,,,,,,,, -125300,5.609954,1.3458419,,,,,,,,,,,,,, -125400,5.290535,1.3984269,,,,,,,,,,,,,, -125500,5.881986,1.3562502,,,,,,,,,,,,,, -125600,5.676828,1.3720498,,,,,,,,,,,,,, -125700,6.0391555,1.3785274,,,,,,,,,,,,,, -125800,5.85836,1.368025,,,,,,,,,,,,,, -125900,5.7852273,1.3527564,,,,,,,,,,,,,, -126000,5.948167,1.4078692,,,,,,,,,,,,,, -126100,6.2855115,1.2919545,,,,,,,,,,,,,, -126200,5.3674455,1.3089705,,,,,,,,,,,,,, -126300,5.4568,1.3648852,,,,,,,,,,,,,, -126303,,,0.776387095451355,0.831611156463623,0.7033199667930603,1.20175302028656,50000.0,0.579200029373169,1.8880500793457031,10000.0,42876.07256484032,44403.05740857124,42876.07256484032,1517.8734288215635,4.515023231506348,0.0 -126400,5.730918,1.311674,,,,,,,,,,,,,, -126500,5.549295,1.2238892,,,,,,,,,,,,,, -126600,7.0245676,1.4048241,,,,,,,,,,,,,, -126700,5.7449517,1.4578559,,,,,,,,,,,,,, -126800,5.804548,1.4185375,,,,,,,,,,,,,, -126900,6.163105,1.3519448,,,,,,,,,,,,,, -127000,5.7003946,1.3370218,,,,,,,,,,,,,, -127100,5.9358406,1.3715916,,,,,,,,,,,,,, -127200,5.869277,1.2701392,,,,,,,,,,,,,, -127300,5.307508,1.3725514,,,,,,,,,,,,,, -127400,5.481137,1.3383615,,,,,,,,,,,,,, -127500,5.96962,1.3362111,,,,,,,,,,,,,, -127600,5.734567,1.297044,,,,,,,,,,,,,, -127700,5.6954036,1.2245994,,,,,,,,,,,,,, -127800,5.298725,1.3904452,,,,,,,,,,,,,, -127808,,,0.808035671710968,0.7111296057701111,0.7016400098800659,1.2024482488632202,50000.0,0.5752000212669373,1.917343020439148,10000.0,43386.099005937576,44931.04763793945,43386.099005937576,1535.7307217121124,4.566674709320068,0.0 -127900,5.205946,1.2434291,,,,,,,,,,,,,, -128000,5.639888,1.3918674,,,,,,,,,,,,,, -128100,6.1334324,1.3263704,,,,,,,,,,,,,, -128200,6.233139,1.343602,,,,,,,,,,,,,, -128300,5.3586683,1.3131456,,,,,,,,,,,,,, -128400,6.015101,1.2580509,,,,,,,,,,,,,, -128500,5.6803894,1.3749912,,,,,,,,,,,,,, -128600,6.3819766,1.258287,,,,,,,,,,,,,, -128700,5.9774637,1.305826,,,,,,,,,,,,,, -128800,6.22195,1.4936279,,,,,,,,,,,,,, -128900,5.541473,1.3617618,,,,,,,,,,,,,, -129000,6.144256,1.325277,,,,,,,,,,,,,, -129100,5.231399,1.3683435,,,,,,,,,,,,,, -129200,6.105222,1.2898014,,,,,,,,,,,,,, -129300,6.7413635,1.3975724,,,,,,,,,,,,,, -129314,,,0.7916733026504517,0.7742475271224976,0.6988999843597412,1.205346941947937,50000.0,0.5738000273704529,1.8980298042297363,10000.0,43896.101038217545,45458.82877635956,43896.101038217545,1553.4010951519012,4.621357679367065,0.0 -129400,5.783458,1.3412951,,,,,,,,,,,,,, -129500,5.815868,1.4434896,,,,,,,,,,,,,, -129600,5.623167,1.3129662,,,,,,,,,,,,,, -129700,6.205418,1.2474664,,,,,,,,,,,,,, -129800,6.391082,1.3457588,,,,,,,,,,,,,, -129900,6.527136,1.3318564,,,,,,,,,,,,,, -130000,6.1524863,1.330631,,,,,,,,,,,,,, -130100,6.659298,1.2770445,,,,,,,,,,,,,, -130200,6.2979856,1.2363136,,,,,,,,,,,,,, -130300,6.0698357,1.3888304,,,,,,,,,,,,,, -130400,5.4554296,1.3277807,,,,,,,,,,,,,, -130500,6.408342,1.4615144,,,,,,,,,,,,,, -130600,6.035574,1.4668854,,,,,,,,,,,,,, -130700,6.2261233,1.2220604,,,,,,,,,,,,,, -130800,5.8362584,1.29565,,,,,,,,,,,,,, -130819,,,0.7887834906578064,0.7834305763244629,0.7028999924659729,1.1931791305541992,50000.0,0.5729000568389893,1.9065872430801392,10000.0,44406.101815223694,45986.675889253616,44406.101815223694,1571.1410930156708,4.673174381256104,0.0 -130900,5.6709995,1.318944,,,,,,,,,,,,,, -131000,5.6188903,1.2592138,,,,,,,,,,,,,, -131100,6.136836,1.3611913,,,,,,,,,,,,,, -131200,6.2141194,1.2272575,,,,,,,,,,,,,, -131300,6.4541163,1.318639,,,,,,,,,,,,,, -131400,7.286132,1.368388,,,,,,,,,,,,,, -131500,5.516319,1.182096,,,,,,,,,,,,,, -131600,6.503944,1.3540188,,,,,,,,,,,,,, -131700,5.3234496,1.1884927,,,,,,,,,,,,,, -131800,6.4227943,1.1668396,,,,,,,,,,,,,, -131900,6.316816,1.3413308,,,,,,,,,,,,,, -132000,5.65488,1.2829454,,,,,,,,,,,,,, -132100,5.800506,1.3759813,,,,,,,,,,,,,, -132200,5.522543,1.330436,,,,,,,,,,,,,, -132300,5.8369,1.3569263,,,,,,,,,,,,,, -132325,,,0.7939453125,0.7579774260520935,0.7079199552536011,1.1673496961593628,50000.0,0.5933000445365906,1.8415168523788448,10000.0,44916.1741373539,46514.71562767029,44916.1741373539,1589.0011916160583,4.726458549499512,0.0 -132400,6.284045,1.2905256,,,,,,,,,,,,,, -132500,6.1616697,1.3195214,,,,,,,,,,,,,, -132600,6.007938,1.2560425,,,,,,,,,,,,,, -132700,6.1497393,1.3284868,,,,,,,,,,,,,, -132800,6.68355,1.2799861,,,,,,,,,,,,,, -132900,6.383335,1.3027585,,,,,,,,,,,,,, -133000,6.1306686,1.3632481,,,,,,,,,,,,,, -133100,6.1620407,1.2643476,,,,,,,,,,,,,, -133200,5.911288,1.3361439,,,,,,,,,,,,,, -133300,6.0381546,1.2475703,,,,,,,,,,,,,, -133400,6.2467837,1.2955067,,,,,,,,,,,,,, -133500,7.3677826,1.2366247,,,,,,,,,,,,,, -133600,6.1756063,1.2508068,,,,,,,,,,,,,, -133700,5.826551,1.1789329,,,,,,,,,,,,,, -133800,6.411605,1.361798,,,,,,,,,,,,,, -133831,,,0.7958984375,0.757809579372406,0.708620011806488,1.1658791303634644,50000.0,0.5860000252723694,1.8372397422790527,10000.0,45426.37725400925,47042.93878364563,45426.37725400925,1606.9133660793304,4.778220176696777,0.0 -133900,5.8226914,1.2778152,,,,,,,,,,,,,, -134000,6.140585,1.3610266,,,,,,,,,,,,,, -134100,5.842912,1.3536323,,,,,,,,,,,,,, -134200,7.0171585,1.2949011,,,,,,,,,,,,,, -134300,5.7278066,1.2178969,,,,,,,,,,,,,, -134400,5.5138764,1.2223938,,,,,,,,,,,,,, -134500,5.7219896,1.2285765,,,,,,,,,,,,,, -134600,5.574174,1.2183356,,,,,,,,,,,,,, -134700,5.975752,1.2543683,,,,,,,,,,,,,, -134800,6.4142685,1.2759343,,,,,,,,,,,,,, -134900,6.89361,1.3419886,,,,,,,,,,,,,, -135000,5.6358867,1.2047033,,,,,,,,,,,,,, -135100,7.8337536,1.3584915,,,,,,,,,,,,,, -135200,6.718262,1.304614,,,,,,,,,,,,,, -135300,6.5064216,1.3161147,,,,,,,,,,,,,, -135336,,,0.7983697056770325,0.737368106842041,0.7152799963951111,1.1455098390579224,50000.0,0.5838000178337097,1.8571194410324097,10000.0,45936.42216873169,47570.64797115326,45936.42216873169,1624.4669604301453,4.833124399185181,0.0 -135400,6.592905,1.3022597,,,,,,,,,,,,,, -135500,5.8419805,1.180093,,,,,,,,,,,,,, -135600,6.4654894,1.3569322,,,,,,,,,,,,,, -135700,5.988889,1.2525694,,,,,,,,,,,,,, -135800,5.578446,1.2436059,,,,,,,,,,,,,, -135900,5.8883395,1.2009177,,,,,,,,,,,,,, -136000,6.1295533,1.2449983,,,,,,,,,,,,,, -136100,6.7029347,1.2950938,,,,,,,,,,,,,, -136200,6.227942,1.1603984,,,,,,,,,,,,,, -136300,6.640349,1.1555756,,,,,,,,,,,,,, -136400,6.3960776,1.1953344,,,,,,,,,,,,,, -136500,6.16356,1.3177272,,,,,,,,,,,,,, -136600,6.666453,1.214158,,,,,,,,,,,,,, -136700,6.267289,1.2033157,,,,,,,,,,,,,, -136800,5.956807,1.1740384,,,,,,,,,,,,,, -136841,,,0.8374122977256775,0.5970585942268372,0.7161399722099304,1.1301541328430176,50000.0,0.5967000126838684,1.7885128259658811,10000.0,46446.49920320511,48098.353201150894,46446.49920320511,1641.9889187812803,4.884382009506226,0.0 -136900,5.5046625,1.0763068,,,,,,,,,,,,,, -137000,6.9653115,1.3535279,,,,,,,,,,,,,, -137100,6.4331136,1.2723953,,,,,,,,,,,,,, -137200,6.5585823,1.2512032,,,,,,,,,,,,,, -137300,6.3253546,1.1589224,,,,,,,,,,,,,, -137400,6.4986405,1.2739594,,,,,,,,,,,,,, -137500,6.6693745,1.1797853,,,,,,,,,,,,,, -137600,6.0691934,1.2189533,,,,,,,,,,,,,, -137700,7.2791076,1.2844874,,,,,,,,,,,,,, -137800,6.6301465,1.2388585,,,,,,,,,,,,,, -137900,6.3892384,1.2338756,,,,,,,,,,,,,, -138000,5.89133,1.1769025,,,,,,,,,,,,,, -138100,6.860828,1.2938992,,,,,,,,,,,,,, -138200,6.1146975,1.2251589,,,,,,,,,,,,,, -138300,6.1645474,1.2859184,,,,,,,,,,,,,, -138346,,,0.8237802982330322,0.637283980846405,0.7233399748802185,1.104615330696106,50000.0,0.5999000072479248,1.7960389852523804,10000.0,46956.61994481087,48626.38387274742,46956.61994481087,1659.7750248908997,4.95368218421936,0.0 -138400,5.980234,1.1685584,,,,,,,,,,,,,, -138500,7.4170666,1.2076526,,,,,,,,,,,,,, -138600,6.7745223,1.1516092,,,,,,,,,,,,,, -138700,7.272353,1.2197955,,,,,,,,,,,,,, -138800,7.136132,1.2487428,,,,,,,,,,,,,, -138900,6.2589893,1.2074441,,,,,,,,,,,,,, -139000,5.9931383,1.3026179,,,,,,,,,,,,,, -139100,5.5887733,1.1835576,,,,,,,,,,,,,, -139200,6.0382333,1.1533267,,,,,,,,,,,,,, -139300,6.3768964,1.1760597,,,,,,,,,,,,,, -139400,6.3160872,1.1752292,,,,,,,,,,,,,, -139500,6.4914575,1.2274394,,,,,,,,,,,,,, -139600,6.0048547,1.2065225,,,,,,,,,,,,,, -139700,7.061938,1.2000159,,,,,,,,,,,,,, -139800,7.024865,1.1431851,,,,,,,,,,,,,, -139851,,,0.8210897445678711,0.6470831036567688,0.7231400012969971,1.1141338348388672,50000.0,0.6011000275611877,1.7942116260528564,10000.0,47466.56193733215,49154.288517951965,47466.56193733215,1677.6271243095398,5.008821725845337,0.0 -139900,6.1822143,1.2399715,,,,,,,,,,,,,, -140000,6.635211,1.246599,,,,,,,,,,,,,, -140100,6.769287,1.195288,,,,,,,,,,,,,, -140200,6.485034,1.1950632,,,,,,,,,,,,,, -140300,7.0767927,1.2263458,,,,,,,,,,,,,, -140400,7.0147367,1.2279274,,,,,,,,,,,,,, -140500,6.3438244,1.2310096,,,,,,,,,,,,,, -140600,6.299579,1.1846104,,,,,,,,,,,,,, -140700,7.0057535,1.1959219,,,,,,,,,,,,,, -140800,6.708429,1.1992409,,,,,,,,,,,,,, -140900,6.569495,1.237265,,,,,,,,,,,,,, -141000,6.6385636,1.2431064,,,,,,,,,,,,,, -141100,6.2758746,1.2023847,,,,,,,,,,,,,, -141200,6.3193245,1.1620518,,,,,,,,,,,,,, -141300,6.964432,1.133766,,,,,,,,,,,,,, -141357,,,0.8201530575752258,0.6532760858535767,0.7235199809074402,1.1173137426376345,50000.0,0.5990000367164612,1.7983546257019043,10000.0,47976.69469380379,49682.10616827011,47976.69469380379,1695.204603433609,5.06158185005188,0.0 -141400,6.4562216,1.1478317,,,,,,,,,,,,,, -141500,6.8939753,1.1426349,,,,,,,,,,,,,, -141600,6.521117,1.1360773,,,,,,,,,,,,,, -141700,6.968421,1.3049958,,,,,,,,,,,,,, -141800,7.0483866,1.2164177,,,,,,,,,,,,,, -141900,6.5786633,1.1073401,,,,,,,,,,,,,, -142000,6.5287356,1.085962,,,,,,,,,,,,,, -142100,6.256194,1.1192982,,,,,,,,,,,,,, -142200,6.19571,1.1193314,,,,,,,,,,,,,, -142300,7.646391,1.2428035,,,,,,,,,,,,,, -142400,6.392212,1.1355689,,,,,,,,,,,,,, -142500,6.402413,1.1721674,,,,,,,,,,,,,, -142600,6.1592646,1.102651,,,,,,,,,,,,,, -142700,7.4463496,1.1398225,,,,,,,,,,,,,, -142800,6.623602,1.1637361,,,,,,,,,,,,,, -142862,,,0.8167450428009033,0.6668460965156555,0.7207799553871155,1.1217389106750488,50000.0,0.5948000550270081,1.8159778118133545,10000.0,48486.89080500603,50210.33151769638,48486.89080500603,1713.1225311756134,5.118523836135864,0.0 -142900,6.985396,1.3149849,,,,,,,,,,,,,, -143000,6.3239484,1.1630479,,,,,,,,,,,,,, -143100,6.757783,1.125021,,,,,,,,,,,,,, -143200,6.37344,1.05134,,,,,,,,,,,,,, -143300,6.817916,1.2389972,,,,,,,,,,,,,, -143400,6.1267433,1.1361068,,,,,,,,,,,,,, -143500,7.186723,1.1316519,,,,,,,,,,,,,, -143600,7.6220636,1.2136885,,,,,,,,,,,,,, -143700,6.7240424,1.1340182,,,,,,,,,,,,,, -143800,6.5833426,1.0878217,,,,,,,,,,,,,, -143900,6.344087,1.1201208,,,,,,,,,,,,,, -144000,7.6377,1.1905133,,,,,,,,,,,,,, -144100,6.469934,1.1998473,,,,,,,,,,,,,, -144200,6.5695143,1.1893045,,,,,,,,,,,,,, -144300,6.963667,1.1522261,,,,,,,,,,,,,, -144368,,,0.8262914419174194,0.6314259767532349,0.7257599830627441,1.0909123420715332,50000.0,0.6010000109672546,1.7762413024902344,10000.0,48996.92267632485,50738.048347473145,48996.92267632485,1730.7020156383514,5.169455766677856,0.0 -144400,7.2155275,1.1752664,,,,,,,,,,,,,, -144500,6.4538956,1.1571494,,,,,,,,,,,,,, -144600,6.972941,1.1352129,,,,,,,,,,,,,, -144700,7.20616,1.2336502,,,,,,,,,,,,,, -144800,7.309404,1.1855316,,,,,,,,,,,,,, -144900,6.7687,1.1707546,,,,,,,,,,,,,, -145000,7.100604,1.1143855,,,,,,,,,,,,,, -145100,7.2320256,1.1191813,,,,,,,,,,,,,, -145200,6.722704,1.1158533,,,,,,,,,,,,,, -145300,6.899202,1.1229591,,,,,,,,,,,,,, -145400,6.9016466,1.0533254,,,,,,,,,,,,,, -145500,6.7095914,1.1043165,,,,,,,,,,,,,, -145600,6.993575,1.0577784,,,,,,,,,,,,,, -145700,6.980824,1.1133144,,,,,,,,,,,,,, -145800,7.4268494,1.253984,,,,,,,,,,,,,, -145874,,,0.8601721525192261,0.5021381378173828,0.732479989528656,1.0820106267929075,50000.0,0.6047000288963318,1.778953194618225,10000.0,49507.11916804314,51266.07844781876,49507.11916804314,1748.4203968048096,5.229321718215942,0.0 -145900,6.760235,1.0606673,,,,,,,,,,,,,, -146000,8.048353,1.1455919,,,,,,,,,,,,,, -146100,6.806696,1.0389558,,,,,,,,,,,,,, -146200,6.709886,1.0635464,,,,,,,,,,,,,, -146300,7.282582,1.1105776,,,,,,,,,,,,,, -146400,8.25104,1.1228039,,,,,,,,,,,,,, -146500,7.303922,1.1096835,,,,,,,,,,,,,, -146600,6.612259,1.1575292,,,,,,,,,,,,,, -146700,7.159777,1.0735892,,,,,,,,,,,,,, -146800,7.286536,1.1116276,,,,,,,,,,,,,, -146900,7.2932725,1.1561191,,,,,,,,,,,,,, -147000,7.7593937,1.1587849,,,,,,,,,,,,,, -147100,7.2120895,1.149126,,,,,,,,,,,,,, -147200,6.607455,1.0767571,,,,,,,,,,,,,, -147300,7.1572647,1.0753735,,,,,,,,,,,,,, -147380,,,0.8510442972183228,0.5359805822372437,0.7305399775505066,1.0758004188537598,50000.0,0.6089000105857849,1.7504609823226929,10000.0,50017.27359175682,51793.96232128143,50017.27359175682,1766.0413398742676,5.282959222793579,0.0 -147400,7.1393332,1.0589633,,,,,,,,,,,,,, -147500,6.1984816,1.0655776,,,,,,,,,,,,,, -147600,7.3347664,1.1210868,,,,,,,,,,,,,, -147700,7.6488514,1.1411573,,,,,,,,,,,,,, -147800,7.0504885,1.0919644,,,,,,,,,,,,,, -147900,7.400415,1.1919333,,,,,,,,,,,,,, -148000,7.120337,1.1019027,,,,,,,,,,,,,, -148100,8.286779,1.0541154,,,,,,,,,,,,,, -148200,7.3146014,1.0672297,,,,,,,,,,,,,, -148300,6.8514924,0.95628905,,,,,,,,,,,,,, -148400,6.796743,1.0540425,,,,,,,,,,,,,, -148500,6.871062,1.1278758,,,,,,,,,,,,,, -148600,7.0962863,1.0862843,,,,,,,,,,,,,, -148700,7.2909966,1.1226628,,,,,,,,,,,,,, -148800,7.8834467,1.0263957,,,,,,,,,,,,,, -148885,,,0.8487523794174194,0.5430769324302673,0.7354199886322021,1.058487057685852,50000.0,0.6100000143051147,1.7439473867416382,10000.0,50527.27390384674,52321.88641309738,50527.27390384674,1783.8505229949951,5.341572046279907,0.0 -148900,6.8580465,1.0402745,,,,,,,,,,,,,, -149000,7.203763,1.1480489,,,,,,,,,,,,,, -149100,7.529423,1.1420296,,,,,,,,,,,,,, -149200,6.928211,1.0452292,,,,,,,,,,,,,, -149300,6.83511,1.1040211,,,,,,,,,,,,,, -149400,7.2876186,1.0685085,,,,,,,,,,,,,, -149500,7.293938,1.1419605,,,,,,,,,,,,,, -149600,7.7436213,1.1063309,,,,,,,,,,,,,, -149700,7.259373,1.0767744,,,,,,,,,,,,,, -149800,7.712945,1.0696557,,,,,,,,,,,,,, -149900,7.003843,1.0745308,,,,,,,,,,,,,, -150000,7.2064857,1.0032728,,,,,,,,,,,,,, -150100,7.305999,1.0689129,,,,,,,,,,,,,, -150200,6.6614785,0.9823468,,,,,,,,,,,,,, -150300,7.555049,1.087448,,,,,,,,,,,,,, -150390,,,0.8517817258834839,0.5283650159835815,0.7335999608039856,1.0614495277404783,50000.0,0.6028000116348267,1.7548463344573977,10000.0,51037.273169994354,52849.83254790306,51037.273169994354,1801.683106660843,5.4006664752960205,0.0 -150400,7.7695966,0.9654485,,,,,,,,,,,,,, -150500,6.959124,1.1270156,,,,,,,,,,,,,, -150600,7.095384,1.1480842,,,,,,,,,,,,,, -150700,7.58861,1.0323006,,,,,,,,,,,,,, -150800,7.511801,1.0092217,,,,,,,,,,,,,, -150900,8.257583,1.1863097,,,,,,,,,,,,,, -151000,6.684237,0.9545932,,,,,,,,,,,,,, -151100,6.637971,1.0897423,,,,,,,,,,,,,, -151200,6.6053553,1.0502553,,,,,,,,,,,,,, -151300,7.257942,1.076487,,,,,,,,,,,,,, -151400,8.246136,1.0753254,,,,,,,,,,,,,, -151500,7.499073,1.0627203,,,,,,,,,,,,,, -151600,8.026384,1.0537485,,,,,,,,,,,,,, -151700,6.830392,0.99687445,,,,,,,,,,,,,, -151800,7.7105474,1.0568392,,,,,,,,,,,,,, -151895,,,0.8530372977256775,0.5189365148544312,0.7375400066375732,1.058759093284607,50000.0,0.6148000359535217,1.7296031713485718,10000.0,51547.48435902596,53377.73521447182,51547.48435902596,1819.265034675598,5.456265211105347,0.0 -151900,7.928322,1.1104321,,,,,,,,,,,,,, -152000,7.2970343,1.0906142,,,,,,,,,,,,,, -152100,7.155725,1.077493,,,,,,,,,,,,,, -152200,7.674265,1.0200734,,,,,,,,,,,,,, -152300,7.620877,1.0814897,,,,,,,,,,,,,, -152400,7.5135946,0.99834293,,,,,,,,,,,,,, -152500,7.6153975,0.9606351,,,,,,,,,,,,,, -152600,7.118752,1.1123816,,,,,,,,,,,,,, -152700,7.1596484,0.97917944,,,,,,,,,,,,,, -152800,7.013693,1.0055709,,,,,,,,,,,,,, -152900,7.4855523,1.0720946,,,,,,,,,,,,,, -153000,7.4030204,1.0461419,,,,,,,,,,,,,, -153100,7.5940905,1.0447948,,,,,,,,,,,,,, -153200,8.237913,1.0376618,,,,,,,,,,,,,, -153300,8.177466,1.0829122,,,,,,,,,,,,,, -153400,7.9635425,1.1624031,,,,,,,,,,,,,, -153401,,,0.8583585619926453,0.5022905468940735,0.7417399883270264,1.0438984632492063,50000.0,0.6172000169754028,1.733088731765747,10000.0,52057.91526436806,53906.03312087059,52057.91526436806,1837.022742509842,5.511146545410156,0.0 -153500,7.6228614,1.0737785,,,,,,,,,,,,,, -153600,7.967565,1.0555702,,,,,,,,,,,,,, -153700,6.9818764,0.95239174,,,,,,,,,,,,,, -153800,7.7417145,0.93885106,,,,,,,,,,,,,, -153900,7.3285785,0.96535033,,,,,,,,,,,,,, -154000,7.637964,1.0083729,,,,,,,,,,,,,, -154100,8.770835,1.1318614,,,,,,,,,,,,,, -154200,7.310567,0.93187296,,,,,,,,,,,,,, -154300,7.6945643,1.0341153,,,,,,,,,,,,,, -154400,7.451059,0.9086881,,,,,,,,,,,,,, -154500,7.3078012,1.0922745,,,,,,,,,,,,,, -154600,7.0702004,0.92265403,,,,,,,,,,,,,, -154700,7.0910616,0.97168237,,,,,,,,,,,,,, -154800,7.479511,1.0797718,,,,,,,,,,,,,, -154900,8.925624,0.89648724,,,,,,,,,,,,,, -154906,,,0.8868981003761292,0.4032793939113617,0.741159975528717,1.0439447164535522,50000.0,0.6126000285148621,1.7265610694885254,10000.0,52567.88165092468,54433.74296832085,52567.88165092468,1854.6535007953644,5.569386959075928,0.0 -155000,7.564387,1.0166802,,,,,,,,,,,,,, -155100,7.788368,0.96249604,,,,,,,,,,,,,, -155200,8.068836,1.0612401,,,,,,,,,,,,,, -155300,7.996786,1.0731231,,,,,,,,,,,,,, -155400,7.962257,1.0198281,,,,,,,,,,,,,, -155500,7.149638,0.9125876,,,,,,,,,,,,,, -155600,7.711408,0.96929735,,,,,,,,,,,,,, -155700,8.933285,1.0185158,,,,,,,,,,,,,, -155800,8.04065,1.0189732,,,,,,,,,,,,,, -155900,9.169998,1.0711542,,,,,,,,,,,,,, -156000,8.612204,0.99090004,,,,,,,,,,,,,, -156100,6.9662576,0.934757,,,,,,,,,,,,,, -156200,7.5412464,1.0042406,,,,,,,,,,,,,, -156300,7.747526,0.99974304,,,,,,,,,,,,,, -156400,7.395843,0.9848769,,,,,,,,,,,,,, -156412,,,0.8819355964660645,0.4201908409595489,0.7441399693489075,1.0274673700332642,50000.0,0.6198000311851501,1.7015371322631836,10000.0,53077.91947197914,54961.70411801338,53077.91947197914,1872.464199066162,5.626893997192383,0.0 -156500,8.197565,1.0100003,,,,,,,,,,,,,, -156600,7.697284,1.0363603,,,,,,,,,,,,,, -156700,7.430027,0.9481138,,,,,,,,,,,,,, -156800,8.084707,0.91045916,,,,,,,,,,,,,, -156900,7.3447165,0.89071196,,,,,,,,,,,,,, -157000,8.513804,1.003996,,,,,,,,,,,,,, -157100,7.9112043,1.0017544,,,,,,,,,,,,,, -157200,8.052506,0.9397397,,,,,,,,,,,,,, -157300,8.111717,0.95732564,,,,,,,,,,,,,, -157400,8.852112,0.98849595,,,,,,,,,,,,,, -157500,7.847053,0.8898567,,,,,,,,,,,,,, -157600,7.5623355,0.97557056,,,,,,,,,,,,,, -157700,8.751049,1.0126336,,,,,,,,,,,,,, -157800,8.533379,0.92677385,,,,,,,,,,,,,, -157900,9.278703,1.0324898,,,,,,,,,,,,,, -157917,,,0.8812978267669678,0.4185983538627624,0.7456600069999695,1.027588129043579,50000.0,0.6189000010490417,1.7277756929397583,10000.0,53587.85387945175,55489.55252146721,53587.85387945175,1890.2672145366669,5.684285640716553,0.0 -158000,7.94133,0.89265984,,,,,,,,,,,,,, -158100,8.537555,0.94306386,,,,,,,,,,,,,, -158200,8.741367,0.9655416,,,,,,,,,,,,,, -158300,7.905381,0.9913487,,,,,,,,,,,,,, -158400,7.676555,1.039944,,,,,,,,,,,,,, -158500,8.453586,1.0722096,,,,,,,,,,,,,, -158600,7.584768,0.8852939,,,,,,,,,,,,,, -158700,8.181936,0.9615986,,,,,,,,,,,,,, -158800,7.977841,1.0171726,,,,,,,,,,,,,, -158900,8.002048,1.0081313,,,,,,,,,,,,,, -159000,8.476195,0.9614287,,,,,,,,,,,,,, -159100,9.397336,0.9695442,,,,,,,,,,,,,, -159200,8.117224,0.93013346,,,,,,,,,,,,,, -159300,7.7488317,0.9574491,,,,,,,,,,,,,, -159400,7.8047795,0.9255817,,,,,,,,,,,,,, -159422,,,0.8851044178009033,0.404142826795578,0.7461400032043457,1.024307131767273,50000.0,0.6230000257492065,1.718498468399048,10000.0,54097.80532884598,56017.39664173126,54097.80532884598,1908.046015739441,5.742774486541748,0.0 -159500,7.836547,1.0192486,,,,,,,,,,,,,, -159600,7.9587545,0.9137031,,,,,,,,,,,,,, -159700,8.535691,0.9456661,,,,,,,,,,,,,, -159800,7.8699093,0.9751,,,,,,,,,,,,,, -159900,7.9657745,0.9430407,,,,,,,,,,,,,, -160000,7.6358123,1.0210546,,,,,,,,,,,,,, -160100,8.654668,0.9412401,,,,,,,,,,,,,, -160200,7.445873,0.89418185,,,,,,,,,,,,,, -160300,9.019413,1.0353556,,,,,,,,,,,,,, -160400,8.389172,0.989656,,,,,,,,,,,,,, -160500,9.025357,0.9321019,,,,,,,,,,,,,, -160600,7.7512417,0.8756385,,,,,,,,,,,,,, -160700,8.750205,0.91306317,,,,,,,,,,,,,, -160800,8.184844,0.9070167,,,,,,,,,,,,,, -160900,7.537041,0.8581356,,,,,,,,,,,,,, -160928,,,0.8844267725944519,0.4088110625743866,0.7474600076675415,1.0270127058029177,50000.0,0.6189000010490417,1.7309526205062866,10000.0,54608.0150744915,56545.54099678993,54608.0150744915,1925.8649232387545,5.801524877548218,0.0 -161000,8.164405,0.8877801,,,,,,,,,,,,,, -161100,7.9981627,0.85286576,,,,,,,,,,,,,, -161200,8.535563,0.9474952,,,,,,,,,,,,,, -161300,7.910818,0.9816619,,,,,,,,,,,,,, -161400,8.430324,0.923285,,,,,,,,,,,,,, -161500,8.33146,0.8612373,,,,,,,,,,,,,, -161600,8.177144,0.9000219,,,,,,,,,,,,,, -161700,8.046163,0.90984356,,,,,,,,,,,,,, -161800,8.3963785,0.8702352,,,,,,,,,,,,,, -161900,7.6863875,0.87355316,,,,,,,,,,,,,, -162000,8.21367,0.8935968,,,,,,,,,,,,,, -162100,7.876366,0.9394037,,,,,,,,,,,,,, -162200,8.649223,0.97961307,,,,,,,,,,,,,, -162300,8.574458,0.9202242,,,,,,,,,,,,,, -162400,8.497309,0.92529863,,,,,,,,,,,,,, -162434,,,0.8910036683082581,0.3811632096767425,0.7514399886131287,1.0032466650009155,50000.0,0.6270000338554382,1.6860774755477903,10000.0,55118.058817625046,57073.90779566765,55118.058817625046,1944.0792744159696,5.855859756469727,0.0 -162500,8.321629,0.868478,,,,,,,,,,,,,, -162600,8.302848,0.8860518,,,,,,,,,,,,,, -162700,8.150881,0.86729383,,,,,,,,,,,,,, -162800,8.253338,0.8726906,,,,,,,,,,,,,, -162900,7.8786163,0.8646869,,,,,,,,,,,,,, -163000,8.315932,0.88323414,,,,,,,,,,,,,, -163100,8.653832,0.9063229,,,,,,,,,,,,,, -163200,8.593638,0.86028636,,,,,,,,,,,,,, -163300,8.442214,0.95251596,,,,,,,,,,,,,, -163400,7.9061866,0.8635695,,,,,,,,,,,,,, -163500,8.487437,0.87925154,,,,,,,,,,,,,, -163600,9.189187,0.8852979,,,,,,,,,,,,,, -163700,8.40562,0.8448557,,,,,,,,,,,,,, -163800,8.024897,0.8570492,,,,,,,,,,,,,, -163900,8.389454,0.8750092,,,,,,,,,,,,,, -163939,,,0.90921950340271,0.3235359787940979,0.7534599900245667,0.9915898442268372,50000.0,0.6303000450134277,1.6824558973312378,10000.0,55627.95647931099,57601.5028333664,55627.95647931099,1961.663684368133,5.913207292556763,0.0 -164000,8.738793,0.79241467,,,,,,,,,,,,,, -164100,8.64611,0.8451657,,,,,,,,,,,,,, -164200,7.9410377,0.8639795,,,,,,,,,,,,,, -164300,8.249949,0.9085127,,,,,,,,,,,,,, -164400,8.410092,0.87379235,,,,,,,,,,,,,, -164500,8.386768,0.8874089,,,,,,,,,,,,,, -164600,8.467381,0.8670895,,,,,,,,,,,,,, -164700,8.185041,0.80673605,,,,,,,,,,,,,, -164800,8.140058,0.84714043,,,,,,,,,,,,,, -164900,8.784489,0.9680544,,,,,,,,,,,,,, -165000,8.63965,0.83186543,,,,,,,,,,,,,, -165100,8.323372,0.8849478,,,,,,,,,,,,,, -165200,7.4385467,0.7411825,,,,,,,,,,,,,, -165300,8.954542,0.9129588,,,,,,,,,,,,,, -165400,7.989693,0.784394,,,,,,,,,,,,,, -165445,,,0.9094586968421936,0.3204634189605713,0.7548999786376953,0.9951203465461732,50000.0,0.6299000382423401,1.6766432523727417,10000.0,56138.10664725304,58129.62777304649,56138.10664725304,1979.5256507396696,5.970006227493286,0.0 -165500,9.485728,0.85412544,,,,,,,,,,,,,, -165600,8.04858,0.8808103,,,,,,,,,,,,,, -165700,7.580618,0.86031336,,,,,,,,,,,,,, -165800,8.9313345,0.84972245,,,,,,,,,,,,,, -165900,8.669374,0.8785649,,,,,,,,,,,,,, -166000,9.4879055,0.89091617,,,,,,,,,,,,,, -166100,8.653397,0.7873676,,,,,,,,,,,,,, -166200,8.52817,0.82113266,,,,,,,,,,,,,, -166300,9.069027,0.9477559,,,,,,,,,,,,,, -166400,9.1419115,0.88521546,,,,,,,,,,,,,, -166500,8.453964,0.84264004,,,,,,,,,,,,,, -166600,8.4903965,0.92493033,,,,,,,,,,,,,, -166700,9.065573,0.85439575,,,,,,,,,,,,,, -166800,9.133621,0.88808095,,,,,,,,,,,,,, -166900,9.069817,0.9040082,,,,,,,,,,,,,, -166950,,,0.9108538031578064,0.3147288560867309,0.7562599778175354,0.9941707849502563,50000.0,0.6342000365257263,1.6934216022491455,10000.0,56648.064351558685,58657.44054579735,56648.064351558685,1997.2634472846985,6.031470537185669,0.0 -167000,8.897828,0.87777364,,,,,,,,,,,,,, -167100,8.7189455,0.7958157,,,,,,,,,,,,,, -167200,8.884516,0.78463423,,,,,,,,,,,,,, -167300,8.391943,0.807543,,,,,,,,,,,,,, -167400,8.904191,0.851018,,,,,,,,,,,,,, -167500,8.53531,0.8854998,,,,,,,,,,,,,, -167600,8.900493,0.77411747,,,,,,,,,,,,,, -167700,9.360258,0.91096133,,,,,,,,,,,,,, -167800,9.703303,0.8215849,,,,,,,,,,,,,, -167900,9.3452835,0.9000125,,,,,,,,,,,,,, -168000,8.601395,0.77974975,,,,,,,,,,,,,, -168100,9.326934,0.9333147,,,,,,,,,,,,,, -168200,9.887579,0.841836,,,,,,,,,,,,,, -168300,8.76172,0.86908513,,,,,,,,,,,,,, -168400,8.919871,0.8763445,,,,,,,,,,,,,, -168456,,,0.9100167155265808,0.3191568851470947,0.7554000020027161,0.9903133511543274,50000.0,0.6330000162124634,1.6975325345993042,10000.0,57158.12152385712,59185.312942266464,57158.12152385712,2014.961641788483,6.093728303909302,0.0 -168500,8.03898,0.81908077,,,,,,,,,,,,,, -168600,8.912004,0.813007,,,,,,,,,,,,,, -168700,8.68866,0.891792,,,,,,,,,,,,,, -168800,9.105606,0.8413818,,,,,,,,,,,,,, -168900,8.901915,0.8325179,,,,,,,,,,,,,, -169000,8.605404,0.8467147,,,,,,,,,,,,,, -169100,8.620702,0.8402456,,,,,,,,,,,,,, -169200,7.9599967,0.7574463,,,,,,,,,,,,,, -169300,8.414045,0.8343065,,,,,,,,,,,,,, -169400,9.0611,0.85386336,,,,,,,,,,,,,, -169500,8.538441,0.86865395,,,,,,,,,,,,,, -169600,8.507243,0.76729566,,,,,,,,,,,,,, -169700,8.918632,0.8630873,,,,,,,,,,,,,, -169800,8.52441,0.82110316,,,,,,,,,,,,,, -169900,8.679917,0.8156318,,,,,,,,,,,,,, -169961,,,0.913843274116516,0.3032202124595642,0.7573599815368652,0.9809739589691162,50000.0,0.6330000162124634,1.672535061836243,10000.0,57668.05534219742,59713.05906128883,57668.05534219742,2032.660756349564,6.152337074279785,0.0 -170000,8.119278,0.8023194,,,,,,,,,,,,,, -170100,7.744203,0.77146804,,,,,,,,,,,,,, -170200,8.387452,0.73096997,,,,,,,,,,,,,, -170300,8.482311,0.80199575,,,,,,,,,,,,,, -170400,8.206524,0.75678134,,,,,,,,,,,,,, -170500,8.470091,0.8692888,,,,,,,,,,,,,, -170600,10.269438,0.7468994,,,,,,,,,,,,,, -170700,9.025297,0.82067096,,,,,,,,,,,,,, -170800,9.669613,0.7369323,,,,,,,,,,,,,, -170900,8.221779,0.7615716,,,,,,,,,,,,,, -171000,9.46376,0.79321784,,,,,,,,,,,,,, -171100,8.538305,0.7733072,,,,,,,,,,,,,, -171200,9.801412,0.74845016,,,,,,,,,,,,,, -171300,9.807942,0.8454754,,,,,,,,,,,,,, -171400,8.70937,0.8120924,,,,,,,,,,,,,, -171467,,,0.9151387214660645,0.2979476153850555,0.7592200040817261,0.9786510467529296,50000.0,0.6367000341415405,1.6685278415679932,10000.0,58178.16114640236,60241.09210109711,58178.16114640236,2050.4763667583466,6.209350109100342,0.0 -171500,8.971468,0.7682956,,,,,,,,,,,,,, -171600,8.026164,0.7586969,,,,,,,,,,,,,, -171700,9.169056,0.7652765,,,,,,,,,,,,,, -171800,8.715974,0.6677642,,,,,,,,,,,,,, -171900,9.758629,0.8098292,,,,,,,,,,,,,, -172000,9.271377,0.90119547,,,,,,,,,,,,,, -172100,8.706139,0.77709854,,,,,,,,,,,,,, -172200,8.820711,0.8077713,,,,,,,,,,,,,, -172300,8.781903,0.76797855,,,,,,,,,,,,,, -172400,8.332614,0.7738071,,,,,,,,,,,,,, -172500,7.895977,0.7567351,,,,,,,,,,,,,, -172600,8.755182,0.7427488,,,,,,,,,,,,,, -172700,9.634292,0.7702914,,,,,,,,,,,,,, -172800,9.340139,0.8592293,,,,,,,,,,,,,, -172900,9.96168,0.8678269,,,,,,,,,,,,,, -172972,,,0.92386794090271,0.2707740664482116,0.759619951248169,0.9723941683769226,50000.0,0.6366000175476074,1.6629676818847656,10000.0,58688.1078979969,60768.88009810448,58688.1078979969,2068.201717376709,6.271347761154175,0.0 -173000,8.594854,0.75961196,,,,,,,,,,,,,, -173100,9.585924,0.80601776,,,,,,,,,,,,,, -173200,9.57504,0.7807369,,,,,,,,,,,,,, -173300,9.233519,0.81512076,,,,,,,,,,,,,, -173400,10.137301,0.7828534,,,,,,,,,,,,,, -173500,8.907897,0.803873,,,,,,,,,,,,,, -173600,10.294262,0.79547465,,,,,,,,,,,,,, -173700,9.659908,0.820402,,,,,,,,,,,,,, -173800,10.923583,0.7572964,,,,,,,,,,,,,, -173900,9.698726,0.85455513,,,,,,,,,,,,,, -174000,10.025697,0.7387782,,,,,,,,,,,,,, -174100,8.947298,0.7958953,,,,,,,,,,,,,, -174200,8.982894,0.85614824,,,,,,,,,,,,,, -174300,8.777735,0.82087994,,,,,,,,,,,,,, -174400,8.960303,0.69811064,,,,,,,,,,,,,, -174478,,,0.9291493892669678,0.2514609694480896,0.7603799700737,0.972877323627472,50000.0,0.6362000107765198,1.6596471071243286,10000.0,59198.22804784775,61296.72523832321,59198.22804784775,2085.809932947159,6.333725452423096,0.0 -174500,8.424347,0.7046167,,,,,,,,,,,,,, -174600,8.570312,0.7243212,,,,,,,,,,,,,, -174700,7.652475,0.64683294,,,,,,,,,,,,,, -174800,8.705972,0.70179725,,,,,,,,,,,,,, -174900,9.300635,0.77761245,,,,,,,,,,,,,, -175000,8.620006,0.7874844,,,,,,,,,,,,,, -175100,9.523096,0.84014696,,,,,,,,,,,,,, -175200,9.149161,0.72197616,,,,,,,,,,,,,, -175300,8.8673115,0.6979325,,,,,,,,,,,,,, -175400,9.11282,0.7705695,,,,,,,,,,,,,, -175500,8.67582,0.72256035,,,,,,,,,,,,,, -175600,8.377939,0.7983055,,,,,,,,,,,,,, -175700,8.856939,0.7636601,,,,,,,,,,,,,, -175800,8.080408,0.7352811,,,,,,,,,,,,,, -175900,10.542345,0.7309542,,,,,,,,,,,,,, -175983,,,0.9298469424247742,0.2528488039970398,0.7617799639701843,0.9658573865890504,50000.0,0.636900007724762,1.6625109910964966,10000.0,59708.26288509369,61824.635633945465,59708.26288509369,2103.5739080905914,6.391595363616943,0.0 -176000,8.583758,0.70917934,,,,,,,,,,,,,, -176100,9.87815,0.7165991,,,,,,,,,,,,,, -176200,8.486223,0.76886827,,,,,,,,,,,,,, -176300,9.144442,0.74686694,,,,,,,,,,,,,, -176400,8.793,0.6721122,,,,,,,,,,,,,, -176500,9.671788,0.716807,,,,,,,,,,,,,, -176600,9.065035,0.7507217,,,,,,,,,,,,,, -176700,10.010795,0.7056916,,,,,,,,,,,,,, -176800,8.730184,0.824761,,,,,,,,,,,,,, -176900,8.517241,0.67621255,,,,,,,,,,,,,, -177000,8.603064,0.7561098,,,,,,,,,,,,,, -177100,9.126188,0.77759504,,,,,,,,,,,,,, -177200,9.4908085,0.8608352,,,,,,,,,,,,,, -177300,8.946759,0.72019887,,,,,,,,,,,,,, -177400,9.108384,0.7462239,,,,,,,,,,,,,, -177488,,,0.928730845451355,0.2565664649009704,0.7629799842834473,0.9615415930747986,50000.0,0.6383000016212463,1.6531953811645508,10000.0,60218.22137951851,62352.43422079086,60218.22137951851,2121.300199508667,6.451011419296265,0.0 -177500,9.570676,0.7543113,,,,,,,,,,,,,, -177600,9.7117195,0.7396354,,,,,,,,,,,,,, -177700,8.261297,0.68162584,,,,,,,,,,,,,, -177800,9.281864,0.70635164,,,,,,,,,,,,,, -177900,9.829368,0.7403432,,,,,,,,,,,,,, -178000,9.851309,0.7806231,,,,,,,,,,,,,, -178100,9.0550785,0.72548854,,,,,,,,,,,,,, -178200,9.313773,0.79718685,,,,,,,,,,,,,, -178300,9.697166,0.833084,,,,,,,,,,,,,, -178400,8.535954,0.77185297,,,,,,,,,,,,,, -178500,9.123644,0.75487554,,,,,,,,,,,,,, -178600,10.133786,0.81215245,,,,,,,,,,,,,, -178700,9.467583,0.69463885,,,,,,,,,,,,,, -178800,9.95308,0.75124246,,,,,,,,,,,,,, -178900,9.127736,0.82297885,,,,,,,,,,,,,, -178993,,,0.9295678734779358,0.2503544390201568,0.7623199820518494,0.9605156183242798,50000.0,0.6381000280380249,1.6532292366027832,10000.0,60728.37961912155,62880.57772922516,60728.37961912155,2139.160705327988,6.521094560623169,0.0 -179000,8.222974,0.82189864,,,,,,,,,,,,,, -179100,8.965132,0.76622546,,,,,,,,,,,,,, -179200,9.218553,0.7604549,,,,,,,,,,,,,, -179300,9.9645605,0.75274545,,,,,,,,,,,,,, -179400,9.454602,0.74490035,,,,,,,,,,,,,, -179500,10.015554,0.82460386,,,,,,,,,,,,,, -179600,8.532604,0.752223,,,,,,,,,,,,,, -179700,8.711238,0.7074256,,,,,,,,,,,,,, -179800,8.932909,0.77520335,,,,,,,,,,,,,, -179900,8.664107,0.7591541,,,,,,,,,,,,,, -180000,9.880764,0.7335593,,,,,,,,,,,,,, -180100,8.377816,0.7396017,,,,,,,,,,,,,, -180200,9.130872,0.7564129,,,,,,,,,,,,,, -180300,9.200827,0.6835066,,,,,,,,,,,,,, -180400,9.360336,0.74173653,,,,,,,,,,,,,, -180498,,,0.9304248690605164,0.244537353515625,0.7628200054168701,0.9591808319091796,50000.0,0.638200044631958,1.647757053375244,10000.0,61238.27491044998,63408.26725912094,61238.27491044998,2156.840073823929,6.580701112747192,0.0 -180500,8.313469,0.7884447,,,,,,,,,,,,,, -180600,9.098681,0.77242017,,,,,,,,,,,,,, -180700,10.245217,0.7459005,,,,,,,,,,,,,, -180800,9.1308975,0.73095244,,,,,,,,,,,,,, -180900,9.392126,0.7851853,,,,,,,,,,,,,, -181000,8.340416,0.6638748,,,,,,,,,,,,,, -181100,9.012318,0.7844553,,,,,,,,,,,,,, -181200,9.329056,0.7519989,,,,,,,,,,,,,, -181300,9.399156,0.7300211,,,,,,,,,,,,,, -181400,8.530986,0.73614615,,,,,,,,,,,,,, -181500,9.646076,0.7639184,,,,,,,,,,,,,, -181600,8.796009,0.7810178,,,,,,,,,,,,,, -181700,9.455642,0.7381169,,,,,,,,,,,,,, -181800,8.277154,0.7104387,,,,,,,,,,,,,, -181900,8.924979,0.7298016,,,,,,,,,,,,,, -182000,8.792897,0.72995347,,,,,,,,,,,,,, -182003,,,0.9319595098495485,0.2480978667736053,0.7633199691772461,0.958383858203888,50000.0,0.6396000385284424,1.6491045951843262,10000.0,61748.39874982834,63936.03000330925,61748.39874982834,2174.361873626709,6.64280366897583,0.0 -182100,8.594622,0.7423986,,,,,,,,,,,,,, -182200,9.128986,0.7719408,,,,,,,,,,,,,, -182300,9.193796,0.7191367,,,,,,,,,,,,,, -182400,9.1554,0.7365853,,,,,,,,,,,,,, -182500,8.221519,0.6734039,,,,,,,,,,,,,, -182600,9.257147,0.77346444,,,,,,,,,,,,,, -182700,8.61188,0.7343214,,,,,,,,,,,,,, -182800,9.319537,0.8104249,,,,,,,,,,,,,, -182900,8.997934,0.7184775,,,,,,,,,,,,,, -183000,8.860897,0.78939843,,,,,,,,,,,,,, -183100,8.894595,0.70575935,,,,,,,,,,,,,, -183200,9.057471,0.7616729,,,,,,,,,,,,,, -183300,9.481912,0.71345925,,,,,,,,,,,,,, -183400,8.026166,0.70230675,,,,,,,,,,,,,, -183500,8.581702,0.7821125,,,,,,,,,,,,,, -183508,,,0.93359375,0.2431531846523285,0.7633199691772461,0.9569819569587708,50000.0,0.6368000507354736,1.650039553642273,10000.0,62258.506739377975,64464.27554774284,62258.506739377975,2192.3842589855194,6.70346999168396,0.0 -183600,10.261045,0.7802158,,,,,,,,,,,,,, -183700,8.910693,0.80798835,,,,,,,,,,,,,, -183800,8.806255,0.7309091,,,,,,,,,,,,,, -183900,9.036113,0.8634082,,,,,,,,,,,,,, -184000,10.217736,0.7414354,,,,,,,,,,,,,, -184100,8.418113,0.62340176,,,,,,,,,,,,,, -184200,10.160118,0.7762836,,,,,,,,,,,,,, -184300,9.68691,0.7806283,,,,,,,,,,,,,, -184400,9.362571,0.72553146,,,,,,,,,,,,,, -184500,9.527659,0.71002394,,,,,,,,,,,,,, -184600,8.404279,0.7152992,,,,,,,,,,,,,, -184700,8.201759,0.7309009,,,,,,,,,,,,,, -184800,8.77157,0.7020699,,,,,,,,,,,,,, -184900,9.142817,0.7402829,,,,,,,,,,,,,, -185000,9.019348,0.72184545,,,,,,,,,,,,,, -185013,,,0.9331752061843872,0.2407183647155761,0.7636199593544006,0.957119882106781,50000.0,0.6378000378608704,1.648453950881958,10000.0,62768.50212907791,64992.15400886536,62768.50212907791,2210.148278236389,6.768203496932983,0.0 -185100,9.462385,0.74127525,,,,,,,,,,,,,, -185200,8.946919,0.6756021,,,,,,,,,,,,,, -185300,9.991117,0.7908383,,,,,,,,,,,,,, -185400,10.325288,0.74327314,,,,,,,,,,,,,, -185500,9.689745,0.7382238,,,,,,,,,,,,,, -185600,8.840704,0.697145,,,,,,,,,,,,,, -185700,9.807121,0.7607864,,,,,,,,,,,,,, -185721,,,,,,,,,,,63008.252861738205,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 48c25392b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -18.02929162979126,0.0,32.62680268287659,1,0,32.62680268287659,0.0006000000284984,6.9125494956970215,10000,50.65618777275085,0.0010961415246129,6.912662982940674,0.0007599999662488,6.913174629211426,50000 -36.14868259429932,0.0185163021087646,542.689546585083,1498,0,542.689546585083,0.1358000040054321,4.661078929901123,10000,578.9105360507965,0.2016103267669677,4.02529239654541,0.1814000010490417,4.143918514251709,50000 -54.20573401451111,0.0462877750396728,1052.8628568649292,2997,0,1052.8628568649292,0.2497000098228454,3.80063271522522,10000,1107.221934556961,0.3574816584587097,2.9869542121887207,0.3305200040340423,3.145972728729248,50000 -72.17443251609802,0.0769317150115966,1562.7982881069183,4497,0,1562.7982881069183,0.2798000276088714,3.6492860317230233,10000,1635.2103700637815,0.3845065236091614,2.834846258163452,0.3623199760913849,2.950068712234497,50000 -89.74597930908203,0.1048753261566162,2072.8568086624146,5999,0,2072.8568086624146,0.2388000041246414,4.093076705932617,10000,2162.92281293869,0.3507254421710968,3.0492608547210693,0.3096399903297424,3.407315969467163,50000 -107.88517737388612,0.1373639106750488,2582.8044941425323,7501,0,2582.8044941425323,0.0470000021159648,6.509797096252441,10000,2691.0967836380005,0.0716478005051612,5.990123271942139,0.0696799978613853,6.023623466491699,50000 -126.39080882072447,0.1694235801696777,3092.775918245316,9005,0,3092.775918245316,0.2943000197410583,3.5717177391052246,10000,3219.6616904735565,0.4273158311843872,2.571629285812378,0.3967399895191192,2.782094717025757,50000 -144.12997436523438,0.1947331428527832,3602.7280580997467,10509,0,3602.7280580997467,0.133200004696846,5.446406364440918,10000,3747.433574438095,0.1981425285339355,4.607287406921387,0.1787599921226501,4.729046821594238,50000 -161.86528515815735,0.223168134689331,4112.816953420639,12015,0,4112.816953420639,0.0442000031471252,6.679653644561768,10000,4275.339448213577,0.0677016898989677,6.2029805183410645,0.0630199983716011,6.263773441314697,50000 -179.61972188949585,0.2544693946838379,4622.73382973671,13520,0,4622.73382973671,0.2696000039577484,3.811286449432373,10000,4803.095787525177,0.3900669515132904,2.867115497589112,0.3664799928665161,3.0192630290985107,50000 -198.42433190345764,0.2853665351867676,5132.905265569687,15027,0,5132.905265569687,0.1473000049591064,5.18111515045166,10000,5332.158295869827,0.2157804518938064,4.200702667236328,0.1911999881267547,4.499402523040772,50000 -216.898282289505,0.3170950412750244,5643.002831459045,16534,0,5643.002831459045,0.1986000090837478,4.769838333129883,10000,5860.815819740295,0.2916533648967743,3.804319143295288,0.2672999799251556,4.0446624755859375,50000 -234.83249616622925,0.34627366065979,6152.994265079498,18042,0,6152.994265079498,0.1946000158786773,4.378692626953125,10000,6388.826258897781,0.2917530238628387,3.487156391143799,0.2727600038051605,3.6579201221466064,50000 -252.6183815002441,0.3780183792114258,6663.09573674202,19550,0,6663.09573674202,0.1752000153064727,4.855306148529053,10000,6916.799047708511,0.2448979616165161,4.03620719909668,0.2307799905538559,4.1693949699401855,50000 -270.57379508018494,0.4093115329742431,7173.257848501205,21059,0,7173.257848501205,0.0834000036120414,6.58660888671875,10000,7445.002496004105,0.1134805455803871,5.853148460388184,0.1053599938750267,6.068696022033691,50000 -288.48663353919983,0.4424080848693847,7683.180972576141,22567,0,7683.180972576141,0.142300009727478,5.258711814880371,10000,7972.925594329834,0.2086256295442581,4.400444030761719,0.1987399905920028,4.497589111328125,50000 -306.22909712791443,0.473712682723999,8193.239090681076,24076,0,8193.239090681076,0.1926000118255615,4.453147411346436,10000,8500.813712358475,0.2997648119926452,3.4221742153167725,0.2604199945926666,3.78203558921814,50000 -324.04403138160706,0.506464958190918,8703.223582744598,25584,0,8703.223582744598,0.0534000024199485,6.625743389129639,10000,9028.702632427216,0.085339605808258,5.972334861755371,0.0833600014448165,5.932559967041016,50000 -342.0292069911957,0.5385315418243408,9213.1674451828,27093,0,9213.1674451828,0.2060000151395797,4.422956943511963,10000,9556.718665838242,0.2820471823215484,3.6558451652526855,0.2640599906444549,3.76244592666626,50000 -359.71742606163025,0.5759317874908447,9723.193513393402,28602,0,9723.193513393402,0.1536000072956085,5.504821300506592,10000,10084.52348947525,0.213428720831871,4.641663551330566,0.1976799964904785,4.830415725708008,50000 -377.4810211658478,0.612412691116333,10233.195832252502,30048,0,10233.195832252502,0.1106000021100044,5.6790618896484375,10000,10612.37854552269,0.1551937162876129,4.965692043304443,0.1513599902391433,5.025054454803467,50000 -395.98668384552,0.6475231647491455,10743.147997140884,31558,0,10743.147997140884,0.1656000018119812,5.351226806640625,10000,11140.925688028336,0.2456154227256775,4.217416763305664,0.2311199903488159,4.373788833618164,50000 -413.9356341362,0.678107738494873,11253.401313781738,33070,0,11253.401313781738,0.1248000040650367,5.6210150718688965,10000,11669.21297764778,0.1723333895206451,4.832897663116455,0.1630599945783615,4.949711322784424,50000 -431.7884068489074,0.7146234512329102,11763.362285137177,34580,0,11763.362285137177,0.096000000834465,6.327773094177246,10000,12197.119595527647,0.136957898736,5.655052661895752,0.125459998846054,5.8631768226623535,50000 -449.6273202896118,0.7484467029571533,12273.451851844788,36091,0,12273.451851844788,0.1918000131845474,4.572559833526611,10000,12725.135549068453,0.2906768023967743,3.589463233947754,0.2702600061893463,3.757100105285645,50000 -467.3155233860016,0.7832720279693604,12783.49753022194,37601,0,12783.49753022194,0.2678000032901764,3.7928645610809326,10000,13252.958825826645,0.3878348171710968,2.8524887561798096,0.3601000010967254,3.0229721069335938,50000 -485.3442895412445,0.820784330368042,13293.57846236229,39112,0,13293.57846236229,0.0463000014424324,6.7392120361328125,10000,13781.161801338196,0.0672831609845161,6.174365043640137,0.0680800005793571,6.197511196136475,50000 -503.2651972770691,0.8609738349914551,13803.59059214592,40623,0,13803.59059214592,0.066500000655651,6.820237159729004,10000,14309.18965625763,0.1066844686865806,6.06245756149292,0.0982799977064132,6.124608516693115,50000 -521.0472767353058,0.9110279083251952,14313.546487808228,42134,0,14313.546487808228,0.1122000068426132,5.937928199768066,10000,14837.033590316772,0.1703603267669677,4.943631649017334,0.162320002913475,5.046461582183838,50000 -538.9020249843597,2.015630960464477,14822.672254800797,43643,0,14822.672254800797,0.1562000066041946,5.226414203643799,10000,15365.175357818604,0.2089445143938064,4.4410600662231445,0.1917800009250641,4.693207263946533,50000 -556.4439558982849,2.052147626876831,15332.741003751757,45155,0,15332.741003751757,0.1218000054359436,5.921328067779541,10000,15892.876887321472,0.1822783797979354,4.9106268882751465,0.1706199944019317,5.011672019958496,50000 -574.490583896637,2.092453956604004,15842.871906518936,46667,0,15842.871906518936,0.0630000010132789,6.328012466430664,10000,16421.14906358719,0.0934709832072258,5.805931568145752,0.0856800004839897,5.914002418518066,50000 -592.2423067092896,2.131907224655152,16352.813577651978,48179,0,16352.813577651978,0.0724000036716461,6.272417545318604,10000,16948.93683218956,0.0985132306814193,5.712707996368408,0.0935999974608421,5.7977824211120605,50000 -609.932549238205,2.17134976387024,16862.992620944977,49691,0,16862.992620944977,0.2041000127792358,4.3102006912231445,10000,17476.90072107315,0.3025948703289032,3.468749523162842,0.2851999998092651,3.577686786651612,50000 -627.7616715431213,2.2113375663757324,17373.104477643967,51203,0,17373.104477643967,0.2173000127077102,4.417811393737793,10000,18004.936414718628,0.3303371965885162,3.299055576324463,0.2969200015068054,3.578896999359131,50000 -645.7019193172455,2.252317190170288,17883.095085144043,52715,0,17883.095085144043,0.130400002002716,5.461226463317871,10000,18532.963203907013,0.1820591539144516,4.70458459854126,0.1709399968385696,4.80840539932251,50000 -663.4520955085754,2.2944045066833496,18393.22921514511,54228,0,18393.22921514511,0.110100008547306,5.747947216033936,10000,19060.944731235504,0.1476801633834839,5.142282009124756,0.1403799951076507,5.216727256774902,50000 -681.3409023284912,2.333213329315185,18903.42562484741,55740,0,18903.42562484741,0.1557000130414962,5.351230621337891,10000,19589.123804807663,0.2161591202020645,4.46263599395752,0.2102999985218048,4.564239025115967,50000 -699.0399971008301,2.3717331886291504,19413.61908721924,57253,0,19413.61908721924,0.2063000053167343,4.711694240570068,10000,20117.10993623733,0.2811702787876129,3.879862785339356,0.2648199796676636,4.011868476867676,50000 -716.9220430850983,2.4119575023651123,19923.75394082069,58765,0,19923.75394082069,0.1857000142335891,4.536026477813721,10000,20645.222013950348,0.2714046537876129,3.727704286575317,0.2528599798679352,3.881164312362671,50000 -734.5041456222534,2.4547340869903564,20433.7780482769,60278,0,20433.7780482769,0.0945000052452087,6.115200519561768,10000,21172.926019191746,0.1557318270206451,5.1116814613342285,0.1419599950313568,5.236762523651123,50000 -752.4799780845642,2.491457223892212,20943.91853928566,61791,0,20943.91853928566,0.1754000037908554,4.864076614379883,10000,21701.13439130783,0.2585897445678711,3.929126262664795,0.2430599927902221,4.085925579071045,50000 -770.3702204227448,2.534074068069458,21454.120364904404,63304,0,21454.120364904404,0.1388000100851059,5.745086193084717,10000,22229.32262706757,0.1938775479793548,4.898688316345215,0.184919998049736,4.995222091674805,50000 -788.0393702983856,2.598047971725464,21964.04625606537,64817,0,21964.04625606537,0.1337000131607055,5.076531410217285,10000,22757.03709292412,0.1991988122463226,4.335831165313721,0.1844199895858764,4.489870548248291,50000 -805.9393737316132,2.638539552688598,22474.17966222763,66330,0,22474.17966222763,0.2068000137805938,4.329137802124023,10000,23285.16569185257,0.3051060140132904,3.4450433254241943,0.2985999882221222,3.5019185543060303,50000 -823.8984444141388,2.679395914077759,22984.28769302368,67842,0,22984.28769302368,0.2841000258922577,3.717777013778688,10000,23813.329062461853,0.3987165093421936,2.7879419326782227,0.3781400024890899,2.944031238555908,50000 -841.9661107063293,2.723883628845215,23494.273504018784,69355,0,23494.273504018784,0.1639000028371811,5.079490184783936,10000,24341.48214435577,0.2458346635103225,4.067709922790527,0.2218199968338012,4.309610843658447,50000 -859.8252913951874,2.7633161544799805,24004.327831745148,70867,0,24004.327831745148,0.2218000143766403,4.1046953201293945,10000,24869.49009847641,0.3173628747463226,3.2800393104553223,0.2933799922466278,3.43172574043274,50000 -877.5954036712646,2.8088531494140625,24514.24355506897,72379,0,24514.24355506897,0.2957000136375427,3.51944899559021,10000,25397.27818083763,0.4204001724720001,2.6156625747680664,0.3950199782848358,2.788191795349121,50000 -895.3173720836639,2.8476545810699463,25024.23308992386,73891,0,25024.23308992386,0.2730000019073486,3.748530149459839,10000,25925.08330845833,0.3917610049247741,2.7942934036254883,0.3677600026130676,2.944324016571045,50000 -912.9306666851044,2.89243745803833,25534.32436609268,75404,0,25534.32436609268,0.0658000037074089,6.601757526397705,10000,26452.890166282654,0.0949457883834838,6.09211540222168,0.0862199962139129,6.208457946777344,50000 -930.827305316925,2.9355061054229736,26044.48452091217,76917,0,26044.48452091217,0.3020000159740448,3.514342784881592,10000,26981.044637680054,0.4280731678009033,2.589535713195801,0.4032399952411651,2.7519068717956543,50000 -948.81911110878,2.9777910709381104,26554.670390605927,78430,0,26554.670390605927,0.2303000092506408,4.192265510559082,10000,27509.320876836777,0.3382294178009033,3.221769094467163,0.3022799789905548,3.51043963432312,50000 -966.4048013687134,3.0196309089660645,27064.858870744705,79943,0,27064.858870744705,0.2179000079631805,4.380500316619873,10000,28037.1936275959,0.3127590715885162,3.484292507171631,0.2857199907302856,3.710571527481079,50000 -984.0845103263856,3.068606376647949,27575.083287000656,81457,0,27575.083287000656,0.0578000023961067,7.317497253417969,10000,28565.20376110077,0.0907804518938064,6.651303291320801,0.0834999978542327,6.782551288604736,50000 -1001.7920260429382,3.1208367347717285,28085.3017513752,82971,0,28085.3017513752,0.2320000082254409,4.121487140655518,10000,29093.237498044968,0.3289221823215484,3.341825723648072,0.3080599904060364,3.4806392192840576,50000 -1019.570032596588,3.163076639175415,28595.36977601052,84484,0,28595.36977601052,0.2698000073432922,3.844083786010742,10000,29621.18208193779,0.3950494229793548,2.8487935066223145,0.3696599900722503,3.0001261234283447,50000 -1037.0942661762238,3.2089014053344727,29105.444316864014,85997,0,29105.444316864014,0.101500004529953,6.495641708374023,10000,30148.880268096924,0.1472815722227096,5.7417826652526855,0.1353199928998947,5.94066047668457,50000 -1054.9764783382416,3.251824378967285,29615.636449813843,87510,0,29615.636449813843,0.2642000019550323,4.0185370445251465,10000,30677.051599264145,0.3875956535339355,2.937028884887696,0.3501999974250793,3.214970588684082,50000 -1072.5670273303986,3.307556867599488,30125.56599378585,89023,0,30125.56599378585,0.2768000066280365,3.74910569190979,10000,31204.683507680893,0.4036192595958709,2.742282629013061,0.3774600028991699,2.935097932815552,50000 -1090.2088098526,3.3545897006988525,30635.79352426529,90537,0,30635.79352426529,0.1845000088214874,4.86121940612793,10000,31732.65471124649,0.2742745578289032,3.888530969619751,0.2615199983119964,4.029567241668701,50000 -1108.64000082016,3.399713516235352,31145.958006620407,92050,0,31145.958006620407,0.2576000094413757,3.8726961612701416,10000,32261.35049700737,0.3800821006298065,2.910261869430542,0.3625999987125397,3.03848934173584,50000 -1126.3750030994415,4.565646648406982,31654.82076215744,93560,0,31654.82076215744,0.3085000216960907,3.6801254749298096,10000,32789.16884326935,0.4333545863628387,2.672496795654297,0.4023999869823456,2.8671212196350098,50000 -1144.1742820739746,4.614187479019165,32165.047719717026,95073,0,32165.047719717026,0.2123000174760818,4.541915893554688,10000,33317.30035114288,0.3159478604793548,3.5270910263061523,0.2946999967098236,3.69120979309082,50000 -1162.073585271835,4.659753799438477,32675.180696725845,96586,0,32675.180696725845,0.25,4.006485462188721,10000,33845.43358707428,0.3710738122463226,2.9733757972717285,0.3378999829292297,3.213762044906616,50000 -1179.7177047729492,4.704460382461548,33185.274107694626,98099,0,33185.274107694626,0.2865000069141388,3.734477996826172,10000,34373.27248048782,0.4164939224720001,2.686826467514038,0.3877999782562256,2.8848915100097656,50000 -1197.6669921875,4.754195690155029,33695.28577399254,99612,0,33695.28577399254,0.3606000244617462,3.1262447834014893,10000,34901.336842536926,0.5056201815605164,2.150707960128784,0.4710799753665924,2.3676950931549072,50000 -1215.8287003040314,4.804080247879028,34205.252484321594,101124,0,34205.252484321594,0.313400000333786,3.591003179550171,10000,35429.5738132,0.4436583220958709,2.589078903198242,0.4133199751377105,2.795851230621338,50000 -1233.6238696575165,4.856037855148315,34715.23096227646,102637,0,34715.23096227646,0.3857000172138214,3.0341577529907227,10000,35957.45670199394,0.5242147445678711,2.068694829940796,0.4916799962520599,2.279438018798828,50000 -1251.489592075348,4.904949903488159,35225.40841984749,104150,0,35225.40841984749,0.3177000284194946,3.549726247787476,10000,36485.60284900665,0.440828263759613,2.609811305999756,0.4175199866294861,2.7619614601135254,50000 -1269.1765999794006,4.950104713439941,35735.49154949188,105662,0,35735.49154949188,0.1985000073909759,4.91162109375,10000,37013.47520160675,0.2795559465885162,3.90887188911438,0.2576999962329864,4.139658451080322,50000 -1286.9605538845062,4.994918346405029,36245.43359160423,107174,0,36245.43359160423,0.3568000197410583,3.284795045852661,10000,37541.30133938789,0.500019907951355,2.244053363800049,0.4604199826717376,2.513838768005371,50000 -1304.5049712657928,5.042736291885376,36755.34883475304,108686,0,36755.34883475304,0.355100005865097,3.2092649936676025,10000,38068.86385965347,0.4937818646430969,2.252739191055298,0.4607200026512146,2.472726345062256,50000 -1322.303986787796,5.093266010284424,37265.49964928627,110199,0,37265.49964928627,0.3678000271320343,3.0940093994140625,10000,38596.9192211628,0.5153260231018066,2.0981321334838867,0.4820399880409241,2.304888963699341,50000 -1340.0610992908478,5.142434120178223,37775.48208498955,111712,0,37775.48208498955,0.3568000197410583,3.1949143409729004,10000,39124.7622282505,0.5056201815605164,2.1512796878814697,0.47079998254776,2.372743368148804,50000 -1357.7409224510193,5.193975210189819,38285.63218688965,113225,0,38285.63218688965,0.3579000234603882,3.2229323387146,10000,39652.69889450073,0.5051419138908386,2.175563335418701,0.4692799746990204,2.409187078475952,50000 -1376.1895372867584,5.24059534072876,38795.65815782547,114739,0,38795.65815782547,0.3734000325202942,3.111175060272217,10000,40181.27318763733,0.5345184803009033,2.0138721466064453,0.4868199825286865,2.304708957672119,50000 -1393.7754967212677,5.281083822250366,39305.71879410744,116252,0,39305.71879410744,0.4410000145435333,2.634089708328247,10000,40709.015604019165,0.6103116869926453,1.59967303276062,0.5604000091552734,1.880994200706482,50000 -1411.296749830246,5.329135894775391,39815.93658471108,117765,0,39815.93658471108,0.4052000045776367,2.835646152496338,10000,41236.85900259018,0.5663065910339355,1.824598789215088,0.5275200009346008,2.038508892059326,50000 -1428.9782931804657,5.38359808921814,40326.08441853523,119278,0,40326.08441853523,0.2028000056743621,4.84499979019165,10000,41764.800423145294,0.2941246628761291,3.7996959686279297,0.2820599973201751,3.9288508892059326,50000 -1446.6527774333954,5.439541339874268,40836.241518974304,120791,0,40836.241518974304,0.3850000202655792,3.043755292892456,10000,42292.74313831329,0.5305325388908386,2.007243633270264,0.4994799792766571,2.1964006423950195,50000 -1465.3159244060516,5.488732576370239,41346.37443685532,122303,0,41346.37443685532,0.3849000036716461,3.013498067855835,10000,42821.643122434616,0.5705317258834839,1.820460081100464,0.5156199932098389,2.1237943172454834,50000 -1483.166562795639,5.5297324657440186,41856.57735204697,123817,0,41856.57735204697,0.4502000212669372,2.5984723567962646,10000,43349.79244160652,0.6350645422935486,1.4768481254577637,0.5727599859237671,1.826435089111328,50000 -1500.7979154586792,5.595268964767456,42366.60649180412,125330,0,42366.60649180412,0.4202000200748443,2.7875430583953857,10000,43877.57340598106,0.5938695669174194,1.693787932395935,0.5395399928092957,2.001538276672364,50000 -1518.3210427761078,5.644772291183472,42876.69231629372,126843,0,42876.69231629372,0.443200021982193,2.6522583961486816,10000,44405.28691577912,0.5989516973495483,1.6441659927368164,0.5536400079727173,1.9106814861297607,50000 -1535.9049890041351,5.6943440437316895,43386.7288172245,128356,0,43386.7288172245,0.4610000252723694,2.505321979522705,10000,44933.01505827904,0.6427175998687744,1.4331533908843994,0.5913599729537964,1.713654637336731,50000 -1553.898535490036,5.74742603302002,43896.84852051735,129869,0,43896.84852051735,0.4061000049114227,2.9258809089660645,10000,45461.23577904701,0.5605269074440002,1.889477491378784,0.5187999606132507,2.140009880065918,50000 -1571.618933916092,5.801463842391968,44406.79634475708,131383,0,44406.79634475708,0.4692000150680542,2.497609138488769,10000,45989.0142326355,0.6741071343421936,1.2939298152923584,0.5915399789810181,1.7248343229293823,50000 -1589.165275335312,5.855878114700317,44916.78594422341,132895,0,44916.78594422341,0.444100022315979,2.6210970878601074,10000,46516.66002821922,0.6228475570678711,1.5460199117660522,0.5662800073623657,1.8539618253707888,50000 -1607.0003921985626,5.908795595169067,45426.89608311653,134408,0,45426.89608311653,0.5121000409126282,2.2486801147460938,10000,47044.71316599846,0.6938576102256775,1.2104731798171997,0.6318199634552002,1.532360315322876,50000 -1624.6373193264008,5.960654020309448,45937.00923109055,135922,0,45937.00923109055,0.4917000234127044,2.346391439437866,10000,47572.57046985626,0.6696428656578064,1.3166732788085938,0.6161999702453613,1.6147363185882568,50000 -1642.9792296886444,6.0111939907073975,46447.15961503983,137436,0,46447.15961503983,0.4622000157833099,2.4913787841796875,10000,48101.16772198677,0.6374163031578064,1.4651330709457395,0.5859400033950806,1.7452064752578735,50000 -1660.8099303245544,6.053507328033447,46957.08819317818,138949,0,46957.08819317818,0.49590003490448,2.300657033920288,10000,48629.02605581284,0.6802853941917419,1.2585314512252808,0.6293399930000305,1.5424376726150513,50000 -1678.656478881836,6.13153076171875,47467.14528131485,140462,0,47467.14528131485,0.5039000511169434,2.279113292694092,10000,49157.06298923493,0.7210419178009033,1.0726977586746216,0.6291199922561646,1.5267126560211182,50000 -1696.3789055347445,6.181941986083984,47977.29451775551,141976,0,47977.29451775551,0.5092000365257263,2.2130398750305176,10000,49685.04183840752,0.7115154266357422,1.108195662498474,0.6423599720001221,1.4797382354736328,50000 -1714.0664644241333,6.234250068664551,48487.30615639687,143487,0,48487.30615639687,0.5218999981880188,2.1681156158447266,10000,50212.84715676308,0.7179726958274841,1.1000560522079468,0.6461799740791321,1.4582358598709106,50000 -1732.0806908607483,6.286098003387451,48997.7052898407,145001,0,48997.7052898407,0.5112000107765198,2.244545221328736,10000,50741.36831307411,0.7001753449440002,1.1895331144332886,0.639519989490509,1.4978569746017456,50000 -1749.8422105312347,6.343450784683228,49507.65221524239,146513,0,49507.65221524239,0.5179000496864319,2.226635932922364,10000,51269.19027972221,0.6998565196990967,1.1637591123580933,0.6406199932098389,1.486459493637085,50000 -1767.6306097507477,6.397595882415772,50017.61883044243,148026,0,50017.61883044243,0.524399995803833,2.186429262161255,10000,51797.05554533005,0.7155213356018066,1.1146916151046753,0.6486999988555908,1.4627606868743896,50000 -1785.4646661281586,6.454508543014526,50527.6418004036,149538,0,50527.6418004036,0.539900004863739,2.03213119506836,10000,52325.024513721466,0.7698501348495483,0.8804860711097717,0.6730799674987793,1.3303167819976809,50000 -1803.1257722377777,6.508534669876099,51037.73308491707,151051,0,51037.73308491707,0.5393000245094299,2.099137544631958,10000,52852.88480973244,0.7565170526504517,0.9239798188209534,0.6729199886322021,1.331039309501648,50000 -1821.800268173217,6.5628204345703125,51547.81756424904,152564,0,51547.81756424904,0.5270000100135803,2.163541555404663,10000,53381.75227236748,0.7428451776504517,0.9893755316734314,0.6662999987602234,1.3725179433822632,50000 -1839.209418058396,6.6128644943237305,52057.731224775314,154077,0,52057.731224775314,0.5506000518798828,2.063292980194092,10000,53909.18029880524,0.7563576102256775,0.9220529198646544,0.6801199913024902,1.3138034343719482,50000 -1857.059386253357,6.671154737472534,52567.773965358734,155590,0,52567.773965358734,0.555400013923645,2.034101963043213,10000,54437.1859266758,0.7622169852256775,0.8947357535362244,0.6869399547576904,1.2795292139053345,50000 -1874.6021373271944,6.727156400680542,53077.83566641808,157103,0,53077.83566641808,0.5601000189781189,2.017871379852295,10000,54964.90230464935,0.7747727632522583,0.8410025238990784,0.6900999546051025,1.268452286720276,50000 -1892.21878695488,6.782275915145874,53588.04853415489,158616,0,53588.04853415489,0.5782000422477722,1.9279481172561648,10000,55492.84319996834,0.8008609414100647,0.7219278812408447,0.6961399912834167,1.2478055953979492,50000 -1910.420448064804,6.837030410766602,54098.26779890061,160130,0,54098.26779890061,0.5812000036239624,1.8732895851135247,10000,56021.37393474579,0.8089724183082581,0.7101982831954956,0.7047399878501892,1.194737195968628,50000 -1928.0951828956604,6.881301641464233,54608.44234919548,161643,0,54608.44234919548,0.5758000016212463,1.9370954036712649,10000,56549.32446479797,0.7973134517669678,0.7421764135360718,0.702739953994751,1.2192460298538208,50000 -1946.0246217250824,6.949363708496094,55118.50452852249,163156,0,55118.50452852249,0.5887000560760498,1.8791732788085933,10000,57077.438390254974,0.8103076815605164,0.7060856819152832,0.7111200094223022,1.171807885169983,50000 -1963.8206391334527,7.007922172546387,55628.45368814469,164668,0,55628.45368814469,0.5937000513076782,1.8554983139038088,10000,57605.29715466499,0.8194754123687744,0.6554054021835327,0.7178800106048584,1.144294261932373,50000 -1981.6831741333008,7.071213245391846,56138.54439759255,166181,0,56138.54439759255,0.6082000136375427,1.8160563707351685,10000,58133.3680229187,0.8342036008834839,0.6015270948410034,0.727840006351471,1.113991379737854,50000 -2000.268955469132,7.146226406097412,56648.53292417526,167694,0,56648.53292417526,0.5982000231742859,1.8255773782730105,10000,58662.07354712486,0.8501076102256775,0.546515941619873,0.7238799929618835,1.1284018754959106,50000 -2018.9037404060364,7.192639112472534,57158.69640493393,169207,0,57158.69640493393,0.6070000529289246,1.7918938398361206,10000,59190.97307920456,0.8522600531578064,0.5386980772018433,0.7333599925041199,1.0911046266555786,50000 -2036.7781956195831,7.252174377441406,57668.74922180176,170720,0,57668.74922180176,0.6105000376701355,1.766804337501526,10000,59719.0146727562,0.8514229655265808,0.5230114459991455,0.7345799803733826,1.079545021057129,50000 -2054.613529920578,7.311715364456177,58178.75817775726,172233,0,58178.75817775726,0.6189000010490417,1.7463383674621582,10000,60246.97335195541,0.8605110049247742,0.5039197206497192,0.7382400035858154,1.062277913093567,50000 -2072.3738420009613,7.36877703666687,58688.94418215752,173746,0,58688.94418215752,0.6200000047683716,1.7367162704467771,10000,60775.03083705902,0.8649752736091614,0.4853964745998382,0.7422399520874023,1.0512233972549438,50000 -2090.135448217392,7.425171852111816,59199.13371539116,175260,0,59199.13371539116,0.6203000545501709,1.7439327239990234,10000,61303.094465732574,0.8682437539100647,0.4675627946853637,0.7414199709892273,1.049114465713501,50000 -2108.078211784363,7.485778331756592,59709.20754027367,176772,0,59709.20754027367,0.6271000504493713,1.7231155633926392,10000,61831.22635412216,0.8796038031578064,0.4347046613693237,0.7447999715805054,1.036797285079956,50000 -2125.9084413051605,7.544202566146851,60219.31211185455,178285,0,60219.31211185455,0.6279000043869019,1.7289866209030151,10000,62359.2754983902,0.8785474896430969,0.4288856983184814,0.746399998664856,1.0343586206436155,50000 -2143.679017782212,7.606815099716186,60729.40054988861,179797,0,60729.40054988861,0.6247000098228455,1.7195037603378296,10000,62887.25140142441,0.8817960619926453,0.4216182827949524,0.7474600076675415,1.027718424797058,50000 -2161.401031255722,7.670835256576538,61239.36246538162,181309,0,61239.36246538162,0.6256000399589539,1.7120555639266968,10000,63415.05498671532,0.8843869566917419,0.4094845950603485,0.7476199865341187,1.02272629737854,50000 -2179.212869882584,7.729203224182129,61749.32241177559,182821,0,61749.32241177559,0.6282000541687012,1.714728593826294,10000,63942.94005918503,0.8831712007522583,0.4109700918197632,0.7479400038719177,1.0229642391204834,50000 -2196.8955330848694,7.789600133895874,62259.33857059479,184333,0,62259.33857059479,0.6284000277519226,1.7110487222671509,10000,64470.75457596779,0.8838687539100647,0.4128356277942657,0.7488600015640259,1.020200252532959,50000 -2214.4908051490784,7.853912591934204,62769.32061004639,185845,0,62769.32061004639,0.6296000480651855,1.7125041484832764,10000,64998.45347905159,0.8853037357330322,0.4091644287109375,0.7495599985122681,1.0208008289337158,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/measurements.csv deleted file mode 100644 index aad9915e0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1992 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6558059,6.934149,,,,,,,,,,,,,, -1,,,0.0010961415246129,6.912662982940674,0.0007599999662488,6.913174629211426,50000.0,0.0006000000284984,6.9125494956970215,10000.0,32.62680268287659,50.65618777275085,32.62680268287659,18.02929162979126,0.0,0.0 -100,0.7810716,6.672268,,,,,,,,,,,,,, -200,1.0219444,6.297262,,,,,,,,,,,,,, -300,6.2087893,6.054476,,,,,,,,,,,,,, -400,4.5331187,5.6799917,,,,,,,,,,,,,, -500,5.169176,5.634752,,,,,,,,,,,,,, -600,6.997705,5.380351,,,,,,,,,,,,,, -700,4.0257077,5.267664,,,,,,,,,,,,,, -800,3.036806,5.053461,,,,,,,,,,,,,, -900,6.0927377,4.8379254,,,,,,,,,,,,,, -1000,4.313483,4.7127924,,,,,,,,,,,,,, -1100,4.073485,4.5392456,,,,,,,,,,,,,, -1200,2.4220073,4.396843,,,,,,,,,,,,,, -1300,2.994902,4.337839,,,,,,,,,,,,,, -1400,2.7875988,4.076362,,,,,,,,,,,,,, -1498,,,0.2016103267669677,4.02529239654541,0.1814000010490417,4.143918514251709,50000.0,0.1358000040054321,4.661078929901123,10000.0,542.689546585083,578.9105360507965,542.689546585083,36.14868259429932,0.0185163021087646,0.0 -1500,2.5834363,4.0971704,,,,,,,,,,,,,, -1600,2.5362196,3.9006062,,,,,,,,,,,,,, -1700,3.0322263,3.7790582,,,,,,,,,,,,,, -1800,2.1933367,3.77353,,,,,,,,,,,,,, -1900,2.1986265,3.7950382,,,,,,,,,,,,,, -2000,2.167945,3.7234352,,,,,,,,,,,,,, -2100,2.6661234,3.5189586,,,,,,,,,,,,,, -2200,1.643664,3.5774548,,,,,,,,,,,,,, -2300,1.4902219,3.4548335,,,,,,,,,,,,,, -2400,1.1049017,3.4938786,,,,,,,,,,,,,, -2500,2.0165987,3.3232234,,,,,,,,,,,,,, -2600,1.1022348,3.2839165,,,,,,,,,,,,,, -2700,2.0350766,3.2241626,,,,,,,,,,,,,, -2800,1.2950163,3.1968985,,,,,,,,,,,,,, -2900,1.494967,3.1876187,,,,,,,,,,,,,, -2997,,,0.3574816584587097,2.9869542121887207,0.3305200040340423,3.145972728729248,50000.0,0.2497000098228454,3.80063271522522,10000.0,1052.8628568649292,1107.221934556961,1052.8628568649292,54.20573401451111,0.0462877750396728,0.0 -3000,1.3317198,3.0252867,,,,,,,,,,,,,, -3100,1.148226,3.1465838,,,,,,,,,,,,,, -3200,1.2535005,3.1704586,,,,,,,,,,,,,, -3300,1.2909877,3.001973,,,,,,,,,,,,,, -3400,0.99989504,3.1714962,,,,,,,,,,,,,, -3500,1.4223672,3.0832045,,,,,,,,,,,,,, -3600,0.8520223,3.0378304,,,,,,,,,,,,,, -3700,1.0149817,2.9314656,,,,,,,,,,,,,, -3800,0.9228158,2.9042933,,,,,,,,,,,,,, -3900,1.31578,2.8699548,,,,,,,,,,,,,, -4000,0.8089478,3.0122123,,,,,,,,,,,,,, -4100,0.7885021,2.916813,,,,,,,,,,,,,, -4200,0.9977814,2.9631293,,,,,,,,,,,,,, -4300,0.7248213,2.7879226,,,,,,,,,,,,,, -4400,0.80387384,2.805394,,,,,,,,,,,,,, -4497,,,0.3845065236091614,2.834846258163452,0.3623199760913849,2.950068712234497,50000.0,0.2798000276088714,3.6492860317230233,10000.0,1562.7982881069183,1635.2103700637815,1562.7982881069183,72.17443251609802,0.0769317150115966,0.0 -4500,0.7877586,2.7325804,,,,,,,,,,,,,, -4600,1.2071753,2.7381735,,,,,,,,,,,,,, -4700,0.9939976,2.742183,,,,,,,,,,,,,, -4800,0.88558775,2.683934,,,,,,,,,,,,,, -4900,0.7930295,2.7896836,,,,,,,,,,,,,, -5000,0.71104026,2.647992,,,,,,,,,,,,,, -5100,1.1151904,2.6948514,,,,,,,,,,,,,, -5200,1.0309209,2.7543397,,,,,,,,,,,,,, -5300,0.83406264,2.8780878,,,,,,,,,,,,,, -5400,0.8980828,2.621498,,,,,,,,,,,,,, -5500,1.155366,2.6752663,,,,,,,,,,,,,, -5600,0.89348304,2.7187543,,,,,,,,,,,,,, -5700,0.90105927,2.5813732,,,,,,,,,,,,,, -5800,0.8883833,2.7209044,,,,,,,,,,,,,, -5900,1.0386847,2.6994154,,,,,,,,,,,,,, -5999,,,0.3507254421710968,3.0492608547210693,0.3096399903297424,3.407315969467163,50000.0,0.2388000041246414,4.093076705932617,10000.0,2072.8568086624146,2162.92281293869,2072.8568086624146,89.74597930908203,0.1048753261566162,0.0 -6000,0.8566075,2.655824,,,,,,,,,,,,,, -6100,1.0328885,2.7919736,,,,,,,,,,,,,, -6200,1.1114051,2.597199,,,,,,,,,,,,,, -6300,0.8957926,2.628284,,,,,,,,,,,,,, -6400,0.92908525,2.5813115,,,,,,,,,,,,,, -6500,0.8604776,2.4763522,,,,,,,,,,,,,, -6600,0.8485209,2.7306228,,,,,,,,,,,,,, -6700,0.8928013,2.4737082,,,,,,,,,,,,,, -6800,0.86538786,2.6161494,,,,,,,,,,,,,, -6900,1.1169052,2.5412652,,,,,,,,,,,,,, -7000,1.0213157,2.4714158,,,,,,,,,,,,,, -7100,0.9949207,2.5565023,,,,,,,,,,,,,, -7200,1.0219148,2.5620716,,,,,,,,,,,,,, -7300,0.9257643,2.5300949,,,,,,,,,,,,,, -7400,0.9879349,2.4789932,,,,,,,,,,,,,, -7500,0.94037074,2.5070217,,,,,,,,,,,,,, -7501,,,0.0716478005051612,5.990123271942139,0.0696799978613853,6.023623466491699,50000.0,0.0470000021159648,6.509797096252441,10000.0,2582.8044941425323,2691.0967836380005,2582.8044941425323,107.88517737388612,0.1373639106750488,0.0 -7600,1.099115,2.5019689,,,,,,,,,,,,,, -7700,0.9435857,2.6873193,,,,,,,,,,,,,, -7800,1.0454364,2.5313942,,,,,,,,,,,,,, -7900,1.1603632,2.702977,,,,,,,,,,,,,, -8000,0.9601969,2.5123158,,,,,,,,,,,,,, -8100,0.9489946,2.5103362,,,,,,,,,,,,,, -8200,0.9429532,2.490365,,,,,,,,,,,,,, -8300,1.0100548,2.5079978,,,,,,,,,,,,,, -8400,1.0318152,2.4810963,,,,,,,,,,,,,, -8500,0.944565,2.4333324,,,,,,,,,,,,,, -8600,1.0615335,2.5821443,,,,,,,,,,,,,, -8700,1.0667629,2.5666568,,,,,,,,,,,,,, -8800,0.9972908,2.512037,,,,,,,,,,,,,, -8900,0.9774855,2.5071602,,,,,,,,,,,,,, -9000,1.0276476,2.436366,,,,,,,,,,,,,, -9005,,,0.4273158311843872,2.571629285812378,0.3967399895191192,2.782094717025757,50000.0,0.2943000197410583,3.5717177391052246,10000.0,3092.775918245316,3219.6616904735565,3092.775918245316,126.39080882072447,0.1694235801696777,0.0 -9100,0.97465265,2.4246988,,,,,,,,,,,,,, -9200,1.0398958,2.5793824,,,,,,,,,,,,,, -9300,0.89451987,2.4392638,,,,,,,,,,,,,, -9400,1.0572649,2.4877234,,,,,,,,,,,,,, -9500,1.0361509,2.5917592,,,,,,,,,,,,,, -9600,0.9599092,2.4174922,,,,,,,,,,,,,, -9700,1.0208335,2.5740652,,,,,,,,,,,,,, -9800,1.0143667,2.4825826,,,,,,,,,,,,,, -9900,1.021359,2.462029,,,,,,,,,,,,,, -10000,0.9485484,2.371998,,,,,,,,,,,,,, -10100,0.9222192,2.6402273,,,,,,,,,,,,,, -10200,0.9690929,2.5753195,,,,,,,,,,,,,, -10300,0.92077404,2.3652525,,,,,,,,,,,,,, -10400,1.0128121,2.464471,,,,,,,,,,,,,, -10500,1.1199573,2.47365,,,,,,,,,,,,,, -10509,,,0.1981425285339355,4.607287406921387,0.1787599921226501,4.729046821594238,50000.0,0.133200004696846,5.446406364440918,10000.0,3602.7280580997467,3747.433574438095,3602.7280580997467,144.12997436523438,0.1947331428527832,0.0 -10600,1.0474057,2.392128,,,,,,,,,,,,,, -10700,1.1597993,2.5113692,,,,,,,,,,,,,, -10800,1.0132638,2.4861922,,,,,,,,,,,,,, -10900,1.0737923,2.5502138,,,,,,,,,,,,,, -11000,1.0645865,2.579109,,,,,,,,,,,,,, -11100,0.9957958,2.387385,,,,,,,,,,,,,, -11200,0.9825701,2.3745625,,,,,,,,,,,,,, -11300,1.1347897,2.4194124,,,,,,,,,,,,,, -11400,1.0781779,2.3091755,,,,,,,,,,,,,, -11500,1.0920986,2.41461,,,,,,,,,,,,,, -11600,1.0194137,2.2947633,,,,,,,,,,,,,, -11700,1.0071566,2.4503765,,,,,,,,,,,,,, -11800,0.9649272,2.3610842,,,,,,,,,,,,,, -11900,1.0210148,2.4537606,,,,,,,,,,,,,, -12000,1.065958,2.5009007,,,,,,,,,,,,,, -12015,,,0.0677016898989677,6.2029805183410645,0.0630199983716011,6.263773441314697,50000.0,0.0442000031471252,6.679653644561768,10000.0,4112.816953420639,4275.339448213577,4112.816953420639,161.86528515815735,0.223168134689331,0.0 -12100,1.0937594,2.421866,,,,,,,,,,,,,, -12200,1.140674,2.5501761,,,,,,,,,,,,,, -12300,1.1027385,2.4750075,,,,,,,,,,,,,, -12400,0.9461284,2.310285,,,,,,,,,,,,,, -12500,0.9514884,2.3729105,,,,,,,,,,,,,, -12600,1.0623171,2.3753834,,,,,,,,,,,,,, -12700,0.910302,2.305409,,,,,,,,,,,,,, -12800,1.2190113,2.4641442,,,,,,,,,,,,,, -12900,0.9553207,2.4781833,,,,,,,,,,,,,, -13000,0.9125773,2.3634403,,,,,,,,,,,,,, -13100,0.95432746,2.4457488,,,,,,,,,,,,,, -13200,0.94127595,2.4047644,,,,,,,,,,,,,, -13300,1.1091976,2.4344888,,,,,,,,,,,,,, -13400,1.0197641,2.4313972,,,,,,,,,,,,,, -13500,1.0606093,2.3569534,,,,,,,,,,,,,, -13520,,,0.3900669515132904,2.867115497589112,0.3664799928665161,3.0192630290985107,50000.0,0.2696000039577484,3.811286449432373,10000.0,4622.73382973671,4803.095787525177,4622.73382973671,179.61972188949585,0.2544693946838379,0.0 -13600,1.0038626,2.3029587,,,,,,,,,,,,,, -13700,1.0515902,2.4313204,,,,,,,,,,,,,, -13800,1.1017478,2.3531826,,,,,,,,,,,,,, -13900,1.0851378,2.3192074,,,,,,,,,,,,,, -14000,0.9746529,2.3771532,,,,,,,,,,,,,, -14100,1.048719,2.379042,,,,,,,,,,,,,, -14200,1.0940796,2.4496703,,,,,,,,,,,,,, -14300,1.1306407,2.3772683,,,,,,,,,,,,,, -14400,1.071237,2.4496977,,,,,,,,,,,,,, -14500,1.0061893,2.5110705,,,,,,,,,,,,,, -14600,1.067338,2.3351514,,,,,,,,,,,,,, -14700,1.0433359,2.345097,,,,,,,,,,,,,, -14800,1.1203321,2.2918582,,,,,,,,,,,,,, -14900,0.9712454,2.2694998,,,,,,,,,,,,,, -15000,1.0379075,2.4613008,,,,,,,,,,,,,, -15027,,,0.2157804518938064,4.200702667236328,0.1911999881267547,4.499402523040772,50000.0,0.1473000049591064,5.18111515045166,10000.0,5132.905265569687,5332.158295869827,5132.905265569687,198.42433190345764,0.2853665351867676,0.0 -15100,0.9873829,2.3725808,,,,,,,,,,,,,, -15200,1.0940052,2.3361802,,,,,,,,,,,,,, -15300,0.9785429,2.3857808,,,,,,,,,,,,,, -15400,0.9912422,2.283329,,,,,,,,,,,,,, -15500,1.0094783,2.413804,,,,,,,,,,,,,, -15600,0.94887996,2.3305962,,,,,,,,,,,,,, -15700,1.0850699,2.3060598,,,,,,,,,,,,,, -15800,0.8975415,2.3720999,,,,,,,,,,,,,, -15900,1.0207454,2.3711543,,,,,,,,,,,,,, -16000,1.1127238,2.3899643,,,,,,,,,,,,,, -16100,1.028219,2.4158092,,,,,,,,,,,,,, -16200,1.0483499,2.4223425,,,,,,,,,,,,,, -16300,1.122263,2.3429673,,,,,,,,,,,,,, -16400,1.1079483,2.432877,,,,,,,,,,,,,, -16500,1.0681388,2.3409584,,,,,,,,,,,,,, -16534,,,0.2916533648967743,3.804319143295288,0.2672999799251556,4.0446624755859375,50000.0,0.1986000090837478,4.769838333129883,10000.0,5643.002831459045,5860.815819740295,5643.002831459045,216.898282289505,0.3170950412750244,0.0 -16600,0.9449347,2.3410923,,,,,,,,,,,,,, -16700,0.9917714,2.304698,,,,,,,,,,,,,, -16800,1.000176,2.3373036,,,,,,,,,,,,,, -16900,1.0125871,2.2860804,,,,,,,,,,,,,, -17000,1.0981817,2.4609659,,,,,,,,,,,,,, -17100,1.1717176,2.3922396,,,,,,,,,,,,,, -17200,1.039264,2.291299,,,,,,,,,,,,,, -17300,1.0986269,2.2669497,,,,,,,,,,,,,, -17400,1.0093448,2.4771972,,,,,,,,,,,,,, -17500,1.0548619,2.2325938,,,,,,,,,,,,,, -17600,1.0074437,2.2383926,,,,,,,,,,,,,, -17700,1.2162435,2.256122,,,,,,,,,,,,,, -17800,0.9807532,2.2865186,,,,,,,,,,,,,, -17900,0.9901072,2.2324123,,,,,,,,,,,,,, -18000,1.0256228,2.3116832,,,,,,,,,,,,,, -18042,,,0.2917530238628387,3.487156391143799,0.2727600038051605,3.6579201221466064,50000.0,0.1946000158786773,4.378692626953125,10000.0,6152.994265079498,6388.826258897781,6152.994265079498,234.83249616622925,0.34627366065979,0.0 -18100,1.0185726,2.4697104,,,,,,,,,,,,,, -18200,0.98772115,2.3453846,,,,,,,,,,,,,, -18300,1.0418397,2.3973875,,,,,,,,,,,,,, -18400,1.1902057,2.2929528,,,,,,,,,,,,,, -18500,1.064385,2.3580887,,,,,,,,,,,,,, -18600,1.0582111,2.2414384,,,,,,,,,,,,,, -18700,1.1682161,2.542976,,,,,,,,,,,,,, -18800,1.0523765,2.2872257,,,,,,,,,,,,,, -18900,1.0284337,2.4310658,,,,,,,,,,,,,, -19000,1.2964822,2.368856,,,,,,,,,,,,,, -19100,1.0785936,2.4019938,,,,,,,,,,,,,, -19200,1.076006,2.356125,,,,,,,,,,,,,, -19300,1.1083624,2.3336632,,,,,,,,,,,,,, -19400,1.05227,2.3237858,,,,,,,,,,,,,, -19500,1.1389114,2.4712214,,,,,,,,,,,,,, -19550,,,0.2448979616165161,4.03620719909668,0.2307799905538559,4.1693949699401855,50000.0,0.1752000153064727,4.855306148529053,10000.0,6663.09573674202,6916.799047708511,6663.09573674202,252.6183815002441,0.3780183792114258,0.0 -19600,1.0979705,2.4795167,,,,,,,,,,,,,, -19700,1.0774912,2.3174932,,,,,,,,,,,,,, -19800,1.2031554,2.263986,,,,,,,,,,,,,, -19900,0.97164524,2.3551064,,,,,,,,,,,,,, -20000,1.091394,2.1976106,,,,,,,,,,,,,, -20100,0.9761988,2.3149126,,,,,,,,,,,,,, -20200,1.1353391,2.4374185,,,,,,,,,,,,,, -20300,1.1180017,2.3264174,,,,,,,,,,,,,, -20400,0.96665835,2.2606976,,,,,,,,,,,,,, -20500,1.064656,2.4099872,,,,,,,,,,,,,, -20600,1.1864136,2.351471,,,,,,,,,,,,,, -20700,1.048781,2.2540178,,,,,,,,,,,,,, -20800,1.0766588,2.373395,,,,,,,,,,,,,, -20900,1.0684878,2.1985319,,,,,,,,,,,,,, -21000,1.0192348,2.374344,,,,,,,,,,,,,, -21059,,,0.1134805455803871,5.853148460388184,0.1053599938750267,6.068696022033691,50000.0,0.0834000036120414,6.58660888671875,10000.0,7173.257848501205,7445.002496004105,7173.257848501205,270.57379508018494,0.4093115329742431,0.0 -21100,0.9707028,2.2807436,,,,,,,,,,,,,, -21200,1.0422069,2.372676,,,,,,,,,,,,,, -21300,1.1078361,2.3672366,,,,,,,,,,,,,, -21400,1.1069989,2.3831544,,,,,,,,,,,,,, -21500,1.1681168,2.3967814,,,,,,,,,,,,,, -21600,1.115913,2.333932,,,,,,,,,,,,,, -21700,1.0718212,2.280603,,,,,,,,,,,,,, -21800,1.0520818,2.3001819,,,,,,,,,,,,,, -21900,1.1123685,2.3549292,,,,,,,,,,,,,, -22000,1.181628,2.314123,,,,,,,,,,,,,, -22100,1.1109927,2.365902,,,,,,,,,,,,,, -22200,1.175576,2.3510032,,,,,,,,,,,,,, -22300,0.99202347,2.2398212,,,,,,,,,,,,,, -22400,1.0718849,2.3263872,,,,,,,,,,,,,, -22500,0.97199976,2.1761942,,,,,,,,,,,,,, -22567,,,0.2086256295442581,4.400444030761719,0.1987399905920028,4.497589111328125,50000.0,0.142300009727478,5.258711814880371,10000.0,7683.180972576141,7972.925594329834,7683.180972576141,288.48663353919983,0.4424080848693847,0.0 -22600,1.0295544,2.2672234,,,,,,,,,,,,,, -22700,1.027496,2.395813,,,,,,,,,,,,,, -22800,1.1650957,2.4012737,,,,,,,,,,,,,, -22900,1.1405773,2.2427225,,,,,,,,,,,,,, -23000,1.0902079,2.357241,,,,,,,,,,,,,, -23100,1.0514164,2.2702622,,,,,,,,,,,,,, -23200,1.3251231,2.4317298,,,,,,,,,,,,,, -23300,1.1281388,2.2533495,,,,,,,,,,,,,, -23400,1.1125504,2.3736103,,,,,,,,,,,,,, -23500,1.1120303,2.3059857,,,,,,,,,,,,,, -23600,1.0554506,2.39809,,,,,,,,,,,,,, -23700,1.1174233,2.409859,,,,,,,,,,,,,, -23800,1.0968015,2.3764806,,,,,,,,,,,,,, -23900,1.0592251,2.4036632,,,,,,,,,,,,,, -24000,1.2600749,2.3632078,,,,,,,,,,,,,, -24076,,,0.2997648119926452,3.4221742153167725,0.2604199945926666,3.78203558921814,50000.0,0.1926000118255615,4.453147411346436,10000.0,8193.239090681076,8500.813712358475,8193.239090681076,306.22909712791443,0.473712682723999,0.0 -24100,1.0789695,2.2978423,,,,,,,,,,,,,, -24200,0.9958641,2.2708006,,,,,,,,,,,,,, -24300,1.0042092,2.3628361,,,,,,,,,,,,,, -24400,1.0712016,2.3618188,,,,,,,,,,,,,, -24500,1.189136,2.392353,,,,,,,,,,,,,, -24600,1.1005512,2.316404,,,,,,,,,,,,,, -24700,1.0644518,2.2113745,,,,,,,,,,,,,, -24800,0.99538153,2.2714725,,,,,,,,,,,,,, -24900,1.0779419,2.3376675,,,,,,,,,,,,,, -25000,1.1105853,2.342821,,,,,,,,,,,,,, -25100,1.0543197,2.3498342,,,,,,,,,,,,,, -25200,1.169304,2.4521337,,,,,,,,,,,,,, -25300,1.1433587,2.415439,,,,,,,,,,,,,, -25400,1.0392652,2.3399696,,,,,,,,,,,,,, -25500,1.1395037,2.3616586,,,,,,,,,,,,,, -25584,,,0.085339605808258,5.972334861755371,0.0833600014448165,5.932559967041016,50000.0,0.0534000024199485,6.625743389129639,10000.0,8703.223582744598,9028.702632427216,8703.223582744598,324.04403138160706,0.506464958190918,0.0 -25600,1.1993698,2.286853,,,,,,,,,,,,,, -25700,1.1061746,2.4044654,,,,,,,,,,,,,, -25800,1.071452,2.2403588,,,,,,,,,,,,,, -25900,1.0652562,2.2167883,,,,,,,,,,,,,, -26000,1.1243747,2.463367,,,,,,,,,,,,,, -26100,1.1590755,2.220294,,,,,,,,,,,,,, -26200,1.1131712,2.320201,,,,,,,,,,,,,, -26300,1.0815372,2.352457,,,,,,,,,,,,,, -26400,1.0949832,2.2328794,,,,,,,,,,,,,, -26500,1.1783352,2.3908453,,,,,,,,,,,,,, -26600,1.0732632,2.334608,,,,,,,,,,,,,, -26700,1.0997908,2.3099005,,,,,,,,,,,,,, -26800,1.1756449,2.3244112,,,,,,,,,,,,,, -26900,1.1250345,2.1180644,,,,,,,,,,,,,, -27000,1.0217075,2.3917086,,,,,,,,,,,,,, -27093,,,0.2820471823215484,3.6558451652526855,0.2640599906444549,3.76244592666626,50000.0,0.2060000151395797,4.422956943511963,10000.0,9213.1674451828,9556.718665838242,9213.1674451828,342.0292069911957,0.5385315418243408,0.0 -27100,1.116375,2.2692463,,,,,,,,,,,,,, -27200,1.0626343,2.3911316,,,,,,,,,,,,,, -27300,1.017978,2.3950675,,,,,,,,,,,,,, -27400,1.1536828,2.3466363,,,,,,,,,,,,,, -27500,1.104717,2.3077772,,,,,,,,,,,,,, -27600,1.1092652,2.3060308,,,,,,,,,,,,,, -27700,1.1197565,2.2707388,,,,,,,,,,,,,, -27800,1.2013221,2.3985596,,,,,,,,,,,,,, -27900,1.1000272,2.4081564,,,,,,,,,,,,,, -28000,1.1379092,2.3373113,,,,,,,,,,,,,, -28100,1.1258452,2.3428931,,,,,,,,,,,,,, -28200,1.0571623,2.3767464,,,,,,,,,,,,,, -28300,1.0271299,2.288055,,,,,,,,,,,,,, -28400,1.1150073,2.2758927,,,,,,,,,,,,,, -28500,1.0979187,2.1412077,,,,,,,,,,,,,, -28600,1.0915543,2.3586667,,,,,,,,,,,,,, -28602,,,0.213428720831871,4.641663551330566,0.1976799964904785,4.830415725708008,50000.0,0.1536000072956085,5.504821300506592,10000.0,9723.193513393402,10084.52348947525,9723.193513393402,359.71742606163025,0.5759317874908447,0.0 -28700,1.0949657,2.3782256,,,,,,,,,,,,,, -28800,1.0322657,2.2555926,,,,,,,,,,,,,, -28900,1.0983279,2.3379314,,,,,,,,,,,,,, -29000,0.98607,2.2355185,,,,,,,,,,,,,, -29100,0.98770714,2.304498,,,,,,,,,,,,,, -29200,1.079972,2.2597716,,,,,,,,,,,,,, -29300,1.0548759,2.3814282,,,,,,,,,,,,,, -29400,1.0997959,2.3834147,,,,,,,,,,,,,, -29500,1.195382,2.3796964,,,,,,,,,,,,,, -29600,1.1641672,2.3731542,,,,,,,,,,,,,, -29700,1.010184,2.1881323,,,,,,,,,,,,,, -29800,1.1987936,2.2391195,,,,,,,,,,,,,, -29900,1.123799,2.363965,,,,,,,,,,,,,, -30000,1.1288084,2.4220166,,,,,,,,,,,,,, -30048,,,0.1551937162876129,4.965692043304443,0.1513599902391433,5.025054454803467,50000.0,0.1106000021100044,5.6790618896484375,10000.0,10233.195832252502,10612.37854552269,10233.195832252502,377.4810211658478,0.612412691116333,0.0 -30100,1.1544608,2.3563035,,,,,,,,,,,,,, -30200,1.1591035,2.4039245,,,,,,,,,,,,,, -30300,1.0002372,2.2938943,,,,,,,,,,,,,, -30400,1.1327547,2.2654727,,,,,,,,,,,,,, -30500,1.0246985,2.2876663,,,,,,,,,,,,,, -30600,1.1280506,2.3459637,,,,,,,,,,,,,, -30700,1.1205949,2.288369,,,,,,,,,,,,,, -30800,1.071229,2.2942553,,,,,,,,,,,,,, -30900,1.0236107,2.3865879,,,,,,,,,,,,,, -31000,1.0709355,2.3052044,,,,,,,,,,,,,, -31100,1.4029585,2.3144567,,,,,,,,,,,,,, -31200,1.2679255,2.2794724,,,,,,,,,,,,,, -31300,1.0606346,2.2813723,,,,,,,,,,,,,, -31400,1.035927,2.3528264,,,,,,,,,,,,,, -31500,1.0408114,2.3285494,,,,,,,,,,,,,, -31558,,,0.2456154227256775,4.217416763305664,0.2311199903488159,4.373788833618164,50000.0,0.1656000018119812,5.351226806640625,10000.0,10743.147997140884,11140.925688028336,10743.147997140884,395.98668384552,0.6475231647491455,0.0 -31600,1.0555074,2.157215,,,,,,,,,,,,,, -31700,1.1237549,2.267316,,,,,,,,,,,,,, -31800,1.2897207,2.30998,,,,,,,,,,,,,, -31900,1.3032197,2.2732558,,,,,,,,,,,,,, -32000,1.2939382,2.248088,,,,,,,,,,,,,, -32100,1.0316403,2.3294709,,,,,,,,,,,,,, -32200,1.1502404,2.3241706,,,,,,,,,,,,,, -32300,1.1583802,2.1833596,,,,,,,,,,,,,, -32400,1.0439909,2.3581135,,,,,,,,,,,,,, -32500,1.068956,2.2415082,,,,,,,,,,,,,, -32600,1.0938823,2.2903588,,,,,,,,,,,,,, -32700,0.9715568,2.3390808,,,,,,,,,,,,,, -32800,1.1782179,2.3004375,,,,,,,,,,,,,, -32900,1.1932169,2.2558184,,,,,,,,,,,,,, -33000,1.0532748,2.2582164,,,,,,,,,,,,,, -33070,,,0.1723333895206451,4.832897663116455,0.1630599945783615,4.949711322784424,50000.0,0.1248000040650367,5.6210150718688965,10000.0,11253.401313781738,11669.21297764778,11253.401313781738,413.9356341362,0.678107738494873,0.0 -33100,1.1389717,2.376559,,,,,,,,,,,,,, -33200,1.115522,2.2948995,,,,,,,,,,,,,, -33300,1.0459621,2.2491796,,,,,,,,,,,,,, -33400,1.1600183,2.4015484,,,,,,,,,,,,,, -33500,1.1371484,2.3329043,,,,,,,,,,,,,, -33600,1.1726581,2.4977608,,,,,,,,,,,,,, -33700,1.2783016,2.1190243,,,,,,,,,,,,,, -33800,1.1212459,2.3241227,,,,,,,,,,,,,, -33900,1.0257045,2.1600966,,,,,,,,,,,,,, -34000,1.037919,2.3263931,,,,,,,,,,,,,, -34100,1.1366444,2.2776477,,,,,,,,,,,,,, -34200,1.0941913,2.139445,,,,,,,,,,,,,, -34300,1.0747594,2.3645265,,,,,,,,,,,,,, -34400,1.0972068,2.2516875,,,,,,,,,,,,,, -34500,1.074216,2.1777887,,,,,,,,,,,,,, -34580,,,0.136957898736,5.655052661895752,0.125459998846054,5.8631768226623535,50000.0,0.096000000834465,6.327773094177246,10000.0,11763.362285137177,12197.119595527647,11763.362285137177,431.7884068489074,0.7146234512329102,0.0 -34600,1.1294402,2.220941,,,,,,,,,,,,,, -34700,1.0914139,2.238156,,,,,,,,,,,,,, -34800,1.338976,2.4665492,,,,,,,,,,,,,, -34900,1.1925558,2.3134909,,,,,,,,,,,,,, -35000,1.2664684,2.2926567,,,,,,,,,,,,,, -35100,1.1004455,2.2650661,,,,,,,,,,,,,, -35200,1.0206324,2.276114,,,,,,,,,,,,,, -35300,1.0592505,2.2628965,,,,,,,,,,,,,, -35400,1.2223598,2.2262216,,,,,,,,,,,,,, -35500,1.1191838,2.2622023,,,,,,,,,,,,,, -35600,1.2657098,2.4009461,,,,,,,,,,,,,, -35700,1.1369945,2.2067828,,,,,,,,,,,,,, -35800,1.1393147,2.2343774,,,,,,,,,,,,,, -35900,1.30377,2.1880686,,,,,,,,,,,,,, -36000,1.1360898,2.1993601,,,,,,,,,,,,,, -36091,,,0.2906768023967743,3.589463233947754,0.2702600061893463,3.757100105285645,50000.0,0.1918000131845474,4.572559833526611,10000.0,12273.451851844788,12725.135549068453,12273.451851844788,449.6273202896118,0.7484467029571533,0.0 -36100,1.0979759,2.303823,,,,,,,,,,,,,, -36200,1.1668656,2.2767997,,,,,,,,,,,,,, -36300,1.1803807,2.3283694,,,,,,,,,,,,,, -36400,1.1424888,2.338965,,,,,,,,,,,,,, -36500,1.1535966,2.3271587,,,,,,,,,,,,,, -36600,1.1808726,2.3769968,,,,,,,,,,,,,, -36700,1.0505043,2.1962671,,,,,,,,,,,,,, -36800,1.0740476,2.2846541,,,,,,,,,,,,,, -36900,1.0509635,2.2417471,,,,,,,,,,,,,, -37000,1.1082019,2.3102822,,,,,,,,,,,,,, -37100,1.0590029,2.3095098,,,,,,,,,,,,,, -37200,1.1747851,2.3052504,,,,,,,,,,,,,, -37300,1.3064716,2.2841504,,,,,,,,,,,,,, -37400,1.2806902,2.2428744,,,,,,,,,,,,,, -37500,1.1813221,2.3916862,,,,,,,,,,,,,, -37600,1.1973671,2.2435267,,,,,,,,,,,,,, -37601,,,0.3878348171710968,2.8524887561798096,0.3601000010967254,3.0229721069335938,50000.0,0.2678000032901764,3.7928645610809326,10000.0,12783.49753022194,13252.958825826645,12783.49753022194,467.3155233860016,0.7832720279693604,0.0 -37700,1.2119406,2.3685572,,,,,,,,,,,,,, -37800,1.1633703,2.2911315,,,,,,,,,,,,,, -37900,1.1171362,2.4215991,,,,,,,,,,,,,, -38000,1.0906887,2.3448858,,,,,,,,,,,,,, -38100,1.1679217,2.3546202,,,,,,,,,,,,,, -38200,1.118729,2.168219,,,,,,,,,,,,,, -38300,1.165508,2.4191086,,,,,,,,,,,,,, -38400,1.1091375,2.3205683,,,,,,,,,,,,,, -38500,1.2911767,2.383605,,,,,,,,,,,,,, -38600,1.2825505,2.3145502,,,,,,,,,,,,,, -38700,1.2251326,2.225765,,,,,,,,,,,,,, -38800,1.2708278,2.2395067,,,,,,,,,,,,,, -38900,1.115028,2.228818,,,,,,,,,,,,,, -39000,1.0793166,2.1642046,,,,,,,,,,,,,, -39100,1.2231048,2.2971263,,,,,,,,,,,,,, -39112,,,0.0672831609845161,6.174365043640137,0.0680800005793571,6.197511196136475,50000.0,0.0463000014424324,6.7392120361328125,10000.0,13293.57846236229,13781.161801338196,13293.57846236229,485.3442895412445,0.820784330368042,0.0 -39200,1.1749053,2.2183151,,,,,,,,,,,,,, -39300,1.3255692,2.3112895,,,,,,,,,,,,,, -39400,1.1212076,2.3111718,,,,,,,,,,,,,, -39500,1.1036491,2.2750685,,,,,,,,,,,,,, -39600,1.1920489,2.1742125,,,,,,,,,,,,,, -39700,1.1412488,2.3156676,,,,,,,,,,,,,, -39800,1.1673621,2.296813,,,,,,,,,,,,,, -39900,1.1499434,2.255745,,,,,,,,,,,,,, -40000,1.2589939,2.327443,,,,,,,,,,,,,, -40100,1.0664016,2.1916282,,,,,,,,,,,,,, -40200,0.99934596,2.171986,,,,,,,,,,,,,, -40300,1.1510526,2.336375,,,,,,,,,,,,,, -40400,1.2263012,2.3393722,,,,,,,,,,,,,, -40500,1.1496968,2.2399056,,,,,,,,,,,,,, -40600,1.172645,2.1460721,,,,,,,,,,,,,, -40623,,,0.1066844686865806,6.06245756149292,0.0982799977064132,6.124608516693115,50000.0,0.066500000655651,6.820237159729004,10000.0,13803.59059214592,14309.18965625763,13803.59059214592,503.2651972770691,0.8609738349914551,0.0 -40700,1.2522115,2.2635002,,,,,,,,,,,,,, -40800,1.2290356,2.3971612,,,,,,,,,,,,,, -40900,1.2083387,2.279293,,,,,,,,,,,,,, -41000,1.1289673,2.209017,,,,,,,,,,,,,, -41100,1.1688215,2.3160374,,,,,,,,,,,,,, -41200,1.2826464,2.2033758,,,,,,,,,,,,,, -41300,1.250016,2.2684176,,,,,,,,,,,,,, -41400,1.339926,2.3295846,,,,,,,,,,,,,, -41500,1.1598328,2.2966034,,,,,,,,,,,,,, -41600,1.1946659,2.2274842,,,,,,,,,,,,,, -41700,1.1578208,2.1852794,,,,,,,,,,,,,, -41800,1.2320151,2.2216446,,,,,,,,,,,,,, -41900,1.101825,2.316664,,,,,,,,,,,,,, -42000,1.1818174,2.178697,,,,,,,,,,,,,, -42100,1.3535166,2.3179214,,,,,,,,,,,,,, -42134,,,0.1703603267669677,4.943631649017334,0.162320002913475,5.046461582183838,50000.0,0.1122000068426132,5.937928199768066,10000.0,14313.546487808228,14837.033590316772,14313.546487808228,521.0472767353058,0.9110279083251952,0.0 -42200,1.2482812,2.350287,,,,,,,,,,,,,, -42300,1.1402133,2.2753057,,,,,,,,,,,,,, -42400,1.2108526,2.2642324,,,,,,,,,,,,,, -42500,1.1508558,2.2374403,,,,,,,,,,,,,, -42600,1.1188774,2.2133093,,,,,,,,,,,,,, -42700,1.1595086,2.3145037,,,,,,,,,,,,,, -42800,1.2136286,2.2264733,,,,,,,,,,,,,, -42900,1.1788573,2.1275706,,,,,,,,,,,,,, -43000,1.2210269,2.2209387,,,,,,,,,,,,,, -43100,1.1001886,2.2337768,,,,,,,,,,,,,, -43200,1.1109455,2.1822798,,,,,,,,,,,,,, -43300,1.1750375,2.2512982,,,,,,,,,,,,,, -43400,1.2178155,2.3473818,,,,,,,,,,,,,, -43500,1.1082699,2.3454132,,,,,,,,,,,,,, -43600,1.2228135,2.249189,,,,,,,,,,,,,, -43643,,,0.2089445143938064,4.4410600662231445,0.1917800009250641,4.693207263946533,50000.0,0.1562000066041946,5.226414203643799,10000.0,14822.672254800797,15365.175357818604,14822.672254800797,538.9020249843597,2.015630960464477,0.0 -43700,1.1796291,2.2475781,,,,,,,,,,,,,, -43800,1.2104611,2.2194827,,,,,,,,,,,,,, -43900,1.2438245,2.356829,,,,,,,,,,,,,, -44000,1.201434,2.2963223,,,,,,,,,,,,,, -44100,1.0778874,2.2419534,,,,,,,,,,,,,, -44200,1.2149293,2.1850653,,,,,,,,,,,,,, -44300,1.049734,2.2041144,,,,,,,,,,,,,, -44400,1.313222,2.2493117,,,,,,,,,,,,,, -44500,1.2241304,2.238169,,,,,,,,,,,,,, -44600,1.1687284,2.2023163,,,,,,,,,,,,,, -44700,1.1925712,2.167689,,,,,,,,,,,,,, -44800,1.1593741,2.1576605,,,,,,,,,,,,,, -44900,1.2819873,2.2950616,,,,,,,,,,,,,, -45000,1.11607,2.1993406,,,,,,,,,,,,,, -45100,1.1844519,2.398654,,,,,,,,,,,,,, -45155,,,0.1822783797979354,4.9106268882751465,0.1706199944019317,5.011672019958496,50000.0,0.1218000054359436,5.921328067779541,10000.0,15332.741003751757,15892.876887321472,15332.741003751757,556.4439558982849,2.052147626876831,0.0 -45200,1.258711,2.2951355,,,,,,,,,,,,,, -45300,1.2963583,2.282247,,,,,,,,,,,,,, -45400,1.1897566,2.2391794,,,,,,,,,,,,,, -45500,1.3013564,2.3353796,,,,,,,,,,,,,, -45600,1.2007413,2.3196635,,,,,,,,,,,,,, -45700,1.2417094,2.209596,,,,,,,,,,,,,, -45800,1.2107278,2.267056,,,,,,,,,,,,,, -45900,1.2891923,2.3836293,,,,,,,,,,,,,, -46000,1.1564738,2.2002945,,,,,,,,,,,,,, -46100,1.2329626,2.1545238,,,,,,,,,,,,,, -46200,1.167117,2.2326863,,,,,,,,,,,,,, -46300,1.2421967,2.1841002,,,,,,,,,,,,,, -46400,1.1953468,2.1148424,,,,,,,,,,,,,, -46500,1.1700581,2.1923215,,,,,,,,,,,,,, -46600,1.127432,2.1733015,,,,,,,,,,,,,, -46667,,,0.0934709832072258,5.805931568145752,0.0856800004839897,5.914002418518066,50000.0,0.0630000010132789,6.328012466430664,10000.0,15842.871906518936,16421.14906358719,15842.871906518936,574.490583896637,2.092453956604004,0.0 -46700,1.2053872,2.2309952,,,,,,,,,,,,,, -46800,1.2831912,2.1785817,,,,,,,,,,,,,, -46900,1.1749651,2.2542398,,,,,,,,,,,,,, -47000,1.2481073,2.2951035,,,,,,,,,,,,,, -47100,1.3061366,2.2205307,,,,,,,,,,,,,, -47200,1.1859524,2.3121836,,,,,,,,,,,,,, -47300,1.203389,2.1841464,,,,,,,,,,,,,, -47400,1.1386555,2.1229346,,,,,,,,,,,,,, -47500,1.5352098,2.267614,,,,,,,,,,,,,, -47600,1.2464198,2.2042742,,,,,,,,,,,,,, -47700,1.255902,2.1871555,,,,,,,,,,,,,, -47800,1.2007166,2.2036562,,,,,,,,,,,,,, -47900,1.3516264,2.2343614,,,,,,,,,,,,,, -48000,1.2145988,2.265582,,,,,,,,,,,,,, -48100,1.2355986,2.2430081,,,,,,,,,,,,,, -48179,,,0.0985132306814193,5.712707996368408,0.0935999974608421,5.7977824211120605,50000.0,0.0724000036716461,6.272417545318604,10000.0,16352.813577651978,16948.93683218956,16352.813577651978,592.2423067092896,2.131907224655152,0.0 -48200,1.2548183,2.1662118,,,,,,,,,,,,,, -48300,1.1899232,2.3421373,,,,,,,,,,,,,, -48400,1.219345,2.219366,,,,,,,,,,,,,, -48500,1.210027,2.2450297,,,,,,,,,,,,,, -48600,1.1550996,2.275972,,,,,,,,,,,,,, -48700,1.1470723,2.2383988,,,,,,,,,,,,,, -48800,1.2065297,2.2634535,,,,,,,,,,,,,, -48900,1.2288486,2.2378366,,,,,,,,,,,,,, -49000,1.2654494,2.1746426,,,,,,,,,,,,,, -49100,1.2880542,2.3176556,,,,,,,,,,,,,, -49200,1.1585015,2.1959891,,,,,,,,,,,,,, -49300,1.1547515,2.1553786,,,,,,,,,,,,,, -49400,1.2953967,2.159866,,,,,,,,,,,,,, -49500,1.3009359,2.1152494,,,,,,,,,,,,,, -49600,1.2021453,2.228241,,,,,,,,,,,,,, -49691,,,0.3025948703289032,3.468749523162842,0.2851999998092651,3.577686786651612,50000.0,0.2041000127792358,4.3102006912231445,10000.0,16862.992620944977,17476.90072107315,16862.992620944977,609.932549238205,2.17134976387024,0.0 -49700,1.1723946,2.1832535,,,,,,,,,,,,,, -49800,1.1912982,2.1995668,,,,,,,,,,,,,, -49900,1.2265854,2.3400707,,,,,,,,,,,,,, -50000,1.3730781,2.330531,,,,,,,,,,,,,, -50100,1.3119571,2.2922127,,,,,,,,,,,,,, -50200,1.2365857,2.2730546,,,,,,,,,,,,,, -50300,1.1162378,2.062758,,,,,,,,,,,,,, -50400,1.1662376,2.137651,,,,,,,,,,,,,, -50500,1.3132141,2.2063577,,,,,,,,,,,,,, -50600,1.2533993,2.2197814,,,,,,,,,,,,,, -50700,1.1994586,2.215732,,,,,,,,,,,,,, -50800,1.1807702,2.3501747,,,,,,,,,,,,,, -50900,1.3225722,2.2386804,,,,,,,,,,,,,, -51000,1.1979053,2.232292,,,,,,,,,,,,,, -51100,1.1187527,2.277072,,,,,,,,,,,,,, -51200,1.6313201,2.2070873,,,,,,,,,,,,,, -51203,,,0.3303371965885162,3.299055576324463,0.2969200015068054,3.578896999359131,50000.0,0.2173000127077102,4.417811393737793,10000.0,17373.104477643967,18004.936414718628,17373.104477643967,627.7616715431213,2.2113375663757324,0.0 -51300,1.228652,2.320199,,,,,,,,,,,,,, -51400,1.283737,2.1357503,,,,,,,,,,,,,, -51500,1.1567065,2.187432,,,,,,,,,,,,,, -51600,1.1261933,2.1464581,,,,,,,,,,,,,, -51700,1.1126711,2.2429543,,,,,,,,,,,,,, -51800,1.3350933,2.2838774,,,,,,,,,,,,,, -51900,1.2796489,2.2832983,,,,,,,,,,,,,, -52000,1.414843,2.2026298,,,,,,,,,,,,,, -52100,1.1674349,2.230494,,,,,,,,,,,,,, -52200,1.3175445,2.1479685,,,,,,,,,,,,,, -52300,1.281198,2.214613,,,,,,,,,,,,,, -52400,1.2428662,2.1373262,,,,,,,,,,,,,, -52500,1.2112614,2.1489375,,,,,,,,,,,,,, -52600,1.1733398,2.2338552,,,,,,,,,,,,,, -52700,1.1853827,2.2557907,,,,,,,,,,,,,, -52715,,,0.1820591539144516,4.70458459854126,0.1709399968385696,4.80840539932251,50000.0,0.130400002002716,5.461226463317871,10000.0,17883.095085144043,18532.963203907013,17883.095085144043,645.7019193172455,2.252317190170288,0.0 -52800,1.2059823,2.1838446,,,,,,,,,,,,,, -52900,1.2664658,2.3141627,,,,,,,,,,,,,, -53000,1.3832408,2.2765007,,,,,,,,,,,,,, -53100,1.2783453,2.123835,,,,,,,,,,,,,, -53200,1.2248766,2.3140717,,,,,,,,,,,,,, -53300,1.175728,2.1768646,,,,,,,,,,,,,, -53400,1.2227191,2.2877493,,,,,,,,,,,,,, -53500,1.3116007,2.2313995,,,,,,,,,,,,,, -53600,1.3171227,2.2824504,,,,,,,,,,,,,, -53700,1.269609,2.2875943,,,,,,,,,,,,,, -53800,1.4005728,2.2640643,,,,,,,,,,,,,, -53900,1.2234956,2.1849473,,,,,,,,,,,,,, -54000,1.2592838,2.2075753,,,,,,,,,,,,,, -54100,1.2574964,2.279374,,,,,,,,,,,,,, -54200,1.2740326,2.2337599,,,,,,,,,,,,,, -54228,,,0.1476801633834839,5.142282009124756,0.1403799951076507,5.216727256774902,50000.0,0.110100008547306,5.747947216033936,10000.0,18393.22921514511,19060.944731235504,18393.22921514511,663.4520955085754,2.2944045066833496,0.0 -54300,1.2001138,2.172352,,,,,,,,,,,,,, -54400,1.2441226,2.3471956,,,,,,,,,,,,,, -54500,1.2058039,2.234682,,,,,,,,,,,,,, -54600,1.5140554,2.2494385,,,,,,,,,,,,,, -54700,1.2572993,2.2421188,,,,,,,,,,,,,, -54800,1.1787401,2.2048104,,,,,,,,,,,,,, -54900,1.326137,2.266259,,,,,,,,,,,,,, -55000,1.1654626,2.2047548,,,,,,,,,,,,,, -55100,1.2773142,2.294346,,,,,,,,,,,,,, -55200,1.2085602,2.113582,,,,,,,,,,,,,, -55300,1.2162948,2.1753538,,,,,,,,,,,,,, -55400,1.3477322,2.164805,,,,,,,,,,,,,, -55500,1.4007572,2.199597,,,,,,,,,,,,,, -55600,1.4148856,2.2362814,,,,,,,,,,,,,, -55700,1.212833,2.2533298,,,,,,,,,,,,,, -55740,,,0.2161591202020645,4.46263599395752,0.2102999985218048,4.564239025115967,50000.0,0.1557000130414962,5.351230621337891,10000.0,18903.42562484741,19589.123804807663,18903.42562484741,681.3409023284912,2.333213329315185,0.0 -55800,1.3547038,2.2009401,,,,,,,,,,,,,, -55900,1.2810638,2.2383027,,,,,,,,,,,,,, -56000,1.2198634,2.2260027,,,,,,,,,,,,,, -56100,1.3147457,2.3154528,,,,,,,,,,,,,, -56200,1.288537,2.295114,,,,,,,,,,,,,, -56300,1.1951405,2.1888933,,,,,,,,,,,,,, -56400,1.2294298,2.1639192,,,,,,,,,,,,,, -56500,1.2870387,2.2477558,,,,,,,,,,,,,, -56600,1.2832211,2.1783004,,,,,,,,,,,,,, -56700,1.2526042,2.1937273,,,,,,,,,,,,,, -56800,1.2300123,2.305351,,,,,,,,,,,,,, -56900,1.3074607,2.3272853,,,,,,,,,,,,,, -57000,1.2914515,2.2940798,,,,,,,,,,,,,, -57100,1.243353,2.112157,,,,,,,,,,,,,, -57200,1.2168536,2.102408,,,,,,,,,,,,,, -57253,,,0.2811702787876129,3.879862785339356,0.2648199796676636,4.011868476867676,50000.0,0.2063000053167343,4.711694240570068,10000.0,19413.61908721924,20117.10993623733,19413.61908721924,699.0399971008301,2.3717331886291504,0.0 -57300,1.238356,2.183279,,,,,,,,,,,,,, -57400,1.3065879,2.3360486,,,,,,,,,,,,,, -57500,1.2684872,2.2248366,,,,,,,,,,,,,, -57600,1.406216,2.1741118,,,,,,,,,,,,,, -57700,1.3243177,2.2284362,,,,,,,,,,,,,, -57800,1.3286096,2.1838574,,,,,,,,,,,,,, -57900,1.3526084,2.2478886,,,,,,,,,,,,,, -58000,1.3225424,2.2688272,,,,,,,,,,,,,, -58100,1.2614968,2.318245,,,,,,,,,,,,,, -58200,1.2786199,2.3641086,,,,,,,,,,,,,, -58300,1.1870176,2.189965,,,,,,,,,,,,,, -58400,1.3418875,2.1484156,,,,,,,,,,,,,, -58500,1.4219584,2.231036,,,,,,,,,,,,,, -58600,1.1990187,2.1309178,,,,,,,,,,,,,, -58700,1.2882344,2.188941,,,,,,,,,,,,,, -58765,,,0.2714046537876129,3.727704286575317,0.2528599798679352,3.881164312362671,50000.0,0.1857000142335891,4.536026477813721,10000.0,19923.75394082069,20645.222013950348,19923.75394082069,716.9220430850983,2.4119575023651123,0.0 -58800,1.3285595,2.1122043,,,,,,,,,,,,,, -58900,1.2336155,2.2329113,,,,,,,,,,,,,, -59000,1.2843494,2.152944,,,,,,,,,,,,,, -59100,1.3016039,2.1478586,,,,,,,,,,,,,, -59200,1.2540239,2.368957,,,,,,,,,,,,,, -59300,1.2256237,2.177031,,,,,,,,,,,,,, -59400,1.3660612,2.3110833,,,,,,,,,,,,,, -59500,1.288035,2.2509944,,,,,,,,,,,,,, -59600,1.2091354,2.2300992,,,,,,,,,,,,,, -59700,1.410662,2.1601663,,,,,,,,,,,,,, -59800,1.4117029,2.1194258,,,,,,,,,,,,,, -59900,1.5503975,2.2759619,,,,,,,,,,,,,, -60000,1.3642111,2.1526291,,,,,,,,,,,,,, -60100,1.2079558,2.2941625,,,,,,,,,,,,,, -60200,1.2032347,2.1816401,,,,,,,,,,,,,, -60278,,,0.1557318270206451,5.1116814613342285,0.1419599950313568,5.236762523651123,50000.0,0.0945000052452087,6.115200519561768,10000.0,20433.7780482769,21172.926019191746,20433.7780482769,734.5041456222534,2.4547340869903564,0.0 -60300,1.3527576,2.2248416,,,,,,,,,,,,,, -60400,1.2400826,2.1478214,,,,,,,,,,,,,, -60500,1.2117157,2.0330243,,,,,,,,,,,,,, -60600,1.3319947,2.198973,,,,,,,,,,,,,, -60700,1.2248573,2.011834,,,,,,,,,,,,,, -60800,1.235879,2.2548144,,,,,,,,,,,,,, -60900,1.2423214,2.226592,,,,,,,,,,,,,, -61000,1.2153502,2.1547441,,,,,,,,,,,,,, -61100,1.19721,2.080743,,,,,,,,,,,,,, -61200,1.4232891,2.2964022,,,,,,,,,,,,,, -61300,1.3401717,2.2341952,,,,,,,,,,,,,, -61400,1.2840775,2.1875148,,,,,,,,,,,,,, -61500,1.317114,2.348532,,,,,,,,,,,,,, -61600,1.6127194,2.2530546,,,,,,,,,,,,,, -61700,1.4160763,2.1324785,,,,,,,,,,,,,, -61791,,,0.2585897445678711,3.929126262664795,0.2430599927902221,4.085925579071045,50000.0,0.1754000037908554,4.864076614379883,10000.0,20943.91853928566,21701.13439130783,20943.91853928566,752.4799780845642,2.491457223892212,0.0 -61800,1.3233838,2.2058964,,,,,,,,,,,,,, -61900,1.2667075,2.1370883,,,,,,,,,,,,,, -62000,1.4209418,2.1685266,,,,,,,,,,,,,, -62100,1.3188553,2.238481,,,,,,,,,,,,,, -62200,1.3099769,2.228169,,,,,,,,,,,,,, -62300,1.2472334,2.0723271,,,,,,,,,,,,,, -62400,1.3949414,2.2518501,,,,,,,,,,,,,, -62500,1.2914851,2.1927733,,,,,,,,,,,,,, -62600,1.464678,2.0031319,,,,,,,,,,,,,, -62700,1.3476295,2.4059,,,,,,,,,,,,,, -62800,1.280738,2.1386113,,,,,,,,,,,,,, -62900,1.1813952,2.091012,,,,,,,,,,,,,, -63000,1.2981011,2.2861993,,,,,,,,,,,,,, -63100,1.2912221,2.1537957,,,,,,,,,,,,,, -63200,1.3755896,2.1879706,,,,,,,,,,,,,, -63300,1.3368491,2.1820636,,,,,,,,,,,,,, -63304,,,0.1938775479793548,4.898688316345215,0.184919998049736,4.995222091674805,50000.0,0.1388000100851059,5.745086193084717,10000.0,21454.120364904404,22229.32262706757,21454.120364904404,770.3702204227448,2.534074068069458,0.0 -63400,1.4308205,2.2070565,,,,,,,,,,,,,, -63500,1.2408828,2.045497,,,,,,,,,,,,,, -63600,1.3139238,2.0226202,,,,,,,,,,,,,, -63700,1.3389357,2.2128248,,,,,,,,,,,,,, -63800,1.3621771,2.209304,,,,,,,,,,,,,, -63900,1.3651255,2.0626426,,,,,,,,,,,,,, -64000,1.1545136,2.1216612,,,,,,,,,,,,,, -64100,1.3676707,2.287193,,,,,,,,,,,,,, -64200,1.3558507,2.1899238,,,,,,,,,,,,,, -64300,1.2394905,2.1554513,,,,,,,,,,,,,, -64400,1.2298055,2.1161523,,,,,,,,,,,,,, -64500,1.3078222,2.2722504,,,,,,,,,,,,,, -64600,1.2240889,2.1043797,,,,,,,,,,,,,, -64700,1.2396234,2.153605,,,,,,,,,,,,,, -64800,1.274657,2.0744085,,,,,,,,,,,,,, -64817,,,0.1991988122463226,4.335831165313721,0.1844199895858764,4.489870548248291,50000.0,0.1337000131607055,5.076531410217285,10000.0,21964.04625606537,22757.03709292412,21964.04625606537,788.0393702983856,2.598047971725464,0.0 -64900,1.2790109,2.1638558,,,,,,,,,,,,,, -65000,1.3961023,2.1654751,,,,,,,,,,,,,, -65100,1.586968,2.2361658,,,,,,,,,,,,,, -65200,1.48834,2.0774045,,,,,,,,,,,,,, -65300,1.3583962,2.3180938,,,,,,,,,,,,,, -65400,1.3856499,2.0557244,,,,,,,,,,,,,, -65500,1.4077777,2.0187056,,,,,,,,,,,,,, -65600,1.3976526,2.2472715,,,,,,,,,,,,,, -65700,1.5664546,2.2611303,,,,,,,,,,,,,, -65800,1.3564732,2.0849986,,,,,,,,,,,,,, -65900,1.2965742,2.115613,,,,,,,,,,,,,, -66000,1.3424753,2.0624292,,,,,,,,,,,,,, -66100,1.2723447,2.1713355,,,,,,,,,,,,,, -66200,1.3990968,2.1764276,,,,,,,,,,,,,, -66300,1.397936,2.1613262,,,,,,,,,,,,,, -66330,,,0.3051060140132904,3.4450433254241943,0.2985999882221222,3.5019185543060303,50000.0,0.2068000137805938,4.329137802124023,10000.0,22474.17966222763,23285.16569185257,22474.17966222763,805.9393737316132,2.638539552688598,0.0 -66400,1.3643645,2.2665257,,,,,,,,,,,,,, -66500,1.3065027,2.0968633,,,,,,,,,,,,,, -66600,1.4861628,2.1917136,,,,,,,,,,,,,, -66700,1.1746749,2.0636544,,,,,,,,,,,,,, -66800,1.3549243,2.0965385,,,,,,,,,,,,,, -66900,1.4300607,2.2360964,,,,,,,,,,,,,, -67000,1.3059125,2.04774,,,,,,,,,,,,,, -67100,1.4427346,2.2074194,,,,,,,,,,,,,, -67200,1.3280938,2.1718485,,,,,,,,,,,,,, -67300,1.5569947,2.2177613,,,,,,,,,,,,,, -67400,1.2315443,2.0708568,,,,,,,,,,,,,, -67500,1.3091458,2.1391563,,,,,,,,,,,,,, -67600,1.3937777,2.2188964,,,,,,,,,,,,,, -67700,1.4373176,2.1408129,,,,,,,,,,,,,, -67800,1.3215111,2.2304316,,,,,,,,,,,,,, -67842,,,0.3987165093421936,2.7879419326782227,0.3781400024890899,2.944031238555908,50000.0,0.2841000258922577,3.717777013778688,10000.0,22984.28769302368,23813.329062461853,22984.28769302368,823.8984444141388,2.679395914077759,0.0 -67900,1.384757,2.1783018,,,,,,,,,,,,,, -68000,1.3373581,2.1132865,,,,,,,,,,,,,, -68100,1.395274,2.2334068,,,,,,,,,,,,,, -68200,1.3679208,2.0701313,,,,,,,,,,,,,, -68300,1.4277512,2.147853,,,,,,,,,,,,,, -68400,1.2967596,2.1840546,,,,,,,,,,,,,, -68500,1.3340589,2.0760174,,,,,,,,,,,,,, -68600,1.3039529,2.1062827,,,,,,,,,,,,,, -68700,1.293603,2.2158504,,,,,,,,,,,,,, -68800,1.2648487,2.0630357,,,,,,,,,,,,,, -68900,1.3147874,2.1382322,,,,,,,,,,,,,, -69000,1.2213771,2.1086676,,,,,,,,,,,,,, -69100,1.2052768,2.0332942,,,,,,,,,,,,,, -69200,1.3015033,1.9803379,,,,,,,,,,,,,, -69300,1.2505536,2.124039,,,,,,,,,,,,,, -69355,,,0.2458346635103225,4.067709922790527,0.2218199968338012,4.309610843658447,50000.0,0.1639000028371811,5.079490184783936,10000.0,23494.273504018784,24341.48214435577,23494.273504018784,841.9661107063293,2.723883628845215,0.0 -69400,1.4945847,2.2235212,,,,,,,,,,,,,, -69500,1.4759699,2.2143397,,,,,,,,,,,,,, -69600,1.2934266,2.1762774,,,,,,,,,,,,,, -69700,1.2701911,2.0257647,,,,,,,,,,,,,, -69800,1.4868824,2.2278135,,,,,,,,,,,,,, -69900,1.3875319,2.1724222,,,,,,,,,,,,,, -70000,1.2007703,2.1193082,,,,,,,,,,,,,, -70100,1.3299059,2.154335,,,,,,,,,,,,,, -70200,1.3564645,2.0723972,,,,,,,,,,,,,, -70300,1.3902051,2.0883756,,,,,,,,,,,,,, -70400,1.4779923,2.225999,,,,,,,,,,,,,, -70500,1.5652423,2.105237,,,,,,,,,,,,,, -70600,1.6226614,2.1930997,,,,,,,,,,,,,, -70700,1.3272945,2.1730263,,,,,,,,,,,,,, -70800,1.4490288,2.2358758,,,,,,,,,,,,,, -70867,,,0.3173628747463226,3.2800393104553223,0.2933799922466278,3.43172574043274,50000.0,0.2218000143766403,4.1046953201293945,10000.0,24004.327831745148,24869.49009847641,24004.327831745148,859.8252913951874,2.7633161544799805,0.0 -70900,1.4088544,2.2661376,,,,,,,,,,,,,, -71000,1.3348837,2.0801957,,,,,,,,,,,,,, -71100,1.2335432,1.9869858,,,,,,,,,,,,,, -71200,1.517432,2.1684413,,,,,,,,,,,,,, -71300,1.3664858,2.1963744,,,,,,,,,,,,,, -71400,1.3043518,2.063426,,,,,,,,,,,,,, -71500,1.3845918,2.226204,,,,,,,,,,,,,, -71600,1.2833927,2.2017128,,,,,,,,,,,,,, -71700,1.4714769,2.149345,,,,,,,,,,,,,, -71800,1.2131878,2.0485988,,,,,,,,,,,,,, -71900,1.3714117,2.1767135,,,,,,,,,,,,,, -72000,1.3267366,2.0514386,,,,,,,,,,,,,, -72100,1.6341189,2.1919065,,,,,,,,,,,,,, -72200,1.4936391,2.1247675,,,,,,,,,,,,,, -72300,1.4353253,2.1437833,,,,,,,,,,,,,, -72379,,,0.4204001724720001,2.6156625747680664,0.3950199782848358,2.788191795349121,50000.0,0.2957000136375427,3.51944899559021,10000.0,24514.24355506897,25397.27818083763,24514.24355506897,877.5954036712646,2.8088531494140625,0.0 -72400,1.5910803,2.2014225,,,,,,,,,,,,,, -72500,1.4418991,2.1271856,,,,,,,,,,,,,, -72600,1.6495507,1.9807642,,,,,,,,,,,,,, -72700,1.4527403,2.0198584,,,,,,,,,,,,,, -72800,1.392946,2.1115363,,,,,,,,,,,,,, -72900,1.2895619,2.080418,,,,,,,,,,,,,, -73000,1.3006749,2.1404645,,,,,,,,,,,,,, -73100,1.332787,2.1436925,,,,,,,,,,,,,, -73200,1.3432251,2.1599462,,,,,,,,,,,,,, -73300,1.3980297,2.0126004,,,,,,,,,,,,,, -73400,1.3790869,2.0389402,,,,,,,,,,,,,, -73500,1.3488088,2.122915,,,,,,,,,,,,,, -73600,1.4274514,1.945245,,,,,,,,,,,,,, -73700,1.3906418,2.0537336,,,,,,,,,,,,,, -73800,1.273969,2.0491996,,,,,,,,,,,,,, -73891,,,0.3917610049247741,2.7942934036254883,0.3677600026130676,2.944324016571045,50000.0,0.2730000019073486,3.748530149459839,10000.0,25024.23308992386,25925.08330845833,25024.23308992386,895.3173720836639,2.8476545810699463,0.0 -73900,1.5565494,2.2719417,,,,,,,,,,,,,, -74000,1.3331218,2.0067115,,,,,,,,,,,,,, -74100,1.6028361,2.026701,,,,,,,,,,,,,, -74200,1.3608935,2.1890044,,,,,,,,,,,,,, -74300,1.369736,2.1618657,,,,,,,,,,,,,, -74400,1.3452343,2.1022863,,,,,,,,,,,,,, -74500,1.3389052,2.0969665,,,,,,,,,,,,,, -74600,1.4191573,2.2726297,,,,,,,,,,,,,, -74700,1.4188813,2.158716,,,,,,,,,,,,,, -74800,1.4284884,2.2541542,,,,,,,,,,,,,, -74900,1.3148535,2.074136,,,,,,,,,,,,,, -75000,1.4781785,2.1391664,,,,,,,,,,,,,, -75100,1.4659705,2.172639,,,,,,,,,,,,,, -75200,1.4127374,2.2172785,,,,,,,,,,,,,, -75300,1.316678,2.0450983,,,,,,,,,,,,,, -75400,1.3746904,2.1108117,,,,,,,,,,,,,, -75404,,,0.0949457883834838,6.09211540222168,0.0862199962139129,6.208457946777344,50000.0,0.0658000037074089,6.601757526397705,10000.0,25534.32436609268,26452.890166282654,25534.32436609268,912.9306666851044,2.89243745803833,0.0 -75500,1.3992928,2.1361308,,,,,,,,,,,,,, -75600,1.3111787,2.139219,,,,,,,,,,,,,, -75700,1.4653821,2.0900538,,,,,,,,,,,,,, -75800,1.4132589,2.192779,,,,,,,,,,,,,, -75900,1.3551996,2.2138908,,,,,,,,,,,,,, -76000,1.4393966,2.1775234,,,,,,,,,,,,,, -76100,1.4763763,2.1648831,,,,,,,,,,,,,, -76200,1.3634456,2.087877,,,,,,,,,,,,,, -76300,1.5894867,2.0125282,,,,,,,,,,,,,, -76400,1.3389269,2.1617572,,,,,,,,,,,,,, -76500,1.4104602,2.0945625,,,,,,,,,,,,,, -76600,1.3997027,2.1124895,,,,,,,,,,,,,, -76700,1.4697508,2.1591403,,,,,,,,,,,,,, -76800,1.5375415,2.0792794,,,,,,,,,,,,,, -76900,1.343617,2.0519698,,,,,,,,,,,,,, -76917,,,0.4280731678009033,2.589535713195801,0.4032399952411651,2.7519068717956543,50000.0,0.3020000159740448,3.514342784881592,10000.0,26044.48452091217,26981.044637680054,26044.48452091217,930.827305316925,2.9355061054229736,0.0 -77000,1.505393,2.244464,,,,,,,,,,,,,, -77100,1.4926288,2.1180549,,,,,,,,,,,,,, -77200,1.7390634,2.1553383,,,,,,,,,,,,,, -77300,1.583215,2.198282,,,,,,,,,,,,,, -77400,1.6861395,2.1969528,,,,,,,,,,,,,, -77500,1.421964,2.0572548,,,,,,,,,,,,,, -77600,1.5288295,2.0714428,,,,,,,,,,,,,, -77700,1.4634287,1.9865521,,,,,,,,,,,,,, -77800,1.6237761,2.1679115,,,,,,,,,,,,,, -77900,1.5792414,2.1918933,,,,,,,,,,,,,, -78000,1.2914379,1.9762595,,,,,,,,,,,,,, -78100,1.3090674,1.993471,,,,,,,,,,,,,, -78200,1.3468715,2.1591136,,,,,,,,,,,,,, -78300,1.3279266,2.0587409,,,,,,,,,,,,,, -78400,1.6304058,2.184085,,,,,,,,,,,,,, -78430,,,0.3382294178009033,3.221769094467163,0.3022799789905548,3.51043963432312,50000.0,0.2303000092506408,4.192265510559082,10000.0,26554.670390605927,27509.320876836777,26554.670390605927,948.81911110878,2.9777910709381104,0.0 -78500,1.4441193,2.1547844,,,,,,,,,,,,,, -78600,1.5016022,2.1789124,,,,,,,,,,,,,, -78700,1.4093188,1.9926589,,,,,,,,,,,,,, -78800,1.5055791,2.0865414,,,,,,,,,,,,,, -78900,1.5282379,2.0319147,,,,,,,,,,,,,, -79000,1.3975797,2.2264671,,,,,,,,,,,,,, -79100,1.3618273,2.1375334,,,,,,,,,,,,,, -79200,1.3569895,2.0088072,,,,,,,,,,,,,, -79300,1.3555485,2.0555532,,,,,,,,,,,,,, -79400,1.6291921,2.0583148,,,,,,,,,,,,,, -79500,1.4802431,2.157498,,,,,,,,,,,,,, -79600,1.5212433,2.119699,,,,,,,,,,,,,, -79700,1.4720277,2.16963,,,,,,,,,,,,,, -79800,1.5521644,2.0173993,,,,,,,,,,,,,, -79900,1.4951212,2.0144415,,,,,,,,,,,,,, -79943,,,0.3127590715885162,3.484292507171631,0.2857199907302856,3.710571527481079,50000.0,0.2179000079631805,4.380500316619873,10000.0,27064.858870744705,28037.1936275959,27064.858870744705,966.4048013687134,3.0196309089660645,0.0 -80000,1.4227227,1.9519296,,,,,,,,,,,,,, -80100,1.4936358,2.0742223,,,,,,,,,,,,,, -80200,1.4108566,2.0986016,,,,,,,,,,,,,, -80300,1.4178458,2.0747085,,,,,,,,,,,,,, -80400,1.416207,2.081668,,,,,,,,,,,,,, -80500,1.5765375,2.1208792,,,,,,,,,,,,,, -80600,1.3329955,2.1366444,,,,,,,,,,,,,, -80700,1.3882931,2.1343713,,,,,,,,,,,,,, -80800,1.435989,2.0948942,,,,,,,,,,,,,, -80900,1.5664428,2.0063524,,,,,,,,,,,,,, -81000,1.3770524,2.039238,,,,,,,,,,,,,, -81100,1.3998008,1.9340249,,,,,,,,,,,,,, -81200,1.5385337,1.9245721,,,,,,,,,,,,,, -81300,1.6824299,1.9748732,,,,,,,,,,,,,, -81400,1.5446503,2.0981638,,,,,,,,,,,,,, -81457,,,0.0907804518938064,6.651303291320801,0.0834999978542327,6.782551288604736,50000.0,0.0578000023961067,7.317497253417969,10000.0,27575.083287000656,28565.20376110077,27575.083287000656,984.0845103263856,3.068606376647949,0.0 -81500,1.5138155,2.0787568,,,,,,,,,,,,,, -81600,1.4865077,2.0128748,,,,,,,,,,,,,, -81700,1.4802337,2.1657448,,,,,,,,,,,,,, -81800,1.4654318,2.049639,,,,,,,,,,,,,, -81900,1.4569176,2.0464497,,,,,,,,,,,,,, -82000,1.4180439,2.1025221,,,,,,,,,,,,,, -82100,1.4751173,2.0056171,,,,,,,,,,,,,, -82200,1.4289047,2.0851986,,,,,,,,,,,,,, -82300,1.5640725,2.0673654,,,,,,,,,,,,,, -82400,1.568019,2.0343552,,,,,,,,,,,,,, -82500,1.5129372,2.0446997,,,,,,,,,,,,,, -82600,1.4523375,1.8972678,,,,,,,,,,,,,, -82700,1.6938318,2.2224355,,,,,,,,,,,,,, -82800,1.4504961,2.0271578,,,,,,,,,,,,,, -82900,1.6107962,2.116346,,,,,,,,,,,,,, -82971,,,0.3289221823215484,3.341825723648072,0.3080599904060364,3.4806392192840576,50000.0,0.2320000082254409,4.121487140655518,10000.0,28085.3017513752,29093.237498044968,28085.3017513752,1001.7920260429382,3.1208367347717285,0.0 -83000,1.647262,2.1090605,,,,,,,,,,,,,, -83100,1.4160022,2.0691004,,,,,,,,,,,,,, -83200,1.8149973,2.0719786,,,,,,,,,,,,,, -83300,1.521016,2.002103,,,,,,,,,,,,,, -83400,1.4457129,2.1041794,,,,,,,,,,,,,, -83500,1.579825,2.1095502,,,,,,,,,,,,,, -83600,1.4580896,2.091848,,,,,,,,,,,,,, -83700,1.6269011,2.1040373,,,,,,,,,,,,,, -83800,1.5679232,2.0046844,,,,,,,,,,,,,, -83900,1.3916675,2.0094693,,,,,,,,,,,,,, -84000,1.4843067,2.0319614,,,,,,,,,,,,,, -84100,1.3761429,2.054889,,,,,,,,,,,,,, -84200,1.6021774,2.0217555,,,,,,,,,,,,,, -84300,1.5529064,2.1229687,,,,,,,,,,,,,, -84400,1.551856,2.0323327,,,,,,,,,,,,,, -84484,,,0.3950494229793548,2.8487935066223145,0.3696599900722503,3.0001261234283447,50000.0,0.2698000073432922,3.844083786010742,10000.0,28595.36977601052,29621.18208193779,28595.36977601052,1019.570032596588,3.163076639175415,0.0 -84500,1.4653999,2.058901,,,,,,,,,,,,,, -84600,1.4209526,2.122616,,,,,,,,,,,,,, -84700,1.5838258,2.066669,,,,,,,,,,,,,, -84800,1.4585856,2.0020266,,,,,,,,,,,,,, -84900,1.3552836,1.970833,,,,,,,,,,,,,, -85000,1.5677431,2.1023364,,,,,,,,,,,,,, -85100,1.3985406,2.047662,,,,,,,,,,,,,, -85200,1.42042,2.0102448,,,,,,,,,,,,,, -85300,1.418997,2.094852,,,,,,,,,,,,,, -85400,1.3710209,1.9769208,,,,,,,,,,,,,, -85500,1.4529477,2.0536149,,,,,,,,,,,,,, -85600,1.4656134,1.9994962,,,,,,,,,,,,,, -85700,1.4523702,2.0914304,,,,,,,,,,,,,, -85800,1.5290129,2.1221902,,,,,,,,,,,,,, -85900,1.4105496,2.0848157,,,,,,,,,,,,,, -85997,,,0.1472815722227096,5.7417826652526855,0.1353199928998947,5.94066047668457,50000.0,0.101500004529953,6.495641708374023,10000.0,29105.444316864014,30148.880268096924,29105.444316864014,1037.0942661762238,3.2089014053344727,0.0 -86000,1.4507765,2.0034926,,,,,,,,,,,,,, -86100,1.459059,1.9968282,,,,,,,,,,,,,, -86200,1.7304424,2.1699748,,,,,,,,,,,,,, -86300,1.630705,2.038168,,,,,,,,,,,,,, -86400,1.5178051,2.1203787,,,,,,,,,,,,,, -86500,1.5330507,2.0873966,,,,,,,,,,,,,, -86600,1.7576406,2.1213136,,,,,,,,,,,,,, -86700,1.4468167,2.1831493,,,,,,,,,,,,,, -86800,1.4992183,2.0474,,,,,,,,,,,,,, -86900,1.9072661,2.0023146,,,,,,,,,,,,,, -87000,1.5901572,2.0293412,,,,,,,,,,,,,, -87100,1.5228575,2.1137993,,,,,,,,,,,,,, -87200,1.4620229,2.064629,,,,,,,,,,,,,, -87300,1.5128944,2.034192,,,,,,,,,,,,,, -87400,1.6125112,2.0759292,,,,,,,,,,,,,, -87500,1.5596731,2.089084,,,,,,,,,,,,,, -87510,,,0.3875956535339355,2.937028884887696,0.3501999974250793,3.214970588684082,50000.0,0.2642000019550323,4.0185370445251465,10000.0,29615.636449813843,30677.051599264145,29615.636449813843,1054.9764783382416,3.251824378967285,0.0 -87600,1.3868959,1.9988574,,,,,,,,,,,,,, -87700,1.4840351,2.1334617,,,,,,,,,,,,,, -87800,1.5939684,2.0266628,,,,,,,,,,,,,, -87900,1.5500064,2.0731091,,,,,,,,,,,,,, -88000,1.4763292,2.1177988,,,,,,,,,,,,,, -88100,1.5890157,2.1386614,,,,,,,,,,,,,, -88200,1.556083,2.0925002,,,,,,,,,,,,,, -88300,1.455715,1.979754,,,,,,,,,,,,,, -88400,1.5249113,2.1208258,,,,,,,,,,,,,, -88500,1.6425772,2.1560817,,,,,,,,,,,,,, -88600,1.4427412,1.9103451,,,,,,,,,,,,,, -88700,1.4309105,1.9632198,,,,,,,,,,,,,, -88800,1.5579139,2.1389346,,,,,,,,,,,,,, -88900,1.4728454,2.0344381,,,,,,,,,,,,,, -89000,1.4801775,2.0402808,,,,,,,,,,,,,, -89023,,,0.4036192595958709,2.742282629013061,0.3774600028991699,2.935097932815552,50000.0,0.2768000066280365,3.74910569190979,10000.0,30125.56599378585,31204.683507680893,30125.56599378585,1072.5670273303986,3.307556867599488,0.0 -89100,1.4742007,1.9603978,,,,,,,,,,,,,, -89200,1.6663529,2.0863743,,,,,,,,,,,,,, -89300,1.6159216,2.0638094,,,,,,,,,,,,,, -89400,1.71352,2.0922277,,,,,,,,,,,,,, -89500,1.5767293,2.0526204,,,,,,,,,,,,,, -89600,1.5950118,2.0905325,,,,,,,,,,,,,, -89700,1.4943978,1.9402041,,,,,,,,,,,,,, -89800,1.5510026,2.072056,,,,,,,,,,,,,, -89900,1.630365,2.1109023,,,,,,,,,,,,,, -90000,1.5566666,2.131006,,,,,,,,,,,,,, -90100,1.6123127,2.1560066,,,,,,,,,,,,,, -90200,1.6050507,2.1132934,,,,,,,,,,,,,, -90300,1.5210924,2.1177843,,,,,,,,,,,,,, -90400,1.5557426,1.9370687,,,,,,,,,,,,,, -90500,1.5586631,2.0348368,,,,,,,,,,,,,, -90537,,,0.2742745578289032,3.888530969619751,0.2615199983119964,4.029567241668701,50000.0,0.1845000088214874,4.86121940612793,10000.0,30635.79352426529,31732.65471124649,30635.79352426529,1090.2088098526,3.3545897006988525,0.0 -90600,1.4980043,1.9919857,,,,,,,,,,,,,, -90700,1.5402615,2.090345,,,,,,,,,,,,,, -90800,1.614395,1.9559414,,,,,,,,,,,,,, -90900,1.4292606,1.9746041,,,,,,,,,,,,,, -91000,1.542128,1.9671533,,,,,,,,,,,,,, -91100,1.8017325,2.024385,,,,,,,,,,,,,, -91200,1.5959198,2.0388374,,,,,,,,,,,,,, -91300,1.5324695,2.0726678,,,,,,,,,,,,,, -91400,1.6610993,2.006288,,,,,,,,,,,,,, -91500,1.4417801,1.9147639,,,,,,,,,,,,,, -91600,1.5242538,2.0547256,,,,,,,,,,,,,, -91700,1.5741149,2.0902572,,,,,,,,,,,,,, -91800,1.5798637,2.1193452,,,,,,,,,,,,,, -91900,1.5107787,2.0450969,,,,,,,,,,,,,, -92000,1.6567602,2.0333736,,,,,,,,,,,,,, -92050,,,0.3800821006298065,2.910261869430542,0.3625999987125397,3.03848934173584,50000.0,0.2576000094413757,3.8726961612701416,10000.0,31145.958006620407,32261.35049700737,31145.958006620407,1108.64000082016,3.399713516235352,0.0 -92100,1.5979102,2.0583086,,,,,,,,,,,,,, -92200,1.4791923,1.9666204,,,,,,,,,,,,,, -92300,1.4823887,1.9680768,,,,,,,,,,,,,, -92400,1.7898986,1.985643,,,,,,,,,,,,,, -92500,1.4653249,1.9975346,,,,,,,,,,,,,, -92600,1.6253811,1.8963537,,,,,,,,,,,,,, -92700,1.6296297,2.0260413,,,,,,,,,,,,,, -92800,1.5574156,1.9253473,,,,,,,,,,,,,, -92900,1.5774784,2.1069398,,,,,,,,,,,,,, -93000,1.6984301,2.103976,,,,,,,,,,,,,, -93100,1.5845543,1.9008797,,,,,,,,,,,,,, -93200,1.6215538,2.003755,,,,,,,,,,,,,, -93300,1.5261528,1.9130121,,,,,,,,,,,,,, -93400,1.6089064,2.1089447,,,,,,,,,,,,,, -93500,1.7632524,1.9293773,,,,,,,,,,,,,, -93560,,,0.4333545863628387,2.672496795654297,0.4023999869823456,2.8671212196350098,50000.0,0.3085000216960907,3.6801254749298096,10000.0,31654.82076215744,32789.16884326935,31654.82076215744,1126.3750030994415,4.565646648406982,0.0 -93600,1.73353,1.952544,,,,,,,,,,,,,, -93700,1.5272053,1.9697659,,,,,,,,,,,,,, -93800,1.6428006,2.0279074,,,,,,,,,,,,,, -93900,1.5147198,1.8695159,,,,,,,,,,,,,, -94000,1.6821479,2.1272755,,,,,,,,,,,,,, -94100,1.5771284,1.9825294,,,,,,,,,,,,,, -94200,1.6156056,2.0266864,,,,,,,,,,,,,, -94300,1.9560026,1.9997376,,,,,,,,,,,,,, -94400,1.671177,1.8791921,,,,,,,,,,,,,, -94500,1.5473793,2.0019236,,,,,,,,,,,,,, -94600,1.44052,1.9407794,,,,,,,,,,,,,, -94700,1.5646564,1.9260441,,,,,,,,,,,,,, -94800,1.6276466,2.0545068,,,,,,,,,,,,,, -94900,1.7858326,1.9306598,,,,,,,,,,,,,, -95000,1.630532,2.0352008,,,,,,,,,,,,,, -95073,,,0.3159478604793548,3.5270910263061523,0.2946999967098236,3.69120979309082,50000.0,0.2123000174760818,4.541915893554688,10000.0,32165.047719717026,33317.30035114288,32165.047719717026,1144.1742820739746,4.614187479019165,0.0 -95100,1.5370256,2.0181859,,,,,,,,,,,,,, -95200,1.5258939,1.8927287,,,,,,,,,,,,,, -95300,1.5939782,2.0432014,,,,,,,,,,,,,, -95400,1.8133866,1.9331186,,,,,,,,,,,,,, -95500,1.7065187,1.9458802,,,,,,,,,,,,,, -95600,1.6639793,1.9808009,,,,,,,,,,,,,, -95700,1.5065144,1.9884946,,,,,,,,,,,,,, -95800,1.6564641,2.0352318,,,,,,,,,,,,,, -95900,1.5483099,1.9132905,,,,,,,,,,,,,, -96000,1.7374762,2.1549442,,,,,,,,,,,,,, -96100,1.8786725,2.2284412,,,,,,,,,,,,,, -96200,1.7551193,2.0516891,,,,,,,,,,,,,, -96300,1.7363249,2.0577545,,,,,,,,,,,,,, -96400,1.5486712,1.8973953,,,,,,,,,,,,,, -96500,1.67556,1.9899639,,,,,,,,,,,,,, -96586,,,0.3710738122463226,2.9733757972717285,0.3378999829292297,3.213762044906616,50000.0,0.25,4.006485462188721,10000.0,32675.180696725845,33845.43358707428,32675.180696725845,1162.073585271835,4.659753799438477,0.0 -96600,1.6634582,2.0394738,,,,,,,,,,,,,, -96700,1.8062146,2.0958018,,,,,,,,,,,,,, -96800,1.7295113,1.8818021,,,,,,,,,,,,,, -96900,1.6065863,2.0028434,,,,,,,,,,,,,, -97000,1.6217637,1.994796,,,,,,,,,,,,,, -97100,1.5704213,1.9553468,,,,,,,,,,,,,, -97200,1.9049584,1.8902018,,,,,,,,,,,,,, -97300,1.6864245,1.9339839,,,,,,,,,,,,,, -97400,1.6592472,1.9866931,,,,,,,,,,,,,, -97500,1.7044617,1.8315883,,,,,,,,,,,,,, -97600,1.6794435,2.073608,,,,,,,,,,,,,, -97700,1.5886039,1.9334425,,,,,,,,,,,,,, -97800,1.6668047,2.0413942,,,,,,,,,,,,,, -97900,1.7024771,1.8632777,,,,,,,,,,,,,, -98000,1.6994513,2.1098063,,,,,,,,,,,,,, -98099,,,0.4164939224720001,2.686826467514038,0.3877999782562256,2.8848915100097656,50000.0,0.2865000069141388,3.734477996826172,10000.0,33185.274107694626,34373.27248048782,33185.274107694626,1179.7177047729492,4.704460382461548,0.0 -98100,1.5785273,1.9937291,,,,,,,,,,,,,, -98200,1.6987053,2.0846663,,,,,,,,,,,,,, -98300,1.7082038,2.005275,,,,,,,,,,,,,, -98400,1.9523348,1.8435372,,,,,,,,,,,,,, -98500,1.6204824,1.9869407,,,,,,,,,,,,,, -98600,1.6450438,2.039218,,,,,,,,,,,,,, -98700,1.7408019,1.9277998,,,,,,,,,,,,,, -98800,1.6482648,2.0510206,,,,,,,,,,,,,, -98900,1.6191691,2.0324738,,,,,,,,,,,,,, -99000,2.0082207,1.880839,,,,,,,,,,,,,, -99100,1.6398841,1.8408761,,,,,,,,,,,,,, -99200,1.849698,1.9398717,,,,,,,,,,,,,, -99300,2.0619502,2.1480267,,,,,,,,,,,,,, -99400,1.5800556,1.915645,,,,,,,,,,,,,, -99500,1.7422924,2.0216758,,,,,,,,,,,,,, -99600,1.5980932,1.9022236,,,,,,,,,,,,,, -99612,,,0.5056201815605164,2.150707960128784,0.4710799753665924,2.3676950931549072,50000.0,0.3606000244617462,3.1262447834014893,10000.0,33695.28577399254,34901.336842536926,33695.28577399254,1197.6669921875,4.754195690155029,0.0 -99700,1.6565787,1.9637129,,,,,,,,,,,,,, -99800,1.7521018,1.9744349,,,,,,,,,,,,,, -99900,1.7657913,1.9371302,,,,,,,,,,,,,, -100000,1.6259195,1.960834,,,,,,,,,,,,,, -100100,1.7713668,1.9196916,,,,,,,,,,,,,, -100200,1.8469673,1.8975809,,,,,,,,,,,,,, -100300,1.6721722,1.9434988,,,,,,,,,,,,,, -100400,1.6198155,1.8780265,,,,,,,,,,,,,, -100500,1.7692704,1.9428525,,,,,,,,,,,,,, -100600,1.8066634,1.8425145,,,,,,,,,,,,,, -100700,1.6812422,1.8823795,,,,,,,,,,,,,, -100800,2.4336867,1.9902893,,,,,,,,,,,,,, -100900,1.6621912,1.9981974,,,,,,,,,,,,,, -101000,1.7518475,1.9412831,,,,,,,,,,,,,, -101100,1.8342726,2.0424328,,,,,,,,,,,,,, -101124,,,0.4436583220958709,2.589078903198242,0.4133199751377105,2.795851230621338,50000.0,0.313400000333786,3.591003179550171,10000.0,34205.252484321594,35429.5738132,34205.252484321594,1215.8287003040314,4.804080247879028,0.0 -101200,1.7151995,1.8874111,,,,,,,,,,,,,, -101300,1.7942848,1.9646547,,,,,,,,,,,,,, -101400,1.7218176,1.9029045,,,,,,,,,,,,,, -101500,1.975013,1.9085556,,,,,,,,,,,,,, -101600,1.7756815,1.9684627,,,,,,,,,,,,,, -101700,1.786206,2.052622,,,,,,,,,,,,,, -101800,1.8278444,1.9669206,,,,,,,,,,,,,, -101900,1.8850775,2.0765872,,,,,,,,,,,,,, -102000,1.9090396,2.0285883,,,,,,,,,,,,,, -102100,1.6889682,1.8452544,,,,,,,,,,,,,, -102200,1.7026365,1.9971646,,,,,,,,,,,,,, -102300,1.710696,1.920136,,,,,,,,,,,,,, -102400,1.636037,1.7616175,,,,,,,,,,,,,, -102500,1.6530116,1.8711963,,,,,,,,,,,,,, -102600,2.006812,1.8385956,,,,,,,,,,,,,, -102637,,,0.5242147445678711,2.068694829940796,0.4916799962520599,2.279438018798828,50000.0,0.3857000172138214,3.0341577529907227,10000.0,34715.23096227646,35957.45670199394,34715.23096227646,1233.6238696575165,4.856037855148315,0.0 -102700,1.6376584,1.8224206,,,,,,,,,,,,,, -102800,1.7635119,2.0397856,,,,,,,,,,,,,, -102900,1.7789242,2.06894,,,,,,,,,,,,,, -103000,1.7172348,1.8358371,,,,,,,,,,,,,, -103100,1.7821716,1.9751055,,,,,,,,,,,,,, -103200,1.8746891,1.9302149,,,,,,,,,,,,,, -103300,1.8137568,1.8337361,,,,,,,,,,,,,, -103400,1.797705,1.8208902,,,,,,,,,,,,,, -103500,1.9623505,2.0660245,,,,,,,,,,,,,, -103600,1.7189264,1.9128436,,,,,,,,,,,,,, -103700,1.7512503,2.0021155,,,,,,,,,,,,,, -103800,1.6478478,1.7998295,,,,,,,,,,,,,, -103900,1.8158251,1.9489696,,,,,,,,,,,,,, -104000,1.8272746,1.9672654,,,,,,,,,,,,,, -104100,1.7073549,1.9463933,,,,,,,,,,,,,, -104150,,,0.440828263759613,2.609811305999756,0.4175199866294861,2.7619614601135254,50000.0,0.3177000284194946,3.549726247787476,10000.0,35225.40841984749,36485.60284900665,35225.40841984749,1251.489592075348,4.904949903488159,0.0 -104200,1.6078192,1.8633825,,,,,,,,,,,,,, -104300,1.6961156,1.8869828,,,,,,,,,,,,,, -104400,1.8461546,1.9568945,,,,,,,,,,,,,, -104500,1.7715994,2.0286446,,,,,,,,,,,,,, -104600,1.7650167,1.8838711,,,,,,,,,,,,,, -104700,1.7779008,1.9423068,,,,,,,,,,,,,, -104800,1.8601462,1.9156605,,,,,,,,,,,,,, -104900,1.8623784,1.8205705,,,,,,,,,,,,,, -105000,1.7313608,1.9359531,,,,,,,,,,,,,, -105100,1.8798991,1.9918923,,,,,,,,,,,,,, -105200,1.7831321,1.8529251,,,,,,,,,,,,,, -105300,1.9313757,2.1035483,,,,,,,,,,,,,, -105400,1.754409,1.8841385,,,,,,,,,,,,,, -105500,1.7156253,1.8422617,,,,,,,,,,,,,, -105600,1.8518133,1.8992242,,,,,,,,,,,,,, -105662,,,0.2795559465885162,3.90887188911438,0.2576999962329864,4.139658451080322,50000.0,0.1985000073909759,4.91162109375,10000.0,35735.49154949188,37013.47520160675,35735.49154949188,1269.1765999794006,4.950104713439941,0.0 -105700,1.7121772,1.8671318,,,,,,,,,,,,,, -105800,1.7258509,1.9090341,,,,,,,,,,,,,, -105900,1.7793144,1.8883616,,,,,,,,,,,,,, -106000,2.1448028,1.8704151,,,,,,,,,,,,,, -106100,2.0003965,1.8575908,,,,,,,,,,,,,, -106200,1.8865497,1.9546487,,,,,,,,,,,,,, -106300,1.937532,2.018349,,,,,,,,,,,,,, -106400,2.1914098,1.9228456,,,,,,,,,,,,,, -106500,1.775928,1.8453057,,,,,,,,,,,,,, -106600,1.7943672,1.9008058,,,,,,,,,,,,,, -106700,1.9447029,1.9337041,,,,,,,,,,,,,, -106800,1.9997703,1.9203463,,,,,,,,,,,,,, -106900,2.0481818,1.9129115,,,,,,,,,,,,,, -107000,1.8768137,1.9050009,,,,,,,,,,,,,, -107100,1.9632845,1.9268562,,,,,,,,,,,,,, -107174,,,0.500019907951355,2.244053363800049,0.4604199826717376,2.513838768005371,50000.0,0.3568000197410583,3.284795045852661,10000.0,36245.43359160423,37541.30133938789,36245.43359160423,1286.9605538845062,4.994918346405029,0.0 -107200,1.7301353,1.9716581,,,,,,,,,,,,,, -107300,1.895329,1.837691,,,,,,,,,,,,,, -107400,1.7146432,1.8524303,,,,,,,,,,,,,, -107500,1.7979648,1.8240392,,,,,,,,,,,,,, -107600,1.7252266,1.9428148,,,,,,,,,,,,,, -107700,1.7112991,1.9510503,,,,,,,,,,,,,, -107800,1.8731345,1.9730675,,,,,,,,,,,,,, -107900,1.776231,1.7655208,,,,,,,,,,,,,, -108000,1.990268,2.0446646,,,,,,,,,,,,,, -108100,1.8587625,1.8539302,,,,,,,,,,,,,, -108200,1.8828869,1.8735316,,,,,,,,,,,,,, -108300,1.891958,1.9665122,,,,,,,,,,,,,, -108400,1.779138,1.8883375,,,,,,,,,,,,,, -108500,2.0948398,1.9659325,,,,,,,,,,,,,, -108600,2.0275822,1.8783109,,,,,,,,,,,,,, -108686,,,0.4937818646430969,2.252739191055298,0.4607200026512146,2.472726345062256,50000.0,0.355100005865097,3.2092649936676025,10000.0,36755.34883475304,38068.86385965347,36755.34883475304,1304.5049712657928,5.042736291885376,0.0 -108700,1.9646415,1.9885594,,,,,,,,,,,,,, -108800,1.9353486,1.9664804,,,,,,,,,,,,,, -108900,1.9772319,1.9066586,,,,,,,,,,,,,, -109000,1.9895219,1.870941,,,,,,,,,,,,,, -109100,1.8752011,1.9938502,,,,,,,,,,,,,, -109200,1.8656485,1.9058343,,,,,,,,,,,,,, -109300,1.8462362,1.8133504,,,,,,,,,,,,,, -109400,1.8587934,1.9253509,,,,,,,,,,,,,, -109500,1.8049412,1.8630866,,,,,,,,,,,,,, -109600,1.8228209,1.8293332,,,,,,,,,,,,,, -109700,1.9545025,1.8574522,,,,,,,,,,,,,, -109800,2.125839,1.9609782,,,,,,,,,,,,,, -109900,1.885767,1.8519216,,,,,,,,,,,,,, -110000,1.9774901,1.9289417,,,,,,,,,,,,,, -110100,1.9689249,1.8467176,,,,,,,,,,,,,, -110199,,,0.5153260231018066,2.0981321334838867,0.4820399880409241,2.304888963699341,50000.0,0.3678000271320343,3.0940093994140625,10000.0,37265.49964928627,38596.9192211628,37265.49964928627,1322.303986787796,5.093266010284424,0.0 -110200,1.9262422,1.8747296,,,,,,,,,,,,,, -110300,2.0145893,1.8614122,,,,,,,,,,,,,, -110400,1.7986722,1.9244957,,,,,,,,,,,,,, -110500,1.9683808,1.8918678,,,,,,,,,,,,,, -110600,1.9596876,1.9861534,,,,,,,,,,,,,, -110700,2.1240876,1.9269657,,,,,,,,,,,,,, -110800,1.9340243,1.9059095,,,,,,,,,,,,,, -110900,2.0176814,1.9654251,,,,,,,,,,,,,, -111000,2.0076942,1.8238685,,,,,,,,,,,,,, -111100,1.8372475,1.8763934,,,,,,,,,,,,,, -111200,2.2012808,2.0132303,,,,,,,,,,,,,, -111300,1.9260716,1.8539507,,,,,,,,,,,,,, -111400,1.9733889,1.7762678,,,,,,,,,,,,,, -111500,2.0941243,1.8391314,,,,,,,,,,,,,, -111600,2.1962132,1.9707751,,,,,,,,,,,,,, -111700,1.8176621,1.8669405,,,,,,,,,,,,,, -111712,,,0.5056201815605164,2.1512796878814697,0.47079998254776,2.372743368148804,50000.0,0.3568000197410583,3.1949143409729004,10000.0,37775.48208498955,39124.7622282505,37775.48208498955,1340.0610992908478,5.142434120178223,0.0 -111800,1.8753841,1.7322581,,,,,,,,,,,,,, -111900,2.0320613,1.8676045,,,,,,,,,,,,,, -112000,1.9000672,1.9209237,,,,,,,,,,,,,, -112100,2.1525059,1.9200183,,,,,,,,,,,,,, -112200,2.0412667,1.7757833,,,,,,,,,,,,,, -112300,1.9722534,1.7281015,,,,,,,,,,,,,, -112400,2.056624,1.950409,,,,,,,,,,,,,, -112500,1.8630321,1.8803097,,,,,,,,,,,,,, -112600,1.9064311,1.7984438,,,,,,,,,,,,,, -112700,1.9729781,1.8498621,,,,,,,,,,,,,, -112800,1.8717042,1.925792,,,,,,,,,,,,,, -112900,2.0259848,1.9942627,,,,,,,,,,,,,, -113000,1.907945,1.8731104,,,,,,,,,,,,,, -113100,1.9991299,1.9387954,,,,,,,,,,,,,, -113200,1.9851468,1.7959034,,,,,,,,,,,,,, -113225,,,0.5051419138908386,2.175563335418701,0.4692799746990204,2.409187078475952,50000.0,0.3579000234603882,3.2229323387146,10000.0,38285.63218688965,39652.69889450073,38285.63218688965,1357.7409224510193,5.193975210189819,0.0 -113300,1.8989105,1.7752461,,,,,,,,,,,,,, -113400,2.0090692,1.8968,,,,,,,,,,,,,, -113500,2.0211859,1.9535048,,,,,,,,,,,,,, -113600,1.9273491,1.7337916,,,,,,,,,,,,,, -113700,2.0204256,1.8559138,,,,,,,,,,,,,, -113800,1.9764838,1.8067331,,,,,,,,,,,,,, -113900,2.0333757,1.7857127,,,,,,,,,,,,,, -114000,2.0210972,1.7675462,,,,,,,,,,,,,, -114100,1.8691926,1.7883116,,,,,,,,,,,,,, -114200,2.1196468,1.9928402,,,,,,,,,,,,,, -114300,2.0129244,1.8299439,,,,,,,,,,,,,, -114400,1.887639,1.8393728,,,,,,,,,,,,,, -114500,1.9230963,1.859513,,,,,,,,,,,,,, -114600,1.9603883,1.8595436,,,,,,,,,,,,,, -114700,2.0110807,1.8969026,,,,,,,,,,,,,, -114739,,,0.5345184803009033,2.0138721466064453,0.4868199825286865,2.304708957672119,50000.0,0.3734000325202942,3.111175060272217,10000.0,38795.65815782547,40181.27318763733,38795.65815782547,1376.1895372867584,5.24059534072876,0.0 -114800,2.2486043,1.9068283,,,,,,,,,,,,,, -114900,2.0212007,1.8408761,,,,,,,,,,,,,, -115000,2.135246,1.8941517,,,,,,,,,,,,,, -115100,1.8639252,1.6926415,,,,,,,,,,,,,, -115200,1.8487178,1.78633,,,,,,,,,,,,,, -115300,1.9424642,1.8021866,,,,,,,,,,,,,, -115400,1.8501076,1.7435918,,,,,,,,,,,,,, -115500,1.8158183,1.8489599,,,,,,,,,,,,,, -115600,1.9945314,1.7471728,,,,,,,,,,,,,, -115700,1.9102064,1.8364995,,,,,,,,,,,,,, -115800,1.9520648,1.7468071,,,,,,,,,,,,,, -115900,2.0626714,1.7909088,,,,,,,,,,,,,, -116000,2.0649378,1.8450522,,,,,,,,,,,,,, -116100,2.0418797,1.9524317,,,,,,,,,,,,,, -116200,2.2591295,1.8883197,,,,,,,,,,,,,, -116252,,,0.6103116869926453,1.59967303276062,0.5604000091552734,1.880994200706482,50000.0,0.4410000145435333,2.634089708328247,10000.0,39305.71879410744,40709.015604019165,39305.71879410744,1393.7754967212677,5.281083822250366,0.0 -116300,2.1540005,1.8759518,,,,,,,,,,,,,, -116400,2.0406415,1.7976019,,,,,,,,,,,,,, -116500,2.1782002,1.987546,,,,,,,,,,,,,, -116600,2.1153796,1.950846,,,,,,,,,,,,,, -116700,2.1059444,1.8376732,,,,,,,,,,,,,, -116800,1.8599594,1.794717,,,,,,,,,,,,,, -116900,2.0483005,1.7960206,,,,,,,,,,,,,, -117000,2.0012035,1.8367841,,,,,,,,,,,,,, -117100,2.1843128,1.8941706,,,,,,,,,,,,,, -117200,2.1769624,1.7603691,,,,,,,,,,,,,, -117300,1.9930849,1.8709266,,,,,,,,,,,,,, -117400,2.1170986,1.8220904,,,,,,,,,,,,,, -117500,2.0903163,1.835192,,,,,,,,,,,,,, -117600,2.377807,1.9053996,,,,,,,,,,,,,, -117700,2.1184993,1.8354588,,,,,,,,,,,,,, -117765,,,0.5663065910339355,1.824598789215088,0.5275200009346008,2.038508892059326,50000.0,0.4052000045776367,2.835646152496338,10000.0,39815.93658471108,41236.85900259018,39815.93658471108,1411.296749830246,5.329135894775391,0.0 -117800,2.1948907,1.8302109,,,,,,,,,,,,,, -117900,2.1199484,1.869288,,,,,,,,,,,,,, -118000,2.158856,1.8069465,,,,,,,,,,,,,, -118100,2.194653,1.7655035,,,,,,,,,,,,,, -118200,2.2675385,1.8302989,,,,,,,,,,,,,, -118300,2.031043,1.7007115,,,,,,,,,,,,,, -118400,2.0311496,1.8180299,,,,,,,,,,,,,, -118500,2.1671891,1.9610931,,,,,,,,,,,,,, -118600,2.1542685,1.7993671,,,,,,,,,,,,,, -118700,2.084633,1.7273948,,,,,,,,,,,,,, -118800,2.3462498,1.9556206,,,,,,,,,,,,,, -118900,2.1340418,1.7420633,,,,,,,,,,,,,, -119000,2.3510883,1.8297265,,,,,,,,,,,,,, -119100,1.9712608,1.8341253,,,,,,,,,,,,,, -119200,2.1574905,1.7765172,,,,,,,,,,,,,, -119278,,,0.2941246628761291,3.7996959686279297,0.2820599973201751,3.9288508892059326,50000.0,0.2028000056743621,4.84499979019165,10000.0,40326.08441853523,41764.800423145294,40326.08441853523,1428.9782931804657,5.38359808921814,0.0 -119300,2.01395,1.7487983,,,,,,,,,,,,,, -119400,2.11589,1.8350178,,,,,,,,,,,,,, -119500,2.0122824,1.7941856,,,,,,,,,,,,,, -119600,2.0384762,1.9459535,,,,,,,,,,,,,, -119700,2.1962168,1.8677362,,,,,,,,,,,,,, -119800,2.1852064,1.8281666,,,,,,,,,,,,,, -119900,2.1201549,1.7911496,,,,,,,,,,,,,, -120000,2.1432674,1.7427709,,,,,,,,,,,,,, -120100,2.0636106,1.7310256,,,,,,,,,,,,,, -120200,2.2592683,1.8126485,,,,,,,,,,,,,, -120300,2.2178342,1.765034,,,,,,,,,,,,,, -120400,2.1296258,1.8753388,,,,,,,,,,,,,, -120500,2.1151156,1.7341536,,,,,,,,,,,,,, -120600,2.1522055,1.6817542,,,,,,,,,,,,,, -120700,2.2509513,1.8392723,,,,,,,,,,,,,, -120791,,,0.5305325388908386,2.007243633270264,0.4994799792766571,2.1964006423950195,50000.0,0.3850000202655792,3.043755292892456,10000.0,40836.241518974304,42292.74313831329,40836.241518974304,1446.6527774333954,5.439541339874268,0.0 -120800,2.1123483,1.7579358,,,,,,,,,,,,,, -120900,2.413191,1.7971431,,,,,,,,,,,,,, -121000,1.9890426,1.7830348,,,,,,,,,,,,,, -121100,2.1242115,1.7052106,,,,,,,,,,,,,, -121200,2.506984,1.8044754,,,,,,,,,,,,,, -121300,2.2231827,1.840596,,,,,,,,,,,,,, -121400,2.3355324,1.7662913,,,,,,,,,,,,,, -121500,2.4698598,1.717386,,,,,,,,,,,,,, -121600,2.422776,1.7957973,,,,,,,,,,,,,, -121700,2.206024,1.7403498,,,,,,,,,,,,,, -121800,2.2365112,1.6941199,,,,,,,,,,,,,, -121900,2.245078,1.7966163,,,,,,,,,,,,,, -122000,2.1463666,1.8176742,,,,,,,,,,,,,, -122100,2.1913652,1.7175301,,,,,,,,,,,,,, -122200,2.2680655,1.8553588,,,,,,,,,,,,,, -122300,2.5449088,1.7797965,,,,,,,,,,,,,, -122303,,,0.5705317258834839,1.820460081100464,0.5156199932098389,2.1237943172454834,50000.0,0.3849000036716461,3.013498067855835,10000.0,41346.37443685532,42821.643122434616,41346.37443685532,1465.3159244060516,5.488732576370239,0.0 -122400,2.100716,1.7790837,,,,,,,,,,,,,, -122500,2.2496386,1.7280797,,,,,,,,,,,,,, -122600,2.2004402,1.7911474,,,,,,,,,,,,,, -122700,2.1227124,1.6730626,,,,,,,,,,,,,, -122800,2.1803577,1.7865372,,,,,,,,,,,,,, -122900,2.5297096,1.7652199,,,,,,,,,,,,,, -123000,2.0043702,1.665074,,,,,,,,,,,,,, -123100,2.1582406,1.7155685,,,,,,,,,,,,,, -123200,2.4271154,1.73051,,,,,,,,,,,,,, -123300,2.3943503,1.8891786,,,,,,,,,,,,,, -123400,2.3773086,1.7989404,,,,,,,,,,,,,, -123500,2.3037055,1.7027625,,,,,,,,,,,,,, -123600,2.299184,1.7972847,,,,,,,,,,,,,, -123700,2.4352298,1.886941,,,,,,,,,,,,,, -123800,2.1373527,1.7029468,,,,,,,,,,,,,, -123817,,,0.6350645422935486,1.4768481254577637,0.5727599859237671,1.826435089111328,50000.0,0.4502000212669372,2.5984723567962646,10000.0,41856.57735204697,43349.79244160652,41856.57735204697,1483.166562795639,5.5297324657440186,0.0 -123900,2.4757018,1.795956,,,,,,,,,,,,,, -124000,2.2662425,1.723665,,,,,,,,,,,,,, -124100,2.4789655,1.8318803,,,,,,,,,,,,,, -124200,2.4206643,1.721711,,,,,,,,,,,,,, -124300,2.3973136,1.7190055,,,,,,,,,,,,,, -124400,2.2109683,1.7236513,,,,,,,,,,,,,, -124500,2.3174765,1.697725,,,,,,,,,,,,,, -124600,2.6832771,1.7497251,,,,,,,,,,,,,, -124700,2.2834735,1.7381874,,,,,,,,,,,,,, -124800,2.4664917,1.6627032,,,,,,,,,,,,,, -124900,2.2655337,1.790554,,,,,,,,,,,,,, -125000,2.235253,1.7176543,,,,,,,,,,,,,, -125100,2.2091205,1.7022113,,,,,,,,,,,,,, -125200,2.2143502,1.6124724,,,,,,,,,,,,,, -125300,2.2847223,1.7338634,,,,,,,,,,,,,, -125330,,,0.5938695669174194,1.693787932395935,0.5395399928092957,2.001538276672364,50000.0,0.4202000200748443,2.7875430583953857,10000.0,42366.60649180412,43877.57340598106,42366.60649180412,1500.7979154586792,5.595268964767456,0.0 -125400,2.1969006,1.7438172,,,,,,,,,,,,,, -125500,2.2234251,1.7009754,,,,,,,,,,,,,, -125600,2.349488,1.7580307,,,,,,,,,,,,,, -125700,2.1414747,1.692484,,,,,,,,,,,,,, -125800,2.381898,1.7680272,,,,,,,,,,,,,, -125900,2.304241,1.7030258,,,,,,,,,,,,,, -126000,2.458339,1.7843587,,,,,,,,,,,,,, -126100,2.2016017,1.6796356,,,,,,,,,,,,,, -126200,2.38043,1.7328193,,,,,,,,,,,,,, -126300,2.493923,1.6858985,,,,,,,,,,,,,, -126400,2.3597527,1.7047582,,,,,,,,,,,,,, -126500,2.5554361,1.6418202,,,,,,,,,,,,,, -126600,2.6134782,1.8408145,,,,,,,,,,,,,, -126700,2.357439,1.8340359,,,,,,,,,,,,,, -126800,2.472641,1.8434255,,,,,,,,,,,,,, -126843,,,0.5989516973495483,1.6441659927368164,0.5536400079727173,1.9106814861297607,50000.0,0.443200021982193,2.6522583961486816,10000.0,42876.69231629372,44405.28691577912,42876.69231629372,1518.3210427761078,5.644772291183472,0.0 -126900,2.287936,1.6932998,,,,,,,,,,,,,, -127000,2.3314393,1.7111056,,,,,,,,,,,,,, -127100,2.3286958,1.7202559,,,,,,,,,,,,,, -127200,2.423984,1.6578628,,,,,,,,,,,,,, -127300,2.371707,1.74174,,,,,,,,,,,,,, -127400,2.493331,1.7201314,,,,,,,,,,,,,, -127500,2.825914,1.6737216,,,,,,,,,,,,,, -127600,2.3138876,1.6397388,,,,,,,,,,,,,, -127700,2.2568731,1.6354172,,,,,,,,,,,,,, -127800,2.3473148,1.8248968,,,,,,,,,,,,,, -127900,2.3432732,1.5740681,,,,,,,,,,,,,, -128000,2.7452157,1.784085,,,,,,,,,,,,,, -128100,2.3404286,1.6125672,,,,,,,,,,,,,, -128200,2.2995048,1.6886436,,,,,,,,,,,,,, -128300,2.3239572,1.6957203,,,,,,,,,,,,,, -128356,,,0.6427175998687744,1.4331533908843994,0.5913599729537964,1.713654637336731,50000.0,0.4610000252723694,2.505321979522705,10000.0,43386.7288172245,44933.01505827904,43386.7288172245,1535.9049890041351,5.6943440437316895,0.0 -128400,2.3866374,1.6169913,,,,,,,,,,,,,, -128500,2.5401003,1.7466669,,,,,,,,,,,,,, -128600,2.4007435,1.6172838,,,,,,,,,,,,,, -128700,2.3793204,1.6603663,,,,,,,,,,,,,, -128800,2.5641406,1.8449123,,,,,,,,,,,,,, -128900,2.4914827,1.7256715,,,,,,,,,,,,,, -129000,2.4061956,1.6953615,,,,,,,,,,,,,, -129100,2.5330057,1.7479248,,,,,,,,,,,,,, -129200,2.8104188,1.653068,,,,,,,,,,,,,, -129300,2.522371,1.793645,,,,,,,,,,,,,, -129400,2.5393534,1.6923364,,,,,,,,,,,,,, -129500,2.599916,1.8334926,,,,,,,,,,,,,, -129600,2.4323359,1.5997152,,,,,,,,,,,,,, -129700,2.4726188,1.6335437,,,,,,,,,,,,,, -129800,2.6409504,1.7505695,,,,,,,,,,,,,, -129869,,,0.5605269074440002,1.889477491378784,0.5187999606132507,2.140009880065918,50000.0,0.4061000049114227,2.9258809089660645,10000.0,43896.84852051735,45461.23577904701,43896.84852051735,1553.898535490036,5.74742603302002,0.0 -129900,2.56252,1.6189864,,,,,,,,,,,,,, -130000,2.4411695,1.6987708,,,,,,,,,,,,,, -130100,2.4817882,1.6847241,,,,,,,,,,,,,, -130200,2.4447167,1.5823823,,,,,,,,,,,,,, -130300,2.571192,1.7608529,,,,,,,,,,,,,, -130400,2.753886,1.7461572,,,,,,,,,,,,,, -130500,2.4853048,1.8026822,,,,,,,,,,,,,, -130600,2.5672503,1.810142,,,,,,,,,,,,,, -130700,2.6550834,1.5985819,,,,,,,,,,,,,, -130800,2.510601,1.6133816,,,,,,,,,,,,,, -130900,2.5160882,1.6554799,,,,,,,,,,,,,, -131000,2.6349702,1.6358356,,,,,,,,,,,,,, -131100,2.5048914,1.6774621,,,,,,,,,,,,,, -131200,2.385751,1.609703,,,,,,,,,,,,,, -131300,2.5912602,1.6799465,,,,,,,,,,,,,, -131383,,,0.6741071343421936,1.2939298152923584,0.5915399789810181,1.7248343229293823,50000.0,0.4692000150680542,2.497609138488769,10000.0,44406.79634475708,45989.0142326355,44406.79634475708,1571.618933916092,5.801463842391968,0.0 -131400,2.534978,1.6527832,,,,,,,,,,,,,, -131500,2.4317656,1.5243835,,,,,,,,,,,,,, -131600,2.5907595,1.7379373,,,,,,,,,,,,,, -131700,2.6427627,1.5803372,,,,,,,,,,,,,, -131800,2.8142922,1.5956442,,,,,,,,,,,,,, -131900,2.4813316,1.7187309,,,,,,,,,,,,,, -132000,2.5866594,1.6742779,,,,,,,,,,,,,, -132100,2.5651348,1.754443,,,,,,,,,,,,,, -132200,2.6476376,1.711014,,,,,,,,,,,,,, -132300,2.50948,1.7441082,,,,,,,,,,,,,, -132400,2.5845945,1.6509428,,,,,,,,,,,,,, -132500,2.5573437,1.6989502,,,,,,,,,,,,,, -132600,2.565142,1.6143167,,,,,,,,,,,,,, -132700,2.7090857,1.6905624,,,,,,,,,,,,,, -132800,2.728172,1.6517112,,,,,,,,,,,,,, -132895,,,0.6228475570678711,1.5460199117660522,0.5662800073623657,1.8539618253707888,50000.0,0.444100022315979,2.6210970878601074,10000.0,44916.78594422341,46516.66002821922,44916.78594422341,1589.165275335312,5.855878114700317,0.0 -132900,2.5918589,1.6959426,,,,,,,,,,,,,, -133000,2.7597053,1.7275357,,,,,,,,,,,,,, -133100,2.6349976,1.6246805,,,,,,,,,,,,,, -133200,2.5970898,1.6612607,,,,,,,,,,,,,, -133300,2.727763,1.6108325,,,,,,,,,,,,,, -133400,2.5095747,1.6487831,,,,,,,,,,,,,, -133500,2.8644996,1.5979655,,,,,,,,,,,,,, -133600,2.5542378,1.6174352,,,,,,,,,,,,,, -133700,2.5595348,1.5863801,,,,,,,,,,,,,, -133800,2.611759,1.7023973,,,,,,,,,,,,,, -133900,2.6200092,1.6505997,,,,,,,,,,,,,, -134000,2.6795628,1.7495483,,,,,,,,,,,,,, -134100,2.751828,1.7395008,,,,,,,,,,,,,, -134200,2.8742976,1.6909325,,,,,,,,,,,,,, -134300,2.6937132,1.6245056,,,,,,,,,,,,,, -134400,2.7498815,1.5289836,,,,,,,,,,,,,, -134408,,,0.6938576102256775,1.2104731798171997,0.6318199634552002,1.532360315322876,50000.0,0.5121000409126282,2.2486801147460938,10000.0,45426.89608311653,47044.71316599846,45426.89608311653,1607.0003921985626,5.908795595169067,0.0 -134500,2.6188748,1.6002035,,,,,,,,,,,,,, -134600,2.6718712,1.6424618,,,,,,,,,,,,,, -134700,2.8221984,1.687814,,,,,,,,,,,,,, -134800,2.708048,1.6453578,,,,,,,,,,,,,, -134900,2.7750754,1.6863353,,,,,,,,,,,,,, -135000,2.7880974,1.5958557,,,,,,,,,,,,,, -135100,3.0073695,1.7571474,,,,,,,,,,,,,, -135200,2.8402753,1.6720349,,,,,,,,,,,,,, -135300,2.8839588,1.7192068,,,,,,,,,,,,,, -135400,2.755597,1.6899834,,,,,,,,,,,,,, -135500,2.60934,1.5383582,,,,,,,,,,,,,, -135600,2.5780835,1.6586796,,,,,,,,,,,,,, -135700,2.583758,1.6021392,,,,,,,,,,,,,, -135800,2.606391,1.6198832,,,,,,,,,,,,,, -135900,2.813613,1.5007355,,,,,,,,,,,,,, -135922,,,0.6696428656578064,1.3166732788085938,0.6161999702453613,1.6147363185882568,50000.0,0.4917000234127044,2.346391439437866,10000.0,45937.00923109055,47572.57046985626,45937.00923109055,1624.6373193264008,5.960654020309448,0.0 -136000,2.610472,1.5653688,,,,,,,,,,,,,, -136100,2.9183142,1.6736189,,,,,,,,,,,,,, -136200,2.6207383,1.5566154,,,,,,,,,,,,,, -136300,2.7123024,1.5855106,,,,,,,,,,,,,, -136400,2.5801017,1.5892926,,,,,,,,,,,,,, -136500,2.898173,1.6971368,,,,,,,,,,,,,, -136600,2.9329858,1.61974,,,,,,,,,,,,,, -136700,2.7033043,1.5123467,,,,,,,,,,,,,, -136800,2.565677,1.4716281,,,,,,,,,,,,,, -136900,2.4748447,1.3695371,,,,,,,,,,,,,, -137000,2.861777,1.7045516,,,,,,,,,,,,,, -137100,2.8316731,1.5968509,,,,,,,,,,,,,, -137200,2.8688395,1.624827,,,,,,,,,,,,,, -137300,2.7133236,1.4399148,,,,,,,,,,,,,, -137400,2.8984556,1.6332291,,,,,,,,,,,,,, -137436,,,0.6374163031578064,1.4651330709457395,0.5859400033950806,1.7452064752578735,50000.0,0.4622000157833099,2.4913787841796875,10000.0,46447.15961503983,48101.16772198677,46447.15961503983,1642.9792296886444,6.0111939907073975,0.0 -137500,2.6563563,1.4964887,,,,,,,,,,,,,, -137600,2.860861,1.6018859,,,,,,,,,,,,,, -137700,2.8301046,1.6307224,,,,,,,,,,,,,, -137800,2.7980275,1.5667858,,,,,,,,,,,,,, -137900,2.818139,1.5440438,,,,,,,,,,,,,, -138000,2.631309,1.5538901,,,,,,,,,,,,,, -138100,2.8291206,1.6156762,,,,,,,,,,,,,, -138200,2.873768,1.5477489,,,,,,,,,,,,,, -138300,2.911943,1.6506042,,,,,,,,,,,,,, -138400,2.784326,1.5555601,,,,,,,,,,,,,, -138500,2.7149036,1.5646122,,,,,,,,,,,,,, -138600,2.9040277,1.5659301,,,,,,,,,,,,,, -138700,3.0618742,1.5786691,,,,,,,,,,,,,, -138800,2.9327948,1.5938253,,,,,,,,,,,,,, -138900,2.681929,1.4897232,,,,,,,,,,,,,, -138949,,,0.6802853941917419,1.2585314512252808,0.6293399930000305,1.5424376726150513,50000.0,0.49590003490448,2.300657033920288,10000.0,46957.08819317818,48629.02605581284,46957.08819317818,1660.8099303245544,6.053507328033447,0.0 -139000,2.7925563,1.6220151,,,,,,,,,,,,,, -139100,2.854995,1.5570326,,,,,,,,,,,,,, -139200,3.1445339,1.4825801,,,,,,,,,,,,,, -139300,2.904072,1.5377463,,,,,,,,,,,,,, -139400,2.7204373,1.5425711,,,,,,,,,,,,,, -139500,2.9252641,1.520068,,,,,,,,,,,,,, -139600,2.9653342,1.5937593,,,,,,,,,,,,,, -139700,2.9672768,1.5758193,,,,,,,,,,,,,, -139800,2.8778265,1.5538509,,,,,,,,,,,,,, -139900,2.9017217,1.5664561,,,,,,,,,,,,,, -140000,2.667847,1.5897864,,,,,,,,,,,,,, -140100,2.952451,1.5622796,,,,,,,,,,,,,, -140200,2.9226515,1.5676144,,,,,,,,,,,,,, -140300,3.3723195,1.6135502,,,,,,,,,,,,,, -140400,2.8241684,1.5229561,,,,,,,,,,,,,, -140462,,,0.7210419178009033,1.0726977586746216,0.6291199922561646,1.5267126560211182,50000.0,0.5039000511169434,2.279113292694092,10000.0,47467.14528131485,49157.06298923493,47467.14528131485,1678.656478881836,6.13153076171875,0.0 -140500,2.9755151,1.5505154,,,,,,,,,,,,,, -140600,2.7868862,1.5592844,,,,,,,,,,,,,, -140700,3.2838817,1.518437,,,,,,,,,,,,,, -140800,3.0908473,1.5638944,,,,,,,,,,,,,, -140900,3.167748,1.6321433,,,,,,,,,,,,,, -141000,3.1129973,1.6301978,,,,,,,,,,,,,, -141100,2.9720893,1.5312529,,,,,,,,,,,,,, -141200,2.969812,1.515198,,,,,,,,,,,,,, -141300,3.2037318,1.4832124,,,,,,,,,,,,,, -141400,3.0850394,1.5146017,,,,,,,,,,,,,, -141500,3.0575953,1.5175331,,,,,,,,,,,,,, -141600,3.216226,1.528719,,,,,,,,,,,,,, -141700,3.2235484,1.6519089,,,,,,,,,,,,,, -141800,2.9859383,1.5462773,,,,,,,,,,,,,, -141900,2.9594865,1.3989319,,,,,,,,,,,,,, -141976,,,0.7115154266357422,1.108195662498474,0.6423599720001221,1.4797382354736328,50000.0,0.5092000365257263,2.2130398750305176,10000.0,47977.29451775551,49685.04183840752,47977.29451775551,1696.3789055347445,6.181941986083984,0.0 -142000,3.084812,1.4700127,,,,,,,,,,,,,, -142100,3.255526,1.4930611,,,,,,,,,,,,,, -142200,2.9521563,1.5192199,,,,,,,,,,,,,, -142300,3.119589,1.6016675,,,,,,,,,,,,,, -142400,2.9961593,1.5277439,,,,,,,,,,,,,, -142500,2.8802598,1.4817746,,,,,,,,,,,,,, -142600,3.1709163,1.4553776,,,,,,,,,,,,,, -142700,2.9951193,1.5081757,,,,,,,,,,,,,, -142800,3.0948827,1.4807701,,,,,,,,,,,,,, -142900,3.2533607,1.7099378,,,,,,,,,,,,,, -143000,3.0154033,1.5668987,,,,,,,,,,,,,, -143100,3.0684903,1.4634695,,,,,,,,,,,,,, -143200,3.0447702,1.4291335,,,,,,,,,,,,,, -143300,3.1402056,1.5922742,,,,,,,,,,,,,, -143400,3.1736286,1.505445,,,,,,,,,,,,,, -143487,,,0.7179726958274841,1.1000560522079468,0.6461799740791321,1.4582358598709106,50000.0,0.5218999981880188,2.1681156158447266,10000.0,48487.30615639687,50212.84715676308,48487.30615639687,1714.0664644241333,6.234250068664551,0.0 -143500,3.3475008,1.5290378,,,,,,,,,,,,,, -143600,3.213611,1.580918,,,,,,,,,,,,,, -143700,3.0934749,1.4837836,,,,,,,,,,,,,, -143800,3.172968,1.4479246,,,,,,,,,,,,,, -143900,3.166955,1.500142,,,,,,,,,,,,,, -144000,3.2235458,1.5626397,,,,,,,,,,,,,, -144100,3.3815303,1.5893427,,,,,,,,,,,,,, -144200,3.204959,1.5836692,,,,,,,,,,,,,, -144300,3.0721316,1.4843032,,,,,,,,,,,,,, -144400,3.3426716,1.5502774,,,,,,,,,,,,,, -144500,2.9327683,1.5081251,,,,,,,,,,,,,, -144600,2.9906857,1.4646293,,,,,,,,,,,,,, -144700,3.223907,1.6085457,,,,,,,,,,,,,, -144800,3.152667,1.5056118,,,,,,,,,,,,,, -144900,3.2328308,1.4945427,,,,,,,,,,,,,, -145000,3.4511132,1.4924474,,,,,,,,,,,,,, -145001,,,0.7001753449440002,1.1895331144332886,0.639519989490509,1.4978569746017456,50000.0,0.5112000107765198,2.244545221328736,10000.0,48997.7052898407,50741.36831307411,48997.7052898407,1732.0806908607483,6.286098003387451,0.0 -145100,3.1652548,1.4218378,,,,,,,,,,,,,, -145200,3.1679263,1.4978112,,,,,,,,,,,,,, -145300,3.0550158,1.4705637,,,,,,,,,,,,,, -145400,3.4100502,1.3452911,,,,,,,,,,,,,, -145500,3.2172153,1.4844921,,,,,,,,,,,,,, -145600,3.1095738,1.4261563,,,,,,,,,,,,,, -145700,3.1481535,1.5073161,,,,,,,,,,,,,, -145800,3.2045379,1.6091404,,,,,,,,,,,,,, -145900,3.074637,1.4012054,,,,,,,,,,,,,, -146000,3.30028,1.5203445,,,,,,,,,,,,,, -146100,3.1570928,1.4165413,,,,,,,,,,,,,, -146200,3.2607474,1.4603288,,,,,,,,,,,,,, -146300,3.410899,1.4618087,,,,,,,,,,,,,, -146400,3.3209865,1.4886582,,,,,,,,,,,,,, -146500,3.1885302,1.4278007,,,,,,,,,,,,,, -146513,,,0.6998565196990967,1.1637591123580933,0.6406199932098389,1.486459493637085,50000.0,0.5179000496864319,2.226635932922364,10000.0,49507.65221524239,51269.19027972221,49507.65221524239,1749.8422105312347,6.343450784683228,0.0 -146600,3.1062002,1.483386,,,,,,,,,,,,,, -146700,3.1780457,1.3925104,,,,,,,,,,,,,, -146800,3.6810257,1.491505,,,,,,,,,,,,,, -146900,3.174426,1.5480757,,,,,,,,,,,,,, -147000,3.287595,1.5715507,,,,,,,,,,,,,, -147100,3.2022917,1.4818643,,,,,,,,,,,,,, -147200,3.202513,1.3961385,,,,,,,,,,,,,, -147300,3.4512825,1.3910489,,,,,,,,,,,,,, -147400,3.1983256,1.4407247,,,,,,,,,,,,,, -147500,3.3165379,1.4283077,,,,,,,,,,,,,, -147600,3.2552283,1.5004053,,,,,,,,,,,,,, -147700,3.2149425,1.4576018,,,,,,,,,,,,,, -147800,3.8548565,1.5106939,,,,,,,,,,,,,, -147900,3.2827022,1.497334,,,,,,,,,,,,,, -148000,3.4319909,1.3899151,,,,,,,,,,,,,, -148026,,,0.7155213356018066,1.1146916151046753,0.6486999988555908,1.4627606868743896,50000.0,0.524399995803833,2.186429262161255,10000.0,50017.61883044243,51797.05554533005,50017.61883044243,1767.6306097507477,6.397595882415772,0.0 -148100,3.3349798,1.4029756,,,,,,,,,,,,,, -148200,3.3310916,1.3603352,,,,,,,,,,,,,, -148300,3.401891,1.2894629,,,,,,,,,,,,,, -148400,3.4114168,1.358032,,,,,,,,,,,,,, -148500,3.2699609,1.4766417,,,,,,,,,,,,,, -148600,3.3525333,1.46664,,,,,,,,,,,,,, -148700,3.3743818,1.472884,,,,,,,,,,,,,, -148800,3.5135882,1.3654075,,,,,,,,,,,,,, -148900,3.0981178,1.3112143,,,,,,,,,,,,,, -149000,3.4870772,1.4866247,,,,,,,,,,,,,, -149100,3.8091378,1.4956645,,,,,,,,,,,,,, -149200,3.4848487,1.3476522,,,,,,,,,,,,,, -149300,3.5779858,1.4713796,,,,,,,,,,,,,, -149400,3.3903627,1.404856,,,,,,,,,,,,,, -149500,3.330027,1.4577198,,,,,,,,,,,,,, -149538,,,0.7698501348495483,0.8804860711097717,0.6730799674987793,1.3303167819976809,50000.0,0.539900004863739,2.03213119506836,10000.0,50527.6418004036,52325.024513721466,50527.6418004036,1785.4646661281586,6.454508543014526,0.0 -149600,3.508543,1.4631011,,,,,,,,,,,,,, -149700,3.1562283,1.3997537,,,,,,,,,,,,,, -149800,3.5542135,1.4245954,,,,,,,,,,,,,, -149900,3.4983733,1.3823252,,,,,,,,,,,,,, -150000,3.5266137,1.3610326,,,,,,,,,,,,,, -150100,3.4888437,1.4618477,,,,,,,,,,,,,, -150200,3.3614826,1.282184,,,,,,,,,,,,,, -150300,3.727773,1.438931,,,,,,,,,,,,,, -150400,3.567918,1.3246408,,,,,,,,,,,,,, -150500,3.5283365,1.4496719,,,,,,,,,,,,,, -150600,3.420122,1.4806706,,,,,,,,,,,,,, -150700,3.584104,1.3601735,,,,,,,,,,,,,, -150800,3.5391252,1.3232412,,,,,,,,,,,,,, -150900,3.7357469,1.5671993,,,,,,,,,,,,,, -151000,3.5031304,1.2907991,,,,,,,,,,,,,, -151051,,,0.7565170526504517,0.9239798188209534,0.6729199886322021,1.331039309501648,50000.0,0.5393000245094299,2.099137544631958,10000.0,51037.73308491707,52852.88480973244,51037.73308491707,1803.1257722377777,6.508534669876099,0.0 -151100,3.410952,1.4281154,,,,,,,,,,,,,, -151200,3.6260343,1.4127253,,,,,,,,,,,,,, -151300,3.4817345,1.3786789,,,,,,,,,,,,,, -151400,3.6527436,1.4092226,,,,,,,,,,,,,, -151500,3.4791262,1.3907032,,,,,,,,,,,,,, -151600,3.636342,1.444553,,,,,,,,,,,,,, -151700,3.538605,1.3306112,,,,,,,,,,,,,, -151800,3.4724731,1.330107,,,,,,,,,,,,,, -151900,3.8170402,1.4316345,,,,,,,,,,,,,, -152000,3.6035328,1.4089465,,,,,,,,,,,,,, -152100,3.5897431,1.4064723,,,,,,,,,,,,,, -152200,3.6885135,1.4215491,,,,,,,,,,,,,, -152300,3.708663,1.4214146,,,,,,,,,,,,,, -152400,3.7028162,1.397355,,,,,,,,,,,,,, -152500,3.2706845,1.2907497,,,,,,,,,,,,,, -152564,,,0.7428451776504517,0.9893755316734314,0.6662999987602234,1.3725179433822632,50000.0,0.5270000100135803,2.163541555404663,10000.0,51547.81756424904,53381.75227236748,51547.81756424904,1821.800268173217,6.5628204345703125,0.0 -152600,3.5711153,1.4553361,,,,,,,,,,,,,, -152700,3.952268,1.2973883,,,,,,,,,,,,,, -152800,3.569057,1.3500162,,,,,,,,,,,,,, -152900,3.9700823,1.4184551,,,,,,,,,,,,,, -153000,3.7328186,1.3535911,,,,,,,,,,,,,, -153100,3.8585565,1.421308,,,,,,,,,,,,,, -153200,3.7467463,1.3013222,,,,,,,,,,,,,, -153300,4.230438,1.4655638,,,,,,,,,,,,,, -153400,3.8488386,1.4610864,,,,,,,,,,,,,, -153500,3.83671,1.4197302,,,,,,,,,,,,,, -153600,3.6537745,1.3892437,,,,,,,,,,,,,, -153700,3.5136957,1.2853715,,,,,,,,,,,,,, -153800,3.7744994,1.319015,,,,,,,,,,,,,, -153900,3.4077318,1.3290651,,,,,,,,,,,,,, -154000,3.8958628,1.3298372,,,,,,,,,,,,,, -154077,,,0.7563576102256775,0.9220529198646544,0.6801199913024902,1.3138034343719482,50000.0,0.5506000518798828,2.063292980194092,10000.0,52057.731224775314,53909.18029880524,52057.731224775314,1839.209418058396,6.6128644943237305,0.0 -154100,3.8426092,1.472359,,,,,,,,,,,,,, -154200,3.4661536,1.25391,,,,,,,,,,,,,, -154300,3.8676026,1.4136659,,,,,,,,,,,,,, -154400,3.5955827,1.2005184,,,,,,,,,,,,,, -154500,3.7983074,1.42944,,,,,,,,,,,,,, -154600,3.6587076,1.2282691,,,,,,,,,,,,,, -154700,3.7446997,1.3325495,,,,,,,,,,,,,, -154800,4.0790553,1.4780324,,,,,,,,,,,,,, -154900,3.5342915,1.1889544,,,,,,,,,,,,,, -155000,3.7138057,1.3292508,,,,,,,,,,,,,, -155100,3.7506828,1.2875048,,,,,,,,,,,,,, -155200,3.9794257,1.3727869,,,,,,,,,,,,,, -155300,4.053749,1.4391624,,,,,,,,,,,,,, -155400,3.9506207,1.3158826,,,,,,,,,,,,,, -155500,3.626866,1.2719791,,,,,,,,,,,,,, -155590,,,0.7622169852256775,0.8947357535362244,0.6869399547576904,1.2795292139053345,50000.0,0.555400013923645,2.034101963043213,10000.0,52567.773965358734,54437.1859266758,52567.773965358734,1857.059386253357,6.671154737472534,0.0 -155600,3.7317026,1.3161263,,,,,,,,,,,,,, -155700,3.9794998,1.3852617,,,,,,,,,,,,,, -155800,3.7594578,1.3366895,,,,,,,,,,,,,, -155900,3.9755914,1.4323199,,,,,,,,,,,,,, -156000,3.6010337,1.2518073,,,,,,,,,,,,,, -156100,3.8123133,1.2203128,,,,,,,,,,,,,, -156200,4.107534,1.356415,,,,,,,,,,,,,, -156300,3.8316834,1.2987148,,,,,,,,,,,,,, -156400,4.0071564,1.3099996,,,,,,,,,,,,,, -156500,4.11666,1.3387786,,,,,,,,,,,,,, -156600,3.883818,1.3486513,,,,,,,,,,,,,, -156700,4.172687,1.3172462,,,,,,,,,,,,,, -156800,4.2931767,1.256636,,,,,,,,,,,,,, -156900,4.0027113,1.2256079,,,,,,,,,,,,,, -157000,3.8902726,1.3198836,,,,,,,,,,,,,, -157100,4.14655,1.3411945,,,,,,,,,,,,,, -157103,,,0.7747727632522583,0.8410025238990784,0.6900999546051025,1.268452286720276,50000.0,0.5601000189781189,2.017871379852295,10000.0,53077.83566641808,54964.90230464935,53077.83566641808,1874.6021373271944,6.727156400680542,0.0 -157200,3.9225788,1.2368362,,,,,,,,,,,,,, -157300,3.791703,1.2950953,,,,,,,,,,,,,, -157400,4.126387,1.3111477,,,,,,,,,,,,,, -157500,4.134638,1.2546804,,,,,,,,,,,,,, -157600,3.7754881,1.2836999,,,,,,,,,,,,,, -157700,4.050344,1.30836,,,,,,,,,,,,,, -157800,4.218883,1.3231288,,,,,,,,,,,,,, -157900,4.288705,1.3841,,,,,,,,,,,,,, -158000,4.076632,1.1740258,,,,,,,,,,,,,, -158100,3.9439526,1.2314955,,,,,,,,,,,,,, -158200,4.0390463,1.336721,,,,,,,,,,,,,, -158300,4.1802163,1.3035465,,,,,,,,,,,,,, -158400,3.9438527,1.3363227,,,,,,,,,,,,,, -158500,4.372042,1.4235795,,,,,,,,,,,,,, -158600,3.941734,1.2423409,,,,,,,,,,,,,, -158616,,,0.8008609414100647,0.7219278812408447,0.6961399912834167,1.2478055953979492,50000.0,0.5782000422477722,1.9279481172561648,10000.0,53588.04853415489,55492.84319996834,53588.04853415489,1892.21878695488,6.782275915145874,0.0 -158700,4.383186,1.3005332,,,,,,,,,,,,,, -158800,3.9279873,1.3229691,,,,,,,,,,,,,, -158900,3.891074,1.2909775,,,,,,,,,,,,,, -159000,3.9884026,1.2872705,,,,,,,,,,,,,, -159100,4.467808,1.3272073,,,,,,,,,,,,,, -159200,4.3560176,1.2847307,,,,,,,,,,,,,, -159300,4.0998564,1.2628962,,,,,,,,,,,,,, -159400,4.1316686,1.2782124,,,,,,,,,,,,,, -159500,4.171126,1.3121723,,,,,,,,,,,,,, -159600,4.3045883,1.2666342,,,,,,,,,,,,,, -159700,3.8332753,1.2355447,,,,,,,,,,,,,, -159800,4.221223,1.2958338,,,,,,,,,,,,,, -159900,4.6540895,1.3149697,,,,,,,,,,,,,, -160000,4.3065314,1.3884946,,,,,,,,,,,,,, -160100,4.4025073,1.244709,,,,,,,,,,,,,, -160130,,,0.8089724183082581,0.7101982831954956,0.7047399878501892,1.194737195968628,50000.0,0.5812000036239624,1.8732895851135247,10000.0,54098.26779890061,56021.37393474579,54098.26779890061,1910.420448064804,6.837030410766602,0.0 -160200,3.9080179,1.185091,,,,,,,,,,,,,, -160300,4.6467385,1.3522644,,,,,,,,,,,,,, -160400,4.64732,1.3246286,,,,,,,,,,,,,, -160500,3.9039788,1.239986,,,,,,,,,,,,,, -160600,3.951142,1.1434915,,,,,,,,,,,,,, -160700,4.1085396,1.1935095,,,,,,,,,,,,,, -160800,4.294335,1.1902599,,,,,,,,,,,,,, -160900,4.221175,1.2057934,,,,,,,,,,,,,, -161000,4.570474,1.2557968,,,,,,,,,,,,,, -161100,3.8935587,1.13097,,,,,,,,,,,,,, -161200,3.9547887,1.2656565,,,,,,,,,,,,,, -161300,4.3642297,1.2977637,,,,,,,,,,,,,, -161400,4.6441956,1.268663,,,,,,,,,,,,,, -161500,4.3083277,1.15313,,,,,,,,,,,,,, -161600,4.196878,1.1671422,,,,,,,,,,,,,, -161643,,,0.7973134517669678,0.7421764135360718,0.702739953994751,1.2192460298538208,50000.0,0.5758000016212463,1.9370954036712649,10000.0,54608.44234919548,56549.32446479797,54608.44234919548,1928.0951828956604,6.881301641464233,0.0 -161700,4.285399,1.2457255,,,,,,,,,,,,,, -161800,4.267925,1.200028,,,,,,,,,,,,,, -161900,4.653656,1.2029098,,,,,,,,,,,,,, -162000,4.441015,1.2345321,,,,,,,,,,,,,, -162100,4.3488913,1.280916,,,,,,,,,,,,,, -162200,4.359267,1.3053439,,,,,,,,,,,,,, -162300,4.404143,1.2379118,,,,,,,,,,,,,, -162400,4.421928,1.2331786,,,,,,,,,,,,,, -162500,4.25628,1.1874151,,,,,,,,,,,,,, -162600,4.4436564,1.1466241,,,,,,,,,,,,,, -162700,4.333523,1.155528,,,,,,,,,,,,,, -162800,4.2582526,1.1782297,,,,,,,,,,,,,, -162900,4.34112,1.1471198,,,,,,,,,,,,,, -163000,4.535535,1.1889347,,,,,,,,,,,,,, -163100,4.362617,1.19122,,,,,,,,,,,,,, -163156,,,0.8103076815605164,0.7060856819152832,0.7111200094223022,1.171807885169983,50000.0,0.5887000560760498,1.8791732788085933,10000.0,55118.50452852249,57077.438390254974,55118.50452852249,1946.0246217250824,6.949363708496094,0.0 -163200,4.1151376,1.0808609,,,,,,,,,,,,,, -163300,4.601838,1.3053517,,,,,,,,,,,,,, -163400,4.5915875,1.177891,,,,,,,,,,,,,, -163500,4.3849607,1.1730704,,,,,,,,,,,,,, -163600,4.573286,1.1806526,,,,,,,,,,,,,, -163700,4.7608953,1.151593,,,,,,,,,,,,,, -163800,4.509412,1.167189,,,,,,,,,,,,,, -163900,4.468714,1.1946547,,,,,,,,,,,,,, -164000,4.226554,1.0624704,,,,,,,,,,,,,, -164100,4.412477,1.1198967,,,,,,,,,,,,,, -164200,4.273656,1.1423111,,,,,,,,,,,,,, -164300,4.760163,1.2271382,,,,,,,,,,,,,, -164400,4.814959,1.1803195,,,,,,,,,,,,,, -164500,4.4698315,1.1657612,,,,,,,,,,,,,, -164600,4.355335,1.1881907,,,,,,,,,,,,,, -164668,,,0.8194754123687744,0.6554054021835327,0.7178800106048584,1.144294261932373,50000.0,0.5937000513076782,1.8554983139038088,10000.0,55628.45368814469,57605.29715466499,55628.45368814469,1963.8206391334527,7.007922172546387,0.0 -164700,4.2416086,1.0645779,,,,,,,,,,,,,, -164800,4.7153006,1.1857444,,,,,,,,,,,,,, -164900,4.6422853,1.2843852,,,,,,,,,,,,,, -165000,4.7897377,1.1483456,,,,,,,,,,,,,, -165100,4.6969953,1.1625959,,,,,,,,,,,,,, -165200,4.3438687,1.0303032,,,,,,,,,,,,,, -165300,4.750995,1.2496696,,,,,,,,,,,,,, -165400,4.476917,1.1145194,,,,,,,,,,,,,, -165500,4.3564825,1.0994538,,,,,,,,,,,,,, -165600,4.6718416,1.1917559,,,,,,,,,,,,,, -165700,4.4429865,1.1024494,,,,,,,,,,,,,, -165800,4.35326,1.1140448,,,,,,,,,,,,,, -165900,4.8720264,1.1411992,,,,,,,,,,,,,, -166000,4.427727,1.134085,,,,,,,,,,,,,, -166100,4.424196,1.0389669,,,,,,,,,,,,,, -166181,,,0.8342036008834839,0.6015270948410034,0.727840006351471,1.113991379737854,50000.0,0.6082000136375427,1.8160563707351685,10000.0,56138.54439759255,58133.3680229187,56138.54439759255,1981.6831741333008,7.071213245391846,0.0 -166200,4.380922,1.0853236,,,,,,,,,,,,,, -166300,4.93746,1.2893802,,,,,,,,,,,,,, -166400,5.004527,1.120145,,,,,,,,,,,,,, -166500,4.8047957,1.1292009,,,,,,,,,,,,,, -166600,4.578129,1.2124034,,,,,,,,,,,,,, -166700,4.8184204,1.1346997,,,,,,,,,,,,,, -166800,4.801632,1.1644651,,,,,,,,,,,,,, -166900,4.531892,1.1114923,,,,,,,,,,,,,, -167000,4.5674725,1.1498014,,,,,,,,,,,,,, -167100,4.7170873,1.0244081,,,,,,,,,,,,,, -167200,4.354073,1.0420682,,,,,,,,,,,,,, -167300,4.601362,1.0637227,,,,,,,,,,,,,, -167400,4.738226,1.0516496,,,,,,,,,,,,,, -167500,4.6004257,1.1359284,,,,,,,,,,,,,, -167600,4.4567475,1.0318341,,,,,,,,,,,,,, -167694,,,0.8501076102256775,0.546515941619873,0.7238799929618835,1.1284018754959106,50000.0,0.5982000231742859,1.8255773782730105,10000.0,56648.53292417526,58662.07354712486,56648.53292417526,2000.268955469132,7.146226406097412,0.0 -167700,5.2303243,1.1958277,,,,,,,,,,,,,, -167800,4.677933,1.1023197,,,,,,,,,,,,,, -167900,5.094013,1.149649,,,,,,,,,,,,,, -168000,4.379588,1.0204502,,,,,,,,,,,,,, -168100,5.1372733,1.2737446,,,,,,,,,,,,,, -168200,4.958009,1.1659205,,,,,,,,,,,,,, -168300,5.0544577,1.1520004,,,,,,,,,,,,,, -168400,5.0048943,1.1348906,,,,,,,,,,,,,, -168500,4.74641,1.0716865,,,,,,,,,,,,,, -168600,4.9714837,1.0605717,,,,,,,,,,,,,, -168700,4.8936567,1.1798478,,,,,,,,,,,,,, -168800,5.043288,1.1320493,,,,,,,,,,,,,, -168900,4.4140024,1.072344,,,,,,,,,,,,,, -169000,4.898145,1.0911224,,,,,,,,,,,,,, -169100,5.463826,1.1440108,,,,,,,,,,,,,, -169200,4.770723,1.049867,,,,,,,,,,,,,, -169207,,,0.8522600531578064,0.5386980772018433,0.7333599925041199,1.0911046266555786,50000.0,0.6070000529289246,1.7918938398361206,10000.0,57158.69640493393,59190.97307920456,57158.69640493393,2018.9037404060364,7.192639112472534,0.0 -169300,5.238467,1.0840907,,,,,,,,,,,,,, -169400,4.909138,1.097197,,,,,,,,,,,,,, -169500,4.7190437,1.1215856,,,,,,,,,,,,,, -169600,4.401987,0.9882909,,,,,,,,,,,,,, -169700,4.9812965,1.1271185,,,,,,,,,,,,,, -169800,4.962094,1.1266553,,,,,,,,,,,,,, -169900,4.688963,1.0320122,,,,,,,,,,,,,, -170000,4.4447784,1.0536736,,,,,,,,,,,,,, -170100,4.991676,1.036954,,,,,,,,,,,,,, -170200,5.062522,0.99507105,,,,,,,,,,,,,, -170300,4.714839,1.033172,,,,,,,,,,,,,, -170400,4.5621424,0.98038745,,,,,,,,,,,,,, -170500,4.86911,1.1097335,,,,,,,,,,,,,, -170600,4.8504257,0.9938638,,,,,,,,,,,,,, -170700,5.069633,1.0461049,,,,,,,,,,,,,, -170720,,,0.8514229655265808,0.5230114459991455,0.7345799803733826,1.079545021057129,50000.0,0.6105000376701355,1.766804337501526,10000.0,57668.74922180176,59719.0146727562,57668.74922180176,2036.7781956195831,7.252174377441406,0.0 -170800,5.07111,0.96797734,,,,,,,,,,,,,, -170900,4.9588733,1.0350609,,,,,,,,,,,,,, -171000,5.066299,1.0375946,,,,,,,,,,,,,, -171100,4.829416,1.0806446,,,,,,,,,,,,,, -171200,4.753137,1.043238,,,,,,,,,,,,,, -171300,5.2530384,1.113045,,,,,,,,,,,,,, -171400,5.1877,1.092167,,,,,,,,,,,,,, -171500,5.48822,1.07133,,,,,,,,,,,,,, -171600,4.6124597,0.9800498,,,,,,,,,,,,,, -171700,5.114557,1.0476829,,,,,,,,,,,,,, -171800,4.7948966,0.89827013,,,,,,,,,,,,,, -171900,4.9563856,1.0606188,,,,,,,,,,,,,, -172000,5.176217,1.2089114,,,,,,,,,,,,,, -172100,5.25189,1.0318086,,,,,,,,,,,,,, -172200,5.273314,1.0839223,,,,,,,,,,,,,, -172233,,,0.8605110049247742,0.5039197206497192,0.7382400035858154,1.062277913093567,50000.0,0.6189000010490417,1.7463383674621582,10000.0,58178.75817775726,60246.97335195541,58178.75817775726,2054.613529920578,7.311715364456177,0.0 -172300,4.9456234,1.0298071,,,,,,,,,,,,,, -172400,4.8934097,1.0097977,,,,,,,,,,,,,, -172500,4.8564677,0.9744826,,,,,,,,,,,,,, -172600,4.753774,1.0284526,,,,,,,,,,,,,, -172700,5.11179,0.9931404,,,,,,,,,,,,,, -172800,5.3746557,1.097893,,,,,,,,,,,,,, -172900,5.365886,1.0941498,,,,,,,,,,,,,, -173000,4.980993,1.0253284,,,,,,,,,,,,,, -173100,5.367035,1.0758091,,,,,,,,,,,,,, -173200,4.8698683,0.988585,,,,,,,,,,,,,, -173300,5.371211,1.0694234,,,,,,,,,,,,,, -173400,4.833382,0.9854538,,,,,,,,,,,,,, -173500,4.9243174,1.066307,,,,,,,,,,,,,, -173600,5.543603,1.0483309,,,,,,,,,,,,,, -173700,5.2756395,1.0844356,,,,,,,,,,,,,, -173746,,,0.8649752736091614,0.4853964745998382,0.7422399520874023,1.0512233972549438,50000.0,0.6200000047683716,1.7367162704467771,10000.0,58688.94418215752,60775.03083705902,58688.94418215752,2072.3738420009613,7.36877703666687,0.0 -173800,4.949644,0.9548417,,,,,,,,,,,,,, -173900,5.155918,1.0793606,,,,,,,,,,,,,, -174000,5.3106513,0.9794936,,,,,,,,,,,,,, -174100,4.9834456,1.0218886,,,,,,,,,,,,,, -174200,5.347258,1.0596285,,,,,,,,,,,,,, -174300,4.9291596,1.0460045,,,,,,,,,,,,,, -174400,4.690568,0.91509366,,,,,,,,,,,,,, -174500,4.6899123,0.9120961,,,,,,,,,,,,,, -174600,5.128848,0.94513565,,,,,,,,,,,,,, -174700,4.8070464,0.8860636,,,,,,,,,,,,,, -174800,4.862586,0.92136884,,,,,,,,,,,,,, -174900,4.78847,0.9980161,,,,,,,,,,,,,, -175000,5.3011775,1.062842,,,,,,,,,,,,,, -175100,5.453509,1.0549846,,,,,,,,,,,,,, -175200,5.2199354,0.9622445,,,,,,,,,,,,,, -175260,,,0.8682437539100647,0.4675627946853637,0.7414199709892273,1.049114465713501,50000.0,0.6203000545501709,1.7439327239990234,10000.0,59199.13371539116,61303.094465732574,59199.13371539116,2090.135448217392,7.425171852111816,0.0 -175300,4.919724,0.90982395,,,,,,,,,,,,,, -175400,5.3282657,0.99191964,,,,,,,,,,,,,, -175500,5.181478,0.9487407,,,,,,,,,,,,,, -175600,4.9193773,0.9996835,,,,,,,,,,,,,, -175700,4.9980493,0.9142719,,,,,,,,,,,,,, -175800,4.974729,0.9735118,,,,,,,,,,,,,, -175900,5.276927,0.9592659,,,,,,,,,,,,,, -176000,5.0460663,0.9978785,,,,,,,,,,,,,, -176100,5.0652146,0.9668655,,,,,,,,,,,,,, -176200,5.0504875,1.0066639,,,,,,,,,,,,,, -176300,5.3442397,0.9424641,,,,,,,,,,,,,, -176400,4.761039,0.8986981,,,,,,,,,,,,,, -176500,4.8797827,0.95459265,,,,,,,,,,,,,, -176600,4.920545,0.95673215,,,,,,,,,,,,,, -176700,5.1601286,0.94862777,,,,,,,,,,,,,, -176772,,,0.8796038031578064,0.4347046613693237,0.7447999715805054,1.036797285079956,50000.0,0.6271000504493713,1.7231155633926392,10000.0,59709.20754027367,61831.22635412216,59709.20754027367,2108.078211784363,7.485778331756592,0.0 -176800,5.381981,1.011652,,,,,,,,,,,,,, -176900,5.007193,0.91519606,,,,,,,,,,,,,, -177000,5.1750865,0.9313742,,,,,,,,,,,,,, -177100,5.550713,0.9713228,,,,,,,,,,,,,, -177200,5.594734,1.0661311,,,,,,,,,,,,,, -177300,4.8367605,0.9147572,,,,,,,,,,,,,, -177400,5.2357,0.9515079,,,,,,,,,,,,,, -177500,5.486882,0.9418496,,,,,,,,,,,,,, -177600,5.0678253,0.9432184,,,,,,,,,,,,,, -177700,5.270625,0.8896926,,,,,,,,,,,,,, -177800,5.3252373,0.88146615,,,,,,,,,,,,,, -177900,5.5445952,1.0028564,,,,,,,,,,,,,, -178000,5.4733005,1.0253966,,,,,,,,,,,,,, -178100,5.533215,0.9117426,,,,,,,,,,,,,, -178200,5.3679347,1.0245857,,,,,,,,,,,,,, -178285,,,0.8785474896430969,0.4288856983184814,0.746399998664856,1.0343586206436155,50000.0,0.6279000043869019,1.7289866209030151,10000.0,60219.31211185455,62359.2754983902,60219.31211185455,2125.9084413051605,7.544202566146851,0.0 -178300,5.370605,1.0450896,,,,,,,,,,,,,, -178400,5.291601,0.94458354,,,,,,,,,,,,,, -178500,5.475281,0.98783153,,,,,,,,,,,,,, -178600,5.713727,1.0472808,,,,,,,,,,,,,, -178700,5.0507994,0.8940735,,,,,,,,,,,,,, -178800,5.4418273,0.97445273,,,,,,,,,,,,,, -178900,5.141085,1.0344822,,,,,,,,,,,,,, -179000,5.686951,1.1005428,,,,,,,,,,,,,, -179100,5.364227,0.9950694,,,,,,,,,,,,,, -179200,5.099856,0.9903912,,,,,,,,,,,,,, -179300,5.909412,0.9569622,,,,,,,,,,,,,, -179400,4.957117,0.9529001,,,,,,,,,,,,,, -179500,5.2770295,0.9809321,,,,,,,,,,,,,, -179600,5.447071,1.0046159,,,,,,,,,,,,,, -179700,5.362776,0.97353303,,,,,,,,,,,,,, -179797,,,0.8817960619926453,0.4216182827949524,0.7474600076675415,1.027718424797058,50000.0,0.6247000098228455,1.7195037603378296,10000.0,60729.40054988861,62887.25140142441,60729.40054988861,2143.679017782212,7.606815099716186,0.0 -179800,5.29142,0.97914314,,,,,,,,,,,,,, -179900,5.477518,0.9416968,,,,,,,,,,,,,, -180000,5.318886,1.0066898,,,,,,,,,,,,,, -180100,5.0610948,0.9298854,,,,,,,,,,,,,, -180200,5.477623,0.9890789,,,,,,,,,,,,,, -180300,5.4049106,0.9561018,,,,,,,,,,,,,, -180400,5.6164293,0.94739306,,,,,,,,,,,,,, -180500,4.91705,0.99860376,,,,,,,,,,,,,, -180600,5.080845,0.96500003,,,,,,,,,,,,,, -180700,5.4532685,1.0160329,,,,,,,,,,,,,, -180800,5.2394614,0.9465799,,,,,,,,,,,,,, -180900,5.381522,1.0081002,,,,,,,,,,,,,, -181000,5.113601,0.83304894,,,,,,,,,,,,,, -181100,5.3939204,0.97095984,,,,,,,,,,,,,, -181200,5.4098444,0.99092627,,,,,,,,,,,,,, -181300,5.3562117,0.948844,,,,,,,,,,,,,, -181309,,,0.8843869566917419,0.4094845950603485,0.7476199865341187,1.02272629737854,50000.0,0.6256000399589539,1.7120555639266968,10000.0,61239.36246538162,63415.05498671532,61239.36246538162,2161.401031255722,7.670835256576538,0.0 -181400,4.7818522,0.92553896,,,,,,,,,,,,,, -181500,5.5169225,0.9824356,,,,,,,,,,,,,, -181600,5.399964,1.0364746,,,,,,,,,,,,,, -181700,5.626503,0.8837422,,,,,,,,,,,,,, -181800,5.1680245,0.8867746,,,,,,,,,,,,,, -181900,5.274159,0.9376025,,,,,,,,,,,,,, -182000,5.2186823,0.9612715,,,,,,,,,,,,,, -182100,5.209294,0.9865149,,,,,,,,,,,,,, -182200,5.254197,0.97305846,,,,,,,,,,,,,, -182300,5.474855,0.94805205,,,,,,,,,,,,,, -182400,5.40803,0.9800776,,,,,,,,,,,,,, -182500,5.1889415,0.8736851,,,,,,,,,,,,,, -182600,5.5903172,1.0133755,,,,,,,,,,,,,, -182700,5.4884973,0.95121795,,,,,,,,,,,,,, -182800,5.7917085,1.0491575,,,,,,,,,,,,,, -182821,,,0.8831712007522583,0.4109700918197632,0.7479400038719177,1.0229642391204834,50000.0,0.6282000541687012,1.714728593826294,10000.0,61749.32241177559,63942.94005918503,61749.32241177559,2179.212869882584,7.729203224182129,0.0 -182900,5.5618277,0.94064385,,,,,,,,,,,,,, -183000,5.462431,0.9582249,,,,,,,,,,,,,, -183100,4.98202,0.8813539,,,,,,,,,,,,,, -183200,5.4904923,0.99370015,,,,,,,,,,,,,, -183300,5.197696,0.9277254,,,,,,,,,,,,,, -183400,4.8859186,0.89627457,,,,,,,,,,,,,, -183500,5.368655,1.0084296,,,,,,,,,,,,,, -183600,5.3714423,1.0107054,,,,,,,,,,,,,, -183700,5.447846,0.9881313,,,,,,,,,,,,,, -183800,5.6277313,0.9529034,,,,,,,,,,,,,, -183900,5.6278863,1.0540372,,,,,,,,,,,,,, -184000,5.2965612,0.9373045,,,,,,,,,,,,,, -184100,4.9623885,0.8131571,,,,,,,,,,,,,, -184200,5.393029,0.9664191,,,,,,,,,,,,,, -184300,5.789408,1.0109682,,,,,,,,,,,,,, -184333,,,0.8838687539100647,0.4128356277942657,0.7488600015640259,1.020200252532959,50000.0,0.6284000277519226,1.7110487222671509,10000.0,62259.33857059479,64470.75457596779,62259.33857059479,2196.8955330848694,7.789600133895874,0.0 -184400,5.273094,0.9452543,,,,,,,,,,,,,, -184500,5.0449114,0.925468,,,,,,,,,,,,,, -184600,4.9894037,0.93707246,,,,,,,,,,,,,, -184700,5.3544655,0.9725454,,,,,,,,,,,,,, -184800,5.2190065,0.9120327,,,,,,,,,,,,,, -184900,5.2118316,0.9581812,,,,,,,,,,,,,, -185000,5.1144757,0.92749894,,,,,,,,,,,,,, -185100,5.336351,0.9505724,,,,,,,,,,,,,, -185200,5.315153,0.8886902,,,,,,,,,,,,,, -185300,5.8041253,1.0089253,,,,,,,,,,,,,, -185400,5.2165136,0.967949,,,,,,,,,,,,,, -185500,5.3965144,0.9164293,,,,,,,,,,,,,, -185600,5.279721,0.8910395,,,,,,,,,,,,,, -185700,5.518393,1.0008813,,,,,,,,,,,,,, -185800,5.2870574,0.96679705,,,,,,,,,,,,,, -185845,,,0.8853037357330322,0.4091644287109375,0.7495599985122681,1.0208008289337158,50000.0,0.6296000480651855,1.7125041484832764,10000.0,62769.32061004639,64998.45347905159,62769.32061004639,2214.4908051490784,7.853912591934204,0.0 -185900,5.625935,0.99908864,,,,,,,,,,,,,, -186000,5.6473327,1.003317,,,,,,,,,,,,,, -186100,4.7849913,0.9046483,,,,,,,,,,,,,, -186200,5.0472302,0.8982937,,,,,,,,,,,,,, -186300,5.2065253,0.9164883,,,,,,,,,,,,,, -186400,4.9711905,0.83571434,,,,,,,,,,,,,, -186500,5.598561,0.9194703,,,,,,,,,,,,,, -186554,,,,,,,,,,,63008.32200908661,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/eval_measurements.csv deleted file mode 100644 index bb1571bef..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.920067310333252,0.0,32.45261359214783,1,0,32.45261359214783,0.0006000000284984,6.9125494956970215,10000,50.37277865409851,0.0010363520123064,6.912950038909912,0.0007599999662488,6.913174629211426,50000 -35.60792398452759,0.0200452804565429,542.546288728714,1497,0,542.546288728714,0.1091000065207481,4.887209415435791,10000,578.2290046215057,0.1638831347227096,4.294137477874756,0.1471199989318847,4.409202575683594,50000 -53.48245787620544,0.0480029582977294,1052.6925213336945,2994,0,1052.6925213336945,0.2400000095367431,3.841279029846192,10000,1106.3321468830109,0.3364756107330322,3.0776820182800293,0.310839980840683,3.213662624359131,50000 -71.33141088485718,0.0783967971801757,1562.6216881275177,4492,0,1562.6216881275177,0.3355000019073486,3.185488224029541,10000,1634.1957335472107,0.5123365521430969,2.0910186767578125,0.4371799826622009,2.503953695297241,50000 -89.09158182144165,0.1058940887451171,2072.839183807373,5992,0,2072.839183807373,0.3891000151634216,2.906574010848999,10000,2162.2559444904327,0.5540696382522583,1.884698987007141,0.5013599991798401,2.165916919708252,50000 -107.05718922615053,0.1353254318237304,2582.7628107070923,7492,0,2582.7628107070923,0.4227000176906585,2.6873884201049805,10000,2690.230018377304,0.5963010191917419,1.6671967506408691,0.5468400120735168,1.941999793052673,50000 -124.83120155334473,0.1644771099090576,3092.9980075359344,8993,0,3092.9980075359344,0.4240000247955322,2.697845935821533,10000,3218.3233294487,0.5882493257522583,1.7096151113510132,0.5422999858856201,1.9613946676254272,50000 -142.59801578521729,0.1943874359130859,3603.048814535141,10494,0,3603.048814535141,0.4456000328063965,2.5981123447418213,10000,3746.226803779602,0.6168686151504517,1.563010334968567,0.5707399845123291,1.820184588432312,50000 -160.46470522880554,0.2256078720092773,4112.968335390091,11995,0,4112.968335390091,0.4604000151157379,2.5150883197784424,10000,4274.098203659058,0.6272122263908386,1.5147466659545898,0.5814799666404724,1.769814372062683,50000 -178.39834880828855,0.2538959980010986,4623.030744314194,13496,0,4623.030744314194,0.4640000164508819,2.4521992206573486,10000,4802.1791207790375,0.6594387888908386,1.379641890525818,0.5920199751853943,1.711396336555481,50000 -196.40883922576904,0.2858829498291015,5133.006846666336,14997,0,5133.006846666336,0.4669000208377838,2.463068723678589,10000,5330.253460884094,0.6680684089660645,1.324404001235962,0.5925399661064148,1.721736192703247,50000 -214.2844240665436,0.3177447319030761,5643.204411029816,16499,0,5643.204411029816,0.4693000316619873,2.453063726425171,10000,5858.414639472961,0.6616111397743225,1.3371883630752563,0.6007199883460999,1.6780246496200562,50000 -232.5098798274994,0.3536627292633056,6153.446529865265,18001,0,6153.446529865265,0.486700028181076,2.3434627056121826,10000,6386.974694013596,0.6668726205825806,1.3255809545516968,0.6147199869155884,1.603179693222046,50000 -250.44188237190247,0.3847367763519287,6663.575809955597,19503,0,6663.575809955597,0.4778000116348266,2.4080591201782227,10000,6915.123049736023,0.6614915132522583,1.3690069913864136,0.6030600070953369,1.656613826751709,50000 -268.6344575881958,0.4189190864562988,7173.793117523193,21006,0,7173.793117523193,0.4838000237941742,2.356985330581665,10000,7443.621410608292,0.6614915132522583,1.3474102020263672,0.6093999743461609,1.626532793045044,50000 -286.4828236103058,0.4501662254333496,7683.747024536133,22508,0,7683.747024536133,0.4811000227928161,2.3702409267425537,10000,7971.509313106537,0.6671914458274841,1.326819896697998,0.6113399863243103,1.6232869625091553,50000 -304.32325291633606,0.4812209606170654,8193.990616083145,24011,0,8193.990616083145,0.4948000311851501,2.310030937194824,10000,8499.679381370544,0.6983218789100647,1.184372901916504,0.6136800050735474,1.602954387664795,50000 -322.419335603714,0.5180819034576416,8703.985595703125,25513,0,8703.985595703125,0.4944000244140625,2.344824314117432,10000,9027.862311124802,0.6822385191917419,1.2491592168807983,0.6146399974822998,1.6218963861465454,50000 -340.03698801994324,0.5498857498168945,9214.018256664276,27016,0,9214.018256664276,0.501800000667572,2.2529296875,10000,9555.59879374504,0.6840720772743225,1.2421938180923462,0.6232399940490723,1.559043526649475,50000 -357.99346256256104,0.5839834213256836,9724.112269163132,28518,0,9724.112269163132,0.4957000315189361,2.3093740940093994,10000,10083.73843884468,0.6774553656578064,1.2737376689910889,0.6197400093078613,1.5842479467391968,50000 -376.1608896255493,0.6163861751556396,10234.21173286438,30021,0,10234.21173286438,0.5035000443458557,2.2645206451416016,10000,10612.092364549637,0.6818000674247742,1.2576472759246826,0.6284199953079224,1.5533267259597778,50000 -394.0955708026886,0.6547572612762451,10744.19499206543,31523,0,10744.19499206543,0.5014000535011292,2.269998788833618,10000,11140.10380768776,0.6839724183082581,1.2390429973602295,0.6254400014877319,1.5558955669403076,50000 -411.72698998451233,0.6967041492462158,11254.327833890917,33025,0,11254.327833890917,0.492900013923645,2.32331657409668,10000,11667.965344667437,0.7113161683082581,1.1245791912078855,0.6225999593734741,1.5649104118347168,50000 -429.7202451229096,0.731212854385376,11764.347550868988,34527,0,11764.347550868988,0.4936000108718872,2.312537670135498,10000,12196.06828379631,0.6932995915412903,1.201788306236267,0.624459981918335,1.5534310340881348,50000 -447.4782257080078,0.7672967910766602,12274.329077005386,36030,0,12274.329077005386,0.5020000338554382,2.281333446502685,10000,12723.898445367811,0.6974449753761292,1.1905690431594849,0.6278799772262573,1.5417182445526123,50000 -465.655154466629,0.8066210746765137,12784.488025665283,37533,0,12784.488025665283,0.5141000151634216,2.2257673740386963,10000,13252.32969903946,0.7012914419174194,1.1706738471984863,0.6375399827957153,1.4896409511566162,50000 -483.5669913291931,0.8619179725646973,13294.657634973526,39037,0,13294.657634973526,0.4931000173091888,2.3207204341888428,10000,13780.520778179169,0.6754822731018066,1.2862061262130735,0.6221799850463867,1.5718142986297607,50000 -501.5477590560913,0.9012980461120604,13804.57847905159,40539,0,13804.57847905159,0.5141000151634216,2.200263738632202,10000,14308.516256809236,0.7030452489852905,1.1629226207733154,0.6412999629974365,1.4571741819381714,50000 -519.2402155399323,0.9375874996185304,14314.496087789536,42041,0,14314.496087789536,0.4976000189781189,2.3147168159484863,10000,14836.219805955889,0.7258250713348389,1.0648127794265747,0.6235199570655823,1.5412627458572388,50000 -537.1130058765411,0.9754059314727784,14824.710835456848,43544,0,14824.710835456848,0.5039000511169434,2.242692947387696,10000,15364.400060892103,0.7102598547935486,1.1271930932998655,0.6373999714851379,1.4905970096588137,50000 -555.0589742660522,1.01684308052063,15334.921788215635,45048,0,15334.921788215635,0.5164999961853027,2.19850754737854,10000,15892.65364933014,0.7119140625,1.117753505706787,0.6432799696922302,1.4690630435943604,50000 -572.9238193035126,1.068709373474121,15844.842143058777,46549,0,15844.842143058777,0.5101000070571899,2.2613413333892822,10000,16420.54642868042,0.6991389989852905,1.1732345819473269,0.6344599723815918,1.5220282077789309,50000 -590.5062322616577,1.105797290802002,16355.0407371521,48053,0,16355.0407371521,0.5202000141143799,2.180363893508911,10000,16948.419927835464,0.7051379084587097,1.1431678533554075,0.6401799917221069,1.4874356985092163,50000 -608.299302816391,1.1421661376953125,16865.1335606575,49556,0,16865.1335606575,0.5054000020027161,2.229286432266236,10000,17476.397280454636,0.6994977593421936,1.168281078338623,0.638759970664978,1.4861197471618652,50000 -626.4494802951813,1.178267478942871,17375.330893039703,51060,0,17375.330893039703,0.5080000162124634,2.239075422286988,10000,18004.835172891617,0.7434629797935486,0.9698396325111388,0.6350199580192566,1.5050801038742063,50000 -644.3647968769073,1.2180454730987549,17885.269745588303,52563,0,17885.269745588303,0.5177000164985657,2.173795223236084,10000,18532.783682346344,0.7270607352256775,1.0618703365325928,0.6416199803352356,1.4836870431900024,50000 -662.2336344718933,1.2604291439056396,18395.33567595482,54067,0,18395.33567595482,0.5090000033378601,2.266648769378662,10000,19060.81605863571,0.70316481590271,1.1438255310058594,0.6356399655342102,1.5182785987854004,50000 -680.2520830631256,1.3046739101409912,18905.46354651451,55570,0,18905.46354651451,0.5169000029563904,2.2187652587890625,10000,19589.064910411835,0.7121531963348389,1.1044635772705078,0.6428200006484985,1.4710193872451782,50000 -697.8852643966675,1.3421645164489746,19415.57575273513,57074,0,19415.57575273513,0.5103000402450562,2.253450393676758,10000,20116.90283894539,0.7025071382522583,1.1650789976119995,0.6407999992370605,1.4790819883346558,50000 -716.4295771121979,1.383310317993164,19925.48170900345,58577,0,19925.48170900345,0.5179000496864319,2.2345850467681885,10000,20645.44976592064,0.7065728306770325,1.1429468393325806,0.6433799862861633,1.4730268716812134,50000 -734.2747831344604,1.4285414218902588,20435.709728956223,60081,0,20435.709728956223,0.5254000425338745,2.1215386390686035,10000,21173.62443089485,0.7592673897743225,0.929104745388031,0.658079981803894,1.402864933013916,50000 -752.2365992069244,1.4622957706451416,20945.643161058422,61584,0,20945.643161058422,0.5242000222206116,2.153595447540283,10000,21701.610381364822,0.73636794090271,1.0095868110656738,0.6518799662590027,1.4322763681411743,50000 -769.8576102256775,1.5023293495178225,21455.75441765785,63088,0,21455.75441765785,0.5323000550270081,2.100260734558105,10000,22229.440237522125,0.7388990521430969,0.9919620156288148,0.6604399681091309,1.3843040466308594,50000 -787.5250680446625,1.551323413848877,21965.73911070824,64591,0,21965.73911070824,0.51910001039505,2.223273992538452,10000,22757.195997476578,0.7138074040412903,1.106141448020935,0.6453199982643127,1.459357976913452,50000 -805.305394411087,1.60355544090271,22475.68429350853,66095,0,22475.68429350853,0.5354000329971313,2.1329658031463623,10000,23285.028936624527,0.7254065275192261,1.0538322925567627,0.6597200036048889,1.3996851444244385,50000 -823.2509181499481,1.649055242538452,22985.69870376587,67598,0,22985.69870376587,0.5164999961853027,2.1947619915008545,10000,23813.08985543251,0.7191087007522583,1.084785223007202,0.6524999737739563,1.4338654279708862,50000 -841.3075633049011,1.700188159942627,23495.98047399521,69101,0,23495.98047399521,0.5143000483512878,2.2291743755340576,10000,24341.53601646424,0.7221181392669678,1.076885223388672,0.6430599689483643,1.46754789352417,50000 -859.0242967605591,1.7394630908966064,24006.12465786934,70605,0,24006.12465786934,0.5434000492095947,2.08642315864563,10000,24869.49318599701,0.7581911683082581,0.9060986638069152,0.6655600070953369,1.3723647594451904,50000 -876.6616532802582,1.7846426963806152,24516.31637406349,72109,0,24516.31637406349,0.5228000283241272,2.215155601501465,10000,25397.42277216912,0.7310466766357422,1.0297045707702637,0.6527999639511108,1.4410232305526731,50000 -894.5995259284973,1.8255927562713623,25026.27029204369,73611,0,25026.27029204369,0.5325000286102295,2.110114812850952,10000,25925.412529945374,0.7375039458274841,0.9993064999580384,0.6623799800872803,1.390560746192932,50000 -912.429455280304,1.87069034576416,25536.4892745018,75115,0,25536.4892745018,0.5391000509262085,2.0707874298095703,10000,26453.562220811844,0.7437619566917419,0.9866275191307068,0.6665399670600891,1.3468725681304932,50000 -929.9996762275696,1.908890724182129,26046.512956619263,76618,0,26046.512956619263,0.5473999977111816,2.035409450531006,10000,26981.249529123303,0.7401546239852905,0.9870557188987732,0.6706399917602539,1.3400790691375732,50000 -947.8705537319184,1.9521074295043943,26556.6778280735,78122,0,26556.6778280735,0.5392000079154968,2.11980938911438,10000,27509.383778572083,0.7384207248687744,1.0084189176559448,0.665340006351471,1.3739938735961914,50000 -965.515125989914,1.991746425628662,27066.751952409744,79626,0,27066.751952409744,0.5445000529289246,2.0731024742126465,10000,28037.19790363312,0.767578125,0.8716039061546326,0.6688399910926819,1.3541743755340576,50000 -983.458841085434,2.036340713500977,27576.86024737358,81130,0,27576.86024737358,0.5324000120162964,2.085909128189087,10000,28565.35233569145,0.7565369606018066,0.9142917394638062,0.6692000031471252,1.3499724864959717,50000 -1001.4011144638062,2.0771901607513428,28086.937801122665,82633,0,28086.937801122665,0.5343000292778015,2.1417131423950195,10000,29093.46788740158,0.7384008169174194,1.0000159740447998,0.6617599725723267,1.3959338665008545,50000 -1019.136206626892,2.137763500213623,28597.01343441009,84137,0,28597.01343441009,0.541100025177002,2.071115255355835,10000,29621.39469194412,0.7523317933082581,0.9355828166007996,0.6721799969673157,1.334437608718872,50000 -1036.7360591888428,2.182596445083618,29106.93808484077,85640,0,29106.93808484077,0.5494000315666199,2.091202735900879,10000,30149.020438194275,0.7469307780265808,0.961391270160675,0.66975998878479,1.3483177423477173,50000 -1054.3037416934967,2.2248997688293457,29617.127873182297,87144,0,29617.127873182297,0.534600019454956,2.0925774574279785,10000,30676.874872922897,0.7379224896430969,0.9985449314117432,0.6692799925804138,1.3523105382919312,50000 -1072.1030399799347,2.271844387054444,30127.066581964493,88647,0,30127.066581964493,0.5424000024795532,2.068551778793335,10000,31204.71547460556,0.7837810516357422,0.8073683381080627,0.6779199838638306,1.3172043561935425,50000 -1090.0969746112823,2.317918062210083,30637.134697914124,90151,0,30637.134697914124,0.5539000034332275,2.0205748081207275,10000,31732.880301237103,0.7757493257522583,0.8434444069862366,0.6840199828147888,1.2816094160079956,50000 -1107.751292705536,2.369476079940796,31147.257095575333,91655,0,31147.257095575333,0.5481000542640686,2.098851203918457,10000,32260.765964984894,0.7606425285339355,0.8926047086715698,0.6730200052261353,1.342376470565796,50000 -1125.5106115341189,2.420135974884033,31657.1976583004,93159,0,31657.1976583004,0.5511000156402588,2.040935516357422,10000,32788.57222819328,0.7567561864852905,0.9086245894432068,0.6795399785041809,1.3091254234313965,50000 -1143.3546307086945,2.4641125202178955,32167.262050628666,94663,0,32167.262050628666,0.5546000003814697,2.026935338973999,10000,33316.578904390335,0.7587292790412903,0.9139228463172911,0.6759399771690369,1.32182776927948,50000 -1161.3488364219666,2.514535427093506,32677.42819571495,96167,0,32677.42819571495,0.5508000254631042,2.0234200954437256,10000,33844.844178915024,0.7642498016357422,0.8854460120201111,0.6845600008964539,1.2905325889587402,50000 -1180.1326916217804,2.562276601791382,33187.63098335266,97669,0,33187.63098335266,0.5527999997138977,2.02603530883789,10000,34373.93581676483,0.8039301633834839,0.7246366143226624,0.6893399953842163,1.2651199102401731,50000 -1198.0567321777344,2.6048731803894043,33697.79170894623,99173,0,33697.79170894623,0.5654000043869019,1.98188054561615,10000,34902.117525577545,0.7836814522743225,0.7890745997428894,0.6879400014877319,1.263701558113098,50000 -1215.964411497116,2.652205467224121,34207.87293791771,100676,0,34207.87293791771,0.5601000189781189,1.9989219903945925,10000,35430.207654953,0.7801538705825806,0.8223650455474854,0.689579963684082,1.26771342754364,50000 -1233.6210358142853,2.698038578033448,34717.97048306465,102180,0,34717.97048306465,0.5678000450134277,1.9510797262191768,10000,35958.062203884125,0.78226637840271,0.813494086265564,0.6924200057983398,1.249699354171753,50000 -1251.2748274803162,2.7438771724700928,35228.02629613876,103683,0,35228.02629613876,0.5628000497817993,1.992747783660889,10000,36485.87252473831,0.7759087681770325,0.8256220817565918,0.6915599703788757,1.2632023096084597,50000 -1269.1238374710083,2.793518781661988,35737.980461120605,105186,0,35737.980461120605,0.5644000172615051,1.9662601947784424,10000,37013.7807905674,0.7785993218421936,0.8219675421714783,0.6947000026702881,1.2501887083053589,50000 -1287.228770017624,2.8367748260498047,36248.21604681015,106690,0,36248.21604681015,0.5685000419616699,2.02057147026062,10000,37542.218354702,0.8191565275192261,0.6560265421867371,0.6922399997711182,1.2624542713165283,50000 -1304.8902144432068,2.8855910301208496,36758.368527174,108194,0,36758.368527174,0.556600034236908,2.0142979621887207,10000,38070.13637447357,0.7940050959587097,0.7586848735809326,0.6922000050544739,1.26578688621521,50000 -1322.726529121399,2.933062791824341,37268.60179138184,109698,0,37268.60179138184,0.5653000473976135,1.950202465057373,10000,38598.30977582932,0.7897400856018066,0.7790143489837646,0.6975199580192566,1.232077956199646,50000 -1340.6236248016355,2.9807047843933105,37778.79350161552,111201,0,37778.79350161552,0.5591000318527222,2.025722026824951,10000,39126.50262880325,0.7818080186843872,0.8038931488990784,0.6892600059509277,1.272587776184082,50000 -1358.8687388896942,3.033480167388916,38288.8953294754,112705,0,38288.8953294754,0.5699000358581543,1.9524030685424805,10000,39654.957559108734,0.7928690910339355,0.7587718367576599,0.699400007724762,1.2161651849746704,50000 -1377.1002910137177,3.080017566680908,38798.96592998505,114209,0,38798.96592998505,0.570900022983551,1.9734641313552856,10000,40183.36325955391,0.7909358739852905,0.765129566192627,0.6998800039291382,1.2264665365219116,50000 -1395.0114710330963,3.126986265182495,39309.17578649521,115713,0,39309.17578649521,0.5711000561714172,1.942007064819336,10000,40711.58852005005,0.8435705900192261,0.576335072517395,0.7015399932861328,1.2031558752059937,50000 -1412.585030078888,3.1796674728393555,39819.32137942314,117217,0,39819.32137942314,0.5755000114440918,1.9088903665542605,10000,41239.41584944725,0.8229233026504517,0.6437040567398071,0.7073799967765808,1.194665551185608,50000 -1430.6261870861051,3.2323155403137207,40329.29307627678,118720,0,40329.29307627678,0.5786000490188599,1.9398770332336424,10000,41767.53636336327,0.8165457248687744,0.6733990907669067,0.7057200074195862,1.2064604759216309,50000 -1449.2406451702118,3.280336380004883,40839.21898150444,120223,0,40839.21898150444,0.5800000429153442,1.913909673690796,10000,42296.18041563034,0.8181002736091614,0.6516917943954468,0.7089999914169312,1.1775261163711548,50000 -1467.3116641044617,3.323556900024414,41349.17365002632,121726,0,41349.17365002632,0.5759000182151794,1.9429177045822144,10000,42824.30308508873,0.8082947731018066,0.6896772384643555,0.7066999673843384,1.1927813291549685,50000 -1485.2464039325714,3.3781511783599854,41859.14821410179,123230,0,41859.14821410179,0.5870000123977661,1.930961012840271,10000,43352.32322573662,0.8172034025192261,0.6480165719985962,0.7120800018310547,1.1730077266693115,50000 -1502.886422872543,3.428584575653076,42369.25692343712,124734,0,42369.25692343712,0.5822000503540039,1.8956857919692995,10000,43880.17747545242,0.8444275856018066,0.5557413101196289,0.7135199904441833,1.162333846092224,50000 -1520.7035658359528,3.4789483547210693,42879.32626509666,126238,0,42879.32626509666,0.5902000069618225,1.887725830078125,10000,44408.17130422592,0.8463608026504517,0.5420656800270081,0.7162599563598633,1.1581156253814695,50000 -1538.350778579712,3.5324387550354004,43389.325770139694,127741,0,43389.325770139694,0.5898000001907349,1.8950108289718628,10000,44935.92632341385,0.8425143361091614,0.5580957531929016,0.7177199721336365,1.1547913551330566,50000 -1556.3912541866302,3.587545394897461,43899.508662223816,129246,0,43899.508662223816,0.5910000205039978,1.917563915252685,10000,45464.26081061363,0.8390465378761292,0.5831946134567261,0.7163999676704407,1.167493462562561,50000 -1574.4115755558014,3.64362382888794,44409.489119291306,130750,0,44409.489119291306,0.5909000039100647,1.8844317197799685,10000,45992.3717007637,0.8318319320678711,0.5993456840515137,0.715939998626709,1.1581324338912964,50000 -1592.0317244529724,3.695392608642578,44919.46421599388,132254,0,44919.46421599388,0.5952000021934509,1.8953807353973389,10000,46520.07335758209,0.8413584232330322,0.5591051578521729,0.7198799848556519,1.1507633924484253,50000 -1609.954525232315,3.7486047744750977,45429.54328536987,133758,0,45429.54328536987,0.5903000235557556,1.915621519088745,10000,47048.18325304985,0.8465999364852905,0.5491646528244019,0.7141599655151367,1.1705526113510132,50000 -1628.5150740146637,3.798022508621216,45939.4492623806,135262,0,45939.4492623806,0.5943000316619873,1.874566674232483,10000,47576.754868507385,0.8697385191917419,0.4575299024581909,0.7210599780082703,1.1391361951828003,50000 -1646.445203781128,3.8481035232543945,46449.43899035454,136765,0,46449.43899035454,0.6003000140190125,1.883261442184448,10000,48104.78061199188,0.8618462681770325,0.4852306246757507,0.7225599884986877,1.1408337354660034,50000 -1664.1922266483307,3.9070560932159424,46959.6368894577,138269,0,46959.6368894577,0.6046000123023987,1.865692138671875,10000,48632.8400952816,0.867586076259613,0.4603193998336792,0.731939971446991,1.108080506324768,50000 -1681.8307964801788,3.9598209857940674,47469.786170721054,139773,0,47469.786170721054,0.5985000133514404,1.8925247192382808,10000,49160.73694300652,0.8621252775192261,0.4767018556594848,0.7252799868583679,1.134071946144104,50000 -1699.5504500865936,4.012192010879517,47979.92598223686,141277,0,47979.92598223686,0.6062000393867493,1.8601926565170288,10000,49688.70345711708,0.8630420565605164,0.4747589826583862,0.726419985294342,1.13043475151062,50000 -1717.231991291046,4.073648452758789,48490.05924510956,142781,0,48490.05924510956,0.5963000059127808,1.9035028219223025,10000,50216.637149333954,0.8639389276504517,0.4681302905082702,0.7273600101470947,1.1400127410888672,50000 -1735.258573770523,4.126613616943359,49000.12387108803,144285,0,49000.12387108803,0.6078000068664551,1.842755913734436,10000,50744.83660268784,0.8989357352256775,0.352335661649704,0.7356399893760681,1.1080385446548462,50000 -1752.9578087329865,4.184151411056519,49510.03820419312,145788,0,49510.03820419312,0.6132000088691711,1.842835783958435,10000,51272.56175875664,0.8931760191917419,0.36942258477211,0.7378199696540833,1.0927354097366333,50000 -1770.8793761730194,4.238028526306152,50020.07937192917,147290,0,50020.07937192917,0.6048000454902649,1.8468471765518188,10000,51800.63540673256,0.8905253410339355,0.3740461766719818,0.7361399531364441,1.08772873878479,50000 -1788.7686505317688,4.294397830963135,50529.98758006096,148793,0,50529.98758006096,0.6073000431060791,1.875683546066284,10000,52328.54508733749,0.8932557106018066,0.3641088008880615,0.7346799969673157,1.1118369102478027,50000 -1806.554343700409,4.349008321762085,51040.08410692215,150296,0,51040.08410692215,0.6160000562667847,1.858383297920227,10000,52856.537358284,0.8932358026504517,0.363602340221405,0.7393400073051453,1.103121042251587,50000 -1824.7137916088104,4.405257701873779,51550.05728435516,151799,0,51550.05728435516,0.6099000573158264,1.8680452108383176,10000,53384.782229185104,0.8980388641357422,0.3513565361499786,0.7346799969673157,1.1100929975509644,50000 -1842.537333250045,4.462253093719482,52060.17003774643,153303,0,52060.17003774643,0.6100000143051147,1.8766512870788568,10000,53912.83220410347,0.9187459945678712,0.2857046723365783,0.7361999750137329,1.1064831018447876,50000 -1860.575585603714,4.516483306884766,52570.1029856205,154806,0,52570.1029856205,0.6166000366210938,1.84649658203125,10000,54440.91394495964,0.9205396771430968,0.2786844074726105,0.7440399527549744,1.0820631980895996,50000 -1878.458645582199,4.574311494827271,53080.2790760994,156310,0,53080.2790760994,0.6159000396728516,1.860220432281494,10000,54969.0854177475,0.9159757494926452,0.2899870276451111,0.7417399883270264,1.092580795288086,50000 -1896.2873928546903,4.627979040145874,53590.2423658371,157813,0,53590.2423658371,0.617400050163269,1.8521299362182613,10000,55496.98831796646,0.9194834232330322,0.2723006904125213,0.7447599768638611,1.0868667364120483,50000 -1914.3551132678983,4.682864904403687,54100.30746936798,159316,0,54100.30746936798,0.6223000288009644,1.8350087404251096,10000,56025.2314991951,0.9231106042861938,0.2618570923805237,0.746239960193634,1.0792582035064695,50000 -1932.119315862656,4.740338563919067,54610.30392670632,160819,0,54610.30392670632,0.6165000200271606,1.841195821762085,10000,56553.10640120506,0.9272759556770324,0.2506012916564941,0.7461400032043457,1.0742896795272827,50000 -1949.6665840148928,4.796144008636475,55120.26196694374,162322,0,55120.26196694374,0.6172000169754028,1.8549610376358032,10000,57080.722338199615,0.9462890625,0.1946813315153122,0.7486400008201599,1.0698437690734863,50000 -1967.51427936554,4.854479789733887,55630.45647478104,163825,0,55630.45647478104,0.6202000379562378,1.8455793857574463,10000,57608.87885069847,0.9412069320678712,0.2064539641141891,0.7478599548339844,1.0675770044326782,50000 -1985.47976231575,4.911173105239868,56140.58478808403,165329,0,56140.58478808403,0.6196000576019287,1.8581254482269287,10000,58137.0842730999,0.9415258169174194,0.2043623626232147,0.7491599917411804,1.068067193031311,50000 -2003.517992734909,4.971322536468506,56650.707070589066,166832,0,56650.707070589066,0.6226000189781189,1.836353421211243,10000,58665.36136960983,0.9414859414100648,0.2066220045089721,0.7490400075912476,1.0683197975158691,50000 -2021.207232952118,5.02895712852478,57160.77901077271,168335,0,57160.77901077271,0.6261000037193298,1.8424190282821653,10000,59193.23601198197,0.9429408311843872,0.1995099037885666,0.7509599924087524,1.0619680881500244,50000 -2038.746482372284,5.086165189743042,57670.78219342232,169838,0,57670.78219342232,0.6240000128746033,1.8413702249526973,10000,59720.89051890373,0.949238657951355,0.1808584332466125,0.7513399720191956,1.060373306274414,50000 -2056.577041864395,5.145231485366821,58180.96397686005,171342,0,58180.96397686005,0.6264000535011292,1.832197308540344,10000,60249.019006729126,0.9575095176696776,0.1561500877141952,0.7531999945640564,1.0549283027648926,50000 -2074.348899126053,5.2050487995147705,58691.0938334465,172845,0,58691.0938334465,0.6274000406265259,1.8417407274246216,10000,60777.03437685967,0.9551578164100648,0.1624404788017273,0.7537599802017212,1.0578937530517578,50000 -2092.863926887512,5.264968156814575,59201.03386545181,174348,0,59201.03386545181,0.6251000165939331,1.828043818473816,10000,61305.60614085197,0.9562938213348388,0.1600802987813949,0.7548199892044067,1.0536705255508425,50000 -2110.787750482559,5.325360298156738,59710.93507862091,175850,0,59710.93507862091,0.628000020980835,1.828550815582276,10000,61833.54664039612,0.9552574753761292,0.1618704348802566,0.753879964351654,1.0522363185882568,50000 -2128.663361310959,5.3837199211120605,60220.999058008194,177353,0,60220.999058008194,0.6294000148773193,1.820769429206848,10000,62361.60022234917,0.9591039419174194,0.1550398766994476,0.7555800080299377,1.0516177415847778,50000 -2146.2478954792023,5.467687129974365,60731.09321784973,178856,0,60731.09321784973,0.6295000314712524,1.821165680885315,10000,62889.41806507111,0.9585060477256776,0.1536491960287094,0.7564199566841125,1.0479017496109009,50000 -2164.042452096939,5.532188653945923,61241.01958680153,180359,0,61241.01958680153,0.6314000487327576,1.8269197940826416,10000,63417.25886678696,0.9599609375,0.1484406441450119,0.7556799650192261,1.0471597909927368,50000 -2181.832026720047,5.597205638885498,61751.20690441132,181862,0,61751.20690441132,0.6303000450134277,1.8241914510726929,10000,63945.35855007172,0.961355984210968,0.1469281464815139,0.7557799816131592,1.046109676361084,50000 -2199.4081242084503,5.6746907234191895,62261.4028301239,183366,0,62261.4028301239,0.6312000155448914,1.8238227367401123,10000,64473.26201725006,0.9613958597183228,0.1471957266330719,0.7557199597358704,1.0465097427368164,50000 -2217.1566207408905,5.737559795379639,62771.28944039345,184868,0,62771.28944039345,0.6312000155448914,1.8226069211959839,10000,65001.01459479332,0.959980845451355,0.14590847492218018,0.7561999559402466,1.0453925132751465,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/measurements.csv deleted file mode 100644 index c101362b7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1982 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6558059,6.934149,,,,,,,,,,,,,, -1,,,0.0010363520123064,6.912950038909912,0.0007599999662488,6.913174629211426,50000.0,0.0006000000284984,6.9125494956970215,10000.0,32.45261359214783,50.37277865409851,32.45261359214783,17.920067310333252,0.0,0.0 -100,0.6718987,6.8274693,,,,,,,,,,,,,, -200,0.7918666,6.5631614,,,,,,,,,,,,,, -300,1.0237862,6.3212547,,,,,,,,,,,,,, -400,1.4975822,5.9742107,,,,,,,,,,,,,, -500,3.178749,5.8881435,,,,,,,,,,,,,, -600,5.8108263,5.606251,,,,,,,,,,,,,, -700,3.4054508,5.5038304,,,,,,,,,,,,,, -800,2.9803205,5.338414,,,,,,,,,,,,,, -900,7.145216,5.1278114,,,,,,,,,,,,,, -1000,6.9615974,5.02708,,,,,,,,,,,,,, -1100,5.8819985,4.8256645,,,,,,,,,,,,,, -1200,3.1157818,4.8339214,,,,,,,,,,,,,, -1300,4.747957,4.7436166,,,,,,,,,,,,,, -1400,3.2620387,4.506455,,,,,,,,,,,,,, -1497,,,0.1638831347227096,4.294137477874756,0.1471199989318847,4.409202575683594,50000.0,0.1091000065207481,4.887209415435791,10000.0,542.546288728714,578.2290046215057,542.546288728714,35.60792398452759,0.0200452804565429,0.0 -1500,3.4361029,4.5533996,,,,,,,,,,,,,, -1600,5.347383,4.3413477,,,,,,,,,,,,,, -1700,5.4807034,4.2591453,,,,,,,,,,,,,, -1800,5.552583,4.179899,,,,,,,,,,,,,, -1900,4.2913575,4.1644273,,,,,,,,,,,,,, -2000,6.182657,4.1737633,,,,,,,,,,,,,, -2100,4.2754383,3.910357,,,,,,,,,,,,,, -2200,3.1794512,3.9872143,,,,,,,,,,,,,, -2300,5.250101,3.8454547,,,,,,,,,,,,,, -2400,4.5101724,3.9257154,,,,,,,,,,,,,, -2500,4.833979,3.6831253,,,,,,,,,,,,,, -2600,3.7704244,3.5967464,,,,,,,,,,,,,, -2700,3.404623,3.5598073,,,,,,,,,,,,,, -2800,4.8649507,3.5372503,,,,,,,,,,,,,, -2900,5.5798774,3.5046062,,,,,,,,,,,,,, -2994,,,0.3364756107330322,3.0776820182800293,0.310839980840683,3.213662624359131,50000.0,0.2400000095367431,3.841279029846192,10000.0,1052.6925213336945,1106.3321468830109,1052.6925213336945,53.48245787620544,0.0480029582977294,0.0 -3000,3.305367,3.2925968,,,,,,,,,,,,,, -3100,4.286718,3.4606073,,,,,,,,,,,,,, -3200,5.6266065,3.3816514,,,,,,,,,,,,,, -3300,3.552451,3.2515733,,,,,,,,,,,,,, -3400,3.2499845,3.3218586,,,,,,,,,,,,,, -3500,3.8776982,3.2771459,,,,,,,,,,,,,, -3600,3.1724317,3.1846714,,,,,,,,,,,,,, -3700,2.9869487,3.0538952,,,,,,,,,,,,,, -3800,5.8235664,3.0825114,,,,,,,,,,,,,, -3900,2.5431843,2.9975348,,,,,,,,,,,,,, -4000,4.2119718,3.050357,,,,,,,,,,,,,, -4100,3.3156505,2.9922848,,,,,,,,,,,,,, -4200,3.0967546,2.9594514,,,,,,,,,,,,,, -4300,4.0979476,2.8633492,,,,,,,,,,,,,, -4400,2.3321176,2.8514335,,,,,,,,,,,,,, -4492,,,0.5123365521430969,2.0910186767578125,0.4371799826622009,2.503953695297241,50000.0,0.3355000019073486,3.185488224029541,10000.0,1562.6216881275177,1634.1957335472107,1562.6216881275177,71.33141088485718,0.0783967971801757,0.0 -4500,2.7188635,2.7674904,,,,,,,,,,,,,, -4600,3.5820951,2.7682872,,,,,,,,,,,,,, -4700,3.2473805,2.771235,,,,,,,,,,,,,, -4800,5.1548843,2.7190142,,,,,,,,,,,,,, -4900,3.1580667,2.7331212,,,,,,,,,,,,,, -5000,2.2043095,2.6126626,,,,,,,,,,,,,, -5100,1.718035,2.609489,,,,,,,,,,,,,, -5200,2.5728455,2.6368496,,,,,,,,,,,,,, -5300,3.7560132,2.7507856,,,,,,,,,,,,,, -5400,2.4848971,2.5449896,,,,,,,,,,,,,, -5500,3.1074994,2.5914373,,,,,,,,,,,,,, -5600,2.3791192,2.6082473,,,,,,,,,,,,,, -5700,1.976633,2.4551477,,,,,,,,,,,,,, -5800,1.9006494,2.5554276,,,,,,,,,,,,,, -5900,1.9953665,2.6091237,,,,,,,,,,,,,, -5992,,,0.5540696382522583,1.884698987007141,0.5013599991798401,2.165916919708252,50000.0,0.3891000151634216,2.906574010848999,10000.0,2072.839183807373,2162.2559444904327,2072.839183807373,89.09158182144165,0.1058940887451171,0.0 -6000,1.9952801,2.5206714,,,,,,,,,,,,,, -6100,2.1704047,2.5956302,,,,,,,,,,,,,, -6200,1.8525758,2.3748806,,,,,,,,,,,,,, -6300,2.5108984,2.429748,,,,,,,,,,,,,, -6400,2.5043797,2.4395146,,,,,,,,,,,,,, -6500,1.8044349,2.2753084,,,,,,,,,,,,,, -6600,2.6476932,2.5323582,,,,,,,,,,,,,, -6700,2.2247875,2.2805407,,,,,,,,,,,,,, -6800,1.6141992,2.3951201,,,,,,,,,,,,,, -6900,1.9341769,2.2716115,,,,,,,,,,,,,, -7000,1.3427562,2.199628,,,,,,,,,,,,,, -7100,2.7809923,2.3127735,,,,,,,,,,,,,, -7200,2.8915899,2.3366032,,,,,,,,,,,,,, -7300,2.3653674,2.2158263,,,,,,,,,,,,,, -7400,1.9738334,2.2099285,,,,,,,,,,,,,, -7492,,,0.5963010191917419,1.6671967506408691,0.5468400120735168,1.941999793052673,50000.0,0.4227000176906585,2.6873884201049805,10000.0,2582.7628107070923,2690.230018377304,2582.7628107070923,107.05718922615053,0.1353254318237304,0.0 -7500,1.7759808,2.2052293,,,,,,,,,,,,,, -7600,1.5266311,2.2386775,,,,,,,,,,,,,, -7700,1.9885845,2.399657,,,,,,,,,,,,,, -7800,2.1925073,2.2636714,,,,,,,,,,,,,, -7900,1.424557,2.3507338,,,,,,,,,,,,,, -8000,1.5591629,2.1510825,,,,,,,,,,,,,, -8100,1.7569371,2.2141652,,,,,,,,,,,,,, -8200,1.5360157,2.1754684,,,,,,,,,,,,,, -8300,2.5284834,2.235492,,,,,,,,,,,,,, -8400,1.3259512,2.126424,,,,,,,,,,,,,, -8500,1.4896046,2.1424606,,,,,,,,,,,,,, -8600,1.9434869,2.2517602,,,,,,,,,,,,,, -8700,1.7482758,2.174961,,,,,,,,,,,,,, -8800,1.9206271,2.1684222,,,,,,,,,,,,,, -8900,1.7018827,2.1234937,,,,,,,,,,,,,, -8993,,,0.5882493257522583,1.7096151113510132,0.5422999858856201,1.9613946676254272,50000.0,0.4240000247955322,2.697845935821533,10000.0,3092.9980075359344,3218.3233294487,3092.9980075359344,124.83120155334473,0.1644771099090576,0.0 -9000,1.3903716,2.0925317,,,,,,,,,,,,,, -9100,1.7917624,2.0808687,,,,,,,,,,,,,, -9200,1.6861194,2.1642203,,,,,,,,,,,,,, -9300,1.6177715,2.093241,,,,,,,,,,,,,, -9400,1.7732509,2.0944633,,,,,,,,,,,,,, -9500,2.1239457,2.1879742,,,,,,,,,,,,,, -9600,2.3108401,2.0233655,,,,,,,,,,,,,, -9700,1.6010185,2.189098,,,,,,,,,,,,,, -9800,1.354861,2.1195774,,,,,,,,,,,,,, -9900,2.1218452,2.0953467,,,,,,,,,,,,,, -10000,1.8139231,2.0094526,,,,,,,,,,,,,, -10100,1.9862077,2.2990596,,,,,,,,,,,,,, -10200,1.8280882,2.1899981,,,,,,,,,,,,,, -10300,1.8001136,2.0042012,,,,,,,,,,,,,, -10400,1.5689006,2.0460322,,,,,,,,,,,,,, -10494,,,0.6168686151504517,1.563010334968567,0.5707399845123291,1.820184588432312,50000.0,0.4456000328063965,2.5981123447418213,10000.0,3603.048814535141,3746.226803779602,3603.048814535141,142.59801578521729,0.1943874359130859,0.0 -10500,1.8015913,2.0183923,,,,,,,,,,,,,, -10600,1.6550474,1.9055208,,,,,,,,,,,,,, -10700,2.532491,2.116743,,,,,,,,,,,,,, -10800,1.6296837,2.1222706,,,,,,,,,,,,,, -10900,1.2824866,2.119035,,,,,,,,,,,,,, -11000,1.846482,2.1067007,,,,,,,,,,,,,, -11100,1.8342911,1.9723431,,,,,,,,,,,,,, -11200,1.4824071,1.9713066,,,,,,,,,,,,,, -11300,1.7470887,2.018032,,,,,,,,,,,,,, -11400,1.7216444,1.9415349,,,,,,,,,,,,,, -11500,1.6026591,1.9942987,,,,,,,,,,,,,, -11600,1.7384981,1.9529785,,,,,,,,,,,,,, -11700,1.8674581,2.002211,,,,,,,,,,,,,, -11800,1.7411429,1.9645399,,,,,,,,,,,,,, -11900,1.6444271,2.03912,,,,,,,,,,,,,, -11995,,,0.6272122263908386,1.5147466659545898,0.5814799666404724,1.769814372062683,50000.0,0.4604000151157379,2.5150883197784424,10000.0,4112.968335390091,4274.098203659058,4112.968335390091,160.46470522880554,0.2256078720092773,0.0 -12000,2.1091497,2.0791569,,,,,,,,,,,,,, -12100,2.0141976,2.046133,,,,,,,,,,,,,, -12200,1.9359121,2.0929039,,,,,,,,,,,,,, -12300,1.5719355,2.0019407,,,,,,,,,,,,,, -12400,1.2964824,1.8605525,,,,,,,,,,,,,, -12500,1.8296108,1.9342936,,,,,,,,,,,,,, -12600,1.3353212,1.9217556,,,,,,,,,,,,,, -12700,1.8714536,1.8650709,,,,,,,,,,,,,, -12800,2.1012897,1.9959509,,,,,,,,,,,,,, -12900,1.7930608,2.009353,,,,,,,,,,,,,, -13000,1.6810968,1.9098372,,,,,,,,,,,,,, -13100,1.3438814,1.9320664,,,,,,,,,,,,,, -13200,1.4996192,1.9218882,,,,,,,,,,,,,, -13300,1.5622846,1.9540737,,,,,,,,,,,,,, -13400,1.5148635,1.9523671,,,,,,,,,,,,,, -13496,,,0.6594387888908386,1.379641890525818,0.5920199751853943,1.711396336555481,50000.0,0.4640000164508819,2.4521992206573486,10000.0,4623.030744314194,4802.1791207790375,4623.030744314194,178.39834880828855,0.2538959980010986,0.0 -13500,1.9401946,1.9143996,,,,,,,,,,,,,, -13600,2.2063315,1.9195917,,,,,,,,,,,,,, -13700,1.700989,1.9417561,,,,,,,,,,,,,, -13800,1.385096,1.8628042,,,,,,,,,,,,,, -13900,1.5736297,1.8741318,,,,,,,,,,,,,, -14000,1.5753489,1.9490285,,,,,,,,,,,,,, -14100,1.9436406,1.8841133,,,,,,,,,,,,,, -14200,2.3227081,1.9948406,,,,,,,,,,,,,, -14300,1.2729859,1.9030231,,,,,,,,,,,,,, -14400,1.5191364,2.0038164,,,,,,,,,,,,,, -14500,1.6143105,1.9903271,,,,,,,,,,,,,, -14600,1.4860089,1.8804309,,,,,,,,,,,,,, -14700,1.430097,1.8759468,,,,,,,,,,,,,, -14800,1.5673044,1.8272469,,,,,,,,,,,,,, -14900,1.4484162,1.8386059,,,,,,,,,,,,,, -14997,,,0.6680684089660645,1.324404001235962,0.5925399661064148,1.721736192703247,50000.0,0.4669000208377838,2.463068723678589,10000.0,5133.006846666336,5330.253460884094,5133.006846666336,196.40883922576904,0.2858829498291015,0.0 -15000,1.7505766,1.9662535,,,,,,,,,,,,,, -15100,2.5196545,1.8456615,,,,,,,,,,,,,, -15200,1.7325988,1.8408732,,,,,,,,,,,,,, -15300,1.4650403,1.9543467,,,,,,,,,,,,,, -15400,1.5696129,1.7764022,,,,,,,,,,,,,, -15500,1.7103058,1.9285797,,,,,,,,,,,,,, -15600,1.7422845,1.8716837,,,,,,,,,,,,,, -15700,1.656561,1.9305787,,,,,,,,,,,,,, -15800,1.7582426,1.9121419,,,,,,,,,,,,,, -15900,1.5264806,1.8778596,,,,,,,,,,,,,, -16000,1.5467184,1.9134076,,,,,,,,,,,,,, -16100,1.4293064,1.9150608,,,,,,,,,,,,,, -16200,1.5606437,1.8855239,,,,,,,,,,,,,, -16300,1.9481398,1.8908638,,,,,,,,,,,,,, -16400,1.4930935,1.9336172,,,,,,,,,,,,,, -16499,,,0.6616111397743225,1.3371883630752563,0.6007199883460999,1.6780246496200562,50000.0,0.4693000316619873,2.453063726425171,10000.0,5643.204411029816,5858.414639472961,5643.204411029816,214.2844240665436,0.3177447319030761,0.0 -16500,1.4872373,1.8948371,,,,,,,,,,,,,, -16600,1.6186775,1.8539093,,,,,,,,,,,,,, -16700,1.6299521,1.7917662,,,,,,,,,,,,,, -16800,1.5846744,1.8682371,,,,,,,,,,,,,, -16900,1.606244,1.8262866,,,,,,,,,,,,,, -17000,1.9221076,1.9090102,,,,,,,,,,,,,, -17100,1.7255155,1.9284865,,,,,,,,,,,,,, -17200,1.6015729,1.8126023,,,,,,,,,,,,,, -17300,1.7064786,1.7890812,,,,,,,,,,,,,, -17400,2.427092,2.0111787,,,,,,,,,,,,,, -17500,1.7198012,1.7771345,,,,,,,,,,,,,, -17600,1.5518156,1.8132082,,,,,,,,,,,,,, -17700,2.3437805,1.7826654,,,,,,,,,,,,,, -17800,1.659259,1.8160061,,,,,,,,,,,,,, -17900,1.5069739,1.7407908,,,,,,,,,,,,,, -18000,1.5944427,1.8289208,,,,,,,,,,,,,, -18001,,,0.6668726205825806,1.3255809545516968,0.6147199869155884,1.603179693222046,50000.0,0.486700028181076,2.3434627056121826,10000.0,6153.446529865265,6386.974694013596,6153.446529865265,232.5098798274994,0.3536627292633056,0.0 -18100,1.5023533,1.8746431,,,,,,,,,,,,,, -18200,1.6068106,1.8537713,,,,,,,,,,,,,, -18300,2.1969466,1.8714517,,,,,,,,,,,,,, -18400,1.9801214,1.7693375,,,,,,,,,,,,,, -18500,1.4862199,1.8556169,,,,,,,,,,,,,, -18600,1.5573003,1.7671547,,,,,,,,,,,,,, -18700,1.7688291,1.9663732,,,,,,,,,,,,,, -18800,1.7864296,1.8193395,,,,,,,,,,,,,, -18900,1.5029086,1.840872,,,,,,,,,,,,,, -19000,1.8988345,1.8034307,,,,,,,,,,,,,, -19100,1.6227353,1.816648,,,,,,,,,,,,,, -19200,1.717673,1.8885759,,,,,,,,,,,,,, -19300,1.5991058,1.8549316,,,,,,,,,,,,,, -19400,1.5699055,1.8497272,,,,,,,,,,,,,, -19500,1.7065781,1.9463865,,,,,,,,,,,,,, -19503,,,0.6614915132522583,1.3690069913864136,0.6030600070953369,1.656613826751709,50000.0,0.4778000116348266,2.4080591201782227,10000.0,6663.575809955597,6915.123049736023,6663.575809955597,250.44188237190247,0.3847367763519287,0.0 -19600,1.7217457,1.9324908,,,,,,,,,,,,,, -19700,1.800533,1.7404758,,,,,,,,,,,,,, -19800,1.7243538,1.7596961,,,,,,,,,,,,,, -19900,1.6035739,1.8052133,,,,,,,,,,,,,, -20000,1.6479064,1.7117736,,,,,,,,,,,,,, -20100,1.927476,1.8350145,,,,,,,,,,,,,, -20200,1.7101965,1.9014235,,,,,,,,,,,,,, -20300,1.5534883,1.8138771,,,,,,,,,,,,,, -20400,1.7646471,1.7359825,,,,,,,,,,,,,, -20500,1.6244994,1.8346864,,,,,,,,,,,,,, -20600,1.5385361,1.7286068,,,,,,,,,,,,,, -20700,1.8459269,1.7243706,,,,,,,,,,,,,, -20800,1.5134753,1.748004,,,,,,,,,,,,,, -20900,1.7915515,1.7141539,,,,,,,,,,,,,, -21000,1.8623264,1.8696024,,,,,,,,,,,,,, -21006,,,0.6614915132522583,1.3474102020263672,0.6093999743461609,1.626532793045044,50000.0,0.4838000237941742,2.356985330581665,10000.0,7173.793117523193,7443.621410608292,7173.793117523193,268.6344575881958,0.4189190864562988,0.0 -21100,1.5865897,1.7691934,,,,,,,,,,,,,, -21200,2.077864,1.8731244,,,,,,,,,,,,,, -21300,1.7467391,1.8430754,,,,,,,,,,,,,, -21400,1.7024914,1.8406131,,,,,,,,,,,,,, -21500,1.6742125,1.785747,,,,,,,,,,,,,, -21600,1.7431713,1.826397,,,,,,,,,,,,,, -21700,1.590146,1.7421108,,,,,,,,,,,,,, -21800,1.8390352,1.7439251,,,,,,,,,,,,,, -21900,1.6600004,1.7901292,,,,,,,,,,,,,, -22000,1.7437395,1.7683692,,,,,,,,,,,,,, -22100,1.6483679,1.8230262,,,,,,,,,,,,,, -22200,1.8113792,1.8006487,,,,,,,,,,,,,, -22300,1.6710986,1.7612792,,,,,,,,,,,,,, -22400,1.614013,1.8022068,,,,,,,,,,,,,, -22500,2.0283203,1.6456584,,,,,,,,,,,,,, -22508,,,0.6671914458274841,1.326819896697998,0.6113399863243103,1.6232869625091553,50000.0,0.4811000227928161,2.3702409267425537,10000.0,7683.747024536133,7971.509313106537,7683.747024536133,286.4828236103058,0.4501662254333496,0.0 -22600,1.8185616,1.7810678,,,,,,,,,,,,,, -22700,1.8115994,1.8636963,,,,,,,,,,,,,, -22800,1.6394376,1.8164878,,,,,,,,,,,,,, -22900,1.5165211,1.6931795,,,,,,,,,,,,,, -23000,1.6455884,1.8406086,,,,,,,,,,,,,, -23100,1.7933791,1.7577785,,,,,,,,,,,,,, -23200,1.7460183,1.7784555,,,,,,,,,,,,,, -23300,1.8659159,1.6830441,,,,,,,,,,,,,, -23400,1.6082969,1.7701615,,,,,,,,,,,,,, -23500,1.6326144,1.7706548,,,,,,,,,,,,,, -23600,1.8764611,1.825115,,,,,,,,,,,,,, -23700,1.7563585,1.8580943,,,,,,,,,,,,,, -23800,1.765341,1.8122395,,,,,,,,,,,,,, -23900,1.5786612,1.7566961,,,,,,,,,,,,,, -24000,1.7214158,1.8001603,,,,,,,,,,,,,, -24011,,,0.6983218789100647,1.184372901916504,0.6136800050735474,1.602954387664795,50000.0,0.4948000311851501,2.310030937194824,10000.0,8193.990616083145,8499.679381370544,8193.990616083145,304.32325291633606,0.4812209606170654,0.0 -24100,2.1490824,1.8047746,,,,,,,,,,,,,, -24200,1.7518932,1.7417504,,,,,,,,,,,,,, -24300,1.9411361,1.8029699,,,,,,,,,,,,,, -24400,1.6312726,1.805182,,,,,,,,,,,,,, -24500,1.7395016,1.8327016,,,,,,,,,,,,,, -24600,1.7354517,1.7402809,,,,,,,,,,,,,, -24700,1.749697,1.7221545,,,,,,,,,,,,,, -24800,1.8002377,1.7405521,,,,,,,,,,,,,, -24900,2.0161226,1.8224053,,,,,,,,,,,,,, -25000,2.2731795,1.8517616,,,,,,,,,,,,,, -25100,1.773554,1.8294716,,,,,,,,,,,,,, -25200,1.8670524,1.7972081,,,,,,,,,,,,,, -25300,1.8639048,1.882245,,,,,,,,,,,,,, -25400,1.6429656,1.7997388,,,,,,,,,,,,,, -25500,2.016669,1.8705302,,,,,,,,,,,,,, -25513,,,0.6822385191917419,1.2491592168807983,0.6146399974822998,1.6218963861465454,50000.0,0.4944000244140625,2.344824314117432,10000.0,8703.985595703125,9027.862311124802,8703.985595703125,322.419335603714,0.5180819034576416,0.0 -25600,1.5345557,1.742786,,,,,,,,,,,,,, -25700,1.6939133,1.8021599,,,,,,,,,,,,,, -25800,1.8545976,1.7078387,,,,,,,,,,,,,, -25900,1.7018945,1.6947541,,,,,,,,,,,,,, -26000,1.8747807,1.8838328,,,,,,,,,,,,,, -26100,1.803004,1.6677818,,,,,,,,,,,,,, -26200,1.8254346,1.7163888,,,,,,,,,,,,,, -26300,1.7695844,1.8106374,,,,,,,,,,,,,, -26400,1.8394692,1.7664268,,,,,,,,,,,,,, -26500,1.8447924,1.8540033,,,,,,,,,,,,,, -26600,1.781399,1.7685503,,,,,,,,,,,,,, -26700,1.6532545,1.808515,,,,,,,,,,,,,, -26800,1.6605778,1.7946228,,,,,,,,,,,,,, -26900,1.6339704,1.5621353,,,,,,,,,,,,,, -27000,1.6399474,1.802996,,,,,,,,,,,,,, -27016,,,0.6840720772743225,1.2421938180923462,0.6232399940490723,1.559043526649475,50000.0,0.501800000667572,2.2529296875,10000.0,9214.018256664276,9555.59879374504,9214.018256664276,340.03698801994324,0.5498857498168945,0.0 -27100,1.8636862,1.7028841,,,,,,,,,,,,,, -27200,1.9217397,1.8376606,,,,,,,,,,,,,, -27300,1.7899656,1.8097854,,,,,,,,,,,,,, -27400,1.8043083,1.7224802,,,,,,,,,,,,,, -27500,1.7497634,1.6995841,,,,,,,,,,,,,, -27600,1.9387181,1.7680752,,,,,,,,,,,,,, -27700,1.9516441,1.68912,,,,,,,,,,,,,, -27800,1.8595477,1.8110557,,,,,,,,,,,,,, -27900,1.7096944,1.7765019,,,,,,,,,,,,,, -28000,1.764697,1.7659276,,,,,,,,,,,,,, -28100,1.731825,1.7251387,,,,,,,,,,,,,, -28200,1.6408726,1.7874603,,,,,,,,,,,,,, -28300,1.654625,1.7007427,,,,,,,,,,,,,, -28400,1.681903,1.6422868,,,,,,,,,,,,,, -28500,1.7052251,1.6176671,,,,,,,,,,,,,, -28518,,,0.6774553656578064,1.2737376689910889,0.6197400093078613,1.5842479467391968,50000.0,0.4957000315189361,2.3093740940093994,10000.0,9724.112269163132,10083.73843884468,9724.112269163132,357.99346256256104,0.5839834213256836,0.0 -28600,1.7642602,1.7928363,,,,,,,,,,,,,, -28700,1.6567596,1.7579395,,,,,,,,,,,,,, -28800,1.8137866,1.7142041,,,,,,,,,,,,,, -28900,1.623168,1.7719926,,,,,,,,,,,,,, -29000,1.9208989,1.6799737,,,,,,,,,,,,,, -29100,1.6709706,1.7218323,,,,,,,,,,,,,, -29200,1.8394841,1.7376992,,,,,,,,,,,,,, -29300,1.7288766,1.7580957,,,,,,,,,,,,,, -29400,1.8409584,1.7765938,,,,,,,,,,,,,, -29500,1.5986277,1.7755948,,,,,,,,,,,,,, -29600,2.0637407,1.7211998,,,,,,,,,,,,,, -29700,1.8508849,1.658998,,,,,,,,,,,,,, -29800,2.0461073,1.669039,,,,,,,,,,,,,, -29900,2.0083985,1.7593461,,,,,,,,,,,,,, -30000,2.0904474,1.8040794,,,,,,,,,,,,,, -30021,,,0.6818000674247742,1.2576472759246826,0.6284199953079224,1.5533267259597778,50000.0,0.5035000443458557,2.2645206451416016,10000.0,10234.21173286438,10612.092364549637,10234.21173286438,376.1608896255493,0.6163861751556396,0.0 -30100,1.653449,1.6795433,,,,,,,,,,,,,, -30200,1.5496274,1.7507044,,,,,,,,,,,,,, -30300,1.9006226,1.7641505,,,,,,,,,,,,,, -30400,1.8185815,1.6420416,,,,,,,,,,,,,, -30500,1.7256887,1.715061,,,,,,,,,,,,,, -30600,1.6494374,1.7633082,,,,,,,,,,,,,, -30700,1.7124326,1.7268608,,,,,,,,,,,,,, -30800,1.7964978,1.7189384,,,,,,,,,,,,,, -30900,1.9308072,1.8294318,,,,,,,,,,,,,, -31000,1.6828587,1.7121084,,,,,,,,,,,,,, -31100,1.9140431,1.7028058,,,,,,,,,,,,,, -31200,1.7861965,1.6666417,,,,,,,,,,,,,, -31300,1.8032044,1.7054293,,,,,,,,,,,,,, -31400,1.697422,1.7457186,,,,,,,,,,,,,, -31500,1.8341947,1.7654732,,,,,,,,,,,,,, -31523,,,0.6839724183082581,1.2390429973602295,0.6254400014877319,1.5558955669403076,50000.0,0.5014000535011292,2.269998788833618,10000.0,10744.19499206543,11140.10380768776,10744.19499206543,394.0955708026886,0.6547572612762451,0.0 -31600,1.4419068,1.5951456,,,,,,,,,,,,,, -31700,1.5609157,1.6516758,,,,,,,,,,,,,, -31800,1.708526,1.7024943,,,,,,,,,,,,,, -31900,1.8020784,1.7070277,,,,,,,,,,,,,, -32000,1.8144418,1.6744354,,,,,,,,,,,,,, -32100,1.7792784,1.7838974,,,,,,,,,,,,,, -32200,1.6549331,1.7335207,,,,,,,,,,,,,, -32300,1.6636584,1.60992,,,,,,,,,,,,,, -32400,2.0563822,1.7829206,,,,,,,,,,,,,, -32500,1.7998189,1.6411713,,,,,,,,,,,,,, -32600,1.8208652,1.7261164,,,,,,,,,,,,,, -32700,1.8021435,1.7646619,,,,,,,,,,,,,, -32800,1.7274863,1.735745,,,,,,,,,,,,,, -32900,1.8287895,1.6096158,,,,,,,,,,,,,, -33000,1.8030158,1.6402141,,,,,,,,,,,,,, -33025,,,0.7113161683082581,1.1245791912078855,0.6225999593734741,1.5649104118347168,50000.0,0.492900013923645,2.32331657409668,10000.0,11254.327833890917,11667.965344667437,11254.327833890917,411.72698998451233,0.6967041492462158,0.0 -33100,1.7346019,1.7294133,,,,,,,,,,,,,, -33200,1.5940748,1.7231545,,,,,,,,,,,,,, -33300,1.752503,1.6770204,,,,,,,,,,,,,, -33400,1.8737832,1.773451,,,,,,,,,,,,,, -33500,1.961233,1.7268562,,,,,,,,,,,,,, -33600,1.8254625,1.9048033,,,,,,,,,,,,,, -33700,1.6030475,1.5323286,,,,,,,,,,,,,, -33800,1.8428707,1.7590616,,,,,,,,,,,,,, -33900,1.5349618,1.5742819,,,,,,,,,,,,,, -34000,1.7279791,1.729623,,,,,,,,,,,,,, -34100,2.0569704,1.7199183,,,,,,,,,,,,,, -34200,1.6788783,1.5597659,,,,,,,,,,,,,, -34300,1.6052815,1.6941632,,,,,,,,,,,,,, -34400,1.7254562,1.5937881,,,,,,,,,,,,,, -34500,1.87731,1.594881,,,,,,,,,,,,,, -34527,,,0.6932995915412903,1.201788306236267,0.624459981918335,1.5534310340881348,50000.0,0.4936000108718872,2.312537670135498,10000.0,11764.347550868988,12196.06828379631,11764.347550868988,429.7202451229096,0.731212854385376,0.0 -34600,1.6642979,1.6249285,,,,,,,,,,,,,, -34700,1.7043428,1.6806922,,,,,,,,,,,,,, -34800,1.9254044,1.8208374,,,,,,,,,,,,,, -34900,1.7773072,1.7285527,,,,,,,,,,,,,, -35000,1.7774459,1.6478914,,,,,,,,,,,,,, -35100,2.2504878,1.6904664,,,,,,,,,,,,,, -35200,1.7528915,1.6620333,,,,,,,,,,,,,, -35300,2.140822,1.6943026,,,,,,,,,,,,,, -35400,2.0814064,1.6765422,,,,,,,,,,,,,, -35500,1.6908436,1.6836973,,,,,,,,,,,,,, -35600,1.7441818,1.7817036,,,,,,,,,,,,,, -35700,1.7402731,1.6660641,,,,,,,,,,,,,, -35800,1.7284548,1.6154912,,,,,,,,,,,,,, -35900,1.9265412,1.6499732,,,,,,,,,,,,,, -36000,1.7410223,1.6472361,,,,,,,,,,,,,, -36030,,,0.6974449753761292,1.1905690431594849,0.6278799772262573,1.5417182445526123,50000.0,0.5020000338554382,2.281333446502685,10000.0,12274.329077005386,12723.898445367811,12274.329077005386,447.4782257080078,0.7672967910766602,0.0 -36100,1.7671194,1.8027482,,,,,,,,,,,,,, -36200,1.8757145,1.6721697,,,,,,,,,,,,,, -36300,1.6735208,1.727373,,,,,,,,,,,,,, -36400,1.8738211,1.812441,,,,,,,,,,,,,, -36500,1.7506353,1.6732178,,,,,,,,,,,,,, -36600,2.356602,1.7562115,,,,,,,,,,,,,, -36700,1.5937303,1.6365634,,,,,,,,,,,,,, -36800,1.8080175,1.6465166,,,,,,,,,,,,,, -36900,1.6140319,1.6515394,,,,,,,,,,,,,, -37000,1.8459111,1.7474376,,,,,,,,,,,,,, -37100,1.7321378,1.7585908,,,,,,,,,,,,,, -37200,1.8468332,1.6824077,,,,,,,,,,,,,, -37300,2.1202934,1.6761312,,,,,,,,,,,,,, -37400,1.7716497,1.6373878,,,,,,,,,,,,,, -37500,1.913829,1.7897192,,,,,,,,,,,,,, -37533,,,0.7012914419174194,1.1706738471984863,0.6375399827957153,1.4896409511566162,50000.0,0.5141000151634216,2.2257673740386963,10000.0,12784.488025665283,13252.32969903946,12784.488025665283,465.655154466629,0.8066210746765137,0.0 -37600,1.8177763,1.6420438,,,,,,,,,,,,,, -37700,2.079263,1.7172616,,,,,,,,,,,,,, -37800,1.5834205,1.6763679,,,,,,,,,,,,,, -37900,1.9359565,1.8214017,,,,,,,,,,,,,, -38000,1.9637603,1.7247832,,,,,,,,,,,,,, -38100,1.8170806,1.7891366,,,,,,,,,,,,,, -38200,1.8490248,1.6407387,,,,,,,,,,,,,, -38300,1.9493164,1.7840316,,,,,,,,,,,,,, -38400,1.8773041,1.6815107,,,,,,,,,,,,,, -38500,1.9735368,1.7148874,,,,,,,,,,,,,, -38600,1.8318212,1.6569587,,,,,,,,,,,,,, -38700,1.6968246,1.6186268,,,,,,,,,,,,,, -38800,1.7069179,1.6457635,,,,,,,,,,,,,, -38900,1.5670307,1.6659037,,,,,,,,,,,,,, -39000,1.9469692,1.5959766,,,,,,,,,,,,,, -39037,,,0.6754822731018066,1.2862061262130735,0.6221799850463867,1.5718142986297607,50000.0,0.4931000173091888,2.3207204341888428,10000.0,13294.657634973526,13780.520778179169,13294.657634973526,483.5669913291931,0.8619179725646973,0.0 -39100,1.8752849,1.6552949,,,,,,,,,,,,,, -39200,1.928156,1.6287365,,,,,,,,,,,,,, -39300,1.8084941,1.6737869,,,,,,,,,,,,,, -39400,1.9740528,1.7349396,,,,,,,,,,,,,, -39500,1.9180697,1.6321571,,,,,,,,,,,,,, -39600,1.7554225,1.5615475,,,,,,,,,,,,,, -39700,1.9772527,1.7183775,,,,,,,,,,,,,, -39800,1.989046,1.697478,,,,,,,,,,,,,, -39900,1.8092957,1.634403,,,,,,,,,,,,,, -40000,1.7539911,1.6939856,,,,,,,,,,,,,, -40100,1.5966376,1.6104072,,,,,,,,,,,,,, -40200,1.7229629,1.6352614,,,,,,,,,,,,,, -40300,1.8573799,1.6951048,,,,,,,,,,,,,, -40400,1.980447,1.698159,,,,,,,,,,,,,, -40500,2.2938483,1.6840134,,,,,,,,,,,,,, -40539,,,0.7030452489852905,1.1629226207733154,0.6412999629974365,1.4571741819381714,50000.0,0.5141000151634216,2.200263738632202,10000.0,13804.57847905159,14308.516256809236,13804.57847905159,501.5477590560913,0.9012980461120604,0.0 -40600,1.5993166,1.5778446,,,,,,,,,,,,,, -40700,1.7729056,1.63924,,,,,,,,,,,,,, -40800,1.6954886,1.749423,,,,,,,,,,,,,, -40900,1.7230418,1.6896551,,,,,,,,,,,,,, -41000,1.8199661,1.5989752,,,,,,,,,,,,,, -41100,1.727265,1.7342024,,,,,,,,,,,,,, -41200,1.7465222,1.5807288,,,,,,,,,,,,,, -41300,1.8944031,1.6717423,,,,,,,,,,,,,, -41400,1.7531831,1.6200224,,,,,,,,,,,,,, -41500,1.7681899,1.6565183,,,,,,,,,,,,,, -41600,1.8935508,1.6011415,,,,,,,,,,,,,, -41700,1.7728709,1.5958092,,,,,,,,,,,,,, -41800,2.067941,1.6319118,,,,,,,,,,,,,, -41900,1.7384647,1.7117233,,,,,,,,,,,,,, -42000,1.7543927,1.5621377,,,,,,,,,,,,,, -42041,,,0.7258250713348389,1.0648127794265747,0.6235199570655823,1.5412627458572388,50000.0,0.4976000189781189,2.3147168159484863,10000.0,14314.496087789536,14836.219805955889,14314.496087789536,519.2402155399323,0.9375874996185304,0.0 -42100,1.9260345,1.6796439,,,,,,,,,,,,,, -42200,1.8875003,1.7031351,,,,,,,,,,,,,, -42300,1.7415568,1.645935,,,,,,,,,,,,,, -42400,1.9999366,1.6496253,,,,,,,,,,,,,, -42500,1.7752492,1.6993513,,,,,,,,,,,,,, -42600,1.8488519,1.6054924,,,,,,,,,,,,,, -42700,1.9597658,1.6508179,,,,,,,,,,,,,, -42800,1.9865929,1.5687858,,,,,,,,,,,,,, -42900,1.7951964,1.5798935,,,,,,,,,,,,,, -43000,1.7785417,1.5982673,,,,,,,,,,,,,, -43100,1.9240625,1.6223631,,,,,,,,,,,,,, -43200,1.607543,1.5554844,,,,,,,,,,,,,, -43300,1.6603637,1.6813872,,,,,,,,,,,,,, -43400,1.7825812,1.7107732,,,,,,,,,,,,,, -43500,1.8126822,1.714396,,,,,,,,,,,,,, -43544,,,0.7102598547935486,1.1271930932998655,0.6373999714851379,1.4905970096588137,50000.0,0.5039000511169434,2.242692947387696,10000.0,14824.710835456848,15364.400060892103,14824.710835456848,537.1130058765411,0.9754059314727784,0.0 -43600,2.0297384,1.6422461,,,,,,,,,,,,,, -43700,1.7585046,1.586827,,,,,,,,,,,,,, -43800,1.7830796,1.5853906,,,,,,,,,,,,,, -43900,1.7878073,1.7030855,,,,,,,,,,,,,, -44000,2.1993768,1.6688977,,,,,,,,,,,,,, -44100,1.9383565,1.6638094,,,,,,,,,,,,,, -44200,1.7859374,1.5664757,,,,,,,,,,,,,, -44300,1.8387741,1.5980797,,,,,,,,,,,,,, -44400,2.041244,1.6099055,,,,,,,,,,,,,, -44500,1.9741926,1.5827531,,,,,,,,,,,,,, -44600,1.7552583,1.5679342,,,,,,,,,,,,,, -44700,1.7201732,1.5845687,,,,,,,,,,,,,, -44800,1.7011226,1.5680135,,,,,,,,,,,,,, -44900,1.9142843,1.6442348,,,,,,,,,,,,,, -45000,1.8578867,1.6128408,,,,,,,,,,,,,, -45048,,,0.7119140625,1.117753505706787,0.6432799696922302,1.4690630435943604,50000.0,0.5164999961853027,2.19850754737854,10000.0,15334.921788215635,15892.65364933014,15334.921788215635,555.0589742660522,1.01684308052063,0.0 -45100,2.2530572,1.7825774,,,,,,,,,,,,,, -45200,1.759243,1.6749867,,,,,,,,,,,,,, -45300,2.0438914,1.6599945,,,,,,,,,,,,,, -45400,1.915775,1.6328564,,,,,,,,,,,,,, -45500,1.9571639,1.6148734,,,,,,,,,,,,,, -45600,1.7465032,1.6326458,,,,,,,,,,,,,, -45700,1.8827131,1.5821698,,,,,,,,,,,,,, -45800,2.0158029,1.7022856,,,,,,,,,,,,,, -45900,1.8868803,1.7469469,,,,,,,,,,,,,, -46000,1.6812896,1.6217916,,,,,,,,,,,,,, -46100,1.8792279,1.5754915,,,,,,,,,,,,,, -46200,1.8042272,1.6387092,,,,,,,,,,,,,, -46300,1.8259858,1.6025628,,,,,,,,,,,,,, -46400,1.6962272,1.5594938,,,,,,,,,,,,,, -46500,2.0058193,1.6221553,,,,,,,,,,,,,, -46549,,,0.6991389989852905,1.1732345819473269,0.6344599723815918,1.5220282077789309,50000.0,0.5101000070571899,2.2613413333892822,10000.0,15844.842143058777,16420.54642868042,15844.842143058777,572.9238193035126,1.068709373474121,0.0 -46600,1.7598583,1.6107539,,,,,,,,,,,,,, -46700,1.8701214,1.6240966,,,,,,,,,,,,,, -46800,1.8737928,1.6015811,,,,,,,,,,,,,, -46900,2.0637975,1.6546311,,,,,,,,,,,,,, -47000,2.0152893,1.6600167,,,,,,,,,,,,,, -47100,1.7877867,1.6029652,,,,,,,,,,,,,, -47200,1.8278339,1.6257758,,,,,,,,,,,,,, -47300,1.8022125,1.547922,,,,,,,,,,,,,, -47400,1.7878717,1.5213606,,,,,,,,,,,,,, -47500,2.0804033,1.6594107,,,,,,,,,,,,,, -47600,1.9368265,1.5913986,,,,,,,,,,,,,, -47700,1.818208,1.5834454,,,,,,,,,,,,,, -47800,1.9288887,1.5893428,,,,,,,,,,,,,, -47900,1.9434159,1.6509141,,,,,,,,,,,,,, -48000,2.1616511,1.6674973,,,,,,,,,,,,,, -48053,,,0.7051379084587097,1.1431678533554075,0.6401799917221069,1.4874356985092163,50000.0,0.5202000141143799,2.180363893508911,10000.0,16355.0407371521,16948.419927835464,16355.0407371521,590.5062322616577,1.105797290802002,0.0 -48100,1.8393549,1.635782,,,,,,,,,,,,,, -48200,1.7603086,1.5267911,,,,,,,,,,,,,, -48300,2.12626,1.7354097,,,,,,,,,,,,,, -48400,1.88186,1.5997857,,,,,,,,,,,,,, -48500,1.9143988,1.6050112,,,,,,,,,,,,,, -48600,1.8048509,1.6218704,,,,,,,,,,,,,, -48700,1.8927544,1.611135,,,,,,,,,,,,,, -48800,1.8118917,1.6474376,,,,,,,,,,,,,, -48900,2.0528307,1.6292548,,,,,,,,,,,,,, -49000,1.7848288,1.561706,,,,,,,,,,,,,, -49100,1.9088404,1.5972215,,,,,,,,,,,,,, -49200,1.8299625,1.5616833,,,,,,,,,,,,,, -49300,1.643105,1.6265101,,,,,,,,,,,,,, -49400,1.7272271,1.5144985,,,,,,,,,,,,,, -49500,1.7496176,1.4721234,,,,,,,,,,,,,, -49556,,,0.6994977593421936,1.168281078338623,0.638759970664978,1.4861197471618652,50000.0,0.5054000020027161,2.229286432266236,10000.0,16865.1335606575,17476.397280454636,16865.1335606575,608.299302816391,1.1421661376953125,0.0 -49600,1.7966621,1.5957721,,,,,,,,,,,,,, -49700,1.8921133,1.571388,,,,,,,,,,,,,, -49800,1.8959355,1.5926856,,,,,,,,,,,,,, -49900,1.8934634,1.6537915,,,,,,,,,,,,,, -50000,1.7499844,1.6554956,,,,,,,,,,,,,, -50100,1.8444611,1.6419008,,,,,,,,,,,,,, -50200,2.1244233,1.6118641,,,,,,,,,,,,,, -50300,2.0442398,1.5114126,,,,,,,,,,,,,, -50400,1.7391673,1.5245281,,,,,,,,,,,,,, -50500,1.9113584,1.6324303,,,,,,,,,,,,,, -50600,1.8362976,1.6246003,,,,,,,,,,,,,, -50700,1.8578956,1.5616019,,,,,,,,,,,,,, -50800,2.1469016,1.7657043,,,,,,,,,,,,,, -50900,1.7706697,1.576591,,,,,,,,,,,,,, -51000,1.8419702,1.6388792,,,,,,,,,,,,,, -51060,,,0.7434629797935486,0.9698396325111388,0.6350199580192566,1.5050801038742063,50000.0,0.5080000162124634,2.239075422286988,10000.0,17375.330893039703,18004.835172891617,17375.330893039703,626.4494802951813,1.178267478942871,0.0 -51100,1.8359612,1.7093295,,,,,,,,,,,,,, -51200,2.0542877,1.6136026,,,,,,,,,,,,,, -51300,2.0745804,1.6870034,,,,,,,,,,,,,, -51400,2.2562785,1.5357858,,,,,,,,,,,,,, -51500,1.958192,1.5666838,,,,,,,,,,,,,, -51600,1.9573408,1.5668006,,,,,,,,,,,,,, -51700,1.665048,1.6329169,,,,,,,,,,,,,, -51800,1.7443388,1.6147994,,,,,,,,,,,,,, -51900,2.1087162,1.6832381,,,,,,,,,,,,,, -52000,1.8061985,1.5577605,,,,,,,,,,,,,, -52100,1.8612347,1.6331872,,,,,,,,,,,,,, -52200,1.9479536,1.6300286,,,,,,,,,,,,,, -52300,1.6943372,1.5902362,,,,,,,,,,,,,, -52400,1.8486542,1.5777065,,,,,,,,,,,,,, -52500,1.8158755,1.5580243,,,,,,,,,,,,,, -52563,,,0.7270607352256775,1.0618703365325928,0.6416199803352356,1.4836870431900024,50000.0,0.5177000164985657,2.173795223236084,10000.0,17885.269745588303,18532.783682346344,17885.269745588303,644.3647968769073,1.2180454730987549,0.0 -52600,1.9018915,1.6023004,,,,,,,,,,,,,, -52700,1.8639951,1.5878835,,,,,,,,,,,,,, -52800,1.8698477,1.5749245,,,,,,,,,,,,,, -52900,1.8959665,1.694817,,,,,,,,,,,,,, -53000,1.8424273,1.6656275,,,,,,,,,,,,,, -53100,1.817902,1.5469093,,,,,,,,,,,,,, -53200,1.7709975,1.6500442,,,,,,,,,,,,,, -53300,2.0677388,1.5949214,,,,,,,,,,,,,, -53400,1.9355202,1.6342938,,,,,,,,,,,,,, -53500,2.025416,1.5737138,,,,,,,,,,,,,, -53600,1.8588539,1.6436572,,,,,,,,,,,,,, -53700,2.1418972,1.6933953,,,,,,,,,,,,,, -53800,2.1281204,1.5767289,,,,,,,,,,,,,, -53900,1.8935499,1.6105549,,,,,,,,,,,,,, -54000,2.034752,1.5927181,,,,,,,,,,,,,, -54067,,,0.70316481590271,1.1438255310058594,0.6356399655342102,1.5182785987854004,50000.0,0.5090000033378601,2.266648769378662,10000.0,18395.33567595482,19060.81605863571,18395.33567595482,662.2336344718933,1.2604291439056396,0.0 -54100,2.0147655,1.6404601,,,,,,,,,,,,,, -54200,1.8326135,1.5284717,,,,,,,,,,,,,, -54300,2.145308,1.5996131,,,,,,,,,,,,,, -54400,1.9121871,1.6644253,,,,,,,,,,,,,, -54500,2.0020556,1.580644,,,,,,,,,,,,,, -54600,2.140771,1.6223621,,,,,,,,,,,,,, -54700,1.8904378,1.5830338,,,,,,,,,,,,,, -54800,1.8695731,1.6087878,,,,,,,,,,,,,, -54900,1.9152243,1.6591439,,,,,,,,,,,,,, -55000,2.0332341,1.5971454,,,,,,,,,,,,,, -55100,1.7837625,1.6051915,,,,,,,,,,,,,, -55200,1.991177,1.5346215,,,,,,,,,,,,,, -55300,1.9492549,1.5384521,,,,,,,,,,,,,, -55400,1.8918003,1.5445381,,,,,,,,,,,,,, -55500,1.9501786,1.6037704,,,,,,,,,,,,,, -55570,,,0.7121531963348389,1.1044635772705078,0.6428200006484985,1.4710193872451782,50000.0,0.5169000029563904,2.2187652587890625,10000.0,18905.46354651451,19589.064910411835,18905.46354651451,680.2520830631256,1.3046739101409912,0.0 -55600,2.039848,1.5896478,,,,,,,,,,,,,, -55700,1.9225895,1.6564671,,,,,,,,,,,,,, -55800,1.9655157,1.5756956,,,,,,,,,,,,,, -55900,1.9653577,1.6147164,,,,,,,,,,,,,, -56000,2.0916355,1.5494125,,,,,,,,,,,,,, -56100,2.0192893,1.6938007,,,,,,,,,,,,,, -56200,1.9337355,1.6410189,,,,,,,,,,,,,, -56300,1.6983045,1.5449419,,,,,,,,,,,,,, -56400,1.9030528,1.5775217,,,,,,,,,,,,,, -56500,1.961319,1.6069994,,,,,,,,,,,,,, -56600,1.9116484,1.5607698,,,,,,,,,,,,,, -56700,1.9668698,1.5611047,,,,,,,,,,,,,, -56800,1.8948207,1.6400231,,,,,,,,,,,,,, -56900,1.8318353,1.6649362,,,,,,,,,,,,,, -57000,2.0161996,1.6594442,,,,,,,,,,,,,, -57074,,,0.7025071382522583,1.1650789976119995,0.6407999992370605,1.4790819883346558,50000.0,0.5103000402450562,2.253450393676758,10000.0,19415.57575273513,20116.90283894539,19415.57575273513,697.8852643966675,1.3421645164489746,0.0 -57100,1.9443629,1.5667037,,,,,,,,,,,,,, -57200,1.624176,1.4626062,,,,,,,,,,,,,, -57300,2.0510392,1.599389,,,,,,,,,,,,,, -57400,1.9050983,1.6354623,,,,,,,,,,,,,, -57500,1.9783263,1.6753836,,,,,,,,,,,,,, -57600,2.0519385,1.5260005,,,,,,,,,,,,,, -57700,2.3329654,1.6299475,,,,,,,,,,,,,, -57800,1.8835576,1.5283481,,,,,,,,,,,,,, -57900,1.8295481,1.5902866,,,,,,,,,,,,,, -58000,2.0057282,1.5938019,,,,,,,,,,,,,, -58100,2.0526803,1.7023933,,,,,,,,,,,,,, -58200,1.9283775,1.7008903,,,,,,,,,,,,,, -58300,1.7831748,1.5804572,,,,,,,,,,,,,, -58400,2.1096792,1.5013773,,,,,,,,,,,,,, -58500,1.9391692,1.5790864,,,,,,,,,,,,,, -58577,,,0.7065728306770325,1.1429468393325806,0.6433799862861633,1.4730268716812134,50000.0,0.5179000496864319,2.2345850467681885,10000.0,19925.48170900345,20645.44976592064,19925.48170900345,716.4295771121979,1.383310317993164,0.0 -58600,1.9391854,1.52338,,,,,,,,,,,,,, -58700,1.9246099,1.5359018,,,,,,,,,,,,,, -58800,1.9295379,1.4536322,,,,,,,,,,,,,, -58900,1.8289709,1.6166694,,,,,,,,,,,,,, -59000,1.9544461,1.566452,,,,,,,,,,,,,, -59100,1.9282522,1.5022157,,,,,,,,,,,,,, -59200,2.1138833,1.7683014,,,,,,,,,,,,,, -59300,2.047937,1.5188262,,,,,,,,,,,,,, -59400,2.0257514,1.6797831,,,,,,,,,,,,,, -59500,1.9225563,1.5943005,,,,,,,,,,,,,, -59600,1.8180072,1.6107256,,,,,,,,,,,,,, -59700,1.9484304,1.5258884,,,,,,,,,,,,,, -59800,2.0572052,1.5567913,,,,,,,,,,,,,, -59900,2.0215569,1.5709283,,,,,,,,,,,,,, -60000,2.1986074,1.5501382,,,,,,,,,,,,,, -60081,,,0.7592673897743225,0.929104745388031,0.658079981803894,1.402864933013916,50000.0,0.5254000425338745,2.1215386390686035,10000.0,20435.709728956223,21173.62443089485,20435.709728956223,734.2747831344604,1.4285414218902588,0.0 -60100,1.8213279,1.636739,,,,,,,,,,,,,, -60200,1.839675,1.5572281,,,,,,,,,,,,,, -60300,1.8423618,1.55983,,,,,,,,,,,,,, -60400,1.9078466,1.5563653,,,,,,,,,,,,,, -60500,1.8431039,1.3911217,,,,,,,,,,,,,, -60600,1.9678204,1.5507493,,,,,,,,,,,,,, -60700,1.9416724,1.4545484,,,,,,,,,,,,,, -60800,1.8287846,1.5748812,,,,,,,,,,,,,, -60900,1.7239891,1.6307541,,,,,,,,,,,,,, -61000,1.9743918,1.5158573,,,,,,,,,,,,,, -61100,1.9987947,1.4559,,,,,,,,,,,,,, -61200,2.0454204,1.6250255,,,,,,,,,,,,,, -61300,1.9051523,1.5535694,,,,,,,,,,,,,, -61400,1.9237236,1.609716,,,,,,,,,,,,,, -61500,1.8413492,1.6188796,,,,,,,,,,,,,, -61584,,,0.73636794090271,1.0095868110656738,0.6518799662590027,1.4322763681411743,50000.0,0.5242000222206116,2.153595447540283,10000.0,20945.643161058422,21701.610381364822,20945.643161058422,752.2365992069244,1.4622957706451416,0.0 -61600,2.0521066,1.5640125,,,,,,,,,,,,,, -61700,1.9965506,1.5109155,,,,,,,,,,,,,, -61800,1.8814971,1.6050198,,,,,,,,,,,,,, -61900,1.8132833,1.4844773,,,,,,,,,,,,,, -62000,1.8913561,1.5818119,,,,,,,,,,,,,, -62100,1.9102335,1.5629033,,,,,,,,,,,,,, -62200,2.0126042,1.5800164,,,,,,,,,,,,,, -62300,1.8700151,1.5384303,,,,,,,,,,,,,, -62400,2.0661244,1.5814437,,,,,,,,,,,,,, -62500,1.9956084,1.5181019,,,,,,,,,,,,,, -62600,1.9363633,1.4320298,,,,,,,,,,,,,, -62700,2.3198984,1.7753985,,,,,,,,,,,,,, -62800,1.7528211,1.4405907,,,,,,,,,,,,,, -62900,1.7918605,1.472116,,,,,,,,,,,,,, -63000,2.200808,1.676567,,,,,,,,,,,,,, -63088,,,0.7388990521430969,0.9919620156288148,0.6604399681091309,1.3843040466308594,50000.0,0.5323000550270081,2.100260734558105,10000.0,21455.75441765785,22229.440237522125,21455.75441765785,769.8576102256775,1.5023293495178225,0.0 -63100,1.9566828,1.5418829,,,,,,,,,,,,,, -63200,1.8489677,1.5266157,,,,,,,,,,,,,, -63300,1.7034168,1.5181347,,,,,,,,,,,,,, -63400,1.8307323,1.574626,,,,,,,,,,,,,, -63500,1.8401299,1.4834111,,,,,,,,,,,,,, -63600,1.9452367,1.3946886,,,,,,,,,,,,,, -63700,2.096981,1.5207816,,,,,,,,,,,,,, -63800,2.0235112,1.564183,,,,,,,,,,,,,, -63900,1.7868814,1.4450877,,,,,,,,,,,,,, -64000,2.0444388,1.5586705,,,,,,,,,,,,,, -64100,2.4310167,1.6432209,,,,,,,,,,,,,, -64200,2.0110822,1.5112913,,,,,,,,,,,,,, -64300,1.9723715,1.4940944,,,,,,,,,,,,,, -64400,1.9089764,1.5078074,,,,,,,,,,,,,, -64500,1.9230123,1.6585821,,,,,,,,,,,,,, -64591,,,0.7138074040412903,1.106141448020935,0.6453199982643127,1.459357976913452,50000.0,0.51910001039505,2.223273992538452,10000.0,21965.73911070824,22757.195997476578,21965.73911070824,787.5250680446625,1.551323413848877,0.0 -64600,1.9052885,1.4632667,,,,,,,,,,,,,, -64700,1.9303397,1.5353007,,,,,,,,,,,,,, -64800,1.8416526,1.5264627,,,,,,,,,,,,,, -64900,1.9036983,1.5788977,,,,,,,,,,,,,, -65000,2.096432,1.4752352,,,,,,,,,,,,,, -65100,2.0335765,1.5781124,,,,,,,,,,,,,, -65200,1.8095604,1.4315658,,,,,,,,,,,,,, -65300,2.0932648,1.7007602,,,,,,,,,,,,,, -65400,2.0192184,1.4177812,,,,,,,,,,,,,, -65500,1.8899407,1.4326673,,,,,,,,,,,,,, -65600,1.9457772,1.5753347,,,,,,,,,,,,,, -65700,2.1548023,1.6373445,,,,,,,,,,,,,, -65800,1.8392689,1.4768314,,,,,,,,,,,,,, -65900,2.0988848,1.4918283,,,,,,,,,,,,,, -66000,1.9037212,1.4619421,,,,,,,,,,,,,, -66095,,,0.7254065275192261,1.0538322925567627,0.6597200036048889,1.3996851444244385,50000.0,0.5354000329971313,2.1329658031463623,10000.0,22475.68429350853,23285.028936624527,22475.68429350853,805.305394411087,1.60355544090271,0.0 -66100,1.8424271,1.5135684,,,,,,,,,,,,,, -66200,2.2471743,1.5749768,,,,,,,,,,,,,, -66300,2.251278,1.5395534,,,,,,,,,,,,,, -66400,2.0831041,1.6194472,,,,,,,,,,,,,, -66500,2.0853162,1.5278583,,,,,,,,,,,,,, -66600,1.9065527,1.5529482,,,,,,,,,,,,,, -66700,1.8092817,1.462575,,,,,,,,,,,,,, -66800,1.9091698,1.4288895,,,,,,,,,,,,,, -66900,2.059137,1.5895821,,,,,,,,,,,,,, -67000,2.0333009,1.45645,,,,,,,,,,,,,, -67100,2.0441425,1.5626478,,,,,,,,,,,,,, -67200,2.0238593,1.5401849,,,,,,,,,,,,,, -67300,1.9641153,1.579458,,,,,,,,,,,,,, -67400,1.8410679,1.4287951,,,,,,,,,,,,,, -67500,1.7025604,1.46542,,,,,,,,,,,,,, -67598,,,0.7191087007522583,1.084785223007202,0.6524999737739563,1.4338654279708862,50000.0,0.5164999961853027,2.1947619915008545,10000.0,22985.69870376587,23813.08985543251,22985.69870376587,823.2509181499481,1.649055242538452,0.0 -67600,2.0326219,1.542855,,,,,,,,,,,,,, -67700,2.041691,1.5051495,,,,,,,,,,,,,, -67800,2.2160046,1.6202374,,,,,,,,,,,,,, -67900,2.3817482,1.600666,,,,,,,,,,,,,, -68000,1.8930423,1.447429,,,,,,,,,,,,,, -68100,2.1898162,1.5545504,,,,,,,,,,,,,, -68200,2.0688155,1.3883458,,,,,,,,,,,,,, -68300,2.2523885,1.5388889,,,,,,,,,,,,,, -68400,1.8922505,1.5113599,,,,,,,,,,,,,, -68500,1.9223065,1.4710538,,,,,,,,,,,,,, -68600,1.8890959,1.5192665,,,,,,,,,,,,,, -68700,2.079569,1.4654243,,,,,,,,,,,,,, -68800,2.1333826,1.4782437,,,,,,,,,,,,,, -68900,1.959854,1.5200775,,,,,,,,,,,,,, -69000,1.9079616,1.5362318,,,,,,,,,,,,,, -69100,1.8735248,1.4037911,,,,,,,,,,,,,, -69101,,,0.7221181392669678,1.076885223388672,0.6430599689483643,1.46754789352417,50000.0,0.5143000483512878,2.2291743755340576,10000.0,23495.98047399521,24341.53601646424,23495.98047399521,841.3075633049011,1.700188159942627,0.0 -69200,2.1586328,1.4012278,,,,,,,,,,,,,, -69300,2.0252566,1.5599898,,,,,,,,,,,,,, -69400,2.152269,1.578856,,,,,,,,,,,,,, -69500,1.8505129,1.5127313,,,,,,,,,,,,,, -69600,1.9400761,1.4848433,,,,,,,,,,,,,, -69700,1.9621186,1.4582294,,,,,,,,,,,,,, -69800,2.1215389,1.5469699,,,,,,,,,,,,,, -69900,2.0480886,1.4713845,,,,,,,,,,,,,, -70000,1.9693741,1.5395834,,,,,,,,,,,,,, -70100,1.7853702,1.5461106,,,,,,,,,,,,,, -70200,1.9604672,1.3869212,,,,,,,,,,,,,, -70300,1.8923072,1.4553626,,,,,,,,,,,,,, -70400,2.3980963,1.6006166,,,,,,,,,,,,,, -70500,2.083424,1.4575558,,,,,,,,,,,,,, -70600,1.9195136,1.468641,,,,,,,,,,,,,, -70605,,,0.7581911683082581,0.9060986638069152,0.6655600070953369,1.3723647594451904,50000.0,0.5434000492095947,2.08642315864563,10000.0,24006.12465786934,24869.49318599701,24006.12465786934,859.0242967605591,1.7394630908966064,0.0 -70700,2.0991552,1.5294131,,,,,,,,,,,,,, -70800,2.031522,1.5594819,,,,,,,,,,,,,, -70900,2.2059891,1.6438185,,,,,,,,,,,,,, -71000,2.084591,1.498662,,,,,,,,,,,,,, -71100,1.7323861,1.3652202,,,,,,,,,,,,,, -71200,2.1244068,1.5311049,,,,,,,,,,,,,, -71300,2.0228593,1.6101635,,,,,,,,,,,,,, -71400,2.172771,1.4755001,,,,,,,,,,,,,, -71500,1.931476,1.5716721,,,,,,,,,,,,,, -71600,2.3543773,1.6239687,,,,,,,,,,,,,, -71700,2.1660662,1.569323,,,,,,,,,,,,,, -71800,1.9086051,1.4009204,,,,,,,,,,,,,, -71900,2.1560547,1.5506151,,,,,,,,,,,,,, -72000,1.9911826,1.4511921,,,,,,,,,,,,,, -72100,1.9098874,1.5071088,,,,,,,,,,,,,, -72109,,,0.7310466766357422,1.0297045707702637,0.6527999639511108,1.4410232305526731,50000.0,0.5228000283241272,2.215155601501465,10000.0,24516.31637406349,25397.42277216912,24516.31637406349,876.6616532802582,1.7846426963806152,0.0 -72200,2.117472,1.5220906,,,,,,,,,,,,,, -72300,2.0668256,1.5348711,,,,,,,,,,,,,, -72400,2.2521849,1.5818555,,,,,,,,,,,,,, -72500,2.059749,1.5089049,,,,,,,,,,,,,, -72600,2.1423802,1.3665016,,,,,,,,,,,,,, -72700,1.9019921,1.4088223,,,,,,,,,,,,,, -72800,2.1330829,1.4699568,,,,,,,,,,,,,, -72900,1.8463385,1.4664764,,,,,,,,,,,,,, -73000,2.0054507,1.5526782,,,,,,,,,,,,,, -73100,2.1433206,1.5708703,,,,,,,,,,,,,, -73200,2.2058754,1.5159742,,,,,,,,,,,,,, -73300,2.1624959,1.3798945,,,,,,,,,,,,,, -73400,1.90758,1.3928415,,,,,,,,,,,,,, -73500,1.9778942,1.484847,,,,,,,,,,,,,, -73600,2.0491815,1.3815362,,,,,,,,,,,,,, -73611,,,0.7375039458274841,0.9993064999580384,0.6623799800872803,1.390560746192932,50000.0,0.5325000286102295,2.110114812850952,10000.0,25026.27029204369,25925.412529945374,25026.27029204369,894.5995259284973,1.8255927562713623,0.0 -73700,1.9254999,1.4214135,,,,,,,,,,,,,, -73800,2.099902,1.4278159,,,,,,,,,,,,,, -73900,2.127434,1.5541673,,,,,,,,,,,,,, -74000,1.9598078,1.4044313,,,,,,,,,,,,,, -74100,2.0684335,1.4156667,,,,,,,,,,,,,, -74200,2.0682983,1.6111095,,,,,,,,,,,,,, -74300,1.8947229,1.5344852,,,,,,,,,,,,,, -74400,2.087146,1.4619341,,,,,,,,,,,,,, -74500,2.161768,1.4861249,,,,,,,,,,,,,, -74600,2.176046,1.6365155,,,,,,,,,,,,,, -74700,2.117852,1.5028996,,,,,,,,,,,,,, -74800,2.1653779,1.6270298,,,,,,,,,,,,,, -74900,1.979013,1.453345,,,,,,,,,,,,,, -75000,2.2501833,1.4884367,,,,,,,,,,,,,, -75100,2.2196221,1.506376,,,,,,,,,,,,,, -75115,,,0.7437619566917419,0.9866275191307068,0.6665399670600891,1.3468725681304932,50000.0,0.5391000509262085,2.0707874298095703,10000.0,25536.4892745018,26453.562220811844,25536.4892745018,912.429455280304,1.87069034576416,0.0 -75200,2.4130337,1.5565448,,,,,,,,,,,,,, -75300,2.177486,1.52424,,,,,,,,,,,,,, -75400,2.0164235,1.4769944,,,,,,,,,,,,,, -75500,2.1718802,1.5521381,,,,,,,,,,,,,, -75600,2.9136503,1.5212135,,,,,,,,,,,,,, -75700,2.2286165,1.4754751,,,,,,,,,,,,,, -75800,2.1286857,1.518881,,,,,,,,,,,,,, -75900,2.0412984,1.5991764,,,,,,,,,,,,,, -76000,2.2726574,1.4873214,,,,,,,,,,,,,, -76100,2.098136,1.5033902,,,,,,,,,,,,,, -76200,2.4876688,1.4329975,,,,,,,,,,,,,, -76300,2.5046213,1.4412773,,,,,,,,,,,,,, -76400,2.0827374,1.4983853,,,,,,,,,,,,,, -76500,1.9455377,1.4319484,,,,,,,,,,,,,, -76600,2.165535,1.5366704,,,,,,,,,,,,,, -76618,,,0.7401546239852905,0.9870557188987732,0.6706399917602539,1.3400790691375732,50000.0,0.5473999977111816,2.035409450531006,10000.0,26046.512956619263,26981.249529123303,26046.512956619263,929.9996762275696,1.908890724182129,0.0 -76700,2.1537256,1.5056466,,,,,,,,,,,,,, -76800,2.080247,1.4537601,,,,,,,,,,,,,, -76900,2.0694392,1.4293339,,,,,,,,,,,,,, -77000,2.1524549,1.5820793,,,,,,,,,,,,,, -77100,2.2705734,1.4902868,,,,,,,,,,,,,, -77200,2.1144667,1.5000303,,,,,,,,,,,,,, -77300,2.0338793,1.5079421,,,,,,,,,,,,,, -77400,2.3344219,1.5120666,,,,,,,,,,,,,, -77500,2.2613547,1.436867,,,,,,,,,,,,,, -77600,2.4420342,1.4773195,,,,,,,,,,,,,, -77700,2.274978,1.3918619,,,,,,,,,,,,,, -77800,2.1611516,1.4926085,,,,,,,,,,,,,, -77900,2.2445226,1.534386,,,,,,,,,,,,,, -78000,1.9430287,1.3624482,,,,,,,,,,,,,, -78100,2.184231,1.3965952,,,,,,,,,,,,,, -78122,,,0.7384207248687744,1.0084189176559448,0.665340006351471,1.3739938735961914,50000.0,0.5392000079154968,2.11980938911438,10000.0,26556.6778280735,27509.383778572083,26556.6778280735,947.8705537319184,1.9521074295043943,0.0 -78200,2.2203493,1.5445495,,,,,,,,,,,,,, -78300,2.2090194,1.4477733,,,,,,,,,,,,,, -78400,2.36484,1.4650909,,,,,,,,,,,,,, -78500,2.3748813,1.5480043,,,,,,,,,,,,,, -78600,2.3646088,1.5185074,,,,,,,,,,,,,, -78700,2.2994647,1.3787861,,,,,,,,,,,,,, -78800,2.3866737,1.4593078,,,,,,,,,,,,,, -78900,2.0721426,1.424949,,,,,,,,,,,,,, -79000,2.0993333,1.510956,,,,,,,,,,,,,, -79100,2.066069,1.4630786,,,,,,,,,,,,,, -79200,2.185969,1.4545935,,,,,,,,,,,,,, -79300,2.3302033,1.4120953,,,,,,,,,,,,,, -79400,2.0873053,1.4906027,,,,,,,,,,,,,, -79500,2.0795362,1.4943881,,,,,,,,,,,,,, -79600,2.2494504,1.4641789,,,,,,,,,,,,,, -79626,,,0.767578125,0.8716039061546326,0.6688399910926819,1.3541743755340576,50000.0,0.5445000529289246,2.0731024742126465,10000.0,27066.751952409744,28037.19790363312,27066.751952409744,965.515125989914,1.991746425628662,0.0 -79700,2.1586363,1.5325923,,,,,,,,,,,,,, -79800,2.2476218,1.4083357,,,,,,,,,,,,,, -79900,2.1349251,1.3522296,,,,,,,,,,,,,, -80000,2.5074308,1.3365941,,,,,,,,,,,,,, -80100,2.2954547,1.5095649,,,,,,,,,,,,,, -80200,2.0804963,1.4626772,,,,,,,,,,,,,, -80300,2.1724257,1.4356563,,,,,,,,,,,,,, -80400,2.0769958,1.4340433,,,,,,,,,,,,,, -80500,2.3906205,1.4610245,,,,,,,,,,,,,, -80600,2.1923354,1.5666608,,,,,,,,,,,,,, -80700,2.4018424,1.4980924,,,,,,,,,,,,,, -80800,2.2845469,1.4777471,,,,,,,,,,,,,, -80900,2.1433496,1.366183,,,,,,,,,,,,,, -81000,2.1400135,1.3792236,,,,,,,,,,,,,, -81100,2.177797,1.3073552,,,,,,,,,,,,,, -81130,,,0.7565369606018066,0.9142917394638062,0.6692000031471252,1.3499724864959717,50000.0,0.5324000120162964,2.085909128189087,10000.0,27576.86024737358,28565.35233569145,27576.86024737358,983.458841085434,2.036340713500977,0.0 -81200,2.0522408,1.3117411,,,,,,,,,,,,,, -81300,2.063443,1.3345792,,,,,,,,,,,,,, -81400,2.1946511,1.4538958,,,,,,,,,,,,,, -81500,2.3783758,1.4831002,,,,,,,,,,,,,, -81600,2.1589189,1.4135962,,,,,,,,,,,,,, -81700,2.3835166,1.4664549,,,,,,,,,,,,,, -81800,2.0805948,1.4346852,,,,,,,,,,,,,, -81900,2.1414232,1.445184,,,,,,,,,,,,,, -82000,1.9454539,1.4176493,,,,,,,,,,,,,, -82100,2.2017121,1.3789394,,,,,,,,,,,,,, -82200,2.1954997,1.4777391,,,,,,,,,,,,,, -82300,2.3594868,1.4665977,,,,,,,,,,,,,, -82400,2.5183613,1.4913176,,,,,,,,,,,,,, -82500,2.2461019,1.4197595,,,,,,,,,,,,,, -82600,2.0555205,1.3980776,,,,,,,,,,,,,, -82633,,,0.7384008169174194,1.0000159740447998,0.6617599725723267,1.3959338665008545,50000.0,0.5343000292778015,2.1417131423950195,10000.0,28086.937801122665,29093.46788740158,28086.937801122665,1001.4011144638062,2.0771901607513428,0.0 -82700,2.2289338,1.4931586,,,,,,,,,,,,,, -82800,1.8956801,1.3837556,,,,,,,,,,,,,, -82900,2.26606,1.4626805,,,,,,,,,,,,,, -83000,2.3436155,1.4952527,,,,,,,,,,,,,, -83100,2.0638602,1.392676,,,,,,,,,,,,,, -83200,2.4656513,1.4534515,,,,,,,,,,,,,, -83300,2.1527686,1.4283092,,,,,,,,,,,,,, -83400,2.3025794,1.4739877,,,,,,,,,,,,,, -83500,2.1514273,1.4632761,,,,,,,,,,,,,, -83600,2.195868,1.4799439,,,,,,,,,,,,,, -83700,2.1498697,1.4371499,,,,,,,,,,,,,, -83800,2.4412124,1.4453897,,,,,,,,,,,,,, -83900,2.0931258,1.3730878,,,,,,,,,,,,,, -84000,2.477377,1.436702,,,,,,,,,,,,,, -84100,2.331456,1.4760457,,,,,,,,,,,,,, -84137,,,0.7523317933082581,0.9355828166007996,0.6721799969673157,1.334437608718872,50000.0,0.541100025177002,2.071115255355835,10000.0,28597.01343441009,29621.39469194412,28597.01343441009,1019.136206626892,2.137763500213623,0.0 -84200,2.0730119,1.3326464,,,,,,,,,,,,,, -84300,2.0718842,1.5089084,,,,,,,,,,,,,, -84400,2.4728267,1.4100717,,,,,,,,,,,,,, -84500,2.2362752,1.452852,,,,,,,,,,,,,, -84600,2.2177556,1.5256519,,,,,,,,,,,,,, -84700,2.2893472,1.4477482,,,,,,,,,,,,,, -84800,2.4227939,1.443297,,,,,,,,,,,,,, -84900,2.3757155,1.3607283,,,,,,,,,,,,,, -85000,2.365369,1.4470656,,,,,,,,,,,,,, -85100,2.202292,1.4333334,,,,,,,,,,,,,, -85200,2.1378965,1.3888928,,,,,,,,,,,,,, -85300,2.1350462,1.4364457,,,,,,,,,,,,,, -85400,2.2510114,1.4092157,,,,,,,,,,,,,, -85500,2.3525965,1.4556386,,,,,,,,,,,,,, -85600,2.1506288,1.4090359,,,,,,,,,,,,,, -85640,,,0.7469307780265808,0.961391270160675,0.66975998878479,1.3483177423477173,50000.0,0.5494000315666199,2.091202735900879,10000.0,29106.93808484077,30149.020438194275,29106.93808484077,1036.7360591888428,2.182596445083618,0.0 -85700,2.1203113,1.4822345,,,,,,,,,,,,,, -85800,2.3747604,1.3942754,,,,,,,,,,,,,, -85900,2.2995455,1.4715023,,,,,,,,,,,,,, -86000,2.2619877,1.4336056,,,,,,,,,,,,,, -86100,2.1498015,1.4249979,,,,,,,,,,,,,, -86200,2.2736292,1.4858319,,,,,,,,,,,,,, -86300,2.2731,1.3990049,,,,,,,,,,,,,, -86400,2.2544978,1.4819268,,,,,,,,,,,,,, -86500,2.5452256,1.4382465,,,,,,,,,,,,,, -86600,2.2892497,1.5164382,,,,,,,,,,,,,, -86700,2.3634837,1.5674262,,,,,,,,,,,,,, -86800,2.4140534,1.4629308,,,,,,,,,,,,,, -86900,2.2588649,1.3782406,,,,,,,,,,,,,, -87000,2.0881126,1.3662724,,,,,,,,,,,,,, -87100,2.3802443,1.477572,,,,,,,,,,,,,, -87144,,,0.7379224896430969,0.9985449314117432,0.6692799925804138,1.3523105382919312,50000.0,0.534600019454956,2.0925774574279785,10000.0,29617.127873182297,30676.874872922897,29617.127873182297,1054.3037416934967,2.2248997688293457,0.0 -87200,2.1459768,1.3959379,,,,,,,,,,,,,, -87300,2.3516724,1.3991731,,,,,,,,,,,,,, -87400,2.2817502,1.4867078,,,,,,,,,,,,,, -87500,2.436472,1.4570312,,,,,,,,,,,,,, -87600,2.2332258,1.4204854,,,,,,,,,,,,,, -87700,2.2364562,1.4766546,,,,,,,,,,,,,, -87800,2.3156798,1.4453009,,,,,,,,,,,,,, -87900,2.271036,1.4449475,,,,,,,,,,,,,, -88000,2.2338703,1.4777013,,,,,,,,,,,,,, -88100,2.1888041,1.4090532,,,,,,,,,,,,,, -88200,2.4274294,1.51927,,,,,,,,,,,,,, -88300,2.2756636,1.3441113,,,,,,,,,,,,,, -88400,2.3359594,1.472626,,,,,,,,,,,,,, -88500,2.4289284,1.5084697,,,,,,,,,,,,,, -88600,2.4707983,1.3809338,,,,,,,,,,,,,, -88647,,,0.7837810516357422,0.8073683381080627,0.6779199838638306,1.3172043561935425,50000.0,0.5424000024795532,2.068551778793335,10000.0,30127.066581964493,31204.71547460556,30127.066581964493,1072.1030399799347,2.271844387054444,0.0 -88700,2.1401749,1.3085741,,,,,,,,,,,,,, -88800,2.1915953,1.4788667,,,,,,,,,,,,,, -88900,2.1414711,1.3954878,,,,,,,,,,,,,, -89000,2.1667802,1.3843083,,,,,,,,,,,,,, -89100,2.2312448,1.3594935,,,,,,,,,,,,,, -89200,2.1861398,1.4331492,,,,,,,,,,,,,, -89300,2.2704449,1.399199,,,,,,,,,,,,,, -89400,2.3833296,1.4947889,,,,,,,,,,,,,, -89500,2.3980668,1.4890403,,,,,,,,,,,,,, -89600,2.2597013,1.5155416,,,,,,,,,,,,,, -89700,2.2060812,1.3220518,,,,,,,,,,,,,, -89800,2.3175993,1.4679303,,,,,,,,,,,,,, -89900,2.4774117,1.4988649,,,,,,,,,,,,,, -90000,2.4009426,1.5349166,,,,,,,,,,,,,, -90100,2.573183,1.4362986,,,,,,,,,,,,,, -90151,,,0.7757493257522583,0.8434444069862366,0.6840199828147888,1.2816094160079956,50000.0,0.5539000034332275,2.0205748081207275,10000.0,30637.134697914124,31732.880301237103,30637.134697914124,1090.0969746112823,2.317918062210083,0.0 -90200,2.3955688,1.4505155,,,,,,,,,,,,,, -90300,2.1321065,1.4871571,,,,,,,,,,,,,, -90400,2.3829951,1.3179771,,,,,,,,,,,,,, -90500,2.1666088,1.3858652,,,,,,,,,,,,,, -90600,2.363655,1.3689476,,,,,,,,,,,,,, -90700,2.1408744,1.489753,,,,,,,,,,,,,, -90800,2.2378845,1.3435067,,,,,,,,,,,,,, -90900,2.2454638,1.406197,,,,,,,,,,,,,, -91000,2.2789626,1.3274441,,,,,,,,,,,,,, -91100,2.2026021,1.4205557,,,,,,,,,,,,,, -91200,2.2938428,1.4018552,,,,,,,,,,,,,, -91300,2.4401069,1.4385556,,,,,,,,,,,,,, -91400,2.3394773,1.3733175,,,,,,,,,,,,,, -91500,2.2727072,1.3717566,,,,,,,,,,,,,, -91600,2.2939,1.4773626,,,,,,,,,,,,,, -91655,,,0.7606425285339355,0.8926047086715698,0.6730200052261353,1.342376470565796,50000.0,0.5481000542640686,2.098851203918457,10000.0,31147.257095575333,32260.765964984894,31147.257095575333,1107.751292705536,2.369476079940796,0.0 -91700,2.498075,1.4637213,,,,,,,,,,,,,, -91800,2.4470475,1.4444563,,,,,,,,,,,,,, -91900,2.2809901,1.4296049,,,,,,,,,,,,,, -92000,2.1955597,1.4230498,,,,,,,,,,,,,, -92100,2.2280855,1.3621716,,,,,,,,,,,,,, -92200,2.3548954,1.3496113,,,,,,,,,,,,,, -92300,2.2198815,1.3489007,,,,,,,,,,,,,, -92400,2.3616772,1.3875518,,,,,,,,,,,,,, -92500,2.0679698,1.3423237,,,,,,,,,,,,,, -92600,2.2838247,1.2582308,,,,,,,,,,,,,, -92700,2.475274,1.3808709,,,,,,,,,,,,,, -92800,2.7254891,1.2972037,,,,,,,,,,,,,, -92900,2.5693629,1.4493722,,,,,,,,,,,,,, -93000,2.3421109,1.4512252,,,,,,,,,,,,,, -93100,2.1318243,1.2794474,,,,,,,,,,,,,, -93159,,,0.7567561864852905,0.9086245894432068,0.6795399785041809,1.3091254234313965,50000.0,0.5511000156402588,2.040935516357422,10000.0,31657.1976583004,32788.57222819328,31657.1976583004,1125.5106115341189,2.420135974884033,0.0 -93200,2.3286262,1.3918864,,,,,,,,,,,,,, -93300,2.2431424,1.2713594,,,,,,,,,,,,,, -93400,2.4284704,1.5470697,,,,,,,,,,,,,, -93500,2.4429243,1.311889,,,,,,,,,,,,,, -93600,2.2914603,1.3792677,,,,,,,,,,,,,, -93700,2.355379,1.3532374,,,,,,,,,,,,,, -93800,2.191593,1.3564999,,,,,,,,,,,,,, -93900,2.2401814,1.2660027,,,,,,,,,,,,,, -94000,2.5270088,1.4873133,,,,,,,,,,,,,, -94100,2.3334181,1.3917716,,,,,,,,,,,,,, -94200,2.5003285,1.4008893,,,,,,,,,,,,,, -94300,2.4290571,1.2901165,,,,,,,,,,,,,, -94400,2.5048885,1.274532,,,,,,,,,,,,,, -94500,2.311963,1.4064264,,,,,,,,,,,,,, -94600,2.3478463,1.3956225,,,,,,,,,,,,,, -94663,,,0.7587292790412903,0.9139228463172911,0.6759399771690369,1.32182776927948,50000.0,0.5546000003814697,2.026935338973999,10000.0,32167.262050628666,33316.578904390335,32167.262050628666,1143.3546307086945,2.4641125202178955,0.0 -94700,2.3213336,1.3066745,,,,,,,,,,,,,, -94800,2.539998,1.4962487,,,,,,,,,,,,,, -94900,2.558796,1.3089113,,,,,,,,,,,,,, -95000,2.486598,1.4315223,,,,,,,,,,,,,, -95100,2.677326,1.4239886,,,,,,,,,,,,,, -95200,2.135923,1.251631,,,,,,,,,,,,,, -95300,2.3655925,1.3947775,,,,,,,,,,,,,, -95400,2.3406441,1.3476307,,,,,,,,,,,,,, -95500,2.5207138,1.3381184,,,,,,,,,,,,,, -95600,2.519617,1.334974,,,,,,,,,,,,,, -95700,2.3785453,1.3783523,,,,,,,,,,,,,, -95800,2.4076278,1.4155726,,,,,,,,,,,,,, -95900,2.6171105,1.3084756,,,,,,,,,,,,,, -96000,2.3930242,1.5454094,,,,,,,,,,,,,, -96100,2.3393226,1.5324776,,,,,,,,,,,,,, -96167,,,0.7642498016357422,0.8854460120201111,0.6845600008964539,1.2905325889587402,50000.0,0.5508000254631042,2.0234200954437256,10000.0,32677.42819571495,33844.844178915024,32677.42819571495,1161.3488364219666,2.514535427093506,0.0 -96200,2.5496268,1.4392352,,,,,,,,,,,,,, -96300,2.488017,1.3789351,,,,,,,,,,,,,, -96400,2.2571523,1.3182851,,,,,,,,,,,,,, -96500,2.39357,1.3946296,,,,,,,,,,,,,, -96600,2.4906926,1.454054,,,,,,,,,,,,,, -96700,2.4740505,1.4314042,,,,,,,,,,,,,, -96800,2.6982222,1.3219599,,,,,,,,,,,,,, -96900,2.6648564,1.3585433,,,,,,,,,,,,,, -97000,2.4001217,1.3537475,,,,,,,,,,,,,, -97100,2.6514316,1.344942,,,,,,,,,,,,,, -97200,2.2793663,1.3463291,,,,,,,,,,,,,, -97300,2.3665183,1.3214806,,,,,,,,,,,,,, -97400,2.437925,1.3502522,,,,,,,,,,,,,, -97500,2.384498,1.2175272,,,,,,,,,,,,,, -97600,2.4140627,1.4364616,,,,,,,,,,,,,, -97669,,,0.8039301633834839,0.7246366143226624,0.6893399953842163,1.2651199102401731,50000.0,0.5527999997138977,2.02603530883789,10000.0,33187.63098335266,34373.93581676483,33187.63098335266,1180.1326916217804,2.562276601791382,0.0 -97700,2.4520342,1.3664494,,,,,,,,,,,,,, -97800,2.4968822,1.4462221,,,,,,,,,,,,,, -97900,2.5266395,1.3114527,,,,,,,,,,,,,, -98000,2.322739,1.4162177,,,,,,,,,,,,,, -98100,2.3959792,1.3285996,,,,,,,,,,,,,, -98200,2.3739672,1.4552194,,,,,,,,,,,,,, -98300,2.4857557,1.352095,,,,,,,,,,,,,, -98400,2.2294078,1.2380859,,,,,,,,,,,,,, -98500,2.5220127,1.3985467,,,,,,,,,,,,,, -98600,2.607404,1.4048182,,,,,,,,,,,,,, -98700,2.3559291,1.3605182,,,,,,,,,,,,,, -98800,2.4227672,1.4695517,,,,,,,,,,,,,, -98900,2.543365,1.4111187,,,,,,,,,,,,,, -99000,2.3015943,1.276366,,,,,,,,,,,,,, -99100,2.3306506,1.2051395,,,,,,,,,,,,,, -99173,,,0.7836814522743225,0.7890745997428894,0.6879400014877319,1.263701558113098,50000.0,0.5654000043869019,1.98188054561615,10000.0,33697.79170894623,34902.117525577545,33697.79170894623,1198.0567321777344,2.6048731803894043,0.0 -99200,2.5321777,1.3313289,,,,,,,,,,,,,, -99300,2.7546241,1.4853985,,,,,,,,,,,,,, -99400,2.7315433,1.3367765,,,,,,,,,,,,,, -99500,2.5970185,1.3965403,,,,,,,,,,,,,, -99600,2.5105255,1.3278638,,,,,,,,,,,,,, -99700,2.5243363,1.4230794,,,,,,,,,,,,,, -99800,2.5618231,1.3079545,,,,,,,,,,,,,, -99900,2.535599,1.3109851,,,,,,,,,,,,,, -100000,2.3319774,1.3488092,,,,,,,,,,,,,, -100100,2.6111186,1.3573344,,,,,,,,,,,,,, -100200,2.4350858,1.2546039,,,,,,,,,,,,,, -100300,2.4468994,1.3217622,,,,,,,,,,,,,, -100400,2.484838,1.2834502,,,,,,,,,,,,,, -100500,2.414493,1.3429761,,,,,,,,,,,,,, -100600,2.4507718,1.2833061,,,,,,,,,,,,,, -100676,,,0.7801538705825806,0.8223650455474854,0.689579963684082,1.26771342754364,50000.0,0.5601000189781189,1.9989219903945925,10000.0,34207.87293791771,35430.207654953,34207.87293791771,1215.964411497116,2.652205467224121,0.0 -100700,2.2489924,1.3041648,,,,,,,,,,,,,, -100800,2.4778059,1.3562989,,,,,,,,,,,,,, -100900,2.601712,1.4003334,,,,,,,,,,,,,, -101000,2.417781,1.354106,,,,,,,,,,,,,, -101100,2.575976,1.4275224,,,,,,,,,,,,,, -101200,2.4231734,1.2286944,,,,,,,,,,,,,, -101300,2.5352447,1.34954,,,,,,,,,,,,,, -101400,2.7872694,1.3864826,,,,,,,,,,,,,, -101500,2.4771378,1.2711706,,,,,,,,,,,,,, -101600,3.0918841,1.290113,,,,,,,,,,,,,, -101700,2.4864883,1.4398271,,,,,,,,,,,,,, -101800,2.7935972,1.3056486,,,,,,,,,,,,,, -101900,2.650485,1.4383081,,,,,,,,,,,,,, -102000,2.392114,1.4001585,,,,,,,,,,,,,, -102100,2.8012373,1.2622094,,,,,,,,,,,,,, -102180,,,0.78226637840271,0.813494086265564,0.6924200057983398,1.249699354171753,50000.0,0.5678000450134277,1.9510797262191768,10000.0,34717.97048306465,35958.062203884125,34717.97048306465,1233.6210358142853,2.698038578033448,0.0 -102200,2.4695017,1.4175496,,,,,,,,,,,,,, -102300,2.47822,1.371909,,,,,,,,,,,,,, -102400,2.3988464,1.1609868,,,,,,,,,,,,,, -102500,2.662475,1.2878064,,,,,,,,,,,,,, -102600,2.493016,1.2110231,,,,,,,,,,,,,, -102700,2.590145,1.1654636,,,,,,,,,,,,,, -102800,2.4719088,1.4109411,,,,,,,,,,,,,, -102900,2.782105,1.4213531,,,,,,,,,,,,,, -103000,2.4609053,1.2728157,,,,,,,,,,,,,, -103100,2.4578388,1.3185807,,,,,,,,,,,,,, -103200,2.499239,1.2839111,,,,,,,,,,,,,, -103300,2.879365,1.2554864,,,,,,,,,,,,,, -103400,2.393172,1.1832204,,,,,,,,,,,,,, -103500,2.6669748,1.4560099,,,,,,,,,,,,,, -103600,2.6071289,1.2889482,,,,,,,,,,,,,, -103683,,,0.7759087681770325,0.8256220817565918,0.6915599703788757,1.2632023096084597,50000.0,0.5628000497817993,1.992747783660889,10000.0,35228.02629613876,36485.87252473831,35228.02629613876,1251.2748274803162,2.7438771724700928,0.0 -103700,2.439481,1.3526738,,,,,,,,,,,,,, -103800,2.4482112,1.2009678,,,,,,,,,,,,,, -103900,2.7223532,1.3578715,,,,,,,,,,,,,, -104000,2.5914977,1.347698,,,,,,,,,,,,,, -104100,2.452539,1.3254936,,,,,,,,,,,,,, -104200,2.515255,1.2464049,,,,,,,,,,,,,, -104300,2.7629323,1.2467487,,,,,,,,,,,,,, -104400,2.7182872,1.3329781,,,,,,,,,,,,,, -104500,2.4756138,1.3518503,,,,,,,,,,,,,, -104600,2.384915,1.2720194,,,,,,,,,,,,,, -104700,2.801616,1.3514794,,,,,,,,,,,,,, -104800,2.712848,1.2905519,,,,,,,,,,,,,, -104900,2.459304,1.2716808,,,,,,,,,,,,,, -105000,2.6883302,1.3102168,,,,,,,,,,,,,, -105100,2.4588072,1.3612094,,,,,,,,,,,,,, -105186,,,0.7785993218421936,0.8219675421714783,0.6947000026702881,1.2501887083053589,50000.0,0.5644000172615051,1.9662601947784424,10000.0,35737.980461120605,37013.7807905674,35737.980461120605,1269.1238374710083,2.793518781661988,0.0 -105200,2.6136894,1.3172715,,,,,,,,,,,,,, -105300,2.584012,1.4325275,,,,,,,,,,,,,, -105400,2.5636928,1.2719599,,,,,,,,,,,,,, -105500,2.4371295,1.195278,,,,,,,,,,,,,, -105600,2.5690513,1.3076128,,,,,,,,,,,,,, -105700,2.5892637,1.2939007,,,,,,,,,,,,,, -105800,2.5020602,1.3011136,,,,,,,,,,,,,, -105900,2.678258,1.3216329,,,,,,,,,,,,,, -106000,2.9371133,1.3034016,,,,,,,,,,,,,, -106100,2.43809,1.2545434,,,,,,,,,,,,,, -106200,2.6612277,1.2943066,,,,,,,,,,,,,, -106300,2.5866842,1.4220955,,,,,,,,,,,,,, -106400,2.825068,1.3136263,,,,,,,,,,,,,, -106500,2.6842387,1.2587119,,,,,,,,,,,,,, -106600,2.4722703,1.3182926,,,,,,,,,,,,,, -106690,,,0.8191565275192261,0.6560265421867371,0.6922399997711182,1.2624542713165283,50000.0,0.5685000419616699,2.02057147026062,10000.0,36248.21604681015,37542.218354702,36248.21604681015,1287.228770017624,2.8367748260498047,0.0 -106700,2.699781,1.3289218,,,,,,,,,,,,,, -106800,2.7427278,1.3150237,,,,,,,,,,,,,, -106900,2.5041459,1.3007748,,,,,,,,,,,,,, -107000,2.5440784,1.3283135,,,,,,,,,,,,,, -107100,2.7158618,1.2888837,,,,,,,,,,,,,, -107200,2.8024921,1.415595,,,,,,,,,,,,,, -107300,2.7497075,1.2566013,,,,,,,,,,,,,, -107400,2.6328957,1.1989245,,,,,,,,,,,,,, -107500,2.6285295,1.2672588,,,,,,,,,,,,,, -107600,2.6000698,1.3043077,,,,,,,,,,,,,, -107700,2.7005742,1.3002318,,,,,,,,,,,,,, -107800,2.6208584,1.35075,,,,,,,,,,,,,, -107900,2.6337454,1.1959987,,,,,,,,,,,,,, -108000,2.826217,1.4178321,,,,,,,,,,,,,, -108100,2.5547388,1.242373,,,,,,,,,,,,,, -108194,,,0.7940050959587097,0.7586848735809326,0.6922000050544739,1.26578688621521,50000.0,0.556600034236908,2.0142979621887207,10000.0,36758.368527174,38070.13637447357,36758.368527174,1304.8902144432068,2.8855910301208496,0.0 -108200,2.7380714,1.2438375,,,,,,,,,,,,,, -108300,2.591788,1.36239,,,,,,,,,,,,,, -108400,2.4918637,1.3208379,,,,,,,,,,,,,, -108500,2.7214725,1.2479217,,,,,,,,,,,,,, -108600,2.5348065,1.2028103,,,,,,,,,,,,,, -108700,2.5934143,1.2785141,,,,,,,,,,,,,, -108800,2.6334963,1.3323873,,,,,,,,,,,,,, -108900,2.727911,1.3023696,,,,,,,,,,,,,, -109000,2.7960696,1.2614067,,,,,,,,,,,,,, -109100,2.604654,1.3040007,,,,,,,,,,,,,, -109200,2.6154628,1.2905855,,,,,,,,,,,,,, -109300,2.7135284,1.2674489,,,,,,,,,,,,,, -109400,2.8930118,1.3061954,,,,,,,,,,,,,, -109500,2.6305676,1.2775865,,,,,,,,,,,,,, -109600,2.380522,1.2193822,,,,,,,,,,,,,, -109698,,,0.7897400856018066,0.7790143489837646,0.6975199580192566,1.232077956199646,50000.0,0.5653000473976135,1.950202465057373,10000.0,37268.60179138184,38598.30977582932,37268.60179138184,1322.726529121399,2.933062791824341,0.0 -109700,2.724482,1.3327144,,,,,,,,,,,,,, -109800,2.7679212,1.2797923,,,,,,,,,,,,,, -109900,2.7296925,1.246259,,,,,,,,,,,,,, -110000,2.7605155,1.30724,,,,,,,,,,,,,, -110100,2.6032636,1.228565,,,,,,,,,,,,,, -110200,2.74765,1.226782,,,,,,,,,,,,,, -110300,2.7580204,1.2539345,,,,,,,,,,,,,, -110400,2.7686412,1.3474836,,,,,,,,,,,,,, -110500,2.882885,1.2751784,,,,,,,,,,,,,, -110600,2.7596474,1.3394752,,,,,,,,,,,,,, -110700,2.887542,1.3087747,,,,,,,,,,,,,, -110800,2.8836339,1.3079422,,,,,,,,,,,,,, -110900,2.7338548,1.3038104,,,,,,,,,,,,,, -111000,2.5618849,1.211454,,,,,,,,,,,,,, -111100,2.6982079,1.2922168,,,,,,,,,,,,,, -111200,3.3191795,1.4012505,,,,,,,,,,,,,, -111201,,,0.7818080186843872,0.8038931488990784,0.6892600059509277,1.272587776184082,50000.0,0.5591000318527222,2.025722026824951,10000.0,37778.79350161552,39126.50262880325,37778.79350161552,1340.6236248016355,2.9807047843933105,0.0 -111300,2.690197,1.2385015,,,,,,,,,,,,,, -111400,2.6215444,1.1386852,,,,,,,,,,,,,, -111500,2.40464,1.1693096,,,,,,,,,,,,,, -111600,2.6712072,1.2936864,,,,,,,,,,,,,, -111700,2.6267374,1.1896032,,,,,,,,,,,,,, -111800,2.642698,1.241968,,,,,,,,,,,,,, -111900,2.7603056,1.2697629,,,,,,,,,,,,,, -112000,2.8513122,1.3428923,,,,,,,,,,,,,, -112100,3.078964,1.3306727,,,,,,,,,,,,,, -112200,2.5204008,1.1163435,,,,,,,,,,,,,, -112300,2.846894,1.237432,,,,,,,,,,,,,, -112400,2.7547932,1.3828557,,,,,,,,,,,,,, -112500,2.8146417,1.2802007,,,,,,,,,,,,,, -112600,2.945411,1.2544445,,,,,,,,,,,,,, -112700,2.5801601,1.2141273,,,,,,,,,,,,,, -112705,,,0.7928690910339355,0.7587718367576599,0.699400007724762,1.2161651849746704,50000.0,0.5699000358581543,1.9524030685424805,10000.0,38288.8953294754,39654.957559108734,38288.8953294754,1358.8687388896942,3.033480167388916,0.0 -112800,2.6566272,1.2857653,,,,,,,,,,,,,, -112900,2.9261205,1.334699,,,,,,,,,,,,,, -113000,2.8966901,1.274214,,,,,,,,,,,,,, -113100,2.9445686,1.3014656,,,,,,,,,,,,,, -113200,2.8194673,1.1898029,,,,,,,,,,,,,, -113300,2.9394646,1.1895479,,,,,,,,,,,,,, -113400,2.8203952,1.2813504,,,,,,,,,,,,,, -113500,2.9775045,1.3801628,,,,,,,,,,,,,, -113600,2.6949077,1.1227667,,,,,,,,,,,,,, -113700,2.6784809,1.269233,,,,,,,,,,,,,, -113800,2.7560763,1.185316,,,,,,,,,,,,,, -113900,2.8546793,1.2388594,,,,,,,,,,,,,, -114000,2.8900752,1.1642765,,,,,,,,,,,,,, -114100,2.7937274,1.2040143,,,,,,,,,,,,,, -114200,2.9052846,1.3904165,,,,,,,,,,,,,, -114209,,,0.7909358739852905,0.765129566192627,0.6998800039291382,1.2264665365219116,50000.0,0.570900022983551,1.9734641313552856,10000.0,38798.96592998505,40183.36325955391,38798.96592998505,1377.1002910137177,3.080017566680908,0.0 -114300,2.9079175,1.181371,,,,,,,,,,,,,, -114400,2.6606553,1.232332,,,,,,,,,,,,,, -114500,2.6440423,1.2193738,,,,,,,,,,,,,, -114600,2.8794324,1.2692276,,,,,,,,,,,,,, -114700,2.9279542,1.2783694,,,,,,,,,,,,,, -114800,3.0560486,1.3189757,,,,,,,,,,,,,, -114900,2.7835383,1.1802261,,,,,,,,,,,,,, -115000,2.8914638,1.2606497,,,,,,,,,,,,,, -115100,2.741499,1.1327933,,,,,,,,,,,,,, -115200,2.8375854,1.2132354,,,,,,,,,,,,,, -115300,3.1212432,1.2526252,,,,,,,,,,,,,, -115400,2.8036926,1.2044909,,,,,,,,,,,,,, -115500,2.7608247,1.2234294,,,,,,,,,,,,,, -115600,2.9408617,1.2025857,,,,,,,,,,,,,, -115700,2.9776218,1.2979656,,,,,,,,,,,,,, -115713,,,0.8435705900192261,0.576335072517395,0.7015399932861328,1.2031558752059937,50000.0,0.5711000561714172,1.942007064819336,10000.0,39309.17578649521,40711.58852005005,39309.17578649521,1395.0114710330963,3.126986265182495,0.0 -115800,2.6517959,1.1852689,,,,,,,,,,,,,, -115900,3.110615,1.2048117,,,,,,,,,,,,,, -116000,2.8656557,1.2377663,,,,,,,,,,,,,, -116100,3.4509044,1.3046528,,,,,,,,,,,,,, -116200,2.825519,1.2090424,,,,,,,,,,,,,, -116300,2.8998678,1.3175957,,,,,,,,,,,,,, -116400,2.6054842,1.1343594,,,,,,,,,,,,,, -116500,3.0216322,1.3746619,,,,,,,,,,,,,, -116600,2.7708747,1.322392,,,,,,,,,,,,,, -116700,2.6783013,1.2288423,,,,,,,,,,,,,, -116800,2.8323731,1.2188003,,,,,,,,,,,,,, -116900,2.903129,1.2360995,,,,,,,,,,,,,, -117000,3.0825431,1.2486627,,,,,,,,,,,,,, -117100,2.7877426,1.277652,,,,,,,,,,,,,, -117200,3.186711,1.2639993,,,,,,,,,,,,,, -117217,,,0.8229233026504517,0.6437040567398071,0.7073799967765808,1.194665551185608,50000.0,0.5755000114440918,1.9088903665542605,10000.0,39819.32137942314,41239.41584944725,39819.32137942314,1412.585030078888,3.1796674728393555,0.0 -117300,2.9128535,1.2695296,,,,,,,,,,,,,, -117400,2.6993034,1.2158613,,,,,,,,,,,,,, -117500,2.8100574,1.2153157,,,,,,,,,,,,,, -117600,3.1227174,1.2572395,,,,,,,,,,,,,, -117700,3.2771575,1.1639044,,,,,,,,,,,,,, -117800,2.78555,1.242057,,,,,,,,,,,,,, -117900,2.8207588,1.2073567,,,,,,,,,,,,,, -118000,3.0052335,1.1771214,,,,,,,,,,,,,, -118100,2.9326763,1.1778626,,,,,,,,,,,,,, -118200,2.9565637,1.2397059,,,,,,,,,,,,,, -118300,2.6118634,1.1015486,,,,,,,,,,,,,, -118400,2.9054668,1.199503,,,,,,,,,,,,,, -118500,2.8884108,1.3271692,,,,,,,,,,,,,, -118600,3.016029,1.1728555,,,,,,,,,,,,,, -118700,2.7309725,1.1235899,,,,,,,,,,,,,, -118720,,,0.8165457248687744,0.6733990907669067,0.7057200074195862,1.2064604759216309,50000.0,0.5786000490188599,1.9398770332336424,10000.0,40329.29307627678,41767.53636336327,40329.29307627678,1430.6261870861051,3.2323155403137207,0.0 -118800,3.2445216,1.286664,,,,,,,,,,,,,, -118900,2.9405644,1.1872814,,,,,,,,,,,,,, -119000,3.115278,1.2490981,,,,,,,,,,,,,, -119100,2.8837829,1.2382814,,,,,,,,,,,,,, -119200,2.9353414,1.2090721,,,,,,,,,,,,,, -119300,2.9747248,1.1761789,,,,,,,,,,,,,, -119400,3.2155368,1.2422397,,,,,,,,,,,,,, -119500,2.7900975,1.134077,,,,,,,,,,,,,, -119600,2.8745751,1.2591636,,,,,,,,,,,,,, -119700,3.132216,1.2405257,,,,,,,,,,,,,, -119800,2.95338,1.1624043,,,,,,,,,,,,,, -119900,2.929473,1.208438,,,,,,,,,,,,,, -120000,2.6488414,1.1180804,,,,,,,,,,,,,, -120100,3.0903363,1.1737506,,,,,,,,,,,,,, -120200,3.064873,1.2205176,,,,,,,,,,,,,, -120223,,,0.8181002736091614,0.6516917943954468,0.7089999914169312,1.1775261163711548,50000.0,0.5800000429153442,1.913909673690796,10000.0,40839.21898150444,42296.18041563034,40839.21898150444,1449.2406451702118,3.280336380004883,0.0 -120300,2.994198,1.2252483,,,,,,,,,,,,,, -120400,2.8967028,1.2984333,,,,,,,,,,,,,, -120500,3.0063286,1.1582767,,,,,,,,,,,,,, -120600,2.929241,1.0835621,,,,,,,,,,,,,, -120700,2.8360617,1.2794309,,,,,,,,,,,,,, -120800,2.8432276,1.1358464,,,,,,,,,,,,,, -120900,2.990575,1.1857841,,,,,,,,,,,,,, -121000,2.8198085,1.1527408,,,,,,,,,,,,,, -121100,2.837075,1.132851,,,,,,,,,,,,,, -121200,3.2151217,1.2271014,,,,,,,,,,,,,, -121300,3.0720923,1.211847,,,,,,,,,,,,,, -121400,3.0954273,1.1526948,,,,,,,,,,,,,, -121500,2.9071186,1.0654514,,,,,,,,,,,,,, -121600,3.0299032,1.2220894,,,,,,,,,,,,,, -121700,2.9144888,1.1649623,,,,,,,,,,,,,, -121726,,,0.8082947731018066,0.6896772384643555,0.7066999673843384,1.1927813291549685,50000.0,0.5759000182151794,1.9429177045822144,10000.0,41349.17365002632,42824.30308508873,41349.17365002632,1467.3116641044617,3.323556900024414,0.0 -121800,2.7907348,1.1066607,,,,,,,,,,,,,, -121900,2.8878236,1.1743766,,,,,,,,,,,,,, -122000,3.319104,1.2157748,,,,,,,,,,,,,, -122100,3.0602489,1.119182,,,,,,,,,,,,,, -122200,3.1367068,1.2420247,,,,,,,,,,,,,, -122300,3.1802168,1.1328506,,,,,,,,,,,,,, -122400,2.955199,1.1881607,,,,,,,,,,,,,, -122500,3.4016354,1.1225345,,,,,,,,,,,,,, -122600,2.91931,1.1859666,,,,,,,,,,,,,, -122700,2.8904047,1.0999676,,,,,,,,,,,,,, -122800,3.0285907,1.1482468,,,,,,,,,,,,,, -122900,3.219238,1.191299,,,,,,,,,,,,,, -123000,2.79036,1.1024008,,,,,,,,,,,,,, -123100,2.9151118,1.1895481,,,,,,,,,,,,,, -123200,3.0083623,1.1371496,,,,,,,,,,,,,, -123230,,,0.8172034025192261,0.6480165719985962,0.7120800018310547,1.1730077266693115,50000.0,0.5870000123977661,1.930961012840271,10000.0,41859.14821410179,43352.32322573662,41859.14821410179,1485.2464039325714,3.3781511783599854,0.0 -123300,3.0493944,1.2639844,,,,,,,,,,,,,, -123400,3.2485225,1.1728647,,,,,,,,,,,,,, -123500,2.8621376,1.1903712,,,,,,,,,,,,,, -123600,2.8815227,1.2272493,,,,,,,,,,,,,, -123700,3.2094955,1.311947,,,,,,,,,,,,,, -123800,3.4181588,1.1100807,,,,,,,,,,,,,, -123900,3.1170318,1.229209,,,,,,,,,,,,,, -124000,2.9128137,1.1251876,,,,,,,,,,,,,, -124100,3.1329157,1.1580299,,,,,,,,,,,,,, -124200,3.1555495,1.1702994,,,,,,,,,,,,,, -124300,2.9444547,1.1636097,,,,,,,,,,,,,, -124400,3.1259599,1.1383381,,,,,,,,,,,,,, -124500,2.9075165,1.0438178,,,,,,,,,,,,,, -124600,2.9882722,1.1175731,,,,,,,,,,,,,, -124700,3.0494735,1.1279353,,,,,,,,,,,,,, -124734,,,0.8444275856018066,0.5557413101196289,0.7135199904441833,1.162333846092224,50000.0,0.5822000503540039,1.8956857919692995,10000.0,42369.25692343712,43880.17747545242,42369.25692343712,1502.886422872543,3.428584575653076,0.0 -124800,3.147045,1.062606,,,,,,,,,,,,,, -124900,3.31509,1.2202693,,,,,,,,,,,,,, -125000,3.0137148,1.1420425,,,,,,,,,,,,,, -125100,2.9897768,1.1206788,,,,,,,,,,,,,, -125200,3.1306212,1.1258332,,,,,,,,,,,,,, -125300,3.1680827,1.1672604,,,,,,,,,,,,,, -125400,3.219991,1.176992,,,,,,,,,,,,,, -125500,3.232974,1.1676592,,,,,,,,,,,,,, -125600,3.4539468,1.1642798,,,,,,,,,,,,,, -125700,3.0684936,1.1030272,,,,,,,,,,,,,, -125800,3.423631,1.1481433,,,,,,,,,,,,,, -125900,3.218137,1.1146452,,,,,,,,,,,,,, -126000,3.3034463,1.1671249,,,,,,,,,,,,,, -126100,3.2062109,1.08992,,,,,,,,,,,,,, -126200,3.240592,1.0955093,,,,,,,,,,,,,, -126238,,,0.8463608026504517,0.5420656800270081,0.7162599563598633,1.1581156253814695,50000.0,0.5902000069618225,1.887725830078125,10000.0,42879.32626509666,44408.17130422592,42879.32626509666,1520.7035658359528,3.4789483547210693,0.0 -126300,3.0331135,1.0719652,,,,,,,,,,,,,, -126400,3.0956848,1.0925056,,,,,,,,,,,,,, -126500,2.9983032,1.0358891,,,,,,,,,,,,,, -126600,3.2864215,1.1391623,,,,,,,,,,,,,, -126700,3.0036967,1.1873266,,,,,,,,,,,,,, -126800,3.5599573,1.2058381,,,,,,,,,,,,,, -126900,3.006422,1.0798666,,,,,,,,,,,,,, -127000,3.3364882,1.1029809,,,,,,,,,,,,,, -127100,3.0430605,1.1105665,,,,,,,,,,,,,, -127200,3.1812696,1.0213525,,,,,,,,,,,,,, -127300,3.0482082,1.1571944,,,,,,,,,,,,,, -127400,3.0911036,1.096846,,,,,,,,,,,,,, -127500,3.286533,1.0897157,,,,,,,,,,,,,, -127600,3.2364995,1.0719497,,,,,,,,,,,,,, -127700,3.1146297,1.0193105,,,,,,,,,,,,,, -127741,,,0.8425143361091614,0.5580957531929016,0.7177199721336365,1.1547913551330566,50000.0,0.5898000001907349,1.8950108289718628,10000.0,43389.325770139694,44935.92632341385,43389.325770139694,1538.350778579712,3.5324387550354004,0.0 -127800,3.3545058,1.1926236,,,,,,,,,,,,,, -127900,2.969495,1.0214046,,,,,,,,,,,,,, -128000,2.8992555,1.1647382,,,,,,,,,,,,,, -128100,3.0463948,1.104565,,,,,,,,,,,,,, -128200,3.3793495,1.1298081,,,,,,,,,,,,,, -128300,3.4997096,1.1306549,,,,,,,,,,,,,, -128400,3.0923455,1.0185679,,,,,,,,,,,,,, -128500,3.2482386,1.1600523,,,,,,,,,,,,,, -128600,3.029673,1.0610707,,,,,,,,,,,,,, -128700,3.5334003,1.0687114,,,,,,,,,,,,,, -128800,3.3665276,1.2235479,,,,,,,,,,,,,, -128900,3.338936,1.1451665,,,,,,,,,,,,,, -129000,3.2488062,1.1268797,,,,,,,,,,,,,, -129100,3.6306832,1.1654382,,,,,,,,,,,,,, -129200,3.3061981,1.0547496,,,,,,,,,,,,,, -129246,,,0.8390465378761292,0.5831946134567261,0.7163999676704407,1.167493462562561,50000.0,0.5910000205039978,1.917563915252685,10000.0,43899.508662223816,45464.26081061363,43899.508662223816,1556.3912541866302,3.587545394897461,0.0 -129300,3.7286072,1.1821582,,,,,,,,,,,,,, -129400,3.479579,1.0981358,,,,,,,,,,,,,, -129500,3.2785077,1.1789368,,,,,,,,,,,,,, -129600,3.1367223,1.1046655,,,,,,,,,,,,,, -129700,3.1435976,1.0325782,,,,,,,,,,,,,, -129800,3.4396217,1.1364028,,,,,,,,,,,,,, -129900,3.6455266,1.0243202,,,,,,,,,,,,,, -130000,3.1774719,1.1251388,,,,,,,,,,,,,, -130100,3.6198165,1.0768732,,,,,,,,,,,,,, -130200,3.417147,1.0231124,,,,,,,,,,,,,, -130300,3.4230404,1.122234,,,,,,,,,,,,,, -130400,3.5552227,1.0843246,,,,,,,,,,,,,, -130500,3.476044,1.1733563,,,,,,,,,,,,,, -130600,3.4081824,1.1928736,,,,,,,,,,,,,, -130700,3.263362,1.0188335,,,,,,,,,,,,,, -130750,,,0.8318319320678711,0.5993456840515137,0.715939998626709,1.1581324338912964,50000.0,0.5909000039100647,1.8844317197799685,10000.0,44409.489119291306,45992.3717007637,44409.489119291306,1574.4115755558014,3.64362382888794,0.0 -130800,3.166349,1.0530926,,,,,,,,,,,,,, -130900,3.2041392,1.1016986,,,,,,,,,,,,,, -131000,3.210295,1.0331718,,,,,,,,,,,,,, -131100,3.4292467,1.1349602,,,,,,,,,,,,,, -131200,3.072051,1.0413948,,,,,,,,,,,,,, -131300,3.104579,1.0528345,,,,,,,,,,,,,, -131400,3.2621217,1.1141487,,,,,,,,,,,,,, -131500,3.383867,1.0025663,,,,,,,,,,,,,, -131600,3.0890746,1.0926073,,,,,,,,,,,,,, -131700,3.2765648,1.0228208,,,,,,,,,,,,,, -131800,3.0743365,0.96347123,,,,,,,,,,,,,, -131900,3.3398242,1.066818,,,,,,,,,,,,,, -132000,3.3448966,1.1009207,,,,,,,,,,,,,, -132100,3.41178,1.1441809,,,,,,,,,,,,,, -132200,3.1817431,1.0732665,,,,,,,,,,,,,, -132254,,,0.8413584232330322,0.5591051578521729,0.7198799848556519,1.1507633924484253,50000.0,0.5952000021934509,1.8953807353973389,10000.0,44919.46421599388,46520.07335758209,44919.46421599388,1592.0317244529724,3.695392608642578,0.0 -132300,3.2922077,1.1414083,,,,,,,,,,,,,, -132400,3.375648,1.0426921,,,,,,,,,,,,,, -132500,3.4579651,1.1208401,,,,,,,,,,,,,, -132600,3.2860332,1.0173515,,,,,,,,,,,,,, -132700,3.2952273,1.0365696,,,,,,,,,,,,,, -132800,3.4977102,1.0852731,,,,,,,,,,,,,, -132900,3.3235276,1.0950638,,,,,,,,,,,,,, -133000,3.371402,1.1280262,,,,,,,,,,,,,, -133100,3.5878026,1.046695,,,,,,,,,,,,,, -133200,3.2784781,1.1382484,,,,,,,,,,,,,, -133300,3.3858817,0.99754095,,,,,,,,,,,,,, -133400,3.2342339,1.0616249,,,,,,,,,,,,,, -133500,3.4324381,1.0433912,,,,,,,,,,,,,, -133600,3.3498812,1.0079702,,,,,,,,,,,,,, -133700,3.4496813,1.0113552,,,,,,,,,,,,,, -133758,,,0.8465999364852905,0.5491646528244019,0.7141599655151367,1.1705526113510132,50000.0,0.5903000235557556,1.915621519088745,10000.0,45429.54328536987,47048.18325304985,45429.54328536987,1609.954525232315,3.7486047744750977,0.0 -133800,3.5811036,1.1489302,,,,,,,,,,,,,, -133900,3.674712,1.0515752,,,,,,,,,,,,,, -134000,3.5165725,1.1252863,,,,,,,,,,,,,, -134100,3.5549128,1.0988157,,,,,,,,,,,,,, -134200,3.6616662,1.1004992,,,,,,,,,,,,,, -134300,3.2238338,1.0614948,,,,,,,,,,,,,, -134400,3.293256,0.9810147,,,,,,,,,,,,,, -134500,3.401154,1.0160754,,,,,,,,,,,,,, -134600,3.3857908,0.9904347,,,,,,,,,,,,,, -134700,3.6191103,1.100167,,,,,,,,,,,,,, -134800,3.45206,1.0504575,,,,,,,,,,,,,, -134900,3.5029151,1.1087897,,,,,,,,,,,,,, -135000,3.5833888,0.981004,,,,,,,,,,,,,, -135100,3.737561,1.097184,,,,,,,,,,,,,, -135200,3.6262503,1.0925736,,,,,,,,,,,,,, -135262,,,0.8697385191917419,0.4575299024581909,0.7210599780082703,1.1391361951828003,50000.0,0.5943000316619873,1.874566674232483,10000.0,45939.4492623806,47576.754868507385,45939.4492623806,1628.5150740146637,3.798022508621216,0.0 -135300,3.664606,1.0679151,,,,,,,,,,,,,, -135400,3.5313892,1.0607779,,,,,,,,,,,,,, -135500,3.3199277,1.0058439,,,,,,,,,,,,,, -135600,3.835577,1.1328223,,,,,,,,,,,,,, -135700,3.3831637,1.0230008,,,,,,,,,,,,,, -135800,3.6647935,1.0345314,,,,,,,,,,,,,, -135900,3.153609,0.98478353,,,,,,,,,,,,,, -136000,3.4040232,1.0323265,,,,,,,,,,,,,, -136100,4.0201735,1.0349009,,,,,,,,,,,,,, -136200,3.3912513,0.954795,,,,,,,,,,,,,, -136300,3.681294,1.024572,,,,,,,,,,,,,, -136400,3.8460217,1.0263181,,,,,,,,,,,,,, -136500,3.825619,1.0962299,,,,,,,,,,,,,, -136600,3.8869019,1.0367334,,,,,,,,,,,,,, -136700,3.3594193,0.99437237,,,,,,,,,,,,,, -136765,,,0.8618462681770325,0.4852306246757507,0.7225599884986877,1.1408337354660034,50000.0,0.6003000140190125,1.883261442184448,10000.0,46449.43899035454,48104.78061199188,46449.43899035454,1646.445203781128,3.8481035232543945,0.0 -136800,3.3735027,0.97221977,,,,,,,,,,,,,, -136900,3.068377,0.87532026,,,,,,,,,,,,,, -137000,3.4457054,1.1285638,,,,,,,,,,,,,, -137100,3.5106823,1.0322447,,,,,,,,,,,,,, -137200,3.8025112,1.0075091,,,,,,,,,,,,,, -137300,3.3035603,0.9241159,,,,,,,,,,,,,, -137400,3.4111261,1.0591612,,,,,,,,,,,,,, -137500,3.4203227,0.89352256,,,,,,,,,,,,,, -137600,3.2688003,1.0246022,,,,,,,,,,,,,, -137700,3.7187681,1.0211902,,,,,,,,,,,,,, -137800,3.6220448,1.0394244,,,,,,,,,,,,,, -137900,3.3903759,1.0256635,,,,,,,,,,,,,, -138000,3.320879,0.9408304,,,,,,,,,,,,,, -138100,3.6952324,1.0517842,,,,,,,,,,,,,, -138200,3.5261395,0.99780136,,,,,,,,,,,,,, -138269,,,0.867586076259613,0.4603193998336792,0.731939971446991,1.108080506324768,50000.0,0.6046000123023987,1.865692138671875,10000.0,46959.6368894577,48632.8400952816,46959.6368894577,1664.1922266483307,3.9070560932159424,0.0 -138300,3.6715415,1.0643821,,,,,,,,,,,,,, -138400,4.016096,0.9816513,,,,,,,,,,,,,, -138500,3.395166,0.96890885,,,,,,,,,,,,,, -138600,3.663291,0.97361237,,,,,,,,,,,,,, -138700,3.7610946,0.99626595,,,,,,,,,,,,,, -138800,3.717425,1.027615,,,,,,,,,,,,,, -138900,3.5738454,1.00517,,,,,,,,,,,,,, -139000,3.7742863,1.0856756,,,,,,,,,,,,,, -139100,3.664257,0.9851903,,,,,,,,,,,,,, -139200,3.475156,0.9503119,,,,,,,,,,,,,, -139300,3.500091,1.0013945,,,,,,,,,,,,,, -139400,3.5394318,0.96914285,,,,,,,,,,,,,, -139500,3.6231856,0.964018,,,,,,,,,,,,,, -139600,3.7659936,1.0381892,,,,,,,,,,,,,, -139700,3.8603635,0.97053957,,,,,,,,,,,,,, -139773,,,0.8621252775192261,0.4767018556594848,0.7252799868583679,1.134071946144104,50000.0,0.5985000133514404,1.8925247192382808,10000.0,47469.786170721054,49160.73694300652,47469.786170721054,1681.8307964801788,3.9598209857940674,0.0 -139800,3.6543765,1.0247211,,,,,,,,,,,,,, -139900,3.5974176,0.9991468,,,,,,,,,,,,,, -140000,3.824343,1.0245569,,,,,,,,,,,,,, -140100,3.890325,0.9432076,,,,,,,,,,,,,, -140200,3.576235,0.970964,,,,,,,,,,,,,, -140300,3.7308307,0.9732309,,,,,,,,,,,,,, -140400,3.7189932,0.98731613,,,,,,,,,,,,,, -140500,3.6627302,0.9964497,,,,,,,,,,,,,, -140600,3.3245716,0.9924284,,,,,,,,,,,,,, -140700,3.4069219,0.95499766,,,,,,,,,,,,,, -140800,3.7155495,0.9593866,,,,,,,,,,,,,, -140900,3.7071505,1.0107008,,,,,,,,,,,,,, -141000,3.8482852,0.99221367,,,,,,,,,,,,,, -141100,3.8808045,0.9737568,,,,,,,,,,,,,, -141200,3.6055887,0.9629847,,,,,,,,,,,,,, -141277,,,0.8630420565605164,0.4747589826583862,0.726419985294342,1.13043475151062,50000.0,0.6062000393867493,1.8601926565170288,10000.0,47979.92598223686,49688.70345711708,47979.92598223686,1699.5504500865936,4.012192010879517,0.0 -141300,3.7428637,0.95073974,,,,,,,,,,,,,, -141400,3.5976255,0.948257,,,,,,,,,,,,,, -141500,3.5780945,0.94967115,,,,,,,,,,,,,, -141600,3.9194214,0.90882957,,,,,,,,,,,,,, -141700,3.8848627,1.0277541,,,,,,,,,,,,,, -141800,3.6831975,1.000812,,,,,,,,,,,,,, -141900,3.390679,0.9017362,,,,,,,,,,,,,, -142000,3.7578897,0.8874672,,,,,,,,,,,,,, -142100,3.5061767,0.9108491,,,,,,,,,,,,,, -142200,3.4142225,0.9260143,,,,,,,,,,,,,, -142300,3.8000686,1.0168954,,,,,,,,,,,,,, -142400,3.5186777,0.93327856,,,,,,,,,,,,,, -142500,3.8193586,0.98710686,,,,,,,,,,,,,, -142600,3.5666003,0.8884373,,,,,,,,,,,,,, -142700,3.7017539,0.9500049,,,,,,,,,,,,,, -142781,,,0.8639389276504517,0.4681302905082702,0.7273600101470947,1.1400127410888672,50000.0,0.5963000059127808,1.9035028219223025,10000.0,48490.05924510956,50216.637149333954,48490.05924510956,1717.231991291046,4.073648452758789,0.0 -142800,3.633811,0.945179,,,,,,,,,,,,,, -142900,3.800329,1.0499947,,,,,,,,,,,,,, -143000,3.9726024,0.9776081,,,,,,,,,,,,,, -143100,3.8088198,0.9309573,,,,,,,,,,,,,, -143200,3.500832,0.83083016,,,,,,,,,,,,,, -143300,3.6628187,1.0090613,,,,,,,,,,,,,, -143400,3.8371396,0.94319373,,,,,,,,,,,,,, -143500,3.6982481,0.92851955,,,,,,,,,,,,,, -143600,4.101797,0.9675739,,,,,,,,,,,,,, -143700,3.865907,0.9445795,,,,,,,,,,,,,, -143800,3.6297307,0.9377163,,,,,,,,,,,,,, -143900,3.8571143,0.92097795,,,,,,,,,,,,,, -144000,3.6566758,0.9641007,,,,,,,,,,,,,, -144100,3.6612716,1.0005815,,,,,,,,,,,,,, -144200,3.8101504,0.9505954,,,,,,,,,,,,,, -144285,,,0.8989357352256775,0.352335661649704,0.7356399893760681,1.1080385446548462,50000.0,0.6078000068664551,1.842755913734436,10000.0,49000.12387108803,50744.83660268784,49000.12387108803,1735.258573770523,4.126613616943359,0.0 -144300,3.8004982,0.94500923,,,,,,,,,,,,,, -144400,3.718982,0.9083558,,,,,,,,,,,,,, -144500,4.1021276,0.9694247,,,,,,,,,,,,,, -144600,3.4031188,0.91110563,,,,,,,,,,,,,, -144700,3.8993375,1.0190248,,,,,,,,,,,,,, -144800,3.938311,0.92928606,,,,,,,,,,,,,, -144900,3.7306614,0.95027304,,,,,,,,,,,,,, -145000,4.029912,0.9433789,,,,,,,,,,,,,, -145100,3.6994653,0.901302,,,,,,,,,,,,,, -145200,3.7155926,0.9356544,,,,,,,,,,,,,, -145300,3.764297,0.9483532,,,,,,,,,,,,,, -145400,3.586445,0.8161033,,,,,,,,,,,,,, -145500,3.8852758,0.8918799,,,,,,,,,,,,,, -145600,3.7489338,0.83543575,,,,,,,,,,,,,, -145700,4.1202164,0.90200794,,,,,,,,,,,,,, -145788,,,0.8931760191917419,0.36942258477211,0.7378199696540833,1.0927354097366333,50000.0,0.6132000088691711,1.842835783958435,10000.0,49510.03820419312,51272.56175875664,49510.03820419312,1752.9578087329865,4.184151411056519,0.0 -145800,3.9128115,1.0203059,,,,,,,,,,,,,, -145900,3.6116207,0.8879971,,,,,,,,,,,,,, -146000,3.9785624,0.9270524,,,,,,,,,,,,,, -146100,3.5987725,0.8509158,,,,,,,,,,,,,, -146200,3.9386683,0.9083041,,,,,,,,,,,,,, -146300,4.0868397,0.88379586,,,,,,,,,,,,,, -146400,4.0181594,0.9236322,,,,,,,,,,,,,, -146500,3.8787324,0.9202452,,,,,,,,,,,,,, -146600,3.665702,0.8961467,,,,,,,,,,,,,, -146700,3.9700594,0.9053736,,,,,,,,,,,,,, -146800,4.246483,0.87293506,,,,,,,,,,,,,, -146900,3.716837,0.9609589,,,,,,,,,,,,,, -147000,3.9230003,0.95501804,,,,,,,,,,,,,, -147100,3.9688203,0.9242407,,,,,,,,,,,,,, -147200,3.9108386,0.8952426,,,,,,,,,,,,,, -147290,,,0.8905253410339355,0.3740461766719818,0.7361399531364441,1.08772873878479,50000.0,0.6048000454902649,1.8468471765518188,10000.0,50020.07937192917,51800.63540673256,50020.07937192917,1770.8793761730194,4.238028526306152,0.0 -147300,4.031343,0.8225105,,,,,,,,,,,,,, -147400,3.805918,0.8239774,,,,,,,,,,,,,, -147500,3.721566,0.8957561,,,,,,,,,,,,,, -147600,3.8363254,0.8823501,,,,,,,,,,,,,, -147700,3.832077,0.90712845,,,,,,,,,,,,,, -147800,4.3264475,0.9121554,,,,,,,,,,,,,, -147900,4.1470675,1.0014799,,,,,,,,,,,,,, -148000,4.023884,0.9111209,,,,,,,,,,,,,, -148100,3.805103,0.83299476,,,,,,,,,,,,,, -148200,3.688946,0.85579133,,,,,,,,,,,,,, -148300,3.881833,0.80324,,,,,,,,,,,,,, -148400,3.7343903,0.8405344,,,,,,,,,,,,,, -148500,4.2461114,0.9363639,,,,,,,,,,,,,, -148600,3.9549682,0.8849101,,,,,,,,,,,,,, -148700,3.9045331,0.8936128,,,,,,,,,,,,,, -148793,,,0.8932557106018066,0.3641088008880615,0.7346799969673157,1.1118369102478027,50000.0,0.6073000431060791,1.875683546066284,10000.0,50529.98758006096,52328.54508733749,50529.98758006096,1788.7686505317688,4.294397830963135,0.0 -148800,3.867504,0.83245337,,,,,,,,,,,,,, -148900,3.833201,0.83210444,,,,,,,,,,,,,, -149000,4.364492,0.93681294,,,,,,,,,,,,,, -149100,3.82754,0.89546704,,,,,,,,,,,,,, -149200,3.9519322,0.8519892,,,,,,,,,,,,,, -149300,4.1512227,0.89758855,,,,,,,,,,,,,, -149400,3.9117181,0.85032606,,,,,,,,,,,,,, -149500,4.44224,0.89142555,,,,,,,,,,,,,, -149600,3.7695496,0.8681403,,,,,,,,,,,,,, -149700,3.8659153,0.8624532,,,,,,,,,,,,,, -149800,4.000084,0.8898487,,,,,,,,,,,,,, -149900,4.211828,0.8893753,,,,,,,,,,,,,, -150000,4.0997705,0.83906984,,,,,,,,,,,,,, -150100,4.0310574,0.88889426,,,,,,,,,,,,,, -150200,3.7906592,0.7705477,,,,,,,,,,,,,, -150296,,,0.8932358026504517,0.363602340221405,0.7393400073051453,1.103121042251587,50000.0,0.6160000562667847,1.858383297920227,10000.0,51040.08410692215,52856.537358284,51040.08410692215,1806.554343700409,4.349008321762085,0.0 -150300,3.9084086,0.84891874,,,,,,,,,,,,,, -150400,3.824875,0.7968542,,,,,,,,,,,,,, -150500,3.7542365,0.89312166,,,,,,,,,,,,,, -150600,4.106807,0.9080759,,,,,,,,,,,,,, -150700,3.6620958,0.80054015,,,,,,,,,,,,,, -150800,3.6985407,0.8305339,,,,,,,,,,,,,, -150900,4.211924,0.9882556,,,,,,,,,,,,,, -151000,3.827437,0.8036897,,,,,,,,,,,,,, -151100,3.8237183,0.94685125,,,,,,,,,,,,,, -151200,4.1526246,0.8714194,,,,,,,,,,,,,, -151300,3.697098,0.8380047,,,,,,,,,,,,,, -151400,3.9378462,0.8535041,,,,,,,,,,,,,, -151500,4.06798,0.8880647,,,,,,,,,,,,,, -151600,4.0646787,0.858876,,,,,,,,,,,,,, -151700,4.0862875,0.8347682,,,,,,,,,,,,,, -151799,,,0.8980388641357422,0.3513565361499786,0.7346799969673157,1.1100929975509644,50000.0,0.6099000573158264,1.8680452108383176,10000.0,51550.05728435516,53384.782229185104,51550.05728435516,1824.7137916088104,4.405257701873779,0.0 -151800,4.188156,0.8436711,,,,,,,,,,,,,, -151900,4.389619,0.926002,,,,,,,,,,,,,, -152000,3.827487,0.9009323,,,,,,,,,,,,,, -152100,4.0923405,0.8484601,,,,,,,,,,,,,, -152200,3.950677,0.81487805,,,,,,,,,,,,,, -152300,3.9740653,0.872634,,,,,,,,,,,,,, -152400,4.0765157,0.81166303,,,,,,,,,,,,,, -152500,3.9615946,0.80684817,,,,,,,,,,,,,, -152600,4.011521,0.9101854,,,,,,,,,,,,,, -152700,3.9611528,0.773957,,,,,,,,,,,,,, -152800,3.7466226,0.81107163,,,,,,,,,,,,,, -152900,4.129103,0.85201144,,,,,,,,,,,,,, -153000,4.110419,0.84795815,,,,,,,,,,,,,, -153100,4.06363,0.8297019,,,,,,,,,,,,,, -153200,4.081508,0.7753943,,,,,,,,,,,,,, -153300,4.428372,0.8642963,,,,,,,,,,,,,, -153303,,,0.9187459945678712,0.2857046723365783,0.7361999750137329,1.1064831018447876,50000.0,0.6100000143051147,1.8766512870788568,10000.0,52060.17003774643,53912.83220410347,52060.17003774643,1842.537333250045,4.462253093719482,0.0 -153400,4.304401,0.9705068,,,,,,,,,,,,,, -153500,4.012359,0.88354456,,,,,,,,,,,,,, -153600,3.9101863,0.8450036,,,,,,,,,,,,,, -153700,3.7801208,0.7784791,,,,,,,,,,,,,, -153800,3.9506311,0.7830916,,,,,,,,,,,,,, -153900,3.698471,0.7711214,,,,,,,,,,,,,, -154000,4.29649,0.8104387,,,,,,,,,,,,,, -154100,4.087879,0.8837658,,,,,,,,,,,,,, -154200,3.935905,0.72800195,,,,,,,,,,,,,, -154300,4.380361,0.85163176,,,,,,,,,,,,,, -154400,3.9178138,0.7382307,,,,,,,,,,,,,, -154500,4.1313744,0.8906855,,,,,,,,,,,,,, -154600,3.7261703,0.7264492,,,,,,,,,,,,,, -154700,3.8138962,0.77505356,,,,,,,,,,,,,, -154800,4.265891,0.88506013,,,,,,,,,,,,,, -154806,,,0.9205396771430968,0.2786844074726105,0.7440399527549744,1.0820631980895996,50000.0,0.6166000366210938,1.84649658203125,10000.0,52570.1029856205,54440.91394495964,52570.1029856205,1860.575585603714,4.516483306884766,0.0 -154900,4.0866103,0.7267026,,,,,,,,,,,,,, -155000,4.1923347,0.84377295,,,,,,,,,,,,,, -155100,4.1467123,0.78347826,,,,,,,,,,,,,, -155200,4.465802,0.8458768,,,,,,,,,,,,,, -155300,4.2023153,0.87932813,,,,,,,,,,,,,, -155400,4.2718277,0.82831866,,,,,,,,,,,,,, -155500,4.0527816,0.74721533,,,,,,,,,,,,,, -155600,3.9829097,0.81844705,,,,,,,,,,,,,, -155700,4.193409,0.85771316,,,,,,,,,,,,,, -155800,4.4802246,0.8029051,,,,,,,,,,,,,, -155900,4.475282,0.8546789,,,,,,,,,,,,,, -156000,4.102091,0.80162627,,,,,,,,,,,,,, -156100,4.3604927,0.7546677,,,,,,,,,,,,,, -156200,4.099206,0.86115456,,,,,,,,,,,,,, -156300,4.3063073,0.8083625,,,,,,,,,,,,,, -156310,,,0.9159757494926452,0.2899870276451111,0.7417399883270264,1.092580795288086,50000.0,0.6159000396728516,1.860220432281494,10000.0,53080.2790760994,54969.0854177475,53080.2790760994,1878.458645582199,4.574311494827271,0.0 -156400,4.2420425,0.8050773,,,,,,,,,,,,,, -156500,4.419969,0.7698693,,,,,,,,,,,,,, -156600,3.940554,0.8244856,,,,,,,,,,,,,, -156700,3.870112,0.7929901,,,,,,,,,,,,,, -156800,4.4290733,0.7544625,,,,,,,,,,,,,, -156900,4.0942373,0.7487783,,,,,,,,,,,,,, -157000,4.1334023,0.7983089,,,,,,,,,,,,,, -157100,4.3017287,0.80520177,,,,,,,,,,,,,, -157200,3.9141526,0.7390211,,,,,,,,,,,,,, -157300,4.341584,0.79580605,,,,,,,,,,,,,, -157400,4.3625774,0.80604196,,,,,,,,,,,,,, -157500,4.042393,0.7297207,,,,,,,,,,,,,, -157600,4.1128263,0.7897986,,,,,,,,,,,,,, -157700,4.4695063,0.80952245,,,,,,,,,,,,,, -157800,4.221804,0.76118654,,,,,,,,,,,,,, -157813,,,0.9194834232330322,0.2723006904125213,0.7447599768638611,1.0868667364120483,50000.0,0.617400050163269,1.8521299362182613,10000.0,53590.2423658371,55496.98831796646,53590.2423658371,1896.2873928546903,4.627979040145874,0.0 -157900,4.592637,0.8745746,,,,,,,,,,,,,, -158000,3.9767814,0.7173193,,,,,,,,,,,,,, -158100,4.084746,0.7537709,,,,,,,,,,,,,, -158200,4.2822785,0.8619109,,,,,,,,,,,,,, -158300,4.298292,0.8354071,,,,,,,,,,,,,, -158400,4.0901985,0.84887314,,,,,,,,,,,,,, -158500,4.676873,0.87955403,,,,,,,,,,,,,, -158600,3.9956675,0.7359059,,,,,,,,,,,,,, -158700,4.167384,0.79297304,,,,,,,,,,,,,, -158800,4.274775,0.85578084,,,,,,,,,,,,,, -158900,4.2139025,0.8221836,,,,,,,,,,,,,, -159000,4.094312,0.76140827,,,,,,,,,,,,,, -159100,4.538099,0.77837515,,,,,,,,,,,,,, -159200,4.5530605,0.8129656,,,,,,,,,,,,,, -159300,3.886095,0.7935217,,,,,,,,,,,,,, -159316,,,0.9231106042861938,0.2618570923805237,0.746239960193634,1.0792582035064695,50000.0,0.6223000288009644,1.8350087404251096,10000.0,54100.30746936798,56025.2314991951,54100.30746936798,1914.3551132678983,4.682864904403687,0.0 -159400,4.131347,0.7886603,,,,,,,,,,,,,, -159500,4.5384436,0.8328931,,,,,,,,,,,,,, -159600,4.5228963,0.76008403,,,,,,,,,,,,,, -159700,4.315317,0.7846221,,,,,,,,,,,,,, -159800,4.371409,0.77456427,,,,,,,,,,,,,, -159900,4.2659388,0.8131582,,,,,,,,,,,,,, -160000,4.5536585,0.8547004,,,,,,,,,,,,,, -160100,4.32583,0.7575029,,,,,,,,,,,,,, -160200,4.145773,0.7244063,,,,,,,,,,,,,, -160300,4.409366,0.8773477,,,,,,,,,,,,,, -160400,4.5270624,0.7878803,,,,,,,,,,,,,, -160500,4.0544844,0.7726365,,,,,,,,,,,,,, -160600,4.0380063,0.69211495,,,,,,,,,,,,,, -160700,4.486453,0.70966494,,,,,,,,,,,,,, -160800,4.2172413,0.7444823,,,,,,,,,,,,,, -160819,,,0.9272759556770324,0.2506012916564941,0.7461400032043457,1.0742896795272827,50000.0,0.6165000200271606,1.841195821762085,10000.0,54610.30392670632,56553.10640120506,54610.30392670632,1932.119315862656,4.740338563919067,0.0 -160900,4.1229224,0.7096086,,,,,,,,,,,,,, -161000,4.177314,0.733248,,,,,,,,,,,,,, -161100,4.0756044,0.6818749,,,,,,,,,,,,,, -161200,4.7379518,0.77908367,,,,,,,,,,,,,, -161300,4.364773,0.81282204,,,,,,,,,,,,,, -161400,4.737295,0.74367464,,,,,,,,,,,,,, -161500,4.201579,0.68108743,,,,,,,,,,,,,, -161600,4.0922036,0.6935081,,,,,,,,,,,,,, -161700,4.312067,0.7486347,,,,,,,,,,,,,, -161800,4.1772885,0.72078913,,,,,,,,,,,,,, -161900,4.4703116,0.7123143,,,,,,,,,,,,,, -162000,4.819591,0.76957726,,,,,,,,,,,,,, -162100,4.265822,0.76095736,,,,,,,,,,,,,, -162200,4.746348,0.83162814,,,,,,,,,,,,,, -162300,4.813008,0.8000669,,,,,,,,,,,,,, -162322,,,0.9462890625,0.1946813315153122,0.7486400008201599,1.0698437690734863,50000.0,0.6172000169754028,1.8549610376358032,10000.0,55120.26196694374,57080.722338199615,55120.26196694374,1949.6665840148928,4.796144008636475,0.0 -162400,4.2791986,0.73716277,,,,,,,,,,,,,, -162500,4.109647,0.67933255,,,,,,,,,,,,,, -162600,4.4200377,0.70088774,,,,,,,,,,,,,, -162700,4.397527,0.7186357,,,,,,,,,,,,,, -162800,4.3923755,0.7091321,,,,,,,,,,,,,, -162900,4.366483,0.7217092,,,,,,,,,,,,,, -163000,4.353551,0.7148872,,,,,,,,,,,,,, -163100,4.434035,0.749265,,,,,,,,,,,,,, -163200,4.39056,0.7106069,,,,,,,,,,,,,, -163300,4.424605,0.7856679,,,,,,,,,,,,,, -163400,4.103768,0.71689075,,,,,,,,,,,,,, -163500,4.210501,0.69556856,,,,,,,,,,,,,, -163600,4.2851605,0.7285475,,,,,,,,,,,,,, -163700,4.3733964,0.6916755,,,,,,,,,,,,,, -163800,4.3913665,0.70756155,,,,,,,,,,,,,, -163825,,,0.9412069320678712,0.2064539641141891,0.7478599548339844,1.0675770044326782,50000.0,0.6202000379562378,1.8455793857574463,10000.0,55630.45647478104,57608.87885069847,55630.45647478104,1967.51427936554,4.854479789733887,0.0 -163900,4.2283716,0.70592463,,,,,,,,,,,,,, -164000,4.1476583,0.6690192,,,,,,,,,,,,,, -164100,4.380987,0.70624125,,,,,,,,,,,,,, -164200,3.9927442,0.7073163,,,,,,,,,,,,,, -164300,4.575537,0.75977796,,,,,,,,,,,,,, -164400,4.5802393,0.7306526,,,,,,,,,,,,,, -164500,4.3398128,0.7023195,,,,,,,,,,,,,, -164600,4.1549263,0.70107424,,,,,,,,,,,,,, -164700,4.412324,0.6732291,,,,,,,,,,,,,, -164800,4.42903,0.7229731,,,,,,,,,,,,,, -164900,4.5821834,0.7927813,,,,,,,,,,,,,, -165000,4.380654,0.70854306,,,,,,,,,,,,,, -165100,4.584676,0.755088,,,,,,,,,,,,,, -165200,4.0007443,0.5804426,,,,,,,,,,,,,, -165300,4.7429004,0.7861776,,,,,,,,,,,,,, -165329,,,0.9415258169174194,0.2043623626232147,0.7491599917411804,1.068067193031311,50000.0,0.6196000576019287,1.8581254482269287,10000.0,56140.58478808403,58137.0842730999,56140.58478808403,1985.47976231575,4.911173105239868,0.0 -165400,4.131954,0.66825116,,,,,,,,,,,,,, -165500,4.46731,0.6813753,,,,,,,,,,,,,, -165600,4.4878983,0.7512912,,,,,,,,,,,,,, -165700,4.2623878,0.7109117,,,,,,,,,,,,,, -165800,4.607316,0.69671667,,,,,,,,,,,,,, -165900,4.383069,0.72981125,,,,,,,,,,,,,, -166000,4.364512,0.71295327,,,,,,,,,,,,,, -166100,4.1236014,0.62045455,,,,,,,,,,,,,, -166200,4.3331337,0.68224543,,,,,,,,,,,,,, -166300,4.8099236,0.7845649,,,,,,,,,,,,,, -166400,4.623006,0.71531844,,,,,,,,,,,,,, -166500,4.596371,0.6978604,,,,,,,,,,,,,, -166600,4.3637524,0.77048683,,,,,,,,,,,,,, -166700,4.23077,0.6987916,,,,,,,,,,,,,, -166800,4.591548,0.7217709,,,,,,,,,,,,,, -166832,,,0.9414859414100648,0.2066220045089721,0.7490400075912476,1.0683197975158691,50000.0,0.6226000189781189,1.836353421211243,10000.0,56650.707070589066,58665.36136960983,56650.707070589066,2003.517992734909,4.971322536468506,0.0 -166900,4.4663606,0.7364661,,,,,,,,,,,,,, -167000,4.2400913,0.6986699,,,,,,,,,,,,,, -167100,4.3005004,0.64074665,,,,,,,,,,,,,, -167200,4.276172,0.66229117,,,,,,,,,,,,,, -167300,4.6578655,0.69251883,,,,,,,,,,,,,, -167400,4.4670534,0.73426884,,,,,,,,,,,,,, -167500,4.677985,0.7536009,,,,,,,,,,,,,, -167600,4.546933,0.6650293,,,,,,,,,,,,,, -167700,4.726865,0.75923187,,,,,,,,,,,,,, -167800,4.3740344,0.69389236,,,,,,,,,,,,,, -167900,5.109883,0.7571018,,,,,,,,,,,,,, -168000,4.2079377,0.65437526,,,,,,,,,,,,,, -168100,4.8015447,0.7787396,,,,,,,,,,,,,, -168200,4.5466914,0.66829026,,,,,,,,,,,,,, -168300,4.7298403,0.7719623,,,,,,,,,,,,,, -168335,,,0.9429408311843872,0.1995099037885666,0.7509599924087524,1.0619680881500244,50000.0,0.6261000037193298,1.8424190282821653,10000.0,57160.77901077271,59193.23601198197,57160.77901077271,2021.207232952118,5.02895712852478,0.0 -168400,4.807946,0.7007116,,,,,,,,,,,,,, -168500,4.3803635,0.6952865,,,,,,,,,,,,,, -168600,4.692153,0.69554013,,,,,,,,,,,,,, -168700,4.3024707,0.7497127,,,,,,,,,,,,,, -168800,4.759723,0.6883726,,,,,,,,,,,,,, -168900,4.640013,0.7117714,,,,,,,,,,,,,, -169000,4.2440305,0.69350874,,,,,,,,,,,,,, -169100,4.7510386,0.7025767,,,,,,,,,,,,,, -169200,4.1288705,0.60939217,,,,,,,,,,,,,, -169300,4.070508,0.690728,,,,,,,,,,,,,, -169400,4.1880956,0.6605699,,,,,,,,,,,,,, -169500,4.4978456,0.72137684,,,,,,,,,,,,,, -169600,4.377101,0.64312077,,,,,,,,,,,,,, -169700,4.324048,0.73130834,,,,,,,,,,,,,, -169800,4.571334,0.6937921,,,,,,,,,,,,,, -169838,,,0.949238657951355,0.1808584332466125,0.7513399720191956,1.060373306274414,50000.0,0.6240000128746033,1.8413702249526973,10000.0,57670.78219342232,59720.89051890373,57670.78219342232,2038.746482372284,5.086165189743042,0.0 -169900,4.107297,0.63170445,,,,,,,,,,,,,, -170000,4.232758,0.7294427,,,,,,,,,,,,,, -170100,4.1406665,0.6398449,,,,,,,,,,,,,, -170200,4.526888,0.60862505,,,,,,,,,,,,,, -170300,4.4637523,0.67447186,,,,,,,,,,,,,, -170400,4.16709,0.6346611,,,,,,,,,,,,,, -170500,4.8352923,0.7351571,,,,,,,,,,,,,, -170600,4.6664267,0.6235671,,,,,,,,,,,,,, -170700,4.4081163,0.69799423,,,,,,,,,,,,,, -170800,4.3634124,0.61978996,,,,,,,,,,,,,, -170900,4.320544,0.6302501,,,,,,,,,,,,,, -171000,4.535399,0.661378,,,,,,,,,,,,,, -171100,4.3403788,0.64031667,,,,,,,,,,,,,, -171200,4.619087,0.660119,,,,,,,,,,,,,, -171300,4.9148874,0.7234287,,,,,,,,,,,,,, -171342,,,0.9575095176696776,0.1561500877141952,0.7531999945640564,1.0549283027648926,50000.0,0.6264000535011292,1.832197308540344,10000.0,58180.96397686005,60249.019006729126,58180.96397686005,2056.577041864395,5.145231485366821,0.0 -171400,4.4854503,0.6753837,,,,,,,,,,,,,, -171500,4.644359,0.634243,,,,,,,,,,,,,, -171600,4.221565,0.63042516,,,,,,,,,,,,,, -171700,4.635718,0.655708,,,,,,,,,,,,,, -171800,4.65927,0.5456376,,,,,,,,,,,,,, -171900,4.3391423,0.65118486,,,,,,,,,,,,,, -172000,5.0092793,0.7638219,,,,,,,,,,,,,, -172100,4.4791565,0.63773715,,,,,,,,,,,,,, -172200,4.1426353,0.62700075,,,,,,,,,,,,,, -172300,4.106591,0.6359603,,,,,,,,,,,,,, -172400,4.537055,0.67141294,,,,,,,,,,,,,, -172500,4.7582765,0.6546058,,,,,,,,,,,,,, -172600,4.687445,0.6077806,,,,,,,,,,,,,, -172700,4.65462,0.6456942,,,,,,,,,,,,,, -172800,5.1846766,0.7403704,,,,,,,,,,,,,, -172845,,,0.9551578164100648,0.1624404788017273,0.7537599802017212,1.0578937530517578,50000.0,0.6274000406265259,1.8417407274246216,10000.0,58691.0938334465,60777.03437685967,58691.0938334465,2074.348899126053,5.2050487995147705,0.0 -172900,4.55274,0.68826663,,,,,,,,,,,,,, -173000,4.7013273,0.61565375,,,,,,,,,,,,,, -173100,4.88246,0.67093676,,,,,,,,,,,,,, -173200,4.4304175,0.65545154,,,,,,,,,,,,,, -173300,4.967025,0.68801093,,,,,,,,,,,,,, -173400,4.4470515,0.6508327,,,,,,,,,,,,,, -173500,4.8876543,0.6900376,,,,,,,,,,,,,, -173600,4.4933157,0.6897182,,,,,,,,,,,,,, -173700,4.690702,0.7100977,,,,,,,,,,,,,, -173800,4.323722,0.63569325,,,,,,,,,,,,,, -173900,4.6587124,0.7133093,,,,,,,,,,,,,, -174000,4.530672,0.6032926,,,,,,,,,,,,,, -174100,4.749698,0.68493104,,,,,,,,,,,,,, -174200,4.284716,0.7119328,,,,,,,,,,,,,, -174300,4.7268114,0.6461692,,,,,,,,,,,,,, -174348,,,0.9562938213348388,0.1600802987813949,0.7548199892044067,1.0536705255508425,50000.0,0.6251000165939331,1.828043818473816,10000.0,59201.03386545181,61305.60614085197,59201.03386545181,2092.863926887512,5.264968156814575,0.0 -174400,4.3489647,0.58464336,,,,,,,,,,,,,, -174500,4.4176507,0.60921454,,,,,,,,,,,,,, -174600,4.3171377,0.6291987,,,,,,,,,,,,,, -174700,3.9985127,0.54983604,,,,,,,,,,,,,, -174800,4.2693286,0.5626882,,,,,,,,,,,,,, -174900,4.2707887,0.6276396,,,,,,,,,,,,,, -175000,4.453269,0.6490452,,,,,,,,,,,,,, -175100,4.725427,0.7042542,,,,,,,,,,,,,, -175200,4.5700846,0.6073024,,,,,,,,,,,,,, -175300,4.1230555,0.5625778,,,,,,,,,,,,,, -175400,4.43355,0.6360874,,,,,,,,,,,,,, -175500,4.6429625,0.618358,,,,,,,,,,,,,, -175600,4.23127,0.654737,,,,,,,,,,,,,, -175700,4.649459,0.6207784,,,,,,,,,,,,,, -175800,4.513749,0.65325737,,,,,,,,,,,,,, -175850,,,0.9552574753761292,0.1618704348802566,0.753879964351654,1.0522363185882568,50000.0,0.628000020980835,1.828550815582276,10000.0,59710.93507862091,61833.54664039612,59710.93507862091,2110.787750482559,5.325360298156738,0.0 -175900,4.619442,0.61642337,,,,,,,,,,,,,, -176000,4.5453167,0.6594496,,,,,,,,,,,,,, -176100,4.444076,0.62365323,,,,,,,,,,,,,, -176200,4.531541,0.646615,,,,,,,,,,,,,, -176300,4.8149962,0.607762,,,,,,,,,,,,,, -176400,4.0148273,0.5888925,,,,,,,,,,,,,, -176500,4.439782,0.60480124,,,,,,,,,,,,,, -176600,4.470874,0.6410275,,,,,,,,,,,,,, -176700,4.547436,0.62591887,,,,,,,,,,,,,, -176800,4.543537,0.67255425,,,,,,,,,,,,,, -176900,4.475057,0.5780136,,,,,,,,,,,,,, -177000,4.5548043,0.6309147,,,,,,,,,,,,,, -177100,4.886173,0.63939464,,,,,,,,,,,,,, -177200,4.7862635,0.71050215,,,,,,,,,,,,,, -177300,4.9420395,0.6039721,,,,,,,,,,,,,, -177353,,,0.9591039419174194,0.1550398766994476,0.7555800080299377,1.0516177415847778,50000.0,0.6294000148773193,1.820769429206848,10000.0,60220.999058008194,62361.60022234917,60220.999058008194,2128.663361310959,5.3837199211120605,0.0 -177400,4.387962,0.63288677,,,,,,,,,,,,,, -177500,4.18938,0.6308017,,,,,,,,,,,,,, -177600,4.836048,0.6308858,,,,,,,,,,,,,, -177700,4.367558,0.594525,,,,,,,,,,,,,, -177800,4.6669292,0.6054982,,,,,,,,,,,,,, -177900,4.7175956,0.62088704,,,,,,,,,,,,,, -178000,4.656584,0.692094,,,,,,,,,,,,,, -178100,4.5046415,0.6218072,,,,,,,,,,,,,, -178200,4.9264116,0.6893868,,,,,,,,,,,,,, -178300,4.677552,0.68034244,,,,,,,,,,,,,, -178400,4.468532,0.6384386,,,,,,,,,,,,,, -178500,4.2940903,0.6727109,,,,,,,,,,,,,, -178600,4.555825,0.659645,,,,,,,,,,,,,, -178700,4.414423,0.57997245,,,,,,,,,,,,,, -178800,5.153641,0.6195537,,,,,,,,,,,,,, -178856,,,0.9585060477256776,0.1536491960287094,0.7564199566841125,1.0479017496109009,50000.0,0.6295000314712524,1.821165680885315,10000.0,60731.09321784973,62889.41806507111,60731.09321784973,2146.2478954792023,5.467687129974365,0.0 -178900,4.6821747,0.69545275,,,,,,,,,,,,,, -179000,4.5429053,0.71541363,,,,,,,,,,,,,, -179100,4.450822,0.6357672,,,,,,,,,,,,,, -179200,4.7325325,0.6338256,,,,,,,,,,,,,, -179300,4.9196224,0.6737886,,,,,,,,,,,,,, -179400,4.296518,0.64218855,,,,,,,,,,,,,, -179500,4.910374,0.68594104,,,,,,,,,,,,,, -179600,4.7254353,0.6560375,,,,,,,,,,,,,, -179700,4.272888,0.5952716,,,,,,,,,,,,,, -179800,4.546713,0.6367809,,,,,,,,,,,,,, -179900,4.3388395,0.6471196,,,,,,,,,,,,,, -180000,4.60827,0.62391293,,,,,,,,,,,,,, -180100,4.2030034,0.63914,,,,,,,,,,,,,, -180200,4.7833815,0.65882885,,,,,,,,,,,,,, -180300,4.913914,0.5871588,,,,,,,,,,,,,, -180359,,,0.9599609375,0.1484406441450119,0.7556799650192261,1.0471597909927368,50000.0,0.6314000487327576,1.8269197940826416,10000.0,61241.01958680153,63417.25886678696,61241.01958680153,2164.042452096939,5.532188653945923,0.0 -180400,4.533183,0.6284253,,,,,,,,,,,,,, -180500,4.5743923,0.67366475,,,,,,,,,,,,,, -180600,4.607141,0.69040406,,,,,,,,,,,,,, -180700,4.4116263,0.61744094,,,,,,,,,,,,,, -180800,4.3176684,0.6093015,,,,,,,,,,,,,, -180900,4.321938,0.63091797,,,,,,,,,,,,,, -181000,4.3926597,0.5649346,,,,,,,,,,,,,, -181100,4.567826,0.64902794,,,,,,,,,,,,,, -181200,4.617874,0.6220678,,,,,,,,,,,,,, -181300,4.8761597,0.6489299,,,,,,,,,,,,,, -181400,4.492257,0.6365363,,,,,,,,,,,,,, -181500,4.799636,0.6697243,,,,,,,,,,,,,, -181600,4.6046553,0.676582,,,,,,,,,,,,,, -181700,4.693967,0.61133707,,,,,,,,,,,,,, -181800,4.4762845,0.5916349,,,,,,,,,,,,,, -181862,,,0.961355984210968,0.1469281464815139,0.7557799816131592,1.046109676361084,50000.0,0.6303000450134277,1.8241914510726929,10000.0,61751.20690441132,63945.35855007172,61751.20690441132,2181.832026720047,5.597205638885498,0.0 -181900,4.584424,0.62992316,,,,,,,,,,,,,, -182000,4.8311663,0.6332154,,,,,,,,,,,,,, -182100,4.395992,0.6402653,,,,,,,,,,,,,, -182200,4.8324776,0.6912076,,,,,,,,,,,,,, -182300,4.217497,0.59603953,,,,,,,,,,,,,, -182400,4.8720617,0.63646483,,,,,,,,,,,,,, -182500,4.3399053,0.5842087,,,,,,,,,,,,,, -182600,4.795415,0.6497725,,,,,,,,,,,,,, -182700,4.433975,0.6121318,,,,,,,,,,,,,, -182800,4.6314106,0.7081212,,,,,,,,,,,,,, -182900,4.3963623,0.63822234,,,,,,,,,,,,,, -183000,5.352112,0.69038594,,,,,,,,,,,,,, -183100,4.629441,0.64526737,,,,,,,,,,,,,, -183200,4.9575586,0.6680967,,,,,,,,,,,,,, -183300,4.520844,0.6056003,,,,,,,,,,,,,, -183366,,,0.9613958597183228,0.1471957266330719,0.7557199597358704,1.0465097427368164,50000.0,0.6312000155448914,1.8238227367401123,10000.0,62261.4028301239,64473.26201725006,62261.4028301239,2199.4081242084503,5.6746907234191895,0.0 -183400,4.0834274,0.5756307,,,,,,,,,,,,,, -183500,4.4686093,0.67460704,,,,,,,,,,,,,, -183600,5.0620074,0.6302975,,,,,,,,,,,,,, -183700,4.604502,0.69164,,,,,,,,,,,,,, -183800,4.63226,0.6569227,,,,,,,,,,,,,, -183900,4.6757174,0.7317283,,,,,,,,,,,,,, -184000,4.9488244,0.631391,,,,,,,,,,,,,, -184100,4.2046194,0.53396946,,,,,,,,,,,,,, -184200,5.0628433,0.6720149,,,,,,,,,,,,,, -184300,4.6117787,0.65999466,,,,,,,,,,,,,, -184400,4.539718,0.59219253,,,,,,,,,,,,,, -184500,4.480992,0.59167415,,,,,,,,,,,,,, -184600,4.1129284,0.61338353,,,,,,,,,,,,,, -184700,4.394139,0.629273,,,,,,,,,,,,,, -184800,4.334815,0.60849065,,,,,,,,,,,,,, -184868,,,0.959980845451355,0.1459084749221801,0.7561999559402466,1.0453925132751465,50000.0,0.6312000155448914,1.822606921195984,10000.0,62771.28944039345,65001.01459479332,62771.28944039345,2217.1566207408905,5.737559795379639,0.0 -184900,4.5480704,0.6527473,,,,,,,,,,,,,, -185000,4.5830207,0.61798555,,,,,,,,,,,,,, -185100,4.6356335,0.5944183,,,,,,,,,,,,,, -185200,4.43008,0.5575676,,,,,,,,,,,,,, -185300,4.5743914,0.66021156,,,,,,,,,,,,,, -185400,4.498904,0.60537136,,,,,,,,,,,,,, -185500,4.6476946,0.6361096,,,,,,,,,,,,,, -185566,,,,,,,,,,,63008.05158352852,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index e8da8c0b2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.06261396408081,0.0,42.45515441894531,1,0,42.45515441894531,0.0010000000474974,6.907756805419922,10000,82.5178644657135,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -61.94122099876404,0.0286300182342529,462.6178922653198,898,0,462.6178922653198,0.0108000002801418,6.456206321716309,10000,524.6362130641937,0.0137695306912064,6.404803276062012,0.0135800000280141,6.413737773895264,50000 -83.57465052604675,0.06095552444458,882.9875965118408,1842,0,882.9875965118408,0.0306000020354986,5.982430934906006,10000,966.72283244133,0.0413671880960464,5.8381195068359375,0.0377600006759166,5.869866847991943,50000 -105.4310748577118,0.0899255275726318,1303.2035462856293,2788,0,1303.2035462856293,0.05180000141263,5.644834041595459,10000,1408.8744959831238,0.0702148452401161,5.4130401611328125,0.0637599974870681,5.468223571777344,50000 -127.17176485061646,0.118675947189331,1723.364368915558,3733,0,1723.364368915558,0.0715000033378601,5.371340274810791,10000,1850.8554532527924,0.1027148440480232,5.1039581298828125,0.0944399982690811,5.144573211669922,50000 -148.97430968284607,0.1471822261810302,2143.616794347763,4673,0,2143.616794347763,0.0976000055670738,5.071451663970947,10000,2292.988788843155,0.1353906244039535,4.738257884979248,0.1279999911785125,4.791199207305908,50000 -170.8790421485901,0.1756596565246582,2563.884337425232,5609,0,2563.884337425232,0.1264000087976455,4.761846542358398,10000,2735.237434864044,0.1873828023672104,4.345278739929199,0.1711599975824356,4.43392276763916,50000 -195.88995552062988,0.2058713436126709,2984.221193313598,6546,0,2984.221193313598,0.1563000082969665,4.464807987213135,10000,3180.6647255420685,0.2308789044618606,3.995484590530396,0.211119994521141,4.086441993713379,50000 -222.11826848983765,0.2382426261901855,3404.525137901306,7486,0,3404.525137901306,0.1887000054121017,4.262177467346191,10000,3627.2830555439,0.2678515613079071,3.742283821105957,0.2485999912023544,3.837204694747925,50000 -253.2263140678405,0.2754535675048828,3824.85044836998,8427,0,3824.85044836998,0.2143000066280365,4.080821514129639,10000,4078.802620887756,0.3073437511920929,3.501837491989136,0.2824800014495849,3.621398448944092,50000 -278.9440326690674,0.3049573898315429,4245.004207611084,9367,0,4245.004207611084,0.242000013589859,3.8549141883850098,10000,4524.752751588821,0.353339821100235,3.1838603019714355,0.315420001745224,3.375602960586548,50000 -307.616507768631,0.3441452980041504,4665.314122676849,10306,0,4665.314122676849,0.2660000026226043,3.685535430908203,10000,4973.823203802109,0.3710156083106994,3.0405492782592773,0.3440199792385101,3.174036741256714,50000 -335.2799074649811,0.3747379779815674,5085.516779184341,11243,0,5085.516779184341,0.2772000133991241,3.618096351623535,10000,5421.768236398697,0.3888476490974426,2.948613166809082,0.3602799773216247,3.093329668045044,50000 -367.0149767398834,0.4094424247741699,5505.593339443207,12180,0,5505.593339443207,0.2879000008106231,3.568051338195801,10000,5873.66360616684,0.4136328101158142,2.8385469913482666,0.3738999962806701,3.021648406982422,50000 -400.35482573509216,0.441788911819458,5925.677830457687,13113,0,5925.677830457687,0.3093000054359436,3.4295434951782227,10000,6327.169443368912,0.4274218678474426,2.739063739776612,0.3973399996757507,2.8888654708862305,50000 -437.1570258140564,0.4712910652160644,6345.869318246841,14049,0,6345.869318246841,0.3188000023365021,3.367639303207397,10000,6784.241494894028,0.4437695145606994,2.6229381561279297,0.4103399813175201,2.784433364868164,50000 -473.5686767101288,0.5005159378051758,6765.792047977448,14985,0,6765.792047977448,0.3266000151634216,3.2856686115264893,10000,7240.653775691986,0.4652148187160492,2.516399383544922,0.4221400022506714,2.725985288619995,50000 -506.31141448020935,0.5309438705444336,7186.127139806747,15915,0,7186.127139806747,0.3323000073432922,3.2542564868927,10000,7693.80969619751,0.4888085722923279,2.4176840782165527,0.4336199760437011,2.675327777862549,50000 -544.1211948394775,0.5615954399108887,7606.272791385651,16839,0,7606.272791385651,0.3374000191688537,3.2630906105041504,10000,8151.8433039188385,0.4630078077316284,2.538491725921631,0.4345999956130981,2.679887056350708,50000 -579.3966374397278,0.5898916721343994,8026.660900354385,17769,0,8026.660900354385,0.34170001745224,3.178121328353882,10000,8607.584362506866,0.4804101586341858,2.426698684692383,0.4465200006961822,2.5907909870147705,50000 -613.8807055950165,0.6182742118835449,8446.603140592575,18701,0,8446.603140592575,0.3521000146865845,3.1127684116363525,10000,9062.087691783903,0.5080273151397705,2.2943875789642334,0.4573400020599365,2.531662702560425,50000 -650.9721372127533,0.6507935523986816,8866.9126598835,19630,0,8866.9126598835,0.3632000088691711,3.1231794357299805,10000,9519.569938898088,0.4916210770606994,2.3779008388519287,0.4596799910068512,2.5362136363983154,50000 -687.1052870750427,0.6784398555755615,9286.968410253525,20558,0,9286.968410253525,0.3674000203609466,2.994854211807251,10000,9975.836344718931,0.5131444931030273,2.2290053367614746,0.4773799777030945,2.398954629898072,50000 -719.8414082527161,0.7146031856536865,9707.312128067017,21482,0,9707.312128067017,0.3772000074386596,3.007101535797119,10000,10429.00049996376,0.5301367044448853,2.176165580749512,0.4836399853229522,2.393795728683472,50000 -756.3632752895355,0.7525720596313477,10127.587964057922,22340,0,10127.587964057922,0.3882000148296356,2.9433727264404297,10000,10885.880244970322,0.5262304544448853,2.1738510131835938,0.4916400015354156,2.3356337547302246,50000 -792.3766157627106,0.7804503440856934,10547.819781303406,23269,0,10547.819781303406,0.3884000182151794,2.912286758422852,10000,11342.201602220535,0.5324609279632568,2.129218816757202,0.4975399971008301,2.299388885498047,50000 -829.1624467372894,0.8119678497314453,10967.767451047896,24199,0,10967.767451047896,0.3945000171661377,2.9152963161468506,10000,11799.014899253843,0.5442578196525574,2.104851245880127,0.498339980840683,2.314196348190308,50000 -865.5352845191956,0.8451135158538818,11388.04214978218,25128,0,11388.04214978218,0.3976000249385834,2.9255411624908447,10000,12255.743125915527,0.5494335889816284,2.1070497035980225,0.5054199695587158,2.3007595539093018,50000 -901.2724430561066,0.883490800857544,11808.012126922607,26047,0,11808.012126922607,0.4003000259399414,2.8565263748168945,10000,12711.5367333889,0.5480077862739563,2.075366735458374,0.5156199932098389,2.2329912185668945,50000 -936.7529728412628,0.9200353622436525,12228.003628730774,26974,0,12228.003628730774,0.3957000076770782,2.879225015640259,10000,13167.093488931656,0.5512499809265137,2.059241533279419,0.510919988155365,2.25130581855774,50000 -972.5665490627288,0.9505248069763184,12648.078053474426,27900,0,12648.078053474426,0.4039000272750854,2.825509786605835,10000,13623.059293746948,0.5764257907867432,1.9356085062026973,0.5181599855422974,2.198256015777588,50000 -1008.0687556266784,0.9811530113220216,13068.219451665878,28825,0,13068.219451665878,0.4154000282287597,2.7733428478240967,10000,14078.781080007551,0.5618749856948853,1.990774154663086,0.5294600129127502,2.153543472290039,50000 -1043.8474705219269,1.0138554573059082,13488.522999286652,29752,0,13488.522999286652,0.414900004863739,2.78692626953125,10000,14534.944336652756,0.5666210651397705,1.9570411443710327,0.5254200100898743,2.154048442840576,50000 -1079.546533346176,1.0415377616882324,13908.572232484818,30682,0,13908.572232484818,0.4161000251770019,2.84878921508789,10000,14990.768680334091,0.5687890648841858,2.034207344055176,0.5227400064468384,2.254662036895752,50000 -1114.5431122779846,1.0704760551452637,14328.762746572496,31610,0,14328.762746572496,0.4204000234603882,2.7725627422332764,10000,15446.033729076384,0.5677539110183716,1.991437315940857,0.531279981136322,2.1594009399414062,50000 -1149.8860309123993,1.1039931774139404,14749.06371474266,32539,0,14749.06371474266,0.4249000251293182,2.7316830158233643,10000,15901.75974059105,0.5756444931030273,1.9299191236495967,0.5379799604415894,2.113074541091919,50000 -1186.6391808986664,1.1376256942749023,15168.989698171616,33466,0,15168.989698171616,0.4214000105857849,2.734881401062012,10000,16358.529019117355,0.5826953053474426,1.9023802280426023,0.5382999777793884,2.1023356914520264,50000 -1222.977279663086,1.1691491603851318,15589.234191179276,34393,0,15589.234191179276,0.431300014257431,2.7067244052886963,10000,16815.190942525864,0.5888671875,1.8822269439697263,0.5423799753189087,2.086999654769897,50000 -1259.0506103038788,1.20351243019104,16009.42619729042,35320,0,16009.42619729042,0.4256000220775604,2.749296188354492,10000,17271.53852891922,0.5822656154632568,1.9264488220214844,0.5397199988365173,2.1193225383758545,50000 -1294.5700707435608,1.2412841320037842,16429.708421945572,36250,0,16429.708421945572,0.4310000240802765,2.680048704147339,10000,17727.43417429924,0.5885351300239563,1.876665472984314,0.5481399893760681,2.063407182693481,50000 -1331.338150024414,1.278425931930542,16850.04904961586,37179,0,16850.04904961586,0.4376000165939331,2.6541614532470703,10000,18184.6283288002,0.6117382645606995,1.750945806503296,0.5479400157928467,2.041366577148437,50000 -1367.6382720470428,1.311347246170044,17269.996902942657,38104,0,17269.996902942657,0.4384000301361084,2.6467978954315186,10000,18640.95697760582,0.5925390720367432,1.8541145324707031,0.5517799854278564,2.030083656311035,50000 -1403.9910144805908,1.3433549404144287,17689.955061912537,39030,0,17689.955061912537,0.4350000321865082,2.658357620239258,10000,19097.34910964966,0.5927343368530273,1.835362672805786,0.5510199666023254,2.033008813858032,50000 -1440.6161260604858,1.377694845199585,18110.31490755081,39957,0,18110.31490755081,0.4463000297546386,2.623008251190185,10000,19554.41698360443,0.5999609231948853,1.784259557723999,0.5545600056648254,2.00947380065918,50000 -1475.5338282585144,1.4079561233520508,18530.302005290985,40882,0,18530.302005290985,0.4439000189304352,2.64489483833313,10000,20009.39934825897,0.5937694907188416,1.8420950174331665,0.5575799942016602,2.006721258163452,50000 -1511.2537944316864,1.4423680305480957,18950.62165570259,41810,0,18950.62165570259,0.4422000348567962,2.627874851226806,10000,20465.520961999893,0.6000195145606995,1.8037185668945312,0.5605800151824951,1.9919346570968628,50000 -1546.9716968536377,1.474189043045044,19370.62746167183,42736,0,19370.62746167183,0.446800023317337,2.621208667755127,10000,20921.324761152267,0.6067187190055847,1.7677797079086304,0.5633800029754639,1.979863405227661,50000 -1581.4821391105652,1.5045185089111328,19790.838992118835,43664,0,19790.838992118835,0.4404000341892242,2.6423449516296387,10000,21376.125276327133,0.6116601228713989,1.7431319952011108,0.5555999875068665,1.9900355339050293,50000 -1617.130084991455,1.5374326705932615,20211.031491041183,44589,0,20211.031491041183,0.4478000104427337,2.59022855758667,10000,21832.046981096268,0.6066015362739563,1.761168122291565,0.562720000743866,1.9542808532714844,50000 -1652.8864409923551,1.57989239692688,20631.116145849228,45511,0,20631.116145849228,0.4520000219345093,2.569453239440918,10000,22287.97873067856,0.6109179258346558,1.7445123195648191,0.5663999915122986,1.957970380783081,50000 -1688.2459979057312,1.6172716617584229,21051.443721055984,46439,0,21051.443721055984,0.4525000154972076,2.548551559448242,10000,22743.75170326233,0.6299023032188416,1.6438137292861938,0.5723599791526794,1.911130428314209,50000 -1724.9623546600342,1.65169358253479,21471.823915719982,47366,0,21471.823915719982,0.4547000229358673,2.593113422393799,10000,23200.93217253685,0.6089648008346558,1.7728519439697266,0.5699399709701538,1.960967898368836,50000 -1761.9988887310028,1.6848242282867432,21892.04855132103,48292,0,21892.04855132103,0.4553000330924988,2.6022982597351074,10000,23658.27480411529,0.6068750023841858,1.7780888080596924,0.5670199990272522,1.9639018774032595,50000 -1797.1702196598053,1.7208774089813232,22312.30050706864,49216,0,22312.30050706864,0.4580000340938568,2.5392799377441406,10000,24113.78163957596,0.6234374642372131,1.6764730215072632,0.5714200139045715,1.9154047966003416,50000 -1831.73202419281,1.7526824474334717,22732.495416641235,50139,0,22732.495416641235,0.455700010061264,2.5697062015533447,10000,24568.618636846542,0.6090039014816284,1.7651121616363523,0.569819986820221,1.935936689376831,50000 -1868.3653919696808,1.794832706451416,23152.728211402893,51065,0,23152.728211402893,0.45210000872612,2.5792510509490967,10000,25025.574861764908,0.6101757884025574,1.7590184211730957,0.5707600116729736,1.938204765319824,50000 -1905.6698813438416,1.828599452972412,23573.097232103348,51992,0,23573.097232103348,0.4656000137329101,2.487163782119751,10000,25483.33177614212,0.6263476610183716,1.6484354734420776,0.5796999931335449,1.865336775779724,50000 -1943.5027811527248,1.8709113597869875,23993.060639619827,52919,0,23993.060639619827,0.4649000167846679,2.5275325775146484,10000,25941.21884179116,0.6298437118530273,1.6609742641448977,0.5778399705886841,1.9005619287490845,50000 -1979.9512028694155,1.9040093421936035,24413.19846820832,53845,0,24413.19846820832,0.457800030708313,2.5673742294311523,10000,26397.88577604294,0.6171679496765137,1.7618088722229004,0.5772799849510193,1.945756793022156,50000 -2016.8710873126984,1.9429829120635984,24833.537580490112,54771,0,24833.537580490112,0.4653000235557556,2.465291976928711,10000,26855.23208403588,0.6280664205551147,1.6353827714920044,0.581820011138916,1.8503868579864504,50000 -2052.716703414917,1.9796583652496336,25253.59513092041,55700,0,25253.59513092041,0.461400032043457,2.5292446613311768,10000,27311.22057056427,0.6409375071525574,1.605479121208191,0.5806399583816528,1.8812267780303955,50000 -2088.7680180072784,2.0139429569244385,25673.629153251648,56627,0,25673.629153251648,0.4660000205039978,2.484644651412964,10000,27767.3883125782,0.6262499690055847,1.657572865486145,0.5814999938011169,1.8544203042984009,50000 -2126.1931478977203,2.051152229309082,26093.89505290985,57555,0,26093.89505290985,0.4666000306606293,2.503213167190552,10000,28225.16506052017,0.6287499666213989,1.6723737716674805,0.5836600065231323,1.8777090311050413,50000 -2164.820201158524,2.0870184898376465,26513.911805152893,58481,0,26513.911805152893,0.4684000313282013,2.4887142181396484,10000,28683.89664721489,0.6403515338897705,1.61244797706604,0.582859992980957,1.8621419668197632,50000 -2203.4964802265167,2.1264965534210205,26934.146147489548,59407,0,26934.146147489548,0.4738000333309173,2.4757091999053955,10000,29142.89506626129,0.6268945336341858,1.6685230731964111,0.5874199867248535,1.8534951210021973,50000 -2240.920811891556,2.165325164794922,27354.376095294952,60334,0,27354.376095294952,0.4762000143527984,2.4529848098754883,10000,29600.636869430546,0.6327733993530273,1.6337242126464844,0.5888599753379822,1.8365858793258667,50000 -2274.5890777111053,2.203279733657837,27774.38081717491,61261,0,27774.38081717491,0.4675000309944153,2.50374174118042,10000,30054.39617562294,0.6378515362739563,1.6398062705993652,0.5875999927520752,1.8620257377624512,50000 -2310.474866867065,2.243015766143799,28194.3759431839,62185,0,28194.3759431839,0.4650000333786011,2.49658751487732,10000,30510.36556315422,0.6422656178474426,1.6118204593658447,0.5854799747467041,1.8537278175354004,50000 -2346.715988636017,2.281113147735596,28614.36855793,63111,0,28614.36855793,0.4726000130176544,2.5159428119659424,10000,30966.68574333191,0.6295117139816284,1.700981855392456,0.5878599882125854,1.883209466934204,50000 -2383.1910014152527,2.3187735080718994,29034.53926801681,64038,0,29034.53926801681,0.4771000146865845,2.4646663665771484,10000,31423.41744923592,0.6376757621765137,1.6427497863769531,0.5946999788284302,1.846629977226257,50000 -2420.782904148102,2.3532867431640625,29454.639416217804,64965,0,29454.639416217804,0.4723000228404999,2.4600064754486084,10000,31881.19240355492,0.6559374928474426,1.5458028316497805,0.5884599685668945,1.832724690437317,50000 -2457.5085434913635,2.387108564376831,29874.70513272285,65893,0,29874.70513272285,0.4759000241756439,2.482262134552002,10000,32338.065663814545,0.6297070384025574,1.6710429191589355,0.5890399813652039,1.8539769649505613,50000 -2495.827398777008,2.426995038986206,30294.62602829933,66814,0,30294.62602829933,0.4705000221729278,2.4594082832336426,10000,32796.39333939552,0.6425976157188416,1.6025398969650269,0.5962600111961365,1.8162277936935425,50000 -2535.44087266922,2.470527648925781,30714.93037724495,67738,0,30714.93037724495,0.4780000150203705,2.456949710845948,10000,33256.4021191597,0.6525781154632568,1.570324182510376,0.6031599640846252,1.8069418668746948,50000 -2575.6301860809326,2.5075490474700928,31135.09487080574,68663,0,31135.09487080574,0.4754000306129455,2.4471421241760254,10000,33716.840874910355,0.6356640458106995,1.6274088621139526,0.5971199870109558,1.8149526119232176,50000 -2614.288892507553,2.7598090171813965,31555.165743112564,69587,0,31555.165743112564,0.4761000275611877,2.4608607292175293,10000,34175.87082648277,0.6425390243530273,1.6058720350265503,0.6004999876022339,1.7994102239608765,50000 -2654.033591747284,2.801017999649048,31975.08659768105,70508,0,31975.08659768105,0.4789000153541565,2.4507243633270264,10000,34635.624881505966,0.6518359184265137,1.5789217948913574,0.5998600125312805,1.8045138120651243,50000 -2692.672725915909,2.847104072570801,32395.51756548881,71431,0,32395.51756548881,0.484000027179718,2.4054059982299805,10000,35094.78996706009,0.6530663967132568,1.5456774234771729,0.6032800078392029,1.7672827243804932,50000 -2729.4140496253967,2.888112783432007,32815.659552812576,72353,0,32815.659552812576,0.4842000305652618,2.39375638961792,10000,35551.76320028305,0.6449413895606995,1.5734144449234009,0.6037600040435791,1.765405535697937,50000 -2766.906188249588,2.9232640266418457,33235.68685173988,73277,0,33235.68685173988,0.484000027179718,2.3881523609161377,10000,36009.36540389061,0.6486523151397705,1.5562613010406494,0.6032199859619141,1.7581266164779663,50000 -2806.733047246933,2.963426113128662,33655.99447274208,74204,0,33655.99447274208,0.4800000190734863,2.4043872356414795,10000,36469.58887457848,0.6731249690055847,1.4761346578598022,0.6058200001716614,1.766079068183899,50000 -2847.1502919197083,3.003435611724853,34076.297404289246,75128,0,34076.297404289246,0.486700028181076,2.39943790435791,10000,36930.39732980728,0.6450976133346558,1.5667901039123535,0.6061399579048157,1.760995626449585,50000 -2887.1610465049744,3.0404083728790283,34496.459659576416,76053,0,34496.459659576416,0.4927000105381012,2.395556688308716,10000,37390.65586042404,0.6499999761581421,1.5527905225753784,0.6094799637794495,1.7591296434402466,50000 -2924.444561958313,3.0794270038604736,34916.55329370499,76976,0,34916.55329370499,0.4882000088691711,2.4163527488708496,10000,37848.11969184876,0.6608788967132568,1.5342339277267456,0.6050800085067749,1.7775715589523315,50000 -2963.59290099144,3.125219821929932,35336.832350969315,77903,0,35336.832350969315,0.4905000329017639,2.4130568504333496,10000,38307.64156937599,0.6502929329872131,1.5727944374084473,0.6092599630355835,1.7659096717834473,50000 -3003.001363515854,3.161072254180908,35756.88972687721,78829,0,35756.88972687721,0.4865000247955322,2.392758369445801,10000,38767.19116520882,0.6518749594688416,1.5727999210357666,0.6112599968910217,1.7607871294021606,50000 -3042.4453415870667,3.199512720108032,36177.00625920296,79754,0,36177.00625920296,0.4955000281333923,2.342292547225952,10000,39226.83818221092,0.6651171445846558,1.4895260334014893,0.6138399839401245,1.7199392318725586,50000 -3083.108058929444,3.2435245513916016,36597.088631391525,80677,0,36597.088631391525,0.493800014257431,2.342613458633423,10000,39687.675043821335,0.6646288633346558,1.4810144901275637,0.6114999651908875,1.7160485982894895,50000 -3124.51851439476,3.2905027866363525,37017.36079597473,81603,0,37017.36079597473,0.4974000155925751,2.352527379989624,10000,40149.45268511772,0.6604687571525574,1.5106688737869265,0.6142599582672119,1.7114691734313965,50000 -3163.785692691803,3.336221933364868,37437.69989728928,82530,0,37437.69989728928,0.4939000308513641,2.370671510696411,10000,40609.15281009674,0.6646093726158142,1.5200191736221311,0.6138399839401245,1.7434662580490112,50000 -3203.184502363205,3.375791311264038,37857.93138933182,83454,0,37857.93138933182,0.5004000067710876,2.3427956104278564,10000,41068.87111830712,0.6852148175239563,1.3989347219467163,0.6177399754524231,1.7047220468521118,50000 -3242.8420696258545,3.4148974418640137,38278.19797229767,84377,0,38278.19797229767,0.4927000105381012,2.361658811569214,10000,41528.8823723793,0.6599413752555847,1.5086867809295654,0.6137599945068359,1.7188671827316284,50000 -3284.55818772316,3.455544948577881,38698.49995803833,85301,0,38698.49995803833,0.5024999976158142,2.314694881439209,10000,41990.98984336853,0.6687890291213989,1.4672608375549316,0.6248399615287781,1.679897427558899,50000 -3321.14311671257,3.497130870819092,39118.564464092255,86227,0,39118.564464092255,0.5035000443458557,2.308379650115967,10000,42447.72882437706,0.682421863079071,1.420057773590088,0.6224600076675415,1.6837719678878784,50000 -3361.6144444942474,3.535960912704468,39538.55706048012,87154,0,39538.55706048012,0.5034000277519226,2.3107705116271973,10000,42908.27966308594,0.6678515672683716,1.4901386499404907,0.624019980430603,1.6841695308685305,50000 -3399.0924847126007,3.5792877674102783,39958.77399516106,88081,0,39958.77399516106,0.5009000301361084,2.3405404090881348,10000,43366.066150188446,0.6660546660423279,1.5081756114959717,0.6201399564743042,1.717879295349121,50000 -3436.8513662815094,3.6288912296295166,40378.939730882645,89007,0,40378.939730882645,0.506600022315979,2.309653282165528,10000,43824.08883070946,0.6788476705551147,1.4367852210998535,0.6260199546813965,1.682494044303894,50000 -3474.23304772377,3.666169404983521,40798.99773478508,89929,0,40798.99773478508,0.5065000057220459,2.3380980491638184,10000,44281.61442565918,0.6749609112739563,1.4869822263717651,0.6248999834060669,1.7146426439285278,50000 -3512.803384065628,3.704136610031128,41219.33115816116,90855,0,41219.33115816116,0.5069000124931335,2.29979944229126,10000,44740.604519844055,0.6749218702316284,1.44557523727417,0.6292200088500977,1.6595513820648191,50000 -3550.632279634476,3.7427337169647217,41639.402092933655,91780,0,41639.402092933655,0.5057000517845154,2.329899787902832,10000,45198.59097504616,0.6800000071525574,1.466551423072815,0.6314399838447571,1.6920958757400513,50000 -3590.631927251816,3.7838430404663086,42059.464174985886,92701,0,42059.464174985886,0.5121000409126282,2.2850232124328613,10000,45658.74293327332,0.69873046875,1.3424489498138428,0.6325399875640869,1.649196982383728,50000 -3628.571489095688,3.82320785522461,42479.74120259285,93626,0,42479.74120259285,0.5180000066757202,2.231893301010132,10000,46117.04714727402,0.6816992163658142,1.4043458700180054,0.637179970741272,1.6172446012496948,50000 -3665.710597515106,3.861924171447754,42900.02947735786,94550,0,42900.02947735786,0.5154000520706177,2.23335862159729,10000,46574.56220269203,0.6850976347923279,1.3780698776245115,0.6401599645614624,1.591439127922058,50000 -3703.507269144058,3.901413679122925,43320.07567238808,95475,0,43320.07567238808,0.5127000212669373,2.2695205211639404,10000,47032.49260401726,0.6938085556030273,1.3812767267227173,0.6334999799728394,1.646828293800354,50000 -3738.279561281204,3.94172739982605,43740.39475917816,96400,0,43740.39475917816,0.5180000066757202,2.2730095386505127,10000,47487.6726129055,0.6801366806030273,1.449390888214111,0.6377800107002258,1.6380335092544556,50000 -3778.8772070407854,3.9924333095550537,44160.30851793289,97326,0,44160.30851793289,0.5163000226020813,2.233106136322021,10000,47948.28272128105,0.6853125095367432,1.400512933731079,0.6398599743843079,1.6111207008361816,50000 -3818.3037271499634,4.035202503204346,44580.45514702797,98253,0,44580.45514702797,0.5175999999046326,2.243485689163208,10000,48407.94709467888,0.6956640481948853,1.3512005805969238,0.6403399705886841,1.6066975593566897,50000 -3855.362357854843,4.076540231704712,45000.66364359856,99180,0,45000.66364359856,0.526199996471405,2.183600187301636,10000,48865.30378293991,0.6977148056030273,1.3450783491134644,0.6491000056266785,1.5536638498306274,50000 -3894.365116834641,4.119498014450073,45420.91577172279,100104,0,45420.91577172279,0.5207000374794006,2.2531089782714844,10000,49324.64928340912,0.6898828148841858,1.416399598121643,0.6423400044441223,1.6274499893188477,50000 -3932.808772087097,4.1595988273620605,45841.04868769646,101032,0,45841.04868769646,0.52510005235672,2.190631628036499,10000,49783.31498479843,0.6932812333106995,1.3365551233291626,0.6417999863624573,1.5768588781356812,50000 -3972.7331693172455,4.202354431152344,46261.34437060356,101958,0,46261.34437060356,0.5157000422477722,2.233988046646118,10000,50243.625947237015,0.7127343416213989,1.287361979484558,0.642579972743988,1.5980144739151,50000 -4012.935410261154,4.241785049438477,46681.6217648983,102883,0,46681.6217648983,0.5223000049591064,2.177493095397949,10000,50704.19277739525,0.6974804401397705,1.331035614013672,0.6518599987030029,1.5392245054244995,50000 -4054.28226852417,4.292704820632935,47101.53124284744,103809,0,47101.53124284744,0.5236000418663025,2.234383583068848,10000,51165.548437833786,0.6997656226158142,1.369573950767517,0.6473199725151062,1.58622944355011,50000 -4095.0082075595856,4.342754125595093,47521.49310541153,104733,0,47521.49310541153,0.528700053691864,2.189589262008667,10000,51626.334483385086,0.7157617211341858,1.2719745635986328,0.6475399732589722,1.5650979280471802,50000 -4132.022391796112,4.3900251388549805,47941.640141010284,105660,0,47941.640141010284,0.5309000015258789,2.172654867172241,10000,52083.59052538872,0.6998632550239563,1.3254988193511963,0.649619996547699,1.5387935638427734,50000 -4172.140084028244,4.437947511672974,48361.77505850792,106589,0,48361.77505850792,0.5260000228881836,2.2354602813720703,10000,52543.93924975395,0.7017773389816284,1.3748812675476074,0.6519799828529358,1.600887656211853,50000 -4209.313284873962,4.48431396484375,48782.06770968437,107516,0,48782.06770968437,0.5263000130653381,2.170165777206421,10000,53001.499126434326,0.7152148485183716,1.2641282081604004,0.6518200039863586,1.5298153162002563,50000 -4247.780642032623,4.5283708572387695,49202.40798187256,108440,0,49202.40798187256,0.5293000340461731,2.17182731628418,10000,53460.39822816849,0.7088671922683716,1.30809485912323,0.6545000076293945,1.5428231954574585,50000 -4287.5615401268005,4.569385528564453,49622.73654794693,109364,0,49622.73654794693,0.5425000190734863,2.130657434463501,10000,53920.59532928467,0.7081640362739563,1.2886130809783936,0.6595799922943115,1.5078322887420654,50000 -4325.538824796677,4.612484693527222,50042.85763192177,110290,0,50042.85763192177,0.5387000441551208,2.11667799949646,10000,54378.7845196724,0.7182226181030273,1.246045470237732,0.6599000096321106,1.4922248125076294,50000 -4362.5080988407135,4.6590001583099365,50462.81347370148,111215,0,50462.81347370148,0.5321000218391418,2.1034798622131348,10000,54835.804351091385,0.7352343797683716,1.1619881391525269,0.6606599688529968,1.4827560186386108,50000 -4398.260158777237,4.708902835845947,50883.19580984116,112141,0,50883.19580984116,0.534000039100647,2.13720703125,10000,55292.03625512123,0.7141991853713989,1.254716873168945,0.6646999716758728,1.48099946975708,50000 -4438.158957719803,4.753890514373779,51303.45684719086,113064,0,51303.45684719086,0.5439000129699707,2.109809875488281,10000,55752.28976273537,0.7215819954872131,1.2358224391937256,0.668940007686615,1.475932002067566,50000 -4476.102520704269,4.79875922203064,51723.392731666565,113989,0,51723.392731666565,0.5433000326156616,2.090823173522949,10000,56210.262149095535,0.7289257645606995,1.1877706050872805,0.660860002040863,1.4783949851989746,50000 -4513.86473441124,4.843943119049072,52143.375336408615,114913,0,52143.375336408615,0.5435000061988831,2.1070306301116943,10000,56668.10083293915,0.71644526720047,1.250329852104187,0.6658799648284912,1.4862215518951416,50000 -4553.015449285507,4.888232231140137,52563.48961663246,115837,0,52563.48961663246,0.5459000468254089,2.074211597442627,10000,57127.45796251297,0.7221484184265137,1.2188655138015747,0.669439971446991,1.451919674873352,50000 -4593.209508657455,4.9340245723724365,52983.74047589302,116762,0,52983.74047589302,0.5437999963760376,2.064565420150757,10000,57588.00782227516,0.7323827743530273,1.1630433797836304,0.6712200045585632,1.4364888668060305,50000 -4634.027329921722,4.976112604141235,53403.952806949615,117688,0,53403.952806949615,0.5520000457763672,2.056349277496338,10000,58049.12762546539,0.7284570336341858,1.1957414150238037,0.6744399666786194,1.430295467376709,50000 -4675.380663156509,5.025782823562622,53824.20411038399,118612,0,53824.20411038399,0.5512000322341919,2.073585271835327,10000,58510.82946014404,0.724804699420929,1.209438681602478,0.6693399548530579,1.4543384313583374,50000 -4716.138568401337,5.069893598556519,54244.48342585564,119537,0,54244.48342585564,0.5481000542640686,2.072385549545288,10000,58971.95789599419,0.7267187237739563,1.2144628763198853,0.670520007610321,1.462047100067139,50000 -4753.07851433754,5.121261596679688,54664.61124134064,120462,0,54664.61124134064,0.5540000200271606,2.0518739223480225,10000,59429.12511634827,0.7491210699081421,1.108149766921997,0.6756599545478821,1.4270612001419067,50000 -4793.797434568405,5.165852785110474,55084.6750433445,121388,0,55084.6750433445,0.5573000311851501,2.0155694484710693,10000,59890.00051164627,0.7331249713897705,1.1577279567718506,0.6822599768638611,1.3907045125961304,50000 -4831.949808120728,5.2081685066223145,55504.61960840225,122313,0,55504.61960840225,0.5590000152587891,2.026743173599243,10000,60348.186989068985,0.7346875071525574,1.1657814979553225,0.675599992275238,1.4246152639389038,50000 -4871.815351247788,5.260644197463989,55924.80131602287,123236,0,55924.80131602287,0.5630000233650208,2.019491672515869,10000,60808.3349378109,0.7521093487739563,1.1196391582489014,0.6818400025367737,1.4134577512741089,50000 -4912.768722057343,5.302764654159546,56345.114104270935,124159,0,56345.114104270935,0.560200035572052,2.0003364086151123,10000,61269.6905901432,0.7383007407188416,1.1520472764968872,0.6840800046920776,1.3852416276931765,50000 -4951.472690105438,5.346803903579712,56765.40418744087,125086,0,56765.40418744087,0.558899998664856,2.0429952144622803,10000,61728.776727199554,0.7397655844688416,1.1573693752288818,0.6845999956130981,1.4100217819213867,50000 -4992.742039680481,5.39652943611145,57185.52679872513,126011,0,57185.52679872513,0.5542000532150269,2.0321836471557617,10000,62190.26622200012,0.7486132383346558,1.118463397026062,0.6848799586296082,1.392722725868225,50000 -5031.7367787361145,5.445305347442627,57605.49915289879,126934,0,57605.49915289879,0.5667000412940979,1.988448977470398,10000,62649.33008289337,0.7451757788658142,1.1258550882339478,0.686739981174469,1.3771331310272217,50000 -5072.641215085983,5.489341497421265,58025.73322200775,127859,0,58025.73322200775,0.5637000203132629,2.017214298248291,10000,63110.56328248978,0.7473437190055847,1.1448750495910645,0.6888999938964844,1.3919219970703125,50000 -5113.915201663971,5.5364158153533936,58445.72785902023,128782,0,58445.72785902023,0.5646000504493713,1.990260481834412,10000,63571.92673802376,0.753710925579071,1.085994005203247,0.6908999681472778,1.3590346574783323,50000 -5154.739242553711,5.584019184112549,58865.84973907471,129707,0,58865.84973907471,0.562000036239624,2.0006537437438965,10000,64032.96802806854,0.7651562094688416,1.0557949542999268,0.6894400119781494,1.376419186592102,50000 -5195.786374568939,5.630152940750122,59285.97770404816,130627,0,59285.97770404816,0.5738000273704529,1.951003074645996,10000,64494.23661541939,0.7515429258346558,1.082430362701416,0.6952399611473083,1.3379631042480469,50000 -5232.168748378754,5.67841386795044,59706.24283266068,131551,0,59706.24283266068,0.5703999996185303,1.9871017932891848,10000,64950.979976415634,0.7570898532867432,1.0904736518859863,0.6944199800491333,1.3673386573791504,50000 -5272.644042253494,5.7267396450042725,60126.579996824265,132475,0,60126.579996824265,0.5768000483512878,1.9146699905395508,10000,65411.887921094894,0.769238293170929,1.0069472789764404,0.7003200054168701,1.3125197887420654,50000 -5314.366637706757,5.773090362548828,60547.2616622448,133401,0,60547.2616622448,0.5755000114440918,1.93559992313385,10000,65874.38661026955,0.757519543170929,1.062558889389038,0.6997599601745605,1.305146098136902,50000 -5351.010791301727,5.823288679122925,60967.64869689941,134327,0,60967.64869689941,0.5751000046730042,1.9424078464508057,10000,66331.51563692093,0.7621874809265137,1.0537630319595337,0.700980007648468,1.31785249710083,50000 -5391.782269239426,5.869398355484009,61387.66164803505,135250,0,61387.66164803505,0.5711000561714172,1.9518496990203853,10000,66792.39415669441,0.7675390243530273,1.0237082242965698,0.6987599730491638,1.3225077390670776,50000 -5430.429076910019,5.918476819992065,61807.8627281189,136174,0,61807.8627281189,0.5789000391960144,1.929173469543457,10000,67251.33851337433,0.764843761920929,1.0483845472335815,0.7023999691009521,1.321294188499451,50000 -5469.642811059952,5.971429824829102,62228.40162968636,137101,0,62228.40162968636,0.58160001039505,1.9345344305038448,10000,67711.19112372398,0.7681640386581421,1.0459219217300415,0.7051399946212769,1.3112800121307373,50000 -5510.844088315964,6.016285419464111,62648.38740777969,138024,0,62648.38740777969,0.5817000269889832,1.9207830429077148,10000,68172.46992921829,0.7739452719688416,1.0145626068115234,0.706559956073761,1.2995346784591677,50000 -5552.058122396469,6.076954126358032,63068.49566245079,138947,0,63068.49566245079,0.579200029373169,1.925778031349182,10000,68633.90044283867,0.7826952934265137,0.9653533697128296,0.7051399946212769,1.3022921085357666,50000 -5591.511062860489,6.124045372009277,63488.75289964676,139869,0,63488.75289964676,0.5893000364303589,1.8732690811157229,10000,69093.70464968681,0.7697656154632568,1.0047340393066406,0.7098000049591064,1.2667945623397827,50000 -5629.565707683563,6.17889142036438,63908.82472872734,140795,0,63908.82472872734,0.5814000368118286,1.9018096923828125,10000,69551.93505644798,0.7763866782188416,0.9946498870849608,0.7087399959564209,1.28446626663208,50000 -5668.733921766281,6.22956919670105,64329.11201548576,141719,0,64329.11201548576,0.5896000266075134,1.8770052194595337,10000,70011.48867511749,0.7819921970367432,0.9503366351127625,0.7135999798774719,1.2518689632415771,50000 -5706.832077026367,6.278799533843994,64749.359623909,142641,0,64749.359623909,0.5865000486373901,1.8613325357437127,10000,70469.93096780777,0.7760937213897705,0.9692516326904296,0.7108399868011475,1.23958420753479,50000 -5747.56706738472,6.333415985107422,65169.36147618294,143566,0,65169.36147618294,0.5878000259399414,1.894475340843201,10000,70930.76969718933,0.778613269329071,0.9984338283538818,0.7130399942398071,1.2804381847381592,50000 -5789.146682500839,6.380536079406738,65589.43929386139,144491,0,65589.43929386139,0.5898000001907349,1.884268045425415,10000,71392.52113199234,0.7850390672683716,0.9747884273529052,0.7129600048065186,1.2740776538848877,50000 -5830.013171672821,6.438591480255127,66009.40965223312,145414,0,66009.40965223312,0.5966000556945801,1.8826606273651123,10000,71853.4624812603,0.77845698595047,0.9919044971466064,0.7156599760055542,1.2642072439193726,50000 -5867.35782957077,6.490723133087158,66429.68575835228,146341,0,66429.68575835228,0.5951000452041626,1.8369319438934328,10000,72311.18190431595,0.7854687571525574,0.936896562576294,0.7184799909591675,1.2213993072509766,50000 -5906.406289815903,6.545153379440308,66849.90239930153,147267,0,66849.90239930153,0.602400004863739,1.8256595134735107,10000,72770.54873371124,0.791308581829071,0.9221143126487732,0.7231599688529968,1.2180984020233154,50000 -5947.267498254776,6.596628427505493,67269.83102846146,148193,0,67269.83102846146,0.6012000441551208,1.8258556127548216,10000,73231.43658638,0.8017187118530273,0.8867132663726807,0.7227199673652649,1.2239965200424194,50000 -5984.706177949905,6.646602392196655,67690.06477189064,149120,0,67690.06477189064,0.6017000079154968,1.801848292350769,10000,73689.20619821548,0.7932226657867432,0.906248927116394,0.7246599793434143,1.2037746906280518,50000 -6023.770362854004,6.693514585494995,68110.0372145176,150046,0,68110.0372145176,0.6027000546455383,1.8345398902893064,10000,74148.33674430847,0.7907617092132568,0.9388483762741088,0.7214599847793579,1.2328137159347534,50000 -6066.470200300217,6.744734525680542,68529.95536541939,150968,0,68529.95536541939,0.6043000221252441,1.8021141290664675,10000,74611.0534362793,0.8061718344688416,0.8653810024261475,0.7301200032234192,1.189891338348389,50000 -6108.044515609741,6.796485185623169,68950.07535648346,151891,0,68950.07535648346,0.6104000210762024,1.778713345527649,10000,75072.84614467621,0.7955273389816284,0.8970678448677063,0.7301200032234192,1.1849876642227173,50000 -6148.80455327034,6.853741884231567,69370.09718680382,152814,0,69370.09718680382,0.6066000461578369,1.7887028455734253,10000,75533.73214793205,0.8013476133346558,0.8696554899215698,0.7305999994277954,1.1746702194213867,50000 -6188.927048921585,6.903278827667236,69790.06775164604,153736,0,69790.06775164604,0.6065000295639038,1.7751833200454712,10000,75993.92200398445,0.8064843416213989,0.8482201099395752,0.7335599660873413,1.1666568517684937,50000 -6229.90651512146,6.94967794418335,70210.01730847359,154662,0,70210.01730847359,0.6170000433921814,1.75464928150177,10000,76454.94499826431,0.804492175579071,0.8526339530944824,0.7369999885559082,1.151193141937256,50000 -6271.519709348679,6.999774217605591,70630.25607657433,155584,0,70630.25607657433,0.6175000071525574,1.73326575756073,10000,76916.89459323883,0.811328113079071,0.8368220329284668,0.738860011100769,1.1395392417907717,50000 -6309.6483726501465,7.052521228790283,71050.56174874306,156508,0,71050.56174874306,0.6134000420570374,1.745561599731445,10000,77375.42791485786,0.8138476610183716,0.8137789964675903,0.7384200096130371,1.1399335861206057,50000 -6349.01679110527,7.105899810791016,71470.68224668503,157431,0,71470.68224668503,0.619100034236908,1.722597599029541,10000,77835.01753425598,0.8206444978713989,0.7946438193321228,0.7416599988937378,1.1222891807556152,50000 -6390.761610031128,7.158337116241455,71890.93097662926,158285,0,71890.93097662926,0.6203000545501709,1.7233086824417114,10000,78297.10611104965,0.81201171875,0.8173814415931702,0.7415399551391602,1.131630539894104,50000 -6427.128840446472,7.209970235824585,72310.98796439171,159207,0,72310.98796439171,0.625700056552887,1.7257004976272583,10000,78753.62868118286,0.8217968344688416,0.7972414493560791,0.7410199642181396,1.1312137842178345,50000 -6467.850144386292,7.2619194984436035,72731.21923279762,160129,0,72731.21923279762,0.6238000392913818,1.7093664407730105,10000,79214.68068003654,0.8250585794448853,0.7695590257644653,0.7425199747085571,1.1162033081054688,50000 -6508.185884714127,7.314181804656982,73151.4669148922,161054,0,73151.4669148922,0.62090003490448,1.732371807098389,10000,79675.36319732666,0.8220117092132568,0.7986540794372559,0.7463399767875671,1.1236673593521118,50000 -6547.046494960785,7.370314836502075,73571.6182346344,161979,0,73571.6182346344,0.6240000128746033,1.6988554000854492,10000,80134.4779574871,0.82289057970047,0.774276852607727,0.7457799911499023,1.096876859664917,50000 -6587.196877479553,7.43206524848938,73991.6619336605,162902,0,73991.6619336605,0.6278000473976135,1.7115119695663452,10000,80594.78010439873,0.82958984375,0.768218994140625,0.7471399903297424,1.1131218671798706,50000 -6625.662905454636,7.482219219207764,74412.06664991379,163829,0,74412.06664991379,0.624500036239624,1.695138692855835,10000,81053.74798631668,0.8261523246765137,0.7748673558235168,0.747219979763031,1.0991054773330688,50000 -6665.215897798538,7.531557321548462,74831.999625206,164754,0,74831.999625206,0.6292000412940979,1.695865511894226,10000,81513.33116149902,0.8268163800239563,0.7791228890419006,0.7498199939727783,1.0978347063064575,50000 -6702.711914539337,7.5804502964019775,75252.15470719337,165676,0,75252.15470719337,0.6304000020027161,1.6812580823898315,10000,81971.07764077187,0.8290429711341858,0.7599804997444153,0.749239981174469,1.0925363302230835,50000 -6744.158909320831,7.63471245765686,75672.05558896065,166600,0,75672.05558896065,0.6350000500679016,1.662460446357727,10000,82432.52755188942,0.8282421827316284,0.7497026920318604,0.7534399628639221,1.0718141794204712,50000 -6783.7116050720215,7.687152862548828,76092.31971812248,167525,0,76092.31971812248,0.6344000101089478,1.650127410888672,10000,82892.44341945648,0.8338280916213989,0.7270570993423462,0.7521399855613708,1.0613549947738647,50000 -6821.297668457031,7.742358684539795,76512.35372972488,168447,0,76512.35372972488,0.6376000046730042,1.6524871587753296,10000,83350.1652495861,0.8368163704872131,0.721832275390625,0.755840003490448,1.0656956434249878,50000 -6862.45266699791,7.796396732330322,76932.46559858322,169369,0,76932.46559858322,0.6318000555038452,1.6552094221115112,10000,83811.53285884857,0.8389843702316284,0.7196028828620911,0.7570399641990662,1.0582653284072876,50000 -6903.587861776352,7.856116533279419,77352.67522978783,170291,0,77352.67522978783,0.6374000310897827,1.660707950592041,10000,84272.98402452469,0.8350781202316284,0.7398861646652222,0.7601999640464783,1.0654001235961914,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 290748920..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1894 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.3286792,6.907757,,,,,,,,,,,,,, -1,,,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.45515441894531,82.5178644657135,42.45515441894531,40.06261396408081,0.0,0.0 -100,0.3491868,6.906249,,,,,,,,,,,,,, -200,0.4333102,6.8949423,,,,,,,,,,,,,, -300,0.6003257,6.858579,,,,,,,,,,,,,, -400,0.65800726,6.8211856,,,,,,,,,,,,,, -500,0.91314083,6.8351703,,,,,,,,,,,,,, -600,0.971586,6.757582,,,,,,,,,,,,,, -700,1.7296199,6.665935,,,,,,,,,,,,,, -800,0.99669135,6.733124,,,,,,,,,,,,,, -898,,,0.0137695306912064,6.404803276062012,0.0135800000280141,6.413737773895264,50000.0,0.0108000002801418,6.456206321716309,10000.0,462.6178922653198,524.6362130641937,462.6178922653198,61.94122099876404,0.0286300182342529,0.0 -900,1.50529,6.696054,,,,,,,,,,,,,, -1000,1.6505151,6.542445,,,,,,,,,,,,,, -1100,1.7018461,6.765303,,,,,,,,,,,,,, -1200,2.298553,6.4679255,,,,,,,,,,,,,, -1300,1.5578537,6.7429476,,,,,,,,,,,,,, -1400,1.5815386,6.4819508,,,,,,,,,,,,,, -1500,1.5045583,6.494715,,,,,,,,,,,,,, -1600,1.6224985,6.27287,,,,,,,,,,,,,, -1700,2.1008453,6.2321305,,,,,,,,,,,,,, -1800,1.5846894,6.2617702,,,,,,,,,,,,,, -1842,,,0.0413671880960464,5.8381195068359375,0.0377600006759166,5.869866847991943,50000.0,0.0306000020354986,5.982430934906006,10000.0,882.9875965118408,966.72283244133,882.9875965118408,83.57465052604675,0.06095552444458,0.0 -1900,1.7387218,6.542836,,,,,,,,,,,,,, -2000,1.7435842,6.637529,,,,,,,,,,,,,, -2100,1.9850839,6.1445775,,,,,,,,,,,,,, -2200,1.5393717,6.417714,,,,,,,,,,,,,, -2300,1.9603856,6.1295767,,,,,,,,,,,,,, -2400,1.8304434,6.0425954,,,,,,,,,,,,,, -2500,1.2487843,6.551854,,,,,,,,,,,,,, -2600,1.754286,6.364751,,,,,,,,,,,,,, -2700,1.2047023,6.196744,,,,,,,,,,,,,, -2788,,,0.0702148452401161,5.4130401611328125,0.0637599974870681,5.468223571777344,50000.0,0.05180000141263,5.644834041595459,10000.0,1303.2035462856293,1408.8744959831238,1303.2035462856293,105.4310748577118,0.0899255275726318,0.0 -2800,1.9568479,6.033197,,,,,,,,,,,,,, -2900,1.9357738,6.011172,,,,,,,,,,,,,, -3000,1.7649513,5.926674,,,,,,,,,,,,,, -3100,1.3825182,5.8831916,,,,,,,,,,,,,, -3200,1.4217485,6.6312265,,,,,,,,,,,,,, -3300,1.6542344,5.8989735,,,,,,,,,,,,,, -3400,1.2877213,6.151626,,,,,,,,,,,,,, -3500,1.6562256,5.841148,,,,,,,,,,,,,, -3600,1.7618433,5.9018025,,,,,,,,,,,,,, -3700,2.0593758,5.7885227,,,,,,,,,,,,,, -3733,,,0.1027148440480232,5.1039581298828125,0.0944399982690811,5.144573211669922,50000.0,0.0715000033378601,5.371340274810791,10000.0,1723.364368915558,1850.8554532527924,1723.364368915558,127.17176485061646,0.118675947189331,0.0 -3800,1.5571783,5.804906,,,,,,,,,,,,,, -3900,1.7640923,5.7201385,,,,,,,,,,,,,, -4000,1.3575981,6.5382314,,,,,,,,,,,,,, -4100,1.4323733,6.5151486,,,,,,,,,,,,,, -4200,1.495558,5.794457,,,,,,,,,,,,,, -4300,1.2876421,6.5104322,,,,,,,,,,,,,, -4400,1.7652639,5.626124,,,,,,,,,,,,,, -4500,1.2792889,6.289568,,,,,,,,,,,,,, -4600,1.7631204,5.535327,,,,,,,,,,,,,, -4673,,,0.1353906244039535,4.738257884979248,0.1279999911785125,4.791199207305908,50000.0,0.0976000055670738,5.071451663970947,10000.0,2143.616794347763,2292.988788843155,2143.616794347763,148.97430968284607,0.1471822261810302,0.0 -4700,1.6352257,5.6396465,,,,,,,,,,,,,, -4800,1.5143949,5.4461174,,,,,,,,,,,,,, -4900,1.6944826,6.0450945,,,,,,,,,,,,,, -5000,1.5045933,5.928248,,,,,,,,,,,,,, -5100,1.5830668,5.614556,,,,,,,,,,,,,, -5200,1.5618026,5.487426,,,,,,,,,,,,,, -5300,1.795515,5.4159884,,,,,,,,,,,,,, -5400,1.5797744,5.315223,,,,,,,,,,,,,, -5500,1.6490965,5.3554997,,,,,,,,,,,,,, -5600,1.1718752,6.3545046,,,,,,,,,,,,,, -5609,,,0.1873828023672104,4.345278739929199,0.1711599975824356,4.43392276763916,50000.0,0.1264000087976455,4.761846542358398,10000.0,2563.884337425232,2735.237434864044,2563.884337425232,170.8790421485901,0.1756596565246582,0.0 -5700,1.1174432,6.4142623,,,,,,,,,,,,,, -5800,1.4455318,5.394262,,,,,,,,,,,,,, -5900,1.6759751,5.155884,,,,,,,,,,,,,, -6000,1.4524496,5.2333536,,,,,,,,,,,,,, -6100,1.8172513,5.213385,,,,,,,,,,,,,, -6200,1.6185715,5.4579163,,,,,,,,,,,,,, -6300,1.2367314,6.351158,,,,,,,,,,,,,, -6400,1.2135254,5.855071,,,,,,,,,,,,,, -6500,1.9450175,5.2565613,,,,,,,,,,,,,, -6546,,,0.2308789044618606,3.995484590530396,0.211119994521141,4.086441993713379,50000.0,0.1563000082969665,4.464807987213135,10000.0,2984.221193313598,3180.6647255420685,2984.221193313598,195.88995552062988,0.2058713436126709,0.0 -6600,1.6446847,5.20199,,,,,,,,,,,,,, -6700,1.4615428,5.3791456,,,,,,,,,,,,,, -6800,1.4760684,5.2063603,,,,,,,,,,,,,, -6900,1.7748157,4.985867,,,,,,,,,,,,,, -7000,1.5776674,5.021708,,,,,,,,,,,,,, -7100,2.003713,4.9136024,,,,,,,,,,,,,, -7200,1.726365,4.9958324,,,,,,,,,,,,,, -7300,1.1851683,6.3095264,,,,,,,,,,,,,, -7400,1.9657135,4.900971,,,,,,,,,,,,,, -7486,,,0.2678515613079071,3.742283821105957,0.2485999912023544,3.837204694747925,50000.0,0.1887000054121017,4.262177467346191,10000.0,3404.525137901306,3627.2830555439,3404.525137901306,222.11826848983765,0.2382426261901855,0.0 -7500,1.6636218,4.769843,,,,,,,,,,,,,, -7600,1.4910915,5.091313,,,,,,,,,,,,,, -7700,1.6612563,4.818793,,,,,,,,,,,,,, -7800,1.6821806,4.8843694,,,,,,,,,,,,,, -7900,2.0101902,5.166768,,,,,,,,,,,,,, -8000,1.3780807,6.1692266,,,,,,,,,,,,,, -8100,1.6807791,4.856743,,,,,,,,,,,,,, -8200,1.3156024,5.5345707,,,,,,,,,,,,,, -8300,1.786056,4.8267794,,,,,,,,,,,,,, -8400,1.5731655,4.662838,,,,,,,,,,,,,, -8427,,,0.3073437511920929,3.501837491989136,0.2824800014495849,3.621398448944092,50000.0,0.2143000066280365,4.080821514129639,10000.0,3824.85044836998,4078.802620887756,3824.85044836998,253.2263140678405,0.2754535675048828,0.0 -8500,1.822447,4.723553,,,,,,,,,,,,,, -8600,1.4958105,5.4069824,,,,,,,,,,,,,, -8700,1.120398,6.0096045,,,,,,,,,,,,,, -8800,1.7965362,4.5886545,,,,,,,,,,,,,, -8900,1.219901,5.842665,,,,,,,,,,,,,, -9000,1.4874723,4.9332857,,,,,,,,,,,,,, -9100,1.8032911,4.583003,,,,,,,,,,,,,, -9200,1.6840687,4.5208044,,,,,,,,,,,,,, -9300,1.2059276,5.6444235,,,,,,,,,,,,,, -9367,,,0.353339821100235,3.1838603019714355,0.315420001745224,3.375602960586548,50000.0,0.242000013589859,3.8549141883850098,10000.0,4245.004207611084,4524.752751588821,4245.004207611084,278.9440326690674,0.3049573898315429,0.0 -9400,1.7881317,4.460726,,,,,,,,,,,,,, -9500,1.2700764,6.1924615,,,,,,,,,,,,,, -9600,1.4239256,5.949154,,,,,,,,,,,,,, -9700,1.151299,6.0431156,,,,,,,,,,,,,, -9800,1.5978388,4.455246,,,,,,,,,,,,,, -9900,1.1169374,6.097369,,,,,,,,,,,,,, -10000,1.9232117,4.3787475,,,,,,,,,,,,,, -10100,1.0659624,6.0891023,,,,,,,,,,,,,, -10200,1.4592726,5.520472,,,,,,,,,,,,,, -10300,1.1690868,5.419787,,,,,,,,,,,,,, -10306,,,0.3710156083106994,3.0405492782592773,0.3440199792385101,3.174036741256714,50000.0,0.2660000026226043,3.685535430908203,10000.0,4665.314122676849,4973.823203802109,4665.314122676849,307.616507768631,0.3441452980041504,0.0 -10400,1.4706628,5.0122175,,,,,,,,,,,,,, -10500,1.7354456,4.997095,,,,,,,,,,,,,, -10600,1.5101663,4.3852777,,,,,,,,,,,,,, -10700,1.7639055,4.4136524,,,,,,,,,,,,,, -10800,1.4470963,4.3690443,,,,,,,,,,,,,, -10900,1.1394184,5.3061314,,,,,,,,,,,,,, -11000,2.5757735,4.2717867,,,,,,,,,,,,,, -11100,1.5169947,4.30663,,,,,,,,,,,,,, -11200,1.4965892,4.436069,,,,,,,,,,,,,, -11243,,,0.3888476490974426,2.948613166809082,0.3602799773216247,3.093329668045044,50000.0,0.2772000133991241,3.618096351623535,10000.0,5085.516779184341,5421.768236398697,5085.516779184341,335.2799074649811,0.3747379779815674,0.0 -11300,1.3768398,4.3571634,,,,,,,,,,,,,, -11400,1.1241628,5.9358425,,,,,,,,,,,,,, -11500,1.3704191,4.8967814,,,,,,,,,,,,,, -11600,1.8054312,4.2930837,,,,,,,,,,,,,, -11700,1.1390454,5.39995,,,,,,,,,,,,,, -11800,1.6318107,4.2207375,,,,,,,,,,,,,, -11900,1.0817549,5.955562,,,,,,,,,,,,,, -12000,1.6022519,4.0562015,,,,,,,,,,,,,, -12100,1.4730207,4.669614,,,,,,,,,,,,,, -12180,,,0.4136328101158142,2.8385469913482666,0.3738999962806701,3.021648406982422,50000.0,0.2879000008106231,3.568051338195801,10000.0,5505.593339443207,5873.66360616684,5505.593339443207,367.0149767398834,0.4094424247741699,0.0 -12200,1.2776785,4.3042965,,,,,,,,,,,,,, -12300,1.7131479,4.1267858,,,,,,,,,,,,,, -12400,1.5979685,4.2983294,,,,,,,,,,,,,, -12500,1.9700018,4.136613,,,,,,,,,,,,,, -12600,1.3084913,4.7134085,,,,,,,,,,,,,, -12700,1.403967,4.163977,,,,,,,,,,,,,, -12800,1.1788198,5.3377533,,,,,,,,,,,,,, -12900,1.3426033,4.1467257,,,,,,,,,,,,,, -13000,1.390297,4.101584,,,,,,,,,,,,,, -13100,1.3730873,4.2291427,,,,,,,,,,,,,, -13113,,,0.4274218678474426,2.739063739776612,0.3973399996757507,2.8888654708862305,50000.0,0.3093000054359436,3.4295434951782227,10000.0,5925.677830457687,6327.169443368912,5925.677830457687,400.35482573509216,0.441788911819458,0.0 -13200,1.7190506,4.1075983,,,,,,,,,,,,,, -13300,1.4494153,4.0785127,,,,,,,,,,,,,, -13400,0.987859,6.0547748,,,,,,,,,,,,,, -13500,1.3848507,4.1127214,,,,,,,,,,,,,, -13600,1.369407,5.4153085,,,,,,,,,,,,,, -13700,1.0953524,5.3302298,,,,,,,,,,,,,, -13800,1.1432791,4.954447,,,,,,,,,,,,,, -13900,0.91693014,5.7233586,,,,,,,,,,,,,, -14000,1.6359279,4.0746665,,,,,,,,,,,,,, -14049,,,0.4437695145606994,2.6229381561279297,0.4103399813175201,2.784433364868164,50000.0,0.3188000023365021,3.367639303207397,10000.0,6345.869318246841,6784.241494894028,6345.869318246841,437.1570258140564,0.4712910652160644,0.0 -14100,1.4783611,4.0383177,,,,,,,,,,,,,, -14200,1.6438762,4.0945783,,,,,,,,,,,,,, -14300,1.0403429,5.214268,,,,,,,,,,,,,, -14400,1.5043589,4.153425,,,,,,,,,,,,,, -14500,1.4801244,3.9621007,,,,,,,,,,,,,, -14600,1.3483151,4.0782533,,,,,,,,,,,,,, -14700,1.4424864,3.9575582,,,,,,,,,,,,,, -14800,1.3654577,4.0847797,,,,,,,,,,,,,, -14900,1.0136104,6.073407,,,,,,,,,,,,,, -14985,,,0.4652148187160492,2.516399383544922,0.4221400022506714,2.725985288619995,50000.0,0.3266000151634216,3.2856686115264893,10000.0,6765.792047977448,7240.653775691986,6765.792047977448,473.5686767101288,0.5005159378051758,0.0 -15000,1.2760416,4.769415,,,,,,,,,,,,,, -15100,1.4831548,4.051008,,,,,,,,,,,,,, -15200,1.3265115,4.4589024,,,,,,,,,,,,,, -15300,1.5147572,4.3200126,,,,,,,,,,,,,, -15400,1.3316445,3.9160202,,,,,,,,,,,,,, -15500,0.9367556,5.651531,,,,,,,,,,,,,, -15600,1.2364699,4.909948,,,,,,,,,,,,,, -15700,1.2933769,3.9345276,,,,,,,,,,,,,, -15800,0.8705867,5.9018803,,,,,,,,,,,,,, -15900,1.2686561,4.2383437,,,,,,,,,,,,,, -15915,,,0.4888085722923279,2.4176840782165527,0.4336199760437011,2.675327777862549,50000.0,0.3323000073432922,3.2542564868927,10000.0,7186.127139806747,7693.80969619751,7186.127139806747,506.31141448020935,0.5309438705444336,0.0 -16000,1.269305,4.672481,,,,,,,,,,,,,, -16100,0.97042173,5.841693,,,,,,,,,,,,,, -16200,1.5933534,3.9972115,,,,,,,,,,,,,, -16300,1.4166266,4.0616574,,,,,,,,,,,,,, -16400,1.3749677,4.0131106,,,,,,,,,,,,,, -16500,1.2934701,3.8671675,,,,,,,,,,,,,, -16600,1.5004644,4.033512,,,,,,,,,,,,,, -16700,1.3756238,3.92485,,,,,,,,,,,,,, -16800,1.4708693,4.0370655,,,,,,,,,,,,,, -16839,,,0.4630078077316284,2.538491725921631,0.4345999956130981,2.679887056350708,50000.0,0.3374000191688537,3.2630906105041504,10000.0,7606.272791385651,8151.8433039188385,7606.272791385651,544.1211948394775,0.5615954399108887,0.0 -16900,1.072543,5.0264835,,,,,,,,,,,,,, -17000,1.5270644,4.0218177,,,,,,,,,,,,,, -17100,1.4549288,3.8614569,,,,,,,,,,,,,, -17200,1.588584,4.0249553,,,,,,,,,,,,,, -17300,1.3983619,3.8491359,,,,,,,,,,,,,, -17400,1.2354529,4.615018,,,,,,,,,,,,,, -17500,1.6048232,3.986241,,,,,,,,,,,,,, -17600,2.3408928,3.9257984,,,,,,,,,,,,,, -17700,1.3337092,3.92275,,,,,,,,,,,,,, -17769,,,0.4804101586341858,2.426698684692383,0.4465200006961822,2.5907909870147705,50000.0,0.34170001745224,3.178121328353882,10000.0,8026.660900354385,8607.584362506866,8026.660900354385,579.3966374397278,0.5898916721343994,0.0 -17800,1.3102384,4.1092014,,,,,,,,,,,,,, -17900,1.4979391,3.951738,,,,,,,,,,,,,, -18000,1.622771,3.9690351,,,,,,,,,,,,,, -18100,0.9336597,5.850613,,,,,,,,,,,,,, -18200,1.4521263,3.8747501,,,,,,,,,,,,,, -18300,1.5881537,4.0267787,,,,,,,,,,,,,, -18400,0.86796445,5.84725,,,,,,,,,,,,,, -18500,1.3443651,3.936944,,,,,,,,,,,,,, -18600,1.3627561,3.8360248,,,,,,,,,,,,,, -18700,1.3260341,3.9198298,,,,,,,,,,,,,, -18701,,,0.5080273151397705,2.2943875789642334,0.4573400020599365,2.531662702560425,50000.0,0.3521000146865845,3.1127684116363525,10000.0,8446.603140592575,9062.087691783903,8446.603140592575,613.8807055950165,0.6182742118835449,0.0 -18800,1.398272,3.837881,,,,,,,,,,,,,, -18900,1.3866965,3.8179865,,,,,,,,,,,,,, -19000,1.5315527,3.8139946,,,,,,,,,,,,,, -19100,1.267206,5.3257933,,,,,,,,,,,,,, -19200,1.3605851,4.0234966,,,,,,,,,,,,,, -19300,1.26873,3.8810337,,,,,,,,,,,,,, -19400,1.212848,4.1249995,,,,,,,,,,,,,, -19500,1.5605952,3.8053265,,,,,,,,,,,,,, -19600,1.310194,3.8887925,,,,,,,,,,,,,, -19630,,,0.4916210770606994,2.3779008388519287,0.4596799910068512,2.5362136363983154,50000.0,0.3632000088691711,3.1231794357299805,10000.0,8866.9126598835,9519.569938898088,8866.9126598835,650.9721372127533,0.6507935523986816,0.0 -19700,1.314166,3.80448,,,,,,,,,,,,,, -19800,1.1860144,4.078653,,,,,,,,,,,,,, -19900,1.2584009,4.076191,,,,,,,,,,,,,, -20000,1.3074852,3.8598447,,,,,,,,,,,,,, -20100,1.1571399,4.7334824,,,,,,,,,,,,,, -20200,0.8228751,5.8699207,,,,,,,,,,,,,, -20300,1.0540975,4.4773326,,,,,,,,,,,,,, -20400,0.9427068,5.796263,,,,,,,,,,,,,, -20500,0.9039198,5.8265076,,,,,,,,,,,,,, -20558,,,0.5131444931030273,2.2290053367614746,0.4773799777030945,2.398954629898072,50000.0,0.3674000203609466,2.994854211807251,10000.0,9286.968410253525,9975.836344718931,9286.968410253525,687.1052870750427,0.6784398555755615,0.0 -20600,1.3948442,3.873175,,,,,,,,,,,,,, -20700,1.3984745,3.771641,,,,,,,,,,,,,, -20800,1.3437271,5.328828,,,,,,,,,,,,,, -20900,1.1299556,4.4878283,,,,,,,,,,,,,, -21000,0.9522876,5.244108,,,,,,,,,,,,,, -21100,1.2465779,3.8262756,,,,,,,,,,,,,, -21200,1.4406425,3.9502351,,,,,,,,,,,,,, -21300,0.8010691,5.817972,,,,,,,,,,,,,, -21400,1.1822851,4.3947973,,,,,,,,,,,,,, -21482,,,0.5301367044448853,2.176165580749512,0.4836399853229522,2.393795728683472,50000.0,0.3772000074386596,3.007101535797119,10000.0,9707.312128067017,10429.00049996376,9707.312128067017,719.8414082527161,0.7146031856536865,0.0 -21500,0.87535334,5.756853,,,,,,,,,,,,,, -21600,0.8535374,5.687702,,,,,,,,,,,,,, -21700,0.8319389,5.5447497,,,,,,,,,,,,,, -21800,1.2513765,3.7932487,,,,,,,,,,,,,, -21900,1.5444651,3.8573382,,,,,,,,,,,,,, -22000,1.3049692,3.767013,,,,,,,,,,,,,, -22100,1.1879253,3.901008,,,,,,,,,,,,,, -22200,1.2393595,4.6251454,,,,,,,,,,,,,, -22300,1.2684112,3.7156854,,,,,,,,,,,,,, -22340,,,0.5262304544448853,2.1738510131835938,0.4916400015354156,2.3356337547302246,50000.0,0.3882000148296356,2.9433727264404297,10000.0,10127.587964057922,10885.880244970322,10127.587964057922,756.3632752895355,0.7525720596313477,0.0 -22400,0.9638263,5.4144883,,,,,,,,,,,,,, -22500,1.1583817,4.3250537,,,,,,,,,,,,,, -22600,1.1178197,4.109894,,,,,,,,,,,,,, -22700,1.5758026,3.7551813,,,,,,,,,,,,,, -22800,1.207914,4.1673703,,,,,,,,,,,,,, -22900,1.401479,3.9274433,,,,,,,,,,,,,, -23000,1.3121428,3.781129,,,,,,,,,,,,,, -23100,1.40249,3.6953683,,,,,,,,,,,,,, -23200,1.3081667,3.7073505,,,,,,,,,,,,,, -23269,,,0.5324609279632568,2.129218816757202,0.4975399971008301,2.299388885498047,50000.0,0.3884000182151794,2.912286758422852,10000.0,10547.819781303406,11342.201602220535,10547.819781303406,792.3766157627106,0.7804503440856934,0.0 -23300,1.0891423,4.730091,,,,,,,,,,,,,, -23400,1.3348937,3.7373204,,,,,,,,,,,,,, -23500,0.9685274,5.5561767,,,,,,,,,,,,,, -23600,1.3625383,4.2875156,,,,,,,,,,,,,, -23700,1.4301455,3.6164434,,,,,,,,,,,,,, -23800,1.2544768,4.023901,,,,,,,,,,,,,, -23900,1.3518035,3.7960129,,,,,,,,,,,,,, -24000,1.2004833,4.69696,,,,,,,,,,,,,, -24100,1.2310792,3.9983516,,,,,,,,,,,,,, -24199,,,0.5442578196525574,2.104851245880127,0.498339980840683,2.314196348190308,50000.0,0.3945000171661377,2.9152963161468506,10000.0,10967.767451047896,11799.014899253843,10967.767451047896,829.1624467372894,0.8119678497314453,0.0 -24200,1.1120588,4.4434547,,,,,,,,,,,,,, -24300,1.4269603,3.7041686,,,,,,,,,,,,,, -24400,1.0782536,5.1436715,,,,,,,,,,,,,, -24500,1.2571359,3.764905,,,,,,,,,,,,,, -24600,1.2564093,3.6825135,,,,,,,,,,,,,, -24700,1.1706293,4.7880793,,,,,,,,,,,,,, -24800,0.924462,5.679656,,,,,,,,,,,,,, -24900,1.1539962,5.455414,,,,,,,,,,,,,, -25000,1.0336385,5.838314,,,,,,,,,,,,,, -25100,1.2332542,4.1499963,,,,,,,,,,,,,, -25128,,,0.5494335889816284,2.1070497035980225,0.5054199695587158,2.3007595539093018,50000.0,0.3976000249385834,2.9255411624908447,10000.0,11388.04214978218,12255.743125915527,11388.04214978218,865.5352845191956,0.8451135158538818,0.0 -25200,1.0361422,4.7452908,,,,,,,,,,,,,, -25300,0.9812696,5.45128,,,,,,,,,,,,,, -25400,1.240541,3.623427,,,,,,,,,,,,,, -25500,1.3238655,3.6927388,,,,,,,,,,,,,, -25600,1.0493083,4.895126,,,,,,,,,,,,,, -25700,1.2916104,3.887684,,,,,,,,,,,,,, -25800,1.4004033,3.6676445,,,,,,,,,,,,,, -25900,1.2581016,3.8625634,,,,,,,,,,,,,, -26000,1.3976403,3.5557528,,,,,,,,,,,,,, -26047,,,0.5480077862739563,2.075366735458374,0.5156199932098389,2.2329912185668945,50000.0,0.4003000259399414,2.8565263748168945,10000.0,11808.012126922607,12711.5367333889,11808.012126922607,901.2724430561066,0.883490800857544,0.0 -26100,1.0986378,5.7307673,,,,,,,,,,,,,, -26200,1.2770381,3.9214926,,,,,,,,,,,,,, -26300,1.4793637,3.5824983,,,,,,,,,,,,,, -26400,1.1223314,4.368575,,,,,,,,,,,,,, -26500,1.4929482,3.6031866,,,,,,,,,,,,,, -26600,1.4287294,3.6496344,,,,,,,,,,,,,, -26700,1.3357087,3.7388575,,,,,,,,,,,,,, -26800,1.0041838,5.4484415,,,,,,,,,,,,,, -26900,1.0742602,4.839417,,,,,,,,,,,,,, -26974,,,0.5512499809265137,2.059241533279419,0.510919988155365,2.25130581855774,50000.0,0.3957000076770782,2.879225015640259,10000.0,12228.003628730774,13167.093488931656,12228.003628730774,936.7529728412628,0.9200353622436525,0.0 -27000,1.3955276,3.5024452,,,,,,,,,,,,,, -27100,1.1810149,3.8511376,,,,,,,,,,,,,, -27200,1.4264784,3.7275157,,,,,,,,,,,,,, -27300,1.34899,3.5234554,,,,,,,,,,,,,, -27400,1.205274,4.907874,,,,,,,,,,,,,, -27500,1.1270579,3.7349288,,,,,,,,,,,,,, -27600,1.016684,5.6041803,,,,,,,,,,,,,, -27700,1.0322405,4.8121967,,,,,,,,,,,,,, -27800,0.99328077,5.737479,,,,,,,,,,,,,, -27900,,,0.5764257907867432,1.9356085062026973,0.5181599855422974,2.198256015777588,50000.0,0.4039000272750854,2.825509786605835,10000.0,12648.078053474426,13623.059293746948,12648.078053474426,972.5665490627288,0.9505248069763184,0.0 -27900,1.2960627,3.6692953,,,,,,,,,,,,,, -28000,1.4802665,3.563031,,,,,,,,,,,,,, -28100,1.3728085,3.751822,,,,,,,,,,,,,, -28200,1.2377893,4.379945,,,,,,,,,,,,,, -28300,1.5140281,3.614008,,,,,,,,,,,,,, -28400,1.2994537,3.5550923,,,,,,,,,,,,,, -28500,1.0876354,3.9869304,,,,,,,,,,,,,, -28600,1.1931688,3.761272,,,,,,,,,,,,,, -28700,1.3101733,4.806696,,,,,,,,,,,,,, -28800,1.0730133,4.9760947,,,,,,,,,,,,,, -28825,,,0.5618749856948853,1.990774154663086,0.5294600129127502,2.153543472290039,50000.0,0.4154000282287597,2.7733428478240967,10000.0,13068.219451665878,14078.781080007551,13068.219451665878,1008.0687556266784,0.9811530113220216,0.0 -28900,1.4997779,4.0696898,,,,,,,,,,,,,, -29000,1.3145499,3.6486955,,,,,,,,,,,,,, -29100,1.0623331,4.834098,,,,,,,,,,,,,, -29200,1.3091551,4.030813,,,,,,,,,,,,,, -29300,1.3382776,3.8333497,,,,,,,,,,,,,, -29400,1.2729064,3.7033505,,,,,,,,,,,,,, -29500,1.2692733,3.680381,,,,,,,,,,,,,, -29600,1.3578877,3.5426092,,,,,,,,,,,,,, -29700,1.3643844,3.5313165,,,,,,,,,,,,,, -29752,,,0.5666210651397705,1.9570411443710327,0.5254200100898743,2.154048442840576,50000.0,0.414900004863739,2.78692626953125,10000.0,13488.522999286652,14534.944336652756,13488.522999286652,1043.8474705219269,1.0138554573059082,0.0 -29800,1.4275908,3.547422,,,,,,,,,,,,,, -29900,1.3987334,3.5339448,,,,,,,,,,,,,, -30000,1.5866907,3.5430937,,,,,,,,,,,,,, -30100,1.4256448,3.9714823,,,,,,,,,,,,,, -30200,1.1672124,4.391115,,,,,,,,,,,,,, -30300,1.3340737,3.6177468,,,,,,,,,,,,,, -30400,1.0924978,5.415821,,,,,,,,,,,,,, -30500,1.3028544,3.876987,,,,,,,,,,,,,, -30600,1.6307116,3.585103,,,,,,,,,,,,,, -30682,,,0.5687890648841858,2.034207344055176,0.5227400064468384,2.254662036895752,50000.0,0.4161000251770019,2.84878921508789,10000.0,13908.572232484818,14990.768680334091,13908.572232484818,1079.546533346176,1.0415377616882324,0.0 -30700,1.1708281,4.135778,,,,,,,,,,,,,, -30800,1.1581283,5.139643,,,,,,,,,,,,,, -30900,1.2309612,4.32562,,,,,,,,,,,,,, -31000,1.4240328,3.4666884,,,,,,,,,,,,,, -31100,1.4491062,3.6077464,,,,,,,,,,,,,, -31200,1.1068927,4.8677125,,,,,,,,,,,,,, -31300,1.1434181,4.7075424,,,,,,,,,,,,,, -31400,1.3704551,3.5235221,,,,,,,,,,,,,, -31500,1.1950032,4.053247,,,,,,,,,,,,,, -31600,1.4105023,3.4845405,,,,,,,,,,,,,, -31610,,,0.5677539110183716,1.991437315940857,0.531279981136322,2.1594009399414062,50000.0,0.4204000234603882,2.7725627422332764,10000.0,14328.762746572496,15446.033729076384,14328.762746572496,1114.5431122779846,1.0704760551452637,0.0 -31700,1.3027003,3.7004995,,,,,,,,,,,,,, -31800,1.335704,3.6314197,,,,,,,,,,,,,, -31900,1.3068855,3.6307847,,,,,,,,,,,,,, -32000,1.3204551,3.5713387,,,,,,,,,,,,,, -32100,1.1876396,5.6974936,,,,,,,,,,,,,, -32200,1.237105,5.720601,,,,,,,,,,,,,, -32300,1.2971659,3.7642844,,,,,,,,,,,,,, -32400,1.5073385,3.6589372,,,,,,,,,,,,,, -32500,0.9501278,5.582318,,,,,,,,,,,,,, -32539,,,0.5756444931030273,1.9299191236495967,0.5379799604415894,2.113074541091919,50000.0,0.4249000251293182,2.7316830158233643,10000.0,14749.06371474266,15901.75974059105,14749.06371474266,1149.8860309123993,1.1039931774139404,0.0 -32600,1.4098198,3.4997566,,,,,,,,,,,,,, -32700,1.381813,3.5106044,,,,,,,,,,,,,, -32800,1.3796653,3.3909166,,,,,,,,,,,,,, -32900,1.385864,3.5513403,,,,,,,,,,,,,, -33000,1.054408,4.6468725,,,,,,,,,,,,,, -33100,1.4175429,3.445407,,,,,,,,,,,,,, -33200,1.4094354,3.5756228,,,,,,,,,,,,,, -33300,1.4839922,3.5319004,,,,,,,,,,,,,, -33400,1.2877681,4.7922306,,,,,,,,,,,,,, -33466,,,0.5826953053474426,1.9023802280426023,0.5382999777793884,2.1023356914520264,50000.0,0.4214000105857849,2.734881401062012,10000.0,15168.989698171616,16358.529019117355,15168.989698171616,1186.6391808986664,1.1376256942749023,0.0 -33500,1.0296485,4.956491,,,,,,,,,,,,,, -33600,1.4240904,3.53367,,,,,,,,,,,,,, -33700,1.3231711,3.560034,,,,,,,,,,,,,, -33800,1.3234195,3.6174612,,,,,,,,,,,,,, -33900,1.3713292,3.5469582,,,,,,,,,,,,,, -34000,1.4852079,3.5674746,,,,,,,,,,,,,, -34100,1.5041853,3.5651352,,,,,,,,,,,,,, -34200,1.3926188,3.6544201,,,,,,,,,,,,,, -34300,1.1392933,5.545694,,,,,,,,,,,,,, -34393,,,0.5888671875,1.8822269439697263,0.5423799753189087,2.086999654769897,50000.0,0.431300014257431,2.7067244052886963,10000.0,15589.234191179276,16815.190942525864,15589.234191179276,1222.977279663086,1.1691491603851318,0.0 -34400,1.4094145,3.5529306,,,,,,,,,,,,,, -34500,1.6672454,3.5642061,,,,,,,,,,,,,, -34600,1.486484,3.3528228,,,,,,,,,,,,,, -34700,1.4098763,3.4547415,,,,,,,,,,,,,, -34800,1.0483737,4.6302967,,,,,,,,,,,,,, -34900,1.2222357,4.5634646,,,,,,,,,,,,,, -35000,1.1251229,5.5297737,,,,,,,,,,,,,, -35100,1.0609283,5.2334857,,,,,,,,,,,,,, -35200,1.4073862,3.4915888,,,,,,,,,,,,,, -35300,1.3147949,3.8102832,,,,,,,,,,,,,, -35320,,,0.5822656154632568,1.9264488220214844,0.5397199988365173,2.1193225383758545,50000.0,0.4256000220775604,2.749296188354492,10000.0,16009.42619729042,17271.53852891922,16009.42619729042,1259.0506103038788,1.20351243019104,0.0 -35400,1.4339917,3.445674,,,,,,,,,,,,,, -35500,1.4431401,3.4275768,,,,,,,,,,,,,, -35600,1.2241207,5.13134,,,,,,,,,,,,,, -35700,1.4133234,3.5792542,,,,,,,,,,,,,, -35800,1.1313184,5.6440353,,,,,,,,,,,,,, -35900,1.5971643,3.6223438,,,,,,,,,,,,,, -36000,1.498956,3.4025931,,,,,,,,,,,,,, -36100,1.491033,3.4155378,,,,,,,,,,,,,, -36200,1.3323488,3.452718,,,,,,,,,,,,,, -36250,,,0.5885351300239563,1.876665472984314,0.5481399893760681,2.063407182693481,50000.0,0.4310000240802765,2.680048704147339,10000.0,16429.708421945572,17727.43417429924,16429.708421945572,1294.5700707435608,1.2412841320037842,0.0 -36300,1.4129547,3.486539,,,,,,,,,,,,,, -36400,1.3613743,3.7768083,,,,,,,,,,,,,, -36500,1.3739685,4.368074,,,,,,,,,,,,,, -36600,1.3958303,3.4424548,,,,,,,,,,,,,, -36700,1.4504461,3.4414213,,,,,,,,,,,,,, -36800,1.52811,3.5283608,,,,,,,,,,,,,, -36900,1.1915034,5.0446615,,,,,,,,,,,,,, -37000,1.5818248,3.4972045,,,,,,,,,,,,,, -37100,1.4013922,3.5104053,,,,,,,,,,,,,, -37179,,,0.6117382645606995,1.750945806503296,0.5479400157928467,2.041366577148437,50000.0,0.4376000165939331,2.6541614532470703,10000.0,16850.04904961586,18184.6283288002,16850.04904961586,1331.338150024414,1.278425931930542,0.0 -37200,1.3204877,3.8116543,,,,,,,,,,,,,, -37300,1.3549953,3.3088236,,,,,,,,,,,,,, -37400,1.4474669,3.315754,,,,,,,,,,,,,, -37500,1.4669327,3.6108582,,,,,,,,,,,,,, -37600,1.7295442,3.4950743,,,,,,,,,,,,,, -37700,1.2904123,3.7021809,,,,,,,,,,,,,, -37800,1.2206414,4.5888076,,,,,,,,,,,,,, -37900,1.1725017,4.701391,,,,,,,,,,,,,, -38000,1.4823895,3.3722918,,,,,,,,,,,,,, -38100,1.0447302,5.6713114,,,,,,,,,,,,,, -38104,,,0.5925390720367432,1.8541145324707031,0.5517799854278564,2.030083656311035,50000.0,0.4384000301361084,2.6467978954315186,10000.0,17269.996902942657,18640.95697760582,17269.996902942657,1367.6382720470428,1.311347246170044,0.0 -38200,1.4221218,3.9952497,,,,,,,,,,,,,, -38300,1.6298459,3.5403616,,,,,,,,,,,,,, -38400,1.2150303,4.345988,,,,,,,,,,,,,, -38500,1.3623418,3.4216561,,,,,,,,,,,,,, -38600,1.5685619,3.4836774,,,,,,,,,,,,,, -38700,1.2558162,4.4658437,,,,,,,,,,,,,, -38800,1.5014392,3.44521,,,,,,,,,,,,,, -38900,1.6838489,3.369714,,,,,,,,,,,,,, -39000,1.3772153,3.8670235,,,,,,,,,,,,,, -39030,,,0.5927343368530273,1.835362672805786,0.5510199666023254,2.033008813858032,50000.0,0.4350000321865082,2.658357620239258,10000.0,17689.955061912537,19097.34910964966,17689.955061912537,1403.9910144805908,1.3433549404144287,0.0 -39100,1.4605441,3.6529012,,,,,,,,,,,,,, -39200,1.3375614,3.8738966,,,,,,,,,,,,,, -39300,1.416002,4.303237,,,,,,,,,,,,,, -39400,1.1644193,4.2343025,,,,,,,,,,,,,, -39500,1.5009965,3.4904516,,,,,,,,,,,,,, -39600,1.4635339,3.400538,,,,,,,,,,,,,, -39700,1.6271316,3.597887,,,,,,,,,,,,,, -39800,1.5088947,3.3872788,,,,,,,,,,,,,, -39900,1.3777103,3.4292839,,,,,,,,,,,,,, -39957,,,0.5999609231948853,1.784259557723999,0.5545600056648254,2.00947380065918,50000.0,0.4463000297546386,2.623008251190185,10000.0,18110.31490755081,19554.41698360443,18110.31490755081,1440.6161260604858,1.377694845199585,0.0 -40000,1.0730422,5.101966,,,,,,,,,,,,,, -40100,1.6068758,3.4180562,,,,,,,,,,,,,, -40200,1.6714971,3.4362583,,,,,,,,,,,,,, -40300,1.3407238,3.3550897,,,,,,,,,,,,,, -40400,1.3378165,3.3511055,,,,,,,,,,,,,, -40500,1.515975,3.3625336,,,,,,,,,,,,,, -40600,1.2029754,4.1285925,,,,,,,,,,,,,, -40700,1.5182973,3.464805,,,,,,,,,,,,,, -40800,1.3751297,3.636691,,,,,,,,,,,,,, -40882,,,0.5937694907188416,1.8420950174331665,0.5575799942016602,2.006721258163452,50000.0,0.4439000189304352,2.64489483833313,10000.0,18530.302005290985,20009.39934825897,18530.302005290985,1475.5338282585144,1.4079561233520508,0.0 -40900,1.0028774,5.5202208,,,,,,,,,,,,,, -41000,1.4492271,3.881847,,,,,,,,,,,,,, -41100,1.3786874,3.5147743,,,,,,,,,,,,,, -41200,1.4446857,3.4572637,,,,,,,,,,,,,, -41300,1.5904055,3.7406287,,,,,,,,,,,,,, -41400,1.3582181,3.3406231,,,,,,,,,,,,,, -41500,1.1355387,4.9287205,,,,,,,,,,,,,, -41600,1.1924856,5.049332,,,,,,,,,,,,,, -41700,1.1719286,5.5986614,,,,,,,,,,,,,, -41800,1.6035287,3.3809965,,,,,,,,,,,,,, -41810,,,0.6000195145606995,1.8037185668945312,0.5605800151824951,1.9919346570968628,50000.0,0.4422000348567962,2.627874851226806,10000.0,18950.62165570259,20465.520961999893,18950.62165570259,1511.2537944316864,1.4423680305480957,0.0 -41900,1.0182881,5.1523914,,,,,,,,,,,,,, -42000,1.3441006,3.4273973,,,,,,,,,,,,,, -42100,1.2304366,4.7307596,,,,,,,,,,,,,, -42200,1.5632782,3.4334414,,,,,,,,,,,,,, -42300,1.4273881,3.4069958,,,,,,,,,,,,,, -42400,1.2888194,3.8909373,,,,,,,,,,,,,, -42500,1.4368839,3.5163794,,,,,,,,,,,,,, -42600,1.1803032,4.4303875,,,,,,,,,,,,,, -42700,1.511061,3.3925242,,,,,,,,,,,,,, -42736,,,0.6067187190055847,1.7677797079086304,0.5633800029754639,1.979863405227661,50000.0,0.446800023317337,2.621208667755127,10000.0,19370.62746167183,20921.324761152267,19370.62746167183,1546.9716968536377,1.474189043045044,0.0 -42800,1.3552668,3.3522744,,,,,,,,,,,,,, -42900,1.8469342,3.4318068,,,,,,,,,,,,,, -43000,1.5071441,3.5479145,,,,,,,,,,,,,, -43100,1.5443445,3.4197705,,,,,,,,,,,,,, -43200,1.4344764,3.4256825,,,,,,,,,,,,,, -43300,1.4821064,3.6636388,,,,,,,,,,,,,, -43400,1.3460901,3.3817456,,,,,,,,,,,,,, -43500,1.4897463,3.9553206,,,,,,,,,,,,,, -43600,1.3689321,3.509856,,,,,,,,,,,,,, -43664,,,0.6116601228713989,1.7431319952011108,0.5555999875068665,1.9900355339050293,50000.0,0.4404000341892242,2.6423449516296387,10000.0,19790.838992118835,21376.125276327133,19790.838992118835,1581.4821391105652,1.5045185089111328,0.0 -43700,1.2632706,3.99385,,,,,,,,,,,,,, -43800,1.2191658,4.627868,,,,,,,,,,,,,, -43900,1.2351601,5.0628448,,,,,,,,,,,,,, -44000,1.5510157,3.3635764,,,,,,,,,,,,,, -44100,1.3732551,3.391531,,,,,,,,,,,,,, -44200,1.4103644,3.612073,,,,,,,,,,,,,, -44300,1.4524565,3.4849007,,,,,,,,,,,,,, -44400,1.4992036,3.5250022,,,,,,,,,,,,,, -44500,1.4719815,3.4707732,,,,,,,,,,,,,, -44589,,,0.6066015362739563,1.761168122291565,0.562720000743866,1.9542808532714844,50000.0,0.4478000104427337,2.59022855758667,10000.0,20211.031491041183,21832.046981096268,20211.031491041183,1617.130084991455,1.5374326705932615,0.0 -44600,1.6235795,3.6080544,,,,,,,,,,,,,, -44700,1.5969923,3.4222076,,,,,,,,,,,,,, -44800,1.2940797,3.7808135,,,,,,,,,,,,,, -44900,1.4220896,3.4991007,,,,,,,,,,,,,, -45000,1.4112445,3.66885,,,,,,,,,,,,,, -45100,1.4348093,3.4159079,,,,,,,,,,,,,, -45200,1.3204412,5.4587336,,,,,,,,,,,,,, -45300,1.6477866,3.387185,,,,,,,,,,,,,, -45400,1.5797967,3.4175653,,,,,,,,,,,,,, -45500,1.4130517,3.3615055,,,,,,,,,,,,,, -45511,,,0.6109179258346558,1.7445123195648191,0.5663999915122986,1.957970380783081,50000.0,0.4520000219345093,2.569453239440918,10000.0,20631.116145849228,22287.97873067856,20631.116145849228,1652.8864409923551,1.57989239692688,0.0 -45600,1.3717997,3.9535227,,,,,,,,,,,,,, -45700,1.3906595,3.4318476,,,,,,,,,,,,,, -45800,1.4485418,3.4744287,,,,,,,,,,,,,, -45900,1.4883512,3.9411259,,,,,,,,,,,,,, -46000,1.3102393,3.721053,,,,,,,,,,,,,, -46100,1.40232,3.8270173,,,,,,,,,,,,,, -46200,1.1154197,4.525117,,,,,,,,,,,,,, -46300,1.4205239,3.4116511,,,,,,,,,,,,,, -46400,1.2727975,4.3916016,,,,,,,,,,,,,, -46439,,,0.6299023032188416,1.6438137292861938,0.5723599791526794,1.911130428314209,50000.0,0.4525000154972076,2.548551559448242,10000.0,21051.443721055984,22743.75170326233,21051.443721055984,1688.2459979057312,1.6172716617584229,0.0 -46500,1.3943185,3.8728275,,,,,,,,,,,,,, -46600,1.6858964,3.3739562,,,,,,,,,,,,,, -46700,1.5005594,3.4087505,,,,,,,,,,,,,, -46800,1.4778693,3.4894109,,,,,,,,,,,,,, -46900,1.4539553,5.2187033,,,,,,,,,,,,,, -47000,1.3939786,3.320624,,,,,,,,,,,,,, -47100,1.4917542,3.3132644,,,,,,,,,,,,,, -47200,1.147506,5.1717005,,,,,,,,,,,,,, -47300,1.5186753,3.37987,,,,,,,,,,,,,, -47366,,,0.6089648008346558,1.7728519439697266,0.5699399709701538,1.960967898368836,50000.0,0.4547000229358673,2.593113422393799,10000.0,21471.823915719982,23200.93217253685,21471.823915719982,1724.9623546600342,1.65169358253479,0.0 -47400,1.3676859,4.0095224,,,,,,,,,,,,,, -47500,1.605692,3.4495914,,,,,,,,,,,,,, -47600,1.4786499,3.3972025,,,,,,,,,,,,,, -47700,1.496633,3.3159142,,,,,,,,,,,,,, -47800,1.1628985,4.842612,,,,,,,,,,,,,, -47900,1.7221731,3.383319,,,,,,,,,,,,,, -48000,1.1885072,4.276,,,,,,,,,,,,,, -48100,1.2185616,4.8049016,,,,,,,,,,,,,, -48200,1.5511658,3.4395165,,,,,,,,,,,,,, -48292,,,0.6068750023841858,1.7780888080596924,0.5670199990272522,1.9639018774032595,50000.0,0.4553000330924988,2.6022982597351074,10000.0,21892.04855132103,23658.27480411529,21892.04855132103,1761.9988887310028,1.6848242282867432,0.0 -48300,1.3949305,3.5111215,,,,,,,,,,,,,, -48400,1.3729805,3.3486054,,,,,,,,,,,,,, -48500,1.524731,3.4194074,,,,,,,,,,,,,, -48600,1.4834974,3.2388313,,,,,,,,,,,,,, -48700,1.5213385,3.343448,,,,,,,,,,,,,, -48800,1.4359548,3.2730012,,,,,,,,,,,,,, -48900,1.5484226,3.283302,,,,,,,,,,,,,, -49000,1.3048738,5.1611614,,,,,,,,,,,,,, -49100,1.1218495,5.591569,,,,,,,,,,,,,, -49200,2.090178,3.358727,,,,,,,,,,,,,, -49216,,,0.6234374642372131,1.6764730215072632,0.5714200139045715,1.9154047966003416,50000.0,0.4580000340938568,2.5392799377441406,10000.0,22312.30050706864,24113.78163957596,22312.30050706864,1797.1702196598053,1.7208774089813232,0.0 -49300,1.3403322,3.6986482,,,,,,,,,,,,,, -49400,1.7064992,3.383956,,,,,,,,,,,,,, -49500,1.4331785,3.7838712,,,,,,,,,,,,,, -49600,1.5720694,3.2988105,,,,,,,,,,,,,, -49700,1.5393023,3.3596234,,,,,,,,,,,,,, -49800,1.2285556,5.5273094,,,,,,,,,,,,,, -49900,1.195889,4.9067583,,,,,,,,,,,,,, -50000,1.6748378,3.3383126,,,,,,,,,,,,,, -50100,1.1990845,4.5231123,,,,,,,,,,,,,, -50139,,,0.6090039014816284,1.7651121616363523,0.569819986820221,1.935936689376831,50000.0,0.455700010061264,2.5697062015533447,10000.0,22732.495416641235,24568.618636846542,22732.495416641235,1831.73202419281,1.7526824474334717,0.0 -50200,1.4063075,3.4829779,,,,,,,,,,,,,, -50300,1.2194809,4.5269814,,,,,,,,,,,,,, -50400,1.4403117,3.9052253,,,,,,,,,,,,,, -50500,1.440061,3.343279,,,,,,,,,,,,,, -50600,1.0412238,5.4661818,,,,,,,,,,,,,, -50700,1.4281343,3.9262362,,,,,,,,,,,,,, -50800,1.3650061,3.2885463,,,,,,,,,,,,,, -50900,1.4667852,3.3682203,,,,,,,,,,,,,, -51000,1.6116537,3.8679051,,,,,,,,,,,,,, -51065,,,0.6101757884025574,1.7590184211730957,0.5707600116729736,1.938204765319824,50000.0,0.45210000872612,2.5792510509490967,10000.0,23152.728211402893,25025.574861764908,23152.728211402893,1868.3653919696808,1.794832706451416,0.0 -51100,1.4345891,3.4030921,,,,,,,,,,,,,, -51200,1.5823152,3.279744,,,,,,,,,,,,,, -51300,1.548411,3.338859,,,,,,,,,,,,,, -51400,1.5637871,3.7211804,,,,,,,,,,,,,, -51500,1.5712808,3.4537141,,,,,,,,,,,,,, -51600,1.2809341,4.8069086,,,,,,,,,,,,,, -51700,1.6183051,3.3327932,,,,,,,,,,,,,, -51800,1.3419788,4.14547,,,,,,,,,,,,,, -51900,1.3113419,3.9990733,,,,,,,,,,,,,, -51992,,,0.6263476610183716,1.6484354734420776,0.5796999931335449,1.865336775779724,50000.0,0.4656000137329101,2.487163782119751,10000.0,23573.097232103348,25483.33177614212,23573.097232103348,1905.6698813438416,1.828599452972412,0.0 -52000,1.5094332,3.7004614,,,,,,,,,,,,,, -52100,1.2197092,4.3223057,,,,,,,,,,,,,, -52200,1.6301633,3.499522,,,,,,,,,,,,,, -52300,1.4159987,3.636109,,,,,,,,,,,,,, -52400,1.4634871,3.4666467,,,,,,,,,,,,,, -52500,1.5686375,3.330757,,,,,,,,,,,,,, -52600,1.4345968,3.8646355,,,,,,,,,,,,,, -52700,1.5149012,3.318666,,,,,,,,,,,,,, -52800,1.5117047,3.393682,,,,,,,,,,,,,, -52900,1.2882208,4.479917,,,,,,,,,,,,,, -52919,,,0.6298437118530273,1.6609742641448977,0.5778399705886841,1.9005619287490845,50000.0,0.4649000167846679,2.5275325775146484,10000.0,23993.060639619827,25941.21884179116,23993.060639619827,1943.5027811527248,1.8709113597869875,0.0 -53000,1.3090826,5.3077,,,,,,,,,,,,,, -53100,1.7054398,3.2987769,,,,,,,,,,,,,, -53200,1.6671844,3.352186,,,,,,,,,,,,,, -53300,1.3525172,5.429901,,,,,,,,,,,,,, -53400,1.4884073,3.4382765,,,,,,,,,,,,,, -53500,1.4344174,3.6175115,,,,,,,,,,,,,, -53600,1.7220294,3.3064294,,,,,,,,,,,,,, -53700,1.3390882,3.8679416,,,,,,,,,,,,,, -53800,1.3391145,3.956828,,,,,,,,,,,,,, -53845,,,0.6171679496765137,1.7618088722229004,0.5772799849510193,1.945756793022156,50000.0,0.457800030708313,2.5673742294311523,10000.0,24413.19846820832,26397.88577604294,24413.19846820832,1979.9512028694155,1.9040093421936035,0.0 -53900,1.2533616,4.7509537,,,,,,,,,,,,,, -54000,1.1088799,4.6682796,,,,,,,,,,,,,, -54100,1.2529287,3.9990144,,,,,,,,,,,,,, -54200,1.4616286,3.3328538,,,,,,,,,,,,,, -54300,1.3797231,4.2273464,,,,,,,,,,,,,, -54400,1.5042139,3.319408,,,,,,,,,,,,,, -54500,1.5503343,3.2816691,,,,,,,,,,,,,, -54600,1.4598488,3.8176074,,,,,,,,,,,,,, -54700,1.5836177,3.376825,,,,,,,,,,,,,, -54771,,,0.6280664205551147,1.6353827714920044,0.581820011138916,1.8503868579864504,50000.0,0.4653000235557556,2.465291976928711,10000.0,24833.537580490112,26855.23208403588,24833.537580490112,2016.8710873126984,1.9429829120635984,0.0 -54800,1.1628165,4.713432,,,,,,,,,,,,,, -54900,1.6719007,3.2708576,,,,,,,,,,,,,, -55000,1.6813568,3.3791938,,,,,,,,,,,,,, -55100,1.4660041,3.228722,,,,,,,,,,,,,, -55200,1.5489116,3.4447103,,,,,,,,,,,,,, -55300,1.5594436,3.3703103,,,,,,,,,,,,,, -55400,1.4924674,3.1964579,,,,,,,,,,,,,, -55500,1.4378114,3.358801,,,,,,,,,,,,,, -55600,1.5449836,3.3716533,,,,,,,,,,,,,, -55700,,,0.6409375071525574,1.605479121208191,0.5806399583816528,1.8812267780303955,50000.0,0.461400032043457,2.5292446613311768,10000.0,25253.59513092041,27311.22057056427,25253.59513092041,2052.716703414917,1.9796583652496336,0.0 -55700,1.5941616,3.3236754,,,,,,,,,,,,,, -55800,1.5835307,3.344381,,,,,,,,,,,,,, -55900,1.6737314,3.3507547,,,,,,,,,,,,,, -56000,1.6998451,3.5626798,,,,,,,,,,,,,, -56100,1.5883102,3.6970723,,,,,,,,,,,,,, -56200,1.5525203,3.3281434,,,,,,,,,,,,,, -56300,1.4984516,3.2951372,,,,,,,,,,,,,, -56400,1.6604145,3.4153533,,,,,,,,,,,,,, -56500,1.8354889,3.427022,,,,,,,,,,,,,, -56600,1.2869201,5.5535703,,,,,,,,,,,,,, -56627,,,0.6262499690055847,1.657572865486145,0.5814999938011169,1.8544203042984009,50000.0,0.4660000205039978,2.484644651412964,10000.0,25673.629153251648,27767.3883125782,25673.629153251648,2088.7680180072784,2.0139429569244385,0.0 -56700,1.5514188,3.2477605,,,,,,,,,,,,,, -56800,1.2224795,4.9621677,,,,,,,,,,,,,, -56900,1.225491,4.963935,,,,,,,,,,,,,, -57000,1.1374646,4.97673,,,,,,,,,,,,,, -57100,1.0845866,4.958519,,,,,,,,,,,,,, -57200,1.4384803,4.866322,,,,,,,,,,,,,, -57300,1.6729645,3.8945434,,,,,,,,,,,,,, -57400,1.5839701,3.2179775,,,,,,,,,,,,,, -57500,1.5653417,3.4789357,,,,,,,,,,,,,, -57555,,,0.6287499666213989,1.6723737716674805,0.5836600065231323,1.8777090311050413,50000.0,0.4666000306606293,2.503213167190552,10000.0,26093.89505290985,28225.16506052017,26093.89505290985,2126.1931478977203,2.051152229309082,0.0 -57600,1.3478378,4.1287394,,,,,,,,,,,,,, -57700,1.6451479,3.3611724,,,,,,,,,,,,,, -57800,1.4996792,3.8146012,,,,,,,,,,,,,, -57900,1.6837534,3.3823428,,,,,,,,,,,,,, -58000,1.5351279,3.2773569,,,,,,,,,,,,,, -58100,1.3995737,4.554545,,,,,,,,,,,,,, -58200,1.3197026,4.7220106,,,,,,,,,,,,,, -58300,1.5173866,3.2604108,,,,,,,,,,,,,, -58400,1.6222804,3.3198109,,,,,,,,,,,,,, -58481,,,0.6403515338897705,1.61244797706604,0.582859992980957,1.8621419668197632,50000.0,0.4684000313282013,2.4887142181396484,10000.0,26513.911805152893,28683.89664721489,26513.911805152893,2164.820201158524,2.0870184898376465,0.0 -58500,1.2983412,4.25698,,,,,,,,,,,,,, -58600,1.4950194,3.602977,,,,,,,,,,,,,, -58700,1.6429241,3.332358,,,,,,,,,,,,,, -58800,1.4904308,3.3859875,,,,,,,,,,,,,, -58900,1.1705171,4.4564686,,,,,,,,,,,,,, -59000,1.5525653,3.2914245,,,,,,,,,,,,,, -59100,1.4673378,3.8551135,,,,,,,,,,,,,, -59200,1.7851063,3.2274468,,,,,,,,,,,,,, -59300,1.119426,5.405107,,,,,,,,,,,,,, -59400,1.2525655,5.5633373,,,,,,,,,,,,,, -59407,,,0.6268945336341858,1.6685230731964111,0.5874199867248535,1.8534951210021973,50000.0,0.4738000333309173,2.4757091999053955,10000.0,26934.146147489548,29142.89506626129,26934.146147489548,2203.4964802265167,2.1264965534210205,0.0 -59500,1.3542918,5.434774,,,,,,,,,,,,,, -59600,1.7159153,3.296309,,,,,,,,,,,,,, -59700,1.1795055,5.193677,,,,,,,,,,,,,, -59800,1.4985044,3.5277567,,,,,,,,,,,,,, -59900,1.6337299,3.2606947,,,,,,,,,,,,,, -60000,1.3038152,4.05112,,,,,,,,,,,,,, -60100,1.7048718,3.3117237,,,,,,,,,,,,,, -60200,1.5950979,3.3028994,,,,,,,,,,,,,, -60300,1.4477384,4.462007,,,,,,,,,,,,,, -60334,,,0.6327733993530273,1.6337242126464844,0.5888599753379822,1.8365858793258667,50000.0,0.4762000143527984,2.4529848098754883,10000.0,27354.376095294952,29600.636869430546,27354.376095294952,2240.920811891556,2.165325164794922,0.0 -60400,1.60569,3.2873812,,,,,,,,,,,,,, -60500,1.9439312,3.380148,,,,,,,,,,,,,, -60600,1.4641417,3.609096,,,,,,,,,,,,,, -60700,1.5558234,3.3087862,,,,,,,,,,,,,, -60800,1.5342232,3.3983204,,,,,,,,,,,,,, -60900,1.4645692,4.188256,,,,,,,,,,,,,, -61000,1.3695499,4.8141975,,,,,,,,,,,,,, -61100,1.580779,3.4229598,,,,,,,,,,,,,, -61200,1.1966629,4.500713,,,,,,,,,,,,,, -61261,,,0.6378515362739563,1.6398062705993652,0.5875999927520752,1.8620257377624512,50000.0,0.4675000309944153,2.50374174118042,10000.0,27774.38081717491,30054.39617562294,27774.38081717491,2274.5890777111053,2.203279733657837,0.0 -61300,1.5652198,3.2485366,,,,,,,,,,,,,, -61400,1.5925906,3.2320025,,,,,,,,,,,,,, -61500,1.9371151,3.2843633,,,,,,,,,,,,,, -61600,1.5766526,3.561977,,,,,,,,,,,,,, -61700,1.4312409,4.483053,,,,,,,,,,,,,, -61800,1.4974657,3.3080978,,,,,,,,,,,,,, -61900,1.707016,3.2736464,,,,,,,,,,,,,, -62000,1.6324397,3.3330674,,,,,,,,,,,,,, -62100,1.7219219,3.5784595,,,,,,,,,,,,,, -62185,,,0.6422656178474426,1.6118204593658447,0.5854799747467041,1.8537278175354004,50000.0,0.4650000333786011,2.49658751487732,10000.0,28194.3759431839,30510.36556315422,28194.3759431839,2310.474866867065,2.243015766143799,0.0 -62200,1.6104059,3.2569096,,,,,,,,,,,,,, -62300,1.4464118,5.4549346,,,,,,,,,,,,,, -62400,1.511524,3.339013,,,,,,,,,,,,,, -62500,1.7265941,3.7783957,,,,,,,,,,,,,, -62600,1.4524437,3.3341117,,,,,,,,,,,,,, -62700,1.2926615,4.5898433,,,,,,,,,,,,,, -62800,1.282076,5.1477323,,,,,,,,,,,,,, -62900,1.7086456,3.2439258,,,,,,,,,,,,,, -63000,1.6836782,3.2565544,,,,,,,,,,,,,, -63100,1.4770757,3.222607,,,,,,,,,,,,,, -63111,,,0.6295117139816284,1.700981855392456,0.5878599882125854,1.883209466934204,50000.0,0.4726000130176544,2.5159428119659424,10000.0,28614.36855793,30966.68574333191,28614.36855793,2346.715988636017,2.281113147735596,0.0 -63200,1.3098357,4.7751474,,,,,,,,,,,,,, -63300,1.4141225,3.5988038,,,,,,,,,,,,,, -63400,1.1637806,5.4102225,,,,,,,,,,,,,, -63500,1.5779457,4.046376,,,,,,,,,,,,,, -63600,1.2158942,4.818108,,,,,,,,,,,,,, -63700,1.5266703,3.304695,,,,,,,,,,,,,, -63800,1.2268038,4.7750816,,,,,,,,,,,,,, -63900,1.571695,4.968616,,,,,,,,,,,,,, -64000,1.5534407,3.255981,,,,,,,,,,,,,, -64038,,,0.6376757621765137,1.6427497863769531,0.5946999788284302,1.846629977226257,50000.0,0.4771000146865845,2.4646663665771484,10000.0,29034.53926801681,31423.41744923592,29034.53926801681,2383.1910014152527,2.3187735080718994,0.0 -64100,1.7525935,3.4533765,,,,,,,,,,,,,, -64200,1.2369696,5.361806,,,,,,,,,,,,,, -64300,1.53254,3.2095141,,,,,,,,,,,,,, -64400,1.3037943,5.2049594,,,,,,,,,,,,,, -64500,1.7966821,3.2842588,,,,,,,,,,,,,, -64600,1.471589,3.537188,,,,,,,,,,,,,, -64700,1.6254983,3.29289,,,,,,,,,,,,,, -64800,1.5313761,3.5892184,,,,,,,,,,,,,, -64900,1.8206766,4.0556083,,,,,,,,,,,,,, -64965,,,0.6559374928474426,1.5458028316497805,0.5884599685668945,1.832724690437317,50000.0,0.4723000228404999,2.4600064754486084,10000.0,29454.639416217804,31881.19240355492,29454.639416217804,2420.782904148102,2.3532867431640625,0.0 -65000,1.4823256,5.079878,,,,,,,,,,,,,, -65100,1.3340098,4.0894294,,,,,,,,,,,,,, -65200,1.6281252,3.3415875,,,,,,,,,,,,,, -65300,1.5532376,5.433507,,,,,,,,,,,,,, -65400,1.6629006,3.322429,,,,,,,,,,,,,, -65500,1.542021,3.370785,,,,,,,,,,,,,, -65600,1.6132554,3.4879751,,,,,,,,,,,,,, -65700,1.7689935,3.2769973,,,,,,,,,,,,,, -65800,1.6402221,3.2263582,,,,,,,,,,,,,, -65893,,,0.6297070384025574,1.6710429191589355,0.5890399813652039,1.8539769649505613,50000.0,0.4759000241756439,2.482262134552002,10000.0,29874.70513272285,32338.065663814545,29874.70513272285,2457.5085434913635,2.387108564376831,0.0 -65900,1.6626853,3.747278,,,,,,,,,,,,,, -66000,1.6999601,3.198123,,,,,,,,,,,,,, -66100,1.2109866,4.85737,,,,,,,,,,,,,, -66200,1.4964435,3.4727738,,,,,,,,,,,,,, -66300,1.627105,3.2394466,,,,,,,,,,,,,, -66400,1.7047405,3.5817854,,,,,,,,,,,,,, -66500,1.6890951,3.173499,,,,,,,,,,,,,, -66600,1.5429387,3.318259,,,,,,,,,,,,,, -66700,1.500985,5.3500624,,,,,,,,,,,,,, -66800,1.6042176,3.6917193,,,,,,,,,,,,,, -66814,,,0.6425976157188416,1.6025398969650269,0.5962600111961365,1.8162277936935425,50000.0,0.4705000221729278,2.4594082832336426,10000.0,30294.62602829933,32796.39333939552,30294.62602829933,2495.827398777008,2.426995038986206,0.0 -66900,1.3712826,4.5776973,,,,,,,,,,,,,, -67000,1.2845569,5.219567,,,,,,,,,,,,,, -67100,1.6481767,3.2668238,,,,,,,,,,,,,, -67200,1.293664,4.416525,,,,,,,,,,,,,, -67300,1.5680734,4.52918,,,,,,,,,,,,,, -67400,1.7999133,3.569198,,,,,,,,,,,,,, -67500,1.7498263,3.3084223,,,,,,,,,,,,,, -67600,1.7678099,3.441628,,,,,,,,,,,,,, -67700,1.5968771,3.2487826,,,,,,,,,,,,,, -67738,,,0.6525781154632568,1.570324182510376,0.6031599640846252,1.8069418668746948,50000.0,0.4780000150203705,2.456949710845948,10000.0,30714.93037724495,33256.4021191597,30714.93037724495,2535.44087266922,2.470527648925781,0.0 -67800,1.5480946,3.224039,,,,,,,,,,,,,, -67900,1.6470618,5.455694,,,,,,,,,,,,,, -68000,1.4334589,3.807479,,,,,,,,,,,,,, -68100,1.5760602,3.7196007,,,,,,,,,,,,,, -68200,1.499724,3.5921988,,,,,,,,,,,,,, -68300,1.6469123,3.4653945,,,,,,,,,,,,,, -68400,1.3962533,4.136296,,,,,,,,,,,,,, -68500,1.6919734,3.331491,,,,,,,,,,,,,, -68600,1.7495718,3.248263,,,,,,,,,,,,,, -68663,,,0.6356640458106995,1.6274088621139526,0.5971199870109558,1.8149526119232176,50000.0,0.4754000306129455,2.4471421241760254,10000.0,31135.09487080574,33716.840874910355,31135.09487080574,2575.6301860809326,2.5075490474700928,0.0 -68700,1.5928609,3.1039119,,,,,,,,,,,,,, -68800,1.2623667,5.3134394,,,,,,,,,,,,,, -68900,1.7144897,3.333601,,,,,,,,,,,,,, -69000,1.5801265,3.169252,,,,,,,,,,,,,, -69100,1.572188,3.1489317,,,,,,,,,,,,,, -69200,1.5582559,3.1589,,,,,,,,,,,,,, -69300,1.6455816,3.238667,,,,,,,,,,,,,, -69400,1.6276156,3.2121797,,,,,,,,,,,,,, -69500,1.8217005,3.3006568,,,,,,,,,,,,,, -69587,,,0.6425390243530273,1.6058720350265503,0.6004999876022339,1.7994102239608765,50000.0,0.4761000275611877,2.4608607292175293,10000.0,31555.165743112564,34175.87082648277,31555.165743112564,2614.288892507553,2.7598090171813965,0.0 -69600,1.528826,3.2180583,,,,,,,,,,,,,, -69700,1.707543,3.15905,,,,,,,,,,,,,, -69800,1.4438169,4.11073,,,,,,,,,,,,,, -69900,1.6802362,3.2547715,,,,,,,,,,,,,, -70000,1.7442489,3.2188535,,,,,,,,,,,,,, -70100,1.5995606,3.288609,,,,,,,,,,,,,, -70200,1.6279831,3.5497084,,,,,,,,,,,,,, -70300,1.4771376,4.218611,,,,,,,,,,,,,, -70400,1.5161037,3.3602812,,,,,,,,,,,,,, -70500,1.7381986,3.2136226,,,,,,,,,,,,,, -70508,,,0.6518359184265137,1.5789217948913574,0.5998600125312805,1.8045138120651243,50000.0,0.4789000153541565,2.4507243633270264,10000.0,31975.08659768105,34635.624881505966,31975.08659768105,2654.033591747284,2.801017999649048,0.0 -70600,1.6981012,3.2544513,,,,,,,,,,,,,, -70700,1.7069219,3.2660153,,,,,,,,,,,,,, -70800,1.7342348,3.1804545,,,,,,,,,,,,,, -70900,1.2924823,4.3659062,,,,,,,,,,,,,, -71000,1.3968992,5.407196,,,,,,,,,,,,,, -71100,1.7066191,3.2507849,,,,,,,,,,,,,, -71200,1.6853182,3.3442783,,,,,,,,,,,,,, -71300,1.760662,3.433177,,,,,,,,,,,,,, -71400,1.2919816,4.1771736,,,,,,,,,,,,,, -71431,,,0.6530663967132568,1.5456774234771729,0.6032800078392029,1.7672827243804932,50000.0,0.484000027179718,2.4054059982299805,10000.0,32395.51756548881,35094.78996706009,32395.51756548881,2692.672725915909,2.847104072570801,0.0 -71500,1.9696773,3.4860134,,,,,,,,,,,,,, -71600,1.3967499,4.1875105,,,,,,,,,,,,,, -71700,1.8034481,3.130632,,,,,,,,,,,,,, -71800,1.3113499,4.7324166,,,,,,,,,,,,,, -71900,1.7028627,3.189457,,,,,,,,,,,,,, -72000,1.6301345,3.5225403,,,,,,,,,,,,,, -72100,1.448736,5.303419,,,,,,,,,,,,,, -72200,1.6683954,3.2869756,,,,,,,,,,,,,, -72300,1.3463813,3.7327154,,,,,,,,,,,,,, -72353,,,0.6449413895606995,1.5734144449234009,0.6037600040435791,1.765405535697937,50000.0,0.4842000305652618,2.39375638961792,10000.0,32815.659552812576,35551.76320028305,32815.659552812576,2729.4140496253967,2.888112783432007,0.0 -72400,1.2550946,4.6616673,,,,,,,,,,,,,, -72500,1.7353476,3.2711134,,,,,,,,,,,,,, -72600,1.5466418,3.9988232,,,,,,,,,,,,,, -72700,1.5774531,3.353241,,,,,,,,,,,,,, -72800,1.2461871,4.9684253,,,,,,,,,,,,,, -72900,1.4002503,5.414728,,,,,,,,,,,,,, -73000,1.5571648,4.3758907,,,,,,,,,,,,,, -73100,1.8154031,3.2002802,,,,,,,,,,,,,, -73200,1.287321,5.37223,,,,,,,,,,,,,, -73277,,,0.6486523151397705,1.5562613010406494,0.6032199859619141,1.7581266164779663,50000.0,0.484000027179718,2.3881523609161377,10000.0,33235.68685173988,36009.36540389061,33235.68685173988,2766.906188249588,2.9232640266418457,0.0 -73300,1.7604285,3.387514,,,,,,,,,,,,,, -73400,1.5889328,3.3389268,,,,,,,,,,,,,, -73500,1.6929302,3.3227236,,,,,,,,,,,,,, -73600,1.8133076,3.245232,,,,,,,,,,,,,, -73700,1.5045788,3.4593482,,,,,,,,,,,,,, -73800,1.7364446,3.2028868,,,,,,,,,,,,,, -73900,1.3605193,5.274911,,,,,,,,,,,,,, -74000,1.7425518,3.1254349,,,,,,,,,,,,,, -74100,1.5445096,5.4714413,,,,,,,,,,,,,, -74200,1.7644027,3.2574966,,,,,,,,,,,,,, -74204,,,0.6731249690055847,1.4761346578598022,0.6058200001716614,1.766079068183899,50000.0,0.4800000190734863,2.4043872356414795,10000.0,33655.99447274208,36469.58887457848,33655.99447274208,2806.733047246933,2.963426113128662,0.0 -74300,1.7402797,3.1905146,,,,,,,,,,,,,, -74400,1.6019441,4.392081,,,,,,,,,,,,,, -74500,1.627459,3.1930463,,,,,,,,,,,,,, -74600,1.519108,3.678345,,,,,,,,,,,,,, -74700,1.6090678,5.3787303,,,,,,,,,,,,,, -74800,1.7561496,3.1930146,,,,,,,,,,,,,, -74900,1.3573526,4.5525684,,,,,,,,,,,,,, -75000,1.2976468,5.0661817,,,,,,,,,,,,,, -75100,1.7560498,3.3340669,,,,,,,,,,,,,, -75128,,,0.6450976133346558,1.5667901039123535,0.6061399579048157,1.760995626449585,50000.0,0.486700028181076,2.39943790435791,10000.0,34076.297404289246,36930.39732980728,34076.297404289246,2847.1502919197083,3.003435611724853,0.0 -75200,1.8936226,3.1017437,,,,,,,,,,,,,, -75300,1.7368805,3.249125,,,,,,,,,,,,,, -75400,1.8807378,3.270784,,,,,,,,,,,,,, -75500,1.5814092,3.279995,,,,,,,,,,,,,, -75600,1.506921,3.6936402,,,,,,,,,,,,,, -75700,1.4388031,4.1945553,,,,,,,,,,,,,, -75800,1.836501,3.367577,,,,,,,,,,,,,, -75900,1.8152119,3.1531339,,,,,,,,,,,,,, -76000,1.7336217,3.4521804,,,,,,,,,,,,,, -76053,,,0.6499999761581421,1.5527905225753784,0.6094799637794495,1.7591296434402466,50000.0,0.4927000105381012,2.395556688308716,10000.0,34496.459659576416,37390.65586042404,34496.459659576416,2887.1610465049744,3.0404083728790283,0.0 -76100,1.6338878,3.0433917,,,,,,,,,,,,,, -76200,1.978798,3.2310534,,,,,,,,,,,,,, -76300,1.699347,3.2741804,,,,,,,,,,,,,, -76400,1.430182,4.9506273,,,,,,,,,,,,,, -76500,1.415168,4.4876957,,,,,,,,,,,,,, -76600,1.5269763,5.298348,,,,,,,,,,,,,, -76700,1.8150748,3.3096101,,,,,,,,,,,,,, -76800,1.3412617,4.3766847,,,,,,,,,,,,,, -76900,1.3902781,4.932516,,,,,,,,,,,,,, -76976,,,0.6608788967132568,1.5342339277267456,0.6050800085067749,1.7775715589523315,50000.0,0.4882000088691711,2.4163527488708496,10000.0,34916.55329370499,37848.11969184876,34916.55329370499,2924.444561958313,3.0794270038604736,0.0 -77000,1.9564288,3.1598136,,,,,,,,,,,,,, -77100,1.6584646,3.0981035,,,,,,,,,,,,,, -77200,1.7672118,3.183806,,,,,,,,,,,,,, -77300,1.793515,3.2179575,,,,,,,,,,,,,, -77400,1.5748088,3.3481452,,,,,,,,,,,,,, -77500,1.9288212,3.1323295,,,,,,,,,,,,,, -77600,1.5997635,3.4702506,,,,,,,,,,,,,, -77700,1.7752213,3.2791839,,,,,,,,,,,,,, -77800,1.6079789,5.2504215,,,,,,,,,,,,,, -77900,1.5914836,3.5474381,,,,,,,,,,,,,, -77903,,,0.6502929329872131,1.5727944374084473,0.6092599630355835,1.7659096717834473,50000.0,0.4905000329017639,2.4130568504333496,10000.0,35336.832350969315,38307.64156937599,35336.832350969315,2963.59290099144,3.125219821929932,0.0 -78000,1.7507964,3.1394296,,,,,,,,,,,,,, -78100,1.6692,3.217828,,,,,,,,,,,,,, -78200,1.8072704,3.2253177,,,,,,,,,,,,,, -78300,1.382955,4.4874144,,,,,,,,,,,,,, -78400,1.7205195,3.1616778,,,,,,,,,,,,,, -78500,1.6702638,3.2127585,,,,,,,,,,,,,, -78600,1.4874144,3.6849782,,,,,,,,,,,,,, -78700,1.7817302,3.191146,,,,,,,,,,,,,, -78800,1.4631414,5.0838885,,,,,,,,,,,,,, -78829,,,0.6518749594688416,1.5727999210357666,0.6112599968910217,1.7607871294021606,50000.0,0.4865000247955322,2.392758369445801,10000.0,35756.88972687721,38767.19116520882,35756.88972687721,3003.001363515854,3.161072254180908,0.0 -78900,1.377505,5.3243513,,,,,,,,,,,,,, -79000,1.5186759,4.3591022,,,,,,,,,,,,,, -79100,2.041824,3.2284234,,,,,,,,,,,,,, -79200,1.3855215,5.268198,,,,,,,,,,,,,, -79300,1.9662844,3.2776256,,,,,,,,,,,,,, -79400,1.74559,3.185479,,,,,,,,,,,,,, -79500,1.5585561,3.1461506,,,,,,,,,,,,,, -79600,1.3666135,4.6770697,,,,,,,,,,,,,, -79700,1.3717189,5.2881417,,,,,,,,,,,,,, -79754,,,0.6651171445846558,1.4895260334014893,0.6138399839401245,1.7199392318725586,50000.0,0.4955000281333923,2.342292547225952,10000.0,36177.00625920296,39226.83818221092,36177.00625920296,3042.4453415870667,3.199512720108032,0.0 -79800,2.0563922,3.3094797,,,,,,,,,,,,,, -79900,1.8283209,3.2163236,,,,,,,,,,,,,, -80000,1.6318989,3.074951,,,,,,,,,,,,,, -80100,1.7078803,3.3465745,,,,,,,,,,,,,, -80200,1.4815533,3.9006968,,,,,,,,,,,,,, -80300,1.7446682,3.1722386,,,,,,,,,,,,,, -80400,1.4091452,4.510696,,,,,,,,,,,,,, -80500,1.3951882,5.158534,,,,,,,,,,,,,, -80600,1.9184946,3.2071025,,,,,,,,,,,,,, -80677,,,0.6646288633346558,1.4810144901275637,0.6114999651908875,1.7160485982894895,50000.0,0.493800014257431,2.342613458633423,10000.0,36597.088631391525,39687.675043821335,36597.088631391525,3083.108058929444,3.2435245513916016,0.0 -80700,1.6755918,3.2353618,,,,,,,,,,,,,, -80800,1.5585458,4.2369223,,,,,,,,,,,,,, -80900,1.7515845,3.0736213,,,,,,,,,,,,,, -81000,1.7109152,3.4779851,,,,,,,,,,,,,, -81100,1.7922697,3.018824,,,,,,,,,,,,,, -81200,1.4378703,5.1553698,,,,,,,,,,,,,, -81300,1.9026742,3.2046208,,,,,,,,,,,,,, -81400,1.7655046,3.1684976,,,,,,,,,,,,,, -81500,1.6510628,3.3045976,,,,,,,,,,,,,, -81600,1.5645677,4.7685885,,,,,,,,,,,,,, -81603,,,0.6604687571525574,1.5106688737869265,0.6142599582672119,1.7114691734313965,50000.0,0.4974000155925751,2.352527379989624,10000.0,37017.36079597473,40149.45268511772,37017.36079597473,3124.51851439476,3.2905027866363525,0.0 -81700,1.4960212,5.284489,,,,,,,,,,,,,, -81800,1.4011791,5.2996445,,,,,,,,,,,,,, -81900,1.8833268,3.2004972,,,,,,,,,,,,,, -82000,1.5652249,4.7614264,,,,,,,,,,,,,, -82100,1.6980342,3.2164574,,,,,,,,,,,,,, -82200,1.516121,5.339248,,,,,,,,,,,,,, -82300,1.6501833,3.1972032,,,,,,,,,,,,,, -82400,1.754153,3.0894265,,,,,,,,,,,,,, -82500,1.5217662,3.9308023,,,,,,,,,,,,,, -82530,,,0.6646093726158142,1.5200191736221311,0.6138399839401245,1.7434662580490112,50000.0,0.4939000308513641,2.370671510696411,10000.0,37437.69989728928,40609.15281009674,37437.69989728928,3163.785692691803,3.336221933364868,0.0 -82600,1.3616611,4.927919,,,,,,,,,,,,,, -82700,1.7115638,3.0985456,,,,,,,,,,,,,, -82800,1.70502,3.2250404,,,,,,,,,,,,,, -82900,1.7187389,3.0851812,,,,,,,,,,,,,, -83000,1.6872406,4.1528344,,,,,,,,,,,,,, -83100,1.8461472,3.164871,,,,,,,,,,,,,, -83200,1.868045,3.3251488,,,,,,,,,,,,,, -83300,1.649776,5.133897,,,,,,,,,,,,,, -83400,1.8596618,5.2666326,,,,,,,,,,,,,, -83454,,,0.6852148175239563,1.3989347219467163,0.6177399754524231,1.7047220468521118,50000.0,0.5004000067710876,2.3427956104278564,10000.0,37857.93138933182,41068.87111830712,37857.93138933182,3203.184502363205,3.375791311264038,0.0 -83500,1.8633212,3.086342,,,,,,,,,,,,,, -83600,1.5512841,3.849631,,,,,,,,,,,,,, -83700,1.5863612,4.7892485,,,,,,,,,,,,,, -83800,1.6601605,3.384779,,,,,,,,,,,,,, -83900,1.948005,3.116246,,,,,,,,,,,,,, -84000,1.8295668,3.4877937,,,,,,,,,,,,,, -84100,1.83353,3.1432827,,,,,,,,,,,,,, -84200,1.7904216,3.0947807,,,,,,,,,,,,,, -84300,1.5298505,4.177771,,,,,,,,,,,,,, -84377,,,0.6599413752555847,1.5086867809295654,0.6137599945068359,1.7188671827316284,50000.0,0.4927000105381012,2.361658811569214,10000.0,38278.19797229767,41528.8823723793,38278.19797229767,3242.8420696258545,3.4148974418640137,0.0 -84400,1.9653118,3.174033,,,,,,,,,,,,,, -84500,1.7545476,3.2640498,,,,,,,,,,,,,, -84600,1.6938112,3.351079,,,,,,,,,,,,,, -84700,2.0915365,3.201777,,,,,,,,,,,,,, -84800,1.9659694,3.1395717,,,,,,,,,,,,,, -84900,1.5933769,5.259899,,,,,,,,,,,,,, -85000,1.7020674,3.3756418,,,,,,,,,,,,,, -85100,1.8004802,3.1716886,,,,,,,,,,,,,, -85200,1.8868756,3.0985003,,,,,,,,,,,,,, -85300,1.5035032,5.253617,,,,,,,,,,,,,, -85301,,,0.6687890291213989,1.4672608375549316,0.6248399615287781,1.679897427558899,50000.0,0.5024999976158142,2.314694881439209,10000.0,38698.49995803833,41990.98984336853,38698.49995803833,3284.55818772316,3.455544948577881,0.0 -85400,1.8073126,3.131301,,,,,,,,,,,,,, -85500,1.8219098,3.2391005,,,,,,,,,,,,,, -85600,1.8776311,3.2011364,,,,,,,,,,,,,, -85700,1.7986745,3.1822257,,,,,,,,,,,,,, -85800,1.6539338,3.3471043,,,,,,,,,,,,,, -85900,1.5977998,3.147736,,,,,,,,,,,,,, -86000,1.7497244,3.1930084,,,,,,,,,,,,,, -86100,1.6379846,3.34574,,,,,,,,,,,,,, -86200,1.561832,5.146417,,,,,,,,,,,,,, -86227,,,0.682421863079071,1.420057773590088,0.6224600076675415,1.6837719678878784,50000.0,0.5035000443458557,2.308379650115967,10000.0,39118.564464092255,42447.72882437706,39118.564464092255,3321.14311671257,3.497130870819092,0.0 -86300,1.9711845,4.812784,,,,,,,,,,,,,, -86400,1.6442429,5.1282735,,,,,,,,,,,,,, -86500,1.7422034,3.0105531,,,,,,,,,,,,,, -86600,1.569773,3.9816587,,,,,,,,,,,,,, -86700,1.7655895,3.065672,,,,,,,,,,,,,, -86800,1.5192926,3.551937,,,,,,,,,,,,,, -86900,1.599082,3.8465948,,,,,,,,,,,,,, -87000,1.6615508,3.218011,,,,,,,,,,,,,, -87100,1.7712581,3.1582823,,,,,,,,,,,,,, -87154,,,0.6678515672683716,1.4901386499404907,0.624019980430603,1.6841695308685305,50000.0,0.5034000277519226,2.3107705116271973,10000.0,39538.55706048012,42908.27966308594,39538.55706048012,3361.6144444942474,3.535960912704468,0.0 -87200,1.7861605,3.091462,,,,,,,,,,,,,, -87300,1.6233608,3.373958,,,,,,,,,,,,,, -87400,1.8326911,3.1591964,,,,,,,,,,,,,, -87500,1.9958404,3.1652827,,,,,,,,,,,,,, -87600,1.5592866,4.255619,,,,,,,,,,,,,, -87700,1.5886893,4.2287188,,,,,,,,,,,,,, -87800,1.4798024,4.547367,,,,,,,,,,,,,, -87900,1.6752889,3.1269927,,,,,,,,,,,,,, -88000,1.5899224,4.463546,,,,,,,,,,,,,, -88081,,,0.6660546660423279,1.5081756114959717,0.6201399564743042,1.717879295349121,50000.0,0.5009000301361084,2.3405404090881348,10000.0,39958.77399516106,43366.066150188446,39958.77399516106,3399.0924847126007,3.5792877674102783,0.0 -88100,1.954877,3.1520402,,,,,,,,,,,,,, -88200,1.6495732,4.947705,,,,,,,,,,,,,, -88300,1.8568043,3.085929,,,,,,,,,,,,,, -88400,1.5074061,5.0879445,,,,,,,,,,,,,, -88500,1.6182108,4.0214915,,,,,,,,,,,,,, -88600,1.9232025,3.1679149,,,,,,,,,,,,,, -88700,1.8813927,3.0850842,,,,,,,,,,,,,, -88800,1.7930747,3.2243497,,,,,,,,,,,,,, -88900,1.6729662,3.411502,,,,,,,,,,,,,, -89000,1.8098265,3.0654893,,,,,,,,,,,,,, -89007,,,0.6788476705551147,1.4367852210998535,0.6260199546813965,1.682494044303894,50000.0,0.506600022315979,2.309653282165528,10000.0,40378.939730882645,43824.08883070946,40378.939730882645,3436.8513662815094,3.6288912296295166,0.0 -89100,2.0483315,3.2288468,,,,,,,,,,,,,, -89200,1.708695,3.3763251,,,,,,,,,,,,,, -89300,1.712437,3.444908,,,,,,,,,,,,,, -89400,1.6440526,4.751516,,,,,,,,,,,,,, -89500,1.4479964,4.1042576,,,,,,,,,,,,,, -89600,1.8013313,3.1304657,,,,,,,,,,,,,, -89700,2.099098,3.466819,,,,,,,,,,,,,, -89800,1.7899698,5.0878754,,,,,,,,,,,,,, -89900,1.7404231,3.0240784,,,,,,,,,,,,,, -89929,,,0.6749609112739563,1.4869822263717651,0.6248999834060669,1.7146426439285278,50000.0,0.5065000057220459,2.3380980491638184,10000.0,40798.99773478508,44281.61442565918,40798.99773478508,3474.23304772377,3.666169404983521,0.0 -90000,1.9876059,3.1126447,,,,,,,,,,,,,, -90100,1.5024222,5.2416544,,,,,,,,,,,,,, -90200,1.5721163,3.795159,,,,,,,,,,,,,, -90300,1.6838913,3.6235564,,,,,,,,,,,,,, -90400,2.1628304,3.0967407,,,,,,,,,,,,,, -90500,1.7725259,3.1077836,,,,,,,,,,,,,, -90600,1.480778,4.2852254,,,,,,,,,,,,,, -90700,2.2769237,3.094719,,,,,,,,,,,,,, -90800,1.7509946,3.5491192,,,,,,,,,,,,,, -90855,,,0.6749218702316284,1.44557523727417,0.6292200088500977,1.6595513820648191,50000.0,0.5069000124931335,2.29979944229126,10000.0,41219.33115816116,44740.604519844055,41219.33115816116,3512.803384065628,3.704136610031128,0.0 -90900,1.8462925,3.2416668,,,,,,,,,,,,,, -91000,1.8309653,3.8592424,,,,,,,,,,,,,, -91100,1.8326312,3.1306486,,,,,,,,,,,,,, -91200,1.9832985,3.003304,,,,,,,,,,,,,, -91300,1.687084,3.4045315,,,,,,,,,,,,,, -91400,1.8409463,3.0074234,,,,,,,,,,,,,, -91500,1.5898607,4.95417,,,,,,,,,,,,,, -91600,1.7788743,3.0461435,,,,,,,,,,,,,, -91700,1.837815,3.0983765,,,,,,,,,,,,,, -91780,,,0.6800000071525574,1.466551423072815,0.6314399838447571,1.6920958757400513,50000.0,0.5057000517845154,2.329899787902832,10000.0,41639.402092933655,45198.59097504616,41639.402092933655,3550.632279634476,3.7427337169647217,0.0 -91800,1.5222251,4.9657826,,,,,,,,,,,,,, -91900,1.7047831,3.3385696,,,,,,,,,,,,,, -92000,1.7960004,3.1711721,,,,,,,,,,,,,, -92100,1.5504113,4.9958696,,,,,,,,,,,,,, -92200,1.8581876,3.0822806,,,,,,,,,,,,,, -92300,1.5862663,5.158376,,,,,,,,,,,,,, -92400,1.5273141,4.5510178,,,,,,,,,,,,,, -92500,1.5048134,4.2354593,,,,,,,,,,,,,, -92600,2.1132565,3.1135502,,,,,,,,,,,,,, -92700,1.5091724,4.8319273,,,,,,,,,,,,,, -92701,,,0.69873046875,1.3424489498138428,0.6325399875640869,1.649196982383728,50000.0,0.5121000409126282,2.2850232124328613,10000.0,42059.464174985886,45658.74293327332,42059.464174985886,3590.631927251816,3.7838430404663086,0.0 -92800,1.8615923,3.1711912,,,,,,,,,,,,,, -92900,1.651667,4.0826163,,,,,,,,,,,,,, -93000,1.679664,3.8059976,,,,,,,,,,,,,, -93100,1.8676631,3.0619853,,,,,,,,,,,,,, -93200,1.7298762,3.7739997,,,,,,,,,,,,,, -93300,1.7413063,3.1083531,,,,,,,,,,,,,, -93400,1.7433097,3.6783745,,,,,,,,,,,,,, -93500,1.6541954,4.0028276,,,,,,,,,,,,,, -93600,1.8460352,3.307341,,,,,,,,,,,,,, -93626,,,0.6816992163658142,1.4043458700180054,0.637179970741272,1.6172446012496948,50000.0,0.5180000066757202,2.231893301010132,10000.0,42479.74120259285,46117.04714727402,42479.74120259285,3628.571489095688,3.82320785522461,0.0 -93700,1.754795,3.1355348,,,,,,,,,,,,,, -93800,1.806142,3.1585312,,,,,,,,,,,,,, -93900,1.613846,4.391349,,,,,,,,,,,,,, -94000,1.9952707,2.9726362,,,,,,,,,,,,,, -94100,1.5417001,5.082616,,,,,,,,,,,,,, -94200,1.5749003,4.867973,,,,,,,,,,,,,, -94300,1.61313,4.9554095,,,,,,,,,,,,,, -94400,1.8047609,3.2409625,,,,,,,,,,,,,, -94500,1.6723642,3.401351,,,,,,,,,,,,,, -94550,,,0.6850976347923279,1.3780698776245115,0.6401599645614624,1.591439127922058,50000.0,0.5154000520706177,2.23335862159729,10000.0,42900.02947735786,46574.56220269203,42900.02947735786,3665.710597515106,3.861924171447754,0.0 -94600,1.8731622,3.0638258,,,,,,,,,,,,,, -94700,1.7311319,3.3049986,,,,,,,,,,,,,, -94800,1.8696495,3.2234988,,,,,,,,,,,,,, -94900,1.7672844,5.1505594,,,,,,,,,,,,,, -95000,1.5896021,4.317278,,,,,,,,,,,,,, -95100,1.8447803,3.0049891,,,,,,,,,,,,,, -95200,1.8920884,3.0845728,,,,,,,,,,,,,, -95300,1.8720814,3.0437336,,,,,,,,,,,,,, -95400,1.9560847,4.863335,,,,,,,,,,,,,, -95475,,,0.6938085556030273,1.3812767267227173,0.6334999799728394,1.646828293800354,50000.0,0.5127000212669373,2.2695205211639404,10000.0,43320.07567238808,47032.49260401726,43320.07567238808,3703.507269144058,3.901413679122925,0.0 -95500,1.9110575,5.2167387,,,,,,,,,,,,,, -95600,1.7596825,3.220847,,,,,,,,,,,,,, -95700,1.7704282,3.829152,,,,,,,,,,,,,, -95800,1.615243,4.9643474,,,,,,,,,,,,,, -95900,1.8738475,3.2639635,,,,,,,,,,,,,, -96000,1.8753717,3.0351188,,,,,,,,,,,,,, -96100,1.5398052,5.139383,,,,,,,,,,,,,, -96200,2.0224657,3.3866825,,,,,,,,,,,,,, -96300,1.8825275,3.0437837,,,,,,,,,,,,,, -96400,,,0.6801366806030273,1.449390888214111,0.6377800107002258,1.6380335092544556,50000.0,0.5180000066757202,2.2730095386505127,10000.0,43740.39475917816,47487.6726129055,43740.39475917816,3738.279561281204,3.94172739982605,0.0 -96400,1.5449985,5.194209,,,,,,,,,,,,,, -96500,2.1668954,3.0649872,,,,,,,,,,,,,, -96600,1.9543133,2.9848905,,,,,,,,,,,,,, -96700,1.833853,2.964246,,,,,,,,,,,,,, -96800,2.098419,3.0753455,,,,,,,,,,,,,, -96900,2.1181707,3.1146667,,,,,,,,,,,,,, -97000,1.7943659,3.2130668,,,,,,,,,,,,,, -97100,1.7995569,5.0708127,,,,,,,,,,,,,, -97200,1.60247,3.6787965,,,,,,,,,,,,,, -97300,1.6578295,4.7497334,,,,,,,,,,,,,, -97326,,,0.6853125095367432,1.400512933731079,0.6398599743843079,1.6111207008361816,50000.0,0.5163000226020813,2.233106136322021,10000.0,44160.30851793289,47948.28272128105,44160.30851793289,3778.8772070407854,3.9924333095550537,0.0 -97400,1.6875975,5.1405272,,,,,,,,,,,,,, -97500,1.6917815,3.2683141,,,,,,,,,,,,,, -97600,1.8833312,3.2021341,,,,,,,,,,,,,, -97700,2.0989892,3.3155408,,,,,,,,,,,,,, -97800,1.6385491,4.2253966,,,,,,,,,,,,,, -97900,2.1163435,2.99831,,,,,,,,,,,,,, -98000,1.8419775,2.9900553,,,,,,,,,,,,,, -98100,1.9435685,3.2203064,,,,,,,,,,,,,, -98200,2.0126243,2.9738185,,,,,,,,,,,,,, -98253,,,0.6956640481948853,1.3512005805969238,0.6403399705886841,1.6066975593566897,50000.0,0.5175999999046326,2.243485689163208,10000.0,44580.45514702797,48407.94709467888,44580.45514702797,3818.3037271499634,4.035202503204346,0.0 -98300,1.9620686,3.175089,,,,,,,,,,,,,, -98400,1.5608071,5.1038547,,,,,,,,,,,,,, -98500,2.0955753,5.190289,,,,,,,,,,,,,, -98600,1.4767544,5.120361,,,,,,,,,,,,,, -98700,2.0970368,2.9533238,,,,,,,,,,,,,, -98800,2.0478947,2.9702685,,,,,,,,,,,,,, -98900,1.9932187,3.1493275,,,,,,,,,,,,,, -99000,2.034099,5.093334,,,,,,,,,,,,,, -99100,2.0401473,3.0169392,,,,,,,,,,,,,, -99180,,,0.6977148056030273,1.3450783491134644,0.6491000056266785,1.5536638498306274,50000.0,0.526199996471405,2.183600187301636,10000.0,45000.66364359856,48865.30378293991,45000.66364359856,3855.362357854843,4.076540231704712,0.0 -99200,1.7959332,4.943288,,,,,,,,,,,,,, -99300,2.0139196,3.094925,,,,,,,,,,,,,, -99400,1.9605757,2.9518538,,,,,,,,,,,,,, -99500,2.0088873,5.013284,,,,,,,,,,,,,, -99600,1.877904,3.3896031,,,,,,,,,,,,,, -99700,1.6672786,4.4847116,,,,,,,,,,,,,, -99800,1.7899693,3.897884,,,,,,,,,,,,,, -99900,1.5280282,4.586857,,,,,,,,,,,,,, -100000,1.8778054,2.8950987,,,,,,,,,,,,,, -100100,1.7201504,5.0678983,,,,,,,,,,,,,, -100104,,,0.6898828148841858,1.416399598121643,0.6423400044441223,1.6274499893188477,50000.0,0.5207000374794006,2.2531089782714844,10000.0,45420.91577172279,49324.64928340912,45420.91577172279,3894.365116834641,4.119498014450073,0.0 -100200,1.5719087,4.8373065,,,,,,,,,,,,,, -100300,1.5713705,4.557147,,,,,,,,,,,,,, -100400,2.027921,3.0921426,,,,,,,,,,,,,, -100500,2.12084,2.986764,,,,,,,,,,,,,, -100600,1.666868,3.9877918,,,,,,,,,,,,,, -100700,2.0706372,2.8706126,,,,,,,,,,,,,, -100800,2.0895085,2.9766827,,,,,,,,,,,,,, -100900,1.6869626,4.608253,,,,,,,,,,,,,, -101000,1.9919331,2.817091,,,,,,,,,,,,,, -101032,,,0.6932812333106995,1.3365551233291626,0.6417999863624573,1.5768588781356812,50000.0,0.52510005235672,2.190631628036499,10000.0,45841.04868769646,49783.31498479843,45841.04868769646,3932.808772087097,4.1595988273620605,0.0 -101100,1.8121579,3.6074724,,,,,,,,,,,,,, -101200,1.9802788,3.3109267,,,,,,,,,,,,,, -101300,2.0344775,2.9440808,,,,,,,,,,,,,, -101400,1.8640584,5.1528444,,,,,,,,,,,,,, -101500,2.05317,2.9930656,,,,,,,,,,,,,, -101600,1.7306175,4.224236,,,,,,,,,,,,,, -101700,1.8820064,3.1593647,,,,,,,,,,,,,, -101800,1.6791656,4.5385904,,,,,,,,,,,,,, -101900,1.8657322,3.0596433,,,,,,,,,,,,,, -101958,,,0.7127343416213989,1.287361979484558,0.642579972743988,1.5980144739151,50000.0,0.5157000422477722,2.233988046646118,10000.0,46261.34437060356,50243.625947237015,46261.34437060356,3972.7331693172455,4.202354431152344,0.0 -102000,1.866907,3.0245833,,,,,,,,,,,,,, -102100,1.7269763,3.9143214,,,,,,,,,,,,,, -102200,1.7508793,3.787746,,,,,,,,,,,,,, -102300,1.8461651,3.6341963,,,,,,,,,,,,,, -102400,2.0201235,3.6390364,,,,,,,,,,,,,, -102500,1.9038947,2.9370737,,,,,,,,,,,,,, -102600,2.2592921,2.9766264,,,,,,,,,,,,,, -102700,1.8842041,3.2674088,,,,,,,,,,,,,, -102800,1.9602523,5.1432004,,,,,,,,,,,,,, -102883,,,0.6974804401397705,1.331035614013672,0.6518599987030029,1.5392245054244995,50000.0,0.5223000049591064,2.177493095397949,10000.0,46681.6217648983,50704.19277739525,46681.6217648983,4012.935410261154,4.241785049438477,0.0 -102900,1.6949122,5.0583386,,,,,,,,,,,,,, -103000,1.9713762,4.949708,,,,,,,,,,,,,, -103100,2.0353074,2.9053304,,,,,,,,,,,,,, -103200,1.7256085,4.189928,,,,,,,,,,,,,, -103300,1.8364842,2.8927217,,,,,,,,,,,,,, -103400,1.9623058,3.0196173,,,,,,,,,,,,,, -103500,1.8492112,3.7403567,,,,,,,,,,,,,, -103600,2.1252298,2.8932104,,,,,,,,,,,,,, -103700,1.9710498,2.922377,,,,,,,,,,,,,, -103800,2.1765368,4.950062,,,,,,,,,,,,,, -103809,,,0.6997656226158142,1.369573950767517,0.6473199725151062,1.58622944355011,50000.0,0.5236000418663025,2.234383583068848,10000.0,47101.53124284744,51165.548437833786,47101.53124284744,4054.28226852417,4.292704820632935,0.0 -103900,1.8391773,4.447103,,,,,,,,,,,,,, -104000,2.1052475,2.881452,,,,,,,,,,,,,, -104100,1.9302973,3.160742,,,,,,,,,,,,,, -104200,2.1585507,4.819146,,,,,,,,,,,,,, -104300,2.1156874,2.9350197,,,,,,,,,,,,,, -104400,2.1176338,2.7982488,,,,,,,,,,,,,, -104500,2.0350835,2.866962,,,,,,,,,,,,,, -104600,1.8332057,3.6895897,,,,,,,,,,,,,, -104700,2.0659113,2.9701543,,,,,,,,,,,,,, -104733,,,0.7157617211341858,1.2719745635986328,0.6475399732589722,1.5650979280471802,50000.0,0.528700053691864,2.189589262008667,10000.0,47521.49310541153,51626.334483385086,47521.49310541153,4095.0082075595856,4.342754125595093,0.0 -104800,1.695601,3.8067074,,,,,,,,,,,,,, -104900,1.9798039,2.9300015,,,,,,,,,,,,,, -105000,2.1581156,2.956809,,,,,,,,,,,,,, -105100,2.0242233,2.937541,,,,,,,,,,,,,, -105200,1.7216332,4.502529,,,,,,,,,,,,,, -105300,2.167502,2.94237,,,,,,,,,,,,,, -105400,1.9064492,3.690351,,,,,,,,,,,,,, -105500,2.0887787,2.9789104,,,,,,,,,,,,,, -105600,1.8496224,4.4639883,,,,,,,,,,,,,, -105660,,,0.6998632550239563,1.3254988193511963,0.649619996547699,1.5387935638427734,50000.0,0.5309000015258789,2.172654867172241,10000.0,47941.640141010284,52083.59052538872,47941.640141010284,4132.022391796112,4.3900251388549805,0.0 -105700,1.8724015,3.1832423,,,,,,,,,,,,,, -105800,1.9975523,2.9113603,,,,,,,,,,,,,, -105900,1.9162828,4.1470795,,,,,,,,,,,,,, -106000,1.7979698,4.0826926,,,,,,,,,,,,,, -106100,2.3273795,3.2134328,,,,,,,,,,,,,, -106200,2.0296419,2.8715138,,,,,,,,,,,,,, -106300,1.9678516,3.1317425,,,,,,,,,,,,,, -106400,2.181859,5.026623,,,,,,,,,,,,,, -106500,2.2198117,2.9812222,,,,,,,,,,,,,, -106589,,,0.7017773389816284,1.3748812675476074,0.6519799828529358,1.600887656211853,50000.0,0.5260000228881836,2.2354602813720703,10000.0,48361.77505850792,52543.93924975395,48361.77505850792,4172.140084028244,4.437947511672974,0.0 -106600,1.9999301,2.9731832,,,,,,,,,,,,,, -106700,2.0382853,2.9581888,,,,,,,,,,,,,, -106800,1.7794344,3.4169755,,,,,,,,,,,,,, -106900,1.995275,4.993963,,,,,,,,,,,,,, -107000,1.9109677,3.0388045,,,,,,,,,,,,,, -107100,1.7929783,4.190413,,,,,,,,,,,,,, -107200,2.0232885,3.8757052,,,,,,,,,,,,,, -107300,1.7819756,4.7421823,,,,,,,,,,,,,, -107400,2.0531573,3.1368535,,,,,,,,,,,,,, -107500,2.1982431,2.9071078,,,,,,,,,,,,,, -107516,,,0.7152148485183716,1.2641282081604004,0.6518200039863586,1.5298153162002563,50000.0,0.5263000130653381,2.170165777206421,10000.0,48782.06770968437,53001.499126434326,48782.06770968437,4209.313284873962,4.48431396484375,0.0 -107600,2.0522,3.0088255,,,,,,,,,,,,,, -107700,1.8242235,3.424683,,,,,,,,,,,,,, -107800,2.2840981,5.1807814,,,,,,,,,,,,,, -107900,1.7966679,4.76514,,,,,,,,,,,,,, -108000,2.1808214,2.8450646,,,,,,,,,,,,,, -108100,2.220922,3.5224738,,,,,,,,,,,,,, -108200,2.2732544,2.9403257,,,,,,,,,,,,,, -108300,2.2120848,2.8756516,,,,,,,,,,,,,, -108400,1.7890139,3.7066147,,,,,,,,,,,,,, -108440,,,0.7088671922683716,1.30809485912323,0.6545000076293945,1.5428231954574585,50000.0,0.5293000340461731,2.17182731628418,10000.0,49202.40798187256,53460.39822816849,49202.40798187256,4247.780642032623,4.5283708572387695,0.0 -108500,2.1819046,3.1168528,,,,,,,,,,,,,, -108600,2.1591914,2.9394531,,,,,,,,,,,,,, -108700,2.1381118,2.8818388,,,,,,,,,,,,,, -108800,2.1482944,2.8612747,,,,,,,,,,,,,, -108900,2.207276,3.1513696,,,,,,,,,,,,,, -109000,1.9713286,3.171012,,,,,,,,,,,,,, -109100,2.0928953,2.9085202,,,,,,,,,,,,,, -109200,2.0812645,2.8482199,,,,,,,,,,,,,, -109300,2.1006722,4.673169,,,,,,,,,,,,,, -109364,,,0.7081640362739563,1.2886130809783936,0.6595799922943115,1.5078322887420654,50000.0,0.5425000190734863,2.130657434463501,10000.0,49622.73654794693,53920.59532928467,49622.73654794693,4287.5615401268005,4.569385528564453,0.0 -109400,2.2785463,2.9987633,,,,,,,,,,,,,, -109500,2.0676072,2.8351278,,,,,,,,,,,,,, -109600,2.2731123,5.11145,,,,,,,,,,,,,, -109700,1.7090728,3.7539566,,,,,,,,,,,,,, -109800,2.3752546,2.9121337,,,,,,,,,,,,,, -109900,2.0977225,3.2923012,,,,,,,,,,,,,, -110000,2.100485,3.332254,,,,,,,,,,,,,, -110100,2.1173844,3.3495789,,,,,,,,,,,,,, -110200,2.1920836,3.2086513,,,,,,,,,,,,,, -110290,,,0.7182226181030273,1.246045470237732,0.6599000096321106,1.4922248125076294,50000.0,0.5387000441551208,2.11667799949646,10000.0,50042.85763192177,54378.7845196724,50042.85763192177,4325.538824796677,4.612484693527222,0.0 -110300,2.3343308,2.9280062,,,,,,,,,,,,,, -110400,2.2433088,2.9037805,,,,,,,,,,,,,, -110500,1.9077765,4.6656632,,,,,,,,,,,,,, -110600,2.067259,2.83198,,,,,,,,,,,,,, -110700,2.0227873,2.8903203,,,,,,,,,,,,,, -110800,2.0651047,3.136997,,,,,,,,,,,,,, -110900,2.0426133,5.058805,,,,,,,,,,,,,, -111000,2.0708473,3.2311463,,,,,,,,,,,,,, -111100,1.8375747,3.9484253,,,,,,,,,,,,,, -111200,2.0885928,4.20889,,,,,,,,,,,,,, -111215,,,0.7352343797683716,1.1619881391525269,0.6606599688529968,1.4827560186386108,50000.0,0.5321000218391418,2.1034798622131348,10000.0,50462.81347370148,54835.804351091385,50462.81347370148,4362.5080988407135,4.6590001583099365,0.0 -111300,1.9711205,2.8028321,,,,,,,,,,,,,, -111400,1.8753319,4.4426413,,,,,,,,,,,,,, -111500,1.9932842,4.4208913,,,,,,,,,,,,,, -111600,2.1663036,2.9678524,,,,,,,,,,,,,, -111700,2.076362,2.7765064,,,,,,,,,,,,,, -111800,2.1542358,3.024045,,,,,,,,,,,,,, -111900,2.1490653,3.3289852,,,,,,,,,,,,,, -112000,2.00685,4.6549263,,,,,,,,,,,,,, -112100,1.8639841,4.3285394,,,,,,,,,,,,,, -112141,,,0.7141991853713989,1.254716873168945,0.6646999716758728,1.48099946975708,50000.0,0.534000039100647,2.13720703125,10000.0,50883.19580984116,55292.03625512123,50883.19580984116,4398.260158777237,4.708902835845947,0.0 -112200,2.2697453,2.8741708,,,,,,,,,,,,,, -112300,2.0126586,3.3110657,,,,,,,,,,,,,, -112400,2.1722386,2.9427204,,,,,,,,,,,,,, -112500,2.588569,2.898995,,,,,,,,,,,,,, -112600,1.9292928,3.583552,,,,,,,,,,,,,, -112700,1.9821318,3.2247586,,,,,,,,,,,,,, -112800,2.3449738,2.9028027,,,,,,,,,,,,,, -112900,2.1537461,2.996454,,,,,,,,,,,,,, -113000,2.024173,2.809462,,,,,,,,,,,,,, -113064,,,0.7215819954872131,1.2358224391937256,0.668940007686615,1.475932002067566,50000.0,0.5439000129699707,2.109809875488281,10000.0,51303.45684719086,55752.28976273537,51303.45684719086,4438.158957719803,4.753890514373779,0.0 -113100,2.0805473,2.8183157,,,,,,,,,,,,,, -113200,2.101261,2.7951326,,,,,,,,,,,,,, -113300,2.1305907,2.9450378,,,,,,,,,,,,,, -113400,1.9693402,4.651368,,,,,,,,,,,,,, -113500,1.9781767,3.287905,,,,,,,,,,,,,, -113600,2.2094443,3.0387552,,,,,,,,,,,,,, -113700,2.4012313,2.9041042,,,,,,,,,,,,,, -113800,2.1173987,2.8289506,,,,,,,,,,,,,, -113900,2.2722564,3.0165694,,,,,,,,,,,,,, -113989,,,0.7289257645606995,1.1877706050872805,0.660860002040863,1.4783949851989746,50000.0,0.5433000326156616,2.090823173522949,10000.0,51723.392731666565,56210.262149095535,51723.392731666565,4476.102520704269,4.79875922203064,0.0 -114000,1.9038862,3.7413409,,,,,,,,,,,,,, -114100,2.3476756,2.8106742,,,,,,,,,,,,,, -114200,2.0092075,2.806758,,,,,,,,,,,,,, -114300,2.102022,2.8661945,,,,,,,,,,,,,, -114400,2.0526342,3.4274993,,,,,,,,,,,,,, -114500,2.1174476,2.8971605,,,,,,,,,,,,,, -114600,2.3255947,3.0797756,,,,,,,,,,,,,, -114700,2.199122,2.9278255,,,,,,,,,,,,,, -114800,2.3540084,5.0806227,,,,,,,,,,,,,, -114900,2.307838,3.150866,,,,,,,,,,,,,, -114913,,,0.71644526720047,1.250329852104187,0.6658799648284912,1.4862215518951416,50000.0,0.5435000061988831,2.1070306301116943,10000.0,52143.375336408615,56668.10083293915,52143.375336408615,4513.86473441124,4.843943119049072,0.0 -115000,2.2743201,3.384947,,,,,,,,,,,,,, -115100,2.3645983,2.785311,,,,,,,,,,,,,, -115200,2.1314664,3.7648544,,,,,,,,,,,,,, -115300,2.2996006,2.8884037,,,,,,,,,,,,,, -115400,2.2493029,2.841762,,,,,,,,,,,,,, -115500,2.4385426,2.9276955,,,,,,,,,,,,,, -115600,2.2846498,2.9687786,,,,,,,,,,,,,, -115700,1.8584492,3.7331886,,,,,,,,,,,,,, -115800,2.474702,2.8271005,,,,,,,,,,,,,, -115837,,,0.7221484184265137,1.2188655138015747,0.669439971446991,1.451919674873352,50000.0,0.5459000468254089,2.074211597442627,10000.0,52563.48961663246,57127.45796251297,52563.48961663246,4553.015449285507,4.888232231140137,0.0 -115900,2.2668486,4.936178,,,,,,,,,,,,,, -116000,2.3807483,2.867293,,,,,,,,,,,,,, -116100,1.9763421,4.142353,,,,,,,,,,,,,, -116200,2.5008445,2.9197307,,,,,,,,,,,,,, -116300,2.189558,2.8862183,,,,,,,,,,,,,, -116400,2.4477925,3.0589352,,,,,,,,,,,,,, -116500,1.9003682,3.8048701,,,,,,,,,,,,,, -116600,2.2883315,2.8573139,,,,,,,,,,,,,, -116700,2.2140105,2.824048,,,,,,,,,,,,,, -116762,,,0.7323827743530273,1.1630433797836304,0.6712200045585632,1.4364888668060305,50000.0,0.5437999963760376,2.064565420150757,10000.0,52983.74047589302,57588.00782227516,52983.74047589302,4593.209508657455,4.9340245723724365,0.0 -116800,2.3338888,2.8445086,,,,,,,,,,,,,, -116900,2.4021266,2.9290965,,,,,,,,,,,,,, -117000,2.3094237,2.955356,,,,,,,,,,,,,, -117100,2.4042952,2.948381,,,,,,,,,,,,,, -117200,2.2087786,3.6630855,,,,,,,,,,,,,, -117300,2.217282,2.8474522,,,,,,,,,,,,,, -117400,2.1357024,4.987015,,,,,,,,,,,,,, -117500,2.22811,4.6828074,,,,,,,,,,,,,, -117600,1.9789993,4.606616,,,,,,,,,,,,,, -117688,,,0.7284570336341858,1.1957414150238037,0.6744399666786194,1.430295467376709,50000.0,0.5520000457763672,2.056349277496338,10000.0,53403.952806949615,58049.12762546539,53403.952806949615,4634.027329921722,4.976112604141235,0.0 -117700,2.4760196,2.7609766,,,,,,,,,,,,,, -117800,2.5453274,2.8588371,,,,,,,,,,,,,, -117900,1.9669294,4.623246,,,,,,,,,,,,,, -118000,2.1919808,4.9988604,,,,,,,,,,,,,, -118100,2.1327708,4.535395,,,,,,,,,,,,,, -118200,2.0924196,3.7107959,,,,,,,,,,,,,, -118300,2.1863906,3.5457788,,,,,,,,,,,,,, -118400,2.5330162,2.9328632,,,,,,,,,,,,,, -118500,2.0221558,3.978305,,,,,,,,,,,,,, -118600,2.300767,3.118719,,,,,,,,,,,,,, -118612,,,0.724804699420929,1.209438681602478,0.6693399548530579,1.4543384313583374,50000.0,0.5512000322341919,2.073585271835327,10000.0,53824.20411038399,58510.82946014404,53824.20411038399,4675.380663156509,5.025782823562622,0.0 -118700,2.2908733,2.842093,,,,,,,,,,,,,, -118800,2.0773554,3.034472,,,,,,,,,,,,,, -118900,2.290571,2.7397256,,,,,,,,,,,,,, -119000,2.5318003,3.069756,,,,,,,,,,,,,, -119100,2.264211,3.5111816,,,,,,,,,,,,,, -119200,2.0908606,4.1721992,,,,,,,,,,,,,, -119300,2.3533964,2.762715,,,,,,,,,,,,,, -119400,2.1844637,3.8600755,,,,,,,,,,,,,, -119500,2.420194,2.9952912,,,,,,,,,,,,,, -119537,,,0.7267187237739563,1.2144628763198853,0.670520007610321,1.462047100067139,50000.0,0.5481000542640686,2.072385549545288,10000.0,54244.48342585564,58971.95789599419,54244.48342585564,4716.138568401337,5.069893598556519,0.0 -119600,2.2646375,3.8965628,,,,,,,,,,,,,, -119700,2.26195,4.6039343,,,,,,,,,,,,,, -119800,2.1973686,3.4465637,,,,,,,,,,,,,, -119900,2.5708282,2.7870708,,,,,,,,,,,,,, -120000,2.6062882,3.0146105,,,,,,,,,,,,,, -120100,2.1758657,2.8455691,,,,,,,,,,,,,, -120200,2.515536,3.0227838,,,,,,,,,,,,,, -120300,2.4327557,3.3401814,,,,,,,,,,,,,, -120400,2.4562836,2.8658586,,,,,,,,,,,,,, -120462,,,0.7491210699081421,1.108149766921997,0.6756599545478821,1.4270612001419067,50000.0,0.5540000200271606,2.0518739223480225,10000.0,54664.61124134064,59429.12511634827,54664.61124134064,4753.07851433754,5.121261596679688,0.0 -120500,2.318334,2.8681705,,,,,,,,,,,,,, -120600,2.5339653,2.8755703,,,,,,,,,,,,,, -120700,2.409985,2.835534,,,,,,,,,,,,,, -120800,2.217463,3.7689788,,,,,,,,,,,,,, -120900,2.38116,3.00528,,,,,,,,,,,,,, -121000,2.3568482,4.4258947,,,,,,,,,,,,,, -121100,2.4940493,2.8226452,,,,,,,,,,,,,, -121200,2.213104,4.3157463,,,,,,,,,,,,,, -121300,2.1557746,3.2369747,,,,,,,,,,,,,, -121388,,,0.7331249713897705,1.1577279567718506,0.6822599768638611,1.3907045125961304,50000.0,0.5573000311851501,2.0155694484710693,10000.0,55084.6750433445,59890.00051164627,55084.6750433445,4793.797434568405,5.165852785110474,0.0 -121400,2.2253914,4.215461,,,,,,,,,,,,,, -121500,2.361024,2.7738094,,,,,,,,,,,,,, -121600,2.6376355,2.7989597,,,,,,,,,,,,,, -121700,2.3564055,2.79475,,,,,,,,,,,,,, -121800,2.227179,3.153706,,,,,,,,,,,,,, -121900,2.5322433,2.8664186,,,,,,,,,,,,,, -122000,2.3835316,2.850706,,,,,,,,,,,,,, -122100,2.3826349,2.7873201,,,,,,,,,,,,,, -122200,2.191743,3.1713467,,,,,,,,,,,,,, -122300,2.430794,3.1683192,,,,,,,,,,,,,, -122313,,,0.7346875071525574,1.1657814979553225,0.675599992275238,1.4246152639389038,50000.0,0.5590000152587891,2.026743173599243,10000.0,55504.61960840225,60348.186989068985,55504.61960840225,4831.949808120728,5.2081685066223145,0.0 -122400,2.4070523,4.3878155,,,,,,,,,,,,,, -122500,2.3668451,2.8522186,,,,,,,,,,,,,, -122600,2.215058,2.745631,,,,,,,,,,,,,, -122700,2.6913378,2.8050544,,,,,,,,,,,,,, -122800,2.6207094,2.8313918,,,,,,,,,,,,,, -122900,2.1739538,2.9989216,,,,,,,,,,,,,, -123000,2.265946,3.3232412,,,,,,,,,,,,,, -123100,1.9879959,3.7527783,,,,,,,,,,,,,, -123200,2.3413873,2.8456037,,,,,,,,,,,,,, -123236,,,0.7521093487739563,1.1196391582489014,0.6818400025367737,1.4134577512741089,50000.0,0.5630000233650208,2.019491672515869,10000.0,55924.80131602287,60808.3349378109,55924.80131602287,4871.815351247788,5.260644197463989,0.0 -123300,2.5067735,2.9559212,,,,,,,,,,,,,, -123400,2.4960463,2.8162014,,,,,,,,,,,,,, -123500,2.5627675,2.9182527,,,,,,,,,,,,,, -123600,2.1782174,4.5628777,,,,,,,,,,,,,, -123700,2.3621156,2.8016877,,,,,,,,,,,,,, -123800,2.305402,4.371302,,,,,,,,,,,,,, -123900,2.185181,4.0671196,,,,,,,,,,,,,, -124000,2.4304194,2.7547162,,,,,,,,,,,,,, -124100,2.3795702,2.966619,,,,,,,,,,,,,, -124159,,,0.7383007407188416,1.1520472764968872,0.6840800046920776,1.3852416276931765,50000.0,0.560200035572052,2.0003364086151123,10000.0,56345.114104270935,61269.6905901432,56345.114104270935,4912.768722057343,5.302764654159546,0.0 -124200,2.5000389,2.787468,,,,,,,,,,,,,, -124300,2.5400646,2.817443,,,,,,,,,,,,,, -124400,2.3121932,3.635969,,,,,,,,,,,,,, -124500,2.1361952,3.7970634,,,,,,,,,,,,,, -124600,2.1701047,3.4126894,,,,,,,,,,,,,, -124700,2.2262688,3.8753064,,,,,,,,,,,,,, -124800,2.3949883,2.716964,,,,,,,,,,,,,, -124900,2.446817,2.7217522,,,,,,,,,,,,,, -125000,2.4199233,2.8956609,,,,,,,,,,,,,, -125086,,,0.7397655844688416,1.1573693752288818,0.6845999956130981,1.4100217819213867,50000.0,0.558899998664856,2.0429952144622803,10000.0,56765.40418744087,61728.776727199554,56765.40418744087,4951.472690105438,5.346803903579712,0.0 -125100,2.6662445,2.7168565,,,,,,,,,,,,,, -125200,2.343427,4.16923,,,,,,,,,,,,,, -125300,2.3320441,2.891007,,,,,,,,,,,,,, -125400,2.762535,2.6880817,,,,,,,,,,,,,, -125500,2.5521486,2.7379496,,,,,,,,,,,,,, -125600,2.2759757,4.476016,,,,,,,,,,,,,, -125700,2.2187831,3.72317,,,,,,,,,,,,,, -125800,2.4207792,3.0130737,,,,,,,,,,,,,, -125900,2.476293,4.8354397,,,,,,,,,,,,,, -126000,2.8542545,2.692937,,,,,,,,,,,,,, -126011,,,0.7486132383346558,1.118463397026062,0.6848799586296082,1.392722725868225,50000.0,0.5542000532150269,2.0321836471557617,10000.0,57185.52679872513,62190.26622200012,57185.52679872513,4992.742039680481,5.39652943611145,0.0 -126100,2.5518909,3.067135,,,,,,,,,,,,,, -126200,2.6170504,3.168226,,,,,,,,,,,,,, -126300,2.4060545,4.835778,,,,,,,,,,,,,, -126400,2.4830449,4.738393,,,,,,,,,,,,,, -126500,2.807015,2.6761694,,,,,,,,,,,,,, -126600,2.4019778,4.45207,,,,,,,,,,,,,, -126700,2.5606918,2.8965454,,,,,,,,,,,,,, -126800,2.4186985,2.7815323,,,,,,,,,,,,,, -126900,2.1643937,3.786946,,,,,,,,,,,,,, -126934,,,0.7451757788658142,1.1258550882339478,0.686739981174469,1.3771331310272217,50000.0,0.5667000412940979,1.988448977470398,10000.0,57605.49915289879,62649.33008289337,57605.49915289879,5031.7367787361145,5.445305347442627,0.0 -127000,2.5605907,2.8418865,,,,,,,,,,,,,, -127100,2.6977015,2.6635108,,,,,,,,,,,,,, -127200,2.5683916,3.0382004,,,,,,,,,,,,,, -127300,2.4484491,2.833065,,,,,,,,,,,,,, -127400,2.401334,2.6324854,,,,,,,,,,,,,, -127500,2.6806755,2.7474718,,,,,,,,,,,,,, -127600,2.8120863,2.8248372,,,,,,,,,,,,,, -127700,2.404517,2.9443343,,,,,,,,,,,,,, -127800,2.6773288,2.8069575,,,,,,,,,,,,,, -127859,,,0.7473437190055847,1.1448750495910645,0.6888999938964844,1.3919219970703125,50000.0,0.5637000203132629,2.017214298248291,10000.0,58025.73322200775,63110.56328248978,58025.73322200775,5072.641215085983,5.489341497421265,0.0 -127900,2.2935097,4.1509604,,,,,,,,,,,,,, -128000,2.6930132,2.699334,,,,,,,,,,,,,, -128100,2.5842438,3.2533019,,,,,,,,,,,,,, -128200,2.6769876,2.8095877,,,,,,,,,,,,,, -128300,2.5035472,2.791458,,,,,,,,,,,,,, -128400,2.4037943,4.7307787,,,,,,,,,,,,,, -128500,2.3056886,3.897868,,,,,,,,,,,,,, -128600,2.8920672,3.5129857,,,,,,,,,,,,,, -128700,2.4459832,3.8768487,,,,,,,,,,,,,, -128782,,,0.753710925579071,1.085994005203247,0.6908999681472778,1.3590346574783323,50000.0,0.5646000504493713,1.990260481834412,10000.0,58445.72785902023,63571.92673802376,58445.72785902023,5113.915201663971,5.5364158153533936,0.0 -128800,2.8182485,2.73761,,,,,,,,,,,,,, -128900,2.6709027,2.965612,,,,,,,,,,,,,, -129000,2.4111807,3.7372122,,,,,,,,,,,,,, -129100,2.467992,3.5291262,,,,,,,,,,,,,, -129200,2.541331,3.3326032,,,,,,,,,,,,,, -129300,2.624773,2.7355018,,,,,,,,,,,,,, -129400,2.9254332,2.6894908,,,,,,,,,,,,,, -129500,2.4822493,3.157474,,,,,,,,,,,,,, -129600,2.5630374,2.8574505,,,,,,,,,,,,,, -129700,2.3526957,4.8342113,,,,,,,,,,,,,, -129707,,,0.7651562094688416,1.0557949542999268,0.6894400119781494,1.376419186592102,50000.0,0.562000036239624,2.0006537437438965,10000.0,58865.84973907471,64032.96802806854,58865.84973907471,5154.739242553711,5.584019184112549,0.0 -129800,2.5019388,2.6575348,,,,,,,,,,,,,, -129900,3.254937,4.84643,,,,,,,,,,,,,, -130000,2.846138,3.2081068,,,,,,,,,,,,,, -130100,2.3557503,3.6538196,,,,,,,,,,,,,, -130200,2.5159214,4.094752,,,,,,,,,,,,,, -130300,2.8413017,2.7292185,,,,,,,,,,,,,, -130400,2.7700386,4.277227,,,,,,,,,,,,,, -130500,2.7054865,2.6479263,,,,,,,,,,,,,, -130600,2.5579643,4.2682943,,,,,,,,,,,,,, -130627,,,0.7515429258346558,1.082430362701416,0.6952399611473083,1.3379631042480469,50000.0,0.5738000273704529,1.951003074645996,10000.0,59285.97770404816,64494.23661541939,59285.97770404816,5195.786374568939,5.630152940750122,0.0 -130700,2.7522755,2.7188377,,,,,,,,,,,,,, -130800,2.3593643,3.6619046,,,,,,,,,,,,,, -130900,2.758334,4.3344445,,,,,,,,,,,,,, -131000,2.5056286,2.6635976,,,,,,,,,,,,,, -131100,2.469119,3.8811173,,,,,,,,,,,,,, -131200,3.027609,2.7863665,,,,,,,,,,,,,, -131300,2.8054578,4.437685,,,,,,,,,,,,,, -131400,2.7759516,3.9104328,,,,,,,,,,,,,, -131500,2.6034718,4.6789007,,,,,,,,,,,,,, -131551,,,0.7570898532867432,1.0904736518859863,0.6944199800491333,1.3673386573791504,50000.0,0.5703999996185303,1.9871017932891848,10000.0,59706.24283266068,64950.979976415634,59706.24283266068,5232.168748378754,5.67841386795044,0.0 -131600,2.840279,2.9899683,,,,,,,,,,,,,, -131700,2.6728501,2.6910625,,,,,,,,,,,,,, -131800,2.609059,2.8166552,,,,,,,,,,,,,, -131900,2.5190556,2.6837964,,,,,,,,,,,,,, -132000,2.9761708,2.8110566,,,,,,,,,,,,,, -132100,2.7722278,2.7599802,,,,,,,,,,,,,, -132200,2.7039742,4.4909945,,,,,,,,,,,,,, -132300,2.834429,2.640015,,,,,,,,,,,,,, -132400,2.6241012,2.7993739,,,,,,,,,,,,,, -132475,,,0.769238293170929,1.0069472789764404,0.7003200054168701,1.3125197887420654,50000.0,0.5768000483512878,1.9146699905395508,10000.0,60126.579996824265,65411.887921094894,60126.579996824265,5272.644042253494,5.7267396450042725,0.0 -132500,2.8180158,2.896297,,,,,,,,,,,,,, -132600,2.6653488,2.6030152,,,,,,,,,,,,,, -132700,2.803698,2.6198452,,,,,,,,,,,,,, -132800,2.7446835,3.8039122,,,,,,,,,,,,,, -132900,2.7748442,2.7903466,,,,,,,,,,,,,, -133000,2.695586,2.9383893,,,,,,,,,,,,,, -133100,2.766896,2.68208,,,,,,,,,,,,,, -133200,2.9901161,2.6387963,,,,,,,,,,,,,, -133300,2.418732,3.624841,,,,,,,,,,,,,, -133400,2.9619277,2.7643936,,,,,,,,,,,,,, -133401,,,0.757519543170929,1.062558889389038,0.6997599601745605,1.305146098136902,50000.0,0.5755000114440918,1.93559992313385,10000.0,60547.2616622448,65874.38661026955,60547.2616622448,5314.366637706757,5.773090362548828,0.0 -133500,3.0967166,2.8210852,,,,,,,,,,,,,, -133600,2.7923908,4.515991,,,,,,,,,,,,,, -133700,2.6779437,4.040703,,,,,,,,,,,,,, -133800,3.0774655,2.7510774,,,,,,,,,,,,,, -133900,2.7713456,2.7716424,,,,,,,,,,,,,, -134000,2.7279987,3.0741267,,,,,,,,,,,,,, -134100,2.5504816,3.1205444,,,,,,,,,,,,,, -134200,2.572896,3.4271603,,,,,,,,,,,,,, -134300,3.0415823,4.671838,,,,,,,,,,,,,, -134327,,,0.7621874809265137,1.0537630319595337,0.700980007648468,1.31785249710083,50000.0,0.5751000046730042,1.9424078464508057,10000.0,60967.64869689941,66331.51563692093,60967.64869689941,5351.010791301727,5.823288679122925,0.0 -134400,2.5551507,3.005569,,,,,,,,,,,,,, -134500,2.8987052,2.8496075,,,,,,,,,,,,,, -134600,3.2345126,2.6460063,,,,,,,,,,,,,, -134700,2.438314,3.2654438,,,,,,,,,,,,,, -134800,2.4951131,3.486089,,,,,,,,,,,,,, -134900,3.7456498,2.824974,,,,,,,,,,,,,, -135000,2.7310338,3.2787604,,,,,,,,,,,,,, -135100,2.7504537,2.6859825,,,,,,,,,,,,,, -135200,2.8105726,2.5678244,,,,,,,,,,,,,, -135250,,,0.7675390243530273,1.0237082242965698,0.6987599730491638,1.3225077390670776,50000.0,0.5711000561714172,1.9518496990203853,10000.0,61387.66164803505,66792.39415669441,61387.66164803505,5391.782269239426,5.869398355484009,0.0 -135300,2.8881557,2.7686076,,,,,,,,,,,,,, -135400,2.8689473,2.6285017,,,,,,,,,,,,,, -135500,3.0152962,2.7475076,,,,,,,,,,,,,, -135600,3.0183163,2.755862,,,,,,,,,,,,,, -135700,2.7727313,2.9125702,,,,,,,,,,,,,, -135800,2.7961686,2.799457,,,,,,,,,,,,,, -135900,2.9078348,2.908799,,,,,,,,,,,,,, -136000,2.8227148,3.174289,,,,,,,,,,,,,, -136100,2.8979213,3.1701362,,,,,,,,,,,,,, -136174,,,0.764843761920929,1.0483845472335815,0.7023999691009521,1.321294188499451,50000.0,0.5789000391960144,1.929173469543457,10000.0,61807.8627281189,67251.33851337433,61807.8627281189,5430.429076910019,5.918476819992065,0.0 -136200,2.9586236,2.6732917,,,,,,,,,,,,,, -136300,2.5897892,3.6116774,,,,,,,,,,,,,, -136400,2.9116118,2.6315799,,,,,,,,,,,,,, -136500,3.2790968,2.7524915,,,,,,,,,,,,,, -136600,3.411429,2.6745536,,,,,,,,,,,,,, -136700,2.769404,2.640717,,,,,,,,,,,,,, -136800,2.952683,2.723394,,,,,,,,,,,,,, -136900,2.8757744,2.6029725,,,,,,,,,,,,,, -137000,2.806549,2.565077,,,,,,,,,,,,,, -137100,2.7878294,4.0464635,,,,,,,,,,,,,, -137101,,,0.7681640386581421,1.0459219217300415,0.7051399946212769,1.3112800121307373,50000.0,0.58160001039505,1.9345344305038448,10000.0,62228.40162968636,67711.19112372398,62228.40162968636,5469.642811059952,5.971429824829102,0.0 -137200,2.9228303,2.6888614,,,,,,,,,,,,,, -137300,2.8774748,2.6161547,,,,,,,,,,,,,, -137400,3.161305,2.6758966,,,,,,,,,,,,,, -137500,2.8821974,4.3540173,,,,,,,,,,,,,, -137600,3.0825398,2.7359068,,,,,,,,,,,,,, -137700,2.803317,2.6027257,,,,,,,,,,,,,, -137800,3.1308994,2.5800097,,,,,,,,,,,,,, -137900,2.8429363,2.6050186,,,,,,,,,,,,,, -138000,2.9540186,2.6643102,,,,,,,,,,,,,, -138024,,,0.7739452719688416,1.0145626068115234,0.706559956073761,1.2995346784591677,50000.0,0.5817000269889832,1.9207830429077148,10000.0,62648.38740777969,68172.46992921829,62648.38740777969,5510.844088315964,6.016285419464111,0.0 -138100,3.0046248,3.6212869,,,,,,,,,,,,,, -138200,2.8872254,2.6658688,,,,,,,,,,,,,, -138300,2.9719324,2.6207771,,,,,,,,,,,,,, -138400,3.0197606,4.3883004,,,,,,,,,,,,,, -138500,2.8057506,2.6243823,,,,,,,,,,,,,, -138600,3.2239,2.66851,,,,,,,,,,,,,, -138700,2.9793437,2.579208,,,,,,,,,,,,,, -138800,2.885836,2.5878038,,,,,,,,,,,,,, -138900,3.1191392,2.6680012,,,,,,,,,,,,,, -138947,,,0.7826952934265137,0.9653533697128296,0.7051399946212769,1.3022921085357666,50000.0,0.579200029373169,1.925778031349182,10000.0,63068.49566245079,68633.90044283867,63068.49566245079,5552.058122396469,6.076954126358032,0.0 -139000,3.0825336,3.8903847,,,,,,,,,,,,,, -139100,3.0760534,2.6386602,,,,,,,,,,,,,, -139200,3.0874903,2.6503248,,,,,,,,,,,,,, -139300,3.14924,2.912624,,,,,,,,,,,,,, -139400,2.7507238,3.9284568,,,,,,,,,,,,,, -139500,2.9628813,4.3632097,,,,,,,,,,,,,, -139600,3.2309253,2.6830506,,,,,,,,,,,,,, -139700,3.313119,2.5093014,,,,,,,,,,,,,, -139800,2.8815813,3.7516608,,,,,,,,,,,,,, -139869,,,0.7697656154632568,1.0047340393066406,0.7098000049591064,1.2667945623397827,50000.0,0.5893000364303589,1.8732690811157229,10000.0,63488.75289964676,69093.70464968681,63488.75289964676,5591.511062860489,6.124045372009277,0.0 -139900,2.8164408,2.9976602,,,,,,,,,,,,,, -140000,2.761999,3.6648655,,,,,,,,,,,,,, -140100,3.0134377,2.4057326,,,,,,,,,,,,,, -140200,3.1708927,2.6444964,,,,,,,,,,,,,, -140300,2.887961,2.7120612,,,,,,,,,,,,,, -140400,3.0391598,4.1386604,,,,,,,,,,,,,, -140500,2.8497224,3.3381577,,,,,,,,,,,,,, -140600,3.155278,2.655078,,,,,,,,,,,,,, -140700,3.2982547,2.6817915,,,,,,,,,,,,,, -140795,,,0.7763866782188416,0.9946498870849608,0.7087399959564209,1.28446626663208,50000.0,0.5814000368118286,1.9018096923828125,10000.0,63908.82472872734,69551.93505644798,63908.82472872734,5629.565707683563,6.17889142036438,0.0 -140800,2.9384406,2.6344488,,,,,,,,,,,,,, -140900,3.3332095,2.630745,,,,,,,,,,,,,, -141000,3.2283134,4.338232,,,,,,,,,,,,,, -141100,2.862116,2.964431,,,,,,,,,,,,,, -141200,2.902077,2.676174,,,,,,,,,,,,,, -141300,3.1804636,2.560028,,,,,,,,,,,,,, -141400,2.7552722,2.645001,,,,,,,,,,,,,, -141500,3.2664483,4.2669835,,,,,,,,,,,,,, -141600,3.0536613,2.5676348,,,,,,,,,,,,,, -141700,3.3169112,2.6346543,,,,,,,,,,,,,, -141719,,,0.7819921970367432,0.9503366351127625,0.7135999798774719,1.2518689632415771,50000.0,0.5896000266075134,1.8770052194595337,10000.0,64329.11201548576,70011.48867511749,64329.11201548576,5668.733921766281,6.22956919670105,0.0 -141800,2.969514,3.3562617,,,,,,,,,,,,,, -141900,2.8581717,2.6111455,,,,,,,,,,,,,, -142000,2.7678878,3.9245443,,,,,,,,,,,,,, -142100,2.7922666,3.871713,,,,,,,,,,,,,, -142200,3.0775237,2.6529253,,,,,,,,,,,,,, -142300,3.1527212,2.576105,,,,,,,,,,,,,, -142400,3.1125054,4.3302293,,,,,,,,,,,,,, -142500,3.055268,3.6545284,,,,,,,,,,,,,, -142600,3.2601297,2.6224415,,,,,,,,,,,,,, -142641,,,0.7760937213897705,0.9692516326904296,0.7108399868011475,1.23958420753479,50000.0,0.5865000486373901,1.8613325357437127,10000.0,64749.359623909,70469.93096780777,64749.359623909,5706.832077026367,6.278799533843994,0.0 -142700,3.217925,2.5669732,,,,,,,,,,,,,, -142800,3.2586975,3.7625332,,,,,,,,,,,,,, -142900,3.3024347,2.6601448,,,,,,,,,,,,,, -143000,3.4007273,4.2982764,,,,,,,,,,,,,, -143100,2.9342103,2.7350035,,,,,,,,,,,,,, -143200,2.9608662,4.0968857,,,,,,,,,,,,,, -143300,3.1045423,2.820893,,,,,,,,,,,,,, -143400,3.201281,2.6191635,,,,,,,,,,,,,, -143500,3.7467227,2.6372814,,,,,,,,,,,,,, -143566,,,0.778613269329071,0.9984338283538818,0.7130399942398071,1.2804381847381592,50000.0,0.5878000259399414,1.894475340843201,10000.0,65169.36147618294,70930.76969718933,65169.36147618294,5747.56706738472,6.333415985107422,0.0 -143600,3.0858502,3.1505535,,,,,,,,,,,,,, -143700,3.366118,2.7804089,,,,,,,,,,,,,, -143800,3.1479716,2.701108,,,,,,,,,,,,,, -143900,3.029183,2.538218,,,,,,,,,,,,,, -144000,3.1528015,2.6611643,,,,,,,,,,,,,, -144100,3.4814904,3.9446635,,,,,,,,,,,,,, -144200,3.1855595,3.0987682,,,,,,,,,,,,,, -144300,3.1862369,2.8879924,,,,,,,,,,,,,, -144400,3.4807835,2.5113199,,,,,,,,,,,,,, -144491,,,0.7850390672683716,0.9747884273529052,0.7129600048065186,1.2740776538848877,50000.0,0.5898000001907349,1.884268045425415,10000.0,65589.43929386139,71392.52113199234,65589.43929386139,5789.146682500839,6.380536079406738,0.0 -144500,3.2410588,2.641115,,,,,,,,,,,,,, -144600,3.1043978,3.6302161,,,,,,,,,,,,,, -144700,3.0285175,3.074037,,,,,,,,,,,,,, -144800,3.1081154,2.8353791,,,,,,,,,,,,,, -144900,3.7136335,3.4806268,,,,,,,,,,,,,, -145000,3.5600402,3.0020728,,,,,,,,,,,,,, -145100,3.3294237,2.8527002,,,,,,,,,,,,,, -145200,3.3696074,2.6600385,,,,,,,,,,,,,, -145300,3.021271,2.7647023,,,,,,,,,,,,,, -145400,3.2005014,4.0481505,,,,,,,,,,,,,, -145414,,,0.77845698595047,0.9919044971466064,0.7156599760055542,1.2642072439193726,50000.0,0.5966000556945801,1.8826606273651123,10000.0,66009.40965223312,71853.4624812603,66009.40965223312,5830.013171672821,6.438591480255127,0.0 -145500,3.4834113,2.9262054,,,,,,,,,,,,,, -145600,3.9772875,4.5604057,,,,,,,,,,,,,, -145700,3.5444794,2.674994,,,,,,,,,,,,,, -145800,3.7201052,2.527428,,,,,,,,,,,,,, -145900,3.9984798,2.5491579,,,,,,,,,,,,,, -146000,3.5079691,2.573698,,,,,,,,,,,,,, -146100,3.1374085,3.5257854,,,,,,,,,,,,,, -146200,3.6600957,4.4464564,,,,,,,,,,,,,, -146300,3.462332,2.3845391,,,,,,,,,,,,,, -146341,,,0.7854687571525574,0.936896562576294,0.7184799909591675,1.2213993072509766,50000.0,0.5951000452041626,1.8369319438934328,10000.0,66429.68575835228,72311.18190431595,66429.68575835228,5867.35782957077,6.490723133087158,0.0 -146400,3.8244007,4.566668,,,,,,,,,,,,,, -146500,3.5450034,2.4887667,,,,,,,,,,,,,, -146600,3.6191978,4.3723235,,,,,,,,,,,,,, -146700,3.801799,2.6134505,,,,,,,,,,,,,, -146800,3.285107,2.561232,,,,,,,,,,,,,, -146900,3.7158675,2.4038904,,,,,,,,,,,,,, -147000,3.4143665,2.7472606,,,,,,,,,,,,,, -147100,3.5377126,2.5103023,,,,,,,,,,,,,, -147200,3.6436467,3.0776577,,,,,,,,,,,,,, -147267,,,0.791308581829071,0.9221143126487732,0.7231599688529968,1.2180984020233154,50000.0,0.602400004863739,1.8256595134735107,10000.0,66849.90239930153,72770.54873371124,66849.90239930153,5906.406289815903,6.545153379440308,0.0 -147300,3.4798357,2.4959767,,,,,,,,,,,,,, -147400,3.459355,3.3385346,,,,,,,,,,,,,, -147500,3.9589238,4.119162,,,,,,,,,,,,,, -147600,3.009428,3.2861156,,,,,,,,,,,,,, -147700,3.5608494,2.9028008,,,,,,,,,,,,,, -147800,3.4043846,2.5844774,,,,,,,,,,,,,, -147900,3.3130717,2.4648066,,,,,,,,,,,,,, -148000,3.7657828,2.7787414,,,,,,,,,,,,,, -148100,3.6340394,4.060913,,,,,,,,,,,,,, -148193,,,0.8017187118530273,0.8867132663726807,0.7227199673652649,1.2239965200424194,50000.0,0.6012000441551208,1.8258556127548216,10000.0,67269.83102846146,73231.43658638,67269.83102846146,5947.267498254776,6.596628427505493,0.0 -148200,3.6486957,2.5210562,,,,,,,,,,,,,, -148300,3.6561608,2.6124384,,,,,,,,,,,,,, -148400,3.4911003,2.5490406,,,,,,,,,,,,,, -148500,3.5354505,4.3119364,,,,,,,,,,,,,, -148600,3.52884,2.5923264,,,,,,,,,,,,,, -148700,3.558841,2.865343,,,,,,,,,,,,,, -148800,3.4714127,2.6159756,,,,,,,,,,,,,, -148900,3.5060039,2.939373,,,,,,,,,,,,,, -149000,4.329131,4.386219,,,,,,,,,,,,,, -149100,3.6636837,2.531009,,,,,,,,,,,,,, -149120,,,0.7932226657867432,0.906248927116394,0.7246599793434143,1.2037746906280518,50000.0,0.6017000079154968,1.801848292350769,10000.0,67690.06477189064,73689.20619821548,67690.06477189064,5984.706177949905,6.646602392196655,0.0 -149200,3.6045291,2.7264948,,,,,,,,,,,,,, -149300,3.1834881,3.256159,,,,,,,,,,,,,, -149400,3.1728704,2.94897,,,,,,,,,,,,,, -149500,3.9373841,4.1900163,,,,,,,,,,,,,, -149600,3.718753,3.7829742,,,,,,,,,,,,,, -149700,4.018608,2.6076512,,,,,,,,,,,,,, -149800,3.5401003,2.4239988,,,,,,,,,,,,,, -149900,3.5731847,2.6273687,,,,,,,,,,,,,, -150000,4.008464,2.6467023,,,,,,,,,,,,,, -150046,,,0.7907617092132568,0.9388483762741088,0.7214599847793579,1.2328137159347534,50000.0,0.6027000546455383,1.8345398902893064,10000.0,68110.0372145176,74148.33674430847,68110.0372145176,6023.770362854004,6.693514585494995,0.0 -150100,3.7632964,2.532368,,,,,,,,,,,,,, -150200,3.9069731,4.07922,,,,,,,,,,,,,, -150300,3.936102,4.4448457,,,,,,,,,,,,,, -150400,3.8088555,2.57048,,,,,,,,,,,,,, -150500,3.867344,2.5145698,,,,,,,,,,,,,, -150600,3.2834928,2.6348205,,,,,,,,,,,,,, -150700,3.8310819,4.055469,,,,,,,,,,,,,, -150800,3.4628825,2.7937462,,,,,,,,,,,,,, -150900,4.0969086,3.993544,,,,,,,,,,,,,, -150968,,,0.8061718344688416,0.8653810024261475,0.7301200032234192,1.189891338348389,50000.0,0.6043000221252441,1.8021141290664675,10000.0,68529.95536541939,74611.0534362793,68529.95536541939,6066.470200300217,6.744734525680542,0.0 -151000,4.079763,2.54597,,,,,,,,,,,,,, -151100,3.8585534,2.4604511,,,,,,,,,,,,,, -151200,3.6298704,2.5495818,,,,,,,,,,,,,, -151300,3.821417,2.7117965,,,,,,,,,,,,,, -151400,4.00089,2.4014263,,,,,,,,,,,,,, -151500,3.7111182,2.8052468,,,,,,,,,,,,,, -151600,3.600886,3.1130702,,,,,,,,,,,,,, -151700,3.4367383,2.9976234,,,,,,,,,,,,,, -151800,3.7167966,2.6632733,,,,,,,,,,,,,, -151891,,,0.7955273389816284,0.8970678448677063,0.7301200032234192,1.1849876642227173,50000.0,0.6104000210762024,1.778713345527649,10000.0,68950.07535648346,75072.84614467621,68950.07535648346,6108.044515609741,6.796485185623169,0.0 -151900,3.691994,3.7929065,,,,,,,,,,,,,, -152000,3.738798,2.467391,,,,,,,,,,,,,, -152100,3.5365634,3.2522893,,,,,,,,,,,,,, -152200,3.9386969,3.918543,,,,,,,,,,,,,, -152300,3.681706,2.5426724,,,,,,,,,,,,,, -152400,3.65609,2.3432112,,,,,,,,,,,,,, -152500,4.249145,2.4615393,,,,,,,,,,,,,, -152600,4.127778,2.4783945,,,,,,,,,,,,,, -152700,4.0840716,4.195085,,,,,,,,,,,,,, -152800,3.3829422,3.0100033,,,,,,,,,,,,,, -152814,,,0.8013476133346558,0.8696554899215698,0.7305999994277954,1.1746702194213867,50000.0,0.6066000461578369,1.7887028455734253,10000.0,69370.09718680382,75533.73214793205,69370.09718680382,6148.80455327034,6.853741884231567,0.0 -152900,4.1131644,2.47506,,,,,,,,,,,,,, -153000,3.8995912,3.0546927,,,,,,,,,,,,,, -153100,4.1992836,2.4458318,,,,,,,,,,,,,, -153200,3.9646451,2.439017,,,,,,,,,,,,,, -153300,3.9118483,2.6521635,,,,,,,,,,,,,, -153400,4.3264117,2.5755985,,,,,,,,,,,,,, -153500,4.0766587,2.4995384,,,,,,,,,,,,,, -153600,4.416893,2.4429705,,,,,,,,,,,,,, -153700,4.456159,4.3144226,,,,,,,,,,,,,, -153736,,,0.8064843416213989,0.8482201099395752,0.7335599660873413,1.1666568517684937,50000.0,0.6065000295639038,1.7751833200454712,10000.0,69790.06775164604,75993.92200398445,69790.06775164604,6188.927048921585,6.903278827667236,0.0 -153800,3.953729,4.0395303,,,,,,,,,,,,,, -153900,4.2034483,2.7349164,,,,,,,,,,,,,, -154000,4.10118,2.4603064,,,,,,,,,,,,,, -154100,4.170846,2.4817786,,,,,,,,,,,,,, -154200,3.6296098,2.5513709,,,,,,,,,,,,,, -154300,4.1044374,2.5056734,,,,,,,,,,,,,, -154400,4.624526,3.0989397,,,,,,,,,,,,,, -154500,3.921249,3.5160213,,,,,,,,,,,,,, -154600,4.5493903,3.69572,,,,,,,,,,,,,, -154662,,,0.804492175579071,0.8526339530944824,0.7369999885559082,1.151193141937256,50000.0,0.6170000433921814,1.75464928150177,10000.0,70210.01730847359,76454.94499826431,70210.01730847359,6229.90651512146,6.94967794418335,0.0 -154700,3.7852416,2.3376055,,,,,,,,,,,,,, -154800,3.7630157,2.7078505,,,,,,,,,,,,,, -154900,4.0564404,2.5514412,,,,,,,,,,,,,, -155000,4.1054826,2.487822,,,,,,,,,,,,,, -155100,4.0029054,2.4181905,,,,,,,,,,,,,, -155200,3.8169918,2.6139898,,,,,,,,,,,,,, -155300,4.345041,4.3354096,,,,,,,,,,,,,, -155400,4.0157247,2.381062,,,,,,,,,,,,,, -155500,4.4441075,2.8895063,,,,,,,,,,,,,, -155584,,,0.811328113079071,0.8368220329284668,0.738860011100769,1.1395392417907717,50000.0,0.6175000071525574,1.73326575756073,10000.0,70630.25607657433,76916.89459323883,70630.25607657433,6271.519709348679,6.999774217605591,0.0 -155600,5.305935,4.4616256,,,,,,,,,,,,,, -155700,3.6566803,3.2479632,,,,,,,,,,,,,, -155800,3.8436809,3.906638,,,,,,,,,,,,,, -155900,4.237891,2.8965786,,,,,,,,,,,,,, -156000,4.425569,2.7149427,,,,,,,,,,,,,, -156100,4.0557547,2.4598165,,,,,,,,,,,,,, -156200,4.2582183,2.3094058,,,,,,,,,,,,,, -156300,4.096556,2.3831675,,,,,,,,,,,,,, -156400,4.2795863,2.4993877,,,,,,,,,,,,,, -156500,4.170824,3.5077477,,,,,,,,,,,,,, -156508,,,0.8138476610183716,0.8137789964675903,0.7384200096130371,1.1399335861206057,50000.0,0.6134000420570374,1.745561599731445,10000.0,71050.56174874306,77375.42791485786,71050.56174874306,6309.6483726501465,7.052521228790283,0.0 -156600,4.359539,2.8583724,,,,,,,,,,,,,, -156700,4.3615274,2.6990325,,,,,,,,,,,,,, -156800,4.4282713,4.283901,,,,,,,,,,,,,, -156900,4.1219196,2.3971581,,,,,,,,,,,,,, -157000,4.069471,2.7472005,,,,,,,,,,,,,, -157100,4.3209224,2.484463,,,,,,,,,,,,,, -157200,3.9388962,3.1467414,,,,,,,,,,,,,, -157300,4.221241,2.4100611,,,,,,,,,,,,,, -157400,4.108594,3.1597197,,,,,,,,,,,,,, -157431,,,0.8206444978713989,0.7946438193321228,0.7416599988937378,1.1222891807556152,50000.0,0.619100034236908,1.722597599029541,10000.0,71470.68224668503,77835.01753425598,71470.68224668503,6349.01679110527,7.105899810791016,0.0 -157500,4.60255,2.415132,,,,,,,,,,,,,, -157600,4.525764,3.927588,,,,,,,,,,,,,, -157700,4.206341,2.410816,,,,,,,,,,,,,, -157800,4.53046,2.494195,,,,,,,,,,,,,, -157900,3.9835956,2.282561,,,,,,,,,,,,,, -158000,4.403176,2.3933372,,,,,,,,,,,,,, -158100,4.0319567,3.4493787,,,,,,,,,,,,,, -158200,3.8494265,2.5490382,,,,,,,,,,,,,, -158285,,,0.81201171875,0.8173814415931702,0.7415399551391602,1.131630539894104,50000.0,0.6203000545501709,1.7233086824417114,10000.0,71890.93097662926,78297.10611104965,71890.93097662926,6390.761610031128,7.158337116241455,0.0 -158300,3.8183262,2.362543,,,,,,,,,,,,,, -158400,4.6752996,2.381338,,,,,,,,,,,,,, -158500,5.1262484,4.2755404,,,,,,,,,,,,,, -158600,4.0216537,2.6103935,,,,,,,,,,,,,, -158700,4.575051,4.0347724,,,,,,,,,,,,,, -158800,4.463932,2.8567772,,,,,,,,,,,,,, -158900,4.5015974,2.3590262,,,,,,,,,,,,,, -159000,4.1778293,3.3702533,,,,,,,,,,,,,, -159100,4.2864556,3.7182245,,,,,,,,,,,,,, -159200,4.4156604,2.536406,,,,,,,,,,,,,, -159207,,,0.8217968344688416,0.7972414493560791,0.7410199642181396,1.1312137842178345,50000.0,0.625700056552887,1.7257004976272583,10000.0,72310.98796439171,78753.62868118286,72310.98796439171,6427.128840446472,7.209970235824585,0.0 -159300,4.230275,2.4866228,,,,,,,,,,,,,, -159400,4.16951,2.6924536,,,,,,,,,,,,,, -159500,4.3823376,2.3508995,,,,,,,,,,,,,, -159600,4.7847686,2.3563101,,,,,,,,,,,,,, -159700,4.4035487,2.3670783,,,,,,,,,,,,,, -159800,4.6697445,2.4854474,,,,,,,,,,,,,, -159900,4.2702055,3.5266519,,,,,,,,,,,,,, -160000,4.4343486,3.1122727,,,,,,,,,,,,,, -160100,4.6063957,2.3693051,,,,,,,,,,,,,, -160129,,,0.8250585794448853,0.7695590257644653,0.7425199747085571,1.1162033081054688,50000.0,0.6238000392913818,1.7093664407730105,10000.0,72731.21923279762,79214.68068003654,72731.21923279762,6467.850144386292,7.2619194984436035,0.0 -160200,4.4714785,2.821301,,,,,,,,,,,,,, -160300,4.204909,3.731617,,,,,,,,,,,,,, -160400,4.456357,2.3424149,,,,,,,,,,,,,, -160500,4.088897,3.011675,,,,,,,,,,,,,, -160600,4.286915,3.1149714,,,,,,,,,,,,,, -160700,4.855698,3.535695,,,,,,,,,,,,,, -160800,4.328449,3.8732765,,,,,,,,,,,,,, -160900,4.45834,2.3506699,,,,,,,,,,,,,, -161000,4.071206,3.1405182,,,,,,,,,,,,,, -161054,,,0.8220117092132568,0.7986540794372559,0.7463399767875671,1.1236673593521118,50000.0,0.62090003490448,1.732371807098389,10000.0,73151.4669148922,79675.36319732666,73151.4669148922,6508.185884714127,7.314181804656982,0.0 -161100,4.529213,2.4488826,,,,,,,,,,,,,, -161200,4.5504637,2.3299916,,,,,,,,,,,,,, -161300,5.3565016,2.475175,,,,,,,,,,,,,, -161400,4.0553,3.6812196,,,,,,,,,,,,,, -161500,4.5487676,2.3048592,,,,,,,,,,,,,, -161600,4.511119,2.3855379,,,,,,,,,,,,,, -161700,4.3467846,2.6591792,,,,,,,,,,,,,, -161800,4.5380683,3.806693,,,,,,,,,,,,,, -161900,4.5737762,2.377841,,,,,,,,,,,,,, -161979,,,0.82289057970047,0.774276852607727,0.7457799911499023,1.096876859664917,50000.0,0.6240000128746033,1.6988554000854492,10000.0,73571.6182346344,80134.4779574871,73571.6182346344,6547.046494960785,7.370314836502075,0.0 -162000,4.5993023,3.1281598,,,,,,,,,,,,,, -162100,4.8712015,2.3053262,,,,,,,,,,,,,, -162200,4.6830845,2.3780336,,,,,,,,,,,,,, -162300,5.2609186,4.1654882,,,,,,,,,,,,,, -162400,4.712434,2.434522,,,,,,,,,,,,,, -162500,4.5573173,2.2824674,,,,,,,,,,,,,, -162600,4.827246,2.5693157,,,,,,,,,,,,,, -162700,4.348932,2.802515,,,,,,,,,,,,,, -162800,5.549857,4.4026837,,,,,,,,,,,,,, -162900,4.643725,2.3517766,,,,,,,,,,,,,, -162902,,,0.82958984375,0.768218994140625,0.7471399903297424,1.1131218671798706,50000.0,0.6278000473976135,1.7115119695663452,10000.0,73991.6619336605,80594.78010439873,73991.6619336605,6587.196877479553,7.43206524848938,0.0 -163000,4.9036555,2.4041257,,,,,,,,,,,,,, -163100,4.440646,2.317734,,,,,,,,,,,,,, -163200,4.386372,2.590894,,,,,,,,,,,,,, -163300,4.3465886,3.2082288,,,,,,,,,,,,,, -163400,4.242778,2.6656961,,,,,,,,,,,,,, -163500,4.583412,2.2387557,,,,,,,,,,,,,, -163600,5.484599,4.2286477,,,,,,,,,,,,,, -163700,4.806463,2.384744,,,,,,,,,,,,,, -163800,5.589351,2.322607,,,,,,,,,,,,,, -163829,,,0.8261523246765137,0.7748673558235168,0.747219979763031,1.0991054773330688,50000.0,0.624500036239624,1.695138692855835,10000.0,74412.06664991379,81053.74798631668,74412.06664991379,6625.662905454636,7.482219219207764,0.0 -163900,4.6818604,2.3784237,,,,,,,,,,,,,, -164000,4.5043244,2.3808315,,,,,,,,,,,,,, -164100,4.4152765,3.3470852,,,,,,,,,,,,,, -164200,4.469691,2.438967,,,,,,,,,,,,,, -164300,4.691882,2.323707,,,,,,,,,,,,,, -164400,4.825071,3.4843063,,,,,,,,,,,,,, -164500,4.6536107,2.3735812,,,,,,,,,,,,,, -164600,4.9609513,2.312037,,,,,,,,,,,,,, -164700,5.2073007,2.3490338,,,,,,,,,,,,,, -164754,,,0.8268163800239563,0.7791228890419006,0.7498199939727783,1.0978347063064575,50000.0,0.6292000412940979,1.695865511894226,10000.0,74831.999625206,81513.33116149902,74831.999625206,6665.215897798538,7.531557321548462,0.0 -164800,5.0486803,2.3232157,,,,,,,,,,,,,, -164900,5.0519223,2.4317966,,,,,,,,,,,,,, -165000,4.533422,2.3764439,,,,,,,,,,,,,, -165100,4.95135,2.3986092,,,,,,,,,,,,,, -165200,4.80088,2.3744512,,,,,,,,,,,,,, -165300,6.426563,4.302765,,,,,,,,,,,,,, -165400,4.9384127,2.3129592,,,,,,,,,,,,,, -165500,4.9108114,2.6827202,,,,,,,,,,,,,, -165600,4.6372695,2.3997967,,,,,,,,,,,,,, -165676,,,0.8290429711341858,0.7599804997444153,0.749239981174469,1.0925363302230835,50000.0,0.6304000020027161,1.6812580823898315,10000.0,75252.15470719337,81971.07764077187,75252.15470719337,6702.711914539337,7.5804502964019775,0.0 -165700,5.0468373,2.2463582,,,,,,,,,,,,,, -165800,5.3499603,4.0813355,,,,,,,,,,,,,, -165900,5.676519,4.3200555,,,,,,,,,,,,,, -166000,5.0914955,2.320167,,,,,,,,,,,,,, -166100,4.6093607,2.4294355,,,,,,,,,,,,,, -166200,4.770245,2.2635872,,,,,,,,,,,,,, -166300,4.621412,2.3838985,,,,,,,,,,,,,, -166400,5.478832,3.9688888,,,,,,,,,,,,,, -166500,5.040826,2.4204652,,,,,,,,,,,,,, -166600,,,0.8282421827316284,0.7497026920318604,0.7534399628639221,1.0718141794204712,50000.0,0.6350000500679016,1.662460446357727,10000.0,75672.05558896065,82432.52755188942,75672.05558896065,6744.158909320831,7.63471245765686,0.0 -166600,6.464906,4.324336,,,,,,,,,,,,,, -166700,5.3172035,2.2194085,,,,,,,,,,,,,, -166800,5.2153807,2.8754342,,,,,,,,,,,,,, -166900,5.4395814,2.9570436,,,,,,,,,,,,,, -167000,4.985899,2.3264334,,,,,,,,,,,,,, -167100,5.342129,2.6931057,,,,,,,,,,,,,, -167200,5.2885685,2.3749943,,,,,,,,,,,,,, -167300,5.227533,2.8997917,,,,,,,,,,,,,, -167400,6.3289247,3.1181414,,,,,,,,,,,,,, -167500,5.957146,2.2388742,,,,,,,,,,,,,, -167525,,,0.8338280916213989,0.7270570993423462,0.7521399855613708,1.0613549947738647,50000.0,0.6344000101089478,1.650127410888672,10000.0,76092.31971812248,82892.44341945648,76092.31971812248,6783.7116050720215,7.687152862548828,0.0 -167600,5.381686,2.396761,,,,,,,,,,,,,, -167700,6.207218,4.2601585,,,,,,,,,,,,,, -167800,4.846186,2.2780285,,,,,,,,,,,,,, -167900,5.806461,2.4463444,,,,,,,,,,,,,, -168000,4.7458644,2.2724621,,,,,,,,,,,,,, -168100,5.331947,2.6108363,,,,,,,,,,,,,, -168200,5.6752977,2.405989,,,,,,,,,,,,,, -168300,4.839433,2.2372766,,,,,,,,,,,,,, -168400,5.3774915,4.102246,,,,,,,,,,,,,, -168447,,,0.8368163704872131,0.721832275390625,0.755840003490448,1.0656956434249878,50000.0,0.6376000046730042,1.6524871587753296,10000.0,76512.35372972488,83350.1652495861,76512.35372972488,6821.297668457031,7.742358684539795,0.0 -168500,4.953656,2.408424,,,,,,,,,,,,,, -168600,4.941088,2.4331727,,,,,,,,,,,,,, -168700,5.225963,3.620211,,,,,,,,,,,,,, -168800,6.1085687,4.1769986,,,,,,,,,,,,,, -168900,4.962438,2.2378168,,,,,,,,,,,,,, -169000,5.735519,4.049717,,,,,,,,,,,,,, -169100,6.163852,4.1270356,,,,,,,,,,,,,, -169200,5.197424,2.2651417,,,,,,,,,,,,,, -169300,5.044522,2.3414197,,,,,,,,,,,,,, -169369,,,0.8389843702316284,0.7196028828620911,0.7570399641990662,1.0582653284072876,50000.0,0.6318000555038452,1.6552094221115112,10000.0,76932.46559858322,83811.53285884857,76932.46559858322,6862.45266699791,7.796396732330322,0.0 -169400,4.992339,2.2771158,,,,,,,,,,,,,, -169500,5.0051084,3.3078613,,,,,,,,,,,,,, -169600,5.026718,2.2251637,,,,,,,,,,,,,, -169700,4.831605,2.244085,,,,,,,,,,,,,, -169800,4.878189,2.2851477,,,,,,,,,,,,,, -169900,4.792595,2.9387026,,,,,,,,,,,,,, -170000,5.0296435,3.2587311,,,,,,,,,,,,,, -170100,5.3804674,3.2294612,,,,,,,,,,,,,, -170200,4.938982,2.1729243,,,,,,,,,,,,,, -170291,,,0.8350781202316284,0.7398861646652222,0.7601999640464783,1.0654001235961914,50000.0,0.6374000310897827,1.660707950592041,10000.0,77352.67522978783,84272.98402452469,77352.67522978783,6903.587861776352,7.856116533279419,0.0 -170300,5.3799253,2.413763,,,,,,,,,,,,,, -170400,5.1687713,2.2497232,,,,,,,,,,,,,, -170500,5.1156764,2.4130416,,,,,,,,,,,,,, -170600,5.444323,2.4511547,,,,,,,,,,,,,, -170666,,,,,,,,,,,77520.27442193031,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 8652db09d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -26.727606058120728,0.0,35.84359121322632,1,0,35.84359121322632,0.0010000000474974,6.907756805419922,10000,62.57128691673279,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -70.29547381401062,0.020416259765625,455.78622579574585,844,0,455.78622579574585,0.0114000001922249,6.438459873199463,10000,526.145667552948,0.0159765612334013,6.382034778594971,0.015999998897314,6.395730972290039,50000 -110.43268370628356,0.0518798828125,876.1266677379608,1764,0,876.1266677379608,0.033500000834465,5.981106758117676,10000,986.7028484344482,0.0448632799088954,5.828395366668701,0.0431599989533424,5.85699987411499,50000 -154.37527751922607,0.0791792869567871,1296.264136552811,2684,0,1296.264136552811,0.0502000041306018,5.668014049530029,10000,1450.8591334819794,0.0735937505960464,5.448266983032227,0.0674799978733062,5.489231586456299,50000 -198.7751681804657,0.1038756370544433,1716.3032870292664,3604,0,1716.3032870292664,0.0767000019550323,5.356302261352539,10000,1915.3701028823853,0.1104687452316284,5.068248271942139,0.1017199978232383,5.126224994659424,50000 -242.7034797668457,0.1274757385253906,2136.638533353805,4522,0,2136.638533353805,0.0994000062346458,5.049302577972412,10000,2379.705213546753,0.1553320288658142,4.668953895568848,0.1380600035190582,4.754825592041016,50000 -283.0039734840393,0.1524102687835693,2556.660418987274,5441,0,2556.660418987274,0.1346000134944915,4.739532470703125,10000,2840.099986076355,0.1989648342132568,4.309050559997559,0.1846799999475479,4.396907806396484,50000 -325.1707737445832,0.1802799701690673,2976.674590110779,6362,0,2976.674590110779,0.1720000058412552,4.464690208435059,10000,3302.356697320938,0.2446484267711639,3.974132537841797,0.2258399873971939,4.073367118835449,50000 -366.0614733695984,0.2094607353210449,3396.9373681545258,7286,0,3396.9373681545258,0.1998000144958496,4.2360687255859375,10000,3763.586677074432,0.2900195121765136,3.6638667583465576,0.2622399926185608,3.812857151031494,50000 -405.1697099208832,0.2350239753723144,3817.3410184383392,8209,0,3817.3410184383392,0.233800008893013,4.013123989105225,10000,4223.172071218491,0.3228124976158142,3.440293550491333,0.3006999790668487,3.549184560775757,50000 -445.9749083518982,0.2604801654815674,4237.481600284576,9131,0,4237.481600284576,0.2590000033378601,3.83144736289978,10000,4684.190475702286,0.3623827993869781,3.1926872730255127,0.3360999822616577,3.323276996612549,50000 -488.92419385910034,0.2864758968353271,4657.512858867645,10053,0,4657.512858867645,0.2822000086307525,3.700833797454834,10000,5147.245245218277,0.394843727350235,3.041377305984497,0.3632999956607818,3.2022409439086914,50000 -529.4197680950165,0.3179991245269775,5077.823751926422,10977,0,5077.823751926422,0.3065000176429748,3.502833366394043,10000,5608.130947113037,0.4287695288658142,2.809704542160034,0.3992599844932556,2.950516700744629,50000 -569.8549757003784,0.3435926437377929,5497.804342985153,11899,0,5497.804342985153,0.3294000029563904,3.33755874633789,10000,6068.620263576508,0.461249977350235,2.6188957691192627,0.4248999953269958,2.788549900054932,50000 -612.0777878761292,0.3792257308959961,5918.146535158157,12820,0,5918.146535158157,0.3380000293254852,3.348545789718628,10000,6531.268862962723,0.4747070074081421,2.617384910583496,0.4347599744796753,2.79663348197937,50000 -651.4733724594116,0.4065427780151367,6338.404702425003,13743,0,6338.404702425003,0.3629000186920166,3.206737995147705,10000,6990.997879266739,0.4976562261581421,2.4742987155914307,0.4613799750804901,2.645811080932617,50000 -692.7202491760254,0.4355850219726562,6758.347489833832,14665,0,6758.347489833832,0.3760000169277191,3.085344791412353,10000,7452.263117074966,0.5162500143051147,2.3231639862060547,0.4807399809360504,2.4985995292663574,50000 -731.0242302417755,0.4632751941680908,7178.313052892685,15584,0,7178.313052892685,0.3820000290870666,3.095673084259033,10000,7910.607768535614,0.5301562547683716,2.32466983795166,0.4898599982261657,2.5129311084747314,50000 -769.5679693222046,0.4897994995117187,7598.365840673447,16507,0,7598.365840673447,0.3920000195503235,2.980851411819458,10000,8369.278602838516,0.564160168170929,2.104198932647705,0.505620002746582,2.3738718032836914,50000 -807.6075274944305,0.5183203220367432,8018.321758508682,17428,0,8018.321758508682,0.4103000164031982,2.918965101242065,10000,8827.351588010788,0.5541601181030273,2.172273874282837,0.5168799757957458,2.334871292114258,50000 -850.389208316803,0.544304609298706,8438.259444952011,18351,0,8438.259444952011,0.4129000306129455,2.89074444770813,10000,9290.145105838776,0.5694726705551147,2.0965421199798584,0.5231199860572815,2.2878384590148926,50000 -887.2015011310577,0.5706558227539062,8858.576433181763,19264,0,8858.576433181763,0.4234000146389007,2.851503372192383,10000,9747.34814119339,0.5887304544448853,1.99932599067688,0.5343999862670898,2.2401058673858643,50000 -928.9679560661316,0.6004471778869629,9278.755592107773,20183,0,9278.755592107773,0.426000028848648,2.823843002319336,10000,10209.370984315872,0.5824609398841858,2.048027515411377,0.540619969367981,2.229803562164306,50000 -973.2437303066254,0.6334230899810791,9699.127198934557,21105,0,9699.127198934557,0.4392000138759613,2.759177923202514,10000,10674.098667144775,0.5959179401397705,1.9627938270568848,0.5534799695014954,2.157258987426758,50000 -1012.2991769313812,0.6604936122894287,10119.155968666077,22026,0,10119.155968666077,0.4387000203132629,2.6716549396514893,10000,11133.25730252266,0.6109570264816284,1.83113980293274,0.5616799592971802,2.0698909759521484,50000 -1050.435089111328,0.6903214454650879,10539.468144655228,22949,0,10539.468144655228,0.450300008058548,2.7123382091522217,10000,11591.783080339432,0.6087695360183716,1.8961181640625,0.5664600133895874,2.086631774902344,50000 -1092.7479412555697,0.7183361053466797,10959.501294612885,23869,0,10959.501294612885,0.4593000113964081,2.6749086380004883,10000,12054.205533742905,0.6216992139816284,1.848590850830078,0.5771799683570862,2.052497148513794,50000 -1132.2981944084167,0.7464418411254883,11379.689157009125,24790,0,11379.689157009125,0.4606000185012817,2.6689186096191406,10000,12514.022708654404,0.6266406178474426,1.8320873975753784,0.5758000016212463,2.063634157180786,50000 -1172.748398065567,0.7861979007720947,11799.745354413986,25709,0,11799.745354413986,0.4701000154018402,2.585104465484619,10000,12974.617196798325,0.6539843678474426,1.6860125064849854,0.5881400108337402,1.9701340198516848,50000 -1214.0835967063904,0.8152399063110352,12219.731608867643,26630,0,12219.731608867643,0.4780000150203705,2.573791742324829,10000,13436.014949560164,0.6346874833106995,1.760985255241394,0.5895199775695801,1.968305706977844,50000 -1256.7773969173431,0.8430163860321045,12639.99943780899,27553,0,12639.99943780899,0.4765000343322754,2.5679101943969727,10000,13899.052494764328,0.6429687142372131,1.7258527278900146,0.5952999591827393,1.952037453651428,50000 -1296.0321819782257,0.8778200149536133,13060.122702360151,28476,0,13060.122702360151,0.4812000095844269,2.56640625,10000,14358.513455867767,0.6584765315055847,1.6636217832565308,0.595579981803894,1.949023962020874,50000 -1337.395380973816,0.9113750457763672,13480.046933412552,29396,0,13480.046933412552,0.4892000257968902,2.5049943923950195,10000,14819.881961107254,0.6520312428474426,1.674699306488037,0.6082800030708313,1.8775641918182373,50000 -1383.1154959201813,0.945807695388794,13900.2150182724,30317,0,13900.2150182724,0.489300012588501,2.4801666736602783,10000,15285.852749586104,0.6537694931030273,1.6482510566711426,0.6024599671363831,1.874757409095764,50000 -1422.7271156311035,0.9807932376861572,14320.48095369339,31241,0,14320.48095369339,0.4948000311851501,2.441969156265259,10000,15745.813037872314,0.6695898175239563,1.5669158697128296,0.6098799705505371,1.8242343664169312,50000 -1466.1945168972015,1.0090515613555908,14740.702543497086,32162,0,14740.702543497086,0.5010000467300415,2.439016103744507,10000,16209.5774371624,0.6646093726158142,1.623708724975586,0.616159975528717,1.841352462768555,50000 -1504.403111219406,1.0377919673919678,15160.87868309021,33084,0,15160.87868309021,0.494700014591217,2.4732367992401123,10000,16668.039219379425,0.6647070050239563,1.6350743770599363,0.6173399686813354,1.8481569290161133,50000 -1543.6040349006653,1.0717051029205322,15581.12509894371,34004,0,15581.12509894371,0.506600022315979,2.3888542652130127,10000,17127.568870782852,0.6825194954872131,1.5208829641342163,0.6213200092315674,1.7882128953933716,50000 -1583.8245911598206,1.103961944580078,16001.194901704788,34924,0,16001.194901704788,0.5037000179290771,2.400809526443481,10000,17587.93961954117,0.6859960556030273,1.5173721313476562,0.6280999779701233,1.7702747583389282,50000 -1627.0497500896454,1.141261339187622,16421.54888010025,35848,0,16421.54888010025,0.5037000179290771,2.3965654373168945,10000,18051.60387778282,0.6784374713897705,1.537048101425171,0.6295599937438965,1.763940691947937,50000 -1666.6609530448914,1.1702823638916016,16841.60574913025,36770,0,16841.60574913025,0.5098000168800354,2.362467765808105,10000,18511.34856534004,0.6861132383346558,1.495723843574524,0.6327999830245972,1.7372503280639648,50000 -1710.5507550239563,1.207282543182373,17261.76012301445,37690,0,17261.76012301445,0.5088000297546387,2.347130537033081,10000,18975.477598905563,0.7042187452316284,1.429381012916565,0.6326799988746643,1.742835283279419,50000 -1747.6528568267822,1.241055250167847,17681.86350107193,38611,0,17681.86350107193,0.5082000494003296,2.393516540527344,10000,19432.76448178292,0.6838085651397705,1.5368809700012207,0.6318199634552002,1.7639278173446655,50000 -1788.8630316257477,1.2756617069244385,18101.93009710312,39530,0,18101.93009710312,0.5134000182151794,2.360651969909668,10000,19894.123265028,0.6908202767372131,1.493054986000061,0.6378999948501587,1.7385770082473757,50000 -1827.666071653366,1.3061726093292236,18521.99777531624,40446,0,18521.99777531624,0.5159000158309937,2.3539485931396484,10000,20353.072145223618,0.7030664086341858,1.4640662670135498,0.6423999667167664,1.7443610429763794,50000 -1869.7756474018097,1.337547779083252,18942.32962703705,41368,0,18942.32962703705,0.5213000178337097,2.339286088943481,10000,20815.592352867126,0.6928125023841858,1.5066217184066772,0.6416199803352356,1.7189924716949463,50000 -1915.4070200920105,1.3723368644714355,19362.28750777245,42290,0,19362.28750777245,0.5215000510215759,2.330148935317993,10000,21281.26524925232,0.6962890625,1.4691346883773804,0.6421599984169006,1.710872769355774,50000 -1955.033771038056,1.4100327491760254,19782.306488990784,43213,0,19782.306488990784,0.5273000001907349,2.284856081008911,10000,21741.00474333763,0.707324206829071,1.4024572372436523,0.647159993648529,1.6708543300628662,50000 -1997.4169154167173,1.4440712928771973,20202.631172180176,44135,0,20202.631172180176,0.5218999981880188,2.304472684860229,10000,22203.795504808422,0.6955273151397705,1.4521766901016235,0.6490199565887451,1.6639647483825684,50000 -2038.766408443451,1.4792227745056152,20622.772392749783,45055,0,20622.772392749783,0.5235000252723694,2.2860796451568604,10000,22665.36895251274,0.698925793170929,1.434710144996643,0.647379994392395,1.667491793632507,50000 -2081.991260766983,1.511063575744629,21043.10510325432,45974,0,21043.10510325432,0.5283000469207764,2.2812087535858154,10000,23129.006335258484,0.7076953053474426,1.4150549173355105,0.6478599905967712,1.678601622581482,50000 -2123.710966825485,1.544229507446289,21463.18249320984,46895,0,21463.18249320984,0.5331000089645386,2.2646377086639404,10000,23590.88441562653,0.7375390529632568,1.29774272441864,0.6542400121688843,1.6448653936386108,50000 -2166.239804506302,1.5809102058410645,21883.53296923637,47818,0,21883.53296923637,0.5327000021934509,2.245063304901123,10000,24053.848252534863,0.7093945145606995,1.4072017669677734,0.6576399803161621,1.6411067247390747,50000 -2208.6510157585144,1.6165921688079834,22303.59793663025,48741,0,22303.59793663025,0.5318000316619873,2.262031078338623,10000,24516.407782793045,0.7172656059265137,1.3772408962249756,0.6559399962425232,1.6439871788024902,50000 -2252.7241654396057,1.656675100326538,22723.800693511963,49664,0,22723.800693511963,0.5281000137329102,2.343284845352173,10000,24980.77202296257,0.71875,1.425883650779724,0.6511600017547607,1.7147786617279053,50000 -2295.708538770676,1.69374680519104,23143.94828104973,50586,0,23143.94828104973,0.5351999998092651,2.278514862060547,10000,25443.989156007767,0.7098437547683716,1.4361015558242798,0.6548399925231934,1.6722317934036257,50000 -2337.176256418228,1.7330005168914795,23564.28750729561,51508,0,23564.28750729561,0.5435000061988831,2.247091293334961,10000,25905.883530139923,0.71498042345047,1.3889296054840088,0.6620399951934814,1.6310160160064695,50000 -2377.6872231960297,1.768537521362305,23984.374716758728,52432,0,23984.374716758728,0.5354000329971313,2.2272799015045166,10000,26366.565184354786,0.7337304353713989,1.2980376482009888,0.6647999882698059,1.5864545106887815,50000 -2417.860870361328,1.8072423934936523,24404.3652215004,53354,0,24404.3652215004,0.532200038433075,2.2386462688446045,10000,26826.81566309929,0.7135351300239563,1.3821560144424438,0.6652799844741821,1.598161697387695,50000 -2461.081460237503,1.847649097442627,24824.307297229767,54275,0,24824.307297229767,0.5452000498771667,2.222517967224121,10000,27290.06626176834,0.7189062237739563,1.361011028289795,0.6606799960136414,1.6143720149993896,50000 -2500.354782104492,1.8835597038269043,25244.55240249633,55196,0,25244.55240249633,0.5435000061988831,2.220703601837158,10000,27749.66876244545,0.7296484112739563,1.3426628112792969,0.664139986038208,1.6135691404342651,50000 -2546.253002643585,1.9170572757720947,25664.7479467392,56120,0,25664.7479467392,0.5502000451087952,2.189172744750977,10000,28215.84439206124,0.7380273342132568,1.2842124700546265,0.6672599911689758,1.5820810794830322,50000 -2589.74267745018,1.952627658843994,26085.00351262093,57042,0,26085.00351262093,0.5464000105857849,2.211744546890259,10000,28679.673223495483,0.7276562452316284,1.3468148708343506,0.6674000024795532,1.603977918624878,50000 -2631.289387464524,1.9895341396331787,26505.091208696365,57963,0,26505.091208696365,0.5430999994277954,2.196040153503418,10000,29141.39177846909,0.7269140481948853,1.328726887702942,0.667419970035553,1.5865769386291504,50000 -2675.127780675888,2.025451183319092,26925.22168302536,58883,0,26925.22168302536,0.5457000136375427,2.201898097991944,10000,29605.44416475296,0.7373241782188416,1.2906686067581177,0.6669600009918213,1.6132616996765137,50000 -2718.944101333618,2.061138868331909,27345.44739818573,59803,0,27345.44739818573,0.5478000044822693,2.153144598007202,10000,30069.56980085373,0.7282031178474426,1.2871874570846558,0.6700199842453003,1.5382176637649536,50000 -2762.44873046875,2.097840070724488,27765.79726195336,60724,0,27765.79726195336,0.5454000234603882,2.2170250415802,10000,30533.50876736641,0.733593761920929,1.3361026048660278,0.6747199892997742,1.584994196891785,50000 -2805.484763622284,2.13291072845459,28185.98828697205,61645,0,28185.98828697205,0.5508000254631042,2.175717830657959,10000,30996.8189907074,0.7450000047683716,1.2544151544570925,0.6762799620628357,1.5565282106399536,50000 -2849.9011821746826,2.1665310859680176,28606.290912389755,62563,0,28606.290912389755,0.5521000027656555,2.1780121326446533,10000,31461.618797063828,0.7283593416213989,1.3276382684707642,0.6751199960708618,1.5634821653366089,50000 -2894.8881330490112,2.20430326461792,29026.3101747036,63483,0,29026.3101747036,0.557200014591217,2.149172306060791,10000,31926.710065841675,0.7360742092132568,1.2766711711883545,0.6779199838638306,1.5261484384536743,50000 -2937.920587062836,2.25028133392334,29446.25080990792,64404,0,29446.25080990792,0.5481000542640686,2.1536831855773926,10000,32389.77768635749,0.7432031035423279,1.2548317909240725,0.6798799633979797,1.5307791233062744,50000 -2976.272282600403,2.2853519916534424,29866.379588842392,65324,0,29866.379588842392,0.5658000111579895,2.10740065574646,10000,32848.34146499634,0.7388671636581421,1.2673910856246948,0.6782000064849854,1.518659234046936,50000 -3020.252357006073,2.32259488105774,30286.45598578453,66244,0,30286.45598578453,0.5547000169754028,2.1452724933624268,10000,33312.4827747345,0.7400780916213989,1.258586883544922,0.6813799738883972,1.5165966749191284,50000 -3061.212740421295,2.3604607582092285,30706.626230478287,67165,0,30706.626230478287,0.5561000108718872,2.169353723526001,10000,33773.69924163818,0.7415820360183716,1.2981585264205933,0.6811999678611755,1.5677835941314695,50000 -3102.8121032714844,2.3978052139282227,31126.832008838654,68086,0,31126.832008838654,0.5576000213623047,2.1319308280944824,10000,34235.588967084885,0.7633007764816284,1.1590244770050049,0.6800999641418457,1.5066139698028564,50000 -3146.148453235626,2.4373722076416016,31546.80517745018,69004,0,31546.80517745018,0.556600034236908,2.1219170093536377,10000,34698.98634314537,0.7406054735183716,1.2410019636154177,0.684719979763031,1.4892326593399048,50000 -3189.873576402664,2.472073554992676,31966.791477441788,69925,0,31966.791477441788,0.5660000443458557,2.104238510131836,10000,35162.78013944626,0.7484570145606995,1.2337993383407593,0.6854599714279175,1.5051804780960083,50000 -3236.175390481949,2.508272171020508,32386.71866321564,70847,0,32386.71866321564,0.5552999973297119,2.1310625076293945,10000,35629.093203783035,0.7518359422683716,1.2066234350204468,0.6805199980735779,1.5176000595092771,50000 -3274.3536903858185,2.5455551147460938,32807.06958389282,71770,0,32807.06958389282,0.5648000240325928,2.08892560005188,10000,36087.70830178261,0.747363269329071,1.2343018054962158,0.6899799704551697,1.4829862117767334,50000 -3319.4512207508087,2.5845770835876465,33227.219742536545,72687,0,33227.219742536545,0.5640000104904175,2.0902061462402344,10000,36553.0427236557,0.7515820264816284,1.2129734754562378,0.6888999938964844,1.4854317903518677,50000 -3362.12451672554,2.621159076690674,33647.15845131874,73609,0,33647.15845131874,0.5670000314712524,2.104422807693481,10000,37015.7394669056,0.7533984184265137,1.191983938217163,0.6872400045394897,1.485839605331421,50000 -3404.320331096649,2.6617767810821533,34067.44424152374,74530,0,34067.44424152374,0.5658000111579895,2.069862604141236,10000,37478.30951237679,0.7507421970367432,1.195424199104309,0.6930800080299377,1.4571106433868408,50000 -3446.2519228458405,2.700667142868042,34487.681963682175,75453,0,34487.681963682175,0.5706000328063965,2.06187105178833,10000,37940.56576251984,0.7525194883346558,1.1894110441207886,0.6906599998474121,1.454953908920288,50000 -3487.078436851501,2.736494779586792,34908.01672363281,76374,0,34908.01672363281,0.5688000321388245,2.0531435012817383,10000,38401.809537410736,0.7610741853713989,1.160759210586548,0.6957799792289734,1.4432342052459717,50000 -3529.1288471221924,2.78158974647522,35327.95745301247,77298,0,35327.95745301247,0.5685999989509583,2.070373773574829,10000,38863.893261671066,0.7732812166213989,1.1148550510406494,0.6937199831008911,1.447901725769043,50000 -3570.3713760375977,2.831173181533813,35748.12657356262,78220,0,35748.12657356262,0.5758000016212463,2.019524097442627,10000,39325.401894807816,0.7600781321525574,1.1587432622909546,0.6990599632263184,1.4184653759002686,50000 -3612.3097426891327,2.8709871768951416,36168.48544359207,79142,0,36168.48544359207,0.5730000138282776,2.055783987045288,10000,39787.786851882935,0.7625195384025574,1.148511528968811,0.6945199966430664,1.4347800016403198,50000 -3655.476496696472,2.9082815647125244,36588.77942371368,80065,0,36588.77942371368,0.5708000063896179,2.067552328109741,10000,40251.33292555809,0.7742968797683716,1.1092896461486816,0.697380006313324,1.4453190565109253,50000 -3697.0613508224487,2.9498679637908936,37008.99789881706,80988,0,37008.99789881706,0.5790000557899475,2.017435073852539,10000,40713.225462675095,0.7635741829872131,1.1445281505584717,0.7017799615859985,1.4090139865875244,50000 -3740.842449903488,2.9926469326019287,37429.169956445694,81909,0,37429.169956445694,0.5771000385284424,2.034175157546997,10000,41177.268661022186,0.7632030844688416,1.150298237800598,0.6981399655342102,1.4378381967544556,50000 -3787.0803208351135,3.0316107273101807,37849.39163827896,82831,0,37849.39163827896,0.5746000409126282,2.0344953536987305,10000,41643.815761089325,0.7751367092132568,1.1169837713241575,0.7000600099563599,1.4290475845336914,50000 -3831.722299337387,3.0690486431121826,38269.435145139694,83754,0,38269.435145139694,0.5769000053405762,2.039259433746338,10000,42108.586940288544,0.7632226347923279,1.1648662090301514,0.6999799609184265,1.4325613975524902,50000 -3870.7454376220703,3.106007099151612,38689.39722657204,84676,0,38689.39722657204,0.5788000226020813,2.06295108795166,10000,42567.65654802322,0.7675585746765137,1.1792206764221191,0.7015199661254883,1.4529019594192505,50000 -3915.285789489746,3.146677255630493,39109.64110136032,85598,0,39109.64110136032,0.5811000466346741,2.0271904468536377,10000,43032.52892065048,0.77259761095047,1.1098675727844238,0.7027599811553955,1.4117202758789062,50000 -3962.117981433869,3.1898202896118164,39529.84063959122,86519,0,39529.84063959122,0.5836000442504883,2.0163145065307617,10000,43499.65197634697,0.7715820074081421,1.1164155006408691,0.7065399885177612,1.3979099988937378,50000 -4006.3306062221527,3.2322468757629395,39949.914578437805,87441,0,39949.914578437805,0.5821000337600708,2.0062777996063232,10000,43964.02792263031,0.769726574420929,1.133750319480896,0.7074599862098694,1.397431492805481,50000 -4046.780064105988,3.2729523181915283,40369.85162806511,88363,0,40369.85162806511,0.5845000147819519,1.9934600591659544,10000,44424.50291514397,0.7766796946525574,1.0799094438552856,0.7064200043678284,1.3791130781173706,50000 -4090.746128797531,3.310702085494995,40790.13759255409,89279,0,40790.13759255409,0.5841000080108643,2.0293197631835938,10000,44888.84012579918,0.7904296517372131,1.0638716220855713,0.7069599628448486,1.4218652248382568,50000 -4132.355852603912,3.3663439750671387,41210.48593831062,90199,0,41210.48593831062,0.5837000012397766,1.990799069404602,10000,45350.90107703209,0.7699609398841858,1.117898941040039,0.7085599899291992,1.3838273286819458,50000 -4169.649255514145,3.409395456314087,41630.54138350487,91120,0,41630.54138350487,0.5908000469207764,1.9569255113601685,10000,45808.34008026123,0.7765820026397705,1.0714341402053833,0.7109000086784363,1.358513593673706,50000 -4214.131242513657,3.4492197036743164,42050.45993804932,92041,0,42050.45993804932,0.5905000567436218,1.978685021400452,10000,46272.82831025124,0.7859570384025574,1.049735188484192,0.7099999785423279,1.3754807710647583,50000 -4257.513018131256,3.4884932041168213,42470.80727481842,92962,0,42470.80727481842,0.5901000499725342,1.974666714668274,10000,46736.64430379868,0.7772070169448853,1.0873217582702637,0.7098199725151062,1.380346417427063,50000 -4300.7799434661865,3.5301735401153564,42890.88827753067,93883,0,42890.88827753067,0.5901000499725342,1.9860336780548096,10000,47200.08113312721,0.781445324420929,1.0885802507400513,0.7143999934196472,1.3779401779174805,50000 -4343.885124206543,3.571937799453736,43311.09178900719,94806,0,43311.09178900719,0.5893000364303589,1.966513991355896,10000,47663.47908735275,0.7893944978713989,1.0502266883850098,0.7142399549484253,1.3713054656982422,50000 -4390.079038143158,3.618546485900879,43731.27753829956,95724,0,43731.27753829956,0.5923000574111938,1.9620274305343628,10000,48129.95305871964,0.7759374976158142,1.094131588935852,0.7149199843406677,1.3668842315673828,50000 -4434.304823875427,3.6608870029449454,44151.39954543114,96646,0,44151.39954543114,0.5936000347137451,1.9557676315307613,10000,48594.39071774483,0.7854687571525574,1.0612901449203491,0.7143399715423584,1.3597755432128906,50000 -4477.041138410568,3.7091317176818848,44571.57318592072,97568,0,44571.57318592072,0.594700038433075,1.9495747089385984,10000,49057.39753556252,0.7908984422683716,1.0401599407196045,0.7178399562835693,1.3506696224212646,50000 -4521.90961265564,3.7491395473480233,44991.559143066406,98488,0,44991.559143066406,0.5902000069618225,1.973228454589844,10000,49522.340443611145,0.8031054735183716,0.9849762320518494,0.7144799828529358,1.3556915521621704,50000 -4565.364971160889,3.7899091243743896,45411.88533329964,99411,0,45411.88533329964,0.5999000072479248,1.932340145111084,10000,49986.21019554138,0.7883203029632568,1.0426161289215088,0.7203399538993835,1.3304067850112915,50000 -4606.589636087418,3.83087944984436,45831.8272600174,100332,0,45831.8272600174,0.5913000106811523,1.9910321235656736,10000,50447.46520733833,0.7882617115974426,1.0813069343566897,0.7168599963188171,1.389996886253357,50000 -4646.38002371788,3.875702142715454,46251.89824414253,101255,0,46251.89824414253,0.5976000428199768,1.9517590999603271,10000,50907.4188015461,0.8056640625,1.0057238340377808,0.7194799780845642,1.3632384538650513,50000 -4685.471276283264,3.91915512084961,46672.21717500687,102178,0,46672.21717500687,0.5955000519752502,1.9908602237701416,10000,51366.9196677208,0.785839855670929,1.0791743993759155,0.7178199887275696,1.3750110864639282,50000 -4729.725798368454,3.966155767440796,47092.51544976234,103099,0,47092.51544976234,0.6014000177383423,1.91221821308136,10000,51831.56687259674,0.8001952767372131,1.0049448013305664,0.7236599922180176,1.3200238943099976,50000 -4769.427984714508,4.009620189666748,47512.693658828735,104021,0,47512.693658828735,0.6020000576972961,1.969410419464112,10000,52291.53897356987,0.802050769329071,1.0305854082107544,0.723580002784729,1.3682783842086792,50000 -4815.495166063309,4.050985097885132,47932.840673685074,104942,0,47932.840673685074,0.6025000214576721,1.9171737432479856,10000,52757.84145927429,0.7951952815055847,1.0146361589431765,0.7224000096321106,1.322086215019226,50000 -4855.456739187241,4.464348077774048,48352.49347019196,105860,0,48352.49347019196,0.6021000146865845,1.924709677696228,10000,53217.91585254669,0.79212886095047,1.0310273170471191,0.7222200036048889,1.32786226272583,50000 -4898.399055242538,4.509274482727051,48772.86058592796,106780,0,48772.86058592796,0.6099000573158264,1.9118989706039429,10000,53681.318464279175,0.8042968511581421,0.9922462105751038,0.7252999544143677,1.32797110080719,50000 -4942.739294528961,4.553928375244141,49192.9281001091,107699,0,49192.9281001091,0.6027000546455383,1.9079450368881223,10000,54145.81826877594,0.802734375,0.9875910878181458,0.7257199883460999,1.3120582103729248,50000 -4984.543601036072,4.5956196784973145,49613.25364589691,108619,0,49613.25364589691,0.6051000356674194,1.904710292816162,10000,54608.0376534462,0.8020898103713989,0.997575342655182,0.7286199927330017,1.3085391521453855,50000 -5025.457714557648,4.6387939453125,50033.40543913841,109540,0,50033.40543913841,0.6074000000953674,1.8863749504089355,10000,55069.19483280182,0.8037304282188416,0.9704073071479796,0.7299799919128418,1.2849246263504028,50000 -5071.100483894348,4.6809704303741455,50453.32793235779,110460,0,50453.32793235779,0.6077000498771667,1.892633080482483,10000,55534.84929966927,0.8193163871765137,0.9186491370201112,0.7305399775505066,1.2858415842056274,50000 -5116.349600315094,4.732409715652466,50873.48181271553,111382,0,50873.48181271553,0.6093000173568726,1.9110552072525024,10000,56000.36197376251,0.8045117259025574,0.9946995973587036,0.7307999730110168,1.30134117603302,50000 -5161.157242059708,4.781898736953735,51293.81174302101,112305,0,51293.81174302101,0.609000027179718,1.8916889429092407,10000,56465.59639286995,0.8092968463897705,0.9513814449310304,0.7310999631881714,1.288879632949829,50000 -5204.9848573207855,4.822944164276123,51713.83294534683,113227,0,51713.83294534683,0.6177000403404236,1.8420090675354004,10000,56929.53366589546,0.81849604845047,0.898280680179596,0.7363399863243103,1.2454630136489868,50000 -5244.5484964847565,4.866366386413574,52133.946326971054,114148,0,52133.946326971054,0.6145000457763672,1.8680979013442995,10000,57389.301633358,0.8068945407867432,0.9638850688934326,0.7357199788093567,1.2680931091308594,50000 -5285.962988376617,4.9072349071502686,52554.23753380776,115071,0,52554.23753380776,0.614300012588501,1.8860503435134888,10000,57851.095802783966,0.81201171875,0.9675384163856506,0.7343999743461609,1.2889907360076904,50000 -5329.898587703705,4.956961631774902,52974.53467059136,115991,0,52974.53467059136,0.6172000169754028,1.8698878288269043,10000,58315.42613291741,0.8161523342132568,0.93476402759552,0.7350599765777588,1.2736767530441284,50000 -5371.47057056427,5.00386118888855,53394.7826769352,116909,0,53394.7826769352,0.6185000538825989,1.8774621486663816,10000,58777.340618133545,0.8122069835662842,0.9754149317741394,0.7371199727058411,1.2887877225875854,50000 -5416.606198310852,5.0547590255737305,53814.80868077278,117830,0,53814.80868077278,0.617900013923645,1.8480632305145264,10000,59242.60025715828,0.81898432970047,0.918323814868927,0.7381199598312378,1.2621748447418213,50000 -5461.611189365387,5.096725225448608,54235.02416753769,118753,0,54235.02416753769,0.6194000244140625,1.8396508693695068,10000,59707.91237425804,0.8180468678474426,0.9298332929611206,0.7390599846839905,1.263730764389038,50000 -5502.972692966461,5.1483683586120605,54655.402338027954,119674,0,54655.402338027954,0.6140000224113464,1.8454985618591309,10000,60169.75206565857,0.8340820074081421,0.8498026728630066,0.7407199740409851,1.2424492835998535,50000 -5542.008673429489,5.199536323547363,55075.56313490868,120593,0,55075.56313490868,0.6239000558853149,1.833344459533692,10000,60629.047945261,0.8211132884025574,0.9188494682312012,0.7406799793243408,1.249154806137085,50000 -5584.007612705231,5.241910457611084,55495.675506830215,121514,0,55495.675506830215,0.6218000054359436,1.8167213201522827,10000,61091.24942660332,0.8241796493530273,0.8857353925704956,0.7432799935340881,1.2298856973648071,50000 -5628.865107297897,5.285839796066284,55915.97837305069,122435,0,55915.97837305069,0.6230000257492065,1.8468738794326784,10000,61556.50082373619,0.8314648270606995,0.890949010848999,0.7434200048446655,1.2616345882415771,50000 -5671.9865918159485,5.331271648406982,56336.17573904991,123357,0,56336.17573904991,0.6271000504493713,1.825277090072632,10000,62019.912034511566,0.8240624666213989,0.9000077843666077,0.7458999752998352,1.2291886806488037,50000 -5712.451839923859,5.377405405044556,56756.23712944984,124277,0,56756.23712944984,0.6223000288009644,1.840736985206604,10000,62480.5327205658,0.8270507454872131,0.9012371897697448,0.7419599890708923,1.2539701461791992,50000 -5755.837740182877,5.423417568206787,57176.52000498772,125199,0,57176.52000498772,0.629300057888031,1.7952983379364014,10000,62944.29509925842,0.8361327648162842,0.8307301998138428,0.746999979019165,1.1995173692703247,50000 -5800.669831514359,5.46803092956543,57596.76955342293,126119,0,57596.76955342293,0.6274000406265259,1.814552426338196,10000,63409.46954703331,0.8282421827316284,0.8835758566856384,0.7492600083351135,1.218082070350647,50000 -5839.35227394104,5.522529602050781,58016.67089056969,127039,0,58016.67089056969,0.6255000233650208,1.8085092306137085,10000,63868.15580749512,0.8307812213897705,0.8714902400970459,0.7470600008964539,1.2164604663848877,50000 -5883.110562562943,5.568439722061157,58436.60658454895,127959,0,58436.60658454895,0.6254000067710876,1.7854843139648438,10000,64331.94377684593,0.8356054425239563,0.8296651244163513,0.7488399744033813,1.202234983444214,50000 -5924.156987428665,5.612093925476074,58856.884100198746,128881,0,58856.884100198746,0.6290000081062317,1.7977409362792969,10000,64793.35998272896,0.8422460556030273,0.8313546776771545,0.752020001411438,1.212985873222351,50000 -5968.510313987732,5.657716512680054,59276.93350100517,129804,0,59276.93350100517,0.629800021648407,1.786934494972229,10000,65257.85636162758,0.8338280916213989,0.8379896283149719,0.7499600052833557,1.1926852464675903,50000 -6010.993898153305,5.702925443649292,59696.888902425766,130725,0,59696.888902425766,0.6312000155448914,1.7743438482284546,10000,65720.388225317,0.8382812142372131,0.8320725560188293,0.7503199577331543,1.1947848796844482,50000 -6056.08829331398,5.74944281578064,60117.00132703781,131647,0,60117.00132703781,0.6288000345230103,1.7920557260513306,10000,66185.6890695095,0.8476171493530273,0.8037877678871155,0.7534599900245667,1.1936339139938354,50000 -6099.994526147842,5.798482418060303,60537.26391768456,132568,0,60537.26391768456,0.6355000138282776,1.788927674293518,10000,66649.95460033417,0.8365820050239563,0.8467892408370972,0.7517600059509277,1.205349326133728,50000 -6138.977798938751,5.845485210418701,60957.49088454247,133490,0,60957.49088454247,0.6336000561714172,1.7881088256835938,10000,67109.25989794731,0.8384960889816284,0.835568904876709,0.7537399530410767,1.1961098909378052,50000 -6185.34400844574,5.892377853393555,61377.691182136536,134411,0,61377.691182136536,0.6336000561714172,1.7869489192962646,10000,67575.92088413239,0.8448827862739563,0.7935135364532471,0.7524799704551697,1.1886016130447388,50000 -6228.144732713699,5.9372031688690186,61797.867267131805,135329,0,61797.867267131805,0.6373000144958496,1.7893301248550415,10000,68038.98924589157,0.8392773270606995,0.8438341617584229,0.7569599747657776,1.196811318397522,50000 -6270.500958204269,5.982234716415405,62218.01303982735,136252,0,62218.01303982735,0.6366000175476074,1.7862942218780518,10000,68501.58464884758,0.84193354845047,0.8302668333053589,0.758080005645752,1.1913862228393557,50000 -6314.320621013641,6.038216590881348,62637.98712944984,137171,0,62637.98712944984,0.6381000280380249,1.7534481287002563,10000,68965.4830365181,0.8501171469688416,0.7863731980323792,0.7578799724578857,1.1653896570205688,50000 -6357.152781248093,6.084154844284058,63058.27062392235,138092,0,63058.27062392235,0.6454000473022461,1.7836123704910278,10000,69428.69195985794,0.8501366972923279,0.8256229758262634,0.759880006313324,1.1994614601135254,50000 -6403.32272028923,6.133307695388794,63478.28254342079,139011,0,63478.28254342079,0.6362000107765198,1.7635866403579712,10000,69894.97005820274,0.8479101657867432,0.7939903736114502,0.7591599822044373,1.1697139739990234,50000 -6448.088568687439,6.181832790374756,63898.31989693642,139931,0,63898.31989693642,0.6403000354766846,1.7529109716415403,10000,70359.86943912506,0.8531054258346558,0.7984217405319214,0.7590999603271484,1.1816964149475098,50000 -6494.210072994232,6.228420257568359,64318.303194999695,140851,0,64318.303194999695,0.643500030040741,1.751853346824646,10000,70826.06875395775,0.8634960651397705,0.751483142375946,0.7616999745368958,1.166806936264038,50000 -6536.446292638779,6.2816667556762695,64738.19510626793,141770,0,64738.19510626793,0.6408000588417053,1.749077081680298,10000,71288.30664849281,0.849609375,0.7915339469909668,0.7606199979782104,1.1642590761184692,50000 -6584.042410612106,6.332995891571045,65158.11840724945,142689,0,65158.11840724945,0.6442000269889832,1.7491698265075684,10000,71755.92424988747,0.8543359041213989,0.7828488945960999,0.7647799849510193,1.1637057065963743,50000 -6629.484586715698,6.386404514312744,65578.4606962204,143611,0,65578.4606962204,0.643500030040741,1.7458293437957764,10000,72221.8099834919,0.86376953125,0.744042694568634,0.7639600038528442,1.1600041389465332,50000 -6674.7976150512695,6.435186386108398,65998.61735081673,144532,0,65998.61735081673,0.6429000496864319,1.74350905418396,10000,72687.37620162964,0.8549999594688416,0.7654373645782471,0.7643399834632874,1.1485599279403689,50000 -6719.138984918594,6.487768888473511,66418.83790254593,145453,0,66418.83790254593,0.6452000141143799,1.722506761550903,10000,73152.03741383553,0.8571484088897705,0.7539299130439758,0.7656399607658386,1.1399657726287842,50000 -6764.900428771973,6.540493726730347,66838.8134508133,146374,0,66838.8134508133,0.6457000374794006,1.715340256690979,10000,73617.8742146492,0.8620703220367432,0.7350453734397888,0.7666599750518799,1.1349040269851685,50000 -6810.045199871063,6.589900016784668,67258.91440343857,147297,0,67258.91440343857,0.6461000442504883,1.7426823377609253,10000,74083.21685099602,0.8604687452316284,0.7560346722602844,0.766759991645813,1.1506201028823853,50000 -6853.139424085617,6.635733366012573,67679.01679587364,148218,0,67679.01679587364,0.6454000473022461,1.7250229120254517,10000,74546.50698709488,0.8648632764816284,0.7450129389762878,0.766979992389679,1.1421996355056765,50000 -6898.35661482811,7.087894439697266,68098.65920686722,149138,0,68098.65920686722,0.6490000486373901,1.7095143795013428,10000,75011.86696529388,0.8698437213897705,0.7154492139816284,0.768839955329895,1.1309633255004885,50000 -6943.96710062027,7.140713453292847,68518.84233808517,150058,0,68518.84233808517,0.650700032711029,1.709939956665039,10000,75477.76089262962,0.8729296922683716,0.7006752490997314,0.7681399583816528,1.1254583597183228,50000 -6986.741266012192,7.1879754066467285,68939.2911374569,150978,0,68939.2911374569,0.6529000401496887,1.707403540611267,10000,75941.07792234421,0.8682031035423279,0.7284000515937805,0.770039975643158,1.1370346546173096,50000 -7028.515541791916,7.23703145980835,69359.93056178093,151901,0,69359.93056178093,0.6514000296592712,1.7238764762878418,10000,76403.58780241013,0.8696093559265137,0.7336189150810242,0.7696399688720703,1.1401625871658323,50000 -7070.822064638138,7.29412579536438,69779.99265003204,152821,0,69779.99265003204,0.6492000222206116,1.7144511938095093,10000,76866.06092453003,0.8742968440055847,0.6973341703414917,0.7716000080108643,1.1223145723342896,50000 -7112.63879776001,7.344912052154541,70200.02810454369,153740,0,70200.02810454369,0.6533000469207764,1.7147181034088137,10000,77328.01024198532,0.8701757788658142,0.7270835041999817,0.7702800035476685,1.1377615928649902,50000 -7154.707575559616,7.395971059799194,70620.03921198845,154661,0,70620.03921198845,0.659000039100647,1.699967861175537,10000,77790.1883494854,0.873828113079071,0.7049767971038818,0.7745800018310547,1.1225615739822388,50000 -7201.448133707047,7.44741940498352,71040.05707144737,155577,0,71040.05707144737,0.6546000242233276,1.675015568733215,10000,78257.04546570778,0.8754296898841858,0.6761949062347412,0.7723399996757507,1.1067827939987185,50000 -7249.346954345703,7.500328063964844,71460.33196163177,156497,0,71460.33196163177,0.6582000255584717,1.6929787397384644,10000,78725.32055997849,0.8728905916213989,0.7006733417510986,0.775439977645874,1.110991358757019,50000 -7291.650812864304,7.549129009246826,71880.53656959534,157419,0,71880.53656959534,0.65420001745224,1.6954491138458252,10000,79187.9251627922,0.8755663633346558,0.6870554089546204,0.7741000056266785,1.113128900527954,50000 -7333.049576044083,7.60271143913269,72300.74264979362,158341,0,72300.74264979362,0.6583000421524048,1.681814670562744,10000,79649.6304256916,0.8795117139816284,0.6774889230728149,0.7754999995231628,1.105984091758728,50000 -7378.887080192566,7.6535325050354,72721.07917380333,159260,0,72721.07917380333,0.6544000506401062,1.6908684968948364,10000,80115.90304541588,0.8759960532188416,0.6908382773399353,0.775879979133606,1.1087701320648191,50000 -7426.482050657272,7.7148168087005615,73141.11105513573,160179,0,73141.11105513573,0.6558000445365906,1.691611409187317,10000,80583.63886260986,0.8782421946525574,0.6889281868934631,0.7750399708747864,1.1090731620788574,50000 -7472.475435972214,7.764083385467529,73561.23503017426,161101,0,73561.23503017426,0.6583000421524048,1.6969735622406006,10000,81049.8535144329,0.8794531226158142,0.6894750595092773,0.7768999934196472,1.1162922382354736,50000 -7511.793575048447,7.818175554275513,73981.34681868553,162021,0,73981.34681868553,0.6588000059127808,1.671212911605835,10000,81509.38480234146,0.8860155940055847,0.6461138129234314,0.7784199714660645,1.0914040803909302,50000 -7556.840314865112,7.872089624404907,74401.3993074894,162943,0,74401.3993074894,0.6619000434875488,1.6778147220611572,10000,81974.58531212807,0.8807421922683716,0.6805300712585449,0.7788800001144409,1.106682062149048,50000 -7604.018439769745,7.928055763244629,74821.57655787468,163867,0,74821.57655787468,0.6641000509262085,1.6616779565811155,10000,82442.0477347374,0.8832812309265137,0.6668259501457214,0.7802599668502808,1.0908979177474976,50000 -7649.971467733383,7.979364395141602,75241.48506331444,164789,0,75241.48506331444,0.6593000292778015,1.6780301332473757,10000,82908.00784730911,0.887988269329071,0.6550906896591187,0.7813000082969666,1.0903257131576538,50000 -7691.369910478592,8.033177137374878,75661.61553740501,165714,0,75661.61553740501,0.663100004196167,1.6789699792861938,10000,83369.6378827095,0.8838671445846558,0.6711214780807495,0.7797999978065491,1.0984795093536377,50000 -7736.580994844437,8.090880393981934,76081.84141349792,166636,0,76081.84141349792,0.6624000072479248,1.6653751134872437,10000,83835.18001461029,0.8857812285423279,0.6534203290939331,0.7819399833679199,1.0902855396270752,50000 -7780.93232178688,8.14522910118103,76501.81643271446,167556,0,76501.81643271446,0.6653000116348267,1.656856894493103,10000,84299.6081571579,0.8882226347923279,0.6422659754753113,0.7817999720573425,1.0872855186462402,50000 -7826.474480390549,8.208962202072144,76921.93040585518,168477,0,76921.93040585518,0.6647000312805176,1.650307297706604,10000,84765.37529754639,0.8875390291213989,0.6451851725578308,0.7829200029373169,1.0794780254364014,50000 -7871.694277763367,8.263505935668945,77342.00812506676,169398,0,77342.00812506676,0.6653000116348267,1.6556425094604492,10000,85230.77464270592,0.8874218463897705,0.6406536102294922,0.7831999659538269,1.078544020652771,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/measurements.csv deleted file mode 100644 index 7f132061d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1885 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.2923718,6.9077535,,,,,,,,,,,,,, -1,,,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,35.84359121322632,62.57128691673279,35.84359121322632,26.727606058120728,0.0,0.0 -100,0.33674377,6.9046965,,,,,,,,,,,,,, -200,0.4063701,6.8858366,,,,,,,,,,,,,, -300,0.48926818,6.843835,,,,,,,,,,,,,, -400,0.48969832,6.8154063,,,,,,,,,,,,,, -500,0.60185885,6.82589,,,,,,,,,,,,,, -600,0.66834307,6.7489967,,,,,,,,,,,,,, -700,1.2775778,6.6789613,,,,,,,,,,,,,, -800,0.6667305,6.749059,,,,,,,,,,,,,, -844,,,0.0159765612334013,6.382034778594971,0.015999998897314,6.395730972290039,50000.0,0.0114000001922249,6.438459873199463,10000.0,455.78622579574585,526.145667552948,455.78622579574585,70.29547381401062,0.020416259765625,0.0 -900,0.85562146,6.703063,,,,,,,,,,,,,, -1000,1.1817167,6.5815806,,,,,,,,,,,,,, -1100,1.2813425,6.789425,,,,,,,,,,,,,, -1200,1.2563598,6.5076914,,,,,,,,,,,,,, -1300,1.294651,6.763404,,,,,,,,,,,,,, -1400,1.0464393,6.523894,,,,,,,,,,,,,, -1500,1.0369308,6.5289674,,,,,,,,,,,,,, -1600,1.0911989,6.3522286,,,,,,,,,,,,,, -1700,1.1451918,6.3072,,,,,,,,,,,,,, -1764,,,0.0448632799088954,5.828395366668701,0.0431599989533424,5.85699987411499,50000.0,0.033500000834465,5.981106758117676,10000.0,876.1266677379608,986.7028484344482,876.1266677379608,110.43268370628356,0.0518798828125,0.0 -1800,1.0708145,6.337923,,,,,,,,,,,,,, -1900,1.4957418,6.6198597,,,,,,,,,,,,,, -2000,1.2915622,6.694346,,,,,,,,,,,,,, -2100,1.1788051,6.270634,,,,,,,,,,,,,, -2200,0.8637411,6.4867997,,,,,,,,,,,,,, -2300,1.025616,6.255562,,,,,,,,,,,,,, -2400,1.5580543,6.192541,,,,,,,,,,,,,, -2500,0.6746225,6.6096997,,,,,,,,,,,,,, -2600,1.378699,6.459697,,,,,,,,,,,,,, -2684,,,0.0735937505960464,5.448266983032227,0.0674799978733062,5.489231586456299,50000.0,0.0502000041306018,5.668014049530029,10000.0,1296.264136552811,1450.8591334819794,1296.264136552811,154.37527751922607,0.0791792869567871,0.0 -2700,0.92075664,6.3078847,,,,,,,,,,,,,, -2800,1.3201079,6.1958914,,,,,,,,,,,,,, -2900,0.9286857,6.148087,,,,,,,,,,,,,, -3000,1.2900716,6.098473,,,,,,,,,,,,,, -3100,1.0412908,6.0167084,,,,,,,,,,,,,, -3200,0.8665591,6.673191,,,,,,,,,,,,,, -3300,1.2239974,6.051734,,,,,,,,,,,,,, -3400,0.9604957,6.2802415,,,,,,,,,,,,,, -3500,1.1483765,6.015986,,,,,,,,,,,,,, -3600,1.150696,6.047162,,,,,,,,,,,,,, -3604,,,0.1104687452316284,5.068248271942139,0.1017199978232383,5.126224994659424,50000.0,0.0767000019550323,5.356302261352539,10000.0,1716.3032870292664,1915.3701028823853,1716.3032870292664,198.7751681804657,0.1038756370544433,0.0 -3700,1.1890613,5.9568295,,,,,,,,,,,,,, -3800,1.0220488,5.935439,,,,,,,,,,,,,, -3900,1.084462,5.87119,,,,,,,,,,,,,, -4000,0.9336324,6.5941367,,,,,,,,,,,,,, -4100,1.1681054,6.600436,,,,,,,,,,,,,, -4200,1.0360715,5.9620743,,,,,,,,,,,,,, -4300,0.9179452,6.56942,,,,,,,,,,,,,, -4400,1.008587,5.774909,,,,,,,,,,,,,, -4500,0.932321,6.3558807,,,,,,,,,,,,,, -4522,,,0.1553320288658142,4.668953895568848,0.1380600035190582,4.754825592041016,50000.0,0.0994000062346458,5.049302577972412,10000.0,2136.638533353805,2379.705213546753,2136.638533353805,242.7034797668457,0.1274757385253906,0.0 -4600,1.3031604,5.73704,,,,,,,,,,,,,, -4700,0.9964568,5.8053427,,,,,,,,,,,,,, -4800,1.2845966,5.662038,,,,,,,,,,,,,, -4900,0.98333365,6.1394725,,,,,,,,,,,,,, -5000,0.80472636,6.023843,,,,,,,,,,,,,, -5100,1.1545293,5.747033,,,,,,,,,,,,,, -5200,1.0859199,5.6784835,,,,,,,,,,,,,, -5300,1.2426356,5.601574,,,,,,,,,,,,,, -5400,1.007134,5.569118,,,,,,,,,,,,,, -5441,,,0.1989648342132568,4.309050559997559,0.1846799999475479,4.396907806396484,50000.0,0.1346000134944915,4.739532470703125,10000.0,2556.660418987274,2840.099986076355,2556.660418987274,283.0039734840393,0.1524102687835693,0.0 -5500,1.1989299,5.5507,,,,,,,,,,,,,, -5600,0.86186737,6.4455013,,,,,,,,,,,,,, -5700,0.8330417,6.491646,,,,,,,,,,,,,, -5800,1.1059186,5.5975876,,,,,,,,,,,,,, -5900,1.3161796,5.4064965,,,,,,,,,,,,,, -6000,0.9892391,5.460799,,,,,,,,,,,,,, -6100,0.9370166,5.384869,,,,,,,,,,,,,, -6200,1.1177237,5.6712084,,,,,,,,,,,,,, -6300,0.8866536,6.446792,,,,,,,,,,,,,, -6362,,,0.2446484267711639,3.974132537841797,0.2258399873971939,4.073367118835449,50000.0,0.1720000058412552,4.464690208435059,10000.0,2976.674590110779,3302.356697320938,2976.674590110779,325.1707737445832,0.1802799701690673,0.0 -6400,0.83973044,5.9674263,,,,,,,,,,,,,, -6500,1.3383994,5.4582286,,,,,,,,,,,,,, -6600,0.9380531,5.3648887,,,,,,,,,,,,,, -6700,0.8957404,5.5783157,,,,,,,,,,,,,, -6800,0.9548847,5.439979,,,,,,,,,,,,,, -6900,1.0292357,5.23077,,,,,,,,,,,,,, -7000,0.9817385,5.2977514,,,,,,,,,,,,,, -7100,1.4766906,5.2035823,,,,,,,,,,,,,, -7200,1.1061909,5.23722,,,,,,,,,,,,,, -7286,,,0.2900195121765136,3.6638667583465576,0.2622399926185608,3.812857151031494,50000.0,0.1998000144958496,4.2360687255859375,10000.0,3396.9373681545258,3763.586677074432,3396.9373681545258,366.0614733695984,0.2094607353210449,0.0 -7300,0.9074705,6.4326572,,,,,,,,,,,,,, -7400,1.1805373,5.186237,,,,,,,,,,,,,, -7500,1.0898749,5.0258913,,,,,,,,,,,,,, -7600,1.0321674,5.2630453,,,,,,,,,,,,,, -7700,1.0262991,5.079749,,,,,,,,,,,,,, -7800,1.1920701,5.1778646,,,,,,,,,,,,,, -7900,1.0751644,5.423164,,,,,,,,,,,,,, -8000,0.9377591,6.3007064,,,,,,,,,,,,,, -8100,1.1142006,5.1478624,,,,,,,,,,,,,, -8200,0.7948936,5.6796856,,,,,,,,,,,,,, -8209,,,0.3228124976158142,3.440293550491333,0.3006999790668487,3.549184560775757,50000.0,0.233800008893013,4.013123989105225,10000.0,3817.3410184383392,4223.172071218491,3817.3410184383392,405.1697099208832,0.2350239753723144,0.0 -8300,1.0856239,5.1239257,,,,,,,,,,,,,, -8400,1.172391,4.9724565,,,,,,,,,,,,,, -8500,1.0942605,5.0558023,,,,,,,,,,,,,, -8600,0.8508058,5.590517,,,,,,,,,,,,,, -8700,0.77456146,6.098136,,,,,,,,,,,,,, -8800,1.0247864,4.9231453,,,,,,,,,,,,,, -8900,0.73776543,5.9791937,,,,,,,,,,,,,, -9000,0.89799047,5.179347,,,,,,,,,,,,,, -9100,1.1388535,4.9333806,,,,,,,,,,,,,, -9131,,,0.3623827993869781,3.1926872730255127,0.3360999822616577,3.323276996612549,50000.0,0.2590000033378601,3.83144736289978,10000.0,4237.481600284576,4684.190475702286,4237.481600284576,445.9749083518982,0.2604801654815674,0.0 -9200,1.0120063,4.854613,,,,,,,,,,,,,, -9300,0.80826503,5.804434,,,,,,,,,,,,,, -9400,1.1012442,4.7596607,,,,,,,,,,,,,, -9500,0.7283229,6.2626143,,,,,,,,,,,,,, -9600,0.7687248,6.028411,,,,,,,,,,,,,, -9700,0.6522203,6.105233,,,,,,,,,,,,,, -9800,0.99194324,4.7352743,,,,,,,,,,,,,, -9900,0.7046742,6.188324,,,,,,,,,,,,,, -10000,1.0114464,4.7251167,,,,,,,,,,,,,, -10053,,,0.394843727350235,3.041377305984497,0.3632999956607818,3.2022409439086914,50000.0,0.2822000086307525,3.700833797454834,10000.0,4657.512858867645,5147.245245218277,4657.512858867645,488.92419385910034,0.2864758968353271,0.0 -10100,0.76981896,6.172821,,,,,,,,,,,,,, -10200,0.73234284,5.6774035,,,,,,,,,,,,,, -10300,0.7807722,5.5701795,,,,,,,,,,,,,, -10400,0.88645804,5.2211742,,,,,,,,,,,,,, -10500,0.94791096,5.1953955,,,,,,,,,,,,,, -10600,0.9804063,4.6643753,,,,,,,,,,,,,, -10700,0.95055366,4.704426,,,,,,,,,,,,,, -10800,0.867366,4.6806173,,,,,,,,,,,,,, -10900,0.6998712,5.4808254,,,,,,,,,,,,,, -10977,,,0.4287695288658142,2.809704542160034,0.3992599844932556,2.950516700744629,50000.0,0.3065000176429748,3.502833366394043,10000.0,5077.823751926422,5608.130947113037,5077.823751926422,529.4197680950165,0.3179991245269775,0.0 -11000,0.90920866,4.5562406,,,,,,,,,,,,,, -11100,1.0056124,4.591115,,,,,,,,,,,,,, -11200,0.98145705,4.7128325,,,,,,,,,,,,,, -11300,0.9313666,4.630425,,,,,,,,,,,,,, -11400,0.7350903,6.0470095,,,,,,,,,,,,,, -11500,0.77435917,5.0940123,,,,,,,,,,,,,, -11600,0.9553469,4.587885,,,,,,,,,,,,,, -11700,0.75386643,5.5478096,,,,,,,,,,,,,, -11800,0.9201107,4.510212,,,,,,,,,,,,,, -11899,,,0.461249977350235,2.6188957691192627,0.4248999953269958,2.788549900054932,50000.0,0.3294000029563904,3.33755874633789,10000.0,5497.804342985153,6068.620263576508,5497.804342985153,569.8549757003784,0.3435926437377929,0.0 -11900,0.7631526,6.0211754,,,,,,,,,,,,,, -12000,1.0048965,4.398296,,,,,,,,,,,,,, -12100,0.84749156,4.8807044,,,,,,,,,,,,,, -12200,0.94436425,4.5795264,,,,,,,,,,,,,, -12300,0.9539428,4.4433274,,,,,,,,,,,,,, -12400,0.8693032,4.544551,,,,,,,,,,,,,, -12500,1.0724485,4.4300094,,,,,,,,,,,,,, -12600,0.88223815,4.878549,,,,,,,,,,,,,, -12700,0.8647204,4.4917903,,,,,,,,,,,,,, -12800,0.7752117,5.490868,,,,,,,,,,,,,, -12820,,,0.4747070074081421,2.617384910583496,0.4347599744796753,2.79663348197937,50000.0,0.3380000293254852,3.348545789718628,10000.0,5918.146535158157,6531.268862962723,5918.146535158157,612.0777878761292,0.3792257308959961,0.0 -12900,0.8703473,4.4047575,,,,,,,,,,,,,, -13000,0.96747786,4.3928986,,,,,,,,,,,,,, -13100,0.9273522,4.458118,,,,,,,,,,,,,, -13200,0.9355625,4.3636913,,,,,,,,,,,,,, -13300,0.8947779,4.3138967,,,,,,,,,,,,,, -13400,0.61967665,6.07794,,,,,,,,,,,,,, -13500,0.9112099,4.396241,,,,,,,,,,,,,, -13600,0.67574847,5.4873924,,,,,,,,,,,,,, -13700,0.68379563,5.4195094,,,,,,,,,,,,,, -13743,,,0.4976562261581421,2.4742987155914307,0.4613799750804901,2.645811080932617,50000.0,0.3629000186920166,3.206737995147705,10000.0,6338.404702425003,6990.997879266739,6338.404702425003,651.4733724594116,0.4065427780151367,0.0 -13800,0.78958863,5.1223226,,,,,,,,,,,,,, -13900,0.67408615,5.8126526,,,,,,,,,,,,,, -14000,0.87002075,4.3148184,,,,,,,,,,,,,, -14100,1.0186695,4.333741,,,,,,,,,,,,,, -14200,0.9167864,4.3329034,,,,,,,,,,,,,, -14300,0.76961213,5.3459496,,,,,,,,,,,,,, -14400,0.9185179,4.387329,,,,,,,,,,,,,, -14500,0.86972785,4.223071,,,,,,,,,,,,,, -14600,0.8768453,4.2971225,,,,,,,,,,,,,, -14665,,,0.5162500143051147,2.3231639862060547,0.4807399809360504,2.4985995292663574,50000.0,0.3760000169277191,3.085344791412353,10000.0,6758.347489833832,7452.263117074966,6758.347489833832,692.7202491760254,0.4355850219726562,0.0 -14700,0.97850513,4.291848,,,,,,,,,,,,,, -14800,0.90839267,4.3275723,,,,,,,,,,,,,, -14900,0.70053047,6.0849156,,,,,,,,,,,,,, -15000,0.81981725,4.931574,,,,,,,,,,,,,, -15100,0.90692174,4.256307,,,,,,,,,,,,,, -15200,0.8340263,4.651666,,,,,,,,,,,,,, -15300,0.79075676,4.460001,,,,,,,,,,,,,, -15400,0.89307475,4.2295003,,,,,,,,,,,,,, -15500,0.65514,5.729743,,,,,,,,,,,,,, -15584,,,0.5301562547683716,2.32466983795166,0.4898599982261657,2.5129311084747314,50000.0,0.3820000290870666,3.095673084259033,10000.0,7178.313052892685,7910.607768535614,7178.313052892685,731.0242302417755,0.4632751941680908,0.0 -15600,0.7713592,5.030924,,,,,,,,,,,,,, -15700,0.9066202,4.19041,,,,,,,,,,,,,, -15800,0.69315076,5.9217052,,,,,,,,,,,,,, -15900,0.8882671,4.5573015,,,,,,,,,,,,,, -16000,0.7689,4.814807,,,,,,,,,,,,,, -16100,0.7058724,5.9037657,,,,,,,,,,,,,, -16200,0.91377074,4.2035336,,,,,,,,,,,,,, -16300,0.90990156,4.30245,,,,,,,,,,,,,, -16400,0.93872714,4.2610145,,,,,,,,,,,,,, -16500,0.94362074,4.1719313,,,,,,,,,,,,,, -16507,,,0.564160168170929,2.104198932647705,0.505620002746582,2.3738718032836914,50000.0,0.3920000195503235,2.980851411819458,10000.0,7598.365840673447,8369.278602838516,7598.365840673447,769.5679693222046,0.4897994995117187,0.0 -16600,1.0046117,4.19418,,,,,,,,,,,,,, -16700,0.91781884,4.1669846,,,,,,,,,,,,,, -16800,0.9296272,4.307294,,,,,,,,,,,,,, -16900,0.7795655,5.087185,,,,,,,,,,,,,, -17000,0.919582,4.2397585,,,,,,,,,,,,,, -17100,0.9444323,4.1057606,,,,,,,,,,,,,, -17200,0.88395774,4.196116,,,,,,,,,,,,,, -17300,0.9575526,4.0549865,,,,,,,,,,,,,, -17400,0.7864679,4.7574515,,,,,,,,,,,,,, -17428,,,0.5541601181030273,2.172273874282837,0.5168799757957458,2.334871292114258,50000.0,0.4103000164031982,2.918965101242065,10000.0,8018.321758508682,8827.351588010788,8018.321758508682,807.6075274944305,0.5183203220367432,0.0 -17500,0.87038857,4.173396,,,,,,,,,,,,,, -17600,0.86909974,4.1625886,,,,,,,,,,,,,, -17700,0.9147274,4.1087637,,,,,,,,,,,,,, -17800,0.87619877,4.305215,,,,,,,,,,,,,, -17900,1.0380177,4.1775837,,,,,,,,,,,,,, -18000,0.9199589,4.1956015,,,,,,,,,,,,,, -18100,0.77337164,5.853806,,,,,,,,,,,,,, -18200,0.92530566,4.0709643,,,,,,,,,,,,,, -18300,0.95621777,4.2020216,,,,,,,,,,,,,, -18351,,,0.5694726705551147,2.0965421199798584,0.5231199860572815,2.2878384590148926,50000.0,0.4129000306129455,2.89074444770813,10000.0,8438.259444952011,9290.145105838776,8438.259444952011,850.389208316803,0.544304609298706,0.0 -18400,0.5990389,5.7885094,,,,,,,,,,,,,, -18500,0.87131983,4.113435,,,,,,,,,,,,,, -18600,0.9071818,4.0725975,,,,,,,,,,,,,, -18700,0.8972317,4.1290693,,,,,,,,,,,,,, -18800,0.97096974,4.0888042,,,,,,,,,,,,,, -18900,0.9886203,4.0572553,,,,,,,,,,,,,, -19000,0.9407529,4.036454,,,,,,,,,,,,,, -19100,0.8050175,5.3096128,,,,,,,,,,,,,, -19200,0.8974779,4.2227497,,,,,,,,,,,,,, -19264,,,0.5887304544448853,1.99932599067688,0.5343999862670898,2.2401058673858643,50000.0,0.4234000146389007,2.851503372192383,10000.0,8858.576433181763,9747.34814119339,8858.576433181763,887.2015011310577,0.5706558227539062,0.0 -19300,1.0723212,4.0898614,,,,,,,,,,,,,, -19400,0.9452801,4.3841867,,,,,,,,,,,,,, -19500,0.89925486,3.996438,,,,,,,,,,,,,, -19600,1.239031,4.157234,,,,,,,,,,,,,, -19700,0.98853594,3.9663658,,,,,,,,,,,,,, -19800,0.8291298,4.2461286,,,,,,,,,,,,,, -19900,0.83751434,4.2083907,,,,,,,,,,,,,, -20000,0.9607646,4.015471,,,,,,,,,,,,,, -20100,0.80904275,4.811954,,,,,,,,,,,,,, -20183,,,0.5824609398841858,2.048027515411377,0.540619969367981,2.229803562164306,50000.0,0.426000028848648,2.823843002319336,10000.0,9278.755592107773,10209.370984315872,9278.755592107773,928.9679560661316,0.6004471778869629,0.0 -20200,0.672238,5.820194,,,,,,,,,,,,,, -20300,0.7989529,4.5765185,,,,,,,,,,,,,, -20400,0.6921665,5.7860928,,,,,,,,,,,,,, -20500,0.7213916,5.7907896,,,,,,,,,,,,,, -20600,0.9851684,4.1294923,,,,,,,,,,,,,, -20700,0.9413569,4.032888,,,,,,,,,,,,,, -20800,0.8137386,5.3580856,,,,,,,,,,,,,, -20900,0.8355115,4.6588182,,,,,,,,,,,,,, -21000,0.7417718,5.2892513,,,,,,,,,,,,,, -21100,1.0040201,4.028834,,,,,,,,,,,,,, -21105,,,0.5959179401397705,1.9627938270568848,0.5534799695014954,2.157258987426758,50000.0,0.4392000138759613,2.759177923202514,10000.0,9699.127198934557,10674.098667144775,9699.127198934557,973.2437303066254,0.6334230899810791,0.0 -21200,0.9409609,4.149243,,,,,,,,,,,,,, -21300,0.67530566,5.789447,,,,,,,,,,,,,, -21400,0.82450956,4.555937,,,,,,,,,,,,,, -21500,0.73151,5.7009025,,,,,,,,,,,,,, -21600,0.6891664,5.6466417,,,,,,,,,,,,,, -21700,0.7250268,5.504397,,,,,,,,,,,,,, -21800,1.0437089,4.024001,,,,,,,,,,,,,, -21900,0.9627467,4.0177855,,,,,,,,,,,,,, -22000,0.92353344,3.9924092,,,,,,,,,,,,,, -22026,,,0.6109570264816284,1.83113980293274,0.5616799592971802,2.0698909759521484,50000.0,0.4387000203132629,2.6716549396514893,10000.0,10119.155968666077,11133.25730252266,10119.155968666077,1012.2991769313812,0.6604936122894287,0.0 -22100,0.95308644,4.1106215,,,,,,,,,,,,,, -22200,0.84209543,4.725225,,,,,,,,,,,,,, -22300,0.92599386,3.9409804,,,,,,,,,,,,,, -22400,0.69385386,5.38069,,,,,,,,,,,,,, -22500,0.81483877,4.4936476,,,,,,,,,,,,,, -22600,0.8411021,4.319414,,,,,,,,,,,,,, -22700,1.0727752,3.9811873,,,,,,,,,,,,,, -22800,0.90398806,4.3253503,,,,,,,,,,,,,, -22900,0.9860827,4.1307535,,,,,,,,,,,,,, -22949,,,0.6087695360183716,1.8961181640625,0.5664600133895874,2.086631774902344,50000.0,0.450300008058548,2.7123382091522217,10000.0,10539.468144655228,11591.783080339432,10539.468144655228,1050.435089111328,0.6903214454650879,0.0 -23000,0.9830578,4.000164,,,,,,,,,,,,,, -23100,0.96162516,3.9216933,,,,,,,,,,,,,, -23200,1.085632,3.9343863,,,,,,,,,,,,,, -23300,0.75634855,4.7256484,,,,,,,,,,,,,, -23400,0.97230554,3.922896,,,,,,,,,,,,,, -23500,0.7121981,5.5389934,,,,,,,,,,,,,, -23600,0.8619268,4.3977947,,,,,,,,,,,,,, -23700,0.9055153,3.8189762,,,,,,,,,,,,,, -23800,0.8419929,4.2072783,,,,,,,,,,,,,, -23869,,,0.6216992139816284,1.848590850830078,0.5771799683570862,2.052497148513794,50000.0,0.4593000113964081,2.6749086380004883,10000.0,10959.501294612885,12054.205533742905,10959.501294612885,1092.7479412555697,0.7183361053466797,0.0 -23900,0.92735654,3.9818678,,,,,,,,,,,,,, -24000,0.76759666,4.782651,,,,,,,,,,,,,, -24100,0.9277711,4.111342,,,,,,,,,,,,,, -24200,0.93960136,4.554867,,,,,,,,,,,,,, -24300,0.96571296,3.92684,,,,,,,,,,,,,, -24400,0.8480958,5.158386,,,,,,,,,,,,,, -24500,0.9053514,3.950931,,,,,,,,,,,,,, -24600,0.95884526,3.8965652,,,,,,,,,,,,,, -24700,0.81208414,4.844012,,,,,,,,,,,,,, -24790,,,0.6266406178474426,1.8320873975753784,0.5758000016212463,2.063634157180786,50000.0,0.4606000185012817,2.6689186096191406,10000.0,11379.689157009125,12514.022708654404,11379.689157009125,1132.2981944084167,0.7464418411254883,0.0 -24800,0.7677364,5.64077,,,,,,,,,,,,,, -24900,0.80518705,5.4278736,,,,,,,,,,,,,, -25000,0.7293714,5.73083,,,,,,,,,,,,,, -25100,0.96283007,4.3236623,,,,,,,,,,,,,, -25200,0.7665032,4.813015,,,,,,,,,,,,,, -25300,0.79686326,5.453378,,,,,,,,,,,,,, -25400,0.9432991,3.8491426,,,,,,,,,,,,,, -25500,0.98040676,3.9208298,,,,,,,,,,,,,, -25600,0.80022466,4.9383388,,,,,,,,,,,,,, -25700,0.855012,4.0638247,,,,,,,,,,,,,, -25709,,,0.6539843678474426,1.6860125064849854,0.5881400108337402,1.9701340198516848,50000.0,0.4701000154018402,2.585104465484619,10000.0,11799.745354413986,12974.617196798325,11799.745354413986,1172.748398065567,0.7861979007720947,0.0 -25800,0.9729066,3.9194145,,,,,,,,,,,,,, -25900,0.9184777,4.0405164,,,,,,,,,,,,,, -26000,0.9205554,3.7874188,,,,,,,,,,,,,, -26100,0.7905278,5.6577187,,,,,,,,,,,,,, -26200,0.901103,4.1055474,,,,,,,,,,,,,, -26300,1.0513552,3.833709,,,,,,,,,,,,,, -26400,0.84219706,4.4873323,,,,,,,,,,,,,, -26500,1.0307924,3.8342826,,,,,,,,,,,,,, -26600,1.0402627,3.8663478,,,,,,,,,,,,,, -26630,,,0.6346874833106995,1.760985255241394,0.5895199775695801,1.968305706977844,50000.0,0.4780000150203705,2.573791742324829,10000.0,12219.731608867643,13436.014949560164,12219.731608867643,1214.0835967063904,0.8152399063110352,0.0 -26700,0.95330846,3.8797202,,,,,,,,,,,,,, -26800,0.7985022,5.4535666,,,,,,,,,,,,,, -26900,0.7876157,4.906673,,,,,,,,,,,,,, -27000,0.92464507,3.7540336,,,,,,,,,,,,,, -27100,0.89657295,4.073003,,,,,,,,,,,,,, -27200,0.97424024,3.9535422,,,,,,,,,,,,,, -27300,0.9681357,3.7260406,,,,,,,,,,,,,, -27400,0.84017307,4.977292,,,,,,,,,,,,,, -27500,0.87393266,3.982655,,,,,,,,,,,,,, -27553,,,0.6429687142372131,1.7258527278900146,0.5952999591827393,1.952037453651428,50000.0,0.4765000343322754,2.5679101943969727,10000.0,12639.99943780899,13899.052494764328,12639.99943780899,1256.7773969173431,0.8430163860321045,0.0 -27600,0.8356999,5.5352917,,,,,,,,,,,,,, -27700,0.8228169,4.8473444,,,,,,,,,,,,,, -27800,0.8239697,5.653827,,,,,,,,,,,,,, -27900,0.9808252,3.8664918,,,,,,,,,,,,,, -28000,0.99814296,3.7821145,,,,,,,,,,,,,, -28100,1.0068375,3.930867,,,,,,,,,,,,,, -28200,0.86609,4.459551,,,,,,,,,,,,,, -28300,0.92324424,3.844061,,,,,,,,,,,,,, -28400,0.95600426,3.786477,,,,,,,,,,,,,, -28476,,,0.6584765315055847,1.6636217832565308,0.595579981803894,1.949023962020874,50000.0,0.4812000095844269,2.56640625,10000.0,13060.122702360151,14358.513455867767,13060.122702360151,1296.0321819782257,0.8778200149536133,0.0 -28500,0.98394656,4.1782246,,,,,,,,,,,,,, -28600,0.8900921,3.9346266,,,,,,,,,,,,,, -28700,0.8923419,4.8671613,,,,,,,,,,,,,, -28800,0.8507349,4.9954286,,,,,,,,,,,,,, -28900,0.8545746,4.2163324,,,,,,,,,,,,,, -29000,0.97907823,3.8763428,,,,,,,,,,,,,, -29100,0.80883497,4.851569,,,,,,,,,,,,,, -29200,0.8844495,4.18422,,,,,,,,,,,,,, -29300,0.90084136,3.999199,,,,,,,,,,,,,, -29396,,,0.6520312428474426,1.674699306488037,0.6082800030708313,1.8775641918182373,50000.0,0.4892000257968902,2.5049943923950195,10000.0,13480.046933412552,14819.881961107254,13480.046933412552,1337.395380973816,0.9113750457763672,0.0 -29400,0.94491816,3.9039547,,,,,,,,,,,,,, -29500,0.93535674,3.9482315,,,,,,,,,,,,,, -29600,0.95926136,3.8013408,,,,,,,,,,,,,, -29700,0.9963041,3.7125144,,,,,,,,,,,,,, -29800,0.9736562,3.795403,,,,,,,,,,,,,, -29900,1.0394777,3.799683,,,,,,,,,,,,,, -30000,1.0560741,3.7613604,,,,,,,,,,,,,, -30100,0.9573419,4.1561084,,,,,,,,,,,,,, -30200,0.8442497,4.462986,,,,,,,,,,,,,, -30300,1.0835797,3.879921,,,,,,,,,,,,,, -30317,,,0.6537694931030273,1.6482510566711426,0.6024599671363831,1.874757409095764,50000.0,0.489300012588501,2.4801666736602783,10000.0,13900.2150182724,15285.852749586104,13900.2150182724,1383.1154959201813,0.945807695388794,0.0 -30400,0.86627865,5.3315234,,,,,,,,,,,,,, -30500,0.9124972,4.077395,,,,,,,,,,,,,, -30600,0.97096395,3.794124,,,,,,,,,,,,,, -30700,0.8711931,4.272919,,,,,,,,,,,,,, -30800,0.84135276,5.0818906,,,,,,,,,,,,,, -30900,0.89116883,4.5020447,,,,,,,,,,,,,, -31000,1.0775725,3.6829858,,,,,,,,,,,,,, -31100,1.0197089,3.7933517,,,,,,,,,,,,,, -31200,0.893756,4.8848844,,,,,,,,,,,,,, -31241,,,0.6695898175239563,1.5669158697128296,0.6098799705505371,1.8242343664169312,50000.0,0.4948000311851501,2.441969156265259,10000.0,14320.48095369339,15745.813037872314,14320.48095369339,1422.7271156311035,0.9807932376861572,0.0 -31300,0.7933699,4.7037735,,,,,,,,,,,,,, -31400,0.94098794,3.7388163,,,,,,,,,,,,,, -31500,0.92393345,4.150905,,,,,,,,,,,,,, -31600,1.0364656,3.745242,,,,,,,,,,,,,, -31700,0.9157881,3.8743484,,,,,,,,,,,,,, -31800,1.0423094,3.887841,,,,,,,,,,,,,, -31900,1.0143521,3.74317,,,,,,,,,,,,,, -32000,0.9818798,3.8165383,,,,,,,,,,,,,, -32100,0.8745821,5.5729055,,,,,,,,,,,,,, -32162,,,0.6646093726158142,1.623708724975586,0.616159975528717,1.841352462768555,50000.0,0.5010000467300415,2.439016103744507,10000.0,14740.702543497086,16209.5774371624,14740.702543497086,1466.1945168972015,1.0090515613555908,0.0 -32200,0.9090338,5.6094084,,,,,,,,,,,,,, -32300,0.9127124,3.921599,,,,,,,,,,,,,, -32400,0.920516,3.7816997,,,,,,,,,,,,,, -32500,0.8013152,5.472118,,,,,,,,,,,,,, -32600,1.0690553,3.7227821,,,,,,,,,,,,,, -32700,0.9008676,3.7223516,,,,,,,,,,,,,, -32800,1.0474547,3.6576214,,,,,,,,,,,,,, -32900,1.0309181,3.7643957,,,,,,,,,,,,,, -33000,0.8224677,4.665991,,,,,,,,,,,,,, -33084,,,0.6647070050239563,1.6350743770599363,0.6173399686813354,1.8481569290161133,50000.0,0.494700014591217,2.4732367992401123,10000.0,15160.87868309021,16668.039219379425,15160.87868309021,1504.403111219406,1.0377919673919678,0.0 -33100,1.0293535,3.6504364,,,,,,,,,,,,,, -33200,1.0282149,3.8024578,,,,,,,,,,,,,, -33300,0.997787,3.7318313,,,,,,,,,,,,,, -33400,0.83488524,4.795144,,,,,,,,,,,,,, -33500,0.80305004,4.9158463,,,,,,,,,,,,,, -33600,0.9678751,3.7268896,,,,,,,,,,,,,, -33700,0.9746062,3.7505774,,,,,,,,,,,,,, -33800,0.95736194,3.801805,,,,,,,,,,,,,, -33900,1.0386405,3.7640219,,,,,,,,,,,,,, -34000,1.0887724,3.7957215,,,,,,,,,,,,,, -34004,,,0.6825194954872131,1.5208829641342163,0.6213200092315674,1.7882128953933716,50000.0,0.506600022315979,2.3888542652130127,10000.0,15581.12509894371,17127.568870782852,15581.12509894371,1543.6040349006653,1.0717051029205322,0.0 -34100,0.984344,3.7561932,,,,,,,,,,,,,, -34200,0.9540312,3.851619,,,,,,,,,,,,,, -34300,0.81840533,5.3910575,,,,,,,,,,,,,, -34400,0.98253644,3.773504,,,,,,,,,,,,,, -34500,1.047402,3.7479782,,,,,,,,,,,,,, -34600,1.0434351,3.5834312,,,,,,,,,,,,,, -34700,0.94214743,3.6254888,,,,,,,,,,,,,, -34800,0.84000015,4.656251,,,,,,,,,,,,,, -34900,0.8591827,4.65616,,,,,,,,,,,,,, -34924,,,0.6859960556030273,1.5173721313476562,0.6280999779701233,1.7702747583389282,50000.0,0.5037000179290771,2.400809526443481,10000.0,16001.194901704788,17587.93961954117,16001.194901704788,1583.8245911598206,1.103961944580078,0.0 -35000,0.7855829,5.4178295,,,,,,,,,,,,,, -35100,0.9213272,5.1492386,,,,,,,,,,,,,, -35200,0.9620646,3.684033,,,,,,,,,,,,,, -35300,0.982033,4.012917,,,,,,,,,,,,,, -35400,0.9453423,3.6316018,,,,,,,,,,,,,, -35500,1.1398827,3.672018,,,,,,,,,,,,,, -35600,0.88960034,5.0676937,,,,,,,,,,,,,, -35700,0.9285209,3.6485004,,,,,,,,,,,,,, -35800,0.8695232,5.487759,,,,,,,,,,,,,, -35848,,,0.6784374713897705,1.537048101425171,0.6295599937438965,1.763940691947937,50000.0,0.5037000179290771,2.3965654373168945,10000.0,16421.54888010025,18051.60387778282,16421.54888010025,1627.0497500896454,1.141261339187622,0.0 -35900,1.0080291,3.7487423,,,,,,,,,,,,,, -36000,0.97582585,3.573268,,,,,,,,,,,,,, -36100,0.98876166,3.6262689,,,,,,,,,,,,,, -36200,0.98598576,3.6436265,,,,,,,,,,,,,, -36300,0.9230713,3.6670527,,,,,,,,,,,,,, -36400,1.0268512,3.9013538,,,,,,,,,,,,,, -36500,0.83787936,4.4274597,,,,,,,,,,,,,, -36600,1.0173213,3.6303644,,,,,,,,,,,,,, -36700,1.0073682,3.6528778,,,,,,,,,,,,,, -36770,,,0.6861132383346558,1.495723843574524,0.6327999830245972,1.7372503280639648,50000.0,0.5098000168800354,2.362467765808105,10000.0,16841.60574913025,18511.34856534004,16841.60574913025,1666.6609530448914,1.1702823638916016,0.0 -36800,1.0459404,3.728935,,,,,,,,,,,,,, -36900,0.836694,4.9711018,,,,,,,,,,,,,, -37000,0.9616815,3.6370587,,,,,,,,,,,,,, -37100,0.9558207,3.7206798,,,,,,,,,,,,,, -37200,0.9478351,3.9469357,,,,,,,,,,,,,, -37300,0.9627846,3.5309513,,,,,,,,,,,,,, -37400,0.9465594,3.527974,,,,,,,,,,,,,, -37500,1.0576116,3.7629929,,,,,,,,,,,,,, -37600,1.1030936,3.6901155,,,,,,,,,,,,,, -37690,,,0.7042187452316284,1.429381012916565,0.6326799988746643,1.742835283279419,50000.0,0.5088000297546387,2.347130537033081,10000.0,17261.76012301445,18975.477598905563,17261.76012301445,1710.5507550239563,1.207282543182373,0.0 -37700,0.8982695,3.8566737,,,,,,,,,,,,,, -37800,0.877516,4.5870085,,,,,,,,,,,,,, -37900,0.893903,4.718984,,,,,,,,,,,,,, -38000,1.0666159,3.6018486,,,,,,,,,,,,,, -38100,0.89486516,5.4936247,,,,,,,,,,,,,, -38200,0.9019418,4.1388454,,,,,,,,,,,,,, -38300,0.97277844,3.6820354,,,,,,,,,,,,,, -38400,0.88523674,4.3729687,,,,,,,,,,,,,, -38500,1.1285481,3.619322,,,,,,,,,,,,,, -38600,1.0965149,3.7000642,,,,,,,,,,,,,, -38611,,,0.6838085651397705,1.5368809700012207,0.6318199634552002,1.7639278173446655,50000.0,0.5082000494003296,2.393516540527344,10000.0,17681.86350107193,19432.76448178292,17681.86350107193,1747.6528568267822,1.241055250167847,0.0 -38700,0.8981772,4.5003624,,,,,,,,,,,,,, -38800,0.9860352,3.6525645,,,,,,,,,,,,,, -38900,0.96573067,3.5409818,,,,,,,,,,,,,, -39000,0.8866818,3.9564404,,,,,,,,,,,,,, -39100,1.1029363,3.765924,,,,,,,,,,,,,, -39200,0.91517466,3.9963121,,,,,,,,,,,,,, -39300,0.90490586,4.321754,,,,,,,,,,,,,, -39400,0.84937805,4.34105,,,,,,,,,,,,,, -39500,1.0098504,3.685013,,,,,,,,,,,,,, -39530,,,0.6908202767372131,1.493054986000061,0.6378999948501587,1.7385770082473757,50000.0,0.5134000182151794,2.360651969909668,10000.0,18101.93009710312,19894.123265028,18101.93009710312,1788.8630316257477,1.2756617069244385,0.0 -39600,1.2232146,3.6260808,,,,,,,,,,,,,, -39700,1.0179584,3.7440965,,,,,,,,,,,,,, -39800,1.0224149,3.5478654,,,,,,,,,,,,,, -39900,0.9669202,3.6759708,,,,,,,,,,,,,, -40000,0.91145664,5.0224686,,,,,,,,,,,,,, -40100,1.0687933,3.6659513,,,,,,,,,,,,,, -40200,1.0186373,3.612332,,,,,,,,,,,,,, -40300,1.0593301,3.6077273,,,,,,,,,,,,,, -40400,0.93779314,3.6144836,,,,,,,,,,,,,, -40446,,,0.7030664086341858,1.4640662670135498,0.6423999667167664,1.7443610429763794,50000.0,0.5159000158309937,2.3539485931396484,10000.0,18521.99777531624,20353.072145223618,18521.99777531624,1827.666071653366,1.3061726093292236,0.0 -40500,1.008042,3.5784,,,,,,,,,,,,,, -40600,0.94026333,4.210668,,,,,,,,,,,,,, -40700,1.079574,3.6958313,,,,,,,,,,,,,, -40800,0.98678046,3.7827091,,,,,,,,,,,,,, -40900,0.9239102,5.385243,,,,,,,,,,,,,, -41000,1.0146993,4.022221,,,,,,,,,,,,,, -41100,1.1011425,3.728956,,,,,,,,,,,,,, -41200,0.9797782,3.6227014,,,,,,,,,,,,,, -41300,0.97699934,3.8815484,,,,,,,,,,,,,, -41368,,,0.6928125023841858,1.5066217184066772,0.6416199803352356,1.7189924716949463,50000.0,0.5213000178337097,2.339286088943481,10000.0,18942.32962703705,20815.592352867126,18942.32962703705,1869.7756474018097,1.337547779083252,0.0 -41400,1.0451522,3.542327,,,,,,,,,,,,,, -41500,0.86477125,4.8522215,,,,,,,,,,,,,, -41600,0.9115752,4.947275,,,,,,,,,,,,,, -41700,0.98726654,5.414027,,,,,,,,,,,,,, -41800,1.1387421,3.5770054,,,,,,,,,,,,,, -41900,0.8759198,5.025327,,,,,,,,,,,,,, -42000,1.0447327,3.6775014,,,,,,,,,,,,,, -42100,0.92819434,4.7154875,,,,,,,,,,,,,, -42200,1.0817763,3.6781096,,,,,,,,,,,,,, -42290,,,0.6962890625,1.4691346883773804,0.6421599984169006,1.710872769355774,50000.0,0.5215000510215759,2.330148935317993,10000.0,19362.28750777245,21281.26524925232,19362.28750777245,1915.4070200920105,1.3723368644714355,0.0 -42300,1.025257,3.581383,,,,,,,,,,,,,, -42400,0.9168508,4.038185,,,,,,,,,,,,,, -42500,0.94665945,3.6632626,,,,,,,,,,,,,, -42600,0.9389475,4.4508743,,,,,,,,,,,,,, -42700,1.0832095,3.578763,,,,,,,,,,,,,, -42800,1.047525,3.5894928,,,,,,,,,,,,,, -42900,1.0257452,3.617361,,,,,,,,,,,,,, -43000,1.0293553,3.713954,,,,,,,,,,,,,, -43100,1.0314524,3.6532352,,,,,,,,,,,,,, -43200,1.0512232,3.6048198,,,,,,,,,,,,,, -43213,,,0.707324206829071,1.4024572372436523,0.647159993648529,1.6708543300628662,50000.0,0.5273000001907349,2.284856081008911,10000.0,19782.306488990784,21741.00474333763,19782.306488990784,1955.033771038056,1.4100327491760254,0.0 -43300,1.0212415,3.7638638,,,,,,,,,,,,,, -43400,1.011319,3.6144986,,,,,,,,,,,,,, -43500,0.9381486,4.0423055,,,,,,,,,,,,,, -43600,1.02262,3.6876357,,,,,,,,,,,,,, -43700,0.9235426,4.085894,,,,,,,,,,,,,, -43800,0.8541012,4.617455,,,,,,,,,,,,,, -43900,0.93999475,4.935654,,,,,,,,,,,,,, -44000,1.0961871,3.5570252,,,,,,,,,,,,,, -44100,1.0446812,3.555067,,,,,,,,,,,,,, -44135,,,0.6955273151397705,1.4521766901016235,0.6490199565887451,1.6639647483825684,50000.0,0.5218999981880188,2.304472684860229,10000.0,20202.631172180176,22203.795504808422,20202.631172180176,1997.4169154167173,1.4440712928771973,0.0 -44200,1.0098525,3.8103504,,,,,,,,,,,,,, -44300,1.1419846,3.6243622,,,,,,,,,,,,,, -44400,1.1316334,3.6505032,,,,,,,,,,,,,, -44500,1.0036583,3.5997968,,,,,,,,,,,,,, -44600,0.96965855,3.785893,,,,,,,,,,,,,, -44700,1.0381984,3.5534263,,,,,,,,,,,,,, -44800,1.0046806,3.9419117,,,,,,,,,,,,,, -44900,0.9905725,3.710495,,,,,,,,,,,,,, -45000,1.0366004,3.888484,,,,,,,,,,,,,, -45055,,,0.698925793170929,1.434710144996643,0.647379994392395,1.667491793632507,50000.0,0.5235000252723694,2.2860796451568604,10000.0,20622.772392749783,22665.36895251274,20622.772392749783,2038.766408443451,1.4792227745056152,0.0 -45100,1.022242,3.531625,,,,,,,,,,,,,, -45200,1.0215408,5.247601,,,,,,,,,,,,,, -45300,1.0519264,3.6176183,,,,,,,,,,,,,, -45400,1.0736403,3.5507066,,,,,,,,,,,,,, -45500,1.0148185,3.566752,,,,,,,,,,,,,, -45600,0.93370205,4.131648,,,,,,,,,,,,,, -45700,1.0854443,3.6414347,,,,,,,,,,,,,, -45800,1.0862447,3.6617007,,,,,,,,,,,,,, -45900,0.9754441,4.0749683,,,,,,,,,,,,,, -45974,,,0.7076953053474426,1.4150549173355105,0.6478599905967712,1.678601622581482,50000.0,0.5283000469207764,2.2812087535858154,10000.0,21043.10510325432,23129.006335258484,21043.10510325432,2081.991260766983,1.511063575744629,0.0 -46000,0.9384883,3.9114904,,,,,,,,,,,,,, -46100,1.0196205,4.0360656,,,,,,,,,,,,,, -46200,0.90954244,4.4919257,,,,,,,,,,,,,, -46300,1.1338893,3.5891593,,,,,,,,,,,,,, -46400,0.91920656,4.445289,,,,,,,,,,,,,, -46500,0.9667172,3.9486892,,,,,,,,,,,,,, -46600,1.0388296,3.5753207,,,,,,,,,,,,,, -46700,1.0538174,3.5536208,,,,,,,,,,,,,, -46800,1.0814232,3.5837035,,,,,,,,,,,,,, -46895,,,0.7375390529632568,1.29774272441864,0.6542400121688843,1.6448653936386108,50000.0,0.5331000089645386,2.2646377086639404,10000.0,21463.18249320984,23590.88441562653,21463.18249320984,2123.710966825485,1.544229507446289,0.0 -46900,0.9412564,5.0601044,,,,,,,,,,,,,, -47000,1.0739459,3.5723422,,,,,,,,,,,,,, -47100,0.9942695,3.5056736,,,,,,,,,,,,,, -47200,0.984393,5.042634,,,,,,,,,,,,,, -47300,1.071189,3.6526198,,,,,,,,,,,,,, -47400,1.0785097,4.101351,,,,,,,,,,,,,, -47500,1.059444,3.572939,,,,,,,,,,,,,, -47600,1.0260217,3.517007,,,,,,,,,,,,,, -47700,1.0965544,3.4641764,,,,,,,,,,,,,, -47800,0.95434684,4.7723556,,,,,,,,,,,,,, -47818,,,0.7093945145606995,1.4072017669677734,0.6576399803161621,1.6411067247390747,50000.0,0.5327000021934509,2.245063304901123,10000.0,21883.53296923637,24053.848252534863,21883.53296923637,2166.239804506302,1.5809102058410645,0.0 -47900,1.035917,3.560738,,,,,,,,,,,,,, -48000,0.9328243,4.3514156,,,,,,,,,,,,,, -48100,0.95174426,4.7311125,,,,,,,,,,,,,, -48200,1.0907692,3.6015918,,,,,,,,,,,,,, -48300,1.1495748,3.6761727,,,,,,,,,,,,,, -48400,1.0269216,3.570601,,,,,,,,,,,,,, -48500,1.1267378,3.6130495,,,,,,,,,,,,,, -48600,1.1158396,3.4268816,,,,,,,,,,,,,, -48700,1.078547,3.5209918,,,,,,,,,,,,,, -48741,,,0.7172656059265137,1.3772408962249756,0.6559399962425232,1.6439871788024902,50000.0,0.5318000316619873,2.262031078338623,10000.0,22303.59793663025,24516.407782793045,22303.59793663025,2208.6510157585144,1.6165921688079834,0.0 -48800,1.047625,3.577057,,,,,,,,,,,,,, -48900,1.0256468,3.4410636,,,,,,,,,,,,,, -49000,0.920445,4.9723306,,,,,,,,,,,,,, -49100,0.983717,5.3969016,,,,,,,,,,,,,, -49200,1.0614957,3.5099335,,,,,,,,,,,,,, -49300,1.0126867,3.8784251,,,,,,,,,,,,,, -49400,1.0821329,3.585783,,,,,,,,,,,,,, -49500,0.99318063,3.8757858,,,,,,,,,,,,,, -49600,1.0361753,3.5113988,,,,,,,,,,,,,, -49664,,,0.71875,1.425883650779724,0.6511600017547607,1.7147786617279053,50000.0,0.5281000137329102,2.343284845352173,10000.0,22723.800693511963,24980.77202296257,22723.800693511963,2252.7241654396057,1.656675100326538,0.0 -49700,1.0398439,3.5004582,,,,,,,,,,,,,, -49800,0.99867976,5.3014045,,,,,,,,,,,,,, -49900,1.0415244,4.816354,,,,,,,,,,,,,, -50000,1.1291431,3.505469,,,,,,,,,,,,,, -50100,0.95909566,4.5126147,,,,,,,,,,,,,, -50200,1.040036,3.7068584,,,,,,,,,,,,,, -50300,1.0007138,4.5465827,,,,,,,,,,,,,, -50400,1.0274098,4.0308366,,,,,,,,,,,,,, -50500,1.0376289,3.54133,,,,,,,,,,,,,, -50586,,,0.7098437547683716,1.4361015558242798,0.6548399925231934,1.6722317934036257,50000.0,0.5351999998092651,2.278514862060547,10000.0,23143.94828104973,25443.989156007767,23143.94828104973,2295.708538770676,1.69374680519104,0.0 -50600,1.0426499,5.300651,,,,,,,,,,,,,, -50700,0.98454964,4.0355487,,,,,,,,,,,,,, -50800,1.0562493,3.5158534,,,,,,,,,,,,,, -50900,1.0724689,3.541012,,,,,,,,,,,,,, -51000,1.0437623,3.9654546,,,,,,,,,,,,,, -51100,1.0797939,3.5515893,,,,,,,,,,,,,, -51200,1.0806309,3.5083258,,,,,,,,,,,,,, -51300,1.0369309,3.5390153,,,,,,,,,,,,,, -51400,1.022455,3.808971,,,,,,,,,,,,,, -51500,1.0230794,3.6500478,,,,,,,,,,,,,, -51508,,,0.71498042345047,1.3889296054840088,0.6620399951934814,1.6310160160064695,50000.0,0.5435000061988831,2.247091293334961,10000.0,23564.28750729561,25905.883530139923,23564.28750729561,2337.176256418228,1.7330005168914795,0.0 -51600,1.0011572,4.76431,,,,,,,,,,,,,, -51700,1.1200709,3.5502543,,,,,,,,,,,,,, -51800,0.938688,4.2277956,,,,,,,,,,,,,, -51900,0.9942484,4.0618973,,,,,,,,,,,,,, -52000,1.0784318,3.8432899,,,,,,,,,,,,,, -52100,1.0556616,4.3604784,,,,,,,,,,,,,, -52200,1.1305742,3.6424117,,,,,,,,,,,,,, -52300,0.9990394,3.7885218,,,,,,,,,,,,,, -52400,0.9718196,3.6284626,,,,,,,,,,,,,, -52432,,,0.7337304353713989,1.2980376482009888,0.6647999882698059,1.5864545106887815,50000.0,0.5354000329971313,2.2272799015045166,10000.0,23984.374716758728,26366.565184354786,23984.374716758728,2377.6872231960297,1.768537521362305,0.0 -52500,1.0604714,3.5167034,,,,,,,,,,,,,, -52600,0.95359206,3.9392657,,,,,,,,,,,,,, -52700,1.0095295,3.5248404,,,,,,,,,,,,,, -52800,1.0805851,3.5829954,,,,,,,,,,,,,, -52900,0.96998686,4.4652348,,,,,,,,,,,,,, -53000,0.94714385,5.0567536,,,,,,,,,,,,,, -53100,1.2140332,3.4915607,,,,,,,,,,,,,, -53200,1.0666941,3.588875,,,,,,,,,,,,,, -53300,1.174983,5.206907,,,,,,,,,,,,,, -53354,,,0.7135351300239563,1.3821560144424438,0.6652799844741821,1.598161697387695,50000.0,0.532200038433075,2.2386462688446045,10000.0,24404.3652215004,26826.81566309929,24404.3652215004,2417.860870361328,1.8072423934936523,0.0 -53400,1.0511862,3.6557095,,,,,,,,,,,,,, -53500,0.99100846,3.745009,,,,,,,,,,,,,, -53600,1.1167334,3.5174448,,,,,,,,,,,,,, -53700,0.9680105,3.9884093,,,,,,,,,,,,,, -53800,0.96863246,4.026438,,,,,,,,,,,,,, -53900,0.9647689,4.6692095,,,,,,,,,,,,,, -54000,1.041081,4.626522,,,,,,,,,,,,,, -54100,0.9734576,4.051561,,,,,,,,,,,,,, -54200,1.1205972,3.4916809,,,,,,,,,,,,,, -54275,,,0.7189062237739563,1.361011028289795,0.6606799960136414,1.6143720149993896,50000.0,0.5452000498771667,2.222517967224121,10000.0,24824.307297229767,27290.06626176834,24824.307297229767,2461.081460237503,1.847649097442627,0.0 -54300,0.986284,4.3044577,,,,,,,,,,,,,, -54400,1.146738,3.4986823,,,,,,,,,,,,,, -54500,1.0899014,3.5083456,,,,,,,,,,,,,, -54600,0.94870645,3.9361947,,,,,,,,,,,,,, -54700,1.1097426,3.5898595,,,,,,,,,,,,,, -54800,0.9599818,4.6170406,,,,,,,,,,,,,, -54900,1.3014994,3.4864979,,,,,,,,,,,,,, -55000,1.2022698,3.5464697,,,,,,,,,,,,,, -55100,1.0481209,3.4197583,,,,,,,,,,,,,, -55196,,,0.7296484112739563,1.3426628112792969,0.664139986038208,1.6135691404342651,50000.0,0.5435000061988831,2.220703601837158,10000.0,25244.55240249633,27749.66876244545,25244.55240249633,2500.354782104492,1.8835597038269043,0.0 -55200,1.0288831,3.593653,,,,,,,,,,,,,, -55300,1.1752623,3.5195155,,,,,,,,,,,,,, -55400,1.0673895,3.4390554,,,,,,,,,,,,,, -55500,0.98857546,3.4819162,,,,,,,,,,,,,, -55600,1.0747691,3.5769088,,,,,,,,,,,,,, -55700,1.0777652,3.5368068,,,,,,,,,,,,,, -55800,1.099769,3.4825778,,,,,,,,,,,,,, -55900,1.0775437,3.494661,,,,,,,,,,,,,, -56000,1.0266947,3.6585822,,,,,,,,,,,,,, -56100,1.1257468,3.8628056,,,,,,,,,,,,,, -56120,,,0.7380273342132568,1.2842124700546265,0.6672599911689758,1.5820810794830322,50000.0,0.5502000451087952,2.189172744750977,10000.0,25664.7479467392,28215.84439206124,25664.7479467392,2546.253002643585,1.9170572757720947,0.0 -56200,1.0037618,3.4461217,,,,,,,,,,,,,, -56300,1.0864575,3.505896,,,,,,,,,,,,,, -56400,1.2144831,3.559648,,,,,,,,,,,,,, -56500,1.0801703,3.5820148,,,,,,,,,,,,,, -56600,1.0771943,5.2843547,,,,,,,,,,,,,, -56700,1.2095736,3.4536428,,,,,,,,,,,,,, -56800,1.0226383,4.811205,,,,,,,,,,,,,, -56900,0.9531948,4.8642306,,,,,,,,,,,,,, -57000,1.016589,4.832343,,,,,,,,,,,,,, -57042,,,0.7276562452316284,1.3468148708343506,0.6674000024795532,1.603977918624878,50000.0,0.5464000105857849,2.211744546890259,10000.0,26085.00351262093,28679.673223495483,26085.00351262093,2589.74267745018,1.952627658843994,0.0 -57100,1.0300995,4.88324,,,,,,,,,,,,,, -57200,0.9310337,4.739979,,,,,,,,,,,,,, -57300,1.0718856,3.983458,,,,,,,,,,,,,, -57400,1.2007436,3.4458637,,,,,,,,,,,,,, -57500,1.1328282,3.6426816,,,,,,,,,,,,,, -57600,0.98825914,4.223711,,,,,,,,,,,,,, -57700,1.0717454,3.5166144,,,,,,,,,,,,,, -57800,1.141919,3.9292033,,,,,,,,,,,,,, -57900,1.1224256,3.5763078,,,,,,,,,,,,,, -57963,,,0.7269140481948853,1.328726887702942,0.667419970035553,1.5865769386291504,50000.0,0.5430999994277954,2.196040153503418,10000.0,26505.091208696365,29141.39177846909,26505.091208696365,2631.289387464524,1.9895341396331787,0.0 -58000,1.117675,3.4815717,,,,,,,,,,,,,, -58100,1.0538831,4.5112867,,,,,,,,,,,,,, -58200,1.0523468,4.665025,,,,,,,,,,,,,, -58300,1.0775425,3.4687424,,,,,,,,,,,,,, -58400,1.1542857,3.4582925,,,,,,,,,,,,,, -58500,0.94104606,4.272415,,,,,,,,,,,,,, -58600,1.0764782,3.7045255,,,,,,,,,,,,,, -58700,1.1605054,3.470218,,,,,,,,,,,,,, -58800,1.091149,3.4924076,,,,,,,,,,,,,, -58883,,,0.7373241782188416,1.2906686067581177,0.6669600009918213,1.6132616996765137,50000.0,0.5457000136375427,2.201898097991944,10000.0,26925.22168302536,29605.44416475296,26925.22168302536,2675.127780675888,2.025451183319092,0.0 -58900,1.0436375,4.4653115,,,,,,,,,,,,,, -59000,1.0873517,3.4505696,,,,,,,,,,,,,, -59100,1.0742024,3.9404852,,,,,,,,,,,,,, -59200,1.2392129,3.470141,,,,,,,,,,,,,, -59300,1.0420867,5.1805425,,,,,,,,,,,,,, -59400,1.0041856,5.323398,,,,,,,,,,,,,, -59500,1.0195886,5.2628355,,,,,,,,,,,,,, -59600,1.142688,3.3708014,,,,,,,,,,,,,, -59700,1.008062,5.0249867,,,,,,,,,,,,,, -59800,1.0432053,3.7101603,,,,,,,,,,,,,, -59803,,,0.7282031178474426,1.2871874570846558,0.6700199842453003,1.5382176637649536,50000.0,0.5478000044822693,2.153144598007202,10000.0,27345.44739818573,30069.56980085373,27345.44739818573,2718.944101333618,2.061138868331909,0.0 -59900,1.0376517,3.434074,,,,,,,,,,,,,, -60000,1.0680431,4.134341,,,,,,,,,,,,,, -60100,1.148483,3.5125642,,,,,,,,,,,,,, -60200,1.1231016,3.4937372,,,,,,,,,,,,,, -60300,1.081999,4.4497643,,,,,,,,,,,,,, -60400,1.1236137,3.3844914,,,,,,,,,,,,,, -60500,1.1412342,3.523386,,,,,,,,,,,,,, -60600,1.1407835,3.7449722,,,,,,,,,,,,,, -60700,1.0805664,3.5305433,,,,,,,,,,,,,, -60724,,,0.733593761920929,1.3361026048660278,0.6747199892997742,1.584994196891785,50000.0,0.5454000234603882,2.2170250415802,10000.0,27765.79726195336,30533.50876736641,27765.79726195336,2762.44873046875,2.097840070724488,0.0 -60800,1.1019866,3.530758,,,,,,,,,,,,,, -60900,1.0351188,4.180088,,,,,,,,,,,,,, -61000,1.2524092,4.722123,,,,,,,,,,,,,, -61100,1.0975993,3.583278,,,,,,,,,,,,,, -61200,1.0567615,4.4782405,,,,,,,,,,,,,, -61300,1.0677354,3.4942155,,,,,,,,,,,,,, -61400,1.2319041,3.4725792,,,,,,,,,,,,,, -61500,1.1830761,3.4187894,,,,,,,,,,,,,, -61600,1.1092483,3.6595385,,,,,,,,,,,,,, -61645,,,0.7450000047683716,1.2544151544570925,0.6762799620628357,1.5565282106399536,50000.0,0.5508000254631042,2.175717830657959,10000.0,28185.98828697205,30996.8189907074,28185.98828697205,2805.484763622284,2.13291072845459,0.0 -61700,1.0669011,4.454755,,,,,,,,,,,,,, -61800,1.0962467,3.4849901,,,,,,,,,,,,,, -61900,1.1900147,3.4103613,,,,,,,,,,,,,, -62000,1.1149889,3.4258628,,,,,,,,,,,,,, -62100,1.0714196,3.6761112,,,,,,,,,,,,,, -62200,1.169013,3.4738543,,,,,,,,,,,,,, -62300,1.0762534,5.1893697,,,,,,,,,,,,,, -62400,1.0553868,3.5542943,,,,,,,,,,,,,, -62500,1.0760103,3.912464,,,,,,,,,,,,,, -62563,,,0.7283593416213989,1.3276382684707642,0.6751199960708618,1.5634821653366089,50000.0,0.5521000027656555,2.1780121326446533,10000.0,28606.290912389755,31461.618797063828,28606.290912389755,2849.9011821746826,2.1665310859680176,0.0 -62600,1.0363225,3.5302894,,,,,,,,,,,,,, -62700,1.0072653,4.493086,,,,,,,,,,,,,, -62800,1.083914,4.9748173,,,,,,,,,,,,,, -62900,1.1794298,3.4464636,,,,,,,,,,,,,, -63000,1.2193619,3.433454,,,,,,,,,,,,,, -63100,1.1572795,3.4665422,,,,,,,,,,,,,, -63200,1.0102812,4.689685,,,,,,,,,,,,,, -63300,0.9929695,3.762244,,,,,,,,,,,,,, -63400,1.0347672,5.149374,,,,,,,,,,,,,, -63483,,,0.7360742092132568,1.2766711711883545,0.6779199838638306,1.5261484384536743,50000.0,0.557200014591217,2.149172306060791,10000.0,29026.3101747036,31926.710065841675,29026.3101747036,2894.8881330490112,2.20430326461792,0.0 -63500,1.0489838,4.1320076,,,,,,,,,,,,,, -63600,0.9907781,4.738025,,,,,,,,,,,,,, -63700,1.1991727,3.4761105,,,,,,,,,,,,,, -63800,1.006439,4.71028,,,,,,,,,,,,,, -63900,1.147372,4.8051367,,,,,,,,,,,,,, -64000,1.1538723,3.4260526,,,,,,,,,,,,,, -64100,1.1644214,3.655263,,,,,,,,,,,,,, -64200,0.9959658,5.1313505,,,,,,,,,,,,,, -64300,1.1519089,3.382805,,,,,,,,,,,,,, -64400,0.99062246,5.0268683,,,,,,,,,,,,,, -64404,,,0.7432031035423279,1.2548317909240725,0.6798799633979797,1.5307791233062744,50000.0,0.5481000542640686,2.1536831855773926,10000.0,29446.25080990792,32389.77768635749,29446.25080990792,2937.920587062836,2.25028133392334,0.0 -64500,1.1589196,3.456204,,,,,,,,,,,,,, -64600,1.0801237,3.7603824,,,,,,,,,,,,,, -64700,1.2068186,3.445393,,,,,,,,,,,,,, -64800,1.011296,3.7229524,,,,,,,,,,,,,, -64900,1.0556531,4.1047454,,,,,,,,,,,,,, -65000,1.0784738,4.8907948,,,,,,,,,,,,,, -65100,0.93814844,4.1154094,,,,,,,,,,,,,, -65200,1.168841,3.5209074,,,,,,,,,,,,,, -65300,1.1371456,5.1742334,,,,,,,,,,,,,, -65324,,,0.7388671636581421,1.2673910856246948,0.6782000064849854,1.518659234046936,50000.0,0.5658000111579895,2.10740065574646,10000.0,29866.379588842392,32848.34146499634,29866.379588842392,2976.272282600403,2.2853519916534424,0.0 -65400,1.055242,3.4606423,,,,,,,,,,,,,, -65500,1.0991367,3.550088,,,,,,,,,,,,,, -65600,1.1185421,3.6275775,,,,,,,,,,,,,, -65700,1.2762371,3.4273424,,,,,,,,,,,,,, -65800,1.1611301,3.3988845,,,,,,,,,,,,,, -65900,1.1046911,3.862645,,,,,,,,,,,,,, -66000,1.1460862,3.3675466,,,,,,,,,,,,,, -66100,0.9897314,4.662354,,,,,,,,,,,,,, -66200,1.0474976,3.6323116,,,,,,,,,,,,,, -66244,,,0.7400780916213989,1.258586883544922,0.6813799738883972,1.5165966749191284,50000.0,0.5547000169754028,2.1452724933624268,10000.0,30286.45598578453,33312.4827747345,30286.45598578453,3020.252357006073,2.32259488105774,0.0 -66300,1.1308693,3.4423013,,,,,,,,,,,,,, -66400,1.0471874,3.6885662,,,,,,,,,,,,,, -66500,1.1388977,3.3402848,,,,,,,,,,,,,, -66600,1.1620698,3.4938664,,,,,,,,,,,,,, -66700,1.1026592,5.120944,,,,,,,,,,,,,, -66800,1.0797397,3.817369,,,,,,,,,,,,,, -66900,1.194445,4.5075216,,,,,,,,,,,,,, -67000,1.0549593,4.9976454,,,,,,,,,,,,,, -67100,1.1709971,3.4605355,,,,,,,,,,,,,, -67165,,,0.7415820360183716,1.2981585264205933,0.6811999678611755,1.5677835941314695,50000.0,0.5561000108718872,2.169353723526001,10000.0,30706.626230478287,33773.69924163818,30706.626230478287,3061.212740421295,2.3604607582092285,0.0 -67200,1.0751287,4.4183173,,,,,,,,,,,,,, -67300,1.015051,4.4428463,,,,,,,,,,,,,, -67400,1.3169304,3.720451,,,,,,,,,,,,,, -67500,1.1346949,3.48201,,,,,,,,,,,,,, -67600,1.108985,3.5333996,,,,,,,,,,,,,, -67700,1.1365504,3.4201992,,,,,,,,,,,,,, -67800,1.243479,3.435071,,,,,,,,,,,,,, -67900,1.2030283,5.195702,,,,,,,,,,,,,, -68000,1.0560929,3.9231665,,,,,,,,,,,,,, -68086,,,0.7633007764816284,1.1590244770050049,0.6800999641418457,1.5066139698028564,50000.0,0.5576000213623047,2.1319308280944824,10000.0,31126.832008838654,34235.588967084885,31126.832008838654,3102.8121032714844,2.3978052139282227,0.0 -68100,1.087637,3.870198,,,,,,,,,,,,,, -68200,1.113414,3.7716315,,,,,,,,,,,,,, -68300,1.1679984,3.6186652,,,,,,,,,,,,,, -68400,1.0921645,4.1602635,,,,,,,,,,,,,, -68500,1.2350097,3.4193819,,,,,,,,,,,,,, -68600,1.2080069,3.4359708,,,,,,,,,,,,,, -68700,1.1224183,3.4026473,,,,,,,,,,,,,, -68800,1.1076849,5.086182,,,,,,,,,,,,,, -68900,1.3951168,3.544807,,,,,,,,,,,,,, -69000,1.144134,3.3792427,,,,,,,,,,,,,, -69004,,,0.7406054735183716,1.2410019636154177,0.684719979763031,1.4892326593399048,50000.0,0.556600034236908,2.1219170093536377,10000.0,31546.80517745018,34698.98634314537,31546.80517745018,3146.148453235626,2.4373722076416016,0.0 -69100,1.0830148,3.3754597,,,,,,,,,,,,,, -69200,1.1502291,3.3500655,,,,,,,,,,,,,, -69300,1.2157207,3.4067276,,,,,,,,,,,,,, -69400,1.0724484,3.3833039,,,,,,,,,,,,,, -69500,1.1939348,3.5060554,,,,,,,,,,,,,, -69600,1.2231683,3.3764615,,,,,,,,,,,,,, -69700,1.1958759,3.4658523,,,,,,,,,,,,,, -69800,1.1704806,4.201598,,,,,,,,,,,,,, -69900,1.1944596,3.4616294,,,,,,,,,,,,,, -69925,,,0.7484570145606995,1.2337993383407593,0.6854599714279175,1.5051804780960083,50000.0,0.5660000443458557,2.104238510131836,10000.0,31966.791477441788,35162.78013944626,31966.791477441788,3189.873576402664,2.472073554992676,0.0 -70000,1.182789,3.3719542,,,,,,,,,,,,,, -70100,1.1324967,3.4692602,,,,,,,,,,,,,, -70200,1.1685522,3.7175539,,,,,,,,,,,,,, -70300,1.0897851,4.221289,,,,,,,,,,,,,, -70400,1.1395738,3.5144684,,,,,,,,,,,,,, -70500,1.2437819,3.4120295,,,,,,,,,,,,,, -70600,1.2309262,3.4590442,,,,,,,,,,,,,, -70700,1.1588491,3.4416003,,,,,,,,,,,,,, -70800,1.1967076,3.3822503,,,,,,,,,,,,,, -70847,,,0.7518359422683716,1.2066234350204468,0.6805199980735779,1.5176000595092771,50000.0,0.5552999973297119,2.1310625076293945,10000.0,32386.71866321564,35629.093203783035,32386.71866321564,3236.175390481949,2.508272171020508,0.0 -70900,1.0559927,4.3884344,,,,,,,,,,,,,, -71000,1.2318671,5.1584992,,,,,,,,,,,,,, -71100,1.2095765,3.4342034,,,,,,,,,,,,,, -71200,1.0840691,3.5191274,,,,,,,,,,,,,, -71300,1.1068265,3.5411348,,,,,,,,,,,,,, -71400,1.1314675,4.1811695,,,,,,,,,,,,,, -71500,1.1292331,3.588643,,,,,,,,,,,,,, -71600,1.0615861,4.2032475,,,,,,,,,,,,,, -71700,1.1837091,3.316343,,,,,,,,,,,,,, -71770,,,0.747363269329071,1.2343018054962158,0.6899799704551697,1.4829862117767334,50000.0,0.5648000240325928,2.08892560005188,10000.0,32807.06958389282,36087.70830178261,32807.06958389282,3274.3536903858185,2.5455551147460938,0.0 -71800,1.0256971,4.6063643,,,,,,,,,,,,,, -71900,1.1894104,3.3907988,,,,,,,,,,,,,, -72000,1.1718147,3.6959865,,,,,,,,,,,,,, -72100,1.1473712,4.9920607,,,,,,,,,,,,,, -72200,1.1922623,3.4583292,,,,,,,,,,,,,, -72300,1.1390253,3.8986797,,,,,,,,,,,,,, -72400,1.0050192,4.5639696,,,,,,,,,,,,,, -72500,1.1854622,3.4164922,,,,,,,,,,,,,, -72600,1.0717734,4.046699,,,,,,,,,,,,,, -72687,,,0.7515820264816284,1.2129734754562378,0.6888999938964844,1.4854317903518677,50000.0,0.5640000104904175,2.0902061462402344,10000.0,33227.219742536545,36553.0427236557,33227.219742536545,3319.4512207508087,2.5845770835876465,0.0 -72700,1.1248951,3.531977,,,,,,,,,,,,,, -72800,1.0591513,4.7883167,,,,,,,,,,,,,, -72900,1.2192272,5.114199,,,,,,,,,,,,,, -73000,1.270558,4.3773713,,,,,,,,,,,,,, -73100,1.1494474,3.3744285,,,,,,,,,,,,,, -73200,1.1236967,5.120618,,,,,,,,,,,,,, -73300,1.2551215,3.5360014,,,,,,,,,,,,,, -73400,1.0881604,3.5457149,,,,,,,,,,,,,, -73500,1.2824395,3.5231225,,,,,,,,,,,,,, -73600,1.1551002,3.428598,,,,,,,,,,,,,, -73609,,,0.7533984184265137,1.191983938217163,0.6872400045394897,1.485839605331421,50000.0,0.5670000314712524,2.104422807693481,10000.0,33647.15845131874,37015.7394669056,33647.15845131874,3362.12451672554,2.621159076690674,0.0 -73700,1.1890926,3.5964427,,,,,,,,,,,,,, -73800,1.2213073,3.4200897,,,,,,,,,,,,,, -73900,1.1004206,4.996434,,,,,,,,,,,,,, -74000,1.1772451,3.356597,,,,,,,,,,,,,, -74100,1.2353624,5.146737,,,,,,,,,,,,,, -74200,1.1664618,3.4502492,,,,,,,,,,,,,, -74300,1.1888276,3.3972726,,,,,,,,,,,,,, -74400,1.1252105,4.357358,,,,,,,,,,,,,, -74500,1.2339251,3.4077396,,,,,,,,,,,,,, -74530,,,0.7507421970367432,1.195424199104309,0.6930800080299377,1.4571106433868408,50000.0,0.5658000111579895,2.069862604141236,10000.0,34067.44424152374,37478.30951237679,34067.44424152374,3404.320331096649,2.6617767810821533,0.0 -74600,1.0642314,3.7575884,,,,,,,,,,,,,, -74700,1.2767531,5.083767,,,,,,,,,,,,,, -74800,1.2093126,3.3638225,,,,,,,,,,,,,, -74900,1.1498102,4.5065393,,,,,,,,,,,,,, -75000,1.148078,4.93232,,,,,,,,,,,,,, -75100,1.2712588,3.466917,,,,,,,,,,,,,, -75200,1.1390431,3.3239644,,,,,,,,,,,,,, -75300,1.2676065,3.390147,,,,,,,,,,,,,, -75400,1.3776723,3.4366329,,,,,,,,,,,,,, -75453,,,0.7525194883346558,1.1894110441207886,0.6906599998474121,1.454953908920288,50000.0,0.5706000328063965,2.06187105178833,10000.0,34487.681963682175,37940.56576251984,34487.681963682175,3446.2519228458405,2.700667142868042,0.0 -75500,1.2071878,3.4956527,,,,,,,,,,,,,, -75600,1.1656649,3.8422694,,,,,,,,,,,,,, -75700,1.0427567,4.2497134,,,,,,,,,,,,,, -75800,1.0737488,3.567879,,,,,,,,,,,,,, -75900,1.1945487,3.3349068,,,,,,,,,,,,,, -76000,1.1102943,3.5586305,,,,,,,,,,,,,, -76100,1.1225824,3.2679117,,,,,,,,,,,,,, -76200,1.1754726,3.3979516,,,,,,,,,,,,,, -76300,1.1844766,3.4632561,,,,,,,,,,,,,, -76374,,,0.7610741853713989,1.160759210586548,0.6957799792289734,1.4432342052459717,50000.0,0.5688000321388245,2.0531435012817383,10000.0,34908.01672363281,38401.809537410736,34908.01672363281,3487.078436851501,2.736494779586792,0.0 -76400,1.1301446,4.81832,,,,,,,,,,,,,, -76500,1.0571514,4.4179316,,,,,,,,,,,,,, -76600,1.3487948,5.093631,,,,,,,,,,,,,, -76700,1.1427722,3.447279,,,,,,,,,,,,,, -76800,1.0813081,4.3566713,,,,,,,,,,,,,, -76900,1.1186565,4.8241096,,,,,,,,,,,,,, -77000,1.2316861,3.355301,,,,,,,,,,,,,, -77100,1.2078876,3.322207,,,,,,,,,,,,,, -77200,1.1188688,3.3597167,,,,,,,,,,,,,, -77298,,,0.7732812166213989,1.1148550510406494,0.6937199831008911,1.447901725769043,50000.0,0.5685999989509583,2.070373773574829,10000.0,35327.95745301247,38863.893261671066,35327.95745301247,3529.1288471221924,2.78158974647522,0.0 -77300,1.2852606,3.3498902,,,,,,,,,,,,,, -77400,1.1422933,3.5231686,,,,,,,,,,,,,, -77500,1.2440114,3.3231657,,,,,,,,,,,,,, -77600,1.1600401,3.6276798,,,,,,,,,,,,,, -77700,1.2131672,3.4271607,,,,,,,,,,,,,, -77800,1.1430087,5.0251446,,,,,,,,,,,,,, -77900,1.1114753,3.6658554,,,,,,,,,,,,,, -78000,1.2905277,3.3447528,,,,,,,,,,,,,, -78100,1.1947885,3.368299,,,,,,,,,,,,,, -78200,1.2215606,3.4051054,,,,,,,,,,,,,, -78220,,,0.7600781321525574,1.1587432622909546,0.6990599632263184,1.4184653759002686,50000.0,0.5758000016212463,2.019524097442627,10000.0,35748.12657356262,39325.401894807816,35748.12657356262,3570.3713760375977,2.831173181533813,0.0 -78300,1.1636293,4.4440985,,,,,,,,,,,,,, -78400,1.2712642,3.351276,,,,,,,,,,,,,, -78500,1.2388442,3.425613,,,,,,,,,,,,,, -78600,1.130588,3.791162,,,,,,,,,,,,,, -78700,1.2853515,3.371876,,,,,,,,,,,,,, -78800,1.2550284,4.8959656,,,,,,,,,,,,,, -78900,1.2690436,5.094208,,,,,,,,,,,,,, -79000,1.2050928,4.3249154,,,,,,,,,,,,,, -79100,1.2541727,3.379599,,,,,,,,,,,,,, -79142,,,0.7625195384025574,1.148511528968811,0.6945199966430664,1.4347800016403198,50000.0,0.5730000138282776,2.055783987045288,10000.0,36168.48544359207,39787.786851882935,36168.48544359207,3612.3097426891327,2.8709871768951416,0.0 -79200,1.1647234,5.023565,,,,,,,,,,,,,, -79300,1.2575608,3.4111924,,,,,,,,,,,,,, -79400,1.2163938,3.3160682,,,,,,,,,,,,,, -79500,1.1849165,3.312768,,,,,,,,,,,,,, -79600,1.1594696,4.548133,,,,,,,,,,,,,, -79700,1.3242996,4.999246,,,,,,,,,,,,,, -79800,1.2439581,3.4339843,,,,,,,,,,,,,, -79900,1.2184353,3.3606896,,,,,,,,,,,,,, -80000,1.1590344,3.2556298,,,,,,,,,,,,,, -80065,,,0.7742968797683716,1.1092896461486816,0.697380006313324,1.4453190565109253,50000.0,0.5708000063896179,2.067552328109741,10000.0,36588.77942371368,40251.33292555809,36588.77942371368,3655.476496696472,2.9082815647125244,0.0 -80100,1.2795421,3.4969573,,,,,,,,,,,,,, -80200,1.1339589,3.957769,,,,,,,,,,,,,, -80300,1.3590403,3.4200652,,,,,,,,,,,,,, -80400,1.1296954,4.4105353,,,,,,,,,,,,,, -80500,1.3792322,4.8792725,,,,,,,,,,,,,, -80600,1.3123902,3.4240398,,,,,,,,,,,,,, -80700,1.2834727,3.4260006,,,,,,,,,,,,,, -80800,1.2401575,4.261189,,,,,,,,,,,,,, -80900,1.1728944,3.3034718,,,,,,,,,,,,,, -80988,,,0.7635741829872131,1.1445281505584717,0.7017799615859985,1.4090139865875244,50000.0,0.5790000557899475,2.017435073852539,10000.0,37008.99789881706,40713.225462675095,37008.99789881706,3697.0613508224487,2.9498679637908936,0.0 -81000,1.1706538,3.6336856,,,,,,,,,,,,,, -81100,1.2499436,3.2182372,,,,,,,,,,,,,, -81200,1.260956,4.886224,,,,,,,,,,,,,, -81300,1.2994742,3.4370942,,,,,,,,,,,,,, -81400,1.2767652,3.338989,,,,,,,,,,,,,, -81500,1.2071171,3.509609,,,,,,,,,,,,,, -81600,1.2225168,4.6820517,,,,,,,,,,,,,, -81700,1.3076503,5.038406,,,,,,,,,,,,,, -81800,1.2254755,5.004818,,,,,,,,,,,,,, -81900,1.2076523,3.3764098,,,,,,,,,,,,,, -81909,,,0.7632030844688416,1.150298237800598,0.6981399655342102,1.4378381967544556,50000.0,0.5771000385284424,2.034175157546997,10000.0,37429.169956445694,41177.268661022186,37429.169956445694,3740.842449903488,2.9926469326019287,0.0 -82000,1.2112942,4.637289,,,,,,,,,,,,,, -82100,1.2173964,3.3918247,,,,,,,,,,,,,, -82200,1.3358694,5.028734,,,,,,,,,,,,,, -82300,1.1563518,3.401727,,,,,,,,,,,,,, -82400,1.1980363,3.2955518,,,,,,,,,,,,,, -82500,1.159254,3.9793348,,,,,,,,,,,,,, -82600,1.2113848,4.7429333,,,,,,,,,,,,,, -82700,1.2610599,3.3396323,,,,,,,,,,,,,, -82800,1.323984,3.4088583,,,,,,,,,,,,,, -82831,,,0.7751367092132568,1.1169837713241575,0.7000600099563599,1.4290475845336914,50000.0,0.5746000409126282,2.0344953536987305,10000.0,37849.39163827896,41643.815761089325,37849.39163827896,3787.0803208351135,3.0316107273101807,0.0 -82900,1.2330446,3.31761,,,,,,,,,,,,,, -83000,1.1465437,4.187418,,,,,,,,,,,,,, -83100,1.2176706,3.3106804,,,,,,,,,,,,,, -83200,1.3407745,3.4729342,,,,,,,,,,,,,, -83300,1.2528981,4.8210278,,,,,,,,,,,,,, -83400,1.338161,5.004161,,,,,,,,,,,,,, -83500,1.2149427,3.308109,,,,,,,,,,,,,, -83600,1.1704334,3.9430919,,,,,,,,,,,,,, -83700,1.2389984,4.6486998,,,,,,,,,,,,,, -83754,,,0.7632226347923279,1.1648662090301514,0.6999799609184265,1.4325613975524902,50000.0,0.5769000053405762,2.039259433746338,10000.0,38269.435145139694,42108.586940288544,38269.435145139694,3831.722299337387,3.0690486431121826,0.0 -83800,1.1510509,3.5227408,,,,,,,,,,,,,, -83900,1.3754411,3.353579,,,,,,,,,,,,,, -84000,1.1490682,3.5843806,,,,,,,,,,,,,, -84100,1.310898,3.325478,,,,,,,,,,,,,, -84200,1.2466928,3.3733406,,,,,,,,,,,,,, -84300,1.2210674,4.1751842,,,,,,,,,,,,,, -84400,1.3720924,3.325886,,,,,,,,,,,,,, -84500,1.2615938,3.4427788,,,,,,,,,,,,,, -84600,1.1970608,3.5437846,,,,,,,,,,,,,, -84676,,,0.7675585746765137,1.1792206764221191,0.7015199661254883,1.4529019594192505,50000.0,0.5788000226020813,2.06295108795166,10000.0,38689.39722657204,42567.65654802322,38689.39722657204,3870.7454376220703,3.106007099151612,0.0 -84700,1.2928979,3.3408675,,,,,,,,,,,,,, -84800,1.2335504,3.313293,,,,,,,,,,,,,, -84900,1.3987528,5.005512,,,,,,,,,,,,,, -85000,1.2270538,3.6091797,,,,,,,,,,,,,, -85100,1.2723609,3.3761268,,,,,,,,,,,,,, -85200,1.2931285,3.3055398,,,,,,,,,,,,,, -85300,1.4305923,4.99183,,,,,,,,,,,,,, -85400,1.2313813,3.3033478,,,,,,,,,,,,,, -85500,1.2812179,3.4521985,,,,,,,,,,,,,, -85598,,,0.77259761095047,1.1098675727844238,0.7027599811553955,1.4117202758789062,50000.0,0.5811000466346741,2.0271904468536377,10000.0,39109.64110136032,43032.52892065048,39109.64110136032,3915.285789489746,3.146677255630493,0.0 -85600,1.3111486,3.3875299,,,,,,,,,,,,,, -85700,1.3249608,3.3576872,,,,,,,,,,,,,, -85800,1.1128892,3.5570717,,,,,,,,,,,,,, -85900,1.2293264,3.394774,,,,,,,,,,,,,, -86000,1.3612602,3.4034677,,,,,,,,,,,,,, -86100,1.0984526,3.5343738,,,,,,,,,,,,,, -86200,1.3262918,4.9281387,,,,,,,,,,,,,, -86300,1.3051203,4.641877,,,,,,,,,,,,,, -86400,1.5705873,4.8899527,,,,,,,,,,,,,, -86500,1.2368997,3.2090094,,,,,,,,,,,,,, -86519,,,0.7715820074081421,1.1164155006408691,0.7065399885177612,1.3979099988937378,50000.0,0.5836000442504883,2.0163145065307617,10000.0,39529.84063959122,43499.65197634697,39529.84063959122,3962.117981433869,3.1898202896118164,0.0 -86600,1.2504507,4.0815964,,,,,,,,,,,,,, -86700,1.2277749,3.3186219,,,,,,,,,,,,,, -86800,1.218995,3.7108886,,,,,,,,,,,,,, -86900,1.1606557,3.9409883,,,,,,,,,,,,,, -87000,1.1555876,3.4401116,,,,,,,,,,,,,, -87100,1.229491,3.3226037,,,,,,,,,,,,,, -87200,1.2819136,3.3162332,,,,,,,,,,,,,, -87300,1.1721423,3.4942122,,,,,,,,,,,,,, -87400,1.3284653,3.311325,,,,,,,,,,,,,, -87441,,,0.769726574420929,1.133750319480896,0.7074599862098694,1.397431492805481,50000.0,0.5821000337600708,2.0062777996063232,10000.0,39949.914578437805,43964.02792263031,39949.914578437805,4006.3306062221527,3.2322468757629395,0.0 -87500,1.3275462,3.3424232,,,,,,,,,,,,,, -87600,1.3012401,4.2293334,,,,,,,,,,,,,, -87700,1.2308308,4.213423,,,,,,,,,,,,,, -87800,1.1464669,4.4142413,,,,,,,,,,,,,, -87900,1.3882312,3.3245893,,,,,,,,,,,,,, -88000,1.2074593,4.41085,,,,,,,,,,,,,, -88100,1.30531,3.306695,,,,,,,,,,,,,, -88200,1.2531028,4.7445726,,,,,,,,,,,,,, -88300,1.2796189,3.2495966,,,,,,,,,,,,,, -88363,,,0.7766796946525574,1.0799094438552856,0.7064200043678284,1.3791130781173706,50000.0,0.5845000147819519,1.9934600591659544,10000.0,40369.85162806511,44424.50291514397,40369.85162806511,4046.780064105988,3.2729523181915283,0.0 -88400,1.403959,4.844777,,,,,,,,,,,,,, -88500,1.1831833,4.0867043,,,,,,,,,,,,,, -88600,1.2921103,3.3084798,,,,,,,,,,,,,, -88700,1.278558,3.3413427,,,,,,,,,,,,,, -88800,1.3508666,3.4464972,,,,,,,,,,,,,, -88900,1.2170582,3.5822487,,,,,,,,,,,,,, -89000,1.2283869,3.2593973,,,,,,,,,,,,,, -89100,1.3364553,3.375767,,,,,,,,,,,,,, -89200,1.2009872,3.5308971,,,,,,,,,,,,,, -89279,,,0.7904296517372131,1.0638716220855713,0.7069599628448486,1.4218652248382568,50000.0,0.5841000080108643,2.0293197631835938,10000.0,40790.13759255409,44888.84012579918,40790.13759255409,4090.746128797531,3.310702085494995,0.0 -89300,1.2056613,3.5870016,,,,,,,,,,,,,, -89400,1.4108475,4.603187,,,,,,,,,,,,,, -89500,1.1527216,4.1579576,,,,,,,,,,,,,, -89600,1.2877716,3.3573375,,,,,,,,,,,,,, -89700,1.3413178,3.5933313,,,,,,,,,,,,,, -89800,1.4135104,4.860231,,,,,,,,,,,,,, -89900,1.2567211,3.2608907,,,,,,,,,,,,,, -90000,1.3298748,3.2934961,,,,,,,,,,,,,, -90100,1.4290365,4.961713,,,,,,,,,,,,,, -90199,,,0.7699609398841858,1.117898941040039,0.7085599899291992,1.3838273286819458,50000.0,0.5837000012397766,1.990799069404602,10000.0,41210.48593831062,45350.90107703209,41210.48593831062,4132.355852603912,3.3663439750671387,0.0 -90200,1.2580739,3.8776605,,,,,,,,,,,,,, -90300,1.3264643,3.79361,,,,,,,,,,,,,, -90400,1.2577592,3.2977257,,,,,,,,,,,,,, -90500,1.2664841,3.274405,,,,,,,,,,,,,, -90600,1.1951998,4.2900743,,,,,,,,,,,,,, -90700,1.3161576,3.2502358,,,,,,,,,,,,,, -90800,1.3359709,3.6760495,,,,,,,,,,,,,, -90900,1.2344797,3.44779,,,,,,,,,,,,,, -91000,1.304355,3.9153733,,,,,,,,,,,,,, -91100,1.2569206,3.307149,,,,,,,,,,,,,, -91120,,,0.7765820026397705,1.0714341402053833,0.7109000086784363,1.358513593673706,50000.0,0.5908000469207764,1.9569255113601685,10000.0,41630.54138350487,45808.34008026123,41630.54138350487,4169.649255514145,3.409395456314087,0.0 -91200,1.3325118,3.2165217,,,,,,,,,,,,,, -91300,1.241066,3.589428,,,,,,,,,,,,,, -91400,1.3318052,3.240762,,,,,,,,,,,,,, -91500,1.2828817,4.8081694,,,,,,,,,,,,,, -91600,1.3275405,3.2951763,,,,,,,,,,,,,, -91700,1.2926381,3.266372,,,,,,,,,,,,,, -91800,1.2604147,4.707119,,,,,,,,,,,,,, -91900,1.1648517,3.5024238,,,,,,,,,,,,,, -92000,1.3322983,3.3821683,,,,,,,,,,,,,, -92041,,,0.7859570384025574,1.049735188484192,0.7099999785423279,1.3754807710647583,50000.0,0.5905000567436218,1.978685021400452,10000.0,42050.45993804932,46272.82831025124,42050.45993804932,4214.131242513657,3.4492197036743164,0.0 -92100,1.3626288,4.818143,,,,,,,,,,,,,, -92200,1.2514039,3.308105,,,,,,,,,,,,,, -92300,1.3700598,4.874931,,,,,,,,,,,,,, -92400,1.2934682,4.443997,,,,,,,,,,,,,, -92500,1.2268909,4.2052855,,,,,,,,,,,,,, -92600,1.3335077,3.2854476,,,,,,,,,,,,,, -92700,1.3186581,4.6625447,,,,,,,,,,,,,, -92800,1.3214598,3.3204355,,,,,,,,,,,,,, -92900,1.2344842,4.095031,,,,,,,,,,,,,, -92962,,,0.7772070169448853,1.0873217582702637,0.7098199725151062,1.380346417427063,50000.0,0.5901000499725342,1.974666714668274,10000.0,42470.80727481842,46736.64430379868,42470.80727481842,4257.513018131256,3.4884932041168213,0.0 -93000,1.2444785,3.894682,,,,,,,,,,,,,, -93100,1.4302647,3.279571,,,,,,,,,,,,,, -93200,1.2853973,3.8743787,,,,,,,,,,,,,, -93300,1.2466334,3.3092086,,,,,,,,,,,,,, -93400,1.3352286,3.8009124,,,,,,,,,,,,,, -93500,1.2629083,4.0356655,,,,,,,,,,,,,, -93600,1.232672,3.4788284,,,,,,,,,,,,,, -93700,1.1911902,3.3768969,,,,,,,,,,,,,, -93800,1.2103841,3.294121,,,,,,,,,,,,,, -93883,,,0.781445324420929,1.0885802507400513,0.7143999934196472,1.3779401779174805,50000.0,0.5901000499725342,1.9860336780548096,10000.0,42890.88827753067,47200.08113312721,42890.88827753067,4300.7799434661865,3.5301735401153564,0.0 -93900,1.266102,4.3511662,,,,,,,,,,,,,, -94000,1.2568796,3.2332845,,,,,,,,,,,,,, -94100,1.5553918,4.85256,,,,,,,,,,,,,, -94200,1.314561,4.708643,,,,,,,,,,,,,, -94300,1.3481812,4.730819,,,,,,,,,,,,,, -94400,1.4029474,3.421955,,,,,,,,,,,,,, -94500,1.269344,3.6041484,,,,,,,,,,,,,, -94600,1.425401,3.294002,,,,,,,,,,,,,, -94700,1.3765011,3.5264406,,,,,,,,,,,,,, -94800,1.285809,3.419744,,,,,,,,,,,,,, -94806,,,0.7893944978713989,1.0502266883850098,0.7142399549484253,1.3713054656982422,50000.0,0.5893000364303589,1.966513991355896,10000.0,43311.09178900719,47663.47908735275,43311.09178900719,4343.885124206543,3.571937799453736,0.0 -94900,1.3682961,4.900446,,,,,,,,,,,,,, -95000,1.4193145,4.3277636,,,,,,,,,,,,,, -95100,1.341115,3.249474,,,,,,,,,,,,,, -95200,1.3899302,3.3089495,,,,,,,,,,,,,, -95300,1.3566569,3.2739902,,,,,,,,,,,,,, -95400,1.4417849,4.6483703,,,,,,,,,,,,,, -95500,1.4361298,4.8789115,,,,,,,,,,,,,, -95600,1.3174198,3.4279513,,,,,,,,,,,,,, -95700,1.1974914,3.8561068,,,,,,,,,,,,,, -95724,,,0.7759374976158142,1.094131588935852,0.7149199843406677,1.3668842315673828,50000.0,0.5923000574111938,1.9620274305343628,10000.0,43731.27753829956,48129.95305871964,43731.27753829956,4390.079038143158,3.618546485900879,0.0 -95800,1.3940831,4.7493706,,,,,,,,,,,,,, -95900,1.2756143,3.4334927,,,,,,,,,,,,,, -96000,1.4222507,3.289673,,,,,,,,,,,,,, -96100,1.4347748,4.9508457,,,,,,,,,,,,,, -96200,1.2991832,3.5506647,,,,,,,,,,,,,, -96300,1.339666,3.2191823,,,,,,,,,,,,,, -96400,1.4075245,4.927518,,,,,,,,,,,,,, -96500,1.309411,3.2613063,,,,,,,,,,,,,, -96600,1.3401563,3.268042,,,,,,,,,,,,,, -96646,,,0.7854687571525574,1.0612901449203491,0.7143399715423584,1.3597755432128906,50000.0,0.5936000347137451,1.9557676315307613,10000.0,44151.39954543114,48594.39071774483,44151.39954543114,4434.304823875427,3.6608870029449454,0.0 -96700,1.4054599,3.169264,,,,,,,,,,,,,, -96800,1.4195563,3.2519934,,,,,,,,,,,,,, -96900,1.4223056,3.3037512,,,,,,,,,,,,,, -97000,1.3451498,3.4716787,,,,,,,,,,,,,, -97100,1.4398968,4.846284,,,,,,,,,,,,,, -97200,1.224211,3.8077939,,,,,,,,,,,,,, -97300,1.2665857,4.561594,,,,,,,,,,,,,, -97400,1.3716145,4.8624573,,,,,,,,,,,,,, -97500,1.271478,3.5166426,,,,,,,,,,,,,, -97568,,,0.7908984422683716,1.0401599407196045,0.7178399562835693,1.3506696224212646,50000.0,0.594700038433075,1.9495747089385984,10000.0,44571.57318592072,49057.39753556252,44571.57318592072,4477.041138410568,3.7091317176818848,0.0 -97600,1.2385587,3.4034307,,,,,,,,,,,,,, -97700,1.400246,3.5647306,,,,,,,,,,,,,, -97800,1.3962458,4.230703,,,,,,,,,,,,,, -97900,1.5644532,3.225418,,,,,,,,,,,,,, -98000,1.3402462,3.2354727,,,,,,,,,,,,,, -98100,1.2930542,3.3458502,,,,,,,,,,,,,, -98200,1.3988063,3.2475007,,,,,,,,,,,,,, -98300,1.4183128,3.3431005,,,,,,,,,,,,,, -98400,1.5189059,4.816232,,,,,,,,,,,,,, -98488,,,0.8031054735183716,0.9849762320518494,0.7144799828529358,1.3556915521621704,50000.0,0.5902000069618225,1.973228454589844,10000.0,44991.559143066406,49522.340443611145,44991.559143066406,4521.90961265564,3.7491395473480233,0.0 -98500,1.5240811,4.9450364,,,,,,,,,,,,,, -98600,1.4207435,4.858019,,,,,,,,,,,,,, -98700,1.4101789,3.2046642,,,,,,,,,,,,,, -98800,1.3450446,3.1712794,,,,,,,,,,,,,, -98900,1.5209267,3.3333547,,,,,,,,,,,,,, -99000,1.4010806,4.851996,,,,,,,,,,,,,, -99100,1.381971,3.2727017,,,,,,,,,,,,,, -99200,1.459443,4.790359,,,,,,,,,,,,,, -99300,1.5215666,3.3039513,,,,,,,,,,,,,, -99400,1.3916532,3.1761475,,,,,,,,,,,,,, -99411,,,0.7883203029632568,1.0426161289215088,0.7203399538993835,1.3304067850112915,50000.0,0.5999000072479248,1.932340145111084,10000.0,45411.88533329964,49986.21019554138,45411.88533329964,4565.364971160889,3.7899091243743896,0.0 -99500,1.4687185,4.731519,,,,,,,,,,,,,, -99600,1.3977622,3.57555,,,,,,,,,,,,,, -99700,1.2723114,4.438478,,,,,,,,,,,,,, -99800,1.2701695,3.957604,,,,,,,,,,,,,, -99900,1.3239707,4.500735,,,,,,,,,,,,,, -100000,1.2744431,3.1738765,,,,,,,,,,,,,, -100100,1.4485383,4.8405666,,,,,,,,,,,,,, -100200,1.3844005,4.6640873,,,,,,,,,,,,,, -100300,1.346863,4.492906,,,,,,,,,,,,,, -100332,,,0.7882617115974426,1.0813069343566897,0.7168599963188171,1.389996886253357,50000.0,0.5913000106811523,1.9910321235656736,10000.0,45831.8272600174,50447.46520733833,45831.8272600174,4606.589636087418,3.83087944984436,0.0 -100400,1.3336246,3.3055546,,,,,,,,,,,,,, -100500,1.5948521,3.2203703,,,,,,,,,,,,,, -100600,1.303138,3.9968164,,,,,,,,,,,,,, -100700,1.4276245,3.1781902,,,,,,,,,,,,,, -100800,1.436132,3.221406,,,,,,,,,,,,,, -100900,1.3907559,4.4481196,,,,,,,,,,,,,, -101000,1.515183,3.1238878,,,,,,,,,,,,,, -101100,1.2683102,3.7346878,,,,,,,,,,,,,, -101200,1.3492757,3.4402332,,,,,,,,,,,,,, -101255,,,0.8056640625,1.0057238340377808,0.7194799780845642,1.3632384538650513,50000.0,0.5976000428199768,1.9517590999603271,10000.0,46251.89824414253,50907.4188015461,46251.89824414253,4646.38002371788,3.875702142715454,0.0 -101300,1.4706814,3.1696148,,,,,,,,,,,,,, -101400,1.5025856,4.8848534,,,,,,,,,,,,,, -101500,1.4203656,3.220691,,,,,,,,,,,,,, -101600,1.332167,4.208975,,,,,,,,,,,,,, -101700,1.3321017,3.3546965,,,,,,,,,,,,,, -101800,1.5181459,4.4434733,,,,,,,,,,,,,, -101900,1.3382504,3.3557343,,,,,,,,,,,,,, -102000,1.390447,3.2608845,,,,,,,,,,,,,, -102100,1.2733766,4.007679,,,,,,,,,,,,,, -102178,,,0.785839855670929,1.0791743993759155,0.7178199887275696,1.3750110864639282,50000.0,0.5955000519752502,1.9908602237701416,10000.0,46672.21717500687,51366.9196677208,46672.21717500687,4685.471276283264,3.91915512084961,0.0 -102200,1.2735908,3.8582788,,,,,,,,,,,,,, -102300,1.4116354,3.7390661,,,,,,,,,,,,,, -102400,1.3687528,3.755065,,,,,,,,,,,,,, -102500,1.4091666,3.192528,,,,,,,,,,,,,, -102600,1.4949237,3.1836329,,,,,,,,,,,,,, -102700,1.331994,3.3970454,,,,,,,,,,,,,, -102800,1.5179687,4.883755,,,,,,,,,,,,,, -102900,1.4843173,4.7782617,,,,,,,,,,,,,, -103000,1.4852655,4.6999397,,,,,,,,,,,,,, -103099,,,0.8001952767372131,1.0049448013305664,0.7236599922180176,1.3200238943099976,50000.0,0.6014000177383423,1.91221821308136,10000.0,47092.51544976234,51831.56687259674,47092.51544976234,4729.725798368454,3.966155767440796,0.0 -103100,1.4139351,3.1687162,,,,,,,,,,,,,, -103200,1.4117321,4.1718016,,,,,,,,,,,,,, -103300,1.4165405,3.1291027,,,,,,,,,,,,,, -103400,1.3535017,3.2838967,,,,,,,,,,,,,, -103500,1.3413289,3.8504047,,,,,,,,,,,,,, -103600,1.5258048,3.184104,,,,,,,,,,,,,, -103700,1.641918,3.1635222,,,,,,,,,,,,,, -103800,1.491459,4.6957035,,,,,,,,,,,,,, -103900,1.3049213,4.331925,,,,,,,,,,,,,, -104000,1.4129169,3.145186,,,,,,,,,,,,,, -104021,,,0.802050769329071,1.0305854082107544,0.723580002784729,1.3682783842086792,50000.0,0.6020000576972961,1.969410419464112,10000.0,47512.693658828735,52291.53897356987,47512.693658828735,4769.427984714508,4.009620189666748,0.0 -104100,1.3842821,3.4032116,,,,,,,,,,,,,, -104200,1.4415903,4.619898,,,,,,,,,,,,,, -104300,1.5587792,3.1592607,,,,,,,,,,,,,, -104400,1.4305727,3.0668013,,,,,,,,,,,,,, -104500,1.3792945,3.1095762,,,,,,,,,,,,,, -104600,1.3030899,3.8462915,,,,,,,,,,,,,, -104700,1.5056355,3.1833162,,,,,,,,,,,,,, -104800,1.4667684,3.8698025,,,,,,,,,,,,,, -104900,1.510914,3.2164583,,,,,,,,,,,,,, -104942,,,0.7951952815055847,1.0146361589431765,0.7224000096321106,1.322086215019226,50000.0,0.6025000214576721,1.9171737432479856,10000.0,47932.840673685074,52757.84145927429,47932.840673685074,4815.495166063309,4.050985097885132,0.0 -105000,1.4518193,3.17635,,,,,,,,,,,,,, -105100,1.4957458,3.1545699,,,,,,,,,,,,,, -105200,1.4729823,4.46281,,,,,,,,,,,,,, -105300,1.4718065,3.1507137,,,,,,,,,,,,,, -105400,1.2596301,3.7835877,,,,,,,,,,,,,, -105500,1.4410782,3.21017,,,,,,,,,,,,,, -105600,1.4384077,4.4024706,,,,,,,,,,,,,, -105700,1.392685,3.3638613,,,,,,,,,,,,,, -105800,1.4147259,3.2062964,,,,,,,,,,,,,, -105860,,,0.79212886095047,1.0310273170471191,0.7222200036048889,1.32786226272583,50000.0,0.6021000146865845,1.924709677696228,10000.0,48352.49347019196,53217.91585254669,48352.49347019196,4855.456739187241,4.464348077774048,0.0 -105900,1.3597728,4.159672,,,,,,,,,,,,,, -106000,1.3152294,4.0978394,,,,,,,,,,,,,, -106100,1.4470719,3.4555454,,,,,,,,,,,,,, -106200,1.3935586,3.1473799,,,,,,,,,,,,,, -106300,1.4782449,3.3934546,,,,,,,,,,,,,, -106400,1.6958072,4.818651,,,,,,,,,,,,,, -106500,1.4713204,3.2058203,,,,,,,,,,,,,, -106600,1.4147977,3.1427922,,,,,,,,,,,,,, -106700,1.439128,3.1605237,,,,,,,,,,,,,, -106780,,,0.8042968511581421,0.9922462105751038,0.7252999544143677,1.32797110080719,50000.0,0.6099000573158264,1.9118989706039429,10000.0,48772.86058592796,53681.318464279175,48772.86058592796,4898.399055242538,4.509274482727051,0.0 -106800,1.3874974,3.6345131,,,,,,,,,,,,,, -106900,1.5124098,4.7168922,,,,,,,,,,,,,, -107000,1.6319427,3.2975106,,,,,,,,,,,,,, -107100,1.4010316,4.1261063,,,,,,,,,,,,,, -107200,1.4613036,3.9731362,,,,,,,,,,,,,, -107300,1.3902746,4.5455647,,,,,,,,,,,,,, -107400,1.3669991,3.3683197,,,,,,,,,,,,,, -107500,1.4763527,3.1445973,,,,,,,,,,,,,, -107600,1.3932047,3.2127278,,,,,,,,,,,,,, -107699,,,0.802734375,0.9875910878181458,0.7257199883460999,1.3120582103729248,50000.0,0.6027000546455383,1.9079450368881223,10000.0,49192.9281001091,54145.81826877594,49192.9281001091,4942.739294528961,4.553928375244141,0.0 -107700,1.3211211,3.5941887,,,,,,,,,,,,,, -107800,1.60057,4.934921,,,,,,,,,,,,,, -107900,1.6014117,4.585443,,,,,,,,,,,,,, -108000,1.3189563,3.1124516,,,,,,,,,,,,,, -108100,1.3431555,3.7167354,,,,,,,,,,,,,, -108200,1.4742477,3.2207296,,,,,,,,,,,,,, -108300,1.4471931,3.1166248,,,,,,,,,,,,,, -108400,1.3657523,3.8225307,,,,,,,,,,,,,, -108500,1.4758124,3.3958054,,,,,,,,,,,,,, -108600,1.4676238,3.1718326,,,,,,,,,,,,,, -108619,,,0.8020898103713989,0.997575342655182,0.7286199927330017,1.3085391521453855,50000.0,0.6051000356674194,1.904710292816162,10000.0,49613.25364589691,54608.0376534462,49613.25364589691,4984.543601036072,4.5956196784973145,0.0 -108700,1.5393628,3.1667042,,,,,,,,,,,,,, -108800,1.5494606,3.1440954,,,,,,,,,,,,,, -108900,1.4032838,3.3467562,,,,,,,,,,,,,, -109000,1.3193636,3.3934221,,,,,,,,,,,,,, -109100,1.4895709,3.171541,,,,,,,,,,,,,, -109200,1.5237324,3.0766673,,,,,,,,,,,,,, -109300,1.5412472,4.47618,,,,,,,,,,,,,, -109400,1.5625826,3.2191148,,,,,,,,,,,,,, -109500,1.5134927,3.1488185,,,,,,,,,,,,,, -109540,,,0.8037304282188416,0.9704073071479796,0.7299799919128418,1.2849246263504028,50000.0,0.6074000000953674,1.8863749504089355,10000.0,50033.40543913841,55069.19483280182,50033.40543913841,5025.457714557648,4.6387939453125,0.0 -109600,1.801103,4.7436624,,,,,,,,,,,,,, -109700,1.4859401,3.8557935,,,,,,,,,,,,,, -109800,1.4919181,3.1558278,,,,,,,,,,,,,, -109900,1.4171994,3.4673817,,,,,,,,,,,,,, -110000,1.518379,3.5303266,,,,,,,,,,,,,, -110100,1.4278129,3.5965185,,,,,,,,,,,,,, -110200,1.522875,3.3755655,,,,,,,,,,,,,, -110300,1.4193163,3.1333761,,,,,,,,,,,,,, -110400,1.5281891,3.1451082,,,,,,,,,,,,,, -110460,,,0.8193163871765137,0.9186491370201112,0.7305399775505066,1.2858415842056274,50000.0,0.6077000498771667,1.892633080482483,10000.0,50453.32793235779,55534.84929966927,50453.32793235779,5071.100483894348,4.6809704303741455,0.0 -110500,1.4790773,4.5165153,,,,,,,,,,,,,, -110600,1.5389758,3.1175096,,,,,,,,,,,,,, -110700,1.5546272,3.1962934,,,,,,,,,,,,,, -110800,1.6313682,3.3005545,,,,,,,,,,,,,, -110900,1.7505877,4.7458587,,,,,,,,,,,,,, -111000,1.3466681,3.4057345,,,,,,,,,,,,,, -111100,1.3860469,3.9576209,,,,,,,,,,,,,, -111200,1.5154362,4.138013,,,,,,,,,,,,,, -111300,1.515561,3.093771,,,,,,,,,,,,,, -111382,,,0.8045117259025574,0.9946995973587036,0.7307999730110168,1.30134117603302,50000.0,0.6093000173568726,1.9110552072525024,10000.0,50873.48181271553,56000.36197376251,50873.48181271553,5116.349600315094,4.732409715652466,0.0 -111400,1.6456789,4.3815837,,,,,,,,,,,,,, -111500,1.4408107,4.3263884,,,,,,,,,,,,,, -111600,1.5681177,3.2391548,,,,,,,,,,,,,, -111700,1.4548872,3.065221,,,,,,,,,,,,,, -111800,1.4834158,3.288125,,,,,,,,,,,,,, -111900,1.5076046,3.532464,,,,,,,,,,,,,, -112000,1.5234874,4.5157003,,,,,,,,,,,,,, -112100,1.5486815,4.2985992,,,,,,,,,,,,,, -112200,1.5685654,3.0956187,,,,,,,,,,,,,, -112300,1.4498043,3.478149,,,,,,,,,,,,,, -112305,,,0.8092968463897705,0.9513814449310304,0.7310999631881714,1.288879632949829,50000.0,0.609000027179718,1.8916889429092407,10000.0,51293.81174302101,56465.59639286995,51293.81174302101,5161.157242059708,4.781898736953735,0.0 -112400,1.6411464,3.1986208,,,,,,,,,,,,,, -112500,1.6972764,3.1992388,,,,,,,,,,,,,, -112600,1.3741498,3.6898708,,,,,,,,,,,,,, -112700,1.4272484,3.4032993,,,,,,,,,,,,,, -112800,1.4985058,3.1791174,,,,,,,,,,,,,, -112900,1.4909571,3.2195036,,,,,,,,,,,,,, -113000,1.4205327,3.118713,,,,,,,,,,,,,, -113100,1.563716,3.1482763,,,,,,,,,,,,,, -113200,1.4910144,3.0743513,,,,,,,,,,,,,, -113227,,,0.81849604845047,0.898280680179596,0.7363399863243103,1.2454630136489868,50000.0,0.6177000403404236,1.8420090675354004,10000.0,51713.83294534683,56929.53366589546,51713.83294534683,5204.9848573207855,4.822944164276123,0.0 -113300,1.5263354,3.1510277,,,,,,,,,,,,,, -113400,1.6571372,4.4680805,,,,,,,,,,,,,, -113500,1.4549586,3.5321424,,,,,,,,,,,,,, -113600,1.4687215,3.2084935,,,,,,,,,,,,,, -113700,1.4945128,3.1462898,,,,,,,,,,,,,, -113800,1.4791874,3.1152182,,,,,,,,,,,,,, -113900,1.6028308,3.2722392,,,,,,,,,,,,,, -114000,1.4111768,3.8106332,,,,,,,,,,,,,, -114100,1.5510234,3.069053,,,,,,,,,,,,,, -114148,,,0.8068945407867432,0.9638850688934326,0.7357199788093567,1.2680931091308594,50000.0,0.6145000457763672,1.8680979013442995,10000.0,52133.946326971054,57389.301633358,52133.946326971054,5244.5484964847565,4.866366386413574,0.0 -114200,1.5017569,3.070502,,,,,,,,,,,,,, -114300,1.3898773,3.0327117,,,,,,,,,,,,,, -114400,1.4613706,3.5954235,,,,,,,,,,,,,, -114500,1.6384248,3.1408274,,,,,,,,,,,,,, -114600,1.4473621,3.301516,,,,,,,,,,,,,, -114700,1.4755183,3.1804883,,,,,,,,,,,,,, -114800,1.7993543,4.808819,,,,,,,,,,,,,, -114900,1.4878771,3.3819258,,,,,,,,,,,,,, -115000,1.7277074,3.575259,,,,,,,,,,,,,, -115071,,,0.81201171875,0.9675384163856506,0.7343999743461609,1.2889907360076904,50000.0,0.614300012588501,1.8860503435134888,10000.0,52554.23753380776,57851.095802783966,52554.23753380776,5285.962988376617,4.9072349071502686,0.0 -115100,1.4987047,3.1022253,,,,,,,,,,,,,, -115200,1.3733456,3.8480513,,,,,,,,,,,,,, -115300,1.5645787,3.1331716,,,,,,,,,,,,,, -115400,1.4789929,3.064733,,,,,,,,,,,,,, -115500,1.6287067,3.1816244,,,,,,,,,,,,,, -115600,1.6318629,3.1948087,,,,,,,,,,,,,, -115700,1.5619737,3.830779,,,,,,,,,,,,,, -115800,1.5680461,3.1074169,,,,,,,,,,,,,, -115900,1.5895717,4.687319,,,,,,,,,,,,,, -115991,,,0.8161523342132568,0.93476402759552,0.7350599765777588,1.2736767530441284,50000.0,0.6172000169754028,1.8698878288269043,10000.0,52974.53467059136,58315.42613291741,52974.53467059136,5329.898587703705,4.956961631774902,0.0 -116000,1.5664045,3.1350715,,,,,,,,,,,,,, -116100,1.5731977,4.1691065,,,,,,,,,,,,,, -116200,1.6551259,3.150439,,,,,,,,,,,,,, -116300,1.5694921,3.1878703,,,,,,,,,,,,,, -116400,1.5453631,3.2527251,,,,,,,,,,,,,, -116500,1.486524,3.8489366,,,,,,,,,,,,,, -116600,1.6801957,3.1260912,,,,,,,,,,,,,, -116700,1.6838872,3.0981295,,,,,,,,,,,,,, -116800,1.5655364,3.1130376,,,,,,,,,,,,,, -116900,1.7240651,3.1498005,,,,,,,,,,,,,, -116909,,,0.8122069835662842,0.9754149317741394,0.7371199727058411,1.2887877225875854,50000.0,0.6185000538825989,1.8774621486663816,10000.0,53394.7826769352,58777.340618133545,53394.7826769352,5371.47057056427,5.00386118888855,0.0 -117000,1.6362379,3.2027557,,,,,,,,,,,,,, -117100,1.6927348,3.1921575,,,,,,,,,,,,,, -117200,1.527406,3.7983797,,,,,,,,,,,,,, -117300,1.6777648,3.1216023,,,,,,,,,,,,,, -117400,1.8721554,4.754442,,,,,,,,,,,,,, -117500,1.6887505,4.5031157,,,,,,,,,,,,,, -117600,1.6518693,4.5201426,,,,,,,,,,,,,, -117700,1.6369548,3.0296578,,,,,,,,,,,,,, -117800,1.6088376,3.1833124,,,,,,,,,,,,,, -117830,,,0.81898432970047,0.918323814868927,0.7381199598312378,1.2621748447418213,50000.0,0.617900013923645,1.8480632305145264,10000.0,53814.80868077278,59242.60025715828,53814.80868077278,5416.606198310852,5.0547590255737305,0.0 -117900,1.7466258,4.45515,,,,,,,,,,,,,, -118000,1.8054785,4.753069,,,,,,,,,,,,,, -118100,1.5816644,4.4627624,,,,,,,,,,,,,, -118200,1.4255702,3.8222682,,,,,,,,,,,,,, -118300,1.5805744,3.7135196,,,,,,,,,,,,,, -118400,1.7907536,3.1226346,,,,,,,,,,,,,, -118500,1.5356246,4.0137906,,,,,,,,,,,,,, -118600,1.485087,3.3701222,,,,,,,,,,,,,, -118700,1.5891411,3.1239424,,,,,,,,,,,,,, -118753,,,0.8180468678474426,0.9298332929611206,0.7390599846839905,1.263730764389038,50000.0,0.6194000244140625,1.8396508693695068,10000.0,54235.02416753769,59707.91237425804,54235.02416753769,5461.611189365387,5.096725225448608,0.0 -118800,1.5539026,3.341034,,,,,,,,,,,,,, -118900,1.4689798,2.9987564,,,,,,,,,,,,,, -119000,1.6334082,3.2386868,,,,,,,,,,,,,, -119100,1.6216799,3.638418,,,,,,,,,,,,,, -119200,1.5957137,4.155032,,,,,,,,,,,,,, -119300,1.5979283,3.0861306,,,,,,,,,,,,,, -119400,1.5642983,3.944942,,,,,,,,,,,,,, -119500,1.609608,3.1850333,,,,,,,,,,,,,, -119600,1.7290545,3.9162374,,,,,,,,,,,,,, -119674,,,0.8340820074081421,0.8498026728630066,0.7407199740409851,1.2424492835998535,50000.0,0.6140000224113464,1.8454985618591309,10000.0,54655.402338027954,60169.75206565857,54655.402338027954,5502.972692966461,5.1483683586120605,0.0 -119700,1.7768466,4.442534,,,,,,,,,,,,,, -119800,1.5466148,3.606401,,,,,,,,,,,,,, -119900,1.594549,3.0908127,,,,,,,,,,,,,, -120000,1.6031162,3.2642899,,,,,,,,,,,,,, -120100,1.5754915,3.11819,,,,,,,,,,,,,, -120200,1.6071464,3.316476,,,,,,,,,,,,,, -120300,1.7230463,3.538486,,,,,,,,,,,,,, -120400,1.6595192,3.1111083,,,,,,,,,,,,,, -120500,1.6459509,3.192588,,,,,,,,,,,,,, -120593,,,0.8211132884025574,0.9188494682312012,0.7406799793243408,1.249154806137085,50000.0,0.6239000558853149,1.833344459533692,10000.0,55075.56313490868,60629.047945261,55075.56313490868,5542.008673429489,5.199536323547363,0.0 -120600,1.6852908,3.142271,,,,,,,,,,,,,, -120700,1.7905442,3.0972743,,,,,,,,,,,,,, -120800,1.5980463,3.8461268,,,,,,,,,,,,,, -120900,1.5852224,3.2601993,,,,,,,,,,,,,, -121000,1.6731796,4.3044157,,,,,,,,,,,,,, -121100,1.7688788,3.0999725,,,,,,,,,,,,,, -121200,1.615426,4.2436657,,,,,,,,,,,,,, -121300,1.5656362,3.4430115,,,,,,,,,,,,,, -121400,1.7769024,4.2498655,,,,,,,,,,,,,, -121500,1.6309401,3.0339568,,,,,,,,,,,,,, -121514,,,0.8241796493530273,0.8857353925704956,0.7432799935340881,1.2298856973648071,50000.0,0.6218000054359436,1.8167213201522827,10000.0,55495.675506830215,61091.24942660332,55495.675506830215,5584.007612705231,5.241910457611084,0.0 -121600,1.5476118,3.091674,,,,,,,,,,,,,, -121700,1.6512783,3.0894492,,,,,,,,,,,,,, -121800,1.5513718,3.3557127,,,,,,,,,,,,,, -121900,1.688811,3.0815067,,,,,,,,,,,,,, -122000,1.7329043,3.1381867,,,,,,,,,,,,,, -122100,1.6707584,3.0703652,,,,,,,,,,,,,, -122200,1.594399,3.353365,,,,,,,,,,,,,, -122300,1.6036577,3.3922563,,,,,,,,,,,,,, -122400,1.7288567,4.3283644,,,,,,,,,,,,,, -122435,,,0.8314648270606995,0.890949010848999,0.7434200048446655,1.2616345882415771,50000.0,0.6230000257492065,1.8468738794326784,10000.0,55915.97837305069,61556.50082373619,55915.97837305069,5628.865107297897,5.285839796066284,0.0 -122500,1.5691069,3.0956411,,,,,,,,,,,,,, -122600,1.7038689,3.0525224,,,,,,,,,,,,,, -122700,1.6454397,3.072925,,,,,,,,,,,,,, -122800,1.8013281,3.1070392,,,,,,,,,,,,,, -122900,1.6208061,3.2752151,,,,,,,,,,,,,, -123000,1.5431154,3.520147,,,,,,,,,,,,,, -123100,1.5786725,3.828548,,,,,,,,,,,,,, -123200,1.763427,3.1415586,,,,,,,,,,,,,, -123300,1.7197287,3.229298,,,,,,,,,,,,,, -123357,,,0.8240624666213989,0.9000077843666077,0.7458999752998352,1.2291886806488037,50000.0,0.6271000504493713,1.825277090072632,10000.0,56336.17573904991,62019.912034511566,56336.17573904991,5671.9865918159485,5.331271648406982,0.0 -123400,1.5883871,3.082423,,,,,,,,,,,,,, -123500,1.7714331,3.1764565,,,,,,,,,,,,,, -123600,1.8106049,4.4432063,,,,,,,,,,,,,, -123700,1.6369944,3.1069133,,,,,,,,,,,,,, -123800,1.7637393,4.2599196,,,,,,,,,,,,,, -123900,1.6630877,4.087774,,,,,,,,,,,,,, -124000,1.7754165,3.068221,,,,,,,,,,,,,, -124100,1.7150798,3.2572556,,,,,,,,,,,,,, -124200,1.8702905,3.086051,,,,,,,,,,,,,, -124277,,,0.8270507454872131,0.9012371897697448,0.7419599890708923,1.2539701461791992,50000.0,0.6223000288009644,1.840736985206604,10000.0,56756.23712944984,62480.5327205658,56756.23712944984,5712.451839923859,5.377405405044556,0.0 -124300,1.5986716,3.0957863,,,,,,,,,,,,,, -124400,1.6124505,3.7578013,,,,,,,,,,,,,, -124500,1.6765873,3.8327305,,,,,,,,,,,,,, -124600,1.7056848,3.5986223,,,,,,,,,,,,,, -124700,1.6511362,3.8949156,,,,,,,,,,,,,, -124800,1.7864934,3.0306056,,,,,,,,,,,,,, -124900,1.6050591,3.0329576,,,,,,,,,,,,,, -125000,1.7838068,3.1639643,,,,,,,,,,,,,, -125100,1.8488271,3.010415,,,,,,,,,,,,,, -125199,,,0.8361327648162842,0.8307301998138428,0.746999979019165,1.1995173692703247,50000.0,0.629300057888031,1.7952983379364014,10000.0,57176.52000498772,62944.29509925842,57176.52000498772,5755.837740182877,5.423417568206787,0.0 -125200,1.7481437,4.170062,,,,,,,,,,,,,, -125300,1.5901443,3.1955125,,,,,,,,,,,,,, -125400,1.7356783,2.9938002,,,,,,,,,,,,,, -125500,1.689761,2.989167,,,,,,,,,,,,,, -125600,1.8215566,4.3479366,,,,,,,,,,,,,, -125700,1.6146932,3.858625,,,,,,,,,,,,,, -125800,1.6561383,3.262933,,,,,,,,,,,,,, -125900,1.9771876,4.6050887,,,,,,,,,,,,,, -126000,1.7818214,3.0104935,,,,,,,,,,,,,, -126100,1.5881488,3.3158073,,,,,,,,,,,,,, -126119,,,0.8282421827316284,0.8835758566856384,0.7492600083351135,1.218082070350647,50000.0,0.6274000406265259,1.814552426338196,10000.0,57596.76955342293,63409.46954703331,57596.76955342293,5800.669831514359,5.46803092956543,0.0 -126200,1.7573847,3.3984632,,,,,,,,,,,,,, -126300,2.0363286,4.629043,,,,,,,,,,,,,, -126400,1.9195449,4.5109406,,,,,,,,,,,,,, -126500,1.7692784,2.9838147,,,,,,,,,,,,,, -126600,1.788684,4.3617344,,,,,,,,,,,,,, -126700,1.6698284,3.0965395,,,,,,,,,,,,,, -126800,1.5692377,3.09269,,,,,,,,,,,,,, -126900,1.5686247,3.844462,,,,,,,,,,,,,, -127000,1.7830522,3.1396377,,,,,,,,,,,,,, -127039,,,0.8307812213897705,0.8714902400970459,0.7470600008964539,1.2164604663848877,50000.0,0.6255000233650208,1.8085092306137085,10000.0,58016.67089056969,63868.15580749512,58016.67089056969,5839.35227394104,5.522529602050781,0.0 -127100,1.647273,3.0031009,,,,,,,,,,,,,, -127200,1.6141224,3.2629225,,,,,,,,,,,,,, -127300,1.7081646,3.1301217,,,,,,,,,,,,,, -127400,1.6826512,2.9928594,,,,,,,,,,,,,, -127500,1.6900929,3.0311165,,,,,,,,,,,,,, -127600,2.02011,3.0903718,,,,,,,,,,,,,, -127700,1.6755822,3.1999571,,,,,,,,,,,,,, -127800,1.5929792,3.0624673,,,,,,,,,,,,,, -127900,1.7903975,4.1327996,,,,,,,,,,,,,, -127959,,,0.8356054425239563,0.8296651244163513,0.7488399744033813,1.202234983444214,50000.0,0.6254000067710876,1.7854843139648438,10000.0,58436.60658454895,64331.94377684593,58436.60658454895,5883.110562562943,5.568439722061157,0.0 -128000,1.7575668,2.9949098,,,,,,,,,,,,,, -128100,1.7612678,3.4210732,,,,,,,,,,,,,, -128200,1.8158771,3.10466,,,,,,,,,,,,,, -128300,1.7860037,3.0895581,,,,,,,,,,,,,, -128400,1.8971411,4.5381627,,,,,,,,,,,,,, -128500,1.8066018,3.9627094,,,,,,,,,,,,,, -128600,1.6055279,3.6375158,,,,,,,,,,,,,, -128700,1.6865755,3.8556867,,,,,,,,,,,,,, -128800,1.8339884,2.9977992,,,,,,,,,,,,,, -128881,,,0.8422460556030273,0.8313546776771545,0.752020001411438,1.212985873222351,50000.0,0.6290000081062317,1.7977409362792969,10000.0,58856.884100198746,64793.35998272896,58856.884100198746,5924.156987428665,5.612093925476074,0.0 -128900,1.7431244,3.2379296,,,,,,,,,,,,,, -129000,1.6394947,3.8294365,,,,,,,,,,,,,, -129100,1.8770722,3.7203069,,,,,,,,,,,,,, -129200,1.7659543,3.520037,,,,,,,,,,,,,, -129300,1.8875246,3.0591037,,,,,,,,,,,,,, -129400,1.8218826,2.9939613,,,,,,,,,,,,,, -129500,1.69327,3.463568,,,,,,,,,,,,,, -129600,1.6518954,3.1190245,,,,,,,,,,,,,, -129700,1.9177953,4.5833936,,,,,,,,,,,,,, -129800,1.7140096,2.9919693,,,,,,,,,,,,,, -129804,,,0.8338280916213989,0.8379896283149719,0.7499600052833557,1.1926852464675903,50000.0,0.629800021648407,1.786934494972229,10000.0,59276.93350100517,65257.85636162758,59276.93350100517,5968.510313987732,5.657716512680054,0.0 -129900,2.076426,4.566325,,,,,,,,,,,,,, -130000,1.7115724,3.3577151,,,,,,,,,,,,,, -130100,1.7515715,3.7715337,,,,,,,,,,,,,, -130200,1.8559024,4.0870705,,,,,,,,,,,,,, -130300,1.6930516,3.004113,,,,,,,,,,,,,, -130400,1.7864872,4.1830535,,,,,,,,,,,,,, -130500,1.7745168,2.9794066,,,,,,,,,,,,,, -130600,1.9180546,4.178173,,,,,,,,,,,,,, -130700,1.8110771,3.0088675,,,,,,,,,,,,,, -130725,,,0.8382812142372131,0.8320725560188293,0.7503199577331543,1.1947848796844482,50000.0,0.6312000155448914,1.7743438482284546,10000.0,59696.888902425766,65720.388225317,59696.888902425766,6010.993898153305,5.702925443649292,0.0 -130800,1.7613238,3.7595987,,,,,,,,,,,,,, -130900,1.8930615,4.2385244,,,,,,,,,,,,,, -131000,1.6872758,2.9472513,,,,,,,,,,,,,, -131100,1.8493137,3.935151,,,,,,,,,,,,,, -131200,1.7717184,3.0903976,,,,,,,,,,,,,, -131300,1.9066368,4.34576,,,,,,,,,,,,,, -131400,1.8725597,3.96967,,,,,,,,,,,,,, -131500,2.2326925,4.5265417,,,,,,,,,,,,,, -131600,1.7349182,3.2142148,,,,,,,,,,,,,, -131647,,,0.8476171493530273,0.8037877678871155,0.7534599900245667,1.1936339139938354,50000.0,0.6288000345230103,1.7920557260513306,10000.0,60117.00132703781,66185.6890695095,60117.00132703781,6056.08829331398,5.74944281578064,0.0 -131700,1.7782922,2.9835536,,,,,,,,,,,,,, -131800,1.8691543,3.1138365,,,,,,,,,,,,,, -131900,1.743306,3.0083213,,,,,,,,,,,,,, -132000,1.7860829,3.109576,,,,,,,,,,,,,, -132100,1.8215286,3.0694377,,,,,,,,,,,,,, -132200,1.9280361,4.369454,,,,,,,,,,,,,, -132300,1.8110645,3.025589,,,,,,,,,,,,,, -132400,1.8561059,3.1152596,,,,,,,,,,,,,, -132500,1.9177196,3.1829581,,,,,,,,,,,,,, -132568,,,0.8365820050239563,0.8467892408370972,0.7517600059509277,1.205349326133728,50000.0,0.6355000138282776,1.788927674293518,10000.0,60537.26391768456,66649.95460033417,60537.26391768456,6099.994526147842,5.798482418060303,0.0 -132600,1.7108558,3.0263064,,,,,,,,,,,,,, -132700,1.9750493,2.9644697,,,,,,,,,,,,,, -132800,1.8212984,3.826482,,,,,,,,,,,,,, -132900,1.9925843,3.0806696,,,,,,,,,,,,,, -133000,1.7418418,3.243733,,,,,,,,,,,,,, -133100,1.8580991,2.9796057,,,,,,,,,,,,,, -133200,1.8328643,2.9638836,,,,,,,,,,,,,, -133300,1.8644388,3.7894464,,,,,,,,,,,,,, -133400,1.8049915,3.0755026,,,,,,,,,,,,,, -133490,,,0.8384960889816284,0.835568904876709,0.7537399530410767,1.1961098909378052,50000.0,0.6336000561714172,1.7881088256835938,10000.0,60957.49088454247,67109.25989794731,60957.49088454247,6138.977798938751,5.845485210418701,0.0 -133500,1.966014,3.0844612,,,,,,,,,,,,,, -133600,2.1025739,4.4144115,,,,,,,,,,,,,, -133700,1.8122442,4.061293,,,,,,,,,,,,,, -133800,1.8649,3.03259,,,,,,,,,,,,,, -133900,1.8314433,3.1033638,,,,,,,,,,,,,, -134000,1.6902663,3.2825253,,,,,,,,,,,,,, -134100,1.7113198,3.3839345,,,,,,,,,,,,,, -134200,1.9323914,3.6174572,,,,,,,,,,,,,, -134300,2.161522,4.467758,,,,,,,,,,,,,, -134400,1.6386467,3.2575068,,,,,,,,,,,,,, -134411,,,0.8448827862739563,0.7935135364532471,0.7524799704551697,1.1886016130447388,50000.0,0.6336000561714172,1.7869489192962646,10000.0,61377.691182136536,67575.92088413239,61377.691182136536,6185.34400844574,5.892377853393555,0.0 -134500,1.6999782,3.1475701,,,,,,,,,,,,,, -134600,1.7349126,2.95542,,,,,,,,,,,,,, -134700,1.7573867,3.4938323,,,,,,,,,,,,,, -134800,1.9028093,3.6578135,,,,,,,,,,,,,, -134900,1.7965089,3.1073408,,,,,,,,,,,,,, -135000,1.650497,3.5312936,,,,,,,,,,,,,, -135100,1.8964003,2.9872687,,,,,,,,,,,,,, -135200,1.7666881,2.9220402,,,,,,,,,,,,,, -135300,1.7697821,3.067948,,,,,,,,,,,,,, -135329,,,0.8392773270606995,0.8438341617584229,0.7569599747657776,1.196811318397522,50000.0,0.6373000144958496,1.7893301248550415,10000.0,61797.867267131805,68038.98924589157,61797.867267131805,6228.144732713699,5.9372031688690186,0.0 -135400,1.8046012,2.9777,,,,,,,,,,,,,, -135500,1.8659792,3.0101914,,,,,,,,,,,,,, -135600,1.7760154,3.0451832,,,,,,,,,,,,,, -135700,1.8639814,3.2390711,,,,,,,,,,,,,, -135800,1.7187173,3.1000063,,,,,,,,,,,,,, -135900,1.816649,3.1896493,,,,,,,,,,,,,, -136000,1.8195986,3.4128275,,,,,,,,,,,,,, -136100,1.655007,3.3481374,,,,,,,,,,,,,, -136200,1.7844901,3.0199804,,,,,,,,,,,,,, -136252,,,0.84193354845047,0.8302668333053589,0.758080005645752,1.1913862228393557,50000.0,0.6366000175476074,1.7862942218780518,10000.0,62218.01303982735,68501.58464884758,62218.01303982735,6270.500958204269,5.982234716415405,0.0 -136300,1.8141068,3.751333,,,,,,,,,,,,,, -136400,1.8450251,2.9839532,,,,,,,,,,,,,, -136500,1.9749148,3.0602307,,,,,,,,,,,,,, -136600,1.8163174,2.9769588,,,,,,,,,,,,,, -136700,2.1018133,2.9544168,,,,,,,,,,,,,, -136800,1.8975478,3.0538495,,,,,,,,,,,,,, -136900,1.9730711,2.9970908,,,,,,,,,,,,,, -137000,1.8545952,2.908634,,,,,,,,,,,,,, -137100,1.8091235,4.0860558,,,,,,,,,,,,,, -137171,,,0.8501171469688416,0.7863731980323792,0.7578799724578857,1.1653896570205688,50000.0,0.6381000280380249,1.7534481287002563,10000.0,62637.98712944984,68965.4830365181,62637.98712944984,6314.320621013641,6.038216590881348,0.0 -137200,1.8066819,3.050126,,,,,,,,,,,,,, -137300,1.830559,2.9727488,,,,,,,,,,,,,, -137400,1.920402,3.016283,,,,,,,,,,,,,, -137500,2.017954,4.313198,,,,,,,,,,,,,, -137600,1.9166788,3.0436199,,,,,,,,,,,,,, -137700,1.7759142,2.8932457,,,,,,,,,,,,,, -137800,1.8887162,2.9458961,,,,,,,,,,,,,, -137900,1.8781611,2.9506493,,,,,,,,,,,,,, -138000,1.956487,2.9591904,,,,,,,,,,,,,, -138092,,,0.8501366972923279,0.8256229758262634,0.759880006313324,1.1994614601135254,50000.0,0.6454000473022461,1.7836123704910278,10000.0,63058.27062392235,69428.69195985794,63058.27062392235,6357.152781248093,6.084154844284058,0.0 -138100,1.8359487,3.7368972,,,,,,,,,,,,,, -138200,1.9348522,2.9837909,,,,,,,,,,,,,, -138300,1.9274651,2.9368443,,,,,,,,,,,,,, -138400,2.2149131,4.299489,,,,,,,,,,,,,, -138500,1.8996699,2.9936285,,,,,,,,,,,,,, -138600,2.03451,2.9957507,,,,,,,,,,,,,, -138700,1.8793641,2.9093094,,,,,,,,,,,,,, -138800,1.8008958,2.9399822,,,,,,,,,,,,,, -138900,1.9747747,2.9978006,,,,,,,,,,,,,, -139000,2.1623147,3.9296691,,,,,,,,,,,,,, -139011,,,0.8479101657867432,0.7939903736114502,0.7591599822044373,1.1697139739990234,50000.0,0.6362000107765198,1.7635866403579712,10000.0,63478.28254342079,69894.97005820274,63478.28254342079,6403.32272028923,6.133307695388794,0.0 -139100,1.8429465,2.974731,,,,,,,,,,,,,, -139200,2.0015721,3.010094,,,,,,,,,,,,,, -139300,1.8339422,3.156442,,,,,,,,,,,,,, -139400,1.971199,4.0233393,,,,,,,,,,,,,, -139500,2.0832562,4.306247,,,,,,,,,,,,,, -139600,2.028631,2.9707046,,,,,,,,,,,,,, -139700,1.8975059,2.901395,,,,,,,,,,,,,, -139800,1.9447337,3.8801959,,,,,,,,,,,,,, -139900,1.9962614,3.2706263,,,,,,,,,,,,,, -139931,,,0.8531054258346558,0.7984217405319214,0.7590999603271484,1.1816964149475098,50000.0,0.6403000354766846,1.7529109716415403,10000.0,63898.31989693642,70359.86943912506,63898.31989693642,6448.088568687439,6.181832790374756,0.0 -140000,1.8952821,3.7374845,,,,,,,,,,,,,, -140100,1.8968018,2.7965302,,,,,,,,,,,,,, -140200,1.7818264,2.9479494,,,,,,,,,,,,,, -140300,1.7407191,3.0061977,,,,,,,,,,,,,, -140400,2.0844877,4.1344156,,,,,,,,,,,,,, -140500,1.7911662,3.560368,,,,,,,,,,,,,, -140600,2.0046327,3.0275288,,,,,,,,,,,,,, -140700,1.9622954,3.0118318,,,,,,,,,,,,,, -140800,1.8751919,2.9935784,,,,,,,,,,,,,, -140851,,,0.8634960651397705,0.751483142375946,0.7616999745368958,1.166806936264038,50000.0,0.643500030040741,1.751853346824646,10000.0,64318.303194999695,70826.06875395775,64318.303194999695,6494.210072994232,6.228420257568359,0.0 -140900,2.0875669,2.968359,,,,,,,,,,,,,, -141000,2.139411,4.2829685,,,,,,,,,,,,,, -141100,1.9737418,3.2704666,,,,,,,,,,,,,, -141200,2.0779378,2.9816885,,,,,,,,,,,,,, -141300,1.9116188,2.9046488,,,,,,,,,,,,,, -141400,1.8175594,2.9939473,,,,,,,,,,,,,, -141500,2.1441233,4.1904416,,,,,,,,,,,,,, -141600,2.250171,2.92441,,,,,,,,,,,,,, -141700,1.9978142,3.000349,,,,,,,,,,,,,, -141770,,,0.849609375,0.7915339469909668,0.7606199979782104,1.1642590761184692,50000.0,0.6408000588417053,1.749077081680298,10000.0,64738.19510626793,71288.30664849281,64738.19510626793,6536.446292638779,6.2816667556762695,0.0 -141800,1.9571224,3.5429583,,,,,,,,,,,,,, -141900,1.9310905,2.9575896,,,,,,,,,,,,,, -142000,1.9499627,3.978919,,,,,,,,,,,,,, -142100,1.9699972,3.9389799,,,,,,,,,,,,,, -142200,2.1176083,2.9701626,,,,,,,,,,,,,, -142300,2.0431333,2.9746594,,,,,,,,,,,,,, -142400,2.206366,4.2996855,,,,,,,,,,,,,, -142500,2.054417,3.8131118,,,,,,,,,,,,,, -142600,1.9493104,2.9412456,,,,,,,,,,,,,, -142689,,,0.8543359041213989,0.7828488945960999,0.7647799849510193,1.1637057065963743,50000.0,0.6442000269889832,1.7491698265075684,10000.0,65158.11840724945,71755.92424988747,65158.11840724945,6584.042410612106,6.332995891571045,0.0 -142700,2.0424745,2.9180887,,,,,,,,,,,,,, -142800,2.2798252,3.918193,,,,,,,,,,,,,, -142900,2.117705,2.9817195,,,,,,,,,,,,,, -143000,2.2978663,4.194566,,,,,,,,,,,,,, -143100,1.9749317,3.1245365,,,,,,,,,,,,,, -143200,2.0390565,4.0818825,,,,,,,,,,,,,, -143300,1.9121054,3.1322114,,,,,,,,,,,,,, -143400,1.9656188,2.9598022,,,,,,,,,,,,,, -143500,2.046633,2.9261293,,,,,,,,,,,,,, -143600,1.867515,3.4210699,,,,,,,,,,,,,, -143611,,,0.86376953125,0.744042694568634,0.7639600038528442,1.1600041389465332,50000.0,0.643500030040741,1.7458293437957764,10000.0,65578.4606962204,72221.8099834919,65578.4606962204,6629.484586715698,6.386404514312744,0.0 -143700,2.028085,3.1304624,,,,,,,,,,,,,, -143800,2.1180396,3.0643249,,,,,,,,,,,,,, -143900,1.9773972,2.902999,,,,,,,,,,,,,, -144000,1.9531999,2.9749403,,,,,,,,,,,,,, -144100,2.1964667,3.995825,,,,,,,,,,,,,, -144200,1.994021,3.3499577,,,,,,,,,,,,,, -144300,2.049598,3.2100127,,,,,,,,,,,,,, -144400,2.0455184,2.871349,,,,,,,,,,,,,, -144500,1.9614855,2.9841568,,,,,,,,,,,,,, -144532,,,0.8549999594688416,0.7654373645782471,0.7643399834632874,1.1485599279403689,50000.0,0.6429000496864319,1.74350905418396,10000.0,65998.61735081673,72687.37620162964,65998.61735081673,6674.7976150512695,6.435186386108398,0.0 -144600,2.1599102,3.7855706,,,,,,,,,,,,,, -144700,2.0035393,3.3809655,,,,,,,,,,,,,, -144800,2.069762,3.2016542,,,,,,,,,,,,,, -144900,1.9803793,3.599033,,,,,,,,,,,,,, -145000,1.8623179,3.2639194,,,,,,,,,,,,,, -145100,1.9438987,3.1765995,,,,,,,,,,,,,, -145200,2.0377162,3.017375,,,,,,,,,,,,,, -145300,2.1640348,3.1462028,,,,,,,,,,,,,, -145400,2.160408,4.077531,,,,,,,,,,,,,, -145453,,,0.8571484088897705,0.7539299130439758,0.7656399607658386,1.1399657726287842,50000.0,0.6452000141143799,1.722506761550903,10000.0,66418.83790254593,73152.03741383553,66418.83790254593,6719.138984918594,6.487768888473511,0.0 -145500,1.97998,3.2183554,,,,,,,,,,,,,, -145600,2.7017877,4.4628506,,,,,,,,,,,,,, -145700,2.083306,2.9611535,,,,,,,,,,,,,, -145800,2.0054293,2.8799362,,,,,,,,,,,,,, -145900,2.004087,2.948021,,,,,,,,,,,,,, -146000,2.1399271,2.9278624,,,,,,,,,,,,,, -146100,1.9869769,3.6942534,,,,,,,,,,,,,, -146200,2.2777066,4.355592,,,,,,,,,,,,,, -146300,1.9725262,2.8266673,,,,,,,,,,,,,, -146374,,,0.8620703220367432,0.7350453734397888,0.7666599750518799,1.1349040269851685,50000.0,0.6457000374794006,1.715340256690979,10000.0,66838.8134508133,73617.8742146492,66838.8134508133,6764.900428771973,6.540493726730347,0.0 -146400,2.4463217,4.4021444,,,,,,,,,,,,,, -146500,2.220702,2.9126403,,,,,,,,,,,,,, -146600,2.318095,4.2556534,,,,,,,,,,,,,, -146700,2.0716765,2.9351432,,,,,,,,,,,,,, -146800,2.045705,2.893371,,,,,,,,,,,,,, -146900,2.1872869,2.8945785,,,,,,,,,,,,,, -147000,1.9363421,3.0682564,,,,,,,,,,,,,, -147100,2.1663737,2.9121916,,,,,,,,,,,,,, -147200,2.195369,3.385835,,,,,,,,,,,,,, -147297,,,0.8604687452316284,0.7560346722602844,0.766759991645813,1.1506201028823853,50000.0,0.6461000442504883,1.7426823377609253,10000.0,67258.91440343857,74083.21685099602,67258.91440343857,6810.045199871063,6.589900016784668,0.0 -147300,1.9672402,2.861823,,,,,,,,,,,,,, -147400,2.03237,3.5675893,,,,,,,,,,,,,, -147500,2.4019122,4.0884247,,,,,,,,,,,,,, -147600,1.9499258,3.4966946,,,,,,,,,,,,,, -147700,2.2319522,3.246724,,,,,,,,,,,,,, -147800,2.1188283,2.9364052,,,,,,,,,,,,,, -147900,2.1259353,2.8799524,,,,,,,,,,,,,, -148000,2.1104956,3.1551423,,,,,,,,,,,,,, -148100,2.3598628,4.054304,,,,,,,,,,,,,, -148200,2.131731,2.9258876,,,,,,,,,,,,,, -148218,,,0.8648632764816284,0.7450129389762878,0.766979992389679,1.1421996355056765,50000.0,0.6454000473022461,1.7250229120254517,10000.0,67679.01679587364,74546.50698709488,67679.01679587364,6853.139424085617,6.635733366012573,0.0 -148300,2.105492,2.9143336,,,,,,,,,,,,,, -148400,2.1639862,2.9387329,,,,,,,,,,,,,, -148500,2.2273445,4.2650013,,,,,,,,,,,,,, -148600,2.1144905,2.9768732,,,,,,,,,,,,,, -148700,2.03324,3.181364,,,,,,,,,,,,,, -148800,2.087462,2.9349666,,,,,,,,,,,,,, -148900,2.1339703,3.233653,,,,,,,,,,,,,, -149000,2.579271,4.287931,,,,,,,,,,,,,, -149100,2.0967379,2.9384851,,,,,,,,,,,,,, -149138,,,0.8698437213897705,0.7154492139816284,0.768839955329895,1.1309633255004885,50000.0,0.6490000486373901,1.7095143795013428,10000.0,68098.65920686722,75011.86696529388,68098.65920686722,6898.35661482811,7.087894439697266,0.0 -149200,1.9589629,3.0811856,,,,,,,,,,,,,, -149300,2.0873482,3.5407739,,,,,,,,,,,,,, -149400,2.03235,3.2981286,,,,,,,,,,,,,, -149500,2.4228904,4.130301,,,,,,,,,,,,,, -149600,2.2233195,3.8499951,,,,,,,,,,,,,, -149700,2.1555007,2.974531,,,,,,,,,,,,,, -149800,2.1206503,2.8437443,,,,,,,,,,,,,, -149900,2.1266692,2.9536595,,,,,,,,,,,,,, -150000,2.2196124,2.9554238,,,,,,,,,,,,,, -150058,,,0.8729296922683716,0.7006752490997314,0.7681399583816528,1.1254583597183228,50000.0,0.650700032711029,1.709939956665039,10000.0,68518.84233808517,75477.76089262962,68518.84233808517,6943.96710062027,7.140713453292847,0.0 -150100,2.1872993,2.9193497,,,,,,,,,,,,,, -150200,2.55167,4.0853853,,,,,,,,,,,,,, -150300,2.665245,4.3244443,,,,,,,,,,,,,, -150400,2.1270893,2.9957218,,,,,,,,,,,,,, -150500,2.0585601,2.9168715,,,,,,,,,,,,,, -150600,2.3070805,2.996108,,,,,,,,,,,,,, -150700,2.5224922,4.0768414,,,,,,,,,,,,,, -150800,2.120166,3.1940203,,,,,,,,,,,,,, -150900,2.4506946,4.0418367,,,,,,,,,,,,,, -150978,,,0.8682031035423279,0.7284000515937805,0.770039975643158,1.1370346546173096,50000.0,0.6529000401496887,1.707403540611267,10000.0,68939.2911374569,75941.07792234421,68939.2911374569,6986.741266012192,7.1879754066467285,0.0 -151000,2.1020155,2.9051213,,,,,,,,,,,,,, -151100,2.2983942,2.8449633,,,,,,,,,,,,,, -151200,2.15614,2.9573753,,,,,,,,,,,,,, -151300,2.089453,3.0455434,,,,,,,,,,,,,, -151400,2.1735296,2.8111944,,,,,,,,,,,,,, -151500,2.14513,3.1626825,,,,,,,,,,,,,, -151600,2.278865,3.3187287,,,,,,,,,,,,,, -151700,2.119038,3.2684631,,,,,,,,,,,,,, -151800,2.302698,3.0211864,,,,,,,,,,,,,, -151900,2.384753,3.9126792,,,,,,,,,,,,,, -151901,,,0.8696093559265137,0.7336189150810242,0.7696399688720703,1.1401625871658323,50000.0,0.6514000296592712,1.7238764762878418,10000.0,69359.93056178093,76403.58780241013,69359.93056178093,7028.515541791916,7.23703145980835,0.0 -152000,2.1304688,2.8958337,,,,,,,,,,,,,, -152100,2.0343285,3.4954667,,,,,,,,,,,,,, -152200,2.6745756,4.02057,,,,,,,,,,,,,, -152300,2.0446897,2.9198828,,,,,,,,,,,,,, -152400,2.1920283,2.8303068,,,,,,,,,,,,,, -152500,2.255705,2.8535864,,,,,,,,,,,,,, -152600,2.3830009,2.8745067,,,,,,,,,,,,,, -152700,2.6222677,4.170237,,,,,,,,,,,,,, -152800,2.0128758,3.345957,,,,,,,,,,,,,, -152821,,,0.8742968440055847,0.6973341703414917,0.7716000080108643,1.1223145723342896,50000.0,0.6492000222206116,1.7144511938095093,10000.0,69779.99265003204,76866.06092453003,69779.99265003204,7070.822064638138,7.29412579536438,0.0 -152900,2.0908496,2.896894,,,,,,,,,,,,,, -153000,2.1113727,3.3105848,,,,,,,,,,,,,, -153100,2.2345166,2.8842993,,,,,,,,,,,,,, -153200,2.1908972,2.9013646,,,,,,,,,,,,,, -153300,2.1476974,2.9752886,,,,,,,,,,,,,, -153400,2.4277866,2.9490762,,,,,,,,,,,,,, -153500,2.2883244,2.9402137,,,,,,,,,,,,,, -153600,2.2456818,2.8178568,,,,,,,,,,,,,, -153700,3.2369301,4.2155,,,,,,,,,,,,,, -153740,,,0.8701757788658142,0.7270835041999817,0.7702800035476685,1.1377615928649902,50000.0,0.6533000469207764,1.7147181034088137,10000.0,70200.02810454369,77328.01024198532,70200.02810454369,7112.63879776001,7.344912052154541,0.0 -153800,2.5732455,4.033617,,,,,,,,,,,,,, -153900,2.2462854,3.1182501,,,,,,,,,,,,,, -154000,2.3039465,2.8902683,,,,,,,,,,,,,, -154100,2.369002,2.8647587,,,,,,,,,,,,,, -154200,2.0183325,2.9508798,,,,,,,,,,,,,, -154300,2.2050745,2.8976536,,,,,,,,,,,,,, -154400,2.3436704,3.4074018,,,,,,,,,,,,,, -154500,2.4881215,3.6888466,,,,,,,,,,,,,, -154600,2.2178667,3.8368173,,,,,,,,,,,,,, -154661,,,0.873828113079071,0.7049767971038818,0.7745800018310547,1.1225615739822388,50000.0,0.659000039100647,1.699967861175537,10000.0,70620.03921198845,77790.1883494854,70620.03921198845,7154.707575559616,7.395971059799194,0.0 -154700,2.1529062,2.802483,,,,,,,,,,,,,, -154800,2.2890143,3.125982,,,,,,,,,,,,,, -154900,2.2915516,2.9402976,,,,,,,,,,,,,, -155000,2.3228376,2.8556101,,,,,,,,,,,,,, -155100,2.246933,2.873556,,,,,,,,,,,,,, -155200,2.1451335,3.0025892,,,,,,,,,,,,,, -155300,2.5955384,4.3163295,,,,,,,,,,,,,, -155400,2.2245035,2.8368883,,,,,,,,,,,,,, -155500,2.1317985,3.2464306,,,,,,,,,,,,,, -155577,,,0.8754296898841858,0.6761949062347412,0.7723399996757507,1.1067827939987185,50000.0,0.6546000242233276,1.675015568733215,10000.0,71040.05707144737,78257.04546570778,71040.05707144737,7201.448133707047,7.44741940498352,0.0 -155600,2.7830286,4.36758,,,,,,,,,,,,,, -155700,2.203053,3.501046,,,,,,,,,,,,,, -155800,2.437775,4.0201817,,,,,,,,,,,,,, -155900,2.3262758,3.2350912,,,,,,,,,,,,,, -156000,2.29427,3.1079566,,,,,,,,,,,,,, -156100,2.158679,2.8647473,,,,,,,,,,,,,, -156200,2.145486,2.7633283,,,,,,,,,,,,,, -156300,2.1770444,2.8168488,,,,,,,,,,,,,, -156400,2.3619485,2.9041557,,,,,,,,,,,,,, -156497,,,0.8728905916213989,0.7006733417510986,0.775439977645874,1.110991358757019,50000.0,0.6582000255584717,1.6929787397384644,10000.0,71460.33196163177,78725.32055997849,71460.33196163177,7249.346954345703,7.500328063964844,0.0 -156500,2.1067934,3.666547,,,,,,,,,,,,,, -156600,2.3662336,3.2057576,,,,,,,,,,,,,, -156700,2.2134,3.0658298,,,,,,,,,,,,,, -156800,2.811549,4.2615347,,,,,,,,,,,,,, -156900,2.2399724,2.8123317,,,,,,,,,,,,,, -157000,2.2906947,3.1638882,,,,,,,,,,,,,, -157100,2.2697065,2.919076,,,,,,,,,,,,,, -157200,2.1348293,3.3651724,,,,,,,,,,,,,, -157300,2.3043165,2.8833156,,,,,,,,,,,,,, -157400,2.2702904,3.4312997,,,,,,,,,,,,,, -157419,,,0.8755663633346558,0.6870554089546204,0.7741000056266785,1.113128900527954,50000.0,0.65420001745224,1.6954491138458252,10000.0,71880.53656959534,79187.9251627922,71880.53656959534,7291.650812864304,7.549129009246826,0.0 -157500,2.3110611,2.8634338,,,,,,,,,,,,,, -157600,2.5940986,3.9825826,,,,,,,,,,,,,, -157700,2.258865,2.8582728,,,,,,,,,,,,,, -157800,2.3529687,2.863533,,,,,,,,,,,,,, -157900,2.2180157,2.7711658,,,,,,,,,,,,,, -158000,2.1711707,2.7815652,,,,,,,,,,,,,, -158100,2.2872338,3.6881156,,,,,,,,,,,,,, -158200,2.2318738,2.9632306,,,,,,,,,,,,,, -158300,2.2206542,2.822152,,,,,,,,,,,,,, -158341,,,0.8795117139816284,0.6774889230728149,0.7754999995231628,1.105984091758728,50000.0,0.6583000421524048,1.681814670562744,10000.0,72300.74264979362,79649.6304256916,72300.74264979362,7333.049576044083,7.60271143913269,0.0 -158400,2.4902916,2.814622,,,,,,,,,,,,,, -158500,2.7811213,4.2044606,,,,,,,,,,,,,, -158600,2.2736626,3.0112052,,,,,,,,,,,,,, -158700,2.7375386,4.0592923,,,,,,,,,,,,,, -158800,2.1572819,3.232707,,,,,,,,,,,,,, -158900,2.2784274,2.7822857,,,,,,,,,,,,,, -159000,2.3132317,3.6145425,,,,,,,,,,,,,, -159100,2.3856347,3.847734,,,,,,,,,,,,,, -159200,2.176671,2.9143355,,,,,,,,,,,,,, -159260,,,0.8759960532188416,0.6908382773399353,0.775879979133606,1.1087701320648191,50000.0,0.6544000506401062,1.6908684968948364,10000.0,72721.07917380333,80115.90304541588,72721.07917380333,7378.887080192566,7.6535325050354,0.0 -159300,2.426841,2.9177628,,,,,,,,,,,,,, -159400,2.2651646,3.0904977,,,,,,,,,,,,,, -159500,2.3747857,2.7874148,,,,,,,,,,,,,, -159600,2.261821,2.8033605,,,,,,,,,,,,,, -159700,2.5573165,2.8153534,,,,,,,,,,,,,, -159800,2.482643,2.8986633,,,,,,,,,,,,,, -159900,2.3562725,3.7353024,,,,,,,,,,,,,, -160000,2.3820994,3.417375,,,,,,,,,,,,,, -160100,2.5152721,2.8215387,,,,,,,,,,,,,, -160179,,,0.8782421946525574,0.6889281868934631,0.7750399708747864,1.1090731620788574,50000.0,0.6558000445365906,1.691611409187317,10000.0,73141.11105513573,80583.63886260986,73141.11105513573,7426.482050657272,7.7148168087005615,0.0 -160200,2.3955958,3.1557097,,,,,,,,,,,,,, -160300,2.470175,3.8526576,,,,,,,,,,,,,, -160400,2.3122048,2.7902348,,,,,,,,,,,,,, -160500,2.864253,3.3001318,,,,,,,,,,,,,, -160600,2.3374455,3.421135,,,,,,,,,,,,,, -160700,2.5221567,3.7188287,,,,,,,,,,,,,, -160800,2.7283769,3.9817338,,,,,,,,,,,,,, -160900,2.3218582,2.7871237,,,,,,,,,,,,,, -161000,2.2958777,3.461269,,,,,,,,,,,,,, -161100,2.4480827,2.908873,,,,,,,,,,,,,, -161101,,,0.8794531226158142,0.6894750595092773,0.7768999934196472,1.1162922382354736,50000.0,0.6583000421524048,1.6969735622406006,10000.0,73561.23503017426,81049.8535144329,73561.23503017426,7472.475435972214,7.764083385467529,0.0 -161200,2.289637,2.7783058,,,,,,,,,,,,,, -161300,2.374333,2.8518894,,,,,,,,,,,,,, -161400,2.8359776,3.8767474,,,,,,,,,,,,,, -161500,2.2495325,2.7855556,,,,,,,,,,,,,, -161600,2.366574,2.86828,,,,,,,,,,,,,, -161700,2.417197,3.099364,,,,,,,,,,,,,, -161800,2.6097515,3.9111412,,,,,,,,,,,,,, -161900,2.1587455,2.818484,,,,,,,,,,,,,, -162000,2.5647593,3.4301846,,,,,,,,,,,,,, -162021,,,0.8860155940055847,0.6461138129234314,0.7784199714660645,1.0914040803909302,50000.0,0.6588000059127808,1.671212911605835,10000.0,73981.34681868553,81509.38480234146,73981.34681868553,7511.793575048447,7.818175554275513,0.0 -162100,2.4301183,2.7986925,,,,,,,,,,,,,, -162200,2.365014,2.8114386,,,,,,,,,,,,,, -162300,3.0692935,4.1962805,,,,,,,,,,,,,, -162400,2.4903848,2.8339417,,,,,,,,,,,,,, -162500,2.3969703,2.7401862,,,,,,,,,,,,,, -162600,2.559425,3.0208669,,,,,,,,,,,,,, -162700,2.3345935,3.171551,,,,,,,,,,,,,, -162800,2.984662,4.3409595,,,,,,,,,,,,,, -162900,2.4894967,2.8062775,,,,,,,,,,,,,, -162943,,,0.8807421922683716,0.6805300712585449,0.7788800001144409,1.106682062149048,50000.0,0.6619000434875488,1.6778147220611572,10000.0,74401.3993074894,81974.58531212807,74401.3993074894,7556.840314865112,7.872089624404907,0.0 -163000,2.632937,2.8658023,,,,,,,,,,,,,, -163100,2.3157153,2.7780752,,,,,,,,,,,,,, -163200,2.4555078,3.0351477,,,,,,,,,,,,,, -163300,2.3213632,3.4904897,,,,,,,,,,,,,, -163400,2.2398496,3.058636,,,,,,,,,,,,,, -163500,2.33821,2.7685225,,,,,,,,,,,,,, -163600,2.893083,4.1940784,,,,,,,,,,,,,, -163700,2.37861,2.8238404,,,,,,,,,,,,,, -163800,3.472524,2.8053222,,,,,,,,,,,,,, -163867,,,0.8832812309265137,0.6668259501457214,0.7802599668502808,1.0908979177474976,50000.0,0.6641000509262085,1.6616779565811155,10000.0,74821.57655787468,82442.0477347374,74821.57655787468,7604.018439769745,7.928055763244629,0.0 -163900,2.2360888,2.8115637,,,,,,,,,,,,,, -164000,2.5359173,2.82004,,,,,,,,,,,,,, -164100,2.6283836,3.6196034,,,,,,,,,,,,,, -164200,2.5378175,2.8966188,,,,,,,,,,,,,, -164300,2.5607612,2.8208117,,,,,,,,,,,,,, -164400,2.7130852,3.708105,,,,,,,,,,,,,, -164500,2.543474,2.7875664,,,,,,,,,,,,,, -164600,2.3848517,2.816094,,,,,,,,,,,,,, -164700,2.4905915,2.7953143,,,,,,,,,,,,,, -164789,,,0.887988269329071,0.6550906896591187,0.7813000082969666,1.0903257131576538,50000.0,0.6593000292778015,1.6780301332473757,10000.0,75241.48506331444,82908.00784730911,75241.48506331444,7649.971467733383,7.979364395141602,0.0 -164800,2.3687959,2.7736435,,,,,,,,,,,,,, -164900,2.242003,2.925533,,,,,,,,,,,,,, -165000,2.3536482,2.8394854,,,,,,,,,,,,,, -165100,2.7037766,2.8696465,,,,,,,,,,,,,, -165200,2.3952475,2.8411105,,,,,,,,,,,,,, -165300,3.2687118,4.2293253,,,,,,,,,,,,,, -165400,2.637173,2.8068943,,,,,,,,,,,,,, -165500,2.3063593,3.120184,,,,,,,,,,,,,, -165600,2.5142295,2.8944538,,,,,,,,,,,,,, -165700,3.045177,2.80574,,,,,,,,,,,,,, -165714,,,0.8838671445846558,0.6711214780807495,0.7797999978065491,1.0984795093536377,50000.0,0.663100004196167,1.6789699792861938,10000.0,75661.61553740501,83369.6378827095,75661.61553740501,7691.369910478592,8.033177137374878,0.0 -165800,2.8461185,4.0835843,,,,,,,,,,,,,, -165900,3.2182257,4.3265944,,,,,,,,,,,,,, -166000,2.4574547,2.7876678,,,,,,,,,,,,,, -166100,2.514409,2.8901854,,,,,,,,,,,,,, -166200,2.461523,2.8120778,,,,,,,,,,,,,, -166300,2.375675,2.8433948,,,,,,,,,,,,,, -166400,2.9057293,4.070734,,,,,,,,,,,,,, -166500,2.7007406,2.8601465,,,,,,,,,,,,,, -166600,3.5091183,4.2833066,,,,,,,,,,,,,, -166636,,,0.8857812285423279,0.6534203290939331,0.7819399833679199,1.0902855396270752,50000.0,0.6624000072479248,1.6653751134872437,10000.0,76081.84141349792,83835.18001461029,76081.84141349792,7736.580994844437,8.090880393981934,0.0 -166700,2.5640116,2.7177315,,,,,,,,,,,,,, -166800,2.8919907,3.2732782,,,,,,,,,,,,,, -166900,2.3619826,3.2339566,,,,,,,,,,,,,, -167000,2.473737,2.7674146,,,,,,,,,,,,,, -167100,2.842168,3.097828,,,,,,,,,,,,,, -167200,2.5391212,2.8213768,,,,,,,,,,,,,, -167300,2.3647544,3.254026,,,,,,,,,,,,,, -167400,2.421854,3.4289327,,,,,,,,,,,,,, -167500,2.4890285,2.7337353,,,,,,,,,,,,,, -167556,,,0.8882226347923279,0.6422659754753113,0.7817999720573425,1.0872855186462402,50000.0,0.6653000116348267,1.656856894493103,10000.0,76501.81643271446,84299.6081571579,76501.81643271446,7780.93232178688,8.14522910118103,0.0 -167600,2.566124,2.886598,,,,,,,,,,,,,, -167700,3.3376415,4.2717505,,,,,,,,,,,,,, -167800,2.437671,2.7852592,,,,,,,,,,,,,, -167900,2.433059,2.882795,,,,,,,,,,,,,, -168000,2.4430623,2.7270577,,,,,,,,,,,,,, -168100,2.630074,3.054757,,,,,,,,,,,,,, -168200,2.6230788,2.8446949,,,,,,,,,,,,,, -168300,2.6795275,2.773734,,,,,,,,,,,,,, -168400,3.0986116,4.12706,,,,,,,,,,,,,, -168477,,,0.8875390291213989,0.6451851725578308,0.7829200029373169,1.0794780254364014,50000.0,0.6647000312805176,1.650307297706604,10000.0,76921.93040585518,84765.37529754639,76921.93040585518,7826.474480390549,8.208962202072144,0.0 -168500,2.528168,2.888667,,,,,,,,,,,,,, -168600,2.4804065,2.9173868,,,,,,,,,,,,,, -168700,2.826037,3.8386233,,,,,,,,,,,,,, -168800,3.25054,4.2389402,,,,,,,,,,,,,, -168900,2.5405247,2.7453747,,,,,,,,,,,,,, -169000,3.1227221,4.118157,,,,,,,,,,,,,, -169100,3.3289933,4.204841,,,,,,,,,,,,,, -169200,2.5244112,2.8043156,,,,,,,,,,,,,, -169300,2.4857123,2.8706741,,,,,,,,,,,,,, -169398,,,0.8874218463897705,0.6406536102294922,0.7831999659538269,1.078544020652771,50000.0,0.6653000116348267,1.6556425094604492,10000.0,77342.00812506676,85230.77464270592,77342.00812506676,7871.694277763367,8.263505935668945,0.0 -169400,2.6170006,2.7751117,,,,,,,,,,,,,, -169500,2.7427251,3.6144245,,,,,,,,,,,,,, -169600,2.3983839,2.7268941,,,,,,,,,,,,,, -169700,2.3932989,2.8192587,,,,,,,,,,,,,, -169795,,,,,,,,,,,77520.01207590103,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 8847b8fd4..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -29.060741901397705,0.0,37.03280329704285,1,0,37.03280329704285,0.0010000000474974,6.907756805419922,10000,66.09364891052246,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -72.75973510742188,0.0193541049957275,457.21945309638977,860,0,457.21945309638977,0.0108000002801418,6.47079610824585,10000,530.0452859401703,0.0125195309519767,6.41873836517334,0.0127199999988079,6.431779384613037,50000 -122.92927050590517,0.0499892234802246,877.4289219379425,1773,0,877.4289219379425,0.0303000006824731,6.007748603820801,10000,1000.501962184906,0.0384765602648258,5.859482288360596,0.0354399979114532,5.89152193069458,50000 -169.69586300849917,0.0778777599334716,1297.60599899292,2692,0,1297.60599899292,0.0446000024676322,5.647171974182129,10000,1467.5219156742096,0.0623632781207561,5.426368713378906,0.0602199994027614,5.460067272186279,50000 -204.232607126236,0.1020910739898681,1717.723210811615,3612,0,1717.723210811615,0.0625,5.394311428070068,10000,1922.2475941181185,0.0909960940480232,5.120298385620117,0.0842399969696998,5.156952857971191,50000 -249.70977568626404,0.1325874328613281,2138.045210123062,4511,0,2138.045210123062,0.0897000059485435,5.0951080322265625,10000,2388.123862028122,0.1286914050579071,4.747590065002441,0.119279995560646,4.806865692138672,50000 -296.49737071990967,0.1598045825958252,2558.279573202133,5431,0,2558.279573202133,0.1129000037908554,4.798599243164063,10000,2855.2203080654144,0.1690820306539535,4.369624137878418,0.1549399942159652,4.4540791511535645,50000 -335.4525935649872,0.1912572383880615,2978.6933076381683,6353,0,2978.6933076381683,0.14410001039505,4.515932083129883,10000,3314.668624162674,0.2099804580211639,4.027159690856934,0.1937799900770187,4.129310607910156,50000 -380.9338798522949,0.2187426090240478,3398.987615585327,7274,0,3398.987615585327,0.1762000024318695,4.282338619232178,10000,3780.519155740738,0.2498242110013961,3.753761768341065,0.2315599918365478,3.850430011749268,50000 -427.0207488536835,0.2449104785919189,3819.009551525116,8196,0,3819.009551525116,0.2035000026226043,4.090365886688232,10000,4246.702522277832,0.2862304747104645,3.5035860538482666,0.2628999948501587,3.623043775558472,50000 -472.77794551849365,0.2752766609191894,4239.054158687592,9114,0,4239.054158687592,0.2339000105857849,3.833681344985962,10000,4712.58321595192,0.3378515541553497,3.160243511199951,0.3047399818897247,3.333974599838257,50000 -519.3172528743744,0.3021481037139892,4659.012718915939,10037,0,4659.012718915939,0.2568000257015228,3.693822860717773,10000,5179.155911445618,0.3565820157527923,3.0313377380371094,0.3318600058555603,3.1532087326049805,50000 -565.0101907253265,0.3308203220367431,5079.202866315842,10960,0,5079.202866315842,0.2639000117778778,3.631510019302368,10000,5645.115402698517,0.3778125047683716,2.9234459400177,0.3477599918842315,3.085547685623169,50000 -606.9534866809845,0.3557870388031006,5499.275423049927,11881,0,5499.275423049927,0.2891000211238861,3.4765114784240723,10000,6107.204093456268,0.4100195169448852,2.730732679367065,0.3714399933815002,2.922337532043457,50000 -653.2269532680511,0.3829104900360107,5919.309090852737,12800,0,5919.309090852737,0.299200028181076,3.3898091316223145,10000,6573.585469007492,0.4242968559265136,2.650355815887451,0.3928999900817871,2.803166151046753,50000 -697.1185910701752,0.8888986110687256,6338.901314973831,13720,0,6338.901314973831,0.3100000023841858,3.3641951084136963,10000,7037.623072147369,0.4307031035423279,2.6292409896850586,0.3979800045490265,2.789198398590088,50000 -733.4783155918121,0.914440393447876,6759.135899543762,14641,0,6759.135899543762,0.3162000179290771,3.266141653060913,10000,7494.290160655975,0.4467968642711639,2.5055689811706543,0.4116399884223938,2.6887874603271484,50000 -770.1984448432922,0.9485998153686525,7179.2973692417145,15556,0,7179.2973692417145,0.3327000141143799,3.171525955200196,10000,7951.254838228226,0.4633203148841858,2.4123597145080566,0.4307799935340881,2.577693462371826,50000 -819.353661775589,0.9791200160980223,7599.65398979187,16474,0,7599.65398979187,0.3393000066280365,3.1288845539093018,10000,8420.844632863998,0.4699999988079071,2.36291766166687,0.4378999769687652,2.5307412147521973,50000 -861.3643939495087,1.0059688091278076,8019.693618297577,17397,0,8019.693618297577,0.3379000127315521,3.118520498275757,10000,8882.96993303299,0.47802734375,2.3279898166656494,0.4379799962043762,2.5215203762054443,50000 -905.2010207176208,1.0355210304260254,8439.754869222641,18317,0,8439.754869222641,0.3397000133991241,3.14664888381958,10000,9346.944638490677,0.4999414086341858,2.242021322250366,0.4409199953079223,2.529330015182495,50000 -950.6946420669556,1.0625567436218262,8860.098046064377,19236,0,8860.098046064377,0.3473000228404999,3.072669744491577,10000,9812.855647802353,0.4896093606948852,2.2664780616760254,0.4534199833869934,2.4464144706726074,50000 -992.1189947128296,1.0902252197265625,9280.475875854492,20159,0,9280.475875854492,0.3519000113010406,3.0374844074249268,10000,10274.733105897903,0.4975976347923279,2.2221951484680176,0.461139976978302,2.4042558670043945,50000 -1038.3150515556335,1.1201841831207275,9700.750081539154,21080,0,9700.750081539154,0.3665000200271606,2.948438882827759,10000,10741.280704021454,0.5208203196525574,2.0787808895111084,0.4711199998855591,2.326721429824829,50000 -1084.7075653076172,1.152724266052246,10121.19642019272,22001,0,10121.19642019272,0.3775000274181366,2.8965604305267334,10000,11208.20026922226,0.5149999856948853,2.116328239440918,0.4795799851417541,2.289320707321167,50000 -1129.569720506668,1.1803789138793943,10541.374686002731,22922,0,10541.374686002731,0.3761000037193298,2.9281606674194336,10000,11673.316876888275,0.521191418170929,2.095719814300537,0.485179990530014,2.2728850841522217,50000 -1174.8933689594269,1.210268497467041,10961.407228469849,23841,0,10961.407228469849,0.3845000267028808,2.868086576461792,10000,12138.750408172607,0.5325781106948853,2.0169644355773926,0.4902399778366089,2.243086338043213,50000 -1222.1330163478851,1.244988203048706,11381.509302854538,24762,0,11381.509302854538,0.3880000114440918,2.834979295730591,10000,12606.17445731163,0.5335351228713989,2.0236427783966064,0.4941200017929077,2.203388214111328,50000 -1264.5435791015625,1.27274751663208,11801.537393569946,25685,0,11801.537393569946,0.3906000256538391,2.8306357860565186,10000,13068.69039440155,0.5403710603713989,2.013492584228516,0.5018999576568604,2.1954407691955566,50000 -1309.7079060077667,1.3029565811157229,12221.88210463524,26604,0,12221.88210463524,0.3939000070095062,2.7954912185668945,10000,13534.276743412018,0.5479491949081421,1.9419466257095337,0.5062199831008911,2.1494569778442383,50000 -1349.8739371299744,1.3320088386535645,12642.213397979736,27527,0,12642.213397979736,0.4063000082969665,2.749185562133789,10000,13994.85052037239,0.5578905940055847,1.9007580280303955,0.51419997215271,2.116635322570801,50000 -1395.4143552780151,1.3635451793670654,13062.25780081749,28448,0,13062.25780081749,0.4115000069141388,2.720576763153076,10000,14460.51496386528,0.5541796684265137,1.9155114889144893,0.5175999999046326,2.0890777111053467,50000 -1440.4584770202637,1.391657829284668,13482.429998636246,29369,0,13482.429998636246,0.4068000316619873,2.7158117294311523,10000,14925.807185173036,0.5678515434265137,1.865001916885376,0.5196200013160706,2.0869064331054688,50000 -1486.9760718345642,1.4241070747375488,13902.37486410141,30290,0,13902.37486410141,0.4129000306129455,2.727597951889038,10000,15392.34959602356,0.5793359279632568,1.8263520002365112,0.5206999778747559,2.103342294692993,50000 -1530.2789916992188,1.456043720245361,14322.354754447935,31212,0,14322.354754447935,0.4142000079154968,2.687079906463623,10000,15855.711465358734,0.564160168170929,1.8538535833358765,0.527899980545044,2.041722297668457,50000 -1575.9266781806946,1.4843404293060305,14742.48341703415,32134,0,14742.48341703415,0.4193000197410583,2.66968321800232,10000,16321.563993692398,0.5700390338897705,1.8478858470916748,0.531279981136322,2.034501314163208,50000 -1622.5681648254397,1.513521432876587,15162.777776002884,33052,0,15162.777776002884,0.4118000268936157,2.7029848098754883,10000,16788.576245307922,0.5787890553474426,1.8214852809906008,0.5297799706459045,2.06502366065979,50000 -1668.3824818134308,1.5456140041351318,15582.711037397385,33973,0,15582.711037397385,0.4244000315666199,2.666775941848755,10000,17254.40386199951,0.5707616806030273,1.8383631706237795,0.5380600094795227,2.0063157081604004,50000 -1714.5101835727692,1.5746049880981443,16003.069641828535,34894,0,16003.069641828535,0.4220000207424164,2.628847360610962,10000,17720.966168165207,0.5763476490974426,1.8116711378097528,0.5345199704170227,2.015291452407837,50000 -1762.5099685192108,1.608259916305542,16423.170471429825,35814,0,16423.170471429825,0.4217000305652618,2.6568825244903564,10000,18189.147683382034,0.581250011920929,1.7759336233139038,0.5402399897575378,1.991693377494812,50000 -1806.7723808288567,1.6376621723175049,16843.52310347557,36737,0,16843.52310347557,0.4239000082015991,2.632972478866577,10000,18653.83980345726,0.5838086009025574,1.7920989990234375,0.5418800115585327,1.9881590604782104,50000 -1853.3759117126465,1.6815898418426514,17263.79153227806,37657,0,17263.79153227806,0.4320000112056732,2.602356672286988,10000,19120.8027510643,0.5879687070846558,1.7495322227478027,0.5480200052261353,1.946392297744751,50000 -1897.910856246948,1.71612548828125,17684.12568449974,38580,0,17684.12568449974,0.4335000216960907,2.5650787353515625,10000,19585.754362106323,0.5942773222923279,1.714832067489624,0.5480599999427795,1.926216721534729,50000 -1944.2633888721464,1.7545132637023926,18104.34602546692,39501,0,18104.34602546692,0.4365000128746032,2.580235958099365,10000,20052.413615226746,0.6201366782188416,1.614801287651062,0.5507599711418152,1.94295072555542,50000 -1992.8587412834167,1.7895410060882568,18524.72751617432,40422,0,18524.72751617432,0.430400013923645,2.604454517364502,10000,20521.472969293594,0.5883983969688416,1.758009433746338,0.5488199591636658,1.9514132738113403,50000 -2036.9278779029848,1.8252899646759035,18944.705970048904,41344,0,18944.705970048904,0.4415000081062317,2.542169570922852,10000,20985.605131864548,0.6026757955551147,1.6814193725585938,0.5562199950218201,1.8999054431915283,50000 -2085.58514547348,1.8638558387756348,19364.84474134445,42267,0,19364.84474134445,0.442300021648407,2.526688814163208,10000,21454.48744463921,0.6179882884025574,1.5960801839828491,0.560479998588562,1.867920994758606,50000 -2130.889495611191,1.8945605754852293,19785.094200372696,43187,0,19785.094200372696,0.4412000179290771,2.5503687858581543,10000,21920.119605779648,0.5931640267372131,1.7232056856155396,0.5553199648857117,1.9169193506240845,50000 -2179.0115325450897,1.9270873069763184,20205.310017347336,44108,0,20205.310017347336,0.4472000300884247,2.5132040977478027,10000,22388.53720092773,0.6003515720367432,1.6775927543640137,0.5586400032043457,1.8800852298736568,50000 -2225.3049223423004,1.9600763320922847,20625.409114599228,45030,0,20625.409114599228,0.4420000314712524,2.5277490615844727,10000,22855.010596990585,0.6131640672683716,1.6342071294784546,0.5626999735832214,1.8697541952133176,50000 -2271.9245131015778,1.9946684837341309,21045.712456464767,45952,0,21045.712456464767,0.4379000067710876,2.550467729568481,10000,23322.01553273201,0.6003320217132568,1.708828091621399,0.5588200092315674,1.894186973571777,50000 -2320.201394081116,2.025162935256958,21465.93386626244,46874,0,21465.93386626244,0.4422000348567962,2.517657518386841,10000,23790.59200644493,0.6097265481948853,1.6467933654785156,0.5633599758148193,1.8516621589660645,50000 -2364.2496979236603,2.061084032058716,21886.19406223297,47793,0,21886.19406223297,0.4451000094413757,2.5339388847351074,10000,24254.98417496681,0.6098241806030273,1.677674412727356,0.5647000074386597,1.89734959602356,50000 -2412.188717842102,2.0992519855499268,22306.34263277054,48711,0,22306.34263277054,0.443200021982193,2.5331673622131348,10000,24723.15763783455,0.6181249618530273,1.6265902519226074,0.5595600008964539,1.886021375656128,50000 -2459.936644077301,2.133064985275269,22726.50920295716,49631,0,22726.50920295716,0.444100022315979,2.522622585296631,10000,25191.15389060974,0.6079882383346558,1.6691138744354248,0.5669800043106079,1.862606406211853,50000 -2506.8026852607727,2.166099071502685,23146.88272571564,50554,0,23146.88272571564,0.4585000276565552,2.471269369125366,10000,25658.474204063416,0.6143358945846558,1.6266354322433472,0.5703999996185303,1.8336740732192995,50000 -2550.7530856132507,2.205763101577759,23566.82369351387,51474,0,23566.82369351387,0.4506000280380249,2.467916250228882,10000,26122.45309472084,0.6382616758346558,1.5179091691970823,0.5744999647140503,1.8161741495132449,50000 -2597.088151693344,2.242366075515747,23987.047943353653,52393,0,23987.047943353653,0.4531000256538391,2.4591097831726074,10000,26589.09598493576,0.615527331829071,1.6107457876205444,0.5764999985694885,1.7930006980895996,50000 -2645.788686990738,2.2893104553222656,24407.272254943848,53314,0,24407.272254943848,0.4622000157833099,2.461902141571045,10000,27058.116228818893,0.6172460913658142,1.6188424825668335,0.5728799700737,1.8209742307662964,50000 -2693.9905047416687,2.3210394382476807,24827.344648122787,54234,0,24827.344648122787,0.4579000174999237,2.464775800704956,10000,27526.469512939453,0.6295117139816284,1.57720947265625,0.572219967842102,1.8321927785873413,50000 -2742.075278520584,2.35662841796875,25247.712097883224,55152,0,25247.712097883224,0.46670001745224,2.4314701557159424,10000,27995.00447773933,0.6147655844688416,1.6305612325668335,0.5796599984169006,1.788099765777588,50000 -2785.8978073596954,2.391852140426636,25667.8112885952,56070,0,25667.8112885952,0.4645000100135803,2.429868221282959,10000,28459.00898051262,0.6223242282867432,1.5823334455490112,0.5771600008010864,1.789430856704712,50000 -2834.294125318527,2.429672718048096,26088.068242549896,56989,0,26088.068242549896,0.457800030708313,2.4639346599578857,10000,28927.74728178978,0.6273437142372131,1.5826400518417358,0.5783799886703491,1.807153582572937,50000 -2880.826591491699,2.475525379180908,26508.35567474365,57909,0,26508.35567474365,0.4639000296592712,2.410048246383667,10000,29394.660277605057,0.6226171851158142,1.5738970041275024,0.578719973564148,1.7816462516784668,50000 -2927.505961894989,2.509864330291748,26928.282299280167,58827,0,26928.282299280167,0.4725000262260437,2.381237268447876,10000,29861.34847187996,0.6302929520606995,1.5352495908737185,0.5891000032424927,1.726948857307434,50000 -2976.0984501838684,2.5432207584381104,27348.234403848648,59747,0,27348.234403848648,0.4588000178337097,2.400780200958252,10000,30329.973765850067,0.6289257407188416,1.5401300191879272,0.581559956073761,1.7661179304122925,50000 -3022.2227504253387,2.578721761703491,27768.30704021454,60667,0,27768.30704021454,0.4660000205039978,2.385641574859619,10000,30796.25399804116,0.6617382764816284,1.412854790687561,0.5889999866485596,1.731900691986084,50000 -3070.1183915138245,2.6172053813934326,28188.32996058464,61583,0,28188.32996058464,0.4729000329971313,2.363012313842773,10000,31264.258637428284,0.6290038824081421,1.5284565687179563,0.5866400003433228,1.7359672784805298,50000 -3116.9010808467865,2.6603078842163086,28608.93250131607,62501,0,28608.93250131607,0.4690000116825104,2.3870527744293213,10000,31731.734622240067,0.6339648365974426,1.5344921350479126,0.5881800055503845,1.743263602256775,50000 -3165.640180826187,2.699095726013184,29029.071456193924,63421,0,29029.071456193924,0.4687000215053558,2.4003686904907227,10000,32200.69903111458,0.6504101157188416,1.453066349029541,0.5853399634361267,1.746640682220459,50000 -3212.613987445832,2.732761144638061,29449.20278072357,64341,0,29449.20278072357,0.4739000201225281,2.3589258193969727,10000,32667.884612083435,0.6320703029632568,1.5366462469100952,0.588979959487915,1.7273471355438232,50000 -3258.559402942657,2.7715091705322266,29869.36788988113,65261,0,29869.36788988113,0.4754000306129455,2.375885486602783,10000,33134.08188343048,0.6381250023841858,1.5205624103546145,0.5915200114250183,1.7391788959503174,50000 -3304.079433441162,2.809138774871826,30289.691901922222,66184,0,30289.691901922222,0.4763000309467315,2.3564023971557617,10000,33600.01197576523,0.6523046493530273,1.447373628616333,0.595259964466095,1.7073568105697632,50000 -3351.523953676224,2.8487424850463867,30710.14184713364,67101,0,30710.14184713364,0.4821000099182129,2.307164669036865,10000,34067.99260735512,0.6419531106948853,1.4803026914596558,0.6005600094795227,1.6780091524124146,50000 -3397.4480471611023,2.8839402198791504,31130.390008687973,68020,0,31130.390008687973,0.4836000204086303,2.322556495666504,10000,34534.2478351593,0.6422265768051147,1.4910420179367063,0.5968599915504456,1.6952730417251587,50000 -3442.4561920166016,2.921594381332397,31550.531436681747,68940,0,31550.531436681747,0.4816000163555145,2.361830711364746,10000,34999.48217225075,0.646289050579071,1.4966928958892822,0.5961999893188477,1.7258793115615845,50000 -3491.027411222458,2.957694292068481,31970.809402942657,69862,0,31970.809402942657,0.4838000237941742,2.3317134380340576,10000,35468.415583610535,0.6507031321525574,1.4504530429840088,0.6015200018882751,1.6813979148864746,50000 -3537.5296635627747,2.994217157363892,32390.82547283173,70783,0,32390.82547283173,0.4803000092506408,2.354217767715454,10000,35935.01786708832,0.6480273008346558,1.50771164894104,0.600879967212677,1.7230969667434692,50000 -3585.316800355912,3.033379316329956,32810.79419326782,71703,0,32810.79419326782,0.4830000102519989,2.303077459335327,10000,36402.86121177673,0.656054675579071,1.4267289638519287,0.6021999716758728,1.6627541780471802,50000 -3632.553422927856,3.0752978324890137,33231.07704329491,72620,0,33231.07704329491,0.4906000196933746,2.331911087036133,10000,36870.46979093552,0.6702929735183716,1.386479735374451,0.605679988861084,1.6824146509170532,50000 -3679.482887744904,3.113532781600952,33651.26502633095,73540,0,33651.26502633095,0.4826000332832336,2.3274142742156982,10000,37337.673639297485,0.6487694978713989,1.481659770011902,0.6056999564170837,1.6786365509033203,50000 -3729.273057460785,3.151975631713867,34071.27830410004,74460,0,34071.27830410004,0.4852000176906585,2.282309293746948,10000,37807.56340956688,0.650585949420929,1.4430195093154907,0.6067399978637695,1.657494306564331,50000 -3776.994529485704,3.188735008239746,34491.6043074131,75382,0,34491.6043074131,0.486700028181076,2.301242113113404,10000,38275.6949763298,0.6629296541213989,1.3945977687835691,0.6041199564933777,1.6651949882507324,50000 -3824.521594762802,3.2313475608825684,34911.860114097595,76303,0,34911.860114097595,0.4861000180244446,2.312741279602051,10000,38743.56843471527,0.6522851586341858,1.4518017768859863,0.6080399751663208,1.6486347913742063,50000 -3871.440359354019,3.267023801803589,35331.959324359894,77223,0,35331.959324359894,0.4901000261306762,2.283908605575561,10000,39210.66960573197,0.6526171565055847,1.4243426322937012,0.605459988117218,1.6445677280426023,50000 -3920.543641090393,3.3071439266204834,35752.245290756226,78143,0,35752.245290756226,0.4957000315189361,2.271117210388184,10000,39680.14678049088,0.6633203029632568,1.4052098989486694,0.6102399826049805,1.643385887145996,50000 -3968.289653062821,3.3433430194854736,36172.618824243546,79065,0,36172.618824243546,0.4962000250816345,2.260464668273926,10000,40148.34972453117,0.6604882478713989,1.4064991474151611,0.6160799860954285,1.6271092891693115,50000 -4017.691013813019,3.384563684463501,36592.68895483017,79985,0,36592.68895483017,0.4923000335693359,2.274275779724121,10000,40617.90951514244,0.6540429592132568,1.444997787475586,0.6112200021743774,1.641564965248108,50000 -4066.0741584301,3.4339022636413574,37012.93760061264,80905,0,37012.93760061264,0.4937000274658203,2.255664110183716,10000,41086.63825082779,0.6605077981948853,1.3962000608444214,0.6113799810409546,1.6233501434326172,50000 -4114.506975889206,3.4718174934387207,37433.07783913613,81826,0,37433.07783913613,0.503600001335144,2.2272136211395264,10000,41555.29655838013,0.6890624761581421,1.2909141778945925,0.6201599836349487,1.59787118434906,50000 -4157.234240293503,3.51223087310791,37853.40079331398,82746,0,37853.40079331398,0.4896000325679779,2.2967171669006348,10000,42018.43554711342,0.6591405868530273,1.4403164386749268,0.6159799695014954,1.641635537147522,50000 -4206.504290103912,3.5590643882751465,38273.69846391678,83666,0,38273.69846391678,0.4989000260829925,2.235353469848633,10000,42488.09752130509,0.663769543170929,1.3689274787902832,0.616320013999939,1.5968698263168335,50000 -4253.889084339142,3.6006920337677,38693.85767745972,84587,0,38693.85767745972,0.4974000155925751,2.218472480773926,10000,42955.73055052757,0.6851366758346558,1.2849160432815552,0.6199399828910828,1.5707097053527832,50000 -4301.529177188873,3.6426284313201904,39114.19425010681,85507,0,39114.19425010681,0.4985000193119049,2.248077869415283,10000,43423.79737305641,0.6639453172683716,1.3985720872879028,0.6188200116157532,1.608765721321106,50000 -4349.79114151001,3.6794145107269287,39534.38470721245,86427,0,39534.38470721245,0.5034000277519226,2.218822717666626,10000,43892.33448171616,0.673535168170929,1.353287935256958,0.6249600052833557,1.573027729988098,50000 -4398.92951130867,3.722033023834229,39954.31283664704,87345,0,39954.31283664704,0.5062000155448914,2.2032463550567627,10000,44361.49378180504,0.6821679472923279,1.2937649488449097,0.6266599893569946,1.55876624584198,50000 -4446.9581387043,3.759573459625244,40374.37486696243,88262,0,40374.37486696243,0.510200023651123,2.181927442550659,10000,44829.66952109337,0.6767578125,1.3330546617507937,0.6308199763298035,1.5420633554458618,50000 -4495.342826843262,3.800304889678955,40794.69830536842,89181,0,40794.69830536842,0.509600043296814,2.215367317199707,10000,45298.46662902832,0.6698632836341858,1.3340215682983398,0.6282399892807007,1.5469225645065308,50000 -4540.608873128891,3.844949960708618,41214.659896850586,90099,0,41214.659896850586,0.5063000321388245,2.2054131031036377,10000,45763.786215782166,0.68359375,1.305981159210205,0.6265999674797058,1.5628786087036133,50000 -4589.280221223831,3.891319274902344,41635.00652647018,91018,0,41635.00652647018,0.5110000371932983,2.183170795440674,10000,46232.89801955223,0.6847070455551147,1.3044037818908691,0.6290199756622314,1.5476787090301514,50000 -4636.324400186539,3.93449640274048,42055.1633245945,91936,0,42055.1633245945,0.5093000531196594,2.173914194107056,10000,46700.18941116333,0.6799609065055847,1.3173344135284424,0.6291599869728088,1.534510850906372,50000 -4685.3404994010925,3.972641944885254,42475.19194078445,92858,0,42475.19194078445,0.5106000304222107,2.1962108612060547,10000,47169.31934523583,0.6845703125,1.30815589427948,0.6358000040054321,1.5391371250152588,50000 -4731.815329551697,4.010065317153931,42895.57374858856,93780,0,42895.57374858856,0.5103000402450562,2.147897481918335,10000,47636.26103925705,0.7090038657188416,1.1813215017318726,0.6354999542236328,1.5112653970718384,50000 -4780.633121490479,4.056293964385986,43315.627083063126,94699,0,43315.627083063126,0.516700029373169,2.158723831176758,10000,48105.22648501396,0.6853711009025574,1.2983242273330688,0.634939968585968,1.5220338106155396,50000 -4830.998948812485,4.0974836349487305,43735.60219955444,95619,0,43735.60219955444,0.5182000398635864,2.203728199005127,10000,48575.65565729141,0.6918359398841858,1.3312885761260986,0.6355999708175659,1.5759185552597046,50000 -4876.413905143738,4.138431787490845,44155.92951416969,96541,0,44155.92951416969,0.5205000042915344,2.156029462814331,10000,49041.48910880089,0.6978124976158142,1.2508113384246826,0.6395999789237976,1.5194545984268188,50000 -4923.843191146851,4.181923389434815,44575.93083691597,97461,0,44575.93083691597,0.5254999995231628,2.106327772140503,10000,49509.01126098633,0.6885156035423279,1.2628841400146484,0.642579972743988,1.4694026708602903,50000 -4971.78942322731,4.224483251571655,44995.94634652138,98381,0,44995.94634652138,0.5218999981880188,2.118631362915039,10000,49977.063490867615,0.6943945288658142,1.2435778379440308,0.6418399810791016,1.479137659072876,50000 -5018.474480867386,4.268015146255493,45416.24973320961,99296,0,45416.24973320961,0.5139000415802002,2.170624017715454,10000,50444.14269065857,0.6937890648841858,1.2605602741241455,0.6408799886703491,1.5083057880401611,50000 -5067.025764942169,4.322429895401001,45836.35974335671,100214,0,45836.35974335671,0.5177000164985657,2.1122829914093018,10000,50912.906153678894,0.6874804496765137,1.2632285356521606,0.6426999568939209,1.4662885665893557,50000 -5111.740765094757,4.362881422042847,46256.391678094864,101134,0,46256.391678094864,0.5227000117301941,2.1116693019866943,10000,51377.74165701866,0.6952343583106995,1.2298552989959717,0.646399974822998,1.4554753303527832,50000 -5158.299484729767,4.410746574401856,46676.57629132271,102056,0,46676.57629132271,0.5200000405311584,2.172455310821533,10000,51844.58192944527,0.6937499642372131,1.2900238037109375,0.638759970664978,1.5294249057769775,50000 -5205.239888191223,4.460630416870117,47096.54494357109,102974,0,47096.54494357109,0.5229000449180603,2.0942981243133545,10000,52311.58908891678,0.7155859470367432,1.145095944404602,0.6472600102424622,1.4536617994308472,50000 -5252.287913560867,4.502202272415161,47516.91737222672,103892,0,47516.91737222672,0.52510005235672,2.1233255863189697,10000,52779.09874868393,0.6991210579872131,1.2539443969726562,0.6495400071144104,1.4771441221237185,50000 -5294.593217134476,4.5446178913116455,47936.990965127945,104814,0,47936.990965127945,0.5250000357627869,2.094282627105713,10000,53241.568251132965,0.7024218440055847,1.2268569469451904,0.6485799551010132,1.4664111137390137,50000 -5340.061425924301,4.586939811706543,48357.00198101997,105735,0,48357.00198101997,0.5297000408172607,2.06118106842041,10000,53707.13777804375,0.7244726419448853,1.1140944957733154,0.6542800068855286,1.4219244718551636,50000 -5388.473347187042,4.627787828445435,48777.18662452698,106651,0,48777.18662452698,0.5336000323295593,2.058079481124878,10000,54175.82275629044,0.7056054472923279,1.192114233970642,0.6538000106811523,1.427022099494934,50000 -5438.257611513138,4.671726226806641,49197.227964401245,107568,0,49197.227964401245,0.5356000065803528,2.0526721477508545,10000,54645.73945307732,0.7154492139816284,1.1752427816390991,0.6553599834442139,1.4241199493408203,50000 -5487.164425611496,4.715782403945923,49617.28289413452,108488,0,49617.28289413452,0.5313000082969666,2.080390214920044,10000,55114.79280400276,0.71839839220047,1.1570578813552856,0.6541000008583069,1.4482253789901731,50000 -5532.3830144405365,4.7658116817474365,50037.39028072357,109409,0,50037.39028072357,0.5364000201225281,2.062147617340088,10000,55580.21659255028,0.7098046541213989,1.2095328569412231,0.6589800119400024,1.4317922592163086,50000 -5579.963124036789,4.810678005218506,50457.57064390183,110328,0,50457.57064390183,0.5294000506401062,2.0938665866851807,10000,56048.070026397705,0.705078125,1.216767191886902,0.6532399654388428,1.4524482488632202,50000 -5630.232651948929,4.852938175201416,50877.94385957718,111247,0,50877.94385957718,0.5353000164031982,2.035445213317871,10000,56518.80230307579,0.7220507860183716,1.1261377334594729,0.6609999537467957,1.4034225940704346,50000 -5678.924888134003,4.904115915298462,51298.08949494362,112167,0,51298.08949494362,0.5349000096321106,2.034170866012573,10000,56987.73841023445,0.7176757454872131,1.1667044162750244,0.6600599884986877,1.414770007133484,50000 -5727.290082454681,4.955888032913208,51718.244585990906,113085,0,51718.244585990906,0.5451000332832336,2.005274534225464,10000,57456.35862541199,0.7177538871765137,1.1266690492630005,0.6662999987602234,1.3710778951644895,50000 -5772.027301549912,5.000272750854492,52138.60879659653,114004,0,52138.60879659653,0.5439000129699707,2.008410692214966,10000,57921.55140066147,0.7226366996765137,1.1332825422286987,0.6693199872970581,1.3838796615600586,50000 -5822.523226261139,5.0422186851501465,52559.03491163254,114925,0,52559.03491163254,0.5440000295639038,1.9908446073532104,10000,58392.56290578842,0.7480273246765137,1.0093517303466797,0.6700999736785889,1.3534212112426758,50000 -5868.125700950623,5.087374925613403,52979.48181724548,115842,0,52979.48181724548,0.5459000468254089,1.989838004112244,10000,58858.70378828049,0.7229687571525574,1.115723729133606,0.6693199872970581,1.3546879291534424,50000 -5916.103356122971,5.1391777992248535,53399.61646103859,116758,0,53399.61646103859,0.5434000492095947,1.9787174463272093,10000,59326.91538286209,0.7300195097923279,1.0898079872131348,0.671019971370697,1.3476372957229614,50000 -5962.832585573196,5.19250226020813,53819.964002370834,117679,0,53819.964002370834,0.544700026512146,2.0176634788513184,10000,59794.09345436096,0.7401171922683716,1.0645672082901,0.6676999926567078,1.3838788270950315,50000 -6011.4386677742,5.236751317977905,54240.32156014442,118597,0,54240.32156014442,0.5469000339508057,1.9880914688110352,10000,60263.14970517159,0.7274804711341858,1.0996066331863403,0.6718199849128723,1.3471410274505615,50000 -6060.143052816391,5.284197807312012,54660.49790549278,119518,0,54660.49790549278,0.5469000339508057,1.9718530178070068,10000,60732.12470769882,0.7304491996765137,1.0636045932769775,0.672760009765625,1.335329294204712,50000 -6109.130714178085,5.328948974609375,55080.59421658516,120438,0,55080.59421658516,0.54830002784729,1.983000874519348,10000,61201.30026555061,0.7368554472923279,1.0701631307601929,0.675879955291748,1.3530094623565674,50000 -6155.618450164795,5.383184432983398,55500.81260108948,121354,0,55500.81260108948,0.5560000538825989,1.955164909362793,10000,61668.1078953743,0.7291015386581421,1.0866446495056152,0.6756199598312378,1.3281562328338623,50000 -6201.618116378784,5.425713300704956,55921.02682614327,122273,0,55921.02682614327,0.5575000047683716,1.926209807395935,10000,62134.41156554222,0.7377538681030273,1.046796798706055,0.6814199686050415,1.3041400909423828,50000 -6248.677495479584,5.474867820739746,56341.15030384064,123189,0,56341.15030384064,0.5567000508308411,1.947321414947509,10000,62601.69093823433,0.7463476657867432,1.0369484424591064,0.6841599941253662,1.3181054592132568,50000 -6299.499836921692,5.528214454650879,56761.34786653519,124108,0,56761.34786653519,0.5599000453948975,1.9192627668380733,10000,63072.82502889633,0.7437499761581421,1.0200892686843872,0.6828199625015259,1.2920786142349243,50000 -6343.759425878525,5.581441879272461,57181.43518590927,125025,0,57181.43518590927,0.5612000226974487,1.9244365692138672,10000,63537.27227139473,0.7412695288658142,1.0353752374649048,0.6846599578857422,1.2898250818252563,50000 -6391.28493309021,5.625512361526489,57601.69718050957,125945,0,57601.69718050957,0.5624000430107117,1.912021517753601,10000,64005.15271496773,0.7484374642372131,0.9919676780700684,0.6880999803543091,1.2748708724975586,50000 -6439.025028705597,5.6715407371521,58021.72929406166,126864,0,58021.72929406166,0.5654000043869019,1.9312610626220703,10000,64473.019273757935,0.76025390625,0.985008418560028,0.687559962272644,1.3039485216140747,50000 -6486.063357114792,5.723682403564453,58441.85705113411,127784,0,58441.85705113411,0.5622000098228455,1.8965699672698968,10000,64940.28579545021,0.7442578077316284,1.032114863395691,0.6894199848175049,1.2755590677261353,50000 -6533.113852500916,5.769906282424927,58862.21596360207,128703,0,58862.21596360207,0.5627000331878662,1.8980499505996704,10000,65407.78814959526,0.7488671541213989,1.0089884996414185,0.6900999546051025,1.273664474487305,50000 -6581.099862098694,5.814780950546265,59282.16807126999,129621,0,59282.16807126999,0.5667999982833862,1.8743077516555784,10000,65875.81766748428,0.7646874785423279,0.932016670703888,0.6924799680709839,1.245804786682129,50000 -6630.546304941177,5.8607916831970215,59702.48772835732,130539,0,59702.48772835732,0.5671000480651855,1.879859447479248,10000,66345.67736411095,0.7518749833106995,0.9943748116493224,0.6937199831008911,1.2452759742736816,50000 -6678.470661401749,5.9110212326049805,60122.77754402161,131459,0,60122.77754402161,0.5730000138282776,1.8423492908477783,10000,66813.98977065086,0.7592187523841858,0.9559024572372437,0.6979999542236328,1.223264455795288,50000 -6728.100863933563,5.9576075077056885,60542.986839056015,132379,0,60542.986839056015,0.5699000358581543,1.8848967552185056,10000,67283.923807621,0.7625390291213989,0.9606295824050904,0.6951199769973755,1.261569857597351,50000 -6775.558528184891,6.004410266876221,60963.19054841995,133298,0,60963.19054841995,0.5729000568389893,1.8487434387207031,10000,67751.67906785011,0.7527148127555847,0.982952117919922,0.7013199925422668,1.2232753038406372,50000 -6821.74352645874,6.052812814712524,61383.45324897766,134216,0,61383.45324897766,0.5751000046730042,1.8352450132369995,10000,68218.22201561928,0.7622656226158142,0.935705542564392,0.701259970664978,1.2122747898101809,50000 -6868.752334356308,6.101391077041626,61803.36758232117,135131,0,61803.36758232117,0.5815000534057617,1.8275467157363887,10000,68685.2418153286,0.7727343440055847,0.8989397287368774,0.7053200006484985,1.198777675628662,50000 -6917.71401143074,6.14823055267334,62223.56702518463,136050,0,62223.56702518463,0.5791000127792358,1.836230993270874,10000,69154.49729061127,0.77685546875,0.8904039859771729,0.7037000060081482,1.207800030708313,50000 -6966.49352145195,6.19426703453064,62643.62982487679,136970,0,62643.62982487679,0.579300045967102,1.815743088722229,10000,69623.43274188042,0.7667577862739563,0.9236660003662108,0.7049799561500549,1.1936697959899902,50000 -7013.5392434597015,6.243442058563232,63063.79121661186,137890,0,63063.79121661186,0.5799000263214111,1.819128155708313,10000,70090.73695373535,0.772753894329071,0.902352213859558,0.706279993057251,1.1953895092010498,50000 -7062.826720952988,6.297896862030029,63483.9806933403,138808,0,63483.9806933403,0.5788000226020813,1.7924264669418335,10000,70560.31663990021,0.78382807970047,0.8481549024581909,0.7096999883651733,1.173912525177002,50000 -7110.006555318832,6.345866918563843,63903.96551418304,139726,0,63903.96551418304,0.5897000432014465,1.7686362266540527,10000,71027.57639861107,0.7791601419448853,0.8667029142379761,0.7123000025749207,1.1546515226364136,50000 -7158.72931265831,6.395110607147217,64324.39351463318,140642,0,64324.39351463318,0.5906000137329102,1.7629472017288208,10000,71496.82323789597,0.7782421708106995,0.8562284708023071,0.7149199843406677,1.147608757019043,50000 -7206.5280418396,6.445488691329956,64744.67516851425,141561,0,64744.67516851425,0.588200032711029,1.8064215183258057,10000,71965.00109434128,0.7863867282867432,0.8447074890136719,0.7131199836730957,1.16819167137146,50000 -7255.32564163208,6.493313312530518,65164.6888320446,142474,0,65164.6888320446,0.5914000272750854,1.762101411819458,10000,72433.90739750862,0.7731249928474426,0.8849400281906128,0.7135399580001831,1.1557042598724363,50000 -7305.463745594025,6.540487051010132,65585.04565286636,143390,0,65585.04565286636,0.5877000093460083,1.783122181892395,10000,72904.4974284172,0.7816015481948853,0.868894100189209,0.716219961643219,1.1561012268066406,50000 -7354.108882188797,6.597029685974121,66004.958309412,144308,0,66004.958309412,0.5929000377655029,1.7589880228042605,10000,73373.15925145149,0.79212886095047,0.8323768377304077,0.7192999720573425,1.145952582359314,50000 -7401.779304265976,6.644883871078491,66425.3852956295,145228,0,66425.3852956295,0.593000054359436,1.7642781734466553,10000,73841.35216474533,0.7823046445846558,0.8637334108352661,0.7188799977302551,1.1443239450454712,50000 -7452.07506108284,6.691014051437378,66845.65575814247,146149,0,66845.65575814247,0.5923000574111938,1.7559378147125244,10000,74312.01202607155,0.7882617115974426,0.8386130928993225,0.7200599908828735,1.130061388015747,50000 -7501.210858821869,6.740800857543945,67265.79447817802,147070,0,67265.79447817802,0.5910000205039978,1.7555046081542969,10000,74781.38368654251,0.7914062142372131,0.8205158710479736,0.7212799787521362,1.122092366218567,50000 -7550.349389076233,6.793242692947388,67685.98072981834,147991,0,67685.98072981834,0.6055999994277954,1.719403624534607,10000,75250.80831742287,0.8050585985183716,0.7596969604492188,0.724079966545105,1.102674126625061,50000 -7600.08563709259,6.847251892089844,68106.26304864883,148910,0,68106.26304864883,0.5976999998092651,1.717188000679016,10000,75720.92788505554,0.79408198595047,0.8039225339889526,0.7251399755477905,1.0975062847137451,50000 -7649.241825342178,6.897162199020386,68526.58949446678,149831,0,68526.58949446678,0.5993000268936157,1.7362432479858398,10000,76190.50789570808,0.7937890291213989,0.8100752830505371,0.7240999937057495,1.1172319650650024,50000 -7697.823964357376,6.947785139083862,68946.71733403206,150752,0,68946.71733403206,0.6023000478744507,1.6958317756652832,10000,76659.31583476067,0.8075780868530273,0.7453005909919739,0.7288599610328674,1.078874945640564,50000 -7747.169746875763,6.998458623886108,69366.65239262581,151672,0,69366.65239262581,0.6053000092506409,1.7195252180099487,10000,77128.69496536255,0.7996875047683716,0.7857888340950012,0.7291399836540222,1.0955058336257937,50000 -7798.037398815155,7.052371025085449,69787.05628466606,152590,0,69787.05628466606,0.6079000234603882,1.6825690269470217,10000,77600.06755399704,0.8040820360183716,0.7654426693916321,0.7314599752426147,1.074555277824402,50000 -7851.229133844376,7.106428384780884,70207.12128043175,153511,0,70207.12128043175,0.6045000553131104,1.678421974182129,10000,78073.42539596558,0.8078711032867432,0.7358626127243042,0.7321999669075012,1.064021110534668,50000 -7901.632784366608,7.164010286331177,70627.35976624489,154432,0,70627.35976624489,0.6107000112533569,1.7023465633392334,10000,78544.17288303375,0.8040429353713989,0.7815302014350891,0.7331399917602539,1.0862562656402588,50000 -7953.746181964874,7.211568593978882,71047.48124408722,155351,0,71047.48124408722,0.6113000512123108,1.681421399116516,10000,79016.50269341469,0.804492175579071,0.760875940322876,0.7349199652671814,1.0727885961532593,50000 -8003.304003477097,7.269599676132202,71467.49888920784,156269,0,71467.49888920784,0.6067000031471252,1.6834540367126465,10000,79486.18399739265,0.8096483945846558,0.7489880323410034,0.73499995470047,1.0660606622695925,50000 -8052.717717885971,7.328284025192261,71887.46869874,157189,0,71887.46869874,0.617400050163269,1.6608270406723022,10000,79955.67472600937,0.8141406178474426,0.7198129296302795,0.7380399703979492,1.0487289428710938,50000 -8099.809978485107,7.386165380477905,72307.4182343483,158105,0,72307.4182343483,0.6190000176429749,1.6595739126205444,10000,80422.82166051865,0.813769519329071,0.7237128019332886,0.7416200041770935,1.0368276834487915,50000 -8149.245040655136,7.438138723373413,72727.65055274963,159025,0,72727.65055274963,0.6154000163078308,1.6334420442581177,10000,80892.5890173912,0.8179101347923279,0.6971430778503418,0.7404400110244751,1.028631567955017,50000 -8198.83842921257,7.497429847717285,73147.87590813637,159945,0,73147.87590813637,0.6166000366210938,1.6675461530685425,10000,81362.51534724236,0.8225781321525574,0.6903254389762878,0.7412799596786499,1.0446081161499023,50000 -8248.3914706707,7.553764581680298,73568.24109864235,160856,0,73568.24109864235,0.622700035572052,1.641365885734558,10000,81832.5364947319,0.8190624713897705,0.7006577253341675,0.7416999936103821,1.023085355758667,50000 -8299.826081991196,7.611063718795776,73988.53427219391,161774,0,73988.53427219391,0.6225000023841858,1.6433871984481812,10000,82304.36846089363,0.8226171731948853,0.693584680557251,0.7432399988174438,1.023402214050293,50000 -8350.1466858387,7.661406278610229,74408.44389462471,162692,0,74408.44389462471,0.6247000098228455,1.6267750263214111,10000,82774.69685792923,0.8324609398841858,0.6406580805778503,0.7445999979972839,1.0042402744293213,50000 -8399.791483402252,7.7152698040008545,74828.47229576111,163612,0,74828.47229576111,0.6284000277519226,1.6180665493011477,10000,83244.47105193138,0.8236718773841858,0.6884275078773499,0.7479199767112732,1.008457899093628,50000 -8450.2347574234,7.767882108688354,75248.7256937027,164533,0,75248.7256937027,0.6251000165939331,1.6135456562042236,10000,83715.26717567444,0.8290624618530273,0.662371039390564,0.7500999569892883,1.003365993499756,50000 -8499.674641370773,7.819833278656006,75669.07854604721,165452,0,75669.07854604721,0.6278000473976135,1.5892136096954346,10000,84185.15949583054,0.8309765458106995,0.6480200886726379,0.7496799826622009,0.990441918373108,50000 -8551.956509113312,7.872750520706177,76089.07453298569,166372,0,76089.07453298569,0.6302000284194946,1.583455204963684,10000,84657.53774857521,0.8332812190055847,0.6423313617706299,0.7530800104141235,0.9780200719833374,50000 -8601.107754945755,7.928579807281494,76509.43893957138,167291,0,76509.43893957138,0.6345000267028809,1.580045223236084,10000,85127.15686106682,0.8354296684265137,0.6281799077987671,0.7523199915885925,0.9751542210578918,50000 -8651.759412050247,7.982419967651367,76929.56784844398,168211,0,76929.56784844398,0.6341000199317932,1.5775206089019775,10000,85598.03895044327,0.8359179496765137,0.6287804841995239,0.7562599778175354,0.970455288887024,50000 -8697.523322582245,8.035626888275146,77349.9605910778,169131,0,77349.9605910778,0.6381000280380249,1.56676185131073,10000,86064.29647517204,0.8357812166213989,0.6224377751350403,0.7562800049781799,0.9682385921478271,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/measurements.csv deleted file mode 100644 index 72e637f69..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1883 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36552602,6.907756,,,,,,,,,,,,,, -1,,,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,37.03280329704285,66.09364891052246,37.03280329704285,29.060741901397705,0.0,0.0 -100,0.39008653,6.905988,,,,,,,,,,,,,, -200,0.47955623,6.8929963,,,,,,,,,,,,,, -300,0.65255666,6.850131,,,,,,,,,,,,,, -400,0.7248288,6.8068423,,,,,,,,,,,,,, -500,0.97842395,6.8205166,,,,,,,,,,,,,, -600,1.1528931,6.723776,,,,,,,,,,,,,, -700,1.9653043,6.6295214,,,,,,,,,,,,,, -800,1.0488639,6.704882,,,,,,,,,,,,,, -860,,,0.0125195309519767,6.41873836517334,0.0127199999988079,6.431779384613037,50000.0,0.0108000002801418,6.47079610824585,10000.0,457.21945309638977,530.0452859401703,457.21945309638977,72.75973510742188,0.0193541049957275,0.0 -900,1.7167344,6.6598234,,,,,,,,,,,,,, -1000,1.916771,6.46307,,,,,,,,,,,,,, -1100,2.232525,6.737201,,,,,,,,,,,,,, -1200,2.7787821,6.3848114,,,,,,,,,,,,,, -1300,1.880311,6.7043686,,,,,,,,,,,,,, -1400,1.912319,6.3954363,,,,,,,,,,,,,, -1500,1.8849802,6.42163,,,,,,,,,,,,,, -1600,1.88706,6.1552396,,,,,,,,,,,,,, -1700,2.4732106,6.104148,,,,,,,,,,,,,, -1773,,,0.0384765602648258,5.859482288360596,0.0354399979114532,5.89152193069458,50000.0,0.0303000006824731,6.007748603820801,10000.0,877.4289219379425,1000.501962184906,877.4289219379425,122.92927050590517,0.0499892234802246,0.0 -1800,2.0567112,6.1106586,,,,,,,,,,,,,, -1900,2.1731675,6.463726,,,,,,,,,,,,,, -2000,2.2821965,6.60379,,,,,,,,,,,,,, -2100,2.1683931,5.99454,,,,,,,,,,,,,, -2200,1.9409584,6.293905,,,,,,,,,,,,,, -2300,2.2504556,5.9796863,,,,,,,,,,,,,, -2400,2.0990067,5.862177,,,,,,,,,,,,,, -2500,1.7505331,6.4937563,,,,,,,,,,,,,, -2600,2.1047888,6.2387295,,,,,,,,,,,,,, -2692,,,0.0623632781207561,5.426368713378906,0.0602199994027614,5.460067272186279,50000.0,0.0446000024676322,5.647171974182129,10000.0,1297.60599899292,1467.5219156742096,1297.60599899292,169.69586300849917,0.0778777599334716,0.0 -2700,1.6361024,6.044617,,,,,,,,,,,,,, -2800,2.0956662,5.8406787,,,,,,,,,,,,,, -2900,2.1842756,5.8203263,,,,,,,,,,,,,, -3000,2.0592813,5.720662,,,,,,,,,,,,,, -3100,1.8947067,5.6558037,,,,,,,,,,,,,, -3200,1.6823628,6.578264,,,,,,,,,,,,,, -3300,1.7787066,5.693294,,,,,,,,,,,,,, -3400,1.6713102,5.9920444,,,,,,,,,,,,,, -3500,2.1070602,5.637171,,,,,,,,,,,,,, -3600,1.7556126,5.6762033,,,,,,,,,,,,,, -3612,,,0.0909960940480232,5.120298385620117,0.0842399969696998,5.156952857971191,50000.0,0.0625,5.394311428070068,10000.0,1717.723210811615,1922.2475941181185,1717.723210811615,204.232607126236,0.1020910739898681,0.0 -3700,2.3192267,5.5615644,,,,,,,,,,,,,, -3800,1.8830137,5.5569396,,,,,,,,,,,,,, -3900,2.3448324,5.491231,,,,,,,,,,,,,, -4000,1.587608,6.462715,,,,,,,,,,,,,, -4100,1.5054697,6.442871,,,,,,,,,,,,,, -4200,1.5710758,5.604182,,,,,,,,,,,,,, -4300,1.51323,6.411308,,,,,,,,,,,,,, -4400,1.857078,5.39197,,,,,,,,,,,,,, -4500,1.7055796,6.1777,,,,,,,,,,,,,, -4511,,,0.1286914050579071,4.747590065002441,0.119279995560646,4.806865692138672,50000.0,0.0897000059485435,5.0951080322265625,10000.0,2138.045210123062,2388.123862028122,2138.045210123062,249.70977568626404,0.1325874328613281,0.0 -4600,1.9388326,5.287646,,,,,,,,,,,,,, -4700,1.8111713,5.3895354,,,,,,,,,,,,,, -4800,2.0013542,5.1896143,,,,,,,,,,,,,, -4900,1.7650732,5.9034333,,,,,,,,,,,,,, -5000,1.7082665,5.7580457,,,,,,,,,,,,,, -5100,1.9031978,5.327183,,,,,,,,,,,,,, -5200,2.114166,5.196054,,,,,,,,,,,,,, -5300,2.255323,5.106659,,,,,,,,,,,,,, -5400,1.9696153,5.0470037,,,,,,,,,,,,,, -5431,,,0.1690820306539535,4.369624137878418,0.1549399942159652,4.4540791511535645,50000.0,0.1129000037908554,4.798599243164063,10000.0,2558.279573202133,2855.2203080654144,2558.279573202133,296.49737071990967,0.1598045825958252,0.0 -5500,1.7945029,5.0400968,,,,,,,,,,,,,, -5600,1.4920208,6.301398,,,,,,,,,,,,,, -5700,1.245773,6.2904644,,,,,,,,,,,,,, -5800,1.9661821,5.1195817,,,,,,,,,,,,,, -5900,2.0045052,4.846854,,,,,,,,,,,,,, -6000,1.7821273,4.9151025,,,,,,,,,,,,,, -6100,2.0059078,4.8716154,,,,,,,,,,,,,, -6200,1.8330775,5.2531624,,,,,,,,,,,,,, -6300,1.4815375,6.2598643,,,,,,,,,,,,,, -6353,,,0.2099804580211639,4.027159690856934,0.1937799900770187,4.129310607910156,50000.0,0.14410001039505,4.515932083129883,10000.0,2978.6933076381683,3314.668624162674,2978.6933076381683,335.4525935649872,0.1912572383880615,0.0 -6400,1.4656065,5.694708,,,,,,,,,,,,,, -6500,2.2451687,4.9399996,,,,,,,,,,,,,, -6600,1.9618341,4.820117,,,,,,,,,,,,,, -6700,1.8100519,5.1010656,,,,,,,,,,,,,, -6800,1.5920477,4.8567944,,,,,,,,,,,,,, -6900,2.1513758,4.6220117,,,,,,,,,,,,,, -7000,1.8509058,4.6679935,,,,,,,,,,,,,, -7100,2.0047345,4.519043,,,,,,,,,,,,,, -7200,1.8351923,4.5805616,,,,,,,,,,,,,, -7274,,,0.2498242110013961,3.753761768341065,0.2315599918365478,3.850430011749268,50000.0,0.1762000024318695,4.282338619232178,10000.0,3398.987615585327,3780.519155740738,3398.987615585327,380.9338798522949,0.2187426090240478,0.0 -7300,1.6024605,6.2275553,,,,,,,,,,,,,, -7400,2.323168,4.5518,,,,,,,,,,,,,, -7500,1.9243822,4.4095592,,,,,,,,,,,,,, -7600,1.8360813,4.7513776,,,,,,,,,,,,,, -7700,1.9881337,4.3839364,,,,,,,,,,,,,, -7800,2.6307766,4.534712,,,,,,,,,,,,,, -7900,2.2315395,4.8546467,,,,,,,,,,,,,, -8000,1.5594833,6.043805,,,,,,,,,,,,,, -8100,1.9945524,4.4721127,,,,,,,,,,,,,, -8196,,,0.2862304747104645,3.5035860538482666,0.2628999948501587,3.623043775558472,50000.0,0.2035000026226043,4.090365886688232,10000.0,3819.009551525116,4246.702522277832,3819.009551525116,427.0207488536835,0.2449104785919189,0.0 -8200,1.4758291,5.296319,,,,,,,,,,,,,, -8300,1.8336091,4.474757,,,,,,,,,,,,,, -8400,1.817822,4.2266054,,,,,,,,,,,,,, -8500,1.996276,4.309154,,,,,,,,,,,,,, -8600,1.5278294,5.1480217,,,,,,,,,,,,,, -8700,1.1498657,5.836867,,,,,,,,,,,,,, -8800,2.0177119,4.157614,,,,,,,,,,,,,, -8900,1.5790981,5.7444153,,,,,,,,,,,,,, -9000,1.951412,4.604743,,,,,,,,,,,,,, -9100,2.0677462,4.1485467,,,,,,,,,,,,,, -9114,,,0.3378515541553497,3.160243511199951,0.3047399818897247,3.333974599838257,50000.0,0.2339000105857849,3.833681344985962,10000.0,4239.054158687592,4712.58321595192,4239.054158687592,472.77794551849365,0.2752766609191894,0.0 -9200,2.2007868,4.0617123,,,,,,,,,,,,,, -9300,1.6489081,5.469297,,,,,,,,,,,,,, -9400,1.8656776,3.910431,,,,,,,,,,,,,, -9500,1.6767907,6.0897436,,,,,,,,,,,,,, -9600,1.3723356,5.774539,,,,,,,,,,,,,, -9700,1.3701199,5.890383,,,,,,,,,,,,,, -9800,2.2022064,3.9788928,,,,,,,,,,,,,, -9900,1.4567922,5.9376965,,,,,,,,,,,,,, -10000,1.8343496,3.8615108,,,,,,,,,,,,,, -10037,,,0.3565820157527923,3.0313377380371094,0.3318600058555603,3.1532087326049805,50000.0,0.2568000257015228,3.693822860717773,10000.0,4659.012718915939,5179.155911445618,4659.012718915939,519.3172528743744,0.3021481037139892,0.0 -10100,1.4691333,5.9953694,,,,,,,,,,,,,, -10200,1.4805492,5.3002295,,,,,,,,,,,,,, -10300,1.5690749,5.1935215,,,,,,,,,,,,,, -10400,1.5421004,4.69626,,,,,,,,,,,,,, -10500,1.7227494,4.6753182,,,,,,,,,,,,,, -10600,1.9983864,3.8824365,,,,,,,,,,,,,, -10700,1.9526334,3.8717391,,,,,,,,,,,,,, -10800,1.5507962,3.8341436,,,,,,,,,,,,,, -10900,1.3425605,5.0538883,,,,,,,,,,,,,, -10960,,,0.3778125047683716,2.9234459400177,0.3477599918842315,3.085547685623169,50000.0,0.2639000117778778,3.631510019302368,10000.0,5079.202866315842,5645.115402698517,5079.202866315842,565.0101907253265,0.3308203220367431,0.0 -11000,1.9622469,3.785567,,,,,,,,,,,,,, -11100,1.8290771,3.7895288,,,,,,,,,,,,,, -11200,1.9562143,3.9750473,,,,,,,,,,,,,, -11300,1.6656753,3.8363566,,,,,,,,,,,,,, -11400,1.4100717,5.776902,,,,,,,,,,,,,, -11500,1.3461405,4.5564895,,,,,,,,,,,,,, -11600,1.9948401,3.7528594,,,,,,,,,,,,,, -11700,1.4315015,5.1552434,,,,,,,,,,,,,, -11800,1.8958392,3.681351,,,,,,,,,,,,,, -11881,,,0.4100195169448852,2.730732679367065,0.3714399933815002,2.922337532043457,50000.0,0.2891000211238861,3.4765114784240723,10000.0,5499.275423049927,6107.204093456268,5499.275423049927,606.9534866809845,0.3557870388031006,0.0 -11900,1.2814753,5.7852225,,,,,,,,,,,,,, -12000,2.326289,3.4948874,,,,,,,,,,,,,, -12100,1.4610468,4.2354603,,,,,,,,,,,,,, -12200,1.5129292,3.8746212,,,,,,,,,,,,,, -12300,1.8864331,3.6281905,,,,,,,,,,,,,, -12400,1.515899,3.7467332,,,,,,,,,,,,,, -12500,2.0453463,3.5901387,,,,,,,,,,,,,, -12600,1.8757948,4.283826,,,,,,,,,,,,,, -12700,1.831816,3.6591194,,,,,,,,,,,,,, -12800,,,0.4242968559265136,2.650355815887451,0.3928999900817871,2.803166151046753,50000.0,0.299200028181076,3.3898091316223145,10000.0,5919.309090852737,6573.585469007492,5919.309090852737,653.2269532680511,0.3829104900360107,0.0 -12800,1.4755554,5.0981526,,,,,,,,,,,,,, -12900,1.6644996,3.537203,,,,,,,,,,,,,, -13000,1.7837776,3.6083744,,,,,,,,,,,,,, -13100,1.5982773,3.7343104,,,,,,,,,,,,,, -13200,1.9458288,3.5561407,,,,,,,,,,,,,, -13300,1.5829701,3.4836192,,,,,,,,,,,,,, -13400,1.0631478,5.868905,,,,,,,,,,,,,, -13500,1.8125823,3.536257,,,,,,,,,,,,,, -13600,1.3747492,5.183559,,,,,,,,,,,,,, -13700,1.2230415,5.1024284,,,,,,,,,,,,,, -13720,,,0.4307031035423279,2.6292409896850586,0.3979800045490265,2.789198398590088,50000.0,0.3100000023841858,3.3641951084136963,10000.0,6338.901314973831,7037.623072147369,6338.901314973831,697.1185910701752,0.8888986110687256,0.0 -13800,1.371775,4.6491857,,,,,,,,,,,,,, -13900,1.1616871,5.557545,,,,,,,,,,,,,, -14000,1.7681847,3.4933965,,,,,,,,,,,,,, -14100,1.735432,3.4871857,,,,,,,,,,,,,, -14200,1.6282288,3.5232244,,,,,,,,,,,,,, -14300,1.2978019,4.9160304,,,,,,,,,,,,,, -14400,1.6247646,3.6098163,,,,,,,,,,,,,, -14500,1.6793202,3.4189842,,,,,,,,,,,,,, -14600,1.790755,3.4702368,,,,,,,,,,,,,, -14641,,,0.4467968642711639,2.5055689811706543,0.4116399884223938,2.6887874603271484,50000.0,0.3162000179290771,3.266141653060913,10000.0,6759.135899543762,7494.290160655975,6759.135899543762,733.4783155918121,0.914440393447876,0.0 -14700,1.7177815,3.3753617,,,,,,,,,,,,,, -14800,1.5560503,3.4478302,,,,,,,,,,,,,, -14900,1.1083677,5.8969746,,,,,,,,,,,,,, -15000,1.610141,4.3972745,,,,,,,,,,,,,, -15100,1.658225,3.350339,,,,,,,,,,,,,, -15200,1.3863041,4.066707,,,,,,,,,,,,,, -15300,1.5756726,3.832176,,,,,,,,,,,,,, -15400,1.6891857,3.402017,,,,,,,,,,,,,, -15500,1.1157521,5.4655805,,,,,,,,,,,,,, -15556,,,0.4633203148841858,2.4123597145080566,0.4307799935340881,2.577693462371826,50000.0,0.3327000141143799,3.171525955200196,10000.0,7179.2973692417145,7951.254838228226,7179.2973692417145,770.1984448432922,0.9485998153686525,0.0 -15600,1.2561176,4.587532,,,,,,,,,,,,,, -15700,1.6435269,3.4091918,,,,,,,,,,,,,, -15800,0.9922013,5.688113,,,,,,,,,,,,,, -15900,1.5045857,3.8222153,,,,,,,,,,,,,, -16000,1.2918413,4.2905955,,,,,,,,,,,,,, -16100,1.1415894,5.669033,,,,,,,,,,,,,, -16200,1.7269298,3.4051452,,,,,,,,,,,,,, -16300,1.7105482,3.5650158,,,,,,,,,,,,,, -16400,1.4763536,3.402027,,,,,,,,,,,,,, -16474,,,0.4699999988079071,2.36291766166687,0.4378999769687652,2.5307412147521973,50000.0,0.3393000066280365,3.1288845539093018,10000.0,7599.65398979187,8420.844632863998,7599.65398979187,819.353661775589,0.9791200160980223,0.0 -16500,1.5601791,3.3080847,,,,,,,,,,,,,, -16600,1.8097001,3.4545717,,,,,,,,,,,,,, -16700,1.5341386,3.2893054,,,,,,,,,,,,,, -16800,1.6873552,3.520803,,,,,,,,,,,,,, -16900,1.342665,4.773001,,,,,,,,,,,,,, -17000,1.5228106,3.4488134,,,,,,,,,,,,,, -17100,1.5710597,3.27456,,,,,,,,,,,,,, -17200,1.6070057,3.433056,,,,,,,,,,,,,, -17300,1.6999712,3.2092786,,,,,,,,,,,,,, -17397,,,0.47802734375,2.3279898166656494,0.4379799962043762,2.5215203762054443,50000.0,0.3379000127315521,3.118520498275757,10000.0,8019.693618297577,8882.96993303299,8019.693618297577,861.3643939495087,1.0059688091278076,0.0 -17400,1.3717732,4.241037,,,,,,,,,,,,,, -17500,1.5934323,3.3913198,,,,,,,,,,,,,, -17600,1.93978,3.3966331,,,,,,,,,,,,,, -17700,1.5145538,3.3904514,,,,,,,,,,,,,, -17800,1.4872519,3.573759,,,,,,,,,,,,,, -17900,1.7800176,3.3681943,,,,,,,,,,,,,, -18000,1.895231,3.3423905,,,,,,,,,,,,,, -18100,1.0818391,5.6638765,,,,,,,,,,,,,, -18200,1.7920002,3.235725,,,,,,,,,,,,,, -18300,1.6185899,3.482361,,,,,,,,,,,,,, -18317,,,0.4999414086341858,2.242021322250366,0.4409199953079223,2.529330015182495,50000.0,0.3397000133991241,3.14664888381958,10000.0,8439.754869222641,9346.944638490677,8439.754869222641,905.2010207176208,1.0355210304260254,0.0 -18400,0.99846447,5.6588106,,,,,,,,,,,,,, -18500,1.5228429,3.336954,,,,,,,,,,,,,, -18600,1.4368821,3.2915986,,,,,,,,,,,,,, -18700,1.8263589,3.3593268,,,,,,,,,,,,,, -18800,1.7503023,3.275544,,,,,,,,,,,,,, -18900,1.6372927,3.2311018,,,,,,,,,,,,,, -19000,1.5522921,3.2429795,,,,,,,,,,,,,, -19100,1.3911656,5.0338993,,,,,,,,,,,,,, -19200,1.6115499,3.5945125,,,,,,,,,,,,,, -19236,,,0.4896093606948852,2.2664780616760254,0.4534199833869934,2.4464144706726074,50000.0,0.3473000228404999,3.072669744491577,10000.0,8860.098046064377,9812.855647802353,8860.098046064377,950.6946420669556,1.0625567436218262,0.0 -19300,1.4888062,3.3057823,,,,,,,,,,,,,, -19400,1.3569398,3.7027335,,,,,,,,,,,,,, -19500,1.6661146,3.2385325,,,,,,,,,,,,,, -19600,1.7011048,3.3963404,,,,,,,,,,,,,, -19700,1.5910559,3.2417092,,,,,,,,,,,,,, -19800,1.4711194,3.6058168,,,,,,,,,,,,,, -19900,1.4568812,3.548015,,,,,,,,,,,,,, -20000,1.5855868,3.2150054,,,,,,,,,,,,,, -20100,1.3486063,4.333437,,,,,,,,,,,,,, -20159,,,0.4975976347923279,2.2221951484680176,0.461139976978302,2.4042558670043945,50000.0,0.3519000113010406,3.0374844074249268,10000.0,9280.475875854492,10274.733105897903,9280.475875854492,992.1189947128296,1.0902252197265625,0.0 -20200,1.1715683,5.696171,,,,,,,,,,,,,, -20300,1.2631532,4.0387135,,,,,,,,,,,,,, -20400,1.0945015,5.621058,,,,,,,,,,,,,, -20500,1.0116153,5.669755,,,,,,,,,,,,,, -20600,1.5626562,3.3469574,,,,,,,,,,,,,, -20700,1.6174133,3.137344,,,,,,,,,,,,,, -20800,1.3163596,5.100176,,,,,,,,,,,,,, -20900,1.3229494,4.0817184,,,,,,,,,,,,,, -21000,0.9760472,5.0381684,,,,,,,,,,,,,, -21080,,,0.5208203196525574,2.0787808895111084,0.4711199998855591,2.326721429824829,50000.0,0.3665000200271606,2.948438882827759,10000.0,9700.750081539154,10741.280704021454,9700.750081539154,1038.3150515556335,1.1201841831207275,0.0 -21100,1.4581493,3.275386,,,,,,,,,,,,,, -21200,1.6943707,3.390757,,,,,,,,,,,,,, -21300,1.0190563,5.651692,,,,,,,,,,,,,, -21400,1.3571098,3.967297,,,,,,,,,,,,,, -21500,1.0005404,5.4983425,,,,,,,,,,,,,, -21600,1.0311517,5.4936767,,,,,,,,,,,,,, -21700,1.0209283,5.3611965,,,,,,,,,,,,,, -21800,1.395581,3.1843438,,,,,,,,,,,,,, -21900,1.5769571,3.2824736,,,,,,,,,,,,,, -22000,1.5545045,3.1672478,,,,,,,,,,,,,, -22001,,,0.5149999856948853,2.116328239440918,0.4795799851417541,2.289320707321167,50000.0,0.3775000274181366,2.8965604305267334,10000.0,10121.19642019272,11208.20026922226,10121.19642019272,1084.7075653076172,1.152724266052246,0.0 -22100,1.6268682,3.357673,,,,,,,,,,,,,, -22200,1.3200532,4.250637,,,,,,,,,,,,,, -22300,1.5768195,3.0963733,,,,,,,,,,,,,, -22400,1.1831752,5.1647787,,,,,,,,,,,,,, -22500,1.2841787,3.949987,,,,,,,,,,,,,, -22600,1.4018606,3.652618,,,,,,,,,,,,,, -22700,1.7403554,3.159059,,,,,,,,,,,,,, -22800,1.2919949,3.7279484,,,,,,,,,,,,,, -22900,1.5550158,3.3560622,,,,,,,,,,,,,, -22922,,,0.521191418170929,2.095719814300537,0.485179990530014,2.2728850841522217,50000.0,0.3761000037193298,2.9281606674194336,10000.0,10541.374686002731,11673.316876888275,10541.374686002731,1129.569720506668,1.1803789138793943,0.0 -23000,1.5790269,3.1463075,,,,,,,,,,,,,, -23100,1.6747861,3.082694,,,,,,,,,,,,,, -23200,1.6861807,3.0918937,,,,,,,,,,,,,, -23300,1.169439,4.3181105,,,,,,,,,,,,,, -23400,1.6681398,3.1399074,,,,,,,,,,,,,, -23500,1.1809323,5.390969,,,,,,,,,,,,,, -23600,1.3584343,3.81572,,,,,,,,,,,,,, -23700,1.6247123,3.0063562,,,,,,,,,,,,,, -23800,1.633398,3.4320364,,,,,,,,,,,,,, -23841,,,0.5325781106948853,2.0169644355773926,0.4902399778366089,2.243086338043213,50000.0,0.3845000267028808,2.868086576461792,10000.0,10961.407228469849,12138.750408172607,10961.407228469849,1174.8933689594269,1.210268497467041,0.0 -23900,1.4822066,3.0969803,,,,,,,,,,,,,, -24000,1.3822126,4.3700843,,,,,,,,,,,,,, -24100,1.7226757,3.3802595,,,,,,,,,,,,,, -24200,1.4689721,4.055129,,,,,,,,,,,,,, -24300,1.8502221,3.0905468,,,,,,,,,,,,,, -24400,1.4286337,4.886263,,,,,,,,,,,,,, -24500,1.4984621,3.1161792,,,,,,,,,,,,,, -24600,1.6227542,3.1090314,,,,,,,,,,,,,, -24700,1.4213254,4.4577084,,,,,,,,,,,,,, -24762,,,0.5335351228713989,2.0236427783966064,0.4941200017929077,2.203388214111328,50000.0,0.3880000114440918,2.834979295730591,10000.0,11381.509302854538,12606.17445731163,11381.509302854538,1222.1330163478851,1.244988203048706,0.0 -24800,1.1553794,5.4722013,,,,,,,,,,,,,, -24900,1.2701389,5.20416,,,,,,,,,,,,,, -25000,1.2761781,5.670542,,,,,,,,,,,,,, -25100,1.4374713,3.6292286,,,,,,,,,,,,,, -25200,1.2214456,4.435803,,,,,,,,,,,,,, -25300,1.1467543,5.303752,,,,,,,,,,,,,, -25400,1.5337374,3.0083485,,,,,,,,,,,,,, -25500,1.5134763,3.0439403,,,,,,,,,,,,,, -25600,1.0878034,4.6107674,,,,,,,,,,,,,, -25685,,,0.5403710603713989,2.013492584228516,0.5018999576568604,2.1954407691955566,50000.0,0.3906000256538391,2.8306357860565186,10000.0,11801.537393569946,13068.69039440155,11801.537393569946,1264.5435791015625,1.27274751663208,0.0 -25700,1.39291,3.2967105,,,,,,,,,,,,,, -25800,1.472323,3.0473115,,,,,,,,,,,,,, -25900,1.5851303,3.3038645,,,,,,,,,,,,,, -26000,1.5728676,2.895053,,,,,,,,,,,,,, -26100,1.2938132,5.5589876,,,,,,,,,,,,,, -26200,1.4977192,3.3884928,,,,,,,,,,,,,, -26300,1.705243,2.9383001,,,,,,,,,,,,,, -26400,1.4734672,3.9375772,,,,,,,,,,,,,, -26500,1.6459943,2.9127567,,,,,,,,,,,,,, -26600,1.667345,2.9956236,,,,,,,,,,,,,, -26604,,,0.5479491949081421,1.9419466257095337,0.5062199831008911,2.1494569778442383,50000.0,0.3939000070095062,2.7954912185668945,10000.0,12221.88210463524,13534.276743412018,12221.88210463524,1309.7079060077667,1.3029565811157229,0.0 -26700,1.591852,3.1107142,,,,,,,,,,,,,, -26800,1.0904485,5.232888,,,,,,,,,,,,,, -26900,1.1993822,4.5330415,,,,,,,,,,,,,, -27000,1.5849243,2.8367276,,,,,,,,,,,,,, -27100,1.4968798,3.312243,,,,,,,,,,,,,, -27200,1.7500467,3.1360624,,,,,,,,,,,,,, -27300,1.6726441,2.8095927,,,,,,,,,,,,,, -27400,1.2023071,4.622897,,,,,,,,,,,,,, -27500,1.4046534,3.1860104,,,,,,,,,,,,,, -27527,,,0.5578905940055847,1.9007580280303955,0.51419997215271,2.116635322570801,50000.0,0.4063000082969665,2.749185562133789,10000.0,12642.213397979736,13994.85052037239,12642.213397979736,1349.8739371299744,1.3320088386535645,0.0 -27600,1.2997314,5.4525914,,,,,,,,,,,,,, -27700,1.0709499,4.518221,,,,,,,,,,,,,, -27800,1.104881,5.529444,,,,,,,,,,,,,, -27900,1.4187793,3.07281,,,,,,,,,,,,,, -28000,1.6190587,2.8957286,,,,,,,,,,,,,, -28100,1.610739,3.0778186,,,,,,,,,,,,,, -28200,1.4800899,3.9392693,,,,,,,,,,,,,, -28300,1.6051935,3.0212743,,,,,,,,,,,,,, -28400,1.51606,2.905063,,,,,,,,,,,,,, -28448,,,0.5541796684265137,1.9155114889144893,0.5175999999046326,2.0890777111053467,50000.0,0.4115000069141388,2.720576763153076,10000.0,13062.25780081749,14460.51496386528,13062.25780081749,1395.4143552780151,1.3635451793670654,0.0 -28500,1.4889953,3.4752274,,,,,,,,,,,,,, -28600,1.3042765,3.1596289,,,,,,,,,,,,,, -28700,1.3854011,4.4367275,,,,,,,,,,,,,, -28800,1.2613645,4.7219257,,,,,,,,,,,,,, -28900,1.440287,3.5242953,,,,,,,,,,,,,, -29000,1.6386055,2.9986494,,,,,,,,,,,,,, -29100,1.3205681,4.4594183,,,,,,,,,,,,,, -29200,1.4099879,3.4783056,,,,,,,,,,,,,, -29300,1.3585912,3.2691944,,,,,,,,,,,,,, -29369,,,0.5678515434265137,1.865001916885376,0.5196200013160706,2.0869064331054688,50000.0,0.4068000316619873,2.7158117294311523,10000.0,13482.429998636246,14925.807185173036,13482.429998636246,1440.4584770202637,1.391657829284668,0.0 -29400,1.6083077,3.0933864,,,,,,,,,,,,,, -29500,1.5520047,3.0996418,,,,,,,,,,,,,, -29600,1.5873905,2.888986,,,,,,,,,,,,,, -29700,1.6678886,2.7858458,,,,,,,,,,,,,, -29800,1.5205991,2.9228063,,,,,,,,,,,,,, -29900,1.7119455,2.9510913,,,,,,,,,,,,,, -30000,1.5000618,2.86871,,,,,,,,,,,,,, -30100,1.3760438,3.467982,,,,,,,,,,,,,, -30200,1.338263,3.9678454,,,,,,,,,,,,,, -30290,,,0.5793359279632568,1.8263520002365112,0.5206999778747559,2.103342294692993,50000.0,0.4129000306129455,2.727597951889038,10000.0,13902.37486410141,15392.34959602356,13902.37486410141,1486.9760718345642,1.4241070747375488,0.0 -30300,1.6556891,2.9448776,,,,,,,,,,,,,, -30400,1.2627013,5.1556106,,,,,,,,,,,,,, -30500,1.5366812,3.3681138,,,,,,,,,,,,,, -30600,1.4719551,2.9339585,,,,,,,,,,,,,, -30700,1.4315612,3.6546135,,,,,,,,,,,,,, -30800,1.2371017,4.8725553,,,,,,,,,,,,,, -30900,1.5220891,3.890379,,,,,,,,,,,,,, -31000,1.6881944,2.7721877,,,,,,,,,,,,,, -31100,1.6661409,2.907672,,,,,,,,,,,,,, -31200,1.3059924,4.567189,,,,,,,,,,,,,, -31212,,,0.564160168170929,1.8538535833358765,0.527899980545044,2.041722297668457,50000.0,0.4142000079154968,2.687079906463623,10000.0,14322.354754447935,15855.711465358734,14322.354754447935,1530.2789916992188,1.456043720245361,0.0 -31300,1.3749157,4.3471,,,,,,,,,,,,,, -31400,1.6980655,2.8741484,,,,,,,,,,,,,, -31500,1.382139,3.5400093,,,,,,,,,,,,,, -31600,1.6310942,2.8060112,,,,,,,,,,,,,, -31700,1.647379,3.0229197,,,,,,,,,,,,,, -31800,1.5734357,3.0413811,,,,,,,,,,,,,, -31900,1.6066958,2.8811855,,,,,,,,,,,,,, -32000,1.585879,2.9451592,,,,,,,,,,,,,, -32100,1.3721108,5.480081,,,,,,,,,,,,,, -32134,,,0.5700390338897705,1.8478858470916748,0.531279981136322,2.034501314163208,50000.0,0.4193000197410583,2.66968321800232,10000.0,14742.48341703415,16321.563993692398,14742.48341703415,1575.9266781806946,1.4843404293060305,0.0 -32200,1.2810197,5.504006,,,,,,,,,,,,,, -32300,1.335074,3.1323495,,,,,,,,,,,,,, -32400,1.4624264,2.9605088,,,,,,,,,,,,,, -32500,1.1574111,5.3734527,,,,,,,,,,,,,, -32600,1.4445024,2.8427217,,,,,,,,,,,,,, -32700,1.8505974,2.8249671,,,,,,,,,,,,,, -32800,1.6555134,2.7420144,,,,,,,,,,,,,, -32900,1.7275627,2.8777885,,,,,,,,,,,,,, -33000,1.3050058,4.3040147,,,,,,,,,,,,,, -33052,,,0.5787890553474426,1.8214852809906008,0.5297799706459045,2.06502366065979,50000.0,0.4118000268936157,2.7029848098754883,10000.0,15162.777776002884,16788.576245307922,15162.777776002884,1622.5681648254397,1.513521432876587,0.0 -33100,1.6669887,2.79593,,,,,,,,,,,,,, -33200,1.5760658,2.9686944,,,,,,,,,,,,,, -33300,1.8343748,2.8661866,,,,,,,,,,,,,, -33400,1.2258774,4.4434032,,,,,,,,,,,,,, -33500,1.221843,4.690252,,,,,,,,,,,,,, -33600,1.7087793,2.8809483,,,,,,,,,,,,,, -33700,1.5884894,2.9690742,,,,,,,,,,,,,, -33800,1.5154842,2.942231,,,,,,,,,,,,,, -33900,1.6874725,2.8455272,,,,,,,,,,,,,, -33973,,,0.5707616806030273,1.8383631706237795,0.5380600094795227,2.0063157081604004,50000.0,0.4244000315666199,2.666775941848755,10000.0,15582.711037397385,17254.40386199951,15582.711037397385,1668.3824818134308,1.5456140041351318,0.0 -34000,1.7024132,2.9217262,,,,,,,,,,,,,, -34100,1.6708571,2.8722591,,,,,,,,,,,,,, -34200,1.5609974,3.065403,,,,,,,,,,,,,, -34300,1.3486888,5.311224,,,,,,,,,,,,,, -34400,1.5661044,2.923806,,,,,,,,,,,,,, -34500,1.6627232,2.9270804,,,,,,,,,,,,,, -34600,1.5674567,2.6689868,,,,,,,,,,,,,, -34700,1.6704152,2.7293515,,,,,,,,,,,,,, -34800,1.1978196,4.2095475,,,,,,,,,,,,,, -34894,,,0.5763476490974426,1.8116711378097528,0.5345199704170227,2.015291452407837,50000.0,0.4220000207424164,2.628847360610962,10000.0,16003.069641828535,17720.966168165207,16003.069641828535,1714.5101835727692,1.5746049880981443,0.0 -34900,1.4593663,4.2395163,,,,,,,,,,,,,, -35000,1.1376575,5.307648,,,,,,,,,,,,,, -35100,1.2633755,4.954355,,,,,,,,,,,,,, -35200,1.5203207,2.8468735,,,,,,,,,,,,,, -35300,1.5895588,3.345256,,,,,,,,,,,,,, -35400,1.6141021,2.7123322,,,,,,,,,,,,,, -35500,1.718364,2.8431692,,,,,,,,,,,,,, -35600,1.2038347,4.847417,,,,,,,,,,,,,, -35700,1.758841,2.909725,,,,,,,,,,,,,, -35800,1.3653245,5.4100103,,,,,,,,,,,,,, -35814,,,0.581250011920929,1.7759336233139038,0.5402399897575378,1.991693377494812,50000.0,0.4217000305652618,2.6568825244903564,10000.0,16423.170471429825,18189.147683382034,16423.170471429825,1762.5099685192108,1.608259916305542,0.0 -35900,1.6633798,2.905393,,,,,,,,,,,,,, -36000,1.6675427,2.7400653,,,,,,,,,,,,,, -36100,1.7672094,2.7447724,,,,,,,,,,,,,, -36200,1.5520179,2.779314,,,,,,,,,,,,,, -36300,1.5892695,2.850189,,,,,,,,,,,,,, -36400,1.6531702,3.1864464,,,,,,,,,,,,,, -36500,1.4407043,3.985126,,,,,,,,,,,,,, -36600,1.7515364,2.7645943,,,,,,,,,,,,,, -36700,1.5204059,2.7721605,,,,,,,,,,,,,, -36737,,,0.5838086009025574,1.7920989990234375,0.5418800115585327,1.9881590604782104,50000.0,0.4239000082015991,2.632972478866577,10000.0,16843.52310347557,18653.83980345726,16843.52310347557,1806.7723808288567,1.6376621723175049,0.0 -36800,1.5291772,2.8037095,,,,,,,,,,,,,, -36900,1.239175,4.780972,,,,,,,,,,,,,, -37000,1.6568928,2.89363,,,,,,,,,,,,,, -37100,1.6711183,2.8392658,,,,,,,,,,,,,, -37200,1.6694977,3.2689643,,,,,,,,,,,,,, -37300,1.6179752,2.6158347,,,,,,,,,,,,,, -37400,1.5250096,2.6358385,,,,,,,,,,,,,, -37500,1.640432,2.9298499,,,,,,,,,,,,,, -37600,1.8178316,2.7132776,,,,,,,,,,,,,, -37657,,,0.5879687070846558,1.7495322227478027,0.5480200052261353,1.946392297744751,50000.0,0.4320000112056732,2.602356672286988,10000.0,17263.79153227806,19120.8027510643,17263.79153227806,1853.3759117126465,1.6815898418426514,0.0 -37700,1.5579271,3.126336,,,,,,,,,,,,,, -37800,1.3183312,4.2078605,,,,,,,,,,,,,, -37900,1.2723485,4.3523703,,,,,,,,,,,,,, -38000,1.5926336,2.699776,,,,,,,,,,,,,, -38100,1.1905986,5.396302,,,,,,,,,,,,,, -38200,1.4765114,3.5182881,,,,,,,,,,,,,, -38300,1.8381456,2.8584397,,,,,,,,,,,,,, -38400,1.4066617,3.9041405,,,,,,,,,,,,,, -38500,1.651307,2.716845,,,,,,,,,,,,,, -38580,,,0.5942773222923279,1.714832067489624,0.5480599999427795,1.926216721534729,50000.0,0.4335000216960907,2.5650787353515625,10000.0,17684.12568449974,19585.754362106323,17684.12568449974,1897.910856246948,1.71612548828125,0.0 -38600,1.8298742,2.8783529,,,,,,,,,,,,,, -38700,1.216278,4.081986,,,,,,,,,,,,,, -38800,1.6698794,2.7722812,,,,,,,,,,,,,, -38900,1.6877928,2.7379992,,,,,,,,,,,,,, -39000,1.4649918,3.289513,,,,,,,,,,,,,, -39100,1.714932,2.9347024,,,,,,,,,,,,,, -39200,1.5980093,3.2566001,,,,,,,,,,,,,, -39300,1.5788448,3.9140785,,,,,,,,,,,,,, -39400,1.4565529,3.828706,,,,,,,,,,,,,, -39500,1.9396158,2.8472517,,,,,,,,,,,,,, -39501,,,0.6201366782188416,1.614801287651062,0.5507599711418152,1.94295072555542,50000.0,0.4365000128746032,2.580235958099365,10000.0,18104.34602546692,20052.413615226746,18104.34602546692,1944.2633888721464,1.7545132637023926,0.0 -39600,1.7647103,2.8671343,,,,,,,,,,,,,, -39700,1.6397853,2.8989654,,,,,,,,,,,,,, -39800,1.6767244,2.6782696,,,,,,,,,,,,,, -39900,1.4493269,2.7388637,,,,,,,,,,,,,, -40000,1.2193147,4.8150797,,,,,,,,,,,,,, -40100,1.712349,2.7996902,,,,,,,,,,,,,, -40200,1.6846709,2.6332889,,,,,,,,,,,,,, -40300,1.7455902,2.7036412,,,,,,,,,,,,,, -40400,1.8244452,2.7392642,,,,,,,,,,,,,, -40422,,,0.5883983969688416,1.758009433746338,0.5488199591636658,1.9514132738113403,50000.0,0.430400013923645,2.604454517364502,10000.0,18524.72751617432,20521.472969293594,18524.72751617432,1992.8587412834167,1.7895410060882568,0.0 -40500,1.635957,2.7502909,,,,,,,,,,,,,, -40600,1.3940613,3.6712928,,,,,,,,,,,,,, -40700,1.545337,2.8174014,,,,,,,,,,,,,, -40800,1.6046752,3.0043519,,,,,,,,,,,,,, -40900,1.215743,5.2527266,,,,,,,,,,,,,, -41000,1.7600212,3.3336916,,,,,,,,,,,,,, -41100,1.7381364,2.842447,,,,,,,,,,,,,, -41200,1.5372627,2.7366846,,,,,,,,,,,,,, -41300,1.6508458,3.1718094,,,,,,,,,,,,,, -41344,,,0.6026757955551147,1.6814193725585938,0.5562199950218201,1.8999054431915283,50000.0,0.4415000081062317,2.542169570922852,10000.0,18944.705970048904,20985.605131864548,18944.705970048904,2036.9278779029848,1.8252899646759035,0.0 -41400,1.6089509,2.6343231,,,,,,,,,,,,,, -41500,1.2166919,4.558399,,,,,,,,,,,,,, -41600,1.435166,4.7487035,,,,,,,,,,,,,, -41700,1.317216,5.3389754,,,,,,,,,,,,,, -41800,1.9412863,2.6919048,,,,,,,,,,,,,, -41900,1.1936324,4.8626933,,,,,,,,,,,,,, -42000,1.6502585,2.7640135,,,,,,,,,,,,,, -42100,1.4061719,4.3690305,,,,,,,,,,,,,, -42200,1.6505864,2.746408,,,,,,,,,,,,,, -42267,,,0.6179882884025574,1.5960801839828491,0.560479998588562,1.867920994758606,50000.0,0.442300021648407,2.526688814163208,10000.0,19364.84474134445,21454.48744463921,19364.84474134445,2085.58514547348,1.8638558387756348,0.0 -42300,1.6045246,2.737645,,,,,,,,,,,,,, -42400,1.5101049,3.3751192,,,,,,,,,,,,,, -42500,1.7054868,2.9084296,,,,,,,,,,,,,, -42600,1.2535486,3.9749646,,,,,,,,,,,,,, -42700,1.962066,2.6956391,,,,,,,,,,,,,, -42800,1.58087,2.6739259,,,,,,,,,,,,,, -42900,1.6651448,2.7050788,,,,,,,,,,,,,, -43000,1.6332985,2.8296807,,,,,,,,,,,,,, -43100,1.8219367,2.7658036,,,,,,,,,,,,,, -43187,,,0.5931640267372131,1.7232056856155396,0.5553199648857117,1.9169193506240845,50000.0,0.4412000179290771,2.5503687858581543,10000.0,19785.094200372696,21920.119605779648,19785.094200372696,2130.889495611191,1.8945605754852293,0.0 -43200,1.7610015,2.7801542,,,,,,,,,,,,,, -43300,1.691426,2.9970326,,,,,,,,,,,,,, -43400,1.5750289,2.72829,,,,,,,,,,,,,, -43500,1.4552252,3.4474027,,,,,,,,,,,,,, -43600,1.8421963,2.8761148,,,,,,,,,,,,,, -43700,1.5097835,3.4620228,,,,,,,,,,,,,, -43800,1.287828,4.273528,,,,,,,,,,,,,, -43900,1.555085,4.769706,,,,,,,,,,,,,, -44000,1.6709888,2.690671,,,,,,,,,,,,,, -44100,1.6200387,2.683463,,,,,,,,,,,,,, -44108,,,0.6003515720367432,1.6775927543640137,0.5586400032043457,1.8800852298736568,50000.0,0.4472000300884247,2.5132040977478027,10000.0,20205.310017347336,22388.53720092773,20205.310017347336,2179.0115325450897,1.9270873069763184,0.0 -44200,1.558547,3.0469017,,,,,,,,,,,,,, -44300,1.877537,2.764672,,,,,,,,,,,,,, -44400,1.9028131,2.7836218,,,,,,,,,,,,,, -44500,1.847049,2.8028905,,,,,,,,,,,,,, -44600,1.6242838,3.0288327,,,,,,,,,,,,,, -44700,1.6179779,2.7206118,,,,,,,,,,,,,, -44800,1.6147726,3.2223024,,,,,,,,,,,,,, -44900,1.70226,2.9079547,,,,,,,,,,,,,, -45000,1.5228059,3.1199656,,,,,,,,,,,,,, -45030,,,0.6131640672683716,1.6342071294784546,0.5626999735832214,1.8697541952133176,50000.0,0.4420000314712524,2.5277490615844727,10000.0,20625.409114599228,22855.010596990585,20625.409114599228,2225.3049223423004,1.9600763320922847,0.0 -45100,1.7876662,2.698992,,,,,,,,,,,,,, -45200,1.3668771,5.1397433,,,,,,,,,,,,,, -45300,1.9472896,2.7846096,,,,,,,,,,,,,, -45400,1.7695317,2.715885,,,,,,,,,,,,,, -45500,1.810922,2.6479688,,,,,,,,,,,,,, -45600,1.5995855,3.461706,,,,,,,,,,,,,, -45700,1.7266271,2.825005,,,,,,,,,,,,,, -45800,1.8260835,2.7307281,,,,,,,,,,,,,, -45900,1.5754424,3.4347813,,,,,,,,,,,,,, -45952,,,0.6003320217132568,1.708828091621399,0.5588200092315674,1.894186973571777,50000.0,0.4379000067710876,2.550467729568481,10000.0,21045.712456464767,23322.01553273201,21045.712456464767,2271.9245131015778,1.9946684837341309,0.0 -46000,1.5653274,3.1821032,,,,,,,,,,,,,, -46100,1.5662303,3.3684044,,,,,,,,,,,,,, -46200,1.2885958,4.063463,,,,,,,,,,,,,, -46300,1.7713866,2.6900823,,,,,,,,,,,,,, -46400,1.3200682,4.0153046,,,,,,,,,,,,,, -46500,1.5884477,3.313675,,,,,,,,,,,,,, -46600,2.2853894,2.7251606,,,,,,,,,,,,,, -46700,1.6522781,2.6363451,,,,,,,,,,,,,, -46800,1.7861228,2.8096852,,,,,,,,,,,,,, -46874,,,0.6097265481948853,1.6467933654785156,0.5633599758148193,1.8516621589660645,50000.0,0.4422000348567962,2.517657518386841,10000.0,21465.93386626244,23790.59200644493,21465.93386626244,2320.201394081116,2.025162935256958,0.0 -46900,1.4508466,4.957243,,,,,,,,,,,,,, -47000,1.531621,2.6022696,,,,,,,,,,,,,, -47100,1.760965,2.6301477,,,,,,,,,,,,,, -47200,1.3595288,4.9247684,,,,,,,,,,,,,, -47300,1.7376444,2.7812924,,,,,,,,,,,,,, -47400,1.5900934,3.5190063,,,,,,,,,,,,,, -47500,1.6065078,2.696062,,,,,,,,,,,,,, -47600,1.6482764,2.6402333,,,,,,,,,,,,,, -47700,1.7135717,2.6619027,,,,,,,,,,,,,, -47793,,,0.6098241806030273,1.677674412727356,0.5647000074386597,1.89734959602356,50000.0,0.4451000094413757,2.5339388847351074,10000.0,21886.19406223297,24254.98417496681,21886.19406223297,2364.2496979236603,2.061084032058716,0.0 -47800,1.499064,4.4917836,,,,,,,,,,,,,, -47900,1.7922322,2.655634,,,,,,,,,,,,,, -48000,1.3881574,3.8716068,,,,,,,,,,,,,, -48100,1.4172366,4.493483,,,,,,,,,,,,,, -48200,1.9459842,2.7498837,,,,,,,,,,,,,, -48300,1.5210937,2.8131511,,,,,,,,,,,,,, -48400,1.8132362,2.6946359,,,,,,,,,,,,,, -48500,1.8338164,2.8035398,,,,,,,,,,,,,, -48600,1.9140404,2.5247195,,,,,,,,,,,,,, -48700,1.7285563,2.6440392,,,,,,,,,,,,,, -48711,,,0.6181249618530273,1.6265902519226074,0.5595600008964539,1.886021375656128,50000.0,0.443200021982193,2.5331673622131348,10000.0,22306.34263277054,24723.15763783455,22306.34263277054,2412.188717842102,2.0992519855499268,0.0 -48800,1.7374299,2.6082206,,,,,,,,,,,,,, -48900,1.7154055,2.5406828,,,,,,,,,,,,,, -49000,1.5480566,4.908635,,,,,,,,,,,,,, -49100,1.4405524,5.364021,,,,,,,,,,,,,, -49200,1.8799394,2.5654566,,,,,,,,,,,,,, -49300,1.7405195,3.1424677,,,,,,,,,,,,,, -49400,1.7197332,2.648619,,,,,,,,,,,,,, -49500,1.5631793,3.114627,,,,,,,,,,,,,, -49600,1.8318237,2.6506646,,,,,,,,,,,,,, -49631,,,0.6079882383346558,1.6691138744354248,0.5669800043106079,1.862606406211853,50000.0,0.444100022315979,2.522622585296631,10000.0,22726.50920295716,25191.15389060974,22726.50920295716,2459.936644077301,2.133064985275269,0.0 -49700,1.854581,2.5788286,,,,,,,,,,,,,, -49800,1.3590194,5.2309256,,,,,,,,,,,,,, -49900,1.2365518,4.5355167,,,,,,,,,,,,,, -50000,1.6505798,2.5772028,,,,,,,,,,,,,, -50100,1.2304131,4.103357,,,,,,,,,,,,,, -50200,1.6110617,2.9005573,,,,,,,,,,,,,, -50300,1.3795109,4.191928,,,,,,,,,,,,,, -50400,1.4594984,3.3783178,,,,,,,,,,,,,, -50500,1.6454623,2.6365936,,,,,,,,,,,,,, -50554,,,0.6143358945846558,1.6266354322433472,0.5703999996185303,1.8336740732192995,50000.0,0.4585000276565552,2.471269369125366,10000.0,23146.88272571564,25658.474204063416,23146.88272571564,2506.8026852607727,2.166099071502685,0.0 -50600,1.3768042,5.2018,,,,,,,,,,,,,, -50700,1.4260474,3.3915687,,,,,,,,,,,,,, -50800,1.8370321,2.5478008,,,,,,,,,,,,,, -50900,1.8124207,2.6548727,,,,,,,,,,,,,, -51000,1.5362443,3.2575936,,,,,,,,,,,,,, -51100,1.8233521,2.699181,,,,,,,,,,,,,, -51200,1.7463224,2.590456,,,,,,,,,,,,,, -51300,1.7495028,2.645889,,,,,,,,,,,,,, -51400,1.578599,3.1093578,,,,,,,,,,,,,, -51474,,,0.6382616758346558,1.5179091691970823,0.5744999647140503,1.8161741495132449,50000.0,0.4506000280380249,2.467916250228882,10000.0,23566.82369351387,26122.45309472084,23566.82369351387,2550.7530856132507,2.205763101577759,0.0 -51500,1.6837587,2.7782478,,,,,,,,,,,,,, -51600,1.6123911,4.4869037,,,,,,,,,,,,,, -51700,1.8046869,2.567341,,,,,,,,,,,,,, -51800,1.487796,3.668038,,,,,,,,,,,,,, -51900,1.435046,3.5039358,,,,,,,,,,,,,, -52000,1.5001476,3.1080613,,,,,,,,,,,,,, -52100,1.3489299,3.893174,,,,,,,,,,,,,, -52200,1.7348924,2.7399673,,,,,,,,,,,,,, -52300,1.59345,2.994236,,,,,,,,,,,,,, -52393,,,0.615527331829071,1.6107457876205444,0.5764999985694885,1.7930006980895996,50000.0,0.4531000256538391,2.4591097831726074,10000.0,23987.047943353653,26589.09598493576,23987.047943353653,2597.088151693344,2.242366075515747,0.0 -52400,1.6267282,2.7888062,,,,,,,,,,,,,, -52500,1.7381097,2.5773542,,,,,,,,,,,,,, -52600,1.6172459,3.2932444,,,,,,,,,,,,,, -52700,2.0711324,2.6428356,,,,,,,,,,,,,, -52800,1.8438051,2.6991525,,,,,,,,,,,,,, -52900,1.4971467,4.086401,,,,,,,,,,,,,, -53000,1.5298194,4.9798083,,,,,,,,,,,,,, -53100,1.8447706,2.6236162,,,,,,,,,,,,,, -53200,1.6212941,2.6919813,,,,,,,,,,,,,, -53300,1.3973664,5.1323876,,,,,,,,,,,,,, -53314,,,0.6172460913658142,1.6188424825668335,0.5728799700737,1.8209742307662964,50000.0,0.4622000157833099,2.461902141571045,10000.0,24407.272254943848,27058.116228818893,24407.272254943848,2645.788686990738,2.2893104553222656,0.0 -53400,1.7496438,2.9086823,,,,,,,,,,,,,, -53500,1.4831655,2.9438744,,,,,,,,,,,,,, -53600,1.9042718,2.5497725,,,,,,,,,,,,,, -53700,1.4780006,3.2894273,,,,,,,,,,,,,, -53800,1.5680083,3.475532,,,,,,,,,,,,,, -53900,1.4521471,4.3913355,,,,,,,,,,,,,, -54000,1.2606878,4.3440104,,,,,,,,,,,,,, -54100,1.5980666,3.5316234,,,,,,,,,,,,,, -54200,1.6515354,2.6564689,,,,,,,,,,,,,, -54234,,,0.6295117139816284,1.57720947265625,0.572219967842102,1.8321927785873413,50000.0,0.4579000174999237,2.464775800704956,10000.0,24827.344648122787,27526.469512939453,24827.344648122787,2693.9905047416687,2.3210394382476807,0.0 -54300,1.4088329,3.826002,,,,,,,,,,,,,, -54400,1.8572016,2.648033,,,,,,,,,,,,,, -54500,1.7173818,2.5973692,,,,,,,,,,,,,, -54600,1.585849,3.2216594,,,,,,,,,,,,,, -54700,1.8368608,2.6577263,,,,,,,,,,,,,, -54800,1.363298,4.342125,,,,,,,,,,,,,, -54900,1.7862446,2.5221264,,,,,,,,,,,,,, -55000,2.152441,2.753241,,,,,,,,,,,,,, -55100,1.7175382,2.5314562,,,,,,,,,,,,,, -55152,,,0.6147655844688416,1.6305612325668335,0.5796599984169006,1.788099765777588,50000.0,0.46670001745224,2.4314701557159424,10000.0,25247.712097883224,27995.00447773933,25247.712097883224,2742.075278520584,2.35662841796875,0.0 -55200,1.8308531,2.7682185,,,,,,,,,,,,,, -55300,1.7293048,2.6504703,,,,,,,,,,,,,, -55400,1.8156949,2.4985402,,,,,,,,,,,,,, -55500,1.9173115,2.607765,,,,,,,,,,,,,, -55600,1.7301699,2.6977577,,,,,,,,,,,,,, -55700,1.6290559,2.6160705,,,,,,,,,,,,,, -55800,1.7380266,2.6153123,,,,,,,,,,,,,, -55900,1.7563374,2.623086,,,,,,,,,,,,,, -56000,1.7599705,2.8648448,,,,,,,,,,,,,, -56070,,,0.6223242282867432,1.5823334455490112,0.5771600008010864,1.789430856704712,50000.0,0.4645000100135803,2.429868221282959,10000.0,25667.8112885952,28459.00898051262,25667.8112885952,2785.8978073596954,2.391852140426636,0.0 -56100,1.6725608,3.071228,,,,,,,,,,,,,, -56200,1.9197217,2.5343037,,,,,,,,,,,,,, -56300,1.6625129,2.5686908,,,,,,,,,,,,,, -56400,1.6821891,2.726592,,,,,,,,,,,,,, -56500,1.7759378,2.7382734,,,,,,,,,,,,,, -56600,1.5907332,5.3633566,,,,,,,,,,,,,, -56700,1.756573,2.4974458,,,,,,,,,,,,,, -56800,1.2902043,4.6046114,,,,,,,,,,,,,, -56900,1.4028974,4.611592,,,,,,,,,,,,,, -56989,,,0.6273437142372131,1.5826400518417358,0.5783799886703491,1.807153582572937,50000.0,0.457800030708313,2.4639346599578857,10000.0,26088.068242549896,28927.74728178978,26088.068242549896,2834.294125318527,2.429672718048096,0.0 -57000,1.4395919,4.6355476,,,,,,,,,,,,,, -57100,1.2692777,4.606334,,,,,,,,,,,,,, -57200,1.5127195,4.4725437,,,,,,,,,,,,,, -57300,1.7611417,3.4548626,,,,,,,,,,,,,, -57400,1.790142,2.4660676,,,,,,,,,,,,,, -57500,1.7999033,2.8362603,,,,,,,,,,,,,, -57600,1.5836586,3.688465,,,,,,,,,,,,,, -57700,1.9762081,2.6158009,,,,,,,,,,,,,, -57800,1.4571105,3.2743654,,,,,,,,,,,,,, -57900,1.7022502,2.6802716,,,,,,,,,,,,,, -57909,,,0.6226171851158142,1.5738970041275024,0.578719973564148,1.7816462516784668,50000.0,0.4639000296592712,2.410048246383667,10000.0,26508.35567474365,29394.660277605057,26508.35567474365,2880.826591491699,2.475525379180908,0.0 -58000,1.7900238,2.5431616,,,,,,,,,,,,,, -58100,1.678855,4.152913,,,,,,,,,,,,,, -58200,1.5005498,4.3261757,,,,,,,,,,,,,, -58300,1.6864189,2.5711136,,,,,,,,,,,,,, -58400,1.9214479,2.5265255,,,,,,,,,,,,,, -58500,1.5016137,3.8204114,,,,,,,,,,,,,, -58600,1.7695441,2.9272227,,,,,,,,,,,,,, -58700,1.6885055,2.653877,,,,,,,,,,,,,, -58800,2.211619,2.6315644,,,,,,,,,,,,,, -58827,,,0.6302929520606995,1.5352495908737185,0.5891000032424927,1.726948857307434,50000.0,0.4725000262260437,2.381237268447876,10000.0,26928.282299280167,29861.34847187996,26928.282299280167,2927.505961894989,2.509864330291748,0.0 -58900,1.5540164,4.1243606,,,,,,,,,,,,,, -59000,1.763447,2.5506818,,,,,,,,,,,,,, -59100,1.533033,3.3189921,,,,,,,,,,,,,, -59200,1.8864795,2.4729664,,,,,,,,,,,,,, -59300,1.2852418,5.17265,,,,,,,,,,,,,, -59400,1.43842,5.329682,,,,,,,,,,,,,, -59500,1.5745777,5.2174816,,,,,,,,,,,,,, -59600,1.7654749,2.527661,,,,,,,,,,,,,, -59700,1.2856768,4.959182,,,,,,,,,,,,,, -59747,,,0.6289257407188416,1.5401300191879272,0.581559956073761,1.7661179304122925,50000.0,0.4588000178337097,2.400780200958252,10000.0,27348.234403848648,30329.973765850067,27348.234403848648,2976.0984501838684,2.5432207584381104,0.0 -59800,1.5552862,2.9312065,,,,,,,,,,,,,, -59900,1.7559135,2.5250895,,,,,,,,,,,,,, -60000,1.4937239,3.5555308,,,,,,,,,,,,,, -60100,1.7810823,2.5795932,,,,,,,,,,,,,, -60200,1.8952651,2.5391145,,,,,,,,,,,,,, -60300,1.6770684,4.1002245,,,,,,,,,,,,,, -60400,2.0714345,2.5356848,,,,,,,,,,,,,, -60500,2.0251613,2.6834047,,,,,,,,,,,,,, -60600,1.7435814,3.005959,,,,,,,,,,,,,, -60667,,,0.6617382764816284,1.412854790687561,0.5889999866485596,1.731900691986084,50000.0,0.4660000205039978,2.385641574859619,10000.0,27768.30704021454,30796.25399804116,27768.30704021454,3022.2227504253387,2.578721761703491,0.0 -60700,1.7239915,2.664824,,,,,,,,,,,,,, -60800,1.9077394,2.6739426,,,,,,,,,,,,,, -60900,1.6171769,3.7032228,,,,,,,,,,,,,, -61000,1.7242415,4.484942,,,,,,,,,,,,,, -61100,1.878987,2.7084854,,,,,,,,,,,,,, -61200,1.3778706,4.1031046,,,,,,,,,,,,,, -61300,1.7570353,2.5744014,,,,,,,,,,,,,, -61400,1.866592,2.5783703,,,,,,,,,,,,,, -61500,2.3746204,2.519454,,,,,,,,,,,,,, -61583,,,0.6290038824081421,1.5284565687179563,0.5866400003433228,1.7359672784805298,50000.0,0.4729000329971313,2.363012313842773,10000.0,28188.32996058464,31264.258637428284,28188.32996058464,3070.1183915138245,2.6172053813934326,0.0 -61600,1.8307029,2.890956,,,,,,,,,,,,,, -61700,1.4551417,4.0429807,,,,,,,,,,,,,, -61800,1.7078623,2.6747046,,,,,,,,,,,,,, -61900,1.8353164,2.5675032,,,,,,,,,,,,,, -62000,1.7949133,2.589387,,,,,,,,,,,,,, -62100,1.8629413,2.8492362,,,,,,,,,,,,,, -62200,1.8262289,2.5234625,,,,,,,,,,,,,, -62300,1.5331472,5.183671,,,,,,,,,,,,,, -62400,1.6531166,2.6083965,,,,,,,,,,,,,, -62500,1.9168137,3.2992957,,,,,,,,,,,,,, -62501,,,0.6339648365974426,1.5344921350479126,0.5881800055503845,1.743263602256775,50000.0,0.4690000116825104,2.3870527744293213,10000.0,28608.93250131607,31731.734622240067,28608.93250131607,3116.9010808467865,2.6603078842163086,0.0 -62600,1.6944035,2.6736755,,,,,,,,,,,,,, -62700,1.5752989,4.17009,,,,,,,,,,,,,, -62800,1.3793075,4.8634562,,,,,,,,,,,,,, -62900,1.8212509,2.4938874,,,,,,,,,,,,,, -63000,1.960373,2.5844336,,,,,,,,,,,,,, -63100,1.7705263,2.4970481,,,,,,,,,,,,,, -63200,1.525664,4.4037027,,,,,,,,,,,,,, -63300,1.6945956,3.00378,,,,,,,,,,,,,, -63400,1.3674582,5.142801,,,,,,,,,,,,,, -63421,,,0.6504101157188416,1.453066349029541,0.5853399634361267,1.746640682220459,50000.0,0.4687000215053558,2.4003686904907227,10000.0,29029.071456193924,32200.69903111458,29029.071456193924,3165.640180826187,2.699095726013184,0.0 -63500,1.6072698,3.56136,,,,,,,,,,,,,, -63600,1.552408,4.468804,,,,,,,,,,,,,, -63700,1.7840029,2.5510135,,,,,,,,,,,,,, -63800,1.3975196,4.435207,,,,,,,,,,,,,, -63900,1.861427,4.7188377,,,,,,,,,,,,,, -64000,1.740351,2.5480373,,,,,,,,,,,,,, -64100,1.9191921,2.847892,,,,,,,,,,,,,, -64200,1.5082177,5.0937195,,,,,,,,,,,,,, -64300,1.8603538,2.4473963,,,,,,,,,,,,,, -64341,,,0.6320703029632568,1.5366462469100952,0.588979959487915,1.7273471355438232,50000.0,0.4739000201225281,2.3589258193969727,10000.0,29449.20278072357,32667.884612083435,29449.20278072357,3212.613987445832,2.732761144638061,0.0 -64400,1.4941133,4.900006,,,,,,,,,,,,,, -64500,1.9165795,2.553666,,,,,,,,,,,,,, -64600,1.8225418,2.9597745,,,,,,,,,,,,,, -64700,1.8076783,2.5064492,,,,,,,,,,,,,, -64800,1.5945908,2.9469883,,,,,,,,,,,,,, -64900,1.8582774,3.5226226,,,,,,,,,,,,,, -65000,1.5149754,4.7719674,,,,,,,,,,,,,, -65100,1.4763155,3.6208022,,,,,,,,,,,,,, -65200,1.8819145,2.6082628,,,,,,,,,,,,,, -65261,,,0.6381250023841858,1.5205624103546145,0.5915200114250183,1.7391788959503174,50000.0,0.4754000306129455,2.375885486602783,10000.0,29869.36788988113,33134.08188343048,29869.36788988113,3258.559402942657,2.7715091705322266,0.0 -65300,1.5482659,5.1355677,,,,,,,,,,,,,, -65400,1.7950221,2.5520678,,,,,,,,,,,,,, -65500,1.7760917,2.5895665,,,,,,,,,,,,,, -65600,1.797239,2.8689485,,,,,,,,,,,,,, -65700,2.0204017,2.5588856,,,,,,,,,,,,,, -65800,1.7511375,2.5173273,,,,,,,,,,,,,, -65900,1.8078674,3.1727805,,,,,,,,,,,,,, -66000,1.9341292,2.4627807,,,,,,,,,,,,,, -66100,1.4807547,4.5221405,,,,,,,,,,,,,, -66184,,,0.6523046493530273,1.447373628616333,0.595259964466095,1.7073568105697632,50000.0,0.4763000309467315,2.3564023971557617,10000.0,30289.691901922222,33600.01197576523,30289.691901922222,3304.079433441162,2.809138774871826,0.0 -66200,1.7286925,2.8819833,,,,,,,,,,,,,, -66300,1.79976,2.4424338,,,,,,,,,,,,,, -66400,1.8095671,2.9924695,,,,,,,,,,,,,, -66500,2.0458643,2.4515157,,,,,,,,,,,,,, -66600,1.8163748,2.5902271,,,,,,,,,,,,,, -66700,1.6050891,5.089095,,,,,,,,,,,,,, -66800,1.8582728,3.0517786,,,,,,,,,,,,,, -66900,1.562229,4.2049437,,,,,,,,,,,,,, -67000,1.5236118,4.889596,,,,,,,,,,,,,, -67100,1.7840924,2.5059345,,,,,,,,,,,,,, -67101,,,0.6419531106948853,1.4803026914596558,0.6005600094795227,1.6780091524124146,50000.0,0.4821000099182129,2.307164669036865,10000.0,30710.14184713364,34067.99260735512,30710.14184713364,3351.523953676224,2.8487424850463867,0.0 -67200,1.5126866,3.9872472,,,,,,,,,,,,,, -67300,1.5869576,4.0751224,,,,,,,,,,,,,, -67400,1.9947438,2.9046988,,,,,,,,,,,,,, -67500,2.1914113,2.6236415,,,,,,,,,,,,,, -67600,1.8524307,2.709061,,,,,,,,,,,,,, -67700,2.0544877,2.554618,,,,,,,,,,,,,, -67800,1.9030606,2.521593,,,,,,,,,,,,,, -67900,1.8080528,5.221381,,,,,,,,,,,,,, -68000,1.7200323,3.210474,,,,,,,,,,,,,, -68020,,,0.6422265768051147,1.4910420179367063,0.5968599915504456,1.6952730417251587,50000.0,0.4836000204086303,2.322556495666504,10000.0,31130.390008687973,34534.2478351593,31130.390008687973,3397.4480471611023,2.8839402198791504,0.0 -68100,1.6709703,3.2026782,,,,,,,,,,,,,, -68200,1.5778613,3.0014143,,,,,,,,,,,,,, -68300,1.7089128,2.8581512,,,,,,,,,,,,,, -68400,1.6099155,3.619467,,,,,,,,,,,,,, -68500,2.064507,2.5596814,,,,,,,,,,,,,, -68600,1.8316236,2.4875336,,,,,,,,,,,,,, -68700,1.839761,2.4006999,,,,,,,,,,,,,, -68800,1.4987534,5.0025887,,,,,,,,,,,,,, -68900,1.927319,2.6706626,,,,,,,,,,,,,, -68940,,,0.646289050579071,1.4966928958892822,0.5961999893188477,1.7258793115615845,50000.0,0.4816000163555145,2.361830711364746,10000.0,31550.531436681747,34999.48217225075,31550.531436681747,3442.4561920166016,2.921594381332397,0.0 -69000,2.0283968,2.4401712,,,,,,,,,,,,,, -69100,1.8435129,2.5097637,,,,,,,,,,,,,, -69200,1.9478865,2.436508,,,,,,,,,,,,,, -69300,1.832449,2.5195355,,,,,,,,,,,,,, -69400,2.0116603,2.5170016,,,,,,,,,,,,,, -69500,1.9050989,2.5952477,,,,,,,,,,,,,, -69600,1.8320013,2.5268962,,,,,,,,,,,,,, -69700,1.7517561,2.5417318,,,,,,,,,,,,,, -69800,1.8531897,3.6734638,,,,,,,,,,,,,, -69862,,,0.6507031321525574,1.4504530429840088,0.6015200018882751,1.6813979148864746,50000.0,0.4838000237941742,2.3317134380340576,10000.0,31970.809402942657,35468.415583610535,31970.809402942657,3491.027411222458,2.957694292068481,0.0 -69900,1.882474,2.5903547,,,,,,,,,,,,,, -70000,1.8806832,2.4556956,,,,,,,,,,,,,, -70100,1.7670432,2.5579734,,,,,,,,,,,,,, -70200,1.8745306,2.9019032,,,,,,,,,,,,,, -70300,1.7563356,3.756946,,,,,,,,,,,,,, -70400,1.6345509,2.6999326,,,,,,,,,,,,,, -70500,1.8277037,2.5571716,,,,,,,,,,,,,, -70600,2.1954114,2.5492783,,,,,,,,,,,,,, -70700,1.9400688,2.5914013,,,,,,,,,,,,,, -70783,,,0.6480273008346558,1.50771164894104,0.600879967212677,1.7230969667434692,50000.0,0.4803000092506408,2.354217767715454,10000.0,32390.82547283173,35935.01786708832,32390.82547283173,3537.5296635627747,2.994217157363892,0.0 -70800,1.8179154,2.4944642,,,,,,,,,,,,,, -70900,1.5090839,3.9961758,,,,,,,,,,,,,, -71000,1.8205814,5.160287,,,,,,,,,,,,,, -71100,1.8566165,2.4901204,,,,,,,,,,,,,, -71200,2.0418243,2.684828,,,,,,,,,,,,,, -71300,2.02031,2.7403002,,,,,,,,,,,,,, -71400,1.4825642,3.7745225,,,,,,,,,,,,,, -71500,1.8436042,2.8571603,,,,,,,,,,,,,, -71600,1.7033674,3.7945676,,,,,,,,,,,,,, -71700,2.0082178,2.3989081,,,,,,,,,,,,,, -71703,,,0.656054675579071,1.4267289638519287,0.6021999716758728,1.6627541780471802,50000.0,0.4830000102519989,2.303077459335327,10000.0,32810.79419326782,36402.86121177673,32810.79419326782,3585.316800355912,3.033379316329956,0.0 -71800,1.4714214,4.371188,,,,,,,,,,,,,, -71900,1.8787645,2.4200695,,,,,,,,,,,,,, -72000,1.8262161,2.9046772,,,,,,,,,,,,,, -72100,1.5766335,4.9479675,,,,,,,,,,,,,, -72200,1.7688107,2.5848424,,,,,,,,,,,,,, -72300,1.6289006,3.1918886,,,,,,,,,,,,,, -72400,1.3283429,4.227413,,,,,,,,,,,,,, -72500,1.8385171,2.559494,,,,,,,,,,,,,, -72600,1.6342688,3.4769936,,,,,,,,,,,,,, -72620,,,0.6702929735183716,1.386479735374451,0.605679988861084,1.6824146509170532,50000.0,0.4906000196933746,2.331911087036133,10000.0,33231.07704329491,36870.46979093552,33231.07704329491,3632.553422927856,3.0752978324890137,0.0 -72700,1.7533803,2.7042985,,,,,,,,,,,,,, -72800,1.4249303,4.599486,,,,,,,,,,,,,, -72900,1.5235181,5.0816875,,,,,,,,,,,,,, -73000,1.7882519,4.026247,,,,,,,,,,,,,, -73100,1.9268925,2.4537132,,,,,,,,,,,,,, -73200,1.7421777,5.095711,,,,,,,,,,,,,, -73300,2.0150688,2.702648,,,,,,,,,,,,,, -73400,1.8498518,2.6560984,,,,,,,,,,,,,, -73500,2.017287,2.7125936,,,,,,,,,,,,,, -73540,,,0.6487694978713989,1.481659770011902,0.6056999564170837,1.6786365509033203,50000.0,0.4826000332832336,2.3274142742156982,10000.0,33651.26502633095,37337.673639297485,33651.26502633095,3679.482887744904,3.113532781600952,0.0 -73600,1.9730028,2.514688,,,,,,,,,,,,,, -73700,1.854787,2.7507825,,,,,,,,,,,,,, -73800,1.9956374,2.4157588,,,,,,,,,,,,,, -73900,1.4881942,4.982486,,,,,,,,,,,,,, -74000,2.0094616,2.47301,,,,,,,,,,,,,, -74100,2.149066,5.227218,,,,,,,,,,,,,, -74200,2.0794067,2.5923893,,,,,,,,,,,,,, -74300,1.9953216,2.4196699,,,,,,,,,,,,,, -74400,1.928469,4.008633,,,,,,,,,,,,,, -74460,,,0.650585949420929,1.4430195093154907,0.6067399978637695,1.657494306564331,50000.0,0.4852000176906585,2.282309293746948,10000.0,34071.27830410004,37807.56340956688,34071.27830410004,3729.273057460785,3.151975631713867,0.0 -74500,1.8398126,2.4514136,,,,,,,,,,,,,, -74600,1.6992092,3.0805519,,,,,,,,,,,,,, -74700,1.9466535,5.102112,,,,,,,,,,,,,, -74800,1.8510339,2.4627905,,,,,,,,,,,,,, -74900,1.6596125,4.1831675,,,,,,,,,,,,,, -75000,1.472621,4.7512846,,,,,,,,,,,,,, -75100,1.797337,2.618376,,,,,,,,,,,,,, -75200,2.0047069,2.4007375,,,,,,,,,,,,,, -75300,1.8942719,2.4573386,,,,,,,,,,,,,, -75382,,,0.6629296541213989,1.3945977687835691,0.6041199564933777,1.6651949882507324,50000.0,0.486700028181076,2.301242113113404,10000.0,34491.6043074131,38275.6949763298,34491.6043074131,3776.994529485704,3.188735008239746,0.0 -75400,2.05396,2.5203354,,,,,,,,,,,,,, -75500,1.8502282,2.6157951,,,,,,,,,,,,,, -75600,1.8226442,3.145233,,,,,,,,,,,,,, -75700,1.6117902,3.7808137,,,,,,,,,,,,,, -75800,2.034318,2.7566288,,,,,,,,,,,,,, -75900,1.9814107,2.4488926,,,,,,,,,,,,,, -76000,1.9270967,2.7558992,,,,,,,,,,,,,, -76100,1.9556551,2.3673086,,,,,,,,,,,,,, -76200,1.9641732,2.4851086,,,,,,,,,,,,,, -76300,2.1362643,2.5409048,,,,,,,,,,,,,, -76303,,,0.6522851586341858,1.4518017768859863,0.6080399751663208,1.6486347913742063,50000.0,0.4861000180244446,2.312741279602051,10000.0,34911.860114097595,38743.56843471527,34911.860114097595,3824.521594762802,3.2313475608825684,0.0 -76400,1.625725,4.640031,,,,,,,,,,,,,, -76500,1.782538,4.0334873,,,,,,,,,,,,,, -76600,1.7694525,4.9700475,,,,,,,,,,,,,, -76700,1.8724791,2.5991013,,,,,,,,,,,,,, -76800,1.8597145,3.9854677,,,,,,,,,,,,,, -76900,1.6232072,4.646215,,,,,,,,,,,,,, -77000,2.1170058,2.3941274,,,,,,,,,,,,,, -77100,2.1180692,2.4121804,,,,,,,,,,,,,, -77200,1.9623309,2.4482856,,,,,,,,,,,,,, -77223,,,0.6526171565055847,1.4243426322937012,0.605459988117218,1.6445677280426023,50000.0,0.4901000261306762,2.283908605575561,10000.0,35331.959324359894,39210.66960573197,35331.959324359894,3871.440359354019,3.267023801803589,0.0 -77300,1.9309728,2.505776,,,,,,,,,,,,,, -77400,1.81563,2.6540608,,,,,,,,,,,,,, -77500,1.9235576,2.4173663,,,,,,,,,,,,,, -77600,1.854867,2.891266,,,,,,,,,,,,,, -77700,1.9835598,2.5701954,,,,,,,,,,,,,, -77800,1.6006515,4.9290357,,,,,,,,,,,,,, -77900,1.7271774,2.9271057,,,,,,,,,,,,,, -78000,1.9028804,2.3486338,,,,,,,,,,,,,, -78100,2.0576432,2.5040286,,,,,,,,,,,,,, -78143,,,0.6633203029632568,1.4052098989486694,0.6102399826049805,1.643385887145996,50000.0,0.4957000315189361,2.271117210388184,10000.0,35752.245290756226,39680.14678049088,35752.245290756226,3920.543641090393,3.3071439266204834,0.0 -78200,1.9903902,2.4816256,,,,,,,,,,,,,, -78300,1.6295245,4.075487,,,,,,,,,,,,,, -78400,2.1651084,2.4273598,,,,,,,,,,,,,, -78500,1.8979071,2.4650962,,,,,,,,,,,,,, -78600,1.8284781,3.0857646,,,,,,,,,,,,,, -78700,2.1025252,2.430925,,,,,,,,,,,,,, -78800,1.3926082,4.790651,,,,,,,,,,,,,, -78900,1.7748023,5.0339737,,,,,,,,,,,,,, -79000,1.8253263,3.9374614,,,,,,,,,,,,,, -79065,,,0.6604882478713989,1.4064991474151611,0.6160799860954285,1.6271092891693115,50000.0,0.4962000250816345,2.260464668273926,10000.0,36172.618824243546,40148.34972453117,36172.618824243546,3968.289653062821,3.3433430194854736,0.0 -79100,1.965138,2.4188719,,,,,,,,,,,,,, -79200,1.6139007,5.0174813,,,,,,,,,,,,,, -79300,1.9853276,2.5299482,,,,,,,,,,,,,, -79400,1.8553773,2.3928225,,,,,,,,,,,,,, -79500,1.9969493,2.4002807,,,,,,,,,,,,,, -79600,1.5706731,4.2832103,,,,,,,,,,,,,, -79700,1.6005203,4.970613,,,,,,,,,,,,,, -79800,2.0084713,2.548378,,,,,,,,,,,,,, -79900,1.892204,2.4359097,,,,,,,,,,,,,, -79985,,,0.6540429592132568,1.444997787475586,0.6112200021743774,1.641564965248108,50000.0,0.4923000335693359,2.274275779724121,10000.0,36592.68895483017,40617.90951514244,36592.68895483017,4017.691013813019,3.384563684463501,0.0 -80000,1.9894764,2.2920363,,,,,,,,,,,,,, -80100,1.8560992,2.6800842,,,,,,,,,,,,,, -80200,1.7242606,3.390758,,,,,,,,,,,,,, -80300,1.8673797,2.437901,,,,,,,,,,,,,, -80400,1.6749946,4.0421543,,,,,,,,,,,,,, -80500,1.7259198,4.803577,,,,,,,,,,,,,, -80600,1.9426982,2.4962842,,,,,,,,,,,,,, -80700,2.080089,2.5275817,,,,,,,,,,,,,, -80800,1.902166,3.7935882,,,,,,,,,,,,,, -80900,1.9477541,2.3910542,,,,,,,,,,,,,, -80905,,,0.6605077981948853,1.3962000608444214,0.6113799810409546,1.6233501434326172,50000.0,0.4937000274658203,2.255664110183716,10000.0,37012.93760061264,41086.63825082779,37012.93760061264,4066.0741584301,3.4339022636413574,0.0 -81000,1.760256,2.8770556,,,,,,,,,,,,,, -81100,2.0307403,2.2897995,,,,,,,,,,,,,, -81200,1.5901405,4.8129277,,,,,,,,,,,,,, -81300,2.1301174,2.4486399,,,,,,,,,,,,,, -81400,2.0291848,2.3785193,,,,,,,,,,,,,, -81500,2.076869,2.659973,,,,,,,,,,,,,, -81600,1.8007036,4.387332,,,,,,,,,,,,,, -81700,1.7908384,4.9812016,,,,,,,,,,,,,, -81800,1.6050649,4.9494357,,,,,,,,,,,,,, -81826,,,0.6890624761581421,1.2909141778945925,0.6201599836349487,1.59787118434906,50000.0,0.503600001335144,2.2272136211395264,10000.0,37433.07783913613,41555.29655838013,37433.07783913613,4114.506975889206,3.4718174934387207,0.0 -81900,2.1724854,2.496075,,,,,,,,,,,,,, -82000,1.9196502,4.394187,,,,,,,,,,,,,, -82100,2.2145839,2.5132902,,,,,,,,,,,,,, -82200,1.8231064,5.0990644,,,,,,,,,,,,,, -82300,1.8830509,2.5321796,,,,,,,,,,,,,, -82400,2.0947502,2.3656857,,,,,,,,,,,,,, -82500,1.6449096,3.3421905,,,,,,,,,,,,,, -82600,1.6085862,4.6107664,,,,,,,,,,,,,, -82700,1.8785883,2.3110301,,,,,,,,,,,,,, -82746,,,0.6591405868530273,1.4403164386749268,0.6159799695014954,1.641635537147522,50000.0,0.4896000325679779,2.2967171669006348,10000.0,37853.40079331398,42018.43554711342,37853.40079331398,4157.234240293503,3.51223087310791,0.0 -82800,2.2036269,2.5219214,,,,,,,,,,,,,, -82900,1.8466988,2.3452208,,,,,,,,,,,,,, -83000,1.7728174,3.7067335,,,,,,,,,,,,,, -83100,2.1690686,2.3348384,,,,,,,,,,,,,, -83200,2.0848556,2.653099,,,,,,,,,,,,,, -83300,1.7415899,4.8397784,,,,,,,,,,,,,, -83400,2.1090314,4.9809685,,,,,,,,,,,,,, -83500,2.0311399,2.3242364,,,,,,,,,,,,,, -83600,1.6820469,3.3599536,,,,,,,,,,,,,, -83666,,,0.663769543170929,1.3689274787902832,0.616320013999939,1.5968698263168335,50000.0,0.4989000260829925,2.235353469848633,10000.0,38273.69846391678,42488.09752130509,38273.69846391678,4206.504290103912,3.5590643882751465,0.0 -83700,1.5922394,4.4309635,,,,,,,,,,,,,, -83800,1.9433477,2.743048,,,,,,,,,,,,,, -83900,2.2922094,2.4184809,,,,,,,,,,,,,, -84000,1.9102346,2.8724666,,,,,,,,,,,,,, -84100,2.0988383,2.4109704,,,,,,,,,,,,,, -84200,2.0483592,2.406075,,,,,,,,,,,,,, -84300,1.7611654,3.6903265,,,,,,,,,,,,,, -84400,2.2206385,2.4356844,,,,,,,,,,,,,, -84500,1.9471103,2.6067042,,,,,,,,,,,,,, -84587,,,0.6851366758346558,1.2849160432815552,0.6199399828910828,1.5707097053527832,50000.0,0.4974000155925751,2.218472480773926,10000.0,38693.85767745972,42955.73055052757,38693.85767745972,4253.889084339142,3.6006920337677,0.0 -84600,1.8273246,2.7007327,,,,,,,,,,,,,, -84700,2.225342,2.415936,,,,,,,,,,,,,, -84800,2.0108037,2.3940485,,,,,,,,,,,,,, -84900,1.8387983,4.9232864,,,,,,,,,,,,,, -85000,2.1157835,2.7788434,,,,,,,,,,,,,, -85100,2.100675,2.4471498,,,,,,,,,,,,,, -85200,2.2848868,2.3676438,,,,,,,,,,,,,, -85300,1.8574789,4.9305706,,,,,,,,,,,,,, -85400,2.0128317,2.367749,,,,,,,,,,,,,, -85500,1.95495,2.5460038,,,,,,,,,,,,,, -85507,,,0.6639453172683716,1.3985720872879028,0.6188200116157532,1.608765721321106,50000.0,0.4985000193119049,2.248077869415283,10000.0,39114.19425010681,43423.79737305641,39114.19425010681,4301.529177188873,3.6426284313201904,0.0 -85600,2.017825,2.4501173,,,,,,,,,,,,,, -85700,2.01269,2.3756678,,,,,,,,,,,,,, -85800,1.9260957,2.7405431,,,,,,,,,,,,,, -85900,1.9557894,2.4222758,,,,,,,,,,,,,, -86000,1.9281251,2.4740791,,,,,,,,,,,,,, -86100,1.7603227,2.704643,,,,,,,,,,,,,, -86200,1.9492056,4.842457,,,,,,,,,,,,,, -86300,2.2583606,4.4693527,,,,,,,,,,,,,, -86400,1.7729179,4.834999,,,,,,,,,,,,,, -86427,,,0.673535168170929,1.353287935256958,0.6249600052833557,1.573027729988098,50000.0,0.5034000277519226,2.218822717666626,10000.0,39534.38470721245,43892.33448171616,39534.38470721245,4349.79114151001,3.6794145107269287,0.0 -86500,1.8737808,2.3054545,,,,,,,,,,,,,, -86600,1.75286,3.5277588,,,,,,,,,,,,,, -86700,1.9605873,2.379194,,,,,,,,,,,,,, -86800,1.7834747,2.9562578,,,,,,,,,,,,,, -86900,1.8091979,3.3399758,,,,,,,,,,,,,, -87000,2.0573673,2.5752401,,,,,,,,,,,,,, -87100,1.9351119,2.4175878,,,,,,,,,,,,,, -87200,2.052997,2.343007,,,,,,,,,,,,,, -87300,2.24199,2.7712855,,,,,,,,,,,,,, -87345,,,0.6821679472923279,1.2937649488449097,0.6266599893569946,1.55876624584198,50000.0,0.5062000155448914,2.2032463550567627,10000.0,39954.31283664704,44361.49378180504,39954.31283664704,4398.92951130867,3.722033023834229,0.0 -87400,2.0779994,2.3652778,,,,,,,,,,,,,, -87500,2.2184408,2.4335728,,,,,,,,,,,,,, -87600,1.7453115,3.791473,,,,,,,,,,,,,, -87700,1.7999669,3.7311814,,,,,,,,,,,,,, -87800,1.6509048,4.0869894,,,,,,,,,,,,,, -87900,2.0503347,2.383314,,,,,,,,,,,,,, -88000,1.8957862,4.0440207,,,,,,,,,,,,,, -88100,2.0661647,2.4160383,,,,,,,,,,,,,, -88200,1.730503,4.5693054,,,,,,,,,,,,,, -88262,,,0.6767578125,1.3330546617507937,0.6308199763298035,1.5420633554458618,50000.0,0.510200023651123,2.181927442550659,10000.0,40374.37486696243,44829.66952109337,40374.37486696243,4446.9581387043,3.759573459625244,0.0 -88300,2.5332587,2.3743913,,,,,,,,,,,,,, -88400,1.9882778,4.7743015,,,,,,,,,,,,,, -88500,1.7412769,3.529978,,,,,,,,,,,,,, -88600,2.0197403,2.388839,,,,,,,,,,,,,, -88700,2.2647524,2.3602831,,,,,,,,,,,,,, -88800,1.93836,2.5611174,,,,,,,,,,,,,, -88900,1.8487356,2.8271732,,,,,,,,,,,,,, -89000,2.0856464,2.2823286,,,,,,,,,,,,,, -89100,2.138976,2.5124655,,,,,,,,,,,,,, -89181,,,0.6698632836341858,1.3340215682983398,0.6282399892807007,1.5469225645065308,50000.0,0.509600043296814,2.215367317199707,10000.0,40794.69830536842,45298.46662902832,40794.69830536842,4495.342826843262,3.800304889678955,0.0 -89200,1.9073482,2.7056289,,,,,,,,,,,,,, -89300,1.816897,2.8438094,,,,,,,,,,,,,, -89400,1.6435686,4.333924,,,,,,,,,,,,,, -89500,1.7060841,3.583942,,,,,,,,,,,,,, -89600,1.8945425,2.430616,,,,,,,,,,,,,, -89700,2.0837884,2.8019881,,,,,,,,,,,,,, -89800,1.9336904,4.7629023,,,,,,,,,,,,,, -89900,2.1170912,2.2420237,,,,,,,,,,,,,, -90000,2.1573002,2.3139024,,,,,,,,,,,,,, -90099,,,0.68359375,1.305981159210205,0.6265999674797058,1.5628786087036133,50000.0,0.5063000321388245,2.2054131031036377,10000.0,41214.659896850586,45763.786215782166,41214.659896850586,4540.608873128891,3.844949960708618,0.0 -90100,1.6325645,4.9249525,,,,,,,,,,,,,, -90200,1.7517551,3.1586819,,,,,,,,,,,,,, -90300,1.811761,3.071293,,,,,,,,,,,,,, -90400,2.316995,2.2825718,,,,,,,,,,,,,, -90500,2.2934773,2.3834553,,,,,,,,,,,,,, -90600,1.7405369,3.8687024,,,,,,,,,,,,,, -90700,2.000504,2.2490053,,,,,,,,,,,,,, -90800,1.8094822,2.9155016,,,,,,,,,,,,,, -90900,2.0862994,2.556591,,,,,,,,,,,,,, -91000,2.0014443,3.303508,,,,,,,,,,,,,, -91018,,,0.6847070455551147,1.3044037818908691,0.6290199756622314,1.5476787090301514,50000.0,0.5110000371932983,2.183170795440674,10000.0,41635.00652647018,46232.89801955223,41635.00652647018,4589.280221223831,3.891319274902344,0.0 -91100,2.4054666,2.4148254,,,,,,,,,,,,,, -91200,2.0335014,2.2285728,,,,,,,,,,,,,, -91300,2.1772928,2.8148217,,,,,,,,,,,,,, -91400,2.0493731,2.2813315,,,,,,,,,,,,,, -91500,2.2252994,4.712999,,,,,,,,,,,,,, -91600,2.048114,2.3207262,,,,,,,,,,,,,, -91700,2.1856062,2.3047533,,,,,,,,,,,,,, -91800,1.5306566,4.6224165,,,,,,,,,,,,,, -91900,1.9236654,2.6528807,,,,,,,,,,,,,, -91936,,,0.6799609065055847,1.3173344135284424,0.6291599869728088,1.534510850906372,50000.0,0.5093000531196594,2.173914194107056,10000.0,42055.1633245945,46700.18941116333,42055.1633245945,4636.324400186539,3.93449640274048,0.0 -92000,2.0571923,2.4762616,,,,,,,,,,,,,, -92100,1.9480057,4.622311,,,,,,,,,,,,,, -92200,2.2179568,2.3710585,,,,,,,,,,,,,, -92300,1.8745099,4.83384,,,,,,,,,,,,,, -92400,1.772188,4.1715503,,,,,,,,,,,,,, -92500,1.8195561,3.7528555,,,,,,,,,,,,,, -92600,2.0808938,2.363534,,,,,,,,,,,,,, -92700,1.7668142,4.5141726,,,,,,,,,,,,,, -92800,2.38706,2.4557643,,,,,,,,,,,,,, -92858,,,0.6845703125,1.30815589427948,0.6358000040054321,1.5391371250152588,50000.0,0.5106000304222107,2.1962108612060547,10000.0,42475.19194078445,47169.31934523583,42475.19194078445,4685.3404994010925,3.972641944885254,0.0 -92900,1.9601052,3.6181695,,,,,,,,,,,,,, -93000,1.8527241,3.244084,,,,,,,,,,,,,, -93100,2.1960177,2.3821926,,,,,,,,,,,,,, -93200,2.0603728,3.2733817,,,,,,,,,,,,,, -93300,2.0998666,2.382287,,,,,,,,,,,,,, -93400,1.7784649,3.0895338,,,,,,,,,,,,,, -93500,2.0952852,3.5397499,,,,,,,,,,,,,, -93600,1.9497294,2.6505527,,,,,,,,,,,,,, -93700,2.0691848,2.4336202,,,,,,,,,,,,,, -93780,,,0.7090038657188416,1.1813215017318726,0.6354999542236328,1.5112653970718384,50000.0,0.5103000402450562,2.147897481918335,10000.0,42895.57374858856,47636.26103925705,42895.57374858856,4731.815329551697,4.010065317153931,0.0 -93800,2.0531356,2.357059,,,,,,,,,,,,,, -93900,1.7816943,3.9366522,,,,,,,,,,,,,, -94000,2.2905636,2.2420638,,,,,,,,,,,,,, -94100,1.9459333,4.781909,,,,,,,,,,,,,, -94200,1.7986323,4.435515,,,,,,,,,,,,,, -94300,1.7943345,4.615101,,,,,,,,,,,,,, -94400,1.9661461,2.554089,,,,,,,,,,,,,, -94500,2.1542373,2.8069916,,,,,,,,,,,,,, -94600,2.3317702,2.3687713,,,,,,,,,,,,,, -94699,,,0.6853711009025574,1.2983242273330688,0.634939968585968,1.5220338106155396,50000.0,0.516700029373169,2.158723831176758,10000.0,43315.627083063126,48105.22648501396,43315.627083063126,4780.633121490479,4.056293964385986,0.0 -94700,1.924417,2.6755,,,,,,,,,,,,,, -94800,2.0454738,2.609596,,,,,,,,,,,,,, -94900,2.068147,4.848771,,,,,,,,,,,,,, -95000,2.1051147,3.9526317,,,,,,,,,,,,,, -95100,1.9568849,2.21925,,,,,,,,,,,,,, -95200,2.3102388,2.330661,,,,,,,,,,,,,, -95300,2.3518028,2.3195963,,,,,,,,,,,,,, -95400,2.3191872,4.50824,,,,,,,,,,,,,, -95500,1.9661441,4.8390226,,,,,,,,,,,,,, -95600,2.0927286,2.4527566,,,,,,,,,,,,,, -95619,,,0.6918359398841858,1.3312885761260986,0.6355999708175659,1.5759185552597046,50000.0,0.5182000398635864,2.203728199005127,10000.0,43735.60219955444,48575.65565729141,43735.60219955444,4830.998948812485,4.0974836349487305,0.0 -95700,1.8707311,3.2904444,,,,,,,,,,,,,, -95800,1.8604124,4.6245785,,,,,,,,,,,,,, -95900,2.0903423,2.644699,,,,,,,,,,,,,, -96000,2.1872973,2.272718,,,,,,,,,,,,,, -96100,1.9055974,4.8346763,,,,,,,,,,,,,, -96200,1.97461,2.7986155,,,,,,,,,,,,,, -96300,2.056607,2.2689965,,,,,,,,,,,,,, -96400,1.8381681,4.8554316,,,,,,,,,,,,,, -96500,2.209197,2.2619677,,,,,,,,,,,,,, -96541,,,0.6978124976158142,1.2508113384246826,0.6395999789237976,1.5194545984268188,50000.0,0.5205000042915344,2.156029462814331,10000.0,44155.92951416969,49041.48910880089,44155.92951416969,4876.413905143738,4.138431787490845,0.0 -96600,2.4518077,2.2241225,,,,,,,,,,,,,, -96700,2.24322,2.186445,,,,,,,,,,,,,, -96800,2.3420975,2.3266299,,,,,,,,,,,,,, -96900,2.2221642,2.39673,,,,,,,,,,,,,, -97000,1.9259835,2.5907278,,,,,,,,,,,,,, -97100,1.9768871,4.681796,,,,,,,,,,,,,, -97200,1.810064,3.157763,,,,,,,,,,,,,, -97300,1.8449551,4.3537827,,,,,,,,,,,,,, -97400,1.8198496,4.822159,,,,,,,,,,,,,, -97461,,,0.6885156035423279,1.2628841400146484,0.642579972743988,1.4694026708602903,50000.0,0.5254999995231628,2.106327772140503,10000.0,44575.93083691597,49509.01126098633,44575.93083691597,4923.843191146851,4.181923389434815,0.0 -97500,2.0571468,2.6378865,,,,,,,,,,,,,, -97600,1.9583685,2.4810586,,,,,,,,,,,,,, -97700,2.149396,2.7161498,,,,,,,,,,,,,, -97800,1.9072034,3.799499,,,,,,,,,,,,,, -97900,2.6697497,2.2539532,,,,,,,,,,,,,, -98000,2.3261425,2.2598424,,,,,,,,,,,,,, -98100,2.0945387,2.4943535,,,,,,,,,,,,,, -98200,2.4071667,2.2548409,,,,,,,,,,,,,, -98300,2.1776347,2.4563954,,,,,,,,,,,,,, -98381,,,0.6943945288658142,1.2435778379440308,0.6418399810791016,1.479137659072876,50000.0,0.5218999981880188,2.118631362915039,10000.0,44995.94634652138,49977.063490867615,44995.94634652138,4971.78942322731,4.224483251571655,0.0 -98400,1.803744,4.715511,,,,,,,,,,,,,, -98500,2.3716264,4.93209,,,,,,,,,,,,,, -98600,1.9463009,4.7872887,,,,,,,,,,,,,, -98700,2.1120565,2.1220908,,,,,,,,,,,,,, -98800,2.4008737,2.1943986,,,,,,,,,,,,,, -98900,2.0474222,2.3925726,,,,,,,,,,,,,, -99000,2.1891317,4.7796907,,,,,,,,,,,,,, -99100,2.185161,2.2302113,,,,,,,,,,,,,, -99200,1.9500506,4.616529,,,,,,,,,,,,,, -99296,,,0.6937890648841858,1.2605602741241455,0.6408799886703491,1.5083057880401611,50000.0,0.5139000415802002,2.170624017715454,10000.0,45416.24973320961,50444.14269065857,45416.24973320961,5018.474480867386,4.268015146255493,0.0 -99300,2.22505,2.3719351,,,,,,,,,,,,,, -99400,2.2089903,2.1429203,,,,,,,,,,,,,, -99500,1.9800365,4.7171097,,,,,,,,,,,,,, -99600,2.0751517,2.772773,,,,,,,,,,,,,, -99700,1.7542386,4.080444,,,,,,,,,,,,,, -99800,2.2058413,3.4258118,,,,,,,,,,,,,, -99900,1.8014168,4.270205,,,,,,,,,,,,,, -100000,2.2569525,2.09846,,,,,,,,,,,,,, -100100,1.8096749,4.798671,,,,,,,,,,,,,, -100200,2.0832977,4.5402565,,,,,,,,,,,,,, -100214,,,0.6874804496765137,1.2632285356521606,0.6426999568939209,1.4662885665893557,50000.0,0.5177000164985657,2.1122829914093018,10000.0,45836.35974335671,50912.906153678894,45836.35974335671,5067.025764942169,4.322429895401001,0.0 -100300,2.034794,4.235285,,,,,,,,,,,,,, -100400,2.2910287,2.3653357,,,,,,,,,,,,,, -100500,2.4665809,2.2295883,,,,,,,,,,,,,, -100600,1.9112643,3.4262614,,,,,,,,,,,,,, -100700,2.3523982,2.1714847,,,,,,,,,,,,,, -100800,2.5008512,2.26887,,,,,,,,,,,,,, -100900,1.9218035,4.199357,,,,,,,,,,,,,, -101000,2.2624943,2.0649583,,,,,,,,,,,,,, -101100,2.2286668,3.05479,,,,,,,,,,,,,, -101134,,,0.6952343583106995,1.2298552989959717,0.646399974822998,1.4554753303527832,50000.0,0.5227000117301941,2.1116693019866943,10000.0,46256.391678094864,51377.74165701866,46256.391678094864,5111.740765094757,4.362881422042847,0.0 -101200,2.3023975,2.6142304,,,,,,,,,,,,,, -101300,2.2250257,2.1048253,,,,,,,,,,,,,, -101400,2.2242439,4.8118715,,,,,,,,,,,,,, -101500,2.307708,2.1472793,,,,,,,,,,,,,, -101600,1.7894403,3.7670112,,,,,,,,,,,,,, -101700,2.0938063,2.4903097,,,,,,,,,,,,,, -101800,1.8713421,4.1342077,,,,,,,,,,,,,, -101900,2.0066013,2.3870666,,,,,,,,,,,,,, -102000,2.1381278,2.283609,,,,,,,,,,,,,, -102056,,,0.6937499642372131,1.2900238037109375,0.638759970664978,1.5294249057769775,50000.0,0.5200000405311584,2.172455310821533,10000.0,46676.57629132271,51844.58192944527,46676.57629132271,5158.299484729767,4.410746574401856,0.0 -102100,1.9964958,3.4522545,,,,,,,,,,,,,, -102200,1.8735214,3.283818,,,,,,,,,,,,,, -102300,2.228386,3.033631,,,,,,,,,,,,,, -102400,2.2741168,3.0177875,,,,,,,,,,,,,, -102500,2.1216588,2.1601315,,,,,,,,,,,,,, -102600,2.413987,2.1921887,,,,,,,,,,,,,, -102700,2.346031,2.6198022,,,,,,,,,,,,,, -102800,2.026588,4.7698007,,,,,,,,,,,,,, -102900,2.0984085,4.671361,,,,,,,,,,,,,, -102974,,,0.7155859470367432,1.145095944404602,0.6472600102424622,1.4536617994308472,50000.0,0.5229000449180603,2.0942981243133545,10000.0,47096.54494357109,52311.58908891678,47096.54494357109,5205.239888191223,4.460630416870117,0.0 -103000,1.9319469,4.5968103,,,,,,,,,,,,,, -103100,2.3288503,2.1079388,,,,,,,,,,,,,, -103200,1.8419535,3.671959,,,,,,,,,,,,,, -103300,2.5108385,2.1156738,,,,,,,,,,,,,, -103400,2.2390015,2.3295856,,,,,,,,,,,,,, -103500,2.0526574,3.1747856,,,,,,,,,,,,,, -103600,2.4272935,2.0896215,,,,,,,,,,,,,, -103700,2.2844093,2.1352782,,,,,,,,,,,,,, -103800,2.0187936,4.5419908,,,,,,,,,,,,,, -103892,,,0.6991210579872131,1.2539443969726562,0.6495400071144104,1.4771441221237185,50000.0,0.52510005235672,2.1233255863189697,10000.0,47516.91737222672,52779.09874868393,47516.91737222672,5252.287913560867,4.502202272415161,0.0 -103900,2.053063,4.026136,,,,,,,,,,,,,, -104000,2.2614396,2.1518276,,,,,,,,,,,,,, -104100,2.0940204,2.467565,,,,,,,,,,,,,, -104200,1.9942243,4.435075,,,,,,,,,,,,,, -104300,2.1141844,2.1793938,,,,,,,,,,,,,, -104400,2.2439988,1.9435859,,,,,,,,,,,,,, -104500,2.2195165,2.0949273,,,,,,,,,,,,,, -104600,2.1560843,3.1616018,,,,,,,,,,,,,, -104700,2.4844172,2.1494257,,,,,,,,,,,,,, -104800,2.1459439,3.2880986,,,,,,,,,,,,,, -104814,,,0.7024218440055847,1.2268569469451904,0.6485799551010132,1.4664111137390137,50000.0,0.5250000357627869,2.094282627105713,10000.0,47936.990965127945,53241.568251132965,47936.990965127945,5294.593217134476,4.5446178913116455,0.0 -104900,2.3613439,2.199733,,,,,,,,,,,,,, -105000,2.5059347,2.2495694,,,,,,,,,,,,,, -105100,2.8286638,2.1535711,,,,,,,,,,,,,, -105200,1.7814142,4.135318,,,,,,,,,,,,,, -105300,2.4107172,2.109812,,,,,,,,,,,,,, -105400,2.022014,3.0936103,,,,,,,,,,,,,, -105500,2.4916084,2.1936677,,,,,,,,,,,,,, -105600,1.8975648,4.0328603,,,,,,,,,,,,,, -105700,2.363262,2.475117,,,,,,,,,,,,,, -105735,,,0.7244726419448853,1.1140944957733154,0.6542800068855286,1.4219244718551636,50000.0,0.5297000408172607,2.06118106842041,10000.0,48357.00198101997,53707.13777804375,48357.00198101997,5340.061425924301,4.586939811706543,0.0 -105800,2.4845045,2.106698,,,,,,,,,,,,,, -105900,2.3318305,3.7441106,,,,,,,,,,,,,, -106000,2.1691058,3.6616368,,,,,,,,,,,,,, -106100,2.4229472,2.531867,,,,,,,,,,,,,, -106200,2.2598572,2.130122,,,,,,,,,,,,,, -106300,2.0266554,2.4488027,,,,,,,,,,,,,, -106400,2.26822,4.6880927,,,,,,,,,,,,,, -106500,2.3023186,2.2026439,,,,,,,,,,,,,, -106600,2.5703266,2.1725526,,,,,,,,,,,,,, -106651,,,0.7056054472923279,1.192114233970642,0.6538000106811523,1.427022099494934,50000.0,0.5336000323295593,2.058079481124878,10000.0,48777.18662452698,54175.82275629044,48777.18662452698,5388.473347187042,4.627787828445435,0.0 -106700,2.307031,2.1987264,,,,,,,,,,,,,, -106800,1.8975793,2.801071,,,,,,,,,,,,,, -106900,2.1121352,4.6323123,,,,,,,,,,,,,, -107000,2.2036247,2.3380387,,,,,,,,,,,,,, -107100,2.1395786,3.725881,,,,,,,,,,,,,, -107200,2.0822077,3.375897,,,,,,,,,,,,,, -107300,2.0459256,4.3928814,,,,,,,,,,,,,, -107400,2.3483965,2.4359608,,,,,,,,,,,,,, -107500,2.2661448,2.1330266,,,,,,,,,,,,,, -107568,,,0.7154492139816284,1.1752427816390991,0.6553599834442139,1.4241199493408203,50000.0,0.5356000065803528,2.0526721477508545,10000.0,49197.227964401245,54645.73945307732,49197.227964401245,5438.257611513138,4.671726226806641,0.0 -107600,2.2915645,2.195799,,,,,,,,,,,,,, -107700,2.0974553,2.742837,,,,,,,,,,,,,, -107800,2.4268985,4.913276,,,,,,,,,,,,,, -107900,1.8443767,4.427885,,,,,,,,,,,,,, -108000,2.4165413,2.0469618,,,,,,,,,,,,,, -108100,2.0783308,3.0130343,,,,,,,,,,,,,, -108200,2.494243,2.3040926,,,,,,,,,,,,,, -108300,2.4523685,2.094214,,,,,,,,,,,,,, -108400,1.970175,3.213468,,,,,,,,,,,,,, -108488,,,0.71839839220047,1.1570578813552856,0.6541000008583069,1.4482253789901731,50000.0,0.5313000082969666,2.080390214920044,10000.0,49617.28289413452,55114.79280400276,49617.28289413452,5487.164425611496,4.715782403945923,0.0 -108500,2.3465228,2.4472055,,,,,,,,,,,,,, -108600,2.28206,2.1919546,,,,,,,,,,,,,, -108700,2.6118524,2.149187,,,,,,,,,,,,,, -108800,2.6072524,2.1727858,,,,,,,,,,,,,, -108900,2.1919353,2.3757012,,,,,,,,,,,,,, -109000,2.3424475,2.5363867,,,,,,,,,,,,,, -109100,2.5512586,2.1223774,,,,,,,,,,,,,, -109200,2.3233335,2.0016134,,,,,,,,,,,,,, -109300,2.5042408,4.3294744,,,,,,,,,,,,,, -109400,2.520401,2.251246,,,,,,,,,,,,,, -109409,,,0.7098046541213989,1.2095328569412231,0.6589800119400024,1.4317922592163086,50000.0,0.5364000201225281,2.062147617340088,10000.0,50037.39028072357,55580.21659255028,50037.39028072357,5532.3830144405365,4.7658116817474365,0.0 -109500,2.414218,2.032833,,,,,,,,,,,,,, -109600,2.2045105,4.7418904,,,,,,,,,,,,,, -109700,2.061814,3.1762526,,,,,,,,,,,,,, -109800,2.3302858,2.1262727,,,,,,,,,,,,,, -109900,2.186822,2.644873,,,,,,,,,,,,,, -110000,2.5263302,2.7274475,,,,,,,,,,,,,, -110100,2.133383,2.81248,,,,,,,,,,,,,, -110200,2.2955296,2.431789,,,,,,,,,,,,,, -110300,2.4093258,2.113152,,,,,,,,,,,,,, -110328,,,0.705078125,1.216767191886902,0.6532399654388428,1.4524482488632202,50000.0,0.5294000506401062,2.0938665866851807,10000.0,50457.57064390183,56048.070026397705,50457.57064390183,5579.963124036789,4.810678005218506,0.0 -110400,2.8098292,2.100286,,,,,,,,,,,,,, -110500,2.183812,4.2547154,,,,,,,,,,,,,, -110600,2.5597925,2.0937562,,,,,,,,,,,,,, -110700,2.5541847,2.1697989,,,,,,,,,,,,,, -110800,2.2678509,2.2980146,,,,,,,,,,,,,, -110900,2.153146,4.697588,,,,,,,,,,,,,, -111000,2.1831427,2.531083,,,,,,,,,,,,,, -111100,2.2275512,3.467537,,,,,,,,,,,,,, -111200,2.1326525,3.6873007,,,,,,,,,,,,,, -111247,,,0.7220507860183716,1.1261377334594729,0.6609999537467957,1.4034225940704346,50000.0,0.5353000164031982,2.035445213317871,10000.0,50877.94385957718,56518.80230307579,50877.94385957718,5630.232651948929,4.852938175201416,0.0 -111300,2.313082,2.0996966,,,,,,,,,,,,,, -111400,2.3290837,4.020654,,,,,,,,,,,,,, -111500,2.1633284,4.0311837,,,,,,,,,,,,,, -111600,2.5686219,2.2065759,,,,,,,,,,,,,, -111700,2.3567362,2.05081,,,,,,,,,,,,,, -111800,2.5365777,2.256818,,,,,,,,,,,,,, -111900,2.189684,2.694779,,,,,,,,,,,,,, -112000,1.9793602,4.191553,,,,,,,,,,,,,, -112100,2.0949092,3.892826,,,,,,,,,,,,,, -112167,,,0.7176757454872131,1.1667044162750244,0.6600599884986877,1.414770007133484,50000.0,0.5349000096321106,2.034170866012573,10000.0,51298.08949494362,56987.73841023445,51298.08949494362,5678.924888134003,4.904115915298462,0.0 -112200,2.6730855,2.0388992,,,,,,,,,,,,,, -112300,2.2495098,2.6408353,,,,,,,,,,,,,, -112400,2.5191703,2.2009094,,,,,,,,,,,,,, -112500,2.4664333,2.1577642,,,,,,,,,,,,,, -112600,2.0886626,2.9695399,,,,,,,,,,,,,, -112700,2.266798,2.522231,,,,,,,,,,,,,, -112800,2.7965527,2.1632302,,,,,,,,,,,,,, -112900,2.4271393,2.264779,,,,,,,,,,,,,, -113000,2.4162135,2.0142195,,,,,,,,,,,,,, -113085,,,0.7177538871765137,1.1266690492630005,0.6662999987602234,1.3710778951644895,50000.0,0.5451000332832336,2.005274534225464,10000.0,51718.244585990906,57456.35862541199,51718.244585990906,5727.290082454681,4.955888032913208,0.0 -113100,2.3597665,2.0601923,,,,,,,,,,,,,, -113200,2.4061484,1.998842,,,,,,,,,,,,,, -113300,2.4718125,2.0837066,,,,,,,,,,,,,, -113400,2.1557863,4.212306,,,,,,,,,,,,,, -113500,2.5053232,2.7580047,,,,,,,,,,,,,, -113600,2.817163,2.2501192,,,,,,,,,,,,,, -113700,2.6361663,2.1769114,,,,,,,,,,,,,, -113800,2.3274963,2.0123742,,,,,,,,,,,,,, -113900,2.3890283,2.2453647,,,,,,,,,,,,,, -114000,2.1433303,3.1774273,,,,,,,,,,,,,, -114004,,,0.7226366996765137,1.1332825422286987,0.6693199872970581,1.3838796615600586,50000.0,0.5439000129699707,2.008410692214966,10000.0,52138.60879659653,57921.55140066147,52138.60879659653,5772.027301549912,5.000272750854492,0.0 -114100,2.741265,2.0405734,,,,,,,,,,,,,, -114200,2.918992,2.0137017,,,,,,,,,,,,,, -114300,2.50169,1.9717488,,,,,,,,,,,,,, -114400,2.4086506,2.7827823,,,,,,,,,,,,,, -114500,2.7115107,2.1188407,,,,,,,,,,,,,, -114600,2.484031,2.32911,,,,,,,,,,,,,, -114700,2.67585,2.1264994,,,,,,,,,,,,,, -114800,2.5821874,4.6770515,,,,,,,,,,,,,, -114900,2.4334984,2.4760828,,,,,,,,,,,,,, -114925,,,0.7480273246765137,1.0093517303466797,0.6700999736785889,1.3534212112426758,50000.0,0.5440000295639038,1.9908446073532104,10000.0,52559.03491163254,58392.56290578842,52559.03491163254,5822.523226261139,5.0422186851501465,0.0 -115000,2.3707237,2.7154055,,,,,,,,,,,,,, -115100,2.4924967,2.0148234,,,,,,,,,,,,,, -115200,2.3808987,3.292058,,,,,,,,,,,,,, -115300,2.3741143,2.1116374,,,,,,,,,,,,,, -115400,2.7904603,2.02701,,,,,,,,,,,,,, -115500,2.499317,2.1173773,,,,,,,,,,,,,, -115600,2.4649904,2.1697516,,,,,,,,,,,,,, -115700,2.1362977,3.233309,,,,,,,,,,,,,, -115800,2.7417362,2.056453,,,,,,,,,,,,,, -115842,,,0.7229687571525574,1.115723729133606,0.6693199872970581,1.3546879291534424,50000.0,0.5459000468254089,1.989838004112244,10000.0,52979.48181724548,58858.70378828049,52979.48181724548,5868.125700950623,5.087374925613403,0.0 -115900,2.5252693,4.5362377,,,,,,,,,,,,,, -116000,2.499296,2.13435,,,,,,,,,,,,,, -116100,2.138287,3.7174635,,,,,,,,,,,,,, -116200,2.569764,2.1518102,,,,,,,,,,,,,, -116300,2.4791296,2.1613355,,,,,,,,,,,,,, -116400,2.7470565,2.3604622,,,,,,,,,,,,,, -116500,2.0986679,3.2846608,,,,,,,,,,,,,, -116600,2.5179992,2.026207,,,,,,,,,,,,,, -116700,2.5670438,2.0253384,,,,,,,,,,,,,, -116758,,,0.7300195097923279,1.0898079872131348,0.671019971370697,1.3476372957229614,50000.0,0.5434000492095947,1.9787174463272093,10000.0,53399.61646103859,59326.91538286209,53399.61646103859,5916.103356122971,5.1391777992248535,0.0 -116800,2.6522958,2.1167293,,,,,,,,,,,,,, -116900,2.5663867,2.1999125,,,,,,,,,,,,,, -117000,2.5282626,2.141559,,,,,,,,,,,,,, -117100,2.6289067,2.153368,,,,,,,,,,,,,, -117200,2.3124924,3.1416514,,,,,,,,,,,,,, -117300,2.7052624,2.0021853,,,,,,,,,,,,,, -117400,2.4112287,4.6539235,,,,,,,,,,,,,, -117500,2.4038713,4.21707,,,,,,,,,,,,,, -117600,2.3815777,4.2082696,,,,,,,,,,,,,, -117679,,,0.7401171922683716,1.0645672082901,0.6676999926567078,1.3838788270950315,50000.0,0.544700026512146,2.0176634788513184,10000.0,53819.964002370834,59794.09345436096,53819.964002370834,5962.832585573196,5.19250226020813,0.0 -117700,2.6296535,1.9760275,,,,,,,,,,,,,, -117800,2.5864785,2.0739226,,,,,,,,,,,,,, -117900,2.370646,4.2247934,,,,,,,,,,,,,, -118000,2.4374769,4.6974025,,,,,,,,,,,,,, -118100,2.5949807,4.1389623,,,,,,,,,,,,,, -118200,2.248126,3.225503,,,,,,,,,,,,,, -118300,2.2596822,2.9887662,,,,,,,,,,,,,, -118400,2.8986785,2.104454,,,,,,,,,,,,,, -118500,2.321492,3.4729924,,,,,,,,,,,,,, -118597,,,0.7274804711341858,1.0996066331863403,0.6718199849128723,1.3471410274505615,50000.0,0.5469000339508057,1.9880914688110352,10000.0,54240.32156014442,60263.14970517159,54240.32156014442,6011.4386677742,5.236751317977905,0.0 -118600,2.450903,2.465352,,,,,,,,,,,,,, -118700,2.650125,2.0292907,,,,,,,,,,,,,, -118800,2.5020003,2.3819146,,,,,,,,,,,,,, -118900,2.3906493,1.9752545,,,,,,,,,,,,,, -119000,2.7167282,2.3165696,,,,,,,,,,,,,, -119100,2.3481166,2.9474397,,,,,,,,,,,,,, -119200,2.6860223,3.6457083,,,,,,,,,,,,,, -119300,2.510998,1.9075648,,,,,,,,,,,,,, -119400,2.4237306,3.306431,,,,,,,,,,,,,, -119500,2.616179,2.1672509,,,,,,,,,,,,,, -119518,,,0.7304491996765137,1.0636045932769775,0.672760009765625,1.335329294204712,50000.0,0.5469000339508057,1.9718530178070068,10000.0,54660.49790549278,60732.12470769882,54660.49790549278,6060.143052816391,5.284197807312012,0.0 -119600,2.32121,3.3481092,,,,,,,,,,,,,, -119700,2.7170932,4.240637,,,,,,,,,,,,,, -119800,2.275267,2.8394837,,,,,,,,,,,,,, -119900,2.6386108,1.9925334,,,,,,,,,,,,,, -120000,2.5140986,2.2956457,,,,,,,,,,,,,, -120100,2.9528108,2.1222491,,,,,,,,,,,,,, -120200,2.4453583,2.315321,,,,,,,,,,,,,, -120300,2.327039,2.698643,,,,,,,,,,,,,, -120400,2.7223375,2.044313,,,,,,,,,,,,,, -120438,,,0.7368554472923279,1.0701631307601929,0.675879955291748,1.3530094623565674,50000.0,0.54830002784729,1.983000874519348,10000.0,55080.59421658516,61201.30026555061,55080.59421658516,6109.130714178085,5.328948974609375,0.0 -120500,2.7757142,2.133061,,,,,,,,,,,,,, -120600,3.1226268,2.073123,,,,,,,,,,,,,, -120700,2.658304,1.9769979,,,,,,,,,,,,,, -120800,2.486303,3.2438655,,,,,,,,,,,,,, -120900,2.7522383,2.2459002,,,,,,,,,,,,,, -121000,2.6224985,3.981721,,,,,,,,,,,,,, -121100,2.6761446,2.038231,,,,,,,,,,,,,, -121200,2.2450824,3.795838,,,,,,,,,,,,,, -121300,2.356185,2.6156955,,,,,,,,,,,,,, -121354,,,0.7291015386581421,1.0866446495056152,0.6756199598312378,1.3281562328338623,50000.0,0.5560000538825989,1.955164909362793,10000.0,55500.81260108948,61668.1078953743,55500.81260108948,6155.618450164795,5.383184432983398,0.0 -121400,2.3194103,3.793734,,,,,,,,,,,,,, -121500,2.946829,1.9479811,,,,,,,,,,,,,, -121600,2.4638898,1.9432827,,,,,,,,,,,,,, -121700,2.683227,2.0250866,,,,,,,,,,,,,, -121800,2.7940512,2.476775,,,,,,,,,,,,,, -121900,2.9650981,2.029975,,,,,,,,,,,,,, -122000,2.9472184,2.0831718,,,,,,,,,,,,,, -122100,2.8492336,1.9554425,,,,,,,,,,,,,, -122200,2.40793,2.4879045,,,,,,,,,,,,,, -122273,,,0.7377538681030273,1.046796798706055,0.6814199686050415,1.3041400909423828,50000.0,0.5575000047683716,1.926209807395935,10000.0,55921.02682614327,62134.41156554222,55921.02682614327,6201.618116378784,5.425713300704956,0.0 -122300,2.751186,2.5210373,,,,,,,,,,,,,, -122400,2.3927343,3.9614482,,,,,,,,,,,,,, -122500,2.746611,2.0534542,,,,,,,,,,,,,, -122600,2.6841218,1.96194,,,,,,,,,,,,,, -122700,2.6303666,1.9848789,,,,,,,,,,,,,, -122800,3.09448,2.0197904,,,,,,,,,,,,,, -122900,2.5396066,2.3226123,,,,,,,,,,,,,, -123000,2.7131608,2.7084062,,,,,,,,,,,,,, -123100,2.3146589,3.274472,,,,,,,,,,,,,, -123189,,,0.7463476657867432,1.0369484424591064,0.6841599941253662,1.3181054592132568,50000.0,0.5567000508308411,1.947321414947509,10000.0,56341.15030384064,62601.69093823433,56341.15030384064,6248.677495479584,5.474867820739746,0.0 -123200,2.7685835,2.0527966,,,,,,,,,,,,,, -123300,2.694618,2.2288291,,,,,,,,,,,,,, -123400,2.959295,1.995882,,,,,,,,,,,,,, -123500,2.731062,2.173432,,,,,,,,,,,,,, -123600,2.6057255,4.1142664,,,,,,,,,,,,,, -123700,2.720747,2.0595577,,,,,,,,,,,,,, -123800,2.7109315,3.9428265,,,,,,,,,,,,,, -123900,2.6604276,3.6135259,,,,,,,,,,,,,, -124000,2.7940314,1.9588859,,,,,,,,,,,,,, -124100,2.93669,2.2257607,,,,,,,,,,,,,, -124108,,,0.7437499761581421,1.0200892686843872,0.6828199625015259,1.2920786142349243,50000.0,0.5599000453948975,1.9192627668380733,10000.0,56761.34786653519,63072.82502889633,56761.34786653519,6299.499836921692,5.528214454650879,0.0 -124200,3.3525918,2.0249405,,,,,,,,,,,,,, -124300,2.5776982,2.0345972,,,,,,,,,,,,,, -124400,2.567429,3.062624,,,,,,,,,,,,,, -124500,2.2945638,3.2325218,,,,,,,,,,,,,, -124600,2.5571194,2.7717836,,,,,,,,,,,,,, -124700,2.4903069,3.3019056,,,,,,,,,,,,,, -124800,3.0630102,1.9181768,,,,,,,,,,,,,, -124900,2.6405783,1.9303987,,,,,,,,,,,,,, -125000,2.7389867,2.1647334,,,,,,,,,,,,,, -125025,,,0.7412695288658142,1.0353752374649048,0.6846599578857422,1.2898250818252563,50000.0,0.5612000226974487,1.9244365692138672,10000.0,57181.43518590927,63537.27227139473,57181.43518590927,6343.759425878525,5.581441879272461,0.0 -125100,3.3907297,1.8666793,,,,,,,,,,,,,, -125200,2.720689,3.7132792,,,,,,,,,,,,,, -125300,2.5451272,2.1291819,,,,,,,,,,,,,, -125400,3.0808144,1.88818,,,,,,,,,,,,,, -125500,2.768266,1.9361088,,,,,,,,,,,,,, -125600,2.6105132,4.04137,,,,,,,,,,,,,, -125700,2.6488402,3.2523031,,,,,,,,,,,,,, -125800,2.738296,2.2413898,,,,,,,,,,,,,, -125900,2.7621713,4.413058,,,,,,,,,,,,,, -125945,,,0.7484374642372131,0.9919676780700684,0.6880999803543091,1.2748708724975586,50000.0,0.5624000430107117,1.912021517753601,10000.0,57601.69718050957,64005.15271496773,57601.69718050957,6391.28493309021,5.625512361526489,0.0 -126000,3.1458426,1.9419131,,,,,,,,,,,,,, -126100,2.5261405,2.3654883,,,,,,,,,,,,,, -126200,2.696609,2.465472,,,,,,,,,,,,,, -126300,2.8397765,4.5127845,,,,,,,,,,,,,, -126400,2.4806561,4.348605,,,,,,,,,,,,,, -126500,2.9521701,1.8577261,,,,,,,,,,,,,, -126600,2.6262152,4.0328083,,,,,,,,,,,,,, -126700,2.7408133,2.1324077,,,,,,,,,,,,,, -126800,2.6202238,1.9438092,,,,,,,,,,,,,, -126864,,,0.76025390625,0.985008418560028,0.687559962272644,1.3039485216140747,50000.0,0.5654000043869019,1.9312610626220703,10000.0,58021.72929406166,64473.019273757935,58021.72929406166,6439.025028705597,5.6715407371521,0.0 -126900,2.5589154,3.2404854,,,,,,,,,,,,,, -127000,2.973908,2.110366,,,,,,,,,,,,,, -127100,2.92232,1.8142042,,,,,,,,,,,,,, -127200,2.6551037,2.3199532,,,,,,,,,,,,,, -127300,2.790574,2.0487006,,,,,,,,,,,,,, -127400,3.011279,1.8589337,,,,,,,,,,,,,, -127500,2.7218757,1.9025813,,,,,,,,,,,,,, -127600,3.2918854,2.0438323,,,,,,,,,,,,,, -127700,2.7620418,2.1823812,,,,,,,,,,,,,, -127784,,,0.7442578077316284,1.032114863395691,0.6894199848175049,1.2755590677261353,50000.0,0.5622000098228455,1.8965699672698968,10000.0,58441.85705113411,64940.28579545021,58441.85705113411,6486.063357114792,5.723682403564453,0.0 -127800,3.0408626,2.0220227,,,,,,,,,,,,,, -127900,2.4366257,3.6275835,,,,,,,,,,,,,, -128000,2.8305323,1.9477668,,,,,,,,,,,,,, -128100,2.7754886,2.557908,,,,,,,,,,,,,, -128200,2.8044853,2.0020409,,,,,,,,,,,,,, -128300,3.1670172,2.0142312,,,,,,,,,,,,,, -128400,2.7063262,4.349627,,,,,,,,,,,,,, -128500,2.6407685,3.425061,,,,,,,,,,,,,, -128600,2.5418966,2.8303266,,,,,,,,,,,,,, -128700,2.4863086,3.3136826,,,,,,,,,,,,,, -128703,,,0.7488671541213989,1.0089884996414185,0.6900999546051025,1.273664474487305,50000.0,0.5627000331878662,1.8980499505996704,10000.0,58862.21596360207,65407.78814959526,58862.21596360207,6533.113852500916,5.769906282424927,0.0 -128800,3.0233474,1.8635607,,,,,,,,,,,,,, -128900,3.0945392,2.2349637,,,,,,,,,,,,,, -129000,2.724022,3.1851025,,,,,,,,,,,,,, -129100,2.7409506,2.8942506,,,,,,,,,,,,,, -129200,2.7531106,2.6593144,,,,,,,,,,,,,, -129300,3.045559,1.9449829,,,,,,,,,,,,,, -129400,3.125828,1.8767543,,,,,,,,,,,,,, -129500,2.5650885,2.5823927,,,,,,,,,,,,,, -129600,3.0619059,2.0678377,,,,,,,,,,,,,, -129621,,,0.7646874785423279,0.932016670703888,0.6924799680709839,1.245804786682129,50000.0,0.5667999982833862,1.8743077516555784,10000.0,59282.16807126999,65875.81766748428,59282.16807126999,6581.099862098694,5.814780950546265,0.0 -129700,2.6064065,4.4643235,,,,,,,,,,,,,, -129800,2.9311671,1.8120472,,,,,,,,,,,,,, -129900,2.8234081,4.48714,,,,,,,,,,,,,, -130000,2.8640532,2.5695806,,,,,,,,,,,,,, -130100,2.6872647,3.0814273,,,,,,,,,,,,,, -130200,2.7471359,3.56032,,,,,,,,,,,,,, -130300,3.2285218,1.9162108,,,,,,,,,,,,,, -130400,2.8193088,3.7655895,,,,,,,,,,,,,, -130500,2.8698387,1.8584051,,,,,,,,,,,,,, -130539,,,0.7518749833106995,0.9943748116493224,0.6937199831008911,1.2452759742736816,50000.0,0.5671000480651855,1.879859447479248,10000.0,59702.48772835732,66345.67736411095,59702.48772835732,6630.546304941177,5.8607916831970215,0.0 -130600,2.6561635,3.7811506,,,,,,,,,,,,,, -130700,2.9837916,1.8842776,,,,,,,,,,,,,, -130800,2.489008,3.0796502,,,,,,,,,,,,,, -130900,3.2484102,3.9051237,,,,,,,,,,,,,, -131000,2.782349,1.8517013,,,,,,,,,,,,,, -131100,2.8763704,3.4577003,,,,,,,,,,,,,, -131200,2.9616349,1.9693639,,,,,,,,,,,,,, -131300,3.1607215,3.978379,,,,,,,,,,,,,, -131400,3.0977218,3.373805,,,,,,,,,,,,,, -131459,,,0.7592187523841858,0.9559024572372437,0.6979999542236328,1.223264455795288,50000.0,0.5730000138282776,1.8423492908477783,10000.0,60122.77754402161,66813.98977065086,60122.77754402161,6678.470661401749,5.9110212326049805,0.0 -131500,2.9208174,4.3091307,,,,,,,,,,,,,, -131600,2.8672743,2.249496,,,,,,,,,,,,,, -131700,2.8911595,1.8867863,,,,,,,,,,,,,, -131800,3.0116365,2.0044372,,,,,,,,,,,,,, -131900,3.0769265,1.83256,,,,,,,,,,,,,, -132000,3.314306,2.0979857,,,,,,,,,,,,,, -132100,3.1287422,1.9653981,,,,,,,,,,,,,, -132200,3.3519971,4.109476,,,,,,,,,,,,,, -132300,3.0488646,1.9421136,,,,,,,,,,,,,, -132379,,,0.7625390291213989,0.9606295824050904,0.6951199769973755,1.261569857597351,50000.0,0.5699000358581543,1.8848967552185056,10000.0,60542.986839056015,67283.923807621,60542.986839056015,6728.100863933563,5.9576075077056885,0.0 -132400,3.333085,1.98661,,,,,,,,,,,,,, -132500,3.124548,2.1344416,,,,,,,,,,,,,, -132600,3.01737,1.824412,,,,,,,,,,,,,, -132700,3.3697228,1.845341,,,,,,,,,,,,,, -132800,2.7757385,3.1930914,,,,,,,,,,,,,, -132900,3.0454872,1.9659925,,,,,,,,,,,,,, -133000,2.9765263,2.218493,,,,,,,,,,,,,, -133100,3.4420967,1.9210063,,,,,,,,,,,,,, -133200,3.134686,1.789742,,,,,,,,,,,,,, -133298,,,0.7527148127555847,0.982952117919922,0.7013199925422668,1.2232753038406372,50000.0,0.5729000568389893,1.8487434387207031,10000.0,60963.19054841995,67751.67906785011,60963.19054841995,6775.558528184891,6.004410266876221,0.0 -133300,2.7511935,3.0911713,,,,,,,,,,,,,, -133400,3.1676223,1.9980865,,,,,,,,,,,,,, -133500,3.263329,1.9794276,,,,,,,,,,,,,, -133600,3.0398521,4.118261,,,,,,,,,,,,,, -133700,3.0740128,3.5276828,,,,,,,,,,,,,, -133800,3.2571783,1.8908811,,,,,,,,,,,,,, -133900,2.9688613,1.9914184,,,,,,,,,,,,,, -134000,2.7760377,2.414723,,,,,,,,,,,,,, -134100,2.8367198,2.4940238,,,,,,,,,,,,,, -134200,3.4814491,2.84465,,,,,,,,,,,,,, -134216,,,0.7622656226158142,0.935705542564392,0.701259970664978,1.2122747898101809,50000.0,0.5751000046730042,1.8352450132369995,10000.0,61383.45324897766,68218.22201561928,61383.45324897766,6821.74352645874,6.052812814712524,0.0 -134300,3.3948917,4.283088,,,,,,,,,,,,,, -134400,2.975632,2.3238776,,,,,,,,,,,,,, -134500,3.1591558,2.1487622,,,,,,,,,,,,,, -134600,3.2247732,1.8434191,,,,,,,,,,,,,, -134700,2.7718546,2.6637797,,,,,,,,,,,,,, -134800,2.8324938,2.910553,,,,,,,,,,,,,, -134900,3.0586348,2.0752034,,,,,,,,,,,,,, -135000,2.7501745,2.6924016,,,,,,,,,,,,,, -135100,3.201216,1.7973589,,,,,,,,,,,,,, -135131,,,0.7727343440055847,0.8989397287368774,0.7053200006484985,1.198777675628662,50000.0,0.5815000534057617,1.8275467157363887,10000.0,61803.36758232117,68685.2418153286,61803.36758232117,6868.752334356308,6.101391077041626,0.0 -135200,3.2844646,1.7502936,,,,,,,,,,,,,, -135300,3.3209891,1.9829267,,,,,,,,,,,,,, -135400,3.341568,1.7598684,,,,,,,,,,,,,, -135500,3.4811747,1.9568298,,,,,,,,,,,,,, -135600,3.10337,2.015767,,,,,,,,,,,,,, -135700,3.2273152,2.1632972,,,,,,,,,,,,,, -135800,4.1905985,2.0061326,,,,,,,,,,,,,, -135900,3.3219097,2.122885,,,,,,,,,,,,,, -136000,3.181255,2.4832664,,,,,,,,,,,,,, -136050,,,0.77685546875,0.8904039859771729,0.7037000060081482,1.207800030708313,50000.0,0.5791000127792358,1.836230993270874,10000.0,62223.56702518463,69154.49729061127,62223.56702518463,6917.71401143074,6.14823055267334,0.0 -136100,3.038761,2.4473152,,,,,,,,,,,,,, -136200,3.5132928,1.8551435,,,,,,,,,,,,,, -136300,3.0772476,3.093696,,,,,,,,,,,,,, -136400,3.2206204,1.8176615,,,,,,,,,,,,,, -136500,3.6047392,1.886403,,,,,,,,,,,,,, -136600,3.0838926,1.8083439,,,,,,,,,,,,,, -136700,3.454201,1.8503904,,,,,,,,,,,,,, -136800,3.2662055,1.9554281,,,,,,,,,,,,,, -136900,3.2943702,1.7820855,,,,,,,,,,,,,, -136970,,,0.7667577862739563,0.9236660003662108,0.7049799561500549,1.1936697959899902,50000.0,0.579300045967102,1.815743088722229,10000.0,62643.62982487679,69623.43274188042,62643.62982487679,6966.49352145195,6.19426703453064,0.0 -137000,3.3791757,1.7436054,,,,,,,,,,,,,, -137100,2.9767346,3.6111054,,,,,,,,,,,,,, -137200,3.2603564,1.9917141,,,,,,,,,,,,,, -137300,3.6190517,1.8371334,,,,,,,,,,,,,, -137400,3.4431942,1.9215957,,,,,,,,,,,,,, -137500,3.3353918,3.9554586,,,,,,,,,,,,,, -137600,3.2249913,1.9866021,,,,,,,,,,,,,, -137700,3.1931958,1.7867078,,,,,,,,,,,,,, -137800,3.5042367,1.7613487,,,,,,,,,,,,,, -137890,,,0.772753894329071,0.902352213859558,0.706279993057251,1.1953895092010498,50000.0,0.5799000263214111,1.819128155708313,10000.0,63063.79121661186,70090.73695373535,63063.79121661186,7013.5392434597015,6.243442058563232,0.0 -137900,3.3270004,1.7818993,,,,,,,,,,,,,, -138000,4.020118,1.8584542,,,,,,,,,,,,,, -138100,3.1183705,3.057457,,,,,,,,,,,,,, -138200,3.2880101,1.8074301,,,,,,,,,,,,,, -138300,4.246114,1.7283144,,,,,,,,,,,,,, -138400,3.0998318,3.9576893,,,,,,,,,,,,,, -138500,3.221896,1.8581185,,,,,,,,,,,,,, -138600,3.3372173,1.8517267,,,,,,,,,,,,,, -138700,3.1404955,1.7364151,,,,,,,,,,,,,, -138800,3.2724652,1.8065481,,,,,,,,,,,,,, -138808,,,0.78382807970047,0.8481549024581909,0.7096999883651733,1.173912525177002,50000.0,0.5788000226020813,1.7924264669418335,10000.0,63483.9806933403,70560.31663990021,63483.9806933403,7062.826720952988,6.297896862030029,0.0 -138900,3.5972447,1.8564119,,,,,,,,,,,,,, -139000,3.380555,3.365274,,,,,,,,,,,,,, -139100,3.2053983,1.7659321,,,,,,,,,,,,,, -139200,3.4638047,1.7858946,,,,,,,,,,,,,, -139300,3.5944443,2.160381,,,,,,,,,,,,,, -139400,2.9181232,3.4328468,,,,,,,,,,,,,, -139500,3.1801198,3.8813076,,,,,,,,,,,,,, -139600,3.4924417,1.8586044,,,,,,,,,,,,,, -139700,3.4492385,1.7871956,,,,,,,,,,,,,, -139726,,,0.7791601419448853,0.8667029142379761,0.7123000025749207,1.1546515226364136,50000.0,0.5897000432014465,1.7686362266540527,10000.0,63903.96551418304,71027.57639861107,63903.96551418304,7110.006555318832,6.345866918563843,0.0 -139800,3.3738189,3.2860088,,,,,,,,,,,,,, -139900,3.2494483,2.3187335,,,,,,,,,,,,,, -140000,3.1214688,3.1082385,,,,,,,,,,,,,, -140100,3.636519,1.64029,,,,,,,,,,,,,, -140200,3.596481,1.7616116,,,,,,,,,,,,,, -140300,3.5784125,1.9084704,,,,,,,,,,,,,, -140400,3.22736,3.6435142,,,,,,,,,,,,,, -140500,3.1068916,2.778686,,,,,,,,,,,,,, -140600,3.5719788,1.8373188,,,,,,,,,,,,,, -140642,,,0.7782421708106995,0.8562284708023071,0.7149199843406677,1.147608757019043,50000.0,0.5906000137329102,1.7629472017288208,10000.0,64324.39351463318,71496.82323789597,64324.39351463318,7158.72931265831,6.395110607147217,0.0 -140700,3.412288,1.8659639,,,,,,,,,,,,,, -140800,3.592115,1.8447096,,,,,,,,,,,,,, -140900,3.7368598,1.8175578,,,,,,,,,,,,,, -141000,3.2275684,3.8187203,,,,,,,,,,,,,, -141100,3.027631,2.2525241,,,,,,,,,,,,,, -141200,3.570252,1.8181636,,,,,,,,,,,,,, -141300,3.6635864,1.7210451,,,,,,,,,,,,,, -141400,3.4421012,1.8341596,,,,,,,,,,,,,, -141500,3.3401005,3.8313584,,,,,,,,,,,,,, -141561,,,0.7863867282867432,0.8447074890136719,0.7131199836730957,1.16819167137146,50000.0,0.588200032711029,1.8064215183258057,10000.0,64744.67516851425,71965.00109434128,64744.67516851425,7206.5280418396,6.445488691329956,0.0 -141600,3.6427069,1.7619548,,,,,,,,,,,,,, -141700,3.636693,1.9006901,,,,,,,,,,,,,, -141800,3.5284636,2.7317219,,,,,,,,,,,,,, -141900,3.561645,1.753759,,,,,,,,,,,,,, -142000,3.10432,3.3811786,,,,,,,,,,,,,, -142100,3.212291,3.3344517,,,,,,,,,,,,,, -142200,3.5304885,1.8301613,,,,,,,,,,,,,, -142300,3.8151374,1.8286349,,,,,,,,,,,,,, -142400,3.7534158,3.962431,,,,,,,,,,,,,, -142474,,,0.7731249928474426,0.8849400281906128,0.7135399580001831,1.1557042598724363,50000.0,0.5914000272750854,1.762101411819458,10000.0,65164.6888320446,72433.90739750862,65164.6888320446,7255.32564163208,6.493313312530518,0.0 -142500,3.3521883,3.0848904,,,,,,,,,,,,,, -142600,3.6284494,1.7534386,,,,,,,,,,,,,, -142700,3.902366,1.7152663,,,,,,,,,,,,,, -142800,3.925217,3.2487643,,,,,,,,,,,,,, -142900,3.7694056,1.8684486,,,,,,,,,,,,,, -143000,3.4207802,3.8118157,,,,,,,,,,,,,, -143100,3.7617397,2.032164,,,,,,,,,,,,,, -143200,3.708654,3.6191485,,,,,,,,,,,,,, -143300,3.5380447,2.1172771,,,,,,,,,,,,,, -143390,,,0.7816015481948853,0.868894100189209,0.716219961643219,1.1561012268066406,50000.0,0.5877000093460083,1.783122181892395,10000.0,65585.04565286636,72904.4974284172,65585.04565286636,7305.463745594025,6.540487051010132,0.0 -143400,4.044438,1.7896775,,,,,,,,,,,,,, -143500,3.8727071,1.7906823,,,,,,,,,,,,,, -143600,3.17541,2.5461848,,,,,,,,,,,,,, -143700,3.5763483,2.0299118,,,,,,,,,,,,,, -143800,3.9817634,1.9335626,,,,,,,,,,,,,, -143900,3.6819215,1.7024574,,,,,,,,,,,,,, -144000,3.8659058,1.9182136,,,,,,,,,,,,,, -144100,3.553923,3.5001967,,,,,,,,,,,,,, -144200,3.5443132,2.3855789,,,,,,,,,,,,,, -144300,3.46574,2.1308162,,,,,,,,,,,,,, -144308,,,0.79212886095047,0.8323768377304077,0.7192999720573425,1.145952582359314,50000.0,0.5929000377655029,1.7589880228042605,10000.0,66004.958309412,73373.15925145149,66004.958309412,7354.108882188797,6.597029685974121,0.0 -144400,3.6535301,1.6833485,,,,,,,,,,,,,, -144500,3.6887143,1.8874204,,,,,,,,,,,,,, -144600,3.747896,3.1033509,,,,,,,,,,,,,, -144700,3.3459687,2.4343758,,,,,,,,,,,,,, -144800,3.4578848,2.1330452,,,,,,,,,,,,,, -144900,3.7287836,2.8827908,,,,,,,,,,,,,, -145000,3.9394586,2.3282886,,,,,,,,,,,,,, -145100,3.6395168,2.061342,,,,,,,,,,,,,, -145200,3.9734802,1.9166698,,,,,,,,,,,,,, -145228,,,0.7823046445846558,0.8637334108352661,0.7188799977302551,1.1443239450454712,50000.0,0.593000054359436,1.7642781734466553,10000.0,66425.3852956295,73841.35216474533,66425.3852956295,7401.779304265976,6.644883871078491,0.0 -145300,3.8524315,2.0423126,,,,,,,,,,,,,, -145400,3.6180415,3.5588174,,,,,,,,,,,,,, -145500,3.5785809,2.2040427,,,,,,,,,,,,,, -145600,4.368666,4.197726,,,,,,,,,,,,,, -145700,3.5924962,1.8012226,,,,,,,,,,,,,, -145800,4.129752,1.6668799,,,,,,,,,,,,,, -145900,3.6801803,1.7340817,,,,,,,,,,,,,, -146000,4.029292,1.7259176,,,,,,,,,,,,,, -146100,3.6504092,2.964161,,,,,,,,,,,,,, -146149,,,0.7882617115974426,0.8386130928993225,0.7200599908828735,1.130061388015747,50000.0,0.5923000574111938,1.7559378147125244,10000.0,66845.65575814247,74312.01202607155,66845.65575814247,7452.07506108284,6.691014051437378,0.0 -146200,3.5997095,3.9639935,,,,,,,,,,,,,, -146300,3.960226,1.5498104,,,,,,,,,,,,,, -146400,4.079673,4.1521497,,,,,,,,,,,,,, -146500,3.9642518,1.657956,,,,,,,,,,,,,, -146600,3.9217527,3.892364,,,,,,,,,,,,,, -146700,4.638708,1.7707002,,,,,,,,,,,,,, -146800,4.201485,1.6529019,,,,,,,,,,,,,, -146900,3.705106,1.628701,,,,,,,,,,,,,, -147000,3.975561,2.011771,,,,,,,,,,,,,, -147070,,,0.7914062142372131,0.8205158710479736,0.7212799787521362,1.122092366218567,50000.0,0.5910000205039978,1.7555046081542969,10000.0,67265.79447817802,74781.38368654251,67265.79447817802,7501.210858821869,6.740800857543945,0.0 -147100,4.1178617,1.6649475,,,,,,,,,,,,,, -147200,3.7112944,2.438682,,,,,,,,,,,,,, -147300,3.9792407,1.6409678,,,,,,,,,,,,,, -147400,3.3311262,2.7451503,,,,,,,,,,,,,, -147500,4.2758617,3.6161194,,,,,,,,,,,,,, -147600,3.6616194,2.7023454,,,,,,,,,,,,,, -147700,4.071072,2.2166932,,,,,,,,,,,,,, -147800,3.9761803,1.7336086,,,,,,,,,,,,,, -147900,3.6329317,1.6071799,,,,,,,,,,,,,, -147991,,,0.8050585985183716,0.7596969604492188,0.724079966545105,1.102674126625061,50000.0,0.6055999994277954,1.719403624534607,10000.0,67685.98072981834,75250.80831742287,67685.98072981834,7550.349389076233,6.793242692947388,0.0 -148000,4.2198787,2.0647206,,,,,,,,,,,,,, -148100,3.9446478,3.5780497,,,,,,,,,,,,,, -148200,4.323691,1.6298139,,,,,,,,,,,,,, -148300,3.664089,1.7288473,,,,,,,,,,,,,, -148400,4.291796,1.7237182,,,,,,,,,,,,,, -148500,4.0764885,3.8657312,,,,,,,,,,,,,, -148600,4.3874664,1.8120106,,,,,,,,,,,,,, -148700,3.7566535,2.0895176,,,,,,,,,,,,,, -148800,4.401804,1.8653426,,,,,,,,,,,,,, -148900,3.9679759,2.2155879,,,,,,,,,,,,,, -148910,,,0.79408198595047,0.8039225339889526,0.7251399755477905,1.0975062847137451,50000.0,0.5976999998092651,1.717188000679016,10000.0,68106.26304864883,75720.92788505554,68106.26304864883,7600.08563709259,6.847251892089844,0.0 -149000,4.729565,3.9646544,,,,,,,,,,,,,, -149100,3.8374617,1.7340205,,,,,,,,,,,,,, -149200,3.759051,1.9678159,,,,,,,,,,,,,, -149300,3.6103857,2.6691232,,,,,,,,,,,,,, -149400,3.6830242,2.2415364,,,,,,,,,,,,,, -149500,4.144462,3.6662972,,,,,,,,,,,,,, -149600,3.9787605,3.2435675,,,,,,,,,,,,,, -149700,4.4450817,1.8441828,,,,,,,,,,,,,, -149800,4.154763,1.5789901,,,,,,,,,,,,,, -149831,,,0.7937890291213989,0.8100752830505371,0.7240999937057495,1.1172319650650024,50000.0,0.5993000268936157,1.7362432479858398,10000.0,68526.58949446678,76190.50789570808,68526.58949446678,7649.241825342178,6.897162199020386,0.0 -149900,4.0499167,1.7410902,,,,,,,,,,,,,, -150000,4.162111,1.8021661,,,,,,,,,,,,,, -150100,4.4465947,1.6708347,,,,,,,,,,,,,, -150200,4.4187384,3.5837405,,,,,,,,,,,,,, -150300,4.550783,3.9946885,,,,,,,,,,,,,, -150400,4.1812773,1.842498,,,,,,,,,,,,,, -150500,4.007715,1.7202778,,,,,,,,,,,,,, -150600,4.129707,1.908577,,,,,,,,,,,,,, -150700,4.3587575,3.5980651,,,,,,,,,,,,,, -150752,,,0.8075780868530273,0.7453005909919739,0.7288599610328674,1.078874945640564,50000.0,0.6023000478744507,1.6958317756652832,10000.0,68946.71733403206,76659.31583476067,68946.71733403206,7697.823964357376,6.947785139083862,0.0 -150800,4.1104035,2.0728111,,,,,,,,,,,,,, -150900,4.370633,3.49104,,,,,,,,,,,,,, -151000,4.4041996,1.6556258,,,,,,,,,,,,,, -151100,4.2141285,1.6086282,,,,,,,,,,,,,, -151200,4.2387466,1.8117421,,,,,,,,,,,,,, -151300,4.2723227,1.9721009,,,,,,,,,,,,,, -151400,4.108684,1.5650777,,,,,,,,,,,,,, -151500,4.2705235,2.079704,,,,,,,,,,,,,, -151600,4.3558836,2.4455698,,,,,,,,,,,,,, -151672,,,0.7996875047683716,0.7857888340950012,0.7291399836540222,1.0955058336257937,50000.0,0.6053000092506409,1.7195252180099487,10000.0,69366.65239262581,77128.69496536255,69366.65239262581,7747.169746875763,6.998458623886108,0.0 -151700,3.9704423,2.2894769,,,,,,,,,,,,,, -151800,4.2616577,1.8941492,,,,,,,,,,,,,, -151900,3.797516,3.291678,,,,,,,,,,,,,, -152000,4.1386995,1.6003106,,,,,,,,,,,,,, -152100,3.916546,2.617878,,,,,,,,,,,,,, -152200,4.5087767,3.4633636,,,,,,,,,,,,,, -152300,4.271555,1.8002391,,,,,,,,,,,,,, -152400,4.4968314,1.5625973,,,,,,,,,,,,,, -152500,4.721171,1.579136,,,,,,,,,,,,,, -152590,,,0.8040820360183716,0.7654426693916321,0.7314599752426147,1.074555277824402,50000.0,0.6079000234603882,1.6825690269470217,10000.0,69787.05628466606,77600.06755399704,69787.05628466606,7798.037398815155,7.052371025085449,0.0 -152600,4.565183,1.7014256,,,,,,,,,,,,,, -152700,4.6729426,3.72607,,,,,,,,,,,,,, -152800,3.7875044,2.3935485,,,,,,,,,,,,,, -152900,4.3575063,1.6340837,,,,,,,,,,,,,, -153000,4.150306,2.3907666,,,,,,,,,,,,,, -153100,4.6678724,1.6441574,,,,,,,,,,,,,, -153200,4.2577357,1.5803173,,,,,,,,,,,,,, -153300,4.3479257,1.8215318,,,,,,,,,,,,,, -153400,4.3352013,1.7121047,,,,,,,,,,,,,, -153500,4.33766,1.7343837,,,,,,,,,,,,,, -153511,,,0.8078711032867432,0.7358626127243042,0.7321999669075012,1.064021110534668,50000.0,0.6045000553131104,1.678421974182129,10000.0,70207.12128043175,78073.42539596558,70207.12128043175,7851.229133844376,7.106428384780884,0.0 -153600,5.070888,1.6096406,,,,,,,,,,,,,, -153700,5.0849404,3.8157403,,,,,,,,,,,,,, -153800,4.9614534,3.5175605,,,,,,,,,,,,,, -153900,4.4569616,1.9535873,,,,,,,,,,,,,, -154000,4.8756547,1.6953028,,,,,,,,,,,,,, -154100,4.4697833,1.6205485,,,,,,,,,,,,,, -154200,4.390603,1.7947353,,,,,,,,,,,,,, -154300,4.4689465,1.652148,,,,,,,,,,,,,, -154400,5.0368505,2.415085,,,,,,,,,,,,,, -154432,,,0.8040429353713989,0.7815302014350891,0.7331399917602539,1.0862562656402588,50000.0,0.6107000112533569,1.7023465633392334,10000.0,70627.35976624489,78544.17288303375,70627.35976624489,7901.632784366608,7.164010286331177,0.0 -154500,4.0939813,2.9932175,,,,,,,,,,,,,, -154600,4.393245,3.2266016,,,,,,,,,,,,,, -154700,4.7116446,1.507281,,,,,,,,,,,,,, -154800,4.176491,2.0396972,,,,,,,,,,,,,, -154900,5.4387536,1.7556936,,,,,,,,,,,,,, -155000,4.4563165,1.5796473,,,,,,,,,,,,,, -155100,4.49778,1.622685,,,,,,,,,,,,,, -155200,4.3124375,1.8022258,,,,,,,,,,,,,, -155300,4.983905,3.932715,,,,,,,,,,,,,, -155351,,,0.804492175579071,0.760875940322876,0.7349199652671814,1.0727885961532593,50000.0,0.6113000512123108,1.681421399116516,10000.0,71047.48124408722,79016.50269341469,71047.48124408722,7953.746181964874,7.211568593978882,0.0 -155400,4.4440303,1.5426893,,,,,,,,,,,,,, -155500,4.2256203,2.1899,,,,,,,,,,,,,, -155600,4.9449773,3.9882376,,,,,,,,,,,,,, -155700,3.9046097,2.6643403,,,,,,,,,,,,,, -155800,4.565972,3.4293554,,,,,,,,,,,,,, -155900,4.54543,2.2373948,,,,,,,,,,,,,, -156000,4.5038257,2.0007207,,,,,,,,,,,,,, -156100,5.2833014,1.6275741,,,,,,,,,,,,,, -156200,4.593781,1.4983634,,,,,,,,,,,,,, -156269,,,0.8096483945846558,0.7489880323410034,0.73499995470047,1.0660606622695925,50000.0,0.6067000031471252,1.6834540367126465,10000.0,71467.49888920784,79486.18399739265,71467.49888920784,8003.304003477097,7.269599676132202,0.0 -156300,5.139183,1.5339607,,,,,,,,,,,,,, -156400,4.6104875,1.6908839,,,,,,,,,,,,,, -156500,4.6521125,2.962753,,,,,,,,,,,,,, -156600,4.6093507,2.1453652,,,,,,,,,,,,,, -156700,4.906398,1.9244021,,,,,,,,,,,,,, -156800,5.120648,3.8301516,,,,,,,,,,,,,, -156900,4.756276,1.5261801,,,,,,,,,,,,,, -157000,4.3434243,2.1120157,,,,,,,,,,,,,, -157100,4.3104863,1.6585791,,,,,,,,,,,,,, -157189,,,0.8141406178474426,0.7198129296302795,0.7380399703979492,1.0487289428710938,50000.0,0.617400050163269,1.6608270406723022,10000.0,71887.46869874,79955.67472600937,71887.46869874,8052.717717885971,7.328284025192261,0.0 -157200,4.9141045,2.510078,,,,,,,,,,,,,, -157300,4.747442,1.5844753,,,,,,,,,,,,,, -157400,4.8040967,2.5790224,,,,,,,,,,,,,, -157500,5.53979,1.615738,,,,,,,,,,,,,, -157600,5.1086287,3.4141502,,,,,,,,,,,,,, -157700,4.801928,1.5683997,,,,,,,,,,,,,, -157800,4.893815,1.6409149,,,,,,,,,,,,,, -157900,5.094948,1.4735413,,,,,,,,,,,,,, -158000,4.80248,1.465798,,,,,,,,,,,,,, -158100,4.5685725,2.8798542,,,,,,,,,,,,,, -158105,,,0.813769519329071,0.7237128019332886,0.7416200041770935,1.0368276834487915,50000.0,0.6190000176429749,1.6595739126205444,10000.0,72307.4182343483,80422.82166051865,72307.4182343483,8099.809978485107,7.386165380477905,0.0 -158200,5.1695004,1.7427226,,,,,,,,,,,,,, -158300,4.470933,1.5614948,,,,,,,,,,,,,, -158400,4.84463,1.566567,,,,,,,,,,,,,, -158500,5.3902054,3.848052,,,,,,,,,,,,,, -158600,4.691293,1.9050862,,,,,,,,,,,,,, -158700,5.3389163,3.565616,,,,,,,,,,,,,, -158800,4.498363,2.1701117,,,,,,,,,,,,,, -158900,4.575175,1.439934,,,,,,,,,,,,,, -159000,4.7580376,2.7688918,,,,,,,,,,,,,, -159025,,,0.8179101347923279,0.6971430778503418,0.7404400110244751,1.028631567955017,50000.0,0.6154000163078308,1.6334420442581177,10000.0,72727.65055274963,80892.5890173912,72727.65055274963,8149.245040655136,7.438138723373413,0.0 -159100,4.742898,3.1936243,,,,,,,,,,,,,, -159200,4.826342,1.7023658,,,,,,,,,,,,,, -159300,5.020297,1.6174166,,,,,,,,,,,,,, -159400,4.845443,1.9897779,,,,,,,,,,,,,, -159500,4.9419093,1.4619246,,,,,,,,,,,,,, -159600,5.4880056,1.5185009,,,,,,,,,,,,,, -159700,5.387671,1.5671335,,,,,,,,,,,,,, -159800,5.2569575,1.6352004,,,,,,,,,,,,,, -159900,5.667787,2.988666,,,,,,,,,,,,,, -159945,,,0.8225781321525574,0.6903254389762878,0.7412799596786499,1.0446081161499023,50000.0,0.6166000366210938,1.6675461530685425,10000.0,73147.87590813637,81362.51534724236,73147.87590813637,8198.83842921257,7.497429847717285,0.0 -160000,4.853368,2.4854698,,,,,,,,,,,,,, -160100,4.9519315,1.5786319,,,,,,,,,,,,,, -160200,5.2313976,2.1226525,,,,,,,,,,,,,, -160300,4.810093,3.1912444,,,,,,,,,,,,,, -160400,5.3586087,1.5038495,,,,,,,,,,,,,, -160500,4.649758,2.314806,,,,,,,,,,,,,, -160600,5.1027474,2.5289812,,,,,,,,,,,,,, -160700,4.9584394,2.985771,,,,,,,,,,,,,, -160800,4.8694987,3.3780246,,,,,,,,,,,,,, -160856,,,0.8190624713897705,0.7006577253341675,0.7416999936103821,1.023085355758667,50000.0,0.622700035572052,1.641365885734558,10000.0,73568.24109864235,81832.5364947319,73568.24109864235,8248.3914706707,7.553764581680298,0.0 -160900,5.6313124,1.4962443,,,,,,,,,,,,,, -161000,4.726449,2.5557718,,,,,,,,,,,,,, -161100,6.2479563,1.6316065,,,,,,,,,,,,,, -161200,5.2781844,1.4435836,,,,,,,,,,,,,, -161300,5.3704715,1.6136377,,,,,,,,,,,,,, -161400,5.157618,3.1652498,,,,,,,,,,,,,, -161500,5.6259,1.4329069,,,,,,,,,,,,,, -161600,4.963749,1.5229998,,,,,,,,,,,,,, -161700,5.490807,1.910062,,,,,,,,,,,,,, -161774,,,0.8226171731948853,0.693584680557251,0.7432399988174438,1.023402214050293,50000.0,0.6225000023841858,1.6433871984481812,10000.0,73988.53427219391,82304.36846089363,73988.53427219391,8299.826081991196,7.611063718795776,0.0 -161800,5.434681,3.239142,,,,,,,,,,,,,, -161900,5.294899,1.5566076,,,,,,,,,,,,,, -162000,5.252085,2.4990342,,,,,,,,,,,,,, -162100,5.366401,1.4539069,,,,,,,,,,,,,, -162200,5.509292,1.5702138,,,,,,,,,,,,,, -162300,5.9499216,3.6611044,,,,,,,,,,,,,, -162400,5.899906,1.5692315,,,,,,,,,,,,,, -162500,5.444803,1.4318339,,,,,,,,,,,,,, -162600,5.4576097,1.7462178,,,,,,,,,,,,,, -162692,,,0.8324609398841858,0.6406580805778503,0.7445999979972839,1.0042402744293213,50000.0,0.6247000098228455,1.6267750263214111,10000.0,74408.44389462471,82774.69685792923,74408.44389462471,8350.1466858387,7.661406278610229,0.0 -162700,5.3101473,2.1156456,,,,,,,,,,,,,, -162800,5.952078,3.9481945,,,,,,,,,,,,,, -162900,5.4023643,1.4867516,,,,,,,,,,,,,, -163000,5.6181383,1.5717157,,,,,,,,,,,,,, -163100,5.322709,1.4971728,,,,,,,,,,,,,, -163200,4.794698,1.7714698,,,,,,,,,,,,,, -163300,5.372728,2.5813444,,,,,,,,,,,,,, -163400,4.8119736,1.953465,,,,,,,,,,,,,, -163500,5.1683035,1.3652085,,,,,,,,,,,,,, -163600,6.039439,3.7590384,,,,,,,,,,,,,, -163612,,,0.8236718773841858,0.6884275078773499,0.7479199767112732,1.008457899093628,50000.0,0.6284000277519226,1.6180665493011477,10000.0,74828.47229576111,83244.47105193138,74828.47229576111,8399.791483402252,7.7152698040008545,0.0 -163700,5.180855,1.5056269,,,,,,,,,,,,,, -163800,5.338652,1.4982435,,,,,,,,,,,,,, -163900,5.0997205,1.4867227,,,,,,,,,,,,,, -164000,5.366918,1.5309519,,,,,,,,,,,,,, -164100,4.8644924,2.780995,,,,,,,,,,,,,, -164200,5.782103,1.6977313,,,,,,,,,,,,,, -164300,5.178328,1.5036817,,,,,,,,,,,,,, -164400,5.203119,2.9028826,,,,,,,,,,,,,, -164500,5.6791234,1.4839346,,,,,,,,,,,,,, -164533,,,0.8290624618530273,0.662371039390564,0.7500999569892883,1.003365993499756,50000.0,0.6251000165939331,1.6135456562042236,10000.0,75248.7256937027,83715.26717567444,75248.7256937027,8450.2347574234,7.767882108688354,0.0 -164600,5.6250453,1.5131025,,,,,,,,,,,,,, -164700,5.6453485,1.5277154,,,,,,,,,,,,,, -164800,5.6512856,1.4599103,,,,,,,,,,,,,, -164900,5.5741134,1.6827391,,,,,,,,,,,,,, -165000,5.036254,1.4910175,,,,,,,,,,,,,, -165100,5.500958,1.5008143,,,,,,,,,,,,,, -165200,5.6338778,1.517099,,,,,,,,,,,,,, -165300,7.335957,3.8685722,,,,,,,,,,,,,, -165400,5.8361955,1.4874704,,,,,,,,,,,,,, -165452,,,0.8309765458106995,0.6480200886726379,0.7496799826622009,0.990441918373108,50000.0,0.6278000473976135,1.5892136096954346,10000.0,75669.07854604721,84185.15949583054,75669.07854604721,8499.674641370773,7.819833278656006,0.0 -165500,4.670807,1.9734162,,,,,,,,,,,,,, -165600,5.8788977,1.6019075,,,,,,,,,,,,,, -165700,5.774631,1.4413366,,,,,,,,,,,,,, -165800,6.0524817,3.5523462,,,,,,,,,,,,,, -165900,6.0635514,3.885696,,,,,,,,,,,,,, -166000,5.79853,1.4574587,,,,,,,,,,,,,, -166100,5.1992693,1.6865572,,,,,,,,,,,,,, -166200,5.8285484,1.4776026,,,,,,,,,,,,,, -166300,5.484957,1.6094776,,,,,,,,,,,,,, -166372,,,0.8332812190055847,0.6423313617706299,0.7530800104141235,0.9780200719833374,50000.0,0.6302000284194946,1.583455204963684,10000.0,76089.07453298569,84657.53774857521,76089.07453298569,8551.956509113312,7.872750520706177,0.0 -166400,6.034084,3.49137,,,,,,,,,,,,,, -166500,6.603238,1.5944238,,,,,,,,,,,,,, -166600,7.0229506,3.8647594,,,,,,,,,,,,,, -166700,5.8669224,1.3547574,,,,,,,,,,,,,, -166800,5.820767,2.1827464,,,,,,,,,,,,,, -166900,5.6322837,2.2594755,,,,,,,,,,,,,, -167000,5.4075136,1.4137533,,,,,,,,,,,,,, -167100,6.067582,1.939938,,,,,,,,,,,,,, -167200,5.9751315,1.5322798,,,,,,,,,,,,,, -167291,,,0.8354296684265137,0.6281799077987671,0.7523199915885925,0.9751542210578918,50000.0,0.6345000267028809,1.580045223236084,10000.0,76509.43893957138,85127.15686106682,76509.43893957138,8601.107754945755,7.928579807281494,0.0 -167300,5.0049295,2.2618861,,,,,,,,,,,,,, -167400,5.1846294,2.5065742,,,,,,,,,,,,,, -167500,5.648644,1.3699129,,,,,,,,,,,,,, -167600,5.827778,1.5160851,,,,,,,,,,,,,, -167700,6.1765413,3.860837,,,,,,,,,,,,,, -167800,5.7661257,1.4542978,,,,,,,,,,,,,, -167900,5.8749294,1.6155365,,,,,,,,,,,,,, -168000,5.4471254,1.3859241,,,,,,,,,,,,,, -168100,5.595221,1.8530886,,,,,,,,,,,,,, -168200,6.1457653,1.458392,,,,,,,,,,,,,, -168211,,,0.8359179496765137,0.6287804841995239,0.7562599778175354,0.970455288887024,50000.0,0.6341000199317932,1.5775206089019775,10000.0,76929.56784844398,85598.03895044327,76929.56784844398,8651.759412050247,7.982419967651367,0.0 -168300,5.4759088,1.3829772,,,,,,,,,,,,,, -168400,6.8108487,3.6147115,,,,,,,,,,,,,, -168500,5.943125,1.602445,,,,,,,,,,,,,, -168600,5.9163175,1.6100663,,,,,,,,,,,,,, -168700,6.288059,3.123259,,,,,,,,,,,,,, -168800,6.331787,3.7075138,,,,,,,,,,,,,, -168900,6.341096,1.387862,,,,,,,,,,,,,, -169000,6.5919375,3.5644472,,,,,,,,,,,,,, -169100,6.674572,3.6808238,,,,,,,,,,,,,, -169131,,,0.8357812166213989,0.6224377751350403,0.7562800049781799,0.9682385921478271,50000.0,0.6381000280380249,1.56676185131073,10000.0,77349.9605910778,86064.29647517204,77349.9605910778,8697.523322582245,8.035626888275146,0.0 -169200,6.2785344,1.3916951,,,,,,,,,,,,,, -169300,6.1464124,1.5603893,,,,,,,,,,,,,, -169400,6.384837,1.3706756,,,,,,,,,,,,,, -169500,5.90237,2.6905344,,,,,,,,,,,,,, -169511,,,,,,,,,,,77520.38380241394,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 85c4bec90..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -29.620367527008057,0.0,36.15118789672852,1,0,36.15118789672852,0.0010000000474974,6.907756805419922,10000,65.77166390419006,0.0009570312104187,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -75.90391826629639,0.0172796249389648,456.1083743572235,861,0,456.1083743572235,0.0265000015497207,6.062020778656006,10000,532.0756976604462,0.0339648425579071,5.889985084533691,0.0312199983745813,5.931602478027344,50000 -124.1122748851776,0.0530307292938232,876.1209945678711,1774,0,876.1209945678711,0.0535000041127204,5.638273239135742,10000,1000.3806068897248,0.0719531252980232,5.346263408660889,0.0669599995017051,5.416959285736084,50000 -177.97129225730896,0.0809743404388427,1296.1530268192291,2694,0,1296.1530268192291,0.0766000002622604,5.27877950668335,10000,1474.3475806713104,0.1066601574420929,4.961225986480713,0.0988399982452392,5.016429424285889,50000 -226.3655271530152,0.1077020168304443,1716.2821543216703,3614,0,1716.2821543216703,0.0958000048995018,5.121191024780273,10000,1942.9452345371249,0.1257812529802322,4.762739181518555,0.1183599978685379,4.825972557067871,50000 -271.0300600528717,0.1323871612548828,2136.238264322281,4529,0,2136.238264322281,0.1157000064849853,4.830338478088379,10000,2407.637862443924,0.1682031154632568,4.37608528137207,0.1546199917793274,4.472805976867676,50000 -320.3178098201752,0.1565053462982177,2556.621108531952,5445,0,2556.621108531952,0.1330000013113021,4.718831539154053,10000,2877.380870580673,0.1838281154632568,4.257326126098633,0.1714800000190735,4.334697246551514,50000 -370.2815811634064,0.1853024959564209,2976.629909515381,6364,0,2976.629909515381,0.1616000086069107,4.43961763381958,10000,3347.430968284607,0.2262499928474426,3.872288942337036,0.2107999920845031,3.977044343948364,50000 -420.16905403137207,0.2169387340545654,3396.913996696472,7280,0,3396.913996696472,0.1700000017881393,4.424200534820557,10000,3817.68110537529,0.2319921851158142,3.833322286605835,0.212359994649887,3.964530467987061,50000 -468.6132698059082,0.2487246990203857,3816.8563516139984,8200,0,3816.8563516139984,0.1718000024557113,4.393675327301025,10000,4286.148283243179,0.2553320229053497,3.758265256881714,0.2200799882411956,3.949136257171631,50000 -514.4226930141449,0.2796754837036133,4236.834023237228,9118,0,4236.834023237228,0.1919000148773193,4.206649303436279,10000,4752.014056444168,0.267871081829071,3.58646821975708,0.2466599941253662,3.7260239124298096,50000 -562.69517993927,0.3083176612854004,4656.83242058754,10034,0,4656.83242058754,0.2006000131368637,4.112731456756592,10000,5220.361520528793,0.2818945348262787,3.467654228210449,0.2607599794864654,3.606190204620361,50000 -609.6357562541962,0.3355536460876465,5076.973096132278,10952,0,5076.973096132278,0.1982000023126602,4.1670637130737305,10000,5687.518299818039,0.294726550579071,3.4312899112701416,0.259660005569458,3.6437604427337646,50000 -658.7836654186249,0.3622803688049316,5497.2170152664185,11870,0,5497.2170152664185,0.2055000066757202,4.087541103363037,10000,6156.984508275986,0.2876171767711639,3.442207336425781,0.2721000015735626,3.556260108947754,50000 -702.2066161632538,0.3930139541625976,5917.322836399078,12789,0,5917.322836399078,0.2110000103712082,4.108162879943848,10000,6620.591611146927,0.2928906083106994,3.4402520656585693,0.2694199979305267,3.5773630142211914,50000 -750.3280100822449,0.4236247539520263,6337.505370855331,13707,0,6337.505370855331,0.2140000164508819,4.01960039138794,10000,7088.974200248718,0.3098046779632568,3.3494856357574463,0.2825199961662292,3.517797946929932,50000 -798.8193001747131,0.4494767189025879,6757.771499872208,14625,0,6757.771499872208,0.2201000154018402,4.021054744720459,10000,7557.805434465408,0.3052343726158142,3.417556047439575,0.2822999954223633,3.541224956512451,50000 -846.682626247406,0.4787595272064209,7177.714713335037,15539,0,7177.714713335037,0.226500004529953,3.976407527923584,10000,8025.688539028168,0.3110156059265136,3.333627939224243,0.290719985961914,3.459397315979004,50000 -896.20419049263,0.5077810287475586,7597.749115467071,16458,0,7597.749115467071,0.2324000149965286,3.872088432312012,10000,8495.327637910843,0.3293554484844208,3.1784443855285645,0.3029599785804748,3.3324337005615234,50000 -946.04585814476,0.5358819961547852,8017.729243755341,17377,0,8017.729243755341,0.2331000119447708,3.922811985015869,10000,8965.22508430481,0.3253906071186065,3.243004083633423,0.3007999956607818,3.385110855102539,50000 -995.0433971881866,0.5683629512786865,8438.02011179924,18295,0,8438.02011179924,0.2424000054597854,3.785786390304565,10000,9434.594047546389,0.3461523354053497,3.063856363296509,0.3191399872303009,3.214115619659424,50000 -1046.9851393699646,0.5980954170227051,8858.02958726883,19213,0,8858.02958726883,0.2399000078439712,3.832469701766968,10000,9906.623125314713,0.3359765410423279,3.1254475116729736,0.3057000041007995,3.2792983055114746,50000 -1096.579419374466,0.625751256942749,9278.196017742155,20132,0,9278.196017742155,0.241100013256073,3.792826414108277,10000,10376.46831059456,0.3737109303474426,2.931885004043579,0.3187199831008911,3.236717939376831,50000 -1148.0055141448977,0.6522841453552246,9698.436318159103,21052,0,9698.436318159103,0.2482000142335891,3.769993543624878,10000,10848.208649396896,0.3433007597923279,3.070847988128662,0.3197799921035766,3.210434913635254,50000 -1197.1896917819977,0.6798815727233887,10118.706419706345,21964,0,10118.706419706345,0.2415000051259994,3.799169778823853,10000,11317.737473249435,0.3420703113079071,3.094113349914551,0.3227799832820892,3.226070165634156,50000 -1245.8942565917969,0.7139434814453125,10538.827253580092,22880,0,10538.827253580092,0.2574000060558319,3.756904125213623,10000,11786.644594669342,0.36865234375,2.973356008529663,0.3253600001335144,3.211477041244507,50000 -1295.278584241867,0.7482728958129883,10958.82033610344,23795,0,10958.82033610344,0.2494000047445297,3.7716026306152335,10000,12256.103848218918,0.3464648425579071,3.0866899490356445,0.3257800042629242,3.2104380130767822,50000 -1344.6447319984436,0.7822833061218262,11378.957997083664,24713,0,11378.957997083664,0.2598000168800354,3.707412242889404,10000,12725.690438747406,0.3595702946186065,2.995157718658448,0.333079993724823,3.153878688812256,50000 -1394.3293986320496,0.8127717971801758,11799.026058912275,25628,0,11799.026058912275,0.2576000094413757,3.701142311096192,10000,13195.52159357071,0.3726952970027923,2.9379703998565674,0.3400200009346008,3.115710020065308,50000 -1444.3692378997805,0.843348503112793,12219.03768491745,26542,0,12219.03768491745,0.2600000202655792,3.686853408813477,10000,13665.651197195051,0.3614648282527923,2.9840753078460693,0.3368600010871887,3.117708444595337,50000 -1494.0119626522064,0.8733620643615723,12639.100955963137,27462,0,12639.100955963137,0.2497000098228454,3.807725191116333,10000,14135.43448138237,0.3483007848262787,3.1170191764831543,0.3240199983119964,3.26419997215271,50000 -1545.383987903595,0.9063313007354736,13059.38454055786,28380,0,13059.38454055786,0.2630999982357025,3.679342269897461,10000,14607.170650720596,0.368945300579071,2.9343044757843018,0.338619977235794,3.106697797775269,50000 -1596.0251424312592,0.9394416809082032,13479.487174987791,29298,0,13479.487174987791,0.2648000121116638,3.661438465118408,10000,15077.99607181549,0.3706835806369781,2.9274771213531494,0.3446199893951416,3.07129430770874,50000 -1645.9867305755615,0.9761998653411864,13899.69462966919,30217,0,13899.69462966919,0.2742000222206116,3.627936124801636,10000,15548.25105714798,0.3766210973262787,2.902115821838379,0.3519800007343292,3.029473066329956,50000 -1695.7626497745514,1.0101087093353271,14319.800573825836,31135,0,14319.800573825836,0.2705000042915344,3.6340298652648926,10000,16018.214548110962,0.3855859339237213,2.8547158241271973,0.3534599840641022,3.0370798110961914,50000 -1743.906052350998,1.0428571701049805,14739.811116695404,32055,0,14739.811116695404,0.2671000063419342,3.660425901412964,10000,16486.448800325394,0.4124804735183716,2.7808024883270264,0.3519199788570404,3.0841052532196045,50000 -1792.961481332779,1.0791571140289309,15160.1337788105,32972,0,15160.1337788105,0.2773000001907348,3.537525653839112,10000,16955.91092300415,0.38671875,2.805215120315552,0.3612799942493438,2.94804048538208,50000 -1839.4359893798828,1.1127994060516355,15580.551282167437,33891,0,15580.551282167437,0.2771000266075134,3.5811426639556885,10000,17422.885044813156,0.3903124928474426,2.800220966339112,0.363999992609024,2.969886541366577,50000 -1888.6756381988523,1.1519043445587158,16000.717082977297,34809,0,16000.717082977297,0.2684000134468078,3.636655569076538,10000,17892.378203630447,0.3956054747104645,2.79667067527771,0.3571199774742126,3.022451400756836,50000 -1938.52412891388,1.1824181079864502,16421.02041387558,35725,0,16421.02041387558,0.2833000123500824,3.526043176651001,10000,18362.607944726944,0.390937477350235,2.7831339836120605,0.3679399788379669,2.916210174560547,50000 -1987.9012160301208,1.2189600467681885,16841.228150367737,36639,0,16841.228150367737,0.2743000090122223,3.578441619873047,10000,18832.27735710144,0.3834374845027923,2.84735369682312,0.3583599925041199,2.9957029819488525,50000 -2039.933144569397,1.262007474899292,17261.399400949478,37555,0,17261.399400949478,0.2765000164508819,3.5944924354553223,10000,19304.57182574272,0.3958398401737213,2.8110783100128174,0.3609800040721893,3.0116312503814697,50000 -2088.265805721283,1.299102783203125,17681.434158086777,38473,0,17681.434158086777,0.2942000031471252,3.5065810680389404,10000,19773.024648189545,0.4010546803474426,2.771629571914673,0.37567999958992,2.9098598957061768,50000 -2137.838715553284,1.3325514793395996,18101.503092050552,39390,0,18101.503092050552,0.2830000221729278,3.5438942909240723,10000,20242.74775648117,0.3926562368869781,2.793609142303467,0.3681399822235107,2.933591842651367,50000 -2187.1993803977966,1.3684730529785156,18521.69138717652,40307,0,18521.69138717652,0.2861000001430511,3.549136161804199,10000,20712.3809030056,0.4039843678474426,2.782717227935791,0.3682200014591217,2.971782684326172,50000 -2240.424861431122,1.3985400199890137,18941.72615456581,41227,0,18941.72615456581,0.2959000170230865,3.476227045059204,10000,21185.7188167572,0.4065038859844208,2.7331039905548096,0.3854199945926666,2.870001792907715,50000 -2289.662645339966,1.4311163425445557,19362.027856588364,42145,0,19362.027856588364,0.2925000190734863,3.437443971633911,10000,21655.337735414505,0.4131640493869781,2.6345086097717285,0.381520003080368,2.8117265701293945,50000 -2338.235716342926,1.4649834632873535,19782.140295267105,43061,0,19782.140295267105,0.3064000010490417,3.392443418502808,10000,22124.10464024544,0.4229101538658142,2.6399965286254883,0.3895399868488312,2.8090786933898926,50000 -2388.211321830749,1.503532886505127,20202.28423190117,43978,0,20202.28423190117,0.2950000166893005,3.4772708415985107,10000,22594.31020140648,0.4477148354053497,2.5128729343414307,0.3838799893856048,2.8571627140045166,50000 -2438.9718708992004,1.5404622554779053,20622.277943134308,44895,0,20622.277943134308,0.3067000210285187,3.367946147918701,10000,23065.153266191483,0.4213085770606994,2.619121551513672,0.3991200029850006,2.753256559371948,50000 -2489.8106729984283,1.576117992401123,21042.53992986679,45814,0,21042.53992986679,0.292600005865097,3.459824323654175,10000,23536.338141679764,0.411914050579071,2.6856353282928467,0.384579986333847,2.840383052825928,50000 -2540.581508398056,1.613168239593506,21462.657836198807,46727,0,21462.657836198807,0.3050000071525574,3.409043550491333,10000,24007.31149339676,0.4382031261920929,2.527977466583252,0.3956199884414673,2.7620689868927,50000 -2591.4612271785736,1.6478898525238037,21882.891329288483,47643,0,21882.891329288483,0.3105000257492065,3.3404970169067383,10000,24478.50642466545,0.4313085973262787,2.561516046524048,0.4053599834442138,2.709469079971313,50000 -2641.5904109478,1.6836578845977783,22302.96504020691,48561,0,22302.96504020691,0.317900002002716,3.279659509658813,10000,24948.79293680191,0.4389062523841858,2.5233049392700195,0.4129000008106231,2.669914960861206,50000 -2692.083906888962,1.7213480472564695,22723.290951013565,49474,0,22723.290951013565,0.3184000253677368,3.2931008338928223,10000,25419.6971578598,0.449531227350235,2.4620165824890137,0.4113599956035614,2.675554275512696,50000 -2738.572528839112,1.7537884712219238,23143.52159333229,50386,0,23143.52159333229,0.3139000236988067,3.3136818408966064,10000,25886.4966814518,0.4435351490974426,2.5222740173339844,0.4126800000667572,2.66820764541626,50000 -2788.143606901169,1.7920589447021484,23563.82721590996,51300,0,23563.82721590996,0.3272000253200531,3.2604527473449707,10000,26356.45927453041,0.4445898234844208,2.477378845214844,0.4203200042247772,2.636164903640747,50000 -2836.9467310905457,1.830620288848877,23983.90644145012,52216,0,23983.90644145012,0.3122000098228454,3.3333580493927,10000,26825.42811512947,0.438789039850235,2.5263781547546387,0.4053199887275696,2.713552474975586,50000 -2885.250978946686,1.871053695678711,24404.01684069633,53133,0,24404.01684069633,0.3252000212669372,3.2333431243896484,10000,27293.93130230904,0.4564843773841858,2.438621759414673,0.4238399863243103,2.607288122177124,50000 -2936.507098197937,1.9068207740783687,24824.36661410332,54054,0,24824.36661410332,0.3118000030517578,3.3063266277313232,10000,27765.620665311813,0.4422265589237213,2.535531759262085,0.412200003862381,2.7084314823150635,50000 -2986.406905412674,1.9439399242401123,25244.459810972214,54973,0,25244.459810972214,0.3169000148773193,3.3160958290100098,10000,28235.69842457772,0.4440820217132568,2.5492329597473145,0.4105399847030639,2.712507724761963,50000 -3036.34405207634,1.9801304340362549,25664.43350338936,55889,0,25664.43350338936,0.316100001335144,3.364776611328125,10000,28705.693928956985,0.4723632633686065,2.383117198944092,0.4037799835205078,2.740209579467773,50000 -3086.3352172374725,2.0175254344940186,26084.50500845909,56807,0,26084.50500845909,0.3288000226020813,3.2641468048095703,10000,29175.84260368347,0.4465624988079071,2.5229108333587646,0.4216800034046173,2.661752700805664,50000 -3135.060056447983,2.0510499477386475,26504.70529055596,57728,0,26504.70529055596,0.3342000246047973,3.217836141586304,10000,29644.84923171997,0.4511913955211639,2.4399993419647217,0.4245599806308746,2.596042394638061,50000 -3185.6205384731293,2.085714817047119,26924.90839576721,58645,0,26924.90839576721,0.3331000208854675,3.236710786819458,10000,30115.69566130638,0.4740820229053497,2.372753143310547,0.4254199862480163,2.6239047050476074,50000 -3235.580410003662,2.121338367462158,27344.886063098907,59558,0,27344.886063098907,0.3296000063419342,3.227254867553711,10000,30585.716195106503,0.4577734172344208,2.429400682449341,0.4283799827098846,2.5803024768829346,50000 -3284.088232278824,2.162022352218628,27764.8627755642,60474,0,27764.8627755642,0.3315000236034393,3.199613809585572,10000,31054.29045343399,0.4648632705211639,2.3890140056610107,0.4313199818134308,2.562873363494873,50000 -3333.0700438022614,2.200597047805786,28185.063593387604,61391,0,28185.063593387604,0.3352000117301941,3.2316999435424805,10000,31523.559364318848,0.467089831829071,2.3804032802581787,0.4293199777603149,2.585036277770996,50000 -3383.3891365528107,2.243925094604492,28604.98014640808,62310,0,28604.98014640808,0.3335000276565552,3.20181655883789,10000,31993.887003183365,0.4643359184265136,2.393752336502075,0.4336199760437011,2.553832054138184,50000 -3432.985024452209,2.2831294536590576,29025.158362150192,63227,0,29025.158362150192,0.3427000045776367,3.135342836380005,10000,32463.748566627502,0.4670117199420929,2.3691554069519043,0.4369799792766571,2.539468050003052,50000 -3483.0788888931274,2.325901746749878,29445.27424812317,64144,0,29445.27424812317,0.3273000121116638,3.223234176635742,10000,32934.04869699478,0.4650976359844208,2.3967294692993164,0.4300599992275238,2.5977704524993896,50000 -3533.843167066574,2.3687186241149902,29865.45224094391,65062,0,29865.45224094391,0.3362000286579132,3.1964430809021,10000,33405.08270573616,0.4678320288658142,2.389435291290283,0.4369399845600128,2.564899682998657,50000 -3584.863637447357,2.4079158306121826,30285.408737182617,65982,0,30285.408737182617,0.3450000286102295,3.2141637802124023,10000,33876.14655208588,0.4744921624660492,2.419340133666992,0.4359599947929382,2.5876500606536865,50000 -3635.3954322338095,2.4538557529449463,30705.5859375,66901,0,30705.5859375,0.340800017118454,3.17934513092041,10000,34346.95048165321,0.4737695157527923,2.3729798793792725,0.44200000166893,2.548090696334839,50000 -3682.5514616966248,2.488304853439331,31125.7574505806,67819,0,31125.7574505806,0.3449000120162964,3.149144172668457,10000,34814.36068201065,0.5081835985183716,2.193055391311645,0.4444599747657776,2.5196900367736816,50000 -3731.446517467499,2.525161027908325,31545.98058009148,68738,0,31545.98058009148,0.3493000268936157,3.140624761581421,10000,35283.56374049187,0.474414050579071,2.358579397201538,0.4458999931812286,2.502725124359131,50000 -3781.5328879356384,2.565286159515381,31966.211909532547,69655,0,31966.211909532547,0.35630002617836,3.0855069160461426,10000,35753.97013711929,0.488085925579071,2.284952402114868,0.4509799778461456,2.4512698650360107,50000 -3829.8283665180206,2.6060354709625244,32386.34180402756,70571,0,32386.34180402756,0.362600028514862,3.076205253601074,10000,36222.48383450508,0.5066015720367432,2.156141519546509,0.4574399888515472,2.423490285873413,50000 -3879.22545838356,2.655019521713257,32806.55907249451,71488,0,32806.55907249451,0.3591000139713287,3.0786125659942627,10000,36692.19506430626,0.4865429699420929,2.26175856590271,0.4581999778747558,2.418018102645874,50000 -3929.594414949417,2.6921746730804443,33226.725727796555,72406,0,33226.725727796555,0.3544000089168548,3.063166618347168,10000,37162.81544137001,0.4870898425579071,2.2414968013763428,0.4551199972629547,2.416773796081543,50000 -3980.205307483673,2.726961135864258,33647.000801324844,73322,0,33647.000801324844,0.3473000228404999,3.1173243522644043,10000,37633.78448009491,0.4894921779632568,2.2488605976104736,0.4532599747180938,2.4476559162139893,50000 -4030.269171476364,2.76204776763916,34067.28379058838,74239,0,34067.28379058838,0.3549000024795532,3.087660074234009,10000,38104.21386909485,0.491503894329071,2.291461944580078,0.4583199918270111,2.455028772354126,50000 -4081.058485507965,2.7998476028442383,34487.50742006302,75158,0,34487.50742006302,0.3680000305175781,3.005544662475586,10000,38575.3131480217,0.5007616877555847,2.170815944671631,0.4699599742889404,2.350716114044189,50000 -4132.322255373001,2.8405230045318604,34907.619475364685,76076,0,34907.619475364685,0.3609000146389007,3.055748462677002,10000,39046.77811384201,0.5032812356948853,2.190779447555542,0.464819997549057,2.4005751609802246,50000 -4181.173963069916,2.8811960220336914,35327.92823219299,76992,0,35327.92823219299,0.3664000034332275,3.04853892326355,10000,39516.02749085426,0.5121288895606995,2.185343027114868,0.4690199792385101,2.402076244354248,50000 -4229.867544412613,2.919562578201294,35747.85677528381,77907,0,35747.85677528381,0.3689000308513641,2.9804775714874268,10000,39984.73551940918,0.5077343583106995,2.158487558364868,0.4768199920654297,2.3143656253814697,50000 -4285.362320184708,2.961005449295044,36168.04482078552,78822,0,36168.04482078552,0.3671000301837921,3.0031578540802,10000,40460.50757360458,0.5114843845367432,2.1231250762939453,0.4710799753665924,2.3353586196899414,50000 -4334.415132522583,3.0005173683166504,36587.97202754021,79740,0,36587.97202754021,0.3686000108718872,3.0116066932678223,10000,40929.57435941696,0.5294140577316284,2.067649602890014,0.465859979391098,2.3905460834503174,50000 -4385.123164892197,3.0408191680908203,37007.93005943298,80656,0,37007.93005943298,0.3669000267982483,3.037405252456665,10000,41400.32828879357,0.4966796636581421,2.233603954315185,0.4697999954223633,2.378688335418701,50000 -4434.969736337662,3.080196142196656,37427.89789867401,81572,0,37427.89789867401,0.3721000254154205,2.969887018203736,10000,41870.229408979416,0.512890636920929,2.137596368789673,0.4765599966049194,2.3293349742889404,50000 -4485.095364332199,3.1229937076568604,37847.978172302246,82491,0,37847.978172302246,0.3731000125408172,2.9452946186065674,10000,42340.52619123459,0.531054675579071,2.0348589420318604,0.4846400022506714,2.283027648925781,50000 -4532.843173027039,3.1619906425476074,38268.08063173294,83408,0,38268.08063173294,0.3836000263690948,2.9029996395111084,10000,42808.46298861504,0.519726574420929,2.083502531051636,0.4893999993801117,2.2483296394348145,50000 -4581.449735164642,3.20169997215271,38688.25997066498,84324,0,38688.25997066498,0.3802000284194946,2.933024883270264,10000,43277.33517932892,0.5274999737739563,2.044956684112549,0.4887399971485138,2.2427096366882324,50000 -4631.371387481689,3.2427725791931152,39108.579641819,85241,0,39108.579641819,0.3765000104904175,2.978649139404297,10000,43747.66602993012,0.5252148509025574,2.1249876022338867,0.4815999865531921,2.3384721279144287,50000 -4684.744953393936,3.281177997589112,39528.60836338997,86158,0,39528.60836338997,0.3702000081539154,2.974332332611084,10000,44221.15460038185,0.515820324420929,2.156745672225952,0.4821399748325348,2.332691192626953,50000 -4734.239243745804,3.324136257171631,39948.69438958168,87075,0,39948.69438958168,0.3924000263214111,2.8767480850219727,10000,44690.825795173645,0.5296288728713989,2.0581154823303223,0.4989999830722809,2.215408563613892,50000 -4786.001025438309,3.3661766052246094,40368.87440466881,87991,0,40368.87440466881,0.3921000063419342,2.863675594329834,10000,45162.857728004456,0.5358788967132568,1.9901221990585327,0.5009599924087524,2.1941733360290527,50000 -4835.186659097672,3.407886266708374,40788.85287356377,88907,0,40788.85287356377,0.3799000084400177,2.949227809906006,10000,45632.11156868935,0.5342382788658142,2.0812058448791504,0.4899799823760986,2.29469895362854,50000 -4883.906123161316,3.450183391571045,41209.05537772179,89821,0,41209.05537772179,0.394400030374527,2.8692476749420166,10000,46101.12329792976,0.5366601347923279,2.041024923324585,0.5012800097465515,2.20486831665039,50000 -4933.602746248245,3.4877588748931885,41629.10823345184,90736,0,41629.10823345184,0.398900032043457,2.8334524631500244,10000,46570.9590845108,0.5470117330551147,1.968786239624024,0.5073800086975098,2.1694068908691406,50000 -4983.74448466301,3.526984214782715,42049.48864984512,91649,0,42049.48864984512,0.3973000049591064,2.799119234085083,10000,47041.567656993866,0.5741991996765137,1.818586349487305,0.5086199641227722,2.1520354747772217,50000 -5034.174078941345,3.568467617034912,42469.72681379318,92566,0,42469.72681379318,0.3943000137805938,2.86732816696167,10000,47512.32449412346,0.5272265672683716,2.088000059127808,0.5004400014877319,2.219738483428955,50000 -5083.820999860764,3.623884439468384,42889.70433783531,93480,0,42889.70433783531,0.393200010061264,2.861738681793213,10000,47982.05200815201,0.5433398485183716,2.01934814453125,0.5062999725341797,2.205066442489624,50000 -5134.835416078568,3.667936563491821,43309.709506988525,94396,0,43309.709506988525,0.3976000249385834,2.820587158203125,10000,48453.16395688057,0.5606250166893005,1.9081244468688965,0.5089399814605713,2.156013965606689,50000 -5185.579132556915,3.7109029293060303,43729.9824256897,95311,0,43729.9824256897,0.4028000235557556,2.7761130332946777,10000,48924.271253585815,0.5488671660423279,1.9369208812713623,0.5136199593544006,2.113115072250366,50000 -5233.184512376785,3.750455856323242,44150.20173883438,96228,0,44150.20173883438,0.4028000235557556,2.821324825286865,10000,49392.18372750282,0.5465624928474426,1.977414846420288,0.5113599896430969,2.1627683639526367,50000 -5282.874450683594,3.79244875907898,44570.48270201683,97146,0,44570.48270201683,0.4019000232219696,2.7724106311798096,10000,49862.24414396286,0.5612499713897705,1.8850711584091189,0.5123400092124939,2.129922866821289,50000 -5328.955847263336,3.831274509429932,44990.81266307831,98058,0,44990.81266307831,0.4170000255107879,2.7410056591033936,10000,50328.74145245552,0.5537304282188416,1.9444429874420168,0.524179995059967,2.098775148391724,50000 -5379.665320634842,3.875601530075073,45410.88789892197,98972,0,45410.88789892197,0.4125000238418579,2.7286581993103027,10000,50799.61802506447,0.56494140625,1.8888856172561648,0.5273999571800232,2.079842090606689,50000 -5429.00506401062,3.9204533100128174,45831.25173521042,99889,0,45831.25173521042,0.4125000238418579,2.748320817947388,10000,51269.41478252411,0.5637304782867432,1.909096121788025,0.5228599905967712,2.1150269508361816,50000 -5477.274906158447,3.9645845890045166,46251.20902395248,100805,0,46251.20902395248,0.42330002784729,2.701802968978882,10000,51737.73347878456,0.5720898509025574,1.8558332920074463,0.532260000705719,2.0450856685638428,50000 -5525.338563919067,4.004071235656738,46671.40980911255,101721,0,46671.40980911255,0.4095000326633453,2.745497703552246,10000,52206.08445620537,0.5673046708106995,1.848193645477295,0.5283600091934204,2.050722360610962,50000 -5574.214259624481,4.048406600952148,47091.384001493454,102640,0,47091.384001493454,0.4227000176906585,2.689002752304077,10000,52675.02685260773,0.5785741806030273,1.818680047988892,0.5371400117874146,2.024113655090332,50000 -5623.107983827591,4.101131916046143,47511.33321595192,103558,0,47511.33321595192,0.41880002617836,2.720669269561768,10000,53143.97084522247,0.5989648103713989,1.7452768087387085,0.5286200046539307,2.076588153839112,50000 -5672.216294765472,4.143033742904663,47931.56638360024,104477,0,47931.56638360024,0.4228000342845917,2.7130990028381348,10000,53613.40177822113,0.5718554854393005,1.8921626806259155,0.5343199968338013,2.075522184371948,50000 -5723.759583473206,4.190583944320679,48351.92719125748,105394,0,48351.92719125748,0.4191000163555145,2.7071774005889893,10000,54085.40121340752,0.5820507407188416,1.837495803833008,0.5375800132751465,2.045927047729492,50000 -5774.5703637599945,4.240931749343872,48772.00656723976,106312,0,48772.00656723976,0.4317000210285187,2.648791074752808,10000,54556.38950634003,0.5904687643051147,1.7310149669647217,0.5409199595451355,1.9848403930664065,50000 -5823.475866556168,4.285040616989136,49192.05952215195,107228,0,49192.05952215195,0.425100028514862,2.690009593963623,10000,55025.439949035645,0.5759570002555847,1.833008050918579,0.537880003452301,2.0171401500701904,50000 -5873.19217467308,4.326111793518066,49612.333136081696,108146,0,49612.333136081696,0.4388000071048736,2.611959934234619,10000,55495.51818156242,0.5908398032188416,1.7584213018417358,0.5523999929428101,1.949994444847107,50000 -5922.949047088623,4.369261980056763,50032.38190603256,109061,0,50032.38190603256,0.4354000091552734,2.6368744373321533,10000,55965.41650009155,0.6046093702316284,1.7196390628814695,0.5508399605751038,1.968966007232666,50000 -5971.910776138306,4.410284757614136,50452.712889909744,109981,0,50452.712889909744,0.4374000132083893,2.614704370498657,10000,56434.7984893322,0.5870702862739563,1.7547787427902222,0.5485199689865112,1.9464657306671145,50000 -6022.969138383865,4.453753709793091,50872.9928958416,110896,0,50872.9928958416,0.4438000321388244,2.5637691020965576,10000,56906.227376937866,0.6010546684265137,1.6854904890060425,0.556879997253418,1.8993265628814693,50000 -6073.349613189697,4.510779142379761,51293.23732328415,111813,0,51293.23732328415,0.4487000107765198,2.552631378173828,10000,57376.957426548,0.606640636920929,1.6578865051269531,0.5602999925613403,1.8957290649414065,50000 -6122.37894487381,4.555319547653198,51713.58249878883,112731,0,51713.58249878883,0.4410000145435333,2.5727462768554688,10000,57846.42418789864,0.6123046875,1.6747193336486816,0.5610399842262268,1.8994790315628047,50000 -6172.862322568893,4.599631071090698,52133.86265182495,113647,0,52133.86265182495,0.4446000158786773,2.5511436462402344,10000,58317.28026175499,0.6024609208106995,1.6888262033462524,0.5626800060272217,1.876355528831482,50000 -6221.115981340408,4.645940542221069,52554.00845098496,114565,0,52554.00845098496,0.4391000270843506,2.578800678253174,10000,58785.77437710762,0.5999413728713989,1.7239134311676023,0.5591199994087219,1.9196312427520752,50000 -6270.487004041672,4.691897869110107,52974.09374523163,115483,0,52974.09374523163,0.4597000181674957,2.50107741355896,10000,59255.32407426834,0.6393163800239563,1.5017653703689575,0.5700199604034424,1.8257381916046145,50000 -6324.160755395889,4.736301422119141,53394.46939063072,116400,0,53394.46939063072,0.4602000117301941,2.501185655593872,10000,59729.46551704407,0.6116992235183716,1.6490353345870972,0.5755999684333801,1.843803882598877,50000 -6375.838165998459,4.7819108963012695,53814.67329096794,117319,0,53814.67329096794,0.4551000297069549,2.4943559169769287,10000,60201.44076490402,0.6172265410423279,1.627258539199829,0.5745399594306946,1.8362805843353271,50000 -6425.695611715317,4.8346474170684814,54234.85326814652,118236,0,54234.85326814652,0.4593000113964081,2.519318103790283,10000,60671.578468084335,0.6234374642372131,1.591255784034729,0.5689799785614014,1.854089856147766,50000 -6475.814856529236,4.881854772567749,54655.12280344963,119150,0,54655.12280344963,0.4622000157833099,2.4543488025665283,10000,61142.06141138077,0.6201952695846558,1.594254732131958,0.5825200080871582,1.778774380683899,50000 -6523.415584564209,4.927415132522583,55075.43723917008,120069,0,55075.43723917008,0.4643000364303589,2.4546687602996826,10000,61610.07016038895,0.6280273199081421,1.5842721462249756,0.5796799659729004,1.7919082641601562,50000 -6573.364508867264,4.976737022399902,55495.7010948658,120987,0,55495.7010948658,0.4634000360965729,2.4435477256774902,10000,62080.37963700295,0.6349999904632568,1.535889744758606,0.584879994392395,1.7815570831298828,50000 -6621.707966804504,5.029550790786743,55915.98540139198,121903,0,55915.98540139198,0.4704000353813171,2.42789888381958,10000,62549.107691049576,0.6355859041213989,1.5495035648345947,0.5896599888801575,1.7570979595184326,50000 -6671.663818836212,5.077637434005737,56336.27511167526,122819,0,56336.27511167526,0.4693000316619873,2.4274723529815674,10000,63019.44883728027,0.6341015696525574,1.5348260402679443,0.5896399617195129,1.748339056968689,50000 -6722.052555799484,5.1239094734191895,56756.223601818085,123736,0,56756.223601818085,0.4792000353336334,2.365442991256714,10000,63489.87978386879,0.6442577838897705,1.493449330329895,0.594539999961853,1.7315579652786257,50000 -6773.475342273712,5.171019077301025,57176.31259250641,124653,0,57176.31259250641,0.4780000150203705,2.39933180809021,10000,63961.48591351509,0.6554492115974426,1.4739404916763306,0.5974400043487549,1.7426960468292236,50000 -6822.102419376373,5.218945741653442,57596.52352762222,125568,0,57596.52352762222,0.4763000309467315,2.396895408630371,10000,64430.41939020157,0.6419726610183716,1.5183041095733645,0.5974000096321106,1.7342219352722168,50000 -6870.285463809967,5.281857252120972,58016.79306435585,126482,0,58016.79306435585,0.4836000204086303,2.356315851211548,10000,64898.98254442215,0.6576171517372131,1.45563805103302,0.6087200045585632,1.6954721212387085,50000 -6917.643349409103,5.328955173492432,58437.100652217865,127400,0,58437.100652217865,0.4881000220775604,2.3263015747070312,10000,65366.74352836609,0.6700586080551147,1.3599334955215454,0.6076599955558777,1.6682051420211792,50000 -6966.18169260025,5.3842246532440186,58857.16923952103,128316,0,58857.16923952103,0.4864000082015991,2.318061113357544,10000,65835.45329618454,0.6534960865974426,1.4608618021011353,0.6089000105857849,1.6636296510696411,50000 -7016.280867099762,5.430881500244141,59277.42556810379,129234,0,59277.42556810379,0.4950000345706939,2.2942965030670166,10000,66305.90271615982,0.6603124737739563,1.4079570770263672,0.6110399961471558,1.6467926502227783,50000 -7064.444483995438,5.477983713150024,59697.67466640472,130154,0,59697.67466640472,0.4908000230789184,2.339881896972656,10000,66774.41073536873,0.6708202958106995,1.3997917175292969,0.6119999885559082,1.674870252609253,50000 -7112.693949460983,5.525101900100708,60117.79187011719,131074,0,60117.79187011719,0.494700014591217,2.2643074989318848,10000,67242.87277555466,0.6654687523841858,1.3992410898208618,0.6183599829673767,1.6063157320022583,50000 -7161.525184869766,5.570846319198608,60537.90219092369,131991,0,60537.90219092369,0.4973000288009643,2.2747392654418945,10000,67711.90890932083,0.6669335961341858,1.3916473388671875,0.6188600063323975,1.623449206352234,50000 -7210.376707792282,5.624185085296631,60957.94920253754,132910,0,60957.94920253754,0.4974000155925751,2.2841172218322754,10000,68180.9085547924,0.6772265434265137,1.361317157745361,0.620959997177124,1.627536654472351,50000 -7261.099878549576,5.673567771911621,61378.31641602516,133828,0,61378.31641602516,0.501300036907196,2.24080491065979,10000,68652.09668803215,0.6712890267372131,1.365814447402954,0.6225999593734741,1.5908536911010742,50000 -7312.5014128685,5.723673582077026,61798.59408092499,134743,0,61798.59408092499,0.4987000226974487,2.2690556049346924,10000,69123.8744187355,0.6692187190055847,1.3907537460327148,0.6218199729919434,1.610404133796692,50000 -7360.451724529266,5.769453287124634,62218.69186377525,135622,0,62218.69186377525,0.5057000517845154,2.21661114692688,10000,69592.01400113106,0.6849804520606995,1.3148226737976074,0.6287999749183655,1.5665632486343384,50000 -7410.717103242874,5.824112415313721,62638.98313713074,136534,0,62638.98313713074,0.5059000253677368,2.23021936416626,10000,70062.6733353138,0.6810937523841858,1.350616216659546,0.6312800049781799,1.5826791524887085,50000 -7459.330714464188,5.873344898223877,63059.25512361527,137447,0,63059.25512361527,0.511900007724762,2.209479093551636,10000,70531.656188488,0.6771484017372131,1.363439917564392,0.6318599581718445,1.5701099634170532,50000 -7508.541209936142,5.927834272384644,63479.41015410423,138358,0,63479.41015410423,0.51500004529953,2.1850335597991943,10000,71001.12347507477,0.6902929544448853,1.3037259578704834,0.6375799775123596,1.5405707359313965,50000 -7560.420810461044,5.9827258586883545,63899.503361940384,139272,0,63899.503361940384,0.5205000042915344,2.161705732345581,10000,71473.19856882095,0.7201562523841858,1.162137746810913,0.6441999673843384,1.5044505596160889,50000 -7610.637903690338,6.034387826919556,64319.68514704704,140189,0,64319.68514704704,0.52510005235672,2.134714841842652,10000,71943.69709396362,0.6912499666213989,1.2617908716201782,0.6451199650764465,1.4823774099349976,50000 -7661.540041446686,6.084909915924072,64739.68468880653,141109,0,64739.68468880653,0.5254999995231628,2.115793943405152,10000,72414.69781684875,0.7030664086341858,1.2258492708206177,0.6487999558448792,1.4680681228637695,50000 -7710.902366638184,6.13739275932312,65159.91322731972,142026,0,65159.91322731972,0.5259000062942505,2.140204906463623,10000,72884.38906359673,0.712695300579071,1.1884727478027344,0.6464200019836426,1.488558292388916,50000 -7759.4170508384705,6.186307430267334,65580.24869775772,142943,0,65580.24869775772,0.5232000350952148,2.109903573989868,10000,73353.33665847778,0.7005664110183716,1.2247933149337769,0.6492999792098999,1.4643806219100952,50000 -7807.887068033218,6.239961385726929,66000.38722920418,143859,0,66000.38722920418,0.5297000408172607,2.113292932510376,10000,73822.04728531837,0.7032226324081421,1.2361243963241575,0.6534000039100647,1.4610118865966797,50000 -7858.054362535477,6.291686773300171,66420.68016719818,144776,0,66420.68016719818,0.5299000144004822,2.11555290222168,10000,74292.60705327988,0.7193945050239563,1.1607738733291626,0.654259979724884,1.447080135345459,50000 -7905.415101289749,6.340944766998291,66840.85182523727,145690,0,66840.85182523727,0.5415000319480896,2.0482189655303955,10000,74760.23641347885,0.713671863079071,1.1829522848129272,0.6609799861907959,1.417777180671692,50000 -7955.480447292328,6.393074035644531,67260.99752354622,146605,0,67260.99752354622,0.542900025844574,2.0617923736572266,10000,75230.54712605476,0.7144335508346558,1.191360592842102,0.6621999740600586,1.4254510402679443,50000 -8003.546766996384,6.440312385559082,67681.06917715073,147522,0,67681.06917715073,0.5448000431060791,2.0465872287750244,10000,75698.780200243,0.7241796851158142,1.1309359073638916,0.6642000079154968,1.4086828231811523,50000 -8053.31559920311,6.489897012710571,68101.14979958534,148440,0,68101.14979958534,0.5482000112533569,2.010650396347046,10000,76168.72648119926,0.7292773127555847,1.113935470581055,0.6719399690628052,1.3779354095458984,50000 -8103.527039289474,6.53973126411438,68521.11470293999,149356,0,68521.11470293999,0.5427000522613525,2.032515525817871,10000,76639.01660203934,0.7234765291213989,1.1463912725448608,0.6699000000953674,1.3762966394424438,50000 -8155.512540578842,6.593385457992554,68941.08154058456,150270,0,68941.08154058456,0.5507000088691711,1.992889165878296,10000,77111.07012796402,0.7346875071525574,1.091256856918335,0.6761800050735474,1.3612442016601562,50000 -8206.480012178421,6.650951862335205,69361.19805550575,151188,0,69361.19805550575,0.5574000477790833,1.9849852323532104,10000,77582.26147294044,0.75537109375,1.024010181427002,0.6801199913024902,1.350976824760437,50000 -8255.09532880783,6.703702688217163,69781.2301557064,152104,0,69781.2301557064,0.556600034236908,1.947913408279419,10000,78051.009329319,0.7370507717132568,1.0689661502838137,0.681659996509552,1.3200819492340088,50000 -8303.366862535477,6.754567861557007,70201.16539907455,153020,0,70201.16539907455,0.5614000558853149,1.9452199935913088,10000,78519.31473040581,0.7427343726158142,1.0332081317901611,0.6832399964332581,1.3128416538238523,50000 -8353.252356290817,6.802517414093018,70621.41844844818,153933,0,70621.41844844818,0.5550000071525574,1.9567975997924805,10000,78989.54914736748,0.7536913752555847,1.0018309354782104,0.6841399669647217,1.3143833875656128,50000 -8402.332137584686,6.852922439575195,71041.63038349152,154850,0,71041.63038349152,0.5651000142097473,1.925018310546875,10000,79458.93891525269,0.74916011095047,1.0158547163009644,0.688539981842041,1.2876722812652588,50000 -8452.610144615173,6.913311243057251,71461.77147507668,155765,0,71461.77147507668,0.5700000524520874,1.9121655225753784,10000,79929.46589922905,0.7540234327316284,0.9922083616256714,0.6906799674034119,1.2663934230804443,50000 -8501.413291931152,6.968406200408936,71881.82072901726,156680,0,71881.82072901726,0.5699000358581543,1.89507257938385,10000,80398.42122769356,0.7607421875,0.967012107372284,0.6918999552726746,1.274028182029724,50000 -8550.524432182312,7.022604942321777,72301.9625506401,157595,0,72301.9625506401,0.5837000012397766,1.8638801574707031,10000,80867.77632331848,0.7591796517372131,0.9821126461029052,0.6990199685096741,1.2495800256729126,50000 -8599.948066473007,7.076277017593384,72722.21183228493,158510,0,72722.21183228493,0.57750004529953,1.8460190296173096,10000,81337.55066609383,0.7641015648841858,0.9493613243103028,0.7014600038528442,1.2222778797149658,50000 -8651.003014564514,7.126868724822998,73142.33656525612,159425,0,73142.33656525612,0.5818000435829163,1.855764389038086,10000,81808.82894182205,0.7665234208106995,0.938743770122528,0.7026199698448181,1.223870873451233,50000 -8700.326808214188,7.178597688674927,73562.57827568054,160338,0,73562.57827568054,0.5848000049591064,1.84183931350708,10000,82278.49391198158,0.7646874785423279,0.9522948861122132,0.7034400105476379,1.2250497341156006,50000 -8750.995544433594,7.237669229507446,73982.7239575386,161254,0,73982.7239575386,0.5820000171661377,1.835293412208557,10000,82749.41592168808,0.7704296708106995,0.934205174446106,0.7053999900817871,1.2130885124206543,50000 -8801.265083551407,7.292815208435059,74402.66515946388,162170,0,74402.66515946388,0.5907000303268433,1.818153619766236,10000,83219.72941589355,0.7777734398841858,0.8995881676673889,0.7101399898529053,1.1974650621414185,50000 -8849.700415611267,7.344009160995483,74822.65410661697,163085,0,74822.65410661697,0.5903000235557556,1.810662150382996,10000,83688.25251555443,0.7899999618530273,0.8373759984970093,0.7140799760818481,1.1738150119781494,50000 -8898.0752389431,7.395813226699829,75243.34757304192,164001,0,75243.34757304192,0.5924000144004822,1.779520034790039,10000,84157.42079758644,0.7776171565055847,0.8835697174072266,0.7133199572563171,1.1664485931396484,50000 -8947.53544473648,7.44543981552124,75663.35647845268,164920,0,75663.35647845268,0.5914000272750854,1.803371548652649,10000,84626.98712992668,0.7840234041213989,0.8810456395149231,0.7141799926757812,1.1800652742385864,50000 -8997.781407117844,7.497044086456299,76083.54992222786,165837,0,76083.54992222786,0.5946000218391418,1.770185470581055,10000,85097.52592563629,0.7928906083106995,0.8252041339874268,0.7172600030899048,1.158810257911682,50000 -9048.111221551895,7.553990364074707,76503.78957104683,166753,0,76503.78957104683,0.6046000123023987,1.7537925243377686,10000,85568.19960737228,0.7876757383346558,0.8453723788261414,0.7234599590301514,1.131539225578308,50000 -9098.598886728289,7.613273620605469,76924.13359379768,167671,0,76924.13359379768,0.6041000485420227,1.7299996614456177,10000,86039.13908457756,0.7952343821525574,0.8193172812461853,0.7249599695205688,1.1146950721740725,50000 -9149.558176994324,7.669828414916992,77344.50248265266,168589,0,77344.50248265266,0.6045000553131104,1.7418795824050903,10000,86510.57143831253,0.8006640672683716,0.809798002243042,0.7245999574661255,1.1315559148788452,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/measurements.csv deleted file mode 100644 index 91799d3a8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1877 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36554998,6.907756,,,,,,,,,,,,,, -1,,,0.0009570312104187,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,36.15118789672852,65.77166390419006,36.15118789672852,29.620367527008057,0.0,0.0 -100,0.6376575,6.8474193,,,,,,,,,,,,,, -200,0.9279347,6.70249,,,,,,,,,,,,,, -300,1.3625321,6.573474,,,,,,,,,,,,,, -400,1.0964962,6.534064,,,,,,,,,,,,,, -500,0.8318119,6.685946,,,,,,,,,,,,,, -600,0.91517615,6.439058,,,,,,,,,,,,,, -700,0.98713154,6.2851467,,,,,,,,,,,,,, -800,0.9208433,6.6215487,,,,,,,,,,,,,, -861,,,0.0339648425579071,5.889985084533691,0.0312199983745813,5.931602478027344,50000.0,0.0265000015497207,6.062020778656006,10000.0,456.1083743572235,532.0756976604462,456.1083743572235,75.90391826629639,0.0172796249389648,0.0 -900,0.61613315,6.4734955,,,,,,,,,,,,,, -1000,0.65019625,6.099498,,,,,,,,,,,,,, -1100,0.6537315,6.756523,,,,,,,,,,,,,, -1200,0.60398716,6.0911007,,,,,,,,,,,,,, -1300,0.92375195,6.799857,,,,,,,,,,,,,, -1400,0.73920393,6.235941,,,,,,,,,,,,,, -1500,0.6416048,6.285389,,,,,,,,,,,,,, -1600,0.6248527,5.9017515,,,,,,,,,,,,,, -1700,0.689465,5.8360276,,,,,,,,,,,,,, -1774,,,0.0719531252980232,5.346263408660889,0.0669599995017051,5.416959285736084,50000.0,0.0535000041127204,5.638273239135742,10000.0,876.1209945678711,1000.3806068897248,876.1209945678711,124.1122748851776,0.0530307292938232,0.0 -1800,0.6179777,5.966546,,,,,,,,,,,,,, -1900,0.48420343,6.4138803,,,,,,,,,,,,,, -2000,0.4727131,6.5348234,,,,,,,,,,,,,, -2100,0.486954,5.756413,,,,,,,,,,,,,, -2200,0.38613865,6.2226305,,,,,,,,,,,,,, -2300,0.5189235,5.791856,,,,,,,,,,,,,, -2400,0.59566104,5.6011972,,,,,,,,,,,,,, -2500,0.42762977,6.4163656,,,,,,,,,,,,,, -2600,0.43809092,6.1235876,,,,,,,,,,,,,, -2694,,,0.1066601574420929,4.961225986480713,0.0988399982452392,5.016429424285889,50000.0,0.0766000002622604,5.27877950668335,10000.0,1296.1530268192291,1474.3475806713104,1296.1530268192291,177.97129225730896,0.0809743404388427,0.0 -2700,0.48627895,5.894249,,,,,,,,,,,,,, -2800,0.4877201,5.537541,,,,,,,,,,,,,, -2900,0.49895513,5.608197,,,,,,,,,,,,,, -3000,0.5830241,5.4431953,,,,,,,,,,,,,, -3100,0.5260329,5.3606043,,,,,,,,,,,,,, -3200,0.38235047,6.4976096,,,,,,,,,,,,,, -3300,0.5685757,5.426673,,,,,,,,,,,,,, -3400,0.608595,5.845279,,,,,,,,,,,,,, -3500,0.66932297,5.449258,,,,,,,,,,,,,, -3600,0.5080203,5.394008,,,,,,,,,,,,,, -3614,,,0.1257812529802322,4.762739181518555,0.1183599978685379,4.825972557067871,50000.0,0.0958000048995018,5.121191024780273,10000.0,1716.2821543216703,1942.9452345371249,1716.2821543216703,226.3655271530152,0.1077020168304443,0.0 -3700,0.63862216,5.3676987,,,,,,,,,,,,,, -3800,0.4736885,5.2934566,,,,,,,,,,,,,, -3900,0.80076975,5.220756,,,,,,,,,,,,,, -4000,0.567715,6.3822045,,,,,,,,,,,,,, -4100,0.7429167,6.3986564,,,,,,,,,,,,,, -4200,0.613958,5.3319845,,,,,,,,,,,,,, -4300,0.5605512,6.336531,,,,,,,,,,,,,, -4400,0.7766954,5.1601143,,,,,,,,,,,,,, -4500,0.54219013,5.993781,,,,,,,,,,,,,, -4529,,,0.1682031154632568,4.37608528137207,0.1546199917793274,4.472805976867676,50000.0,0.1157000064849853,4.830338478088379,10000.0,2136.238264322281,2407.637862443924,2136.238264322281,271.0300600528717,0.1323871612548828,0.0 -4600,0.7793062,5.0792885,,,,,,,,,,,,,, -4700,0.61015135,5.175837,,,,,,,,,,,,,, -4800,0.67751074,4.907221,,,,,,,,,,,,,, -4900,0.6693233,5.769417,,,,,,,,,,,,,, -5000,0.7370275,5.597152,,,,,,,,,,,,,, -5100,0.7471035,5.181202,,,,,,,,,,,,,, -5200,0.65447176,4.9436164,,,,,,,,,,,,,, -5300,0.7492952,4.9493437,,,,,,,,,,,,,, -5400,0.6197304,4.926573,,,,,,,,,,,,,, -5445,,,0.1838281154632568,4.257326126098633,0.1714800000190735,4.334697246551514,50000.0,0.1330000013113021,4.718831539154053,10000.0,2556.621108531952,2877.380870580673,2556.621108531952,320.3178098201752,0.1565053462982177,0.0 -5500,0.6934272,4.836215,,,,,,,,,,,,,, -5600,0.57345986,6.25273,,,,,,,,,,,,,, -5700,0.5569326,6.269748,,,,,,,,,,,,,, -5800,0.9343688,5.1015196,,,,,,,,,,,,,, -5900,0.8220937,4.6870136,,,,,,,,,,,,,, -6000,0.87090427,4.952626,,,,,,,,,,,,,, -6100,0.8615638,4.830052,,,,,,,,,,,,,, -6200,0.8438536,5.193898,,,,,,,,,,,,,, -6300,0.8845249,6.3138328,,,,,,,,,,,,,, -6364,,,0.2262499928474426,3.872288942337036,0.2107999920845031,3.977044343948364,50000.0,0.1616000086069107,4.43961763381958,10000.0,2976.629909515381,3347.430968284607,2976.629909515381,370.2815811634064,0.1853024959564209,0.0 -6400,0.8292652,5.724516,,,,,,,,,,,,,, -6500,0.81489164,4.970718,,,,,,,,,,,,,, -6600,0.7192458,4.909964,,,,,,,,,,,,,, -6700,0.7408788,5.1461706,,,,,,,,,,,,,, -6800,0.80557775,4.953672,,,,,,,,,,,,,, -6900,0.69694966,4.706662,,,,,,,,,,,,,, -7000,1.0280602,4.829719,,,,,,,,,,,,,, -7100,0.8235598,4.585063,,,,,,,,,,,,,, -7200,0.88808393,4.7722235,,,,,,,,,,,,,, -7280,,,0.2319921851158142,3.833322286605835,0.212359994649887,3.964530467987061,50000.0,0.1700000017881393,4.424200534820557,10000.0,3396.913996696472,3817.68110537529,3396.913996696472,420.16905403137207,0.2169387340545654,0.0 -7300,0.77992857,6.3360076,,,,,,,,,,,,,, -7400,1.0178779,4.7342834,,,,,,,,,,,,,, -7500,0.9958114,4.5851235,,,,,,,,,,,,,, -7600,0.6923848,4.8333707,,,,,,,,,,,,,, -7700,0.7708591,4.6020355,,,,,,,,,,,,,, -7800,0.80691594,4.6075206,,,,,,,,,,,,,, -7900,0.7939662,5.004179,,,,,,,,,,,,,, -8000,0.9527178,6.2454686,,,,,,,,,,,,,, -8100,0.837816,4.7599344,,,,,,,,,,,,,, -8200,,,0.2553320229053497,3.758265256881714,0.2200799882411956,3.949136257171631,50000.0,0.1718000024557113,4.393675327301025,10000.0,3816.8563516139984,4286.148283243179,3816.8563516139984,468.6132698059082,0.2487246990203857,0.0 -8200,0.78018,5.4983625,,,,,,,,,,,,,, -8300,0.9451501,4.6773314,,,,,,,,,,,,,, -8400,0.771885,4.5388927,,,,,,,,,,,,,, -8500,0.97966254,4.5537815,,,,,,,,,,,,,, -8600,0.73140335,5.3341527,,,,,,,,,,,,,, -8700,0.740714,5.981064,,,,,,,,,,,,,, -8800,0.75722945,4.4553633,,,,,,,,,,,,,, -8900,0.65536946,5.8713694,,,,,,,,,,,,,, -9000,0.6928795,4.7842875,,,,,,,,,,,,,, -9100,0.98346287,4.551713,,,,,,,,,,,,,, -9118,,,0.267871081829071,3.58646821975708,0.2466599941253662,3.7260239124298096,50000.0,0.1919000148773193,4.206649303436279,10000.0,4236.834023237228,4752.014056444168,4236.834023237228,514.4226930141449,0.2796754837036133,0.0 -9200,0.98473793,4.5248203,,,,,,,,,,,,,, -9300,0.64025974,5.6616974,,,,,,,,,,,,,, -9400,1.0258626,4.4380293,,,,,,,,,,,,,, -9500,0.7587928,6.2031302,,,,,,,,,,,,,, -9600,0.6402052,5.9192233,,,,,,,,,,,,,, -9700,0.7487326,6.0509233,,,,,,,,,,,,,, -9800,0.95399666,4.401967,,,,,,,,,,,,,, -9900,0.68491197,6.1075554,,,,,,,,,,,,,, -10000,0.92754054,4.38716,,,,,,,,,,,,,, -10034,,,0.2818945348262787,3.467654228210449,0.2607599794864654,3.606190204620361,50000.0,0.2006000131368637,4.112731456756592,10000.0,4656.83242058754,5220.361520528793,4656.83242058754,562.69517993927,0.3083176612854004,0.0 -10100,0.8426502,6.1239104,,,,,,,,,,,,,, -10200,0.63765955,5.505297,,,,,,,,,,,,,, -10300,0.8334407,5.4704733,,,,,,,,,,,,,, -10400,0.81664777,5.0901194,,,,,,,,,,,,,, -10500,0.81568384,4.9763765,,,,,,,,,,,,,, -10600,0.79321957,4.400794,,,,,,,,,,,,,, -10700,1.2955295,4.557459,,,,,,,,,,,,,, -10800,0.9925716,4.4072495,,,,,,,,,,,,,, -10900,0.685854,5.3608055,,,,,,,,,,,,,, -10952,,,0.294726550579071,3.4312899112701416,0.259660005569458,3.6437604427337646,50000.0,0.1982000023126602,4.1670637130737305,10000.0,5076.973096132278,5687.518299818039,5076.973096132278,609.6357562541962,0.3355536460876465,0.0 -11000,0.86458015,4.3052683,,,,,,,,,,,,,, -11100,0.97355413,4.309908,,,,,,,,,,,,,, -11200,0.86818385,4.5307183,,,,,,,,,,,,,, -11300,0.7634065,4.351643,,,,,,,,,,,,,, -11400,0.6709948,6.0298457,,,,,,,,,,,,,, -11500,0.9221734,5.0856075,,,,,,,,,,,,,, -11600,0.9248375,4.3577847,,,,,,,,,,,,,, -11700,0.8079269,5.466876,,,,,,,,,,,,,, -11800,0.9284423,4.228136,,,,,,,,,,,,,, -11870,,,0.2876171767711639,3.442207336425781,0.2721000015735626,3.556260108947754,50000.0,0.2055000066757202,4.087541103363037,10000.0,5497.2170152664185,6156.984508275986,5497.2170152664185,658.7836654186249,0.3622803688049316,0.0 -11900,0.63965183,6.012583,,,,,,,,,,,,,, -12000,0.9309958,4.202846,,,,,,,,,,,,,, -12100,1.0304483,4.8660064,,,,,,,,,,,,,, -12200,0.7830263,4.385379,,,,,,,,,,,,,, -12300,0.97423893,4.289025,,,,,,,,,,,,,, -12400,0.8531801,4.4195633,,,,,,,,,,,,,, -12500,0.89893407,4.2261267,,,,,,,,,,,,,, -12600,1.0137416,4.7468843,,,,,,,,,,,,,, -12700,1.0907513,4.3592668,,,,,,,,,,,,,, -12789,,,0.2928906083106994,3.4402520656585693,0.2694199979305267,3.5773630142211914,50000.0,0.2110000103712082,4.108162879943848,10000.0,5917.322836399078,6620.591611146927,5917.322836399078,702.2066161632538,0.3930139541625976,0.0 -12800,0.88607335,5.51733,,,,,,,,,,,,,, -12900,0.7769704,4.1547985,,,,,,,,,,,,,, -13000,1.0288947,4.281943,,,,,,,,,,,,,, -13100,0.80523103,4.3066893,,,,,,,,,,,,,, -13200,1.0475564,4.339219,,,,,,,,,,,,,, -13300,0.8405435,4.195071,,,,,,,,,,,,,, -13400,0.82688797,6.182556,,,,,,,,,,,,,, -13500,0.8488722,4.227752,,,,,,,,,,,,,, -13600,0.81642157,5.5879235,,,,,,,,,,,,,, -13700,0.6264162,5.4182014,,,,,,,,,,,,,, -13707,,,0.3098046779632568,3.3494856357574463,0.2825199961662292,3.517797946929932,50000.0,0.2140000164508819,4.01960039138794,10000.0,6337.505370855331,7088.974200248718,6337.505370855331,750.3280100822449,0.4236247539520263,0.0 -13800,0.8537352,5.1045055,,,,,,,,,,,,,, -13900,0.6541437,5.841221,,,,,,,,,,,,,, -14000,0.9575193,4.238019,,,,,,,,,,,,,, -14100,1.1056448,4.3412256,,,,,,,,,,,,,, -14200,0.95427877,4.223213,,,,,,,,,,,,,, -14300,0.7688453,5.3547397,,,,,,,,,,,,,, -14400,0.93181854,4.4184904,,,,,,,,,,,,,, -14500,0.89799273,4.094774,,,,,,,,,,,,,, -14600,0.8695027,4.2086835,,,,,,,,,,,,,, -14625,,,0.3052343726158142,3.417556047439575,0.2822999954223633,3.541224956512451,50000.0,0.2201000154018402,4.021054744720459,10000.0,6757.771499872208,7557.805434465408,6757.771499872208,798.8193001747131,0.4494767189025879,0.0 -14700,0.7814907,4.041311,,,,,,,,,,,,,, -14800,0.85130286,4.244982,,,,,,,,,,,,,, -14900,0.7253838,6.2244124,,,,,,,,,,,,,, -15000,0.83228326,4.9402804,,,,,,,,,,,,,, -15100,0.8544488,4.133041,,,,,,,,,,,,,, -15200,0.68997055,4.6237664,,,,,,,,,,,,,, -15300,0.7884593,4.5076094,,,,,,,,,,,,,, -15400,0.8385951,4.0942993,,,,,,,,,,,,,, -15500,0.75073266,5.818883,,,,,,,,,,,,,, -15539,,,0.3110156059265136,3.333627939224243,0.290719985961914,3.459397315979004,50000.0,0.226500004529953,3.976407527923584,10000.0,7177.714713335037,8025.688539028168,7177.714713335037,846.682626247406,0.4787595272064209,0.0 -15600,0.76988745,5.1039557,,,,,,,,,,,,,, -15700,0.809323,4.128466,,,,,,,,,,,,,, -15800,0.77995104,6.005169,,,,,,,,,,,,,, -15900,0.88015157,4.557086,,,,,,,,,,,,,, -16000,0.8360379,4.881337,,,,,,,,,,,,,, -16100,0.67848563,5.987008,,,,,,,,,,,,,, -16200,0.8881146,4.272724,,,,,,,,,,,,,, -16300,0.8461256,4.3618007,,,,,,,,,,,,,, -16400,1.3164814,4.29789,,,,,,,,,,,,,, -16458,,,0.3293554484844208,3.1784443855285645,0.3029599785804748,3.3324337005615234,50000.0,0.2324000149965286,3.872088432312012,10000.0,7597.749115467071,8495.327637910843,7597.749115467071,896.20419049263,0.5077810287475586,0.0 -16500,0.845236,4.0134068,,,,,,,,,,,,,, -16600,1.1483454,4.283595,,,,,,,,,,,,,, -16700,0.8213644,4.0484467,,,,,,,,,,,,,, -16800,1.0365736,4.333249,,,,,,,,,,,,,, -16900,0.76535636,5.212039,,,,,,,,,,,,,, -17000,0.883125,4.193965,,,,,,,,,,,,,, -17100,0.9606233,4.0497227,,,,,,,,,,,,,, -17200,0.9626182,4.3467493,,,,,,,,,,,,,, -17300,1.1423011,4.126547,,,,,,,,,,,,,, -17377,,,0.3253906071186065,3.243004083633423,0.3007999956607818,3.385110855102539,50000.0,0.2331000119447708,3.922811985015869,10000.0,8017.729243755341,8965.22508430481,8017.729243755341,946.04585814476,0.5358819961547852,0.0 -17400,0.89564437,4.8603177,,,,,,,,,,,,,, -17500,1.0819707,4.128632,,,,,,,,,,,,,, -17600,0.8433224,4.1304564,,,,,,,,,,,,,, -17700,0.9610574,4.202442,,,,,,,,,,,,,, -17800,1.1958411,4.3832746,,,,,,,,,,,,,, -17900,0.8539858,4.2055798,,,,,,,,,,,,,, -18000,0.9481638,4.2328095,,,,,,,,,,,,,, -18100,0.7551279,6.008916,,,,,,,,,,,,,, -18200,1.1383364,4.107272,,,,,,,,,,,,,, -18295,,,0.3461523354053497,3.063856363296509,0.3191399872303009,3.214115619659424,50000.0,0.2424000054597854,3.785786390304565,10000.0,8438.02011179924,9434.594047546389,8438.02011179924,995.0433971881866,0.5683629512786865,0.0 -18300,0.938798,4.217078,,,,,,,,,,,,,, -18400,0.88347775,6.063612,,,,,,,,,,,,,, -18500,1.0941256,4.212214,,,,,,,,,,,,,, -18600,0.99921143,4.097374,,,,,,,,,,,,,, -18700,0.9384707,4.2645617,,,,,,,,,,,,,, -18800,1.1569686,4.1012344,,,,,,,,,,,,,, -18900,0.9095814,4.082232,,,,,,,,,,,,,, -19000,0.86148,3.9459352,,,,,,,,,,,,,, -19100,0.9296887,5.5647225,,,,,,,,,,,,,, -19200,0.86188763,4.3070273,,,,,,,,,,,,,, -19213,,,0.3359765410423279,3.1254475116729736,0.3057000041007995,3.2792983055114746,50000.0,0.2399000078439712,3.832469701766968,10000.0,8858.02958726883,9906.623125314713,8858.02958726883,1046.9851393699646,0.5980954170227051,0.0 -19300,0.90426815,4.0965858,,,,,,,,,,,,,, -19400,0.91844064,4.47684,,,,,,,,,,,,,, -19500,1.3079354,4.1750255,,,,,,,,,,,,,, -19600,1.1571714,4.276084,,,,,,,,,,,,,, -19700,0.90952826,4.0777974,,,,,,,,,,,,,, -19800,0.79575783,4.2479615,,,,,,,,,,,,,, -19900,0.969355,4.2685966,,,,,,,,,,,,,, -20000,0.93140143,3.9859586,,,,,,,,,,,,,, -20100,0.88576573,4.928349,,,,,,,,,,,,,, -20132,,,0.3737109303474426,2.931885004043579,0.3187199831008911,3.236717939376831,50000.0,0.241100013256073,3.792826414108277,10000.0,9278.196017742155,10376.46831059456,9278.196017742155,1096.579419374466,0.625751256942749,0.0 -20200,0.74257046,6.0954413,,,,,,,,,,,,,, -20300,0.87882143,4.768694,,,,,,,,,,,,,, -20400,0.7791446,6.0453477,,,,,,,,,,,,,, -20500,0.6605992,5.9920793,,,,,,,,,,,,,, -20600,0.880063,4.035149,,,,,,,,,,,,,, -20700,0.8505324,3.9588237,,,,,,,,,,,,,, -20800,0.81649566,5.597704,,,,,,,,,,,,,, -20900,0.79024994,4.780726,,,,,,,,,,,,,, -21000,0.7250906,5.5351233,,,,,,,,,,,,,, -21052,,,0.3433007597923279,3.070847988128662,0.3197799921035766,3.210434913635254,50000.0,0.2482000142335891,3.769993543624878,10000.0,9698.436318159103,10848.208649396896,9698.436318159103,1148.0055141448977,0.6522841453552246,0.0 -21100,0.9515346,4.043995,,,,,,,,,,,,,, -21200,0.90108573,4.2084723,,,,,,,,,,,,,, -21300,0.66345674,6.1109376,,,,,,,,,,,,,, -21400,0.82205,4.700605,,,,,,,,,,,,,, -21500,0.7605357,6.0159016,,,,,,,,,,,,,, -21600,0.7031479,5.8962936,,,,,,,,,,,,,, -21700,0.66869664,5.793374,,,,,,,,,,,,,, -21800,0.9158125,4.1911454,,,,,,,,,,,,,, -21900,1.0263909,4.276331,,,,,,,,,,,,,, -21964,,,0.3420703113079071,3.094113349914551,0.3227799832820892,3.226070165634156,50000.0,0.2415000051259994,3.799169778823853,10000.0,10118.706419706345,11317.737473249435,10118.706419706345,1197.1896917819977,0.6798815727233887,0.0 -22000,0.8872069,4.0828466,,,,,,,,,,,,,, -22100,0.93815464,4.245524,,,,,,,,,,,,,, -22200,0.7837553,4.8687625,,,,,,,,,,,,,, -22300,1.0074724,4.001991,,,,,,,,,,,,,, -22400,0.7902895,5.6796803,,,,,,,,,,,,,, -22500,0.87028843,4.6818824,,,,,,,,,,,,,, -22600,0.7784167,4.3817387,,,,,,,,,,,,,, -22700,0.9006351,4.001855,,,,,,,,,,,,,, -22800,0.8595797,4.4804335,,,,,,,,,,,,,, -22880,,,0.36865234375,2.973356008529663,0.3253600001335144,3.211477041244507,50000.0,0.2574000060558319,3.756904125213623,10000.0,10538.827253580092,11786.644594669342,10538.827253580092,1245.8942565917969,0.7139434814453125,0.0 -22900,0.8665457,4.258939,,,,,,,,,,,,,, -23000,0.9754269,4.0098386,,,,,,,,,,,,,, -23100,0.9364333,4.0146065,,,,,,,,,,,,,, -23200,0.9273937,4.058566,,,,,,,,,,,,,, -23300,0.7488321,5.0356917,,,,,,,,,,,,,, -23400,1.1181095,4.0734677,,,,,,,,,,,,,, -23500,1.0191234,5.96643,,,,,,,,,,,,,, -23600,0.96609384,4.6283693,,,,,,,,,,,,,, -23700,0.9533028,3.864932,,,,,,,,,,,,,, -23795,,,0.3464648425579071,3.0866899490356445,0.3257800042629242,3.2104380130767822,50000.0,0.2494000047445297,3.7716026306152335,10000.0,10958.82033610344,12256.103848218918,10958.82033610344,1295.278584241867,0.7482728958129883,0.0 -23800,0.9531725,4.291996,,,,,,,,,,,,,, -23900,0.9207655,4.084735,,,,,,,,,,,,,, -24000,0.8794355,5.0585318,,,,,,,,,,,,,, -24100,0.94469947,4.217495,,,,,,,,,,,,,, -24200,0.99414665,4.7696023,,,,,,,,,,,,,, -24300,0.9532767,4.0238905,,,,,,,,,,,,,, -24400,0.7135678,5.4240165,,,,,,,,,,,,,, -24500,1.1119447,4.0695405,,,,,,,,,,,,,, -24600,0.87578785,3.9665277,,,,,,,,,,,,,, -24700,1.2010821,5.2279453,,,,,,,,,,,,,, -24713,,,0.3595702946186065,2.995157718658448,0.333079993724823,3.153878688812256,50000.0,0.2598000168800354,3.707412242889404,10000.0,11378.957997083664,12725.690438747406,11378.957997083664,1344.6447319984436,0.7822833061218262,0.0 -24800,0.7712962,6.011901,,,,,,,,,,,,,, -24900,0.84921944,5.75626,,,,,,,,,,,,,, -25000,0.70611435,6.113511,,,,,,,,,,,,,, -25100,0.89422584,4.5345488,,,,,,,,,,,,,, -25200,0.8161853,5.0252957,,,,,,,,,,,,,, -25300,0.7684935,5.824777,,,,,,,,,,,,,, -25400,0.98879784,4.021473,,,,,,,,,,,,,, -25500,1.107937,4.0618553,,,,,,,,,,,,,, -25600,0.7509744,5.2239184,,,,,,,,,,,,,, -25628,,,0.3726952970027923,2.9379703998565674,0.3400200009346008,3.115710020065308,50000.0,0.2576000094413757,3.701142311096192,10000.0,11799.026058912275,13195.52159357071,11799.026058912275,1394.3293986320496,0.8127717971801758,0.0 -25700,0.9697219,4.2147512,,,,,,,,,,,,,, -25800,0.95539826,4.1109796,,,,,,,,,,,,,, -25900,0.9415759,4.220256,,,,,,,,,,,,,, -26000,1.0481058,3.848192,,,,,,,,,,,,,, -26100,0.8834505,6.088268,,,,,,,,,,,,,, -26200,1.0292395,4.3063626,,,,,,,,,,,,,, -26300,1.0025582,3.8942864,,,,,,,,,,,,,, -26400,0.8437461,4.6495037,,,,,,,,,,,,,, -26500,1.0411756,4.0202284,,,,,,,,,,,,,, -26542,,,0.3614648282527923,2.9840753078460693,0.3368600010871887,3.117708444595337,50000.0,0.2600000202655792,3.686853408813477,10000.0,12219.03768491745,13665.651197195051,12219.03768491745,1444.3692378997805,0.843348503112793,0.0 -26600,0.9338968,3.898096,,,,,,,,,,,,,, -26700,0.998999,4.0065737,,,,,,,,,,,,,, -26800,0.7876776,5.8525343,,,,,,,,,,,,,, -26900,0.7659978,5.194877,,,,,,,,,,,,,, -27000,0.9518101,3.7954566,,,,,,,,,,,,,, -27100,0.9505633,4.300941,,,,,,,,,,,,,, -27200,0.87775105,4.151347,,,,,,,,,,,,,, -27300,1.1292099,3.8847134,,,,,,,,,,,,,, -27400,0.81806296,5.284588,,,,,,,,,,,,,, -27462,,,0.3483007848262787,3.1170191764831543,0.3240199983119964,3.26419997215271,50000.0,0.2497000098228454,3.807725191116333,10000.0,12639.100955963137,14135.43448138237,12639.100955963137,1494.0119626522064,0.8733620643615723,0.0 -27500,0.86278284,4.0449433,,,,,,,,,,,,,, -27600,0.77391094,5.96972,,,,,,,,,,,,,, -27700,0.88276666,5.231979,,,,,,,,,,,,,, -27800,0.8438078,6.0308237,,,,,,,,,,,,,, -27900,0.9802267,4.063492,,,,,,,,,,,,,, -28000,0.9636325,3.9381971,,,,,,,,,,,,,, -28100,1.206313,4.06375,,,,,,,,,,,,,, -28200,0.8105347,4.730583,,,,,,,,,,,,,, -28300,0.983904,3.9198718,,,,,,,,,,,,,, -28380,,,0.368945300579071,2.9343044757843018,0.338619977235794,3.106697797775269,50000.0,0.2630999982357025,3.679342269897461,10000.0,13059.38454055786,14607.170650720596,13059.38454055786,1545.383987903595,0.9063313007354736,0.0 -28400,0.9707431,3.933516,,,,,,,,,,,,,, -28500,0.85835814,4.367571,,,,,,,,,,,,,, -28600,0.95805484,4.112112,,,,,,,,,,,,,, -28700,0.80878365,5.1526875,,,,,,,,,,,,,, -28800,0.77297425,5.3313756,,,,,,,,,,,,,, -28900,0.8894185,4.400491,,,,,,,,,,,,,, -29000,0.9330435,3.983539,,,,,,,,,,,,,, -29100,0.8681721,5.1241784,,,,,,,,,,,,,, -29200,0.8249556,4.358349,,,,,,,,,,,,,, -29298,,,0.3706835806369781,2.9274771213531494,0.3446199893951416,3.07129430770874,50000.0,0.2648000121116638,3.661438465118408,10000.0,13479.487174987791,15077.99607181549,13479.487174987791,1596.0251424312592,0.9394416809082032,0.0 -29300,0.9782758,4.1240697,,,,,,,,,,,,,, -29400,0.89696455,4.012604,,,,,,,,,,,,,, -29500,0.8928116,4.018538,,,,,,,,,,,,,, -29600,1.2665273,3.908413,,,,,,,,,,,,,, -29700,0.9642732,3.9757273,,,,,,,,,,,,,, -29800,1.1535435,3.950909,,,,,,,,,,,,,, -29900,1.052159,3.9608564,,,,,,,,,,,,,, -30000,0.97592074,3.8988495,,,,,,,,,,,,,, -30100,1.1700959,4.3302674,,,,,,,,,,,,,, -30200,0.77945244,4.639443,,,,,,,,,,,,,, -30217,,,0.3766210973262787,2.902115821838379,0.3519800007343292,3.029473066329956,50000.0,0.2742000222206116,3.627936124801636,10000.0,13899.69462966919,15548.25105714798,13899.69462966919,1645.9867305755615,0.9761998653411864,0.0 -30300,0.99566483,3.9351645,,,,,,,,,,,,,, -30400,0.79645145,5.784069,,,,,,,,,,,,,, -30500,0.9268369,4.217522,,,,,,,,,,,,,, -30600,0.9298227,3.9458349,,,,,,,,,,,,,, -30700,0.8062129,4.490538,,,,,,,,,,,,,, -30800,0.76324993,5.4574747,,,,,,,,,,,,,, -30900,0.9053555,4.7280335,,,,,,,,,,,,,, -31000,0.9252263,3.7468283,,,,,,,,,,,,,, -31100,0.8670486,3.8567085,,,,,,,,,,,,,, -31135,,,0.3855859339237213,2.8547158241271973,0.3534599840641022,3.0370798110961914,50000.0,0.2705000042915344,3.6340298652648926,10000.0,14319.800573825836,16018.214548110962,14319.800573825836,1695.7626497745514,1.0101087093353271,0.0 -31200,0.9114246,5.280366,,,,,,,,,,,,,, -31300,0.8228166,5.038291,,,,,,,,,,,,,, -31400,1.0824959,3.8467257,,,,,,,,,,,,,, -31500,0.9362341,4.3543806,,,,,,,,,,,,,, -31600,1.0498282,3.8728778,,,,,,,,,,,,,, -31700,0.9556361,3.9720306,,,,,,,,,,,,,, -31800,0.93775105,3.9823325,,,,,,,,,,,,,, -31900,1.1863238,3.9079766,,,,,,,,,,,,,, -32000,0.92639023,3.9355671,,,,,,,,,,,,,, -32055,,,0.4124804735183716,2.7808024883270264,0.3519199788570404,3.0841052532196045,50000.0,0.2671000063419342,3.660425901412964,10000.0,14739.811116695404,16486.448800325394,14739.811116695404,1743.906052350998,1.0428571701049805,0.0 -32100,0.84413004,6.0463395,,,,,,,,,,,,,, -32200,0.8531856,6.041004,,,,,,,,,,,,,, -32300,0.90912926,4.0626845,,,,,,,,,,,,,, -32400,1.2668827,4.0662203,,,,,,,,,,,,,, -32500,0.7503442,5.912663,,,,,,,,,,,,,, -32600,1.0043337,3.9163477,,,,,,,,,,,,,, -32700,0.94334394,3.8924468,,,,,,,,,,,,,, -32800,1.0250753,3.8017511,,,,,,,,,,,,,, -32900,1.0778482,3.9262633,,,,,,,,,,,,,, -32972,,,0.38671875,2.805215120315552,0.3612799942493438,2.94804048538208,50000.0,0.2773000001907348,3.537525653839112,10000.0,15160.1337788105,16955.91092300415,15160.1337788105,1792.961481332779,1.0791571140289309,0.0 -33000,1.0820218,5.142108,,,,,,,,,,,,,, -33100,1.011948,3.8401341,,,,,,,,,,,,,, -33200,1.0804527,3.8721704,,,,,,,,,,,,,, -33300,1.088784,3.9565334,,,,,,,,,,,,,, -33400,0.6999451,5.0904417,,,,,,,,,,,,,, -33500,0.73439777,5.3328223,,,,,,,,,,,,,, -33600,1.0347521,3.8143098,,,,,,,,,,,,,, -33700,1.0165184,3.9084668,,,,,,,,,,,,,, -33800,0.8804476,3.7918067,,,,,,,,,,,,,, -33891,,,0.3903124928474426,2.800220966339112,0.363999992609024,2.969886541366577,50000.0,0.2771000266075134,3.5811426639556885,10000.0,15580.551282167437,17422.885044813156,15580.551282167437,1839.4359893798828,1.1127994060516355,0.0 -33900,0.982601,3.9136443,,,,,,,,,,,,,, -34000,1.0864767,3.9332006,,,,,,,,,,,,,, -34100,0.9256942,3.8367817,,,,,,,,,,,,,, -34200,0.9155215,4.0424356,,,,,,,,,,,,,, -34300,0.93503606,5.897776,,,,,,,,,,,,,, -34400,1.0211271,3.9315789,,,,,,,,,,,,,, -34500,1.2764599,4.04308,,,,,,,,,,,,,, -34600,1.3106836,3.7581391,,,,,,,,,,,,,, -34700,1.1402143,3.8074548,,,,,,,,,,,,,, -34800,0.72734743,4.960705,,,,,,,,,,,,,, -34809,,,0.3956054747104645,2.79667067527771,0.3571199774742126,3.022451400756836,50000.0,0.2684000134468078,3.636655569076538,10000.0,16000.717082977297,17892.378203630447,16000.717082977297,1888.6756381988523,1.1519043445587158,0.0 -34900,0.8565786,4.913135,,,,,,,,,,,,,, -35000,0.7735865,5.8497677,,,,,,,,,,,,,, -35100,0.87377167,5.655754,,,,,,,,,,,,,, -35200,0.9111198,3.8337903,,,,,,,,,,,,,, -35300,1.0074416,4.1477695,,,,,,,,,,,,,, -35400,1.3028622,3.7858365,,,,,,,,,,,,,, -35500,1.3409085,3.924534,,,,,,,,,,,,,, -35600,1.0373164,5.629124,,,,,,,,,,,,,, -35700,1.008452,3.909542,,,,,,,,,,,,,, -35725,,,0.390937477350235,2.7831339836120605,0.3679399788379669,2.916210174560547,50000.0,0.2833000123500824,3.526043176651001,10000.0,16421.02041387558,18362.607944726944,16421.02041387558,1938.52412891388,1.1824181079864502,0.0 -35800,0.97869575,5.986896,,,,,,,,,,,,,, -35900,0.98221874,3.8264937,,,,,,,,,,,,,, -36000,1.1257263,3.8754525,,,,,,,,,,,,,, -36100,0.95629466,3.6614044,,,,,,,,,,,,,, -36200,1.1214575,3.7760694,,,,,,,,,,,,,, -36300,1.1432321,3.7868338,,,,,,,,,,,,,, -36400,1.0547885,4.170247,,,,,,,,,,,,,, -36500,0.8570074,4.6732264,,,,,,,,,,,,,, -36600,1.0241061,3.8245664,,,,,,,,,,,,,, -36639,,,0.3834374845027923,2.84735369682312,0.3583599925041199,2.9957029819488525,50000.0,0.2743000090122223,3.578441619873047,10000.0,16841.228150367737,18832.27735710144,16841.228150367737,1987.9012160301208,1.2189600467681885,0.0 -36700,0.9675542,3.7961597,,,,,,,,,,,,,, -36800,0.9675787,3.7458115,,,,,,,,,,,,,, -36900,0.83424896,5.488601,,,,,,,,,,,,,, -37000,1.0058205,3.9183834,,,,,,,,,,,,,, -37100,0.94003946,3.8636913,,,,,,,,,,,,,, -37200,0.9465392,4.1746144,,,,,,,,,,,,,, -37300,1.2983147,3.7188063,,,,,,,,,,,,,, -37400,1.0360903,3.6711826,,,,,,,,,,,,,, -37500,1.0637861,3.9168887,,,,,,,,,,,,,, -37555,,,0.3958398401737213,2.8110783100128174,0.3609800040721893,3.0116312503814697,50000.0,0.2765000164508819,3.5944924354553223,10000.0,17261.399400949478,19304.57182574272,17261.399400949478,2039.933144569397,1.262007474899292,0.0 -37600,1.0930492,3.8426263,,,,,,,,,,,,,, -37700,1.62197,4.2773485,,,,,,,,,,,,,, -37800,0.9042305,4.904513,,,,,,,,,,,,,, -37900,0.8433196,5.0695667,,,,,,,,,,,,,, -38000,0.96637636,3.76522,,,,,,,,,,,,,, -38100,0.8453654,5.966834,,,,,,,,,,,,,, -38200,1.2040266,4.362403,,,,,,,,,,,,,, -38300,0.89644814,3.8555892,,,,,,,,,,,,,, -38400,0.8598058,4.6303587,,,,,,,,,,,,,, -38473,,,0.4010546803474426,2.771629571914673,0.37567999958992,2.9098598957061768,50000.0,0.2942000031471252,3.5065810680389404,10000.0,17681.434158086777,19773.024648189545,17681.434158086777,2088.265805721283,1.299102783203125,0.0 -38500,1.0955659,3.6403248,,,,,,,,,,,,,, -38600,1.0507705,3.8246505,,,,,,,,,,,,,, -38700,1.0914787,4.900394,,,,,,,,,,,,,, -38800,1.10723,3.818284,,,,,,,,,,,,,, -38900,1.1294576,3.8721654,,,,,,,,,,,,,, -39000,1.1138201,4.2278724,,,,,,,,,,,,,, -39100,0.98214704,3.805429,,,,,,,,,,,,,, -39200,0.8384758,4.073676,,,,,,,,,,,,,, -39300,0.8149646,4.5951486,,,,,,,,,,,,,, -39390,,,0.3926562368869781,2.793609142303467,0.3681399822235107,2.933591842651367,50000.0,0.2830000221729278,3.5438942909240723,10000.0,18101.503092050552,20242.74775648117,18101.503092050552,2137.838715553284,1.3325514793395996,0.0 -39400,0.85704803,4.5793133,,,,,,,,,,,,,, -39500,1.0396423,3.8725505,,,,,,,,,,,,,, -39600,1.0260253,3.753574,,,,,,,,,,,,,, -39700,0.9587211,3.8968241,,,,,,,,,,,,,, -39800,1.1329135,3.7315369,,,,,,,,,,,,,, -39900,0.9480439,3.708399,,,,,,,,,,,,,, -40000,0.78171283,5.441061,,,,,,,,,,,,,, -40100,1.5887533,3.9496891,,,,,,,,,,,,,, -40200,0.9978264,3.6994336,,,,,,,,,,,,,, -40300,1.1549879,3.8274531,,,,,,,,,,,,,, -40307,,,0.4039843678474426,2.782717227935791,0.3682200014591217,2.971782684326172,50000.0,0.2861000001430511,3.549136161804199,10000.0,18521.69138717652,20712.3809030056,18521.69138717652,2187.1993803977966,1.3684730529785156,0.0 -40400,0.97189033,3.7040339,,,,,,,,,,,,,, -40500,1.1072656,3.6531277,,,,,,,,,,,,,, -40600,0.80234027,4.4251285,,,,,,,,,,,,,, -40700,1.1414173,3.7393317,,,,,,,,,,,,,, -40800,0.86194074,3.8437054,,,,,,,,,,,,,, -40900,0.73089844,5.881457,,,,,,,,,,,,,, -41000,1.0895548,4.2235174,,,,,,,,,,,,,, -41100,1.1635264,3.9232216,,,,,,,,,,,,,, -41200,1.0692937,3.6681004,,,,,,,,,,,,,, -41227,,,0.4065038859844208,2.7331039905548096,0.3854199945926666,2.870001792907715,50000.0,0.2959000170230865,3.476227045059204,10000.0,18941.72615456581,21185.7188167572,18941.72615456581,2240.424861431122,1.3985400199890137,0.0 -41300,0.9968707,4.0607986,,,,,,,,,,,,,, -41400,1.082902,3.7270494,,,,,,,,,,,,,, -41500,0.7098315,5.262603,,,,,,,,,,,,,, -41600,0.70570356,5.3796215,,,,,,,,,,,,,, -41700,0.7802073,5.9585333,,,,,,,,,,,,,, -41800,1.0230335,3.6591651,,,,,,,,,,,,,, -41900,0.76492214,5.5575514,,,,,,,,,,,,,, -42000,1.056822,3.771801,,,,,,,,,,,,,, -42100,0.92242944,5.023393,,,,,,,,,,,,,, -42145,,,0.4131640493869781,2.6345086097717285,0.381520003080368,2.8117265701293945,50000.0,0.2925000190734863,3.437443971633911,10000.0,19362.027856588364,21655.337735414505,19362.027856588364,2289.662645339966,1.4311163425445557,0.0 -42200,0.96500826,3.6655443,,,,,,,,,,,,,, -42300,1.156904,3.7976928,,,,,,,,,,,,,, -42400,0.98739177,4.2072034,,,,,,,,,,,,,, -42500,1.0563005,3.7456145,,,,,,,,,,,,,, -42600,0.733975,4.6269054,,,,,,,,,,,,,, -42700,1.0887353,3.7235026,,,,,,,,,,,,,, -42800,1.035191,3.7134924,,,,,,,,,,,,,, -42900,0.9948729,3.710372,,,,,,,,,,,,,, -43000,1.007915,3.8879113,,,,,,,,,,,,,, -43061,,,0.4229101538658142,2.6399965286254883,0.3895399868488312,2.8090786933898926,50000.0,0.3064000010490417,3.392443418502808,10000.0,19782.140295267105,22124.10464024544,19782.140295267105,2338.235716342926,1.4649834632873535,0.0 -43100,1.08491,3.8070223,,,,,,,,,,,,,, -43200,1.1117657,3.8247998,,,,,,,,,,,,,, -43300,0.94959384,3.962989,,,,,,,,,,,,,, -43400,1.163154,3.6696224,,,,,,,,,,,,,, -43500,0.8054071,4.185935,,,,,,,,,,,,,, -43600,1.1199054,3.7744641,,,,,,,,,,,,,, -43700,0.83438677,4.2579136,,,,,,,,,,,,,, -43800,1.027319,5.082136,,,,,,,,,,,,,, -43900,0.79593337,5.4058,,,,,,,,,,,,,, -43978,,,0.4477148354053497,2.5128729343414307,0.3838799893856048,2.8571627140045166,50000.0,0.2950000166893005,3.4772708415985107,10000.0,20202.28423190117,22594.31020140648,20202.28423190117,2388.211321830749,1.503532886505127,0.0 -44000,1.1889887,3.7708616,,,,,,,,,,,,,, -44100,0.89869523,3.668364,,,,,,,,,,,,,, -44200,1.1068287,3.9300065,,,,,,,,,,,,,, -44300,1.1325103,3.822878,,,,,,,,,,,,,, -44400,0.9420378,3.6813767,,,,,,,,,,,,,, -44500,1.1455494,3.8120906,,,,,,,,,,,,,, -44600,0.925273,3.8947973,,,,,,,,,,,,,, -44700,0.92467535,3.626821,,,,,,,,,,,,,, -44800,0.91791403,4.0210013,,,,,,,,,,,,,, -44895,,,0.4213085770606994,2.619121551513672,0.3991200029850006,2.753256559371948,50000.0,0.3067000210285187,3.367946147918701,10000.0,20622.277943134308,23065.153266191483,20622.277943134308,2438.9718708992004,1.5404622554779053,0.0 -44900,1.0442622,3.7462614,,,,,,,,,,,,,, -45000,0.809095,3.9335299,,,,,,,,,,,,,, -45100,1.0147616,3.6850672,,,,,,,,,,,,,, -45200,0.86529267,5.797097,,,,,,,,,,,,,, -45300,1.1113877,3.7320333,,,,,,,,,,,,,, -45400,1.0591971,3.649451,,,,,,,,,,,,,, -45500,1.0338701,3.634503,,,,,,,,,,,,,, -45600,0.9559275,4.248224,,,,,,,,,,,,,, -45700,1.0826311,3.7097342,,,,,,,,,,,,,, -45800,1.1026635,3.7243786,,,,,,,,,,,,,, -45814,,,0.411914050579071,2.6856353282928467,0.384579986333847,2.840383052825928,50000.0,0.292600005865097,3.459824323654175,10000.0,21042.53992986679,23536.338141679764,21042.53992986679,2489.8106729984283,1.576117992401123,0.0 -45900,0.91676086,4.180902,,,,,,,,,,,,,, -46000,0.93083656,3.975063,,,,,,,,,,,,,, -46100,0.8445713,4.1097093,,,,,,,,,,,,,, -46200,0.86110437,4.8293533,,,,,,,,,,,,,, -46300,1.0285473,3.7068348,,,,,,,,,,,,,, -46400,0.95158374,4.6963024,,,,,,,,,,,,,, -46500,1.1324941,4.202257,,,,,,,,,,,,,, -46600,1.1510557,3.608306,,,,,,,,,,,,,, -46700,0.98680866,3.6204731,,,,,,,,,,,,,, -46727,,,0.4382031261920929,2.527977466583252,0.3956199884414673,2.7620689868927,50000.0,0.3050000071525574,3.409043550491333,10000.0,21462.657836198807,24007.31149339676,21462.657836198807,2540.581508398056,1.613168239593506,0.0 -46800,0.95855904,3.8050933,,,,,,,,,,,,,, -46900,0.9524563,5.6277742,,,,,,,,,,,,,, -47000,1.1807864,3.575391,,,,,,,,,,,,,, -47100,1.1935576,3.6173537,,,,,,,,,,,,,, -47200,1.0798218,5.620653,,,,,,,,,,,,,, -47300,1.1688898,3.7726192,,,,,,,,,,,,,, -47400,1.0207005,4.3323174,,,,,,,,,,,,,, -47500,1.2154106,3.691714,,,,,,,,,,,,,, -47600,1.1457517,3.7733116,,,,,,,,,,,,,, -47643,,,0.4313085973262787,2.561516046524048,0.4053599834442138,2.709469079971313,50000.0,0.3105000257492065,3.3404970169067383,10000.0,21882.891329288483,24478.50642466545,21882.891329288483,2591.4612271785736,1.6478898525238037,0.0 -47700,1.0219799,3.6209965,,,,,,,,,,,,,, -47800,0.89646065,5.215052,,,,,,,,,,,,,, -47900,1.0670937,3.5463376,,,,,,,,,,,,,, -48000,0.884008,4.606067,,,,,,,,,,,,,, -48100,0.9478785,5.171317,,,,,,,,,,,,,, -48200,1.0785784,3.7618182,,,,,,,,,,,,,, -48300,1.3522366,3.8797927,,,,,,,,,,,,,, -48400,1.0459378,3.6231766,,,,,,,,,,,,,, -48500,1.0100082,3.7124429,,,,,,,,,,,,,, -48561,,,0.4389062523841858,2.5233049392700195,0.4129000008106231,2.669914960861206,50000.0,0.317900002002716,3.279659509658813,10000.0,22302.96504020691,24948.79293680191,22302.96504020691,2641.5904109478,1.6836578845977783,0.0 -48600,1.0273364,3.4523046,,,,,,,,,,,,,, -48700,0.9594688,3.5930395,,,,,,,,,,,,,, -48800,1.1795957,3.58572,,,,,,,,,,,,,, -48900,0.9485995,3.5182397,,,,,,,,,,,,,, -49000,0.8645698,5.512331,,,,,,,,,,,,,, -49100,0.7482713,5.9925346,,,,,,,,,,,,,, -49200,1.1868929,3.6340141,,,,,,,,,,,,,, -49300,0.91693574,3.9384537,,,,,,,,,,,,,, -49400,1.1453803,3.6546304,,,,,,,,,,,,,, -49474,,,0.449531227350235,2.4620165824890137,0.4113599956035614,2.675554275512696,50000.0,0.3184000253677368,3.2931008338928223,10000.0,22723.290951013565,25419.6971578598,22723.290951013565,2692.083906888962,1.7213480472564695,0.0 -49500,0.9650322,3.9765282,,,,,,,,,,,,,, -49600,1.0264274,3.5417843,,,,,,,,,,,,,, -49700,1.0349295,3.5607119,,,,,,,,,,,,,, -49800,0.9259342,5.918178,,,,,,,,,,,,,, -49900,0.9152614,5.283085,,,,,,,,,,,,,, -50000,1.0679086,3.5756745,,,,,,,,,,,,,, -50100,0.98453087,4.847432,,,,,,,,,,,,,, -50200,0.9069005,3.796087,,,,,,,,,,,,,, -50300,0.98775774,4.9168277,,,,,,,,,,,,,, -50386,,,0.4435351490974426,2.5222740173339844,0.4126800000667572,2.66820764541626,50000.0,0.3139000236988067,3.3136818408966064,10000.0,23143.52159333229,25886.4966814518,23143.52159333229,2738.572528839112,1.7537884712219238,0.0 -50400,0.9750499,4.2227845,,,,,,,,,,,,,, -50500,1.049773,3.5462284,,,,,,,,,,,,,, -50600,0.84084255,5.8115454,,,,,,,,,,,,,, -50700,0.9375983,4.149451,,,,,,,,,,,,,, -50800,1.4055743,3.5368938,,,,,,,,,,,,,, -50900,1.2679359,3.6304505,,,,,,,,,,,,,, -51000,1.0491961,4.1838827,,,,,,,,,,,,,, -51100,1.0215945,3.7104745,,,,,,,,,,,,,, -51200,1.1164234,3.549036,,,,,,,,,,,,,, -51300,,,0.4445898234844208,2.477378845214844,0.4203200042247772,2.636164903640747,50000.0,0.3272000253200531,3.2604527473449707,10000.0,23563.82721590996,26356.45927453041,23563.82721590996,2788.143606901169,1.7920589447021484,0.0 -51300,0.93763757,3.4718487,,,,,,,,,,,,,, -51400,0.9046262,3.966613,,,,,,,,,,,,,, -51500,1.0000324,3.6353326,,,,,,,,,,,,,, -51600,0.9738179,5.2163506,,,,,,,,,,,,,, -51700,1.0748291,3.5988119,,,,,,,,,,,,,, -51800,1.0129597,4.433087,,,,,,,,,,,,,, -51900,0.95405424,4.2669168,,,,,,,,,,,,,, -52000,0.97694457,3.9734812,,,,,,,,,,,,,, -52100,0.85896516,4.5926924,,,,,,,,,,,,,, -52200,1.2854096,3.7383904,,,,,,,,,,,,,, -52216,,,0.438789039850235,2.5263781547546387,0.4053199887275696,2.713552474975586,50000.0,0.3122000098228454,3.3333580493927,10000.0,23983.90644145012,26825.42811512947,23983.90644145012,2836.9467310905457,1.830620288848877,0.0 -52300,0.96194935,3.7868743,,,,,,,,,,,,,, -52400,0.8762144,3.6556914,,,,,,,,,,,,,, -52500,1.23013,3.7047753,,,,,,,,,,,,,, -52600,0.93404406,4.1204576,,,,,,,,,,,,,, -52700,0.9055395,3.5172675,,,,,,,,,,,,,, -52800,1.0565703,3.5832443,,,,,,,,,,,,,, -52900,0.8652562,4.7225103,,,,,,,,,,,,,, -53000,0.93645227,5.650799,,,,,,,,,,,,,, -53100,1.2774501,3.6537783,,,,,,,,,,,,,, -53133,,,0.4564843773841858,2.438621759414673,0.4238399863243103,2.607288122177124,50000.0,0.3252000212669372,3.2333431243896484,10000.0,24404.01684069633,27293.93130230904,24404.01684069633,2885.250978946686,1.871053695678711,0.0 -53200,1.1767265,3.6361837,,,,,,,,,,,,,, -53300,0.8301804,5.7635045,,,,,,,,,,,,,, -53400,1.0884875,3.6248355,,,,,,,,,,,,,, -53500,0.9830501,3.825982,,,,,,,,,,,,,, -53600,1.2126169,3.6503654,,,,,,,,,,,,,, -53700,0.9410669,4.1007643,,,,,,,,,,,,,, -53800,1.034185,4.1870527,,,,,,,,,,,,,, -53900,0.88742495,5.068784,,,,,,,,,,,,,, -54000,0.6857884,4.9572244,,,,,,,,,,,,,, -54054,,,0.4422265589237213,2.535531759262085,0.412200003862381,2.7084314823150635,50000.0,0.3118000030517578,3.3063266277313232,10000.0,24824.36661410332,27765.620665311813,24824.36661410332,2936.507098197937,1.9068207740783687,0.0 -54100,0.9163449,4.2407045,,,,,,,,,,,,,, -54200,0.96152896,3.523913,,,,,,,,,,,,,, -54300,0.9789556,4.535842,,,,,,,,,,,,,, -54400,1.1825732,3.578185,,,,,,,,,,,,,, -54500,1.2866849,3.7488098,,,,,,,,,,,,,, -54600,1.0201594,4.1046157,,,,,,,,,,,,,, -54700,1.1604898,3.6541276,,,,,,,,,,,,,, -54800,0.8356115,5.0217786,,,,,,,,,,,,,, -54900,1.0746852,3.5173569,,,,,,,,,,,,,, -54973,,,0.4440820217132568,2.5492329597473145,0.4105399847030639,2.712507724761963,50000.0,0.3169000148773193,3.3160958290100098,10000.0,25244.459810972214,28235.69842457772,25244.459810972214,2986.406905412674,1.9439399242401123,0.0 -55000,1.0170684,3.5656867,,,,,,,,,,,,,, -55100,1.3540459,3.520227,,,,,,,,,,,,,, -55200,0.8973911,3.6201675,,,,,,,,,,,,,, -55300,1.111048,3.657592,,,,,,,,,,,,,, -55400,1.0549878,3.4045308,,,,,,,,,,,,,, -55500,1.0929323,3.4665198,,,,,,,,,,,,,, -55600,1.354095,3.6606534,,,,,,,,,,,,,, -55700,1.1669384,3.5261369,,,,,,,,,,,,,, -55800,1.0022929,3.598332,,,,,,,,,,,,,, -55889,,,0.4723632633686065,2.383117198944092,0.4037799835205078,2.740209579467773,50000.0,0.316100001335144,3.364776611328125,10000.0,25664.43350338936,28705.693928956985,25664.43350338936,3036.34405207634,1.9801304340362549,0.0 -55900,1.1322564,3.6073408,,,,,,,,,,,,,, -56000,1.0212265,3.7658508,,,,,,,,,,,,,, -56100,1.0047818,3.9496667,,,,,,,,,,,,,, -56200,0.99294734,3.5028505,,,,,,,,,,,,,, -56300,1.1289036,3.4182463,,,,,,,,,,,,,, -56400,1.2843435,3.723871,,,,,,,,,,,,,, -56500,0.9718157,3.6078808,,,,,,,,,,,,,, -56600,0.8449167,5.872051,,,,,,,,,,,,,, -56700,1.1012151,3.4226282,,,,,,,,,,,,,, -56800,0.7585949,5.2807164,,,,,,,,,,,,,, -56807,,,0.4465624988079071,2.5229108333587646,0.4216800034046173,2.661752700805664,50000.0,0.3288000226020813,3.2641468048095703,10000.0,26084.50500845909,29175.84260368347,26084.50500845909,3086.3352172374725,2.0175254344940186,0.0 -56900,0.7723489,5.2367015,,,,,,,,,,,,,, -57000,0.6744967,5.2822742,,,,,,,,,,,,,, -57100,0.83310944,5.258626,,,,,,,,,,,,,, -57200,0.76929384,5.1776886,,,,,,,,,,,,,, -57300,1.1738359,4.23849,,,,,,,,,,,,,, -57400,1.074505,3.4521308,,,,,,,,,,,,,, -57500,0.91650033,3.6617737,,,,,,,,,,,,,, -57600,0.8850001,4.370293,,,,,,,,,,,,,, -57700,0.9838125,3.5403383,,,,,,,,,,,,,, -57728,,,0.4511913955211639,2.4399993419647217,0.4245599806308746,2.596042394638061,50000.0,0.3342000246047973,3.217836141586304,10000.0,26504.70529055596,29644.84923171997,26504.70529055596,3135.060056447983,2.0510499477386475,0.0 -57800,0.9428274,4.004897,,,,,,,,,,,,,, -57900,1.0292697,3.6235936,,,,,,,,,,,,,, -58000,1.3167397,3.6015131,,,,,,,,,,,,,, -58100,0.9983833,4.849959,,,,,,,,,,,,,, -58200,0.7214663,4.986675,,,,,,,,,,,,,, -58300,1.2319946,3.4508688,,,,,,,,,,,,,, -58400,1.169317,3.5470953,,,,,,,,,,,,,, -58500,0.84531945,4.441426,,,,,,,,,,,,,, -58600,1.0975555,3.8463087,,,,,,,,,,,,,, -58645,,,0.4740820229053497,2.372753143310547,0.4254199862480163,2.6239047050476074,50000.0,0.3331000208854675,3.236710786819458,10000.0,26924.90839576721,30115.69566130638,26924.90839576721,3185.6205384731293,2.085714817047119,0.0 -58700,1.2048208,3.6426291,,,,,,,,,,,,,, -58800,1.1154327,3.5567,,,,,,,,,,,,,, -58900,0.7813886,4.763135,,,,,,,,,,,,,, -59000,1.0615709,3.4471107,,,,,,,,,,,,,, -59100,1.1912751,4.131693,,,,,,,,,,,,,, -59200,1.1537808,3.4757946,,,,,,,,,,,,,, -59300,0.7342795,5.7616663,,,,,,,,,,,,,, -59400,0.7277457,5.888978,,,,,,,,,,,,,, -59500,1.0104781,5.827072,,,,,,,,,,,,,, -59558,,,0.4577734172344208,2.429400682449341,0.4283799827098846,2.5803024768829346,50000.0,0.3296000063419342,3.227254867553711,10000.0,27344.886063098907,30585.716195106503,27344.886063098907,3235.580410003662,2.121338367462158,0.0 -59600,1.0729278,3.474533,,,,,,,,,,,,,, -59700,0.7152781,5.5225334,,,,,,,,,,,,,, -59800,1.1819416,3.7963696,,,,,,,,,,,,,, -59900,1.071338,3.4518795,,,,,,,,,,,,,, -60000,1.023417,4.2702446,,,,,,,,,,,,,, -60100,1.0200981,3.4386382,,,,,,,,,,,,,, -60200,0.98030025,3.4543328,,,,,,,,,,,,,, -60300,1.0198596,4.741507,,,,,,,,,,,,,, -60400,1.0701276,3.3961015,,,,,,,,,,,,,, -60474,,,0.4648632705211639,2.3890140056610107,0.4313199818134308,2.562873363494873,50000.0,0.3315000236034393,3.199613809585572,10000.0,27764.8627755642,31054.29045343399,27764.8627755642,3284.088232278824,2.162022352218628,0.0 -60500,1.3301505,3.6689103,,,,,,,,,,,,,, -60600,1.0569843,3.8455777,,,,,,,,,,,,,, -60700,1.0138946,3.5579202,,,,,,,,,,,,,, -60800,1.291676,3.5554347,,,,,,,,,,,,,, -60900,0.99030954,4.460952,,,,,,,,,,,,,, -61000,0.879993,5.1397705,,,,,,,,,,,,,, -61100,1.0328472,3.642576,,,,,,,,,,,,,, -61200,0.88677055,4.8288364,,,,,,,,,,,,,, -61300,0.98417616,3.3812125,,,,,,,,,,,,,, -61391,,,0.467089831829071,2.3804032802581787,0.4293199777603149,2.585036277770996,50000.0,0.3352000117301941,3.2316999435424805,10000.0,28185.063593387604,31523.559364318848,28185.063593387604,3333.0700438022614,2.200597047805786,0.0 -61400,1.1726544,3.4843976,,,,,,,,,,,,,, -61500,0.9426984,3.412158,,,,,,,,,,,,,, -61600,1.1353536,3.728514,,,,,,,,,,,,,, -61700,0.926685,4.7019944,,,,,,,,,,,,,, -61800,1.1157961,3.524183,,,,,,,,,,,,,, -61900,1.0040598,3.4015684,,,,,,,,,,,,,, -62000,1.0683132,3.567574,,,,,,,,,,,,,, -62100,0.9576161,3.5483716,,,,,,,,,,,,,, -62200,1.0786787,3.472426,,,,,,,,,,,,,, -62300,0.94190043,5.82098,,,,,,,,,,,,,, -62310,,,0.4643359184265136,2.393752336502075,0.4336199760437011,2.553832054138184,50000.0,0.3335000276565552,3.20181655883789,10000.0,28604.98014640808,31993.887003183365,28604.98014640808,3383.3891365528107,2.243925094604492,0.0 -62400,1.1419193,3.5667672,,,,,,,,,,,,,, -62500,1.0354494,4.0755644,,,,,,,,,,,,,, -62600,1.0516206,3.5994058,,,,,,,,,,,,,, -62700,0.93222535,4.8365803,,,,,,,,,,,,,, -62800,0.93023777,5.615429,,,,,,,,,,,,,, -62900,1.1161941,3.3754177,,,,,,,,,,,,,, -63000,1.0394768,3.5142596,,,,,,,,,,,,,, -63100,1.0676333,3.4272628,,,,,,,,,,,,,, -63200,0.7814022,5.0171175,,,,,,,,,,,,,, -63227,,,0.4670117199420929,2.3691554069519043,0.4369799792766571,2.539468050003052,50000.0,0.3427000045776367,3.135342836380005,10000.0,29025.158362150192,32463.748566627502,29025.158362150192,3432.985024452209,2.2831294536590576,0.0 -63300,0.962591,3.8334856,,,,,,,,,,,,,, -63400,0.835587,5.751251,,,,,,,,,,,,,, -63500,0.96969604,4.299368,,,,,,,,,,,,,, -63600,0.7601577,5.126894,,,,,,,,,,,,,, -63700,1.0148733,3.4227395,,,,,,,,,,,,,, -63800,1.029909,5.108992,,,,,,,,,,,,,, -63900,0.97135293,5.2852345,,,,,,,,,,,,,, -64000,0.99521035,3.4140413,,,,,,,,,,,,,, -64100,1.0645347,3.7360418,,,,,,,,,,,,,, -64144,,,0.4650976359844208,2.3967294692993164,0.4300599992275238,2.5977704524993896,50000.0,0.3273000121116638,3.223234176635742,10000.0,29445.27424812317,32934.04869699478,29445.27424812317,3483.0788888931274,2.325901746749878,0.0 -64200,0.8235906,5.6785355,,,,,,,,,,,,,, -64300,1.1026084,3.322059,,,,,,,,,,,,,, -64400,0.73030305,5.5337825,,,,,,,,,,,,,, -64500,1.1819004,3.5145113,,,,,,,,,,,,,, -64600,0.9762738,3.7165558,,,,,,,,,,,,,, -64700,1.149501,3.4374752,,,,,,,,,,,,,, -64800,1.1507943,3.7896552,,,,,,,,,,,,,, -64900,0.91929114,4.288972,,,,,,,,,,,,,, -65000,0.81358504,5.434526,,,,,,,,,,,,,, -65062,,,0.4678320288658142,2.389435291290283,0.4369399845600128,2.564899682998657,50000.0,0.3362000286579132,3.1964430809021,10000.0,29865.45224094391,33405.08270573616,29865.45224094391,3533.843167066574,2.3687186241149902,0.0 -65100,0.9151859,4.3725204,,,,,,,,,,,,,, -65200,1.026765,3.4241502,,,,,,,,,,,,,, -65300,0.8568125,5.771651,,,,,,,,,,,,,, -65400,1.027612,3.4695952,,,,,,,,,,,,,, -65500,1.0473301,3.5538821,,,,,,,,,,,,,, -65600,0.9869904,3.693788,,,,,,,,,,,,,, -65700,1.0024505,3.4169264,,,,,,,,,,,,,, -65800,1.5408307,3.563355,,,,,,,,,,,,,, -65900,1.0014551,3.9543495,,,,,,,,,,,,,, -65982,,,0.4744921624660492,2.419340133666992,0.4359599947929382,2.5876500606536865,50000.0,0.3450000286102295,3.2141637802124023,10000.0,30285.408737182617,33876.14655208588,30285.408737182617,3584.863637447357,2.4079158306121826,0.0 -66000,0.9955453,3.2341688,,,,,,,,,,,,,, -66100,0.9474115,5.2254953,,,,,,,,,,,,,, -66200,1.0315953,3.6147556,,,,,,,,,,,,,, -66300,1.3364333,3.3909488,,,,,,,,,,,,,, -66400,0.95341706,3.7407806,,,,,,,,,,,,,, -66500,0.9942095,3.3522801,,,,,,,,,,,,,, -66600,0.99927205,3.4635181,,,,,,,,,,,,,, -66700,0.92548996,5.7173324,,,,,,,,,,,,,, -66800,0.8721764,3.814887,,,,,,,,,,,,,, -66900,0.94078726,4.8738494,,,,,,,,,,,,,, -66901,,,0.4737695157527923,2.3729798793792725,0.44200000166893,2.548090696334839,50000.0,0.340800017118454,3.17934513092041,10000.0,30705.5859375,34346.95048165321,30705.5859375,3635.3954322338095,2.4538557529449463,0.0 -67000,0.9274899,5.5848455,,,,,,,,,,,,,, -67100,1.1472625,3.3675418,,,,,,,,,,,,,, -67200,0.8101528,4.638985,,,,,,,,,,,,,, -67300,1.0456744,4.844603,,,,,,,,,,,,,, -67400,1.0879736,3.710201,,,,,,,,,,,,,, -67500,1.1316478,3.4205213,,,,,,,,,,,,,, -67600,0.9577945,3.4742806,,,,,,,,,,,,,, -67700,1.1444672,3.5168674,,,,,,,,,,,,,, -67800,1.0985835,3.3113396,,,,,,,,,,,,,, -67819,,,0.5081835985183716,2.193055391311645,0.4444599747657776,2.5196900367736816,50000.0,0.3449000120162964,3.149144172668457,10000.0,31125.7574505806,34814.36068201065,31125.7574505806,3682.5514616966248,2.488304853439331,0.0 -67900,0.89763045,5.7572737,,,,,,,,,,,,,, -68000,0.99157757,4.050592,,,,,,,,,,,,,, -68100,1.02236,3.9193804,,,,,,,,,,,,,, -68200,0.9468494,3.783991,,,,,,,,,,,,,, -68300,0.9674977,3.6476576,,,,,,,,,,,,,, -68400,0.9479458,4.391162,,,,,,,,,,,,,, -68500,1.019044,3.4398131,,,,,,,,,,,,,, -68600,0.99556136,3.338807,,,,,,,,,,,,,, -68700,1.0194508,3.2514,,,,,,,,,,,,,, -68738,,,0.474414050579071,2.358579397201538,0.4458999931812286,2.502725124359131,50000.0,0.3493000268936157,3.140624761581421,10000.0,31545.98058009148,35283.56374049187,31545.98058009148,3731.446517467499,2.525161027908325,0.0 -68800,0.8169252,5.6039352,,,,,,,,,,,,,, -68900,1.0039718,3.4115975,,,,,,,,,,,,,, -69000,1.1054986,3.261751,,,,,,,,,,,,,, -69100,1.0886296,3.3213205,,,,,,,,,,,,,, -69200,1.0365411,3.284594,,,,,,,,,,,,,, -69300,1.1785436,3.4429862,,,,,,,,,,,,,, -69400,1.072809,3.4039927,,,,,,,,,,,,,, -69500,1.1874928,3.59021,,,,,,,,,,,,,, -69600,1.1358697,3.3686168,,,,,,,,,,,,,, -69655,,,0.488085925579071,2.284952402114868,0.4509799778461456,2.4512698650360107,50000.0,0.35630002617836,3.0855069160461426,10000.0,31966.211909532547,35753.97013711929,31966.211909532547,3781.5328879356384,2.565286159515381,0.0 -69700,0.97604185,3.3385506,,,,,,,,,,,,,, -69800,0.92751175,4.354225,,,,,,,,,,,,,, -69900,1.0929469,3.588341,,,,,,,,,,,,,, -70000,1.5205513,3.385963,,,,,,,,,,,,,, -70100,1.1502838,3.4716415,,,,,,,,,,,,,, -70200,1.0533097,3.7408478,,,,,,,,,,,,,, -70300,0.89010364,4.4139323,,,,,,,,,,,,,, -70400,1.1087947,3.5132947,,,,,,,,,,,,,, -70500,1.2098922,3.3036966,,,,,,,,,,,,,, -70571,,,0.5066015720367432,2.156141519546509,0.4574399888515472,2.423490285873413,50000.0,0.362600028514862,3.076205253601074,10000.0,32386.34180402756,36222.48383450508,32386.34180402756,3829.8283665180206,2.6060354709625244,0.0 -70600,1.0256675,3.4704971,,,,,,,,,,,,,, -70700,1.1288921,3.4255562,,,,,,,,,,,,,, -70800,1.1075569,3.2853749,,,,,,,,,,,,,, -70900,0.8943685,4.654378,,,,,,,,,,,,,, -71000,0.9787409,5.7806373,,,,,,,,,,,,,, -71100,1.1737251,3.3855667,,,,,,,,,,,,,, -71200,1.2731515,3.5602531,,,,,,,,,,,,,, -71300,1.2206081,3.5206814,,,,,,,,,,,,,, -71400,0.93797,4.4025555,,,,,,,,,,,,,, -71488,,,0.4865429699420929,2.26175856590271,0.4581999778747558,2.418018102645874,50000.0,0.3591000139713287,3.0786125659942627,10000.0,32806.55907249451,36692.19506430626,32806.55907249451,3879.22545838356,2.655019521713257,0.0 -71500,0.87344134,3.6226716,,,,,,,,,,,,,, -71600,0.9880126,4.3539753,,,,,,,,,,,,,, -71700,1.0617193,3.233439,,,,,,,,,,,,,, -71800,1.0126989,5.0887866,,,,,,,,,,,,,, -71900,1.4851303,3.2981205,,,,,,,,,,,,,, -72000,0.972248,3.6495519,,,,,,,,,,,,,, -72100,0.8138961,5.5479507,,,,,,,,,,,,,, -72200,1.2610456,3.363789,,,,,,,,,,,,,, -72300,1.0007215,3.9578674,,,,,,,,,,,,,, -72400,0.8676603,4.970248,,,,,,,,,,,,,, -72406,,,0.4870898425579071,2.2414968013763428,0.4551199972629547,2.416773796081543,50000.0,0.3544000089168548,3.063166618347168,10000.0,33226.725727796555,37162.81544137001,33226.725727796555,3929.594414949417,2.6921746730804443,0.0 -72500,1.0095212,3.2633877,,,,,,,,,,,,,, -72600,0.9730206,4.231148,,,,,,,,,,,,,, -72700,1.041008,3.526145,,,,,,,,,,,,,, -72800,0.932608,5.2686043,,,,,,,,,,,,,, -72900,0.8053775,5.695489,,,,,,,,,,,,,, -73000,0.9221475,4.6753573,,,,,,,,,,,,,, -73100,1.3364472,3.3502426,,,,,,,,,,,,,, -73200,0.8273032,5.7428207,,,,,,,,,,,,,, -73300,1.1159718,3.4904113,,,,,,,,,,,,,, -73322,,,0.4894921779632568,2.2488605976104736,0.4532599747180938,2.4476559162139893,50000.0,0.3473000228404999,3.1173243522644043,10000.0,33647.000801324844,37633.78448009491,33647.000801324844,3980.205307483673,2.726961135864258,0.0 -73400,1.1331606,3.4660535,,,,,,,,,,,,,, -73500,1.1117059,3.4264996,,,,,,,,,,,,,, -73600,1.0571918,3.342888,,,,,,,,,,,,,, -73700,1.0741816,3.5455403,,,,,,,,,,,,,, -73800,1.3647274,3.3589616,,,,,,,,,,,,,, -73900,0.9094315,5.5703955,,,,,,,,,,,,,, -74000,1.0656928,3.3295162,,,,,,,,,,,,,, -74100,1.1140327,5.863957,,,,,,,,,,,,,, -74200,1.2157888,3.4252565,,,,,,,,,,,,,, -74239,,,0.491503894329071,2.291461944580078,0.4583199918270111,2.455028772354126,50000.0,0.3549000024795532,3.087660074234009,10000.0,34067.28379058838,38104.21386909485,34067.28379058838,4030.269171476364,2.76204776763916,0.0 -74300,1.0971551,3.2522366,,,,,,,,,,,,,, -74400,0.95186436,4.6432633,,,,,,,,,,,,,, -74500,1.0624573,3.2608726,,,,,,,,,,,,,, -74600,0.9765137,3.7127044,,,,,,,,,,,,,, -74700,0.9789596,5.751251,,,,,,,,,,,,,, -74800,1.1233805,3.2460265,,,,,,,,,,,,,, -74900,0.90531754,4.8263626,,,,,,,,,,,,,, -75000,0.82120436,5.3608284,,,,,,,,,,,,,, -75100,1.0218801,3.3398664,,,,,,,,,,,,,, -75158,,,0.5007616877555847,2.170815944671631,0.4699599742889404,2.350716114044189,50000.0,0.3680000305175781,3.005544662475586,10000.0,34487.50742006302,38575.3131480217,34487.50742006302,4081.058485507965,2.7998476028442383,0.0 -75200,1.1087509,3.2202706,,,,,,,,,,,,,, -75300,1.3544569,3.3590717,,,,,,,,,,,,,, -75400,1.204404,3.3857317,,,,,,,,,,,,,, -75500,1.0497661,3.2357452,,,,,,,,,,,,,, -75600,0.93637526,3.8237848,,,,,,,,,,,,,, -75700,0.86409885,4.4182305,,,,,,,,,,,,,, -75800,1.0761917,3.433669,,,,,,,,,,,,,, -75900,1.2978286,3.2371457,,,,,,,,,,,,,, -76000,1.1130607,3.5539508,,,,,,,,,,,,,, -76076,,,0.5032812356948853,2.190779447555542,0.464819997549057,2.4005751609802246,50000.0,0.3609000146389007,3.055748462677002,10000.0,34907.619475364685,39046.77811384201,34907.619475364685,4132.322255373001,2.8405230045318604,0.0 -76100,1.1103475,3.2208533,,,,,,,,,,,,,, -76200,1.0958915,3.327836,,,,,,,,,,,,,, -76300,1.2276622,3.4426932,,,,,,,,,,,,,, -76400,0.8689346,5.28126,,,,,,,,,,,,,, -76500,0.86471367,4.6523986,,,,,,,,,,,,,, -76600,1.2396327,5.7184057,,,,,,,,,,,,,, -76700,1.1546191,3.3945327,,,,,,,,,,,,,, -76800,0.8295889,4.613745,,,,,,,,,,,,,, -76900,0.97382814,5.350559,,,,,,,,,,,,,, -76992,,,0.5121288895606995,2.185343027114868,0.4690199792385101,2.402076244354248,50000.0,0.3664000034332275,3.04853892326355,10000.0,35327.92823219299,39516.02749085426,35327.92823219299,4181.173963069916,2.8811960220336914,0.0 -77000,1.4623944,3.3831306,,,,,,,,,,,,,, -77100,1.0568027,3.173308,,,,,,,,,,,,,, -77200,1.1462268,3.1985219,,,,,,,,,,,,,, -77300,1.1163427,3.285587,,,,,,,,,,,,,, -77400,1.0257783,3.4072466,,,,,,,,,,,,,, -77500,1.0953596,3.2225823,,,,,,,,,,,,,, -77600,0.9867946,3.5639615,,,,,,,,,,,,,, -77700,1.0846701,3.3757899,,,,,,,,,,,,,, -77800,0.93588555,5.587442,,,,,,,,,,,,,, -77900,0.9761745,3.7834358,,,,,,,,,,,,,, -77907,,,0.5077343583106995,2.158487558364868,0.4768199920654297,2.3143656253814697,50000.0,0.3689000308513641,2.9804775714874268,10000.0,35747.85677528381,39984.73551940918,35747.85677528381,4229.867544412613,2.919562578201294,0.0 -78000,1.1380363,3.2131517,,,,,,,,,,,,,, -78100,1.081739,3.2523048,,,,,,,,,,,,,, -78200,1.1503662,3.4525785,,,,,,,,,,,,,, -78300,0.77574074,4.7387557,,,,,,,,,,,,,, -78400,1.0332477,3.1711285,,,,,,,,,,,,,, -78500,1.2521929,3.2800488,,,,,,,,,,,,,, -78600,1.021204,3.7522106,,,,,,,,,,,,,, -78700,1.1531315,3.2006373,,,,,,,,,,,,,, -78800,0.82304734,5.3827477,,,,,,,,,,,,,, -78822,,,0.5114843845367432,2.1231250762939453,0.4710799753665924,2.3353586196899414,50000.0,0.3671000301837921,3.0031578540802,10000.0,36168.04482078552,40460.50757360458,36168.04482078552,4285.362320184708,2.961005449295044,0.0 -78900,0.845028,5.6611757,,,,,,,,,,,,,, -79000,0.8881687,4.603716,,,,,,,,,,,,,, -79100,1.1036868,3.3372116,,,,,,,,,,,,,, -79200,0.7685314,5.640656,,,,,,,,,,,,,, -79300,1.0837865,3.3536577,,,,,,,,,,,,,, -79400,0.98165965,3.2144282,,,,,,,,,,,,,, -79500,1.2038735,3.1905997,,,,,,,,,,,,,, -79600,0.8414791,4.913557,,,,,,,,,,,,,, -79700,0.83490855,5.5667934,,,,,,,,,,,,,, -79740,,,0.5294140577316284,2.067649602890014,0.465859979391098,2.3905460834503174,50000.0,0.3686000108718872,3.0116066932678223,10000.0,36587.97202754021,40929.57435941696,36587.97202754021,4334.415132522583,3.0005173683166504,0.0 -79800,1.0554793,3.4318213,,,,,,,,,,,,,, -79900,1.0025249,3.2633731,,,,,,,,,,,,,, -80000,1.1337855,3.2185476,,,,,,,,,,,,,, -80100,1.4429392,3.4772403,,,,,,,,,,,,,, -80200,0.8801783,3.9571276,,,,,,,,,,,,,, -80300,1.4040271,3.3664405,,,,,,,,,,,,,, -80400,0.86319244,4.7067,,,,,,,,,,,,,, -80500,0.88179463,5.488451,,,,,,,,,,,,,, -80600,1.1385493,3.2473273,,,,,,,,,,,,,, -80656,,,0.4966796636581421,2.233603954315185,0.4697999954223633,2.378688335418701,50000.0,0.3669000267982483,3.037405252456665,10000.0,37007.93005943298,41400.32828879357,37007.93005943298,4385.123164892197,3.0408191680908203,0.0 -80700,1.3738483,3.4414237,,,,,,,,,,,,,, -80800,0.92390114,4.44893,,,,,,,,,,,,,, -80900,1.0625238,3.113422,,,,,,,,,,,,,, -81000,1.085368,3.5981853,,,,,,,,,,,,,, -81100,1.8548467,3.124561,,,,,,,,,,,,,, -81200,0.97008663,5.5219593,,,,,,,,,,,,,, -81300,1.2615479,3.1869326,,,,,,,,,,,,,, -81400,1.2213075,3.2747881,,,,,,,,,,,,,, -81500,1.1839548,3.4089427,,,,,,,,,,,,,, -81572,,,0.512890636920929,2.137596368789673,0.4765599966049194,2.3293349742889404,50000.0,0.3721000254154205,2.969887018203736,10000.0,37427.89789867401,41870.229408979416,37427.89789867401,4434.969736337662,3.080196142196656,0.0 -81600,0.9760924,5.0453515,,,,,,,,,,,,,, -81700,1.2166779,5.791533,,,,,,,,,,,,,, -81800,0.7296898,5.5593586,,,,,,,,,,,,,, -81900,1.0357741,3.2766004,,,,,,,,,,,,,, -82000,1.1578333,5.0891504,,,,,,,,,,,,,, -82100,1.3044312,3.352506,,,,,,,,,,,,,, -82200,1.1951613,5.7819977,,,,,,,,,,,,,, -82300,1.0836564,3.2718546,,,,,,,,,,,,,, -82400,1.1624908,3.1946049,,,,,,,,,,,,,, -82491,,,0.531054675579071,2.0348589420318604,0.4846400022506714,2.283027648925781,50000.0,0.3731000125408172,2.9452946186065674,10000.0,37847.978172302246,42340.52619123459,37847.978172302246,4485.095364332199,3.1229937076568604,0.0 -82500,1.0561078,4.058468,,,,,,,,,,,,,, -82600,0.9524169,5.237011,,,,,,,,,,,,,, -82700,1.0288637,3.1125941,,,,,,,,,,,,,, -82800,1.1767949,3.279204,,,,,,,,,,,,,, -82900,1.0354326,3.1641378,,,,,,,,,,,,,, -83000,0.88831866,4.3189874,,,,,,,,,,,,,, -83100,1.1931418,3.1959066,,,,,,,,,,,,,, -83200,1.0403935,3.3169763,,,,,,,,,,,,,, -83300,0.78555375,5.430136,,,,,,,,,,,,,, -83400,1.3519591,5.7027817,,,,,,,,,,,,,, -83408,,,0.519726574420929,2.083502531051636,0.4893999993801117,2.2483296394348145,50000.0,0.3836000263690948,2.9029996395111084,10000.0,38268.08063173294,42808.46298861504,38268.08063173294,4532.843173027039,3.1619906425476074,0.0 -83500,1.2220076,3.1823635,,,,,,,,,,,,,, -83600,1.0018239,3.9840882,,,,,,,,,,,,,, -83700,0.83590484,5.0494804,,,,,,,,,,,,,, -83800,1.113183,3.501216,,,,,,,,,,,,,, -83900,1.0964537,3.1566167,,,,,,,,,,,,,, -84000,1.1440064,3.6312504,,,,,,,,,,,,,, -84100,1.2567225,3.2522762,,,,,,,,,,,,,, -84200,1.0460119,3.134544,,,,,,,,,,,,,, -84300,0.9325708,4.305947,,,,,,,,,,,,,, -84324,,,0.5274999737739563,2.044956684112549,0.4887399971485138,2.2427096366882324,50000.0,0.3802000284194946,2.933024883270264,10000.0,38688.25997066498,43277.33517932892,38688.25997066498,4581.449735164642,3.20169997215271,0.0 -84400,1.0140475,3.3001332,,,,,,,,,,,,,, -84500,1.073227,3.318574,,,,,,,,,,,,,, -84600,1.1399267,3.3868673,,,,,,,,,,,,,, -84700,1.2483506,3.2465982,,,,,,,,,,,,,, -84800,1.1076732,3.1679251,,,,,,,,,,,,,, -84900,1.0430672,5.588085,,,,,,,,,,,,,, -85000,1.1048652,3.4883947,,,,,,,,,,,,,, -85100,1.1222105,3.2280855,,,,,,,,,,,,,, -85200,1.0862647,3.1229699,,,,,,,,,,,,,, -85241,,,0.5252148509025574,2.1249876022338867,0.4815999865531921,2.3384721279144287,50000.0,0.3765000104904175,2.978649139404297,10000.0,39108.579641819,43747.66602993012,39108.579641819,4631.371387481689,3.2427725791931152,0.0 -85300,0.81944674,5.5719686,,,,,,,,,,,,,, -85400,1.1249624,3.1473548,,,,,,,,,,,,,, -85500,1.1134292,3.293788,,,,,,,,,,,,,, -85600,1.1291646,3.152142,,,,,,,,,,,,,, -85700,1.2231907,3.2597928,,,,,,,,,,,,,, -85800,1.0574226,3.3848004,,,,,,,,,,,,,, -85900,1.1441596,3.2740097,,,,,,,,,,,,,, -86000,1.2047791,3.202754,,,,,,,,,,,,,, -86100,1.2038206,3.478154,,,,,,,,,,,,,, -86158,,,0.515820324420929,2.156745672225952,0.4821399748325348,2.332691192626953,50000.0,0.3702000081539154,2.974332332611084,10000.0,39528.60836338997,44221.15460038185,39528.60836338997,4684.744953393936,3.281177997589112,0.0 -86200,0.90574265,5.438243,,,,,,,,,,,,,, -86300,1.1186787,5.0867453,,,,,,,,,,,,,, -86400,0.9029755,5.496128,,,,,,,,,,,,,, -86500,1.1039636,3.033525,,,,,,,,,,,,,, -86600,0.92189276,4.083867,,,,,,,,,,,,,, -86700,1.0564897,3.1617975,,,,,,,,,,,,,, -86800,1.1312412,3.6582162,,,,,,,,,,,,,, -86900,1.1361146,3.9267867,,,,,,,,,,,,,, -87000,1.1040546,3.307971,,,,,,,,,,,,,, -87075,,,0.5296288728713989,2.0581154823303223,0.4989999830722809,2.215408563613892,50000.0,0.3924000263214111,2.8767480850219727,10000.0,39948.69438958168,44690.825795173645,39948.69438958168,4734.239243745804,3.324136257171631,0.0 -87100,1.0270679,3.1578343,,,,,,,,,,,,,, -87200,1.1782578,3.1027956,,,,,,,,,,,,,, -87300,1.0639706,3.4089463,,,,,,,,,,,,,, -87400,1.1513704,3.1866493,,,,,,,,,,,,,, -87500,1.1828914,3.1308885,,,,,,,,,,,,,, -87600,0.88460237,4.4647713,,,,,,,,,,,,,, -87700,0.9650232,4.3776126,,,,,,,,,,,,,, -87800,0.9578976,4.7413383,,,,,,,,,,,,,, -87900,1.0709974,3.123575,,,,,,,,,,,,,, -87991,,,0.5358788967132568,1.9901221990585327,0.5009599924087524,2.1941733360290527,50000.0,0.3921000063419342,2.863675594329834,10000.0,40368.87440466881,45162.857728004456,40368.87440466881,4786.001025438309,3.3661766052246094,0.0 -88000,1.0438669,4.6689687,,,,,,,,,,,,,, -88100,1.3486449,3.2629747,,,,,,,,,,,,,, -88200,0.84780604,5.2512484,,,,,,,,,,,,,, -88300,1.2790518,3.1730623,,,,,,,,,,,,,, -88400,0.8452546,5.4085436,,,,,,,,,,,,,, -88500,0.94495624,4.1991763,,,,,,,,,,,,,, -88600,1.0500053,3.1763592,,,,,,,,,,,,,, -88700,1.2128519,3.1183207,,,,,,,,,,,,,, -88800,1.1469842,3.2353234,,,,,,,,,,,,,, -88900,1.0261341,3.400574,,,,,,,,,,,,,, -88907,,,0.5342382788658142,2.0812058448791504,0.4899799823760986,2.29469895362854,50000.0,0.3799000084400177,2.949227809906006,10000.0,40788.85287356377,45632.11156868935,40788.85287356377,4835.186659097672,3.407886266708374,0.0 -89000,1.2759364,3.126399,,,,,,,,,,,,,, -89100,1.1223613,3.295693,,,,,,,,,,,,,, -89200,1.0684772,3.3977048,,,,,,,,,,,,,, -89300,1.0494411,3.5052578,,,,,,,,,,,,,, -89400,0.88888454,4.9912806,,,,,,,,,,,,,, -89500,1.0620012,4.2293534,,,,,,,,,,,,,, -89600,1.1335403,3.1849809,,,,,,,,,,,,,, -89700,1.0203111,3.432192,,,,,,,,,,,,,, -89800,1.0140308,5.422461,,,,,,,,,,,,,, -89821,,,0.5366601347923279,2.041024923324585,0.5012800097465515,2.20486831665039,50000.0,0.394400030374527,2.8692476749420166,10000.0,41209.05537772179,46101.12329792976,41209.05537772179,4883.906123161316,3.450183391571045,0.0 -89900,1.1494128,2.971726,,,,,,,,,,,,,, -90000,1.3229262,3.1810977,,,,,,,,,,,,,, -90100,0.8395103,5.4962225,,,,,,,,,,,,,, -90200,0.9553162,3.8288736,,,,,,,,,,,,,, -90300,1.0111576,3.7379463,,,,,,,,,,,,,, -90400,1.3403822,3.1356058,,,,,,,,,,,,,, -90500,1.0391617,3.1122923,,,,,,,,,,,,,, -90600,1.0874072,4.525238,,,,,,,,,,,,,, -90700,1.1869161,3.04081,,,,,,,,,,,,,, -90736,,,0.5470117330551147,1.968786239624024,0.5073800086975098,2.1694068908691406,50000.0,0.398900032043457,2.8334524631500244,10000.0,41629.10823345184,46570.9590845108,41629.10823345184,4933.602746248245,3.4877588748931885,0.0 -90800,1.0591363,3.5862238,,,,,,,,,,,,,, -90900,1.1687934,3.3208034,,,,,,,,,,,,,, -91000,1.115602,3.9243772,,,,,,,,,,,,,, -91100,1.2706758,3.1150703,,,,,,,,,,,,,, -91200,1.1384282,2.9819634,,,,,,,,,,,,,, -91300,1.0714427,3.4001083,,,,,,,,,,,,,, -91400,1.2958065,2.9764004,,,,,,,,,,,,,, -91500,0.92338985,5.2563434,,,,,,,,,,,,,, -91600,1.1922892,3.0820198,,,,,,,,,,,,,, -91649,,,0.5741991996765137,1.818586349487305,0.5086199641227722,2.1520354747772217,50000.0,0.3973000049591064,2.799119234085083,10000.0,42049.48864984512,47041.567656993866,42049.48864984512,4983.74448466301,3.526984214782715,0.0 -91700,1.1937941,3.1155682,,,,,,,,,,,,,, -91800,0.8257889,5.241748,,,,,,,,,,,,,, -91900,1.0652498,3.3650343,,,,,,,,,,,,,, -92000,1.2571871,3.2841823,,,,,,,,,,,,,, -92100,1.0630411,5.319709,,,,,,,,,,,,,, -92200,1.1719563,3.170908,,,,,,,,,,,,,, -92300,0.92577785,5.4631553,,,,,,,,,,,,,, -92400,0.93017083,4.735108,,,,,,,,,,,,,, -92500,0.88490504,4.3455114,,,,,,,,,,,,,, -92566,,,0.5272265672683716,2.088000059127808,0.5004400014877319,2.219738483428955,50000.0,0.3943000137805938,2.86732816696167,10000.0,42469.72681379318,47512.32449412346,42469.72681379318,5034.174078941345,3.568467617034912,0.0 -92600,0.98952687,3.085132,,,,,,,,,,,,,, -92700,0.88007575,5.091714,,,,,,,,,,,,,, -92800,1.202318,3.209047,,,,,,,,,,,,,, -92900,0.98364294,4.2218223,,,,,,,,,,,,,, -93000,1.0808035,3.9154809,,,,,,,,,,,,,, -93100,1.4494553,3.1102707,,,,,,,,,,,,,, -93200,1.0307426,3.8108416,,,,,,,,,,,,,, -93300,1.0709541,3.1411216,,,,,,,,,,,,,, -93400,0.97126544,3.7505393,,,,,,,,,,,,,, -93480,,,0.5433398485183716,2.01934814453125,0.5062999725341797,2.205066442489624,50000.0,0.393200010061264,2.861738681793213,10000.0,42889.70433783531,47982.05200815201,42889.70433783531,5083.820999860764,3.623884439468384,0.0 -93500,1.0345905,4.1042256,,,,,,,,,,,,,, -93600,1.2099334,3.3727164,,,,,,,,,,,,,, -93700,1.0818292,3.122455,,,,,,,,,,,,,, -93800,1.164272,3.2174666,,,,,,,,,,,,,, -93900,0.98507774,4.501733,,,,,,,,,,,,,, -94000,1.2578992,2.992657,,,,,,,,,,,,,, -94100,1.0460669,5.45409,,,,,,,,,,,,,, -94200,1.0397733,5.1903667,,,,,,,,,,,,,, -94300,0.8826734,5.2193656,,,,,,,,,,,,,, -94396,,,0.5606250166893005,1.9081244468688965,0.5089399814605713,2.156013965606689,50000.0,0.3976000249385834,2.820587158203125,10000.0,43309.709506988525,48453.16395688057,43309.709506988525,5134.835416078568,3.667936563491821,0.0 -94400,1.1654152,3.2217925,,,,,,,,,,,,,, -94500,1.143763,3.395741,,,,,,,,,,,,,, -94600,1.1694304,3.1318035,,,,,,,,,,,,,, -94700,1.1177197,3.3685584,,,,,,,,,,,,,, -94800,1.1348985,3.294805,,,,,,,,,,,,,, -94900,1.0782725,5.525538,,,,,,,,,,,,,, -95000,0.95082176,4.524354,,,,,,,,,,,,,, -95100,1.2437286,3.0374732,,,,,,,,,,,,,, -95200,1.1203883,3.0683656,,,,,,,,,,,,,, -95300,1.2565538,3.1311355,,,,,,,,,,,,,, -95311,,,0.5488671660423279,1.9369208812713623,0.5136199593544006,2.113115072250366,50000.0,0.4028000235557556,2.7761130332946777,10000.0,43729.9824256897,48924.271253585815,43729.9824256897,5185.579132556915,3.7109029293060303,0.0 -95400,1.0792408,5.168098,,,,,,,,,,,,,, -95500,0.93598354,5.52093,,,,,,,,,,,,,, -95600,1.1716644,3.1559849,,,,,,,,,,,,,, -95700,0.9506558,3.8500142,,,,,,,,,,,,,, -95800,0.9136772,5.274949,,,,,,,,,,,,,, -95900,1.2033626,3.3372102,,,,,,,,,,,,,, -96000,1.1882434,3.035385,,,,,,,,,,,,,, -96100,0.9197937,5.365866,,,,,,,,,,,,,, -96200,1.2181766,3.4410467,,,,,,,,,,,,,, -96228,,,0.5465624928474426,1.977414846420288,0.5113599896430969,2.1627683639526367,50000.0,0.4028000235557556,2.821324825286865,10000.0,44150.20173883438,49392.18372750282,44150.20173883438,5233.184512376785,3.750455856323242,0.0 -96300,1.2735866,3.0072389,,,,,,,,,,,,,, -96400,0.8352648,5.481724,,,,,,,,,,,,,, -96500,1.1615863,3.0441256,,,,,,,,,,,,,, -96600,1.1929334,3.0098088,,,,,,,,,,,,,, -96700,1.0778214,2.8638904,,,,,,,,,,,,,, -96800,1.253605,3.0031888,,,,,,,,,,,,,, -96900,1.1376696,3.126567,,,,,,,,,,,,,, -97000,1.0244328,3.183669,,,,,,,,,,,,,, -97100,1.2314014,5.4560018,,,,,,,,,,,,,, -97146,,,0.5612499713897705,1.8850711584091189,0.5123400092124939,2.129922866821289,50000.0,0.4019000232219696,2.7724106311798096,10000.0,44570.48270201683,49862.24414396286,44570.48270201683,5282.874450683594,3.79244875907898,0.0 -97200,1.0208889,3.727046,,,,,,,,,,,,,, -97300,1.1474582,5.055504,,,,,,,,,,,,,, -97400,1.0105377,5.4257283,,,,,,,,,,,,,, -97500,1.1708522,3.2764697,,,,,,,,,,,,,, -97600,1.1257075,3.0584793,,,,,,,,,,,,,, -97700,1.0644221,3.271185,,,,,,,,,,,,,, -97800,1.0257204,4.4330807,,,,,,,,,,,,,, -97900,1.2209101,3.011612,,,,,,,,,,,,,, -98000,1.3043158,3.0688007,,,,,,,,,,,,,, -98058,,,0.5537304282188416,1.9444429874420168,0.524179995059967,2.098775148391724,50000.0,0.4170000255107879,2.7410056591033936,10000.0,44990.81266307831,50328.74145245552,44990.81266307831,5328.955847263336,3.831274509429932,0.0 -98100,1.2284012,3.2013645,,,,,,,,,,,,,, -98200,1.2646568,2.9757907,,,,,,,,,,,,,, -98300,1.135831,3.1345851,,,,,,,,,,,,,, -98400,0.8838612,5.435555,,,,,,,,,,,,,, -98500,1.0661091,5.495361,,,,,,,,,,,,,, -98600,1.0225573,5.4109526,,,,,,,,,,,,,, -98700,1.1819028,2.8918035,,,,,,,,,,,,,, -98800,1.1742082,2.9612553,,,,,,,,,,,,,, -98900,1.156512,3.1134105,,,,,,,,,,,,,, -98972,,,0.56494140625,1.8888856172561648,0.5273999571800232,2.079842090606689,50000.0,0.4125000238418579,2.7286581993103027,10000.0,45410.88789892197,50799.61802506447,45410.88789892197,5379.665320634842,3.875601530075073,0.0 -99000,0.9908819,5.4305778,,,,,,,,,,,,,, -99100,1.202619,3.0391216,,,,,,,,,,,,,, -99200,1.0351211,5.310262,,,,,,,,,,,,,, -99300,1.2022609,3.014783,,,,,,,,,,,,,, -99400,1.140535,2.8869514,,,,,,,,,,,,,, -99500,0.8892379,5.2818885,,,,,,,,,,,,,, -99600,1.1225004,3.4164767,,,,,,,,,,,,,, -99700,0.89415604,4.656433,,,,,,,,,,,,,, -99800,1.2055491,4.020897,,,,,,,,,,,,,, -99889,,,0.5637304782867432,1.909096121788025,0.5228599905967712,2.1150269508361816,50000.0,0.4125000238418579,2.748320817947388,10000.0,45831.25173521042,51269.41478252411,45831.25173521042,5429.00506401062,3.9204533100128174,0.0 -99900,0.80967754,4.779552,,,,,,,,,,,,,, -100000,1.1578709,2.784881,,,,,,,,,,,,,, -100100,1.019366,5.434597,,,,,,,,,,,,,, -100200,1.0700574,5.1734037,,,,,,,,,,,,,, -100300,1.0057791,4.798854,,,,,,,,,,,,,, -100400,1.1789714,3.0309904,,,,,,,,,,,,,, -100500,1.2561305,2.8980908,,,,,,,,,,,,,, -100600,0.9672581,4.0192714,,,,,,,,,,,,,, -100700,1.1056792,2.8469517,,,,,,,,,,,,,, -100800,1.2617713,2.9582686,,,,,,,,,,,,,, -100805,,,0.5720898509025574,1.8558332920074463,0.532260000705719,2.0450856685638428,50000.0,0.42330002784729,2.701802968978882,10000.0,46251.20902395248,51737.73347878456,46251.20902395248,5477.274906158447,3.9645845890045166,0.0 -100900,0.86683,4.822637,,,,,,,,,,,,,, -101000,1.2240617,2.822279,,,,,,,,,,,,,, -101100,1.0622195,3.600254,,,,,,,,,,,,,, -101200,1.3167148,3.2626138,,,,,,,,,,,,,, -101300,1.3536017,2.880283,,,,,,,,,,,,,, -101400,1.1298516,5.5003347,,,,,,,,,,,,,, -101500,1.2484349,2.956373,,,,,,,,,,,,,, -101600,1.0123483,4.373563,,,,,,,,,,,,,, -101700,1.2257005,3.1546726,,,,,,,,,,,,,, -101721,,,0.5673046708106995,1.848193645477295,0.5283600091934204,2.050722360610962,50000.0,0.4095000326633453,2.745497703552246,10000.0,46671.40980911255,52206.08445620537,46671.40980911255,5525.338563919067,4.004071235656738,0.0 -101800,0.97360766,4.7857265,,,,,,,,,,,,,, -101900,1.1745924,3.071049,,,,,,,,,,,,,, -102000,1.2014266,2.9482434,,,,,,,,,,,,,, -102100,1.1621768,4.054964,,,,,,,,,,,,,, -102200,0.9650763,3.9200456,,,,,,,,,,,,,, -102300,1.1215582,3.6578872,,,,,,,,,,,,,, -102400,1.0319828,3.6769686,,,,,,,,,,,,,, -102500,1.2271676,2.880997,,,,,,,,,,,,,, -102600,1.2325419,2.8791394,,,,,,,,,,,,,, -102640,,,0.5785741806030273,1.818680047988892,0.5371400117874146,2.024113655090332,50000.0,0.4227000176906585,2.689002752304077,10000.0,47091.384001493454,52675.02685260773,47091.384001493454,5574.214259624481,4.048406600952148,0.0 -102700,1.2614453,3.2480218,,,,,,,,,,,,,, -102800,0.88376176,5.4111547,,,,,,,,,,,,,, -102900,0.93731683,5.3701735,,,,,,,,,,,,,, -103000,0.95893365,5.2113743,,,,,,,,,,,,,, -103100,1.2682859,2.880889,,,,,,,,,,,,,, -103200,1.140404,4.306112,,,,,,,,,,,,,, -103300,1.2809054,2.80668,,,,,,,,,,,,,, -103400,1.1589819,3.0131927,,,,,,,,,,,,,, -103500,0.98375344,3.6796865,,,,,,,,,,,,,, -103558,,,0.5989648103713989,1.7452768087387085,0.5286200046539307,2.076588153839112,50000.0,0.41880002617836,2.720669269561768,10000.0,47511.33321595192,53143.97084522247,47511.33321595192,5623.107983827591,4.101131916046143,0.0 -103600,1.3987923,2.7621071,,,,,,,,,,,,,, -103700,1.1154834,2.814837,,,,,,,,,,,,,, -103800,1.0858209,5.235933,,,,,,,,,,,,,, -103900,1.009174,4.651586,,,,,,,,,,,,,, -104000,1.6202824,2.8831418,,,,,,,,,,,,,, -104100,1.1778735,3.174378,,,,,,,,,,,,,, -104200,1.0835199,5.1456685,,,,,,,,,,,,,, -104300,1.254723,2.8846323,,,,,,,,,,,,,, -104400,1.1789608,2.7017617,,,,,,,,,,,,,, -104477,,,0.5718554854393005,1.8921626806259155,0.5343199968338013,2.075522184371948,50000.0,0.4228000342845917,2.7130990028381348,10000.0,47931.56638360024,53613.40177822113,47931.56638360024,5672.216294765472,4.143033742904663,0.0 -104500,1.339537,2.7562742,,,,,,,,,,,,,, -104600,1.1085774,3.7372632,,,,,,,,,,,,,, -104700,1.303056,2.908227,,,,,,,,,,,,,, -104800,1.0716108,3.8611114,,,,,,,,,,,,,, -104900,1.2122717,2.8643274,,,,,,,,,,,,,, -105000,1.2679715,2.8521056,,,,,,,,,,,,,, -105100,1.2825754,2.8678427,,,,,,,,,,,,,, -105200,0.9379612,4.680742,,,,,,,,,,,,,, -105300,1.2352175,2.8874636,,,,,,,,,,,,,, -105394,,,0.5820507407188416,1.837495803833008,0.5375800132751465,2.045927047729492,50000.0,0.4191000163555145,2.7071774005889893,10000.0,48351.92719125748,54085.40121340752,48351.92719125748,5723.759583473206,4.190583944320679,0.0 -105400,1.019332,3.733363,,,,,,,,,,,,,, -105500,1.2361661,2.9178805,,,,,,,,,,,,,, -105600,1.0434679,4.640808,,,,,,,,,,,,,, -105700,1.2516495,3.16505,,,,,,,,,,,,,, -105800,1.2783761,2.7574341,,,,,,,,,,,,,, -105900,1.0783963,4.2665863,,,,,,,,,,,,,, -106000,1.0601349,4.214996,,,,,,,,,,,,,, -106100,1.191478,3.1403031,,,,,,,,,,,,,, -106200,1.3016824,2.814723,,,,,,,,,,,,,, -106300,1.1706418,3.0862224,,,,,,,,,,,,,, -106312,,,0.5904687643051147,1.7310149669647217,0.5409199595451355,1.9848403930664065,50000.0,0.4317000210285187,2.648791074752808,10000.0,48772.00656723976,54556.38950634003,48772.00656723976,5774.5703637599945,4.240931749343872,0.0 -106400,1.12936,5.386484,,,,,,,,,,,,,, -106500,1.2670435,2.8540907,,,,,,,,,,,,,, -106600,1.3204277,2.862022,,,,,,,,,,,,,, -106700,1.2959665,2.8421905,,,,,,,,,,,,,, -106800,1.1497179,3.409346,,,,,,,,,,,,,, -106900,0.94993293,5.253581,,,,,,,,,,,,,, -107000,1.1165495,2.9344568,,,,,,,,,,,,,, -107100,1.1545541,4.301124,,,,,,,,,,,,,, -107200,1.0355955,3.9380193,,,,,,,,,,,,,, -107228,,,0.5759570002555847,1.833008050918579,0.537880003452301,2.0171401500701904,50000.0,0.425100028514862,2.690009593963623,10000.0,49192.05952215195,55025.439949035645,49192.05952215195,5823.475866556168,4.285040616989136,0.0 -107300,0.9594739,4.9569273,,,,,,,,,,,,,, -107400,1.1775261,3.1087186,,,,,,,,,,,,,, -107500,1.3139807,2.861394,,,,,,,,,,,,,, -107600,1.1995883,2.846101,,,,,,,,,,,,,, -107700,1.1113491,3.40741,,,,,,,,,,,,,, -107800,1.220098,5.545696,,,,,,,,,,,,,, -107900,0.8540976,4.9880986,,,,,,,,,,,,,, -108000,1.2984045,2.8448555,,,,,,,,,,,,,, -108100,1.0505313,3.5480642,,,,,,,,,,,,,, -108146,,,0.5908398032188416,1.7584213018417358,0.5523999929428101,1.949994444847107,50000.0,0.4388000071048736,2.611959934234619,10000.0,49612.333136081696,55495.51818156242,49612.333136081696,5873.19217467308,4.326111793518066,0.0 -108200,1.1550568,2.8150797,,,,,,,,,,,,,, -108300,1.1608646,2.709409,,,,,,,,,,,,,, -108400,1.1326727,3.7669196,,,,,,,,,,,,,, -108500,1.5107872,3.1381483,,,,,,,,,,,,,, -108600,1.2910634,2.8517964,,,,,,,,,,,,,, -108700,1.2421197,2.7844272,,,,,,,,,,,,,, -108800,1.2735033,2.783486,,,,,,,,,,,,,, -108900,1.2158557,2.957055,,,,,,,,,,,,,, -109000,1.1736392,3.083315,,,,,,,,,,,,,, -109061,,,0.6046093702316284,1.7196390628814695,0.5508399605751038,1.968966007232666,50000.0,0.4354000091552734,2.6368744373321533,10000.0,50032.38190603256,55965.41650009155,50032.38190603256,5922.949047088623,4.369261980056763,0.0 -109100,1.2398015,2.76023,,,,,,,,,,,,,, -109200,1.2488351,2.69924,,,,,,,,,,,,,, -109300,1.0662147,4.880578,,,,,,,,,,,,,, -109400,1.2777067,2.883962,,,,,,,,,,,,,, -109500,1.2639918,2.680699,,,,,,,,,,,,,, -109600,1.1395673,5.4122286,,,,,,,,,,,,,, -109700,1.0622787,3.7176898,,,,,,,,,,,,,, -109800,1.3189752,2.8291724,,,,,,,,,,,,,, -109900,1.0984601,3.2325587,,,,,,,,,,,,,, -109981,,,0.5870702862739563,1.7547787427902222,0.5485199689865112,1.9464657306671145,50000.0,0.4374000132083893,2.614704370498657,10000.0,50452.712889909744,56434.7984893322,50452.712889909744,5971.910776138306,4.410284757614136,0.0 -110000,1.1473361,3.3173878,,,,,,,,,,,,,, -110100,1.0472118,3.3478835,,,,,,,,,,,,,, -110200,1.1336884,3.039132,,,,,,,,,,,,,, -110300,1.3406203,2.8193882,,,,,,,,,,,,,, -110400,1.2988893,2.7770166,,,,,,,,,,,,,, -110500,1.0041583,4.8395224,,,,,,,,,,,,,, -110600,1.1907413,2.7092617,,,,,,,,,,,,,, -110700,1.4104198,2.7805505,,,,,,,,,,,,,, -110800,1.3045064,3.0757625,,,,,,,,,,,,,, -110896,,,0.6010546684265137,1.6854904890060425,0.556879997253418,1.8993265628814693,50000.0,0.4438000321388244,2.5637691020965576,10000.0,50872.9928958416,56906.227376937866,50872.9928958416,6022.969138383865,4.453753709793091,0.0 -110900,0.9786993,5.358777,,,,,,,,,,,,,, -111000,1.0635633,3.0994523,,,,,,,,,,,,,, -111100,1.0058866,3.9149814,,,,,,,,,,,,,, -111200,1.0080541,4.23971,,,,,,,,,,,,,, -111300,1.3045516,2.6784322,,,,,,,,,,,,,, -111400,1.0089173,4.5980206,,,,,,,,,,,,,, -111500,0.99063665,4.5629616,,,,,,,,,,,,,, -111600,1.2086998,2.9033005,,,,,,,,,,,,,, -111700,1.362229,2.7150936,,,,,,,,,,,,,, -111800,1.2535536,2.92984,,,,,,,,,,,,,, -111813,,,0.606640636920929,1.6578865051269531,0.5602999925613403,1.8957290649414065,50000.0,0.4487000107765198,2.552631378173828,10000.0,51293.23732328415,57376.957426548,51293.23732328415,6073.349613189697,4.510779142379761,0.0 -111900,1.1202179,3.2140336,,,,,,,,,,,,,, -112000,0.97536314,4.833217,,,,,,,,,,,,,, -112100,1.025683,4.458943,,,,,,,,,,,,,, -112200,1.4140924,2.7685592,,,,,,,,,,,,,, -112300,1.1858792,3.1981003,,,,,,,,,,,,,, -112400,1.1337469,2.866878,,,,,,,,,,,,,, -112500,1.3953861,2.8558278,,,,,,,,,,,,,, -112600,1.2391603,3.609572,,,,,,,,,,,,,, -112700,1.2260983,3.1086204,,,,,,,,,,,,,, -112731,,,0.6123046875,1.6747193336486816,0.5610399842262268,1.8994790315628047,50000.0,0.4410000145435333,2.5727462768554688,10000.0,51713.58249878883,57846.42418789864,51713.58249878883,6122.37894487381,4.555319547653198,0.0 -112800,1.4300076,2.8920677,,,,,,,,,,,,,, -112900,1.2701296,2.8746603,,,,,,,,,,,,,, -113000,1.1902853,2.646508,,,,,,,,,,,,,, -113100,1.2725261,2.7248447,,,,,,,,,,,,,, -113200,1.2494636,2.6107116,,,,,,,,,,,,,, -113300,1.3495667,2.8226113,,,,,,,,,,,,,, -113400,1.2612551,4.930781,,,,,,,,,,,,,, -113500,1.1287673,3.2738774,,,,,,,,,,,,,, -113600,1.2179263,2.8906915,,,,,,,,,,,,,, -113647,,,0.6024609208106995,1.6888262033462524,0.5626800060272217,1.876355528831482,50000.0,0.4446000158786773,2.5511436462402344,10000.0,52133.86265182495,58317.28026175499,52133.86265182495,6172.862322568893,4.599631071090698,0.0 -113700,1.2938683,2.7360668,,,,,,,,,,,,,, -113800,1.4164095,2.773395,,,,,,,,,,,,,, -113900,1.2617662,2.8669171,,,,,,,,,,,,,, -114000,1.0596323,3.7667327,,,,,,,,,,,,,, -114100,1.2632153,2.6792798,,,,,,,,,,,,,, -114200,1.310912,2.708134,,,,,,,,,,,,,, -114300,1.2322282,2.594045,,,,,,,,,,,,,, -114400,1.1398965,3.3683946,,,,,,,,,,,,,, -114500,1.336897,2.7242465,,,,,,,,,,,,,, -114565,,,0.5999413728713989,1.7239134311676023,0.5591199994087219,1.9196312427520752,50000.0,0.4391000270843506,2.578800678253174,10000.0,52554.00845098496,58785.77437710762,52554.00845098496,6221.115981340408,4.645940542221069,0.0 -114600,1.3007991,2.9523115,,,,,,,,,,,,,, -114700,1.3745948,2.747729,,,,,,,,,,,,,, -114800,1.1801713,5.294643,,,,,,,,,,,,,, -114900,1.2349265,3.1108162,,,,,,,,,,,,,, -115000,1.1030878,3.3317993,,,,,,,,,,,,,, -115100,1.4021554,2.6613078,,,,,,,,,,,,,, -115200,1.0726675,3.7897716,,,,,,,,,,,,,, -115300,1.3925544,2.7441823,,,,,,,,,,,,,, -115400,1.4511137,2.6577406,,,,,,,,,,,,,, -115483,,,0.6393163800239563,1.5017653703689575,0.5700199604034424,1.8257381916046145,50000.0,0.4597000181674957,2.50107741355896,10000.0,52974.09374523163,59255.32407426834,52974.09374523163,6270.487004041672,4.691897869110107,0.0 -115500,1.3895975,2.746747,,,,,,,,,,,,,, -115600,1.2328221,2.8213413,,,,,,,,,,,,,, -115700,1.0733362,3.7561448,,,,,,,,,,,,,, -115800,1.34639,2.6625512,,,,,,,,,,,,,, -115900,1.1786662,5.2000475,,,,,,,,,,,,,, -116000,1.4711312,2.7643764,,,,,,,,,,,,,, -116100,1.2300212,4.25914,,,,,,,,,,,,,, -116200,1.5645028,2.8394997,,,,,,,,,,,,,, -116300,1.2477081,2.7280354,,,,,,,,,,,,,, -116400,,,0.6116992235183716,1.6490353345870972,0.5755999684333801,1.843803882598877,50000.0,0.4602000117301941,2.501185655593872,10000.0,53394.46939063072,59729.46551704407,53394.46939063072,6324.160755395889,4.736301422119141,0.0 -116400,1.4901093,3.0000648,,,,,,,,,,,,,, -116500,1.1624563,3.7815022,,,,,,,,,,,,,, -116600,1.3894913,2.6548662,,,,,,,,,,,,,, -116700,1.3024157,2.656059,,,,,,,,,,,,,, -116800,1.3843817,2.6570206,,,,,,,,,,,,,, -116900,1.231146,2.7595758,,,,,,,,,,,,,, -117000,1.3548669,2.6889467,,,,,,,,,,,,,, -117100,1.31497,2.8135557,,,,,,,,,,,,,, -117200,1.1314305,3.7008724,,,,,,,,,,,,,, -117300,1.3520416,2.6145666,,,,,,,,,,,,,, -117319,,,0.6172265410423279,1.627258539199829,0.5745399594306946,1.8362805843353271,50000.0,0.4551000297069549,2.4943559169769287,10000.0,53814.67329096794,60201.44076490402,53814.67329096794,6375.838165998459,4.7819108963012695,0.0 -117400,1.2856263,5.30151,,,,,,,,,,,,,, -117500,1.1623119,4.875556,,,,,,,,,,,,,, -117600,1.107557,4.8058095,,,,,,,,,,,,,, -117700,1.3477634,2.5751872,,,,,,,,,,,,,, -117800,1.4065473,2.7039032,,,,,,,,,,,,,, -117900,1.1093916,4.854533,,,,,,,,,,,,,, -118000,1.0302739,5.2268267,,,,,,,,,,,,,, -118100,1.1593646,4.7395782,,,,,,,,,,,,,, -118200,1.1242998,3.655637,,,,,,,,,,,,,, -118236,,,0.6234374642372131,1.591255784034729,0.5689799785614014,1.854089856147766,50000.0,0.4593000113964081,2.519318103790283,10000.0,54234.85326814652,60671.578468084335,54234.85326814652,6425.695611715317,4.8346474170684814,0.0 -118300,1.1599402,3.5031078,,,,,,,,,,,,,, -118400,1.4212428,2.6948037,,,,,,,,,,,,,, -118500,1.0592276,3.9572752,,,,,,,,,,,,,, -118600,1.2728064,3.0076666,,,,,,,,,,,,,, -118700,1.5191363,2.6685464,,,,,,,,,,,,,, -118800,1.25596,2.9819832,,,,,,,,,,,,,, -118900,1.415309,2.5485744,,,,,,,,,,,,,, -119000,1.3494607,2.9102135,,,,,,,,,,,,,, -119100,1.1112486,3.3850448,,,,,,,,,,,,,, -119150,,,0.6201952695846558,1.594254732131958,0.5825200080871582,1.778774380683899,50000.0,0.4622000157833099,2.4543488025665283,10000.0,54655.12280344963,61142.06141138077,54655.12280344963,6475.814856529236,4.881854772567749,0.0 -119200,1.0881739,4.2263103,,,,,,,,,,,,,, -119300,1.3383963,2.616737,,,,,,,,,,,,,, -119400,1.0869998,3.9025269,,,,,,,,,,,,,, -119500,1.431331,2.8078218,,,,,,,,,,,,,, -119600,1.1512569,3.8762138,,,,,,,,,,,,,, -119700,1.1088576,4.8241324,,,,,,,,,,,,,, -119800,1.1343315,3.364413,,,,,,,,,,,,,, -119900,1.3398279,2.5554924,,,,,,,,,,,,,, -120000,1.3486706,2.9297595,,,,,,,,,,,,,, -120069,,,0.6280273199081421,1.5842721462249756,0.5796799659729004,1.7919082641601562,50000.0,0.4643000364303589,2.4546687602996826,10000.0,55075.43723917008,61610.07016038895,55075.43723917008,6523.415584564209,4.927415132522583,0.0 -120100,1.5330944,2.729847,,,,,,,,,,,,,, -120200,1.2491136,2.8530805,,,,,,,,,,,,,, -120300,1.1705657,3.1930215,,,,,,,,,,,,,, -120400,1.4049888,2.724791,,,,,,,,,,,,,, -120500,1.2884167,2.7149916,,,,,,,,,,,,,, -120600,1.469169,2.6274672,,,,,,,,,,,,,, -120700,1.3473563,2.6009264,,,,,,,,,,,,,, -120800,1.2135882,3.7719798,,,,,,,,,,,,,, -120900,1.4513799,2.8786774,,,,,,,,,,,,,, -120987,,,0.6349999904632568,1.535889744758606,0.584879994392395,1.7815570831298828,50000.0,0.4634000360965729,2.4435477256774902,10000.0,55495.7010948658,62080.37963700295,55495.7010948658,6573.364508867264,4.976737022399902,0.0 -121000,1.224214,4.520016,,,,,,,,,,,,,, -121100,1.4349226,2.607977,,,,,,,,,,,,,, -121200,1.1008104,4.397361,,,,,,,,,,,,,, -121300,1.3028772,3.140253,,,,,,,,,,,,,, -121400,1.3150673,4.3517284,,,,,,,,,,,,,, -121500,1.3231955,2.5217056,,,,,,,,,,,,,, -121600,1.5511751,2.6281765,,,,,,,,,,,,,, -121700,1.3336409,2.635692,,,,,,,,,,,,,, -121800,1.4033959,2.9476953,,,,,,,,,,,,,, -121900,1.5758977,2.670938,,,,,,,,,,,,,, -121903,,,0.6355859041213989,1.5495035648345947,0.5896599888801575,1.7570979595184326,50000.0,0.4704000353813171,2.42789888381958,10000.0,55915.98540139198,62549.107691049576,55915.98540139198,6621.707966804504,5.029550790786743,0.0 -122000,1.4940263,2.6937249,,,,,,,,,,,,,, -122100,1.3822596,2.6029406,,,,,,,,,,,,,, -122200,1.2474232,2.9939735,,,,,,,,,,,,,, -122300,1.3351939,3.093294,,,,,,,,,,,,,, -122400,1.1080317,4.456364,,,,,,,,,,,,,, -122500,1.4539069,2.635924,,,,,,,,,,,,,, -122600,1.3679819,2.5335333,,,,,,,,,,,,,, -122700,1.7362509,2.5765762,,,,,,,,,,,,,, -122800,1.3996202,2.6749396,,,,,,,,,,,,,, -122819,,,0.6341015696525574,1.5348260402679443,0.5896399617195129,1.748339056968689,50000.0,0.4693000316619873,2.4274723529815674,10000.0,56336.27511167526,63019.44883728027,56336.27511167526,6671.663818836212,5.077637434005737,0.0 -122900,1.2705919,2.80081,,,,,,,,,,,,,, -123000,1.2177382,3.189772,,,,,,,,,,,,,, -123100,1.2180681,3.747539,,,,,,,,,,,,,, -123200,1.3501072,2.5917177,,,,,,,,,,,,,, -123300,1.4340737,2.7291942,,,,,,,,,,,,,, -123400,1.4147173,2.5622382,,,,,,,,,,,,,, -123500,1.3999357,2.6860137,,,,,,,,,,,,,, -123600,1.285791,4.7394104,,,,,,,,,,,,,, -123700,1.4280963,2.5713525,,,,,,,,,,,,,, -123736,,,0.6442577838897705,1.493449330329895,0.594539999961853,1.7315579652786257,50000.0,0.4792000353336334,2.365442991256714,10000.0,56756.223601818085,63489.87978386879,56756.223601818085,6722.052555799484,5.1239094734191895,0.0 -123800,1.1703299,4.4922714,,,,,,,,,,,,,, -123900,1.2493362,4.1328173,,,,,,,,,,,,,, -124000,1.5102721,2.5626488,,,,,,,,,,,,,, -124100,1.3803158,2.7263246,,,,,,,,,,,,,, -124200,1.423286,2.6988182,,,,,,,,,,,,,, -124300,1.4038044,2.624113,,,,,,,,,,,,,, -124400,1.2546206,3.5953135,,,,,,,,,,,,,, -124500,1.1684623,3.7726116,,,,,,,,,,,,,, -124600,1.2307336,3.2928953,,,,,,,,,,,,,, -124653,,,0.6554492115974426,1.4739404916763306,0.5974400043487549,1.7426960468292236,50000.0,0.4780000150203705,2.39933180809021,10000.0,57176.31259250641,63961.48591351509,57176.31259250641,6773.475342273712,5.171019077301025,0.0 -124700,1.1558492,3.839777,,,,,,,,,,,,,, -124800,1.5725591,2.5155215,,,,,,,,,,,,,, -124900,1.3999475,2.521327,,,,,,,,,,,,,, -125000,1.4720905,2.6518538,,,,,,,,,,,,,, -125100,1.4115351,2.4811385,,,,,,,,,,,,,, -125200,1.2371991,4.1950693,,,,,,,,,,,,,, -125300,1.5227847,2.7357712,,,,,,,,,,,,,, -125400,1.4913366,2.4758363,,,,,,,,,,,,,, -125500,1.464649,2.5911098,,,,,,,,,,,,,, -125568,,,0.6419726610183716,1.5183041095733645,0.5974000096321106,1.7342219352722168,50000.0,0.4763000309467315,2.396895408630371,10000.0,57596.52352762222,64430.41939020157,57596.52352762222,6822.102419376373,5.218945741653442,0.0 -125600,1.2119285,4.663387,,,,,,,,,,,,,, -125700,1.2059052,3.7636151,,,,,,,,,,,,,, -125800,1.4370754,2.8143358,,,,,,,,,,,,,, -125900,1.2888267,5.0904565,,,,,,,,,,,,,, -126000,1.7576662,2.414129,,,,,,,,,,,,,, -126100,1.276872,2.9092557,,,,,,,,,,,,,, -126200,1.3840508,2.9914198,,,,,,,,,,,,,, -126300,1.2546396,5.016047,,,,,,,,,,,,,, -126400,1.1597676,4.9242487,,,,,,,,,,,,,, -126482,,,0.6576171517372131,1.45563805103302,0.6087200045585632,1.6954721212387085,50000.0,0.4836000204086303,2.356315851211548,10000.0,58016.79306435585,64898.98254442215,58016.79306435585,6870.285463809967,5.281857252120972,0.0 -126500,1.5908304,2.4410305,,,,,,,,,,,,,, -126600,1.1936071,4.5611863,,,,,,,,,,,,,, -126700,1.3983415,2.642115,,,,,,,,,,,,,, -126800,1.4528879,2.5068486,,,,,,,,,,,,,, -126900,1.2580304,3.72545,,,,,,,,,,,,,, -127000,1.3754768,2.625873,,,,,,,,,,,,,, -127100,1.4753944,2.41542,,,,,,,,,,,,,, -127200,1.4612635,2.8092651,,,,,,,,,,,,,, -127300,1.506041,2.5655909,,,,,,,,,,,,,, -127400,,,0.6700586080551147,1.3599334955215454,0.6076599955558777,1.6682051420211792,50000.0,0.4881000220775604,2.3263015747070312,10000.0,58437.100652217865,65366.74352836609,58437.100652217865,6917.643349409103,5.328955173492432,0.0 -127400,1.4510837,2.3974326,,,,,,,,,,,,,, -127500,1.5092989,2.5660326,,,,,,,,,,,,,, -127600,1.5646185,2.6050043,,,,,,,,,,,,,, -127700,1.2701769,2.674583,,,,,,,,,,,,,, -127800,1.3691893,2.5396283,,,,,,,,,,,,,, -127900,1.2000884,4.186804,,,,,,,,,,,,,, -128000,1.420282,2.510512,,,,,,,,,,,,,, -128100,1.280776,3.016441,,,,,,,,,,,,,, -128200,1.6476531,2.6052897,,,,,,,,,,,,,, -128300,1.5364959,2.5319157,,,,,,,,,,,,,, -128316,,,0.6534960865974426,1.4608618021011353,0.6089000105857849,1.6636296510696411,50000.0,0.4864000082015991,2.318061113357544,10000.0,58857.16923952103,65835.45329618454,58857.16923952103,6966.18169260025,5.3842246532440186,0.0 -128400,1.2608806,4.8925905,,,,,,,,,,,,,, -128500,1.2023767,3.9128103,,,,,,,,,,,,,, -128600,1.3042159,3.3590803,,,,,,,,,,,,,, -128700,1.2223989,3.834158,,,,,,,,,,,,,, -128800,1.4357834,2.4278858,,,,,,,,,,,,,, -128900,1.5060383,2.7600133,,,,,,,,,,,,,, -129000,1.3489753,3.6857896,,,,,,,,,,,,,, -129100,1.312011,3.413288,,,,,,,,,,,,,, -129200,1.40949,3.245743,,,,,,,,,,,,,, -129234,,,0.6603124737739563,1.4079570770263672,0.6110399961471558,1.6467926502227783,50000.0,0.4950000345706939,2.2942965030670166,10000.0,59277.42556810379,66305.90271615982,59277.42556810379,7016.280867099762,5.430881500244141,0.0 -129300,1.5805216,2.4227748,,,,,,,,,,,,,, -129400,1.5066088,2.3899374,,,,,,,,,,,,,, -129500,1.3317819,3.0404143,,,,,,,,,,,,,, -129600,1.5417944,2.5488021,,,,,,,,,,,,,, -129700,1.1467465,5.0506115,,,,,,,,,,,,,, -129800,1.4835933,2.2989492,,,,,,,,,,,,,, -129900,1.2419916,5.0569205,,,,,,,,,,,,,, -130000,1.437712,3.0155437,,,,,,,,,,,,,, -130100,1.4189476,3.5640779,,,,,,,,,,,,,, -130154,,,0.6708202958106995,1.3997917175292969,0.6119999885559082,1.674870252609253,50000.0,0.4908000230789184,2.339881896972656,10000.0,59697.67466640472,66774.41073536873,59697.67466640472,7064.444483995438,5.477983713150024,0.0 -130200,1.3032411,4.1100163,,,,,,,,,,,,,, -130300,1.6511985,2.4291482,,,,,,,,,,,,,, -130400,1.4345894,4.320045,,,,,,,,,,,,,, -130500,1.6009201,2.3546107,,,,,,,,,,,,,, -130600,1.2116508,4.3538923,,,,,,,,,,,,,, -130700,1.8145437,2.4187713,,,,,,,,,,,,,, -130800,1.3589631,3.6104848,,,,,,,,,,,,,, -130900,1.4009484,4.5531783,,,,,,,,,,,,,, -131000,1.4388456,2.3452687,,,,,,,,,,,,,, -131074,,,0.6654687523841858,1.3992410898208618,0.6183599829673767,1.6063157320022583,50000.0,0.494700014591217,2.2643074989318848,10000.0,60117.79187011719,67242.87277555466,60117.79187011719,7112.693949460983,5.525101900100708,0.0 -131100,1.3974383,3.8781815,,,,,,,,,,,,,, -131200,1.567416,2.5008664,,,,,,,,,,,,,, -131300,1.2342247,4.5215864,,,,,,,,,,,,,, -131400,1.3090606,3.8436666,,,,,,,,,,,,,, -131500,1.2799354,4.9395776,,,,,,,,,,,,,, -131600,1.5833626,2.761108,,,,,,,,,,,,,, -131700,1.5217028,2.4713883,,,,,,,,,,,,,, -131800,1.4910951,2.525275,,,,,,,,,,,,,, -131900,1.5265715,2.3510475,,,,,,,,,,,,,, -131991,,,0.6669335961341858,1.3916473388671875,0.6188600063323975,1.623449206352234,50000.0,0.4973000288009643,2.2747392654418945,10000.0,60537.90219092369,67711.90890932083,60537.90219092369,7161.525184869766,5.570846319198608,0.0 -132000,1.569303,2.5551023,,,,,,,,,,,,,, -132100,1.5565239,2.4561996,,,,,,,,,,,,,, -132200,1.42753,4.6669044,,,,,,,,,,,,,, -132300,1.4784262,2.4250238,,,,,,,,,,,,,, -132400,1.5797694,2.556764,,,,,,,,,,,,,, -132500,1.4376202,2.6459186,,,,,,,,,,,,,, -132600,1.6203957,2.309926,,,,,,,,,,,,,, -132700,1.6489012,2.40938,,,,,,,,,,,,,, -132800,1.343829,3.7207725,,,,,,,,,,,,,, -132900,1.5815489,2.5509312,,,,,,,,,,,,,, -132910,,,0.6772265434265137,1.361317157745361,0.620959997177124,1.627536654472351,50000.0,0.4974000155925751,2.2841172218322754,10000.0,60957.94920253754,68180.9085547924,60957.94920253754,7210.376707792282,5.624185085296631,0.0 -133000,1.5043283,2.687019,,,,,,,,,,,,,, -133100,1.6626142,2.4342067,,,,,,,,,,,,,, -133200,1.7479178,2.2765484,,,,,,,,,,,,,, -133300,1.274004,3.4897854,,,,,,,,,,,,,, -133400,1.7646288,2.5890253,,,,,,,,,,,,,, -133500,1.5622356,2.5799234,,,,,,,,,,,,,, -133600,1.4796791,4.703372,,,,,,,,,,,,,, -133700,1.3483279,4.0499945,,,,,,,,,,,,,, -133800,1.6576439,2.4289496,,,,,,,,,,,,,, -133828,,,0.6712890267372131,1.365814447402954,0.6225999593734741,1.5908536911010742,50000.0,0.501300036907196,2.24080491065979,10000.0,61378.31641602516,68652.09668803215,61378.31641602516,7261.099878549576,5.673567771911621,0.0 -133900,1.552389,2.5312376,,,,,,,,,,,,,, -134000,1.4469186,2.8506157,,,,,,,,,,,,,, -134100,1.3413448,2.8882785,,,,,,,,,,,,,, -134200,1.2326477,3.2916226,,,,,,,,,,,,,, -134300,1.5311993,4.83767,,,,,,,,,,,,,, -134400,1.4550228,2.7933967,,,,,,,,,,,,,, -134500,1.477325,2.622659,,,,,,,,,,,,,, -134600,1.5929064,2.3939335,,,,,,,,,,,,,, -134700,1.428383,3.1036038,,,,,,,,,,,,,, -134743,,,0.6692187190055847,1.3907537460327148,0.6218199729919434,1.610404133796692,50000.0,0.4987000226974487,2.2690556049346924,10000.0,61798.59408092499,69123.8744187355,61798.59408092499,7312.5014128685,5.723673582077026,0.0 -134800,1.378892,3.3457415,,,,,,,,,,,,,, -134900,1.4585007,2.4628875,,,,,,,,,,,,,, -135000,1.4335871,3.1253507,,,,,,,,,,,,,, -135100,1.5732633,2.3388503,,,,,,,,,,,,,, -135200,1.6800563,2.28952,,,,,,,,,,,,,, -135300,1.7614919,2.4705293,,,,,,,,,,,,,, -135400,1.6488701,2.316188,,,,,,,,,,,,,, -135500,1.5530027,2.3987777,,,,,,,,,,,,,, -135600,1.8381873,2.5006073,,,,,,,,,,,,,, -135622,,,0.6849804520606995,1.3148226737976074,0.6287999749183655,1.5665632486343384,50000.0,0.5057000517845154,2.21661114692688,10000.0,62218.69186377525,69592.01400113106,62218.69186377525,7360.451724529266,5.769453287124634,0.0 -135700,1.4993637,2.6391582,,,,,,,,,,,,,, -135800,1.4863702,2.454681,,,,,,,,,,,,,, -135900,1.5957142,2.6743932,,,,,,,,,,,,,, -136000,1.6259297,2.938339,,,,,,,,,,,,,, -136100,1.4932468,2.9563656,,,,,,,,,,,,,, -136200,1.6908855,2.3262174,,,,,,,,,,,,,, -136300,1.5099527,3.502095,,,,,,,,,,,,,, -136400,1.615465,2.3180187,,,,,,,,,,,,,, -136500,1.6417522,2.4486375,,,,,,,,,,,,,, -136534,,,0.6810937523841858,1.350616216659546,0.6312800049781799,1.5826791524887085,50000.0,0.5059000253677368,2.23021936416626,10000.0,62638.98313713074,70062.6733353138,62638.98313713074,7410.717103242874,5.824112415313721,0.0 -136600,1.635045,2.3139668,,,,,,,,,,,,,, -136700,1.8165052,2.377051,,,,,,,,,,,,,, -136800,1.6457044,2.3897555,,,,,,,,,,,,,, -136900,1.6415348,2.2540822,,,,,,,,,,,,,, -137000,1.5716251,2.2649703,,,,,,,,,,,,,, -137100,1.2471161,4.1413193,,,,,,,,,,,,,, -137200,1.6954645,2.4860811,,,,,,,,,,,,,, -137300,1.7818433,2.3520179,,,,,,,,,,,,,, -137400,1.511564,2.479478,,,,,,,,,,,,,, -137447,,,0.6771484017372131,1.363439917564392,0.6318599581718445,1.5701099634170532,50000.0,0.511900007724762,2.209479093551636,10000.0,63059.25512361527,70531.656188488,63059.25512361527,7459.330714464188,5.873344898223877,0.0 -137500,1.4433972,4.4242363,,,,,,,,,,,,,, -137600,1.6585661,2.443277,,,,,,,,,,,,,, -137700,1.6467806,2.3082974,,,,,,,,,,,,,, -137800,1.8648441,2.24307,,,,,,,,,,,,,, -137900,1.7503303,2.2150474,,,,,,,,,,,,,, -138000,1.7204541,2.3663993,,,,,,,,,,,,,, -138100,1.4996799,3.4533446,,,,,,,,,,,,,, -138200,1.7166047,2.398608,,,,,,,,,,,,,, -138300,1.799978,2.2292027,,,,,,,,,,,,,, -138358,,,0.6902929544448853,1.3037259578704834,0.6375799775123596,1.5405707359313965,50000.0,0.51500004529953,2.1850335597991943,10000.0,63479.41015410423,71001.12347507477,63479.41015410423,7508.541209936142,5.927834272384644,0.0 -138400,1.4159758,4.4762344,,,,,,,,,,,,,, -138500,1.6952636,2.3426404,,,,,,,,,,,,,, -138600,1.6706774,2.3066344,,,,,,,,,,,,,, -138700,1.5538491,2.1595654,,,,,,,,,,,,,, -138800,1.7047881,2.3444853,,,,,,,,,,,,,, -138900,1.6487443,2.2617776,,,,,,,,,,,,,, -139000,1.4270494,3.7920065,,,,,,,,,,,,,, -139100,1.8645161,2.2833116,,,,,,,,,,,,,, -139200,1.7095445,2.3568597,,,,,,,,,,,,,, -139272,,,0.7201562523841858,1.162137746810913,0.6441999673843384,1.5044505596160889,50000.0,0.5205000042915344,2.161705732345581,10000.0,63899.503361940384,71473.19856882095,63899.503361940384,7560.420810461044,5.9827258586883545,0.0 -139300,1.6680878,2.5879004,,,,,,,,,,,,,, -139400,1.4054269,3.8967843,,,,,,,,,,,,,, -139500,1.4828032,4.4013176,,,,,,,,,,,,,, -139600,1.6505727,2.3388157,,,,,,,,,,,,,, -139700,1.760051,2.2769985,,,,,,,,,,,,,, -139800,1.4651386,3.638495,,,,,,,,,,,,,, -139900,1.649155,2.7638524,,,,,,,,,,,,,, -140000,1.4902557,3.4650092,,,,,,,,,,,,,, -140100,1.7179755,2.046024,,,,,,,,,,,,,, -140189,,,0.6912499666213989,1.2617908716201782,0.6451199650764465,1.4823774099349976,50000.0,0.52510005235672,2.134714841842652,10000.0,64319.68514704704,71943.69709396362,64319.68514704704,7610.637903690338,6.034387826919556,0.0 -140200,1.6892799,2.29314,,,,,,,,,,,,,, -140300,1.5506433,2.3527014,,,,,,,,,,,,,, -140400,1.4245579,4.1701603,,,,,,,,,,,,,, -140500,1.6090914,3.184696,,,,,,,,,,,,,, -140600,1.6092262,2.260986,,,,,,,,,,,,,, -140700,1.6764044,2.3288953,,,,,,,,,,,,,, -140800,1.7133385,2.2481668,,,,,,,,,,,,,, -140900,1.7965486,2.2933607,,,,,,,,,,,,,, -141000,1.6105814,4.317459,,,,,,,,,,,,,, -141100,1.5109574,2.635922,,,,,,,,,,,,,, -141109,,,0.7030664086341858,1.2258492708206177,0.6487999558448792,1.4680681228637695,50000.0,0.5254999995231628,2.115793943405152,10000.0,64739.68468880653,72414.69781684875,64739.68468880653,7661.540041446686,6.084909915924072,0.0 -141200,1.7952496,2.3635945,,,,,,,,,,,,,, -141300,1.812559,2.1928144,,,,,,,,,,,,,, -141400,1.5685512,2.2369137,,,,,,,,,,,,,, -141500,1.4140451,4.296265,,,,,,,,,,,,,, -141600,1.7660011,2.2097435,,,,,,,,,,,,,, -141700,1.7162956,2.3248258,,,,,,,,,,,,,, -141800,1.4717458,3.1152737,,,,,,,,,,,,,, -141900,1.6931483,2.2174098,,,,,,,,,,,,,, -142000,1.554809,3.8276153,,,,,,,,,,,,,, -142026,,,0.712695300579071,1.1884727478027344,0.6464200019836426,1.488558292388916,50000.0,0.5259000062942505,2.140204906463623,10000.0,65159.91322731972,72884.38906359673,65159.91322731972,7710.902366638184,6.13739275932312,0.0 -142100,1.4573323,3.782024,,,,,,,,,,,,,, -142200,1.7691408,2.2295005,,,,,,,,,,,,,, -142300,1.9485667,2.2362864,,,,,,,,,,,,,, -142400,1.4586574,4.416421,,,,,,,,,,,,,, -142500,1.5074966,3.5056608,,,,,,,,,,,,,, -142600,1.8610427,2.2979882,,,,,,,,,,,,,, -142700,1.7719365,2.1611087,,,,,,,,,,,,,, -142800,1.5347795,3.6591706,,,,,,,,,,,,,, -142900,1.7530043,2.3179097,,,,,,,,,,,,,, -142943,,,0.7005664110183716,1.2247933149337769,0.6492999792098999,1.4643806219100952,50000.0,0.5232000350952148,2.109903573989868,10000.0,65580.24869775772,73353.33665847778,65580.24869775772,7759.4170508384705,6.186307430267334,0.0 -143000,1.542436,4.374678,,,,,,,,,,,,,, -143100,1.7187275,2.4621124,,,,,,,,,,,,,, -143200,1.7759262,4.2023997,,,,,,,,,,,,,, -143300,1.7397048,2.490524,,,,,,,,,,,,,, -143400,1.7680526,2.2328641,,,,,,,,,,,,,, -143500,1.9069356,2.232253,,,,,,,,,,,,,, -143600,1.6474435,2.938618,,,,,,,,,,,,,, -143700,1.8469859,2.4739642,,,,,,,,,,,,,, -143800,1.7131609,2.369737,,,,,,,,,,,,,, -143859,,,0.7032226324081421,1.2361243963241575,0.6534000039100647,1.4610118865966797,50000.0,0.5297000408172607,2.113292932510376,10000.0,66000.38722920418,73822.04728531837,66000.38722920418,7807.887068033218,6.239961385726929,0.0 -143900,1.8333229,2.2109566,,,,,,,,,,,,,, -144000,1.7826713,2.3018928,,,,,,,,,,,,,, -144100,1.5714971,3.917407,,,,,,,,,,,,,, -144200,1.6389775,2.808281,,,,,,,,,,,,,, -144300,1.7526146,2.6518734,,,,,,,,,,,,,, -144400,1.8749907,2.1811848,,,,,,,,,,,,,, -144500,1.8871914,2.3317785,,,,,,,,,,,,,, -144600,1.6249506,3.566054,,,,,,,,,,,,,, -144700,1.7051667,2.8018863,,,,,,,,,,,,,, -144776,,,0.7193945050239563,1.1607738733291626,0.654259979724884,1.447080135345459,50000.0,0.5299000144004822,2.11555290222168,10000.0,66420.68016719818,74292.60705327988,66420.68016719818,7858.054362535477,6.291686773300171,0.0 -144800,1.7326535,2.5361598,,,,,,,,,,,,,, -144900,1.5756742,3.1872385,,,,,,,,,,,,,, -145000,1.7272753,2.6647084,,,,,,,,,,,,,, -145100,1.7124788,2.5109515,,,,,,,,,,,,,, -145200,1.690971,2.271937,,,,,,,,,,,,,, -145300,1.6808048,2.4535818,,,,,,,,,,,,,, -145400,1.5255935,3.972972,,,,,,,,,,,,,, -145500,1.6594871,2.549554,,,,,,,,,,,,,, -145600,1.8011346,4.6781497,,,,,,,,,,,,,, -145690,,,0.713671863079071,1.1829522848129272,0.6609799861907959,1.417777180671692,50000.0,0.5415000319480896,2.0482189655303955,10000.0,66840.85182523727,74760.23641347885,66840.85182523727,7905.415101289749,6.340944766998291,0.0 -145700,1.9216595,2.21134,,,,,,,,,,,,,, -145800,1.7971878,2.148365,,,,,,,,,,,,,, -145900,1.8079607,2.1509943,,,,,,,,,,,,,, -146000,1.8375504,2.1902835,,,,,,,,,,,,,, -146100,1.542265,3.3352509,,,,,,,,,,,,,, -146200,1.7758547,4.503374,,,,,,,,,,,,,, -146300,1.8682983,1.9348531,,,,,,,,,,,,,, -146400,1.6997889,4.651106,,,,,,,,,,,,,, -146500,1.8586608,2.0672586,,,,,,,,,,,,,, -146600,1.7018726,4.394954,,,,,,,,,,,,,, -146605,,,0.7144335508346558,1.191360592842102,0.6621999740600586,1.4254510402679443,50000.0,0.542900025844574,2.0617923736572266,10000.0,67260.99752354622,75230.54712605476,67260.99752354622,7955.480447292328,6.393074035644531,0.0 -146700,1.9462831,2.237731,,,,,,,,,,,,,, -146800,1.8603586,2.0468063,,,,,,,,,,,,,, -146900,1.9500576,2.0384793,,,,,,,,,,,,,, -147000,1.772419,2.3821342,,,,,,,,,,,,,, -147100,1.9538878,2.1533182,,,,,,,,,,,,,, -147200,1.7466384,2.804682,,,,,,,,,,,,,, -147300,1.9111619,2.0035245,,,,,,,,,,,,,, -147400,1.8186058,3.1627126,,,,,,,,,,,,,, -147500,1.9161748,4.151925,,,,,,,,,,,,,, -147522,,,0.7241796851158142,1.1309359073638916,0.6642000079154968,1.4086828231811523,50000.0,0.5448000431060791,2.0465872287750244,10000.0,67681.06917715073,75698.780200243,67681.06917715073,8003.546766996384,6.440312385559082,0.0 -147600,1.5756623,3.0479312,,,,,,,,,,,,,, -147700,1.8218707,2.5736876,,,,,,,,,,,,,, -147800,1.9177823,2.2480445,,,,,,,,,,,,,, -147900,1.8528515,2.0817502,,,,,,,,,,,,,, -148000,1.8724647,2.4977942,,,,,,,,,,,,,, -148100,1.6925796,4.0846105,,,,,,,,,,,,,, -148200,1.8952405,2.1195364,,,,,,,,,,,,,, -148300,1.8366768,2.1234288,,,,,,,,,,,,,, -148400,2.013297,2.172323,,,,,,,,,,,,,, -148440,,,0.7292773127555847,1.113935470581055,0.6719399690628052,1.3779354095458984,50000.0,0.5482000112533569,2.010650396347046,10000.0,68101.14979958534,76168.72648119926,68101.14979958534,8053.31559920311,6.489897012710571,0.0 -148500,1.6916353,4.2588167,,,,,,,,,,,,,, -148600,1.8786,2.1672173,,,,,,,,,,,,,, -148700,1.9598609,2.4563324,,,,,,,,,,,,,, -148800,1.9772264,2.2111175,,,,,,,,,,,,,, -148900,1.6835386,2.574465,,,,,,,,,,,,,, -149000,1.9040729,4.4124956,,,,,,,,,,,,,, -149100,2.0479882,2.2032697,,,,,,,,,,,,,, -149200,1.7243154,2.4050517,,,,,,,,,,,,,, -149300,1.6089697,3.033975,,,,,,,,,,,,,, -149356,,,0.7234765291213989,1.1463912725448608,0.6699000000953674,1.3762966394424438,50000.0,0.5427000522613525,2.032515525817871,10000.0,68521.11470293999,76639.01660203934,68521.11470293999,8103.527039289474,6.53973126411438,0.0 -149400,1.7178849,2.6109743,,,,,,,,,,,,,, -149500,1.716613,4.112822,,,,,,,,,,,,,, -149600,1.6489742,3.620186,,,,,,,,,,,,,, -149700,2.1593516,2.2140574,,,,,,,,,,,,,, -149800,2.0262766,1.9362115,,,,,,,,,,,,,, -149900,1.9239573,2.1510317,,,,,,,,,,,,,, -150000,2.0591946,2.1972187,,,,,,,,,,,,,, -150100,2.0383883,2.079699,,,,,,,,,,,,,, -150200,1.9483622,4.081066,,,,,,,,,,,,,, -150270,,,0.7346875071525574,1.091256856918335,0.6761800050735474,1.3612442016601562,50000.0,0.5507000088691711,1.992889165878296,10000.0,68941.08154058456,77111.07012796402,68941.08154058456,8155.512540578842,6.593385457992554,0.0 -150300,1.9259125,4.545467,,,,,,,,,,,,,, -150400,1.9179543,2.2135954,,,,,,,,,,,,,, -150500,2.156578,2.129471,,,,,,,,,,,,,, -150600,2.0541887,2.2600415,,,,,,,,,,,,,, -150700,1.863306,4.0633264,,,,,,,,,,,,,, -150800,2.199458,2.4810765,,,,,,,,,,,,,, -150900,1.861233,3.9384317,,,,,,,,,,,,,, -151000,2.0244884,2.1483686,,,,,,,,,,,,,, -151100,2.1680393,2.0259516,,,,,,,,,,,,,, -151188,,,0.75537109375,1.024010181427002,0.6801199913024902,1.350976824760437,50000.0,0.5574000477790833,1.9849852323532104,10000.0,69361.19805550575,77582.26147294044,69361.19805550575,8206.480012178421,6.650951862335205,0.0 -151200,1.9955969,2.2082467,,,,,,,,,,,,,, -151300,2.1594064,2.3073864,,,,,,,,,,,,,, -151400,2.2452178,2.0265703,,,,,,,,,,,,,, -151500,1.9530776,2.424089,,,,,,,,,,,,,, -151600,2.017546,2.8310564,,,,,,,,,,,,,, -151700,1.7468668,2.6408045,,,,,,,,,,,,,, -151800,2.056536,2.228015,,,,,,,,,,,,,, -151900,1.8367689,3.7118614,,,,,,,,,,,,,, -152000,2.051963,1.9727194,,,,,,,,,,,,,, -152100,1.8322599,3.0029142,,,,,,,,,,,,,, -152104,,,0.7370507717132568,1.0689661502838137,0.681659996509552,1.3200819492340088,50000.0,0.556600034236908,1.947913408279419,10000.0,69781.2301557064,78051.009329319,69781.2301557064,8255.09532880783,6.703702688217163,0.0 -152200,2.1543405,3.860938,,,,,,,,,,,,,, -152300,2.0210726,2.102229,,,,,,,,,,,,,, -152400,2.1970842,1.9742601,,,,,,,,,,,,,, -152500,2.1114984,1.9566879,,,,,,,,,,,,,, -152600,2.1793823,2.0555158,,,,,,,,,,,,,, -152700,2.012076,4.1958685,,,,,,,,,,,,,, -152800,2.0197852,2.7251263,,,,,,,,,,,,,, -152900,2.3061974,2.0168223,,,,,,,,,,,,,, -153000,1.8600628,2.7271304,,,,,,,,,,,,,, -153020,,,0.7427343726158142,1.0332081317901611,0.6832399964332581,1.3128416538238523,50000.0,0.5614000558853149,1.9452199935913088,10000.0,70201.16539907455,78519.31473040581,70201.16539907455,8303.366862535477,6.754567861557007,0.0 -153100,2.064611,1.9903057,,,,,,,,,,,,,, -153200,2.1100092,2.0521107,,,,,,,,,,,,,, -153300,2.022376,2.2064607,,,,,,,,,,,,,, -153400,2.0734036,2.0762107,,,,,,,,,,,,,, -153500,2.1238592,2.110048,,,,,,,,,,,,,, -153600,2.172035,1.999629,,,,,,,,,,,,,, -153700,1.9610325,4.2944884,,,,,,,,,,,,,, -153800,2.0107293,3.977041,,,,,,,,,,,,,, -153900,1.9197152,2.2922635,,,,,,,,,,,,,, -153933,,,0.7536913752555847,1.0018309354782104,0.6841399669647217,1.3143833875656128,50000.0,0.5550000071525574,1.9567975997924805,10000.0,70621.41844844818,78989.54914736748,70621.41844844818,8353.252356290817,6.802517414093018,0.0 -154000,2.338619,2.0822656,,,,,,,,,,,,,, -154100,2.2076712,1.9734778,,,,,,,,,,,,,, -154200,2.0161893,2.1139345,,,,,,,,,,,,,, -154300,2.195426,2.0428123,,,,,,,,,,,,,, -154400,1.8997781,2.7062097,,,,,,,,,,,,,, -154500,2.0299914,3.297466,,,,,,,,,,,,,, -154600,2.0185099,3.588445,,,,,,,,,,,,,, -154700,2.143009,1.9162079,,,,,,,,,,,,,, -154800,2.0177653,2.3187017,,,,,,,,,,,,,, -154850,,,0.74916011095047,1.0158547163009644,0.688539981842041,1.2876722812652588,50000.0,0.5651000142097473,1.925018310546875,10000.0,71041.63038349152,79458.93891525269,71041.63038349152,8402.332137584686,6.852922439575195,0.0 -154900,2.1504936,2.1258035,,,,,,,,,,,,,, -155000,2.2708974,1.9578999,,,,,,,,,,,,,, -155100,2.1568708,1.9474418,,,,,,,,,,,,,, -155200,2.0602982,2.1890693,,,,,,,,,,,,,, -155300,1.9781343,4.3264184,,,,,,,,,,,,,, -155400,2.2729542,1.949036,,,,,,,,,,,,,, -155500,1.9607921,2.5694838,,,,,,,,,,,,,, -155600,2.1583242,4.466099,,,,,,,,,,,,,, -155700,2.164873,2.9762025,,,,,,,,,,,,,, -155765,,,0.7540234327316284,0.9922083616256714,0.6906799674034119,1.2663934230804443,50000.0,0.5700000524520874,1.9121655225753784,10000.0,71461.77147507668,79929.46589922905,71461.77147507668,8452.610144615173,6.913311243057251,0.0 -155800,1.9731978,3.7776728,,,,,,,,,,,,,, -155900,2.1786091,2.5223014,,,,,,,,,,,,,, -156000,2.120243,2.337052,,,,,,,,,,,,,, -156100,2.4205062,1.9789363,,,,,,,,,,,,,, -156200,2.1787937,1.8287964,,,,,,,,,,,,,, -156300,2.2353785,1.9557956,,,,,,,,,,,,,, -156400,2.2185426,2.0524046,,,,,,,,,,,,,, -156500,1.858136,3.2576787,,,,,,,,,,,,,, -156600,1.9433621,2.491393,,,,,,,,,,,,,, -156680,,,0.7607421875,0.967012107372284,0.6918999552726746,1.274028182029724,50000.0,0.5699000358581543,1.89507257938385,10000.0,71881.82072901726,80398.42122769356,71881.82072901726,8501.413291931152,6.968406200408936,0.0 -156700,2.239082,2.2425199,,,,,,,,,,,,,, -156800,2.2860289,4.326155,,,,,,,,,,,,,, -156900,2.2644455,1.9381177,,,,,,,,,,,,,, -157000,2.1549163,2.3907924,,,,,,,,,,,,,, -157100,2.2622077,1.993206,,,,,,,,,,,,,, -157200,1.9629714,2.867285,,,,,,,,,,,,,, -157300,2.3773816,1.9288374,,,,,,,,,,,,,, -157400,2.0232599,2.8239365,,,,,,,,,,,,,, -157500,2.224733,1.9453517,,,,,,,,,,,,,, -157595,,,0.7591796517372131,0.9821126461029052,0.6990199685096741,1.2495800256729126,50000.0,0.5837000012397766,1.8638801574707031,10000.0,72301.9625506401,80867.77632331848,72301.9625506401,8550.524432182312,7.022604942321777,0.0 -157600,2.344774,3.8397717,,,,,,,,,,,,,, -157700,2.4814997,1.8979069,,,,,,,,,,,,,, -157800,2.1480963,1.9477915,,,,,,,,,,,,,, -157900,2.1371775,1.7156566,,,,,,,,,,,,,, -158000,2.2380383,1.88348,,,,,,,,,,,,,, -158100,2.1056073,3.2300382,,,,,,,,,,,,,, -158200,2.1818268,2.096196,,,,,,,,,,,,,, -158300,2.497387,1.87856,,,,,,,,,,,,,, -158400,2.5479302,1.9599298,,,,,,,,,,,,,, -158500,2.4610875,4.3065023,,,,,,,,,,,,,, -158510,,,0.7641015648841858,0.9493613243103028,0.7014600038528442,1.2222778797149658,50000.0,0.57750004529953,1.8460190296173096,10000.0,72722.21183228493,81337.55066609383,72722.21183228493,8599.948066473007,7.076277017593384,0.0 -158600,2.2743425,2.1532753,,,,,,,,,,,,,, -158700,2.6652665,3.945246,,,,,,,,,,,,,, -158800,2.2089443,2.425045,,,,,,,,,,,,,, -158900,2.2971566,1.8829607,,,,,,,,,,,,,, -159000,1.911951,3.0791392,,,,,,,,,,,,,, -159100,2.088954,3.5231442,,,,,,,,,,,,,, -159200,2.3056495,1.9921976,,,,,,,,,,,,,, -159300,2.5001357,1.989685,,,,,,,,,,,,,, -159400,2.3263872,2.3103993,,,,,,,,,,,,,, -159425,,,0.7665234208106995,0.938743770122528,0.7026199698448181,1.223870873451233,50000.0,0.5818000435829163,1.855764389038086,10000.0,73142.33656525612,81808.82894182205,73142.33656525612,8651.003014564514,7.126868724822998,0.0 -159500,2.4855516,1.8219991,,,,,,,,,,,,,, -159600,2.3883417,1.8415792,,,,,,,,,,,,,, -159700,2.5963345,1.840642,,,,,,,,,,,,,, -159800,2.3073206,1.9390508,,,,,,,,,,,,,, -159900,2.2190769,3.3076148,,,,,,,,,,,,,, -160000,2.077183,2.7465835,,,,,,,,,,,,,, -160100,2.477448,1.8831307,,,,,,,,,,,,,, -160200,2.2820108,2.36607,,,,,,,,,,,,,, -160300,2.2615252,3.535597,,,,,,,,,,,,,, -160338,,,0.7646874785423279,0.9522948861122132,0.7034400105476379,1.2250497341156006,50000.0,0.5848000049591064,1.84183931350708,10000.0,73562.57827568054,82278.49391198158,73562.57827568054,8700.326808214188,7.178597688674927,0.0 -160400,2.3869436,1.8079536,,,,,,,,,,,,,, -160500,2.002642,2.5271738,,,,,,,,,,,,,, -160600,2.2354007,2.7744405,,,,,,,,,,,,,, -160700,2.4132261,3.3483222,,,,,,,,,,,,,, -160800,2.3776832,3.751851,,,,,,,,,,,,,, -160900,2.5507917,1.8181654,,,,,,,,,,,,,, -161000,2.1496778,2.7776568,,,,,,,,,,,,,, -161100,2.6456802,1.9726137,,,,,,,,,,,,,, -161200,2.3060708,1.692596,,,,,,,,,,,,,, -161254,,,0.7704296708106995,0.934205174446106,0.7053999900817871,1.2130885124206543,50000.0,0.5820000171661377,1.835293412208557,10000.0,73982.7239575386,82749.41592168808,73982.7239575386,8750.995544433594,7.237669229507446,0.0 -161300,2.603789,1.9841471,,,,,,,,,,,,,, -161400,2.1885896,3.5189934,,,,,,,,,,,,,, -161500,2.3953488,1.7725322,,,,,,,,,,,,,, -161600,2.6210876,1.8749332,,,,,,,,,,,,,, -161700,2.47206,2.2723076,,,,,,,,,,,,,, -161800,2.2118618,3.632029,,,,,,,,,,,,,, -161900,2.275966,1.7896624,,,,,,,,,,,,,, -162000,2.4187808,2.77489,,,,,,,,,,,,,, -162100,2.6049957,1.7865508,,,,,,,,,,,,,, -162170,,,0.7777734398841858,0.8995881676673889,0.7101399898529053,1.1974650621414185,50000.0,0.5907000303268433,1.818153619766236,10000.0,74402.66515946388,83219.72941589355,74402.66515946388,8801.265083551407,7.292815208435059,0.0 -162200,2.7328444,1.8779495,,,,,,,,,,,,,, -162300,2.5166733,4.0637584,,,,,,,,,,,,,, -162400,2.4568949,1.8321458,,,,,,,,,,,,,, -162500,2.4782977,1.7157905,,,,,,,,,,,,,, -162600,2.349177,1.9937334,,,,,,,,,,,,,, -162700,2.3457046,2.337895,,,,,,,,,,,,,, -162800,2.6845815,4.362858,,,,,,,,,,,,,, -162900,2.8051062,1.7626681,,,,,,,,,,,,,, -163000,2.6295338,1.8903689,,,,,,,,,,,,,, -163085,,,0.7899999618530273,0.8373759984970093,0.7140799760818481,1.1738150119781494,50000.0,0.5903000235557556,1.810662150382996,10000.0,74822.65410661697,83688.25251555443,74822.65410661697,8849.700415611267,7.344009160995483,0.0 -163100,2.4173987,1.728438,,,,,,,,,,,,,, -163200,2.3410459,2.0820649,,,,,,,,,,,,,, -163300,2.3262691,2.8973567,,,,,,,,,,,,,, -163400,2.3286893,2.2034223,,,,,,,,,,,,,, -163500,2.615786,1.7027504,,,,,,,,,,,,,, -163600,2.7169049,4.1996193,,,,,,,,,,,,,, -163700,2.5941734,1.7432508,,,,,,,,,,,,,, -163800,2.7388477,1.7345273,,,,,,,,,,,,,, -163900,2.595188,1.8034139,,,,,,,,,,,,,, -164000,2.5957491,1.824266,,,,,,,,,,,,,, -164001,,,0.7776171565055847,0.8835697174072266,0.7133199572563171,1.1664485931396484,50000.0,0.5924000144004822,1.779520034790039,10000.0,75243.34757304192,84157.42079758644,75243.34757304192,8898.0752389431,7.395813226699829,0.0 -164100,2.4580786,3.075214,,,,,,,,,,,,,, -164200,2.4625978,1.9291109,,,,,,,,,,,,,, -164300,2.6664355,1.803834,,,,,,,,,,,,,, -164400,2.2992914,3.2541451,,,,,,,,,,,,,, -164500,2.625481,1.7995038,,,,,,,,,,,,,, -164600,2.669624,1.7616262,,,,,,,,,,,,,, -164700,2.7083955,1.7645769,,,,,,,,,,,,,, -164800,2.6546645,1.74018,,,,,,,,,,,,,, -164900,2.4821353,1.9370835,,,,,,,,,,,,,, -164920,,,0.7840234041213989,0.8810456395149231,0.7141799926757812,1.1800652742385864,50000.0,0.5914000272750854,1.803371548652649,10000.0,75663.35647845268,84626.98712992668,75663.35647845268,8947.53544473648,7.44543981552124,0.0 -165000,2.5039155,1.8140222,,,,,,,,,,,,,, -165100,2.7497802,1.8208723,,,,,,,,,,,,,, -165200,2.4946814,1.8018124,,,,,,,,,,,,,, -165300,2.9181507,4.268771,,,,,,,,,,,,,, -165400,2.936193,1.7350973,,,,,,,,,,,,,, -165500,2.451844,2.2142673,,,,,,,,,,,,,, -165600,2.6992006,1.8301938,,,,,,,,,,,,,, -165700,3.0182629,1.7496002,,,,,,,,,,,,,, -165800,2.739568,3.928339,,,,,,,,,,,,,, -165837,,,0.7928906083106995,0.8252041339874268,0.7172600030899048,1.158810257911682,50000.0,0.5946000218391418,1.770185470581055,10000.0,76083.54992222786,85097.52592563629,76083.54992222786,8997.781407117844,7.497044086456299,0.0 -165900,3.0417554,4.3098264,,,,,,,,,,,,,, -166000,2.8189218,1.7645121,,,,,,,,,,,,,, -166100,2.6640532,1.9032894,,,,,,,,,,,,,, -166200,2.745977,1.7108057,,,,,,,,,,,,,, -166300,2.7436624,1.8433701,,,,,,,,,,,,,, -166400,2.6472418,3.7585812,,,,,,,,,,,,,, -166500,2.882904,1.8402681,,,,,,,,,,,,,, -166600,2.9161985,4.2519493,,,,,,,,,,,,,, -166700,2.7724164,1.63224,,,,,,,,,,,,,, -166753,,,0.7876757383346558,0.8453723788261414,0.7234599590301514,1.131539225578308,50000.0,0.6046000123023987,1.7537925243377686,10000.0,76503.78957104683,85568.19960737228,76503.78957104683,9048.111221551895,7.553990364074707,0.0 -166800,2.7856207,2.3786447,,,,,,,,,,,,,, -166900,2.6028018,2.5234265,,,,,,,,,,,,,, -167000,2.7219808,1.748086,,,,,,,,,,,,,, -167100,2.799055,2.1755066,,,,,,,,,,,,,, -167200,2.9520125,1.7964939,,,,,,,,,,,,,, -167300,2.5082152,2.4738762,,,,,,,,,,,,,, -167400,2.6229308,2.8202274,,,,,,,,,,,,,, -167500,2.9396546,1.6282717,,,,,,,,,,,,,, -167600,2.780817,1.7512974,,,,,,,,,,,,,, -167671,,,0.7952343821525574,0.8193172812461853,0.7249599695205688,1.1146950721740725,50000.0,0.6041000485420227,1.7299996614456177,10000.0,76924.13359379768,86039.13908457756,76924.13359379768,9098.598886728289,7.613273620605469,0.0 -167700,3.0909076,4.2641454,,,,,,,,,,,,,, -167800,3.0642662,1.76266,,,,,,,,,,,,,, -167900,2.960337,1.8919021,,,,,,,,,,,,,, -168000,2.7200615,1.6400018,,,,,,,,,,,,,, -168100,2.852095,2.112506,,,,,,,,,,,,,, -168200,3.0655599,1.7800003,,,,,,,,,,,,,, -168300,2.8241947,1.6759242,,,,,,,,,,,,,, -168400,2.9795063,3.9572263,,,,,,,,,,,,,, -168500,2.6633434,1.82706,,,,,,,,,,,,,, -168589,,,0.8006640672683716,0.809798002243042,0.7245999574661255,1.1315559148788452,50000.0,0.6045000553131104,1.7418795824050903,10000.0,77344.50248265266,86510.57143831253,77344.50248265266,9149.558176994324,7.669828414916992,0.0 -168600,3.1019254,1.8952466,,,,,,,,,,,,,, -168700,2.6455653,3.4057236,,,,,,,,,,,,,, -168800,3.0682666,4.0765676,,,,,,,,,,,,,, -168900,3.0344133,1.5829786,,,,,,,,,,,,,, -168978,,,,,,,,,,,77520.03267264366,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 962686b06..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -53.20211386680603,0.0,40.29682993888855,1,0,40.29682993888855,0.0010000000474974,6.907756805419922,10000,93.49904775619508,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -103.95526385307312,0.0180015563964843,460.35428380966187,850,0,460.35428380966187,0.0224000010639429,6.085418224334717,10000,564.3717305660248,0.0340039059519767,5.898688793182373,0.0292599983513355,5.964993476867676,50000 -153.4896755218506,0.0557641983032226,880.7335257530212,1755,0,880.7335257530212,0.0504000037908554,5.601574420928955,10000,1034.3701009750366,0.0697851553559303,5.347999572753906,0.0648000016808509,5.401942253112793,50000 -205.0133285522461,0.0906887054443359,1301.019949913025,2667,0,1301.019949913025,0.0861000046133995,5.135005950927734,10000,1506.262699842453,0.116367183625698,4.784284591674805,0.1093599945306778,4.85033655166626,50000 -255.52188277244568,0.1169803142547607,1721.0184772014618,3575,0,1721.0184772014618,0.1219000071287155,4.7613372802734375,10000,1976.84295463562,0.1799414008855819,4.237734794616699,0.1611000001430511,4.358086585998535,50000 -307.60889863967896,0.1440844535827636,2141.226948738098,4481,0,2141.226948738098,0.1535000056028366,4.470614910125732,10000,2449.213127851486,0.2168945223093032,3.952517986297608,0.2021399885416031,4.050927639007568,50000 -359.7699909210205,0.1761953830718994,2561.221724510193,5390,0,2561.221724510193,0.1933000087738037,4.191379547119141,10000,2921.4478216171265,0.2662499845027923,3.6010043621063232,0.2474399954080581,3.722931385040283,50000 -413.4065380096436,0.2065181732177734,2981.236423969269,6302,0,2981.236423969269,0.2187000066041946,3.96230673789978,10000,3395.177416563034,0.3132421672344208,3.299769163131714,0.2857999801635742,3.4638993740081787,50000 -464.2115104198456,0.2327077388763427,3401.2242062091827,7214,0,3401.2242062091827,0.2492000162601471,3.800960302352905,10000,3866.043285608292,0.3377148509025574,3.136587858200073,0.3125399947166443,3.273046493530273,50000 -513.9831821918488,0.2652533054351806,3821.3774077892303,8125,0,3821.3774077892303,0.268200010061264,3.6436169147491455,10000,4336.0475380420685,0.3757226467132568,2.8953447341918945,0.3463599979877472,3.064788341522217,50000 -565.3924582004547,0.2942073345184326,4241.45069026947,9031,0,4241.45069026947,0.2786000072956085,3.57048773765564,10000,4807.605703353882,0.3907421827316284,2.777474880218506,0.3562799990177154,2.988455295562744,50000 -617.273538351059,0.3234410285949707,4661.742277622223,9936,0,4661.742277622223,0.2880000174045563,3.4860610961914062,10000,5279.853937387466,0.3995312452316284,2.739250659942627,0.373879998922348,2.8733174800872803,50000 -670.9042701721191,0.3543593883514404,5081.717084169388,10843,0,5081.717084169388,0.2933000028133392,3.4927845001220703,10000,5753.537490606308,0.4029882848262787,2.765798807144165,0.3730599880218506,2.929877519607544,50000 -722.1547908782959,0.3821694850921631,5501.863046884537,11750,0,5501.863046884537,0.3113000094890594,3.375816583633423,10000,6225.0086896419525,0.4351171851158142,2.5774996280670166,0.3994199931621551,2.767735004425049,50000 -774.9626307487488,0.4158406257629394,5921.816805839539,12656,0,5921.816805839539,0.3148000240325928,3.308043956756592,10000,6697.850959300995,0.4435742199420929,2.502842903137207,0.4075399935245514,2.688591718673706,50000 -826.442524433136,0.4475893974304199,6342.030628442764,13561,0,6342.030628442764,0.3383000195026397,3.179518222808838,10000,7169.624086141586,0.4646874964237213,2.3886172771453857,0.4336400032043457,2.552778720855713,50000 -879.3584513664246,0.4796981811523437,6762.030198812485,14465,0,6762.030198812485,0.3439000248908996,3.1078310012817383,10000,7642.619294166565,0.4827734231948852,2.269944429397583,0.4453199803829193,2.4657223224639893,50000 -930.4304752349854,0.5069384574890137,7182.221163272858,15371,0,7182.221163272858,0.344400018453598,3.1176199913024902,10000,8113.95642209053,0.5088281035423279,2.144819974899292,0.4482799768447876,2.458869457244873,50000 -981.7441575527192,0.5383155345916748,7602.466048240662,16277,0,7602.466048240662,0.3543000221252441,3.049616813659668,10000,8585.593609571457,0.4963085949420929,2.2050812244415283,0.457999974489212,2.4059407711029053,50000 -1033.7704393863678,0.5695366859436035,8022.563579797745,17183,0,8022.563579797745,0.3633000254631042,3.038583993911743,10000,9057.795295476912,0.499804675579071,2.1979124546051025,0.4614399969577789,2.4079489707946777,50000 -1086.3128879070282,0.6030721664428711,8442.695451498032,18085,0,8442.695451498032,0.3718000054359436,2.949524641036988,10000,9530.55030632019,0.5384179353713989,1.9806084632873533,0.4762399792671203,2.3108084201812744,50000 -1137.469367980957,0.6330897808074951,8862.955168247223,18992,0,8862.955168247223,0.380700021982193,2.9152309894561768,10000,10002.043234586716,0.5228710770606995,2.079188823699951,0.487419992685318,2.272768020629883,50000 -1190.1830968856812,0.6724350452423096,9283.24651813507,19898,0,9283.24651813507,0.3789000213146209,2.931926727294922,10000,10475.134246349337,0.5240820050239563,2.092133045196533,0.4813199937343597,2.309021234512329,50000 -1241.9485597610474,0.7021317481994629,9703.740409374235,20801,0,9703.740409374235,0.3815000057220459,2.893972873687744,10000,10947.46983218193,0.5454296469688416,1.9585974216461184,0.4882199764251709,2.2411868572235107,50000 -1294.5543503761292,0.7387468814849854,10123.801826238632,21708,0,10123.801826238632,0.3945000171661377,2.8344380855560303,10000,11420.220302343369,0.5394921898841858,1.981076717376709,0.4965799748897552,2.200188636779785,50000 -1345.811369419098,0.7716727256774902,10543.785840034485,22610,0,10543.785840034485,0.380700021982193,2.866044759750366,10000,11891.540809392927,0.5367578268051147,1.97980535030365,0.4986799955368042,2.189040422439575,50000 -1397.873296737671,0.8026630878448486,10963.787488222122,23511,0,10963.787488222122,0.3975000083446502,2.843322515487671,10000,12363.682228326796,0.5564843416213989,1.8949991464614868,0.5028799772262573,2.1828277111053467,50000 -1450.3069911003113,0.8335375785827637,11383.82677412033,24413,0,11383.82677412033,0.399800032377243,2.798171281814575,10000,12836.232869625092,0.5447851419448853,1.9759751558303835,0.5110799670219421,2.1550962924957275,50000 -1502.8061113357544,0.8648619651794434,11803.81246328354,25317,0,11803.81246328354,0.4092000126838684,2.747539758682251,10000,13308.797145843506,0.5570703148841858,1.8775259256362915,0.5131399631500244,2.10324501991272,50000 -1555.772926568985,0.895737886428833,12223.884135246277,26218,0,12223.884135246277,0.4030000269412994,2.811911106109619,10000,13781.913187265396,0.5672265291213989,1.894603490829468,0.5156599879264832,2.146013259887696,50000 -1606.7444801330566,0.9290673732757568,12643.856747865677,27116,0,12643.856747865677,0.4043000340461731,2.7777791023254395,10000,14252.93661761284,0.5529491901397705,1.901785969734192,0.5184800028800964,2.0907881259918213,50000 -1658.8222217559814,0.9629182815551758,13064.020081281662,28016,0,13064.020081281662,0.4104000329971313,2.7968037128448486,10000,14725.257707118988,0.5609570145606995,1.945162057876587,0.5189399719238281,2.1581993103027344,50000 -1710.7781717777252,0.99446702003479,13484.096655845642,28920,0,13484.096655845642,0.4128000140190124,2.717691898345948,10000,15197.368015289308,0.5832226276397705,1.7956725358963013,0.5298199653625488,2.0622482299804688,50000 -1764.337646484375,1.0288941860198977,13904.335749864578,29821,0,13904.335749864578,0.4171000123023987,2.6709110736846924,10000,15671.247034311296,0.5748437643051147,1.802386283874512,0.538159966468811,1.9932000637054443,50000 -1817.6779038906093,1.0666768550872805,14324.31052160263,30719,0,14324.31052160263,0.4170000255107879,2.6758742332458496,10000,16144.646352529526,0.5781054496765137,1.7854104042053225,0.5322200059890747,2.0117881298065186,50000 -1870.412175655365,1.103344440460205,14744.462392568588,31622,0,14744.462392568588,0.4200000166893005,2.6826159954071045,10000,16617.61550116539,0.5849804282188416,1.757678747177124,0.5351200103759766,2.018251895904541,50000 -1923.5166292190552,1.1434593200683594,15164.471905231476,32523,0,15164.471905231476,0.4262000322341919,2.658798217773437,10000,17090.816098451614,0.5832421779632568,1.8074359893798828,0.5415599942207336,2.009346008300781,50000 -1975.7538046836853,1.178342580795288,15584.505164146423,33427,0,15584.505164146423,0.4164000153541565,2.6936354637146,10000,17563.168486595154,0.583300769329071,1.8000158071517944,0.5366399884223938,2.02001428604126,50000 -2027.1375963687897,1.2206003665924072,16004.51957321167,34328,0,16004.51957321167,0.4330000281333923,2.5797040462493896,10000,18034.655037164688,0.6021679639816284,1.6713885068893433,0.5493800044059753,1.922240614891052,50000 -2077.322345972061,1.2598974704742432,16424.51560664177,35225,0,16424.51560664177,0.4413000345230102,2.5836009979248047,10000,18504.921072244644,0.6086132526397705,1.6429468393325806,0.5566200017929077,1.910609126091004,50000 -2128.4921691417694,1.2963852882385254,16844.621319770813,36125,0,16844.621319770813,0.4359000325202942,2.588874578475952,10000,18976.279500722885,0.5963281393051147,1.712787389755249,0.5558199882507324,1.907938957214356,50000 -2179.108320236206,1.329272985458374,17264.70506668091,37028,0,17264.70506668091,0.4410000145435333,2.56365966796875,10000,19447.05951809883,0.6057812571525574,1.6403157711029053,0.5555599927902222,1.885812759399414,50000 -2232.1497917175293,1.3694918155670166,17684.651047706604,37930,0,17684.651047706604,0.4363000094890594,2.5774965286254883,10000,19920.133969783783,0.6322460770606995,1.546537160873413,0.5542399883270264,1.9083751440048216,50000 -2284.9851546287537,1.4052977561950684,18104.869975566864,38831,0,18104.869975566864,0.4299000203609466,2.600237846374512,10000,20393.270917892456,0.5987695455551147,1.698283076286316,0.5547999739646912,1.91600239276886,50000 -2349.030675649643,1.4428019523620603,18524.83588886261,39728,0,18524.83588886261,0.4417000114917755,2.5730700492858887,10000,20877.36520934105,0.609570324420929,1.6623051166534424,0.5617200136184692,1.9099894762039185,50000 -2400.5719459056854,1.515040397644043,18944.7302005291,40622,0,18944.7302005291,0.4377000331878662,2.548011302947998,10000,21348.91973233223,0.6304491758346558,1.5252255201339722,0.5624600052833557,1.8603116273880005,50000 -2453.674109697342,1.5481061935424805,19365.09273838997,41519,0,19365.09273838997,0.4422000348567962,2.518946886062622,10000,21822.46431851387,0.6089648008346558,1.633178949356079,0.568839967250824,1.8405061960220337,50000 -2505.828282117844,1.584458351135254,19785.181124925613,42423,0,19785.181124925613,0.4497000277042389,2.5079567432403564,10000,22294.79103422165,0.6158398389816284,1.5993508100509644,0.5668599605560303,1.833824157714844,50000 -2557.823818206787,1.628956317901611,20205.273720502853,43326,0,20205.273720502853,0.4489000141620636,2.537725687026977,10000,22766.971638917923,0.6241992115974426,1.6059962511062622,0.5625,1.9052478075027464,50000 -2609.916860103607,1.6631481647491455,20625.553758859634,44229,0,20625.553758859634,0.4519000351428985,2.4976508617401123,10000,23239.426404237747,0.6158398389816284,1.622150421142578,0.5743799805641174,1.831883668899536,50000 -2664.1930780410767,2.1820499897003174,21044.99322247505,45133,0,21044.99322247505,0.4542000293731689,2.4946107864379883,10000,23713.70925736428,0.6198632717132568,1.6021332740783691,0.5716800093650818,1.8257529735565183,50000 -2716.657520532608,2.2180991172790527,21465.308529138565,46038,0,21465.308529138565,0.4570000171661377,2.525801658630371,10000,24186.573442697525,0.6359961032867432,1.5632343292236328,0.5751199722290039,1.8546125888824463,50000 -2770.924865722656,2.2558722496032715,21885.264724254608,46941,0,21885.264724254608,0.4579000174999237,2.460035562515259,10000,24660.88339877129,0.6237695217132568,1.593305587768555,0.583840012550354,1.7873661518096924,50000 -2824.276191473007,2.2912559509277344,22305.34062218666,47845,0,22305.34062218666,0.4599000215530395,2.456450223922729,10000,25134.393942832947,0.6243554353713989,1.5742669105529783,0.5797199606895447,1.8012193441390991,50000 -2876.451560020447,2.335695505142212,22725.4294102192,48753,0,22725.4294102192,0.4649000167846679,2.4156606197357178,10000,25606.750625133514,0.6461523175239563,1.4468914270401,0.5879200100898743,1.7366979122161863,50000 -2928.817025184632,2.37554931640625,23145.645364046097,49661,0,23145.645364046097,0.4748000204563141,2.404741048812866,10000,26079.421197891235,0.6366406083106995,1.5108044147491455,0.5891199707984924,1.732394456863403,50000 -2981.225423812866,2.413803577423096,23565.832825899124,50559,0,23565.832825899124,0.4726000130176544,2.380942344665528,10000,26552.102893590927,0.637988269329071,1.5000442266464231,0.5877199769020081,1.7394182682037354,50000 -3033.735938310623,2.457144498825073,23985.97656941414,51462,0,23985.97656941414,0.46670001745224,2.434500217437744,10000,27024.848866462708,0.6421484351158142,1.497378706932068,0.5862199664115906,1.781178593635559,50000 -3086.1749098300934,2.498885154724121,24406.12998008728,52366,0,24406.12998008728,0.4722000360488891,2.3861453533172607,10000,27497.53017377853,0.6407226324081421,1.4778684377670288,0.5956599712371826,1.697656512260437,50000 -3138.3990700244904,2.5419416427612305,24826.21729016304,53267,0,24826.21729016304,0.4702000319957733,2.407470703125,10000,27969.93226337433,0.6386523246765137,1.511078119277954,0.5876399874687195,1.755619764328003,50000 -3190.0559356212616,2.578366041183472,25246.371876716614,54170,0,25246.371876716614,0.4725000262260437,2.3761520385742188,10000,28441.82798075676,0.6471874713897705,1.4455610513687134,0.5920599699020386,1.721648097038269,50000 -3241.9795260429382,2.6154439449310303,25666.634063720703,55074,0,25666.634063720703,0.4713000357151031,2.3855762481689453,10000,28914.098296403885,0.6447460651397705,1.47289776802063,0.5917400121688843,1.7395529747009275,50000 -3294.8521118164062,2.6539435386657715,26086.77654409409,55970,0,26086.77654409409,0.4783000349998474,2.387146711349488,10000,29387.199322223663,0.6394140720367432,1.5037568807601929,0.5911999940872192,1.7273492813110352,50000 -3347.73197889328,2.6959104537963867,26506.76273608208,56872,0,26506.76273608208,0.4745000302791595,2.365806579589844,10000,29860.154549121857,0.6453320384025574,1.4583547115325928,0.5988999605178833,1.701833724975586,50000 -3399.9325094223022,2.73496150970459,26926.744647026066,57778,0,26926.744647026066,0.4857000112533569,2.3210160732269287,10000,30332.42511534691,0.6798242330551147,1.318408727645874,0.6038199663162231,1.671668529510498,50000 -3450.595221042633,2.778135299682617,27346.96013259888,58680,0,27346.96013259888,0.4785000085830688,2.356879711151123,10000,30803.39335131645,0.6500585675239563,1.4647369384765625,0.5982599854469299,1.699963092803955,50000 -3503.4003579616547,2.815491199493408,27767.20039820671,59582,0,27767.20039820671,0.4764000177383423,2.356093168258667,10000,31276.52426123619,0.6515820026397705,1.4490962028503418,0.6025399565696716,1.689075589179993,50000 -3555.011140346527,2.856693983078003,28187.20550107956,60486,0,28187.20550107956,0.4871000349521637,2.3113863468170166,10000,31748.22876477241,0.6818945407867432,1.308190941810608,0.6083399653434753,1.650006651878357,50000 -3606.759976387024,2.89612340927124,28607.25713896752,61388,0,28607.25713896752,0.4802000224590301,2.3440380096435547,10000,32220.116106033325,0.6479101181030273,1.457031011581421,0.6025999784469604,1.6835922002792358,50000 -3656.720737695694,2.934617042541504,29027.53623533249,62289,0,29027.53623533249,0.4809000194072723,2.35693907737732,10000,32690.442022562027,0.6520702838897705,1.4345024824142456,0.6013000011444092,1.6837353706359863,50000 -3709.118235349655,2.974860906600952,29447.884727954865,63194,0,29447.884727954865,0.4918000102043152,2.2952914237976074,10000,33163.27663230896,0.6779687404632568,1.3029876947402954,0.6114599704742432,1.6231791973114014,50000 -3761.9516339302063,3.0155563354492188,29868.319049596783,64090,0,29868.319049596783,0.4866000115871429,2.292811155319214,10000,33636.6328959465,0.6555468440055847,1.3943846225738523,0.6055799722671509,1.6383819580078125,50000 -3813.739384889602,3.059473991394043,30288.43496155739,64992,0,30288.43496155739,0.4844000339508056,2.3205599784851074,10000,34108.627766132355,0.6604882478713989,1.4079033136367798,0.6082199811935425,1.6526474952697754,50000 -3867.965404987335,3.1049630641937256,30708.749766349792,65896,0,30708.749766349792,0.4891000092029571,2.299282550811768,10000,34583.26203036308,0.6743749976158142,1.3156598806381226,0.6094799637794495,1.6173152923583984,50000 -3921.2101068496704,3.146657943725586,31128.717413187027,66799,0,31128.717413187027,0.4841000139713287,2.2967209815979004,10000,35056.56464314461,0.6649413704872131,1.3905673027038574,0.613599956035614,1.6307467222213743,50000 -3974.759640932083,3.186652421951294,31548.65906047821,67700,0,31548.65906047821,0.4953000247478485,2.2668533325195312,10000,35530.14267539978,0.666308581829071,1.380321025848389,0.6187199950218201,1.6084468364715576,50000 -4025.041212797165,3.2343225479125977,31968.73847270012,68601,0,31968.73847270012,0.491100013256073,2.2899532318115234,10000,36000.59927082062,0.6759960651397705,1.306284785270691,0.6125800013542175,1.6162729263305664,50000 -4079.8373177051535,3.27199387550354,32388.73214316368,69503,0,32388.73214316368,0.495600014925003,2.2810490131378174,10000,36475.47456264496,0.6673437356948853,1.392255187034607,0.6157999634742737,1.6166682243347168,50000 -4133.789508104324,3.3121135234832764,32808.83074641228,70398,0,32808.83074641228,0.4962000250816345,2.304437160491944,10000,36949.61338496208,0.6661718487739563,1.4046850204467771,0.6162399649620056,1.6451646089553833,50000 -4186.583588838577,3.3608882427215576,33228.976645469666,71300,0,33228.976645469666,0.4915000200271606,2.279672861099243,10000,37422.65373325348,0.671875,1.3597897291183472,0.6108399629592896,1.6405349969863892,50000 -4238.585715770721,3.405482292175293,33649.24731469154,72203,0,33649.24731469154,0.5045000314712524,2.2146520614624023,10000,37895.018824100494,0.6790234446525574,1.29867684841156,0.6268599629402161,1.5448366403579712,50000 -4291.450464725494,3.447964191436768,34069.50347137451,73109,0,34069.50347137451,0.4961000382900238,2.272181510925293,10000,38368.23072266579,0.6642382740974426,1.3958392143249512,0.6126999855041504,1.638340711593628,50000 -4344.784727811813,3.4918909072875977,34489.668355703354,74011,0,34489.668355703354,0.5003000497817993,2.2567830085754395,10000,38841.82125282288,0.6847460865974426,1.295527458190918,0.6244399547576904,1.5823724269866943,50000 -4396.087966918945,3.528670072555542,34909.960586071014,74915,0,34909.960586071014,0.499500036239624,2.259209156036377,10000,39313.50108551979,0.6771484017372131,1.3254854679107666,0.6175999641418457,1.5920697450637815,50000 -4446.619699478149,3.570764780044556,35330.13725447655,75819,0,35330.13725447655,0.4980000257492065,2.25400972366333,10000,39784.29939389229,0.6713671684265137,1.3427538871765137,0.6200799942016602,1.5841830968856812,50000 -4498.87525844574,3.6172690391540527,35750.3324637413,76720,0,35750.3324637413,0.5045000314712524,2.233535051345825,10000,40256.8454182148,0.6835741996765137,1.2883057594299316,0.6281999945640564,1.5592522621154783,50000 -4554.121444702148,3.6560237407684326,36170.52016687393,77628,0,36170.52016687393,0.5058000087738037,2.190584659576416,10000,40732.36606168747,0.7032226324081421,1.1986013650894165,0.6326799988746643,1.526818037033081,50000 -4606.165371894836,3.701697587966919,36590.65875458717,78530,0,36590.65875458717,0.5034000277519226,2.228025197982788,10000,41204.64127254486,0.6783398389816284,1.3336890935897827,0.6248399615287781,1.5802377462387085,50000 -4657.113859415054,3.745133638381958,37011.00729799271,79433,0,37011.00729799271,0.5078999996185303,2.239119291305542,10000,41676.02999925613,0.6779101490974426,1.3415353298187256,0.6250999569892883,1.593327283859253,50000 -4711.764025211334,3.782468557357788,37431.40174984932,80333,0,37431.40174984932,0.5133000016212463,2.168611526489258,10000,42151.15966916084,0.712890625,1.150959014892578,0.6348599791526794,1.51292884349823,50000 -4763.166525602341,3.8283190727233887,37851.49418973923,81236,0,37851.49418973923,0.4996000230312347,2.2001304626464844,10000,42622.74819707871,0.6807616949081421,1.2854599952697754,0.6318599581718445,1.5201311111450195,50000 -4815.883181333542,3.870955228805542,38271.61984539032,82140,0,38271.61984539032,0.5154000520706177,2.1858279705047607,10000,43095.682314157486,0.6885156035423279,1.2505003213882446,0.6342399716377258,1.5106966495513916,50000 -4868.257662296295,3.9152021408081055,38691.78195428848,83040,0,38691.78195428848,0.5186000466346741,2.158198595046997,10000,43568.31066441536,0.7089062333106995,1.1777950525283811,0.6371200084686279,1.510980486869812,50000 -4919.264448881149,3.956256151199341,39111.83612918854,83940,0,39111.83612918854,0.5134000182151794,2.157160758972168,10000,44039.46033358574,0.6898242235183716,1.2514363527297974,0.6396399736404419,1.4833993911743164,50000 -4971.682218551636,4.003744602203369,39531.96011543274,84841,0,39531.96011543274,0.5060000419616699,2.2265491485595703,10000,44512.09685301781,0.6862499713897705,1.308943748474121,0.6337400078773499,1.5579991340637207,50000 -5022.644863128662,4.050050735473633,39952.28334736824,85741,0,39952.28334736824,0.5157000422477722,2.157416820526123,10000,44983.476644039154,0.7101757526397705,1.163679122924805,0.6434400081634521,1.4765359163284302,50000 -5075.555291175842,4.092345476150513,40372.55580735207,86646,0,40372.55580735207,0.5220000147819519,2.1490726470947266,10000,45456.74989652634,0.6906445026397705,1.2567678689956665,0.6412999629974365,1.5003139972686768,50000 -5129.692433595657,4.137173175811768,40792.581500291824,87543,0,40792.581500291824,0.5128000378608704,2.205008029937744,10000,45931.0048494339,0.6939257383346558,1.2720842361450195,0.6372399926185608,1.5326356887817385,50000 -5181.660806417465,4.181311845779419,41212.59487867355,88444,0,41212.59487867355,0.5198000073432922,2.147436141967773,10000,46403.07826638222,0.7095507383346558,1.1825847625732422,0.6438800096511841,1.4855278730392456,50000 -5232.368206501007,4.225497245788574,41632.62812304497,89347,0,41632.62812304497,0.5243000388145447,2.1224191188812256,10000,46873.910395145416,0.697949230670929,1.2049740552902222,0.646120011806488,1.454306960105896,50000 -5283.711639404297,4.278661251068115,42052.653123140335,90248,0,42052.653123140335,0.5238000154495239,2.1029837131500244,10000,47345.378999471664,0.7050976157188416,1.1831070184707642,0.6489799618721008,1.4485565423965454,50000 -5336.162372112274,4.322976589202881,42472.69245290756,91150,0,42472.69245290756,0.523300051689148,2.1205849647521973,10000,47817.961052656174,0.7136132717132568,1.1587448120117188,0.645859956741333,1.4699233770370483,50000 -5388.570991754532,4.366878986358643,42893.01210975647,92050,0,42893.01210975647,0.5268000364303589,2.090376138687134,10000,48290.78095889092,0.7071288824081421,1.196994662284851,0.6503399610519409,1.4454437494277954,50000 -5440.473873138428,4.412621974945068,43313.314470767975,92948,0,43313.314470767975,0.5206000208854675,2.143444061279297,10000,48763.07976198197,0.7015038728713989,1.22853684425354,0.6442599892616272,1.4875376224517822,50000 -5492.004205703735,4.4586100578308105,43733.2725212574,93845,0,43733.2725212574,0.5323000550270081,2.060444593429565,10000,49234.66075634956,0.7256445288658142,1.0961700677871704,0.6585800051689148,1.40444016456604,50000 -5542.949748277664,4.503470420837402,44153.533059597015,94745,0,44153.533059597015,0.5218999981880188,2.109593629837036,10000,49705.9593732357,0.7081640362739563,1.187019944190979,0.6531999707221985,1.43330717086792,50000 -5594.092439651489,4.546373128890991,44573.71209073067,95645,0,44573.71209073067,0.5276000499725342,2.1062769889831543,10000,50177.37110543251,0.7108983993530273,1.1773436069488523,0.6543799638748169,1.4492274522781372,50000 -5647.579032897949,4.593611240386963,44993.845116853714,96550,0,44993.845116853714,0.5288000106811523,2.1328086853027344,10000,50651.08602452278,0.7174413800239563,1.1969013214111328,0.6540799736976624,1.4859381914138794,50000 -5698.86651301384,4.638423681259155,45413.95650100708,97448,0,45413.95650100708,0.534500002861023,2.111708402633667,10000,51122.57685351372,0.7165625095367432,1.188768744468689,0.6569199562072754,1.463334083557129,50000 -5749.567130565643,4.687403678894043,45834.25590658188,98347,0,45834.25590658188,0.5304000377655029,2.0756752490997314,10000,51593.67358827591,0.7182226181030273,1.1286863088607788,0.6610599756240845,1.4054856300354004,50000 -5802.036701917648,4.73491358757019,46254.56909441948,99248,0,46254.56909441948,0.535800039768219,2.052008390426636,10000,52066.551471710205,0.7224413752555847,1.107817769050598,0.660319983959198,1.3991767168045044,50000 -5855.741189956665,4.792380809783936,46674.90282559395,100151,0,46674.90282559395,0.5397000312805176,2.0368974208831787,10000,52540.69532442093,0.7393554449081421,1.0557329654693604,0.6615599989891052,1.396098256111145,50000 -5907.954854726791,4.839926719665527,47095.12478327751,101053,0,47095.12478327751,0.5364000201225281,2.0542080402374268,10000,53013.22602438927,0.7187304496765137,1.1348941326141355,0.6573399901390076,1.40870201587677,50000 -5958.642646789551,4.885095834732056,47515.08821105957,101956,0,47515.08821105957,0.5386000275611877,2.0662755966186523,10000,53483.97033381462,0.7291796803474426,1.140833616256714,0.6621400117874146,1.432441234588623,50000 -6011.463626861572,4.929841756820679,47935.08392548561,102858,0,47935.08392548561,0.5406000018119812,2.0302770137786865,10000,53956.87895011902,0.7438867092132568,1.0300265550613403,0.6672199964523315,1.3864587545394895,50000 -6062.30232834816,4.975308418273926,48355.07038998604,103763,0,48355.07038998604,0.5421000123023987,2.0381054878234863,10000,54427.79730439186,0.723437488079071,1.120510816574097,0.66975998878479,1.3746145963668823,50000 -6113.727089166641,5.020384788513184,48775.350959062576,104666,0,48775.350959062576,0.5452000498771667,1.9986283779144287,10000,54899.59563159943,0.7344530820846558,1.072016358375549,0.6725599765777588,1.3530502319335938,50000 -6166.253857374191,5.067118167877197,49195.26766204834,105559,0,49195.26766204834,0.5433000326156616,2.0017428398132324,10000,55372.13287234306,0.7469140291213989,0.9951573610305786,0.6689199805259705,1.3489623069763184,50000 -6218.311373949051,5.114095211029053,49615.44148349762,106458,0,49615.44148349762,0.5472000241279602,2.037179470062256,10000,55844.475420475006,0.730273425579071,1.1089922189712524,0.6736199855804443,1.3686671257019043,50000 -6269.040015935898,5.159966707229614,50035.62106966972,107364,0,50035.62106966972,0.5441000461578369,2.013227939605713,10000,56315.47707438469,0.731640636920929,1.0827327966690063,0.6683599948883057,1.369884967803955,50000 -6321.7311725616455,5.206001281738281,50455.64405369759,108262,0,50455.64405369759,0.5437000393867493,2.018178701400757,10000,56788.2848212719,0.7414648532867432,1.0510865449905396,0.6689800024032593,1.3708804845809937,50000 -6374.832296609879,5.255663156509399,50875.999675273895,109164,0,50875.999675273895,0.5509999990463257,1.997771978378296,10000,57261.83912968636,0.7342773079872131,1.0697695016860962,0.6750999689102173,1.3431190252304075,50000 -6425.622935056686,5.307147979736328,51296.14013576508,110071,0,51296.14013576508,0.550000011920929,2.0083322525024414,10000,57732.86963033676,0.7425194978713989,1.0702048540115356,0.6782400012016296,1.3490768671035769,50000 -6477.19561457634,5.351512670516968,51716.33648109436,110978,0,51716.33648109436,0.5574000477790833,1.9583526849746704,10000,58204.73041367531,0.7479101419448853,1.0002694129943848,0.6759999990463257,1.3279269933700562,50000 -6528.562636137009,5.398941516876221,52136.71265649796,111880,0,52136.71265649796,0.5534999966621399,2.0046257972717285,10000,58676.56856536865,0.7358593344688416,1.0901877880096436,0.6753999590873718,1.3613321781158447,50000 -6580.2001440525055,5.444950819015503,52557.00045776367,112785,0,52557.00045776367,0.5529000163078308,1.9855674505233765,10000,59148.58823752403,0.7383984327316284,1.064749836921692,0.6764199733734131,1.340150237083435,50000 -6634.199034690857,5.4898645877838135,52976.95989322662,113682,0,52976.95989322662,0.5527999997138977,1.991071343421936,10000,59622.63963222504,0.7493749856948853,1.0180168151855469,0.6759999990463257,1.3451844453811646,50000 -6687.553718566895,5.537170886993408,53397.26622343063,114585,0,53397.26622343063,0.5619000196456909,1.937618613243103,10000,60096.39663481712,0.74609375,1.0239732265472412,0.6865999698638916,1.2994922399520874,50000 -6738.60989689827,5.5869903564453125,53817.85648846626,115490,0,53817.85648846626,0.562000036239624,1.9419201612472528,10000,60568.140437603,0.7510351538658142,1.0083297491073608,0.687279999256134,1.2991552352905271,50000 -6790.61523938179,5.63244891166687,54237.85001659393,116390,0,54237.85001659393,0.5644000172615051,1.9282153844833367,10000,61040.232503175735,0.7574414014816284,0.960770845413208,0.6850199699401855,1.277678370475769,50000 -6841.529457330704,5.67902684211731,54657.93375110626,117290,0,54657.93375110626,0.5618000030517578,1.951360106468201,10000,61511.32391309738,0.7455663681030273,1.0378332138061523,0.6839199662208557,1.3181474208831787,50000 -6893.628536224365,5.7277820110321045,55078.15105628967,118185,0,55078.15105628967,0.5634000301361084,1.9548426866531368,10000,61983.7370262146,0.75244140625,1.0247377157211304,0.6880599856376648,1.310234546661377,50000 -6947.353275775909,5.781080007553101,55498.36799407005,119087,0,55498.36799407005,0.5605000257492065,1.9243587255477903,10000,62457.77957677841,0.758105456829071,0.9552485346794128,0.6878600120544434,1.2778600454330444,50000 -6999.096809387207,5.830645799636841,55918.59027314186,119991,0,55918.59027314186,0.5639000535011292,1.8964364528656008,10000,62929.843616724014,0.7644335627555847,0.937131941318512,0.6890999674797058,1.2613743543624878,50000 -7049.592758893967,5.876109838485718,56338.974491119385,120894,0,56338.974491119385,0.566100001335144,1.8966219425201416,10000,63400.81737446785,0.7629492282867432,0.9582368731498718,0.6942200064659119,1.2611232995986938,50000 -7102.412714481354,5.923561811447144,56759.05333805084,121797,0,56759.05333805084,0.5713000297546387,1.8828102350234983,10000,63873.81185340881,0.7608593702316284,0.9513280391693116,0.6955199837684631,1.255292892456055,50000 -7153.010915279388,5.974011421203613,57179.4742333889,122701,0,57179.4742333889,0.5670000314712524,1.9046696424484253,10000,64344.930874586105,0.7781640291213989,0.8763917088508606,0.6922599673271179,1.2633213996887207,50000 -7204.862086057663,6.020724296569824,57599.657051086426,123603,0,57599.657051086426,0.5670000314712524,1.907639145851136,10000,64817.05897903442,0.7577733993530273,0.9707837700843812,0.6915799975395203,1.2661999464035034,50000 -7256.597772836685,6.068363428115845,58019.7577214241,124503,0,58019.7577214241,0.5719000101089478,1.917148470878601,10000,65288.98992753029,0.7660741806030273,0.9606534838676452,0.6967200040817261,1.2727292776107788,50000 -7307.421809196472,6.114432096481323,58439.809049129486,125405,0,58439.809049129486,0.5746000409126282,1.8689360618591309,10000,65759.95871829987,0.7822851538658142,0.8623960018157959,0.6998800039291382,1.230411410331726,50000 -7358.625300168991,6.164425373077393,58860.01943874359,126304,0,58860.01943874359,0.5796000361442566,1.875036239624024,10000,66231.4698665142,0.764843761920929,0.9532604217529296,0.6993199586868286,1.2452449798583984,50000 -7410.638778209686,6.2095441818237305,59280.34904909134,127206,0,59280.34904909134,0.5750000476837158,1.863745212554932,10000,66703.90618491173,0.7702929377555847,0.9010846614837646,0.7008799910545349,1.2191319465637207,50000 -7462.659845113754,6.260125875473023,59700.57011055946,128109,0,59700.57011055946,0.579300045967102,1.843841552734375,10000,67176.2463812828,0.7816601395606995,0.8531544208526611,0.7025600075721741,1.214359164237976,50000 -7516.386365413666,6.315888404846191,60120.5441262722,129007,0,60120.5441262722,0.5808000564575195,1.853320598602295,10000,67650.05022072792,0.7702538967132568,0.9246523976325988,0.7058799862861633,1.214882493019104,50000 -7567.649868488312,6.368206024169922,60540.73557567597,129909,0,60540.73557567597,0.5808000564575195,1.8382052183151243,10000,68121.60476827621,0.7751367092132568,0.8928495049476624,0.7060399651527405,1.207680106163025,50000 -7619.496718406677,6.896216869354248,60960.21865844727,130812,0,60960.21865844727,0.5781000256538391,1.8685678243637085,10000,68593.51099419594,0.7800390720367432,0.8839324712753296,0.7036600112915039,1.2200918197631836,50000 -7673.678128242493,6.953572034835815,61380.37313580513,131714,0,61380.37313580513,0.5771000385284424,1.840156316757202,10000,69067.9515554905,0.7735546827316284,0.887130081653595,0.7057799696922302,1.1913654804229736,50000 -7727.336859464645,7.005480766296387,61800.565898656845,132613,0,61800.565898656845,0.5889000296592712,1.8052948713302608,10000,69541.9029185772,0.7814843654632568,0.8491562604904175,0.7097600102424622,1.1710385084152222,50000 -7780.478232383728,7.055763244628906,62220.73200464249,133511,0,62220.73200464249,0.5838000178337097,1.8079313039779663,10000,70015.30837655067,0.793652355670929,0.8182350993156433,0.7105199694633484,1.1751962900161743,50000 -7830.6907432079315,7.107162237167358,62640.94975614548,134412,0,62640.94975614548,0.591200053691864,1.7911345958709717,10000,70485.8378021717,0.7824413776397705,0.8449131846427917,0.7128199934959412,1.1581400632858276,50000 -7881.830750465393,7.161731958389282,63061.04555249214,135312,0,63061.04555249214,0.589900016784668,1.7887169122695925,10000,70957.17614507675,0.782910168170929,0.8488849401473999,0.7114799618721008,1.170685648918152,50000 -7933.790218830109,7.2181336879730225,63480.95566868782,136213,0,63480.95566868782,0.5837000012397766,1.847142100334168,10000,71429.1494038105,0.7852538824081421,0.8555783033370972,0.7083799839019775,1.213556170463562,50000 -7985.396743774414,7.266888856887817,63901.08681106568,137114,0,63901.08681106568,0.5837000012397766,1.829952836036682,10000,71900.98412513733,0.7789257764816284,0.8748505711555481,0.7115199565887451,1.1856586933135986,50000 -8036.531369924545,7.317673444747925,64321.17126393318,138009,0,64321.17126393318,0.5907000303268433,1.7869914770126345,10000,72372.30072402954,0.79212886095047,0.8283264636993408,0.717960000038147,1.1578574180603027,50000 -8088.132665634155,7.369465112686157,64741.42377257347,138908,0,64741.42377257347,0.5958000421524048,1.7757651805877686,10000,72844.25393557549,0.7974609136581421,0.8029530048370361,0.7170799970626831,1.1547406911849976,50000 -8140.616198778152,7.422507762908935,65161.54459476471,139808,0,65161.54459476471,0.5945000052452087,1.7900747060775757,10000,73316.9589650631,0.7875585556030273,0.8376666903495789,0.7172799706459045,1.1552940607070925,50000 -8193.26209139824,7.47014856338501,65581.90678691864,140711,0,65581.90678691864,0.6028000116348267,1.7677903175354004,10000,73790.06270742416,0.7987695336341858,0.7999251484870911,0.7211399674415588,1.141826629638672,50000 -8245.082573652267,7.520330667495727,66002.23467111588,141614,0,66002.23467111588,0.595300018787384,1.7889913320541382,10000,74262.30921554565,0.8019140362739563,0.7914997935295105,0.7195999622344971,1.156431794166565,50000 -8295.933718919754,7.569155216217041,66422.37011957169,142514,0,66422.37011957169,0.6007000207901001,1.7685530185699463,10000,74733.39208197594,0.8023241758346558,0.7893538475036621,0.7219600081443787,1.13852059841156,50000 -8347.408848524094,7.634051322937012,66842.30435395241,143414,0,66842.30435395241,0.5990000367164612,1.7969648838043213,10000,75204.91482305527,0.7995898127555847,0.8224471211433411,0.7217999696731567,1.164766550064087,50000 -8400.72471165657,7.686464548110962,67262.48490691185,144311,0,67262.48490691185,0.5958000421524048,1.7974989414215088,10000,75678.5110874176,0.8031640648841858,0.8152992725372314,0.7218199968338013,1.169253706932068,50000 -8452.127160549164,7.742059707641602,67682.51352453232,145214,0,67682.51352453232,0.6038000583648682,1.7650368213653564,10000,76150.04582834244,0.8151953220367432,0.7412189841270447,0.7258599996566772,1.138275146484375,50000 -8503.738919973373,7.794044256210327,68102.61263489723,146115,0,68102.61263489723,0.6037000417709351,1.7594192028045654,10000,76621.85620999336,0.8047655820846558,0.7920248508453369,0.726099967956543,1.126613736152649,50000 -8555.59926533699,7.842405796051025,68522.7866191864,147015,0,68522.7866191864,0.6011000275611877,1.7695281505584717,10000,77093.98651838303,0.8040234446525574,0.7940880656242371,0.7256399989128113,1.1346133947372437,50000 -8608.788478851318,7.894062995910644,68942.73813343048,147914,0,68942.73813343048,0.6055000424385071,1.7449946403503418,10000,77567.22704386711,0.8194531202316284,0.7242646813392639,0.7303400039672852,1.1180957555770874,50000 -8660.702143192291,7.949257135391235,69363.04661417007,148809,0,69363.04661417007,0.6076000332832336,1.71618390083313,10000,78039.55141115189,0.8100390434265137,0.7443209886550903,0.7328599691390991,1.0836445093154907,50000 -8714.796729803085,8.007672309875488,69783.11700677872,149710,0,69783.11700677872,0.6105000376701355,1.738780498504639,10000,78513.82325577736,0.8106640577316284,0.7712754011154175,0.7304799556732178,1.1216198205947876,50000 -8767.916935682297,8.061208963394165,70203.25068831444,150609,0,70203.25068831444,0.6121000051498413,1.7156203985214231,10000,78987.17783665657,0.8217577934265137,0.6940465569496155,0.7341199517250061,1.0832279920578003,50000 -8820.977076530457,8.114773273468018,70623.22665309906,151509,0,70623.22665309906,0.6114000082015991,1.724048733711243,10000,79460.3146841526,0.8132226467132568,0.7564131021499634,0.7326799631118774,1.1018426418304443,50000 -8871.157362699509,8.176369428634644,71043.51707410812,152411,0,71043.51707410812,0.6057000160217285,1.7319797277450562,10000,79930.89526510239,0.8139452934265137,0.7501934170722961,0.7333599925041199,1.1036591529846191,50000 -8923.512953042984,8.234338998794556,71463.45511341095,153312,0,71463.45511341095,0.6135000586509705,1.6989619731903076,10000,80403.29509854317,0.826464831829071,0.6860719919204712,0.7383999824523926,1.0645508766174316,50000 -8975.040938138962,8.287301778793335,71883.42968702316,154210,0,71883.42968702316,0.6165000200271606,1.6909252405166626,10000,80874.89892697334,0.8174414038658142,0.7186430096626282,0.7394199967384338,1.062025547027588,50000 -9026.914111852646,8.34177279472351,72303.53037238121,155111,0,72303.53037238121,0.6168000102043152,1.6916823387145996,10000,81346.9751830101,0.8227929472923279,0.6901466250419617,0.7375199794769287,1.0573351383209229,50000 -9080.40405368805,8.39595341682434,72723.77445316315,156013,0,72723.77445316315,0.6097000241279602,1.709358811378479,10000,81820.81177711487,0.8246484398841858,0.6990688443183899,0.7361199855804443,1.0750834941864014,50000 -9131.963080406187,8.450865745544434,73143.83418130875,156915,0,73143.83418130875,0.6189000010490417,1.6846977472305298,10000,82292.53384494781,0.8210351467132568,0.7016533613204956,0.74263995885849,1.0572545528411863,50000 -9182.99531197548,8.513504981994629,73563.9743270874,157815,0,73563.9743270874,0.615600049495697,1.6868818998336792,10000,82763.81664323807,0.8214843273162842,0.6942675709724426,0.7416399717330933,1.0510855913162231,50000 -9234.988682985306,8.569389820098877,73984.35371112823,158714,0,73984.35371112823,0.6203000545501709,1.66560959815979,10000,83236.29381513596,0.8269335627555847,0.649540901184082,0.7415800094604492,1.031722068786621,50000 -9286.535109996796,8.62211298942566,74404.43561291695,159614,0,74404.43561291695,0.6204000115394592,1.683318853378296,10000,83708.0224416256,0.8246288895606995,0.6832771897315979,0.7423799633979797,1.0461320877075195,50000 -9337.866757631302,8.684773206710815,74824.56441307068,160509,0,74824.56441307068,0.6167000532150269,1.6868876218795776,10000,84179.59306025505,0.8296093344688416,0.6793026328086853,0.744159996509552,1.0463520288467407,50000 -9391.32567691803,8.740829706192017,75244.89977121353,161411,0,75244.89977121353,0.6142000555992126,1.708626389503479,10000,84653.49061632156,0.8265038728713989,0.695087730884552,0.7425000071525574,1.076271653175354,50000 -9442.36303639412,8.792722463607788,75665.09418177605,162313,0,75665.09418177605,0.6253000497817993,1.6685092449188232,10000,85124.82177829742,0.8310937285423279,0.6638551950454712,0.7471999526023865,1.0294922590255735,50000 -9493.877099275587,8.844767093658447,76085.4808113575,163214,0,76085.4808113575,0.6261000037193298,1.664926290512085,10000,85596.82547187805,0.8335351347923279,0.658676266670227,0.7470600008964539,1.034615993499756,50000 -9544.983451128006,8.900535106658936,76505.46322965622,164114,0,76505.46322965622,0.6234000325202942,1.6698811054229736,10000,86068.02240657806,0.8374804258346558,0.6364269256591797,0.747439980506897,1.0335808992385864,50000 -9596.894824266434,8.953065156936646,76925.65456914902,165017,0,76925.65456914902,0.6254000067710876,1.6488336324691772,10000,86540.2251765728,0.838183581829071,0.639176070690155,0.7498399615287781,1.0234603881835938,50000 -9649.550794363022,9.01101303100586,77345.65645003319,165916,0,77345.65645003319,0.6272000074386597,1.657045841217041,10000,87012.98839354515,0.8354687094688416,0.6519188284873962,0.7498599886894226,1.0220907926559448,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/measurements.csv deleted file mode 100644 index 5c3126177..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1850 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.34849083,6.907756,,,,,,,,,,,,,, -1,,,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,40.29682993888855,93.49904775619508,40.29682993888855,53.20211386680603,0.0,0.0 -100,0.4443306,6.882812,,,,,,,,,,,,,, -200,0.7352036,6.80726,,,,,,,,,,,,,, -300,0.7514103,6.675519,,,,,,,,,,,,,, -400,0.9284753,6.5976152,,,,,,,,,,,,,, -500,0.95151836,6.701248,,,,,,,,,,,,,, -600,1.3488687,6.509419,,,,,,,,,,,,,, -700,1.3643855,6.347097,,,,,,,,,,,,,, -800,0.968558,6.563069,,,,,,,,,,,,,, -850,,,0.0340039059519767,5.898688793182373,0.0292599983513355,5.964993476867676,50000.0,0.0224000010639429,6.085418224334717,10000.0,460.35428380966187,564.3717305660248,460.35428380966187,103.95526385307312,0.0180015563964843,0.0 -900,1.7616351,6.4907074,,,,,,,,,,,,,, -1000,1.2228092,6.0826826,,,,,,,,,,,,,, -1100,1.378355,6.7012773,,,,,,,,,,,,,, -1200,1.2821903,6.041514,,,,,,,,,,,,,, -1300,1.8654741,6.73869,,,,,,,,,,,,,, -1400,1.0578218,6.1379414,,,,,,,,,,,,,, -1500,0.9764123,6.1871905,,,,,,,,,,,,,, -1600,0.9599773,5.780385,,,,,,,,,,,,,, -1700,1.0374498,5.712042,,,,,,,,,,,,,, -1755,,,0.0697851553559303,5.347999572753906,0.0648000016808509,5.401942253112793,50000.0,0.0504000037908554,5.601574420928955,10000.0,880.7335257530212,1034.3701009750366,880.7335257530212,153.4896755218506,0.0557641983032226,0.0 -1800,0.931982,5.8168073,,,,,,,,,,,,,, -1900,1.0327837,6.3913755,,,,,,,,,,,,,, -2000,0.8433645,6.545988,,,,,,,,,,,,,, -2100,1.0915482,5.647473,,,,,,,,,,,,,, -2200,0.8109934,6.1455503,,,,,,,,,,,,,, -2300,0.8435701,5.6287546,,,,,,,,,,,,,, -2400,1.3954287,5.495467,,,,,,,,,,,,,, -2500,0.62587804,6.3627167,,,,,,,,,,,,,, -2600,0.9031124,6.068066,,,,,,,,,,,,,, -2667,,,0.116367183625698,4.784284591674805,0.1093599945306778,4.85033655166626,50000.0,0.0861000046133995,5.135005950927734,10000.0,1301.019949913025,1506.262699842453,1301.019949913025,205.0133285522461,0.0906887054443359,0.0 -2700,0.8020141,5.7380643,,,,,,,,,,,,,, -2800,0.9835427,5.433894,,,,,,,,,,,,,, -2900,0.9074565,5.461616,,,,,,,,,,,,,, -3000,1.0494418,5.283935,,,,,,,,,,,,,, -3100,0.9560899,5.1937137,,,,,,,,,,,,,, -3200,0.8134841,6.498567,,,,,,,,,,,,,, -3300,0.7585902,5.239679,,,,,,,,,,,,,, -3400,0.9399389,5.69285,,,,,,,,,,,,,, -3500,0.8435208,5.173619,,,,,,,,,,,,,, -3575,,,0.1799414008855819,4.237734794616699,0.1611000001430511,4.358086585998535,50000.0,0.1219000071287155,4.7613372802734375,10000.0,1721.0184772014618,1976.84295463562,1721.0184772014618,255.52188277244568,0.1169803142547607,0.0 -3600,1.1445765,5.2352695,,,,,,,,,,,,,, -3700,0.8547691,5.0759273,,,,,,,,,,,,,, -3800,0.9461258,5.0614657,,,,,,,,,,,,,, -3900,0.92978287,4.9388657,,,,,,,,,,,,,, -4000,0.9120948,6.361823,,,,,,,,,,,,,, -4100,0.71929115,6.2600584,,,,,,,,,,,,,, -4200,0.7548473,5.1207395,,,,,,,,,,,,,, -4300,0.6368809,6.2363396,,,,,,,,,,,,,, -4400,0.89165556,4.8076725,,,,,,,,,,,,,, -4481,,,0.2168945223093032,3.952517986297608,0.2021399885416031,4.050927639007568,50000.0,0.1535000056028366,4.470614910125732,10000.0,2141.226948738098,2449.213127851486,2141.226948738098,307.60889863967896,0.1440844535827636,0.0 -4500,0.6711258,5.855978,,,,,,,,,,,,,, -4600,0.9121632,4.667151,,,,,,,,,,,,,, -4700,0.8307198,4.8389916,,,,,,,,,,,,,, -4800,0.8443097,4.5668774,,,,,,,,,,,,,, -4900,0.69733644,5.551006,,,,,,,,,,,,,, -5000,0.71797925,5.3108687,,,,,,,,,,,,,, -5100,1.0755372,4.8711524,,,,,,,,,,,,,, -5200,0.7367353,4.642479,,,,,,,,,,,,,, -5300,1.1470411,4.574643,,,,,,,,,,,,,, -5390,,,0.2662499845027923,3.6010043621063232,0.2474399954080581,3.722931385040283,50000.0,0.1933000087738037,4.191379547119141,10000.0,2561.221724510193,2921.4478216171265,2561.221724510193,359.7699909210205,0.1761953830718994,0.0 -5400,0.8621529,4.641099,,,,,,,,,,,,,, -5500,0.871458,4.462156,,,,,,,,,,,,,, -5600,0.7234158,6.097019,,,,,,,,,,,,,, -5700,0.56961346,6.137297,,,,,,,,,,,,,, -5800,0.82227045,4.6566215,,,,,,,,,,,,,, -5900,0.8204822,4.1617517,,,,,,,,,,,,,, -6000,0.7489298,4.372951,,,,,,,,,,,,,, -6100,0.78805757,4.320787,,,,,,,,,,,,,, -6200,0.71898615,4.7354474,,,,,,,,,,,,,, -6300,0.51781565,6.115749,,,,,,,,,,,,,, -6302,,,0.3132421672344208,3.299769163131714,0.2857999801635742,3.4638993740081787,50000.0,0.2187000066041946,3.96230673789978,10000.0,2981.236423969269,3395.177416563034,2981.236423969269,413.4065380096436,0.2065181732177734,0.0 -6400,0.54109955,5.3354826,,,,,,,,,,,,,, -6500,0.82053167,4.5196033,,,,,,,,,,,,,, -6600,0.8007403,4.439628,,,,,,,,,,,,,, -6700,0.71822006,4.7494316,,,,,,,,,,,,,, -6800,0.7026888,4.487993,,,,,,,,,,,,,, -6900,0.793103,4.162724,,,,,,,,,,,,,, -7000,0.8399517,4.223076,,,,,,,,,,,,,, -7100,0.94433385,4.0886626,,,,,,,,,,,,,, -7200,0.7694431,4.2483416,,,,,,,,,,,,,, -7214,,,0.3377148509025574,3.136587858200073,0.3125399947166443,3.273046493530273,50000.0,0.2492000162601471,3.800960302352905,10000.0,3401.2242062091827,3866.043285608292,3401.2242062091827,464.2115104198456,0.2327077388763427,0.0 -7300,0.635819,6.100253,,,,,,,,,,,,,, -7400,0.9480798,4.133012,,,,,,,,,,,,,, -7500,0.740777,3.9600048,,,,,,,,,,,,,, -7600,0.7352216,4.3503056,,,,,,,,,,,,,, -7700,0.7983889,4.0614395,,,,,,,,,,,,,, -7800,0.79838824,4.1439238,,,,,,,,,,,,,, -7900,0.824501,4.5775123,,,,,,,,,,,,,, -8000,0.8377623,5.929782,,,,,,,,,,,,,, -8100,0.8963583,4.098383,,,,,,,,,,,,,, -8125,,,0.3757226467132568,2.8953447341918945,0.3463599979877472,3.064788341522217,50000.0,0.268200010061264,3.6436169147491455,10000.0,3821.3774077892303,4336.0475380420685,3821.3774077892303,513.9831821918488,0.2652533054351806,0.0 -8200,0.6258759,5.102156,,,,,,,,,,,,,, -8300,1.0483307,4.1589646,,,,,,,,,,,,,, -8400,0.9431401,3.9290943,,,,,,,,,,,,,, -8500,0.94626355,3.9367218,,,,,,,,,,,,,, -8600,0.7231335,4.9335427,,,,,,,,,,,,,, -8700,0.5518955,5.7032127,,,,,,,,,,,,,, -8800,1.0394399,3.8839886,,,,,,,,,,,,,, -8900,0.6781293,5.5741887,,,,,,,,,,,,,, -9000,0.76665634,4.322622,,,,,,,,,,,,,, -9031,,,0.3907421827316284,2.777474880218506,0.3562799990177154,2.988455295562744,50000.0,0.2786000072956085,3.57048773765564,10000.0,4241.45069026947,4807.605703353882,4241.45069026947,565.3924582004547,0.2942073345184326,0.0 -9100,1.0786437,3.8787222,,,,,,,,,,,,,, -9200,0.9016974,3.7941403,,,,,,,,,,,,,, -9300,0.66394216,5.3313665,,,,,,,,,,,,,, -9400,1.0742974,3.7233307,,,,,,,,,,,,,, -9500,0.70571357,5.9301496,,,,,,,,,,,,,, -9600,0.62388927,5.602517,,,,,,,,,,,,,, -9700,0.63262373,5.7620106,,,,,,,,,,,,,, -9800,0.8991698,3.7268045,,,,,,,,,,,,,, -9900,0.6353771,5.8739758,,,,,,,,,,,,,, -9936,,,0.3995312452316284,2.739250659942627,0.373879998922348,2.8733174800872803,50000.0,0.2880000174045563,3.4860610961914062,10000.0,4661.742277622223,5279.853937387466,4661.742277622223,617.273538351059,0.3234410285949707,0.0 -10000,0.8948127,3.6189036,,,,,,,,,,,,,, -10100,0.78268147,5.8798656,,,,,,,,,,,,,, -10200,0.7680445,5.194078,,,,,,,,,,,,,, -10300,0.65280503,5.0339227,,,,,,,,,,,,,, -10400,0.7647078,4.5496254,,,,,,,,,,,,,, -10500,0.72008926,4.4980073,,,,,,,,,,,,,, -10600,0.8897404,3.617826,,,,,,,,,,,,,, -10700,0.936233,3.7394829,,,,,,,,,,,,,, -10800,0.85198265,3.7789235,,,,,,,,,,,,,, -10843,,,0.4029882848262787,2.765798807144165,0.3730599880218506,2.929877519607544,50000.0,0.2933000028133392,3.4927845001220703,10000.0,5081.717084169388,5753.537490606308,5081.717084169388,670.9042701721191,0.3543593883514404,0.0 -10900,0.64360285,4.915125,,,,,,,,,,,,,, -11000,0.8911729,3.5333214,,,,,,,,,,,,,, -11100,1.0186527,3.599619,,,,,,,,,,,,,, -11200,0.8681889,3.7987041,,,,,,,,,,,,,, -11300,0.9616859,3.606094,,,,,,,,,,,,,, -11400,0.7993257,5.7109966,,,,,,,,,,,,,, -11500,0.78860885,4.4550247,,,,,,,,,,,,,, -11600,0.90986824,3.6579468,,,,,,,,,,,,,, -11700,0.7039792,5.085369,,,,,,,,,,,,,, -11750,,,0.4351171851158142,2.5774996280670166,0.3994199931621551,2.767735004425049,50000.0,0.3113000094890594,3.375816583633423,10000.0,5501.863046884537,6225.0086896419525,5501.863046884537,722.1547908782959,0.3821694850921631,0.0 -11800,0.9690445,3.5527248,,,,,,,,,,,,,, -11900,0.8885082,5.7120795,,,,,,,,,,,,,, -12000,0.975726,3.3633554,,,,,,,,,,,,,, -12100,0.9297657,4.1537576,,,,,,,,,,,,,, -12200,0.84221125,3.6468592,,,,,,,,,,,,,, -12300,1.0051352,3.5333488,,,,,,,,,,,,,, -12400,0.85600907,3.698825,,,,,,,,,,,,,, -12500,1.1239744,3.476401,,,,,,,,,,,,,, -12600,0.95804644,4.241269,,,,,,,,,,,,,, -12656,,,0.4435742199420929,2.502842903137207,0.4075399935245514,2.688591718673706,50000.0,0.3148000240325928,3.308043956756592,10000.0,5921.816805839539,6697.850959300995,5921.816805839539,774.9626307487488,0.4158406257629394,0.0 -12700,0.91382754,3.485703,,,,,,,,,,,,,, -12800,0.7039283,5.048729,,,,,,,,,,,,,, -12900,0.9087301,3.4486573,,,,,,,,,,,,,, -13000,1.0020807,3.4589767,,,,,,,,,,,,,, -13100,0.9468751,3.5956185,,,,,,,,,,,,,, -13200,0.9694759,3.418402,,,,,,,,,,,,,, -13300,1.0300007,3.373278,,,,,,,,,,,,,, -13400,0.68398786,5.810048,,,,,,,,,,,,,, -13500,0.954297,3.4188085,,,,,,,,,,,,,, -13561,,,0.4646874964237213,2.3886172771453857,0.4336400032043457,2.552778720855713,50000.0,0.3383000195026397,3.179518222808838,10000.0,6342.030628442764,7169.624086141586,6342.030628442764,826.442524433136,0.4475893974304199,0.0 -13600,0.8066623,5.060335,,,,,,,,,,,,,, -13700,0.72619694,5.0077877,,,,,,,,,,,,,, -13800,0.7640293,4.563603,,,,,,,,,,,,,, -13900,0.6969559,5.44791,,,,,,,,,,,,,, -14000,1.0989634,3.4499414,,,,,,,,,,,,,, -14100,1.0485281,3.3886364,,,,,,,,,,,,,, -14200,1.1424155,3.4183774,,,,,,,,,,,,,, -14300,0.8973796,4.9426117,,,,,,,,,,,,,, -14400,1.1225215,3.517126,,,,,,,,,,,,,, -14465,,,0.4827734231948852,2.269944429397583,0.4453199803829193,2.4657223224639893,50000.0,0.3439000248908996,3.1078310012817383,10000.0,6762.030198812485,7642.619294166565,6762.030198812485,879.3584513664246,0.4796981811523437,0.0 -14500,0.96208954,3.2400565,,,,,,,,,,,,,, -14600,1.00046,3.3484747,,,,,,,,,,,,,, -14700,1.0232434,3.2920134,,,,,,,,,,,,,, -14800,1.0351564,3.3827107,,,,,,,,,,,,,, -14900,0.81533974,5.8458123,,,,,,,,,,,,,, -15000,0.8824767,4.2979884,,,,,,,,,,,,,, -15100,0.9695208,3.3092222,,,,,,,,,,,,,, -15200,0.80328614,3.955645,,,,,,,,,,,,,, -15300,0.8916852,3.7584984,,,,,,,,,,,,,, -15371,,,0.5088281035423279,2.144819974899292,0.4482799768447876,2.458869457244873,50000.0,0.344400018453598,3.1176199913024902,10000.0,7182.221163272858,8113.95642209053,7182.221163272858,930.4304752349854,0.5069384574890137,0.0 -15400,0.937875,3.3005283,,,,,,,,,,,,,, -15500,0.6906818,5.400544,,,,,,,,,,,,,, -15600,0.7244038,4.5304675,,,,,,,,,,,,,, -15700,0.9225302,3.2474713,,,,,,,,,,,,,, -15800,0.76194376,5.642768,,,,,,,,,,,,,, -15900,0.8782641,3.7337213,,,,,,,,,,,,,, -16000,0.85017174,4.192633,,,,,,,,,,,,,, -16100,0.7946743,5.6242666,,,,,,,,,,,,,, -16200,0.8916633,3.397703,,,,,,,,,,,,,, -16277,,,0.4963085949420929,2.2050812244415283,0.457999974489212,2.4059407711029053,50000.0,0.3543000221252441,3.049616813659668,10000.0,7602.466048240662,8585.593609571457,7602.466048240662,981.7441575527192,0.5383155345916748,0.0 -16300,1.1173893,3.4570317,,,,,,,,,,,,,, -16400,1.0941745,3.3100443,,,,,,,,,,,,,, -16500,1.0498028,3.1685886,,,,,,,,,,,,,, -16600,1.2556398,3.3591967,,,,,,,,,,,,,, -16700,1.0421902,3.2026827,,,,,,,,,,,,,, -16800,1.0603592,3.3999293,,,,,,,,,,,,,, -16900,0.78629255,4.6743336,,,,,,,,,,,,,, -17000,1.0314736,3.3620524,,,,,,,,,,,,,, -17100,1.0764056,3.1686215,,,,,,,,,,,,,, -17183,,,0.499804675579071,2.1979124546051025,0.4614399969577789,2.4079489707946777,50000.0,0.3633000254631042,3.038583993911743,10000.0,8022.563579797745,9057.795295476912,8022.563579797745,1033.7704393863678,0.5695366859436035,0.0 -17200,1.1675489,3.3498192,,,,,,,,,,,,,, -17300,1.0107642,3.1255977,,,,,,,,,,,,,, -17400,0.8856311,4.161525,,,,,,,,,,,,,, -17500,1.1000599,3.2690036,,,,,,,,,,,,,, -17600,1.2138036,3.2610326,,,,,,,,,,,,,, -17700,1.0136739,3.1916163,,,,,,,,,,,,,, -17800,1.0491089,3.4987469,,,,,,,,,,,,,, -17900,1.0235662,3.2256098,,,,,,,,,,,,,, -18000,1.0669837,3.2665737,,,,,,,,,,,,,, -18085,,,0.5384179353713989,1.9806084632873533,0.4762399792671203,2.3108084201812744,50000.0,0.3718000054359436,2.949524641036988,10000.0,8442.695451498032,9530.55030632019,8442.695451498032,1086.3128879070282,0.6030721664428711,0.0 -18100,0.7718257,5.601017,,,,,,,,,,,,,, -18200,0.9714069,3.1348507,,,,,,,,,,,,,, -18300,1.081531,3.3756423,,,,,,,,,,,,,, -18400,0.76443094,5.5828004,,,,,,,,,,,,,, -18500,0.97446996,3.2613554,,,,,,,,,,,,,, -18600,1.0895864,3.2161934,,,,,,,,,,,,,, -18700,0.96703863,3.300268,,,,,,,,,,,,,, -18800,1.256185,3.1410074,,,,,,,,,,,,,, -18900,1.2190143,3.1437213,,,,,,,,,,,,,, -18992,,,0.5228710770606995,2.079188823699951,0.487419992685318,2.272768020629883,50000.0,0.380700021982193,2.9152309894561768,10000.0,8862.955168247223,10002.043234586716,8862.955168247223,1137.469367980957,0.6330897808074951,0.0 -19000,1.1161317,3.1336148,,,,,,,,,,,,,, -19100,0.9458953,4.985527,,,,,,,,,,,,,, -19200,0.91733193,3.3421602,,,,,,,,,,,,,, -19300,1.0453357,3.1454027,,,,,,,,,,,,,, -19400,0.8977299,3.6115627,,,,,,,,,,,,,, -19500,1.0482838,3.1093698,,,,,,,,,,,,,, -19600,1.1358376,3.2784796,,,,,,,,,,,,,, -19700,1.0416274,3.136512,,,,,,,,,,,,,, -19800,0.99599826,3.5061471,,,,,,,,,,,,,, -19898,,,0.5240820050239563,2.092133045196533,0.4813199937343597,2.309021234512329,50000.0,0.3789000213146209,2.931926727294922,10000.0,9283.24651813507,10475.134246349337,9283.24651813507,1190.1830968856812,0.6724350452423096,0.0 -19900,0.9270892,3.5359874,,,,,,,,,,,,,, -20000,1.1284126,3.1652942,,,,,,,,,,,,,, -20100,1.0580264,4.2660775,,,,,,,,,,,,,, -20200,0.75657064,5.617144,,,,,,,,,,,,,, -20300,0.8634832,3.9515526,,,,,,,,,,,,,, -20400,1.0710996,5.619597,,,,,,,,,,,,,, -20500,0.8103847,5.6033945,,,,,,,,,,,,,, -20600,1.0675588,3.2406719,,,,,,,,,,,,,, -20700,1.0802896,3.0500727,,,,,,,,,,,,,, -20800,1.0412254,5.0150127,,,,,,,,,,,,,, -20801,,,0.5454296469688416,1.9585974216461184,0.4882199764251709,2.2411868572235107,50000.0,0.3815000057220459,2.893972873687744,10000.0,9703.740409374235,10947.46983218193,9703.740409374235,1241.9485597610474,0.7021317481994629,0.0 -20900,0.9739087,4.0339184,,,,,,,,,,,,,, -21000,0.74675256,4.930768,,,,,,,,,,,,,, -21100,1.1672946,3.172959,,,,,,,,,,,,,, -21200,1.0591465,3.2928417,,,,,,,,,,,,,, -21300,0.7514843,5.611923,,,,,,,,,,,,,, -21400,0.8535476,3.8842084,,,,,,,,,,,,,, -21500,0.7539666,5.488266,,,,,,,,,,,,,, -21600,0.76625913,5.460301,,,,,,,,,,,,,, -21700,0.7290415,5.307366,,,,,,,,,,,,,, -21708,,,0.5394921898841858,1.981076717376709,0.4965799748897552,2.200188636779785,50000.0,0.3945000171661377,2.8344380855560303,10000.0,10123.801826238632,11420.220302343369,10123.801826238632,1294.5543503761292,0.7387468814849854,0.0 -21800,1.0304661,3.1887045,,,,,,,,,,,,,, -21900,1.1626835,3.234624,,,,,,,,,,,,,, -22000,1.1709396,3.1551006,,,,,,,,,,,,,, -22100,1.0341693,3.277284,,,,,,,,,,,,,, -22200,0.9659013,4.219419,,,,,,,,,,,,,, -22300,1.0993294,3.1292999,,,,,,,,,,,,,, -22400,0.7346588,5.1285057,,,,,,,,,,,,,, -22500,1.0206919,3.903864,,,,,,,,,,,,,, -22600,0.9243336,3.6138449,,,,,,,,,,,,,, -22610,,,0.5367578268051147,1.97980535030365,0.4986799955368042,2.189040422439575,50000.0,0.380700021982193,2.866044759750366,10000.0,10543.785840034485,11891.540809392927,10543.785840034485,1345.811369419098,0.7716727256774902,0.0 -22700,1.1924273,3.1240704,,,,,,,,,,,,,, -22800,0.98039865,3.7251427,,,,,,,,,,,,,, -22900,1.1759207,3.384228,,,,,,,,,,,,,, -23000,1.1277888,3.0959747,,,,,,,,,,,,,, -23100,1.1279985,3.0529518,,,,,,,,,,,,,, -23200,1.1249439,3.1419432,,,,,,,,,,,,,, -23300,0.78212714,4.3424673,,,,,,,,,,,,,, -23400,1.0287756,3.0170557,,,,,,,,,,,,,, -23500,0.7659596,5.2883615,,,,,,,,,,,,,, -23511,,,0.5564843416213989,1.8949991464614868,0.5028799772262573,2.1828277111053467,50000.0,0.3975000083446502,2.843322515487671,10000.0,10963.787488222122,12363.682228326796,10963.787488222122,1397.873296737671,0.8026630878448486,0.0 -23600,1.1786404,3.7612681,,,,,,,,,,,,,, -23700,1.1272784,2.8796089,,,,,,,,,,,,,, -23800,1.0970424,3.5241728,,,,,,,,,,,,,, -23900,1.0837765,3.1783528,,,,,,,,,,,,,, -24000,0.8846114,4.3048635,,,,,,,,,,,,,, -24100,1.1157097,3.4045658,,,,,,,,,,,,,, -24200,0.9625013,3.9965749,,,,,,,,,,,,,, -24300,1.1013107,3.0820272,,,,,,,,,,,,,, -24400,1.1831138,4.9407067,,,,,,,,,,,,,, -24413,,,0.5447851419448853,1.9759751558303835,0.5110799670219421,2.1550962924957275,50000.0,0.399800032377243,2.798171281814575,10000.0,11383.82677412033,12836.232869625092,11383.82677412033,1450.3069911003113,0.8335375785827637,0.0 -24500,1.0138633,3.1273255,,,,,,,,,,,,,, -24600,1.1867331,3.1250427,,,,,,,,,,,,,, -24700,0.9618902,4.443325,,,,,,,,,,,,,, -24800,0.8236581,5.459691,,,,,,,,,,,,,, -24900,0.9518952,5.2105255,,,,,,,,,,,,,, -25000,0.81024563,5.61,,,,,,,,,,,,,, -25100,0.97962475,3.6606832,,,,,,,,,,,,,, -25200,0.901455,4.3643246,,,,,,,,,,,,,, -25300,0.93130666,5.225534,,,,,,,,,,,,,, -25317,,,0.5570703148841858,1.8775259256362915,0.5131399631500244,2.10324501991272,50000.0,0.4092000126838684,2.747539758682251,10000.0,11803.81246328354,13308.797145843506,11803.81246328354,1502.8061113357544,0.8648619651794434,0.0 -25400,1.121771,3.0548377,,,,,,,,,,,,,, -25500,1.0353297,3.0156548,,,,,,,,,,,,,, -25600,0.86030626,4.563039,,,,,,,,,,,,,, -25700,0.98118854,3.2838285,,,,,,,,,,,,,, -25800,1.1456379,3.0690784,,,,,,,,,,,,,, -25900,1.1963693,3.3093183,,,,,,,,,,,,,, -26000,1.3979365,2.9416146,,,,,,,,,,,,,, -26100,0.92368484,5.5292435,,,,,,,,,,,,,, -26200,1.0804803,3.367162,,,,,,,,,,,,,, -26218,,,0.5672265291213989,1.894603490829468,0.5156599879264832,2.146013259887696,50000.0,0.4030000269412994,2.811911106109619,10000.0,12223.884135246277,13781.913187265396,12223.884135246277,1555.772926568985,0.895737886428833,0.0 -26300,1.1282974,2.9541485,,,,,,,,,,,,,, -26400,0.9807191,3.8947318,,,,,,,,,,,,,, -26500,1.0798126,2.9270797,,,,,,,,,,,,,, -26600,1.1089404,2.9596622,,,,,,,,,,,,,, -26700,1.104547,3.062709,,,,,,,,,,,,,, -26800,0.805526,5.18433,,,,,,,,,,,,,, -26900,0.9060673,4.5167046,,,,,,,,,,,,,, -27000,1.1508203,2.8619583,,,,,,,,,,,,,, -27100,1.1246166,3.3392017,,,,,,,,,,,,,, -27116,,,0.5529491901397705,1.901785969734192,0.5184800028800964,2.0907881259918213,50000.0,0.4043000340461731,2.7777791023254395,10000.0,12643.856747865677,14252.93661761284,12643.856747865677,1606.7444801330566,0.9290673732757568,0.0 -27200,1.1220349,3.1462646,,,,,,,,,,,,,, -27300,1.1277862,2.7504084,,,,,,,,,,,,,, -27400,0.91113496,4.620497,,,,,,,,,,,,,, -27500,1.0143393,3.1635697,,,,,,,,,,,,,, -27600,0.89556694,5.4109497,,,,,,,,,,,,,, -27700,0.798762,4.487503,,,,,,,,,,,,,, -27800,0.8605928,5.4966526,,,,,,,,,,,,,, -27900,1.2549003,3.0577307,,,,,,,,,,,,,, -28000,1.1522425,2.902376,,,,,,,,,,,,,, -28016,,,0.5609570145606995,1.945162057876587,0.5189399719238281,2.1581993103027344,50000.0,0.4104000329971313,2.7968037128448486,10000.0,13064.020081281662,14725.257707118988,13064.020081281662,1658.8222217559814,0.9629182815551758,0.0 -28100,0.9946857,3.0646193,,,,,,,,,,,,,, -28200,0.9976774,3.9877088,,,,,,,,,,,,,, -28300,1.2510515,3.0680504,,,,,,,,,,,,,, -28400,1.0422565,2.858689,,,,,,,,,,,,,, -28500,1.1854593,3.5155692,,,,,,,,,,,,,, -28600,1.0108098,3.149994,,,,,,,,,,,,,, -28700,1.0643245,4.471058,,,,,,,,,,,,,, -28800,1.101119,4.7625723,,,,,,,,,,,,,, -28900,1.002415,3.5291872,,,,,,,,,,,,,, -28920,,,0.5832226276397705,1.7956725358963013,0.5298199653625488,2.0622482299804688,50000.0,0.4128000140190124,2.717691898345948,10000.0,13484.096655845642,15197.368015289308,13484.096655845642,1710.7781717777252,0.99446702003479,0.0 -29000,1.1382022,3.0207658,,,,,,,,,,,,,, -29100,0.8870618,4.525935,,,,,,,,,,,,,, -29200,0.8265731,3.5371928,,,,,,,,,,,,,, -29300,1.0846591,3.2151656,,,,,,,,,,,,,, -29400,1.0224975,3.080461,,,,,,,,,,,,,, -29500,1.0413125,3.0607905,,,,,,,,,,,,,, -29600,1.1686877,2.9258227,,,,,,,,,,,,,, -29700,1.1737826,2.8707554,,,,,,,,,,,,,, -29800,1.0317429,2.8616893,,,,,,,,,,,,,, -29821,,,0.5748437643051147,1.802386283874512,0.538159966468811,1.9932000637054443,50000.0,0.4171000123023987,2.6709110736846924,10000.0,13904.335749864578,15671.247034311296,13904.335749864578,1764.337646484375,1.0288941860198977,0.0 -29900,1.7931921,2.931974,,,,,,,,,,,,,, -30000,1.1844609,2.846205,,,,,,,,,,,,,, -30100,1.0082248,3.421433,,,,,,,,,,,,,, -30200,0.848236,3.9016209,,,,,,,,,,,,,, -30300,1.083163,2.9320078,,,,,,,,,,,,,, -30400,1.0963762,5.1842027,,,,,,,,,,,,,, -30500,1.0819052,3.4040477,,,,,,,,,,,,,, -30600,1.1011859,2.9046893,,,,,,,,,,,,,, -30700,1.0328454,3.6849644,,,,,,,,,,,,,, -30719,,,0.5781054496765137,1.7854104042053225,0.5322200059890747,2.0117881298065186,50000.0,0.4170000255107879,2.6758742332458496,10000.0,14324.31052160263,16144.646352529526,14324.31052160263,1817.6779038906093,1.0666768550872805,0.0 -30800,0.8245394,4.8419757,,,,,,,,,,,,,, -30900,1.019473,3.913994,,,,,,,,,,,,,, -31000,1.2618616,2.8070333,,,,,,,,,,,,,, -31100,1.0686338,2.8696702,,,,,,,,,,,,,, -31200,0.99737006,4.5777397,,,,,,,,,,,,,, -31300,1.033038,4.363845,,,,,,,,,,,,,, -31400,1.1689996,2.811573,,,,,,,,,,,,,, -31500,0.96647346,3.5080545,,,,,,,,,,,,,, -31600,1.0420835,2.8448257,,,,,,,,,,,,,, -31622,,,0.5849804282188416,1.757678747177124,0.5351200103759766,2.018251895904541,50000.0,0.4200000166893005,2.6826159954071045,10000.0,14744.462392568588,16617.61550116539,14744.462392568588,1870.412175655365,1.103344440460205,0.0 -31700,1.1343465,3.1291804,,,,,,,,,,,,,, -31800,1.0556238,3.048271,,,,,,,,,,,,,, -31900,1.2688606,2.8821635,,,,,,,,,,,,,, -32000,1.0828218,2.9314802,,,,,,,,,,,,,, -32100,1.1850439,5.4827557,,,,,,,,,,,,,, -32200,0.93503976,5.441985,,,,,,,,,,,,,, -32300,1.0320854,3.1747644,,,,,,,,,,,,,, -32400,1.1927608,3.0184743,,,,,,,,,,,,,, -32500,0.84223884,5.3877482,,,,,,,,,,,,,, -32523,,,0.5832421779632568,1.8074359893798828,0.5415599942207336,2.009346008300781,50000.0,0.4262000322341919,2.658798217773437,10000.0,15164.471905231476,17090.816098451614,15164.471905231476,1923.5166292190552,1.1434593200683594,0.0 -32600,1.1222118,2.8570037,,,,,,,,,,,,,, -32700,1.1572076,2.7999687,,,,,,,,,,,,,, -32800,1.0390922,2.7806256,,,,,,,,,,,,,, -32900,1.079821,2.8286633,,,,,,,,,,,,,, -33000,0.9547325,4.2652593,,,,,,,,,,,,,, -33100,1.1218317,2.733697,,,,,,,,,,,,,, -33200,1.2514726,2.906446,,,,,,,,,,,,,, -33300,1.1685443,2.838942,,,,,,,,,,,,,, -33400,0.8752509,4.4129868,,,,,,,,,,,,,, -33427,,,0.583300769329071,1.8000158071517944,0.5366399884223938,2.02001428604126,50000.0,0.4164000153541565,2.6936354637146,10000.0,15584.505164146423,17563.168486595154,15584.505164146423,1975.7538046836853,1.178342580795288,0.0 -33500,0.7840975,4.6238446,,,,,,,,,,,,,, -33600,1.1370703,2.8802392,,,,,,,,,,,,,, -33700,1.0878259,2.9004948,,,,,,,,,,,,,, -33800,1.1006458,2.9955606,,,,,,,,,,,,,, -33900,1.1712863,2.8824751,,,,,,,,,,,,,, -34000,1.1631105,2.9256837,,,,,,,,,,,,,, -34100,1.2243518,2.8398354,,,,,,,,,,,,,, -34200,1.2433724,3.0327358,,,,,,,,,,,,,, -34300,1.0241319,5.323407,,,,,,,,,,,,,, -34328,,,0.6021679639816284,1.6713885068893433,0.5493800044059753,1.922240614891052,50000.0,0.4330000281333923,2.5797040462493896,10000.0,16004.51957321167,18034.655037164688,16004.51957321167,2027.1375963687897,1.2206003665924072,0.0 -34400,1.0488526,2.923873,,,,,,,,,,,,,, -34500,1.2307694,2.8948202,,,,,,,,,,,,,, -34600,1.1361516,2.6936748,,,,,,,,,,,,,, -34700,1.1188233,2.7458787,,,,,,,,,,,,,, -34800,0.8543444,4.197951,,,,,,,,,,,,,, -34900,0.9840557,4.2000856,,,,,,,,,,,,,, -35000,0.9227489,5.3057933,,,,,,,,,,,,,, -35100,0.9593927,4.949767,,,,,,,,,,,,,, -35200,1.0476319,2.790627,,,,,,,,,,,,,, -35225,,,0.6086132526397705,1.6429468393325806,0.5566200017929077,1.910609126091004,50000.0,0.4413000345230102,2.5836009979248047,10000.0,16424.51560664177,18504.921072244644,16424.51560664177,2077.322345972061,1.2598974704742432,0.0 -35300,1.0084574,3.2868643,,,,,,,,,,,,,, -35400,1.0580374,2.6822948,,,,,,,,,,,,,, -35500,1.2267259,2.820411,,,,,,,,,,,,,, -35600,1.030077,4.852878,,,,,,,,,,,,,, -35700,1.200533,2.774024,,,,,,,,,,,,,, -35800,0.9458052,5.3714566,,,,,,,,,,,,,, -35900,1.1453891,2.8955522,,,,,,,,,,,,,, -36000,1.1667143,2.7246697,,,,,,,,,,,,,, -36100,1.0185535,2.7431107,,,,,,,,,,,,,, -36125,,,0.5963281393051147,1.712787389755249,0.5558199882507324,1.907938957214356,50000.0,0.4359000325202942,2.588874578475952,10000.0,16844.621319770813,18976.279500722885,16844.621319770813,2128.4921691417694,1.2963852882385254,0.0 -36200,1.0960879,2.7673912,,,,,,,,,,,,,, -36300,1.0738765,2.778564,,,,,,,,,,,,,, -36400,0.98774064,3.137389,,,,,,,,,,,,,, -36500,1.0217899,3.9904182,,,,,,,,,,,,,, -36600,1.0973132,2.7587547,,,,,,,,,,,,,, -36700,1.1353073,2.7415674,,,,,,,,,,,,,, -36800,1.2386618,2.8395858,,,,,,,,,,,,,, -36900,0.9000441,4.76402,,,,,,,,,,,,,, -37000,1.1273332,2.870947,,,,,,,,,,,,,, -37028,,,0.6057812571525574,1.6403157711029053,0.5555599927902222,1.885812759399414,50000.0,0.4410000145435333,2.56365966796875,10000.0,17264.70506668091,19447.05951809883,17264.70506668091,2179.108320236206,1.329272985458374,0.0 -37100,1.317245,2.8025403,,,,,,,,,,,,,, -37200,1.0620605,3.255684,,,,,,,,,,,,,, -37300,1.1526132,2.5672143,,,,,,,,,,,,,, -37400,1.1926044,2.6236787,,,,,,,,,,,,,, -37500,1.2014288,2.9605484,,,,,,,,,,,,,, -37600,1.1723772,2.7430258,,,,,,,,,,,,,, -37700,1.2173724,3.1419144,,,,,,,,,,,,,, -37800,0.9143556,4.177885,,,,,,,,,,,,,, -37900,0.94650733,4.337064,,,,,,,,,,,,,, -37930,,,0.6322460770606995,1.546537160873413,0.5542399883270264,1.9083751440048216,50000.0,0.4363000094890594,2.5774965286254883,10000.0,17684.651047706604,19920.133969783783,17684.651047706604,2232.1497917175293,1.3694918155670166,0.0 -38000,1.1685024,2.7343693,,,,,,,,,,,,,, -38100,0.938628,5.381082,,,,,,,,,,,,,, -38200,1.0099195,3.5367239,,,,,,,,,,,,,, -38300,1.1535782,2.799279,,,,,,,,,,,,,, -38400,0.9413564,3.900082,,,,,,,,,,,,,, -38500,1.2362952,2.710474,,,,,,,,,,,,,, -38600,1.1949755,2.8174057,,,,,,,,,,,,,, -38700,0.8731194,4.0412407,,,,,,,,,,,,,, -38800,1.1526008,2.7308517,,,,,,,,,,,,,, -38831,,,0.5987695455551147,1.698283076286316,0.5547999739646912,1.91600239276886,50000.0,0.4299000203609466,2.600237846374512,10000.0,18104.869975566864,20393.270917892456,18104.869975566864,2284.9851546287537,1.4052977561950684,0.0 -38900,1.2709877,2.6602676,,,,,,,,,,,,,, -39000,1.0420883,3.3634977,,,,,,,,,,,,,, -39100,1.3801379,2.931635,,,,,,,,,,,,,, -39200,1.0740453,3.2818635,,,,,,,,,,,,,, -39300,0.9205869,3.8511715,,,,,,,,,,,,,, -39400,0.9259073,3.7547097,,,,,,,,,,,,,, -39500,1.158804,2.8317206,,,,,,,,,,,,,, -39600,1.3437564,2.7485936,,,,,,,,,,,,,, -39700,1.1812599,2.9877343,,,,,,,,,,,,,, -39728,,,0.609570324420929,1.6623051166534424,0.5617200136184692,1.9099894762039185,50000.0,0.4417000114917755,2.5730700492858887,10000.0,18524.83588886261,20877.36520934105,18524.83588886261,2349.030675649643,1.4428019523620603,0.0 -39800,1.2226255,2.6373036,,,,,,,,,,,,,, -39900,1.1806732,2.8146963,,,,,,,,,,,,,, -40000,0.902299,4.7838316,,,,,,,,,,,,,, -40100,1.3211712,2.792802,,,,,,,,,,,,,, -40200,1.1962823,2.6789815,,,,,,,,,,,,,, -40300,1.0906937,2.7374177,,,,,,,,,,,,,, -40400,1.1244042,2.6885128,,,,,,,,,,,,,, -40500,1.3882867,2.7501535,,,,,,,,,,,,,, -40600,0.97342974,3.6642833,,,,,,,,,,,,,, -40622,,,0.6304491758346558,1.5252255201339722,0.5624600052833557,1.8603116273880005,50000.0,0.4377000331878662,2.548011302947998,10000.0,18944.7302005291,21348.91973233223,18944.7302005291,2400.5719459056854,1.515040397644043,0.0 -40700,1.0987028,2.778509,,,,,,,,,,,,,, -40800,1.0407121,2.9756835,,,,,,,,,,,,,, -40900,0.87749314,5.244214,,,,,,,,,,,,,, -41000,1.1256043,3.3199317,,,,,,,,,,,,,, -41100,1.2078189,2.8617344,,,,,,,,,,,,,, -41200,1.1888967,2.6841974,,,,,,,,,,,,,, -41300,1.1889193,3.1773329,,,,,,,,,,,,,, -41400,1.0880549,2.5743885,,,,,,,,,,,,,, -41500,0.92587954,4.568495,,,,,,,,,,,,,, -41519,,,0.6089648008346558,1.633178949356079,0.568839967250824,1.8405061960220337,50000.0,0.4422000348567962,2.518946886062622,10000.0,19365.09273838997,21822.46431851387,19365.09273838997,2453.674109697342,1.5481061935424805,0.0 -41600,0.93742305,4.7432528,,,,,,,,,,,,,, -41700,0.92009807,5.3330655,,,,,,,,,,,,,, -41800,1.2165452,2.5960388,,,,,,,,,,,,,, -41900,0.90147924,4.834522,,,,,,,,,,,,,, -42000,1.0874445,2.7802203,,,,,,,,,,,,,, -42100,0.95748335,4.402976,,,,,,,,,,,,,, -42200,1.1773762,2.7524323,,,,,,,,,,,,,, -42300,1.1415262,2.6302063,,,,,,,,,,,,,, -42400,1.1165067,3.327847,,,,,,,,,,,,,, -42423,,,0.6158398389816284,1.5993508100509644,0.5668599605560303,1.833824157714844,50000.0,0.4497000277042389,2.5079567432403564,10000.0,19785.181124925613,22294.79103422165,19785.181124925613,2505.828282117844,1.584458351135254,0.0 -42500,1.0537744,2.8185756,,,,,,,,,,,,,, -42600,1.0128584,4.0112085,,,,,,,,,,,,,, -42700,1.0921096,2.6881008,,,,,,,,,,,,,, -42800,1.2526143,2.6854339,,,,,,,,,,,,,, -42900,1.3335773,2.7587144,,,,,,,,,,,,,, -43000,1.2431086,2.873466,,,,,,,,,,,,,, -43100,1.2112993,2.7360606,,,,,,,,,,,,,, -43200,1.2229105,2.7459688,,,,,,,,,,,,,, -43300,1.270217,3.0015028,,,,,,,,,,,,,, -43326,,,0.6241992115974426,1.6059962511062622,0.5625,1.9052478075027464,50000.0,0.4489000141620636,2.537725687026977,10000.0,20205.273720502853,22766.971638917923,20205.273720502853,2557.823818206787,1.628956317901611,0.0 -43400,1.1044735,2.6837635,,,,,,,,,,,,,, -43500,1.0162871,3.4268415,,,,,,,,,,,,,, -43600,1.1790841,2.8345325,,,,,,,,,,,,,, -43700,1.0042152,3.4479597,,,,,,,,,,,,,, -43800,0.91000324,4.2235804,,,,,,,,,,,,,, -43900,1.0579392,4.727058,,,,,,,,,,,,,, -44000,1.1132,2.6529288,,,,,,,,,,,,,, -44100,1.1646298,2.6329193,,,,,,,,,,,,,, -44200,1.1846414,3.0268402,,,,,,,,,,,,,, -44229,,,0.6158398389816284,1.622150421142578,0.5743799805641174,1.831883668899536,50000.0,0.4519000351428985,2.4976508617401123,10000.0,20625.553758859634,23239.426404237747,20625.553758859634,2609.916860103607,1.6631481647491455,0.0 -44300,1.144713,2.737261,,,,,,,,,,,,,, -44400,1.1200504,2.792587,,,,,,,,,,,,,, -44500,1.3504722,2.7437177,,,,,,,,,,,,,, -44600,1.1596849,3.0447292,,,,,,,,,,,,,, -44700,1.1441033,2.6635237,,,,,,,,,,,,,, -44800,1.079852,3.222365,,,,,,,,,,,,,, -44900,1.0462697,2.903251,,,,,,,,,,,,,, -45000,1.1516206,3.1309218,,,,,,,,,,,,,, -45100,1.1898308,2.616238,,,,,,,,,,,,,, -45133,,,0.6198632717132568,1.6021332740783691,0.5716800093650818,1.8257529735565183,50000.0,0.4542000293731689,2.4946107864379883,10000.0,21044.99322247505,23713.70925736428,21044.99322247505,2664.1930780410767,2.1820499897003174,0.0 -45200,0.962863,5.13145,,,,,,,,,,,,,, -45300,1.2248404,2.6488147,,,,,,,,,,,,,, -45400,1.1599709,2.6466484,,,,,,,,,,,,,, -45500,1.0853652,2.6728594,,,,,,,,,,,,,, -45600,0.99616975,3.4823544,,,,,,,,,,,,,, -45700,1.1093569,2.7579281,,,,,,,,,,,,,, -45800,1.1113678,2.753408,,,,,,,,,,,,,, -45900,1.0312599,3.3823602,,,,,,,,,,,,,, -46000,0.96287453,3.1106062,,,,,,,,,,,,,, -46038,,,0.6359961032867432,1.5632343292236328,0.5751199722290039,1.8546125888824463,50000.0,0.4570000171661377,2.525801658630371,10000.0,21465.308529138565,24186.573442697525,21465.308529138565,2716.657520532608,2.2180991172790527,0.0 -46100,1.1049482,3.341744,,,,,,,,,,,,,, -46200,1.0103083,4.0894737,,,,,,,,,,,,,, -46300,1.0994298,2.744278,,,,,,,,,,,,,, -46400,0.9658959,3.9749336,,,,,,,,,,,,,, -46500,1.0435332,3.2144215,,,,,,,,,,,,,, -46600,1.1802548,2.682899,,,,,,,,,,,,,, -46700,1.1815846,2.6454573,,,,,,,,,,,,,, -46800,1.2748785,2.7217782,,,,,,,,,,,,,, -46900,1.1521957,4.937114,,,,,,,,,,,,,, -46941,,,0.6237695217132568,1.593305587768555,0.583840012550354,1.7873661518096924,50000.0,0.4579000174999237,2.460035562515259,10000.0,21885.264724254608,24660.88339877129,21885.264724254608,2770.924865722656,2.2558722496032715,0.0 -47000,1.1844765,2.6556742,,,,,,,,,,,,,, -47100,1.2509977,2.604765,,,,,,,,,,,,,, -47200,0.9069642,4.9066663,,,,,,,,,,,,,, -47300,1.182786,2.7242384,,,,,,,,,,,,,, -47400,1.0561243,3.4977462,,,,,,,,,,,,,, -47500,1.1707263,2.7863252,,,,,,,,,,,,,, -47600,1.3109585,2.6889954,,,,,,,,,,,,,, -47700,1.2542287,2.536963,,,,,,,,,,,,,, -47800,1.0048093,4.4914055,,,,,,,,,,,,,, -47845,,,0.6243554353713989,1.5742669105529783,0.5797199606895447,1.8012193441390991,50000.0,0.4599000215530395,2.456450223922729,10000.0,22305.34062218666,25134.393942832947,22305.34062218666,2824.276191473007,2.2912559509277344,0.0 -47900,1.2600613,2.6188312,,,,,,,,,,,,,, -48000,0.98619163,3.8804078,,,,,,,,,,,,,, -48100,0.9848935,4.412178,,,,,,,,,,,,,, -48200,1.3004899,2.7179265,,,,,,,,,,,,,, -48300,1.0810556,2.8450434,,,,,,,,,,,,,, -48400,1.1044549,2.616222,,,,,,,,,,,,,, -48500,1.2720064,2.7249506,,,,,,,,,,,,,, -48600,1.1709532,2.4411669,,,,,,,,,,,,,, -48700,1.2001551,2.591934,,,,,,,,,,,,,, -48753,,,0.6461523175239563,1.4468914270401,0.5879200100898743,1.7366979122161863,50000.0,0.4649000167846679,2.4156606197357178,10000.0,22725.4294102192,25606.750625133514,22725.4294102192,2876.451560020447,2.335695505142212,0.0 -48800,1.1508139,2.602172,,,,,,,,,,,,,, -48900,1.1732996,2.4761798,,,,,,,,,,,,,, -49000,1.0755126,4.8766885,,,,,,,,,,,,,, -49100,0.9512854,5.3261404,,,,,,,,,,,,,, -49200,1.4544686,2.6451705,,,,,,,,,,,,,, -49300,1.108628,3.064887,,,,,,,,,,,,,, -49400,1.1120446,2.6358087,,,,,,,,,,,,,, -49500,1.0978917,3.1125484,,,,,,,,,,,,,, -49600,1.1636224,2.590482,,,,,,,,,,,,,, -49661,,,0.6366406083106995,1.5108044147491455,0.5891199707984924,1.732394456863403,50000.0,0.4748000204563141,2.404741048812866,10000.0,23145.645364046097,26079.421197891235,23145.645364046097,2928.817025184632,2.37554931640625,0.0 -49700,1.243177,2.5958886,,,,,,,,,,,,,, -49800,1.1106181,5.2469935,,,,,,,,,,,,,, -49900,0.9631252,4.583314,,,,,,,,,,,,,, -50000,1.140199,2.5517197,,,,,,,,,,,,,, -50100,0.9648954,4.10553,,,,,,,,,,,,,, -50200,1.091547,2.8447156,,,,,,,,,,,,,, -50300,0.9978571,4.1900964,,,,,,,,,,,,,, -50400,1.0565536,3.3763745,,,,,,,,,,,,,, -50500,1.11003,2.5897954,,,,,,,,,,,,,, -50559,,,0.637988269329071,1.5000442266464231,0.5877199769020081,1.7394182682037354,50000.0,0.4726000130176544,2.380942344665528,10000.0,23565.832825899124,26552.102893590927,23565.832825899124,2981.225423812866,2.413803577423096,0.0 -50600,0.9900041,5.174776,,,,,,,,,,,,,, -50700,1.0910842,3.3546002,,,,,,,,,,,,,, -50800,1.1756188,2.4713082,,,,,,,,,,,,,, -50900,1.2247837,2.667298,,,,,,,,,,,,,, -51000,1.2195655,3.2629633,,,,,,,,,,,,,, -51100,1.3470459,2.6664777,,,,,,,,,,,,,, -51200,1.3332767,2.5890856,,,,,,,,,,,,,, -51300,1.2218624,2.6032014,,,,,,,,,,,,,, -51400,1.1870493,3.056457,,,,,,,,,,,,,, -51462,,,0.6421484351158142,1.497378706932068,0.5862199664115906,1.781178593635559,50000.0,0.46670001745224,2.434500217437744,10000.0,23985.97656941414,27024.848866462708,23985.97656941414,3033.735938310623,2.457144498825073,0.0 -51500,1.2691127,2.777162,,,,,,,,,,,,,, -51600,1.0596371,4.4767013,,,,,,,,,,,,,, -51700,1.2971017,2.5924003,,,,,,,,,,,,,, -51800,1.0477862,3.6281247,,,,,,,,,,,,,, -51900,1.0423256,3.4843454,,,,,,,,,,,,,, -52000,1.022578,3.0703084,,,,,,,,,,,,,, -52100,0.9439658,3.8329785,,,,,,,,,,,,,, -52200,1.1751195,2.7642803,,,,,,,,,,,,,, -52300,1.1527321,2.9698887,,,,,,,,,,,,,, -52366,,,0.6407226324081421,1.4778684377670288,0.5956599712371826,1.697656512260437,50000.0,0.4722000360488891,2.3861453533172607,10000.0,24406.12998008728,27497.53017377853,24406.12998008728,3086.1749098300934,2.498885154724121,0.0 -52400,1.0979978,2.7816188,,,,,,,,,,,,,, -52500,1.3079219,2.571219,,,,,,,,,,,,,, -52600,1.0362041,3.2254145,,,,,,,,,,,,,, -52700,1.2398823,2.5344214,,,,,,,,,,,,,, -52800,1.1767572,2.6491823,,,,,,,,,,,,,, -52900,1.1097146,3.9961648,,,,,,,,,,,,,, -53000,1.0898614,4.9738073,,,,,,,,,,,,,, -53100,1.2345821,2.6134408,,,,,,,,,,,,,, -53200,1.2302934,2.7260408,,,,,,,,,,,,,, -53267,,,0.6386523246765137,1.511078119277954,0.5876399874687195,1.755619764328003,50000.0,0.4702000319957733,2.407470703125,10000.0,24826.21729016304,27969.93226337433,24826.21729016304,3138.3990700244904,2.5419416427612305,0.0 -53300,0.98228794,5.1253996,,,,,,,,,,,,,, -53400,1.1809603,2.7671402,,,,,,,,,,,,,, -53500,1.0663486,2.9713988,,,,,,,,,,,,,, -53600,1.3468755,2.5176306,,,,,,,,,,,,,, -53700,1.1175761,3.2962775,,,,,,,,,,,,,, -53800,1.0287379,3.4544787,,,,,,,,,,,,,, -53900,0.96640986,4.309713,,,,,,,,,,,,,, -54000,0.9611887,4.316653,,,,,,,,,,,,,, -54100,1.0865937,3.4489293,,,,,,,,,,,,,, -54170,,,0.6471874713897705,1.4455610513687134,0.5920599699020386,1.721648097038269,50000.0,0.4725000262260437,2.3761520385742188,10000.0,25246.371876716614,28441.82798075676,25246.371876716614,3190.0559356212616,2.578366041183472,0.0 -54200,1.3601322,2.604629,,,,,,,,,,,,,, -54300,1.0078264,3.7860625,,,,,,,,,,,,,, -54400,1.282842,2.5715227,,,,,,,,,,,,,, -54500,1.288707,2.6157975,,,,,,,,,,,,,, -54600,1.0364949,3.2696543,,,,,,,,,,,,,, -54700,1.2162663,2.6158872,,,,,,,,,,,,,, -54800,1.0513278,4.3656116,,,,,,,,,,,,,, -54900,1.1867838,2.547098,,,,,,,,,,,,,, -55000,1.2405643,2.6536894,,,,,,,,,,,,,, -55074,,,0.6447460651397705,1.47289776802063,0.5917400121688843,1.7395529747009275,50000.0,0.4713000357151031,2.3855762481689453,10000.0,25666.634063720703,28914.098296403885,25666.634063720703,3241.9795260429382,2.6154439449310303,0.0 -55100,1.1692289,2.5476873,,,,,,,,,,,,,, -55200,1.1706998,2.7573144,,,,,,,,,,,,,, -55300,1.3251265,2.558076,,,,,,,,,,,,,, -55400,1.1561202,2.4439428,,,,,,,,,,,,,, -55500,1.1417439,2.564032,,,,,,,,,,,,,, -55600,1.3508905,2.6432476,,,,,,,,,,,,,, -55700,1.2823226,2.5603294,,,,,,,,,,,,,, -55800,1.3095082,2.6437168,,,,,,,,,,,,,, -55900,1.2949227,2.638536,,,,,,,,,,,,,, -55970,,,0.6394140720367432,1.5037568807601929,0.5911999940872192,1.7273492813110352,50000.0,0.4783000349998474,2.387146711349488,10000.0,26086.77654409409,29387.199322223663,26086.77654409409,3294.8521118164062,2.6539435386657715,0.0 -56000,1.1722289,2.8328817,,,,,,,,,,,,,, -56100,1.3048415,3.1418066,,,,,,,,,,,,,, -56200,1.1559749,2.506458,,,,,,,,,,,,,, -56300,1.2796171,2.5557077,,,,,,,,,,,,,, -56400,1.2372555,2.7250586,,,,,,,,,,,,,, -56500,1.2028837,2.6733534,,,,,,,,,,,,,, -56600,1.0765911,5.2774696,,,,,,,,,,,,,, -56700,1.2010003,2.4449248,,,,,,,,,,,,,, -56800,0.9841873,4.5768304,,,,,,,,,,,,,, -56872,,,0.6453320384025574,1.4583547115325928,0.5988999605178833,1.701833724975586,50000.0,0.4745000302791595,2.365806579589844,10000.0,26506.76273608208,29860.154549121857,26506.76273608208,3347.73197889328,2.6959104537963867,0.0 -56900,0.99374956,4.627356,,,,,,,,,,,,,, -57000,0.9981979,4.576369,,,,,,,,,,,,,, -57100,0.9793474,4.672213,,,,,,,,,,,,,, -57200,0.92942953,4.453847,,,,,,,,,,,,,, -57300,1.1871748,3.390582,,,,,,,,,,,,,, -57400,1.2548147,2.4320953,,,,,,,,,,,,,, -57500,1.2686018,2.7862322,,,,,,,,,,,,,, -57600,1.0294732,3.651391,,,,,,,,,,,,,, -57700,1.2867078,2.5821805,,,,,,,,,,,,,, -57778,,,0.6798242330551147,1.318408727645874,0.6038199663162231,1.671668529510498,50000.0,0.4857000112533569,2.3210160732269287,10000.0,26926.744647026066,30332.42511534691,26926.744647026066,3399.9325094223022,2.73496150970459,0.0 -57800,1.019792,3.2251606,,,,,,,,,,,,,, -57900,1.2509553,2.6372197,,,,,,,,,,,,,, -58000,1.2300485,2.5538714,,,,,,,,,,,,,, -58100,1.1471488,4.113748,,,,,,,,,,,,,, -58200,0.9461318,4.2998037,,,,,,,,,,,,,, -58300,1.1709981,2.531231,,,,,,,,,,,,,, -58400,1.3251023,2.5727565,,,,,,,,,,,,,, -58500,1.1934739,3.7973003,,,,,,,,,,,,,, -58600,1.1878562,2.9497278,,,,,,,,,,,,,, -58680,,,0.6500585675239563,1.4647369384765625,0.5982599854469299,1.699963092803955,50000.0,0.4785000085830688,2.356879711151123,10000.0,27346.96013259888,30803.39335131645,27346.96013259888,3450.595221042633,2.778135299682617,0.0 -58700,1.2331252,2.5305169,,,,,,,,,,,,,, -58800,1.2705355,2.6022506,,,,,,,,,,,,,, -58900,0.9600449,4.0119667,,,,,,,,,,,,,, -59000,1.4650952,2.4797604,,,,,,,,,,,,,, -59100,1.4566737,3.2931929,,,,,,,,,,,,,, -59200,1.2687235,2.51197,,,,,,,,,,,,,, -59300,0.90564466,5.102538,,,,,,,,,,,,,, -59400,0.9951512,5.3069873,,,,,,,,,,,,,, -59500,1.118023,5.126647,,,,,,,,,,,,,, -59582,,,0.6515820026397705,1.4490962028503418,0.6025399565696716,1.689075589179993,50000.0,0.4764000177383423,2.356093168258667,10000.0,27767.20039820671,31276.52426123619,27767.20039820671,3503.4003579616547,2.815491199493408,0.0 -59600,1.1382269,2.4542727,,,,,,,,,,,,,, -59700,0.9297465,4.8940086,,,,,,,,,,,,,, -59800,1.1116986,2.9365807,,,,,,,,,,,,,, -59900,1.2645167,2.4706063,,,,,,,,,,,,,, -60000,1.056127,3.4590063,,,,,,,,,,,,,, -60100,1.2210362,2.579693,,,,,,,,,,,,,, -60200,1.2946377,2.5338311,,,,,,,,,,,,,, -60300,1.0936873,4.046583,,,,,,,,,,,,,, -60400,1.3374661,2.5426912,,,,,,,,,,,,,, -60486,,,0.6818945407867432,1.308190941810608,0.6083399653434753,1.650006651878357,50000.0,0.4871000349521637,2.3113863468170166,10000.0,28187.20550107956,31748.22876477241,28187.20550107956,3555.011140346527,2.856693983078003,0.0 -60500,1.4529138,2.5745144,,,,,,,,,,,,,, -60600,1.1530545,2.9982696,,,,,,,,,,,,,, -60700,1.1543102,2.5701916,,,,,,,,,,,,,, -60800,1.2669061,2.6866663,,,,,,,,,,,,,, -60900,1.1190436,3.686316,,,,,,,,,,,,,, -61000,1.151921,4.4786825,,,,,,,,,,,,,, -61100,1.2157133,2.6635973,,,,,,,,,,,,,, -61200,0.95521766,4.084901,,,,,,,,,,,,,, -61300,1.271279,2.492612,,,,,,,,,,,,,, -61388,,,0.6479101181030273,1.457031011581421,0.6025999784469604,1.6835922002792358,50000.0,0.4802000224590301,2.3440380096435547,10000.0,28607.25713896752,32220.116106033325,28607.25713896752,3606.759976387024,2.89612340927124,0.0 -61400,1.4319098,2.5242546,,,,,,,,,,,,,, -61500,1.375633,2.5078735,,,,,,,,,,,,,, -61600,1.2860432,2.9006376,,,,,,,,,,,,,, -61700,1.1056907,4.0284157,,,,,,,,,,,,,, -61800,1.1865865,2.5762677,,,,,,,,,,,,,, -61900,1.203618,2.5038047,,,,,,,,,,,,,, -62000,1.2258317,2.5854475,,,,,,,,,,,,,, -62100,1.2102922,2.7951765,,,,,,,,,,,,,, -62200,1.3172604,2.5209332,,,,,,,,,,,,,, -62289,,,0.6520702838897705,1.4345024824142456,0.6013000011444092,1.6837353706359863,50000.0,0.4809000194072723,2.35693907737732,10000.0,29027.53623533249,32690.442022562027,29027.53623533249,3656.720737695694,2.934617042541504,0.0 -62300,1.0589728,5.095432,,,,,,,,,,,,,, -62400,1.098274,2.6041822,,,,,,,,,,,,,, -62500,1.3014414,3.2449975,,,,,,,,,,,,,, -62600,1.0624971,2.6683524,,,,,,,,,,,,,, -62700,1.077305,4.1449423,,,,,,,,,,,,,, -62800,1.1168642,4.8815002,,,,,,,,,,,,,, -62900,1.2052968,2.4517303,,,,,,,,,,,,,, -63000,1.3345386,2.5006123,,,,,,,,,,,,,, -63100,1.3498138,2.4777954,,,,,,,,,,,,,, -63194,,,0.6779687404632568,1.3029876947402954,0.6114599704742432,1.6231791973114014,50000.0,0.4918000102043152,2.2952914237976074,10000.0,29447.884727954865,33163.27663230896,29447.884727954865,3709.118235349655,2.974860906600952,0.0 -63200,1.096317,4.3975167,,,,,,,,,,,,,, -63300,1.1167971,3.005075,,,,,,,,,,,,,, -63400,1.0578792,5.0713797,,,,,,,,,,,,,, -63500,1.1336198,3.648084,,,,,,,,,,,,,, -63600,0.9680802,4.4546685,,,,,,,,,,,,,, -63700,1.2477535,2.4264083,,,,,,,,,,,,,, -63800,0.96829504,4.448554,,,,,,,,,,,,,, -63900,1.1674781,4.62156,,,,,,,,,,,,,, -64000,1.3147874,2.4964724,,,,,,,,,,,,,, -64090,,,0.6555468440055847,1.3943846225738523,0.6055799722671509,1.6383819580078125,50000.0,0.4866000115871429,2.292811155319214,10000.0,29868.319049596783,33636.6328959465,29868.319049596783,3761.9516339302063,3.0155563354492188,0.0 -64100,1.2451307,2.8004267,,,,,,,,,,,,,, -64200,1.1161962,5.0459127,,,,,,,,,,,,,, -64300,1.1235532,2.3963118,,,,,,,,,,,,,, -64400,0.9831942,4.884793,,,,,,,,,,,,,, -64500,1.1746944,2.544019,,,,,,,,,,,,,, -64600,1.2161547,2.904422,,,,,,,,,,,,,, -64700,1.2835698,2.57195,,,,,,,,,,,,,, -64800,1.1450241,2.9389167,,,,,,,,,,,,,, -64900,1.1986235,3.553226,,,,,,,,,,,,,, -64992,,,0.6604882478713989,1.4079033136367798,0.6082199811935425,1.6526474952697754,50000.0,0.4844000339508056,2.3205599784851074,10000.0,30288.43496155739,34108.627766132355,30288.43496155739,3813.739384889602,3.059473991394043,0.0 -65000,1.2096443,4.7563467,,,,,,,,,,,,,, -65100,1.0329965,3.5957313,,,,,,,,,,,,,, -65200,1.2653129,2.5949166,,,,,,,,,,,,,, -65300,1.1240672,5.1365666,,,,,,,,,,,,,, -65400,1.177009,2.533421,,,,,,,,,,,,,, -65500,1.2446206,2.6063213,,,,,,,,,,,,,, -65600,1.2049901,2.848548,,,,,,,,,,,,,, -65700,1.3704773,2.5105405,,,,,,,,,,,,,, -65800,1.5457156,2.513243,,,,,,,,,,,,,, -65896,,,0.6743749976158142,1.3156598806381226,0.6094799637794495,1.6173152923583984,50000.0,0.4891000092029571,2.299282550811768,10000.0,30708.749766349792,34583.26203036308,30708.749766349792,3867.965404987335,3.1049630641937256,0.0 -65900,1.12344,3.1859763,,,,,,,,,,,,,, -66000,1.2639765,2.3962572,,,,,,,,,,,,,, -66100,1.0338273,4.5174165,,,,,,,,,,,,,, -66200,1.253046,2.8497233,,,,,,,,,,,,,, -66300,1.3945811,2.519852,,,,,,,,,,,,,, -66400,1.0914178,2.9247022,,,,,,,,,,,,,, -66500,1.2793531,2.392457,,,,,,,,,,,,,, -66600,1.1326174,2.569742,,,,,,,,,,,,,, -66700,1.1093001,5.0374813,,,,,,,,,,,,,, -66799,,,0.6649413704872131,1.3905673027038574,0.613599956035614,1.6307467222213743,50000.0,0.4841000139713287,2.2967209815979004,10000.0,31128.717413187027,35056.56464314461,31128.717413187027,3921.2101068496704,3.146657943725586,0.0 -66800,1.1988627,3.053912,,,,,,,,,,,,,, -66900,1.2061535,4.1509733,,,,,,,,,,,,,, -67000,0.9826975,4.834099,,,,,,,,,,,,,, -67100,1.2626214,2.5272923,,,,,,,,,,,,,, -67200,1.0910181,3.9401271,,,,,,,,,,,,,, -67300,1.3325577,4.0485067,,,,,,,,,,,,,, -67400,1.2334913,2.9085844,,,,,,,,,,,,,, -67500,1.2996252,2.5171149,,,,,,,,,,,,,, -67600,1.2546893,2.6044028,,,,,,,,,,,,,, -67700,,,0.666308581829071,1.380321025848389,0.6187199950218201,1.6084468364715576,50000.0,0.4953000247478485,2.2668533325195312,10000.0,31548.65906047821,35530.14267539978,31548.65906047821,3974.759640932083,3.186652421951294,0.0 -67700,1.2732835,2.4987805,,,,,,,,,,,,,, -67800,1.2998642,2.450993,,,,,,,,,,,,,, -67900,1.1510891,5.170457,,,,,,,,,,,,,, -68000,1.1500791,3.211639,,,,,,,,,,,,,, -68100,1.1162095,3.1593366,,,,,,,,,,,,,, -68200,1.1477953,2.9906387,,,,,,,,,,,,,, -68300,1.1698415,2.8027744,,,,,,,,,,,,,, -68400,1.1075535,3.6053991,,,,,,,,,,,,,, -68500,1.3898569,2.4849553,,,,,,,,,,,,,, -68600,1.4371914,2.4579263,,,,,,,,,,,,,, -68601,,,0.6759960651397705,1.306284785270691,0.6125800013542175,1.6162729263305664,50000.0,0.491100013256073,2.2899532318115234,10000.0,31968.73847270012,36000.59927082062,31968.73847270012,4025.041212797165,3.2343225479125977,0.0 -68700,1.2645428,2.4332628,,,,,,,,,,,,,, -68800,1.0867491,4.9622083,,,,,,,,,,,,,, -68900,1.4362596,2.5762625,,,,,,,,,,,,,, -69000,1.2605625,2.3918462,,,,,,,,,,,,,, -69100,1.5729225,2.4999852,,,,,,,,,,,,,, -69200,1.3269477,2.3998988,,,,,,,,,,,,,, -69300,1.244372,2.4880774,,,,,,,,,,,,,, -69400,1.3824911,2.4310768,,,,,,,,,,,,,, -69500,1.2918837,2.5873876,,,,,,,,,,,,,, -69503,,,0.6673437356948853,1.392255187034607,0.6157999634742737,1.6166682243347168,50000.0,0.495600014925003,2.2810490131378174,10000.0,32388.73214316368,36475.47456264496,32388.73214316368,4079.8373177051535,3.27199387550354,0.0 -69600,1.346257,2.3961449,,,,,,,,,,,,,, -69700,1.2170148,2.4523249,,,,,,,,,,,,,, -69800,1.2206186,3.6388617,,,,,,,,,,,,,, -69900,1.4002696,2.5930436,,,,,,,,,,,,,, -70000,1.2816784,2.4324565,,,,,,,,,,,,,, -70100,1.2371716,2.5701513,,,,,,,,,,,,,, -70200,1.4036533,2.9580042,,,,,,,,,,,,,, -70300,1.1233549,3.7096472,,,,,,,,,,,,,, -70398,,,0.6661718487739563,1.4046850204467771,0.6162399649620056,1.6451646089553833,50000.0,0.4962000250816345,2.304437160491944,10000.0,32808.83074641228,36949.61338496208,32808.83074641228,4133.789508104324,3.3121135234832764,0.0 -70400,1.2191657,2.6513963,,,,,,,,,,,,,, -70500,1.1854275,2.4604497,,,,,,,,,,,,,, -70600,1.4270524,2.5163002,,,,,,,,,,,,,, -70700,1.2485437,2.536599,,,,,,,,,,,,,, -70800,1.2198372,2.3693225,,,,,,,,,,,,,, -70900,1.0919261,3.9641476,,,,,,,,,,,,,, -71000,1.0964704,5.1223187,,,,,,,,,,,,,, -71100,1.2391293,2.469298,,,,,,,,,,,,,, -71200,1.3477107,2.6689205,,,,,,,,,,,,,, -71300,,,0.671875,1.3597897291183472,0.6108399629592896,1.6405349969863892,50000.0,0.4915000200271606,2.279672861099243,10000.0,33228.976645469666,37422.65373325348,33228.976645469666,4186.583588838577,3.3608882427215576,0.0 -71300,1.3131051,2.6880598,,,,,,,,,,,,,, -71400,1.0990007,3.6836746,,,,,,,,,,,,,, -71500,1.1922177,2.8827367,,,,,,,,,,,,,, -71600,1.1373901,3.6771278,,,,,,,,,,,,,, -71700,1.2957847,2.3039055,,,,,,,,,,,,,, -71800,0.9932348,4.2709823,,,,,,,,,,,,,, -71900,1.4644312,2.4206219,,,,,,,,,,,,,, -72000,1.2771955,2.8982892,,,,,,,,,,,,,, -72100,1.0194398,4.8928604,,,,,,,,,,,,,, -72200,1.3160281,2.495253,,,,,,,,,,,,,, -72203,,,0.6790234446525574,1.29867684841156,0.6268599629402161,1.5448366403579712,50000.0,0.5045000314712524,2.2146520614624023,10000.0,33649.24731469154,37895.018824100494,33649.24731469154,4238.585715770721,3.405482292175293,0.0 -72300,1.1587089,3.213356,,,,,,,,,,,,,, -72400,0.9804038,4.2710996,,,,,,,,,,,,,, -72500,1.2591256,2.4483526,,,,,,,,,,,,,, -72600,1.1434832,3.4511888,,,,,,,,,,,,,, -72700,1.1885142,2.599995,,,,,,,,,,,,,, -72800,0.99756664,4.545849,,,,,,,,,,,,,, -72900,1.0922025,5.071219,,,,,,,,,,,,,, -73000,1.1081297,3.9729838,,,,,,,,,,,,,, -73100,1.2238387,2.353355,,,,,,,,,,,,,, -73109,,,0.6642382740974426,1.3958392143249512,0.6126999855041504,1.638340711593628,50000.0,0.4961000382900238,2.272181510925293,10000.0,34069.50347137451,38368.23072266579,34069.50347137451,4291.450464725494,3.447964191436768,0.0 -73200,1.0433081,5.0519238,,,,,,,,,,,,,, -73300,1.3853704,2.6689312,,,,,,,,,,,,,, -73400,1.2250627,2.6241043,,,,,,,,,,,,,, -73500,1.2749493,2.559022,,,,,,,,,,,,,, -73600,1.258745,2.5283327,,,,,,,,,,,,,, -73700,1.2006563,2.7393868,,,,,,,,,,,,,, -73800,1.2838554,2.4260478,,,,,,,,,,,,,, -73900,1.0271655,4.858547,,,,,,,,,,,,,, -74000,1.2309964,2.3654528,,,,,,,,,,,,,, -74011,,,0.6847460865974426,1.295527458190918,0.6244399547576904,1.5823724269866943,50000.0,0.5003000497817993,2.2567830085754395,10000.0,34489.668355703354,38841.82125282288,34489.668355703354,4344.784727811813,3.4918909072875977,0.0 -74100,1.181017,5.1252785,,,,,,,,,,,,,, -74200,1.4305762,2.5301623,,,,,,,,,,,,,, -74300,1.3045816,2.425611,,,,,,,,,,,,,, -74400,1.129352,3.8974624,,,,,,,,,,,,,, -74500,1.268332,2.4189515,,,,,,,,,,,,,, -74600,1.2187141,3.1071024,,,,,,,,,,,,,, -74700,1.3106899,5.033443,,,,,,,,,,,,,, -74800,1.2442858,2.4825828,,,,,,,,,,,,,, -74900,1.0987157,4.176252,,,,,,,,,,,,,, -74915,,,0.6771484017372131,1.3254854679107666,0.6175999641418457,1.5920697450637815,50000.0,0.499500036239624,2.259209156036377,10000.0,34909.960586071014,39313.50108551979,34909.960586071014,4396.087966918945,3.528670072555542,0.0 -75000,1.1898857,4.6816697,,,,,,,,,,,,,, -75100,1.2309498,2.468992,,,,,,,,,,,,,, -75200,1.3336568,2.3476691,,,,,,,,,,,,,, -75300,1.5117251,2.3812459,,,,,,,,,,,,,, -75400,1.5029595,2.4489684,,,,,,,,,,,,,, -75500,1.3832401,2.5698268,,,,,,,,,,,,,, -75600,1.2725312,3.14284,,,,,,,,,,,,,, -75700,1.0773933,3.7094603,,,,,,,,,,,,,, -75800,1.4414423,2.7168489,,,,,,,,,,,,,, -75819,,,0.6713671684265137,1.3427538871765137,0.6200799942016602,1.5841830968856812,50000.0,0.4980000257492065,2.25400972366333,10000.0,35330.13725447655,39784.29939389229,35330.13725447655,4446.619699478149,3.570764780044556,0.0 -75900,1.2829481,2.3520145,,,,,,,,,,,,,, -76000,1.2910056,2.7959433,,,,,,,,,,,,,, -76100,1.3101853,2.2484088,,,,,,,,,,,,,, -76200,1.3302927,2.435497,,,,,,,,,,,,,, -76300,1.3258326,2.4825468,,,,,,,,,,,,,, -76400,1.0398631,4.583538,,,,,,,,,,,,,, -76500,1.0518892,4.01755,,,,,,,,,,,,,, -76600,1.1823933,4.925175,,,,,,,,,,,,,, -76700,1.2529134,2.5615497,,,,,,,,,,,,,, -76720,,,0.6835741996765137,1.2883057594299316,0.6281999945640564,1.5592522621154783,50000.0,0.5045000314712524,2.233535051345825,10000.0,35750.3324637413,40256.8454182148,35750.3324637413,4498.87525844574,3.6172690391540527,0.0 -76800,1.1113405,3.9408598,,,,,,,,,,,,,, -76900,1.1537569,4.5877695,,,,,,,,,,,,,, -77000,1.4402119,2.357883,,,,,,,,,,,,,, -77100,1.3315258,2.2893746,,,,,,,,,,,,,, -77200,1.2611964,2.3772588,,,,,,,,,,,,,, -77300,1.2800549,2.4436738,,,,,,,,,,,,,, -77400,1.188305,2.6566687,,,,,,,,,,,,,, -77500,1.3823124,2.332466,,,,,,,,,,,,,, -77600,1.1924548,2.8385715,,,,,,,,,,,,,, -77628,,,0.7032226324081421,1.1986013650894165,0.6326799988746643,1.526818037033081,50000.0,0.5058000087738037,2.190584659576416,10000.0,36170.52016687393,40732.36606168747,36170.52016687393,4554.121444702148,3.6560237407684326,0.0 -77700,1.3494016,2.559764,,,,,,,,,,,,,, -77800,1.0544297,4.8891745,,,,,,,,,,,,,, -77900,1.2660707,2.9429326,,,,,,,,,,,,,, -78000,1.3731138,2.3621035,,,,,,,,,,,,,, -78100,1.3403715,2.4308608,,,,,,,,,,,,,, -78200,1.3686244,2.4600942,,,,,,,,,,,,,, -78300,1.2151299,4.041031,,,,,,,,,,,,,, -78400,1.5090624,2.350827,,,,,,,,,,,,,, -78500,1.3343965,2.3946705,,,,,,,,,,,,,, -78530,,,0.6783398389816284,1.3336890935897827,0.6248399615287781,1.5802377462387085,50000.0,0.5034000277519226,2.228025197982788,10000.0,36590.65875458717,41204.64127254486,36590.65875458717,4606.165371894836,3.701697587966919,0.0 -78600,1.1241463,3.0454247,,,,,,,,,,,,,, -78700,1.3819932,2.4006915,,,,,,,,,,,,,, -78800,1.0293512,4.729495,,,,,,,,,,,,,, -78900,1.0089321,5.000269,,,,,,,,,,,,,, -79000,1.0210443,3.9528258,,,,,,,,,,,,,, -79100,1.3658649,2.3902855,,,,,,,,,,,,,, -79200,1.0491257,5.0069838,,,,,,,,,,,,,, -79300,1.7094076,2.4738047,,,,,,,,,,,,,, -79400,1.4032753,2.341528,,,,,,,,,,,,,, -79433,,,0.6779101490974426,1.3415353298187256,0.6250999569892883,1.593327283859253,50000.0,0.5078999996185303,2.239119291305542,10000.0,37011.00729799271,41676.02999925613,37011.00729799271,4657.113859415054,3.745133638381958,0.0 -79500,1.2710267,2.3230906,,,,,,,,,,,,,, -79600,1.0325037,4.2774997,,,,,,,,,,,,,, -79700,1.2825149,4.897076,,,,,,,,,,,,,, -79800,1.3825079,2.5018911,,,,,,,,,,,,,, -79900,1.3842525,2.4257298,,,,,,,,,,,,,, -80000,1.513257,2.2743614,,,,,,,,,,,,,, -80100,1.2876314,2.6042445,,,,,,,,,,,,,, -80200,1.1231045,3.3349972,,,,,,,,,,,,,, -80300,1.4034637,2.3840995,,,,,,,,,,,,,, -80333,,,0.712890625,1.150959014892578,0.6348599791526794,1.51292884349823,50000.0,0.5133000016212463,2.168611526489258,10000.0,37431.40174984932,42151.15966916084,37431.40174984932,4711.764025211334,3.782468557357788,0.0 -80400,1.1779888,4.0107565,,,,,,,,,,,,,, -80500,1.3177826,4.7957177,,,,,,,,,,,,,, -80600,1.6599455,2.459232,,,,,,,,,,,,,, -80700,1.3712145,2.4363027,,,,,,,,,,,,,, -80800,1.2215575,3.7676306,,,,,,,,,,,,,, -80900,1.4034344,2.3056858,,,,,,,,,,,,,, -81000,1.3310696,2.801651,,,,,,,,,,,,,, -81100,1.4217057,2.1477165,,,,,,,,,,,,,, -81200,1.231137,4.815932,,,,,,,,,,,,,, -81236,,,0.6807616949081421,1.2854599952697754,0.6318599581718445,1.5201311111450195,50000.0,0.4996000230312347,2.2001304626464844,10000.0,37851.49418973923,42622.74819707871,37851.49418973923,4763.166525602341,3.8283190727233887,0.0 -81300,1.4381582,2.3926213,,,,,,,,,,,,,, -81400,1.5863651,2.3633723,,,,,,,,,,,,,, -81500,1.2768297,2.628611,,,,,,,,,,,,,, -81600,1.2386489,4.4319434,,,,,,,,,,,,,, -81700,1.18039,4.9149046,,,,,,,,,,,,,, -81800,1.038,4.9220862,,,,,,,,,,,,,, -81900,1.3752767,2.4611418,,,,,,,,,,,,,, -82000,1.3200337,4.321279,,,,,,,,,,,,,, -82100,1.2955983,2.4391253,,,,,,,,,,,,,, -82140,,,0.6885156035423279,1.2505003213882446,0.6342399716377258,1.5106966495513916,50000.0,0.5154000520706177,2.1858279705047607,10000.0,38271.61984539032,43095.682314157486,38271.61984539032,4815.883181333542,3.870955228805542,0.0 -82200,1.2797397,5.016423,,,,,,,,,,,,,, -82300,1.2905725,2.4607513,,,,,,,,,,,,,, -82400,1.3735826,2.2937918,,,,,,,,,,,,,, -82500,1.2507509,3.3169243,,,,,,,,,,,,,, -82600,1.0404552,4.5426383,,,,,,,,,,,,,, -82700,1.2926645,2.3002512,,,,,,,,,,,,,, -82800,1.6005989,2.48471,,,,,,,,,,,,,, -82900,1.263858,2.3271174,,,,,,,,,,,,,, -83000,1.1840807,3.6160178,,,,,,,,,,,,,, -83040,,,0.7089062333106995,1.1777950525283811,0.6371200084686279,1.510980486869812,50000.0,0.5186000466346741,2.158198595046997,10000.0,38691.78195428848,43568.31066441536,38691.78195428848,4868.257662296295,3.9152021408081055,0.0 -83100,1.3908756,2.3488266,,,,,,,,,,,,,, -83200,1.4284714,2.5746746,,,,,,,,,,,,,, -83300,1.1234003,4.7492795,,,,,,,,,,,,,, -83400,1.3530482,4.925543,,,,,,,,,,,,,, -83500,1.3097801,2.2621672,,,,,,,,,,,,,, -83600,1.3749633,3.3015592,,,,,,,,,,,,,, -83700,1.1278995,4.3822365,,,,,,,,,,,,,, -83800,1.3599948,2.6519423,,,,,,,,,,,,,, -83900,1.4991256,2.3511498,,,,,,,,,,,,,, -83940,,,0.6898242235183716,1.2514363527297974,0.6396399736404419,1.4833993911743164,50000.0,0.5134000182151794,2.157160758972168,10000.0,39111.83612918854,44039.46033358574,39111.83612918854,4919.264448881149,3.956256151199341,0.0 -84000,1.303438,2.7968864,,,,,,,,,,,,,, -84100,1.4423027,2.3019662,,,,,,,,,,,,,, -84200,1.3524553,2.362874,,,,,,,,,,,,,, -84300,1.1745495,3.7020586,,,,,,,,,,,,,, -84400,1.5991731,2.391603,,,,,,,,,,,,,, -84500,1.3276222,2.5687559,,,,,,,,,,,,,, -84600,1.3442959,2.67843,,,,,,,,,,,,,, -84700,1.4674762,2.3820605,,,,,,,,,,,,,, -84800,1.28315,2.327803,,,,,,,,,,,,,, -84841,,,0.6862499713897705,1.308943748474121,0.6337400078773499,1.5579991340637207,50000.0,0.5060000419616699,2.2265491485595703,10000.0,39531.96011543274,44512.09685301781,39531.96011543274,4971.682218551636,4.003744602203369,0.0 -84900,1.3209364,4.8557253,,,,,,,,,,,,,, -85000,1.3693029,2.7077713,,,,,,,,,,,,,, -85100,1.3499427,2.394483,,,,,,,,,,,,,, -85200,1.4585587,2.3179173,,,,,,,,,,,,,, -85300,1.111306,4.9037538,,,,,,,,,,,,,, -85400,1.379055,2.3653827,,,,,,,,,,,,,, -85500,1.4015396,2.5239894,,,,,,,,,,,,,, -85600,1.3641413,2.4060347,,,,,,,,,,,,,, -85700,1.3255694,2.2856355,,,,,,,,,,,,,, -85741,,,0.7101757526397705,1.163679122924805,0.6434400081634521,1.4765359163284302,50000.0,0.5157000422477722,2.157416820526123,10000.0,39952.28334736824,44983.476644039154,39952.28334736824,5022.644863128662,4.050050735473633,0.0 -85800,1.1489358,2.6690054,,,,,,,,,,,,,, -85900,1.2941417,2.417416,,,,,,,,,,,,,, -86000,1.4420577,2.4580626,,,,,,,,,,,,,, -86100,1.3289961,2.6835372,,,,,,,,,,,,,, -86200,1.3850938,4.8285127,,,,,,,,,,,,,, -86300,1.7201856,4.43573,,,,,,,,,,,,,, -86400,1.2080538,4.779981,,,,,,,,,,,,,, -86500,1.3483047,2.2167227,,,,,,,,,,,,,, -86600,1.2456049,3.4663742,,,,,,,,,,,,,, -86646,,,0.6906445026397705,1.2567678689956665,0.6412999629974365,1.5003139972686768,50000.0,0.5220000147819519,2.1490726470947266,10000.0,40372.55580735207,45456.74989652634,40372.55580735207,5075.555291175842,4.092345476150513,0.0 -86700,1.3779963,2.2955518,,,,,,,,,,,,,, -86800,1.188184,2.909912,,,,,,,,,,,,,, -86900,1.2844748,3.281737,,,,,,,,,,,,,, -87000,1.5319813,2.502334,,,,,,,,,,,,,, -87100,1.3523206,2.3101616,,,,,,,,,,,,,, -87200,1.3794641,2.2426515,,,,,,,,,,,,,, -87300,1.2448225,2.7043219,,,,,,,,,,,,,, -87400,1.3227248,2.3128507,,,,,,,,,,,,,, -87500,1.3993518,2.2842536,,,,,,,,,,,,,, -87543,,,0.6939257383346558,1.2720842361450195,0.6372399926185608,1.5326356887817385,50000.0,0.5128000378608704,2.205008029937744,10000.0,40792.581500291824,45931.0048494339,40792.581500291824,5129.692433595657,4.137173175811768,0.0 -87600,1.3490795,3.7669508,,,,,,,,,,,,,, -87700,1.2012812,3.7205884,,,,,,,,,,,,,, -87800,1.2582419,4.081649,,,,,,,,,,,,,, -87900,1.3344902,2.2971423,,,,,,,,,,,,,, -88000,1.3624152,4.0432296,,,,,,,,,,,,,, -88100,1.420815,2.381884,,,,,,,,,,,,,, -88200,1.2031457,4.5429044,,,,,,,,,,,,,, -88300,1.4765048,2.3345728,,,,,,,,,,,,,, -88400,1.1836745,4.706715,,,,,,,,,,,,,, -88444,,,0.7095507383346558,1.1825847625732422,0.6438800096511841,1.4855278730392456,50000.0,0.5198000073432922,2.147436141967773,10000.0,41212.59487867355,46403.07826638222,41212.59487867355,5181.660806417465,4.181311845779419,0.0 -88500,1.1682171,3.5027578,,,,,,,,,,,,,, -88600,1.2815939,2.2848024,,,,,,,,,,,,,, -88700,1.4197832,2.289482,,,,,,,,,,,,,, -88800,1.5558496,2.5309136,,,,,,,,,,,,,, -88900,1.1916397,2.7104268,,,,,,,,,,,,,, -89000,1.4082811,2.2089705,,,,,,,,,,,,,, -89100,1.4765106,2.4745045,,,,,,,,,,,,,, -89200,1.3456057,2.6294782,,,,,,,,,,,,,, -89300,1.2763673,2.75554,,,,,,,,,,,,,, -89347,,,0.697949230670929,1.2049740552902222,0.646120011806488,1.454306960105896,50000.0,0.5243000388145447,2.1224191188812256,10000.0,41632.62812304497,46873.910395145416,41632.62812304497,5232.368206501007,4.225497245788574,0.0 -89400,1.2837529,4.3208733,,,,,,,,,,,,,, -89500,1.0576904,3.5724957,,,,,,,,,,,,,, -89600,1.3932611,2.3464627,,,,,,,,,,,,,, -89700,1.2721591,2.696485,,,,,,,,,,,,,, -89800,1.4610273,4.715677,,,,,,,,,,,,,, -89900,1.3366601,2.2277641,,,,,,,,,,,,,, -90000,1.4475988,2.3279488,,,,,,,,,,,,,, -90100,1.1429038,4.8372746,,,,,,,,,,,,,, -90200,1.2376716,3.2033186,,,,,,,,,,,,,, -90248,,,0.7050976157188416,1.1831070184707642,0.6489799618721008,1.4485565423965454,50000.0,0.5238000154495239,2.1029837131500244,10000.0,42052.653123140335,47345.378999471664,42052.653123140335,5283.711639404297,4.278661251068115,0.0 -90300,1.31286,3.0811212,,,,,,,,,,,,,, -90400,1.5548875,2.321167,,,,,,,,,,,,,, -90500,1.5166633,2.36763,,,,,,,,,,,,,, -90600,1.1814011,3.8292153,,,,,,,,,,,,,, -90700,1.3649008,2.1970592,,,,,,,,,,,,,, -90800,1.4163591,2.8642297,,,,,,,,,,,,,, -90900,1.3875407,2.5134544,,,,,,,,,,,,,, -91000,1.3481215,3.269525,,,,,,,,,,,,,, -91100,1.5115715,2.377652,,,,,,,,,,,,,, -91150,,,0.7136132717132568,1.1587448120117188,0.645859956741333,1.4699233770370483,50000.0,0.523300051689148,2.1205849647521973,10000.0,42472.69245290756,47817.961052656174,42472.69245290756,5336.162372112274,4.322976589202881,0.0 -91200,1.534936,2.1808333,,,,,,,,,,,,,, -91300,1.4253623,2.744159,,,,,,,,,,,,,, -91400,1.4632427,2.1956842,,,,,,,,,,,,,, -91500,1.4551672,4.6346827,,,,,,,,,,,,,, -91600,1.5036765,2.2833457,,,,,,,,,,,,,, -91700,1.5092369,2.2949462,,,,,,,,,,,,,, -91800,1.1040684,4.5926523,,,,,,,,,,,,,, -91900,1.2886838,2.5960946,,,,,,,,,,,,,, -92000,1.4942001,2.472852,,,,,,,,,,,,,, -92050,,,0.7071288824081421,1.196994662284851,0.6503399610519409,1.4454437494277954,50000.0,0.5268000364303589,2.090376138687134,10000.0,42893.01210975647,48290.78095889092,42893.01210975647,5388.570991754532,4.366878986358643,0.0 -92100,1.3055528,4.629155,,,,,,,,,,,,,, -92200,1.4690012,2.3599908,,,,,,,,,,,,,, -92300,1.1756036,4.7610807,,,,,,,,,,,,,, -92400,1.1779555,4.1495814,,,,,,,,,,,,,, -92500,1.3153346,3.728034,,,,,,,,,,,,,, -92600,1.4411316,2.2997942,,,,,,,,,,,,,, -92700,1.1025366,4.4007683,,,,,,,,,,,,,, -92800,1.4556278,2.4014595,,,,,,,,,,,,,, -92900,1.2151872,3.580529,,,,,,,,,,,,,, -92948,,,0.7015038728713989,1.22853684425354,0.6442599892616272,1.4875376224517822,50000.0,0.5206000208854675,2.143444061279297,10000.0,43313.314470767975,48763.07976198197,43313.314470767975,5440.473873138428,4.412621974945068,0.0 -93000,1.4402263,3.2503767,,,,,,,,,,,,,, -93100,1.4968023,2.2804723,,,,,,,,,,,,,, -93200,1.2786093,3.1450822,,,,,,,,,,,,,, -93300,1.3397145,2.3314834,,,,,,,,,,,,,, -93400,1.2444803,3.1100276,,,,,,,,,,,,,, -93500,1.3285248,3.4781427,,,,,,,,,,,,,, -93600,1.3190056,2.6632986,,,,,,,,,,,,,, -93700,1.4026414,2.3871608,,,,,,,,,,,,,, -93800,1.3813208,2.3870447,,,,,,,,,,,,,, -93845,,,0.7256445288658142,1.0961700677871704,0.6585800051689148,1.40444016456604,50000.0,0.5323000550270081,2.060444593429565,10000.0,43733.2725212574,49234.66075634956,43733.2725212574,5492.004205703735,4.4586100578308105,0.0 -93900,1.2108086,3.879017,,,,,,,,,,,,,, -94000,1.4465905,2.224271,,,,,,,,,,,,,, -94100,1.1613541,4.730617,,,,,,,,,,,,,, -94200,1.2767556,4.435846,,,,,,,,,,,,,, -94300,1.2462345,4.5105753,,,,,,,,,,,,,, -94400,1.3942422,2.5542004,,,,,,,,,,,,,, -94500,1.2980027,2.7443285,,,,,,,,,,,,,, -94600,1.4007876,2.2647722,,,,,,,,,,,,,, -94700,1.3460284,2.6324868,,,,,,,,,,,,,, -94745,,,0.7081640362739563,1.187019944190979,0.6531999707221985,1.43330717086792,50000.0,0.5218999981880188,2.109593629837036,10000.0,44153.533059597015,49705.9593732357,44153.533059597015,5542.949748277664,4.503470420837402,0.0 -94800,1.4073662,2.589421,,,,,,,,,,,,,, -94900,1.4330608,4.768134,,,,,,,,,,,,,, -95000,1.34486,3.9020123,,,,,,,,,,,,,, -95100,1.4461927,2.261674,,,,,,,,,,,,,, -95200,1.4904964,2.3397536,,,,,,,,,,,,,, -95300,1.4769996,2.251679,,,,,,,,,,,,,, -95400,1.5296577,4.4381413,,,,,,,,,,,,,, -95500,1.1624862,4.7886257,,,,,,,,,,,,,, -95600,1.3213574,2.3965287,,,,,,,,,,,,,, -95645,,,0.7108983993530273,1.1773436069488523,0.6543799638748169,1.4492274522781372,50000.0,0.5276000499725342,2.1062769889831543,10000.0,44573.71209073067,50177.37110543251,44573.71209073067,5594.092439651489,4.546373128890991,0.0 -95700,1.2704519,3.2078714,,,,,,,,,,,,,, -95800,1.1131271,4.564024,,,,,,,,,,,,,, -95900,1.3955588,2.6049051,,,,,,,,,,,,,, -96000,1.4430066,2.2605712,,,,,,,,,,,,,, -96100,1.1272136,4.7527514,,,,,,,,,,,,,, -96200,1.2294749,2.7127256,,,,,,,,,,,,,, -96300,1.4342084,2.2532876,,,,,,,,,,,,,, -96400,1.3300741,4.8397317,,,,,,,,,,,,,, -96500,1.5023437,2.1818643,,,,,,,,,,,,,, -96550,,,0.7174413800239563,1.1969013214111328,0.6540799736976624,1.4859381914138794,50000.0,0.5288000106811523,2.1328086853027344,10000.0,44993.845116853714,50651.08602452278,44993.845116853714,5647.579032897949,4.593611240386963,0.0 -96600,1.4530458,2.1997514,,,,,,,,,,,,,, -96700,1.3950167,2.160376,,,,,,,,,,,,,, -96800,1.3739363,2.250995,,,,,,,,,,,,,, -96900,1.4518821,2.3279862,,,,,,,,,,,,,, -97000,1.4403174,2.516071,,,,,,,,,,,,,, -97100,1.3406202,4.691406,,,,,,,,,,,,,, -97200,1.2316102,3.0776029,,,,,,,,,,,,,, -97300,1.237155,4.3220854,,,,,,,,,,,,,, -97400,1.3378308,4.78494,,,,,,,,,,,,,, -97448,,,0.7165625095367432,1.188768744468689,0.6569199562072754,1.463334083557129,50000.0,0.534500002861023,2.111708402633667,10000.0,45413.95650100708,51122.57685351372,45413.95650100708,5698.86651301384,4.638423681259155,0.0 -97500,1.4813472,2.5648706,,,,,,,,,,,,,, -97600,1.2828928,2.476362,,,,,,,,,,,,,, -97700,1.2784656,2.6725676,,,,,,,,,,,,,, -97800,1.241582,3.737714,,,,,,,,,,,,,, -97900,1.6617664,2.2203486,,,,,,,,,,,,,, -98000,1.4245042,2.2264583,,,,,,,,,,,,,, -98100,1.4190437,2.4138672,,,,,,,,,,,,,, -98200,1.4963046,2.1853495,,,,,,,,,,,,,, -98300,1.4185423,2.3988392,,,,,,,,,,,,,, -98347,,,0.7182226181030273,1.1286863088607788,0.6610599756240845,1.4054856300354004,50000.0,0.5304000377655029,2.0756752490997314,10000.0,45834.25590658188,51593.67358827591,45834.25590658188,5749.567130565643,4.687403678894043,0.0 -98400,1.2290486,4.7076774,,,,,,,,,,,,,, -98500,1.279146,4.8011265,,,,,,,,,,,,,, -98600,1.2315402,4.729507,,,,,,,,,,,,,, -98700,1.4332666,2.1262949,,,,,,,,,,,,,, -98800,1.5095994,2.1355367,,,,,,,,,,,,,, -98900,1.4768051,2.3680625,,,,,,,,,,,,,, -99000,1.3212994,4.6686945,,,,,,,,,,,,,, -99100,1.3872479,2.2152398,,,,,,,,,,,,,, -99200,1.3211912,4.529902,,,,,,,,,,,,,, -99248,,,0.7224413752555847,1.107817769050598,0.660319983959198,1.3991767168045044,50000.0,0.535800039768219,2.052008390426636,10000.0,46254.56909441948,52066.551471710205,46254.56909441948,5802.036701917648,4.73491358757019,0.0 -99300,1.6145159,2.2856622,,,,,,,,,,,,,, -99400,1.4591018,2.0733962,,,,,,,,,,,,,, -99500,1.2294574,4.635017,,,,,,,,,,,,,, -99600,1.3351189,2.7201996,,,,,,,,,,,,,, -99700,1.1879638,4.0177045,,,,,,,,,,,,,, -99800,1.3908195,3.3028998,,,,,,,,,,,,,, -99900,1.2245063,4.1893053,,,,,,,,,,,,,, -100000,1.47858,2.0559602,,,,,,,,,,,,,, -100100,1.241801,4.7214375,,,,,,,,,,,,,, -100151,,,0.7393554449081421,1.0557329654693604,0.6615599989891052,1.396098256111145,50000.0,0.5397000312805176,2.0368974208831787,10000.0,46674.90282559395,52540.69532442093,46674.90282559395,5855.741189956665,4.792380809783936,0.0 -100200,1.2819285,4.4195547,,,,,,,,,,,,,, -100300,1.3484372,4.141281,,,,,,,,,,,,,, -100400,1.5662589,2.3300092,,,,,,,,,,,,,, -100500,1.5417343,2.1826856,,,,,,,,,,,,,, -100600,1.3540663,3.4015126,,,,,,,,,,,,,, -100700,1.5428998,2.1271644,,,,,,,,,,,,,, -100800,1.5266023,2.1494474,,,,,,,,,,,,,, -100900,1.2690171,4.154403,,,,,,,,,,,,,, -101000,1.7537146,2.0019076,,,,,,,,,,,,,, -101053,,,0.7187304496765137,1.1348941326141355,0.6573399901390076,1.40870201587677,50000.0,0.5364000201225281,2.0542080402374268,10000.0,47095.12478327751,53013.22602438927,47095.12478327751,5907.954854726791,4.839926719665527,0.0 -101100,1.3857422,2.9376028,,,,,,,,,,,,,, -101200,1.4256675,2.5422332,,,,,,,,,,,,,, -101300,1.634068,2.0532167,,,,,,,,,,,,,, -101400,1.3910098,4.768256,,,,,,,,,,,,,, -101500,1.6365929,2.161908,,,,,,,,,,,,,, -101600,1.251744,3.7823098,,,,,,,,,,,,,, -101700,1.4300869,2.4350417,,,,,,,,,,,,,, -101800,1.343007,4.0650034,,,,,,,,,,,,,, -101900,1.6787755,2.3996658,,,,,,,,,,,,,, -101956,,,0.7291796803474426,1.140833616256714,0.6621400117874146,1.432441234588623,50000.0,0.5386000275611877,2.0662755966186523,10000.0,47515.08821105957,53483.97033381462,47515.08821105957,5958.642646789551,4.885095834732056,0.0 -102000,1.4753487,2.192264,,,,,,,,,,,,,, -102100,1.2333732,3.3824627,,,,,,,,,,,,,, -102200,1.3349831,3.2405295,,,,,,,,,,,,,, -102300,1.3234762,3.0297468,,,,,,,,,,,,,, -102400,1.4211658,2.9860702,,,,,,,,,,,,,, -102500,1.5523518,2.1401348,,,,,,,,,,,,,, -102600,1.5495265,2.1177096,,,,,,,,,,,,,, -102700,1.4904962,2.4800074,,,,,,,,,,,,,, -102800,1.3882499,4.7848353,,,,,,,,,,,,,, -102858,,,0.7438867092132568,1.0300265550613403,0.6672199964523315,1.3864587545394895,50000.0,0.5406000018119812,2.0302770137786865,10000.0,47935.08392548561,53956.87895011902,47935.08392548561,6011.463626861572,4.929841756820679,0.0 -102900,1.2459522,4.6774282,,,,,,,,,,,,,, -103000,1.2816552,4.4900146,,,,,,,,,,,,,, -103100,1.5997131,2.058398,,,,,,,,,,,,,, -103200,1.2672585,3.7092643,,,,,,,,,,,,,, -103300,1.598554,2.0482225,,,,,,,,,,,,,, -103400,1.5563928,2.2666361,,,,,,,,,,,,,, -103500,1.2854668,3.0949793,,,,,,,,,,,,,, -103600,1.5318369,2.030815,,,,,,,,,,,,,, -103700,1.5511671,2.0658348,,,,,,,,,,,,,, -103763,,,0.723437488079071,1.120510816574097,0.66975998878479,1.3746145963668823,50000.0,0.5421000123023987,2.0381054878234863,10000.0,48355.07038998604,54427.79730439186,48355.07038998604,6062.30232834816,4.975308418273926,0.0 -103800,1.386046,4.51952,,,,,,,,,,,,,, -103900,1.4121767,4.0140834,,,,,,,,,,,,,, -104000,1.4877398,2.0428727,,,,,,,,,,,,,, -104100,1.4040444,2.463751,,,,,,,,,,,,,, -104200,1.55518,4.410178,,,,,,,,,,,,,, -104300,1.4666698,2.154542,,,,,,,,,,,,,, -104400,1.5363748,1.929115,,,,,,,,,,,,,, -104500,1.5279716,2.006433,,,,,,,,,,,,,, -104600,1.3558235,3.138921,,,,,,,,,,,,,, -104666,,,0.7344530820846558,1.072016358375549,0.6725599765777588,1.3530502319335938,50000.0,0.5452000498771667,1.9986283779144287,10000.0,48775.350959062576,54899.59563159943,48775.350959062576,6113.727089166641,5.020384788513184,0.0 -104700,1.6700151,2.0325613,,,,,,,,,,,,,, -104800,1.4353664,3.2287347,,,,,,,,,,,,,, -104900,1.6242994,2.2031784,,,,,,,,,,,,,, -105000,1.615763,2.149009,,,,,,,,,,,,,, -105100,1.5420332,2.1514766,,,,,,,,,,,,,, -105200,1.3093729,4.0601807,,,,,,,,,,,,,, -105300,1.5834911,2.043019,,,,,,,,,,,,,, -105400,1.2874417,3.0701578,,,,,,,,,,,,,, -105500,1.7213383,2.1634705,,,,,,,,,,,,,, -105559,,,0.7469140291213989,0.9951573610305786,0.6689199805259705,1.3489623069763184,50000.0,0.5433000326156616,2.0017428398132324,10000.0,49195.26766204834,55372.13287234306,49195.26766204834,6166.253857374191,5.067118167877197,0.0 -105600,1.3087689,4.0547304,,,,,,,,,,,,,, -105700,1.525202,2.45964,,,,,,,,,,,,,, -105800,1.5753194,2.1128151,,,,,,,,,,,,,, -105900,1.3065257,3.635406,,,,,,,,,,,,,, -106000,1.3532205,3.5565538,,,,,,,,,,,,,, -106100,1.5058726,2.4610724,,,,,,,,,,,,,, -106200,1.5194054,2.0872836,,,,,,,,,,,,,, -106300,1.5492873,2.380986,,,,,,,,,,,,,, -106400,1.6432685,4.6520944,,,,,,,,,,,,,, -106458,,,0.730273425579071,1.1089922189712524,0.6736199855804443,1.3686671257019043,50000.0,0.5472000241279602,2.037179470062256,10000.0,49615.44148349762,55844.475420475006,49615.44148349762,6218.311373949051,5.114095211029053,0.0 -106500,1.5550002,2.113976,,,,,,,,,,,,,, -106600,1.98031,2.1310668,,,,,,,,,,,,,, -106700,1.4962327,2.1047223,,,,,,,,,,,,,, -106800,1.3345598,2.7676501,,,,,,,,,,,,,, -106900,1.5369807,4.5869308,,,,,,,,,,,,,, -107000,1.546864,2.312242,,,,,,,,,,,,,, -107100,1.4770828,3.6954591,,,,,,,,,,,,,, -107200,1.5434413,3.3537588,,,,,,,,,,,,,, -107300,1.3688825,4.2615657,,,,,,,,,,,,,, -107364,,,0.731640636920929,1.0827327966690063,0.6683599948883057,1.369884967803955,50000.0,0.5441000461578369,2.013227939605713,10000.0,50035.62106966972,56315.47707438469,50035.62106966972,6269.040015935898,5.159966707229614,0.0 -107400,1.4811367,2.4343383,,,,,,,,,,,,,, -107500,1.5906621,2.056623,,,,,,,,,,,,,, -107600,1.5971477,2.1790051,,,,,,,,,,,,,, -107700,1.4848775,2.7839835,,,,,,,,,,,,,, -107800,1.5519453,4.8098574,,,,,,,,,,,,,, -107900,1.2608296,4.3342147,,,,,,,,,,,,,, -108000,1.5999312,2.091027,,,,,,,,,,,,,, -108100,1.5488092,3.0153818,,,,,,,,,,,,,, -108200,1.4656358,2.1711147,,,,,,,,,,,,,, -108262,,,0.7414648532867432,1.0510865449905396,0.6689800024032593,1.3708804845809937,50000.0,0.5437000393867493,2.018178701400757,10000.0,50455.64405369759,56788.2848212719,50455.64405369759,6321.7311725616455,5.206001281738281,0.0 -108300,1.5755537,2.0470736,,,,,,,,,,,,,, -108400,1.4435546,3.125463,,,,,,,,,,,,,, -108500,1.563143,2.3759127,,,,,,,,,,,,,, -108600,1.7259005,2.1528869,,,,,,,,,,,,,, -108700,1.5898552,2.118012,,,,,,,,,,,,,, -108800,1.634497,1.9937087,,,,,,,,,,,,,, -108900,1.5685554,2.3199675,,,,,,,,,,,,,, -109000,1.6211847,2.497299,,,,,,,,,,,,,, -109100,1.5816737,2.0327008,,,,,,,,,,,,,, -109164,,,0.7342773079872131,1.0697695016860962,0.6750999689102173,1.3431190252304075,50000.0,0.5509999990463257,1.997771978378296,10000.0,50875.999675273895,57261.83912968636,50875.999675273895,6374.832296609879,5.255663156509399,0.0 -109200,1.6248642,1.963835,,,,,,,,,,,,,, -109300,1.5602106,4.2061024,,,,,,,,,,,,,, -109400,1.6719706,2.1370416,,,,,,,,,,,,,, -109500,1.6519817,2.0197053,,,,,,,,,,,,,, -109600,1.435459,4.6842194,,,,,,,,,,,,,, -109700,1.523908,3.1510532,,,,,,,,,,,,,, -109800,1.6212752,2.083203,,,,,,,,,,,,,, -109900,1.534946,2.5972862,,,,,,,,,,,,,, -110000,1.4678596,2.666895,,,,,,,,,,,,,, -110071,,,0.7425194978713989,1.0702048540115356,0.6782400012016296,1.3490768671035769,50000.0,0.550000011920929,2.0083322525024414,10000.0,51296.14013576508,57732.86963033676,51296.14013576508,6425.622935056686,5.307147979736328,0.0 -110100,1.4765439,2.8107083,,,,,,,,,,,,,, -110200,1.5202878,2.4273458,,,,,,,,,,,,,, -110300,1.5518007,2.106073,,,,,,,,,,,,,, -110400,1.6829886,2.0134125,,,,,,,,,,,,,, -110500,1.4267322,4.2297964,,,,,,,,,,,,,, -110600,1.5237772,2.0319176,,,,,,,,,,,,,, -110700,1.5475625,2.0837502,,,,,,,,,,,,,, -110800,1.5847305,2.3682013,,,,,,,,,,,,,, -110900,1.5691326,4.633339,,,,,,,,,,,,,, -110978,,,0.7479101419448853,1.0002694129943848,0.6759999990463257,1.3279269933700562,50000.0,0.5574000477790833,1.9583526849746704,10000.0,51716.33648109436,58204.73041367531,51716.33648109436,6477.19561457634,5.351512670516968,0.0 -111000,1.5252953,2.5220962,,,,,,,,,,,,,, -111100,1.465226,3.4351704,,,,,,,,,,,,,, -111200,1.3422451,3.6730084,,,,,,,,,,,,,, -111300,1.5434835,2.021834,,,,,,,,,,,,,, -111400,1.4328455,3.954116,,,,,,,,,,,,,, -111500,1.3531641,3.9638863,,,,,,,,,,,,,, -111600,1.9867929,2.214653,,,,,,,,,,,,,, -111700,1.6937106,1.9743425,,,,,,,,,,,,,, -111800,1.6187449,2.2833714,,,,,,,,,,,,,, -111880,,,0.7358593344688416,1.0901877880096436,0.6753999590873718,1.3613321781158447,50000.0,0.5534999966621399,2.0046257972717285,10000.0,52136.71265649796,58676.56856536865,52136.71265649796,6528.562636137009,5.398941516876221,0.0 -111900,1.4947052,2.6791306,,,,,,,,,,,,,, -112000,1.3791231,4.1740313,,,,,,,,,,,,,, -112100,1.5482981,3.927456,,,,,,,,,,,,,, -112200,1.7705615,2.0622082,,,,,,,,,,,,,, -112300,1.4139762,2.568209,,,,,,,,,,,,,, -112400,1.5730557,2.1702976,,,,,,,,,,,,,, -112500,1.6088641,2.0380478,,,,,,,,,,,,,, -112600,1.3446008,2.9642055,,,,,,,,,,,,,, -112700,1.6874593,2.4535966,,,,,,,,,,,,,, -112785,,,0.7383984327316284,1.064749836921692,0.6764199733734131,1.340150237083435,50000.0,0.5529000163078308,1.9855674505233765,10000.0,52557.00045776367,59148.58823752403,52557.00045776367,6580.2001440525055,5.444950819015503,0.0 -112800,1.7772261,2.0898771,,,,,,,,,,,,,, -112900,1.5786505,2.2047758,,,,,,,,,,,,,, -113000,1.8673522,1.9506083,,,,,,,,,,,,,, -113100,1.6704462,2.000366,,,,,,,,,,,,,, -113200,1.6009046,1.9965701,,,,,,,,,,,,,, -113300,1.8897436,2.1157117,,,,,,,,,,,,,, -113400,1.5893568,4.200406,,,,,,,,,,,,,, -113500,1.4871082,2.7366993,,,,,,,,,,,,,, -113600,1.4466568,2.2072144,,,,,,,,,,,,,, -113682,,,0.7493749856948853,1.0180168151855469,0.6759999990463257,1.3451844453811646,50000.0,0.5527999997138977,1.991071343421936,10000.0,52976.95989322662,59622.63963222504,52976.95989322662,6634.199034690857,5.4898645877838135,0.0 -113700,1.6345245,2.1303363,,,,,,,,,,,,,, -113800,1.6627753,1.9558892,,,,,,,,,,,,,, -113900,1.6570755,2.2566514,,,,,,,,,,,,,, -114000,1.5102754,3.125235,,,,,,,,,,,,,, -114100,1.6623701,1.9560915,,,,,,,,,,,,,, -114200,1.5807725,1.9757824,,,,,,,,,,,,,, -114300,1.6135482,1.9653795,,,,,,,,,,,,,, -114400,1.5071405,2.727613,,,,,,,,,,,,,, -114500,1.7568243,2.0566683,,,,,,,,,,,,,, -114585,,,0.74609375,1.0239732265472412,0.6865999698638916,1.2994922399520874,50000.0,0.5619000196456909,1.937618613243103,10000.0,53397.26622343063,60096.39663481712,53397.26622343063,6687.553718566895,5.537170886993408,0.0 -114600,1.6170822,2.294289,,,,,,,,,,,,,, -114700,1.7505732,2.151593,,,,,,,,,,,,,, -114800,1.5612946,4.6527567,,,,,,,,,,,,,, -114900,1.6958953,2.4036334,,,,,,,,,,,,,, -115000,1.5334854,2.7306302,,,,,,,,,,,,,, -115100,1.6409495,1.9408216,,,,,,,,,,,,,, -115200,1.5040795,3.213705,,,,,,,,,,,,,, -115300,1.6374961,2.0898516,,,,,,,,,,,,,, -115400,1.6077719,2.0014396,,,,,,,,,,,,,, -115490,,,0.7510351538658142,1.0083297491073608,0.687279999256134,1.2991552352905271,50000.0,0.562000036239624,1.9419201612472528,10000.0,53817.85648846626,60568.140437603,53817.85648846626,6738.60989689827,5.5869903564453125,0.0 -115500,1.6000713,2.036161,,,,,,,,,,,,,, -115600,1.7094628,2.1726618,,,,,,,,,,,,,, -115700,1.3990462,3.1767437,,,,,,,,,,,,,, -115800,1.5958309,1.9919593,,,,,,,,,,,,,, -115900,1.5001702,4.5255895,,,,,,,,,,,,,, -116000,1.7705501,2.0666754,,,,,,,,,,,,,, -116100,1.5914493,3.6749668,,,,,,,,,,,,,, -116200,1.6367577,2.0963843,,,,,,,,,,,,,, -116300,1.6227655,2.0741136,,,,,,,,,,,,,, -116390,,,0.7574414014816284,0.960770845413208,0.6850199699401855,1.277678370475769,50000.0,0.5644000172615051,1.9282153844833367,10000.0,54237.85001659393,61040.232503175735,54237.85001659393,6790.61523938179,5.63244891166687,0.0 -116400,1.6632309,2.3058414,,,,,,,,,,,,,, -116500,1.3940784,3.26727,,,,,,,,,,,,,, -116600,1.6595771,1.9854119,,,,,,,,,,,,,, -116700,1.6845585,1.9973844,,,,,,,,,,,,,, -116800,1.7764397,2.0311244,,,,,,,,,,,,,, -116900,1.8634707,2.0152733,,,,,,,,,,,,,, -117000,1.7075768,2.1030607,,,,,,,,,,,,,, -117100,1.8197613,2.1248689,,,,,,,,,,,,,, -117200,1.5282624,3.1680434,,,,,,,,,,,,,, -117290,,,0.7455663681030273,1.0378332138061523,0.6839199662208557,1.3181474208831787,50000.0,0.5618000030517578,1.951360106468201,10000.0,54657.93375110626,61511.32391309738,54657.93375110626,6841.529457330704,5.67902684211731,0.0 -117300,1.6337422,1.9703836,,,,,,,,,,,,,, -117400,1.5929663,4.5859504,,,,,,,,,,,,,, -117500,1.5982778,4.212229,,,,,,,,,,,,,, -117600,1.4787717,4.159676,,,,,,,,,,,,,, -117700,1.5760056,1.8958802,,,,,,,,,,,,,, -117800,1.7288373,2.0317676,,,,,,,,,,,,,, -117900,1.5630534,4.170822,,,,,,,,,,,,,, -118000,1.5566455,4.564907,,,,,,,,,,,,,, -118100,1.5513271,4.100118,,,,,,,,,,,,,, -118185,,,0.75244140625,1.0247377157211304,0.6880599856376648,1.310234546661377,50000.0,0.5634000301361084,1.9548426866531368,10000.0,55078.15105628967,61983.7370262146,55078.15105628967,6893.628536224365,5.7277820110321045,0.0 -118200,1.5177486,3.126921,,,,,,,,,,,,,, -118300,1.5749629,2.950705,,,,,,,,,,,,,, -118400,1.7295163,2.0185857,,,,,,,,,,,,,, -118500,1.429745,3.4499962,,,,,,,,,,,,,, -118600,1.585352,2.4784818,,,,,,,,,,,,,, -118700,1.7810146,1.9886329,,,,,,,,,,,,,, -118800,1.5600184,2.3517742,,,,,,,,,,,,,, -118900,2.0329607,1.9566714,,,,,,,,,,,,,, -119000,1.5478445,2.291977,,,,,,,,,,,,,, -119087,,,0.758105456829071,0.9552485346794128,0.6878600120544434,1.2778600454330444,50000.0,0.5605000257492065,1.9243587255477903,10000.0,55498.36799407005,62457.77957677841,55498.36799407005,6947.353275775909,5.781080007553101,0.0 -119100,1.6338623,2.8803005,,,,,,,,,,,,,, -119200,1.473757,3.6943536,,,,,,,,,,,,,, -119300,1.7708186,1.9452323,,,,,,,,,,,,,, -119400,1.5451968,3.345552,,,,,,,,,,,,,, -119500,1.8893849,2.189724,,,,,,,,,,,,,, -119600,1.5473453,3.3145683,,,,,,,,,,,,,, -119700,1.749104,4.2093344,,,,,,,,,,,,,, -119800,1.7268125,2.8270452,,,,,,,,,,,,,, -119900,1.8262029,1.941597,,,,,,,,,,,,,, -119991,,,0.7644335627555847,0.937131941318512,0.6890999674797058,1.2613743543624878,50000.0,0.5639000535011292,1.8964364528656008,10000.0,55918.59027314186,62929.843616724014,55918.59027314186,6999.096809387207,5.830645799636841,0.0 -120000,1.7714925,2.2812421,,,,,,,,,,,,,, -120100,1.6579645,2.0254002,,,,,,,,,,,,,, -120200,1.5644904,2.2657883,,,,,,,,,,,,,, -120300,1.7126447,2.656459,,,,,,,,,,,,,, -120400,1.7956529,2.0340352,,,,,,,,,,,,,, -120500,1.9375093,2.1018121,,,,,,,,,,,,,, -120600,2.0698192,2.0094247,,,,,,,,,,,,,, -120700,1.7410659,1.921515,,,,,,,,,,,,,, -120800,1.5396359,3.190006,,,,,,,,,,,,,, -120894,,,0.7629492282867432,0.9582368731498718,0.6942200064659119,1.2611232995986938,50000.0,0.566100001335144,1.8966219425201416,10000.0,56338.974491119385,63400.81737446785,56338.974491119385,7049.592758893967,5.876109838485718,0.0 -120900,1.6873578,2.186094,,,,,,,,,,,,,, -121000,1.8225935,3.942614,,,,,,,,,,,,,, -121100,1.8080006,1.9993544,,,,,,,,,,,,,, -121200,1.577287,3.8072684,,,,,,,,,,,,,, -121300,1.6202888,2.5690234,,,,,,,,,,,,,, -121400,1.7265024,3.7459145,,,,,,,,,,,,,, -121500,1.7531506,1.9506698,,,,,,,,,,,,,, -121600,1.7759696,1.982733,,,,,,,,,,,,,, -121700,1.7344526,1.9712224,,,,,,,,,,,,,, -121797,,,0.7608593702316284,0.9513280391693116,0.6955199837684631,1.255292892456055,50000.0,0.5713000297546387,1.8828102350234983,10000.0,56759.05333805084,63873.81185340881,56759.05333805084,7102.412714481354,5.923561811447144,0.0 -121800,1.5472237,2.3900151,,,,,,,,,,,,,, -121900,1.7720563,1.984065,,,,,,,,,,,,,, -122000,1.9415435,2.0738063,,,,,,,,,,,,,, -122100,1.7705373,1.9371761,,,,,,,,,,,,,, -122200,1.5618144,2.4565957,,,,,,,,,,,,,, -122300,1.6791238,2.4664438,,,,,,,,,,,,,, -122400,1.7602386,3.9371562,,,,,,,,,,,,,, -122500,1.8104875,2.023802,,,,,,,,,,,,,, -122600,1.7061331,1.9315159,,,,,,,,,,,,,, -122700,1.935515,1.9437418,,,,,,,,,,,,,, -122701,,,0.7781640291213989,0.8763917088508606,0.6922599673271179,1.2633213996887207,50000.0,0.5670000314712524,1.9046696424484253,10000.0,57179.4742333889,64344.930874586105,57179.4742333889,7153.010915279388,5.974011421203613,0.0 -122800,1.6568576,1.9219131,,,,,,,,,,,,,, -122900,1.7695874,2.2328377,,,,,,,,,,,,,, -123000,1.579335,2.6637292,,,,,,,,,,,,,, -123100,1.6530108,3.2053232,,,,,,,,,,,,,, -123200,1.9036722,2.0263422,,,,,,,,,,,,,, -123300,2.004397,2.1977575,,,,,,,,,,,,,, -123400,1.7628653,1.9221525,,,,,,,,,,,,,, -123500,2.013056,2.1277165,,,,,,,,,,,,,, -123600,1.6393207,4.0687795,,,,,,,,,,,,,, -123603,,,0.7577733993530273,0.9707837700843812,0.6915799975395203,1.2661999464035034,50000.0,0.5670000314712524,1.907639145851136,10000.0,57599.657051086426,64817.05897903442,57599.657051086426,7204.862086057663,6.020724296569824,0.0 -123700,1.78845,1.960318,,,,,,,,,,,,,, -123800,1.7139053,3.879263,,,,,,,,,,,,,, -123900,1.5959569,3.6092596,,,,,,,,,,,,,, -124000,1.9990599,1.8925884,,,,,,,,,,,,,, -124100,1.6596286,2.116476,,,,,,,,,,,,,, -124200,1.9564993,2.0133877,,,,,,,,,,,,,, -124300,1.8149807,2.0167365,,,,,,,,,,,,,, -124400,1.6518974,3.040203,,,,,,,,,,,,,, -124500,1.5803366,3.2065835,,,,,,,,,,,,,, -124503,,,0.7660741806030273,0.9606534838676452,0.6967200040817261,1.2727292776107788,50000.0,0.5719000101089478,1.917148470878601,10000.0,58019.7577214241,65288.98992753029,58019.7577214241,7256.597772836685,6.068363428115845,0.0 -124600,1.711065,2.7920213,,,,,,,,,,,,,, -124700,1.5589545,3.3425918,,,,,,,,,,,,,, -124800,2.0106566,1.8538268,,,,,,,,,,,,,, -124900,1.6838574,1.9075074,,,,,,,,,,,,,, -125000,1.836077,2.1196074,,,,,,,,,,,,,, -125100,2.0184808,1.8510867,,,,,,,,,,,,,, -125200,1.7196269,3.6702542,,,,,,,,,,,,,, -125300,1.537206,2.107503,,,,,,,,,,,,,, -125400,1.8807582,1.8565897,,,,,,,,,,,,,, -125405,,,0.7822851538658142,0.8623960018157959,0.6998800039291382,1.230411410331726,50000.0,0.5746000409126282,1.8689360618591309,10000.0,58439.809049129486,65759.95871829987,58439.809049129486,7307.421809196472,6.114432096481323,0.0 -125500,1.7317909,1.8454876,,,,,,,,,,,,,, -125600,1.6510038,4.0506897,,,,,,,,,,,,,, -125700,1.6326563,3.245666,,,,,,,,,,,,,, -125800,1.6945794,2.2750168,,,,,,,,,,,,,, -125900,1.7626834,4.4218774,,,,,,,,,,,,,, -126000,1.7521777,1.7784303,,,,,,,,,,,,,, -126100,1.6584508,2.3246899,,,,,,,,,,,,,, -126200,1.8768753,2.4148219,,,,,,,,,,,,,, -126300,1.7670931,4.42505,,,,,,,,,,,,,, -126304,,,0.764843761920929,0.9532604217529296,0.6993199586868286,1.2452449798583984,50000.0,0.5796000361442566,1.875036239624024,10000.0,58860.01943874359,66231.4698665142,58860.01943874359,7358.625300168991,6.164425373077393,0.0 -126400,1.5661108,4.3099775,,,,,,,,,,,,,, -126500,1.7286427,1.739254,,,,,,,,,,,,,, -126600,1.6503483,4.006409,,,,,,,,,,,,,, -126700,1.8099521,2.0073977,,,,,,,,,,,,,, -126800,1.752814,1.9383776,,,,,,,,,,,,,, -126900,1.7340274,3.1865203,,,,,,,,,,,,,, -127000,1.926231,2.0071344,,,,,,,,,,,,,, -127100,1.8669732,1.7809608,,,,,,,,,,,,,, -127200,1.8508263,2.3264494,,,,,,,,,,,,,, -127206,,,0.7702929377555847,0.9010846614837646,0.7008799910545349,1.2191319465637207,50000.0,0.5750000476837158,1.863745212554932,10000.0,59280.34904909134,66703.90618491173,59280.34904909134,7410.638778209686,6.2095441818237305,0.0 -127300,1.7237885,2.038709,,,,,,,,,,,,,, -127400,1.7776964,1.7805586,,,,,,,,,,,,,, -127500,1.9606053,1.8592162,,,,,,,,,,,,,, -127600,1.9740851,1.9931548,,,,,,,,,,,,,, -127700,1.7893769,2.169053,,,,,,,,,,,,,, -127800,1.7913858,2.0059133,,,,,,,,,,,,,, -127900,2.0410864,3.6481917,,,,,,,,,,,,,, -128000,1.8911521,1.8571746,,,,,,,,,,,,,, -128100,1.8277355,2.5476964,,,,,,,,,,,,,, -128109,,,0.7816601395606995,0.8531544208526611,0.7025600075721741,1.214359164237976,50000.0,0.579300045967102,1.843841552734375,10000.0,59700.57011055946,67176.2463812828,59700.57011055946,7462.659845113754,6.260125875473023,0.0 -128200,1.7304021,1.9065498,,,,,,,,,,,,,, -128300,1.9678397,1.9903009,,,,,,,,,,,,,, -128400,1.8144964,4.3262825,,,,,,,,,,,,,, -128500,1.7650002,3.4018843,,,,,,,,,,,,,, -128600,1.6737479,2.9091008,,,,,,,,,,,,,, -128700,1.777057,3.3422694,,,,,,,,,,,,,, -128800,1.9641813,1.8452103,,,,,,,,,,,,,, -128900,1.7697284,2.1615863,,,,,,,,,,,,,, -129000,1.7803514,3.199478,,,,,,,,,,,,,, -129007,,,0.7702538967132568,0.9246523976325988,0.7058799862861633,1.214882493019104,50000.0,0.5808000564575195,1.853320598602295,10000.0,60120.5441262722,67650.05022072792,60120.5441262722,7516.386365413666,6.315888404846191,0.0 -129100,1.6561364,2.885164,,,,,,,,,,,,,, -129200,1.7154058,2.640494,,,,,,,,,,,,,, -129300,2.075075,1.8626039,,,,,,,,,,,,,, -129400,2.0615466,1.821706,,,,,,,,,,,,,, -129500,1.7339959,2.505853,,,,,,,,,,,,,, -129600,2.1361122,2.0547683,,,,,,,,,,,,,, -129700,1.669126,4.473903,,,,,,,,,,,,,, -129800,1.910696,1.8033522,,,,,,,,,,,,,, -129900,1.899946,4.4084344,,,,,,,,,,,,,, -129909,,,0.7751367092132568,0.8928495049476624,0.7060399651527405,1.207680106163025,50000.0,0.5808000564575195,1.8382052183151243,10000.0,60540.73557567597,68121.60476827621,60540.73557567597,7567.649868488312,6.368206024169922,0.0 -130000,2.011744,2.5616183,,,,,,,,,,,,,, -130100,1.8151324,3.0234447,,,,,,,,,,,,,, -130200,1.7589909,3.5955186,,,,,,,,,,,,,, -130300,1.890026,1.8051319,,,,,,,,,,,,,, -130400,1.7804666,3.8128762,,,,,,,,,,,,,, -130500,1.9714538,1.7563019,,,,,,,,,,,,,, -130600,1.8217345,3.7819743,,,,,,,,,,,,,, -130700,1.8206141,1.8602722,,,,,,,,,,,,,, -130800,1.6409436,3.087154,,,,,,,,,,,,,, -130812,,,0.7800390720367432,0.8839324712753296,0.7036600112915039,1.2200918197631836,50000.0,0.5781000256538391,1.8685678243637085,10000.0,60960.21865844727,68593.51099419594,60960.21865844727,7619.496718406677,6.896216869354248,0.0 -130900,1.9524404,3.8536978,,,,,,,,,,,,,, -131000,1.7935393,1.7771451,,,,,,,,,,,,,, -131100,1.9533274,3.398289,,,,,,,,,,,,,, -131200,1.9066082,1.9157659,,,,,,,,,,,,,, -131300,1.9853309,3.9943874,,,,,,,,,,,,,, -131400,1.7847263,3.3704033,,,,,,,,,,,,,, -131500,1.6378498,4.344617,,,,,,,,,,,,,, -131600,1.9275393,2.1903696,,,,,,,,,,,,,, -131700,1.8110054,1.8092033,,,,,,,,,,,,,, -131714,,,0.7735546827316284,0.887130081653595,0.7057799696922302,1.1913654804229736,50000.0,0.5771000385284424,1.840156316757202,10000.0,61380.37313580513,69067.9515554905,61380.37313580513,7673.678128242493,6.953572034835815,0.0 -131800,2.0909407,1.9456528,,,,,,,,,,,,,, -131900,1.8435938,1.8506931,,,,,,,,,,,,,, -132000,2.0456626,1.9975836,,,,,,,,,,,,,, -132100,1.890138,1.9281561,,,,,,,,,,,,,, -132200,1.9941354,4.064794,,,,,,,,,,,,,, -132300,2.031362,1.864753,,,,,,,,,,,,,, -132400,1.9932636,1.9100589,,,,,,,,,,,,,, -132500,1.8647755,2.070197,,,,,,,,,,,,,, -132600,1.7864827,1.7856419,,,,,,,,,,,,,, -132613,,,0.7814843654632568,0.8491562604904175,0.7097600102424622,1.1710385084152222,50000.0,0.5889000296592712,1.8052948713302608,10000.0,61800.565898656845,69541.9029185772,61800.565898656845,7727.336859464645,7.005480766296387,0.0 -132700,1.8549677,1.752801,,,,,,,,,,,,,, -132800,1.8262469,3.2055018,,,,,,,,,,,,,, -132900,1.9486881,1.9517183,,,,,,,,,,,,,, -133000,1.890329,2.227369,,,,,,,,,,,,,, -133100,2.0210655,1.8085318,,,,,,,,,,,,,, -133200,2.0792866,1.7615756,,,,,,,,,,,,,, -133300,1.8357626,3.0542004,,,,,,,,,,,,,, -133400,1.894469,1.964912,,,,,,,,,,,,,, -133500,2.0523784,1.9592266,,,,,,,,,,,,,, -133511,,,0.793652355670929,0.8182350993156433,0.7105199694633484,1.1751962900161743,50000.0,0.5838000178337097,1.8079313039779663,10000.0,62220.73200464249,70015.30837655067,62220.73200464249,7780.478232383728,7.055763244628906,0.0 -133600,1.6897621,4.119447,,,,,,,,,,,,,, -133700,1.8917946,3.5593762,,,,,,,,,,,,,, -133800,2.0060015,1.8903582,,,,,,,,,,,,,, -133900,1.8275744,2.0092494,,,,,,,,,,,,,, -134000,1.7380999,2.3939977,,,,,,,,,,,,,, -134100,1.7779721,2.4283166,,,,,,,,,,,,,, -134200,1.8546137,2.7867796,,,,,,,,,,,,,, -134300,1.9751724,4.2888956,,,,,,,,,,,,,, -134400,1.7289662,2.2211745,,,,,,,,,,,,,, -134412,,,0.7824413776397705,0.8449131846427917,0.7128199934959412,1.1581400632858276,50000.0,0.591200053691864,1.7911345958709717,10000.0,62640.94975614548,70485.8378021717,62640.94975614548,7830.6907432079315,7.107162237167358,0.0 -134500,1.8778296,2.076176,,,,,,,,,,,,,, -134600,1.9347496,1.815492,,,,,,,,,,,,,, -134700,1.7922432,2.5949771,,,,,,,,,,,,,, -134800,1.8181694,2.929833,,,,,,,,,,,,,, -134900,2.0266104,1.9882467,,,,,,,,,,,,,, -135000,1.815765,2.6725023,,,,,,,,,,,,,, -135100,2.0341568,1.8050513,,,,,,,,,,,,,, -135200,2.091166,1.6828332,,,,,,,,,,,,,, -135300,2.1313784,1.9400222,,,,,,,,,,,,,, -135312,,,0.782910168170929,0.8488849401473999,0.7114799618721008,1.170685648918152,50000.0,0.589900016784668,1.7887169122695925,10000.0,63061.04555249214,70957.17614507675,63061.04555249214,7881.830750465393,7.161731958389282,0.0 -135400,1.9634416,1.7286934,,,,,,,,,,,,,, -135500,2.0957086,1.8743691,,,,,,,,,,,,,, -135600,1.9809555,1.8705761,,,,,,,,,,,,,, -135700,1.8906382,2.139977,,,,,,,,,,,,,, -135800,1.82728,1.9851228,,,,,,,,,,,,,, -135900,1.884126,2.126862,,,,,,,,,,,,,, -136000,1.9151855,2.455168,,,,,,,,,,,,,, -136100,1.8060529,2.4515467,,,,,,,,,,,,,, -136200,2.038428,1.7928594,,,,,,,,,,,,,, -136213,,,0.7852538824081421,0.8555783033370972,0.7083799839019775,1.213556170463562,50000.0,0.5837000012397766,1.847142100334168,10000.0,63480.95566868782,71429.1494038105,63480.95566868782,7933.790218830109,7.2181336879730225,0.0 -136300,1.6585732,3.0399766,,,,,,,,,,,,,, -136400,2.0468698,1.7530637,,,,,,,,,,,,,, -136500,2.182841,1.8986144,,,,,,,,,,,,,, -136600,2.0048597,1.7567462,,,,,,,,,,,,,, -136700,2.0118585,1.7561574,,,,,,,,,,,,,, -136800,1.9809462,1.8814994,,,,,,,,,,,,,, -136900,2.0389097,1.7696185,,,,,,,,,,,,,, -137000,1.9877027,1.7432795,,,,,,,,,,,,,, -137100,1.895546,3.6232855,,,,,,,,,,,,,, -137114,,,0.7789257764816284,0.8748505711555481,0.7115199565887451,1.1856586933135986,50000.0,0.5837000012397766,1.829952836036682,10000.0,63901.08681106568,71900.98412513733,63901.08681106568,7985.396743774414,7.266888856887817,0.0 -137200,2.057234,1.9465687,,,,,,,,,,,,,, -137300,2.0447266,1.8181436,,,,,,,,,,,,,, -137400,2.0029655,1.8878026,,,,,,,,,,,,,, -137500,1.8675497,3.8443723,,,,,,,,,,,,,, -137600,2.0844529,1.9128596,,,,,,,,,,,,,, -137700,1.9879018,1.7891207,,,,,,,,,,,,,, -137800,1.934085,1.7253938,,,,,,,,,,,,,, -137900,1.9577407,1.7296666,,,,,,,,,,,,,, -138000,2.1302438,1.8371317,,,,,,,,,,,,,, -138009,,,0.79212886095047,0.8283264636993408,0.717960000038147,1.1578574180603027,50000.0,0.5907000303268433,1.7869914770126345,10000.0,64321.17126393318,72372.30072402954,64321.17126393318,8036.531369924545,7.317673444747925,0.0 -138100,1.9285944,3.0311596,,,,,,,,,,,,,, -138200,2.1766334,1.8155272,,,,,,,,,,,,,, -138300,2.1564991,1.661749,,,,,,,,,,,,,, -138400,1.8991162,3.964685,,,,,,,,,,,,,, -138500,1.9609387,1.7919642,,,,,,,,,,,,,, -138600,2.1056943,1.7830883,,,,,,,,,,,,,, -138700,2.0440092,1.6569768,,,,,,,,,,,,,, -138800,2.0552466,1.7258914,,,,,,,,,,,,,, -138900,2.1622622,1.8322023,,,,,,,,,,,,,, -138908,,,0.7974609136581421,0.8029530048370361,0.7170799970626831,1.1547406911849976,50000.0,0.5958000421524048,1.7757651805877686,10000.0,64741.42377257347,72844.25393557549,64741.42377257347,8088.132665634155,7.369465112686157,0.0 -139000,2.062015,3.3424542,,,,,,,,,,,,,, -139100,1.8953849,1.7678201,,,,,,,,,,,,,, -139200,2.1699347,1.738126,,,,,,,,,,,,,, -139300,2.223654,2.1232677,,,,,,,,,,,,,, -139400,1.9195681,3.4373937,,,,,,,,,,,,,, -139500,1.8534288,3.885899,,,,,,,,,,,,,, -139600,2.1901996,1.7784392,,,,,,,,,,,,,, -139700,1.9729997,1.6501245,,,,,,,,,,,,,, -139800,1.9896172,3.2431026,,,,,,,,,,,,,, -139808,,,0.7875585556030273,0.8376666903495789,0.7172799706459045,1.1552940607070925,50000.0,0.5945000052452087,1.7900747060775757,10000.0,65161.54459476471,73316.9589650631,65161.54459476471,8140.616198778152,7.422507762908935,0.0 -139900,2.1387708,2.313905,,,,,,,,,,,,,, -140000,1.8130717,3.0586445,,,,,,,,,,,,,, -140100,1.9716649,1.5235236,,,,,,,,,,,,,, -140200,2.0423052,1.7997254,,,,,,,,,,,,,, -140300,1.9486433,1.8817695,,,,,,,,,,,,,, -140400,2.147474,3.673144,,,,,,,,,,,,,, -140500,1.9651784,2.7472167,,,,,,,,,,,,,, -140600,2.1724687,1.7974248,,,,,,,,,,,,,, -140700,2.12652,1.8170836,,,,,,,,,,,,,, -140711,,,0.7987695336341858,0.7999251484870911,0.7211399674415588,1.141826629638672,50000.0,0.6028000116348267,1.7677903175354004,10000.0,65581.90678691864,73790.06270742416,65581.90678691864,8193.26209139824,7.47014856338501,0.0 -140800,1.9778818,1.7747298,,,,,,,,,,,,,, -140900,2.2398705,1.7063745,,,,,,,,,,,,,, -141000,1.976068,3.9257987,,,,,,,,,,,,,, -141100,1.8804193,2.2419403,,,,,,,,,,,,,, -141200,2.4364016,1.7605574,,,,,,,,,,,,,, -141300,2.237137,1.6814516,,,,,,,,,,,,,, -141400,2.0814478,1.8645259,,,,,,,,,,,,,, -141500,1.9107859,3.8164597,,,,,,,,,,,,,, -141600,2.4648097,1.7298007,,,,,,,,,,,,,, -141614,,,0.8019140362739563,0.7914997935295105,0.7195999622344971,1.156431794166565,50000.0,0.595300018787384,1.7889913320541382,10000.0,66002.23467111588,74262.30921554565,66002.23467111588,8245.082573652267,7.520330667495727,0.0 -141700,2.106392,1.8466115,,,,,,,,,,,,,, -141800,1.9377998,2.7319531,,,,,,,,,,,,,, -141900,2.1542857,1.7175839,,,,,,,,,,,,,, -142000,2.03042,3.423691,,,,,,,,,,,,,, -142100,1.8322418,3.3523643,,,,,,,,,,,,,, -142200,2.0300982,1.7676384,,,,,,,,,,,,,, -142300,2.184677,1.7589931,,,,,,,,,,,,,, -142400,2.0208664,3.978087,,,,,,,,,,,,,, -142500,1.8305295,3.135499,,,,,,,,,,,,,, -142514,,,0.8023241758346558,0.7893538475036621,0.7219600081443787,1.13852059841156,50000.0,0.6007000207901001,1.7685530185699463,10000.0,66422.37011957169,74733.39208197594,66422.37011957169,8295.933718919754,7.569155216217041,0.0 -142600,2.017315,1.7378316,,,,,,,,,,,,,, -142700,2.0463264,1.6954672,,,,,,,,,,,,,, -142800,2.2464154,3.2013159,,,,,,,,,,,,,, -142900,2.27188,1.8434246,,,,,,,,,,,,,, -143000,2.0200655,3.8374972,,,,,,,,,,,,,, -143100,2.1048357,2.0358467,,,,,,,,,,,,,, -143200,2.1480067,3.655641,,,,,,,,,,,,,, -143300,1.8326846,2.0548358,,,,,,,,,,,,,, -143400,2.120206,1.8060387,,,,,,,,,,,,,, -143414,,,0.7995898127555847,0.8224471211433411,0.7217999696731567,1.164766550064087,50000.0,0.5990000367164612,1.7969648838043213,10000.0,66842.30435395241,75204.91482305527,66842.30435395241,8347.408848524094,7.634051322937012,0.0 -143500,2.3991082,1.7150646,,,,,,,,,,,,,, -143600,1.8533348,2.5660098,,,,,,,,,,,,,, -143700,2.195296,2.0408854,,,,,,,,,,,,,, -143800,2.0912964,1.85112,,,,,,,,,,,,,, -143900,2.3804972,1.620621,,,,,,,,,,,,,, -144000,2.3953185,1.8698294,,,,,,,,,,,,,, -144100,2.0136049,3.502253,,,,,,,,,,,,,, -144200,2.1194072,2.4257474,,,,,,,,,,,,,, -144300,2.0638342,2.1457405,,,,,,,,,,,,,, -144311,,,0.8031640648841858,0.8152992725372314,0.7218199968338013,1.169253706932068,50000.0,0.5958000421524048,1.7974989414215088,10000.0,67262.48490691185,75678.5110874176,67262.48490691185,8400.72471165657,7.686464548110962,0.0 -144400,2.503836,1.6028957,,,,,,,,,,,,,, -144500,2.0499098,1.8021958,,,,,,,,,,,,,, -144600,2.2751474,3.1415415,,,,,,,,,,,,,, -144700,2.2281666,2.4532058,,,,,,,,,,,,,, -144800,2.3717499,2.1508949,,,,,,,,,,,,,, -144900,2.1305602,2.8553643,,,,,,,,,,,,,, -145000,2.124334,2.308219,,,,,,,,,,,,,, -145100,2.0686736,2.0232303,,,,,,,,,,,,,, -145200,2.1825027,1.8706362,,,,,,,,,,,,,, -145214,,,0.8151953220367432,0.7412189841270447,0.7258599996566772,1.138275146484375,50000.0,0.6038000583648682,1.7650368213653564,10000.0,67682.51352453232,76150.04582834244,67682.51352453232,8452.127160549164,7.742059707641602,0.0 -145300,2.2204115,1.9952035,,,,,,,,,,,,,, -145400,2.3187337,3.596229,,,,,,,,,,,,,, -145500,2.1786566,2.194857,,,,,,,,,,,,,, -145600,2.126788,4.1400566,,,,,,,,,,,,,, -145700,2.1585243,1.7835239,,,,,,,,,,,,,, -145800,2.1546235,1.6082474,,,,,,,,,,,,,, -145900,2.063841,1.7423532,,,,,,,,,,,,,, -146000,2.2791595,1.7327355,,,,,,,,,,,,,, -146100,2.3025558,2.9460182,,,,,,,,,,,,,, -146115,,,0.8047655820846558,0.7920248508453369,0.726099967956543,1.126613736152649,50000.0,0.6037000417709351,1.7594192028045654,10000.0,68102.61263489723,76621.85620999336,68102.61263489723,8503.738919973373,7.794044256210327,0.0 -146200,2.2233202,4.0488305,,,,,,,,,,,,,, -146300,2.176033,1.5487889,,,,,,,,,,,,,, -146400,2.1631122,4.132457,,,,,,,,,,,,,, -146500,2.2221317,1.6382846,,,,,,,,,,,,,, -146600,2.3035743,3.9558094,,,,,,,,,,,,,, -146700,2.510154,1.7297411,,,,,,,,,,,,,, -146800,2.2519457,1.6143663,,,,,,,,,,,,,, -146900,2.326871,1.585344,,,,,,,,,,,,,, -147000,2.1837556,1.9121764,,,,,,,,,,,,,, -147015,,,0.8040234446525574,0.7940880656242371,0.7256399989128113,1.1346133947372437,50000.0,0.6011000275611877,1.7695281505584717,10000.0,68522.7866191864,77093.98651838303,68522.7866191864,8555.59926533699,7.842405796051025,0.0 -147100,2.602448,1.6367496,,,,,,,,,,,,,, -147200,2.2560835,2.421273,,,,,,,,,,,,,, -147300,2.3112733,1.6102381,,,,,,,,,,,,,, -147400,2.1176019,2.7196329,,,,,,,,,,,,,, -147500,2.3897028,3.6818774,,,,,,,,,,,,,, -147600,2.0234628,2.7133498,,,,,,,,,,,,,, -147700,2.2746315,2.1896353,,,,,,,,,,,,,, -147800,2.3326855,1.6988084,,,,,,,,,,,,,, -147900,2.2202344,1.587163,,,,,,,,,,,,,, -147914,,,0.8194531202316284,0.7242646813392639,0.7303400039672852,1.1180957555770874,50000.0,0.6055000424385071,1.7449946403503418,10000.0,68942.73813343048,77567.22704386711,68942.73813343048,8608.788478851318,7.894062995910644,0.0 -148000,1.9788755,2.034015,,,,,,,,,,,,,, -148100,2.161157,3.665174,,,,,,,,,,,,,, -148200,2.2218974,1.6799784,,,,,,,,,,,,,, -148300,2.3786623,1.7056433,,,,,,,,,,,,,, -148400,2.1973836,1.7156297,,,,,,,,,,,,,, -148500,2.1996775,3.8853827,,,,,,,,,,,,,, -148600,2.2407737,1.7924795,,,,,,,,,,,,,, -148700,2.2487504,2.1000462,,,,,,,,,,,,,, -148800,2.528713,1.7615496,,,,,,,,,,,,,, -148809,,,0.8100390434265137,0.7443209886550903,0.7328599691390991,1.0836445093154907,50000.0,0.6076000332832336,1.71618390083313,10000.0,69363.04661417007,78039.55141115189,69363.04661417007,8660.702143192291,7.949257135391235,0.0 -148900,2.2066517,2.2145321,,,,,,,,,,,,,, -149000,2.1082308,3.9723487,,,,,,,,,,,,,, -149100,2.4741278,1.6918688,,,,,,,,,,,,,, -149200,2.0353558,2.0114927,,,,,,,,,,,,,, -149300,1.9515789,2.615871,,,,,,,,,,,,,, -149400,2.5021753,2.3055515,,,,,,,,,,,,,, -149500,2.2154763,3.6977303,,,,,,,,,,,,,, -149600,2.1746843,3.281671,,,,,,,,,,,,,, -149700,2.6620052,1.838944,,,,,,,,,,,,,, -149710,,,0.8106640577316284,0.7712754011154175,0.7304799556732178,1.1216198205947876,50000.0,0.6105000376701355,1.738780498504639,10000.0,69783.11700677872,78513.82325577736,69783.11700677872,8714.796729803085,8.007672309875488,0.0 -149800,2.692527,1.5857196,,,,,,,,,,,,,, -149900,2.4006107,1.7665702,,,,,,,,,,,,,, -150000,2.8671257,1.77514,,,,,,,,,,,,,, -150100,2.3101814,1.6363337,,,,,,,,,,,,,, -150200,2.5146031,3.6268098,,,,,,,,,,,,,, -150300,2.3005955,4.064559,,,,,,,,,,,,,, -150400,2.2806945,1.7963367,,,,,,,,,,,,,, -150500,2.2934608,1.6992913,,,,,,,,,,,,,, -150600,2.4826396,1.8432176,,,,,,,,,,,,,, -150609,,,0.8217577934265137,0.6940465569496155,0.7341199517250061,1.0832279920578003,50000.0,0.6121000051498413,1.7156203985214231,10000.0,70203.25068831444,78987.17783665657,70203.25068831444,8767.916935682297,8.061208963394165,0.0 -150700,2.4739265,3.641234,,,,,,,,,,,,,, -150800,2.3423069,2.0779152,,,,,,,,,,,,,, -150900,2.3938522,3.5677829,,,,,,,,,,,,,, -151000,2.2948563,1.6825329,,,,,,,,,,,,,, -151100,2.4538777,1.6050606,,,,,,,,,,,,,, -151200,2.2901354,1.7492985,,,,,,,,,,,,,, -151300,2.2989528,1.918362,,,,,,,,,,,,,, -151400,2.3653824,1.5569811,,,,,,,,,,,,,, -151500,2.529001,2.0702078,,,,,,,,,,,,,, -151509,,,0.8132226467132568,0.7564131021499634,0.7326799631118774,1.1018426418304443,50000.0,0.6114000082015991,1.724048733711243,10000.0,70623.22665309906,79460.3146841526,70623.22665309906,8820.977076530457,8.114773273468018,0.0 -151600,2.6142619,2.3971868,,,,,,,,,,,,,, -151700,2.2699647,2.2800322,,,,,,,,,,,,,, -151800,2.4854605,1.8713319,,,,,,,,,,,,,, -151900,2.0786974,3.3017623,,,,,,,,,,,,,, -152000,2.4974978,1.5923827,,,,,,,,,,,,,, -152100,1.9886003,2.656921,,,,,,,,,,,,,, -152200,2.2063615,3.4930158,,,,,,,,,,,,,, -152300,2.508277,1.75465,,,,,,,,,,,,,, -152400,3.4622507,1.5496396,,,,,,,,,,,,,, -152411,,,0.8139452934265137,0.7501934170722961,0.7333599925041199,1.1036591529846191,50000.0,0.6057000160217285,1.7319797277450562,10000.0,71043.51707410812,79930.89526510239,71043.51707410812,8871.157362699509,8.176369428634644,0.0 -152500,2.3372838,1.5379162,,,,,,,,,,,,,, -152600,2.5093572,1.6505141,,,,,,,,,,,,,, -152700,2.33841,3.82657,,,,,,,,,,,,,, -152800,2.0517688,2.4361255,,,,,,,,,,,,,, -152900,2.563026,1.7112893,,,,,,,,,,,,,, -153000,2.6332767,2.4262915,,,,,,,,,,,,,, -153100,2.685082,1.5932522,,,,,,,,,,,,,, -153200,2.4066625,1.586751,,,,,,,,,,,,,, -153300,2.371555,1.8151304,,,,,,,,,,,,,, -153312,,,0.826464831829071,0.6860719919204712,0.7383999824523926,1.0645508766174316,50000.0,0.6135000586509705,1.6989619731903076,10000.0,71463.45511341095,80403.29509854317,71463.45511341095,8923.512953042984,8.234338998794556,0.0 -153400,2.4066718,1.7408447,,,,,,,,,,,,,, -153500,2.8332064,1.7020943,,,,,,,,,,,,,, -153600,2.8594925,1.5662154,,,,,,,,,,,,,, -153700,2.3379078,3.8112774,,,,,,,,,,,,,, -153800,3.5097504,3.5712764,,,,,,,,,,,,,, -153900,2.3628652,1.9686551,,,,,,,,,,,,,, -154000,2.209534,1.66046,,,,,,,,,,,,,, -154100,2.4253883,1.5974447,,,,,,,,,,,,,, -154200,2.1068454,1.7579525,,,,,,,,,,,,,, -154210,,,0.8174414038658142,0.7186430096626282,0.7394199967384338,1.062025547027588,50000.0,0.6165000200271606,1.6909252405166626,10000.0,71883.42968702316,80874.89892697334,71883.42968702316,8975.040938138962,8.287301778793335,0.0 -154300,2.3762655,1.651077,,,,,,,,,,,,,, -154400,2.2924302,2.4692953,,,,,,,,,,,,,, -154500,2.2501576,2.9973907,,,,,,,,,,,,,, -154600,2.4645636,3.2946718,,,,,,,,,,,,,, -154700,2.5144813,1.4995824,,,,,,,,,,,,,, -154800,2.2174513,2.04306,,,,,,,,,,,,,, -154900,2.278796,1.7915992,,,,,,,,,,,,,, -155000,2.7733495,1.5858243,,,,,,,,,,,,,, -155100,2.3047643,1.5843779,,,,,,,,,,,,,, -155111,,,0.8227929472923279,0.6901466250419617,0.7375199794769287,1.0573351383209229,50000.0,0.6168000102043152,1.6916823387145996,10000.0,72303.53037238121,81346.9751830101,72303.53037238121,9026.914111852646,8.34177279472351,0.0 -155200,2.289108,1.8006268,,,,,,,,,,,,,, -155300,2.3560853,3.9573119,,,,,,,,,,,,,, -155400,2.327758,1.5585009,,,,,,,,,,,,,, -155500,2.3504536,2.2050583,,,,,,,,,,,,,, -155600,2.642028,4.0765557,,,,,,,,,,,,,, -155700,2.6844018,2.6982498,,,,,,,,,,,,,, -155800,2.261515,3.4561193,,,,,,,,,,,,,, -155900,2.3959935,2.1811557,,,,,,,,,,,,,, -156000,2.5357425,2.0472748,,,,,,,,,,,,,, -156013,,,0.8246484398841858,0.6990688443183899,0.7361199855804443,1.0750834941864014,50000.0,0.6097000241279602,1.709358811378479,10000.0,72723.77445316315,81820.81177711487,72723.77445316315,9080.40405368805,8.39595341682434,0.0 -156100,2.3985522,1.6193824,,,,,,,,,,,,,, -156200,2.4797482,1.4391016,,,,,,,,,,,,,, -156300,2.6533837,1.4987471,,,,,,,,,,,,,, -156400,2.580717,1.6650953,,,,,,,,,,,,,, -156500,2.6002727,2.9910245,,,,,,,,,,,,,, -156600,2.2960353,2.1970763,,,,,,,,,,,,,, -156700,2.4177113,1.9311537,,,,,,,,,,,,,, -156800,2.5815177,3.9563472,,,,,,,,,,,,,, -156900,2.2417219,1.5071312,,,,,,,,,,,,,, -156915,,,0.8210351467132568,0.7016533613204956,0.74263995885849,1.0572545528411863,50000.0,0.6189000010490417,1.6846977472305298,10000.0,73143.83418130875,82292.53384494781,73143.83418130875,9131.963080406187,8.450865745544434,0.0 -157000,2.514379,2.1030483,,,,,,,,,,,,,, -157100,2.870117,1.6861498,,,,,,,,,,,,,, -157200,2.412934,2.5065546,,,,,,,,,,,,,, -157300,2.542269,1.6075398,,,,,,,,,,,,,, -157400,2.3371696,2.5546677,,,,,,,,,,,,,, -157500,2.4257078,1.6282041,,,,,,,,,,,,,, -157600,2.5607111,3.5044515,,,,,,,,,,,,,, -157700,2.9283037,1.5728514,,,,,,,,,,,,,, -157800,2.9303544,1.6205883,,,,,,,,,,,,,, -157815,,,0.8214843273162842,0.6942675709724426,0.7416399717330933,1.0510855913162231,50000.0,0.615600049495697,1.6868818998336792,10000.0,73563.9743270874,82763.81664323807,73563.9743270874,9182.99531197548,8.513504981994629,0.0 -157900,2.396234,1.4288359,,,,,,,,,,,,,, -158000,2.535592,1.5073571,,,,,,,,,,,,,, -158100,2.5090313,2.915856,,,,,,,,,,,,,, -158200,2.6927934,1.755554,,,,,,,,,,,,,, -158300,2.6301243,1.5443248,,,,,,,,,,,,,, -158400,2.71295,1.5593506,,,,,,,,,,,,,, -158500,2.5867364,3.9450166,,,,,,,,,,,,,, -158600,2.4690797,1.8562107,,,,,,,,,,,,,, -158700,2.564951,3.5847876,,,,,,,,,,,,,, -158714,,,0.8269335627555847,0.649540901184082,0.7415800094604492,1.031722068786621,50000.0,0.6203000545501709,1.66560959815979,10000.0,73984.35371112823,83236.29381513596,73984.35371112823,9234.988682985306,8.569389820098877,0.0 -158800,2.4422073,2.1569977,,,,,,,,,,,,,, -158900,2.4477384,1.443963,,,,,,,,,,,,,, -159000,2.4281783,2.8281906,,,,,,,,,,,,,, -159100,2.401861,3.1974502,,,,,,,,,,,,,, -159200,2.3062418,1.6951399,,,,,,,,,,,,,, -159300,2.8885157,1.6376868,,,,,,,,,,,,,, -159400,2.4561481,2.0243568,,,,,,,,,,,,,, -159500,2.5627396,1.4410845,,,,,,,,,,,,,, -159600,2.8390713,1.4937221,,,,,,,,,,,,,, -159614,,,0.8246288895606995,0.6832771897315979,0.7423799633979797,1.0461320877075195,50000.0,0.6204000115394592,1.683318853378296,10000.0,74404.43561291695,83708.0224416256,74404.43561291695,9286.535109996796,8.62211298942566,0.0 -159700,2.6311996,1.5236858,,,,,,,,,,,,,, -159800,3.3497338,1.62451,,,,,,,,,,,,,, -159900,2.403446,3.0466623,,,,,,,,,,,,,, -160000,2.5637727,2.4782019,,,,,,,,,,,,,, -160100,2.5301747,1.4844089,,,,,,,,,,,,,, -160200,2.7597885,2.12181,,,,,,,,,,,,,, -160300,2.7116306,3.262834,,,,,,,,,,,,,, -160400,2.4379916,1.5119462,,,,,,,,,,,,,, -160500,2.3944042,2.3402798,,,,,,,,,,,,,, -160509,,,0.8296093344688416,0.6793026328086853,0.744159996509552,1.0463520288467407,50000.0,0.6167000532150269,1.6868876218795776,10000.0,74824.56441307068,84179.59306025505,74824.56441307068,9337.866757631302,8.684773206710815,0.0 -160600,2.6471891,2.5341449,,,,,,,,,,,,,, -160700,2.510086,3.0833788,,,,,,,,,,,,,, -160800,2.5131829,3.454294,,,,,,,,,,,,,, -160900,2.922422,1.5566049,,,,,,,,,,,,,, -161000,2.2582388,2.5414333,,,,,,,,,,,,,, -161100,2.9965584,1.6316572,,,,,,,,,,,,,, -161200,2.448305,1.4763148,,,,,,,,,,,,,, -161300,2.5905833,1.6353518,,,,,,,,,,,,,, -161400,2.5686433,3.2748923,,,,,,,,,,,,,, -161411,,,0.8265038728713989,0.695087730884552,0.7425000071525574,1.076271653175354,50000.0,0.6142000555992126,1.708626389503479,10000.0,75244.89977121353,84653.49061632156,75244.89977121353,9391.32567691803,8.740829706192017,0.0 -161500,2.4561954,1.3952296,,,,,,,,,,,,,, -161600,2.4309022,1.5400598,,,,,,,,,,,,,, -161700,2.4963074,1.9514474,,,,,,,,,,,,,, -161800,2.438403,3.3221433,,,,,,,,,,,,,, -161900,2.5808592,1.5705452,,,,,,,,,,,,,, -162000,2.766999,2.4840431,,,,,,,,,,,,,, -162100,2.777399,1.4908551,,,,,,,,,,,,,, -162200,2.7408671,1.5816,,,,,,,,,,,,,, -162300,2.6169226,3.7556796,,,,,,,,,,,,,, -162313,,,0.8310937285423279,0.6638551950454712,0.7471999526023865,1.0294922590255735,50000.0,0.6253000497817993,1.6685092449188232,10000.0,75665.09418177605,85124.82177829742,75665.09418177605,9442.36303639412,8.792722463607788,0.0 -162400,2.88329,1.5514132,,,,,,,,,,,,,, -162500,2.7763233,1.4523903,,,,,,,,,,,,,, -162600,2.7175982,1.8016808,,,,,,,,,,,,,, -162700,2.4147043,2.1022947,,,,,,,,,,,,,, -162800,2.6136942,4.082072,,,,,,,,,,,,,, -162900,2.617125,1.4801276,,,,,,,,,,,,,, -163000,2.421431,1.5615728,,,,,,,,,,,,,, -163100,2.465435,1.4848266,,,,,,,,,,,,,, -163200,2.583624,1.8781238,,,,,,,,,,,,,, -163214,,,0.8335351347923279,0.658676266670227,0.7470600008964539,1.034615993499756,50000.0,0.6261000037193298,1.664926290512085,10000.0,76085.4808113575,85596.82547187805,76085.4808113575,9493.877099275587,8.844767093658447,0.0 -163300,2.914221,2.6734643,,,,,,,,,,,,,, -163400,2.3505764,1.9241948,,,,,,,,,,,,,, -163500,2.3410099,1.366329,,,,,,,,,,,,,, -163600,2.692702,3.9005282,,,,,,,,,,,,,, -163700,2.4836822,1.4703202,,,,,,,,,,,,,, -163800,3.1972256,1.4724115,,,,,,,,,,,,,, -163900,2.798485,1.5445753,,,,,,,,,,,,,, -164000,2.7194889,1.5249323,,,,,,,,,,,,,, -164100,2.7243376,2.8584166,,,,,,,,,,,,,, -164114,,,0.8374804258346558,0.6364269256591797,0.747439980506897,1.0335808992385864,50000.0,0.6234000325202942,1.6698811054229736,10000.0,76505.46322965622,86068.02240657806,76505.46322965622,9544.983451128006,8.900535106658936,0.0 -164200,2.9229183,1.6875461,,,,,,,,,,,,,, -164300,2.6903286,1.4931257,,,,,,,,,,,,,, -164400,2.3078864,2.9756265,,,,,,,,,,,,,, -164500,2.8175387,1.5262253,,,,,,,,,,,,,, -164600,2.6350217,1.5058739,,,,,,,,,,,,,, -164700,3.2430947,1.5741634,,,,,,,,,,,,,, -164800,2.7246158,1.4825513,,,,,,,,,,,,,, -164900,2.5897815,1.7180583,,,,,,,,,,,,,, -165000,2.8028657,1.5438765,,,,,,,,,,,,,, -165017,,,0.838183581829071,0.639176070690155,0.7498399615287781,1.0234603881835938,50000.0,0.6254000067710876,1.6488336324691772,10000.0,76925.65456914902,86540.2251765728,76925.65456914902,9596.894824266434,8.953065156936646,0.0 -165100,2.872313,1.5378318,,,,,,,,,,,,,, -165200,2.792662,1.5804656,,,,,,,,,,,,,, -165300,2.648106,3.966325,,,,,,,,,,,,,, -165400,3.2266858,1.5157187,,,,,,,,,,,,,, -165500,2.6895883,1.998046,,,,,,,,,,,,,, -165600,2.5182629,1.555622,,,,,,,,,,,,,, -165700,2.7817726,1.4384241,,,,,,,,,,,,,, -165800,2.6563132,3.7063475,,,,,,,,,,,,,, -165900,2.6417875,4.0407805,,,,,,,,,,,,,, -165916,,,0.8354687094688416,0.6519188284873962,0.7498599886894226,1.0220907926559448,50000.0,0.6272000074386597,1.657045841217041,10000.0,77345.65645003319,87012.98839354515,77345.65645003319,9649.550794363022,9.01101303100586,0.0 -166000,3.022616,1.4790392,,,,,,,,,,,,,, -166100,2.8304138,1.6995149,,,,,,,,,,,,,, -166200,2.666725,1.4840987,,,,,,,,,,,,,, -166298,,,,,,,,,,,77520.36937975883,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index f6e47206d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -177.5883333683014,0.0,62.33279633522034,1,0,62.33279633522034,30.871214,2472,1.190014827453131,239.9211916923523,31.690332,1.115215625845949,30.757864,5348,1.1779352558965794 -286.587749004364,0.046248435974121,1502.6152069568634,1748,0,1502.6152069568634,5.9551396,2472,0.899579550301627,1789.3208968639374,5.9307604,0.944635537887994,6.0122747,5348,0.8966179750330672 -406.619413614273,0.0941307544708252,2942.7661283016205,3538,0,2942.7661283016205,3.388314,2472,0.7071273332927102,3349.630561113357,3.3182616,0.719244599525877,3.7128878,5348,0.7624858800698997 -542.1664445400238,0.1464216709136963,4383.109746217728,5287,0,4383.109746217728,0.8411037,2472,0.2719517396867954,4925.651102542877,0.7781294,0.2672749603344034,1.1462481,5348,0.3319366268573139 -677.7943549156189,0.1984179019927978,5823.235984802246,7026,0,5823.235984802246,0.58280253,2472,0.1943208823350192,6501.536516189575,0.5286893,0.1846181235534983,0.84069234,5348,0.2531643125404288 -811.9875221252441,0.2500596046447754,7263.73501253128,8791,0,7263.73501253128,0.4925594,2472,0.1663721487620092,8076.360379934311,0.44677088,0.159434864858988,0.7451185,5348,0.2265271247477722 -948.816588640213,0.3072614669799804,8703.784867048264,10539,0,8703.784867048264,0.4373133,2472,0.1492088639733512,9653.374977111816,0.4121761,0.1454353404356338,0.67398584,5348,0.2053930891993396 -1088.169734954834,0.355985164642334,10143.811375379562,12248,0,10143.811375379562,0.411901,2472,0.1403936384132594,11232.8799161911,0.36185822,0.1293897432625018,0.6489348,5348,0.1973217992411442 -1223.8755309581757,0.4083857536315918,11584.330972671509,13977,0,11584.330972671509,0.3806148,2472,0.1301362906993277,12809.234954595566,0.30664057,0.1145063574195897,0.6024696,5348,0.1828687836102609 -1358.2513551712036,0.4628069400787353,13024.30504345894,15712,0,13024.30504345894,0.36267507,2472,0.1221132167448662,14383.716886281967,0.28140113,0.1030010267650516,0.5813166,5348,0.1775683790802977 -1492.9678757190704,0.5146021842956543,14464.859641313553,17428,0,14464.859641313553,0.35867026,2472,0.1208132756484471,15959.115698099136,0.27844656,0.1055253744630417,0.5747308,5348,0.1757339949988897 -1630.2590498924255,0.5669634342193604,15904.858520269394,19133,0,15904.858520269394,0.33965674,2472,0.1150650986127191,17536.534660100937,0.2847909,0.1031915174912081,0.5575726,5348,0.1692557227956013 -1768.594096660614,0.622377872467041,17345.28639960289,20853,0,17345.28639960289,0.32976592,2472,0.1125261511587756,19115.42906999588,0.27910382,0.1014535573904555,0.5391951,5348,0.1646697625920812 -1904.43694972992,0.6713097095489502,18785.17260837555,22545,0,18785.17260837555,0.313901,2472,0.1081591615379928,20691.28274512291,0.2605467,0.0943512707068628,0.52861744,5348,0.1601224210008013 -2039.88334441185,0.7225484848022461,20225.37378954888,24263,0,20225.37378954888,0.31107822,2472,0.1044827656246826,22267.05878567696,0.2430587,0.0905367171994425,0.5190067,5348,0.1562702144298444 -2174.472556114197,0.7736678123474121,21665.884313106537,25979,0,21665.884313106537,0.3014969,2472,0.1013954055206873,23842.28617978096,0.22930108,0.0865750280393333,0.51106757,5348,0.153431746430192 -2313.663006067276,0.8280179500579834,23106.38948559761,27691,0,23106.38948559761,0.2954965,2472,0.0986939654296914,25422.112336874008,0.22043964,0.0796043816673655,0.49129182,5348,0.1478706662676076 -2449.174416303634,0.8814880847930908,24547.143973588943,29397,0,24547.143973588943,0.28681365,2472,0.0966628074665366,26998.507881879807,0.23616512,0.0882897815912636,0.48435923,5348,0.1464224683086013 -2584.9278602600098,0.9351787567138672,25987.65677118301,31119,0,25987.65677118301,0.28513002,2472,0.0958097211220116,28574.90484571457,0.21892151,0.0790563160950412,0.47866896,5348,0.1437481294109696 -2723.3589627742767,0.9874227046966552,27427.88367891312,32826,0,27427.88367891312,0.27610952,2472,0.093636382101436,30153.691357135773,0.22863874,0.0829305128191956,0.47382867,5348,0.1411317184316981 -2858.290730953217,1.050553798675537,28867.750519037247,34538,0,28867.750519037247,0.26683512,2472,0.0901427904048097,31728.63215994835,0.21834032,0.077597323017573,0.46095514,5348,0.1372602025546212 -2998.3748049736023,1.1091015338897705,30308.22143220901,36249,0,30308.22143220901,0.26671857,2472,0.0898381167103365,33309.32222747803,0.17301418,0.0649682769082541,0.4529431,5348,0.1348368846365505 -3134.772925376892,1.1662836074829102,31748.56366539001,37952,0,31748.56366539001,0.25399256,2472,0.0842524323116608,34886.196565151215,0.19108309,0.0690904099174251,0.43196315,5348,0.1302798883922106 -3270.4538078308105,1.2276201248168943,33188.59181380272,39671,0,33188.59181380272,0.24960111,2472,0.0832977880689781,36462.04642629624,0.23817594,0.0868124298259412,0.43086866,5348,0.1289668555760448 -3403.1156027317047,1.2870965003967283,34628.68599200249,41382,0,34628.68599200249,0.24681735,2472,0.0825665712022424,38034.93848752976,0.2418203,0.0883383202784547,0.4284742,5348,0.1280110449231007 -3535.2889487743378,1.3462481498718262,36068.79007220268,43104,0,36068.79007220268,0.23967984,2472,0.0812463185261917,39607.35124826431,0.2755499,0.1017133564588683,0.4144993,5348,0.1238691987603425 -3669.527673482895,1.4051265716552734,37508.72058033943,44819,0,37508.72058033943,0.23447508,2472,0.0793573416204578,41181.65663433075,0.23959982,0.0853581710813722,0.40491655,5348,0.120325941087307 -3803.516398906708,1.4624698162078855,38949.01678466797,46534,0,38949.01678466797,0.23017377,2472,0.0778339731480917,42756.07475566864,0.22094382,0.0819313462416354,0.40488222,5348,0.1207990190872491 -3939.021583557129,1.5191435813903809,40389.19332933426,48243,0,40389.19332933426,0.22285704,2472,0.0749294172607803,44331.88853049278,0.18676206,0.0701747707968547,0.39040998,5348,0.1165316624347104 -4073.134298324585,1.579833984375,41829.21695446968,49950,0,41829.21695446968,0.21055159,2472,0.0707249202770499,45906.16408348084,0.20000587,0.0746332694281664,0.3872141,5348,0.1135290653330372 -4207.490744113922,1.6347463130950928,43269.77072787285,51667,0,43269.77072787285,0.20984745,2472,0.0708264781752076,47481.20555949211,0.17628975,0.0666766878543205,0.3812173,5348,0.1120615580679108 -4343.593742609024,1.6926517486572266,44709.96631407738,53392,0,44709.96631407738,0.20121741,2472,0.0685109580972112,49057.63871669769,0.17431833,0.0662895019472981,0.36423317,5348,0.1081224596194135 -4480.270308256149,1.7505762577056885,46150.56253623962,55121,0,46150.56253623962,0.20039488,2472,0.0675360022748969,50635.04708957672,0.16579323,0.0624706694271911,0.36580062,5348,0.108238315456134 -4612.914803504944,1.8098342418670648,47590.833641052246,56851,0,47590.833641052246,0.19666456,2472,0.066317307497004,52208.10142946243,0.16412912,0.061519178599636,0.35546932,5348,0.1046467845177983 -4749.212601184845,1.87066912651062,49030.73461127281,58559,0,49030.73461127281,0.19008121,2472,0.0627424694818516,53784.43896389008,0.16294,0.0612672758361559,0.34518048,5348,0.1016538420691852 -4885.57377076149,1.9312732219696045,50470.98560571671,60285,0,50470.98560571671,0.18431841,2472,0.060386326244592,55361.18985915184,0.13095543,0.0508440379312663,0.33887717,5348,0.099317416028655 -5021.4548535346985,1.996178150177002,51911.11315703392,62016,0,51911.11315703392,0.17523469,2472,0.0594519935815408,56937.3402159214,0.12969963,0.0491762128611199,0.3313753,5348,0.0969713353350647 -5155.799608707428,2.057328224182129,53351.861990213394,63722,0,53351.861990213394,0.17089857,2472,0.0575020819369122,58512.570706129074,0.11556399,0.0442518859081373,0.32159543,5348,0.0928874170906668 -5290.980396270752,2.120563507080078,54792.02154159546,65443,0,54792.02154159546,0.1703582,2472,0.056384945057177,60088.04986286163,0.11451922,0.0434646458320285,0.3201817,5348,0.0917481680295818 -5424.689416408539,2.1815385818481445,56232.27780032158,67169,0,56232.27780032158,0.16590431,2472,0.0544147218329169,61662.15571832657,0.1024924,0.0387681402473111,0.31268138,5348,0.0893055408053911 -5557.467604875565,2.247718572616577,57672.29211544991,68876,0,57672.29211544991,0.1626802,2472,0.0536428818069181,63235.09067606926,0.09694151,0.0373078157687922,0.30519754,5348,0.0869208414995607 -5691.591963291168,2.307018518447876,59112.44225859642,70607,0,59112.44225859642,0.16090697,2472,0.0543131639347592,64809.5007917881,0.07881192,0.0299819644603755,0.30443206,5348,0.0855788447242148 -5824.119750261307,2.368692636489868,60552.34362244606,72323,0,60552.34362244606,0.15847631,2472,0.051266426990027016,66382.07028150558,0.07315466,0.027833001988071572,0.2974776,5348,0.08360929549996621 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 5f0e32734..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,775 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,42.525078,32.340534,,,,,,,,,,,,,, -1,,,31.690332,1.115215625845949,30.757864,1.1779352558965794,5348.0,30.871214,1.190014827453131,2472.0,62.33279633522034,239.9211916923523,62.33279633522034,177.5883333683014,0.0,0.0 -100,27.41089,7.3506327,,,,,,,,,,,,,, -200,2.985887,6.135008,,,,,,,,,,,,,, -300,0.36946872,5.8703218,,,,,,,,,,,,,, -400,0.2716549,5.8151875,,,,,,,,,,,,,, -500,0.4552,5.808858,,,,,,,,,,,,,, -600,0.3759365,5.806507,,,,,,,,,,,,,, -700,0.4537559,5.773753,,,,,,,,,,,,,, -800,0.7711817,5.803458,,,,,,,,,,,,,, -900,0.41134396,5.7702417,,,,,,,,,,,,,, -1000,0.3028858,5.772437,,,,,,,,,,,,,, -1100,0.42878667,5.7979026,,,,,,,,,,,,,, -1200,0.25215468,5.7946343,,,,,,,,,,,,,, -1300,0.88243115,5.7801585,,,,,,,,,,,,,, -1400,1.8957255,5.723971,,,,,,,,,,,,,, -1500,3.02559,5.588041,,,,,,,,,,,,,, -1600,0.94793516,5.451843,,,,,,,,,,,,,, -1700,1.7213606,5.3253098,,,,,,,,,,,,,, -1748,,,5.9307604,0.944635537887994,6.0122747,0.8966179750330672,5348.0,5.9551396,0.899579550301627,2472.0,1502.6152069568634,1789.3208968639374,1502.6152069568634,286.587749004364,0.046248435974121,0.0 -1800,2.1194727,4.868768,,,,,,,,,,,,,, -1900,0.79285747,4.3611236,,,,,,,,,,,,,, -2000,1.0489984,3.940273,,,,,,,,,,,,,, -2100,0.8391045,3.7355275,,,,,,,,,,,,,, -2200,1.0049694,3.5099237,,,,,,,,,,,,,, -2300,1.0498129,3.3364792,,,,,,,,,,,,,, -2400,1.0317239,3.1993945,,,,,,,,,,,,,, -2500,1.1312522,3.105492,,,,,,,,,,,,,, -2600,1.6410557,2.9914873,,,,,,,,,,,,,, -2700,1.2395813,2.9838057,,,,,,,,,,,,,, -2800,1.1348026,2.861239,,,,,,,,,,,,,, -2900,1.2104957,2.7846498,,,,,,,,,,,,,, -3000,1.2204347,2.733573,,,,,,,,,,,,,, -3100,1.3100306,2.7234063,,,,,,,,,,,,,, -3200,1.1349946,2.6616726,,,,,,,,,,,,,, -3300,1.006934,2.5885773,,,,,,,,,,,,,, -3400,0.97036594,2.5879855,,,,,,,,,,,,,, -3500,1.1153752,2.4921448,,,,,,,,,,,,,, -3538,,,3.3182616,0.719244599525877,3.7128878,0.7624858800698997,5348.0,3.388314,0.7071273332927102,2472.0,2942.7661283016205,3349.630561113357,2942.7661283016205,406.619413614273,0.0941307544708252,0.0 -3600,1.0210419,2.4506838,,,,,,,,,,,,,, -3700,1.0718224,2.4505365,,,,,,,,,,,,,, -3800,0.98513085,2.3649335,,,,,,,,,,,,,, -3900,0.9294442,2.421522,,,,,,,,,,,,,, -4000,1.4426502,2.4170384,,,,,,,,,,,,,, -4100,1.0274006,2.3105297,,,,,,,,,,,,,, -4200,0.876999,2.297287,,,,,,,,,,,,,, -4300,0.92507964,2.2027795,,,,,,,,,,,,,, -4400,0.8819118,2.1707175,,,,,,,,,,,,,, -4500,0.9703253,2.1316695,,,,,,,,,,,,,, -4600,0.79822147,2.0721836,,,,,,,,,,,,,, -4700,0.78390956,2.093941,,,,,,,,,,,,,, -4800,0.8955137,2.1239693,,,,,,,,,,,,,, -4900,0.8903907,2.0648363,,,,,,,,,,,,,, -5000,0.8776193,2.0512767,,,,,,,,,,,,,, -5100,0.8850569,2.0460472,,,,,,,,,,,,,, -5200,0.8286507,2.013826,,,,,,,,,,,,,, -5287,,,0.7781294,0.2672749603344034,1.1462481,0.3319366268573139,5348.0,0.8411037,0.2719517396867954,2472.0,4383.109746217728,4925.651102542877,4383.109746217728,542.1664445400238,0.1464216709136963,0.0 -5300,0.87694603,1.978351,,,,,,,,,,,,,, -5400,0.8903149,1.9457814,,,,,,,,,,,,,, -5500,0.81592375,1.9348,,,,,,,,,,,,,, -5600,0.8397042,1.9317075,,,,,,,,,,,,,, -5700,0.818583,1.9414173,,,,,,,,,,,,,, -5800,0.88924235,1.9251812,,,,,,,,,,,,,, -5900,0.7669879,1.9372883,,,,,,,,,,,,,, -6000,1.0236658,1.909238,,,,,,,,,,,,,, -6100,0.8799917,1.7938198,,,,,,,,,,,,,, -6200,0.9180934,1.8936881,,,,,,,,,,,,,, -6300,0.8244866,1.8285553,,,,,,,,,,,,,, -6400,0.7979285,1.8204994,,,,,,,,,,,,,, -6500,0.7961064,1.7995458,,,,,,,,,,,,,, -6600,0.8250268,1.8626326,,,,,,,,,,,,,, -6700,0.7453776,1.8114331,,,,,,,,,,,,,, -6800,0.7514188,1.7782844,,,,,,,,,,,,,, -6900,0.7006417,1.7619518,,,,,,,,,,,,,, -7000,0.8083574,1.7677506,,,,,,,,,,,,,, -7026,,,0.5286893,0.1846181235534983,0.84069234,0.2531643125404288,5348.0,0.58280253,0.1943208823350192,2472.0,5823.235984802246,6501.536516189575,5823.235984802246,677.7943549156189,0.1984179019927978,0.0 -7100,0.7613477,1.7735558,,,,,,,,,,,,,, -7200,0.7831846,1.7408745,,,,,,,,,,,,,, -7300,0.80755806,1.7489482,,,,,,,,,,,,,, -7400,0.7587738,1.8029264,,,,,,,,,,,,,, -7500,0.7792135,1.7482251,,,,,,,,,,,,,, -7600,0.67004275,1.7170056,,,,,,,,,,,,,, -7700,0.6292826,1.7060577,,,,,,,,,,,,,, -7800,0.68693614,1.715799,,,,,,,,,,,,,, -7900,0.89480525,1.7164603,,,,,,,,,,,,,, -8000,0.7830219,1.7431313,,,,,,,,,,,,,, -8100,0.7265518,1.7020477,,,,,,,,,,,,,, -8200,0.7071579,1.6911454,,,,,,,,,,,,,, -8300,0.6874141,1.7138708,,,,,,,,,,,,,, -8400,0.7284955,1.7214243,,,,,,,,,,,,,, -8500,0.87889063,1.7103618,,,,,,,,,,,,,, -8600,0.8163687,1.6332359,,,,,,,,,,,,,, -8700,0.7312891,1.6908305,,,,,,,,,,,,,, -8791,,,0.44677088,0.159434864858988,0.7451185,0.2265271247477722,5348.0,0.4925594,0.1663721487620092,2472.0,7263.73501253128,8076.360379934311,7263.73501253128,811.9875221252441,0.2500596046447754,0.0 -8800,0.903504,1.6738129,,,,,,,,,,,,,, -8900,0.6360498,1.6761938,,,,,,,,,,,,,, -9000,0.68216544,1.6501206,,,,,,,,,,,,,, -9100,0.731653,1.6561022,,,,,,,,,,,,,, -9200,0.7692998,1.5883323,,,,,,,,,,,,,, -9300,0.75264794,1.618034,,,,,,,,,,,,,, -9400,0.71042913,1.6069592,,,,,,,,,,,,,, -9500,0.7053621,1.6691471,,,,,,,,,,,,,, -9600,0.7743356,1.6141111,,,,,,,,,,,,,, -9700,0.6296932,1.6364533,,,,,,,,,,,,,, -9800,0.6554393,1.6246436,,,,,,,,,,,,,, -9900,0.62127155,1.6805745,,,,,,,,,,,,,, -10000,0.6635452,1.6477392,,,,,,,,,,,,,, -10100,0.7748278,1.5852369,,,,,,,,,,,,,, -10200,0.64453256,1.6380804,,,,,,,,,,,,,, -10300,0.80177605,1.604669,,,,,,,,,,,,,, -10400,0.8088547,1.5960184,,,,,,,,,,,,,, -10500,0.6331517,1.5504185,,,,,,,,,,,,,, -10539,,,0.4121761,0.1454353404356338,0.67398584,0.2053930891993396,5348.0,0.4373133,0.1492088639733512,2472.0,8703.784867048264,9653.374977111816,8703.784867048264,948.816588640213,0.3072614669799804,0.0 -10600,0.67650586,1.5616491,,,,,,,,,,,,,, -10700,0.6302884,1.5940876,,,,,,,,,,,,,, -10800,0.6971849,1.5965849,,,,,,,,,,,,,, -10900,0.634451,1.6042074,,,,,,,,,,,,,, -11000,0.5931638,1.4950103,,,,,,,,,,,,,, -11100,0.71336395,1.5806898,,,,,,,,,,,,,, -11200,0.7104953,1.6334293,,,,,,,,,,,,,, -11300,0.6833037,1.5508507,,,,,,,,,,,,,, -11400,0.6613432,1.5856783,,,,,,,,,,,,,, -11500,0.6368841,1.4485941,,,,,,,,,,,,,, -11600,0.71955764,1.4778069,,,,,,,,,,,,,, -11700,0.8074399,1.5497866,,,,,,,,,,,,,, -11800,0.629428,1.5351492,,,,,,,,,,,,,, -11900,0.60416174,1.52555,,,,,,,,,,,,,, -12000,0.86722195,1.584731,,,,,,,,,,,,,, -12100,0.69548297,1.5560594,,,,,,,,,,,,,, -12200,0.6866737,1.516217,,,,,,,,,,,,,, -12248,,,0.36185822,0.1293897432625018,0.6489348,0.1973217992411442,5348.0,0.411901,0.1403936384132594,2472.0,10143.811375379562,11232.8799161911,10143.811375379562,1088.169734954834,0.355985164642334,0.0 -12300,0.85802454,1.5753084,,,,,,,,,,,,,, -12400,0.68773705,1.5171494,,,,,,,,,,,,,, -12500,0.61123,1.5104089,,,,,,,,,,,,,, -12600,0.7517943,1.5519046,,,,,,,,,,,,,, -12700,0.5583406,1.515472,,,,,,,,,,,,,, -12800,0.84060454,1.5074304,,,,,,,,,,,,,, -12900,0.6632181,1.5159254,,,,,,,,,,,,,, -13000,0.6520058,1.4335374,,,,,,,,,,,,,, -13100,0.60985297,1.5174593,,,,,,,,,,,,,, -13200,0.5642421,1.5092762,,,,,,,,,,,,,, -13300,0.6145579,1.5131412,,,,,,,,,,,,,, -13400,0.6397633,1.4369586,,,,,,,,,,,,,, -13500,0.6761977,1.4354331,,,,,,,,,,,,,, -13600,0.6229076,1.541834,,,,,,,,,,,,,, -13700,0.81431144,1.4946159,,,,,,,,,,,,,, -13800,0.58265746,1.483322,,,,,,,,,,,,,, -13900,0.64465076,1.468347,,,,,,,,,,,,,, -13977,,,0.30664057,0.1145063574195897,0.6024696,0.1828687836102609,5348.0,0.3806148,0.1301362906993277,2472.0,11584.330972671509,12809.234954595566,11584.330972671509,1223.8755309581757,0.4083857536315918,0.0 -14000,0.6349198,1.4764149,,,,,,,,,,,,,, -14100,0.629891,1.448917,,,,,,,,,,,,,, -14200,0.68855774,1.496419,,,,,,,,,,,,,, -14300,0.8867103,1.5418147,,,,,,,,,,,,,, -14400,0.6843044,1.5095788,,,,,,,,,,,,,, -14500,0.7960741,1.4492047,,,,,,,,,,,,,, -14600,0.6772844,1.4595308,,,,,,,,,,,,,, -14700,0.7016332,1.468982,,,,,,,,,,,,,, -14800,0.9664271,1.4983894,,,,,,,,,,,,,, -14900,0.68946034,1.4859788,,,,,,,,,,,,,, -15000,0.6671635,1.461488,,,,,,,,,,,,,, -15100,0.6402296,1.4332796,,,,,,,,,,,,,, -15200,0.6010946,1.4359589,,,,,,,,,,,,,, -15300,0.7214873,1.5446442,,,,,,,,,,,,,, -15400,0.6895794,1.5073605,,,,,,,,,,,,,, -15500,0.7290965,1.4428568,,,,,,,,,,,,,, -15600,0.61382,1.4608594,,,,,,,,,,,,,, -15700,0.58984554,1.3879924,,,,,,,,,,,,,, -15712,,,0.28140113,0.1030010267650516,0.5813166,0.1775683790802977,5348.0,0.36267507,0.1221132167448662,2472.0,13024.30504345894,14383.716886281967,13024.30504345894,1358.2513551712036,0.4628069400787353,0.0 -15800,0.66258454,1.4701204,,,,,,,,,,,,,, -15900,0.71034694,1.4307394,,,,,,,,,,,,,, -16000,0.6638378,1.4987391,,,,,,,,,,,,,, -16100,0.5996323,1.437057,,,,,,,,,,,,,, -16200,0.7175058,1.4325762,,,,,,,,,,,,,, -16300,0.9194444,1.4848478,,,,,,,,,,,,,, -16400,0.6839324,1.4923443,,,,,,,,,,,,,, -16500,0.68452954,1.422442,,,,,,,,,,,,,, -16600,0.72471845,1.4265043,,,,,,,,,,,,,, -16700,0.7374366,1.474597,,,,,,,,,,,,,, -16800,0.6685376,1.4595925,,,,,,,,,,,,,, -16900,0.7345797,1.4087547,,,,,,,,,,,,,, -17000,0.5728051,1.4491445,,,,,,,,,,,,,, -17100,0.64172953,1.4260415,,,,,,,,,,,,,, -17200,0.67535335,1.4132596,,,,,,,,,,,,,, -17300,0.93315417,1.447543,,,,,,,,,,,,,, -17400,0.65523607,1.4539319,,,,,,,,,,,,,, -17428,,,0.27844656,0.1055253744630417,0.5747308,0.1757339949988897,5348.0,0.35867026,0.1208132756484471,2472.0,14464.859641313553,15959.115698099136,14464.859641313553,1492.9678757190704,0.5146021842956543,0.0 -17500,0.7708474,1.4489726,,,,,,,,,,,,,, -17600,0.73013985,1.4632909,,,,,,,,,,,,,, -17700,0.747329,1.4691441,,,,,,,,,,,,,, -17800,0.66233677,1.374242,,,,,,,,,,,,,, -17900,0.78987825,1.4043577,,,,,,,,,,,,,, -18000,0.7599321,1.4350281,,,,,,,,,,,,,, -18100,0.7943564,1.3739784,,,,,,,,,,,,,, -18200,0.61764956,1.4353515,,,,,,,,,,,,,, -18300,0.6530398,1.4537189,,,,,,,,,,,,,, -18400,0.6178988,1.4366282,,,,,,,,,,,,,, -18500,0.6522559,1.4446888,,,,,,,,,,,,,, -18600,0.79476064,1.3444197,,,,,,,,,,,,,, -18700,0.6257711,1.3519466,,,,,,,,,,,,,, -18800,0.67903143,1.4397645,,,,,,,,,,,,,, -18900,0.6488216,1.3298118,,,,,,,,,,,,,, -19000,0.67685014,1.4059786,,,,,,,,,,,,,, -19100,0.6058924,1.4097373,,,,,,,,,,,,,, -19133,,,0.2847909,0.1031915174912081,0.5575726,0.1692557227956013,5348.0,0.33965674,0.1150650986127191,2472.0,15904.858520269394,17536.534660100937,15904.858520269394,1630.2590498924255,0.5669634342193604,0.0 -19200,0.63094276,1.4054681,,,,,,,,,,,,,, -19300,0.71036536,1.4244289,,,,,,,,,,,,,, -19400,0.7468902,1.4874426,,,,,,,,,,,,,, -19500,0.66755897,1.4212843,,,,,,,,,,,,,, -19600,0.7611548,1.3989382,,,,,,,,,,,,,, -19700,0.824516,1.3801866,,,,,,,,,,,,,, -19800,0.74253607,1.3902566,,,,,,,,,,,,,, -19900,0.7902397,1.3915379,,,,,,,,,,,,,, -20000,0.63846755,1.3794097,,,,,,,,,,,,,, -20100,0.67128456,1.3404584,,,,,,,,,,,,,, -20200,0.692735,1.3565227,,,,,,,,,,,,,, -20300,0.6475328,1.3971497,,,,,,,,,,,,,, -20400,0.67693645,1.4524903,,,,,,,,,,,,,, -20500,0.75675964,1.4208649,,,,,,,,,,,,,, -20600,0.7194002,1.329979,,,,,,,,,,,,,, -20700,0.66030985,1.328702,,,,,,,,,,,,,, -20800,0.60292596,1.3712682,,,,,,,,,,,,,, -20853,,,0.27910382,0.1014535573904555,0.5391951,0.1646697625920812,5348.0,0.32976592,0.1125261511587756,2472.0,17345.28639960289,19115.42906999588,17345.28639960289,1768.594096660614,0.622377872467041,0.0 -20900,0.6777336,1.3850055,,,,,,,,,,,,,, -21000,0.75744224,1.3837643,,,,,,,,,,,,,, -21100,0.68388283,1.3221377,,,,,,,,,,,,,, -21200,0.8165781,1.390056,,,,,,,,,,,,,, -21300,0.7606935,1.3753808,,,,,,,,,,,,,, -21400,0.96662617,1.3790796,,,,,,,,,,,,,, -21500,0.58575946,1.3767939,,,,,,,,,,,,,, -21600,0.7104824,1.3938487,,,,,,,,,,,,,, -21700,0.6954826,1.377426,,,,,,,,,,,,,, -21800,0.8411045,1.3476889,,,,,,,,,,,,,, -21900,0.9111315,1.37075,,,,,,,,,,,,,, -22000,0.76053303,1.3892992,,,,,,,,,,,,,, -22100,0.72930115,1.3854351,,,,,,,,,,,,,, -22200,0.78499264,1.3158367,,,,,,,,,,,,,, -22300,0.60037273,1.3607633,,,,,,,,,,,,,, -22400,0.64410406,1.3886261,,,,,,,,,,,,,, -22500,0.61589396,1.3794464,,,,,,,,,,,,,, -22545,,,0.2605467,0.0943512707068628,0.52861744,0.1601224210008013,5348.0,0.313901,0.1081591615379928,2472.0,18785.17260837555,20691.28274512291,18785.17260837555,1904.43694972992,0.6713097095489502,0.0 -22600,0.67075104,1.4425733,,,,,,,,,,,,,, -22700,0.62078464,1.3652332,,,,,,,,,,,,,, -22800,0.6940381,1.3524343,,,,,,,,,,,,,, -22900,0.65916437,1.3467134,,,,,,,,,,,,,, -23000,0.75009006,1.3848951,,,,,,,,,,,,,, -23100,0.6652139,1.3105423,,,,,,,,,,,,,, -23200,0.846114,1.3804752,,,,,,,,,,,,,, -23300,0.63316005,1.3353233,,,,,,,,,,,,,, -23400,0.66972345,1.3630911,,,,,,,,,,,,,, -23500,0.67673284,1.2989761,,,,,,,,,,,,,, -23600,0.67098737,1.3376441,,,,,,,,,,,,,, -23700,0.62742877,1.2924438,,,,,,,,,,,,,, -23800,0.6776282,1.2891659,,,,,,,,,,,,,, -23900,0.8966541,1.3127266,,,,,,,,,,,,,, -24000,0.663567,1.3209835,,,,,,,,,,,,,, -24100,0.7091203,1.3852853,,,,,,,,,,,,,, -24200,0.73548293,1.3173496,,,,,,,,,,,,,, -24263,,,0.2430587,0.0905367171994425,0.5190067,0.1562702144298444,5348.0,0.31107822,0.1044827656246826,2472.0,20225.37378954888,22267.05878567696,20225.37378954888,2039.88334441185,0.7225484848022461,0.0 -24300,0.5827775,1.2774246,,,,,,,,,,,,,, -24400,0.6489622,1.326741,,,,,,,,,,,,,, -24500,0.7052378,1.3101181,,,,,,,,,,,,,, -24600,0.72259074,1.3753557,,,,,,,,,,,,,, -24700,0.6784053,1.309767,,,,,,,,,,,,,, -24800,0.72777075,1.3460039,,,,,,,,,,,,,, -24900,0.6425148,1.2946113,,,,,,,,,,,,,, -25000,0.69626015,1.3234633,,,,,,,,,,,,,, -25100,0.6073204,1.2862146,,,,,,,,,,,,,, -25200,0.8420933,1.3714956,,,,,,,,,,,,,, -25300,0.75950795,1.3705635,,,,,,,,,,,,,, -25400,0.61283034,1.3171268,,,,,,,,,,,,,, -25500,0.6561485,1.3671585,,,,,,,,,,,,,, -25600,0.7352251,1.3696347,,,,,,,,,,,,,, -25700,0.8187825,1.327128,,,,,,,,,,,,,, -25800,0.8329214,1.2904203,,,,,,,,,,,,,, -25900,0.7528057,1.2518649,,,,,,,,,,,,,, -25979,,,0.22930108,0.0865750280393333,0.51106757,0.153431746430192,5348.0,0.3014969,0.1013954055206873,2472.0,21665.884313106537,23842.28617978096,21665.884313106537,2174.472556114197,0.7736678123474121,0.0 -26000,0.79257447,1.2894149,,,,,,,,,,,,,, -26100,0.7336763,1.3104405,,,,,,,,,,,,,, -26200,0.67640686,1.3585988,,,,,,,,,,,,,, -26300,0.81842697,1.3157065,,,,,,,,,,,,,, -26400,0.6203249,1.2957482,,,,,,,,,,,,,, -26500,0.67208207,1.2719007,,,,,,,,,,,,,, -26600,0.7511589,1.2855859,,,,,,,,,,,,,, -26700,0.8236297,1.3296746,,,,,,,,,,,,,, -26800,0.6490396,1.2679752,,,,,,,,,,,,,, -26900,0.7052075,1.3081888,,,,,,,,,,,,,, -27000,0.8219254,1.311355,,,,,,,,,,,,,, -27100,0.65815926,1.2874647,,,,,,,,,,,,,, -27200,0.7007243,1.3411052,,,,,,,,,,,,,, -27300,0.79618734,1.3577542,,,,,,,,,,,,,, -27400,0.7233063,1.3237337,,,,,,,,,,,,,, -27500,0.67833966,1.354761,,,,,,,,,,,,,, -27600,0.78889114,1.3350545,,,,,,,,,,,,,, -27691,,,0.22043964,0.0796043816673655,0.49129182,0.1478706662676076,5348.0,0.2954965,0.0986939654296914,2472.0,23106.38948559761,25422.112336874008,23106.38948559761,2313.663006067276,0.8280179500579834,0.0 -27700,0.76581514,1.3324856,,,,,,,,,,,,,, -27800,1.0099497,1.3393847,,,,,,,,,,,,,, -27900,0.704316,1.2290417,,,,,,,,,,,,,, -28000,0.5649327,1.2604208,,,,,,,,,,,,,, -28100,0.65488577,1.2675437,,,,,,,,,,,,,, -28200,0.7180008,1.3030659,,,,,,,,,,,,,, -28300,0.69551617,1.332069,,,,,,,,,,,,,, -28400,0.62211835,1.3021077,,,,,,,,,,,,,, -28500,0.92631376,1.3262162,,,,,,,,,,,,,, -28600,0.73987323,1.3948311,,,,,,,,,,,,,, -28700,0.75252986,1.3187624,,,,,,,,,,,,,, -28800,0.7338685,1.285709,,,,,,,,,,,,,, -28900,0.8411006,1.2622514,,,,,,,,,,,,,, -29000,0.67919105,1.280183,,,,,,,,,,,,,, -29100,0.6735059,1.2758497,,,,,,,,,,,,,, -29200,0.69406265,1.2567983,,,,,,,,,,,,,, -29300,0.69007456,1.3018129,,,,,,,,,,,,,, -29397,,,0.23616512,0.0882897815912636,0.48435923,0.1464224683086013,5348.0,0.28681365,0.0966628074665366,2472.0,24547.143973588943,26998.507881879807,24547.143973588943,2449.174416303634,0.8814880847930908,0.0 -29400,0.679657,1.2645792,,,,,,,,,,,,,, -29500,0.6688042,1.3333709,,,,,,,,,,,,,, -29600,0.6449133,1.2903216,,,,,,,,,,,,,, -29700,0.7060782,1.2446193,,,,,,,,,,,,,, -29800,0.72928894,1.3437042,,,,,,,,,,,,,, -29900,0.66989684,1.2843555,,,,,,,,,,,,,, -30000,0.7022911,1.2767998,,,,,,,,,,,,,, -30100,0.7813895,1.2646891,,,,,,,,,,,,,, -30200,0.676281,1.3263224,,,,,,,,,,,,,, -30300,0.5880114,1.2964734,,,,,,,,,,,,,, -30400,0.69099367,1.2500993,,,,,,,,,,,,,, -30500,0.85042113,1.2624066,,,,,,,,,,,,,, -30600,0.724407,1.2675353,,,,,,,,,,,,,, -30700,0.8992986,1.3330088,,,,,,,,,,,,,, -30800,0.76468295,1.2545618,,,,,,,,,,,,,, -30900,0.8692729,1.3017043,,,,,,,,,,,,,, -31000,0.77764386,1.3069029,,,,,,,,,,,,,, -31100,0.73306984,1.2662228,,,,,,,,,,,,,, -31119,,,0.21892151,0.0790563160950412,0.47866896,0.1437481294109696,5348.0,0.28513002,0.0958097211220116,2472.0,25987.65677118301,28574.90484571457,25987.65677118301,2584.9278602600098,0.9351787567138672,0.0 -31200,0.6728823,1.2556598,,,,,,,,,,,,,, -31300,0.81326014,1.2833146,,,,,,,,,,,,,, -31400,0.71054256,1.2424498,,,,,,,,,,,,,, -31500,0.7076191,1.2614564,,,,,,,,,,,,,, -31600,0.6885779,1.2567667,,,,,,,,,,,,,, -31700,0.692914,1.2640398,,,,,,,,,,,,,, -31800,0.77905977,1.2854123,,,,,,,,,,,,,, -31900,0.6791264,1.2947679,,,,,,,,,,,,,, -32000,0.6996264,1.2410575,,,,,,,,,,,,,, -32100,0.685878,1.2047701,,,,,,,,,,,,,, -32200,0.741509,1.3062931,,,,,,,,,,,,,, -32300,0.74759686,1.1853801,,,,,,,,,,,,,, -32400,0.7671275,1.2794669,,,,,,,,,,,,,, -32500,0.74828094,1.2754383,,,,,,,,,,,,,, -32600,0.78794837,1.3428457,,,,,,,,,,,,,, -32700,0.7196764,1.2582175,,,,,,,,,,,,,, -32800,0.88993466,1.2691253,,,,,,,,,,,,,, -32826,,,0.22863874,0.0829305128191956,0.47382867,0.1411317184316981,5348.0,0.27610952,0.093636382101436,2472.0,27427.88367891312,30153.691357135773,27427.88367891312,2723.3589627742767,0.9874227046966552,0.0 -32900,0.77217567,1.2692248,,,,,,,,,,,,,, -33000,0.6235415,1.2316068,,,,,,,,,,,,,, -33100,0.63943356,1.2539934,,,,,,,,,,,,,, -33200,0.7470538,1.236482,,,,,,,,,,,,,, -33300,0.7969663,1.2976938,,,,,,,,,,,,,, -33400,0.6904364,1.2742352,,,,,,,,,,,,,, -33500,0.8335603,1.3082315,,,,,,,,,,,,,, -33600,0.6559369,1.2322885,,,,,,,,,,,,,, -33700,0.79544836,1.2393576,,,,,,,,,,,,,, -33800,0.67373323,1.2976153,,,,,,,,,,,,,, -33900,0.6603475,1.2443557,,,,,,,,,,,,,, -34000,0.7105501,1.2134674,,,,,,,,,,,,,, -34100,0.78034216,1.2360797,,,,,,,,,,,,,, -34200,0.7639972,1.2301508,,,,,,,,,,,,,, -34300,0.7472997,1.2966228,,,,,,,,,,,,,, -34400,0.72050065,1.2433174,,,,,,,,,,,,,, -34500,0.80687416,1.2444792,,,,,,,,,,,,,, -34538,,,0.21834032,0.077597323017573,0.46095514,0.1372602025546212,5348.0,0.26683512,0.0901427904048097,2472.0,28867.750519037247,31728.63215994835,28867.750519037247,2858.290730953217,1.050553798675537,0.0 -34600,0.749585,1.2740966,,,,,,,,,,,,,, -34700,0.76572067,1.2205616,,,,,,,,,,,,,, -34800,0.83205336,1.2518175,,,,,,,,,,,,,, -34900,0.84792674,1.2379622,,,,,,,,,,,,,, -35000,0.7188042,1.3044409,,,,,,,,,,,,,, -35100,0.782596,1.2605655,,,,,,,,,,,,,, -35200,0.66484755,1.2247725,,,,,,,,,,,,,, -35300,0.8299708,1.244805,,,,,,,,,,,,,, -35400,0.6851922,1.2077944,,,,,,,,,,,,,, -35500,0.89216185,1.206559,,,,,,,,,,,,,, -35600,0.7528281,1.2561598,,,,,,,,,,,,,, -35700,0.66744214,1.2110695,,,,,,,,,,,,,, -35800,0.7710812,1.1996709,,,,,,,,,,,,,, -35900,0.72294044,1.2429541,,,,,,,,,,,,,, -36000,0.74651605,1.2803204,,,,,,,,,,,,,, -36100,0.6750094,1.1896291,,,,,,,,,,,,,, -36200,0.7235647,1.2567499,,,,,,,,,,,,,, -36249,,,0.17301418,0.0649682769082541,0.4529431,0.1348368846365505,5348.0,0.26671857,0.0898381167103365,2472.0,30308.22143220901,33309.32222747803,30308.22143220901,2998.3748049736023,1.1091015338897705,0.0 -36300,0.7476482,1.2447664,,,,,,,,,,,,,, -36400,0.85658664,1.1818117,,,,,,,,,,,,,, -36500,0.70269436,1.239956,,,,,,,,,,,,,, -36600,0.7223771,1.2221385,,,,,,,,,,,,,, -36700,0.9052443,1.2144611,,,,,,,,,,,,,, -36800,0.747634,1.2192439,,,,,,,,,,,,,, -36900,0.7670277,1.1891303,,,,,,,,,,,,,, -37000,0.6983262,1.3295919,,,,,,,,,,,,,, -37100,0.9051548,1.2199204,,,,,,,,,,,,,, -37200,0.67006254,1.2357548,,,,,,,,,,,,,, -37300,0.6099799,1.204821,,,,,,,,,,,,,, -37400,0.82900333,1.1827176,,,,,,,,,,,,,, -37500,0.7688889,1.2291478,,,,,,,,,,,,,, -37600,0.68591106,1.129728,,,,,,,,,,,,,, -37700,0.6699265,1.2407166,,,,,,,,,,,,,, -37800,0.6261045,1.2038356,,,,,,,,,,,,,, -37900,0.8340644,1.2043518,,,,,,,,,,,,,, -37952,,,0.19108309,0.0690904099174251,0.43196315,0.1302798883922106,5348.0,0.25399256,0.0842524323116608,2472.0,31748.56366539001,34886.196565151215,31748.56366539001,3134.772925376892,1.1662836074829102,0.0 -38000,0.6965922,1.2234735,,,,,,,,,,,,,, -38100,0.80952746,1.234104,,,,,,,,,,,,,, -38200,0.6266599,1.1956624,,,,,,,,,,,,,, -38300,0.8614739,1.2261215,,,,,,,,,,,,,, -38400,0.7481023,1.281858,,,,,,,,,,,,,, -38500,0.75192356,1.2404006,,,,,,,,,,,,,, -38600,0.62679845,1.1531031,,,,,,,,,,,,,, -38700,0.79979736,1.2131492,,,,,,,,,,,,,, -38800,0.77666587,1.2140801,,,,,,,,,,,,,, -38900,0.78575647,1.2218671,,,,,,,,,,,,,, -39000,0.7880033,1.2556325,,,,,,,,,,,,,, -39100,0.78701615,1.1866282,,,,,,,,,,,,,, -39200,0.74897027,1.2260705,,,,,,,,,,,,,, -39300,0.6354719,1.2321373,,,,,,,,,,,,,, -39400,0.8701678,1.2173598,,,,,,,,,,,,,, -39500,0.8585838,1.2092043,,,,,,,,,,,,,, -39600,0.70415896,1.1629846,,,,,,,,,,,,,, -39671,,,0.23817594,0.0868124298259412,0.43086866,0.1289668555760448,5348.0,0.24960111,0.0832977880689781,2472.0,33188.59181380272,36462.04642629624,33188.59181380272,3270.4538078308105,1.2276201248168943,0.0 -39700,0.682254,1.1924309,,,,,,,,,,,,,, -39800,0.8854155,1.1487197,,,,,,,,,,,,,, -39900,0.8687687,1.2482437,,,,,,,,,,,,,, -40000,0.9390361,1.2340128,,,,,,,,,,,,,, -40100,0.79087865,1.1690174,,,,,,,,,,,,,, -40200,0.70005894,1.2322055,,,,,,,,,,,,,, -40300,0.72864676,1.1895434,,,,,,,,,,,,,, -40400,0.75424135,1.1857393,,,,,,,,,,,,,, -40500,0.79273194,1.1514149,,,,,,,,,,,,,, -40600,0.7294973,1.1726884,,,,,,,,,,,,,, -40700,0.94256663,1.2332376,,,,,,,,,,,,,, -40800,0.7769335,1.2191594,,,,,,,,,,,,,, -40900,0.7453344,1.1961184,,,,,,,,,,,,,, -41000,0.938417,1.2465465,,,,,,,,,,,,,, -41100,0.76735544,1.1839255,,,,,,,,,,,,,, -41200,0.8269171,1.226946,,,,,,,,,,,,,, -41300,0.7754208,1.1662457,,,,,,,,,,,,,, -41382,,,0.2418203,0.0883383202784547,0.4284742,0.1280110449231007,5348.0,0.24681735,0.0825665712022424,2472.0,34628.68599200249,38034.93848752976,34628.68599200249,3403.1156027317047,1.2870965003967283,0.0 -41400,0.84433436,1.1592877,,,,,,,,,,,,,, -41500,0.74794555,1.1949329,,,,,,,,,,,,,, -41600,0.785486,1.1903745,,,,,,,,,,,,,, -41700,0.71447784,1.2019715,,,,,,,,,,,,,, -41800,0.89750266,1.1750312,,,,,,,,,,,,,, -41900,0.7717154,1.1863497,,,,,,,,,,,,,, -42000,0.8229836,1.1818955,,,,,,,,,,,,,, -42100,0.7925433,1.2041783,,,,,,,,,,,,,, -42200,0.74849606,1.2471956,,,,,,,,,,,,,, -42300,0.73089695,1.2204114,,,,,,,,,,,,,, -42400,0.727402,1.1458375,,,,,,,,,,,,,, -42500,0.757533,1.1983941,,,,,,,,,,,,,, -42600,0.75487316,1.2181863,,,,,,,,,,,,,, -42700,0.85913056,1.1266681,,,,,,,,,,,,,, -42800,0.712805,1.1830786,,,,,,,,,,,,,, -42900,0.767924,1.1966586,,,,,,,,,,,,,, -43000,0.73363334,1.1731437,,,,,,,,,,,,,, -43100,0.817617,1.1302309,,,,,,,,,,,,,, -43104,,,0.2755499,0.1017133564588683,0.4144993,0.1238691987603425,5348.0,0.23967984,0.0812463185261917,2472.0,36068.79007220268,39607.35124826431,36068.79007220268,3535.2889487743378,1.3462481498718262,0.0 -43200,0.9254233,1.11247,,,,,,,,,,,,,, -43300,0.83417404,1.1341618,,,,,,,,,,,,,, -43400,0.80104613,1.1562423,,,,,,,,,,,,,, -43500,0.90341455,1.1902944,,,,,,,,,,,,,, -43600,0.82933164,1.1983967,,,,,,,,,,,,,, -43700,0.7640085,1.1549124,,,,,,,,,,,,,, -43800,0.9064648,1.1413844,,,,,,,,,,,,,, -43900,0.65699935,1.1415582,,,,,,,,,,,,,, -44000,0.7597639,1.1584759,,,,,,,,,,,,,, -44100,0.77114195,1.1807201,,,,,,,,,,,,,, -44200,0.80208945,1.1204025,,,,,,,,,,,,,, -44300,0.8555353,1.2162648,,,,,,,,,,,,,, -44400,0.8489752,1.1851423,,,,,,,,,,,,,, -44500,0.9610152,1.1261767,,,,,,,,,,,,,, -44600,0.72604114,1.2310455,,,,,,,,,,,,,, -44700,0.75433195,1.1259171,,,,,,,,,,,,,, -44800,0.8793706,1.1401848,,,,,,,,,,,,,, -44819,,,0.23959982,0.0853581710813722,0.40491655,0.120325941087307,5348.0,0.23447508,0.0793573416204578,2472.0,37508.72058033943,41181.65663433075,37508.72058033943,3669.527673482895,1.4051265716552734,0.0 -44900,0.84065294,1.1812723,,,,,,,,,,,,,, -45000,0.91228145,1.1745409,,,,,,,,,,,,,, -45100,0.7545146,1.1830574,,,,,,,,,,,,,, -45200,0.9182517,1.1888155,,,,,,,,,,,,,, -45300,0.78045577,1.1624295,,,,,,,,,,,,,, -45400,0.66178143,1.1238465,,,,,,,,,,,,,, -45500,0.70951223,1.1465617,,,,,,,,,,,,,, -45600,0.7284717,1.1307862,,,,,,,,,,,,,, -45700,0.76782066,1.1569052,,,,,,,,,,,,,, -45800,0.91894656,1.1656572,,,,,,,,,,,,,, -45900,0.98241603,1.1449479,,,,,,,,,,,,,, -46000,0.7439406,1.1151861,,,,,,,,,,,,,, -46100,0.8190247,1.1103199,,,,,,,,,,,,,, -46200,0.77804434,1.1517308,,,,,,,,,,,,,, -46300,0.8228455,1.1617891,,,,,,,,,,,,,, -46400,0.86166143,1.1638682,,,,,,,,,,,,,, -46500,0.7497151,1.1707166,,,,,,,,,,,,,, -46534,,,0.22094382,0.0819313462416354,0.40488222,0.1207990190872491,5348.0,0.23017377,0.0778339731480917,2472.0,38949.01678466797,42756.07475566864,38949.01678466797,3803.516398906708,1.4624698162078855,0.0 -46600,0.7732558,1.1648111,,,,,,,,,,,,,, -46700,0.7973712,1.1108702,,,,,,,,,,,,,, -46800,0.7171082,1.1324958,,,,,,,,,,,,,, -46900,0.9534844,1.1365161,,,,,,,,,,,,,, -47000,0.7879024,1.1319923,,,,,,,,,,,,,, -47100,0.80701715,1.1435691,,,,,,,,,,,,,, -47200,0.8012681,1.1235421,,,,,,,,,,,,,, -47300,0.7132016,1.1679008,,,,,,,,,,,,,, -47400,0.8369312,1.1513253,,,,,,,,,,,,,, -47500,0.8106727,1.0877227,,,,,,,,,,,,,, -47600,0.7774314,1.0648072,,,,,,,,,,,,,, -47700,0.76245433,1.1378677,,,,,,,,,,,,,, -47800,0.8477877,1.1896818,,,,,,,,,,,,,, -47900,0.77997816,1.1092118,,,,,,,,,,,,,, -48000,0.6843855,1.1005404,,,,,,,,,,,,,, -48100,0.93434685,1.1977419,,,,,,,,,,,,,, -48200,0.9631234,1.1316054,,,,,,,,,,,,,, -48243,,,0.18676206,0.0701747707968547,0.39040998,0.1165316624347104,5348.0,0.22285704,0.0749294172607803,2472.0,40389.19332933426,44331.88853049278,40389.19332933426,3939.021583557129,1.5191435813903809,0.0 -48300,0.84460825,1.1285938,,,,,,,,,,,,,, -48400,0.8310586,1.1381183,,,,,,,,,,,,,, -48500,0.81441253,1.1345263,,,,,,,,,,,,,, -48600,0.82686013,1.1386558,,,,,,,,,,,,,, -48700,0.8034368,1.1305686,,,,,,,,,,,,,, -48800,0.8852743,1.1391968,,,,,,,,,,,,,, -48900,0.92603636,1.1489295,,,,,,,,,,,,,, -49000,0.891815,1.1668136,,,,,,,,,,,,,, -49100,1.3950828,1.0897138,,,,,,,,,,,,,, -49200,0.7926254,1.1169398,,,,,,,,,,,,,, -49300,1.0045085,1.0912758,,,,,,,,,,,,,, -49400,0.9302204,1.0728494,,,,,,,,,,,,,, -49500,0.8718718,1.0869595,,,,,,,,,,,,,, -49600,0.8644204,1.0824784,,,,,,,,,,,,,, -49700,0.92622495,1.1235936,,,,,,,,,,,,,, -49800,0.71358854,1.1091658,,,,,,,,,,,,,, -49900,0.90163493,1.0980719,,,,,,,,,,,,,, -49950,,,0.20000587,0.0746332694281664,0.3872141,0.1135290653330372,5348.0,0.21055159,0.0707249202770499,2472.0,41829.21695446968,45906.16408348084,41829.21695446968,4073.134298324585,1.579833984375,0.0 -50000,0.8593642,1.1272403,,,,,,,,,,,,,, -50100,0.83928317,1.0678872,,,,,,,,,,,,,, -50200,0.76090825,1.1353011,,,,,,,,,,,,,, -50300,1.2502522,1.0854642,,,,,,,,,,,,,, -50400,0.8756032,1.1058401,,,,,,,,,,,,,, -50500,1.0927467,1.0865936,,,,,,,,,,,,,, -50600,0.7614249,1.0873523,,,,,,,,,,,,,, -50700,0.8053245,1.0884871,,,,,,,,,,,,,, -50800,0.8327334,1.1448774,,,,,,,,,,,,,, -50900,0.81688535,1.0710233,,,,,,,,,,,,,, -51000,0.8575185,1.1174121,,,,,,,,,,,,,, -51100,0.87717867,1.0937169,,,,,,,,,,,,,, -51200,0.9216813,1.1022635,,,,,,,,,,,,,, -51300,0.8856544,1.0756251,,,,,,,,,,,,,, -51400,1.0462416,1.1289859,,,,,,,,,,,,,, -51500,0.96043026,1.1428924,,,,,,,,,,,,,, -51600,0.8931301,1.1085547,,,,,,,,,,,,,, -51667,,,0.17628975,0.0666766878543205,0.3812173,0.1120615580679108,5348.0,0.20984745,0.0708264781752076,2472.0,43269.77072787285,47481.20555949211,43269.77072787285,4207.490744113922,1.6347463130950928,0.0 -51700,0.7871799,1.0765206,,,,,,,,,,,,,, -51800,0.800695,1.1079966,,,,,,,,,,,,,, -51900,0.9629984,1.1062142,,,,,,,,,,,,,, -52000,1.0410028,1.1245737,,,,,,,,,,,,,, -52100,0.8192639,1.086683,,,,,,,,,,,,,, -52200,1.0860164,1.1190324,,,,,,,,,,,,,, -52300,0.7450001,1.0159038,,,,,,,,,,,,,, -52400,0.8241703,1.0724125,,,,,,,,,,,,,, -52500,1.0035313,1.1136663,,,,,,,,,,,,,, -52600,0.8988152,1.0673426,,,,,,,,,,,,,, -52700,0.9173837,1.092892,,,,,,,,,,,,,, -52800,0.9346498,1.0147803,,,,,,,,,,,,,, -52900,0.7880412,0.9905575,,,,,,,,,,,,,, -53000,1.1504683,1.0653732,,,,,,,,,,,,,, -53100,0.89340913,1.0857574,,,,,,,,,,,,,, -53200,0.86864483,1.0178854,,,,,,,,,,,,,, -53300,0.8131616,1.0618607,,,,,,,,,,,,,, -53392,,,0.17431833,0.0662895019472981,0.36423317,0.1081224596194135,5348.0,0.20121741,0.0685109580972112,2472.0,44709.96631407738,49057.63871669769,44709.96631407738,4343.593742609024,1.6926517486572266,0.0 -53400,0.83774304,1.0930207,,,,,,,,,,,,,, -53500,0.87406796,1.1353112,,,,,,,,,,,,,, -53600,0.8260946,1.0379022,,,,,,,,,,,,,, -53700,0.80728805,1.0307976,,,,,,,,,,,,,, -53800,0.89740354,1.0091226,,,,,,,,,,,,,, -53900,0.90985686,1.0693154,,,,,,,,,,,,,, -54000,0.92234457,1.0837344,,,,,,,,,,,,,, -54100,0.9546362,1.0489142,,,,,,,,,,,,,, -54200,0.95123124,1.0623444,,,,,,,,,,,,,, -54300,0.92004126,1.0263065,,,,,,,,,,,,,, -54400,0.78302467,1.0270683,,,,,,,,,,,,,, -54500,0.88598824,1.1343999,,,,,,,,,,,,,, -54600,0.96392715,1.0773528,,,,,,,,,,,,,, -54700,1.0256461,1.0725483,,,,,,,,,,,,,, -54800,1.0681523,1.0818944,,,,,,,,,,,,,, -54900,0.89568466,1.111479,,,,,,,,,,,,,, -55000,0.9065399,1.0715425,,,,,,,,,,,,,, -55100,0.92930126,1.008556,,,,,,,,,,,,,, -55121,,,0.16579323,0.0624706694271911,0.36580062,0.108238315456134,5348.0,0.20039488,0.0675360022748969,2472.0,46150.56253623962,50635.04708957672,46150.56253623962,4480.270308256149,1.7505762577056885,0.0 -55200,0.90231013,1.0871897,,,,,,,,,,,,,, -55300,1.2968884,1.0260363,,,,,,,,,,,,,, -55400,0.8171209,1.0516998,,,,,,,,,,,,,, -55500,1.0612301,1.0480222,,,,,,,,,,,,,, -55600,1.0704697,1.0616784,,,,,,,,,,,,,, -55700,0.8074067,0.9897111,,,,,,,,,,,,,, -55800,0.77411324,1.0227972,,,,,,,,,,,,,, -55900,0.93679065,1.0236715,,,,,,,,,,,,,, -56000,0.8096705,1.074231,,,,,,,,,,,,,, -56100,0.8621739,1.0222995,,,,,,,,,,,,,, -56200,0.9335031,1.0478432,,,,,,,,,,,,,, -56300,1.1646048,1.0473053,,,,,,,,,,,,,, -56400,0.9205963,1.0310594,,,,,,,,,,,,,, -56500,0.8879599,0.9929638,,,,,,,,,,,,,, -56600,0.98302627,1.0846531,,,,,,,,,,,,,, -56700,1.097947,1.0237119,,,,,,,,,,,,,, -56800,1.111342,0.99007744,,,,,,,,,,,,,, -56851,,,0.16412912,0.061519178599636,0.35546932,0.1046467845177983,5348.0,0.19666456,0.066317307497004,2472.0,47590.833641052246,52208.10142946243,47590.833641052246,4612.914803504944,1.8098342418670648,0.0 -56900,1.1456507,1.0285147,,,,,,,,,,,,,, -57000,0.9241428,1.018057,,,,,,,,,,,,,, -57100,0.9460255,1.035678,,,,,,,,,,,,,, -57200,0.89024246,1.0136055,,,,,,,,,,,,,, -57300,1.0046799,1.0130714,,,,,,,,,,,,,, -57400,1.0557474,1.0016078,,,,,,,,,,,,,, -57500,0.96476114,1.0536978,,,,,,,,,,,,,, -57600,0.9238401,1.0679061,,,,,,,,,,,,,, -57700,1.077522,1.0463641,,,,,,,,,,,,,, -57800,0.95010626,1.0104444,,,,,,,,,,,,,, -57900,1.0818676,0.9898182,,,,,,,,,,,,,, -58000,0.95756567,0.98840386,,,,,,,,,,,,,, -58100,0.9790016,1.0094122,,,,,,,,,,,,,, -58200,0.9521132,1.0012585,,,,,,,,,,,,,, -58300,0.9376379,1.0378405,,,,,,,,,,,,,, -58400,1.0155959,0.99073994,,,,,,,,,,,,,, -58500,0.99932563,1.0210067,,,,,,,,,,,,,, -58559,,,0.16294,0.0612672758361559,0.34518048,0.1016538420691852,5348.0,0.19008121,0.0627424694818516,2472.0,49030.73461127281,53784.43896389008,49030.73461127281,4749.212601184845,1.87066912651062,0.0 -58600,0.9174493,0.98229235,,,,,,,,,,,,,, -58700,0.93468225,1.0268984,,,,,,,,,,,,,, -58800,1.2656727,0.99597764,,,,,,,,,,,,,, -58900,0.8829031,1.0147614,,,,,,,,,,,,,, -59000,0.899374,0.9745622,,,,,,,,,,,,,, -59100,1.1968403,1.0348266,,,,,,,,,,,,,, -59200,1.0306184,1.067263,,,,,,,,,,,,,, -59300,1.0497888,0.94701415,,,,,,,,,,,,,, -59400,0.85685325,1.0306517,,,,,,,,,,,,,, -59500,1.0284523,1.0342689,,,,,,,,,,,,,, -59600,0.99403137,1.0306034,,,,,,,,,,,,,, -59700,1.0192746,1.0214279,,,,,,,,,,,,,, -59800,0.96740675,1.0310231,,,,,,,,,,,,,, -59900,0.94405866,0.96382433,,,,,,,,,,,,,, -60000,0.984786,1.0152063,,,,,,,,,,,,,, -60100,1.3870201,0.9703571,,,,,,,,,,,,,, -60200,1.0473491,0.99222064,,,,,,,,,,,,,, -60285,,,0.13095543,0.0508440379312663,0.33887717,0.099317416028655,5348.0,0.18431841,0.060386326244592,2472.0,50470.98560571671,55361.18985915184,50470.98560571671,4885.57377076149,1.9312732219696045,0.0 -60300,1.121216,0.98758984,,,,,,,,,,,,,, -60400,0.918182,1.006884,,,,,,,,,,,,,, -60500,1.103157,1.0298352,,,,,,,,,,,,,, -60600,1.1187005,0.99000686,,,,,,,,,,,,,, -60700,1.0932453,0.97065127,,,,,,,,,,,,,, -60800,0.91486174,0.8998368,,,,,,,,,,,,,, -60900,1.082987,0.99853176,,,,,,,,,,,,,, -61000,1.0116074,0.9779355,,,,,,,,,,,,,, -61100,1.0657259,1.0092405,,,,,,,,,,,,,, -61200,1.0982571,0.9870815,,,,,,,,,,,,,, -61300,1.279292,0.93828505,,,,,,,,,,,,,, -61400,0.92617804,0.9603761,,,,,,,,,,,,,, -61500,1.0405508,1.0381088,,,,,,,,,,,,,, -61600,1.0268683,0.9370214,,,,,,,,,,,,,, -61700,1.3107921,0.9568773,,,,,,,,,,,,,, -61800,1.077506,1.0277609,,,,,,,,,,,,,, -61900,0.937521,1.0185438,,,,,,,,,,,,,, -62000,1.0966537,0.94803876,,,,,,,,,,,,,, -62016,,,0.12969963,0.0491762128611199,0.3313753,0.0969713353350647,5348.0,0.17523469,0.0594519935815408,2472.0,51911.11315703392,56937.3402159214,51911.11315703392,5021.4548535346985,1.996178150177002,0.0 -62100,1.4696779,0.94270235,,,,,,,,,,,,,, -62200,0.89316773,0.977631,,,,,,,,,,,,,, -62300,1.3157086,0.9173648,,,,,,,,,,,,,, -62400,0.94597507,0.95923424,,,,,,,,,,,,,, -62500,1.1324966,1.0171163,,,,,,,,,,,,,, -62600,0.99652344,0.9019579,,,,,,,,,,,,,, -62700,1.0236259,0.96052855,,,,,,,,,,,,,, -62800,1.048555,0.97667235,,,,,,,,,,,,,, -62900,0.97392744,0.91863555,,,,,,,,,,,,,, -63000,1.101305,0.9817351,,,,,,,,,,,,,, -63100,1.0468103,0.99746674,,,,,,,,,,,,,, -63200,1.0607768,0.94160575,,,,,,,,,,,,,, -63300,1.0841362,0.9643805,,,,,,,,,,,,,, -63400,1.2989242,0.96181244,,,,,,,,,,,,,, -63500,1.174638,0.9383343,,,,,,,,,,,,,, -63600,1.3677272,0.9607228,,,,,,,,,,,,,, -63700,0.9770346,0.94181156,,,,,,,,,,,,,, -63722,,,0.11556399,0.0442518859081373,0.32159543,0.0928874170906668,5348.0,0.17089857,0.0575020819369122,2472.0,53351.861990213394,58512.570706129074,53351.861990213394,5155.799608707428,2.057328224182129,0.0 -63800,1.2215962,0.95348173,,,,,,,,,,,,,, -63900,1.1428629,0.95156693,,,,,,,,,,,,,, -64000,1.1036844,0.9788633,,,,,,,,,,,,,, -64100,1.0726703,0.9583403,,,,,,,,,,,,,, -64200,1.2284683,0.9602607,,,,,,,,,,,,,, -64300,1.1469592,0.9591859,,,,,,,,,,,,,, -64400,1.104631,0.9203296,,,,,,,,,,,,,, -64500,1.0315053,0.93376446,,,,,,,,,,,,,, -64600,1.1584802,0.9629896,,,,,,,,,,,,,, -64700,1.3971162,0.9199938,,,,,,,,,,,,,, -64800,1.4781598,0.9338602,,,,,,,,,,,,,, -64900,0.9388573,0.9566746,,,,,,,,,,,,,, -65000,1.1600227,0.8994771,,,,,,,,,,,,,, -65100,1.1459229,0.9191287,,,,,,,,,,,,,, -65200,1.3331556,0.93980175,,,,,,,,,,,,,, -65300,1.0729301,0.94375545,,,,,,,,,,,,,, -65400,1.1776575,0.9084361,,,,,,,,,,,,,, -65443,,,0.11451922,0.0434646458320285,0.3201817,0.0917481680295818,5348.0,0.1703582,0.056384945057177,2472.0,54792.02154159546,60088.04986286163,54792.02154159546,5290.980396270752,2.120563507080078,0.0 -65500,1.066921,0.9452578,,,,,,,,,,,,,, -65600,1.2968221,0.92670846,,,,,,,,,,,,,, -65700,1.0614936,0.9199157,,,,,,,,,,,,,, -65800,1.1543491,0.8912728,,,,,,,,,,,,,, -65900,1.0646774,0.9346426,,,,,,,,,,,,,, -66000,1.216677,0.9113672,,,,,,,,,,,,,, -66100,1.2883025,0.9524037,,,,,,,,,,,,,, -66200,1.0590259,0.9386498,,,,,,,,,,,,,, -66300,1.1870857,0.8777069,,,,,,,,,,,,,, -66400,1.0861764,0.9199934,,,,,,,,,,,,,, -66500,1.128361,0.9400294,,,,,,,,,,,,,, -66600,1.3816354,0.91435343,,,,,,,,,,,,,, -66700,1.3231272,0.90306926,,,,,,,,,,,,,, -66800,1.0923009,0.9644095,,,,,,,,,,,,,, -66900,1.2915846,0.9035473,,,,,,,,,,,,,, -67000,1.1941179,0.902817,,,,,,,,,,,,,, -67100,1.243128,0.94224876,,,,,,,,,,,,,, -67169,,,0.1024924,0.0387681402473111,0.31268138,0.0893055408053911,5348.0,0.16590431,0.0544147218329169,2472.0,56232.27780032158,61662.15571832657,56232.27780032158,5424.689416408539,2.1815385818481445,0.0 -67200,1.1246388,0.9347576,,,,,,,,,,,,,, -67300,1.1434208,0.89897466,,,,,,,,,,,,,, -67400,1.1370605,0.9256037,,,,,,,,,,,,,, -67500,1.0746207,0.90831304,,,,,,,,,,,,,, -67600,1.1862054,0.9038916,,,,,,,,,,,,,, -67700,1.4981018,0.90317726,,,,,,,,,,,,,, -67800,1.6267413,0.88496906,,,,,,,,,,,,,, -67900,1.205409,0.9439915,,,,,,,,,,,,,, -68000,1.1223927,0.8621645,,,,,,,,,,,,,, -68100,1.0654817,0.9235322,,,,,,,,,,,,,, -68200,1.0486078,0.87165064,,,,,,,,,,,,,, -68300,1.1231998,0.8985571,,,,,,,,,,,,,, -68400,1.3386486,0.91527194,,,,,,,,,,,,,, -68500,1.0339545,0.8958048,,,,,,,,,,,,,, -68600,1.2845709,0.9352036,,,,,,,,,,,,,, -68700,1.3737782,0.8971117,,,,,,,,,,,,,, -68800,1.1444011,0.87160677,,,,,,,,,,,,,, -68876,,,0.09694151,0.0373078157687922,0.30519754,0.0869208414995607,5348.0,0.1626802,0.0536428818069181,2472.0,57672.29211544991,63235.09067606926,57672.29211544991,5557.467604875565,2.247718572616577,0.0 -68900,1.2620337,0.91143453,,,,,,,,,,,,,, -69000,1.172761,0.89709413,,,,,,,,,,,,,, -69100,1.3621768,0.9203969,,,,,,,,,,,,,, -69200,1.1489724,0.9282253,,,,,,,,,,,,,, -69300,1.3583772,0.8896434,,,,,,,,,,,,,, -69400,1.2202183,0.86253613,,,,,,,,,,,,,, -69500,1.5518188,0.9074996,,,,,,,,,,,,,, -69600,1.1018116,0.91334796,,,,,,,,,,,,,, -69700,1.8361585,0.89083207,,,,,,,,,,,,,, -69800,1.3638496,0.81880355,,,,,,,,,,,,,, -69900,1.1598189,0.89278,,,,,,,,,,,,,, -70000,1.3510163,0.9454928,,,,,,,,,,,,,, -70100,1.4424636,0.87791866,,,,,,,,,,,,,, -70200,1.1584315,0.86551833,,,,,,,,,,,,,, -70300,1.138484,0.8966166,,,,,,,,,,,,,, -70400,1.1685781,0.88211924,,,,,,,,,,,,,, -70500,1.2646822,0.88381463,,,,,,,,,,,,,, -70600,1.2408721,0.8676036,,,,,,,,,,,,,, -70607,,,0.07881192,0.0299819644603755,0.30443206,0.0855788447242148,5348.0,0.16090697,0.0543131639347592,2472.0,59112.44225859642,64809.5007917881,59112.44225859642,5691.591963291168,2.307018518447876,0.0 -70700,1.5190082,0.8646028,,,,,,,,,,,,,, -70800,1.3904067,0.8722625,,,,,,,,,,,,,, -70900,1.0634089,0.8747431,,,,,,,,,,,,,, -71000,1.1615072,0.8944696,,,,,,,,,,,,,, -71100,1.4158528,0.88004345,,,,,,,,,,,,,, -71200,1.3753291,0.85969806,,,,,,,,,,,,,, -71300,1.5512326,0.89505047,,,,,,,,,,,,,, -71400,1.0961992,0.8877216,,,,,,,,,,,,,, -71500,1.2889434,0.92092496,,,,,,,,,,,,,, -71600,1.1409413,0.87164164,,,,,,,,,,,,,, -71700,1.3996326,0.86558515,,,,,,,,,,,,,, -71800,1.0709144,0.8066353,,,,,,,,,,,,,, -71900,1.7052099,0.8927943,,,,,,,,,,,,,, -72000,1.422342,0.8652067,,,,,,,,,,,,,, -72100,1.9013461,0.8923627,,,,,,,,,,,,,, -72200,1.3176248,0.83741516,,,,,,,,,,,,,, -72300,1.3266685,0.8381321,,,,,,,,,,,,,, -72323,,,0.07315466,0.0278330019880715,0.2974776,0.0836092954999662,5348.0,0.15847631,0.051266426990027,2472.0,60552.34362244606,66382.07028150558,60552.34362244606,5824.119750261307,2.368692636489868,0.0 -72400,1.1387275,0.86703867,,,,,,,,,,,,,, -72500,1.6547081,0.84745073,,,,,,,,,,,,,, -72600,1.2987764,0.83918077,,,,,,,,,,,,,, -72700,1.3201114,0.834628,,,,,,,,,,,,,, -72800,1.8488113,0.83748245,,,,,,,,,,,,,, -72900,1.4290572,0.8366103,,,,,,,,,,,,,, -72950,,,,,,,,,,,61068.561847925186,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 81e14e007..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -138.79355025291443,0.0,36.7762336730957,1,0,36.7762336730957,30.871214,2472,1.1899945158734997,175.56985116004944,32.058502,1.1703602515325648,30.757864,5348,1.1779352558965794 -251.8186011314392,0.0316855907440185,1477.254374742508,1706,0,1477.254374742508,6.156824,2472,0.8923486279527959,1729.1772775650024,6.2620625,0.9306712172923776,6.2103653,5348,0.8912886065439238 -385.1414303779602,0.0872964859008789,2917.648540019989,3416,0,2917.648540019989,2.3015385,2472,0.5043974569902301,3303.025156736374,2.681486,0.573457583199163,2.6309824,5348,0.5461540689535321 -520.9090373516083,0.1316628456115722,4357.607877492905,5113,0,4357.607877492905,0.699066,2472,0.2274084455548108,4878.8678069114685,0.8880565,0.2816791786266284,0.9799351,5348,0.2902188709848711 -656.9221382141113,0.1820619106292724,5797.983711004257,6814,0,5797.983711004257,0.5053952,2472,0.1702313488920033,6455.382749795914,0.61878496,0.2115188436201421,0.7577799,5348,0.229925562624907 -794.4483127593994,0.2337381839752197,7237.903148651123,8528,0,7237.903148651123,0.43440384,2472,0.1470964596916702,8032.956871747971,0.5105811,0.1751914005955296,0.6722918,5348,0.2040317831178736 -954.7578556537628,0.2836422920227051,8678.238431215286,10237,0,8678.238431215286,0.38682497,2472,0.1322080718217455,9633.72482395172,0.30627134,0.1130175565380037,0.61254525,5348,0.1879374764667831 -1090.755385637283,0.3366787433624267,10118.417578220367,11951,0,10118.417578220367,0.3629714,2472,0.1239006357524424,11210.029403686523,0.2763798,0.1026946434415796,0.579581,5348,0.1773559767129768 -1226.897890329361,0.393744945526123,11558.427836418152,13670,0,11558.427836418152,0.3431652,2472,0.1186399366278715,12786.317202329636,0.26991197,0.0994529453738206,0.5541362,5348,0.1695260530812825 -1364.6433284282684,0.4485950469970703,12998.425715208054,15354,0,12998.425715208054,0.32349566,2472,0.1106168626734101,14364.191707611084,0.2340234,0.09087778612572,0.5342885,5348,0.1642063392451992 -1498.5190238952637,0.4984886646270752,14438.438889980316,17048,0,14438.438889980316,0.30940318,2472,0.1046249466821034,15938.20537996292,0.23105614,0.0864152774149522,0.51330495,5348,0.1552371665524199 -1633.5538535118103,0.5498590469360352,15878.782112121582,18776,0,15878.782112121582,0.29983425,2472,0.1006032539150569,17513.712234973907,0.2122344,0.083547928665306,0.5012511,5348,0.1528910858588296 -1767.913744688034,0.6073198318481445,17318.944922447205,20469,0,17318.944922447205,0.28335324,2472,0.0950175695163813,19088.36824822426,0.22761327,0.0824963516181646,0.48038745,5348,0.1461521380229201 -1905.30117058754,0.6659457683563232,18759.727390527725,22173,0,18759.727390527725,0.27434734,2472,0.0939004326366461,20666.671327352524,0.20327859,0.0780693384223918,0.4765248,5348,0.1447425586761539 -2041.6354887485504,0.7222182750701904,20200.300904989243,23892,0,20200.300904989243,0.26593608,2472,0.090894318851177,22243.712045431137,0.16723728,0.0660113386024076,0.4556515,5348,0.1391911331666296 -2176.9085042476654,0.7782742977142334,21641.20259666443,25590,0,21641.20259666443,0.26481235,2472,0.0906505798955984,23820.0189204216,0.16852893,0.0654229545988174,0.45363906,5348,0.1365747221873582 -2313.249319076538,0.829658031463623,23081.45346689224,27297,0,23081.45346689224,0.25956357,2472,0.0863039018544472,25396.73770880699,0.15395947,0.062930848772081,0.4432258,5348,0.1348755032487907 -2450.9862122535706,0.8819196224212646,24521.913326740265,29000,0,24521.913326740265,0.25002927,2472,0.086222655535921,26975.06228876114,0.16537838,0.0638170060280613,0.43855268,5348,0.1327128609633412 -2585.5083949565887,0.9355514049530028,25962.687898874283,30703,0,25962.687898874283,0.24562462,2472,0.0828103101578209,28550.48955798149,0.16171478,0.0622558439090314,0.43173057,5348,0.1284841229230427 -2723.458404779434,0.9899494647979736,27402.558747529984,32426,0,27402.558747529984,0.24090798,2472,0.0830946722726626,30128.4415204525,0.14997573,0.0594484991749729,0.4245383,5348,0.127624858800699 -2861.7312412261963,1.050461769104004,28842.477757692337,34142,0,28842.477757692337,0.2346872,2472,0.0791339142445108,31706.7717628479,0.14274807,0.0562218694164557,0.4185331,5348,0.1253849792907692 -2996.345301389694,1.1065824031829834,30282.776348114014,35838,0,30282.776348114014,0.23062919,2472,0.0768387057461458,33281.81516170502,0.12617776,0.0508775722171549,0.40873176,5348,0.120547998107688 -3133.003934860229,1.1698052883148191,31722.89506983757,37553,0,31722.89506983757,0.22272642,2472,0.0752747141145166,34858.73185968399,0.1250977,0.0486143405405681,0.40274152,5348,0.1209245295770296 -3269.946902036667,1.2371857166290283,33163.29486012459,39264,0,33163.29486012459,0.22493432,2472,0.0736904109032559,36436.21641421318,0.12867208,0.0503787059078121,0.39843825,5348,0.118279154638578 -3408.116780996322,1.292286396026611,34603.79007267952,40958,0,34603.79007267952,0.21982695,2472,0.07413726565515,38015.01168775559,0.11844589,0.0466740389277173,0.38874224,5348,0.1154310319858655 -3544.602229118347,1.3495960235595703,36044.39460206032,42686,0,36044.39460206032,0.2124593,2472,0.0698921455121564,39592.237605810165,0.11153542,0.0453985841797476,0.38364777,5348,0.1128628942718943 -3683.409082174301,1.406475305557251,37485.06210923195,44363,0,37485.06210923195,0.20779921,2472,0.0695671602380517,41171.84330153465,0.1322927,0.0478370932916387,0.37878135,5348,0.1110864381088465 -3819.99338889122,1.465172290802002,38925.02504873276,46050,0,38925.02504873276,0.20512715,2472,0.0687953202120529,42748.52485728264,0.08896961,0.0359333018848575,0.37348276,5348,0.1105457775374842 -3958.7835640907288,1.533278226852417,40365.21565961838,47755,0,40365.21565961838,0.2026361,2472,0.0677391180712124,44327.65276861191,0.09621476,0.0385807713017612,0.3708817,5348,0.1084120992112148 -4095.601159334183,1.597571611404419,41805.53027248383,49431,0,41805.53027248383,0.1993707,2472,0.0656267137895314,45904.92424035072,0.114910506,0.0454800272605359,0.3653458,5348,0.1066163337420469 -4230.847477197647,1.6536619663238523,43246.34427714348,51120,0,43246.34427714348,0.19543357,2472,0.0643877074320069,47481.11702299118,0.11720094,0.0459911610910562,0.35821724,5348,0.1047336763953387 -4362.7331802845,1.71537446975708,44687.03170180321,52839,0,44687.03170180321,0.19422725,2472,0.0630877663355879,49053.82754635811,0.1289726,0.0527311119361425,0.3622812,5348,0.1027930911302702 -4499.063733100891,1.7805514335632324,46126.97074270248,54523,0,46126.97074270248,0.19085734,2472,0.0624174842077468,50630.23694562912,0.11021236,0.042651719685409,0.34967783,5348,0.1018662444365061 -4637.565958499908,1.8369412422180176,47568.22066044808,56223,0,47568.22066044808,0.1872067,2472,0.0610566083724331,52210.12111449242,0.09992043,0.0402115093926129,0.3473121,5348,0.1006304488448207 -4771.489331007004,1.899439573287964,49008.23717498779,57923,0,49008.23717498779,0.18778616,2472,0.061462839965064,53784.19967198372,0.08032824,0.0322207382815245,0.34661222,5348,0.0986705542736321 -4906.509536027908,1.959836483001709,50448.39874601364,59600,0,50448.39874601364,0.18204756,2472,0.0583754798610687,55359.5160779953,0.09178022,0.0368883750354486,0.3443059,5348,0.0984484972532512 -5047.088932514191,2.025007963180542,51888.93256998062,61297,0,51888.93256998062,0.18233697,2472,0.0587207767148051,56940.76946043968,0.076783,0.0299433215698855,0.34047252,5348,0.0963148189269818 -5181.807159900665,2.087443351745605,53328.91833233833,62996,0,53328.91833233833,0.17784591,2472,0.0574614587776491,58515.613005161285,0.07512333,0.0306181933086717,0.33476064,5348,0.0945190534578139 -5317.002858400345,2.147326707839966,54768.97867703438,64691,0,54768.97867703438,0.17620729,2472,0.0559990250441776,60091.00317811966,0.068822555,0.0274565048181808,0.33348805,5348,0.0934956602334495 -5449.299804925919,2.210102796554565,56209.32621002197,66393,0,56209.32621002197,0.17585711,2472,0.0568317998090711,61663.78510403633,0.07160717,0.0279594590714634,0.33296782,5348,0.0930225822335074 -5584.706095695496,2.277143955230713,57649.28861045837,68079,0,57649.28861045837,0.17321478,2472,0.0558974671460199,63239.29641199112,0.07398825,0.0287359782769362,0.33303624,5348,0.0930515461926875 -5720.492845773697,2.3368923664093018,59090.46134185791,69778,0,59090.46134185791,0.17336021,2472,0.055531858712652,64816.39191651344,0.061255727,0.0239954952694754,0.32672223,5348,0.0911592341929192 -5855.735311031342,2.396700859069824,60530.481770038605,71500,0,60530.481770038605,0.17124374,2472,0.05437409867365385,66391.79093980789,0.062214013,0.024718881054426472,0.32732484,5348,0.09039651660117594 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/measurements.csv deleted file mode 100644 index 1274776f1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/measurements.csv +++ /dev/null @@ -1,767 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,42.587708,31.960152,,,,,,,,,,,,,, -1,,,32.058502,1.1703602515325648,30.757864,1.1779352558965794,5348.0,30.871214,1.1899945158734997,2472.0,36.7762336730957,175.56985116004944,36.7762336730957,138.79355025291443,0.0,0.0 -100,2.1679137,6.5368686,,,,,,,,,,,,,, -200,0.6594303,5.9166384,,,,,,,,,,,,,, -300,0.54308134,5.8423276,,,,,,,,,,,,,, -400,5.53243,5.8345423,,,,,,,,,,,,,, -500,5.927961,5.836406,,,,,,,,,,,,,, -600,6.0273433,5.844461,,,,,,,,,,,,,, -700,0.44281265,5.794849,,,,,,,,,,,,,, -800,0.7825109,5.798374,,,,,,,,,,,,,, -900,1.7148659,5.794052,,,,,,,,,,,,,, -1000,0.69575745,5.7630763,,,,,,,,,,,,,, -1100,3.0030005,5.729012,,,,,,,,,,,,,, -1200,0.345625,5.5551467,,,,,,,,,,,,,, -1300,1.2724117,5.4706836,,,,,,,,,,,,,, -1400,2.659447,5.317225,,,,,,,,,,,,,, -1500,0.83086985,4.7000914,,,,,,,,,,,,,, -1600,2.6495013,4.1761546,,,,,,,,,,,,,, -1700,1.4203461,3.8417554,,,,,,,,,,,,,, -1706,,,6.2620625,0.9306712172923776,6.2103653,0.8912886065439238,5348.0,6.156824,0.8923486279527959,2472.0,1477.254374742508,1729.1772775650024,1477.254374742508,251.8186011314392,0.0316855907440185,0.0 -1800,1.6049523,3.6524723,,,,,,,,,,,,,, -1900,1.6441811,3.4518487,,,,,,,,,,,,,, -2000,1.5833392,3.314888,,,,,,,,,,,,,, -2100,1.9745584,3.198912,,,,,,,,,,,,,, -2200,1.4019674,3.0861747,,,,,,,,,,,,,, -2300,1.1916562,3.0750384,,,,,,,,,,,,,, -2400,1.6654437,2.8821104,,,,,,,,,,,,,, -2500,1.1377739,2.8174303,,,,,,,,,,,,,, -2600,0.95153445,2.8154922,,,,,,,,,,,,,, -2700,0.9500308,2.718359,,,,,,,,,,,,,, -2800,0.9759306,2.67755,,,,,,,,,,,,,, -2900,0.8720946,2.5822523,,,,,,,,,,,,,, -3000,0.98257095,2.6292243,,,,,,,,,,,,,, -3100,1.3676621,2.470106,,,,,,,,,,,,,, -3200,0.9889243,2.5134425,,,,,,,,,,,,,, -3300,1.7655983,2.4417388,,,,,,,,,,,,,, -3400,0.8949383,2.371038,,,,,,,,,,,,,, -3416,,,2.681486,0.573457583199163,2.6309824,0.5461540689535321,5348.0,2.3015385,0.5043974569902301,2472.0,2917.648540019989,3303.025156736374,2917.648540019989,385.1414303779602,0.0872964859008789,0.0 -3500,0.93096834,2.409657,,,,,,,,,,,,,, -3600,0.7850849,2.26995,,,,,,,,,,,,,, -3700,0.8606612,2.2079196,,,,,,,,,,,,,, -3800,0.7224233,2.1812344,,,,,,,,,,,,,, -3900,0.97750753,2.240332,,,,,,,,,,,,,, -4000,0.8671828,2.1598296,,,,,,,,,,,,,, -4100,0.80033,2.1578567,,,,,,,,,,,,,, -4200,1.2367799,2.0495343,,,,,,,,,,,,,, -4300,1.1098723,1.9767427,,,,,,,,,,,,,, -4400,0.96037346,1.9803503,,,,,,,,,,,,,, -4500,1.4047383,1.992275,,,,,,,,,,,,,, -4600,0.777972,1.9143279,,,,,,,,,,,,,, -4700,0.7659651,1.9188426,,,,,,,,,,,,,, -4800,0.7803557,1.8938494,,,,,,,,,,,,,, -4900,0.7457156,1.9074311,,,,,,,,,,,,,, -5000,0.8705866,1.8922353,,,,,,,,,,,,,, -5100,0.65561,1.8980151,,,,,,,,,,,,,, -5113,,,0.8880565,0.2816791786266284,0.9799351,0.2902188709848711,5348.0,0.699066,0.2274084455548108,2472.0,4357.607877492905,4878.8678069114685,4357.607877492905,520.9090373516083,0.1316628456115722,0.0 -5200,0.87073845,1.8995419,,,,,,,,,,,,,, -5300,1.0228812,1.8538699,,,,,,,,,,,,,, -5400,0.829654,1.8065262,,,,,,,,,,,,,, -5500,0.7612463,1.809647,,,,,,,,,,,,,, -5600,0.76576984,1.7638866,,,,,,,,,,,,,, -5700,0.6867443,1.8257685,,,,,,,,,,,,,, -5800,0.9574746,1.8302168,,,,,,,,,,,,,, -5900,0.83717155,1.7687498,,,,,,,,,,,,,, -6000,0.94563776,1.7845073,,,,,,,,,,,,,, -6100,0.63091576,1.6904569,,,,,,,,,,,,,, -6200,0.70107806,1.7087959,,,,,,,,,,,,,, -6300,0.5917444,1.7193115,,,,,,,,,,,,,, -6400,0.6978045,1.7161168,,,,,,,,,,,,,, -6500,0.63121426,1.7201668,,,,,,,,,,,,,, -6600,0.7468717,1.6409839,,,,,,,,,,,,,, -6700,0.76625246,1.6903907,,,,,,,,,,,,,, -6800,0.669717,1.6497921,,,,,,,,,,,,,, -6814,,,0.61878496,0.2115188436201421,0.7577799,0.229925562624907,5348.0,0.5053952,0.1702313488920033,2472.0,5797.983711004257,6455.382749795914,5797.983711004257,656.9221382141113,0.1820619106292724,0.0 -6900,0.72395176,1.6954107,,,,,,,,,,,,,, -7000,0.5894425,1.6330994,,,,,,,,,,,,,, -7100,0.67942804,1.6372067,,,,,,,,,,,,,, -7200,0.6319641,1.698771,,,,,,,,,,,,,, -7300,0.59730136,1.617103,,,,,,,,,,,,,, -7400,0.758127,1.6425977,,,,,,,,,,,,,, -7500,0.73288095,1.6036419,,,,,,,,,,,,,, -7600,0.6209434,1.5802599,,,,,,,,,,,,,, -7700,0.70146257,1.6058317,,,,,,,,,,,,,, -7800,0.56059325,1.577353,,,,,,,,,,,,,, -7900,0.80812377,1.6732019,,,,,,,,,,,,,, -8000,0.65859973,1.5962933,,,,,,,,,,,,,, -8100,0.6401792,1.6003108,,,,,,,,,,,,,, -8200,0.6460671,1.5595611,,,,,,,,,,,,,, -8300,0.65050083,1.5857913,,,,,,,,,,,,,, -8400,0.55312234,1.5322884,,,,,,,,,,,,,, -8500,0.858231,1.582378,,,,,,,,,,,,,, -8528,,,0.5105811,0.1751914005955296,0.6722918,0.2040317831178736,5348.0,0.43440384,0.1470964596916702,2472.0,7237.903148651123,8032.956871747971,7237.903148651123,794.4483127593994,0.2337381839752197,0.0 -8600,0.6801678,1.5628436,,,,,,,,,,,,,, -8700,0.60891616,1.5547528,,,,,,,,,,,,,, -8800,0.5671761,1.5781206,,,,,,,,,,,,,, -8900,0.7299678,1.5753886,,,,,,,,,,,,,, -9000,0.69964767,1.5768346,,,,,,,,,,,,,, -9100,0.611624,1.5459918,,,,,,,,,,,,,, -9200,0.6957778,1.5610656,,,,,,,,,,,,,, -9300,0.53716844,1.5275273,,,,,,,,,,,,,, -9400,0.66296524,1.5174751,,,,,,,,,,,,,, -9500,0.81286573,1.5616698,,,,,,,,,,,,,, -9600,0.9438636,1.5116956,,,,,,,,,,,,,, -9700,0.6780317,1.4941663,,,,,,,,,,,,,, -9800,0.646469,1.454679,,,,,,,,,,,,,, -9900,0.6197824,1.5395762,,,,,,,,,,,,,, -10000,0.67687494,1.53021,,,,,,,,,,,,,, -10100,0.5808559,1.459968,,,,,,,,,,,,,, -10200,0.7759921,1.5099458,,,,,,,,,,,,,, -10237,,,0.30627134,0.1130175565380037,0.61254525,0.1879374764667831,5348.0,0.38682497,0.1322080718217455,2472.0,8678.238431215286,9633.72482395172,8678.238431215286,954.7578556537628,0.2836422920227051,0.0 -10300,0.70658815,1.513116,,,,,,,,,,,,,, -10400,0.6832271,1.4500297,,,,,,,,,,,,,, -10500,0.5750352,1.4390157,,,,,,,,,,,,,, -10600,0.5500688,1.4801719,,,,,,,,,,,,,, -10700,0.60064685,1.5050205,,,,,,,,,,,,,, -10800,0.7780868,1.4871072,,,,,,,,,,,,,, -10900,0.68363,1.5226046,,,,,,,,,,,,,, -11000,0.54841876,1.4514986,,,,,,,,,,,,,, -11100,0.6254904,1.4590553,,,,,,,,,,,,,, -11200,0.73014903,1.5073055,,,,,,,,,,,,,, -11300,0.6173702,1.4920005,,,,,,,,,,,,,, -11400,0.6257127,1.4473617,,,,,,,,,,,,,, -11500,0.57654136,1.4119585,,,,,,,,,,,,,, -11600,0.58488786,1.3965102,,,,,,,,,,,,,, -11700,0.672885,1.426747,,,,,,,,,,,,,, -11800,0.69484967,1.4198546,,,,,,,,,,,,,, -11900,0.5694604,1.4566497,,,,,,,,,,,,,, -11951,,,0.2763798,0.1026946434415796,0.579581,0.1773559767129768,5348.0,0.3629714,0.1239006357524424,2472.0,10118.417578220367,11210.029403686523,10118.417578220367,1090.755385637283,0.3366787433624267,0.0 -12000,0.56877136,1.3948252,,,,,,,,,,,,,, -12100,0.7434194,1.4188215,,,,,,,,,,,,,, -12200,0.57991266,1.4246641,,,,,,,,,,,,,, -12300,0.6687335,1.4249337,,,,,,,,,,,,,, -12400,0.52600694,1.3722863,,,,,,,,,,,,,, -12500,0.7659344,1.4110075,,,,,,,,,,,,,, -12600,0.5831013,1.3997786,,,,,,,,,,,,,, -12700,0.6242613,1.4115435,,,,,,,,,,,,,, -12800,0.63747996,1.3506367,,,,,,,,,,,,,, -12900,0.67042,1.3870578,,,,,,,,,,,,,, -13000,0.5756627,1.399942,,,,,,,,,,,,,, -13100,0.65817547,1.4608305,,,,,,,,,,,,,, -13200,0.66780365,1.3970668,,,,,,,,,,,,,, -13300,0.6473275,1.4114327,,,,,,,,,,,,,, -13400,0.58102447,1.4027048,,,,,,,,,,,,,, -13500,0.9792726,1.3650256,,,,,,,,,,,,,, -13600,0.6196069,1.402912,,,,,,,,,,,,,, -13670,,,0.26991197,0.0994529453738206,0.5541362,0.1695260530812825,5348.0,0.3431652,0.1186399366278715,2472.0,11558.427836418152,12786.317202329636,11558.427836418152,1226.897890329361,0.393744945526123,0.0 -13700,0.5588053,1.3722675,,,,,,,,,,,,,, -13800,0.5214336,1.380196,,,,,,,,,,,,,, -13900,0.59859455,1.3907889,,,,,,,,,,,,,, -14000,0.49199045,1.3982924,,,,,,,,,,,,,, -14100,0.5619031,1.3308308,,,,,,,,,,,,,, -14200,0.6402675,1.4185486,,,,,,,,,,,,,, -14300,0.5874968,1.4138217,,,,,,,,,,,,,, -14400,0.61604875,1.4182756,,,,,,,,,,,,,, -14500,0.62921965,1.361329,,,,,,,,,,,,,, -14600,0.53993344,1.3379227,,,,,,,,,,,,,, -14700,0.59016633,1.409427,,,,,,,,,,,,,, -14800,0.64620745,1.3220898,,,,,,,,,,,,,, -14900,0.694821,1.3406684,,,,,,,,,,,,,, -15000,0.65164745,1.401195,,,,,,,,,,,,,, -15100,0.59358793,1.3734981,,,,,,,,,,,,,, -15200,0.543562,1.3161136,,,,,,,,,,,,,, -15300,0.7534332,1.4031866,,,,,,,,,,,,,, -15354,,,0.2340234,0.09087778612572,0.5342885,0.1642063392451992,5348.0,0.32349566,0.1106168626734101,2472.0,12998.425715208054,14364.191707611084,12998.425715208054,1364.6433284282684,0.4485950469970703,0.0 -15400,0.5829388,1.3439677,,,,,,,,,,,,,, -15500,0.6186061,1.316474,,,,,,,,,,,,,, -15600,0.63867766,1.2186611,,,,,,,,,,,,,, -15700,0.76328295,1.3905832,,,,,,,,,,,,,, -15800,0.63979185,1.3542088,,,,,,,,,,,,,, -15900,0.6264479,1.2914459,,,,,,,,,,,,,, -16000,0.65766937,1.3728275,,,,,,,,,,,,,, -16100,0.604211,1.2953981,,,,,,,,,,,,,, -16200,0.5641908,1.2902081,,,,,,,,,,,,,, -16300,0.51593626,1.3898773,,,,,,,,,,,,,, -16400,0.62237775,1.281029,,,,,,,,,,,,,, -16500,0.68585885,1.3059001,,,,,,,,,,,,,, -16600,0.5601139,1.3175253,,,,,,,,,,,,,, -16700,0.5243256,1.3289535,,,,,,,,,,,,,, -16800,0.59882593,1.3470545,,,,,,,,,,,,,, -16900,0.5067787,1.3061631,,,,,,,,,,,,,, -17000,0.6996454,1.3213164,,,,,,,,,,,,,, -17048,,,0.23105614,0.0864152774149522,0.51330495,0.1552371665524199,5348.0,0.30940318,0.1046249466821034,2472.0,14438.438889980316,15938.20537996292,14438.438889980316,1498.5190238952637,0.4984886646270752,0.0 -17100,0.62918794,1.2903836,,,,,,,,,,,,,, -17200,0.5757809,1.3689481,,,,,,,,,,,,,, -17300,0.5922283,1.3413645,,,,,,,,,,,,,, -17400,0.7322941,1.3123133,,,,,,,,,,,,,, -17500,0.7139623,1.2927648,,,,,,,,,,,,,, -17600,0.49930912,1.2877918,,,,,,,,,,,,,, -17700,0.6208343,1.2830952,,,,,,,,,,,,,, -17800,0.64693993,1.2965404,,,,,,,,,,,,,, -17900,0.5686936,1.3287845,,,,,,,,,,,,,, -18000,0.63157743,1.3040464,,,,,,,,,,,,,, -18100,0.8363512,1.2906733,,,,,,,,,,,,,, -18200,0.59754497,1.3195288,,,,,,,,,,,,,, -18300,0.60689104,1.2728571,,,,,,,,,,,,,, -18400,0.6530643,1.3548802,,,,,,,,,,,,,, -18500,0.5621728,1.3991147,,,,,,,,,,,,,, -18600,0.6560406,1.2716436,,,,,,,,,,,,,, -18700,0.6131505,1.2948309,,,,,,,,,,,,,, -18776,,,0.2122344,0.083547928665306,0.5012511,0.1528910858588296,5348.0,0.29983425,0.1006032539150569,2472.0,15878.782112121582,17513.712234973907,15878.782112121582,1633.5538535118103,0.5498590469360352,0.0 -18800,0.727144,1.2891474,,,,,,,,,,,,,, -18900,0.673423,1.2398411,,,,,,,,,,,,,, -19000,0.5949094,1.3121811,,,,,,,,,,,,,, -19100,0.57735074,1.2989547,,,,,,,,,,,,,, -19200,0.61337036,1.2838049,,,,,,,,,,,,,, -19300,0.5784434,1.3006163,,,,,,,,,,,,,, -19400,0.51446384,1.3316803,,,,,,,,,,,,,, -19500,0.7096051,1.2590097,,,,,,,,,,,,,, -19600,0.58472514,1.2759088,,,,,,,,,,,,,, -19700,0.70900846,1.236012,,,,,,,,,,,,,, -19800,0.5594016,1.2977818,,,,,,,,,,,,,, -19900,0.6054242,1.2708877,,,,,,,,,,,,,, -20000,0.5076783,1.2856219,,,,,,,,,,,,,, -20100,0.5209162,1.2302899,,,,,,,,,,,,,, -20200,0.7063804,1.2864927,,,,,,,,,,,,,, -20300,0.56362754,1.3207529,,,,,,,,,,,,,, -20400,1.489865,1.3171525,,,,,,,,,,,,,, -20469,,,0.22761327,0.0824963516181646,0.48038745,0.1461521380229201,5348.0,0.28335324,0.0950175695163813,2472.0,17318.944922447205,19088.36824822426,17318.944922447205,1767.913744688034,0.6073198318481445,0.0 -20500,0.74007714,1.2877483,,,,,,,,,,,,,, -20600,0.5116306,1.2194552,,,,,,,,,,,,,, -20700,0.7597873,1.232943,,,,,,,,,,,,,, -20800,0.54674214,1.2892467,,,,,,,,,,,,,, -20900,0.7564501,1.2638674,,,,,,,,,,,,,, -21000,0.5613137,1.2962627,,,,,,,,,,,,,, -21100,0.9315203,1.2178051,,,,,,,,,,,,,, -21200,0.5901528,1.267761,,,,,,,,,,,,,, -21300,0.6993615,1.231626,,,,,,,,,,,,,, -21400,0.61934376,1.2895193,,,,,,,,,,,,,, -21500,0.61058605,1.2704422,,,,,,,,,,,,,, -21600,0.50719607,1.1972541,,,,,,,,,,,,,, -21700,0.5809477,1.2268727,,,,,,,,,,,,,, -21800,0.5465508,1.2712961,,,,,,,,,,,,,, -21900,0.5688271,1.2733779,,,,,,,,,,,,,, -22000,0.4967672,1.2330177,,,,,,,,,,,,,, -22100,0.55208427,1.2751483,,,,,,,,,,,,,, -22173,,,0.20327859,0.0780693384223918,0.4765248,0.1447425586761539,5348.0,0.27434734,0.0939004326366461,2472.0,18759.727390527725,20666.671327352524,18759.727390527725,1905.30117058754,0.6659457683563232,0.0 -22200,0.5732327,1.2573831,,,,,,,,,,,,,, -22300,0.52347034,1.2625494,,,,,,,,,,,,,, -22400,0.58790684,1.2918646,,,,,,,,,,,,,, -22500,0.67927265,1.2776496,,,,,,,,,,,,,, -22600,0.55819935,1.3107927,,,,,,,,,,,,,, -22700,0.619501,1.2246708,,,,,,,,,,,,,, -22800,0.5238234,1.2224859,,,,,,,,,,,,,, -22900,0.5563605,1.2341928,,,,,,,,,,,,,, -23000,0.56347,1.2984459,,,,,,,,,,,,,, -23100,0.73156935,1.208738,,,,,,,,,,,,,, -23200,0.5722804,1.2850336,,,,,,,,,,,,,, -23300,0.6153872,1.2292086,,,,,,,,,,,,,, -23400,0.5541971,1.2317871,,,,,,,,,,,,,, -23500,0.6719874,1.2319572,,,,,,,,,,,,,, -23600,0.5857843,1.2127813,,,,,,,,,,,,,, -23700,0.70649064,1.2394434,,,,,,,,,,,,,, -23800,0.5312295,1.1969694,,,,,,,,,,,,,, -23892,,,0.16723728,0.0660113386024076,0.4556515,0.1391911331666296,5348.0,0.26593608,0.090894318851177,2472.0,20200.300904989243,22243.712045431137,20200.300904989243,2041.6354887485504,0.7222182750701904,0.0 -23900,0.525354,1.1747144,,,,,,,,,,,,,, -24000,0.7221761,1.2476666,,,,,,,,,,,,,, -24100,0.5935401,1.2270606,,,,,,,,,,,,,, -24200,0.67437017,1.2493397,,,,,,,,,,,,,, -24300,0.6326306,1.1887763,,,,,,,,,,,,,, -24400,0.6242219,1.2370466,,,,,,,,,,,,,, -24500,0.6184985,1.2247893,,,,,,,,,,,,,, -24600,0.5937656,1.2221184,,,,,,,,,,,,,, -24700,0.5371537,1.2791424,,,,,,,,,,,,,, -24800,0.6196798,1.1828703,,,,,,,,,,,,,, -24900,0.72827405,1.1792831,,,,,,,,,,,,,, -25000,0.5793952,1.1959513,,,,,,,,,,,,,, -25100,0.5511277,1.1818964,,,,,,,,,,,,,, -25200,0.58629096,1.2023593,,,,,,,,,,,,,, -25300,0.5927563,1.2619773,,,,,,,,,,,,,, -25400,0.898693,1.1959282,,,,,,,,,,,,,, -25500,0.59027773,1.1919229,,,,,,,,,,,,,, -25590,,,0.16852893,0.0654229545988174,0.45363906,0.1365747221873582,5348.0,0.26481235,0.0906505798955984,2472.0,21641.20259666443,23820.0189204216,21641.20259666443,2176.9085042476654,0.7782742977142334,0.0 -25600,0.62051094,1.2349391,,,,,,,,,,,,,, -25700,0.62134653,1.2137196,,,,,,,,,,,,,, -25800,0.55015755,1.16846,,,,,,,,,,,,,, -25900,0.62106234,1.2151238,,,,,,,,,,,,,, -26000,0.63514405,1.1920487,,,,,,,,,,,,,, -26100,0.6137814,1.165283,,,,,,,,,,,,,, -26200,0.6752108,1.2256639,,,,,,,,,,,,,, -26300,0.5356195,1.2037227,,,,,,,,,,,,,, -26400,0.56028295,1.1892076,,,,,,,,,,,,,, -26500,0.64935106,1.251858,,,,,,,,,,,,,, -26600,0.618204,1.1459769,,,,,,,,,,,,,, -26700,1.2034948,1.2114309,,,,,,,,,,,,,, -26800,0.5925285,1.1154835,,,,,,,,,,,,,, -26900,0.60786974,1.1789391,,,,,,,,,,,,,, -27000,0.5831115,1.2034626,,,,,,,,,,,,,, -27100,0.630153,1.1697952,,,,,,,,,,,,,, -27200,0.5611697,1.2097093,,,,,,,,,,,,,, -27297,,,0.15395947,0.062930848772081,0.4432258,0.1348755032487907,5348.0,0.25956357,0.0863039018544472,2472.0,23081.45346689224,25396.73770880699,23081.45346689224,2313.249319076538,0.829658031463623,0.0 -27300,0.64122355,1.1988212,,,,,,,,,,,,,, -27400,0.56694704,1.1770023,,,,,,,,,,,,,, -27500,0.62386197,1.1997788,,,,,,,,,,,,,, -27600,0.54206586,1.1794096,,,,,,,,,,,,,, -27700,0.59456533,1.201101,,,,,,,,,,,,,, -27800,0.68244916,1.1747102,,,,,,,,,,,,,, -27900,0.6467541,1.1373384,,,,,,,,,,,,,, -28000,0.789982,1.1526036,,,,,,,,,,,,,, -28100,0.5676608,1.1861494,,,,,,,,,,,,,, -28200,0.57148445,1.229289,,,,,,,,,,,,,, -28300,0.5546318,1.204281,,,,,,,,,,,,,, -28400,0.6398486,1.2375938,,,,,,,,,,,,,, -28500,0.5989923,1.2487265,,,,,,,,,,,,,, -28600,0.57193846,1.1370707,,,,,,,,,,,,,, -28700,0.70843947,1.2163754,,,,,,,,,,,,,, -28800,0.5540685,1.2075826,,,,,,,,,,,,,, -28900,0.657463,1.1953769,,,,,,,,,,,,,, -29000,,,0.16537838,0.0638170060280613,0.43855268,0.1327128609633412,5348.0,0.25002927,0.086222655535921,2472.0,24521.913326740265,26975.06228876114,24521.913326740265,2450.9862122535706,0.8819196224212646,0.0 -29000,0.62576246,1.1611425,,,,,,,,,,,,,, -29100,0.5793099,1.1932154,,,,,,,,,,,,,, -29200,0.57330805,1.1842124,,,,,,,,,,,,,, -29300,0.56337214,1.2228546,,,,,,,,,,,,,, -29400,0.5418928,1.134592,,,,,,,,,,,,,, -29500,0.557687,1.1817082,,,,,,,,,,,,,, -29600,0.51657516,1.18396,,,,,,,,,,,,,, -29700,0.62354654,1.1773243,,,,,,,,,,,,,, -29800,0.65983415,1.2022079,,,,,,,,,,,,,, -29900,0.53397936,1.1486528,,,,,,,,,,,,,, -30000,0.5526275,1.1347244,,,,,,,,,,,,,, -30100,0.59283006,1.2173473,,,,,,,,,,,,,, -30200,0.56373936,1.1665015,,,,,,,,,,,,,, -30300,0.508959,1.1453336,,,,,,,,,,,,,, -30400,0.5480002,1.1322976,,,,,,,,,,,,,, -30500,0.68474907,1.1503762,,,,,,,,,,,,,, -30600,0.67837757,1.1557728,,,,,,,,,,,,,, -30700,0.6590004,1.1533918,,,,,,,,,,,,,, -30703,,,0.16171478,0.0622558439090314,0.43173057,0.1284841229230427,5348.0,0.24562462,0.0828103101578209,2472.0,25962.687898874283,28550.48955798149,25962.687898874283,2585.5083949565887,0.9355514049530028,0.0 -30800,0.55518925,1.1782655,,,,,,,,,,,,,, -30900,0.63437784,1.1321083,,,,,,,,,,,,,, -31000,0.6860947,1.1551578,,,,,,,,,,,,,, -31100,0.6414584,1.181093,,,,,,,,,,,,,, -31200,0.78007597,1.2194145,,,,,,,,,,,,,, -31300,0.8686797,1.1798594,,,,,,,,,,,,,, -31400,0.59837455,1.1397244,,,,,,,,,,,,,, -31500,0.8013723,1.163114,,,,,,,,,,,,,, -31600,0.5709141,1.1259259,,,,,,,,,,,,,, -31700,0.6482577,1.1810472,,,,,,,,,,,,,, -31800,0.6344936,1.1227016,,,,,,,,,,,,,, -31900,0.7027573,1.1455282,,,,,,,,,,,,,, -32000,0.6942395,1.158735,,,,,,,,,,,,,, -32100,0.6126607,1.1468246,,,,,,,,,,,,,, -32200,0.6955578,1.1797274,,,,,,,,,,,,,, -32300,0.57182634,1.1034817,,,,,,,,,,,,,, -32400,0.6644084,1.153964,,,,,,,,,,,,,, -32426,,,0.14997573,0.0594484991749729,0.4245383,0.127624858800699,5348.0,0.24090798,0.0830946722726626,2472.0,27402.558747529984,30128.4415204525,27402.558747529984,2723.458404779434,0.9899494647979736,0.0 -32500,0.6264268,1.1797702,,,,,,,,,,,,,, -32600,1.1164771,1.1161397,,,,,,,,,,,,,, -32700,0.6238384,1.1487955,,,,,,,,,,,,,, -32800,0.80732876,1.1796286,,,,,,,,,,,,,, -32900,0.6241998,1.1074718,,,,,,,,,,,,,, -33000,0.63326335,1.1860625,,,,,,,,,,,,,, -33100,0.5391466,1.1432527,,,,,,,,,,,,,, -33200,1.2828262,1.0627631,,,,,,,,,,,,,, -33300,0.55947816,1.1743493,,,,,,,,,,,,,, -33400,0.65876216,1.1190443,,,,,,,,,,,,,, -33500,0.6498789,1.1250587,,,,,,,,,,,,,, -33600,0.6723739,1.0417749,,,,,,,,,,,,,, -33700,0.5478636,1.1491151,,,,,,,,,,,,,, -33800,0.53251207,1.1382692,,,,,,,,,,,,,, -33900,0.56347644,1.142749,,,,,,,,,,,,,, -34000,0.5252664,1.1158329,,,,,,,,,,,,,, -34100,0.5630034,1.0998615,,,,,,,,,,,,,, -34142,,,0.14274807,0.0562218694164557,0.4185331,0.1253849792907692,5348.0,0.2346872,0.0791339142445108,2472.0,28842.477757692337,31706.7717628479,28842.477757692337,2861.7312412261963,1.050461769104004,0.0 -34200,0.750858,1.0659415,,,,,,,,,,,,,, -34300,0.629409,1.1681975,,,,,,,,,,,,,, -34400,0.7373336,1.1674345,,,,,,,,,,,,,, -34500,0.62663865,1.177879,,,,,,,,,,,,,, -34600,0.6186469,1.1295735,,,,,,,,,,,,,, -34700,0.6508367,1.1418098,,,,,,,,,,,,,, -34800,0.6537351,1.0790523,,,,,,,,,,,,,, -34900,0.55715907,1.0933945,,,,,,,,,,,,,, -35000,0.6675577,1.1971202,,,,,,,,,,,,,, -35100,0.58448803,1.114942,,,,,,,,,,,,,, -35200,0.6277496,1.1081836,,,,,,,,,,,,,, -35300,0.6121323,1.1202374,,,,,,,,,,,,,, -35400,0.5648027,1.1679415,,,,,,,,,,,,,, -35500,0.6060984,1.0981001,,,,,,,,,,,,,, -35600,0.6262847,1.1427602,,,,,,,,,,,,,, -35700,0.70306605,1.089689,,,,,,,,,,,,,, -35800,0.5614551,1.0878469,,,,,,,,,,,,,, -35838,,,0.12617776,0.0508775722171549,0.40873176,0.120547998107688,5348.0,0.23062919,0.0768387057461458,2472.0,30282.776348114014,33281.81516170502,30282.776348114014,2996.345301389694,1.1065824031829834,0.0 -35900,0.6442965,1.1282865,,,,,,,,,,,,,, -36000,0.5757429,1.0835042,,,,,,,,,,,,,, -36100,0.61338776,1.0659803,,,,,,,,,,,,,, -36200,0.864448,1.0967212,,,,,,,,,,,,,, -36300,0.65352184,1.1470166,,,,,,,,,,,,,, -36400,0.6739895,1.1139565,,,,,,,,,,,,,, -36500,0.6957286,1.1200787,,,,,,,,,,,,,, -36600,0.573288,1.1619028,,,,,,,,,,,,,, -36700,0.7017576,1.120003,,,,,,,,,,,,,, -36800,0.64014804,1.1417075,,,,,,,,,,,,,, -36900,0.70498955,1.1164457,,,,,,,,,,,,,, -37000,0.5946534,1.1307282,,,,,,,,,,,,,, -37100,0.5993681,1.0712687,,,,,,,,,,,,,, -37200,0.8539306,1.0787508,,,,,,,,,,,,,, -37300,0.613909,1.0941372,,,,,,,,,,,,,, -37400,0.68624115,1.0833445,,,,,,,,,,,,,, -37500,0.62935185,1.0534514,,,,,,,,,,,,,, -37553,,,0.1250977,0.0486143405405681,0.40274152,0.1209245295770296,5348.0,0.22272642,0.0752747141145166,2472.0,31722.89506983757,34858.73185968399,31722.89506983757,3133.003934860229,1.1698052883148191,0.0 -37600,0.5818122,1.0834244,,,,,,,,,,,,,, -37700,0.5765405,1.0894951,,,,,,,,,,,,,, -37800,0.6750507,1.1557615,,,,,,,,,,,,,, -37900,1.2523661,1.1081448,,,,,,,,,,,,,, -38000,0.7170005,1.0797578,,,,,,,,,,,,,, -38100,0.58107907,1.0875083,,,,,,,,,,,,,, -38200,0.65114355,1.0976429,,,,,,,,,,,,,, -38300,0.6196107,1.1233459,,,,,,,,,,,,,, -38400,0.61386794,1.0610574,,,,,,,,,,,,,, -38500,0.7082529,1.0940654,,,,,,,,,,,,,, -38600,0.5477315,1.0644324,,,,,,,,,,,,,, -38700,0.6629418,1.073117,,,,,,,,,,,,,, -38800,0.58198625,1.1059742,,,,,,,,,,,,,, -38900,0.62255806,1.0543993,,,,,,,,,,,,,, -39000,0.62842643,1.1095189,,,,,,,,,,,,,, -39100,0.7277115,1.130517,,,,,,,,,,,,,, -39200,0.62179536,1.0492321,,,,,,,,,,,,,, -39264,,,0.12867208,0.0503787059078121,0.39843825,0.118279154638578,5348.0,0.22493432,0.0736904109032559,2472.0,33163.29486012459,36436.21641421318,33163.29486012459,3269.946902036667,1.2371857166290283,0.0 -39300,0.7695071,1.094708,,,,,,,,,,,,,, -39400,0.5546678,1.0724597,,,,,,,,,,,,,, -39500,0.66739887,1.1156609,,,,,,,,,,,,,, -39600,0.6017025,1.0088778,,,,,,,,,,,,,, -39700,0.60568607,1.0575687,,,,,,,,,,,,,, -39800,0.58997613,1.0820729,,,,,,,,,,,,,, -39900,0.5770412,1.087832,,,,,,,,,,,,,, -40000,0.73959947,1.098176,,,,,,,,,,,,,, -40100,0.7809902,1.1011014,,,,,,,,,,,,,, -40200,0.62424064,1.053567,,,,,,,,,,,,,, -40300,0.6029852,1.0433294,,,,,,,,,,,,,, -40400,0.6380811,1.0389488,,,,,,,,,,,,,, -40500,0.7308915,1.0928565,,,,,,,,,,,,,, -40600,0.6637421,1.0085874,,,,,,,,,,,,,, -40700,0.58845145,1.016908,,,,,,,,,,,,,, -40800,0.6857954,1.0877869,,,,,,,,,,,,,, -40900,0.6175294,1.0482163,,,,,,,,,,,,,, -40958,,,0.11844589,0.0466740389277173,0.38874224,0.1154310319858655,5348.0,0.21982695,0.07413726565515,2472.0,34603.79007267952,38015.01168775559,34603.79007267952,3408.116780996322,1.292286396026611,0.0 -41000,0.6521465,1.0840753,,,,,,,,,,,,,, -41100,0.57825,1.1059653,,,,,,,,,,,,,, -41200,0.63362616,1.0955508,,,,,,,,,,,,,, -41300,0.6183019,1.080053,,,,,,,,,,,,,, -41400,0.6691277,1.0528129,,,,,,,,,,,,,, -41500,0.88808316,1.0622989,,,,,,,,,,,,,, -41600,0.6913227,1.0764198,,,,,,,,,,,,,, -41700,0.6310688,1.1020129,,,,,,,,,,,,,, -41800,0.61362535,1.0809208,,,,,,,,,,,,,, -41900,0.52718353,1.0633547,,,,,,,,,,,,,, -42000,0.7745061,1.0415876,,,,,,,,,,,,,, -42100,0.68643606,1.0662918,,,,,,,,,,,,,, -42200,0.6192143,1.0947472,,,,,,,,,,,,,, -42300,1.0758743,1.0878272,,,,,,,,,,,,,, -42400,0.58471406,1.0436002,,,,,,,,,,,,,, -42500,0.5208756,1.0163724,,,,,,,,,,,,,, -42600,0.5398119,1.025718,,,,,,,,,,,,,, -42686,,,0.11153542,0.0453985841797476,0.38364777,0.1128628942718943,5348.0,0.2124593,0.0698921455121564,2472.0,36044.39460206032,39592.237605810165,36044.39460206032,3544.602229118347,1.3495960235595703,0.0 -42700,0.6609615,1.0649863,,,,,,,,,,,,,, -42800,0.59331703,1.0433531,,,,,,,,,,,,,, -42900,0.58889765,1.0821745,,,,,,,,,,,,,, -43000,0.6111531,1.054014,,,,,,,,,,,,,, -43100,0.7831781,0.98415476,,,,,,,,,,,,,, -43200,0.61585826,1.0369626,,,,,,,,,,,,,, -43300,0.77566713,1.0633352,,,,,,,,,,,,,, -43400,0.67027485,1.044771,,,,,,,,,,,,,, -43500,0.5931841,1.0450374,,,,,,,,,,,,,, -43600,1.2545798,1.0704675,,,,,,,,,,,,,, -43700,0.56625825,1.0243896,,,,,,,,,,,,,, -43800,0.7478904,1.0159779,,,,,,,,,,,,,, -43900,0.5751941,1.0242915,,,,,,,,,,,,,, -44000,0.62860894,1.0299146,,,,,,,,,,,,,, -44100,0.78446704,1.0314292,,,,,,,,,,,,,, -44200,0.671738,1.0499227,,,,,,,,,,,,,, -44300,0.69981575,1.0795401,,,,,,,,,,,,,, -44363,,,0.1322927,0.0478370932916387,0.37878135,0.1110864381088465,5348.0,0.20779921,0.0695671602380517,2472.0,37485.06210923195,41171.84330153465,37485.06210923195,3683.409082174301,1.406475305557251,0.0 -44400,0.65824765,1.0385545,,,,,,,,,,,,,, -44500,0.7354649,1.048066,,,,,,,,,,,,,, -44600,0.62599164,1.0563643,,,,,,,,,,,,,, -44700,0.6460903,1.0338783,,,,,,,,,,,,,, -44800,0.6196037,1.0404994,,,,,,,,,,,,,, -44900,0.6338814,1.0541626,,,,,,,,,,,,,, -45000,0.9082579,1.0571107,,,,,,,,,,,,,, -45100,0.6163911,1.048975,,,,,,,,,,,,,, -45200,0.922752,1.0612702,,,,,,,,,,,,,, -45300,0.586379,0.9736728,,,,,,,,,,,,,, -45400,0.6718039,1.0073351,,,,,,,,,,,,,, -45500,0.75438744,1.0084758,,,,,,,,,,,,,, -45600,0.6726055,1.074567,,,,,,,,,,,,,, -45700,0.61236393,1.0265015,,,,,,,,,,,,,, -45800,0.6598681,1.0866957,,,,,,,,,,,,,, -45900,0.74277246,1.0280392,,,,,,,,,,,,,, -46000,0.62762964,1.0375526,,,,,,,,,,,,,, -46050,,,0.08896961,0.0359333018848575,0.37348276,0.1105457775374842,5348.0,0.20512715,0.0687953202120529,2472.0,38925.02504873276,42748.52485728264,38925.02504873276,3819.99338889122,1.465172290802002,0.0 -46100,0.6640685,1.0157999,,,,,,,,,,,,,, -46200,0.68946624,1.0465693,,,,,,,,,,,,,, -46300,0.5371999,1.0457569,,,,,,,,,,,,,, -46400,0.67171663,1.0162464,,,,,,,,,,,,,, -46500,0.8277615,1.0319898,,,,,,,,,,,,,, -46600,0.60586846,1.0432894,,,,,,,,,,,,,, -46700,0.7758619,0.97445273,,,,,,,,,,,,,, -46800,0.631537,1.025238,,,,,,,,,,,,,, -46900,0.7774274,1.0161179,,,,,,,,,,,,,, -47000,0.6580258,1.015731,,,,,,,,,,,,,, -47100,0.6459709,1.0201937,,,,,,,,,,,,,, -47200,0.6676773,1.0288893,,,,,,,,,,,,,, -47300,0.77535695,1.0495248,,,,,,,,,,,,,, -47400,0.62304467,1.0342858,,,,,,,,,,,,,, -47500,0.67333895,0.99090767,,,,,,,,,,,,,, -47600,0.68198556,0.990533,,,,,,,,,,,,,, -47700,0.8437823,1.0620352,,,,,,,,,,,,,, -47755,,,0.09621476,0.0385807713017612,0.3708817,0.1084120992112148,5348.0,0.2026361,0.0677391180712124,2472.0,40365.21565961838,44327.65276861191,40365.21565961838,3958.7835640907288,1.533278226852417,0.0 -47800,0.7795956,1.0474869,,,,,,,,,,,,,, -47900,0.8258997,0.9911021,,,,,,,,,,,,,, -48000,1.1197673,1.0350667,,,,,,,,,,,,,, -48100,0.76355565,1.0525007,,,,,,,,,,,,,, -48200,0.60670763,1.0207639,,,,,,,,,,,,,, -48300,1.0509305,1.0452422,,,,,,,,,,,,,, -48400,0.70493835,1.0273117,,,,,,,,,,,,,, -48500,0.86512315,0.98715204,,,,,,,,,,,,,, -48600,0.630021,0.9701911,,,,,,,,,,,,,, -48700,0.6772338,0.9841268,,,,,,,,,,,,,, -48800,0.72007257,1.0136166,,,,,,,,,,,,,, -48900,0.63167906,1.0301116,,,,,,,,,,,,,, -49000,0.6814259,1.0222874,,,,,,,,,,,,,, -49100,0.5859585,0.99613255,,,,,,,,,,,,,, -49200,1.0135272,1.0405641,,,,,,,,,,,,,, -49300,0.85538596,0.96958333,,,,,,,,,,,,,, -49400,0.685165,0.96627736,,,,,,,,,,,,,, -49431,,,0.114910506,0.0454800272605359,0.3653458,0.1066163337420469,5348.0,0.1993707,0.0656267137895314,2472.0,41805.53027248383,45904.92424035072,41805.53027248383,4095.601159334183,1.597571611404419,0.0 -49500,0.65068096,1.0191379,,,,,,,,,,,,,, -49600,0.6403527,0.97165257,,,,,,,,,,,,,, -49700,0.6904791,1.0089164,,,,,,,,,,,,,, -49800,0.7431177,1.0294707,,,,,,,,,,,,,, -49900,0.67318577,0.9924644,,,,,,,,,,,,,, -50000,0.70787966,1.0001729,,,,,,,,,,,,,, -50100,0.600518,1.000597,,,,,,,,,,,,,, -50200,0.8052907,0.96845365,,,,,,,,,,,,,, -50300,0.6897219,0.9774196,,,,,,,,,,,,,, -50400,0.6750511,1.0423739,,,,,,,,,,,,,, -50500,0.75257885,0.9785338,,,,,,,,,,,,,, -50600,0.66631055,0.9799451,,,,,,,,,,,,,, -50700,0.6638561,0.9687007,,,,,,,,,,,,,, -50800,0.7134483,1.0471387,,,,,,,,,,,,,, -50900,0.9059783,0.9487596,,,,,,,,,,,,,, -51000,0.6380542,1.0111564,,,,,,,,,,,,,, -51100,0.70445734,1.0026641,,,,,,,,,,,,,, -51120,,,0.11720094,0.0459911610910562,0.35821724,0.1047336763953387,5348.0,0.19543357,0.0643877074320069,2472.0,43246.34427714348,47481.11702299118,43246.34427714348,4230.847477197647,1.6536619663238523,0.0 -51200,0.9225606,1.0377417,,,,,,,,,,,,,, -51300,0.70825595,0.94531035,,,,,,,,,,,,,, -51400,0.8071866,0.9767209,,,,,,,,,,,,,, -51500,0.69416755,0.9968497,,,,,,,,,,,,,, -51600,0.6781038,0.99416596,,,,,,,,,,,,,, -51700,0.6725321,1.0276918,,,,,,,,,,,,,, -51800,0.6819832,0.9527296,,,,,,,,,,,,,, -51900,0.6414867,0.94404614,,,,,,,,,,,,,, -52000,0.7060375,0.9875132,,,,,,,,,,,,,, -52100,0.64508146,0.996324,,,,,,,,,,,,,, -52200,0.6727191,0.98283875,,,,,,,,,,,,,, -52300,0.7252228,0.9756631,,,,,,,,,,,,,, -52400,0.6917812,0.97578764,,,,,,,,,,,,,, -52500,0.7439761,0.9958245,,,,,,,,,,,,,, -52600,0.67565465,0.97011477,,,,,,,,,,,,,, -52700,0.6442874,0.9839663,,,,,,,,,,,,,, -52800,0.6375092,0.9335084,,,,,,,,,,,,,, -52839,,,0.1289726,0.0527311119361425,0.3622812,0.1027930911302702,5348.0,0.19422725,0.0630877663355879,2472.0,44687.03170180321,49053.82754635811,44687.03170180321,4362.7331802845,1.71537446975708,0.0 -52900,0.62655413,0.91353536,,,,,,,,,,,,,, -53000,0.6154582,0.94462955,,,,,,,,,,,,,, -53100,0.9288864,0.9721477,,,,,,,,,,,,,, -53200,0.9652006,0.9507126,,,,,,,,,,,,,, -53300,0.59956014,0.95977706,,,,,,,,,,,,,, -53400,0.6212718,0.9733361,,,,,,,,,,,,,, -53500,0.71288973,0.9816872,,,,,,,,,,,,,, -53600,0.6284264,0.957717,,,,,,,,,,,,,, -53700,0.6395422,0.9454445,,,,,,,,,,,,,, -53800,0.81430185,0.9216216,,,,,,,,,,,,,, -53900,0.6643478,0.9951349,,,,,,,,,,,,,, -54000,0.62972325,0.95178217,,,,,,,,,,,,,, -54100,0.78867394,0.9570793,,,,,,,,,,,,,, -54200,0.79708314,1.0048051,,,,,,,,,,,,,, -54300,0.67282957,0.9733947,,,,,,,,,,,,,, -54400,0.66132957,0.94033,,,,,,,,,,,,,, -54500,0.6477445,0.98581946,,,,,,,,,,,,,, -54523,,,0.11021236,0.042651719685409,0.34967783,0.1018662444365061,5348.0,0.19085734,0.0624174842077468,2472.0,46126.97074270248,50630.23694562912,46126.97074270248,4499.063733100891,1.7805514335632324,0.0 -54600,0.6580473,0.99014103,,,,,,,,,,,,,, -54700,0.7066077,1.0109159,,,,,,,,,,,,,, -54800,0.661037,0.9693772,,,,,,,,,,,,,, -54900,0.77885664,0.98129,,,,,,,,,,,,,, -55000,0.74625707,0.9548172,,,,,,,,,,,,,, -55100,0.7024037,0.910109,,,,,,,,,,,,,, -55200,0.7134644,0.9903199,,,,,,,,,,,,,, -55300,0.77589303,0.9796632,,,,,,,,,,,,,, -55400,0.74698055,0.95306426,,,,,,,,,,,,,, -55500,0.7754618,0.9355711,,,,,,,,,,,,,, -55600,0.86941123,0.952577,,,,,,,,,,,,,, -55700,0.72042423,0.98092055,,,,,,,,,,,,,, -55800,0.72281843,0.97361606,,,,,,,,,,,,,, -55900,0.6174673,0.90906364,,,,,,,,,,,,,, -56000,0.74645585,0.9651133,,,,,,,,,,,,,, -56100,0.67359126,0.9045245,,,,,,,,,,,,,, -56200,0.71586174,0.9343428,,,,,,,,,,,,,, -56223,,,0.09992043,0.0402115093926129,0.3473121,0.1006304488448207,5348.0,0.1872067,0.0610566083724331,2472.0,47568.22066044808,52210.12111449242,47568.22066044808,4637.565958499908,1.8369412422180176,0.0 -56300,0.7965398,0.9478667,,,,,,,,,,,,,, -56400,0.95568204,0.9318567,,,,,,,,,,,,,, -56500,0.8169665,0.9278915,,,,,,,,,,,,,, -56600,0.5886433,0.9657692,,,,,,,,,,,,,, -56700,0.6611285,0.9727345,,,,,,,,,,,,,, -56800,0.9068897,0.98768306,,,,,,,,,,,,,, -56900,0.6771665,0.92503226,,,,,,,,,,,,,, -57000,0.7737611,0.91250813,,,,,,,,,,,,,, -57100,1.0699081,0.9498071,,,,,,,,,,,,,, -57200,0.7180641,0.90518373,,,,,,,,,,,,,, -57300,0.8021782,0.9126729,,,,,,,,,,,,,, -57400,0.6407906,0.87939763,,,,,,,,,,,,,, -57500,0.68207115,0.961987,,,,,,,,,,,,,, -57600,0.71602917,0.97193164,,,,,,,,,,,,,, -57700,0.6469986,0.93655825,,,,,,,,,,,,,, -57800,0.7761762,0.92862326,,,,,,,,,,,,,, -57900,1.0950079,0.92603827,,,,,,,,,,,,,, -57923,,,0.08032824,0.0322207382815245,0.34661222,0.0986705542736321,5348.0,0.18778616,0.061462839965064,2472.0,49008.23717498779,53784.19967198372,49008.23717498779,4771.489331007004,1.899439573287964,0.0 -58000,1.0880439,0.90043133,,,,,,,,,,,,,, -58100,0.6823492,0.89762366,,,,,,,,,,,,,, -58200,0.7776731,0.930579,,,,,,,,,,,,,, -58300,1.0645761,0.92217606,,,,,,,,,,,,,, -58400,0.6570299,0.8948423,,,,,,,,,,,,,, -58500,0.77664053,0.9168711,,,,,,,,,,,,,, -58600,0.7467788,0.96068865,,,,,,,,,,,,,, -58700,0.83314556,0.9720455,,,,,,,,,,,,,, -58800,0.77509415,0.9456904,,,,,,,,,,,,,, -58900,0.6473761,0.89855164,,,,,,,,,,,,,, -59000,1.0544714,0.88535935,,,,,,,,,,,,,, -59100,0.6914843,0.9452285,,,,,,,,,,,,,, -59200,1.0113829,0.9534292,,,,,,,,,,,,,, -59300,0.60200864,0.9129936,,,,,,,,,,,,,, -59400,0.7257048,0.9011245,,,,,,,,,,,,,, -59500,0.72017384,0.8706252,,,,,,,,,,,,,, -59600,,,0.09178022,0.0368883750354486,0.3443059,0.0984484972532512,5348.0,0.18204756,0.0583754798610687,2472.0,50448.39874601364,55359.5160779953,50448.39874601364,4906.509536027908,1.959836483001709,0.0 -59600,0.78376794,0.9321819,,,,,,,,,,,,,, -59700,0.72680247,0.9519326,,,,,,,,,,,,,, -59800,0.61246514,0.8920209,,,,,,,,,,,,,, -59900,0.7340077,0.8760612,,,,,,,,,,,,,, -60000,0.68724173,0.94866645,,,,,,,,,,,,,, -60100,0.7109624,0.9309786,,,,,,,,,,,,,, -60200,0.99527055,0.90433663,,,,,,,,,,,,,, -60300,0.92159426,0.9145282,,,,,,,,,,,,,, -60400,0.65664065,0.9499142,,,,,,,,,,,,,, -60500,1.0115899,0.91794467,,,,,,,,,,,,,, -60600,0.8021299,0.901382,,,,,,,,,,,,,, -60700,0.67448306,0.91773105,,,,,,,,,,,,,, -60800,0.9978954,0.8957701,,,,,,,,,,,,,, -60900,0.7079425,0.87331194,,,,,,,,,,,,,, -61000,0.7736043,0.93117493,,,,,,,,,,,,,, -61100,0.7243999,0.93173057,,,,,,,,,,,,,, -61200,0.73443735,0.91532546,,,,,,,,,,,,,, -61297,,,0.076783,0.0299433215698855,0.34047252,0.0963148189269818,5348.0,0.18233697,0.0587207767148051,2472.0,51888.93256998062,56940.76946043968,51888.93256998062,5047.088932514191,2.025007963180542,0.0 -61300,0.7569644,0.9414163,,,,,,,,,,,,,, -61400,1.6928804,0.90477383,,,,,,,,,,,,,, -61500,0.8767077,0.92823255,,,,,,,,,,,,,, -61600,0.8426652,0.8257298,,,,,,,,,,,,,, -61700,0.6933526,0.8984412,,,,,,,,,,,,,, -61800,0.848236,0.90403193,,,,,,,,,,,,,, -61900,0.7033577,0.9243876,,,,,,,,,,,,,, -62000,1.376644,0.8981764,,,,,,,,,,,,,, -62100,0.7421606,0.90738684,,,,,,,,,,,,,, -62200,0.7383927,0.8963453,,,,,,,,,,,,,, -62300,0.7059286,0.8925812,,,,,,,,,,,,,, -62400,0.7386778,0.91821295,,,,,,,,,,,,,, -62500,0.8328243,0.94856477,,,,,,,,,,,,,, -62600,0.67175484,0.86683184,,,,,,,,,,,,,, -62700,0.7941354,0.90650445,,,,,,,,,,,,,, -62800,0.7963468,0.90390456,,,,,,,,,,,,,, -62900,0.6691566,0.8959915,,,,,,,,,,,,,, -62996,,,0.07512333,0.0306181933086717,0.33476064,0.0945190534578139,5348.0,0.17784591,0.0574614587776491,2472.0,53328.91833233833,58515.613005161285,53328.91833233833,5181.807159900665,2.087443351745605,0.0 -63000,0.81354773,0.886476,,,,,,,,,,,,,, -63100,0.85463226,0.88179857,,,,,,,,,,,,,, -63200,0.7496493,0.91627544,,,,,,,,,,,,,, -63300,0.69912374,0.89877504,,,,,,,,,,,,,, -63400,0.7552503,0.8961311,,,,,,,,,,,,,, -63500,0.8095395,0.92409474,,,,,,,,,,,,,, -63600,0.68812335,0.86291367,,,,,,,,,,,,,, -63700,0.78725845,0.8748681,,,,,,,,,,,,,, -63800,0.9842257,0.9111457,,,,,,,,,,,,,, -63900,0.6890869,0.8581596,,,,,,,,,,,,,, -64000,0.84921247,0.8704224,,,,,,,,,,,,,, -64100,0.6743111,0.8767179,,,,,,,,,,,,,, -64200,0.689315,0.9034922,,,,,,,,,,,,,, -64300,0.6729949,0.91481847,,,,,,,,,,,,,, -64400,0.6709927,0.8706539,,,,,,,,,,,,,, -64500,0.89182717,0.89256245,,,,,,,,,,,,,, -64600,0.6913587,0.89799255,,,,,,,,,,,,,, -64691,,,0.068822555,0.0274565048181808,0.33348805,0.0934956602334495,5348.0,0.17620729,0.0559990250441776,2472.0,54768.97867703438,60091.00317811966,54768.97867703438,5317.002858400345,2.147326707839966,0.0 -64700,0.7505153,0.86310476,,,,,,,,,,,,,, -64800,0.72133946,0.841698,,,,,,,,,,,,,, -64900,1.1025498,0.8751014,,,,,,,,,,,,,, -65000,1.1698165,0.8702249,,,,,,,,,,,,,, -65100,0.8055768,0.9037042,,,,,,,,,,,,,, -65200,0.8971743,0.894874,,,,,,,,,,,,,, -65300,0.79766536,0.88908565,,,,,,,,,,,,,, -65400,0.77770877,0.8510755,,,,,,,,,,,,,, -65500,0.8285824,0.89260316,,,,,,,,,,,,,, -65600,0.65895975,0.89310825,,,,,,,,,,,,,, -65700,0.8060633,0.8681073,,,,,,,,,,,,,, -65800,0.7727525,0.8273281,,,,,,,,,,,,,, -65900,1.251449,0.88071275,,,,,,,,,,,,,, -66000,0.8780228,0.867358,,,,,,,,,,,,,, -66100,1.1575919,0.87056684,,,,,,,,,,,,,, -66200,0.80967724,0.8811394,,,,,,,,,,,,,, -66300,0.6810929,0.8768535,,,,,,,,,,,,,, -66393,,,0.07160717,0.0279594590714634,0.33296782,0.0930225822335074,5348.0,0.17585711,0.0568317998090711,2472.0,56209.32621002197,61663.78510403633,56209.32621002197,5449.299804925919,2.210102796554565,0.0 -66400,0.7834923,0.91825634,,,,,,,,,,,,,, -66500,1.2258343,0.8913166,,,,,,,,,,,,,, -66600,0.67190605,0.8181993,,,,,,,,,,,,,, -66700,0.80354685,0.9057628,,,,,,,,,,,,,, -66800,0.79834086,0.8640553,,,,,,,,,,,,,, -66900,0.82822907,0.88538253,,,,,,,,,,,,,, -67000,0.83818394,0.85842544,,,,,,,,,,,,,, -67100,0.86060435,0.86654603,,,,,,,,,,,,,, -67200,0.98128057,0.91685003,,,,,,,,,,,,,, -67300,0.7502013,0.8443765,,,,,,,,,,,,,, -67400,0.73107934,0.84722054,,,,,,,,,,,,,, -67500,0.80471814,0.8970473,,,,,,,,,,,,,, -67600,0.9854027,0.851809,,,,,,,,,,,,,, -67700,1.1110724,0.8707796,,,,,,,,,,,,,, -67800,0.6863576,0.8462818,,,,,,,,,,,,,, -67900,0.93996215,0.91661334,,,,,,,,,,,,,, -68000,0.78729296,0.8143172,,,,,,,,,,,,,, -68079,,,0.07398825,0.0287359782769362,0.33303624,0.0930515461926875,5348.0,0.17321478,0.0558974671460199,2472.0,57649.28861045837,63239.29641199112,57649.28861045837,5584.706095695496,2.277143955230713,0.0 -68100,0.66996336,0.81226665,,,,,,,,,,,,,, -68200,0.77393454,0.8363476,,,,,,,,,,,,,, -68300,1.0007539,0.8383426,,,,,,,,,,,,,, -68400,2.0779445,0.8608805,,,,,,,,,,,,,, -68500,0.73295623,0.8733535,,,,,,,,,,,,,, -68600,0.85900766,0.8913394,,,,,,,,,,,,,, -68700,0.62612855,0.84294486,,,,,,,,,,,,,, -68800,0.91571265,0.892331,,,,,,,,,,,,,, -68900,0.8925061,0.85550094,,,,,,,,,,,,,, -69000,0.7307365,0.91143453,,,,,,,,,,,,,, -69100,0.98989314,0.8360425,,,,,,,,,,,,,, -69200,0.8932787,0.8729568,,,,,,,,,,,,,, -69300,0.9328048,0.8516736,,,,,,,,,,,,,, -69400,0.79042107,0.8392247,,,,,,,,,,,,,, -69500,0.9350728,0.8453261,,,,,,,,,,,,,, -69600,0.81782013,0.88765,,,,,,,,,,,,,, -69700,0.7437823,0.85819376,,,,,,,,,,,,,, -69778,,,0.061255727,0.0239954952694754,0.32672223,0.0911592341929192,5348.0,0.17336021,0.055531858712652,2472.0,59090.46134185791,64816.39191651344,59090.46134185791,5720.492845773697,2.3368923664093018,0.0 -69800,0.7782196,0.83967936,,,,,,,,,,,,,, -69900,0.8116748,0.84681594,,,,,,,,,,,,,, -70000,0.9885286,0.8455474,,,,,,,,,,,,,, -70100,0.8244331,0.8733234,,,,,,,,,,,,,, -70200,0.69894016,0.87398064,,,,,,,,,,,,,, -70300,1.2151052,0.85774523,,,,,,,,,,,,,, -70400,0.87210846,0.8155772,,,,,,,,,,,,,, -70500,0.8186183,0.88597035,,,,,,,,,,,,,, -70600,0.8013145,0.81612706,,,,,,,,,,,,,, -70700,0.77333087,0.8755482,,,,,,,,,,,,,, -70800,0.8326608,0.8817169,,,,,,,,,,,,,, -70900,0.76268816,0.8623344,,,,,,,,,,,,,, -71000,0.74721676,0.8411685,,,,,,,,,,,,,, -71100,0.69615424,0.851138,,,,,,,,,,,,,, -71200,0.7239543,0.85208607,,,,,,,,,,,,,, -71300,0.7190743,0.8857732,,,,,,,,,,,,,, -71400,0.7625032,0.83825654,,,,,,,,,,,,,, -71500,,,0.062214013,0.0247188810544264,0.32732484,0.0903965166011759,5348.0,0.17124374,0.0543740986736538,2472.0,60530.481770038605,66391.79093980789,60530.481770038605,5855.735311031342,2.396700859069824,0.0 -71500,1.0860437,0.85355365,,,,,,,,,,,,,, -71600,0.75912637,0.85910404,,,,,,,,,,,,,, -71700,0.80206555,0.864092,,,,,,,,,,,,,, -71800,0.8377436,0.81395334,,,,,,,,,,,,,, -71900,0.7234861,0.86448747,,,,,,,,,,,,,, -72000,1.0223613,0.8413328,,,,,,,,,,,,,, -72100,0.8341502,0.864328,,,,,,,,,,,,,, -72135,,,,,,,,,,,61068.56172633171,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 220242459..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -140.7980694770813,0.0,35.92231845855713,1,0,35.92231845855713,30.871214,2472,1.190014827453131,176.72045016288757,32.398228,1.135258711138108,30.757864,5348,1.1779352558965794 -251.5450747013092,0.0311036109924316,1475.9753279685974,1679,0,1475.9753279685974,6.0640683,2472,0.899579550301627,1727.6220135688782,6.0571713,0.9391896477614642,6.13424,5348,0.8966179750330672 -369.2691359519959,0.0840563774108886,2916.449270963669,3379,0,2916.449270963669,3.4898498,2472,0.754087705400849,3285.948108673096,3.853232,0.8365734501347709,3.802879,5348,0.7978219102696544 -510.6455659866333,0.1421759128570556,4357.660209178925,5053,0,4357.660209178925,1.0181005,2472,0.3211666971340361,4868.666774988174,1.3261416,0.3929169378299152,1.3397774,5348,0.3746005387296408 -649.3420903682709,0.194298505783081,5798.374348640442,6745,0,5798.374348640442,0.596215,2472,0.199540958300327,6448.203431367874,0.7336726,0.2415297145928209,0.87219507,5348,0.2599418789885785 -785.0801012516022,0.249565839767456,7238.520888566971,8450,0,7238.520888566971,0.51278436,2472,0.1729734121422623,8024.218991994858,0.61011267,0.2042115852891749,0.7665455,5348,0.2304662231962694 -919.9031093120576,0.3922488689422607,8678.477567434311,10145,0,8678.477567434311,0.44745144,2472,0.1521540430199256,9599.215919017792,0.5839462,0.1999528479289575,0.6855125,5348,0.2087335991581142 -1058.5496191978457,0.4467132091522217,10118.87825846672,11861,0,10118.87825846672,0.41517326,2472,0.1424044847967826,11178.391873121262,0.50584203,0.1745268731560112,0.65349907,5348,0.198181063363488 -1194.934509754181,0.5001969337463379,11559.40121126175,13572,0,11559.40121126175,0.39062923,2472,0.1336298823959539,12755.429028272629,0.49853423,0.1681780210600365,0.62621295,5348,0.1889801789972677 -1331.903869152069,0.5556769371032715,12999.642688512802,15270,0,12999.642688512802,0.36876667,2472,0.1259724168748603,14332.771677017212,0.41710535,0.1461387617891493,0.58882713,5348,0.1800399702636685 -1466.1310760974884,0.6140539646148682,14440.339807987211,16979,0,14440.339807987211,0.36074033,2472,0.1234740925801799,15907.82964539528,0.4158837,0.1474966401676499,0.572075,5348,0.1737258271624009 -1598.8436088562012,0.6656575202941895,15880.586597919464,18678,0,15880.586597919464,0.3428115,2472,0.1159588081165072,17480.91488957405,0.39677075,0.1402287030102502,0.55821544,5348,0.1698350019792038 -1745.0095345973969,0.7188072204589844,17321.084460258484,20395,0,17321.084460258484,0.3305695,2472,0.111652753234619,19067.70646309853,0.26701897,0.0973689141082947,0.53861165,5348,0.1639842822248182 -1883.2937757968905,0.7721741199493408,18761.54176592827,22125,0,18761.54176592827,0.32201567,2472,0.1093575447362541,20646.577522277832,0.25013727,0.0908888957771509,0.5297374,5348,0.1605472257354432 -2023.181410074234,0.8276913166046143,20202.07715177536,23843,0,20202.07715177536,0.31246486,2472,0.1053561635488392,22227.13118314743,0.24616796,0.0902991392319867,0.5182299,5348,0.1571294785521882 -2151.7861466407776,0.880361795425415,21642.539937257767,25659,0,21642.539937257767,0.30169642,2472,0.1013954055206873,23796.32274031639,0.23114382,0.0857162323641013,0.5024098,5348,0.1519256205528254 -2280.3359375,0.935565948486328,23082.5641746521,27534,0,23082.5641746521,0.29015937,2472,0.0991814433408486,25365.021073818207,0.22649139,0.0850088484254067,0.49297327,5348,0.150226401614258 -2409.5551302433014,0.989698886871338,24523.00218605995,29414,0,24523.00218605995,0.28078052,2472,0.0932098389291735,26934.803947925568,0.20505889,0.0763605881515613,0.4759323,5348,0.1427923187580254 -2536.493052005768,1.041748285293579,25962.93151783943,31290,0,25962.93151783943,0.2828514,2472,0.095119127414539,28501.7939991951,0.24940795,0.0886371916240273,0.47313604,5348,0.14384467594157 -2664.0866878032684,1.0995814800262451,27403.008170366287,33168,0,27403.008170366287,0.26978076,2472,0.0908740072715455,30069.59238386154,0.20455477,0.0765743657943205,0.46298337,5348,0.1396931751257518 -2793.5609545707703,1.1512136459350586,28843.57023668289,35039,0,28843.57023668289,0.26430643,2472,0.0884772408750228,31639.75147986412,0.18673988,0.070781465569247,0.44775584,5348,0.1340065844733869 -2922.4401302337646,1.2019224166870115,30284.19582605362,36922,0,30284.19582605362,0.25393718,2472,0.0839680701968192,33209.37820768356,0.17774187,0.0659739017195099,0.44245794,5348,0.1318246328818174 -3050.8781111240387,1.257685661315918,31724.767939329147,38794,0,31724.767939329147,0.24846429,2472,0.0833790343875043,34778.51620674133,0.16964507,0.0657891903531438,0.43418646,5348,0.1291116753719455 -3180.488514661789,1.3088583946228027,33164.96550655365,40666,0,33164.96550655365,0.24406633,2472,0.0825868827818739,36348.4490673542,0.17991582,0.0673838558744427,0.4272925,5348,0.1278951890863801 -3309.295679807663,1.3665189743041992,34604.94077825546,42536,0,34604.94077825546,0.24332765,2472,0.0804135437612983,37917.36299943924,0.17591874,0.065198274999605,0.42120722,5348,0.1239367813317628 -3439.328738212585,1.4259233474731443,36044.85700464249,44398,0,36044.85700464249,0.23402373,2472,0.0777121036703024,39487.44536948204,0.1598213,0.0607650590540974,0.4153037,5348,0.1243615860664047 -3569.7060022354126,1.4790055751800537,37485.1546869278,46259,0,37485.1546869278,0.22412066,2472,0.0761887351979363,41058.24700117111,0.15301697,0.0589560528388082,0.40488496,5348,0.1207507458219488 -3700.0979936122894,1.5350711345672607,38925.27982378006,48120,0,38925.27982378006,0.21684836,2472,0.0725529624438892,42628.89527773857,0.13564813,0.050873950467339,0.39561796,5348,0.1174295451692943 -3833.651977300644,1.5921378135681152,40365.31974339485,49986,0,40365.31974339485,0.21637918,2472,0.0717404992586273,44202.62142920494,0.13670734,0.0511201691746605,0.38398,5348,0.1154310319858655 -3962.5165877342224,1.646233320236206,41805.59288787842,51848,0,41805.59288787842,0.21018149,2472,0.0695468486584201,45771.889213085175,0.13435693,0.050879230706896,0.37354812,5348,0.1099471890477615 -4092.0777394771576,1.704374074935913,43245.52307033539,53703,0,43245.52307033539,0.20350519,2472,0.0677594296508439,47341.51375102997,0.1186261,0.0446542677961897,0.36797065,5348,0.1078038560684321 -4223.207216978073,1.7623958587646484,44685.69279670715,55563,0,44685.69279670715,0.19598241,2472,0.0648954969227956,48912.9459066391,0.11973902,0.0472273716241127,0.36250752,5348,0.1065101325583865 -4353.911198377609,1.8190128803253167,46125.87220811844,57412,0,46125.87220811844,0.19371364,2472,0.0641439684764284,50483.96151971817,0.13343582,0.048016934638889,0.34958616,5348,0.1037102831709742 -4484.790504455566,1.8767881393432613,47566.122009038925,59270,0,47566.122009038925,0.184741,2472,0.0619706294558527,52055.225329875946,0.08705975,0.0339719108044844,0.34289178,5348,0.1005628662734004 -4616.315147399902,1.937981128692627,49006.60219502449,61129,0,49006.60219502449,0.17703222,2472,0.0590051388296467,53627.36594891548,0.088854894,0.0341538013399359,0.33372676,5348,0.0980816204369696 -4747.2524383068085,1.996875286102295,50447.22417402268,62978,0,50447.22417402268,0.17833579,2472,0.0587410882944366,55199.06018662453,0.107303254,0.0416680418495659,0.33158988,5348,0.0967106597024436 -4877.739913702011,2.052243232727051,51887.12818932533,64826,0,51887.12818932533,0.16995475,2472,0.0560193366238092,56769.58395195007,0.101906076,0.0387572157318086,0.31840345,5348,0.0927425972947662 -5006.407078027725,2.1070005893707275,53327.30903625488,66672,0,53327.30903625488,0.1668592,2472,0.0543944102532854,58338.56297969818,0.11785042,0.0457300787328425,0.30807304,5348,0.0889290093360495 -5135.344121932983,2.1641712188720703,54767.38990926743,68524,0,54767.38990926743,0.1638089,2472,0.0525257449271829,59907.71535515785,0.092671074,0.0340948951196127,0.30833793,5348,0.0882918022340867 -5263.575304508209,2.2221508026123047,56207.60083270073,70376,0,56207.60083270073,0.1597395,2472,0.0516523470030264,61476.29270792008,0.08907641,0.034411574026683,0.30346766,5348,0.0865829286424592 -5394.533260583878,2.282556772232056,57647.67440390587,72210,0,57647.67440390587,0.15752408,2472,0.0510633111937115,63047.46157693863,0.066054076,0.0250100868816741,0.2977439,5348,0.0852988597854736 -5527.584671020508,2.346060037612915,59088.03487086296,74052,0,59088.03487086296,0.15448758,2472,0.04960087746024,64621.01479768753,0.078413114,0.0301152721978346,0.29437685,5348,0.0843913223978296 -5660.3452179431915,2.4161722660064697,60528.329884529114,75894,0,60528.329884529114,0.15408525,2472,0.048768102695346614,66194.21731305122,0.06854972,0.025352037408149633,0.29016793,5348,0.08304932562248375 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/measurements.csv deleted file mode 100644 index 5e761d637..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/measurements.csv +++ /dev/null @@ -1,811 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,42.290123,31.43843,,,,,,,,,,,,,, -1,,,32.398228,1.135258711138108,30.757864,1.1779352558965794,5348.0,30.871214,1.190014827453131,2472.0,35.92231845855713,176.72045016288757,35.92231845855713,140.7980694770813,0.0,0.0 -100,27.528732,7.335051,,,,,,,,,,,,,, -200,3.0503874,6.13145,,,,,,,,,,,,,, -300,0.40863764,5.855577,,,,,,,,,,,,,, -400,0.4015501,5.8107643,,,,,,,,,,,,,, -500,0.4725317,5.7828555,,,,,,,,,,,,,, -600,0.25514457,5.8039923,,,,,,,,,,,,,, -700,0.5692391,5.7931705,,,,,,,,,,,,,, -800,0.3332073,5.7883058,,,,,,,,,,,,,, -900,0.35689464,5.809726,,,,,,,,,,,,,, -1000,0.2920481,5.7789702,,,,,,,,,,,,,, -1100,0.4651395,5.7573323,,,,,,,,,,,,,, -1200,0.4742976,5.7777677,,,,,,,,,,,,,, -1300,0.2788835,5.7864156,,,,,,,,,,,,,, -1400,2.034634,5.7203965,,,,,,,,,,,,,, -1500,0.41400576,5.556803,,,,,,,,,,,,,, -1600,1.0933832,5.4591346,,,,,,,,,,,,,, -1679,,,6.0571713,0.9391896477614642,6.13424,0.8966179750330672,5348.0,6.0640683,0.899579550301627,2472.0,1475.9753279685974,1727.6220135688782,1475.9753279685974,251.5450747013092,0.0311036109924316,0.0 -1700,2.4329207,5.375963,,,,,,,,,,,,,, -1800,1.4875903,4.9825754,,,,,,,,,,,,,, -1900,0.77271163,4.448923,,,,,,,,,,,,,, -2000,1.3467507,4.018181,,,,,,,,,,,,,, -2100,1.0312144,3.750053,,,,,,,,,,,,,, -2200,1.1669005,3.521702,,,,,,,,,,,,,, -2300,1.1415144,3.4305687,,,,,,,,,,,,,, -2400,1.3135147,3.2970698,,,,,,,,,,,,,, -2500,1.5464432,3.2183144,,,,,,,,,,,,,, -2600,1.8668762,3.0826218,,,,,,,,,,,,,, -2700,1.151943,3.0202012,,,,,,,,,,,,,, -2800,1.1324735,2.9048872,,,,,,,,,,,,,, -2900,1.3365039,2.8996167,,,,,,,,,,,,,, -3000,1.1041831,2.8490849,,,,,,,,,,,,,, -3100,1.169428,2.835411,,,,,,,,,,,,,, -3200,0.9815049,2.6829748,,,,,,,,,,,,,, -3300,1.0025095,2.7077243,,,,,,,,,,,,,, -3379,,,3.853232,0.8365734501347709,3.802879,0.7978219102696544,5348.0,3.4898498,0.754087705400849,2472.0,2916.449270963669,3285.948108673096,2916.449270963669,369.2691359519959,0.0840563774108886,0.0 -3400,1.1277499,2.6327043,,,,,,,,,,,,,, -3500,1.2178351,2.60299,,,,,,,,,,,,,, -3600,0.9604911,2.606831,,,,,,,,,,,,,, -3700,1.4058292,2.513371,,,,,,,,,,,,,, -3800,1.0332165,2.427441,,,,,,,,,,,,,, -3900,0.9986699,2.452695,,,,,,,,,,,,,, -4000,0.9403914,2.4348774,,,,,,,,,,,,,, -4100,0.92526615,2.36221,,,,,,,,,,,,,, -4200,0.883022,2.2727208,,,,,,,,,,,,,, -4300,0.97720265,2.302027,,,,,,,,,,,,,, -4400,0.91043055,2.2222166,,,,,,,,,,,,,, -4500,0.8562314,2.2107685,,,,,,,,,,,,,, -4600,1.2324126,2.2071984,,,,,,,,,,,,,, -4700,1.0589583,2.175306,,,,,,,,,,,,,, -4800,1.0858363,2.0766408,,,,,,,,,,,,,, -4900,0.8542726,2.1111555,,,,,,,,,,,,,, -5000,1.6755131,2.1511607,,,,,,,,,,,,,, -5053,,,1.3261416,0.3929169378299152,1.3397774,0.3746005387296408,5348.0,1.0181005,0.3211666971340361,2472.0,4357.660209178925,4868.666774988174,4357.660209178925,510.6455659866333,0.1421759128570556,0.0 -5100,0.81705064,2.133006,,,,,,,,,,,,,, -5200,0.84275,2.0636625,,,,,,,,,,,,,, -5300,0.90073556,1.9482601,,,,,,,,,,,,,, -5400,1.040381,2.0428448,,,,,,,,,,,,,, -5500,0.78563577,1.9888215,,,,,,,,,,,,,, -5600,1.1071113,1.9667712,,,,,,,,,,,,,, -5700,0.8218348,1.9298021,,,,,,,,,,,,,, -5800,0.7822499,1.9268495,,,,,,,,,,,,,, -5900,0.78031754,1.8682913,,,,,,,,,,,,,, -6000,0.7338767,1.869928,,,,,,,,,,,,,, -6100,0.82738864,1.8623765,,,,,,,,,,,,,, -6200,0.70377636,1.9048767,,,,,,,,,,,,,, -6300,0.7893164,1.8497566,,,,,,,,,,,,,, -6400,0.9306888,1.7876009,,,,,,,,,,,,,, -6500,0.736001,1.8031998,,,,,,,,,,,,,, -6600,0.66010684,1.781207,,,,,,,,,,,,,, -6700,0.8956347,1.8327987,,,,,,,,,,,,,, -6745,,,0.7336726,0.2415297145928209,0.87219507,0.2599418789885785,5348.0,0.596215,0.199540958300327,2472.0,5798.374348640442,6448.203431367874,5798.374348640442,649.3420903682709,0.194298505783081,0.0 -6800,0.73647213,1.793342,,,,,,,,,,,,,, -6900,0.7792092,1.7674373,,,,,,,,,,,,,, -7000,0.7807924,1.7757868,,,,,,,,,,,,,, -7100,0.7539056,1.7966701,,,,,,,,,,,,,, -7200,0.67776614,1.7717731,,,,,,,,,,,,,, -7300,0.7392372,1.7574027,,,,,,,,,,,,,, -7400,0.6945283,1.6667286,,,,,,,,,,,,,, -7500,0.68424535,1.7261888,,,,,,,,,,,,,, -7600,0.64707774,1.7231413,,,,,,,,,,,,,, -7700,0.709893,1.6980002,,,,,,,,,,,,,, -7800,0.73837626,1.7686465,,,,,,,,,,,,,, -7900,0.77557015,1.7585322,,,,,,,,,,,,,, -8000,0.7543948,1.7501537,,,,,,,,,,,,,, -8100,0.69861305,1.6871837,,,,,,,,,,,,,, -8200,0.7452908,1.7275531,,,,,,,,,,,,,, -8300,0.7630697,1.6674131,,,,,,,,,,,,,, -8400,0.6648023,1.6818187,,,,,,,,,,,,,, -8450,,,0.61011267,0.2042115852891749,0.7665455,0.2304662231962694,5348.0,0.51278436,0.1729734121422623,2472.0,7238.520888566971,8024.218991994858,7238.520888566971,785.0801012516022,0.249565839767456,0.0 -8500,0.67457664,1.7251128,,,,,,,,,,,,,, -8600,0.6851994,1.6582289,,,,,,,,,,,,,, -8700,0.70248574,1.6750026,,,,,,,,,,,,,, -8800,0.7176679,1.695174,,,,,,,,,,,,,, -8900,0.66515285,1.6151388,,,,,,,,,,,,,, -9000,0.67829347,1.6192911,,,,,,,,,,,,,, -9100,0.719845,1.7002621,,,,,,,,,,,,,, -9200,0.6816742,1.6427957,,,,,,,,,,,,,, -9300,0.63111985,1.647549,,,,,,,,,,,,,, -9400,0.6436275,1.6007963,,,,,,,,,,,,,, -9500,0.7110431,1.6677588,,,,,,,,,,,,,, -9600,0.6937712,1.6127417,,,,,,,,,,,,,, -9700,0.6735477,1.676089,,,,,,,,,,,,,, -9800,0.7629146,1.5808539,,,,,,,,,,,,,, -9900,0.78252506,1.6647528,,,,,,,,,,,,,, -10000,0.8895164,1.6293044,,,,,,,,,,,,,, -10100,0.73694575,1.6050974,,,,,,,,,,,,,, -10145,,,0.5839462,0.1999528479289575,0.6855125,0.2087335991581142,5348.0,0.44745144,0.1521540430199256,2472.0,8678.477567434311,9599.215919017792,8678.477567434311,919.9031093120576,0.3922488689422607,0.0 -10200,0.7329075,1.6284858,,,,,,,,,,,,,, -10300,0.6052306,1.5735219,,,,,,,,,,,,,, -10400,0.669713,1.580095,,,,,,,,,,,,,, -10500,0.7074447,1.5438936,,,,,,,,,,,,,, -10600,0.6648802,1.579466,,,,,,,,,,,,,, -10700,0.6312191,1.5712235,,,,,,,,,,,,,, -10800,0.5903329,1.564744,,,,,,,,,,,,,, -10900,0.5971906,1.6081424,,,,,,,,,,,,,, -11000,0.6044624,1.5523878,,,,,,,,,,,,,, -11100,0.7548086,1.5111384,,,,,,,,,,,,,, -11200,0.5978002,1.5725676,,,,,,,,,,,,,, -11300,0.6524346,1.5544161,,,,,,,,,,,,,, -11400,0.5878955,1.5302104,,,,,,,,,,,,,, -11500,0.5721269,1.526609,,,,,,,,,,,,,, -11600,0.6787415,1.4736173,,,,,,,,,,,,,, -11700,0.6066799,1.4860288,,,,,,,,,,,,,, -11800,0.6561988,1.5031948,,,,,,,,,,,,,, -11861,,,0.50584203,0.1745268731560112,0.65349907,0.198181063363488,5348.0,0.41517326,0.1424044847967826,2472.0,10118.87825846672,11178.391873121262,10118.87825846672,1058.5496191978457,0.4467132091522217,0.0 -11900,0.5635231,1.5269713,,,,,,,,,,,,,, -12000,0.58368266,1.5352095,,,,,,,,,,,,,, -12100,0.5491286,1.5592002,,,,,,,,,,,,,, -12200,0.5588003,1.53188,,,,,,,,,,,,,, -12300,0.6168115,1.5496104,,,,,,,,,,,,,, -12400,0.68302506,1.4752057,,,,,,,,,,,,,, -12500,0.5749675,1.4991534,,,,,,,,,,,,,, -12600,0.67154473,1.5389413,,,,,,,,,,,,,, -12700,0.69316167,1.4893736,,,,,,,,,,,,,, -12800,0.5765326,1.4376298,,,,,,,,,,,,,, -12900,0.6527337,1.5197768,,,,,,,,,,,,,, -13000,0.64635825,1.5305562,,,,,,,,,,,,,, -13100,0.61801493,1.4828993,,,,,,,,,,,,,, -13200,0.7153703,1.43682,,,,,,,,,,,,,, -13300,0.7547834,1.533427,,,,,,,,,,,,,, -13400,0.619493,1.4028891,,,,,,,,,,,,,, -13500,0.6426565,1.4507003,,,,,,,,,,,,,, -13572,,,0.49853423,0.1681780210600365,0.62621295,0.1889801789972677,5348.0,0.39062923,0.1336298823959539,2472.0,11559.40121126175,12755.429028272629,11559.40121126175,1194.934509754181,0.5001969337463379,0.0 -13600,0.66865,1.5175419,,,,,,,,,,,,,, -13700,0.5770279,1.5031085,,,,,,,,,,,,,, -13800,0.6434481,1.4880028,,,,,,,,,,,,,, -13900,0.57108307,1.4944036,,,,,,,,,,,,,, -14000,0.69046,1.4459673,,,,,,,,,,,,,, -14100,0.6282362,1.5171412,,,,,,,,,,,,,, -14200,0.69758564,1.4710155,,,,,,,,,,,,,, -14300,0.5458021,1.5503124,,,,,,,,,,,,,, -14400,0.5976816,1.4777025,,,,,,,,,,,,,, -14500,0.71175295,1.453401,,,,,,,,,,,,,, -14600,0.7238355,1.455485,,,,,,,,,,,,,, -14700,0.64746016,1.4399805,,,,,,,,,,,,,, -14800,0.7199216,1.4979084,,,,,,,,,,,,,, -14900,0.6694257,1.4733505,,,,,,,,,,,,,, -15000,0.5977422,1.4110006,,,,,,,,,,,,,, -15100,0.6759942,1.4622562,,,,,,,,,,,,,, -15200,0.773009,1.4358013,,,,,,,,,,,,,, -15270,,,0.41710535,0.1461387617891493,0.58882713,0.1800399702636685,5348.0,0.36876667,0.1259724168748603,2472.0,12999.642688512802,14332.771677017212,12999.642688512802,1331.903869152069,0.5556769371032715,0.0 -15300,0.73751223,1.5066698,,,,,,,,,,,,,, -15400,0.5957108,1.5162507,,,,,,,,,,,,,, -15500,0.5794882,1.4485607,,,,,,,,,,,,,, -15600,0.67958367,1.4795802,,,,,,,,,,,,,, -15700,0.5311117,1.468126,,,,,,,,,,,,,, -15800,0.64665234,1.4534763,,,,,,,,,,,,,, -15900,0.6671011,1.4813081,,,,,,,,,,,,,, -16000,0.69545496,1.4834131,,,,,,,,,,,,,, -16100,0.6289571,1.3872522,,,,,,,,,,,,,, -16200,0.6305627,1.4431024,,,,,,,,,,,,,, -16300,0.66075075,1.4594556,,,,,,,,,,,,,, -16400,0.61742413,1.440809,,,,,,,,,,,,,, -16500,0.6183563,1.4299144,,,,,,,,,,,,,, -16600,0.66778654,1.4104892,,,,,,,,,,,,,, -16700,0.7475284,1.4550276,,,,,,,,,,,,,, -16800,0.64564055,1.4666793,,,,,,,,,,,,,, -16900,0.5986582,1.4035298,,,,,,,,,,,,,, -16979,,,0.4158837,0.1474966401676499,0.572075,0.1737258271624009,5348.0,0.36074033,0.1234740925801799,2472.0,14440.339807987211,15907.82964539528,14440.339807987211,1466.1310760974884,0.6140539646148682,0.0 -17000,0.5393336,1.4752555,,,,,,,,,,,,,, -17100,0.596478,1.4391606,,,,,,,,,,,,,, -17200,0.8223939,1.4412388,,,,,,,,,,,,,, -17300,0.78568524,1.3604363,,,,,,,,,,,,,, -17400,0.5829436,1.4562644,,,,,,,,,,,,,, -17500,0.82024014,1.461452,,,,,,,,,,,,,, -17600,0.7048167,1.4415607,,,,,,,,,,,,,, -17700,0.69897467,1.4546733,,,,,,,,,,,,,, -17800,0.86463934,1.39237,,,,,,,,,,,,,, -17900,0.6238633,1.4595317,,,,,,,,,,,,,, -18000,0.6937502,1.464735,,,,,,,,,,,,,, -18100,0.7831234,1.4280462,,,,,,,,,,,,,, -18200,0.88055015,1.4289604,,,,,,,,,,,,,, -18300,0.69654894,1.416511,,,,,,,,,,,,,, -18400,0.7234278,1.468528,,,,,,,,,,,,,, -18500,0.6533745,1.4035374,,,,,,,,,,,,,, -18600,0.58153963,1.3253224,,,,,,,,,,,,,, -18678,,,0.39677075,0.1402287030102502,0.55821544,0.1698350019792038,5348.0,0.3428115,0.1159588081165072,2472.0,15880.586597919464,17480.91488957405,15880.586597919464,1598.8436088562012,0.6656575202941895,0.0 -18700,0.7047321,1.3955525,,,,,,,,,,,,,, -18800,0.6887988,1.3970867,,,,,,,,,,,,,, -18900,0.65514416,1.3565503,,,,,,,,,,,,,, -19000,0.7292319,1.4089793,,,,,,,,,,,,,, -19100,0.81517285,1.4594425,,,,,,,,,,,,,, -19200,0.5988373,1.4267226,,,,,,,,,,,,,, -19300,0.6203704,1.3734083,,,,,,,,,,,,,, -19400,0.59101725,1.4535462,,,,,,,,,,,,,, -19500,0.73042977,1.3890133,,,,,,,,,,,,,, -19600,0.66161394,1.3652087,,,,,,,,,,,,,, -19700,0.55576444,1.3700795,,,,,,,,,,,,,, -19800,0.75454974,1.3898358,,,,,,,,,,,,,, -19900,0.69567764,1.3808777,,,,,,,,,,,,,, -20000,0.6879813,1.3428029,,,,,,,,,,,,,, -20100,0.7599076,1.3826628,,,,,,,,,,,,,, -20200,0.6327963,1.3413254,,,,,,,,,,,,,, -20300,0.5970582,1.4365913,,,,,,,,,,,,,, -20395,,,0.26701897,0.0973689141082947,0.53861165,0.1639842822248182,5348.0,0.3305695,0.111652753234619,2472.0,17321.084460258484,19067.70646309853,17321.084460258484,1745.0095345973969,0.7188072204589844,0.0 -20400,0.7400009,1.4209553,,,,,,,,,,,,,, -20500,0.6950381,1.3531386,,,,,,,,,,,,,, -20600,0.6577782,1.3218426,,,,,,,,,,,,,, -20700,0.5575299,1.3482779,,,,,,,,,,,,,, -20800,0.6775738,1.3702526,,,,,,,,,,,,,, -20900,0.70115465,1.4317583,,,,,,,,,,,,,, -21000,0.7137291,1.3884857,,,,,,,,,,,,,, -21100,0.57094,1.3298259,,,,,,,,,,,,,, -21200,0.7171504,1.3494374,,,,,,,,,,,,,, -21300,0.7660911,1.4032998,,,,,,,,,,,,,, -21400,0.703141,1.3882971,,,,,,,,,,,,,, -21500,0.7075198,1.382075,,,,,,,,,,,,,, -21600,0.6276935,1.3461657,,,,,,,,,,,,,, -21700,0.65003496,1.3758591,,,,,,,,,,,,,, -21800,0.7054726,1.3987021,,,,,,,,,,,,,, -21900,0.61926544,1.3529679,,,,,,,,,,,,,, -22000,0.6901061,1.360996,,,,,,,,,,,,,, -22100,0.7958336,1.4227939,,,,,,,,,,,,,, -22125,,,0.25013727,0.0908888957771509,0.5297374,0.1605472257354432,5348.0,0.32201567,0.1093575447362541,2472.0,18761.54176592827,20646.577522277832,18761.54176592827,1883.2937757968905,0.7721741199493408,0.0 -22200,0.66090876,1.3450959,,,,,,,,,,,,,, -22300,0.80734116,1.343235,,,,,,,,,,,,,, -22400,0.7096839,1.4094601,,,,,,,,,,,,,, -22500,0.7381149,1.3544043,,,,,,,,,,,,,, -22600,0.6102702,1.3817078,,,,,,,,,,,,,, -22700,0.66788507,1.3264731,,,,,,,,,,,,,, -22800,0.618112,1.3629262,,,,,,,,,,,,,, -22900,0.72052,1.3778708,,,,,,,,,,,,,, -23000,0.752791,1.4673834,,,,,,,,,,,,,, -23100,0.7061305,1.3189673,,,,,,,,,,,,,, -23200,0.6655501,1.3511108,,,,,,,,,,,,,, -23300,0.62102467,1.3318987,,,,,,,,,,,,,, -23400,0.8454889,1.2951139,,,,,,,,,,,,,, -23500,0.65380895,1.3268342,,,,,,,,,,,,,, -23600,0.66880596,1.3429217,,,,,,,,,,,,,, -23700,0.74223006,1.316651,,,,,,,,,,,,,, -23800,0.6945261,1.3413833,,,,,,,,,,,,,, -23843,,,0.24616796,0.0902991392319867,0.5182299,0.1571294785521882,5348.0,0.31246486,0.1053561635488392,2472.0,20202.07715177536,22227.13118314743,20202.07715177536,2023.181410074234,0.8276913166046143,0.0 -23900,0.7037927,1.281345,,,,,,,,,,,,,, -24000,0.654366,1.2846652,,,,,,,,,,,,,, -24100,0.58218277,1.3091404,,,,,,,,,,,,,, -24200,0.7408485,1.4057558,,,,,,,,,,,,,, -24300,0.76476943,1.3876086,,,,,,,,,,,,,, -24400,0.78166395,1.3350563,,,,,,,,,,,,,, -24500,0.6583998,1.3191221,,,,,,,,,,,,,, -24600,0.6697755,1.3366636,,,,,,,,,,,,,, -24700,0.65726185,1.346269,,,,,,,,,,,,,, -24800,0.5511576,1.3265699,,,,,,,,,,,,,, -24900,0.6603741,1.2940761,,,,,,,,,,,,,, -25000,0.634105,1.3016717,,,,,,,,,,,,,, -25100,0.57106966,1.2673496,,,,,,,,,,,,,, -25200,0.639901,1.354671,,,,,,,,,,,,,, -25300,0.7066547,1.3478079,,,,,,,,,,,,,, -25400,0.5926883,1.2735343,,,,,,,,,,,,,, -25500,0.61951673,1.3444374,,,,,,,,,,,,,, -25600,0.63613814,1.3831085,,,,,,,,,,,,,, -25659,,,0.23114382,0.0857162323641013,0.5024098,0.1519256205528254,5348.0,0.30169642,0.1013954055206873,2472.0,21642.539937257767,23796.32274031639,21642.539937257767,2151.7861466407776,0.880361795425415,0.0 -25700,0.69621074,1.2868179,,,,,,,,,,,,,, -25800,0.6841633,1.2969906,,,,,,,,,,,,,, -25900,0.64078516,1.3293815,,,,,,,,,,,,,, -26000,0.66446424,1.2969514,,,,,,,,,,,,,, -26100,0.63778025,1.2850897,,,,,,,,,,,,,, -26200,0.8450199,1.366396,,,,,,,,,,,,,, -26300,0.7107547,1.3164654,,,,,,,,,,,,,, -26400,0.65166044,1.2922758,,,,,,,,,,,,,, -26500,0.94068485,1.3054693,,,,,,,,,,,,,, -26600,0.766703,1.2833583,,,,,,,,,,,,,, -26700,0.5993786,1.3151909,,,,,,,,,,,,,, -26800,0.5578881,1.2572408,,,,,,,,,,,,,, -26900,0.712542,1.3095638,,,,,,,,,,,,,, -27000,0.6324024,1.2546116,,,,,,,,,,,,,, -27100,0.8970673,1.2970675,,,,,,,,,,,,,, -27200,0.57879245,1.2712963,,,,,,,,,,,,,, -27300,0.75012666,1.2905395,,,,,,,,,,,,,, -27400,0.9794683,1.3072087,,,,,,,,,,,,,, -27500,0.7221001,1.276228,,,,,,,,,,,,,, -27534,,,0.22649139,0.0850088484254067,0.49297327,0.150226401614258,5348.0,0.29015937,0.0991814433408486,2472.0,23082.5641746521,25365.021073818207,23082.5641746521,2280.3359375,0.935565948486328,0.0 -27600,0.61400205,1.2677168,,,,,,,,,,,,,, -27700,0.6931194,1.29453,,,,,,,,,,,,,, -27800,0.59297013,1.2960223,,,,,,,,,,,,,, -27900,0.86120445,1.2703806,,,,,,,,,,,,,, -28000,0.6629317,1.2500364,,,,,,,,,,,,,, -28100,0.7613835,1.3599036,,,,,,,,,,,,,, -28200,0.6740895,1.3584658,,,,,,,,,,,,,, -28300,0.8697842,1.3224571,,,,,,,,,,,,,, -28400,0.651515,1.3285054,,,,,,,,,,,,,, -28500,0.7691136,1.3216822,,,,,,,,,,,,,, -28600,0.6452589,1.3106753,,,,,,,,,,,,,, -28700,0.7310135,1.2959713,,,,,,,,,,,,,, -28800,0.65891033,1.2919956,,,,,,,,,,,,,, -28900,0.58908653,1.2228552,,,,,,,,,,,,,, -29000,0.6720991,1.2941265,,,,,,,,,,,,,, -29100,0.6521589,1.2838655,,,,,,,,,,,,,, -29200,0.68085825,1.3164148,,,,,,,,,,,,,, -29300,0.6265952,1.2858962,,,,,,,,,,,,,, -29400,0.5847771,1.256889,,,,,,,,,,,,,, -29414,,,0.20505889,0.0763605881515613,0.4759323,0.1427923187580254,5348.0,0.28078052,0.0932098389291735,2472.0,24523.00218605995,26934.803947925568,24523.00218605995,2409.5551302433014,0.989698886871338,0.0 -29500,0.70135593,1.3265357,,,,,,,,,,,,,, -29600,0.71925116,1.2874509,,,,,,,,,,,,,, -29700,0.7214335,1.24623,,,,,,,,,,,,,, -29800,0.7014208,1.3177782,,,,,,,,,,,,,, -29900,0.64788157,1.2388443,,,,,,,,,,,,,, -30000,0.6427084,1.2745489,,,,,,,,,,,,,, -30100,0.7079259,1.3106593,,,,,,,,,,,,,, -30200,0.71893805,1.278129,,,,,,,,,,,,,, -30300,0.7514759,1.311683,,,,,,,,,,,,,, -30400,0.6111921,1.280621,,,,,,,,,,,,,, -30500,0.65168655,1.3277074,,,,,,,,,,,,,, -30600,0.6204156,1.2878915,,,,,,,,,,,,,, -30700,0.7193186,1.2722619,,,,,,,,,,,,,, -30800,0.69204676,1.3137294,,,,,,,,,,,,,, -30900,0.70750374,1.2603543,,,,,,,,,,,,,, -31000,0.6546605,1.2404562,,,,,,,,,,,,,, -31100,0.6690622,1.2729234,,,,,,,,,,,,,, -31200,0.7141405,1.3054929,,,,,,,,,,,,,, -31290,,,0.24940795,0.0886371916240273,0.47313604,0.14384467594157,5348.0,0.2828514,0.095119127414539,2472.0,25962.93151783943,28501.7939991951,25962.93151783943,2536.493052005768,1.041748285293579,0.0 -31300,0.7157777,1.2947631,,,,,,,,,,,,,, -31400,0.83905834,1.2594582,,,,,,,,,,,,,, -31500,0.638585,1.2068113,,,,,,,,,,,,,, -31600,0.6919557,1.2668496,,,,,,,,,,,,,, -31700,0.6722349,1.2863338,,,,,,,,,,,,,, -31800,0.62665886,1.2684679,,,,,,,,,,,,,, -31900,0.60385317,1.2819471,,,,,,,,,,,,,, -32000,0.65376353,1.2250814,,,,,,,,,,,,,, -32100,0.79370475,1.2390395,,,,,,,,,,,,,, -32200,0.7768178,1.2712865,,,,,,,,,,,,,, -32300,0.5742007,1.2020681,,,,,,,,,,,,,, -32400,0.5190186,1.2112719,,,,,,,,,,,,,, -32500,0.6791407,1.2370815,,,,,,,,,,,,,, -32600,0.6317014,1.2732215,,,,,,,,,,,,,, -32700,0.7122878,1.2500192,,,,,,,,,,,,,, -32800,0.65313,1.293017,,,,,,,,,,,,,, -32900,0.601326,1.2461519,,,,,,,,,,,,,, -33000,0.72133285,1.2459985,,,,,,,,,,,,,, -33100,0.7930845,1.3002423,,,,,,,,,,,,,, -33168,,,0.20455477,0.0765743657943205,0.46298337,0.1396931751257518,5348.0,0.26978076,0.0908740072715455,2472.0,27403.008170366287,30069.59238386154,27403.008170366287,2664.0866878032684,1.0995814800262451,0.0 -33200,0.7096338,1.1806203,,,,,,,,,,,,,, -33300,0.7008545,1.2647759,,,,,,,,,,,,,, -33400,0.66844314,1.26859,,,,,,,,,,,,,, -33500,0.62189025,1.2989931,,,,,,,,,,,,,, -33600,0.7127594,1.276009,,,,,,,,,,,,,, -33700,0.7304514,1.3045473,,,,,,,,,,,,,, -33800,0.70469415,1.3163664,,,,,,,,,,,,,, -33900,0.68787074,1.2462262,,,,,,,,,,,,,, -34000,0.7308013,1.2995995,,,,,,,,,,,,,, -34100,0.8058815,1.2463833,,,,,,,,,,,,,, -34200,0.68321085,1.1773171,,,,,,,,,,,,,, -34300,0.61415267,1.2071675,,,,,,,,,,,,,, -34400,0.73854434,1.2436309,,,,,,,,,,,,,, -34500,0.76180094,1.2297841,,,,,,,,,,,,,, -34600,0.8350382,1.2865769,,,,,,,,,,,,,, -34700,0.7102624,1.2063556,,,,,,,,,,,,,, -34800,0.6342452,1.2327142,,,,,,,,,,,,,, -34900,0.6670668,1.2250787,,,,,,,,,,,,,, -35000,0.7667671,1.264874,,,,,,,,,,,,,, -35039,,,0.18673988,0.070781465569247,0.44775584,0.1340065844733869,5348.0,0.26430643,0.0884772408750228,2472.0,28843.57023668289,31639.75147986412,28843.57023668289,2793.5609545707703,1.1512136459350586,0.0 -35100,0.6480618,1.2563344,,,,,,,,,,,,,, -35200,0.6502942,1.2193006,,,,,,,,,,,,,, -35300,0.7519408,1.2228884,,,,,,,,,,,,,, -35400,0.6891857,1.2921892,,,,,,,,,,,,,, -35500,0.69267714,1.2043877,,,,,,,,,,,,,, -35600,0.736822,1.2675173,,,,,,,,,,,,,, -35700,0.6718229,1.1804402,,,,,,,,,,,,,, -35800,0.69823503,1.1950269,,,,,,,,,,,,,, -35900,0.629109,1.1904652,,,,,,,,,,,,,, -36000,0.668801,1.2469157,,,,,,,,,,,,,, -36100,0.6698596,1.2107497,,,,,,,,,,,,,, -36200,0.6332615,1.2238947,,,,,,,,,,,,,, -36300,0.6873796,1.2011715,,,,,,,,,,,,,, -36400,0.74236995,1.2288388,,,,,,,,,,,,,, -36500,0.66271865,1.2244143,,,,,,,,,,,,,, -36600,0.623509,1.2262433,,,,,,,,,,,,,, -36700,0.72983474,1.2342955,,,,,,,,,,,,,, -36800,0.617646,1.1690443,,,,,,,,,,,,,, -36900,0.731971,1.184098,,,,,,,,,,,,,, -36922,,,0.17774187,0.0659739017195099,0.44245794,0.1318246328818174,5348.0,0.25393718,0.0839680701968192,2472.0,30284.19582605362,33209.37820768356,30284.19582605362,2922.4401302337646,1.2019224166870115,0.0 -37000,0.69827837,1.2512237,,,,,,,,,,,,,, -37100,0.6367336,1.1997362,,,,,,,,,,,,,, -37200,0.7380319,1.253748,,,,,,,,,,,,,, -37300,0.7301571,1.2628121,,,,,,,,,,,,,, -37400,0.6815397,1.1905235,,,,,,,,,,,,,, -37500,0.6968535,1.2035671,,,,,,,,,,,,,, -37600,0.6378319,1.1440319,,,,,,,,,,,,,, -37700,0.7719328,1.1892058,,,,,,,,,,,,,, -37800,0.8077153,1.1960307,,,,,,,,,,,,,, -37900,0.68459046,1.234529,,,,,,,,,,,,,, -38000,0.83322924,1.2691119,,,,,,,,,,,,,, -38100,0.8285446,1.2570955,,,,,,,,,,,,,, -38200,0.6982385,1.2086971,,,,,,,,,,,,,, -38300,0.82234627,1.2355788,,,,,,,,,,,,,, -38400,0.77350134,1.2422177,,,,,,,,,,,,,, -38500,0.7732206,1.2311312,,,,,,,,,,,,,, -38600,0.82582146,1.197312,,,,,,,,,,,,,, -38700,0.6716096,1.2137798,,,,,,,,,,,,,, -38794,,,0.16964507,0.0657891903531438,0.43418646,0.1291116753719455,5348.0,0.24846429,0.0833790343875043,2472.0,31724.767939329147,34778.51620674133,31724.767939329147,3050.8781111240387,1.257685661315918,0.0 -38800,0.7339001,1.260191,,,,,,,,,,,,,, -38900,0.68879586,1.1737144,,,,,,,,,,,,,, -39000,0.6727731,1.2263683,,,,,,,,,,,,,, -39100,0.65614533,1.1721203,,,,,,,,,,,,,, -39200,0.7434236,1.2675893,,,,,,,,,,,,,, -39300,0.686637,1.2350651,,,,,,,,,,,,,, -39400,0.6964,1.2061361,,,,,,,,,,,,,, -39500,0.8411998,1.1634555,,,,,,,,,,,,,, -39600,1.1357206,1.1828969,,,,,,,,,,,,,, -39700,0.69818646,1.193681,,,,,,,,,,,,,, -39800,0.7318955,1.2178074,,,,,,,,,,,,,, -39900,0.7547629,1.2219716,,,,,,,,,,,,,, -40000,0.8207997,1.2406566,,,,,,,,,,,,,, -40100,0.6392998,1.1285031,,,,,,,,,,,,,, -40200,0.739315,1.1326437,,,,,,,,,,,,,, -40300,0.6931354,1.1834154,,,,,,,,,,,,,, -40400,0.88644814,1.2043247,,,,,,,,,,,,,, -40500,0.6408633,1.1445066,,,,,,,,,,,,,, -40600,1.0790354,1.1377062,,,,,,,,,,,,,, -40666,,,0.17991582,0.0673838558744427,0.4272925,0.1278951890863801,5348.0,0.24406633,0.0825868827818739,2472.0,33164.96550655365,36348.4490673542,33164.96550655365,3180.488514661789,1.3088583946228027,0.0 -40700,0.75124645,1.1929051,,,,,,,,,,,,,, -40800,0.7343215,1.2070376,,,,,,,,,,,,,, -40900,0.73453695,1.2486355,,,,,,,,,,,,,, -41000,0.79592854,1.2278594,,,,,,,,,,,,,, -41100,0.72319883,1.2426636,,,,,,,,,,,,,, -41200,0.6745698,1.1827524,,,,,,,,,,,,,, -41300,0.6845741,1.179422,,,,,,,,,,,,,, -41400,0.6905844,1.157775,,,,,,,,,,,,,, -41500,0.66745204,1.1579576,,,,,,,,,,,,,, -41600,0.83568245,1.2382625,,,,,,,,,,,,,, -41700,0.91036546,1.2103988,,,,,,,,,,,,,, -41800,0.85356593,1.1324445,,,,,,,,,,,,,, -41900,0.666913,1.1782285,,,,,,,,,,,,,, -42000,0.76182145,1.1476538,,,,,,,,,,,,,, -42100,0.7157599,1.2072228,,,,,,,,,,,,,, -42200,0.70882374,1.2029065,,,,,,,,,,,,,, -42300,0.6903735,1.1913791,,,,,,,,,,,,,, -42400,0.7818971,1.1666135,,,,,,,,,,,,,, -42500,0.6829868,1.2063955,,,,,,,,,,,,,, -42536,,,0.17591874,0.065198274999605,0.42120722,0.1239367813317628,5348.0,0.24332765,0.0804135437612983,2472.0,34604.94077825546,37917.36299943924,34604.94077825546,3309.295679807663,1.3665189743041992,0.0 -42600,0.7837043,1.1752726,,,,,,,,,,,,,, -42700,0.7473615,1.1322823,,,,,,,,,,,,,, -42800,0.8413916,1.2012619,,,,,,,,,,,,,, -42900,0.7717091,1.1679939,,,,,,,,,,,,,, -43000,0.7251274,1.1264964,,,,,,,,,,,,,, -43100,1.0938286,1.1666207,,,,,,,,,,,,,, -43200,0.8917963,1.1357487,,,,,,,,,,,,,, -43300,0.6780526,1.1780683,,,,,,,,,,,,,, -43400,0.6684245,1.1770552,,,,,,,,,,,,,, -43500,0.7817321,1.1553543,,,,,,,,,,,,,, -43600,0.7174248,1.1549866,,,,,,,,,,,,,, -43700,0.7221889,1.142928,,,,,,,,,,,,,, -43800,0.7934199,1.1760844,,,,,,,,,,,,,, -43900,0.78238934,1.1624453,,,,,,,,,,,,,, -44000,0.90362495,1.176329,,,,,,,,,,,,,, -44100,0.796005,1.2181754,,,,,,,,,,,,,, -44200,0.7760202,1.0992775,,,,,,,,,,,,,, -44300,0.78716743,1.1628549,,,,,,,,,,,,,, -44398,,,0.1598213,0.0607650590540974,0.4153037,0.1243615860664047,5348.0,0.23402373,0.0777121036703024,2472.0,36044.85700464249,39487.44536948204,36044.85700464249,3439.328738212585,1.4259233474731443,0.0 -44400,0.6700384,1.173798,,,,,,,,,,,,,, -44500,0.90172684,1.1299398,,,,,,,,,,,,,, -44600,0.80190104,1.1602591,,,,,,,,,,,,,, -44700,0.7940263,1.18223,,,,,,,,,,,,,, -44800,0.7995196,1.1838696,,,,,,,,,,,,,, -44900,0.7994406,1.1608276,,,,,,,,,,,,,, -45000,0.68832684,1.1419578,,,,,,,,,,,,,, -45100,0.7768772,1.1915116,,,,,,,,,,,,,, -45200,0.7373601,1.1799296,,,,,,,,,,,,,, -45300,0.81368273,1.1558834,,,,,,,,,,,,,, -45400,0.7770932,1.1313452,,,,,,,,,,,,,, -45500,0.85188895,1.0930874,,,,,,,,,,,,,, -45600,0.904087,1.1145594,,,,,,,,,,,,,, -45700,0.74664956,1.1409756,,,,,,,,,,,,,, -45800,0.9019362,1.1537342,,,,,,,,,,,,,, -45900,0.88889325,1.1688911,,,,,,,,,,,,,, -46000,0.83903825,1.1474137,,,,,,,,,,,,,, -46100,1.253321,1.1240058,,,,,,,,,,,,,, -46200,0.84867144,1.1259134,,,,,,,,,,,,,, -46259,,,0.15301697,0.0589560528388082,0.40488496,0.1207507458219488,5348.0,0.22412066,0.0761887351979363,2472.0,37485.1546869278,41058.24700117111,37485.1546869278,3569.7060022354126,1.4790055751800537,0.0 -46300,0.6744496,1.0945642,,,,,,,,,,,,,, -46400,0.8457907,1.1610625,,,,,,,,,,,,,, -46500,0.7700271,1.1488115,,,,,,,,,,,,,, -46600,0.7575428,1.2191107,,,,,,,,,,,,,, -46700,0.73101836,1.0859859,,,,,,,,,,,,,, -46800,0.85682887,1.1437463,,,,,,,,,,,,,, -46900,0.8344765,1.135787,,,,,,,,,,,,,, -47000,0.8624994,1.2073101,,,,,,,,,,,,,, -47100,0.7419462,1.1929052,,,,,,,,,,,,,, -47200,0.8197152,1.1180872,,,,,,,,,,,,,, -47300,0.7581027,1.1659951,,,,,,,,,,,,,, -47400,0.8602353,1.103556,,,,,,,,,,,,,, -47500,0.73609245,1.1119722,,,,,,,,,,,,,, -47600,0.79466546,1.1143994,,,,,,,,,,,,,, -47700,0.79104066,1.1207912,,,,,,,,,,,,,, -47800,0.7240884,1.1315472,,,,,,,,,,,,,, -47900,0.9956946,1.1143146,,,,,,,,,,,,,, -48000,0.83760786,1.1126577,,,,,,,,,,,,,, -48100,0.8350878,1.1705735,,,,,,,,,,,,,, -48120,,,0.13564813,0.050873950467339,0.39561796,0.1174295451692943,5348.0,0.21684836,0.0725529624438892,2472.0,38925.27982378006,42628.89527773857,38925.27982378006,3700.0979936122894,1.5350711345672607,0.0 -48200,0.976769,1.0841274,,,,,,,,,,,,,, -48300,0.7603248,1.1095067,,,,,,,,,,,,,, -48400,0.6928938,1.0840806,,,,,,,,,,,,,, -48500,1.0426768,1.1354417,,,,,,,,,,,,,, -48600,0.68469936,1.1020757,,,,,,,,,,,,,, -48700,0.8473246,1.1079448,,,,,,,,,,,,,, -48800,0.7850594,1.0890276,,,,,,,,,,,,,, -48900,0.76798105,1.1604644,,,,,,,,,,,,,, -49000,0.8195077,1.0521313,,,,,,,,,,,,,, -49100,0.9647511,1.1348901,,,,,,,,,,,,,, -49200,0.78230625,1.1202987,,,,,,,,,,,,,, -49300,0.72816235,1.1206471,,,,,,,,,,,,,, -49400,0.79463696,1.0849837,,,,,,,,,,,,,, -49500,0.7387045,1.174986,,,,,,,,,,,,,, -49600,0.88081175,1.1056367,,,,,,,,,,,,,, -49700,0.8866866,1.061899,,,,,,,,,,,,,, -49800,0.8931708,1.0939982,,,,,,,,,,,,,, -49900,0.76818264,1.1281458,,,,,,,,,,,,,, -49986,,,0.13670734,0.0511201691746605,0.38398,0.1154310319858655,5348.0,0.21637918,0.0717404992586273,2472.0,40365.31974339485,44202.62142920494,40365.31974339485,3833.651977300644,1.5921378135681152,0.0 -50000,0.760556,1.1011626,,,,,,,,,,,,,, -50100,0.8186327,1.1572143,,,,,,,,,,,,,, -50200,0.9263151,1.1501124,,,,,,,,,,,,,, -50300,0.77125275,1.1215764,,,,,,,,,,,,,, -50400,0.8218098,1.1250228,,,,,,,,,,,,,, -50500,0.85254073,1.0976818,,,,,,,,,,,,,, -50600,0.95304525,1.0927815,,,,,,,,,,,,,, -50700,0.75723886,1.1014848,,,,,,,,,,,,,, -50800,0.8666034,1.1950955,,,,,,,,,,,,,, -50900,0.81169385,1.0784936,,,,,,,,,,,,,, -51000,0.8305637,1.0841467,,,,,,,,,,,,,, -51100,0.86511,1.1061342,,,,,,,,,,,,,, -51200,0.79776776,1.0979079,,,,,,,,,,,,,, -51300,0.8263811,1.0834366,,,,,,,,,,,,,, -51400,0.83262825,1.1145424,,,,,,,,,,,,,, -51500,0.9665315,1.0618305,,,,,,,,,,,,,, -51600,0.8798803,1.1580468,,,,,,,,,,,,,, -51700,0.7727317,1.0698578,,,,,,,,,,,,,, -51800,0.86179566,1.0602803,,,,,,,,,,,,,, -51848,,,0.13435693,0.050879230706896,0.37354812,0.1099471890477615,5348.0,0.21018149,0.0695468486584201,2472.0,41805.59288787842,45771.889213085175,41805.59288787842,3962.5165877342224,1.646233320236206,0.0 -51900,1.2669607,1.0586507,,,,,,,,,,,,,, -52000,0.8240803,1.1183357,,,,,,,,,,,,,, -52100,0.88039804,1.096184,,,,,,,,,,,,,, -52200,0.8549243,1.1133538,,,,,,,,,,,,,, -52300,0.8343646,1.0827284,,,,,,,,,,,,,, -52400,0.7716173,1.0782624,,,,,,,,,,,,,, -52500,0.7515451,1.0663575,,,,,,,,,,,,,, -52600,0.79602903,1.0605465,,,,,,,,,,,,,, -52700,0.86076176,1.0766969,,,,,,,,,,,,,, -52800,0.8145795,1.0466541,,,,,,,,,,,,,, -52900,0.8878991,1.0751077,,,,,,,,,,,,,, -53000,0.94270456,1.0453136,,,,,,,,,,,,,, -53100,0.92607534,1.0544682,,,,,,,,,,,,,, -53200,0.848143,1.0329057,,,,,,,,,,,,,, -53300,0.9464233,1.0655483,,,,,,,,,,,,,, -53400,0.9287192,1.1046882,,,,,,,,,,,,,, -53500,0.83431107,1.0991983,,,,,,,,,,,,,, -53600,0.85778534,1.0618116,,,,,,,,,,,,,, -53700,1.0190938,1.045821,,,,,,,,,,,,,, -53703,,,0.1186261,0.0446542677961897,0.36797065,0.1078038560684321,5348.0,0.20350519,0.0677594296508439,2472.0,43245.52307033539,47341.51375102997,43245.52307033539,4092.0777394771576,1.704374074935913,0.0 -53800,0.82728714,1.0170072,,,,,,,,,,,,,, -53900,0.97916937,1.1262554,,,,,,,,,,,,,, -54000,0.8645289,1.110099,,,,,,,,,,,,,, -54100,0.9137816,1.0716814,,,,,,,,,,,,,, -54200,1.0219004,1.0852101,,,,,,,,,,,,,, -54300,0.9583195,1.0624181,,,,,,,,,,,,,, -54400,0.9146891,1.0579281,,,,,,,,,,,,,, -54500,0.851833,1.13755,,,,,,,,,,,,,, -54600,0.8857905,1.096183,,,,,,,,,,,,,, -54700,0.85483253,1.091421,,,,,,,,,,,,,, -54800,0.97397876,1.0698341,,,,,,,,,,,,,, -54900,0.8699609,1.1166939,,,,,,,,,,,,,, -55000,0.92822754,1.0681648,,,,,,,,,,,,,, -55100,0.9675532,1.0502318,,,,,,,,,,,,,, -55200,1.0318843,1.0353607,,,,,,,,,,,,,, -55300,0.77224857,1.0188878,,,,,,,,,,,,,, -55400,0.75240433,0.9992951,,,,,,,,,,,,,, -55500,0.8391271,1.0234493,,,,,,,,,,,,,, -55563,,,0.11973902,0.0472273716241127,0.36250752,0.1065101325583865,5348.0,0.19598241,0.0648954969227956,2472.0,44685.69279670715,48912.9459066391,44685.69279670715,4223.207216978073,1.7623958587646484,0.0 -55600,0.9119577,1.0204298,,,,,,,,,,,,,, -55700,1.0049174,1.0630665,,,,,,,,,,,,,, -55800,0.8009221,1.0207219,,,,,,,,,,,,,, -55900,0.9054953,1.019052,,,,,,,,,,,,,, -56000,0.99202263,1.0684302,,,,,,,,,,,,,, -56100,0.83784443,1.0173533,,,,,,,,,,,,,, -56200,0.98456854,1.0399127,,,,,,,,,,,,,, -56300,1.0424025,1.0012764,,,,,,,,,,,,,, -56400,0.83540475,1.0241327,,,,,,,,,,,,,, -56500,0.854674,1.0200293,,,,,,,,,,,,,, -56600,1.2998573,1.0216697,,,,,,,,,,,,,, -56700,0.9044252,1.0268145,,,,,,,,,,,,,, -56800,0.8284965,0.9822658,,,,,,,,,,,,,, -56900,0.8358606,1.0120189,,,,,,,,,,,,,, -57000,0.93253094,1.012379,,,,,,,,,,,,,, -57100,1.0681342,1.0491871,,,,,,,,,,,,,, -57200,0.9585054,0.98008853,,,,,,,,,,,,,, -57300,0.82171714,1.0317808,,,,,,,,,,,,,, -57400,1.1020805,1.0111203,,,,,,,,,,,,,, -57412,,,0.13343582,0.048016934638889,0.34958616,0.1037102831709742,5348.0,0.19371364,0.0641439684764284,2472.0,46125.87220811844,50483.96151971817,46125.87220811844,4353.911198377609,1.8190128803253167,0.0 -57500,0.9739952,1.067157,,,,,,,,,,,,,, -57600,0.94359016,1.0290678,,,,,,,,,,,,,, -57700,0.9204681,1.005016,,,,,,,,,,,,,, -57800,0.96442175,1.0360552,,,,,,,,,,,,,, -57900,0.99548835,0.9988481,,,,,,,,,,,,,, -58000,1.0074735,1.0109663,,,,,,,,,,,,,, -58100,1.0320752,1.0264401,,,,,,,,,,,,,, -58200,0.87885,1.0041994,,,,,,,,,,,,,, -58300,0.8505463,1.0152013,,,,,,,,,,,,,, -58400,0.9882148,1.0109438,,,,,,,,,,,,,, -58500,0.8730608,1.0176518,,,,,,,,,,,,,, -58600,1.0962553,1.0481236,,,,,,,,,,,,,, -58700,0.8329692,0.9898536,,,,,,,,,,,,,, -58800,1.0935533,1.0280733,,,,,,,,,,,,,, -58900,1.1268246,1.027978,,,,,,,,,,,,,, -59000,1.19295,1.0227183,,,,,,,,,,,,,, -59100,0.99634266,1.0160975,,,,,,,,,,,,,, -59200,0.9328141,1.0514207,,,,,,,,,,,,,, -59270,,,0.08705975,0.0339719108044844,0.34289178,0.1005628662734004,5348.0,0.184741,0.0619706294558527,2472.0,47566.122009038925,52055.225329875946,47566.122009038925,4484.790504455566,1.8767881393432613,0.0 -59300,0.86535084,0.9854123,,,,,,,,,,,,,, -59400,0.90002406,1.007217,,,,,,,,,,,,,, -59500,0.9832513,0.9513953,,,,,,,,,,,,,, -59600,0.9855431,1.0336776,,,,,,,,,,,,,, -59700,0.940607,1.014761,,,,,,,,,,,,,, -59800,1.0818216,0.99280244,,,,,,,,,,,,,, -59900,1.0112953,0.9681106,,,,,,,,,,,,,, -60000,0.89394075,1.0213772,,,,,,,,,,,,,, -60100,0.92168826,0.97828656,,,,,,,,,,,,,, -60200,0.9432335,0.94825643,,,,,,,,,,,,,, -60300,1.014676,0.96792454,,,,,,,,,,,,,, -60400,1.1173054,0.9729353,,,,,,,,,,,,,, -60500,1.0894989,0.9892479,,,,,,,,,,,,,, -60600,0.9167111,0.9808141,,,,,,,,,,,,,, -60700,0.97428983,0.9787736,,,,,,,,,,,,,, -60800,0.9642918,0.9460953,,,,,,,,,,,,,, -60900,1.0970298,0.96940607,,,,,,,,,,,,,, -61000,0.9666675,0.9792064,,,,,,,,,,,,,, -61100,1.1133437,1.0480351,,,,,,,,,,,,,, -61129,,,0.088854894,0.0341538013399359,0.33372676,0.0980816204369696,5348.0,0.17703222,0.0590051388296467,2472.0,49006.60219502449,53627.36594891548,49006.60219502449,4616.315147399902,1.937981128692627,0.0 -61200,1.2474737,0.98337734,,,,,,,,,,,,,, -61300,0.9574431,0.9280401,,,,,,,,,,,,,, -61400,0.9673637,0.94832647,,,,,,,,,,,,,, -61500,1.0905378,0.99207324,,,,,,,,,,,,,, -61600,0.91187376,1.0222116,,,,,,,,,,,,,, -61700,1.2391704,0.9644501,,,,,,,,,,,,,, -61800,1.1043274,0.9425179,,,,,,,,,,,,,, -61900,1.0715624,1.0060563,,,,,,,,,,,,,, -62000,1.0995007,0.9918034,,,,,,,,,,,,,, -62100,1.1018473,0.942602,,,,,,,,,,,,,, -62200,1.1514932,0.92391855,,,,,,,,,,,,,, -62300,0.9733825,0.9408399,,,,,,,,,,,,,, -62400,1.0151277,0.9394443,,,,,,,,,,,,,, -62500,1.1214559,0.9962428,,,,,,,,,,,,,, -62600,1.0179315,0.94878423,,,,,,,,,,,,,, -62700,1.0672916,0.96656346,,,,,,,,,,,,,, -62800,1.0608948,0.9633689,,,,,,,,,,,,,, -62900,1.0187948,0.9537843,,,,,,,,,,,,,, -62978,,,0.107303254,0.0416680418495659,0.33158988,0.0967106597024436,5348.0,0.17833579,0.0587410882944366,2472.0,50447.22417402268,55199.06018662453,50447.22417402268,4747.2524383068085,1.996875286102295,0.0 -63000,1.151972,0.94631696,,,,,,,,,,,,,, -63100,1.141763,0.9344741,,,,,,,,,,,,,, -63200,1.0066955,0.9626098,,,,,,,,,,,,,, -63300,1.2495751,0.92053264,,,,,,,,,,,,,, -63400,1.0728176,0.99209315,,,,,,,,,,,,,, -63500,1.296249,0.95167387,,,,,,,,,,,,,, -63600,1.0476792,0.9318857,,,,,,,,,,,,,, -63700,1.0523057,0.92015314,,,,,,,,,,,,,, -63800,1.1751052,0.92667544,,,,,,,,,,,,,, -63900,1.197691,0.9233489,,,,,,,,,,,,,, -64000,1.0630724,0.95278704,,,,,,,,,,,,,, -64100,1.0458035,0.936265,,,,,,,,,,,,,, -64200,1.1450094,0.9677794,,,,,,,,,,,,,, -64300,1.1256536,0.9448277,,,,,,,,,,,,,, -64400,1.0680857,0.92065054,,,,,,,,,,,,,, -64500,1.0689102,0.97633195,,,,,,,,,,,,,, -64600,1.0974528,0.95618486,,,,,,,,,,,,,, -64700,1.1115836,0.9438037,,,,,,,,,,,,,, -64800,1.1877625,0.91771626,,,,,,,,,,,,,, -64826,,,0.101906076,0.0387572157318086,0.31840345,0.0927425972947662,5348.0,0.16995475,0.0560193366238092,2472.0,51887.12818932533,56769.58395195007,51887.12818932533,4877.739913702011,2.052243232727051,0.0 -64900,1.2254033,0.96731615,,,,,,,,,,,,,, -65000,1.2004836,0.9269421,,,,,,,,,,,,,, -65100,1.0586159,0.9088363,,,,,,,,,,,,,, -65200,1.0321677,0.89347035,,,,,,,,,,,,,, -65300,1.3227143,0.986801,,,,,,,,,,,,,, -65400,0.9569573,0.9098754,,,,,,,,,,,,,, -65500,1.1628735,0.9383092,,,,,,,,,,,,,, -65600,1.0731275,0.9200547,,,,,,,,,,,,,, -65700,1.2976028,0.8908075,,,,,,,,,,,,,, -65800,1.0079374,0.8978127,,,,,,,,,,,,,, -65900,1.2789179,0.9140206,,,,,,,,,,,,,, -66000,2.0222642,0.91977996,,,,,,,,,,,,,, -66100,1.1494591,0.92556757,,,,,,,,,,,,,, -66200,1.1463344,0.9131948,,,,,,,,,,,,,, -66300,1.0813541,0.9372725,,,,,,,,,,,,,, -66400,1.0656028,0.9525871,,,,,,,,,,,,,, -66500,1.1063933,0.96113443,,,,,,,,,,,,,, -66600,1.192613,0.9135751,,,,,,,,,,,,,, -66672,,,0.11785042,0.0457300787328425,0.30807304,0.0889290093360495,5348.0,0.1668592,0.0543944102532854,2472.0,53327.30903625488,58338.56297969818,53327.30903625488,5006.407078027725,2.1070005893707275,0.0 -66700,1.067945,0.92007816,,,,,,,,,,,,,, -66800,1.1271951,0.9178225,,,,,,,,,,,,,, -66900,1.1316298,0.92673206,,,,,,,,,,,,,, -67000,1.2658551,0.8751219,,,,,,,,,,,,,, -67100,1.1856086,0.9369301,,,,,,,,,,,,,, -67200,1.143278,0.93120414,,,,,,,,,,,,,, -67300,1.1063108,0.98089087,,,,,,,,,,,,,, -67400,1.1810064,0.93916535,,,,,,,,,,,,,, -67500,1.3458533,0.9389377,,,,,,,,,,,,,, -67600,1.3763844,0.86162394,,,,,,,,,,,,,, -67700,1.1642569,0.8749896,,,,,,,,,,,,,, -67800,1.1132342,0.88647956,,,,,,,,,,,,,, -67900,1.1057471,0.937825,,,,,,,,,,,,,, -68000,1.0783134,0.90288746,,,,,,,,,,,,,, -68100,1.1258736,0.8563966,,,,,,,,,,,,,, -68200,1.4677999,0.8915264,,,,,,,,,,,,,, -68300,0.99457514,0.8561946,,,,,,,,,,,,,, -68400,1.1261129,0.8770566,,,,,,,,,,,,,, -68500,1.3390036,0.8763556,,,,,,,,,,,,,, -68524,,,0.092671074,0.0340948951196127,0.30833793,0.0882918022340867,5348.0,0.1638089,0.0525257449271829,2472.0,54767.38990926743,59907.71535515785,54767.38990926743,5135.344121932983,2.1641712188720703,0.0 -68600,1.2948494,0.9319779,,,,,,,,,,,,,, -68700,1.2723607,0.9001643,,,,,,,,,,,,,, -68800,1.2750084,0.8850672,,,,,,,,,,,,,, -68900,1.3512001,0.8827934,,,,,,,,,,,,,, -69000,1.1606445,0.9175013,,,,,,,,,,,,,, -69100,1.1062081,0.88156784,,,,,,,,,,,,,, -69200,1.1558028,0.8697975,,,,,,,,,,,,,, -69300,1.269063,0.8877808,,,,,,,,,,,,,, -69400,1.3759973,0.88312364,,,,,,,,,,,,,, -69500,1.1513379,0.8657584,,,,,,,,,,,,,, -69600,1.319811,0.9105979,,,,,,,,,,,,,, -69700,1.374258,0.8999183,,,,,,,,,,,,,, -69800,1.0593234,0.857819,,,,,,,,,,,,,, -69900,1.0415391,0.86455435,,,,,,,,,,,,,, -70000,0.99864745,0.91483164,,,,,,,,,,,,,, -70100,1.6244727,0.9054942,,,,,,,,,,,,,, -70200,1.1297554,0.9097047,,,,,,,,,,,,,, -70300,1.2852823,0.89250356,,,,,,,,,,,,,, -70376,,,0.08907641,0.034411574026683,0.30346766,0.0865829286424592,5348.0,0.1597395,0.0516523470030264,2472.0,56207.60083270073,61476.29270792008,56207.60083270073,5263.575304508209,2.2221508026123047,0.0 -70400,1.1181045,0.8804676,,,,,,,,,,,,,, -70500,1.1155288,0.8972317,,,,,,,,,,,,,, -70600,1.395139,0.8421903,,,,,,,,,,,,,, -70700,1.1638032,0.8540772,,,,,,,,,,,,,, -70800,1.0703892,0.8837582,,,,,,,,,,,,,, -70900,1.1346613,0.9244051,,,,,,,,,,,,,, -71000,1.4208622,0.8908788,,,,,,,,,,,,,, -71100,1.5084691,0.87189084,,,,,,,,,,,,,, -71200,1.379163,0.8623955,,,,,,,,,,,,,, -71300,1.0931116,0.85950327,,,,,,,,,,,,,, -71400,1.0288069,0.87605727,,,,,,,,,,,,,, -71500,1.3846803,0.8569242,,,,,,,,,,,,,, -71600,1.23368,0.8879693,,,,,,,,,,,,,, -71700,1.1900635,0.85543495,,,,,,,,,,,,,, -71800,1.1933126,0.8169889,,,,,,,,,,,,,, -71900,1.5055547,0.8670423,,,,,,,,,,,,,, -72000,1.355629,0.8938689,,,,,,,,,,,,,, -72100,1.2432024,0.88720423,,,,,,,,,,,,,, -72200,1.216091,0.87694526,,,,,,,,,,,,,, -72210,,,0.066054076,0.0250100868816741,0.2977439,0.0852988597854736,5348.0,0.15752408,0.0510633111937115,2472.0,57647.67440390587,63047.46157693863,57647.67440390587,5394.533260583878,2.282556772232056,0.0 -72300,1.3101698,0.8868782,,,,,,,,,,,,,, -72400,1.1115645,0.8387385,,,,,,,,,,,,,, -72500,1.2653172,0.9097033,,,,,,,,,,,,,, -72600,1.2227204,0.9319262,,,,,,,,,,,,,, -72700,1.3576012,0.8217966,,,,,,,,,,,,,, -72800,1.3304958,0.8722767,,,,,,,,,,,,,, -72900,1.3188,0.882716,,,,,,,,,,,,,, -73000,1.2031428,0.9827153,,,,,,,,,,,,,, -73100,1.2945614,0.85768276,,,,,,,,,,,,,, -73200,1.2695332,0.86504126,,,,,,,,,,,,,, -73300,1.2411431,0.8067214,,,,,,,,,,,,,, -73400,1.399786,0.8221307,,,,,,,,,,,,,, -73500,1.3871609,0.8822938,,,,,,,,,,,,,, -73600,1.5614713,0.82883316,,,,,,,,,,,,,, -73700,1.4363273,0.85203326,,,,,,,,,,,,,, -73800,1.3366874,0.8700972,,,,,,,,,,,,,, -73900,1.2695795,0.88400656,,,,,,,,,,,,,, -74000,1.3987193,0.90518975,,,,,,,,,,,,,, -74052,,,0.078413114,0.0301152721978346,0.29437685,0.0843913223978296,5348.0,0.15448758,0.04960087746024,2472.0,59088.03487086296,64621.01479768753,59088.03487086296,5527.584671020508,2.346060037612915,0.0 -74100,1.1042829,0.79856205,,,,,,,,,,,,,, -74200,1.1684921,0.830141,,,,,,,,,,,,,, -74300,1.1302284,0.8187664,,,,,,,,,,,,,, -74400,1.3099989,0.8453651,,,,,,,,,,,,,, -74500,1.2838192,0.8365215,,,,,,,,,,,,,, -74600,1.0784625,0.81330734,,,,,,,,,,,,,, -74700,1.4476607,0.8887517,,,,,,,,,,,,,, -74800,1.1731993,0.8413179,,,,,,,,,,,,,, -74900,1.179455,0.8660777,,,,,,,,,,,,,, -75000,1.3180317,0.85860336,,,,,,,,,,,,,, -75100,1.6303312,0.8666158,,,,,,,,,,,,,, -75200,1.3394989,0.81863385,,,,,,,,,,,,,, -75300,1.1899605,0.8876046,,,,,,,,,,,,,, -75400,1.4103333,0.86306775,,,,,,,,,,,,,, -75500,1.1874187,0.87762624,,,,,,,,,,,,,, -75600,1.3786224,0.8664213,,,,,,,,,,,,,, -75700,1.2646706,0.80594563,,,,,,,,,,,,,, -75800,1.5457615,0.86095876,,,,,,,,,,,,,, -75894,,,0.06854972,0.0253520374081496,0.29016793,0.0830493256224837,5348.0,0.15408525,0.0487681026953466,2472.0,60528.32988452912,66194.21731305122,60528.32988452912,5660.345217943192,2.4161722660064697,0.0 -75900,1.1465151,0.8365229,,,,,,,,,,,,,, -76000,1.4513053,0.82864,,,,,,,,,,,,,, -76100,1.1530479,0.8400319,,,,,,,,,,,,,, -76200,1.3088113,0.8385746,,,,,,,,,,,,,, -76300,1.2691154,0.8546907,,,,,,,,,,,,,, -76400,1.2564814,0.81405026,,,,,,,,,,,,,, -76500,1.1554366,0.84695524,,,,,,,,,,,,,, -76600,,,,,,,,,,,61068.26130890846,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/eval_measurements.csv deleted file mode 100644 index b209b510f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -132.29344940185547,0.0,34.868218183517456,1,0,34.868218183517456,30.871214,2472,1.1899945158734997,167.16175413131714,31.446594,1.1269105188367123,30.757864,5348,1.1779352558965794 -243.40519452095032,0.0271918773651123,1474.8262186050415,1822,0,1474.8262186050415,4.465194,2472,0.8685231450449902,1718.332086801529,4.786604,0.9100190787680568,4.6146393,5348,0.8706855769137936 -369.1273944377899,0.0727007389068603,2914.829460382461,3672,0,2914.829460382461,2.4753306,2472,0.5817845753864278,3284.181474685669,2.9874406,0.6618090910126309,2.8403525,5348,0.6263166533110632 -498.098771572113,0.1232173442840576,4354.856928110123,5527,0,4354.856928110123,1.5014671,2472,0.4370442589320171,4853.308263778687,1.8403991,0.5122606049758368,1.9006361,5348,0.4987593770817846 -627.3559744358063,0.1754987239837646,5795.416927576065,7370,0,5795.416927576065,1.1917499,2472,0.3648162817622326,6423.257106304169,1.4814537,0.4351280226418841,1.549442,5348,0.4238875426011566 -757.7280685901642,0.2269151210784912,7236.158366441727,9210,0,7236.158366441727,1.1388015,2472,0.3562854183169825,7994.499317407608,1.5415139,0.4593235089347876,1.4839153,5348,0.4172934145611477 -887.9946537017822,0.2738659381866455,8676.491871833801,11039,0,8676.491871833801,0.9740689,2472,0.3123514715739443,9565.224183321,1.1786427,0.368609909825751,1.3118448,5348,0.3775065893007134 -1017.7202196121216,0.3288073539733886,10116.50751399994,12891,0,10116.50751399994,0.8644295,2472,0.2827778116304105,11135.100909948347,1.1353468,0.3522497808195933,1.1930368,5348,0.3462834413045367 -1146.2477378845217,0.3858742713928222,11556.709911346436,14741,0,11556.709911346436,0.8089821,2472,0.2671785184733817,12703.967337846756,1.091014,0.3399796344439577,1.137504,5348,0.3313090744084111 -1277.6439950466156,0.4783322811126709,12997.059435367584,16567,0,12997.059435367584,0.77649456,2472,0.2576117644669226,14275.884579896929,1.0301901,0.326469812233411,1.10055,5348,0.3198200372669608 -1407.5817940235138,0.5285263061523438,14437.434829473495,18413,0,14437.434829473495,0.8797922,2472,0.281091950520992,15846.326422691343,1.1991059,0.3634266049870247,1.2338159,5348,0.3535340857526284 -1538.6586077213287,0.5804932117462158,15878.150474071505,20256,0,15878.150474071505,0.7080025,2472,0.2387626185688461,17418.24999141693,0.88449436,0.288827478657715,1.0200983,5348,0.3036098747791498 -1669.7016806602478,0.6308321952819824,17318.73151898384,22103,0,17318.73151898384,0.67560107,2472,0.2254585339101822,18990.00296020508,0.90416443,0.2953098019868589,0.9769954,5348,0.2892244417196868 -1799.7231032848358,0.6856436729431152,18758.80566763878,23950,0,18758.80566763878,0.64889586,2472,0.2191010094855076,20560.23240017891,0.83565265,0.2703343074093714,0.94258255,5348,0.2817903588634542 -1931.4437320232391,0.7362205982208252,20198.753033638,25774,0,20198.753033638,0.62141186,2472,0.208884284930839,22132.027775764465,0.83382195,0.2674261006030476,0.90188086,5348,0.2706006159668652 -2063.012087583542,0.7946665287017822,21638.79001927376,27620,0,21638.79001927376,0.6013053,2472,0.2044157374118985,23703.77072691917,0.7256441,0.2424350146842476,0.8796342,5348,0.2647788601716597 -2193.7856407165527,0.8524501323699951,23079.2897040844,29464,0,23079.2897040844,0.5760872,2472,0.1957426929092275,25275.1791806221,0.7058413,0.2416801166096091,0.8496082,5348,0.2589764136825743 -2323.7597975730896,0.9073889255523682,24519.32209968567,31322,0,24519.32209968567,0.5572826,2472,0.1906851095809721,26845.321516752243,0.71590084,0.2423552059780039,0.8296148,5348,0.2496210548673933 -2465.121926546097,0.965153694152832,25959.6188287735,33176,0,25959.6188287735,0.54166,2472,0.187049336826925,28427.117254018784,0.5189762,0.1857237276933245,0.810009,5348,0.2471687729901426 -2596.715493917465,1.015394926071167,27400.167891979218,35020,0,27400.167891979218,0.5160036,2472,0.1775638291389921,29999.38902139664,0.4539365,0.1625599436555531,0.78835493,5348,0.2402656960522123 -2730.4466235637665,1.0714967250823977,28840.363377332687,36863,0,28840.363377332687,0.5011267,2472,0.1698048057197408,31573.44949054718,0.44536912,0.1580918412839946,0.7607473,5348,0.2320206223389362 -2862.9936952590942,1.1281471252441406,30280.498059034348,38718,0,30280.498059034348,0.4796191,2472,0.1635691507728556,33146.26831173897,0.41135126,0.1500194003983548,0.7361761,5348,0.2231962694420576 -2992.8555755615234,1.185434341430664,31720.88926625252,40574,0,31720.88926625252,0.45901236,2472,0.1562366705258668,34716.65667510033,0.40231338,0.1454793648395642,0.70659876,5348,0.2144781177288394 -3124.930632591248,1.2468442916870115,33161.03370857239,42424,0,33161.03370857239,0.4475745,2472,0.1520524851217679,36289.01644325256,0.37856817,0.1399152124411051,0.6976467,5348,0.214381571198239 -3255.2952768802643,1.3067612648010254,34601.290254592896,44268,0,34601.290254592896,0.43433073,2472,0.1464464891434607,37859.7775645256,0.42395383,0.1483157612824591,0.67980665,5348,0.207304710505228 -3389.268036603928,1.362666368484497,36041.80644035339,46107,0,36041.80644035339,0.4211136,2472,0.1439684764284118,39434.40048265457,0.3622536,0.1347455764944699,0.65977854,5348,0.2022360176487058 -3521.39735031128,1.421417236328125,37482.298840522766,47963,0,37482.298840522766,0.40312463,2472,0.1372453435703694,41007.16001605988,0.33330858,0.1254129063868123,0.6382336,5348,0.195207430220995 -3652.845793962479,1.4763050079345703,38922.43569779396,49822,0,38922.43569779396,0.38547108,2472,0.1314362317957467,42578.880200862885,0.31033447,0.1149222233388164,0.6196946,5348,0.1900711547930525 -3785.52684879303,1.538241624832153,40362.72778439522,51669,0,40362.72778439522,0.3817462,2472,0.1289785306603294,44151.99367618561,0.29647183,0.1116216549116634,0.6016999,5348,0.1810826727941531 -3917.9606053829193,1.5975022315979004,41802.975400447845,53518,0,41802.975400447845,0.35301486,2472,0.1198383198261328,45724.815037965775,0.30620503,0.1127560213666556,0.57299715,5348,0.1762456916110719 -4051.69429731369,1.6556503772735596,43243.05828857422,55366,0,43243.05828857422,0.33865485,2472,0.115654134422034,47298.76901054382,0.28244948,0.1038226074516651,0.55347747,5348,0.1705397916525869 -4184.260848999023,1.7116804122924805,44683.50492525101,57223,0,44683.50492525101,0.3345989,2472,0.1149026059756667,48871.916709423065,0.27888992,0.1028839746470579,0.54089737,5348,0.1678364887957751 -4315.326943397522,1.769678831100464,46124.06595420837,59078,0,46124.06595420837,0.31451392,2472,0.1089513131436231,50443.68307852745,0.24648692,0.0924602903006723,0.51485807,5348,0.1599196732865404 -4446.514243841171,1.8272485733032229,47564.641932964325,60931,0,47564.641932964325,0.3011645,2472,0.1023500497633701,52015.58261036873,0.22372997,0.0836315461675692,0.4961345,5348,0.1519352752058854 -4580.05917096138,1.886429786682129,49005.04101896286,62775,0,49005.04101896286,0.2855554,2472,0.0995064286149533,53589.66577386856,0.20581685,0.0759066700481866,0.47948486,5348,0.1463066124718808 -4711.07154917717,1.944622278213501,50445.12503552437,64623,0,50445.12503552437,0.2775251,2472,0.0937176284199622,55160.89828634262,0.206059,0.0774030837838861,0.46793726,5348,0.1434971084314085 -4842.50137925148,2.015546560287476,51885.24637913704,66485,0,51885.24637913704,0.2626538,2472,0.0902443483029675,56732.60203003883,0.17852962,0.0660985968496317,0.44406882,5348,0.1364781756567577 -4975.279711008072,2.079103946685791,53325.28448152542,68347,0,53325.28448152542,0.25467312,2472,0.0852070765543436,58305.56083631516,0.1791338,0.0678948165873385,0.4300354,5348,0.1310426059839539 -5107.127786159515,2.1438803672790527,54765.86633563042,70199,0,54765.86633563042,0.24319124,2472,0.082627505941137,59878.135501384735,0.1901347,0.0665341131428934,0.4159782,5348,0.1261766608416926 -5242.33988404274,2.204277276992798,56206.35363817215,72051,0,56206.35363817215,0.23223098,2472,0.0790932910852477,61453.973516225815,0.13278739,0.0494465020682123,0.3995478,5348,0.1224789287196964 -5375.872570037842,2.272150993347168,57646.68882584572,73899,0,57646.68882584572,0.2275426,2472,0.0774277415554607,63027.98972511292,0.13270302,0.0492665924636954,0.39282784,5348,0.119418403699663 -5507.54785490036,2.3289687633514404,59086.815969944,75757,0,59086.815969944,0.22376499,2472,0.0750512867385696,64599.93062663078,0.16273986,0.0621603130729865,0.38568026,5348,0.1166088996591907 -5638.74435043335,2.393007516860962,60527.34537887573,77618,0,60527.34537887573,0.21958135,2472,0.07427944671257083,66171.80233621597,0.1638349,0.06019522480349154,0.3825297,5348,0.1155468878225861 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/measurements.csv deleted file mode 100644 index 7923aa668..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/measurements.csv +++ /dev/null @@ -1,829 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,39.550594,32.219704,,,,,,,,,,,,,, -1,,,31.446594,1.1269105188367123,30.757864,1.1779352558965794,5348.0,30.871214,1.1899945158734997,2472.0,34.868218183517456,167.16175413131714,34.868218183517456,132.29344940185547,0.0,0.0 -100,5.3064513,5.8377485,,,,,,,,,,,,,, -200,3.723013,5.7664237,,,,,,,,,,,,,, -300,1.3693209,5.5976954,,,,,,,,,,,,,, -400,1.5431001,5.5890565,,,,,,,,,,,,,, -500,1.1457114,5.549234,,,,,,,,,,,,,, -600,0.43792978,5.5002666,,,,,,,,,,,,,, -700,3.7343125,5.558849,,,,,,,,,,,,,, -800,0.36215448,5.495203,,,,,,,,,,,,,, -900,3.0857573,5.485403,,,,,,,,,,,,,, -1000,1.9484832,5.4634166,,,,,,,,,,,,,, -1100,1.5805905,5.3904877,,,,,,,,,,,,,, -1200,2.7703466,5.1770563,,,,,,,,,,,,,, -1300,1.8790433,4.729324,,,,,,,,,,,,,, -1400,1.7646728,4.392812,,,,,,,,,,,,,, -1500,1.2208544,4.2224402,,,,,,,,,,,,,, -1600,0.63997644,4.0704856,,,,,,,,,,,,,, -1700,2.4097078,3.991147,,,,,,,,,,,,,, -1800,0.96450394,3.9017794,,,,,,,,,,,,,, -1822,,,4.786604,0.9100190787680568,4.6146393,0.8706855769137936,5348.0,4.465194,0.8685231450449902,2472.0,1474.8262186050415,1718.332086801529,1474.8262186050415,243.40519452095032,0.0271918773651123,0.0 -1900,1.2285758,3.8038757,,,,,,,,,,,,,, -2000,1.7948021,3.7669017,,,,,,,,,,,,,, -2100,1.949466,3.7195523,,,,,,,,,,,,,, -2200,1.9319893,3.5958264,,,,,,,,,,,,,, -2300,1.7120742,3.5806181,,,,,,,,,,,,,, -2400,1.2646648,3.3973625,,,,,,,,,,,,,, -2500,0.6941887,3.3156805,,,,,,,,,,,,,, -2600,0.9639467,3.3295538,,,,,,,,,,,,,, -2700,0.6150938,3.2545118,,,,,,,,,,,,,, -2800,1.8777021,3.1842496,,,,,,,,,,,,,, -2900,1.0908813,3.1826196,,,,,,,,,,,,,, -3000,1.1241616,3.1596067,,,,,,,,,,,,,, -3100,1.1309047,3.1333337,,,,,,,,,,,,,, -3200,1.9047579,3.1001444,,,,,,,,,,,,,, -3300,1.1693367,3.005895,,,,,,,,,,,,,, -3400,1.6794442,3.0057473,,,,,,,,,,,,,, -3500,0.61537164,3.0163808,,,,,,,,,,,,,, -3600,0.9261984,2.9821398,,,,,,,,,,,,,, -3672,,,2.9874406,0.6618090910126309,2.8403525,0.6263166533110632,5348.0,2.4753306,0.5817845753864278,2472.0,2914.829460382461,3284.181474685669,2914.829460382461,369.1273944377899,0.0727007389068603,0.0 -3700,1.2029264,2.9696116,,,,,,,,,,,,,, -3800,0.9390143,2.8918467,,,,,,,,,,,,,, -3900,1.6752945,2.904835,,,,,,,,,,,,,, -4000,0.9396365,2.9547327,,,,,,,,,,,,,, -4100,0.69464886,2.8107896,,,,,,,,,,,,,, -4200,1.4341198,2.8181593,,,,,,,,,,,,,, -4300,1.0151165,2.8839562,,,,,,,,,,,,,, -4400,0.79567516,2.714461,,,,,,,,,,,,,, -4500,1.2651824,2.7711086,,,,,,,,,,,,,, -4600,1.3854821,2.8706892,,,,,,,,,,,,,, -4700,0.77849615,2.8379228,,,,,,,,,,,,,, -4800,0.5295031,2.7285292,,,,,,,,,,,,,, -4900,1.479526,2.7062683,,,,,,,,,,,,,, -5000,1.7491404,2.817684,,,,,,,,,,,,,, -5100,0.5976361,2.7485528,,,,,,,,,,,,,, -5200,1.160448,2.7724166,,,,,,,,,,,,,, -5300,2.3325875,2.7343485,,,,,,,,,,,,,, -5400,1.2982974,2.6857507,,,,,,,,,,,,,, -5500,0.60902,2.7494714,,,,,,,,,,,,,, -5527,,,1.8403991,0.5122606049758368,1.9006361,0.4987593770817846,5348.0,1.5014671,0.4370442589320171,2472.0,4354.856928110123,4853.308263778687,4354.856928110123,498.098771572113,0.1232173442840576,0.0 -5600,1.4048495,2.6650999,,,,,,,,,,,,,, -5700,1.1954173,2.6728792,,,,,,,,,,,,,, -5800,0.8777982,2.6741076,,,,,,,,,,,,,, -5900,0.8164663,2.5999887,,,,,,,,,,,,,, -6000,0.8480012,2.6361766,,,,,,,,,,,,,, -6100,1.2391503,2.5725265,,,,,,,,,,,,,, -6200,1.6369361,2.6732552,,,,,,,,,,,,,, -6300,0.91002977,2.6320527,,,,,,,,,,,,,, -6400,0.8164535,2.58036,,,,,,,,,,,,,, -6500,1.4531652,2.5805042,,,,,,,,,,,,,, -6600,0.7341854,2.5365567,,,,,,,,,,,,,, -6700,0.5924628,2.5367541,,,,,,,,,,,,,, -6800,1.5696155,2.6002247,,,,,,,,,,,,,, -6900,0.969463,2.6064088,,,,,,,,,,,,,, -7000,0.68015575,2.493282,,,,,,,,,,,,,, -7100,1.2361164,2.4633439,,,,,,,,,,,,,, -7200,1.0439534,2.4988384,,,,,,,,,,,,,, -7300,1.2401998,2.4404638,,,,,,,,,,,,,, -7370,,,1.4814537,0.4351280226418841,1.549442,0.4238875426011566,5348.0,1.1917499,0.3648162817622326,2472.0,5795.416927576065,6423.257106304169,5795.416927576065,627.3559744358063,0.1754987239837646,0.0 -7400,1.2838535,2.4395912,,,,,,,,,,,,,, -7500,1.4841845,2.534463,,,,,,,,,,,,,, -7600,0.5258479,2.4083424,,,,,,,,,,,,,, -7700,0.6485867,2.4291067,,,,,,,,,,,,,, -7800,1.3382347,2.450111,,,,,,,,,,,,,, -7900,1.2627907,2.4999158,,,,,,,,,,,,,, -8000,2.0401568,2.4773142,,,,,,,,,,,,,, -8100,1.0779102,2.4049475,,,,,,,,,,,,,, -8200,0.7794374,2.388849,,,,,,,,,,,,,, -8300,0.65187407,2.3953938,,,,,,,,,,,,,, -8400,2.2589922,2.5390968,,,,,,,,,,,,,, -8500,1.2696365,2.4441128,,,,,,,,,,,,,, -8600,1.2406045,2.3597865,,,,,,,,,,,,,, -8700,0.8876284,2.410219,,,,,,,,,,,,,, -8800,1.1915835,2.4305296,,,,,,,,,,,,,, -8900,0.9390555,2.2841613,,,,,,,,,,,,,, -9000,1.1099185,2.3241243,,,,,,,,,,,,,, -9100,0.8242513,2.4013116,,,,,,,,,,,,,, -9200,1.3095474,2.3578336,,,,,,,,,,,,,, -9210,,,1.5415139,0.4593235089347876,1.4839153,0.4172934145611477,5348.0,1.1388015,0.3562854183169825,2472.0,7236.158366441727,7994.499317407608,7236.158366441727,757.7280685901642,0.2269151210784912,0.0 -9300,0.6610012,2.390754,,,,,,,,,,,,,, -9400,1.2368197,2.3789308,,,,,,,,,,,,,, -9500,0.9119992,2.4013624,,,,,,,,,,,,,, -9600,0.57508105,2.3337502,,,,,,,,,,,,,, -9700,0.79891586,2.330086,,,,,,,,,,,,,, -9800,0.7007428,2.282985,,,,,,,,,,,,,, -9900,0.8383224,2.348607,,,,,,,,,,,,,, -10000,1.1322047,2.435026,,,,,,,,,,,,,, -10100,0.56057435,2.331833,,,,,,,,,,,,,, -10200,0.73942804,2.3640518,,,,,,,,,,,,,, -10300,0.79112834,2.2809622,,,,,,,,,,,,,, -10400,0.65062666,2.2335465,,,,,,,,,,,,,, -10500,1.7698325,2.3278081,,,,,,,,,,,,,, -10600,1.0002769,2.2929864,,,,,,,,,,,,,, -10700,0.58207995,2.2968762,,,,,,,,,,,,,, -10800,0.65909123,2.2668128,,,,,,,,,,,,,, -10900,0.54604685,2.2340925,,,,,,,,,,,,,, -11000,1.3229052,2.181547,,,,,,,,,,,,,, -11039,,,1.1786427,0.368609909825751,1.3118448,0.3775065893007134,5348.0,0.9740689,0.3123514715739443,2472.0,8676.491871833801,9565.224183321,8676.491871833801,887.9946537017822,0.2738659381866455,0.0 -11100,0.72128254,2.274934,,,,,,,,,,,,,, -11200,0.56627476,2.2690132,,,,,,,,,,,,,, -11300,0.75501204,2.2598832,,,,,,,,,,,,,, -11400,1.0241991,2.2717297,,,,,,,,,,,,,, -11500,0.5880801,2.226361,,,,,,,,,,,,,, -11600,1.1388465,2.1622698,,,,,,,,,,,,,, -11700,0.5263289,2.1570294,,,,,,,,,,,,,, -11800,0.92801505,2.204756,,,,,,,,,,,,,, -11900,0.5056254,2.2069745,,,,,,,,,,,,,, -12000,0.70068944,2.1808543,,,,,,,,,,,,,, -12100,1.242634,2.2507281,,,,,,,,,,,,,, -12200,0.75142354,2.1187272,,,,,,,,,,,,,, -12300,0.85245216,2.1318936,,,,,,,,,,,,,, -12400,0.57203054,2.1734838,,,,,,,,,,,,,, -12500,0.90308356,2.1206524,,,,,,,,,,,,,, -12600,0.7473082,2.2183514,,,,,,,,,,,,,, -12700,0.69777566,2.1390543,,,,,,,,,,,,,, -12800,0.61295485,2.129695,,,,,,,,,,,,,, -12891,,,1.1353468,0.3522497808195933,1.1930368,0.3462834413045367,5348.0,0.8644295,0.2827778116304105,2472.0,10116.50751399994,11135.100909948347,10116.50751399994,1017.7202196121216,0.3288073539733886,0.0 -12900,0.6459562,2.1153896,,,,,,,,,,,,,, -13000,0.90732753,2.081261,,,,,,,,,,,,,, -13100,1.456399,2.1553736,,,,,,,,,,,,,, -13200,0.84231186,2.1011946,,,,,,,,,,,,,, -13300,0.70707905,2.0954673,,,,,,,,,,,,,, -13400,1.0591294,2.128496,,,,,,,,,,,,,, -13500,0.9947137,2.1157014,,,,,,,,,,,,,, -13600,0.7198215,2.232559,,,,,,,,,,,,,, -13700,0.56265205,2.1469405,,,,,,,,,,,,,, -13800,0.90310144,2.130083,,,,,,,,,,,,,, -13900,1.0336281,2.1656547,,,,,,,,,,,,,, -14000,0.7986821,2.1285865,,,,,,,,,,,,,, -14100,0.8948344,2.1220703,,,,,,,,,,,,,, -14200,0.7101128,2.1088333,,,,,,,,,,,,,, -14300,1.0185729,2.1404383,,,,,,,,,,,,,, -14400,0.6491746,2.1779447,,,,,,,,,,,,,, -14500,0.51341367,2.0602643,,,,,,,,,,,,,, -14600,0.9008968,2.0613077,,,,,,,,,,,,,, -14700,1.0537843,2.1140494,,,,,,,,,,,,,, -14741,,,1.091014,0.3399796344439577,1.137504,0.3313090744084111,5348.0,0.8089821,0.2671785184733817,2472.0,11556.709911346436,12703.967337846756,11556.709911346436,1146.2477378845217,0.3858742713928222,0.0 -14800,0.65372616,2.0069,,,,,,,,,,,,,, -14900,0.6813357,2.083054,,,,,,,,,,,,,, -15000,1.130462,2.084042,,,,,,,,,,,,,, -15100,0.9009434,2.0819998,,,,,,,,,,,,,, -15200,0.90302134,2.0287373,,,,,,,,,,,,,, -15300,0.56336623,2.0730855,,,,,,,,,,,,,, -15400,1.1045482,2.0387163,,,,,,,,,,,,,, -15500,0.74360883,2.067439,,,,,,,,,,,,,, -15600,0.9476809,2.0399837,,,,,,,,,,,,,, -15700,0.52932674,2.0582762,,,,,,,,,,,,,, -15800,0.82848126,2.0689085,,,,,,,,,,,,,, -15900,1.006844,2.0850604,,,,,,,,,,,,,, -16000,1.0630212,2.0349798,,,,,,,,,,,,,, -16100,0.8149728,2.0666957,,,,,,,,,,,,,, -16200,0.5917758,2.0374234,,,,,,,,,,,,,, -16300,0.76264256,2.0682778,,,,,,,,,,,,,, -16400,0.9765174,2.0364861,,,,,,,,,,,,,, -16500,0.5820055,2.034402,,,,,,,,,,,,,, -16567,,,1.0301901,0.326469812233411,1.10055,0.3198200372669608,5348.0,0.77649456,0.2576117644669226,2472.0,12997.059435367584,14275.884579896929,12997.059435367584,1277.6439950466156,0.4783322811126709,0.0 -16600,0.62319964,2.0048337,,,,,,,,,,,,,, -16700,0.87101257,1.9687852,,,,,,,,,,,,,, -16800,0.72918457,2.059888,,,,,,,,,,,,,, -16900,0.5476906,1.9753867,,,,,,,,,,,,,, -17000,0.70912504,1.9949405,,,,,,,,,,,,,, -17100,0.6984801,1.9968511,,,,,,,,,,,,,, -17200,1.046294,2.0556839,,,,,,,,,,,,,, -17300,0.71655226,2.013692,,,,,,,,,,,,,, -17400,0.5002393,1.9884598,,,,,,,,,,,,,, -17500,0.78872234,1.9437102,,,,,,,,,,,,,, -17600,0.61400473,1.9892669,,,,,,,,,,,,,, -17700,1.1660029,2.0284681,,,,,,,,,,,,,, -17800,0.7121152,1.9786395,,,,,,,,,,,,,, -17900,0.6803032,2.023457,,,,,,,,,,,,,, -18000,0.6031041,2.0198696,,,,,,,,,,,,,, -18100,0.58047515,1.9828401,,,,,,,,,,,,,, -18200,1.3374615,2.060221,,,,,,,,,,,,,, -18300,0.55151826,1.9620292,,,,,,,,,,,,,, -18400,0.6241441,1.9986035,,,,,,,,,,,,,, -18413,,,1.1991059,0.3634266049870247,1.2338159,0.3535340857526284,5348.0,0.8797922,0.281091950520992,2472.0,14437.434829473495,15846.326422691343,14437.434829473495,1407.5817940235138,0.5285263061523438,0.0 -18500,1.0240175,2.0121136,,,,,,,,,,,,,, -18600,0.6734064,1.9541906,,,,,,,,,,,,,, -18700,0.7820144,1.9879541,,,,,,,,,,,,,, -18800,1.5726249,1.964764,,,,,,,,,,,,,, -18900,0.74107474,1.9302555,,,,,,,,,,,,,, -19000,0.7460172,2.0433192,,,,,,,,,,,,,, -19100,1.0159135,1.9761555,,,,,,,,,,,,,, -19200,0.7732463,1.9945388,,,,,,,,,,,,,, -19300,0.98112726,1.9799502,,,,,,,,,,,,,, -19400,0.73257214,2.0091627,,,,,,,,,,,,,, -19500,0.71192795,1.9401426,,,,,,,,,,,,,, -19600,0.9337178,1.9466518,,,,,,,,,,,,,, -19700,0.7741991,1.9000075,,,,,,,,,,,,,, -19800,0.92312413,1.9385134,,,,,,,,,,,,,, -19900,0.6484752,2.0175169,,,,,,,,,,,,,, -20000,0.8976198,1.9457252,,,,,,,,,,,,,, -20100,0.97242904,1.9288595,,,,,,,,,,,,,, -20200,0.6929891,1.9541196,,,,,,,,,,,,,, -20256,,,0.88449436,0.288827478657715,1.0200983,0.3036098747791498,5348.0,0.7080025,0.2387626185688461,2472.0,15878.150474071505,17418.24999141693,15878.150474071505,1538.6586077213287,0.5804932117462158,0.0 -20300,1.173129,2.0331833,,,,,,,,,,,,,, -20400,0.691438,2.0113719,,,,,,,,,,,,,, -20500,0.6892476,1.985394,,,,,,,,,,,,,, -20600,0.8989007,1.9939685,,,,,,,,,,,,,, -20700,0.8471918,1.9239546,,,,,,,,,,,,,, -20800,1.0062128,1.9871142,,,,,,,,,,,,,, -20900,0.9032702,1.9932837,,,,,,,,,,,,,, -21000,0.78303945,1.9410998,,,,,,,,,,,,,, -21100,0.48740256,1.9378977,,,,,,,,,,,,,, -21200,0.53743035,1.9795759,,,,,,,,,,,,,, -21300,0.8467962,1.9636015,,,,,,,,,,,,,, -21400,1.0054728,1.9730583,,,,,,,,,,,,,, -21500,0.74851406,1.9235524,,,,,,,,,,,,,, -21600,0.93957347,1.9209678,,,,,,,,,,,,,, -21700,0.86844146,1.9266835,,,,,,,,,,,,,, -21800,0.66488725,1.8938018,,,,,,,,,,,,,, -21900,0.7531136,1.9322163,,,,,,,,,,,,,, -22000,0.6375026,1.8974206,,,,,,,,,,,,,, -22100,0.6569458,1.8519542,,,,,,,,,,,,,, -22103,,,0.90416443,0.2953098019868589,0.9769954,0.2892244417196868,5348.0,0.67560107,0.2254585339101822,2472.0,17318.73151898384,18990.00296020508,17318.73151898384,1669.7016806602478,0.6308321952819824,0.0 -22200,1.0011492,1.9303385,,,,,,,,,,,,,, -22300,0.94317436,1.9572034,,,,,,,,,,,,,, -22400,0.60729945,1.9581039,,,,,,,,,,,,,, -22500,0.54255795,2.0011725,,,,,,,,,,,,,, -22600,0.58676827,1.9334213,,,,,,,,,,,,,, -22700,1.0520875,1.8700784,,,,,,,,,,,,,, -22800,0.62923414,1.9436938,,,,,,,,,,,,,, -22900,0.6927154,1.9392159,,,,,,,,,,,,,, -23000,0.98207426,1.8840942,,,,,,,,,,,,,, -23100,0.704883,1.8987076,,,,,,,,,,,,,, -23200,0.70036167,1.90166,,,,,,,,,,,,,, -23300,0.9785137,1.8948457,,,,,,,,,,,,,, -23400,0.7860101,1.9111402,,,,,,,,,,,,,, -23500,0.5141234,1.9254299,,,,,,,,,,,,,, -23600,0.51699275,1.9193596,,,,,,,,,,,,,, -23700,0.8412578,1.887677,,,,,,,,,,,,,, -23800,0.87837887,1.8849941,,,,,,,,,,,,,, -23900,0.7242739,1.7969171,,,,,,,,,,,,,, -23950,,,0.83565265,0.2703343074093714,0.94258255,0.2817903588634542,5348.0,0.64889586,0.2191010094855076,2472.0,18758.80566763878,20560.23240017891,18758.80566763878,1799.7231032848358,0.6856436729431152,0.0 -24000,0.9350784,1.8588237,,,,,,,,,,,,,, -24100,0.95157784,1.8933134,,,,,,,,,,,,,, -24200,0.53964293,1.8295869,,,,,,,,,,,,,, -24300,0.9455451,1.8889989,,,,,,,,,,,,,, -24400,0.7411234,1.8988645,,,,,,,,,,,,,, -24500,0.5903552,1.7738965,,,,,,,,,,,,,, -24600,0.64302343,1.9215654,,,,,,,,,,,,,, -24700,0.71748465,1.9112718,,,,,,,,,,,,,, -24800,0.59052646,1.8695116,,,,,,,,,,,,,, -24900,0.5198293,1.7937063,,,,,,,,,,,,,, -25000,0.6159112,1.8996534,,,,,,,,,,,,,, -25100,0.695884,1.798456,,,,,,,,,,,,,, -25200,0.6288899,1.820506,,,,,,,,,,,,,, -25300,0.6624283,1.811315,,,,,,,,,,,,,, -25400,0.5339246,1.7683953,,,,,,,,,,,,,, -25500,0.7470896,1.8808019,,,,,,,,,,,,,, -25600,1.0131935,1.8785483,,,,,,,,,,,,,, -25700,0.7581015,1.8459184,,,,,,,,,,,,,, -25774,,,0.83382195,0.2674261006030476,0.90188086,0.2706006159668652,5348.0,0.62141186,0.208884284930839,2472.0,20198.753033638,22132.027775764465,20198.753033638,1931.4437320232391,0.7362205982208252,0.0 -25800,0.67241377,1.8339422,,,,,,,,,,,,,, -25900,0.63061184,1.8011961,,,,,,,,,,,,,, -26000,1.0564071,1.8704094,,,,,,,,,,,,,, -26100,0.5374898,1.7537756,,,,,,,,,,,,,, -26200,0.70431435,1.821936,,,,,,,,,,,,,, -26300,0.51565933,1.8297144,,,,,,,,,,,,,, -26400,0.4687543,1.744071,,,,,,,,,,,,,, -26500,0.73953754,1.7971882,,,,,,,,,,,,,, -26600,0.4827238,1.7839311,,,,,,,,,,,,,, -26700,0.6116471,1.8175381,,,,,,,,,,,,,, -26800,0.9767289,1.7992799,,,,,,,,,,,,,, -26900,0.5634514,1.789628,,,,,,,,,,,,,, -27000,0.585845,1.7484007,,,,,,,,,,,,,, -27100,0.54535306,1.8590478,,,,,,,,,,,,,, -27200,0.6084375,1.7994591,,,,,,,,,,,,,, -27300,0.73477703,1.8493567,,,,,,,,,,,,,, -27400,0.576098,1.7944896,,,,,,,,,,,,,, -27500,0.8474095,1.7618412,,,,,,,,,,,,,, -27600,0.7073307,1.8003024,,,,,,,,,,,,,, -27620,,,0.7256441,0.2424350146842476,0.8796342,0.2647788601716597,5348.0,0.6013053,0.2044157374118985,2472.0,21638.79001927376,23703.77072691917,21638.79001927376,2063.012087583542,0.7946665287017822,0.0 -27700,0.8435535,1.7619996,,,,,,,,,,,,,, -27800,0.7515245,1.7720244,,,,,,,,,,,,,, -27900,0.57116735,1.8034213,,,,,,,,,,,,,, -28000,0.54262066,1.7974982,,,,,,,,,,,,,, -28100,0.7447726,1.7891291,,,,,,,,,,,,,, -28200,0.7575212,1.8525189,,,,,,,,,,,,,, -28300,0.5438994,1.788396,,,,,,,,,,,,,, -28400,0.95539916,1.8187256,,,,,,,,,,,,,, -28500,0.5503793,1.7925842,,,,,,,,,,,,,, -28600,0.61232966,1.7908361,,,,,,,,,,,,,, -28700,0.69572264,1.7853379,,,,,,,,,,,,,, -28800,0.77106047,1.7668675,,,,,,,,,,,,,, -28900,0.7195387,1.7027241,,,,,,,,,,,,,, -29000,0.8248612,1.8111023,,,,,,,,,,,,,, -29100,0.7497535,1.7745316,,,,,,,,,,,,,, -29200,0.86001605,1.7705417,,,,,,,,,,,,,, -29300,0.7882125,1.7847024,,,,,,,,,,,,,, -29400,0.7256404,1.7950109,,,,,,,,,,,,,, -29464,,,0.7058413,0.2416801166096091,0.8496082,0.2589764136825743,5348.0,0.5760872,0.1957426929092275,2472.0,23079.2897040844,25275.1791806221,23079.2897040844,2193.7856407165527,0.8524501323699951,0.0 -29500,0.8925212,1.8258481,,,,,,,,,,,,,, -29600,0.63178575,1.7329243,,,,,,,,,,,,,, -29700,0.87805355,1.8139724,,,,,,,,,,,,,, -29800,0.6316626,1.7641791,,,,,,,,,,,,,, -29900,0.8346553,1.787634,,,,,,,,,,,,,, -30000,0.6840506,1.7570895,,,,,,,,,,,,,, -30100,0.66503656,1.7742743,,,,,,,,,,,,,, -30200,0.5545476,1.8215635,,,,,,,,,,,,,, -30300,0.526566,1.7314193,,,,,,,,,,,,,, -30400,0.5346506,1.7333245,,,,,,,,,,,,,, -30500,0.84788555,1.8348663,,,,,,,,,,,,,, -30600,0.68761134,1.7826402,,,,,,,,,,,,,, -30700,0.6359146,1.768116,,,,,,,,,,,,,, -30800,0.6663406,1.7880467,,,,,,,,,,,,,, -30900,0.6292365,1.7265457,,,,,,,,,,,,,, -31000,0.81525075,1.740643,,,,,,,,,,,,,, -31100,0.9822072,1.7286215,,,,,,,,,,,,,, -31200,0.54840237,1.7534093,,,,,,,,,,,,,, -31300,0.608124,1.6958718,,,,,,,,,,,,,, -31322,,,0.71590084,0.2423552059780039,0.8296148,0.2496210548673933,5348.0,0.5572826,0.1906851095809721,2472.0,24519.32209968567,26845.321516752243,24519.32209968567,2323.7597975730896,0.9073889255523682,0.0 -31400,0.7904003,1.7697195,,,,,,,,,,,,,, -31500,0.64286864,1.7970172,,,,,,,,,,,,,, -31600,0.69094414,1.7174793,,,,,,,,,,,,,, -31700,0.6411223,1.7870393,,,,,,,,,,,,,, -31800,0.6865558,1.7549382,,,,,,,,,,,,,, -31900,0.87048084,1.7438762,,,,,,,,,,,,,, -32000,0.5940178,1.7977008,,,,,,,,,,,,,, -32100,0.67580503,1.6703762,,,,,,,,,,,,,, -32200,0.51263523,1.7436695,,,,,,,,,,,,,, -32300,0.8801638,1.7185017,,,,,,,,,,,,,, -32400,0.6652186,1.7548289,,,,,,,,,,,,,, -32500,0.74556273,1.7197411,,,,,,,,,,,,,, -32600,0.6173228,1.7218099,,,,,,,,,,,,,, -32700,0.597372,1.6873902,,,,,,,,,,,,,, -32800,0.6157814,1.8202633,,,,,,,,,,,,,, -32900,0.52067345,1.7260382,,,,,,,,,,,,,, -33000,0.67036545,1.7424221,,,,,,,,,,,,,, -33100,0.9016411,1.7698358,,,,,,,,,,,,,, -33176,,,0.5189762,0.1857237276933245,0.810009,0.2471687729901426,5348.0,0.54166,0.187049336826925,2472.0,25959.6188287735,28427.117254018784,25959.6188287735,2465.121926546097,0.965153694152832,0.0 -33200,0.71217674,1.6565064,,,,,,,,,,,,,, -33300,0.7688876,1.7436725,,,,,,,,,,,,,, -33400,0.5601228,1.7598193,,,,,,,,,,,,,, -33500,0.5542472,1.7175049,,,,,,,,,,,,,, -33600,0.6769914,1.6595695,,,,,,,,,,,,,, -33700,0.7115834,1.7843795,,,,,,,,,,,,,, -33800,0.86133987,1.7061034,,,,,,,,,,,,,, -33900,0.64836293,1.6958739,,,,,,,,,,,,,, -34000,0.72217023,1.6907531,,,,,,,,,,,,,, -34100,0.7007852,1.6960167,,,,,,,,,,,,,, -34200,0.5798257,1.6770113,,,,,,,,,,,,,, -34300,0.56797016,1.7488801,,,,,,,,,,,,,, -34400,0.63641155,1.6734465,,,,,,,,,,,,,, -34500,0.61240077,1.6851821,,,,,,,,,,,,,, -34600,0.6303085,1.6963241,,,,,,,,,,,,,, -34700,0.55337965,1.6785839,,,,,,,,,,,,,, -34800,0.98665464,1.666137,,,,,,,,,,,,,, -34900,0.79382235,1.636563,,,,,,,,,,,,,, -35000,0.5938071,1.7767446,,,,,,,,,,,,,, -35020,,,0.4539365,0.1625599436555531,0.78835493,0.2402656960522123,5348.0,0.5160036,0.1775638291389921,2472.0,27400.167891979218,29999.38902139664,27400.167891979218,2596.715493917465,1.015394926071167,0.0 -35100,0.72539544,1.6583346,,,,,,,,,,,,,, -35200,0.64330345,1.6360222,,,,,,,,,,,,,, -35300,0.69006026,1.6997384,,,,,,,,,,,,,, -35400,0.80944824,1.6766744,,,,,,,,,,,,,, -35500,0.6232861,1.7115415,,,,,,,,,,,,,, -35600,0.86026233,1.6522126,,,,,,,,,,,,,, -35700,0.6251313,1.6043768,,,,,,,,,,,,,, -35800,0.5794859,1.6502106,,,,,,,,,,,,,, -35900,0.72135484,1.6613281,,,,,,,,,,,,,, -36000,0.5503356,1.6735353,,,,,,,,,,,,,, -36100,0.7113036,1.6200612,,,,,,,,,,,,,, -36200,0.69338244,1.6331285,,,,,,,,,,,,,, -36300,0.60639864,1.6404598,,,,,,,,,,,,,, -36400,0.5779376,1.6661105,,,,,,,,,,,,,, -36500,0.6027106,1.6454291,,,,,,,,,,,,,, -36600,0.8173215,1.721428,,,,,,,,,,,,,, -36700,0.6439328,1.682893,,,,,,,,,,,,,, -36800,0.66700983,1.6306775,,,,,,,,,,,,,, -36863,,,0.44536912,0.1580918412839946,0.7607473,0.2320206223389362,5348.0,0.5011267,0.1698048057197408,2472.0,28840.363377332687,31573.44949054718,28840.363377332687,2730.4466235637665,1.0714967250823977,0.0 -36900,0.7996556,1.623744,,,,,,,,,,,,,, -37000,0.58889544,1.7332094,,,,,,,,,,,,,, -37100,0.6941535,1.6166933,,,,,,,,,,,,,, -37200,0.7248189,1.6203885,,,,,,,,,,,,,, -37300,0.89744085,1.7136507,,,,,,,,,,,,,, -37400,0.6337225,1.5870622,,,,,,,,,,,,,, -37500,0.6896412,1.6488962,,,,,,,,,,,,,, -37600,0.66509753,1.6349051,,,,,,,,,,,,,, -37700,0.6169227,1.6558101,,,,,,,,,,,,,, -37800,0.59187216,1.6369771,,,,,,,,,,,,,, -37900,0.6531769,1.67015,,,,,,,,,,,,,, -38000,0.57675046,1.6613914,,,,,,,,,,,,,, -38100,0.63261914,1.659769,,,,,,,,,,,,,, -38200,0.5770907,1.6628491,,,,,,,,,,,,,, -38300,0.5304195,1.7321135,,,,,,,,,,,,,, -38400,0.6323228,1.6719693,,,,,,,,,,,,,, -38500,0.868311,1.709422,,,,,,,,,,,,,, -38600,0.7916696,1.6567405,,,,,,,,,,,,,, -38700,0.73469764,1.6619102,,,,,,,,,,,,,, -38718,,,0.41135126,0.1500194003983548,0.7361761,0.2231962694420576,5348.0,0.4796191,0.1635691507728556,2472.0,30280.498059034348,33146.26831173897,30280.498059034348,2862.9936952590942,1.1281471252441406,0.0 -38800,0.71002024,1.640579,,,,,,,,,,,,,, -38900,0.5877658,1.5660493,,,,,,,,,,,,,, -39000,0.5827138,1.6280779,,,,,,,,,,,,,, -39100,0.738109,1.6286534,,,,,,,,,,,,,, -39200,0.7719378,1.6215528,,,,,,,,,,,,,, -39300,0.6519159,1.6253452,,,,,,,,,,,,,, -39400,0.71073484,1.6369282,,,,,,,,,,,,,, -39500,0.56874526,1.5575397,,,,,,,,,,,,,, -39600,0.6189989,1.6166062,,,,,,,,,,,,,, -39700,0.61987394,1.6125896,,,,,,,,,,,,,, -39800,0.64098865,1.5874628,,,,,,,,,,,,,, -39900,0.57081884,1.6508434,,,,,,,,,,,,,, -40000,0.69561505,1.5636706,,,,,,,,,,,,,, -40100,0.599533,1.5757908,,,,,,,,,,,,,, -40200,0.96846294,1.6070216,,,,,,,,,,,,,, -40300,0.716032,1.6301556,,,,,,,,,,,,,, -40400,0.63984805,1.5536294,,,,,,,,,,,,,, -40500,0.54644734,1.5420078,,,,,,,,,,,,,, -40574,,,0.40231338,0.1454793648395642,0.70659876,0.2144781177288394,5348.0,0.45901236,0.1562366705258668,2472.0,31720.88926625252,34716.65667510033,31720.88926625252,2992.8555755615234,1.185434341430664,0.0 -40600,0.700066,1.6325212,,,,,,,,,,,,,, -40700,0.60577685,1.5089246,,,,,,,,,,,,,, -40800,0.71383023,1.6946548,,,,,,,,,,,,,, -40900,0.7196324,1.6460922,,,,,,,,,,,,,, -41000,0.58973724,1.5936892,,,,,,,,,,,,,, -41100,0.69335103,1.5997615,,,,,,,,,,,,,, -41200,0.61607707,1.6224109,,,,,,,,,,,,,, -41300,0.69929993,1.5845072,,,,,,,,,,,,,, -41400,0.7806133,1.600275,,,,,,,,,,,,,, -41500,0.6711699,1.5764377,,,,,,,,,,,,,, -41600,0.6549175,1.5679923,,,,,,,,,,,,,, -41700,0.61515105,1.5509883,,,,,,,,,,,,,, -41800,0.64811593,1.5744928,,,,,,,,,,,,,, -41900,0.59980977,1.5986228,,,,,,,,,,,,,, -42000,0.59784454,1.535893,,,,,,,,,,,,,, -42100,0.6388639,1.61641,,,,,,,,,,,,,, -42200,0.5989913,1.6343981,,,,,,,,,,,,,, -42300,0.84168327,1.6034005,,,,,,,,,,,,,, -42400,0.7299498,1.618905,,,,,,,,,,,,,, -42424,,,0.37856817,0.1399152124411051,0.6976467,0.214381571198239,5348.0,0.4475745,0.1520524851217679,2472.0,33161.03370857239,36289.01644325256,33161.03370857239,3124.930632591248,1.2468442916870115,0.0 -42500,0.80674726,1.545546,,,,,,,,,,,,,, -42600,0.65224475,1.5776926,,,,,,,,,,,,,, -42700,0.7209218,1.5416676,,,,,,,,,,,,,, -42800,0.78907895,1.5537101,,,,,,,,,,,,,, -42900,0.62079537,1.5421797,,,,,,,,,,,,,, -43000,0.7355025,1.6036631,,,,,,,,,,,,,, -43100,0.7344759,1.5608557,,,,,,,,,,,,,, -43200,0.69648844,1.5618571,,,,,,,,,,,,,, -43300,0.696879,1.5652612,,,,,,,,,,,,,, -43400,0.87414956,1.5397837,,,,,,,,,,,,,, -43500,0.59661543,1.5396922,,,,,,,,,,,,,, -43600,0.72091264,1.5864016,,,,,,,,,,,,,, -43700,0.58731586,1.5480348,,,,,,,,,,,,,, -43800,0.68325585,1.5482994,,,,,,,,,,,,,, -43900,0.66670954,1.5152733,,,,,,,,,,,,,, -44000,0.63436913,1.602809,,,,,,,,,,,,,, -44100,0.71217257,1.4962074,,,,,,,,,,,,,, -44200,0.83320624,1.5210987,,,,,,,,,,,,,, -44268,,,0.42395383,0.1483157612824591,0.67980665,0.207304710505228,5348.0,0.43433073,0.1464464891434607,2472.0,34601.290254592896,37859.7775645256,34601.290254592896,3255.2952768802643,1.3067612648010254,0.0 -44300,0.72958463,1.5629559,,,,,,,,,,,,,, -44400,0.61536217,1.5358734,,,,,,,,,,,,,, -44500,0.5087456,1.5248617,,,,,,,,,,,,,, -44600,0.6618101,1.6159961,,,,,,,,,,,,,, -44700,0.63501894,1.5204773,,,,,,,,,,,,,, -44800,0.5114561,1.4788433,,,,,,,,,,,,,, -44900,0.61033475,1.5163329,,,,,,,,,,,,,, -45000,0.62098217,1.5847417,,,,,,,,,,,,,, -45100,0.6277794,1.5737065,,,,,,,,,,,,,, -45200,0.61279726,1.5791616,,,,,,,,,,,,,, -45300,0.62314796,1.4982568,,,,,,,,,,,,,, -45400,0.5407949,1.5382206,,,,,,,,,,,,,, -45500,0.76705945,1.5333312,,,,,,,,,,,,,, -45600,0.6949182,1.5370637,,,,,,,,,,,,,, -45700,0.89685816,1.5147126,,,,,,,,,,,,,, -45800,0.8086463,1.5587084,,,,,,,,,,,,,, -45900,0.7469556,1.5306432,,,,,,,,,,,,,, -46000,0.5565277,1.4854517,,,,,,,,,,,,,, -46100,0.6009767,1.4962782,,,,,,,,,,,,,, -46107,,,0.3622536,0.1347455764944699,0.65977854,0.2022360176487058,5348.0,0.4211136,0.1439684764284118,2472.0,36041.80644035339,39434.40048265457,36041.80644035339,3389.268036603928,1.362666368484497,0.0 -46200,0.66435355,1.5139056,,,,,,,,,,,,,, -46300,0.73451996,1.5356328,,,,,,,,,,,,,, -46400,0.64135844,1.4906545,,,,,,,,,,,,,, -46500,0.6383321,1.5408087,,,,,,,,,,,,,, -46600,0.8193317,1.5655156,,,,,,,,,,,,,, -46700,0.6274105,1.4788822,,,,,,,,,,,,,, -46800,0.59832054,1.5743204,,,,,,,,,,,,,, -46900,0.66558355,1.519496,,,,,,,,,,,,,, -47000,0.77795374,1.5625166,,,,,,,,,,,,,, -47100,0.6844051,1.5887282,,,,,,,,,,,,,, -47200,0.69246936,1.5192065,,,,,,,,,,,,,, -47300,0.69475305,1.5488875,,,,,,,,,,,,,, -47400,0.57561815,1.5325531,,,,,,,,,,,,,, -47500,0.6473039,1.4665393,,,,,,,,,,,,,, -47600,0.65506035,1.4562163,,,,,,,,,,,,,, -47700,0.74730533,1.4821658,,,,,,,,,,,,,, -47800,0.62923807,1.4806254,,,,,,,,,,,,,, -47900,0.61779225,1.4452198,,,,,,,,,,,,,, -47963,,,0.33330858,0.1254129063868123,0.6382336,0.195207430220995,5348.0,0.40312463,0.1372453435703694,2472.0,37482.298840522766,41007.16001605988,37482.298840522766,3521.39735031128,1.421417236328125,0.0 -48000,0.72829664,1.5379691,,,,,,,,,,,,,, -48100,0.70559776,1.5045708,,,,,,,,,,,,,, -48200,0.6206104,1.4774265,,,,,,,,,,,,,, -48300,0.7306031,1.5273362,,,,,,,,,,,,,, -48400,0.63132644,1.5013134,,,,,,,,,,,,,, -48500,0.57533437,1.5255427,,,,,,,,,,,,,, -48600,0.7633123,1.4890915,,,,,,,,,,,,,, -48700,0.7273132,1.5024604,,,,,,,,,,,,,, -48800,0.66916066,1.4845817,,,,,,,,,,,,,, -48900,0.74027914,1.4836668,,,,,,,,,,,,,, -49000,0.8425069,1.4763783,,,,,,,,,,,,,, -49100,0.58188236,1.4548552,,,,,,,,,,,,,, -49200,0.77630454,1.4827837,,,,,,,,,,,,,, -49300,0.90525055,1.4953896,,,,,,,,,,,,,, -49400,0.6069859,1.4713488,,,,,,,,,,,,,, -49500,0.7738265,1.4776924,,,,,,,,,,,,,, -49600,0.72548014,1.4762714,,,,,,,,,,,,,, -49700,0.61671704,1.4686742,,,,,,,,,,,,,, -49800,0.70645285,1.4914839,,,,,,,,,,,,,, -49822,,,0.31033447,0.1149222233388164,0.6196946,0.1900711547930525,5348.0,0.38547108,0.1314362317957467,2472.0,38922.43569779396,42578.880200862885,38922.43569779396,3652.845793962479,1.4763050079345703,0.0 -49900,0.7423766,1.5017141,,,,,,,,,,,,,, -50000,0.64842236,1.4668998,,,,,,,,,,,,,, -50100,0.64007354,1.4785689,,,,,,,,,,,,,, -50200,0.7629869,1.4453886,,,,,,,,,,,,,, -50300,0.6742171,1.5072944,,,,,,,,,,,,,, -50400,0.6395108,1.4861796,,,,,,,,,,,,,, -50500,0.62715083,1.4358845,,,,,,,,,,,,,, -50600,0.60696614,1.4750414,,,,,,,,,,,,,, -50700,0.63317764,1.4625931,,,,,,,,,,,,,, -50800,0.71852386,1.4907882,,,,,,,,,,,,,, -50900,0.5938553,1.395477,,,,,,,,,,,,,, -51000,0.69598013,1.4781744,,,,,,,,,,,,,, -51100,0.61908615,1.511212,,,,,,,,,,,,,, -51200,0.6827611,1.4386427,,,,,,,,,,,,,, -51300,0.7628414,1.4221661,,,,,,,,,,,,,, -51400,0.73767775,1.4973594,,,,,,,,,,,,,, -51500,0.74488133,1.4044483,,,,,,,,,,,,,, -51600,0.6869849,1.4634029,,,,,,,,,,,,,, -51669,,,0.29647183,0.1116216549116634,0.6016999,0.1810826727941531,5348.0,0.3817462,0.1289785306603294,2472.0,40362.72778439522,44151.99367618561,40362.72778439522,3785.52684879303,1.538241624832153,0.0 -51700,0.6983575,1.4124202,,,,,,,,,,,,,, -51800,0.7006516,1.4618045,,,,,,,,,,,,,, -51900,0.78214884,1.3885485,,,,,,,,,,,,,, -52000,0.7153401,1.3666393,,,,,,,,,,,,,, -52100,0.65697825,1.4992493,,,,,,,,,,,,,, -52200,0.6623849,1.4504534,,,,,,,,,,,,,, -52300,0.6573242,1.4168564,,,,,,,,,,,,,, -52400,0.6152126,1.4941912,,,,,,,,,,,,,, -52500,0.71808237,1.4265579,,,,,,,,,,,,,, -52600,0.69908077,1.4337714,,,,,,,,,,,,,, -52700,0.6946559,1.460248,,,,,,,,,,,,,, -52800,0.6384547,1.362993,,,,,,,,,,,,,, -52900,0.6476799,1.4302332,,,,,,,,,,,,,, -53000,0.7081899,1.3964541,,,,,,,,,,,,,, -53100,0.65813875,1.4196992,,,,,,,,,,,,,, -53200,0.9248075,1.3653346,,,,,,,,,,,,,, -53300,0.6078083,1.397601,,,,,,,,,,,,,, -53400,0.6350057,1.4175776,,,,,,,,,,,,,, -53500,0.6930179,1.4204491,,,,,,,,,,,,,, -53518,,,0.30620503,0.1127560213666556,0.57299715,0.1762456916110719,5348.0,0.35301486,0.1198383198261328,2472.0,41802.975400447845,45724.815037965775,41802.975400447845,3917.9606053829193,1.5975022315979004,0.0 -53600,0.6003255,1.3863475,,,,,,,,,,,,,, -53700,0.83158153,1.4273696,,,,,,,,,,,,,, -53800,0.68131495,1.4223621,,,,,,,,,,,,,, -53900,0.61892307,1.3650723,,,,,,,,,,,,,, -54000,0.73890436,1.3928628,,,,,,,,,,,,,, -54100,0.678462,1.4070823,,,,,,,,,,,,,, -54200,0.8100922,1.4007292,,,,,,,,,,,,,, -54300,0.66541326,1.3334,,,,,,,,,,,,,, -54400,0.7224562,1.3745915,,,,,,,,,,,,,, -54500,0.9149801,1.4571835,,,,,,,,,,,,,, -54600,0.78781635,1.3993831,,,,,,,,,,,,,, -54700,0.715706,1.387436,,,,,,,,,,,,,, -54800,0.9457056,1.393625,,,,,,,,,,,,,, -54900,0.6165685,1.4171672,,,,,,,,,,,,,, -55000,0.68708956,1.415846,,,,,,,,,,,,,, -55100,0.6834706,1.3772222,,,,,,,,,,,,,, -55200,0.7505147,1.4234546,,,,,,,,,,,,,, -55300,0.9401795,1.3902981,,,,,,,,,,,,,, -55366,,,0.28244948,0.1038226074516651,0.55347747,0.1705397916525869,5348.0,0.33865485,0.115654134422034,2472.0,43243.05828857422,47298.76901054382,43243.05828857422,4051.69429731369,1.6556503772735596,0.0 -55400,0.8408465,1.4334239,,,,,,,,,,,,,, -55500,0.6479648,1.3517478,,,,,,,,,,,,,, -55600,0.70053476,1.4110782,,,,,,,,,,,,,, -55700,0.634661,1.3885719,,,,,,,,,,,,,, -55800,0.5718055,1.3904299,,,,,,,,,,,,,, -55900,0.80394447,1.3477464,,,,,,,,,,,,,, -56000,0.7330284,1.3429224,,,,,,,,,,,,,, -56100,0.74713266,1.3692894,,,,,,,,,,,,,, -56200,0.87824523,1.4096766,,,,,,,,,,,,,, -56300,0.66310936,1.3637961,,,,,,,,,,,,,, -56400,0.6727142,1.4488434,,,,,,,,,,,,,, -56500,0.73309803,1.3479847,,,,,,,,,,,,,, -56600,0.87957776,1.4011959,,,,,,,,,,,,,, -56700,0.8336241,1.3270005,,,,,,,,,,,,,, -56800,0.6707721,1.3962508,,,,,,,,,,,,,, -56900,0.65307873,1.3091649,,,,,,,,,,,,,, -57000,0.74970096,1.3755584,,,,,,,,,,,,,, -57100,0.67356,1.3999527,,,,,,,,,,,,,, -57200,0.6757816,1.3430375,,,,,,,,,,,,,, -57223,,,0.27888992,0.1028839746470579,0.54089737,0.1678364887957751,5348.0,0.3345989,0.1149026059756667,2472.0,44683.50492525101,48871.916709423065,44683.50492525101,4184.260848999023,1.7116804122924805,0.0 -57300,0.7074818,1.3610398,,,,,,,,,,,,,, -57400,0.6871236,1.3447194,,,,,,,,,,,,,, -57500,0.83803767,1.3963351,,,,,,,,,,,,,, -57600,0.716017,1.3866228,,,,,,,,,,,,,, -57700,0.78948385,1.3270669,,,,,,,,,,,,,, -57800,0.7546094,1.3252225,,,,,,,,,,,,,, -57900,0.8288526,1.3655266,,,,,,,,,,,,,, -58000,0.76404643,1.3470863,,,,,,,,,,,,,, -58100,0.74088806,1.3457468,,,,,,,,,,,,,, -58200,0.81416494,1.382623,,,,,,,,,,,,,, -58300,0.7223473,1.3176186,,,,,,,,,,,,,, -58400,0.7545978,1.3121349,,,,,,,,,,,,,, -58500,0.8334565,1.3201195,,,,,,,,,,,,,, -58600,0.73125637,1.3553843,,,,,,,,,,,,,, -58700,0.6020319,1.3144503,,,,,,,,,,,,,, -58800,0.847088,1.334531,,,,,,,,,,,,,, -58900,0.69077057,1.3046567,,,,,,,,,,,,,, -59000,0.6767703,1.294775,,,,,,,,,,,,,, -59078,,,0.24648692,0.0924602903006723,0.51485807,0.1599196732865404,5348.0,0.31451392,0.1089513131436231,2472.0,46124.06595420837,50443.68307852745,46124.06595420837,4315.326943397522,1.769678831100464,0.0 -59100,0.77598035,1.2865908,,,,,,,,,,,,,, -59200,0.8596197,1.3659283,,,,,,,,,,,,,, -59300,0.8176308,1.3103906,,,,,,,,,,,,,, -59400,0.91550684,1.348114,,,,,,,,,,,,,, -59500,0.66053236,1.2813478,,,,,,,,,,,,,, -59600,0.8075119,1.3280666,,,,,,,,,,,,,, -59700,0.7931349,1.354949,,,,,,,,,,,,,, -59800,0.7847441,1.243489,,,,,,,,,,,,,, -59900,0.78661114,1.3217943,,,,,,,,,,,,,, -60000,0.76823163,1.3215017,,,,,,,,,,,,,, -60100,0.7171549,1.2898774,,,,,,,,,,,,,, -60200,0.71634334,1.2660744,,,,,,,,,,,,,, -60300,0.72272044,1.3118894,,,,,,,,,,,,,, -60400,0.7853376,1.3164026,,,,,,,,,,,,,, -60500,0.85147107,1.3492328,,,,,,,,,,,,,, -60600,0.7217233,1.3126191,,,,,,,,,,,,,, -60700,0.7223115,1.297996,,,,,,,,,,,,,, -60800,0.68179244,1.2834378,,,,,,,,,,,,,, -60900,0.7488297,1.2802577,,,,,,,,,,,,,, -60931,,,0.22372997,0.0836315461675692,0.4961345,0.1519352752058854,5348.0,0.3011645,0.1023500497633701,2472.0,47564.641932964325,52015.58261036873,47564.641932964325,4446.514243841171,1.8272485733032229,0.0 -61000,0.95486265,1.2725724,,,,,,,,,,,,,, -61100,0.66428155,1.2961364,,,,,,,,,,,,,, -61200,0.7719129,1.2724631,,,,,,,,,,,,,, -61300,0.85765404,1.2728996,,,,,,,,,,,,,, -61400,0.8022362,1.2573886,,,,,,,,,,,,,, -61500,0.7257493,1.3060501,,,,,,,,,,,,,, -61600,0.821982,1.2686211,,,,,,,,,,,,,, -61700,0.8098081,1.2921994,,,,,,,,,,,,,, -61800,0.85221803,1.2692313,,,,,,,,,,,,,, -61900,0.753563,1.2467446,,,,,,,,,,,,,, -62000,0.810338,1.2865821,,,,,,,,,,,,,, -62100,0.70080256,1.2533453,,,,,,,,,,,,,, -62200,0.7349611,1.2240592,,,,,,,,,,,,,, -62300,0.6893751,1.3109506,,,,,,,,,,,,,, -62400,0.81045806,1.2697263,,,,,,,,,,,,,, -62500,0.85080296,1.3236037,,,,,,,,,,,,,, -62600,0.7061587,1.2116402,,,,,,,,,,,,,, -62700,0.69291806,1.2555246,,,,,,,,,,,,,, -62775,,,0.20581685,0.0759066700481866,0.47948486,0.1463066124718808,5348.0,0.2855554,0.0995064286149533,2472.0,49005.04101896286,53589.66577386856,49005.04101896286,4580.05917096138,1.886429786682129,0.0 -62800,0.94900495,1.3117814,,,,,,,,,,,,,, -62900,0.8462754,1.256847,,,,,,,,,,,,,, -63000,0.73670495,1.2671108,,,,,,,,,,,,,, -63100,0.7491948,1.2587621,,,,,,,,,,,,,, -63200,0.84951127,1.296201,,,,,,,,,,,,,, -63300,0.8405752,1.2403258,,,,,,,,,,,,,, -63400,0.77906144,1.2293559,,,,,,,,,,,,,, -63500,0.85279614,1.2607925,,,,,,,,,,,,,, -63600,0.8417362,1.2439817,,,,,,,,,,,,,, -63700,0.7923113,1.2094988,,,,,,,,,,,,,, -63800,0.8128863,1.2760776,,,,,,,,,,,,,, -63900,1.0326331,1.2389692,,,,,,,,,,,,,, -64000,1.0481529,1.2259785,,,,,,,,,,,,,, -64100,0.8631455,1.228765,,,,,,,,,,,,,, -64200,0.85018915,1.2256025,,,,,,,,,,,,,, -64300,0.6985715,1.1952791,,,,,,,,,,,,,, -64400,0.77013737,1.2646831,,,,,,,,,,,,,, -64500,0.68154955,1.235596,,,,,,,,,,,,,, -64600,0.75385153,1.2420353,,,,,,,,,,,,,, -64623,,,0.206059,0.0774030837838861,0.46793726,0.1434971084314085,5348.0,0.2775251,0.0937176284199622,2472.0,50445.12503552437,55160.89828634262,50445.12503552437,4711.07154917717,1.944622278213501,0.0 -64700,0.81141216,1.1894144,,,,,,,,,,,,,, -64800,0.9249159,1.2591275,,,,,,,,,,,,,, -64900,0.8581002,1.2339208,,,,,,,,,,,,,, -65000,0.70774,1.189878,,,,,,,,,,,,,, -65100,0.68162775,1.1880717,,,,,,,,,,,,,, -65200,0.9164706,1.220607,,,,,,,,,,,,,, -65300,0.8592545,1.2289534,,,,,,,,,,,,,, -65400,0.8930399,1.2397923,,,,,,,,,,,,,, -65500,0.85854274,1.2149866,,,,,,,,,,,,,, -65600,0.7861038,1.2551218,,,,,,,,,,,,,, -65700,0.68621594,1.178648,,,,,,,,,,,,,, -65800,0.8392032,1.1818451,,,,,,,,,,,,,, -65900,0.8230587,1.2473228,,,,,,,,,,,,,, -66000,0.7949669,1.1972519,,,,,,,,,,,,,, -66100,0.83259505,1.2039623,,,,,,,,,,,,,, -66200,0.762301,1.1752219,,,,,,,,,,,,,, -66300,0.83548504,1.1810737,,,,,,,,,,,,,, -66400,0.8467813,1.2350088,,,,,,,,,,,,,, -66485,,,0.17852962,0.0660985968496317,0.44406882,0.1364781756567577,5348.0,0.2626538,0.0902443483029675,2472.0,51885.24637913704,56732.60203003883,51885.24637913704,4842.50137925148,2.015546560287476,0.0 -66500,0.8143145,1.226178,,,,,,,,,,,,,, -66600,0.8235508,1.153719,,,,,,,,,,,,,, -66700,0.89767945,1.2198712,,,,,,,,,,,,,, -66800,0.86298436,1.2041738,,,,,,,,,,,,,, -66900,0.8542018,1.1575756,,,,,,,,,,,,,, -67000,0.7939006,1.1879505,,,,,,,,,,,,,, -67100,0.8496114,1.1773027,,,,,,,,,,,,,, -67200,0.8084698,1.1906408,,,,,,,,,,,,,, -67300,0.82305443,1.1905435,,,,,,,,,,,,,, -67400,0.7917573,1.2052107,,,,,,,,,,,,,, -67500,0.8438505,1.2013685,,,,,,,,,,,,,, -67600,1.0430397,1.1274279,,,,,,,,,,,,,, -67700,0.8557302,1.2075082,,,,,,,,,,,,,, -67800,0.87806195,1.2121987,,,,,,,,,,,,,, -67900,0.8657973,1.1606055,,,,,,,,,,,,,, -68000,0.82108366,1.137907,,,,,,,,,,,,,, -68100,0.85045934,1.1953912,,,,,,,,,,,,,, -68200,0.9364278,1.1957426,,,,,,,,,,,,,, -68300,0.83660245,1.1363311,,,,,,,,,,,,,, -68347,,,0.1791338,0.0678948165873385,0.4300354,0.1310426059839539,5348.0,0.25467312,0.0852070765543436,2472.0,53325.28448152542,58305.56083631516,53325.28448152542,4975.279711008072,2.079103946685791,0.0 -68400,1.0559468,1.2044309,,,,,,,,,,,,,, -68500,0.97756827,1.1715004,,,,,,,,,,,,,, -68600,0.77869916,1.1543043,,,,,,,,,,,,,, -68700,0.8217143,1.1421869,,,,,,,,,,,,,, -68800,0.82269126,1.1315373,,,,,,,,,,,,,, -68900,0.82779086,1.1865207,,,,,,,,,,,,,, -69000,0.8625908,1.1905329,,,,,,,,,,,,,, -69100,0.73700666,1.1474484,,,,,,,,,,,,,, -69200,0.7681069,1.1158268,,,,,,,,,,,,,, -69300,1.0330086,1.1282384,,,,,,,,,,,,,, -69400,0.81771797,1.12499,,,,,,,,,,,,,, -69500,0.8262797,1.1841334,,,,,,,,,,,,,, -69600,0.8218091,1.1394515,,,,,,,,,,,,,, -69700,0.8967854,1.1714263,,,,,,,,,,,,,, -69800,0.9407914,1.0873076,,,,,,,,,,,,,, -69900,0.82993126,1.1151544,,,,,,,,,,,,,, -70000,0.9017647,1.1355624,,,,,,,,,,,,,, -70100,0.7725584,1.0986984,,,,,,,,,,,,,, -70199,,,0.1901347,0.0665341131428934,0.4159782,0.1261766608416926,5348.0,0.24319124,0.082627505941137,2472.0,54765.86633563042,59878.135501384735,54765.86633563042,5107.127786159515,2.1438803672790527,0.0 -70200,0.8263014,1.1638501,,,,,,,,,,,,,, -70300,0.8955203,1.1561027,,,,,,,,,,,,,, -70400,0.9681806,1.1210233,,,,,,,,,,,,,, -70500,0.78355974,1.1322042,,,,,,,,,,,,,, -70600,0.85764915,1.087835,,,,,,,,,,,,,, -70700,0.9485182,1.1375327,,,,,,,,,,,,,, -70800,0.98149544,1.1273775,,,,,,,,,,,,,, -70900,0.87114394,1.1232044,,,,,,,,,,,,,, -71000,0.9337178,1.115716,,,,,,,,,,,,,, -71100,1.0371205,1.094567,,,,,,,,,,,,,, -71200,0.79922813,1.098161,,,,,,,,,,,,,, -71300,0.8892563,1.1327558,,,,,,,,,,,,,, -71400,1.0237144,1.1409541,,,,,,,,,,,,,, -71500,0.97517407,1.1594213,,,,,,,,,,,,,, -71600,0.9531388,1.1393803,,,,,,,,,,,,,, -71700,1.019828,1.1643201,,,,,,,,,,,,,, -71800,0.8323882,1.0738196,,,,,,,,,,,,,, -71900,0.86624354,1.1692373,,,,,,,,,,,,,, -72000,1.0379431,1.1209172,,,,,,,,,,,,,, -72051,,,0.13278739,0.0494465020682123,0.3995478,0.1224789287196964,5348.0,0.23223098,0.0790932910852477,2472.0,56206.35363817215,61453.973516225815,56206.35363817215,5242.33988404274,2.204277276992798,0.0 -72100,0.92261535,1.0694726,,,,,,,,,,,,,, -72200,1.0118967,1.1375555,,,,,,,,,,,,,, -72300,0.9114291,1.1040182,,,,,,,,,,,,,, -72400,0.8167216,1.0784215,,,,,,,,,,,,,, -72500,0.8023329,1.121857,,,,,,,,,,,,,, -72600,0.9304153,1.146476,,,,,,,,,,,,,, -72700,0.9179664,1.0348638,,,,,,,,,,,,,, -72800,0.87423474,1.1076872,,,,,,,,,,,,,, -72900,0.824804,1.1153136,,,,,,,,,,,,,, -73000,1.0529387,1.1338689,,,,,,,,,,,,,, -73100,0.9742207,1.1389288,,,,,,,,,,,,,, -73200,0.8463871,1.0707021,,,,,,,,,,,,,, -73300,0.7903538,1.0519806,,,,,,,,,,,,,, -73400,0.9220399,1.0805082,,,,,,,,,,,,,, -73500,0.92214006,1.1538068,,,,,,,,,,,,,, -73600,0.9605646,1.054401,,,,,,,,,,,,,, -73700,0.81992024,1.1483923,,,,,,,,,,,,,, -73800,1.016013,1.1072989,,,,,,,,,,,,,, -73899,,,0.13270302,0.0492665924636954,0.39282784,0.119418403699663,5348.0,0.2275426,0.0774277415554607,2472.0,57646.68882584572,63027.98972511292,57646.68882584572,5375.872570037842,2.272150993347168,0.0 -73900,0.8968008,1.0545483,,,,,,,,,,,,,, -74000,1.202573,1.1270822,,,,,,,,,,,,,, -74100,0.88788974,1.0607212,,,,,,,,,,,,,, -74200,0.84836227,1.0169048,,,,,,,,,,,,,, -74300,0.9623189,1.0870425,,,,,,,,,,,,,, -74400,1.1308345,1.1077582,,,,,,,,,,,,,, -74500,0.73416847,1.0520853,,,,,,,,,,,,,, -74600,0.98472255,1.0767212,,,,,,,,,,,,,, -74700,0.9041686,1.0726821,,,,,,,,,,,,,, -74800,1.0549859,1.0393493,,,,,,,,,,,,,, -74900,0.91002303,1.0582347,,,,,,,,,,,,,, -75000,0.97154826,1.0808904,,,,,,,,,,,,,, -75100,0.79845744,1.1140522,,,,,,,,,,,,,, -75200,0.99976224,1.0590861,,,,,,,,,,,,,, -75300,1.2892927,1.0647689,,,,,,,,,,,,,, -75400,1.0862819,1.0657659,,,,,,,,,,,,,, -75500,1.0328903,1.0169976,,,,,,,,,,,,,, -75600,0.84311324,1.0372,,,,,,,,,,,,,, -75700,0.9041549,1.0237985,,,,,,,,,,,,,, -75757,,,0.16273986,0.0621603130729865,0.38568026,0.1166088996591907,5348.0,0.22376499,0.0750512867385696,2472.0,59086.815969944,64599.93062663078,59086.815969944,5507.54785490036,2.3289687633514404,0.0 -75800,1.2073352,1.0463054,,,,,,,,,,,,,, -75900,0.9629314,1.0673989,,,,,,,,,,,,,, -76000,1.0420585,1.0684277,,,,,,,,,,,,,, -76100,0.9173332,1.0530572,,,,,,,,,,,,,, -76200,0.9151423,1.0525738,,,,,,,,,,,,,, -76300,0.960457,1.0434052,,,,,,,,,,,,,, -76400,0.8906695,1.1093519,,,,,,,,,,,,,, -76500,0.8405554,1.063422,,,,,,,,,,,,,, -76600,0.9361579,1.1467472,,,,,,,,,,,,,, -76700,0.91067064,1.0821114,,,,,,,,,,,,,, -76800,1.0272374,1.071723,,,,,,,,,,,,,, -76900,0.8450375,1.0857761,,,,,,,,,,,,,, -77000,0.9028136,1.0709401,,,,,,,,,,,,,, -77100,0.896337,1.0758643,,,,,,,,,,,,,, -77200,0.9967114,1.1020977,,,,,,,,,,,,,, -77300,1.0465937,1.08707,,,,,,,,,,,,,, -77400,0.898162,1.0533926,,,,,,,,,,,,,, -77500,0.9409304,1.0634621,,,,,,,,,,,,,, -77600,1.0263275,1.0873237,,,,,,,,,,,,,, -77618,,,0.1638349,0.0601952248034915,0.3825297,0.1155468878225861,5348.0,0.21958135,0.0742794467125708,2472.0,60527.34537887573,66171.80233621597,60527.34537887573,5638.74435043335,2.393007516860962,0.0 -77700,0.8138798,1.0376806,,,,,,,,,,,,,, -77800,0.8319489,1.0513468,,,,,,,,,,,,,, -77900,0.92579097,1.0668012,,,,,,,,,,,,,, -78000,1.3186176,1.0902498,,,,,,,,,,,,,, -78100,1.264804,1.0981886,,,,,,,,,,,,,, -78200,0.902832,1.0593559,,,,,,,,,,,,,, -78300,0.829372,1.0422231,,,,,,,,,,,,,, -78320,,,,,,,,,,,61068.04541969299,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 4dead2938..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -134.41353964805603,0.0,36.66109991073608,1,0,36.66109991073608,30.871214,2472,1.190014827453131,171.07474374771118,32.05279,1.151012912416401,30.757864,5348,1.1779352558965794 -257.50309681892395,0.0270421504974365,1477.4896006584167,1820,0,1477.4896006584167,3.0782773,2472,0.5873296366258404,1735.0940651893616,3.4769754,0.6580758933017746,3.3963308,5348,0.6373519217586916 -389.10982155799866,0.081702709197998,2917.4977061748505,3665,0,2917.4977061748505,0.64782745,2472,0.2096155017975748,3306.841879844665,0.88411194,0.2777697504927147,0.9202193,5348,0.2683221178446952 -522.0387523174286,0.1281547546386718,4358.005998134613,5510,0,4358.005998134613,0.4876269,2472,0.1638535128876972,4880.403652906418,0.5478368,0.1874115335970506,0.74291396,5348,0.222394933238074 -653.3796038627625,0.1764125823974609,5798.597426176071,7342,0,5798.597426176071,0.42051786,2472,0.1432778827209392,6452.462328910828,0.55028814,0.1848111126959918,0.6664708,5348,0.201840176873244 -787.4491305351257,0.2273602485656738,7238.834161758423,9177,0,7238.834161758423,0.39491493,2472,0.1330814697459021,8026.899383306503,0.45161825,0.1562252391663874,0.6241024,5348,0.1880726416096237 -920.1078622341156,0.2839655876159668,8678.961176633835,11007,0,8678.961176633835,0.3724855,2472,0.1259114821359657,9599.821209192276,0.43340543,0.1545875990998041,0.60133076,5348,0.1809088890390723 -1051.5222551822662,0.340630292892456,10119.012127161026,12854,0,10119.012127161026,0.35132352,2472,0.1199601893039221,11171.424641609192,0.40785065,0.1410327868852459,0.5672202,5348,0.1705784102648271 -1182.161565065384,0.389937162399292,11558.991518497469,14692,0,11558.991518497469,0.33908087,2472,0.1147807364978774,12742.170778512957,0.4037296,0.1419416377749473,0.55786395,5348,0.1679716539386157 -1314.5952589511871,0.441178560256958,12999.028420209885,16515,0,12999.028420209885,0.32374918,2472,0.1109621595271464,14314.772800445557,0.399308,0.1375804774441912,0.535952,5348,0.161435453816967 -1447.1639330387115,0.4920070171356201,14439.1499106884,18354,0,14439.1499106884,0.31622794,2472,0.1068389088619422,15887.59229850769,0.37713492,0.132674836900189,0.5242192,5348,0.1588480067968757 -1578.2720866203308,0.5422773361206055,15879.57132577896,20191,0,15879.57132577896,0.3005866,2472,0.102979708731948,17459.250715255737,0.37280694,0.1348939676380699,0.50043136,5348,0.1525917916139683 -1710.5015604496002,0.5935218334197998,17319.888331651688,22036,0,17319.888331651688,0.29445583,2472,0.1005016960168992,19031.927735090256,0.31283918,0.1112585406676173,0.49085632,5348,0.1479865221043282 -1842.077701330185,0.6459090709686279,18759.982160806656,23861,0,18759.982160806656,0.28720966,2472,0.095565982166433,20603.729377031326,0.32723373,0.1149733224814613,0.48551804,5348,0.1441922434517315 -1973.528375864029,0.6977200508117676,20200.02801823616,25690,0,20200.02801823616,0.27616546,2472,0.0951394389941705,22175.355345249176,0.31858635,0.1120703817632829,0.47096306,5348,0.1414889405949197 -2105.8720936775208,0.7506282329559326,21640.408204317093,27515,0,21640.408204317093,0.27005804,2472,0.0916052241382812,23748.209998607635,0.31003976,0.1098124938211935,0.46048242,5348,0.1372891665138013 -2237.2667417526245,0.8045666217803955,23080.796627759933,29360,0,23080.796627759933,0.26472273,2472,0.0887006682509698,25320.12511134148,0.27465776,0.0989900013511687,0.45722303,5348,0.1353389265956728 -2367.49915266037,0.8600101470947266,24520.75093364716,31195,0,24520.75093364716,0.25443456,2472,0.0857351776247638,26890.44734716416,0.24588323,0.0895053800268287,0.4484083,5348,0.1336203983509852 -2497.8841466903687,0.9160919189453124,25961.20705795288,33018,0,25961.20705795288,0.24990562,2472,0.0821197164503483,28461.4237473011,0.27327314,0.098914586620965,0.44101134,5348,0.1304633268003514 -2630.312391757965,0.9704453945159912,27401.429488182068,34855,0,27401.429488182068,0.24126007,2472,0.0799666890094042,30034.20651745796,0.24580301,0.0894540231484832,0.42582294,5348,0.1269876516987362 -2761.6267426013947,1.024658441543579,28841.924087047577,36692,0,28841.924087047577,0.24345164,2472,0.0814291227428757,31606.14753127098,0.26628497,0.0925580900906949,0.42301342,5348,0.1271903994129971 -2893.059848546982,1.086231708526611,30282.230488061905,38534,0,30282.230488061905,0.2337808,2472,0.0773668068165661,33178.02899599075,0.2095532,0.076352120071763,0.40643167,5348,0.120103884066926 -3024.0696020126343,1.1442553997039795,31722.74143695832,40365,0,31722.74143695832,0.22827744,2472,0.0759653078219893,34749.68506407738,0.19563828,0.0730934437328656,0.40345013,5348,0.1190322175772613 -3154.5400528907776,1.197132587432861,33163.28744673729,42198,0,33163.28744673729,0.21821883,2472,0.0725732740235208,36320.83214735985,0.21608451,0.0794167166053134,0.39799383,5348,0.1173136893325738 -3296.9027593135834,1.2538504600524902,34603.37959957123,44043,0,34603.37959957123,0.21474697,2472,0.0706030507992606,37903.424902677536,0.14231463,0.0532641934465834,0.38264957,5348,0.1126601465576334 -3430.021687746048,1.3127069473266602,36043.25439476967,45893,0,36043.25439476967,0.21339078,2472,0.0698921455121564,39476.55716729164,0.12276159,0.0459216859633146,0.3812691,5348,0.1105264682313641 -3564.374750375748,1.379469633102417,37483.68469452858,47735,0,37483.68469452858,0.2057504,2472,0.067800052810107,41051.488491773605,0.12371109,0.0474424386832767,0.36838114,5348,0.1077845467623121 -3697.955447912216,1.4431352615356443,38924.20606184006,49571,0,38924.20606184006,0.20391922,2472,0.0671094591026344,42625.732850551605,0.12066799,0.0457725052918445,0.36527857,5348,0.1063460034563658 -3829.9755721092224,1.496940851211548,40364.20062971115,51403,0,40364.20062971115,0.19339007,2472,0.0655048443117421,44197.879747867584,0.10996965,0.0426972796731894,0.3555751,5348,0.1044054181912973 -3964.7778856754303,1.5501763820648191,41804.38695335388,53238,0,41804.38695335388,0.18965232,2472,0.0626002884244307,45773.00025463104,0.10054823,0.0394028957698239,0.35113725,5348,0.1017600432528457 -4095.675683498383,1.6069636344909668,43244.36760210991,55087,0,43244.36760210991,0.18984833,2472,0.0627018463225885,47344.01453351975,0.114656754,0.0430929397423648,0.34877175,5348,0.0997615300694169 -4227.535162687302,1.6658551692962646,44684.72154283524,56927,0,44684.72154283524,0.18326183,2472,0.0600207178112241,48916.36607122421,0.094018914,0.0355911965754031,0.33511576,5348,0.0966334224779632 -4360.501399755478,1.725421667098999,46124.952325344086,58747,0,46124.952325344086,0.18139437,2472,0.0579692482684378,50489.70118045807,0.09163833,0.0359684636692946,0.33312958,5348,0.0941328673354123 -4493.0639543533325,1.7812588214874268,47565.06146478653,60586,0,47565.06146478653,0.17375073,2472,0.0578473787906485,52062.507381916046,0.075750425,0.0300035989499534,0.32939577,5348,0.0937466812130106 -4626.207343816757,1.843952894210816,49005.45624756813,62424,0,49005.45624756813,0.17216443,2472,0.0564661913757032,53636.18750405312,0.070418306,0.0283815178793507,0.3231744,5348,0.0918640238663023 -4757.616107225418,1.9643621444702148,50445.885021448135,64266,0,50445.885021448135,0.1681303,2472,0.0537850628643389,55208.22303843498,0.07730179,0.0303223574687047,0.31241328,5348,0.0885524778667078 -4889.189654827118,2.029456615447998,51886.28319978714,66097,0,51886.28319978714,0.16740415,2472,0.0537038165458127,56780.337996959686,0.07038693,0.027615892312107,0.3174673,5348,0.0877125230504841 -5022.44410276413,2.0852842330932617,53326.39934182167,67925,0,53326.39934182167,0.16228499,2472,0.0524444986086568,58353.84253954888,0.06680417,0.0261072042993986,0.30674285,5348,0.0854726435405543 -5155.602071285248,2.1447830200195312,54766.46932411194,69755,0,54766.46932411194,0.1607685,2472,0.0518351512197103,59927.20699119568,0.05866672,0.0228105070404748,0.3071572,5348,0.0856753912548152 -5289.01722574234,2.2092363834381104,56206.59651684761,71598,0,56206.59651684761,0.15960225,2472,0.0507992606585014,61500.89174199104,0.053217243,0.0202194390030425,0.30390048,5348,0.084024445581548 -5421.555196762085,2.2709219455718994,57646.99338531494,73439,0,57646.99338531494,0.15717694,2472,0.050230536428818,63073.967522382736,0.05957009,0.0217736158018626,0.30104917,5348,0.0834065477857053 -5552.972937345505,2.3312904834747314,59087.44711709023,75268,0,59087.44711709023,0.15624015,2472,0.0497024353583978,64645.97832798958,0.056394346,0.0210327421923502,0.29899248,5348,0.0824700464388812 -5685.517039775848,2.3913497924804688,60527.67941856384,77109,0,60527.67941856384,0.15573071,2472,0.049499319562082346,66218.89422011375,0.05739345,0.021495527344796185,0.2971615,5348,0.0818038753777383 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/measurements.csv deleted file mode 100644 index 5fca84222..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/measurements.csv +++ /dev/null @@ -1,824 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,40.054005,32.430748,,,,,,,,,,,,,, -1,,,32.05279,1.151012912416401,30.757864,1.1779352558965794,5348.0,30.871214,1.190014827453131,2472.0,36.66109991073608,171.07474374771118,36.66109991073608,134.41353964805603,0.0,0.0 -100,0.86787117,5.9211755,,,,,,,,,,,,,, -200,1.3994772,5.844816,,,,,,,,,,,,,, -300,1.1309179,5.80712,,,,,,,,,,,,,, -400,1.342169,5.7899237,,,,,,,,,,,,,, -500,0.38046578,5.8046374,,,,,,,,,,,,,, -600,1.1254737,5.74806,,,,,,,,,,,,,, -700,0.6958833,5.5582376,,,,,,,,,,,,,, -800,1.8781445,5.4639816,,,,,,,,,,,,,, -900,2.2855213,4.691918,,,,,,,,,,,,,, -1000,1.0702143,3.7810383,,,,,,,,,,,,,, -1100,1.4159647,3.4544344,,,,,,,,,,,,,, -1200,1.2351469,3.1977916,,,,,,,,,,,,,, -1300,0.6949178,3.0547774,,,,,,,,,,,,,, -1400,1.0988444,2.795135,,,,,,,,,,,,,, -1500,0.70693326,2.6875315,,,,,,,,,,,,,, -1600,0.6352713,2.6784797,,,,,,,,,,,,,, -1700,0.83065474,2.5250218,,,,,,,,,,,,,, -1800,0.9893713,2.4568946,,,,,,,,,,,,,, -1820,,,3.4769754,0.6580758933017746,3.3963308,0.6373519217586916,5348.0,3.0782773,0.5873296366258404,2472.0,1477.4896006584167,1735.0940651893616,1477.4896006584167,257.50309681892395,0.0270421504974365,0.0 -1900,0.6396137,2.369795,,,,,,,,,,,,,, -2000,0.6139592,2.269916,,,,,,,,,,,,,, -2100,1.0443202,2.2600148,,,,,,,,,,,,,, -2200,1.048399,2.2689147,,,,,,,,,,,,,, -2300,0.7374526,2.195173,,,,,,,,,,,,,, -2400,0.8675659,2.068522,,,,,,,,,,,,,, -2500,0.65946215,2.0900593,,,,,,,,,,,,,, -2600,0.57099795,2.079349,,,,,,,,,,,,,, -2700,0.5505874,2.0413623,,,,,,,,,,,,,, -2800,0.55602306,1.9749429,,,,,,,,,,,,,, -2900,0.5325472,1.939666,,,,,,,,,,,,,, -3000,0.51601315,1.972864,,,,,,,,,,,,,, -3100,0.6586059,1.9553022,,,,,,,,,,,,,, -3200,0.5328167,1.9048091,,,,,,,,,,,,,, -3300,0.5376849,1.848808,,,,,,,,,,,,,, -3400,0.90769047,1.8863325,,,,,,,,,,,,,, -3500,0.8665699,1.8516741,,,,,,,,,,,,,, -3600,0.7695097,1.8010722,,,,,,,,,,,,,, -3665,,,0.88411194,0.2777697504927147,0.9202193,0.2683221178446952,5348.0,0.64782745,0.2096155017975748,2472.0,2917.4977061748505,3306.841879844665,2917.4977061748505,389.10982155799866,0.081702709197998,0.0 -3700,0.56038135,1.8857012,,,,,,,,,,,,,, -3800,0.49767435,1.8338882,,,,,,,,,,,,,, -3900,0.5044699,1.8178127,,,,,,,,,,,,,, -4000,0.57792836,1.7818881,,,,,,,,,,,,,, -4100,0.64517057,1.840824,,,,,,,,,,,,,, -4200,0.6792081,1.8115078,,,,,,,,,,,,,, -4300,0.616004,1.7647158,,,,,,,,,,,,,, -4400,0.5678123,1.7052523,,,,,,,,,,,,,, -4500,0.50144833,1.7257274,,,,,,,,,,,,,, -4600,0.49466732,1.7240711,,,,,,,,,,,,,, -4700,0.6593073,1.7392294,,,,,,,,,,,,,, -4800,0.4498648,1.7571301,,,,,,,,,,,,,, -4900,0.46477848,1.7649542,,,,,,,,,,,,,, -5000,0.509262,1.7440035,,,,,,,,,,,,,, -5100,0.58629555,1.7226006,,,,,,,,,,,,,, -5200,0.45675197,1.6892185,,,,,,,,,,,,,, -5300,0.41298994,1.6268873,,,,,,,,,,,,,, -5400,0.66220516,1.71508,,,,,,,,,,,,,, -5500,0.52090406,1.678426,,,,,,,,,,,,,, -5510,,,0.5478368,0.1874115335970506,0.74291396,0.222394933238074,5348.0,0.4876269,0.1638535128876972,2472.0,4358.005998134613,4880.403652906418,4358.005998134613,522.0387523174286,0.1281547546386718,0.0 -5600,0.5193321,1.6462605,,,,,,,,,,,,,, -5700,0.81458825,1.7225289,,,,,,,,,,,,,, -5800,0.52552086,1.684839,,,,,,,,,,,,,, -5900,0.69461155,1.6771905,,,,,,,,,,,,,, -6000,0.5947555,1.6350092,,,,,,,,,,,,,, -6100,0.5491196,1.5891452,,,,,,,,,,,,,, -6200,0.57590204,1.677196,,,,,,,,,,,,,, -6300,0.5292258,1.663692,,,,,,,,,,,,,, -6400,0.5267032,1.6371428,,,,,,,,,,,,,, -6500,0.44888845,1.6131349,,,,,,,,,,,,,, -6600,0.50006336,1.6090777,,,,,,,,,,,,,, -6700,0.48041308,1.6003548,,,,,,,,,,,,,, -6800,0.6016054,1.6523373,,,,,,,,,,,,,, -6900,0.49184734,1.5918939,,,,,,,,,,,,,, -7000,0.5358269,1.578899,,,,,,,,,,,,,, -7100,0.43022764,1.5920035,,,,,,,,,,,,,, -7200,0.54420894,1.615948,,,,,,,,,,,,,, -7300,0.50362766,1.5633881,,,,,,,,,,,,,, -7342,,,0.55028814,0.1848111126959918,0.6664708,0.201840176873244,5348.0,0.42051786,0.1432778827209392,2472.0,5798.597426176071,6452.462328910828,5798.597426176071,653.3796038627625,0.1764125823974609,0.0 -7400,0.50640255,1.5631448,,,,,,,,,,,,,, -7500,0.4770142,1.599591,,,,,,,,,,,,,, -7600,0.7299109,1.5902377,,,,,,,,,,,,,, -7700,0.6322739,1.5337337,,,,,,,,,,,,,, -7800,0.59434795,1.5594249,,,,,,,,,,,,,, -7900,0.4939353,1.5999649,,,,,,,,,,,,,, -8000,0.4688147,1.6053185,,,,,,,,,,,,,, -8100,0.72523695,1.5731971,,,,,,,,,,,,,, -8200,0.52210724,1.5541428,,,,,,,,,,,,,, -8300,0.70812684,1.5180931,,,,,,,,,,,,,, -8400,0.46469447,1.5793499,,,,,,,,,,,,,, -8500,0.5814556,1.5882083,,,,,,,,,,,,,, -8600,0.4651068,1.5269468,,,,,,,,,,,,,, -8700,0.5943937,1.613956,,,,,,,,,,,,,, -8800,0.5679697,1.5567833,,,,,,,,,,,,,, -8900,0.38799584,1.5514938,,,,,,,,,,,,,, -9000,0.55704314,1.5035038,,,,,,,,,,,,,, -9100,0.4682057,1.5890431,,,,,,,,,,,,,, -9177,,,0.45161825,0.1562252391663874,0.6241024,0.1880726416096237,5348.0,0.39491493,0.1330814697459021,2472.0,7238.834161758423,8026.899383306503,7238.834161758423,787.4491305351257,0.2273602485656738,0.0 -9200,0.47940075,1.5613925,,,,,,,,,,,,,, -9300,0.5081962,1.5406215,,,,,,,,,,,,,, -9400,0.497451,1.4998605,,,,,,,,,,,,,, -9500,0.6455264,1.5598414,,,,,,,,,,,,,, -9600,0.48044103,1.5267,,,,,,,,,,,,,, -9700,0.47191304,1.5682468,,,,,,,,,,,,,, -9800,0.52391326,1.5057099,,,,,,,,,,,,,, -9900,0.5313485,1.5635402,,,,,,,,,,,,,, -10000,0.63232696,1.5692036,,,,,,,,,,,,,, -10100,0.5131987,1.4932767,,,,,,,,,,,,,, -10200,0.5522921,1.5522761,,,,,,,,,,,,,, -10300,0.46875936,1.4968345,,,,,,,,,,,,,, -10400,0.57048017,1.4689153,,,,,,,,,,,,,, -10500,0.5413277,1.4926593,,,,,,,,,,,,,, -10600,0.44170117,1.4829062,,,,,,,,,,,,,, -10700,0.57828313,1.5243785,,,,,,,,,,,,,, -10800,0.48022786,1.4652857,,,,,,,,,,,,,, -10900,0.4635662,1.488042,,,,,,,,,,,,,, -11000,0.55880773,1.4286888,,,,,,,,,,,,,, -11007,,,0.43340543,0.1545875990998041,0.60133076,0.1809088890390723,5348.0,0.3724855,0.1259114821359657,2472.0,8678.961176633835,9599.821209192276,8678.961176633835,920.1078622341156,0.2839655876159668,0.0 -11100,0.50786525,1.4329774,,,,,,,,,,,,,, -11200,0.6183548,1.5157428,,,,,,,,,,,,,, -11300,0.4625067,1.5160578,,,,,,,,,,,,,, -11400,0.5828083,1.4458039,,,,,,,,,,,,,, -11500,0.5616337,1.5304478,,,,,,,,,,,,,, -11600,0.52548903,1.4298023,,,,,,,,,,,,,, -11700,0.49675047,1.433459,,,,,,,,,,,,,, -11800,0.46553427,1.4405996,,,,,,,,,,,,,, -11900,0.5052406,1.4804751,,,,,,,,,,,,,, -12000,0.64699644,1.4159483,,,,,,,,,,,,,, -12100,0.552574,1.4694406,,,,,,,,,,,,,, -12200,0.43592975,1.4431901,,,,,,,,,,,,,, -12300,0.6061976,1.4459257,,,,,,,,,,,,,, -12400,0.81126195,1.479743,,,,,,,,,,,,,, -12500,0.4526121,1.4485645,,,,,,,,,,,,,, -12600,0.5185242,1.5006638,,,,,,,,,,,,,, -12700,0.39087608,1.4388866,,,,,,,,,,,,,, -12800,0.37506777,1.4186625,,,,,,,,,,,,,, -12854,,,0.40785065,0.1410327868852459,0.5672202,0.1705784102648271,5348.0,0.35132352,0.1199601893039221,2472.0,10119.012127161026,11171.424641609192,10119.012127161026,1051.5222551822662,0.340630292892456,0.0 -12900,0.5205893,1.4511648,,,,,,,,,,,,,, -13000,0.5002004,1.4651916,,,,,,,,,,,,,, -13100,0.592962,1.4458218,,,,,,,,,,,,,, -13200,0.6401119,1.4098235,,,,,,,,,,,,,, -13300,0.5304758,1.4431162,,,,,,,,,,,,,, -13400,0.5934094,1.4813899,,,,,,,,,,,,,, -13500,0.5508969,1.33997,,,,,,,,,,,,,, -13600,0.5282834,1.4229758,,,,,,,,,,,,,, -13700,0.6805935,1.4587806,,,,,,,,,,,,,, -13800,0.53891164,1.4686478,,,,,,,,,,,,,, -13900,0.6127596,1.4393224,,,,,,,,,,,,,, -14000,0.46150213,1.4840297,,,,,,,,,,,,,, -14100,0.5930339,1.4170991,,,,,,,,,,,,,, -14200,0.6002119,1.4656934,,,,,,,,,,,,,, -14300,0.6370007,1.4580336,,,,,,,,,,,,,, -14400,0.5016254,1.4224591,,,,,,,,,,,,,, -14500,0.6129039,1.4523492,,,,,,,,,,,,,, -14600,0.464738,1.3406495,,,,,,,,,,,,,, -14692,,,0.4037296,0.1419416377749473,0.55786395,0.1679716539386157,5348.0,0.33908087,0.1147807364978774,2472.0,11558.991518497469,12742.170778512957,11558.991518497469,1182.161565065384,0.389937162399292,0.0 -14700,0.45596927,1.4087974,,,,,,,,,,,,,, -14800,0.49403688,1.3767811,,,,,,,,,,,,,, -14900,0.47884464,1.4428579,,,,,,,,,,,,,, -15000,0.51872927,1.436728,,,,,,,,,,,,,, -15100,0.48653394,1.3962916,,,,,,,,,,,,,, -15200,0.46298814,1.4379363,,,,,,,,,,,,,, -15300,0.55310094,1.4879236,,,,,,,,,,,,,, -15400,0.5220247,1.4929903,,,,,,,,,,,,,, -15500,0.45236418,1.3733171,,,,,,,,,,,,,, -15600,0.48289147,1.413524,,,,,,,,,,,,,, -15700,0.5400831,1.3960592,,,,,,,,,,,,,, -15800,0.4924354,1.4667056,,,,,,,,,,,,,, -15900,0.47346103,1.4054604,,,,,,,,,,,,,, -16000,0.51429784,1.4455477,,,,,,,,,,,,,, -16100,0.42142287,1.4066633,,,,,,,,,,,,,, -16200,0.62670565,1.4141363,,,,,,,,,,,,,, -16300,0.4814045,1.3940241,,,,,,,,,,,,,, -16400,0.5154961,1.3955592,,,,,,,,,,,,,, -16500,0.4909776,1.4309702,,,,,,,,,,,,,, -16515,,,0.399308,0.1375804774441912,0.535952,0.161435453816967,5348.0,0.32374918,0.1109621595271464,2472.0,12999.028420209885,14314.772800445557,12999.028420209885,1314.5952589511871,0.441178560256958,0.0 -16600,0.5141962,1.3395438,,,,,,,,,,,,,, -16700,0.49748716,1.4243832,,,,,,,,,,,,,, -16800,0.48889372,1.3978752,,,,,,,,,,,,,, -16900,0.5118833,1.3935957,,,,,,,,,,,,,, -17000,0.56698215,1.4172965,,,,,,,,,,,,,, -17100,0.5074909,1.3795055,,,,,,,,,,,,,, -17200,0.5085744,1.4178392,,,,,,,,,,,,,, -17300,0.5528053,1.409152,,,,,,,,,,,,,, -17400,0.47707677,1.3910358,,,,,,,,,,,,,, -17500,0.5136587,1.3801205,,,,,,,,,,,,,, -17600,0.45491698,1.3757219,,,,,,,,,,,,,, -17700,0.5614682,1.4584862,,,,,,,,,,,,,, -17800,0.54313874,1.3367559,,,,,,,,,,,,,, -17900,0.5186098,1.4485292,,,,,,,,,,,,,, -18000,0.5503097,1.4033831,,,,,,,,,,,,,, -18100,0.45523605,1.3325224,,,,,,,,,,,,,, -18200,0.4985848,1.393488,,,,,,,,,,,,,, -18300,0.5109418,1.332491,,,,,,,,,,,,,, -18354,,,0.37713492,0.132674836900189,0.5242192,0.1588480067968757,5348.0,0.31622794,0.1068389088619422,2472.0,14439.1499106884,15887.59229850769,14439.1499106884,1447.1639330387115,0.4920070171356201,0.0 -18400,0.5470941,1.3923457,,,,,,,,,,,,,, -18500,0.48844692,1.4054302,,,,,,,,,,,,,, -18600,0.50771624,1.357607,,,,,,,,,,,,,, -18700,0.57098806,1.364195,,,,,,,,,,,,,, -18800,0.46863344,1.3569608,,,,,,,,,,,,,, -18900,0.49897397,1.3113867,,,,,,,,,,,,,, -19000,0.6019496,1.4038556,,,,,,,,,,,,,, -19100,0.6574063,1.4486456,,,,,,,,,,,,,, -19200,0.5555425,1.4201133,,,,,,,,,,,,,, -19300,0.59265363,1.3665874,,,,,,,,,,,,,, -19400,0.5539409,1.3876176,,,,,,,,,,,,,, -19500,0.46736863,1.3577198,,,,,,,,,,,,,, -19600,0.50627285,1.3956897,,,,,,,,,,,,,, -19700,0.50232416,1.3680199,,,,,,,,,,,,,, -19800,0.6635635,1.2956995,,,,,,,,,,,,,, -19900,0.57590634,1.3974261,,,,,,,,,,,,,, -20000,0.54292315,1.323567,,,,,,,,,,,,,, -20100,0.6263136,1.3666247,,,,,,,,,,,,,, -20191,,,0.37280694,0.1348939676380699,0.50043136,0.1525917916139683,5348.0,0.3005866,0.102979708731948,2472.0,15879.57132577896,17459.250715255737,15879.57132577896,1578.2720866203308,0.5422773361206055,0.0 -20200,0.43695116,1.3333269,,,,,,,,,,,,,, -20300,0.5170655,1.3819813,,,,,,,,,,,,,, -20400,0.61222744,1.4127378,,,,,,,,,,,,,, -20500,0.47654653,1.3543558,,,,,,,,,,,,,, -20600,0.5132263,1.3269566,,,,,,,,,,,,,, -20700,0.65606165,1.3079133,,,,,,,,,,,,,, -20800,0.4905546,1.3527745,,,,,,,,,,,,,, -20900,0.6409442,1.4032654,,,,,,,,,,,,,, -21000,0.5977623,1.393566,,,,,,,,,,,,,, -21100,0.43614736,1.3284812,,,,,,,,,,,,,, -21200,0.5265273,1.3825226,,,,,,,,,,,,,, -21300,0.4919974,1.3367425,,,,,,,,,,,,,, -21400,0.5647937,1.330099,,,,,,,,,,,,,, -21500,0.5960511,1.3507041,,,,,,,,,,,,,, -21600,0.44269806,1.3401923,,,,,,,,,,,,,, -21700,0.5157749,1.2956574,,,,,,,,,,,,,, -21800,0.45011818,1.3221363,,,,,,,,,,,,,, -21900,0.49358788,1.3280839,,,,,,,,,,,,,, -22000,0.45736587,1.3218195,,,,,,,,,,,,,, -22036,,,0.31283918,0.1112585406676173,0.49085632,0.1479865221043282,5348.0,0.29445583,0.1005016960168992,2472.0,17319.888331651688,19031.927735090256,17319.888331651688,1710.5015604496002,0.5935218334197998,0.0 -22100,0.4376982,1.3195142,,,,,,,,,,,,,, -22200,0.49739763,1.3366896,,,,,,,,,,,,,, -22300,0.48396423,1.317873,,,,,,,,,,,,,, -22400,0.5964747,1.3603889,,,,,,,,,,,,,, -22500,0.5167078,1.3976556,,,,,,,,,,,,,, -22600,0.4180003,1.3360735,,,,,,,,,,,,,, -22700,0.6918073,1.3457668,,,,,,,,,,,,,, -22800,0.6198089,1.3106078,,,,,,,,,,,,,, -22900,0.5398295,1.3597628,,,,,,,,,,,,,, -23000,0.46060872,1.3271711,,,,,,,,,,,,,, -23100,0.4981128,1.3190497,,,,,,,,,,,,,, -23200,0.5494125,1.3813859,,,,,,,,,,,,,, -23300,0.59744596,1.281305,,,,,,,,,,,,,, -23400,0.4854157,1.2989738,,,,,,,,,,,,,, -23500,0.47787952,1.3312333,,,,,,,,,,,,,, -23600,0.54448646,1.3819999,,,,,,,,,,,,,, -23700,0.53754616,1.3239503,,,,,,,,,,,,,, -23800,0.44865492,1.2887118,,,,,,,,,,,,,, -23861,,,0.32723373,0.1149733224814613,0.48551804,0.1441922434517315,5348.0,0.28720966,0.095565982166433,2472.0,18759.982160806656,20603.729377031326,18759.982160806656,1842.077701330185,0.6459090709686279,0.0 -23900,0.56934,1.295155,,,,,,,,,,,,,, -24000,0.59637284,1.3898015,,,,,,,,,,,,,, -24100,0.5718747,1.2946686,,,,,,,,,,,,,, -24200,0.5990261,1.304861,,,,,,,,,,,,,, -24300,0.4970443,1.2969381,,,,,,,,,,,,,, -24400,0.57867366,1.3138201,,,,,,,,,,,,,, -24500,0.58181804,1.2999928,,,,,,,,,,,,,, -24600,0.44142845,1.3256704,,,,,,,,,,,,,, -24700,0.4840862,1.3312062,,,,,,,,,,,,,, -24800,0.49848646,1.338264,,,,,,,,,,,,,, -24900,0.7310281,1.3258916,,,,,,,,,,,,,, -25000,0.55898184,1.3238953,,,,,,,,,,,,,, -25100,0.4861886,1.3494639,,,,,,,,,,,,,, -25200,0.66411865,1.2648724,,,,,,,,,,,,,, -25300,0.45731515,1.2965784,,,,,,,,,,,,,, -25400,0.54806703,1.3240249,,,,,,,,,,,,,, -25500,0.4993544,1.3433473,,,,,,,,,,,,,, -25600,0.43680194,1.2946478,,,,,,,,,,,,,, -25690,,,0.31858635,0.1120703817632829,0.47096306,0.1414889405949197,5348.0,0.27616546,0.0951394389941705,2472.0,20200.02801823616,22175.355345249176,20200.02801823616,1973.528375864029,0.6977200508117676,0.0 -25700,0.5346118,1.2679049,,,,,,,,,,,,,, -25800,0.5538668,1.2787989,,,,,,,,,,,,,, -25900,0.47413492,1.275293,,,,,,,,,,,,,, -26000,0.5040651,1.2779356,,,,,,,,,,,,,, -26100,0.5266499,1.2957933,,,,,,,,,,,,,, -26200,0.5705999,1.3091611,,,,,,,,,,,,,, -26300,0.6831459,1.3394057,,,,,,,,,,,,,, -26400,0.568162,1.2966245,,,,,,,,,,,,,, -26500,0.4865162,1.3373718,,,,,,,,,,,,,, -26600,0.64759356,1.3023237,,,,,,,,,,,,,, -26700,0.7633718,1.2886502,,,,,,,,,,,,,, -26800,0.5576255,1.2167689,,,,,,,,,,,,,, -26900,0.48898023,1.2951751,,,,,,,,,,,,,, -27000,0.57635826,1.2870008,,,,,,,,,,,,,, -27100,0.60949785,1.2623358,,,,,,,,,,,,,, -27200,0.5935604,1.3029265,,,,,,,,,,,,,, -27300,0.53110975,1.2951294,,,,,,,,,,,,,, -27400,0.52033675,1.2477795,,,,,,,,,,,,,, -27500,0.5210854,1.3329793,,,,,,,,,,,,,, -27515,,,0.31003976,0.1098124938211935,0.46048242,0.1372891665138013,5348.0,0.27005804,0.0916052241382812,2472.0,21640.408204317093,23748.209998607635,21640.408204317093,2105.8720936775208,0.7506282329559326,0.0 -27600,0.5572831,1.2652304,,,,,,,,,,,,,, -27700,0.5508173,1.329027,,,,,,,,,,,,,, -27800,0.5209931,1.3141313,,,,,,,,,,,,,, -27900,0.6256838,1.2246372,,,,,,,,,,,,,, -28000,0.57263947,1.2308873,,,,,,,,,,,,,, -28100,0.47205105,1.2994466,,,,,,,,,,,,,, -28200,0.5999875,1.3234818,,,,,,,,,,,,,, -28300,0.5216021,1.2588706,,,,,,,,,,,,,, -28400,0.54790866,1.2815583,,,,,,,,,,,,,, -28500,0.5974798,1.3182269,,,,,,,,,,,,,, -28600,0.583677,1.2883517,,,,,,,,,,,,,, -28700,0.5893778,1.2574166,,,,,,,,,,,,,, -28800,0.52053887,1.2730198,,,,,,,,,,,,,, -28900,0.52129436,1.2930288,,,,,,,,,,,,,, -29000,0.4830082,1.2499017,,,,,,,,,,,,,, -29100,0.52318543,1.2787126,,,,,,,,,,,,,, -29200,0.52655524,1.2926522,,,,,,,,,,,,,, -29300,0.5586227,1.3121971,,,,,,,,,,,,,, -29360,,,0.27465776,0.0989900013511687,0.45722303,0.1353389265956728,5348.0,0.26472273,0.0887006682509698,2472.0,23080.796627759933,25320.12511134148,23080.796627759933,2237.2667417526245,0.8045666217803955,0.0 -29400,0.52284163,1.2546297,,,,,,,,,,,,,, -29500,0.5865267,1.2961723,,,,,,,,,,,,,, -29600,0.61615133,1.2242826,,,,,,,,,,,,,, -29700,0.5395969,1.2778136,,,,,,,,,,,,,, -29800,0.47757858,1.2707677,,,,,,,,,,,,,, -29900,0.59041953,1.3138051,,,,,,,,,,,,,, -30000,0.5901737,1.277805,,,,,,,,,,,,,, -30100,0.7576494,1.3077419,,,,,,,,,,,,,, -30200,0.5748315,1.2925078,,,,,,,,,,,,,, -30300,0.5389415,1.2708797,,,,,,,,,,,,,, -30400,0.50287473,1.2216108,,,,,,,,,,,,,, -30500,0.51213044,1.2985306,,,,,,,,,,,,,, -30600,0.6835635,1.2695339,,,,,,,,,,,,,, -30700,0.6482434,1.2622933,,,,,,,,,,,,,, -30800,0.5562435,1.2261873,,,,,,,,,,,,,, -30900,0.58490026,1.2744267,,,,,,,,,,,,,, -31000,0.5932747,1.2579827,,,,,,,,,,,,,, -31100,0.5239964,1.3108661,,,,,,,,,,,,,, -31195,,,0.24588323,0.0895053800268287,0.4484083,0.1336203983509852,5348.0,0.25443456,0.0857351776247638,2472.0,24520.75093364716,26890.44734716416,24520.75093364716,2367.49915266037,0.8600101470947266,0.0 -31200,0.46596035,1.2270746,,,,,,,,,,,,,, -31300,0.52637184,1.2029285,,,,,,,,,,,,,, -31400,0.5418324,1.2527117,,,,,,,,,,,,,, -31500,0.71751463,1.3102248,,,,,,,,,,,,,, -31600,0.5148313,1.2626538,,,,,,,,,,,,,, -31700,0.58788645,1.2454191,,,,,,,,,,,,,, -31800,0.52692986,1.2299953,,,,,,,,,,,,,, -31900,0.45024347,1.2605015,,,,,,,,,,,,,, -32000,0.6790445,1.2608682,,,,,,,,,,,,,, -32100,0.66196275,1.2459134,,,,,,,,,,,,,, -32200,0.54738593,1.2809035,,,,,,,,,,,,,, -32300,0.5884903,1.231695,,,,,,,,,,,,,, -32400,0.50911033,1.2605852,,,,,,,,,,,,,, -32500,0.53698504,1.2398028,,,,,,,,,,,,,, -32600,0.6624473,1.2787561,,,,,,,,,,,,,, -32700,0.47982603,1.2210017,,,,,,,,,,,,,, -32800,0.48658657,1.2318106,,,,,,,,,,,,,, -32900,0.6367525,1.2416188,,,,,,,,,,,,,, -33000,0.5250923,1.2406162,,,,,,,,,,,,,, -33018,,,0.27327314,0.098914586620965,0.44101134,0.1304633268003514,5348.0,0.24990562,0.0821197164503483,2472.0,25961.20705795288,28461.4237473011,25961.20705795288,2497.8841466903687,0.9160919189453124,0.0 -33100,0.5327647,1.2159663,,,,,,,,,,,,,, -33200,0.48673397,1.2036867,,,,,,,,,,,,,, -33300,0.63889426,1.281402,,,,,,,,,,,,,, -33400,0.6629967,1.2497092,,,,,,,,,,,,,, -33500,0.6024158,1.2457794,,,,,,,,,,,,,, -33600,0.5958953,1.2435899,,,,,,,,,,,,,, -33700,0.6180345,1.2767926,,,,,,,,,,,,,, -33800,0.5553262,1.2718146,,,,,,,,,,,,,, -33900,0.53711426,1.2514637,,,,,,,,,,,,,, -34000,0.5222271,1.2032354,,,,,,,,,,,,,, -34100,0.5593179,1.2393466,,,,,,,,,,,,,, -34200,0.61804867,1.2449934,,,,,,,,,,,,,, -34300,0.47412825,1.2517219,,,,,,,,,,,,,, -34400,0.4892844,1.2463107,,,,,,,,,,,,,, -34500,0.67809325,1.1953653,,,,,,,,,,,,,, -34600,0.6452926,1.2523798,,,,,,,,,,,,,, -34700,0.68779314,1.2534782,,,,,,,,,,,,,, -34800,0.55235934,1.241978,,,,,,,,,,,,,, -34855,,,0.24580301,0.0894540231484832,0.42582294,0.1269876516987362,5348.0,0.24126007,0.0799666890094042,2472.0,27401.429488182068,30034.20651745796,27401.429488182068,2630.312391757965,0.9704453945159912,0.0 -34900,0.62860715,1.211157,,,,,,,,,,,,,, -35000,0.52953404,1.2436703,,,,,,,,,,,,,, -35100,0.5255296,1.22733,,,,,,,,,,,,,, -35200,0.52613807,1.2294528,,,,,,,,,,,,,, -35300,0.61285734,1.2868336,,,,,,,,,,,,,, -35400,0.5248562,1.2665762,,,,,,,,,,,,,, -35500,0.54970294,1.210672,,,,,,,,,,,,,, -35600,0.51328486,1.2459875,,,,,,,,,,,,,, -35700,0.58242786,1.1841059,,,,,,,,,,,,,, -35800,0.5551075,1.1928853,,,,,,,,,,,,,, -35900,0.63958895,1.1905737,,,,,,,,,,,,,, -36000,0.5638377,1.2190871,,,,,,,,,,,,,, -36100,0.5048702,1.2041278,,,,,,,,,,,,,, -36200,0.68849975,1.2177155,,,,,,,,,,,,,, -36300,0.612023,1.2258567,,,,,,,,,,,,,, -36400,0.51424927,1.2104033,,,,,,,,,,,,,, -36500,0.49630094,1.1549817,,,,,,,,,,,,,, -36600,0.56202745,1.2206148,,,,,,,,,,,,,, -36692,,,0.26628497,0.0925580900906949,0.42301342,0.1271903994129971,5348.0,0.24345164,0.0814291227428757,2472.0,28841.924087047577,31606.14753127098,28841.924087047577,2761.6267426013947,1.024658441543579,0.0 -36700,0.5813151,1.1864334,,,,,,,,,,,,,, -36800,0.6063907,1.1799186,,,,,,,,,,,,,, -36900,0.6360657,1.1980408,,,,,,,,,,,,,, -37000,0.5450829,1.200948,,,,,,,,,,,,,, -37100,0.5262162,1.1417611,,,,,,,,,,,,,, -37200,0.51107764,1.1811574,,,,,,,,,,,,,, -37300,0.6124897,1.2376361,,,,,,,,,,,,,, -37400,0.53544396,1.2047279,,,,,,,,,,,,,, -37500,0.6126241,1.2026529,,,,,,,,,,,,,, -37600,0.5355609,1.106101,,,,,,,,,,,,,, -37700,0.56564075,1.2295994,,,,,,,,,,,,,, -37800,0.54674923,1.2431208,,,,,,,,,,,,,, -37900,0.783973,1.1897511,,,,,,,,,,,,,, -38000,0.5244495,1.1731904,,,,,,,,,,,,,, -38100,0.5395722,1.2359217,,,,,,,,,,,,,, -38200,0.51237714,1.1673931,,,,,,,,,,,,,, -38300,0.6462599,1.2114524,,,,,,,,,,,,,, -38400,0.6055872,1.1298469,,,,,,,,,,,,,, -38500,0.5472921,1.168892,,,,,,,,,,,,,, -38534,,,0.2095532,0.076352120071763,0.40643167,0.120103884066926,5348.0,0.2337808,0.0773668068165661,2472.0,30282.230488061905,33178.02899599075,30282.230488061905,2893.059848546982,1.086231708526611,0.0 -38600,0.55056137,1.1256514,,,,,,,,,,,,,, -38700,0.6545294,1.1635908,,,,,,,,,,,,,, -38800,0.6342907,1.220904,,,,,,,,,,,,,, -38900,0.58292454,1.1665682,,,,,,,,,,,,,, -39000,0.5622495,1.1990926,,,,,,,,,,,,,, -39100,0.5610412,1.1932509,,,,,,,,,,,,,, -39200,0.50813264,1.1754589,,,,,,,,,,,,,, -39300,0.6335941,1.2152102,,,,,,,,,,,,,, -39400,0.48829573,1.1592531,,,,,,,,,,,,,, -39500,0.5226082,1.115842,,,,,,,,,,,,,, -39600,0.50060886,1.1736983,,,,,,,,,,,,,, -39700,0.5485666,1.1339467,,,,,,,,,,,,,, -39800,0.6138593,1.195108,,,,,,,,,,,,,, -39900,0.6236417,1.1430428,,,,,,,,,,,,,, -40000,0.57835305,1.1920532,,,,,,,,,,,,,, -40100,0.49548596,1.189247,,,,,,,,,,,,,, -40200,0.5970434,1.1648815,,,,,,,,,,,,,, -40300,0.72140425,1.1713604,,,,,,,,,,,,,, -40365,,,0.19563828,0.0730934437328656,0.40345013,0.1190322175772613,5348.0,0.22827744,0.0759653078219893,2472.0,31722.74143695832,34749.68506407738,31722.74143695832,3024.0696020126343,1.1442553997039795,0.0 -40400,0.5153281,1.131845,,,,,,,,,,,,,, -40500,0.66603965,1.1216222,,,,,,,,,,,,,, -40600,0.53647625,1.1391647,,,,,,,,,,,,,, -40700,0.55604076,1.1311113,,,,,,,,,,,,,, -40800,0.52739906,1.1722442,,,,,,,,,,,,,, -40900,0.7565968,1.1760906,,,,,,,,,,,,,, -41000,0.48761037,1.1621003,,,,,,,,,,,,,, -41100,0.5322868,1.1403092,,,,,,,,,,,,,, -41200,0.5335765,1.2128733,,,,,,,,,,,,,, -41300,0.7063631,1.128421,,,,,,,,,,,,,, -41400,0.55585223,1.1256348,,,,,,,,,,,,,, -41500,0.55741584,1.1436887,,,,,,,,,,,,,, -41600,0.5488334,1.2277107,,,,,,,,,,,,,, -41700,0.6184176,1.1797117,,,,,,,,,,,,,, -41800,0.47159123,1.1398429,,,,,,,,,,,,,, -41900,0.63896084,1.1213607,,,,,,,,,,,,,, -42000,0.53579545,1.1322342,,,,,,,,,,,,,, -42100,0.63284093,1.1499747,,,,,,,,,,,,,, -42198,,,0.21608451,0.0794167166053134,0.39799383,0.1173136893325738,5348.0,0.21821883,0.0725732740235208,2472.0,33163.28744673729,36320.83214735985,33163.28744673729,3154.5400528907776,1.197132587432861,0.0 -42200,0.6267701,1.1436839,,,,,,,,,,,,,, -42300,0.6300361,1.1631191,,,,,,,,,,,,,, -42400,0.5367777,1.1538568,,,,,,,,,,,,,, -42500,0.5908184,1.1380873,,,,,,,,,,,,,, -42600,0.5959623,1.110309,,,,,,,,,,,,,, -42700,0.5906137,1.1511654,,,,,,,,,,,,,, -42800,0.5636152,1.179668,,,,,,,,,,,,,, -42900,0.72627306,1.1370387,,,,,,,,,,,,,, -43000,0.6678894,1.1357555,,,,,,,,,,,,,, -43100,0.5493719,1.0879498,,,,,,,,,,,,,, -43200,0.63042754,1.1439574,,,,,,,,,,,,,, -43300,0.4969707,1.1249696,,,,,,,,,,,,,, -43400,0.567291,1.140392,,,,,,,,,,,,,, -43500,0.5467421,1.121321,,,,,,,,,,,,,, -43600,0.69984156,1.1678348,,,,,,,,,,,,,, -43700,0.6459318,1.1147376,,,,,,,,,,,,,, -43800,0.61600024,1.1274,,,,,,,,,,,,,, -43900,0.48712882,1.10018,,,,,,,,,,,,,, -44000,0.97537476,1.1742738,,,,,,,,,,,,,, -44043,,,0.14231463,0.0532641934465834,0.38264957,0.1126601465576334,5348.0,0.21474697,0.0706030507992606,2472.0,34603.37959957123,37903.424902677536,34603.37959957123,3296.9027593135834,1.2538504600524902,0.0 -44100,0.5925071,1.1089033,,,,,,,,,,,,,, -44200,0.53675514,1.1470637,,,,,,,,,,,,,, -44300,0.6378376,1.133516,,,,,,,,,,,,,, -44400,0.5506119,1.1593481,,,,,,,,,,,,,, -44500,0.6361831,1.0867684,,,,,,,,,,,,,, -44600,0.5758166,1.1399189,,,,,,,,,,,,,, -44700,0.5359217,1.1164005,,,,,,,,,,,,,, -44800,0.53021806,1.1125643,,,,,,,,,,,,,, -44900,0.5294907,1.1384026,,,,,,,,,,,,,, -45000,0.52378356,1.1025085,,,,,,,,,,,,,, -45100,0.60674953,1.1768568,,,,,,,,,,,,,, -45200,0.5877531,1.1615891,,,,,,,,,,,,,, -45300,0.5332529,1.1134437,,,,,,,,,,,,,, -45400,0.5361114,1.0880858,,,,,,,,,,,,,, -45500,0.560895,1.1063678,,,,,,,,,,,,,, -45600,0.73528963,1.2084848,,,,,,,,,,,,,, -45700,0.51298636,1.0950571,,,,,,,,,,,,,, -45800,0.6361666,1.1194757,,,,,,,,,,,,,, -45893,,,0.12276159,0.0459216859633146,0.3812691,0.1105264682313641,5348.0,0.21339078,0.0698921455121564,2472.0,36043.25439476967,39476.55716729164,36043.25439476967,3430.021687746048,1.3127069473266602,0.0 -45900,0.6074009,1.1257977,,,,,,,,,,,,,, -46000,0.60993433,1.1291614,,,,,,,,,,,,,, -46100,0.66949683,1.121102,,,,,,,,,,,,,, -46200,0.65050435,1.0958521,,,,,,,,,,,,,, -46300,0.5873796,1.1469041,,,,,,,,,,,,,, -46400,0.5858593,1.1178676,,,,,,,,,,,,,, -46500,0.59656394,1.0841882,,,,,,,,,,,,,, -46600,0.5509933,1.1039971,,,,,,,,,,,,,, -46700,0.57362735,1.0827991,,,,,,,,,,,,,, -46800,0.5698136,1.0994928,,,,,,,,,,,,,, -46900,0.70575297,1.0944667,,,,,,,,,,,,,, -47000,0.5584715,1.1285604,,,,,,,,,,,,,, -47100,0.4870344,1.1933692,,,,,,,,,,,,,, -47200,0.5250773,1.0718474,,,,,,,,,,,,,, -47300,0.5456566,1.1334734,,,,,,,,,,,,,, -47400,0.7096623,1.1375803,,,,,,,,,,,,,, -47500,0.49042466,1.0999411,,,,,,,,,,,,,, -47600,0.61016613,1.083543,,,,,,,,,,,,,, -47700,0.6598335,1.1129947,,,,,,,,,,,,,, -47735,,,0.12371109,0.0474424386832767,0.36838114,0.1077845467623121,5348.0,0.2057504,0.067800052810107,2472.0,37483.68469452858,41051.488491773605,37483.68469452858,3564.374750375748,1.379469633102417,0.0 -47800,0.5356412,1.1335019,,,,,,,,,,,,,, -47900,0.56634116,1.1446767,,,,,,,,,,,,,, -48000,0.56559974,1.1174835,,,,,,,,,,,,,, -48100,0.5646383,1.0898361,,,,,,,,,,,,,, -48200,0.6069763,1.0648078,,,,,,,,,,,,,, -48300,0.67154413,1.120441,,,,,,,,,,,,,, -48400,0.5341808,1.0432037,,,,,,,,,,,,,, -48500,0.527867,1.1222985,,,,,,,,,,,,,, -48600,0.5896042,1.1144422,,,,,,,,,,,,,, -48700,0.5569264,1.05357,,,,,,,,,,,,,, -48800,0.7082181,1.1388558,,,,,,,,,,,,,, -48900,0.55477864,1.103897,,,,,,,,,,,,,, -49000,0.5361841,1.0932487,,,,,,,,,,,,,, -49100,0.5497908,1.0618746,,,,,,,,,,,,,, -49200,0.5856465,1.0933756,,,,,,,,,,,,,, -49300,0.523727,1.0741045,,,,,,,,,,,,,, -49400,0.55502445,1.0748785,,,,,,,,,,,,,, -49500,0.75398225,1.0979631,,,,,,,,,,,,,, -49571,,,0.12066799,0.0457725052918445,0.36527857,0.1063460034563658,5348.0,0.20391922,0.0671094591026344,2472.0,38924.20606184006,42625.732850551605,38924.20606184006,3697.955447912216,1.4431352615356443,0.0 -49600,0.57577777,1.0693958,,,,,,,,,,,,,, -49700,0.58702356,1.0277213,,,,,,,,,,,,,, -49800,0.54016155,1.0820519,,,,,,,,,,,,,, -49900,0.6091563,1.12256,,,,,,,,,,,,,, -50000,0.5525601,1.0853002,,,,,,,,,,,,,, -50100,0.64306915,1.0406413,,,,,,,,,,,,,, -50200,0.672712,1.1297112,,,,,,,,,,,,,, -50300,0.6038232,1.071721,,,,,,,,,,,,,, -50400,0.59255016,1.1365453,,,,,,,,,,,,,, -50500,0.57145375,1.0761611,,,,,,,,,,,,,, -50600,0.5881163,1.07365,,,,,,,,,,,,,, -50700,0.5728353,1.0517235,,,,,,,,,,,,,, -50800,0.6407559,1.0937254,,,,,,,,,,,,,, -50900,0.59202987,1.0040522,,,,,,,,,,,,,, -51000,0.56378204,1.0742491,,,,,,,,,,,,,, -51100,0.648314,1.1113144,,,,,,,,,,,,,, -51200,0.682348,1.0765972,,,,,,,,,,,,,, -51300,0.5796312,1.0861611,,,,,,,,,,,,,, -51400,0.5497829,1.0952204,,,,,,,,,,,,,, -51403,,,0.10996965,0.0426972796731894,0.3555751,0.1044054181912973,5348.0,0.19339007,0.0655048443117421,2472.0,40364.20062971115,44197.879747867584,40364.20062971115,3829.9755721092224,1.496940851211548,0.0 -51500,0.5993868,1.0734044,,,,,,,,,,,,,, -51600,0.572032,1.0812477,,,,,,,,,,,,,, -51700,0.49584645,1.0698085,,,,,,,,,,,,,, -51800,0.526282,1.0737861,,,,,,,,,,,,,, -51900,0.7025693,1.0641652,,,,,,,,,,,,,, -52000,0.68990797,1.1047965,,,,,,,,,,,,,, -52100,0.5794248,1.0788965,,,,,,,,,,,,,, -52200,0.6326944,1.0521184,,,,,,,,,,,,,, -52300,0.546752,1.0289761,,,,,,,,,,,,,, -52400,0.5146633,1.0342048,,,,,,,,,,,,,, -52500,0.6665292,1.0927075,,,,,,,,,,,,,, -52600,0.66114646,0.988851,,,,,,,,,,,,,, -52700,0.5244751,1.0400928,,,,,,,,,,,,,, -52800,0.62657833,0.9991752,,,,,,,,,,,,,, -52900,0.7388991,1.0685031,,,,,,,,,,,,,, -53000,0.6155367,1.0207393,,,,,,,,,,,,,, -53100,0.742444,1.0615268,,,,,,,,,,,,,, -53200,0.6323821,1.064041,,,,,,,,,,,,,, -53238,,,0.10054823,0.0394028957698239,0.35113725,0.1017600432528457,5348.0,0.18965232,0.0626002884244307,2472.0,41804.38695335388,45773.00025463104,41804.38695335388,3964.7778856754303,1.5501763820648191,0.0 -53300,0.7137004,1.0545408,,,,,,,,,,,,,, -53400,0.5536286,1.075935,,,,,,,,,,,,,, -53500,0.63459235,1.0503229,,,,,,,,,,,,,, -53600,0.6014696,1.0670533,,,,,,,,,,,,,, -53700,0.7141222,1.0208776,,,,,,,,,,,,,, -53800,0.6525429,1.0122868,,,,,,,,,,,,,, -53900,0.56639445,1.0012052,,,,,,,,,,,,,, -54000,0.5738022,1.0317963,,,,,,,,,,,,,, -54100,0.6267463,1.0174891,,,,,,,,,,,,,, -54200,0.5568229,1.083527,,,,,,,,,,,,,, -54300,0.5686618,1.0377303,,,,,,,,,,,,,, -54400,0.6938647,1.0112373,,,,,,,,,,,,,, -54500,0.5492226,1.0712346,,,,,,,,,,,,,, -54600,0.5767532,1.0299376,,,,,,,,,,,,,, -54700,0.8569828,1.0647807,,,,,,,,,,,,,, -54800,0.63466996,0.99391234,,,,,,,,,,,,,, -54900,0.6312409,1.063508,,,,,,,,,,,,,, -55000,0.63834697,1.0687488,,,,,,,,,,,,,, -55087,,,0.114656754,0.0430929397423648,0.34877175,0.0997615300694169,5348.0,0.18984833,0.0627018463225885,2472.0,43244.36760210991,47344.01453351975,43244.36760210991,4095.675683498383,1.6069636344909668,0.0 -55100,0.61133945,0.99678487,,,,,,,,,,,,,, -55200,0.57680786,1.0167764,,,,,,,,,,,,,, -55300,0.66108096,1.0393094,,,,,,,,,,,,,, -55400,0.73634857,1.0666945,,,,,,,,,,,,,, -55500,0.664899,1.0135458,,,,,,,,,,,,,, -55600,0.67761266,1.0489783,,,,,,,,,,,,,, -55700,0.77566916,1.0366845,,,,,,,,,,,,,, -55800,0.5796931,1.053995,,,,,,,,,,,,,, -55900,0.60607994,1.0036849,,,,,,,,,,,,,, -56000,0.5255393,0.9821308,,,,,,,,,,,,,, -56100,0.58356535,1.0068562,,,,,,,,,,,,,, -56200,0.5527827,0.99438006,,,,,,,,,,,,,, -56300,0.66083527,1.0677515,,,,,,,,,,,,,, -56400,0.6213452,1.0121781,,,,,,,,,,,,,, -56500,0.5633268,1.0027661,,,,,,,,,,,,,, -56600,0.5973848,1.01425,,,,,,,,,,,,,, -56700,0.7235302,1.0147766,,,,,,,,,,,,,, -56800,0.66178393,1.00092,,,,,,,,,,,,,, -56900,0.6251652,0.97890943,,,,,,,,,,,,,, -56927,,,0.094018914,0.0355911965754031,0.33511576,0.0966334224779632,5348.0,0.18326183,0.0600207178112241,2472.0,44684.72154283524,48916.36607122421,44684.72154283524,4227.535162687302,1.6658551692962646,0.0 -57000,0.7540664,1.0036335,,,,,,,,,,,,,, -57100,0.5922785,0.98354775,,,,,,,,,,,,,, -57200,0.6090404,1.0048953,,,,,,,,,,,,,, -57300,0.5970247,0.9960327,,,,,,,,,,,,,, -57400,0.6587327,0.9988903,,,,,,,,,,,,,, -57500,0.5576745,0.9887859,,,,,,,,,,,,,, -57600,1.0301687,1.002001,,,,,,,,,,,,,, -57700,0.6547654,1.02453,,,,,,,,,,,,,, -57800,0.70750505,0.97148407,,,,,,,,,,,,,, -57900,0.6761914,0.9752554,,,,,,,,,,,,,, -58000,0.5328376,0.9667233,,,,,,,,,,,,,, -58100,0.66602194,1.0038384,,,,,,,,,,,,,, -58200,0.7702936,1.0004855,,,,,,,,,,,,,, -58300,0.8254814,1.0177076,,,,,,,,,,,,,, -58400,0.65034395,0.994654,,,,,,,,,,,,,, -58500,0.5562851,0.99080896,,,,,,,,,,,,,, -58600,0.56249493,0.9663001,,,,,,,,,,,,,, -58700,0.5918941,0.9875783,,,,,,,,,,,,,, -58747,,,0.09163833,0.0359684636692946,0.33312958,0.0941328673354123,5348.0,0.18139437,0.0579692482684378,2472.0,46124.952325344086,50489.70118045807,46124.952325344086,4360.501399755478,1.725421667098999,0.0 -58800,0.75205666,1.0050814,,,,,,,,,,,,,, -58900,0.518384,1.0083795,,,,,,,,,,,,,, -59000,0.66752034,0.95881385,,,,,,,,,,,,,, -59100,0.62021774,1.0154755,,,,,,,,,,,,,, -59200,0.6082474,0.9962067,,,,,,,,,,,,,, -59300,0.66379666,0.9737995,,,,,,,,,,,,,, -59400,0.687265,1.0029061,,,,,,,,,,,,,, -59500,0.66740525,0.9516844,,,,,,,,,,,,,, -59600,0.665452,1.0218552,,,,,,,,,,,,,, -59700,0.8084717,0.9926994,,,,,,,,,,,,,, -59800,0.5821782,0.9565812,,,,,,,,,,,,,, -59900,0.60764515,1.0045382,,,,,,,,,,,,,, -60000,0.8119188,0.9645272,,,,,,,,,,,,,, -60100,0.7653128,0.9910251,,,,,,,,,,,,,, -60200,0.7713162,0.9355469,,,,,,,,,,,,,, -60300,0.5490178,0.94069517,,,,,,,,,,,,,, -60400,0.64552736,1.0188107,,,,,,,,,,,,,, -60500,0.57558227,0.9914707,,,,,,,,,,,,,, -60586,,,0.075750425,0.0300035989499534,0.32939577,0.0937466812130106,5348.0,0.17375073,0.0578473787906485,2472.0,47565.06146478653,52062.507381916046,47565.06146478653,4493.0639543533325,1.7812588214874268,0.0 -60600,0.6057669,0.9834392,,,,,,,,,,,,,, -60700,0.6955156,0.97446585,,,,,,,,,,,,,, -60800,0.5829779,0.93465555,,,,,,,,,,,,,, -60900,0.7763421,0.9870503,,,,,,,,,,,,,, -61000,0.70735925,1.0045125,,,,,,,,,,,,,, -61100,0.60472673,0.96669865,,,,,,,,,,,,,, -61200,0.7838619,0.97177947,,,,,,,,,,,,,, -61300,0.6564444,0.92401814,,,,,,,,,,,,,, -61400,0.7687286,0.96533006,,,,,,,,,,,,,, -61500,0.6763781,0.95402825,,,,,,,,,,,,,, -61600,0.72676617,0.960622,,,,,,,,,,,,,, -61700,0.77913135,0.9720821,,,,,,,,,,,,,, -61800,0.6103461,0.9866961,,,,,,,,,,,,,, -61900,0.83993256,0.96557456,,,,,,,,,,,,,, -62000,0.6269415,0.94774765,,,,,,,,,,,,,, -62100,0.65658253,0.9541569,,,,,,,,,,,,,, -62200,0.61155397,0.9184081,,,,,,,,,,,,,, -62300,0.62148935,0.92890775,,,,,,,,,,,,,, -62400,0.72643334,0.951257,,,,,,,,,,,,,, -62424,,,0.070418306,0.0283815178793507,0.3231744,0.0918640238663023,5348.0,0.17216443,0.0564661913757032,2472.0,49005.45624756813,53636.18750405312,49005.45624756813,4626.207343816757,1.843952894210816,0.0 -62500,1.0504031,0.9558207,,,,,,,,,,,,,, -62600,0.8023315,0.94691956,,,,,,,,,,,,,, -62700,0.6020609,0.9750443,,,,,,,,,,,,,, -62800,0.5669864,0.97149,,,,,,,,,,,,,, -62900,0.7009345,0.9267346,,,,,,,,,,,,,, -63000,0.7100157,0.97312087,,,,,,,,,,,,,, -63100,0.61643887,0.9062258,,,,,,,,,,,,,, -63200,0.6256794,0.96587664,,,,,,,,,,,,,, -63300,0.60729855,0.980196,,,,,,,,,,,,,, -63400,0.6845013,0.93814325,,,,,,,,,,,,,, -63500,0.58257914,0.95551956,,,,,,,,,,,,,, -63600,0.824135,0.9689467,,,,,,,,,,,,,, -63700,0.7355585,0.9524044,,,,,,,,,,,,,, -63800,0.7201597,0.95645744,,,,,,,,,,,,,, -63900,0.7353422,0.93496895,,,,,,,,,,,,,, -64000,0.5671,0.9517724,,,,,,,,,,,,,, -64100,0.59840894,0.93337715,,,,,,,,,,,,,, -64200,0.63476616,0.9922751,,,,,,,,,,,,,, -64266,,,0.07730179,0.0303223574687047,0.31241328,0.0885524778667078,5348.0,0.1681303,0.0537850628643389,2472.0,50445.885021448135,55208.22303843498,50445.885021448135,4757.616107225418,1.9643621444702148,0.0 -64300,0.73501945,0.9326766,,,,,,,,,,,,,, -64400,0.6254743,0.95431614,,,,,,,,,,,,,, -64500,0.65991026,0.9291458,,,,,,,,,,,,,, -64600,0.60697484,0.97901946,,,,,,,,,,,,,, -64700,0.5810749,0.9028822,,,,,,,,,,,,,, -64800,0.6562352,0.89014614,,,,,,,,,,,,,, -64900,0.6355793,0.96374816,,,,,,,,,,,,,, -65000,0.6637367,0.95974296,,,,,,,,,,,,,, -65100,0.82597923,0.93531007,,,,,,,,,,,,,, -65200,1.032697,0.90503174,,,,,,,,,,,,,, -65300,0.7349927,0.9687863,,,,,,,,,,,,,, -65400,0.6471677,0.92174685,,,,,,,,,,,,,, -65500,0.7047416,0.9809646,,,,,,,,,,,,,, -65600,0.6211663,0.9362683,,,,,,,,,,,,,, -65700,0.96761423,0.919658,,,,,,,,,,,,,, -65800,0.7028739,0.93197614,,,,,,,,,,,,,, -65900,0.80488336,0.9429913,,,,,,,,,,,,,, -66000,0.6517861,0.94336456,,,,,,,,,,,,,, -66097,,,0.07038693,0.027615892312107,0.3174673,0.0877125230504841,5348.0,0.16740415,0.0537038165458127,2472.0,51886.28319978714,56780.337996959686,51886.28319978714,4889.189654827118,2.029456615447998,0.0 -66100,0.8116859,0.9135892,,,,,,,,,,,,,, -66200,0.6817418,0.8859909,,,,,,,,,,,,,, -66300,0.70103914,0.93447846,,,,,,,,,,,,,, -66400,0.75988865,0.97918206,,,,,,,,,,,,,, -66500,0.7519779,0.9479174,,,,,,,,,,,,,, -66600,0.6505963,0.8945204,,,,,,,,,,,,,, -66700,0.64411545,0.9277175,,,,,,,,,,,,,, -66800,0.6036001,0.9463521,,,,,,,,,,,,,, -66900,0.67491686,0.9238028,,,,,,,,,,,,,, -67000,0.59898597,0.94513714,,,,,,,,,,,,,, -67100,0.7037681,0.9508043,,,,,,,,,,,,,, -67200,0.9315871,0.8986395,,,,,,,,,,,,,, -67300,0.58798265,0.9412032,,,,,,,,,,,,,, -67400,0.8794484,0.9488014,,,,,,,,,,,,,, -67500,0.7690282,0.95488065,,,,,,,,,,,,,, -67600,0.8269347,0.9438679,,,,,,,,,,,,,, -67700,0.6784447,0.90612406,,,,,,,,,,,,,, -67800,0.678505,0.9199088,,,,,,,,,,,,,, -67900,0.9271438,0.9550247,,,,,,,,,,,,,, -67925,,,0.06680417,0.0261072042993986,0.30674285,0.0854726435405543,5348.0,0.16228499,0.0524444986086568,2472.0,53326.39934182167,58353.84253954888,53326.39934182167,5022.44410276413,2.0852842330932617,0.0 -68000,0.6303442,0.9159269,,,,,,,,,,,,,, -68100,0.84519887,0.91119635,,,,,,,,,,,,,, -68200,0.7522616,0.9136424,,,,,,,,,,,,,, -68300,0.71430016,0.968497,,,,,,,,,,,,,, -68400,0.7034463,0.93403524,,,,,,,,,,,,,, -68500,0.96781707,0.90964615,,,,,,,,,,,,,, -68600,0.79749244,0.91327834,,,,,,,,,,,,,, -68700,0.6675391,0.90802103,,,,,,,,,,,,,, -68800,0.6754824,0.9346722,,,,,,,,,,,,,, -68900,0.7079615,0.884932,,,,,,,,,,,,,, -69000,1.0053611,0.91557264,,,,,,,,,,,,,, -69100,0.6818965,0.894204,,,,,,,,,,,,,, -69200,0.6581243,0.8976334,,,,,,,,,,,,,, -69300,0.8394691,0.88773584,,,,,,,,,,,,,, -69400,0.67563736,0.8789858,,,,,,,,,,,,,, -69500,1.0512713,0.92381126,,,,,,,,,,,,,, -69600,0.7063981,0.9063907,,,,,,,,,,,,,, -69700,0.86853564,0.8711316,,,,,,,,,,,,,, -69755,,,0.05866672,0.0228105070404748,0.3071572,0.0856753912548152,5348.0,0.1607685,0.0518351512197103,2472.0,54766.46932411194,59927.20699119568,54766.46932411194,5155.602071285248,2.1447830200195312,0.0 -69800,0.7934328,0.88055855,,,,,,,,,,,,,, -69900,0.763594,0.8933644,,,,,,,,,,,,,, -70000,0.6765368,0.9240441,,,,,,,,,,,,,, -70100,0.5850229,0.85304886,,,,,,,,,,,,,, -70200,0.6819096,0.9341951,,,,,,,,,,,,,, -70300,0.5744965,0.87971544,,,,,,,,,,,,,, -70400,0.922294,0.9032218,,,,,,,,,,,,,, -70500,0.7598607,0.91563034,,,,,,,,,,,,,, -70600,0.7937552,0.88663167,,,,,,,,,,,,,, -70700,0.8324725,0.8988621,,,,,,,,,,,,,, -70800,0.7904172,0.8929275,,,,,,,,,,,,,, -70900,0.74248517,0.9094372,,,,,,,,,,,,,, -71000,0.80460346,0.91832316,,,,,,,,,,,,,, -71100,0.7022747,0.8818811,,,,,,,,,,,,,, -71200,0.7203736,0.90108,,,,,,,,,,,,,, -71300,0.6947776,0.8953156,,,,,,,,,,,,,, -71400,0.9789495,0.89683986,,,,,,,,,,,,,, -71500,1.2216822,0.88896227,,,,,,,,,,,,,, -71598,,,0.053217243,0.0202194390030425,0.30390048,0.084024445581548,5348.0,0.15960225,0.0507992606585014,2472.0,56206.59651684761,61500.89174199104,56206.59651684761,5289.01722574234,2.2092363834381104,0.0 -71600,0.6994014,0.89297056,,,,,,,,,,,,,, -71700,0.89829445,0.9303972,,,,,,,,,,,,,, -71800,0.96454614,0.90125227,,,,,,,,,,,,,, -71900,0.63994855,0.93212,,,,,,,,,,,,,, -72000,0.7011781,0.92404807,,,,,,,,,,,,,, -72100,0.8360339,0.8452704,,,,,,,,,,,,,, -72200,0.669246,0.89634216,,,,,,,,,,,,,, -72300,0.6223918,0.9047947,,,,,,,,,,,,,, -72400,0.7752789,0.8740458,,,,,,,,,,,,,, -72500,0.9455729,0.9234639,,,,,,,,,,,,,, -72600,0.7422028,0.93602616,,,,,,,,,,,,,, -72700,0.81355417,0.9340972,,,,,,,,,,,,,, -72800,0.62799734,0.86621404,,,,,,,,,,,,,, -72900,0.71929127,0.9244973,,,,,,,,,,,,,, -73000,0.8736004,0.9270004,,,,,,,,,,,,,, -73100,0.7900187,0.92421,,,,,,,,,,,,,, -73200,0.7943166,0.88220507,,,,,,,,,,,,,, -73300,0.87713146,0.8267024,,,,,,,,,,,,,, -73400,0.73457587,0.83812296,,,,,,,,,,,,,, -73439,,,0.05957009,0.0217736158018626,0.30104917,0.0834065477857053,5348.0,0.15717694,0.050230536428818,2472.0,57646.99338531494,63073.967522382736,57646.99338531494,5421.555196762085,2.2709219455718994,0.0 -73500,0.8320067,0.899795,,,,,,,,,,,,,, -73600,0.73744637,0.8938368,,,,,,,,,,,,,, -73700,0.6923317,0.89016783,,,,,,,,,,,,,, -73800,0.7064615,0.88924056,,,,,,,,,,,,,, -73900,0.70382327,0.8978123,,,,,,,,,,,,,, -74000,0.7540026,0.9383217,,,,,,,,,,,,,, -74100,0.8533187,0.89819545,,,,,,,,,,,,,, -74200,0.61298007,0.86426955,,,,,,,,,,,,,, -74300,0.80956876,0.8817732,,,,,,,,,,,,,, -74400,0.8266268,0.90020674,,,,,,,,,,,,,, -74500,0.61131775,0.8740283,,,,,,,,,,,,,, -74600,0.7340278,0.8983685,,,,,,,,,,,,,, -74700,0.7372241,0.8839979,,,,,,,,,,,,,, -74800,1.0390513,0.89898705,,,,,,,,,,,,,, -74900,0.8893454,0.8963035,,,,,,,,,,,,,, -75000,0.61407596,0.8262876,,,,,,,,,,,,,, -75100,0.705627,0.9483829,,,,,,,,,,,,,, -75200,0.75884706,0.8540951,,,,,,,,,,,,,, -75268,,,0.056394346,0.0210327421923502,0.29899248,0.0824700464388812,5348.0,0.15624015,0.0497024353583978,2472.0,59087.44711709023,64645.97832798958,59087.44711709023,5552.972937345505,2.3312904834747314,0.0 -75300,0.63498765,0.8564254,,,,,,,,,,,,,, -75400,0.7396645,0.8670416,,,,,,,,,,,,,, -75500,0.9577805,0.8929642,,,,,,,,,,,,,, -75600,0.710592,0.90466714,,,,,,,,,,,,,, -75700,0.7948819,0.89209163,,,,,,,,,,,,,, -75800,0.73171115,0.9154316,,,,,,,,,,,,,, -75900,0.8259642,0.88753283,,,,,,,,,,,,,, -76000,0.6123865,0.8940393,,,,,,,,,,,,,, -76100,0.63363737,0.9144395,,,,,,,,,,,,,, -76200,0.8853267,0.9282677,,,,,,,,,,,,,, -76300,0.7669663,0.8228525,,,,,,,,,,,,,, -76400,0.67558664,0.8831153,,,,,,,,,,,,,, -76500,0.8560798,0.86187387,,,,,,,,,,,,,, -76600,0.640312,0.8747958,,,,,,,,,,,,,, -76700,0.68280256,0.9289016,,,,,,,,,,,,,, -76800,0.6560624,0.8951741,,,,,,,,,,,,,, -76900,0.7154675,0.8726064,,,,,,,,,,,,,, -77000,0.684266,0.85671306,,,,,,,,,,,,,, -77100,0.6274369,0.85947233,,,,,,,,,,,,,, -77109,,,0.05739345,0.0214955273447961,0.2971615,0.0818038753777383,5348.0,0.15573071,0.0494993195620823,2472.0,60527.67941856384,66218.89422011375,60527.67941856384,5685.517039775848,2.3913497924804688,0.0 -77200,0.86410093,0.9160483,,,,,,,,,,,,,, -77300,0.7358985,0.8820133,,,,,,,,,,,,,, -77400,0.5924075,0.86457455,,,,,,,,,,,,,, -77500,0.903603,0.91080785,,,,,,,,,,,,,, -77600,0.86829334,0.86610806,,,,,,,,,,,,,, -77700,0.85195833,0.8757688,,,,,,,,,,,,,, -77800,0.83870494,0.8702856,,,,,,,,,,,,,, -77813,,,,,,,,,,,61068.168311834335,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index ea33dfa61..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,30 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -226.2646632194519,0.0,44.15716743469238,1,0,44.15716743469238,30.657606,2472,3.232161355188593,270.42189955711365,31.563702,3.280985152947546,30.570293,5348,2.911601996582253 -333.1237199306488,0.041421890258789,1485.011241197586,1767,0,1485.011241197586,6.3765306,2472,0.899579550301627,1818.2476708889008,6.473987,0.944635537887994,6.4356875,5348,0.8966179750330672 -453.0073957443237,0.0823078155517578,2925.4576818943024,3537,0,2925.4576818943024,3.662109,2472,0.7429772713423923,3378.6899876594543,3.7205336,0.7672056042404709,4.0462813,5348,0.7830116724755496 -585.1130557060242,0.1241567134857177,4365.597179412842,5280,0,4365.597179412842,0.671695,2472,0.2141246724757784,4951.050794363022,0.60861146,0.2074207610523375,1.0145419,5348,0.285150178128349 -729.5503969192505,0.1606092453002929,5805.970959663391,7030,0,5805.970959663391,0.5282411,2472,0.1707797615420551,6535.9713497161865,0.49595514,0.1668359706353611,0.8424785,5348,0.2426503953580428 -865.7520883083344,0.2023007869720459,7246.0776579380035,8781,0,7246.0776579380035,0.47641897,2472,0.1522149777588203,8112.394110202789,0.4343019,0.1476276935788993,0.7638279,5348,0.2213232667484094 -1002.0370271205902,0.246028184890747,8686.643969297409,10527,0,8686.643969297409,0.44031197,2472,0.1406983121077326,9689.364002227783,0.42308322,0.1412369099124571,0.724271,5348,0.2080191548316711 -1134.8351573944092,0.2853608131408691,10126.57528614998,12254,0,10126.57528614998,0.42263985,2472,0.13685942355737,11262.206323862076,0.3849505,0.1299161933933167,0.70631933,5348,0.2039835098525734 -1277.226976394653,0.3252630233764648,11566.618283510208,13977,0,11566.618283510208,0.3976388,2472,0.1286941685454878,12844.752668619156,0.32947904,0.1149223848774019,0.66567683,5348,0.1920696679764812 -1409.226472377777,0.3636348247528076,13007.737249851229,15733,0,13007.737249851229,0.3880893,2472,0.1242865557654418,14417.983313083649,0.3160166,0.1087040948241465,0.65010995,5348,0.188130569527984 -1546.4209856987,0.4011971950531006,14448.331030368803,17470,0,14448.331030368803,0.370601,2472,0.1191274145390287,15995.882573604584,0.3091924,0.1084159663640215,0.62598205,5348,0.1816040240593954 -1679.376972436905,0.4410090446472168,15888.437488555908,19191,0,15888.437488555908,0.35781944,2472,0.1147401133386143,17569.057680606842,0.31268734,0.106248182130675,0.60505414,5348,0.1773077034476766 -1815.8983781337736,0.4839413166046142,17328.820681095123,20942,0,17328.820681095123,0.34689486,2472,0.1130339406495643,19146.07883024216,0.31071246,0.1065924499432822,0.5931073,5348,0.1737354818154609 -1950.612141370773,0.5289254188537598,18768.853582143784,22661,0,18768.853582143784,0.33769685,2472,0.1081185383787297,20720.94268655777,0.29349634,0.1005453235929622,0.57543117,5348,0.168493005203858 -2092.538419485092,0.5695466995239258,20209.092022895813,24399,0,20209.092022895813,0.325007,2472,0.1048280624784189,22303.220635175705,0.26594213,0.0915554664076686,0.5630687,5348,0.163482240265696 -2236.258577108383,0.6100704669952393,21649.144829034805,26147,0,21649.144829034805,0.3190817,2472,0.102126622387423,23887.106551885605,0.25006896,0.0878218002556143,0.5523809,5348,0.1610009944292652 -2377.8051433563232,0.6545779705047607,23089.323447704315,27855,0,23089.323447704315,0.30643246,2472,0.0981049296203765,25468.94979000092,0.23860557,0.0828977917787483,0.5355645,5348,0.1553530223891404 -2511.6333718299866,0.694037675857544,24529.774584531784,29591,0,24529.774584531784,0.293739,2472,0.0932504620884366,27043.3413450718,0.25175303,0.0871955711561795,0.5208304,5348,0.1510856657366017 -2644.3974990844727,0.7355659008026123,25970.473658800125,31338,0,25970.473658800125,0.2868147,2472,0.0914427315012288,28616.918856859207,0.22907293,0.0768696520328379,0.50683457,5348,0.1489712967164524 -2785.990561962128,0.7848289012908936,27410.98273205757,33063,0,27410.98273205757,0.27371708,2472,0.0880913208620234,30199.143298387527,0.23234938,0.0781942682193877,0.49369305,5348,0.1442212074109116 -2931.6980545520782,0.8314599990844727,28850.900230646133,34798,0,28850.900230646133,0.26734018,2472,0.0863851481729734,31784.88902115822,0.21853557,0.0723009047603823,0.47827327,5348,0.1389883854523687 -3065.7418246269226,0.8742971420288086,30290.972578525543,36544,0,30290.972578525543,0.25568554,2472,0.0814291227428757,33359.12208724022,0.16265668,0.0572860613142277,0.46007198,5348,0.1338424553713662 -3202.054293870926,0.916539430618286,31731.00056815148,38279,0,31731.00056815148,0.24500446,2472,0.0778339731480917,34935.57740569115,0.17948455,0.0625910751620874,0.44596314,5348,0.1296040626780076 -3333.839416027069,0.960141897201538,33171.87407159805,40011,0,33171.87407159805,0.23872232,2472,0.0754575183312006,36508.35352563858,0.22586584,0.0771792604806821,0.43340683,5348,0.1258580572907112 -3467.3239908218384,1.0061976909637451,34611.97883796692,41741,0,34611.97883796692,0.22935602,2472,0.0725529624438892,38082.061220407486,0.22595026,0.0774299560889767,0.4262832,5348,0.1229326974135184 -3601.6576216220856,1.0495140552520752,36052.042571783066,43467,0,36052.042571783066,0.22580484,2472,0.0708671013344707,39656.57435369492,0.25700453,0.0887505776509058,0.41875762,5348,0.1210307307606901 -3733.827961683273,1.090902328491211,37492.52834439278,45193,0,37492.52834439278,0.22298104,2472,0.0698108991936302,41229.34570097923,0.22935137,0.0772333072317175,0.41290858,5348,0.119177037373162 -3866.790652275085,1.1397252082824707,38932.92411899567,46925,0,38932.92411899567,0.22249565,2472,0.0699530802510511,42802.82615447045,0.21121906,0.0731408593410753,0.4120718,5348,0.11874257798546 -4000.894311904907,1.186645746231079,39802.3655128479,48000,0,39802.3655128479,0.22241092,2472,0.0701561960473666,43806.46564793587,0.19107874,0.06727098717767457,0.4123398,5348,0.11909980014868166 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 73ffc5731..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,511 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,19.473045,32.968338,,,,,,,,,,,,,, -1,,,31.563702,3.280985152947546,30.570293,2.911601996582253,5348.0,30.657606,3.232161355188593,2472.0,44.15716743469238,270.42189955711365,44.15716743469238,226.2646632194519,0.0,0.0 -100,6.9535117,9.512773,,,,,,,,,,,,,, -200,2.7786052,6.6726274,,,,,,,,,,,,,, -300,0.9673515,5.9280505,,,,,,,,,,,,,, -400,0.51677,5.8589106,,,,,,,,,,,,,, -500,0.600669,5.8317685,,,,,,,,,,,,,, -600,0.4826936,5.797192,,,,,,,,,,,,,, -700,0.37846547,5.7728367,,,,,,,,,,,,,, -800,0.42708656,5.6803846,,,,,,,,,,,,,, -900,0.4224222,5.5707808,,,,,,,,,,,,,, -1000,0.9037691,5.494487,,,,,,,,,,,,,, -1100,0.6559108,5.2960625,,,,,,,,,,,,,, -1200,1.2580963,4.9911532,,,,,,,,,,,,,, -1300,1.379294,4.5440645,,,,,,,,,,,,,, -1400,1.532748,4.1023836,,,,,,,,,,,,,, -1500,1.9314594,3.7746987,,,,,,,,,,,,,, -1600,2.2932472,3.5489352,,,,,,,,,,,,,, -1700,2.8955314,3.3429492,,,,,,,,,,,,,, -1767,,,6.473987,0.944635537887994,6.4356875,0.8966179750330672,5348.0,6.3765306,0.899579550301627,2472.0,1485.011241197586,1818.2476708889008,1485.011241197586,333.1237199306488,0.041421890258789,0.0 -1800,2.6431448,3.1910057,,,,,,,,,,,,,, -1900,2.7440944,3.0495732,,,,,,,,,,,,,, -2000,2.8505328,2.9230926,,,,,,,,,,,,,, -2100,3.4145076,2.8555024,,,,,,,,,,,,,, -2200,2.5926738,2.7946978,,,,,,,,,,,,,, -2300,3.1311786,2.7415636,,,,,,,,,,,,,, -2400,2.0639474,2.6803267,,,,,,,,,,,,,, -2500,3.936512,2.6281726,,,,,,,,,,,,,, -2600,2.5105917,2.485751,,,,,,,,,,,,,, -2700,2.8480914,2.4357066,,,,,,,,,,,,,, -2800,3.7132866,2.510907,,,,,,,,,,,,,, -2900,4.012956,2.402615,,,,,,,,,,,,,, -3000,3.2598128,2.3322027,,,,,,,,,,,,,, -3100,2.6755326,2.321126,,,,,,,,,,,,,, -3200,4.650546,2.2541013,,,,,,,,,,,,,, -3300,4.5164385,2.2264717,,,,,,,,,,,,,, -3400,3.780532,2.2484834,,,,,,,,,,,,,, -3500,4.6168237,2.1617088,,,,,,,,,,,,,, -3537,,,3.7205336,0.7672056042404709,4.0462813,0.7830116724755496,5348.0,3.662109,0.7429772713423923,2472.0,2925.4576818943024,3378.6899876594543,2925.4576818943024,453.0073957443237,0.0823078155517578,0.0 -3600,2.8412793,2.1211777,,,,,,,,,,,,,, -3700,3.1103244,2.1528163,,,,,,,,,,,,,, -3800,2.9706886,2.125999,,,,,,,,,,,,,, -3900,3.7705991,2.0797124,,,,,,,,,,,,,, -4000,5.749728,2.0935614,,,,,,,,,,,,,, -4100,4.8220043,2.0914688,,,,,,,,,,,,,, -4200,4.2103114,1.991512,,,,,,,,,,,,,, -4300,2.6960185,2.0569494,,,,,,,,,,,,,, -4400,4.2794943,1.9657122,,,,,,,,,,,,,, -4500,2.6757808,2.0420554,,,,,,,,,,,,,, -4600,4.427333,1.9828222,,,,,,,,,,,,,, -4700,3.059097,2.002001,,,,,,,,,,,,,, -4800,3.530097,1.9443716,,,,,,,,,,,,,, -4900,4.6303287,2.0045092,,,,,,,,,,,,,, -5000,2.795675,1.8764658,,,,,,,,,,,,,, -5100,3.513572,1.8870101,,,,,,,,,,,,,, -5200,3.5944393,1.9028224,,,,,,,,,,,,,, -5280,,,0.60861146,0.2074207610523375,1.0145419,0.285150178128349,5348.0,0.671695,0.2141246724757784,2472.0,4365.597179412842,4951.050794363022,4365.597179412842,585.1130557060242,0.1241567134857177,0.0 -5300,5.1312113,1.8532381,,,,,,,,,,,,,, -5400,2.7616825,1.9055338,,,,,,,,,,,,,, -5500,2.9781415,1.9307739,,,,,,,,,,,,,, -5600,3.03058,1.852001,,,,,,,,,,,,,, -5700,2.8229573,1.8651222,,,,,,,,,,,,,, -5800,3.253351,1.7972897,,,,,,,,,,,,,, -5900,3.659773,1.8383682,,,,,,,,,,,,,, -6000,3.427371,1.84901,,,,,,,,,,,,,, -6100,2.521181,1.8323463,,,,,,,,,,,,,, -6200,3.9587007,1.7662078,,,,,,,,,,,,,, -6300,4.709691,1.7868919,,,,,,,,,,,,,, -6400,4.2482862,1.7740729,,,,,,,,,,,,,, -6500,3.5967767,1.6899431,,,,,,,,,,,,,, -6600,3.0014253,1.7879112,,,,,,,,,,,,,, -6700,2.5312824,1.8177854,,,,,,,,,,,,,, -6800,2.1554823,1.7090726,,,,,,,,,,,,,, -6900,3.0177302,1.7131697,,,,,,,,,,,,,, -7000,3.3434927,1.8547496,,,,,,,,,,,,,, -7030,,,0.49595514,0.1668359706353611,0.8424785,0.2426503953580428,5348.0,0.5282411,0.1707797615420551,2472.0,5805.970959663391,6535.9713497161865,5805.970959663391,729.5503969192505,0.1606092453002929,0.0 -7100,2.687036,1.7679818,,,,,,,,,,,,,, -7200,2.4890633,1.8103464,,,,,,,,,,,,,, -7300,2.1057403,1.7849455,,,,,,,,,,,,,, -7400,2.4340641,1.7200866,,,,,,,,,,,,,, -7500,2.9974208,1.7407523,,,,,,,,,,,,,, -7600,3.9325457,1.7232037,,,,,,,,,,,,,, -7700,2.2726336,1.7164423,,,,,,,,,,,,,, -7800,2.5313153,1.7124901,,,,,,,,,,,,,, -7900,2.162084,1.6620882,,,,,,,,,,,,,, -8000,3.2129276,1.6681794,,,,,,,,,,,,,, -8100,6.582836,1.6535839,,,,,,,,,,,,,, -8200,3.2176077,1.7677864,,,,,,,,,,,,,, -8300,3.6711361,1.7244065,,,,,,,,,,,,,, -8400,2.5827546,1.6849699,,,,,,,,,,,,,, -8500,4.2197824,1.6447337,,,,,,,,,,,,,, -8600,2.2726264,1.654274,,,,,,,,,,,,,, -8700,2.550263,1.6770093,,,,,,,,,,,,,, -8781,,,0.4343019,0.1476276935788993,0.7638279,0.2213232667484094,5348.0,0.47641897,0.1522149777588203,2472.0,7246.0776579380035,8112.394110202789,7246.0776579380035,865.7520883083344,0.2023007869720459,0.0 -8800,2.6056824,1.6416377,,,,,,,,,,,,,, -8900,3.851344,1.7024949,,,,,,,,,,,,,, -9000,2.4293175,1.681653,,,,,,,,,,,,,, -9100,2.5723581,1.750404,,,,,,,,,,,,,, -9200,2.4136984,1.6923096,,,,,,,,,,,,,, -9300,2.3046074,1.6613562,,,,,,,,,,,,,, -9400,2.8100443,1.6958619,,,,,,,,,,,,,, -9500,2.8213181,1.678485,,,,,,,,,,,,,, -9600,2.563567,1.6627548,,,,,,,,,,,,,, -9700,2.138356,1.713762,,,,,,,,,,,,,, -9800,2.409897,1.6266795,,,,,,,,,,,,,, -9900,3.905684,1.6366432,,,,,,,,,,,,,, -10000,3.1380978,1.6553563,,,,,,,,,,,,,, -10100,3.212415,1.6067958,,,,,,,,,,,,,, -10200,3.1651788,1.6380363,,,,,,,,,,,,,, -10300,2.286719,1.6530005,,,,,,,,,,,,,, -10400,3.0094829,1.597597,,,,,,,,,,,,,, -10500,2.789282,1.6286433,,,,,,,,,,,,,, -10527,,,0.42308322,0.1412369099124571,0.724271,0.2080191548316711,5348.0,0.44031197,0.1406983121077326,2472.0,8686.643969297409,9689.364002227783,8686.643969297409,1002.0370271205902,0.246028184890747,0.0 -10600,1.8996099,1.5836529,,,,,,,,,,,,,, -10700,3.2717266,1.6270082,,,,,,,,,,,,,, -10800,3.9573534,1.6484945,,,,,,,,,,,,,, -10900,2.1632767,1.6651431,,,,,,,,,,,,,, -11000,2.2546146,1.6274388,,,,,,,,,,,,,, -11100,2.721501,1.6152401,,,,,,,,,,,,,, -11200,2.9065633,1.6004522,,,,,,,,,,,,,, -11300,2.1360898,1.6012919,,,,,,,,,,,,,, -11400,4.756043,1.5759604,,,,,,,,,,,,,, -11500,3.592877,1.6992748,,,,,,,,,,,,,, -11600,3.3077302,1.6591132,,,,,,,,,,,,,, -11700,2.979795,1.5587596,,,,,,,,,,,,,, -11800,2.1502185,1.64165,,,,,,,,,,,,,, -11900,2.7813685,1.6334081,,,,,,,,,,,,,, -12000,3.4678605,1.6129704,,,,,,,,,,,,,, -12100,3.831001,1.5963411,,,,,,,,,,,,,, -12200,2.7099712,1.6415131,,,,,,,,,,,,,, -12254,,,0.3849505,0.1299161933933167,0.70631933,0.2039835098525734,5348.0,0.42263985,0.13685942355737,2472.0,10126.57528614998,11262.206323862076,10126.57528614998,1134.8351573944092,0.2853608131408691,0.0 -12300,2.8463154,1.6512764,,,,,,,,,,,,,, -12400,2.3913548,1.5936016,,,,,,,,,,,,,, -12500,2.639102,1.607133,,,,,,,,,,,,,, -12600,2.546253,1.5726393,,,,,,,,,,,,,, -12700,1.764143,1.5909306,,,,,,,,,,,,,, -12800,3.2992034,1.6136415,,,,,,,,,,,,,, -12900,4.1761675,1.6032428,,,,,,,,,,,,,, -13000,3.6153994,1.63157,,,,,,,,,,,,,, -13100,2.9171443,1.6196595,,,,,,,,,,,,,, -13200,2.8490133,1.5917798,,,,,,,,,,,,,, -13300,5.177733,1.592229,,,,,,,,,,,,,, -13400,2.9205573,1.57188,,,,,,,,,,,,,, -13500,2.4703567,1.5843233,,,,,,,,,,,,,, -13600,2.7398934,1.5971043,,,,,,,,,,,,,, -13700,4.792127,1.5623455,,,,,,,,,,,,,, -13800,2.7897756,1.6124337,,,,,,,,,,,,,, -13900,4.0020485,1.5841523,,,,,,,,,,,,,, -13977,,,0.32947904,0.1149223848774019,0.66567683,0.1920696679764812,5348.0,0.3976388,0.1286941685454878,2472.0,11566.618283510208,12844.752668619156,11566.618283510208,1277.226976394653,0.3252630233764648,0.0 -14000,2.6998134,1.6169924,,,,,,,,,,,,,, -14100,1.9535018,1.6417259,,,,,,,,,,,,,, -14200,2.9383438,1.5735183,,,,,,,,,,,,,, -14300,3.0932896,1.683875,,,,,,,,,,,,,, -14400,2.5689905,1.5823988,,,,,,,,,,,,,, -14500,2.3665235,1.5462978,,,,,,,,,,,,,, -14600,6.843161,1.5429828,,,,,,,,,,,,,, -14700,2.1957998,1.5337684,,,,,,,,,,,,,, -14800,4.103471,1.5628574,,,,,,,,,,,,,, -14900,2.8817713,1.5666858,,,,,,,,,,,,,, -15000,2.9262855,1.586649,,,,,,,,,,,,,, -15100,2.6850128,1.4790286,,,,,,,,,,,,,, -15200,2.7175212,1.4792712,,,,,,,,,,,,,, -15300,3.6789532,1.5937377,,,,,,,,,,,,,, -15400,2.899261,1.5321833,,,,,,,,,,,,,, -15500,2.2540119,1.5277902,,,,,,,,,,,,,, -15600,3.9077928,1.5712227,,,,,,,,,,,,,, -15700,3.2662525,1.5551286,,,,,,,,,,,,,, -15733,,,0.3160166,0.1087040948241465,0.65010995,0.188130569527984,5348.0,0.3880893,0.1242865557654418,2472.0,13007.737249851229,14417.983313083649,13007.737249851229,1409.226472377777,0.3636348247528076,0.0 -15800,3.0569036,1.5357641,,,,,,,,,,,,,, -15900,2.9395957,1.5309283,,,,,,,,,,,,,, -16000,2.1867166,1.5227288,,,,,,,,,,,,,, -16100,3.4406738,1.5582501,,,,,,,,,,,,,, -16200,2.2444198,1.540488,,,,,,,,,,,,,, -16300,2.533617,1.5509328,,,,,,,,,,,,,, -16400,2.5005355,1.5151395,,,,,,,,,,,,,, -16500,2.5500104,1.4779068,,,,,,,,,,,,,, -16600,2.9321496,1.5191625,,,,,,,,,,,,,, -16700,2.6951427,1.5236609,,,,,,,,,,,,,, -16800,3.4807026,1.5662757,,,,,,,,,,,,,, -16900,2.971635,1.4896116,,,,,,,,,,,,,, -17000,2.319685,1.4989469,,,,,,,,,,,,,, -17100,2.5336347,1.4973015,,,,,,,,,,,,,, -17200,2.9211423,1.518313,,,,,,,,,,,,,, -17300,2.6586633,1.5619963,,,,,,,,,,,,,, -17400,3.8311443,1.5046569,,,,,,,,,,,,,, -17470,,,0.3091924,0.1084159663640215,0.62598205,0.1816040240593954,5348.0,0.370601,0.1191274145390287,2472.0,14448.331030368803,15995.882573604584,14448.331030368803,1546.4209856987,0.4011971950531006,0.0 -17500,2.8354173,1.5557531,,,,,,,,,,,,,, -17600,2.6474416,1.548256,,,,,,,,,,,,,, -17700,2.6268492,1.5625753,,,,,,,,,,,,,, -17800,3.4635136,1.5330812,,,,,,,,,,,,,, -17900,2.273455,1.5885689,,,,,,,,,,,,,, -18000,2.8140364,1.5327904,,,,,,,,,,,,,, -18100,2.5356994,1.5271165,,,,,,,,,,,,,, -18200,3.1824696,1.5409493,,,,,,,,,,,,,, -18300,2.7371962,1.5044855,,,,,,,,,,,,,, -18400,4.130111,1.5360461,,,,,,,,,,,,,, -18500,4.408894,1.550284,,,,,,,,,,,,,, -18600,2.9804459,1.5257152,,,,,,,,,,,,,, -18700,2.1206076,1.527637,,,,,,,,,,,,,, -18800,3.2011983,1.4841306,,,,,,,,,,,,,, -18900,3.0307488,1.4988757,,,,,,,,,,,,,, -19000,2.0298386,1.5122668,,,,,,,,,,,,,, -19100,2.3782334,1.4495407,,,,,,,,,,,,,, -19191,,,0.31268734,0.106248182130675,0.60505414,0.1773077034476766,5348.0,0.35781944,0.1147401133386143,2472.0,15888.437488555908,17569.057680606842,15888.437488555908,1679.376972436905,0.4410090446472168,0.0 -19200,3.2799067,1.5702068,,,,,,,,,,,,,, -19300,2.7953603,1.483119,,,,,,,,,,,,,, -19400,3.3404505,1.5178703,,,,,,,,,,,,,, -19500,3.3309188,1.4917268,,,,,,,,,,,,,, -19600,2.6029983,1.5516582,,,,,,,,,,,,,, -19700,2.6812403,1.5067134,,,,,,,,,,,,,, -19800,1.8869237,1.4482049,,,,,,,,,,,,,, -19900,2.982912,1.5839473,,,,,,,,,,,,,, -20000,2.1451101,1.4988422,,,,,,,,,,,,,, -20100,2.2114978,1.4409219,,,,,,,,,,,,,, -20200,3.4029443,1.4482486,,,,,,,,,,,,,, -20300,2.6302545,1.4958823,,,,,,,,,,,,,, -20400,3.3166664,1.4664617,,,,,,,,,,,,,, -20500,4.9756327,1.5994712,,,,,,,,,,,,,, -20600,2.8643544,1.452646,,,,,,,,,,,,,, -20700,1.9528257,1.4647069,,,,,,,,,,,,,, -20800,2.2724822,1.4515252,,,,,,,,,,,,,, -20900,2.88492,1.5134443,,,,,,,,,,,,,, -20942,,,0.31071246,0.1065924499432822,0.5931073,0.1737354818154609,5348.0,0.34689486,0.1130339406495643,2472.0,17328.820681095123,19146.07883024216,17328.820681095123,1815.8983781337736,0.4839413166046142,0.0 -21000,2.9642453,1.4559547,,,,,,,,,,,,,, -21100,2.8075402,1.5042764,,,,,,,,,,,,,, -21200,3.2357073,1.4218631,,,,,,,,,,,,,, -21300,2.7890546,1.4907361,,,,,,,,,,,,,, -21400,2.8590536,1.4916226,,,,,,,,,,,,,, -21500,2.8563132,1.4502345,,,,,,,,,,,,,, -21600,4.1570725,1.4812924,,,,,,,,,,,,,, -21700,3.5288503,1.3834679,,,,,,,,,,,,,, -21800,2.7974942,1.4852908,,,,,,,,,,,,,, -21900,2.9204295,1.4300122,,,,,,,,,,,,,, -22000,2.102264,1.4604574,,,,,,,,,,,,,, -22100,2.6620278,1.4710666,,,,,,,,,,,,,, -22200,3.7181456,1.4225216,,,,,,,,,,,,,, -22300,1.9746708,1.4047586,,,,,,,,,,,,,, -22400,2.407556,1.436089,,,,,,,,,,,,,, -22500,3.379407,1.4414326,,,,,,,,,,,,,, -22600,4.07431,1.4664469,,,,,,,,,,,,,, -22661,,,0.29349634,0.1005453235929622,0.57543117,0.168493005203858,5348.0,0.33769685,0.1081185383787297,2472.0,18768.853582143784,20720.94268655777,18768.853582143784,1950.612141370773,0.5289254188537598,0.0 -22700,2.2584198,1.408855,,,,,,,,,,,,,, -22800,3.3522458,1.4579747,,,,,,,,,,,,,, -22900,3.8585384,1.4498823,,,,,,,,,,,,,, -23000,2.9430218,1.4868549,,,,,,,,,,,,,, -23100,1.739266,1.4839625,,,,,,,,,,,,,, -23200,3.1414714,1.402161,,,,,,,,,,,,,, -23300,4.0919733,1.4132175,,,,,,,,,,,,,, -23400,5.1061764,1.4728698,,,,,,,,,,,,,, -23500,3.3936055,1.4279575,,,,,,,,,,,,,, -23600,2.2062764,1.4870923,,,,,,,,,,,,,, -23700,3.9729989,1.416,,,,,,,,,,,,,, -23800,2.6731422,1.423028,,,,,,,,,,,,,, -23900,1.8364004,1.4229074,,,,,,,,,,,,,, -24000,2.9761422,1.4389402,,,,,,,,,,,,,, -24100,3.1608686,1.4460809,,,,,,,,,,,,,, -24200,5.579229,1.4073203,,,,,,,,,,,,,, -24300,2.2631364,1.3857597,,,,,,,,,,,,,, -24399,,,0.26594213,0.0915554664076686,0.5630687,0.163482240265696,5348.0,0.325007,0.1048280624784189,2472.0,20209.092022895813,22303.220635175705,20209.092022895813,2092.538419485092,0.5695466995239258,0.0 -24400,2.6681778,1.4061543,,,,,,,,,,,,,, -24500,2.4149213,1.4501095,,,,,,,,,,,,,, -24600,2.8140733,1.406144,,,,,,,,,,,,,, -24700,2.9497347,1.4645858,,,,,,,,,,,,,, -24800,3.9533377,1.4593694,,,,,,,,,,,,,, -24900,3.125584,1.4371121,,,,,,,,,,,,,, -25000,3.0044672,1.4106977,,,,,,,,,,,,,, -25100,2.4548528,1.4792429,,,,,,,,,,,,,, -25200,3.1767147,1.4591606,,,,,,,,,,,,,, -25300,2.629584,1.4117563,,,,,,,,,,,,,, -25400,3.3206317,1.4741294,,,,,,,,,,,,,, -25500,2.4725633,1.4221729,,,,,,,,,,,,,, -25600,2.9125445,1.4155815,,,,,,,,,,,,,, -25700,3.9820943,1.4414393,,,,,,,,,,,,,, -25800,2.738249,1.423616,,,,,,,,,,,,,, -25900,3.0008965,1.2970918,,,,,,,,,,,,,, -26000,2.8608608,1.3930043,,,,,,,,,,,,,, -26100,3.6863906,1.3915795,,,,,,,,,,,,,, -26147,,,0.25006896,0.0878218002556143,0.5523809,0.1610009944292652,5348.0,0.3190817,0.102126622387423,2472.0,21649.144829034805,23887.106551885605,21649.144829034805,2236.258577108383,0.6100704669952393,0.0 -26200,2.9692104,1.3997004,,,,,,,,,,,,,, -26300,3.445747,1.4187522,,,,,,,,,,,,,, -26400,2.6487222,1.3682439,,,,,,,,,,,,,, -26500,3.6746843,1.389462,,,,,,,,,,,,,, -26600,3.026982,1.4162189,,,,,,,,,,,,,, -26700,2.4810429,1.4145523,,,,,,,,,,,,,, -26800,2.3346689,1.4185144,,,,,,,,,,,,,, -26900,3.0573618,1.4158947,,,,,,,,,,,,,, -27000,2.8378518,1.4176472,,,,,,,,,,,,,, -27100,2.6296685,1.442471,,,,,,,,,,,,,, -27200,2.8219714,1.3380756,,,,,,,,,,,,,, -27300,3.005007,1.431502,,,,,,,,,,,,,, -27400,2.5603127,1.376673,,,,,,,,,,,,,, -27500,3.62704,1.37332,,,,,,,,,,,,,, -27600,2.7153618,1.4601862,,,,,,,,,,,,,, -27700,3.2421687,1.4400554,,,,,,,,,,,,,, -27800,2.312075,1.3952892,,,,,,,,,,,,,, -27855,,,0.23860557,0.0828977917787483,0.5355645,0.1553530223891404,5348.0,0.30643246,0.0981049296203765,2472.0,23089.323447704315,25468.94979000092,23089.323447704315,2377.8051433563232,0.6545779705047607,0.0 -27900,1.9617211,1.3450061,,,,,,,,,,,,,, -28000,2.6290371,1.3348149,,,,,,,,,,,,,, -28100,2.7481248,1.3686148,,,,,,,,,,,,,, -28200,2.3638368,1.4051263,,,,,,,,,,,,,, -28300,3.2001178,1.4432659,,,,,,,,,,,,,, -28400,3.0057588,1.3362285,,,,,,,,,,,,,, -28500,3.246362,1.3635122,,,,,,,,,,,,,, -28600,2.9293425,1.410387,,,,,,,,,,,,,, -28700,3.5660362,1.3515196,,,,,,,,,,,,,, -28800,2.3356664,1.3664013,,,,,,,,,,,,,, -28900,2.906681,1.3082219,,,,,,,,,,,,,, -29000,2.6756167,1.3269488,,,,,,,,,,,,,, -29100,2.7748246,1.3616874,,,,,,,,,,,,,, -29200,5.1018305,1.286907,,,,,,,,,,,,,, -29300,3.012744,1.4119489,,,,,,,,,,,,,, -29400,3.0713983,1.3359039,,,,,,,,,,,,,, -29500,2.7077575,1.3629388,,,,,,,,,,,,,, -29591,,,0.25175303,0.0871955711561795,0.5208304,0.1510856657366017,5348.0,0.293739,0.0932504620884366,2472.0,24529.774584531784,27043.3413450718,24529.774584531784,2511.6333718299866,0.694037675857544,0.0 -29600,2.622313,1.3263842,,,,,,,,,,,,,, -29700,3.0509713,1.370684,,,,,,,,,,,,,, -29800,3.3946257,1.372145,,,,,,,,,,,,,, -29900,2.084175,1.3832855,,,,,,,,,,,,,, -30000,2.2968462,1.3705137,,,,,,,,,,,,,, -30100,2.4455726,1.3262742,,,,,,,,,,,,,, -30200,2.3168204,1.3620194,,,,,,,,,,,,,, -30300,2.5644534,1.2785313,,,,,,,,,,,,,, -30400,3.1417096,1.3316967,,,,,,,,,,,,,, -30500,2.6724055,1.32944,,,,,,,,,,,,,, -30600,2.2739005,1.3244507,,,,,,,,,,,,,, -30700,2.8282955,1.3468546,,,,,,,,,,,,,, -30800,2.8620255,1.300859,,,,,,,,,,,,,, -30900,2.4128354,1.3126844,,,,,,,,,,,,,, -31000,5.583099,1.3693309,,,,,,,,,,,,,, -31100,5.180118,1.3238306,,,,,,,,,,,,,, -31200,2.65675,1.2809554,,,,,,,,,,,,,, -31300,3.4794664,1.3163223,,,,,,,,,,,,,, -31338,,,0.22907293,0.0768696520328379,0.50683457,0.1489712967164524,5348.0,0.2868147,0.0914427315012288,2472.0,25970.473658800125,28616.918856859207,25970.473658800125,2644.3974990844727,0.7355659008026123,0.0 -31400,1.9990544,1.2685164,,,,,,,,,,,,,, -31500,2.2907214,1.2984111,,,,,,,,,,,,,, -31600,2.1513724,1.2889531,,,,,,,,,,,,,, -31700,3.480705,1.3161876,,,,,,,,,,,,,, -31800,2.8401124,1.3032333,,,,,,,,,,,,,, -31900,2.1415148,1.342192,,,,,,,,,,,,,, -32000,2.1486695,1.2661254,,,,,,,,,,,,,, -32100,2.317701,1.3419299,,,,,,,,,,,,,, -32200,3.128578,1.329806,,,,,,,,,,,,,, -32300,2.5155065,1.3293478,,,,,,,,,,,,,, -32400,2.6224382,1.3034801,,,,,,,,,,,,,, -32500,2.289473,1.3899704,,,,,,,,,,,,,, -32600,4.844293,1.2397038,,,,,,,,,,,,,, -32700,2.1424105,1.2916311,,,,,,,,,,,,,, -32800,4.1060767,1.2225207,,,,,,,,,,,,,, -32900,6.5082107,1.3485986,,,,,,,,,,,,,, -33000,2.791465,1.2804847,,,,,,,,,,,,,, -33063,,,0.23234938,0.0781942682193877,0.49369305,0.1442212074109116,5348.0,0.27371708,0.0880913208620234,2472.0,27410.98273205757,30199.143298387527,27410.98273205757,2785.990561962128,0.7848289012908936,0.0 -33100,2.0394711,1.2791482,,,,,,,,,,,,,, -33200,2.2647169,1.2851173,,,,,,,,,,,,,, -33300,2.8866985,1.2538004,,,,,,,,,,,,,, -33400,2.9484994,1.2927094,,,,,,,,,,,,,, -33500,3.2452948,1.253569,,,,,,,,,,,,,, -33600,2.8443425,1.2901058,,,,,,,,,,,,,, -33700,2.0036118,1.2936203,,,,,,,,,,,,,, -33800,2.694808,1.2439935,,,,,,,,,,,,,, -33900,3.344452,1.2807734,,,,,,,,,,,,,, -34000,1.9728005,1.2563077,,,,,,,,,,,,,, -34100,2.9467108,1.3435208,,,,,,,,,,,,,, -34200,2.7472281,1.2472062,,,,,,,,,,,,,, -34300,3.2390873,1.310446,,,,,,,,,,,,,, -34400,2.1792212,1.2808758,,,,,,,,,,,,,, -34500,2.884527,1.2829825,,,,,,,,,,,,,, -34600,2.6972973,1.2566327,,,,,,,,,,,,,, -34700,3.510596,1.3029586,,,,,,,,,,,,,, -34798,,,0.21853557,0.0723009047603823,0.47827327,0.1389883854523687,5348.0,0.26734018,0.0863851481729734,2472.0,28850.900230646133,31784.88902115822,28850.900230646133,2931.6980545520782,0.8314599990844727,0.0 -34800,2.3027446,1.1757462,,,,,,,,,,,,,, -34900,3.2901912,1.2197984,,,,,,,,,,,,,, -35000,5.290783,1.314693,,,,,,,,,,,,,, -35100,5.5574985,1.2414261,,,,,,,,,,,,,, -35200,2.4503953,1.2825687,,,,,,,,,,,,,, -35300,2.3409882,1.265002,,,,,,,,,,,,,, -35400,3.367438,1.3071125,,,,,,,,,,,,,, -35500,3.604647,1.2670889,,,,,,,,,,,,,, -35600,3.8930695,1.236544,,,,,,,,,,,,,, -35700,1.9747038,1.2688318,,,,,,,,,,,,,, -35800,2.2901242,1.2295489,,,,,,,,,,,,,, -35900,2.9311404,1.3035816,,,,,,,,,,,,,, -36000,3.2252166,1.2401004,,,,,,,,,,,,,, -36100,2.4850903,1.2631385,,,,,,,,,,,,,, -36200,2.295188,1.2040197,,,,,,,,,,,,,, -36300,2.2902963,1.2546362,,,,,,,,,,,,,, -36400,3.4569058,1.2467307,,,,,,,,,,,,,, -36500,2.5841434,1.248954,,,,,,,,,,,,,, -36544,,,0.16265668,0.0572860613142277,0.46007198,0.1338424553713662,5348.0,0.25568554,0.0814291227428757,2472.0,30290.972578525543,33359.12208724022,30290.972578525543,3065.7418246269226,0.8742971420288086,0.0 -36600,6.0826526,1.2331493,,,,,,,,,,,,,, -36700,3.071612,1.2292153,,,,,,,,,,,,,, -36800,2.9856231,1.1965156,,,,,,,,,,,,,, -36900,2.9190373,1.2685637,,,,,,,,,,,,,, -37000,2.5196836,1.2318618,,,,,,,,,,,,,, -37100,3.5471354,1.262705,,,,,,,,,,,,,, -37200,2.1999133,1.2343951,,,,,,,,,,,,,, -37300,2.6187363,1.2023433,,,,,,,,,,,,,, -37400,3.228742,1.2413476,,,,,,,,,,,,,, -37500,2.339581,1.2328259,,,,,,,,,,,,,, -37600,2.2328143,1.231847,,,,,,,,,,,,,, -37700,2.4978206,1.2271618,,,,,,,,,,,,,, -37800,2.3370254,1.2318594,,,,,,,,,,,,,, -37900,3.567681,1.231084,,,,,,,,,,,,,, -38000,2.318111,1.2034572,,,,,,,,,,,,,, -38100,4.2309785,1.1936864,,,,,,,,,,,,,, -38200,5.002096,1.2340355,,,,,,,,,,,,,, -38279,,,0.17948455,0.0625910751620874,0.44596314,0.1296040626780076,5348.0,0.24500446,0.0778339731480917,2472.0,31731.00056815148,34935.57740569115,31731.00056815148,3202.054293870926,0.916539430618286,0.0 -38300,2.4540362,1.2110279,,,,,,,,,,,,,, -38400,3.061014,1.2776378,,,,,,,,,,,,,, -38500,2.0447443,1.1933551,,,,,,,,,,,,,, -38600,2.6994407,1.2117013,,,,,,,,,,,,,, -38700,4.640989,1.235962,,,,,,,,,,,,,, -38800,2.3469787,1.1771516,,,,,,,,,,,,,, -38900,2.0023088,1.1833415,,,,,,,,,,,,,, -39000,3.2004387,1.2119342,,,,,,,,,,,,,, -39100,3.7395744,1.1925658,,,,,,,,,,,,,, -39200,2.5036845,1.1782457,,,,,,,,,,,,,, -39300,4.369343,1.2051771,,,,,,,,,,,,,, -39400,4.192931,1.139362,,,,,,,,,,,,,, -39500,2.6285186,1.1460289,,,,,,,,,,,,,, -39600,4.3503346,1.1402626,,,,,,,,,,,,,, -39700,5.4981294,1.1923167,,,,,,,,,,,,,, -39800,2.1241333,1.228899,,,,,,,,,,,,,, -39900,2.9301333,1.1423795,,,,,,,,,,,,,, -40000,2.5094514,1.2591897,,,,,,,,,,,,,, -40011,,,0.22586584,0.0771792604806821,0.43340683,0.1258580572907112,5348.0,0.23872232,0.0754575183312006,2472.0,33171.87407159805,36508.35352563858,33171.87407159805,3333.839416027069,0.960141897201538,0.0 -40100,2.930879,1.1769029,,,,,,,,,,,,,, -40200,3.4648707,1.1374241,,,,,,,,,,,,,, -40300,2.1627488,1.154166,,,,,,,,,,,,,, -40400,2.058491,1.1517098,,,,,,,,,,,,,, -40500,3.0251768,1.1974852,,,,,,,,,,,,,, -40600,3.5114017,1.1533762,,,,,,,,,,,,,, -40700,3.4058492,1.2207679,,,,,,,,,,,,,, -40800,2.6878552,1.1901397,,,,,,,,,,,,,, -40900,3.7881148,1.1645082,,,,,,,,,,,,,, -41000,2.4954588,1.1824653,,,,,,,,,,,,,, -41100,2.6014888,1.1595216,,,,,,,,,,,,,, -41200,4.42464,1.207358,,,,,,,,,,,,,, -41300,2.4993696,1.1319695,,,,,,,,,,,,,, -41400,3.3255947,1.1563072,,,,,,,,,,,,,, -41500,3.0741386,1.1263539,,,,,,,,,,,,,, -41600,2.870723,1.1418791,,,,,,,,,,,,,, -41700,4.2314734,1.165146,,,,,,,,,,,,,, -41741,,,0.22595026,0.0774299560889767,0.4262832,0.1229326974135184,5348.0,0.22935602,0.0725529624438892,2472.0,34611.97883796692,38082.061220407486,34611.97883796692,3467.3239908218384,1.0061976909637451,0.0 -41800,2.6711578,1.1992155,,,,,,,,,,,,,, -41900,2.1235135,1.1660966,,,,,,,,,,,,,, -42000,2.4595444,1.1385146,,,,,,,,,,,,,, -42100,3.0819874,1.1776451,,,,,,,,,,,,,, -42200,3.8875282,1.1376492,,,,,,,,,,,,,, -42300,3.3006938,1.1425755,,,,,,,,,,,,,, -42400,2.8872132,1.1752217,,,,,,,,,,,,,, -42500,3.882451,1.1265416,,,,,,,,,,,,,, -42600,4.409626,1.1689663,,,,,,,,,,,,,, -42700,4.0265946,1.1823527,,,,,,,,,,,,,, -42800,3.5801375,1.1365832,,,,,,,,,,,,,, -42900,5.309421,1.2115905,,,,,,,,,,,,,, -43000,3.263469,1.1103227,,,,,,,,,,,,,, -43100,3.3265088,1.1215786,,,,,,,,,,,,,, -43200,2.3350148,1.1496419,,,,,,,,,,,,,, -43300,3.5162146,1.1553031,,,,,,,,,,,,,, -43400,2.6319993,1.21091,,,,,,,,,,,,,, -43467,,,0.25700453,0.0887505776509058,0.41875762,0.1210307307606901,5348.0,0.22580484,0.0708671013344707,2472.0,36052.042571783066,39656.57435369492,36052.042571783066,3601.6576216220856,1.0495140552520752,0.0 -43500,2.6424692,1.1589301,,,,,,,,,,,,,, -43600,6.58885,1.1681452,,,,,,,,,,,,,, -43700,2.7628212,1.1225097,,,,,,,,,,,,,, -43800,4.6934924,1.1323301,,,,,,,,,,,,,, -43900,2.7859702,1.112424,,,,,,,,,,,,,, -44000,2.5670629,1.1583743,,,,,,,,,,,,,, -44100,4.9845433,1.1522328,,,,,,,,,,,,,, -44200,3.4600668,1.1265408,,,,,,,,,,,,,, -44300,2.8512287,1.1514573,,,,,,,,,,,,,, -44400,3.3035421,1.1015515,,,,,,,,,,,,,, -44500,4.088741,1.1298566,,,,,,,,,,,,,, -44600,3.5406582,1.1661096,,,,,,,,,,,,,, -44700,3.4848347,1.1737458,,,,,,,,,,,,,, -44800,3.952312,1.113825,,,,,,,,,,,,,, -44900,2.7267766,1.1668669,,,,,,,,,,,,,, -45000,3.477577,1.1397711,,,,,,,,,,,,,, -45100,3.0088553,1.1297435,,,,,,,,,,,,,, -45193,,,0.22935137,0.0772333072317175,0.41290858,0.119177037373162,5348.0,0.22298104,0.0698108991936302,2472.0,37492.52834439278,41229.34570097923,37492.52834439278,3733.827961683273,1.090902328491211,0.0 -45200,3.1390893,1.1680017,,,,,,,,,,,,,, -45300,3.4304216,1.1207303,,,,,,,,,,,,,, -45400,5.916537,1.1648191,,,,,,,,,,,,,, -45500,3.4845665,1.1247002,,,,,,,,,,,,,, -45600,2.9387693,1.1083219,,,,,,,,,,,,,, -45700,2.639355,1.12755,,,,,,,,,,,,,, -45800,3.8680372,1.1408652,,,,,,,,,,,,,, -45900,2.9754233,1.1581833,,,,,,,,,,,,,, -46000,5.6245823,1.1318902,,,,,,,,,,,,,, -46100,2.6100497,1.1253254,,,,,,,,,,,,,, -46200,4.9462533,1.134572,,,,,,,,,,,,,, -46300,4.855156,1.1136358,,,,,,,,,,,,,, -46400,12.245806,1.1499991,,,,,,,,,,,,,, -46500,3.2885487,1.1065173,,,,,,,,,,,,,, -46600,3.493686,1.1556406,,,,,,,,,,,,,, -46700,2.7600076,1.1425848,,,,,,,,,,,,,, -46800,2.3189957,1.1270525,,,,,,,,,,,,,, -46900,2.9203556,1.1590859,,,,,,,,,,,,,, -46925,,,0.21121906,0.0731408593410753,0.4120718,0.11874257798546,5348.0,0.22249565,0.0699530802510511,2472.0,38932.92411899567,42802.82615447045,38932.92411899567,3866.790652275085,1.1397252082824707,0.0 -47000,2.9701395,1.0734375,,,,,,,,,,,,,, -47100,3.1440537,1.1739132,,,,,,,,,,,,,, -47200,3.549214,1.139631,,,,,,,,,,,,,, -47300,3.693375,1.1334723,,,,,,,,,,,,,, -47400,4.995905,1.1784284,,,,,,,,,,,,,, -47500,3.4774053,1.1848395,,,,,,,,,,,,,, -47600,3.273864,1.1100976,,,,,,,,,,,,,, -47700,3.2373576,1.1603566,,,,,,,,,,,,,, -47800,3.8737817,1.0448327,,,,,,,,,,,,,, -47900,3.986093,1.1565086,,,,,,,,,,,,,, -48000,,,0.19107874,0.0672709871776745,0.4123398,0.1190998001486816,5348.0,0.22241092,0.0701561960473666,2472.0,39802.3655128479,43806.46564793587,39802.3655128479,4000.894311904907,1.186645746231079,0.0 -48000,,,,,,,,,,,39802.3655128479,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 1ff031863..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,30 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -179.37396621704102,0.0,15.712034225463867,1,0,15.712034225463867,30.657555,2472,3.2318363699144883,195.0860705375672,31.744686,3.086809647688824,30.570248,5348,2.9115054500516524 -289.1729745864868,0.0313925743103027,1456.2849533557892,1755,0,1456.2849533557892,6.4480104,2472,0.8991123839701014,1745.5630271434784,6.613655,0.9419290070434632,6.4508014,5348,0.8961545516861852 -412.4362471103668,0.0835540294647216,2896.8062427043915,3499,0,2896.8062427043915,3.2973654,2472,0.6442629943330693,3309.478348255157,3.93307,0.757276603130865,3.7532241,5348,0.7071840273419775 -546.3971445560455,0.1400220394134521,4337.244078874588,5216,0,4337.244078874588,0.6439003,2472,0.2018564783783235,4884.011987924576,0.9103338,0.2712767425810904,0.9936676,5348,0.2741631829460208 -677.8204755783081,0.1918311119079589,5777.536859035492,6949,0,5777.536859035492,0.50204223,2472,0.1613551885930168,6455.857183456421,0.6986057,0.220888705744995,0.8042088,5348,0.230369676665669 -815.2142312526703,0.2425811290740966,7217.844273805618,8679,0,7217.844273805618,0.4397295,2472,0.1438466069506225,8033.686626672745,0.59292424,0.1918774981213038,0.7251276,5348,0.2104907460150419 -949.255347251892,0.2980294227600097,8658.25753068924,10387,0,8658.25753068924,0.55918086,2472,0.1789653281335689,9608.27274942398,0.77348566,0.2425348039545362,0.9241919,5348,0.2576344169072284 -1082.87002658844,0.3533177375793457,10098.338337421415,12106,0,10098.338337421415,0.37781128,2472,0.1232506652042329,11182.101341962814,0.49431863,0.1640443156405458,0.6491058,5348,0.1874064705484808 -1217.89963889122,0.4057838916778564,11538.243397474287,13829,0,11538.243397474287,0.3647178,2472,0.1195133345520281,12757.166206598282,0.45966643,0.1526637057536937,0.6343328,5348,0.1831004952837019 -1353.5177383422852,0.4561934471130371,12978.453676700592,15544,0,12978.453676700592,0.34296244,2472,0.1111449637438303,14333.121833562853,0.4444199,0.145305374152371,0.5959213,5348,0.1713025092443303 -1488.453429937363,0.5051791667938232,14418.44036602974,17258,0,14418.44036602974,0.3266319,2472,0.1062295614729957,15908.17206454277,0.41884962,0.1392690235012633,0.58758664,5348,0.1685798970813983 -1624.9235422611237,0.5569183826446533,15858.892181396484,18996,0,15858.892181396484,0.31985068,2472,0.105945199358154,17485.223326921463,0.41031313,0.1389371617059429,0.5695679,5348,0.163482240265696 -1758.560397386551,0.6087489128112793,17299.33094882965,20715,0,17299.33094882965,0.30855164,2472,0.1005016960168992,19059.429202079773,0.36195385,0.1216861816105341,0.5525102,5348,0.1612809793680064 -1892.2922401428225,0.6600890159606934,18739.416509628296,22428,0,18739.416509628296,0.30128074,2472,0.0965003148294842,20633.377140283585,0.3208047,0.1074748092640736,0.53712636,5348,0.1548992536953184 -2025.5153098106384,0.7153546810150146,20179.82127547264,24166,0,20179.82127547264,0.289528,2472,0.093412954725489,22207.138612270355,0.36535513,0.1245470043013423,0.52740115,5348,0.1513463413692229 -2159.8889672756195,0.7683267593383789,21620.019072532654,25875,0,21620.019072532654,0.2792349,2472,0.0916864704568074,23781.841695070267,0.32537457,0.1094823663253697,0.51202977,5348,0.1477065371655869 -2297.339144229889,0.822613000869751,23060.151054382324,27568,0,23060.151054382324,0.2763063,2472,0.0893506387991794,25359.555035591125,0.31351584,0.1057912797410919,0.5010615,5348,0.1443370632476322 -2430.969462156296,0.8745872974395752,24501.43314099312,29314,0,24501.43314099312,0.26795527,2472,0.0859789165803424,26934.59812092781,0.28522658,0.0990947228968867,0.48595044,5348,0.140871042799077 -2564.800893306732,0.9301862716674804,25941.97016787529,31036,0,25941.97016787529,0.26138502,2472,0.083886823878293,28509.099543333054,0.26872805,0.0943131977384197,0.47574943,5348,0.1365940314934783 -2698.872343540192,0.9823896884918212,27382.09087133408,32757,0,27382.09087133408,0.24835972,2472,0.0814494343225072,30083.422049045563,0.2560898,0.0891313082388943,0.46562022,5348,0.1335914343918051 -2857.1478266716003,1.0384981632232666,28822.69719481468,34518,0,28822.69719481468,0.2446342,2472,0.0796010805760364,31682.440099477768,0.15375972,0.0547525930964121,0.45650312,5348,0.130608146596252 -2992.0138483047485,1.0968949794769287,30263.501261234283,36248,0,30263.501261234283,0.23961341,2472,0.0775902341925131,33258.248254299164,0.14782521,0.0519759160073353,0.44794834,5348,0.1287254892495438 -3128.327612876892,1.143108367919922,31703.481546401978,37941,0,31703.481546401978,0.23499914,2472,0.0761684236183048,34834.663165569305,0.15493704,0.0548058881048653,0.44182682,5348,0.1258580572907112 -3263.819913864136,1.1951828002929688,33144.05407190323,39661,0,33144.05407190323,0.23229034,2472,0.0749294172607803,36410.86017107964,0.13454847,0.0484128474830954,0.4356194,5348,0.1246512256582059 -3398.2206456661224,1.250486135482788,34584.85236620903,41377,0,34584.85236620903,0.2291814,2472,0.0734060487884142,37986.19295430184,0.14412792,0.0515004984779097,0.43128958,5348,0.1231161358216592 -3533.5720529556274,1.3105106353759766,36025.78731369972,43071,0,36025.78731369972,0.22690861,2472,0.0733451140495196,39562.61579680443,0.1331379,0.0482064816694626,0.42728096,5348,0.1221989437809552 -3669.3792510032654,1.3633804321289062,37466.45243930817,44804,0,37466.45243930817,0.22530259,2472,0.0731216866735726,41139.22075676918,0.15674452,0.0525206026268349,0.42669344,5348,0.1222665263523755 -3807.438556671143,1.4167790412902832,38906.67197370529,46503,0,38906.67197370529,0.22481138,2472,0.0725326508642577,42717.63082480431,0.14168745,0.0509011874469889,0.42561495,5348,0.1211658959035307 -3942.461863040924,1.4705607891082764,40174.220403671265,48000,0,40174.220403671265,0.22480996,2472,0.07304044035504641,44120.32517743111,0.121190354,0.04460789943515802,0.42536923,5348,0.12134933431167151 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/measurements.csv deleted file mode 100644 index 628661d06..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/measurements.csv +++ /dev/null @@ -1,511 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,19.677786,33.113293,,,,,,,,,,,,,, -1,,,31.744686,3.086809647688824,30.570248,2.9115054500516524,5348.0,30.657555,3.2318363699144883,2472.0,15.712034225463867,195.0860705375672,15.712034225463867,179.37396621704102,0.0,0.0 -100,2.828709,7.5532155,,,,,,,,,,,,,, -200,1.2264386,5.994002,,,,,,,,,,,,,, -300,0.7438927,5.8499146,,,,,,,,,,,,,, -400,0.55252963,5.825764,,,,,,,,,,,,,, -500,1.8384137,5.79629,,,,,,,,,,,,,, -600,0.63845867,5.6488624,,,,,,,,,,,,,, -700,0.708019,5.50197,,,,,,,,,,,,,, -800,1.6105295,5.2290354,,,,,,,,,,,,,, -900,1.7680719,4.784316,,,,,,,,,,,,,, -1000,1.9115158,4.3328743,,,,,,,,,,,,,, -1100,2.0945034,3.981099,,,,,,,,,,,,,, -1200,3.1197064,3.6544752,,,,,,,,,,,,,, -1300,2.698316,3.4632182,,,,,,,,,,,,,, -1400,2.9865665,3.3184807,,,,,,,,,,,,,, -1500,2.400276,3.1751165,,,,,,,,,,,,,, -1600,3.7757635,3.0617611,,,,,,,,,,,,,, -1700,2.6201658,2.9204938,,,,,,,,,,,,,, -1755,,,6.613655,0.9419290070434632,6.4508014,0.8961545516861852,5348.0,6.4480104,0.8991123839701014,2472.0,1456.2849533557892,1745.5630271434784,1456.2849533557892,289.1729745864868,0.0313925743103027,0.0 -1800,3.4410028,2.8721375,,,,,,,,,,,,,, -1900,2.0985782,2.754826,,,,,,,,,,,,,, -2000,2.7927964,2.6798851,,,,,,,,,,,,,, -2100,3.2281964,2.5789962,,,,,,,,,,,,,, -2200,4.1877575,2.578802,,,,,,,,,,,,,, -2300,2.539874,2.5235887,,,,,,,,,,,,,, -2400,2.158878,2.4076874,,,,,,,,,,,,,, -2500,2.993849,2.459443,,,,,,,,,,,,,, -2600,2.9981806,2.3664196,,,,,,,,,,,,,, -2700,3.9879615,2.3459938,,,,,,,,,,,,,, -2800,2.8452532,2.2829437,,,,,,,,,,,,,, -2900,2.6873164,2.2446597,,,,,,,,,,,,,, -3000,2.9936445,2.181462,,,,,,,,,,,,,, -3100,3.2759442,2.2156901,,,,,,,,,,,,,, -3200,3.2302985,2.1734312,,,,,,,,,,,,,, -3300,3.4949105,2.1224887,,,,,,,,,,,,,, -3400,9.99699,2.2165022,,,,,,,,,,,,,, -3499,,,3.93307,0.757276603130865,3.7532241,0.7071840273419775,5348.0,3.2973654,0.6442629943330693,2472.0,2896.8062427043915,3309.478348255157,2896.8062427043915,412.4362471103668,0.0835540294647216,0.0 -3500,3.6754167,2.0115163,,,,,,,,,,,,,, -3600,3.6798744,2.0883558,,,,,,,,,,,,,, -3700,3.563343,2.0300689,,,,,,,,,,,,,, -3800,1.9958775,2.035046,,,,,,,,,,,,,, -3900,2.360428,2.0066648,,,,,,,,,,,,,, -4000,2.8578434,1.9933659,,,,,,,,,,,,,, -4100,3.647795,2.0408714,,,,,,,,,,,,,, -4200,2.5748074,1.967792,,,,,,,,,,,,,, -4300,2.4153259,1.9267977,,,,,,,,,,,,,, -4400,4.5044117,1.8968658,,,,,,,,,,,,,, -4500,2.5975096,1.9656748,,,,,,,,,,,,,, -4600,2.6205206,1.9313674,,,,,,,,,,,,,, -4700,2.3881202,1.9323721,,,,,,,,,,,,,, -4800,3.3698835,1.9234667,,,,,,,,,,,,,, -4900,4.368038,1.8696454,,,,,,,,,,,,,, -5000,2.7817852,1.8448032,,,,,,,,,,,,,, -5100,3.094404,1.8611306,,,,,,,,,,,,,, -5200,2.0683815,1.8018689,,,,,,,,,,,,,, -5216,,,0.9103338,0.2712767425810904,0.9936676,0.2741631829460208,5348.0,0.6439003,0.2018564783783235,2472.0,4337.244078874588,4884.011987924576,4337.244078874588,546.3971445560455,0.1400220394134521,0.0 -5300,2.910766,1.7334154,,,,,,,,,,,,,, -5400,4.0617547,1.814947,,,,,,,,,,,,,, -5500,3.623837,1.8338654,,,,,,,,,,,,,, -5600,2.6921537,1.8152596,,,,,,,,,,,,,, -5700,4.3820844,1.7683051,,,,,,,,,,,,,, -5800,3.8817642,1.8469367,,,,,,,,,,,,,, -5900,2.9042428,1.8146663,,,,,,,,,,,,,, -6000,4.07275,1.8052889,,,,,,,,,,,,,, -6100,3.607179,1.8160723,,,,,,,,,,,,,, -6200,3.0581558,1.7114855,,,,,,,,,,,,,, -6300,2.582204,1.7424655,,,,,,,,,,,,,, -6400,3.427089,1.721512,,,,,,,,,,,,,, -6500,3.1314352,1.6928271,,,,,,,,,,,,,, -6600,4.5555625,1.7102487,,,,,,,,,,,,,, -6700,2.8786407,1.7158235,,,,,,,,,,,,,, -6800,2.888451,1.7098264,,,,,,,,,,,,,, -6900,2.3518672,1.7457937,,,,,,,,,,,,,, -6949,,,0.6986057,0.220888705744995,0.8042088,0.230369676665669,5348.0,0.50204223,0.1613551885930168,2472.0,5777.536859035492,6455.857183456421,5777.536859035492,677.8204755783081,0.1918311119079589,0.0 -7000,2.8376625,1.7253087,,,,,,,,,,,,,, -7100,2.1902308,1.6207807,,,,,,,,,,,,,, -7200,2.4919493,1.7747682,,,,,,,,,,,,,, -7300,2.6630163,1.6948832,,,,,,,,,,,,,, -7400,4.0960097,1.6344655,,,,,,,,,,,,,, -7500,2.8951998,1.6535102,,,,,,,,,,,,,, -7600,3.5569346,1.6427115,,,,,,,,,,,,,, -7700,2.5945961,1.6629019,,,,,,,,,,,,,, -7800,2.3481655,1.6321409,,,,,,,,,,,,,, -7900,3.5978038,1.6604527,,,,,,,,,,,,,, -8000,1.8803147,1.6393743,,,,,,,,,,,,,, -8100,5.1544952,1.6948421,,,,,,,,,,,,,, -8200,2.1010678,1.5913533,,,,,,,,,,,,,, -8300,2.476037,1.7102789,,,,,,,,,,,,,, -8400,2.1151152,1.6610848,,,,,,,,,,,,,, -8500,3.2749522,1.5891598,,,,,,,,,,,,,, -8600,2.416307,1.605862,,,,,,,,,,,,,, -8679,,,0.59292424,0.1918774981213038,0.7251276,0.2104907460150419,5348.0,0.4397295,0.1438466069506225,2472.0,7217.844273805618,8033.686626672745,7217.844273805618,815.2142312526703,0.2425811290740966,0.0 -8700,2.799989,1.6319022,,,,,,,,,,,,,, -8800,2.6533926,1.654345,,,,,,,,,,,,,, -8900,3.1612532,1.6628052,,,,,,,,,,,,,, -9000,2.9863837,1.6138437,,,,,,,,,,,,,, -9100,2.7629364,1.5802381,,,,,,,,,,,,,, -9200,2.5824172,1.6109743,,,,,,,,,,,,,, -9300,2.7699096,1.6156554,,,,,,,,,,,,,, -9400,2.5944512,1.6505808,,,,,,,,,,,,,, -9500,2.882657,1.6321241,,,,,,,,,,,,,, -9600,2.5693228,1.4754847,,,,,,,,,,,,,, -9700,2.886591,1.6278926,,,,,,,,,,,,,, -9800,3.215513,1.5719508,,,,,,,,,,,,,, -9900,3.0641105,1.5121939,,,,,,,,,,,,,, -10000,3.1508648,1.5863936,,,,,,,,,,,,,, -10100,2.7210362,1.5689584,,,,,,,,,,,,,, -10200,2.879592,1.5665672,,,,,,,,,,,,,, -10300,4.2972627,1.5648971,,,,,,,,,,,,,, -10387,,,0.77348566,0.2425348039545362,0.9241919,0.2576344169072284,5348.0,0.55918086,0.1789653281335689,2472.0,8658.25753068924,9608.27274942398,8658.25753068924,949.255347251892,0.2980294227600097,0.0 -10400,2.0187957,1.8136827,,,,,,,,,,,,,, -10500,2.639669,1.6324365,,,,,,,,,,,,,, -10600,2.2579389,1.6101044,,,,,,,,,,,,,, -10700,1.8786635,1.6901581,,,,,,,,,,,,,, -10800,2.449849,1.6380996,,,,,,,,,,,,,, -10900,2.3720977,1.5503936,,,,,,,,,,,,,, -11000,3.0653162,1.6100415,,,,,,,,,,,,,, -11100,3.0399497,1.5536163,,,,,,,,,,,,,, -11200,1.6165869,1.5974101,,,,,,,,,,,,,, -11300,2.3634574,1.5911164,,,,,,,,,,,,,, -11400,1.7929463,1.563373,,,,,,,,,,,,,, -11500,2.8841417,1.5778171,,,,,,,,,,,,,, -11600,2.3764687,1.5812062,,,,,,,,,,,,,, -11700,2.1056128,1.475165,,,,,,,,,,,,,, -11800,2.1221116,1.6393331,,,,,,,,,,,,,, -11900,2.5312479,1.5308322,,,,,,,,,,,,,, -12000,2.1642787,1.5394818,,,,,,,,,,,,,, -12100,2.7224538,1.5634809,,,,,,,,,,,,,, -12106,,,0.49431863,0.1640443156405458,0.6491058,0.1874064705484808,5348.0,0.37781128,0.1232506652042329,2472.0,10098.338337421415,11182.101341962814,10098.338337421415,1082.87002658844,0.3533177375793457,0.0 -12200,3.12497,1.5233773,,,,,,,,,,,,,, -12300,2.156312,1.5193655,,,,,,,,,,,,,, -12400,3.1866994,1.5147282,,,,,,,,,,,,,, -12500,2.9084055,1.4963932,,,,,,,,,,,,,, -12600,2.5644937,1.5068563,,,,,,,,,,,,,, -12700,2.9253602,1.5854989,,,,,,,,,,,,,, -12800,2.441733,1.552332,,,,,,,,,,,,,, -12900,3.9497333,1.5319285,,,,,,,,,,,,,, -13000,3.2603476,1.4426527,,,,,,,,,,,,,, -13100,1.5423512,1.4583951,,,,,,,,,,,,,, -13200,2.9026625,1.570402,,,,,,,,,,,,,, -13300,2.4794185,1.4809426,,,,,,,,,,,,,, -13400,1.8377651,1.5009577,,,,,,,,,,,,,, -13500,2.9324636,1.444431,,,,,,,,,,,,,, -13600,3.023348,1.4405404,,,,,,,,,,,,,, -13700,3.6851246,1.4746114,,,,,,,,,,,,,, -13800,3.7507732,1.469154,,,,,,,,,,,,,, -13829,,,0.45966643,0.1526637057536937,0.6343328,0.1831004952837019,5348.0,0.3647178,0.1195133345520281,2472.0,11538.243397474287,12757.166206598282,11538.243397474287,1217.89963889122,0.4057838916778564,0.0 -13900,2.8877535,1.577677,,,,,,,,,,,,,, -14000,2.3800805,1.5151404,,,,,,,,,,,,,, -14100,3.2288244,1.4430664,,,,,,,,,,,,,, -14200,2.70631,1.4305296,,,,,,,,,,,,,, -14300,3.7291436,1.5058826,,,,,,,,,,,,,, -14400,2.1525342,1.5190816,,,,,,,,,,,,,, -14500,2.9267583,1.4531478,,,,,,,,,,,,,, -14600,2.6933837,1.4877455,,,,,,,,,,,,,, -14700,2.3140726,1.4918388,,,,,,,,,,,,,, -14800,2.9076602,1.5082002,,,,,,,,,,,,,, -14900,2.3223014,1.4874111,,,,,,,,,,,,,, -15000,2.6315625,1.5388402,,,,,,,,,,,,,, -15100,4.498767,1.4613117,,,,,,,,,,,,,, -15200,2.2295272,1.4739305,,,,,,,,,,,,,, -15300,2.6934829,1.4974176,,,,,,,,,,,,,, -15400,2.0164037,1.479274,,,,,,,,,,,,,, -15500,2.3907993,1.4511578,,,,,,,,,,,,,, -15544,,,0.4444199,0.145305374152371,0.5959213,0.1713025092443303,5348.0,0.34296244,0.1111449637438303,2472.0,12978.453676700592,14333.121833562853,12978.453676700592,1353.5177383422852,0.4561934471130371,0.0 -15600,2.5387244,1.4765787,,,,,,,,,,,,,, -15700,2.5395107,1.5253259,,,,,,,,,,,,,, -15800,2.3069162,1.4544992,,,,,,,,,,,,,, -15900,1.848484,1.4919292,,,,,,,,,,,,,, -16000,2.0700507,1.4799707,,,,,,,,,,,,,, -16100,2.602657,1.399082,,,,,,,,,,,,,, -16200,2.2710001,1.4159137,,,,,,,,,,,,,, -16300,2.3057828,1.4808236,,,,,,,,,,,,,, -16400,2.957977,1.3925705,,,,,,,,,,,,,, -16500,2.370047,1.4321516,,,,,,,,,,,,,, -16600,3.5080922,1.447274,,,,,,,,,,,,,, -16700,2.2558155,1.4535959,,,,,,,,,,,,,, -16800,2.620765,1.4221115,,,,,,,,,,,,,, -16900,5.1187315,1.4006114,,,,,,,,,,,,,, -17000,1.9260873,1.4445913,,,,,,,,,,,,,, -17100,2.8313518,1.389603,,,,,,,,,,,,,, -17200,1.9363438,1.4426816,,,,,,,,,,,,,, -17258,,,0.41884962,0.1392690235012633,0.58758664,0.1685798970813983,5348.0,0.3266319,0.1062295614729957,2472.0,14418.44036602974,15908.17206454277,14418.44036602974,1488.453429937363,0.5051791667938232,0.0 -17300,2.8557518,1.4203938,,,,,,,,,,,,,, -17400,2.3860717,1.4251715,,,,,,,,,,,,,, -17500,2.6525662,1.4453863,,,,,,,,,,,,,, -17600,1.9907801,1.4145257,,,,,,,,,,,,,, -17700,1.4688741,1.4214739,,,,,,,,,,,,,, -17800,2.062885,1.4250816,,,,,,,,,,,,,, -17900,2.587472,1.3722831,,,,,,,,,,,,,, -18000,2.4650385,1.4547697,,,,,,,,,,,,,, -18100,3.0191069,1.3866061,,,,,,,,,,,,,, -18200,2.701054,1.4614997,,,,,,,,,,,,,, -18300,2.7116752,1.4102967,,,,,,,,,,,,,, -18400,2.9427507,1.3920008,,,,,,,,,,,,,, -18500,2.8539536,1.4660437,,,,,,,,,,,,,, -18600,3.4413602,1.4117785,,,,,,,,,,,,,, -18700,3.0512896,1.3946037,,,,,,,,,,,,,, -18800,2.60187,1.3900162,,,,,,,,,,,,,, -18900,3.2388737,1.4375252,,,,,,,,,,,,,, -18996,,,0.41031313,0.1389371617059429,0.5695679,0.163482240265696,5348.0,0.31985068,0.105945199358154,2472.0,15858.892181396484,17485.223326921463,15858.892181396484,1624.9235422611237,0.5569183826446533,0.0 -19000,2.0895255,1.3624859,,,,,,,,,,,,,, -19100,2.670027,1.415209,,,,,,,,,,,,,, -19200,2.6165023,1.4696637,,,,,,,,,,,,,, -19300,1.6435561,1.3850341,,,,,,,,,,,,,, -19400,1.7738789,1.4012065,,,,,,,,,,,,,, -19500,2.9815211,1.4197357,,,,,,,,,,,,,, -19600,3.163011,1.4304323,,,,,,,,,,,,,, -19700,2.5143466,1.3791935,,,,,,,,,,,,,, -19800,1.8005875,1.3663874,,,,,,,,,,,,,, -19900,2.9029176,1.4217498,,,,,,,,,,,,,, -20000,2.2003133,1.4005036,,,,,,,,,,,,,, -20100,1.9090357,1.3345087,,,,,,,,,,,,,, -20200,4.0750265,1.3834927,,,,,,,,,,,,,, -20300,3.0226488,1.4207976,,,,,,,,,,,,,, -20400,2.757402,1.4399939,,,,,,,,,,,,,, -20500,2.3410497,1.4369836,,,,,,,,,,,,,, -20600,1.6501148,1.30576,,,,,,,,,,,,,, -20700,1.9857874,1.3901963,,,,,,,,,,,,,, -20715,,,0.36195385,0.1216861816105341,0.5525102,0.1612809793680064,5348.0,0.30855164,0.1005016960168992,2472.0,17299.33094882965,19059.429202079773,17299.33094882965,1758.560397386551,0.6087489128112793,0.0 -20800,3.0568383,1.372151,,,,,,,,,,,,,, -20900,2.6085498,1.3397197,,,,,,,,,,,,,, -21000,2.743869,1.3480089,,,,,,,,,,,,,, -21100,2.3050885,1.3568681,,,,,,,,,,,,,, -21200,2.9238958,1.3683783,,,,,,,,,,,,,, -21300,2.0591457,1.4016834,,,,,,,,,,,,,, -21400,2.9719076,1.4138743,,,,,,,,,,,,,, -21500,3.3536298,1.3979635,,,,,,,,,,,,,, -21600,1.635939,1.3627174,,,,,,,,,,,,,, -21700,3.700651,1.3075383,,,,,,,,,,,,,, -21800,2.389947,1.3747578,,,,,,,,,,,,,, -21900,3.2936969,1.3644531,,,,,,,,,,,,,, -22000,2.6713157,1.4575042,,,,,,,,,,,,,, -22100,3.0761003,1.3391739,,,,,,,,,,,,,, -22200,2.2485368,1.3660501,,,,,,,,,,,,,, -22300,3.2506747,1.4724302,,,,,,,,,,,,,, -22400,2.312461,1.3802922,,,,,,,,,,,,,, -22428,,,0.3208047,0.1074748092640736,0.53712636,0.1548992536953184,5348.0,0.30128074,0.0965003148294842,2472.0,18739.416509628296,20633.377140283585,18739.416509628296,1892.2922401428225,0.6600890159606934,0.0 -22500,2.1530025,1.3607806,,,,,,,,,,,,,, -22600,1.5858432,1.4101988,,,,,,,,,,,,,, -22700,2.5993834,1.3068322,,,,,,,,,,,,,, -22800,2.2161164,1.4173663,,,,,,,,,,,,,, -22900,2.3050752,1.3687017,,,,,,,,,,,,,, -23000,1.7631385,1.2900056,,,,,,,,,,,,,, -23100,1.9942255,1.3500438,,,,,,,,,,,,,, -23200,3.395731,1.3624701,,,,,,,,,,,,,, -23300,2.9572773,1.3439021,,,,,,,,,,,,,, -23400,2.680253,1.3271712,,,,,,,,,,,,,, -23500,2.9328637,1.349669,,,,,,,,,,,,,, -23600,3.0029483,1.3811555,,,,,,,,,,,,,, -23700,4.947315,1.3603315,,,,,,,,,,,,,, -23800,2.2942212,1.351987,,,,,,,,,,,,,, -23900,2.017989,1.3034978,,,,,,,,,,,,,, -24000,1.7391828,1.2932526,,,,,,,,,,,,,, -24100,1.8457447,1.3148241,,,,,,,,,,,,,, -24166,,,0.36535513,0.1245470043013423,0.52740115,0.1513463413692229,5348.0,0.289528,0.093412954725489,2472.0,20179.82127547264,22207.138612270355,20179.82127547264,2025.5153098106384,0.7153546810150146,0.0 -24200,1.7392578,1.334034,,,,,,,,,,,,,, -24300,1.8500317,1.321653,,,,,,,,,,,,,, -24400,2.2664907,1.3104942,,,,,,,,,,,,,, -24500,4.0119567,1.3781865,,,,,,,,,,,,,, -24600,1.9544178,1.3324808,,,,,,,,,,,,,, -24700,4.1731973,1.3886558,,,,,,,,,,,,,, -24800,2.5806081,1.347228,,,,,,,,,,,,,, -24900,2.4006548,1.3226829,,,,,,,,,,,,,, -25000,2.3073115,1.3269982,,,,,,,,,,,,,, -25100,3.9439926,1.3520396,,,,,,,,,,,,,, -25200,2.1771622,1.3505595,,,,,,,,,,,,,, -25300,2.8502119,1.3596206,,,,,,,,,,,,,, -25400,1.7914598,1.3203515,,,,,,,,,,,,,, -25500,2.3670175,1.3390763,,,,,,,,,,,,,, -25600,1.9442538,1.2836717,,,,,,,,,,,,,, -25700,1.9639853,1.3527873,,,,,,,,,,,,,, -25800,2.6662235,1.2913002,,,,,,,,,,,,,, -25875,,,0.32537457,0.1094823663253697,0.51202977,0.1477065371655869,5348.0,0.2792349,0.0916864704568074,2472.0,21620.019072532654,23781.841695070267,21620.019072532654,2159.8889672756195,0.7683267593383789,0.0 -25900,2.006784,1.2182527,,,,,,,,,,,,,, -26000,4.2730813,1.2838829,,,,,,,,,,,,,, -26100,2.2443237,1.3233148,,,,,,,,,,,,,, -26200,2.7729642,1.3496261,,,,,,,,,,,,,, -26300,2.6964674,1.3272083,,,,,,,,,,,,,, -26400,1.6569674,1.2714618,,,,,,,,,,,,,, -26500,3.5487933,1.3578414,,,,,,,,,,,,,, -26600,2.209597,1.2912453,,,,,,,,,,,,,, -26700,2.6586246,1.336998,,,,,,,,,,,,,, -26800,1.9079564,1.2977656,,,,,,,,,,,,,, -26900,5.0043077,1.2827859,,,,,,,,,,,,,, -27000,2.1826682,1.3105835,,,,,,,,,,,,,, -27100,1.7912506,1.3710911,,,,,,,,,,,,,, -27200,1.8299321,1.302163,,,,,,,,,,,,,, -27300,3.32904,1.3154553,,,,,,,,,,,,,, -27400,4.0361056,1.2993432,,,,,,,,,,,,,, -27500,2.3981442,1.2795919,,,,,,,,,,,,,, -27568,,,0.31351584,0.1057912797410919,0.5010615,0.1443370632476322,5348.0,0.2763063,0.0893506387991794,2472.0,23060.151054382324,25359.555035591125,23060.151054382324,2297.339144229889,0.822613000869751,0.0 -27600,1.4044367,1.2794192,,,,,,,,,,,,,, -27700,2.0743456,1.310129,,,,,,,,,,,,,, -27800,1.8473401,1.252351,,,,,,,,,,,,,, -27900,2.3203986,1.3101698,,,,,,,,,,,,,, -28000,2.6193135,1.2575133,,,,,,,,,,,,,, -28100,2.0412323,1.2849247,,,,,,,,,,,,,, -28200,2.7488024,1.2935765,,,,,,,,,,,,,, -28300,1.8157102,1.2981541,,,,,,,,,,,,,, -28400,3.5412428,1.2540247,,,,,,,,,,,,,, -28500,2.3532171,1.2952158,,,,,,,,,,,,,, -28600,2.9830997,1.2147781,,,,,,,,,,,,,, -28700,2.9903433,1.3441592,,,,,,,,,,,,,, -28800,3.064296,1.234493,,,,,,,,,,,,,, -28900,2.7000701,1.2753017,,,,,,,,,,,,,, -29000,2.9506445,1.2491214,,,,,,,,,,,,,, -29100,3.6067243,1.2529385,,,,,,,,,,,,,, -29200,2.2858162,1.3012503,,,,,,,,,,,,,, -29300,2.4448938,1.2717534,,,,,,,,,,,,,, -29314,,,0.28522658,0.0990947228968867,0.48595044,0.140871042799077,5348.0,0.26795527,0.0859789165803424,2472.0,24501.43314099312,26934.59812092781,24501.43314099312,2430.969462156296,0.8745872974395752,0.0 -29400,1.9649708,1.2649113,,,,,,,,,,,,,, -29500,2.718895,1.2705957,,,,,,,,,,,,,, -29600,2.182384,1.2191149,,,,,,,,,,,,,, -29700,2.5936244,1.1747241,,,,,,,,,,,,,, -29800,2.9841356,1.3107427,,,,,,,,,,,,,, -29900,1.644564,1.2498055,,,,,,,,,,,,,, -30000,2.1201935,1.2644664,,,,,,,,,,,,,, -30100,2.0871453,1.2577584,,,,,,,,,,,,,, -30200,1.9328455,1.2339371,,,,,,,,,,,,,, -30300,2.2941322,1.228933,,,,,,,,,,,,,, -30400,1.9754772,1.2506948,,,,,,,,,,,,,, -30500,2.4581163,1.2259432,,,,,,,,,,,,,, -30600,1.855134,1.2094412,,,,,,,,,,,,,, -30700,5.8239393,1.1955278,,,,,,,,,,,,,, -30800,3.7600849,1.2151266,,,,,,,,,,,,,, -30900,2.5180266,1.2568642,,,,,,,,,,,,,, -31000,2.490877,1.2471211,,,,,,,,,,,,,, -31036,,,0.26872805,0.0943131977384197,0.47574943,0.1365940314934783,5348.0,0.26138502,0.083886823878293,2472.0,25941.97016787529,28509.099543333054,25941.97016787529,2564.800893306732,0.9301862716674804,0.0 -31100,2.2070723,1.2351797,,,,,,,,,,,,,, -31200,2.6851544,1.2408744,,,,,,,,,,,,,, -31300,2.182011,1.2058624,,,,,,,,,,,,,, -31400,2.2666125,1.1765794,,,,,,,,,,,,,, -31500,2.3668938,1.3055893,,,,,,,,,,,,,, -31600,2.7634695,1.1874517,,,,,,,,,,,,,, -31700,2.7751143,1.2874365,,,,,,,,,,,,,, -31800,3.0271814,1.2832391,,,,,,,,,,,,,, -31900,2.610038,1.2964228,,,,,,,,,,,,,, -32000,1.954692,1.2305503,,,,,,,,,,,,,, -32100,2.5255275,1.2683666,,,,,,,,,,,,,, -32200,2.1273642,1.2437878,,,,,,,,,,,,,, -32300,1.756294,1.233752,,,,,,,,,,,,,, -32400,2.778715,1.3152357,,,,,,,,,,,,,, -32500,1.8537961,1.2625442,,,,,,,,,,,,,, -32600,3.3941383,1.2362934,,,,,,,,,,,,,, -32700,3.9201958,1.2223374,,,,,,,,,,,,,, -32757,,,0.2560898,0.0891313082388943,0.46562022,0.1335914343918051,5348.0,0.24835972,0.0814494343225072,2472.0,27382.09087133408,30083.422049045563,27382.09087133408,2698.872343540192,0.9823896884918212,0.0 -32800,1.7622874,1.1960759,,,,,,,,,,,,,, -32900,1.9302188,1.2199285,,,,,,,,,,,,,, -33000,2.3176563,1.1769675,,,,,,,,,,,,,, -33100,2.1296203,1.2206368,,,,,,,,,,,,,, -33200,6.433339,1.1950264,,,,,,,,,,,,,, -33300,3.2041185,1.1957014,,,,,,,,,,,,,, -33400,3.0343373,1.2489513,,,,,,,,,,,,,, -33500,2.9898226,1.2279528,,,,,,,,,,,,,, -33600,2.7267597,1.1796787,,,,,,,,,,,,,, -33700,3.904666,1.2175316,,,,,,,,,,,,,, -33800,2.7354374,1.2029886,,,,,,,,,,,,,, -33900,1.680082,1.2069846,,,,,,,,,,,,,, -34000,1.7035795,1.1999657,,,,,,,,,,,,,, -34100,2.2520874,1.2222174,,,,,,,,,,,,,, -34200,2.3003216,1.1890824,,,,,,,,,,,,,, -34300,4.2745714,1.2323695,,,,,,,,,,,,,, -34400,2.7436097,1.2201434,,,,,,,,,,,,,, -34500,2.0959666,1.2285955,,,,,,,,,,,,,, -34518,,,0.15375972,0.0547525930964121,0.45650312,0.130608146596252,5348.0,0.2446342,0.0796010805760364,2472.0,28822.69719481468,31682.440099477768,28822.69719481468,2857.1478266716003,1.0384981632232666,0.0 -34600,2.0523877,1.2312292,,,,,,,,,,,,,, -34700,2.2496676,1.2624984,,,,,,,,,,,,,, -34800,3.0492187,1.1960006,,,,,,,,,,,,,, -34900,2.7330089,1.2616429,,,,,,,,,,,,,, -35000,2.2866533,1.1966909,,,,,,,,,,,,,, -35100,1.9520347,1.211751,,,,,,,,,,,,,, -35200,2.139972,1.2054648,,,,,,,,,,,,,, -35300,1.7852027,1.2067751,,,,,,,,,,,,,, -35400,3.7459424,1.2603335,,,,,,,,,,,,,, -35500,3.4377553,1.2081766,,,,,,,,,,,,,, -35600,1.9000264,1.1918893,,,,,,,,,,,,,, -35700,3.862821,1.1718869,,,,,,,,,,,,,, -35800,2.3876765,1.1615814,,,,,,,,,,,,,, -35900,3.2884266,1.1831576,,,,,,,,,,,,,, -36000,2.766355,1.1540161,,,,,,,,,,,,,, -36100,3.6370106,1.205811,,,,,,,,,,,,,, -36200,2.2817726,1.1475383,,,,,,,,,,,,,, -36248,,,0.14782521,0.0519759160073353,0.44794834,0.1287254892495438,5348.0,0.23961341,0.0775902341925131,2472.0,30263.501261234283,33258.248254299164,30263.501261234283,2992.0138483047485,1.0968949794769287,0.0 -36300,1.693608,1.1627016,,,,,,,,,,,,,, -36400,3.0187438,1.1603668,,,,,,,,,,,,,, -36500,2.8245068,1.1670382,,,,,,,,,,,,,, -36600,3.6122942,1.2025479,,,,,,,,,,,,,, -36700,3.5649426,1.161759,,,,,,,,,,,,,, -36800,2.380795,1.1773251,,,,,,,,,,,,,, -36900,3.0716648,1.2376653,,,,,,,,,,,,,, -37000,1.7465283,1.1949425,,,,,,,,,,,,,, -37100,2.317199,1.1529287,,,,,,,,,,,,,, -37200,1.9969575,1.217336,,,,,,,,,,,,,, -37300,2.119053,1.1553979,,,,,,,,,,,,,, -37400,1.9857525,1.1714679,,,,,,,,,,,,,, -37500,2.0749114,1.1580691,,,,,,,,,,,,,, -37600,2.6056485,1.1980516,,,,,,,,,,,,,, -37700,3.0055475,1.2254889,,,,,,,,,,,,,, -37800,4.0745296,1.1963385,,,,,,,,,,,,,, -37900,2.8634748,1.1747413,,,,,,,,,,,,,, -37941,,,0.15493704,0.0548058881048653,0.44182682,0.1258580572907112,5348.0,0.23499914,0.0761684236183048,2472.0,31703.481546401978,34834.663165569305,31703.481546401978,3128.327612876892,1.143108367919922,0.0 -38000,1.8767668,1.1841701,,,,,,,,,,,,,, -38100,2.3001397,1.210617,,,,,,,,,,,,,, -38200,3.352397,1.174121,,,,,,,,,,,,,, -38300,2.939682,1.1898988,,,,,,,,,,,,,, -38400,2.6899526,1.1131219,,,,,,,,,,,,,, -38500,3.5703099,1.1890149,,,,,,,,,,,,,, -38600,2.042814,1.1340309,,,,,,,,,,,,,, -38700,2.0722773,1.2038547,,,,,,,,,,,,,, -38800,2.6679542,1.1251189,,,,,,,,,,,,,, -38900,4.4969177,1.1588305,,,,,,,,,,,,,, -39000,3.7190776,1.1635467,,,,,,,,,,,,,, -39100,2.0629425,1.1439172,,,,,,,,,,,,,, -39200,3.3802817,1.1562597,,,,,,,,,,,,,, -39300,1.7898023,1.1996313,,,,,,,,,,,,,, -39400,3.1570823,1.1667489,,,,,,,,,,,,,, -39500,3.2792854,1.1589392,,,,,,,,,,,,,, -39600,3.2538266,1.1630696,,,,,,,,,,,,,, -39661,,,0.13454847,0.0484128474830954,0.4356194,0.1246512256582059,5348.0,0.23229034,0.0749294172607803,2472.0,33144.05407190323,36410.86017107964,33144.05407190323,3263.819913864136,1.1951828002929688,0.0 -39700,4.527709,1.1520054,,,,,,,,,,,,,, -39800,2.0774755,1.1917791,,,,,,,,,,,,,, -39900,2.32698,1.1841928,,,,,,,,,,,,,, -40000,2.825967,1.1648488,,,,,,,,,,,,,, -40100,2.3899896,1.1943864,,,,,,,,,,,,,, -40200,4.505771,1.1757989,,,,,,,,,,,,,, -40300,3.123366,1.178076,,,,,,,,,,,,,, -40400,2.0851355,1.104455,,,,,,,,,,,,,, -40500,2.4554503,1.1581758,,,,,,,,,,,,,, -40600,2.9737346,1.1410432,,,,,,,,,,,,,, -40700,2.102324,1.1650169,,,,,,,,,,,,,, -40800,1.9767921,1.1260352,,,,,,,,,,,,,, -40900,2.7988207,1.1733682,,,,,,,,,,,,,, -41000,2.6893225,1.154521,,,,,,,,,,,,,, -41100,3.45044,1.1792402,,,,,,,,,,,,,, -41200,2.9880052,1.1432894,,,,,,,,,,,,,, -41300,2.8017704,1.1457691,,,,,,,,,,,,,, -41377,,,0.14412792,0.0515004984779097,0.43128958,0.1231161358216592,5348.0,0.2291814,0.0734060487884142,2472.0,34584.85236620903,37986.19295430184,34584.85236620903,3398.2206456661224,1.250486135482788,0.0 -41400,2.963161,1.1753623,,,,,,,,,,,,,, -41500,2.3701043,1.1420034,,,,,,,,,,,,,, -41600,2.1592164,1.1675076,,,,,,,,,,,,,, -41700,2.2625675,1.2001165,,,,,,,,,,,,,, -41800,3.8886406,1.1161727,,,,,,,,,,,,,, -41900,3.2150476,1.1425517,,,,,,,,,,,,,, -42000,2.0829155,1.1704243,,,,,,,,,,,,,, -42100,3.9280903,1.1024579,,,,,,,,,,,,,, -42200,2.087334,1.1702354,,,,,,,,,,,,,, -42300,4.3771343,1.173804,,,,,,,,,,,,,, -42400,5.6381693,1.0948641,,,,,,,,,,,,,, -42500,2.484128,1.0945407,,,,,,,,,,,,,, -42600,3.3416321,1.1067656,,,,,,,,,,,,,, -42700,4.0635056,1.1240046,,,,,,,,,,,,,, -42800,2.4374409,1.1432427,,,,,,,,,,,,,, -42900,2.4050279,1.1567479,,,,,,,,,,,,,, -43000,1.7716779,1.1545099,,,,,,,,,,,,,, -43071,,,0.1331379,0.0482064816694626,0.42728096,0.1221989437809552,5348.0,0.22690861,0.0733451140495196,2472.0,36025.78731369972,39562.61579680443,36025.78731369972,3533.5720529556274,1.3105106353759766,0.0 -43100,2.930886,1.1247603,,,,,,,,,,,,,, -43200,2.5510142,1.1603621,,,,,,,,,,,,,, -43300,2.534363,1.0901939,,,,,,,,,,,,,, -43400,3.502024,1.1496015,,,,,,,,,,,,,, -43500,2.1119583,1.1549693,,,,,,,,,,,,,, -43600,4.856941,1.1436503,,,,,,,,,,,,,, -43700,4.3173685,1.1227853,,,,,,,,,,,,,, -43800,1.9166491,1.0756475,,,,,,,,,,,,,, -43900,1.9922235,1.108905,,,,,,,,,,,,,, -44000,2.2061577,1.1492361,,,,,,,,,,,,,, -44100,3.0110443,1.0979716,,,,,,,,,,,,,, -44200,2.4204175,1.1217386,,,,,,,,,,,,,, -44300,2.3947482,1.1065089,,,,,,,,,,,,,, -44400,2.8540254,1.1173837,,,,,,,,,,,,,, -44500,3.4610064,1.1566064,,,,,,,,,,,,,, -44600,2.3699334,1.0834271,,,,,,,,,,,,,, -44700,2.7329245,1.1396023,,,,,,,,,,,,,, -44800,2.493805,1.18474,,,,,,,,,,,,,, -44804,,,0.15674452,0.0525206026268349,0.42669344,0.1222665263523755,5348.0,0.22530259,0.0731216866735726,2472.0,37466.45243930817,41139.22075676918,37466.45243930817,3669.3792510032654,1.3633804321289062,0.0 -44900,3.3723922,1.1389805,,,,,,,,,,,,,, -45000,4.5910797,1.1245973,,,,,,,,,,,,,, -45100,3.3293629,1.1158532,,,,,,,,,,,,,, -45200,3.078691,1.1355038,,,,,,,,,,,,,, -45300,2.8113585,1.063671,,,,,,,,,,,,,, -45400,4.4553895,1.0769902,,,,,,,,,,,,,, -45500,3.4241774,1.1282324,,,,,,,,,,,,,, -45600,2.9931836,1.1175148,,,,,,,,,,,,,, -45700,2.9983234,1.132207,,,,,,,,,,,,,, -45800,2.2624884,1.1298081,,,,,,,,,,,,,, -45900,2.8804047,1.0454237,,,,,,,,,,,,,, -46000,1.7736039,1.1034064,,,,,,,,,,,,,, -46100,2.7680092,1.1479163,,,,,,,,,,,,,, -46200,1.8407696,1.1723324,,,,,,,,,,,,,, -46300,3.703181,1.1314385,,,,,,,,,,,,,, -46400,2.5387676,1.0808685,,,,,,,,,,,,,, -46500,3.113151,1.1016003,,,,,,,,,,,,,, -46503,,,0.14168745,0.0509011874469889,0.42561495,0.1211658959035307,5348.0,0.22481138,0.0725326508642577,2472.0,38906.67197370529,42717.63082480431,38906.67197370529,3807.438556671143,1.4167790412902832,0.0 -46600,2.6117225,1.1691922,,,,,,,,,,,,,, -46700,3.6088145,1.0989245,,,,,,,,,,,,,, -46800,2.8322306,1.1371219,,,,,,,,,,,,,, -46900,2.5285964,1.1321725,,,,,,,,,,,,,, -47000,3.2924998,1.1064353,,,,,,,,,,,,,, -47100,3.2915833,1.1182278,,,,,,,,,,,,,, -47200,2.075617,1.1661794,,,,,,,,,,,,,, -47300,3.9702613,1.1717371,,,,,,,,,,,,,, -47400,1.96521,1.1296867,,,,,,,,,,,,,, -47500,2.8531976,1.1400111,,,,,,,,,,,,,, -47600,1.8557409,1.1160165,,,,,,,,,,,,,, -47700,2.9828517,1.13885,,,,,,,,,,,,,, -47800,3.9596126,1.135943,,,,,,,,,,,,,, -47900,2.5150294,1.1483119,,,,,,,,,,,,,, -48000,,,0.121190354,0.044607899435158,0.42536923,0.1213493343116715,5348.0,0.22480996,0.0730404403550464,2472.0,40174.22040367127,44120.32517743111,40174.22040367127,3942.461863040924,1.4705607891082764,0.0 -48000,,,,,,,,,,,40174.220403671265,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 88119f34f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -182.46350383758545,0.0,15.480498790740969,1,0,15.480498790740969,30.6577,2472,3.2319176162330145,197.9440755844116,31.647242,3.052888001878375,30.570385,5348,2.911727507072033 -295.8599810600281,0.0336084365844726,1455.57701587677,1709,0,1455.57701587677,6.3034024,2472,0.8991123839701014,1751.5442821979525,6.3669677,0.9411137440758294,6.3865676,5348,0.8960773144617048 -409.7762682437897,0.0885474681854248,2895.7558012008667,3436,0,2895.7558012008667,4.3991666,2472,0.8860723498466476,3305.7746090888977,4.5963387,0.9283475590008252,4.687805,5348,0.8887301234830126 -544.7885777950287,0.1420388221740722,4336.122656345367,5133,0,4336.122656345367,0.77849424,2472,0.2492027704994617,4881.283031463623,0.764981,0.2510452520218764,1.1377617,5348,0.3193373046139587 -682.6388504505157,0.1931395530700683,5776.731394290924,6831,0,5776.731394290924,0.5582751,2472,0.179005951292832,6459.871329784393,0.5411437,0.1812576618889239,0.8500217,5348,0.2485880069899688 -817.924996137619,0.3204481601715088,7216.925114154816,8559,0,7216.925114154816,0.4910314,2472,0.1571100684500233,8035.558212518692,0.45767567,0.1520767509732723,0.78322417,5348,0.2247313592786043 -955.059031009674,0.4464774131774902,8656.972178220749,10255,0,8656.972178220749,0.4505234,2472,0.1460605691304612,9612.94243979454,0.3818851,0.130847625625425,0.7319703,5348,0.2101238691987603 -1093.622872591019,0.5056447982788086,10097.09136915207,11957,0,10097.09136915207,0.42062852,2472,0.1345439034793736,11191.76549577713,0.3600609,0.1224991440237527,0.6953865,5348,0.1999575195265358 -1228.071349620819,0.5557355880737305,11537.651688098907,13684,0,11537.651688098907,0.40535486,2472,0.1304003412345378,12766.904225826263,0.37431747,0.1254869075957585,0.67228556,5348,0.1934985566293675 -1365.2708704471588,0.6095635890960693,12978.08913564682,15387,0,12978.08913564682,0.3900358,2472,0.1256474316007556,14344.672752857208,0.33916324,0.1129862559990412,0.6459791,5348,0.1861610203037353 -1501.8507792949677,0.659881591796875,14418.107741355896,17085,0,14418.107741355896,0.37440822,2472,0.121402311457762,15921.399963855743,0.32999718,0.1126295270339591,0.6329018,5348,0.1827336184674203 -1639.8546781539917,0.7098729610443115,15858.09096622467,18812,0,15858.09096622467,0.36095637,2472,0.1158572502183494,17499.5164706707,0.35147545,0.1125434761798398,0.6095573,5348,0.175338154223428 -1777.8929901123047,0.7604987621307373,17298.066687345505,20513,0,17298.066687345505,0.35536996,2472,0.1155322649442447,19077.658970832825,0.2695812,0.0923303320548844,0.59979016,5348,0.1738803016113616 -1915.6497299671173,0.8108630180358887,18738.8721241951,22217,0,18738.8721241951,0.3375406,2472,0.1102512542400422,20656.351114034653,0.2810426,0.0961853131681713,0.58093655,5348,0.1685702424283383 -2051.2442207336426,0.8616993427276611,20179.33757257461,23927,0,20179.33757257461,0.32571787,2472,0.1042999614079987,22232.538994073868,0.3692097,0.1231945390990832,0.56566906,5348,0.164853201000222 -2185.283809185028,0.9100189208984376,21619.9885225296,25613,0,21619.9885225296,0.31971592,2472,0.1032640708467897,23807.35521101952,0.3844082,0.1260792174737577,0.54311615,5348,0.157988742674532 -2320.938892841339,0.965810775756836,23059.87901854515,27319,0,23059.87901854515,0.3093404,2472,0.0984096033148498,25383.03518056869,0.42835096,0.1418925190272879,0.540189,5348,0.1571198238991281 -2454.164674520493,1.0208978652954102,24500.38377356529,29031,0,24500.38377356529,0.29937524,2472,0.0949972579367497,26956.89998602867,0.3626228,0.1182944007512618,0.52205914,5348,0.1520800950017861 -2589.162698030472,1.0792968273162842,25941.25056886673,30729,0,25941.25056886673,0.28334433,2472,0.0904880872585461,28532.90132164955,0.32588735,0.1100028658286459,0.50159395,5348,0.1470210567983239 -2726.269201755524,1.1319363117218018,27381.333554029465,32452,0,27381.333554029465,0.27825376,2472,0.0895131314362318,30110.22183823585,0.2819919,0.0966568077732058,0.48701313,5348,0.1421840756152427 -2861.870703935623,1.1866295337677002,28821.81571483612,34156,0,28821.81571483612,0.26531848,2472,0.0852273881339751,31686.440123796463,0.30927896,0.1050533365327981,0.46904442,5348,0.136545758228178 -2998.3368847370148,1.2433674335479736,30262.3197350502,35850,0,30262.3197350502,0.25412038,2472,0.0814494343225072,33263.54372572899,0.26929548,0.0921024489359426,0.45808133,5348,0.1342286414937679 -3133.3894832134247,1.3041977882385254,31702.28162097931,37571,0,31702.28162097931,0.24617566,2472,0.0782808278999857,34838.69946312904,0.2542912,0.0884822202003802,0.44569063,5348,0.1310136420247738 -3268.299718618393,1.3554682731628418,33142.84040021896,39261,0,33142.84040021896,0.2394587,2472,0.0759653078219893,36414.29718565941,0.2523249,0.0850499358672177,0.4367461,5348,0.1274703843517383 -3405.947791576385,1.4116151332855225,34583.702951192856,40968,0,34583.702951192856,0.23126832,2472,0.0734466719476773,37992.941672325134,0.24387786,0.0842229176835395,0.42268944,5348,0.1245064058623053 -3539.2807710170746,1.4675979614257812,36023.71224784851,42693,0,36023.71224784851,0.22636868,2472,0.0724107813864684,39566.42193007469,0.24088885,0.0808483381965696,0.41713747,5348,0.1221410158625949 -3674.257806539536,1.5213351249694824,37463.88521766663,44390,0,37463.88521766663,0.22205788,2472,0.070765543436313,41141.70326757431,0.21339326,0.0741663739706858,0.4088802,5348,0.1195439141894436 -3806.234555482864,1.577507495880127,38904.08066558838,46095,0,38904.08066558838,0.22028401,2472,0.0703390002640505,42714.010046720505,0.21221867,0.0744158716991566,0.4071756,5348,0.1189067070874808 -3940.92617559433,1.6310737133026123,40343.99628758431,47825,0,40343.99628758431,0.2197189,2472,0.0701155728881035,44288.75029063225,0.20782495,0.0720470484289579,0.40657282,5348,0.1190901454956216 -4064.7738120555878,1.6917657852172852,40475.96862196922,48000,0,40475.96862196922,0.2197332,2472,0.07007494972884042,44544.644236803055,0.21455996,0.07427663280772798,0.40663552,5348,0.11909980014868166 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/measurements.csv deleted file mode 100644 index c074618cf..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,18.901436,32.93538,,,,,,,,,,,,,, -1,,,31.647242,3.052888001878375,30.570385,2.911727507072033,5348.0,30.6577,3.2319176162330145,2472.0,15.480498790740969,197.9440755844116,15.480498790740969,182.46350383758545,0.0,0.0 -100,6.479203,9.542411,,,,,,,,,,,,,, -200,1.5030441,6.6132507,,,,,,,,,,,,,, -300,1.0582416,5.9563384,,,,,,,,,,,,,, -400,0.8683003,5.8752947,,,,,,,,,,,,,, -500,0.5737672,5.8132076,,,,,,,,,,,,,, -600,0.28706062,5.8251915,,,,,,,,,,,,,, -700,0.38802978,5.7696505,,,,,,,,,,,,,, -800,0.34527248,5.678323,,,,,,,,,,,,,, -900,0.4101484,5.5413117,,,,,,,,,,,,,, -1000,0.77750826,5.471586,,,,,,,,,,,,,, -1100,0.7421144,5.279803,,,,,,,,,,,,,, -1200,0.9911707,4.857293,,,,,,,,,,,,,, -1300,1.1893673,4.4132853,,,,,,,,,,,,,, -1400,2.3530414,4.041403,,,,,,,,,,,,,, -1500,1.2850446,3.7141333,,,,,,,,,,,,,, -1600,4.7443943,3.5507574,,,,,,,,,,,,,, -1700,2.2735937,3.346539,,,,,,,,,,,,,, -1709,,,6.3669677,0.9411137440758294,6.3865676,0.8960773144617048,5348.0,6.3034024,0.8991123839701014,2472.0,1455.57701587677,1751.5442821979525,1455.57701587677,295.8599810600281,0.0336084365844726,0.0 -1800,2.7594614,3.1909292,,,,,,,,,,,,,, -1900,3.6892316,3.0836265,,,,,,,,,,,,,, -2000,3.6727688,2.9494739,,,,,,,,,,,,,, -2100,2.4869974,2.859063,,,,,,,,,,,,,, -2200,3.6516309,2.8180602,,,,,,,,,,,,,, -2300,2.9745295,2.6783905,,,,,,,,,,,,,, -2400,2.9702623,2.646635,,,,,,,,,,,,,, -2500,2.7226632,2.5723248,,,,,,,,,,,,,, -2600,4.16173,2.518753,,,,,,,,,,,,,, -2700,3.72872,2.4158974,,,,,,,,,,,,,, -2800,4.3629045,2.4020307,,,,,,,,,,,,,, -2900,4.529335,2.4114661,,,,,,,,,,,,,, -3000,3.5939906,2.3292875,,,,,,,,,,,,,, -3100,5.9555945,2.3581724,,,,,,,,,,,,,, -3200,3.1391609,2.279485,,,,,,,,,,,,,, -3300,3.275765,2.2643895,,,,,,,,,,,,,, -3400,4.3223,2.227366,,,,,,,,,,,,,, -3436,,,4.5963387,0.9283475590008252,4.687805,0.8887301234830126,5348.0,4.3991666,0.8860723498466476,2472.0,2895.7558012008667,3305.7746090888977,2895.7558012008667,409.7762682437897,0.0885474681854248,0.0 -3500,2.5389245,2.0964859,,,,,,,,,,,,,, -3600,4.7113237,2.244817,,,,,,,,,,,,,, -3700,3.7162783,2.1197653,,,,,,,,,,,,,, -3800,3.1088238,2.1070359,,,,,,,,,,,,,, -3900,2.795718,2.0863743,,,,,,,,,,,,,, -4000,3.908377,2.1119795,,,,,,,,,,,,,, -4100,3.427303,2.078576,,,,,,,,,,,,,, -4200,4.416883,2.0563376,,,,,,,,,,,,,, -4300,4.362059,2.0913377,,,,,,,,,,,,,, -4400,3.995113,1.9992751,,,,,,,,,,,,,, -4500,3.2635672,2.0812345,,,,,,,,,,,,,, -4600,3.2043474,2.0646913,,,,,,,,,,,,,, -4700,3.2853048,1.9817524,,,,,,,,,,,,,, -4800,3.2638943,2.0055969,,,,,,,,,,,,,, -4900,4.7679033,2.6500716,,,,,,,,,,,,,, -5000,2.2083657,2.152081,,,,,,,,,,,,,, -5100,2.4305153,2.0646152,,,,,,,,,,,,,, -5133,,,0.764981,0.2510452520218764,1.1377617,0.3193373046139587,5348.0,0.77849424,0.2492027704994617,2472.0,4336.122656345367,4881.283031463623,4336.122656345367,544.7885777950287,0.1420388221740722,0.0 -5200,2.9780557,1.9330634,,,,,,,,,,,,,, -5300,1.9688345,1.9623879,,,,,,,,,,,,,, -5400,3.3325894,1.9283892,,,,,,,,,,,,,, -5500,2.1439683,2.0422676,,,,,,,,,,,,,, -5600,2.4491436,1.9638262,,,,,,,,,,,,,, -5700,1.9628522,1.9139838,,,,,,,,,,,,,, -5800,2.6342788,1.8966937,,,,,,,,,,,,,, -5900,2.0960097,1.8412398,,,,,,,,,,,,,, -6000,2.041629,1.89423,,,,,,,,,,,,,, -6100,1.9791746,1.8807405,,,,,,,,,,,,,, -6200,2.8209414,1.8297662,,,,,,,,,,,,,, -6300,1.7884985,1.8696053,,,,,,,,,,,,,, -6400,2.2643864,1.8525157,,,,,,,,,,,,,, -6500,2.0769553,1.7629958,,,,,,,,,,,,,, -6600,2.4735909,1.8083161,,,,,,,,,,,,,, -6700,2.3959322,1.769381,,,,,,,,,,,,,, -6800,2.5112572,1.7823764,,,,,,,,,,,,,, -6831,,,0.5411437,0.1812576618889239,0.8500217,0.2485880069899688,5348.0,0.5582751,0.179005951292832,2472.0,5776.731394290924,6459.871329784393,5776.731394290924,682.6388504505157,0.1931395530700683,0.0 -6900,1.9542776,1.8634946,,,,,,,,,,,,,, -7000,2.6705306,1.8657657,,,,,,,,,,,,,, -7100,2.5169463,1.7547649,,,,,,,,,,,,,, -7200,1.9985746,1.7856368,,,,,,,,,,,,,, -7300,2.4186532,1.7534014,,,,,,,,,,,,,, -7400,2.314776,1.7816589,,,,,,,,,,,,,, -7500,2.6404307,1.7529005,,,,,,,,,,,,,, -7600,2.9103572,1.7628247,,,,,,,,,,,,,, -7700,3.0677805,1.723283,,,,,,,,,,,,,, -7800,1.7872002,1.7744335,,,,,,,,,,,,,, -7900,3.01516,1.7498138,,,,,,,,,,,,,, -8000,2.1571972,1.7174674,,,,,,,,,,,,,, -8100,2.676554,1.7162046,,,,,,,,,,,,,, -8200,2.441028,1.7637441,,,,,,,,,,,,,, -8300,1.4082092,1.685673,,,,,,,,,,,,,, -8400,2.1915307,1.8180777,,,,,,,,,,,,,, -8500,4.922859,1.674497,,,,,,,,,,,,,, -8559,,,0.45767567,0.1520767509732723,0.78322417,0.2247313592786043,5348.0,0.4910314,0.1571100684500233,2472.0,7216.925114154816,8035.558212518692,7216.925114154816,817.924996137619,0.3204481601715088,0.0 -8600,1.8906398,1.670705,,,,,,,,,,,,,, -8700,2.667448,1.7024628,,,,,,,,,,,,,, -8800,3.2446144,1.676232,,,,,,,,,,,,,, -8900,1.9932336,1.744071,,,,,,,,,,,,,, -9000,4.124909,1.717085,,,,,,,,,,,,,, -9100,1.7636561,1.7074696,,,,,,,,,,,,,, -9200,2.1059475,1.7161596,,,,,,,,,,,,,, -9300,2.6215038,1.6989627,,,,,,,,,,,,,, -9400,2.7773836,1.7266941,,,,,,,,,,,,,, -9500,1.9516503,1.6960844,,,,,,,,,,,,,, -9600,2.3550014,1.6473581,,,,,,,,,,,,,, -9700,2.1120534,1.6904601,,,,,,,,,,,,,, -9800,2.49923,1.6496284,,,,,,,,,,,,,, -9900,2.125485,1.6270106,,,,,,,,,,,,,, -10000,2.0969193,1.6055112,,,,,,,,,,,,,, -10100,2.4381368,1.6323621,,,,,,,,,,,,,, -10200,2.0011642,1.7286801,,,,,,,,,,,,,, -10255,,,0.3818851,0.130847625625425,0.7319703,0.2101238691987603,5348.0,0.4505234,0.1460605691304612,2472.0,8656.972178220749,9612.94243979454,8656.972178220749,955.059031009674,0.4464774131774902,0.0 -10300,1.6925328,1.6277361,,,,,,,,,,,,,, -10400,1.9307162,1.6926982,,,,,,,,,,,,,, -10500,2.6126165,1.6385419,,,,,,,,,,,,,, -10600,2.843604,1.6398982,,,,,,,,,,,,,, -10700,2.3250964,1.6161084,,,,,,,,,,,,,, -10800,1.9390689,1.6143593,,,,,,,,,,,,,, -10900,2.5218422,1.6287419,,,,,,,,,,,,,, -11000,2.3188236,1.6065186,,,,,,,,,,,,,, -11100,2.2125597,1.6438748,,,,,,,,,,,,,, -11200,2.9845917,1.5760885,,,,,,,,,,,,,, -11300,2.1392555,1.5686166,,,,,,,,,,,,,, -11400,3.5415883,1.6068516,,,,,,,,,,,,,, -11500,2.2177858,1.6888565,,,,,,,,,,,,,, -11600,2.3603227,1.5640192,,,,,,,,,,,,,, -11700,2.0467732,1.5919355,,,,,,,,,,,,,, -11800,2.3305376,1.683834,,,,,,,,,,,,,, -11900,3.1417482,1.6149969,,,,,,,,,,,,,, -11957,,,0.3600609,0.1224991440237527,0.6953865,0.1999575195265358,5348.0,0.42062852,0.1345439034793736,2472.0,10097.09136915207,11191.76549577713,10097.09136915207,1093.622872591019,0.5056447982788086,0.0 -12000,3.1743734,1.6720804,,,,,,,,,,,,,, -12100,2.167931,1.5821111,,,,,,,,,,,,,, -12200,2.789503,1.6284931,,,,,,,,,,,,,, -12300,2.6600666,1.5614954,,,,,,,,,,,,,, -12400,2.729877,1.6083932,,,,,,,,,,,,,, -12500,3.672004,1.5870122,,,,,,,,,,,,,, -12600,2.1336758,1.6225924,,,,,,,,,,,,,, -12700,2.4810472,1.5242124,,,,,,,,,,,,,, -12800,2.6303325,1.593808,,,,,,,,,,,,,, -12900,2.8656087,1.6049119,,,,,,,,,,,,,, -13000,2.1789985,1.5417646,,,,,,,,,,,,,, -13100,2.2787719,1.5350671,,,,,,,,,,,,,, -13200,3.0751169,1.5842005,,,,,,,,,,,,,, -13300,2.9450834,1.5948919,,,,,,,,,,,,,, -13400,2.5655003,1.5597703,,,,,,,,,,,,,, -13500,2.4493012,1.5818764,,,,,,,,,,,,,, -13600,2.5498261,1.6425023,,,,,,,,,,,,,, -13684,,,0.37431747,0.1254869075957585,0.67228556,0.1934985566293675,5348.0,0.40535486,0.1304003412345378,2472.0,11537.651688098907,12766.904225826263,11537.651688098907,1228.071349620819,0.5557355880737305,0.0 -13700,3.566835,1.5573868,,,,,,,,,,,,,, -13800,2.4287276,1.5567263,,,,,,,,,,,,,, -13900,2.7557902,1.5770025,,,,,,,,,,,,,, -14000,2.6914246,1.5640454,,,,,,,,,,,,,, -14100,3.1594872,1.5511934,,,,,,,,,,,,,, -14200,2.175956,1.5538019,,,,,,,,,,,,,, -14300,3.1537461,1.574142,,,,,,,,,,,,,, -14400,3.9069831,1.543317,,,,,,,,,,,,,, -14500,2.720939,1.5740699,,,,,,,,,,,,,, -14600,1.9947697,1.5302659,,,,,,,,,,,,,, -14700,2.0114703,1.4986453,,,,,,,,,,,,,, -14800,2.105779,1.5157093,,,,,,,,,,,,,, -14900,2.7736676,1.5716573,,,,,,,,,,,,,, -15000,2.1683645,1.5914754,,,,,,,,,,,,,, -15100,1.9751098,1.5201164,,,,,,,,,,,,,, -15200,2.3740742,1.6119431,,,,,,,,,,,,,, -15300,2.9167838,1.5809914,,,,,,,,,,,,,, -15387,,,0.33916324,0.1129862559990412,0.6459791,0.1861610203037353,5348.0,0.3900358,0.1256474316007556,2472.0,12978.08913564682,14344.672752857208,12978.08913564682,1365.2708704471588,0.6095635890960693,0.0 -15400,2.6368098,1.5582733,,,,,,,,,,,,,, -15500,3.3011758,1.5698063,,,,,,,,,,,,,, -15600,3.1607122,1.5605325,,,,,,,,,,,,,, -15700,2.378544,1.5050347,,,,,,,,,,,,,, -15800,4.810314,1.5275059,,,,,,,,,,,,,, -15900,2.3767009,1.5611055,,,,,,,,,,,,,, -16000,3.1869364,1.5246843,,,,,,,,,,,,,, -16100,2.1552725,1.531593,,,,,,,,,,,,,, -16200,2.6097093,1.498917,,,,,,,,,,,,,, -16300,4.4440856,1.5461056,,,,,,,,,,,,,, -16400,3.502866,1.5289904,,,,,,,,,,,,,, -16500,2.109666,1.5133508,,,,,,,,,,,,,, -16600,2.0236766,1.5613034,,,,,,,,,,,,,, -16700,3.4461563,1.502877,,,,,,,,,,,,,, -16800,2.4073968,1.5835114,,,,,,,,,,,,,, -16900,2.409738,1.4701042,,,,,,,,,,,,,, -17000,2.7950003,1.5173696,,,,,,,,,,,,,, -17085,,,0.32999718,0.1126295270339591,0.6329018,0.1827336184674203,5348.0,0.37440822,0.121402311457762,2472.0,14418.107741355896,15921.399963855743,14418.107741355896,1501.8507792949677,0.659881591796875,0.0 -17100,1.8878123,1.5021577,,,,,,,,,,,,,, -17200,2.571138,1.5517831,,,,,,,,,,,,,, -17300,2.226431,1.5508784,,,,,,,,,,,,,, -17400,2.1999433,1.5481493,,,,,,,,,,,,,, -17500,2.3013892,1.6016376,,,,,,,,,,,,,, -17600,2.1500058,1.4759222,,,,,,,,,,,,,, -17700,2.3278916,1.5255564,,,,,,,,,,,,,, -17800,2.233745,1.4957945,,,,,,,,,,,,,, -17900,2.758183,1.5280831,,,,,,,,,,,,,, -18000,2.4131134,1.5772874,,,,,,,,,,,,,, -18100,2.7015824,1.4553336,,,,,,,,,,,,,, -18200,2.958978,1.4800613,,,,,,,,,,,,,, -18300,1.9821205,1.5222052,,,,,,,,,,,,,, -18400,1.984117,1.4940481,,,,,,,,,,,,,, -18500,2.1169336,1.533196,,,,,,,,,,,,,, -18600,3.2515562,1.5230153,,,,,,,,,,,,,, -18700,3.108241,1.4807427,,,,,,,,,,,,,, -18800,3.1231856,1.4489585,,,,,,,,,,,,,, -18812,,,0.35147545,0.1125434761798398,0.6095573,0.175338154223428,5348.0,0.36095637,0.1158572502183494,2472.0,15858.09096622467,17499.5164706707,15858.09096622467,1639.8546781539917,0.7098729610443115,0.0 -18900,2.1283453,1.5437034,,,,,,,,,,,,,, -19000,2.1143386,1.5145679,,,,,,,,,,,,,, -19100,2.3566034,1.4758573,,,,,,,,,,,,,, -19200,2.7622373,1.5446405,,,,,,,,,,,,,, -19300,2.135489,1.5185066,,,,,,,,,,,,,, -19400,3.0929363,1.496225,,,,,,,,,,,,,, -19500,2.5169468,1.4870017,,,,,,,,,,,,,, -19600,2.4326313,1.5197419,,,,,,,,,,,,,, -19700,2.9755168,1.5775862,,,,,,,,,,,,,, -19800,3.1737123,1.4601114,,,,,,,,,,,,,, -19900,3.0140471,1.5120939,,,,,,,,,,,,,, -20000,1.8459017,1.4238944,,,,,,,,,,,,,, -20100,3.5204268,1.4674468,,,,,,,,,,,,,, -20200,2.0152524,1.517025,,,,,,,,,,,,,, -20300,2.5589387,1.5269045,,,,,,,,,,,,,, -20400,2.9342692,1.5148693,,,,,,,,,,,,,, -20500,3.3250227,1.4885466,,,,,,,,,,,,,, -20513,,,0.2695812,0.0923303320548844,0.59979016,0.1738803016113616,5348.0,0.35536996,0.1155322649442447,2472.0,17298.066687345505,19077.658970832825,17298.066687345505,1777.8929901123047,0.7604987621307373,0.0 -20600,2.0158474,1.4861377,,,,,,,,,,,,,, -20700,3.5171263,1.4200155,,,,,,,,,,,,,, -20800,2.9953117,1.4861954,,,,,,,,,,,,,, -20900,2.270504,1.4834416,,,,,,,,,,,,,, -21000,2.3767223,1.4822818,,,,,,,,,,,,,, -21100,2.0193605,1.4520383,,,,,,,,,,,,,, -21200,5.445841,1.447926,,,,,,,,,,,,,, -21300,2.2414377,1.4713184,,,,,,,,,,,,,, -21400,2.9600317,1.463406,,,,,,,,,,,,,, -21500,2.3546324,1.4852874,,,,,,,,,,,,,, -21600,4.068005,1.4604878,,,,,,,,,,,,,, -21700,2.262097,1.4334934,,,,,,,,,,,,,, -21800,2.839019,1.4396878,,,,,,,,,,,,,, -21900,2.4226959,1.4394,,,,,,,,,,,,,, -22000,2.1523364,1.4934332,,,,,,,,,,,,,, -22100,2.2191396,1.4053488,,,,,,,,,,,,,, -22200,1.8848623,1.4555091,,,,,,,,,,,,,, -22217,,,0.2810426,0.0961853131681713,0.58093655,0.1685702424283383,5348.0,0.3375406,0.1102512542400422,2472.0,18738.8721241951,20656.351114034653,18738.8721241951,1915.6497299671173,0.8108630180358887,0.0 -22300,2.8775706,1.5121446,,,,,,,,,,,,,, -22400,2.751803,1.5186074,,,,,,,,,,,,,, -22500,2.7241082,1.430704,,,,,,,,,,,,,, -22600,1.9946266,1.4341836,,,,,,,,,,,,,, -22700,2.2636533,1.39097,,,,,,,,,,,,,, -22800,2.6040263,1.4274284,,,,,,,,,,,,,, -22900,3.2434142,1.4129791,,,,,,,,,,,,,, -23000,2.7196364,1.4544886,,,,,,,,,,,,,, -23100,3.0806892,1.4489305,,,,,,,,,,,,,, -23200,2.5440392,1.3888752,,,,,,,,,,,,,, -23300,2.0124435,1.4334817,,,,,,,,,,,,,, -23400,2.1625855,1.4502573,,,,,,,,,,,,,, -23500,2.5199275,1.4529797,,,,,,,,,,,,,, -23600,2.8734,1.4425799,,,,,,,,,,,,,, -23700,2.662792,1.4248359,,,,,,,,,,,,,, -23800,3.424297,1.4258258,,,,,,,,,,,,,, -23900,1.7608335,1.4024476,,,,,,,,,,,,,, -23927,,,0.3692097,0.1231945390990832,0.56566906,0.164853201000222,5348.0,0.32571787,0.1042999614079987,2472.0,20179.33757257461,22232.538994073868,20179.33757257461,2051.2442207336426,0.8616993427276611,0.0 -24000,1.8602314,1.389882,,,,,,,,,,,,,, -24100,3.0885348,1.4495227,,,,,,,,,,,,,, -24200,2.6371336,1.3999803,,,,,,,,,,,,,, -24300,3.1134467,1.4234061,,,,,,,,,,,,,, -24400,2.4083195,1.3880044,,,,,,,,,,,,,, -24500,3.8731813,1.4692084,,,,,,,,,,,,,, -24600,2.4349115,1.4465256,,,,,,,,,,,,,, -24700,2.947736,1.4560318,,,,,,,,,,,,,, -24800,2.5395281,1.3982977,,,,,,,,,,,,,, -24900,3.8893268,1.4160477,,,,,,,,,,,,,, -25000,2.4157076,1.387316,,,,,,,,,,,,,, -25100,3.2606719,1.4729995,,,,,,,,,,,,,, -25200,3.3717937,1.3936181,,,,,,,,,,,,,, -25300,2.529843,1.435978,,,,,,,,,,,,,, -25400,1.9404501,1.3896246,,,,,,,,,,,,,, -25500,3.1665215,1.4155434,,,,,,,,,,,,,, -25600,2.816979,1.4168912,,,,,,,,,,,,,, -25613,,,0.3844082,0.1260792174737577,0.54311615,0.157988742674532,5348.0,0.31971592,0.1032640708467897,2472.0,21619.9885225296,23807.35521101952,21619.9885225296,2185.283809185028,0.9100189208984376,0.0 -25700,2.474414,1.431116,,,,,,,,,,,,,, -25800,3.2000973,1.4224881,,,,,,,,,,,,,, -25900,2.3521307,1.3323737,,,,,,,,,,,,,, -26000,2.077272,1.3336256,,,,,,,,,,,,,, -26100,3.0204701,1.4397757,,,,,,,,,,,,,, -26200,2.8553925,1.398203,,,,,,,,,,,,,, -26300,2.8017912,1.3825525,,,,,,,,,,,,,, -26400,2.1199377,1.4084918,,,,,,,,,,,,,, -26500,3.4023528,1.3496041,,,,,,,,,,,,,, -26600,2.1841664,1.4137199,,,,,,,,,,,,,, -26700,2.8295712,1.3305452,,,,,,,,,,,,,, -26800,2.1250324,1.3590871,,,,,,,,,,,,,, -26900,2.2030256,1.3600706,,,,,,,,,,,,,, -27000,2.5394998,1.3682717,,,,,,,,,,,,,, -27100,2.015262,1.3533533,,,,,,,,,,,,,, -27200,2.308584,1.3955314,,,,,,,,,,,,,, -27300,3.228885,1.4369587,,,,,,,,,,,,,, -27319,,,0.42835096,0.1418925190272879,0.540189,0.1571198238991281,5348.0,0.3093404,0.0984096033148498,2472.0,23059.87901854515,25383.03518056869,23059.87901854515,2320.938892841339,0.965810775756836,0.0 -27400,2.0987492,1.4013246,,,,,,,,,,,,,, -27500,2.884002,1.3601478,,,,,,,,,,,,,, -27600,2.196211,1.3998133,,,,,,,,,,,,,, -27700,2.655043,1.4160962,,,,,,,,,,,,,, -27800,2.2320144,1.3754045,,,,,,,,,,,,,, -27900,2.687901,1.3713187,,,,,,,,,,,,,, -28000,1.8180116,1.3473213,,,,,,,,,,,,,, -28100,2.2633784,1.3162867,,,,,,,,,,,,,, -28200,2.0051916,1.3629749,,,,,,,,,,,,,, -28300,2.8222878,1.3483447,,,,,,,,,,,,,, -28400,2.4947865,1.3654126,,,,,,,,,,,,,, -28500,2.8263116,1.3383467,,,,,,,,,,,,,, -28600,2.8269627,1.2806429,,,,,,,,,,,,,, -28700,2.4755945,1.3265562,,,,,,,,,,,,,, -28800,2.0690877,1.3079059,,,,,,,,,,,,,, -28900,1.8646684,1.349537,,,,,,,,,,,,,, -29000,2.6036305,1.335649,,,,,,,,,,,,,, -29031,,,0.3626228,0.1182944007512618,0.52205914,0.1520800950017861,5348.0,0.29937524,0.0949972579367497,2472.0,24500.38377356529,26956.89998602867,24500.38377356529,2454.164674520493,1.0208978652954102,0.0 -29100,2.5671585,1.3006638,,,,,,,,,,,,,, -29200,2.5379062,1.3327485,,,,,,,,,,,,,, -29300,2.0504348,1.3788111,,,,,,,,,,,,,, -29400,2.4222279,1.3208147,,,,,,,,,,,,,, -29500,2.801497,1.3606093,,,,,,,,,,,,,, -29600,2.7708802,1.3374102,,,,,,,,,,,,,, -29700,2.548212,1.3307443,,,,,,,,,,,,,, -29800,2.0369992,1.335321,,,,,,,,,,,,,, -29900,2.1546474,1.3560314,,,,,,,,,,,,,, -30000,3.4777877,1.3351223,,,,,,,,,,,,,, -30100,2.2281537,1.2612044,,,,,,,,,,,,,, -30200,2.6762288,1.3209481,,,,,,,,,,,,,, -30300,3.4700549,1.3324386,,,,,,,,,,,,,, -30400,2.1648123,1.3333476,,,,,,,,,,,,,, -30500,2.1422365,1.3001556,,,,,,,,,,,,,, -30600,2.3433177,1.3344637,,,,,,,,,,,,,, -30700,2.3413858,1.251257,,,,,,,,,,,,,, -30729,,,0.32588735,0.1100028658286459,0.50159395,0.1470210567983239,5348.0,0.28334433,0.0904880872585461,2472.0,25941.25056886673,28532.90132164955,25941.25056886673,2589.162698030472,1.0792968273162842,0.0 -30800,2.9466887,1.2729588,,,,,,,,,,,,,, -30900,2.1253521,1.3459325,,,,,,,,,,,,,, -31000,3.571402,1.357669,,,,,,,,,,,,,, -31100,2.7354157,1.2205163,,,,,,,,,,,,,, -31200,2.62328,1.3327817,,,,,,,,,,,,,, -31300,2.1224475,1.3017368,,,,,,,,,,,,,, -31400,2.0550823,1.2792872,,,,,,,,,,,,,, -31500,2.5902824,1.3229318,,,,,,,,,,,,,, -31600,2.182172,1.263895,,,,,,,,,,,,,, -31700,2.135242,1.3458292,,,,,,,,,,,,,, -31800,2.142158,1.259104,,,,,,,,,,,,,, -31900,2.401364,1.299761,,,,,,,,,,,,,, -32000,3.3411677,1.2525799,,,,,,,,,,,,,, -32100,3.132139,1.3175293,,,,,,,,,,,,,, -32200,1.9284045,1.2711157,,,,,,,,,,,,,, -32300,2.7926905,1.355851,,,,,,,,,,,,,, -32400,9.101859,1.2800108,,,,,,,,,,,,,, -32452,,,0.2819919,0.0966568077732058,0.48701313,0.1421840756152427,5348.0,0.27825376,0.0895131314362318,2472.0,27381.333554029465,30110.22183823585,27381.333554029465,2726.269201755524,1.1319363117218018,0.0 -32500,2.2394466,1.2871244,,,,,,,,,,,,,, -32600,2.537132,1.357343,,,,,,,,,,,,,, -32700,2.4794261,1.2357883,,,,,,,,,,,,,, -32800,2.2331796,1.2909628,,,,,,,,,,,,,, -32900,2.7541776,1.3392748,,,,,,,,,,,,,, -33000,2.0320563,1.2950008,,,,,,,,,,,,,, -33100,2.168469,1.2806473,,,,,,,,,,,,,, -33200,3.6306167,1.3035023,,,,,,,,,,,,,, -33300,2.3794982,1.2447044,,,,,,,,,,,,,, -33400,5.509019,1.2908214,,,,,,,,,,,,,, -33500,2.096427,1.3497932,,,,,,,,,,,,,, -33600,2.0851233,1.2400727,,,,,,,,,,,,,, -33700,2.1239593,1.2796326,,,,,,,,,,,,,, -33800,3.293143,1.2894499,,,,,,,,,,,,,, -33900,3.3726757,1.2488798,,,,,,,,,,,,,, -34000,1.9663848,1.250814,,,,,,,,,,,,,, -34100,2.3492706,1.2974404,,,,,,,,,,,,,, -34156,,,0.30927896,0.1050533365327981,0.46904442,0.136545758228178,5348.0,0.26531848,0.0852273881339751,2472.0,28821.81571483612,31686.440123796463,28821.81571483612,2861.870703935623,1.1866295337677002,0.0 -34200,2.4755495,1.2209032,,,,,,,,,,,,,, -34300,2.8029287,1.2678019,,,,,,,,,,,,,, -34400,4.068766,1.2623658,,,,,,,,,,,,,, -34500,2.2873638,1.2264225,,,,,,,,,,,,,, -34600,2.969883,1.2919718,,,,,,,,,,,,,, -34700,1.6533363,1.2939698,,,,,,,,,,,,,, -34800,2.070615,1.2053177,,,,,,,,,,,,,, -34900,3.2466223,1.2557935,,,,,,,,,,,,,, -35000,2.2757359,1.206745,,,,,,,,,,,,,, -35100,2.3780708,1.2225885,,,,,,,,,,,,,, -35200,2.2230687,1.2665609,,,,,,,,,,,,,, -35300,3.0512269,1.2673306,,,,,,,,,,,,,, -35400,2.82455,1.2619756,,,,,,,,,,,,,, -35500,3.5065074,1.2659014,,,,,,,,,,,,,, -35600,2.4957416,1.2529639,,,,,,,,,,,,,, -35700,3.8416958,1.2709434,,,,,,,,,,,,,, -35800,2.1403956,1.2479907,,,,,,,,,,,,,, -35850,,,0.26929548,0.0921024489359426,0.45808133,0.1342286414937679,5348.0,0.25412038,0.0814494343225072,2472.0,30262.3197350502,33263.54372572899,30262.3197350502,2998.3368847370148,1.2433674335479736,0.0 -35900,2.393849,1.2082772,,,,,,,,,,,,,, -36000,2.4258015,1.2456541,,,,,,,,,,,,,, -36100,3.4983962,1.2173994,,,,,,,,,,,,,, -36200,2.4286172,1.2129378,,,,,,,,,,,,,, -36300,3.1184251,1.2185379,,,,,,,,,,,,,, -36400,3.5991917,1.2746872,,,,,,,,,,,,,, -36500,2.9399252,1.2192146,,,,,,,,,,,,,, -36600,3.3523903,1.2475163,,,,,,,,,,,,,, -36700,2.8130684,1.2276661,,,,,,,,,,,,,, -36800,2.8824215,1.2578009,,,,,,,,,,,,,, -36900,2.2079868,1.2467712,,,,,,,,,,,,,, -37000,2.7277298,1.1357498,,,,,,,,,,,,,, -37100,2.5574315,1.2106067,,,,,,,,,,,,,, -37200,6.005655,1.2369031,,,,,,,,,,,,,, -37300,2.404818,1.261654,,,,,,,,,,,,,, -37400,2.700865,1.1680797,,,,,,,,,,,,,, -37500,2.891621,1.2228297,,,,,,,,,,,,,, -37571,,,0.2542912,0.0884822202003802,0.44569063,0.1310136420247738,5348.0,0.24617566,0.0782808278999857,2472.0,31702.28162097931,34838.69946312904,31702.28162097931,3133.3894832134247,1.3041977882385254,0.0 -37600,2.676945,1.1863632,,,,,,,,,,,,,, -37700,2.6018465,1.2253697,,,,,,,,,,,,,, -37800,2.7679577,1.1900486,,,,,,,,,,,,,, -37900,3.7342675,1.2676926,,,,,,,,,,,,,, -38000,3.2692804,1.2387519,,,,,,,,,,,,,, -38100,3.0338318,1.192773,,,,,,,,,,,,,, -38200,2.4304845,1.2246798,,,,,,,,,,,,,, -38300,2.728959,1.2183845,,,,,,,,,,,,,, -38400,2.1525912,1.1580957,,,,,,,,,,,,,, -38500,2.3435485,1.2113005,,,,,,,,,,,,,, -38600,4.1523795,1.2208415,,,,,,,,,,,,,, -38700,4.089396,1.202893,,,,,,,,,,,,,, -38800,3.5341363,1.1746325,,,,,,,,,,,,,, -38900,4.602003,1.2076818,,,,,,,,,,,,,, -39000,2.6622348,1.1892259,,,,,,,,,,,,,, -39100,6.2077837,1.1821234,,,,,,,,,,,,,, -39200,2.5431294,1.1517534,,,,,,,,,,,,,, -39261,,,0.2523249,0.0850499358672177,0.4367461,0.1274703843517383,5348.0,0.2394587,0.0759653078219893,2472.0,33142.84040021896,36414.29718565941,33142.84040021896,3268.299718618393,1.3554682731628418,0.0 -39300,3.3637698,1.1624675,,,,,,,,,,,,,, -39400,2.702959,1.1779531,,,,,,,,,,,,,, -39500,2.6333468,1.1691664,,,,,,,,,,,,,, -39600,5.3447456,1.1342647,,,,,,,,,,,,,, -39700,2.4886076,1.1916815,,,,,,,,,,,,,, -39800,2.134258,1.2049751,,,,,,,,,,,,,, -39900,2.787912,1.1629165,,,,,,,,,,,,,, -40000,2.773001,1.17578,,,,,,,,,,,,,, -40100,3.6937215,1.2227911,,,,,,,,,,,,,, -40200,1.955499,1.13413,,,,,,,,,,,,,, -40300,2.4813845,1.1384202,,,,,,,,,,,,,, -40400,2.2996397,1.1373158,,,,,,,,,,,,,, -40500,4.453646,1.1673594,,,,,,,,,,,,,, -40600,2.498249,1.1188834,,,,,,,,,,,,,, -40700,3.9420717,1.18137,,,,,,,,,,,,,, -40800,2.3730862,1.1516119,,,,,,,,,,,,,, -40900,3.132842,1.1614995,,,,,,,,,,,,,, -40968,,,0.24387786,0.0842229176835395,0.42268944,0.1245064058623053,5348.0,0.23126832,0.0734466719476773,2472.0,34583.702951192856,37992.941672325134,34583.702951192856,3405.947791576385,1.4116151332855225,0.0 -41000,2.5212834,1.1580217,,,,,,,,,,,,,, -41100,3.9036744,1.145086,,,,,,,,,,,,,, -41200,2.793451,1.0988559,,,,,,,,,,,,,, -41300,3.6768522,1.1800252,,,,,,,,,,,,,, -41400,4.055247,1.1632297,,,,,,,,,,,,,, -41500,2.5756483,1.1443627,,,,,,,,,,,,,, -41600,2.4246778,1.1228931,,,,,,,,,,,,,, -41700,2.4690118,1.1575441,,,,,,,,,,,,,, -41800,2.7376394,1.1487244,,,,,,,,,,,,,, -41900,1.9773506,1.1544313,,,,,,,,,,,,,, -42000,3.580573,1.1472776,,,,,,,,,,,,,, -42100,2.3695064,1.1209028,,,,,,,,,,,,,, -42200,2.169464,1.1565228,,,,,,,,,,,,,, -42300,4.0542107,1.1938447,,,,,,,,,,,,,, -42400,3.2134957,1.1659063,,,,,,,,,,,,,, -42500,3.0182507,1.0917549,,,,,,,,,,,,,, -42600,2.3448312,1.1859977,,,,,,,,,,,,,, -42693,,,0.24088885,0.0808483381965696,0.41713747,0.1221410158625949,5348.0,0.22636868,0.0724107813864684,2472.0,36023.71224784851,39566.42193007469,36023.71224784851,3539.2807710170746,1.4675979614257812,0.0 -42700,2.4004693,1.1309664,,,,,,,,,,,,,, -42800,2.2218375,1.1582013,,,,,,,,,,,,,, -42900,3.3266678,1.1080992,,,,,,,,,,,,,, -43000,2.7350075,1.1253079,,,,,,,,,,,,,, -43100,2.2343695,1.1461731,,,,,,,,,,,,,, -43200,2.7941165,1.1176615,,,,,,,,,,,,,, -43300,2.3758733,1.0347769,,,,,,,,,,,,,, -43400,3.1213374,1.1674988,,,,,,,,,,,,,, -43500,2.3518646,1.1216402,,,,,,,,,,,,,, -43600,3.0689685,1.1757021,,,,,,,,,,,,,, -43700,6.4670215,1.1420267,,,,,,,,,,,,,, -43800,2.0571964,1.1269064,,,,,,,,,,,,,, -43900,2.324975,1.0611401,,,,,,,,,,,,,, -44000,3.3033922,1.0653503,,,,,,,,,,,,,, -44100,3.5650377,1.1626825,,,,,,,,,,,,,, -44200,2.6505542,1.1112169,,,,,,,,,,,,,, -44300,3.1096442,1.060209,,,,,,,,,,,,,, -44390,,,0.21339326,0.0741663739706858,0.4088802,0.1195439141894436,5348.0,0.22205788,0.070765543436313,2472.0,37463.88521766663,41141.70326757431,37463.88521766663,3674.257806539536,1.5213351249694824,0.0 -44400,2.033157,1.0495491,,,,,,,,,,,,,, -44500,4.454913,1.1238507,,,,,,,,,,,,,, -44600,2.604383,1.105651,,,,,,,,,,,,,, -44700,3.1065679,1.1132025,,,,,,,,,,,,,, -44800,2.2921588,1.1374612,,,,,,,,,,,,,, -44900,2.901928,1.1421179,,,,,,,,,,,,,, -45000,2.587357,1.0960802,,,,,,,,,,,,,, -45100,2.8146982,1.0926727,,,,,,,,,,,,,, -45200,2.712977,1.0873238,,,,,,,,,,,,,, -45300,2.8752491,1.1018504,,,,,,,,,,,,,, -45400,3.0947015,1.0741367,,,,,,,,,,,,,, -45500,3.9815524,1.0953875,,,,,,,,,,,,,, -45600,4.030051,1.1222174,,,,,,,,,,,,,, -45700,3.0099156,1.1860784,,,,,,,,,,,,,, -45800,2.6202154,1.1374451,,,,,,,,,,,,,, -45900,2.6312685,1.1581925,,,,,,,,,,,,,, -46000,2.6492956,1.1899725,,,,,,,,,,,,,, -46095,,,0.21221867,0.0744158716991566,0.4071756,0.1189067070874808,5348.0,0.22028401,0.0703390002640505,2472.0,38904.08066558838,42714.010046720505,38904.08066558838,3806.234555482864,1.577507495880127,0.0 -46100,3.2896132,1.0936987,,,,,,,,,,,,,, -46200,2.5250964,1.1669449,,,,,,,,,,,,,, -46300,3.4525852,1.1270602,,,,,,,,,,,,,, -46400,2.4617646,1.1094006,,,,,,,,,,,,,, -46500,2.8580065,1.0940261,,,,,,,,,,,,,, -46600,5.0300403,1.0896337,,,,,,,,,,,,,, -46700,2.52104,1.1481966,,,,,,,,,,,,,, -46800,5.493582,1.1445172,,,,,,,,,,,,,, -46900,2.7925463,1.1371826,,,,,,,,,,,,,, -47000,2.08541,1.1084725,,,,,,,,,,,,,, -47100,2.4104853,1.1260891,,,,,,,,,,,,,, -47200,2.667624,1.0858552,,,,,,,,,,,,,, -47300,2.8599634,1.1412953,,,,,,,,,,,,,, -47400,2.8153806,1.1326326,,,,,,,,,,,,,, -47500,2.284735,1.1456435,,,,,,,,,,,,,, -47600,5.589726,1.1463665,,,,,,,,,,,,,, -47700,2.4260638,1.108376,,,,,,,,,,,,,, -47800,2.7531497,1.1238908,,,,,,,,,,,,,, -47825,,,0.20782495,0.0720470484289579,0.40657282,0.1190901454956216,5348.0,0.2197189,0.0701155728881035,2472.0,40343.99628758431,44288.75029063225,40343.99628758431,3940.92617559433,1.6310737133026123,0.0 -47900,2.635236,1.1020505,,,,,,,,,,,,,, -48000,,,0.21455996,0.0742766328077279,0.40663552,0.1190998001486816,5348.0,0.2197332,0.0700749497288404,2472.0,40475.96862196922,44544.644236803055,40475.96862196922,4064.773812055588,1.6917657852172852,0.0 -48000,,,,,,,,,,,40475.96862196922,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/eval_measurements.csv deleted file mode 100644 index ba3258140..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -173.78581547737122,0.0,15.822179079055786,1,0,15.822179079055786,30.657564,2472,3.231897304653383,189.60808277130127,32.005363,2.979481132075472,30.57026,5348,2.911621305888373 -307.7569811344147,0.0297317504882812,1455.7558295726776,1705,0,1455.7558295726776,1.3494298,2472,0.3802530822822091,1763.6183669567108,1.7329582,0.4562996089686897,1.7820381,5348,0.4402135609256881 -439.76512384414673,0.0895667076110839,2896.13534617424,3421,0,2896.13534617424,0.7284901,2472,0.2250523023175512,3336.1467773914337,0.9513836,0.2833366333833055,1.08201,5348,0.2937428193517866 -573.4874782562256,0.1315174102783203,4337.026045560837,5110,0,4337.026045560837,0.64543873,2472,0.2012674425690085,4910.876049280167,0.8250965,0.2466646851028317,0.96908915,5348,0.2676849107427325 -710.4288213253021,0.1832258701324463,5777.149932384491,6803,0,5777.149932384491,0.6189214,2472,0.1907866674791298,6488.072125196457,0.8884563,0.267509416380854,0.9415473,5348,0.2631858424167527 -846.5350136756897,0.2336657047271728,7217.49794960022,8533,0,7217.49794960022,0.57711726,2472,0.1780106838908862,8064.655959129333,0.7801244,0.2347124218934048,0.90787,5348,0.2515809494385819 -983.403307914734,0.3625540733337402,8657.429280042648,10220,0,8657.429280042648,0.5482055,2472,0.1682814372473747,9641.66242980957,0.78038126,0.2305007804084504,0.8497801,5348,0.2387499155217857 -1120.4950244426727,0.4233393669128418,10097.492525815964,11925,0,10097.492525815964,0.5307556,2472,0.1668190035139032,11218.957754611967,0.6961746,0.213959671084794,0.8327891,5348,0.2342411925427459 -1255.964623451233,0.4713277816772461,11537.983419895172,13649,0,11537.983419895172,0.49814272,2472,0.1586334369223894,12795.047011137009,0.68453616,0.2114985991207489,0.8041042,5348,0.2259961188294698 -1393.445172548294,0.5251693725585938,12978.61863040924,15342,0,12978.61863040924,0.49917656,2472,0.1561554242073406,14373.29497885704,0.68597907,0.2103933557639641,0.7896287,5348,0.2207439875648068 -1539.2309548854828,0.5735268592834473,14418.936494112017,17076,0,14418.936494112017,0.4785223,2472,0.1510978408790851,15959.526646614077,0.43890128,0.1444019164369262,0.7677971,5348,0.2151732527491624 -1674.8570148944857,0.6237502098083496,15859.42936182022,18793,0,15859.42936182022,0.45743635,2472,0.1456137143785672,17535.774045705795,0.4114516,0.1333188679830135,0.739945,5348,0.2078164071174102 -1812.569813489914,0.6772727966308594,17300.21093249321,20505,0,17300.21093249321,0.43603286,2472,0.1374078362074218,19114.40056157112,0.38734055,0.129115201993911,0.71387607,5348,0.2016277745059231 -1949.4109256267548,0.7298746109008789,18740.341069221497,22215,0,18740.341069221497,0.41533288,2472,0.1301566022789592,20691.503726243973,0.3631994,0.1221305633005597,0.67903066,5348,0.1932861542620466 -2085.8726336956024,0.7837283611297607,20180.55277776718,23940,0,20180.55277776718,0.39538053,2472,0.1259521052952288,22268.30971980095,0.357722,0.1192458582972644,0.66510963,5348,0.190418722303214 -2221.237764120102,0.8373799324035645,21620.54094862938,25627,0,21620.54094862938,0.3858443,2472,0.1221132167448662,23843.795273303986,0.327648,0.1109453339589517,0.6450211,5348,0.1839114861407455 -2356.1602787971497,0.893054723739624,23060.78440976143,27322,0,23060.78440976143,0.36358392,2472,0.1171571913147685,25419.094624519348,0.35698888,0.115199530197328,0.6116994,5348,0.1745947459378047 -2494.747106552124,0.9459781646728516,24500.67453694344,29032,0,24500.67453694344,0.3473988,2472,0.1140495196311417,26997.705008029938,0.30421615,0.1036155133015636,0.59509784,5348,0.1709066684688685 -2632.291247367859,0.9961137771606444,25941.396134853363,30728,0,25941.396134853363,0.3357012,2472,0.1066967278045213,28576.09837579727,0.26352453,0.0905633455560276,0.5693506,5348,0.1632891472044952 -2767.3991503715515,1.0550148487091064,27382.316435575485,32447,0,27382.316435575485,0.31943408,2472,0.1029187739930534,30152.26672244072,0.25451177,0.086225316926675,0.54807425,5348,0.1576797937766106 -2903.598759412765,1.106881618499756,28822.94229197502,34160,0,28822.94229197502,0.30711773,2472,0.0979627485629557,31729.22262334824,0.2348298,0.0816698966408268,0.53372824,5348,0.1530745242669704 -3038.3013293743134,1.1584124565124512,30263.509187936783,35863,0,30263.509187936783,0.28727466,2472,0.091747405195702,33304.622517347336,0.23932621,0.0797680012210462,0.5034849,5348,0.144308099288452 -3176.7409851551056,1.2146975994110107,31703.80317473412,37591,0,31703.80317473412,0.27169424,2472,0.0879491398046026,34883.49290180206,0.21954021,0.0745288632163148,0.4808648,5348,0.1389304575340085 -3313.524995803833,1.2724831104278564,33143.80793213844,39304,0,33143.80793213844,0.2551368,2472,0.0820790932910852,36460.42004203797,0.19475468,0.0674533633356491,0.45533058,5348,0.131293626963515 -3448.9083466529846,1.3277525901794434,34583.80059170723,41016,0,34583.80059170723,0.24059556,2472,0.0772043141795137,38035.92982244492,0.18020418,0.0614644881696766,0.43848684,5348,0.126273207372293 -3583.6617407798767,1.3814823627471924,36024.716413497925,42751,0,36024.716413497925,0.23369418,2472,0.0746247435663071,39611.73369860649,0.15269303,0.0525743253947298,0.4185573,5348,0.1192349652915222 -3719.5914764404297,1.4356229305267334,37464.84016394615,44456,0,37464.84016394615,0.22438748,2472,0.0719029918956797,41187.92023229599,0.15371695,0.0517177970517024,0.4076263,5348,0.1159234192919277 -3854.6277837753296,1.4938104152679443,38905.040621995926,46165,0,38905.040621995926,0.22044119,2472,0.0706233623788922,42763.2930085659,0.16266528,0.0549525944734069,0.40298778,5348,0.114620041128822 -3990.876124382019,1.5495550632476809,40345.54191684723,47904,0,40345.54191684723,0.21938264,2472,0.0707046086974184,44340.17746424675,0.14715919,0.0487956127626761,0.40185603,5348,0.1147841702308427 -4121.736053466797,1.6115646362304688,40417.566965818405,48000,0,40417.566965818405,0.21940319,2472,0.0706639855381553,44543.13734269142,0.1594734,0.054616546975526925,0.401908,5348,0.11480347953696284 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/measurements.csv deleted file mode 100644 index ef7af3bef..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,19.243675,33.472027,,,,,,,,,,,,,, -1,,,32.005363,2.979481132075472,30.57026,2.911621305888373,5348.0,30.657564,3.231897304653383,2472.0,15.822179079055786,189.60808277130127,15.822179079055786,173.78581547737122,0.0,0.0 -100,1.1079513,5.760642,,,,,,,,,,,,,, -200,4.900049,5.1235147,,,,,,,,,,,,,, -300,2.6073303,3.7083008,,,,,,,,,,,,,, -400,3.928571,3.237816,,,,,,,,,,,,,, -500,1.9085493,2.9476702,,,,,,,,,,,,,, -600,3.147137,2.8430276,,,,,,,,,,,,,, -700,4.634906,2.8220923,,,,,,,,,,,,,, -800,2.8927975,2.7382066,,,,,,,,,,,,,, -900,2.7479181,2.6702375,,,,,,,,,,,,,, -1000,3.3134794,2.6732452,,,,,,,,,,,,,, -1100,2.3523486,2.531987,,,,,,,,,,,,,, -1200,2.0655124,2.4331582,,,,,,,,,,,,,, -1300,3.4064422,2.4035723,,,,,,,,,,,,,, -1400,1.8196212,2.292174,,,,,,,,,,,,,, -1500,1.6934532,2.2981863,,,,,,,,,,,,,, -1600,2.9950223,2.2358115,,,,,,,,,,,,,, -1700,2.8711288,2.2990398,,,,,,,,,,,,,, -1705,,,1.7329582,0.4562996089686897,1.7820381,0.4402135609256881,5348.0,1.3494298,0.3802530822822091,2472.0,1455.7558295726776,1763.6183669567108,1455.7558295726776,307.7569811344147,0.0297317504882812,0.0 -1800,3.2318988,2.2521794,,,,,,,,,,,,,, -1900,2.3761935,2.2239585,,,,,,,,,,,,,, -2000,3.7848012,2.2072997,,,,,,,,,,,,,, -2100,2.6247985,2.2383907,,,,,,,,,,,,,, -2200,2.564129,2.2393994,,,,,,,,,,,,,, -2300,2.1024277,2.1923149,,,,,,,,,,,,,, -2400,5.0031595,2.1481915,,,,,,,,,,,,,, -2500,2.8083682,2.152179,,,,,,,,,,,,,, -2600,2.4666717,2.080598,,,,,,,,,,,,,, -2700,2.7932496,2.1028209,,,,,,,,,,,,,, -2800,5.116051,2.0830543,,,,,,,,,,,,,, -2900,2.754926,2.082984,,,,,,,,,,,,,, -3000,4.919745,2.031355,,,,,,,,,,,,,, -3100,3.7268841,2.0427995,,,,,,,,,,,,,, -3200,3.2733843,2.0406046,,,,,,,,,,,,,, -3300,14.0865,2.0546212,,,,,,,,,,,,,, -3400,4.908954,2.0035703,,,,,,,,,,,,,, -3421,,,0.9513836,0.2833366333833055,1.08201,0.2937428193517866,5348.0,0.7284901,0.2250523023175512,2472.0,2896.13534617424,3336.1467773914337,2896.13534617424,439.76512384414673,0.0895667076110839,0.0 -3500,3.7394092,2.019057,,,,,,,,,,,,,, -3600,3.3099165,2.0895677,,,,,,,,,,,,,, -3700,4.8726954,2.0615792,,,,,,,,,,,,,, -3800,3.1523993,2.1219704,,,,,,,,,,,,,, -3900,4.6662974,2.062872,,,,,,,,,,,,,, -4000,2.910119,2.0120392,,,,,,,,,,,,,, -4100,3.1303012,2.0403202,,,,,,,,,,,,,, -4200,2.893442,2.0187254,,,,,,,,,,,,,, -4300,3.278835,2.0222604,,,,,,,,,,,,,, -4400,2.123718,1.9753394,,,,,,,,,,,,,, -4500,3.6106174,2.0630584,,,,,,,,,,,,,, -4600,12.722773,2.029737,,,,,,,,,,,,,, -4700,8.384468,2.4478993,,,,,,,,,,,,,, -4800,2.2864404,2.0189424,,,,,,,,,,,,,, -4900,2.8476763,1.9673331,,,,,,,,,,,,,, -5000,2.2325144,2.080452,,,,,,,,,,,,,, -5100,3.2607605,2.0567863,,,,,,,,,,,,,, -5110,,,0.8250965,0.2466646851028317,0.96908915,0.2676849107427325,5348.0,0.64543873,0.2012674425690085,2472.0,4337.026045560837,4910.876049280167,4337.026045560837,573.4874782562256,0.1315174102783203,0.0 -5200,1.5799602,1.9307587,,,,,,,,,,,,,, -5300,2.0802479,1.9317505,,,,,,,,,,,,,, -5400,2.6910255,2.0089154,,,,,,,,,,,,,, -5500,3.1400692,2.0021982,,,,,,,,,,,,,, -5600,3.3728042,2.0042503,,,,,,,,,,,,,, -5700,3.2475395,1.9750023,,,,,,,,,,,,,, -5800,3.2797067,1.9764079,,,,,,,,,,,,,, -5900,2.4889135,1.9400826,,,,,,,,,,,,,, -6000,7.348606,1.996896,,,,,,,,,,,,,, -6100,2.8697531,1.9266456,,,,,,,,,,,,,, -6200,2.634992,1.8839086,,,,,,,,,,,,,, -6300,2.8511696,1.8889843,,,,,,,,,,,,,, -6400,2.6579037,1.9595851,,,,,,,,,,,,,, -6500,2.7395597,2.0359435,,,,,,,,,,,,,, -6600,2.5789845,1.9498347,,,,,,,,,,,,,, -6700,2.6732726,2.0011144,,,,,,,,,,,,,, -6800,1.9441117,1.9194064,,,,,,,,,,,,,, -6803,,,0.8884563,0.267509416380854,0.9415473,0.2631858424167527,5348.0,0.6189214,0.1907866674791298,2472.0,5777.149932384491,6488.072125196457,5777.149932384491,710.4288213253021,0.1832258701324463,0.0 -6900,4.4257607,1.9770205,,,,,,,,,,,,,, -7000,2.7620707,1.9461744,,,,,,,,,,,,,, -7100,3.8432124,1.9008195,,,,,,,,,,,,,, -7200,3.6582766,1.9379456,,,,,,,,,,,,,, -7300,2.2383463,1.9717442,,,,,,,,,,,,,, -7400,3.141465,1.917374,,,,,,,,,,,,,, -7500,2.5700407,1.9017433,,,,,,,,,,,,,, -7600,4.5234284,1.8666531,,,,,,,,,,,,,, -7700,3.8023598,1.8812625,,,,,,,,,,,,,, -7800,2.232039,1.9355519,,,,,,,,,,,,,, -7900,3.4233224,1.8847792,,,,,,,,,,,,,, -8000,5.1613545,1.9177169,,,,,,,,,,,,,, -8100,2.3574393,1.8838178,,,,,,,,,,,,,, -8200,3.3437636,1.8827922,,,,,,,,,,,,,, -8300,4.2190585,1.9847661,,,,,,,,,,,,,, -8400,3.522221,1.8875803,,,,,,,,,,,,,, -8500,3.4254408,1.8351208,,,,,,,,,,,,,, -8533,,,0.7801244,0.2347124218934048,0.90787,0.2515809494385819,5348.0,0.57711726,0.1780106838908862,2472.0,7217.49794960022,8064.655959129333,7217.49794960022,846.5350136756897,0.2336657047271728,0.0 -8600,2.0373607,1.8327141,,,,,,,,,,,,,, -8700,2.1770551,1.8229849,,,,,,,,,,,,,, -8800,2.687739,1.8451507,,,,,,,,,,,,,, -8900,3.2154098,1.8724537,,,,,,,,,,,,,, -9000,2.552245,1.8747267,,,,,,,,,,,,,, -9100,9.052183,1.8864915,,,,,,,,,,,,,, -9200,1.7337558,1.8982313,,,,,,,,,,,,,, -9300,2.0396671,1.9132066,,,,,,,,,,,,,, -9400,2.2185273,1.9307549,,,,,,,,,,,,,, -9500,3.4950955,1.832586,,,,,,,,,,,,,, -9600,2.9822571,1.8337988,,,,,,,,,,,,,, -9700,3.6931746,1.9153365,,,,,,,,,,,,,, -9800,2.2657661,1.8768601,,,,,,,,,,,,,, -9900,3.1431623,1.7783741,,,,,,,,,,,,,, -10000,3.4998498,1.7867204,,,,,,,,,,,,,, -10100,2.746277,1.8156894,,,,,,,,,,,,,, -10200,2.7286315,1.8040047,,,,,,,,,,,,,, -10220,,,0.78038126,0.2305007804084504,0.8497801,0.2387499155217857,5348.0,0.5482055,0.1682814372473747,2472.0,8657.429280042648,9641.66242980957,8657.429280042648,983.403307914734,0.3625540733337402,0.0 -10300,2.514994,1.8834136,,,,,,,,,,,,,, -10400,3.8975296,1.8710234,,,,,,,,,,,,,, -10500,2.1352577,1.7622215,,,,,,,,,,,,,, -10600,2.0198648,1.761915,,,,,,,,,,,,,, -10700,4.383542,1.8197255,,,,,,,,,,,,,, -10800,3.6963716,1.852248,,,,,,,,,,,,,, -10900,4.818033,1.8891618,,,,,,,,,,,,,, -11000,2.746999,1.7551352,,,,,,,,,,,,,, -11100,2.7497706,1.839093,,,,,,,,,,,,,, -11200,5.181009,1.769002,,,,,,,,,,,,,, -11300,2.2258813,1.8098443,,,,,,,,,,,,,, -11400,3.3542123,1.8244523,,,,,,,,,,,,,, -11500,1.8577198,1.8406484,,,,,,,,,,,,,, -11600,2.194944,1.8393677,,,,,,,,,,,,,, -11700,4.5585794,1.8056606,,,,,,,,,,,,,, -11800,2.2787905,1.863144,,,,,,,,,,,,,, -11900,3.0868196,1.7939044,,,,,,,,,,,,,, -11925,,,0.6961746,0.213959671084794,0.8327891,0.2342411925427459,5348.0,0.5307556,0.1668190035139032,2472.0,10097.492525815964,11218.957754611967,10097.492525815964,1120.4950244426727,0.4233393669128418,0.0 -12000,2.4238389,1.7805461,,,,,,,,,,,,,, -12100,3.689647,1.7690079,,,,,,,,,,,,,, -12200,3.0902932,1.8790984,,,,,,,,,,,,,, -12300,3.3023121,1.8426224,,,,,,,,,,,,,, -12400,3.050288,1.7830137,,,,,,,,,,,,,, -12500,3.5507836,1.7596418,,,,,,,,,,,,,, -12600,3.5854714,1.8103663,,,,,,,,,,,,,, -12700,2.3093166,1.7817591,,,,,,,,,,,,,, -12800,3.3575637,1.8567857,,,,,,,,,,,,,, -12900,2.9030664,1.7713838,,,,,,,,,,,,,, -13000,3.58052,1.7640415,,,,,,,,,,,,,, -13100,2.7499292,1.7695448,,,,,,,,,,,,,, -13200,1.9718301,1.8774377,,,,,,,,,,,,,, -13300,2.2932694,1.7926953,,,,,,,,,,,,,, -13400,2.2949336,1.8015667,,,,,,,,,,,,,, -13500,3.2902622,1.8337576,,,,,,,,,,,,,, -13600,3.4868417,1.8097667,,,,,,,,,,,,,, -13649,,,0.68453616,0.2114985991207489,0.8041042,0.2259961188294698,5348.0,0.49814272,0.1586334369223894,2472.0,11537.983419895172,12795.047011137009,11537.983419895172,1255.964623451233,0.4713277816772461,0.0 -13700,2.1976721,1.7841864,,,,,,,,,,,,,, -13800,1.8376575,1.7670693,,,,,,,,,,,,,, -13900,4.076993,1.8694824,,,,,,,,,,,,,, -14000,2.167554,1.8169954,,,,,,,,,,,,,, -14100,2.2042334,1.7620928,,,,,,,,,,,,,, -14200,3.5798538,1.7990091,,,,,,,,,,,,,, -14300,5.36412,1.8143514,,,,,,,,,,,,,, -14400,2.8972325,1.767251,,,,,,,,,,,,,, -14500,3.1821823,1.782392,,,,,,,,,,,,,, -14600,2.8984427,1.7308048,,,,,,,,,,,,,, -14700,3.6189287,1.7696668,,,,,,,,,,,,,, -14800,2.2648494,1.7300668,,,,,,,,,,,,,, -14900,5.89981,1.8489611,,,,,,,,,,,,,, -15000,3.2960153,1.8261812,,,,,,,,,,,,,, -15100,3.6878502,1.7270893,,,,,,,,,,,,,, -15200,2.3152757,1.7514259,,,,,,,,,,,,,, -15300,4.677623,1.7866586,,,,,,,,,,,,,, -15342,,,0.68597907,0.2103933557639641,0.7896287,0.2207439875648068,5348.0,0.49917656,0.1561554242073406,2472.0,12978.61863040924,14373.29497885704,12978.61863040924,1393.445172548294,0.5251693725585938,0.0 -15400,2.5413692,1.7444775,,,,,,,,,,,,,, -15500,4.273728,1.8192728,,,,,,,,,,,,,, -15600,2.4807615,1.7831454,,,,,,,,,,,,,, -15700,2.351215,1.7809316,,,,,,,,,,,,,, -15800,2.7910707,1.8270891,,,,,,,,,,,,,, -15900,2.448021,1.7716064,,,,,,,,,,,,,, -16000,3.3948026,1.8089806,,,,,,,,,,,,,, -16100,3.6492767,1.8314867,,,,,,,,,,,,,, -16200,4.836199,1.7613004,,,,,,,,,,,,,, -16300,4.025828,1.8392231,,,,,,,,,,,,,, -16400,2.2675307,1.7580874,,,,,,,,,,,,,, -16500,2.4825222,1.6945443,,,,,,,,,,,,,, -16600,3.2059557,1.7654074,,,,,,,,,,,,,, -16700,2.25285,1.77363,,,,,,,,,,,,,, -16800,2.662095,1.7425523,,,,,,,,,,,,,, -16900,1.7266736,1.6931723,,,,,,,,,,,,,, -17000,2.8012972,1.6573769,,,,,,,,,,,,,, -17076,,,0.43890128,0.1444019164369262,0.7677971,0.2151732527491624,5348.0,0.4785223,0.1510978408790851,2472.0,14418.936494112017,15959.526646614077,14418.936494112017,1539.2309548854828,0.5735268592834473,0.0 -17100,3.2533562,1.7562304,,,,,,,,,,,,,, -17200,1.8408694,1.7151008,,,,,,,,,,,,,, -17300,4.950689,1.7943717,,,,,,,,,,,,,, -17400,1.8637385,1.6982554,,,,,,,,,,,,,, -17500,1.8382027,1.7303191,,,,,,,,,,,,,, -17600,2.5975997,1.6996214,,,,,,,,,,,,,, -17700,2.2133913,1.8133644,,,,,,,,,,,,,, -17800,2.4302883,1.738193,,,,,,,,,,,,,, -17900,2.4797769,1.7246398,,,,,,,,,,,,,, -18000,4.9507527,1.7259446,,,,,,,,,,,,,, -18100,3.6104417,1.7259597,,,,,,,,,,,,,, -18200,2.1016583,1.748288,,,,,,,,,,,,,, -18300,1.6431315,1.738276,,,,,,,,,,,,,, -18400,2.548339,1.6438773,,,,,,,,,,,,,, -18500,3.0521433,1.7618401,,,,,,,,,,,,,, -18600,2.5252297,1.6546375,,,,,,,,,,,,,, -18700,3.0612645,1.6894758,,,,,,,,,,,,,, -18793,,,0.4114516,0.1333188679830135,0.739945,0.2078164071174102,5348.0,0.45743635,0.1456137143785672,2472.0,15859.42936182022,17535.774045705795,15859.42936182022,1674.8570148944857,0.6237502098083496,0.0 -18800,2.9631388,1.7168458,,,,,,,,,,,,,, -18900,2.7329166,1.7132831,,,,,,,,,,,,,, -19000,2.4912388,1.6824417,,,,,,,,,,,,,, -19100,2.7649603,1.7384468,,,,,,,,,,,,,, -19200,3.244262,1.7174214,,,,,,,,,,,,,, -19300,3.2906408,1.605741,,,,,,,,,,,,,, -19400,2.497279,1.6766002,,,,,,,,,,,,,, -19500,1.7976153,1.6038837,,,,,,,,,,,,,, -19600,3.0327055,1.7038418,,,,,,,,,,,,,, -19700,1.663602,1.7446885,,,,,,,,,,,,,, -19800,2.767272,1.6141092,,,,,,,,,,,,,, -19900,5.27775,1.730458,,,,,,,,,,,,,, -20000,3.5759373,1.658571,,,,,,,,,,,,,, -20100,1.9925082,1.6511077,,,,,,,,,,,,,, -20200,2.1526513,1.6362407,,,,,,,,,,,,,, -20300,2.6714494,1.6546698,,,,,,,,,,,,,, -20400,3.436356,1.6515875,,,,,,,,,,,,,, -20500,2.061499,1.7551378,,,,,,,,,,,,,, -20505,,,0.38734055,0.129115201993911,0.71387607,0.2016277745059231,5348.0,0.43603286,0.1374078362074218,2472.0,17300.21093249321,19114.40056157112,17300.21093249321,1812.569813489914,0.6772727966308594,0.0 -20600,2.6146774,1.5984042,,,,,,,,,,,,,, -20700,2.233839,1.6479061,,,,,,,,,,,,,, -20800,2.57188,1.6394619,,,,,,,,,,,,,, -20900,2.2228475,1.6426651,,,,,,,,,,,,,, -21000,2.0044441,1.6231892,,,,,,,,,,,,,, -21100,4.1916823,1.7122731,,,,,,,,,,,,,, -21200,2.5572267,1.6566057,,,,,,,,,,,,,, -21300,3.3909807,1.6660593,,,,,,,,,,,,,, -21400,2.7563155,1.6792008,,,,,,,,,,,,,, -21500,1.6719413,1.6134294,,,,,,,,,,,,,, -21600,2.3817809,1.661626,,,,,,,,,,,,,, -21700,2.6780024,1.5997833,,,,,,,,,,,,,, -21800,2.5664136,1.6633072,,,,,,,,,,,,,, -21900,2.0462582,1.6333566,,,,,,,,,,,,,, -22000,2.8689585,1.6762934,,,,,,,,,,,,,, -22100,1.7744571,1.6534402,,,,,,,,,,,,,, -22200,1.3200102,1.6379923,,,,,,,,,,,,,, -22215,,,0.3631994,0.1221305633005597,0.67903066,0.1932861542620466,5348.0,0.41533288,0.1301566022789592,2472.0,18740.341069221497,20691.503726243973,18740.341069221497,1949.4109256267548,0.7298746109008789,0.0 -22300,5.263276,1.603741,,,,,,,,,,,,,, -22400,2.027572,1.6347165,,,,,,,,,,,,,, -22500,1.9091636,1.648654,,,,,,,,,,,,,, -22600,1.6334373,1.6323955,,,,,,,,,,,,,, -22700,3.25098,1.5747235,,,,,,,,,,,,,, -22800,2.0867915,1.5910826,,,,,,,,,,,,,, -22900,2.918485,1.5826871,,,,,,,,,,,,,, -23000,2.3231149,1.5882963,,,,,,,,,,,,,, -23100,3.3896284,1.6313635,,,,,,,,,,,,,, -23200,2.4056864,1.5663459,,,,,,,,,,,,,, -23300,2.195092,1.5707726,,,,,,,,,,,,,, -23400,2.6937761,1.579089,,,,,,,,,,,,,, -23500,6.331902,1.6400034,,,,,,,,,,,,,, -23600,3.3331993,1.5340675,,,,,,,,,,,,,, -23700,3.051637,1.6267713,,,,,,,,,,,,,, -23800,1.6866243,1.5748976,,,,,,,,,,,,,, -23900,2.1610227,1.6115123,,,,,,,,,,,,,, -23940,,,0.357722,0.1192458582972644,0.66510963,0.190418722303214,5348.0,0.39538053,0.1259521052952288,2472.0,20180.55277776718,22268.30971980095,20180.55277776718,2085.8726336956024,0.7837283611297607,0.0 -24000,3.3651295,1.588162,,,,,,,,,,,,,, -24100,2.1322906,1.5938941,,,,,,,,,,,,,, -24200,1.8343744,1.5974157,,,,,,,,,,,,,, -24300,2.591088,1.5186175,,,,,,,,,,,,,, -24400,1.5723505,1.5448824,,,,,,,,,,,,,, -24500,3.9902868,1.67243,,,,,,,,,,,,,, -24600,2.0071144,1.6096518,,,,,,,,,,,,,, -24700,1.9499764,1.6399935,,,,,,,,,,,,,, -24800,2.2638679,1.5927535,,,,,,,,,,,,,, -24900,2.5455472,1.5975416,,,,,,,,,,,,,, -25000,3.0906198,1.5962888,,,,,,,,,,,,,, -25100,4.318281,1.6214998,,,,,,,,,,,,,, -25200,2.3033128,1.62034,,,,,,,,,,,,,, -25300,3.6034477,1.6149403,,,,,,,,,,,,,, -25400,3.662015,1.6293377,,,,,,,,,,,,,, -25500,4.162584,1.6084199,,,,,,,,,,,,,, -25600,3.586491,1.594653,,,,,,,,,,,,,, -25627,,,0.327648,0.1109453339589517,0.6450211,0.1839114861407455,5348.0,0.3858443,0.1221132167448662,2472.0,21620.54094862938,23843.795273303986,21620.54094862938,2221.237764120102,0.8373799324035645,0.0 -25700,3.0671592,1.5932069,,,,,,,,,,,,,, -25800,5.3820834,1.6529833,,,,,,,,,,,,,, -25900,1.8025608,1.6328251,,,,,,,,,,,,,, -26000,2.2826433,1.6096679,,,,,,,,,,,,,, -26100,3.1368058,1.5250971,,,,,,,,,,,,,, -26200,2.332703,1.5827532,,,,,,,,,,,,,, -26300,2.3716025,1.5005102,,,,,,,,,,,,,, -26400,3.3107386,1.5031693,,,,,,,,,,,,,, -26500,2.7627773,1.5635053,,,,,,,,,,,,,, -26600,3.636431,1.61881,,,,,,,,,,,,,, -26700,2.5181355,1.5479292,,,,,,,,,,,,,, -26800,2.6380763,1.4705809,,,,,,,,,,,,,, -26900,2.0085995,1.5217522,,,,,,,,,,,,,, -27000,2.7055147,1.5505738,,,,,,,,,,,,,, -27100,1.869721,1.5612715,,,,,,,,,,,,,, -27200,2.2147942,1.4642805,,,,,,,,,,,,,, -27300,2.4141273,1.5351843,,,,,,,,,,,,,, -27322,,,0.35698888,0.115199530197328,0.6116994,0.1745947459378047,5348.0,0.36358392,0.1171571913147685,2472.0,23060.78440976143,25419.094624519348,23060.78440976143,2356.1602787971497,0.893054723739624,0.0 -27400,2.450503,1.4956841,,,,,,,,,,,,,, -27500,3.1495452,1.5239035,,,,,,,,,,,,,, -27600,2.4963439,1.5823798,,,,,,,,,,,,,, -27700,2.5619838,1.5924755,,,,,,,,,,,,,, -27800,3.5423765,1.5278627,,,,,,,,,,,,,, -27900,2.1062143,1.4944177,,,,,,,,,,,,,, -28000,2.223428,1.4735302,,,,,,,,,,,,,, -28100,2.4934103,1.4958528,,,,,,,,,,,,,, -28200,2.172254,1.4919628,,,,,,,,,,,,,, -28300,2.2085166,1.5379496,,,,,,,,,,,,,, -28400,2.2512603,1.5355287,,,,,,,,,,,,,, -28500,2.3877978,1.5077686,,,,,,,,,,,,,, -28600,1.6962228,1.4727926,,,,,,,,,,,,,, -28700,2.373364,1.4810795,,,,,,,,,,,,,, -28800,3.0599504,1.4498498,,,,,,,,,,,,,, -28900,2.3524976,1.490793,,,,,,,,,,,,,, -29000,1.6807419,1.480154,,,,,,,,,,,,,, -29032,,,0.30421615,0.1036155133015636,0.59509784,0.1709066684688685,5348.0,0.3473988,0.1140495196311417,2472.0,24500.67453694344,26997.705008029938,24500.67453694344,2494.747106552124,0.9459781646728516,0.0 -29100,2.3017309,1.4851928,,,,,,,,,,,,,, -29200,2.2956681,1.4812927,,,,,,,,,,,,,, -29300,2.8507984,1.5546323,,,,,,,,,,,,,, -29400,3.683166,1.4620153,,,,,,,,,,,,,, -29500,2.0445635,1.4786606,,,,,,,,,,,,,, -29600,1.8455869,1.4701556,,,,,,,,,,,,,, -29700,2.5614343,1.4492298,,,,,,,,,,,,,, -29800,3.7330356,1.4879911,,,,,,,,,,,,,, -29900,1.9590671,1.4569803,,,,,,,,,,,,,, -30000,2.5759728,1.5137556,,,,,,,,,,,,,, -30100,1.7184008,1.4428234,,,,,,,,,,,,,, -30200,2.3021958,1.4170809,,,,,,,,,,,,,, -30300,1.7275308,1.4719355,,,,,,,,,,,,,, -30400,2.0928562,1.482963,,,,,,,,,,,,,, -30500,2.9945264,1.4492362,,,,,,,,,,,,,, -30600,1.630651,1.4116203,,,,,,,,,,,,,, -30700,2.285664,1.4479998,,,,,,,,,,,,,, -30728,,,0.26352453,0.0905633455560276,0.5693506,0.1632891472044952,5348.0,0.3357012,0.1066967278045213,2472.0,25941.396134853363,28576.09837579727,25941.396134853363,2632.291247367859,0.9961137771606444,0.0 -30800,2.7069132,1.4910564,,,,,,,,,,,,,, -30900,2.632244,1.4668124,,,,,,,,,,,,,, -31000,2.080625,1.4849348,,,,,,,,,,,,,, -31100,2.680619,1.5013938,,,,,,,,,,,,,, -31200,2.2811964,1.504311,,,,,,,,,,,,,, -31300,2.4850917,1.4346205,,,,,,,,,,,,,, -31400,1.6954525,1.4014215,,,,,,,,,,,,,, -31500,2.2037523,1.4387689,,,,,,,,,,,,,, -31600,2.654143,1.4017849,,,,,,,,,,,,,, -31700,2.2571814,1.4708115,,,,,,,,,,,,,, -31800,2.4723983,1.4410864,,,,,,,,,,,,,, -31900,2.207026,1.4066331,,,,,,,,,,,,,, -32000,2.5009959,1.4116186,,,,,,,,,,,,,, -32100,2.0915167,1.4230497,,,,,,,,,,,,,, -32200,2.4522076,1.4703195,,,,,,,,,,,,,, -32300,1.4224918,1.4103999,,,,,,,,,,,,,, -32400,3.921313,1.487346,,,,,,,,,,,,,, -32447,,,0.25451177,0.086225316926675,0.54807425,0.1576797937766106,5348.0,0.31943408,0.1029187739930534,2472.0,27382.316435575485,30152.26672244072,27382.316435575485,2767.3991503715515,1.0550148487091064,0.0 -32500,2.731427,1.4304184,,,,,,,,,,,,,, -32600,1.8302292,1.4308867,,,,,,,,,,,,,, -32700,3.6353333,1.396647,,,,,,,,,,,,,, -32800,2.3901763,1.3870183,,,,,,,,,,,,,, -32900,3.1793394,1.399233,,,,,,,,,,,,,, -33000,1.3808236,1.3586053,,,,,,,,,,,,,, -33100,2.2073014,1.4106703,,,,,,,,,,,,,, -33200,2.7071652,1.458095,,,,,,,,,,,,,, -33300,2.0143116,1.4174227,,,,,,,,,,,,,, -33400,2.5366747,1.4489638,,,,,,,,,,,,,, -33500,2.9808433,1.3628538,,,,,,,,,,,,,, -33600,1.9046886,1.3861108,,,,,,,,,,,,,, -33700,1.5829432,1.4761921,,,,,,,,,,,,,, -33800,1.5325724,1.410719,,,,,,,,,,,,,, -33900,2.8729486,1.3653245,,,,,,,,,,,,,, -34000,3.8169084,1.3453877,,,,,,,,,,,,,, -34100,1.5275381,1.4619894,,,,,,,,,,,,,, -34160,,,0.2348298,0.0816698966408268,0.53372824,0.1530745242669704,5348.0,0.30711773,0.0979627485629557,2472.0,28822.94229197502,31729.22262334824,28822.94229197502,2903.598759412765,1.106881618499756,0.0 -34200,1.5607598,1.3774923,,,,,,,,,,,,,, -34300,2.713557,1.4378229,,,,,,,,,,,,,, -34400,1.9199717,1.3827363,,,,,,,,,,,,,, -34500,2.4880104,1.3805735,,,,,,,,,,,,,, -34600,3.4461222,1.4134299,,,,,,,,,,,,,, -34700,1.4360083,1.3893603,,,,,,,,,,,,,, -34800,2.3082545,1.2937025,,,,,,,,,,,,,, -34900,2.4057243,1.3718097,,,,,,,,,,,,,, -35000,2.180039,1.3746114,,,,,,,,,,,,,, -35100,2.4158812,1.2677823,,,,,,,,,,,,,, -35200,1.753404,1.3584336,,,,,,,,,,,,,, -35300,2.1490395,1.368797,,,,,,,,,,,,,, -35400,1.8237243,1.4370311,,,,,,,,,,,,,, -35500,2.9401777,1.3468363,,,,,,,,,,,,,, -35600,2.3630104,1.3175681,,,,,,,,,,,,,, -35700,1.502706,1.3376933,,,,,,,,,,,,,, -35800,1.807184,1.3220574,,,,,,,,,,,,,, -35863,,,0.23932621,0.0797680012210462,0.5034849,0.144308099288452,5348.0,0.28727466,0.091747405195702,2472.0,30263.509187936783,33304.622517347336,30263.509187936783,3038.3013293743134,1.1584124565124512,0.0 -35900,2.2265198,1.3518198,,,,,,,,,,,,,, -36000,1.6307052,1.3485819,,,,,,,,,,,,,, -36100,2.2646613,1.3200704,,,,,,,,,,,,,, -36200,1.7050059,1.3118168,,,,,,,,,,,,,, -36300,2.6630044,1.3087568,,,,,,,,,,,,,, -36400,2.4169886,1.2827942,,,,,,,,,,,,,, -36500,2.3521934,1.3559197,,,,,,,,,,,,,, -36600,2.1542757,1.3486106,,,,,,,,,,,,,, -36700,1.5704567,1.3146584,,,,,,,,,,,,,, -36800,3.9297056,1.3340544,,,,,,,,,,,,,, -36900,2.2192006,1.3374163,,,,,,,,,,,,,, -37000,0.9911604,1.298452,,,,,,,,,,,,,, -37100,2.17697,1.3243911,,,,,,,,,,,,,, -37200,1.4068524,1.3062644,,,,,,,,,,,,,, -37300,3.191694,1.2978021,,,,,,,,,,,,,, -37400,1.7601453,1.2841811,,,,,,,,,,,,,, -37500,2.8351254,1.2900727,,,,,,,,,,,,,, -37591,,,0.21954021,0.0745288632163148,0.4808648,0.1389304575340085,5348.0,0.27169424,0.0879491398046026,2472.0,31703.80317473412,34883.49290180206,31703.80317473412,3176.7409851551056,1.2146975994110107,0.0 -37600,2.769735,1.3002645,,,,,,,,,,,,,, -37700,2.2245781,1.268361,,,,,,,,,,,,,, -37800,1.3911418,1.2757654,,,,,,,,,,,,,, -37900,2.9171474,1.300876,,,,,,,,,,,,,, -38000,2.8048902,1.3113594,,,,,,,,,,,,,, -38100,2.4895961,1.3193628,,,,,,,,,,,,,, -38200,2.8764486,1.3593677,,,,,,,,,,,,,, -38300,4.3514743,1.2890509,,,,,,,,,,,,,, -38400,1.4697804,1.2812188,,,,,,,,,,,,,, -38500,2.4451456,1.3049824,,,,,,,,,,,,,, -38600,6.220402,1.2550367,,,,,,,,,,,,,, -38700,2.3564656,1.2673875,,,,,,,,,,,,,, -38800,3.2433126,1.2296654,,,,,,,,,,,,,, -38900,3.4128942,1.232964,,,,,,,,,,,,,, -39000,1.3809916,1.2279934,,,,,,,,,,,,,, -39100,2.8193216,1.2588912,,,,,,,,,,,,,, -39200,4.031445,1.2018049,,,,,,,,,,,,,, -39300,1.7888882,1.2468771,,,,,,,,,,,,,, -39304,,,0.19475468,0.0674533633356491,0.45533058,0.131293626963515,5348.0,0.2551368,0.0820790932910852,2472.0,33143.80793213844,36460.42004203797,33143.80793213844,3313.524995803833,1.2724831104278564,0.0 -39400,1.6080906,1.278664,,,,,,,,,,,,,, -39500,1.8715001,1.2737777,,,,,,,,,,,,,, -39600,2.9169714,1.2350792,,,,,,,,,,,,,, -39700,3.3389313,1.2482567,,,,,,,,,,,,,, -39800,1.5169506,1.297164,,,,,,,,,,,,,, -39900,1.7298071,1.224886,,,,,,,,,,,,,, -40000,4.143226,1.243832,,,,,,,,,,,,,, -40100,1.9241328,1.2764189,,,,,,,,,,,,,, -40200,2.0293221,1.236347,,,,,,,,,,,,,, -40300,1.9744819,1.2109426,,,,,,,,,,,,,, -40400,2.73464,1.1703398,,,,,,,,,,,,,, -40500,2.1573243,1.2871683,,,,,,,,,,,,,, -40600,2.6221473,1.2224984,,,,,,,,,,,,,, -40700,3.0795581,1.2379328,,,,,,,,,,,,,, -40800,2.188524,1.2251003,,,,,,,,,,,,,, -40900,3.5523674,1.2021528,,,,,,,,,,,,,, -41000,3.0752273,1.2377447,,,,,,,,,,,,,, -41016,,,0.18020418,0.0614644881696766,0.43848684,0.126273207372293,5348.0,0.24059556,0.0772043141795137,2472.0,34583.80059170723,38035.92982244492,34583.80059170723,3448.9083466529846,1.3277525901794434,0.0 -41100,2.903432,1.2342591,,,,,,,,,,,,,, -41200,3.757362,1.2458153,,,,,,,,,,,,,, -41300,1.98019,1.2437003,,,,,,,,,,,,,, -41400,2.9501567,1.2026503,,,,,,,,,,,,,, -41500,2.6033115,1.2159535,,,,,,,,,,,,,, -41600,1.761016,1.2025571,,,,,,,,,,,,,, -41700,4.803617,1.1786659,,,,,,,,,,,,,, -41800,1.5169625,1.2138706,,,,,,,,,,,,,, -41900,2.621081,1.2285634,,,,,,,,,,,,,, -42000,2.2743971,1.2196563,,,,,,,,,,,,,, -42100,3.2579198,1.2015783,,,,,,,,,,,,,, -42200,2.5566955,1.179695,,,,,,,,,,,,,, -42300,2.1710532,1.2384914,,,,,,,,,,,,,, -42400,3.1151712,1.2024113,,,,,,,,,,,,,, -42500,1.7061226,1.1313456,,,,,,,,,,,,,, -42600,2.4670815,1.2094271,,,,,,,,,,,,,, -42700,1.6454314,1.1975068,,,,,,,,,,,,,, -42751,,,0.15269303,0.0525743253947298,0.4185573,0.1192349652915222,5348.0,0.23369418,0.0746247435663071,2472.0,36024.716413497925,39611.73369860649,36024.716413497925,3583.6617407798767,1.3814823627471924,0.0 -42800,3.297846,1.1732638,,,,,,,,,,,,,, -42900,13.603172,1.2280641,,,,,,,,,,,,,, -43000,1.9336363,1.140233,,,,,,,,,,,,,, -43100,3.4579947,1.1934901,,,,,,,,,,,,,, -43200,2.5880048,1.1818806,,,,,,,,,,,,,, -43300,1.7000995,1.1315626,,,,,,,,,,,,,, -43400,3.0529583,1.2721919,,,,,,,,,,,,,, -43500,2.1343696,1.1846716,,,,,,,,,,,,,, -43600,2.3072035,1.1754713,,,,,,,,,,,,,, -43700,1.8746635,1.1503303,,,,,,,,,,,,,, -43800,2.745939,1.1347855,,,,,,,,,,,,,, -43900,2.7749245,1.1739687,,,,,,,,,,,,,, -44000,2.5337791,1.1139332,,,,,,,,,,,,,, -44100,3.3038645,1.1486454,,,,,,,,,,,,,, -44200,1.4931571,1.242465,,,,,,,,,,,,,, -44300,1.2217971,1.192453,,,,,,,,,,,,,, -44400,2.9213932,1.1294478,,,,,,,,,,,,,, -44456,,,0.15371695,0.0517177970517024,0.4076263,0.1159234192919277,5348.0,0.22438748,0.0719029918956797,2472.0,37464.84016394615,41187.92023229599,37464.84016394615,3719.5914764404297,1.4356229305267334,0.0 -44500,2.7322783,1.1770387,,,,,,,,,,,,,, -44600,2.509926,1.1485034,,,,,,,,,,,,,, -44700,2.182186,1.1387228,,,,,,,,,,,,,, -44800,2.822122,1.1617293,,,,,,,,,,,,,, -44900,2.9822726,1.1680406,,,,,,,,,,,,,, -45000,3.205277,1.1429617,,,,,,,,,,,,,, -45100,3.1828706,1.1577094,,,,,,,,,,,,,, -45200,5.3801036,1.1134222,,,,,,,,,,,,,, -45300,1.5005809,1.1321629,,,,,,,,,,,,,, -45400,2.3976088,1.0801406,,,,,,,,,,,,,, -45500,2.2359617,1.1761878,,,,,,,,,,,,,, -45600,2.9775655,1.1193669,,,,,,,,,,,,,, -45700,2.270889,1.145896,,,,,,,,,,,,,, -45800,1.692924,1.1873107,,,,,,,,,,,,,, -45900,3.616464,1.1406645,,,,,,,,,,,,,, -46000,2.4044135,1.15358,,,,,,,,,,,,,, -46100,2.780727,1.136773,,,,,,,,,,,,,, -46165,,,0.16266528,0.0549525944734069,0.40298778,0.114620041128822,5348.0,0.22044119,0.0706233623788922,2472.0,38905.040621995926,42763.2930085659,38905.040621995926,3854.6277837753296,1.4938104152679443,0.0 -46200,1.8117995,1.1820247,,,,,,,,,,,,,, -46300,2.013323,1.1461089,,,,,,,,,,,,,, -46400,3.2652428,1.100212,,,,,,,,,,,,,, -46500,1.9070668,1.1578181,,,,,,,,,,,,,, -46600,2.1295598,1.1499774,,,,,,,,,,,,,, -46700,2.1504536,1.089201,,,,,,,,,,,,,, -46800,2.2642148,1.1632426,,,,,,,,,,,,,, -46900,1.7842135,1.1545304,,,,,,,,,,,,,, -47000,3.2147343,1.14714,,,,,,,,,,,,,, -47100,2.9678557,1.1939447,,,,,,,,,,,,,, -47200,2.5118077,1.1360933,,,,,,,,,,,,,, -47300,3.5811121,1.11874,,,,,,,,,,,,,, -47400,1.7995752,1.1291643,,,,,,,,,,,,,, -47500,4.848804,1.1480521,,,,,,,,,,,,,, -47600,2.637509,1.1781895,,,,,,,,,,,,,, -47700,2.3646815,1.1563934,,,,,,,,,,,,,, -47800,2.3536563,1.1033428,,,,,,,,,,,,,, -47900,4.4969754,1.1376561,,,,,,,,,,,,,, -47904,,,0.14715919,0.0487956127626761,0.40185603,0.1147841702308427,5348.0,0.21938264,0.0707046086974184,2472.0,40345.54191684723,44340.17746424675,40345.54191684723,3990.876124382019,1.5495550632476809,0.0 -48000,,,0.1594734,0.0546165469755269,0.401908,0.1148034795369628,5348.0,0.21940319,0.0706639855381553,2472.0,40417.566965818405,44543.13734269142,40417.566965818405,4121.736053466797,1.6115646362304688,0.0 -48000,,,,,,,,,,,40417.566965818405,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/eval_measurements.csv deleted file mode 100644 index ae26575b0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,31 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -179.9624104499817,0.0,18.64958238601685,1,0,18.64958238601685,30.657665,2472,3.2318769930737514,198.61209559440613,31.628937,3.1009376811671285,30.570345,5348,2.911437867480232 -308.8790779113769,0.0292551517486572,1458.985486984253,1712,0,1458.985486984253,2.5435536,2472,0.5706741413279711,1767.9684970378876,2.5453644,0.5917621452917745,3.0406559,5348,0.6442163800843816 -445.87190437316895,0.0791728496551513,2899.323889017105,3454,0,2899.323889017105,0.66275734,2472,0.2053703816545812,3345.4317281246185,0.6002691,0.1935208688776996,1.0050833,5348,0.2805642179248289 -580.9096684455872,0.1311507225036621,4340.835663795471,5151,0,4340.835663795471,0.542488,2472,0.1678955172343753,4922.1106123924255,0.6795138,0.2118001694225331,0.8538837,5348,0.2422835185417612 -718.7179839611053,0.1858899593353271,5780.938099622726,6877,0,5780.938099622726,0.48380572,2472,0.1544695630979221,6500.155675172806,0.62634224,0.1970878434523252,0.78297794,5348,0.2219604738503721 -851.6845552921295,0.2392699718475341,7221.522592782974,8624,0,7221.522592782974,0.44854873,2472,0.1430950785042552,8073.839767456055,0.69155705,0.2157953410595257,0.738769,5348,0.2113982834026859 -989.5500612258912,0.2970449924468994,8661.632459640503,10329,0,8661.632459640503,0.42991805,2472,0.137895314118579,9651.952268838882,0.58510447,0.1821902111638566,0.71844894,5348,0.2058178939339814 -1125.454788684845,0.3500871658325195,10101.663102388382,12066,0,10101.663102388382,0.4067004,2472,0.1325533686754819,11228.021137475967,0.58945584,0.1888911636582841,0.697573,5348,0.1997258078530948 -1262.816997051239,0.4042665958404541,11541.959456205368,13788,0,11541.959456205368,0.39410502,2472,0.1272723579712794,12805.813822984695,0.44308084,0.1463969658659924,0.6531635,5348,0.1873871612423607 -1397.6854872703552,0.454848051071167,12982.570106267927,15492,0,12982.570106267927,0.37502876,2472,0.1206101598521317,14381.421175718307,0.5043031,0.163937207851208,0.64487046,5348,0.1834673720999836 -1535.3861787319183,0.5043253898620605,14422.980370283129,17205,0,14422.980370283129,0.3541107,2472,0.1146385554404566,15959.662130594254,0.44294888,0.1443313293253173,0.6099825,5348,0.1756567577744093 -1670.8955328464508,0.5572524070739746,15863.009141683578,18938,0,15863.009141683578,0.3469349,2472,0.112688643795828,17535.333825826645,0.43057463,0.1437458643843492,0.59888196,5348,0.1724417583054153 -1805.8819556236267,0.6184689998626709,17303.21779513359,20641,0,17303.21779513359,0.33290595,2472,0.1072248288749416,19110.668027877808,0.4117053,0.1353284273644044,0.5848249,5348,0.1684061133263176 -1940.645127773285,0.6647696495056152,18743.48190689087,22360,0,18743.48190689087,0.3231798,2472,0.1041984035098409,20685.82173585892,0.40284434,0.1338424391521736,0.5693361,5348,0.1629319250412736 -2074.248395442962,0.7133445739746094,20183.945833444595,24075,0,20183.945833444595,0.31203714,2472,0.1026953466171064,22260.01620745659,0.43493432,0.138860300698049,0.5609899,5348,0.1619761143883294 -2210.4984588623047,0.7659845352172852,21624.523320913315,25757,0,21624.523320913315,0.30561492,2472,0.0986939654296914,23836.974450588223,0.361608,0.1208505050839415,0.5388162,5348,0.1562702144298444 -2344.8034658432007,0.8201003074645996,23064.93439102173,27487,0,23064.93439102173,0.29226407,2472,0.0947738305608027,25411.825412988663,0.37259036,0.1272394987236017,0.5171766,5348,0.1500526178591772 -2478.910748243332,0.8709836006164551,24504.82792067528,29212,0,24504.82792067528,0.2780522,2472,0.0901834135640728,26985.95388007164,0.30891111,0.1043790285799066,0.5069906,5348,0.1449549610434749 -2611.3780493736267,0.9297785758972168,25945.360632419583,30904,0,25945.360632419583,0.27133593,2472,0.0875429082119716,28559.09006714821,0.3336694,0.1106648943691846,0.5003634,5348,0.1430916130028867 -2745.85751581192,0.98575758934021,27385.716521024704,32614,0,27385.716521024704,0.26732162,2472,0.084963337598765,30134.061345100403,0.3215798,0.1096156143514594,0.4891848,5348,0.1402531450032343 -2880.9082431793213,1.0413026809692385,28825.82073378563,34328,0,28825.82073378563,0.2541731,2472,0.0815916153799281,31709.3510825634,0.30576283,0.103337584099204,0.4679304,5348,0.1343638066366085 -3014.781930685044,1.0968976020812988,30266.26204228401,36023,0,30266.26204228401,0.24383034,2472,0.0783620742185119,33283.799632787704,0.24863774,0.0858738576102899,0.45736533,5348,0.1318342875348774 -3149.2583301067352,1.1505632400512695,31706.707235336304,37719,0,31706.707235336304,0.24005848,2472,0.0765340320516726,34858.85399699211,0.22751224,0.0775546242095184,0.44469142,5348,0.1273352192088977 -3284.214804172516,1.206776142120361,33147.02962565422,39437,0,33147.02962565422,0.23413983,2472,0.0745638088274125,36434.26735687256,0.2554167,0.0892404637499159,0.43711883,5348,0.1243326221072245 -3417.66161775589,1.2691385746002195,34586.98317909241,41134,0,34586.98317909241,0.23012903,2472,0.07413726565515,38007.80613279343,0.2320539,0.0801694666651623,0.42751628,5348,0.1214362261892118 -3551.85942363739,1.322899580001831,36027.01313138008,42832,0,36027.01313138008,0.225074,2472,0.0719639266345743,39582.16630578041,0.24211791,0.0808709346414912,0.42340276,5348,0.1203452503934271 -3689.00634765625,1.3793542385101318,37467.38977432251,44550,0,37467.38977432251,0.22411768,2472,0.0716998760993642,41159.824618816376,0.2000132,0.0698117025523792,0.4208656,5348,0.119881827046545 -3824.576935768128,1.4339289665222168,38907.78763270378,46233,0,38907.78763270378,0.22205608,2472,0.0704202465825767,42735.924476861954,0.20435266,0.0718254876884155,0.41903222,5348,0.1190032536180812 -3960.401820898056,1.495056390762329,40348.03506684303,47961,0,40348.03506684303,0.22187033,2472,0.0707452318566815,44312.138027668,0.22109444,0.0784157378244303,0.41872,5348,0.1186267221487395 -4093.375575065613,1.5548663139343262,40377.33175325394,48000,0,40377.33175325394,0.2218685,2472,0.0707452318566815,44474.47600841522,0.14898208,0.052723391806529206,0.4187202,5348,0.11863637680179963 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/measurements.csv deleted file mode 100644 index 5448bdc50..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/measurements.csv +++ /dev/null @@ -1,512 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,17.643763,32.94795,,,,,,,,,,,,,, -1,,,31.628937,3.1009376811671285,30.570345,2.911437867480232,5348.0,30.657665,3.2318769930737514,2472.0,18.64958238601685,198.61209559440613,18.64958238601685,179.9624104499817,0.0,0.0 -100,1.2905042,6.0244274,,,,,,,,,,,,,, -200,0.7758316,5.8357086,,,,,,,,,,,,,, -300,0.30706945,5.7282553,,,,,,,,,,,,,, -400,1.7293211,5.4851093,,,,,,,,,,,,,, -500,1.4110036,4.8273373,,,,,,,,,,,,,, -600,1.5390569,3.959859,,,,,,,,,,,,,, -700,3.3612623,3.4966152,,,,,,,,,,,,,, -800,2.3696413,3.2559922,,,,,,,,,,,,,, -900,2.2213326,3.0608444,,,,,,,,,,,,,, -1000,3.1279452,2.9521115,,,,,,,,,,,,,, -1100,2.5945444,2.732706,,,,,,,,,,,,,, -1200,3.0844545,2.632289,,,,,,,,,,,,,, -1300,2.728989,2.6051102,,,,,,,,,,,,,, -1400,2.2248597,2.464379,,,,,,,,,,,,,, -1500,2.528255,2.4408908,,,,,,,,,,,,,, -1600,3.0368373,2.526615,,,,,,,,,,,,,, -1700,2.0361235,2.4041543,,,,,,,,,,,,,, -1712,,,2.5453644,0.5917621452917745,3.0406559,0.6442163800843816,5348.0,2.5435536,0.5706741413279711,2472.0,1458.985486984253,1767.9684970378876,1458.985486984253,308.8790779113769,0.0292551517486572,0.0 -1800,2.0640166,2.2666306,,,,,,,,,,,,,, -1900,2.1477098,2.2792602,,,,,,,,,,,,,, -2000,2.7720025,2.267765,,,,,,,,,,,,,, -2100,2.1742141,2.1715038,,,,,,,,,,,,,, -2200,2.5565288,2.2861362,,,,,,,,,,,,,, -2300,2.7373626,2.271649,,,,,,,,,,,,,, -2400,1.6853975,2.1071358,,,,,,,,,,,,,, -2500,2.6247127,2.1402144,,,,,,,,,,,,,, -2600,2.5106692,2.0723705,,,,,,,,,,,,,, -2700,1.8255651,2.0695407,,,,,,,,,,,,,, -2800,2.6548944,2.0751019,,,,,,,,,,,,,, -2900,3.2086992,2.055704,,,,,,,,,,,,,, -3000,1.5861479,2.0392563,,,,,,,,,,,,,, -3100,2.4723706,2.0658326,,,,,,,,,,,,,, -3200,4.5997925,2.0648606,,,,,,,,,,,,,, -3300,3.4149373,1.9201952,,,,,,,,,,,,,, -3400,3.6899815,1.9943337,,,,,,,,,,,,,, -3454,,,0.6002691,0.1935208688776996,1.0050833,0.2805642179248289,5348.0,0.66275734,0.2053703816545812,2472.0,2899.323889017105,3345.4317281246185,2899.323889017105,445.87190437316895,0.0791728496551513,0.0 -3500,3.0240116,1.964744,,,,,,,,,,,,,, -3600,9.929198,1.972284,,,,,,,,,,,,,, -3700,2.397228,1.9081552,,,,,,,,,,,,,, -3800,3.3416705,1.9936558,,,,,,,,,,,,,, -3900,1.919976,1.9710709,,,,,,,,,,,,,, -4000,1.8729178,1.8805426,,,,,,,,,,,,,, -4100,2.6349137,1.9243377,,,,,,,,,,,,,, -4200,2.6554396,1.9555675,,,,,,,,,,,,,, -4300,2.6023536,1.8940436,,,,,,,,,,,,,, -4400,3.3864913,1.8320698,,,,,,,,,,,,,, -4500,3.5458345,1.8722434,,,,,,,,,,,,,, -4600,2.075119,1.939197,,,,,,,,,,,,,, -4700,2.4863482,2.0150533,,,,,,,,,,,,,, -4800,2.1067703,1.8172338,,,,,,,,,,,,,, -4900,2.1686134,1.865531,,,,,,,,,,,,,, -5000,1.9573001,1.8023221,,,,,,,,,,,,,, -5100,3.6511097,1.8167005,,,,,,,,,,,,,, -5151,,,0.6795138,0.2118001694225331,0.8538837,0.2422835185417612,5348.0,0.542488,0.1678955172343753,2472.0,4340.835663795471,4922.1106123924255,4340.835663795471,580.9096684455872,0.1311507225036621,0.0 -5200,2.1704836,1.7861905,,,,,,,,,,,,,, -5300,2.5088794,1.8141342,,,,,,,,,,,,,, -5400,1.8885033,1.8506523,,,,,,,,,,,,,, -5500,2.8141034,1.8175873,,,,,,,,,,,,,, -5600,2.9431882,1.8891754,,,,,,,,,,,,,, -5700,1.718806,1.8234546,,,,,,,,,,,,,, -5800,3.4389443,1.8185421,,,,,,,,,,,,,, -5900,2.270771,1.7909175,,,,,,,,,,,,,, -6000,3.0105999,1.8048565,,,,,,,,,,,,,, -6100,2.4794495,1.7708088,,,,,,,,,,,,,, -6200,2.4878316,1.7615819,,,,,,,,,,,,,, -6300,2.1549594,1.8131188,,,,,,,,,,,,,, -6400,2.6095686,1.7471524,,,,,,,,,,,,,, -6500,2.8431194,1.7469049,,,,,,,,,,,,,, -6600,3.4927773,1.7243525,,,,,,,,,,,,,, -6700,4.165955,1.7572329,,,,,,,,,,,,,, -6800,4.323061,1.7466012,,,,,,,,,,,,,, -6877,,,0.62634224,0.1970878434523252,0.78297794,0.2219604738503721,5348.0,0.48380572,0.1544695630979221,2472.0,5780.938099622726,6500.155675172806,5780.938099622726,718.7179839611053,0.1858899593353271,0.0 -6900,2.0393653,1.7881141,,,,,,,,,,,,,, -7000,3.997975,2.1735446,,,,,,,,,,,,,, -7100,3.3297222,1.7355745,,,,,,,,,,,,,, -7200,2.3697894,1.791733,,,,,,,,,,,,,, -7300,5.21071,1.6825811,,,,,,,,,,,,,, -7400,2.3788323,1.72584,,,,,,,,,,,,,, -7500,3.231616,1.6828526,,,,,,,,,,,,,, -7600,3.2314394,1.7246994,,,,,,,,,,,,,, -7700,3.8124933,1.7581226,,,,,,,,,,,,,, -7800,2.770284,1.7213558,,,,,,,,,,,,,, -7900,3.2509828,1.7817897,,,,,,,,,,,,,, -8000,1.9803467,1.6338847,,,,,,,,,,,,,, -8100,3.187216,1.738893,,,,,,,,,,,,,, -8200,2.7061934,1.6865023,,,,,,,,,,,,,, -8300,2.825566,1.7086998,,,,,,,,,,,,,, -8400,2.4519556,1.7173581,,,,,,,,,,,,,, -8500,2.2942424,1.7320751,,,,,,,,,,,,,, -8600,2.6345074,1.6538697,,,,,,,,,,,,,, -8624,,,0.69155705,0.2157953410595257,0.738769,0.2113982834026859,5348.0,0.44854873,0.1430950785042552,2472.0,7221.522592782974,8073.839767456055,7221.522592782974,851.6845552921295,0.2392699718475341,0.0 -8700,3.0074441,1.6359715,,,,,,,,,,,,,, -8800,1.7338351,1.6908398,,,,,,,,,,,,,, -8900,2.3105698,1.6279457,,,,,,,,,,,,,, -9000,3.3855872,1.6894084,,,,,,,,,,,,,, -9100,2.7122266,1.6638945,,,,,,,,,,,,,, -9200,3.4261918,1.6804487,,,,,,,,,,,,,, -9300,3.304474,1.6739293,,,,,,,,,,,,,, -9400,2.5250201,1.6597812,,,,,,,,,,,,,, -9500,2.7550733,1.7109627,,,,,,,,,,,,,, -9600,2.5540438,1.604389,,,,,,,,,,,,,, -9700,2.4725385,1.7121178,,,,,,,,,,,,,, -9800,2.6291122,1.7070354,,,,,,,,,,,,,, -9900,2.001876,1.6181087,,,,,,,,,,,,,, -10000,2.2138252,1.6862645,,,,,,,,,,,,,, -10100,3.226315,1.6726515,,,,,,,,,,,,,, -10200,2.762483,1.6537555,,,,,,,,,,,,,, -10300,3.7620275,1.6536614,,,,,,,,,,,,,, -10329,,,0.58510447,0.1821902111638566,0.71844894,0.2058178939339814,5348.0,0.42991805,0.137895314118579,2472.0,8661.632459640503,9651.952268838882,8661.632459640503,989.5500612258912,0.2970449924468994,0.0 -10400,1.9683406,1.6153286,,,,,,,,,,,,,, -10500,4.453317,1.6129278,,,,,,,,,,,,,, -10600,3.0724354,1.5862924,,,,,,,,,,,,,, -10700,3.9175408,1.7002543,,,,,,,,,,,,,, -10800,2.6080787,1.6903577,,,,,,,,,,,,,, -10900,2.705009,1.6027194,,,,,,,,,,,,,, -11000,2.1912663,1.6544237,,,,,,,,,,,,,, -11100,2.5445225,1.6500158,,,,,,,,,,,,,, -11200,2.2750158,1.6588055,,,,,,,,,,,,,, -11300,4.934863,1.6352096,,,,,,,,,,,,,, -11400,3.1892328,1.6266643,,,,,,,,,,,,,, -11500,2.675254,1.677878,,,,,,,,,,,,,, -11600,2.7180617,1.701605,,,,,,,,,,,,,, -11700,4.471758,1.5975174,,,,,,,,,,,,,, -11800,5.7302427,1.6781867,,,,,,,,,,,,,, -11900,3.8237221,1.608381,,,,,,,,,,,,,, -12000,2.225581,1.5553603,,,,,,,,,,,,,, -12066,,,0.58945584,0.1888911636582841,0.697573,0.1997258078530948,5348.0,0.4067004,0.1325533686754819,2472.0,10101.663102388382,11228.021137475967,10101.663102388382,1125.454788684845,0.3500871658325195,0.0 -12100,1.726515,1.5558534,,,,,,,,,,,,,, -12200,4.0832167,1.7038335,,,,,,,,,,,,,, -12300,2.5220866,1.6573284,,,,,,,,,,,,,, -12400,3.41983,1.604792,,,,,,,,,,,,,, -12500,3.8276277,1.518466,,,,,,,,,,,,,, -12600,1.8602858,1.6020939,,,,,,,,,,,,,, -12700,2.0674832,1.5638181,,,,,,,,,,,,,, -12800,3.1699522,1.5730877,,,,,,,,,,,,,, -12900,3.7798867,1.5635427,,,,,,,,,,,,,, -13000,3.71751,1.5243903,,,,,,,,,,,,,, -13100,2.3784683,1.6370244,,,,,,,,,,,,,, -13200,2.9511194,1.6007289,,,,,,,,,,,,,, -13300,3.6502542,1.5316986,,,,,,,,,,,,,, -13400,1.6008233,1.5806457,,,,,,,,,,,,,, -13500,2.1910205,1.5503575,,,,,,,,,,,,,, -13600,3.6428692,1.6302875,,,,,,,,,,,,,, -13700,2.574834,1.5182335,,,,,,,,,,,,,, -13788,,,0.44308084,0.1463969658659924,0.6531635,0.1873871612423607,5348.0,0.39410502,0.1272723579712794,2472.0,11541.959456205368,12805.813822984695,11541.959456205368,1262.816997051239,0.4042665958404541,0.0 -13800,4.3045774,1.5465931,,,,,,,,,,,,,, -13900,2.1053517,1.6415894,,,,,,,,,,,,,, -14000,3.3391824,1.5578887,,,,,,,,,,,,,, -14100,2.07261,1.5620434,,,,,,,,,,,,,, -14200,3.8852632,1.5860184,,,,,,,,,,,,,, -14300,1.9547772,1.6586168,,,,,,,,,,,,,, -14400,2.2473645,1.5866387,,,,,,,,,,,,,, -14500,2.150304,1.5601405,,,,,,,,,,,,,, -14600,2.1902182,1.5126526,,,,,,,,,,,,,, -14700,2.7292242,1.510629,,,,,,,,,,,,,, -14800,4.2159176,1.5865891,,,,,,,,,,,,,, -14900,3.7896888,1.5822258,,,,,,,,,,,,,, -15000,4.4583397,1.6059315,,,,,,,,,,,,,, -15100,3.0192492,1.5590482,,,,,,,,,,,,,, -15200,2.2565022,1.4709673,,,,,,,,,,,,,, -15300,4.3617134,1.578189,,,,,,,,,,,,,, -15400,3.2675061,1.5227596,,,,,,,,,,,,,, -15492,,,0.5043031,0.163937207851208,0.64487046,0.1834673720999836,5348.0,0.37502876,0.1206101598521317,2472.0,12982.570106267927,14381.421175718307,12982.570106267927,1397.6854872703552,0.454848051071167,0.0 -15500,3.0355163,1.5806564,,,,,,,,,,,,,, -15600,2.0091355,1.5505803,,,,,,,,,,,,,, -15700,4.1298876,1.5983115,,,,,,,,,,,,,, -15800,3.3531637,1.5984575,,,,,,,,,,,,,, -15900,2.4538004,1.5739144,,,,,,,,,,,,,, -16000,3.4837332,1.5624437,,,,,,,,,,,,,, -16100,3.4874508,1.5866308,,,,,,,,,,,,,, -16200,3.010638,1.5437554,,,,,,,,,,,,,, -16300,2.184427,1.5487307,,,,,,,,,,,,,, -16400,5.5630803,1.5184486,,,,,,,,,,,,,, -16500,1.7899528,1.5082177,,,,,,,,,,,,,, -16600,2.242072,1.5313969,,,,,,,,,,,,,, -16700,1.9989625,1.5013016,,,,,,,,,,,,,, -16800,2.3057215,1.5857308,,,,,,,,,,,,,, -16900,2.9466279,1.526644,,,,,,,,,,,,,, -17000,2.9205174,1.5605286,,,,,,,,,,,,,, -17100,2.7971966,1.5040789,,,,,,,,,,,,,, -17200,2.750985,1.5454056,,,,,,,,,,,,,, -17205,,,0.44294888,0.1443313293253173,0.6099825,0.1756567577744093,5348.0,0.3541107,0.1146385554404566,2472.0,14422.980370283129,15959.662130594254,14422.980370283129,1535.3861787319183,0.5043253898620605,0.0 -17300,4.4010777,1.4793363,,,,,,,,,,,,,, -17400,3.656222,1.5675442,,,,,,,,,,,,,, -17500,2.1854823,1.5378056,,,,,,,,,,,,,, -17600,2.096793,1.5302423,,,,,,,,,,,,,, -17700,2.6188757,1.5179738,,,,,,,,,,,,,, -17800,2.300718,1.5194851,,,,,,,,,,,,,, -17900,2.8073168,1.4970019,,,,,,,,,,,,,, -18000,3.159364,1.4919921,,,,,,,,,,,,,, -18100,2.3313868,1.4880301,,,,,,,,,,,,,, -18200,2.6939173,1.526702,,,,,,,,,,,,,, -18300,3.4757025,1.5911932,,,,,,,,,,,,,, -18400,3.7255437,1.4700652,,,,,,,,,,,,,, -18500,3.5668173,1.5130087,,,,,,,,,,,,,, -18600,3.3393428,1.5178086,,,,,,,,,,,,,, -18700,2.2923906,1.4083315,,,,,,,,,,,,,, -18800,2.457505,1.5021938,,,,,,,,,,,,,, -18900,2.575922,1.4834025,,,,,,,,,,,,,, -18938,,,0.43057463,0.1437458643843492,0.59888196,0.1724417583054153,5348.0,0.3469349,0.112688643795828,2472.0,15863.009141683578,17535.333825826645,15863.009141683578,1670.8955328464508,0.5572524070739746,0.0 -19000,1.9328477,1.513635,,,,,,,,,,,,,, -19100,2.1838324,1.4990678,,,,,,,,,,,,,, -19200,2.4123497,1.5659096,,,,,,,,,,,,,, -19300,3.2509053,1.4792428,,,,,,,,,,,,,, -19400,3.2682843,1.4933627,,,,,,,,,,,,,, -19500,2.3763182,1.4780176,,,,,,,,,,,,,, -19600,2.4189472,1.4799277,,,,,,,,,,,,,, -19700,3.2423975,1.5075638,,,,,,,,,,,,,, -19800,1.7674271,1.4235286,,,,,,,,,,,,,, -19900,2.5578742,1.5003544,,,,,,,,,,,,,, -20000,2.577964,1.4546746,,,,,,,,,,,,,, -20100,2.816328,1.4847425,,,,,,,,,,,,,, -20200,4.06072,1.5143533,,,,,,,,,,,,,, -20300,3.0246432,1.4550353,,,,,,,,,,,,,, -20400,2.7907712,1.5099621,,,,,,,,,,,,,, -20500,4.125346,1.516901,,,,,,,,,,,,,, -20600,2.914922,1.4177685,,,,,,,,,,,,,, -20641,,,0.4117053,0.1353284273644044,0.5848249,0.1684061133263176,5348.0,0.33290595,0.1072248288749416,2472.0,17303.21779513359,19110.668027877808,17303.21779513359,1805.8819556236267,0.6184689998626709,0.0 -20700,2.4680543,1.4282708,,,,,,,,,,,,,, -20800,3.9765549,1.4624608,,,,,,,,,,,,,, -20900,3.2458012,1.4605892,,,,,,,,,,,,,, -21000,2.5769658,1.4309385,,,,,,,,,,,,,, -21100,2.0814795,1.4759451,,,,,,,,,,,,,, -21200,2.569613,1.4447211,,,,,,,,,,,,,, -21300,3.678592,1.4811952,,,,,,,,,,,,,, -21400,2.6118736,1.5135412,,,,,,,,,,,,,, -21500,2.5318787,1.515238,,,,,,,,,,,,,, -21600,2.0976555,1.4089346,,,,,,,,,,,,,, -21700,2.5415437,1.3747054,,,,,,,,,,,,,, -21800,2.9548626,1.4544313,,,,,,,,,,,,,, -21900,2.3605053,1.3717813,,,,,,,,,,,,,, -22000,4.1046658,1.4838159,,,,,,,,,,,,,, -22100,2.2469645,1.4906996,,,,,,,,,,,,,, -22200,3.2453473,1.4488587,,,,,,,,,,,,,, -22300,2.9284556,1.4901426,,,,,,,,,,,,,, -22360,,,0.40284434,0.1338424391521736,0.5693361,0.1629319250412736,5348.0,0.3231798,0.1041984035098409,2472.0,18743.48190689087,20685.82173585892,18743.48190689087,1940.645127773285,0.6647696495056152,0.0 -22400,3.0177894,1.448665,,,,,,,,,,,,,, -22500,2.3592987,1.4344561,,,,,,,,,,,,,, -22600,1.4515202,1.4454288,,,,,,,,,,,,,, -22700,3.7851667,1.3930787,,,,,,,,,,,,,, -22800,1.8587747,1.5924829,,,,,,,,,,,,,, -22900,2.5708704,1.4316274,,,,,,,,,,,,,, -23000,2.4900188,1.4703014,,,,,,,,,,,,,, -23100,4.414086,1.3901914,,,,,,,,,,,,,, -23200,2.091103,1.3854152,,,,,,,,,,,,,, -23300,1.5320128,1.4528738,,,,,,,,,,,,,, -23400,1.7970518,1.4141226,,,,,,,,,,,,,, -23500,4.427203,1.4517602,,,,,,,,,,,,,, -23600,2.4104092,1.4184248,,,,,,,,,,,,,, -23700,2.4033248,1.4429132,,,,,,,,,,,,,, -23800,3.6410334,1.443665,,,,,,,,,,,,,, -23900,2.2178562,1.4070739,,,,,,,,,,,,,, -24000,2.0302978,1.4577302,,,,,,,,,,,,,, -24075,,,0.43493432,0.138860300698049,0.5609899,0.1619761143883294,5348.0,0.31203714,0.1026953466171064,2472.0,20183.945833444595,22260.01620745659,20183.945833444595,2074.248395442962,0.7133445739746094,0.0 -24100,2.4223084,1.4195491,,,,,,,,,,,,,, -24200,3.0282855,1.355144,,,,,,,,,,,,,, -24300,2.7288923,1.3964356,,,,,,,,,,,,,, -24400,2.1800845,1.4028659,,,,,,,,,,,,,, -24500,1.7078965,1.4274509,,,,,,,,,,,,,, -24600,3.4451015,1.44244,,,,,,,,,,,,,, -24700,1.8575695,1.408935,,,,,,,,,,,,,, -24800,1.779169,1.359853,,,,,,,,,,,,,, -24900,3.942451,1.4579166,,,,,,,,,,,,,, -25000,2.7704198,1.4695488,,,,,,,,,,,,,, -25100,1.8832663,1.4473208,,,,,,,,,,,,,, -25200,1.9465151,1.3621743,,,,,,,,,,,,,, -25300,3.0985146,1.4876819,,,,,,,,,,,,,, -25400,1.7288582,1.395996,,,,,,,,,,,,,, -25500,2.1641622,1.432336,,,,,,,,,,,,,, -25600,3.1234872,1.3742055,,,,,,,,,,,,,, -25700,2.2592316,1.4981575,,,,,,,,,,,,,, -25757,,,0.361608,0.1208505050839415,0.5388162,0.1562702144298444,5348.0,0.30561492,0.0986939654296914,2472.0,21624.523320913315,23836.974450588223,21624.523320913315,2210.4984588623047,0.7659845352172852,0.0 -25800,2.297072,1.4125239,,,,,,,,,,,,,, -25900,1.619115,1.3916334,,,,,,,,,,,,,, -26000,1.5775898,1.3466592,,,,,,,,,,,,,, -26100,3.8032873,1.4443276,,,,,,,,,,,,,, -26200,1.9556262,1.3921679,,,,,,,,,,,,,, -26300,2.335743,1.4482431,,,,,,,,,,,,,, -26400,2.4291947,1.3914706,,,,,,,,,,,,,, -26500,1.7052732,1.4112784,,,,,,,,,,,,,, -26600,2.2763357,1.4051481,,,,,,,,,,,,,, -26700,9.18308,1.4222796,,,,,,,,,,,,,, -26800,2.3011835,1.3598189,,,,,,,,,,,,,, -26900,3.0176275,1.3841733,,,,,,,,,,,,,, -27000,1.6900948,1.3764112,,,,,,,,,,,,,, -27100,3.4725814,1.360174,,,,,,,,,,,,,, -27200,1.8841788,1.3798923,,,,,,,,,,,,,, -27300,1.6883776,1.397174,,,,,,,,,,,,,, -27400,3.4845002,1.3678157,,,,,,,,,,,,,, -27487,,,0.37259036,0.1272394987236017,0.5171766,0.1500526178591772,5348.0,0.29226407,0.0947738305608027,2472.0,23064.93439102173,25411.825412988663,23064.93439102173,2344.8034658432007,0.8201003074645996,0.0 -27500,3.422883,1.2904173,,,,,,,,,,,,,, -27600,2.3876765,1.3769957,,,,,,,,,,,,,, -27700,2.0715477,1.3923739,,,,,,,,,,,,,, -27800,3.02215,1.4246385,,,,,,,,,,,,,, -27900,3.121164,1.392402,,,,,,,,,,,,,, -28000,2.494826,1.3156729,,,,,,,,,,,,,, -28100,2.8908615,1.3490454,,,,,,,,,,,,,, -28200,2.2783399,1.3874544,,,,,,,,,,,,,, -28300,2.6253085,1.3538479,,,,,,,,,,,,,, -28400,2.3654068,1.3549675,,,,,,,,,,,,,, -28500,2.1110518,1.3620814,,,,,,,,,,,,,, -28600,1.869286,1.3660975,,,,,,,,,,,,,, -28700,2.0614917,1.3802696,,,,,,,,,,,,,, -28800,3.2029238,1.292813,,,,,,,,,,,,,, -28900,4.9998164,1.3803817,,,,,,,,,,,,,, -29000,3.173962,1.3784496,,,,,,,,,,,,,, -29100,2.6169453,1.3521144,,,,,,,,,,,,,, -29200,2.4810326,1.2983044,,,,,,,,,,,,,, -29212,,,0.30891111,0.1043790285799066,0.5069906,0.1449549610434749,5348.0,0.2780522,0.0901834135640728,2472.0,24504.82792067528,26985.95388007164,24504.82792067528,2478.910748243332,0.8709836006164551,0.0 -29300,2.7375576,1.3332751,,,,,,,,,,,,,, -29400,1.7273302,1.3244056,,,,,,,,,,,,,, -29500,2.3634093,1.3336606,,,,,,,,,,,,,, -29600,2.8183208,1.3674177,,,,,,,,,,,,,, -29700,2.966894,1.2632389,,,,,,,,,,,,,, -29800,2.2571259,1.3440001,,,,,,,,,,,,,, -29900,3.503938,1.3344891,,,,,,,,,,,,,, -30000,2.5546086,1.3781646,,,,,,,,,,,,,, -30100,3.285717,1.3353652,,,,,,,,,,,,,, -30200,2.3623424,1.2982166,,,,,,,,,,,,,, -30300,2.5248573,1.3324381,,,,,,,,,,,,,, -30400,1.5959208,1.3355238,,,,,,,,,,,,,, -30500,1.8390731,1.2564707,,,,,,,,,,,,,, -30600,2.6979837,1.289243,,,,,,,,,,,,,, -30700,2.0685375,1.2491264,,,,,,,,,,,,,, -30800,4.4151254,1.2938359,,,,,,,,,,,,,, -30900,2.9004405,1.339847,,,,,,,,,,,,,, -30904,,,0.3336694,0.1106648943691846,0.5003634,0.1430916130028867,5348.0,0.27133593,0.0875429082119716,2472.0,25945.360632419583,28559.09006714821,25945.360632419583,2611.3780493736267,0.9297785758972168,0.0 -31000,1.7499745,1.3056216,,,,,,,,,,,,,, -31100,2.6845317,1.2912021,,,,,,,,,,,,,, -31200,2.4741273,1.2788326,,,,,,,,,,,,,, -31300,2.0924332,1.2615496,,,,,,,,,,,,,, -31400,2.9010992,1.2711463,,,,,,,,,,,,,, -31500,2.0311573,1.2884667,,,,,,,,,,,,,, -31600,3.1765525,1.3112013,,,,,,,,,,,,,, -31700,4.2129855,1.3143663,,,,,,,,,,,,,, -31800,1.9183481,1.2838017,,,,,,,,,,,,,, -31900,1.5388061,1.3337852,,,,,,,,,,,,,, -32000,2.1237562,1.2542367,,,,,,,,,,,,,, -32100,2.7071435,1.3210001,,,,,,,,,,,,,, -32200,2.1501043,1.2836559,,,,,,,,,,,,,, -32300,3.453486,1.2587147,,,,,,,,,,,,,, -32400,3.4070811,1.3418872,,,,,,,,,,,,,, -32500,2.1681964,1.2923735,,,,,,,,,,,,,, -32600,5.3134837,1.2690823,,,,,,,,,,,,,, -32614,,,0.3215798,0.1096156143514594,0.4891848,0.1402531450032343,5348.0,0.26732162,0.084963337598765,2472.0,27385.716521024704,30134.061345100403,27385.716521024704,2745.85751581192,0.98575758934021,0.0 -32700,4.5203876,1.3103151,,,,,,,,,,,,,, -32800,2.7854955,1.2561584,,,,,,,,,,,,,, -32900,1.7988371,1.2846125,,,,,,,,,,,,,, -33000,1.8748662,1.2339853,,,,,,,,,,,,,, -33100,2.9607232,1.309509,,,,,,,,,,,,,, -33200,2.363251,1.3115753,,,,,,,,,,,,,, -33300,3.1086435,1.2658063,,,,,,,,,,,,,, -33400,2.2621646,1.2429836,,,,,,,,,,,,,, -33500,3.0728452,1.243752,,,,,,,,,,,,,, -33600,3.6560879,1.3085915,,,,,,,,,,,,,, -33700,2.8419774,1.2921351,,,,,,,,,,,,,, -33800,3.2763872,1.2982321,,,,,,,,,,,,,, -33900,2.0743906,1.2379104,,,,,,,,,,,,,, -34000,1.8926464,1.2189646,,,,,,,,,,,,,, -34100,1.944517,1.2962472,,,,,,,,,,,,,, -34200,2.4957623,1.2567964,,,,,,,,,,,,,, -34300,2.7393005,1.2492063,,,,,,,,,,,,,, -34328,,,0.30576283,0.103337584099204,0.4679304,0.1343638066366085,5348.0,0.2541731,0.0815916153799281,2472.0,28825.82073378563,31709.3510825634,28825.82073378563,2880.9082431793213,1.0413026809692385,0.0 -34400,2.3073123,1.2176849,,,,,,,,,,,,,, -34500,3.2668676,1.2318709,,,,,,,,,,,,,, -34600,2.9382885,1.3587935,,,,,,,,,,,,,, -34700,2.8268054,1.2416941,,,,,,,,,,,,,, -34800,2.0569634,1.2212116,,,,,,,,,,,,,, -34900,1.7240378,1.2663138,,,,,,,,,,,,,, -35000,2.3466518,1.3149445,,,,,,,,,,,,,, -35100,2.2750986,1.2347609,,,,,,,,,,,,,, -35200,2.0933788,1.2486509,,,,,,,,,,,,,, -35300,1.7779214,1.2222449,,,,,,,,,,,,,, -35400,1.7731366,1.321599,,,,,,,,,,,,,, -35500,2.8425329,1.2728467,,,,,,,,,,,,,, -35600,2.256974,1.2327468,,,,,,,,,,,,,, -35700,4.005249,1.2373058,,,,,,,,,,,,,, -35800,1.6644627,1.2233506,,,,,,,,,,,,,, -35900,2.5046082,1.3034897,,,,,,,,,,,,,, -36000,1.3684922,1.1756003,,,,,,,,,,,,,, -36023,,,0.24863774,0.0858738576102899,0.45736533,0.1318342875348774,5348.0,0.24383034,0.0783620742185119,2472.0,30266.26204228401,33283.799632787704,30266.26204228401,3014.781930685044,1.0968976020812988,0.0 -36100,1.66841,1.2614559,,,,,,,,,,,,,, -36200,3.6780632,1.1955787,,,,,,,,,,,,,, -36300,2.5860047,1.2232914,,,,,,,,,,,,,, -36400,2.5500407,1.2530943,,,,,,,,,,,,,, -36500,2.1209903,1.2689474,,,,,,,,,,,,,, -36600,2.2418394,1.200424,,,,,,,,,,,,,, -36700,1.9009545,1.1863576,,,,,,,,,,,,,, -36800,8.007783,1.1884316,,,,,,,,,,,,,, -36900,3.2688668,1.237666,,,,,,,,,,,,,, -37000,4.387223,1.2358632,,,,,,,,,,,,,, -37100,2.0729203,1.2704333,,,,,,,,,,,,,, -37200,2.5048652,1.2339417,,,,,,,,,,,,,, -37300,1.6210994,1.1568027,,,,,,,,,,,,,, -37400,1.850986,1.2331284,,,,,,,,,,,,,, -37500,4.552544,1.200725,,,,,,,,,,,,,, -37600,1.523679,1.2175021,,,,,,,,,,,,,, -37700,2.333252,1.2004685,,,,,,,,,,,,,, -37719,,,0.22751224,0.0775546242095184,0.44469142,0.1273352192088977,5348.0,0.24005848,0.0765340320516726,2472.0,31706.707235336304,34858.85399699211,31706.707235336304,3149.2583301067352,1.1505632400512695,0.0 -37800,2.1442494,1.2531867,,,,,,,,,,,,,, -37900,2.7311788,1.2537072,,,,,,,,,,,,,, -38000,2.1913772,1.2285326,,,,,,,,,,,,,, -38100,2.9046419,1.217014,,,,,,,,,,,,,, -38200,2.0261533,1.1995392,,,,,,,,,,,,,, -38300,4.5956225,1.2367858,,,,,,,,,,,,,, -38400,2.4005551,1.2218022,,,,,,,,,,,,,, -38500,2.3528373,1.2580017,,,,,,,,,,,,,, -38600,2.6554494,1.1650081,,,,,,,,,,,,,, -38700,2.803607,1.2062744,,,,,,,,,,,,,, -38800,2.0250723,1.2137089,,,,,,,,,,,,,, -38900,2.006204,1.1926128,,,,,,,,,,,,,, -39000,2.3692694,1.2152333,,,,,,,,,,,,,, -39100,2.389249,1.2219963,,,,,,,,,,,,,, -39200,2.846703,1.2478969,,,,,,,,,,,,,, -39300,3.3333035,1.1528847,,,,,,,,,,,,,, -39400,1.9309202,1.1589952,,,,,,,,,,,,,, -39437,,,0.2554167,0.0892404637499159,0.43711883,0.1243326221072245,5348.0,0.23413983,0.0745638088274125,2472.0,33147.02962565422,36434.26735687256,33147.02962565422,3284.214804172516,1.206776142120361,0.0 -39500,2.5312848,1.2118425,,,,,,,,,,,,,, -39600,3.7952173,1.1936289,,,,,,,,,,,,,, -39700,3.5871713,1.2124304,,,,,,,,,,,,,, -39800,3.2722185,1.1543584,,,,,,,,,,,,,, -39900,1.9571007,1.2041439,,,,,,,,,,,,,, -40000,2.290141,1.191494,,,,,,,,,,,,,, -40100,2.134536,1.202128,,,,,,,,,,,,,, -40200,4.6667047,1.1506741,,,,,,,,,,,,,, -40300,2.1615674,1.1555729,,,,,,,,,,,,,, -40400,2.0174875,1.119246,,,,,,,,,,,,,, -40500,3.0722136,1.2114428,,,,,,,,,,,,,, -40600,1.3998482,1.1450883,,,,,,,,,,,,,, -40700,3.0106099,1.1977664,,,,,,,,,,,,,, -40800,2.095177,1.1714375,,,,,,,,,,,,,, -40900,3.606102,1.2064334,,,,,,,,,,,,,, -41000,2.7977939,1.1743715,,,,,,,,,,,,,, -41100,2.355592,1.1943952,,,,,,,,,,,,,, -41134,,,0.2320539,0.0801694666651623,0.42751628,0.1214362261892118,5348.0,0.23012903,0.07413726565515,2472.0,34586.98317909241,38007.80613279343,34586.98317909241,3417.66161775589,1.2691385746002195,0.0 -41200,3.772259,1.147287,,,,,,,,,,,,,, -41300,5.7286434,1.1354872,,,,,,,,,,,,,, -41400,3.282774,1.1808136,,,,,,,,,,,,,, -41500,2.1543653,1.1355418,,,,,,,,,,,,,, -41600,1.9087523,1.1646903,,,,,,,,,,,,,, -41700,2.298044,1.1751798,,,,,,,,,,,,,, -41800,2.0589724,1.1875831,,,,,,,,,,,,,, -41900,2.7918782,1.1866941,,,,,,,,,,,,,, -42000,2.0121186,1.1386627,,,,,,,,,,,,,, -42100,2.9019313,1.170217,,,,,,,,,,,,,, -42200,3.6805875,1.1977036,,,,,,,,,,,,,, -42300,1.7718693,1.2049848,,,,,,,,,,,,,, -42400,2.72365,1.1754214,,,,,,,,,,,,,, -42500,2.0342057,1.1509681,,,,,,,,,,,,,, -42600,2.6830485,1.1519161,,,,,,,,,,,,,, -42700,1.5348516,1.1731558,,,,,,,,,,,,,, -42800,1.6375452,1.1859809,,,,,,,,,,,,,, -42832,,,0.24211791,0.0808709346414912,0.42340276,0.1203452503934271,5348.0,0.225074,0.0719639266345743,2472.0,36027.01313138008,39582.16630578041,36027.01313138008,3551.85942363739,1.322899580001831,0.0 -42900,2.9673238,1.191998,,,,,,,,,,,,,, -43000,2.6160605,1.1188197,,,,,,,,,,,,,, -43100,2.5713747,1.1267529,,,,,,,,,,,,,, -43200,1.9885579,1.1646702,,,,,,,,,,,,,, -43300,1.7891163,1.1461515,,,,,,,,,,,,,, -43400,2.687062,1.2056495,,,,,,,,,,,,,, -43500,2.3717663,1.1602384,,,,,,,,,,,,,, -43600,2.1228476,1.18974,,,,,,,,,,,,,, -43700,2.1115677,1.174139,,,,,,,,,,,,,, -43800,2.4652283,1.1425015,,,,,,,,,,,,,, -43900,4.5383177,1.1215017,,,,,,,,,,,,,, -44000,5.0948286,1.1990303,,,,,,,,,,,,,, -44100,2.2755442,1.1789887,,,,,,,,,,,,,, -44200,1.3271744,1.0896486,,,,,,,,,,,,,, -44300,1.9794769,1.1814002,,,,,,,,,,,,,, -44400,2.590187,1.1249416,,,,,,,,,,,,,, -44500,2.4807684,1.126992,,,,,,,,,,,,,, -44550,,,0.2000132,0.0698117025523792,0.4208656,0.119881827046545,5348.0,0.22411768,0.0716998760993642,2472.0,37467.38977432251,41159.824618816376,37467.38977432251,3689.00634765625,1.3793542385101318,0.0 -44600,2.7864914,1.1298082,,,,,,,,,,,,,, -44700,2.43474,1.1607352,,,,,,,,,,,,,, -44800,1.6225997,1.1220149,,,,,,,,,,,,,, -44900,2.5069323,1.1737405,,,,,,,,,,,,,, -45000,1.918609,1.1254709,,,,,,,,,,,,,, -45100,2.8321567,1.1626766,,,,,,,,,,,,,, -45200,1.7065239,1.1394788,,,,,,,,,,,,,, -45300,2.8153183,1.0868064,,,,,,,,,,,,,, -45400,1.8422242,1.066611,,,,,,,,,,,,,, -45500,4.206991,1.1697434,,,,,,,,,,,,,, -45600,4.403618,1.1417897,,,,,,,,,,,,,, -45700,1.9714476,1.189083,,,,,,,,,,,,,, -45800,2.0244951,1.2286575,,,,,,,,,,,,,, -45900,6.0410986,1.114196,,,,,,,,,,,,,, -46000,3.4429586,1.1871176,,,,,,,,,,,,,, -46100,3.3357801,1.1619506,,,,,,,,,,,,,, -46200,3.728532,1.1866401,,,,,,,,,,,,,, -46233,,,0.20435266,0.0718254876884155,0.41903222,0.1190032536180812,5348.0,0.22205608,0.0704202465825767,2472.0,38907.78763270378,42735.924476861954,38907.78763270378,3824.576935768128,1.4339289665222168,0.0 -46300,1.7601196,1.169534,,,,,,,,,,,,,, -46400,2.0516071,1.1653374,,,,,,,,,,,,,, -46500,4.234861,1.1211479,,,,,,,,,,,,,, -46600,3.14692,1.1985924,,,,,,,,,,,,,, -46700,2.6189115,1.1481397,,,,,,,,,,,,,, -46800,2.0113142,1.1651821,,,,,,,,,,,,,, -46900,1.5827571,1.1261779,,,,,,,,,,,,,, -47000,1.6572095,1.1731961,,,,,,,,,,,,,, -47100,2.1569927,1.1591283,,,,,,,,,,,,,, -47200,4.07708,1.1362168,,,,,,,,,,,,,, -47300,2.1000214,1.1115694,,,,,,,,,,,,,, -47400,3.6087468,1.1565562,,,,,,,,,,,,,, -47500,1.9806669,1.1776345,,,,,,,,,,,,,, -47600,4.979935,1.0926243,,,,,,,,,,,,,, -47700,3.7453156,1.1808792,,,,,,,,,,,,,, -47800,2.3193614,1.135165,,,,,,,,,,,,,, -47900,3.7583153,1.1031457,,,,,,,,,,,,,, -47961,,,0.22109444,0.0784157378244303,0.41872,0.1186267221487395,5348.0,0.22187033,0.0707452318566815,2472.0,40348.03506684303,44312.138027668,40348.03506684303,3960.401820898056,1.495056390762329,0.0 -48000,,,0.14898208,0.0527233918065292,0.4187202,0.1186363768017996,5348.0,0.2218685,0.0707452318566815,2472.0,40377.33175325394,44474.47600841522,40377.33175325394,4093.375575065613,1.5548663139343262,0.0 -48000,,,,,,,,,,,40377.33175325394,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 8f191ae7e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -305.463401556015,0.0,18.12672519683838,1,0,18.12672519683838,0.5256852507591248,0.7376683354377747,0.0260269185740518,43793,323.5901656150818,0.5290504693984985,0.7363465428352356,0.0208464848778457,0.5270814895629883,0.737440824508667,0.024032072784418,43793 -424.9269685745239,0.0321342945098876,258.0806932449341,738,0,258.0806932449341,0.983142077922821,0.0811450853943824,0.041607852548778,43793,683.0615284442902,0.986726939678192,0.0696189478039741,0.0375416835773601,0.9841179251670836,0.0782666727900505,0.039824265408219,43793 -545.8983333110809,0.0606129169464111,498.13424348831177,1472,0,498.13424348831177,0.9831724166870116,0.0650959089398384,0.074120522459401,43793,1044.1351709365845,0.9868612289428712,0.0518169067800045,0.0714212744243948,0.9841642379760742,0.0616440325975418,0.0751438053838872,43793 -668.7412447929382,0.0881714820861816,738.378705739975,2221,0,738.378705739975,0.9841352701187134,0.0557597503066062,0.1304286933554119,43793,1407.2695829868317,0.9878098964691162,0.043958980590105,0.136974567316324,0.9851465821266174,0.0528057515621185,0.1295014808454525,43793 -790.6891014575958,0.1168539524078369,978.384375333786,2968,0,978.384375333786,0.9843475222587584,0.0542657747864723,0.1536343785086584,43793,1769.2708656787872,0.9880942702293396,0.0420986711978912,0.1608620839752452,0.985255777835846,0.0515386573970317,0.1466299149649001,43793 -916.2338988780976,0.1446871757507324,1218.5362486839294,3714,0,1218.5362486839294,0.9845543503761292,0.0525040216743946,0.1697330694428799,43793,2135.0145077705383,0.988378643989563,0.0402594655752182,0.1823391436100207,0.9854835271835328,0.0498136468231678,0.1700115195174815,43793 -1045.9301965236664,0.1723625659942627,1458.7867727279663,4458,0,1458.7867727279663,0.9847325086593628,0.0519967675209045,0.1873038039271249,43793,2505.007899045944,0.9884053468704224,0.0399228520691394,0.2100107487256066,0.9856386184692384,0.0493864566087722,0.1833988949672098,43793 -1176.5565106868744,0.1991899013519287,1698.7766785621643,5192,0,1698.7766785621643,0.9848563075065612,0.0509334243834018,0.1923857107825827,43793,2875.6716067790985,0.9888152480125428,0.0387150831520557,0.2292613295651795,0.9857624173164368,0.048279769718647,0.1916392368220283,43793 -1304.0002155303955,0.228431224822998,1938.92862033844,5929,0,1938.92862033844,0.9850403666496276,0.0503462068736553,0.208156916642346,43793,3243.315950632096,0.9890048503875732,0.037666168063879,0.2492122623022253,0.9859474897384644,0.0476326122879982,0.2099731371247263,43793 -1431.8058450222015,0.2576940059661865,2179.0981862545013,6671,0,2179.0981862545013,0.9851086139678956,0.0502110980451107,0.2054259037859834,43793,3611.34085059166,0.9889360070228576,0.0377796180546283,0.2379796961105802,0.9859402179718018,0.0476204343140125,0.2081422313743727,43793 -1561.7417376041412,0.2840974330902099,2419.294604063034,7435,0,2419.294604063034,0.9852202534675598,0.0493000671267509,0.2102784090723167,43793,3981.51959657669,0.988992154598236,0.0374467149376869,0.2587613004191066,0.9860936403274536,0.0468455292284488,0.2130963062876768,43793 -1689.6956989765167,0.3128774166107178,2659.541307926178,8200,0,2659.541307926178,0.9851284027099608,0.0495839193463325,0.2186114076168888,43793,4349.7693428993225,0.9891512393951416,0.0365913212299346,0.2610685178132049,0.9860554933547974,0.0468535237014293,0.2194025693009693,43793 -1818.4839661121368,0.3447468280792236,2899.5621926784515,8952,0,2899.5621926784515,0.9851734638214112,0.0489959381520748,0.2270890372536049,43793,4718.630972146988,0.9891273975372314,0.0366468429565429,0.2791102145414673,0.98610258102417,0.0462144687771797,0.2229995826347963,43793 -1948.37777876854,0.3719336986541748,3139.516256332397,9706,0,3139.516256332397,0.9853731393814088,0.0486707761883735,0.2265520511640969,43793,5088.528156280518,0.98965722322464,0.0352459549903869,0.2876529415246712,0.9862491488456726,0.0461195521056652,0.2190752261508913,43793 -2074.412222862244,0.4013481140136719,3379.5207164287567,10468,0,3379.5207164287567,0.9854143857955932,0.0490829050540924,0.2329932907320756,43793,5454.616222381592,0.9896705150604248,0.0348237752914428,0.3068134647972326,0.9862880706787108,0.0463917814195156,0.2411961410225304,43793 -2202.437658548355,0.4295799732208252,3619.7391617298126,11226,0,3619.7391617298126,0.9855487942695618,0.0485002435743808,0.239486379839807,43793,5822.907917261124,0.9897632002830504,0.0342228785157203,0.3330268928440666,0.9864967465400696,0.0456630326807498,0.2484119453619031,43793 -2333.4328026771545,0.4608142375946045,3859.737271785736,11982,0,3859.737271785736,0.9855656027793884,0.0477055162191391,0.2389553858798118,43793,6193.9528086185455,0.9901514053344728,0.0327607914805412,0.3584733182330887,0.9865888953208924,0.0448321178555488,0.2500146851650296,43793 -2464.989315032959,0.4908986091613769,4099.829087495804,12739,0,4099.829087495804,0.985632598400116,0.0476379878818988,0.2430544044938929,43793,6565.651482105255,0.9903623461723328,0.0321791432797908,0.3770886565685112,0.9865893125534058,0.0448479317128658,0.2501933144468553,43793 -2594.4998412132263,0.5199365615844727,4339.98765039444,13493,0,4339.98765039444,0.985736608505249,0.0479441359639167,0.2448625925982244,43793,6935.369527339935,0.9904111623764038,0.0319937393069267,0.3685707194392107,0.986591339111328,0.0450877733528614,0.2536282798796875,43793 -2724.34804224968,0.5486664772033691,4580.077013969421,14249,0,4580.077013969421,0.9857930541038512,0.0473803281784057,0.2434473027563511,43793,7305.356061458588,0.9904477000236512,0.0317926034331321,0.3795072526894285,0.9866713285446168,0.0447172634303569,0.2544789447205896,43793 -2855.0291335582733,0.5781416893005371,4820.203650474548,15004,0,4820.203650474548,0.9858074188232422,0.0476044304668903,0.2454498017573989,43793,7676.213287353516,0.9902809858322144,0.0322612114250659,0.3739686437256189,0.9866579174995422,0.0448336601257324,0.2554795498068286,43793 -2991.6243121624,0.6067461967468262,5060.340332508087,15757,0,5060.340332508087,0.98580402135849,0.0471926145255565,0.2571249985640839,43793,8052.993452072144,0.9904695749282836,0.0315851792693138,0.3780887865883426,0.9866855144500732,0.044472336769104,0.2637727440673564,43793 -3126.3872005939484,0.6358902454376221,5300.3963787555695,16510,0,5300.3963787555695,0.985867202281952,0.047187402844429,0.2516207197326865,43793,8427.862237930298,0.9904604554176332,0.0316490679979324,0.3682752869847712,0.9867634773254396,0.0443042665719985,0.2625507900616484,43793 -3253.9894936084747,0.6687760353088379,5540.545610666275,17255,0,5540.545610666275,0.9859405159950256,0.04707632958889,0.2522397872860027,43793,8795.66866350174,0.9904911518096924,0.0313327424228191,0.3850487983348396,0.9867411255836488,0.044379997998476,0.2592182022289485,43793 -3387.650359153748,0.7024412155151367,5780.521510839462,18009,0,5780.521510839462,0.9858688712120056,0.0477351546287536,0.2433850683436001,43793,9169.35992527008,0.9906929135322572,0.0308549869805574,0.3971799953290061,0.9866570830345154,0.0449607670307159,0.2552737329943473,43793 -3518.769593477249,0.7336812019348145,6020.512645244598,18759,0,6020.512645244598,0.985846996307373,0.0472518391907215,0.2473320580292265,43793,9540.521861076357,0.990839421749115,0.030348252505064,0.4093167248324422,0.9866721034049988,0.0445089861750602,0.2656968890905586,43793 -3647.493248462677,0.7661452293395996,6260.5007147789,19514,0,6260.5007147789,0.9858187437057496,0.0476085133850574,0.2507867971943465,43793,9909.286611318588,0.9907936453819276,0.0300093758851289,0.4283451389177612,0.9867143034934998,0.0449393689632415,0.2623142473408286,43793 -3778.8022241592407,0.7969620227813721,6500.6417491436005,20272,0,6500.6417491436005,0.9859198331832886,0.0473172441124916,0.2484191606510649,43793,10280.78795480728,0.991016924381256,0.0292759668081998,0.4346414677193345,0.9868003726005554,0.0445684269070625,0.2706530486331538,43793 -3906.517117500305,0.8276827335357666,6740.721883535385,21028,0,6740.721883535385,0.9858166575431824,0.0471398457884788,0.2538452476150781,43793,10648.634120225906,0.9911913275718688,0.029066402465105,0.4319467944312771,0.9865812063217164,0.0446751601994037,0.2618761943758421,43793 -4039.047303438186,0.8594973087310791,6980.779272079468,21774,0,6980.779272079468,0.9858962893486024,0.0469451062381267,0.2553447340116264,43793,11021.275067090988,0.991000235080719,0.0297731887549161,0.4357048122346068,0.9867650866508484,0.0443596877157688,0.2604349182296928,43793 -4167.759459018707,0.8915479183197021,7220.825093746185,22527,0,7220.825093746185,0.985925316810608,0.0474839806556701,0.2561323133732616,43793,11390.0852560997,0.990811824798584,0.0301263500005006,0.3952604718650014,0.9867122769355774,0.0446335412561893,0.2590375873548066,43793 -4299.94361448288,0.9270291328430176,7460.819598436356,23280,0,7460.819598436356,0.985958993434906,0.0472457595169544,0.2608490687676737,43793,11762.319630622864,0.9908799529075624,0.0300232395529747,0.4227056232213048,0.9867395162582396,0.0447255373001098,0.2640613190295243,43793 -4432.592004299164,1.2428011894226074,7700.661962509155,24031,0,7700.661962509155,0.9859577417373656,0.04722660779953,0.2516795713976436,43793,12135.146509170532,0.9909188747406006,0.0297171231359243,0.4331208281078121,0.9867057800292968,0.0446142293512821,0.2611251490209201,43793 -4566.1248388290405,1.2736265659332275,7940.79802775383,24785,0,7940.79802775383,0.9858036041259766,0.0478490367531776,0.2503198834780422,43793,12508.866287469864,0.9908357858657836,0.0298755336552858,0.4064477900468354,0.9866595268249512,0.0450131893157959,0.2573716657473429,43793 -4701.938496828079,1.304905891418457,8180.888298749924,25537,0,8180.888298749924,0.9858790040016174,0.0471441224217414,0.2556667009203458,43793,12884.822046756744,0.9911275506019592,0.0291652176529169,0.4410087981079901,0.9867557287216188,0.0444139316678047,0.271666663964825,43793 -4835.052654981613,1.3399152755737305,8421.127045869827,26288,0,8421.127045869827,0.985999882221222,0.0471678562462329,0.2552766763990095,43793,13258.230972528458,0.9913343191146852,0.0283323358744382,0.4558309209604774,0.986750066280365,0.0446289516985416,0.2663130309218301,43793 -4968.371803283691,1.371005296707153,8661.194565296173,27044,0,8661.194565296173,0.9859737753868104,0.0477073155343532,0.2546293091749866,43793,13631.669520616531,0.9913906455039978,0.0281543508172035,0.4644263463446575,0.9868559837341307,0.0448550656437873,0.2704016450281102,43793 -5099.115159749985,1.403883695602417,8901.284093856812,27801,0,8901.284093856812,0.9859451055526732,0.0478653758764267,0.2571030521372473,43793,14002.555571317673,0.9913938045501708,0.0278185661882162,0.476967536096138,0.986710250377655,0.0451094470918178,0.2647917434751276,43793 -5230.611914157867,1.4391045570373535,9141.320052146912,28546,0,9141.320052146912,0.9859514236450196,0.0475336760282516,0.2583526410611838,43793,14374.145962715147,0.9915661811828612,0.0271521862596273,0.48831206860766,0.986932337284088,0.0444307848811149,0.2773346118327409,43793 -5365.830714225769,1.4701387882232666,9381.337366342545,29299,0,9381.337366342545,0.9858566522598268,0.047407079488039,0.2526362549918176,43793,14749.433589935305,0.9914606809616088,0.0279100723564624,0.4554978870592862,0.9867382645606996,0.044542621821165,0.2711044475252064,43793 -5498.725798130035,1.5060889720916748,9621.466262102129,30045,0,9621.466262102129,0.9859737753868104,0.0471785925328731,0.2610579508779325,43793,15122.514045715332,0.9913190603256226,0.0283074267208576,0.4533270602357647,0.986749231815338,0.044565699994564,0.2735216258339702,43793 -5629.058895349503,1.540019989013672,9861.666239261627,30794,0,9861.666239261627,0.9858756065368652,0.0477723777294158,0.2537742032421881,43793,15493.1018307209,0.9913336038589478,0.0282726194709539,0.4520790403454369,0.986777663230896,0.0448727048933506,0.2668618929323761,43793 -5765.29677939415,1.5725555419921875,10101.732964754105,31550,0,10101.732964754105,0.9859691262245178,0.0474783778190612,0.263347388682702,43793,15869.459534406662,0.9914684891700744,0.0279049314558506,0.46250529010033,0.986751675605774,0.0447242632508277,0.2758591528506365,43793 -5896.408312559128,1.604877471923828,10341.703769683838,32282,0,10341.703769683838,0.985917329788208,0.0468419790267944,0.2613756747206935,43793,16240.596867084503,0.9915327429771424,0.0276380106806755,0.4701164224800756,0.9868023991584778,0.0441828817129135,0.2724851975870768,43793 -6031.045778036118,1.6387574672698977,10581.765675783156,33014,0,10581.765675783156,0.9860011339187622,0.0471965111792087,0.2568954076632428,43793,16615.352819919586,0.9916948676109314,0.026967754587531,0.4762917360420234,0.9868279695510864,0.0444532893598079,0.2725076791228633,43793 -6161.0498831272125,1.6721007823944092,10821.961535930634,33766,0,10821.961535930634,0.9859564900398254,0.0474421940743923,0.2586274867507777,43793,16985.606071949005,0.9919641613960266,0.0263283308595418,0.4970465806059808,0.9867817163467408,0.0448424257338047,0.2703013692188287,43793 -6289.505045890808,1.7057373523712158,11062.15025663376,34522,0,11062.15025663376,0.9859733581542968,0.0475214347243309,0.2607594029126369,43793,17354.303694963455,0.9921725392341614,0.0255570504814386,0.527538363625475,0.986916482448578,0.0447726920247077,0.2722659590388679,43793 -6419.514434099197,1.7393901348114014,11302.103286266329,35277,0,11302.103286266329,0.9859846830368042,0.0482017584145069,0.2530558781775611,43793,17724.31993317604,0.9921831488609314,0.0251275803893804,0.532644755953247,0.9869319200515748,0.0451381579041481,0.2747726059200759,43793 -6547.149572849274,1.7725646495819092,11542.11351633072,36035,0,11542.11351633072,0.9860104322433472,0.047689463943243,0.2610775836816156,43793,18092.01842617989,0.9919872879981996,0.0259499680250883,0.5178907389430005,0.9868706464767456,0.0447487831115722,0.2739533269282222,43793 -6677.873930931091,1.8053655624389648,11782.1854326725,36789,0,11782.1854326725,0.985889494419098,0.0480927638709545,0.2523359419552278,43793,18462.86779975891,0.9918658137321472,0.026509465649724,0.4845961109288035,0.9867196083068848,0.0451682284474372,0.2734755096168028,43793 -6806.229717254639,1.8392269611358645,12022.233164787292,37526,0,12022.233164787292,0.9860280752182008,0.0478722341358661,0.2616498358073807,43793,18831.325871944427,0.9918019771575928,0.0265244841575622,0.4839325470592795,0.9868669509887696,0.0447966158390045,0.2795964258277546,43793 -6936.348298549652,1.873117446899414,12262.27923464775,38275,0,12262.27923464775,0.9859615564346312,0.0483447611331939,0.255402046353583,43793,19201.54480075836,0.9918152093887328,0.0264482665807008,0.4996172600643373,0.9869120121002196,0.0453627854585647,0.2735382884365154,43793 -7063.6506524086,1.907481670379639,12502.245649576187,39022,0,12502.245649576187,0.98598051071167,0.0478197634220123,0.2580725558252392,43793,19568.8684835434,0.9920036792755128,0.025848040357232,0.501135744574851,0.9868023991584778,0.0449364073574543,0.2725329770589841,43793 -7195.441838502884,1.942678689956665,12742.505630731584,39763,0,12742.505630731584,0.9859746098518372,0.0481580346822738,0.2584333189610739,43793,19940.97686815262,0.9920294880867004,0.0255630780011415,0.5300515347756143,0.9869067668914796,0.0451199188828468,0.2812240884105178,43793 -7323.763496160507,1.9770872592926023,12982.455196619034,40508,0,12982.455196619034,0.9860655665397644,0.048062939196825,0.2669514125648769,43793,20309.30371117592,0.9923818707466124,0.0246172044426202,0.531692979601829,0.9868564009666444,0.0451633147895336,0.2772046084427561,43793 -7454.994160413742,2.011096239089966,13222.612513780594,41256,0,13222.612513780594,0.9858545660972596,0.0481712892651557,0.2574186207001034,43793,20680.74651002884,0.9925772547721864,0.0241183917969465,0.5373471320030823,0.986707866191864,0.0452901497483253,0.2748573775242192,43793 -7586.161582946777,2.044940948486328,13462.853985786438,42004,0,13462.853985786438,0.9858461618423462,0.0480821169912815,0.2624563335373316,43793,21052.20994591713,0.99293053150177,0.0230482406914234,0.5792572161274165,0.986726939678192,0.0451919101178646,0.2777320618341087,43793 -7717.086632013321,2.079592227935791,13703.067671060562,42744,0,13703.067671060562,0.98598051071167,0.0482771880924701,0.2657243764067328,43793,21423.40425777436,0.9928200840950012,0.0233676191419363,0.5596890151487883,0.9868795275688172,0.0452751107513904,0.2799710175790842,43793 -7845.803071022034,2.1141724586486816,13943.23703622818,43492,0,13943.23703622818,0.9860390424728394,0.0483165048062801,0.2653350767236811,43793,21792.34558987617,0.9925847053527832,0.0239696539938449,0.5528363329090942,0.986907124519348,0.0454422645270824,0.2818758340362011,43793 -7975.708652496338,2.1499216556549072,14183.470051765442,44245,0,14183.470051765442,0.9859662055969238,0.0486884862184524,0.2590255891925674,43793,22162.540717601776,0.9922870993614196,0.0249149054288864,0.5299943259670108,0.9868738651275636,0.0457279868423938,0.2731604479649976,43793 -8099.870770931244,2.1851043701171875,14423.669402837751,45001,0,14423.669402837751,0.9859042763710022,0.0485112071037292,0.2633550920425738,43793,22526.95771765709,0.9922727942466736,0.0247032344341278,0.5314834003758118,0.9867748022079468,0.0455899201333522,0.2845712141946099,43793 -8225.82084441185,2.220059871673584,14663.82266998291,45751,0,14663.82266998291,0.9860045313835144,0.0489345826208591,0.2583359577865653,43793,22893.1158246994,0.992500066757202,0.0239628087729215,0.5480373339681623,0.9867601990699768,0.0460610948503017,0.2770109295318826,43793 -8354.959605455399,2.2551698684692383,14903.953869581224,46501,0,14903.953869581224,0.9858596324920654,0.0496632941067218,0.2589811391864166,43793,23262.441887378693,0.9925301671028136,0.0238752625882625,0.5455075725286589,0.9867683053016664,0.0463589653372764,0.2753539428337459,43793 -8485.91120505333,2.2899508476257324,15144.071580171583,47254,0,15144.071580171583,0.9858705401420592,0.0494574196636676,0.254589348944634,43793,23633.566210269928,0.9927741289138794,0.0230254717171192,0.5728452454270465,0.9868178367614746,0.0460526496171951,0.2810175965239018,43793 -8611.761420249939,2.3253207206726074,15384.28563117981,48014,0,15384.28563117981,0.9859164953231812,0.0495819374918937,0.2659727477785244,43793,23999.686259269714,0.993111252784729,0.0219979658722877,0.6028270478090696,0.9867350459098816,0.0465046465396881,0.2760659886726114,43793 -8734.362447023392,2.3620529174804688,15624.298836946487,48776,0,15624.298836946487,0.9858773350715636,0.0503807589411735,0.2556186798312245,43793,24362.35687804222,0.9933006167411804,0.021312803030014,0.6074992806664145,0.9868174195289612,0.0469700321555137,0.271830096179202,43793 -8860.920503616333,2.397468328475952,15864.43145275116,49522,0,15864.43145275116,0.9858027696609496,0.0497420094907283,0.2576226855179493,43793,24729.10475111008,0.9936766624450684,0.0204411093145608,0.62012147547647,0.9867147207260132,0.0466574504971504,0.2757375957739105,43793 -8989.623442411423,2.433577060699463,16104.392055511476,50268,0,16104.392055511476,0.9857581257820128,0.0497113391757011,0.257039206586833,43793,25097.82476592064,0.9936593770980836,0.0206755194813013,0.6200574946558495,0.9865986108779908,0.0468205362558364,0.2705907601746179,43793 -9113.758713245392,2.470756769180298,16344.416038274763,51022,0,16344.416038274763,0.9858331084251404,0.0501727536320686,0.2515554863987931,43793,25462.04162096977,0.9933759570121764,0.0212582051753997,0.6098886010964394,0.9867504835128784,0.0470260456204414,0.2726649447568852,43793 -9241.55316233635,2.508890390396118,16584.653742313385,51772,0,16584.653742313385,0.9858819246292114,0.0503813885152339,0.257556443992949,43793,25830.131724357605,0.9933111667633056,0.0214274190366268,0.6009365828350484,0.9867772459983826,0.0471513681113719,0.2797023395820893,43793 -9371.00605082512,2.5452442169189453,16824.87526488304,52521,0,16824.87526488304,0.985834777355194,0.0509601235389709,0.2574955043291971,43793,26199.8629732132,0.9929652214050292,0.0222934540361166,0.5837654219526973,0.986805260181427,0.0476638115942478,0.2768003173815692,43793 -9496.152726888657,2.5827507972717285,17064.86261534691,53276,0,17064.86261534691,0.9858364462852478,0.0508538261055946,0.2604646324893419,43793,26565.05463194847,0.9930841326713562,0.0219203252345323,0.5777379497760418,0.986806869506836,0.0475819483399391,0.2802084358791125,43793 -9622.10262274742,2.618824481964112,17304.836101531982,54023,0,17304.836101531982,0.985871434211731,0.0509433448314666,0.2601165040001111,43793,26931.03513693809,0.9934099316596984,0.0208245944231748,0.6189938976762239,0.9866924285888672,0.0480832867324352,0.2718280150950659,43793 -9751.109506607056,2.6560540199279785,17545.082458496094,54777,0,17545.082458496094,0.9857589602470398,0.0514470785856246,0.2557963810743766,43793,27300.34612417221,0.993585765361786,0.0203476287424564,0.6222725297509415,0.9866132736206056,0.048272106796503,0.2787696840403332,43793 -9877.802606344225,2.69838547706604,17785.06728363037,55511,0,17785.06728363037,0.985806941986084,0.052032433450222,0.253222930490168,43793,27667.090751171112,0.9937710165977478,0.0195806194096803,0.6551641993310239,0.9865986108779908,0.0490063689649105,0.2737162962179191,43793 -9996.913838386536,2.7351033687591557,18025.3019797802,56263,0,18025.3019797802,0.9857622981071472,0.0521343909204006,0.2585453414761834,43793,28026.49404001236,0.994708776473999,0.0172035414725542,0.6975938120108676,0.9865763187408448,0.0490662939846515,0.2698884790954958,43793 -10129.031205415726,2.776209592819214,18265.25046825409,56999,0,18265.25046825409,0.9856982827186584,0.05246062949299812,0.2558035961147663,43793,28398.62367272377,0.994670569896698,0.01724269986152649,0.6762894919081208,0.9865738749504089,0.04920833185315132,0.2711901601753668,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index cad639751..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,656 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.305959,0.7359338,,,,,,,,,,,,,,,,, -1,,,0.5290504693984985,0.7363465428352356,0.0208464848778457,0.5270814895629883,0.737440824508667,0.024032072784418,43793.0,0.5256852507591248,0.7376683354377747,0.0260269185740518,43793.0,18.12672519683838,323.5901656150818,18.12672519683838,305.463401556015,0.0,0.0 -100,0.602501,0.4509123,,,,,,,,,,,,,,,,, -200,0.36551422,0.33057827,,,,,,,,,,,,,,,,, -300,0.2652188,0.2370993,,,,,,,,,,,,,,,,, -400,0.1740951,0.16176589,,,,,,,,,,,,,,,,, -500,0.106396325,0.11636761,,,,,,,,,,,,,,,,, -600,0.06633883,0.088116735,,,,,,,,,,,,,,,,, -700,0.043617,0.07043259,,,,,,,,,,,,,,,,, -738,,,0.986726939678192,0.0696189478039741,0.0375416835773601,0.9841179251670836,0.0782666727900505,0.039824265408219,43793.0,0.983142077922821,0.0811450853943824,0.041607852548778,43793.0,258.0806932449341,683.0615284442902,258.0806932449341,424.9269685745239,0.0321342945098876,0.0 -800,0.07282408,0.068949625,,,,,,,,,,,,,,,,, -900,0.35891122,0.05988559,,,,,,,,,,,,,,,,, -1000,0.18956766,0.05538078,,,,,,,,,,,,,,,,, -1100,0.2274813,0.050127164,,,,,,,,,,,,,,,,, -1200,0.17472644,0.052812506,,,,,,,,,,,,,,,,, -1300,0.1676615,0.047847264,,,,,,,,,,,,,,,,, -1400,0.11826427,0.052951373,,,,,,,,,,,,,,,,, -1472,,,0.9868612289428712,0.0518169067800045,0.0714212744243948,0.9841642379760742,0.0616440325975418,0.0751438053838872,43793.0,0.9831724166870116,0.0650959089398384,0.074120522459401,43793.0,498.13424348831177,1044.1351709365845,498.13424348831177,545.8983333110809,0.0606129169464111,0.0 -1500,0.092025265,0.05878996,,,,,,,,,,,,,,,,, -1600,0.23107824,0.050965782,,,,,,,,,,,,,,,,, -1700,0.14124836,0.049557902,,,,,,,,,,,,,,,,, -1800,0.055839233,0.051149003,,,,,,,,,,,,,,,,, -1900,0.069842964,0.05170739,,,,,,,,,,,,,,,,, -2000,0.14490569,0.047500458,,,,,,,,,,,,,,,,, -2100,0.17302439,0.0476358,,,,,,,,,,,,,,,,, -2200,0.11109031,0.047356855,,,,,,,,,,,,,,,,, -2221,,,0.9878098964691162,0.043958980590105,0.136974567316324,0.9851465821266174,0.0528057515621185,0.1295014808454525,43793.0,0.9841352701187134,0.0557597503066062,0.1304286933554119,43793.0,738.378705739975,1407.2695829868317,738.378705739975,668.7412447929382,0.0881714820861816,0.0 -2300,0.08574845,0.043842334,,,,,,,,,,,,,,,,, -2400,0.07295684,0.04717003,,,,,,,,,,,,,,,,, -2500,0.17090803,0.04806999,,,,,,,,,,,,,,,,, -2600,0.09688294,0.04246784,,,,,,,,,,,,,,,,, -2700,0.1143397,0.047789153,,,,,,,,,,,,,,,,, -2800,0.10518013,0.04677713,,,,,,,,,,,,,,,,, -2900,0.1010825,0.04361195,,,,,,,,,,,,,,,,, -2968,,,0.9880942702293396,0.0420986711978912,0.1608620839752452,0.985255777835846,0.0515386573970317,0.1466299149649001,43793.0,0.9843475222587584,0.0542657747864723,0.1536343785086584,43793.0,978.384375333786,1769.2708656787872,978.384375333786,790.6891014575958,0.1168539524078369,0.0 -3000,0.061159454,0.04121742,,,,,,,,,,,,,,,,, -3100,0.07749809,0.046617355,,,,,,,,,,,,,,,,, -3200,0.06094758,0.040623214,,,,,,,,,,,,,,,,, -3300,0.073697835,0.043831185,,,,,,,,,,,,,,,,, -3400,0.107740805,0.039825737,,,,,,,,,,,,,,,,, -3500,0.08223354,0.0466459,,,,,,,,,,,,,,,,, -3600,0.11345597,0.04417498,,,,,,,,,,,,,,,,, -3700,0.10969038,0.04126257,,,,,,,,,,,,,,,,, -3714,,,0.988378643989563,0.0402594655752182,0.1823391436100207,0.9854835271835328,0.0498136468231678,0.1700115195174815,43793.0,0.9845543503761292,0.0525040216743946,0.1697330694428799,43793.0,1218.5362486839294,2135.0145077705383,1218.5362486839294,916.2338988780976,0.1446871757507324,0.0 -3800,0.05346959,0.040660724,,,,,,,,,,,,,,,,, -3900,0.07729036,0.043436106,,,,,,,,,,,,,,,,, -4000,0.06037007,0.043977648,,,,,,,,,,,,,,,,, -4100,0.059038922,0.039318323,,,,,,,,,,,,,,,,, -4200,0.11085991,0.042250667,,,,,,,,,,,,,,,,, -4300,0.04774929,0.042160224,,,,,,,,,,,,,,,,, -4400,0.04909496,0.042228755,,,,,,,,,,,,,,,,, -4458,,,0.9884053468704224,0.0399228520691394,0.2100107487256066,0.9856386184692384,0.0493864566087722,0.1833988949672098,43793.0,0.9847325086593628,0.0519967675209045,0.1873038039271249,43793.0,1458.7867727279663,2505.007899045944,1458.7867727279663,1045.9301965236664,0.1723625659942627,0.0 -4500,0.034369703,0.03727929,,,,,,,,,,,,,,,,, -4600,0.045355413,0.042503078,,,,,,,,,,,,,,,,, -4700,0.10628274,0.04949545,,,,,,,,,,,,,,,,, -4800,0.049862828,0.037762444,,,,,,,,,,,,,,,,, -4900,0.046844788,0.0422274,,,,,,,,,,,,,,,,, -5000,0.025696458,0.038919896,,,,,,,,,,,,,,,,, -5100,0.056826737,0.039586786,,,,,,,,,,,,,,,,, -5192,,,0.9888152480125428,0.0387150831520557,0.2292613295651795,0.9857624173164368,0.048279769718647,0.1916392368220283,43793.0,0.9848563075065612,0.0509334243834018,0.1923857107825827,43793.0,1698.7766785621643,2875.6716067790985,1698.7766785621643,1176.5565106868744,0.1991899013519287,0.0 -5200,0.03734681,0.040969424,,,,,,,,,,,,,,,,, -5300,0.06738583,0.04075547,,,,,,,,,,,,,,,,, -5400,0.06993628,0.042048205,,,,,,,,,,,,,,,,, -5500,0.07955788,0.04288098,,,,,,,,,,,,,,,,, -5600,0.061606433,0.040509824,,,,,,,,,,,,,,,,, -5700,0.0525437,0.042796314,,,,,,,,,,,,,,,,, -5800,0.039386984,0.039183684,,,,,,,,,,,,,,,,, -5900,0.039926644,0.03942458,,,,,,,,,,,,,,,,, -5929,,,0.9890048503875732,0.037666168063879,0.2492122623022253,0.9859474897384644,0.0476326122879982,0.2099731371247263,43793.0,0.9850403666496276,0.0503462068736553,0.208156916642346,43793.0,1938.92862033844,3243.315950632096,1938.92862033844,1304.0002155303955,0.228431224822998,0.0 -6000,0.04698095,0.03883712,,,,,,,,,,,,,,,,, -6100,0.06188108,0.04108072,,,,,,,,,,,,,,,,, -6200,0.05451571,0.04168801,,,,,,,,,,,,,,,,, -6300,0.03294317,0.039874174,,,,,,,,,,,,,,,,, -6400,0.06481031,0.038348474,,,,,,,,,,,,,,,,, -6500,0.033674125,0.04140754,,,,,,,,,,,,,,,,, -6600,0.120273255,0.041419543,,,,,,,,,,,,,,,,, -6671,,,0.9889360070228576,0.0377796180546283,0.2379796961105802,0.9859402179718018,0.0476204343140125,0.2081422313743727,43793.0,0.9851086139678956,0.0502110980451107,0.2054259037859834,43793.0,2179.0981862545013,3611.34085059166,2179.0981862545013,1431.8058450222015,0.2576940059661865,0.0 -6700,0.03469106,0.037372705,,,,,,,,,,,,,,,,, -6800,0.022515722,0.0377814,,,,,,,,,,,,,,,,, -6900,0.026119946,0.041093037,,,,,,,,,,,,,,,,, -7000,0.030386731,0.03912495,,,,,,,,,,,,,,,,, -7100,0.030894337,0.041010585,,,,,,,,,,,,,,,,, -7200,0.041554697,0.04050669,,,,,,,,,,,,,,,,, -7300,0.027826529,0.043813102,,,,,,,,,,,,,,,,, -7400,0.047152445,0.040902786,,,,,,,,,,,,,,,,, -7435,,,0.988992154598236,0.0374467149376869,0.2587613004191066,0.9860936403274536,0.0468455292284488,0.2130963062876768,43793.0,0.9852202534675598,0.0493000671267509,0.2102784090723167,43793.0,2419.294604063034,3981.51959657669,2419.294604063034,1561.7417376041412,0.2840974330902099,0.0 -7500,0.0222725,0.03943813,,,,,,,,,,,,,,,,, -7600,0.04427028,0.04154034,,,,,,,,,,,,,,,,, -7700,0.036491398,0.041324344,,,,,,,,,,,,,,,,, -7800,0.029278072,0.04033684,,,,,,,,,,,,,,,,, -7900,0.040302698,0.037165612,,,,,,,,,,,,,,,,, -8000,0.02070745,0.037089575,,,,,,,,,,,,,,,,, -8100,0.028944122,0.042165957,,,,,,,,,,,,,,,,, -8200,,,0.9891512393951416,0.0365913212299346,0.2610685178132049,0.9860554933547974,0.0468535237014293,0.2194025693009693,43793.0,0.9851284027099608,0.0495839193463325,0.2186114076168888,43793.0,2659.541307926178,4349.7693428993225,2659.541307926178,1689.6956989765167,0.3128774166107178,0.0 -8200,0.029113641,0.039200824,,,,,,,,,,,,,,,,, -8300,0.023435581,0.037064824,,,,,,,,,,,,,,,,, -8400,0.022859948,0.038437687,,,,,,,,,,,,,,,,, -8500,0.022414241,0.03927794,,,,,,,,,,,,,,,,, -8600,0.025479553,0.036575112,,,,,,,,,,,,,,,,, -8700,0.024763986,0.038926344,,,,,,,,,,,,,,,,, -8800,0.019348001,0.038770553,,,,,,,,,,,,,,,,, -8900,0.023141053,0.042666838,,,,,,,,,,,,,,,,, -8952,,,0.9891273975372314,0.0366468429565429,0.2791102145414673,0.98610258102417,0.0462144687771797,0.2229995826347963,43793.0,0.9851734638214112,0.0489959381520748,0.2270890372536049,43793.0,2899.5621926784515,4718.630972146988,2899.5621926784515,1818.4839661121368,0.3447468280792236,0.0 -9000,0.032405514,0.04136744,,,,,,,,,,,,,,,,, -9100,0.032768775,0.03837514,,,,,,,,,,,,,,,,, -9200,0.02595819,0.035796627,,,,,,,,,,,,,,,,, -9300,0.022824308,0.038895026,,,,,,,,,,,,,,,,, -9400,0.04191608,0.03941328,,,,,,,,,,,,,,,,, -9500,0.019603932,0.037712663,,,,,,,,,,,,,,,,, -9600,0.026229313,0.041768868,,,,,,,,,,,,,,,,, -9700,0.042415008,0.037701745,,,,,,,,,,,,,,,,, -9706,,,0.98965722322464,0.0352459549903869,0.2876529415246712,0.9862491488456726,0.0461195521056652,0.2190752261508913,43793.0,0.9853731393814088,0.0486707761883735,0.2265520511640969,43793.0,3139.516256332397,5088.528156280518,3139.516256332397,1948.37777876854,0.3719336986541748,0.0 -9800,0.030849984,0.03806647,,,,,,,,,,,,,,,,, -9900,0.035649996,0.04088165,,,,,,,,,,,,,,,,, -10000,0.023316192,0.0371916,,,,,,,,,,,,,,,,, -10100,0.023753794,0.03463267,,,,,,,,,,,,,,,,, -10200,0.021373006,0.03580541,,,,,,,,,,,,,,,,, -10300,0.024758235,0.038538456,,,,,,,,,,,,,,,,, -10400,0.029672412,0.036037564,,,,,,,,,,,,,,,,, -10468,,,0.9896705150604248,0.0348237752914428,0.3068134647972326,0.9862880706787108,0.0463917814195156,0.2411961410225304,43793.0,0.9854143857955932,0.0490829050540924,0.2329932907320756,43793.0,3379.5207164287567,5454.616222381592,3379.5207164287567,2074.412222862244,0.4013481140136719,0.0 -10500,0.032071166,0.038343064,,,,,,,,,,,,,,,,, -10600,0.022290604,0.034742024,,,,,,,,,,,,,,,,, -10700,0.03502222,0.038389347,,,,,,,,,,,,,,,,, -10800,0.026013592,0.036611572,,,,,,,,,,,,,,,,, -10900,0.045773417,0.039725445,,,,,,,,,,,,,,,,, -11000,0.024538863,0.03825865,,,,,,,,,,,,,,,,, -11100,0.028144766,0.038507327,,,,,,,,,,,,,,,,, -11200,0.02419887,0.03973857,,,,,,,,,,,,,,,,, -11226,,,0.9897632002830504,0.0342228785157203,0.3330268928440666,0.9864967465400696,0.0456630326807498,0.2484119453619031,43793.0,0.9855487942695618,0.0485002435743808,0.239486379839807,43793.0,3619.7391617298126,5822.907917261124,3619.7391617298126,2202.437658548355,0.4295799732208252,0.0 -11300,0.027082158,0.036885504,,,,,,,,,,,,,,,,, -11400,0.04088108,0.036772776,,,,,,,,,,,,,,,,, -11500,0.028293362,0.038645256,,,,,,,,,,,,,,,,, -11600,0.03631538,0.038807776,,,,,,,,,,,,,,,,, -11700,0.034311388,0.03746475,,,,,,,,,,,,,,,,, -11800,0.03338478,0.036397975,,,,,,,,,,,,,,,,, -11900,0.034031626,0.036607362,,,,,,,,,,,,,,,,, -11982,,,0.9901514053344728,0.0327607914805412,0.3584733182330887,0.9865888953208924,0.0448321178555488,0.2500146851650296,43793.0,0.9855656027793884,0.0477055162191391,0.2389553858798118,43793.0,3859.737271785736,6193.9528086185455,3859.737271785736,2333.4328026771545,0.4608142375946045,0.0 -12000,0.03167455,0.03616707,,,,,,,,,,,,,,,,, -12100,0.029071707,0.036580164,,,,,,,,,,,,,,,,, -12200,0.037451282,0.037108622,,,,,,,,,,,,,,,,, -12300,0.035550665,0.036103677,,,,,,,,,,,,,,,,, -12400,0.04593637,0.037698682,,,,,,,,,,,,,,,,, -12500,0.051345162,0.034791123,,,,,,,,,,,,,,,,, -12600,0.028449608,0.038673606,,,,,,,,,,,,,,,,, -12700,0.028771585,0.038124263,,,,,,,,,,,,,,,,, -12739,,,0.9903623461723328,0.0321791432797908,0.3770886565685112,0.9865893125534058,0.0448479317128658,0.2501933144468553,43793.0,0.985632598400116,0.0476379878818988,0.2430544044938929,43793.0,4099.829087495804,6565.651482105255,4099.829087495804,2464.989315032959,0.4908986091613769,0.0 -12800,0.03941936,0.037263095,,,,,,,,,,,,,,,,, -12900,0.03302703,0.035521317,,,,,,,,,,,,,,,,, -13000,0.030140981,0.032457512,,,,,,,,,,,,,,,,, -13100,0.04571728,0.03610181,,,,,,,,,,,,,,,,, -13200,0.035797175,0.03301134,,,,,,,,,,,,,,,,, -13300,0.030198686,0.03448622,,,,,,,,,,,,,,,,, -13400,0.031955943,0.034407653,,,,,,,,,,,,,,,,, -13493,,,0.9904111623764038,0.0319937393069267,0.3685707194392107,0.986591339111328,0.0450877733528614,0.2536282798796875,43793.0,0.985736608505249,0.0479441359639167,0.2448625925982244,43793.0,4339.98765039444,6935.369527339935,4339.98765039444,2594.4998412132263,0.5199365615844727,0.0 -13500,0.04187304,0.035362836,,,,,,,,,,,,,,,,, -13600,0.03447148,0.036549605,,,,,,,,,,,,,,,,, -13700,0.03264469,0.03538265,,,,,,,,,,,,,,,,, -13800,0.042189095,0.03282719,,,,,,,,,,,,,,,,, -13900,0.048270386,0.036475264,,,,,,,,,,,,,,,,, -14000,0.03606413,0.035669725,,,,,,,,,,,,,,,,, -14100,0.03605698,0.037683897,,,,,,,,,,,,,,,,, -14200,0.041657753,0.036075387,,,,,,,,,,,,,,,,, -14249,,,0.9904477000236512,0.0317926034331321,0.3795072526894285,0.9866713285446168,0.0447172634303569,0.2544789447205896,43793.0,0.9857930541038512,0.0473803281784057,0.2434473027563511,43793.0,4580.077013969421,7305.356061458588,4580.077013969421,2724.34804224968,0.5486664772033691,0.0 -14300,0.042247478,0.036645748,,,,,,,,,,,,,,,,, -14400,0.04851496,0.03701158,,,,,,,,,,,,,,,,, -14500,0.048395,0.03377022,,,,,,,,,,,,,,,,, -14600,0.0514204,0.038120195,,,,,,,,,,,,,,,,, -14700,0.059999123,0.03650911,,,,,,,,,,,,,,,,, -14800,0.06606277,0.032731045,,,,,,,,,,,,,,,,, -14900,0.045139924,0.034757074,,,,,,,,,,,,,,,,, -15000,0.06140277,0.03535928,,,,,,,,,,,,,,,,, -15004,,,0.9902809858322144,0.0322612114250659,0.3739686437256189,0.9866579174995422,0.0448336601257324,0.2554795498068286,43793.0,0.9858074188232422,0.0476044304668903,0.2454498017573989,43793.0,4820.203650474548,7676.213287353516,4820.203650474548,2855.0291335582733,0.5781416893005371,0.0 -15100,0.045024134,0.033520266,,,,,,,,,,,,,,,,, -15200,0.043407753,0.035907965,,,,,,,,,,,,,,,,, -15300,0.04433698,0.03567745,,,,,,,,,,,,,,,,, -15400,0.045795575,0.03438732,,,,,,,,,,,,,,,,, -15500,0.042788778,0.03602673,,,,,,,,,,,,,,,,, -15600,0.059710998,0.03641318,,,,,,,,,,,,,,,,, -15700,0.047633346,0.03523736,,,,,,,,,,,,,,,,, -15757,,,0.9904695749282836,0.0315851792693138,0.3780887865883426,0.9866855144500732,0.044472336769104,0.2637727440673564,43793.0,0.98580402135849,0.0471926145255565,0.2571249985640839,43793.0,5060.340332508087,8052.993452072144,5060.340332508087,2991.6243121624,0.6067461967468262,0.0 -15800,0.05367186,0.03601341,,,,,,,,,,,,,,,,, -15900,0.0458294,0.037073344,,,,,,,,,,,,,,,,, -16000,0.04472853,0.033853278,,,,,,,,,,,,,,,,, -16100,0.03829414,0.03446303,,,,,,,,,,,,,,,,, -16200,0.06795414,0.03574186,,,,,,,,,,,,,,,,, -16300,0.062385812,0.037334222,,,,,,,,,,,,,,,,, -16400,0.050453175,0.035912596,,,,,,,,,,,,,,,,, -16500,0.077204816,0.038043067,,,,,,,,,,,,,,,,, -16510,,,0.9904604554176332,0.0316490679979324,0.3682752869847712,0.9867634773254396,0.0443042665719985,0.2625507900616484,43793.0,0.985867202281952,0.047187402844429,0.2516207197326865,43793.0,5300.3963787555695,8427.862237930298,5300.3963787555695,3126.3872005939484,0.6358902454376221,0.0 -16600,0.05411357,0.037633546,,,,,,,,,,,,,,,,, -16700,0.055889875,0.03712166,,,,,,,,,,,,,,,,, -16800,0.07847564,0.034967463,,,,,,,,,,,,,,,,, -16900,0.06015618,0.033639107,,,,,,,,,,,,,,,,, -17000,0.0607914,0.036332667,,,,,,,,,,,,,,,,, -17100,0.07865646,0.034770627,,,,,,,,,,,,,,,,, -17200,0.05234165,0.03516688,,,,,,,,,,,,,,,,, -17255,,,0.9904911518096924,0.0313327424228191,0.3850487983348396,0.9867411255836488,0.044379997998476,0.2592182022289485,43793.0,0.9859405159950256,0.04707632958889,0.2522397872860027,43793.0,5540.545610666275,8795.66866350174,5540.545610666275,3253.9894936084747,0.6687760353088379,0.0 -17300,0.047234293,0.033653658,,,,,,,,,,,,,,,,, -17400,0.048665017,0.0335407,,,,,,,,,,,,,,,,, -17500,0.058452334,0.0359513,,,,,,,,,,,,,,,,, -17600,0.082966864,0.03153133,,,,,,,,,,,,,,,,, -17700,0.044014506,0.032319628,,,,,,,,,,,,,,,,, -17800,0.05067025,0.034008745,,,,,,,,,,,,,,,,, -17900,0.059663326,0.0344277,,,,,,,,,,,,,,,,, -18000,0.054860517,0.033330932,,,,,,,,,,,,,,,,, -18009,,,0.9906929135322572,0.0308549869805574,0.3971799953290061,0.9866570830345154,0.0449607670307159,0.2552737329943473,43793.0,0.9858688712120056,0.0477351546287536,0.2433850683436001,43793.0,5780.521510839462,9169.35992527008,5780.521510839462,3387.650359153748,0.7024412155151367,0.0 -18100,0.059314627,0.03905811,,,,,,,,,,,,,,,,, -18200,0.051806428,0.032617897,,,,,,,,,,,,,,,,, -18300,0.0725428,0.035138514,,,,,,,,,,,,,,,,, -18400,0.055210743,0.03565929,,,,,,,,,,,,,,,,, -18500,0.067606285,0.034719486,,,,,,,,,,,,,,,,, -18600,0.0641496,0.0348411,,,,,,,,,,,,,,,,, -18700,0.069647275,0.036575433,,,,,,,,,,,,,,,,, -18759,,,0.990839421749115,0.030348252505064,0.4093167248324422,0.9866721034049988,0.0445089861750602,0.2656968890905586,43793.0,0.985846996307373,0.0472518391907215,0.2473320580292265,43793.0,6020.512645244598,9540.521861076357,6020.512645244598,3518.769593477249,0.7336812019348145,0.0 -18800,0.06225705,0.031548828,,,,,,,,,,,,,,,,, -18900,0.049611233,0.035366017,,,,,,,,,,,,,,,,, -19000,0.0624401,0.03695222,,,,,,,,,,,,,,,,, -19100,0.057412557,0.030174596,,,,,,,,,,,,,,,,, -19200,0.05945721,0.035395917,,,,,,,,,,,,,,,,, -19300,0.05949602,0.033453282,,,,,,,,,,,,,,,,, -19400,0.09476891,0.034088925,,,,,,,,,,,,,,,,, -19500,0.058327865,0.035078663,,,,,,,,,,,,,,,,, -19514,,,0.9907936453819276,0.0300093758851289,0.4283451389177612,0.9867143034934998,0.0449393689632415,0.2623142473408286,43793.0,0.9858187437057496,0.0476085133850574,0.2507867971943465,43793.0,6260.5007147789,9909.286611318588,6260.5007147789,3647.493248462677,0.7661452293395996,0.0 -19600,0.061866753,0.035637975,,,,,,,,,,,,,,,,, -19700,0.07498042,0.03722527,,,,,,,,,,,,,,,,, -19800,0.05758464,0.033531476,,,,,,,,,,,,,,,,, -19900,0.06107921,0.0341629,,,,,,,,,,,,,,,,, -20000,0.06669183,0.03555912,,,,,,,,,,,,,,,,, -20100,0.08147286,0.035734676,,,,,,,,,,,,,,,,, -20200,0.064154886,0.037195448,,,,,,,,,,,,,,,,, -20272,,,0.991016924381256,0.0292759668081998,0.4346414677193345,0.9868003726005554,0.0445684269070625,0.2706530486331538,43793.0,0.9859198331832886,0.0473172441124916,0.2484191606510649,43793.0,6500.6417491436005,10280.78795480728,6500.6417491436005,3778.8022241592407,0.7969620227813721,0.0 -20300,0.08053258,0.033431996,,,,,,,,,,,,,,,,, -20400,0.066035435,0.032787275,,,,,,,,,,,,,,,,, -20500,0.07257827,0.035928555,,,,,,,,,,,,,,,,, -20600,0.07943596,0.03653247,,,,,,,,,,,,,,,,, -20700,0.11656121,0.036159795,,,,,,,,,,,,,,,,, -20800,0.06761183,0.035238516,,,,,,,,,,,,,,,,, -20900,0.060840655,0.03182381,,,,,,,,,,,,,,,,, -21000,0.07069084,0.04140362,,,,,,,,,,,,,,,,, -21028,,,0.9911913275718688,0.029066402465105,0.4319467944312771,0.9865812063217164,0.0446751601994037,0.2618761943758421,43793.0,0.9858166575431824,0.0471398457884788,0.2538452476150781,43793.0,6740.721883535385,10648.634120225906,6740.721883535385,3906.517117500305,0.8276827335357666,0.0 -21100,0.085142165,0.03616114,,,,,,,,,,,,,,,,, -21200,0.09749503,0.033624645,,,,,,,,,,,,,,,,, -21300,0.07436079,0.033896785,,,,,,,,,,,,,,,,, -21400,0.06420853,0.033166282,,,,,,,,,,,,,,,,, -21500,0.06509911,0.032903597,,,,,,,,,,,,,,,,, -21600,0.06300718,0.033385433,,,,,,,,,,,,,,,,, -21700,0.08612926,0.033915423,,,,,,,,,,,,,,,,, -21774,,,0.991000235080719,0.0297731887549161,0.4357048122346068,0.9867650866508484,0.0443596877157688,0.2604349182296928,43793.0,0.9858962893486024,0.0469451062381267,0.2553447340116264,43793.0,6980.779272079468,11021.275067090988,6980.779272079468,4039.047303438186,0.8594973087310791,0.0 -21800,0.101466954,0.034171555,,,,,,,,,,,,,,,,, -21900,0.080543265,0.033910304,,,,,,,,,,,,,,,,, -22000,0.07990755,0.032324024,,,,,,,,,,,,,,,,, -22100,0.08575789,0.035338398,,,,,,,,,,,,,,,,, -22200,0.065296985,0.03398075,,,,,,,,,,,,,,,,, -22300,0.077160604,0.031156853,,,,,,,,,,,,,,,,, -22400,0.058521003,0.032760017,,,,,,,,,,,,,,,,, -22500,0.066027634,0.03630116,,,,,,,,,,,,,,,,, -22527,,,0.990811824798584,0.0301263500005006,0.3952604718650014,0.9867122769355774,0.0446335412561893,0.2590375873548066,43793.0,0.985925316810608,0.0474839806556701,0.2561323133732616,43793.0,7220.825093746185,11390.0852560997,7220.825093746185,4167.759459018707,0.8915479183197021,0.0 -22600,0.12188498,0.036526926,,,,,,,,,,,,,,,,, -22700,0.07724035,0.03181358,,,,,,,,,,,,,,,,, -22800,0.06851863,0.031569462,,,,,,,,,,,,,,,,, -22900,0.08025553,0.03155223,,,,,,,,,,,,,,,,, -23000,0.08505307,0.037112493,,,,,,,,,,,,,,,,, -23100,0.080062814,0.032717865,,,,,,,,,,,,,,,,, -23200,0.06920158,0.035921585,,,,,,,,,,,,,,,,, -23280,,,0.9908799529075624,0.0300232395529747,0.4227056232213048,0.9867395162582396,0.0447255373001098,0.2640613190295243,43793.0,0.985958993434906,0.0472457595169544,0.2608490687676737,43793.0,7460.819598436356,11762.319630622864,7460.819598436356,4299.94361448288,0.9270291328430176,0.0 -23300,0.09720096,0.03739285,,,,,,,,,,,,,,,,, -23400,0.05965735,0.032447882,,,,,,,,,,,,,,,,, -23500,0.06984687,0.033221997,,,,,,,,,,,,,,,,, -23600,0.09588093,0.033199187,,,,,,,,,,,,,,,,, -23700,0.09247708,0.033606507,,,,,,,,,,,,,,,,, -23800,0.06975483,0.03142772,,,,,,,,,,,,,,,,, -23900,0.06793735,0.033144414,,,,,,,,,,,,,,,,, -24000,0.07584876,0.03804818,,,,,,,,,,,,,,,,, -24031,,,0.9909188747406006,0.0297171231359243,0.4331208281078121,0.9867057800292968,0.0446142293512821,0.2611251490209201,43793.0,0.9859577417373656,0.04722660779953,0.2516795713976436,43793.0,7700.661962509155,12135.146509170532,7700.661962509155,4432.592004299164,1.2428011894226074,0.0 -24100,0.065889694,0.03534159,,,,,,,,,,,,,,,,, -24200,0.07774942,0.034092095,,,,,,,,,,,,,,,,, -24300,0.09835467,0.034291826,,,,,,,,,,,,,,,,, -24400,0.09736616,0.033846673,,,,,,,,,,,,,,,,, -24500,0.09782401,0.035784226,,,,,,,,,,,,,,,,, -24600,0.09678921,0.035611805,,,,,,,,,,,,,,,,, -24700,0.08783776,0.033997577,,,,,,,,,,,,,,,,, -24785,,,0.9908357858657836,0.0298755336552858,0.4064477900468354,0.9866595268249512,0.0450131893157959,0.2573716657473429,43793.0,0.9858036041259766,0.0478490367531776,0.2503198834780422,43793.0,7940.79802775383,12508.866287469864,7940.79802775383,4566.1248388290405,1.2736265659332275,0.0 -24800,0.06618781,0.03249248,,,,,,,,,,,,,,,,, -24900,0.07399307,0.03319317,,,,,,,,,,,,,,,,, -25000,0.09597904,0.031050725,,,,,,,,,,,,,,,,, -25100,0.07271909,0.033215385,,,,,,,,,,,,,,,,, -25200,0.07585993,0.033128135,,,,,,,,,,,,,,,,, -25300,0.0859686,0.035552185,,,,,,,,,,,,,,,,, -25400,0.07099074,0.036633592,,,,,,,,,,,,,,,,, -25500,0.07154639,0.03078394,,,,,,,,,,,,,,,,, -25537,,,0.9911275506019592,0.0291652176529169,0.4410087981079901,0.9867557287216188,0.0444139316678047,0.271666663964825,43793.0,0.9858790040016174,0.0471441224217414,0.2556667009203458,43793.0,8180.888298749924,12884.822046756744,8180.888298749924,4701.938496828079,1.304905891418457,0.0 -25600,0.09931754,0.0335399,,,,,,,,,,,,,,,,, -25700,0.06282259,0.031706765,,,,,,,,,,,,,,,,, -25800,0.06677589,0.03326147,,,,,,,,,,,,,,,,, -25900,0.079205684,0.031757325,,,,,,,,,,,,,,,,, -26000,0.082264595,0.033507794,,,,,,,,,,,,,,,,, -26100,0.10925739,0.03416699,,,,,,,,,,,,,,,,, -26200,0.07642108,0.0335184,,,,,,,,,,,,,,,,, -26288,,,0.9913343191146852,0.0283323358744382,0.4558309209604774,0.986750066280365,0.0446289516985416,0.2663130309218301,43793.0,0.985999882221222,0.0471678562462329,0.2552766763990095,43793.0,8421.127045869827,13258.230972528458,8421.127045869827,4835.052654981613,1.3399152755737305,0.0 -26300,0.093353435,0.035641156,,,,,,,,,,,,,,,,, -26400,0.07047642,0.03108424,,,,,,,,,,,,,,,,, -26500,0.06919893,0.03319655,,,,,,,,,,,,,,,,, -26600,0.09663079,0.037824795,,,,,,,,,,,,,,,,, -26700,0.076522104,0.03279491,,,,,,,,,,,,,,,,, -26800,0.07863149,0.030536735,,,,,,,,,,,,,,,,, -26900,0.0948133,0.035776634,,,,,,,,,,,,,,,,, -27000,0.13303286,0.033238053,,,,,,,,,,,,,,,,, -27044,,,0.9913906455039978,0.0281543508172035,0.4644263463446575,0.9868559837341307,0.0448550656437873,0.2704016450281102,43793.0,0.9859737753868104,0.0477073155343532,0.2546293091749866,43793.0,8661.194565296173,13631.669520616531,8661.194565296173,4968.371803283691,1.371005296707153,0.0 -27100,0.08728693,0.030906023,,,,,,,,,,,,,,,,, -27200,0.09667586,0.036986306,,,,,,,,,,,,,,,,, -27300,0.086327314,0.033738453,,,,,,,,,,,,,,,,, -27400,0.09385709,0.031869568,,,,,,,,,,,,,,,,, -27500,0.06766407,0.03014037,,,,,,,,,,,,,,,,, -27600,0.077099904,0.032086767,,,,,,,,,,,,,,,,, -27700,0.066485114,0.032791913,,,,,,,,,,,,,,,,, -27800,0.07762378,0.031893246,,,,,,,,,,,,,,,,, -27801,,,0.9913938045501708,0.0278185661882162,0.476967536096138,0.986710250377655,0.0451094470918178,0.2647917434751276,43793.0,0.9859451055526732,0.0478653758764267,0.2571030521372473,43793.0,8901.284093856812,14002.555571317673,8901.284093856812,5099.115159749985,1.403883695602417,0.0 -27900,0.089216426,0.032601878,,,,,,,,,,,,,,,,, -28000,0.08217676,0.033630524,,,,,,,,,,,,,,,,, -28100,0.075389855,0.03340178,,,,,,,,,,,,,,,,, -28200,0.06878594,0.03050423,,,,,,,,,,,,,,,,, -28300,0.068508804,0.031648222,,,,,,,,,,,,,,,,, -28400,0.07909197,0.032359906,,,,,,,,,,,,,,,,, -28500,0.07125651,0.032210648,,,,,,,,,,,,,,,,, -28546,,,0.9915661811828612,0.0271521862596273,0.48831206860766,0.986932337284088,0.0444307848811149,0.2773346118327409,43793.0,0.9859514236450196,0.0475336760282516,0.2583526410611838,43793.0,9141.320052146912,14374.145962715147,9141.320052146912,5230.611914157867,1.4391045570373535,0.0 -28600,0.081869304,0.030577434,,,,,,,,,,,,,,,,, -28700,0.08846496,0.034057364,,,,,,,,,,,,,,,,, -28800,0.07249788,0.03126474,,,,,,,,,,,,,,,,, -28900,0.08650623,0.029909132,,,,,,,,,,,,,,,,, -29000,0.08902541,0.033216346,,,,,,,,,,,,,,,,, -29100,0.08096798,0.031315245,,,,,,,,,,,,,,,,, -29200,0.0755614,0.032140892,,,,,,,,,,,,,,,,, -29299,,,0.9914606809616088,0.0279100723564624,0.4554978870592862,0.9867382645606996,0.044542621821165,0.2711044475252064,43793.0,0.9858566522598268,0.047407079488039,0.2526362549918176,43793.0,9381.337366342545,14749.433589935305,9381.337366342545,5365.830714225769,1.4701387882232666,0.0 -29300,0.077233665,0.030284464,,,,,,,,,,,,,,,,, -29400,0.06830639,0.032059472,,,,,,,,,,,,,,,,, -29500,0.06667482,0.031052405,,,,,,,,,,,,,,,,, -29600,0.07932112,0.0321009,,,,,,,,,,,,,,,,, -29700,0.08212544,0.031788807,,,,,,,,,,,,,,,,, -29800,0.10591904,0.035991363,,,,,,,,,,,,,,,,, -29900,0.07639844,0.033325553,,,,,,,,,,,,,,,,, -30000,0.097181775,0.034760825,,,,,,,,,,,,,,,,, -30045,,,0.9913190603256226,0.0283074267208576,0.4533270602357647,0.986749231815338,0.044565699994564,0.2735216258339702,43793.0,0.9859737753868104,0.0471785925328731,0.2610579508779325,43793.0,9621.466262102129,15122.514045715332,9621.466262102129,5498.725798130035,1.5060889720916748,0.0 -30100,0.09051034,0.032278217,,,,,,,,,,,,,,,,, -30200,0.0776867,0.031425633,,,,,,,,,,,,,,,,, -30300,0.07282895,0.03203665,,,,,,,,,,,,,,,,, -30400,0.088811554,0.031639222,,,,,,,,,,,,,,,,, -30500,0.080187716,0.032292772,,,,,,,,,,,,,,,,, -30600,0.08184625,0.03242679,,,,,,,,,,,,,,,,, -30700,0.060041495,0.029972004,,,,,,,,,,,,,,,,, -30794,,,0.9913336038589478,0.0282726194709539,0.4520790403454369,0.986777663230896,0.0448727048933506,0.2668618929323761,43793.0,0.9858756065368652,0.0477723777294158,0.2537742032421881,43793.0,9861.666239261627,15493.1018307209,9861.666239261627,5629.058895349503,1.540019989013672,0.0 -30800,0.0971825,0.03279804,,,,,,,,,,,,,,,,, -30900,0.07730635,0.032333497,,,,,,,,,,,,,,,,, -31000,0.10874148,0.032622162,,,,,,,,,,,,,,,,, -31100,0.08671239,0.031848926,,,,,,,,,,,,,,,,, -31200,0.08588086,0.033181034,,,,,,,,,,,,,,,,, -31300,0.09070739,0.032230083,,,,,,,,,,,,,,,,, -31400,0.07821236,0.0324996,,,,,,,,,,,,,,,,, -31500,0.08282912,0.029434051,,,,,,,,,,,,,,,,, -31550,,,0.9914684891700744,0.0279049314558506,0.46250529010033,0.986751675605774,0.0447242632508277,0.2758591528506365,43793.0,0.9859691262245178,0.0474783778190612,0.263347388682702,43793.0,10101.732964754105,15869.459534406662,10101.732964754105,5765.29677939415,1.5725555419921875,0.0 -31600,0.083660446,0.03382211,,,,,,,,,,,,,,,,, -31700,0.07404836,0.028612645,,,,,,,,,,,,,,,,, -31800,0.091363244,0.034848236,,,,,,,,,,,,,,,,, -31900,0.08704276,0.034580395,,,,,,,,,,,,,,,,, -32000,0.09569921,0.030484298,,,,,,,,,,,,,,,,, -32100,0.12740268,0.03128179,,,,,,,,,,,,,,,,, -32200,0.083113566,0.03321716,,,,,,,,,,,,,,,,, -32282,,,0.9915327429771424,0.0276380106806755,0.4701164224800756,0.9868023991584778,0.0441828817129135,0.2724851975870768,43793.0,0.985917329788208,0.0468419790267944,0.2613756747206935,43793.0,10341.703769683838,16240.596867084503,10341.703769683838,5896.408312559128,1.604877471923828,0.0 -32300,0.08616044,0.034834843,,,,,,,,,,,,,,,,, -32400,0.09998764,0.03583501,,,,,,,,,,,,,,,,, -32500,0.10223625,0.030774722,,,,,,,,,,,,,,,,, -32600,0.094468236,0.034576695,,,,,,,,,,,,,,,,, -32700,0.09756434,0.032847103,,,,,,,,,,,,,,,,, -32800,0.08273942,0.03283232,,,,,,,,,,,,,,,,, -32900,0.09558079,0.033435963,,,,,,,,,,,,,,,,, -33000,0.07881524,0.030610055,,,,,,,,,,,,,,,,, -33014,,,0.9916948676109314,0.026967754587531,0.4762917360420234,0.9868279695510864,0.0444532893598079,0.2725076791228633,43793.0,0.9860011339187622,0.0471965111792087,0.2568954076632428,43793.0,10581.765675783156,16615.352819919586,10581.765675783156,6031.045778036118,1.6387574672698977,0.0 -33100,0.0992214,0.031930227,,,,,,,,,,,,,,,,, -33200,0.07501914,0.03373052,,,,,,,,,,,,,,,,, -33300,0.1101242,0.033568475,,,,,,,,,,,,,,,,, -33400,0.087716356,0.03021945,,,,,,,,,,,,,,,,, -33500,0.080268145,0.028029725,,,,,,,,,,,,,,,,, -33600,0.07896368,0.032055683,,,,,,,,,,,,,,,,, -33700,0.07911641,0.03142926,,,,,,,,,,,,,,,,, -33766,,,0.9919641613960266,0.0263283308595418,0.4970465806059808,0.9867817163467408,0.0448424257338047,0.2703013692188287,43793.0,0.9859564900398254,0.0474421940743923,0.2586274867507777,43793.0,10821.961535930634,16985.606071949005,10821.961535930634,6161.0498831272125,1.6721007823944092,0.0 -33800,0.09849553,0.03357244,,,,,,,,,,,,,,,,, -33900,0.08582092,0.03177789,,,,,,,,,,,,,,,,, -34000,0.1189628,0.03546065,,,,,,,,,,,,,,,,, -34100,0.097083986,0.030267468,,,,,,,,,,,,,,,,, -34200,0.09459457,0.03307054,,,,,,,,,,,,,,,,, -34300,0.11988617,0.030158577,,,,,,,,,,,,,,,,, -34400,0.07699112,0.027701383,,,,,,,,,,,,,,,,, -34500,0.07961607,0.03232669,,,,,,,,,,,,,,,,, -34522,,,0.9921725392341614,0.0255570504814386,0.527538363625475,0.986916482448578,0.0447726920247077,0.2722659590388679,43793.0,0.9859733581542968,0.0475214347243309,0.2607594029126369,43793.0,11062.15025663376,17354.303694963455,11062.15025663376,6289.505045890808,1.7057373523712158,0.0 -34600,0.07553768,0.030488333,,,,,,,,,,,,,,,,, -34700,0.098253295,0.0322974,,,,,,,,,,,,,,,,, -34800,0.08636513,0.031216936,,,,,,,,,,,,,,,,, -34900,0.09298869,0.03166018,,,,,,,,,,,,,,,,, -35000,0.0851985,0.03106649,,,,,,,,,,,,,,,,, -35100,0.09423361,0.031683676,,,,,,,,,,,,,,,,, -35200,0.086095676,0.029184127,,,,,,,,,,,,,,,,, -35277,,,0.9921831488609314,0.0251275803893804,0.532644755953247,0.9869319200515748,0.0451381579041481,0.2747726059200759,43793.0,0.9859846830368042,0.0482017584145069,0.2530558781775611,43793.0,11302.103286266329,17724.31993317604,11302.103286266329,6419.514434099197,1.7393901348114014,0.0 -35300,0.10521321,0.031915616,,,,,,,,,,,,,,,,, -35400,0.091741845,0.03147219,,,,,,,,,,,,,,,,, -35500,0.08209981,0.03379824,,,,,,,,,,,,,,,,, -35600,0.14511938,0.030507669,,,,,,,,,,,,,,,,, -35700,0.09027137,0.0321843,,,,,,,,,,,,,,,,, -35800,0.11268397,0.030164743,,,,,,,,,,,,,,,,, -35900,0.12222001,0.033035185,,,,,,,,,,,,,,,,, -36000,0.09077768,0.030236382,,,,,,,,,,,,,,,,, -36035,,,0.9919872879981996,0.0259499680250883,0.5178907389430005,0.9868706464767456,0.0447487831115722,0.2739533269282222,43793.0,0.9860104322433472,0.047689463943243,0.2610775836816156,43793.0,11542.11351633072,18092.01842617989,11542.11351633072,6547.149572849274,1.7725646495819092,0.0 -36100,0.09534647,0.03244661,,,,,,,,,,,,,,,,, -36200,0.089885816,0.031757135,,,,,,,,,,,,,,,,, -36300,0.11166523,0.03154682,,,,,,,,,,,,,,,,, -36400,0.08876623,0.029843194,,,,,,,,,,,,,,,,, -36500,0.097908355,0.034002654,,,,,,,,,,,,,,,,, -36600,0.08508101,0.0313352,,,,,,,,,,,,,,,,, -36700,0.13053584,0.030704482,,,,,,,,,,,,,,,,, -36789,,,0.9918658137321472,0.026509465649724,0.4845961109288035,0.9867196083068848,0.0451682284474372,0.2734755096168028,43793.0,0.985889494419098,0.0480927638709545,0.2523359419552278,43793.0,11782.1854326725,18462.86779975891,11782.1854326725,6677.873930931091,1.8053655624389648,0.0 -36800,0.09597087,0.02946942,,,,,,,,,,,,,,,,, -36900,0.094874986,0.031761125,,,,,,,,,,,,,,,,, -37000,0.098188885,0.03172635,,,,,,,,,,,,,,,,, -37100,0.08129565,0.027280007,,,,,,,,,,,,,,,,, -37200,0.11350527,0.03112822,,,,,,,,,,,,,,,,, -37300,0.17868282,0.035565093,,,,,,,,,,,,,,,,, -37400,0.095645726,0.031256665,,,,,,,,,,,,,,,,, -37500,0.10149633,0.030905709,,,,,,,,,,,,,,,,, -37526,,,0.9918019771575928,0.0265244841575622,0.4839325470592795,0.9868669509887696,0.0447966158390045,0.2795964258277546,43793.0,0.9860280752182008,0.0478722341358661,0.2616498358073807,43793.0,12022.233164787292,18831.325871944427,12022.233164787292,6806.229717254639,1.8392269611358645,0.0 -37600,0.11184104,0.032754853,,,,,,,,,,,,,,,,, -37700,0.1099479,0.030090082,,,,,,,,,,,,,,,,, -37800,0.11073539,0.031510457,,,,,,,,,,,,,,,,, -37900,0.0843509,0.02888715,,,,,,,,,,,,,,,,, -38000,0.10683715,0.031751107,,,,,,,,,,,,,,,,, -38100,0.09749458,0.031167293,,,,,,,,,,,,,,,,, -38200,0.10230914,0.03151045,,,,,,,,,,,,,,,,, -38275,,,0.9918152093887328,0.0264482665807008,0.4996172600643373,0.9869120121002196,0.0453627854585647,0.2735382884365154,43793.0,0.9859615564346312,0.0483447611331939,0.255402046353583,43793.0,12262.27923464775,19201.54480075836,12262.27923464775,6936.348298549652,1.873117446899414,0.0 -38300,0.08859125,0.03011567,,,,,,,,,,,,,,,,, -38400,0.10658007,0.029562434,,,,,,,,,,,,,,,,, -38500,0.104536936,0.027040161,,,,,,,,,,,,,,,,, -38600,0.08781239,0.03270317,,,,,,,,,,,,,,,,, -38700,0.08335384,0.027177181,,,,,,,,,,,,,,,,, -38800,0.0934161,0.027865365,,,,,,,,,,,,,,,,, -38900,0.09657476,0.02987542,,,,,,,,,,,,,,,,, -39000,0.07438104,0.028340723,,,,,,,,,,,,,,,,, -39022,,,0.9920036792755128,0.025848040357232,0.501135744574851,0.9868023991584778,0.0449364073574543,0.2725329770589841,43793.0,0.98598051071167,0.0478197634220123,0.2580725558252392,43793.0,12502.245649576187,19568.8684835434,12502.245649576187,7063.6506524086,1.907481670379639,0.0 -39100,0.088041715,0.028951049,,,,,,,,,,,,,,,,, -39200,0.09282489,0.03177314,,,,,,,,,,,,,,,,, -39300,0.10227943,0.03253094,,,,,,,,,,,,,,,,, -39400,0.11410766,0.029482214,,,,,,,,,,,,,,,,, -39500,0.0944694,0.02922974,,,,,,,,,,,,,,,,, -39600,0.095833085,0.03055839,,,,,,,,,,,,,,,,, -39700,0.14000654,0.028950453,,,,,,,,,,,,,,,,, -39763,,,0.9920294880867004,0.0255630780011415,0.5300515347756143,0.9869067668914796,0.0451199188828468,0.2812240884105178,43793.0,0.9859746098518372,0.0481580346822738,0.2584333189610739,43793.0,12742.505630731584,19940.97686815262,12742.505630731584,7195.441838502884,1.942678689956665,0.0 -39800,0.09369291,0.030758547,,,,,,,,,,,,,,,,, -39900,0.08589245,0.030168094,,,,,,,,,,,,,,,,, -40000,0.121090256,0.030997714,,,,,,,,,,,,,,,,, -40100,0.09138567,0.028750213,,,,,,,,,,,,,,,,, -40200,0.14983554,0.0316478,,,,,,,,,,,,,,,,, -40300,0.11572537,0.030090442,,,,,,,,,,,,,,,,, -40400,0.08274931,0.028730446,,,,,,,,,,,,,,,,, -40500,0.0971411,0.03195737,,,,,,,,,,,,,,,,, -40508,,,0.9923818707466124,0.0246172044426202,0.531692979601829,0.9868564009666444,0.0451633147895336,0.2772046084427561,43793.0,0.9860655665397644,0.048062939196825,0.2669514125648769,43793.0,12982.455196619034,20309.30371117592,12982.455196619034,7323.763496160507,1.9770872592926023,0.0 -40600,0.12304198,0.032706503,,,,,,,,,,,,,,,,, -40700,0.10092341,0.02849891,,,,,,,,,,,,,,,,, -40800,0.10694681,0.029562088,,,,,,,,,,,,,,,,, -40900,0.099057354,0.030900663,,,,,,,,,,,,,,,,, -41000,0.09278912,0.028686179,,,,,,,,,,,,,,,,, -41100,0.106728226,0.027934363,,,,,,,,,,,,,,,,, -41200,0.1131646,0.027790489,,,,,,,,,,,,,,,,, -41256,,,0.9925772547721864,0.0241183917969465,0.5373471320030823,0.986707866191864,0.0452901497483253,0.2748573775242192,43793.0,0.9858545660972596,0.0481712892651557,0.2574186207001034,43793.0,13222.612513780594,20680.74651002884,13222.612513780594,7454.994160413742,2.011096239089966,0.0 -41300,0.091823675,0.031196425,,,,,,,,,,,,,,,,, -41400,0.096558824,0.032316156,,,,,,,,,,,,,,,,, -41500,0.108309515,0.030625561,,,,,,,,,,,,,,,,, -41600,0.09774461,0.02862452,,,,,,,,,,,,,,,,, -41700,0.116148986,0.028833462,,,,,,,,,,,,,,,,, -41800,0.13774481,0.02779074,,,,,,,,,,,,,,,,, -41900,0.09863604,0.031582136,,,,,,,,,,,,,,,,, -42000,0.1176404,0.029743709,,,,,,,,,,,,,,,,, -42004,,,0.99293053150177,0.0230482406914234,0.5792572161274165,0.986726939678192,0.0451919101178646,0.2777320618341087,43793.0,0.9858461618423462,0.0480821169912815,0.2624563335373316,43793.0,13462.853985786438,21052.20994591713,13462.853985786438,7586.161582946777,2.044940948486328,0.0 -42100,0.09449284,0.029776134,,,,,,,,,,,,,,,,, -42200,0.09398756,0.029640034,,,,,,,,,,,,,,,,, -42300,0.11546482,0.029948113,,,,,,,,,,,,,,,,, -42400,0.11203164,0.028508112,,,,,,,,,,,,,,,,, -42500,0.10847079,0.026789736,,,,,,,,,,,,,,,,, -42600,0.10226139,0.029940901,,,,,,,,,,,,,,,,, -42700,0.09915038,0.02985135,,,,,,,,,,,,,,,,, -42744,,,0.9928200840950012,0.0233676191419363,0.5596890151487883,0.9868795275688172,0.0452751107513904,0.2799710175790842,43793.0,0.98598051071167,0.0482771880924701,0.2657243764067328,43793.0,13703.067671060562,21423.40425777436,13703.067671060562,7717.086632013321,2.079592227935791,0.0 -42800,0.10655007,0.02967897,,,,,,,,,,,,,,,,, -42900,0.12554045,0.03297668,,,,,,,,,,,,,,,,, -43000,0.12091503,0.029100828,,,,,,,,,,,,,,,,, -43100,0.101493195,0.031006247,,,,,,,,,,,,,,,,, -43200,0.11186372,0.030123204,,,,,,,,,,,,,,,,, -43300,0.1324868,0.028281147,,,,,,,,,,,,,,,,, -43400,0.101685986,0.029366907,,,,,,,,,,,,,,,,, -43492,,,0.9925847053527832,0.0239696539938449,0.5528363329090942,0.986907124519348,0.0454422645270824,0.2818758340362011,43793.0,0.9860390424728394,0.0483165048062801,0.2653350767236811,43793.0,13943.23703622818,21792.34558987617,13943.23703622818,7845.803071022034,2.1141724586486816,0.0 -43500,0.100477956,0.030017328,,,,,,,,,,,,,,,,, -43600,0.100730106,0.03026038,,,,,,,,,,,,,,,,, -43700,0.11754683,0.028429816,,,,,,,,,,,,,,,,, -43800,0.12726517,0.031002035,,,,,,,,,,,,,,,,, -43900,0.105826765,0.03012323,,,,,,,,,,,,,,,,, -44000,0.13318296,0.031493008,,,,,,,,,,,,,,,,, -44100,0.11582503,0.028817154,,,,,,,,,,,,,,,,, -44200,0.10474339,0.028076867,,,,,,,,,,,,,,,,, -44245,,,0.9922870993614196,0.0249149054288864,0.5299943259670108,0.9868738651275636,0.0457279868423938,0.2731604479649976,43793.0,0.9859662055969238,0.0486884862184524,0.2590255891925674,43793.0,14183.470051765442,22162.540717601776,14183.470051765442,7975.708652496338,2.1499216556549072,0.0 -44300,0.10205388,0.027504189,,,,,,,,,,,,,,,,, -44400,0.1070146,0.031847507,,,,,,,,,,,,,,,,, -44500,0.111248784,0.029814538,,,,,,,,,,,,,,,,, -44600,0.120283924,0.027524246,,,,,,,,,,,,,,,,, -44700,0.10274856,0.029657142,,,,,,,,,,,,,,,,, -44800,0.1163637,0.03053561,,,,,,,,,,,,,,,,, -44900,0.112351894,0.027903346,,,,,,,,,,,,,,,,, -45000,0.119156495,0.030515576,,,,,,,,,,,,,,,,, -45001,,,0.9922727942466736,0.0247032344341278,0.5314834003758118,0.9867748022079468,0.0455899201333522,0.2845712141946099,43793.0,0.9859042763710022,0.0485112071037292,0.2633550920425738,43793.0,14423.669402837751,22526.95771765709,14423.669402837751,8099.870770931244,2.1851043701171875,0.0 -45100,0.11336126,0.031212231,,,,,,,,,,,,,,,,, -45200,0.112220384,0.027411306,,,,,,,,,,,,,,,,, -45300,0.099510476,0.029144801,,,,,,,,,,,,,,,,, -45400,0.12354467,0.028508803,,,,,,,,,,,,,,,,, -45500,0.1423303,0.029786266,,,,,,,,,,,,,,,,, -45600,0.10821549,0.027560286,,,,,,,,,,,,,,,,, -45700,0.13746823,0.030109754,,,,,,,,,,,,,,,,, -45751,,,0.992500066757202,0.0239628087729215,0.5480373339681623,0.9867601990699768,0.0460610948503017,0.2770109295318826,43793.0,0.9860045313835144,0.0489345826208591,0.2583359577865653,43793.0,14663.82266998291,22893.1158246994,14663.82266998291,8225.82084441185,2.220059871673584,0.0 -45800,0.12238146,0.02760539,,,,,,,,,,,,,,,,, -45900,0.12490188,0.029048435,,,,,,,,,,,,,,,,, -46000,0.11784009,0.02929558,,,,,,,,,,,,,,,,, -46100,0.10334254,0.029911028,,,,,,,,,,,,,,,,, -46200,0.10751649,0.024621606,,,,,,,,,,,,,,,,, -46300,0.112644956,0.029959995,,,,,,,,,,,,,,,,, -46400,0.1153991,0.028825477,,,,,,,,,,,,,,,,, -46500,0.11345452,0.027495902,,,,,,,,,,,,,,,,, -46501,,,0.9925301671028136,0.0238752625882625,0.5455075725286589,0.9867683053016664,0.0463589653372764,0.2753539428337459,43793.0,0.9858596324920654,0.0496632941067218,0.2589811391864166,43793.0,14903.953869581224,23262.441887378693,14903.953869581224,8354.959605455399,2.2551698684692383,0.0 -46600,0.110066965,0.029139549,,,,,,,,,,,,,,,,, -46700,0.12429306,0.02730856,,,,,,,,,,,,,,,,, -46800,0.1060554,0.027686357,,,,,,,,,,,,,,,,, -46900,0.1162381,0.02843319,,,,,,,,,,,,,,,,, -47000,0.13327384,0.02915243,,,,,,,,,,,,,,,,, -47100,0.12868598,0.029649049,,,,,,,,,,,,,,,,, -47200,0.11981621,0.02812942,,,,,,,,,,,,,,,,, -47254,,,0.9927741289138794,0.0230254717171192,0.5728452454270465,0.9868178367614746,0.0460526496171951,0.2810175965239018,43793.0,0.9858705401420592,0.0494574196636676,0.254589348944634,43793.0,15144.071580171583,23633.566210269928,15144.071580171583,8485.91120505333,2.2899508476257324,0.0 -47300,0.11030925,0.029788291,,,,,,,,,,,,,,,,, -47400,0.1129143,0.027693175,,,,,,,,,,,,,,,,, -47500,0.10830981,0.028761486,,,,,,,,,,,,,,,,, -47600,0.11094495,0.026557632,,,,,,,,,,,,,,,,, -47700,0.11546848,0.026507521,,,,,,,,,,,,,,,,, -47800,0.1448475,0.029223997,,,,,,,,,,,,,,,,, -47900,0.13537234,0.031265467,,,,,,,,,,,,,,,,, -48000,0.12364784,0.026514446,,,,,,,,,,,,,,,,, -48014,,,0.993111252784729,0.0219979658722877,0.6028270478090696,0.9867350459098816,0.0465046465396881,0.2760659886726114,43793.0,0.9859164953231812,0.0495819374918937,0.2659727477785244,43793.0,15384.28563117981,23999.686259269714,15384.28563117981,8611.761420249939,2.3253207206726074,0.0 -48100,0.11587298,0.026828729,,,,,,,,,,,,,,,,, -48200,0.12768619,0.028752457,,,,,,,,,,,,,,,,, -48300,0.12881236,0.02824268,,,,,,,,,,,,,,,,, -48400,0.14707544,0.026789656,,,,,,,,,,,,,,,,, -48500,0.11918811,0.02791954,,,,,,,,,,,,,,,,, -48600,0.13729419,0.028972529,,,,,,,,,,,,,,,,, -48700,0.13117455,0.027890947,,,,,,,,,,,,,,,,, -48776,,,0.9933006167411804,0.021312803030014,0.6074992806664145,0.9868174195289612,0.0469700321555137,0.271830096179202,43793.0,0.9858773350715636,0.0503807589411735,0.2556186798312245,43793.0,15624.298836946487,24362.35687804222,15624.298836946487,8734.362447023392,2.3620529174804688,0.0 -48800,0.12032273,0.026039572,,,,,,,,,,,,,,,,, -48900,0.17008103,0.03232695,,,,,,,,,,,,,,,,, -49000,0.12437358,0.027233798,,,,,,,,,,,,,,,,, -49100,0.13946727,0.02947123,,,,,,,,,,,,,,,,, -49200,0.12469816,0.027084619,,,,,,,,,,,,,,,,, -49300,0.12435105,0.029357072,,,,,,,,,,,,,,,,, -49400,0.13579519,0.028166724,,,,,,,,,,,,,,,,, -49500,0.123446874,0.026645273,,,,,,,,,,,,,,,,, -49522,,,0.9936766624450684,0.0204411093145608,0.62012147547647,0.9867147207260132,0.0466574504971504,0.2757375957739105,43793.0,0.9858027696609496,0.0497420094907283,0.2576226855179493,43793.0,15864.43145275116,24729.10475111008,15864.43145275116,8860.920503616333,2.397468328475952,0.0 -49600,0.15264082,0.027361063,,,,,,,,,,,,,,,,, -49700,0.13634837,0.027078904,,,,,,,,,,,,,,,,, -49800,0.13937233,0.028405027,,,,,,,,,,,,,,,,, -49900,0.15240575,0.029066758,,,,,,,,,,,,,,,,, -50000,0.13271442,0.02828824,,,,,,,,,,,,,,,,, -50100,0.13011445,0.027126815,,,,,,,,,,,,,,,,, -50200,0.110500716,0.026609194,,,,,,,,,,,,,,,,, -50268,,,0.9936593770980836,0.0206755194813013,0.6200574946558495,0.9865986108779908,0.0468205362558364,0.2705907601746179,43793.0,0.9857581257820128,0.0497113391757011,0.257039206586833,43793.0,16104.392055511476,25097.82476592064,16104.392055511476,8989.623442411423,2.433577060699463,0.0 -50300,0.12223343,0.025255164,,,,,,,,,,,,,,,,, -50400,0.12556843,0.02603888,,,,,,,,,,,,,,,,, -50500,0.1700352,0.029224692,,,,,,,,,,,,,,,,, -50600,0.14246058,0.027697217,,,,,,,,,,,,,,,,, -50700,0.12710544,0.027861996,,,,,,,,,,,,,,,,, -50800,0.1230721,0.026744781,,,,,,,,,,,,,,,,, -50900,0.12585711,0.025823547,,,,,,,,,,,,,,,,, -51000,0.11106533,0.02586918,,,,,,,,,,,,,,,,, -51022,,,0.9933759570121764,0.0212582051753997,0.6098886010964394,0.9867504835128784,0.0470260456204414,0.2726649447568852,43793.0,0.9858331084251404,0.0501727536320686,0.2515554863987931,43793.0,16344.416038274763,25462.04162096977,16344.416038274763,9113.758713245392,2.470756769180298,0.0 -51100,0.12988748,0.026775718,,,,,,,,,,,,,,,,, -51200,0.12061265,0.023879465,,,,,,,,,,,,,,,,, -51300,0.147634,0.026589254,,,,,,,,,,,,,,,,, -51400,0.1531493,0.026385387,,,,,,,,,,,,,,,,, -51500,0.1616796,0.029007899,,,,,,,,,,,,,,,,, -51600,0.13820066,0.02588533,,,,,,,,,,,,,,,,, -51700,0.14123832,0.025535712,,,,,,,,,,,,,,,,, -51772,,,0.9933111667633056,0.0214274190366268,0.6009365828350484,0.9867772459983826,0.0471513681113719,0.2797023395820893,43793.0,0.9858819246292114,0.0503813885152339,0.257556443992949,43793.0,16584.653742313385,25830.131724357605,16584.653742313385,9241.55316233635,2.508890390396118,0.0 -51800,0.13325693,0.026754675,,,,,,,,,,,,,,,,, -51900,0.15347126,0.026874924,,,,,,,,,,,,,,,,, -52000,0.13385837,0.027130665,,,,,,,,,,,,,,,,, -52100,0.1538251,0.027756693,,,,,,,,,,,,,,,,, -52200,0.15392406,0.025740726,,,,,,,,,,,,,,,,, -52300,0.12672843,0.02475846,,,,,,,,,,,,,,,,, -52400,0.14895985,0.02457128,,,,,,,,,,,,,,,,, -52500,0.14393467,0.026737727,,,,,,,,,,,,,,,,, -52521,,,0.9929652214050292,0.0222934540361166,0.5837654219526973,0.986805260181427,0.0476638115942478,0.2768003173815692,43793.0,0.985834777355194,0.0509601235389709,0.2574955043291971,43793.0,16824.87526488304,26199.8629732132,16824.87526488304,9371.00605082512,2.5452442169189453,0.0 -52600,0.14020832,0.024188114,,,,,,,,,,,,,,,,, -52700,0.17316417,0.027221225,,,,,,,,,,,,,,,,, -52800,0.11856834,0.024977185,,,,,,,,,,,,,,,,, -52900,0.14603002,0.028669875,,,,,,,,,,,,,,,,, -53000,0.1263398,0.024291841,,,,,,,,,,,,,,,,, -53100,0.16914742,0.02381099,,,,,,,,,,,,,,,,, -53200,0.15254377,0.024709688,,,,,,,,,,,,,,,,, -53276,,,0.9930841326713562,0.0219203252345323,0.5777379497760418,0.986806869506836,0.0475819483399391,0.2802084358791125,43793.0,0.9858364462852478,0.0508538261055946,0.2604646324893419,43793.0,17064.86261534691,26565.05463194847,17064.86261534691,9496.152726888657,2.5827507972717285,0.0 -53300,0.16019242,0.027061015,,,,,,,,,,,,,,,,, -53400,0.14600512,0.026349097,,,,,,,,,,,,,,,,, -53500,0.14697851,0.026653964,,,,,,,,,,,,,,,,, -53600,0.1402003,0.027396558,,,,,,,,,,,,,,,,, -53700,0.13875735,0.024381893,,,,,,,,,,,,,,,,, -53800,0.15593067,0.0261336,,,,,,,,,,,,,,,,, -53900,0.1518037,0.027768731,,,,,,,,,,,,,,,,, -54000,0.16219465,0.027708331,,,,,,,,,,,,,,,,, -54023,,,0.9934099316596984,0.0208245944231748,0.6189938976762239,0.9866924285888672,0.0480832867324352,0.2718280150950659,43793.0,0.985871434211731,0.0509433448314666,0.2601165040001111,43793.0,17304.836101531982,26931.03513693809,17304.836101531982,9622.10262274742,2.618824481964112,0.0 -54100,0.14792246,0.027664585,,,,,,,,,,,,,,,,, -54200,0.15495177,0.027463486,,,,,,,,,,,,,,,,, -54300,0.15796164,0.023337224,,,,,,,,,,,,,,,,, -54400,0.18737197,0.02575057,,,,,,,,,,,,,,,,, -54500,0.14881843,0.024916852,,,,,,,,,,,,,,,,, -54600,0.14618748,0.027300568,,,,,,,,,,,,,,,,, -54700,0.15086578,0.02732929,,,,,,,,,,,,,,,,, -54777,,,0.993585765361786,0.0203476287424564,0.6222725297509415,0.9866132736206056,0.048272106796503,0.2787696840403332,43793.0,0.9857589602470398,0.0514470785856246,0.2557963810743766,43793.0,17545.082458496094,27300.34612417221,17545.082458496094,9751.109506607056,2.6560540199279785,0.0 -54800,0.17362124,0.025805173,,,,,,,,,,,,,,,,, -54900,0.14158626,0.025141455,,,,,,,,,,,,,,,,, -55000,0.1582151,0.024965234,,,,,,,,,,,,,,,,, -55100,0.14584419,0.02396284,,,,,,,,,,,,,,,,, -55200,0.1485809,0.02451804,,,,,,,,,,,,,,,,, -55300,0.16587211,0.025340285,,,,,,,,,,,,,,,,, -55400,0.15723957,0.024225142,,,,,,,,,,,,,,,,, -55500,0.16204847,0.02724758,,,,,,,,,,,,,,,,, -55511,,,0.9937710165977478,0.0195806194096803,0.6551641993310239,0.9865986108779908,0.0490063689649105,0.2737162962179191,43793.0,0.985806941986084,0.052032433450222,0.253222930490168,43793.0,17785.06728363037,27667.090751171112,17785.06728363037,9877.802606344225,2.69838547706604,0.0 -55600,0.15302062,0.02528883,,,,,,,,,,,,,,,,, -55700,0.1709988,0.024412163,,,,,,,,,,,,,,,,, -55800,0.16619104,0.026682649,,,,,,,,,,,,,,,,, -55900,0.14307235,0.022738343,,,,,,,,,,,,,,,,, -56000,0.14032355,0.025829691,,,,,,,,,,,,,,,,, -56100,0.16738638,0.023612142,,,,,,,,,,,,,,,,, -56200,0.18048477,0.026058644,,,,,,,,,,,,,,,,, -56263,,,0.994708776473999,0.0172035414725542,0.6975938120108676,0.9865763187408448,0.0490662939846515,0.2698884790954958,43793.0,0.9857622981071472,0.0521343909204006,0.2585453414761834,43793.0,18025.3019797802,28026.49404001236,18025.3019797802,9996.913838386536,2.735103368759156,0.0 -56300,0.15273964,0.0225294,,,,,,,,,,,,,,,,, -56400,0.16499424,0.023369495,,,,,,,,,,,,,,,,, -56500,0.16058655,0.024887301,,,,,,,,,,,,,,,,, -56600,0.18036443,0.023746029,,,,,,,,,,,,,,,,, -56700,0.18681607,0.02462669,,,,,,,,,,,,,,,,, -56800,0.18146534,0.022927986,,,,,,,,,,,,,,,,, -56900,0.16727898,0.023796344,,,,,,,,,,,,,,,,, -56999,,,0.994670569896698,0.0172426998615264,0.6762894919081208,0.9865738749504088,0.0492083318531513,0.2711901601753668,43793.0,0.9856982827186584,0.0524606294929981,0.2558035961147663,43793.0,18265.25046825409,28398.62367272377,18265.25046825409,10129.031205415726,2.776209592819214,0.0 -57000,0.17361985,0.023832299,,,,,,,,,,,,,,,,, -57100,0.15965658,0.024900954,,,,,,,,,,,,,,,,, -57200,0.1734536,0.023346717,,,,,,,,,,,,,,,,, -57300,0.16132447,0.023713,,,,,,,,,,,,,,,,, -57400,0.19128641,0.022553906,,,,,,,,,,,,,,,,, -57500,0.17787845,0.025158081,,,,,,,,,,,,,,,,, -57600,0.1594591,0.02384227,,,,,,,,,,,,,,,,, -57666,,,,,,,,,,,,,,18477.228969812393,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 1878d8305..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -127.20457720756532,0.0,12.126285314559937,1,0,12.126285314559937,0.5256850719451904,0.7376683354377747,0.0260218475243196,43793,139.33090662956238,0.5289902091026306,0.7364099025726318,0.0206864879251886,0.5270814895629883,0.737440824508667,0.0240625484422362,43793 -253.0918610095977,0.021273136138916,252.2958436012268,750,0,252.2958436012268,0.983142077922821,0.0760092660784721,0.0461748143897109,43793,505.4296803474426,0.986754298210144,0.0641855522990226,0.043571310424089,0.9841179251670836,0.0730567350983619,0.0453327088891681,43793 -381.9180512428284,0.0487275123596191,492.5110273361206,1491,0,492.5110273361206,0.9834765195846558,0.0612092986702919,0.0846994099290799,43793,874.5190241336823,0.9871442317962646,0.0485940016806125,0.0847953123676511,0.9844759702682496,0.0579231418669223,0.0876923435601555,43793 -510.0526442527771,0.0775880813598632,732.7416000366211,2236,0,732.7416000366211,0.9837313294410706,0.0580892078578472,0.1110899480813678,43793,1242.9329607486725,0.9874035120010376,0.0458655469119548,0.1138696637483815,0.9847410321235656,0.0549195483326911,0.1122611004225512,43793 -637.340523481369,0.1042702198028564,972.7662193775176,2985,0,972.7662193775176,0.9838774800300598,0.0565657056868076,0.1329894472602042,43793,1610.2927725315094,0.9875199794769288,0.0446250885725021,0.1425213550215882,0.984859585762024,0.0536448583006858,0.1345195998544093,43793 -760.508437871933,0.131772756576538,1212.8425545692444,3730,0,1212.8425545692444,0.983906090259552,0.0564019158482551,0.1527795488106242,43793,1973.5845963954928,0.9874981641769408,0.0441562943160533,0.16691992686249,0.9848372340202332,0.0532684922218322,0.1536197460118933,43793 -882.3497655391693,0.1592152118682861,1453.064716339111,4480,0,1453.064716339111,0.9844414591789246,0.053200889378786,0.1681703076887086,43793,2335.695539712906,0.9881991744041444,0.0413440726697444,0.1970566464858755,0.985330045223236,0.0503430180251598,0.1702478411418039,43793 -1006.7336690425872,0.1871747970581054,1693.1650228500366,5232,0,1693.1650228500366,0.9842683672904968,0.0540040172636508,0.1897039477965048,43793,2700.22809290886,0.9880143404006958,0.041274219751358,0.2060017191467205,0.9851628541946412,0.0510475561022758,0.1887682563880464,43793 -1135.106214761734,0.2173802852630615,1933.403118848801,5975,0,1933.403118848801,0.9848635196685792,0.0512248836457729,0.2039977342576277,43793,3068.890548229217,0.9888017773628236,0.0386166907846927,0.2381257512652334,0.9857409000396729,0.048397846519947,0.2011913112182664,43793 -1257.4715340137482,0.246328592300415,2173.590786933899,6721,0,2173.590786933899,0.9846937656402588,0.051630537956953,0.2059085480747178,43793,3431.493139743805,0.9884495735168456,0.0391944348812103,0.2413052924407402,0.9856499433517456,0.0487793684005737,0.2044096955630416,43793 -1379.377456188202,0.2736272811889648,2413.612592458725,7477,0,2413.612592458725,0.9850782752037048,0.0501323044300079,0.2198003057995466,43793,3793.469251871109,0.9889029264450072,0.037515502423048,0.2757240612573762,0.9859803915023804,0.0474647879600524,0.2206059988797787,43793 -1500.782978773117,0.3023748397827148,2653.793397426605,8225,0,2653.793397426605,0.9852569103240968,0.0494446270167827,0.2252269584982761,43793,4155.104791641235,0.9894757270812988,0.0359002090990543,0.3072628440876122,0.9861716032028198,0.0467517748475074,0.2247504688285445,43793 -1625.9661529064178,0.3305733203887939,2893.755373477936,8984,0,2893.755373477936,0.9854257702827454,0.048779260367155,0.2395137327731072,43793,4520.299135923386,0.9894742965698242,0.0352791287004947,0.3132704878213992,0.986291766166687,0.0461629033088684,0.235990866792487,43793 -1749.5614104270935,0.358844518661499,3133.7946379184723,9733,0,3133.7946379184723,0.9853495359420776,0.0491264685988426,0.2371165836262182,43793,4883.982634544373,0.9896693229675292,0.0342782400548458,0.3510216167629603,0.9862142205238342,0.0464232116937637,0.2376936664874003,43793 -1871.8764510154724,0.3905558586120605,3374.0041666030884,10481,0,3374.0041666030884,0.9854468703269958,0.0490616187453269,0.2472147852661594,43793,5246.561218261719,0.9896564483642578,0.0339212678372859,0.3735877274136109,0.9863014817237854,0.0462760962545871,0.2454329941619501,43793 -1993.6837046146395,0.4195625782012939,3614.227289676666,11236,0,3614.227289676666,0.9854127168655396,0.0493291094899177,0.2470103021327772,43793,5608.641200304031,0.989663541316986,0.0340158343315124,0.3581437473261031,0.9863518476486206,0.0462555959820747,0.2520758598756026,43793 -2113.992784023285,0.4489624500274658,3854.4807589054094,11992,0,3854.4807589054094,0.9854961037635804,0.0490399077534675,0.2538872922350826,43793,5969.253441572189,0.9899739623069764,0.0330843292176723,0.3703011366378856,0.9864541292190552,0.046054221689701,0.2510109312493945,43793 -2232.345431089401,0.4781973361968994,4094.628182649613,12742,0,4094.628182649613,0.9856456518173218,0.0483121685683727,0.2498462761455225,43793,6327.803616523743,0.9902027249336244,0.0324925333261489,0.3828772272923211,0.9865673780441284,0.0455458313226699,0.2547032048499188,43793 -2352.0891876220703,0.5074634552001953,4334.628254652023,13491,0,4334.628254652023,0.98575097322464,0.0479115359485149,0.2501149207185102,43793,6687.5966901779175,0.9905490279197692,0.0317074432969093,0.3998465120748556,0.9866088032722472,0.0450855679810047,0.255800612872904,43793 -2473.1447324752808,0.5425519943237305,4574.708070993424,14243,0,4574.708070993424,0.9858794212341307,0.0477812215685844,0.2592707509360255,43793,7048.78892493248,0.9906753897666932,0.0309940185397863,0.4288764369674755,0.9867565631866456,0.0450419187545776,0.2566420787519,43793 -2594.674001932144,0.5714969635009766,4814.695953845978,14988,0,4814.695953845978,0.985731601715088,0.0484391562640666,0.2535050643187877,43793,7410.355037212372,0.9906166195869446,0.0303849708288908,0.4262323265330509,0.986648976802826,0.045590728521347,0.2593278819248042,43793 -2715.398867845536,0.6034946441650391,5054.773876190186,15725,0,5054.773876190186,0.985811173915863,0.0480678342282772,0.2574324454887617,43793,7771.213989019394,0.9909983277320862,0.0295489635318517,0.454112727750228,0.9866875410079956,0.0453388169407844,0.2552772530468821,43793 -2835.131489276886,0.632915735244751,5294.818076848984,16472,0,5294.818076848984,0.985897958278656,0.0482565090060234,0.256947923922455,43793,8131.040618658066,0.991483211517334,0.0277794189751148,0.4892964067033922,0.986763834953308,0.045558076351881,0.2630052954278285,43793 -2953.16513299942,0.6627569198608398,5535.046224355698,17216,0,5535.046224355698,0.9859485030174256,0.0482847541570663,0.2594389000289132,43793,8489.352545261383,0.9915978908538818,0.0274420864880085,0.5071297348406607,0.9867764711380004,0.0454589314758777,0.259824773748504,43793 -3069.9476771354675,0.691856861114502,5775.2991716861725,17970,0,5775.2991716861725,0.985820472240448,0.0484392642974853,0.2553436713257808,43793,8846.437492847443,0.991574227809906,0.027520490810275,0.5063725356236961,0.9866810441017152,0.0457698628306388,0.2549757899168097,43793 -3193.9604799747467,0.723224401473999,6015.247898340225,18712,0,6015.247898340225,0.985951006412506,0.0483011044561862,0.2602272453178065,43793,9210.452338218687,0.9916909337043762,0.0276644323021173,0.4879586830864612,0.9867812991142272,0.045619148761034,0.2659859739068194,43793 -3317.048745393753,0.7530508041381836,6255.323559045792,19462,0,6255.323559045792,0.9858710169792176,0.0486629828810691,0.2597261426049351,43793,9573.666657924652,0.9915679693222046,0.0275299493223428,0.4931503201414768,0.9867374897003174,0.0458053909242153,0.2623122251543744,43793 -3435.202912569046,0.7837088108062744,6495.576878070831,20217,0,6495.576878070831,0.985770344734192,0.0491240657866001,0.2538815944606575,43793,9932.12539958954,0.9917171597480774,0.0272450819611549,0.5011942091495836,0.986620545387268,0.0463441610336303,0.2602996367965026,43793 -3553.2706336975098,0.817469596862793,6735.836624145508,20964,0,6735.836624145508,0.985842764377594,0.0492424480617046,0.2505229961525075,43793,10290.507292747498,0.9916337728500366,0.0269696675240993,0.5147379353860131,0.9867159724235536,0.0463829524815082,0.2640476810927308,43793 -3680.445751905441,0.8478615283966064,6975.841460704804,21709,0,6975.841460704804,0.985846996307373,0.0488092526793479,0.2529326603844326,43793,10657.739182472227,0.992107629776001,0.0259189009666442,0.5224654679997216,0.9866802096366882,0.0461787469685077,0.2573069386431443,43793 -3795.818098068237,0.8788893222808838,7216.057116985321,22456,0,7216.057116985321,0.9858326315879822,0.0496953874826431,0.2527972181515401,43793,11013.379020690918,0.9921752214431764,0.0250938981771469,0.5612627482316623,0.9866765737533568,0.0467959716916084,0.2537034977859143,43793 -3919.440933704376,0.9106287956237792,7456.157476663589,23203,0,7456.157476663589,0.9858015179634094,0.0494990982115268,0.2530047210423624,43793,11377.154458284378,0.9926006197929382,0.0240978058427572,0.5666674766675607,0.9866806268692015,0.0464501082897186,0.2576747314424948,43793 -4039.667632818222,0.9410545825958252,7696.350813627243,23953,0,7696.350813627243,0.9857783317565918,0.0499153286218643,0.2435153176655401,43793,11737.6252348423,0.9929395318031312,0.0229268539696931,0.6068382807187294,0.9866469502449036,0.0467442981898784,0.2506215024728762,43793 -4162.690213441849,0.9763381481170654,7936.4745626449585,24697,0,7936.4745626449585,0.9857467412948608,0.0501218922436237,0.2449126202635368,43793,12100.827523469923,0.9930366277694702,0.0227992050349712,0.6002876704147919,0.9866778254508972,0.0469138175249099,0.2580225307631216,43793 -4282.287209033966,1.0089423656463623,8176.495717287064,25443,0,8176.495717287064,0.9858074188232422,0.050314363092184,0.2511761806975714,43793,12460.498522043228,0.9928334951400756,0.0233095176517963,0.5824658225387833,0.9866371750831604,0.0471940413117408,0.2585839289400465,43793 -4402.311491250992,1.04034686088562,8416.455756187439,26181,0,8416.455756187439,0.9857892990112304,0.0504956133663654,0.2471158389841096,43793,12820.535581588743,0.9925329089164734,0.0240434017032384,0.5762234810496102,0.986629068851471,0.0475363433361053,0.255861604009055,43793 -4520.522451400757,1.071984052658081,8656.649072170258,26930,0,8656.649072170258,0.98576021194458,0.0513136461377143,0.2501128544462677,43793,13178.992216587068,0.9923691749572754,0.0243989583104848,0.5442668976545151,0.986559271812439,0.0484216548502445,0.2523562677872928,43793 -4643.076948165894,1.1032321453094482,8896.665783882141,27681,0,8896.665783882141,0.9857446551322936,0.0510604418814182,0.2448402843569965,43793,13541.615369558334,0.992576003074646,0.0239151343703269,0.5780123560081694,0.9865665435791016,0.0480653345584869,0.2504831156838596,43793 -4764.155343532562,1.1362247467041016,9136.74174809456,28428,0,9136.74174809456,0.9858065247535706,0.0515264011919498,0.2475043142655151,43793,13902.822404623032,0.9928958415985109,0.022750936448574,0.5812031928977817,0.9866591095924376,0.0483444854617118,0.258367335005388,43793 -4883.786524772644,1.1681454181671145,9376.930389642715,29172,0,9376.930389642715,0.9856435656547546,0.0517445616424083,0.2452964177865293,43793,14262.695072174072,0.9930545687675476,0.0223341267555952,0.615808792548507,0.9865154027938844,0.0486091189086437,0.248205262006874,43793 -5001.006668329239,1.2020776271820068,9617.14767241478,29922,0,9617.14767241478,0.9856886267662048,0.0519146211445331,0.2413855139103063,43793,14620.18716263771,0.9933610558509828,0.0212800167500972,0.6281966779933524,0.986520290374756,0.0487982630729675,0.2524698596022855,43793 -5124.642949104309,1.2355139255523682,9857.341331720352,30664,0,9857.341331720352,0.9856064915657043,0.0523036830127239,0.2375251315964577,43793,14984.070784330368,0.9941012263298036,0.0193631164729595,0.6699394839401119,0.9864890575408936,0.0492533333599567,0.2432590919025162,43793 -5241.653692960739,1.2681159973144531,10097.41719198227,31412,0,10097.41719198227,0.9856364130973816,0.0520030371844768,0.2407133325040985,43793,15341.210273504255,0.9941602349281312,0.0192818939685821,0.6668670900993225,0.9864894151687622,0.0489224307239055,0.2506274620268583,43793 -5363.469901323319,1.3005335330963137,10337.642447710035,32146,0,10337.642447710035,0.9856595396995544,0.0528297610580921,0.2389180610951756,43793,15703.304780244827,0.9937401413917542,0.0202498193830251,0.6502916140923689,0.9865487217903136,0.0496782660484313,0.2499031904886978,43793 -5488.026214838028,1.334458589553833,10577.850949764252,32902,0,10577.850949764252,0.9856982827186584,0.0526974648237228,0.2426612461020473,43793,16068.123711824415,0.9936198592185974,0.0206041522324085,0.6400930897954951,0.986531674861908,0.0495291985571384,0.2453435833052644,43793 -5605.228426933289,1.3684918880462646,10817.950261116028,33652,0,10817.950261116028,0.9855508804321288,0.0528176799416542,0.2418377184419295,43793,16425.47945213318,0.9934507012367249,0.0209422651678323,0.6310022435507462,0.9864732027053832,0.0497406870126724,0.2459634658972307,43793 -5726.372263431549,1.402308702468872,11058.12816143036,34387,0,11058.12816143036,0.985524356365204,0.0534192509949207,0.2359514915263177,43793,16786.856697797775,0.9934825897216796,0.0207439642399549,0.6385106217173827,0.9863489866256714,0.0502095408737659,0.2464846540457237,43793 -5840.427979707718,1.437326431274414,11298.363315105438,35138,0,11298.363315105438,0.9854733943939208,0.0535249002277851,0.2315473117127911,43793,17141.203130722046,0.9936950206756592,0.0203303676098585,0.6362148804038725,0.9863278865814208,0.0502837263047695,0.2430508544160199,43793 -5960.965605020523,1.471550226211548,11538.582442760468,35887,0,11538.582442760468,0.9854927659034728,0.0545136667788028,0.2386775963454686,43793,17502.014199733734,0.993681788444519,0.019934918731451,0.6535335108823093,0.98631489276886,0.0514221414923667,0.2430296632439309,43793 -6084.737009048462,1.5058512687683103,11778.559102535248,36621,0,11778.559102535248,0.9854645133018494,0.0545359924435615,0.2343355525502071,43793,17865.818506240845,0.9945728182792664,0.0176827441900968,0.688483399506173,0.986333966255188,0.051441915333271,0.244130999515316,43793 -6205.075866937637,1.5408875942230225,12018.643469572067,37372,0,12018.643469572067,0.9855479598045348,0.0552063584327697,0.235362061924285,43793,18226.29783797264,0.9952541589736938,0.0158701855689287,0.7463612898447716,0.9863250255584716,0.0521195009350776,0.2393038359651531,43793 -6320.522824525833,1.5767641067504885,12258.72385263443,38121,0,12258.72385263443,0.9855525493621826,0.0550990290939807,0.2385763605440669,43793,18581.881452083588,0.994975447654724,0.016621870920062,0.7227896262201021,0.986448049545288,0.0518963262438774,0.2398712386816698,43793 -6437.69078040123,1.609722137451172,12498.797527074814,38874,0,12498.797527074814,0.9855129718780518,0.0561452880501747,0.2376918945252976,43793,18939.17649841309,0.994772493839264,0.0167333744466304,0.714931953560515,0.986291766166687,0.0532053150236606,0.2356094752372903,43793 -6553.50083398819,1.6491947174072266,12738.750935316086,39625,0,12738.750935316086,0.9854274988174438,0.0564156621694564,0.2301816212423666,43793,19294.999740600582,0.994320273399353,0.0178232826292514,0.705524199624799,0.986319363117218,0.0530956424772739,0.2403834761443907,43793 -6664.366796016693,1.6828031539916992,12978.82855129242,40377,0,12978.82855129242,0.9854485392570496,0.0566561445593833,0.2315639609968926,43793,19645.99722838401,0.9941571950912476,0.0183549374341964,0.6876441021637686,0.986297845840454,0.0534391738474369,0.2370363956850179,43793 -6781.759689092636,1.7166087627410889,13219.0130982399,41136,0,13219.0130982399,0.9851621389389038,0.055863220244646,0.2326064034439338,43793,20003.628672599792,0.9942678809165956,0.0182990487664937,0.6918804230446601,0.9859828352928162,0.0525840632617473,0.2357247872052208,43793 -6894.503623247147,1.7502388954162598,13459.264628887177,41890,0,13459.264628887177,0.9853798747062684,0.0568219721317291,0.2324982355114241,43793,20356.678092479706,0.9942914843559264,0.0179653763771057,0.701986704581986,0.9862337112426758,0.05374501273036,0.2348137809159835,43793 -7009.93562078476,1.7863051891326904,13699.364458560944,42636,0,13699.364458560944,0.9854514598846436,0.058204285800457,0.2309237291602419,43793,20712.2668569088,0.9941781163215636,0.0178943015635013,0.6990086529737705,0.9864057898521424,0.0544170401990413,0.2398192574255114,43793 -7126.545880794525,1.8220069408416748,13939.344955205916,43393,0,13939.344955205916,0.985403060913086,0.0582301169633865,0.2280993431328843,43793,21068.913420438766,0.9946008920669556,0.0168354269117116,0.7236670943979682,0.9863424897193908,0.0547343268990516,0.2367951793593295,43793 -7235.889031648636,1.866541862487793,14179.578085184095,44147,0,14179.578085184095,0.9852429628372192,0.0583256147801876,0.2253432139066692,43793,21418.554672002792,0.9953388571739196,0.0151414508000016,0.7568423681622725,0.9860754013061525,0.054790049791336,0.2288448749028145,43793 -7351.393758773804,1.9024038314819336,14419.814072847366,44884,0,14419.814072847366,0.9854169487953186,0.0582167506217956,0.2274132334812646,43793,21774.352262735367,0.996155858039856,0.0133711593225598,0.7954491680973936,0.986195147037506,0.0549307949841022,0.2399991624461255,43793 -7465.55695939064,1.9380853176116943,14659.792273283005,45620,0,14659.792273283005,0.9852564930915833,0.0593939758837223,0.2220551116809077,43793,22128.552355766296,0.996436893939972,0.0128021948039531,0.7997349957824464,0.986088752746582,0.055894199758768,0.2335008345089164,43793 -7581.692788124084,1.9741096496582031,14899.860787391664,46380,0,14899.860787391664,0.9850963950157166,0.05946921184659,0.2228981628786851,43793,22484.813533067703,0.995314359664917,0.0151610001921653,0.7500370780843769,0.9860250353813172,0.0560168400406837,0.2258241513802438,43793 -7694.41745185852,2.01121473312378,15139.825226068497,47129,0,15139.825226068497,0.9853002429008484,0.0599371455609798,0.2316235445308457,43793,22837.559860944748,0.9956495761871338,0.0142398113384842,0.7698141151166427,0.9861208200454712,0.0565310157835483,0.2288986155071729,43793 -7809.593261241913,2.0465872287750244,15379.947883367538,47879,0,15379.947883367538,0.9852320551872252,0.0598737373948097,0.2277907845115211,43793,23192.9143447876,0.9952120780944824,0.015117822214961,0.7573643665429037,0.9860530495643616,0.0564565919339656,0.2272830550792519,43793 -7922.388898611069,2.09068250656128,15620.005237102509,48625,0,15620.005237102509,0.9852830171585084,0.0603629611432552,0.2280650358780248,43793,23545.832807779312,0.9950177669525146,0.0155489360913634,0.7415567667563101,0.9860798716545104,0.0569931715726852,0.2300238933443485,43793 -8041.968768358231,2.127811908721924,15860.014887571337,49368,0,15860.014887571337,0.9852412939071656,0.0611497797071933,0.2257478494693963,43793,23905.479996204376,0.9946566820144652,0.0164572186768054,0.7364141469888428,0.9861857891082764,0.057639967650175,0.2242057574321107,43793 -8160.495388507843,2.1722218990325928,16100.006301641464,50121,0,16100.006301641464,0.9851734638214112,0.0609352216124534,0.2285361877579136,43793,24264.06285619736,0.9957500696182252,0.0139283929020166,0.7854547276249997,0.9860916137695312,0.0574779324233531,0.2286573693495823,43793 -8273.594465494156,2.208230495452881,16340.236045360563,50869,0,16340.236045360563,0.985240876674652,0.0616643354296684,0.2263295055390688,43793,24617.448274374008,0.99623841047287,0.0125467721372842,0.8122539594282422,0.9861472249031068,0.0581727661192417,0.2258159763884001,43793 -8386.665300130844,2.2444097995758057,16580.56310081482,51620,0,16580.56310081482,0.9852067828178406,0.0629700496792793,0.2218210113962008,43793,24970.902344703674,0.9967470765113832,0.0114964265376329,0.8367234392536422,0.986141562461853,0.0593289285898208,0.2274897693187031,43793 -8495.671817302704,2.2801904678344727,16820.774089574814,52381,0,16820.774089574814,0.9852383732795716,0.0631648600101471,0.223539776319986,43793,25320.17582321167,0.9970282316207886,0.0108794895932078,0.8544680011223145,0.9861334562301636,0.0595301464200019,0.2277542538640574,43793 -8608.965829610825,2.3182740211486816,17060.759781360626,53126,0,17060.759781360626,0.9852033853530884,0.0632070749998092,0.2219429456533263,43793,25673.513555049896,0.9972330331802368,0.0105284303426742,0.8540056525446815,0.9861581921577454,0.0593896433711051,0.227870746548491,43793 -8723.215174674988,2.3571720123291016,17300.893683433533,53873,0,17300.893683433533,0.9851911664009094,0.0638005584478378,0.2149521823725836,43793,26027.95757389069,0.9966502785682678,0.011681037954986,0.835379206205047,0.9860693216323853,0.0602785646915435,0.2249466474666006,43793 -8836.655643463135,2.39902138710022,17540.84679889679,54620,0,17540.84679889679,0.9852109551429749,0.0639358758926391,0.2158959223624439,43793,26381.413821458817,0.9948896765708924,0.0155108226463198,0.7515258127315602,0.9860928654670716,0.0601283572614192,0.22509716084751,43793 -8945.55486869812,2.4369421005249023,17780.836613178253,55367,0,17780.836613178253,0.985027313232422,0.0636545941233635,0.2154933151832759,43793,26730.361434221268,0.9954902529716492,0.0140686891973018,0.7815939752419798,0.9859552383422852,0.0598899014294147,0.2264501280562338,43793 -9059.463624238968,2.475022077560425,18021.015577316284,56126,0,18021.015577316284,0.9852176904678344,0.0648197308182716,0.2176275776241821,43793,27084.50729203224,0.9959081411361694,0.0127335209399461,0.8261949531777921,0.986108660697937,0.0612089298665523,0.2227735685321576,43793 -9171.793118715286,2.5136139392852783,18261.217477321625,56866,0,18261.217477321625,0.985187828540802,0.06561759114265442,0.2167473575816602,43793,27437.100608110428,0.9954772591590881,0.013762353919446468,0.8133653761740869,0.9861249327659607,0.06183765456080437,0.2242386545209901,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/measurements.csv deleted file mode 100644 index e9943dd3c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/measurements.csv +++ /dev/null @@ -1,655 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3050804,0.73579973,,,,,,,,,,,,,,,,, -1,,,0.5289902091026306,0.7364099025726318,0.0206864879251886,0.5270814895629883,0.737440824508667,0.0240625484422362,43793.0,0.5256850719451904,0.7376683354377747,0.0260218475243196,43793.0,12.126285314559937,139.33090662956238,12.126285314559937,127.20457720756532,0.0,0.0 -100,0.45551386,0.40146494,,,,,,,,,,,,,,,,, -200,0.32764107,0.29363418,,,,,,,,,,,,,,,,, -300,0.23025993,0.20027761,,,,,,,,,,,,,,,,, -400,0.14745148,0.13419878,,,,,,,,,,,,,,,,, -500,0.08871685,0.099133454,,,,,,,,,,,,,,,,, -600,0.07955796,0.077774175,,,,,,,,,,,,,,,,, -700,0.037942693,0.06396359,,,,,,,,,,,,,,,,, -750,,,0.986754298210144,0.0641855522990226,0.043571310424089,0.9841179251670836,0.0730567350983619,0.0453327088891681,43793.0,0.983142077922821,0.0760092660784721,0.0461748143897109,43793.0,252.2958436012268,505.4296803474426,252.2958436012268,253.0918610095977,0.021273136138916,0.0 -800,0.1069227,0.06539791,,,,,,,,,,,,,,,,, -900,0.10761505,0.056760382,,,,,,,,,,,,,,,,, -1000,0.05383116,0.05364513,,,,,,,,,,,,,,,,, -1100,0.057074413,0.048420697,,,,,,,,,,,,,,,,, -1200,0.03885037,0.053245302,,,,,,,,,,,,,,,,, -1300,0.04625007,0.049689446,,,,,,,,,,,,,,,,, -1400,0.041510668,0.053376045,,,,,,,,,,,,,,,,, -1491,,,0.9871442317962646,0.0485940016806125,0.0847953123676511,0.9844759702682496,0.0579231418669223,0.0876923435601555,43793.0,0.9834765195846558,0.0612092986702919,0.0846994099290799,43793.0,492.5110273361206,874.5190241336823,492.5110273361206,381.9180512428284,0.0487275123596191,0.0 -1500,0.034243513,0.05671674,,,,,,,,,,,,,,,,, -1600,0.05296692,0.05095119,,,,,,,,,,,,,,,,, -1700,0.023551712,0.053852223,,,,,,,,,,,,,,,,, -1800,0.031809375,0.052751467,,,,,,,,,,,,,,,,, -1900,0.027549485,0.053167902,,,,,,,,,,,,,,,,, -2000,0.03321226,0.049967237,,,,,,,,,,,,,,,,, -2100,0.026390025,0.05018376,,,,,,,,,,,,,,,,, -2200,0.033986904,0.0503423,,,,,,,,,,,,,,,,, -2236,,,0.9874035120010376,0.0458655469119548,0.1138696637483815,0.9847410321235656,0.0549195483326911,0.1122611004225512,43793.0,0.9837313294410706,0.0580892078578472,0.1110899480813678,43793.0,732.7416000366211,1242.9329607486725,732.7416000366211,510.0526442527771,0.0775880813598632,0.0 -2300,0.036457606,0.046172,,,,,,,,,,,,,,,,, -2400,0.026548445,0.049182907,,,,,,,,,,,,,,,,, -2500,0.058099676,0.05060228,,,,,,,,,,,,,,,,, -2600,0.07484243,0.045715563,,,,,,,,,,,,,,,,, -2700,0.04894025,0.05003473,,,,,,,,,,,,,,,,, -2800,0.029364113,0.048828274,,,,,,,,,,,,,,,,, -2900,0.066683754,0.04793973,,,,,,,,,,,,,,,,, -2985,,,0.9875199794769288,0.0446250885725021,0.1425213550215882,0.984859585762024,0.0536448583006858,0.1345195998544093,43793.0,0.9838774800300598,0.0565657056868076,0.1329894472602042,43793.0,972.7662193775176,1610.2927725315094,972.7662193775176,637.340523481369,0.1042702198028564,0.0 -3000,0.026698124,0.04471589,,,,,,,,,,,,,,,,, -3100,0.03181428,0.04886238,,,,,,,,,,,,,,,,, -3200,0.04231537,0.04429995,,,,,,,,,,,,,,,,, -3300,0.02457644,0.046883866,,,,,,,,,,,,,,,,, -3400,0.043453388,0.044385947,,,,,,,,,,,,,,,,, -3500,0.021044038,0.04824136,,,,,,,,,,,,,,,,, -3600,0.041711885,0.046907667,,,,,,,,,,,,,,,,, -3700,0.03205072,0.043969702,,,,,,,,,,,,,,,,, -3730,,,0.9874981641769408,0.0441562943160533,0.16691992686249,0.9848372340202332,0.0532684922218322,0.1536197460118933,43793.0,0.983906090259552,0.0564019158482551,0.1527795488106242,43793.0,1212.8425545692444,1973.5845963954928,1212.8425545692444,760.508437871933,0.131772756576538,0.0 -3800,0.054954853,0.059457984,,,,,,,,,,,,,,,,, -3900,0.048689935,0.047674134,,,,,,,,,,,,,,,,, -4000,0.03214997,0.047185093,,,,,,,,,,,,,,,,, -4100,0.027059894,0.04363025,,,,,,,,,,,,,,,,, -4200,0.047976285,0.045386124,,,,,,,,,,,,,,,,, -4300,0.054428905,0.049128402,,,,,,,,,,,,,,,,, -4400,0.021769183,0.04558294,,,,,,,,,,,,,,,,, -4480,,,0.9881991744041444,0.0413440726697444,0.1970566464858755,0.985330045223236,0.0503430180251598,0.1702478411418039,43793.0,0.9844414591789246,0.053200889378786,0.1681703076887086,43793.0,1453.064716339111,2335.695539712906,1453.064716339111,882.3497655391693,0.1592152118682861,0.0 -4500,0.011402607,0.040306345,,,,,,,,,,,,,,,,, -4600,0.016126279,0.04462785,,,,,,,,,,,,,,,,, -4700,0.01991858,0.050891846,,,,,,,,,,,,,,,,, -4800,0.016101353,0.04011064,,,,,,,,,,,,,,,,, -4900,0.024194613,0.04525669,,,,,,,,,,,,,,,,, -5000,0.010730064,0.041516162,,,,,,,,,,,,,,,,, -5100,0.018281886,0.04235119,,,,,,,,,,,,,,,,, -5200,0.015959693,0.043226987,,,,,,,,,,,,,,,,, -5232,,,0.9880143404006958,0.041274219751358,0.2060017191467205,0.9851628541946412,0.0510475561022758,0.1887682563880464,43793.0,0.9842683672904968,0.0540040172636508,0.1897039477965048,43793.0,1693.1650228500366,2700.22809290886,1693.1650228500366,1006.7336690425872,0.1871747970581054,0.0 -5300,0.02729364,0.043042295,,,,,,,,,,,,,,,,, -5400,0.018740352,0.043929845,,,,,,,,,,,,,,,,, -5500,0.03072395,0.04498469,,,,,,,,,,,,,,,,, -5600,0.028725876,0.04180368,,,,,,,,,,,,,,,,, -5700,0.018485893,0.04478857,,,,,,,,,,,,,,,,, -5800,0.020130947,0.041711338,,,,,,,,,,,,,,,,, -5900,0.022949811,0.042256955,,,,,,,,,,,,,,,,, -5975,,,0.9888017773628236,0.0386166907846927,0.2381257512652334,0.9857409000396729,0.048397846519947,0.2011913112182664,43793.0,0.9848635196685792,0.0512248836457729,0.2039977342576277,43793.0,1933.403118848801,3068.890548229217,1933.403118848801,1135.106214761734,0.2173802852630615,0.0 -6000,0.016170138,0.04093352,,,,,,,,,,,,,,,,, -6100,0.029766634,0.0434131,,,,,,,,,,,,,,,,, -6200,0.020136854,0.04341522,,,,,,,,,,,,,,,,, -6300,0.018818175,0.042024527,,,,,,,,,,,,,,,,, -6400,0.024010124,0.040693987,,,,,,,,,,,,,,,,, -6500,0.014920604,0.04263138,,,,,,,,,,,,,,,,, -6600,0.037385795,0.042872384,,,,,,,,,,,,,,,,, -6700,0.017708119,0.038892917,,,,,,,,,,,,,,,,, -6721,,,0.9884495735168456,0.0391944348812103,0.2413052924407402,0.9856499433517456,0.0487793684005737,0.2044096955630416,43793.0,0.9846937656402588,0.051630537956953,0.2059085480747178,43793.0,2173.590786933899,3431.493139743805,2173.590786933899,1257.4715340137482,0.246328592300415,0.0 -6800,0.011360957,0.040328868,,,,,,,,,,,,,,,,, -6900,0.011673983,0.042820033,,,,,,,,,,,,,,,,, -7000,0.014704465,0.041037474,,,,,,,,,,,,,,,,, -7100,0.017889274,0.043274682,,,,,,,,,,,,,,,,, -7200,0.014752509,0.04149049,,,,,,,,,,,,,,,,, -7300,0.015229591,0.045950096,,,,,,,,,,,,,,,,, -7400,0.018138787,0.042100623,,,,,,,,,,,,,,,,, -7477,,,0.9889029264450072,0.037515502423048,0.2757240612573762,0.9859803915023804,0.0474647879600524,0.2206059988797787,43793.0,0.9850782752037048,0.0501323044300079,0.2198003057995466,43793.0,2413.612592458725,3793.469251871109,2413.612592458725,1379.377456188202,0.2736272811889648,0.0 -7500,0.014963686,0.04117051,,,,,,,,,,,,,,,,, -7600,0.022542067,0.043293614,,,,,,,,,,,,,,,,, -7700,0.017289264,0.042998035,,,,,,,,,,,,,,,,, -7800,0.018044323,0.042449623,,,,,,,,,,,,,,,,, -7900,0.016710177,0.03847995,,,,,,,,,,,,,,,,, -8000,0.013356859,0.03879589,,,,,,,,,,,,,,,,, -8100,0.01992386,0.04315818,,,,,,,,,,,,,,,,, -8200,0.018227587,0.04117391,,,,,,,,,,,,,,,,, -8225,,,0.9894757270812988,0.0359002090990543,0.3072628440876122,0.9861716032028198,0.0467517748475074,0.2247504688285445,43793.0,0.9852569103240968,0.0494446270167827,0.2252269584982761,43793.0,2653.793397426605,4155.104791641235,2653.793397426605,1500.782978773117,0.3023748397827148,0.0 -8300,0.013498155,0.038243003,,,,,,,,,,,,,,,,, -8400,0.015046779,0.04036945,,,,,,,,,,,,,,,,, -8500,0.019451104,0.04157226,,,,,,,,,,,,,,,,, -8600,0.017865736,0.039242618,,,,,,,,,,,,,,,,, -8700,0.013131846,0.041650895,,,,,,,,,,,,,,,,, -8800,0.014615037,0.040477626,,,,,,,,,,,,,,,,, -8900,0.018124096,0.045514606,,,,,,,,,,,,,,,,, -8984,,,0.9894742965698242,0.0352791287004947,0.3132704878213992,0.986291766166687,0.0461629033088684,0.235990866792487,43793.0,0.9854257702827454,0.048779260367155,0.2395137327731072,43793.0,2893.755373477936,4520.299135923386,2893.755373477936,1625.9661529064178,0.3305733203887939,0.0 -9000,0.01810156,0.042578742,,,,,,,,,,,,,,,,, -9100,0.023668127,0.0407538,,,,,,,,,,,,,,,,, -9200,0.01866452,0.037600875,,,,,,,,,,,,,,,,, -9300,0.018210696,0.04017688,,,,,,,,,,,,,,,,, -9400,0.023688557,0.04111379,,,,,,,,,,,,,,,,, -9500,0.016786829,0.039084308,,,,,,,,,,,,,,,,, -9600,0.016641738,0.044066615,,,,,,,,,,,,,,,,, -9700,0.020552192,0.039085913,,,,,,,,,,,,,,,,, -9733,,,0.9896693229675292,0.0342782400548458,0.3510216167629603,0.9862142205238342,0.0464232116937637,0.2376936664874003,43793.0,0.9853495359420776,0.0491264685988426,0.2371165836262182,43793.0,3133.7946379184723,4883.982634544373,3133.7946379184723,1749.5614104270935,0.358844518661499,0.0 -9800,0.013894429,0.039349116,,,,,,,,,,,,,,,,, -9900,0.029786874,0.043522246,,,,,,,,,,,,,,,,, -10000,0.01535322,0.038987856,,,,,,,,,,,,,,,,, -10100,0.014292835,0.036724027,,,,,,,,,,,,,,,,, -10200,0.017473495,0.0385877,,,,,,,,,,,,,,,,, -10300,0.015619233,0.040467154,,,,,,,,,,,,,,,,, -10400,0.013233175,0.03740311,,,,,,,,,,,,,,,,, -10481,,,0.9896564483642578,0.0339212678372859,0.3735877274136109,0.9863014817237854,0.0462760962545871,0.2454329941619501,43793.0,0.9854468703269958,0.0490616187453269,0.2472147852661594,43793.0,3374.0041666030884,5246.561218261719,3374.0041666030884,1871.8764510154724,0.3905558586120605,0.0 -10500,0.015522066,0.039854877,,,,,,,,,,,,,,,,, -10600,0.01231677,0.03650451,,,,,,,,,,,,,,,,, -10700,0.015834587,0.039028585,,,,,,,,,,,,,,,,, -10800,0.014680378,0.038274955,,,,,,,,,,,,,,,,, -10900,0.017095923,0.040643066,,,,,,,,,,,,,,,,, -11000,0.015916122,0.039642114,,,,,,,,,,,,,,,,, -11100,0.019988,0.040198136,,,,,,,,,,,,,,,,, -11200,0.018562252,0.041706152,,,,,,,,,,,,,,,,, -11236,,,0.989663541316986,0.0340158343315124,0.3581437473261031,0.9863518476486206,0.0462555959820747,0.2520758598756026,43793.0,0.9854127168655396,0.0493291094899177,0.2470103021327772,43793.0,3614.227289676666,5608.641200304031,3614.227289676666,1993.6837046146395,0.4195625782012939,0.0 -11300,0.018470768,0.03899065,,,,,,,,,,,,,,,,, -11400,0.018623484,0.039099745,,,,,,,,,,,,,,,,, -11500,0.0175985,0.040365245,,,,,,,,,,,,,,,,, -11600,0.01706694,0.040017094,,,,,,,,,,,,,,,,, -11700,0.015059777,0.03898074,,,,,,,,,,,,,,,,, -11800,0.018523095,0.038743027,,,,,,,,,,,,,,,,, -11900,0.0188339,0.03872688,,,,,,,,,,,,,,,,, -11992,,,0.9899739623069764,0.0330843292176723,0.3703011366378856,0.9864541292190552,0.046054221689701,0.2510109312493945,43793.0,0.9854961037635804,0.0490399077534675,0.2538872922350826,43793.0,3854.4807589054094,5969.253441572189,3854.4807589054094,2113.992784023285,0.4489624500274658,0.0 -12000,0.020818807,0.037839163,,,,,,,,,,,,,,,,, -12100,0.018767193,0.038445417,,,,,,,,,,,,,,,,, -12200,0.020396233,0.038065027,,,,,,,,,,,,,,,,, -12300,0.018192349,0.037586246,,,,,,,,,,,,,,,,, -12400,0.014686275,0.039032537,,,,,,,,,,,,,,,,, -12500,0.016265517,0.036770705,,,,,,,,,,,,,,,,, -12600,0.01841239,0.04007139,,,,,,,,,,,,,,,,, -12700,0.019433955,0.039555836,,,,,,,,,,,,,,,,, -12742,,,0.9902027249336244,0.0324925333261489,0.3828772272923211,0.9865673780441284,0.0455458313226699,0.2547032048499188,43793.0,0.9856456518173218,0.0483121685683727,0.2498462761455225,43793.0,4094.628182649613,6327.803616523743,4094.628182649613,2232.345431089401,0.4781973361968994,0.0 -12800,0.020100806,0.039446343,,,,,,,,,,,,,,,,, -12900,0.01931611,0.037207,,,,,,,,,,,,,,,,, -13000,0.015200509,0.034494214,,,,,,,,,,,,,,,,, -13100,0.01593613,0.037335545,,,,,,,,,,,,,,,,, -13200,0.015508545,0.035281233,,,,,,,,,,,,,,,,, -13300,0.015252384,0.037402324,,,,,,,,,,,,,,,,, -13400,0.016675603,0.036102567,,,,,,,,,,,,,,,,, -13491,,,0.9905490279197692,0.0317074432969093,0.3998465120748556,0.9866088032722472,0.0450855679810047,0.255800612872904,43793.0,0.98575097322464,0.0479115359485149,0.2501149207185102,43793.0,4334.628254652023,6687.5966901779175,4334.628254652023,2352.0891876220703,0.5074634552001953,0.0 -13500,0.020038083,0.03765475,,,,,,,,,,,,,,,,, -13600,0.018968813,0.038634054,,,,,,,,,,,,,,,,, -13700,0.017571196,0.036826037,,,,,,,,,,,,,,,,, -13800,0.018071678,0.034056187,,,,,,,,,,,,,,,,, -13900,0.01592665,0.03753956,,,,,,,,,,,,,,,,, -14000,0.019384619,0.037176702,,,,,,,,,,,,,,,,, -14100,0.016726881,0.038977668,,,,,,,,,,,,,,,,, -14200,0.02107708,0.03658416,,,,,,,,,,,,,,,,, -14243,,,0.9906753897666932,0.0309940185397863,0.4288764369674755,0.9867565631866456,0.0450419187545776,0.2566420787519,43793.0,0.9858794212341307,0.0477812215685844,0.2592707509360255,43793.0,4574.708070993424,7048.78892493248,4574.708070993424,2473.1447324752808,0.5425519943237305,0.0 -14300,0.019568356,0.03796185,,,,,,,,,,,,,,,,, -14400,0.017068088,0.037466146,,,,,,,,,,,,,,,,, -14500,0.01556192,0.034511857,,,,,,,,,,,,,,,,, -14600,0.0178966,0.037517473,,,,,,,,,,,,,,,,, -14700,0.022212034,0.036922228,,,,,,,,,,,,,,,,, -14800,0.018780245,0.033434585,,,,,,,,,,,,,,,,, -14900,0.020112388,0.035723154,,,,,,,,,,,,,,,,, -14988,,,0.9906166195869446,0.0303849708288908,0.4262323265330509,0.986648976802826,0.045590728521347,0.2593278819248042,43793.0,0.985731601715088,0.0484391562640666,0.2535050643187877,43793.0,4814.695953845978,7410.355037212372,4814.695953845978,2594.674001932144,0.5714969635009766,0.0 -15000,0.020786678,0.037264284,,,,,,,,,,,,,,,,, -15100,0.016851023,0.034956787,,,,,,,,,,,,,,,,, -15200,0.022459276,0.03812817,,,,,,,,,,,,,,,,, -15300,0.022974778,0.035575576,,,,,,,,,,,,,,,,, -15400,0.019130832,0.03553083,,,,,,,,,,,,,,,,, -15500,0.019201672,0.03654228,,,,,,,,,,,,,,,,, -15600,0.022254165,0.03720373,,,,,,,,,,,,,,,,, -15700,0.023206161,0.035961352,,,,,,,,,,,,,,,,, -15725,,,0.9909983277320862,0.0295489635318517,0.454112727750228,0.9866875410079956,0.0453388169407844,0.2552772530468821,43793.0,0.985811173915863,0.0480678342282772,0.2574324454887617,43793.0,5054.773876190186,7771.213989019394,5054.773876190186,2715.398867845536,0.6034946441650391,0.0 -15800,0.017015692,0.035892818,,,,,,,,,,,,,,,,, -15900,0.01930677,0.038364768,,,,,,,,,,,,,,,,, -16000,0.018004138,0.03597727,,,,,,,,,,,,,,,,, -16100,0.019792415,0.03689308,,,,,,,,,,,,,,,,, -16200,0.023861613,0.036779523,,,,,,,,,,,,,,,,, -16300,0.02368044,0.03743101,,,,,,,,,,,,,,,,, -16400,0.02680528,0.03748867,,,,,,,,,,,,,,,,, -16472,,,0.991483211517334,0.0277794189751148,0.4892964067033922,0.986763834953308,0.045558076351881,0.2630052954278285,43793.0,0.985897958278656,0.0482565090060234,0.256947923922455,43793.0,5294.818076848984,8131.040618658066,5294.818076848984,2835.131489276886,0.632915735244751,0.0 -16500,0.02323125,0.03883541,,,,,,,,,,,,,,,,, -16600,0.020788139,0.038289405,,,,,,,,,,,,,,,,, -16700,0.024964307,0.03772577,,,,,,,,,,,,,,,,, -16800,0.01969048,0.035859335,,,,,,,,,,,,,,,,, -16900,0.026373256,0.035348874,,,,,,,,,,,,,,,,, -17000,0.028052289,0.0367091,,,,,,,,,,,,,,,,, -17100,0.02054665,0.03443784,,,,,,,,,,,,,,,,, -17200,0.01993079,0.03574957,,,,,,,,,,,,,,,,, -17216,,,0.9915978908538818,0.0274420864880085,0.5071297348406607,0.9867764711380004,0.0454589314758777,0.259824773748504,43793.0,0.9859485030174256,0.0482847541570663,0.2594389000289132,43793.0,5535.046224355698,8489.352545261383,5535.046224355698,2953.16513299942,0.6627569198608398,0.0 -17300,0.018983735,0.0342315,,,,,,,,,,,,,,,,, -17400,0.020810204,0.03394183,,,,,,,,,,,,,,,,, -17500,0.023589455,0.03706359,,,,,,,,,,,,,,,,, -17600,0.028167557,0.033554044,,,,,,,,,,,,,,,,, -17700,0.018709792,0.033542097,,,,,,,,,,,,,,,,, -17800,0.029454492,0.035019174,,,,,,,,,,,,,,,,, -17900,0.023690617,0.035437863,,,,,,,,,,,,,,,,, -17970,,,0.991574227809906,0.027520490810275,0.5063725356236961,0.9866810441017152,0.0457698628306388,0.2549757899168097,43793.0,0.985820472240448,0.0484392642974853,0.2553436713257808,43793.0,5775.2991716861725,8846.437492847443,5775.2991716861725,3069.9476771354675,0.691856861114502,0.0 -18000,0.022601636,0.03420733,,,,,,,,,,,,,,,,, -18100,0.028824864,0.0377649,,,,,,,,,,,,,,,,, -18200,0.020621264,0.033774275,,,,,,,,,,,,,,,,, -18300,0.027622944,0.035217203,,,,,,,,,,,,,,,,, -18400,0.019695599,0.03540585,,,,,,,,,,,,,,,,, -18500,0.020854758,0.03516108,,,,,,,,,,,,,,,,, -18600,0.027689587,0.03541314,,,,,,,,,,,,,,,,, -18700,0.028363595,0.036295675,,,,,,,,,,,,,,,,, -18712,,,0.9916909337043762,0.0276644323021173,0.4879586830864612,0.9867812991142272,0.045619148761034,0.2659859739068194,43793.0,0.985951006412506,0.0483011044561862,0.2602272453178065,43793.0,6015.247898340225,9210.452338218687,6015.247898340225,3193.9604799747467,0.723224401473999,0.0 -18800,0.020306166,0.03196559,,,,,,,,,,,,,,,,, -18900,0.02240101,0.035772793,,,,,,,,,,,,,,,,, -19000,0.02538959,0.037087455,,,,,,,,,,,,,,,,, -19100,0.022857426,0.031196764,,,,,,,,,,,,,,,,, -19200,0.021789804,0.03593295,,,,,,,,,,,,,,,,, -19300,0.02183555,0.03385428,,,,,,,,,,,,,,,,, -19400,0.027497482,0.03329554,,,,,,,,,,,,,,,,, -19462,,,0.9915679693222046,0.0275299493223428,0.4931503201414768,0.9867374897003174,0.0458053909242153,0.2623122251543744,43793.0,0.9858710169792176,0.0486629828810691,0.2597261426049351,43793.0,6255.323559045792,9573.666657924652,6255.323559045792,3317.048745393753,0.7530508041381836,0.0 -19500,0.024119893,0.034291774,,,,,,,,,,,,,,,,, -19600,0.02760949,0.036050886,,,,,,,,,,,,,,,,, -19700,0.028253302,0.036302548,,,,,,,,,,,,,,,,, -19800,0.024177194,0.034179665,,,,,,,,,,,,,,,,, -19900,0.022409359,0.03428642,,,,,,,,,,,,,,,,, -20000,0.027906267,0.03494006,,,,,,,,,,,,,,,,, -20100,0.03053135,0.036042232,,,,,,,,,,,,,,,,, -20200,0.02676497,0.036995053,,,,,,,,,,,,,,,,, -20217,,,0.9917171597480774,0.0272450819611549,0.5011942091495836,0.986620545387268,0.0463441610336303,0.2602996367965026,43793.0,0.985770344734192,0.0491240657866001,0.2538815944606575,43793.0,6495.576878070831,9932.12539958954,6495.576878070831,3435.202912569046,0.7837088108062744,0.0 -20300,0.026351212,0.032940254,,,,,,,,,,,,,,,,, -20400,0.027200153,0.033906255,,,,,,,,,,,,,,,,, -20500,0.030070301,0.03542132,,,,,,,,,,,,,,,,, -20600,0.026156906,0.036158714,,,,,,,,,,,,,,,,, -20700,0.032303132,0.03507515,,,,,,,,,,,,,,,,, -20800,0.029517192,0.034777474,,,,,,,,,,,,,,,,, -20900,0.029492667,0.03287051,,,,,,,,,,,,,,,,, -20964,,,0.9916337728500366,0.0269696675240993,0.5147379353860131,0.9867159724235536,0.0463829524815082,0.2640476810927308,43793.0,0.985842764377594,0.0492424480617046,0.2505229961525075,43793.0,6735.836624145508,10290.507292747498,6735.836624145508,3553.2706336975098,0.817469596862793,0.0 -21000,0.03411523,0.039938763,,,,,,,,,,,,,,,,, -21100,0.028889908,0.035424802,,,,,,,,,,,,,,,,, -21200,0.032442957,0.032803666,,,,,,,,,,,,,,,,, -21300,0.026334291,0.03405286,,,,,,,,,,,,,,,,, -21400,0.03184388,0.033964805,,,,,,,,,,,,,,,,, -21500,0.031738866,0.034029827,,,,,,,,,,,,,,,,, -21600,0.025383282,0.032369375,,,,,,,,,,,,,,,,, -21700,0.027360689,0.032589123,,,,,,,,,,,,,,,,, -21709,,,0.992107629776001,0.0259189009666442,0.5224654679997216,0.9866802096366882,0.0461787469685077,0.2573069386431443,43793.0,0.985846996307373,0.0488092526793479,0.2529326603844326,43793.0,6975.841460704804,10657.739182472227,6975.841460704804,3680.445751905441,0.8478615283966064,0.0 -21800,0.027851127,0.03328685,,,,,,,,,,,,,,,,, -21900,0.029831793,0.032922745,,,,,,,,,,,,,,,,, -22000,0.028004123,0.032873455,,,,,,,,,,,,,,,,, -22100,0.032993317,0.034361914,,,,,,,,,,,,,,,,, -22200,0.03207831,0.033119325,,,,,,,,,,,,,,,,, -22300,0.030998845,0.030863445,,,,,,,,,,,,,,,,, -22400,0.029721804,0.033242024,,,,,,,,,,,,,,,,, -22456,,,0.9921752214431764,0.0250938981771469,0.5612627482316623,0.9866765737533568,0.0467959716916084,0.2537034977859143,43793.0,0.9858326315879822,0.0496953874826431,0.2527972181515401,43793.0,7216.057116985321,11013.379020690918,7216.057116985321,3795.818098068237,0.8788893222808838,0.0 -22500,0.031545714,0.034913305,,,,,,,,,,,,,,,,, -22600,0.036489815,0.03625433,,,,,,,,,,,,,,,,, -22700,0.028068818,0.032647096,,,,,,,,,,,,,,,,, -22800,0.027882986,0.030983422,,,,,,,,,,,,,,,,, -22900,0.0296304,0.031203728,,,,,,,,,,,,,,,,, -23000,0.035935767,0.03571935,,,,,,,,,,,,,,,,, -23100,0.034500875,0.032769892,,,,,,,,,,,,,,,,, -23200,0.03522504,0.034845043,,,,,,,,,,,,,,,,, -23203,,,0.9926006197929382,0.0240978058427572,0.5666674766675607,0.9866806268692015,0.0464501082897186,0.2576747314424948,43793.0,0.9858015179634094,0.0494990982115268,0.2530047210423624,43793.0,7456.157476663589,11377.154458284378,7456.157476663589,3919.440933704376,0.9106287956237792,0.0 -23300,0.039763764,0.03604966,,,,,,,,,,,,,,,,, -23400,0.02938447,0.031786386,,,,,,,,,,,,,,,,, -23500,0.032478046,0.031187952,,,,,,,,,,,,,,,,, -23600,0.032203775,0.031921986,,,,,,,,,,,,,,,,, -23700,0.033581436,0.032366708,,,,,,,,,,,,,,,,, -23800,0.033647038,0.031973287,,,,,,,,,,,,,,,,, -23900,0.03828894,0.03327192,,,,,,,,,,,,,,,,, -23953,,,0.9929395318031312,0.0229268539696931,0.6068382807187294,0.9866469502449036,0.0467442981898784,0.2506215024728762,43793.0,0.9857783317565918,0.0499153286218643,0.2435153176655401,43793.0,7696.350813627243,11737.6252348423,7696.350813627243,4039.667632818222,0.9410545825958252,0.0 -24000,0.03956934,0.036232244,,,,,,,,,,,,,,,,, -24100,0.03148213,0.033421263,,,,,,,,,,,,,,,,, -24200,0.032426633,0.033068087,,,,,,,,,,,,,,,,, -24300,0.039487835,0.033418383,,,,,,,,,,,,,,,,, -24400,0.037718445,0.032252423,,,,,,,,,,,,,,,,, -24500,0.042189498,0.034807768,,,,,,,,,,,,,,,,, -24600,0.035165235,0.03303477,,,,,,,,,,,,,,,,, -24697,,,0.9930366277694702,0.0227992050349712,0.6002876704147919,0.9866778254508972,0.0469138175249099,0.2580225307631216,43793.0,0.9857467412948608,0.0501218922436237,0.2449126202635368,43793.0,7936.4745626449585,12100.827523469923,7936.4745626449585,4162.690213441849,0.9763381481170654,0.0 -24700,0.034266785,0.03317762,,,,,,,,,,,,,,,,, -24800,0.033358634,0.03242633,,,,,,,,,,,,,,,,, -24900,0.054492652,0.03220028,,,,,,,,,,,,,,,,, -25000,0.03792135,0.030615237,,,,,,,,,,,,,,,,, -25100,0.03801484,0.03279332,,,,,,,,,,,,,,,,, -25200,0.03566383,0.032700773,,,,,,,,,,,,,,,,, -25300,0.034893386,0.03409219,,,,,,,,,,,,,,,,, -25400,0.038329624,0.03475141,,,,,,,,,,,,,,,,, -25443,,,0.9928334951400756,0.0233095176517963,0.5824658225387833,0.9866371750831604,0.0471940413117408,0.2585839289400465,43793.0,0.9858074188232422,0.050314363092184,0.2511761806975714,43793.0,8176.495717287064,12460.498522043228,8176.495717287064,4282.287209033966,1.0089423656463623,0.0 -25500,0.032066394,0.02950352,,,,,,,,,,,,,,,,, -25600,0.038706016,0.03215794,,,,,,,,,,,,,,,,, -25700,0.034388673,0.031441227,,,,,,,,,,,,,,,,, -25800,0.03528788,0.032172974,,,,,,,,,,,,,,,,, -25900,0.04249428,0.0312954,,,,,,,,,,,,,,,,, -26000,0.046610933,0.032338683,,,,,,,,,,,,,,,,, -26100,0.03892949,0.032169577,,,,,,,,,,,,,,,,, -26181,,,0.9925329089164734,0.0240434017032384,0.5762234810496102,0.986629068851471,0.0475363433361053,0.255861604009055,43793.0,0.9857892990112304,0.0504956133663654,0.2471158389841096,43793.0,8416.455756187439,12820.535581588743,8416.455756187439,4402.311491250992,1.04034686088562,0.0 -26200,0.04112539,0.032828856,,,,,,,,,,,,,,,,, -26300,0.05154078,0.034566045,,,,,,,,,,,,,,,,, -26400,0.03663109,0.03049987,,,,,,,,,,,,,,,,, -26500,0.044006296,0.032763522,,,,,,,,,,,,,,,,, -26600,0.0455543,0.034788206,,,,,,,,,,,,,,,,, -26700,0.038844965,0.03189223,,,,,,,,,,,,,,,,, -26800,0.03383069,0.029602502,,,,,,,,,,,,,,,,, -26900,0.039339058,0.03345758,,,,,,,,,,,,,,,,, -26930,,,0.9923691749572754,0.0243989583104848,0.5442668976545151,0.986559271812439,0.0484216548502445,0.2523562677872928,43793.0,0.98576021194458,0.0513136461377143,0.2501128544462677,43793.0,8656.649072170258,13178.992216587068,8656.649072170258,4520.522451400757,1.071984052658081,0.0 -27000,0.04470867,0.03179097,,,,,,,,,,,,,,,,, -27100,0.03967177,0.030375406,,,,,,,,,,,,,,,,, -27200,0.04134804,0.034475278,,,,,,,,,,,,,,,,, -27300,0.045205228,0.032185238,,,,,,,,,,,,,,,,, -27400,0.03636833,0.030284664,,,,,,,,,,,,,,,,, -27500,0.033686507,0.02984493,,,,,,,,,,,,,,,,, -27600,0.039265938,0.03011774,,,,,,,,,,,,,,,,, -27681,,,0.992576003074646,0.0239151343703269,0.5780123560081694,0.9865665435791016,0.0480653345584869,0.2504831156838596,43793.0,0.9857446551322936,0.0510604418814182,0.2448402843569965,43793.0,8896.665783882141,13541.615369558334,8896.665783882141,4643.076948165894,1.1032321453094482,0.0 -27700,0.035421778,0.031077292,,,,,,,,,,,,,,,,, -27800,0.03986422,0.029795006,,,,,,,,,,,,,,,,, -27900,0.05042524,0.031603217,,,,,,,,,,,,,,,,, -28000,0.043708228,0.03262327,,,,,,,,,,,,,,,,, -28100,0.04637787,0.032342635,,,,,,,,,,,,,,,,, -28200,0.040547766,0.030161934,,,,,,,,,,,,,,,,, -28300,0.038509075,0.029638425,,,,,,,,,,,,,,,,, -28400,0.044981416,0.031144176,,,,,,,,,,,,,,,,, -28428,,,0.9928958415985109,0.022750936448574,0.5812031928977817,0.9866591095924376,0.0483444854617118,0.258367335005388,43793.0,0.9858065247535706,0.0515264011919498,0.2475043142655151,43793.0,9136.74174809456,13902.822404623032,9136.74174809456,4764.155343532562,1.1362247467041016,0.0 -28500,0.035402518,0.030273419,,,,,,,,,,,,,,,,, -28600,0.03983274,0.03089994,,,,,,,,,,,,,,,,, -28700,0.047606625,0.031019524,,,,,,,,,,,,,,,,, -28800,0.04226357,0.031146854,,,,,,,,,,,,,,,,, -28900,0.046390966,0.029319003,,,,,,,,,,,,,,,,, -29000,0.046153124,0.031263437,,,,,,,,,,,,,,,,, -29100,0.043435868,0.030494107,,,,,,,,,,,,,,,,, -29172,,,0.9930545687675476,0.0223341267555952,0.615808792548507,0.9865154027938844,0.0486091189086437,0.248205262006874,43793.0,0.9856435656547546,0.0517445616424083,0.2452964177865293,43793.0,9376.930389642715,14262.695072174072,9376.930389642715,4883.786524772644,1.1681454181671145,0.0 -29200,0.04055125,0.030690674,,,,,,,,,,,,,,,,, -29300,0.050557327,0.030266758,,,,,,,,,,,,,,,,, -29400,0.045185465,0.030250577,,,,,,,,,,,,,,,,, -29500,0.04510053,0.030428693,,,,,,,,,,,,,,,,, -29600,0.04870483,0.030514864,,,,,,,,,,,,,,,,, -29700,0.048829284,0.030199174,,,,,,,,,,,,,,,,, -29800,0.05051511,0.03247656,,,,,,,,,,,,,,,,, -29900,0.043838732,0.031270646,,,,,,,,,,,,,,,,, -29922,,,0.9933610558509828,0.0212800167500972,0.6281966779933524,0.986520290374756,0.0487982630729675,0.2524698596022855,43793.0,0.9856886267662048,0.0519146211445331,0.2413855139103063,43793.0,9617.14767241478,14620.18716263771,9617.14767241478,5001.006668329239,1.2020776271820068,0.0 -30000,0.044609085,0.031572927,,,,,,,,,,,,,,,,, -30100,0.05127194,0.030413968,,,,,,,,,,,,,,,,, -30200,0.056610823,0.031215172,,,,,,,,,,,,,,,,, -30300,0.042344134,0.029973054,,,,,,,,,,,,,,,,, -30400,0.046904076,0.029690541,,,,,,,,,,,,,,,,, -30500,0.052126322,0.03194178,,,,,,,,,,,,,,,,, -30600,0.04902084,0.03142374,,,,,,,,,,,,,,,,, -30664,,,0.9941012263298036,0.0193631164729595,0.6699394839401119,0.9864890575408936,0.0492533333599567,0.2432590919025162,43793.0,0.9856064915657043,0.0523036830127239,0.2375251315964577,43793.0,9857.341331720352,14984.070784330368,9857.341331720352,5124.642949104309,1.2355139255523682,0.0 -30700,0.050462257,0.029234448,,,,,,,,,,,,,,,,, -30800,0.048335046,0.030360987,,,,,,,,,,,,,,,,, -30900,0.047883514,0.030466372,,,,,,,,,,,,,,,,, -31000,0.04943042,0.031133875,,,,,,,,,,,,,,,,, -31100,0.051660165,0.02937498,,,,,,,,,,,,,,,,, -31200,0.05062955,0.03136946,,,,,,,,,,,,,,,,, -31300,0.046960834,0.029139709,,,,,,,,,,,,,,,,, -31400,0.050560407,0.030171046,,,,,,,,,,,,,,,,, -31412,,,0.9941602349281312,0.0192818939685821,0.6668670900993225,0.9864894151687622,0.0489224307239055,0.2506274620268583,43793.0,0.9856364130973816,0.0520030371844768,0.2407133325040985,43793.0,10097.41719198227,15341.210273504255,10097.41719198227,5241.653692960739,1.2681159973144531,0.0 -31500,0.05334119,0.028173178,,,,,,,,,,,,,,,,, -31600,0.04922351,0.031107046,,,,,,,,,,,,,,,,, -31700,0.044456944,0.02715386,,,,,,,,,,,,,,,,, -31800,0.054218903,0.032101456,,,,,,,,,,,,,,,,, -31900,0.04837386,0.031535603,,,,,,,,,,,,,,,,, -32000,0.04611227,0.028371084,,,,,,,,,,,,,,,,, -32100,0.0526958,0.029917112,,,,,,,,,,,,,,,,, -32146,,,0.9937401413917542,0.0202498193830251,0.6502916140923689,0.9865487217903136,0.0496782660484313,0.2499031904886978,43793.0,0.9856595396995544,0.0528297610580921,0.2389180610951756,43793.0,10337.642447710035,15703.304780244827,10337.642447710035,5363.469901323319,1.3005335330963137,0.0 -32200,0.05053887,0.030438943,,,,,,,,,,,,,,,,, -32300,0.05674058,0.032937676,,,,,,,,,,,,,,,,, -32400,0.06898654,0.032900512,,,,,,,,,,,,,,,,, -32500,0.054084763,0.029198673,,,,,,,,,,,,,,,,, -32600,0.051563647,0.031091338,,,,,,,,,,,,,,,,, -32700,0.054621894,0.030644529,,,,,,,,,,,,,,,,, -32800,0.047951803,0.02992608,,,,,,,,,,,,,,,,, -32900,0.0569149,0.030543812,,,,,,,,,,,,,,,,, -32902,,,0.9936198592185974,0.0206041522324085,0.6400930897954951,0.986531674861908,0.0495291985571384,0.2453435833052644,43793.0,0.9856982827186584,0.0526974648237228,0.2426612461020473,43793.0,10577.850949764252,16068.123711824415,10577.850949764252,5488.026214838028,1.334458589553833,0.0 -33000,0.059014257,0.029684106,,,,,,,,,,,,,,,,, -33100,0.04829948,0.029138654,,,,,,,,,,,,,,,,, -33200,0.07015564,0.03102638,,,,,,,,,,,,,,,,, -33300,0.064863056,0.031256225,,,,,,,,,,,,,,,,, -33400,0.06066113,0.029572863,,,,,,,,,,,,,,,,, -33500,0.055346735,0.027253166,,,,,,,,,,,,,,,,, -33600,0.05822426,0.029651517,,,,,,,,,,,,,,,,, -33652,,,0.9934507012367249,0.0209422651678323,0.6310022435507462,0.9864732027053832,0.0497406870126724,0.2459634658972307,43793.0,0.9855508804321288,0.0528176799416542,0.2418377184419295,43793.0,10817.950261116028,16425.47945213318,10817.950261116028,5605.228426933289,1.3684918880462646,0.0 -33700,0.081602186,0.030467918,,,,,,,,,,,,,,,,, -33800,0.057243258,0.0306555,,,,,,,,,,,,,,,,, -33900,0.053377412,0.028444968,,,,,,,,,,,,,,,,, -34000,0.049096506,0.03113742,,,,,,,,,,,,,,,,, -34100,0.06043673,0.029273136,,,,,,,,,,,,,,,,, -34200,0.05480994,0.029559227,,,,,,,,,,,,,,,,, -34300,0.06165821,0.028216967,,,,,,,,,,,,,,,,, -34387,,,0.9934825897216796,0.0207439642399549,0.6385106217173827,0.9863489866256714,0.0502095408737659,0.2464846540457237,43793.0,0.985524356365204,0.0534192509949207,0.2359514915263177,43793.0,11058.12816143036,16786.856697797775,11058.12816143036,5726.372263431549,1.402308702468872,0.0 -34400,0.051622406,0.026754845,,,,,,,,,,,,,,,,, -34500,0.065582484,0.030161627,,,,,,,,,,,,,,,,, -34600,0.062304858,0.029454269,,,,,,,,,,,,,,,,, -34700,0.068828575,0.029394928,,,,,,,,,,,,,,,,, -34800,0.056933638,0.0286374,,,,,,,,,,,,,,,,, -34900,0.059876584,0.028888673,,,,,,,,,,,,,,,,, -35000,0.055445474,0.028965373,,,,,,,,,,,,,,,,, -35100,0.06874684,0.029898282,,,,,,,,,,,,,,,,, -35138,,,0.9936950206756592,0.0203303676098585,0.6362148804038725,0.9863278865814208,0.0502837263047695,0.2430508544160199,43793.0,0.9854733943939208,0.0535249002277851,0.2315473117127911,43793.0,11298.363315105438,17141.203130722046,11298.363315105438,5840.427979707718,1.437326431274414,0.0 -35200,0.053833954,0.026801985,,,,,,,,,,,,,,,,, -35300,0.059287492,0.030433927,,,,,,,,,,,,,,,,, -35400,0.057166617,0.028241618,,,,,,,,,,,,,,,,, -35500,0.070387736,0.030994477,,,,,,,,,,,,,,,,, -35600,0.069406584,0.027031805,,,,,,,,,,,,,,,,, -35700,0.06189213,0.029921645,,,,,,,,,,,,,,,,, -35800,0.05568505,0.027499285,,,,,,,,,,,,,,,,, -35887,,,0.993681788444519,0.019934918731451,0.6535335108823093,0.98631489276886,0.0514221414923667,0.2430296632439309,43793.0,0.9854927659034728,0.0545136667788028,0.2386775963454686,43793.0,11538.582442760468,17502.014199733734,11538.582442760468,5960.965605020523,1.471550226211548,0.0 -35900,0.062381256,0.029602528,,,,,,,,,,,,,,,,, -36000,0.06530128,0.028588062,,,,,,,,,,,,,,,,, -36100,0.06913522,0.028858347,,,,,,,,,,,,,,,,, -36200,0.059979975,0.02899201,,,,,,,,,,,,,,,,, -36300,0.05678974,0.028699407,,,,,,,,,,,,,,,,, -36400,0.065693595,0.028306386,,,,,,,,,,,,,,,,, -36500,0.0691558,0.031161318,,,,,,,,,,,,,,,,, -36600,0.05995297,0.029295195,,,,,,,,,,,,,,,,, -36621,,,0.9945728182792664,0.0176827441900968,0.688483399506173,0.986333966255188,0.051441915333271,0.244130999515316,43793.0,0.9854645133018494,0.0545359924435615,0.2343355525502071,43793.0,11778.559102535248,17865.818506240845,11778.559102535248,6084.737009048462,1.5058512687683103,0.0 -36700,0.05676802,0.027798137,,,,,,,,,,,,,,,,, -36800,0.063754365,0.027563347,,,,,,,,,,,,,,,,, -36900,0.06308863,0.02962742,,,,,,,,,,,,,,,,, -37000,0.06016212,0.027648194,,,,,,,,,,,,,,,,, -37100,0.060909625,0.02695221,,,,,,,,,,,,,,,,, -37200,0.063938394,0.029168012,,,,,,,,,,,,,,,,, -37300,0.09607778,0.030303115,,,,,,,,,,,,,,,,, -37372,,,0.9952541589736938,0.0158701855689287,0.7463612898447716,0.9863250255584716,0.0521195009350776,0.2393038359651531,43793.0,0.9855479598045348,0.0552063584327697,0.235362061924285,43793.0,12018.643469572067,18226.29783797264,12018.643469572067,6205.075866937637,1.5408875942230225,0.0 -37400,0.06847924,0.029260056,,,,,,,,,,,,,,,,, -37500,0.06083018,0.028041707,,,,,,,,,,,,,,,,, -37600,0.066452,0.028495941,,,,,,,,,,,,,,,,, -37700,0.058340747,0.027537601,,,,,,,,,,,,,,,,, -37800,0.07775353,0.028990306,,,,,,,,,,,,,,,,, -37900,0.06421923,0.0269745,,,,,,,,,,,,,,,,, -38000,0.06970971,0.02913537,,,,,,,,,,,,,,,,, -38100,0.06797527,0.028132565,,,,,,,,,,,,,,,,, -38121,,,0.994975447654724,0.016621870920062,0.7227896262201021,0.986448049545288,0.0518963262438774,0.2398712386816698,43793.0,0.9855525493621826,0.0550990290939807,0.2385763605440669,43793.0,12258.72385263443,18581.881452083588,12258.72385263443,6320.522824525833,1.5767641067504885,0.0 -38200,0.07369253,0.028874937,,,,,,,,,,,,,,,,, -38300,0.071115196,0.027286733,,,,,,,,,,,,,,,,, -38400,0.06470241,0.02742684,,,,,,,,,,,,,,,,, -38500,0.06530871,0.02544442,,,,,,,,,,,,,,,,, -38600,0.070833325,0.028259108,,,,,,,,,,,,,,,,, -38700,0.075545095,0.026103519,,,,,,,,,,,,,,,,, -38800,0.063071944,0.026560431,,,,,,,,,,,,,,,,, -38874,,,0.994772493839264,0.0167333744466304,0.714931953560515,0.986291766166687,0.0532053150236606,0.2356094752372903,43793.0,0.9855129718780518,0.0561452880501747,0.2376918945252976,43793.0,12498.797527074814,18939.17649841309,12498.797527074814,6437.69078040123,1.609722137451172,0.0 -38900,0.07720718,0.027196815,,,,,,,,,,,,,,,,, -39000,0.07268668,0.026808495,,,,,,,,,,,,,,,,, -39100,0.06803102,0.026946135,,,,,,,,,,,,,,,,, -39200,0.07214293,0.027666984,,,,,,,,,,,,,,,,, -39300,0.07312157,0.028406456,,,,,,,,,,,,,,,,, -39400,0.077106,0.027043859,,,,,,,,,,,,,,,,, -39500,0.07394126,0.026616063,,,,,,,,,,,,,,,,, -39600,0.076236896,0.02857884,,,,,,,,,,,,,,,,, -39625,,,0.994320273399353,0.0178232826292514,0.705524199624799,0.986319363117218,0.0530956424772739,0.2403834761443907,43793.0,0.9854274988174438,0.0564156621694564,0.2301816212423666,43793.0,12738.750935316086,19294.999740600582,12738.750935316086,6553.50083398819,1.6491947174072266,0.0 -39700,0.082146905,0.026981914,,,,,,,,,,,,,,,,, -39800,0.15305391,0.030165395,,,,,,,,,,,,,,,,, -39900,0.07256932,0.02761526,,,,,,,,,,,,,,,,, -40000,0.07071108,0.027166022,,,,,,,,,,,,,,,,, -40100,0.07030934,0.026736729,,,,,,,,,,,,,,,,, -40200,0.072485626,0.028355766,,,,,,,,,,,,,,,,, -40300,0.06281087,0.026637308,,,,,,,,,,,,,,,,, -40377,,,0.9941571950912476,0.0183549374341964,0.6876441021637686,0.986297845840454,0.0534391738474369,0.2370363956850179,43793.0,0.9854485392570496,0.0566561445593833,0.2315639609968926,43793.0,12978.82855129242,19645.99722838401,12978.82855129242,6664.366796016693,1.6828031539916992,0.0 -40400,0.07351127,0.026923299,,,,,,,,,,,,,,,,, -40500,0.074320234,0.028031148,,,,,,,,,,,,,,,,, -40600,0.07634843,0.028763017,,,,,,,,,,,,,,,,, -40700,0.081170686,0.026347836,,,,,,,,,,,,,,,,, -40800,0.07086652,0.02676453,,,,,,,,,,,,,,,,, -40900,0.070990436,0.028197704,,,,,,,,,,,,,,,,, -41000,0.07959488,0.02638378,,,,,,,,,,,,,,,,, -41100,0.07897372,0.025454612,,,,,,,,,,,,,,,,, -41136,,,0.9942678809165956,0.0182990487664937,0.6918804230446601,0.9859828352928162,0.0525840632617473,0.2357247872052208,43793.0,0.9851621389389038,0.055863220244646,0.2326064034439338,43793.0,13219.0130982399,20003.628672599792,13219.0130982399,6781.759689092636,1.7166087627410889,0.0 -41200,0.07352398,0.025533024,,,,,,,,,,,,,,,,, -41300,0.08000225,0.027719872,,,,,,,,,,,,,,,,, -41400,0.098088175,0.028930737,,,,,,,,,,,,,,,,, -41500,0.080975205,0.027423965,,,,,,,,,,,,,,,,, -41600,0.084855184,0.02610054,,,,,,,,,,,,,,,,, -41700,0.07566962,0.026658911,,,,,,,,,,,,,,,,, -41800,0.07954737,0.026212795,,,,,,,,,,,,,,,,, -41890,,,0.9942914843559264,0.0179653763771057,0.701986704581986,0.9862337112426758,0.05374501273036,0.2348137809159835,43793.0,0.9853798747062684,0.0568219721317291,0.2324982355114241,43793.0,13459.264628887177,20356.678092479706,13459.264628887177,6894.503623247147,1.7502388954162598,0.0 -41900,0.07674712,0.027993603,,,,,,,,,,,,,,,,, -42000,0.078413405,0.026041074,,,,,,,,,,,,,,,,, -42100,0.070883,0.027300578,,,,,,,,,,,,,,,,, -42200,0.07566156,0.027091792,,,,,,,,,,,,,,,,, -42300,0.08215933,0.027526515,,,,,,,,,,,,,,,,, -42400,0.07238214,0.0260107,,,,,,,,,,,,,,,,, -42500,0.078260794,0.025301093,,,,,,,,,,,,,,,,, -42600,0.08238443,0.026600169,,,,,,,,,,,,,,,,, -42636,,,0.9941781163215636,0.0178943015635013,0.6990086529737705,0.9864057898521424,0.0544170401990413,0.2398192574255114,43793.0,0.9854514598846436,0.058204285800457,0.2309237291602419,43793.0,13699.364458560944,20712.2668569088,13699.364458560944,7009.93562078476,1.7863051891326904,0.0 -42700,0.07470922,0.026382845,,,,,,,,,,,,,,,,, -42800,0.08311963,0.026719386,,,,,,,,,,,,,,,,, -42900,0.084144965,0.028447771,,,,,,,,,,,,,,,,, -43000,0.07421604,0.026140308,,,,,,,,,,,,,,,,, -43100,0.10352311,0.027572311,,,,,,,,,,,,,,,,, -43200,0.07524329,0.02659991,,,,,,,,,,,,,,,,, -43300,0.080508344,0.026042921,,,,,,,,,,,,,,,,, -43393,,,0.9946008920669556,0.0168354269117116,0.7236670943979682,0.9863424897193908,0.0547343268990516,0.2367951793593295,43793.0,0.985403060913086,0.0582301169633865,0.2280993431328843,43793.0,13939.344955205916,21068.913420438766,13939.344955205916,7126.545880794525,1.8220069408416748,0.0 -43400,0.09484634,0.026384206,,,,,,,,,,,,,,,,, -43500,0.079214334,0.026844751,,,,,,,,,,,,,,,,, -43600,0.093275234,0.02772149,,,,,,,,,,,,,,,,, -43700,0.0754582,0.024977019,,,,,,,,,,,,,,,,, -43800,0.08458179,0.027570253,,,,,,,,,,,,,,,,, -43900,0.09090638,0.027118789,,,,,,,,,,,,,,,,, -44000,0.10377314,0.02679177,,,,,,,,,,,,,,,,, -44100,0.07874482,0.025769548,,,,,,,,,,,,,,,,, -44147,,,0.9953388571739196,0.0151414508000016,0.7568423681622725,0.9860754013061525,0.054790049791336,0.2288448749028145,43793.0,0.9852429628372192,0.0583256147801876,0.2253432139066692,43793.0,14179.578085184095,21418.554672002792,14179.578085184095,7235.889031648636,1.866541862487793,0.0 -44200,0.07511849,0.025422035,,,,,,,,,,,,,,,,, -44300,0.09172794,0.025806556,,,,,,,,,,,,,,,,, -44400,0.0828691,0.027515624,,,,,,,,,,,,,,,,, -44500,0.088691354,0.026495663,,,,,,,,,,,,,,,,, -44600,0.07987135,0.02513071,,,,,,,,,,,,,,,,, -44700,0.07877775,0.02587164,,,,,,,,,,,,,,,,, -44800,0.09153914,0.026899634,,,,,,,,,,,,,,,,, -44884,,,0.996155858039856,0.0133711593225598,0.7954491680973936,0.986195147037506,0.0549307949841022,0.2399991624461255,43793.0,0.9854169487953186,0.0582167506217956,0.2274132334812646,43793.0,14419.814072847366,21774.352262735367,14419.814072847366,7351.393758773804,1.9024038314819336,0.0 -44900,0.069565915,0.025197318,,,,,,,,,,,,,,,,, -45000,0.107469946,0.027783146,,,,,,,,,,,,,,,,, -45100,0.08388229,0.027366191,,,,,,,,,,,,,,,,, -45200,0.08300631,0.025651995,,,,,,,,,,,,,,,,, -45300,0.09632407,0.0256099,,,,,,,,,,,,,,,,, -45400,0.08608835,0.02570038,,,,,,,,,,,,,,,,, -45500,0.09378766,0.0251899,,,,,,,,,,,,,,,,, -45600,0.085164614,0.02426286,,,,,,,,,,,,,,,,, -45620,,,0.996436893939972,0.0128021948039531,0.7997349957824464,0.986088752746582,0.055894199758768,0.2335008345089164,43793.0,0.9852564930915833,0.0593939758837223,0.2220551116809077,43793.0,14659.792273283005,22128.552355766296,14659.792273283005,7465.55695939064,1.9380853176116943,0.0 -45700,0.09125164,0.025691103,,,,,,,,,,,,,,,,, -45800,0.08179398,0.025114547,,,,,,,,,,,,,,,,, -45900,0.09176662,0.024884108,,,,,,,,,,,,,,,,, -46000,0.08105781,0.025680853,,,,,,,,,,,,,,,,, -46100,0.08210017,0.025718337,,,,,,,,,,,,,,,,, -46200,0.08578084,0.023431536,,,,,,,,,,,,,,,,, -46300,0.07817081,0.026396755,,,,,,,,,,,,,,,,, -46380,,,0.995314359664917,0.0151610001921653,0.7500370780843769,0.9860250353813172,0.0560168400406837,0.2258241513802438,43793.0,0.9850963950157166,0.05946921184659,0.2228981628786851,43793.0,14899.860787391664,22484.813533067703,14899.860787391664,7581.692788124084,1.9741096496582031,0.0 -46400,0.08717627,0.025483787,,,,,,,,,,,,,,,,, -46500,0.086934686,0.02480655,,,,,,,,,,,,,,,,, -46600,0.08026773,0.02511676,,,,,,,,,,,,,,,,, -46700,0.08799622,0.025183065,,,,,,,,,,,,,,,,, -46800,0.08035992,0.024716947,,,,,,,,,,,,,,,,, -46900,0.08105121,0.02514385,,,,,,,,,,,,,,,,, -47000,0.082895085,0.025501821,,,,,,,,,,,,,,,,, -47100,0.0861091,0.0252945,,,,,,,,,,,,,,,,, -47129,,,0.9956495761871338,0.0142398113384842,0.7698141151166427,0.9861208200454712,0.0565310157835483,0.2288986155071729,43793.0,0.9853002429008484,0.0599371455609798,0.2316235445308457,43793.0,15139.825226068497,22837.559860944748,15139.825226068497,7694.41745185852,2.01121473312378,0.0 -47200,0.08554657,0.024658823,,,,,,,,,,,,,,,,, -47300,0.09210719,0.026046928,,,,,,,,,,,,,,,,, -47400,0.08912134,0.024915699,,,,,,,,,,,,,,,,, -47500,0.07800409,0.025453873,,,,,,,,,,,,,,,,, -47600,0.08964256,0.024213357,,,,,,,,,,,,,,,,, -47700,0.095364824,0.02453294,,,,,,,,,,,,,,,,, -47800,0.08737711,0.025019364,,,,,,,,,,,,,,,,, -47879,,,0.9952120780944824,0.015117822214961,0.7573643665429037,0.9860530495643616,0.0564565919339656,0.2272830550792519,43793.0,0.9852320551872252,0.0598737373948097,0.2277907845115211,43793.0,15379.947883367538,23192.9143447876,15379.947883367538,7809.593261241913,2.0465872287750244,0.0 -47900,0.0915774,0.026477892,,,,,,,,,,,,,,,,, -48000,0.07986718,0.023829196,,,,,,,,,,,,,,,,, -48100,0.10211986,0.024739493,,,,,,,,,,,,,,,,, -48200,0.09896483,0.025071776,,,,,,,,,,,,,,,,, -48300,0.094872385,0.025721023,,,,,,,,,,,,,,,,, -48400,0.08189562,0.024032662,,,,,,,,,,,,,,,,, -48500,0.08173989,0.024797091,,,,,,,,,,,,,,,,, -48600,0.09529244,0.025380788,,,,,,,,,,,,,,,,, -48625,,,0.9950177669525146,0.0155489360913634,0.7415567667563101,0.9860798716545104,0.0569931715726852,0.2300238933443485,43793.0,0.9852830171585084,0.0603629611432552,0.2280650358780248,43793.0,15620.005237102509,23545.832807779312,15620.005237102509,7922.388898611069,2.09068250656128,0.0 -48700,0.07913523,0.025005043,,,,,,,,,,,,,,,,, -48800,0.07711375,0.024208233,,,,,,,,,,,,,,,,, -48900,0.09728167,0.027561024,,,,,,,,,,,,,,,,, -49000,0.08616992,0.024763212,,,,,,,,,,,,,,,,, -49100,0.089908674,0.026042936,,,,,,,,,,,,,,,,, -49200,0.08959636,0.02445629,,,,,,,,,,,,,,,,, -49300,0.0992151,0.025274409,,,,,,,,,,,,,,,,, -49368,,,0.9946566820144652,0.0164572186768054,0.7364141469888428,0.9861857891082764,0.057639967650175,0.2242057574321107,43793.0,0.9852412939071656,0.0611497797071933,0.2257478494693963,43793.0,15860.014887571337,23905.479996204376,15860.014887571337,8041.968768358231,2.127811908721924,0.0 -49400,0.087635614,0.024790214,,,,,,,,,,,,,,,,, -49500,0.09397706,0.024651503,,,,,,,,,,,,,,,,, -49600,0.10687463,0.024220435,,,,,,,,,,,,,,,,, -49700,0.0840943,0.02414346,,,,,,,,,,,,,,,,, -49800,0.08580417,0.02448154,,,,,,,,,,,,,,,,, -49900,0.08493313,0.025241103,,,,,,,,,,,,,,,,, -50000,0.08513395,0.024187645,,,,,,,,,,,,,,,,, -50100,0.097203255,0.024409646,,,,,,,,,,,,,,,,, -50121,,,0.9957500696182252,0.0139283929020166,0.7854547276249997,0.9860916137695312,0.0574779324233531,0.2286573693495823,43793.0,0.9851734638214112,0.0609352216124534,0.2285361877579136,43793.0,16100.006301641464,24264.06285619736,16100.006301641464,8160.495388507843,2.1722218990325928,0.0 -50200,0.097806334,0.023637034,,,,,,,,,,,,,,,,, -50300,0.09120374,0.0233372,,,,,,,,,,,,,,,,, -50400,0.09166185,0.022822525,,,,,,,,,,,,,,,,, -50500,0.079593614,0.025337417,,,,,,,,,,,,,,,,, -50600,0.094288416,0.024785385,,,,,,,,,,,,,,,,, -50700,0.08621185,0.02484056,,,,,,,,,,,,,,,,, -50800,0.09310931,0.024264997,,,,,,,,,,,,,,,,, -50869,,,0.99623841047287,0.0125467721372842,0.8122539594282422,0.9861472249031068,0.0581727661192417,0.2258159763884001,43793.0,0.985240876674652,0.0616643354296684,0.2263295055390688,43793.0,16340.236045360563,24617.448274374008,16340.236045360563,8273.594465494156,2.208230495452881,0.0 -50900,0.07966488,0.023101524,,,,,,,,,,,,,,,,, -51000,0.08418196,0.02447747,,,,,,,,,,,,,,,,, -51100,0.0777879,0.023514943,,,,,,,,,,,,,,,,, -51200,0.077449486,0.022287678,,,,,,,,,,,,,,,,, -51300,0.07637907,0.02443909,,,,,,,,,,,,,,,,, -51400,0.08279913,0.02351591,,,,,,,,,,,,,,,,, -51500,0.10051569,0.024540935,,,,,,,,,,,,,,,,, -51600,0.095395595,0.024025336,,,,,,,,,,,,,,,,, -51620,,,0.9967470765113832,0.0114964265376329,0.8367234392536422,0.986141562461853,0.0593289285898208,0.2274897693187031,43793.0,0.9852067828178406,0.0629700496792793,0.2218210113962008,43793.0,16580.56310081482,24970.902344703674,16580.56310081482,8386.665300130844,2.2444097995758057,0.0 -51700,0.09981086,0.023449685,,,,,,,,,,,,,,,,, -51800,0.08156553,0.023109028,,,,,,,,,,,,,,,,, -51900,0.09649786,0.02398552,,,,,,,,,,,,,,,,, -52000,0.09558176,0.02428505,,,,,,,,,,,,,,,,, -52100,0.08942554,0.024020612,,,,,,,,,,,,,,,,, -52200,0.084230535,0.02266952,,,,,,,,,,,,,,,,, -52300,0.104660064,0.023015026,,,,,,,,,,,,,,,,, -52381,,,0.9970282316207886,0.0108794895932078,0.8544680011223145,0.9861334562301636,0.0595301464200019,0.2277542538640574,43793.0,0.9852383732795716,0.0631648600101471,0.223539776319986,43793.0,16820.774089574814,25320.17582321167,16820.774089574814,8495.671817302704,2.2801904678344727,0.0 -52400,0.086249836,0.022793652,,,,,,,,,,,,,,,,, -52500,0.10051536,0.02385584,,,,,,,,,,,,,,,,, -52600,0.14096174,0.023610845,,,,,,,,,,,,,,,,, -52700,0.076738186,0.023502283,,,,,,,,,,,,,,,,, -52800,0.081347264,0.022572262,,,,,,,,,,,,,,,,, -52900,0.09955562,0.025859937,,,,,,,,,,,,,,,,, -53000,0.07799416,0.02243636,,,,,,,,,,,,,,,,, -53100,0.083359286,0.021908367,,,,,,,,,,,,,,,,, -53126,,,0.9972330331802368,0.0105284303426742,0.8540056525446815,0.9861581921577454,0.0593896433711051,0.227870746548491,43793.0,0.9852033853530884,0.0632070749998092,0.2219429456533263,43793.0,17060.759781360626,25673.513555049896,17060.759781360626,8608.965829610825,2.3182740211486816,0.0 -53200,0.08545674,0.022557946,,,,,,,,,,,,,,,,, -53300,0.12083102,0.024232078,,,,,,,,,,,,,,,,, -53400,0.10172915,0.024289472,,,,,,,,,,,,,,,,, -53500,0.08644103,0.023839343,,,,,,,,,,,,,,,,, -53600,0.085603625,0.024239894,,,,,,,,,,,,,,,,, -53700,0.086165085,0.022421747,,,,,,,,,,,,,,,,, -53800,0.08303787,0.023537051,,,,,,,,,,,,,,,,, -53873,,,0.9966502785682678,0.011681037954986,0.835379206205047,0.9860693216323853,0.0602785646915435,0.2249466474666006,43793.0,0.9851911664009094,0.0638005584478378,0.2149521823725836,43793.0,17300.893683433533,26027.95757389069,17300.893683433533,8723.215174674988,2.3571720123291016,0.0 -53900,0.08420102,0.023848783,,,,,,,,,,,,,,,,, -54000,0.102120295,0.024169771,,,,,,,,,,,,,,,,, -54100,0.08715389,0.023523659,,,,,,,,,,,,,,,,, -54200,0.09495542,0.024251584,,,,,,,,,,,,,,,,, -54300,0.091471694,0.021747595,,,,,,,,,,,,,,,,, -54400,0.09503355,0.023236908,,,,,,,,,,,,,,,,, -54500,0.081001274,0.023446413,,,,,,,,,,,,,,,,, -54600,0.09852863,0.023257993,,,,,,,,,,,,,,,,, -54620,,,0.9948896765708924,0.0155108226463198,0.7515258127315602,0.9860928654670716,0.0601283572614192,0.22509716084751,43793.0,0.9852109551429749,0.0639358758926391,0.2158959223624439,43793.0,17540.84679889679,26381.413821458817,17540.84679889679,8836.655643463135,2.39902138710022,0.0 -54700,0.09782594,0.02375496,,,,,,,,,,,,,,,,, -54800,0.09374601,0.022324806,,,,,,,,,,,,,,,,, -54900,0.094350874,0.022061512,,,,,,,,,,,,,,,,, -55000,0.09312722,0.023152893,,,,,,,,,,,,,,,,, -55100,0.08306823,0.02226713,,,,,,,,,,,,,,,,, -55200,0.08290864,0.022322977,,,,,,,,,,,,,,,,, -55300,0.087813824,0.022909448,,,,,,,,,,,,,,,,, -55367,,,0.9954902529716492,0.0140686891973018,0.7815939752419798,0.9859552383422852,0.0598899014294147,0.2264501280562338,43793.0,0.985027313232422,0.0636545941233635,0.2154933151832759,43793.0,17780.836613178253,26730.361434221268,17780.836613178253,8945.55486869812,2.4369421005249023,0.0 -55400,0.08207027,0.021694677,,,,,,,,,,,,,,,,, -55500,0.08713838,0.02412149,,,,,,,,,,,,,,,,, -55600,0.08350408,0.022846343,,,,,,,,,,,,,,,,, -55700,0.08152761,0.022125527,,,,,,,,,,,,,,,,, -55800,0.09505553,0.023573104,,,,,,,,,,,,,,,,, -55900,0.08921535,0.021943148,,,,,,,,,,,,,,,,, -56000,0.09410923,0.024301922,,,,,,,,,,,,,,,,, -56100,0.079573706,0.022032931,,,,,,,,,,,,,,,,, -56126,,,0.9959081411361694,0.0127335209399461,0.8261949531777921,0.986108660697937,0.0612089298665523,0.2227735685321576,43793.0,0.9852176904678344,0.0648197308182716,0.2176275776241821,43793.0,18021.015577316284,27084.50729203224,18021.015577316284,9059.463624238968,2.475022077560425,0.0 -56200,0.089596905,0.023203868,,,,,,,,,,,,,,,,, -56300,0.09437663,0.022093616,,,,,,,,,,,,,,,,, -56400,0.0776937,0.021611018,,,,,,,,,,,,,,,,, -56500,0.104470275,0.023929205,,,,,,,,,,,,,,,,, -56600,0.08986239,0.021840785,,,,,,,,,,,,,,,,, -56700,0.08366497,0.022377048,,,,,,,,,,,,,,,,, -56800,0.06712424,0.020538712,,,,,,,,,,,,,,,,, -56866,,,0.995477259159088,0.0137623539194464,0.8133653761740869,0.9861249327659608,0.0618376545608043,0.2242386545209901,43793.0,0.985187828540802,0.0656175911426544,0.2167473575816602,43793.0,18261.217477321625,27437.100608110428,18261.217477321625,9171.793118715286,2.5136139392852783,0.0 -56900,0.082542166,0.02212639,,,,,,,,,,,,,,,,, -57000,0.08734202,0.02193411,,,,,,,,,,,,,,,,, -57100,0.09662248,0.022642277,,,,,,,,,,,,,,,,, -57200,0.07827202,0.021583844,,,,,,,,,,,,,,,,, -57300,0.08814766,0.021673435,,,,,,,,,,,,,,,,, -57400,0.10474334,0.020977153,,,,,,,,,,,,,,,,, -57500,0.08637113,0.022516236,,,,,,,,,,,,,,,,, -57545,,,,,,,,,,,,,,18477.311544656754,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 134cdaa27..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -114.73655438423155,0.0,11.743421077728271,1,0,11.743421077728271,0.5256850719451904,0.7376683354377747,0.0260580168197227,43793,126.48002123832704,0.5291368365287781,0.7363722324371338,0.0218029747063758,0.5270814895629883,0.737440824508667,0.0240470210742928,43793 -231.53080677986145,0.0203430652618408,252.02333188056943,749,0,252.02333188056943,0.983142077922821,0.0799547135829925,0.0433793901311975,43793,483.5957524776459,0.9866331815719604,0.0689497143030166,0.0381990878470789,0.9841179251670836,0.0770890638232231,0.0413221289476577,43793 -345.62128949165344,0.0478668212890625,492.2781195640564,1499,0,492.2781195640564,0.9835198521614076,0.0610214844346046,0.0858427855891884,43793,837.9886548519135,0.9871978163719176,0.0483209080994129,0.0859351007741268,0.984529972076416,0.0576369874179363,0.0872120753564924,43793 -466.60114884376526,0.0754346847534179,732.3831737041473,2248,0,732.3831737041473,0.9840531349182128,0.0561034195125103,0.1367317304202627,43793,1199.1216876506803,0.98789244890213,0.0435101650655269,0.1318084182003098,0.9850869178771972,0.0531092658638954,0.1295160089173986,43793 -575.1102995872498,0.1052591800689697,972.3318812847136,2997,0,972.3318812847136,0.9842986464500428,0.053968284279108,0.1536918482128193,43793,1547.6310951709747,0.9880146980285645,0.0421840101480484,0.1713988685245202,0.9852290153503418,0.0512479990720748,0.154959509498418,43793 -689.58882188797,0.1334693431854248,1212.3737390041351,3750,0,1212.3737390041351,0.9845029711723328,0.0533266179263591,0.1682972779868397,43793,1902.1999762058256,0.988412380218506,0.0405141748487949,0.1719516109836918,0.9853410124778748,0.0503965727984905,0.1698219331341683,43793 -808.075345993042,0.1630210876464843,1452.4729554653168,4503,0,1452.4729554653168,0.9845539331436156,0.0520947463810443,0.1769383818230964,43793,2260.835620880127,0.9884085059165956,0.0401086173951625,0.2116189163327401,0.9854055643081664,0.0493852011859416,0.1808792961795491,43793 -924.7717523574828,0.190324068069458,1692.5093190670011,5253,0,1692.5093190670011,0.9849397540092468,0.0509831346571445,0.1961209567121776,43793,2617.616218805313,0.9887412786483764,0.0384508743882179,0.2200065332386008,0.9858241081237792,0.048249889165163,0.1957211458247129,43793 -1036.267024755478,0.2197639942169189,1932.4763493537905,6003,0,1932.4763493537905,0.9850517511367798,0.0505339317023754,0.2111077347749315,43793,2969.128954410553,0.9889976382255554,0.0376361981034278,0.2488204824404136,0.9858983755111694,0.0479212068021297,0.2057879012354354,43793 -1149.54962849617,0.2551395893096924,2172.550704717636,6755,0,2172.550704717636,0.9851566553115844,0.0493510365486145,0.2218740276827641,43793,3322.5414004325867,0.9892213940620422,0.0366696007549762,0.2634163786222182,0.986061990261078,0.0467621460556983,0.2146012390955433,43793 -1267.2843120098114,0.2840569019317627,2412.738084077835,7516,0,2412.738084077835,0.9853032231330872,0.0493544600903987,0.2166553914108559,43793,3680.512838125229,0.989296853542328,0.0359546542167663,0.278105170758098,0.9860575199127196,0.0468306243419647,0.2125293204426107,43793 -1377.429646730423,0.3117320537567138,2652.844264984131,8273,0,2652.844264984131,0.9853959083557128,0.0489659383893013,0.2248276800619779,43793,4030.8121387958527,0.9894575476646424,0.035493679344654,0.280584196580796,0.9862965941429138,0.0462246499955654,0.2248349779677364,43793 -1489.488579750061,0.3412954807281494,2892.8968245983124,9024,0,2892.8968245983124,0.9855276942253112,0.0486106649041175,0.2302574675518472,43793,4382.973264217377,0.9894021153450012,0.0356401763856411,0.2816182237666848,0.9863436818122864,0.0458505451679229,0.2299495299738932,43793 -1602.8503172397614,0.3689990043640136,3132.904645681381,9776,0,3132.904645681381,0.985284686088562,0.0488116517663002,0.2275440479991647,43793,4736.391147851944,0.989449679851532,0.0354800149798393,0.2902229885165782,0.9861695766448976,0.0460140667855739,0.228143906935731,43793 -1719.3369302749634,0.4020838737487793,3373.1482582092285,10529,0,3373.1482582092285,0.985558032989502,0.0482427552342414,0.232955356440201,43793,5093.1759622097015,0.9898571372032166,0.0340229794383049,0.3148685575602764,0.9864293336868286,0.0454701818525791,0.2401039852908579,43793 -1828.657984495163,0.4322450160980224,3613.232980489731,11283,0,3613.232980489731,0.9854072332382202,0.0483103096485137,0.2369131371307541,43793,5442.6319761276245,0.989729344844818,0.0340922735631465,0.316281491512823,0.9861927032470704,0.0455793216824531,0.2445057253378528,43793 -1943.490604162216,0.4619700908660888,3853.240893602371,12043,0,3853.240893602371,0.985649049282074,0.0477396845817565,0.2417897996837219,43793,5797.5224277973175,0.9900323748588562,0.0331622846424579,0.341318229949142,0.986455738544464,0.045129720121622,0.2422295153032972,43793 -2055.456913471222,0.8892319202423096,4092.884313106537,12792,0,4092.884313106537,0.9857446551322936,0.047577828168869,0.2455561227362303,43793,6149.5794270038605,0.990278661251068,0.0320963263511657,0.3793567029038292,0.9866027235984802,0.0446693301200866,0.2507951353642788,43793 -2166.441266775131,0.9181814193725586,4333.080197811127,13540,0,4333.080197811127,0.9855934381484984,0.0475713983178138,0.2468264667660303,43793,6500.808378696442,0.9903260469436646,0.0318269245326519,0.3630988382990801,0.986552357673645,0.0445600450038909,0.2593600780039394,43793 -2280.971666574478,0.9479734897613524,4573.193847417831,14285,0,4573.193847417831,0.9856544733047484,0.0473392717540264,0.2469929637029501,43793,6855.502467632294,0.990658700466156,0.0308063067495822,0.3933822209818916,0.9864890575408936,0.0447542667388916,0.2566857614723372,43793 -2391.64150595665,0.9773366451263428,4813.296638011932,15030,0,4813.296638011932,0.9856852293014526,0.0476922467350959,0.2472861083323727,43793,7206.325021743774,0.99084734916687,0.030131122097373,0.3993562358528172,0.9865454435348512,0.0447263419628143,0.2587401186460871,43793 -2501.61195063591,1.0068624019622805,5053.413534879684,15774,0,5053.413534879684,0.985791802406311,0.0475449040532112,0.2527092747568265,43793,7556.462839126587,0.990815281867981,0.0303503256291151,0.3971719559810763,0.9866440892219543,0.0446596071124076,0.265979384637311,43793 -2609.4342000484467,1.0378234386444092,5293.414917469025,16528,0,5293.414917469025,0.9855433106422424,0.0480347089469432,0.2441168465974053,43793,7904.33753156662,0.990627944469452,0.0308685638010501,0.3913606469411551,0.9864423274993896,0.0450369007885456,0.2502382753996813,43793 -2719.6714749336243,1.0690538883209229,5533.490139722824,17281,0,5533.490139722824,0.9857496619224548,0.0472060106694698,0.2529896059917935,43793,8254.701673984528,0.9905187487602234,0.0311160124838352,0.3874007734505545,0.9865190982818604,0.0445390082895755,0.2598120879231033,43793 -2833.4104483127594,1.101609468460083,5773.559222221375,18027,0,5773.559222221375,0.9858617186546326,0.0472342632710933,0.2590980910827544,43793,8608.562299251556,0.990691065788269,0.0305870901793241,0.3921542590010484,0.9866855144500732,0.0443942174315452,0.2646887457085707,43793 -2942.844471931457,1.133225440979004,6013.634890794754,18772,0,6013.634890794754,0.9856511354446412,0.048032097518444,0.2462150144339345,43793,8958.123777866364,0.990775227546692,0.0302655268460512,0.3949178652503551,0.98653244972229,0.0453040190041065,0.2542964945301081,43793 -3053.710742712021,1.1633639335632324,6253.68617773056,19524,0,6253.68617773056,0.9857951998710632,0.0473557077348232,0.2504987704034213,43793,9309.091790914536,0.990888774394989,0.0298353638499975,0.4150541904359339,0.9865884780883788,0.044611282646656,0.257674980871371,43793 -3162.178370714188,1.1930761337280271,6493.699534893036,20281,0,6493.699534893036,0.985710084438324,0.0476002022624015,0.2503044683984956,43793,9657.622642278671,0.99105304479599,0.0290999766439199,0.4372819758899958,0.986669659614563,0.0448590330779552,0.2627159908102456,43793 -3271.0463457107544,1.224475383758545,6733.768709421158,21042,0,6733.768709421158,0.9857838153839112,0.0475145317614078,0.2609902777836348,43793,10006.611500024796,0.9911764860153198,0.0285094995051622,0.4408113644699721,0.9867488145828248,0.044682178646326,0.2665122163037616,43793 -3381.8595008850098,1.2595610618591309,6974.013297080994,21776,0,6974.013297080994,0.985862135887146,0.0472691245377063,0.2573478156117922,43793,10357.728278398514,0.991248607635498,0.0282602701336145,0.4538668844774671,0.9866753816604614,0.0444318726658821,0.2672081493576128,43793 -3493.4931180477142,1.289865255355835,7213.988672018051,22529,0,7213.988672018051,0.9856717586517334,0.047910563647747,0.2484504127180709,43793,10709.38787341118,0.9914287328720092,0.0280359983444213,0.4523613957712486,0.9865572452545166,0.0449055433273315,0.2573994457154161,43793 -3604.5464034080505,1.3261759281158447,7453.943880081177,23263,0,7453.943880081177,0.9859034419059752,0.0473452322185039,0.2580197824040367,43793,11060.453907966614,0.9912769198417664,0.0283569395542144,0.4405505491753948,0.986663579940796,0.0446783117949962,0.2679886704370666,43793 -3712.1588644981375,1.3611788749694824,7694.173835515976,23997,0,7694.173835515976,0.9859076142311096,0.0473978109657764,0.2571818691088361,43793,11408.35333752632,0.9911851286888124,0.0288587547838687,0.4346443075859728,0.9867857694625854,0.0445174388587474,0.2709604029205279,43793 -3818.455191135408,1.3941354751586914,7934.371356964111,24753,0,7934.371356964111,0.9857720136642456,0.0477791726589202,0.2502596400985658,43793,11754.900280475616,0.9910281896591188,0.0290977265685796,0.4323543772737795,0.9866266250610352,0.0450372099876403,0.2616482551665179,43793 -3924.865197896957,1.4252452850341797,8174.356585264206,25506,0,8174.356585264206,0.985776662826538,0.048466894775629,0.2461963545937001,43793,12101.346946954727,0.990942358970642,0.0292593203485012,0.430151878859003,0.986642062664032,0.0455658473074436,0.257154364719039,43793 -4037.365210533142,1.4570987224578855,8414.532056331635,26254,0,8414.532056331635,0.9858777523040771,0.048056062310934,0.2581640623044151,43793,12454.075824022291,0.9912765026092528,0.0283831879496574,0.4462405974678892,0.9867240786552428,0.0452526807785034,0.2645874325719716,43793 -4140.939670324326,1.4889216423034668,8654.747866868973,27008,0,8654.747866868973,0.985755980014801,0.0479071885347366,0.2624967264405229,43793,12797.918403863909,0.99150550365448,0.0274326615035533,0.4741881005743563,0.9866209626197816,0.0449338443577289,0.2674430369710101,43793 -4246.946130990982,1.519975185394287,8894.826835393906,27762,0,8894.826835393906,0.9856725931167604,0.0478430762887001,0.2581436799319148,43793,13144.054966926577,0.9915337562561036,0.0273089949041605,0.4780035881167493,0.9865288138389589,0.044875830411911,0.263436139508458,43793 -4354.20730304718,1.55245041847229,9134.948380470276,28519,0,9134.948380470276,0.9858773350715636,0.0482346639037132,0.2525917458842993,43793,13491.490109682083,0.9918634295463562,0.0262583419680595,0.4980736352194875,0.9867614507675172,0.0451503284275531,0.2686315066777437,43793 -4460.907576322556,1.5855414867401123,9375.02670264244,29267,0,9375.02670264244,0.9858684539794922,0.0482171401381492,0.25379463137131,43793,13838.322686195374,0.9919530749320984,0.0258277095854282,0.5062438255238915,0.9866485595703124,0.0451736338436603,0.2622896558243855,43793 -4564.600377559662,1.6183860301971436,9615.066091775894,30027,0,9615.066091775894,0.98581200838089,0.0479614846408367,0.2577138416244921,43793,14182.107945919037,0.9917265176773072,0.0266038626432418,0.4825484781293242,0.9867070317268372,0.0451933033764362,0.2677666322082109,43793 -4672.233599424362,1.6509528160095217,9855.15208029747,30789,0,9855.15208029747,0.9856843948364258,0.0481309331953525,0.2610056338665237,43793,14529.88004231453,0.9917187094688416,0.0268257912248373,0.4753152896559408,0.9866449236869812,0.0451675727963447,0.2709646793278654,43793 -4776.200966119766,1.6840364933013916,10095.158047437668,31549,0,10095.158047437668,0.9858478307724,0.0483252517879009,0.2621210626835716,43793,14873.906448364258,0.9915143251419068,0.0273069459944963,0.4664824860010296,0.9866051077842712,0.0454580970108509,0.2700092224648855,43793 -4882.224649429321,1.7169504165649414,10335.340276241302,32302,0,10335.340276241302,0.9857808351516724,0.0479646921157836,0.2607573388347303,43793,15220.165620326996,0.991627275943756,0.0269815512001514,0.4847153296311458,0.9866538643836976,0.0450129844248294,0.2784450487358846,43793 -4989.512982130051,1.749946117401123,10575.47698545456,33058,0,10575.47698545456,0.9859063625335692,0.0485273450613021,0.2565631284387224,43793,15567.64379477501,0.9915615320205688,0.026972159743309,0.4788424207599614,0.9867366552352904,0.0455386377871036,0.2693432926444069,43793 -5100.989178657532,1.7831127643585205,10815.58456158638,33804,0,10815.58456158638,0.9856860637664796,0.0485675148665905,0.2549646396946223,43793,15919.281561613085,0.991463303565979,0.0272666439414024,0.4701334986106348,0.986585259437561,0.0456389114260673,0.2751634706236099,43793 -5206.580208301544,1.8159537315368648,11055.76077890396,34562,0,11055.76077890396,0.9858149886131288,0.0488108955323696,0.2672977632378215,43793,16265.101999282835,0.9917646646499634,0.0262292195111513,0.4959750658599631,0.9866027235984802,0.0459834299981594,0.2699241965319351,43793 -5313.920674800873,1.8492765426635744,11295.954808950424,35322,0,11295.954808950424,0.9858158230781556,0.0485912635922431,0.2602659377411682,43793,16612.69092822075,0.992120325565338,0.0252374149858951,0.5261174470746846,0.9866863489151,0.0453717298805713,0.2695367499370574,43793 -5417.184643983841,1.882709264755249,11535.98846077919,36076,0,11535.98846077919,0.9856621026992798,0.0485945083200931,0.2601393911286229,43793,16956.04179906845,0.9921783804893494,0.0249777026474475,0.5172506535598527,0.9864655137062072,0.0458004400134086,0.2712302921497364,43793 -5521.864846467972,1.9163761138916016,11776.233745574951,36835,0,11776.233745574951,0.985743761062622,0.0490642413496971,0.2600800066140546,43793,17301.020915985107,0.9925355911254884,0.0236619692295789,0.5555088251116154,0.9866177439689636,0.0460899136960506,0.2748257784845323,43793 -5633.303159475327,1.9569110870361328,12016.22732925415,37588,0,12016.22732925415,0.985876441001892,0.0493538789451122,0.2592533850321455,43793,17652.513543844223,0.9923935532569884,0.0240364968776702,0.550742001196731,0.986711084842682,0.0462623983621597,0.2691423838406784,43793 -5744.66099023819,1.9948585033416748,12256.457571268082,38335,0,12256.457571268082,0.985694706439972,0.0492275729775428,0.264713702862164,43793,18004.159906864166,0.9923391938209534,0.0244967173784971,0.5252410837212944,0.986572265625,0.0459901019930839,0.2755666996524653,43793 -5847.613451719284,2.028700351715088,12496.689729452131,39088,0,12496.689729452131,0.985687792301178,0.04935147985816,0.2622586007059652,43793,18347.39897227288,0.9921321868896484,0.0250871740281581,0.5257421666684701,0.9865466952323914,0.0463662147521972,0.2729468275313599,43793 -5955.988241195679,2.06325101852417,12736.806010484695,39848,0,12736.806010484695,0.985620379447937,0.0494289062917232,0.2571113858861016,43793,18695.944691181183,0.992060124874115,0.0253172293305397,0.5124504340697817,0.9864711761474608,0.0464315824210643,0.2747143159924758,43793 -6058.107916593552,2.097200632095337,12977.04095864296,40588,0,12977.04095864296,0.9857370257377625,0.0491687096655368,0.2609491486758925,43793,19038.35589647293,0.9921411871910096,0.0250214058905839,0.5354935396143286,0.9864374995231628,0.0461788550019264,0.2703810326860346,43793 -6166.109060764313,2.139084815979004,13217.228226184843,41348,0,13217.228226184843,0.9857257008552552,0.0491464734077453,0.2626485602912387,43793,19386.606401205063,0.992215633392334,0.0246634781360626,0.5279078039982317,0.9865621328353882,0.046393159776926,0.2734804765352344,43793 -6276.40532040596,2.1731395721435547,13457.440511703491,42093,0,13457.440511703491,0.9857661128044128,0.0498609170317649,0.2591178976031317,43793,19737.170392751694,0.9923676252365112,0.0241897907108068,0.5447008755982579,0.9866648316383362,0.0466361306607723,0.2701977304227742,43793 -6388.139190912247,2.2093663215637207,13697.419059038162,42853,0,13697.419059038162,0.985605239868164,0.0498588271439075,0.264570208509632,43793,20088.93888497353,0.9926583170890808,0.0231785699725151,0.5682554598312384,0.9865312576293944,0.0467051304876804,0.2762466894173453,43793 -6490.929327249527,2.244922876358032,13937.618394374847,43608,0,13937.618394374847,0.9856485724449158,0.0495678298175334,0.2643410263914481,43793,20431.984637498856,0.9930390119552612,0.0220589619129896,0.610813742315912,0.986449658870697,0.0466724410653114,0.277682788359816,43793 -6595.918229103088,2.280789852142334,14177.737513303757,44364,0,14177.737513303757,0.9856389164924622,0.0498519428074359,0.2585699227760166,43793,20777.149724960327,0.9931018352508544,0.0217983778566122,0.5935135327960487,0.9863875508308412,0.0471132732927799,0.2702936182597056,43793 -6700.786629199982,2.320936918258667,14417.723826646805,45119,0,14417.723826646805,0.9856823086738586,0.0505120158195495,0.2607793049964524,43793,21122.0652821064,0.9935007691383362,0.0207743290811777,0.6282470077284659,0.986407458782196,0.0473083592951297,0.2746351986985576,43793 -6803.560876607895,2.355963468551636,14657.829684019089,45872,0,14657.829684019089,0.9856418371200562,0.0508385933935642,0.2581153166498968,43793,21465.00151848793,0.993179976940155,0.021613860502839,0.5867398796278861,0.9865272045135498,0.0474358759820461,0.270803415758193,43793 -6907.516575574875,2.390967845916748,14897.901216506958,46624,0,14897.901216506958,0.9857025146484376,0.051041230559349,0.2624502477888986,43793,21809.08435201645,0.9928723573684692,0.0224040485918521,0.5792297219554194,0.9865121841430664,0.0478705167770385,0.2754371891996536,43793 -7011.414033174515,2.427185297012329,15137.97015786171,47381,0,15137.97015786171,0.9857825636863708,0.0509413219988346,0.2635603124742287,43793,22153.10691165924,0.9928503632545472,0.0223389249294996,0.5829663625984829,0.9865494966506958,0.0480351597070694,0.2717624907753252,43793 -7121.453207015991,2.461944580078125,15378.18701696396,48134,0,15378.18701696396,0.9855150580406188,0.0513287670910358,0.2564390795803031,43793,22503.418353796005,0.9928312301635742,0.0224734991788864,0.581797263235984,0.9864045977592468,0.0480904653668403,0.2763301327778938,43793 -7231.157967567444,2.4997594356536865,15618.224890708923,48888,0,15618.224890708923,0.9853625893592834,0.0514854378998279,0.255840481288772,43793,22853.219148159027,0.9929319024086,0.022101666778326,0.5932613878542279,0.9862564206123352,0.0482456870377063,0.2651477654096575,43793 -7335.19739151001,2.960538148880005,15857.749560117722,49643,0,15857.749560117722,0.9856106638908386,0.0522529184818267,0.2538569310413312,43793,23197.264449834824,0.9931454062461852,0.0213271249085664,0.5954668398015054,0.9864569902420044,0.0487939976155757,0.2757709988394968,43793 -7444.359100818634,2.996345281600952,16097.697563171389,50393,0,16097.697563171389,0.985623300075531,0.0526586323976516,0.2555738823624148,43793,23546.42974758148,0.9933032393455504,0.0206730123609304,0.6219227461660122,0.986491084098816,0.0492160953581333,0.271150679389319,43793 -7546.4050052165985,3.041268825531006,16337.699612855911,51151,0,16337.699612855911,0.9853950142860411,0.0527043342590332,0.2526182830014674,43793,23888.542701005936,0.9937353134155272,0.0196653474122285,0.649720284927805,0.9862414002418518,0.0492569468915462,0.2753046787889687,43793 -7654.490574836731,3.076837539672852,16577.81442141533,51910,0,16577.81442141533,0.98554664850235,0.0531106442213058,0.2518272298354042,43793,24236.79912090301,0.9942708015441896,0.0178426392376422,0.6942757210473101,0.9864622354507446,0.0494240075349807,0.2712660328003022,43793 -7763.898961544037,3.1142618656158447,16817.855093955994,52662,0,16817.855093955994,0.9852699637413024,0.05378008633852,0.2500704661699077,43793,24586.30546784401,0.994036078453064,0.018522597849369,0.6743744810593322,0.9860635995864868,0.0501335337758064,0.266329344424858,43793 -7870.911877632141,3.1511764526367188,17057.9955971241,53406,0,17057.9955971241,0.985526442527771,0.0545035935938358,0.2508009860190624,43793,24933.515946626663,0.9941920042037964,0.0179198160767555,0.6768201170825074,0.986491858959198,0.050784520804882,0.2638165742392505,43793 -7978.640509128571,3.1888413429260254,17298.101543664932,54155,0,17298.101543664932,0.9853095412254332,0.0547344386577606,0.2534831283282912,43793,25281.40832829476,0.9939531683921814,0.0188170745968818,0.660413799421119,0.9861224889755248,0.0512725599110126,0.2609240909249675,43793 -8083.218180656433,3.224980354309082,17538.212922811508,54910,0,17538.212922811508,0.9854388236999512,0.0555348582565784,0.2517655666803668,43793,25626.154417037964,0.9937155842781068,0.0192698650062084,0.6451781842926773,0.9864431619644164,0.0515469796955585,0.263833178461923,43793 -8189.279809951782,3.261626958847046,17778.314692020416,55667,0,17778.314692020416,0.9853293299674988,0.0559036470949649,0.2508333407109575,43793,25972.37487721443,0.9937472343444824,0.0190456174314022,0.666710952563394,0.9862653613090516,0.0520232208073139,0.2650312759193009,43793 -8297.278610229492,3.2980382442474365,18018.523869276047,56428,0,18018.523869276047,0.9853870272636414,0.0560189709067344,0.2533901667486278,43793,26320.63973712921,0.9937534928321838,0.0190392918884754,0.653744196203391,0.9862361550331116,0.0522097609937191,0.2656088692381051,43793 -8407.633393764496,3.338200807571411,18258.58660340309,57170,0,18258.58660340309,0.9852640628814697,0.05669456720352173,0.24714238903356686,43793,26671.117931365967,0.993995189666748,0.018129203468561172,0.6793042671348181,0.9860672354698181,0.05306480824947357,0.25908438815895707,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/measurements.csv deleted file mode 100644 index 7f315d38b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/measurements.csv +++ /dev/null @@ -1,658 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3069103,0.7360678,,,,,,,,,,,,,,,,, -1,,,0.5291368365287781,0.7363722324371338,0.0218029747063758,0.5270814895629883,0.737440824508667,0.0240470210742928,43793.0,0.5256850719451904,0.7376683354377747,0.0260580168197227,43793.0,11.743421077728271,126.48002123832704,11.743421077728271,114.73655438423155,0.0,0.0 -100,0.5977209,0.45140156,,,,,,,,,,,,,,,,, -200,0.36484256,0.33138102,,,,,,,,,,,,,,,,, -300,0.26505953,0.23799436,,,,,,,,,,,,,,,,, -400,0.1741102,0.16270784,,,,,,,,,,,,,,,,, -500,0.10657344,0.117706284,,,,,,,,,,,,,,,,, -600,0.067157336,0.08904933,,,,,,,,,,,,,,,,, -700,0.045562927,0.070899606,,,,,,,,,,,,,,,,, -749,,,0.9866331815719604,0.0689497143030166,0.0381990878470789,0.9841179251670836,0.0770890638232231,0.0413221289476577,43793.0,0.983142077922821,0.0799547135829925,0.0433793901311975,43793.0,252.02333188056943,483.5957524776459,252.02333188056943,231.53080677986145,0.0203430652618408,0.0 -800,0.1353868,0.069817774,,,,,,,,,,,,,,,,, -900,0.56478304,0.06056622,,,,,,,,,,,,,,,,, -1000,0.16200195,0.05518084,,,,,,,,,,,,,,,,, -1100,0.2109137,0.04945036,,,,,,,,,,,,,,,,, -1200,0.18658803,0.052263677,,,,,,,,,,,,,,,,, -1300,0.23013349,0.04766509,,,,,,,,,,,,,,,,, -1400,0.15933093,0.052277435,,,,,,,,,,,,,,,,, -1499,,,0.9871978163719176,0.0483209080994129,0.0859351007741268,0.984529972076416,0.0576369874179363,0.0872120753564924,43793.0,0.9835198521614076,0.0610214844346046,0.0858427855891884,43793.0,492.2781195640564,837.9886548519135,492.2781195640564,345.62128949165344,0.0478668212890625,0.0 -1500,0.26642793,0.057241432,,,,,,,,,,,,,,,,, -1600,0.30220777,0.050244052,,,,,,,,,,,,,,,,, -1700,0.17959517,0.047968682,,,,,,,,,,,,,,,,, -1800,0.10014963,0.050227743,,,,,,,,,,,,,,,,, -1900,0.08796599,0.0507884,,,,,,,,,,,,,,,,, -2000,0.121838704,0.04525522,,,,,,,,,,,,,,,,, -2100,0.24418488,0.04629742,,,,,,,,,,,,,,,,, -2200,0.14637055,0.045533843,,,,,,,,,,,,,,,,, -2248,,,0.98789244890213,0.0435101650655269,0.1318084182003098,0.9850869178771972,0.0531092658638954,0.1295160089173986,43793.0,0.9840531349182128,0.0561034195125103,0.1367317304202627,43793.0,732.3831737041473,1199.1216876506803,732.3831737041473,466.60114884376526,0.0754346847534179,0.0 -2300,0.09934877,0.04188651,,,,,,,,,,,,,,,,, -2400,0.08451623,0.044854213,,,,,,,,,,,,,,,,, -2500,0.18704475,0.046695072,,,,,,,,,,,,,,,,, -2600,0.090265036,0.039685935,,,,,,,,,,,,,,,,, -2700,0.13227062,0.04605916,,,,,,,,,,,,,,,,, -2800,0.10641513,0.045020092,,,,,,,,,,,,,,,,, -2900,0.11909342,0.041570775,,,,,,,,,,,,,,,,, -2997,,,0.9880146980285645,0.0421840101480484,0.1713988685245202,0.9852290153503418,0.0512479990720748,0.154959509498418,43793.0,0.9842986464500428,0.053968284279108,0.1536918482128193,43793.0,972.3318812847136,1547.6310951709747,972.3318812847136,575.1102995872498,0.1052591800689697,0.0 -3000,0.059398588,0.03824162,,,,,,,,,,,,,,,,, -3100,0.08849929,0.044341642,,,,,,,,,,,,,,,,, -3200,0.081638806,0.038783696,,,,,,,,,,,,,,,,, -3300,0.077074856,0.041007847,,,,,,,,,,,,,,,,, -3400,0.12794818,0.03825378,,,,,,,,,,,,,,,,, -3500,0.075508274,0.044554085,,,,,,,,,,,,,,,,, -3600,0.14333622,0.04283009,,,,,,,,,,,,,,,,, -3700,0.10270188,0.037887886,,,,,,,,,,,,,,,,, -3750,,,0.988412380218506,0.0405141748487949,0.1719516109836918,0.9853410124778748,0.0503965727984905,0.1698219331341683,43793.0,0.9845029711723328,0.0533266179263591,0.1682972779868397,43793.0,1212.3737390041351,1902.1999762058256,1212.3737390041351,689.58882188797,0.1334693431854248,0.0 -3800,0.06476213,0.03866514,,,,,,,,,,,,,,,,, -3900,0.12148537,0.0415388,,,,,,,,,,,,,,,,, -4000,0.06759086,0.042332932,,,,,,,,,,,,,,,,, -4100,0.06940226,0.036519628,,,,,,,,,,,,,,,,, -4200,0.14266683,0.039879885,,,,,,,,,,,,,,,,, -4300,0.046915136,0.03947696,,,,,,,,,,,,,,,,, -4400,0.057575814,0.04019922,,,,,,,,,,,,,,,,, -4500,0.044264693,0.034529828,,,,,,,,,,,,,,,,, -4503,,,0.9884085059165956,0.0401086173951625,0.2116189163327401,0.9854055643081664,0.0493852011859416,0.1808792961795491,43793.0,0.9845539331436156,0.0520947463810443,0.1769383818230964,43793.0,1452.4729554653168,2260.835620880127,1452.4729554653168,808.075345993042,0.1630210876464843,0.0 -4600,0.056426883,0.04037038,,,,,,,,,,,,,,,,, -4700,0.114991665,0.04685814,,,,,,,,,,,,,,,,, -4800,0.05920896,0.033751804,,,,,,,,,,,,,,,,, -4900,0.06187173,0.039364677,,,,,,,,,,,,,,,,, -5000,0.03319209,0.036689863,,,,,,,,,,,,,,,,, -5100,0.061753523,0.03678751,,,,,,,,,,,,,,,,, -5200,0.038576026,0.037744563,,,,,,,,,,,,,,,,, -5253,,,0.9887412786483764,0.0384508743882179,0.2200065332386008,0.9858241081237792,0.048249889165163,0.1957211458247129,43793.0,0.9849397540092468,0.0509831346571445,0.1961209567121776,43793.0,1692.5093190670011,2617.616218805313,1692.5093190670011,924.7717523574828,0.190324068069458,0.0 -5300,0.069789015,0.03798268,,,,,,,,,,,,,,,,, -5400,0.078643404,0.038951058,,,,,,,,,,,,,,,,, -5500,0.09193181,0.040098347,,,,,,,,,,,,,,,,, -5600,0.07889473,0.03730408,,,,,,,,,,,,,,,,, -5700,0.070316225,0.0406548,,,,,,,,,,,,,,,,, -5800,0.04235111,0.036421,,,,,,,,,,,,,,,,, -5900,0.050870545,0.03659881,,,,,,,,,,,,,,,,, -6000,0.06761097,0.036077738,,,,,,,,,,,,,,,,, -6003,,,0.9889976382255554,0.0376361981034278,0.2488204824404136,0.9858983755111694,0.0479212068021297,0.2057879012354354,43793.0,0.9850517511367798,0.0505339317023754,0.2111077347749315,43793.0,1932.4763493537905,2969.128954410553,1932.4763493537905,1036.267024755478,0.2197639942169189,0.0 -6100,0.041538477,0.03764328,,,,,,,,,,,,,,,,, -6200,0.054510716,0.0400627,,,,,,,,,,,,,,,,, -6300,0.038720515,0.035960473,,,,,,,,,,,,,,,,, -6400,0.08844911,0.035577405,,,,,,,,,,,,,,,,, -6500,0.039290648,0.038600314,,,,,,,,,,,,,,,,, -6600,0.15978101,0.03997962,,,,,,,,,,,,,,,,, -6700,0.03835214,0.03410417,,,,,,,,,,,,,,,,, -6755,,,0.9892213940620422,0.0366696007549762,0.2634163786222182,0.986061990261078,0.0467621460556983,0.2146012390955433,43793.0,0.9851566553115844,0.0493510365486145,0.2218740276827641,43793.0,2172.550704717636,3322.5414004325867,2172.550704717636,1149.54962849617,0.2551395893096924,0.0 -6800,0.028023468,0.03542205,,,,,,,,,,,,,,,,, -6900,0.033806145,0.039195303,,,,,,,,,,,,,,,,, -7000,0.028328326,0.03592991,,,,,,,,,,,,,,,,, -7100,0.034682136,0.038101736,,,,,,,,,,,,,,,,, -7200,0.051441506,0.037009656,,,,,,,,,,,,,,,,, -7300,0.03686903,0.041611608,,,,,,,,,,,,,,,,, -7400,0.0396447,0.03818717,,,,,,,,,,,,,,,,, -7500,0.029591344,0.03605585,,,,,,,,,,,,,,,,, -7516,,,0.989296853542328,0.0359546542167663,0.278105170758098,0.9860575199127196,0.0468306243419647,0.2125293204426107,43793.0,0.9853032231330872,0.0493544600903987,0.2166553914108559,43793.0,2412.738084077835,3680.512838125229,2412.738084077835,1267.2843120098114,0.2840569019317627,0.0 -7600,0.036217235,0.038033795,,,,,,,,,,,,,,,,, -7700,0.03493824,0.038502555,,,,,,,,,,,,,,,,, -7800,0.04045588,0.037976235,,,,,,,,,,,,,,,,, -7900,0.050704286,0.033576954,,,,,,,,,,,,,,,,, -8000,0.026936483,0.0339868,,,,,,,,,,,,,,,,, -8100,0.036686648,0.039905578,,,,,,,,,,,,,,,,, -8200,0.029528605,0.03581684,,,,,,,,,,,,,,,,, -8273,,,0.9894575476646424,0.035493679344654,0.280584196580796,0.9862965941429138,0.0462246499955654,0.2248349779677364,43793.0,0.9853959083557128,0.0489659383893013,0.2248276800619779,43793.0,2652.844264984131,4030.8121387958527,2652.844264984131,1377.429646730423,0.3117320537567138,0.0 -8300,0.030706616,0.03295863,,,,,,,,,,,,,,,,, -8400,0.024898915,0.0347639,,,,,,,,,,,,,,,,, -8500,0.034142528,0.036920574,,,,,,,,,,,,,,,,, -8600,0.024388067,0.03374456,,,,,,,,,,,,,,,,, -8700,0.03213183,0.03586349,,,,,,,,,,,,,,,,, -8800,0.03058035,0.035618998,,,,,,,,,,,,,,,,, -8900,0.029497527,0.039824914,,,,,,,,,,,,,,,,, -9000,0.04674735,0.038691964,,,,,,,,,,,,,,,,, -9024,,,0.9894021153450012,0.0356401763856411,0.2816182237666848,0.9863436818122864,0.0458505451679229,0.2299495299738932,43793.0,0.9855276942253112,0.0486106649041175,0.2302574675518472,43793.0,2892.8968245983124,4382.973264217377,2892.8968245983124,1489.488579750061,0.3412954807281494,0.0 -9100,0.040850785,0.03486967,,,,,,,,,,,,,,,,, -9200,0.03216612,0.03190459,,,,,,,,,,,,,,,,, -9300,0.025859447,0.03544497,,,,,,,,,,,,,,,,, -9400,0.05087671,0.035673633,,,,,,,,,,,,,,,,, -9500,0.03345791,0.034142982,,,,,,,,,,,,,,,,, -9600,0.03246846,0.038826026,,,,,,,,,,,,,,,,, -9700,0.036188323,0.033655822,,,,,,,,,,,,,,,,, -9776,,,0.989449679851532,0.0354800149798393,0.2902229885165782,0.9861695766448976,0.0460140667855739,0.228143906935731,43793.0,0.985284686088562,0.0488116517663002,0.2275440479991647,43793.0,3132.904645681381,4736.391147851944,3132.904645681381,1602.8503172397614,0.3689990043640136,0.0 -9800,0.04492598,0.035405114,,,,,,,,,,,,,,,,, -9900,0.058550313,0.037957214,,,,,,,,,,,,,,,,, -10000,0.029795358,0.033699803,,,,,,,,,,,,,,,,, -10100,0.030292029,0.031270236,,,,,,,,,,,,,,,,, -10200,0.0333693,0.031985007,,,,,,,,,,,,,,,,, -10300,0.030211229,0.035597775,,,,,,,,,,,,,,,,, -10400,0.03367724,0.032567326,,,,,,,,,,,,,,,,, -10500,0.029636526,0.03502823,,,,,,,,,,,,,,,,, -10529,,,0.9898571372032166,0.0340229794383049,0.3148685575602764,0.9864293336868286,0.0454701818525791,0.2401039852908579,43793.0,0.985558032989502,0.0482427552342414,0.232955356440201,43793.0,3373.1482582092285,5093.1759622097015,3373.1482582092285,1719.3369302749634,0.4020838737487793,0.0 -10600,0.026613722,0.03095515,,,,,,,,,,,,,,,,, -10700,0.040332537,0.035066847,,,,,,,,,,,,,,,,, -10800,0.033279978,0.032398753,,,,,,,,,,,,,,,,, -10900,0.044536136,0.03603456,,,,,,,,,,,,,,,,, -11000,0.031656496,0.034454253,,,,,,,,,,,,,,,,, -11100,0.039132606,0.0352566,,,,,,,,,,,,,,,,, -11200,0.035498597,0.036684796,,,,,,,,,,,,,,,,, -11283,,,0.989729344844818,0.0340922735631465,0.316281491512823,0.9861927032470704,0.0455793216824531,0.2445057253378528,43793.0,0.9854072332382202,0.0483103096485137,0.2369131371307541,43793.0,3613.232980489731,5442.6319761276245,3613.232980489731,1828.657984495163,0.4322450160980224,0.0 -11300,0.04638389,0.033509698,,,,,,,,,,,,,,,,, -11400,0.036510035,0.03261976,,,,,,,,,,,,,,,,, -11500,0.048639458,0.035530176,,,,,,,,,,,,,,,,, -11600,0.0338635,0.034965727,,,,,,,,,,,,,,,,, -11700,0.03274986,0.033895597,,,,,,,,,,,,,,,,, -11800,0.053227782,0.033216726,,,,,,,,,,,,,,,,, -11900,0.050185386,0.033781342,,,,,,,,,,,,,,,,, -12000,0.032658413,0.03184481,,,,,,,,,,,,,,,,, -12043,,,0.9900323748588562,0.0331622846424579,0.341318229949142,0.986455738544464,0.045129720121622,0.2422295153032972,43793.0,0.985649049282074,0.0477396845817565,0.2417897996837219,43793.0,3853.240893602371,5797.5224277973175,3853.240893602371,1943.490604162216,0.4619700908660888,0.0 -12100,0.045097277,0.03310557,,,,,,,,,,,,,,,,, -12200,0.05145082,0.033565268,,,,,,,,,,,,,,,,, -12300,0.05335579,0.03252876,,,,,,,,,,,,,,,,, -12400,0.04191538,0.03366888,,,,,,,,,,,,,,,,, -12500,0.04643912,0.030820182,,,,,,,,,,,,,,,,, -12600,0.036898833,0.034337807,,,,,,,,,,,,,,,,, -12700,0.03421089,0.034607384,,,,,,,,,,,,,,,,, -12792,,,0.990278661251068,0.0320963263511657,0.3793567029038292,0.9866027235984802,0.0446693301200866,0.2507951353642788,43793.0,0.9857446551322936,0.047577828168869,0.2455561227362303,43793.0,4092.884313106537,6149.5794270038605,4092.884313106537,2055.456913471222,0.8892319202423096,0.0 -12800,0.052345634,0.033954296,,,,,,,,,,,,,,,,, -12900,0.03804923,0.03227029,,,,,,,,,,,,,,,,, -13000,0.03441896,0.028723136,,,,,,,,,,,,,,,,, -13100,0.04917992,0.03139254,,,,,,,,,,,,,,,,, -13200,0.038384296,0.029348867,,,,,,,,,,,,,,,,, -13300,0.039842267,0.03090015,,,,,,,,,,,,,,,,, -13400,0.043231204,0.030964142,,,,,,,,,,,,,,,,, -13500,0.04323141,0.030659635,,,,,,,,,,,,,,,,, -13540,,,0.9903260469436646,0.0318269245326519,0.3630988382990801,0.986552357673645,0.0445600450038909,0.2593600780039394,43793.0,0.9855934381484984,0.0475713983178138,0.2468264667660303,43793.0,4333.080197811127,6500.808378696442,4333.080197811127,2166.441266775131,0.9181814193725586,0.0 -13600,0.04702151,0.0326841,,,,,,,,,,,,,,,,, -13700,0.038971953,0.030985825,,,,,,,,,,,,,,,,, -13800,0.051694423,0.028505806,,,,,,,,,,,,,,,,, -13900,0.04855902,0.031592827,,,,,,,,,,,,,,,,, -14000,0.04506469,0.031889416,,,,,,,,,,,,,,,,, -14100,0.042500075,0.033997264,,,,,,,,,,,,,,,,, -14200,0.049607042,0.031473726,,,,,,,,,,,,,,,,, -14285,,,0.990658700466156,0.0308063067495822,0.3933822209818916,0.9864890575408936,0.0447542667388916,0.2566857614723372,43793.0,0.9856544733047484,0.0473392717540264,0.2469929637029501,43793.0,4573.193847417831,6855.502467632294,4573.193847417831,2280.971666574478,0.9479734897613524,0.0 -14300,0.054102156,0.03329901,,,,,,,,,,,,,,,,, -14400,0.06673991,0.033112332,,,,,,,,,,,,,,,,, -14500,0.04704691,0.029283194,,,,,,,,,,,,,,,,, -14600,0.08140309,0.03249224,,,,,,,,,,,,,,,,, -14700,0.050905738,0.031111334,,,,,,,,,,,,,,,,, -14800,0.08062568,0.02820569,,,,,,,,,,,,,,,,, -14900,0.06818839,0.030681511,,,,,,,,,,,,,,,,, -15000,0.06062121,0.031525105,,,,,,,,,,,,,,,,, -15030,,,0.99084734916687,0.030131122097373,0.3993562358528172,0.9865454435348512,0.0447263419628143,0.2587401186460871,43793.0,0.9856852293014526,0.0476922467350959,0.2472861083323727,43793.0,4813.296638011932,7206.325021743774,4813.296638011932,2391.64150595665,0.9773366451263428,0.0 -15100,0.05186234,0.03034566,,,,,,,,,,,,,,,,, -15200,0.052472737,0.03220125,,,,,,,,,,,,,,,,, -15300,0.11372656,0.030752026,,,,,,,,,,,,,,,,, -15400,0.07179315,0.030212516,,,,,,,,,,,,,,,,, -15500,0.055093482,0.03191312,,,,,,,,,,,,,,,,, -15600,0.06392266,0.032940086,,,,,,,,,,,,,,,,, -15700,0.058522027,0.031114887,,,,,,,,,,,,,,,,, -15774,,,0.990815281867981,0.0303503256291151,0.3971719559810763,0.9866440892219543,0.0446596071124076,0.265979384637311,43793.0,0.985791802406311,0.0475449040532112,0.2527092747568265,43793.0,5053.413534879684,7556.462839126587,5053.413534879684,2501.61195063591,1.0068624019622805,0.0 -15800,0.054738194,0.031060942,,,,,,,,,,,,,,,,, -15900,0.05772934,0.032700222,,,,,,,,,,,,,,,,, -16000,0.0628304,0.03015907,,,,,,,,,,,,,,,,, -16100,0.059624407,0.029771172,,,,,,,,,,,,,,,,, -16200,0.07248819,0.03256995,,,,,,,,,,,,,,,,, -16300,0.062660895,0.033483405,,,,,,,,,,,,,,,,, -16400,0.061288178,0.031639073,,,,,,,,,,,,,,,,, -16500,0.12982629,0.034766663,,,,,,,,,,,,,,,,, -16528,,,0.990627944469452,0.0308685638010501,0.3913606469411551,0.9864423274993896,0.0450369007885456,0.2502382753996813,43793.0,0.9855433106422424,0.0480347089469432,0.2441168465974053,43793.0,5293.414917469025,7904.33753156662,5293.414917469025,2609.4342000484467,1.0378234386444092,0.0 -16600,0.056959476,0.033055905,,,,,,,,,,,,,,,,, -16700,0.074861676,0.033883672,,,,,,,,,,,,,,,,, -16800,0.052436046,0.030758169,,,,,,,,,,,,,,,,, -16900,0.09952594,0.029610572,,,,,,,,,,,,,,,,, -17000,0.09160679,0.03258648,,,,,,,,,,,,,,,,, -17100,0.07490917,0.028978078,,,,,,,,,,,,,,,,, -17200,0.05910042,0.030312018,,,,,,,,,,,,,,,,, -17281,,,0.9905187487602234,0.0311160124838352,0.3874007734505545,0.9865190982818604,0.0445390082895755,0.2598120879231033,43793.0,0.9857496619224548,0.0472060106694698,0.2529896059917935,43793.0,5533.490139722824,8254.701673984528,5533.490139722824,2719.6714749336243,1.0690538883209229,0.0 -17300,0.06749697,0.029294703,,,,,,,,,,,,,,,,, -17400,0.06223943,0.029184027,,,,,,,,,,,,,,,,, -17500,0.06523788,0.03164773,,,,,,,,,,,,,,,,, -17600,0.13020742,0.02822819,,,,,,,,,,,,,,,,, -17700,0.05892545,0.029532101,,,,,,,,,,,,,,,,, -17800,0.06438745,0.030481042,,,,,,,,,,,,,,,,, -17900,0.070800744,0.031182734,,,,,,,,,,,,,,,,, -18000,0.066799365,0.029282756,,,,,,,,,,,,,,,,, -18027,,,0.990691065788269,0.0305870901793241,0.3921542590010484,0.9866855144500732,0.0443942174315452,0.2646887457085707,43793.0,0.9858617186546326,0.0472342632710933,0.2590980910827544,43793.0,5773.559222221375,8608.562299251556,5773.559222221375,2833.4104483127594,1.101609468460083,0.0 -18100,0.081939965,0.033920284,,,,,,,,,,,,,,,,, -18200,0.06428919,0.029091142,,,,,,,,,,,,,,,,, -18300,0.06601522,0.029829564,,,,,,,,,,,,,,,,, -18400,0.064528786,0.031073792,,,,,,,,,,,,,,,,, -18500,0.0841622,0.031105716,,,,,,,,,,,,,,,,, -18600,0.097090594,0.030861823,,,,,,,,,,,,,,,,, -18700,0.08774071,0.03361878,,,,,,,,,,,,,,,,, -18772,,,0.990775227546692,0.0302655268460512,0.3949178652503551,0.98653244972229,0.0453040190041065,0.2542964945301081,43793.0,0.9856511354446412,0.048032097518444,0.2462150144339345,43793.0,6013.634890794754,8958.123777866364,6013.634890794754,2942.844471931457,1.133225440979004,0.0 -18800,0.101088166,0.027383655,,,,,,,,,,,,,,,,, -18900,0.066910096,0.030538172,,,,,,,,,,,,,,,,, -19000,0.072519146,0.033233766,,,,,,,,,,,,,,,,, -19100,0.06275699,0.024380255,,,,,,,,,,,,,,,,, -19200,0.08270278,0.03163188,,,,,,,,,,,,,,,,, -19300,0.08745536,0.02961025,,,,,,,,,,,,,,,,, -19400,0.10547148,0.02938238,,,,,,,,,,,,,,,,, -19500,0.08866343,0.029700976,,,,,,,,,,,,,,,,, -19524,,,0.990888774394989,0.0298353638499975,0.4150541904359339,0.9865884780883788,0.044611282646656,0.257674980871371,43793.0,0.9857951998710632,0.0473557077348232,0.2504987704034213,43793.0,6253.68617773056,9309.091790914536,6253.68617773056,3053.710742712021,1.1633639335632324,0.0 -19600,0.097121194,0.031412393,,,,,,,,,,,,,,,,, -19700,0.08843191,0.032751266,,,,,,,,,,,,,,,,, -19800,0.078801274,0.028809618,,,,,,,,,,,,,,,,, -19900,0.07605705,0.030203907,,,,,,,,,,,,,,,,, -20000,0.09137616,0.031200003,,,,,,,,,,,,,,,,, -20100,0.0857185,0.032297328,,,,,,,,,,,,,,,,, -20200,0.08783948,0.033038042,,,,,,,,,,,,,,,,, -20281,,,0.99105304479599,0.0290999766439199,0.4372819758899958,0.986669659614563,0.0448590330779552,0.2627159908102456,43793.0,0.985710084438324,0.0476002022624015,0.2503044683984956,43793.0,6493.699534893036,9657.622642278671,6493.699534893036,3162.178370714188,1.1930761337280271,0.0 -20300,0.07984065,0.029351646,,,,,,,,,,,,,,,,, -20400,0.084779955,0.028589938,,,,,,,,,,,,,,,,, -20500,0.09204747,0.032806017,,,,,,,,,,,,,,,,, -20600,0.08272575,0.030614037,,,,,,,,,,,,,,,,, -20700,0.14061844,0.031669334,,,,,,,,,,,,,,,,, -20800,0.10342668,0.031184783,,,,,,,,,,,,,,,,, -20900,0.078502014,0.027346032,,,,,,,,,,,,,,,,, -21000,0.08788689,0.0377612,,,,,,,,,,,,,,,,, -21042,,,0.9911764860153198,0.0285094995051622,0.4408113644699721,0.9867488145828248,0.044682178646326,0.2665122163037616,43793.0,0.9857838153839112,0.0475145317614078,0.2609902777836348,43793.0,6733.768709421158,10006.611500024796,6733.768709421158,3271.0463457107544,1.224475383758545,0.0 -21100,0.084685296,0.030586159,,,,,,,,,,,,,,,,, -21200,0.08804631,0.029469198,,,,,,,,,,,,,,,,, -21300,0.08322337,0.029768227,,,,,,,,,,,,,,,,, -21400,0.08485589,0.029260678,,,,,,,,,,,,,,,,, -21500,0.074526876,0.029097976,,,,,,,,,,,,,,,,, -21600,0.08247158,0.028988145,,,,,,,,,,,,,,,,, -21700,0.09687512,0.030337712,,,,,,,,,,,,,,,,, -21776,,,0.991248607635498,0.0282602701336145,0.4538668844774671,0.9866753816604614,0.0444318726658821,0.2672081493576128,43793.0,0.985862135887146,0.0472691245377063,0.2573478156117922,43793.0,6974.013297080994,10357.728278398514,6974.013297080994,3381.8595008850098,1.2595610618591309,0.0 -21800,0.0676149,0.028929986,,,,,,,,,,,,,,,,, -21900,0.12029122,0.02874984,,,,,,,,,,,,,,,,, -22000,0.08913657,0.028543899,,,,,,,,,,,,,,,,, -22100,0.095287226,0.030418644,,,,,,,,,,,,,,,,, -22200,0.11726759,0.029280216,,,,,,,,,,,,,,,,, -22300,0.071875975,0.02684356,,,,,,,,,,,,,,,,, -22400,0.07776488,0.02780754,,,,,,,,,,,,,,,,, -22500,0.08523522,0.032061897,,,,,,,,,,,,,,,,, -22529,,,0.9914287328720092,0.0280359983444213,0.4523613957712486,0.9865572452545166,0.0449055433273315,0.2573994457154161,43793.0,0.9856717586517334,0.047910563647747,0.2484504127180709,43793.0,7213.988672018051,10709.38787341118,7213.988672018051,3493.4931180477142,1.289865255355835,0.0 -22600,0.117682844,0.03394799,,,,,,,,,,,,,,,,, -22700,0.07822185,0.027940568,,,,,,,,,,,,,,,,, -22800,0.07784418,0.02684133,,,,,,,,,,,,,,,,, -22900,0.080089815,0.027449496,,,,,,,,,,,,,,,,, -23000,0.10343522,0.032295685,,,,,,,,,,,,,,,,, -23100,0.10305397,0.028499253,,,,,,,,,,,,,,,,, -23200,0.11415697,0.032200277,,,,,,,,,,,,,,,,, -23263,,,0.9912769198417664,0.0283569395542144,0.4405505491753948,0.986663579940796,0.0446783117949962,0.2679886704370666,43793.0,0.9859034419059752,0.0473452322185039,0.2580197824040367,43793.0,7453.943880081177,11060.453907966614,7453.943880081177,3604.5464034080505,1.3261759281158447,0.0 -23300,0.09830501,0.032610863,,,,,,,,,,,,,,,,, -23400,0.11442902,0.029075213,,,,,,,,,,,,,,,,, -23500,0.076001756,0.028540106,,,,,,,,,,,,,,,,, -23600,0.09826491,0.027825944,,,,,,,,,,,,,,,,, -23700,0.114776306,0.030018622,,,,,,,,,,,,,,,,, -23800,0.073993444,0.026946306,,,,,,,,,,,,,,,,, -23900,0.12101347,0.029954394,,,,,,,,,,,,,,,,, -23997,,,0.9911851286888124,0.0288587547838687,0.4346443075859728,0.9867857694625854,0.0445174388587474,0.2709604029205279,43793.0,0.9859076142311096,0.0473978109657764,0.2571818691088361,43793.0,7694.173835515976,11408.35333752632,7694.173835515976,3712.1588644981375,1.3611788749694824,0.0 -24000,0.085632116,0.03297972,,,,,,,,,,,,,,,,, -24100,0.10942139,0.030254962,,,,,,,,,,,,,,,,, -24200,0.09650137,0.02930423,,,,,,,,,,,,,,,,, -24300,0.088074185,0.029363155,,,,,,,,,,,,,,,,, -24400,0.08302353,0.02846315,,,,,,,,,,,,,,,,, -24500,0.09464058,0.030702362,,,,,,,,,,,,,,,,, -24600,0.100489825,0.030739775,,,,,,,,,,,,,,,,, -24700,0.074430615,0.02785677,,,,,,,,,,,,,,,,, -24753,,,0.9910281896591188,0.0290977265685796,0.4323543772737795,0.9866266250610352,0.0450372099876403,0.2616482551665179,43793.0,0.9857720136642456,0.0477791726589202,0.2502596400985658,43793.0,7934.371356964111,11754.900280475616,7934.371356964111,3818.455191135408,1.3941354751586914,0.0 -24800,0.08364569,0.028507765,,,,,,,,,,,,,,,,, -24900,0.10482167,0.029235745,,,,,,,,,,,,,,,,, -25000,0.11800132,0.027301645,,,,,,,,,,,,,,,,, -25100,0.09431963,0.029485695,,,,,,,,,,,,,,,,, -25200,0.08470066,0.028431514,,,,,,,,,,,,,,,,, -25300,0.13957217,0.029753748,,,,,,,,,,,,,,,,, -25400,0.1198618,0.031970877,,,,,,,,,,,,,,,,, -25500,0.07992968,0.026567584,,,,,,,,,,,,,,,,, -25506,,,0.990942358970642,0.0292593203485012,0.430151878859003,0.986642062664032,0.0455658473074436,0.257154364719039,43793.0,0.985776662826538,0.048466894775629,0.2461963545937001,43793.0,8174.356585264206,12101.346946954727,8174.356585264206,3924.865197896957,1.4252452850341797,0.0 -25600,0.11283668,0.027895879,,,,,,,,,,,,,,,,, -25700,0.08425332,0.02696386,,,,,,,,,,,,,,,,, -25800,0.084599465,0.028591473,,,,,,,,,,,,,,,,, -25900,0.10437494,0.026983507,,,,,,,,,,,,,,,,, -26000,0.099556044,0.028145287,,,,,,,,,,,,,,,,, -26100,0.120695196,0.028972458,,,,,,,,,,,,,,,,, -26200,0.088491894,0.02775576,,,,,,,,,,,,,,,,, -26254,,,0.9912765026092528,0.0283831879496574,0.4462405974678892,0.9867240786552428,0.0452526807785034,0.2645874325719716,43793.0,0.9858777523040771,0.048056062310934,0.2581640623044151,43793.0,8414.532056331635,12454.075824022291,8414.532056331635,4037.365210533142,1.4570987224578855,0.0 -26300,0.11813148,0.031711984,,,,,,,,,,,,,,,,, -26400,0.109236695,0.025696637,,,,,,,,,,,,,,,,, -26500,0.088480964,0.028093273,,,,,,,,,,,,,,,,, -26600,0.093396716,0.033125505,,,,,,,,,,,,,,,,, -26700,0.07629175,0.028736897,,,,,,,,,,,,,,,,, -26800,0.08255243,0.025163665,,,,,,,,,,,,,,,,, -26900,0.13656537,0.032467417,,,,,,,,,,,,,,,,, -27000,0.1272827,0.028777879,,,,,,,,,,,,,,,,, -27008,,,0.99150550365448,0.0274326615035533,0.4741881005743563,0.9866209626197816,0.0449338443577289,0.2674430369710101,43793.0,0.985755980014801,0.0479071885347366,0.2624967264405229,43793.0,8654.747866868973,12797.918403863909,8654.747866868973,4140.939670324326,1.4889216423034668,0.0 -27100,0.110800125,0.026279787,,,,,,,,,,,,,,,,, -27200,0.095307745,0.032493234,,,,,,,,,,,,,,,,, -27300,0.10720525,0.02965799,,,,,,,,,,,,,,,,, -27400,0.09162046,0.026899977,,,,,,,,,,,,,,,,, -27500,0.09043713,0.026190912,,,,,,,,,,,,,,,,, -27600,0.0926724,0.027341947,,,,,,,,,,,,,,,,, -27700,0.08672258,0.028306562,,,,,,,,,,,,,,,,, -27762,,,0.9915337562561036,0.0273089949041605,0.4780035881167493,0.9865288138389589,0.044875830411911,0.263436139508458,43793.0,0.9856725931167604,0.0478430762887001,0.2581436799319148,43793.0,8894.826835393906,13144.054966926577,8894.826835393906,4246.946130990982,1.519975185394287,0.0 -27800,0.08590714,0.02635031,,,,,,,,,,,,,,,,, -27900,0.11190542,0.027686274,,,,,,,,,,,,,,,,, -28000,0.11862475,0.029112814,,,,,,,,,,,,,,,,, -28100,0.093681484,0.028858082,,,,,,,,,,,,,,,,, -28200,0.087001696,0.02542301,,,,,,,,,,,,,,,,, -28300,0.08273743,0.027624542,,,,,,,,,,,,,,,,, -28400,0.10959003,0.028274607,,,,,,,,,,,,,,,,, -28500,0.10129916,0.02733032,,,,,,,,,,,,,,,,, -28519,,,0.9918634295463562,0.0262583419680595,0.4980736352194875,0.9867614507675172,0.0451503284275531,0.2686315066777437,43793.0,0.9858773350715636,0.0482346639037132,0.2525917458842993,43793.0,9134.948380470276,13491.490109682083,9134.948380470276,4354.20730304718,1.55245041847229,0.0 -28600,0.0945073,0.02583739,,,,,,,,,,,,,,,,, -28700,0.10076583,0.029286513,,,,,,,,,,,,,,,,, -28800,0.08073778,0.025825094,,,,,,,,,,,,,,,,, -28900,0.09672669,0.025505165,,,,,,,,,,,,,,,,, -29000,0.1014231,0.030250682,,,,,,,,,,,,,,,,, -29100,0.09671839,0.02661185,,,,,,,,,,,,,,,,, -29200,0.08874584,0.027345663,,,,,,,,,,,,,,,,, -29267,,,0.9919530749320984,0.0258277095854282,0.5062438255238915,0.9866485595703124,0.0451736338436603,0.2622896558243855,43793.0,0.9858684539794922,0.0482171401381492,0.25379463137131,43793.0,9375.02670264244,13838.322686195374,9375.02670264244,4460.907576322556,1.5855414867401123,0.0 -29300,0.100071706,0.026674328,,,,,,,,,,,,,,,,, -29400,0.088865586,0.028280467,,,,,,,,,,,,,,,,, -29500,0.09269502,0.026866192,,,,,,,,,,,,,,,,, -29600,0.1055987,0.027530577,,,,,,,,,,,,,,,,, -29700,0.090242624,0.026357353,,,,,,,,,,,,,,,,, -29800,0.12704395,0.030588705,,,,,,,,,,,,,,,,, -29900,0.09158186,0.02846958,,,,,,,,,,,,,,,,, -30000,0.14805385,0.029896265,,,,,,,,,,,,,,,,, -30027,,,0.9917265176773072,0.0266038626432418,0.4825484781293242,0.9867070317268372,0.0451933033764362,0.2677666322082109,43793.0,0.98581200838089,0.0479614846408367,0.2577138416244921,43793.0,9615.066091775894,14182.107945919037,9615.066091775894,4564.600377559662,1.6183860301971436,0.0 -30100,0.09292038,0.027366841,,,,,,,,,,,,,,,,, -30200,0.09696997,0.025322262,,,,,,,,,,,,,,,,, -30300,0.094756395,0.02721671,,,,,,,,,,,,,,,,, -30400,0.11073782,0.027948622,,,,,,,,,,,,,,,,, -30500,0.09722612,0.028296096,,,,,,,,,,,,,,,,, -30600,0.10126927,0.028468259,,,,,,,,,,,,,,,,, -30700,0.101663545,0.025737092,,,,,,,,,,,,,,,,, -30789,,,0.9917187094688416,0.0268257912248373,0.4753152896559408,0.9866449236869812,0.0451675727963447,0.2709646793278654,43793.0,0.9856843948364258,0.0481309331953525,0.2610056338665237,43793.0,9855.15208029747,14529.88004231453,9855.15208029747,4672.233599424362,1.6509528160095217,0.0 -30800,0.098840795,0.027769076,,,,,,,,,,,,,,,,, -30900,0.15667474,0.027518597,,,,,,,,,,,,,,,,, -31000,0.10230838,0.027746823,,,,,,,,,,,,,,,,, -31100,0.1412059,0.027284669,,,,,,,,,,,,,,,,, -31200,0.10144151,0.028901802,,,,,,,,,,,,,,,,, -31300,0.10598586,0.027491532,,,,,,,,,,,,,,,,, -31400,0.09942765,0.02857135,,,,,,,,,,,,,,,,, -31500,0.116689034,0.024562566,,,,,,,,,,,,,,,,, -31549,,,0.9915143251419068,0.0273069459944963,0.4664824860010296,0.9866051077842712,0.0454580970108509,0.2700092224648855,43793.0,0.9858478307724,0.0483252517879009,0.2621210626835716,43793.0,10095.158047437668,14873.906448364258,10095.158047437668,4776.200966119766,1.6840364933013916,0.0 -31600,0.12272128,0.03103256,,,,,,,,,,,,,,,,, -31700,0.08826358,0.024601743,,,,,,,,,,,,,,,,, -31800,0.11099016,0.030147491,,,,,,,,,,,,,,,,, -31900,0.08742511,0.028972073,,,,,,,,,,,,,,,,, -32000,0.12467171,0.02544123,,,,,,,,,,,,,,,,, -32100,0.09581071,0.025974229,,,,,,,,,,,,,,,,, -32200,0.10690536,0.028534297,,,,,,,,,,,,,,,,, -32300,0.10367481,0.02981013,,,,,,,,,,,,,,,,, -32302,,,0.991627275943756,0.0269815512001514,0.4847153296311458,0.9866538643836976,0.0450129844248294,0.2784450487358846,43793.0,0.9857808351516724,0.0479646921157836,0.2607573388347303,43793.0,10335.340276241302,15220.165620326996,10335.340276241302,4882.224649429321,1.7169504165649414,0.0 -32400,0.1277138,0.03076765,,,,,,,,,,,,,,,,, -32500,0.09876351,0.025497807,,,,,,,,,,,,,,,,, -32600,0.103834115,0.028824657,,,,,,,,,,,,,,,,, -32700,0.13197845,0.027661767,,,,,,,,,,,,,,,,, -32800,0.10789533,0.028685868,,,,,,,,,,,,,,,,, -32900,0.12675145,0.028967153,,,,,,,,,,,,,,,,, -33000,0.13015594,0.026005816,,,,,,,,,,,,,,,,, -33058,,,0.9915615320205688,0.026972159743309,0.4788424207599614,0.9867366552352904,0.0455386377871036,0.2693432926444069,43793.0,0.9859063625335692,0.0485273450613021,0.2565631284387224,43793.0,10575.47698545456,15567.64379477501,10575.47698545456,4989.512982130051,1.749946117401123,0.0 -33100,0.10354966,0.026872134,,,,,,,,,,,,,,,,, -33200,0.13128679,0.029566772,,,,,,,,,,,,,,,,, -33300,0.12247961,0.028257947,,,,,,,,,,,,,,,,, -33400,0.098649554,0.024799258,,,,,,,,,,,,,,,,, -33500,0.10400909,0.023169834,,,,,,,,,,,,,,,,, -33600,0.10379091,0.026567828,,,,,,,,,,,,,,,,, -33700,0.09773329,0.026450362,,,,,,,,,,,,,,,,, -33800,0.10066786,0.027503569,,,,,,,,,,,,,,,,, -33804,,,0.991463303565979,0.0272666439414024,0.4701334986106348,0.986585259437561,0.0456389114260673,0.2751634706236099,43793.0,0.9856860637664796,0.0485675148665905,0.2549646396946223,43793.0,10815.58456158638,15919.281561613085,10815.58456158638,5100.989178657532,1.7831127643585205,0.0 -33900,0.09851678,0.026350748,,,,,,,,,,,,,,,,, -34000,0.13750334,0.03101851,,,,,,,,,,,,,,,,, -34100,0.09718377,0.0253439,,,,,,,,,,,,,,,,, -34200,0.09253953,0.026542839,,,,,,,,,,,,,,,,, -34300,0.11435591,0.025542807,,,,,,,,,,,,,,,,, -34400,0.08569829,0.023128113,,,,,,,,,,,,,,,,, -34500,0.12875359,0.02803044,,,,,,,,,,,,,,,,, -34562,,,0.9917646646499634,0.0262292195111513,0.4959750658599631,0.9866027235984802,0.0459834299981594,0.2699241965319351,43793.0,0.9858149886131288,0.0488108955323696,0.2672977632378215,43793.0,11055.76077890396,16265.101999282835,11055.76077890396,5206.580208301544,1.8159537315368648,0.0 -34600,0.0795386,0.02453753,,,,,,,,,,,,,,,,, -34700,0.107804105,0.027870506,,,,,,,,,,,,,,,,, -34800,0.117363155,0.026186172,,,,,,,,,,,,,,,,, -34900,0.100058645,0.026216134,,,,,,,,,,,,,,,,, -35000,0.10428933,0.026633821,,,,,,,,,,,,,,,,, -35100,0.13366286,0.027233029,,,,,,,,,,,,,,,,, -35200,0.101562195,0.023945322,,,,,,,,,,,,,,,,, -35300,0.14504099,0.027873594,,,,,,,,,,,,,,,,, -35322,,,0.992120325565338,0.0252374149858951,0.5261174470746846,0.9866863489151,0.0453717298805713,0.2695367499370574,43793.0,0.9858158230781556,0.0485912635922431,0.2602659377411682,43793.0,11295.954808950424,16612.69092822075,11295.954808950424,5313.920674800873,1.8492765426635744,0.0 -35400,0.11179623,0.026610138,,,,,,,,,,,,,,,,, -35500,0.12364779,0.029509135,,,,,,,,,,,,,,,,, -35600,0.10624648,0.025301255,,,,,,,,,,,,,,,,, -35700,0.10089898,0.02666609,,,,,,,,,,,,,,,,, -35800,0.100714974,0.02456599,,,,,,,,,,,,,,,,, -35900,0.12387836,0.027135564,,,,,,,,,,,,,,,,, -36000,0.10766964,0.024314644,,,,,,,,,,,,,,,,, -36076,,,0.9921783804893494,0.0249777026474475,0.5172506535598527,0.9864655137062072,0.0458004400134086,0.2712302921497364,43793.0,0.9856621026992798,0.0485945083200931,0.2601393911286229,43793.0,11535.98846077919,16956.04179906845,11535.98846077919,5417.184643983841,1.882709264755249,0.0 -36100,0.12672804,0.0261685,,,,,,,,,,,,,,,,, -36200,0.14213733,0.026402006,,,,,,,,,,,,,,,,, -36300,0.13173315,0.026830735,,,,,,,,,,,,,,,,, -36400,0.12516563,0.025675863,,,,,,,,,,,,,,,,, -36500,0.124122895,0.028188787,,,,,,,,,,,,,,,,, -36600,0.11356223,0.024908826,,,,,,,,,,,,,,,,, -36700,0.09584326,0.024596302,,,,,,,,,,,,,,,,, -36800,0.11529377,0.024936091,,,,,,,,,,,,,,,,, -36835,,,0.9925355911254884,0.0236619692295789,0.5555088251116154,0.9866177439689636,0.0460899136960506,0.2748257784845323,43793.0,0.985743761062622,0.0490642413496971,0.2600800066140546,43793.0,11776.233745574951,17301.020915985107,11776.233745574951,5521.864846467972,1.9163761138916016,0.0 -36900,0.12864201,0.026761966,,,,,,,,,,,,,,,,, -37000,0.10968411,0.02695885,,,,,,,,,,,,,,,,, -37100,0.11105558,0.021667061,,,,,,,,,,,,,,,,, -37200,0.17927203,0.027805708,,,,,,,,,,,,,,,,, -37300,0.16629647,0.02872264,,,,,,,,,,,,,,,,, -37400,0.12538166,0.02616758,,,,,,,,,,,,,,,,, -37500,0.1194111,0.026239544,,,,,,,,,,,,,,,,, -37588,,,0.9923935532569884,0.0240364968776702,0.550742001196731,0.986711084842682,0.0462623983621597,0.2691423838406784,43793.0,0.985876441001892,0.0493538789451122,0.2592533850321455,43793.0,12016.22732925415,17652.513543844223,12016.22732925415,5633.303159475327,1.9569110870361328,0.0 -37600,0.11160369,0.028069524,,,,,,,,,,,,,,,,, -37700,0.09764793,0.024383668,,,,,,,,,,,,,,,,, -37800,0.15691324,0.02625672,,,,,,,,,,,,,,,,, -37900,0.13072045,0.024381904,,,,,,,,,,,,,,,,, -38000,0.11132943,0.026143061,,,,,,,,,,,,,,,,, -38100,0.116699606,0.027033607,,,,,,,,,,,,,,,,, -38200,0.11972089,0.02632819,,,,,,,,,,,,,,,,, -38300,0.11309234,0.02424325,,,,,,,,,,,,,,,,, -38335,,,0.9923391938209534,0.0244967173784971,0.5252410837212944,0.986572265625,0.0459901019930839,0.2755666996524653,43793.0,0.985694706439972,0.0492275729775428,0.264713702862164,43793.0,12256.457571268082,18004.159906864166,12256.457571268082,5744.66099023819,1.9948585033416748,0.0 -38400,0.164798,0.024609743,,,,,,,,,,,,,,,,, -38500,0.1373905,0.021495387,,,,,,,,,,,,,,,,, -38600,0.11828345,0.027535353,,,,,,,,,,,,,,,,, -38700,0.11855073,0.022374269,,,,,,,,,,,,,,,,, -38800,0.102344655,0.022737982,,,,,,,,,,,,,,,,, -38900,0.11224851,0.0249236,,,,,,,,,,,,,,,,, -39000,0.11732254,0.02328035,,,,,,,,,,,,,,,,, -39088,,,0.9921321868896484,0.0250871740281581,0.5257421666684701,0.9865466952323914,0.0463662147521972,0.2729468275313599,43793.0,0.985687792301178,0.04935147985816,0.2622586007059652,43793.0,12496.689729452131,18347.39897227288,12496.689729452131,5847.613451719284,2.028700351715088,0.0 -39100,0.111568,0.024112668,,,,,,,,,,,,,,,,, -39200,0.10919571,0.0266755,,,,,,,,,,,,,,,,, -39300,0.13041551,0.028467562,,,,,,,,,,,,,,,,, -39400,0.12656307,0.024999497,,,,,,,,,,,,,,,,, -39500,0.13331045,0.024074804,,,,,,,,,,,,,,,,, -39600,0.14435485,0.025719633,,,,,,,,,,,,,,,,, -39700,0.15111578,0.022359561,,,,,,,,,,,,,,,,, -39800,0.13947028,0.025359644,,,,,,,,,,,,,,,,, -39848,,,0.992060124874115,0.0253172293305397,0.5124504340697817,0.9864711761474608,0.0464315824210643,0.2747143159924758,43793.0,0.985620379447937,0.0494289062917232,0.2571113858861016,43793.0,12736.806010484695,18695.944691181183,12736.806010484695,5955.988241195679,2.06325101852417,0.0 -39900,0.11385319,0.02473904,,,,,,,,,,,,,,,,, -40000,0.17131834,0.025870001,,,,,,,,,,,,,,,,, -40100,0.14648142,0.025539912,,,,,,,,,,,,,,,,, -40200,0.15754192,0.027267808,,,,,,,,,,,,,,,,, -40300,0.1510589,0.024362596,,,,,,,,,,,,,,,,, -40400,0.114310555,0.022913676,,,,,,,,,,,,,,,,, -40500,0.124557756,0.026979145,,,,,,,,,,,,,,,,, -40588,,,0.9921411871910096,0.0250214058905839,0.5354935396143286,0.9864374995231628,0.0461788550019264,0.2703810326860346,43793.0,0.9857370257377625,0.0491687096655368,0.2609491486758925,43793.0,12977.04095864296,19038.35589647293,12977.04095864296,6058.107916593552,2.097200632095337,0.0 -40600,0.13127768,0.026719099,,,,,,,,,,,,,,,,, -40700,0.12187339,0.023234036,,,,,,,,,,,,,,,,, -40800,0.1487418,0.02471199,,,,,,,,,,,,,,,,, -40900,0.11327993,0.025291726,,,,,,,,,,,,,,,,, -41000,0.11546061,0.024000958,,,,,,,,,,,,,,,,, -41100,0.11387761,0.022541944,,,,,,,,,,,,,,,,, -41200,0.13212201,0.0227324,,,,,,,,,,,,,,,,, -41300,0.13584355,0.026379507,,,,,,,,,,,,,,,,, -41348,,,0.992215633392334,0.0246634781360626,0.5279078039982317,0.9865621328353882,0.046393159776926,0.2734804765352344,43793.0,0.9857257008552552,0.0491464734077453,0.2626485602912387,43793.0,13217.228226184843,19386.606401205063,13217.228226184843,6166.109060764313,2.139084815979004,0.0 -41400,0.13013335,0.02666385,,,,,,,,,,,,,,,,, -41500,0.13312295,0.025633065,,,,,,,,,,,,,,,,, -41600,0.124119796,0.023086868,,,,,,,,,,,,,,,,, -41700,0.124279015,0.022445338,,,,,,,,,,,,,,,,, -41800,0.12356909,0.021686496,,,,,,,,,,,,,,,,, -41900,0.10431718,0.025123984,,,,,,,,,,,,,,,,, -42000,0.11993185,0.023767153,,,,,,,,,,,,,,,,, -42093,,,0.9923676252365112,0.0241897907108068,0.5447008755982579,0.9866648316383362,0.0466361306607723,0.2701977304227742,43793.0,0.9857661128044128,0.0498609170317649,0.2591178976031317,43793.0,13457.440511703491,19737.170392751694,13457.440511703491,6276.40532040596,2.1731395721435547,0.0 -42100,0.13766916,0.024966128,,,,,,,,,,,,,,,,, -42200,0.12362029,0.025118383,,,,,,,,,,,,,,,,, -42300,0.119845666,0.023997756,,,,,,,,,,,,,,,,, -42400,0.12144512,0.023709249,,,,,,,,,,,,,,,,, -42500,0.13480917,0.021585556,,,,,,,,,,,,,,,,, -42600,0.13566239,0.024323674,,,,,,,,,,,,,,,,, -42700,0.14701766,0.024072917,,,,,,,,,,,,,,,,, -42800,0.14617711,0.024058655,,,,,,,,,,,,,,,,, -42853,,,0.9926583170890808,0.0231785699725151,0.5682554598312384,0.9865312576293944,0.0467051304876804,0.2762466894173453,43793.0,0.985605239868164,0.0498588271439075,0.264570208509632,43793.0,13697.419059038162,20088.93888497353,13697.419059038162,6388.139190912247,2.2093663215637207,0.0 -42900,0.16447535,0.02675493,,,,,,,,,,,,,,,,, -43000,0.16336766,0.024568968,,,,,,,,,,,,,,,,, -43100,0.12301532,0.025683666,,,,,,,,,,,,,,,,, -43200,0.13748686,0.024041336,,,,,,,,,,,,,,,,, -43300,0.13679145,0.022590104,,,,,,,,,,,,,,,,, -43400,0.14359443,0.023091171,,,,,,,,,,,,,,,,, -43500,0.12562172,0.023860926,,,,,,,,,,,,,,,,, -43600,0.12924956,0.024636695,,,,,,,,,,,,,,,,, -43608,,,0.9930390119552612,0.0220589619129896,0.610813742315912,0.986449658870697,0.0466724410653114,0.277682788359816,43793.0,0.9856485724449158,0.0495678298175334,0.2643410263914481,43793.0,13937.618394374847,20431.984637498856,13937.618394374847,6490.929327249527,2.244922876358032,0.0 -43700,0.1412249,0.0236886,,,,,,,,,,,,,,,,, -43800,0.14702265,0.027299881,,,,,,,,,,,,,,,,, -43900,0.12652765,0.024161082,,,,,,,,,,,,,,,,, -44000,0.16531153,0.024775757,,,,,,,,,,,,,,,,, -44100,0.13470751,0.024534844,,,,,,,,,,,,,,,,, -44200,0.12202839,0.022940515,,,,,,,,,,,,,,,,, -44300,0.13298187,0.021430003,,,,,,,,,,,,,,,,, -44364,,,0.9931018352508544,0.0217983778566122,0.5935135327960487,0.9863875508308412,0.0471132732927799,0.2702936182597056,43793.0,0.9856389164924622,0.0498519428074359,0.2585699227760166,43793.0,14177.737513303757,20777.149724960327,14177.737513303757,6595.918229103088,2.280789852142334,0.0 -44400,0.16960178,0.027203912,,,,,,,,,,,,,,,,, -44500,0.13392118,0.023749694,,,,,,,,,,,,,,,,, -44600,0.12845087,0.021369874,,,,,,,,,,,,,,,,, -44700,0.120907456,0.023625879,,,,,,,,,,,,,,,,, -44800,0.13481444,0.024823286,,,,,,,,,,,,,,,,, -44900,0.12976708,0.021361511,,,,,,,,,,,,,,,,, -45000,0.15428726,0.024731347,,,,,,,,,,,,,,,,, -45100,0.15365851,0.026300812,,,,,,,,,,,,,,,,, -45119,,,0.9935007691383362,0.0207743290811777,0.6282470077284659,0.986407458782196,0.0473083592951297,0.2746351986985576,43793.0,0.9856823086738586,0.0505120158195495,0.2607793049964524,43793.0,14417.723826646805,21122.0652821064,14417.723826646805,6700.786629199982,2.320936918258667,0.0 -45200,0.14199375,0.021081373,,,,,,,,,,,,,,,,, -45300,0.13985237,0.023611462,,,,,,,,,,,,,,,,, -45400,0.12342511,0.022027165,,,,,,,,,,,,,,,,, -45500,0.16798797,0.024003234,,,,,,,,,,,,,,,,, -45600,0.13447118,0.021751026,,,,,,,,,,,,,,,,, -45700,0.15976296,0.024267735,,,,,,,,,,,,,,,,, -45800,0.1562958,0.022552976,,,,,,,,,,,,,,,,, -45872,,,0.993179976940155,0.021613860502839,0.5867398796278861,0.9865272045135498,0.0474358759820461,0.270803415758193,43793.0,0.9856418371200562,0.0508385933935642,0.2581153166498968,43793.0,14657.829684019089,21465.00151848793,14657.829684019089,6803.560876607895,2.355963468551636,0.0 -45900,0.16530009,0.023398897,,,,,,,,,,,,,,,,, -46000,0.12333356,0.022684483,,,,,,,,,,,,,,,,, -46100,0.14591876,0.023858428,,,,,,,,,,,,,,,,, -46200,0.12999152,0.018735383,,,,,,,,,,,,,,,,, -46300,0.15172893,0.024908856,,,,,,,,,,,,,,,,, -46400,0.15506482,0.021645222,,,,,,,,,,,,,,,,, -46500,0.16518393,0.020629793,,,,,,,,,,,,,,,,, -46600,0.15516807,0.023332965,,,,,,,,,,,,,,,,, -46624,,,0.9928723573684692,0.0224040485918521,0.5792297219554194,0.9865121841430664,0.0478705167770385,0.2754371891996536,43793.0,0.9857025146484376,0.051041230559349,0.2624502477888986,43793.0,14897.901216506958,21809.08435201645,14897.901216506958,6907.516575574875,2.390967845916748,0.0 -46700,0.12975226,0.022618512,,,,,,,,,,,,,,,,, -46800,0.13447419,0.022270191,,,,,,,,,,,,,,,,, -46900,0.14151008,0.02252225,,,,,,,,,,,,,,,,, -47000,0.144242,0.021900047,,,,,,,,,,,,,,,,, -47100,0.16925849,0.024692044,,,,,,,,,,,,,,,,, -47200,0.15780902,0.022040192,,,,,,,,,,,,,,,,, -47300,0.15511659,0.024329172,,,,,,,,,,,,,,,,, -47381,,,0.9928503632545472,0.0223389249294996,0.5829663625984829,0.9865494966506958,0.0480351597070694,0.2717624907753252,43793.0,0.9857825636863708,0.0509413219988346,0.2635603124742287,43793.0,15137.97015786171,22153.10691165924,15137.97015786171,7011.414033174515,2.427185297012329,0.0 -47400,0.15494558,0.022227526,,,,,,,,,,,,,,,,, -47500,0.18015188,0.025176471,,,,,,,,,,,,,,,,, -47600,0.15434813,0.020384016,,,,,,,,,,,,,,,,, -47700,0.1529346,0.020674307,,,,,,,,,,,,,,,,, -47800,0.15573569,0.022724649,,,,,,,,,,,,,,,,, -47900,0.170884,0.02518358,,,,,,,,,,,,,,,,, -48000,0.13257363,0.02032955,,,,,,,,,,,,,,,,, -48100,0.1401221,0.021324866,,,,,,,,,,,,,,,,, -48134,,,0.9928312301635742,0.0224734991788864,0.581797263235984,0.9864045977592468,0.0480904653668403,0.2763301327778938,43793.0,0.9855150580406188,0.0513287670910358,0.2564390795803031,43793.0,15378.18701696396,22503.418353796005,15378.18701696396,7121.453207015991,2.461944580078125,0.0 -48200,0.15978824,0.02209176,,,,,,,,,,,,,,,,, -48300,0.15417987,0.021225343,,,,,,,,,,,,,,,,, -48400,0.18201159,0.022097426,,,,,,,,,,,,,,,,, -48500,0.16306275,0.021691613,,,,,,,,,,,,,,,,, -48600,0.16634531,0.024615506,,,,,,,,,,,,,,,,, -48700,0.15179536,0.021731501,,,,,,,,,,,,,,,,, -48800,0.15830104,0.020207861,,,,,,,,,,,,,,,,, -48888,,,0.9929319024086,0.022101666778326,0.5932613878542279,0.9862564206123352,0.0482456870377063,0.2651477654096575,43793.0,0.9853625893592834,0.0514854378998279,0.255840481288772,43793.0,15618.224890708923,22853.219148159027,15618.224890708923,7231.157967567444,2.4997594356536865,0.0 -48900,0.23653616,0.02671538,,,,,,,,,,,,,,,,, -49000,0.16424996,0.021483572,,,,,,,,,,,,,,,,, -49100,0.16800202,0.022675684,,,,,,,,,,,,,,,,, -49200,0.14950608,0.021241328,,,,,,,,,,,,,,,,, -49300,0.1674796,0.023057848,,,,,,,,,,,,,,,,, -49400,0.16319296,0.021297112,,,,,,,,,,,,,,,,, -49500,0.15156122,0.020367214,,,,,,,,,,,,,,,,, -49600,0.1430372,0.020811655,,,,,,,,,,,,,,,,, -49643,,,0.9931454062461852,0.0213271249085664,0.5954668398015054,0.9864569902420044,0.0487939976155757,0.2757709988394968,43793.0,0.9856106638908386,0.0522529184818267,0.2538569310413312,43793.0,15857.749560117722,23197.264449834824,15857.749560117722,7335.19739151001,2.960538148880005,0.0 -49700,0.1508891,0.0202528,,,,,,,,,,,,,,,,, -49800,0.16912623,0.022335192,,,,,,,,,,,,,,,,, -49900,0.19982857,0.021927824,,,,,,,,,,,,,,,,, -50000,0.17098358,0.0214459,,,,,,,,,,,,,,,,, -50100,0.21808873,0.021898467,,,,,,,,,,,,,,,,, -50200,0.16713145,0.020688992,,,,,,,,,,,,,,,,, -50300,0.17432459,0.019383363,,,,,,,,,,,,,,,,, -50393,,,0.9933032393455504,0.0206730123609304,0.6219227461660122,0.986491084098816,0.0492160953581333,0.271150679389319,43793.0,0.985623300075531,0.0526586323976516,0.2555738823624148,43793.0,16097.697563171389,23546.42974758148,16097.697563171389,7444.359100818634,2.996345281600952,0.0 -50400,0.14779559,0.01957625,,,,,,,,,,,,,,,,, -50500,0.2207782,0.023409346,,,,,,,,,,,,,,,,, -50600,0.16628462,0.021987392,,,,,,,,,,,,,,,,, -50700,0.16225109,0.022141151,,,,,,,,,,,,,,,,, -50800,0.2007335,0.020907553,,,,,,,,,,,,,,,,, -50900,0.16163929,0.020123873,,,,,,,,,,,,,,,,, -51000,0.16488974,0.019711304,,,,,,,,,,,,,,,,, -51100,0.16393971,0.020410463,,,,,,,,,,,,,,,,, -51151,,,0.9937353134155272,0.0196653474122285,0.649720284927805,0.9862414002418518,0.0492569468915462,0.2753046787889687,43793.0,0.9853950142860411,0.0527043342590332,0.2526182830014674,43793.0,16337.699612855911,23888.542701005936,16337.699612855911,7546.4050052165985,3.041268825531006,0.0 -51200,0.14239824,0.017369324,,,,,,,,,,,,,,,,, -51300,0.1925156,0.019731207,,,,,,,,,,,,,,,,, -51400,0.17883244,0.02051295,,,,,,,,,,,,,,,,, -51500,0.15325557,0.021199103,,,,,,,,,,,,,,,,, -51600,0.17682414,0.02043788,,,,,,,,,,,,,,,,, -51700,0.17506513,0.019187713,,,,,,,,,,,,,,,,, -51800,0.16619001,0.021043101,,,,,,,,,,,,,,,,, -51900,0.17574091,0.020595293,,,,,,,,,,,,,,,,, -51910,,,0.9942708015441896,0.0178426392376422,0.6942757210473101,0.9864622354507446,0.0494240075349807,0.2712660328003022,43793.0,0.98554664850235,0.0531106442213058,0.2518272298354042,43793.0,16577.81442141533,24236.79912090301,16577.81442141533,7654.490574836731,3.076837539672852,0.0 -52000,0.16486195,0.020086305,,,,,,,,,,,,,,,,, -52100,0.18429518,0.02152332,,,,,,,,,,,,,,,,, -52200,0.17237617,0.018818356,,,,,,,,,,,,,,,,, -52300,0.15159002,0.018071938,,,,,,,,,,,,,,,,, -52400,0.19561747,0.017718975,,,,,,,,,,,,,,,,, -52500,0.188838,0.020221546,,,,,,,,,,,,,,,,, -52600,0.20074739,0.018507456,,,,,,,,,,,,,,,,, -52662,,,0.994036078453064,0.018522597849369,0.6743744810593322,0.9860635995864868,0.0501335337758064,0.266329344424858,43793.0,0.9852699637413024,0.05378008633852,0.2500704661699077,43793.0,16817.855093955994,24586.30546784401,16817.855093955994,7763.898961544037,3.1142618656158447,0.0 -52700,0.18369378,0.020766607,,,,,,,,,,,,,,,,, -52800,0.14959896,0.018357255,,,,,,,,,,,,,,,,, -52900,0.18991661,0.021855863,,,,,,,,,,,,,,,,, -53000,0.16704653,0.018628156,,,,,,,,,,,,,,,,, -53100,0.18559253,0.017009333,,,,,,,,,,,,,,,,, -53200,0.18265188,0.017542914,,,,,,,,,,,,,,,,, -53300,0.20863143,0.021000758,,,,,,,,,,,,,,,,, -53400,0.20309775,0.019943366,,,,,,,,,,,,,,,,, -53406,,,0.9941920042037964,0.0179198160767555,0.6768201170825074,0.986491858959198,0.050784520804882,0.2638165742392505,43793.0,0.985526442527771,0.0545035935938358,0.2508009860190624,43793.0,17057.9955971241,24933.515946626663,17057.9955971241,7870.911877632141,3.1511764526367188,0.0 -53500,0.17608604,0.020775566,,,,,,,,,,,,,,,,, -53600,0.19802086,0.021204052,,,,,,,,,,,,,,,,, -53700,0.17071405,0.017456166,,,,,,,,,,,,,,,,, -53800,0.19478032,0.020703377,,,,,,,,,,,,,,,,, -53900,0.20239955,0.020637445,,,,,,,,,,,,,,,,, -54000,0.19852336,0.020766586,,,,,,,,,,,,,,,,, -54100,0.20194247,0.021267124,,,,,,,,,,,,,,,,, -54155,,,0.9939531683921814,0.0188170745968818,0.660413799421119,0.9861224889755248,0.0512725599110126,0.2609240909249675,43793.0,0.9853095412254332,0.0547344386577606,0.2534831283282912,43793.0,17298.101543664932,25281.40832829476,17298.101543664932,7978.640509128571,3.1888413429260254,0.0 -54200,0.21025446,0.019611962,,,,,,,,,,,,,,,,, -54300,0.17088749,0.01673612,,,,,,,,,,,,,,,,, -54400,0.19701041,0.019092582,,,,,,,,,,,,,,,,, -54500,0.2333262,0.018640745,,,,,,,,,,,,,,,,, -54600,0.19809434,0.019745672,,,,,,,,,,,,,,,,, -54700,0.20268591,0.021513533,,,,,,,,,,,,,,,,, -54800,0.22537586,0.018746844,,,,,,,,,,,,,,,,, -54900,0.21358216,0.020014765,,,,,,,,,,,,,,,,, -54910,,,0.9937155842781068,0.0192698650062084,0.6451781842926773,0.9864431619644164,0.0515469796955585,0.263833178461923,43793.0,0.9854388236999512,0.0555348582565784,0.2517655666803668,43793.0,17538.212922811508,25626.154417037964,17538.212922811508,8083.218180656433,3.224980354309082,0.0 -55000,0.25068614,0.018680707,,,,,,,,,,,,,,,,, -55100,0.20019093,0.018308036,,,,,,,,,,,,,,,,, -55200,0.19046815,0.018446527,,,,,,,,,,,,,,,,, -55300,0.2194151,0.019099494,,,,,,,,,,,,,,,,, -55400,0.17969105,0.016672134,,,,,,,,,,,,,,,,, -55500,0.20238104,0.020607738,,,,,,,,,,,,,,,,, -55600,0.19124965,0.017958421,,,,,,,,,,,,,,,,, -55667,,,0.9937472343444824,0.0190456174314022,0.666710952563394,0.9862653613090516,0.0520232208073139,0.2650312759193009,43793.0,0.9853293299674988,0.0559036470949649,0.2508333407109575,43793.0,17778.314692020416,25972.37487721443,17778.314692020416,8189.279809951782,3.261626958847046,0.0 -55700,0.21659128,0.01780956,,,,,,,,,,,,,,,,, -55800,0.21080603,0.019384284,,,,,,,,,,,,,,,,, -55900,0.17786063,0.016453134,,,,,,,,,,,,,,,,, -56000,0.2180359,0.020533817,,,,,,,,,,,,,,,,, -56100,0.21469492,0.01814537,,,,,,,,,,,,,,,,, -56200,0.23408139,0.018635033,,,,,,,,,,,,,,,,, -56300,0.22136165,0.014911694,,,,,,,,,,,,,,,,, -56400,0.19714266,0.016029682,,,,,,,,,,,,,,,,, -56428,,,0.9937534928321838,0.0190392918884754,0.653744196203391,0.9862361550331116,0.0522097609937191,0.2656088692381051,43793.0,0.9853870272636414,0.0560189709067344,0.2533901667486278,43793.0,18018.523869276047,26320.63973712921,18018.523869276047,8297.278610229492,3.2980382442474365,0.0 -56500,0.20265043,0.018596474,,,,,,,,,,,,,,,,, -56600,0.22479953,0.016931046,,,,,,,,,,,,,,,,, -56700,0.2115247,0.017407484,,,,,,,,,,,,,,,,, -56800,0.21537693,0.016403895,,,,,,,,,,,,,,,,, -56900,0.19630647,0.016214643,,,,,,,,,,,,,,,,, -57000,0.21703148,0.01666189,,,,,,,,,,,,,,,,, -57100,0.22520971,0.018191759,,,,,,,,,,,,,,,,, -57170,,,0.993995189666748,0.0181292034685611,0.6793042671348181,0.986067235469818,0.0530648082494735,0.259084388158957,43793.0,0.9852640628814696,0.0566945672035217,0.2471423890335668,43793.0,18258.58660340309,26671.117931365967,18258.58660340309,8407.633393764496,3.338200807571411,0.0 -57200,0.21258417,0.01583249,,,,,,,,,,,,,,,,, -57300,0.21638983,0.016934017,,,,,,,,,,,,,,,,, -57400,0.21658994,0.015880916,,,,,,,,,,,,,,,,, -57500,0.21564665,0.016746657,,,,,,,,,,,,,,,,, -57600,0.22056952,0.016352933,,,,,,,,,,,,,,,,, -57700,0.22676784,0.017997239,,,,,,,,,,,,,,,,, -57800,0.2595732,0.016917225,,,,,,,,,,,,,,,,, -57853,,,,,,,,,,,,,,18477.1279194355,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 5f111d106..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -109.55257105827332,0.0,11.894293308258057,1,0,11.894293308258057,0.5256850719451904,0.7376683354377747,0.0260457203743644,43793,121.44691801071168,0.5288742184638977,0.7363438606262207,0.0209708392711852,0.5270814895629883,0.7374407649040222,0.0240884203701428,43793 -215.14411735534668,0.0215587615966796,251.9131505489349,743,0,251.9131505489349,0.9831504821777344,0.066874660551548,0.0423844425126298,43793,467.0992650985718,0.9868398904800416,0.0536873005330562,0.0423955400895439,0.9841423034667968,0.0637450590729713,0.0409373969587783,43793 -320.01709842681885,0.049159288406372,492.1413643360138,1499,0,492.1413643360138,0.9832835793495178,0.0638415142893791,0.0626250946805319,43793,812.2493708133698,0.9869152903556824,0.050368096679449,0.0641939864066242,0.9842705726623536,0.0605609454214572,0.0625524167338166,43793 -431.74619340896606,0.0765466690063476,732.1600904464722,2256,0,732.1600904464722,0.9835636615753174,0.0608636774122715,0.0900609695372506,43793,1164.0452189445496,0.9873570799827576,0.0469028390944004,0.0943435755349976,0.9845478534698486,0.0576092824339866,0.0900378077916038,43793 -534.4365320205688,0.1034972667694091,972.113499403,3014,0,972.113499403,0.9837780594825744,0.0581569485366344,0.1073112922785335,43793,1506.7364542484283,0.987488567829132,0.0455315560102462,0.1164872494530039,0.9847686290740968,0.0550459511578083,0.1060626290043858,43793 -642.7269651889801,0.1341676712036132,1212.357487201691,3771,0,1212.357487201691,0.9840333461761476,0.0553927794098854,0.1286881339106711,43793,1855.3214933872223,0.9876707196235656,0.0435436479747295,0.1393367840583442,0.9849566221237184,0.0524537749588489,0.1315201944459623,43793 -747.282881975174,0.161816120147705,1452.5180218219757,4529,0,1452.5180218219757,0.9841015338897704,0.0561741665005683,0.1380425379492896,43793,2200.0856885910034,0.9878915548324584,0.0427244156599044,0.1588165990742979,0.9850808382034302,0.0530340410768985,0.1406192140263638,43793 -856.7676658630371,0.1903581619262695,1692.6388084888458,5280,0,1692.6388084888458,0.983956217765808,0.054508913308382,0.1441801118123625,43793,2549.740065097809,0.9880184531211852,0.0421564318239688,0.1746008317249671,0.9848993420600892,0.0518106520175933,0.1456738089220982,43793 -961.2165246009828,0.2178926467895507,1932.7112345695496,6032,0,1932.7112345695496,0.9843584895133972,0.0538691580295562,0.1498930955587417,43793,2894.309236764908,0.9882644414901732,0.041140217334032,0.1706003199822776,0.985255002975464,0.0511212050914764,0.1506775152204858,43793 -1067.2840287685394,0.2464246749877929,2172.968799352646,6787,0,2172.968799352646,0.9844317436218262,0.0527786277234554,0.1560394778303248,43793,3240.6831436157227,0.9882400631904602,0.0409719496965408,0.1882719291767858,0.9853280186653136,0.050181183964014,0.1575803344208639,43793 -1173.091017961502,0.2735772132873535,2412.973567724228,7541,0,2412.973567724228,0.9845829606056212,0.0522934198379516,0.163858113154755,43793,3586.542414188385,0.9883561730384828,0.0403928533196449,0.1836572749367792,0.9855983853340148,0.0495022647082805,0.1682590390890557,43793 -1280.0739023685455,0.3043797016143799,2653.181145429611,8290,0,2653.181145429611,0.9844376444816588,0.052938163280487,0.1628862646202331,43793,3933.7839448452,0.9883766174316406,0.040410179644823,0.186537344706435,0.985425055027008,0.0500553324818611,0.1676451153843103,43793 -1382.3609237670898,0.3331387042999267,2893.182389497757,9043,0,2893.182389497757,0.9845362305641174,0.0523326396942138,0.1688662002125683,43793,4276.121250391007,0.9884548783302308,0.0400790758430957,0.2033323095092237,0.985439658164978,0.0495700985193252,0.1732071321595093,43793 -1487.1310422420502,0.3607993125915527,3133.343291997909,9800,0,3133.343291997909,0.9847678542137146,0.0524194948375225,0.1706226099469405,43793,4621.099257946014,0.9884865283966064,0.039673525840044,0.2058117402778865,0.9856600761413574,0.0494489483535289,0.1700376509465785,43793 -1594.379231929779,0.3902661800384521,3373.4267842769623,10563,0,3373.4267842769623,0.9847211241722108,0.052160020917654,0.1650606484348421,43793,4968.480097293854,0.9885579347610474,0.0394711606204509,0.1919396949243059,0.9855850338935852,0.049318790435791,0.1746983318028851,43793 -1696.071894645691,0.4180216789245605,3613.5157945156097,11336,0,3613.5157945156097,0.9846722483634948,0.0525008179247379,0.1774870521222304,43793,5310.309653043747,0.9885560870170592,0.0393308736383914,0.2063082084535478,0.9856479167938232,0.0493979267776012,0.1842653055696553,43793 -1801.2887377738953,0.4466302394866943,3853.641400814056,12103,0,3853.641400814056,0.9847754836082458,0.0514159128069877,0.1753369128481019,43793,5655.700120210648,0.988620102405548,0.0392267331480979,0.2103504437889654,0.9856515526771544,0.0486785285174846,0.1772119822089247,43793 -1902.346410989761,0.4743480682373047,4093.658228158951,12866,0,4093.658228158951,0.984644055366516,0.0521092340350151,0.164051053742787,43793,5996.82173871994,0.988673210144043,0.0391677431762218,0.2106503490344172,0.985570788383484,0.0492812134325504,0.1662361976664409,43793 -2010.393649101257,0.5054383277893066,4333.704358816147,13629,0,4333.704358816147,0.9846861958503724,0.0520230717957019,0.1755181138502504,43793,6344.965814590454,0.98865807056427,0.0386559441685676,0.2166301543980013,0.9856945872306824,0.0490887686610221,0.1791409499424729,43793 -2116.118327856064,0.5356552600860596,4573.699645042419,14389,0,4573.699645042419,0.9843934178352356,0.0527944862842559,0.1688209091514003,43793,6690.735986948013,0.9884803295135498,0.0401361621916294,0.1923698670874865,0.9853804111480712,0.0498729534447193,0.1726102830664344,43793 -2226.663529396057,0.5656332969665527,4813.877220630646,15152,0,4813.877220630646,0.984649121761322,0.0525808483362197,0.1737742618310206,43793,7041.50847029686,0.9884754419326782,0.0394728332757949,0.201157464733156,0.9856117963790894,0.0496006943285465,0.1733149277449493,43793 -2326.930092334748,0.5987019538879395,5054.121758937836,15903,0,5054.121758937836,0.9847834706306458,0.051879771053791,0.1804877903662326,43793,7382.07443356514,0.9885097742080688,0.039356917142868,0.2035128769114544,0.9857429265975952,0.0487789139151573,0.1836194772418177,43793 -2433.7707934379578,0.6277382373809814,5294.265455007553,16665,0,5294.265455007553,0.9848668575286864,0.0515738539397716,0.1752540595629578,43793,7729.107240438461,0.9887725710868835,0.0389116667211055,0.2065929856554009,0.9857416749000548,0.0486304834485054,0.1823984798059077,43793 -2541.077990293503,0.6633737087249756,5534.256805181503,17425,0,5534.256805181503,0.9847135543823242,0.0520097613334655,0.1743933827146704,43793,8076.462526798248,0.9885989427566528,0.039196029305458,0.2020360716367713,0.9855594038963318,0.0490655675530433,0.1677639326494566,43793 -2646.999152421952,0.6947915554046631,5774.4022517204285,18183,0,5774.4022517204285,0.9847097396850586,0.0518496893346309,0.1760573149142341,43793,8422.579911470413,0.9885631799697876,0.0392050370573997,0.2111582382971683,0.9856604933738708,0.0490625202655792,0.1755662687161741,43793 -2751.677888393402,0.729445219039917,6014.416128873825,18935,0,6014.416128873825,0.9846933484077454,0.0519953034818172,0.1730764760968453,43793,8767.328412532806,0.9886395931243896,0.0389874204993248,0.2070608752482068,0.9856272339820862,0.0490880906581878,0.1780778120628973,43793 -2852.4364881515503,0.760200023651123,6254.527894496918,19686,0,6254.527894496918,0.984761118888855,0.0513607002794742,0.1776437460243953,43793,9108.249343633652,0.9886992573738098,0.0385549515485763,0.2199436476022474,0.9856789708137512,0.0485288426280021,0.1826200142812672,43793 -2958.607923746109,0.7917554378509521,6494.588192939758,20446,0,6494.588192939758,0.984872341156006,0.0517776682972908,0.178504747280965,43793,9454.533122062683,0.9888023138046264,0.0385515615344047,0.2108082699550368,0.9857754111289978,0.0487221740186214,0.1826559964014653,43793 -3062.7118458747864,0.8235526084899902,6734.835546016693,21201,0,6734.835546016693,0.9848698377609252,0.0513353645801544,0.182898133811869,43793,9798.936812639236,0.988953709602356,0.0377000719308853,0.230811797165299,0.9858095049858092,0.048275701701641,0.187712481522891,43793 -3165.2085721492767,0.85434889793396,6975.096371412277,21957,0,6975.096371412277,0.9847274422645568,0.0518092028796672,0.1798832976796487,43793,10141.745084524156,0.988852560520172,0.0381752923130989,0.2194745163208711,0.985712468624115,0.0487480387091636,0.1864697991138633,43793 -3270.7840468883514,0.8855433464050293,7215.198048114777,22721,0,7215.198048114777,0.9849422574043274,0.051137737929821,0.1806988098203659,43793,10487.473489522934,0.9887853860855104,0.0382936522364616,0.225787219799659,0.9858614802360536,0.0484032183885574,0.1871781066362978,43793 -3377.4589943885803,0.9166724681854248,7455.259896755218,23481,0,7455.259896755218,0.9848819971084596,0.0512367002665996,0.1817081629190941,43793,10834.261020421982,0.9887216091156006,0.0385473296046257,0.2161478474097,0.9858009815216064,0.0483117811381816,0.1869150763786361,43793 -3478.77467918396,0.9493060111999512,7695.373823165894,24238,0,7695.373823165894,0.9849464893341064,0.0510138981044292,0.1909638553261999,43793,11175.742981672289,0.9888251423835754,0.038296639919281,0.2175519109860235,0.9858407378196716,0.0481756515800952,0.190224826472304,43793 -3581.462842464447,0.980283260345459,7935.566670894623,24996,0,7935.566670894623,0.9849563837051392,0.0509523451328277,0.1876386925962422,43793,11518.674816608427,0.9887945652008056,0.0382853634655475,0.2241156935598468,0.9858716130256652,0.0480336621403694,0.1920055872143085,43793 -3687.1374304294586,1.0166258811950684,8175.762567043304,25759,0,8175.762567043304,0.9849376082420348,0.0512047596275806,0.1794732225786249,43793,11864.60201716423,0.9889598488807678,0.0382597297430038,0.2139574837879134,0.9858368635177612,0.0482598207890987,0.1868899929087874,43793 -3790.8824088573456,1.0475599765777588,8415.960547924042,26525,0,8415.960547924042,0.9847586154937744,0.0519197210669517,0.1759232424762905,43793,12208.59658241272,0.9887128472328186,0.0387284010648727,0.212833312132566,0.9856792092323304,0.048808928579092,0.1811582477184639,43793 -3899.773766040802,1.078413486480713,8655.956233978271,27288,0,8655.956233978271,0.9847981929779052,0.0514655895531177,0.1775487268339715,43793,12557.534547328947,0.9887645244598388,0.0383030474185943,0.2247528497563965,0.9856621623039246,0.0485771968960762,0.183385786420843,43793 -4005.359260320664,1.1128215789794922,8896.014830827713,28035,0,8896.014830827713,0.9848756790161132,0.0512337498366832,0.1824875723104822,43793,12903.236117601396,0.9887883067131042,0.0383867248892784,0.2266030029702564,0.9857709407806396,0.0483795665204525,0.1852712474437164,43793 -4108.936638832092,1.145425796508789,9136.247852563858,28789,0,9136.247852563858,0.9850125908851624,0.051252357661724,0.1844789854736912,43793,13247.099108457563,0.9888795018196106,0.0379779003560543,0.2251258474426496,0.9859036803245544,0.0483595877885818,0.1953504302792705,43793 -4209.517950057983,1.1778452396392822,9376.250790834429,29556,0,9376.250790834429,0.9849578142166138,0.0510977432131767,0.1866836963508372,43793,13587.735835075378,0.9890247583389282,0.0373383872210979,0.2374429432051383,0.9858587980270386,0.0481251999735832,0.1895831564274472,43793 -4313.439270734787,1.2094299793243408,9616.483031272888,30318,0,9616.483031272888,0.9849645495414734,0.0513634011149406,0.1865411181259032,43793,13931.94085597992,0.9889353513717652,0.0375845544040203,0.2431385472982229,0.9859227538108826,0.0482968352735042,0.1986930130203873,43793 -4418.80347275734,1.241889476776123,9856.629340171814,31076,0,9856.629340171814,0.9844574332237244,0.051728893071413,0.1833039313442641,43793,14277.50398659706,0.9885992407798768,0.0386416278779506,0.2223428213165376,0.9853118062019348,0.0490008220076561,0.1907836688823532,43793 -4527.166420221329,1.280625581741333,10096.608315229416,31840,0,10096.608315229416,0.9849127531051636,0.0506469495594501,0.1879387186659543,43793,14625.904658794405,0.9889352321624756,0.0378828234970569,0.2303038092934673,0.9858736395835876,0.047849740833044,0.1935717885093315,43793 -4634.336527824402,1.3167879581451416,10336.679065942764,32593,0,10336.679065942764,0.9849587082862854,0.0513384491205215,0.1829017924658921,43793,14973.201986551285,0.9887871742248536,0.0383941009640693,0.2274084218291437,0.985841989517212,0.0483883619308471,0.1839510717008973,43793 -4737.503067016602,1.3489949703216553,10576.746666193008,33347,0,10576.746666193008,0.9849936366081238,0.0513042733073234,0.1857491169451275,43793,15316.488739967346,0.9890311360359192,0.0375706255435943,0.2313390646531625,0.9859755039215088,0.0481745265424251,0.1941218011418622,43793 -4844.645184755325,1.3820130825042725,10817.016341924667,34102,0,10817.016341924667,0.9850218892097472,0.0503968968987464,0.193202887268333,43793,15663.953579187391,0.9890766143798828,0.0374147966504097,0.2217049939529483,0.9859133958816528,0.0475398078560829,0.1953037639585747,43793 -4944.97594666481,1.415419578552246,11057.10607290268,34861,0,11057.10607290268,0.9849645495414734,0.0513733327388763,0.1845622175386213,43793,16004.42927980423,0.9888927936553956,0.0379797182977199,0.2367710387399423,0.9859304428100586,0.0483159534633159,0.1910359411764158,43793 -5048.230092048645,1.4496972560882568,11297.28503227234,35619,0,11297.28503227234,0.985063135623932,0.0505873449146747,0.1908041635108264,43793,16347.916496515274,0.9888798594474792,0.0377690196037292,0.2419741886575905,0.9859288334846495,0.0476657561957836,0.195245566004814,43793 -5151.097851753235,1.4821081161499023,11537.255279302595,36364,0,11537.255279302595,0.9849388599395752,0.0511537045240402,0.1846787561870888,43793,16690.80647611618,0.988921880722046,0.0374002493917942,0.2393158036051269,0.985909342765808,0.0482694059610366,0.1960000096477773,43793 -5252.974688053131,1.514685869216919,11777.304328680038,37122,0,11777.304328680038,0.9850353598594666,0.0502888746559619,0.185580799291037,43793,17032.784563302994,0.9891963601112366,0.0366889312863349,0.2415403166293018,0.9859312772750854,0.0474190339446067,0.1927813767994909,43793 -5356.93186378479,1.5475928783416748,12017.266623020172,37884,0,12017.266623020172,0.9851309657096864,0.0503636747598648,0.1916574145193226,43793,17376.75620341301,0.9892786145210266,0.0363476835191249,0.2639222874446258,0.9859499335289,0.0476463437080383,0.1962621683676298,43793 -5460.476921081543,1.582322359085083,12257.309017419817,38639,0,12257.309017419817,0.985082507133484,0.0503891482949256,0.1913735749887168,43793,17720.39829516411,0.989258587360382,0.0369451642036438,0.240270671925092,0.9859738945961,0.0476714670658111,0.1958262650743369,43793 -5564.655128240585,1.6185622215270996,12497.395084142683,39400,0,12497.395084142683,0.9850707054138184,0.0503598414361476,0.1969670857838198,43793,18064.71885228157,0.9892501831054688,0.0367242135107517,0.2511990354285915,0.9859767556190492,0.0474646426737308,0.2000518758374433,43793 -5667.876608371735,1.6552250385284424,12737.522000074388,40142,0,12737.522000074388,0.9850610494613647,0.0504180826246738,0.1941491034924058,43793,18408.12707543373,0.9890367984771729,0.0373102724552154,0.2378807756548885,0.9860019087791444,0.0474509969353675,0.1982368226275608,43793 -5774.92215013504,1.6892304420471191,12977.745255231855,40902,0,12977.745255231855,0.985190749168396,0.0495871491730213,0.1979586610062238,43793,18755.44972276688,0.989230453968048,0.0369149968028068,0.2448759926264851,0.9860546588897704,0.0470158010721206,0.1998473375823047,43793 -5875.849026441574,1.7248921394348145,13217.90337395668,41664,0,13217.90337395668,0.9852290749549866,0.049860768020153,0.1985576710004122,43793,19096.59023118019,0.9892215132713318,0.0366970710456371,0.2535574392803493,0.986094892024994,0.0470614470541477,0.1968372069270044,43793 -5978.910071611404,1.7608411312103271,13457.914669513702,42422,0,13457.914669513702,0.9852202534675598,0.0496694259345531,0.2027795862341287,43793,19439.718099355698,0.9893390536308287,0.0362650416791439,0.2505960639193763,0.9860355854034424,0.0468550026416778,0.2100582771150316,43793 -6081.750519990921,1.803160190582276,13697.976322174072,43185,0,13697.976322174072,0.9850420951843262,0.0502906031906604,0.1873192221946274,43793,19782.68240070343,0.9892221689224244,0.0366501882672309,0.2538185712336535,0.9859641790390016,0.0472843907773494,0.2026995334999386,43793 -6181.76887345314,1.839141607284546,13938.120023965836,43929,0,13938.120023965836,0.9852215051651,0.0496812611818313,0.1970707128760987,43793,20122.90183091164,0.9891306161880492,0.0365006737411022,0.2575394367368832,0.9861127138137816,0.0469780974090099,0.2011696650770483,43793 -6291.385230064392,1.874242067337036,14178.33187842369,44683,0,14178.33187842369,0.985240876674652,0.0498799122869968,0.1917108982459069,43793,20472.78648281097,0.9893518686294556,0.0362337455153465,0.2565196842884465,0.9861443638801576,0.0469116233289241,0.2061833071985306,43793 -6389.235752105713,1.912820100784301,14418.279221534727,45441,0,14418.279221534727,0.9853171110153198,0.0496277213096618,0.1984186097203279,43793,20810.643618822098,0.9894936084747314,0.0357696004211902,0.2657211170336563,0.9861301779747008,0.0469798222184181,0.2082862425106177,43793 -6495.324959516525,1.9479403495788568,14658.41047000885,46209,0,14658.41047000885,0.985183596611023,0.0496970899403095,0.1975784774136722,43793,21156.919238328934,0.9894861578941344,0.0355414748191833,0.2656732267817507,0.9860729575157166,0.046764601022005,0.2034237219377546,43793 -6599.232083559036,1.9829089641571045,14898.544739961624,46976,0,14898.544739961624,0.985218584537506,0.0495537556707859,0.1997976356789462,43793,21501.015640735623,0.9894858598709106,0.0355452746152877,0.2714054095701027,0.9861480593681335,0.0468971356749534,0.2085097651598301,43793 -6703.9477796554565,2.0181338787078857,15138.784644126892,47733,0,15138.784644126892,0.9853790402412416,0.0494207330048084,0.2046231602859208,43793,21846.0258705616,0.9893301129341124,0.0360425375401973,0.2642958725828656,0.9861845970153807,0.0468552336096763,0.2086522317130877,43793 -6805.587907791138,2.05293345451355,15378.88271021843,48488,0,15378.88271021843,0.985231637954712,0.0496157556772232,0.2004303623092527,43793,22187.8186275959,0.9893452525138856,0.0367339625954628,0.2576216190027671,0.9859787821769714,0.047152355313301,0.2089292845674148,43793 -6909.61977314949,2.088233470916748,15618.961620092392,49237,0,15618.961620092392,0.9853769540786744,0.0493928268551826,0.1981841083007525,43793,22531.986197948456,0.989425539970398,0.0356992445886135,0.2673758242081083,0.9861699938774108,0.0465958379209041,0.2079270938996903,43793 -7015.207123994827,2.1260316371917725,15859.194140434263,49998,0,15859.194140434263,0.9853028059005736,0.0492511764168739,0.2006010894362136,43793,22877.863034963608,0.9894838929176332,0.0357767902314662,0.2691995550619271,0.9862085580825806,0.0462435148656368,0.21020474569661,43793 -7117.35399389267,2.1619765758514404,16099.170221567154,50746,0,16099.170221567154,0.9852480292320251,0.0494797006249427,0.2006904166219487,43793,23220.043934583664,0.9894869923591614,0.0357800237834453,0.2574080154233004,0.9860761761665344,0.0467116720974445,0.2097283898741605,43793 -7223.010741472244,2.197328329086304,16339.22156381607,51507,0,16339.22156381607,0.9854000806808472,0.0490581057965755,0.2058574711568974,43793,23565.807424545288,0.989623725414276,0.0349702797830104,0.2895929079046951,0.986312448978424,0.0461284630000591,0.2125804910287261,43793 -7324.651810646057,2.236160516738892,16579.318229198456,52257,0,16579.318229198456,0.9854320883750916,0.0492740757763385,0.2077359707988157,43793,23907.607943296432,0.989595115184784,0.0349764414131641,0.2768938511577066,0.986238956451416,0.0464691221714019,0.2151909716914597,43793 -7426.342364788055,2.273671865463257,16819.287693738937,53018,0,16819.287693738937,0.9853954315185548,0.0490460842847824,0.2080408723529798,43793,24249.325675964355,0.989634335041046,0.0349128283560276,0.2853589611694491,0.98628568649292,0.0463243946433067,0.2164889520117149,43793 -7528.076402425766,2.310476779937744,17059.330296278,53768,0,17059.330296278,0.9854085445404052,0.0492413267493248,0.210766323279644,43793,24591.159697771072,0.9898422956466676,0.0343345664441585,0.2906300054524572,0.9863603711128236,0.0461886413395404,0.2171539752387562,43793 -7631.970919847488,2.348414182662964,17299.522981405258,54516,0,17299.522981405258,0.9854624271392822,0.0491330437362194,0.2111562767797493,43793,24935.30538392067,0.9898531436920166,0.034180212765932,0.2880437519159216,0.9863904118537904,0.0460393577814102,0.2220284243850361,43793 -7739.218340873718,2.38744592666626,17539.468029499054,55281,0,17539.468029499054,0.985429346561432,0.048757079988718,0.2053719159509973,43793,25282.556941986084,0.9897654056549072,0.0345088131725788,0.284018912475966,0.986370086669922,0.0458903796970844,0.2217395258628096,43793 -7841.314211845398,2.4286468029022217,17779.500022172928,56038,0,17779.500022172928,0.9853495359420776,0.0489507801830768,0.2030077043569287,43793,25624.747854471207,0.989662766456604,0.034886036068201,0.2843685079643878,0.9862491488456726,0.0460207685828208,0.2158384060379358,43793 -7944.941290616989,2.466990232467652,18019.45709347725,56795,0,18019.45709347725,0.9854097962379456,0.0487010665237903,0.2089325028808694,43793,25968.38992166519,0.9896244406700134,0.0349636748433113,0.279530086905652,0.9862730503082277,0.045942336320877,0.2164093587546455,43793 -8046.498164176941,2.5046470165252686,18259.68850851059,57561,0,18259.68850851059,0.985299825668335,0.048982568085193634,0.2117912201186689,43793,26310.2358212471,0.9898160099983215,0.03474228456616402,0.28430747955029,0.9861565828323364,0.046272970736026764,0.22056289875778998,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/measurements.csv deleted file mode 100644 index b675b074b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/measurements.csv +++ /dev/null @@ -1,662 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.3069108,0.7360678,,,,,,,,,,,,,,,,, -1,,,0.5288742184638977,0.7363438606262207,0.0209708392711852,0.5270814895629883,0.7374407649040222,0.0240884203701428,43793.0,0.5256850719451904,0.7376683354377747,0.0260457203743644,43793.0,11.894293308258057,121.44691801071168,11.894293308258057,109.55257105827332,0.0,0.0 -100,0.10308685,0.11775342,,,,,,,,,,,,,,,,, -200,0.010197889,0.05845621,,,,,,,,,,,,,,,,, -300,0.010209921,0.053938165,,,,,,,,,,,,,,,,, -400,0.010257046,0.053276177,,,,,,,,,,,,,,,,, -500,0.016969169,0.062173676,,,,,,,,,,,,,,,,, -600,0.042544253,0.05675369,,,,,,,,,,,,,,,,, -700,0.012070019,0.05042211,,,,,,,,,,,,,,,,, -743,,,0.9868398904800416,0.0536873005330562,0.0423955400895439,0.9841423034667968,0.0637450590729713,0.0409373969587783,43793.0,0.9831504821777344,0.066874660551548,0.0423844425126298,43793.0,251.9131505489349,467.0992650985718,251.9131505489349,215.14411735534668,0.0215587615966796,0.0 -800,0.021527508,0.060305387,,,,,,,,,,,,,,,,, -900,0.014284473,0.051464584,,,,,,,,,,,,,,,,, -1000,0.0043638595,0.051817384,,,,,,,,,,,,,,,,, -1100,0.010431802,0.046127263,,,,,,,,,,,,,,,,, -1200,0.014400263,0.051810406,,,,,,,,,,,,,,,,, -1300,0.008682702,0.047563985,,,,,,,,,,,,,,,,, -1400,0.015221537,0.05514168,,,,,,,,,,,,,,,,, -1499,,,0.9869152903556824,0.050368096679449,0.0641939864066242,0.9842705726623536,0.0605609454214572,0.0625524167338166,43793.0,0.9832835793495178,0.0638415142893791,0.0626250946805319,43793.0,492.1413643360138,812.2493708133698,492.1413643360138,320.01709842681885,0.049159288406372,0.0 -1500,0.0066469815,0.05897718,,,,,,,,,,,,,,,,, -1600,0.015194059,0.051780485,,,,,,,,,,,,,,,,, -1700,0.008079142,0.052956887,,,,,,,,,,,,,,,,, -1800,0.0080063995,0.0525817,,,,,,,,,,,,,,,,, -1900,0.0065752394,0.053044107,,,,,,,,,,,,,,,,, -2000,0.024478339,0.049947724,,,,,,,,,,,,,,,,, -2100,0.019550944,0.048801336,,,,,,,,,,,,,,,,, -2200,0.009019425,0.049574833,,,,,,,,,,,,,,,,, -2256,,,0.9873570799827576,0.0469028390944004,0.0943435755349976,0.9845478534698486,0.0576092824339866,0.0900378077916038,43793.0,0.9835636615753174,0.0608636774122715,0.0900609695372506,43793.0,732.1600904464722,1164.0452189445496,732.1600904464722,431.74619340896606,0.0765466690063476,0.0 -2300,0.034427,0.04683706,,,,,,,,,,,,,,,,, -2400,0.015382051,0.049922455,,,,,,,,,,,,,,,,, -2500,0.030708868,0.04909372,,,,,,,,,,,,,,,,, -2600,0.016006583,0.04336854,,,,,,,,,,,,,,,,, -2700,0.015501457,0.04800704,,,,,,,,,,,,,,,,, -2800,0.014787887,0.049721207,,,,,,,,,,,,,,,,, -2900,0.040422898,0.047065243,,,,,,,,,,,,,,,,, -3000,0.020678166,0.042907212,,,,,,,,,,,,,,,,, -3014,,,0.987488567829132,0.0455315560102462,0.1164872494530039,0.9847686290740968,0.0550459511578083,0.1060626290043858,43793.0,0.9837780594825744,0.0581569485366344,0.1073112922785335,43793.0,972.113499403,1506.7364542484283,972.113499403,534.4365320205688,0.1034972667694091,0.0 -3100,0.029478597,0.047931794,,,,,,,,,,,,,,,,, -3200,0.043495376,0.043803934,,,,,,,,,,,,,,,,, -3300,0.026082357,0.046035726,,,,,,,,,,,,,,,,, -3400,0.019340783,0.040448174,,,,,,,,,,,,,,,,, -3500,0.019337887,0.048225574,,,,,,,,,,,,,,,,, -3600,0.033712905,0.04576987,,,,,,,,,,,,,,,,, -3700,0.035341132,0.040965058,,,,,,,,,,,,,,,,, -3771,,,0.9876707196235656,0.0435436479747295,0.1393367840583442,0.9849566221237184,0.0524537749588489,0.1315201944459623,43793.0,0.9840333461761476,0.0553927794098854,0.1286881339106711,43793.0,1212.357487201691,1855.3214933872223,1212.357487201691,642.7269651889801,0.1341676712036132,0.0 -3800,0.02287248,0.04132059,,,,,,,,,,,,,,,,, -3900,0.0384494,0.045268193,,,,,,,,,,,,,,,,, -4000,0.031652585,0.045505814,,,,,,,,,,,,,,,,, -4100,0.021919092,0.040960867,,,,,,,,,,,,,,,,, -4200,0.0243679,0.043625176,,,,,,,,,,,,,,,,, -4300,0.022896176,0.042253476,,,,,,,,,,,,,,,,, -4400,0.04672966,0.04326633,,,,,,,,,,,,,,,,, -4500,0.025907919,0.038320653,,,,,,,,,,,,,,,,, -4529,,,0.9878915548324584,0.0427244156599044,0.1588165990742979,0.9850808382034302,0.0530340410768985,0.1406192140263638,43793.0,0.9841015338897704,0.0561741665005683,0.1380425379492896,43793.0,1452.5180218219757,2200.0856885910034,1452.5180218219757,747.282881975174,0.161816120147705,0.0 -4600,0.031122688,0.043808408,,,,,,,,,,,,,,,,, -4700,0.108133264,0.051889356,,,,,,,,,,,,,,,,, -4800,0.02578594,0.037119534,,,,,,,,,,,,,,,,, -4900,0.03793108,0.04259322,,,,,,,,,,,,,,,,, -5000,0.022719627,0.039677083,,,,,,,,,,,,,,,,, -5100,0.04373152,0.039862342,,,,,,,,,,,,,,,,, -5200,0.055217084,0.042732846,,,,,,,,,,,,,,,,, -5280,,,0.9880184531211852,0.0421564318239688,0.1746008317249671,0.9848993420600892,0.0518106520175933,0.1456738089220982,43793.0,0.983956217765808,0.054508913308382,0.1441801118123625,43793.0,1692.6388084888458,2549.740065097809,1692.6388084888458,856.7676658630371,0.1903581619262695,0.0 -5300,0.07668885,0.041886054,,,,,,,,,,,,,,,,, -5400,0.04179135,0.042741112,,,,,,,,,,,,,,,,, -5500,0.089741774,0.043523535,,,,,,,,,,,,,,,,, -5600,0.03775847,0.04039143,,,,,,,,,,,,,,,,, -5700,0.025505558,0.043028723,,,,,,,,,,,,,,,,, -5800,0.018364254,0.039741397,,,,,,,,,,,,,,,,, -5900,0.036377873,0.040752556,,,,,,,,,,,,,,,,, -6000,0.040665727,0.03786823,,,,,,,,,,,,,,,,, -6032,,,0.9882644414901732,0.041140217334032,0.1706003199822776,0.985255002975464,0.0511212050914764,0.1506775152204858,43793.0,0.9843584895133972,0.0538691580295562,0.1498930955587417,43793.0,1932.7112345695496,2894.309236764908,1932.7112345695496,961.2165246009828,0.2178926467895507,0.0 -6100,0.06069628,0.04336228,,,,,,,,,,,,,,,,, -6200,0.03411033,0.041857667,,,,,,,,,,,,,,,,, -6300,0.041157402,0.040955972,,,,,,,,,,,,,,,,, -6400,0.06560417,0.03796781,,,,,,,,,,,,,,,,, -6500,0.042723138,0.040855464,,,,,,,,,,,,,,,,, -6600,0.07177208,0.03910956,,,,,,,,,,,,,,,,, -6700,0.027615098,0.036301192,,,,,,,,,,,,,,,,, -6787,,,0.9882400631904602,0.0409719496965408,0.1882719291767858,0.9853280186653136,0.050181183964014,0.1575803344208639,43793.0,0.9844317436218262,0.0527786277234554,0.1560394778303248,43793.0,2172.968799352646,3240.6831436157227,2172.968799352646,1067.2840287685394,0.2464246749877929,0.0 -6800,0.021529103,0.03755516,,,,,,,,,,,,,,,,, -6900,0.053475168,0.043288995,,,,,,,,,,,,,,,,, -7000,0.0329147,0.037723865,,,,,,,,,,,,,,,,, -7100,0.025500286,0.041582335,,,,,,,,,,,,,,,,, -7200,0.07943519,0.040810164,,,,,,,,,,,,,,,,, -7300,0.06962063,0.04712294,,,,,,,,,,,,,,,,, -7400,0.10229809,0.041914534,,,,,,,,,,,,,,,,, -7500,0.10194337,0.039862867,,,,,,,,,,,,,,,,, -7541,,,0.9883561730384828,0.0403928533196449,0.1836572749367792,0.9855983853340148,0.0495022647082805,0.1682590390890557,43793.0,0.9845829606056212,0.0522934198379516,0.163858113154755,43793.0,2412.973567724228,3586.542414188385,2412.973567724228,1173.091017961502,0.2735772132873535,0.0 -7600,0.07168323,0.04220718,,,,,,,,,,,,,,,,, -7700,0.050682835,0.045147277,,,,,,,,,,,,,,,,, -7800,0.030969933,0.040904008,,,,,,,,,,,,,,,,, -7900,0.052315738,0.03725497,,,,,,,,,,,,,,,,, -8000,0.02399677,0.038986523,,,,,,,,,,,,,,,,, -8100,0.051015303,0.044006016,,,,,,,,,,,,,,,,, -8200,0.025344972,0.04032494,,,,,,,,,,,,,,,,, -8290,,,0.9883766174316406,0.040410179644823,0.186537344706435,0.985425055027008,0.0500553324818611,0.1676451153843103,43793.0,0.9844376444816588,0.052938163280487,0.1628862646202331,43793.0,2653.181145429611,3933.7839448452,2653.181145429611,1280.0739023685455,0.3043797016143799,0.0 -8300,0.0316727,0.036564115,,,,,,,,,,,,,,,,, -8400,0.046745587,0.038619574,,,,,,,,,,,,,,,,, -8500,0.03496598,0.041207936,,,,,,,,,,,,,,,,, -8600,0.03688793,0.036333725,,,,,,,,,,,,,,,,, -8700,0.03966713,0.040794633,,,,,,,,,,,,,,,,, -8800,0.04225741,0.03941251,,,,,,,,,,,,,,,,, -8900,0.040177956,0.045811214,,,,,,,,,,,,,,,,, -9000,0.07804496,0.043498464,,,,,,,,,,,,,,,,, -9043,,,0.9884548783302308,0.0400790758430957,0.2033323095092237,0.985439658164978,0.0495700985193252,0.1732071321595093,43793.0,0.9845362305641174,0.0523326396942138,0.1688662002125683,43793.0,2893.182389497757,4276.121250391007,2893.182389497757,1382.3609237670898,0.3331387042999267,0.0 -9100,0.057947934,0.039373357,,,,,,,,,,,,,,,,, -9200,0.076212786,0.03654371,,,,,,,,,,,,,,,,, -9300,0.096826196,0.039812528,,,,,,,,,,,,,,,,, -9400,0.028042568,0.03986728,,,,,,,,,,,,,,,,, -9500,0.020498024,0.03957833,,,,,,,,,,,,,,,,, -9600,0.04771875,0.04692169,,,,,,,,,,,,,,,,, -9700,0.062141947,0.03961836,,,,,,,,,,,,,,,,, -9800,,,0.9884865283966064,0.039673525840044,0.2058117402778865,0.9856600761413574,0.0494489483535289,0.1700376509465785,43793.0,0.9847678542137146,0.0524194948375225,0.1706226099469405,43793.0,3133.343291997909,4621.099257946014,3133.343291997909,1487.1310422420502,0.3607993125915527,0.0 -9800,0.022473924,0.037130423,,,,,,,,,,,,,,,,, -9900,0.06635717,0.041879892,,,,,,,,,,,,,,,,, -10000,0.034397636,0.038740885,,,,,,,,,,,,,,,,, -10100,0.084389314,0.037142653,,,,,,,,,,,,,,,,, -10200,0.104156144,0.03961447,,,,,,,,,,,,,,,,, -10300,0.043692578,0.040761374,,,,,,,,,,,,,,,,, -10400,0.03056778,0.03713002,,,,,,,,,,,,,,,,, -10500,0.04166804,0.039498143,,,,,,,,,,,,,,,,, -10563,,,0.9885579347610474,0.0394711606204509,0.1919396949243059,0.9855850338935852,0.049318790435791,0.1746983318028851,43793.0,0.9847211241722108,0.052160020917654,0.1650606484348421,43793.0,3373.4267842769623,4968.480097293854,3373.4267842769623,1594.379231929779,0.3902661800384521,0.0 -10600,0.030292023,0.036000744,,,,,,,,,,,,,,,,, -10700,0.066312104,0.040452782,,,,,,,,,,,,,,,,, -10800,0.045398284,0.039914675,,,,,,,,,,,,,,,,, -10900,0.052446835,0.043341,,,,,,,,,,,,,,,,, -11000,0.07870627,0.03988249,,,,,,,,,,,,,,,,, -11100,0.101765126,0.041040495,,,,,,,,,,,,,,,,, -11200,0.05992825,0.043554522,,,,,,,,,,,,,,,,, -11300,0.054362893,0.04034744,,,,,,,,,,,,,,,,, -11336,,,0.9885560870170592,0.0393308736383914,0.2063082084535478,0.9856479167938232,0.0493979267776012,0.1842653055696553,43793.0,0.9846722483634948,0.0525008179247379,0.1774870521222304,43793.0,3613.5157945156097,5310.309653043747,3613.5157945156097,1696.071894645691,0.4180216789245605,0.0 -11400,0.044622302,0.039111335,,,,,,,,,,,,,,,,, -11500,0.06647734,0.041695345,,,,,,,,,,,,,,,,, -11600,0.032311372,0.040840972,,,,,,,,,,,,,,,,, -11700,0.026011499,0.03899219,,,,,,,,,,,,,,,,, -11800,0.0829419,0.040419903,,,,,,,,,,,,,,,,, -11900,0.04640195,0.039634373,,,,,,,,,,,,,,,,, -12000,0.039074633,0.03933835,,,,,,,,,,,,,,,,, -12100,0.04683775,0.03951585,,,,,,,,,,,,,,,,, -12103,,,0.988620102405548,0.0392267331480979,0.2103504437889654,0.9856515526771544,0.0486785285174846,0.1772119822089247,43793.0,0.9847754836082458,0.0514159128069877,0.1753369128481019,43793.0,3853.641400814056,5655.700120210648,3853.641400814056,1801.2887377738953,0.4466302394866943,0.0 -12200,0.050222997,0.038349025,,,,,,,,,,,,,,,,, -12300,0.04476761,0.038549174,,,,,,,,,,,,,,,,, -12400,0.0355109,0.039397035,,,,,,,,,,,,,,,,, -12500,0.07030017,0.038334396,,,,,,,,,,,,,,,,, -12600,0.028419647,0.04275635,,,,,,,,,,,,,,,,, -12700,0.03179685,0.04102362,,,,,,,,,,,,,,,,, -12800,0.032861948,0.041462522,,,,,,,,,,,,,,,,, -12866,,,0.988673210144043,0.0391677431762218,0.2106503490344172,0.985570788383484,0.0492812134325504,0.1662361976664409,43793.0,0.984644055366516,0.0521092340350151,0.164051053742787,43793.0,4093.658228158951,5996.82173871994,4093.658228158951,1902.346410989761,0.4743480682373047,0.0 -12900,0.03741814,0.037069842,,,,,,,,,,,,,,,,, -13000,0.027624002,0.033971015,,,,,,,,,,,,,,,,, -13100,0.046429116,0.03849825,,,,,,,,,,,,,,,,, -13200,0.053482704,0.035350673,,,,,,,,,,,,,,,,, -13300,0.07177347,0.036013808,,,,,,,,,,,,,,,,, -13400,0.048858646,0.03694587,,,,,,,,,,,,,,,,, -13500,0.075965285,0.036896434,,,,,,,,,,,,,,,,, -13600,0.034869295,0.0388018,,,,,,,,,,,,,,,,, -13629,,,0.98865807056427,0.0386559441685676,0.2166301543980013,0.9856945872306824,0.0490887686610221,0.1791409499424729,43793.0,0.9846861958503724,0.0520230717957019,0.1755181138502504,43793.0,4333.704358816147,6344.965814590454,4333.704358816147,2010.393649101257,0.5054383277893066,0.0 -13700,0.045012966,0.039276455,,,,,,,,,,,,,,,,, -13800,0.03170905,0.034730412,,,,,,,,,,,,,,,,, -13900,0.028329149,0.040814184,,,,,,,,,,,,,,,,, -14000,0.030276947,0.038937777,,,,,,,,,,,,,,,,, -14100,0.029761123,0.04053815,,,,,,,,,,,,,,,,, -14200,0.056511547,0.04053184,,,,,,,,,,,,,,,,, -14300,0.027466431,0.039946355,,,,,,,,,,,,,,,,, -14389,,,0.9884803295135498,0.0401361621916294,0.1923698670874865,0.9853804111480712,0.0498729534447193,0.1726102830664344,43793.0,0.9843934178352356,0.0527944862842559,0.1688209091514003,43793.0,4573.699645042419,6690.735986948013,4573.699645042419,2116.118327856064,0.5356552600860596,0.0 -14400,0.05728669,0.04281507,,,,,,,,,,,,,,,,, -14500,0.06221482,0.036197975,,,,,,,,,,,,,,,,, -14600,0.11396109,0.042406797,,,,,,,,,,,,,,,,, -14700,0.07099814,0.04056152,,,,,,,,,,,,,,,,, -14800,0.07424833,0.03340265,,,,,,,,,,,,,,,,, -14900,0.07641686,0.03665123,,,,,,,,,,,,,,,,, -15000,0.048711386,0.04173762,,,,,,,,,,,,,,,,, -15100,0.030667385,0.035930634,,,,,,,,,,,,,,,,, -15152,,,0.9884754419326782,0.0394728332757949,0.201157464733156,0.9856117963790894,0.0496006943285465,0.1733149277449493,43793.0,0.984649121761322,0.0525808483362197,0.1737742618310206,43793.0,4813.877220630646,7041.50847029686,4813.877220630646,2226.663529396057,0.5656332969665527,0.0 -15200,0.057923447,0.04237372,,,,,,,,,,,,,,,,, -15300,0.055747125,0.038580272,,,,,,,,,,,,,,,,, -15400,0.038114112,0.037984923,,,,,,,,,,,,,,,,, -15500,0.097539656,0.03896339,,,,,,,,,,,,,,,,, -15600,0.034471504,0.041661836,,,,,,,,,,,,,,,,, -15700,0.058191337,0.038381785,,,,,,,,,,,,,,,,, -15800,0.033314563,0.0383157,,,,,,,,,,,,,,,,, -15900,0.04394106,0.04018091,,,,,,,,,,,,,,,,, -15903,,,0.9885097742080688,0.039356917142868,0.2035128769114544,0.9857429265975952,0.0487789139151573,0.1836194772418177,43793.0,0.9847834706306458,0.051879771053791,0.1804877903662326,43793.0,5054.121758937836,7382.07443356514,5054.121758937836,2326.930092334748,0.5987019538879395,0.0 -16000,0.041481007,0.037734862,,,,,,,,,,,,,,,,, -16100,0.03364255,0.039708417,,,,,,,,,,,,,,,,, -16200,0.06510343,0.04013676,,,,,,,,,,,,,,,,, -16300,0.031571172,0.04225795,,,,,,,,,,,,,,,,, -16400,0.055680156,0.03957242,,,,,,,,,,,,,,,,, -16500,0.07244237,0.043670505,,,,,,,,,,,,,,,,, -16600,0.026304351,0.041427325,,,,,,,,,,,,,,,,, -16665,,,0.9887725710868835,0.0389116667211055,0.2065929856554009,0.9857416749000548,0.0486304834485054,0.1823984798059077,43793.0,0.9848668575286864,0.0515738539397716,0.1752540595629578,43793.0,5294.265455007553,7729.107240438461,5294.265455007553,2433.7707934379578,0.6277382373809814,0.0 -16700,0.06418429,0.04302406,,,,,,,,,,,,,,,,, -16800,0.06976347,0.03865644,,,,,,,,,,,,,,,,, -16900,0.055519987,0.03774971,,,,,,,,,,,,,,,,, -17000,0.09124298,0.042532135,,,,,,,,,,,,,,,,, -17100,0.07829273,0.037881583,,,,,,,,,,,,,,,,, -17200,0.032979287,0.038644947,,,,,,,,,,,,,,,,, -17300,0.037335407,0.035649113,,,,,,,,,,,,,,,,, -17400,0.031567857,0.037081692,,,,,,,,,,,,,,,,, -17425,,,0.9885989427566528,0.039196029305458,0.2020360716367713,0.9855594038963318,0.0490655675530433,0.1677639326494566,43793.0,0.9847135543823242,0.0520097613334655,0.1743933827146704,43793.0,5534.256805181503,8076.462526798248,5534.256805181503,2541.077990293503,0.6633737087249756,0.0 -17500,0.032832053,0.041491292,,,,,,,,,,,,,,,,, -17600,0.092486754,0.036144566,,,,,,,,,,,,,,,,, -17700,0.032492902,0.03542047,,,,,,,,,,,,,,,,, -17800,0.04183221,0.03766093,,,,,,,,,,,,,,,,, -17900,0.056409903,0.038775012,,,,,,,,,,,,,,,,, -18000,0.032808997,0.03903832,,,,,,,,,,,,,,,,, -18100,0.029965753,0.043412253,,,,,,,,,,,,,,,,, -18183,,,0.9885631799697876,0.0392050370573997,0.2111582382971683,0.9856604933738708,0.0490625202655792,0.1755662687161741,43793.0,0.9847097396850586,0.0518496893346309,0.1760573149142341,43793.0,5774.4022517204285,8422.579911470413,5774.4022517204285,2646.999152421952,0.6947915554046631,0.0 -18200,0.043850884,0.03660165,,,,,,,,,,,,,,,,, -18300,0.042908724,0.04113717,,,,,,,,,,,,,,,,, -18400,0.059850037,0.039632756,,,,,,,,,,,,,,,,, -18500,0.0475542,0.038326185,,,,,,,,,,,,,,,,, -18600,0.12546466,0.040914092,,,,,,,,,,,,,,,,, -18700,0.058257576,0.043282524,,,,,,,,,,,,,,,,, -18800,0.04314983,0.03515864,,,,,,,,,,,,,,,,, -18900,0.056620732,0.03855864,,,,,,,,,,,,,,,,, -18935,,,0.9886395931243896,0.0389874204993248,0.2070608752482068,0.9856272339820862,0.0490880906581878,0.1780778120628973,43793.0,0.9846933484077454,0.0519953034818172,0.1730764760968453,43793.0,6014.416128873825,8767.328412532806,6014.416128873825,2751.677888393402,0.729445219039917,0.0 -19000,0.046237465,0.04113611,,,,,,,,,,,,,,,,, -19100,0.037638556,0.034426186,,,,,,,,,,,,,,,,, -19200,0.11748136,0.04059694,,,,,,,,,,,,,,,,, -19300,0.034378566,0.036585446,,,,,,,,,,,,,,,,, -19400,0.056235876,0.036981054,,,,,,,,,,,,,,,,, -19500,0.09719603,0.040542092,,,,,,,,,,,,,,,,, -19600,0.109885365,0.039782155,,,,,,,,,,,,,,,,, -19686,,,0.9886992573738098,0.0385549515485763,0.2199436476022474,0.9856789708137512,0.0485288426280021,0.1826200142812672,43793.0,0.984761118888855,0.0513607002794742,0.1776437460243953,43793.0,6254.527894496918,9108.249343633652,6254.527894496918,2852.4364881515503,0.760200023651123,0.0 -19700,0.08363413,0.045504738,,,,,,,,,,,,,,,,, -19800,0.025586773,0.038015082,,,,,,,,,,,,,,,,, -19900,0.03675608,0.038129453,,,,,,,,,,,,,,,,, -20000,0.041334573,0.037578717,,,,,,,,,,,,,,,,, -20100,0.075002074,0.04151332,,,,,,,,,,,,,,,,, -20200,0.040180705,0.041732803,,,,,,,,,,,,,,,,, -20300,0.03839986,0.03608468,,,,,,,,,,,,,,,,, -20400,0.051977895,0.03772379,,,,,,,,,,,,,,,,, -20446,,,0.9888023138046264,0.0385515615344047,0.2108082699550368,0.9857754111289978,0.0487221740186214,0.1826559964014653,43793.0,0.984872341156006,0.0517776682972908,0.178504747280965,43793.0,6494.588192939758,9454.533122062683,6494.588192939758,2958.607923746109,0.7917554378509521,0.0 -20500,0.05959749,0.040238753,,,,,,,,,,,,,,,,, -20600,0.031281445,0.041322056,,,,,,,,,,,,,,,,, -20700,0.108212866,0.042796608,,,,,,,,,,,,,,,,, -20800,0.051685773,0.040508077,,,,,,,,,,,,,,,,, -20900,0.08505464,0.038002364,,,,,,,,,,,,,,,,, -21000,0.07897712,0.0470736,,,,,,,,,,,,,,,,, -21100,0.046627253,0.0421092,,,,,,,,,,,,,,,,, -21200,0.034727518,0.03939471,,,,,,,,,,,,,,,,, -21201,,,0.988953709602356,0.0377000719308853,0.230811797165299,0.9858095049858092,0.048275701701641,0.187712481522891,43793.0,0.9848698377609252,0.0513353645801544,0.182898133811869,43793.0,6734.835546016693,9798.936812639236,6734.835546016693,3062.7118458747864,0.8235526084899902,0.0 -21300,0.06324209,0.03866183,,,,,,,,,,,,,,,,, -21400,0.04620128,0.038297474,,,,,,,,,,,,,,,,, -21500,0.1302833,0.038936295,,,,,,,,,,,,,,,,, -21600,0.083523735,0.03686819,,,,,,,,,,,,,,,,, -21700,0.043798078,0.03906604,,,,,,,,,,,,,,,,, -21800,0.035027,0.036074422,,,,,,,,,,,,,,,,, -21900,0.046258505,0.038090732,,,,,,,,,,,,,,,,, -21957,,,0.988852560520172,0.0381752923130989,0.2194745163208711,0.985712468624115,0.0487480387091636,0.1864697991138633,43793.0,0.9847274422645568,0.0518092028796672,0.1798832976796487,43793.0,6975.096371412277,10141.745084524156,6975.096371412277,3165.2085721492767,0.85434889793396,0.0 -22000,0.1293118,0.0374378,,,,,,,,,,,,,,,,, -22100,0.09352771,0.039418772,,,,,,,,,,,,,,,,, -22200,0.036178768,0.037529018,,,,,,,,,,,,,,,,, -22300,0.056863904,0.035370667,,,,,,,,,,,,,,,,, -22400,0.081876084,0.036374845,,,,,,,,,,,,,,,,, -22500,0.066945344,0.04123262,,,,,,,,,,,,,,,,, -22600,0.033702057,0.041483503,,,,,,,,,,,,,,,,, -22700,0.031218233,0.03524396,,,,,,,,,,,,,,,,, -22721,,,0.9887853860855104,0.0382936522364616,0.225787219799659,0.9858614802360536,0.0484032183885574,0.1871781066362978,43793.0,0.9849422574043274,0.051137737929821,0.1806988098203659,43793.0,7215.198048114777,10487.473489522934,7215.198048114777,3270.7840468883514,0.8855433464050293,0.0 -22800,0.052364577,0.0358646,,,,,,,,,,,,,,,,, -22900,0.028403914,0.03545642,,,,,,,,,,,,,,,,, -23000,0.05682157,0.042115208,,,,,,,,,,,,,,,,, -23100,0.03953883,0.037856005,,,,,,,,,,,,,,,,, -23200,0.041816022,0.04291806,,,,,,,,,,,,,,,,, -23300,0.040251713,0.04198905,,,,,,,,,,,,,,,,, -23400,0.08043726,0.037063006,,,,,,,,,,,,,,,,, -23481,,,0.9887216091156006,0.0385473296046257,0.2161478474097,0.9858009815216064,0.0483117811381816,0.1869150763786361,43793.0,0.9848819971084596,0.0512367002665996,0.1817081629190941,43793.0,7455.259896755218,10834.261020421982,7455.259896755218,3377.4589943885803,0.9166724681854248,0.0 -23500,0.033323746,0.036753353,,,,,,,,,,,,,,,,, -23600,0.05362233,0.03827748,,,,,,,,,,,,,,,,, -23700,0.03361591,0.037503723,,,,,,,,,,,,,,,,, -23800,0.044500876,0.035501216,,,,,,,,,,,,,,,,, -23900,0.04466218,0.038725723,,,,,,,,,,,,,,,,, -24000,0.12010726,0.045006208,,,,,,,,,,,,,,,,, -24100,0.059926827,0.041224383,,,,,,,,,,,,,,,,, -24200,0.04587216,0.039095484,,,,,,,,,,,,,,,,, -24238,,,0.9888251423835754,0.038296639919281,0.2175519109860235,0.9858407378196716,0.0481756515800952,0.190224826472304,43793.0,0.9849464893341064,0.0510138981044292,0.1909638553261999,43793.0,7695.373823165894,11175.742981672289,7695.373823165894,3478.77467918396,0.9493060111999512,0.0 -24300,0.056088585,0.039362006,,,,,,,,,,,,,,,,, -24400,0.07615084,0.038208548,,,,,,,,,,,,,,,,, -24500,0.10098622,0.043885846,,,,,,,,,,,,,,,,, -24600,0.07833368,0.042549826,,,,,,,,,,,,,,,,, -24700,0.09476277,0.03924113,,,,,,,,,,,,,,,,, -24800,0.033997174,0.037844468,,,,,,,,,,,,,,,,, -24900,0.05162353,0.03849124,,,,,,,,,,,,,,,,, -24996,,,0.9887945652008056,0.0382853634655475,0.2241156935598468,0.9858716130256652,0.0480336621403694,0.1920055872143085,43793.0,0.9849563837051392,0.0509523451328277,0.1876386925962422,43793.0,7935.566670894623,11518.674816608427,7935.566670894623,3581.462842464447,0.980283260345459,0.0 -25000,0.04275795,0.036212534,,,,,,,,,,,,,,,,, -25100,0.051227964,0.03990833,,,,,,,,,,,,,,,,, -25200,0.08077122,0.041094296,,,,,,,,,,,,,,,,, -25300,0.053658806,0.041416034,,,,,,,,,,,,,,,,, -25400,0.040217534,0.042434797,,,,,,,,,,,,,,,,, -25500,0.061905164,0.034995984,,,,,,,,,,,,,,,,, -25600,0.049756464,0.039425302,,,,,,,,,,,,,,,,, -25700,0.037684437,0.03670455,,,,,,,,,,,,,,,,, -25759,,,0.9889598488807678,0.0382597297430038,0.2139574837879134,0.9858368635177612,0.0482598207890987,0.1868899929087874,43793.0,0.9849376082420348,0.0512047596275806,0.1794732225786249,43793.0,8175.762567043304,11864.60201716423,8175.762567043304,3687.1374304294586,1.0166258811950684,0.0 -25800,0.051386643,0.03840654,,,,,,,,,,,,,,,,, -25900,0.04201104,0.036175944,,,,,,,,,,,,,,,,, -26000,0.027436582,0.036325585,,,,,,,,,,,,,,,,, -26100,0.055128805,0.039160423,,,,,,,,,,,,,,,,, -26200,0.038526896,0.039477244,,,,,,,,,,,,,,,,, -26300,0.054302324,0.042880993,,,,,,,,,,,,,,,,, -26400,0.03420579,0.035223976,,,,,,,,,,,,,,,,, -26500,0.03711632,0.037774697,,,,,,,,,,,,,,,,, -26525,,,0.9887128472328186,0.0387284010648727,0.212833312132566,0.9856792092323304,0.048808928579092,0.1811582477184639,43793.0,0.9847586154937744,0.0519197210669517,0.1759232424762905,43793.0,8415.960547924042,12208.59658241272,8415.960547924042,3790.8824088573456,1.0475599765777588,0.0 -26600,0.09452299,0.046094984,,,,,,,,,,,,,,,,, -26700,0.034083363,0.038571283,,,,,,,,,,,,,,,,, -26800,0.04303221,0.034378976,,,,,,,,,,,,,,,,, -26900,0.06490978,0.04076269,,,,,,,,,,,,,,,,, -27000,0.05941696,0.03897034,,,,,,,,,,,,,,,,, -27100,0.046414126,0.03536911,,,,,,,,,,,,,,,,, -27200,0.07425863,0.043199718,,,,,,,,,,,,,,,,, -27288,,,0.9887645244598388,0.0383030474185943,0.2247528497563965,0.9856621623039246,0.0485771968960762,0.183385786420843,43793.0,0.9847981929779052,0.0514655895531177,0.1775487268339715,43793.0,8655.956233978271,12557.534547328947,8655.956233978271,3899.773766040802,1.078413486480713,0.0 -27300,0.07173501,0.039999373,,,,,,,,,,,,,,,,, -27400,0.08950888,0.03762395,,,,,,,,,,,,,,,,, -27500,0.06039103,0.034919143,,,,,,,,,,,,,,,,, -27600,0.049716495,0.03915512,,,,,,,,,,,,,,,,, -27700,0.08405084,0.0392222,,,,,,,,,,,,,,,,, -27800,0.061233386,0.03673111,,,,,,,,,,,,,,,,, -27900,0.03418359,0.038986556,,,,,,,,,,,,,,,,, -28000,0.056018792,0.039698523,,,,,,,,,,,,,,,,, -28035,,,0.9887883067131042,0.0383867248892784,0.2266030029702564,0.9857709407806396,0.0483795665204525,0.1852712474437164,43793.0,0.9848756790161132,0.0512337498366832,0.1824875723104822,43793.0,8896.014830827713,12903.236117601396,8896.014830827713,4005.359260320664,1.1128215789794922,0.0 -28100,0.091491885,0.03961862,,,,,,,,,,,,,,,,, -28200,0.118173435,0.036492966,,,,,,,,,,,,,,,,, -28300,0.07097897,0.035704505,,,,,,,,,,,,,,,,, -28400,0.03784964,0.038866844,,,,,,,,,,,,,,,,, -28500,0.06907158,0.037647888,,,,,,,,,,,,,,,,, -28600,0.034989957,0.03618495,,,,,,,,,,,,,,,,, -28700,0.07183583,0.040603373,,,,,,,,,,,,,,,,, -28789,,,0.9888795018196106,0.0379779003560543,0.2251258474426496,0.9859036803245544,0.0483595877885818,0.1953504302792705,43793.0,0.9850125908851624,0.051252357661724,0.1844789854736912,43793.0,9136.247852563858,13247.099108457563,9136.247852563858,4108.936638832092,1.145425796508789,0.0 -28800,0.04597618,0.0355917,,,,,,,,,,,,,,,,, -28900,0.119660474,0.034681134,,,,,,,,,,,,,,,,, -29000,0.06026225,0.039980516,,,,,,,,,,,,,,,,, -29100,0.116815716,0.03599765,,,,,,,,,,,,,,,,, -29200,0.051627643,0.036278997,,,,,,,,,,,,,,,,, -29300,0.039115176,0.03488345,,,,,,,,,,,,,,,,, -29400,0.058740977,0.036532376,,,,,,,,,,,,,,,,, -29500,0.0482627,0.036340352,,,,,,,,,,,,,,,,, -29556,,,0.9890247583389282,0.0373383872210979,0.2374429432051383,0.9858587980270386,0.0481251999735832,0.1895831564274472,43793.0,0.9849578142166138,0.0510977432131767,0.1866836963508372,43793.0,9376.250790834429,13587.735835075378,9376.250790834429,4209.517950057983,1.1778452396392822,0.0 -29600,0.056654803,0.03758813,,,,,,,,,,,,,,,,, -29700,0.06815676,0.03714967,,,,,,,,,,,,,,,,, -29800,0.065766886,0.0449776,,,,,,,,,,,,,,,,, -29900,0.032729726,0.03704353,,,,,,,,,,,,,,,,, -30000,0.044481862,0.040303588,,,,,,,,,,,,,,,,, -30100,0.0340149,0.0371251,,,,,,,,,,,,,,,,, -30200,0.08938069,0.036895446,,,,,,,,,,,,,,,,, -30300,0.0727887,0.036279608,,,,,,,,,,,,,,,,, -30318,,,0.9889353513717652,0.0375845544040203,0.2431385472982229,0.9859227538108826,0.0482968352735042,0.1986930130203873,43793.0,0.9849645495414734,0.0513634011149406,0.1865411181259032,43793.0,9616.483031272888,13931.94085597992,9616.483031272888,4313.439270734787,1.2094299793243408,0.0 -30400,0.06869497,0.03760801,,,,,,,,,,,,,,,,, -30500,0.03993496,0.037428737,,,,,,,,,,,,,,,,, -30600,0.08646808,0.03687066,,,,,,,,,,,,,,,,, -30700,0.05567519,0.03520008,,,,,,,,,,,,,,,,, -30800,0.04398107,0.0380493,,,,,,,,,,,,,,,,, -30900,0.057362374,0.038119074,,,,,,,,,,,,,,,,, -31000,0.08289781,0.03884128,,,,,,,,,,,,,,,,, -31076,,,0.9885992407798768,0.0386416278779506,0.2223428213165376,0.9853118062019348,0.0490008220076561,0.1907836688823532,43793.0,0.9844574332237244,0.051728893071413,0.1833039313442641,43793.0,9856.629340171814,14277.50398659706,9856.629340171814,4418.80347275734,1.241889476776123,0.0 -31100,0.04506334,0.036430527,,,,,,,,,,,,,,,,, -31200,0.06067539,0.04035419,,,,,,,,,,,,,,,,, -31300,0.09172569,0.0389661,,,,,,,,,,,,,,,,, -31400,0.060565352,0.03880622,,,,,,,,,,,,,,,,, -31500,0.08470721,0.033319507,,,,,,,,,,,,,,,,, -31600,0.05798414,0.042866286,,,,,,,,,,,,,,,,, -31700,0.047313556,0.032791045,,,,,,,,,,,,,,,,, -31800,0.054398205,0.041308284,,,,,,,,,,,,,,,,, -31840,,,0.9889352321624756,0.0378828234970569,0.2303038092934673,0.9858736395835876,0.047849740833044,0.1935717885093315,43793.0,0.9849127531051636,0.0506469495594501,0.1879387186659543,43793.0,10096.608315229416,14625.904658794405,10096.608315229416,4527.166420221329,1.280625581741333,0.0 -31900,0.04506254,0.04190918,,,,,,,,,,,,,,,,, -32000,0.068008706,0.035037212,,,,,,,,,,,,,,,,, -32100,0.050973676,0.034984596,,,,,,,,,,,,,,,,, -32200,0.05341448,0.03770159,,,,,,,,,,,,,,,,, -32300,0.044122048,0.042789962,,,,,,,,,,,,,,,,, -32400,0.10849291,0.04444972,,,,,,,,,,,,,,,,, -32500,0.041487582,0.03491358,,,,,,,,,,,,,,,,, -32593,,,0.9887871742248536,0.0383941009640693,0.2274084218291437,0.985841989517212,0.0483883619308471,0.1839510717008973,43793.0,0.9849587082862854,0.0513384491205215,0.1829017924658921,43793.0,10336.679065942764,14973.201986551285,10336.679065942764,4634.336527824402,1.3167879581451416,0.0 -32600,0.04660214,0.040226452,,,,,,,,,,,,,,,,, -32700,0.04827029,0.039724745,,,,,,,,,,,,,,,,, -32800,0.05613549,0.03994754,,,,,,,,,,,,,,,,, -32900,0.06718428,0.039787594,,,,,,,,,,,,,,,,, -33000,0.06689325,0.035461556,,,,,,,,,,,,,,,,, -33100,0.04101071,0.036549684,,,,,,,,,,,,,,,,, -33200,0.04201583,0.039403543,,,,,,,,,,,,,,,,, -33300,0.09432399,0.04099142,,,,,,,,,,,,,,,,, -33347,,,0.9890311360359192,0.0375706255435943,0.2313390646531625,0.9859755039215088,0.0481745265424251,0.1941218011418622,43793.0,0.9849936366081238,0.0513042733073234,0.1857491169451275,43793.0,10576.746666193008,15316.488739967346,10576.746666193008,4737.503067016602,1.3489949703216553,0.0 -33400,0.037941553,0.03707849,,,,,,,,,,,,,,,,, -33500,0.06210032,0.032262042,,,,,,,,,,,,,,,,, -33600,0.032905754,0.037620187,,,,,,,,,,,,,,,,, -33700,0.04281924,0.039037757,,,,,,,,,,,,,,,,, -33800,0.096859,0.040055394,,,,,,,,,,,,,,,,, -33900,0.03506893,0.03706554,,,,,,,,,,,,,,,,, -34000,0.09381833,0.042306703,,,,,,,,,,,,,,,,, -34100,0.0805538,0.038066503,,,,,,,,,,,,,,,,, -34102,,,0.9890766143798828,0.0374147966504097,0.2217049939529483,0.9859133958816528,0.0475398078560829,0.1953037639585747,43793.0,0.9850218892097472,0.0503968968987464,0.193202887268333,43793.0,10817.016341924667,15663.953579187391,10817.016341924667,4844.645184755325,1.3820130825042725,0.0 -34200,0.04334488,0.03808037,,,,,,,,,,,,,,,,, -34300,0.06491461,0.034876596,,,,,,,,,,,,,,,,, -34400,0.05920271,0.03301385,,,,,,,,,,,,,,,,, -34500,0.039029814,0.037812762,,,,,,,,,,,,,,,,, -34600,0.055777714,0.03690226,,,,,,,,,,,,,,,,, -34700,0.08875724,0.038775347,,,,,,,,,,,,,,,,, -34800,0.057209417,0.03767681,,,,,,,,,,,,,,,,, -34861,,,0.9888927936553956,0.0379797182977199,0.2367710387399423,0.9859304428100586,0.0483159534633159,0.1910359411764158,43793.0,0.9849645495414734,0.0513733327388763,0.1845622175386213,43793.0,11057.10607290268,16004.42927980423,11057.10607290268,4944.97594666481,1.415419578552246,0.0 -34900,0.072371975,0.037017282,,,,,,,,,,,,,,,,, -35000,0.042412106,0.036177024,,,,,,,,,,,,,,,,, -35100,0.064765185,0.0406628,,,,,,,,,,,,,,,,, -35200,0.057943862,0.033433333,,,,,,,,,,,,,,,,, -35300,0.1069282,0.04043464,,,,,,,,,,,,,,,,, -35400,0.05113984,0.036958225,,,,,,,,,,,,,,,,, -35500,0.041355457,0.041252904,,,,,,,,,,,,,,,,, -35600,0.07880948,0.036575582,,,,,,,,,,,,,,,,, -35619,,,0.9888798594474792,0.0377690196037292,0.2419741886575905,0.9859288334846495,0.0476657561957836,0.195245566004814,43793.0,0.985063135623932,0.0505873449146747,0.1908041635108264,43793.0,11297.28503227234,16347.916496515274,11297.28503227234,5048.230092048645,1.4496972560882568,0.0 -35700,0.052448366,0.038045004,,,,,,,,,,,,,,,,, -35800,0.04575062,0.03710997,,,,,,,,,,,,,,,,, -35900,0.04206331,0.0394805,,,,,,,,,,,,,,,,, -36000,0.068858914,0.037122212,,,,,,,,,,,,,,,,, -36100,0.057111237,0.038036667,,,,,,,,,,,,,,,,, -36200,0.08769913,0.039407156,,,,,,,,,,,,,,,,, -36300,0.076684386,0.040379375,,,,,,,,,,,,,,,,, -36364,,,0.988921880722046,0.0374002493917942,0.2393158036051269,0.985909342765808,0.0482694059610366,0.1960000096477773,43793.0,0.9849388599395752,0.0511537045240402,0.1846787561870888,43793.0,11537.255279302595,16690.80647611618,11537.255279302595,5151.097851753235,1.4821081161499023,0.0 -36400,0.030490901,0.035360914,,,,,,,,,,,,,,,,, -36500,0.086497314,0.04237939,,,,,,,,,,,,,,,,, -36600,0.11227863,0.040347435,,,,,,,,,,,,,,,,, -36700,0.05727972,0.0373932,,,,,,,,,,,,,,,,, -36800,0.06417121,0.03678712,,,,,,,,,,,,,,,,, -36900,0.070259854,0.03899168,,,,,,,,,,,,,,,,, -37000,0.08170905,0.03928404,,,,,,,,,,,,,,,,, -37100,0.038990907,0.032699626,,,,,,,,,,,,,,,,, -37122,,,0.9891963601112366,0.0366889312863349,0.2415403166293018,0.9859312772750854,0.0474190339446067,0.1927813767994909,43793.0,0.9850353598594666,0.0502888746559619,0.185580799291037,43793.0,11777.304328680038,17032.784563302994,11777.304328680038,5252.974688053131,1.514685869216919,0.0 -37200,0.061661948,0.035482444,,,,,,,,,,,,,,,,, -37300,0.1417917,0.04404179,,,,,,,,,,,,,,,,, -37400,0.048240628,0.039172,,,,,,,,,,,,,,,,, -37500,0.040433504,0.0362207,,,,,,,,,,,,,,,,, -37600,0.051411133,0.03814786,,,,,,,,,,,,,,,,, -37700,0.07771938,0.035679866,,,,,,,,,,,,,,,,, -37800,0.070537,0.041456584,,,,,,,,,,,,,,,,, -37884,,,0.9892786145210266,0.0363476835191249,0.2639222874446258,0.9859499335289,0.0476463437080383,0.1962621683676298,43793.0,0.9851309657096864,0.0503636747598648,0.1916574145193226,43793.0,12017.266623020172,17376.75620341301,12017.266623020172,5356.93186378479,1.5475928783416748,0.0 -37900,0.03503598,0.03530987,,,,,,,,,,,,,,,,, -38000,0.039018318,0.039519694,,,,,,,,,,,,,,,,, -38100,0.055639487,0.039157398,,,,,,,,,,,,,,,,, -38200,0.040546767,0.038385384,,,,,,,,,,,,,,,,, -38300,0.10332928,0.035174314,,,,,,,,,,,,,,,,, -38400,0.112444684,0.038283996,,,,,,,,,,,,,,,,, -38500,0.032607988,0.032417256,,,,,,,,,,,,,,,,, -38600,0.054764364,0.040518314,,,,,,,,,,,,,,,,, -38639,,,0.989258587360382,0.0369451642036438,0.240270671925092,0.9859738945961,0.0476714670658111,0.1958262650743369,43793.0,0.985082507133484,0.0503891482949256,0.1913735749887168,43793.0,12257.309017419817,17720.39829516411,12257.309017419817,5460.476921081543,1.582322359085083,0.0 -38700,0.07323122,0.033264074,,,,,,,,,,,,,,,,, -38800,0.07068543,0.0359182,,,,,,,,,,,,,,,,, -38900,0.043360017,0.035147607,,,,,,,,,,,,,,,,, -39000,0.038617685,0.033822868,,,,,,,,,,,,,,,,, -39100,0.06262474,0.03438813,,,,,,,,,,,,,,,,, -39200,0.041689094,0.036738303,,,,,,,,,,,,,,,,, -39300,0.049582507,0.04100351,,,,,,,,,,,,,,,,, -39400,,,0.9892501831054688,0.0367242135107517,0.2511990354285915,0.9859767556190492,0.0474646426737308,0.2000518758374433,43793.0,0.9850707054138184,0.0503598414361476,0.1969670857838198,43793.0,12497.395084142683,18064.71885228157,12497.395084142683,5564.655128240585,1.6185622215270996,0.0 -39400,0.06580944,0.03656732,,,,,,,,,,,,,,,,, -39500,0.09082903,0.03500076,,,,,,,,,,,,,,,,, -39600,0.043018322,0.037229467,,,,,,,,,,,,,,,,, -39700,0.049968034,0.03559006,,,,,,,,,,,,,,,,, -39800,0.04178966,0.038132288,,,,,,,,,,,,,,,,, -39900,0.09027063,0.036543712,,,,,,,,,,,,,,,,, -40000,0.14472115,0.038093474,,,,,,,,,,,,,,,,, -40100,0.081382886,0.03664865,,,,,,,,,,,,,,,,, -40142,,,0.9890367984771729,0.0373102724552154,0.2378807756548885,0.9860019087791444,0.0474509969353675,0.1982368226275608,43793.0,0.9850610494613647,0.0504180826246738,0.1941491034924058,43793.0,12737.522000074388,18408.12707543373,12737.522000074388,5667.876608371735,1.6552250385284424,0.0 -40200,0.16853186,0.042218044,,,,,,,,,,,,,,,,, -40300,0.03899729,0.034732126,,,,,,,,,,,,,,,,, -40400,0.072863586,0.034902927,,,,,,,,,,,,,,,,, -40500,0.06347589,0.03949675,,,,,,,,,,,,,,,,, -40600,0.04553629,0.040362243,,,,,,,,,,,,,,,,, -40700,0.04868039,0.035677563,,,,,,,,,,,,,,,,, -40800,0.050667193,0.03703947,,,,,,,,,,,,,,,,, -40900,0.11166436,0.038908467,,,,,,,,,,,,,,,,, -40902,,,0.989230453968048,0.0369149968028068,0.2448759926264851,0.9860546588897704,0.0470158010721206,0.1998473375823047,43793.0,0.985190749168396,0.0495871491730213,0.1979586610062238,43793.0,12977.745255231855,18755.44972276688,12977.745255231855,5774.92215013504,1.6892304420471191,0.0 -41000,0.053287253,0.034413,,,,,,,,,,,,,,,,, -41100,0.076571405,0.03374644,,,,,,,,,,,,,,,,, -41200,0.10766653,0.033497967,,,,,,,,,,,,,,,,, -41300,0.058871385,0.039120432,,,,,,,,,,,,,,,,, -41400,0.07695556,0.04062382,,,,,,,,,,,,,,,,, -41500,0.054181464,0.0379822,,,,,,,,,,,,,,,,, -41600,0.09738934,0.034448523,,,,,,,,,,,,,,,,, -41664,,,0.9892215132713318,0.0366970710456371,0.2535574392803493,0.986094892024994,0.0470614470541477,0.1968372069270044,43793.0,0.9852290749549866,0.049860768020153,0.1985576710004122,43793.0,13217.90337395668,19096.59023118019,13217.90337395668,5875.849026441574,1.7248921394348145,0.0 -41700,0.038072303,0.03515253,,,,,,,,,,,,,,,,, -41800,0.0800056,0.033703696,,,,,,,,,,,,,,,,, -41900,0.062696934,0.03961587,,,,,,,,,,,,,,,,, -42000,0.1309359,0.035381816,,,,,,,,,,,,,,,,, -42100,0.08591978,0.03841223,,,,,,,,,,,,,,,,, -42200,0.04573601,0.03600306,,,,,,,,,,,,,,,,, -42300,0.041110065,0.03842771,,,,,,,,,,,,,,,,, -42400,0.03595663,0.037345205,,,,,,,,,,,,,,,,, -42422,,,0.9893390536308287,0.0362650416791439,0.2505960639193763,0.9860355854034424,0.0468550026416778,0.2100582771150316,43793.0,0.9852202534675598,0.0496694259345531,0.2027795862341287,43793.0,13457.914669513702,19439.718099355698,13457.914669513702,5978.910071611404,1.7608411312103271,0.0 -42500,0.05306524,0.032225084,,,,,,,,,,,,,,,,, -42600,0.056034334,0.036884427,,,,,,,,,,,,,,,,, -42700,0.055405177,0.037055396,,,,,,,,,,,,,,,,, -42800,0.09288147,0.036408007,,,,,,,,,,,,,,,,, -42900,0.055921454,0.044016674,,,,,,,,,,,,,,,,, -43000,0.1370969,0.04013812,,,,,,,,,,,,,,,,, -43100,0.045535192,0.03831421,,,,,,,,,,,,,,,,, -43185,,,0.9892221689224244,0.0366501882672309,0.2538185712336535,0.9859641790390016,0.0472843907773494,0.2026995334999386,43793.0,0.9850420951843262,0.0502906031906604,0.1873192221946274,43793.0,13697.976322174072,19782.68240070343,13697.976322174072,6081.750519990921,1.803160190582276,0.0 -43200,0.13315669,0.040090367,,,,,,,,,,,,,,,,, -43300,0.04223978,0.036788207,,,,,,,,,,,,,,,,, -43400,0.06600581,0.03589589,,,,,,,,,,,,,,,,, -43500,0.08249431,0.036277864,,,,,,,,,,,,,,,,, -43600,0.070701815,0.037856866,,,,,,,,,,,,,,,,, -43700,0.1558468,0.03792875,,,,,,,,,,,,,,,,, -43800,0.060936045,0.039855022,,,,,,,,,,,,,,,,, -43900,0.06854598,0.039182153,,,,,,,,,,,,,,,,, -43929,,,0.9891306161880492,0.0365006737411022,0.2575394367368832,0.9861127138137816,0.0469780974090099,0.2011696650770483,43793.0,0.9852215051651,0.0496812611818313,0.1970707128760987,43793.0,13938.120023965836,20122.90183091164,13938.120023965836,6181.76887345314,1.839141607284546,0.0 -44000,0.08208859,0.03929339,,,,,,,,,,,,,,,,, -44100,0.12821689,0.03456166,,,,,,,,,,,,,,,,, -44200,0.0778136,0.034270883,,,,,,,,,,,,,,,,, -44300,0.04988577,0.033958912,,,,,,,,,,,,,,,,, -44400,0.085729174,0.04156739,,,,,,,,,,,,,,,,, -44500,0.05849542,0.039489895,,,,,,,,,,,,,,,,, -44600,0.10255275,0.032909248,,,,,,,,,,,,,,,,, -44683,,,0.9893518686294556,0.0362337455153465,0.2565196842884465,0.9861443638801576,0.0469116233289241,0.2061833071985306,43793.0,0.985240876674652,0.0498799122869968,0.1917108982459069,43793.0,14178.33187842369,20472.78648281097,14178.33187842369,6291.385230064392,1.874242067337036,0.0 -44700,0.056735966,0.03698265,,,,,,,,,,,,,,,,, -44800,0.06546196,0.038099878,,,,,,,,,,,,,,,,, -44900,0.07054286,0.037194066,,,,,,,,,,,,,,,,, -45000,0.056077465,0.03965208,,,,,,,,,,,,,,,,, -45100,0.07191596,0.041129734,,,,,,,,,,,,,,,,, -45200,0.06499606,0.034206755,,,,,,,,,,,,,,,,, -45300,0.087026715,0.038022473,,,,,,,,,,,,,,,,, -45400,0.049566176,0.03622679,,,,,,,,,,,,,,,,, -45441,,,0.9894936084747314,0.0357696004211902,0.2657211170336563,0.9861301779747008,0.0469798222184181,0.2082862425106177,43793.0,0.9853171110153198,0.0496277213096618,0.1984186097203279,43793.0,14418.279221534727,20810.643618822098,14418.279221534727,6389.235752105713,1.912820100784301,0.0 -45500,0.055652454,0.034747425,,,,,,,,,,,,,,,,, -45600,0.057046235,0.033871833,,,,,,,,,,,,,,,,, -45700,0.04906808,0.037512753,,,,,,,,,,,,,,,,, -45800,0.056087326,0.035752553,,,,,,,,,,,,,,,,, -45900,0.08003013,0.0352714,,,,,,,,,,,,,,,,, -46000,0.0571211,0.03649424,,,,,,,,,,,,,,,,, -46100,0.042312615,0.037824698,,,,,,,,,,,,,,,,, -46200,0.07370897,0.02936723,,,,,,,,,,,,,,,,, -46209,,,0.9894861578941344,0.0355414748191833,0.2656732267817507,0.9860729575157166,0.046764601022005,0.2034237219377546,43793.0,0.985183596611023,0.0496970899403095,0.1975784774136722,43793.0,14658.41047000885,21156.919238328934,14658.41047000885,6495.324959516525,1.9479403495788568,0.0 -46300,0.121337876,0.03865522,,,,,,,,,,,,,,,,, -46400,0.047752455,0.035933807,,,,,,,,,,,,,,,,, -46500,0.1119985,0.036692627,,,,,,,,,,,,,,,,, -46600,0.05856296,0.03669379,,,,,,,,,,,,,,,,, -46700,0.08868686,0.038124405,,,,,,,,,,,,,,,,, -46800,0.06726062,0.03380443,,,,,,,,,,,,,,,,, -46900,0.061366588,0.03580955,,,,,,,,,,,,,,,,, -46976,,,0.9894858598709106,0.0355452746152877,0.2714054095701027,0.9861480593681335,0.0468971356749534,0.2085097651598301,43793.0,0.985218584537506,0.0495537556707859,0.1997976356789462,43793.0,14898.544739961624,21501.015640735623,14898.544739961624,6599.232083559036,1.9829089641571045,0.0 -47000,0.056643393,0.035148345,,,,,,,,,,,,,,,,, -47100,0.08915733,0.037719943,,,,,,,,,,,,,,,,, -47200,0.05647426,0.036987845,,,,,,,,,,,,,,,,, -47300,0.106808454,0.03943732,,,,,,,,,,,,,,,,, -47400,0.053713497,0.035231482,,,,,,,,,,,,,,,,, -47500,0.07652828,0.039229903,,,,,,,,,,,,,,,,, -47600,0.050632153,0.03307459,,,,,,,,,,,,,,,,, -47700,0.06398856,0.034837134,,,,,,,,,,,,,,,,, -47733,,,0.9893301129341124,0.0360425375401973,0.2642958725828656,0.9861845970153807,0.0468552336096763,0.2086522317130877,43793.0,0.9853790402412416,0.0494207330048084,0.2046231602859208,43793.0,15138.784644126892,21846.0258705616,15138.784644126892,6703.9477796554565,2.0181338787078857,0.0 -47800,0.10006428,0.0369333,,,,,,,,,,,,,,,,, -47900,0.11679552,0.042826265,,,,,,,,,,,,,,,,, -48000,0.083376676,0.033282775,,,,,,,,,,,,,,,,, -48100,0.061014235,0.035325516,,,,,,,,,,,,,,,,, -48200,0.077497296,0.037510797,,,,,,,,,,,,,,,,, -48300,0.06281374,0.03896522,,,,,,,,,,,,,,,,, -48400,0.08910286,0.03508895,,,,,,,,,,,,,,,,, -48488,,,0.9893452525138856,0.0367339625954628,0.2576216190027671,0.9859787821769714,0.047152355313301,0.2089292845674148,43793.0,0.985231637954712,0.0496157556772232,0.2004303623092527,43793.0,15378.88271021843,22187.8186275959,15378.88271021843,6805.587907791138,2.05293345451355,0.0 -48500,0.090465575,0.036665525,,,,,,,,,,,,,,,,, -48600,0.07029349,0.040920474,,,,,,,,,,,,,,,,, -48700,0.066444084,0.036442254,,,,,,,,,,,,,,,,, -48800,0.080100715,0.037633628,,,,,,,,,,,,,,,,, -48900,0.1497317,0.04289184,,,,,,,,,,,,,,,,, -49000,0.064518206,0.0352752,,,,,,,,,,,,,,,,, -49100,0.07150586,0.03923223,,,,,,,,,,,,,,,,, -49200,0.08956062,0.03624821,,,,,,,,,,,,,,,,, -49237,,,0.989425539970398,0.0356992445886135,0.2673758242081083,0.9861699938774108,0.0465958379209041,0.2079270938996903,43793.0,0.9853769540786744,0.0493928268551826,0.1981841083007525,43793.0,15618.961620092392,22531.986197948456,15618.961620092392,6909.61977314949,2.088233470916748,0.0 -49300,0.10272979,0.038165092,,,,,,,,,,,,,,,,, -49400,0.21622115,0.036756173,,,,,,,,,,,,,,,,, -49500,0.05457542,0.036979664,,,,,,,,,,,,,,,,, -49600,0.1195137,0.034875173,,,,,,,,,,,,,,,,, -49700,0.10796082,0.032982454,,,,,,,,,,,,,,,,, -49800,0.09542282,0.037108146,,,,,,,,,,,,,,,,, -49900,0.07313003,0.039746623,,,,,,,,,,,,,,,,, -49998,,,0.9894838929176332,0.0357767902314662,0.2691995550619271,0.9862085580825806,0.0462435148656368,0.21020474569661,43793.0,0.9853028059005736,0.0492511764168739,0.2006010894362136,43793.0,15859.194140434263,22877.863034963608,15859.194140434263,7015.207123994827,2.1260316371917725,0.0 -50000,0.056426544,0.03805737,,,,,,,,,,,,,,,,, -50100,0.098630615,0.03698839,,,,,,,,,,,,,,,,, -50200,0.0737299,0.033002883,,,,,,,,,,,,,,,,, -50300,0.050795153,0.035532847,,,,,,,,,,,,,,,,, -50400,0.1526369,0.033765968,,,,,,,,,,,,,,,,, -50500,0.080388315,0.040334698,,,,,,,,,,,,,,,,, -50600,0.060210165,0.03850567,,,,,,,,,,,,,,,,, -50700,0.071939394,0.03805623,,,,,,,,,,,,,,,,, -50746,,,0.9894869923591614,0.0357800237834453,0.2574080154233004,0.9860761761665344,0.0467116720974445,0.2097283898741605,43793.0,0.9852480292320251,0.0494797006249427,0.2006904166219487,43793.0,16099.170221567154,23220.043934583664,16099.170221567154,7117.35399389267,2.1619765758514404,0.0 -50800,0.06560708,0.035528798,,,,,,,,,,,,,,,,, -50900,0.05792928,0.034479246,,,,,,,,,,,,,,,,, -51000,0.08831135,0.03539307,,,,,,,,,,,,,,,,, -51100,0.060862537,0.03568076,,,,,,,,,,,,,,,,, -51200,0.09432317,0.031088868,,,,,,,,,,,,,,,,, -51300,0.08166249,0.03605457,,,,,,,,,,,,,,,,, -51400,0.09024346,0.035826106,,,,,,,,,,,,,,,,, -51500,0.087231874,0.03859815,,,,,,,,,,,,,,,,, -51507,,,0.989623725414276,0.0349702797830104,0.2895929079046951,0.986312448978424,0.0461284630000591,0.2125804910287261,43793.0,0.9854000806808472,0.0490581057965755,0.2058574711568974,43793.0,16339.22156381607,23565.807424545288,16339.22156381607,7223.010741472244,2.197328329086304,0.0 -51600,0.08176344,0.034788027,,,,,,,,,,,,,,,,, -51700,0.08435576,0.033178307,,,,,,,,,,,,,,,,, -51800,0.115303524,0.034108862,,,,,,,,,,,,,,,,, -51900,0.1010294,0.03630296,,,,,,,,,,,,,,,,, -52000,0.09047104,0.039102778,,,,,,,,,,,,,,,,, -52100,0.05824798,0.03897659,,,,,,,,,,,,,,,,, -52200,0.058101397,0.03245311,,,,,,,,,,,,,,,,, -52257,,,0.989595115184784,0.0349764414131641,0.2768938511577066,0.986238956451416,0.0464691221714019,0.2151909716914597,43793.0,0.9854320883750916,0.0492740757763385,0.2077359707988157,43793.0,16579.318229198456,23907.607943296432,16579.318229198456,7324.651810646057,2.236160516738892,0.0 -52300,0.077377744,0.03139,,,,,,,,,,,,,,,,, -52400,0.05245097,0.032283317,,,,,,,,,,,,,,,,, -52500,0.06902958,0.035135962,,,,,,,,,,,,,,,,, -52600,0.11142395,0.030546866,,,,,,,,,,,,,,,,, -52700,0.0810783,0.03763711,,,,,,,,,,,,,,,,, -52800,0.09922578,0.0322399,,,,,,,,,,,,,,,,, -52900,0.07584402,0.040832337,,,,,,,,,,,,,,,,, -53000,0.06511543,0.032947622,,,,,,,,,,,,,,,,, -53018,,,0.989634335041046,0.0349128283560276,0.2853589611694491,0.98628568649292,0.0463243946433067,0.2164889520117149,43793.0,0.9853954315185548,0.0490460842847824,0.2080408723529798,43793.0,16819.287693738937,24249.325675964355,16819.287693738937,7426.342364788055,2.273671865463257,0.0 -53100,0.06426112,0.031838745,,,,,,,,,,,,,,,,, -53200,0.0933905,0.032142065,,,,,,,,,,,,,,,,, -53300,0.09617607,0.03494774,,,,,,,,,,,,,,,,, -53400,0.11115303,0.037541904,,,,,,,,,,,,,,,,, -53500,0.0786357,0.037274867,,,,,,,,,,,,,,,,, -53600,0.09202931,0.040294234,,,,,,,,,,,,,,,,, -53700,0.14632578,0.031883772,,,,,,,,,,,,,,,,, -53768,,,0.9898422956466676,0.0343345664441585,0.2906300054524572,0.9863603711128236,0.0461886413395404,0.2171539752387562,43793.0,0.9854085445404052,0.0492413267493248,0.210766323279644,43793.0,17059.330296278,24591.159697771072,17059.330296278,7528.076402425766,2.310476779937744,0.0 -53800,0.15279327,0.036346115,,,,,,,,,,,,,,,,, -53900,0.096763544,0.037031986,,,,,,,,,,,,,,,,, -54000,0.08598362,0.03739958,,,,,,,,,,,,,,,,, -54100,0.06295095,0.03935087,,,,,,,,,,,,,,,,, -54200,0.07292707,0.037066594,,,,,,,,,,,,,,,,, -54300,0.092045136,0.031261656,,,,,,,,,,,,,,,,, -54400,0.085957356,0.03422762,,,,,,,,,,,,,,,,, -54500,0.08441129,0.035183944,,,,,,,,,,,,,,,,, -54516,,,0.9898531436920166,0.034180212765932,0.2880437519159216,0.9863904118537904,0.0460393577814102,0.2220284243850361,43793.0,0.9854624271392822,0.0491330437362194,0.2111562767797493,43793.0,17299.522981405258,24935.30538392067,17299.522981405258,7631.970919847488,2.348414182662964,0.0 -54600,0.06821673,0.039576415,,,,,,,,,,,,,,,,, -54700,0.08068359,0.037289627,,,,,,,,,,,,,,,,, -54800,0.085043766,0.0347119,,,,,,,,,,,,,,,,, -54900,0.06574007,0.033299603,,,,,,,,,,,,,,,,, -55000,0.13046575,0.035758402,,,,,,,,,,,,,,,,, -55100,0.07791745,0.032709703,,,,,,,,,,,,,,,,, -55200,0.058898922,0.03187015,,,,,,,,,,,,,,,,, -55281,,,0.9897654056549072,0.0345088131725788,0.284018912475966,0.986370086669922,0.0458903796970844,0.2217395258628096,43793.0,0.985429346561432,0.048757079988718,0.2053719159509973,43793.0,17539.468029499054,25282.556941986084,17539.468029499054,7739.218340873718,2.38744592666626,0.0 -55300,0.14616713,0.038151518,,,,,,,,,,,,,,,,, -55400,0.14080755,0.031125521,,,,,,,,,,,,,,,,, -55500,0.072806396,0.037448585,,,,,,,,,,,,,,,,, -55600,0.059921157,0.03532398,,,,,,,,,,,,,,,,, -55700,0.0928509,0.03403815,,,,,,,,,,,,,,,,, -55800,0.0794539,0.036876757,,,,,,,,,,,,,,,,, -55900,0.12056307,0.030476758,,,,,,,,,,,,,,,,, -56000,0.087106965,0.03746855,,,,,,,,,,,,,,,,, -56038,,,0.989662766456604,0.034886036068201,0.2843685079643878,0.9862491488456726,0.0460207685828208,0.2158384060379358,43793.0,0.9853495359420776,0.0489507801830768,0.2030077043569287,43793.0,17779.500022172928,25624.747854471207,17779.500022172928,7841.314211845398,2.4286468029022217,0.0 -56100,0.0718053,0.03462376,,,,,,,,,,,,,,,,, -56200,0.12215198,0.037889056,,,,,,,,,,,,,,,,, -56300,0.10080029,0.032301836,,,,,,,,,,,,,,,,, -56400,0.07444165,0.03420947,,,,,,,,,,,,,,,,, -56500,0.10535595,0.034733117,,,,,,,,,,,,,,,,, -56600,0.0964139,0.035694472,,,,,,,,,,,,,,,,, -56700,0.12052881,0.03600343,,,,,,,,,,,,,,,,, -56795,,,0.9896244406700134,0.0349636748433113,0.279530086905652,0.9862730503082277,0.045942336320877,0.2164093587546455,43793.0,0.9854097962379456,0.0487010665237903,0.2089325028808694,43793.0,18019.45709347725,25968.38992166519,18019.45709347725,7944.941290616989,2.466990232467652,0.0 -56800,0.088275656,0.03182714,,,,,,,,,,,,,,,,, -56900,0.10643354,0.03459491,,,,,,,,,,,,,,,,, -57000,0.08662338,0.035979293,,,,,,,,,,,,,,,,, -57100,0.10945864,0.03390678,,,,,,,,,,,,,,,,, -57200,0.10198492,0.033562884,,,,,,,,,,,,,,,,, -57300,0.100802004,0.034313608,,,,,,,,,,,,,,,,, -57400,0.103216216,0.031757288,,,,,,,,,,,,,,,,, -57500,0.14496742,0.038493376,,,,,,,,,,,,,,,,, -57561,,,0.9898160099983216,0.034742284566164,0.28430747955029,0.9861565828323364,0.0462729707360267,0.2205628987577899,43793.0,0.985299825668335,0.0489825680851936,0.2117912201186689,43793.0,18259.68850851059,26310.2358212471,18259.68850851059,8046.498164176941,2.504647016525269,0.0 -57600,0.11436214,0.03444787,,,,,,,,,,,,,,,,, -57700,0.080235824,0.037231516,,,,,,,,,,,,,,,,, -57800,0.10681933,0.034105726,,,,,,,,,,,,,,,,, -57900,0.15417346,0.03630937,,,,,,,,,,,,,,,,, -58000,0.16075099,0.03668728,,,,,,,,,,,,,,,,, -58100,0.11085169,0.0317833,,,,,,,,,,,,,,,,, -58200,0.10300039,0.03452971,,,,,,,,,,,,,,,,, -58253,,,,,,,,,,,,,,18477.287540912628,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/eval_measurements.csv deleted file mode 100644 index f4d4d47c9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -103.042014837265,0.0,13.132712841033936,1,0,13.132712841033936,0.5256850719451904,0.7376683354377747,0.026033115042489,43793,116.17477297782898,0.5288358926773071,0.7364843487739563,0.0208787674226513,0.5270814895629883,0.737440824508667,0.024115497491982,43793 -205.7111649513245,0.3030984401702881,252.96980333328247,749,0,252.96980333328247,0.983132779598236,0.0632855519652366,0.0582553299126998,43793,459.00523805618286,0.9869271516799928,0.0506599321961402,0.0599720103019728,0.984121561050415,0.0600958317518234,0.0569143727465115,43793 -313.2293353080749,0.3306779861450195,493.0213305950165,1503,0,493.0213305950165,0.983751118183136,0.060521088540554,0.1071025414931932,43793,806.6227686405182,0.98733389377594,0.0471424348652362,0.1090940666267015,0.9847106337547302,0.0569747574627399,0.1081028403749525,43793 -416.8628304004669,0.3571460247039795,733.1634314060211,2259,0,733.1634314060211,0.984021544456482,0.0551568306982517,0.1416593044334367,43793,1150.4456820487976,0.9877690076828004,0.0432558394968509,0.1465678209953282,0.9849586486816406,0.0523457862436771,0.1377446736224435,43793 -519.2329788208008,0.3839807510375976,973.2205414772034,3014,0,973.2205414772034,0.9841373562812804,0.0538092143833637,0.1561581495213794,43793,1492.9210562705994,0.9881166219711304,0.0416135974228382,0.1762406938747769,0.9850816130638124,0.0510498061776161,0.1529185079755586,43793 -619.2677090167999,0.4104807376861572,1213.301376581192,3762,0,1213.301376581192,0.984641969203949,0.0520867556333541,0.1823691982921597,43793,1833.0843858718872,0.9884961247444152,0.0397477447986602,0.1983987575133168,0.9855943322181702,0.0491466596722602,0.1765521805688691,43793 -719.900552034378,0.438244104385376,1453.3801944255829,4510,0,1453.3801944255829,0.9848470687866212,0.0507810413837432,0.1934886227088521,43793,2173.8443462848663,0.9887267351150512,0.0388303436338901,0.2203952174799551,0.9857790470123292,0.0481136739253997,0.1902895326825564,43793 -823.4393961429596,0.4656791687011719,1693.555429935455,5261,0,1693.555429935455,0.984906017780304,0.0503714494407177,0.2053310842785677,43793,2517.606040239334,0.988910675048828,0.0377405993640422,0.237520259220091,0.985815167427063,0.047681551426649,0.2013130685938968,43793 -930.9390976428986,0.4940738677978515,1933.570939540863,6010,0,1933.570939540863,0.985258162021637,0.0491888448596,0.2155477305191308,43793,2865.170319318772,0.9891675114631652,0.0366061218082904,0.2542204334360372,0.9861252903938292,0.046566005796194,0.216148982664498,43793 -1034.9742851257324,0.5242447853088379,2173.825298309326,6751,0,2173.825298309326,0.9853482842445374,0.0490752831101417,0.2274546894836015,43793,3209.513146877289,0.9891411662101746,0.0364784225821495,0.2799573602779195,0.9861935377120972,0.0465248227119445,0.220708275832851,43793 -1142.034281015396,0.5527384281158447,2414.03213095665,7506,0,2414.03213095665,0.98521226644516,0.0492332838475704,0.2200865516961716,43793,3556.829462051392,0.9896240234375,0.0350513271987438,0.2970691115022231,0.9860489964485168,0.0466834381222724,0.2174827924270316,43793 -1248.3411090373993,0.5830700397491455,2654.151238441468,8241,0,2654.151238441468,0.9855904579162598,0.0480321943759918,0.235439831915132,43793,3903.309905767441,0.9899398684501648,0.0338945463299751,0.3269750978417046,0.9864467978477478,0.0452142804861068,0.238935644503845,43793 -1359.987991809845,0.6116423606872559,2894.265217065811,8992,0,2894.265217065811,0.9855828881263732,0.0478238090872764,0.2394972606796361,43793,4255.120206356049,0.990164875984192,0.0329089388251304,0.343574061270163,0.9863510131835938,0.0453560687601566,0.2322570774352511,43793 -1463.3309633731842,0.6411969661712646,3134.242573261261,9742,0,3134.242573261261,0.9856751561164856,0.0479455590248107,0.2425333064582841,43793,4598.49095082283,0.9900648593902588,0.0330399274826049,0.3367835881506164,0.9864760637283324,0.0452936850488185,0.2370356089909642,43793 -1567.4261529445648,0.6686229705810547,3374.392449617386,10495,0,3374.392449617386,0.9855492115020752,0.0476983599364757,0.2441579029718845,43793,4942.783852100372,0.9899468421936036,0.0334078483283519,0.3359655048713717,0.986441969871521,0.0449944920837879,0.2485300402087053,43793 -1669.9343152046204,0.6964986324310303,3614.5693922042847,11255,0,3614.5693922042847,0.9857833981513976,0.0474206693470478,0.2541108142727972,43793,5285.517816543579,0.9901042580604552,0.032832533121109,0.350322015439398,0.9866904020309448,0.0445432774722576,0.2615323398364682,43793 -1777.4200673103333,0.7287240028381348,3854.664434194565,11991,0,3854.664434194565,0.985854983329773,0.0473280772566795,0.2575606978341206,43793,5633.155611276627,0.9903876781463624,0.0318429432809352,0.3595448267341654,0.9868012070655824,0.0443866401910781,0.2597861086541841,43793 -1880.0096390247345,0.7573575973510742,4094.871671199799,12741,0,4094.871671199799,0.985881507396698,0.0474580600857734,0.2509873600028932,43793,5976.001766443253,0.9904323220252992,0.0317389890551567,0.3702695672831442,0.986777663230896,0.044543270021677,0.2593790380554532,43793 -1982.690937757492,0.785717248916626,4334.968943119049,13487,0,4334.968943119049,0.9858739376068116,0.0471687763929367,0.2534053815363039,43793,6318.829391956329,0.9905805587768556,0.031044103205204,0.3950207597065057,0.9867857694625854,0.0442429110407829,0.2639529675838791,43793 -2086.5756731033325,0.8285675048828125,4575.041334629059,14240,0,4575.041334629059,0.9859964847564696,0.0471503436565399,0.2576448284120592,43793,6662.850385427475,0.990578591823578,0.0308920368552207,0.3902965871620472,0.9868515133857728,0.0444140061736106,0.2635190960018144,43793 -2186.236001253128,0.8597097396850586,4815.053724527359,14992,0,4815.053724527359,0.9859017133712769,0.0474303290247917,0.2545299474361538,43793,7002.574951410294,0.9906534552574158,0.0304636172950267,0.4054393168607071,0.98679918050766,0.0444273576140403,0.2602008816337027,43793 -2290.80570602417,0.8886802196502686,5055.218120336533,15745,0,5055.218120336533,0.985859215259552,0.0472582839429378,0.2598296965103563,43793,7347.359417915344,0.9910573363304138,0.0293477196246385,0.4316723557821035,0.9867650866508484,0.0445890054106712,0.2701452169981196,43793 -2392.0925753116608,0.9176356792449952,5295.226839065552,16505,0,5295.226839065552,0.985924482345581,0.0468704588711261,0.2595190000831163,43793,7688.704651594162,0.9910054802894592,0.0295385904610157,0.4293138323079307,0.9866920113563538,0.044278547167778,0.2718293568793558,43793 -2490.3105919361115,0.9467637538909912,5535.3997938632965,17266,0,5535.3997938632965,0.9859097599983216,0.0468415208160877,0.2596608132025662,43793,8027.145152568817,0.991085171699524,0.0294013731181621,0.4287108022495856,0.9867752194404602,0.0442826114594936,0.2713982672671105,43793 -2593.659286260605,0.9766237735748292,5775.622433185577,18021,0,5775.622433185577,0.9859362840652466,0.046722188591957,0.2607173361411814,43793,8370.766879558563,0.9907681941986084,0.030180849134922,0.4159931679730896,0.9868023991584778,0.0440882481634616,0.2719252917849784,43793 -2694.3736753463745,1.0064642429351809,6015.79421544075,18785,0,6015.79421544075,0.9860011339187622,0.0467405170202255,0.2595897761150846,43793,8711.703551054,0.99088716506958,0.0299148950725793,0.4029454179103736,0.9867849946022034,0.0442803464829921,0.2694011256778948,43793 -2796.258656024933,1.0368270874023438,6255.988671064377,19526,0,6255.988671064377,0.9861102104187012,0.0469335913658142,0.2687963123652824,43793,9053.836593389511,0.9911262392997742,0.0291009750217199,0.4268990125422792,0.9868763089179992,0.0442834086716175,0.2757139397117077,43793 -2896.8540201187134,1.066905498504639,6496.154449701309,20280,0,6496.154449701309,0.986026406288147,0.0471141971647739,0.2621612672994907,43793,9394.648380041122,0.9910481572151184,0.0292404498904943,0.4309862850495052,0.9868617057800292,0.0443498492240905,0.2763822320770312,43793 -3000.609070301056,1.0970373153686523,6736.26745891571,21033,0,6736.26745891571,0.9860247373580932,0.0469842851161956,0.2614326363612846,43793,9738.567348957062,0.991021692752838,0.0290874261409044,0.4451907895185464,0.98687082529068,0.0442472547292709,0.2728488769441149,43793 -3104.023591041565,1.1280357837677002,6976.427305936813,21799,0,6976.427305936813,0.9859337210655212,0.0468870475888252,0.264711273784717,43793,10082.193821668625,0.9912761449813844,0.0282179340720176,0.4582781325632959,0.9867464303970336,0.0442386865615844,0.2674546311890728,43793 -3205.8467667102814,1.1613752841949463,7216.409387588501,22545,0,7216.409387588501,0.9859463572502136,0.0472444929182529,0.2622843075814916,43793,10424.05526447296,0.9913682341575624,0.0281646307557821,0.4644050467140248,0.9868109226226808,0.0444370955228805,0.2704546050163061,43793 -3308.191586256027,1.1935296058654783,7456.590592622757,23296,0,7456.590592622757,0.9861654043197632,0.0465232543647289,0.2689258721748647,43793,10766.633274793625,0.9915869235992432,0.0273259282112121,0.470983335381583,0.9869566559791564,0.0440373495221138,0.2790912138407179,43793 -3407.847371816635,1.2242672443389893,7696.545921325684,24052,0,7696.545921325684,0.9862087965011596,0.0469167828559875,0.2769371208521102,43793,11106.29555439949,0.991644561290741,0.0270146057009696,0.4884058360080612,0.9869595170021056,0.0441711880266666,0.2820494999826886,43793 -3514.256602048874,1.25538969039917,7936.503590583801,24798,0,7936.503590583801,0.9860756993293762,0.0467779636383056,0.2672637130743255,43793,11452.717252254486,0.991566836833954,0.0274626016616821,0.4743264384794057,0.9868316650390624,0.0442666038870811,0.2701833417002129,43793 -3614.790126800537,1.2863295078277588,8176.548576593399,25556,0,8176.548576593399,0.9861093759536744,0.0466210879385471,0.2705003644585464,43793,11793.34719634056,0.9913595914840698,0.028126923367381,0.4622975459881457,0.9869400262832642,0.0441190637648105,0.2790331791350799,43793 -3713.144189834594,1.3176674842834473,8416.634189844131,26314,0,8416.634189844131,0.9859468340873718,0.0466155149042606,0.2698534271230323,43793,12131.83881664276,0.9912744164466858,0.0281446538865566,0.4479110439168735,0.9868324398994446,0.0439815558493137,0.2810698441376711,43793 -3816.244790554047,1.3499679565429688,8656.794304132462,27065,0,8656.794304132462,0.9860487580299376,0.0471839122474193,0.2670339478514961,43793,12475.151976585388,0.9913526177406312,0.0279363002628088,0.4632723640628128,0.986867368221283,0.0443237386643886,0.2782958896685469,43793 -3917.5429599285126,1.381486415863037,8896.875655651093,27820,0,8896.875655651093,0.9861077070236206,0.0469717867672443,0.264931771800272,43793,12816.583455085754,0.9914987683296204,0.0274489261209964,0.4683983500487964,0.9869043231010436,0.0442365556955337,0.2760711966392069,43793 -4023.892383813858,1.4207780361175537,9137.05087184906,28577,0,9137.05087184906,0.9861435294151306,0.0470503196120262,0.2718404436297,43793,13163.168231010435,0.9915719032287598,0.0271757207810878,0.4733769339022719,0.9869562983512878,0.044232428073883,0.2856657416394459,43793 -4126.997245788574,1.452960968017578,9377.265213012695,29337,0,9377.265213012695,0.9861460328102112,0.0468155331909656,0.2694635318031406,43793,13506.539906024933,0.991813600063324,0.0265144873410463,0.4933331247934246,0.987015962600708,0.0440828762948513,0.2803936676206833,43793 -4232.770231723785,1.485065221786499,9617.43070077896,30096,0,9617.43070077896,0.9861797094345092,0.0468039661645889,0.2741908935948261,43793,13852.530488491058,0.9918859601020812,0.0261133890599012,0.5055414899153353,0.9869863390922546,0.0437818579375743,0.2895730354147849,43793 -4334.496691703796,1.5186164379119873,9857.432433843613,30853,0,9857.432433843613,0.9860929846763612,0.0471954904496669,0.2743780890275217,43793,14194.312682628632,0.9918944835662842,0.0259779468178749,0.5109262523099191,0.986935555934906,0.0444079004228115,0.2870528055093321,43793 -4440.5671174526215,1.551112174987793,10097.557705879211,31607,0,10097.557705879211,0.9863023161888124,0.0465546734631061,0.2826208197222706,43793,14540.564175128937,0.9922711849212646,0.0249787494540214,0.526230375162185,0.9870455861091614,0.0440018251538276,0.2825439937763267,43793 -4541.132849693298,1.584326982498169,10337.713238954544,32365,0,10337.713238954544,0.986120343208313,0.0476447828114032,0.2724481736220283,43793,14881.338361740112,0.9920186400413512,0.0256517603993415,0.5237306076253523,0.9870054125785828,0.0448300689458847,0.2819343967433677,43793 -4642.83878827095,1.6173911094665527,10577.716572523115,33121,0,10577.716572523115,0.9861447811126708,0.0467436574399471,0.2786746065075245,43793,15223.10135102272,0.9920905828475952,0.0256081577390432,0.5157125620277023,0.9869648218154908,0.0443299151957035,0.2834494862434448,43793 -4751.999447107315,1.6501054763793943,10817.91072845459,33874,0,10817.91072845459,0.9861708879470824,0.0472316332161426,0.2744900853244321,43793,15572.50892305374,0.9917319416999816,0.0265057682991027,0.4951007638370917,0.9868974089622498,0.0445404797792434,0.284593376840681,43793 -4849.44939661026,1.683269739151001,11058.15099453926,34632,0,11058.15099453926,0.9862428903579712,0.0472697019577026,0.2789517823758971,43793,15910.252793550491,0.9918310642242432,0.0261690374463796,0.4915393588454686,0.987043559551239,0.0445980280637741,0.2865059627513893,43793 -4948.712647199631,1.7171683311462402,11298.339334487917,35378,0,11298.339334487917,0.9859455227851868,0.0472081154584884,0.2693484981572296,43793,16249.758395671844,0.9919923543930054,0.0257877148687839,0.5099326315816797,0.9868158102035522,0.0444606207311153,0.2835898925780593,43793 -5049.473174333572,1.7545788288116455,11538.558787345886,36126,0,11538.558787345886,0.986178457736969,0.0475702285766601,0.2711337537526851,43793,16590.798290252686,0.992000699043274,0.0254850517958402,0.5258720739264008,0.987106442451477,0.0446226224303245,0.2882472900207319,43793 -5152.5869171619415,1.787407159805298,11778.575371980667,36889,0,11778.575371980667,0.986159086227417,0.0470145195722579,0.2754729484835849,43793,16933.982120752335,0.9921327829360962,0.0251639336347579,0.5333610794699605,0.986931085586548,0.044403463602066,0.2876226417928679,43793 -5256.769897699356,2.2428689002990723,12018.111292123796,37648,0,12018.111292123796,0.9862471222877502,0.0473320521414279,0.2776725133597497,43793,17278.17639899254,0.9923455119132996,0.0244872458279132,0.5460171638954645,0.9870200157165528,0.0446493402123451,0.2832415157203511,43793 -5356.201946020126,2.277024269104004,12258.103625297546,38400,0,12258.103625297546,0.9862862825393676,0.0478410758078098,0.2788159578705024,43793,17617.65689277649,0.9923316836357116,0.0241330750286579,0.5530907366085315,0.9871101379394532,0.0448829308152198,0.2893611700860986,43793 -5457.145552873611,2.3104984760284424,12498.116117477415,39153,0,12498.116117477415,0.9862328171730042,0.047747578471899,0.2744675293708428,43793,17958.667691469193,0.9926602840423584,0.0234204418957233,0.5637032087035159,0.987121880054474,0.0448669232428073,0.2899202695664321,43793 -5556.928096294403,2.3550827503204346,12738.242744922638,39908,0,12738.242744922638,0.9862656593322754,0.0473434291779994,0.2774957648688163,43793,18298.64096140861,0.9927948713302612,0.0230580810457468,0.5783965721104753,0.9870395064353944,0.0444376207888126,0.286592025599736,43793 -5658.854161739349,2.401482343673706,12978.360262870789,40666,0,12978.360262870789,0.9861329793930054,0.0476587414741516,0.2742626805934189,43793,18640.750893592834,0.992606282234192,0.0236912108957767,0.5603386086431281,0.986970067024231,0.0448714643716812,0.2884174564599778,43793 -5767.809935808182,2.435475349426269,13218.463328838348,41428,0,13218.463328838348,0.9861510992050172,0.0476103760302066,0.2792500978473953,43793,18989.86360406876,0.9925146102905272,0.023902615532279,0.5628586654561019,0.9869022965431212,0.0450378134846687,0.2864156006446135,43793 -5865.882692337036,2.4739890098571777,13458.669250249864,42176,0,13458.669250249864,0.986019253730774,0.0477226413786411,0.2770095178376321,43793,19328.203336000443,0.9924070835113524,0.024378603324294,0.5376271174259066,0.986899435520172,0.0449063070118427,0.2886609224145899,43793 -5969.3086223602295,2.510117053985596,13698.63403391838,42932,0,13698.63403391838,0.9862534403800964,0.0477136820554733,0.2787048486685619,43793,19671.65065932274,0.9924079179763794,0.024059934541583,0.5503674235235126,0.9870553016662598,0.0449348129332065,0.2921576254088969,43793 -6071.934442520142,2.5530261993408203,13938.807181596756,43675,0,13938.807181596756,0.9862707257270812,0.048059307038784,0.275700904360155,43793,20014.515516281128,0.9927031993865968,0.0233249384909868,0.5674737634256353,0.9871320724487304,0.0450026728212833,0.2952275795871997,43793 -6173.127116203308,2.588895082473755,14179.057677268982,44433,0,14179.057677268982,0.9861220121383668,0.0477507822215557,0.2761335127109555,43793,20356.0144674778,0.9927995800971984,0.0229585859924554,0.5733450457802596,0.98695707321167,0.0449227802455425,0.2942502749410342,43793 -6276.180570602417,2.632185935974121,14419.24270439148,45198,0,14419.24270439148,0.9862256646156312,0.0481463596224784,0.2737167666026969,43793,20699.316056489944,0.992863118648529,0.0225877240300178,0.5935786349699248,0.9871166348457336,0.0451795570552349,0.2876991151416316,43793 -6375.948750257492,2.666360378265381,14659.305247306824,45945,0,14659.305247306824,0.9862942695617676,0.0487545803189277,0.2747491794169562,43793,21039.20338702202,0.9931178092956544,0.021580209955573,0.6017364886534111,0.9870991706848145,0.0457594506442546,0.2840771218195006,43793 -6475.143660068512,2.705780267715454,14899.513046503069,46709,0,14899.513046503069,0.9862281680107116,0.048151209950447,0.2759195730682534,43793,21378.66554641724,0.9932405352592468,0.0215317849069833,0.607432035924848,0.9870342016220092,0.0452977679669857,0.2873679751801644,43793 -6577.043631315231,2.747018337249756,15139.652928829191,47468,0,15139.652928829191,0.986265242099762,0.0487691946327686,0.2753911116802991,43793,21720.766462802887,0.9933563470840454,0.0209414884448051,0.6339597542179121,0.9871182441711426,0.0456337742507457,0.2866939484275668,43793 -6674.879307746887,2.782480478286743,15379.684081554413,48225,0,15379.684081554413,0.9862454533576964,0.0489246025681495,0.2751573313937183,43793,22058.68881917,0.9933629035949708,0.021108966320753,0.6215161265846207,0.9870553016662598,0.045908585190773,0.2829216172937069,43793 -6775.858198404312,2.8195629119873047,15619.90098786354,48987,0,15619.90098786354,0.9862496256828308,0.0487202107906341,0.2773312555186638,43793,22399.94140410424,0.9932666420936584,0.0212262161076068,0.611152309972192,0.9870601892471312,0.0456889979541301,0.290932891249044,43793 -6875.612053394318,2.857351303100586,15860.069697856905,49745,0,15860.069697856905,0.9861186742782592,0.0488905385136604,0.278546482883604,43793,22739.921632766724,0.9932006001472472,0.021596472710371,0.6055980733194052,0.9869509935379028,0.0460079833865165,0.2892799181159078,43793 -6974.607265710831,2.895205497741699,16100.29632639885,50487,0,16100.29632639885,0.9861944913864136,0.0490744151175022,0.2775945821958022,43793,23079.203650951385,0.9930036664009094,0.0220139380544424,0.593615504286974,0.9869891405105592,0.046247225254774,0.2876874796115182,43793 -7078.8386816978455,2.932920217514038,16340.390928268433,51253,0,16340.390928268433,0.986276626586914,0.0492494367063045,0.2809767090009311,43793,23423.58703923225,0.9931820631027222,0.0213425699621438,0.6134943919388604,0.9870630502700806,0.0462197475135326,0.2916114859958877,43793 -7177.91805267334,2.969144821166992,16580.394956111908,52016,0,16580.394956111908,0.986294686794281,0.0494212321937084,0.2847148471225769,43793,23762.726235628128,0.993344247341156,0.0208707619458436,0.6242266281073867,0.9870747923851012,0.0464723855257034,0.2907870160289012,43793 -7282.37335062027,3.005022048950196,16820.457956552505,52779,0,16820.457956552505,0.9862024784088136,0.0493457950651645,0.2773575915352462,43793,24107.30019426346,0.993465006351471,0.0205740891396999,0.6301589141153873,0.987104833126068,0.046230211853981,0.2900999925999507,43793 -7380.0420706272125,3.041551351547241,17060.670438051224,53539,0,17060.670438051224,0.9862854480743408,0.0497939959168434,0.2795073140182851,43793,24445.238028764725,0.9936485290527344,0.0197937209159135,0.6516988153813621,0.9871584177017212,0.0465914942324161,0.2926621584624282,43793 -7483.484494686127,3.0773909091949463,17300.793816804886,54287,0,17300.793816804886,0.9861018061637878,0.0499334521591663,0.2787303739599055,43793,24788.86217021942,0.993804931640625,0.0193524714559316,0.6605385215433868,0.9869335293769836,0.0468230247497558,0.2932988620035959,43793 -7586.947088718414,3.1195337772369385,17540.84943985939,55034,0,17540.84943985939,0.9861464500427246,0.0501525029540061,0.2725020764620905,43793,25132.44391345977,0.9940738677978516,0.018616709858179,0.6824790188145755,0.9870139360427856,0.0467959530651569,0.291244495393672,43793 -7687.15647649765,3.157621622085572,17780.89176774025,55785,0,17780.89176774025,0.9862500429153442,0.0502575188875198,0.2784436858642199,43793,25472.753759860992,0.994262993335724,0.0181721411645412,0.6914053178116203,0.9870870113372804,0.0470523163676261,0.2891383001431945,43793 -7789.804910898209,3.195686817169189,18020.839967250824,56537,0,18020.839967250824,0.986182689666748,0.0502708926796913,0.2769247715519446,43793,25815.408913373947,0.994239866733551,0.0183276608586311,0.685328157042368,0.9870309829711914,0.0470251254737377,0.2896792143578579,43793 -7888.045190811157,3.2336459159851074,18260.91561436653,57296,0,18260.91561436653,0.9863001704216003,0.05074793100357056,0.2777486826024884,43793,26153.78284072876,0.9940085411071777,0.01863333210349083,0.6795497812086909,0.9870337843894958,0.04753246530890465,0.2885810856821342,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/measurements.csv deleted file mode 100644 index 2fca571a5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/measurements.csv +++ /dev/null @@ -1,659 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.8872375,0.7344205,,,,,,,,,,,,,,,,, -1,,,0.5288358926773071,0.7364843487739563,0.0208787674226513,0.5270814895629883,0.737440824508667,0.024115497491982,43793.0,0.5256850719451904,0.7376683354377747,0.026033115042489,43793.0,13.132712841033936,116.17477297782898,13.132712841033936,103.042014837265,0.0,0.0 -100,0.28229994,0.2644097,,,,,,,,,,,,,,,,, -200,0.09104837,0.10725577,,,,,,,,,,,,,,,,, -300,0.03161377,0.067231245,,,,,,,,,,,,,,,,, -400,0.021831676,0.057807364,,,,,,,,,,,,,,,,, -500,0.037039593,0.061670296,,,,,,,,,,,,,,,,, -600,0.019142315,0.05476305,,,,,,,,,,,,,,,,, -700,0.01672886,0.049029566,,,,,,,,,,,,,,,,, -749,,,0.9869271516799928,0.0506599321961402,0.0599720103019728,0.984121561050415,0.0600958317518234,0.0569143727465115,43793.0,0.983132779598236,0.0632855519652366,0.0582553299126998,43793.0,252.96980333328247,459.00523805618286,252.96980333328247,205.7111649513245,0.3030984401702881,0.0 -800,0.03211605,0.05555944,,,,,,,,,,,,,,,,, -900,0.020857353,0.04917189,,,,,,,,,,,,,,,,, -1000,0.022346692,0.049923457,,,,,,,,,,,,,,,,, -1100,0.02265306,0.043555737,,,,,,,,,,,,,,,,, -1200,0.048484147,0.048713338,,,,,,,,,,,,,,,,, -1300,0.035756722,0.04497696,,,,,,,,,,,,,,,,, -1400,0.022501249,0.049350254,,,,,,,,,,,,,,,,, -1500,0.024555128,0.05350736,,,,,,,,,,,,,,,,, -1503,,,0.98733389377594,0.0471424348652362,0.1090940666267015,0.9847106337547302,0.0569747574627399,0.1081028403749525,43793.0,0.983751118183136,0.060521088540554,0.1071025414931932,43793.0,493.0213305950165,806.6227686405182,493.0213305950165,313.2293353080749,0.3306779861450195,0.0 -1600,0.026684927,0.048454598,,,,,,,,,,,,,,,,, -1700,0.018571533,0.047423888,,,,,,,,,,,,,,,,, -1800,0.017066937,0.04927708,,,,,,,,,,,,,,,,, -1900,0.015868302,0.051108174,,,,,,,,,,,,,,,,, -2000,0.015154433,0.044937477,,,,,,,,,,,,,,,,, -2100,0.0131393755,0.045363788,,,,,,,,,,,,,,,,, -2200,0.0147870295,0.044672508,,,,,,,,,,,,,,,,, -2259,,,0.9877690076828004,0.0432558394968509,0.1465678209953282,0.9849586486816406,0.0523457862436771,0.1377446736224435,43793.0,0.984021544456482,0.0551568306982517,0.1416593044334367,43793.0,733.1634314060211,1150.4456820487976,733.1634314060211,416.8628304004669,0.3571460247039795,0.0 -2300,0.014691583,0.042257052,,,,,,,,,,,,,,,,, -2400,0.01646396,0.046212487,,,,,,,,,,,,,,,,, -2500,0.021930981,0.047135234,,,,,,,,,,,,,,,,, -2600,0.028436627,0.040099625,,,,,,,,,,,,,,,,, -2700,0.021845868,0.047131438,,,,,,,,,,,,,,,,, -2800,0.017089363,0.045878123,,,,,,,,,,,,,,,,, -2900,0.015437418,0.042125296,,,,,,,,,,,,,,,,, -3000,0.010316322,0.038349643,,,,,,,,,,,,,,,,, -3014,,,0.9881166219711304,0.0416135974228382,0.1762406938747769,0.9850816130638124,0.0510498061776161,0.1529185079755586,43793.0,0.9841373562812804,0.0538092143833637,0.1561581495213794,43793.0,973.2205414772034,1492.9210562705994,973.2205414772034,519.2329788208008,0.3839807510375976,0.0 -3100,0.015926022,0.044541612,,,,,,,,,,,,,,,,, -3200,0.026264325,0.0398983,,,,,,,,,,,,,,,,, -3300,0.012014005,0.042099334,,,,,,,,,,,,,,,,, -3400,0.012565531,0.037721924,,,,,,,,,,,,,,,,, -3500,0.015304007,0.04526736,,,,,,,,,,,,,,,,, -3600,0.019025212,0.042266496,,,,,,,,,,,,,,,,, -3700,0.019871539,0.038601987,,,,,,,,,,,,,,,,, -3762,,,0.9884961247444152,0.0397477447986602,0.1983987575133168,0.9855943322181702,0.0491466596722602,0.1765521805688691,43793.0,0.984641969203949,0.0520867556333541,0.1823691982921597,43793.0,1213.301376581192,1833.0843858718872,1213.301376581192,619.2677090167999,0.4104807376861572,0.0 -3800,0.011475528,0.039165802,,,,,,,,,,,,,,,,, -3900,0.013387885,0.040729173,,,,,,,,,,,,,,,,, -4000,0.011713068,0.041350238,,,,,,,,,,,,,,,,, -4100,0.012952226,0.036852855,,,,,,,,,,,,,,,,, -4200,0.020642908,0.039762888,,,,,,,,,,,,,,,,, -4300,0.012498501,0.040130164,,,,,,,,,,,,,,,,, -4400,0.010406245,0.04026993,,,,,,,,,,,,,,,,, -4500,0.010003535,0.03475768,,,,,,,,,,,,,,,,, -4510,,,0.9887267351150512,0.0388303436338901,0.2203952174799551,0.9857790470123292,0.0481136739253997,0.1902895326825564,43793.0,0.9848470687866212,0.0507810413837432,0.1934886227088521,43793.0,1453.3801944255829,2173.8443462848663,1453.3801944255829,719.900552034378,0.438244104385376,0.0 -4600,0.0112053845,0.04002699,,,,,,,,,,,,,,,,, -4700,0.019208023,0.04658735,,,,,,,,,,,,,,,,, -4800,0.01048442,0.033851296,,,,,,,,,,,,,,,,, -4900,0.014246771,0.04004497,,,,,,,,,,,,,,,,, -5000,0.010414839,0.036717664,,,,,,,,,,,,,,,,, -5100,0.012728004,0.036787525,,,,,,,,,,,,,,,,, -5200,0.012142495,0.037607994,,,,,,,,,,,,,,,,, -5261,,,0.988910675048828,0.0377405993640422,0.237520259220091,0.985815167427063,0.047681551426649,0.2013130685938968,43793.0,0.984906017780304,0.0503714494407177,0.2053310842785677,43793.0,1693.555429935455,2517.606040239334,1693.555429935455,823.4393961429596,0.4656791687011719,0.0 -5300,0.015345512,0.037765622,,,,,,,,,,,,,,,,, -5400,0.014786914,0.03933934,,,,,,,,,,,,,,,,, -5500,0.01965823,0.038571656,,,,,,,,,,,,,,,,, -5600,0.0113170305,0.03614555,,,,,,,,,,,,,,,,, -5700,0.011031,0.040022053,,,,,,,,,,,,,,,,, -5800,0.018066129,0.036627688,,,,,,,,,,,,,,,,, -5900,0.013783631,0.03563941,,,,,,,,,,,,,,,,, -6000,0.0140257515,0.034663334,,,,,,,,,,,,,,,,, -6010,,,0.9891675114631652,0.0366061218082904,0.2542204334360372,0.9861252903938292,0.046566005796194,0.216148982664498,43793.0,0.985258162021637,0.0491888448596,0.2155477305191308,43793.0,1933.570939540863,2865.170319318772,1933.570939540863,930.9390976428986,0.4940738677978515,0.0 -6100,0.01462496,0.037545513,,,,,,,,,,,,,,,,, -6200,0.015750414,0.039153002,,,,,,,,,,,,,,,,, -6300,0.011553227,0.03594365,,,,,,,,,,,,,,,,, -6400,0.015157121,0.034064125,,,,,,,,,,,,,,,,, -6500,0.012778345,0.03919976,,,,,,,,,,,,,,,,, -6600,0.047233388,0.037869163,,,,,,,,,,,,,,,,, -6700,0.010939436,0.03311363,,,,,,,,,,,,,,,,, -6751,,,0.9891411662101746,0.0364784225821495,0.2799573602779195,0.9861935377120972,0.0465248227119445,0.220708275832851,43793.0,0.9853482842445374,0.0490752831101417,0.2274546894836015,43793.0,2173.825298309326,3209.513146877289,2173.825298309326,1034.9742851257324,0.5242447853088379,0.0 -6800,0.010812062,0.034811623,,,,,,,,,,,,,,,,, -6900,0.01627334,0.039404213,,,,,,,,,,,,,,,,, -7000,0.010254026,0.035460044,,,,,,,,,,,,,,,,, -7100,0.019770714,0.03750555,,,,,,,,,,,,,,,,, -7200,0.02374922,0.03678213,,,,,,,,,,,,,,,,, -7300,0.016348768,0.041256223,,,,,,,,,,,,,,,,, -7400,0.021271128,0.037961267,,,,,,,,,,,,,,,,, -7500,0.0151063055,0.0349484,,,,,,,,,,,,,,,,, -7506,,,0.9896240234375,0.0350513271987438,0.2970691115022231,0.9860489964485168,0.0466834381222724,0.2174827924270316,43793.0,0.98521226644516,0.0492332838475704,0.2200865516961716,43793.0,2414.03213095665,3556.829462051392,2414.03213095665,1142.034281015396,0.5527384281158447,0.0 -7600,0.016864827,0.037175417,,,,,,,,,,,,,,,,, -7700,0.016946442,0.038066443,,,,,,,,,,,,,,,,, -7800,0.018335085,0.036907725,,,,,,,,,,,,,,,,, -7900,0.015273083,0.0316976,,,,,,,,,,,,,,,,, -8000,0.01315429,0.032753784,,,,,,,,,,,,,,,,, -8100,0.02146206,0.03850077,,,,,,,,,,,,,,,,, -8200,0.02228666,0.03535281,,,,,,,,,,,,,,,,, -8241,,,0.9899398684501648,0.0338945463299751,0.3269750978417046,0.9864467978477478,0.0452142804861068,0.238935644503845,43793.0,0.9855904579162598,0.0480321943759918,0.235439831915132,43793.0,2654.151238441468,3903.309905767441,2654.151238441468,1248.3411090373993,0.5830700397491455,0.0 -8300,0.015061823,0.03247486,,,,,,,,,,,,,,,,, -8400,0.015847418,0.035603516,,,,,,,,,,,,,,,,, -8500,0.021450281,0.036910746,,,,,,,,,,,,,,,,, -8600,0.013043511,0.03340092,,,,,,,,,,,,,,,,, -8700,0.01734431,0.035671785,,,,,,,,,,,,,,,,, -8800,0.015576391,0.034411702,,,,,,,,,,,,,,,,, -8900,0.018329676,0.03878926,,,,,,,,,,,,,,,,, -8992,,,0.990164875984192,0.0329089388251304,0.343574061270163,0.9863510131835938,0.0453560687601566,0.2322570774352511,43793.0,0.9855828881263732,0.0478238090872764,0.2394972606796361,43793.0,2894.265217065811,4255.120206356049,2894.265217065811,1359.987991809845,0.6116423606872559,0.0 -9000,0.0254718,0.038154215,,,,,,,,,,,,,,,,, -9100,0.025449498,0.034639295,,,,,,,,,,,,,,,,, -9200,0.016032243,0.03110881,,,,,,,,,,,,,,,,, -9300,0.01685656,0.03540154,,,,,,,,,,,,,,,,, -9400,0.030608628,0.03535014,,,,,,,,,,,,,,,,, -9500,0.018593814,0.03425137,,,,,,,,,,,,,,,,, -9600,0.018987903,0.039226297,,,,,,,,,,,,,,,,, -9700,0.02680728,0.033179548,,,,,,,,,,,,,,,,, -9742,,,0.9900648593902588,0.0330399274826049,0.3367835881506164,0.9864760637283324,0.0452936850488185,0.2370356089909642,43793.0,0.9856751561164856,0.0479455590248107,0.2425333064582841,43793.0,3134.242573261261,4598.49095082283,3134.242573261261,1463.3309633731842,0.6411969661712646,0.0 -9800,0.017200649,0.034213096,,,,,,,,,,,,,,,,, -9900,0.031695757,0.037518665,,,,,,,,,,,,,,,,, -10000,0.017700681,0.03417122,,,,,,,,,,,,,,,,, -10100,0.014795807,0.030657688,,,,,,,,,,,,,,,,, -10200,0.018287458,0.0326664,,,,,,,,,,,,,,,,, -10300,0.020244109,0.035216983,,,,,,,,,,,,,,,,, -10400,0.027631247,0.033607025,,,,,,,,,,,,,,,,, -10495,,,0.9899468421936036,0.0334078483283519,0.3359655048713717,0.986441969871521,0.0449944920837879,0.2485300402087053,43793.0,0.9855492115020752,0.0476983599364757,0.2441579029718845,43793.0,3374.392449617386,4942.783852100372,3374.392449617386,1567.4261529445648,0.6686229705810547,0.0 -10500,0.022729926,0.035399254,,,,,,,,,,,,,,,,, -10600,0.017215637,0.031219102,,,,,,,,,,,,,,,,, -10700,0.026343178,0.03412699,,,,,,,,,,,,,,,,, -10800,0.021091707,0.032279246,,,,,,,,,,,,,,,,, -10900,0.028530773,0.036050357,,,,,,,,,,,,,,,,, -11000,0.019593379,0.034391157,,,,,,,,,,,,,,,,, -11100,0.024615657,0.03473306,,,,,,,,,,,,,,,,, -11200,0.021833997,0.03809314,,,,,,,,,,,,,,,,, -11255,,,0.9901042580604552,0.032832533121109,0.350322015439398,0.9866904020309448,0.0445432774722576,0.2615323398364682,43793.0,0.9857833981513976,0.0474206693470478,0.2541108142727972,43793.0,3614.5693922042847,5285.517816543579,3614.5693922042847,1669.9343152046204,0.6964986324310303,0.0 -11300,0.030097421,0.032961253,,,,,,,,,,,,,,,,, -11400,0.028080963,0.033164278,,,,,,,,,,,,,,,,, -11500,0.026674872,0.035189692,,,,,,,,,,,,,,,,, -11600,0.02242379,0.03521454,,,,,,,,,,,,,,,,, -11700,0.023681495,0.03373937,,,,,,,,,,,,,,,,, -11800,0.03627676,0.033167284,,,,,,,,,,,,,,,,, -11900,0.023166945,0.03352572,,,,,,,,,,,,,,,,, -11991,,,0.9903876781463624,0.0318429432809352,0.3595448267341654,0.9868012070655824,0.0443866401910781,0.2597861086541841,43793.0,0.985854983329773,0.0473280772566795,0.2575606978341206,43793.0,3854.664434194565,5633.155611276627,3854.664434194565,1777.4200673103333,0.7287240028381348,0.0 -12000,0.028640779,0.03235442,,,,,,,,,,,,,,,,, -12100,0.024314187,0.032894768,,,,,,,,,,,,,,,,, -12200,0.0371274,0.032939445,,,,,,,,,,,,,,,,, -12300,0.028861618,0.03250436,,,,,,,,,,,,,,,,, -12400,0.026530964,0.03349507,,,,,,,,,,,,,,,,, -12500,0.027335418,0.030952914,,,,,,,,,,,,,,,,, -12600,0.027611718,0.034314312,,,,,,,,,,,,,,,,, -12700,0.027028603,0.03541786,,,,,,,,,,,,,,,,, -12741,,,0.9904323220252992,0.0317389890551567,0.3702695672831442,0.986777663230896,0.044543270021677,0.2593790380554532,43793.0,0.985881507396698,0.0474580600857734,0.2509873600028932,43793.0,4094.871671199799,5976.001766443253,4094.871671199799,1880.0096390247345,0.7573575973510742,0.0 -12800,0.031986307,0.035744168,,,,,,,,,,,,,,,,, -12900,0.026953949,0.032539252,,,,,,,,,,,,,,,,, -13000,0.021677205,0.02914968,,,,,,,,,,,,,,,,, -13100,0.039742187,0.03181629,,,,,,,,,,,,,,,,, -13200,0.023874614,0.030268446,,,,,,,,,,,,,,,,, -13300,0.02938496,0.031827156,,,,,,,,,,,,,,,,, -13400,0.027824018,0.031061469,,,,,,,,,,,,,,,,, -13487,,,0.9905805587768556,0.031044103205204,0.3950207597065057,0.9867857694625854,0.0442429110407829,0.2639529675838791,43793.0,0.9858739376068116,0.0471687763929367,0.2534053815363039,43793.0,4334.968943119049,6318.829391956329,4334.968943119049,1982.690937757492,0.785717248916626,0.0 -13500,0.038513828,0.030964535,,,,,,,,,,,,,,,,, -13600,0.028539767,0.032876868,,,,,,,,,,,,,,,,, -13700,0.02850886,0.03128901,,,,,,,,,,,,,,,,, -13800,0.029960766,0.02891972,,,,,,,,,,,,,,,,, -13900,0.034604654,0.032944903,,,,,,,,,,,,,,,,, -14000,0.03063587,0.03224384,,,,,,,,,,,,,,,,, -14100,0.035243366,0.034860715,,,,,,,,,,,,,,,,, -14200,0.042233903,0.032350313,,,,,,,,,,,,,,,,, -14240,,,0.990578591823578,0.0308920368552207,0.3902965871620472,0.9868515133857728,0.0444140061736106,0.2635190960018144,43793.0,0.9859964847564696,0.0471503436565399,0.2576448284120592,43793.0,4575.041334629059,6662.850385427475,4575.041334629059,2086.5756731033325,0.8285675048828125,0.0 -14300,0.031819396,0.032770157,,,,,,,,,,,,,,,,, -14400,0.034309953,0.033358715,,,,,,,,,,,,,,,,, -14500,0.031321634,0.029257754,,,,,,,,,,,,,,,,, -14600,0.04094714,0.034407936,,,,,,,,,,,,,,,,, -14700,0.041623,0.03258118,,,,,,,,,,,,,,,,, -14800,0.037745148,0.028152294,,,,,,,,,,,,,,,,, -14900,0.042674005,0.030612925,,,,,,,,,,,,,,,,, -14992,,,0.9906534552574158,0.0304636172950267,0.4054393168607071,0.98679918050766,0.0444273576140403,0.2602008816337027,43793.0,0.9859017133712769,0.0474303290247917,0.2545299474361538,43793.0,4815.053724527359,7002.574951410294,4815.053724527359,2186.236001253128,0.8597097396850586,0.0 -15000,0.043873496,0.033818085,,,,,,,,,,,,,,,,, -15100,0.03355113,0.029926,,,,,,,,,,,,,,,,, -15200,0.037376698,0.033243414,,,,,,,,,,,,,,,,, -15300,0.04217709,0.030847054,,,,,,,,,,,,,,,,, -15400,0.038729116,0.03132544,,,,,,,,,,,,,,,,, -15500,0.032891765,0.03304213,,,,,,,,,,,,,,,,, -15600,0.035637446,0.033820543,,,,,,,,,,,,,,,,, -15700,0.035120357,0.032346383,,,,,,,,,,,,,,,,, -15745,,,0.9910573363304138,0.0293477196246385,0.4316723557821035,0.9867650866508484,0.0445890054106712,0.2701452169981196,43793.0,0.985859215259552,0.0472582839429378,0.2598296965103563,43793.0,5055.218120336533,7347.359417915344,5055.218120336533,2290.80570602417,0.8886802196502686,0.0 -15800,0.038007915,0.032071725,,,,,,,,,,,,,,,,, -15900,0.033862423,0.0322532,,,,,,,,,,,,,,,,, -16000,0.032345314,0.03125682,,,,,,,,,,,,,,,,, -16100,0.03557028,0.031258024,,,,,,,,,,,,,,,,, -16200,0.03797316,0.032352816,,,,,,,,,,,,,,,,, -16300,0.044266846,0.03347484,,,,,,,,,,,,,,,,, -16400,0.04069014,0.031923406,,,,,,,,,,,,,,,,, -16500,0.049298756,0.035101812,,,,,,,,,,,,,,,,, -16505,,,0.9910054802894592,0.0295385904610157,0.4293138323079307,0.9866920113563538,0.044278547167778,0.2718293568793558,43793.0,0.985924482345581,0.0468704588711261,0.2595190000831163,43793.0,5295.226839065552,7688.704651594162,5295.226839065552,2392.0925753116608,0.9176356792449952,0.0 -16600,0.031788677,0.033547893,,,,,,,,,,,,,,,,, -16700,0.051495083,0.03385684,,,,,,,,,,,,,,,,, -16800,0.035639714,0.031244973,,,,,,,,,,,,,,,,, -16900,0.056149688,0.030843528,,,,,,,,,,,,,,,,, -17000,0.045689795,0.03287616,,,,,,,,,,,,,,,,, -17100,0.036261234,0.029590376,,,,,,,,,,,,,,,,, -17200,0.035017785,0.032448772,,,,,,,,,,,,,,,,, -17266,,,0.991085171699524,0.0294013731181621,0.4287108022495856,0.9867752194404602,0.0442826114594936,0.2713982672671105,43793.0,0.9859097599983216,0.0468415208160877,0.2596608132025662,43793.0,5535.3997938632965,8027.145152568817,5535.3997938632965,2490.3105919361115,0.9467637538909912,0.0 -17300,0.03379623,0.029464407,,,,,,,,,,,,,,,,, -17400,0.034058332,0.02981072,,,,,,,,,,,,,,,,, -17500,0.04012508,0.032177977,,,,,,,,,,,,,,,,, -17600,0.042445082,0.028115207,,,,,,,,,,,,,,,,, -17700,0.0375254,0.029150076,,,,,,,,,,,,,,,,, -17800,0.034432117,0.030213766,,,,,,,,,,,,,,,,, -17900,0.048935737,0.03155386,,,,,,,,,,,,,,,,, -18000,0.03425087,0.029024124,,,,,,,,,,,,,,,,, -18021,,,0.9907681941986084,0.030180849134922,0.4159931679730896,0.9868023991584778,0.0440882481634616,0.2719252917849784,43793.0,0.9859362840652466,0.046722188591957,0.2607173361411814,43793.0,5775.622433185577,8370.766879558563,5775.622433185577,2593.659286260605,0.9766237735748292,0.0 -18100,0.051839758,0.035066094,,,,,,,,,,,,,,,,, -18200,0.040205076,0.030485895,,,,,,,,,,,,,,,,, -18300,0.04111637,0.031267468,,,,,,,,,,,,,,,,, -18400,0.038427908,0.031093149,,,,,,,,,,,,,,,,, -18500,0.04110864,0.030831233,,,,,,,,,,,,,,,,, -18600,0.051618867,0.0317946,,,,,,,,,,,,,,,,, -18700,0.053321,0.034115274,,,,,,,,,,,,,,,,, -18785,,,0.99088716506958,0.0299148950725793,0.4029454179103736,0.9867849946022034,0.0442803464829921,0.2694011256778948,43793.0,0.9860011339187622,0.0467405170202255,0.2595897761150846,43793.0,6015.79421544075,8711.703551054,6015.79421544075,2694.3736753463745,1.0064642429351809,0.0 -18800,0.045181632,0.028453926,,,,,,,,,,,,,,,,, -18900,0.038321435,0.031370282,,,,,,,,,,,,,,,,, -19000,0.03892838,0.033838354,,,,,,,,,,,,,,,,, -19100,0.03754049,0.026178278,,,,,,,,,,,,,,,,, -19200,0.04605924,0.032688364,,,,,,,,,,,,,,,,, -19300,0.042136755,0.029233467,,,,,,,,,,,,,,,,, -19400,0.047944628,0.029521324,,,,,,,,,,,,,,,,, -19500,0.047051404,0.031275246,,,,,,,,,,,,,,,,, -19526,,,0.9911262392997742,0.0291009750217199,0.4268990125422792,0.9868763089179992,0.0442834086716175,0.2757139397117077,43793.0,0.9861102104187012,0.0469335913658142,0.2687963123652824,43793.0,6255.988671064377,9053.836593389511,6255.988671064377,2796.258656024933,1.0368270874023438,0.0 -19600,0.06393356,0.032913048,,,,,,,,,,,,,,,,, -19700,0.051183473,0.03506659,,,,,,,,,,,,,,,,, -19800,0.04351963,0.031550307,,,,,,,,,,,,,,,,, -19900,0.037993737,0.030155107,,,,,,,,,,,,,,,,, -20000,0.0422609,0.03114385,,,,,,,,,,,,,,,,, -20100,0.04777898,0.032847773,,,,,,,,,,,,,,,,, -20200,0.039187692,0.03393702,,,,,,,,,,,,,,,,, -20280,,,0.9910481572151184,0.0292404498904943,0.4309862850495052,0.9868617057800292,0.0443498492240905,0.2763822320770312,43793.0,0.986026406288147,0.0471141971647739,0.2621612672994907,43793.0,6496.154449701309,9394.648380041122,6496.154449701309,2896.8540201187134,1.066905498504639,0.0 -20300,0.040119167,0.029355256,,,,,,,,,,,,,,,,, -20400,0.054263387,0.030882025,,,,,,,,,,,,,,,,, -20500,0.046637326,0.03347717,,,,,,,,,,,,,,,,, -20600,0.04167824,0.03257213,,,,,,,,,,,,,,,,, -20700,0.06028395,0.033416037,,,,,,,,,,,,,,,,, -20800,0.047546677,0.03171068,,,,,,,,,,,,,,,,, -20900,0.045617424,0.028515631,,,,,,,,,,,,,,,,, -21000,0.051254738,0.037282642,,,,,,,,,,,,,,,,, -21033,,,0.991021692752838,0.0290874261409044,0.4451907895185464,0.98687082529068,0.0442472547292709,0.2728488769441149,43793.0,0.9860247373580932,0.0469842851161956,0.2614326363612846,43793.0,6736.26745891571,9738.567348957062,6736.26745891571,3000.609070301056,1.0970373153686523,0.0 -21100,0.054335907,0.03280738,,,,,,,,,,,,,,,,, -21200,0.043746464,0.030655205,,,,,,,,,,,,,,,,, -21300,0.052522894,0.0311859,,,,,,,,,,,,,,,,, -21400,0.04186761,0.029775232,,,,,,,,,,,,,,,,, -21500,0.045299757,0.030282754,,,,,,,,,,,,,,,,, -21600,0.044642977,0.02948786,,,,,,,,,,,,,,,,, -21700,0.05127597,0.030269124,,,,,,,,,,,,,,,,, -21799,,,0.9912761449813844,0.0282179340720176,0.4582781325632959,0.9867464303970336,0.0442386865615844,0.2674546311890728,43793.0,0.9859337210655212,0.0468870475888252,0.264711273784717,43793.0,6976.427305936813,10082.193821668625,6976.427305936813,3104.023591041565,1.1280357837677002,0.0 -21800,0.045509294,0.029263416,,,,,,,,,,,,,,,,, -21900,0.05047755,0.028854527,,,,,,,,,,,,,,,,, -22000,0.043773092,0.029313426,,,,,,,,,,,,,,,,, -22100,0.06243185,0.031803187,,,,,,,,,,,,,,,,, -22200,0.05241746,0.03021062,,,,,,,,,,,,,,,,, -22300,0.04603039,0.027240245,,,,,,,,,,,,,,,,, -22400,0.042723488,0.029145606,,,,,,,,,,,,,,,,, -22500,0.043487962,0.032315854,,,,,,,,,,,,,,,,, -22545,,,0.9913682341575624,0.0281646307557821,0.4644050467140248,0.9868109226226808,0.0444370955228805,0.2704546050163061,43793.0,0.9859463572502136,0.0472444929182529,0.2622843075814916,43793.0,7216.409387588501,10424.05526447296,7216.409387588501,3205.8467667102814,1.1613752841949463,0.0 -22600,0.06252997,0.034468584,,,,,,,,,,,,,,,,, -22700,0.046409823,0.028541729,,,,,,,,,,,,,,,,, -22800,0.046865184,0.027486237,,,,,,,,,,,,,,,,, -22900,0.04992317,0.028480437,,,,,,,,,,,,,,,,, -23000,0.055386562,0.03341396,,,,,,,,,,,,,,,,, -23100,0.06966646,0.029379454,,,,,,,,,,,,,,,,, -23200,0.05066584,0.032295723,,,,,,,,,,,,,,,,, -23296,,,0.9915869235992432,0.0273259282112121,0.470983335381583,0.9869566559791564,0.0440373495221138,0.2790912138407179,43793.0,0.9861654043197632,0.0465232543647289,0.2689258721748647,43793.0,7456.590592622757,10766.633274793625,7456.590592622757,3308.191586256027,1.1935296058654783,0.0 -23300,0.051462945,0.034191456,,,,,,,,,,,,,,,,, -23400,0.046581313,0.02947656,,,,,,,,,,,,,,,,, -23500,0.053248897,0.02927549,,,,,,,,,,,,,,,,, -23600,0.060492743,0.02840757,,,,,,,,,,,,,,,,, -23700,0.06169141,0.030382484,,,,,,,,,,,,,,,,, -23800,0.046656564,0.027667545,,,,,,,,,,,,,,,,, -23900,0.0575336,0.030597482,,,,,,,,,,,,,,,,, -24000,0.05308128,0.03394575,,,,,,,,,,,,,,,,, -24052,,,0.991644561290741,0.0270146057009696,0.4884058360080612,0.9869595170021056,0.0441711880266666,0.2820494999826886,43793.0,0.9862087965011596,0.0469167828559875,0.2769371208521102,43793.0,7696.545921325684,11106.29555439949,7696.545921325684,3407.847371816635,1.2242672443389893,0.0 -24100,0.053170905,0.032412335,,,,,,,,,,,,,,,,, -24200,0.05587867,0.03127319,,,,,,,,,,,,,,,,, -24300,0.052696556,0.03224343,,,,,,,,,,,,,,,,, -24400,0.050502934,0.029771753,,,,,,,,,,,,,,,,, -24500,0.06693552,0.032710712,,,,,,,,,,,,,,,,, -24600,0.05083015,0.030761253,,,,,,,,,,,,,,,,, -24700,0.05084523,0.030080596,,,,,,,,,,,,,,,,, -24798,,,0.991566836833954,0.0274626016616821,0.4743264384794057,0.9868316650390624,0.0442666038870811,0.2701833417002129,43793.0,0.9860756993293762,0.0467779636383056,0.2672637130743255,43793.0,7936.503590583801,11452.717252254486,7936.503590583801,3514.256602048874,1.25538969039917,0.0 -24800,0.047372013,0.028792648,,,,,,,,,,,,,,,,, -24900,0.052006498,0.029266262,,,,,,,,,,,,,,,,, -25000,0.07940573,0.02856622,,,,,,,,,,,,,,,,, -25100,0.051767636,0.030078456,,,,,,,,,,,,,,,,, -25200,0.05267878,0.028869629,,,,,,,,,,,,,,,,, -25300,0.055686682,0.031598393,,,,,,,,,,,,,,,,, -25400,0.05823597,0.03222127,,,,,,,,,,,,,,,,, -25500,0.04693961,0.026938364,,,,,,,,,,,,,,,,, -25556,,,0.9913595914840698,0.028126923367381,0.4622975459881457,0.9869400262832642,0.0441190637648105,0.2790331791350799,43793.0,0.9861093759536744,0.0466210879385471,0.2705003644585464,43793.0,8176.548576593399,11793.34719634056,8176.548576593399,3614.790126800537,1.2863295078277588,0.0 -25600,0.060278535,0.029621478,,,,,,,,,,,,,,,,, -25700,0.04607278,0.028817946,,,,,,,,,,,,,,,,, -25800,0.054326627,0.0295806,,,,,,,,,,,,,,,,, -25900,0.059996493,0.029078947,,,,,,,,,,,,,,,,, -26000,0.061429735,0.028673766,,,,,,,,,,,,,,,,, -26100,0.058368333,0.02988681,,,,,,,,,,,,,,,,, -26200,0.05889978,0.030051528,,,,,,,,,,,,,,,,, -26300,0.05683282,0.032302447,,,,,,,,,,,,,,,,, -26314,,,0.9912744164466858,0.0281446538865566,0.4479110439168735,0.9868324398994446,0.0439815558493137,0.2810698441376711,43793.0,0.9859468340873718,0.0466155149042606,0.2698534271230323,43793.0,8416.634189844131,12131.83881664276,8416.634189844131,3713.144189834594,1.3176674842834473,0.0 -26400,0.051817924,0.026385553,,,,,,,,,,,,,,,,, -26500,0.052686878,0.029865278,,,,,,,,,,,,,,,,, -26600,0.06206098,0.035052046,,,,,,,,,,,,,,,,, -26700,0.043498762,0.02959601,,,,,,,,,,,,,,,,, -26800,0.048809525,0.025311312,,,,,,,,,,,,,,,,, -26900,0.06440774,0.032378685,,,,,,,,,,,,,,,,, -27000,0.08026186,0.02929744,,,,,,,,,,,,,,,,, -27065,,,0.9913526177406312,0.0279363002628088,0.4632723640628128,0.986867368221283,0.0443237386643886,0.2782958896685469,43793.0,0.9860487580299376,0.0471839122474193,0.2670339478514961,43793.0,8656.794304132462,12475.151976585388,8656.794304132462,3816.244790554047,1.3499679565429688,0.0 -27100,0.056534313,0.02693093,,,,,,,,,,,,,,,,, -27200,0.055738777,0.034295555,,,,,,,,,,,,,,,,, -27300,0.0558803,0.030335763,,,,,,,,,,,,,,,,, -27400,0.05568225,0.027317265,,,,,,,,,,,,,,,,, -27500,0.059207514,0.026241362,,,,,,,,,,,,,,,,, -27600,0.05420743,0.02907873,,,,,,,,,,,,,,,,, -27700,0.045187283,0.029052127,,,,,,,,,,,,,,,,, -27800,0.05203309,0.028147306,,,,,,,,,,,,,,,,, -27820,,,0.9914987683296204,0.0274489261209964,0.4683983500487964,0.9869043231010436,0.0442365556955337,0.2760711966392069,43793.0,0.9861077070236206,0.0469717867672443,0.264931771800272,43793.0,8896.875655651093,12816.583455085754,8896.875655651093,3917.5429599285126,1.381486415863037,0.0 -27900,0.054994397,0.028294677,,,,,,,,,,,,,,,,, -28000,0.075187646,0.029909741,,,,,,,,,,,,,,,,, -28100,0.055309203,0.030441096,,,,,,,,,,,,,,,,, -28200,0.04788161,0.026104046,,,,,,,,,,,,,,,,, -28300,0.055218168,0.0280187,,,,,,,,,,,,,,,,, -28400,0.05222668,0.028442344,,,,,,,,,,,,,,,,, -28500,0.054392904,0.02889376,,,,,,,,,,,,,,,,, -28577,,,0.9915719032287598,0.0271757207810878,0.4733769339022719,0.9869562983512878,0.044232428073883,0.2856657416394459,43793.0,0.9861435294151306,0.0470503196120262,0.2718404436297,43793.0,9137.05087184906,13163.168231010435,9137.05087184906,4023.892383813858,1.4207780361175537,0.0 -28600,0.049709298,0.027566794,,,,,,,,,,,,,,,,, -28700,0.057698928,0.029030336,,,,,,,,,,,,,,,,, -28800,0.051055644,0.027155865,,,,,,,,,,,,,,,,, -28900,0.05755033,0.026547411,,,,,,,,,,,,,,,,, -29000,0.052957214,0.030183094,,,,,,,,,,,,,,,,, -29100,0.06087676,0.02721628,,,,,,,,,,,,,,,,, -29200,0.052670836,0.027765697,,,,,,,,,,,,,,,,, -29300,0.046530787,0.026850509,,,,,,,,,,,,,,,,, -29337,,,0.991813600063324,0.0265144873410463,0.4933331247934246,0.987015962600708,0.0440828762948513,0.2803936676206833,43793.0,0.9861460328102112,0.0468155331909656,0.2694635318031406,43793.0,9377.265213012695,13506.539906024933,9377.265213012695,4126.997245788574,1.452960968017578,0.0 -29400,0.052980274,0.02896015,,,,,,,,,,,,,,,,, -29500,0.054846745,0.027608769,,,,,,,,,,,,,,,,, -29600,0.07270521,0.02870557,,,,,,,,,,,,,,,,, -29700,0.07429186,0.027773576,,,,,,,,,,,,,,,,, -29800,0.06986174,0.033011973,,,,,,,,,,,,,,,,, -29900,0.0489099,0.029424138,,,,,,,,,,,,,,,,, -30000,0.07576813,0.03125146,,,,,,,,,,,,,,,,, -30096,,,0.9918859601020812,0.0261133890599012,0.5055414899153353,0.9869863390922546,0.0437818579375743,0.2895730354147849,43793.0,0.9861797094345092,0.0468039661645889,0.2741908935948261,43793.0,9617.43070077896,13852.530488491058,9617.43070077896,4232.770231723785,1.485065221786499,0.0 -30100,0.05560451,0.029737337,,,,,,,,,,,,,,,,, -30200,0.05661956,0.0272982,,,,,,,,,,,,,,,,, -30300,0.051871795,0.027248226,,,,,,,,,,,,,,,,, -30400,0.060952384,0.026409058,,,,,,,,,,,,,,,,, -30500,0.0665474,0.02978332,,,,,,,,,,,,,,,,, -30600,0.067811675,0.029340655,,,,,,,,,,,,,,,,, -30700,0.054536883,0.02723301,,,,,,,,,,,,,,,,, -30800,0.07877838,0.029410021,,,,,,,,,,,,,,,,, -30853,,,0.9918944835662842,0.0259779468178749,0.5109262523099191,0.986935555934906,0.0444079004228115,0.2870528055093321,43793.0,0.9860929846763612,0.0471954904496669,0.2743780890275217,43793.0,9857.432433843613,14194.312682628632,9857.432433843613,4334.496691703796,1.5186164379119873,0.0 -30900,0.055344325,0.029502966,,,,,,,,,,,,,,,,, -31000,0.07593305,0.029115219,,,,,,,,,,,,,,,,, -31100,0.06261238,0.027281197,,,,,,,,,,,,,,,,, -31200,0.07125377,0.030079704,,,,,,,,,,,,,,,,, -31300,0.070370615,0.028632209,,,,,,,,,,,,,,,,, -31400,0.06096845,0.02891601,,,,,,,,,,,,,,,,, -31500,0.06497662,0.025383657,,,,,,,,,,,,,,,,, -31600,0.057557803,0.031391926,,,,,,,,,,,,,,,,, -31607,,,0.9922711849212646,0.0249787494540214,0.526230375162185,0.9870455861091614,0.0440018251538276,0.2825439937763267,43793.0,0.9863023161888124,0.0465546734631061,0.2826208197222706,43793.0,10097.557705879211,14540.564175128937,10097.557705879211,4440.5671174526215,1.551112174987793,0.0 -31700,0.06134048,0.024918294,,,,,,,,,,,,,,,,, -31800,0.06893582,0.03155397,,,,,,,,,,,,,,,,, -31900,0.05615293,0.029872322,,,,,,,,,,,,,,,,, -32000,0.059244186,0.026772937,,,,,,,,,,,,,,,,, -32100,0.08238887,0.028000025,,,,,,,,,,,,,,,,, -32200,0.055254545,0.029600395,,,,,,,,,,,,,,,,, -32300,0.061287347,0.030918812,,,,,,,,,,,,,,,,, -32365,,,0.9920186400413512,0.0256517603993415,0.5237306076253523,0.9870054125785828,0.0448300689458847,0.2819343967433677,43793.0,0.986120343208313,0.0476447828114032,0.2724481736220283,43793.0,10337.713238954544,14881.338361740112,10337.713238954544,4541.132849693298,1.584326982498169,0.0 -32400,0.07607377,0.033711385,,,,,,,,,,,,,,,,, -32500,0.06883486,0.02753343,,,,,,,,,,,,,,,,, -32600,0.06597186,0.030592274,,,,,,,,,,,,,,,,, -32700,0.06649501,0.029192071,,,,,,,,,,,,,,,,, -32800,0.056967642,0.029122079,,,,,,,,,,,,,,,,, -32900,0.065362364,0.03073229,,,,,,,,,,,,,,,,, -33000,0.069279455,0.026690677,,,,,,,,,,,,,,,,, -33100,0.06396434,0.02842745,,,,,,,,,,,,,,,,, -33121,,,0.9920905828475952,0.0256081577390432,0.5157125620277023,0.9869648218154908,0.0443299151957035,0.2834494862434448,43793.0,0.9861447811126708,0.0467436574399471,0.2786746065075245,43793.0,10577.716572523115,15223.10135102272,10577.716572523115,4642.83878827095,1.6173911094665527,0.0 -33200,0.07825178,0.03078606,,,,,,,,,,,,,,,,, -33300,0.073834516,0.030480864,,,,,,,,,,,,,,,,, -33400,0.059459932,0.027182734,,,,,,,,,,,,,,,,, -33500,0.054015066,0.024431348,,,,,,,,,,,,,,,,, -33600,0.061351188,0.028870517,,,,,,,,,,,,,,,,, -33700,0.069210514,0.02765286,,,,,,,,,,,,,,,,, -33800,0.067961715,0.028421802,,,,,,,,,,,,,,,,, -33874,,,0.9917319416999816,0.0265057682991027,0.4951007638370917,0.9868974089622498,0.0445404797792434,0.284593376840681,43793.0,0.9861708879470824,0.0472316332161426,0.2744900853244321,43793.0,10817.91072845459,15572.50892305374,10817.91072845459,4751.999447107315,1.6501054763793943,0.0 -33900,0.061072297,0.027418228,,,,,,,,,,,,,,,,, -34000,0.071033545,0.032521218,,,,,,,,,,,,,,,,, -34100,0.06679286,0.025767282,,,,,,,,,,,,,,,,, -34200,0.068356186,0.029876625,,,,,,,,,,,,,,,,, -34300,0.059252158,0.026247393,,,,,,,,,,,,,,,,, -34400,0.0527971,0.023865724,,,,,,,,,,,,,,,,, -34500,0.07489057,0.02912043,,,,,,,,,,,,,,,,, -34600,0.081520885,0.027232004,,,,,,,,,,,,,,,,, -34632,,,0.9918310642242432,0.0261690374463796,0.4915393588454686,0.987043559551239,0.0445980280637741,0.2865059627513893,43793.0,0.9862428903579712,0.0472697019577026,0.2789517823758971,43793.0,11058.15099453926,15910.252793550491,11058.15099453926,4849.44939661026,1.683269739151001,0.0 -34700,0.06590462,0.029244412,,,,,,,,,,,,,,,,, -34800,0.06331464,0.027169777,,,,,,,,,,,,,,,,, -34900,0.064076394,0.026341878,,,,,,,,,,,,,,,,, -35000,0.06314522,0.02719968,,,,,,,,,,,,,,,,, -35100,0.07042484,0.027421225,,,,,,,,,,,,,,,,, -35200,0.060879353,0.025297064,,,,,,,,,,,,,,,,, -35300,0.06663118,0.02851169,,,,,,,,,,,,,,,,, -35378,,,0.9919923543930054,0.0257877148687839,0.5099326315816797,0.9868158102035522,0.0444606207311153,0.2835898925780593,43793.0,0.9859455227851868,0.0472081154584884,0.2693484981572296,43793.0,11298.339334487917,16249.758395671844,11298.339334487917,4948.712647199631,1.7171683311462402,0.0 -35400,0.06418819,0.028403029,,,,,,,,,,,,,,,,, -35500,0.06904648,0.031672586,,,,,,,,,,,,,,,,, -35600,0.0819008,0.026826894,,,,,,,,,,,,,,,,, -35700,0.06615642,0.028780337,,,,,,,,,,,,,,,,, -35800,0.06342214,0.025229396,,,,,,,,,,,,,,,,, -35900,0.08349824,0.028897664,,,,,,,,,,,,,,,,, -36000,0.06889889,0.026024122,,,,,,,,,,,,,,,,, -36100,0.07908233,0.027876908,,,,,,,,,,,,,,,,, -36126,,,0.992000699043274,0.0254850517958402,0.5258720739264008,0.987106442451477,0.0446226224303245,0.2882472900207319,43793.0,0.986178457736969,0.0475702285766601,0.2711337537526851,43793.0,11538.558787345886,16590.798290252686,11538.558787345886,5049.473174333572,1.7545788288116455,0.0 -36200,0.07119404,0.027783306,,,,,,,,,,,,,,,,, -36300,0.08811392,0.029225366,,,,,,,,,,,,,,,,, -36400,0.059283916,0.0263951,,,,,,,,,,,,,,,,, -36500,0.084790714,0.031413972,,,,,,,,,,,,,,,,, -36600,0.06177201,0.026169984,,,,,,,,,,,,,,,,, -36700,0.07538556,0.025979154,,,,,,,,,,,,,,,,, -36800,0.07491824,0.027215054,,,,,,,,,,,,,,,,, -36889,,,0.9921327829360962,0.0251639336347579,0.5333610794699605,0.986931085586548,0.044403463602066,0.2876226417928679,43793.0,0.986159086227417,0.0470145195722579,0.2754729484835849,43793.0,11778.575371980667,16933.982120752335,11778.575371980667,5152.5869171619415,1.787407159805298,0.0 -36900,0.065613866,0.027880946,,,,,,,,,,,,,,,,, -37000,0.077277675,0.02861731,,,,,,,,,,,,,,,,, -37100,0.063661605,0.023619885,,,,,,,,,,,,,,,,, -37200,0.07915021,0.02830316,,,,,,,,,,,,,,,,, -37300,0.09316482,0.031931132,,,,,,,,,,,,,,,,, -37400,0.06798577,0.027806884,,,,,,,,,,,,,,,,, -37500,0.06882188,0.02618197,,,,,,,,,,,,,,,,, -37600,0.08798067,0.028791215,,,,,,,,,,,,,,,,, -37648,,,0.9923455119132996,0.0244872458279132,0.5460171638954645,0.9870200157165528,0.0446493402123451,0.2832415157203511,43793.0,0.9862471222877502,0.0473320521414279,0.2776725133597497,43793.0,12018.111292123796,17278.17639899254,12018.111292123796,5256.769897699356,2.2428689002990723,0.0 -37700,0.06694174,0.02587853,,,,,,,,,,,,,,,,, -37800,0.07288295,0.027811661,,,,,,,,,,,,,,,,, -37900,0.06445939,0.026056934,,,,,,,,,,,,,,,,, -38000,0.083298124,0.029164875,,,,,,,,,,,,,,,,, -38100,0.071040176,0.028186576,,,,,,,,,,,,,,,,, -38200,0.07620975,0.027379777,,,,,,,,,,,,,,,,, -38300,0.07261551,0.026027333,,,,,,,,,,,,,,,,, -38400,,,0.9923316836357116,0.0241330750286579,0.5530907366085315,0.9871101379394532,0.0448829308152198,0.2893611700860986,43793.0,0.9862862825393676,0.0478410758078098,0.2788159578705024,43793.0,12258.103625297546,17617.65689277649,12258.103625297546,5356.201946020126,2.277024269104004,0.0 -38400,0.0843068,0.026146568,,,,,,,,,,,,,,,,, -38500,0.06477447,0.023018256,,,,,,,,,,,,,,,,, -38600,0.06323562,0.028764186,,,,,,,,,,,,,,,,, -38700,0.06393758,0.023169192,,,,,,,,,,,,,,,,, -38800,0.08197141,0.02596784,,,,,,,,,,,,,,,,, -38900,0.06686661,0.026180612,,,,,,,,,,,,,,,,, -39000,0.062020738,0.024402836,,,,,,,,,,,,,,,,, -39100,0.071491845,0.02539368,,,,,,,,,,,,,,,,, -39153,,,0.9926602840423584,0.0234204418957233,0.5637032087035159,0.987121880054474,0.0448669232428073,0.2899202695664321,43793.0,0.9862328171730042,0.047747578471899,0.2744675293708428,43793.0,12498.116117477415,17958.667691469193,12498.116117477415,5457.145552873611,2.3104984760284424,0.0 -39200,0.07505207,0.027611082,,,,,,,,,,,,,,,,, -39300,0.07120761,0.029557528,,,,,,,,,,,,,,,,, -39400,0.07677443,0.026916754,,,,,,,,,,,,,,,,, -39500,0.07270806,0.026193622,,,,,,,,,,,,,,,,, -39600,0.06536177,0.026600989,,,,,,,,,,,,,,,,, -39700,0.08449059,0.025462164,,,,,,,,,,,,,,,,, -39800,0.06765312,0.027200146,,,,,,,,,,,,,,,,, -39900,0.08149138,0.026612328,,,,,,,,,,,,,,,,, -39908,,,0.9927948713302612,0.0230580810457468,0.5783965721104753,0.9870395064353944,0.0444376207888126,0.286592025599736,43793.0,0.9862656593322754,0.0473434291779994,0.2774957648688163,43793.0,12738.242744922638,18298.64096140861,12738.242744922638,5556.928096294403,2.3550827503204346,0.0 -40000,0.074420005,0.027072256,,,,,,,,,,,,,,,,, -40100,0.06914817,0.025764372,,,,,,,,,,,,,,,,, -40200,0.10172809,0.029186532,,,,,,,,,,,,,,,,, -40300,0.0728595,0.025729332,,,,,,,,,,,,,,,,, -40400,0.057231653,0.024139453,,,,,,,,,,,,,,,,, -40500,0.07030474,0.027966918,,,,,,,,,,,,,,,,, -40600,0.08156126,0.029392958,,,,,,,,,,,,,,,,, -40666,,,0.992606282234192,0.0236912108957767,0.5603386086431281,0.986970067024231,0.0448714643716812,0.2884174564599778,43793.0,0.9861329793930054,0.0476587414741516,0.2742626805934189,43793.0,12978.360262870789,18640.750893592834,12978.360262870789,5658.854161739349,2.401482343673706,0.0 -40700,0.06033966,0.024815109,,,,,,,,,,,,,,,,, -40800,0.06855115,0.026981952,,,,,,,,,,,,,,,,, -40900,0.08774033,0.027696159,,,,,,,,,,,,,,,,, -41000,0.06762152,0.025427176,,,,,,,,,,,,,,,,, -41100,0.07926761,0.024029253,,,,,,,,,,,,,,,,, -41200,0.08130864,0.02347615,,,,,,,,,,,,,,,,, -41300,0.07917619,0.027785389,,,,,,,,,,,,,,,,, -41400,0.0663514,0.027979989,,,,,,,,,,,,,,,,, -41428,,,0.9925146102905272,0.023902615532279,0.5628586654561019,0.9869022965431212,0.0450378134846687,0.2864156006446135,43793.0,0.9861510992050172,0.0476103760302066,0.2792500978473953,43793.0,13218.463328838348,18989.86360406876,13218.463328838348,5767.809935808182,2.435475349426269,0.0 -41500,0.09883573,0.027445046,,,,,,,,,,,,,,,,, -41600,0.06514094,0.025087032,,,,,,,,,,,,,,,,, -41700,0.066963285,0.024896681,,,,,,,,,,,,,,,,, -41800,0.0752981,0.023574324,,,,,,,,,,,,,,,,, -41900,0.07627228,0.027873974,,,,,,,,,,,,,,,,, -42000,0.087394044,0.02733303,,,,,,,,,,,,,,,,, -42100,0.07945077,0.026397567,,,,,,,,,,,,,,,,, -42176,,,0.9924070835113524,0.024378603324294,0.5376271174259066,0.986899435520172,0.0449063070118427,0.2886609224145899,43793.0,0.986019253730774,0.0477226413786411,0.2770095178376321,43793.0,13458.669250249864,19328.203336000443,13458.669250249864,5865.882692337036,2.4739890098571777,0.0 -42200,0.0709763,0.0261735,,,,,,,,,,,,,,,,, -42300,0.077002324,0.027324742,,,,,,,,,,,,,,,,, -42400,0.067138314,0.025265642,,,,,,,,,,,,,,,,, -42500,0.0691928,0.022837795,,,,,,,,,,,,,,,,, -42600,0.08468722,0.027511027,,,,,,,,,,,,,,,,, -42700,0.07175239,0.0262144,,,,,,,,,,,,,,,,, -42800,0.091313876,0.026944058,,,,,,,,,,,,,,,,, -42900,0.096658505,0.029076736,,,,,,,,,,,,,,,,, -42932,,,0.9924079179763794,0.024059934541583,0.5503674235235126,0.9870553016662598,0.0449348129332065,0.2921576254088969,43793.0,0.9862534403800964,0.0477136820554733,0.2787048486685619,43793.0,13698.63403391838,19671.65065932274,13698.63403391838,5969.3086223602295,2.510117053985596,0.0 -43000,0.0778659,0.02459097,,,,,,,,,,,,,,,,, -43100,0.08319354,0.028148625,,,,,,,,,,,,,,,,, -43200,0.07824332,0.026387896,,,,,,,,,,,,,,,,, -43300,0.0763942,0.024941724,,,,,,,,,,,,,,,,, -43400,0.084480345,0.026022501,,,,,,,,,,,,,,,,, -43500,0.06950775,0.025546595,,,,,,,,,,,,,,,,, -43600,0.08579944,0.026259732,,,,,,,,,,,,,,,,, -43675,,,0.9927031993865968,0.0233249384909868,0.5674737634256353,0.9871320724487304,0.0450026728212833,0.2952275795871997,43793.0,0.9862707257270812,0.048059307038784,0.275700904360155,43793.0,13938.807181596756,20014.515516281128,13938.807181596756,6071.934442520142,2.5530261993408203,0.0 -43700,0.08646498,0.02573403,,,,,,,,,,,,,,,,, -43800,0.08781888,0.02833531,,,,,,,,,,,,,,,,, -43900,0.08937235,0.02670967,,,,,,,,,,,,,,,,, -44000,0.0846478,0.026982626,,,,,,,,,,,,,,,,, -44100,0.08388243,0.025429765,,,,,,,,,,,,,,,,, -44200,0.0662382,0.02426963,,,,,,,,,,,,,,,,, -44300,0.0799373,0.023754478,,,,,,,,,,,,,,,,, -44400,0.092362806,0.0293897,,,,,,,,,,,,,,,,, -44433,,,0.9927995800971984,0.0229585859924554,0.5733450457802596,0.98695707321167,0.0449227802455425,0.2942502749410342,43793.0,0.9861220121383668,0.0477507822215557,0.2761335127109555,43793.0,14179.057677268982,20356.0144674778,14179.057677268982,6173.127116203308,2.588895082473755,0.0 -44500,0.09408794,0.026611513,,,,,,,,,,,,,,,,, -44600,0.06763538,0.023645902,,,,,,,,,,,,,,,,, -44700,0.08013109,0.027778072,,,,,,,,,,,,,,,,, -44800,0.079776004,0.027592925,,,,,,,,,,,,,,,,, -44900,0.0778934,0.023531757,,,,,,,,,,,,,,,,, -45000,0.094224654,0.028064122,,,,,,,,,,,,,,,,, -45100,0.07286275,0.028079908,,,,,,,,,,,,,,,,, -45198,,,0.992863118648529,0.0225877240300178,0.5935786349699248,0.9871166348457336,0.0451795570552349,0.2876991151416316,43793.0,0.9862256646156312,0.0481463596224784,0.2737167666026969,43793.0,14419.24270439148,20699.316056489944,14419.24270439148,6276.180570602417,2.632185935974121,0.0 -45200,0.07675138,0.024225058,,,,,,,,,,,,,,,,, -45300,0.10158419,0.027168734,,,,,,,,,,,,,,,,, -45400,0.077956125,0.02439711,,,,,,,,,,,,,,,,, -45500,0.08022661,0.025467403,,,,,,,,,,,,,,,,, -45600,0.07879134,0.024037212,,,,,,,,,,,,,,,,, -45700,0.0996986,0.026963945,,,,,,,,,,,,,,,,, -45800,0.076109484,0.02499917,,,,,,,,,,,,,,,,, -45900,0.084338166,0.025757976,,,,,,,,,,,,,,,,, -45945,,,0.9931178092956544,0.021580209955573,0.6017364886534111,0.9870991706848145,0.0457594506442546,0.2840771218195006,43793.0,0.9862942695617676,0.0487545803189277,0.2747491794169562,43793.0,14659.305247306824,21039.20338702202,14659.305247306824,6375.948750257492,2.666360378265381,0.0 -46000,0.0882246,0.02622733,,,,,,,,,,,,,,,,, -46100,0.08072629,0.026096655,,,,,,,,,,,,,,,,, -46200,0.07367178,0.02061778,,,,,,,,,,,,,,,,, -46300,0.09141264,0.027032707,,,,,,,,,,,,,,,,, -46400,0.08522222,0.023790447,,,,,,,,,,,,,,,,, -46500,0.084904425,0.024431411,,,,,,,,,,,,,,,,, -46600,0.085057914,0.024517596,,,,,,,,,,,,,,,,, -46700,0.07879739,0.024211945,,,,,,,,,,,,,,,,, -46709,,,0.9932405352592468,0.0215317849069833,0.607432035924848,0.9870342016220092,0.0452977679669857,0.2873679751801644,43793.0,0.9862281680107116,0.048151209950447,0.2759195730682534,43793.0,14899.513046503069,21378.66554641724,14899.513046503069,6475.143660068512,2.705780267715454,0.0 -46800,0.08366704,0.024218412,,,,,,,,,,,,,,,,, -46900,0.08410231,0.02471675,,,,,,,,,,,,,,,,, -47000,0.082662806,0.024516976,,,,,,,,,,,,,,,,, -47100,0.09389187,0.026842047,,,,,,,,,,,,,,,,, -47200,0.088380836,0.024339627,,,,,,,,,,,,,,,,, -47300,0.096451946,0.026285384,,,,,,,,,,,,,,,,, -47400,0.082449324,0.024346171,,,,,,,,,,,,,,,,, -47468,,,0.9933563470840454,0.0209414884448051,0.6339597542179121,0.9871182441711426,0.0456337742507457,0.2866939484275668,43793.0,0.986265242099762,0.0487691946327686,0.2753911116802991,43793.0,15139.652928829191,21720.766462802887,15139.652928829191,6577.043631315231,2.747018337249756,0.0 -47500,0.09803543,0.026043613,,,,,,,,,,,,,,,,, -47600,0.09848143,0.023121644,,,,,,,,,,,,,,,,, -47700,0.09236847,0.024095545,,,,,,,,,,,,,,,,, -47800,0.09262111,0.025248392,,,,,,,,,,,,,,,,, -47900,0.09710582,0.028835122,,,,,,,,,,,,,,,,, -48000,0.078594446,0.022640256,,,,,,,,,,,,,,,,, -48100,0.08675035,0.02455608,,,,,,,,,,,,,,,,, -48200,0.07644296,0.025356442,,,,,,,,,,,,,,,,, -48225,,,0.9933629035949708,0.021108966320753,0.6215161265846207,0.9870553016662598,0.045908585190773,0.2829216172937069,43793.0,0.9862454533576964,0.0489246025681495,0.2751573313937183,43793.0,15379.684081554413,22058.68881917,15379.684081554413,6674.879307746887,2.782480478286743,0.0 -48300,0.08746413,0.025179585,,,,,,,,,,,,,,,,, -48400,0.11204299,0.023308776,,,,,,,,,,,,,,,,, -48500,0.07509385,0.023635471,,,,,,,,,,,,,,,,, -48600,0.08893545,0.026673535,,,,,,,,,,,,,,,,, -48700,0.09589293,0.025470925,,,,,,,,,,,,,,,,, -48800,0.10294751,0.023316277,,,,,,,,,,,,,,,,, -48900,0.12082295,0.029761096,,,,,,,,,,,,,,,,, -48987,,,0.9932666420936584,0.0212262161076068,0.611152309972192,0.9870601892471312,0.0456889979541301,0.290932891249044,43793.0,0.9862496256828308,0.0487202107906341,0.2773312555186638,43793.0,15619.90098786354,22399.94140410424,15619.90098786354,6775.858198404312,2.8195629119873047,0.0 -49000,0.09741661,0.024073444,,,,,,,,,,,,,,,,, -49100,0.10358415,0.025394518,,,,,,,,,,,,,,,,, -49200,0.098717265,0.025702445,,,,,,,,,,,,,,,,, -49300,0.08292712,0.02508735,,,,,,,,,,,,,,,,, -49400,0.090624064,0.024557738,,,,,,,,,,,,,,,,, -49500,0.07828407,0.02390759,,,,,,,,,,,,,,,,, -49600,0.089436226,0.02468201,,,,,,,,,,,,,,,,, -49700,0.08878493,0.022656431,,,,,,,,,,,,,,,,, -49745,,,0.9932006001472472,0.021596472710371,0.6055980733194052,0.9869509935379028,0.0460079833865165,0.2892799181159078,43793.0,0.9861186742782592,0.0488905385136604,0.278546482883604,43793.0,15860.069697856905,22739.921632766724,15860.069697856905,6875.612053394318,2.857351303100586,0.0 -49800,0.1067358,0.025511887,,,,,,,,,,,,,,,,, -49900,0.10105734,0.025219657,,,,,,,,,,,,,,,,, -50000,0.086563684,0.024599167,,,,,,,,,,,,,,,,, -50100,0.10168818,0.024301015,,,,,,,,,,,,,,,,, -50200,0.10058904,0.024407346,,,,,,,,,,,,,,,,, -50300,0.10184451,0.023133038,,,,,,,,,,,,,,,,, -50400,0.0930021,0.02232906,,,,,,,,,,,,,,,,, -50487,,,0.9930036664009094,0.0220139380544424,0.593615504286974,0.9869891405105592,0.046247225254774,0.2876874796115182,43793.0,0.9861944913864136,0.0490744151175022,0.2775945821958022,43793.0,16100.29632639885,23079.203650951385,16100.29632639885,6974.607265710831,2.895205497741699,0.0 -50500,0.115881324,0.026954953,,,,,,,,,,,,,,,,, -50600,0.08975903,0.024470594,,,,,,,,,,,,,,,,, -50700,0.08748326,0.026607426,,,,,,,,,,,,,,,,, -50800,0.103413515,0.024344902,,,,,,,,,,,,,,,,, -50900,0.10032648,0.023565333,,,,,,,,,,,,,,,,, -51000,0.087040834,0.022276646,,,,,,,,,,,,,,,,, -51100,0.085700944,0.022612663,,,,,,,,,,,,,,,,, -51200,0.09432035,0.021100758,,,,,,,,,,,,,,,,, -51253,,,0.9931820631027222,0.0213425699621438,0.6134943919388604,0.9870630502700806,0.0462197475135326,0.2916114859958877,43793.0,0.986276626586914,0.0492494367063045,0.2809767090009311,43793.0,16340.390928268433,23423.58703923225,16340.390928268433,7078.8386816978455,2.932920217514038,0.0 -51300,0.10498254,0.023079123,,,,,,,,,,,,,,,,, -51400,0.08838641,0.023136681,,,,,,,,,,,,,,,,, -51500,0.09034569,0.02609218,,,,,,,,,,,,,,,,, -51600,0.09163225,0.022365807,,,,,,,,,,,,,,,,, -51700,0.0854076,0.02181275,,,,,,,,,,,,,,,,, -51800,0.09614662,0.023884548,,,,,,,,,,,,,,,,, -51900,0.1002846,0.02404534,,,,,,,,,,,,,,,,, -52000,0.10998601,0.024739804,,,,,,,,,,,,,,,,, -52016,,,0.993344247341156,0.0208707619458436,0.6242266281073867,0.9870747923851012,0.0464723855257034,0.2907870160289012,43793.0,0.986294686794281,0.0494212321937084,0.2847148471225769,43793.0,16580.394956111908,23762.726235628128,16580.394956111908,7177.91805267334,2.969144821166992,0.0 -52100,0.11357634,0.026081946,,,,,,,,,,,,,,,,, -52200,0.0839162,0.021109696,,,,,,,,,,,,,,,,, -52300,0.09444642,0.02141187,,,,,,,,,,,,,,,,, -52400,0.09328995,0.021212647,,,,,,,,,,,,,,,,, -52500,0.104486786,0.023033498,,,,,,,,,,,,,,,,, -52600,0.102913335,0.020923251,,,,,,,,,,,,,,,,, -52700,0.110111006,0.026050473,,,,,,,,,,,,,,,,, -52779,,,0.993465006351471,0.0205740891396999,0.6301589141153873,0.987104833126068,0.046230211853981,0.2900999925999507,43793.0,0.9862024784088136,0.0493457950651645,0.2773575915352462,43793.0,16820.457956552505,24107.30019426346,16820.457956552505,7282.37335062027,3.005022048950196,0.0 -52800,0.09975637,0.02093048,,,,,,,,,,,,,,,,, -52900,0.10604876,0.026424201,,,,,,,,,,,,,,,,, -53000,0.08450646,0.02155314,,,,,,,,,,,,,,,,, -53100,0.10446744,0.020054774,,,,,,,,,,,,,,,,, -53200,0.1059084,0.022198431,,,,,,,,,,,,,,,,, -53300,0.11367432,0.02487739,,,,,,,,,,,,,,,,, -53400,0.11112951,0.02378203,,,,,,,,,,,,,,,,, -53500,0.1069569,0.025321042,,,,,,,,,,,,,,,,, -53539,,,0.9936485290527344,0.0197937209159135,0.6516988153813621,0.9871584177017212,0.0465914942324161,0.2926621584624282,43793.0,0.9862854480743408,0.0497939959168434,0.2795073140182851,43793.0,17060.670438051224,24445.238028764725,17060.670438051224,7380.0420706272125,3.041551351547241,0.0 -53600,0.111963,0.025175406,,,,,,,,,,,,,,,,, -53700,0.0959528,0.019621618,,,,,,,,,,,,,,,,, -53800,0.106434464,0.023640411,,,,,,,,,,,,,,,,, -53900,0.12119891,0.024898125,,,,,,,,,,,,,,,,, -54000,0.103723735,0.024056783,,,,,,,,,,,,,,,,, -54100,0.103084125,0.025608692,,,,,,,,,,,,,,,,, -54200,0.10005304,0.023972258,,,,,,,,,,,,,,,,, -54287,,,0.993804931640625,0.0193524714559316,0.6605385215433868,0.9869335293769836,0.0468230247497558,0.2932988620035959,43793.0,0.9861018061637878,0.0499334521591663,0.2787303739599055,43793.0,17300.793816804886,24788.86217021942,17300.793816804886,7483.484494686127,3.0773909091949463,0.0 -54300,0.089442074,0.020360127,,,,,,,,,,,,,,,,, -54400,0.11430348,0.022736153,,,,,,,,,,,,,,,,, -54500,0.096443854,0.022164013,,,,,,,,,,,,,,,,, -54600,0.107562296,0.024989175,,,,,,,,,,,,,,,,, -54700,0.10332594,0.024553398,,,,,,,,,,,,,,,,, -54800,0.09850698,0.02306423,,,,,,,,,,,,,,,,, -54900,0.11437499,0.023705596,,,,,,,,,,,,,,,,, -55000,0.104244135,0.022366337,,,,,,,,,,,,,,,,, -55034,,,0.9940738677978516,0.018616709858179,0.6824790188145755,0.9870139360427856,0.0467959530651569,0.291244495393672,43793.0,0.9861464500427246,0.0501525029540061,0.2725020764620905,43793.0,17540.84943985939,25132.44391345977,17540.84943985939,7586.947088718414,3.1195337772369385,0.0 -55100,0.0961994,0.021715643,,,,,,,,,,,,,,,,, -55200,0.103676565,0.02226936,,,,,,,,,,,,,,,,, -55300,0.124726325,0.024476357,,,,,,,,,,,,,,,,, -55400,0.09216361,0.020323463,,,,,,,,,,,,,,,,, -55500,0.10634796,0.025724767,,,,,,,,,,,,,,,,, -55600,0.09830875,0.022933219,,,,,,,,,,,,,,,,, -55700,0.110456675,0.02180558,,,,,,,,,,,,,,,,, -55785,,,0.994262993335724,0.0181721411645412,0.6914053178116203,0.9870870113372804,0.0470523163676261,0.2891383001431945,43793.0,0.9862500429153442,0.0502575188875198,0.2784436858642199,43793.0,17780.89176774025,25472.753759860992,17780.89176774025,7687.15647649765,3.157621622085572,0.0 -55800,0.104871176,0.024929436,,,,,,,,,,,,,,,,, -55900,0.09968613,0.019887492,,,,,,,,,,,,,,,,, -56000,0.11198479,0.02400962,,,,,,,,,,,,,,,,, -56100,0.1025875,0.021562781,,,,,,,,,,,,,,,,, -56200,0.116913125,0.023270613,,,,,,,,,,,,,,,,, -56300,0.10633685,0.021055086,,,,,,,,,,,,,,,,, -56400,0.10632865,0.021695439,,,,,,,,,,,,,,,,, -56500,0.098956555,0.020049691,,,,,,,,,,,,,,,,, -56537,,,0.994239866733551,0.0183276608586311,0.685328157042368,0.9870309829711914,0.0470251254737377,0.2896792143578579,43793.0,0.986182689666748,0.0502708926796913,0.2769247715519446,43793.0,18020.839967250824,25815.408913373947,18020.839967250824,7789.804910898209,3.195686817169189,0.0 -56600,0.12222444,0.02129477,,,,,,,,,,,,,,,,, -56700,0.11281088,0.021453347,,,,,,,,,,,,,,,,, -56800,0.10954922,0.020088067,,,,,,,,,,,,,,,,, -56900,0.1262467,0.022257024,,,,,,,,,,,,,,,,, -57000,0.100579865,0.020578314,,,,,,,,,,,,,,,,, -57100,0.10964593,0.022438634,,,,,,,,,,,,,,,,, -57200,0.12722427,0.021555694,,,,,,,,,,,,,,,,, -57296,,,0.9940085411071776,0.0186333321034908,0.6795497812086909,0.9870337843894958,0.0475324653089046,0.2885810856821342,43793.0,0.9863001704216005,0.0507479310035705,0.2777486826024884,43793.0,18260.91561436653,26153.78284072876,18260.91561436653,7888.045190811157,3.2336459159851074,0.0 -57300,0.10563655,0.021338264,,,,,,,,,,,,,,,,, -57400,0.11799816,0.020356946,,,,,,,,,,,,,,,,, -57500,0.11783178,0.023237253,,,,,,,,,,,,,,,,, -57600,0.120142795,0.021376397,,,,,,,,,,,,,,,,, -57700,0.12612247,0.023848051,,,,,,,,,,,,,,,,, -57800,0.10702263,0.020564778,,,,,,,,,,,,,,,,, -57900,0.12633717,0.022879353,,,,,,,,,,,,,,,,, -57983,,,,,,,,,,,,,,18477.23125720024,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 7f2f2b8cc..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -859.2842590808868,0.0,36.66873526573181,1,0,36.66873526573181,0.0007088489946909,0.0,11.191027641296388,3003,895.9530372619629,0.0005978592089377,0.0,11.188164710998535,0.0004835649742744,0.0,11.190281867980955,3000 -1538.283482313156,0.0305800437927246,876.663633108139,2396,0,876.663633108139,0.3806286752223968,7.894892890419983,4.3253560066223145,3003,2415.0562007427216,0.4108909666538238,14.135664486209082,4.017301082611084,0.3961265087127685,9.774756744618458,4.122348785400391,3000 -1997.2269372940063,0.0612609386444091,1716.6430037021637,4790,0,1716.6430037021637,0.5423159599304199,18.71880292122399,2.7538504600524902,3003,3714.0869665145874,0.542180597782135,24.602004108236525,2.732574224472046,0.5455852746963501,20.36131793008979,2.710276365280152,3000 -2430.385052204132,0.0902647972106933,2556.738527774811,7185,0,2556.738527774811,0.5881703495979309,21.78799174427264,2.31225848197937,3003,4987.44571518898,0.5838678479194641,27.234345079265907,2.3458170890808105,0.5861551761627197,23.17941764435876,2.3145346641540527,3000 -2872.007763624192,0.1170134544372558,3396.931195497513,9582,0,3396.931195497513,0.6104235649108887,23.303168552867703,2.106617450714112,3003,6269.367269515991,0.5891788601875305,27.89985486351619,2.2676243782043457,0.6069980263710022,24.639766045857712,2.1352217197418213,3000 -3450.5933167934418,0.1438968181610107,4237.160478591919,11980,0,4237.160478591919,0.6257509589195251,24.60483532445832,1.974539041519165,3003,7688.285516738892,0.6027758717536926,28.891173233400416,2.144322872161865,0.6202526688575745,25.276258137532324,2.017131328582764,3000 -3891.816184282303,0.1714715957641601,5077.2845776081085,14380,0,5077.2845776081085,0.637022852897644,25.595357216731195,1.8806644678115845,3003,8969.737039804459,0.6137219071388245,29.763904979834205,2.0440056324005127,0.6301347613334656,26.04915640493588,1.92502510547638,3000 -4331.687678337097,0.1997363567352295,5917.195971488953,16776,0,5917.195971488953,0.6476788520812988,25.957416144448224,1.806699275970459,3003,10249.626316785812,0.6152881979942322,29.83333990536737,2.0378637313842773,0.6370410919189453,26.930205823477618,1.862482070922852,3000 -4788.225700378418,0.2278189659118652,6757.092479228973,19173,0,6757.092479228973,0.6542908549308777,26.770263755815385,1.7641087770462036,3003,11546.166801214218,0.6377730965614319,31.384704872693035,1.8624444007873533,0.6448649168014526,27.341544220468872,1.817959427833557,3000 -5232.208755493164,0.2553091049194336,7597.14856171608,21572,0,7597.14856171608,0.6562082767486572,26.91141067893076,1.7340083122253418,3003,12830.313049316406,0.6281996965408325,30.43993556295968,1.9296544790267944,0.647431492805481,27.470066942691965,1.7907466888427734,3000 -5685.127123594284,0.2881033420562744,8437.255508899689,23972,0,8437.255508899689,0.6590785384178162,26.735188246892022,1.7163658142089844,3003,14123.446489810944,0.6257244944572449,30.647498348593707,1.9461606740951536,0.6491302251815796,27.639039633589352,1.7739145755767822,3000 -6189.89656496048,0.316986083984375,9277.202247619629,26373,0,9277.202247619629,0.6620068550109863,27.445354021922327,1.7058236598968506,3003,15468.266919612885,0.6318857073783875,31.12798198722012,1.8993377685546875,0.6520439982414246,27.978040845419148,1.7544596195220947,3000 -6676.355022907257,0.3454797267913818,10117.357129573822,28775,0,10117.357129573822,0.2897216975688934,0.0742939463471604,4.496841907501221,3003,16794.98263859749,0.3327521383762359,0.5231126249610653,3.951723337173462,0.2997359037399292,0.1150917293941524,4.32647705078125,3000 -7171.487823009491,0.3740403652191162,10957.487800359726,31177,0,10957.487800359726,0.662657618522644,27.054145258705955,1.6983014345169067,3003,18130.35057020188,0.6317818760871887,30.54583083674687,1.9045122861862185,0.6527383327484131,27.74894091426883,1.7534050941467283,3000 -7665.472952365875,0.4036171436309814,11797.682716608047,33579,0,11797.682716608047,0.6678984761238098,27.42061605427625,1.6765241622924805,3003,19464.63546180725,0.6354771852493286,30.563934941582254,1.879430055618286,0.6564084887504578,28.160680413584807,1.7356117963790894,3000 -8156.066877841949,0.4347925186157226,12637.7005712986,35981,0,12637.7005712986,0.6671082377433777,27.54092155263876,1.6758179664611816,3003,20795.35339331627,0.6375195384025574,31.009726888605,1.8644064664840696,0.6567308306694031,28.174935173674044,1.7263554334640503,3000 -8625.987797498703,0.4648962020874023,13477.855075120926,38382,0,13477.855075120926,0.6678752303123474,27.693778275595207,1.6616512537002563,3003,22105.53423690796,0.6450705528259277,32.15894108712716,1.804628610610962,0.6593098640441895,28.36643041451481,1.7143281698226929,3000 -9132.55176949501,0.4963889122009277,14317.929590463638,40783,0,14317.929590463638,0.6715008020401001,27.98273853137321,1.6438584327697754,3003,23452.281606912613,0.6396812200546265,31.35065259778344,1.8572800159454343,0.6599918007850647,28.27201281310921,1.7028095722198486,3000 -9649.80432677269,0.5288918018341064,15158.089678287506,43185,0,15158.089678287506,0.6713729500770569,27.85697410575625,1.6449847221374512,3003,24809.804277181625,0.6362661123275757,31.457801930764408,1.8771040439605715,0.6599298119544983,28.35290610957772,1.7009010314941406,3000 -10285.620803833008,0.5597467422485352,15998.09845662117,45589,0,15998.09845662117,0.5968043804168701,19.25898748849468,2.1455793380737305,3003,26285.73772263527,0.590805172920227,24.934241026593256,2.153980493545532,0.5927143096923828,20.67777602591015,2.137026071548462,3000 -10755.301027297974,0.5908377170562744,16838.03125667572,47991,0,16838.03125667572,0.6726977229118347,27.737813681491367,1.6335320472717283,3003,27595.45713019371,0.6427844762802124,31.89963562097901,1.8317044973373413,0.6606737375259399,28.24034997826204,1.6950204372406006,3000 -11244.286452054976,0.6243085861206055,17677.970749616623,50392,0,17677.970749616623,0.6735227704048157,27.932242256488045,1.623759388923645,3003,28924.49044728279,0.6582806706428528,32.246968132861625,1.7210654020309448,0.6620624661445618,28.434917155790025,1.6892980337142944,3000 -11790.28213953972,0.6552319526672363,18517.93385744095,52793,0,18517.93385744095,0.6762535572052002,28.25260805810742,1.611672282218933,3003,30310.55475568772,0.6478515863418579,31.51827536648588,1.799898624420166,0.6649638414382935,28.84308670692472,1.6738790273666382,3000 -12287.342983961104,0.6947894096374512,19358.101722240448,55195,0,19358.101722240448,0.6748591065406799,28.13716320799326,1.615303874015808,3003,31647.902009248734,0.6423264145851135,31.520352554833806,1.831187605857849,0.6630296111106873,28.5659265759097,1.678435444831848,3000 -12749.21016407013,0.7266678810119629,20198.159299850464,57598,0,20198.159299850464,0.6766951680183411,28.10985823697703,1.596364974975586,3003,32949.93321561813,0.6510561108589172,31.74470780612312,1.7641416788101196,0.6652490496635437,28.623343659957573,1.6621553897857666,3000 -13262.024190664291,0.7604336738586426,21038.290306568146,59999,0,21038.290306568146,0.6790773272514343,28.48931735603496,1.5836652517318726,3003,34302.990511894226,0.644768476486206,31.910825778247982,1.8109740018844604,0.6663773655891418,28.56648746349024,1.6560256481170654,3000 -13727.225650072098,0.7934637069702148,21878.360268354416,62401,0,21878.360268354416,0.6824356913566589,28.67643381041361,1.577749729156494,3003,35608.370332956314,0.6491748094558716,32.01965766582803,1.7910131216049194,0.669477105140686,29.233560910622813,1.6406055688858032,3000 -14303.382984161375,0.8290464878082275,22718.57821083069,64804,0,22718.57821083069,0.6805996298789978,28.65766199623945,1.5786640644073486,3003,37024.85886883736,0.6520285606384277,31.971719271704508,1.760977268218994,0.6713989973068237,29.0899984850159,1.6386438608169556,3000 -14788.661399126053,0.8624668121337891,23558.55596637726,67206,0,23558.55596637726,0.6822613477706909,28.72514626657771,1.5671225786209106,3003,38350.22704553604,0.6519597768783569,31.91697570224685,1.7757725715637207,0.670419454574585,29.23616891361002,1.638580083847046,3000 -15286.969474315643,0.8974158763885498,24398.73073887825,69609,0,24398.73073887825,0.6841090321540833,29.071967186401448,1.5585025548934937,3003,39688.82049059868,0.6634093523025513,32.762389538704355,1.6923391819000244,0.6716221570968628,29.347775662744024,1.6273186206817627,3000 -15781.841740846634,0.9326050281524658,25238.861780643463,72012,0,25238.861780643463,0.6856196522712708,28.909133725017643,1.5455368757247925,3003,41023.935676813126,0.6599642634391785,32.44681045615368,1.7183440923690796,0.6714857816696167,29.12857090191241,1.6171070337295532,3000 -16306.87161040306,0.9680724143981934,26079.09641242028,74415,0,26079.09641242028,0.6896287202835083,29.456325345785103,1.534721612930298,3003,42389.31352448464,0.6568201780319214,32.56595782076231,1.7489362955093384,0.6759618520736694,29.63801741455089,1.6061208248138428,3000 -16784.96683216095,1.0114808082580566,26919.07650828361,76817,0,26919.07650828361,0.6898146867752075,29.37583258832015,1.5289326906204224,3003,43707.50957036018,0.661069929599762,32.60740521012439,1.7087233066558838,0.6762098073959351,29.3602712239802,1.603691577911377,3000 -17255.74615764618,1.0472896099090576,27759.00503540039,79219,0,27759.00503540039,0.6900470852851868,29.416101277503813,1.530118107795715,3003,45018.331510305405,0.6592926383018494,32.536791630899906,1.724799871444702,0.6757262945175171,29.62497735225051,1.6033855676651,3000 -17825.267672777176,1.0836036205291748,28599.207001686096,81622,0,28599.207001686096,0.6895938515663147,29.071133285384303,1.5182905197143557,3003,46428.168867349625,0.6859181523323059,34.2604911187917,1.5658107995986938,0.6783672571182251,29.177389345403792,1.58946430683136,3000 -18324.43095898628,1.1216096878051758,29439.43800854683,84025,0,29439.43800854683,0.6921852231025696,29.47494867004409,1.5084973573684692,3003,47767.677359580994,0.664563000202179,33.1192903855946,1.696059226989746,0.6789996027946472,29.565914486020333,1.582720160484314,3000 -18846.27249503136,1.1595826148986816,30279.43471693993,86427,0,30279.43471693993,0.6930103302001953,29.577308672706145,1.5063138008117676,3003,49129.6320040226,0.6629806756973267,32.61050182170082,1.7091134786605835,0.6798179745674133,29.637436093481103,1.5794235467910769,3000 -19347.71030664444,1.1959412097930908,31119.40724849701,88829,0,31119.40724849701,0.6959851384162903,29.98966318985136,1.4950261116027832,3003,50471.155947208405,0.6717724204063416,33.690016325414604,1.6399476528167725,0.6823846101760864,30.108637798690488,1.570753574371338,3000 -19810.24594926834,1.2348592281341553,31959.536118268967,91231,0,31959.536118268967,0.6975887417793274,30.00201244760764,1.4849611520767212,3003,51773.93786597252,0.6652762293815613,33.516918895580865,1.684821844100952,0.682136595249176,29.97718632659504,1.5662659406661987,3000 -20335.630873441696,1.280862808227539,32799.64951753616,93633,0,32799.64951753616,0.6991226673126221,30.245471742788,1.4706521034240725,3003,53139.55933403969,0.6700052618980408,33.97726970739959,1.6654683351516724,0.6840088963508606,30.201155231440644,1.5550308227539062,3000 -20851.50264811516,1.319817066192627,33639.54993915558,96035,0,33639.54993915558,0.6990413069725037,30.28542712697332,1.4698675870895386,3003,54495.446476221085,0.6742818355560303,33.98130465498545,1.6293132305145264,0.6850255727767944,30.02152235059844,1.5457754135131836,3000 -21343.85585975647,1.3581314086914062,34479.54863166809,98436,0,34479.54863166809,0.7021788358688354,30.44932657753321,1.4656221866607666,3003,55827.9131128788,0.6713905930519104,33.99387977317795,1.650544047355652,0.685149610042572,30.28188354970151,1.544081687927246,3000 -21868.93807411194,1.398036003112793,35319.508835315704,100838,0,35319.508835315704,0.7017489075660706,30.560988154441777,1.4601340293884275,3003,57193.07353281975,0.6884165406227112,35.2754136989264,1.5480619668960571,0.6868482828140259,30.528489437300816,1.538313865661621,3000 -22355.186529397964,1.44234037399292,36159.6839966774,103240,0,36159.6839966774,0.7038870453834534,30.51891889328535,1.4475866556167605,3003,58519.62125611305,0.6779972314834595,34.50147922318023,1.604690432548523,0.6882741451263428,30.41491065908115,1.5299122333526611,3000 -22838.996133089066,1.4821085929870603,36999.760825634,105642,0,36999.760825634,0.7051188349723816,30.512686812294906,1.4421643018722534,3003,59843.622671842575,0.6837121248245239,34.452289799932295,1.5825554132461548,0.6886957287788391,30.520724395170017,1.523957371711731,3000 -23336.82561659813,1.5219931602478027,37839.9554643631,108044,0,37839.9554643631,0.7067108154296875,30.80866748847693,1.433881402015686,3003,61181.76550936699,0.6904653310775757,35.465164598367224,1.5444891452789309,0.6911135315895081,30.564646606616183,1.5180892944335938,3000 -23839.350203037266,1.570399045944214,38680.11610341072,110446,0,38680.11610341072,0.7068270444869995,30.79432213794177,1.429835557937622,3003,62524.57753229141,0.6886968612670898,34.8873148481868,1.552962303161621,0.6909399628639221,30.795585694647222,1.516938328742981,3000 -24334.75069284439,1.6107723712921145,39520.2310898304,112846,0,39520.2310898304,0.7065481543540955,30.935883540712304,1.431631565093994,3003,63860.211265563965,0.7077298760414124,36.72570097184537,1.4527026414871216,0.6913243532180786,30.76716783706688,1.5172661542892456,3000 -24817.92786431313,1.6504244804382324,40360.24406290054,115248,0,40360.24406290054,0.70927894115448,30.84711171958969,1.4228955507278442,3003,65183.5163064003,0.69832843542099,35.53555504677267,1.5016064643859863,0.693717360496521,30.621397086007867,1.5081539154052734,3000 -25323.58639740944,1.6914031505584717,41200.15984940529,117649,0,41200.15984940529,0.7097670435905457,30.92984988614016,1.4213274717330933,3003,66529.20644688606,0.6946881413459778,35.48660590814062,1.516922116279602,0.6925890445709229,31.07297394971343,1.506855607032776,3000 -25808.70779204369,1.7335550785064695,42040.353865385056,120050,0,42040.353865385056,0.7109523415565491,31.28327925826771,1.4153249263763428,3003,67854.64235019684,0.7077162861824036,35.699938968051285,1.4448779821395874,0.693221390247345,30.88221735153581,1.5050368309020996,3000 -26298.369471549988,1.7748847007751465,42880.47454190254,122450,0,42880.47454190254,0.7108477354049683,31.20686190559411,1.4121458530426023,3003,69184.54448390007,0.7041738033294678,36.04297490759636,1.4591567516326904,0.6942257285118103,31.002648152503227,1.5022872686386108,3000 -26783.07488465309,1.8175811767578125,43720.612193107605,124852,0,43720.612193107605,0.7114055156707764,31.294057267940783,1.4117215871810913,3003,70509.50488901138,0.7027409076690674,36.27220736720769,1.4704632759094238,0.6941513419151306,31.01403960420787,1.5010319948196411,3000 -27269.56863617897,1.865821361541748,44560.81832766533,127254,0,44560.81832766533,0.7116495370864868,31.268713611605445,1.409090876579285,3003,71836.32944989204,0.7076531052589417,36.59637497627496,1.4463974237442017,0.6947712898254395,31.142788279699538,1.4987682104110718,3000 -27762.5166027546,1.908890962600708,45400.99208855629,129656,0,45400.99208855629,0.7116960287094116,31.130467907561723,1.4084151983261108,3003,73169.57006430626,0.7090178728103638,36.79773448540247,1.4380403757095337,0.6948084831237793,30.993740811222963,1.4995105266571045,3000 -28244.09170150757,1.952728033065796,46240.98077011109,132058,0,46240.98077011109,0.7120911478996277,31.249781981597465,1.4084433317184448,3003,74491.25281620026,0.7075369954109192,36.62875430446796,1.453471302986145,0.6945481300354004,31.076465832705026,1.4992551803588867,3000 -28729.57975912094,1.9961745738983154,46686.89723396301,133333,0,46686.89723396301,0.7120795249938965,31.198964185453555,1.408614993095398,3003,75422.7423479557,0.7088993191719055,36.586798307686536,1.4459348917007446,0.6944240927696228,31.059765499889604,1.499434232711792,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 44abffe83..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.4194074,11.15856,,,,,,,,,,,,,,,,, -1,,,0.0005978592089377,11.188164710998535,0.0,0.0004835649742744,11.190281867980955,0.0,3000.0,0.0007088489946909,11.191027641296388,0.0,3003.0,36.66873526573181,895.9530372619629,36.66873526573181,859.2842590808868,0.0,0.0 -100,0.42528227,8.962848,,,,,,,,,,,,,,,,, -200,0.1663825,8.640404,,,,,,,,,,,,,,,,, -300,0.20045772,8.420865,,,,,,,,,,,,,,,,, -400,0.25424585,8.015001,,,,,,,,,,,,,,,,, -500,0.37435165,7.6830664,,,,,,,,,,,,,,,,, -600,0.4818032,7.4660316,,,,,,,,,,,,,,,,, -700,0.5534128,7.244294,,,,,,,,,,,,,,,,, -800,0.5590087,7.0125585,,,,,,,,,,,,,,,,, -900,0.48544762,6.849398,,,,,,,,,,,,,,,,, -1000,0.5167057,6.5558643,,,,,,,,,,,,,,,,, -1100,0.63844675,6.40058,,,,,,,,,,,,,,,,, -1200,0.74035233,6.199862,,,,,,,,,,,,,,,,, -1300,0.7884056,6.103892,,,,,,,,,,,,,,,,, -1400,0.59029144,5.9702883,,,,,,,,,,,,,,,,, -1500,0.64305013,5.8134623,,,,,,,,,,,,,,,,, -1600,0.6377042,5.681713,,,,,,,,,,,,,,,,, -1700,0.751525,5.615146,,,,,,,,,,,,,,,,, -1800,0.70100164,5.437358,,,,,,,,,,,,,,,,, -1900,0.80482936,5.3843746,,,,,,,,,,,,,,,,, -2000,1.3536358,5.324806,,,,,,,,,,,,,,,,, -2100,1.108948,5.095525,,,,,,,,,,,,,,,,, -2200,0.9458641,5.067303,,,,,,,,,,,,,,,,, -2300,0.6677108,4.962763,,,,,,,,,,,,,,,,, -2396,,,0.4108909666538238,4.017301082611084,14.135664486209082,0.3961265087127685,4.122348785400391,9.774756744618458,3000.0,0.3806286752223968,4.3253560066223145,7.894892890419983,3003.0,876.663633108139,2415.0562007427216,876.663633108139,1538.283482313156,0.0305800437927246,0.0 -2400,0.91265917,4.8684697,,,,,,,,,,,,,,,,, -2500,0.8074229,4.807572,,,,,,,,,,,,,,,,, -2600,0.7062898,4.692281,,,,,,,,,,,,,,,,, -2700,1.1238914,4.6830773,,,,,,,,,,,,,,,,, -2800,0.8473366,4.485372,,,,,,,,,,,,,,,,, -2900,0.9035134,4.5503306,,,,,,,,,,,,,,,,, -3000,0.7354458,4.463273,,,,,,,,,,,,,,,,, -3100,0.994355,4.3445187,,,,,,,,,,,,,,,,, -3200,0.81429094,4.2241187,,,,,,,,,,,,,,,,, -3300,0.7298565,4.1783237,,,,,,,,,,,,,,,,, -3400,0.5972297,4.194804,,,,,,,,,,,,,,,,, -3500,0.7686181,4.0663304,,,,,,,,,,,,,,,,, -3600,0.7743767,4.082444,,,,,,,,,,,,,,,,, -3700,0.6672373,3.9689033,,,,,,,,,,,,,,,,, -3800,0.72484684,4.0259905,,,,,,,,,,,,,,,,, -3900,0.5413583,4.0734,,,,,,,,,,,,,,,,, -4000,0.8572595,3.9899547,,,,,,,,,,,,,,,,, -4100,0.8928013,3.942252,,,,,,,,,,,,,,,,, -4200,0.5906884,3.9767284,,,,,,,,,,,,,,,,, -4300,0.5915397,3.9129205,,,,,,,,,,,,,,,,, -4400,0.6671742,3.926551,,,,,,,,,,,,,,,,, -4500,0.5354587,3.8102558,,,,,,,,,,,,,,,,, -4600,0.75565207,3.9019446,,,,,,,,,,,,,,,,, -4700,0.5723802,3.7799127,,,,,,,,,,,,,,,,, -4790,,,0.542180597782135,2.732574224472046,24.602004108236525,0.5455852746963501,2.710276365280152,20.36131793008979,3000.0,0.5423159599304199,2.7538504600524902,18.71880292122399,3003.0,1716.6430037021637,3714.0869665145874,1716.6430037021637,1997.2269372940063,0.0612609386444091,0.0 -4800,0.53480554,3.7215567,,,,,,,,,,,,,,,,, -4900,0.5471739,3.766754,,,,,,,,,,,,,,,,, -5000,0.5498984,3.7036247,,,,,,,,,,,,,,,,, -5100,0.65240335,3.725021,,,,,,,,,,,,,,,,, -5200,0.5699418,3.697103,,,,,,,,,,,,,,,,, -5300,0.54722834,3.6671038,,,,,,,,,,,,,,,,, -5400,0.57698727,3.729301,,,,,,,,,,,,,,,,, -5500,0.6858416,3.6577835,,,,,,,,,,,,,,,,, -5600,0.56958026,3.7140908,,,,,,,,,,,,,,,,, -5700,0.5260539,3.6029003,,,,,,,,,,,,,,,,, -5800,0.5058509,3.558176,,,,,,,,,,,,,,,,, -5900,0.49426094,3.6651728,,,,,,,,,,,,,,,,, -6000,0.46574938,3.5054169,,,,,,,,,,,,,,,,, -6100,0.47613457,3.5620577,,,,,,,,,,,,,,,,, -6200,0.48184446,3.6429398,,,,,,,,,,,,,,,,, -6300,0.5226581,3.5601523,,,,,,,,,,,,,,,,, -6400,0.45214602,3.5794024,,,,,,,,,,,,,,,,, -6500,0.5410308,3.5853992,,,,,,,,,,,,,,,,, -6600,0.49857983,3.5734,,,,,,,,,,,,,,,,, -6700,0.44248325,3.5846355,,,,,,,,,,,,,,,,, -6800,0.42461026,3.4926066,,,,,,,,,,,,,,,,, -6900,0.38765317,3.5297809,,,,,,,,,,,,,,,,, -7000,0.44581667,3.491657,,,,,,,,,,,,,,,,, -7100,0.4662662,3.4871078,,,,,,,,,,,,,,,,, -7185,,,0.5838678479194641,2.3458170890808105,27.234345079265907,0.5861551761627197,2.3145346641540527,23.17941764435876,3000.0,0.5881703495979309,2.31225848197937,21.78799174427264,3003.0,2556.738527774811,4987.44571518898,2556.738527774811,2430.385052204132,0.0902647972106933,0.0 -7200,0.37892708,3.4750495,,,,,,,,,,,,,,,,, -7300,0.44536558,3.443595,,,,,,,,,,,,,,,,, -7400,0.42507556,3.393884,,,,,,,,,,,,,,,,, -7500,0.44164243,3.4545288,,,,,,,,,,,,,,,,, -7600,0.39904538,3.4314003,,,,,,,,,,,,,,,,, -7700,0.4311406,3.4200056,,,,,,,,,,,,,,,,, -7800,0.44015446,3.4194708,,,,,,,,,,,,,,,,, -7900,0.40709043,3.4015532,,,,,,,,,,,,,,,,, -8000,0.38347575,3.363804,,,,,,,,,,,,,,,,, -8100,0.42317095,3.4079716,,,,,,,,,,,,,,,,, -8200,0.39093116,3.310473,,,,,,,,,,,,,,,,, -8300,0.36498675,3.457147,,,,,,,,,,,,,,,,, -8400,0.38027596,3.350937,,,,,,,,,,,,,,,,, -8500,0.41256148,3.3661163,,,,,,,,,,,,,,,,, -8600,0.36227122,3.3794422,,,,,,,,,,,,,,,,, -8700,0.33223423,3.389224,,,,,,,,,,,,,,,,, -8800,0.38269198,3.3245716,,,,,,,,,,,,,,,,, -8900,0.47024593,3.3904402,,,,,,,,,,,,,,,,, -9000,0.34822673,3.396903,,,,,,,,,,,,,,,,, -9100,0.36236075,3.3302562,,,,,,,,,,,,,,,,, -9200,0.32539213,3.4247665,,,,,,,,,,,,,,,,, -9300,0.3095316,3.2939103,,,,,,,,,,,,,,,,, -9400,0.30037495,3.3291497,,,,,,,,,,,,,,,,, -9500,0.34119767,3.3212583,,,,,,,,,,,,,,,,, -9582,,,0.5891788601875305,2.2676243782043457,27.89985486351619,0.6069980263710022,2.1352217197418213,24.639766045857712,3000.0,0.6104235649108887,2.106617450714112,23.303168552867703,3003.0,3396.931195497513,6269.367269515991,3396.931195497513,2872.007763624192,0.1170134544372558,0.0 -9600,0.3091865,3.3186982,,,,,,,,,,,,,,,,, -9700,0.30058935,3.254854,,,,,,,,,,,,,,,,, -9800,0.3273663,3.3013902,,,,,,,,,,,,,,,,, -9900,0.290918,3.2615993,,,,,,,,,,,,,,,,, -10000,0.29553938,3.249652,,,,,,,,,,,,,,,,, -10100,0.2773566,3.245557,,,,,,,,,,,,,,,,, -10200,0.31084087,3.2742853,,,,,,,,,,,,,,,,, -10300,0.29429132,3.2880263,,,,,,,,,,,,,,,,, -10400,0.29746807,3.2148225,,,,,,,,,,,,,,,,, -10500,0.29367056,3.3390713,,,,,,,,,,,,,,,,, -10600,0.2634024,3.2966726,,,,,,,,,,,,,,,,, -10700,0.3554486,3.3833153,,,,,,,,,,,,,,,,, -10800,0.31658137,3.30618,,,,,,,,,,,,,,,,, -10900,0.33166918,3.236732,,,,,,,,,,,,,,,,, -11000,0.25541994,3.3151455,,,,,,,,,,,,,,,,, -11100,0.2641542,3.2197952,,,,,,,,,,,,,,,,, -11200,0.2915626,3.2983782,,,,,,,,,,,,,,,,, -11300,0.25499612,3.2117715,,,,,,,,,,,,,,,,, -11400,0.2708567,3.206081,,,,,,,,,,,,,,,,, -11500,0.3133298,3.2422802,,,,,,,,,,,,,,,,, -11600,0.27364892,3.2354054,,,,,,,,,,,,,,,,, -11700,0.2702984,3.3683856,,,,,,,,,,,,,,,,, -11800,0.270074,3.1991484,,,,,,,,,,,,,,,,, -11900,0.29270828,3.2553442,,,,,,,,,,,,,,,,, -11980,,,0.6027758717536926,2.144322872161865,28.891173233400416,0.6202526688575745,2.017131328582764,25.276258137532324,3000.0,0.6257509589195251,1.974539041519165,24.60483532445832,3003.0,4237.160478591919,7688.285516738892,4237.160478591919,3450.5933167934418,0.1438968181610107,0.0 -12000,0.24769425,3.2747502,,,,,,,,,,,,,,,,, -12100,0.26362398,3.1402285,,,,,,,,,,,,,,,,, -12200,0.267762,3.1976335,,,,,,,,,,,,,,,,, -12300,0.25190267,3.238458,,,,,,,,,,,,,,,,, -12400,0.23565365,3.18849,,,,,,,,,,,,,,,,, -12500,0.26276875,3.2187998,,,,,,,,,,,,,,,,, -12600,0.25096655,3.2067606,,,,,,,,,,,,,,,,, -12700,0.25492552,3.2115479,,,,,,,,,,,,,,,,, -12800,0.25047433,3.148068,,,,,,,,,,,,,,,,, -12900,0.25555736,3.2165477,,,,,,,,,,,,,,,,, -13000,0.34942403,3.261744,,,,,,,,,,,,,,,,, -13100,0.23655266,3.2216015,,,,,,,,,,,,,,,,, -13200,0.25277933,3.2442422,,,,,,,,,,,,,,,,, -13300,0.23422498,3.1504936,,,,,,,,,,,,,,,,, -13400,0.22672053,3.212481,,,,,,,,,,,,,,,,, -13500,0.26511565,3.1633322,,,,,,,,,,,,,,,,, -13600,0.28695723,3.1529207,,,,,,,,,,,,,,,,, -13700,0.26070964,3.182811,,,,,,,,,,,,,,,,, -13800,0.24832873,3.2047396,,,,,,,,,,,,,,,,, -13900,0.26263,3.1998317,,,,,,,,,,,,,,,,, -14000,0.2531719,3.1302903,,,,,,,,,,,,,,,,, -14100,0.29162496,3.1390998,,,,,,,,,,,,,,,,, -14200,0.30591425,3.173442,,,,,,,,,,,,,,,,, -14300,0.24919423,3.0673227,,,,,,,,,,,,,,,,, -14380,,,0.6137219071388245,2.0440056324005127,29.763904979834205,0.6301347613334656,1.92502510547638,26.04915640493588,3000.0,0.637022852897644,1.8806644678115845,25.595357216731195,3003.0,5077.2845776081085,8969.737039804459,5077.2845776081085,3891.816184282303,0.1714715957641601,0.0 -14400,0.26077422,3.1573985,,,,,,,,,,,,,,,,, -14500,0.24482962,3.1194324,,,,,,,,,,,,,,,,, -14600,0.25213644,3.1521356,,,,,,,,,,,,,,,,, -14700,0.25548968,3.0929549,,,,,,,,,,,,,,,,, -14800,0.2446095,3.142793,,,,,,,,,,,,,,,,, -14900,0.23545122,3.1423094,,,,,,,,,,,,,,,,, -15000,0.31463227,3.0634105,,,,,,,,,,,,,,,,, -15100,0.29505306,3.0831187,,,,,,,,,,,,,,,,, -15200,0.2451519,3.240284,,,,,,,,,,,,,,,,, -15300,0.26497003,3.1529448,,,,,,,,,,,,,,,,, -15400,0.28218654,3.0399306,,,,,,,,,,,,,,,,, -15500,0.26117694,3.1575065,,,,,,,,,,,,,,,,, -15600,0.26393533,3.1612265,,,,,,,,,,,,,,,,, -15700,0.34186298,3.0698538,,,,,,,,,,,,,,,,, -15800,0.28408217,3.1444895,,,,,,,,,,,,,,,,, -15900,0.2833318,3.1840465,,,,,,,,,,,,,,,,, -16000,0.27434525,3.1293209,,,,,,,,,,,,,,,,, -16100,0.26580152,3.1788895,,,,,,,,,,,,,,,,, -16200,0.31162608,3.0538802,,,,,,,,,,,,,,,,, -16300,0.28431216,3.08129,,,,,,,,,,,,,,,,, -16400,0.2836126,3.0849624,,,,,,,,,,,,,,,,, -16500,0.33224472,3.1020718,,,,,,,,,,,,,,,,, -16600,0.25372732,3.0719335,,,,,,,,,,,,,,,,, -16700,0.26528853,3.0423148,,,,,,,,,,,,,,,,, -16776,,,0.6152881979942322,2.0378637313842773,29.83333990536737,0.6370410919189453,1.862482070922852,26.930205823477618,3000.0,0.6476788520812988,1.806699275970459,25.957416144448224,3003.0,5917.195971488953,10249.626316785812,5917.195971488953,4331.687678337097,0.1997363567352295,0.0 -16800,0.34303018,3.136628,,,,,,,,,,,,,,,,, -16900,0.28918457,3.0609138,,,,,,,,,,,,,,,,, -17000,0.32654727,3.141428,,,,,,,,,,,,,,,,, -17100,0.34120947,3.1171489,,,,,,,,,,,,,,,,, -17200,0.37823996,3.041304,,,,,,,,,,,,,,,,, -17300,0.3172903,3.118267,,,,,,,,,,,,,,,,, -17400,0.31417263,3.0422091,,,,,,,,,,,,,,,,, -17500,0.29476133,3.0582397,,,,,,,,,,,,,,,,, -17600,0.2877927,3.1181746,,,,,,,,,,,,,,,,, -17700,0.3458433,3.010891,,,,,,,,,,,,,,,,, -17800,0.37457022,3.0939136,,,,,,,,,,,,,,,,, -17900,0.40223598,3.0754607,,,,,,,,,,,,,,,,, -18000,0.3212312,3.0761154,,,,,,,,,,,,,,,,, -18100,0.30683467,3.0961797,,,,,,,,,,,,,,,,, -18200,0.35429567,3.0764375,,,,,,,,,,,,,,,,, -18300,0.29006195,3.0569713,,,,,,,,,,,,,,,,, -18400,0.29769394,3.0508852,,,,,,,,,,,,,,,,, -18500,0.35239136,3.0741813,,,,,,,,,,,,,,,,, -18600,0.38958612,3.0714834,,,,,,,,,,,,,,,,, -18700,0.3476743,3.0484338,,,,,,,,,,,,,,,,, -18800,0.36745343,3.0787427,,,,,,,,,,,,,,,,, -18900,0.3596079,2.9991145,,,,,,,,,,,,,,,,, -19000,0.33963192,3.071118,,,,,,,,,,,,,,,,, -19100,0.34334463,3.073323,,,,,,,,,,,,,,,,, -19173,,,0.6377730965614319,1.8624444007873533,31.384704872693035,0.6448649168014526,1.817959427833557,27.341544220468872,3000.0,0.6542908549308777,1.7641087770462036,26.770263755815385,3003.0,6757.092479228973,11546.166801214218,6757.092479228973,4788.225700378418,0.2278189659118652,0.0 -19200,0.3008093,3.0157077,,,,,,,,,,,,,,,,, -19300,0.29319152,3.018927,,,,,,,,,,,,,,,,, -19400,0.33993986,3.1085129,,,,,,,,,,,,,,,,, -19500,0.36164168,3.1157913,,,,,,,,,,,,,,,,, -19600,0.3555902,3.0531883,,,,,,,,,,,,,,,,, -19700,0.35599646,3.1454463,,,,,,,,,,,,,,,,, -19800,0.29391587,3.033497,,,,,,,,,,,,,,,,, -19900,0.33217973,3.060626,,,,,,,,,,,,,,,,, -20000,0.4012116,3.108803,,,,,,,,,,,,,,,,, -20100,0.34815124,3.000384,,,,,,,,,,,,,,,,, -20200,0.34517136,3.0677457,,,,,,,,,,,,,,,,, -20300,0.35411313,3.0795808,,,,,,,,,,,,,,,,, -20400,0.3091237,3.1398525,,,,,,,,,,,,,,,,, -20500,0.3898695,3.0303638,,,,,,,,,,,,,,,,, -20600,0.3733378,3.0654082,,,,,,,,,,,,,,,,, -20700,0.31239796,3.0140657,,,,,,,,,,,,,,,,, -20800,0.31830728,3.0426106,,,,,,,,,,,,,,,,, -20900,0.33978015,2.9699688,,,,,,,,,,,,,,,,, -21000,0.31035104,3.0354223,,,,,,,,,,,,,,,,, -21100,0.3524188,3.1043484,,,,,,,,,,,,,,,,, -21200,0.36926103,3.0920756,,,,,,,,,,,,,,,,, -21300,0.4997481,3.0672636,,,,,,,,,,,,,,,,, -21400,0.35077387,3.0863106,,,,,,,,,,,,,,,,, -21500,0.34789854,3.1083636,,,,,,,,,,,,,,,,, -21572,,,0.6281996965408325,1.9296544790267944,30.43993556295968,0.647431492805481,1.7907466888427734,27.470066942691965,3000.0,0.6562082767486572,1.7340083122253418,26.91141067893076,3003.0,7597.14856171608,12830.313049316406,7597.14856171608,5232.208755493164,0.2553091049194336,0.0 -21600,0.42602554,2.9867573,,,,,,,,,,,,,,,,, -21700,0.43310338,3.0244396,,,,,,,,,,,,,,,,, -21800,0.41940734,3.023382,,,,,,,,,,,,,,,,, -21900,0.40671486,3.103252,,,,,,,,,,,,,,,,, -22000,0.3894094,3.0191927,,,,,,,,,,,,,,,,, -22100,0.34132126,3.1134503,,,,,,,,,,,,,,,,, -22200,0.36494926,3.0719137,,,,,,,,,,,,,,,,, -22300,0.34757754,2.9804277,,,,,,,,,,,,,,,,, -22400,0.43021765,2.9939804,,,,,,,,,,,,,,,,, -22500,0.37833795,3.0026536,,,,,,,,,,,,,,,,, -22600,0.3342034,2.9949272,,,,,,,,,,,,,,,,, -22700,0.45939445,3.1056771,,,,,,,,,,,,,,,,, -22800,0.38309735,3.012552,,,,,,,,,,,,,,,,, -22900,0.38167074,3.0341973,,,,,,,,,,,,,,,,, -23000,0.360441,3.0173182,,,,,,,,,,,,,,,,, -23100,0.34016553,3.0373387,,,,,,,,,,,,,,,,, -23200,0.3326298,3.0430052,,,,,,,,,,,,,,,,, -23300,0.31476614,3.0591402,,,,,,,,,,,,,,,,, -23400,0.38523316,3.0848436,,,,,,,,,,,,,,,,, -23500,0.34270403,2.9561145,,,,,,,,,,,,,,,,, -23600,0.5044901,3.0935364,,,,,,,,,,,,,,,,, -23700,0.3825779,3.0552235,,,,,,,,,,,,,,,,, -23800,0.31954476,3.0053024,,,,,,,,,,,,,,,,, -23900,0.36296168,2.9764905,,,,,,,,,,,,,,,,, -23972,,,0.6257244944572449,1.9461606740951536,30.647498348593707,0.6491302251815796,1.7739145755767822,27.639039633589352,3000.0,0.6590785384178162,1.7163658142089844,26.735188246892022,3003.0,8437.255508899689,14123.446489810944,8437.255508899689,5685.127123594284,0.2881033420562744,0.0 -24000,0.3967652,2.9890523,,,,,,,,,,,,,,,,, -24100,0.34007272,2.9972775,,,,,,,,,,,,,,,,, -24200,0.3944685,3.0349782,,,,,,,,,,,,,,,,, -24300,0.3912773,3.0026238,,,,,,,,,,,,,,,,, -24400,0.3515919,2.9945354,,,,,,,,,,,,,,,,, -24500,0.3816071,2.9701412,,,,,,,,,,,,,,,,, -24600,0.3471481,2.9399655,,,,,,,,,,,,,,,,, -24700,0.47478172,2.990425,,,,,,,,,,,,,,,,, -24800,0.35098302,2.9734187,,,,,,,,,,,,,,,,, -24900,0.4315271,3.0328934,,,,,,,,,,,,,,,,, -25000,0.3849291,3.0283437,,,,,,,,,,,,,,,,, -25100,0.32938904,3.002583,,,,,,,,,,,,,,,,, -25200,0.34174728,2.952782,,,,,,,,,,,,,,,,, -25300,0.4118844,3.132526,,,,,,,,,,,,,,,,, -25400,0.39089516,3.0291739,,,,,,,,,,,,,,,,, -25500,0.45790032,3.0153239,,,,,,,,,,,,,,,,, -25600,0.38086346,2.9983087,,,,,,,,,,,,,,,,, -25700,0.37643123,3.002805,,,,,,,,,,,,,,,,, -25800,0.43844607,3.080972,,,,,,,,,,,,,,,,, -25900,0.37425083,3.014807,,,,,,,,,,,,,,,,, -26000,0.55374396,2.9682074,,,,,,,,,,,,,,,,, -26100,0.31615067,2.9938707,,,,,,,,,,,,,,,,, -26200,0.47701684,3.0024855,,,,,,,,,,,,,,,,, -26300,0.33360398,3.0011659,,,,,,,,,,,,,,,,, -26373,,,0.6318857073783875,1.8993377685546875,31.12798198722012,0.6520439982414246,1.7544596195220947,27.978040845419148,3000.0,0.6620068550109863,1.7058236598968506,27.445354021922327,3003.0,9277.202247619629,15468.266919612885,9277.202247619629,6189.89656496048,0.316986083984375,0.0 -26400,0.33576933,3.013734,,,,,,,,,,,,,,,,, -26500,0.45601264,2.9580157,,,,,,,,,,,,,,,,, -26600,0.36575937,2.9518278,,,,,,,,,,,,,,,,, -26700,0.35549057,3.0283034,,,,,,,,,,,,,,,,, -26800,0.41483843,3.0458329,,,,,,,,,,,,,,,,, -26900,0.40073797,3.0008945,,,,,,,,,,,,,,,,, -27000,0.32365465,2.972515,,,,,,,,,,,,,,,,, -27100,0.37877053,2.965573,,,,,,,,,,,,,,,,, -27200,0.3828146,2.9654615,,,,,,,,,,,,,,,,, -27300,0.3197186,2.9327617,,,,,,,,,,,,,,,,, -27400,0.36019477,2.9871025,,,,,,,,,,,,,,,,, -27500,0.5293545,3.001053,,,,,,,,,,,,,,,,, -27600,0.41428608,3.0117211,,,,,,,,,,,,,,,,, -27700,0.3580779,2.9968884,,,,,,,,,,,,,,,,, -27800,0.3365702,3.0146303,,,,,,,,,,,,,,,,, -27900,0.3574115,2.9423287,,,,,,,,,,,,,,,,, -28000,0.3939906,3.12065,,,,,,,,,,,,,,,,, -28100,0.3673827,2.920911,,,,,,,,,,,,,,,,, -28200,0.37813717,2.9672616,,,,,,,,,,,,,,,,, -28300,0.40727618,2.9867694,,,,,,,,,,,,,,,,, -28400,0.37223753,2.9662032,,,,,,,,,,,,,,,,, -28500,2.5226636,2.935749,,,,,,,,,,,,,,,,, -28600,0.4059371,3.085087,,,,,,,,,,,,,,,,, -28700,0.4548862,4.881657,,,,,,,,,,,,,,,,, -28775,,,0.3327521383762359,3.951723337173462,0.5231126249610653,0.2997359037399292,4.32647705078125,0.1150917293941524,3000.0,0.2897216975688934,4.496841907501221,0.0742939463471604,3003.0,10117.357129573822,16794.98263859749,10117.357129573822,6676.355022907257,0.3454797267913818,0.0 -28800,0.45022675,4.7929845,,,,,,,,,,,,,,,,, -28900,1.0211961,4.7193236,,,,,,,,,,,,,,,,, -29000,0.7067185,4.7025466,,,,,,,,,,,,,,,,, -29100,0.5418512,4.6628222,,,,,,,,,,,,,,,,, -29200,0.67145926,4.6906447,,,,,,,,,,,,,,,,, -29300,1.0712739,4.6793323,,,,,,,,,,,,,,,,, -29400,0.69011164,4.6197815,,,,,,,,,,,,,,,,, -29500,1.5590514,4.5755453,,,,,,,,,,,,,,,,, -29600,0.555885,3.2224627,,,,,,,,,,,,,,,,, -29700,0.38568258,3.0743248,,,,,,,,,,,,,,,,, -29800,0.39630565,3.038542,,,,,,,,,,,,,,,,, -29900,0.38019168,3.001736,,,,,,,,,,,,,,,,, -30000,0.33035654,3.0157242,,,,,,,,,,,,,,,,, -30100,0.3526381,3.025136,,,,,,,,,,,,,,,,, -30200,0.36079922,2.9722853,,,,,,,,,,,,,,,,, -30300,0.34444737,2.9741707,,,,,,,,,,,,,,,,, -30400,0.35648635,3.0370111,,,,,,,,,,,,,,,,, -30500,0.3949797,3.020733,,,,,,,,,,,,,,,,, -30600,0.34651518,2.939868,,,,,,,,,,,,,,,,, -30700,0.36744022,3.016074,,,,,,,,,,,,,,,,, -30800,0.36259824,2.9482863,,,,,,,,,,,,,,,,, -30900,0.3596482,2.9506288,,,,,,,,,,,,,,,,, -31000,0.42260394,2.978978,,,,,,,,,,,,,,,,, -31100,0.43996745,2.981469,,,,,,,,,,,,,,,,, -31177,,,0.6317818760871887,1.9045122861862185,30.54583083674687,0.6527383327484131,1.7534050941467283,27.74894091426883,3000.0,0.662657618522644,1.6983014345169067,27.054145258705955,3003.0,10957.487800359726,18130.35057020188,10957.487800359726,7171.487823009491,0.3740403652191162,0.0 -31200,0.3672789,2.9482946,,,,,,,,,,,,,,,,, -31300,0.39339402,3.0193374,,,,,,,,,,,,,,,,, -31400,0.38898605,2.9988444,,,,,,,,,,,,,,,,, -31500,0.41722688,2.9453523,,,,,,,,,,,,,,,,, -31600,0.43245795,2.9446006,,,,,,,,,,,,,,,,, -31700,0.3731996,3.0371833,,,,,,,,,,,,,,,,, -31800,0.37808597,3.0029612,,,,,,,,,,,,,,,,, -31900,0.32564446,2.9815273,,,,,,,,,,,,,,,,, -32000,0.36629602,2.9808185,,,,,,,,,,,,,,,,, -32100,0.44342566,2.9913602,,,,,,,,,,,,,,,,, -32200,0.344024,2.979103,,,,,,,,,,,,,,,,, -32300,0.36446905,2.918357,,,,,,,,,,,,,,,,, -32400,0.3771775,3.0981839,,,,,,,,,,,,,,,,, -32500,0.41393355,2.9559336,,,,,,,,,,,,,,,,, -32600,0.42108065,2.9382734,,,,,,,,,,,,,,,,, -32700,0.3654599,2.9606402,,,,,,,,,,,,,,,,, -32800,0.38021752,3.0677629,,,,,,,,,,,,,,,,, -32900,0.40541977,2.9817605,,,,,,,,,,,,,,,,, -33000,0.44361156,2.9980311,,,,,,,,,,,,,,,,, -33100,0.34991527,3.010003,,,,,,,,,,,,,,,,, -33200,0.3651487,3.0656395,,,,,,,,,,,,,,,,, -33300,0.3311187,2.9641085,,,,,,,,,,,,,,,,, -33400,0.38141936,2.9767141,,,,,,,,,,,,,,,,, -33500,0.39897048,3.0624993,,,,,,,,,,,,,,,,, -33579,,,0.6354771852493286,1.879430055618286,30.563934941582254,0.6564084887504578,1.7356117963790894,28.160680413584807,3000.0,0.6678984761238098,1.6765241622924805,27.42061605427625,3003.0,11797.682716608047,19464.63546180725,11797.682716608047,7665.472952365875,0.4036171436309814,0.0 -33600,0.41779417,2.9650662,,,,,,,,,,,,,,,,, -33700,0.3396748,2.9477718,,,,,,,,,,,,,,,,, -33800,0.33788988,2.966268,,,,,,,,,,,,,,,,, -33900,0.351573,3.0066562,,,,,,,,,,,,,,,,, -34000,0.37832552,2.9783144,,,,,,,,,,,,,,,,, -34100,0.40421623,2.956696,,,,,,,,,,,,,,,,, -34200,0.42254424,3.0615613,,,,,,,,,,,,,,,,, -34300,0.34012735,3.0304995,,,,,,,,,,,,,,,,, -34400,0.41777512,2.9257784,,,,,,,,,,,,,,,,, -34500,0.37270284,3.0228426,,,,,,,,,,,,,,,,, -34600,0.35844356,2.958135,,,,,,,,,,,,,,,,, -34700,0.32890686,2.9371471,,,,,,,,,,,,,,,,, -34800,0.35167527,3.0066156,,,,,,,,,,,,,,,,, -34900,0.38633633,3.046371,,,,,,,,,,,,,,,,, -35000,0.36603835,2.9616122,,,,,,,,,,,,,,,,, -35100,0.36086068,3.0077503,,,,,,,,,,,,,,,,, -35200,0.4167997,3.009339,,,,,,,,,,,,,,,,, -35300,0.39522076,2.9595547,,,,,,,,,,,,,,,,, -35400,0.35452688,2.9655268,,,,,,,,,,,,,,,,, -35500,0.36087814,3.039496,,,,,,,,,,,,,,,,, -35600,0.405044,2.9824996,,,,,,,,,,,,,,,,, -35700,0.37737036,3.0100708,,,,,,,,,,,,,,,,, -35800,0.33525416,2.9229007,,,,,,,,,,,,,,,,, -35900,0.3702704,2.9699187,,,,,,,,,,,,,,,,, -35981,,,0.6375195384025574,1.8644064664840696,31.009726888605,0.6567308306694031,1.7263554334640503,28.174935173674044,3000.0,0.6671082377433777,1.6758179664611816,27.54092155263876,3003.0,12637.7005712986,20795.35339331627,12637.7005712986,8156.066877841949,0.4347925186157226,0.0 -36000,0.33324292,2.9229689,,,,,,,,,,,,,,,,, -36100,0.40767404,3.0204332,,,,,,,,,,,,,,,,, -36200,0.43227053,2.958876,,,,,,,,,,,,,,,,, -36300,0.3777408,2.9527388,,,,,,,,,,,,,,,,, -36400,0.36283186,2.9484165,,,,,,,,,,,,,,,,, -36500,0.37362716,3.0104876,,,,,,,,,,,,,,,,, -36600,0.51323545,2.9564555,,,,,,,,,,,,,,,,, -36700,0.45460874,2.9413693,,,,,,,,,,,,,,,,, -36800,0.37424257,2.943585,,,,,,,,,,,,,,,,, -36900,0.40750197,3.0228343,,,,,,,,,,,,,,,,, -37000,0.3987169,2.998808,,,,,,,,,,,,,,,,, -37100,0.3967357,2.8843415,,,,,,,,,,,,,,,,, -37200,0.36741337,2.9808424,,,,,,,,,,,,,,,,, -37300,0.42013022,2.9563859,,,,,,,,,,,,,,,,, -37400,0.41514793,2.9891653,,,,,,,,,,,,,,,,, -37500,0.3585671,3.0132036,,,,,,,,,,,,,,,,, -37600,0.38881007,2.95092,,,,,,,,,,,,,,,,, -37700,0.33680677,2.9556942,,,,,,,,,,,,,,,,, -37800,0.3691126,2.9943428,,,,,,,,,,,,,,,,, -37900,0.40161783,3.0252385,,,,,,,,,,,,,,,,, -38000,0.34495574,2.9772882,,,,,,,,,,,,,,,,, -38100,0.40363353,2.979368,,,,,,,,,,,,,,,,, -38200,0.3859937,2.9525056,,,,,,,,,,,,,,,,, -38300,0.3682496,2.9800675,,,,,,,,,,,,,,,,, -38382,,,0.6450705528259277,1.804628610610962,32.15894108712716,0.6593098640441895,1.7143281698226929,28.36643041451481,3000.0,0.6678752303123474,1.6616512537002563,27.693778275595207,3003.0,13477.855075120926,22105.53423690796,13477.855075120926,8625.987797498703,0.4648962020874023,0.0 -38400,0.3971868,2.9484434,,,,,,,,,,,,,,,,, -38500,0.44992268,2.9544973,,,,,,,,,,,,,,,,, -38600,0.5929089,3.0279052,,,,,,,,,,,,,,,,, -38700,0.3677456,2.9034863,,,,,,,,,,,,,,,,, -38800,0.37065896,2.9408557,,,,,,,,,,,,,,,,, -38900,0.33518422,2.922355,,,,,,,,,,,,,,,,, -39000,0.42062548,3.0014298,,,,,,,,,,,,,,,,, -39100,0.46152714,2.9287465,,,,,,,,,,,,,,,,, -39200,0.35856825,3.0141125,,,,,,,,,,,,,,,,, -39300,0.36664146,2.9518566,,,,,,,,,,,,,,,,, -39400,0.37603888,2.9793937,,,,,,,,,,,,,,,,, -39500,0.34051973,2.9215696,,,,,,,,,,,,,,,,, -39600,0.33228716,2.9518728,,,,,,,,,,,,,,,,, -39700,0.32911438,2.9361537,,,,,,,,,,,,,,,,, -39800,0.37532154,2.9817505,,,,,,,,,,,,,,,,, -39900,0.439474,2.943122,,,,,,,,,,,,,,,,, -40000,0.40042624,2.9967935,,,,,,,,,,,,,,,,, -40100,0.39040914,2.9214752,,,,,,,,,,,,,,,,, -40200,0.34669605,2.9145098,,,,,,,,,,,,,,,,, -40300,0.33821505,2.9609575,,,,,,,,,,,,,,,,, -40400,0.39161363,2.954513,,,,,,,,,,,,,,,,, -40500,0.36181247,2.8819602,,,,,,,,,,,,,,,,, -40600,0.38166627,2.9785326,,,,,,,,,,,,,,,,, -40700,0.37074178,2.948669,,,,,,,,,,,,,,,,, -40783,,,0.6396812200546265,1.8572800159454343,31.35065259778344,0.6599918007850647,1.7028095722198486,28.27201281310921,3000.0,0.6715008020401001,1.6438584327697754,27.98273853137321,3003.0,14317.929590463638,23452.281606912613,14317.929590463638,9132.55176949501,0.4963889122009277,0.0 -40800,0.3895141,2.888696,,,,,,,,,,,,,,,,, -40900,0.37996328,2.9756393,,,,,,,,,,,,,,,,, -41000,0.36510828,2.9603238,,,,,,,,,,,,,,,,, -41100,0.36199087,2.9776487,,,,,,,,,,,,,,,,, -41200,0.3765194,2.9675474,,,,,,,,,,,,,,,,, -41300,0.36920306,2.9288979,,,,,,,,,,,,,,,,, -41400,0.42410615,2.8581133,,,,,,,,,,,,,,,,, -41500,0.35942864,2.9153223,,,,,,,,,,,,,,,,, -41600,0.3260799,2.9848623,,,,,,,,,,,,,,,,, -41700,0.4175135,2.9771855,,,,,,,,,,,,,,,,, -41800,0.35056704,2.9199827,,,,,,,,,,,,,,,,, -41900,0.3651758,2.966501,,,,,,,,,,,,,,,,, -42000,0.34394383,2.9350386,,,,,,,,,,,,,,,,, -42100,0.3533721,2.9186268,,,,,,,,,,,,,,,,, -42200,0.37589473,3.0611506,,,,,,,,,,,,,,,,, -42300,0.38320255,2.987235,,,,,,,,,,,,,,,,, -42400,0.35670733,2.9218209,,,,,,,,,,,,,,,,, -42500,0.3603802,2.9741607,,,,,,,,,,,,,,,,, -42600,0.35843122,2.93777,,,,,,,,,,,,,,,,, -42700,0.342264,2.997262,,,,,,,,,,,,,,,,, -42800,0.34933126,2.9633722,,,,,,,,,,,,,,,,, -42900,0.38104838,2.9133403,,,,,,,,,,,,,,,,, -43000,0.33852664,2.8763044,,,,,,,,,,,,,,,,, -43100,0.43945643,2.947332,,,,,,,,,,,,,,,,, -43185,,,0.6362661123275757,1.8771040439605715,31.457801930764408,0.6599298119544983,1.7009010314941406,28.35290610957772,3000.0,0.6713729500770569,1.6449847221374512,27.85697410575625,3003.0,15158.089678287506,24809.804277181625,15158.089678287506,9649.80432677269,0.5288918018341064,0.0 -43200,0.38787392,2.9231489,,,,,,,,,,,,,,,,, -43300,0.35083142,2.905592,,,,,,,,,,,,,,,,, -43400,0.39410156,2.948029,,,,,,,,,,,,,,,,, -43500,0.33478236,2.8930254,,,,,,,,,,,,,,,,, -43600,0.38229513,2.9706967,,,,,,,,,,,,,,,,, -43700,0.36906824,2.9894953,,,,,,,,,,,,,,,,, -43800,0.4359549,2.9019766,,,,,,,,,,,,,,,,, -43900,0.35687476,2.923898,,,,,,,,,,,,,,,,, -44000,0.35439575,2.8908207,,,,,,,,,,,,,,,,, -44100,0.4001958,2.9597564,,,,,,,,,,,,,,,,, -44200,0.33911675,2.9433095,,,,,,,,,,,,,,,,, -44300,0.43122953,2.8506298,,,,,,,,,,,,,,,,, -44400,0.37381896,2.8913221,,,,,,,,,,,,,,,,, -44500,0.35367954,2.958842,,,,,,,,,,,,,,,,, -44600,1.2682959,4.9308815,,,,,,,,,,,,,,,,, -44700,0.46775684,4.7700744,,,,,,,,,,,,,,,,, -44800,0.79125714,4.704562,,,,,,,,,,,,,,,,, -44900,1.5305868,4.660212,,,,,,,,,,,,,,,,, -45000,0.5050659,4.622547,,,,,,,,,,,,,,,,, -45100,0.62876415,4.5593796,,,,,,,,,,,,,,,,, -45200,0.5787408,4.601517,,,,,,,,,,,,,,,,, -45300,2.7181678,4.549339,,,,,,,,,,,,,,,,, -45400,1.5630704,4.540361,,,,,,,,,,,,,,,,, -45500,1.1183954,3.8352802,,,,,,,,,,,,,,,,, -45589,,,0.590805172920227,2.153980493545532,24.934241026593256,0.5927143096923828,2.137026071548462,20.67777602591015,3000.0,0.5968043804168701,2.1455793380737305,19.25898748849468,3003.0,15998.09845662117,26285.73772263527,15998.09845662117,10285.620803833008,0.5597467422485352,0.0 -45600,2.457657,3.2300556,,,,,,,,,,,,,,,,, -45700,0.56317717,3.1190798,,,,,,,,,,,,,,,,, -45800,0.43596855,2.9824347,,,,,,,,,,,,,,,,, -45900,0.48345956,2.9323218,,,,,,,,,,,,,,,,, -46000,0.4580907,2.9412634,,,,,,,,,,,,,,,,, -46100,0.38138327,2.9891858,,,,,,,,,,,,,,,,, -46200,0.3777598,2.9106083,,,,,,,,,,,,,,,,, -46300,0.45130742,2.950134,,,,,,,,,,,,,,,,, -46400,0.4327606,3.0221982,,,,,,,,,,,,,,,,, -46500,0.36915272,3.0092332,,,,,,,,,,,,,,,,, -46600,0.42065394,3.0121708,,,,,,,,,,,,,,,,, -46700,0.36746863,2.9921284,,,,,,,,,,,,,,,,, -46800,0.38869125,2.9675927,,,,,,,,,,,,,,,,, -46900,0.3781778,2.9320228,,,,,,,,,,,,,,,,, -47000,0.35793328,2.9403737,,,,,,,,,,,,,,,,, -47100,0.37070116,2.9047656,,,,,,,,,,,,,,,,, -47200,0.43884152,2.9549565,,,,,,,,,,,,,,,,, -47300,0.38875255,2.9102838,,,,,,,,,,,,,,,,, -47400,0.33255926,2.951872,,,,,,,,,,,,,,,,, -47500,0.4442291,2.9160554,,,,,,,,,,,,,,,,, -47600,0.35000327,2.941833,,,,,,,,,,,,,,,,, -47700,0.4961145,2.980406,,,,,,,,,,,,,,,,, -47800,0.38037708,2.9399939,,,,,,,,,,,,,,,,, -47900,0.3620346,2.9425397,,,,,,,,,,,,,,,,, -47991,,,0.6427844762802124,1.8317044973373413,31.89963562097901,0.6606737375259399,1.6950204372406006,28.24034997826204,3000.0,0.6726977229118347,1.6335320472717283,27.737813681491367,3003.0,16838.03125667572,27595.45713019371,16838.03125667572,10755.301027297974,0.5908377170562744,0.0 -48000,0.40519324,2.9698699,,,,,,,,,,,,,,,,, -48100,0.40592238,2.941477,,,,,,,,,,,,,,,,, -48200,0.37084863,2.994805,,,,,,,,,,,,,,,,, -48300,0.34636033,2.9461617,,,,,,,,,,,,,,,,, -48400,0.3670839,2.903393,,,,,,,,,,,,,,,,, -48500,0.36945975,2.9188647,,,,,,,,,,,,,,,,, -48600,0.37471518,3.0028207,,,,,,,,,,,,,,,,, -48700,0.37493214,2.9634602,,,,,,,,,,,,,,,,, -48800,0.39281225,2.8894064,,,,,,,,,,,,,,,,, -48900,0.36442468,2.880361,,,,,,,,,,,,,,,,, -49000,0.38678074,2.9518716,,,,,,,,,,,,,,,,, -49100,0.4453393,2.9198728,,,,,,,,,,,,,,,,, -49200,0.37992266,2.8694177,,,,,,,,,,,,,,,,, -49300,0.3272671,2.8729753,,,,,,,,,,,,,,,,, -49400,0.38794535,2.9126983,,,,,,,,,,,,,,,,, -49500,0.3472973,2.9758525,,,,,,,,,,,,,,,,, -49600,0.39915213,2.9155703,,,,,,,,,,,,,,,,, -49700,0.38352096,2.9979367,,,,,,,,,,,,,,,,, -49800,0.3576228,2.887781,,,,,,,,,,,,,,,,, -49900,0.45690373,2.9640348,,,,,,,,,,,,,,,,, -50000,0.3913766,2.9255035,,,,,,,,,,,,,,,,, -50100,0.41821668,3.033079,,,,,,,,,,,,,,,,, -50200,0.42967495,2.9340487,,,,,,,,,,,,,,,,, -50300,0.4330957,2.8856072,,,,,,,,,,,,,,,,, -50392,,,0.6582806706428528,1.7210654020309448,32.246968132861625,0.6620624661445618,1.6892980337142944,28.434917155790025,3000.0,0.6735227704048157,1.623759388923645,27.932242256488045,3003.0,17677.970749616623,28924.49044728279,17677.970749616623,11244.286452054976,0.6243085861206055,0.0 -50400,0.37989604,2.9765162,,,,,,,,,,,,,,,,, -50500,0.3785628,2.9607263,,,,,,,,,,,,,,,,, -50600,0.3409919,2.8957152,,,,,,,,,,,,,,,,, -50700,0.37907448,2.912894,,,,,,,,,,,,,,,,, -50800,0.36026317,2.9402509,,,,,,,,,,,,,,,,, -50900,0.47789472,2.9724958,,,,,,,,,,,,,,,,, -51000,0.38247523,2.9312131,,,,,,,,,,,,,,,,, -51100,0.3303449,2.9068985,,,,,,,,,,,,,,,,, -51200,0.41551995,2.9751348,,,,,,,,,,,,,,,,, -51300,0.3457701,2.9493573,,,,,,,,,,,,,,,,, -51400,0.3939502,2.9018037,,,,,,,,,,,,,,,,, -51500,0.34387046,2.9316843,,,,,,,,,,,,,,,,, -51600,0.39263225,3.001893,,,,,,,,,,,,,,,,, -51700,0.4209225,2.9301176,,,,,,,,,,,,,,,,, -51800,0.38460523,3.049595,,,,,,,,,,,,,,,,, -51900,0.37849388,2.8744478,,,,,,,,,,,,,,,,, -52000,0.47246793,2.8941162,,,,,,,,,,,,,,,,, -52100,0.39107952,2.994984,,,,,,,,,,,,,,,,, -52200,0.35053384,2.9208348,,,,,,,,,,,,,,,,, -52300,0.38074195,2.9182272,,,,,,,,,,,,,,,,, -52400,0.40070876,2.971415,,,,,,,,,,,,,,,,, -52500,0.38671732,2.9974706,,,,,,,,,,,,,,,,, -52600,0.36264327,2.960334,,,,,,,,,,,,,,,,, -52700,0.38002318,2.956329,,,,,,,,,,,,,,,,, -52793,,,0.6478515863418579,1.799898624420166,31.51827536648588,0.6649638414382935,1.6738790273666382,28.84308670692472,3000.0,0.6762535572052002,1.611672282218933,28.25260805810742,3003.0,18517.93385744095,30310.55475568772,18517.93385744095,11790.28213953972,0.6552319526672363,0.0 -52800,0.36140162,2.9699392,,,,,,,,,,,,,,,,, -52900,0.38097143,2.9808517,,,,,,,,,,,,,,,,, -53000,0.3790047,2.9723454,,,,,,,,,,,,,,,,, -53100,0.38045847,2.9538374,,,,,,,,,,,,,,,,, -53200,0.39349627,2.9159124,,,,,,,,,,,,,,,,, -53300,0.4105472,2.882951,,,,,,,,,,,,,,,,, -53400,0.3545389,2.9444895,,,,,,,,,,,,,,,,, -53500,0.36600283,2.9004745,,,,,,,,,,,,,,,,, -53600,0.39557025,2.9532313,,,,,,,,,,,,,,,,, -53700,0.40956098,2.9287271,,,,,,,,,,,,,,,,, -53800,0.34521416,2.961523,,,,,,,,,,,,,,,,, -53900,0.39269575,2.8731592,,,,,,,,,,,,,,,,, -54000,0.35611886,2.8880434,,,,,,,,,,,,,,,,, -54100,0.39465967,2.8482718,,,,,,,,,,,,,,,,, -54200,0.41927558,2.9300995,,,,,,,,,,,,,,,,, -54300,0.4243763,2.8574467,,,,,,,,,,,,,,,,, -54400,0.35605115,2.8623314,,,,,,,,,,,,,,,,, -54500,0.35322353,2.9137793,,,,,,,,,,,,,,,,, -54600,0.3844006,2.9380744,,,,,,,,,,,,,,,,, -54700,0.39156264,2.9085476,,,,,,,,,,,,,,,,, -54800,0.36167476,2.8615284,,,,,,,,,,,,,,,,, -54900,0.36547494,2.9324987,,,,,,,,,,,,,,,,, -55000,0.35864958,2.8874767,,,,,,,,,,,,,,,,, -55100,0.40848634,2.9722714,,,,,,,,,,,,,,,,, -55195,,,0.6423264145851135,1.831187605857849,31.520352554833806,0.6630296111106873,1.678435444831848,28.5659265759097,3000.0,0.6748591065406799,1.615303874015808,28.13716320799326,3003.0,19358.101722240448,31647.902009248734,19358.101722240448,12287.342983961104,0.6947894096374512,0.0 -55200,0.38432005,2.8544939,,,,,,,,,,,,,,,,, -55300,0.3687967,2.9518776,,,,,,,,,,,,,,,,, -55400,0.37050524,2.9173412,,,,,,,,,,,,,,,,, -55500,0.37369418,2.8858395,,,,,,,,,,,,,,,,, -55600,0.39428782,2.9473948,,,,,,,,,,,,,,,,, -55700,1.6293,3.0187602,,,,,,,,,,,,,,,,, -55800,2.047851,3.42446,,,,,,,,,,,,,,,,, -55900,2.2278125,3.1522658,,,,,,,,,,,,,,,,, -56000,0.4218152,2.9111972,,,,,,,,,,,,,,,,, -56100,0.36544123,2.8672466,,,,,,,,,,,,,,,,, -56200,0.40168872,2.969303,,,,,,,,,,,,,,,,, -56300,0.3965823,2.9565136,,,,,,,,,,,,,,,,, -56400,0.33992794,2.8829982,,,,,,,,,,,,,,,,, -56500,0.3408613,2.930265,,,,,,,,,,,,,,,,, -56600,0.33819887,2.8660514,,,,,,,,,,,,,,,,, -56700,0.3849892,2.971985,,,,,,,,,,,,,,,,, -56800,0.3895445,2.9786286,,,,,,,,,,,,,,,,, -56900,0.36169425,3.00358,,,,,,,,,,,,,,,,, -57000,0.40720132,2.8768518,,,,,,,,,,,,,,,,, -57100,0.3535477,2.8671002,,,,,,,,,,,,,,,,, -57200,0.36930272,2.9106755,,,,,,,,,,,,,,,,, -57300,0.38280502,2.9115381,,,,,,,,,,,,,,,,, -57400,0.37738603,2.9018815,,,,,,,,,,,,,,,,, -57500,0.33228007,3.0030742,,,,,,,,,,,,,,,,, -57598,,,0.6510561108589172,1.7641416788101196,31.74470780612312,0.6652490496635437,1.6621553897857666,28.623343659957573,3000.0,0.6766951680183411,1.596364974975586,28.10985823697703,3003.0,20198.159299850464,32949.93321561813,20198.159299850464,12749.21016407013,0.7266678810119629,0.0 -57600,0.3666736,2.9206524,,,,,,,,,,,,,,,,, -57700,0.3584428,2.975852,,,,,,,,,,,,,,,,, -57800,0.3526135,2.9658597,,,,,,,,,,,,,,,,, -57900,0.33515623,2.8122406,,,,,,,,,,,,,,,,, -58000,0.36178917,2.8679996,,,,,,,,,,,,,,,,, -58100,0.37318382,2.880636,,,,,,,,,,,,,,,,, -58200,0.37046182,2.8271554,,,,,,,,,,,,,,,,, -58300,0.3412272,2.8824406,,,,,,,,,,,,,,,,, -58400,0.3437402,2.8857722,,,,,,,,,,,,,,,,, -58500,0.35865507,2.9259799,,,,,,,,,,,,,,,,, -58600,0.35025075,2.9615533,,,,,,,,,,,,,,,,, -58700,0.37111634,2.905604,,,,,,,,,,,,,,,,, -58800,0.35524434,2.8822062,,,,,,,,,,,,,,,,, -58900,0.36043823,2.8340352,,,,,,,,,,,,,,,,, -59000,0.34235746,2.9462306,,,,,,,,,,,,,,,,, -59100,0.35315052,2.902164,,,,,,,,,,,,,,,,, -59200,0.35908857,2.9062316,,,,,,,,,,,,,,,,, -59300,0.33853564,2.9375942,,,,,,,,,,,,,,,,, -59400,0.35902804,2.8911033,,,,,,,,,,,,,,,,, -59500,0.36842567,2.9021544,,,,,,,,,,,,,,,,, -59600,0.39034107,2.910972,,,,,,,,,,,,,,,,, -59700,0.3890139,2.8726454,,,,,,,,,,,,,,,,, -59800,0.34031025,2.860434,,,,,,,,,,,,,,,,, -59900,0.4113187,2.8665798,,,,,,,,,,,,,,,,, -59999,,,0.644768476486206,1.8109740018844604,31.910825778247982,0.6663773655891418,1.6560256481170654,28.56648746349024,3000.0,0.6790773272514343,1.5836652517318726,28.48931735603496,3003.0,21038.290306568146,34302.990511894226,21038.290306568146,13262.024190664291,0.7604336738586426,0.0 -60000,0.37713623,2.9261913,,,,,,,,,,,,,,,,, -60100,0.34631428,2.9170787,,,,,,,,,,,,,,,,, -60200,0.42547455,2.8624415,,,,,,,,,,,,,,,,, -60300,0.36322066,2.9094248,,,,,,,,,,,,,,,,, -60400,0.36848214,2.9287646,,,,,,,,,,,,,,,,, -60500,0.37354487,2.945878,,,,,,,,,,,,,,,,, -60600,0.35094044,2.9139495,,,,,,,,,,,,,,,,, -60700,0.33922356,2.9162872,,,,,,,,,,,,,,,,, -60800,0.33847195,2.9137573,,,,,,,,,,,,,,,,, -60900,0.37173456,2.962829,,,,,,,,,,,,,,,,, -61000,0.3683173,2.9163578,,,,,,,,,,,,,,,,, -61100,0.40514112,2.856899,,,,,,,,,,,,,,,,, -61200,0.3915669,2.897583,,,,,,,,,,,,,,,,, -61300,0.3637861,2.9401972,,,,,,,,,,,,,,,,, -61400,0.42357987,2.8946688,,,,,,,,,,,,,,,,, -61500,0.36146662,2.8894672,,,,,,,,,,,,,,,,, -61600,0.3618919,2.8580546,,,,,,,,,,,,,,,,, -61700,0.35720998,2.8740227,,,,,,,,,,,,,,,,, -61800,0.35362384,2.9375687,,,,,,,,,,,,,,,,, -61900,0.35209566,2.8914754,,,,,,,,,,,,,,,,, -62000,0.3680736,2.9527092,,,,,,,,,,,,,,,,, -62100,0.48881033,2.84685,,,,,,,,,,,,,,,,, -62200,0.3545049,2.911336,,,,,,,,,,,,,,,,, -62300,0.41217563,2.8268406,,,,,,,,,,,,,,,,, -62400,0.3567804,2.850663,,,,,,,,,,,,,,,,, -62401,,,0.6491748094558716,1.7910131216049194,32.01965766582803,0.669477105140686,1.6406055688858032,29.233560910622813,3000.0,0.6824356913566589,1.577749729156494,28.67643381041361,3003.0,21878.360268354416,35608.370332956314,21878.360268354416,13727.225650072098,0.7934637069702148,0.0 -62500,0.41211757,2.8760772,,,,,,,,,,,,,,,,, -62600,0.3644631,2.8862283,,,,,,,,,,,,,,,,, -62700,0.37273633,2.8462367,,,,,,,,,,,,,,,,, -62800,0.37902772,2.9120796,,,,,,,,,,,,,,,,, -62900,0.3606702,2.8194456,,,,,,,,,,,,,,,,, -63000,0.37440953,2.9468398,,,,,,,,,,,,,,,,, -63100,0.36223888,2.8473296,,,,,,,,,,,,,,,,, -63200,0.44847345,2.884867,,,,,,,,,,,,,,,,, -63300,0.361128,2.9208686,,,,,,,,,,,,,,,,, -63400,0.38417834,2.9856877,,,,,,,,,,,,,,,,, -63500,0.34090057,2.8601494,,,,,,,,,,,,,,,,, -63600,0.36769253,2.8447337,,,,,,,,,,,,,,,,, -63700,0.37147707,2.8406744,,,,,,,,,,,,,,,,, -63800,0.36875296,2.986109,,,,,,,,,,,,,,,,, -63900,0.37289253,2.889038,,,,,,,,,,,,,,,,, -64000,0.42813942,2.9144998,,,,,,,,,,,,,,,,, -64100,0.3916176,2.972261,,,,,,,,,,,,,,,,, -64200,0.35443708,2.9444928,,,,,,,,,,,,,,,,, -64300,0.357451,2.8483007,,,,,,,,,,,,,,,,, -64400,0.3777198,2.9191148,,,,,,,,,,,,,,,,, -64500,0.4043931,2.9168215,,,,,,,,,,,,,,,,, -64600,0.36476254,2.8279169,,,,,,,,,,,,,,,,, -64700,0.43115437,2.8163657,,,,,,,,,,,,,,,,, -64800,0.3848181,2.879482,,,,,,,,,,,,,,,,, -64804,,,0.6520285606384277,1.760977268218994,31.971719271704508,0.6713989973068237,1.6386438608169556,29.0899984850159,3000.0,0.6805996298789978,1.5786640644073486,28.65766199623945,3003.0,22718.57821083069,37024.85886883736,22718.57821083069,14303.382984161375,0.8290464878082275,0.0 -64900,0.42018625,2.8605406,,,,,,,,,,,,,,,,, -65000,0.37618634,2.9193428,,,,,,,,,,,,,,,,, -65100,0.37203994,2.878746,,,,,,,,,,,,,,,,, -65200,0.34895444,2.923446,,,,,,,,,,,,,,,,, -65300,0.3920442,2.8884647,,,,,,,,,,,,,,,,, -65400,0.36730203,2.93469,,,,,,,,,,,,,,,,, -65500,0.3520051,2.843573,,,,,,,,,,,,,,,,, -65600,0.41986915,2.9067442,,,,,,,,,,,,,,,,, -65700,0.35613248,2.878689,,,,,,,,,,,,,,,,, -65800,0.39280882,2.8719695,,,,,,,,,,,,,,,,, -65900,0.3796239,2.859275,,,,,,,,,,,,,,,,, -66000,0.37283188,2.8875678,,,,,,,,,,,,,,,,, -66100,0.3837938,2.8610282,,,,,,,,,,,,,,,,, -66200,0.3587189,2.8566859,,,,,,,,,,,,,,,,, -66300,0.43069986,2.7964616,,,,,,,,,,,,,,,,, -66400,0.35696152,2.8240569,,,,,,,,,,,,,,,,, -66500,0.40717283,2.9058325,,,,,,,,,,,,,,,,, -66600,0.36975962,2.8385139,,,,,,,,,,,,,,,,, -66700,0.37796426,2.8549688,,,,,,,,,,,,,,,,, -66800,0.35495886,2.7954655,,,,,,,,,,,,,,,,, -66900,0.3601138,2.891377,,,,,,,,,,,,,,,,, -67000,0.37330654,2.908404,,,,,,,,,,,,,,,,, -67100,0.3530785,2.7952232,,,,,,,,,,,,,,,,, -67200,0.3634841,2.8339946,,,,,,,,,,,,,,,,, -67206,,,0.6519597768783569,1.7757725715637207,31.91697570224685,0.670419454574585,1.638580083847046,29.23616891361002,3000.0,0.6822613477706909,1.5671225786209106,28.72514626657771,3003.0,23558.55596637726,38350.22704553604,23558.55596637726,14788.661399126053,0.8624668121337891,0.0 -67300,0.40800744,2.852639,,,,,,,,,,,,,,,,, -67400,0.3622253,2.9234447,,,,,,,,,,,,,,,,, -67500,0.38296196,2.84693,,,,,,,,,,,,,,,,, -67600,0.37866196,2.8787668,,,,,,,,,,,,,,,,, -67700,0.40337822,2.9397774,,,,,,,,,,,,,,,,, -67800,0.38020372,2.910445,,,,,,,,,,,,,,,,, -67900,0.3695485,2.8401725,,,,,,,,,,,,,,,,, -68000,0.36697575,2.801948,,,,,,,,,,,,,,,,, -68100,0.42137614,2.918217,,,,,,,,,,,,,,,,, -68200,0.4024838,2.919333,,,,,,,,,,,,,,,,, -68300,0.3741112,2.8410182,,,,,,,,,,,,,,,,, -68400,0.3812651,2.903903,,,,,,,,,,,,,,,,, -68500,0.37959364,2.8599265,,,,,,,,,,,,,,,,, -68600,0.37475643,2.8692417,,,,,,,,,,,,,,,,, -68700,0.3934335,2.9149098,,,,,,,,,,,,,,,,, -68800,0.37125057,2.9443192,,,,,,,,,,,,,,,,, -68900,0.35630518,2.7813795,,,,,,,,,,,,,,,,, -69000,0.4437898,2.9059553,,,,,,,,,,,,,,,,, -69100,0.38340428,2.8806226,,,,,,,,,,,,,,,,, -69200,0.37127447,2.8696437,,,,,,,,,,,,,,,,, -69300,0.37773642,2.8577406,,,,,,,,,,,,,,,,, -69400,0.3830887,2.8145232,,,,,,,,,,,,,,,,, -69500,0.34612614,2.857713,,,,,,,,,,,,,,,,, -69600,0.36024493,2.8084316,,,,,,,,,,,,,,,,, -69609,,,0.6634093523025513,1.6923391819000244,32.762389538704355,0.6716221570968628,1.6273186206817627,29.347775662744024,3000.0,0.6841090321540833,1.5585025548934937,29.071967186401448,3003.0,24398.73073887825,39688.82049059868,24398.73073887825,15286.969474315643,0.8974158763885498,0.0 -69700,0.36050922,2.8567793,,,,,,,,,,,,,,,,, -69800,0.44470492,2.8712626,,,,,,,,,,,,,,,,, -69900,0.36565223,2.9003298,,,,,,,,,,,,,,,,, -70000,0.3660957,2.8126235,,,,,,,,,,,,,,,,, -70100,0.38387138,2.8335185,,,,,,,,,,,,,,,,, -70200,0.4130197,2.8654401,,,,,,,,,,,,,,,,, -70300,0.4161787,2.8527443,,,,,,,,,,,,,,,,, -70400,0.36983657,2.8536608,,,,,,,,,,,,,,,,, -70500,0.37767914,2.8104355,,,,,,,,,,,,,,,,, -70600,0.4137597,2.881133,,,,,,,,,,,,,,,,, -70700,0.39697993,2.8962104,,,,,,,,,,,,,,,,, -70800,0.38718507,2.8412347,,,,,,,,,,,,,,,,, -70900,0.38342264,2.9569259,,,,,,,,,,,,,,,,, -71000,0.36970055,2.872888,,,,,,,,,,,,,,,,, -71100,0.37862495,2.8094437,,,,,,,,,,,,,,,,, -71200,0.38666838,2.8663356,,,,,,,,,,,,,,,,, -71300,0.41526806,2.9131966,,,,,,,,,,,,,,,,, -71400,0.383447,2.810556,,,,,,,,,,,,,,,,, -71500,0.38773158,2.8124943,,,,,,,,,,,,,,,,, -71600,0.41876093,2.847669,,,,,,,,,,,,,,,,, -71700,0.362664,2.889362,,,,,,,,,,,,,,,,, -71800,0.3467567,2.9431553,,,,,,,,,,,,,,,,, -71900,0.35823026,2.8095434,,,,,,,,,,,,,,,,, -72000,0.3789634,2.884638,,,,,,,,,,,,,,,,, -72012,,,0.6599642634391785,1.7183440923690796,32.44681045615368,0.6714857816696167,1.6171070337295532,29.12857090191241,3000.0,0.6856196522712708,1.5455368757247925,28.909133725017643,3003.0,25238.861780643463,41023.935676813126,25238.861780643463,15781.841740846634,0.9326050281524658,0.0 -72100,0.36797825,2.8762605,,,,,,,,,,,,,,,,, -72200,0.352117,2.8351197,,,,,,,,,,,,,,,,, -72300,0.38330716,2.8249228,,,,,,,,,,,,,,,,, -72400,0.35407022,2.8283567,,,,,,,,,,,,,,,,, -72500,0.384995,2.879914,,,,,,,,,,,,,,,,, -72600,0.35712028,2.7329295,,,,,,,,,,,,,,,,, -72700,0.3643987,2.891464,,,,,,,,,,,,,,,,, -72800,0.4126452,2.8426318,,,,,,,,,,,,,,,,, -72900,0.40285027,2.857994,,,,,,,,,,,,,,,,, -73000,0.39211047,2.8490593,,,,,,,,,,,,,,,,, -73100,0.3779544,2.8126109,,,,,,,,,,,,,,,,, -73200,0.3931422,2.860169,,,,,,,,,,,,,,,,, -73300,0.385137,2.799856,,,,,,,,,,,,,,,,, -73400,0.41325462,2.864978,,,,,,,,,,,,,,,,, -73500,0.3810034,2.8204284,,,,,,,,,,,,,,,,, -73600,0.37754515,2.829031,,,,,,,,,,,,,,,,, -73700,0.3938583,2.8675046,,,,,,,,,,,,,,,,, -73800,0.38477954,2.8300085,,,,,,,,,,,,,,,,, -73900,0.37135726,2.8607602,,,,,,,,,,,,,,,,, -74000,0.402154,2.8325713,,,,,,,,,,,,,,,,, -74100,0.40690196,2.8345203,,,,,,,,,,,,,,,,, -74200,0.3449089,2.8693264,,,,,,,,,,,,,,,,, -74300,0.37565902,2.8131242,,,,,,,,,,,,,,,,, -74400,0.3930473,2.9188046,,,,,,,,,,,,,,,,, -74415,,,0.6568201780319214,1.7489362955093384,32.56595782076231,0.6759618520736694,1.6061208248138428,29.63801741455089,3000.0,0.6896287202835083,1.534721612930298,29.456325345785103,3003.0,26079.09641242028,42389.31352448464,26079.09641242028,16306.87161040306,0.9680724143981934,0.0 -74500,0.37936327,2.8720152,,,,,,,,,,,,,,,,, -74600,0.35641012,2.8595958,,,,,,,,,,,,,,,,, -74700,0.41859096,2.8811095,,,,,,,,,,,,,,,,, -74800,0.40289512,2.8157713,,,,,,,,,,,,,,,,, -74900,0.39894095,2.7611213,,,,,,,,,,,,,,,,, -75000,0.39055166,2.834992,,,,,,,,,,,,,,,,, -75100,0.3752063,2.822003,,,,,,,,,,,,,,,,, -75200,0.39406332,2.8852344,,,,,,,,,,,,,,,,, -75300,0.3930735,2.791623,,,,,,,,,,,,,,,,, -75400,0.3586393,2.7366328,,,,,,,,,,,,,,,,, -75500,0.39374727,2.7775018,,,,,,,,,,,,,,,,, -75600,0.3794093,2.8811743,,,,,,,,,,,,,,,,, -75700,0.38609463,2.7975633,,,,,,,,,,,,,,,,, -75800,0.41326788,2.8338664,,,,,,,,,,,,,,,,, -75900,0.39487225,2.8644884,,,,,,,,,,,,,,,,, -76000,0.39395857,2.8537865,,,,,,,,,,,,,,,,, -76100,0.48051068,2.8596146,,,,,,,,,,,,,,,,, -76200,0.38712698,2.8552687,,,,,,,,,,,,,,,,, -76300,0.39477208,2.9090645,,,,,,,,,,,,,,,,, -76400,0.38190827,2.900368,,,,,,,,,,,,,,,,, -76500,0.39534643,2.8495991,,,,,,,,,,,,,,,,, -76600,0.41553435,2.8383992,,,,,,,,,,,,,,,,, -76700,0.3761164,2.7067235,,,,,,,,,,,,,,,,, -76800,0.38069767,2.8568625,,,,,,,,,,,,,,,,, -76817,,,0.661069929599762,1.7087233066558838,32.60740521012439,0.6762098073959351,1.603691577911377,29.3602712239802,3000.0,0.6898146867752075,1.5289326906204224,29.37583258832015,3003.0,26919.07650828361,43707.50957036018,26919.07650828361,16784.96683216095,1.0114808082580566,0.0 -76900,0.37276068,2.865256,,,,,,,,,,,,,,,,, -77000,0.43681145,2.8881273,,,,,,,,,,,,,,,,, -77100,0.39037275,2.8375552,,,,,,,,,,,,,,,,, -77200,0.40291208,2.8621275,,,,,,,,,,,,,,,,, -77300,0.38833085,2.8895535,,,,,,,,,,,,,,,,, -77400,0.40276316,2.8675263,,,,,,,,,,,,,,,,, -77500,0.41434872,2.8225272,,,,,,,,,,,,,,,,, -77600,0.38218322,2.7924035,,,,,,,,,,,,,,,,, -77700,0.37443292,2.8048189,,,,,,,,,,,,,,,,, -77800,0.39501035,2.7902143,,,,,,,,,,,,,,,,, -77900,0.3974878,2.9080682,,,,,,,,,,,,,,,,, -78000,0.40935254,2.8464973,,,,,,,,,,,,,,,,, -78100,0.39704216,2.8041599,,,,,,,,,,,,,,,,, -78200,0.37557575,2.8313513,,,,,,,,,,,,,,,,, -78300,0.3726534,2.8348012,,,,,,,,,,,,,,,,, -78400,0.40900165,2.775706,,,,,,,,,,,,,,,,, -78500,0.42039582,2.7994773,,,,,,,,,,,,,,,,, -78600,0.3939,2.9068582,,,,,,,,,,,,,,,,, -78700,0.40138853,2.868829,,,,,,,,,,,,,,,,, -78800,0.41401213,2.8395422,,,,,,,,,,,,,,,,, -78900,0.4125141,2.827562,,,,,,,,,,,,,,,,, -79000,0.4127027,2.8786426,,,,,,,,,,,,,,,,, -79100,0.395847,2.815606,,,,,,,,,,,,,,,,, -79200,0.3642507,2.8093448,,,,,,,,,,,,,,,,, -79219,,,0.6592926383018494,1.724799871444702,32.536791630899906,0.6757262945175171,1.6033855676651,29.62497735225051,3000.0,0.6900470852851868,1.530118107795715,29.416101277503813,3003.0,27759.00503540039,45018.331510305405,27759.00503540039,17255.74615764618,1.0472896099090576,0.0 -79300,0.38443586,2.8324962,,,,,,,,,,,,,,,,, -79400,0.41394836,2.815664,,,,,,,,,,,,,,,,, -79500,0.4027334,2.7736294,,,,,,,,,,,,,,,,, -79600,0.38358018,2.8155844,,,,,,,,,,,,,,,,, -79700,0.39862025,2.8716748,,,,,,,,,,,,,,,,, -79800,0.39586943,2.8295105,,,,,,,,,,,,,,,,, -79900,0.41900158,2.8478184,,,,,,,,,,,,,,,,, -80000,0.40871555,2.775917,,,,,,,,,,,,,,,,, -80100,0.40062597,2.8631773,,,,,,,,,,,,,,,,, -80200,0.38116467,2.8255193,,,,,,,,,,,,,,,,, -80300,0.4266376,2.8545048,,,,,,,,,,,,,,,,, -80400,0.40405172,2.8329883,,,,,,,,,,,,,,,,, -80500,0.38384956,2.824126,,,,,,,,,,,,,,,,, -80600,0.38576454,2.8438725,,,,,,,,,,,,,,,,, -80700,0.40063345,2.8466258,,,,,,,,,,,,,,,,, -80800,0.38474855,2.787246,,,,,,,,,,,,,,,,, -80900,0.3948708,2.883628,,,,,,,,,,,,,,,,, -81000,0.397134,2.847214,,,,,,,,,,,,,,,,, -81100,0.39357293,2.7805204,,,,,,,,,,,,,,,,, -81200,0.39201707,2.7757885,,,,,,,,,,,,,,,,, -81300,0.43774784,2.7399995,,,,,,,,,,,,,,,,, -81400,0.39109156,2.9476285,,,,,,,,,,,,,,,,, -81500,0.3865184,2.787323,,,,,,,,,,,,,,,,, -81600,0.39542955,2.8627949,,,,,,,,,,,,,,,,, -81622,,,0.6859181523323059,1.5658107995986938,34.2604911187917,0.6783672571182251,1.58946430683136,29.177389345403792,3000.0,0.6895938515663147,1.5182905197143557,29.071133285384303,3003.0,28599.207001686096,46428.168867349625,28599.207001686096,17825.267672777176,1.0836036205291748,0.0 -81700,0.41977108,2.795591,,,,,,,,,,,,,,,,, -81800,0.39184728,2.854391,,,,,,,,,,,,,,,,, -81900,0.4028396,2.8145134,,,,,,,,,,,,,,,,, -82000,0.41273186,2.783441,,,,,,,,,,,,,,,,, -82100,0.39678016,2.803919,,,,,,,,,,,,,,,,, -82200,0.43331242,2.7731426,,,,,,,,,,,,,,,,, -82300,0.4343233,2.8757489,,,,,,,,,,,,,,,,, -82400,0.38130713,2.8075867,,,,,,,,,,,,,,,,, -82500,0.39515188,2.7786312,,,,,,,,,,,,,,,,, -82600,0.4016654,2.7945201,,,,,,,,,,,,,,,,, -82700,0.40652224,2.8105845,,,,,,,,,,,,,,,,, -82800,0.44774586,2.7707567,,,,,,,,,,,,,,,,, -82900,0.38798106,2.8195243,,,,,,,,,,,,,,,,, -83000,0.4054765,2.7641177,,,,,,,,,,,,,,,,, -83100,0.4264959,2.776888,,,,,,,,,,,,,,,,, -83200,0.42694727,2.804317,,,,,,,,,,,,,,,,, -83300,0.40207514,2.775402,,,,,,,,,,,,,,,,, -83400,0.44176048,2.7963698,,,,,,,,,,,,,,,,, -83500,0.41705608,2.7970028,,,,,,,,,,,,,,,,, -83600,0.43025076,2.820886,,,,,,,,,,,,,,,,, -83700,0.4441405,2.8588986,,,,,,,,,,,,,,,,, -83800,0.40705666,2.8705983,,,,,,,,,,,,,,,,, -83900,0.39591298,2.846517,,,,,,,,,,,,,,,,, -84000,0.41183445,2.8012702,,,,,,,,,,,,,,,,, -84025,,,0.664563000202179,1.696059226989746,33.1192903855946,0.6789996027946472,1.582720160484314,29.565914486020333,3000.0,0.6921852231025696,1.5084973573684692,29.47494867004409,3003.0,29439.43800854683,47767.677359580994,29439.43800854683,18324.43095898628,1.1216096878051758,0.0 -84100,0.44390166,2.8106804,,,,,,,,,,,,,,,,, -84200,0.43011728,2.8201396,,,,,,,,,,,,,,,,, -84300,0.40642616,2.8685133,,,,,,,,,,,,,,,,, -84400,0.42803693,2.847126,,,,,,,,,,,,,,,,, -84500,0.41297802,2.861915,,,,,,,,,,,,,,,,, -84600,0.40086398,2.8171384,,,,,,,,,,,,,,,,, -84700,0.39930543,2.7568896,,,,,,,,,,,,,,,,, -84800,0.44893855,2.7980094,,,,,,,,,,,,,,,,, -84900,0.4017056,2.8021812,,,,,,,,,,,,,,,,, -85000,0.42957464,2.826542,,,,,,,,,,,,,,,,, -85100,0.42567334,2.8019707,,,,,,,,,,,,,,,,, -85200,0.40982535,2.8068073,,,,,,,,,,,,,,,,, -85300,0.4074695,2.8780658,,,,,,,,,,,,,,,,, -85400,0.4233961,2.8160176,,,,,,,,,,,,,,,,, -85500,0.42603084,2.8815727,,,,,,,,,,,,,,,,, -85600,0.42870057,2.8476682,,,,,,,,,,,,,,,,, -85700,0.43140012,2.8170931,,,,,,,,,,,,,,,,, -85800,0.45753095,2.8074024,,,,,,,,,,,,,,,,, -85900,0.40526482,2.8276424,,,,,,,,,,,,,,,,, -86000,0.45461422,2.858211,,,,,,,,,,,,,,,,, -86100,0.4472673,2.8638587,,,,,,,,,,,,,,,,, -86200,0.45659485,2.7912197,,,,,,,,,,,,,,,,, -86300,0.41329435,2.8178632,,,,,,,,,,,,,,,,, -86400,0.43040097,2.8256855,,,,,,,,,,,,,,,,, -86427,,,0.6629806756973267,1.7091134786605835,32.61050182170082,0.6798179745674133,1.5794235467910769,29.637436093481103,3000.0,0.6930103302001953,1.5063138008117676,29.577308672706145,3003.0,30279.43471693993,49129.6320040226,30279.43471693993,18846.27249503136,1.1595826148986816,0.0 -86500,0.4177315,2.8335629,,,,,,,,,,,,,,,,, -86600,0.4350835,2.8415656,,,,,,,,,,,,,,,,, -86700,0.4497818,2.815661,,,,,,,,,,,,,,,,, -86800,0.4473321,2.7514591,,,,,,,,,,,,,,,,, -86900,0.42724732,2.7803447,,,,,,,,,,,,,,,,, -87000,0.4472679,2.7640662,,,,,,,,,,,,,,,,, -87100,0.43769473,2.8223672,,,,,,,,,,,,,,,,, -87200,0.42584348,2.7694666,,,,,,,,,,,,,,,,, -87300,0.41418242,2.7668278,,,,,,,,,,,,,,,,, -87400,0.4282063,2.7816973,,,,,,,,,,,,,,,,, -87500,0.4316271,2.730644,,,,,,,,,,,,,,,,, -87600,0.45421815,2.8192172,,,,,,,,,,,,,,,,, -87700,0.40737426,2.7716396,,,,,,,,,,,,,,,,, -87800,0.41224208,2.7332609,,,,,,,,,,,,,,,,, -87900,0.44200405,2.801889,,,,,,,,,,,,,,,,, -88000,0.41390884,2.7904873,,,,,,,,,,,,,,,,, -88100,0.4378626,2.8139658,,,,,,,,,,,,,,,,, -88200,0.43135214,2.812934,,,,,,,,,,,,,,,,, -88300,0.42111114,2.78831,,,,,,,,,,,,,,,,, -88400,0.4195963,2.7614443,,,,,,,,,,,,,,,,, -88500,0.42955372,2.7658575,,,,,,,,,,,,,,,,, -88600,0.4205667,2.8294618,,,,,,,,,,,,,,,,, -88700,0.45886216,2.8362634,,,,,,,,,,,,,,,,, -88800,0.42337683,2.7964163,,,,,,,,,,,,,,,,, -88829,,,0.6717724204063416,1.6399476528167725,33.690016325414604,0.6823846101760864,1.570753574371338,30.108637798690488,3000.0,0.6959851384162903,1.4950261116027832,29.98966318985136,3003.0,31119.40724849701,50471.155947208405,31119.40724849701,19347.71030664444,1.1959412097930908,0.0 -88900,0.43366346,2.7797852,,,,,,,,,,,,,,,,, -89000,0.4442103,2.751321,,,,,,,,,,,,,,,,, -89100,0.43487698,2.7677705,,,,,,,,,,,,,,,,, -89200,0.43333918,2.816158,,,,,,,,,,,,,,,,, -89300,0.41318345,2.7125974,,,,,,,,,,,,,,,,, -89400,0.4298697,2.709432,,,,,,,,,,,,,,,,, -89500,0.44983807,2.8398843,,,,,,,,,,,,,,,,, -89600,0.44499794,2.8047318,,,,,,,,,,,,,,,,, -89700,0.41862768,2.8185284,,,,,,,,,,,,,,,,, -89800,0.44757366,2.7653272,,,,,,,,,,,,,,,,, -89900,0.45129234,2.8362656,,,,,,,,,,,,,,,,, -90000,0.46479094,2.8110693,,,,,,,,,,,,,,,,, -90100,0.4213883,2.8359036,,,,,,,,,,,,,,,,, -90200,0.44590768,2.7980292,,,,,,,,,,,,,,,,, -90300,0.44107595,2.8424332,,,,,,,,,,,,,,,,, -90400,0.48884478,2.812661,,,,,,,,,,,,,,,,, -90500,0.4409175,2.7550485,,,,,,,,,,,,,,,,, -90600,0.48537615,2.8286772,,,,,,,,,,,,,,,,, -90700,0.42349297,2.7923844,,,,,,,,,,,,,,,,, -90800,0.43983108,2.8416867,,,,,,,,,,,,,,,,, -90900,0.43138266,2.775792,,,,,,,,,,,,,,,,, -91000,0.4347928,2.7758796,,,,,,,,,,,,,,,,, -91100,0.4483545,2.7214692,,,,,,,,,,,,,,,,, -91200,0.44051695,2.7934616,,,,,,,,,,,,,,,,, -91231,,,0.6652762293815613,1.684821844100952,33.516918895580865,0.682136595249176,1.5662659406661987,29.97718632659504,3000.0,0.6975887417793274,1.4849611520767212,30.00201244760764,3003.0,31959.536118268967,51773.93786597252,31959.536118268967,19810.24594926834,1.2348592281341553,0.0 -91300,0.45473713,2.7376812,,,,,,,,,,,,,,,,, -91400,0.46374878,2.7596974,,,,,,,,,,,,,,,,, -91500,0.43107033,2.7570415,,,,,,,,,,,,,,,,, -91600,0.46181405,2.828515,,,,,,,,,,,,,,,,, -91700,0.46528327,2.7708807,,,,,,,,,,,,,,,,, -91800,0.44179463,2.750282,,,,,,,,,,,,,,,,, -91900,0.47295427,2.7203968,,,,,,,,,,,,,,,,, -92000,0.48273903,2.8106222,,,,,,,,,,,,,,,,, -92100,0.44679385,2.806037,,,,,,,,,,,,,,,,, -92200,0.47359976,2.7737122,,,,,,,,,,,,,,,,, -92300,0.44977218,2.7361014,,,,,,,,,,,,,,,,, -92400,0.47891158,2.8427596,,,,,,,,,,,,,,,,, -92500,0.45810533,2.7719219,,,,,,,,,,,,,,,,, -92600,0.43729368,2.7536292,,,,,,,,,,,,,,,,, -92700,0.46842447,2.740328,,,,,,,,,,,,,,,,, -92800,0.46642864,2.7677383,,,,,,,,,,,,,,,,, -92900,0.44674477,2.7402089,,,,,,,,,,,,,,,,, -93000,0.45068645,2.822376,,,,,,,,,,,,,,,,, -93100,0.48226586,2.83917,,,,,,,,,,,,,,,,, -93200,0.4623614,2.7999227,,,,,,,,,,,,,,,,, -93300,0.46092805,2.7845585,,,,,,,,,,,,,,,,, -93400,0.47482014,2.7766807,,,,,,,,,,,,,,,,, -93500,0.4888861,2.728939,,,,,,,,,,,,,,,,, -93600,0.5012913,2.7521756,,,,,,,,,,,,,,,,, -93633,,,0.6700052618980408,1.6654683351516724,33.97726970739959,0.6840088963508606,1.5550308227539062,30.201155231440644,3000.0,0.6991226673126221,1.4706521034240725,30.245471742788,3003.0,32799.64951753616,53139.55933403969,32799.64951753616,20335.630873441696,1.280862808227539,0.0 -93700,0.4570308,2.72571,,,,,,,,,,,,,,,,, -93800,0.45400673,2.6962817,,,,,,,,,,,,,,,,, -93900,0.5051881,2.80506,,,,,,,,,,,,,,,,, -94000,0.4674129,2.7270741,,,,,,,,,,,,,,,,, -94100,0.45568207,2.7361581,,,,,,,,,,,,,,,,, -94200,0.45869714,2.7738085,,,,,,,,,,,,,,,,, -94300,0.48697302,2.7848237,,,,,,,,,,,,,,,,, -94400,0.46141616,2.768245,,,,,,,,,,,,,,,,, -94500,0.47066286,2.788486,,,,,,,,,,,,,,,,, -94600,0.49120945,2.803831,,,,,,,,,,,,,,,,, -94700,0.46589458,2.7549882,,,,,,,,,,,,,,,,, -94800,0.45270598,2.7374954,,,,,,,,,,,,,,,,, -94900,0.47681412,2.7510202,,,,,,,,,,,,,,,,, -95000,0.46935788,2.7224743,,,,,,,,,,,,,,,,, -95100,0.49864987,2.7853882,,,,,,,,,,,,,,,,, -95200,0.47338447,2.8038173,,,,,,,,,,,,,,,,, -95300,0.48943892,2.7829425,,,,,,,,,,,,,,,,, -95400,0.4679482,2.6857555,,,,,,,,,,,,,,,,, -95500,0.4644346,2.6971288,,,,,,,,,,,,,,,,, -95600,0.47730815,2.72687,,,,,,,,,,,,,,,,, -95700,0.4951262,2.826854,,,,,,,,,,,,,,,,, -95800,0.4724002,2.803525,,,,,,,,,,,,,,,,, -95900,0.5505526,2.8417041,,,,,,,,,,,,,,,,, -96000,0.4769632,2.732728,,,,,,,,,,,,,,,,, -96035,,,0.6742818355560303,1.6293132305145264,33.98130465498545,0.6850255727767944,1.5457754135131836,30.02152235059844,3000.0,0.6990413069725037,1.4698675870895386,30.28542712697332,3003.0,33639.54993915558,54495.446476221085,33639.54993915558,20851.50264811516,1.319817066192627,0.0 -96100,0.49402446,2.7295375,,,,,,,,,,,,,,,,, -96200,0.4870203,2.7176995,,,,,,,,,,,,,,,,, -96300,0.48189545,2.7940466,,,,,,,,,,,,,,,,, -96400,0.51786363,2.8038747,,,,,,,,,,,,,,,,, -96500,0.5243857,2.7243936,,,,,,,,,,,,,,,,, -96600,0.5034324,2.8110156,,,,,,,,,,,,,,,,, -96700,0.5009374,2.7147982,,,,,,,,,,,,,,,,, -96800,0.47684592,2.8101282,,,,,,,,,,,,,,,,, -96900,0.5114732,2.7601814,,,,,,,,,,,,,,,,, -97000,0.48195,2.7601664,,,,,,,,,,,,,,,,, -97100,0.47375658,2.704816,,,,,,,,,,,,,,,,, -97200,0.49035904,2.7031963,,,,,,,,,,,,,,,,, -97300,0.5097282,2.76061,,,,,,,,,,,,,,,,, -97400,0.47550294,2.7478662,,,,,,,,,,,,,,,,, -97500,0.4860234,2.7585793,,,,,,,,,,,,,,,,, -97600,0.49724287,2.7642667,,,,,,,,,,,,,,,,, -97700,0.4982654,2.731635,,,,,,,,,,,,,,,,, -97800,0.4922959,2.7864256,,,,,,,,,,,,,,,,, -97900,0.48416916,2.725025,,,,,,,,,,,,,,,,, -98000,0.49385998,2.8127828,,,,,,,,,,,,,,,,, -98100,0.49305692,2.743254,,,,,,,,,,,,,,,,, -98200,0.46978623,2.7702448,,,,,,,,,,,,,,,,, -98300,0.526261,2.807713,,,,,,,,,,,,,,,,, -98400,0.46991834,2.6807704,,,,,,,,,,,,,,,,, -98436,,,0.6713905930519104,1.650544047355652,33.99387977317795,0.685149610042572,1.544081687927246,30.28188354970151,3000.0,0.7021788358688354,1.4656221866607666,30.44932657753321,3003.0,34479.54863166809,55827.9131128788,34479.54863166809,21343.85585975647,1.3581314086914062,0.0 -98500,0.5071025,2.7562096,,,,,,,,,,,,,,,,, -98600,0.47738433,2.730547,,,,,,,,,,,,,,,,, -98700,0.5210321,2.7108867,,,,,,,,,,,,,,,,, -98800,0.47694644,2.810567,,,,,,,,,,,,,,,,, -98900,0.5252654,2.7960994,,,,,,,,,,,,,,,,, -99000,0.4785103,2.7439978,,,,,,,,,,,,,,,,, -99100,0.46128064,2.725279,,,,,,,,,,,,,,,,, -99200,0.5186166,2.7065713,,,,,,,,,,,,,,,,, -99300,0.48531634,2.7287934,,,,,,,,,,,,,,,,, -99400,0.48416477,2.709542,,,,,,,,,,,,,,,,, -99500,0.51094073,2.727368,,,,,,,,,,,,,,,,, -99600,0.50490534,2.7330837,,,,,,,,,,,,,,,,, -99700,0.48267406,2.6626127,,,,,,,,,,,,,,,,, -99800,0.50479877,2.7515001,,,,,,,,,,,,,,,,, -99900,0.5123803,2.682423,,,,,,,,,,,,,,,,, -100000,0.50514746,2.6907976,,,,,,,,,,,,,,,,, -100100,0.50371236,2.7734098,,,,,,,,,,,,,,,,, -100200,0.50009,2.7265587,,,,,,,,,,,,,,,,, -100300,0.47383443,2.692106,,,,,,,,,,,,,,,,, -100400,0.5092716,2.7538595,,,,,,,,,,,,,,,,, -100500,0.5451765,2.7321126,,,,,,,,,,,,,,,,, -100600,0.49041662,2.7113922,,,,,,,,,,,,,,,,, -100700,0.50865704,2.702961,,,,,,,,,,,,,,,,, -100800,0.52530617,2.7059546,,,,,,,,,,,,,,,,, -100838,,,0.6884165406227112,1.5480619668960571,35.2754136989264,0.6868482828140259,1.538313865661621,30.528489437300816,3000.0,0.7017489075660706,1.4601340293884275,30.560988154441777,3003.0,35319.508835315704,57193.07353281975,35319.508835315704,21868.93807411194,1.398036003112793,0.0 -100900,0.497965,2.7429504,,,,,,,,,,,,,,,,, -101000,0.52676916,2.7588415,,,,,,,,,,,,,,,,, -101100,0.54365194,2.7390919,,,,,,,,,,,,,,,,, -101200,0.55802315,2.7539685,,,,,,,,,,,,,,,,, -101300,0.5207652,2.7867293,,,,,,,,,,,,,,,,, -101400,0.52342254,2.7492678,,,,,,,,,,,,,,,,, -101500,0.52632624,2.736694,,,,,,,,,,,,,,,,, -101600,0.5179421,2.7133574,,,,,,,,,,,,,,,,, -101700,0.5362984,2.7097487,,,,,,,,,,,,,,,,, -101800,0.5092855,2.745149,,,,,,,,,,,,,,,,, -101900,0.53256565,2.7928066,,,,,,,,,,,,,,,,, -102000,0.52434766,2.742701,,,,,,,,,,,,,,,,, -102100,0.5199669,2.7139425,,,,,,,,,,,,,,,,, -102200,0.5427394,2.7543497,,,,,,,,,,,,,,,,, -102300,0.5124066,2.6912477,,,,,,,,,,,,,,,,, -102400,0.54005355,2.7259083,,,,,,,,,,,,,,,,, -102500,0.52059615,2.7177296,,,,,,,,,,,,,,,,, -102600,0.5247467,2.7159228,,,,,,,,,,,,,,,,, -102700,0.55860716,2.8239076,,,,,,,,,,,,,,,,, -102800,0.53785825,2.7481284,,,,,,,,,,,,,,,,, -102900,0.53683805,2.7391307,,,,,,,,,,,,,,,,, -103000,0.52891093,2.7247372,,,,,,,,,,,,,,,,, -103100,0.5317012,2.7051961,,,,,,,,,,,,,,,,, -103200,0.5358518,2.7800586,,,,,,,,,,,,,,,,, -103240,,,0.6779972314834595,1.604690432548523,34.50147922318023,0.6882741451263428,1.5299122333526611,30.41491065908115,3000.0,0.7038870453834534,1.4475866556167605,30.51891889328535,3003.0,36159.6839966774,58519.62125611305,36159.6839966774,22355.186529397964,1.44234037399292,0.0 -103300,0.549167,2.768599,,,,,,,,,,,,,,,,, -103400,0.52927595,2.7472904,,,,,,,,,,,,,,,,, -103500,0.5282406,2.6877844,,,,,,,,,,,,,,,,, -103600,0.5755715,2.7008505,,,,,,,,,,,,,,,,, -103700,0.56216246,2.6654103,,,,,,,,,,,,,,,,, -103800,0.5447666,2.7621324,,,,,,,,,,,,,,,,, -103900,0.5419125,2.7205071,,,,,,,,,,,,,,,,, -104000,0.55378664,2.7388504,,,,,,,,,,,,,,,,, -104100,0.5385183,2.655596,,,,,,,,,,,,,,,,, -104200,0.54394716,2.722493,,,,,,,,,,,,,,,,, -104300,0.5447037,2.7001677,,,,,,,,,,,,,,,,, -104400,0.5709779,2.7694507,,,,,,,,,,,,,,,,, -104500,0.54953647,2.7215252,,,,,,,,,,,,,,,,, -104600,0.52428937,2.6718526,,,,,,,,,,,,,,,,, -104700,0.5344486,2.6942396,,,,,,,,,,,,,,,,, -104800,0.5547314,2.6926615,,,,,,,,,,,,,,,,, -104900,0.55626875,2.7004478,,,,,,,,,,,,,,,,, -105000,0.5446863,2.6654987,,,,,,,,,,,,,,,,, -105100,0.5512468,2.718969,,,,,,,,,,,,,,,,, -105200,0.57297987,2.7291238,,,,,,,,,,,,,,,,, -105300,0.5595602,2.722842,,,,,,,,,,,,,,,,, -105400,0.5976902,2.6774428,,,,,,,,,,,,,,,,, -105500,0.5739713,2.709121,,,,,,,,,,,,,,,,, -105600,0.54702675,2.6856577,,,,,,,,,,,,,,,,, -105642,,,0.6837121248245239,1.5825554132461548,34.452289799932295,0.6886957287788391,1.523957371711731,30.520724395170017,3000.0,0.7051188349723816,1.4421643018722534,30.512686812294906,3003.0,36999.760825634,59843.622671842575,36999.760825634,22838.996133089066,1.4821085929870603,0.0 -105700,0.5697279,2.6501129,,,,,,,,,,,,,,,,, -105800,0.5629196,2.7140129,,,,,,,,,,,,,,,,, -105900,0.5637109,2.7635114,,,,,,,,,,,,,,,,, -106000,0.5529147,2.724962,,,,,,,,,,,,,,,,, -106100,0.5754676,2.6827722,,,,,,,,,,,,,,,,, -106200,0.5588002,2.692272,,,,,,,,,,,,,,,,, -106300,0.5614282,2.7182078,,,,,,,,,,,,,,,,, -106400,0.57249004,2.6759734,,,,,,,,,,,,,,,,, -106500,0.5861711,2.6964836,,,,,,,,,,,,,,,,, -106600,0.5579023,2.649177,,,,,,,,,,,,,,,,, -106700,0.55703187,2.6688404,,,,,,,,,,,,,,,,, -106800,0.54718816,2.69172,,,,,,,,,,,,,,,,, -106900,0.5701777,2.662765,,,,,,,,,,,,,,,,, -107000,0.59011346,2.7857857,,,,,,,,,,,,,,,,, -107100,0.59853786,2.6863537,,,,,,,,,,,,,,,,, -107200,0.5759992,2.6939118,,,,,,,,,,,,,,,,, -107300,0.5790329,2.7185037,,,,,,,,,,,,,,,,, -107400,0.5878894,2.7401414,,,,,,,,,,,,,,,,, -107500,0.5676039,2.6171036,,,,,,,,,,,,,,,,, -107600,0.5784912,2.6836658,,,,,,,,,,,,,,,,, -107700,0.56393296,2.6284122,,,,,,,,,,,,,,,,, -107800,0.5797488,2.7093637,,,,,,,,,,,,,,,,, -107900,0.5933814,2.6376991,,,,,,,,,,,,,,,,, -108000,0.5941821,2.6948864,,,,,,,,,,,,,,,,, -108044,,,0.6904653310775757,1.5444891452789309,35.465164598367224,0.6911135315895081,1.5180892944335938,30.564646606616183,3000.0,0.7067108154296875,1.433881402015686,30.80866748847693,3003.0,37839.9554643631,61181.76550936699,37839.9554643631,23336.82561659813,1.5219931602478027,0.0 -108100,0.5807115,2.6774669,,,,,,,,,,,,,,,,, -108200,0.5986567,2.6988518,,,,,,,,,,,,,,,,, -108300,0.6244123,2.7083993,,,,,,,,,,,,,,,,, -108400,0.5871438,2.6954038,,,,,,,,,,,,,,,,, -108500,0.6283383,2.6171556,,,,,,,,,,,,,,,,, -108600,0.58971936,2.6955917,,,,,,,,,,,,,,,,, -108700,0.6169514,2.6465302,,,,,,,,,,,,,,,,, -108800,0.5979446,2.6132898,,,,,,,,,,,,,,,,, -108900,0.6146346,2.667989,,,,,,,,,,,,,,,,, -109000,0.60074687,2.6730354,,,,,,,,,,,,,,,,, -109100,0.61581177,2.6491337,,,,,,,,,,,,,,,,, -109200,0.60825574,2.6762505,,,,,,,,,,,,,,,,, -109300,0.6256353,2.6620078,,,,,,,,,,,,,,,,, -109400,0.62542087,2.7052007,,,,,,,,,,,,,,,,, -109500,0.59754086,2.7284355,,,,,,,,,,,,,,,,, -109600,0.5903589,2.6235633,,,,,,,,,,,,,,,,, -109700,0.59838486,2.6848867,,,,,,,,,,,,,,,,, -109800,0.6201928,2.6650352,,,,,,,,,,,,,,,,, -109900,0.61618525,2.6848614,,,,,,,,,,,,,,,,, -110000,0.61000824,2.6701727,,,,,,,,,,,,,,,,, -110100,0.60812104,2.6598046,,,,,,,,,,,,,,,,, -110200,0.64326936,2.7057514,,,,,,,,,,,,,,,,, -110300,0.60777044,2.6130092,,,,,,,,,,,,,,,,, -110400,0.63677365,2.6569767,,,,,,,,,,,,,,,,, -110446,,,0.6886968612670898,1.552962303161621,34.8873148481868,0.6909399628639221,1.516938328742981,30.795585694647222,3000.0,0.7068270444869995,1.429835557937622,30.79432213794177,3003.0,38680.11610341072,62524.57753229141,38680.11610341072,23839.350203037266,1.570399045944214,0.0 -110500,0.62288475,2.7445426,,,,,,,,,,,,,,,,, -110600,0.62067723,2.6463735,,,,,,,,,,,,,,,,, -110700,0.6208816,2.7147582,,,,,,,,,,,,,,,,, -110800,0.6176912,2.675053,,,,,,,,,,,,,,,,, -110900,0.6299062,2.6975129,,,,,,,,,,,,,,,,, -111000,0.6206457,2.7302854,,,,,,,,,,,,,,,,, -111100,0.6004942,2.6108983,,,,,,,,,,,,,,,,, -111200,0.63811815,2.5861814,,,,,,,,,,,,,,,,, -111300,0.6335778,2.6993396,,,,,,,,,,,,,,,,, -111400,0.66222954,2.6494083,,,,,,,,,,,,,,,,, -111500,0.6231872,2.6860442,,,,,,,,,,,,,,,,, -111600,0.6472174,2.6382282,,,,,,,,,,,,,,,,, -111700,0.6542183,2.613972,,,,,,,,,,,,,,,,, -111800,0.6371312,2.6765738,,,,,,,,,,,,,,,,, -111900,0.66189915,2.6275027,,,,,,,,,,,,,,,,, -112000,0.62983173,2.6312075,,,,,,,,,,,,,,,,, -112100,0.65630066,2.6557584,,,,,,,,,,,,,,,,, -112200,0.62822074,2.6409411,,,,,,,,,,,,,,,,, -112300,0.65022993,2.6776085,,,,,,,,,,,,,,,,, -112400,0.6226096,2.6585286,,,,,,,,,,,,,,,,, -112500,0.6363279,2.7254896,,,,,,,,,,,,,,,,, -112600,0.67699164,2.7126188,,,,,,,,,,,,,,,,, -112700,0.6322625,2.655568,,,,,,,,,,,,,,,,, -112800,0.63814884,2.5951562,,,,,,,,,,,,,,,,, -112846,,,0.7077298760414124,1.4527026414871216,36.72570097184537,0.6913243532180786,1.5172661542892456,30.76716783706688,3000.0,0.7065481543540955,1.431631565093994,30.935883540712304,3003.0,39520.2310898304,63860.211265563965,39520.2310898304,24334.75069284439,1.6107723712921145,0.0 -112900,0.66301244,2.658156,,,,,,,,,,,,,,,,, -113000,0.65189034,2.6345663,,,,,,,,,,,,,,,,, -113100,0.6984388,2.6850421,,,,,,,,,,,,,,,,, -113200,0.64200574,2.6261322,,,,,,,,,,,,,,,,, -113300,0.70565647,2.664636,,,,,,,,,,,,,,,,, -113400,0.6625402,2.628398,,,,,,,,,,,,,,,,, -113500,0.6667883,2.6990786,,,,,,,,,,,,,,,,, -113600,0.6557849,2.6508236,,,,,,,,,,,,,,,,, -113700,0.6549825,2.6545446,,,,,,,,,,,,,,,,, -113800,0.6748792,2.6837552,,,,,,,,,,,,,,,,, -113900,0.6712679,2.7177303,,,,,,,,,,,,,,,,, -114000,0.66579455,2.573896,,,,,,,,,,,,,,,,, -114100,0.68039554,2.6809084,,,,,,,,,,,,,,,,, -114200,0.6546648,2.6476874,,,,,,,,,,,,,,,,, -114300,0.66449016,2.6751363,,,,,,,,,,,,,,,,, -114400,0.6590635,2.6257002,,,,,,,,,,,,,,,,, -114500,0.636567,2.612426,,,,,,,,,,,,,,,,, -114600,0.6665772,2.6440806,,,,,,,,,,,,,,,,, -114700,0.69660646,2.650956,,,,,,,,,,,,,,,,, -114800,0.67891353,2.65414,,,,,,,,,,,,,,,,, -114900,0.67742574,2.6036503,,,,,,,,,,,,,,,,, -115000,0.67005026,2.6157155,,,,,,,,,,,,,,,,, -115100,0.65383285,2.6077101,,,,,,,,,,,,,,,,, -115200,0.69984525,2.6022017,,,,,,,,,,,,,,,,, -115248,,,0.69832843542099,1.5016064643859863,35.53555504677267,0.693717360496521,1.5081539154052734,30.621397086007867,3000.0,0.70927894115448,1.4228955507278442,30.84711171958969,3003.0,40360.24406290054,65183.5163064003,40360.24406290054,24817.92786431313,1.6504244804382324,0.0 -115300,0.7169471,2.7123249,,,,,,,,,,,,,,,,, -115400,0.69329095,2.6540654,,,,,,,,,,,,,,,,, -115500,0.68787205,2.6486807,,,,,,,,,,,,,,,,, -115600,0.7134921,2.6494877,,,,,,,,,,,,,,,,, -115700,0.67555565,2.626847,,,,,,,,,,,,,,,,, -115800,0.6962353,2.6110594,,,,,,,,,,,,,,,,, -115900,0.66944355,2.69147,,,,,,,,,,,,,,,,, -116000,0.69871396,2.5765529,,,,,,,,,,,,,,,,, -116100,0.68077815,2.6006405,,,,,,,,,,,,,,,,, -116200,0.7085855,2.655516,,,,,,,,,,,,,,,,, -116300,0.6828673,2.633842,,,,,,,,,,,,,,,,, -116400,0.68517125,2.5793688,,,,,,,,,,,,,,,,, -116500,0.67601657,2.6483893,,,,,,,,,,,,,,,,, -116600,0.69783133,2.557034,,,,,,,,,,,,,,,,, -116700,0.7213045,2.6346717,,,,,,,,,,,,,,,,, -116800,0.69766,2.670346,,,,,,,,,,,,,,,,, -116900,0.6861703,2.685975,,,,,,,,,,,,,,,,, -117000,0.69266844,2.6810248,,,,,,,,,,,,,,,,, -117100,0.72273934,2.664562,,,,,,,,,,,,,,,,, -117200,0.7168865,2.6536539,,,,,,,,,,,,,,,,, -117300,0.72468555,2.5920687,,,,,,,,,,,,,,,,, -117400,0.7117956,2.6990108,,,,,,,,,,,,,,,,, -117500,0.6942996,2.5817754,,,,,,,,,,,,,,,,, -117600,0.7087283,2.595974,,,,,,,,,,,,,,,,, -117649,,,0.6946881413459778,1.516922116279602,35.48660590814062,0.6925890445709229,1.506855607032776,31.07297394971343,3000.0,0.7097670435905457,1.4213274717330933,30.92984988614016,3003.0,41200.15984940529,66529.20644688606,41200.15984940529,25323.58639740944,1.6914031505584717,0.0 -117700,0.7029101,2.656537,,,,,,,,,,,,,,,,, -117800,0.7001099,2.63289,,,,,,,,,,,,,,,,, -117900,0.72024083,2.6197264,,,,,,,,,,,,,,,,, -118000,0.7029961,2.5694072,,,,,,,,,,,,,,,,, -118100,0.74459314,2.7008028,,,,,,,,,,,,,,,,, -118200,0.72048783,2.6253998,,,,,,,,,,,,,,,,, -118300,0.7035351,2.651759,,,,,,,,,,,,,,,,, -118400,0.73057973,2.6308663,,,,,,,,,,,,,,,,, -118500,0.746598,2.6375847,,,,,,,,,,,,,,,,, -118600,0.7248506,2.6027122,,,,,,,,,,,,,,,,, -118700,0.7475158,2.688984,,,,,,,,,,,,,,,,, -118800,0.72475696,2.6245346,,,,,,,,,,,,,,,,, -118900,0.7189694,2.603588,,,,,,,,,,,,,,,,, -119000,0.7237539,2.6763988,,,,,,,,,,,,,,,,, -119100,0.7250754,2.6662238,,,,,,,,,,,,,,,,, -119200,0.7636131,2.6495178,,,,,,,,,,,,,,,,, -119300,0.71448916,2.611675,,,,,,,,,,,,,,,,, -119400,0.71099824,2.5915396,,,,,,,,,,,,,,,,, -119500,0.75694746,2.5734446,,,,,,,,,,,,,,,,, -119600,0.91639715,2.6007912,,,,,,,,,,,,,,,,, -119700,0.7134903,2.60877,,,,,,,,,,,,,,,,, -119800,0.737762,2.5767226,,,,,,,,,,,,,,,,, -119900,0.7405086,2.6405857,,,,,,,,,,,,,,,,, -120000,0.7397038,2.6115837,,,,,,,,,,,,,,,,, -120050,,,0.7077162861824036,1.4448779821395874,35.699938968051285,0.693221390247345,1.5050368309020996,30.88221735153581,3000.0,0.7109523415565491,1.4153249263763428,31.28327925826771,3003.0,42040.353865385056,67854.64235019684,42040.353865385056,25808.70779204369,1.7335550785064695,0.0 -120100,0.76977944,2.6454341,,,,,,,,,,,,,,,,, -120200,0.7216861,2.6356955,,,,,,,,,,,,,,,,, -120300,0.7218419,2.6412504,,,,,,,,,,,,,,,,, -120400,0.73870826,2.5744824,,,,,,,,,,,,,,,,, -120500,0.78756124,2.6576738,,,,,,,,,,,,,,,,, -120600,0.72851795,2.66256,,,,,,,,,,,,,,,,, -120700,0.7186808,2.579487,,,,,,,,,,,,,,,,, -120800,0.69118047,2.6099615,,,,,,,,,,,,,,,,, -120900,0.7500014,2.6060157,,,,,,,,,,,,,,,,, -121000,0.7239236,2.625853,,,,,,,,,,,,,,,,, -121100,0.75339764,2.6820545,,,,,,,,,,,,,,,,, -121200,0.71981865,2.645824,,,,,,,,,,,,,,,,, -121300,0.71210825,2.6381035,,,,,,,,,,,,,,,,, -121400,0.7328974,2.6072674,,,,,,,,,,,,,,,,, -121500,0.7405379,2.5985317,,,,,,,,,,,,,,,,, -121600,0.77350426,2.6102612,,,,,,,,,,,,,,,,, -121700,0.73101103,2.6050506,,,,,,,,,,,,,,,,, -121800,0.7717187,2.6370075,,,,,,,,,,,,,,,,, -121900,0.7622969,2.656632,,,,,,,,,,,,,,,,, -122000,0.7684124,2.5997736,,,,,,,,,,,,,,,,, -122100,0.7300491,2.546295,,,,,,,,,,,,,,,,, -122200,0.7509826,2.6326067,,,,,,,,,,,,,,,,, -122300,0.74229765,2.5681317,,,,,,,,,,,,,,,,, -122400,0.7697368,2.6370032,,,,,,,,,,,,,,,,, -122450,,,0.7041738033294678,1.4591567516326904,36.04297490759636,0.6942257285118103,1.5022872686386108,31.002648152503227,3000.0,0.7108477354049683,1.4121458530426023,31.20686190559411,3003.0,42880.47454190254,69184.54448390007,42880.47454190254,26298.369471549988,1.7748847007751465,0.0 -122500,0.7535972,2.5880613,,,,,,,,,,,,,,,,, -122600,0.74825567,2.633867,,,,,,,,,,,,,,,,, -122700,0.7238712,2.5544586,,,,,,,,,,,,,,,,, -122800,0.74472624,2.6519067,,,,,,,,,,,,,,,,, -122900,0.76341057,2.61628,,,,,,,,,,,,,,,,, -123000,0.77067655,2.6510952,,,,,,,,,,,,,,,,, -123100,0.7651439,2.560746,,,,,,,,,,,,,,,,, -123200,0.77542555,2.6568666,,,,,,,,,,,,,,,,, -123300,0.7582647,2.5866437,,,,,,,,,,,,,,,,, -123400,0.795741,2.6020803,,,,,,,,,,,,,,,,, -123500,0.76281536,2.6251154,,,,,,,,,,,,,,,,, -123600,0.7479954,2.6270828,,,,,,,,,,,,,,,,, -123700,0.77975905,2.6464036,,,,,,,,,,,,,,,,, -123800,0.7606339,2.5542033,,,,,,,,,,,,,,,,, -123900,0.7632922,2.612336,,,,,,,,,,,,,,,,, -124000,0.7745922,2.6066153,,,,,,,,,,,,,,,,, -124100,0.7479524,2.6261888,,,,,,,,,,,,,,,,, -124200,0.73859745,2.5536346,,,,,,,,,,,,,,,,, -124300,0.7620058,2.6303065,,,,,,,,,,,,,,,,, -124400,0.7258154,2.601488,,,,,,,,,,,,,,,,, -124500,0.72218853,2.5964446,,,,,,,,,,,,,,,,, -124600,0.76153034,2.568071,,,,,,,,,,,,,,,,, -124700,0.7630687,2.560258,,,,,,,,,,,,,,,,, -124800,0.7941944,2.645778,,,,,,,,,,,,,,,,, -124852,,,0.7027409076690674,1.4704632759094238,36.27220736720769,0.6941513419151306,1.5010319948196411,31.01403960420787,3000.0,0.7114055156707764,1.4117215871810913,31.294057267940783,3003.0,43720.612193107605,70509.50488901138,43720.612193107605,26783.07488465309,1.8175811767578125,0.0 -124900,0.76422423,2.5980477,,,,,,,,,,,,,,,,, -125000,0.7658137,2.5935357,,,,,,,,,,,,,,,,, -125100,0.8003266,2.6630597,,,,,,,,,,,,,,,,, -125200,0.75023407,2.5701551,,,,,,,,,,,,,,,,, -125300,0.74995637,2.5957673,,,,,,,,,,,,,,,,, -125400,0.80561507,2.659847,,,,,,,,,,,,,,,,, -125500,0.74646735,2.6281626,,,,,,,,,,,,,,,,, -125600,0.7584905,2.5947971,,,,,,,,,,,,,,,,, -125700,0.7686745,2.6074283,,,,,,,,,,,,,,,,, -125800,0.7458227,2.611261,,,,,,,,,,,,,,,,, -125900,0.7756706,2.503927,,,,,,,,,,,,,,,,, -126000,0.77830833,2.6270542,,,,,,,,,,,,,,,,, -126100,0.75337356,2.6149523,,,,,,,,,,,,,,,,, -126200,0.77070886,2.5391212,,,,,,,,,,,,,,,,, -126300,0.7704174,2.6447406,,,,,,,,,,,,,,,,, -126400,0.7463947,2.6130853,,,,,,,,,,,,,,,,, -126500,0.78059703,2.5273154,,,,,,,,,,,,,,,,, -126600,0.75385225,2.6167786,,,,,,,,,,,,,,,,, -126700,0.79367673,2.5378196,,,,,,,,,,,,,,,,, -126800,0.7856839,2.5866685,,,,,,,,,,,,,,,,, -126900,0.7502787,2.6528468,,,,,,,,,,,,,,,,, -127000,0.74822426,2.562553,,,,,,,,,,,,,,,,, -127100,0.7674596,2.6123126,,,,,,,,,,,,,,,,, -127200,0.75042117,2.6225371,,,,,,,,,,,,,,,,, -127254,,,0.7076531052589417,1.4463974237442017,36.59637497627496,0.6947712898254395,1.4987682104110718,31.142788279699538,3000.0,0.7116495370864868,1.409090876579285,31.268713611605445,3003.0,44560.81832766533,71836.32944989204,44560.81832766533,27269.56863617897,1.865821361541748,0.0 -127300,0.79000133,2.5477047,,,,,,,,,,,,,,,,, -127400,0.76789576,2.5646458,,,,,,,,,,,,,,,,, -127500,0.7666229,2.6168206,,,,,,,,,,,,,,,,, -127600,0.75915337,2.643883,,,,,,,,,,,,,,,,, -127700,0.75471216,2.6550672,,,,,,,,,,,,,,,,, -127800,0.79224795,2.5920494,,,,,,,,,,,,,,,,, -127900,0.76084834,2.6297386,,,,,,,,,,,,,,,,, -128000,0.7916762,2.5813167,,,,,,,,,,,,,,,,, -128100,0.7609367,2.6832623,,,,,,,,,,,,,,,,, -128200,0.7454855,2.613519,,,,,,,,,,,,,,,,, -128300,0.76909983,2.6198347,,,,,,,,,,,,,,,,, -128400,0.7754774,2.5539784,,,,,,,,,,,,,,,,, -128500,0.7640288,2.5560243,,,,,,,,,,,,,,,,, -128600,0.7587619,2.5923662,,,,,,,,,,,,,,,,, -128700,0.7471555,2.5979977,,,,,,,,,,,,,,,,, -128800,0.77983034,2.6215217,,,,,,,,,,,,,,,,, -128900,0.7576534,2.6047642,,,,,,,,,,,,,,,,, -129000,0.76295185,2.588451,,,,,,,,,,,,,,,,, -129100,0.76367193,2.6497233,,,,,,,,,,,,,,,,, -129200,0.77573335,2.622287,,,,,,,,,,,,,,,,, -129300,0.72569776,2.658707,,,,,,,,,,,,,,,,, -129400,0.7873101,2.6445763,,,,,,,,,,,,,,,,, -129500,0.761438,2.5677457,,,,,,,,,,,,,,,,, -129600,0.7636316,2.6427193,,,,,,,,,,,,,,,,, -129656,,,0.7090178728103638,1.4380403757095337,36.79773448540247,0.6948084831237793,1.4995105266571045,30.993740811222963,3000.0,0.7116960287094116,1.4084151983261108,31.130467907561723,3003.0,45400.99208855629,73169.57006430626,45400.99208855629,27762.5166027546,1.908890962600708,0.0 -129700,0.77385217,2.6600685,,,,,,,,,,,,,,,,, -129800,0.76598334,2.5898786,,,,,,,,,,,,,,,,, -129900,0.76385707,2.581184,,,,,,,,,,,,,,,,, -130000,0.75501096,2.5981445,,,,,,,,,,,,,,,,, -130100,0.7645582,2.6317832,,,,,,,,,,,,,,,,, -130200,0.7580004,2.6090472,,,,,,,,,,,,,,,,, -130300,0.75614583,2.5871813,,,,,,,,,,,,,,,,, -130400,0.7930359,2.5367424,,,,,,,,,,,,,,,,, -130500,0.75372,2.5921385,,,,,,,,,,,,,,,,, -130600,0.7923927,2.5339398,,,,,,,,,,,,,,,,, -130700,0.76468545,2.6018806,,,,,,,,,,,,,,,,, -130800,0.7790236,2.6346455,,,,,,,,,,,,,,,,, -130900,0.7592385,2.5191033,,,,,,,,,,,,,,,,, -131000,0.7646776,2.5278356,,,,,,,,,,,,,,,,, -131100,0.7485796,2.538356,,,,,,,,,,,,,,,,, -131200,0.7651307,2.5954356,,,,,,,,,,,,,,,,, -131300,0.7535169,2.5996494,,,,,,,,,,,,,,,,, -131400,0.75107396,2.6153142,,,,,,,,,,,,,,,,, -131500,0.75814354,2.618488,,,,,,,,,,,,,,,,, -131600,0.7532807,2.5790617,,,,,,,,,,,,,,,,, -131700,0.76037663,2.596194,,,,,,,,,,,,,,,,, -131800,0.7772366,2.6330948,,,,,,,,,,,,,,,,, -131900,0.73807144,2.5638695,,,,,,,,,,,,,,,,, -132000,0.74193466,2.5698457,,,,,,,,,,,,,,,,, -132058,,,0.7075369954109192,1.453471302986145,36.62875430446796,0.6945481300354004,1.4992551803588867,31.076465832705026,3000.0,0.7120911478996277,1.4084433317184448,31.249781981597465,3003.0,46240.98077011109,74491.25281620026,46240.98077011109,28244.09170150757,1.952728033065796,0.0 -132100,0.76791835,2.6219995,,,,,,,,,,,,,,,,, -132200,0.7669015,2.6212785,,,,,,,,,,,,,,,,, -132300,0.74497753,2.6074312,,,,,,,,,,,,,,,,, -132400,0.7656326,2.6392322,,,,,,,,,,,,,,,,, -132500,0.74030215,2.5422883,,,,,,,,,,,,,,,,, -132600,0.755969,2.6145077,,,,,,,,,,,,,,,,, -132700,0.76389915,2.642939,,,,,,,,,,,,,,,,, -132800,0.73718566,2.5539534,,,,,,,,,,,,,,,,, -132900,0.7371659,2.5630503,,,,,,,,,,,,,,,,, -133000,0.7529944,2.5897288,,,,,,,,,,,,,,,,, -133100,0.79299337,2.5626788,,,,,,,,,,,,,,,,, -133200,0.77211314,2.5252943,,,,,,,,,,,,,,,,, -133300,0.7707744,2.6668828,,,,,,,,,,,,,,,,, -133333,,,0.7088993191719055,1.4459348917007446,36.586798307686536,0.6944240927696228,1.499434232711792,31.059765499889604,3000.0,0.7120795249938965,1.408614993095398,31.198964185453555,3003.0,46686.89723396301,75422.7423479557,46686.89723396301,28729.57975912094,1.9961745738983157,0.0 -133333,,,,,,,,,,,,,,46686.89723396301,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/eval_measurements.csv deleted file mode 100644 index ddf86a796..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -838.8422293663025,0.0,28.51171493530273,1,0,28.51171493530273,0.0007088489946909,0.0,11.191027641296388,3003,867.353991985321,0.0006397879915311,0.0,11.187285423278809,0.0004835649742744,0.0,11.190281867980955,3000 -1419.8872134685516,0.0192539691925048,868.5572006702423,2396,0,868.5572006702423,0.3910173773765564,8.25107154164634,4.271887302398682,3003,2288.539109945297,0.4204213917255401,14.03031101364103,3.955871820449829,0.408426433801651,9.933766386758254,4.070903778076172,3000 -1907.1950154304504,0.0457632541656494,1708.5745656490326,4791,0,1708.5745656490326,0.5451164841651917,18.830678245727768,2.857901334762573,3003,3615.966960668564,0.5435932278633118,24.044350532229807,2.867501735687256,0.5418655872344971,20.27386090307613,2.839674472808838,3000 -2371.394764184952,0.0773811340332031,2548.571653842926,7188,0,2548.571653842926,0.5944570302963257,21.934132082215903,2.381966590881348,3003,4920.271017313004,0.5837926268577576,26.822264155984897,2.4588887691497803,0.5924786925315857,23.340748577255788,2.391460657119751,3000 -2810.153936624527,0.1032743453979492,3388.733800888061,9587,0,3388.733800888061,0.6207309365272522,23.86926771835117,2.1605913639068604,3003,6199.2928557395935,0.5973837971687317,28.25564876003111,2.325868606567383,0.6162601709365845,25.123763093156985,2.1848175525665283,3000 -3298.432733297348,0.1302118301391601,4228.941308021545,11986,0,4228.941308021545,0.638847291469574,25.131617057030173,2.0232722759246826,3003,7527.883685111999,0.6148414611816406,29.49347492495337,2.191822528839112,0.629031240940094,26.15469182732652,2.0681145191192627,3000 -3745.225136041641,0.1623759269714355,5068.914443492889,14384,0,5068.914443492889,0.6496427059173584,26.02713434983809,1.9267456531524656,3003,8814.757400751114,0.6251729726791382,30.452753434872445,2.1058154106140137,0.6403516530990601,26.710132243429697,1.9786962270736688,3000 -4193.632272481918,0.1892008781433105,5909.078606367111,16782,0,5909.078606367111,0.658288300037384,26.765356617369413,1.8623210191726685,3003,10103.429570436478,0.6332809329032898,30.834592006989155,2.025113582611084,0.6474562883377075,27.531752272070463,1.917392373085022,3000 -4647.600977182388,0.220649242401123,6749.185294866562,19181,0,6749.185294866562,0.6626227498054504,27.056916774320545,1.812300086021424,3003,11397.61303448677,0.6485809087753296,31.71264530159373,1.9150187969207764,0.6535567045211792,27.791618921827165,1.8694406747817995,3000 -5112.02565741539,0.2493281364440918,7589.185811042786,21580,0,7589.185811042786,0.6678287386894226,27.37082354685312,1.778388500213623,3003,12702.14163517952,0.6398290395736694,31.557177824770733,1.9732670783996584,0.6579831838607788,28.18313877713356,1.8404289484024048,3000 -5562.247340202332,0.2781414985656738,8429.109383583069,23978,0,8429.109383583069,0.6701877117156982,27.80662204576,1.7438005208969116,3003,13992.390851259232,0.6370953321456909,31.35145588205602,1.9664952754974363,0.6594586372375488,28.52456600789127,1.80987560749054,3000 -6028.958881616592,0.308734655380249,9269.108746528624,26376,0,9269.108746528624,0.6761141419410706,28.11392985066348,1.732893466949463,3003,15299.209090471268,0.6522204279899597,31.961837474519832,1.8794294595718384,0.6629427671432495,28.61563434415713,1.797146201133728,3000 -6510.161110162735,0.3377220630645752,10109.169246673584,28774,0,10109.169246673584,0.6750682592391968,27.884347883711857,1.7232335805892944,3003,16620.57497549057,0.648067057132721,31.576812118177703,1.904816508293152,0.6638355255126953,28.26089184313604,1.7895036935806274,3000 -7003.654671907425,0.3678188323974609,10949.117016077042,31172,0,10949.117016077042,0.6779385209083557,28.39342145211166,1.6976670026779177,3003,17954.123304367065,0.6439923644065857,31.66790219778987,1.927431583404541,0.6673568487167358,28.57217885812917,1.7629570960998535,3000 -7469.372006177902,0.3968191146850586,11789.085807323456,33571,0,11789.085807323456,0.6815641522407532,28.767894892189727,1.701885461807251,3003,19259.912284851074,0.6535011529922485,31.83566635295377,1.888905644416809,0.6689687371253967,29.0990490788574,1.7727051973342896,3000 -7951.277729272842,0.4293038845062256,12629.10321187973,35970,0,12629.10321187973,0.6820057034492493,28.82065666908729,1.6754212379455566,3003,20581.94299149513,0.6479287147521973,32.29340009964613,1.900919795036316,0.6697747111320496,29.200182035206225,1.7474881410598757,3000 -8408.473896980286,0.4591190814971924,13469.029118299484,38367,0,13469.029118299484,0.6850618720054626,29.158325918477285,1.644993782043457,3003,21879.16938900948,0.66447913646698,33.572362901384416,1.7844882011413574,0.6720189452171326,29.411130481189385,1.7210848331451416,3000 -8879.810878276825,0.489011287689209,14309.25852894783,40765,0,14309.25852894783,0.685631275177002,29.24589113563988,1.6500742435455322,3003,23190.842586278915,0.6568623781204224,32.29865399133282,1.846411228179932,0.6735936403274536,29.367416114376645,1.7242670059204102,3000 -9380.485354423525,0.5209271907806396,15149.16226387024,43162,0,15149.16226387024,0.6875603199005127,28.950136975521986,1.6370383501052856,3003,24531.530041217804,0.6540481448173523,31.76641579588908,1.848538517951965,0.674411952495575,29.176469017738977,1.7126529216766355,3000 -9989.577900886536,0.5572404861450195,15989.111895561218,45563,0,15989.111895561218,0.3135901391506195,0.1555851035313861,4.377139091491699,3003,25980.68439888954,0.3642924129962921,1.166315163018871,3.769196033477783,0.324819266796112,0.1726707275755886,4.211953163146973,3000 -10442.567687034609,0.5890069007873535,16829.141446828842,47961,0,16829.141446828842,0.6885131597518921,29.155636425634672,1.6407841444015503,3003,27273.812072753903,0.6598178744316101,32.72687596383734,1.826006293296814,0.6744863390922546,29.72216502720672,1.7195638418197632,3000 -11092.072012662888,0.622807502746582,17669.30721282959,50361,0,17669.30721282959,0.6898379325866699,29.56486532290584,1.6269718408584597,3003,28763.591370821,0.6838412880897522,33.974547470472466,1.6663483381271362,0.6764330267906189,29.85961521752938,1.7032217979431152,3000 -11646.070712327955,0.6562457084655762,18509.54104423523,52761,0,18509.54104423523,0.6924641132354736,29.45363869235563,1.6087028980255127,3003,30157.93383049965,0.6685400009155273,32.919607532370215,1.7587262392044067,0.677734911441803,29.581894504613768,1.6866081953048706,3000 -12118.026804924011,0.6894242763519287,19349.585392713547,55160,0,19349.585392713547,0.6906862258911133,29.408274314513523,1.611272215843201,3003,31470.043719291687,0.6608024835586548,32.987301429211975,1.809146761894226,0.6783052682876587,30.181948989674293,1.6888394355773926,3000 -12565.69912815094,0.722074031829834,20189.579274892807,57559,0,20189.579274892807,0.6937656402587891,29.84514057531192,1.5925296545028689,3003,32757.81765937805,0.6720222234725952,33.59649108157888,1.7211450338363647,0.6793467998504639,29.910391819317304,1.676062822341919,3000 -13053.144666194916,0.7552039623260498,21029.69254183769,59957,0,21029.69254183769,0.6962756514549255,29.90336671635429,1.587439775466919,3003,34085.48671579361,0.6659659147262573,33.462413155122306,1.7736477851867676,0.6806362867355347,30.146721231820003,1.6755925416946411,3000 -13501.640924453735,0.7891860008239746,21869.68085551262,62356,0,21869.68085551262,0.6968218088150024,29.91820395719109,1.5876115560531616,3003,35374.082418203354,0.6661973595619202,33.18605866469831,1.785617709159851,0.6803635358810425,29.76717653235124,1.675857663154602,3000 -13985.39007115364,0.8220915794372559,22709.793394088745,64756,0,22709.793394088745,0.696577787399292,30.02749926620833,1.567109227180481,3003,36698.05253005028,0.6719292402267456,33.31623273945491,1.71939218044281,0.6818762421607971,30.07580176805558,1.6507298946380615,3000 -14433.178615570068,0.8567020893096924,23549.83774733544,67155,0,23549.83774733544,0.6964499354362488,29.772931959805444,1.5692729949951172,3003,37985.99623680115,0.6676704287528992,32.8873355357405,1.753592014312744,0.6813802719116211,29.90753576601113,1.6592007875442505,3000 -14926.024219036102,0.891709566116333,24389.92008113861,69554,0,24389.92008113861,0.6994596719741821,30.31156286628029,1.5576194524765017,3003,39319.03587079048,0.6829418540000916,34.84525919786184,1.6505063772201538,0.6819506287574768,30.26084758035537,1.6456681489944458,3000 -15408.027456998823,0.9275662899017334,25229.987417936325,71953,0,25229.987417936325,0.6987624168395996,30.00881668864493,1.561434030532837,3003,40641.21749544144,0.6751496195793152,34.044371473527406,1.7136439085006714,0.683302104473114,30.32466585747712,1.647717833518982,3000 -15889.227751493454,0.9696736335754396,26069.960821390152,74352,0,26069.960821390152,0.6999825835227966,30.17969249679205,1.5527477264404297,3003,41962.50837039948,0.6767979860305786,33.86340278617494,1.709175705909729,0.6831409335136414,30.18935325973564,1.6428673267364502,3000 -16368.497725009918,1.0062716007232666,26910.131754636765,76750,0,26910.131754636765,0.6988554000854492,30.020957346602792,1.5575097799301147,3003,43282.06278991699,0.6845331788063049,34.86974028404848,1.6683518886566162,0.683984100818634,30.00869039508164,1.6397377252578735,3000 -16835.83377790451,1.0420207977294922,27750.265002012253,79149,0,27750.265002012253,0.7008192539215088,30.329277062653787,1.5451098680496216,3003,44589.64325428009,0.6779137253761292,33.970007069879514,1.6993038654327393,0.684653639793396,30.250622985921563,1.6361262798309326,3000 -17292.853385925293,1.0856754779815674,28590.42280459404,81549,0,28590.42280459404,0.701830267906189,30.135994300813103,1.5316359996795654,3003,45886.93843173981,0.7079644799232483,36.10748172508141,1.5214544534683228,0.6850627660751343,30.422390302489923,1.627657413482666,3000 -17747.714488744736,1.1382238864898682,29430.36559510231,83947,0,29430.36559510231,0.7033990025520325,30.45241823104937,1.5269720554351809,3003,47181.868070364,0.6867976188659668,34.510879718808624,1.6339534521102903,0.6870218515396118,30.32294582121328,1.6197893619537354,3000 -18245.310992002487,1.1753811836242676,30270.410665035248,86346,0,30270.410665035248,0.7028993368148804,30.48136215527729,1.5305290222167969,3003,48519.62218332291,0.6832473278045654,34.83272741944827,1.6666216850280762,0.6869474649429321,30.546641891474536,1.6207863092422483,3000 -18713.81881380081,1.2181808948516846,31110.53629755973,88746,0,31110.53629755973,0.7047004699707031,30.563865944400103,1.5223288536071775,3003,49828.37257575989,0.7003049254417419,35.16525340135614,1.5541876554489136,0.6866498589515686,30.46771344521242,1.6179324388504028,3000 -19169.279168844223,1.2637646198272705,31950.433703422543,91145,0,31950.433703422543,0.7067689299583435,30.68588993681621,1.5179078578948977,3003,51123.85071182251,0.6862742900848389,34.99582119196632,1.6273705959320068,0.6887701153755188,30.406371078037647,1.612417221069336,3000 -19610.351407289505,1.3041329383850098,32790.42817759514,93544,0,32790.42817759514,0.7058160901069641,30.48390357612902,1.51614511013031,3003,52405.03318023682,0.6899000406265259,34.90819228231661,1.6128188371658323,0.6882493495941162,30.37207243465497,1.6095622777938845,3000 -20073.95977115631,1.3421132564544678,33630.38622021675,95943,0,33630.38622021675,0.7060717344284058,30.849107356658006,1.5157976150512695,3003,53708.71432638168,0.697588324546814,35.739230605844355,1.5702698230743408,0.6895760893821716,30.57486585549891,1.610058069229126,3000 -20543.905728816982,1.3811841011047363,34470.52443647385,98343,0,34470.52443647385,0.7058973908424377,30.59627386107932,1.5149260759353638,3003,55018.911603450775,0.6986111998558044,35.27395572303037,1.5686570405960083,0.6885965466499329,30.290472276833565,1.6070998907089231,3000 -21003.31357169152,1.4205613136291504,35310.65348362923,100742,0,35310.65348362923,0.7069548964500427,30.70333007489312,1.5113070011138916,3003,56318.56501626968,0.710821807384491,37.11803812727856,1.5044238567352295,0.6893652677536011,30.65314764916486,1.6094584465026855,3000 -21465.7024269104,1.4596493244171145,36150.85769796372,103142,0,36150.85769796372,0.7081401348114014,30.88696074093417,1.5082166194915771,3003,57621.27230811119,0.7058877348899841,36.64237829046024,1.533977508544922,0.6901340484619141,30.61682211553897,1.6044197082519531,3000 -21925.09859728813,1.5006978511810305,36991.28224873543,105542,0,36991.28224873543,0.7084190249443054,30.523381697580785,1.5058610439300537,3003,58921.207073926926,0.7012184858322144,36.010873962839284,1.5534294843673706,0.6888568997383118,30.25613399959102,1.603817582130432,3000 -22370.46321105957,1.5430285930633545,37831.47907757759,107941,0,37831.47907757759,0.7079658508300781,30.685208275329504,1.504045844078064,3003,60206.88630771637,0.7100006937980652,36.99780640058015,1.5051331520080566,0.689811646938324,30.53258891829982,1.60414719581604,3000 -22827.366498708725,1.5851430892944336,38671.50674414635,110340,0,38671.50674414635,0.7091511487960815,30.69295551513549,1.5065263509750366,3003,61503.93488526344,0.7088562846183777,35.960513350985146,1.5129475593566897,0.6904191970825195,30.43087989291936,1.6018781661987305,3000 -23285.7987074852,1.6249699592590332,39511.70456314087,112740,0,39511.70456314087,0.7088257670402527,30.616537234812025,1.500476598739624,3003,62802.67998576164,0.7243646383285522,37.83614617089411,1.4386584758758545,0.6897496581077576,30.64392620751296,1.5978442430496216,3000 -23739.756477594376,1.674572467803955,40351.63377261162,115139,0,40351.63377261162,0.7087792754173279,30.776656235656343,1.4993302822113037,3003,64096.69275712967,0.7164209485054016,37.26717590462045,1.4740687608718872,0.6905679702758789,30.91084348148604,1.598555564880371,3000 -24188.48369383812,1.718095302581787,41191.59611129761,117538,0,41191.59611129761,0.7088141441345215,30.903441013207743,1.501481056213379,3003,65385.50169849396,0.7153463959693909,36.84388016626823,1.4766026735305786,0.6916590929031372,30.74155176517884,1.6009763479232788,3000 -24656.691277503967,1.7594947814941406,42031.71455454826,119937,0,42031.71455454826,0.7088955044746399,30.883720879674414,1.4994897842407229,3003,66693.94711279869,0.7232077717781067,37.65100512157615,1.438293695449829,0.6920310854911804,30.75141096290305,1.6015228033065796,3000 -25128.7705988884,1.8893790245056152,42871.72847676277,122336,0,42871.72847676277,0.70961594581604,30.907040212287704,1.4969851970672607,3003,68006.24787807465,0.7199275493621826,37.4064410559461,1.4584884643554688,0.6925642490386963,30.63594776497589,1.5980294942855835,3000 -25586.507713079453,1.9442377090454104,43711.629980802536,124734,0,43711.629980802536,0.7094067931175232,30.645121997646555,1.498136281967163,3003,69304.01886892319,0.7197436094284058,37.86308040490159,1.4523401260375977,0.6918451189994812,30.577418563951102,1.6004658937454224,3000 -26056.940609931946,1.98885178565979,44551.60100340843,127133,0,44551.60100340843,0.70961594581604,30.935873794008984,1.49761700630188,3003,70614.54214811325,0.723660409450531,37.82017963906209,1.441123127937317,0.6915971040725708,30.58672769907038,1.5994668006896973,3000 -26523.145493984222,2.03778076171875,45391.557307481766,129532,0,45391.557307481766,0.709732174873352,30.904820565100355,1.4976327419281006,3003,71920.82846522331,0.7209928035736084,37.57740041402509,1.453560709953308,0.6917583346366882,30.57694225255877,1.5988129377365112,3000 -26983.98964881897,2.0844168663024902,46231.56835961342,131931,0,46231.56835961342,0.7099761962890625,30.837799961632037,1.4979794025421145,3003,73221.8060324192,0.7224260568618774,37.93193093486411,1.4456599950790403,0.6920434832572937,30.504024003063837,1.5994871854782104,3000 -27447.8187623024,2.1351213455200195,46722.17216897011,133333,0,46722.17216897011,0.7098832130432129,30.838390296957208,1.4978277683258057,3003,74176.3343679905,0.7234098315238953,37.50861974519182,1.439003586769104,0.6921550631523132,30.494101371783113,1.5993489027023315,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/measurements.csv deleted file mode 100644 index 5c4a71193..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.9793863,11.153433,,,,,,,,,,,,,,,,, -1,,,0.0006397879915311,11.187285423278809,0.0,0.0004835649742744,11.190281867980955,0.0,3000.0,0.0007088489946909,11.191027641296388,0.0,3003.0,28.51171493530273,867.353991985321,28.51171493530273,838.8422293663025,0.0,0.0 -100,0.28308716,9.076077,,,,,,,,,,,,,,,,, -200,0.20622914,8.751754,,,,,,,,,,,,,,,,, -300,0.85481673,8.385589,,,,,,,,,,,,,,,,, -400,0.70552874,8.078027,,,,,,,,,,,,,,,,, -500,0.8811238,7.8795543,,,,,,,,,,,,,,,,, -600,0.8429286,7.7052493,,,,,,,,,,,,,,,,, -700,0.56609964,7.503654,,,,,,,,,,,,,,,,, -800,0.5068932,7.3083787,,,,,,,,,,,,,,,,, -900,0.8021007,7.2056375,,,,,,,,,,,,,,,,, -1000,0.56624043,6.953925,,,,,,,,,,,,,,,,, -1100,0.6298307,6.843849,,,,,,,,,,,,,,,,, -1200,0.513474,6.6927433,,,,,,,,,,,,,,,,, -1300,0.7590303,6.647438,,,,,,,,,,,,,,,,, -1400,0.6347119,6.5417304,,,,,,,,,,,,,,,,, -1500,0.5585836,6.4066234,,,,,,,,,,,,,,,,, -1600,0.6429976,6.3130355,,,,,,,,,,,,,,,,, -1700,0.6418276,6.2541823,,,,,,,,,,,,,,,,, -1800,0.72755104,6.132977,,,,,,,,,,,,,,,,, -1900,0.59484565,6.031875,,,,,,,,,,,,,,,,, -2000,0.7339063,5.9784026,,,,,,,,,,,,,,,,, -2100,0.67892367,5.789961,,,,,,,,,,,,,,,,, -2200,0.5377997,5.725469,,,,,,,,,,,,,,,,, -2300,0.6752447,5.6639843,,,,,,,,,,,,,,,,, -2396,,,0.4204213917255401,3.955871820449829,14.03031101364103,0.408426433801651,4.070903778076172,9.933766386758254,3000.0,0.3910173773765564,4.271887302398682,8.25107154164634,3003.0,868.5572006702423,2288.539109945297,868.5572006702423,1419.8872134685516,0.0192539691925048,0.0 -2400,0.57645226,5.581534,,,,,,,,,,,,,,,,, -2500,0.54556096,5.5360904,,,,,,,,,,,,,,,,, -2600,0.6013534,5.44068,,,,,,,,,,,,,,,,, -2700,0.7028097,5.449978,,,,,,,,,,,,,,,,, -2800,0.520051,5.2542396,,,,,,,,,,,,,,,,, -2900,0.5108324,5.313332,,,,,,,,,,,,,,,,, -3000,0.6266188,5.264731,,,,,,,,,,,,,,,,, -3100,0.5019331,5.1522555,,,,,,,,,,,,,,,,, -3200,0.69292164,5.097206,,,,,,,,,,,,,,,,, -3300,0.4918401,5.0078325,,,,,,,,,,,,,,,,, -3400,0.54398423,5.053867,,,,,,,,,,,,,,,,, -3500,0.505278,4.9402504,,,,,,,,,,,,,,,,, -3600,0.441977,4.9529567,,,,,,,,,,,,,,,,, -3700,0.43984687,4.865982,,,,,,,,,,,,,,,,, -3800,0.47163188,4.918844,,,,,,,,,,,,,,,,, -3900,0.50097835,4.9720006,,,,,,,,,,,,,,,,, -4000,0.42881155,4.8833966,,,,,,,,,,,,,,,,, -4100,0.44617587,4.8410177,,,,,,,,,,,,,,,,, -4200,0.40296322,4.8707843,,,,,,,,,,,,,,,,, -4300,0.39647388,4.831126,,,,,,,,,,,,,,,,, -4400,0.41150352,4.837865,,,,,,,,,,,,,,,,, -4500,0.4020457,4.744564,,,,,,,,,,,,,,,,, -4600,0.38454765,4.82504,,,,,,,,,,,,,,,,, -4700,0.37995505,4.720276,,,,,,,,,,,,,,,,, -4791,,,0.5435932278633118,2.867501735687256,24.044350532229807,0.5418655872344971,2.839674472808838,20.27386090307613,3000.0,0.5451164841651917,2.857901334762573,18.830678245727768,3003.0,1708.5745656490326,3615.966960668564,1708.5745656490326,1907.1950154304504,0.0457632541656494,0.0 -4800,0.41858727,4.6902294,,,,,,,,,,,,,,,,, -4900,0.37659895,4.7100315,,,,,,,,,,,,,,,,, -5000,0.35276368,4.6518197,,,,,,,,,,,,,,,,, -5100,0.3721263,4.680577,,,,,,,,,,,,,,,,, -5200,0.3522225,4.6502724,,,,,,,,,,,,,,,,, -5300,0.35199064,4.6242046,,,,,,,,,,,,,,,,, -5400,0.33797598,4.6839013,,,,,,,,,,,,,,,,, -5500,0.36676785,4.623737,,,,,,,,,,,,,,,,, -5600,0.3854676,4.677263,,,,,,,,,,,,,,,,, -5700,0.34671107,4.570304,,,,,,,,,,,,,,,,, -5800,0.31157255,4.5265512,,,,,,,,,,,,,,,,, -5900,0.33161804,4.6229897,,,,,,,,,,,,,,,,, -6000,0.3238063,4.4889708,,,,,,,,,,,,,,,,, -6100,0.29906103,4.5373163,,,,,,,,,,,,,,,,, -6200,0.29719618,4.6099405,,,,,,,,,,,,,,,,, -6300,0.29022473,4.537184,,,,,,,,,,,,,,,,, -6400,0.31707415,4.562203,,,,,,,,,,,,,,,,, -6500,0.30601478,4.554979,,,,,,,,,,,,,,,,, -6600,0.2804476,4.5399065,,,,,,,,,,,,,,,,, -6700,0.26094544,4.549295,,,,,,,,,,,,,,,,, -6800,0.31309348,4.4838777,,,,,,,,,,,,,,,,, -6900,0.25017318,4.4940424,,,,,,,,,,,,,,,,, -7000,0.2589379,4.46261,,,,,,,,,,,,,,,,, -7100,0.27188975,4.4607034,,,,,,,,,,,,,,,,, -7188,,,0.5837926268577576,2.4588887691497803,26.822264155984897,0.5924786925315857,2.391460657119751,23.340748577255788,3000.0,0.5944570302963257,2.381966590881348,21.934132082215903,3003.0,2548.571653842926,4920.271017313004,2548.571653842926,2371.394764184952,0.0773811340332031,0.0 -7200,0.23791309,4.446186,,,,,,,,,,,,,,,,, -7300,0.23529807,4.407566,,,,,,,,,,,,,,,,, -7400,0.24056777,4.3735104,,,,,,,,,,,,,,,,, -7500,0.2709292,4.4180727,,,,,,,,,,,,,,,,, -7600,0.23474102,4.4008336,,,,,,,,,,,,,,,,, -7700,0.2475145,4.3882875,,,,,,,,,,,,,,,,, -7800,0.21958897,4.383618,,,,,,,,,,,,,,,,, -7900,0.24165915,4.3592596,,,,,,,,,,,,,,,,, -8000,0.21648507,4.333078,,,,,,,,,,,,,,,,, -8100,0.21393782,4.36673,,,,,,,,,,,,,,,,, -8200,0.21392396,4.293405,,,,,,,,,,,,,,,,, -8300,0.21744365,4.4180264,,,,,,,,,,,,,,,,, -8400,0.20725209,4.327676,,,,,,,,,,,,,,,,, -8500,0.2018986,4.323656,,,,,,,,,,,,,,,,, -8600,0.2295744,4.343627,,,,,,,,,,,,,,,,, -8700,0.21719332,4.3447957,,,,,,,,,,,,,,,,, -8800,0.2240338,4.289812,,,,,,,,,,,,,,,,, -8900,0.26129773,4.3422465,,,,,,,,,,,,,,,,, -9000,0.19522627,4.3563952,,,,,,,,,,,,,,,,, -9100,0.22207665,4.2863255,,,,,,,,,,,,,,,,, -9200,0.1867775,4.375924,,,,,,,,,,,,,,,,, -9300,0.1924791,4.260812,,,,,,,,,,,,,,,,, -9400,0.19721042,4.2858243,,,,,,,,,,,,,,,,, -9500,0.20110308,4.289108,,,,,,,,,,,,,,,,, -9587,,,0.5973837971687317,2.325868606567383,28.25564876003111,0.6162601709365845,2.1848175525665283,25.123763093156985,3000.0,0.6207309365272522,2.1605913639068604,23.86926771835117,3003.0,3388.733800888061,6199.2928557395935,3388.733800888061,2810.153936624527,0.1032743453979492,0.0 -9600,0.20935374,4.2855,,,,,,,,,,,,,,,,, -9700,0.20172694,4.222354,,,,,,,,,,,,,,,,, -9800,0.18916766,4.2630506,,,,,,,,,,,,,,,,, -9900,0.17844597,4.236443,,,,,,,,,,,,,,,,, -10000,0.18910426,4.2198935,,,,,,,,,,,,,,,,, -10100,0.1850432,4.221271,,,,,,,,,,,,,,,,, -10200,0.21582696,4.246985,,,,,,,,,,,,,,,,, -10300,0.17452951,4.2492256,,,,,,,,,,,,,,,,, -10400,0.17987071,4.19552,,,,,,,,,,,,,,,,, -10500,0.21474041,4.2984447,,,,,,,,,,,,,,,,, -10600,0.17413609,4.2501087,,,,,,,,,,,,,,,,, -10700,0.19836168,4.342918,,,,,,,,,,,,,,,,, -10800,0.18016534,4.263479,,,,,,,,,,,,,,,,, -10900,0.1702199,4.208224,,,,,,,,,,,,,,,,, -11000,0.17915182,4.275029,,,,,,,,,,,,,,,,, -11100,0.16405743,4.19266,,,,,,,,,,,,,,,,, -11200,0.18615547,4.2592807,,,,,,,,,,,,,,,,, -11300,0.17960429,4.1850343,,,,,,,,,,,,,,,,, -11400,0.18444468,4.1758413,,,,,,,,,,,,,,,,, -11500,0.17989159,4.2091246,,,,,,,,,,,,,,,,, -11600,0.18810737,4.198193,,,,,,,,,,,,,,,,, -11700,0.20731646,4.3227224,,,,,,,,,,,,,,,,, -11800,0.18943344,4.1798077,,,,,,,,,,,,,,,,, -11900,0.17411307,4.2159495,,,,,,,,,,,,,,,,, -11986,,,0.6148414611816406,2.191822528839112,29.49347492495337,0.629031240940094,2.0681145191192627,26.15469182732652,3000.0,0.638847291469574,2.0232722759246826,25.131617057030173,3003.0,4228.941308021545,7527.883685111999,4228.941308021545,3298.432733297348,0.1302118301391601,0.0 -12000,0.18071151,4.233079,,,,,,,,,,,,,,,,, -12100,0.164395,4.1195507,,,,,,,,,,,,,,,,, -12200,0.19772571,4.169731,,,,,,,,,,,,,,,,, -12300,0.17200671,4.205206,,,,,,,,,,,,,,,,, -12400,0.18225425,4.160146,,,,,,,,,,,,,,,,, -12500,0.18193926,4.1807656,,,,,,,,,,,,,,,,, -12600,0.1714537,4.1697474,,,,,,,,,,,,,,,,, -12700,0.17013726,4.1811485,,,,,,,,,,,,,,,,, -12800,0.15881142,4.126411,,,,,,,,,,,,,,,,, -12900,0.15558934,4.183957,,,,,,,,,,,,,,,,, -13000,0.19292822,4.2200956,,,,,,,,,,,,,,,,, -13100,0.16955844,4.183307,,,,,,,,,,,,,,,,, -13200,0.1738953,4.211741,,,,,,,,,,,,,,,,, -13300,0.15601146,4.128059,,,,,,,,,,,,,,,,, -13400,0.1636788,4.1753416,,,,,,,,,,,,,,,,, -13500,0.16436835,4.135151,,,,,,,,,,,,,,,,, -13600,0.15949892,4.1143856,,,,,,,,,,,,,,,,, -13700,0.17042401,4.1470833,,,,,,,,,,,,,,,,, -13800,0.16728458,4.166496,,,,,,,,,,,,,,,,, -13900,0.1843589,4.1725864,,,,,,,,,,,,,,,,, -14000,0.17765947,4.1109986,,,,,,,,,,,,,,,,, -14100,0.1589539,4.117954,,,,,,,,,,,,,,,,, -14200,0.18239178,4.1440554,,,,,,,,,,,,,,,,, -14300,0.15793031,4.0527453,,,,,,,,,,,,,,,,, -14384,,,0.6251729726791382,2.1058154106140137,30.452753434872445,0.6403516530990601,1.9786962270736688,26.710132243429697,3000.0,0.6496427059173584,1.9267456531524656,26.02713434983809,3003.0,5068.914443492889,8814.757400751114,5068.914443492889,3745.225136041641,0.1623759269714355,0.0 -14400,0.1591921,4.1306214,,,,,,,,,,,,,,,,, -14500,0.16094281,4.097742,,,,,,,,,,,,,,,,, -14600,0.18584257,4.128318,,,,,,,,,,,,,,,,, -14700,0.15775234,4.084554,,,,,,,,,,,,,,,,, -14800,0.14964919,4.126914,,,,,,,,,,,,,,,,, -14900,0.15817267,4.124673,,,,,,,,,,,,,,,,, -15000,0.20682637,4.0548687,,,,,,,,,,,,,,,,, -15100,0.17034337,4.070092,,,,,,,,,,,,,,,,, -15200,0.16607279,4.1993456,,,,,,,,,,,,,,,,, -15300,0.15856108,4.133731,,,,,,,,,,,,,,,,, -15400,0.17537424,4.0312104,,,,,,,,,,,,,,,,, -15500,0.1890999,4.1392527,,,,,,,,,,,,,,,,, -15600,0.15955146,4.1450806,,,,,,,,,,,,,,,,, -15700,0.17310151,4.063243,,,,,,,,,,,,,,,,, -15800,0.1758614,4.1158113,,,,,,,,,,,,,,,,, -15900,0.16687639,4.1702175,,,,,,,,,,,,,,,,, -16000,0.17630562,4.1069403,,,,,,,,,,,,,,,,, -16100,0.15884407,4.153007,,,,,,,,,,,,,,,,, -16200,0.17057636,4.0430965,,,,,,,,,,,,,,,,, -16300,0.15728341,4.0699406,,,,,,,,,,,,,,,,, -16400,0.15590307,4.0728984,,,,,,,,,,,,,,,,, -16500,0.16539794,4.08131,,,,,,,,,,,,,,,,, -16600,0.17620516,4.0636625,,,,,,,,,,,,,,,,, -16700,0.21121638,4.038431,,,,,,,,,,,,,,,,, -16782,,,0.6332809329032898,2.025113582611084,30.834592006989155,0.6474562883377075,1.917392373085022,27.531752272070463,3000.0,0.658288300037384,1.8623210191726685,26.765356617369413,3003.0,5909.078606367111,10103.429570436478,5909.078606367111,4193.632272481918,0.1892008781433105,0.0 -16800,0.23392442,4.121644,,,,,,,,,,,,,,,,, -16900,0.15207165,4.0542083,,,,,,,,,,,,,,,,, -17000,0.21927267,4.126695,,,,,,,,,,,,,,,,, -17100,0.21484841,4.10452,,,,,,,,,,,,,,,,, -17200,0.16632897,4.035848,,,,,,,,,,,,,,,,, -17300,0.17006491,4.0952797,,,,,,,,,,,,,,,,, -17400,0.18972886,4.0401397,,,,,,,,,,,,,,,,, -17500,0.15697049,4.05741,,,,,,,,,,,,,,,,, -17600,0.1608362,4.1004696,,,,,,,,,,,,,,,,, -17700,0.15386249,4.0060506,,,,,,,,,,,,,,,,, -17800,0.2432756,4.076232,,,,,,,,,,,,,,,,, -17900,0.17083637,4.057,,,,,,,,,,,,,,,,, -18000,0.17659521,4.0652223,,,,,,,,,,,,,,,,, -18100,0.1736357,4.082292,,,,,,,,,,,,,,,,, -18200,0.25749552,4.0669165,,,,,,,,,,,,,,,,, -18300,0.21964718,4.047198,,,,,,,,,,,,,,,,, -18400,0.15676929,4.03821,,,,,,,,,,,,,,,,, -18500,0.20041372,4.0665793,,,,,,,,,,,,,,,,, -18600,0.21157902,4.063512,,,,,,,,,,,,,,,,, -18700,0.19125426,4.0376186,,,,,,,,,,,,,,,,, -18800,0.23057471,4.066479,,,,,,,,,,,,,,,,, -18900,0.16639169,3.9882135,,,,,,,,,,,,,,,,, -19000,0.15863287,4.060006,,,,,,,,,,,,,,,,, -19100,0.16690822,4.063904,,,,,,,,,,,,,,,,, -19181,,,0.6485809087753296,1.9150187969207764,31.71264530159373,0.6535567045211792,1.8694406747817995,27.791618921827165,3000.0,0.6626227498054504,1.812300086021424,27.056916774320545,3003.0,6749.185294866562,11397.61303448677,6749.185294866562,4647.600977182388,0.220649242401123,0.0 -19200,0.16347072,4.0114136,,,,,,,,,,,,,,,,, -19300,0.16580337,4.018551,,,,,,,,,,,,,,,,, -19400,0.20688197,4.096822,,,,,,,,,,,,,,,,, -19500,0.17268626,4.09243,,,,,,,,,,,,,,,,, -19600,0.15765193,4.041841,,,,,,,,,,,,,,,,, -19700,0.14956053,4.1208105,,,,,,,,,,,,,,,,, -19800,0.17170075,4.0297656,,,,,,,,,,,,,,,,, -19900,0.16811347,4.044556,,,,,,,,,,,,,,,,, -20000,0.2517398,4.0860734,,,,,,,,,,,,,,,,, -20100,0.16186853,3.9883294,,,,,,,,,,,,,,,,, -20200,0.20292884,4.053085,,,,,,,,,,,,,,,,, -20300,0.16063674,4.0609784,,,,,,,,,,,,,,,,, -20400,0.16035593,4.116659,,,,,,,,,,,,,,,,, -20500,0.23302543,4.024647,,,,,,,,,,,,,,,,, -20600,0.19807228,4.0517554,,,,,,,,,,,,,,,,, -20700,0.18188041,4.000985,,,,,,,,,,,,,,,,, -20800,0.16813236,4.024017,,,,,,,,,,,,,,,,, -20900,0.17265253,3.9646373,,,,,,,,,,,,,,,,, -21000,0.1616368,4.0249844,,,,,,,,,,,,,,,,, -21100,0.17666695,4.07677,,,,,,,,,,,,,,,,, -21200,0.17119731,4.0598245,,,,,,,,,,,,,,,,, -21300,0.28974956,4.05105,,,,,,,,,,,,,,,,, -21400,0.20836486,4.054217,,,,,,,,,,,,,,,,, -21500,0.17534208,4.0839095,,,,,,,,,,,,,,,,, -21580,,,0.6398290395736694,1.9732670783996584,31.557177824770733,0.6579831838607788,1.8404289484024048,28.18313877713356,3000.0,0.6678287386894226,1.778388500213623,27.37082354685312,3003.0,7589.185811042786,12702.14163517952,7589.185811042786,5112.02565741539,0.2493281364440918,0.0 -21600,0.19470419,3.9730644,,,,,,,,,,,,,,,,, -21700,0.1705651,3.9719725,,,,,,,,,,,,,,,,, -21800,0.17410244,4.0001273,,,,,,,,,,,,,,,,, -21900,0.18031226,4.0716486,,,,,,,,,,,,,,,,, -22000,0.17886573,3.9923265,,,,,,,,,,,,,,,,, -22100,0.23284522,4.0852923,,,,,,,,,,,,,,,,, -22200,0.17752717,4.0452824,,,,,,,,,,,,,,,,, -22300,0.15508519,3.9624264,,,,,,,,,,,,,,,,, -22400,0.26761016,3.9858677,,,,,,,,,,,,,,,,, -22500,0.19787122,3.988224,,,,,,,,,,,,,,,,, -22600,0.16953568,3.9814122,,,,,,,,,,,,,,,,, -22700,0.19645658,4.0781746,,,,,,,,,,,,,,,,, -22800,0.2103503,3.991538,,,,,,,,,,,,,,,,, -22900,0.22016203,4.005991,,,,,,,,,,,,,,,,, -23000,0.17568994,3.999373,,,,,,,,,,,,,,,,, -23100,0.20998569,4.005884,,,,,,,,,,,,,,,,, -23200,0.18062086,4.023234,,,,,,,,,,,,,,,,, -23300,0.18353651,4.0334716,,,,,,,,,,,,,,,,, -23400,0.18512695,4.058509,,,,,,,,,,,,,,,,, -23500,0.17671081,3.9430637,,,,,,,,,,,,,,,,, -23600,0.30574617,4.061396,,,,,,,,,,,,,,,,, -23700,0.18499194,4.033817,,,,,,,,,,,,,,,,, -23800,0.22359782,3.9894166,,,,,,,,,,,,,,,,, -23900,0.18302004,3.9574323,,,,,,,,,,,,,,,,, -23978,,,0.6370953321456909,1.9664952754974363,31.35145588205602,0.6594586372375488,1.80987560749054,28.52456600789127,3000.0,0.6701877117156982,1.7438005208969116,27.80662204576,3003.0,8429.109383583069,13992.390851259232,8429.109383583069,5562.247340202332,0.2781414985656738,0.0 -24000,0.16586584,3.9618587,,,,,,,,,,,,,,,,, -24100,0.17305025,3.9750676,,,,,,,,,,,,,,,,, -24200,0.2386976,3.997244,,,,,,,,,,,,,,,,, -24300,0.18972328,3.9811008,,,,,,,,,,,,,,,,, -24400,0.18987255,3.9672177,,,,,,,,,,,,,,,,, -24500,0.18025729,3.958938,,,,,,,,,,,,,,,,, -24600,0.17178464,3.9297898,,,,,,,,,,,,,,,,, -24700,0.246925,3.9702005,,,,,,,,,,,,,,,,, -24800,0.18683822,3.9611464,,,,,,,,,,,,,,,,, -24900,0.18594842,4.0016527,,,,,,,,,,,,,,,,, -25000,0.20348266,3.9996684,,,,,,,,,,,,,,,,, -25100,0.23279072,3.9714134,,,,,,,,,,,,,,,,, -25200,0.20435053,3.9333358,,,,,,,,,,,,,,,,, -25300,0.20308277,4.086254,,,,,,,,,,,,,,,,, -25400,0.16760962,3.9964254,,,,,,,,,,,,,,,,, -25500,0.20940962,3.9734867,,,,,,,,,,,,,,,,, -25600,0.19554225,3.9693036,,,,,,,,,,,,,,,,, -25700,0.17882977,3.9815087,,,,,,,,,,,,,,,,, -25800,0.25720474,4.0468316,,,,,,,,,,,,,,,,, -25900,0.22412105,3.9819438,,,,,,,,,,,,,,,,, -26000,0.2691343,3.9422486,,,,,,,,,,,,,,,,, -26100,0.27467433,3.9689445,,,,,,,,,,,,,,,,, -26200,0.20436373,3.9716177,,,,,,,,,,,,,,,,, -26300,0.19185443,3.9683979,,,,,,,,,,,,,,,,, -26376,,,0.6522204279899597,1.8794294595718384,31.961837474519832,0.6629427671432495,1.797146201133728,28.61563434415713,3000.0,0.6761141419410706,1.732893466949463,28.11392985066348,3003.0,9269.108746528624,15299.209090471268,9269.108746528624,6028.958881616592,0.308734655380249,0.0 -26400,0.23559421,3.989599,,,,,,,,,,,,,,,,, -26500,0.26128188,3.9432256,,,,,,,,,,,,,,,,, -26600,0.22181806,3.9285338,,,,,,,,,,,,,,,,, -26700,0.16070423,3.99252,,,,,,,,,,,,,,,,, -26800,0.2683718,4.009258,,,,,,,,,,,,,,,,, -26900,0.17778227,3.9686859,,,,,,,,,,,,,,,,, -27000,0.17506808,3.9486952,,,,,,,,,,,,,,,,, -27100,0.19031288,3.9387555,,,,,,,,,,,,,,,,, -27200,0.2164778,3.941546,,,,,,,,,,,,,,,,, -27300,0.18021666,3.9072528,,,,,,,,,,,,,,,,, -27400,0.1959131,3.9583642,,,,,,,,,,,,,,,,, -27500,0.20488206,3.9564397,,,,,,,,,,,,,,,,, -27600,0.2573527,3.9747128,,,,,,,,,,,,,,,,, -27700,0.20876442,3.9622142,,,,,,,,,,,,,,,,, -27800,0.20637186,3.9800558,,,,,,,,,,,,,,,,, -27900,0.23719819,3.9186258,,,,,,,,,,,,,,,,, -28000,0.21070921,4.0672145,,,,,,,,,,,,,,,,, -28100,0.21334599,3.9028904,,,,,,,,,,,,,,,,, -28200,0.21744277,3.9378104,,,,,,,,,,,,,,,,, -28300,0.22336075,3.9511514,,,,,,,,,,,,,,,,, -28400,0.23183858,3.9341955,,,,,,,,,,,,,,,,, -28500,0.17921986,3.9030397,,,,,,,,,,,,,,,,, -28600,0.2690738,4.037204,,,,,,,,,,,,,,,,, -28700,0.20520636,3.9694405,,,,,,,,,,,,,,,,, -28774,,,0.648067057132721,1.904816508293152,31.576812118177703,0.6638355255126953,1.7895036935806274,28.26089184313604,3000.0,0.6750682592391968,1.7232335805892944,27.884347883711857,3003.0,10109.169246673584,16620.57497549057,10109.169246673584,6510.161110162735,0.3377220630645752,0.0 -28800,0.2209079,3.9852173,,,,,,,,,,,,,,,,, -28900,0.24021383,4.0244417,,,,,,,,,,,,,,,,, -29000,0.1943898,3.9201467,,,,,,,,,,,,,,,,, -29100,0.19836901,3.9302082,,,,,,,,,,,,,,,,, -29200,0.235318,4.0218973,,,,,,,,,,,,,,,,, -29300,0.21543233,3.9540002,,,,,,,,,,,,,,,,, -29400,0.18872099,4.025523,,,,,,,,,,,,,,,,, -29500,0.24675032,3.8887198,,,,,,,,,,,,,,,,, -29600,0.19858888,3.9342105,,,,,,,,,,,,,,,,, -29700,0.28536594,3.9523742,,,,,,,,,,,,,,,,, -29800,0.21049713,3.9604418,,,,,,,,,,,,,,,,, -29900,0.22478615,3.9303179,,,,,,,,,,,,,,,,, -30000,0.20712087,3.9360332,,,,,,,,,,,,,,,,, -30100,0.22251719,3.9731688,,,,,,,,,,,,,,,,, -30200,0.26793897,3.9185784,,,,,,,,,,,,,,,,, -30300,0.2077922,3.922257,,,,,,,,,,,,,,,,, -30400,0.23812121,3.9901261,,,,,,,,,,,,,,,,, -30500,0.27539676,3.962757,,,,,,,,,,,,,,,,, -30600,0.2887969,3.900778,,,,,,,,,,,,,,,,, -30700,0.19975837,3.9647083,,,,,,,,,,,,,,,,, -30800,0.17374425,3.904178,,,,,,,,,,,,,,,,, -30900,0.2048912,3.9076555,,,,,,,,,,,,,,,,, -31000,0.20547156,3.9278858,,,,,,,,,,,,,,,,, -31100,0.2241023,3.9330823,,,,,,,,,,,,,,,,, -31172,,,0.6439923644065857,1.927431583404541,31.66790219778987,0.6673568487167358,1.7629570960998535,28.57217885812917,3000.0,0.6779385209083557,1.6976670026779177,28.39342145211166,3003.0,10949.117016077042,17954.123304367065,10949.117016077042,7003.654671907425,0.3678188323974609,0.0 -31200,0.1992653,3.9040315,,,,,,,,,,,,,,,,, -31300,0.21522638,3.9664867,,,,,,,,,,,,,,,,, -31400,0.2096703,3.953812,,,,,,,,,,,,,,,,, -31500,0.19431098,3.9040415,,,,,,,,,,,,,,,,, -31600,0.21219082,3.9090204,,,,,,,,,,,,,,,,, -31700,0.22913656,3.9782867,,,,,,,,,,,,,,,,, -31800,0.2633317,3.9573696,,,,,,,,,,,,,,,,, -31900,0.21732345,3.9260008,,,,,,,,,,,,,,,,, -32000,0.19758767,3.9386923,,,,,,,,,,,,,,,,, -32100,0.2059511,3.9476254,,,,,,,,,,,,,,,,, -32200,0.25617626,3.9399889,,,,,,,,,,,,,,,,, -32300,0.20243354,3.87163,,,,,,,,,,,,,,,,, -32400,0.22962236,4.0281243,,,,,,,,,,,,,,,,, -32500,0.20402528,3.9094193,,,,,,,,,,,,,,,,, -32600,0.22564805,3.904729,,,,,,,,,,,,,,,,, -32700,0.21208124,3.9203136,,,,,,,,,,,,,,,,, -32800,0.21818084,4.0069833,,,,,,,,,,,,,,,,, -32900,0.18386386,3.92561,,,,,,,,,,,,,,,,, -33000,0.20219193,3.9458585,,,,,,,,,,,,,,,,, -33100,0.21071623,3.9607117,,,,,,,,,,,,,,,,, -33200,0.23541656,4.0032334,,,,,,,,,,,,,,,,, -33300,0.21263573,3.9223099,,,,,,,,,,,,,,,,, -33400,0.21929796,3.9211488,,,,,,,,,,,,,,,,, -33500,0.2269757,4.004682,,,,,,,,,,,,,,,,, -33571,,,0.6535011529922485,1.888905644416809,31.83566635295377,0.6689687371253967,1.7727051973342896,29.0990490788574,3000.0,0.6815641522407532,1.701885461807251,28.767894892189727,3003.0,11789.085807323456,19259.912284851074,11789.085807323456,7469.372006177902,0.3968191146850586,0.0 -33600,0.25997174,3.918284,,,,,,,,,,,,,,,,, -33700,0.21348962,3.9031122,,,,,,,,,,,,,,,,, -33800,0.21011932,3.915789,,,,,,,,,,,,,,,,, -33900,0.23957543,3.9485118,,,,,,,,,,,,,,,,, -34000,0.23530065,3.9298906,,,,,,,,,,,,,,,,, -34100,1.4118567,4.000487,,,,,,,,,,,,,,,,, -34200,0.21990341,4.0000033,,,,,,,,,,,,,,,,, -34300,0.23537731,3.9678,,,,,,,,,,,,,,,,, -34400,0.24200962,3.88332,,,,,,,,,,,,,,,,, -34500,0.2533975,3.9641042,,,,,,,,,,,,,,,,, -34600,0.21048394,3.9126575,,,,,,,,,,,,,,,,, -34700,0.21275163,3.896408,,,,,,,,,,,,,,,,, -34800,0.21840394,3.9536085,,,,,,,,,,,,,,,,, -34900,0.23324923,3.9884222,,,,,,,,,,,,,,,,, -35000,0.23132965,3.9063182,,,,,,,,,,,,,,,,, -35100,0.22553118,3.940766,,,,,,,,,,,,,,,,, -35200,0.23445936,3.940757,,,,,,,,,,,,,,,,, -35300,0.23088436,3.899485,,,,,,,,,,,,,,,,, -35400,0.21343476,3.9125643,,,,,,,,,,,,,,,,, -35500,0.22172366,3.979881,,,,,,,,,,,,,,,,, -35600,0.2842291,3.9281619,,,,,,,,,,,,,,,,, -35700,0.2531823,3.9478896,,,,,,,,,,,,,,,,, -35800,0.22827254,3.8785815,,,,,,,,,,,,,,,,, -35900,0.22168334,3.9154086,,,,,,,,,,,,,,,,, -35970,,,0.6479287147521973,1.900919795036316,32.29340009964613,0.6697747111320496,1.7474881410598757,29.200182035206225,3000.0,0.6820057034492493,1.6754212379455566,28.82065666908729,3003.0,12629.10321187973,20581.94299149513,12629.10321187973,7951.277729272842,0.4293038845062256,0.0 -36000,0.20959437,3.8734655,,,,,,,,,,,,,,,,, -36100,0.25630036,3.9652863,,,,,,,,,,,,,,,,, -36200,0.28570938,3.9062138,,,,,,,,,,,,,,,,, -36300,0.21929008,3.903222,,,,,,,,,,,,,,,,, -36400,0.20323217,3.9017336,,,,,,,,,,,,,,,,, -36500,0.23318036,3.9575887,,,,,,,,,,,,,,,,, -36600,0.2839059,3.9034681,,,,,,,,,,,,,,,,, -36700,0.24486674,3.8931398,,,,,,,,,,,,,,,,, -36800,0.25837022,3.897808,,,,,,,,,,,,,,,,, -36900,0.23393773,3.958345,,,,,,,,,,,,,,,,, -37000,0.2531004,3.9456794,,,,,,,,,,,,,,,,, -37100,0.26040807,3.8463228,,,,,,,,,,,,,,,,, -37200,0.27314427,3.9228091,,,,,,,,,,,,,,,,, -37300,0.25507358,3.909131,,,,,,,,,,,,,,,,, -37400,0.25300786,3.9321494,,,,,,,,,,,,,,,,, -37500,0.22655986,3.9518251,,,,,,,,,,,,,,,,, -37600,0.3260843,3.903852,,,,,,,,,,,,,,,,, -37700,0.29779437,3.9070337,,,,,,,,,,,,,,,,, -37800,0.23064242,3.948924,,,,,,,,,,,,,,,,, -37900,0.27043122,3.9600325,,,,,,,,,,,,,,,,, -38000,0.20719153,3.922887,,,,,,,,,,,,,,,,, -38100,0.22996601,3.9224021,,,,,,,,,,,,,,,,, -38200,0.22404031,3.8995154,,,,,,,,,,,,,,,,, -38300,0.24380867,3.9171453,,,,,,,,,,,,,,,,, -38367,,,0.66447913646698,1.7844882011413574,33.572362901384416,0.6720189452171326,1.7210848331451416,29.411130481189385,3000.0,0.6850618720054626,1.644993782043457,29.158325918477285,3003.0,13469.029118299484,21879.16938900948,13469.029118299484,8408.473896980286,0.4591190814971924,0.0 -38400,0.22390836,3.9025533,,,,,,,,,,,,,,,,, -38500,0.24771419,3.9029043,,,,,,,,,,,,,,,,, -38600,0.26174504,3.9592073,,,,,,,,,,,,,,,,, -38700,0.21195543,3.8494978,,,,,,,,,,,,,,,,, -38800,0.23989235,3.8849738,,,,,,,,,,,,,,,,, -38900,0.21573688,3.8777378,,,,,,,,,,,,,,,,, -39000,0.22306854,3.9365947,,,,,,,,,,,,,,,,, -39100,0.23556913,3.8732228,,,,,,,,,,,,,,,,, -39200,0.21990551,3.9482782,,,,,,,,,,,,,,,,, -39300,0.21772994,3.9084246,,,,,,,,,,,,,,,,, -39400,0.25198454,3.919216,,,,,,,,,,,,,,,,, -39500,0.23250428,3.8786628,,,,,,,,,,,,,,,,, -39600,0.30584192,3.9021962,,,,,,,,,,,,,,,,, -39700,0.26352224,3.8856792,,,,,,,,,,,,,,,,, -39800,0.24279541,3.925272,,,,,,,,,,,,,,,,, -39900,0.2728644,3.874369,,,,,,,,,,,,,,,,, -40000,0.27151188,3.938139,,,,,,,,,,,,,,,,, -40100,0.27877137,3.8696609,,,,,,,,,,,,,,,,, -40200,0.22414535,3.8674622,,,,,,,,,,,,,,,,, -40300,0.24950151,3.9072247,,,,,,,,,,,,,,,,, -40400,0.23557058,3.8940363,,,,,,,,,,,,,,,,, -40500,0.24725115,3.8403409,,,,,,,,,,,,,,,,, -40600,0.24990942,3.9262323,,,,,,,,,,,,,,,,, -40700,0.25658485,3.8943822,,,,,,,,,,,,,,,,, -40765,,,0.6568623781204224,1.846411228179932,32.29865399133282,0.6735936403274536,1.7242670059204102,29.367416114376645,3000.0,0.685631275177002,1.6500742435455322,29.24589113563988,3003.0,14309.25852894783,23190.842586278915,14309.25852894783,8879.810878276825,0.489011287689209,0.0 -40800,0.22726344,3.8380334,,,,,,,,,,,,,,,,, -40900,0.21409288,3.906471,,,,,,,,,,,,,,,,, -41000,0.24562576,3.9067008,,,,,,,,,,,,,,,,, -41100,0.23249127,3.9282117,,,,,,,,,,,,,,,,, -41200,0.23378037,3.908043,,,,,,,,,,,,,,,,, -41300,0.2664613,3.8823838,,,,,,,,,,,,,,,,, -41400,0.25100005,3.8141096,,,,,,,,,,,,,,,,, -41500,0.250061,3.8648288,,,,,,,,,,,,,,,,, -41600,0.22962537,3.9222708,,,,,,,,,,,,,,,,, -41700,0.28941402,3.9206536,,,,,,,,,,,,,,,,, -41800,0.2556326,3.865707,,,,,,,,,,,,,,,,, -41900,0.24352357,3.9045842,,,,,,,,,,,,,,,,, -42000,0.22699332,3.8823345,,,,,,,,,,,,,,,,, -42100,0.26771247,3.8630862,,,,,,,,,,,,,,,,, -42200,0.25342095,3.9936461,,,,,,,,,,,,,,,,, -42300,0.24219424,3.9181287,,,,,,,,,,,,,,,,, -42400,0.27276152,3.8715603,,,,,,,,,,,,,,,,, -42500,0.25749502,3.9198074,,,,,,,,,,,,,,,,, -42600,0.25714043,3.875933,,,,,,,,,,,,,,,,, -42700,0.2785797,3.939097,,,,,,,,,,,,,,,,, -42800,0.27076477,3.9060075,,,,,,,,,,,,,,,,, -42900,0.24407865,3.8553033,,,,,,,,,,,,,,,,, -43000,0.2454505,3.8234754,,,,,,,,,,,,,,,,, -43100,0.27168465,3.8878314,,,,,,,,,,,,,,,,, -43162,,,0.6540481448173523,1.848538517951965,31.76641579588908,0.674411952495575,1.7126529216766355,29.176469017738977,3000.0,0.6875603199005127,1.6370383501052856,28.950136975521986,3003.0,15149.16226387024,24531.530041217804,15149.16226387024,9380.485354423525,0.5209271907806396,0.0 -43200,0.25774458,3.871628,,,,,,,,,,,,,,,,, -43300,0.24632972,3.848533,,,,,,,,,,,,,,,,, -43400,0.26142922,3.8953836,,,,,,,,,,,,,,,,, -43500,0.25506085,3.847335,,,,,,,,,,,,,,,,, -43600,0.2509547,3.9170783,,,,,,,,,,,,,,,,, -43700,0.28935552,3.9247415,,,,,,,,,,,,,,,,, -43800,0.2535453,3.8410025,,,,,,,,,,,,,,,,, -43900,0.22415714,3.8613038,,,,,,,,,,,,,,,,, -44000,0.24849796,3.840971,,,,,,,,,,,,,,,,, -44100,0.25266537,3.9006329,,,,,,,,,,,,,,,,, -44200,0.23071879,3.8774357,,,,,,,,,,,,,,,,, -44300,0.27654934,3.8080223,,,,,,,,,,,,,,,,, -44400,0.2660741,3.8406665,,,,,,,,,,,,,,,,, -44500,0.22931626,3.8983254,,,,,,,,,,,,,,,,, -44600,0.2293714,3.8903916,,,,,,,,,,,,,,,,, -44700,1.3969446,7.1588616,,,,,,,,,,,,,,,,, -44800,0.6324692,5.618645,,,,,,,,,,,,,,,,, -44900,0.4785485,5.5239353,,,,,,,,,,,,,,,,, -45000,1.0477802,5.498213,,,,,,,,,,,,,,,,, -45100,0.28017196,5.4408336,,,,,,,,,,,,,,,,, -45200,0.8077594,5.446749,,,,,,,,,,,,,,,,, -45300,0.50901,5.4159584,,,,,,,,,,,,,,,,, -45400,0.5441701,5.436289,,,,,,,,,,,,,,,,, -45500,0.2992488,5.403451,,,,,,,,,,,,,,,,, -45563,,,0.3642924129962921,3.769196033477783,1.166315163018871,0.324819266796112,4.211953163146973,0.1726707275755886,3000.0,0.3135901391506195,4.377139091491699,0.1555851035313861,3003.0,15989.111895561218,25980.68439888954,15989.111895561218,9989.577900886536,0.5572404861450195,0.0 -45600,1.2150499,5.404718,,,,,,,,,,,,,,,,, -45700,0.4860091,4.0753455,,,,,,,,,,,,,,,,, -45800,0.24992166,3.8961172,,,,,,,,,,,,,,,,, -45900,0.2940683,3.860371,,,,,,,,,,,,,,,,, -46000,0.23011598,3.8727887,,,,,,,,,,,,,,,,, -46100,0.22694445,3.9152255,,,,,,,,,,,,,,,,, -46200,0.25547868,3.8529987,,,,,,,,,,,,,,,,, -46300,0.22853407,3.8834958,,,,,,,,,,,,,,,,, -46400,0.30068398,3.9468951,,,,,,,,,,,,,,,,, -46500,0.31216037,3.9380755,,,,,,,,,,,,,,,,, -46600,0.25144482,3.9191046,,,,,,,,,,,,,,,,, -46700,0.21511671,3.9159918,,,,,,,,,,,,,,,,, -46800,0.25128132,3.8974578,,,,,,,,,,,,,,,,, -46900,0.24753985,3.8700335,,,,,,,,,,,,,,,,, -47000,0.24791948,3.8802438,,,,,,,,,,,,,,,,, -47100,0.37597668,3.8481462,,,,,,,,,,,,,,,,, -47200,0.2833214,3.8898222,,,,,,,,,,,,,,,,, -47300,0.23451331,3.8453326,,,,,,,,,,,,,,,,, -47400,0.23526226,3.876654,,,,,,,,,,,,,,,,, -47500,0.26163706,3.85581,,,,,,,,,,,,,,,,, -47600,0.25535855,3.8808818,,,,,,,,,,,,,,,,, -47700,0.2833526,3.910405,,,,,,,,,,,,,,,,, -47800,0.2800193,3.8744926,,,,,,,,,,,,,,,,, -47900,0.26854444,3.8618796,,,,,,,,,,,,,,,,, -47961,,,0.6598178744316101,1.826006293296814,32.72687596383734,0.6744863390922546,1.7195638418197632,29.72216502720672,3000.0,0.6885131597518921,1.6407841444015503,29.155636425634672,3003.0,16829.141446828842,27273.812072753903,16829.141446828842,10442.567687034609,0.5890069007873535,0.0 -48000,0.29866767,3.8975418,,,,,,,,,,,,,,,,, -48100,0.3403175,3.878272,,,,,,,,,,,,,,,,, -48200,0.23817182,3.9215512,,,,,,,,,,,,,,,,, -48300,0.2751369,3.8769338,,,,,,,,,,,,,,,,, -48400,0.257367,3.8382604,,,,,,,,,,,,,,,,, -48500,0.26162723,3.8665688,,,,,,,,,,,,,,,,, -48600,0.24341543,3.9269946,,,,,,,,,,,,,,,,, -48700,0.23998454,3.8986673,,,,,,,,,,,,,,,,, -48800,0.25691172,3.8218303,,,,,,,,,,,,,,,,, -48900,0.2528683,3.8254516,,,,,,,,,,,,,,,,, -49000,0.2715916,3.883153,,,,,,,,,,,,,,,,, -49100,0.3399203,3.8550384,,,,,,,,,,,,,,,,, -49200,0.27517724,3.8206587,,,,,,,,,,,,,,,,, -49300,0.24906886,3.8193645,,,,,,,,,,,,,,,,, -49400,0.30111286,3.8470562,,,,,,,,,,,,,,,,, -49500,0.2839908,3.9157474,,,,,,,,,,,,,,,,, -49600,0.39986366,3.86418,,,,,,,,,,,,,,,,, -49700,0.2699401,3.9260387,,,,,,,,,,,,,,,,, -49800,0.26715186,3.8396313,,,,,,,,,,,,,,,,, -49900,0.24699359,3.8959162,,,,,,,,,,,,,,,,, -50000,0.2796536,3.8590312,,,,,,,,,,,,,,,,, -50100,0.32896036,3.951729,,,,,,,,,,,,,,,,, -50200,0.26258686,3.8729496,,,,,,,,,,,,,,,,, -50300,0.29794383,3.8129137,,,,,,,,,,,,,,,,, -50361,,,0.6838412880897522,1.6663483381271362,33.974547470472466,0.6764330267906189,1.7032217979431152,29.85961521752938,3000.0,0.6898379325866699,1.6269718408584597,29.56486532290584,3003.0,17669.30721282959,28763.591370821,17669.30721282959,11092.072012662888,0.622807502746582,0.0 -50400,0.25453916,3.90519,,,,,,,,,,,,,,,,, -50500,0.24286856,3.8989162,,,,,,,,,,,,,,,,, -50600,0.29634503,3.8354037,,,,,,,,,,,,,,,,, -50700,0.2559726,3.850367,,,,,,,,,,,,,,,,, -50800,0.2704564,3.8706973,,,,,,,,,,,,,,,,, -50900,0.26627168,3.8967469,,,,,,,,,,,,,,,,, -51000,0.32653204,3.8793445,,,,,,,,,,,,,,,,, -51100,0.23174472,3.8642433,,,,,,,,,,,,,,,,, -51200,0.2838489,3.9215345,,,,,,,,,,,,,,,,, -51300,0.24060972,3.887468,,,,,,,,,,,,,,,,, -51400,0.2705256,3.8514287,,,,,,,,,,,,,,,,, -51500,0.23818898,3.8768566,,,,,,,,,,,,,,,,, -51600,0.25558487,3.9416375,,,,,,,,,,,,,,,,, -51700,0.23711067,3.8777409,,,,,,,,,,,,,,,,, -51800,9.358762,4.6832223,,,,,,,,,,,,,,,,, -51900,0.23383054,3.8472059,,,,,,,,,,,,,,,,, -52000,0.25387296,3.8474844,,,,,,,,,,,,,,,,, -52100,0.2653298,3.9341643,,,,,,,,,,,,,,,,, -52200,0.22439891,3.8552885,,,,,,,,,,,,,,,,, -52300,0.2532725,3.8549361,,,,,,,,,,,,,,,,, -52400,0.26029274,3.9019237,,,,,,,,,,,,,,,,, -52500,0.24356703,3.928137,,,,,,,,,,,,,,,,, -52600,0.26203725,3.8914971,,,,,,,,,,,,,,,,, -52700,0.24731396,3.888574,,,,,,,,,,,,,,,,, -52761,,,0.6685400009155273,1.7587262392044067,32.919607532370215,0.677734911441803,1.6866081953048706,29.581894504613768,3000.0,0.6924641132354736,1.6087028980255127,29.45363869235563,3003.0,18509.54104423523,30157.93383049965,18509.54104423523,11646.070712327955,0.6562457084655762,0.0 -52800,0.2432878,3.8999257,,,,,,,,,,,,,,,,, -52900,0.2875909,3.9146473,,,,,,,,,,,,,,,,, -53000,0.24922043,3.90193,,,,,,,,,,,,,,,,, -53100,0.26772305,3.8870378,,,,,,,,,,,,,,,,, -53200,0.23836008,3.8458903,,,,,,,,,,,,,,,,, -53300,0.29219553,3.8263762,,,,,,,,,,,,,,,,, -53400,0.2627175,3.881389,,,,,,,,,,,,,,,,, -53500,0.30250636,3.843916,,,,,,,,,,,,,,,,, -53600,0.2781701,3.886772,,,,,,,,,,,,,,,,, -53700,0.25047544,3.8590205,,,,,,,,,,,,,,,,, -53800,0.30561703,3.9003716,,,,,,,,,,,,,,,,, -53900,0.24660644,3.8206892,,,,,,,,,,,,,,,,, -54000,0.24795903,3.8269255,,,,,,,,,,,,,,,,, -54100,0.25091487,3.8013456,,,,,,,,,,,,,,,,, -54200,0.25627264,3.864263,,,,,,,,,,,,,,,,, -54300,0.37677628,3.8024182,,,,,,,,,,,,,,,,, -54400,0.2576914,3.803591,,,,,,,,,,,,,,,,, -54500,0.28226307,3.8513198,,,,,,,,,,,,,,,,, -54600,0.24773948,3.8767345,,,,,,,,,,,,,,,,, -54700,0.30119312,3.8541403,,,,,,,,,,,,,,,,, -54800,0.26598722,3.811965,,,,,,,,,,,,,,,,, -54900,0.27533975,3.8770103,,,,,,,,,,,,,,,,, -55000,0.26821154,3.8302882,,,,,,,,,,,,,,,,, -55100,0.30489892,3.899926,,,,,,,,,,,,,,,,, -55160,,,0.6608024835586548,1.809146761894226,32.987301429211975,0.6783052682876587,1.6888394355773926,30.181948989674293,3000.0,0.6906862258911133,1.611272215843201,29.408274314513523,3003.0,19349.585392713547,31470.043719291687,19349.585392713547,12118.026804924011,0.6894242763519287,0.0 -55200,0.2568926,3.8023424,,,,,,,,,,,,,,,,, -55300,0.26279595,3.8875208,,,,,,,,,,,,,,,,, -55400,0.26251072,3.8585675,,,,,,,,,,,,,,,,, -55500,0.24477503,3.830803,,,,,,,,,,,,,,,,, -55600,0.3037454,3.8757074,,,,,,,,,,,,,,,,, -55700,0.24451135,3.8107119,,,,,,,,,,,,,,,,, -55800,0.25226745,3.8165386,,,,,,,,,,,,,,,,, -55900,0.30203557,3.9054036,,,,,,,,,,,,,,,,, -56000,0.26959497,3.815875,,,,,,,,,,,,,,,,, -56100,0.2644166,3.8102233,,,,,,,,,,,,,,,,, -56200,0.26595724,3.896327,,,,,,,,,,,,,,,,, -56300,0.34287995,3.8847744,,,,,,,,,,,,,,,,, -56400,0.24941929,3.827832,,,,,,,,,,,,,,,,, -56500,0.26958126,3.8538861,,,,,,,,,,,,,,,,, -56600,0.28762403,3.8160026,,,,,,,,,,,,,,,,, -56700,0.30110446,3.8974314,,,,,,,,,,,,,,,,, -56800,0.26166487,3.9093692,,,,,,,,,,,,,,,,, -56900,0.37930363,3.9369078,,,,,,,,,,,,,,,,, -57000,0.2554092,3.8167667,,,,,,,,,,,,,,,,, -57100,0.26699308,3.8191721,,,,,,,,,,,,,,,,, -57200,0.25055772,3.859081,,,,,,,,,,,,,,,,, -57300,0.24628551,3.8546946,,,,,,,,,,,,,,,,, -57400,0.26096508,3.8407037,,,,,,,,,,,,,,,,, -57500,0.25283366,3.9362783,,,,,,,,,,,,,,,,, -57559,,,0.6720222234725952,1.7211450338363647,33.59649108157888,0.6793467998504639,1.676062822341919,29.910391819317304,3000.0,0.6937656402587891,1.5925296545028689,29.84514057531192,3003.0,20189.579274892807,32757.81765937805,20189.579274892807,12565.69912815094,0.722074031829834,0.0 -57600,0.2822575,3.8690164,,,,,,,,,,,,,,,,, -57700,0.27513155,3.8997586,,,,,,,,,,,,,,,,, -57800,0.26501343,3.900924,,,,,,,,,,,,,,,,, -57900,0.24402983,3.7755628,,,,,,,,,,,,,,,,, -58000,0.42464197,3.820827,,,,,,,,,,,,,,,,, -58100,0.26683533,3.838157,,,,,,,,,,,,,,,,, -58200,0.25932217,3.7808597,,,,,,,,,,,,,,,,, -58300,0.24167843,3.8238132,,,,,,,,,,,,,,,,, -58400,0.2581717,3.835771,,,,,,,,,,,,,,,,, -58500,0.25411382,3.8606842,,,,,,,,,,,,,,,,, -58600,0.27180633,3.8890438,,,,,,,,,,,,,,,,, -58700,0.26651198,3.840267,,,,,,,,,,,,,,,,, -58800,0.2758994,3.8326821,,,,,,,,,,,,,,,,, -58900,0.26413533,3.775493,,,,,,,,,,,,,,,,, -59000,0.26890555,3.8723352,,,,,,,,,,,,,,,,, -59100,0.26503932,3.8449903,,,,,,,,,,,,,,,,, -59200,0.29642168,3.842021,,,,,,,,,,,,,,,,, -59300,0.30569717,3.8720121,,,,,,,,,,,,,,,,, -59400,0.26563013,3.8324943,,,,,,,,,,,,,,,,, -59500,0.26659665,3.8319685,,,,,,,,,,,,,,,,, -59600,0.31077263,3.8516138,,,,,,,,,,,,,,,,, -59700,0.29838163,3.8188965,,,,,,,,,,,,,,,,, -59800,0.2881954,3.8153396,,,,,,,,,,,,,,,,, -59900,0.3220459,3.8094192,,,,,,,,,,,,,,,,, -59957,,,0.6659659147262573,1.7736477851867676,33.462413155122306,0.6806362867355347,1.6755925416946411,30.146721231820003,3000.0,0.6962756514549255,1.587439775466919,29.90336671635429,3003.0,21029.69254183769,34085.48671579361,21029.69254183769,13053.144666194916,0.7552039623260498,0.0 -60000,0.26273885,3.8619795,,,,,,,,,,,,,,,,, -60100,0.28849146,3.8549924,,,,,,,,,,,,,,,,, -60200,0.30289075,3.8007836,,,,,,,,,,,,,,,,, -60300,0.47755665,4.058439,,,,,,,,,,,,,,,,, -60400,0.28252715,3.8877292,,,,,,,,,,,,,,,,, -60500,0.2806847,3.882089,,,,,,,,,,,,,,,,, -60600,0.258053,3.8473725,,,,,,,,,,,,,,,,, -60700,0.25821874,3.8432238,,,,,,,,,,,,,,,,, -60800,0.24782088,3.8489563,,,,,,,,,,,,,,,,, -60900,0.2542057,3.8805776,,,,,,,,,,,,,,,,, -61000,0.27610567,3.8417861,,,,,,,,,,,,,,,,, -61100,0.2598098,3.7839339,,,,,,,,,,,,,,,,, -61200,0.24519984,3.8281825,,,,,,,,,,,,,,,,, -61300,0.26965263,3.8693795,,,,,,,,,,,,,,,,, -61400,0.30249277,3.8302727,,,,,,,,,,,,,,,,, -61500,0.2562273,3.8310053,,,,,,,,,,,,,,,,, -61600,0.25591853,3.7993443,,,,,,,,,,,,,,,,, -61700,0.25854024,3.8118036,,,,,,,,,,,,,,,,, -61800,0.3566124,3.8667095,,,,,,,,,,,,,,,,, -61900,0.32075602,3.8262522,,,,,,,,,,,,,,,,, -62000,0.2717216,3.8657916,,,,,,,,,,,,,,,,, -62100,0.25398958,3.7767854,,,,,,,,,,,,,,,,, -62200,0.29602942,3.8451958,,,,,,,,,,,,,,,,, -62300,0.29409644,3.767007,,,,,,,,,,,,,,,,, -62356,,,0.6661973595619202,1.785617709159851,33.18605866469831,0.6803635358810425,1.675857663154602,29.76717653235124,3000.0,0.6968218088150024,1.5876115560531616,29.91820395719109,3003.0,21869.68085551262,35374.082418203354,21869.68085551262,13501.640924453735,0.7891860008239746,0.0 -62400,0.24702486,3.7893682,,,,,,,,,,,,,,,,, -62500,0.29563636,3.8096743,,,,,,,,,,,,,,,,, -62600,0.2993174,3.814311,,,,,,,,,,,,,,,,, -62700,0.31286013,3.7899,,,,,,,,,,,,,,,,, -62800,0.2680117,3.842303,,,,,,,,,,,,,,,,, -62900,0.2509497,3.7615612,,,,,,,,,,,,,,,,, -63000,0.28603536,3.877984,,,,,,,,,,,,,,,,, -63100,0.26417807,3.7842588,,,,,,,,,,,,,,,,, -63200,0.28370687,3.8267481,,,,,,,,,,,,,,,,, -63300,0.27803534,3.8570042,,,,,,,,,,,,,,,,, -63400,0.33283043,3.9137654,,,,,,,,,,,,,,,,, -63500,0.2670257,3.816608,,,,,,,,,,,,,,,,, -63600,0.29148987,3.791512,,,,,,,,,,,,,,,,, -63700,0.26819235,3.7945,,,,,,,,,,,,,,,,, -63800,0.3223262,3.91402,,,,,,,,,,,,,,,,, -63900,0.25115457,3.834349,,,,,,,,,,,,,,,,, -64000,0.31358203,3.840465,,,,,,,,,,,,,,,,, -64100,0.2950527,3.9020972,,,,,,,,,,,,,,,,, -64200,0.2812823,3.8705819,,,,,,,,,,,,,,,,, -64300,0.27591386,3.8082058,,,,,,,,,,,,,,,,, -64400,0.2967991,3.8551762,,,,,,,,,,,,,,,,, -64500,0.27549717,3.843357,,,,,,,,,,,,,,,,, -64600,0.27051473,3.7775013,,,,,,,,,,,,,,,,, -64700,0.35059246,3.7626545,,,,,,,,,,,,,,,,, -64756,,,0.6719292402267456,1.71939218044281,33.31623273945491,0.6818762421607971,1.6507298946380615,30.07580176805558,3000.0,0.696577787399292,1.567109227180481,30.02749926620833,3003.0,22709.793394088745,36698.05253005028,22709.793394088745,13985.39007115364,0.8220915794372559,0.0 -64800,0.25937682,3.8252604,,,,,,,,,,,,,,,,, -64900,0.31659073,3.8053713,,,,,,,,,,,,,,,,, -65000,0.29413867,3.8462527,,,,,,,,,,,,,,,,, -65100,0.26424032,3.8109539,,,,,,,,,,,,,,,,, -65200,0.27729017,3.8533754,,,,,,,,,,,,,,,,, -65300,0.27764922,3.8201044,,,,,,,,,,,,,,,,, -65400,0.31542382,3.8689938,,,,,,,,,,,,,,,,, -65500,0.26388133,3.7946205,,,,,,,,,,,,,,,,, -65600,0.31575778,3.8421428,,,,,,,,,,,,,,,,, -65700,0.28019357,3.8186133,,,,,,,,,,,,,,,,, -65800,0.26726678,3.8102446,,,,,,,,,,,,,,,,, -65900,0.27997944,3.8028557,,,,,,,,,,,,,,,,, -66000,0.26190123,3.8183005,,,,,,,,,,,,,,,,, -66100,0.28285295,3.800263,,,,,,,,,,,,,,,,, -66200,0.33111787,3.7976758,,,,,,,,,,,,,,,,, -66300,0.32621047,3.7479422,,,,,,,,,,,,,,,,, -66400,0.25599197,3.7709503,,,,,,,,,,,,,,,,, -66500,0.2674576,3.8340406,,,,,,,,,,,,,,,,, -66600,0.2582287,3.7946365,,,,,,,,,,,,,,,,, -66700,0.27072293,3.8050296,,,,,,,,,,,,,,,,, -66800,0.25789702,3.7553003,,,,,,,,,,,,,,,,, -66900,0.28048402,3.8328395,,,,,,,,,,,,,,,,, -67000,0.26294768,3.8477767,,,,,,,,,,,,,,,,, -67100,0.27937123,3.7598672,,,,,,,,,,,,,,,,, -67155,,,0.6676704287528992,1.753592014312744,32.8873355357405,0.6813802719116211,1.6592007875442505,29.90753576601113,3000.0,0.6964499354362488,1.5692729949951172,29.772931959805444,3003.0,23549.83774733544,37985.99623680115,23549.83774733544,14433.178615570068,0.8567020893096924,0.0 -67200,0.30430734,3.7803762,,,,,,,,,,,,,,,,, -67300,0.2766429,3.7799501,,,,,,,,,,,,,,,,, -67400,0.2774531,3.85923,,,,,,,,,,,,,,,,, -67500,0.31576884,3.790229,,,,,,,,,,,,,,,,, -67600,0.30320954,3.810415,,,,,,,,,,,,,,,,, -67700,0.28232014,3.8729534,,,,,,,,,,,,,,,,, -67800,0.28035253,3.851918,,,,,,,,,,,,,,,,, -67900,0.27706718,3.7809877,,,,,,,,,,,,,,,,, -68000,0.27567756,3.7507014,,,,,,,,,,,,,,,,, -68100,0.28600752,3.839242,,,,,,,,,,,,,,,,, -68200,0.2876087,3.842384,,,,,,,,,,,,,,,,, -68300,0.27162313,3.7792625,,,,,,,,,,,,,,,,, -68400,0.27684334,3.8340461,,,,,,,,,,,,,,,,, -68500,0.278625,3.7978055,,,,,,,,,,,,,,,,, -68600,0.2956519,3.813085,,,,,,,,,,,,,,,,, -68700,0.29654995,3.8475175,,,,,,,,,,,,,,,,, -68800,0.28239414,3.876857,,,,,,,,,,,,,,,,, -68900,0.2677967,3.7334352,,,,,,,,,,,,,,,,, -69000,0.28781822,3.8246577,,,,,,,,,,,,,,,,, -69100,0.2593597,3.8094609,,,,,,,,,,,,,,,,, -69200,0.2797255,3.809297,,,,,,,,,,,,,,,,, -69300,0.28260675,3.7932305,,,,,,,,,,,,,,,,, -69400,0.26822937,3.7590513,,,,,,,,,,,,,,,,, -69500,0.268042,3.7995932,,,,,,,,,,,,,,,,, -69554,,,0.6829418540000916,1.6505063772201538,34.84525919786184,0.6819506287574768,1.6456681489944458,30.26084758035537,3000.0,0.6994596719741821,1.5576194524765017,30.31156286628029,3003.0,24389.92008113861,39319.03587079048,24389.92008113861,14926.024219036102,0.891709566116333,0.0 -69600,0.2666105,3.7562485,,,,,,,,,,,,,,,,, -69700,0.26694047,3.8042305,,,,,,,,,,,,,,,,, -69800,0.27453795,3.8085027,,,,,,,,,,,,,,,,, -69900,0.29344383,3.8291621,,,,,,,,,,,,,,,,, -70000,0.2822352,3.7682092,,,,,,,,,,,,,,,,, -70100,0.2620981,3.7691236,,,,,,,,,,,,,,,,, -70200,0.2970275,3.809002,,,,,,,,,,,,,,,,, -70300,0.30310062,3.782517,,,,,,,,,,,,,,,,, -70400,0.26153067,3.798574,,,,,,,,,,,,,,,,, -70500,0.29194975,3.7612221,,,,,,,,,,,,,,,,, -70600,0.2740642,3.8179278,,,,,,,,,,,,,,,,, -70700,0.29763067,3.8394375,,,,,,,,,,,,,,,,, -70800,0.26490587,3.7872694,,,,,,,,,,,,,,,,, -70900,0.30275708,3.8769557,,,,,,,,,,,,,,,,, -71000,0.27990815,3.8125238,,,,,,,,,,,,,,,,, -71100,0.2735693,3.7610607,,,,,,,,,,,,,,,,, -71200,0.2845376,3.7983563,,,,,,,,,,,,,,,,, -71300,0.29488552,3.8488557,,,,,,,,,,,,,,,,, -71400,0.26227713,3.7459798,,,,,,,,,,,,,,,,, -71500,0.305874,3.7484338,,,,,,,,,,,,,,,,, -71600,0.29102203,3.785411,,,,,,,,,,,,,,,,, -71700,0.26621377,3.8061895,,,,,,,,,,,,,,,,, -71800,0.27269658,3.8686204,,,,,,,,,,,,,,,,, -71900,0.27935243,3.7583256,,,,,,,,,,,,,,,,, -71953,,,0.6751496195793152,1.7136439085006714,34.044371473527406,0.683302104473114,1.647717833518982,30.32466585747712,3000.0,0.6987624168395996,1.561434030532837,30.00881668864493,3003.0,25229.987417936325,40641.21749544144,25229.987417936325,15408.027456998823,0.9275662899017334,0.0 -72000,0.27738985,3.818932,,,,,,,,,,,,,,,,, -72100,0.2775816,3.814783,,,,,,,,,,,,,,,,, -72200,0.27316102,3.7748754,,,,,,,,,,,,,,,,, -72300,0.2653427,3.7732415,,,,,,,,,,,,,,,,, -72400,0.291276,3.7761662,,,,,,,,,,,,,,,,, -72500,0.29714912,3.8124943,,,,,,,,,,,,,,,,, -72600,0.2600926,3.6923723,,,,,,,,,,,,,,,,, -72700,0.28616557,3.8318918,,,,,,,,,,,,,,,,, -72800,0.31833735,3.7999265,,,,,,,,,,,,,,,,, -72900,0.30749276,3.7907915,,,,,,,,,,,,,,,,, -73000,0.30857155,3.801901,,,,,,,,,,,,,,,,, -73100,0.31257546,3.7629092,,,,,,,,,,,,,,,,, -73200,0.28162998,3.7969,,,,,,,,,,,,,,,,, -73300,0.28375843,3.7462053,,,,,,,,,,,,,,,,, -73400,0.27384278,3.8083973,,,,,,,,,,,,,,,,, -73500,0.2721026,3.7627692,,,,,,,,,,,,,,,,, -73600,0.29805064,3.7693043,,,,,,,,,,,,,,,,, -73700,0.3088017,3.787443,,,,,,,,,,,,,,,,, -73800,0.26561952,3.771445,,,,,,,,,,,,,,,,, -73900,0.28570062,3.796341,,,,,,,,,,,,,,,,, -74000,0.2963619,3.7787387,,,,,,,,,,,,,,,,, -74100,0.278066,3.781478,,,,,,,,,,,,,,,,, -74200,0.2923922,3.8084023,,,,,,,,,,,,,,,,, -74300,0.29127628,3.7558784,,,,,,,,,,,,,,,,, -74352,,,0.6767979860305786,1.709175705909729,33.86340278617494,0.6831409335136414,1.6428673267364502,30.18935325973564,3000.0,0.6999825835227966,1.5527477264404297,30.17969249679205,3003.0,26069.960821390152,41962.50837039948,26069.960821390152,15889.227751493454,0.9696736335754396,0.0 -74400,0.27070543,3.8490484,,,,,,,,,,,,,,,,, -74500,0.26219597,3.7976472,,,,,,,,,,,,,,,,, -74600,0.26684237,3.79782,,,,,,,,,,,,,,,,, -74700,0.28847295,3.8079612,,,,,,,,,,,,,,,,, -74800,0.2796099,3.7616906,,,,,,,,,,,,,,,,, -74900,0.28157467,3.7117386,,,,,,,,,,,,,,,,, -75000,0.28304097,3.7729654,,,,,,,,,,,,,,,,, -75100,0.28859192,3.7642336,,,,,,,,,,,,,,,,, -75200,0.28834984,3.8101346,,,,,,,,,,,,,,,,, -75300,0.28687885,3.7390125,,,,,,,,,,,,,,,,, -75400,0.27826744,3.695401,,,,,,,,,,,,,,,,, -75500,0.2733219,3.7239046,,,,,,,,,,,,,,,,, -75600,0.29190606,3.8139534,,,,,,,,,,,,,,,,, -75700,0.26239935,3.738528,,,,,,,,,,,,,,,,, -75800,0.29636288,3.7755656,,,,,,,,,,,,,,,,, -75900,0.35959977,3.8082323,,,,,,,,,,,,,,,,, -76000,0.28467563,3.7884176,,,,,,,,,,,,,,,,, -76100,0.3633282,3.7983348,,,,,,,,,,,,,,,,, -76200,0.28107184,3.7980707,,,,,,,,,,,,,,,,, -76300,0.29808003,3.8367214,,,,,,,,,,,,,,,,, -76400,0.30058536,3.8299878,,,,,,,,,,,,,,,,, -76500,0.3056483,3.7948031,,,,,,,,,,,,,,,,, -76600,0.33378214,3.7756956,,,,,,,,,,,,,,,,, -76700,0.2732786,3.6799798,,,,,,,,,,,,,,,,, -76750,,,0.6845331788063049,1.6683518886566162,34.86974028404848,0.683984100818634,1.6397377252578735,30.00869039508164,3000.0,0.6988554000854492,1.5575097799301147,30.020957346602792,3003.0,26910.131754636765,43282.06278991699,26910.131754636765,16368.497725009918,1.0062716007232666,0.0 -76800,0.28421932,3.8019114,,,,,,,,,,,,,,,,, -76900,0.277604,3.807478,,,,,,,,,,,,,,,,, -77000,0.2998724,3.82106,,,,,,,,,,,,,,,,, -77100,0.30683136,3.7827652,,,,,,,,,,,,,,,,, -77200,0.28516725,3.7976024,,,,,,,,,,,,,,,,, -77300,0.30435365,3.8238726,,,,,,,,,,,,,,,,, -77400,0.3088973,3.799093,,,,,,,,,,,,,,,,, -77500,0.27970928,3.767721,,,,,,,,,,,,,,,,, -77600,0.307098,3.7371001,,,,,,,,,,,,,,,,, -77700,0.2905563,3.7498884,,,,,,,,,,,,,,,,, -77800,0.29184306,3.7340248,,,,,,,,,,,,,,,,, -77900,0.29782164,3.834337,,,,,,,,,,,,,,,,, -78000,0.28704116,3.7845316,,,,,,,,,,,,,,,,, -78100,0.29558453,3.7568774,,,,,,,,,,,,,,,,, -78200,0.27903375,3.7712836,,,,,,,,,,,,,,,,, -78300,0.2837206,3.7684343,,,,,,,,,,,,,,,,, -78400,0.28459987,3.7249215,,,,,,,,,,,,,,,,, -78500,0.27971134,3.7475674,,,,,,,,,,,,,,,,, -78600,0.2964153,3.8361168,,,,,,,,,,,,,,,,, -78700,0.31607994,3.8010406,,,,,,,,,,,,,,,,, -78800,0.2887645,3.778059,,,,,,,,,,,,,,,,, -78900,0.275027,3.7734394,,,,,,,,,,,,,,,,, -79000,0.29163706,3.8047178,,,,,,,,,,,,,,,,, -79100,0.2846315,3.7685387,,,,,,,,,,,,,,,,, -79149,,,0.6779137253761292,1.6993038654327393,33.970007069879514,0.684653639793396,1.6361262798309326,30.250622985921563,3000.0,0.7008192539215088,1.5451098680496216,30.329277062653787,3003.0,27750.265002012253,44589.64325428009,27750.265002012253,16835.83377790451,1.0420207977294922,0.0 -79200,0.2786895,3.7579954,,,,,,,,,,,,,,,,, -79300,0.2739799,3.7726436,,,,,,,,,,,,,,,,, -79400,0.3057973,3.759753,,,,,,,,,,,,,,,,, -79500,0.2940297,3.7242684,,,,,,,,,,,,,,,,, -79600,0.2917377,3.7449238,,,,,,,,,,,,,,,,, -79700,0.27712166,3.811197,,,,,,,,,,,,,,,,, -79800,0.28586337,3.7517984,,,,,,,,,,,,,,,,, -79900,0.2805293,3.7863266,,,,,,,,,,,,,,,,, -80000,0.3129311,3.7239115,,,,,,,,,,,,,,,,, -80100,0.33122793,3.7951515,,,,,,,,,,,,,,,,, -80200,0.28957146,3.7712946,,,,,,,,,,,,,,,,, -80300,0.28728056,3.7913349,,,,,,,,,,,,,,,,, -80400,0.28592664,3.7657878,,,,,,,,,,,,,,,,, -80500,0.28476346,3.7690384,,,,,,,,,,,,,,,,, -80600,0.29941776,3.7837832,,,,,,,,,,,,,,,,, -80700,0.28879097,3.7891533,,,,,,,,,,,,,,,,, -80800,0.2942315,3.7382205,,,,,,,,,,,,,,,,, -80900,0.29296765,3.8119419,,,,,,,,,,,,,,,,, -81000,0.28486377,3.7848318,,,,,,,,,,,,,,,,, -81100,0.28310126,3.7278063,,,,,,,,,,,,,,,,, -81200,0.27526563,3.7260036,,,,,,,,,,,,,,,,, -81300,0.2861855,3.7002523,,,,,,,,,,,,,,,,, -81400,0.30590957,3.863876,,,,,,,,,,,,,,,,, -81500,0.27880785,3.737994,,,,,,,,,,,,,,,,, -81549,,,0.7079644799232483,1.5214544534683228,36.10748172508141,0.6850627660751343,1.627657413482666,30.422390302489923,3000.0,0.701830267906189,1.5316359996795654,30.135994300813103,3003.0,28590.42280459404,45886.93843173981,28590.42280459404,17292.853385925293,1.0856754779815674,0.0 -81600,0.3122119,3.789076,,,,,,,,,,,,,,,,, -81700,0.30363867,3.7361221,,,,,,,,,,,,,,,,, -81800,0.30842307,3.792188,,,,,,,,,,,,,,,,, -81900,0.27699485,3.7552798,,,,,,,,,,,,,,,,, -82000,0.31385258,3.7275345,,,,,,,,,,,,,,,,, -82100,0.29076356,3.7452197,,,,,,,,,,,,,,,,, -82200,0.29913068,3.715428,,,,,,,,,,,,,,,,, -82300,0.3183922,3.810677,,,,,,,,,,,,,,,,, -82400,0.3063583,3.7468388,,,,,,,,,,,,,,,,, -82500,0.29217917,3.726456,,,,,,,,,,,,,,,,, -82600,0.29924762,3.7390087,,,,,,,,,,,,,,,,, -82700,0.28392076,3.7568386,,,,,,,,,,,,,,,,, -82800,0.3249446,3.7225897,,,,,,,,,,,,,,,,, -82900,0.29826874,3.774185,,,,,,,,,,,,,,,,, -83000,0.30535007,3.71723,,,,,,,,,,,,,,,,, -83100,0.27484277,3.7300026,,,,,,,,,,,,,,,,, -83200,0.3176328,3.7500725,,,,,,,,,,,,,,,,, -83300,0.28939936,3.729928,,,,,,,,,,,,,,,,, -83400,0.31085992,3.7408643,,,,,,,,,,,,,,,,, -83500,0.31861743,3.738977,,,,,,,,,,,,,,,,, -83600,0.29938915,3.7653568,,,,,,,,,,,,,,,,, -83700,0.2904432,3.8031166,,,,,,,,,,,,,,,,, -83800,0.30836123,3.802354,,,,,,,,,,,,,,,,, -83900,0.29401067,3.7825732,,,,,,,,,,,,,,,,, -83947,,,0.6867976188659668,1.6339534521102903,34.510879718808624,0.6870218515396118,1.6197893619537354,30.32294582121328,3000.0,0.7033990025520325,1.5269720554351809,30.45241823104937,3003.0,29430.36559510231,47181.868070364,29430.36559510231,17747.714488744736,1.1382238864898682,0.0 -84000,0.32918596,3.758812,,,,,,,,,,,,,,,,, -84100,0.31402576,3.7623296,,,,,,,,,,,,,,,,, -84200,0.32671908,3.7676077,,,,,,,,,,,,,,,,, -84300,0.32353762,3.79667,,,,,,,,,,,,,,,,, -84400,0.31809676,3.775208,,,,,,,,,,,,,,,,, -84500,0.32608575,3.7977166,,,,,,,,,,,,,,,,, -84600,0.30728698,3.7566097,,,,,,,,,,,,,,,,, -84700,0.302056,3.7052038,,,,,,,,,,,,,,,,, -84800,0.31287387,3.740299,,,,,,,,,,,,,,,,, -84900,0.30595422,3.7406642,,,,,,,,,,,,,,,,, -85000,0.3153772,3.7605066,,,,,,,,,,,,,,,,, -85100,0.29707053,3.7440462,,,,,,,,,,,,,,,,, -85200,0.307483,3.7523246,,,,,,,,,,,,,,,,, -85300,0.29036534,3.8120546,,,,,,,,,,,,,,,,, -85400,0.30490723,3.7648256,,,,,,,,,,,,,,,,, -85500,0.31001922,3.810722,,,,,,,,,,,,,,,,, -85600,0.31140295,3.7662055,,,,,,,,,,,,,,,,, -85700,0.33147252,3.7651153,,,,,,,,,,,,,,,,, -85800,0.32633924,3.7492306,,,,,,,,,,,,,,,,, -85900,0.28340545,3.7656188,,,,,,,,,,,,,,,,, -86000,0.35667658,3.794966,,,,,,,,,,,,,,,,, -86100,0.3218991,3.7980895,,,,,,,,,,,,,,,,, -86200,0.3278935,3.7401228,,,,,,,,,,,,,,,,, -86300,0.28570977,3.7586033,,,,,,,,,,,,,,,,, -86346,,,0.6832473278045654,1.6666216850280762,34.83272741944827,0.6869474649429321,1.6207863092422483,30.546641891474536,3000.0,0.7028993368148804,1.5305290222167969,30.48136215527729,3003.0,30270.410665035248,48519.62218332291,30270.410665035248,18245.310992002487,1.1753811836242676,0.0 -86400,0.36158195,3.767042,,,,,,,,,,,,,,,,, -86500,0.33219856,3.7647195,,,,,,,,,,,,,,,,, -86600,0.3131436,3.7641335,,,,,,,,,,,,,,,,, -86700,0.34723222,3.7478516,,,,,,,,,,,,,,,,, -86800,0.299004,3.7025535,,,,,,,,,,,,,,,,, -86900,0.3233328,3.7329302,,,,,,,,,,,,,,,,, -87000,0.293264,3.7124019,,,,,,,,,,,,,,,,, -87100,0.31030723,3.7544727,,,,,,,,,,,,,,,,, -87200,0.28897402,3.7220924,,,,,,,,,,,,,,,,, -87300,0.30262607,3.706585,,,,,,,,,,,,,,,,, -87400,0.30806798,3.7227578,,,,,,,,,,,,,,,,, -87500,0.30334207,3.6838946,,,,,,,,,,,,,,,,, -87600,0.32315782,3.7566767,,,,,,,,,,,,,,,,, -87700,0.3148348,3.720151,,,,,,,,,,,,,,,,, -87800,0.2982507,3.689038,,,,,,,,,,,,,,,,, -87900,0.3019324,3.7453918,,,,,,,,,,,,,,,,, -88000,0.308287,3.73642,,,,,,,,,,,,,,,,, -88100,0.30618823,3.750843,,,,,,,,,,,,,,,,, -88200,0.31422108,3.7610767,,,,,,,,,,,,,,,,, -88300,0.33487877,3.740419,,,,,,,,,,,,,,,,, -88400,0.30958256,3.7188482,,,,,,,,,,,,,,,,, -88500,0.30920908,3.7210221,,,,,,,,,,,,,,,,, -88600,0.32073194,3.769011,,,,,,,,,,,,,,,,, -88700,0.36247095,3.7735772,,,,,,,,,,,,,,,,, -88746,,,0.7003049254417419,1.5541876554489136,35.16525340135614,0.6866498589515686,1.6179324388504028,30.46771344521242,3000.0,0.7047004699707031,1.5223288536071775,30.563865944400103,3003.0,31110.53629755973,49828.37257575989,31110.53629755973,18713.81881380081,1.2181808948516846,0.0 -88800,0.30559143,3.7386007,,,,,,,,,,,,,,,,, -88900,0.3101192,3.725838,,,,,,,,,,,,,,,,, -89000,0.30706084,3.70138,,,,,,,,,,,,,,,,, -89100,0.31565848,3.7190962,,,,,,,,,,,,,,,,, -89200,0.3111298,3.7601917,,,,,,,,,,,,,,,,, -89300,0.2956893,3.6696603,,,,,,,,,,,,,,,,, -89400,0.2861574,3.6693296,,,,,,,,,,,,,,,,, -89500,0.32707042,3.7748094,,,,,,,,,,,,,,,,, -89600,0.36249155,3.7512107,,,,,,,,,,,,,,,,, -89700,0.30547762,3.763575,,,,,,,,,,,,,,,,, -89800,0.3132756,3.720048,,,,,,,,,,,,,,,,, -89900,0.31833854,3.770832,,,,,,,,,,,,,,,,, -90000,0.29798087,3.756782,,,,,,,,,,,,,,,,, -90100,0.31005338,3.7722194,,,,,,,,,,,,,,,,, -90200,0.30793074,3.7327416,,,,,,,,,,,,,,,,, -90300,0.3429956,3.780282,,,,,,,,,,,,,,,,, -90400,0.34596667,3.755532,,,,,,,,,,,,,,,,, -90500,0.3194334,3.7096503,,,,,,,,,,,,,,,,, -90600,0.33687544,3.77026,,,,,,,,,,,,,,,,, -90700,0.30561274,3.7276227,,,,,,,,,,,,,,,,, -90800,0.3081863,3.785145,,,,,,,,,,,,,,,,, -90900,0.3181205,3.7162175,,,,,,,,,,,,,,,,, -91000,0.31078434,3.7267206,,,,,,,,,,,,,,,,, -91100,0.31704956,3.6661325,,,,,,,,,,,,,,,,, -91145,,,0.6862742900848389,1.6273705959320068,34.99582119196632,0.6887701153755188,1.612417221069336,30.406371078037647,3000.0,0.7067689299583435,1.5179078578948977,30.68588993681621,3003.0,31950.433703422543,51123.85071182251,31950.433703422543,19169.279168844223,1.2637646198272705,0.0 -91200,0.318402,3.7377603,,,,,,,,,,,,,,,,, -91300,0.3181758,3.6902065,,,,,,,,,,,,,,,,, -91400,0.33485496,3.705072,,,,,,,,,,,,,,,,, -91500,0.32414028,3.704048,,,,,,,,,,,,,,,,, -91600,0.3211741,3.7716453,,,,,,,,,,,,,,,,, -91700,0.3125961,3.718521,,,,,,,,,,,,,,,,, -91800,0.30078465,3.7083113,,,,,,,,,,,,,,,,, -91900,0.36689955,3.6809726,,,,,,,,,,,,,,,,, -92000,0.35085824,3.7555037,,,,,,,,,,,,,,,,, -92100,0.30872723,3.752651,,,,,,,,,,,,,,,,, -92200,0.33645898,3.7226503,,,,,,,,,,,,,,,,, -92300,0.3099797,3.6933649,,,,,,,,,,,,,,,,, -92400,0.33641425,3.7774544,,,,,,,,,,,,,,,,, -92500,0.32671982,3.7123353,,,,,,,,,,,,,,,,, -92600,0.3293212,3.711197,,,,,,,,,,,,,,,,, -92700,0.3075014,3.6939542,,,,,,,,,,,,,,,,, -92800,0.3218294,3.7115552,,,,,,,,,,,,,,,,, -92900,0.30037168,3.692259,,,,,,,,,,,,,,,,, -93000,0.31302688,3.760218,,,,,,,,,,,,,,,,, -93100,0.32718828,3.7732663,,,,,,,,,,,,,,,,, -93200,0.3221502,3.7387881,,,,,,,,,,,,,,,,, -93300,0.32989544,3.729125,,,,,,,,,,,,,,,,, -93400,0.32559782,3.7175915,,,,,,,,,,,,,,,,, -93500,0.3357018,3.681769,,,,,,,,,,,,,,,,, -93544,,,0.6899000406265259,1.6128188371658323,34.90819228231661,0.6882493495941162,1.6095622777938845,30.37207243465497,3000.0,0.7058160901069641,1.51614511013031,30.48390357612902,3003.0,32790.42817759514,52405.03318023682,32790.42817759514,19610.351407289505,1.3041329383850098,0.0 -93600,0.3384742,3.6971416,,,,,,,,,,,,,,,,, -93700,0.30213687,3.679249,,,,,,,,,,,,,,,,, -93800,0.31738928,3.6578493,,,,,,,,,,,,,,,,, -93900,0.3319649,3.7294185,,,,,,,,,,,,,,,,, -94000,0.32752004,3.6798465,,,,,,,,,,,,,,,,, -94100,0.33169568,3.6880834,,,,,,,,,,,,,,,,, -94200,0.32307905,3.7189808,,,,,,,,,,,,,,,,, -94300,0.33275488,3.7295983,,,,,,,,,,,,,,,,, -94400,0.31879944,3.7192576,,,,,,,,,,,,,,,,, -94500,0.32812,3.7330463,,,,,,,,,,,,,,,,, -94600,0.31815284,3.7445114,,,,,,,,,,,,,,,,, -94700,0.31768423,3.7005115,,,,,,,,,,,,,,,,, -94800,0.31159526,3.6874604,,,,,,,,,,,,,,,,, -94900,0.31957904,3.7026372,,,,,,,,,,,,,,,,, -95000,0.325797,3.6835895,,,,,,,,,,,,,,,,, -95100,0.33679163,3.7318404,,,,,,,,,,,,,,,,, -95200,0.31084105,3.7473717,,,,,,,,,,,,,,,,, -95300,0.33956227,3.7290177,,,,,,,,,,,,,,,,, -95400,0.33128652,3.6516457,,,,,,,,,,,,,,,,, -95500,0.3265758,3.656774,,,,,,,,,,,,,,,,, -95600,0.31633252,3.692401,,,,,,,,,,,,,,,,, -95700,0.34271657,3.7673671,,,,,,,,,,,,,,,,, -95800,0.3385883,3.7442405,,,,,,,,,,,,,,,,, -95900,0.34833008,3.766902,,,,,,,,,,,,,,,,, -95943,,,0.697588324546814,1.5702698230743408,35.739230605844355,0.6895760893821716,1.610058069229126,30.57486585549891,3000.0,0.7060717344284058,1.5157976150512695,30.849107356658006,3003.0,33630.38622021675,53708.71432638168,33630.38622021675,20073.95977115631,1.3421132564544678,0.0 -96000,0.33379623,3.691605,,,,,,,,,,,,,,,,, -96100,0.31337053,3.6856744,,,,,,,,,,,,,,,,, -96200,0.3377152,3.672749,,,,,,,,,,,,,,,,, -96300,0.32014686,3.7376761,,,,,,,,,,,,,,,,, -96400,0.36087093,3.7479415,,,,,,,,,,,,,,,,, -96500,0.32046166,3.687497,,,,,,,,,,,,,,,,, -96600,0.3943041,3.750204,,,,,,,,,,,,,,,,, -96700,0.33141115,3.6774156,,,,,,,,,,,,,,,,, -96800,0.34723097,3.7567027,,,,,,,,,,,,,,,,, -96900,0.33216923,3.7067006,,,,,,,,,,,,,,,,, -97000,0.34151682,3.711203,,,,,,,,,,,,,,,,, -97100,0.3140573,3.6700628,,,,,,,,,,,,,,,,, -97200,0.34046045,3.6686425,,,,,,,,,,,,,,,,, -97300,0.3380678,3.71152,,,,,,,,,,,,,,,,, -97400,0.34332073,3.7137346,,,,,,,,,,,,,,,,, -97500,0.3377951,3.7168558,,,,,,,,,,,,,,,,, -97600,0.32798702,3.7032065,,,,,,,,,,,,,,,,, -97700,0.33750072,3.6883373,,,,,,,,,,,,,,,,, -97800,0.3560994,3.7315571,,,,,,,,,,,,,,,,, -97900,0.3451451,3.6829698,,,,,,,,,,,,,,,,, -98000,0.34651038,3.759617,,,,,,,,,,,,,,,,, -98100,0.340398,3.6952944,,,,,,,,,,,,,,,,, -98200,0.3252098,3.72274,,,,,,,,,,,,,,,,, -98300,0.34469643,3.748184,,,,,,,,,,,,,,,,, -98343,,,0.6986111998558044,1.5686570405960083,35.27395572303037,0.6885965466499329,1.6070998907089231,30.290472276833565,3000.0,0.7058973908424377,1.5149260759353638,30.59627386107932,3003.0,34470.52443647385,55018.911603450775,34470.52443647385,20543.905728816982,1.3811841011047363,0.0 -98400,0.33399817,3.642625,,,,,,,,,,,,,,,,, -98500,0.32195145,3.7028809,,,,,,,,,,,,,,,,, -98600,0.33641925,3.6923056,,,,,,,,,,,,,,,,, -98700,0.3335925,3.666046,,,,,,,,,,,,,,,,, -98800,0.35757902,3.7602644,,,,,,,,,,,,,,,,, -98900,0.3392601,3.7421937,,,,,,,,,,,,,,,,, -99000,0.3419593,3.6960618,,,,,,,,,,,,,,,,, -99100,0.33669347,3.6835673,,,,,,,,,,,,,,,,, -99200,0.33820927,3.6642408,,,,,,,,,,,,,,,,, -99300,0.32146528,3.688496,,,,,,,,,,,,,,,,, -99400,0.33667612,3.672687,,,,,,,,,,,,,,,,, -99500,0.33409905,3.6842592,,,,,,,,,,,,,,,,, -99600,0.33290383,3.688814,,,,,,,,,,,,,,,,, -99700,0.32923847,3.6394222,,,,,,,,,,,,,,,,, -99800,0.34835514,3.710738,,,,,,,,,,,,,,,,, -99900,0.3444603,3.6399004,,,,,,,,,,,,,,,,, -100000,0.34079078,3.6518812,,,,,,,,,,,,,,,,, -100100,0.34925538,3.7243972,,,,,,,,,,,,,,,,, -100200,0.33073002,3.6853647,,,,,,,,,,,,,,,,, -100300,0.35115466,3.6551385,,,,,,,,,,,,,,,,, -100400,0.34651297,3.6975307,,,,,,,,,,,,,,,,, -100500,0.34952414,3.690619,,,,,,,,,,,,,,,,, -100600,0.37192506,3.6745949,,,,,,,,,,,,,,,,, -100700,0.34616244,3.666658,,,,,,,,,,,,,,,,, -100742,,,0.710821807384491,1.5044238567352295,37.11803812727856,0.6893652677536011,1.6094584465026855,30.65314764916486,3000.0,0.7069548964500427,1.5113070011138916,30.70333007489312,3003.0,35310.65348362923,56318.56501626968,35310.65348362923,21003.31357169152,1.4205613136291504,0.0 -100800,0.3576803,3.6689868,,,,,,,,,,,,,,,,, -100900,0.32488272,3.6952257,,,,,,,,,,,,,,,,, -101000,0.36725432,3.7179575,,,,,,,,,,,,,,,,, -101100,0.3351234,3.6971176,,,,,,,,,,,,,,,,, -101200,0.3797909,3.7103693,,,,,,,,,,,,,,,,, -101300,0.36019585,3.7414284,,,,,,,,,,,,,,,,, -101400,0.35007372,3.7118585,,,,,,,,,,,,,,,,, -101500,0.32846436,3.6952538,,,,,,,,,,,,,,,,, -101600,0.3473851,3.6760824,,,,,,,,,,,,,,,,, -101700,0.36106536,3.676316,,,,,,,,,,,,,,,,, -101800,0.35603222,3.7050414,,,,,,,,,,,,,,,,, -101900,0.3403672,3.7469604,,,,,,,,,,,,,,,,, -102000,0.34124443,3.705977,,,,,,,,,,,,,,,,, -102100,0.33238992,3.6762516,,,,,,,,,,,,,,,,, -102200,0.35871303,3.7108655,,,,,,,,,,,,,,,,, -102300,0.35554025,3.6586518,,,,,,,,,,,,,,,,, -102400,0.33947814,3.6865458,,,,,,,,,,,,,,,,, -102500,0.34007815,3.6751533,,,,,,,,,,,,,,,,, -102600,0.3528516,3.6749082,,,,,,,,,,,,,,,,, -102700,0.3763921,3.7642488,,,,,,,,,,,,,,,,, -102800,0.35223427,3.7034824,,,,,,,,,,,,,,,,, -102900,0.3566442,3.6951857,,,,,,,,,,,,,,,,, -103000,0.3344976,3.682671,,,,,,,,,,,,,,,,, -103100,0.35794544,3.6660244,,,,,,,,,,,,,,,,, -103142,,,0.7058877348899841,1.533977508544922,36.64237829046024,0.6901340484619141,1.6044197082519531,30.61682211553897,3000.0,0.7081401348114014,1.5082166194915771,30.88696074093417,3003.0,36150.85769796372,57621.27230811119,36150.85769796372,21465.7024269104,1.4596493244171145,0.0 -103200,0.35169843,3.7305954,,,,,,,,,,,,,,,,, -103300,0.34487692,3.7175252,,,,,,,,,,,,,,,,, -103400,0.35904795,3.7024105,,,,,,,,,,,,,,,,, -103500,0.32583725,3.656802,,,,,,,,,,,,,,,,, -103600,0.34700572,3.6618004,,,,,,,,,,,,,,,,, -103700,0.33673513,3.6342173,,,,,,,,,,,,,,,,, -103800,0.35458753,3.7192376,,,,,,,,,,,,,,,,, -103900,0.3368621,3.6834533,,,,,,,,,,,,,,,,, -104000,0.35603678,3.6968172,,,,,,,,,,,,,,,,, -104100,0.3443255,3.6295977,,,,,,,,,,,,,,,,, -104200,0.3624349,3.685126,,,,,,,,,,,,,,,,, -104300,0.34240633,3.66572,,,,,,,,,,,,,,,,, -104400,0.36074904,3.7218275,,,,,,,,,,,,,,,,, -104500,0.36431718,3.6852436,,,,,,,,,,,,,,,,, -104600,0.35552394,3.644061,,,,,,,,,,,,,,,,, -104700,0.33728346,3.6651444,,,,,,,,,,,,,,,,, -104800,0.37277603,3.6630683,,,,,,,,,,,,,,,,, -104900,0.35261786,3.6657264,,,,,,,,,,,,,,,,, -105000,0.35334802,3.63041,,,,,,,,,,,,,,,,, -105100,0.35361525,3.6841242,,,,,,,,,,,,,,,,, -105200,0.34762922,3.6962667,,,,,,,,,,,,,,,,, -105300,0.34422794,3.6859016,,,,,,,,,,,,,,,,, -105400,0.34857634,3.646323,,,,,,,,,,,,,,,,, -105500,0.38970122,3.6680357,,,,,,,,,,,,,,,,, -105542,,,0.7012184858322144,1.5534294843673706,36.010873962839284,0.6888568997383118,1.603817582130432,30.25613399959102,3000.0,0.7084190249443054,1.5058610439300537,30.523381697580785,3003.0,36991.28224873543,58921.207073926926,36991.28224873543,21925.09859728813,1.5006978511810305,0.0 -105600,0.36074242,3.6532822,,,,,,,,,,,,,,,,, -105700,0.36227182,3.625073,,,,,,,,,,,,,,,,, -105800,0.3970789,3.6744695,,,,,,,,,,,,,,,,, -105900,0.37448987,3.7131448,,,,,,,,,,,,,,,,, -106000,0.35968682,3.6809132,,,,,,,,,,,,,,,,, -106100,0.34300604,3.646347,,,,,,,,,,,,,,,,, -106200,0.3882209,3.6589823,,,,,,,,,,,,,,,,, -106300,0.3691702,3.6781802,,,,,,,,,,,,,,,,, -106400,0.3504517,3.6478286,,,,,,,,,,,,,,,,, -106500,0.3592409,3.6538541,,,,,,,,,,,,,,,,, -106600,0.3428946,3.621333,,,,,,,,,,,,,,,,, -106700,0.3449254,3.6399198,,,,,,,,,,,,,,,,, -106800,0.356723,3.6618376,,,,,,,,,,,,,,,,, -106900,0.360749,3.6305933,,,,,,,,,,,,,,,,, -107000,0.37622514,3.743108,,,,,,,,,,,,,,,,, -107100,0.36797825,3.6486497,,,,,,,,,,,,,,,,, -107200,0.36030194,3.660451,,,,,,,,,,,,,,,,, -107300,0.38210452,3.6856143,,,,,,,,,,,,,,,,, -107400,0.37280712,3.7016878,,,,,,,,,,,,,,,,, -107500,0.35725188,3.5988495,,,,,,,,,,,,,,,,, -107600,0.3754002,3.6539953,,,,,,,,,,,,,,,,, -107700,0.354225,3.6162336,,,,,,,,,,,,,,,,, -107800,0.3498626,3.6804802,,,,,,,,,,,,,,,,, -107900,0.3558837,3.6187418,,,,,,,,,,,,,,,,, -107941,,,0.7100006937980652,1.5051331520080566,36.99780640058015,0.689811646938324,1.60414719581604,30.53258891829982,3000.0,0.7079658508300781,1.504045844078064,30.685208275329504,3003.0,37831.47907757759,60206.88630771637,37831.47907757759,22370.46321105957,1.5430285930633545,0.0 -108000,0.3807998,3.6661866,,,,,,,,,,,,,,,,, -108100,0.35654956,3.6528103,,,,,,,,,,,,,,,,, -108200,0.3633523,3.671673,,,,,,,,,,,,,,,,, -108300,0.35705605,3.683162,,,,,,,,,,,,,,,,, -108400,0.3986245,3.659108,,,,,,,,,,,,,,,,, -108500,0.37911355,3.602964,,,,,,,,,,,,,,,,, -108600,0.35978106,3.6697035,,,,,,,,,,,,,,,,, -108700,0.3634743,3.6237867,,,,,,,,,,,,,,,,, -108800,0.36531425,3.5848873,,,,,,,,,,,,,,,,, -108900,0.3592927,3.642789,,,,,,,,,,,,,,,,, -109000,0.36861852,3.6447766,,,,,,,,,,,,,,,,, -109100,0.35629767,3.6263216,,,,,,,,,,,,,,,,, -109200,0.38443896,3.6466694,,,,,,,,,,,,,,,,, -109300,0.36899242,3.6356642,,,,,,,,,,,,,,,,, -109400,0.38173142,3.6713703,,,,,,,,,,,,,,,,, -109500,0.38252905,3.6910634,,,,,,,,,,,,,,,,, -109600,0.37068418,3.6082144,,,,,,,,,,,,,,,,, -109700,0.35750648,3.6576152,,,,,,,,,,,,,,,,, -109800,0.3720738,3.640754,,,,,,,,,,,,,,,,, -109900,0.37342522,3.6586401,,,,,,,,,,,,,,,,, -110000,0.3715471,3.6430783,,,,,,,,,,,,,,,,, -110100,0.3771239,3.641489,,,,,,,,,,,,,,,,, -110200,0.36125097,3.6689348,,,,,,,,,,,,,,,,, -110300,0.36909592,3.594199,,,,,,,,,,,,,,,,, -110340,,,0.7088562846183777,1.5129475593566897,35.960513350985146,0.6904191970825195,1.6018781661987305,30.43087989291936,3000.0,0.7091511487960815,1.5065263509750366,30.69295551513549,3003.0,38671.50674414635,61503.93488526344,38671.50674414635,22827.366498708725,1.5851430892944336,0.0 -110400,0.38533598,3.637964,,,,,,,,,,,,,,,,, -110500,0.39044562,3.7089558,,,,,,,,,,,,,,,,, -110600,0.3857575,3.62547,,,,,,,,,,,,,,,,, -110700,0.37378183,3.6851048,,,,,,,,,,,,,,,,, -110800,0.36273932,3.653078,,,,,,,,,,,,,,,,, -110900,0.38365704,3.6677132,,,,,,,,,,,,,,,,, -111000,0.36961484,3.6983907,,,,,,,,,,,,,,,,, -111100,0.3631925,3.6039603,,,,,,,,,,,,,,,,, -111200,0.35673156,3.5811422,,,,,,,,,,,,,,,,, -111300,0.3695339,3.6710887,,,,,,,,,,,,,,,,, -111400,0.40088323,3.6333737,,,,,,,,,,,,,,,,, -111500,0.37943077,3.6623678,,,,,,,,,,,,,,,,, -111600,0.37749568,3.6169877,,,,,,,,,,,,,,,,, -111700,0.38099694,3.5997195,,,,,,,,,,,,,,,,, -111800,0.36522043,3.6511033,,,,,,,,,,,,,,,,, -111900,0.37490347,3.6103358,,,,,,,,,,,,,,,,, -112000,0.38169816,3.6159277,,,,,,,,,,,,,,,,, -112100,0.38445377,3.6350772,,,,,,,,,,,,,,,,, -112200,0.37852615,3.6295972,,,,,,,,,,,,,,,,, -112300,0.35626587,3.6492434,,,,,,,,,,,,,,,,, -112400,0.37785652,3.6331756,,,,,,,,,,,,,,,,, -112500,0.40001452,3.6917589,,,,,,,,,,,,,,,,, -112600,0.39587337,3.6799564,,,,,,,,,,,,,,,,, -112700,0.3679755,3.629842,,,,,,,,,,,,,,,,, -112740,,,0.7243646383285522,1.4386584758758545,37.83614617089411,0.6897496581077576,1.5978442430496216,30.64392620751296,3000.0,0.7088257670402527,1.500476598739624,30.616537234812025,3003.0,39511.70456314087,62802.67998576164,39511.70456314087,23285.7987074852,1.6249699592590332,0.0 -112800,0.37971961,3.5895002,,,,,,,,,,,,,,,,, -112900,0.37540144,3.6384552,,,,,,,,,,,,,,,,, -113000,0.3684945,3.6205862,,,,,,,,,,,,,,,,, -113100,0.38244727,3.6639512,,,,,,,,,,,,,,,,, -113200,0.37379628,3.60971,,,,,,,,,,,,,,,,, -113300,0.39878267,3.6370916,,,,,,,,,,,,,,,,, -113400,0.37465927,3.6114113,,,,,,,,,,,,,,,,, -113500,0.36982673,3.6767845,,,,,,,,,,,,,,,,, -113600,0.37796074,3.6394432,,,,,,,,,,,,,,,,, -113700,0.3754869,3.6316464,,,,,,,,,,,,,,,,, -113800,0.38573116,3.6612482,,,,,,,,,,,,,,,,, -113900,0.39551666,3.6872463,,,,,,,,,,,,,,,,, -114000,0.3668624,3.5778396,,,,,,,,,,,,,,,,, -114100,0.3734673,3.6638083,,,,,,,,,,,,,,,,, -114200,0.3809846,3.6323621,,,,,,,,,,,,,,,,, -114300,0.38426647,3.6544933,,,,,,,,,,,,,,,,, -114400,0.37055272,3.6123567,,,,,,,,,,,,,,,,, -114500,0.36305904,3.608463,,,,,,,,,,,,,,,,, -114600,0.39999792,3.622753,,,,,,,,,,,,,,,,, -114700,0.39677733,3.6294997,,,,,,,,,,,,,,,,, -114800,0.37684578,3.638529,,,,,,,,,,,,,,,,, -114900,0.37856168,3.59865,,,,,,,,,,,,,,,,, -115000,0.38013622,3.606969,,,,,,,,,,,,,,,,, -115100,0.3896995,3.6058218,,,,,,,,,,,,,,,,, -115139,,,0.7164209485054016,1.4740687608718872,37.26717590462045,0.6905679702758789,1.598555564880371,30.91084348148604,3000.0,0.7087792754173279,1.4993302822113037,30.776656235656343,3003.0,40351.63377261162,64096.69275712967,40351.63377261162,23739.756477594376,1.674572467803955,0.0 -115200,0.39996237,3.5941434,,,,,,,,,,,,,,,,, -115300,0.41234407,3.6829844,,,,,,,,,,,,,,,,, -115400,0.37978387,3.6376843,,,,,,,,,,,,,,,,, -115500,0.40439898,3.633159,,,,,,,,,,,,,,,,, -115600,0.37053818,3.6340842,,,,,,,,,,,,,,,,, -115700,0.38350147,3.6227717,,,,,,,,,,,,,,,,, -115800,0.38936406,3.609992,,,,,,,,,,,,,,,,, -115900,0.380146,3.6776044,,,,,,,,,,,,,,,,, -116000,0.3807597,3.5708063,,,,,,,,,,,,,,,,, -116100,0.37798938,3.594624,,,,,,,,,,,,,,,,, -116200,0.39129817,3.641533,,,,,,,,,,,,,,,,, -116300,0.39006463,3.6196506,,,,,,,,,,,,,,,,, -116400,0.37378675,3.582781,,,,,,,,,,,,,,,,, -116500,0.38538584,3.6298442,,,,,,,,,,,,,,,,, -116600,0.38852698,3.560669,,,,,,,,,,,,,,,,, -116700,0.38943008,3.6263137,,,,,,,,,,,,,,,,, -116800,0.39138663,3.6536925,,,,,,,,,,,,,,,,, -116900,0.39496043,3.6672926,,,,,,,,,,,,,,,,, -117000,0.38495868,3.6698885,,,,,,,,,,,,,,,,, -117100,0.3927277,3.6484833,,,,,,,,,,,,,,,,, -117200,0.38877547,3.6474485,,,,,,,,,,,,,,,,, -117300,0.39561737,3.5923204,,,,,,,,,,,,,,,,, -117400,0.41392827,3.6720593,,,,,,,,,,,,,,,,, -117500,0.36088762,3.585177,,,,,,,,,,,,,,,,, -117538,,,0.7153463959693909,1.4766026735305786,36.84388016626823,0.6916590929031372,1.6009763479232788,30.74155176517884,3000.0,0.7088141441345215,1.501481056213379,30.903441013207743,3003.0,41191.59611129761,65385.50169849396,41191.59611129761,24188.48369383812,1.718095302581787,0.0 -117600,0.37311262,3.5950227,,,,,,,,,,,,,,,,, -117700,0.38632447,3.64122,,,,,,,,,,,,,,,,, -117800,0.3919341,3.6262884,,,,,,,,,,,,,,,,, -117900,0.40099937,3.6110728,,,,,,,,,,,,,,,,, -118000,0.37903106,3.5688736,,,,,,,,,,,,,,,,, -118100,0.40460062,3.675954,,,,,,,,,,,,,,,,, -118200,0.4016811,3.6154194,,,,,,,,,,,,,,,,, -118300,0.39196002,3.636394,,,,,,,,,,,,,,,,, -118400,0.37509367,3.6220715,,,,,,,,,,,,,,,,, -118500,0.38357428,3.6268933,,,,,,,,,,,,,,,,, -118600,0.3915116,3.6054251,,,,,,,,,,,,,,,,, -118700,0.40163255,3.6664622,,,,,,,,,,,,,,,,, -118800,0.4029542,3.6162698,,,,,,,,,,,,,,,,, -118900,0.38449395,3.5997782,,,,,,,,,,,,,,,,, -119000,0.3954399,3.6614802,,,,,,,,,,,,,,,,, -119100,0.40130657,3.6502786,,,,,,,,,,,,,,,,, -119200,0.40211272,3.6437178,,,,,,,,,,,,,,,,, -119300,0.38859588,3.6129267,,,,,,,,,,,,,,,,, -119400,0.38866356,3.589285,,,,,,,,,,,,,,,,, -119500,0.39851215,3.5743573,,,,,,,,,,,,,,,,, -119600,0.41062585,3.6029668,,,,,,,,,,,,,,,,, -119700,0.39188433,3.6085773,,,,,,,,,,,,,,,,, -119800,0.40479228,3.581828,,,,,,,,,,,,,,,,, -119900,0.3902159,3.634164,,,,,,,,,,,,,,,,, -119937,,,0.7232077717781067,1.438293695449829,37.65100512157615,0.6920310854911804,1.6015228033065796,30.75141096290305,3000.0,0.7088955044746399,1.4994897842407229,30.883720879674414,3003.0,42031.71455454826,66693.94711279869,42031.71455454826,24656.691277503967,1.7594947814941406,0.0 -120000,0.4018959,3.610453,,,,,,,,,,,,,,,,, -120100,0.41416326,3.6383069,,,,,,,,,,,,,,,,, -120200,0.3948126,3.6298718,,,,,,,,,,,,,,,,, -120300,0.38518512,3.6346288,,,,,,,,,,,,,,,,, -120400,0.39772674,3.5870826,,,,,,,,,,,,,,,,, -120500,0.40424472,3.6481583,,,,,,,,,,,,,,,,, -120600,0.3869653,3.6549432,,,,,,,,,,,,,,,,, -120700,0.3964739,3.5854118,,,,,,,,,,,,,,,,, -120800,0.38125557,3.6185017,,,,,,,,,,,,,,,,, -120900,0.39570028,3.6074858,,,,,,,,,,,,,,,,, -121000,0.40177628,3.625602,,,,,,,,,,,,,,,,, -121100,0.3959622,3.6729808,,,,,,,,,,,,,,,,, -121200,0.3820722,3.6425211,,,,,,,,,,,,,,,,, -121300,0.39243072,3.6353056,,,,,,,,,,,,,,,,, -121400,0.38542947,3.604057,,,,,,,,,,,,,,,,, -121500,0.39446425,3.5987291,,,,,,,,,,,,,,,,, -121600,0.4029112,3.6108239,,,,,,,,,,,,,,,,, -121700,0.3843715,3.6079712,,,,,,,,,,,,,,,,, -121800,0.40268856,3.6317916,,,,,,,,,,,,,,,,, -121900,0.4094695,3.6538994,,,,,,,,,,,,,,,,, -122000,0.4099176,3.6016934,,,,,,,,,,,,,,,,, -122100,0.3888113,3.562914,,,,,,,,,,,,,,,,, -122200,0.4062343,3.6345887,,,,,,,,,,,,,,,,, -122300,0.39743015,3.5772815,,,,,,,,,,,,,,,,, -122336,,,0.7199275493621826,1.4584884643554688,37.4064410559461,0.6925642490386963,1.5980294942855835,30.63594776497589,3000.0,0.70961594581604,1.4969851970672607,30.907040212287704,3003.0,42871.72847676277,68006.24787807465,42871.72847676277,25128.7705988884,1.8893790245056152,0.0 -122400,0.4011079,3.6384063,,,,,,,,,,,,,,,,, -122500,0.401906,3.5907185,,,,,,,,,,,,,,,,, -122600,0.3936414,3.636018,,,,,,,,,,,,,,,,, -122700,0.3728628,3.5713816,,,,,,,,,,,,,,,,, -122800,0.38881597,3.6437654,,,,,,,,,,,,,,,,, -122900,0.38678247,3.6242893,,,,,,,,,,,,,,,,, -123000,0.40679365,3.647403,,,,,,,,,,,,,,,,, -123100,0.3989342,3.5790503,,,,,,,,,,,,,,,,, -123200,0.3975521,3.651165,,,,,,,,,,,,,,,,, -123300,0.38762882,3.590284,,,,,,,,,,,,,,,,, -123400,0.3939668,3.6054077,,,,,,,,,,,,,,,,, -123500,0.40556034,3.6233766,,,,,,,,,,,,,,,,, -123600,0.3953379,3.6252813,,,,,,,,,,,,,,,,, -123700,0.3960405,3.6361217,,,,,,,,,,,,,,,,, -123800,0.3878425,3.567404,,,,,,,,,,,,,,,,, -123900,0.40028092,3.6174126,,,,,,,,,,,,,,,,, -124000,0.3990941,3.6137526,,,,,,,,,,,,,,,,, -124100,0.39125532,3.6293511,,,,,,,,,,,,,,,,, -124200,0.38798627,3.5668695,,,,,,,,,,,,,,,,, -124300,0.39684945,3.6352139,,,,,,,,,,,,,,,,, -124400,0.3898068,3.6056073,,,,,,,,,,,,,,,,, -124500,0.39339575,3.6190898,,,,,,,,,,,,,,,,, -124600,0.39924434,3.5763166,,,,,,,,,,,,,,,,, -124700,0.3903509,3.5782301,,,,,,,,,,,,,,,,, -124734,,,0.7197436094284058,1.4523401260375977,37.86308040490159,0.6918451189994812,1.6004658937454224,30.577418563951102,3000.0,0.7094067931175232,1.498136281967163,30.645121997646555,3003.0,43711.629980802536,69304.01886892319,43711.629980802536,25586.507713079453,1.9442377090454104,0.0 -124800,0.41225785,3.634661,,,,,,,,,,,,,,,,, -124900,0.38818356,3.6029365,,,,,,,,,,,,,,,,, -125000,0.38752303,3.598089,,,,,,,,,,,,,,,,, -125100,0.3994243,3.660062,,,,,,,,,,,,,,,,, -125200,0.39842966,3.5826366,,,,,,,,,,,,,,,,, -125300,0.3829066,3.5971208,,,,,,,,,,,,,,,,, -125400,0.4107871,3.655233,,,,,,,,,,,,,,,,, -125500,0.40631494,3.6355405,,,,,,,,,,,,,,,,, -125600,0.3879253,3.6000423,,,,,,,,,,,,,,,,, -125700,0.40125304,3.6142843,,,,,,,,,,,,,,,,, -125800,0.3926576,3.6161714,,,,,,,,,,,,,,,,, -125900,0.3796891,3.5218587,,,,,,,,,,,,,,,,, -126000,0.4125813,3.6352792,,,,,,,,,,,,,,,,, -126100,0.38278717,3.6264946,,,,,,,,,,,,,,,,, -126200,0.40464833,3.5621605,,,,,,,,,,,,,,,,, -126300,0.41495964,3.6478767,,,,,,,,,,,,,,,,, -126400,0.38643575,3.6277592,,,,,,,,,,,,,,,,, -126500,0.3901138,3.5531576,,,,,,,,,,,,,,,,, -126600,0.39720193,3.6285284,,,,,,,,,,,,,,,,, -126700,0.4087804,3.560195,,,,,,,,,,,,,,,,, -126800,0.39570555,3.6033802,,,,,,,,,,,,,,,,, -126900,0.3929281,3.6582289,,,,,,,,,,,,,,,,, -127000,0.37673035,3.5799706,,,,,,,,,,,,,,,,, -127100,0.40367255,3.6246655,,,,,,,,,,,,,,,,, -127133,,,0.723660409450531,1.441123127937317,37.82017963906209,0.6915971040725708,1.5994668006896973,30.58672769907038,3000.0,0.70961594581604,1.49761700630188,30.935873794008984,3003.0,44551.60100340843,70614.54214811325,44551.60100340843,26056.940609931946,1.98885178565979,0.0 -127200,0.40453768,3.6330187,,,,,,,,,,,,,,,,, -127300,0.3880589,3.569075,,,,,,,,,,,,,,,,, -127400,0.41585413,3.5830464,,,,,,,,,,,,,,,,, -127500,0.41115576,3.6284602,,,,,,,,,,,,,,,,, -127600,0.41236392,3.6440957,,,,,,,,,,,,,,,,, -127700,0.39057237,3.660212,,,,,,,,,,,,,,,,, -127800,0.38536987,3.6020343,,,,,,,,,,,,,,,,, -127900,0.40540552,3.6434367,,,,,,,,,,,,,,,,, -128000,0.39995894,3.5914617,,,,,,,,,,,,,,,,, -128100,0.40050018,3.6836548,,,,,,,,,,,,,,,,, -128200,0.3904177,3.6241968,,,,,,,,,,,,,,,,, -128300,0.40749383,3.6329846,,,,,,,,,,,,,,,,, -128400,0.39840326,3.577959,,,,,,,,,,,,,,,,, -128500,0.3892837,3.5723512,,,,,,,,,,,,,,,,, -128600,0.39295435,3.607802,,,,,,,,,,,,,,,,, -128700,0.39749646,3.613231,,,,,,,,,,,,,,,,, -128800,0.39670244,3.6240327,,,,,,,,,,,,,,,,, -128900,0.3931852,3.61251,,,,,,,,,,,,,,,,, -129000,0.3983548,3.602097,,,,,,,,,,,,,,,,, -129100,0.39015964,3.6646056,,,,,,,,,,,,,,,,, -129200,0.40864882,3.6353984,,,,,,,,,,,,,,,,, -129300,0.39542454,3.667044,,,,,,,,,,,,,,,,, -129400,0.41189674,3.6519356,,,,,,,,,,,,,,,,, -129500,0.39497378,3.5879076,,,,,,,,,,,,,,,,, -129532,,,0.7209928035736084,1.453560709953308,37.57740041402509,0.6917583346366882,1.5988129377365112,30.57694225255877,3000.0,0.709732174873352,1.4976327419281006,30.904820565100355,3003.0,45391.557307481766,71920.82846522331,45391.557307481766,26523.145493984222,2.03778076171875,0.0 -129600,0.39979184,3.6498594,,,,,,,,,,,,,,,,, -129700,0.41080967,3.662706,,,,,,,,,,,,,,,,, -129800,0.38495362,3.6048105,,,,,,,,,,,,,,,,, -129900,0.39313942,3.5989437,,,,,,,,,,,,,,,,, -130000,0.3799858,3.6182153,,,,,,,,,,,,,,,,, -130100,0.3981643,3.6454108,,,,,,,,,,,,,,,,, -130200,0.38277677,3.623148,,,,,,,,,,,,,,,,, -130300,0.39479998,3.6049273,,,,,,,,,,,,,,,,, -130400,0.4086161,3.560307,,,,,,,,,,,,,,,,, -130500,0.39392468,3.603871,,,,,,,,,,,,,,,,, -130600,0.38916543,3.561702,,,,,,,,,,,,,,,,, -130700,0.40126246,3.6148334,,,,,,,,,,,,,,,,, -130800,0.40585184,3.644369,,,,,,,,,,,,,,,,, -130900,0.37843585,3.549255,,,,,,,,,,,,,,,,, -131000,0.39495826,3.556795,,,,,,,,,,,,,,,,, -131100,0.38317722,3.563312,,,,,,,,,,,,,,,,, -131200,0.393987,3.6070848,,,,,,,,,,,,,,,,, -131300,0.39849785,3.6134179,,,,,,,,,,,,,,,,, -131400,0.3916584,3.6300051,,,,,,,,,,,,,,,,, -131500,0.40593043,3.6288135,,,,,,,,,,,,,,,,, -131600,0.38148704,3.592493,,,,,,,,,,,,,,,,, -131700,0.39692092,3.6126764,,,,,,,,,,,,,,,,, -131800,0.39456314,3.641684,,,,,,,,,,,,,,,,, -131900,0.38126868,3.5817182,,,,,,,,,,,,,,,,, -131931,,,0.7224260568618774,1.4456599950790403,37.93193093486411,0.6920434832572937,1.5994871854782104,30.50402400306384,3000.0,0.7099761962890625,1.4979794025421145,30.83779996163204,3003.0,46231.56835961342,73221.8060324192,46231.56835961342,26983.98964881897,2.0844168663024902,0.0 -132000,0.39743653,3.5885253,,,,,,,,,,,,,,,,, -132100,0.40488282,3.631298,,,,,,,,,,,,,,,,, -132200,0.40164894,3.6296194,,,,,,,,,,,,,,,,, -132300,0.40155166,3.6213424,,,,,,,,,,,,,,,,, -132400,0.40540743,3.6518433,,,,,,,,,,,,,,,,, -132500,0.39001036,3.5753026,,,,,,,,,,,,,,,,, -132600,0.41180292,3.635432,,,,,,,,,,,,,,,,, -132700,0.4139279,3.6578858,,,,,,,,,,,,,,,,, -132800,0.39383534,3.5810487,,,,,,,,,,,,,,,,, -132900,0.40364954,3.584775,,,,,,,,,,,,,,,,, -133000,0.39012173,3.6099665,,,,,,,,,,,,,,,,, -133100,0.3959219,3.5880196,,,,,,,,,,,,,,,,, -133200,0.39316025,3.553531,,,,,,,,,,,,,,,,, -133300,0.38905936,3.6705372,,,,,,,,,,,,,,,,, -133333,,,0.7234098315238953,1.439003586769104,37.50861974519182,0.6921550631523132,1.5993489027023315,30.494101371783117,3000.0,0.7098832130432129,1.4978277683258057,30.838390296957208,3003.0,46722.17216897011,74176.3343679905,46722.17216897011,27447.8187623024,2.1351213455200195,0.0 -133333,,,,,,,,,,,,,,46722.17216897011,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/eval_measurements.csv deleted file mode 100644 index f1cdc68b1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -839.9496130943298,0.0,27.4393572807312,1,0,27.4393572807312,0.0007088489946909,0.0,11.191027641296388,3003,867.3890187740326,0.0006245528929866,0.0,11.190462112426758,0.0004835649742744,0.0,11.190281867980955,3000 -1566.579701423645,0.0190737247467041,867.5393142700195,2397,0,867.5393142700195,0.3817093670368194,7.806855004409972,4.273231029510498,3003,2434.214878797531,0.4117948710918426,14.546588761941903,3.941603660583496,0.3955065608024597,9.821189947778166,4.074873924255371,3000 -2026.0431470870967,0.0441141128540039,1707.4862928390503,4792,0,1707.4862928390503,0.5415606498718262,18.69656506174673,2.6677498817443848,3003,3733.7244775295258,0.5426854491233826,24.877218045562664,2.662497043609619,0.5426343083381653,20.313745733230583,2.6319589614868164,3000 -2506.784266471863,0.0735259056091308,2547.5802307128906,7189,0,2547.5802307128906,0.5870431661605835,21.947922210608603,2.226987838745117,3003,5054.663430452347,0.5803403854370117,27.177722377753945,2.2744669914245605,0.585894763469696,23.69241279942485,2.227734565734864,3000 -2947.9040400981903,0.0987389087677002,3387.8171026706696,9587,0,3387.8171026706696,0.6092615127563477,23.120191808823755,2.007220268249512,3003,6336.122858285904,0.5935688018798828,28.43446243478617,2.1319386959075928,0.604071855545044,24.43123456805673,2.039424180984497,3000 -3681.867109060288,0.1276702880859375,4227.878267049789,11984,0,4227.878267049789,0.6242752075195312,24.5813471409335,1.8770338296890257,3003,7910.251308679581,0.6023198962211609,28.64165997028332,2.052376747131348,0.6189383864402771,25.24411501245022,1.9206935167312624,3000 -4101.813674926758,0.1563193798065185,5067.965177059174,14383,0,5067.965177059174,0.6378014087677002,25.51934415460533,1.7801294326782229,3003,9170.390238761902,0.6155896186828613,29.66881520325619,1.9335036277771,0.629217267036438,26.167779487244733,1.829842209815979,3000 -4562.548948049545,0.185197114944458,5908.19517326355,16781,0,5908.19517326355,0.6479228734970093,26.094720521677864,1.70824933052063,3003,10471.462261676788,0.6175857782363892,30.239096733865782,1.9224374294281008,0.6388513445854187,26.84825213597496,1.768695831298828,3000 -5157.970972537994,0.2151799201965332,6748.225657224655,19181,0,6748.225657224655,0.6533612608909607,26.776162031607065,1.6571162939071655,3003,11907.02089357376,0.6363543272018433,31.287817954743225,1.7792028188705444,0.6457824110984802,27.271873368897552,1.717884540557861,3000 -5742.355647802353,0.2489650249481201,7588.451881408691,21582,0,7588.451881408691,0.6578351259231567,26.86156974459988,1.6361119747161863,3003,13331.740639448166,0.6294018626213074,30.38335666819069,1.8236947059631348,0.6469727754592896,27.204733443076528,1.6930538415908811,3000 -6233.2626214027405,0.2775824069976806,8428.454867362976,23982,0,8428.454867362976,0.6590552926063538,26.74998812857728,1.6155636310577393,3003,14662.757348537443,0.6340451240539551,30.86761748790029,1.8066679239273071,0.6483242511749268,27.5362496664734,1.6817162036895752,3000 -6764.1959137916565,0.3061375617980957,9268.692746162416,26383,0,9268.692746162416,0.6617047190666199,26.932696320278875,1.606836199760437,3003,16034.033752679825,0.6357397437095642,30.912190908003147,1.7810266017913818,0.6519448161125183,27.51485381411525,1.6675649881362915,3000 -7257.295704364777,0.3352222442626953,10108.606403827667,28783,0,10108.606403827667,0.6621811985969543,26.89210431803698,1.59303081035614,3003,17367.151869773865,0.6305766105651855,30.657916208514465,1.8134877681732176,0.6531599164009094,27.818056014519808,1.6515889167785645,3000 -7744.511089324951,0.3644375801086426,10948.571624994278,31184,0,10948.571624994278,0.6656208634376526,27.360726859055152,1.5788705348968506,3003,18694.43784260749,0.6314585208892822,30.78093214753829,1.8155176639556885,0.6558877229690552,27.954856722763047,1.641958475112915,3000 -8278.055241584778,0.3943896293640136,11788.665621519089,33583,0,11788.665621519089,0.6670966148376465,27.57975050809596,1.5727591514587402,3003,20068.18571949005,0.6337587833404541,31.308028349740688,1.78791081905365,0.6552677750587463,28.209447885022723,1.638217806816101,3000 -8734.487503767014,0.4250726699829101,12628.895520687103,35984,0,12628.895520687103,0.6696415543556213,27.97125458605869,1.5568478107452393,3003,21364.958856105804,0.6375605463981628,30.86603327927146,1.7774264812469482,0.6572268009185791,28.12284049602379,1.6223275661468506,3000 -9242.88492488861,0.455068826675415,13469.028195142746,38385,0,13469.028195142746,0.669780969619751,27.636952264917564,1.5530155897140503,3003,22713.596930027008,0.6480917930603027,31.93014213013636,1.692330002784729,0.6595950126647949,28.11108498556029,1.6177055835723877,3000 -9768.36000418663,0.4869167804718017,14309.030643939972,40787,0,14309.030643939972,0.6694207191467285,27.76162602770913,1.5531401634216309,3003,24079.183569192886,0.6398926973342896,31.34876197331421,1.7486668825149536,0.6608969569206238,28.497071212994896,1.61137592792511,3000 -10269.365025520325,0.5226850509643555,15148.924597978592,43189,0,15148.924597978592,0.6715589165687561,28.04822035841773,1.5377026796340942,3003,25420.193618297577,0.6379035115242004,31.63035083470584,1.7647839784622192,0.6619136929512024,28.62113660665732,1.5984268188476562,3000 -10806.36244225502,0.5551190376281738,15989.009542703629,45591,0,15989.009542703629,0.6716750860214233,27.95658945953593,1.53374445438385,3003,26797.38473558426,0.645022988319397,31.1812755361264,1.712433695793152,0.6615293025970459,28.51240267229416,1.5967493057250977,3000 -11319.686262845991,0.5872867107391357,16829.091136455536,47993,0,16829.091136455536,0.6764046549797058,27.780123571158285,1.518065333366394,3003,28150.89732050896,0.6419897079467773,31.19208884900772,1.7327611446380615,0.6634387373924255,28.441895240859274,1.5877983570098877,3000 -11998.36673951149,0.6197023391723633,17669.225746631622,50395,0,17669.225746631622,0.6767880916595459,28.356528811797137,1.508074164390564,3003,29669.82399916649,0.6616469025611877,32.69124701519562,1.6047558784484863,0.6644430756568909,28.24018160352546,1.5778286457061768,3000 -12520.904906749724,0.6526608467102051,18509.355145454407,52798,0,18509.355145454407,0.6748591065406799,28.116153582756937,1.510031819343567,3003,31032.60077619553,0.6459488868713379,31.285477518950703,1.7108194828033447,0.6637115478515625,28.79268462181212,1.577656865119934,3000 -13096.69297504425,0.6853039264678955,19349.54166293144,55200,0,19349.54166293144,0.6769275665283203,28.07042551889989,1.498882532119751,3003,32448.686230182648,0.6462436318397522,31.66185194613176,1.7085494995117188,0.6645422577857971,28.52198308530932,1.571349024772644,3000 -13655.178454637527,0.7172021865844727,20189.47763967514,57602,0,20189.47763967514,0.6784266233444214,28.403415924124,1.494862079620361,3003,33847.21442198753,0.6559958457946777,31.92104954047709,1.6449164152145386,0.6649266481399536,28.65088127558128,1.5654634237289429,3000 -14164.931086301804,0.7506864070892334,21029.704471111298,60005,0,21029.704471111298,0.6812503933906555,28.37237056623999,1.4764920473098757,3003,35197.30215525627,0.6456266641616821,32.09485250326785,1.7141486406326294,0.6683239936828613,28.846310976430967,1.5589754581451416,3000 -14683.995180606842,0.7836964130401611,21869.76788258553,62407,0,21869.76788258553,0.6835861206054688,28.85938976173801,1.4696898460388184,3003,36556.538370132446,0.6473643183708191,31.688075634827747,1.701988935470581,0.6715973615646362,28.90730774156349,1.5442315340042114,3000 -15260.428012609482,0.8176324367523193,22710.00342679024,64810,0,22710.00342679024,0.682877242565155,28.835831003764778,1.4688782691955566,3003,37973.3166179657,0.6546890735626221,31.84385511875302,1.657198429107666,0.6724157333374023,29.134471317994127,1.5407817363739014,3000 -15787.419231653214,0.9314663410186768,23549.912605524063,67212,0,23549.912605524063,0.6839579343795776,28.923713488856706,1.4663946628570557,3003,39340.40902590752,0.6506806015968323,31.947221106489987,1.6883158683776855,0.6726140975952148,29.3831373186652,1.5377442836761477,3000 -16320.727837085724,0.9683539867401124,24389.983780145645,69614,0,24389.983780145645,0.6869211792945862,29.16384456800778,1.4553210735321045,3003,40713.90342450142,0.6651747822761536,32.645817061232485,1.5777531862258911,0.6728744506835938,29.578541313522177,1.529529690742493,3000 -16861.239188194275,1.0111916065216064,25229.91768527031,72016,0,25229.91768527031,0.6847946047782898,28.99919040780581,1.4506778717041016,3003,42094.46844172478,0.6537835001945496,32.53735942164496,1.6537264585494995,0.673072874546051,29.34280734448664,1.5194854736328125,3000 -17390.422934532166,1.0483558177947998,26069.80594444275,74418,0,26069.80594444275,0.6886991262435913,29.56834137146062,1.439121961593628,3003,43463.65474176407,0.6587724685668945,32.378273362747734,1.636032223701477,0.6745235323905945,29.499814008006418,1.510649561882019,3000 -17931.327827215195,1.0869407653808594,26909.785906791687,76820,0,26909.785906791687,0.6883272528648376,29.255923845563185,1.435570240020752,3003,44844.65379357338,0.6653993725776672,32.9791006066284,1.5844321250915527,0.6778588891029358,29.33848707278825,1.5037970542907717,3000 -18454.98840594292,1.123607158660889,27749.74374437332,79222,0,27749.74374437332,0.6905118823051453,29.21426525986103,1.4266692399978638,3003,46208.38493394852,0.6607912182807922,33.057591815314304,1.6076757907867432,0.6763338446617126,29.76166169576297,1.5020564794540403,3000 -18973.0031478405,1.1630005836486816,28589.722110033035,81624,0,28589.722110033035,0.6915345191955566,29.30637132482035,1.4193542003631592,3003,47566.49417424202,0.6834427714347839,34.51069994560347,1.4710636138916016,0.6774373650550842,29.57352650655724,1.4913984537124634,3000 -19553.72599005699,1.1998803615570068,29429.614921092987,84026,0,29429.614921092987,0.6908488869667053,29.32195293579914,1.4123516082763672,3003,48987.2216424942,0.6655948162078857,32.66122393028471,1.5831010341644287,0.675899863243103,29.64316058337916,1.4887354373931885,3000 -20071.82431006432,1.2375590801239014,30269.60338878632,86428,0,30269.60338878632,0.6907791495323181,29.4747702020092,1.4073405265808103,3003,50345.42161464691,0.667698323726654,33.187502389890206,1.5811502933502195,0.6793344020843506,29.7653434375351,1.481445074081421,3000 -20641.15272283554,1.275867938995361,31109.718526124954,88828,0,31109.718526124954,0.6958921551704407,29.70958983317945,1.3935723304748535,3003,51754.98230290413,0.6753185987472534,33.65455010911434,1.5222102403640747,0.6817150115966797,30.07307836815615,1.4678661823272705,3000 -21144.83209013939,1.315791130065918,31949.66242647171,91229,0,31949.66242647171,0.697681725025177,30.023480420024,1.3859319686889648,3003,53098.72299027443,0.6722777485847473,33.57607223090198,1.5488667488098145,0.6819506287574768,30.27933128209323,1.465559005737305,3000 -21668.04584169388,1.3538849353790283,32789.70082950592,93631,0,32789.70082950592,0.7006217241287231,30.23283622999288,1.3737126588821411,3003,54462.08965873718,0.6707874536514282,33.5174534912928,1.561752438545227,0.6827317476272583,30.06944812642009,1.4605400562286377,3000 -22205.843585014343,1.392345666885376,33629.65852046013,96033,0,33629.65852046013,0.6993434429168701,30.150542856499595,1.370069980621338,3003,55839.95797109604,0.6781882047653198,34.07090496618493,1.512836456298828,0.6862531304359436,30.29679901504323,1.4463741779327393,3000 -22739.84860920906,1.4309487342834473,34469.67932343483,98435,0,34469.67932343483,0.7017953991889954,30.47622243073185,1.3655441999435425,3003,57214.09830498696,0.6749489307403564,33.67612239774069,1.5274507999420166,0.6862407326698303,30.31411470974103,1.4482619762420654,3000 -23232.463057994843,1.4708812236785889,35309.65514802933,100836,0,35309.65514802933,0.7013770341873169,30.38689364950769,1.360811710357666,3003,58546.80599832535,0.6923024654388428,35.541552632794804,1.423957109451294,0.6876789927482605,30.672736533118787,1.4401698112487793,3000 -23779.142338991165,1.5111699104309082,36149.77913951874,103238,0,36149.77913951874,0.7033408880233765,30.52896680762646,1.354027271270752,3003,59933.7256834507,0.6829273700714111,34.769351800976146,1.4789576530456543,0.6884477734565735,30.403672649418127,1.4318045377731323,3000 -24277.60093665123,1.552889108657837,36989.887207746506,105640,0,36989.887207746506,0.7058508992195129,30.580010769399586,1.3480515480041504,3003,61272.41058278084,0.6826414465904236,34.5595366890984,1.486212134361267,0.689315676689148,30.46267469124044,1.428796410560608,3000 -24773.52156305313,1.6027710437774658,37829.96246767044,108043,0,37829.96246767044,0.7048399448394775,30.62602068083253,1.3408323526382446,3003,62608.53202152252,0.6920853853225708,35.49696534352309,1.4408038854599,0.6909027695655823,30.654077905274068,1.4207086563110352,3000 -25285.58715200424,1.6442315578460691,38670.09367990494,110446,0,38670.09367990494,0.7077334523200989,30.601082920116674,1.335753321647644,3003,63960.84578132629,0.6905348300933838,34.91691907999978,1.441590428352356,0.6907168030738831,30.70651994536399,1.420137882232666,3000 -25818.71938562393,1.6956148147583008,39509.97726178169,112848,0,39509.97726178169,0.7069665193557739,30.83949578884752,1.3334004878997805,3003,65333.98978304863,0.7087449431419373,36.47203853513563,1.3407204151153564,0.691969096660614,31.080921156801807,1.416434407234192,3000 -26345.901841640472,1.7358145713806152,40349.93993067741,115250,0,40349.93993067741,0.7082214951515198,30.625093720903884,1.327876091003418,3003,66701.25009989738,0.6991199254989624,36.12728996133859,1.3918790817260742,0.693630576133728,31.00381458821544,1.4105188846588137,3000 -26863.21662259102,1.7834343910217283,41189.89118170738,117651,0,41189.89118170738,0.7084422707557678,30.823761511515347,1.3249000310897827,3003,68058.64119935036,0.6992983222007751,35.97326922225457,1.395785570144653,0.6929610371589661,31.041739597257934,1.409582495689392,3000 -27391.812509059902,1.8251049518585205,42029.80330038071,120052,0,42029.80330038071,0.7096391916275024,30.95885245135478,1.3247169256210327,3003,69427.26662182808,0.7118774056434631,36.789703266642434,1.3286157846450806,0.6940025687217712,30.93815497869649,1.410871505737305,3000 -27920.49778151512,1.8757200241088867,42869.90353536606,122454,0,42869.90353536606,0.7113706469535828,30.76013239179441,1.3193117380142212,3003,70796.17909526825,0.707240104675293,36.60618759390239,1.3504246473312378,0.6941017508506775,31.11960561212031,1.4063982963562012,3000 -28439.25016117096,1.9196293354034424,43710.02218127251,124856,0,43710.02218127251,0.711661159992218,30.829254280160164,1.3174381256103516,3003,72155.17084789276,0.7027835845947266,36.60843246209514,1.3786816596984863,0.6942257285118103,31.05294464668553,1.4070228338241575,3000 -28948.995346546173,1.9638566970825195,44549.97412252426,127257,0,44549.97412252426,0.7116146683692932,30.96810588775437,1.3157283067703247,3003,73504.98891401291,0.7099313735961914,36.71479864667537,1.3409501314163208,0.6950068473815918,31.12469228435073,1.405134677886963,3000 -29457.83890938759,2.007380247116089,45390.17083954811,129659,0,45390.17083954811,0.7120562791824341,30.971645040855503,1.3163964748382568,3003,74854.14743614197,0.7101288437843323,36.566009419018926,1.3376935720443726,0.6951928734779358,31.22496853474892,1.4060457944869995,3000 -29957.364458322525,2.050343990325928,46230.35232543945,132060,0,46230.35232543945,0.7124397158622742,30.99279959879872,1.315784215927124,3003,76193.97706127167,0.7102059721946716,36.37749326792103,1.3387534618377686,0.6951432824134827,31.32584139390445,1.4058610200881958,3000 -30464.457669734955,2.0986831188201904,46675.375801324844,133333,0,46675.375801324844,0.7124048471450806,30.969127426676877,1.3156583309173584,3003,77146.18306612968,0.7138813138008118,36.535556096203386,1.3141133785247803,0.6952672600746155,31.22940684635679,1.4057084321975708,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/measurements.csv deleted file mode 100644 index 2316989ae..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.8829894,11.16369,,,,,,,,,,,,,,,,, -1,,,0.0006245528929866,11.190462112426758,0.0,0.0004835649742744,11.190281867980955,0.0,3000.0,0.0007088489946909,11.191027641296388,0.0,3003.0,27.4393572807312,867.3890187740326,27.4393572807312,839.9496130943298,0.0,0.0 -100,0.47063193,8.721491,,,,,,,,,,,,,,,,, -200,0.18907042,8.367546,,,,,,,,,,,,,,,,, -300,0.23125242,8.113975,,,,,,,,,,,,,,,,, -400,0.29426527,7.651283,,,,,,,,,,,,,,,,, -500,0.41446602,7.27998,,,,,,,,,,,,,,,,, -600,0.5810669,7.035875,,,,,,,,,,,,,,,,, -700,0.6190925,6.7818184,,,,,,,,,,,,,,,,, -800,0.6298009,6.5157795,,,,,,,,,,,,,,,,, -900,0.5615819,6.326859,,,,,,,,,,,,,,,,, -1000,0.60840523,5.989688,,,,,,,,,,,,,,,,, -1100,0.74087346,5.806454,,,,,,,,,,,,,,,,, -1200,0.73365974,5.5667467,,,,,,,,,,,,,,,,, -1300,1.0276644,5.4554224,,,,,,,,,,,,,,,,, -1400,0.58676517,5.289326,,,,,,,,,,,,,,,,, -1500,1.2533529,5.1178026,,,,,,,,,,,,,,,,, -1600,0.6876035,4.9449253,,,,,,,,,,,,,,,,, -1700,0.81475866,4.86411,,,,,,,,,,,,,,,,, -1800,0.86823994,4.6496425,,,,,,,,,,,,,,,,, -1900,1.3522755,4.5989327,,,,,,,,,,,,,,,,, -2000,1.1559585,4.49623,,,,,,,,,,,,,,,,, -2100,0.9250083,4.231876,,,,,,,,,,,,,,,,, -2200,0.88505894,4.1914983,,,,,,,,,,,,,,,,, -2300,0.9125306,4.0864553,,,,,,,,,,,,,,,,, -2397,,,0.4117948710918426,3.941603660583496,14.546588761941903,0.3955065608024597,4.074873924255371,9.821189947778166,3000.0,0.3817093670368194,4.273231029510498,7.806855004409972,3003.0,867.5393142700195,2434.214878797531,867.5393142700195,1566.579701423645,0.0190737247467041,0.0 -2400,0.77590483,3.9582236,,,,,,,,,,,,,,,,, -2500,1.0806755,3.8930845,,,,,,,,,,,,,,,,, -2600,0.83090377,3.7563152,,,,,,,,,,,,,,,,, -2700,0.87832844,3.7273984,,,,,,,,,,,,,,,,, -2800,0.98649085,3.5132306,,,,,,,,,,,,,,,,, -2900,0.73792034,3.575509,,,,,,,,,,,,,,,,, -3000,0.8712969,3.4831362,,,,,,,,,,,,,,,,, -3100,0.98668844,3.3339922,,,,,,,,,,,,,,,,, -3200,1.314854,3.2094152,,,,,,,,,,,,,,,,, -3300,0.92239875,3.1493056,,,,,,,,,,,,,,,,, -3400,0.70775825,3.1721117,,,,,,,,,,,,,,,,, -3500,1.0843437,3.0244448,,,,,,,,,,,,,,,,, -3600,1.0661601,3.047989,,,,,,,,,,,,,,,,, -3700,0.70450974,2.8967485,,,,,,,,,,,,,,,,, -3800,0.85974675,2.9723325,,,,,,,,,,,,,,,,, -3900,0.8841952,3.0301175,,,,,,,,,,,,,,,,, -4000,0.6709667,2.9151392,,,,,,,,,,,,,,,,, -4100,0.9795652,2.8688781,,,,,,,,,,,,,,,,, -4200,0.67318094,2.9061596,,,,,,,,,,,,,,,,, -4300,0.7403491,2.8368747,,,,,,,,,,,,,,,,, -4400,0.6746263,2.8435276,,,,,,,,,,,,,,,,, -4500,0.64222836,2.7183874,,,,,,,,,,,,,,,,, -4600,0.66964275,2.808038,,,,,,,,,,,,,,,,, -4700,0.71332467,2.6766124,,,,,,,,,,,,,,,,, -4792,,,0.5426854491233826,2.662497043609619,24.877218045562664,0.5426343083381653,2.6319589614868164,20.313745733230583,3000.0,0.5415606498718262,2.6677498817443848,18.69656506174673,3003.0,1707.4862928390503,3733.7244775295258,1707.4862928390503,2026.0431470870967,0.0441141128540039,0.0 -4800,0.62656665,2.6143355,,,,,,,,,,,,,,,,, -4900,0.67616296,2.6573546,,,,,,,,,,,,,,,,, -5000,0.619391,2.5854144,,,,,,,,,,,,,,,,, -5100,0.6364688,2.609963,,,,,,,,,,,,,,,,, -5200,0.6485516,2.5664518,,,,,,,,,,,,,,,,, -5300,0.7147272,2.5413167,,,,,,,,,,,,,,,,, -5400,0.6169868,2.6086602,,,,,,,,,,,,,,,,, -5500,0.595288,2.5276926,,,,,,,,,,,,,,,,, -5600,0.70929706,2.6003091,,,,,,,,,,,,,,,,, -5700,0.74939185,2.4764643,,,,,,,,,,,,,,,,, -5800,0.577338,2.4172916,,,,,,,,,,,,,,,,, -5900,0.6859408,2.547986,,,,,,,,,,,,,,,,, -6000,0.54237306,2.3518255,,,,,,,,,,,,,,,,, -6100,0.51811063,2.4254193,,,,,,,,,,,,,,,,, -6200,0.58732593,2.5145662,,,,,,,,,,,,,,,,, -6300,0.52202934,2.41911,,,,,,,,,,,,,,,,, -6400,0.6282639,2.4511483,,,,,,,,,,,,,,,,, -6500,0.6407616,2.4505763,,,,,,,,,,,,,,,,, -6600,0.4965836,2.4272568,,,,,,,,,,,,,,,,, -6700,0.48798016,2.4404292,,,,,,,,,,,,,,,,, -6800,0.5409683,2.3522246,,,,,,,,,,,,,,,,, -6900,0.5055803,2.3869305,,,,,,,,,,,,,,,,, -7000,0.45982954,2.345891,,,,,,,,,,,,,,,,, -7100,0.47231007,2.3299994,,,,,,,,,,,,,,,,, -7189,,,0.5803403854370117,2.2744669914245605,27.177722377753945,0.585894763469696,2.227734565734864,23.69241279942485,3000.0,0.5870431661605835,2.226987838745117,21.947922210608603,3003.0,2547.5802307128906,5054.663430452347,2547.5802307128906,2506.784266471863,0.0735259056091308,0.0 -7200,0.4940782,2.3289547,,,,,,,,,,,,,,,,, -7300,0.41980946,2.2880569,,,,,,,,,,,,,,,,, -7400,0.5877674,2.239927,,,,,,,,,,,,,,,,, -7500,0.5194919,2.3054101,,,,,,,,,,,,,,,,, -7600,0.80334955,2.2888985,,,,,,,,,,,,,,,,, -7700,0.4494245,2.2649777,,,,,,,,,,,,,,,,, -7800,0.4238925,2.2674954,,,,,,,,,,,,,,,,, -7900,0.46080363,2.2310863,,,,,,,,,,,,,,,,, -8000,0.40318784,2.1985543,,,,,,,,,,,,,,,,, -8100,0.41042444,2.2460837,,,,,,,,,,,,,,,,, -8200,0.4044199,2.1460443,,,,,,,,,,,,,,,,, -8300,0.40898383,2.3071184,,,,,,,,,,,,,,,,, -8400,0.38648707,2.185678,,,,,,,,,,,,,,,,, -8500,0.47928345,2.1968129,,,,,,,,,,,,,,,,, -8600,0.47696075,2.2271123,,,,,,,,,,,,,,,,, -8700,0.41491133,2.2285793,,,,,,,,,,,,,,,,, -8800,0.4366208,2.1617098,,,,,,,,,,,,,,,,, -8900,0.47305262,2.2275133,,,,,,,,,,,,,,,,, -9000,0.43691167,2.2386258,,,,,,,,,,,,,,,,, -9100,0.4348764,2.157909,,,,,,,,,,,,,,,,, -9200,0.36964622,2.2773995,,,,,,,,,,,,,,,,, -9300,0.34938386,2.119808,,,,,,,,,,,,,,,,, -9400,0.39589867,2.1692402,,,,,,,,,,,,,,,,, -9500,0.37685674,2.1709032,,,,,,,,,,,,,,,,, -9587,,,0.5935688018798828,2.1319386959075928,28.43446243478617,0.604071855545044,2.039424180984497,24.43123456805673,3000.0,0.6092615127563477,2.007220268249512,23.120191808823755,3003.0,3387.8171026706696,6336.122858285904,3387.8171026706696,2947.9040400981903,0.0987389087677002,0.0 -9600,0.3659065,2.1642354,,,,,,,,,,,,,,,,, -9700,0.3347875,2.085092,,,,,,,,,,,,,,,,, -9800,0.3193693,2.1256108,,,,,,,,,,,,,,,,, -9900,0.31999847,2.1027448,,,,,,,,,,,,,,,,, -10000,0.34064627,2.0752964,,,,,,,,,,,,,,,,, -10100,0.3387349,2.0751336,,,,,,,,,,,,,,,,, -10200,0.3428443,2.1062858,,,,,,,,,,,,,,,,, -10300,0.33449134,2.1216083,,,,,,,,,,,,,,,,, -10400,0.31359404,2.043808,,,,,,,,,,,,,,,,, -10500,0.3315193,2.1797667,,,,,,,,,,,,,,,,, -10600,0.3408848,2.1275952,,,,,,,,,,,,,,,,, -10700,0.3002039,2.2277188,,,,,,,,,,,,,,,,, -10800,0.35868824,2.1440744,,,,,,,,,,,,,,,,, -10900,0.33001655,2.0690258,,,,,,,,,,,,,,,,, -11000,0.28021967,2.1555636,,,,,,,,,,,,,,,,, -11100,0.29698157,2.0460408,,,,,,,,,,,,,,,,, -11200,0.34712663,2.134265,,,,,,,,,,,,,,,,, -11300,0.28064358,2.0417216,,,,,,,,,,,,,,,,, -11400,0.32105082,2.0296292,,,,,,,,,,,,,,,,, -11500,0.30710432,2.0685377,,,,,,,,,,,,,,,,, -11600,0.3002459,2.063652,,,,,,,,,,,,,,,,, -11700,0.2854681,2.222447,,,,,,,,,,,,,,,,, -11800,0.32326853,2.0306704,,,,,,,,,,,,,,,,, -11900,0.32347712,2.0944026,,,,,,,,,,,,,,,,, -11984,,,0.6023198962211609,2.052376747131348,28.64165997028332,0.6189383864402771,1.9206935167312624,25.24411501245022,3000.0,0.6242752075195312,1.8770338296890257,24.5813471409335,3003.0,4227.878267049789,7910.251308679581,4227.878267049789,3681.867109060288,0.1276702880859375,0.0 -12000,0.28283083,2.1152182,,,,,,,,,,,,,,,,, -12100,0.3127158,1.9594108,,,,,,,,,,,,,,,,, -12200,0.2696322,2.0279667,,,,,,,,,,,,,,,,, -12300,0.2775196,2.0615056,,,,,,,,,,,,,,,,, -12400,0.28766155,2.010132,,,,,,,,,,,,,,,,, -12500,0.28016496,2.0472307,,,,,,,,,,,,,,,,, -12600,0.29517785,2.0304153,,,,,,,,,,,,,,,,, -12700,0.26985353,2.0379171,,,,,,,,,,,,,,,,, -12800,0.27127513,1.9695181,,,,,,,,,,,,,,,,, -12900,0.2800339,2.052916,,,,,,,,,,,,,,,,, -13000,0.3156882,2.0909147,,,,,,,,,,,,,,,,, -13100,0.2679443,2.060716,,,,,,,,,,,,,,,,, -13200,0.32993114,2.093401,,,,,,,,,,,,,,,,, -13300,0.29565766,1.9886467,,,,,,,,,,,,,,,,, -13400,0.27844518,2.0389898,,,,,,,,,,,,,,,,, -13500,0.32640338,2.0034628,,,,,,,,,,,,,,,,, -13600,0.2871352,1.973017,,,,,,,,,,,,,,,,, -13700,0.33737314,2.005351,,,,,,,,,,,,,,,,, -13800,0.29096416,2.0386367,,,,,,,,,,,,,,,,, -13900,0.288577,2.0429764,,,,,,,,,,,,,,,,, -14000,0.32216164,1.9557259,,,,,,,,,,,,,,,,, -14100,0.29603443,1.9654688,,,,,,,,,,,,,,,,, -14200,0.34364793,1.9985527,,,,,,,,,,,,,,,,, -14300,0.266411,1.882274,,,,,,,,,,,,,,,,, -14383,,,0.6155896186828613,1.9335036277771,29.66881520325619,0.629217267036438,1.829842209815979,26.167779487244733,3000.0,0.6378014087677002,1.7801294326782229,25.51934415460533,3003.0,5067.965177059174,9170.390238761902,5067.965177059174,4101.813674926758,0.1563193798065185,0.0 -14400,0.30777702,1.9788713,,,,,,,,,,,,,,,,, -14500,0.29435366,1.9434524,,,,,,,,,,,,,,,,, -14600,0.33061382,1.9819223,,,,,,,,,,,,,,,,, -14700,0.28507057,1.9165856,,,,,,,,,,,,,,,,, -14800,0.27670467,1.9703016,,,,,,,,,,,,,,,,, -14900,0.33540997,1.974933,,,,,,,,,,,,,,,,, -15000,0.3649299,1.8801949,,,,,,,,,,,,,,,,, -15100,0.27917835,1.9036173,,,,,,,,,,,,,,,,, -15200,0.29955846,2.0745304,,,,,,,,,,,,,,,,, -15300,0.31795365,1.9831553,,,,,,,,,,,,,,,,, -15400,0.35103738,1.8547908,,,,,,,,,,,,,,,,, -15500,0.32101235,1.9911953,,,,,,,,,,,,,,,,, -15600,0.29308882,1.9909867,,,,,,,,,,,,,,,,, -15700,0.36924225,1.8915665,,,,,,,,,,,,,,,,, -15800,0.32197267,1.9635853,,,,,,,,,,,,,,,,, -15900,0.2721554,2.0161562,,,,,,,,,,,,,,,,, -16000,0.34141073,1.9538281,,,,,,,,,,,,,,,,, -16100,0.29799718,2.0133047,,,,,,,,,,,,,,,,, -16200,0.3085321,1.8703921,,,,,,,,,,,,,,,,, -16300,0.28812194,1.8975391,,,,,,,,,,,,,,,,, -16400,0.3446185,1.9081709,,,,,,,,,,,,,,,,, -16500,0.36085042,1.9330902,,,,,,,,,,,,,,,,, -16600,0.29630157,1.8969135,,,,,,,,,,,,,,,,, -16700,0.32188794,1.8573394,,,,,,,,,,,,,,,,, -16781,,,0.6175857782363892,1.9224374294281008,30.239096733865782,0.6388513445854187,1.768695831298828,26.84825213597496,3000.0,0.6479228734970093,1.70824933052063,26.094720521677864,3003.0,5908.19517326355,10471.462261676788,5908.19517326355,4562.548948049545,0.185197114944458,0.0 -16800,0.35147238,1.9555131,,,,,,,,,,,,,,,,, -16900,0.32163408,1.8826826,,,,,,,,,,,,,,,,, -17000,0.44069812,1.9861349,,,,,,,,,,,,,,,,, -17100,0.32321614,1.9454218,,,,,,,,,,,,,,,,, -17200,0.33534878,1.8525887,,,,,,,,,,,,,,,,, -17300,0.36233944,1.9422079,,,,,,,,,,,,,,,,, -17400,0.41273892,1.8610597,,,,,,,,,,,,,,,,, -17500,0.3567059,1.8909961,,,,,,,,,,,,,,,,, -17600,0.32270247,1.949429,,,,,,,,,,,,,,,,, -17700,0.35213068,1.8283333,,,,,,,,,,,,,,,,, -17800,0.46880433,1.9192809,,,,,,,,,,,,,,,,, -17900,0.34062275,1.8944623,,,,,,,,,,,,,,,,, -18000,0.33205935,1.9019381,,,,,,,,,,,,,,,,, -18100,0.3292328,1.9185008,,,,,,,,,,,,,,,,, -18200,0.53171295,1.8995782,,,,,,,,,,,,,,,,, -18300,0.37374675,1.8764136,,,,,,,,,,,,,,,,, -18400,0.3074759,1.8675307,,,,,,,,,,,,,,,,, -18500,0.39384714,1.9041507,,,,,,,,,,,,,,,,, -18600,0.3720667,1.8982754,,,,,,,,,,,,,,,,, -18700,0.45296392,1.8741642,,,,,,,,,,,,,,,,, -18800,0.42545363,1.9047713,,,,,,,,,,,,,,,,, -18900,0.42368394,1.8140297,,,,,,,,,,,,,,,,, -19000,0.35919353,1.9015834,,,,,,,,,,,,,,,,, -19100,0.40282708,1.8961846,,,,,,,,,,,,,,,,, -19181,,,0.6363543272018433,1.7792028188705444,31.287817954743225,0.6457824110984802,1.717884540557861,27.271873368897552,3000.0,0.6533612608909607,1.6571162939071655,26.776162031607065,3003.0,6748.225657224655,11907.02089357376,6748.225657224655,5157.970972537994,0.2151799201965332,0.0 -19200,0.35474887,1.8322594,,,,,,,,,,,,,,,,, -19300,0.3426311,1.8457083,,,,,,,,,,,,,,,,, -19400,0.3386116,1.9409306,,,,,,,,,,,,,,,,, -19500,0.4072713,1.9415089,,,,,,,,,,,,,,,,, -19600,0.3352693,1.872372,,,,,,,,,,,,,,,,, -19700,0.35767192,1.9806657,,,,,,,,,,,,,,,,, -19800,0.36539084,1.8528073,,,,,,,,,,,,,,,,, -19900,0.3592763,1.8838835,,,,,,,,,,,,,,,,, -20000,0.47031957,1.934931,,,,,,,,,,,,,,,,, -20100,0.4045487,1.808282,,,,,,,,,,,,,,,,, -20200,0.40750134,1.8904481,,,,,,,,,,,,,,,,, -20300,0.37353584,1.9069246,,,,,,,,,,,,,,,,, -20400,0.4225713,1.9836144,,,,,,,,,,,,,,,,, -20500,0.40322644,1.8540179,,,,,,,,,,,,,,,,, -20600,0.3884167,1.8934431,,,,,,,,,,,,,,,,, -20700,0.34554502,1.8319788,,,,,,,,,,,,,,,,, -20800,0.38944694,1.8672837,,,,,,,,,,,,,,,,, -20900,0.38775733,1.7817284,,,,,,,,,,,,,,,,, -21000,0.4201164,1.8637873,,,,,,,,,,,,,,,,, -21100,0.46448466,1.93592,,,,,,,,,,,,,,,,, -21200,0.40297905,1.9185079,,,,,,,,,,,,,,,,, -21300,0.44910824,1.893541,,,,,,,,,,,,,,,,, -21400,0.36525774,1.91284,,,,,,,,,,,,,,,,, -21500,0.49180314,1.9442976,,,,,,,,,,,,,,,,, -21582,,,0.6294018626213074,1.8236947059631348,30.38335666819069,0.6469727754592896,1.6930538415908811,27.204733443076528,3000.0,0.6578351259231567,1.6361119747161863,26.86156974459988,3003.0,7588.451881408691,13331.740639448166,7588.451881408691,5742.355647802353,0.2489650249481201,0.0 -21600,0.4592947,1.7925258,,,,,,,,,,,,,,,,, -21700,0.41491652,1.7875174,,,,,,,,,,,,,,,,, -21800,0.36266905,1.8297648,,,,,,,,,,,,,,,,, -21900,0.51473033,1.9322735,,,,,,,,,,,,,,,,, -22000,0.47283834,1.837931,,,,,,,,,,,,,,,,, -22100,0.38316685,1.9496253,,,,,,,,,,,,,,,,, -22200,0.44614315,1.9032347,,,,,,,,,,,,,,,,, -22300,0.39687386,1.793874,,,,,,,,,,,,,,,,, -22400,0.57144874,1.8134502,,,,,,,,,,,,,,,,, -22500,0.4208528,1.8261404,,,,,,,,,,,,,,,,, -22600,0.34356585,1.8060619,,,,,,,,,,,,,,,,, -22700,0.54947406,1.9429085,,,,,,,,,,,,,,,,, -22800,0.5243982,1.8326695,,,,,,,,,,,,,,,,, -22900,0.40542755,1.8450314,,,,,,,,,,,,,,,,, -23000,0.3830743,1.8378583,,,,,,,,,,,,,,,,, -23100,0.40794274,1.8579365,,,,,,,,,,,,,,,,, -23200,0.38914683,1.8662568,,,,,,,,,,,,,,,,, -23300,0.49155638,1.8932854,,,,,,,,,,,,,,,,, -23400,0.36708236,1.9134454,,,,,,,,,,,,,,,,, -23500,0.3815842,1.7668947,,,,,,,,,,,,,,,,, -23600,0.5564276,1.9331446,,,,,,,,,,,,,,,,, -23700,0.45132238,1.890878,,,,,,,,,,,,,,,,, -23800,0.38376263,1.8249898,,,,,,,,,,,,,,,,, -23900,0.40402898,1.7885348,,,,,,,,,,,,,,,,, -23982,,,0.6340451240539551,1.8066679239273071,30.86761748790029,0.6483242511749268,1.6817162036895752,27.5362496664734,3000.0,0.6590552926063538,1.6155636310577393,26.74998812857728,3003.0,8428.454867362976,14662.757348537443,8428.454867362976,6233.2626214027405,0.2775824069976806,0.0 -24000,0.38018477,1.7919582,,,,,,,,,,,,,,,,, -24100,0.4158105,1.8157091,,,,,,,,,,,,,,,,, -24200,0.42616105,1.8464882,,,,,,,,,,,,,,,,, -24300,0.45077327,1.8112906,,,,,,,,,,,,,,,,, -24400,0.36408767,1.8015103,,,,,,,,,,,,,,,,, -24500,0.39626852,1.7851698,,,,,,,,,,,,,,,,, -24600,0.35218063,1.7499437,,,,,,,,,,,,,,,,, -24700,0.43517134,1.806724,,,,,,,,,,,,,,,,, -24800,0.38984573,1.7984133,,,,,,,,,,,,,,,,, -24900,0.37894228,1.8521147,,,,,,,,,,,,,,,,, -25000,0.40058383,1.844677,,,,,,,,,,,,,,,,, -25100,0.37920249,1.8138059,,,,,,,,,,,,,,,,, -25200,0.3742449,1.7617714,,,,,,,,,,,,,,,,, -25300,0.4205041,1.9653596,,,,,,,,,,,,,,,,, -25400,0.43879333,1.8474461,,,,,,,,,,,,,,,,, -25500,0.44756603,1.8352602,,,,,,,,,,,,,,,,, -25600,0.38883644,1.8233532,,,,,,,,,,,,,,,,, -25700,0.41077545,1.8274548,,,,,,,,,,,,,,,,, -25800,0.4806645,1.9096825,,,,,,,,,,,,,,,,, -25900,0.45376816,1.8295236,,,,,,,,,,,,,,,,, -26000,0.5246255,1.7801341,,,,,,,,,,,,,,,,, -26100,0.37897006,1.8147218,,,,,,,,,,,,,,,,, -26200,0.45048398,1.8250744,,,,,,,,,,,,,,,,, -26300,0.4078918,1.8136208,,,,,,,,,,,,,,,,, -26383,,,0.6357397437095642,1.7810266017913818,30.912190908003147,0.6519448161125183,1.6675649881362915,27.51485381411525,3000.0,0.6617047190666199,1.606836199760437,26.932696320278875,3003.0,9268.692746162416,16034.033752679825,9268.692746162416,6764.1959137916565,0.3061375617980957,0.0 -26400,0.3993,1.8351302,,,,,,,,,,,,,,,,, -26500,0.40665618,1.7735213,,,,,,,,,,,,,,,,, -26600,0.59919167,1.7615547,,,,,,,,,,,,,,,,, -26700,0.42989382,1.8498926,,,,,,,,,,,,,,,,, -26800,0.45788103,1.8751683,,,,,,,,,,,,,,,,, -26900,0.45104408,1.8219048,,,,,,,,,,,,,,,,, -27000,0.3714533,1.7855738,,,,,,,,,,,,,,,,, -27100,0.42426184,1.7747461,,,,,,,,,,,,,,,,, -27200,0.43300042,1.7850952,,,,,,,,,,,,,,,,, -27300,0.42534158,1.736626,,,,,,,,,,,,,,,,, -27400,0.41260412,1.80313,,,,,,,,,,,,,,,,, -27500,0.40771592,1.7996829,,,,,,,,,,,,,,,,, -27600,0.5934062,1.8275647,,,,,,,,,,,,,,,,, -27700,0.41152906,1.8103147,,,,,,,,,,,,,,,,, -27800,0.77242607,1.8510727,,,,,,,,,,,,,,,,, -27900,0.8521705,1.7573822,,,,,,,,,,,,,,,,, -28000,0.4407233,1.9540077,,,,,,,,,,,,,,,,, -28100,0.43611863,1.7318634,,,,,,,,,,,,,,,,, -28200,0.4378506,1.7799084,,,,,,,,,,,,,,,,, -28300,0.4577,1.7969856,,,,,,,,,,,,,,,,, -28400,0.4265847,1.7775652,,,,,,,,,,,,,,,,, -28500,0.48884434,1.7404498,,,,,,,,,,,,,,,,, -28600,0.43875578,1.9135659,,,,,,,,,,,,,,,,, -28700,0.4015442,1.8106438,,,,,,,,,,,,,,,,, -28783,,,0.6305766105651855,1.8134877681732176,30.657916208514465,0.6531599164009094,1.6515889167785645,27.818056014519808,3000.0,0.6621811985969543,1.59303081035614,26.89210431803698,3003.0,10108.606403827667,17367.151869773865,10108.606403827667,7257.295704364777,0.3352222442626953,0.0 -28800,0.35105777,1.837627,,,,,,,,,,,,,,,,, -28900,0.3936862,1.9050777,,,,,,,,,,,,,,,,, -29000,0.37501106,1.7616405,,,,,,,,,,,,,,,,, -29100,0.4654813,1.7783694,,,,,,,,,,,,,,,,, -29200,0.43471685,1.8923699,,,,,,,,,,,,,,,,, -29300,0.43092883,1.8090596,,,,,,,,,,,,,,,,, -29400,0.41903955,1.905065,,,,,,,,,,,,,,,,, -29500,0.45033103,1.7173368,,,,,,,,,,,,,,,,, -29600,0.3898208,1.7866102,,,,,,,,,,,,,,,,, -29700,0.4089688,1.8069857,,,,,,,,,,,,,,,,, -29800,0.46792036,1.8098752,,,,,,,,,,,,,,,,, -29900,0.4651122,1.7780818,,,,,,,,,,,,,,,,, -30000,0.41193357,1.8003675,,,,,,,,,,,,,,,,, -30100,0.3999035,1.8306943,,,,,,,,,,,,,,,,, -30200,0.41947275,1.7686707,,,,,,,,,,,,,,,,, -30300,0.3984788,1.7671745,,,,,,,,,,,,,,,,, -30400,0.45910117,1.8505318,,,,,,,,,,,,,,,,, -30500,0.40516126,1.8237486,,,,,,,,,,,,,,,,, -30600,0.36878198,1.7366263,,,,,,,,,,,,,,,,, -30700,0.45335838,1.8261715,,,,,,,,,,,,,,,,, -30800,0.3851244,1.7484683,,,,,,,,,,,,,,,,, -30900,0.5850111,1.7579191,,,,,,,,,,,,,,,,, -31000,0.39281029,1.7747366,,,,,,,,,,,,,,,,, -31100,0.37338057,1.7828768,,,,,,,,,,,,,,,,, -31184,,,0.6314585208892822,1.8155176639556885,30.78093214753829,0.6558877229690552,1.641958475112915,27.954856722763047,3000.0,0.6656208634376526,1.5788705348968506,27.360726859055152,3003.0,10948.571624994278,18694.43784260749,10948.571624994278,7744.511089324951,0.3644375801086426,0.0 -31200,0.38916764,1.7462451,,,,,,,,,,,,,,,,, -31300,0.39073926,1.8297076,,,,,,,,,,,,,,,,, -31400,0.4193744,1.8078163,,,,,,,,,,,,,,,,, -31500,0.41126737,1.7471868,,,,,,,,,,,,,,,,, -31600,0.43035102,1.7556676,,,,,,,,,,,,,,,,, -31700,0.43139482,1.8457695,,,,,,,,,,,,,,,,, -31800,0.42056218,1.8108677,,,,,,,,,,,,,,,,, -31900,0.3847342,1.7865156,,,,,,,,,,,,,,,,, -32000,0.47050953,1.7957689,,,,,,,,,,,,,,,,, -32100,0.3985469,1.8066369,,,,,,,,,,,,,,,,, -32200,0.36379844,1.789744,,,,,,,,,,,,,,,,, -32300,0.38798445,1.710597,,,,,,,,,,,,,,,,, -32400,0.47108537,1.917722,,,,,,,,,,,,,,,,, -32500,0.38962397,1.7588849,,,,,,,,,,,,,,,,, -32600,0.43620884,1.7476969,,,,,,,,,,,,,,,,, -32700,0.42784637,1.7797035,,,,,,,,,,,,,,,,, -32800,0.40397292,1.8818307,,,,,,,,,,,,,,,,, -32900,0.451338,1.793152,,,,,,,,,,,,,,,,, -33000,0.3821141,1.811429,,,,,,,,,,,,,,,,, -33100,0.40375066,1.8329268,,,,,,,,,,,,,,,,, -33200,0.43087128,1.8783687,,,,,,,,,,,,,,,,, -33300,0.38430375,1.7743443,,,,,,,,,,,,,,,,, -33400,0.42888078,1.7835017,,,,,,,,,,,,,,,,, -33500,0.48022464,1.8800688,,,,,,,,,,,,,,,,, -33583,,,0.6337587833404541,1.78791081905365,31.308028349740688,0.6552677750587463,1.638217806816101,28.209447885022723,3000.0,0.6670966148376465,1.5727591514587402,27.57975050809596,3003.0,11788.665621519089,20068.18571949005,11788.665621519089,8278.055241584778,0.3943896293640136,0.0 -33600,0.556778,1.7791348,,,,,,,,,,,,,,,,, -33700,0.35845265,1.7513825,,,,,,,,,,,,,,,,, -33800,0.37498552,1.7755415,,,,,,,,,,,,,,,,, -33900,0.4258196,1.8225956,,,,,,,,,,,,,,,,, -34000,0.4353971,1.7912966,,,,,,,,,,,,,,,,, -34100,0.4020473,1.7609975,,,,,,,,,,,,,,,,, -34200,0.43347397,1.8742188,,,,,,,,,,,,,,,,, -34300,0.40797588,1.8441099,,,,,,,,,,,,,,,,, -34400,0.45349386,1.7290893,,,,,,,,,,,,,,,,, -34500,0.40102187,1.8323187,,,,,,,,,,,,,,,,, -34600,0.41069427,1.760427,,,,,,,,,,,,,,,,, -34700,0.42627263,1.7475365,,,,,,,,,,,,,,,,, -34800,0.45124358,1.8295124,,,,,,,,,,,,,,,,, -34900,0.45963982,1.8757341,,,,,,,,,,,,,,,,, -35000,0.3808322,1.7673218,,,,,,,,,,,,,,,,, -35100,0.40951318,1.814531,,,,,,,,,,,,,,,,, -35200,0.39410567,1.8084797,,,,,,,,,,,,,,,,, -35300,0.4118322,1.7616894,,,,,,,,,,,,,,,,, -35400,0.42756134,1.7709861,,,,,,,,,,,,,,,,, -35500,0.38523602,1.856645,,,,,,,,,,,,,,,,, -35600,0.47802457,1.7912533,,,,,,,,,,,,,,,,, -35700,0.4186177,1.8178036,,,,,,,,,,,,,,,,, -35800,0.42946973,1.7187552,,,,,,,,,,,,,,,,, -35900,0.41093946,1.7852377,,,,,,,,,,,,,,,,, -35984,,,0.6375605463981628,1.7774264812469482,30.86603327927146,0.6572268009185791,1.6223275661468506,28.12284049602379,3000.0,0.6696415543556213,1.5568478107452393,27.97125458605869,3003.0,12628.895520687103,21364.958856105804,12628.895520687103,8734.487503767014,0.4250726699829101,0.0 -36000,0.3897278,1.73013,,,,,,,,,,,,,,,,, -36100,0.48318225,1.8377372,,,,,,,,,,,,,,,,, -36200,0.46942878,1.7645227,,,,,,,,,,,,,,,,, -36300,0.40950462,1.7587932,,,,,,,,,,,,,,,,, -36400,0.3834377,1.7602359,,,,,,,,,,,,,,,,, -36500,0.43866134,1.8354133,,,,,,,,,,,,,,,,, -36600,0.5263871,1.7650245,,,,,,,,,,,,,,,,, -36700,0.5218618,1.744106,,,,,,,,,,,,,,,,, -36800,0.46531838,1.7527722,,,,,,,,,,,,,,,,, -36900,0.45615038,1.846012,,,,,,,,,,,,,,,,, -37000,0.43632355,1.8152052,,,,,,,,,,,,,,,,, -37100,0.39545485,1.6806328,,,,,,,,,,,,,,,,, -37200,0.43258232,1.7959838,,,,,,,,,,,,,,,,, -37300,0.43795636,1.7707013,,,,,,,,,,,,,,,,, -37400,0.3812672,1.8041278,,,,,,,,,,,,,,,,, -37500,0.4200254,1.8303058,,,,,,,,,,,,,,,,, -37600,0.40211472,1.7673773,,,,,,,,,,,,,,,,, -37700,0.40351892,1.7724487,,,,,,,,,,,,,,,,, -37800,0.41774875,1.8116647,,,,,,,,,,,,,,,,, -37900,0.41284716,1.8412064,,,,,,,,,,,,,,,,, -38000,0.35942563,1.7904496,,,,,,,,,,,,,,,,, -38100,0.44375452,1.7972703,,,,,,,,,,,,,,,,, -38200,0.42030904,1.7570702,,,,,,,,,,,,,,,,, -38300,0.4063331,1.7924386,,,,,,,,,,,,,,,,, -38385,,,0.6480917930603027,1.692330002784729,31.93014213013636,0.6595950126647949,1.6177055835723877,28.11108498556029,3000.0,0.669780969619751,1.5530155897140503,27.636952264917564,3003.0,13469.028195142746,22713.596930027008,13469.028195142746,9242.88492488861,0.455068826675415,0.0 -38400,0.4240293,1.7599564,,,,,,,,,,,,,,,,, -38500,0.4255255,1.7769592,,,,,,,,,,,,,,,,, -38600,0.4462286,1.8371056,,,,,,,,,,,,,,,,, -38700,0.44120768,1.6994869,,,,,,,,,,,,,,,,, -38800,0.4173905,1.7473066,,,,,,,,,,,,,,,,, -38900,0.38370275,1.731282,,,,,,,,,,,,,,,,, -39000,0.38464215,1.8135579,,,,,,,,,,,,,,,,, -39100,0.4440694,1.7298137,,,,,,,,,,,,,,,,, -39200,0.3922505,1.8334256,,,,,,,,,,,,,,,,, -39300,0.43069437,1.7707888,,,,,,,,,,,,,,,,, -39400,0.4127459,1.7915931,,,,,,,,,,,,,,,,, -39500,0.4204103,1.7297002,,,,,,,,,,,,,,,,, -39600,0.3958307,1.7713096,,,,,,,,,,,,,,,,, -39700,0.3632641,1.739767,,,,,,,,,,,,,,,,, -39800,0.43456063,1.7913347,,,,,,,,,,,,,,,,, -39900,0.44519684,1.7453154,,,,,,,,,,,,,,,,, -40000,0.39806926,1.8133639,,,,,,,,,,,,,,,,, -40100,0.37991416,1.7171756,,,,,,,,,,,,,,,,, -40200,0.3868043,1.7181592,,,,,,,,,,,,,,,,, -40300,0.40598628,1.7647259,,,,,,,,,,,,,,,,, -40400,0.41409844,1.7586238,,,,,,,,,,,,,,,,, -40500,0.47418794,1.6882749,,,,,,,,,,,,,,,,, -40600,0.39111492,1.8032833,,,,,,,,,,,,,,,,, -40700,0.38712168,1.7525309,,,,,,,,,,,,,,,,, -40787,,,0.6398926973342896,1.7486668825149536,31.34876197331421,0.6608969569206238,1.61137592792511,28.497071212994896,3000.0,0.6694207191467285,1.5531401634216309,27.76162602770913,3003.0,14309.030643939972,24079.183569192886,14309.030643939972,9768.36000418663,0.4869167804718017,0.0 -40800,0.5016987,1.6966116,,,,,,,,,,,,,,,,, -40900,0.38550317,1.780714,,,,,,,,,,,,,,,,, -41000,0.40793055,1.7690896,,,,,,,,,,,,,,,,, -41100,0.42857292,1.7973367,,,,,,,,,,,,,,,,, -41200,0.39693284,1.7718263,,,,,,,,,,,,,,,,, -41300,0.40904757,1.732009,,,,,,,,,,,,,,,,, -41400,0.44343606,1.648867,,,,,,,,,,,,,,,,, -41500,0.37317353,1.7212737,,,,,,,,,,,,,,,,, -41600,0.4482885,1.799299,,,,,,,,,,,,,,,,, -41700,0.4595106,1.7837418,,,,,,,,,,,,,,,,, -41800,0.42135617,1.7266479,,,,,,,,,,,,,,,,, -41900,0.39672118,1.7756889,,,,,,,,,,,,,,,,, -42000,0.3932979,1.7470497,,,,,,,,,,,,,,,,, -42100,0.40101624,1.7338468,,,,,,,,,,,,,,,,, -42200,0.41207817,1.8863722,,,,,,,,,,,,,,,,, -42300,0.4306944,1.7996072,,,,,,,,,,,,,,,,, -42400,0.436758,1.7248769,,,,,,,,,,,,,,,,, -42500,0.4059611,1.7840039,,,,,,,,,,,,,,,,, -42600,0.4113106,1.73349,,,,,,,,,,,,,,,,, -42700,0.37176013,1.819147,,,,,,,,,,,,,,,,, -42800,0.39268172,1.7720461,,,,,,,,,,,,,,,,, -42900,0.37475398,1.7195343,,,,,,,,,,,,,,,,, -43000,0.35707745,1.6626879,,,,,,,,,,,,,,,,, -43100,0.50971085,1.7600209,,,,,,,,,,,,,,,,, -43189,,,0.6379035115242004,1.7647839784622192,31.63035083470584,0.6619136929512024,1.5984268188476562,28.62113660665732,3000.0,0.6715589165687561,1.5377026796340942,28.04822035841773,3003.0,15148.924597978592,25420.193618297577,15148.924597978592,10269.365025520325,0.5226850509643555,0.0 -43200,0.4322658,1.7226794,,,,,,,,,,,,,,,,, -43300,0.42082414,1.709461,,,,,,,,,,,,,,,,, -43400,0.44771314,1.7558553,,,,,,,,,,,,,,,,, -43500,0.3862041,1.6946062,,,,,,,,,,,,,,,,, -43600,0.37153375,1.7846637,,,,,,,,,,,,,,,,, -43700,0.41569477,1.8021212,,,,,,,,,,,,,,,,, -43800,0.4696332,1.6992809,,,,,,,,,,,,,,,,, -43900,0.5154319,1.750956,,,,,,,,,,,,,,,,, -44000,0.5897429,1.6954848,,,,,,,,,,,,,,,,, -44100,0.4216983,1.7768917,,,,,,,,,,,,,,,,, -44200,0.43569705,1.7557713,,,,,,,,,,,,,,,,, -44300,0.47854,1.6481745,,,,,,,,,,,,,,,,, -44400,0.4581458,1.6966277,,,,,,,,,,,,,,,,, -44500,0.45724452,1.76087,,,,,,,,,,,,,,,,, -44600,0.41987157,1.756359,,,,,,,,,,,,,,,,, -44700,0.42054522,1.7226413,,,,,,,,,,,,,,,,, -44800,0.42015633,1.7949781,,,,,,,,,,,,,,,,, -44900,0.40263718,1.6860809,,,,,,,,,,,,,,,,, -45000,0.38625664,1.6610905,,,,,,,,,,,,,,,,, -45100,0.43398002,1.6950216,,,,,,,,,,,,,,,,, -45200,0.37550142,1.7411712,,,,,,,,,,,,,,,,, -45300,0.35499078,1.7469432,,,,,,,,,,,,,,,,, -45400,0.36564252,1.7243156,,,,,,,,,,,,,,,,, -45500,0.39901906,1.8225923,,,,,,,,,,,,,,,,, -45591,,,0.645022988319397,1.712433695793152,31.1812755361264,0.6615293025970459,1.5967493057250977,28.51240267229416,3000.0,0.6716750860214233,1.53374445438385,27.95658945953593,3003.0,15989.009542703629,26797.38473558426,15989.009542703629,10806.36244225502,0.5551190376281738,0.0 -45600,0.3837442,1.7244841,,,,,,,,,,,,,,,,, -45700,0.44022354,1.8297721,,,,,,,,,,,,,,,,, -45800,0.40496767,1.6993207,,,,,,,,,,,,,,,,, -45900,0.39682347,1.6881621,,,,,,,,,,,,,,,,, -46000,0.4130224,1.7175888,,,,,,,,,,,,,,,,, -46100,0.42867362,1.788752,,,,,,,,,,,,,,,,, -46200,0.3641042,1.6977991,,,,,,,,,,,,,,,,, -46300,0.3867662,1.7430726,,,,,,,,,,,,,,,,, -46400,0.43308824,1.8303292,,,,,,,,,,,,,,,,, -46500,0.37908208,1.8228601,,,,,,,,,,,,,,,,, -46600,0.4953173,1.8138686,,,,,,,,,,,,,,,,, -46700,0.43896386,1.8056074,,,,,,,,,,,,,,,,, -46800,0.40207383,1.7610314,,,,,,,,,,,,,,,,, -46900,0.34493136,1.7302171,,,,,,,,,,,,,,,,, -47000,0.37234214,1.7457846,,,,,,,,,,,,,,,,, -47100,0.4328731,1.6983888,,,,,,,,,,,,,,,,, -47200,0.44603232,1.7591903,,,,,,,,,,,,,,,,, -47300,0.44277552,1.7032976,,,,,,,,,,,,,,,,, -47400,0.38229457,1.7454128,,,,,,,,,,,,,,,,, -47500,0.45458505,1.7080244,,,,,,,,,,,,,,,,, -47600,0.43480715,1.7464105,,,,,,,,,,,,,,,,, -47700,0.5894698,1.7928072,,,,,,,,,,,,,,,,, -47800,0.42898318,1.7449272,,,,,,,,,,,,,,,,, -47900,0.39143318,1.7176805,,,,,,,,,,,,,,,,, -47993,,,0.6419897079467773,1.7327611446380615,31.19208884900772,0.6634387373924255,1.5877983570098877,28.441895240859274,3000.0,0.6764046549797058,1.518065333366394,27.780123571158285,3003.0,16829.091136455536,28150.89732050896,16829.091136455536,11319.686262845991,0.5872867107391357,0.0 -48000,0.48565835,1.77492,,,,,,,,,,,,,,,,, -48100,0.48232466,1.7443944,,,,,,,,,,,,,,,,, -48200,0.40406367,1.8052322,,,,,,,,,,,,,,,,, -48300,0.42512855,1.7465389,,,,,,,,,,,,,,,,, -48400,0.40338328,1.6885849,,,,,,,,,,,,,,,,, -48500,0.47553846,1.7117752,,,,,,,,,,,,,,,,, -48600,0.38510972,1.810916,,,,,,,,,,,,,,,,, -48700,0.44470868,1.7751882,,,,,,,,,,,,,,,,, -48800,0.39536506,1.6799341,,,,,,,,,,,,,,,,, -48900,0.38401645,1.6786625,,,,,,,,,,,,,,,,, -49000,0.41580695,1.7540419,,,,,,,,,,,,,,,,, -49100,0.41859186,1.7087991,,,,,,,,,,,,,,,,, -49200,0.45790654,1.6631114,,,,,,,,,,,,,,,,, -49300,0.37109175,1.6666648,,,,,,,,,,,,,,,,, -49400,0.4152782,1.70487,,,,,,,,,,,,,,,,, -49500,0.3797682,1.7881218,,,,,,,,,,,,,,,,, -49600,0.4503626,1.7167234,,,,,,,,,,,,,,,,, -49700,0.40297458,1.8113673,,,,,,,,,,,,,,,,, -49800,0.40348345,1.6913396,,,,,,,,,,,,,,,,, -49900,0.39240065,1.7648312,,,,,,,,,,,,,,,,, -50000,0.4385289,1.7156363,,,,,,,,,,,,,,,,, -50100,0.4630798,1.8468114,,,,,,,,,,,,,,,,, -50200,0.37583205,1.7371224,,,,,,,,,,,,,,,,, -50300,0.4486689,1.67651,,,,,,,,,,,,,,,,, -50395,,,0.6616469025611877,1.6047558784484863,32.69124701519562,0.6644430756568909,1.5778286457061768,28.24018160352546,3000.0,0.6767880916595459,1.508074164390564,28.356528811797137,3003.0,17669.225746631622,29669.82399916649,17669.225746631622,11998.36673951149,0.6197023391723633,0.0 -50400,0.42942017,1.7831248,,,,,,,,,,,,,,,,, -50500,0.3722666,1.7702973,,,,,,,,,,,,,,,,, -50600,0.42111427,1.7050935,,,,,,,,,,,,,,,,, -50700,0.42474303,1.7256188,,,,,,,,,,,,,,,,, -50800,0.45106912,1.7479773,,,,,,,,,,,,,,,,, -50900,0.41264576,1.7691104,,,,,,,,,,,,,,,,, -51000,0.40775305,1.7188486,,,,,,,,,,,,,,,,, -51100,0.35682368,1.6979015,,,,,,,,,,,,,,,,, -51200,0.39212248,1.775921,,,,,,,,,,,,,,,,, -51300,0.39160857,1.7437365,,,,,,,,,,,,,,,,, -51400,0.41881853,1.6912667,,,,,,,,,,,,,,,,, -51500,0.37762573,1.7197952,,,,,,,,,,,,,,,,, -51600,0.39896426,1.8110772,,,,,,,,,,,,,,,,, -51700,0.39177334,1.7179563,,,,,,,,,,,,,,,,, -51800,0.4205552,1.870194,,,,,,,,,,,,,,,,, -51900,0.3637161,1.6619477,,,,,,,,,,,,,,,,, -52000,0.4742402,1.6930494,,,,,,,,,,,,,,,,, -52100,0.39052364,1.8087215,,,,,,,,,,,,,,,,, -52200,0.37940714,1.7194922,,,,,,,,,,,,,,,,, -52300,0.3962629,1.7183893,,,,,,,,,,,,,,,,, -52400,0.41215774,1.7807426,,,,,,,,,,,,,,,,, -52500,0.37629852,1.804229,,,,,,,,,,,,,,,,, -52600,0.40412948,1.7719358,,,,,,,,,,,,,,,,, -52700,0.37902066,1.7711353,,,,,,,,,,,,,,,,, -52798,,,0.6459488868713379,1.7108194828033447,31.285477518950703,0.6637115478515625,1.577656865119934,28.79268462181212,3000.0,0.6748591065406799,1.510031819343567,28.116153582756937,3003.0,18509.355145454407,31032.60077619553,18509.355145454407,12520.904906749724,0.6526608467102051,0.0 -52800,0.425208,1.7822196,,,,,,,,,,,,,,,,, -52900,0.41882864,1.8040302,,,,,,,,,,,,,,,,, -53000,0.4284482,1.7869978,,,,,,,,,,,,,,,,, -53100,0.40042445,1.7587556,,,,,,,,,,,,,,,,, -53200,0.4077577,1.714053,,,,,,,,,,,,,,,,, -53300,0.4410229,1.6895128,,,,,,,,,,,,,,,,, -53400,0.39509347,1.7500869,,,,,,,,,,,,,,,,, -53500,0.37573636,1.7033327,,,,,,,,,,,,,,,,, -53600,0.3963803,1.7616099,,,,,,,,,,,,,,,,, -53700,0.3792299,1.7334815,,,,,,,,,,,,,,,,, -53800,0.39553824,1.7637242,,,,,,,,,,,,,,,,, -53900,0.4394331,1.6704345,,,,,,,,,,,,,,,,, -54000,0.4097069,1.6729974,,,,,,,,,,,,,,,,, -54100,0.37867293,1.6473749,,,,,,,,,,,,,,,,, -54200,0.38971072,1.7396455,,,,,,,,,,,,,,,,, -54300,0.41864243,1.6569958,,,,,,,,,,,,,,,,, -54400,0.3707237,1.6580213,,,,,,,,,,,,,,,,, -54500,0.4357461,1.7161843,,,,,,,,,,,,,,,,, -54600,0.4088687,1.7480055,,,,,,,,,,,,,,,,, -54700,0.46882924,1.7132841,,,,,,,,,,,,,,,,, -54800,0.3924081,1.6650097,,,,,,,,,,,,,,,,, -54900,0.40530977,1.7401203,,,,,,,,,,,,,,,,, -55000,0.3789468,1.686426,,,,,,,,,,,,,,,,, -55100,0.42663705,1.782467,,,,,,,,,,,,,,,,, -55200,,,0.6462436318397522,1.7085494995117188,31.66185194613176,0.6645422577857971,1.571349024772644,28.52198308530932,3000.0,0.6769275665283203,1.498882532119751,28.07042551889989,3003.0,19349.54166293144,32448.686230182648,19349.54166293144,13096.69297504425,0.6853039264678955,0.0 -55200,0.37282997,1.6504434,,,,,,,,,,,,,,,,, -55300,0.3995002,1.7630192,,,,,,,,,,,,,,,,, -55400,0.39202392,1.7237668,,,,,,,,,,,,,,,,, -55500,0.38171798,1.6862454,,,,,,,,,,,,,,,,, -55600,0.42281786,1.7490731,,,,,,,,,,,,,,,,, -55700,0.4255705,1.6554888,,,,,,,,,,,,,,,,, -55800,1.7092851,1.7397233,,,,,,,,,,,,,,,,, -55900,0.5846451,1.8086029,,,,,,,,,,,,,,,,, -56000,0.4019548,1.6749685,,,,,,,,,,,,,,,,, -56100,0.43219903,1.6715318,,,,,,,,,,,,,,,,, -56200,0.42251858,1.7778875,,,,,,,,,,,,,,,,, -56300,0.4329902,1.7676378,,,,,,,,,,,,,,,,, -56400,0.3931889,1.6816747,,,,,,,,,,,,,,,,, -56500,0.3747636,1.7363056,,,,,,,,,,,,,,,,, -56600,0.39758593,1.6732925,,,,,,,,,,,,,,,,, -56700,0.4617449,1.7856363,,,,,,,,,,,,,,,,, -56800,0.37483433,1.7983159,,,,,,,,,,,,,,,,, -56900,0.42820805,1.8254255,,,,,,,,,,,,,,,,, -57000,0.3561622,1.6742834,,,,,,,,,,,,,,,,, -57100,0.3609525,1.6636459,,,,,,,,,,,,,,,,, -57200,0.39799967,1.7185891,,,,,,,,,,,,,,,,, -57300,0.3976837,1.7127779,,,,,,,,,,,,,,,,, -57400,0.4465866,1.6954055,,,,,,,,,,,,,,,,, -57500,0.4083479,1.8216597,,,,,,,,,,,,,,,,, -57600,0.376191,1.7287179,,,,,,,,,,,,,,,,, -57602,,,0.6559958457946777,1.6449164152145386,31.92104954047709,0.6649266481399536,1.5654634237289429,28.65088127558128,3000.0,0.6784266233444214,1.494862079620361,28.403415924124,3003.0,20189.47763967514,33847.21442198753,20189.47763967514,13655.178454637527,0.7172021865844727,0.0 -57700,0.39609653,1.769116,,,,,,,,,,,,,,,,, -57800,0.40718144,1.7803968,,,,,,,,,,,,,,,,, -57900,0.42205623,1.6024146,,,,,,,,,,,,,,,,, -58000,0.42942646,1.6646348,,,,,,,,,,,,,,,,, -58100,0.3788719,1.6783352,,,,,,,,,,,,,,,,, -58200,0.39483047,1.6194,,,,,,,,,,,,,,,,, -58300,0.39186516,1.6748744,,,,,,,,,,,,,,,,, -58400,0.40371948,1.6875209,,,,,,,,,,,,,,,,, -58500,0.35927388,1.7279004,,,,,,,,,,,,,,,,, -58600,0.37676787,1.7628094,,,,,,,,,,,,,,,,, -58700,0.43302482,1.7055138,,,,,,,,,,,,,,,,, -58800,0.39566776,1.6879278,,,,,,,,,,,,,,,,, -58900,0.41326302,1.6241072,,,,,,,,,,,,,,,,, -59000,0.3946629,1.7428898,,,,,,,,,,,,,,,,, -59100,0.4066716,1.709317,,,,,,,,,,,,,,,,, -59200,0.4258089,1.7038058,,,,,,,,,,,,,,,,, -59300,0.41724467,1.7534711,,,,,,,,,,,,,,,,, -59400,0.39547464,1.6958221,,,,,,,,,,,,,,,,, -59500,0.42951518,1.7145327,,,,,,,,,,,,,,,,, -59600,0.40149656,1.7208923,,,,,,,,,,,,,,,,, -59700,0.3942971,1.6680248,,,,,,,,,,,,,,,,, -59800,0.4276816,1.6583942,,,,,,,,,,,,,,,,, -59900,0.42504886,1.6681212,,,,,,,,,,,,,,,,, -60000,0.4327881,1.7345179,,,,,,,,,,,,,,,,, -60005,,,0.6456266641616821,1.7141486406326294,32.09485250326785,0.6683239936828613,1.5589754581451416,28.846310976430967,3000.0,0.6812503933906555,1.4764920473098757,28.37237056623999,3003.0,21029.704471111298,35197.30215525627,21029.704471111298,14164.931086301804,0.7506864070892334,0.0 -60100,0.42797223,1.719846,,,,,,,,,,,,,,,,, -60200,0.4212079,1.6545475,,,,,,,,,,,,,,,,, -60300,0.35812992,1.7135687,,,,,,,,,,,,,,,,, -60400,0.4168857,1.7374759,,,,,,,,,,,,,,,,, -60500,0.41301566,1.7619666,,,,,,,,,,,,,,,,, -60600,0.40807432,1.7166916,,,,,,,,,,,,,,,,, -60700,0.41158783,1.717918,,,,,,,,,,,,,,,,, -60800,0.3831934,1.7140743,,,,,,,,,,,,,,,,, -60900,0.395588,1.7643559,,,,,,,,,,,,,,,,, -61000,0.41835904,1.7137994,,,,,,,,,,,,,,,,, -61100,0.3966793,1.6399364,,,,,,,,,,,,,,,,, -61200,0.43701765,1.6988802,,,,,,,,,,,,,,,,, -61300,0.3971767,1.7459222,,,,,,,,,,,,,,,,, -61400,0.39915168,1.6980087,,,,,,,,,,,,,,,,, -61500,0.3909017,1.6969941,,,,,,,,,,,,,,,,, -61600,0.38590303,1.6568322,,,,,,,,,,,,,,,,, -61700,0.39651412,1.6760368,,,,,,,,,,,,,,,,, -61800,0.382373,1.7520406,,,,,,,,,,,,,,,,, -61900,0.39206842,1.6926919,,,,,,,,,,,,,,,,, -62000,0.439441,1.7608764,,,,,,,,,,,,,,,,, -62100,0.41995677,1.6315396,,,,,,,,,,,,,,,,, -62200,0.42017114,1.7129503,,,,,,,,,,,,,,,,, -62300,0.41099674,1.6172948,,,,,,,,,,,,,,,,, -62400,0.39649493,1.6427795,,,,,,,,,,,,,,,,, -62407,,,0.6473643183708191,1.701988935470581,31.688075634827747,0.6715973615646362,1.5442315340042114,28.90730774156349,3000.0,0.6835861206054688,1.4696898460388184,28.85938976173801,3003.0,21869.76788258553,36556.538370132446,21869.76788258553,14683.995180606842,0.7836964130401611,0.0 -62500,0.4426213,1.6733315,,,,,,,,,,,,,,,,, -62600,0.37462124,1.6786791,,,,,,,,,,,,,,,,, -62700,0.4477879,1.6395588,,,,,,,,,,,,,,,,, -62800,0.41570696,1.7033598,,,,,,,,,,,,,,,,, -62900,0.41117066,1.6157768,,,,,,,,,,,,,,,,, -63000,0.45908302,1.7499111,,,,,,,,,,,,,,,,, -63100,0.4061249,1.6478914,,,,,,,,,,,,,,,,, -63200,0.40180656,1.6818995,,,,,,,,,,,,,,,,, -63300,0.40399483,1.7357004,,,,,,,,,,,,,,,,, -63400,0.42566356,1.802872,,,,,,,,,,,,,,,,, -63500,0.42303005,1.663978,,,,,,,,,,,,,,,,, -63600,0.38276616,1.6342708,,,,,,,,,,,,,,,,, -63700,0.40215233,1.6329516,,,,,,,,,,,,,,,,, -63800,0.41526923,1.8000866,,,,,,,,,,,,,,,,, -63900,0.397411,1.6933652,,,,,,,,,,,,,,,,, -64000,0.43008265,1.7039566,,,,,,,,,,,,,,,,, -64100,0.439,1.7880899,,,,,,,,,,,,,,,,, -64200,0.41245043,1.7479521,,,,,,,,,,,,,,,,, -64300,0.39267588,1.647569,,,,,,,,,,,,,,,,, -64400,0.41663393,1.7133154,,,,,,,,,,,,,,,,, -64500,0.4163571,1.7105783,,,,,,,,,,,,,,,,, -64600,0.42047924,1.6223377,,,,,,,,,,,,,,,,, -64700,0.8653832,1.6069052,,,,,,,,,,,,,,,,, -64800,0.4231623,1.680329,,,,,,,,,,,,,,,,, -64810,,,0.6546890735626221,1.657198429107666,31.84385511875302,0.6724157333374023,1.5407817363739014,29.134471317994127,3000.0,0.682877242565155,1.4688782691955566,28.835831003764778,3003.0,22710.00342679024,37973.3166179657,22710.00342679024,15260.428012609482,0.8176324367523193,0.0 -64900,0.43335578,1.6580632,,,,,,,,,,,,,,,,, -65000,0.39639255,1.7155374,,,,,,,,,,,,,,,,, -65100,0.42906496,1.6806273,,,,,,,,,,,,,,,,, -65200,0.39680016,1.7274925,,,,,,,,,,,,,,,,, -65300,0.41438112,1.686956,,,,,,,,,,,,,,,,, -65400,0.43141517,1.7477752,,,,,,,,,,,,,,,,, -65500,0.40014553,1.6484928,,,,,,,,,,,,,,,,, -65600,0.44620985,1.7281663,,,,,,,,,,,,,,,,, -65700,0.39065793,1.6765776,,,,,,,,,,,,,,,,, -65800,0.41078755,1.6762092,,,,,,,,,,,,,,,,, -65900,0.38264468,1.6602964,,,,,,,,,,,,,,,,, -66000,0.4088113,1.6719106,,,,,,,,,,,,,,,,, -66100,0.42236823,1.6593884,,,,,,,,,,,,,,,,, -66200,0.40484944,1.6527346,,,,,,,,,,,,,,,,, -66300,0.44137433,1.5884688,,,,,,,,,,,,,,,,, -66400,0.412847,1.6166719,,,,,,,,,,,,,,,,, -66500,0.4337935,1.7056874,,,,,,,,,,,,,,,,, -66600,0.4136177,1.6320164,,,,,,,,,,,,,,,,, -66700,0.39031243,1.6530002,,,,,,,,,,,,,,,,, -66800,0.41530123,1.5874295,,,,,,,,,,,,,,,,, -66900,0.4274566,1.6967999,,,,,,,,,,,,,,,,, -67000,0.41360402,1.7142646,,,,,,,,,,,,,,,,, -67100,0.38939443,1.5922632,,,,,,,,,,,,,,,,, -67200,0.42054036,1.6314179,,,,,,,,,,,,,,,,, -67212,,,0.6506806015968323,1.6883158683776855,31.947221106489987,0.6726140975952148,1.5377442836761477,29.3831373186652,3000.0,0.6839579343795776,1.4663946628570557,28.923713488856706,3003.0,23549.912605524063,39340.40902590752,23549.912605524063,15787.419231653214,0.9314663410186768,0.0 -67300,0.41978583,1.6440241,,,,,,,,,,,,,,,,, -67400,0.40006116,1.734223,,,,,,,,,,,,,,,,, -67500,0.5753886,1.6558137,,,,,,,,,,,,,,,,, -67600,0.39665282,1.6869075,,,,,,,,,,,,,,,,, -67700,2.1150804,1.7565224,,,,,,,,,,,,,,,,, -67800,0.4197492,1.7270173,,,,,,,,,,,,,,,,, -67900,0.4190344,1.6361601,,,,,,,,,,,,,,,,, -68000,0.38784027,1.59396,,,,,,,,,,,,,,,,, -68100,0.4393542,1.7249943,,,,,,,,,,,,,,,,, -68200,0.42700493,1.7211926,,,,,,,,,,,,,,,,, -68300,0.40181148,1.6346759,,,,,,,,,,,,,,,,, -68400,0.43550077,1.7098013,,,,,,,,,,,,,,,,, -68500,0.40922245,1.6585397,,,,,,,,,,,,,,,,, -68600,0.40734014,1.6726938,,,,,,,,,,,,,,,,, -68700,0.41796616,1.7249249,,,,,,,,,,,,,,,,, -68800,0.42782825,1.7529461,,,,,,,,,,,,,,,,, -68900,0.3918451,1.572151,,,,,,,,,,,,,,,,, -69000,0.45499906,1.7220532,,,,,,,,,,,,,,,,, -69100,0.42137253,1.6791497,,,,,,,,,,,,,,,,, -69200,0.40005714,1.6696388,,,,,,,,,,,,,,,,, -69300,0.3957357,1.652661,,,,,,,,,,,,,,,,, -69400,0.39800802,1.6050891,,,,,,,,,,,,,,,,, -69500,0.37987277,1.6533731,,,,,,,,,,,,,,,,, -69600,0.39952454,1.6040978,,,,,,,,,,,,,,,,, -69614,,,0.6651747822761536,1.5777531862258911,32.645817061232485,0.6728744506835938,1.529529690742493,29.578541313522177,3000.0,0.6869211792945862,1.4553210735321045,29.16384456800778,3003.0,24389.983780145645,40713.90342450142,24389.983780145645,16320.727837085724,0.9683539867401124,0.0 -69700,0.43708575,1.6567607,,,,,,,,,,,,,,,,, -69800,0.45587653,1.6778071,,,,,,,,,,,,,,,,, -69900,0.4185777,1.7018496,,,,,,,,,,,,,,,,, -70000,0.41980064,1.6094197,,,,,,,,,,,,,,,,, -70100,0.41466406,1.6216342,,,,,,,,,,,,,,,,, -70200,0.4336759,1.668586,,,,,,,,,,,,,,,,, -70300,0.48940134,1.6353158,,,,,,,,,,,,,,,,, -70400,0.39972973,1.6528343,,,,,,,,,,,,,,,,, -70500,0.43602967,1.6030145,,,,,,,,,,,,,,,,, -70600,0.40201727,1.680637,,,,,,,,,,,,,,,,, -70700,0.41402504,1.7070118,,,,,,,,,,,,,,,,, -70800,0.4359051,1.6366324,,,,,,,,,,,,,,,,, -70900,0.4254628,1.7627385,,,,,,,,,,,,,,,,, -71000,0.3891442,1.6758788,,,,,,,,,,,,,,,,, -71100,0.4134795,1.6049848,,,,,,,,,,,,,,,,, -71200,0.42339784,1.6574683,,,,,,,,,,,,,,,,, -71300,0.46682447,1.7276757,,,,,,,,,,,,,,,,, -71400,0.39848873,1.5886818,,,,,,,,,,,,,,,,, -71500,0.44680572,1.5974827,,,,,,,,,,,,,,,,, -71600,0.41874793,1.6516412,,,,,,,,,,,,,,,,, -71700,0.4388524,1.6754047,,,,,,,,,,,,,,,,, -71800,0.38380277,1.7468101,,,,,,,,,,,,,,,,, -71900,0.39997077,1.6016772,,,,,,,,,,,,,,,,, -72000,0.4039926,1.6768804,,,,,,,,,,,,,,,,, -72016,,,0.6537835001945496,1.6537264585494995,32.53735942164496,0.673072874546051,1.5194854736328125,29.34280734448664,3000.0,0.6847946047782898,1.4506778717041016,28.99919040780581,3003.0,25229.91768527031,42094.46844172478,25229.91768527031,16861.239188194275,1.0111916065216064,0.0 -72100,0.4576635,1.6849463,,,,,,,,,,,,,,,,, -72200,0.42176172,1.6221958,,,,,,,,,,,,,,,,, -72300,0.40315852,1.6238228,,,,,,,,,,,,,,,,, -72400,0.40370205,1.6308615,,,,,,,,,,,,,,,,, -72500,0.40088302,1.6795958,,,,,,,,,,,,,,,,, -72600,0.40335193,1.5143592,,,,,,,,,,,,,,,,, -72700,0.39002317,1.7005184,,,,,,,,,,,,,,,,, -72800,0.41303664,1.6393037,,,,,,,,,,,,,,,,, -72900,0.40729818,1.6555332,,,,,,,,,,,,,,,,, -73000,0.40270016,1.6556934,,,,,,,,,,,,,,,,, -73100,0.38734052,1.6097852,,,,,,,,,,,,,,,,, -73200,0.4475739,1.6564527,,,,,,,,,,,,,,,,, -73300,0.4736467,1.583696,,,,,,,,,,,,,,,,, -73400,0.42811278,1.6706847,,,,,,,,,,,,,,,,, -73500,0.40697286,1.6120176,,,,,,,,,,,,,,,,, -73600,0.39847708,1.6194669,,,,,,,,,,,,,,,,, -73700,1.9588772,1.647972,,,,,,,,,,,,,,,,, -73800,0.4044812,1.6315314,,,,,,,,,,,,,,,,, -73900,0.4333189,1.6560892,,,,,,,,,,,,,,,,, -74000,0.41311967,1.6313447,,,,,,,,,,,,,,,,, -74100,0.42601717,1.6370066,,,,,,,,,,,,,,,,, -74200,0.4172693,1.6717314,,,,,,,,,,,,,,,,, -74300,0.45797944,1.6031556,,,,,,,,,,,,,,,,, -74400,0.40507337,1.7254404,,,,,,,,,,,,,,,,, -74418,,,0.6587724685668945,1.636032223701477,32.378273362747734,0.6745235323905945,1.510649561882019,29.499814008006418,3000.0,0.6886991262435913,1.439121961593628,29.56834137146062,3003.0,26069.80594444275,43463.65474176407,26069.80594444275,17390.422934532166,1.0483558177947998,0.0 -74500,0.41796416,1.6716475,,,,,,,,,,,,,,,,, -74600,0.41539764,1.6664238,,,,,,,,,,,,,,,,, -74700,0.4119368,1.6763866,,,,,,,,,,,,,,,,, -74800,0.40423986,1.6119783,,,,,,,,,,,,,,,,, -74900,0.42409122,1.54939,,,,,,,,,,,,,,,,, -75000,0.4281751,1.6280389,,,,,,,,,,,,,,,,, -75100,0.41640037,1.6169848,,,,,,,,,,,,,,,,, -75200,0.40520132,1.6825732,,,,,,,,,,,,,,,,, -75300,0.3819273,1.5817512,,,,,,,,,,,,,,,,, -75400,0.42527735,1.518882,,,,,,,,,,,,,,,,, -75500,0.42718527,1.5753831,,,,,,,,,,,,,,,,, -75600,0.42204124,1.6876426,,,,,,,,,,,,,,,,, -75700,0.45773092,1.5883231,,,,,,,,,,,,,,,,, -75800,0.43728846,1.6260777,,,,,,,,,,,,,,,,, -75900,0.43521434,1.6665245,,,,,,,,,,,,,,,,, -76000,0.43004316,1.6500024,,,,,,,,,,,,,,,,, -76100,0.48599344,1.6622708,,,,,,,,,,,,,,,,, -76200,0.44169396,1.6679317,,,,,,,,,,,,,,,,, -76300,0.43245152,1.7132668,,,,,,,,,,,,,,,,, -76400,0.46021312,1.7080786,,,,,,,,,,,,,,,,, -76500,0.4218559,1.637026,,,,,,,,,,,,,,,,, -76600,0.45296296,1.6272991,,,,,,,,,,,,,,,,, -76700,0.41580096,1.4914275,,,,,,,,,,,,,,,,, -76800,0.41294476,1.6584266,,,,,,,,,,,,,,,,, -76820,,,0.6653993725776672,1.5844321250915527,32.9791006066284,0.6778588891029358,1.5037970542907717,29.33848707278825,3000.0,0.6883272528648376,1.435570240020752,29.255923845563185,3003.0,26909.785906791687,44844.65379357338,26909.785906791687,17931.327827215195,1.0869407653808594,0.0 -76900,0.40195757,1.6615826,,,,,,,,,,,,,,,,, -77000,0.46575713,1.6970639,,,,,,,,,,,,,,,,, -77100,0.42720783,1.638026,,,,,,,,,,,,,,,,, -77200,0.42885518,1.6612966,,,,,,,,,,,,,,,,, -77300,0.43949395,1.6802453,,,,,,,,,,,,,,,,, -77400,0.43193492,1.6675992,,,,,,,,,,,,,,,,, -77500,0.42997417,1.6141263,,,,,,,,,,,,,,,,, -77600,0.44675803,1.58505,,,,,,,,,,,,,,,,, -77700,0.40903494,1.6004834,,,,,,,,,,,,,,,,, -77800,0.42753556,1.5708748,,,,,,,,,,,,,,,,, -77900,0.44620088,1.702282,,,,,,,,,,,,,,,,, -78000,0.4313734,1.650474,,,,,,,,,,,,,,,,, -78100,0.4500542,1.6065866,,,,,,,,,,,,,,,,, -78200,0.42843205,1.6360024,,,,,,,,,,,,,,,,, -78300,0.40703508,1.631355,,,,,,,,,,,,,,,,, -78400,0.4176927,1.5749822,,,,,,,,,,,,,,,,, -78500,0.41941208,1.5989319,,,,,,,,,,,,,,,,, -78600,0.43005377,1.7124908,,,,,,,,,,,,,,,,, -78700,0.44004843,1.6778172,,,,,,,,,,,,,,,,, -78800,0.44949993,1.6361375,,,,,,,,,,,,,,,,, -78900,0.45149684,1.639536,,,,,,,,,,,,,,,,, -79000,0.42542648,1.6830521,,,,,,,,,,,,,,,,, -79100,0.41674873,1.6143093,,,,,,,,,,,,,,,,, -79200,0.41372988,1.6055405,,,,,,,,,,,,,,,,, -79222,,,0.6607912182807922,1.6076757907867432,33.057591815314304,0.6763338446617126,1.5020564794540403,29.76166169576297,3000.0,0.6905118823051453,1.4266692399978638,29.21426525986103,3003.0,27749.74374437332,46208.38493394852,27749.74374437332,18454.98840594292,1.123607158660889,0.0 -79300,0.40705213,1.6262547,,,,,,,,,,,,,,,,, -79400,0.4467883,1.6085083,,,,,,,,,,,,,,,,, -79500,0.41449732,1.5674297,,,,,,,,,,,,,,,,, -79600,0.40970552,1.6007547,,,,,,,,,,,,,,,,, -79700,0.4469304,1.6746764,,,,,,,,,,,,,,,,, -79800,0.45541054,1.6026802,,,,,,,,,,,,,,,,, -79900,0.47386572,1.6479398,,,,,,,,,,,,,,,,, -80000,0.47707304,1.5670687,,,,,,,,,,,,,,,,, -80100,0.44258782,1.6609668,,,,,,,,,,,,,,,,, -80200,0.42896453,1.6211858,,,,,,,,,,,,,,,,, -80300,0.45402548,1.6496576,,,,,,,,,,,,,,,,, -80400,0.4298478,1.6275839,,,,,,,,,,,,,,,,, -80500,0.43153635,1.616732,,,,,,,,,,,,,,,,, -80600,0.42110312,1.6416477,,,,,,,,,,,,,,,,, -80700,0.43129012,1.6534196,,,,,,,,,,,,,,,,, -80800,0.44535163,1.5844129,,,,,,,,,,,,,,,,, -80900,0.4188301,1.6985862,,,,,,,,,,,,,,,,, -81000,0.4281894,1.648888,,,,,,,,,,,,,,,,, -81100,0.4231183,1.5672985,,,,,,,,,,,,,,,,, -81200,0.44037828,1.5683545,,,,,,,,,,,,,,,,, -81300,0.42918566,1.5269283,,,,,,,,,,,,,,,,, -81400,0.4418061,1.7513924,,,,,,,,,,,,,,,,, -81500,0.44306943,1.5768502,,,,,,,,,,,,,,,,, -81600,0.44417772,1.6516181,,,,,,,,,,,,,,,,, -81624,,,0.6834427714347839,1.4710636138916016,34.51069994560347,0.6774373650550842,1.4913984537124634,29.57352650655724,3000.0,0.6915345191955566,1.4193542003631592,29.30637132482035,3003.0,28589.722110033035,47566.49417424202,28589.722110033035,18973.0031478405,1.1630005836486816,0.0 -81700,0.42933255,1.5870098,,,,,,,,,,,,,,,,, -81800,0.44234765,1.6590663,,,,,,,,,,,,,,,,, -81900,0.4537094,1.6061441,,,,,,,,,,,,,,,,, -82000,0.47517973,1.565782,,,,,,,,,,,,,,,,, -82100,0.4391425,1.5946103,,,,,,,,,,,,,,,,, -82200,0.45685405,1.5514737,,,,,,,,,,,,,,,,, -82300,0.44367784,1.6723874,,,,,,,,,,,,,,,,, -82400,0.4462709,1.5914497,,,,,,,,,,,,,,,,, -82500,0.46490118,1.5716195,,,,,,,,,,,,,,,,, -82600,0.4172457,1.5780543,,,,,,,,,,,,,,,,, -82700,0.40979728,1.5993656,,,,,,,,,,,,,,,,, -82800,0.43696743,1.5588247,,,,,,,,,,,,,,,,, -82900,0.43523636,1.6271285,,,,,,,,,,,,,,,,, -83000,0.42770246,1.5479343,,,,,,,,,,,,,,,,, -83100,0.42932475,1.5671761,,,,,,,,,,,,,,,,, -83200,0.45742077,1.5922655,,,,,,,,,,,,,,,,, -83300,0.45160145,1.5644823,,,,,,,,,,,,,,,,, -83400,0.47406736,1.5893921,,,,,,,,,,,,,,,,, -83500,0.48156643,1.5822845,,,,,,,,,,,,,,,,, -83600,0.47362134,1.6210909,,,,,,,,,,,,,,,,, -83700,0.49904197,1.6593919,,,,,,,,,,,,,,,,, -83800,0.4473737,1.6785941,,,,,,,,,,,,,,,,, -83900,0.44033325,1.6454167,,,,,,,,,,,,,,,,, -84000,0.4748629,1.6014086,,,,,,,,,,,,,,,,, -84026,,,0.6655948162078857,1.5831010341644287,32.66122393028471,0.675899863243103,1.4887354373931885,29.64316058337916,3000.0,0.6908488869667053,1.4123516082763672,29.32195293579914,3003.0,29429.614921092987,48987.2216424942,29429.614921092987,19553.72599005699,1.1998803615570068,0.0 -84100,0.48941082,1.6055621,,,,,,,,,,,,,,,,, -84200,0.4724427,1.6142051,,,,,,,,,,,,,,,,, -84300,0.8916174,1.666428,,,,,,,,,,,,,,,,, -84400,0.4786105,1.6315132,,,,,,,,,,,,,,,,, -84500,0.4581935,1.6684071,,,,,,,,,,,,,,,,, -84600,0.45777562,1.6042911,,,,,,,,,,,,,,,,, -84700,0.42738375,1.5415429,,,,,,,,,,,,,,,,, -84800,0.47004223,1.592544,,,,,,,,,,,,,,,,, -84900,0.4352758,1.587717,,,,,,,,,,,,,,,,, -85000,0.46764106,1.6240664,,,,,,,,,,,,,,,,, -85100,0.44446096,1.5912291,,,,,,,,,,,,,,,,, -85200,0.46558064,1.6068383,,,,,,,,,,,,,,,,, -85300,0.46392176,1.6744392,,,,,,,,,,,,,,,,, -85400,0.46985328,1.6157428,,,,,,,,,,,,,,,,, -85500,0.44168675,1.6795853,,,,,,,,,,,,,,,,, -85600,0.49241832,1.6251795,,,,,,,,,,,,,,,,, -85700,0.4608516,1.6181649,,,,,,,,,,,,,,,,, -85800,0.47090232,1.6003021,,,,,,,,,,,,,,,,, -85900,0.44917277,1.6275254,,,,,,,,,,,,,,,,, -86000,0.49077222,1.6593972,,,,,,,,,,,,,,,,, -86100,0.46409762,1.6663487,,,,,,,,,,,,,,,,, -86200,0.46627578,1.5781775,,,,,,,,,,,,,,,,, -86300,0.439986,1.6184399,,,,,,,,,,,,,,,,, -86400,0.47903904,1.6200107,,,,,,,,,,,,,,,,, -86428,,,0.667698323726654,1.5811502933502195,33.187502389890206,0.6793344020843506,1.481445074081421,29.7653434375351,3000.0,0.6907791495323181,1.4073405265808103,29.4747702020092,3003.0,30269.60338878632,50345.42161464691,30269.60338878632,20071.82431006432,1.2375590801239014,0.0 -86500,0.47604007,1.6280266,,,,,,,,,,,,,,,,, -86600,0.46398607,1.6260293,,,,,,,,,,,,,,,,, -86700,0.46694136,1.6031197,,,,,,,,,,,,,,,,, -86800,0.46756792,1.5409685,,,,,,,,,,,,,,,,, -86900,0.47723484,1.5734589,,,,,,,,,,,,,,,,, -87000,0.46340865,1.5590031,,,,,,,,,,,,,,,,, -87100,0.44814175,1.6134434,,,,,,,,,,,,,,,,, -87200,0.45902005,1.5614972,,,,,,,,,,,,,,,,, -87300,0.4480046,1.5485903,,,,,,,,,,,,,,,,, -87400,0.45841354,1.5669068,,,,,,,,,,,,,,,,, -87500,0.45747554,1.5148892,,,,,,,,,,,,,,,,, -87600,0.45690045,1.6043687,,,,,,,,,,,,,,,,, -87700,0.43298584,1.5593894,,,,,,,,,,,,,,,,, -87800,0.45871562,1.5216948,,,,,,,,,,,,,,,,, -87900,0.4683964,1.6008446,,,,,,,,,,,,,,,,, -88000,0.46833628,1.5865906,,,,,,,,,,,,,,,,, -88100,0.46422237,1.6061885,,,,,,,,,,,,,,,,, -88200,0.48286873,1.6141167,,,,,,,,,,,,,,,,, -88300,0.41901177,1.5851117,,,,,,,,,,,,,,,,, -88400,0.4507235,1.5599155,,,,,,,,,,,,,,,,, -88500,0.5042842,1.5546676,,,,,,,,,,,,,,,,, -88600,0.46447814,1.6232877,,,,,,,,,,,,,,,,, -88700,0.51513857,1.6343359,,,,,,,,,,,,,,,,, -88800,0.45931768,1.5877441,,,,,,,,,,,,,,,,, -88828,,,0.6753185987472534,1.5222102403640747,33.65455010911434,0.6817150115966797,1.4678661823272705,30.07307836815615,3000.0,0.6958921551704407,1.3935723304748535,29.70958983317945,3003.0,31109.718526124954,51754.98230290413,31109.718526124954,20641.15272283554,1.275867938995361,0.0 -88900,0.49419126,1.5692078,,,,,,,,,,,,,,,,, -89000,0.45819634,1.5337212,,,,,,,,,,,,,,,,, -89100,0.50038147,1.5619053,,,,,,,,,,,,,,,,, -89200,0.44499752,1.6143792,,,,,,,,,,,,,,,,, -89300,0.44325644,1.4861407,,,,,,,,,,,,,,,,, -89400,0.47860688,1.4895089,,,,,,,,,,,,,,,,, -89500,0.48033535,1.6393545,,,,,,,,,,,,,,,,, -89600,0.48335478,1.6055111,,,,,,,,,,,,,,,,, -89700,0.46414456,1.6154456,,,,,,,,,,,,,,,,, -89800,0.48864296,1.5503705,,,,,,,,,,,,,,,,, -89900,0.4998957,1.638138,,,,,,,,,,,,,,,,, -90000,0.4773964,1.6078273,,,,,,,,,,,,,,,,, -90100,0.4704186,1.6252372,,,,,,,,,,,,,,,,, -90200,0.5202693,1.6041638,,,,,,,,,,,,,,,,, -90300,0.4915556,1.6411366,,,,,,,,,,,,,,,,, -90400,0.52176934,1.6121805,,,,,,,,,,,,,,,,, -90500,0.4888724,1.5479708,,,,,,,,,,,,,,,,, -90600,0.4836115,1.6222683,,,,,,,,,,,,,,,,, -90700,0.4676693,1.5816413,,,,,,,,,,,,,,,,, -90800,0.47782627,1.6470705,,,,,,,,,,,,,,,,, -90900,0.47253385,1.5638326,,,,,,,,,,,,,,,,, -91000,0.47235793,1.5752635,,,,,,,,,,,,,,,,, -91100,0.5092913,1.4926989,,,,,,,,,,,,,,,,, -91200,0.44702098,1.5882863,,,,,,,,,,,,,,,,, -91229,,,0.6722777485847473,1.5488667488098145,33.57607223090198,0.6819506287574768,1.465559005737305,30.27933128209323,3000.0,0.697681725025177,1.3859319686889648,30.023480420024,3003.0,31949.66242647171,53098.72299027443,31949.66242647171,21144.83209013939,1.315791130065918,0.0 -91300,0.46858078,1.5243534,,,,,,,,,,,,,,,,, -91400,0.47285363,1.5424039,,,,,,,,,,,,,,,,, -91500,0.47948852,1.5441473,,,,,,,,,,,,,,,,, -91600,0.48657462,1.6210599,,,,,,,,,,,,,,,,, -91700,0.5121561,1.5614464,,,,,,,,,,,,,,,,, -91800,0.4517248,1.5403469,,,,,,,,,,,,,,,,, -91900,0.50003403,1.5045767,,,,,,,,,,,,,,,,, -92000,0.71348155,1.6134688,,,,,,,,,,,,,,,,, -92100,0.47467598,1.5966581,,,,,,,,,,,,,,,,, -92200,0.52640504,1.569224,,,,,,,,,,,,,,,,, -92300,0.49393374,1.5280397,,,,,,,,,,,,,,,,, -92400,0.5003412,1.6360284,,,,,,,,,,,,,,,,, -92500,0.5087976,1.5477529,,,,,,,,,,,,,,,,, -92600,0.47797284,1.5426389,,,,,,,,,,,,,,,,, -92700,0.48515093,1.5282087,,,,,,,,,,,,,,,,, -92800,0.48726198,1.5477581,,,,,,,,,,,,,,,,, -92900,0.5128632,1.5186208,,,,,,,,,,,,,,,,, -93000,0.523805,1.6145185,,,,,,,,,,,,,,,,, -93100,0.5210866,1.6387875,,,,,,,,,,,,,,,,, -93200,0.4659102,1.5865432,,,,,,,,,,,,,,,,, -93300,0.4887472,1.5742317,,,,,,,,,,,,,,,,, -93400,0.4896127,1.566896,,,,,,,,,,,,,,,,, -93500,0.48690122,1.5166503,,,,,,,,,,,,,,,,, -93600,0.49534228,1.5304646,,,,,,,,,,,,,,,,, -93631,,,0.6707874536514282,1.561752438545227,33.5174534912928,0.6827317476272583,1.4605400562286377,30.06944812642009,3000.0,0.7006217241287231,1.3737126588821411,30.23283622999288,3003.0,32789.70082950592,54462.08965873718,32789.70082950592,21668.04584169388,1.3538849353790283,0.0 -93700,0.4837681,1.5086355,,,,,,,,,,,,,,,,, -93800,0.49462876,1.4821436,,,,,,,,,,,,,,,,, -93900,0.5233626,1.5872155,,,,,,,,,,,,,,,,, -94000,0.48871917,1.5105027,,,,,,,,,,,,,,,,, -94100,0.4931721,1.5227234,,,,,,,,,,,,,,,,, -94200,0.5148041,1.5657552,,,,,,,,,,,,,,,,, -94300,0.5008045,1.5792212,,,,,,,,,,,,,,,,, -94400,0.47661033,1.5658116,,,,,,,,,,,,,,,,, -94500,0.5158281,1.5828098,,,,,,,,,,,,,,,,, -94600,0.493346,1.6041965,,,,,,,,,,,,,,,,, -94700,0.4948694,1.5450045,,,,,,,,,,,,,,,,, -94800,0.5048483,1.5221364,,,,,,,,,,,,,,,,, -94900,0.51010936,1.5335665,,,,,,,,,,,,,,,,, -95000,0.5102763,1.5069833,,,,,,,,,,,,,,,,, -95100,0.5342491,1.581847,,,,,,,,,,,,,,,,, -95200,0.4931545,1.6000285,,,,,,,,,,,,,,,,, -95300,0.5275528,1.573764,,,,,,,,,,,,,,,,, -95400,0.52002877,1.4725833,,,,,,,,,,,,,,,,, -95500,0.526297,1.482047,,,,,,,,,,,,,,,,, -95600,0.5107044,1.5109622,,,,,,,,,,,,,,,,, -95700,0.53043807,1.6294998,,,,,,,,,,,,,,,,, -95800,0.50344753,1.5980775,,,,,,,,,,,,,,,,, -95900,0.5535344,1.635565,,,,,,,,,,,,,,,,, -96000,0.5245822,1.5257281,,,,,,,,,,,,,,,,, -96033,,,0.6781882047653198,1.512836456298828,34.07090496618493,0.6862531304359436,1.4463741779327393,30.29679901504323,3000.0,0.6993434429168701,1.370069980621338,30.150542856499595,3003.0,33629.65852046013,55839.95797109604,33629.65852046013,22205.843585014343,1.392345666885376,0.0 -96100,0.51776266,1.5153849,,,,,,,,,,,,,,,,, -96200,0.5232622,1.5011983,,,,,,,,,,,,,,,,, -96300,0.52203083,1.5916334,,,,,,,,,,,,,,,,, -96400,0.54626274,1.5969892,,,,,,,,,,,,,,,,, -96500,0.52191603,1.5064113,,,,,,,,,,,,,,,,, -96600,0.54880583,1.5991564,,,,,,,,,,,,,,,,, -96700,0.53166497,1.507018,,,,,,,,,,,,,,,,, -96800,0.55684,1.609376,,,,,,,,,,,,,,,,, -96900,0.5242482,1.5531735,,,,,,,,,,,,,,,,, -97000,0.5259813,1.547529,,,,,,,,,,,,,,,,, -97100,0.51737,1.4897195,,,,,,,,,,,,,,,,, -97200,0.5446058,1.4915329,,,,,,,,,,,,,,,,, -97300,0.53886175,1.5507965,,,,,,,,,,,,,,,,, -97400,0.5285696,1.5420992,,,,,,,,,,,,,,,,, -97500,0.5117733,1.5497544,,,,,,,,,,,,,,,,, -97600,0.52071846,1.5445948,,,,,,,,,,,,,,,,, -97700,0.5385548,1.5249408,,,,,,,,,,,,,,,,, -97800,0.5437283,1.5810975,,,,,,,,,,,,,,,,, -97900,0.51706606,1.5102375,,,,,,,,,,,,,,,,, -98000,0.54131496,1.6088388,,,,,,,,,,,,,,,,, -98100,0.5388451,1.5289294,,,,,,,,,,,,,,,,, -98200,0.52753955,1.5648555,,,,,,,,,,,,,,,,, -98300,0.56353486,1.5951986,,,,,,,,,,,,,,,,, -98400,0.5193629,1.4572908,,,,,,,,,,,,,,,,, -98435,,,0.6749489307403564,1.5274507999420166,33.67612239774069,0.6862407326698303,1.4482619762420654,30.31411470974103,3000.0,0.7017953991889954,1.3655441999435425,30.47622243073185,3003.0,34469.67932343483,57214.09830498696,34469.67932343483,22739.84860920906,1.4309487342834473,0.0 -98500,0.5513843,1.54378,,,,,,,,,,,,,,,,, -98600,0.52694684,1.518908,,,,,,,,,,,,,,,,, -98700,0.5526949,1.4930009,,,,,,,,,,,,,,,,, -98800,0.5495006,1.6120636,,,,,,,,,,,,,,,,, -98900,0.5373943,1.5858929,,,,,,,,,,,,,,,,, -99000,0.52089584,1.5304406,,,,,,,,,,,,,,,,, -99100,0.5278962,1.5128084,,,,,,,,,,,,,,,,, -99200,0.5509051,1.4867966,,,,,,,,,,,,,,,,, -99300,0.5368672,1.5195271,,,,,,,,,,,,,,,,, -99400,0.5601646,1.4947059,,,,,,,,,,,,,,,,, -99500,0.5518117,1.5191392,,,,,,,,,,,,,,,,, -99600,0.5294422,1.5200062,,,,,,,,,,,,,,,,, -99700,0.5637377,1.4492582,,,,,,,,,,,,,,,,, -99800,0.56834775,1.5475637,,,,,,,,,,,,,,,,, -99900,0.56823534,1.4556836,,,,,,,,,,,,,,,,, -100000,0.5451533,1.4732783,,,,,,,,,,,,,,,,, -100100,0.5668577,1.56796,,,,,,,,,,,,,,,,, -100200,0.53611875,1.5089772,,,,,,,,,,,,,,,,, -100300,0.5541115,1.4728391,,,,,,,,,,,,,,,,, -100400,0.59289855,1.5476084,,,,,,,,,,,,,,,,, -100500,0.60660785,1.522862,,,,,,,,,,,,,,,,, -100600,0.5451213,1.4983782,,,,,,,,,,,,,,,,, -100700,0.5763707,1.4947468,,,,,,,,,,,,,,,,, -100800,0.56907517,1.4891374,,,,,,,,,,,,,,,,, -100836,,,0.6923024654388428,1.423957109451294,35.541552632794804,0.6876789927482605,1.4401698112487793,30.672736533118787,3000.0,0.7013770341873169,1.360811710357666,30.38689364950769,3003.0,35309.65514802933,58546.80599832535,35309.65514802933,23232.463057994843,1.4708812236785889,0.0 -100900,0.56942785,1.52863,,,,,,,,,,,,,,,,, -101000,0.585568,1.5462574,,,,,,,,,,,,,,,,, -101100,0.5453172,1.5263944,,,,,,,,,,,,,,,,, -101200,0.6041024,1.5446568,,,,,,,,,,,,,,,,, -101300,0.5797443,1.5814188,,,,,,,,,,,,,,,,, -101400,0.5627269,1.5386337,,,,,,,,,,,,,,,,, -101500,0.55149597,1.5270361,,,,,,,,,,,,,,,,, -101600,0.58530664,1.4946628,,,,,,,,,,,,,,,,, -101700,0.59112704,1.4899547,,,,,,,,,,,,,,,,, -101800,0.5574225,1.5365292,,,,,,,,,,,,,,,,, -101900,0.57594633,1.5979933,,,,,,,,,,,,,,,,, -102000,0.56503665,1.5307662,,,,,,,,,,,,,,,,, -102100,0.59128404,1.5029709,,,,,,,,,,,,,,,,, -102200,0.59599704,1.542681,,,,,,,,,,,,,,,,, -102300,0.5801518,1.4766355,,,,,,,,,,,,,,,,, -102400,0.5924122,1.508791,,,,,,,,,,,,,,,,, -102500,0.5734434,1.4988122,,,,,,,,,,,,,,,,, -102600,0.57842803,1.5096447,,,,,,,,,,,,,,,,, -102700,0.6050437,1.6215932,,,,,,,,,,,,,,,,, -102800,0.5901435,1.5378906,,,,,,,,,,,,,,,,, -102900,0.6040997,1.5285212,,,,,,,,,,,,,,,,, -103000,0.58044714,1.5138363,,,,,,,,,,,,,,,,, -103100,0.60466516,1.4900681,,,,,,,,,,,,,,,,, -103200,0.58322835,1.5694739,,,,,,,,,,,,,,,,, -103238,,,0.6829273700714111,1.4789576530456543,34.769351800976146,0.6884477734565735,1.4318045377731323,30.403672649418127,3000.0,0.7033408880233765,1.354027271270752,30.52896680762646,3003.0,36149.77913951874,59933.7256834507,36149.77913951874,23779.142338991165,1.5111699104309082,0.0 -103300,0.5974767,1.5661598,,,,,,,,,,,,,,,,, -103400,0.58888346,1.5366577,,,,,,,,,,,,,,,,, -103500,0.5999317,1.4693464,,,,,,,,,,,,,,,,, -103600,0.5884165,1.4843037,,,,,,,,,,,,,,,,, -103700,0.5902553,1.4428978,,,,,,,,,,,,,,,,, -103800,0.58742535,1.5519749,,,,,,,,,,,,,,,,, -103900,0.6058536,1.5026685,,,,,,,,,,,,,,,,, -104000,0.58248377,1.5248104,,,,,,,,,,,,,,,,, -104100,0.6048497,1.4330719,,,,,,,,,,,,,,,,, -104200,0.6186172,1.5138327,,,,,,,,,,,,,,,,, -104300,0.6175197,1.4808166,,,,,,,,,,,,,,,,, -104400,0.60556847,1.5545231,,,,,,,,,,,,,,,,, -104500,0.6033859,1.5054047,,,,,,,,,,,,,,,,, -104600,0.57088405,1.4507254,,,,,,,,,,,,,,,,, -104700,0.59916306,1.474071,,,,,,,,,,,,,,,,, -104800,0.63001466,1.4806383,,,,,,,,,,,,,,,,, -104900,0.6230243,1.4789395,,,,,,,,,,,,,,,,, -105000,0.6140163,1.441425,,,,,,,,,,,,,,,,, -105100,0.6172969,1.5100851,,,,,,,,,,,,,,,,, -105200,0.61219275,1.5205796,,,,,,,,,,,,,,,,, -105300,0.6125925,1.5093099,,,,,,,,,,,,,,,,, -105400,0.6143299,1.461317,,,,,,,,,,,,,,,,, -105500,0.64494044,1.4926102,,,,,,,,,,,,,,,,, -105600,0.6163681,1.4593757,,,,,,,,,,,,,,,,, -105640,,,0.6826414465904236,1.486212134361267,34.5595366890984,0.689315676689148,1.428796410560608,30.46267469124044,3000.0,0.7058508992195129,1.3480515480041504,30.580010769399586,3003.0,36989.887207746506,61272.41058278084,36989.887207746506,24277.60093665123,1.552889108657837,0.0 -105700,0.6460087,1.437668,,,,,,,,,,,,,,,,, -105800,0.62424284,1.4939321,,,,,,,,,,,,,,,,, -105900,0.6218511,1.5547688,,,,,,,,,,,,,,,,, -106000,0.63554096,1.5094191,,,,,,,,,,,,,,,,, -106100,0.5913273,1.4679432,,,,,,,,,,,,,,,,, -106200,0.6419844,1.4738792,,,,,,,,,,,,,,,,, -106300,0.6343086,1.499486,,,,,,,,,,,,,,,,, -106400,0.6288139,1.4624691,,,,,,,,,,,,,,,,, -106500,0.6565029,1.4775653,,,,,,,,,,,,,,,,, -106600,0.6058529,1.4255264,,,,,,,,,,,,,,,,, -106700,0.6146883,1.4511285,,,,,,,,,,,,,,,,, -106800,0.6202058,1.475826,,,,,,,,,,,,,,,,, -106900,0.6675904,1.4394442,,,,,,,,,,,,,,,,, -107000,0.65274644,1.579461,,,,,,,,,,,,,,,,, -107100,0.6704296,1.4710302,,,,,,,,,,,,,,,,, -107200,0.6303753,1.4787359,,,,,,,,,,,,,,,,, -107300,0.6586799,1.5008332,,,,,,,,,,,,,,,,, -107400,0.64157724,1.5267171,,,,,,,,,,,,,,,,, -107500,0.6455147,1.3903162,,,,,,,,,,,,,,,,, -107600,0.6360936,1.4614633,,,,,,,,,,,,,,,,, -107700,0.62093973,1.4016148,,,,,,,,,,,,,,,,, -107800,0.626804,1.4939997,,,,,,,,,,,,,,,,, -107900,0.65636235,1.4185439,,,,,,,,,,,,,,,,, -108000,0.65455776,1.4823478,,,,,,,,,,,,,,,,, -108043,,,0.6920853853225708,1.4408038854599,35.49696534352309,0.6909027695655823,1.4207086563110352,30.654077905274068,3000.0,0.7048399448394775,1.3408323526382446,30.62602068083253,3003.0,37829.96246767044,62608.53202152252,37829.96246767044,24773.52156305313,1.6027710437774658,0.0 -108100,0.6605352,1.4502633,,,,,,,,,,,,,,,,, -108200,0.669048,1.4834146,,,,,,,,,,,,,,,,, -108300,0.64293605,1.4972757,,,,,,,,,,,,,,,,, -108400,0.647019,1.4710165,,,,,,,,,,,,,,,,, -108500,0.66559607,1.391551,,,,,,,,,,,,,,,,, -108600,0.64943,1.4791231,,,,,,,,,,,,,,,,, -108700,0.66315186,1.4255575,,,,,,,,,,,,,,,,, -108800,0.65368193,1.380663,,,,,,,,,,,,,,,,, -108900,0.6793937,1.4436997,,,,,,,,,,,,,,,,, -109000,0.6704162,1.4530543,,,,,,,,,,,,,,,,, -109100,0.66642433,1.4236004,,,,,,,,,,,,,,,,, -109200,0.6768409,1.4517657,,,,,,,,,,,,,,,,, -109300,0.68854094,1.4384626,,,,,,,,,,,,,,,,, -109400,0.6864266,1.4940131,,,,,,,,,,,,,,,,, -109500,0.6844321,1.5116016,,,,,,,,,,,,,,,,, -109600,0.644859,1.4017882,,,,,,,,,,,,,,,,, -109700,0.6285076,1.467244,,,,,,,,,,,,,,,,, -109800,0.6775138,1.4408352,,,,,,,,,,,,,,,,, -109900,0.6542153,1.4670944,,,,,,,,,,,,,,,,, -110000,0.68680507,1.4504029,,,,,,,,,,,,,,,,, -110100,0.69170725,1.4411304,,,,,,,,,,,,,,,,, -110200,0.6941817,1.4893312,,,,,,,,,,,,,,,,, -110300,0.7081185,1.3864913,,,,,,,,,,,,,,,,, -110400,0.68926144,1.4360139,,,,,,,,,,,,,,,,, -110446,,,0.6905348300933838,1.441590428352356,34.91691907999978,0.6907168030738831,1.420137882232666,30.70651994536399,3000.0,0.7077334523200989,1.335753321647644,30.601082920116674,3003.0,38670.09367990494,63960.84578132629,38670.09367990494,25285.58715200424,1.6442315578460691,0.0 -110500,0.6878336,1.5353631,,,,,,,,,,,,,,,,, -110600,0.69460243,1.4207052,,,,,,,,,,,,,,,,, -110700,0.7173557,1.5011284,,,,,,,,,,,,,,,,, -110800,0.7164285,1.4549961,,,,,,,,,,,,,,,,, -110900,0.6985066,1.4809103,,,,,,,,,,,,,,,,, -111000,0.7068184,1.5202639,,,,,,,,,,,,,,,,, -111100,0.65588826,1.3854917,,,,,,,,,,,,,,,,, -111200,0.683952,1.3546858,,,,,,,,,,,,,,,,, -111300,0.68864214,1.475733,,,,,,,,,,,,,,,,, -111400,0.7221097,1.4338288,,,,,,,,,,,,,,,,, -111500,0.6993352,1.4702965,,,,,,,,,,,,,,,,, -111600,0.7140003,1.4151044,,,,,,,,,,,,,,,,, -111700,0.709949,1.385487,,,,,,,,,,,,,,,,, -111800,0.6854851,1.4527953,,,,,,,,,,,,,,,,, -111900,0.7286061,1.4019353,,,,,,,,,,,,,,,,, -112000,0.68084997,1.4075916,,,,,,,,,,,,,,,,, -112100,0.70028305,1.4278895,,,,,,,,,,,,,,,,, -112200,0.69345987,1.416699,,,,,,,,,,,,,,,,, -112300,0.68723255,1.4590777,,,,,,,,,,,,,,,,, -112400,0.69497657,1.4397515,,,,,,,,,,,,,,,,, -112500,0.74532133,1.5136006,,,,,,,,,,,,,,,,, -112600,0.7480454,1.4968299,,,,,,,,,,,,,,,,, -112700,0.7038567,1.4320372,,,,,,,,,,,,,,,,, -112800,0.7045451,1.3712043,,,,,,,,,,,,,,,,, -112848,,,0.7087449431419373,1.3407204151153564,36.47203853513563,0.691969096660614,1.416434407234192,31.080921156801807,3000.0,0.7069665193557739,1.3334004878997805,30.83949578884752,3003.0,39509.97726178169,65333.98978304863,39509.97726178169,25818.71938562393,1.6956148147583008,0.0 -112900,0.74737155,1.4380887,,,,,,,,,,,,,,,,, -113000,0.7249533,1.416148,,,,,,,,,,,,,,,,, -113100,0.7526048,1.4738033,,,,,,,,,,,,,,,,, -113200,0.7190989,1.3996013,,,,,,,,,,,,,,,,, -113300,0.737481,1.4402211,,,,,,,,,,,,,,,,, -113400,0.72427106,1.3989029,,,,,,,,,,,,,,,,, -113500,0.7330233,1.4840965,,,,,,,,,,,,,,,,, -113600,0.75079453,1.4293824,,,,,,,,,,,,,,,,, -113700,0.7400086,1.4266967,,,,,,,,,,,,,,,,, -113800,0.7718094,1.4679363,,,,,,,,,,,,,,,,, -113900,0.72692657,1.5042433,,,,,,,,,,,,,,,,, -114000,0.741008,1.3466587,,,,,,,,,,,,,,,,, -114100,0.74312454,1.463852,,,,,,,,,,,,,,,,, -114200,0.73431677,1.4224474,,,,,,,,,,,,,,,,, -114300,0.7392117,1.4541125,,,,,,,,,,,,,,,,, -114400,0.7219465,1.3985703,,,,,,,,,,,,,,,,, -114500,0.7030131,1.3821359,,,,,,,,,,,,,,,,, -114600,0.73670584,1.4182609,,,,,,,,,,,,,,,,, -114700,0.7419988,1.4303616,,,,,,,,,,,,,,,,, -114800,0.74019396,1.4274623,,,,,,,,,,,,,,,,, -114900,0.7175606,1.3797306,,,,,,,,,,,,,,,,, -115000,0.76284075,1.4006505,,,,,,,,,,,,,,,,, -115100,0.7374233,1.3801097,,,,,,,,,,,,,,,,, -115200,0.7641409,1.3660319,,,,,,,,,,,,,,,,, -115250,,,0.6991199254989624,1.3918790817260742,36.12728996133859,0.693630576133728,1.4105188846588137,31.00381458821544,3000.0,0.7082214951515198,1.327876091003418,30.625093720903884,3003.0,40349.93993067741,66701.25009989738,40349.93993067741,26345.901841640472,1.7358145713806152,0.0 -115300,0.74742234,1.4975839,,,,,,,,,,,,,,,,, -115400,0.7701842,1.4310772,,,,,,,,,,,,,,,,, -115500,0.7700296,1.4236773,,,,,,,,,,,,,,,,, -115600,0.73799294,1.4249616,,,,,,,,,,,,,,,,, -115700,0.7437402,1.4055982,,,,,,,,,,,,,,,,, -115800,0.7682733,1.3768901,,,,,,,,,,,,,,,,, -115900,0.7236362,1.4726095,,,,,,,,,,,,,,,,, -116000,1.8283889,1.3423862,,,,,,,,,,,,,,,,, -116100,0.7794309,1.3713199,,,,,,,,,,,,,,,,, -116200,0.78012574,1.4301682,,,,,,,,,,,,,,,,, -116300,0.7463946,1.4041572,,,,,,,,,,,,,,,,, -116400,0.7416127,1.3467484,,,,,,,,,,,,,,,,, -116500,0.7801263,1.4227784,,,,,,,,,,,,,,,,, -116600,0.76293826,1.3216674,,,,,,,,,,,,,,,,, -116700,0.7458425,1.41129,,,,,,,,,,,,,,,,, -116800,0.7695754,1.4492376,,,,,,,,,,,,,,,,, -116900,0.77990717,1.4645922,,,,,,,,,,,,,,,,, -117000,0.7778511,1.4622703,,,,,,,,,,,,,,,,, -117100,0.76213574,1.4396205,,,,,,,,,,,,,,,,, -117200,0.7820056,1.4311153,,,,,,,,,,,,,,,,, -117300,0.7733877,1.360896,,,,,,,,,,,,,,,,, -117400,1.0575488,1.4811336,,,,,,,,,,,,,,,,, -117500,0.75433874,1.3541855,,,,,,,,,,,,,,,,, -117600,0.7829238,1.3643433,,,,,,,,,,,,,,,,, -117651,,,0.6992983222007751,1.395785570144653,35.97326922225457,0.6929610371589661,1.409582495689392,31.041739597257934,3000.0,0.7084422707557678,1.3249000310897827,30.823761511515347,3003.0,41189.89118170738,68058.64119935036,41189.89118170738,26863.21662259102,1.7834343910217283,0.0 -117700,0.7810761,1.4309103,,,,,,,,,,,,,,,,, -117800,0.75411445,1.4110725,,,,,,,,,,,,,,,,, -117900,0.81017184,1.3935755,,,,,,,,,,,,,,,,, -118000,0.7740415,1.3325824,,,,,,,,,,,,,,,,, -118100,0.7936315,1.4745966,,,,,,,,,,,,,,,,, -118200,0.8258927,1.3979748,,,,,,,,,,,,,,,,, -118300,0.7738131,1.4323398,,,,,,,,,,,,,,,,, -118400,0.79693675,1.4073874,,,,,,,,,,,,,,,,, -118500,0.80604786,1.4137478,,,,,,,,,,,,,,,,, -118600,0.78801376,1.3796421,,,,,,,,,,,,,,,,, -118700,0.80272347,1.4695104,,,,,,,,,,,,,,,,, -118800,0.7994475,1.4014938,,,,,,,,,,,,,,,,, -118900,0.7866531,1.3759248,,,,,,,,,,,,,,,,, -119000,0.83225715,1.4606694,,,,,,,,,,,,,,,,, -119100,0.8276349,1.4434813,,,,,,,,,,,,,,,,, -119200,0.83171755,1.4297829,,,,,,,,,,,,,,,,, -119300,0.8031882,1.3873575,,,,,,,,,,,,,,,,, -119400,0.7868779,1.3615682,,,,,,,,,,,,,,,,, -119500,0.8185843,1.3443958,,,,,,,,,,,,,,,,, -119600,0.79724455,1.3727539,,,,,,,,,,,,,,,,, -119700,0.7967516,1.3825518,,,,,,,,,,,,,,,,, -119800,0.7805871,1.3413379,,,,,,,,,,,,,,,,, -119900,0.7952764,1.4157013,,,,,,,,,,,,,,,,, -120000,0.817104,1.3850391,,,,,,,,,,,,,,,,, -120052,,,0.7118774056434631,1.3286157846450806,36.789703266642434,0.6940025687217712,1.410871505737305,30.93815497869649,3000.0,0.7096391916275024,1.3247169256210327,30.95885245135478,3003.0,42029.80330038071,69427.26662182808,42029.80330038071,27391.812509059902,1.8251049518585205,0.0 -120100,0.80309993,1.421806,,,,,,,,,,,,,,,,, -120200,0.8110753,1.4134804,,,,,,,,,,,,,,,,, -120300,0.8429394,1.4168077,,,,,,,,,,,,,,,,, -120400,0.8247471,1.344393,,,,,,,,,,,,,,,,, -120500,0.85209197,1.4332247,,,,,,,,,,,,,,,,, -120600,0.804901,1.4370273,,,,,,,,,,,,,,,,, -120700,0.8058287,1.3439897,,,,,,,,,,,,,,,,, -120800,0.8104022,1.3869328,,,,,,,,,,,,,,,,, -120900,0.806617,1.3812141,,,,,,,,,,,,,,,,, -121000,0.80219936,1.398323,,,,,,,,,,,,,,,,, -121100,0.8338159,1.4603047,,,,,,,,,,,,,,,,, -121200,0.83362824,1.4274145,,,,,,,,,,,,,,,,, -121300,0.7828645,1.4133346,,,,,,,,,,,,,,,,, -121400,0.80365074,1.3804032,,,,,,,,,,,,,,,,, -121500,0.8263809,1.3705928,,,,,,,,,,,,,,,,, -121600,0.8233464,1.381179,,,,,,,,,,,,,,,,, -121700,0.80573964,1.3734182,,,,,,,,,,,,,,,,, -121800,0.85989404,1.410012,,,,,,,,,,,,,,,,, -121900,0.8255038,1.4382255,,,,,,,,,,,,,,,,, -122000,0.8658783,1.3682685,,,,,,,,,,,,,,,,, -122100,0.83083856,1.3204107,,,,,,,,,,,,,,,,, -122200,0.8171739,1.4121797,,,,,,,,,,,,,,,,, -122300,0.82926434,1.3344663,,,,,,,,,,,,,,,,, -122400,0.8385155,1.4133258,,,,,,,,,,,,,,,,, -122454,,,0.707240104675293,1.3504246473312378,36.60618759390239,0.6941017508506775,1.4063982963562012,31.11960561212031,3000.0,0.7113706469535828,1.3193117380142212,30.76013239179441,3003.0,42869.90353536606,70796.17909526825,42869.90353536606,27920.49778151512,1.8757200241088867,0.0 -122500,0.83399326,1.3565642,,,,,,,,,,,,,,,,, -122600,0.8193206,1.4090339,,,,,,,,,,,,,,,,, -122700,0.8010302,1.3293043,,,,,,,,,,,,,,,,, -122800,0.81510514,1.429319,,,,,,,,,,,,,,,,, -122900,0.8236286,1.3937529,,,,,,,,,,,,,,,,, -123000,0.856118,1.4309407,,,,,,,,,,,,,,,,, -123100,0.80238926,1.3276238,,,,,,,,,,,,,,,,, -123200,0.8346543,1.4341393,,,,,,,,,,,,,,,,, -123300,0.8481417,1.350522,,,,,,,,,,,,,,,,, -123400,0.85366714,1.370007,,,,,,,,,,,,,,,,, -123500,0.8322561,1.3965021,,,,,,,,,,,,,,,,, -123600,0.8286695,1.400866,,,,,,,,,,,,,,,,, -123700,0.8678803,1.4190599,,,,,,,,,,,,,,,,, -123800,0.8222572,1.3204246,,,,,,,,,,,,,,,,, -123900,0.8626611,1.3787118,,,,,,,,,,,,,,,,, -124000,0.8299937,1.3815567,,,,,,,,,,,,,,,,, -124100,0.8458643,1.4025195,,,,,,,,,,,,,,,,, -124200,0.8240874,1.3150554,,,,,,,,,,,,,,,,, -124300,0.80685157,1.4080222,,,,,,,,,,,,,,,,, -124400,0.82515466,1.3713349,,,,,,,,,,,,,,,,, -124500,0.829858,1.3695676,,,,,,,,,,,,,,,,, -124600,0.80392003,1.3319799,,,,,,,,,,,,,,,,, -124700,0.85473555,1.3244067,,,,,,,,,,,,,,,,, -124800,0.8591814,1.4222565,,,,,,,,,,,,,,,,, -124856,,,0.7027835845947266,1.3786816596984863,36.60843246209514,0.6942257285118103,1.4070228338241575,31.05294464668553,3000.0,0.711661159992218,1.3174381256103516,30.829254280160164,3003.0,43710.02218127251,72155.17084789276,43710.02218127251,28439.25016117096,1.9196293354034424,0.0 -124900,0.8306935,1.3740283,,,,,,,,,,,,,,,,, -125000,0.8367332,1.3595508,,,,,,,,,,,,,,,,, -125100,0.8556557,1.4377137,,,,,,,,,,,,,,,,, -125200,0.8401275,1.336053,,,,,,,,,,,,,,,,, -125300,0.8107104,1.3659569,,,,,,,,,,,,,,,,, -125400,0.85497385,1.4347253,,,,,,,,,,,,,,,,, -125500,0.82233393,1.406336,,,,,,,,,,,,,,,,, -125600,0.8263696,1.3607497,,,,,,,,,,,,,,,,, -125700,0.8369458,1.3834801,,,,,,,,,,,,,,,,, -125800,0.80712336,1.3865477,,,,,,,,,,,,,,,,, -125900,0.86324817,1.2602906,,,,,,,,,,,,,,,,, -126000,0.87317175,1.4029024,,,,,,,,,,,,,,,,, -126100,0.8371494,1.3921281,,,,,,,,,,,,,,,,, -126200,0.8555797,1.3027183,,,,,,,,,,,,,,,,, -126300,0.8558547,1.4244463,,,,,,,,,,,,,,,,, -126400,0.8541231,1.3795854,,,,,,,,,,,,,,,,, -126500,0.8278933,1.2898428,,,,,,,,,,,,,,,,, -126600,0.8256708,1.393011,,,,,,,,,,,,,,,,, -126700,0.87490237,1.3042401,,,,,,,,,,,,,,,,, -126800,0.8559103,1.3560367,,,,,,,,,,,,,,,,, -126900,0.832504,1.4301274,,,,,,,,,,,,,,,,, -127000,0.8554487,1.3227997,,,,,,,,,,,,,,,,, -127100,0.8498808,1.3794965,,,,,,,,,,,,,,,,, -127200,0.85384804,1.397481,,,,,,,,,,,,,,,,, -127257,,,0.7099313735961914,1.3409501314163208,36.71479864667537,0.6950068473815918,1.405134677886963,31.12469228435073,3000.0,0.7116146683692932,1.3157283067703247,30.96810588775437,3003.0,44549.97412252426,73504.98891401291,44549.97412252426,28948.995346546173,1.9638566970825195,0.0 -127300,0.8380878,1.3092438,,,,,,,,,,,,,,,,, -127400,0.8313776,1.3318412,,,,,,,,,,,,,,,,, -127500,0.85247356,1.391019,,,,,,,,,,,,,,,,, -127600,0.8766121,1.4180377,,,,,,,,,,,,,,,,, -127700,0.84022385,1.4362532,,,,,,,,,,,,,,,,, -127800,0.84345275,1.3527067,,,,,,,,,,,,,,,,, -127900,0.83730876,1.3997736,,,,,,,,,,,,,,,,, -128000,0.860483,1.3475308,,,,,,,,,,,,,,,,, -128100,0.8304709,1.4600353,,,,,,,,,,,,,,,,, -128200,0.8214742,1.3899734,,,,,,,,,,,,,,,,, -128300,0.8723387,1.3994884,,,,,,,,,,,,,,,,, -128400,0.85601074,1.3300908,,,,,,,,,,,,,,,,, -128500,0.84521055,1.3162702,,,,,,,,,,,,,,,,, -128600,0.82060546,1.3597182,,,,,,,,,,,,,,,,, -128700,0.836699,1.3763851,,,,,,,,,,,,,,,,, -128800,0.843079,1.3955818,,,,,,,,,,,,,,,,, -128900,0.83331454,1.3743869,,,,,,,,,,,,,,,,, -129000,0.8436863,1.3525711,,,,,,,,,,,,,,,,, -129100,0.8340946,1.4298333,,,,,,,,,,,,,,,,, -129200,0.8807164,1.3987046,,,,,,,,,,,,,,,,, -129300,0.8391364,1.4391468,,,,,,,,,,,,,,,,, -129400,0.8457449,1.4169763,,,,,,,,,,,,,,,,, -129500,0.8676963,1.3330681,,,,,,,,,,,,,,,,, -129600,0.8435983,1.4188534,,,,,,,,,,,,,,,,, -129659,,,0.7101288437843323,1.3376935720443726,36.566009419018926,0.6951928734779358,1.4060457944869995,31.22496853474892,3000.0,0.7120562791824341,1.3163964748382568,30.971645040855503,3003.0,45390.17083954811,74854.14743614197,45390.17083954811,29457.83890938759,2.007380247116089,0.0 -129700,0.84791905,1.434081,,,,,,,,,,,,,,,,, -129800,0.83318055,1.3614748,,,,,,,,,,,,,,,,, -129900,0.8524017,1.3538693,,,,,,,,,,,,,,,,, -130000,0.8045576,1.3674159,,,,,,,,,,,,,,,,, -130100,0.85284257,1.4076965,,,,,,,,,,,,,,,,, -130200,0.8219044,1.3784114,,,,,,,,,,,,,,,,, -130300,0.83719844,1.3515744,,,,,,,,,,,,,,,,, -130400,0.8422267,1.2964019,,,,,,,,,,,,,,,,, -130500,0.8369664,1.3596851,,,,,,,,,,,,,,,,, -130600,0.8288282,1.2954699,,,,,,,,,,,,,,,,, -130700,0.839193,1.374803,,,,,,,,,,,,,,,,, -130800,0.84896344,1.4043069,,,,,,,,,,,,,,,,, -130900,0.85177636,1.279088,,,,,,,,,,,,,,,,, -131000,0.84428495,1.282631,,,,,,,,,,,,,,,,, -131100,0.8113351,1.3086402,,,,,,,,,,,,,,,,, -131200,0.8173577,1.3633459,,,,,,,,,,,,,,,,, -131300,0.8397461,1.3638889,,,,,,,,,,,,,,,,, -131400,0.82026106,1.3974276,,,,,,,,,,,,,,,,, -131500,0.84849775,1.3873482,,,,,,,,,,,,,,,,, -131600,0.83240783,1.347479,,,,,,,,,,,,,,,,, -131700,0.827629,1.3629206,,,,,,,,,,,,,,,,, -131800,0.85134035,1.4083885,,,,,,,,,,,,,,,,, -131900,0.8233073,1.3311667,,,,,,,,,,,,,,,,, -132000,0.8431301,1.3350251,,,,,,,,,,,,,,,,, -132060,,,0.7102059721946716,1.3387534618377686,36.37749326792103,0.6951432824134827,1.4058610200881958,31.32584139390445,3000.0,0.7124397158622742,1.315784215927124,30.99279959879872,3003.0,46230.35232543945,76193.97706127167,46230.35232543945,29957.364458322525,2.050343990325928,0.0 -132100,0.8635229,1.391247,,,,,,,,,,,,,,,,, -132200,0.8259952,1.3883196,,,,,,,,,,,,,,,,, -132300,0.84292424,1.380575,,,,,,,,,,,,,,,,, -132400,0.88064194,1.4208524,,,,,,,,,,,,,,,,, -132500,0.82335657,1.3126972,,,,,,,,,,,,,,,,, -132600,0.84139115,1.3959558,,,,,,,,,,,,,,,,, -132700,0.8464824,1.4133679,,,,,,,,,,,,,,,,, -132800,0.8117151,1.3211056,,,,,,,,,,,,,,,,, -132900,0.80562204,1.3242257,,,,,,,,,,,,,,,,, -133000,0.81887573,1.3607439,,,,,,,,,,,,,,,,, -133100,0.84387,1.3290178,,,,,,,,,,,,,,,,, -133200,0.8243986,1.2882496,,,,,,,,,,,,,,,,, -133300,0.8721327,1.4440203,,,,,,,,,,,,,,,,, -133333,,,0.7138813138008118,1.3141133785247805,36.535556096203386,0.6952672600746155,1.4057084321975708,31.22940684635679,3000.0,0.7124048471450806,1.3156583309173584,30.96912742667688,3003.0,46675.37580132485,77146.18306612968,46675.37580132485,30464.457669734955,2.0986831188201904,0.0 -133333,,,,,,,,,,,,,,46675.375801324844,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 6428375c2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -838.6662838459015,0.0,27.63993668556213,1,0,27.63993668556213,0.0007088489946909,0.0,11.191027641296388,3003,866.306262254715,0.0004915465833619,0.0,11.19037628173828,0.0004835649742744,0.0,11.190281867980955,3000 -1295.6509058475494,0.0188579559326171,868.045458316803,2401,0,868.045458316803,0.536575436592102,18.380749662236845,2.5754520893096924,3003,2163.79090499878,0.5356904864311218,23.428443626878675,2.585400342941284,0.5370299220085144,19.524479217978104,2.548142910003662,3000 -1774.7314743995669,0.044546365737915,1708.1538841724396,4803,0,1708.1538841724396,0.6001161932945251,22.17443894805829,2.041588306427002,3003,3483.080261707306,0.5805889964103699,27.33088719134872,2.216299295425415,0.5971283316612244,23.725006938623974,2.0716443061828613,3000 -2270.4233243465424,0.0717079639434814,2548.2274305820465,7205,0,2548.2274305820465,0.6096101403236389,23.194267112786346,1.9771912097930908,3003,4818.948817253113,0.5939295291900635,27.58276559377453,2.099783182144165,0.6013316512107849,24.02121224259647,2.019782781600952,3000 -2764.5327610969543,0.0976550579071044,3388.435542821884,9607,0,3388.435542821884,0.6173610091209412,23.49182647289432,1.931437611579895,3003,6153.368245840073,0.5930240750312805,28.102238158911632,2.122828960418701,0.610048234462738,24.52583438503706,1.98089861869812,3000 -3368.605906009674,0.1242556571960449,4228.665568351746,12009,0,4228.665568351746,0.6177793145179749,23.448882076309648,1.9158536195755005,3003,7597.773768186569,0.5922099947929382,28.130188839500367,2.122587919235229,0.6139291524887085,25.084190063223115,1.9412788152694704,3000 -4191.980131864548,0.1532905101776123,5068.869497776032,14411,0,5068.869497776032,0.6252977848052979,23.781386907356108,1.8740822076797483,3003,9261.455813646317,0.5989102721214294,27.91491412501564,2.0769054889678955,0.61870276927948,25.03533095862288,1.9237356185913088,3000 -4739.305437326431,0.1807782649993896,5908.888803958893,16811,0,5908.888803958893,0.6220208406448364,23.640765822774625,1.8677488565444944,3003,10648.903832674026,0.5971294641494751,28.010472281108505,2.082860946655273,0.6194343566894531,25.16562009613152,1.906266689300537,3000 -5261.734477043152,0.2103769779205322,6748.833927154541,19212,0,6748.833927154541,0.6279588937759399,24.208979117206603,1.8503223657608032,3003,12011.384460687636,0.6081152558326721,29.55355088171276,1.988739967346192,0.6222489476203918,25.59932936158585,1.901801586151123,3000 -5764.907237768173,0.2401816844940185,7588.848192691803,21613,0,7588.848192691803,0.6319214701652527,25.061395584676017,1.8304063081741333,3003,13354.67732167244,0.6023226380348206,28.267704243990025,2.051201105117798,0.6232532858848572,25.35422343622899,1.8913344144821167,3000 -6240.782923460007,0.2702040672302246,8428.94619846344,24014,0,8428.94619846344,0.6272500157356262,24.59423009944916,1.8352961540222168,3003,14670.757593154907,0.601838231086731,28.35724378331637,2.06695556640625,0.622211754322052,25.447011922234104,1.888195872306824,3000 -6728.776242017746,0.2997102737426758,9269.041829109192,26415,0,9269.041829109192,0.6306548118591309,24.450720016161004,1.823778748512268,3003,15998.954328775406,0.6068945527076721,29.215509208704106,2.0167477130889893,0.6261794567108154,26.067529651612723,1.86761474609375,3000 -7361.727405786514,0.3307063579559326,10109.0739672184,28816,0,10109.0739672184,0.6310731768608093,24.502543676121828,1.814279556274414,3003,17472.045060634613,0.6076061129570007,28.61913311876729,2.0170085430145264,0.6236252188682556,25.54049461004584,1.8711217641830444,3000 -7889.079682350159,0.3619084358215332,10949.304612398148,31218,0,10949.304612398148,0.630387544631958,24.071335933790778,1.82143235206604,3003,18839.73539876938,0.6038236021995544,28.42866717212713,2.032376766204834,0.6241955757141113,25.288898614166587,1.8636945486068728,3000 -8387.22826910019,0.3919835090637207,11789.44777727127,33620,0,11789.44777727127,0.6363604664802551,24.920774993304907,1.7914565801620483,3003,20178.132102012634,0.609279990196228,28.38434953125899,1.99933660030365,0.6294032335281372,25.79646487757145,1.85016667842865,3000 -8959.241186380386,0.4224326610565185,12629.596199512482,36022,0,12629.596199512482,0.6338969469070435,24.793330930167617,1.800976037979126,3003,21590.39995884896,0.6051180958747864,28.50333226972837,2.022434711456299,0.6277913451194763,25.85400422027908,1.8456335067749023,3000 -9454.756899356842,0.452505350112915,13469.625989198685,38423,0,13469.625989198685,0.6375108957290649,25.07801187053838,1.7810250520706177,3003,22926.05318045616,0.6163857579231262,29.14323735552401,1.9339035749435425,0.6316102743148804,26.215031467429903,1.8279942274093628,3000 -9978.96815109253,0.4831573963165283,14309.583218336104,40824,0,14309.583218336104,0.6380454301834106,24.74312727864808,1.7811000347137451,3003,24290.329845905304,0.6103520393371582,29.142409587723943,1.9843041896820068,0.6277045607566833,25.61308660868012,1.833596467971801,3000 -10813.402871608734,0.5145649909973145,15149.78208065033,43226,0,15149.78208065033,0.6379989981651306,24.556187006987688,1.769517183303833,3003,25965.07247161865,0.6160485148429871,29.20263873077558,1.9517220258712769,0.6300479769706726,22.43720782495505,1.814311981201172,3000 -11449.814586162567,0.5513138771057129,15989.976817846298,45628,0,15989.976817846298,0.6361513137817383,24.85202517879502,1.7746611833572388,3003,27441.79203939438,0.6102014183998108,29.17744358049914,1.9810327291488647,0.6297379732131958,25.96525554826419,1.8217298984527588,3000 -12047.255705833437,0.5844166278839111,16830.157977104187,48030,0,16830.157977104187,0.6418104767799377,25.610204331052604,1.7534323930740356,3003,28879.521926641464,0.6131874322891235,29.10088251542185,1.9688488245010376,0.6333833336830139,26.4319809165706,1.80632483959198,3000 -12541.4399433136,0.6211183071136475,17670.106913089752,50431,0,17670.106913089752,0.6455406546592712,25.79031645404289,1.725481033325195,3003,30213.769419670105,0.6287590861320496,30.29873589696989,1.834293246269226,0.6369046568870544,26.45663032630946,1.7885897159576416,3000 -13070.50362753868,0.6545243263244629,18510.01056957245,52832,0,18510.01056957245,0.6466097235679626,25.517370648362107,1.7135682106018066,3003,31582.84491109848,0.6180623173713684,29.478945279105627,1.9361435174942017,0.6362723112106323,26.286487484829664,1.7838091850280762,3000 -13590.657889842989,0.6880850791931152,19349.917988538746,55233,0,19349.917988538746,0.6476556062698364,25.627177441123937,1.7083313465118408,3003,32943.01463651657,0.6131559014320374,29.613710414488946,1.9650341272354128,0.6389381289482117,26.565091468583024,1.767750859260559,3000 -14172.742085933683,0.7207436561584473,20189.931703090668,57634,0,20189.931703090668,0.6502469778060913,26.16670611175656,1.700022578239441,3003,34365.222650527954,0.623281717300415,30.03921252766637,1.8893014192581177,0.639272928237915,26.7444395678232,1.763260841369629,3000 -14631.463346242905,0.7540163993835449,21030.09056377411,60036,0,21030.09056377411,0.6512114405632019,25.748610033015733,1.682421326637268,3003,35664.21192359924,0.6204534769058228,30.193340853845427,1.9200528860092163,0.6428934335708618,26.78540750934711,1.7423086166381836,3000 -15134.199719667437,0.7899153232574463,21870.185875177383,62438,0,21870.185875177383,0.6535122990608215,26.31763856807123,1.6790618896484375,3003,37007.15480089188,0.6210818886756897,29.44747130858823,1.9125397205352783,0.6429926156997681,27.028646288240783,1.7286237478256226,3000 -15622.89288187027,0.8244888782501221,22710.087017774586,64838,0,22710.087017774586,0.653872549533844,26.29416092890208,1.662144660949707,3003,38335.86119651794,0.6214345097541809,29.87296898313337,1.899621844291687,0.6450632810592651,26.96320628106604,1.725380539894104,3000 -16216.848822593687,0.8594467639923096,23550.073257923126,67239,0,23550.073257923126,0.6557550430297852,26.504229713711247,1.6552984714508057,3003,39769.913810014725,0.6219140887260437,29.898942772496444,1.8977916240692136,0.6444805264472961,27.21427000478681,1.7192747592926023,3000 -16739.037479639053,0.9037342071533204,24390.02137088776,69640,0,24390.02137088776,0.6588344573974609,26.86768368161496,1.639184832572937,3003,41132.171432971954,0.6379268765449524,30.41896362964424,1.781161189079285,0.649774968624115,27.45221541734196,1.6948308944702148,3000 -17305.21870613098,0.9400687217712402,25229.988456487656,72042,0,25229.988456487656,0.6625297665596008,27.263224574075057,1.6213717460632324,3003,42538.432683467865,0.6333906650543213,30.65259166358259,1.8417229652404783,0.6506924629211426,27.47268478695876,1.6889138221740725,3000 -17854.567930936813,0.9772353172302246,26069.895349264145,74443,0,26069.895349264145,0.662494957447052,26.571996580797872,1.607891082763672,3003,43927.80186915398,0.6327632665634155,30.128673062601425,1.840661883354187,0.6526143550872803,27.47545997818604,1.6699992418289185,3000 -18422.05643105507,1.0228424072265625,26909.92446255684,76844,0,26909.92446255684,0.6651095151901245,27.0778855815688,1.5921093225479126,3003,45335.44073653221,0.6360828280448914,30.947704964285027,1.7997148036956787,0.6537302732467651,27.672505117892413,1.655122995376587,3000 -18953.771744966507,1.0598394870758057,27749.865286827087,79245,0,27749.865286827087,0.666852593421936,27.10222512757891,1.5803241729736328,3003,46707.21108341217,0.6354596018791199,30.61698412674778,1.8094682693481443,0.6570656299591064,27.837613650228427,1.6474754810333252,3000 -19483.31466794014,1.0974700450897217,28589.836373090744,81645,0,28589.836373090744,0.6707803606987,27.857890901754264,1.570162057876587,3003,48076.842170238495,0.6552977561950684,32.10947252151286,1.6481138467788696,0.6575492024421692,27.99446188450933,1.6359890699386597,3000 -20069.23208117485,1.141160249710083,29429.99524831772,84048,0,29429.99524831772,0.6720934510231018,27.58812304327058,1.5535181760787964,3003,49503.03884482384,0.6372441649436951,31.231938001811475,1.792824149131775,0.6590867042541504,28.22790193040609,1.6225743293762207,3000 -20759.51155924797,1.185424566268921,30269.94028711319,86449,0,30269.94028711319,0.6744524240493774,27.97908133401334,1.5341390371322632,3003,51033.38291668892,0.6410945057868958,31.41908621569645,1.7650405168533323,0.6631659865379333,28.64836691973624,1.604163408279419,3000 -21304.90257930756,1.2225456237792969,31109.83723974228,88851,0,31109.83723974228,0.6743943095207214,27.792506126010068,1.5237705707550049,3003,52418.78368616104,0.6504623293876648,31.84581123567893,1.689774990081787,0.6630296111106873,28.532335408202147,1.5953129529953003,3000 -21846.217509269714,1.2612571716308594,31949.88634347916,91254,0,31949.88634347916,0.6779152750968933,28.25653450368869,1.5046041011810305,3003,53800.26221823692,0.6479614973068237,31.76366900205668,1.720964789390564,0.665571391582489,28.862445004881646,1.5849394798278809,3000 -22403.69824290276,1.300079584121704,32789.87203192711,93656,0,32789.87203192711,0.6798908114433289,28.425006038283964,1.4960421323776243,3003,55197.84358978272,0.6468232274055481,31.68318684415901,1.723307490348816,0.6688323616981506,29.080144950437283,1.5702892541885376,3000 -22949.19913959503,1.338247776031494,33629.953258514404,96058,0,33629.953258514404,0.6813665628433228,28.80327442841708,1.4757575988769531,3003,56583.53949093819,0.6511693596839905,31.714046536608137,1.6873743534088137,0.6717833280563354,29.245886051656544,1.547297716140747,3000 -23462.70897555352,1.385880470275879,34470.01172947884,98461,0,34470.01172947884,0.6865260601043701,29.0640595971544,1.456973910331726,3003,57937.23023700714,0.6557013392448425,32.65956777295396,1.67685866355896,0.6736308336257935,29.419734290930183,1.5366019010543823,3000 -24032.816915750504,1.42657732963562,35309.92882537842,100863,0,35309.92882537842,0.6878159642219543,29.20659989243813,1.442979335784912,3003,59347.370725631714,0.6670910716056824,33.74426266233717,1.576935052871704,0.6752303242683411,29.383349787408346,1.5253691673278809,3000 -24612.61792945861,1.474708080291748,36149.880427360535,103265,0,36149.880427360535,0.6917320489883423,29.722389977868755,1.421906352043152,3003,60767.247957229614,0.6593321561813354,32.99818121622395,1.6377383470535278,0.676978588104248,29.280341283493588,1.510822296142578,3000 -25177.135360479355,1.5139610767364502,36990.06665062904,105668,0,36990.06665062904,0.6930451393127441,29.341508113582247,1.4107818603515625,3003,62172.06532788277,0.6658817529678345,32.95288097209237,1.6051431894302368,0.6797559857368469,29.482852243566366,1.4939372539520264,3000 -25700.941499471664,1.5608172416687012,37830.27451658249,108070,0,37830.27451658249,0.6957992315292358,29.608836478521752,1.3981857299804688,3003,63536.203300237656,0.6727055311203003,33.9455520747962,1.5570759773254397,0.6823598146438599,29.72807262910545,1.4784127473831177,3000 -26255.110206842422,1.6010394096374512,38670.3262925148,110471,0,38670.3262925148,0.6989832520484924,29.73907053500585,1.3790864944458008,3003,64930.54170131683,0.6731343865394592,33.786392688007275,1.5567851066589355,0.6843188405036926,30.15936458541903,1.4721633195877075,3000 -26827.57296895981,1.6411964893341064,39510.36837506294,112873,0,39510.36837506294,0.6995874643325806,30.10767739633273,1.3684998750686646,3003,66343.16135883331,0.6936594247817993,35.508655574660985,1.431340217590332,0.685050368309021,30.06185313900829,1.4566941261291504,3000 -27465.951172590256,1.6827702522277832,40350.29947733879,115275,0,40350.29947733879,0.7013421654701233,30.101626542932724,1.357330322265625,3003,67821.58649492264,0.6838340163230896,34.26181205626829,1.4911880493164062,0.6880261898040771,30.37748177908537,1.4454679489135742,3000 -28020.668506383896,1.7252159118652344,41190.48154783249,117677,0,41190.48154783249,0.7044913172721863,30.434828744976627,1.3434144258499146,3003,69216.60397958755,0.6806775331497192,34.7215799842478,1.5084309577941897,0.6893280744552612,30.50947906381349,1.4370354413986206,3000 -28575.47600364685,1.7752346992492676,42030.51136517525,120079,0,42030.51136517525,0.7062111496925354,30.584510235579657,1.3397371768951416,3003,70611.56590008736,0.6928339004516602,35.37524108993434,1.4373910427093506,0.690034806728363,30.701788225822256,1.4303141832351685,3000 -29121.44729280472,1.8173811435699463,42870.60580945015,122481,0,42870.60580945015,0.7073150873184204,30.721067569448703,1.328243374824524,3003,71997.75002217293,0.688599705696106,35.63986574927191,1.4655152559280396,0.6912251710891724,30.711851471746595,1.4246001243591309,3000 -29661.2113044262,1.8602821826934808,43710.570806741714,124883,0,43710.570806741714,0.7088490128517151,30.963558849323057,1.3179659843444824,3003,73377.59749126434,0.6913896799087524,35.22434842341379,1.4493789672851562,0.6933205723762512,30.862479457424147,1.415845274925232,3000 -30192.34838557244,1.9036564826965328,44550.7475271225,127286,0,44550.7475271225,0.7101272344589233,30.82745926563664,1.3143149614334106,3003,74749.03038787842,0.6950476765632629,35.723502797572735,1.428101658821106,0.6939281225204468,30.916063388296088,1.4126336574554443,3000 -30699.643965244293,1.9506874084472656,45390.82706856728,129688,0,45390.82706856728,0.7101272344589233,31.011806972890582,1.311076045036316,3003,76096.52921843529,0.6956238150596619,35.90949567414748,1.4236267805099487,0.6945853233337402,30.926567537057483,1.410330057144165,3000 -31225.78633785248,1.9961497783660889,46230.71160531044,132089,0,46230.71160531044,0.7106966376304626,30.939606512181268,1.3105497360229492,3003,77462.6783709526,0.7016637921333313,35.79594700241151,1.386156439781189,0.694386899471283,30.910722132760736,1.4098578691482544,3000 -31743.91850733757,2.04841947555542,46665.6048541069,133333,0,46665.6048541069,0.7108826041221619,30.902745284398407,1.3107471466064453,3003,78415.79715752602,0.6967169642448425,35.835157411181875,1.42378568649292,0.69430011510849,30.869154358846924,1.410101294517517,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/measurements.csv deleted file mode 100644 index 6c362deec..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.8829894,11.16369,,,,,,,,,,,,,,,,, -1,,,0.0004915465833619,11.19037628173828,0.0,0.0004835649742744,11.190281867980955,0.0,3000.0,0.0007088489946909,11.191027641296388,0.0,3003.0,27.63993668556213,866.306262254715,27.63993668556213,838.6662838459015,0.0,0.0 -100,0.7193015,7.6090302,,,,,,,,,,,,,,,,, -200,0.6340204,6.695078,,,,,,,,,,,,,,,,, -300,0.7479122,5.9650474,,,,,,,,,,,,,,,,, -400,0.65326035,5.4351273,,,,,,,,,,,,,,,,, -500,0.680016,5.116533,,,,,,,,,,,,,,,,, -600,0.6123816,4.8370442,,,,,,,,,,,,,,,,, -700,0.5395627,4.473989,,,,,,,,,,,,,,,,, -800,0.66514146,4.2458773,,,,,,,,,,,,,,,,, -900,0.41849777,3.9969108,,,,,,,,,,,,,,,,, -1000,0.5168086,3.800872,,,,,,,,,,,,,,,,, -1100,0.36535215,3.631072,,,,,,,,,,,,,,,,, -1200,0.41841188,3.5184464,,,,,,,,,,,,,,,,, -1300,0.377614,3.4442353,,,,,,,,,,,,,,,,, -1400,0.34311792,3.328879,,,,,,,,,,,,,,,,, -1500,0.3182164,3.16109,,,,,,,,,,,,,,,,, -1600,0.31747648,3.1040888,,,,,,,,,,,,,,,,, -1700,0.20649077,2.981342,,,,,,,,,,,,,,,,, -1800,0.2673987,2.8660893,,,,,,,,,,,,,,,,, -1900,0.25400835,2.8744411,,,,,,,,,,,,,,,,, -2000,0.19078232,2.8157282,,,,,,,,,,,,,,,,, -2100,0.19430041,2.6960568,,,,,,,,,,,,,,,,, -2200,0.22886778,2.6833525,,,,,,,,,,,,,,,,, -2300,0.22640194,2.6434317,,,,,,,,,,,,,,,,, -2400,0.24623276,2.6225646,,,,,,,,,,,,,,,,, -2401,,,0.5356904864311218,2.585400342941284,23.428443626878675,0.5370299220085144,2.548142910003662,19.524479217978104,3000.0,0.536575436592102,2.5754520893096924,18.380749662236845,3003.0,868.045458316803,2163.79090499878,868.045458316803,1295.6509058475494,0.0188579559326171,0.0 -2500,0.27137583,2.6439872,,,,,,,,,,,,,,,,, -2600,0.23822749,2.5441656,,,,,,,,,,,,,,,,, -2700,0.1954355,2.5420756,,,,,,,,,,,,,,,,, -2800,0.16879533,2.3519745,,,,,,,,,,,,,,,,, -2900,0.20660149,2.5108376,,,,,,,,,,,,,,,,, -3000,0.21344821,2.4606972,,,,,,,,,,,,,,,,, -3100,0.25600433,2.3798664,,,,,,,,,,,,,,,,, -3200,0.31468746,2.2963068,,,,,,,,,,,,,,,,, -3300,0.48200667,2.2792294,,,,,,,,,,,,,,,,, -3400,0.312465,2.326338,,,,,,,,,,,,,,,,, -3500,0.28832453,2.233379,,,,,,,,,,,,,,,,, -3600,0.24294989,2.280182,,,,,,,,,,,,,,,,, -3700,0.46167058,2.1865463,,,,,,,,,,,,,,,,, -3800,0.27445742,2.2363727,,,,,,,,,,,,,,,,, -3900,0.20571579,2.3234477,,,,,,,,,,,,,,,,, -4000,0.44365278,2.281809,,,,,,,,,,,,,,,,, -4100,0.42484483,2.2425694,,,,,,,,,,,,,,,,, -4200,0.25601313,2.2847254,,,,,,,,,,,,,,,,, -4300,0.352862,2.2678516,,,,,,,,,,,,,,,,, -4400,0.3927038,2.2960064,,,,,,,,,,,,,,,,, -4500,0.30086267,2.1929693,,,,,,,,,,,,,,,,, -4600,0.5900669,2.3110561,,,,,,,,,,,,,,,,, -4700,0.41310564,2.2029526,,,,,,,,,,,,,,,,, -4800,0.58356136,2.1689126,,,,,,,,,,,,,,,,, -4803,,,0.5805889964103699,2.216299295425415,27.33088719134872,0.5971283316612244,2.0716443061828613,23.725006938623974,3000.0,0.6001161932945251,2.041588306427002,22.17443894805829,3003.0,1708.1538841724396,3483.080261707306,1708.1538841724396,1774.7314743995669,0.044546365737915,0.0 -4900,0.37206438,2.212171,,,,,,,,,,,,,,,,, -5000,0.2090131,2.1679187,,,,,,,,,,,,,,,,, -5100,0.4222025,2.1906817,,,,,,,,,,,,,,,,, -5200,0.31346753,2.1696796,,,,,,,,,,,,,,,,, -5300,0.2756263,2.1576612,,,,,,,,,,,,,,,,, -5400,0.32541353,2.2422404,,,,,,,,,,,,,,,,, -5500,0.31210104,2.1590953,,,,,,,,,,,,,,,,, -5600,0.4192217,2.2248874,,,,,,,,,,,,,,,,, -5700,0.64795333,2.1374505,,,,,,,,,,,,,,,,, -5800,0.34833363,2.1088762,,,,,,,,,,,,,,,,, -5900,0.28659263,2.2247314,,,,,,,,,,,,,,,,, -6000,0.34942922,2.0657365,,,,,,,,,,,,,,,,, -6100,0.577994,2.1383204,,,,,,,,,,,,,,,,, -6200,0.5044883,2.2454789,,,,,,,,,,,,,,,,, -6300,0.543911,2.1667712,,,,,,,,,,,,,,,,, -6400,0.3493749,2.1950808,,,,,,,,,,,,,,,,, -6500,0.347769,2.2092226,,,,,,,,,,,,,,,,, -6600,0.29470164,2.1975074,,,,,,,,,,,,,,,,, -6700,0.47619712,2.2299023,,,,,,,,,,,,,,,,, -6800,0.5590005,2.1395779,,,,,,,,,,,,,,,,, -6900,0.28900266,2.1926036,,,,,,,,,,,,,,,,, -7000,0.681238,2.1714935,,,,,,,,,,,,,,,,, -7100,0.33873403,2.1550047,,,,,,,,,,,,,,,,, -7200,0.40421006,2.1547875,,,,,,,,,,,,,,,,, -7205,,,0.5939295291900635,2.099783182144165,27.58276559377453,0.6013316512107849,2.019782781600952,24.02121224259647,3000.0,0.6096101403236389,1.9771912097930908,23.194267112786346,3003.0,2548.2274305820465,4818.948817253113,2548.2274305820465,2270.4233243465424,0.0717079639434814,0.0 -7300,0.40429214,2.119563,,,,,,,,,,,,,,,,, -7400,0.30317023,2.0850394,,,,,,,,,,,,,,,,, -7500,0.38321123,2.1656914,,,,,,,,,,,,,,,,, -7600,0.38133237,2.1292777,,,,,,,,,,,,,,,,, -7700,0.40345094,2.1180105,,,,,,,,,,,,,,,,, -7800,0.30496222,2.1369276,,,,,,,,,,,,,,,,, -7900,0.43539888,2.102462,,,,,,,,,,,,,,,,, -8000,0.3997565,2.0696847,,,,,,,,,,,,,,,,, -8100,0.5320926,2.1369028,,,,,,,,,,,,,,,,, -8200,0.7054833,2.0505438,,,,,,,,,,,,,,,,, -8300,0.6180346,2.2171342,,,,,,,,,,,,,,,,, -8400,0.25242603,2.0867388,,,,,,,,,,,,,,,,, -8500,0.40306947,2.0991473,,,,,,,,,,,,,,,,, -8600,0.4424814,2.1310039,,,,,,,,,,,,,,,,, -8700,0.4010813,2.1385868,,,,,,,,,,,,,,,,, -8800,0.31041968,2.076581,,,,,,,,,,,,,,,,, -8900,0.63324296,2.1392329,,,,,,,,,,,,,,,,, -9000,0.35411766,2.182796,,,,,,,,,,,,,,,,, -9100,0.34272155,2.0898352,,,,,,,,,,,,,,,,, -9200,0.40148,2.2146473,,,,,,,,,,,,,,,,, -9300,0.25410774,2.0749564,,,,,,,,,,,,,,,,, -9400,0.34959668,2.121759,,,,,,,,,,,,,,,,, -9500,0.7388337,2.1239603,,,,,,,,,,,,,,,,, -9600,0.28153816,2.10897,,,,,,,,,,,,,,,,, -9607,,,0.5930240750312805,2.122828960418701,28.102238158911632,0.610048234462738,1.98089861869812,24.52583438503706,3000.0,0.6173610091209412,1.931437611579895,23.49182647289432,3003.0,3388.435542821884,6153.368245840073,3388.435542821884,2764.5327610969543,0.0976550579071044,0.0 -9700,0.47755128,2.0546732,,,,,,,,,,,,,,,,, -9800,0.47117558,2.1015844,,,,,,,,,,,,,,,,, -9900,0.25133514,2.0609782,,,,,,,,,,,,,,,,, -10000,0.27464437,2.0427165,,,,,,,,,,,,,,,,, -10100,0.46626738,2.051321,,,,,,,,,,,,,,,,, -10200,0.40677154,2.1017725,,,,,,,,,,,,,,,,, -10300,0.3218557,2.1067982,,,,,,,,,,,,,,,,, -10400,0.3357541,2.0371628,,,,,,,,,,,,,,,,, -10500,0.2881717,2.1713033,,,,,,,,,,,,,,,,, -10600,0.51391894,2.1211715,,,,,,,,,,,,,,,,, -10700,0.61388206,2.2324502,,,,,,,,,,,,,,,,, -10800,0.45345852,2.1395464,,,,,,,,,,,,,,,,, -10900,0.2496212,2.0616393,,,,,,,,,,,,,,,,, -11000,0.25756723,2.1529982,,,,,,,,,,,,,,,,, -11100,0.36143398,2.0637953,,,,,,,,,,,,,,,,, -11200,0.429218,2.1464972,,,,,,,,,,,,,,,,, -11300,0.689008,2.0728931,,,,,,,,,,,,,,,,, -11400,0.75012606,2.0529647,,,,,,,,,,,,,,,,, -11500,0.25224203,2.08326,,,,,,,,,,,,,,,,, -11600,0.5732191,2.0978181,,,,,,,,,,,,,,,,, -11700,0.33199993,2.2473497,,,,,,,,,,,,,,,,, -11800,0.2869732,2.059303,,,,,,,,,,,,,,,,, -11900,0.48096007,2.1249025,,,,,,,,,,,,,,,,, -12000,0.2917539,2.1560762,,,,,,,,,,,,,,,,, -12009,,,0.5922099947929382,2.122587919235229,28.130188839500367,0.6139291524887085,1.9412788152694704,25.084190063223115,3000.0,0.6177793145179749,1.9158536195755005,23.448882076309648,3003.0,4228.665568351746,7597.773768186569,4228.665568351746,3368.605906009674,0.1242556571960449,0.0 -12100,0.60186136,2.0153067,,,,,,,,,,,,,,,,, -12200,0.6012896,2.0909047,,,,,,,,,,,,,,,,, -12300,0.5381772,2.118443,,,,,,,,,,,,,,,,, -12400,0.46032715,2.0664833,,,,,,,,,,,,,,,,, -12500,0.41973925,2.1006343,,,,,,,,,,,,,,,,, -12600,0.43591446,2.079689,,,,,,,,,,,,,,,,, -12700,0.3380724,2.096079,,,,,,,,,,,,,,,,, -12800,0.49802214,2.028244,,,,,,,,,,,,,,,,, -12900,0.29054558,2.1108987,,,,,,,,,,,,,,,,, -13000,0.4176313,2.1576772,,,,,,,,,,,,,,,,, -13100,0.4627378,2.1248033,,,,,,,,,,,,,,,,, -13200,0.2657371,2.161139,,,,,,,,,,,,,,,,, -13300,0.29272297,2.0453296,,,,,,,,,,,,,,,,, -13400,0.8839353,2.1356335,,,,,,,,,,,,,,,,, -13500,0.38053188,2.0702896,,,,,,,,,,,,,,,,, -13600,0.86476827,2.0479724,,,,,,,,,,,,,,,,, -13700,0.38458067,2.092019,,,,,,,,,,,,,,,,, -13800,0.71586,2.1188424,,,,,,,,,,,,,,,,, -13900,0.32172444,2.1277769,,,,,,,,,,,,,,,,, -14000,0.29207024,2.04202,,,,,,,,,,,,,,,,, -14100,0.5345838,2.0553398,,,,,,,,,,,,,,,,, -14200,0.45643455,2.0979125,,,,,,,,,,,,,,,,, -14300,0.38914496,1.9800622,,,,,,,,,,,,,,,,, -14400,0.2898615,2.0687943,,,,,,,,,,,,,,,,, -14411,,,0.5989102721214294,2.0769054889678955,27.91491412501564,0.61870276927948,1.9237356185913088,25.03533095862288,3000.0,0.6252977848052979,1.8740822076797483,23.781386907356108,3003.0,5068.869497776032,9261.455813646317,5068.869497776032,4191.980131864548,0.1532905101776123,0.0 -14500,0.48833495,2.0469153,,,,,,,,,,,,,,,,, -14600,0.58552825,2.0830107,,,,,,,,,,,,,,,,, -14700,0.3841571,2.0334096,,,,,,,,,,,,,,,,, -14800,0.2900622,2.0819426,,,,,,,,,,,,,,,,, -14900,0.36698878,2.1048539,,,,,,,,,,,,,,,,, -15000,0.41211164,1.9908372,,,,,,,,,,,,,,,,, -15100,0.4350433,2.01831,,,,,,,,,,,,,,,,, -15200,0.51618016,2.19288,,,,,,,,,,,,,,,,, -15300,0.49427575,2.1167462,,,,,,,,,,,,,,,,, -15400,0.4374482,1.9699562,,,,,,,,,,,,,,,,, -15500,0.6225112,2.1241271,,,,,,,,,,,,,,,,, -15600,0.3459636,2.116844,,,,,,,,,,,,,,,,, -15700,0.4532516,2.0277953,,,,,,,,,,,,,,,,, -15800,0.60063833,2.0984483,,,,,,,,,,,,,,,,, -15900,0.25980395,2.155886,,,,,,,,,,,,,,,,, -16000,0.25450012,2.089826,,,,,,,,,,,,,,,,, -16100,0.52632576,2.1591594,,,,,,,,,,,,,,,,, -16200,0.41894802,2.0078585,,,,,,,,,,,,,,,,, -16300,0.45639977,2.04665,,,,,,,,,,,,,,,,, -16400,0.2934332,2.062143,,,,,,,,,,,,,,,,, -16500,0.6752283,2.0658798,,,,,,,,,,,,,,,,, -16600,0.29015887,2.036742,,,,,,,,,,,,,,,,, -16700,0.31445026,1.9929384,,,,,,,,,,,,,,,,, -16800,0.34188715,2.1149983,,,,,,,,,,,,,,,,, -16811,,,0.5971294641494751,2.082860946655273,28.010472281108505,0.6194343566894531,1.906266689300537,25.16562009613152,3000.0,0.6220208406448364,1.8677488565444944,23.640765822774625,3003.0,5908.888803958893,10648.903832674026,5908.888803958893,4739.305437326431,0.1807782649993896,0.0 -16900,0.80194414,2.040208,,,,,,,,,,,,,,,,, -17000,0.34039018,2.1276069,,,,,,,,,,,,,,,,, -17100,0.57246566,2.1064165,,,,,,,,,,,,,,,,, -17200,0.29532245,2.0091696,,,,,,,,,,,,,,,,, -17300,0.33803716,2.101095,,,,,,,,,,,,,,,,, -17400,0.65270644,2.0200388,,,,,,,,,,,,,,,,, -17500,0.7869984,2.0557547,,,,,,,,,,,,,,,,, -17600,0.43108585,2.112697,,,,,,,,,,,,,,,,, -17700,0.2686219,1.9864299,,,,,,,,,,,,,,,,, -17800,0.29374182,2.0737817,,,,,,,,,,,,,,,,, -17900,0.37158728,2.0668192,,,,,,,,,,,,,,,,, -18000,0.35252368,2.0723,,,,,,,,,,,,,,,,, -18100,0.3217755,2.0966165,,,,,,,,,,,,,,,,, -18200,0.28175613,2.0577173,,,,,,,,,,,,,,,,, -18300,0.51759595,2.0386493,,,,,,,,,,,,,,,,, -18400,0.3411342,2.031595,,,,,,,,,,,,,,,,, -18500,0.48267236,2.0868907,,,,,,,,,,,,,,,,, -18600,0.8833571,2.0843973,,,,,,,,,,,,,,,,, -18700,0.3428137,2.0428133,,,,,,,,,,,,,,,,, -18800,0.6194317,2.0669422,,,,,,,,,,,,,,,,, -18900,0.588263,1.980073,,,,,,,,,,,,,,,,, -19000,0.51525533,2.075227,,,,,,,,,,,,,,,,, -19100,0.8431568,2.0831258,,,,,,,,,,,,,,,,, -19200,0.25548935,2.007339,,,,,,,,,,,,,,,,, -19212,,,0.6081152558326721,1.988739967346192,29.55355088171276,0.6222489476203918,1.901801586151123,25.59932936158585,3000.0,0.6279588937759399,1.8503223657608032,24.208979117206603,3003.0,6748.833927154541,12011.384460687636,6748.833927154541,5261.734477043152,0.2103769779205322,0.0 -19300,0.32570827,2.022308,,,,,,,,,,,,,,,,, -19400,0.3852347,2.1252956,,,,,,,,,,,,,,,,, -19500,0.52764446,2.1265533,,,,,,,,,,,,,,,,, -19600,0.24666643,2.057468,,,,,,,,,,,,,,,,, -19700,0.33127826,2.1719296,,,,,,,,,,,,,,,,, -19800,0.39718157,2.039321,,,,,,,,,,,,,,,,, -19900,0.38887307,2.0709994,,,,,,,,,,,,,,,,, -20000,0.30471936,2.131378,,,,,,,,,,,,,,,,, -20100,0.27715024,1.992918,,,,,,,,,,,,,,,,, -20200,0.34542838,2.0798473,,,,,,,,,,,,,,,,, -20300,0.46787217,2.0968218,,,,,,,,,,,,,,,,, -20400,0.3003503,2.1651824,,,,,,,,,,,,,,,,, -20500,0.34828657,2.040154,,,,,,,,,,,,,,,,, -20600,0.34758192,2.0792298,,,,,,,,,,,,,,,,, -20700,0.4869142,2.0148559,,,,,,,,,,,,,,,,, -20800,0.6048626,2.0563252,,,,,,,,,,,,,,,,, -20900,0.27312216,1.9718637,,,,,,,,,,,,,,,,, -21000,0.47426322,2.0445468,,,,,,,,,,,,,,,,, -21100,0.48150182,2.132621,,,,,,,,,,,,,,,,, -21200,0.740979,2.1357398,,,,,,,,,,,,,,,,, -21300,0.45288724,2.0929008,,,,,,,,,,,,,,,,, -21400,0.28309438,2.1061704,,,,,,,,,,,,,,,,, -21500,0.2626512,2.1243508,,,,,,,,,,,,,,,,, -21600,0.42975205,1.9886309,,,,,,,,,,,,,,,,, -21613,,,0.6023226380348206,2.051201105117798,28.267704243990025,0.6232532858848572,1.8913344144821167,25.35422343622899,3000.0,0.6319214701652527,1.8304063081741333,25.061395584676017,3003.0,7588.848192691803,13354.67732167244,7588.848192691803,5764.907237768173,0.2401816844940185,0.0 -21700,0.28783998,1.983058,,,,,,,,,,,,,,,,, -21800,0.62333673,2.0310404,,,,,,,,,,,,,,,,, -21900,0.61149794,2.1288643,,,,,,,,,,,,,,,,, -22000,0.5274383,2.0305228,,,,,,,,,,,,,,,,, -22100,0.3052809,2.1611931,,,,,,,,,,,,,,,,, -22200,0.29846823,2.0978124,,,,,,,,,,,,,,,,, -22300,0.44251785,1.9924874,,,,,,,,,,,,,,,,, -22400,0.7771456,2.0206811,,,,,,,,,,,,,,,,, -22500,0.7088218,2.0273948,,,,,,,,,,,,,,,,, -22600,0.46035865,2.023674,,,,,,,,,,,,,,,,, -22700,0.39214832,2.1510818,,,,,,,,,,,,,,,,, -22800,0.6429327,2.0221856,,,,,,,,,,,,,,,,, -22900,0.33439234,2.0459466,,,,,,,,,,,,,,,,, -23000,0.54105467,2.031531,,,,,,,,,,,,,,,,, -23100,0.5411619,2.054462,,,,,,,,,,,,,,,,, -23200,0.48234046,2.062212,,,,,,,,,,,,,,,,, -23300,0.3112842,2.09256,,,,,,,,,,,,,,,,, -23400,0.477327,2.1367505,,,,,,,,,,,,,,,,, -23500,0.4753483,1.9815363,,,,,,,,,,,,,,,,, -23600,0.37704015,2.1356514,,,,,,,,,,,,,,,,, -23700,0.32700214,2.0900962,,,,,,,,,,,,,,,,, -23800,0.3023886,2.0306108,,,,,,,,,,,,,,,,, -23900,0.3411729,1.9901057,,,,,,,,,,,,,,,,, -24000,0.42131704,1.9965558,,,,,,,,,,,,,,,,, -24014,,,0.601838231086731,2.06695556640625,28.35724378331637,0.622211754322052,1.888195872306824,25.447011922234104,3000.0,0.6272500157356262,1.8352961540222168,24.59423009944916,3003.0,8428.94619846344,14670.757593154907,8428.94619846344,6240.782923460007,0.2702040672302246,0.0 -24100,0.49056956,2.0250456,,,,,,,,,,,,,,,,, -24200,0.3330151,2.065101,,,,,,,,,,,,,,,,, -24300,0.39755479,2.0353107,,,,,,,,,,,,,,,,, -24400,0.24745649,1.9906974,,,,,,,,,,,,,,,,, -24500,0.4360429,1.9889817,,,,,,,,,,,,,,,,, -24600,0.3487525,1.962061,,,,,,,,,,,,,,,,, -24700,0.46044075,2.0013735,,,,,,,,,,,,,,,,, -24800,0.4475845,2.0174127,,,,,,,,,,,,,,,,, -24900,0.32873565,2.0667973,,,,,,,,,,,,,,,,, -25000,0.5525304,2.0599244,,,,,,,,,,,,,,,,, -25100,0.3700336,2.0189738,,,,,,,,,,,,,,,,, -25200,0.43748719,1.9768612,,,,,,,,,,,,,,,,, -25300,0.27558634,2.1667008,,,,,,,,,,,,,,,,, -25400,0.64300984,2.0631814,,,,,,,,,,,,,,,,, -25500,0.4007564,2.037705,,,,,,,,,,,,,,,,, -25600,0.6403209,2.032174,,,,,,,,,,,,,,,,, -25700,0.44853306,2.0384,,,,,,,,,,,,,,,,, -25800,0.4562905,2.1178172,,,,,,,,,,,,,,,,, -25900,0.45313025,2.0465603,,,,,,,,,,,,,,,,, -26000,0.5694518,1.9838761,,,,,,,,,,,,,,,,, -26100,0.45666572,2.0360029,,,,,,,,,,,,,,,,, -26200,0.39236993,2.0426517,,,,,,,,,,,,,,,,, -26300,0.444916,2.0288885,,,,,,,,,,,,,,,,, -26400,0.35308748,2.048941,,,,,,,,,,,,,,,,, -26415,,,0.6068945527076721,2.0167477130889893,29.215509208704106,0.6261794567108154,1.86761474609375,26.067529651612723,3000.0,0.6306548118591309,1.823778748512268,24.450720016161004,3003.0,9269.041829109192,15998.954328775406,9269.041829109192,6728.776242017746,0.2997102737426758,0.0 -26500,0.42918894,1.9838014,,,,,,,,,,,,,,,,, -26600,0.36928102,1.9776611,,,,,,,,,,,,,,,,, -26700,0.35834372,2.060543,,,,,,,,,,,,,,,,, -26800,0.25749147,2.0793366,,,,,,,,,,,,,,,,, -26900,0.41334352,2.0336933,,,,,,,,,,,,,,,,, -27000,0.30221105,2.013345,,,,,,,,,,,,,,,,, -27100,0.45082164,2.0026152,,,,,,,,,,,,,,,,, -27200,0.3916912,1.9790019,,,,,,,,,,,,,,,,, -27300,0.4106076,1.9400007,,,,,,,,,,,,,,,,, -27400,0.29444298,2.0008786,,,,,,,,,,,,,,,,, -27500,0.2586188,2.016364,,,,,,,,,,,,,,,,, -27600,0.28598708,2.0612571,,,,,,,,,,,,,,,,, -27700,0.42710105,2.010861,,,,,,,,,,,,,,,,, -27800,0.36982456,2.0556893,,,,,,,,,,,,,,,,, -27900,0.3001121,1.9729723,,,,,,,,,,,,,,,,, -28000,0.64761746,2.1864579,,,,,,,,,,,,,,,,, -28100,0.4727173,1.9443699,,,,,,,,,,,,,,,,, -28200,0.5539298,1.9996455,,,,,,,,,,,,,,,,, -28300,0.4228601,2.0200577,,,,,,,,,,,,,,,,, -28400,0.2747987,1.9812535,,,,,,,,,,,,,,,,, -28500,0.4083005,1.9580218,,,,,,,,,,,,,,,,, -28600,0.36004585,2.1295094,,,,,,,,,,,,,,,,, -28700,0.6646688,2.0340145,,,,,,,,,,,,,,,,, -28800,0.30795702,2.0634274,,,,,,,,,,,,,,,,, -28816,,,0.6076061129570007,2.0170085430145264,28.61913311876729,0.6236252188682556,1.8711217641830444,25.54049461004584,3000.0,0.6310731768608093,1.814279556274414,24.502543676121828,3003.0,10109.0739672184,17472.045060634613,10109.0739672184,7361.727405786514,0.3307063579559326,0.0 -28900,0.4889186,2.118036,,,,,,,,,,,,,,,,, -29000,0.26564646,1.9661002,,,,,,,,,,,,,,,,, -29100,0.3272672,1.9913893,,,,,,,,,,,,,,,,, -29200,0.27396986,2.099998,,,,,,,,,,,,,,,,, -29300,0.66106015,2.0364377,,,,,,,,,,,,,,,,, -29400,0.32190102,2.1145587,,,,,,,,,,,,,,,,, -29500,0.29628453,1.9278473,,,,,,,,,,,,,,,,, -29600,0.56592214,1.9941736,,,,,,,,,,,,,,,,, -29700,0.6796316,2.0264564,,,,,,,,,,,,,,,,, -29800,0.23903501,2.024767,,,,,,,,,,,,,,,,, -29900,0.27055603,1.9860407,,,,,,,,,,,,,,,,, -30000,0.6640029,2.02843,,,,,,,,,,,,,,,,, -30100,0.26683068,2.0499012,,,,,,,,,,,,,,,,, -30200,0.7042154,1.9778868,,,,,,,,,,,,,,,,, -30300,0.25716895,1.9816959,,,,,,,,,,,,,,,,, -30400,0.22252835,2.05737,,,,,,,,,,,,,,,,, -30500,0.44940683,2.0434563,,,,,,,,,,,,,,,,, -30600,0.37330654,1.950097,,,,,,,,,,,,,,,,, -30700,0.32132414,2.07328,,,,,,,,,,,,,,,,, -30800,0.6676084,1.9642433,,,,,,,,,,,,,,,,, -30900,0.39794606,1.9687072,,,,,,,,,,,,,,,,, -31000,0.44214177,1.9990567,,,,,,,,,,,,,,,,, -31100,0.26147583,2.0003684,,,,,,,,,,,,,,,,, -31200,0.60963583,1.9592059,,,,,,,,,,,,,,,,, -31218,,,0.6038236021995544,2.032376766204834,28.42866717212713,0.6241955757141113,1.8636945486068728,25.288898614166587,3000.0,0.630387544631958,1.82143235206604,24.071335933790778,3003.0,10949.304612398148,18839.73539876938,10949.304612398148,7889.079682350159,0.3619084358215332,0.0 -31300,0.37287924,2.0532002,,,,,,,,,,,,,,,,, -31400,0.32454157,2.0354364,,,,,,,,,,,,,,,,, -31500,0.428875,1.9631792,,,,,,,,,,,,,,,,, -31600,0.25842538,1.9624197,,,,,,,,,,,,,,,,, -31700,0.32814574,2.0707,,,,,,,,,,,,,,,,, -31800,0.30272782,2.0428615,,,,,,,,,,,,,,,,, -31900,0.42122695,2.0029323,,,,,,,,,,,,,,,,, -32000,0.3596898,2.0096452,,,,,,,,,,,,,,,,, -32100,0.6573592,2.0259995,,,,,,,,,,,,,,,,, -32200,0.4887073,2.0177732,,,,,,,,,,,,,,,,, -32300,0.3790971,1.9241077,,,,,,,,,,,,,,,,, -32400,0.29687482,2.140698,,,,,,,,,,,,,,,,, -32500,0.7667806,1.9865352,,,,,,,,,,,,,,,,, -32600,0.33668324,1.962053,,,,,,,,,,,,,,,,, -32700,0.37500918,1.9851097,,,,,,,,,,,,,,,,, -32800,0.27327552,2.112956,,,,,,,,,,,,,,,,, -32900,0.33622387,2.00343,,,,,,,,,,,,,,,,, -33000,0.39372763,2.0224113,,,,,,,,,,,,,,,,, -33100,0.33037823,2.0423436,,,,,,,,,,,,,,,,, -33200,0.23074958,2.1053631,,,,,,,,,,,,,,,,, -33300,0.28180686,1.9935602,,,,,,,,,,,,,,,,, -33400,0.6717878,1.9958608,,,,,,,,,,,,,,,,, -33500,0.3116565,2.1088517,,,,,,,,,,,,,,,,, -33600,0.3525532,2.0003083,,,,,,,,,,,,,,,,, -33620,,,0.609279990196228,1.99933660030365,28.38434953125899,0.6294032335281372,1.85016667842865,25.79646487757145,3000.0,0.6363604664802551,1.7914565801620483,24.920774993304907,3003.0,11789.44777727127,20178.132102012634,11789.44777727127,8387.22826910019,0.3919835090637207,0.0 -33700,0.28858846,1.9684789,,,,,,,,,,,,,,,,, -33800,0.6609573,2.00337,,,,,,,,,,,,,,,,, -33900,0.55906993,2.0522492,,,,,,,,,,,,,,,,, -34000,0.25986078,2.0090292,,,,,,,,,,,,,,,,, -34100,0.31993508,1.9847974,,,,,,,,,,,,,,,,, -34200,0.46316344,2.0995457,,,,,,,,,,,,,,,,, -34300,0.49076724,2.0744987,,,,,,,,,,,,,,,,, -34400,0.4670001,1.9435897,,,,,,,,,,,,,,,,, -34500,0.26305193,2.055708,,,,,,,,,,,,,,,,, -34600,0.37461323,1.9796684,,,,,,,,,,,,,,,,, -34700,0.3964216,1.9665645,,,,,,,,,,,,,,,,, -34800,0.30960646,2.0381913,,,,,,,,,,,,,,,,, -34900,0.28187385,2.0891454,,,,,,,,,,,,,,,,, -35000,0.4982625,2.0089211,,,,,,,,,,,,,,,,, -35100,0.30505174,2.0468411,,,,,,,,,,,,,,,,, -35200,0.61552715,2.0357423,,,,,,,,,,,,,,,,, -35300,0.3404277,1.9742677,,,,,,,,,,,,,,,,, -35400,0.28507504,1.9849591,,,,,,,,,,,,,,,,, -35500,0.4129648,2.0805402,,,,,,,,,,,,,,,,, -35600,0.55216914,2.006622,,,,,,,,,,,,,,,,, -35700,0.2781737,2.0384362,,,,,,,,,,,,,,,,, -35800,0.38754562,1.9241223,,,,,,,,,,,,,,,,, -35900,0.45546958,1.9937102,,,,,,,,,,,,,,,,, -36000,0.3360178,1.9423959,,,,,,,,,,,,,,,,, -36022,,,0.6051180958747864,2.022434711456299,28.50333226972837,0.6277913451194763,1.8456335067749023,25.85400422027908,3000.0,0.6338969469070435,1.800976037979126,24.793330930167617,3003.0,12629.596199512482,21590.39995884896,12629.596199512482,8959.241186380386,0.4224326610565185,0.0 -36100,0.3028476,2.068459,,,,,,,,,,,,,,,,, -36200,0.3089059,1.995585,,,,,,,,,,,,,,,,, -36300,0.37869555,1.982545,,,,,,,,,,,,,,,,, -36400,0.3301818,1.9849756,,,,,,,,,,,,,,,,, -36500,0.43477833,2.0491555,,,,,,,,,,,,,,,,, -36600,0.27602255,1.9662695,,,,,,,,,,,,,,,,, -36700,0.6579313,1.9922916,,,,,,,,,,,,,,,,, -36800,0.67234874,1.990557,,,,,,,,,,,,,,,,, -36900,0.802835,2.0546517,,,,,,,,,,,,,,,,, -37000,0.28793862,2.0388987,,,,,,,,,,,,,,,,, -37100,0.3550197,1.8979511,,,,,,,,,,,,,,,,, -37200,0.3480587,2.0002983,,,,,,,,,,,,,,,,, -37300,0.33095482,1.9979191,,,,,,,,,,,,,,,,, -37400,0.26004818,2.027279,,,,,,,,,,,,,,,,, -37500,0.38978663,2.051245,,,,,,,,,,,,,,,,, -37600,0.54747385,1.9919094,,,,,,,,,,,,,,,,, -37700,0.42363915,1.9846634,,,,,,,,,,,,,,,,, -37800,0.31695315,2.0357473,,,,,,,,,,,,,,,,, -37900,0.8199223,2.072045,,,,,,,,,,,,,,,,, -38000,0.41124728,2.0247848,,,,,,,,,,,,,,,,, -38100,0.3035282,2.0123873,,,,,,,,,,,,,,,,, -38200,0.35359937,1.9756762,,,,,,,,,,,,,,,,, -38300,0.57438874,2.0276034,,,,,,,,,,,,,,,,, -38400,0.46002567,1.9735068,,,,,,,,,,,,,,,,, -38423,,,0.6163857579231262,1.9339035749435425,29.14323735552401,0.6316102743148804,1.8279942274093628,26.215031467429903,3000.0,0.6375108957290649,1.7810250520706177,25.07801187053838,3003.0,13469.625989198685,22926.05318045616,13469.625989198685,9454.756899356842,0.452505350112915,0.0 -38500,0.47936386,1.986563,,,,,,,,,,,,,,,,, -38600,0.4471497,2.066067,,,,,,,,,,,,,,,,, -38700,0.32611012,1.9137337,,,,,,,,,,,,,,,,, -38800,0.5404707,1.9716843,,,,,,,,,,,,,,,,, -38900,0.3954352,1.9574587,,,,,,,,,,,,,,,,, -39000,0.2582295,2.02992,,,,,,,,,,,,,,,,, -39100,0.28904712,1.9484484,,,,,,,,,,,,,,,,, -39200,0.2845193,2.078942,,,,,,,,,,,,,,,,, -39300,0.49939615,1.9963522,,,,,,,,,,,,,,,,, -39400,0.3572837,2.033409,,,,,,,,,,,,,,,,, -39500,0.41275212,1.9685458,,,,,,,,,,,,,,,,, -39600,0.6281918,1.981816,,,,,,,,,,,,,,,,, -39700,0.37664592,1.9766771,,,,,,,,,,,,,,,,, -39800,0.5316004,2.0298183,,,,,,,,,,,,,,,,, -39900,0.5469354,1.9736778,,,,,,,,,,,,,,,,, -40000,0.41118914,2.0371456,,,,,,,,,,,,,,,,, -40100,0.3335677,1.9406329,,,,,,,,,,,,,,,,, -40200,0.6459389,1.9452418,,,,,,,,,,,,,,,,, -40300,0.5404739,1.9941736,,,,,,,,,,,,,,,,, -40400,0.27156025,1.9763628,,,,,,,,,,,,,,,,, -40500,0.6136922,1.9233146,,,,,,,,,,,,,,,,, -40600,0.49447605,2.0149636,,,,,,,,,,,,,,,,, -40700,0.8512884,1.9845271,,,,,,,,,,,,,,,,, -40800,0.3580114,1.8958627,,,,,,,,,,,,,,,,, -40824,,,0.6103520393371582,1.9843041896820068,29.142409587723943,0.6277045607566833,1.833596467971801,25.61308660868012,3000.0,0.6380454301834106,1.7811000347137451,24.74312727864808,3003.0,14309.583218336104,24290.329845905304,14309.583218336104,9978.96815109253,0.4831573963165283,0.0 -40900,0.56940556,1.9993527,,,,,,,,,,,,,,,,, -41000,0.25078717,1.9809874,,,,,,,,,,,,,,,,, -41100,0.76418376,2.008616,,,,,,,,,,,,,,,,, -41200,0.4584978,1.9971532,,,,,,,,,,,,,,,,, -41300,0.62115234,1.9474212,,,,,,,,,,,,,,,,, -41400,0.40938357,1.8670639,,,,,,,,,,,,,,,,, -41500,0.31213284,1.9463845,,,,,,,,,,,,,,,,, -41600,0.45815584,2.0256073,,,,,,,,,,,,,,,,, -41700,0.42215976,2.0042098,,,,,,,,,,,,,,,,, -41800,0.30919057,1.9448487,,,,,,,,,,,,,,,,, -41900,0.46676737,1.9947186,,,,,,,,,,,,,,,,, -42000,0.32574055,1.966653,,,,,,,,,,,,,,,,, -42100,0.28740644,1.9468279,,,,,,,,,,,,,,,,, -42200,0.6255049,2.1298785,,,,,,,,,,,,,,,,, -42300,0.38068452,2.0176682,,,,,,,,,,,,,,,,, -42400,0.35088336,1.9642954,,,,,,,,,,,,,,,,, -42500,0.35917816,1.9991947,,,,,,,,,,,,,,,,, -42600,0.77085453,1.9464136,,,,,,,,,,,,,,,,, -42700,0.8151125,2.039836,,,,,,,,,,,,,,,,, -42800,0.668221,2.0125659,,,,,,,,,,,,,,,,, -42900,0.30555415,1.9379636,,,,,,,,,,,,,,,,, -43000,0.30738762,1.8816733,,,,,,,,,,,,,,,,, -43100,0.27820972,1.9865398,,,,,,,,,,,,,,,,, -43200,0.50528294,1.9447801,,,,,,,,,,,,,,,,, -43226,,,0.6160485148429871,1.9517220258712769,29.20263873077558,0.6300479769706726,1.814311981201172,22.43720782495505,3000.0,0.6379989981651306,1.769517183303833,24.556187006987688,3003.0,15149.78208065033,25965.07247161865,15149.78208065033,10813.402871608734,0.5145649909973145,0.0 -43300,0.41226864,1.93686,,,,,,,,,,,,,,,,, -43400,0.36228198,1.9724113,,,,,,,,,,,,,,,,, -43500,0.55928916,1.903503,,,,,,,,,,,,,,,,, -43600,0.31765938,2.0001214,,,,,,,,,,,,,,,,, -43700,0.5206715,2.029423,,,,,,,,,,,,,,,,, -43800,0.6485302,1.9388243,,,,,,,,,,,,,,,,, -43900,0.40286997,1.9443036,,,,,,,,,,,,,,,,, -44000,0.41367644,1.9043194,,,,,,,,,,,,,,,,, -44100,0.46610215,1.9931433,,,,,,,,,,,,,,,,, -44200,0.29495382,1.9569834,,,,,,,,,,,,,,,,, -44300,0.3886598,1.8606106,,,,,,,,,,,,,,,,, -44400,0.7320322,1.905761,,,,,,,,,,,,,,,,, -44500,0.39451888,1.9915574,,,,,,,,,,,,,,,,, -44600,0.38987613,1.9709345,,,,,,,,,,,,,,,,, -44700,0.35385883,1.9479518,,,,,,,,,,,,,,,,, -44800,0.2792071,2.0282502,,,,,,,,,,,,,,,,, -44900,0.36266872,1.9122244,,,,,,,,,,,,,,,,, -45000,0.32530263,1.8794398,,,,,,,,,,,,,,,,, -45100,0.516697,1.9139566,,,,,,,,,,,,,,,,, -45200,0.3103022,1.9558996,,,,,,,,,,,,,,,,, -45300,0.4904587,1.9732559,,,,,,,,,,,,,,,,, -45400,0.45502827,1.9574753,,,,,,,,,,,,,,,,, -45500,0.5384887,2.0428123,,,,,,,,,,,,,,,,, -45600,0.2962484,1.94246,,,,,,,,,,,,,,,,, -45628,,,0.6102014183998108,1.9810327291488647,29.17744358049914,0.6297379732131958,1.8217298984527588,25.96525554826419,3000.0,0.6361513137817383,1.7746611833572388,24.85202517879502,3003.0,15989.976817846298,27441.79203939438,15989.976817846298,11449.814586162567,0.5513138771057129,0.0 -45700,0.40020338,2.0643659,,,,,,,,,,,,,,,,, -45800,0.44669664,1.956012,,,,,,,,,,,,,,,,, -45900,0.4758813,1.9176103,,,,,,,,,,,,,,,,, -46000,0.3309236,1.9297191,,,,,,,,,,,,,,,,, -46100,0.48618126,2.0113573,,,,,,,,,,,,,,,,, -46200,0.4549017,1.9248462,,,,,,,,,,,,,,,,, -46300,0.49131623,1.9637287,,,,,,,,,,,,,,,,, -46400,0.4339269,2.0518508,,,,,,,,,,,,,,,,, -46500,0.29214877,2.0407321,,,,,,,,,,,,,,,,, -46600,0.38785166,2.0348558,,,,,,,,,,,,,,,,, -46700,0.27860445,2.013859,,,,,,,,,,,,,,,,, -46800,0.31021526,1.9833874,,,,,,,,,,,,,,,,, -46900,0.34793007,1.9560018,,,,,,,,,,,,,,,,, -47000,0.4705133,1.9428372,,,,,,,,,,,,,,,,, -47100,0.6090247,1.9201312,,,,,,,,,,,,,,,,, -47200,0.5544916,1.9887457,,,,,,,,,,,,,,,,, -47300,0.35899374,1.9158881,,,,,,,,,,,,,,,,, -47400,0.35391217,1.9806515,,,,,,,,,,,,,,,,, -47500,0.47397044,1.9264314,,,,,,,,,,,,,,,,, -47600,0.24210154,1.9559784,,,,,,,,,,,,,,,,, -47700,0.32479656,2.0184891,,,,,,,,,,,,,,,,, -47800,0.27395055,1.9500116,,,,,,,,,,,,,,,,, -47900,0.62499064,1.9248874,,,,,,,,,,,,,,,,, -48000,0.3744512,1.9968355,,,,,,,,,,,,,,,,, -48030,,,0.6131874322891235,1.9688488245010376,29.10088251542185,0.6333833336830139,1.80632483959198,26.4319809165706,3000.0,0.6418104767799377,1.7534323930740356,25.610204331052604,3003.0,16830.157977104187,28879.521926641464,16830.157977104187,12047.255705833437,0.5844166278839111,0.0 -48100,0.49591893,1.9625545,,,,,,,,,,,,,,,,, -48200,0.28419095,2.0229259,,,,,,,,,,,,,,,,, -48300,0.36123067,1.9628166,,,,,,,,,,,,,,,,, -48400,0.78293204,1.909979,,,,,,,,,,,,,,,,, -48500,0.32030684,1.9228694,,,,,,,,,,,,,,,,, -48600,0.3010352,2.0461004,,,,,,,,,,,,,,,,, -48700,0.26597828,1.9853334,,,,,,,,,,,,,,,,, -48800,0.38449544,1.885788,,,,,,,,,,,,,,,,, -48900,0.39614677,1.8944092,,,,,,,,,,,,,,,,, -49000,0.3545913,1.9812542,,,,,,,,,,,,,,,,, -49100,0.24885501,1.928084,,,,,,,,,,,,,,,,, -49200,0.35451564,1.8789364,,,,,,,,,,,,,,,,, -49300,0.25763112,1.8784612,,,,,,,,,,,,,,,,, -49400,0.3106102,1.928648,,,,,,,,,,,,,,,,, -49500,0.31769663,1.9986322,,,,,,,,,,,,,,,,, -49600,0.49675167,1.9411402,,,,,,,,,,,,,,,,, -49700,0.38257155,2.023859,,,,,,,,,,,,,,,,, -49800,0.46924964,1.8951786,,,,,,,,,,,,,,,,, -49900,0.5429772,2.0204668,,,,,,,,,,,,,,,,, -50000,0.39255995,1.9455783,,,,,,,,,,,,,,,,, -50100,0.5127731,2.0571265,,,,,,,,,,,,,,,,, -50200,0.3978618,1.9620653,,,,,,,,,,,,,,,,, -50300,0.3105312,1.8808937,,,,,,,,,,,,,,,,, -50400,0.32448313,1.9939137,,,,,,,,,,,,,,,,, -50431,,,0.6287590861320496,1.834293246269226,30.29873589696989,0.6369046568870544,1.7885897159576416,26.45663032630946,3000.0,0.6455406546592712,1.725481033325195,25.79031645404289,3003.0,17670.106913089752,30213.769419670105,17670.106913089752,12541.4399433136,0.6211183071136475,0.0 -50500,0.26980507,1.9857856,,,,,,,,,,,,,,,,, -50600,0.30715653,1.907972,,,,,,,,,,,,,,,,, -50700,0.30909684,1.9278086,,,,,,,,,,,,,,,,, -50800,0.35537952,1.9500386,,,,,,,,,,,,,,,,, -50900,0.2805746,1.9845022,,,,,,,,,,,,,,,,, -51000,0.41699016,1.9243817,,,,,,,,,,,,,,,,, -51100,0.26375413,1.9236435,,,,,,,,,,,,,,,,, -51200,0.37416217,1.999529,,,,,,,,,,,,,,,,, -51300,0.4839078,1.9665765,,,,,,,,,,,,,,,,, -51400,0.26850617,1.8907757,,,,,,,,,,,,,,,,, -51500,0.29015726,1.9313257,,,,,,,,,,,,,,,,, -51600,0.45452222,2.037998,,,,,,,,,,,,,,,,, -51700,0.33671805,1.9253931,,,,,,,,,,,,,,,,, -51800,0.27713978,2.0918891,,,,,,,,,,,,,,,,, -51900,0.348625,1.871754,,,,,,,,,,,,,,,,, -52000,0.33328694,1.9044679,,,,,,,,,,,,,,,,, -52100,0.3775946,2.03164,,,,,,,,,,,,,,,,, -52200,0.30153382,1.9312239,,,,,,,,,,,,,,,,, -52300,0.3960283,1.9204836,,,,,,,,,,,,,,,,, -52400,0.28152397,1.9959136,,,,,,,,,,,,,,,,, -52500,0.3025933,2.0169022,,,,,,,,,,,,,,,,, -52600,0.32139528,1.9800031,,,,,,,,,,,,,,,,, -52700,0.37329543,1.9911366,,,,,,,,,,,,,,,,, -52800,0.45549914,2.0026455,,,,,,,,,,,,,,,,, -52832,,,0.6180623173713684,1.9361435174942017,29.478945279105627,0.6362723112106323,1.7838091850280762,26.286487484829664,3000.0,0.6466097235679626,1.7135682106018066,25.517370648362107,3003.0,18510.01056957245,31582.84491109848,18510.01056957245,13070.50362753868,0.6545243263244629,0.0 -52900,0.7034001,2.0311503,,,,,,,,,,,,,,,,, -53000,0.5650996,2.001087,,,,,,,,,,,,,,,,, -53100,0.29576975,1.974165,,,,,,,,,,,,,,,,, -53200,0.40918666,1.9198048,,,,,,,,,,,,,,,,, -53300,0.42886075,1.884842,,,,,,,,,,,,,,,,, -53400,0.2449814,1.9530158,,,,,,,,,,,,,,,,, -53500,0.5047098,1.9202888,,,,,,,,,,,,,,,,, -53600,0.4231844,1.9684579,,,,,,,,,,,,,,,,, -53700,0.3736451,1.957103,,,,,,,,,,,,,,,,, -53800,0.3979953,1.9912167,,,,,,,,,,,,,,,,, -53900,0.29696125,1.8613605,,,,,,,,,,,,,,,,, -54000,0.51496977,1.8952705,,,,,,,,,,,,,,,,, -54100,0.28133044,1.8412335,,,,,,,,,,,,,,,,, -54200,0.31453782,1.9473201,,,,,,,,,,,,,,,,, -54300,0.39626518,1.8647012,,,,,,,,,,,,,,,,, -54400,0.3011108,1.8683816,,,,,,,,,,,,,,,,, -54500,0.3351337,1.9285252,,,,,,,,,,,,,,,,, -54600,0.40240705,1.9765805,,,,,,,,,,,,,,,,, -54700,0.4735485,1.9295267,,,,,,,,,,,,,,,,, -54800,0.2788375,1.8618789,,,,,,,,,,,,,,,,, -54900,0.40311164,1.9461335,,,,,,,,,,,,,,,,, -55000,0.37752572,1.8889054,,,,,,,,,,,,,,,,, -55100,0.4198434,1.9955058,,,,,,,,,,,,,,,,, -55200,0.35375854,1.8605875,,,,,,,,,,,,,,,,, -55233,,,0.6131559014320374,1.9650341272354128,29.613710414488946,0.6389381289482117,1.767750859260559,26.565091468583024,3000.0,0.6476556062698364,1.7083313465118408,25.627177441123937,3003.0,19349.917988538746,32943.01463651657,19349.917988538746,13590.657889842989,0.6880850791931152,0.0 -55300,0.2677528,1.9826655,,,,,,,,,,,,,,,,, -55400,0.34878898,1.9330071,,,,,,,,,,,,,,,,, -55500,0.28049222,1.8969442,,,,,,,,,,,,,,,,, -55600,0.41802529,1.9570498,,,,,,,,,,,,,,,,, -55700,0.29261237,1.8684473,,,,,,,,,,,,,,,,, -55800,0.45052713,1.8855846,,,,,,,,,,,,,,,,, -55900,0.49314398,2.010452,,,,,,,,,,,,,,,,, -56000,0.30812836,1.8765653,,,,,,,,,,,,,,,,, -56100,0.725706,1.8788612,,,,,,,,,,,,,,,,, -56200,0.2665874,1.9875765,,,,,,,,,,,,,,,,, -56300,0.3614445,1.9937863,,,,,,,,,,,,,,,,, -56400,0.38980323,1.8849564,,,,,,,,,,,,,,,,, -56500,0.3075535,1.9485672,,,,,,,,,,,,,,,,, -56600,0.32554108,1.8693961,,,,,,,,,,,,,,,,, -56700,0.30235776,2.001355,,,,,,,,,,,,,,,,, -56800,0.2537688,2.0049963,,,,,,,,,,,,,,,,, -56900,0.32414246,2.0450969,,,,,,,,,,,,,,,,, -57000,0.4344899,1.8786244,,,,,,,,,,,,,,,,, -57100,0.42567694,1.8649316,,,,,,,,,,,,,,,,, -57200,0.5197837,1.9163935,,,,,,,,,,,,,,,,, -57300,0.27990586,1.913797,,,,,,,,,,,,,,,,, -57400,0.56623775,1.9095103,,,,,,,,,,,,,,,,, -57500,0.3519881,2.034838,,,,,,,,,,,,,,,,, -57600,0.37984943,1.9414679,,,,,,,,,,,,,,,,, -57634,,,0.623281717300415,1.8893014192581177,30.03921252766637,0.639272928237915,1.763260841369629,26.7444395678232,3000.0,0.6502469778060913,1.700022578239441,26.16670611175656,3003.0,20189.931703090668,34365.222650527954,20189.931703090668,14172.742085933683,0.7207436561584473,0.0 -57700,0.2733544,1.9800731,,,,,,,,,,,,,,,,, -57800,0.36055297,1.991925,,,,,,,,,,,,,,,,, -57900,0.42905512,1.8021085,,,,,,,,,,,,,,,,, -58000,0.29840302,1.8708751,,,,,,,,,,,,,,,,, -58100,0.46789166,1.8876708,,,,,,,,,,,,,,,,, -58200,0.46166578,1.812376,,,,,,,,,,,,,,,,, -58300,0.33167452,1.8768519,,,,,,,,,,,,,,,,, -58400,0.44828293,1.8945748,,,,,,,,,,,,,,,,, -58500,0.4420239,1.9499575,,,,,,,,,,,,,,,,, -58600,0.39426464,1.9688573,,,,,,,,,,,,,,,,, -58700,0.39005235,1.9264965,,,,,,,,,,,,,,,,, -58800,0.26829287,1.8848655,,,,,,,,,,,,,,,,, -58900,0.6204145,1.8118666,,,,,,,,,,,,,,,,, -59000,0.41248754,1.9606202,,,,,,,,,,,,,,,,, -59100,0.51550776,1.9193782,,,,,,,,,,,,,,,,, -59200,0.2943339,1.9103937,,,,,,,,,,,,,,,,, -59300,0.34688783,1.9726988,,,,,,,,,,,,,,,,, -59400,0.58964944,1.9090604,,,,,,,,,,,,,,,,, -59500,0.27801603,1.9083532,,,,,,,,,,,,,,,,, -59600,0.38217244,1.9290435,,,,,,,,,,,,,,,,, -59700,0.31120837,1.885513,,,,,,,,,,,,,,,,, -59800,0.30384666,1.8604879,,,,,,,,,,,,,,,,, -59900,0.31344432,1.8587704,,,,,,,,,,,,,,,,, -60000,0.27923355,1.9395509,,,,,,,,,,,,,,,,, -60036,,,0.6204534769058228,1.9200528860092163,30.193340853845427,0.6428934335708618,1.7423086166381836,26.78540750934711,3000.0,0.6512114405632019,1.682421326637268,25.748610033015733,3003.0,21030.09056377411,35664.21192359924,21030.09056377411,14631.463346242905,0.7540163993835449,0.0 -60100,0.31841937,1.9289591,,,,,,,,,,,,,,,,, -60200,0.33938348,1.86195,,,,,,,,,,,,,,,,, -60300,0.31561175,1.9187052,,,,,,,,,,,,,,,,, -60400,0.30033264,1.9403777,,,,,,,,,,,,,,,,, -60500,0.28589123,1.9542273,,,,,,,,,,,,,,,,, -60600,0.34190762,1.9244971,,,,,,,,,,,,,,,,, -60700,0.43376017,1.923214,,,,,,,,,,,,,,,,, -60800,0.30856994,1.9244714,,,,,,,,,,,,,,,,, -60900,0.3562543,1.9687117,,,,,,,,,,,,,,,,, -61000,0.3047939,1.9440566,,,,,,,,,,,,,,,,, -61100,0.26987433,1.8461298,,,,,,,,,,,,,,,,, -61200,0.34976608,1.9053167,,,,,,,,,,,,,,,,, -61300,0.27826852,1.9641706,,,,,,,,,,,,,,,,, -61400,0.25634566,1.8988403,,,,,,,,,,,,,,,,, -61500,0.6532284,1.9032199,,,,,,,,,,,,,,,,, -61600,0.3110563,1.8637167,,,,,,,,,,,,,,,,, -61700,0.38779163,1.8805853,,,,,,,,,,,,,,,,, -61800,0.36188233,1.9619468,,,,,,,,,,,,,,,,, -61900,0.27950498,1.8994853,,,,,,,,,,,,,,,,, -62000,0.342047,1.9632856,,,,,,,,,,,,,,,,, -62100,0.2577655,1.823509,,,,,,,,,,,,,,,,, -62200,0.2633622,1.9171183,,,,,,,,,,,,,,,,, -62300,0.2535054,1.8176872,,,,,,,,,,,,,,,,, -62400,0.26979285,1.8365971,,,,,,,,,,,,,,,,, -62438,,,0.6210818886756897,1.9125397205352783,29.44747130858823,0.6429926156997681,1.7286237478256226,27.028646288240783,3000.0,0.6535122990608215,1.6790618896484375,26.31763856807123,3003.0,21870.185875177383,37007.15480089188,21870.185875177383,15134.199719667437,0.7899153232574463,0.0 -62500,0.25597385,1.8711902,,,,,,,,,,,,,,,,, -62600,0.25604677,1.8792348,,,,,,,,,,,,,,,,, -62700,0.27526435,1.8464022,,,,,,,,,,,,,,,,, -62800,0.3104898,1.9078308,,,,,,,,,,,,,,,,, -62900,0.3180263,1.8034116,,,,,,,,,,,,,,,,, -63000,0.53488684,1.9632461,,,,,,,,,,,,,,,,, -63100,0.413498,1.860219,,,,,,,,,,,,,,,,, -63200,0.26958632,1.8959979,,,,,,,,,,,,,,,,, -63300,0.35657117,1.9426523,,,,,,,,,,,,,,,,, -63400,0.32593757,2.0039499,,,,,,,,,,,,,,,,, -63500,0.34867606,1.8524529,,,,,,,,,,,,,,,,, -63600,0.33907768,1.843704,,,,,,,,,,,,,,,,, -63700,0.3690438,1.8314404,,,,,,,,,,,,,,,,, -63800,0.3136176,2.0067003,,,,,,,,,,,,,,,,, -63900,0.3392965,1.9005015,,,,,,,,,,,,,,,,, -64000,0.31920478,1.9089539,,,,,,,,,,,,,,,,, -64100,0.324723,1.9969786,,,,,,,,,,,,,,,,, -64200,0.36312047,1.9388533,,,,,,,,,,,,,,,,, -64300,0.29292282,1.834525,,,,,,,,,,,,,,,,, -64400,0.30591187,1.9102997,,,,,,,,,,,,,,,,, -64500,0.2683548,1.90492,,,,,,,,,,,,,,,,, -64600,0.3298352,1.8180785,,,,,,,,,,,,,,,,, -64700,0.276541,1.7887589,,,,,,,,,,,,,,,,, -64800,0.30505654,1.8800131,,,,,,,,,,,,,,,,, -64838,,,0.6214345097541809,1.899621844291687,29.87296898313337,0.6450632810592651,1.725380539894104,26.96320628106604,3000.0,0.653872549533844,1.662144660949707,26.29416092890208,3003.0,22710.087017774586,38335.86119651794,22710.087017774586,15622.89288187027,0.8244888782501221,0.0 -64900,0.2599083,1.8526127,,,,,,,,,,,,,,,,, -65000,0.33405817,1.9206237,,,,,,,,,,,,,,,,, -65100,0.3079488,1.8766752,,,,,,,,,,,,,,,,, -65200,0.32551157,1.9306706,,,,,,,,,,,,,,,,, -65300,0.33667788,1.8824202,,,,,,,,,,,,,,,,, -65400,0.33126438,1.953956,,,,,,,,,,,,,,,,, -65500,0.3587005,1.8364496,,,,,,,,,,,,,,,,, -65600,0.37251696,1.9111624,,,,,,,,,,,,,,,,, -65700,0.28774223,1.8848512,,,,,,,,,,,,,,,,, -65800,0.30108604,1.8755119,,,,,,,,,,,,,,,,, -65900,0.3727806,1.8586563,,,,,,,,,,,,,,,,, -66000,0.37527213,1.8782219,,,,,,,,,,,,,,,,, -66100,0.28525397,1.8578784,,,,,,,,,,,,,,,,, -66200,0.33043373,1.840644,,,,,,,,,,,,,,,,, -66300,0.2630493,1.7766849,,,,,,,,,,,,,,,,, -66400,0.31212786,1.8105494,,,,,,,,,,,,,,,,, -66500,0.30036753,1.902085,,,,,,,,,,,,,,,,, -66600,0.26284903,1.8394945,,,,,,,,,,,,,,,,, -66700,0.3796159,1.8542738,,,,,,,,,,,,,,,,, -66800,0.32415813,1.7886705,,,,,,,,,,,,,,,,, -66900,0.333115,1.9006824,,,,,,,,,,,,,,,,, -67000,0.27224383,1.9111627,,,,,,,,,,,,,,,,, -67100,0.3359932,1.7911447,,,,,,,,,,,,,,,,, -67200,0.33370435,1.8353314,,,,,,,,,,,,,,,,, -67239,,,0.6219140887260437,1.8977916240692136,29.898942772496444,0.6444805264472961,1.7192747592926023,27.21427000478681,3000.0,0.6557550430297852,1.6552984714508057,26.504229713711247,3003.0,23550.073257923126,39769.913810014725,23550.073257923126,16216.848822593687,0.8594467639923096,0.0 -67300,0.27677965,1.8324982,,,,,,,,,,,,,,,,, -67400,0.3631429,1.9345582,,,,,,,,,,,,,,,,, -67500,0.35391748,1.8392565,,,,,,,,,,,,,,,,, -67600,0.29294202,1.8707442,,,,,,,,,,,,,,,,, -67700,0.32349166,1.9497322,,,,,,,,,,,,,,,,, -67800,0.29766357,1.9247679,,,,,,,,,,,,,,,,, -67900,0.30052236,1.8418845,,,,,,,,,,,,,,,,, -68000,0.31473893,1.7850093,,,,,,,,,,,,,,,,, -68100,0.29912132,1.9241941,,,,,,,,,,,,,,,,, -68200,0.30649954,1.9290179,,,,,,,,,,,,,,,,, -68300,0.33850908,1.8321664,,,,,,,,,,,,,,,,, -68400,0.26715153,1.906467,,,,,,,,,,,,,,,,, -68500,0.25396943,1.8473725,,,,,,,,,,,,,,,,, -68600,0.31525254,1.8663968,,,,,,,,,,,,,,,,, -68700,0.28549817,1.9150974,,,,,,,,,,,,,,,,, -68800,0.280017,1.9730182,,,,,,,,,,,,,,,,, -68900,0.35063067,1.7550672,,,,,,,,,,,,,,,,, -69000,0.2814723,1.8776058,,,,,,,,,,,,,,,,, -69100,0.28350323,1.8609617,,,,,,,,,,,,,,,,, -69200,0.2692866,1.871496,,,,,,,,,,,,,,,,, -69300,0.30455902,1.8528144,,,,,,,,,,,,,,,,, -69400,0.3289366,1.7998034,,,,,,,,,,,,,,,,, -69500,0.2769242,1.8550199,,,,,,,,,,,,,,,,, -69600,0.28547168,1.7796997,,,,,,,,,,,,,,,,, -69640,,,0.6379268765449524,1.781161189079285,30.41896362964424,0.649774968624115,1.6948308944702148,27.45221541734196,3000.0,0.6588344573974609,1.639184832572937,26.86768368161496,3003.0,24390.02137088776,41132.171432971954,24390.02137088776,16739.037479639053,0.9037342071533204,0.0 -69700,0.4522233,1.8515956,,,,,,,,,,,,,,,,, -69800,0.2717694,1.8670315,,,,,,,,,,,,,,,,, -69900,0.2772858,1.9032266,,,,,,,,,,,,,,,,, -70000,0.3210478,1.7936423,,,,,,,,,,,,,,,,, -70100,0.34307355,1.7995174,,,,,,,,,,,,,,,,, -70200,0.32327417,1.8617595,,,,,,,,,,,,,,,,, -70300,0.39894077,1.8195451,,,,,,,,,,,,,,,,, -70400,0.28499365,1.8413903,,,,,,,,,,,,,,,,, -70500,0.25448433,1.7981611,,,,,,,,,,,,,,,,, -70600,0.35183704,1.8749999,,,,,,,,,,,,,,,,, -70700,0.30065006,1.8899273,,,,,,,,,,,,,,,,, -70800,0.28319964,1.8269957,,,,,,,,,,,,,,,,, -70900,0.3402538,1.9672607,,,,,,,,,,,,,,,,, -71000,0.3088324,1.8753688,,,,,,,,,,,,,,,,, -71100,0.31393197,1.7941749,,,,,,,,,,,,,,,,, -71200,0.27999482,1.8554415,,,,,,,,,,,,,,,,, -71300,0.4000972,1.9187468,,,,,,,,,,,,,,,,, -71400,0.31787258,1.7701043,,,,,,,,,,,,,,,,, -71500,0.30701217,1.7937349,,,,,,,,,,,,,,,,, -71600,0.34591517,1.839754,,,,,,,,,,,,,,,,, -71700,0.39966536,1.8726792,,,,,,,,,,,,,,,,, -71800,0.2621369,1.944061,,,,,,,,,,,,,,,,, -71900,0.31128752,1.7950637,,,,,,,,,,,,,,,,, -72000,0.2848189,1.8675184,,,,,,,,,,,,,,,,, -72042,,,0.6333906650543213,1.8417229652404783,30.65259166358259,0.6506924629211426,1.6889138221740725,27.47268478695876,3000.0,0.6625297665596008,1.6213717460632324,27.263224574075057,3003.0,25229.988456487656,42538.432683467865,25229.988456487656,17305.21870613098,0.9400687217712402,0.0 -72100,0.43366385,1.8788838,,,,,,,,,,,,,,,,, -72200,0.28820628,1.804801,,,,,,,,,,,,,,,,, -72300,0.29785877,1.8104854,,,,,,,,,,,,,,,,, -72400,0.35068876,1.8146477,,,,,,,,,,,,,,,,, -72500,0.30887318,1.8602525,,,,,,,,,,,,,,,,, -72600,0.36624578,1.6956632,,,,,,,,,,,,,,,,, -72700,0.2699637,1.8882576,,,,,,,,,,,,,,,,, -72800,0.2900485,1.8389555,,,,,,,,,,,,,,,,, -72900,0.2794862,1.8451877,,,,,,,,,,,,,,,,, -73000,0.29969078,1.8452041,,,,,,,,,,,,,,,,, -73100,0.25519314,1.7882837,,,,,,,,,,,,,,,,, -73200,0.37367025,1.8489699,,,,,,,,,,,,,,,,, -73300,0.29092112,1.7719039,,,,,,,,,,,,,,,,, -73400,0.30345878,1.8683759,,,,,,,,,,,,,,,,, -73500,0.29281867,1.8014163,,,,,,,,,,,,,,,,, -73600,0.2437653,1.7866267,,,,,,,,,,,,,,,,, -73700,0.35038853,1.8390167,,,,,,,,,,,,,,,,, -73800,0.31454453,1.8165741,,,,,,,,,,,,,,,,, -73900,0.31106052,1.8304355,,,,,,,,,,,,,,,,, -74000,0.3135632,1.8404673,,,,,,,,,,,,,,,,, -74100,0.2848171,1.8279186,,,,,,,,,,,,,,,,, -74200,0.27120212,1.8620151,,,,,,,,,,,,,,,,, -74300,0.32389647,1.7929556,,,,,,,,,,,,,,,,, -74400,0.29662788,1.9003949,,,,,,,,,,,,,,,,, -74443,,,0.6327632665634155,1.840661883354187,30.128673062601425,0.6526143550872803,1.6699992418289185,27.47545997818604,3000.0,0.662494957447052,1.607891082763672,26.571996580797872,3003.0,26069.895349264145,43927.80186915398,26069.895349264145,17854.567930936813,0.9772353172302246,0.0 -74500,0.30983317,1.8426975,,,,,,,,,,,,,,,,, -74600,0.32561347,1.8466364,,,,,,,,,,,,,,,,, -74700,0.2751831,1.8684378,,,,,,,,,,,,,,,,, -74800,0.3016151,1.7996533,,,,,,,,,,,,,,,,, -74900,0.3299102,1.722877,,,,,,,,,,,,,,,,, -75000,0.34382707,1.8191309,,,,,,,,,,,,,,,,, -75100,0.3726055,1.7984874,,,,,,,,,,,,,,,,, -75200,0.30245766,1.8638443,,,,,,,,,,,,,,,,, -75300,0.29143476,1.7581477,,,,,,,,,,,,,,,,, -75400,0.2695427,1.698586,,,,,,,,,,,,,,,,, -75500,0.28950188,1.7514946,,,,,,,,,,,,,,,,, -75600,0.2964772,1.8708268,,,,,,,,,,,,,,,,, -75700,0.26287484,1.769049,,,,,,,,,,,,,,,,, -75800,0.29550543,1.8103311,,,,,,,,,,,,,,,,, -75900,0.28708178,1.8481653,,,,,,,,,,,,,,,,, -76000,0.3038173,1.8316065,,,,,,,,,,,,,,,,, -76100,0.2851945,1.840243,,,,,,,,,,,,,,,,, -76200,0.37561285,1.8339634,,,,,,,,,,,,,,,,, -76300,0.30178824,1.9004669,,,,,,,,,,,,,,,,, -76400,0.29905647,1.896454,,,,,,,,,,,,,,,,, -76500,0.32758355,1.8193177,,,,,,,,,,,,,,,,, -76600,0.28134444,1.8106307,,,,,,,,,,,,,,,,, -76700,0.24886675,1.6623081,,,,,,,,,,,,,,,,, -76800,0.3932708,1.8396477,,,,,,,,,,,,,,,,, -76844,,,0.6360828280448914,1.7997148036956787,30.947704964285027,0.6537302732467651,1.655122995376587,27.672505117892413,3000.0,0.6651095151901245,1.5921093225479126,27.0778855815688,3003.0,26909.92446255684,45335.44073653221,26909.92446255684,18422.05643105507,1.0228424072265625,0.0 -76900,0.32412317,1.8481557,,,,,,,,,,,,,,,,, -77000,0.38501188,1.9053434,,,,,,,,,,,,,,,,, -77100,0.37301126,1.8235375,,,,,,,,,,,,,,,,, -77200,0.28803515,1.8385189,,,,,,,,,,,,,,,,, -77300,0.29850206,1.8815302,,,,,,,,,,,,,,,,, -77400,0.31943372,1.8431913,,,,,,,,,,,,,,,,, -77500,0.28630376,1.7962087,,,,,,,,,,,,,,,,, -77600,0.30354846,1.77879,,,,,,,,,,,,,,,,, -77700,0.27615717,1.7747,,,,,,,,,,,,,,,,, -77800,0.3481666,1.7494478,,,,,,,,,,,,,,,,, -77900,0.2900968,1.9019567,,,,,,,,,,,,,,,,, -78000,0.32087824,1.8344481,,,,,,,,,,,,,,,,, -78100,0.302071,1.7812464,,,,,,,,,,,,,,,,, -78200,0.29449257,1.8129455,,,,,,,,,,,,,,,,, -78300,0.2755643,1.8068831,,,,,,,,,,,,,,,,, -78400,0.33390728,1.7521232,,,,,,,,,,,,,,,,, -78500,0.32677552,1.7720897,,,,,,,,,,,,,,,,, -78600,0.27105427,1.9031578,,,,,,,,,,,,,,,,, -78700,0.28759763,1.8618066,,,,,,,,,,,,,,,,, -78800,0.3378545,1.8251219,,,,,,,,,,,,,,,,, -78900,0.3031837,1.8099957,,,,,,,,,,,,,,,,, -79000,0.30619544,1.860893,,,,,,,,,,,,,,,,, -79100,0.27270982,1.7865546,,,,,,,,,,,,,,,,, -79200,0.2654168,1.7810897,,,,,,,,,,,,,,,,, -79245,,,0.6354596018791199,1.8094682693481443,30.61698412674778,0.6570656299591064,1.6474754810333252,27.837613650228427,3000.0,0.666852593421936,1.5803241729736328,27.10222512757891,3003.0,27749.865286827087,46707.21108341217,27749.865286827087,18953.771744966507,1.0598394870758057,0.0 -79300,0.2507872,1.8040192,,,,,,,,,,,,,,,,, -79400,0.34144178,1.7798203,,,,,,,,,,,,,,,,, -79500,0.29116324,1.7366008,,,,,,,,,,,,,,,,, -79600,0.3135403,1.7762823,,,,,,,,,,,,,,,,, -79700,0.31301272,1.8626162,,,,,,,,,,,,,,,,, -79800,0.35351777,1.7794688,,,,,,,,,,,,,,,,, -79900,0.28914532,1.8205065,,,,,,,,,,,,,,,,, -80000,0.30711338,1.7420406,,,,,,,,,,,,,,,,, -80100,0.2987264,1.841894,,,,,,,,,,,,,,,,, -80200,0.2925537,1.7942061,,,,,,,,,,,,,,,,, -80300,0.28068918,1.8230959,,,,,,,,,,,,,,,,, -80400,0.27095336,1.8049922,,,,,,,,,,,,,,,,, -80500,0.29723096,1.804889,,,,,,,,,,,,,,,,, -80600,0.36919335,1.8128012,,,,,,,,,,,,,,,,, -80700,0.2714122,1.8291483,,,,,,,,,,,,,,,,, -80800,0.30073836,1.7639388,,,,,,,,,,,,,,,,, -80900,0.32169425,1.8861587,,,,,,,,,,,,,,,,, -81000,0.31427032,1.8342056,,,,,,,,,,,,,,,,, -81100,0.31756514,1.7439544,,,,,,,,,,,,,,,,, -81200,0.30167586,1.7380822,,,,,,,,,,,,,,,,, -81300,0.29491594,1.7069801,,,,,,,,,,,,,,,,, -81400,0.3101068,1.9451927,,,,,,,,,,,,,,,,, -81500,0.28487095,1.7454728,,,,,,,,,,,,,,,,, -81600,0.29599717,1.8303813,,,,,,,,,,,,,,,,, -81645,,,0.6552977561950684,1.6481138467788696,32.10947252151286,0.6575492024421692,1.6359890699386597,27.99446188450933,3000.0,0.6707803606987,1.570162057876587,27.857890901754264,3003.0,28589.836373090744,48076.842170238495,28589.836373090744,19483.31466794014,1.0974700450897217,0.0 -81700,0.2975123,1.762664,,,,,,,,,,,,,,,,, -81800,0.27820462,1.828337,,,,,,,,,,,,,,,,, -81900,0.30889517,1.7812028,,,,,,,,,,,,,,,,, -82000,0.3261193,1.7402337,,,,,,,,,,,,,,,,, -82100,0.30415067,1.7784667,,,,,,,,,,,,,,,,, -82200,0.2785057,1.736193,,,,,,,,,,,,,,,,, -82300,0.28433365,1.8552029,,,,,,,,,,,,,,,,, -82400,0.2912336,1.7668607,,,,,,,,,,,,,,,,, -82500,0.28758097,1.7495346,,,,,,,,,,,,,,,,, -82600,0.28611884,1.7556418,,,,,,,,,,,,,,,,, -82700,0.31215197,1.7637028,,,,,,,,,,,,,,,,, -82800,0.31478694,1.73487,,,,,,,,,,,,,,,,, -82900,0.30278,1.8025897,,,,,,,,,,,,,,,,, -83000,0.37534213,1.7220961,,,,,,,,,,,,,,,,, -83100,0.32107866,1.7324646,,,,,,,,,,,,,,,,, -83200,0.2825255,1.768715,,,,,,,,,,,,,,,,, -83300,0.29292977,1.7367429,,,,,,,,,,,,,,,,, -83400,0.31079435,1.7642236,,,,,,,,,,,,,,,,, -83500,0.36941287,1.7638791,,,,,,,,,,,,,,,,, -83600,0.28226614,1.7990338,,,,,,,,,,,,,,,,, -83700,0.3109576,1.8466002,,,,,,,,,,,,,,,,, -83800,0.30658877,1.8567982,,,,,,,,,,,,,,,,, -83900,0.3036583,1.808724,,,,,,,,,,,,,,,,, -84000,0.31206095,1.7589971,,,,,,,,,,,,,,,,, -84048,,,0.6372441649436951,1.792824149131775,31.231938001811475,0.6590867042541504,1.6225743293762207,28.22790193040609,3000.0,0.6720934510231018,1.5535181760787964,27.58812304327058,3003.0,29429.99524831772,49503.03884482384,29429.99524831772,20069.23208117485,1.141160249710083,0.0 -84100,0.32435066,1.7803569,,,,,,,,,,,,,,,,, -84200,0.33035108,1.7930609,,,,,,,,,,,,,,,,, -84300,0.2776483,1.8433663,,,,,,,,,,,,,,,,, -84400,0.30054402,1.8151029,,,,,,,,,,,,,,,,, -84500,0.30542186,1.843251,,,,,,,,,,,,,,,,, -84600,0.29441583,1.7850251,,,,,,,,,,,,,,,,, -84700,0.28866592,1.710718,,,,,,,,,,,,,,,,, -84800,0.33141965,1.7639537,,,,,,,,,,,,,,,,, -84900,0.2845682,1.7605116,,,,,,,,,,,,,,,,, -85000,0.29132372,1.8049706,,,,,,,,,,,,,,,,, -85100,0.30512542,1.7601384,,,,,,,,,,,,,,,,, -85200,0.25993648,1.7784342,,,,,,,,,,,,,,,,, -85300,0.38991913,1.8509699,,,,,,,,,,,,,,,,, -85400,0.31630847,1.7854952,,,,,,,,,,,,,,,,, -85500,0.31659028,1.8554105,,,,,,,,,,,,,,,,, -85600,0.31662413,1.7833382,,,,,,,,,,,,,,,,, -85700,0.28778338,1.777295,,,,,,,,,,,,,,,,, -85800,0.2893565,1.7744591,,,,,,,,,,,,,,,,, -85900,0.29436752,1.7975321,,,,,,,,,,,,,,,,, -86000,0.3584604,1.8377274,,,,,,,,,,,,,,,,, -86100,0.33537364,1.8419989,,,,,,,,,,,,,,,,, -86200,0.29821146,1.7584921,,,,,,,,,,,,,,,,, -86300,0.32124665,1.7939587,,,,,,,,,,,,,,,,, -86400,0.32698357,1.8007288,,,,,,,,,,,,,,,,, -86449,,,0.6410945057868958,1.7650405168533323,31.41908621569645,0.6631659865379333,1.604163408279419,28.64836691973624,3000.0,0.6744524240493774,1.5341390371322632,27.97908133401334,3003.0,30269.94028711319,51033.38291668892,30269.94028711319,20759.51155924797,1.185424566268921,0.0 -86500,0.30369383,1.8072623,,,,,,,,,,,,,,,,, -86600,0.44401094,1.8050153,,,,,,,,,,,,,,,,, -86700,0.2879453,1.7734861,,,,,,,,,,,,,,,,, -86800,0.31261984,1.6999844,,,,,,,,,,,,,,,,, -86900,0.30178976,1.7504101,,,,,,,,,,,,,,,,, -87000,0.30444106,1.7194782,,,,,,,,,,,,,,,,, -87100,0.31163386,1.7771133,,,,,,,,,,,,,,,,, -87200,0.29361933,1.7175918,,,,,,,,,,,,,,,,, -87300,0.2949448,1.7055221,,,,,,,,,,,,,,,,, -87400,0.342695,1.7364699,,,,,,,,,,,,,,,,, -87500,0.30144563,1.6810786,,,,,,,,,,,,,,,,, -87600,0.31631923,1.7759701,,,,,,,,,,,,,,,,, -87700,0.30346,1.723271,,,,,,,,,,,,,,,,, -87800,0.3408492,1.6862805,,,,,,,,,,,,,,,,, -87900,0.35994312,1.7634178,,,,,,,,,,,,,,,,, -88000,0.30804655,1.7424976,,,,,,,,,,,,,,,,, -88100,0.30993432,1.7769978,,,,,,,,,,,,,,,,, -88200,0.30708194,1.7868111,,,,,,,,,,,,,,,,, -88300,0.2976217,1.7564834,,,,,,,,,,,,,,,,, -88400,0.29175973,1.7186937,,,,,,,,,,,,,,,,, -88500,0.3015547,1.7116045,,,,,,,,,,,,,,,,, -88600,0.29864684,1.7880133,,,,,,,,,,,,,,,,, -88700,0.3017106,1.8049537,,,,,,,,,,,,,,,,, -88800,0.29252,1.7513511,,,,,,,,,,,,,,,,, -88851,,,0.6504623293876648,1.689774990081787,31.84581123567893,0.6630296111106873,1.5953129529953003,28.532335408202147,3000.0,0.6743943095207214,1.5237705707550049,27.792506126010068,3003.0,31109.83723974228,52418.78368616104,31109.83723974228,21304.90257930756,1.2225456237792969,0.0 -88900,0.31843442,1.7357901,,,,,,,,,,,,,,,,, -89000,0.39207157,1.6966461,,,,,,,,,,,,,,,,, -89100,0.34491453,1.7235734,,,,,,,,,,,,,,,,, -89200,0.43464398,1.7788749,,,,,,,,,,,,,,,,, -89300,0.2724809,1.6502397,,,,,,,,,,,,,,,,, -89400,0.27272606,1.6423724,,,,,,,,,,,,,,,,, -89500,0.3368174,1.8081957,,,,,,,,,,,,,,,,, -89600,0.30889034,1.7746321,,,,,,,,,,,,,,,,, -89700,0.3284954,1.7867037,,,,,,,,,,,,,,,,, -89800,0.29699033,1.7178823,,,,,,,,,,,,,,,,, -89900,0.3158781,1.8118718,,,,,,,,,,,,,,,,, -90000,0.3021261,1.7767545,,,,,,,,,,,,,,,,, -90100,0.33127993,1.7963473,,,,,,,,,,,,,,,,, -90200,0.5081197,1.7603405,,,,,,,,,,,,,,,,, -90300,0.3936242,1.7940327,,,,,,,,,,,,,,,,, -90400,0.29739094,1.7739185,,,,,,,,,,,,,,,,, -90500,0.3256298,1.7022337,,,,,,,,,,,,,,,,, -90600,0.29875782,1.78413,,,,,,,,,,,,,,,,, -90700,0.31025606,1.7402139,,,,,,,,,,,,,,,,, -90800,0.31282926,1.8075922,,,,,,,,,,,,,,,,, -90900,0.30719307,1.7208338,,,,,,,,,,,,,,,,, -91000,0.3249035,1.7266505,,,,,,,,,,,,,,,,, -91100,0.30396873,1.6544776,,,,,,,,,,,,,,,,, -91200,0.27179745,1.7469614,,,,,,,,,,,,,,,,, -91254,,,0.6479614973068237,1.720964789390564,31.76366900205668,0.665571391582489,1.5849394798278809,28.862445004881646,3000.0,0.6779152750968933,1.5046041011810305,28.25653450368869,3003.0,31949.88634347916,53800.26221823692,31949.88634347916,21846.217509269714,1.2612571716308594,0.0 -91300,0.28147167,1.6823138,,,,,,,,,,,,,,,,, -91400,0.32210624,1.6943544,,,,,,,,,,,,,,,,, -91500,0.28090835,1.7071599,,,,,,,,,,,,,,,,, -91600,0.3477525,1.784613,,,,,,,,,,,,,,,,, -91700,0.301081,1.7190816,,,,,,,,,,,,,,,,, -91800,0.29282877,1.6961881,,,,,,,,,,,,,,,,, -91900,0.31711128,1.666302,,,,,,,,,,,,,,,,, -92000,0.3373231,1.776889,,,,,,,,,,,,,,,,, -92100,0.28354713,1.7589364,,,,,,,,,,,,,,,,, -92200,0.3454023,1.731811,,,,,,,,,,,,,,,,, -92300,0.32122606,1.6778606,,,,,,,,,,,,,,,,, -92400,0.3147666,1.790138,,,,,,,,,,,,,,,,, -92500,0.34815705,1.712104,,,,,,,,,,,,,,,,, -92600,0.2870848,1.7080332,,,,,,,,,,,,,,,,, -92700,0.29379404,1.6818018,,,,,,,,,,,,,,,,, -92800,0.28741297,1.7021022,,,,,,,,,,,,,,,,, -92900,0.29778865,1.6848258,,,,,,,,,,,,,,,,, -93000,0.36681384,1.7769475,,,,,,,,,,,,,,,,, -93100,0.30934444,1.7969452,,,,,,,,,,,,,,,,, -93200,0.28815997,1.7480336,,,,,,,,,,,,,,,,, -93300,0.2970033,1.725935,,,,,,,,,,,,,,,,, -93400,0.3307317,1.7296611,,,,,,,,,,,,,,,,, -93500,0.30026567,1.6718177,,,,,,,,,,,,,,,,, -93600,0.33750284,1.7080953,,,,,,,,,,,,,,,,, -93656,,,0.6468232274055481,1.723307490348816,31.68318684415901,0.6688323616981506,1.5702892541885376,29.080144950437283,3000.0,0.6798908114433289,1.4960421323776243,28.425006038283964,3003.0,32789.87203192711,55197.84358978272,32789.87203192711,22403.69824290276,1.300079584121704,0.0 -93700,0.2859037,1.6596495,,,,,,,,,,,,,,,,, -93800,0.2838414,1.6285291,,,,,,,,,,,,,,,,, -93900,0.29476207,1.7576325,,,,,,,,,,,,,,,,, -94000,0.31665176,1.6579407,,,,,,,,,,,,,,,,, -94100,0.3209586,1.6742151,,,,,,,,,,,,,,,,, -94200,0.3159236,1.7194514,,,,,,,,,,,,,,,,, -94300,0.29230118,1.7422726,,,,,,,,,,,,,,,,, -94400,0.29824707,1.71518,,,,,,,,,,,,,,,,, -94500,0.31193525,1.7248915,,,,,,,,,,,,,,,,, -94600,0.3321651,1.7599297,,,,,,,,,,,,,,,,, -94700,0.31369686,1.6947166,,,,,,,,,,,,,,,,, -94800,0.29460976,1.6758711,,,,,,,,,,,,,,,,, -94900,0.291881,1.691349,,,,,,,,,,,,,,,,, -95000,0.29494265,1.6593933,,,,,,,,,,,,,,,,, -95100,0.320376,1.7270235,,,,,,,,,,,,,,,,, -95200,0.297804,1.7618439,,,,,,,,,,,,,,,,, -95300,0.31010586,1.739292,,,,,,,,,,,,,,,,, -95400,0.29407632,1.6170409,,,,,,,,,,,,,,,,, -95500,0.30162993,1.6308966,,,,,,,,,,,,,,,,, -95600,0.31603009,1.6626213,,,,,,,,,,,,,,,,, -95700,0.3338539,1.7790887,,,,,,,,,,,,,,,,, -95800,0.28726077,1.7549354,,,,,,,,,,,,,,,,, -95900,0.35069016,1.7817584,,,,,,,,,,,,,,,,, -96000,0.33413494,1.6806146,,,,,,,,,,,,,,,,, -96058,,,0.6511693596839905,1.6873743534088137,31.714046536608137,0.6717833280563354,1.547297716140747,29.245886051656544,3000.0,0.6813665628433228,1.4757575988769531,28.80327442841708,3003.0,33629.953258514404,56583.53949093819,33629.953258514404,22949.19913959503,1.338247776031494,0.0 -96100,0.29567418,1.6616485,,,,,,,,,,,,,,,,, -96200,0.3036067,1.6518495,,,,,,,,,,,,,,,,, -96300,0.29803187,1.7522942,,,,,,,,,,,,,,,,, -96400,0.32741094,1.7542624,,,,,,,,,,,,,,,,, -96500,0.30578774,1.6668339,,,,,,,,,,,,,,,,, -96600,0.372088,1.7573484,,,,,,,,,,,,,,,,, -96700,0.30472776,1.663755,,,,,,,,,,,,,,,,, -96800,0.2901916,1.754118,,,,,,,,,,,,,,,,, -96900,0.30051306,1.7095122,,,,,,,,,,,,,,,,, -97000,0.28978547,1.69876,,,,,,,,,,,,,,,,, -97100,0.30811283,1.6344867,,,,,,,,,,,,,,,,, -97200,0.2978489,1.6295227,,,,,,,,,,,,,,,,, -97300,0.30682796,1.6969507,,,,,,,,,,,,,,,,, -97400,0.32378522,1.6815951,,,,,,,,,,,,,,,,, -97500,0.3499467,1.6944575,,,,,,,,,,,,,,,,, -97600,0.31949708,1.6968719,,,,,,,,,,,,,,,,, -97700,0.2875651,1.6703783,,,,,,,,,,,,,,,,, -97800,0.3080418,1.7253764,,,,,,,,,,,,,,,,, -97900,0.29446846,1.6490036,,,,,,,,,,,,,,,,, -98000,0.32819727,1.7571765,,,,,,,,,,,,,,,,, -98100,0.29661965,1.6684909,,,,,,,,,,,,,,,,, -98200,0.32685113,1.7149246,,,,,,,,,,,,,,,,, -98300,0.31789425,1.7573215,,,,,,,,,,,,,,,,, -98400,0.30418572,1.6028632,,,,,,,,,,,,,,,,, -98461,,,0.6557013392448425,1.67685866355896,32.65956777295396,0.6736308336257935,1.5366019010543823,29.419734290930183,3000.0,0.6865260601043701,1.456973910331726,29.0640595971544,3003.0,34470.01172947884,57937.23023700714,34470.01172947884,23462.70897555352,1.385880470275879,0.0 -98500,0.31924608,1.6899794,,,,,,,,,,,,,,,,, -98600,0.324208,1.6596606,,,,,,,,,,,,,,,,, -98700,0.28887683,1.6308408,,,,,,,,,,,,,,,,, -98800,0.3223048,1.767389,,,,,,,,,,,,,,,,, -98900,0.30344558,1.7343363,,,,,,,,,,,,,,,,, -99000,0.28925136,1.6809161,,,,,,,,,,,,,,,,, -99100,0.29513273,1.6594536,,,,,,,,,,,,,,,,, -99200,0.29478845,1.6330254,,,,,,,,,,,,,,,,, -99300,0.31987405,1.6711689,,,,,,,,,,,,,,,,, -99400,0.30440792,1.6343466,,,,,,,,,,,,,,,,, -99500,0.28519136,1.6594993,,,,,,,,,,,,,,,,, -99600,0.2940878,1.6533594,,,,,,,,,,,,,,,,, -99700,0.3023654,1.5846865,,,,,,,,,,,,,,,,, -99800,0.3333609,1.6941682,,,,,,,,,,,,,,,,, -99900,0.28331485,1.6040597,,,,,,,,,,,,,,,,, -100000,0.31472546,1.6075255,,,,,,,,,,,,,,,,, -100100,0.3111099,1.7214533,,,,,,,,,,,,,,,,, -100200,0.31814456,1.6580178,,,,,,,,,,,,,,,,, -100300,0.33640623,1.6177164,,,,,,,,,,,,,,,,, -100400,0.33754492,1.6843175,,,,,,,,,,,,,,,,, -100500,0.30166718,1.668883,,,,,,,,,,,,,,,,, -100600,0.28863502,1.636951,,,,,,,,,,,,,,,,, -100700,0.33202714,1.6308674,,,,,,,,,,,,,,,,, -100800,0.3148379,1.6323421,,,,,,,,,,,,,,,,, -100863,,,0.6670910716056824,1.576935052871704,33.74426266233717,0.6752303242683411,1.5253691673278809,29.383349787408346,3000.0,0.6878159642219543,1.442979335784912,29.20659989243813,3003.0,35309.92882537842,59347.370725631714,35309.92882537842,24032.816915750504,1.42657732963562,0.0 -100900,0.3009846,1.6726344,,,,,,,,,,,,,,,,, -101000,0.3085289,1.6853441,,,,,,,,,,,,,,,,, -101100,0.32160565,1.6641436,,,,,,,,,,,,,,,,, -101200,0.32381967,1.6877939,,,,,,,,,,,,,,,,, -101300,0.32012594,1.7198063,,,,,,,,,,,,,,,,, -101400,0.30869,1.6732459,,,,,,,,,,,,,,,,, -101500,0.30583245,1.6675646,,,,,,,,,,,,,,,,, -101600,0.3094702,1.6356171,,,,,,,,,,,,,,,,, -101700,0.3240416,1.6348046,,,,,,,,,,,,,,,,, -101800,0.28724873,1.6689197,,,,,,,,,,,,,,,,, -101900,0.30108568,1.7274189,,,,,,,,,,,,,,,,, -102000,0.31261075,1.6687075,,,,,,,,,,,,,,,,, -102100,0.32242268,1.6436373,,,,,,,,,,,,,,,,, -102200,0.32429796,1.690748,,,,,,,,,,,,,,,,, -102300,0.3169287,1.6087087,,,,,,,,,,,,,,,,, -102400,0.34047544,1.6510926,,,,,,,,,,,,,,,,, -102500,0.33612427,1.638691,,,,,,,,,,,,,,,,, -102600,0.3657837,1.6457872,,,,,,,,,,,,,,,,, -102700,0.3216113,1.7678821,,,,,,,,,,,,,,,,, -102800,0.32523692,1.6808519,,,,,,,,,,,,,,,,, -102900,0.33574924,1.6658285,,,,,,,,,,,,,,,,, -103000,0.3322626,1.642459,,,,,,,,,,,,,,,,, -103100,0.31520852,1.6225718,,,,,,,,,,,,,,,,, -103200,0.31050938,1.7135477,,,,,,,,,,,,,,,,, -103265,,,0.6593321561813354,1.6377383470535278,32.99818121622395,0.676978588104248,1.510822296142578,29.280341283493588,3000.0,0.6917320489883423,1.421906352043152,29.722389977868755,3003.0,36149.880427360535,60767.247957229614,36149.880427360535,24612.61792945861,1.474708080291748,0.0 -103300,0.29789245,1.6996179,,,,,,,,,,,,,,,,, -103400,0.32027432,1.6707921,,,,,,,,,,,,,,,,, -103500,0.3122577,1.6079552,,,,,,,,,,,,,,,,, -103600,0.33521238,1.6222972,,,,,,,,,,,,,,,,, -103700,0.3040008,1.576892,,,,,,,,,,,,,,,,, -103800,0.32384062,1.6892918,,,,,,,,,,,,,,,,, -103900,0.30260375,1.6295643,,,,,,,,,,,,,,,,, -104000,0.32063726,1.6714168,,,,,,,,,,,,,,,,, -104100,0.3099531,1.5592171,,,,,,,,,,,,,,,,, -104200,0.31518003,1.6470984,,,,,,,,,,,,,,,,, -104300,0.3082732,1.6181039,,,,,,,,,,,,,,,,, -104400,0.3259358,1.7016604,,,,,,,,,,,,,,,,, -104500,0.3150681,1.6421621,,,,,,,,,,,,,,,,, -104600,0.30278575,1.5800347,,,,,,,,,,,,,,,,, -104700,0.30358613,1.6013203,,,,,,,,,,,,,,,,, -104800,0.3098702,1.6150948,,,,,,,,,,,,,,,,, -104900,0.3330489,1.6230271,,,,,,,,,,,,,,,,, -105000,0.33056515,1.575636,,,,,,,,,,,,,,,,, -105100,0.32557762,1.639345,,,,,,,,,,,,,,,,, -105200,0.3111478,1.6571759,,,,,,,,,,,,,,,,, -105300,0.32613763,1.6382946,,,,,,,,,,,,,,,,, -105400,0.33530712,1.5816696,,,,,,,,,,,,,,,,, -105500,0.32341084,1.6262385,,,,,,,,,,,,,,,,, -105600,0.3279287,1.5950382,,,,,,,,,,,,,,,,, -105668,,,0.6658817529678345,1.6051431894302368,32.95288097209237,0.6797559857368469,1.4939372539520264,29.482852243566366,3000.0,0.6930451393127441,1.4107818603515625,29.341508113582247,3003.0,36990.06665062904,62172.06532788277,36990.06665062904,25177.135360479355,1.5139610767364502,0.0 -105700,0.30652702,1.5485195,,,,,,,,,,,,,,,,, -105800,0.32691664,1.6238313,,,,,,,,,,,,,,,,, -105900,0.34768936,1.6792035,,,,,,,,,,,,,,,,, -106000,0.31582874,1.6436435,,,,,,,,,,,,,,,,, -106100,0.31513685,1.6031384,,,,,,,,,,,,,,,,, -106200,0.32716924,1.5999544,,,,,,,,,,,,,,,,, -106300,0.34206104,1.6344513,,,,,,,,,,,,,,,,, -106400,0.33308104,1.595442,,,,,,,,,,,,,,,,, -106500,0.345929,1.608431,,,,,,,,,,,,,,,,, -106600,0.31561443,1.5471542,,,,,,,,,,,,,,,,, -106700,0.3217556,1.5776719,,,,,,,,,,,,,,,,, -106800,0.3050859,1.6010938,,,,,,,,,,,,,,,,, -106900,0.33333868,1.5658786,,,,,,,,,,,,,,,,, -107000,0.32621694,1.7099384,,,,,,,,,,,,,,,,, -107100,0.33627206,1.5944989,,,,,,,,,,,,,,,,, -107200,0.3516596,1.6021724,,,,,,,,,,,,,,,,, -107300,0.34044376,1.6470107,,,,,,,,,,,,,,,,, -107400,0.34301963,1.6594285,,,,,,,,,,,,,,,,, -107500,0.31886402,1.5109667,,,,,,,,,,,,,,,,, -107600,0.3545683,1.599294,,,,,,,,,,,,,,,,, -107700,0.3141829,1.5317794,,,,,,,,,,,,,,,,, -107800,0.32101303,1.6178538,,,,,,,,,,,,,,,,, -107900,0.33037367,1.5351108,,,,,,,,,,,,,,,,, -108000,0.31829312,1.6131902,,,,,,,,,,,,,,,,, -108070,,,0.6727055311203003,1.5570759773254397,33.9455520747962,0.6823598146438599,1.4784127473831177,29.72807262910545,3000.0,0.6957992315292358,1.3981857299804688,29.608836478521752,3003.0,37830.27451658249,63536.203300237656,37830.27451658249,25700.941499471664,1.5608172416687012,0.0 -108100,0.32618335,1.5775238,,,,,,,,,,,,,,,,, -108200,0.3431465,1.6069356,,,,,,,,,,,,,,,,, -108300,0.34514078,1.6225479,,,,,,,,,,,,,,,,, -108400,0.33633444,1.6022828,,,,,,,,,,,,,,,,, -108500,0.3218109,1.5211576,,,,,,,,,,,,,,,,, -108600,0.33433086,1.5979518,,,,,,,,,,,,,,,,, -108700,0.30364913,1.5401695,,,,,,,,,,,,,,,,, -108800,0.32454607,1.5083174,,,,,,,,,,,,,,,,, -108900,0.3484009,1.5756011,,,,,,,,,,,,,,,,, -109000,0.30669108,1.5803826,,,,,,,,,,,,,,,,, -109100,0.30819127,1.553341,,,,,,,,,,,,,,,,, -109200,0.33164626,1.5734619,,,,,,,,,,,,,,,,, -109300,0.31532264,1.5616968,,,,,,,,,,,,,,,,, -109400,0.3353146,1.6248472,,,,,,,,,,,,,,,,, -109500,0.31594092,1.6450987,,,,,,,,,,,,,,,,, -109600,0.33724943,1.5160335,,,,,,,,,,,,,,,,, -109700,0.3284607,1.586124,,,,,,,,,,,,,,,,, -109800,0.33561715,1.5590273,,,,,,,,,,,,,,,,, -109900,0.33299607,1.5877794,,,,,,,,,,,,,,,,, -110000,0.3047249,1.5666304,,,,,,,,,,,,,,,,, -110100,0.32218194,1.5520047,,,,,,,,,,,,,,,,, -110200,0.32523346,1.6115539,,,,,,,,,,,,,,,,, -110300,0.362506,1.4932005,,,,,,,,,,,,,,,,, -110400,0.3422587,1.5565177,,,,,,,,,,,,,,,,, -110471,,,0.6731343865394592,1.5567851066589355,33.786392688007275,0.6843188405036926,1.4721633195877075,30.15936458541903,3000.0,0.6989832520484924,1.3790864944458008,29.73907053500585,3003.0,38670.3262925148,64930.54170131683,38670.3262925148,26255.110206842422,1.6010394096374512,0.0 -110500,0.33636743,1.6544336,,,,,,,,,,,,,,,,, -110600,0.33851743,1.5395514,,,,,,,,,,,,,,,,, -110700,0.33003685,1.6255606,,,,,,,,,,,,,,,,, -110800,0.3397345,1.5743109,,,,,,,,,,,,,,,,, -110900,0.30807072,1.6013905,,,,,,,,,,,,,,,,, -111000,0.34793675,1.6343281,,,,,,,,,,,,,,,,, -111100,0.335741,1.4958745,,,,,,,,,,,,,,,,, -111200,0.33984384,1.4661244,,,,,,,,,,,,,,,,, -111300,0.35361114,1.5998986,,,,,,,,,,,,,,,,, -111400,0.3300453,1.5472544,,,,,,,,,,,,,,,,, -111500,0.32740003,1.5841224,,,,,,,,,,,,,,,,, -111600,0.35211718,1.5396196,,,,,,,,,,,,,,,,, -111700,0.32923013,1.5008795,,,,,,,,,,,,,,,,, -111800,0.3230648,1.5704494,,,,,,,,,,,,,,,,, -111900,0.3505411,1.5224077,,,,,,,,,,,,,,,,, -112000,0.31902134,1.5186002,,,,,,,,,,,,,,,,, -112100,0.32991368,1.5460694,,,,,,,,,,,,,,,,, -112200,0.32666284,1.5331358,,,,,,,,,,,,,,,,, -112300,0.31896386,1.5746827,,,,,,,,,,,,,,,,, -112400,0.3365466,1.53655,,,,,,,,,,,,,,,,, -112500,0.3471737,1.6270171,,,,,,,,,,,,,,,,, -112600,0.38180414,1.6241051,,,,,,,,,,,,,,,,, -112700,0.3287512,1.5390843,,,,,,,,,,,,,,,,, -112800,0.32762796,1.4796177,,,,,,,,,,,,,,,,, -112873,,,0.6936594247817993,1.431340217590332,35.508655574660985,0.685050368309021,1.4566941261291504,30.06185313900829,3000.0,0.6995874643325806,1.3684998750686646,30.10767739633273,3003.0,39510.36837506294,66343.16135883331,39510.36837506294,26827.57296895981,1.6411964893341064,0.0 -112900,0.34693307,1.5528413,,,,,,,,,,,,,,,,, -113000,0.3340348,1.5150753,,,,,,,,,,,,,,,,, -113100,0.36750084,1.5845027,,,,,,,,,,,,,,,,, -113200,0.33996162,1.5054011,,,,,,,,,,,,,,,,, -113300,0.35834286,1.550688,,,,,,,,,,,,,,,,, -113400,0.3689207,1.5051115,,,,,,,,,,,,,,,,, -113500,0.3441974,1.5968301,,,,,,,,,,,,,,,,, -113600,0.3478649,1.5312078,,,,,,,,,,,,,,,,, -113700,0.33696723,1.5326185,,,,,,,,,,,,,,,,, -113800,0.32658973,1.5736378,,,,,,,,,,,,,,,,, -113900,0.34913054,1.6148697,,,,,,,,,,,,,,,,, -114000,0.33829603,1.4608727,,,,,,,,,,,,,,,,, -114100,0.33761185,1.5739189,,,,,,,,,,,,,,,,, -114200,0.34374025,1.5340302,,,,,,,,,,,,,,,,, -114300,0.36480808,1.5663395,,,,,,,,,,,,,,,,, -114400,0.36174393,1.506775,,,,,,,,,,,,,,,,, -114500,0.34948266,1.495893,,,,,,,,,,,,,,,,, -114600,0.35234904,1.5196408,,,,,,,,,,,,,,,,, -114700,0.34132054,1.5382333,,,,,,,,,,,,,,,,, -114800,0.35338968,1.5372049,,,,,,,,,,,,,,,,, -114900,0.31909284,1.4798,,,,,,,,,,,,,,,,, -115000,0.32888374,1.4983588,,,,,,,,,,,,,,,,, -115100,0.3257627,1.4794297,,,,,,,,,,,,,,,,, -115200,0.3482012,1.4719652,,,,,,,,,,,,,,,,, -115275,,,0.6838340163230896,1.4911880493164062,34.26181205626829,0.6880261898040771,1.4454679489135742,30.37748177908537,3000.0,0.7013421654701233,1.357330322265625,30.101626542932724,3003.0,40350.29947733879,67821.58649492264,40350.29947733879,27465.951172590256,1.6827702522277832,0.0 -115300,0.3605469,1.5892085,,,,,,,,,,,,,,,,, -115400,0.3575065,1.5384123,,,,,,,,,,,,,,,,, -115500,0.36271968,1.5288482,,,,,,,,,,,,,,,,, -115600,0.34377316,1.5334979,,,,,,,,,,,,,,,,, -115700,0.3437543,1.5035263,,,,,,,,,,,,,,,,, -115800,0.33425003,1.4872159,,,,,,,,,,,,,,,,, -115900,0.36028203,1.5874203,,,,,,,,,,,,,,,,, -116000,0.33127472,1.4399122,,,,,,,,,,,,,,,,, -116100,0.3559902,1.4697896,,,,,,,,,,,,,,,,, -116200,0.3496144,1.5377066,,,,,,,,,,,,,,,,, -116300,0.35350984,1.5114709,,,,,,,,,,,,,,,,, -116400,0.3555929,1.4497976,,,,,,,,,,,,,,,,, -116500,0.33792156,1.526618,,,,,,,,,,,,,,,,, -116600,0.3674729,1.4203098,,,,,,,,,,,,,,,,, -116700,0.35407388,1.519624,,,,,,,,,,,,,,,,, -116800,0.34768927,1.550977,,,,,,,,,,,,,,,,, -116900,0.35487893,1.5648336,,,,,,,,,,,,,,,,, -117000,0.38090435,1.5716794,,,,,,,,,,,,,,,,, -117100,0.36891407,1.5432508,,,,,,,,,,,,,,,,, -117200,0.35210025,1.5313115,,,,,,,,,,,,,,,,, -117300,0.3571683,1.4598545,,,,,,,,,,,,,,,,, -117400,0.36449733,1.5784874,,,,,,,,,,,,,,,,, -117500,0.34500697,1.4421191,,,,,,,,,,,,,,,,, -117600,0.3449158,1.464603,,,,,,,,,,,,,,,,, -117677,,,0.6806775331497192,1.5084309577941897,34.7215799842478,0.6893280744552612,1.4370354413986206,30.50947906381349,3000.0,0.7044913172721863,1.3434144258499146,30.434828744976627,3003.0,41190.48154783249,69216.60397958755,41190.48154783249,28020.668506383896,1.7252159118652344,0.0 -117700,0.34380725,1.5305481,,,,,,,,,,,,,,,,, -117800,0.37078905,1.508099,,,,,,,,,,,,,,,,, -117900,0.36370677,1.4954963,,,,,,,,,,,,,,,,, -118000,0.37392485,1.4251976,,,,,,,,,,,,,,,,, -118100,0.38106066,1.5839387,,,,,,,,,,,,,,,,, -118200,0.37216663,1.5034252,,,,,,,,,,,,,,,,, -118300,0.37103447,1.5239816,,,,,,,,,,,,,,,,, -118400,0.38041866,1.5020701,,,,,,,,,,,,,,,,, -118500,0.35369098,1.5128388,,,,,,,,,,,,,,,,, -118600,0.35767233,1.4750143,,,,,,,,,,,,,,,,, -118700,0.35473886,1.5677352,,,,,,,,,,,,,,,,, -118800,0.37096387,1.4908746,,,,,,,,,,,,,,,,, -118900,0.35181847,1.4661599,,,,,,,,,,,,,,,,, -119000,0.3855484,1.5572973,,,,,,,,,,,,,,,,, -119100,0.3507259,1.5349681,,,,,,,,,,,,,,,,, -119200,0.3873678,1.524081,,,,,,,,,,,,,,,,, -119300,0.3635311,1.4777881,,,,,,,,,,,,,,,,, -119400,0.3469042,1.4529263,,,,,,,,,,,,,,,,, -119500,0.38514045,1.432828,,,,,,,,,,,,,,,,, -119600,0.35327047,1.4589629,,,,,,,,,,,,,,,,, -119700,0.35917243,1.4725212,,,,,,,,,,,,,,,,, -119800,0.37402475,1.430017,,,,,,,,,,,,,,,,, -119900,0.43620124,1.514366,,,,,,,,,,,,,,,,, -120000,0.35509974,1.474806,,,,,,,,,,,,,,,,, -120079,,,0.6928339004516602,1.4373910427093506,35.37524108993434,0.690034806728363,1.4303141832351685,30.701788225822256,3000.0,0.7062111496925354,1.3397371768951416,30.584510235579657,3003.0,42030.51136517525,70611.56590008736,42030.51136517525,28575.47600364685,1.7752346992492676,0.0 -120100,0.38373938,1.5063146,,,,,,,,,,,,,,,,, -120200,0.36444455,1.4973785,,,,,,,,,,,,,,,,, -120300,0.3797098,1.5094829,,,,,,,,,,,,,,,,, -120400,0.36582544,1.4282911,,,,,,,,,,,,,,,,, -120500,0.39529616,1.5234523,,,,,,,,,,,,,,,,, -120600,0.35626096,1.5349057,,,,,,,,,,,,,,,,, -120700,0.37678528,1.4360441,,,,,,,,,,,,,,,,, -120800,0.34767488,1.4754934,,,,,,,,,,,,,,,,, -120900,0.35671893,1.4651333,,,,,,,,,,,,,,,,, -121000,0.36050057,1.4856503,,,,,,,,,,,,,,,,, -121100,0.38569307,1.5458928,,,,,,,,,,,,,,,,, -121200,0.3695576,1.5104651,,,,,,,,,,,,,,,,, -121300,0.37301067,1.502848,,,,,,,,,,,,,,,,, -121400,0.36597404,1.4631487,,,,,,,,,,,,,,,,, -121500,0.36812454,1.446039,,,,,,,,,,,,,,,,, -121600,0.37972313,1.4675418,,,,,,,,,,,,,,,,, -121700,0.37622237,1.4653779,,,,,,,,,,,,,,,,, -121800,0.3681108,1.4985646,,,,,,,,,,,,,,,,, -121900,0.36833048,1.5239269,,,,,,,,,,,,,,,,, -122000,0.39554983,1.4599355,,,,,,,,,,,,,,,,, -122100,0.34839827,1.3969126,,,,,,,,,,,,,,,,, -122200,0.36155418,1.491281,,,,,,,,,,,,,,,,, -122300,0.3725847,1.4216872,,,,,,,,,,,,,,,,, -122400,0.37815,1.4913982,,,,,,,,,,,,,,,,, -122481,,,0.688599705696106,1.4655152559280396,35.63986574927191,0.6912251710891724,1.4246001243591309,30.711851471746595,3000.0,0.7073150873184204,1.328243374824524,30.721067569448703,3003.0,42870.60580945015,71997.75002217293,42870.60580945015,29121.44729280472,1.8173811435699463,0.0 -122500,0.38828483,1.4404918,,,,,,,,,,,,,,,,, -122600,0.38083184,1.4983706,,,,,,,,,,,,,,,,, -122700,0.36972615,1.4011667,,,,,,,,,,,,,,,,, -122800,0.37982664,1.5071201,,,,,,,,,,,,,,,,, -122900,0.38276446,1.4741805,,,,,,,,,,,,,,,,, -123000,0.41222516,1.510311,,,,,,,,,,,,,,,,, -123100,0.37473267,1.4061834,,,,,,,,,,,,,,,,, -123200,0.4073046,1.5161488,,,,,,,,,,,,,,,,, -123300,0.38594463,1.4275326,,,,,,,,,,,,,,,,, -123400,0.373172,1.4616578,,,,,,,,,,,,,,,,, -123500,0.3688234,1.4769993,,,,,,,,,,,,,,,,, -123600,0.3800431,1.4842124,,,,,,,,,,,,,,,,, -123700,0.3990994,1.5055816,,,,,,,,,,,,,,,,, -123800,0.3723604,1.392455,,,,,,,,,,,,,,,,, -123900,0.3651024,1.4561191,,,,,,,,,,,,,,,,, -124000,0.3827491,1.4576812,,,,,,,,,,,,,,,,, -124100,0.40236938,1.4745063,,,,,,,,,,,,,,,,, -124200,0.39025885,1.3880708,,,,,,,,,,,,,,,,, -124300,0.37832698,1.4939663,,,,,,,,,,,,,,,,, -124400,0.36348656,1.4401859,,,,,,,,,,,,,,,,, -124500,0.37203115,1.4464597,,,,,,,,,,,,,,,,, -124600,0.37203598,1.4097863,,,,,,,,,,,,,,,,, -124700,0.3968283,1.4049757,,,,,,,,,,,,,,,,, -124800,0.40055463,1.4979696,,,,,,,,,,,,,,,,, -124883,,,0.6913896799087524,1.4493789672851562,35.22434842341379,0.6933205723762512,1.415845274925232,30.862479457424147,3000.0,0.7088490128517151,1.3179659843444824,30.963558849323057,3003.0,43710.570806741714,73377.59749126434,43710.570806741714,29661.2113044262,1.8602821826934808,0.0 -124900,0.37601146,1.4480772,,,,,,,,,,,,,,,,, -125000,0.3792984,1.4362901,,,,,,,,,,,,,,,,, -125100,0.39502385,1.5296812,,,,,,,,,,,,,,,,, -125200,0.38440478,1.420562,,,,,,,,,,,,,,,,, -125300,0.3859216,1.4395161,,,,,,,,,,,,,,,,, -125400,0.36785656,1.511643,,,,,,,,,,,,,,,,, -125500,0.38360435,1.4866679,,,,,,,,,,,,,,,,, -125600,0.3809767,1.4433001,,,,,,,,,,,,,,,,, -125700,0.39662603,1.4545469,,,,,,,,,,,,,,,,, -125800,0.38957763,1.4609513,,,,,,,,,,,,,,,,, -125900,0.37156242,1.3317741,,,,,,,,,,,,,,,,, -126000,0.38685107,1.4726909,,,,,,,,,,,,,,,,, -126100,0.37633568,1.4602875,,,,,,,,,,,,,,,,, -126200,0.39830703,1.3762665,,,,,,,,,,,,,,,,, -126300,0.3878074,1.4910915,,,,,,,,,,,,,,,,, -126400,0.38718745,1.455365,,,,,,,,,,,,,,,,, -126500,0.37986585,1.3578345,,,,,,,,,,,,,,,,, -126600,0.4016884,1.4649819,,,,,,,,,,,,,,,,, -126700,0.37294418,1.3731643,,,,,,,,,,,,,,,,, -126800,0.37597802,1.4320799,,,,,,,,,,,,,,,,, -126900,0.3903547,1.5029267,,,,,,,,,,,,,,,,, -127000,0.38364175,1.3920063,,,,,,,,,,,,,,,,, -127100,0.3973077,1.4521558,,,,,,,,,,,,,,,,, -127200,0.3909938,1.46788,,,,,,,,,,,,,,,,, -127286,,,0.6950476765632629,1.428101658821106,35.723502797572735,0.6939281225204468,1.4126336574554443,30.916063388296088,3000.0,0.7101272344589233,1.3143149614334106,30.82745926563664,3003.0,44550.7475271225,74749.03038787842,44550.7475271225,30192.34838557244,1.9036564826965328,0.0 -127300,0.37620354,1.3845147,,,,,,,,,,,,,,,,, -127400,0.39286277,1.3962728,,,,,,,,,,,,,,,,, -127500,0.3870032,1.4644266,,,,,,,,,,,,,,,,, -127600,0.38437554,1.4826385,,,,,,,,,,,,,,,,, -127700,0.3802589,1.5061111,,,,,,,,,,,,,,,,, -127800,0.3954217,1.4238517,,,,,,,,,,,,,,,,, -127900,0.3884344,1.4744372,,,,,,,,,,,,,,,,, -128000,0.39134645,1.4168987,,,,,,,,,,,,,,,,, -128100,0.38979518,1.5349568,,,,,,,,,,,,,,,,, -128200,0.3783019,1.455519,,,,,,,,,,,,,,,,, -128300,0.39719778,1.4701385,,,,,,,,,,,,,,,,, -128400,0.3709226,1.3957841,,,,,,,,,,,,,,,,, -128500,0.3710729,1.3826504,,,,,,,,,,,,,,,,, -128600,0.37572727,1.4284245,,,,,,,,,,,,,,,,, -128700,0.374121,1.4375676,,,,,,,,,,,,,,,,, -128800,0.38427964,1.4559214,,,,,,,,,,,,,,,,, -128900,0.38070163,1.43696,,,,,,,,,,,,,,,,, -129000,0.39940736,1.4253094,,,,,,,,,,,,,,,,, -129100,0.3837274,1.49771,,,,,,,,,,,,,,,,, -129200,0.39099777,1.4732455,,,,,,,,,,,,,,,,, -129300,0.38489056,1.5004334,,,,,,,,,,,,,,,,, -129400,0.39216274,1.4849037,,,,,,,,,,,,,,,,, -129500,0.37829965,1.4040459,,,,,,,,,,,,,,,,, -129600,0.39058086,1.4810082,,,,,,,,,,,,,,,,, -129688,,,0.6956238150596619,1.4236267805099487,35.90949567414748,0.6945853233337402,1.410330057144165,30.926567537057483,3000.0,0.7101272344589233,1.311076045036316,31.011806972890582,3003.0,45390.82706856728,76096.52921843529,45390.82706856728,30699.643965244293,1.9506874084472656,0.0 -129700,0.4081833,1.5038632,,,,,,,,,,,,,,,,, -129800,0.37433672,1.4281813,,,,,,,,,,,,,,,,, -129900,0.38163364,1.4103824,,,,,,,,,,,,,,,,, -130000,0.373322,1.4323671,,,,,,,,,,,,,,,,, -130100,0.3989731,1.4841478,,,,,,,,,,,,,,,,, -130200,0.3769825,1.4472798,,,,,,,,,,,,,,,,, -130300,0.38435575,1.411662,,,,,,,,,,,,,,,,, -130400,0.38014123,1.3605084,,,,,,,,,,,,,,,,, -130500,0.38099205,1.4235259,,,,,,,,,,,,,,,,, -130600,0.37431347,1.3596225,,,,,,,,,,,,,,,,, -130700,0.372599,1.443989,,,,,,,,,,,,,,,,, -130800,0.3899614,1.4748902,,,,,,,,,,,,,,,,, -130900,0.37759578,1.3402104,,,,,,,,,,,,,,,,, -131000,0.38112682,1.3510737,,,,,,,,,,,,,,,,, -131100,0.36861125,1.3679442,,,,,,,,,,,,,,,,, -131200,0.39679724,1.4276932,,,,,,,,,,,,,,,,, -131300,0.36297873,1.4330238,,,,,,,,,,,,,,,,, -131400,0.362013,1.4562855,,,,,,,,,,,,,,,,, -131500,0.3860138,1.4535447,,,,,,,,,,,,,,,,, -131600,0.37421802,1.407863,,,,,,,,,,,,,,,,, -131700,0.38181034,1.4310224,,,,,,,,,,,,,,,,, -131800,0.40068823,1.4758563,,,,,,,,,,,,,,,,, -131900,0.37716115,1.3958571,,,,,,,,,,,,,,,,, -132000,0.37877408,1.3974261,,,,,,,,,,,,,,,,, -132089,,,0.7016637921333313,1.386156439781189,35.79594700241151,0.694386899471283,1.4098578691482544,30.910722132760736,3000.0,0.7106966376304626,1.3105497360229492,30.939606512181268,3003.0,46230.71160531044,77462.6783709526,46230.71160531044,31225.78633785248,1.9961497783660889,0.0 -132100,0.3836858,1.453876,,,,,,,,,,,,,,,,, -132200,0.38823667,1.4553663,,,,,,,,,,,,,,,,, -132300,0.3826069,1.4315059,,,,,,,,,,,,,,,,, -132400,0.40224284,1.4775723,,,,,,,,,,,,,,,,, -132500,0.3848329,1.3715096,,,,,,,,,,,,,,,,, -132600,0.39640155,1.4661884,,,,,,,,,,,,,,,,, -132700,0.39891323,1.4886827,,,,,,,,,,,,,,,,, -132800,0.3807556,1.3801243,,,,,,,,,,,,,,,,, -132900,0.3793223,1.3821796,,,,,,,,,,,,,,,,, -133000,0.37180945,1.4235581,,,,,,,,,,,,,,,,, -133100,0.37284404,1.3943084,,,,,,,,,,,,,,,,, -133200,0.39485297,1.3437575,,,,,,,,,,,,,,,,, -133300,0.38302398,1.5054262,,,,,,,,,,,,,,,,, -133333,,,0.6967169642448425,1.42378568649292,35.835157411181875,0.69430011510849,1.410101294517517,30.869154358846924,3000.0,0.7108826041221619,1.310747146606445,30.902745284398407,3003.0,46665.6048541069,78415.79715752602,46665.6048541069,31743.91850733757,2.04841947555542,0.0 -133333,,,,,,,,,,,,,,46665.6048541069,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/eval_measurements.csv deleted file mode 100644 index f2503dd0a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,59 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -841.7417721748352,0.0,31.995970487594604,1,0,31.995970487594604,0.0007088489946909,0.0,11.191027641296388,3003,873.7377886772156,0.0005809449939988,0.0,11.19240951538086,0.0004835649742744,0.0,11.190281867980955,3000 -1379.5385262966156,0.0206320285797119,872.1930792331696,2343,0,872.1930792331696,0.5100807547569275,17.119625855590364,2.911456823348999,3003,2251.830168247223,0.5090352892875671,22.760860535397587,2.874778985977173,0.509181559085846,18.645719298845435,2.877437829971313,3000 -1839.9097967147827,0.0460519790649414,1712.265777349472,4687,0,1712.265777349472,0.5936436057090759,22.072443389635552,2.1284549236297607,3003,3552.374214172364,0.5762688517570496,27.03187390595071,2.2632884979248047,0.5884490013122559,23.44131993452736,2.1593129634857178,3000 -2280.1172001361847,0.0716917514801025,2552.231223344803,7032,0,2552.231223344803,0.6223810315132141,23.71615737939061,1.891671061515808,3003,4832.6473343372345,0.6044877767562866,29.150223633210214,2.010396242141724,0.6167561411857605,25.14535870776252,1.9335185289382928,3000 -2709.7009241580963,0.0972914695739746,3392.4043912887573,9378,0,3392.4043912887573,0.637569010257721,25.21416047743482,1.7603546380996704,3003,6102.503938674927,0.6121419668197632,29.574920246863407,1.9549943208694456,0.6282377243041992,26.20106139743751,1.8250943422317505,3000 -3187.856215953827,0.1250281333923339,4232.526937961578,11725,0,4232.526937961578,0.6475510001182556,25.765832245270712,1.691665768623352,3003,7420.881967782974,0.6145750284194946,29.91781370647072,1.9365715980529783,0.637115478515625,26.47492510250528,1.762860894203186,3000 -3639.9922440052032,0.1513416767120361,5072.729242563248,14071,0,5072.729242563248,0.6514670848846436,26.01923512162698,1.646348476409912,3003,8713.320725440979,0.6287801861763,30.57407513299764,1.8265225887298584,0.6427942514419556,27.20454930807229,1.715195655822754,3000 -4137.209697246552,0.1788101196289062,5912.799705505371,16415,0,5912.799705505371,0.6555575132369995,25.974975993637944,1.6271188259124756,3003,10050.71216583252,0.6283307075500488,30.64059289155258,1.83135998249054,0.6439225673675537,26.862268504080657,1.6987574100494385,3000 -4697.838307380676,0.2108919620513916,6752.885207414627,18760,0,6752.885207414627,0.6605310440063477,26.84595365712397,1.594834566116333,3003,11451.534358024595,0.6313144564628601,30.3301787042636,1.8021981716156008,0.6500601172447205,27.40282885917218,1.666564702987671,3000 -5348.135931015015,0.2389867305755615,7592.865230083466,21106,0,7592.865230083466,0.6628551483154297,26.80578598178156,1.5742295980453491,3003,12941.91342139244,0.6347932815551758,31.21414025678416,1.7635672092437744,0.6517339944839478,27.60810916818256,1.6439547538757324,3000 -5836.484181642532,0.2676777839660644,8432.95629954338,23452,0,8432.95629954338,0.6636337637901306,27.270381340846026,1.5664596557617188,3003,14270.456376552582,0.6305269598960876,30.61069211527252,1.790299415588379,0.6529738903045654,27.64664612249136,1.6360857486724854,3000 -6356.156872987747,0.2959790229797363,9273.077693939207,25798,0,9273.077693939207,0.6658416390419006,27.35816931986309,1.5436705350875854,3003,15630.35198712349,0.642456591129303,31.55244803474688,1.7020819187164309,0.6566688418388367,28.153017370131508,1.6203370094299316,3000 -6800.711962461472,0.3239438533782959,10113.140311479568,28144,0,10113.140311479568,0.6691185832023621,27.576740566813548,1.5272434949874878,3003,16915.07120656967,0.6392359733581543,31.22346015081715,1.7401916980743408,0.6564580798149109,27.81124516471783,1.6131483316421509,3000 -7284.199323177338,0.3592972755432129,10953.335748672484,30490,0,10953.335748672484,0.6619836091995239,26.57815801463406,1.5729281902313232,3003,18238.866058826447,0.6276158094406128,29.77989999011368,1.8166896104812624,0.6502957344055176,26.92453743681419,1.6435110569000244,3000 -7821.921438455582,0.3904638290405273,11793.542217969894,32836,0,11793.542217969894,0.6721515655517578,27.20909834912664,1.512355923652649,3003,19616.90053844452,0.6437987089157104,31.70166116605699,1.7077805995941162,0.6607481837272644,28.223498811079704,1.5932337045669556,3000 -8442.923906326294,0.4212970733642578,12633.712213754654,35182,0,12633.712213754654,0.6742432117462158,27.942357040831503,1.504552960395813,3003,21078.17825651169,0.644922137260437,31.362425620829725,1.709308624267578,0.6612069010734558,28.02851951476708,1.584161639213562,3000 -8934.070060491562,0.4525520801544189,13473.850688695908,37528,0,13473.850688695908,0.6755679845809937,27.54270861029842,1.4972996711730957,3003,22409.57164978981,0.6506013870239258,32.2079672733254,1.6590782403945925,0.6631659865379333,28.21083540443361,1.5785056352615356,3000 -9511.859144210815,0.4880993366241455,14313.951741695404,39873,0,14313.951741695404,0.6739526987075806,27.546707690715404,1.494259476661682,3003,23827.573871850967,0.6454939842224121,31.815291168303688,1.692597508430481,0.6619632840156555,28.208675423715423,1.576145052909851,3000 -10027.034233808516,0.5258986949920654,15154.16379904747,42219,0,15154.16379904747,0.6774853467941284,27.98596060024852,1.4813222885131836,3003,25183.07438015937,0.6474224925041199,31.77898337621006,1.68617844581604,0.6637239456176758,28.69145620875173,1.564249873161316,3000 -10544.088871002195,0.5563359260559082,15994.344855546951,44564,0,15994.344855546951,0.6764511466026306,27.830552615538394,1.4741398096084597,3003,26540.41818094253,0.6538242101669312,32.21654232899015,1.6220608949661257,0.6659061908721924,28.64157721411009,1.5523436069488523,3000 -11064.661159992218,0.5887689590454102,16834.357803106308,46910,0,16834.357803106308,0.6811225414276123,28.346774455501244,1.4597563743591309,3003,27901.11219573021,0.6494661569595337,31.6695985126084,1.669074296951294,0.6667740941047668,28.77795598635,1.5472114086151123,3000 -11535.024050235748,0.6217913627624512,17674.36550784111,49256,0,17674.36550784111,0.6809831261634827,28.45048140232037,1.4532852172851562,3003,29211.590218305588,0.6462664604187012,31.678658230344087,1.6856175661087036,0.6664517521858215,28.842956598526044,1.5448410511016846,3000 -12118.698055744171,0.6539661884307861,18514.27074623108,51601,0,18514.27074623108,0.6818430423736572,28.224619330030805,1.4524085521697998,3003,30635.27508306504,0.6567646265029907,32.20008454597276,1.613392949104309,0.668398380279541,28.79298934338836,1.533761501312256,3000 -12716.516110658646,0.6878187656402588,19354.333768606182,53947,0,19354.333768606182,0.6840276718139648,28.765265595968746,1.4374014139175415,3003,32073.26548433304,0.6505134701728821,31.867379464977144,1.660967469215393,0.6687207818031311,28.76152895014794,1.5278306007385254,3000 -13297.87374830246,0.7198841571807861,20194.51906824112,56293,0,20194.51906824112,0.6826448440551758,28.341602362376708,1.4398152828216553,3003,33494.9159014225,0.6667506098747253,33.09192433508889,1.5589523315429688,0.6687827706336975,28.63368948315072,1.5208386182785034,3000 -13941.65684556961,0.7570688724517822,21034.617770195007,58639,0,21034.617770195007,0.6835163831710815,28.599861865485742,1.4296272993087769,3003,34978.909477710724,0.655954897403717,32.188351151900314,1.618160605430603,0.6726884841918945,29.07218011022572,1.5124878883361816,3000 -14425.952094316484,0.7913601398468018,21874.773897647858,60985,0,21874.773897647858,0.6864215135574341,29.007746720886413,1.420398235321045,3003,36303.46973657608,0.6542706489562988,32.1757717300823,1.6405000686645508,0.6717957258224487,29.062029670361643,1.5104609727859497,3000 -14934.84242773056,0.8252365589141846,22714.68177628517,63330,0,22714.68177628517,0.6854802370071411,28.617106854066893,1.4173305034637451,3003,37652.37822747231,0.6675124168395996,32.92653022817832,1.544390082359314,0.6732712388038635,29.13346058303088,1.505146026611328,3000 -15550.918023347856,0.8605771064758301,23554.65200853348,65676,0,23554.65200853348,0.6907327175140381,29.165284759238112,1.3964228630065918,3003,39108.53287887573,0.6562884449958801,32.68944922577982,1.612204670906067,0.6751435399055481,29.30038312727096,1.491057515144348,3000 -16098.42095041275,0.8969278335571289,24394.624833345413,68022,0,24394.624833345413,0.6890709400177002,29.19693387188205,1.3990836143493652,3003,40496.11887073517,0.6579658389091492,32.35109639577502,1.613708734512329,0.6769165992736816,29.62679347416672,1.4885090589523315,3000 -16816.858570575714,0.9333441257476808,25234.659603357315,70368,0,25234.659603357315,0.6903492212295532,28.78012899563756,1.3912490606307983,3003,42054.70087099075,0.6629226207733154,32.78073814231867,1.5664702653884888,0.6758130788803101,29.19225225701453,1.479118824005127,3000 -17393.892706871033,0.9688091278076172,26074.649913072582,72714,0,26074.649913072582,0.6935099959373474,29.474507063397496,1.380598545074463,3003,43471.833458423615,0.6606534719467163,32.511623246878834,1.593955636024475,0.6776605248451233,29.498651552431973,1.4730331897735596,3000 -17876.356063604355,1.006338119506836,26914.73360776901,75061,0,26914.73360776901,0.6918134093284607,29.14106138373952,1.3813815116882324,3003,44794.491381406784,0.6858612298965454,34.89402750189174,1.4403479099273682,0.677734911441803,29.468598279995035,1.4711121320724487,3000 -18526.739768743515,1.0434327125549316,27754.92539286613,77406,0,27754.92539286613,0.694428026676178,29.284875716363665,1.3683438301086426,3003,46285.17871427536,0.665233850479126,33.34801089317181,1.553189396858215,0.6795451641082764,29.55528459819339,1.4581555128097534,3000 -19097.646564006805,1.0817084312438965,28595.10469722748,79751,0,28595.10469722748,0.6931613683700562,29.269250344436912,1.3743035793304443,3003,47696.377888441086,0.6607846617698669,33.05226990077929,1.5785906314849854,0.6790120601654053,29.44278157741172,1.4651505947113037,3000 -19642.142485380173,1.1191542148590088,29434.99210190773,82097,0,29434.99210190773,0.6955435872077942,29.42719219404512,1.3564544916152954,3003,49080.874106407166,0.6755383014678955,33.659196883772914,1.49465811252594,0.6815910339355469,29.689568520414724,1.449007272720337,3000 -20309.78834581375,1.156053066253662,30274.89132285118,84442,0,30274.89132285118,0.697402834892273,29.5544296637964,1.344928741455078,3003,50588.53385710716,0.6705901026725769,33.776368870019915,1.5163570642471311,0.6826077699661255,29.743246304174004,1.4386560916900637,3000 -20825.03607749939,1.1935102939605713,31114.93718099594,86788,0,31114.93718099594,0.6986927390098572,29.75376935276304,1.338726043701172,3003,51943.9401807785,0.6693463325500488,33.519672285560624,1.5310808420181274,0.682632565498352,30.200129200007808,1.43792986869812,3000 -21425.26732730865,1.2301530838012695,31955.03034901619,89133,0,31955.03034901619,0.6976584792137146,29.52895939744145,1.333077311515808,3003,53384.378826379776,0.6757782101631165,34.190875248660575,1.485595941543579,0.6854719519615173,29.772108184170435,1.4247512817382812,3000 -22052.010696411133,1.276489496231079,32794.96316599846,91479,0,32794.96316599846,0.7033525109291077,30.261509158958816,1.3221434354782104,3003,54851.17655205727,0.6764565110206604,33.79142281553153,1.4868851900100708,0.6856207251548767,30.081463864085272,1.4216893911361694,3000 -22633.36717224121,1.315727710723877,33634.999903678894,93825,0,33634.999903678894,0.7028295993804932,30.10693494571689,1.319978952407837,3003,56272.6827609539,0.690422534942627,34.952902904917686,1.4100230932235718,0.685050368309021,30.129531487347982,1.4175060987472534,3000 -23432.17221236229,1.3558223247528076,34475.117656469345,96171,0,34475.117656469345,0.7031084895133972,30.01868414744477,1.311916708946228,3003,57911.7241795063,0.6796119809150696,34.32243432293985,1.4651366472244265,0.6871210336685181,30.200335772061464,1.4065537452697754,3000 -23966.787534713745,1.394505500793457,35315.19370055199,98517,0,35315.19370055199,0.7058160901069641,30.582308832194634,1.301902413368225,3003,59286.53086400032,0.6749527454376221,33.76268337396561,1.4922711849212646,0.6880509853363037,30.562362559763177,1.4029872417449951,3000 -24508.896410226826,1.4340672492980957,36155.15829825401,100862,0,36155.15829825401,0.7054442167282104,30.24920047046108,1.299425721168518,3003,60668.71771264076,0.6901666522026062,34.54494358756864,1.3958227634429932,0.6898488402366638,30.523842063106216,1.3990224599838257,3000 -25077.136883974075,1.4827287197113037,36995.03522968292,103207,0,36995.03522968292,0.708686351776123,30.83900984318424,1.2871750593185425,3003,62076.9599506855,0.6826054453849792,34.78132076309567,1.445312738418579,0.6900224089622498,30.54686521049929,1.391385555267334,3000 -25654.313775777817,1.522960901260376,37835.16900038719,105552,0,37835.16900038719,0.706187903881073,30.671070000288665,1.2858394384384155,3003,63494.38958978653,0.6831424832344055,34.30571893606282,1.4466559886932373,0.6902456283569336,30.71525097595463,1.3884084224700928,3000 -26250.866422891617,1.5631628036499023,38675.34786558151,107899,0,38675.34786558151,0.7079077363014221,30.42644247752539,1.2825238704681396,3003,64931.23558783531,0.6916837692260742,35.40145042693628,1.393021583557129,0.691299557685852,30.558788910901555,1.383009910583496,3000 -26872.388967752457,1.603606939315796,39515.29098367691,110245,0,39515.29098367691,0.7103015780448914,30.49992268140961,1.2744451761245728,3003,66392.81823420525,0.6882879137992859,35.19981584636805,1.4097886085510254,0.6923038959503174,30.56347426094465,1.3791635036468506,3000 -27474.955493688583,1.654353380203247,40355.34690570831,112591,0,40355.34690570831,0.7099180817604065,30.61117920436823,1.2725780010223389,3003,67835.56606578827,0.6941087245941162,35.41065961993999,1.3747334480285645,0.6923410892486572,30.704859704017625,1.3770735263824463,3000 -28219.04938960076,1.7018797397613523,41195.55699467659,114937,0,41195.55699467659,0.7107896208763123,30.542745113871938,1.2660441398620603,3003,69419.99466729164,0.6922072172164917,35.2817522823305,1.393733263015747,0.693717360496521,30.42187428646859,1.3703649044036863,3000 -28810.52793598175,1.7455673217773438,42035.75880694389,117283,0,42035.75880694389,0.7114868760108948,30.76111965834509,1.2633055448532104,3003,70851.79516196251,0.690304696559906,35.20293102881795,1.3954484462738037,0.69430011510849,30.820244597158045,1.370716571807861,3000 -29512.847000598907,1.788480281829834,42875.923796892166,119629,0,42875.923796892166,0.7119168043136597,30.859965177480586,1.2605042457580566,3003,72394.39634847641,0.6958056688308716,35.7102083223731,1.3685253858566284,0.6936181783676147,30.65897634626175,1.36842942237854,3000 -30084.16514778137,1.832347393035889,43716.03023672104,121974,0,43716.03023672104,0.7125210762023926,30.773637579945547,1.2577069997787476,3003,73805.94238114357,0.6915479302406311,35.36310249116494,1.388578176498413,0.6942381262779236,30.68772665448068,1.3653781414031982,3000 -30701.053248643875,1.876137018203736,44556.06577014923,124318,0,44556.06577014923,0.7131369709968567,30.84595524052872,1.2557777166366575,3003,75262.98744797707,0.6958664655685425,35.70916601186907,1.3725371360778809,0.6946473121643066,30.91054033191716,1.364732027053833,3000 -31320.80184316635,1.9225962162017824,45396.00702667236,126663,0,45396.00702667236,0.7128929495811462,30.87501995821473,1.2559876441955566,3003,76722.7999753952,0.6963878273963928,35.3685362955417,1.3677339553833008,0.6950316429138184,30.77219468305123,1.3643397092819214,3000 -31981.52465200424,1.9672255516052248,46236.22338271141,129010,0,46236.22338271141,0.7133809924125671,30.837718893970138,1.2536063194274902,3003,78223.86001873016,0.6947799324989319,35.44969392992167,1.3747010231018066,0.6953044533729553,30.786053998723386,1.3621292114257812,3000 -32636.82600307465,2.012143135070801,47076.33667135239,131357,0,47076.33667135239,0.7132763862609863,30.857126776061257,1.2537087202072144,3003,79719.39379668236,0.6980651617050171,35.745533977811,1.3623944520950315,0.6951928734779358,30.816823338684078,1.3623720407485962,3000 -33270.594187021255,2.060373544692993,47783.83299946785,133333,0,47783.83299946785,0.7132763862609863,30.852166892997438,1.253868579864502,3003,81060.76957821846,0.6977786421775818,35.85408950071874,1.358901858329773,0.6951680779457092,30.784543671591038,1.3623775243759155,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/measurements.csv deleted file mode 100644 index b97959b60..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1394 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.7888966,11.135162,,,,,,,,,,,,,,,,, -1,,,0.0005809449939988,11.19240951538086,0.0,0.0004835649742744,11.190281867980955,0.0,3000.0,0.0007088489946909,11.191027641296388,0.0,3003.0,31.995970487594604,873.7377886772156,31.995970487594604,841.7417721748352,0.0,0.0 -100,0.18685812,8.258305,,,,,,,,,,,,,,,,, -200,0.96242,7.5216694,,,,,,,,,,,,,,,,, -300,0.37346396,6.8933144,,,,,,,,,,,,,,,,, -400,0.5321248,6.3386793,,,,,,,,,,,,,,,,, -500,0.38655487,5.9003644,,,,,,,,,,,,,,,,, -600,0.5496124,5.615564,,,,,,,,,,,,,,,,, -700,0.62727594,5.336079,,,,,,,,,,,,,,,,, -800,0.52702785,5.06124,,,,,,,,,,,,,,,,, -900,0.4881948,4.889184,,,,,,,,,,,,,,,,, -1000,0.74749297,4.609756,,,,,,,,,,,,,,,,, -1100,0.68140274,4.337661,,,,,,,,,,,,,,,,, -1200,0.6069384,4.1046906,,,,,,,,,,,,,,,,, -1300,0.52492005,3.9454908,,,,,,,,,,,,,,,,, -1400,0.46002233,3.7745183,,,,,,,,,,,,,,,,, -1500,0.52612484,3.6036823,,,,,,,,,,,,,,,,, -1600,0.51676446,3.5289445,,,,,,,,,,,,,,,,, -1700,0.52693444,3.4295611,,,,,,,,,,,,,,,,, -1800,0.5184435,3.2781405,,,,,,,,,,,,,,,,, -1900,0.5212904,3.2703674,,,,,,,,,,,,,,,,, -2000,0.44549286,3.2209134,,,,,,,,,,,,,,,,, -2100,0.45205802,3.0577016,,,,,,,,,,,,,,,,, -2200,0.38764402,3.0401943,,,,,,,,,,,,,,,,, -2300,0.42471823,3.014045,,,,,,,,,,,,,,,,, -2343,,,0.5090352892875671,2.874778985977173,22.760860535397587,0.509181559085846,2.877437829971313,18.645719298845435,3000.0,0.5100807547569275,2.911456823348999,17.119625855590364,3003.0,872.1930792331696,2251.830168247223,872.1930792331696,1379.5385262966156,0.0206320285797119,0.0 -2400,0.38534495,2.9670136,,,,,,,,,,,,,,,,, -2500,0.30748245,2.955076,,,,,,,,,,,,,,,,, -2600,0.3189297,2.8719404,,,,,,,,,,,,,,,,, -2700,0.29831645,2.8704045,,,,,,,,,,,,,,,,, -2800,0.2687663,2.6922119,,,,,,,,,,,,,,,,, -2900,0.2936276,2.8190198,,,,,,,,,,,,,,,,, -3000,0.27218738,2.7449796,,,,,,,,,,,,,,,,, -3100,0.24613503,2.6449745,,,,,,,,,,,,,,,,, -3200,0.26772994,2.5420196,,,,,,,,,,,,,,,,, -3300,0.31884214,2.5197976,,,,,,,,,,,,,,,,, -3400,0.2794606,2.5824146,,,,,,,,,,,,,,,,, -3500,0.20613354,2.4527884,,,,,,,,,,,,,,,,, -3600,0.31341657,2.4917493,,,,,,,,,,,,,,,,, -3700,0.22207393,2.3891866,,,,,,,,,,,,,,,,, -3800,0.27629593,2.4627666,,,,,,,,,,,,,,,,, -3900,0.2222386,2.5314398,,,,,,,,,,,,,,,,, -4000,0.20482816,2.4508045,,,,,,,,,,,,,,,,, -4100,0.20348887,2.416131,,,,,,,,,,,,,,,,, -4200,0.18092579,2.442455,,,,,,,,,,,,,,,,, -4300,0.18419796,2.4227068,,,,,,,,,,,,,,,,, -4400,0.1812954,2.4428048,,,,,,,,,,,,,,,,, -4500,0.159813,2.3343306,,,,,,,,,,,,,,,,, -4600,0.1606258,2.4408965,,,,,,,,,,,,,,,,, -4687,,,0.5762688517570496,2.2632884979248047,27.03187390595071,0.5884490013122559,2.1593129634857178,23.44131993452736,3000.0,0.5936436057090759,2.1284549236297607,22.072443389635552,3003.0,1712.265777349472,3552.374214172364,1712.265777349472,1839.9097967147827,0.0460519790649414,0.0 -4700,0.17878298,2.3124852,,,,,,,,,,,,,,,,, -4800,0.20168667,2.2772775,,,,,,,,,,,,,,,,, -4900,0.21811934,2.3324416,,,,,,,,,,,,,,,,, -5000,0.16706789,2.2739477,,,,,,,,,,,,,,,,, -5100,0.15344226,2.2805712,,,,,,,,,,,,,,,,, -5200,0.17500062,2.2611787,,,,,,,,,,,,,,,,, -5300,0.16838075,2.2458963,,,,,,,,,,,,,,,,, -5400,0.15672554,2.3171299,,,,,,,,,,,,,,,,, -5500,0.19391814,2.2544098,,,,,,,,,,,,,,,,, -5600,0.21046458,2.3024116,,,,,,,,,,,,,,,,, -5700,0.1839806,2.200963,,,,,,,,,,,,,,,,, -5800,0.16498119,2.1547043,,,,,,,,,,,,,,,,, -5900,0.1686302,2.274933,,,,,,,,,,,,,,,,, -6000,0.15669721,2.106619,,,,,,,,,,,,,,,,, -6100,0.18010496,2.169141,,,,,,,,,,,,,,,,, -6200,0.17973836,2.2718005,,,,,,,,,,,,,,,,, -6300,0.16840094,2.1866355,,,,,,,,,,,,,,,,, -6400,0.14328931,2.221851,,,,,,,,,,,,,,,,, -6500,0.15303056,2.2233531,,,,,,,,,,,,,,,,, -6600,0.17749336,2.2111084,,,,,,,,,,,,,,,,, -6700,0.18292576,2.2364404,,,,,,,,,,,,,,,,, -6800,0.1884104,2.1450524,,,,,,,,,,,,,,,,, -6900,0.1521944,2.1917436,,,,,,,,,,,,,,,,, -7000,0.13756655,2.1504664,,,,,,,,,,,,,,,,, -7032,,,0.6044877767562866,2.010396242141724,29.150223633210214,0.6167561411857605,1.9335185289382928,25.14535870776252,3000.0,0.6223810315132141,1.891671061515808,23.71615737939061,3003.0,2552.231223344803,4832.6473343372345,2552.231223344803,2280.1172001361847,0.0716917514801025,0.0 -7100,0.15741427,2.1359131,,,,,,,,,,,,,,,,, -7200,0.17119983,2.1363087,,,,,,,,,,,,,,,,, -7300,0.15575351,2.1074588,,,,,,,,,,,,,,,,, -7400,0.14584774,2.04824,,,,,,,,,,,,,,,,, -7500,0.17599548,2.119853,,,,,,,,,,,,,,,,, -7600,0.16963542,2.0993292,,,,,,,,,,,,,,,,, -7700,0.15262309,2.095829,,,,,,,,,,,,,,,,, -7800,0.18618609,2.0974147,,,,,,,,,,,,,,,,, -7900,0.16317701,2.0685186,,,,,,,,,,,,,,,,, -8000,0.1550961,2.0399647,,,,,,,,,,,,,,,,, -8100,0.17438753,2.0847538,,,,,,,,,,,,,,,,, -8200,0.15510274,1.9871409,,,,,,,,,,,,,,,,, -8300,0.18711115,2.1581063,,,,,,,,,,,,,,,,, -8400,0.15225714,2.0363278,,,,,,,,,,,,,,,,, -8500,0.18264143,2.048,,,,,,,,,,,,,,,,, -8600,0.15705176,2.07308,,,,,,,,,,,,,,,,, -8700,0.15137643,2.0829723,,,,,,,,,,,,,,,,, -8800,0.16348211,2.0173433,,,,,,,,,,,,,,,,, -8900,0.20630124,2.0726657,,,,,,,,,,,,,,,,, -9000,0.15900132,2.1057343,,,,,,,,,,,,,,,,, -9100,0.27442443,2.0262861,,,,,,,,,,,,,,,,, -9200,0.15003072,2.1361496,,,,,,,,,,,,,,,,, -9300,0.17086332,1.9988544,,,,,,,,,,,,,,,,, -9378,,,0.6121419668197632,1.9549943208694456,29.574920246863407,0.6282377243041992,1.8250943422317505,26.20106139743751,3000.0,0.637569010257721,1.7603546380996704,25.21416047743482,3003.0,3392.4043912887573,6102.503938674927,3392.4043912887573,2709.7009241580963,0.0972914695739746,0.0 -9400,0.15948269,2.034294,,,,,,,,,,,,,,,,, -9500,0.15606415,2.0415466,,,,,,,,,,,,,,,,, -9600,0.19949298,2.0316892,,,,,,,,,,,,,,,,, -9700,0.15628844,1.9503477,,,,,,,,,,,,,,,,, -9800,0.1517764,2.010788,,,,,,,,,,,,,,,,, -9900,0.16330072,1.9813185,,,,,,,,,,,,,,,,, -10000,0.19039865,1.9610829,,,,,,,,,,,,,,,,, -10100,0.16356577,1.9509408,,,,,,,,,,,,,,,,, -10200,0.24637455,2.0077899,,,,,,,,,,,,,,,,, -10300,0.18859163,2.0000997,,,,,,,,,,,,,,,,, -10400,0.18146315,1.9353911,,,,,,,,,,,,,,,,, -10500,0.16627498,2.0649445,,,,,,,,,,,,,,,,, -10600,0.18481876,2.002929,,,,,,,,,,,,,,,,, -10700,0.16961901,2.1234114,,,,,,,,,,,,,,,,, -10800,0.21337038,2.0339208,,,,,,,,,,,,,,,,, -10900,0.2755026,1.9592687,,,,,,,,,,,,,,,,, -11000,0.19262335,2.0565,,,,,,,,,,,,,,,,, -11100,0.16552594,1.945133,,,,,,,,,,,,,,,,, -11200,0.26976967,2.028636,,,,,,,,,,,,,,,,, -11300,0.16098833,1.9432701,,,,,,,,,,,,,,,,, -11400,0.2088008,1.9318589,,,,,,,,,,,,,,,,, -11500,0.20093682,1.9709991,,,,,,,,,,,,,,,,, -11600,0.24372311,1.9679165,,,,,,,,,,,,,,,,, -11700,0.16943625,2.131793,,,,,,,,,,,,,,,,, -11725,,,0.6145750284194946,1.9365715980529783,29.91781370647072,0.637115478515625,1.762860894203186,26.47492510250528,3000.0,0.6475510001182556,1.691665768623352,25.765832245270712,3003.0,4232.526937961578,7420.881967782974,4232.526937961578,3187.856215953827,0.1250281333923339,0.0 -11800,0.19299786,1.9435652,,,,,,,,,,,,,,,,, -11900,0.19177203,1.9948642,,,,,,,,,,,,,,,,, -12000,0.19380726,2.0223718,,,,,,,,,,,,,,,,, -12100,0.32464722,1.8829199,,,,,,,,,,,,,,,,, -12200,0.3261016,1.9506898,,,,,,,,,,,,,,,,, -12300,0.18007453,1.9872506,,,,,,,,,,,,,,,,, -12400,0.1867361,1.9227327,,,,,,,,,,,,,,,,, -12500,0.18372163,1.9609296,,,,,,,,,,,,,,,,, -12600,0.2474267,1.947903,,,,,,,,,,,,,,,,, -12700,0.18589826,1.9663815,,,,,,,,,,,,,,,,, -12800,0.21002194,1.8960721,,,,,,,,,,,,,,,,, -12900,0.22048198,1.9725013,,,,,,,,,,,,,,,,, -13000,0.22932272,2.0149887,,,,,,,,,,,,,,,,, -13100,0.1650691,1.9739808,,,,,,,,,,,,,,,,, -13200,0.18398994,2.0148518,,,,,,,,,,,,,,,,, -13300,0.16636169,1.908425,,,,,,,,,,,,,,,,, -13400,0.18250243,1.9699087,,,,,,,,,,,,,,,,, -13500,0.17617689,1.9234947,,,,,,,,,,,,,,,,, -13600,0.22459672,1.8980935,,,,,,,,,,,,,,,,, -13700,0.20858337,1.939027,,,,,,,,,,,,,,,,, -13800,0.23441917,1.9721227,,,,,,,,,,,,,,,,, -13900,0.21487305,1.9795762,,,,,,,,,,,,,,,,, -14000,0.19251344,1.8942231,,,,,,,,,,,,,,,,, -14071,,,0.6287801861763,1.8265225887298584,30.57407513299764,0.6427942514419556,1.715195655822754,27.20454930807229,3000.0,0.6514670848846436,1.646348476409912,26.01923512162698,3003.0,5072.729242563248,8713.320725440979,5072.729242563248,3639.9922440052032,0.1513416767120361,0.0 -14100,0.2259567,1.9074196,,,,,,,,,,,,,,,,, -14200,0.19489041,1.9385623,,,,,,,,,,,,,,,,, -14300,0.19261551,1.8345453,,,,,,,,,,,,,,,,, -14400,0.18551949,1.9216746,,,,,,,,,,,,,,,,, -14500,0.25239912,1.9042337,,,,,,,,,,,,,,,,, -14600,0.20123573,1.9235452,,,,,,,,,,,,,,,,, -14700,0.20805867,1.8697629,,,,,,,,,,,,,,,,, -14800,0.22569942,1.918207,,,,,,,,,,,,,,,,, -14900,0.18403183,1.9384322,,,,,,,,,,,,,,,,, -15000,0.27031374,1.8427008,,,,,,,,,,,,,,,,, -15100,0.43428195,1.8668869,,,,,,,,,,,,,,,,, -15200,0.25541812,2.0331514,,,,,,,,,,,,,,,,, -15300,0.18441011,1.9454823,,,,,,,,,,,,,,,,, -15400,0.27476567,1.8144937,,,,,,,,,,,,,,,,, -15500,0.18747443,1.9472835,,,,,,,,,,,,,,,,, -15600,0.2178138,1.9471588,,,,,,,,,,,,,,,,, -15700,0.22743402,1.8543656,,,,,,,,,,,,,,,,, -15800,0.22701001,1.9239228,,,,,,,,,,,,,,,,, -15900,0.19915795,1.9929588,,,,,,,,,,,,,,,,, -16000,0.20067991,1.919977,,,,,,,,,,,,,,,,, -16100,0.18112351,1.9777583,,,,,,,,,,,,,,,,, -16200,0.22324155,1.8369496,,,,,,,,,,,,,,,,, -16300,0.18751463,1.8707792,,,,,,,,,,,,,,,,, -16400,0.22801748,1.8889257,,,,,,,,,,,,,,,,, -16415,,,0.6283307075500488,1.83135998249054,30.64059289155258,0.6439225673675537,1.6987574100494385,26.862268504080657,3000.0,0.6555575132369995,1.6271188259124756,25.974975993637944,3003.0,5912.799705505371,10050.71216583252,5912.799705505371,4137.209697246552,0.1788101196289062,0.0 -16500,0.19168392,1.8944546,,,,,,,,,,,,,,,,, -16600,0.18028854,1.8714925,,,,,,,,,,,,,,,,, -16700,0.21673459,1.827632,,,,,,,,,,,,,,,,, -16800,0.32905844,1.9375628,,,,,,,,,,,,,,,,, -16900,0.17095484,1.8538553,,,,,,,,,,,,,,,,, -17000,0.30413932,1.9551129,,,,,,,,,,,,,,,,, -17100,0.18689881,1.920831,,,,,,,,,,,,,,,,, -17200,0.25903174,1.835474,,,,,,,,,,,,,,,,, -17300,0.18089585,1.919566,,,,,,,,,,,,,,,,, -17400,0.22248247,1.8451779,,,,,,,,,,,,,,,,, -17500,0.21495196,1.8623242,,,,,,,,,,,,,,,,, -17600,0.17660442,1.9294347,,,,,,,,,,,,,,,,, -17700,0.1857388,1.8086149,,,,,,,,,,,,,,,,, -17800,0.20398095,1.8966993,,,,,,,,,,,,,,,,, -17900,0.21636428,1.881545,,,,,,,,,,,,,,,,, -18000,0.18332046,1.8806286,,,,,,,,,,,,,,,,, -18100,0.19860761,1.9025455,,,,,,,,,,,,,,,,, -18200,0.32581905,1.881779,,,,,,,,,,,,,,,,, -18300,0.19998701,1.8637553,,,,,,,,,,,,,,,,, -18400,0.23852296,1.8522774,,,,,,,,,,,,,,,,, -18500,0.21204522,1.8945534,,,,,,,,,,,,,,,,, -18600,0.3600626,1.8898067,,,,,,,,,,,,,,,,, -18700,0.23509778,1.8485281,,,,,,,,,,,,,,,,, -18760,,,0.6313144564628601,1.8021981716156008,30.3301787042636,0.6500601172447205,1.666564702987671,27.40282885917218,3000.0,0.6605310440063477,1.594834566116333,26.84595365712397,3003.0,6752.885207414627,11451.534358024595,6752.885207414627,4697.838307380676,0.2108919620513916,0.0 -18800,0.26309657,1.8952817,,,,,,,,,,,,,,,,, -18900,0.16933553,1.792977,,,,,,,,,,,,,,,,, -19000,0.19731203,1.89367,,,,,,,,,,,,,,,,, -19100,0.20182426,1.886162,,,,,,,,,,,,,,,,, -19200,0.18071012,1.8231096,,,,,,,,,,,,,,,,, -19300,0.18525589,1.8299216,,,,,,,,,,,,,,,,, -19400,0.17096913,1.9370931,,,,,,,,,,,,,,,,, -19500,0.18194193,1.9279586,,,,,,,,,,,,,,,,, -19600,0.16404556,1.873807,,,,,,,,,,,,,,,,, -19700,0.19296554,1.974808,,,,,,,,,,,,,,,,, -19800,0.18945457,1.8424404,,,,,,,,,,,,,,,,, -19900,0.2651709,1.8850687,,,,,,,,,,,,,,,,, -20000,0.2382924,1.9226617,,,,,,,,,,,,,,,,, -20100,0.20669115,1.8042585,,,,,,,,,,,,,,,,, -20200,0.17859894,1.883004,,,,,,,,,,,,,,,,, -20300,0.19017923,1.8991706,,,,,,,,,,,,,,,,, -20400,0.18295076,1.969417,,,,,,,,,,,,,,,,, -20500,0.20950745,1.8446984,,,,,,,,,,,,,,,,, -20600,0.20509118,1.8861467,,,,,,,,,,,,,,,,, -20700,0.19239403,1.8261312,,,,,,,,,,,,,,,,, -20800,0.17433311,1.8557794,,,,,,,,,,,,,,,,, -20900,0.19420642,1.7814207,,,,,,,,,,,,,,,,, -21000,0.18902075,1.8581011,,,,,,,,,,,,,,,,, -21100,0.28606868,1.9276793,,,,,,,,,,,,,,,,, -21106,,,0.6347932815551758,1.7635672092437744,31.21414025678416,0.6517339944839478,1.6439547538757324,27.60810916818256,3000.0,0.6628551483154297,1.5742295980453491,26.80578598178156,3003.0,7592.865230083466,12941.91342139244,7592.865230083466,5348.135931015015,0.2389867305755615,0.0 -21200,0.23836564,1.9169735,,,,,,,,,,,,,,,,, -21300,0.3339751,1.8891654,,,,,,,,,,,,,,,,, -21400,0.2091005,1.9094594,,,,,,,,,,,,,,,,, -21500,0.20648941,1.9347696,,,,,,,,,,,,,,,,, -21600,0.23036756,1.7959913,,,,,,,,,,,,,,,,, -21700,0.22284392,1.789917,,,,,,,,,,,,,,,,, -21800,0.2357067,1.8270645,,,,,,,,,,,,,,,,, -21900,0.2357323,1.921053,,,,,,,,,,,,,,,,, -22000,0.3156456,1.8220754,,,,,,,,,,,,,,,,, -22100,0.2289592,1.9504869,,,,,,,,,,,,,,,,, -22200,0.20287976,1.9047054,,,,,,,,,,,,,,,,, -22300,0.1946413,1.7941718,,,,,,,,,,,,,,,,, -22400,0.2060197,1.8091567,,,,,,,,,,,,,,,,, -22500,0.19046071,1.8156615,,,,,,,,,,,,,,,,, -22600,0.22814527,1.8110029,,,,,,,,,,,,,,,,, -22700,0.20869534,1.9390515,,,,,,,,,,,,,,,,, -22800,0.31685546,1.8307841,,,,,,,,,,,,,,,,, -22900,0.18177359,1.8436993,,,,,,,,,,,,,,,,, -23000,0.23401515,1.840375,,,,,,,,,,,,,,,,, -23100,0.22222683,1.8536413,,,,,,,,,,,,,,,,, -23200,0.18510558,1.8739874,,,,,,,,,,,,,,,,, -23300,0.2238487,1.8930633,,,,,,,,,,,,,,,,, -23400,6.3827443,2.4910705,,,,,,,,,,,,,,,,, -23452,,,0.6305269598960876,1.790299415588379,30.61069211527252,0.6529738903045654,1.6360857486724854,27.64664612249136,3000.0,0.6636337637901306,1.5664596557617188,27.270381340846026,3003.0,8432.95629954338,14270.456376552582,8432.95629954338,5836.484181642532,0.2676777839660644,0.0 -23500,0.17045893,1.7641804,,,,,,,,,,,,,,,,, -23600,0.23988204,1.9193158,,,,,,,,,,,,,,,,, -23700,0.21657166,1.8794422,,,,,,,,,,,,,,,,, -23800,0.20052099,1.8281133,,,,,,,,,,,,,,,,, -23900,0.16866231,1.785902,,,,,,,,,,,,,,,,, -24000,0.21708563,1.7933279,,,,,,,,,,,,,,,,, -24100,0.19791766,1.8165357,,,,,,,,,,,,,,,,, -24200,0.19598746,1.8536882,,,,,,,,,,,,,,,,, -24300,0.21639067,1.8194848,,,,,,,,,,,,,,,,, -24400,0.20462598,1.8030485,,,,,,,,,,,,,,,,, -24500,0.2203179,1.7952309,,,,,,,,,,,,,,,,, -24600,0.21170707,1.7590357,,,,,,,,,,,,,,,,, -24700,0.21424776,1.812214,,,,,,,,,,,,,,,,, -24800,0.20844056,1.8054489,,,,,,,,,,,,,,,,, -24900,0.2321738,1.860269,,,,,,,,,,,,,,,,, -25000,0.18752253,1.8269426,,,,,,,,,,,,,,,,, -25100,0.19087537,1.8221353,,,,,,,,,,,,,,,,, -25200,0.22619385,1.7682761,,,,,,,,,,,,,,,,, -25300,0.31541345,1.968707,,,,,,,,,,,,,,,,, -25400,0.21681146,1.8550248,,,,,,,,,,,,,,,,, -25500,0.27070764,1.824899,,,,,,,,,,,,,,,,, -25600,0.17383806,1.8195752,,,,,,,,,,,,,,,,, -25700,0.22176531,1.8360695,,,,,,,,,,,,,,,,, -25798,,,0.642456591129303,1.7020819187164309,31.55244803474688,0.6566688418388367,1.6203370094299316,28.153017370131508,3000.0,0.6658416390419006,1.5436705350875854,27.35816931986309,3003.0,9273.077693939207,15630.35198712349,9273.077693939207,6356.156872987747,0.2959790229797363,0.0 -25800,0.20706162,1.9117371,,,,,,,,,,,,,,,,, -25900,0.24203022,1.8343604,,,,,,,,,,,,,,,,, -26000,0.4100331,1.791838,,,,,,,,,,,,,,,,, -26100,0.17836483,1.8179181,,,,,,,,,,,,,,,,, -26200,0.22213076,1.8291157,,,,,,,,,,,,,,,,, -26300,0.19323055,1.8190931,,,,,,,,,,,,,,,,, -26400,0.18291235,1.8443815,,,,,,,,,,,,,,,,, -26500,0.3473539,1.7880852,,,,,,,,,,,,,,,,, -26600,0.24834861,1.7640507,,,,,,,,,,,,,,,,, -26700,0.21725781,1.8565086,,,,,,,,,,,,,,,,, -26800,0.21435867,1.8730544,,,,,,,,,,,,,,,,, -26900,0.21718098,1.8192893,,,,,,,,,,,,,,,,, -27000,0.22169137,1.8015774,,,,,,,,,,,,,,,,, -27100,0.18295865,1.7836173,,,,,,,,,,,,,,,,, -27200,0.27652103,1.7829406,,,,,,,,,,,,,,,,, -27300,0.21197686,1.7476237,,,,,,,,,,,,,,,,, -27400,0.19959001,1.8090566,,,,,,,,,,,,,,,,, -27500,0.19807385,1.8131373,,,,,,,,,,,,,,,,, -27600,0.22288099,1.8328172,,,,,,,,,,,,,,,,, -27700,0.17498775,1.8157561,,,,,,,,,,,,,,,,, -27800,0.19066246,1.8289697,,,,,,,,,,,,,,,,, -27900,0.21260749,1.7640342,,,,,,,,,,,,,,,,, -28000,0.2154235,1.9628257,,,,,,,,,,,,,,,,, -28100,0.17567438,1.739371,,,,,,,,,,,,,,,,, -28144,,,0.6392359733581543,1.7401916980743408,31.22346015081715,0.6564580798149109,1.6131483316421509,27.81124516471783,3000.0,0.6691185832023621,1.5272434949874878,27.576740566813548,3003.0,10113.140311479568,16915.07120656967,10113.140311479568,6800.711962461472,0.3239438533782959,0.0 -28200,0.18076892,1.7793554,,,,,,,,,,,,,,,,, -28300,0.21029496,1.8100529,,,,,,,,,,,,,,,,, -28400,0.22605468,1.7805591,,,,,,,,,,,,,,,,, -28500,3.174215,1.7447762,,,,,,,,,,,,,,,,, -28600,0.23555714,1.930266,,,,,,,,,,,,,,,,, -28700,0.20933232,1.8221877,,,,,,,,,,,,,,,,, -28800,0.19860414,1.846551,,,,,,,,,,,,,,,,, -28900,0.25466603,1.9100012,,,,,,,,,,,,,,,,, -29000,0.18691987,1.7661314,,,,,,,,,,,,,,,,, -29100,0.1827215,1.783478,,,,,,,,,,,,,,,,, -29200,0.263723,1.8935307,,,,,,,,,,,,,,,,, -29300,0.23065884,1.8233191,,,,,,,,,,,,,,,,, -29400,0.20848933,1.9110045,,,,,,,,,,,,,,,,, -29500,0.21788648,1.7372627,,,,,,,,,,,,,,,,, -29600,0.18156002,1.7840326,,,,,,,,,,,,,,,,, -29700,0.29384616,1.8260708,,,,,,,,,,,,,,,,, -29800,0.19003904,1.8212433,,,,,,,,,,,,,,,,, -29900,0.20062125,1.7856021,,,,,,,,,,,,,,,,, -30000,0.2493965,1.8017997,,,,,,,,,,,,,,,,, -30100,0.19754262,1.8368188,,,,,,,,,,,,,,,,, -30200,0.2444474,1.7743813,,,,,,,,,,,,,,,,, -30300,0.21561329,1.7699265,,,,,,,,,,,,,,,,, -30400,0.19283074,1.8522072,,,,,,,,,,,,,,,,, -30490,,,0.6276158094406128,1.8166896104812624,29.77989999011368,0.6502957344055176,1.6435110569000244,26.92453743681419,3000.0,0.6619836091995239,1.5729281902313232,26.57815801463406,3003.0,10953.335748672484,18238.866058826447,10953.335748672484,7284.199323177338,0.3592972755432129,0.0 -30500,5.5681243,1.9024951,,,,,,,,,,,,,,,,, -30600,0.1762477,1.745751,,,,,,,,,,,,,,,,, -30700,0.20330803,1.8262799,,,,,,,,,,,,,,,,, -30800,0.18275273,1.7535019,,,,,,,,,,,,,,,,, -30900,0.20094572,1.7594544,,,,,,,,,,,,,,,,, -31000,0.23705593,1.7864937,,,,,,,,,,,,,,,,, -31100,0.2525908,1.7903414,,,,,,,,,,,,,,,,, -31200,0.19757006,1.7577715,,,,,,,,,,,,,,,,, -31300,0.19945024,1.835347,,,,,,,,,,,,,,,,, -31400,0.20428573,1.8171418,,,,,,,,,,,,,,,,, -31500,0.28779924,1.7696549,,,,,,,,,,,,,,,,, -31600,0.21218912,1.7576557,,,,,,,,,,,,,,,,, -31700,0.23372601,1.8551126,,,,,,,,,,,,,,,,, -31800,0.6996106,1.845931,,,,,,,,,,,,,,,,, -31900,0.1745987,1.7898285,,,,,,,,,,,,,,,,, -32000,0.2042335,1.8009244,,,,,,,,,,,,,,,,, -32100,0.2672747,1.8215456,,,,,,,,,,,,,,,,, -32200,0.20484968,1.8060519,,,,,,,,,,,,,,,,, -32300,0.18854423,1.7169956,,,,,,,,,,,,,,,,, -32400,0.20492482,1.9315645,,,,,,,,,,,,,,,,, -32500,0.24452825,1.766243,,,,,,,,,,,,,,,,, -32600,0.20817724,1.7570751,,,,,,,,,,,,,,,,, -32700,0.2853433,1.7759101,,,,,,,,,,,,,,,,, -32800,0.21414846,1.8871819,,,,,,,,,,,,,,,,, -32836,,,0.6437987089157104,1.7077805995941162,31.70166116605699,0.6607481837272644,1.5932337045669556,28.223498811079704,3000.0,0.6721515655517578,1.512355923652649,27.20909834912664,3003.0,11793.542217969894,19616.90053844452,11793.542217969894,7821.921438455582,0.3904638290405273,0.0 -32900,0.18479995,1.79666,,,,,,,,,,,,,,,,, -33000,0.19966385,1.812956,,,,,,,,,,,,,,,,, -33100,0.20532452,1.8430041,,,,,,,,,,,,,,,,, -33200,0.19837329,1.908799,,,,,,,,,,,,,,,,, -33300,0.20254552,1.7861573,,,,,,,,,,,,,,,,, -33400,0.25749224,1.8102032,,,,,,,,,,,,,,,,, -33500,0.26306403,1.8874007,,,,,,,,,,,,,,,,, -33600,0.21605009,1.7826391,,,,,,,,,,,,,,,,, -33700,0.19652596,1.7577999,,,,,,,,,,,,,,,,, -33800,0.19982381,1.7869406,,,,,,,,,,,,,,,,, -33900,0.21045688,1.828163,,,,,,,,,,,,,,,,, -34000,0.19987884,1.7865385,,,,,,,,,,,,,,,,, -34100,0.22584057,1.76977,,,,,,,,,,,,,,,,, -34200,0.21825661,1.8805916,,,,,,,,,,,,,,,,, -34300,0.20363195,1.8579702,,,,,,,,,,,,,,,,, -34400,0.20549515,1.7433318,,,,,,,,,,,,,,,,, -34500,0.20960517,1.8462212,,,,,,,,,,,,,,,,, -34600,0.20955914,1.76705,,,,,,,,,,,,,,,,, -34700,0.17861855,1.7584953,,,,,,,,,,,,,,,,, -34800,0.1748026,1.8298705,,,,,,,,,,,,,,,,, -34900,0.29693308,1.8720503,,,,,,,,,,,,,,,,, -35000,0.19638358,1.774249,,,,,,,,,,,,,,,,, -35100,0.22838004,1.8214552,,,,,,,,,,,,,,,,, -35182,,,0.644922137260437,1.709308624267578,31.362425620829725,0.6612069010734558,1.584161639213562,28.02851951476708,3000.0,0.6742432117462158,1.504552960395813,27.942357040831503,3003.0,12633.712213754654,21078.17825651169,12633.712213754654,8442.923906326294,0.4212970733642578,0.0 -35200,0.19877875,1.8223343,,,,,,,,,,,,,,,,, -35300,0.20674354,1.7589338,,,,,,,,,,,,,,,,, -35400,0.19075869,1.7660543,,,,,,,,,,,,,,,,, -35500,0.2086413,1.8752773,,,,,,,,,,,,,,,,, -35600,0.20241134,1.7918756,,,,,,,,,,,,,,,,, -35700,0.20265692,1.8315095,,,,,,,,,,,,,,,,, -35800,0.18528146,1.7253472,,,,,,,,,,,,,,,,, -35900,0.19548748,1.7911658,,,,,,,,,,,,,,,,, -36000,0.18758243,1.7314721,,,,,,,,,,,,,,,,, -36100,0.22477005,1.8338084,,,,,,,,,,,,,,,,, -36200,0.2283963,1.7811291,,,,,,,,,,,,,,,,, -36300,0.18528877,1.7700027,,,,,,,,,,,,,,,,, -36400,0.18171887,1.763109,,,,,,,,,,,,,,,,, -36500,0.2013964,1.8402996,,,,,,,,,,,,,,,,, -36600,0.26814285,1.7714994,,,,,,,,,,,,,,,,, -36700,0.26776657,1.7551368,,,,,,,,,,,,,,,,, -36800,0.21745001,1.7603782,,,,,,,,,,,,,,,,, -36900,0.26868594,1.8481077,,,,,,,,,,,,,,,,, -37000,0.22690384,1.8222505,,,,,,,,,,,,,,,,, -37100,0.2645881,1.685908,,,,,,,,,,,,,,,,, -37200,0.18875156,1.8020962,,,,,,,,,,,,,,,,, -37300,0.18830413,1.7778585,,,,,,,,,,,,,,,,, -37400,0.20083296,1.8148464,,,,,,,,,,,,,,,,, -37500,0.20760328,1.8339949,,,,,,,,,,,,,,,,, -37528,,,0.6506013870239258,1.6590782403945925,32.2079672733254,0.6631659865379333,1.5785056352615356,28.21083540443361,3000.0,0.6755679845809937,1.4972996711730957,27.54270861029842,3003.0,13473.850688695908,22409.57164978981,13473.850688695908,8934.070060491562,0.4525520801544189,0.0 -37600,0.19701831,1.7720242,,,,,,,,,,,,,,,,, -37700,0.21234682,1.7855345,,,,,,,,,,,,,,,,, -37800,0.19926868,1.827204,,,,,,,,,,,,,,,,, -37900,0.1883506,1.8506463,,,,,,,,,,,,,,,,, -38000,0.20681657,1.803083,,,,,,,,,,,,,,,,, -38100,0.1950847,1.8002517,,,,,,,,,,,,,,,,, -38200,0.20846272,1.768793,,,,,,,,,,,,,,,,, -38300,0.19020003,1.7859813,,,,,,,,,,,,,,,,, -38400,0.2426664,1.762155,,,,,,,,,,,,,,,,, -38500,0.20442168,1.7699716,,,,,,,,,,,,,,,,, -38600,0.24533637,1.8402582,,,,,,,,,,,,,,,,, -38700,0.21066166,1.7034053,,,,,,,,,,,,,,,,, -38800,0.23147005,1.7507555,,,,,,,,,,,,,,,,, -38900,0.18041532,1.7361503,,,,,,,,,,,,,,,,, -39000,0.20102099,1.8159641,,,,,,,,,,,,,,,,, -39100,0.20953077,1.7341152,,,,,,,,,,,,,,,,, -39200,0.22536069,1.834832,,,,,,,,,,,,,,,,, -39300,0.19854179,1.7741234,,,,,,,,,,,,,,,,, -39400,0.19237801,1.7960715,,,,,,,,,,,,,,,,, -39500,0.19065367,1.7428921,,,,,,,,,,,,,,,,, -39600,0.19213603,1.771291,,,,,,,,,,,,,,,,, -39700,0.2429054,1.7583648,,,,,,,,,,,,,,,,, -39800,0.19873121,1.8086002,,,,,,,,,,,,,,,,, -39873,,,0.6454939842224121,1.692597508430481,31.815291168303688,0.6619632840156555,1.576145052909851,28.208675423715423,3000.0,0.6739526987075806,1.494259476661682,27.546707690715404,3003.0,14313.951741695404,23827.573871850967,14313.951741695404,9511.859144210815,0.4880993366241455,0.0 -39900,0.7070041,1.7459166,,,,,,,,,,,,,,,,, -40000,0.20769437,1.814438,,,,,,,,,,,,,,,,, -40100,0.21025026,1.7183656,,,,,,,,,,,,,,,,, -40200,0.18001391,1.7274861,,,,,,,,,,,,,,,,, -40300,0.31304365,1.78149,,,,,,,,,,,,,,,,, -40400,0.20909382,1.7711718,,,,,,,,,,,,,,,,, -40500,0.19100606,1.6964699,,,,,,,,,,,,,,,,, -40600,0.20900832,1.8008313,,,,,,,,,,,,,,,,, -40700,0.21194716,1.7663957,,,,,,,,,,,,,,,,, -40800,0.33698422,1.698124,,,,,,,,,,,,,,,,, -40900,0.22047254,1.7949286,,,,,,,,,,,,,,,,, -41000,0.19279675,1.7851318,,,,,,,,,,,,,,,,, -41100,0.18881722,1.8034009,,,,,,,,,,,,,,,,, -41200,0.20473337,1.7928293,,,,,,,,,,,,,,,,, -41300,0.20541938,1.7436736,,,,,,,,,,,,,,,,, -41400,0.2674359,1.6642637,,,,,,,,,,,,,,,,, -41500,0.19854623,1.7270226,,,,,,,,,,,,,,,,, -41600,0.20848256,1.8045132,,,,,,,,,,,,,,,,, -41700,0.22113469,1.7922493,,,,,,,,,,,,,,,,, -41800,0.19534017,1.7284666,,,,,,,,,,,,,,,,, -41900,0.27564687,1.780698,,,,,,,,,,,,,,,,, -42000,0.18859115,1.747492,,,,,,,,,,,,,,,,, -42100,0.20012154,1.7362162,,,,,,,,,,,,,,,,, -42200,0.20398462,1.896294,,,,,,,,,,,,,,,,, -42219,,,0.6474224925041199,1.68617844581604,31.77898337621006,0.6637239456176758,1.564249873161316,28.69145620875173,3000.0,0.6774853467941284,1.4813222885131836,27.98596060024852,3003.0,15154.16379904747,25183.07438015937,15154.16379904747,10027.034233808516,0.5258986949920654,0.0 -42300,0.19051403,1.8014241,,,,,,,,,,,,,,,,, -42400,0.20989494,1.7386587,,,,,,,,,,,,,,,,, -42500,0.19711718,1.7923937,,,,,,,,,,,,,,,,, -42600,0.21676084,1.736239,,,,,,,,,,,,,,,,, -42700,0.20730041,1.8235584,,,,,,,,,,,,,,,,, -42800,0.19443291,1.782323,,,,,,,,,,,,,,,,, -42900,0.20146814,1.7152736,,,,,,,,,,,,,,,,, -43000,0.19936733,1.6662027,,,,,,,,,,,,,,,,, -43100,0.2942023,1.7650461,,,,,,,,,,,,,,,,, -43200,0.20233025,1.737719,,,,,,,,,,,,,,,,, -43300,0.2283864,1.7091507,,,,,,,,,,,,,,,,, -43400,0.19801232,1.767061,,,,,,,,,,,,,,,,, -43500,0.19264267,1.7009982,,,,,,,,,,,,,,,,, -43600,0.19215503,1.7901404,,,,,,,,,,,,,,,,, -43700,0.214196,1.820166,,,,,,,,,,,,,,,,, -43800,0.21613854,1.7007847,,,,,,,,,,,,,,,,, -43900,0.20755063,1.7272736,,,,,,,,,,,,,,,,, -44000,0.20084728,1.7042605,,,,,,,,,,,,,,,,, -44100,0.20490597,1.7775881,,,,,,,,,,,,,,,,, -44200,0.20431563,1.7612734,,,,,,,,,,,,,,,,, -44300,0.21367931,1.654593,,,,,,,,,,,,,,,,, -44400,0.20858835,1.6952144,,,,,,,,,,,,,,,,, -44500,0.18178016,1.7732536,,,,,,,,,,,,,,,,, -44564,,,0.6538242101669312,1.6220608949661257,32.21654232899015,0.6659061908721924,1.5523436069488523,28.64157721411009,3000.0,0.6764511466026306,1.4741398096084597,27.830552615538394,3003.0,15994.344855546951,26540.41818094253,15994.344855546951,10544.088871002195,0.5563359260559082,0.0 -44600,0.19741741,1.75492,,,,,,,,,,,,,,,,, -44700,0.1985329,1.7301571,,,,,,,,,,,,,,,,, -44800,0.21881135,1.8020477,,,,,,,,,,,,,,,,, -44900,0.21190642,1.7016945,,,,,,,,,,,,,,,,, -45000,0.21989934,1.6709588,,,,,,,,,,,,,,,,, -45100,0.19387169,1.6988542,,,,,,,,,,,,,,,,, -45200,0.21402873,1.7542759,,,,,,,,,,,,,,,,, -45300,0.17984243,1.7599623,,,,,,,,,,,,,,,,, -45400,0.1806182,1.7331146,,,,,,,,,,,,,,,,, -45500,0.18596363,1.8239543,,,,,,,,,,,,,,,,, -45600,0.19486865,1.727948,,,,,,,,,,,,,,,,, -45700,0.19846451,1.8331561,,,,,,,,,,,,,,,,, -45800,0.22636126,1.7144921,,,,,,,,,,,,,,,,, -45900,0.23246545,1.6956347,,,,,,,,,,,,,,,,, -46000,0.1898588,1.7211403,,,,,,,,,,,,,,,,, -46100,0.18889643,1.7908279,,,,,,,,,,,,,,,,, -46200,0.19885008,1.7171488,,,,,,,,,,,,,,,,, -46300,0.2557908,1.7414154,,,,,,,,,,,,,,,,, -46400,0.20374617,1.841025,,,,,,,,,,,,,,,,, -46500,0.27864683,1.8322381,,,,,,,,,,,,,,,,, -46600,0.2323756,1.8029763,,,,,,,,,,,,,,,,, -46700,0.20009595,1.8063381,,,,,,,,,,,,,,,,, -46800,0.20646001,1.7727958,,,,,,,,,,,,,,,,, -46900,0.19471413,1.7331661,,,,,,,,,,,,,,,,, -46910,,,0.6494661569595337,1.669074296951294,31.6695985126084,0.6667740941047668,1.5472114086151123,28.77795598635,3000.0,0.6811225414276123,1.4597563743591309,28.346774455501244,3003.0,16834.357803106308,27901.11219573021,16834.357803106308,11064.661159992218,0.5887689590454102,0.0 -47000,0.18364432,1.743984,,,,,,,,,,,,,,,,, -47100,0.21822026,1.699655,,,,,,,,,,,,,,,,, -47200,0.24806747,1.770139,,,,,,,,,,,,,,,,, -47300,0.19636427,1.7085154,,,,,,,,,,,,,,,,, -47400,0.19903426,1.7648561,,,,,,,,,,,,,,,,, -47500,0.20119774,1.7211308,,,,,,,,,,,,,,,,, -47600,0.23944558,1.7561618,,,,,,,,,,,,,,,,, -47700,0.19756468,1.8062766,,,,,,,,,,,,,,,,, -47800,0.21087152,1.7367747,,,,,,,,,,,,,,,,, -47900,0.24434158,1.7356241,,,,,,,,,,,,,,,,, -48000,0.26659983,1.786727,,,,,,,,,,,,,,,,, -48100,0.27428323,1.7428722,,,,,,,,,,,,,,,,, -48200,0.1892437,1.8090808,,,,,,,,,,,,,,,,, -48300,0.2397134,1.7538211,,,,,,,,,,,,,,,,, -48400,0.26183966,1.6988353,,,,,,,,,,,,,,,,, -48500,0.19186594,1.7189766,,,,,,,,,,,,,,,,, -48600,0.19957595,1.8154836,,,,,,,,,,,,,,,,, -48700,0.22315337,1.7746944,,,,,,,,,,,,,,,,, -48800,0.19824918,1.687377,,,,,,,,,,,,,,,,, -48900,0.19685365,1.6789565,,,,,,,,,,,,,,,,, -49000,0.18392594,1.7585793,,,,,,,,,,,,,,,,, -49100,0.24798721,1.7126288,,,,,,,,,,,,,,,,, -49200,0.2019833,1.6643399,,,,,,,,,,,,,,,,, -49256,,,0.6462664604187012,1.6856175661087036,31.678658230344087,0.6664517521858215,1.5448410511016846,28.842956598526044,3000.0,0.6809831261634827,1.4532852172851562,28.45048140232037,3003.0,17674.36550784111,29211.590218305588,17674.36550784111,11535.024050235748,0.6217913627624512,0.0 -49300,0.1836816,1.6734602,,,,,,,,,,,,,,,,, -49400,0.19609287,1.7164161,,,,,,,,,,,,,,,,, -49500,0.20943758,1.8017079,,,,,,,,,,,,,,,,, -49600,0.26799852,1.7281355,,,,,,,,,,,,,,,,, -49700,0.23519796,1.8174388,,,,,,,,,,,,,,,,, -49800,0.20173785,1.6808304,,,,,,,,,,,,,,,,, -49900,0.23209324,1.7698632,,,,,,,,,,,,,,,,, -50000,0.22755112,1.7541535,,,,,,,,,,,,,,,,, -50100,0.22861552,1.8523263,,,,,,,,,,,,,,,,, -50200,0.19493727,1.7403247,,,,,,,,,,,,,,,,, -50300,0.22442436,1.6685697,,,,,,,,,,,,,,,,, -50400,0.19066224,1.7971125,,,,,,,,,,,,,,,,, -50500,0.20128538,1.7780792,,,,,,,,,,,,,,,,, -50600,0.19689558,1.702488,,,,,,,,,,,,,,,,, -50700,0.22529078,1.7143885,,,,,,,,,,,,,,,,, -50800,0.20625712,1.7378389,,,,,,,,,,,,,,,,, -50900,0.24867006,1.7764901,,,,,,,,,,,,,,,,, -51000,0.21541929,1.7288439,,,,,,,,,,,,,,,,, -51100,0.18223806,1.7145962,,,,,,,,,,,,,,,,, -51200,0.22912589,1.7849592,,,,,,,,,,,,,,,,, -51300,0.3618468,1.7474835,,,,,,,,,,,,,,,,, -51400,0.2515323,1.6853342,,,,,,,,,,,,,,,,, -51500,0.19988436,1.7266555,,,,,,,,,,,,,,,,, -51600,0.1904477,1.8125964,,,,,,,,,,,,,,,,, -51601,,,0.6567646265029907,1.613392949104309,32.20008454597276,0.668398380279541,1.533761501312256,28.79298934338836,3000.0,0.6818430423736572,1.4524085521697998,28.224619330030805,3003.0,18514.27074623108,30635.27508306504,18514.27074623108,12118.698055744171,0.6539661884307861,0.0 -51700,0.19027765,1.7248933,,,,,,,,,,,,,,,,, -51800,0.20728323,1.8608805,,,,,,,,,,,,,,,,, -51900,0.21748114,1.6650454,,,,,,,,,,,,,,,,, -52000,0.18828219,1.696093,,,,,,,,,,,,,,,,, -52100,0.19314423,1.8151733,,,,,,,,,,,,,,,,, -52200,1.4686831,1.7270234,,,,,,,,,,,,,,,,, -52300,0.21946101,1.7183499,,,,,,,,,,,,,,,,, -52400,0.20564501,1.7887669,,,,,,,,,,,,,,,,, -52500,0.20508559,1.8190033,,,,,,,,,,,,,,,,, -52600,0.23593242,1.7736719,,,,,,,,,,,,,,,,, -52700,0.40994212,1.7710346,,,,,,,,,,,,,,,,, -52800,0.20273554,1.7811141,,,,,,,,,,,,,,,,, -52900,0.20759727,1.809612,,,,,,,,,,,,,,,,, -53000,0.22257935,1.7883002,,,,,,,,,,,,,,,,, -53100,0.23155266,1.7799939,,,,,,,,,,,,,,,,, -53200,0.20049247,1.7123281,,,,,,,,,,,,,,,,, -53300,0.22995573,1.6893438,,,,,,,,,,,,,,,,, -53400,0.2127266,1.7544296,,,,,,,,,,,,,,,,, -53500,0.2117905,1.7092175,,,,,,,,,,,,,,,,, -53600,0.19592604,1.7657123,,,,,,,,,,,,,,,,, -53700,0.248031,1.7452201,,,,,,,,,,,,,,,,, -53800,0.19147332,1.7702844,,,,,,,,,,,,,,,,, -53900,0.21344607,1.6697849,,,,,,,,,,,,,,,,, -53947,,,0.6505134701728821,1.660967469215393,31.867379464977144,0.6687207818031311,1.5278306007385254,28.76152895014794,3000.0,0.6840276718139648,1.4374014139175415,28.765265595968746,3003.0,19354.333768606182,32073.26548433304,19354.333768606182,12716.516110658646,0.6878187656402588,0.0 -54000,0.19339764,1.6808987,,,,,,,,,,,,,,,,, -54100,0.19333492,1.6455073,,,,,,,,,,,,,,,,, -54200,0.21659659,1.735578,,,,,,,,,,,,,,,,, -54300,0.20477612,1.6600326,,,,,,,,,,,,,,,,, -54400,0.20703858,1.6578405,,,,,,,,,,,,,,,,, -54500,0.19704951,1.7255968,,,,,,,,,,,,,,,,, -54600,0.20112471,1.7614266,,,,,,,,,,,,,,,,, -54700,0.26324415,1.720815,,,,,,,,,,,,,,,,, -54800,0.19343664,1.6636393,,,,,,,,,,,,,,,,, -54900,0.23106357,1.7691197,,,,,,,,,,,,,,,,, -55000,0.20188881,1.7030773,,,,,,,,,,,,,,,,, -55100,0.3721626,1.8521173,,,,,,,,,,,,,,,,, -55200,0.1774189,1.6586928,,,,,,,,,,,,,,,,, -55300,0.18400894,1.7764753,,,,,,,,,,,,,,,,, -55400,0.19019535,1.7279991,,,,,,,,,,,,,,,,, -55500,0.18110539,1.6930563,,,,,,,,,,,,,,,,, -55600,0.20212662,1.7522124,,,,,,,,,,,,,,,,, -55700,0.20424482,1.6677555,,,,,,,,,,,,,,,,, -55800,0.20545022,1.6727662,,,,,,,,,,,,,,,,, -55900,0.20477498,1.8053306,,,,,,,,,,,,,,,,, -56000,0.18938084,1.6784458,,,,,,,,,,,,,,,,, -56100,0.25775185,1.6659547,,,,,,,,,,,,,,,,, -56200,0.19981961,1.7844737,,,,,,,,,,,,,,,,, -56293,,,0.6667506098747253,1.5589523315429688,33.09192433508889,0.6687827706336975,1.5208386182785034,28.63368948315072,3000.0,0.6826448440551758,1.4398152828216553,28.341602362376708,3003.0,20194.51906824112,33494.9159014225,20194.51906824112,13297.87374830246,0.7198841571807861,0.0 -56300,0.25425923,1.7603196,,,,,,,,,,,,,,,,, -56400,0.19906384,1.6985313,,,,,,,,,,,,,,,,, -56500,0.21295628,1.7375036,,,,,,,,,,,,,,,,, -56600,0.2129666,1.6812997,,,,,,,,,,,,,,,,, -56700,0.20825353,1.7835041,,,,,,,,,,,,,,,,, -56800,0.19696067,1.804934,,,,,,,,,,,,,,,,, -56900,0.24828959,1.8408519,,,,,,,,,,,,,,,,, -57000,0.18710314,1.6776344,,,,,,,,,,,,,,,,, -57100,0.19565944,1.6760231,,,,,,,,,,,,,,,,, -57200,0.1912932,1.7251647,,,,,,,,,,,,,,,,, -57300,0.19681072,1.7209276,,,,,,,,,,,,,,,,, -57400,0.2078195,1.6939665,,,,,,,,,,,,,,,,, -57500,0.18848039,1.8293872,,,,,,,,,,,,,,,,, -57600,0.19158578,1.7363129,,,,,,,,,,,,,,,,, -57700,0.19932103,1.7809094,,,,,,,,,,,,,,,,, -57800,0.20549078,1.7708852,,,,,,,,,,,,,,,,, -57900,0.20668447,1.6155442,,,,,,,,,,,,,,,,, -58000,0.19972073,1.671116,,,,,,,,,,,,,,,,, -58100,0.188689,1.6874737,,,,,,,,,,,,,,,,, -58200,0.21678373,1.6184387,,,,,,,,,,,,,,,,, -58300,0.19036691,1.6713543,,,,,,,,,,,,,,,,, -58400,0.22080213,1.6911123,,,,,,,,,,,,,,,,, -58500,0.19529891,1.732369,,,,,,,,,,,,,,,,, -58600,0.18567768,1.7661854,,,,,,,,,,,,,,,,, -58639,,,0.655954897403717,1.618160605430603,32.188351151900314,0.6726884841918945,1.5124878883361816,29.07218011022572,3000.0,0.6835163831710815,1.4296272993087769,28.599861865485742,3003.0,21034.617770195007,34978.909477710724,21034.617770195007,13941.65684556961,0.7570688724517822,0.0 -58700,0.20893873,1.7078501,,,,,,,,,,,,,,,,, -58800,0.18816401,1.6934885,,,,,,,,,,,,,,,,, -58900,0.21172395,1.6186713,,,,,,,,,,,,,,,,, -59000,0.2055356,1.7548127,,,,,,,,,,,,,,,,, -59100,0.20653978,1.7165586,,,,,,,,,,,,,,,,, -59200,0.2265011,1.715002,,,,,,,,,,,,,,,,, -59300,0.23965088,1.7537202,,,,,,,,,,,,,,,,, -59400,0.21260639,1.7024767,,,,,,,,,,,,,,,,, -59500,0.35535043,1.731729,,,,,,,,,,,,,,,,, -59600,0.20369592,1.7168196,,,,,,,,,,,,,,,,, -59700,0.20238107,1.6879959,,,,,,,,,,,,,,,,, -59800,0.18985575,1.6651666,,,,,,,,,,,,,,,,, -59900,0.19565561,1.6614815,,,,,,,,,,,,,,,,, -60000,0.20826653,1.7398663,,,,,,,,,,,,,,,,, -60100,0.20452209,1.719511,,,,,,,,,,,,,,,,, -60200,0.19018954,1.6594476,,,,,,,,,,,,,,,,, -60300,0.20312119,1.7279987,,,,,,,,,,,,,,,,, -60400,0.20033145,1.7419134,,,,,,,,,,,,,,,,, -60500,0.19463196,1.7543554,,,,,,,,,,,,,,,,, -60600,0.19479421,1.7091222,,,,,,,,,,,,,,,,, -60700,0.21117012,1.7241769,,,,,,,,,,,,,,,,, -60800,0.21225077,1.7259178,,,,,,,,,,,,,,,,, -60900,0.20699781,1.7797279,,,,,,,,,,,,,,,,, -60985,,,0.6542706489562988,1.6405000686645508,32.1757717300823,0.6717957258224487,1.5104609727859497,29.062029670361643,3000.0,0.6864215135574341,1.420398235321045,29.007746720886413,3003.0,21874.773897647858,36303.46973657608,21874.773897647858,14425.952094316484,0.7913601398468018,0.0 -61000,0.20335758,1.7213062,,,,,,,,,,,,,,,,, -61100,0.5883748,1.6454862,,,,,,,,,,,,,,,,, -61200,0.23324968,1.698914,,,,,,,,,,,,,,,,, -61300,0.18883853,1.7517372,,,,,,,,,,,,,,,,, -61400,0.20022789,1.7016335,,,,,,,,,,,,,,,,, -61500,0.19389655,1.6990153,,,,,,,,,,,,,,,,, -61600,0.20268925,1.6693107,,,,,,,,,,,,,,,,, -61700,0.21999863,1.6787565,,,,,,,,,,,,,,,,, -61800,0.20190014,1.7487725,,,,,,,,,,,,,,,,, -61900,0.2331536,1.6958808,,,,,,,,,,,,,,,,, -62000,0.23222534,1.759919,,,,,,,,,,,,,,,,, -62100,0.3302368,1.642112,,,,,,,,,,,,,,,,, -62200,0.19910035,1.7231133,,,,,,,,,,,,,,,,, -62300,0.19844022,1.622588,,,,,,,,,,,,,,,,, -62400,0.19211847,1.6397616,,,,,,,,,,,,,,,,, -62500,0.19236855,1.6815218,,,,,,,,,,,,,,,,, -62600,0.18142688,1.678377,,,,,,,,,,,,,,,,, -62700,0.22331046,1.6569908,,,,,,,,,,,,,,,,, -62800,0.2234435,1.7119845,,,,,,,,,,,,,,,,, -62900,0.2105944,1.619251,,,,,,,,,,,,,,,,, -63000,0.23151189,1.7683842,,,,,,,,,,,,,,,,, -63100,0.1986469,1.6526809,,,,,,,,,,,,,,,,, -63200,0.19314043,1.6949041,,,,,,,,,,,,,,,,, -63300,0.2110657,1.7405057,,,,,,,,,,,,,,,,, -63330,,,0.6675124168395996,1.544390082359314,32.92653022817832,0.6732712388038635,1.505146026611328,29.13346058303088,3000.0,0.6854802370071411,1.4173305034637451,28.617106854066893,3003.0,22714.68177628517,37652.37822747231,22714.68177628517,14934.84242773056,0.8252365589141846,0.0 -63400,0.20393622,1.8113303,,,,,,,,,,,,,,,,, -63500,0.2504338,1.6624101,,,,,,,,,,,,,,,,, -63600,0.23182616,1.6480682,,,,,,,,,,,,,,,,, -63700,0.2119964,1.6379212,,,,,,,,,,,,,,,,, -63800,0.24901792,1.8049756,,,,,,,,,,,,,,,,, -63900,0.19008894,1.700013,,,,,,,,,,,,,,,,, -64000,0.26779506,1.71974,,,,,,,,,,,,,,,,, -64100,0.21049097,1.7911949,,,,,,,,,,,,,,,,, -64200,0.31915832,1.7448752,,,,,,,,,,,,,,,,, -64300,0.20430736,1.6417302,,,,,,,,,,,,,,,,, -64400,0.23516156,1.7163297,,,,,,,,,,,,,,,,, -64500,0.22466,1.7053298,,,,,,,,,,,,,,,,, -64600,0.4328278,1.6302121,,,,,,,,,,,,,,,,, -64700,0.29262576,1.6007785,,,,,,,,,,,,,,,,, -64800,0.2017332,1.6946018,,,,,,,,,,,,,,,,, -64900,0.22859478,1.6657826,,,,,,,,,,,,,,,,, -65000,0.1901103,1.7202392,,,,,,,,,,,,,,,,, -65100,0.19387367,1.6850213,,,,,,,,,,,,,,,,, -65200,0.21948712,1.7521348,,,,,,,,,,,,,,,,, -65300,1.7155962,1.6961584,,,,,,,,,,,,,,,,, -65400,0.1980594,1.757016,,,,,,,,,,,,,,,,, -65500,0.19659078,1.6465919,,,,,,,,,,,,,,,,, -65600,0.2087613,1.7165668,,,,,,,,,,,,,,,,, -65676,,,0.6562884449958801,1.612204670906067,32.68944922577982,0.6751435399055481,1.491057515144348,29.30038312727096,3000.0,0.6907327175140381,1.3964228630065918,29.165284759238112,3003.0,23554.65200853348,39108.53287887573,23554.65200853348,15550.918023347856,0.8605771064758301,0.0 -65700,0.21194054,1.6830829,,,,,,,,,,,,,,,,, -65800,0.18686475,1.6707959,,,,,,,,,,,,,,,,, -65900,0.205337,1.6594127,,,,,,,,,,,,,,,,, -66000,0.1980241,1.6796122,,,,,,,,,,,,,,,,, -66100,0.20609951,1.6538725,,,,,,,,,,,,,,,,, -66200,0.20152749,1.6602739,,,,,,,,,,,,,,,,, -66300,0.2249478,1.5978206,,,,,,,,,,,,,,,,, -66400,0.21201953,1.6218619,,,,,,,,,,,,,,,,, -66500,0.2151536,1.714491,,,,,,,,,,,,,,,,, -66600,0.22777312,1.6514498,,,,,,,,,,,,,,,,, -66700,0.2016096,1.6537482,,,,,,,,,,,,,,,,, -66800,0.22036135,1.599921,,,,,,,,,,,,,,,,, -66900,0.20757985,1.6998448,,,,,,,,,,,,,,,,, -67000,0.2247614,1.7258488,,,,,,,,,,,,,,,,, -67100,0.47882044,1.6025802,,,,,,,,,,,,,,,,, -67200,0.19946824,1.638268,,,,,,,,,,,,,,,,, -67300,0.20203812,1.6361479,,,,,,,,,,,,,,,,, -67400,0.1975815,1.7445852,,,,,,,,,,,,,,,,, -67500,0.2270243,1.6442288,,,,,,,,,,,,,,,,, -67600,0.19269794,1.6808468,,,,,,,,,,,,,,,,, -67700,0.2032161,1.7558386,,,,,,,,,,,,,,,,, -67800,0.20540918,1.7269114,,,,,,,,,,,,,,,,, -67900,0.20722185,1.6383746,,,,,,,,,,,,,,,,, -68000,0.18766606,1.591257,,,,,,,,,,,,,,,,, -68022,,,0.6579658389091492,1.613708734512329,32.35109639577502,0.6769165992736816,1.4885090589523315,29.62679347416672,3000.0,0.6890709400177002,1.3990836143493652,29.19693387188205,3003.0,24394.624833345413,40496.11887073517,24394.624833345413,16098.42095041275,0.8969278335571289,0.0 -68100,0.21157627,1.7225629,,,,,,,,,,,,,,,,, -68200,0.21069558,1.7275759,,,,,,,,,,,,,,,,, -68300,0.19922532,1.6334143,,,,,,,,,,,,,,,,, -68400,0.19278586,1.7104349,,,,,,,,,,,,,,,,, -68500,0.18536912,1.6584519,,,,,,,,,,,,,,,,, -68600,0.18222047,1.6730472,,,,,,,,,,,,,,,,, -68700,0.20114781,1.7274334,,,,,,,,,,,,,,,,, -68800,0.19020635,1.7547613,,,,,,,,,,,,,,,,, -68900,0.19153544,1.5799165,,,,,,,,,,,,,,,,, -69000,0.21341012,1.7018672,,,,,,,,,,,,,,,,, -69100,0.18253864,1.682136,,,,,,,,,,,,,,,,, -69200,0.19538869,1.6769338,,,,,,,,,,,,,,,,, -69300,0.1841177,1.6605799,,,,,,,,,,,,,,,,, -69400,0.21083379,1.6115544,,,,,,,,,,,,,,,,, -69500,0.20208521,1.6701297,,,,,,,,,,,,,,,,, -69600,0.20338887,1.6049376,,,,,,,,,,,,,,,,, -69700,0.20450182,1.6742514,,,,,,,,,,,,,,,,, -69800,0.20578213,1.6713912,,,,,,,,,,,,,,,,, -69900,0.20491521,1.7044475,,,,,,,,,,,,,,,,, -70000,0.1996757,1.6078322,,,,,,,,,,,,,,,,, -70100,0.19685264,1.6201463,,,,,,,,,,,,,,,,, -70200,0.22216207,1.672051,,,,,,,,,,,,,,,,, -70300,0.2490882,1.6371918,,,,,,,,,,,,,,,,, -70368,,,0.6629226207733154,1.5664702653884888,32.78073814231867,0.6758130788803101,1.479118824005127,29.19225225701453,3000.0,0.6903492212295532,1.3912490606307983,28.78012899563756,3003.0,25234.659603357315,42054.70087099075,25234.659603357315,16816.858570575714,0.9333441257476808,0.0 -70400,0.19794373,1.6558294,,,,,,,,,,,,,,,,, -70500,0.19645181,1.6115474,,,,,,,,,,,,,,,,, -70600,0.1912906,1.6866047,,,,,,,,,,,,,,,,, -70700,0.20503522,1.7132695,,,,,,,,,,,,,,,,, -70800,0.19936924,1.6519289,,,,,,,,,,,,,,,,, -70900,0.2115548,1.7669991,,,,,,,,,,,,,,,,, -71000,0.20130228,1.6873723,,,,,,,,,,,,,,,,, -71100,0.20846003,1.60061,,,,,,,,,,,,,,,,, -71200,0.20608175,1.6584752,,,,,,,,,,,,,,,,, -71300,0.24580798,1.7318958,,,,,,,,,,,,,,,,, -71400,0.1989512,1.5898911,,,,,,,,,,,,,,,,, -71500,0.22056365,1.602597,,,,,,,,,,,,,,,,, -71600,0.20341161,1.6530501,,,,,,,,,,,,,,,,, -71700,0.34021607,1.679482,,,,,,,,,,,,,,,,, -71800,0.19293405,1.750308,,,,,,,,,,,,,,,,, -71900,0.20297466,1.6076287,,,,,,,,,,,,,,,,, -72000,0.73005056,1.6866434,,,,,,,,,,,,,,,,, -72100,0.20520197,1.6901035,,,,,,,,,,,,,,,,, -72200,0.2000411,1.6306363,,,,,,,,,,,,,,,,, -72300,0.20865288,1.6287462,,,,,,,,,,,,,,,,, -72400,0.19518912,1.6410782,,,,,,,,,,,,,,,,, -72500,0.19537346,1.673521,,,,,,,,,,,,,,,,, -72600,0.20104998,1.516323,,,,,,,,,,,,,,,,, -72700,0.19627273,1.7017467,,,,,,,,,,,,,,,,, -72714,,,0.6606534719467163,1.593955636024475,32.511623246878834,0.6776605248451233,1.4730331897735596,29.498651552431973,3000.0,0.6935099959373474,1.380598545074463,29.474507063397496,3003.0,26074.649913072582,43471.833458423615,26074.649913072582,17393.892706871033,0.9688091278076172,0.0 -72800,0.19580913,1.6477791,,,,,,,,,,,,,,,,, -72900,0.2037936,1.6544925,,,,,,,,,,,,,,,,, -73000,0.194161,1.6620983,,,,,,,,,,,,,,,,, -73100,0.2198389,1.6161718,,,,,,,,,,,,,,,,, -73200,0.20736118,1.6614903,,,,,,,,,,,,,,,,, -73300,0.20262757,1.5835482,,,,,,,,,,,,,,,,, -73400,0.2105127,1.6784036,,,,,,,,,,,,,,,,, -73500,0.19553451,1.616393,,,,,,,,,,,,,,,,, -73600,0.19602206,1.6264695,,,,,,,,,,,,,,,,, -73700,0.21783353,1.6599364,,,,,,,,,,,,,,,,, -73800,0.19735286,1.6290971,,,,,,,,,,,,,,,,, -73900,0.19988342,1.6562773,,,,,,,,,,,,,,,,, -74000,0.20735186,1.6347624,,,,,,,,,,,,,,,,, -74100,0.20039009,1.6411047,,,,,,,,,,,,,,,,, -74200,0.21008839,1.6809986,,,,,,,,,,,,,,,,, -74300,0.2246375,1.6129414,,,,,,,,,,,,,,,,, -74400,0.19815451,1.727943,,,,,,,,,,,,,,,,, -74500,0.18320967,1.6602396,,,,,,,,,,,,,,,,, -74600,0.19121894,1.6693585,,,,,,,,,,,,,,,,, -74700,0.21275583,1.6775906,,,,,,,,,,,,,,,,, -74800,0.19979085,1.6240339,,,,,,,,,,,,,,,,, -74900,0.20722687,1.5520625,,,,,,,,,,,,,,,,, -75000,0.2250261,1.6352918,,,,,,,,,,,,,,,,, -75061,,,0.6858612298965454,1.4403479099273682,34.89402750189174,0.677734911441803,1.4711121320724487,29.468598279995035,3000.0,0.6918134093284607,1.3813815116882324,29.14106138373952,3003.0,26914.73360776901,44794.491381406784,26914.73360776901,17876.356063604355,1.006338119506836,0.0 -75100,0.19162643,1.620102,,,,,,,,,,,,,,,,, -75200,0.2321412,1.6884558,,,,,,,,,,,,,,,,, -75300,0.21776438,1.5915916,,,,,,,,,,,,,,,,, -75400,0.17885025,1.5150841,,,,,,,,,,,,,,,,, -75500,0.20946814,1.577575,,,,,,,,,,,,,,,,, -75600,0.19500424,1.695587,,,,,,,,,,,,,,,,, -75700,0.18857092,1.5881009,,,,,,,,,,,,,,,,, -75800,0.22740635,1.6362215,,,,,,,,,,,,,,,,, -75900,0.21841122,1.6716483,,,,,,,,,,,,,,,,, -76000,0.20265011,1.6548334,,,,,,,,,,,,,,,,, -76100,0.23683158,1.6618853,,,,,,,,,,,,,,,,, -76200,0.20218106,1.6595745,,,,,,,,,,,,,,,,, -76300,0.2018202,1.7195085,,,,,,,,,,,,,,,,, -76400,0.22532246,1.7002364,,,,,,,,,,,,,,,,, -76500,0.19097602,1.6420764,,,,,,,,,,,,,,,,, -76600,0.21162404,1.6320336,,,,,,,,,,,,,,,,, -76700,0.23387733,1.4924242,,,,,,,,,,,,,,,,, -76800,0.19718152,1.6642417,,,,,,,,,,,,,,,,, -76900,0.19911954,1.6678673,,,,,,,,,,,,,,,,, -77000,0.20949735,1.7009358,,,,,,,,,,,,,,,,, -77100,0.19359466,1.649208,,,,,,,,,,,,,,,,, -77200,0.20119902,1.6623306,,,,,,,,,,,,,,,,, -77300,0.20151737,1.6979558,,,,,,,,,,,,,,,,, -77400,0.2023148,1.6731839,,,,,,,,,,,,,,,,, -77406,,,0.665233850479126,1.553189396858215,33.34801089317181,0.6795451641082764,1.4581555128097534,29.55528459819339,3000.0,0.694428026676178,1.3683438301086426,29.284875716363665,3003.0,27754.92539286613,46285.17871427536,27754.92539286613,18526.739768743515,1.0434327125549316,0.0 -77500,0.2115018,1.6227142,,,,,,,,,,,,,,,,, -77600,0.21359515,1.5880574,,,,,,,,,,,,,,,,, -77700,0.18793267,1.6022309,,,,,,,,,,,,,,,,, -77800,0.20716639,1.5771449,,,,,,,,,,,,,,,,, -77900,0.20822166,1.7057661,,,,,,,,,,,,,,,,, -78000,0.23421511,1.6650839,,,,,,,,,,,,,,,,, -78100,0.20372202,1.6214179,,,,,,,,,,,,,,,,, -78200,0.18863337,1.6352044,,,,,,,,,,,,,,,,, -78300,0.21069548,1.6371878,,,,,,,,,,,,,,,,, -78400,0.23178703,1.5835884,,,,,,,,,,,,,,,,, -78500,0.19463827,1.5951327,,,,,,,,,,,,,,,,, -78600,0.2127602,1.7232617,,,,,,,,,,,,,,,,, -78700,0.24580054,1.681875,,,,,,,,,,,,,,,,, -78800,0.20346786,1.6378261,,,,,,,,,,,,,,,,, -78900,0.19956224,1.6373272,,,,,,,,,,,,,,,,, -79000,0.20777613,1.6865247,,,,,,,,,,,,,,,,, -79100,0.1901261,1.618868,,,,,,,,,,,,,,,,, -79200,0.21299596,1.6146541,,,,,,,,,,,,,,,,, -79300,0.19205788,1.6306926,,,,,,,,,,,,,,,,, -79400,0.23471995,1.6320337,,,,,,,,,,,,,,,,, -79500,0.21722001,1.5721344,,,,,,,,,,,,,,,,, -79600,0.22160485,1.5979614,,,,,,,,,,,,,,,,, -79700,2.1776104,1.8054137,,,,,,,,,,,,,,,,, -79751,,,0.6607846617698669,1.5785906314849854,33.05226990077929,0.6790120601654053,1.4651505947113037,29.44278157741172,3000.0,0.6931613683700562,1.3743035793304443,29.269250344436912,3003.0,28595.10469722748,47696.377888441086,28595.10469722748,19097.646564006805,1.0817084312438965,0.0 -79800,0.3377219,1.6168147,,,,,,,,,,,,,,,,, -79900,0.19179524,1.6563346,,,,,,,,,,,,,,,,, -80000,0.20870635,1.5709862,,,,,,,,,,,,,,,,, -80100,0.22783719,1.6697336,,,,,,,,,,,,,,,,, -80200,0.19310573,1.6246611,,,,,,,,,,,,,,,,, -80300,0.20073855,1.6537577,,,,,,,,,,,,,,,,, -80400,0.1987329,1.6331533,,,,,,,,,,,,,,,,, -80500,0.19256784,1.631983,,,,,,,,,,,,,,,,, -80600,0.21194659,1.64427,,,,,,,,,,,,,,,,, -80700,0.21352129,1.6636095,,,,,,,,,,,,,,,,, -80800,0.19643904,1.5982991,,,,,,,,,,,,,,,,, -80900,0.19550648,1.6973267,,,,,,,,,,,,,,,,, -81000,0.19672872,1.6514195,,,,,,,,,,,,,,,,, -81100,0.18863834,1.5772117,,,,,,,,,,,,,,,,, -81200,0.20552287,1.5784546,,,,,,,,,,,,,,,,, -81300,0.20569153,1.5340106,,,,,,,,,,,,,,,,, -81400,0.20195128,1.7593122,,,,,,,,,,,,,,,,, -81500,0.19699188,1.5878516,,,,,,,,,,,,,,,,, -81600,0.20421968,1.662318,,,,,,,,,,,,,,,,, -81700,0.19473337,1.5975738,,,,,,,,,,,,,,,,, -81800,0.19363506,1.6622318,,,,,,,,,,,,,,,,, -81900,0.18817638,1.6112087,,,,,,,,,,,,,,,,, -82000,0.24578963,1.5745468,,,,,,,,,,,,,,,,, -82097,,,0.6755383014678955,1.49465811252594,33.659196883772914,0.6815910339355469,1.449007272720337,29.689568520414724,3000.0,0.6955435872077942,1.3564544916152954,29.42719219404512,3003.0,29434.99210190773,49080.874106407166,29434.99210190773,19642.142485380173,1.1191542148590088,0.0 -82100,0.19733505,1.6050044,,,,,,,,,,,,,,,,, -82200,0.20334621,1.5646198,,,,,,,,,,,,,,,,, -82300,0.20139048,1.6874256,,,,,,,,,,,,,,,,, -82400,0.18826115,1.5990481,,,,,,,,,,,,,,,,, -82500,0.20023496,1.5739695,,,,,,,,,,,,,,,,, -82600,0.2015686,1.5914747,,,,,,,,,,,,,,,,, -82700,0.19383219,1.6108297,,,,,,,,,,,,,,,,, -82800,0.21953021,1.5710902,,,,,,,,,,,,,,,,, -82900,0.20833232,1.6335255,,,,,,,,,,,,,,,,, -83000,0.20676984,1.5544121,,,,,,,,,,,,,,,,, -83100,0.1996211,1.5759397,,,,,,,,,,,,,,,,, -83200,0.21114323,1.6024067,,,,,,,,,,,,,,,,, -83300,0.18988585,1.568895,,,,,,,,,,,,,,,,, -83400,0.22017372,1.599446,,,,,,,,,,,,,,,,, -83500,0.20141724,1.5974894,,,,,,,,,,,,,,,,, -83600,0.21064782,1.6321428,,,,,,,,,,,,,,,,, -83700,0.21082012,1.6787428,,,,,,,,,,,,,,,,, -83800,0.20875837,1.6816955,,,,,,,,,,,,,,,,, -83900,0.19595362,1.6524372,,,,,,,,,,,,,,,,, -84000,0.21022995,1.5929892,,,,,,,,,,,,,,,,, -84100,0.23011562,1.6211478,,,,,,,,,,,,,,,,, -84200,0.20539933,1.6299682,,,,,,,,,,,,,,,,, -84300,0.20977944,1.6663283,,,,,,,,,,,,,,,,, -84400,0.20567556,1.6435964,,,,,,,,,,,,,,,,, -84442,,,0.6705901026725769,1.5163570642471311,33.776368870019915,0.6826077699661255,1.4386560916900637,29.743246304174004,3000.0,0.697402834892273,1.344928741455078,29.5544296637964,3003.0,30274.89132285118,50588.53385710716,30274.89132285118,20309.78834581375,1.156053066253662,0.0 -84500,0.20643502,1.6745163,,,,,,,,,,,,,,,,, -84600,0.19830984,1.6108284,,,,,,,,,,,,,,,,, -84700,0.50554,1.5522901,,,,,,,,,,,,,,,,, -84800,0.22491296,1.6023242,,,,,,,,,,,,,,,,, -84900,0.20251621,1.600063,,,,,,,,,,,,,,,,, -85000,0.2070538,1.6301734,,,,,,,,,,,,,,,,, -85100,0.20741196,1.6023486,,,,,,,,,,,,,,,,, -85200,0.19382247,1.6184461,,,,,,,,,,,,,,,,, -85300,0.21460691,1.6804417,,,,,,,,,,,,,,,,, -85400,0.20525269,1.618245,,,,,,,,,,,,,,,,, -85500,0.20890066,1.6852312,,,,,,,,,,,,,,,,, -85600,0.21088207,1.6311547,,,,,,,,,,,,,,,,, -85700,0.21799448,1.6252373,,,,,,,,,,,,,,,,, -85800,0.2072768,1.6086323,,,,,,,,,,,,,,,,, -85900,0.20004514,1.6418353,,,,,,,,,,,,,,,,, -86000,0.21895996,1.663197,,,,,,,,,,,,,,,,, -86100,0.21475647,1.6810136,,,,,,,,,,,,,,,,, -86200,0.22122316,1.5934986,,,,,,,,,,,,,,,,, -86300,0.20532821,1.6381005,,,,,,,,,,,,,,,,, -86400,0.22169866,1.6300629,,,,,,,,,,,,,,,,, -86500,0.21318032,1.6398151,,,,,,,,,,,,,,,,, -86600,0.21033356,1.6385975,,,,,,,,,,,,,,,,, -86700,0.19825628,1.6048214,,,,,,,,,,,,,,,,, -86788,,,0.6693463325500488,1.5310808420181274,33.519672285560624,0.682632565498352,1.43792986869812,30.200129200007808,3000.0,0.6986927390098572,1.338726043701172,29.75376935276304,3003.0,31114.93718099594,51943.9401807785,31114.93718099594,20825.03607749939,1.1935102939605713,0.0 -86800,0.21128935,1.5472257,,,,,,,,,,,,,,,,, -86900,0.20294185,1.5827385,,,,,,,,,,,,,,,,, -87000,0.21774656,1.5606463,,,,,,,,,,,,,,,,, -87100,0.20511617,1.6167059,,,,,,,,,,,,,,,,, -87200,0.20897698,1.5864989,,,,,,,,,,,,,,,,, -87300,0.21537662,1.5658337,,,,,,,,,,,,,,,,, -87400,0.2032797,1.5810025,,,,,,,,,,,,,,,,, -87500,0.21537879,1.5252997,,,,,,,,,,,,,,,,, -87600,0.20864432,1.615295,,,,,,,,,,,,,,,,, -87700,0.1923474,1.5762237,,,,,,,,,,,,,,,,, -87800,0.20296632,1.5327115,,,,,,,,,,,,,,,,, -87900,0.21039063,1.6076548,,,,,,,,,,,,,,,,, -88000,0.21317825,1.5994786,,,,,,,,,,,,,,,,, -88100,0.20182434,1.6143663,,,,,,,,,,,,,,,,, -88200,0.21098132,1.6311495,,,,,,,,,,,,,,,,, -88300,0.20028485,1.6012857,,,,,,,,,,,,,,,,, -88400,0.21695513,1.5644429,,,,,,,,,,,,,,,,, -88500,0.2151568,1.5691147,,,,,,,,,,,,,,,,, -88600,0.21672632,1.6346318,,,,,,,,,,,,,,,,, -88700,0.233315,1.6479219,,,,,,,,,,,,,,,,, -88800,0.20421554,1.5993572,,,,,,,,,,,,,,,,, -88900,0.2335876,1.5733926,,,,,,,,,,,,,,,,, -89000,0.19735774,1.5492674,,,,,,,,,,,,,,,,, -89100,0.21344392,1.57064,,,,,,,,,,,,,,,,, -89133,,,0.6757782101631165,1.485595941543579,34.190875248660575,0.6854719519615173,1.4247512817382812,29.772108184170435,3000.0,0.6976584792137146,1.333077311515808,29.52895939744145,3003.0,31955.03034901619,53384.378826379776,31955.03034901619,21425.26732730865,1.2301530838012695,0.0 -89200,0.1946488,1.6211587,,,,,,,,,,,,,,,,, -89300,0.21454225,1.5071677,,,,,,,,,,,,,,,,, -89400,0.20419243,1.5031163,,,,,,,,,,,,,,,,, -89500,0.227333,1.6447753,,,,,,,,,,,,,,,,, -89600,0.21633266,1.6220752,,,,,,,,,,,,,,,,, -89700,0.20227373,1.6308193,,,,,,,,,,,,,,,,, -89800,0.22620337,1.5712093,,,,,,,,,,,,,,,,, -89900,0.20834626,1.6487597,,,,,,,,,,,,,,,,, -90000,0.20892096,1.6297317,,,,,,,,,,,,,,,,, -90100,0.21284927,1.6484363,,,,,,,,,,,,,,,,, -90200,0.21313553,1.5987835,,,,,,,,,,,,,,,,, -90300,0.21560101,1.6563181,,,,,,,,,,,,,,,,, -90400,0.2100971,1.6212459,,,,,,,,,,,,,,,,, -90500,0.23108436,1.5565203,,,,,,,,,,,,,,,,, -90600,0.20608054,1.6404141,,,,,,,,,,,,,,,,, -90700,0.21129331,1.5874076,,,,,,,,,,,,,,,,, -90800,0.20120125,1.660101,,,,,,,,,,,,,,,,, -90900,0.2150217,1.5753264,,,,,,,,,,,,,,,,, -91000,0.2080153,1.5737969,,,,,,,,,,,,,,,,, -91100,0.20400092,1.503796,,,,,,,,,,,,,,,,, -91200,0.21221535,1.6066881,,,,,,,,,,,,,,,,, -91300,0.19732949,1.533621,,,,,,,,,,,,,,,,, -91400,0.2080357,1.5513302,,,,,,,,,,,,,,,,, -91479,,,0.6764565110206604,1.4868851900100708,33.79142281553153,0.6856207251548767,1.4216893911361694,30.081463864085272,3000.0,0.7033525109291077,1.3221434354782104,30.261509158958816,3003.0,32794.96316599846,54851.17655205727,32794.96316599846,22052.010696411133,1.276489496231079,0.0 -91500,0.21479647,1.5586699,,,,,,,,,,,,,,,,, -91600,0.21439256,1.6445364,,,,,,,,,,,,,,,,, -91700,0.2056213,1.5753235,,,,,,,,,,,,,,,,, -91800,0.2132279,1.5574057,,,,,,,,,,,,,,,,, -91900,0.23030691,1.5222298,,,,,,,,,,,,,,,,, -92000,0.22892232,1.6261886,,,,,,,,,,,,,,,,, -92100,0.22637036,1.6270553,,,,,,,,,,,,,,,,, -92200,0.2151641,1.5899203,,,,,,,,,,,,,,,,, -92300,0.21667616,1.5309331,,,,,,,,,,,,,,,,, -92400,0.22690694,1.6432455,,,,,,,,,,,,,,,,, -92500,0.22893068,1.5695174,,,,,,,,,,,,,,,,, -92600,0.22411686,1.5587566,,,,,,,,,,,,,,,,, -92700,0.21160945,1.5443764,,,,,,,,,,,,,,,,, -92800,0.21750832,1.5684009,,,,,,,,,,,,,,,,, -92900,0.20006107,1.5361947,,,,,,,,,,,,,,,,, -93000,0.22398926,1.6266189,,,,,,,,,,,,,,,,, -93100,0.21775518,1.6479553,,,,,,,,,,,,,,,,, -93200,0.19944875,1.5934013,,,,,,,,,,,,,,,,, -93300,0.21418795,1.5825748,,,,,,,,,,,,,,,,, -93400,0.20785087,1.5885873,,,,,,,,,,,,,,,,, -93500,0.2183398,1.5313215,,,,,,,,,,,,,,,,, -93600,0.21434315,1.5582634,,,,,,,,,,,,,,,,, -93700,0.19810654,1.5223808,,,,,,,,,,,,,,,,, -93800,0.20016418,1.4963452,,,,,,,,,,,,,,,,, -93825,,,0.690422534942627,1.4100230932235718,34.952902904917686,0.685050368309021,1.4175060987472534,30.129531487347982,3000.0,0.7028295993804932,1.319978952407837,30.10693494571689,3003.0,33634.999903678894,56272.6827609539,33634.999903678894,22633.36717224121,1.315727710723877,0.0 -93900,0.2290803,1.6105094,,,,,,,,,,,,,,,,, -94000,0.21466896,1.5209795,,,,,,,,,,,,,,,,, -94100,0.20118485,1.540332,,,,,,,,,,,,,,,,, -94200,0.21159407,1.5852319,,,,,,,,,,,,,,,,, -94300,0.20594619,1.5994477,,,,,,,,,,,,,,,,, -94400,0.20935057,1.5697106,,,,,,,,,,,,,,,,, -94500,0.21227747,1.589754,,,,,,,,,,,,,,,,, -94600,0.20455614,1.6208771,,,,,,,,,,,,,,,,, -94700,0.22642833,1.5585079,,,,,,,,,,,,,,,,, -94800,0.20894437,1.5324951,,,,,,,,,,,,,,,,, -94900,0.20802993,1.5552818,,,,,,,,,,,,,,,,, -95000,0.21560095,1.5323007,,,,,,,,,,,,,,,,, -95100,0.21968178,1.5895567,,,,,,,,,,,,,,,,, -95200,0.21209049,1.6171787,,,,,,,,,,,,,,,,, -95300,0.21563275,1.5942682,,,,,,,,,,,,,,,,, -95400,0.21793,1.4845285,,,,,,,,,,,,,,,,, -95500,0.21699023,1.4877278,,,,,,,,,,,,,,,,, -95600,0.21423687,1.5333434,,,,,,,,,,,,,,,,, -95700,0.21634524,1.6420176,,,,,,,,,,,,,,,,, -95800,0.21528907,1.6173103,,,,,,,,,,,,,,,,, -95900,0.24171276,1.6421875,,,,,,,,,,,,,,,,, -96000,0.2225096,1.5444032,,,,,,,,,,,,,,,,, -96100,0.20970474,1.5358014,,,,,,,,,,,,,,,,, -96171,,,0.6796119809150696,1.4651366472244265,34.32243432293985,0.6871210336685181,1.4065537452697754,30.200335772061464,3000.0,0.7031084895133972,1.311916708946228,30.01868414744477,3003.0,34475.117656469345,57911.7241795063,34475.117656469345,23432.17221236229,1.3558223247528076,0.0 -96200,0.212465,1.5209494,,,,,,,,,,,,,,,,, -96300,0.22730899,1.6097434,,,,,,,,,,,,,,,,, -96400,0.22525193,1.6145356,,,,,,,,,,,,,,,,, -96500,0.21536016,1.532313,,,,,,,,,,,,,,,,, -96600,0.2332857,1.624813,,,,,,,,,,,,,,,,, -96700,0.27825794,1.5311685,,,,,,,,,,,,,,,,, -96800,0.222635,1.6289489,,,,,,,,,,,,,,,,, -96900,0.2119631,1.5693526,,,,,,,,,,,,,,,,, -97000,0.21729544,1.5676379,,,,,,,,,,,,,,,,, -97100,0.20895809,1.5077704,,,,,,,,,,,,,,,,, -97200,0.21957743,1.5090928,,,,,,,,,,,,,,,,, -97300,0.21804367,1.575578,,,,,,,,,,,,,,,,, -97400,0.21303666,1.55794,,,,,,,,,,,,,,,,, -97500,0.21410511,1.5745909,,,,,,,,,,,,,,,,, -97600,0.22521237,1.5649165,,,,,,,,,,,,,,,,, -97700,0.21499501,1.5447037,,,,,,,,,,,,,,,,, -97800,0.21268657,1.5944618,,,,,,,,,,,,,,,,, -97900,0.21524139,1.5351825,,,,,,,,,,,,,,,,, -98000,0.21651913,1.6294652,,,,,,,,,,,,,,,,, -98100,0.23488961,1.5482422,,,,,,,,,,,,,,,,, -98200,0.21600273,1.5868464,,,,,,,,,,,,,,,,, -98300,0.21764918,1.6287426,,,,,,,,,,,,,,,,, -98400,0.21219665,1.4784428,,,,,,,,,,,,,,,,, -98500,0.22905393,1.5641928,,,,,,,,,,,,,,,,, -98517,,,0.6749527454376221,1.4922711849212646,33.76268337396561,0.6880509853363037,1.4029872417449951,30.562362559763177,3000.0,0.7058160901069641,1.301902413368225,30.582308832194634,3003.0,35315.19370055199,59286.53086400032,35315.19370055199,23966.787534713745,1.394505500793457,0.0 -98600,0.23535396,1.5431846,,,,,,,,,,,,,,,,, -98700,0.22741751,1.5185535,,,,,,,,,,,,,,,,, -98800,0.22026117,1.6407075,,,,,,,,,,,,,,,,, -98900,0.22222136,1.6059221,,,,,,,,,,,,,,,,, -99000,0.2089255,1.5543859,,,,,,,,,,,,,,,,, -99100,0.20102197,1.533915,,,,,,,,,,,,,,,,, -99200,0.23190384,1.5175735,,,,,,,,,,,,,,,,, -99300,0.2224152,1.54155,,,,,,,,,,,,,,,,, -99400,0.21778625,1.5160764,,,,,,,,,,,,,,,,, -99500,0.2206057,1.5353348,,,,,,,,,,,,,,,,, -99600,0.22133507,1.545948,,,,,,,,,,,,,,,,, -99700,0.21167535,1.4745282,,,,,,,,,,,,,,,,, -99800,0.22076972,1.5625352,,,,,,,,,,,,,,,,, -99900,0.2301229,1.4871219,,,,,,,,,,,,,,,,, -100000,0.21911216,1.4933002,,,,,,,,,,,,,,,,, -100100,0.21726245,1.5966141,,,,,,,,,,,,,,,,, -100200,0.22226737,1.5365571,,,,,,,,,,,,,,,,, -100300,0.20507634,1.5036668,,,,,,,,,,,,,,,,, -100400,0.22351244,1.5622137,,,,,,,,,,,,,,,,, -100500,0.22293359,1.5464267,,,,,,,,,,,,,,,,, -100600,0.2156996,1.5205954,,,,,,,,,,,,,,,,, -100700,0.2320441,1.5126468,,,,,,,,,,,,,,,,, -100800,0.23152399,1.5166905,,,,,,,,,,,,,,,,, -100862,,,0.6901666522026062,1.3958227634429932,34.54494358756864,0.6898488402366638,1.3990224599838257,30.523842063106216,3000.0,0.7054442167282104,1.299425721168518,30.24920047046108,3003.0,36155.15829825401,60668.71771264076,36155.15829825401,24508.896410226826,1.4340672492980957,0.0 -100900,0.2166571,1.556816,,,,,,,,,,,,,,,,, -101000,0.22467026,1.5714326,,,,,,,,,,,,,,,,, -101100,0.21049598,1.5454642,,,,,,,,,,,,,,,,, -101200,0.23221526,1.559265,,,,,,,,,,,,,,,,, -101300,0.22928682,1.6017218,,,,,,,,,,,,,,,,, -101400,0.22063819,1.5647175,,,,,,,,,,,,,,,,, -101500,0.21545005,1.5525582,,,,,,,,,,,,,,,,, -101600,0.23238918,1.5286173,,,,,,,,,,,,,,,,, -101700,0.22116858,1.5221356,,,,,,,,,,,,,,,,, -101800,0.22071539,1.5630461,,,,,,,,,,,,,,,,, -101900,0.2090783,1.615913,,,,,,,,,,,,,,,,, -102000,0.21414188,1.5614969,,,,,,,,,,,,,,,,, -102100,0.22260715,1.5305887,,,,,,,,,,,,,,,,, -102200,0.22661746,1.5790395,,,,,,,,,,,,,,,,, -102300,0.2324214,1.4973655,,,,,,,,,,,,,,,,, -102400,0.22045209,1.539569,,,,,,,,,,,,,,,,, -102500,0.22551441,1.5276378,,,,,,,,,,,,,,,,, -102600,0.23074347,1.5300933,,,,,,,,,,,,,,,,, -102700,0.22963871,1.6488988,,,,,,,,,,,,,,,,, -102800,0.22928362,1.568436,,,,,,,,,,,,,,,,, -102900,0.22824487,1.5529324,,,,,,,,,,,,,,,,, -103000,0.22150119,1.5405793,,,,,,,,,,,,,,,,, -103100,0.22912535,1.5178233,,,,,,,,,,,,,,,,, -103200,0.22030103,1.6036136,,,,,,,,,,,,,,,,, -103207,,,0.6826054453849792,1.445312738418579,34.78132076309567,0.6900224089622498,1.391385555267334,30.54686521049929,3000.0,0.708686351776123,1.2871750593185425,30.83900984318424,3003.0,36995.03522968292,62076.9599506855,36995.03522968292,25077.136883974075,1.4827287197113037,0.0 -103300,0.22596149,1.5938772,,,,,,,,,,,,,,,,, -103400,0.21909656,1.5662178,,,,,,,,,,,,,,,,, -103500,0.2161521,1.502473,,,,,,,,,,,,,,,,, -103600,0.21368518,1.5157973,,,,,,,,,,,,,,,,, -103700,0.23290165,1.4764428,,,,,,,,,,,,,,,,, -103800,0.22545171,1.584297,,,,,,,,,,,,,,,,, -103900,0.21457776,1.5358847,,,,,,,,,,,,,,,,, -104000,0.22077332,1.5594444,,,,,,,,,,,,,,,,, -104100,0.207874,1.4601072,,,,,,,,,,,,,,,,, -104200,0.21976994,1.5374938,,,,,,,,,,,,,,,,, -104300,0.21243866,1.5164773,,,,,,,,,,,,,,,,, -104400,0.22215784,1.5992215,,,,,,,,,,,,,,,,, -104500,0.2170451,1.5383756,,,,,,,,,,,,,,,,, -104600,0.21476696,1.4828371,,,,,,,,,,,,,,,,, -104700,0.21853323,1.5166266,,,,,,,,,,,,,,,,, -104800,0.23480886,1.5148578,,,,,,,,,,,,,,,,, -104900,0.22044331,1.5115865,,,,,,,,,,,,,,,,, -105000,0.21720286,1.4722433,,,,,,,,,,,,,,,,, -105100,0.22349545,1.551707,,,,,,,,,,,,,,,,, -105200,0.21521798,1.5599271,,,,,,,,,,,,,,,,, -105300,0.22413322,1.5418336,,,,,,,,,,,,,,,,, -105400,0.233238,1.4913775,,,,,,,,,,,,,,,,, -105500,0.24325345,1.5269258,,,,,,,,,,,,,,,,, -105552,,,0.6831424832344055,1.4466559886932373,34.30571893606282,0.6902456283569336,1.3884084224700928,30.71525097595463,3000.0,0.706187903881073,1.2858394384384155,30.671070000288665,3003.0,37835.16900038719,63494.38958978653,37835.16900038719,25654.313775777817,1.522960901260376,0.0 -105600,0.21531114,1.5046086,,,,,,,,,,,,,,,,, -105700,0.21169387,1.4613429,,,,,,,,,,,,,,,,, -105800,0.22405466,1.5283945,,,,,,,,,,,,,,,,, -105900,0.22332783,1.5903677,,,,,,,,,,,,,,,,, -106000,0.22658665,1.5478656,,,,,,,,,,,,,,,,, -106100,0.2300914,1.4974744,,,,,,,,,,,,,,,,, -106200,0.2278091,1.5113832,,,,,,,,,,,,,,,,, -106300,0.2323255,1.542158,,,,,,,,,,,,,,,,, -106400,0.21865147,1.501892,,,,,,,,,,,,,,,,, -106500,0.21960387,1.5155382,,,,,,,,,,,,,,,,, -106600,0.22020479,1.4551787,,,,,,,,,,,,,,,,, -106700,0.2172845,1.4900831,,,,,,,,,,,,,,,,, -106800,0.21288542,1.5139062,,,,,,,,,,,,,,,,, -106900,0.22423679,1.474293,,,,,,,,,,,,,,,,, -107000,0.2352636,1.6246008,,,,,,,,,,,,,,,,, -107100,0.22916372,1.5088234,,,,,,,,,,,,,,,,, -107200,0.22721222,1.5220306,,,,,,,,,,,,,,,,, -107300,0.24389073,1.5592242,,,,,,,,,,,,,,,,, -107400,0.24398696,1.5679904,,,,,,,,,,,,,,,,, -107500,0.2274668,1.4278054,,,,,,,,,,,,,,,,, -107600,0.24285,1.5027485,,,,,,,,,,,,,,,,, -107700,0.22992098,1.4478806,,,,,,,,,,,,,,,,, -107800,0.21795371,1.5363237,,,,,,,,,,,,,,,,, -107899,,,0.6916837692260742,1.393021583557129,35.40145042693628,0.691299557685852,1.383009910583496,30.558788910901555,3000.0,0.7079077363014221,1.2825238704681396,30.42644247752539,3003.0,38675.34786558151,64931.23558783531,38675.34786558151,26250.866422891617,1.5631628036499023,0.0 -107900,0.2099613,1.4518491,,,,,,,,,,,,,,,,, -108000,0.23288924,1.5165427,,,,,,,,,,,,,,,,, -108100,0.21717165,1.4948075,,,,,,,,,,,,,,,,, -108200,0.24099658,1.5297974,,,,,,,,,,,,,,,,, -108300,0.21954222,1.5355756,,,,,,,,,,,,,,,,, -108400,0.21921603,1.5124702,,,,,,,,,,,,,,,,, -108500,0.22179519,1.4347044,,,,,,,,,,,,,,,,, -108600,0.2262523,1.5152137,,,,,,,,,,,,,,,,, -108700,0.22067209,1.4707941,,,,,,,,,,,,,,,,, -108800,0.2211475,1.4222095,,,,,,,,,,,,,,,,, -108900,0.2246691,1.4853443,,,,,,,,,,,,,,,,, -109000,0.22330597,1.4946326,,,,,,,,,,,,,,,,, -109100,0.21940771,1.4748417,,,,,,,,,,,,,,,,, -109200,0.22713305,1.4974161,,,,,,,,,,,,,,,,, -109300,0.2324908,1.4803878,,,,,,,,,,,,,,,,, -109400,0.23051096,1.543386,,,,,,,,,,,,,,,,, -109500,0.23864554,1.5608709,,,,,,,,,,,,,,,,, -109600,0.21823715,1.4403839,,,,,,,,,,,,,,,,, -109700,0.22497943,1.5065795,,,,,,,,,,,,,,,,, -109800,0.22448118,1.4941889,,,,,,,,,,,,,,,,, -109900,0.2261248,1.5125567,,,,,,,,,,,,,,,,, -110000,0.2215708,1.4952974,,,,,,,,,,,,,,,,, -110100,0.22525187,1.480502,,,,,,,,,,,,,,,,, -110200,0.2389116,1.531275,,,,,,,,,,,,,,,,, -110245,,,0.6882879137992859,1.4097886085510254,35.19981584636805,0.6923038959503174,1.3791635036468506,30.56347426094465,3000.0,0.7103015780448914,1.2744451761245728,30.49992268140961,3003.0,39515.29098367691,66392.81823420525,39515.29098367691,26872.388967752457,1.603606939315796,0.0 -110300,0.22750159,1.4197017,,,,,,,,,,,,,,,,, -110400,0.24315758,1.489272,,,,,,,,,,,,,,,,, -110500,0.22780703,1.5829346,,,,,,,,,,,,,,,,, -110600,0.22353329,1.4704119,,,,,,,,,,,,,,,,, -110700,0.24034764,1.5512398,,,,,,,,,,,,,,,,, -110800,0.24294707,1.5051416,,,,,,,,,,,,,,,,, -110900,0.22409411,1.5315835,,,,,,,,,,,,,,,,, -111000,0.23275912,1.5615045,,,,,,,,,,,,,,,,, -111100,0.22180146,1.4402696,,,,,,,,,,,,,,,,, -111200,0.22499697,1.40247,,,,,,,,,,,,,,,,, -111300,0.22525737,1.5275282,,,,,,,,,,,,,,,,, -111400,0.24090639,1.4738946,,,,,,,,,,,,,,,,, -111500,0.24252668,1.5239171,,,,,,,,,,,,,,,,, -111600,0.23537111,1.4678667,,,,,,,,,,,,,,,,, -111700,0.22792915,1.4370629,,,,,,,,,,,,,,,,, -111800,0.22150813,1.5043677,,,,,,,,,,,,,,,,, -111900,0.23855165,1.4536704,,,,,,,,,,,,,,,,, -112000,0.22830264,1.4587519,,,,,,,,,,,,,,,,, -112100,0.24167897,1.4908562,,,,,,,,,,,,,,,,, -112200,0.22704762,1.4652245,,,,,,,,,,,,,,,,, -112300,0.22345291,1.5170217,,,,,,,,,,,,,,,,, -112400,0.22552903,1.4773278,,,,,,,,,,,,,,,,, -112500,0.2492572,1.5580595,,,,,,,,,,,,,,,,, -112591,,,0.6941087245941162,1.3747334480285645,35.41065961993999,0.6923410892486572,1.3770735263824463,30.704859704017625,3000.0,0.7099180817604065,1.2725780010223389,30.61117920436823,3003.0,40355.34690570831,67835.56606578827,40355.34690570831,27474.955493688583,1.654353380203247,0.0 -112600,0.23456676,1.5472214,,,,,,,,,,,,,,,,, -112700,0.217146,1.4802704,,,,,,,,,,,,,,,,, -112800,0.22089188,1.423032,,,,,,,,,,,,,,,,, -112900,0.23109965,1.4883432,,,,,,,,,,,,,,,,, -113000,0.22427705,1.4598768,,,,,,,,,,,,,,,,, -113100,0.25398955,1.5202686,,,,,,,,,,,,,,,,, -113200,0.23098621,1.4523399,,,,,,,,,,,,,,,,, -113300,0.23076808,1.5033265,,,,,,,,,,,,,,,,, -113400,0.2275931,1.4473451,,,,,,,,,,,,,,,,, -113500,0.23162797,1.5395291,,,,,,,,,,,,,,,,, -113600,0.22574838,1.4855055,,,,,,,,,,,,,,,,, -113700,0.23715815,1.4834162,,,,,,,,,,,,,,,,, -113800,0.24473673,1.5248197,,,,,,,,,,,,,,,,, -113900,0.22980034,1.5503179,,,,,,,,,,,,,,,,, -114000,0.23216188,1.4136912,,,,,,,,,,,,,,,,, -114100,0.23840101,1.5250707,,,,,,,,,,,,,,,,, -114200,0.23402913,1.4733955,,,,,,,,,,,,,,,,, -114300,0.22958517,1.512565,,,,,,,,,,,,,,,,, -114400,0.23690751,1.4591551,,,,,,,,,,,,,,,,, -114500,0.21556237,1.4407805,,,,,,,,,,,,,,,,, -114600,0.23325253,1.4703621,,,,,,,,,,,,,,,,, -114700,0.24044904,1.4873062,,,,,,,,,,,,,,,,, -114800,0.23534356,1.4889947,,,,,,,,,,,,,,,,, -114900,0.22332723,1.4331853,,,,,,,,,,,,,,,,, -114937,,,0.6922072172164917,1.393733263015747,35.2817522823305,0.693717360496521,1.3703649044036863,30.42187428646859,3000.0,0.7107896208763123,1.2660441398620603,30.542745113871938,3003.0,41195.55699467659,69419.99466729164,41195.55699467659,28219.04938960076,1.7018797397613523,0.0 -115000,0.2349152,1.4536656,,,,,,,,,,,,,,,,, -115100,0.23302092,1.4425172,,,,,,,,,,,,,,,,, -115200,0.23878804,1.4311919,,,,,,,,,,,,,,,,, -115300,0.24373084,1.5503608,,,,,,,,,,,,,,,,, -115400,0.23444612,1.4962575,,,,,,,,,,,,,,,,, -115500,0.22851254,1.4861996,,,,,,,,,,,,,,,,, -115600,0.2339696,1.4997215,,,,,,,,,,,,,,,,, -115700,0.22179973,1.4603211,,,,,,,,,,,,,,,,, -115800,0.22780915,1.4467552,,,,,,,,,,,,,,,,, -115900,0.23429573,1.5444891,,,,,,,,,,,,,,,,, -116000,0.23451374,1.400379,,,,,,,,,,,,,,,,, -116100,0.22649352,1.4287641,,,,,,,,,,,,,,,,, -116200,0.22683805,1.4954523,,,,,,,,,,,,,,,,, -116300,0.22679624,1.465934,,,,,,,,,,,,,,,,, -116400,0.22404,1.4088259,,,,,,,,,,,,,,,,, -116500,0.23226497,1.4825892,,,,,,,,,,,,,,,,, -116600,0.23239289,1.3836172,,,,,,,,,,,,,,,,, -116700,0.23704529,1.4781817,,,,,,,,,,,,,,,,, -116800,0.22914532,1.5136809,,,,,,,,,,,,,,,,, -116900,0.2360857,1.5181406,,,,,,,,,,,,,,,,, -117000,0.23322864,1.53026,,,,,,,,,,,,,,,,, -117100,0.23519298,1.50584,,,,,,,,,,,,,,,,, -117200,0.23208548,1.499388,,,,,,,,,,,,,,,,, -117283,,,0.690304696559906,1.3954484462738037,35.20293102881795,0.69430011510849,1.370716571807861,30.820244597158045,3000.0,0.7114868760108948,1.2633055448532104,30.76111965834509,3003.0,42035.75880694389,70851.79516196251,42035.75880694389,28810.52793598175,1.7455673217773438,0.0 -117300,0.23284356,1.4269193,,,,,,,,,,,,,,,,, -117400,0.24839275,1.5543472,,,,,,,,,,,,,,,,, -117500,0.22775386,1.4099962,,,,,,,,,,,,,,,,, -117600,0.22689201,1.4301968,,,,,,,,,,,,,,,,, -117700,0.23418577,1.5043377,,,,,,,,,,,,,,,,, -117800,0.2348802,1.4793807,,,,,,,,,,,,,,,,, -117900,0.23914562,1.4604559,,,,,,,,,,,,,,,,, -118000,0.23429935,1.4033751,,,,,,,,,,,,,,,,, -118100,0.23694764,1.5445663,,,,,,,,,,,,,,,,, -118200,0.22942285,1.477778,,,,,,,,,,,,,,,,, -118300,0.22641954,1.4908284,,,,,,,,,,,,,,,,, -118400,0.22131318,1.4770448,,,,,,,,,,,,,,,,, -118500,0.24087186,1.4854003,,,,,,,,,,,,,,,,, -118600,0.23201397,1.4483745,,,,,,,,,,,,,,,,, -118700,0.23878713,1.5403761,,,,,,,,,,,,,,,,, -118800,0.24465178,1.4580048,,,,,,,,,,,,,,,,, -118900,0.23220491,1.4452413,,,,,,,,,,,,,,,,, -119000,0.24153903,1.5311925,,,,,,,,,,,,,,,,, -119100,0.24303427,1.5223361,,,,,,,,,,,,,,,,, -119200,0.22830011,1.5012329,,,,,,,,,,,,,,,,, -119300,0.22242111,1.4584929,,,,,,,,,,,,,,,,, -119400,0.22949186,1.4313855,,,,,,,,,,,,,,,,, -119500,0.23995894,1.4129362,,,,,,,,,,,,,,,,, -119600,0.2292475,1.4412992,,,,,,,,,,,,,,,,, -119629,,,0.6958056688308716,1.3685253858566284,35.7102083223731,0.6936181783676147,1.36842942237854,30.65897634626175,3000.0,0.7119168043136597,1.2605042457580566,30.859965177480586,3003.0,42875.923796892166,72394.39634847641,42875.923796892166,29512.847000598907,1.788480281829834,0.0 -119700,0.2294474,1.453072,,,,,,,,,,,,,,,,, -119800,0.24372303,1.4194053,,,,,,,,,,,,,,,,, -119900,0.2368791,1.4807514,,,,,,,,,,,,,,,,, -120000,0.23872228,1.4575673,,,,,,,,,,,,,,,,, -120100,0.24073076,1.4913524,,,,,,,,,,,,,,,,, -120200,0.24219936,1.4807216,,,,,,,,,,,,,,,,, -120300,0.23669943,1.4918352,,,,,,,,,,,,,,,,, -120400,0.23672976,1.410567,,,,,,,,,,,,,,,,, -120500,0.25371176,1.5106063,,,,,,,,,,,,,,,,, -120600,0.23698314,1.5157372,,,,,,,,,,,,,,,,, -120700,0.2378684,1.4193364,,,,,,,,,,,,,,,,, -120800,0.23992841,1.4596083,,,,,,,,,,,,,,,,, -120900,0.22346398,1.461515,,,,,,,,,,,,,,,,, -121000,0.23691906,1.4770709,,,,,,,,,,,,,,,,, -121100,0.23646086,1.5360965,,,,,,,,,,,,,,,,, -121200,0.23276812,1.5012002,,,,,,,,,,,,,,,,, -121300,0.23551713,1.4890078,,,,,,,,,,,,,,,,, -121400,0.2283427,1.4418923,,,,,,,,,,,,,,,,, -121500,0.23924924,1.4417379,,,,,,,,,,,,,,,,, -121600,0.24251932,1.4592757,,,,,,,,,,,,,,,,, -121700,0.22903812,1.4557052,,,,,,,,,,,,,,,,, -121800,0.2351552,1.4909413,,,,,,,,,,,,,,,,, -121900,0.2336245,1.5196059,,,,,,,,,,,,,,,,, -121974,,,0.6915479302406311,1.388578176498413,35.36310249116494,0.6942381262779236,1.3653781414031982,30.68772665448068,3000.0,0.7125210762023926,1.2577069997787476,30.773637579945547,3003.0,43716.03023672104,73805.94238114357,43716.03023672104,30084.16514778137,1.832347393035889,0.0 -122000,0.24402435,1.4478855,,,,,,,,,,,,,,,,, -122100,0.22660361,1.3921736,,,,,,,,,,,,,,,,, -122200,0.23766302,1.4913889,,,,,,,,,,,,,,,,, -122300,0.24108678,1.4143661,,,,,,,,,,,,,,,,, -122400,0.24541205,1.4910551,,,,,,,,,,,,,,,,, -122500,0.2374271,1.4396708,,,,,,,,,,,,,,,,, -122600,0.22953111,1.4890759,,,,,,,,,,,,,,,,, -122700,0.22208145,1.4021947,,,,,,,,,,,,,,,,, -122800,0.23151219,1.5051887,,,,,,,,,,,,,,,,, -122900,0.24962352,1.4736578,,,,,,,,,,,,,,,,, -123000,0.25214326,1.5029778,,,,,,,,,,,,,,,,, -123100,0.23119466,1.4134061,,,,,,,,,,,,,,,,, -123200,0.24402782,1.5191435,,,,,,,,,,,,,,,,, -123300,0.23369162,1.431765,,,,,,,,,,,,,,,,, -123400,0.23553851,1.4651467,,,,,,,,,,,,,,,,, -123500,0.2454029,1.488551,,,,,,,,,,,,,,,,, -123600,0.23845057,1.4849594,,,,,,,,,,,,,,,,, -123700,0.24027278,1.5052052,,,,,,,,,,,,,,,,, -123800,0.23252039,1.4060457,,,,,,,,,,,,,,,,, -123900,0.24130283,1.4614906,,,,,,,,,,,,,,,,, -124000,0.23167793,1.4639057,,,,,,,,,,,,,,,,, -124100,0.23239812,1.479845,,,,,,,,,,,,,,,,, -124200,0.22272113,1.4046701,,,,,,,,,,,,,,,,, -124300,0.23209675,1.4974681,,,,,,,,,,,,,,,,, -124318,,,0.6958664655685425,1.3725371360778809,35.70916601186907,0.6946473121643066,1.364732027053833,30.91054033191716,3000.0,0.7131369709968567,1.2557777166366575,30.84595524052872,3003.0,44556.06577014923,75262.98744797707,44556.06577014923,30701.053248643875,1.876137018203736,0.0 -124400,0.23403926,1.4606553,,,,,,,,,,,,,,,,, -124500,0.22235827,1.4585434,,,,,,,,,,,,,,,,, -124600,0.24040341,1.4192528,,,,,,,,,,,,,,,,, -124700,0.23791525,1.4129708,,,,,,,,,,,,,,,,, -124800,0.24043691,1.507719,,,,,,,,,,,,,,,,, -124900,0.23341472,1.4473294,,,,,,,,,,,,,,,,, -125000,0.23579091,1.4501146,,,,,,,,,,,,,,,,, -125100,0.23942418,1.5217501,,,,,,,,,,,,,,,,, -125200,0.2331575,1.4301556,,,,,,,,,,,,,,,,, -125300,0.23973915,1.4497246,,,,,,,,,,,,,,,,, -125400,0.2398029,1.5171161,,,,,,,,,,,,,,,,, -125500,0.23332955,1.4922832,,,,,,,,,,,,,,,,, -125600,0.22865798,1.4512935,,,,,,,,,,,,,,,,, -125700,0.23508066,1.4660505,,,,,,,,,,,,,,,,, -125800,0.24730262,1.4659498,,,,,,,,,,,,,,,,, -125900,0.23264137,1.3490759,,,,,,,,,,,,,,,,, -126000,0.24251801,1.4943436,,,,,,,,,,,,,,,,, -126100,0.2234755,1.4812621,,,,,,,,,,,,,,,,, -126200,0.23583336,1.3957599,,,,,,,,,,,,,,,,, -126300,0.24912585,1.5136313,,,,,,,,,,,,,,,,, -126400,0.2333872,1.475882,,,,,,,,,,,,,,,,, -126500,0.23112905,1.3878393,,,,,,,,,,,,,,,,, -126600,0.24432084,1.4814276,,,,,,,,,,,,,,,,, -126663,,,0.6963878273963928,1.3677339553833008,35.3685362955417,0.6950316429138184,1.3643397092819214,30.77219468305123,3000.0,0.7128929495811462,1.2559876441955566,30.87501995821473,3003.0,45396.00702667236,76722.7999753952,45396.00702667236,31320.80184316635,1.9225962162017824,0.0 -126700,0.22516921,1.3928032,,,,,,,,,,,,,,,,, -126800,0.2343817,1.4464201,,,,,,,,,,,,,,,,, -126900,0.24128747,1.5206759,,,,,,,,,,,,,,,,, -127000,0.2294833,1.410466,,,,,,,,,,,,,,,,, -127100,0.23871407,1.4773302,,,,,,,,,,,,,,,,, -127200,0.23948745,1.4939637,,,,,,,,,,,,,,,,, -127300,0.24233745,1.4035499,,,,,,,,,,,,,,,,, -127400,0.25172332,1.4253914,,,,,,,,,,,,,,,,, -127500,0.2388356,1.4832178,,,,,,,,,,,,,,,,, -127600,0.23887822,1.5033064,,,,,,,,,,,,,,,,, -127700,0.23040986,1.5296183,,,,,,,,,,,,,,,,, -127800,0.23408146,1.4497333,,,,,,,,,,,,,,,,, -127900,0.22866136,1.4984776,,,,,,,,,,,,,,,,, -128000,0.23864982,1.4415188,,,,,,,,,,,,,,,,, -128100,0.23133658,1.5636085,,,,,,,,,,,,,,,,, -128200,0.23688601,1.4790673,,,,,,,,,,,,,,,,, -128300,0.2341617,1.4988688,,,,,,,,,,,,,,,,, -128400,0.23572648,1.4232867,,,,,,,,,,,,,,,,, -128500,0.23276354,1.4115652,,,,,,,,,,,,,,,,, -128600,0.2247555,1.453373,,,,,,,,,,,,,,,,, -128700,0.22669266,1.4663815,,,,,,,,,,,,,,,,, -128800,0.24337028,1.4871867,,,,,,,,,,,,,,,,, -128900,0.22911456,1.4571748,,,,,,,,,,,,,,,,, -129000,0.24391016,1.4484441,,,,,,,,,,,,,,,,, -129010,,,0.6947799324989319,1.3747010231018066,35.44969392992167,0.6953044533729553,1.3621292114257812,30.786053998723386,3000.0,0.7133809924125671,1.2536063194274902,30.837718893970138,3003.0,46236.22338271141,78223.86001873016,46236.22338271141,31981.52465200424,1.9672255516052248,0.0 -129100,0.231198,1.5257868,,,,,,,,,,,,,,,,, -129200,0.23917834,1.5006934,,,,,,,,,,,,,,,,, -129300,0.2396898,1.5363959,,,,,,,,,,,,,,,,, -129400,0.25180173,1.5172535,,,,,,,,,,,,,,,,, -129500,0.23021491,1.4313586,,,,,,,,,,,,,,,,, -129600,0.22946022,1.5167185,,,,,,,,,,,,,,,,, -129700,0.2295412,1.530685,,,,,,,,,,,,,,,,, -129800,0.23333474,1.4507636,,,,,,,,,,,,,,,,, -129900,0.23380363,1.4427872,,,,,,,,,,,,,,,,, -130000,0.22633617,1.4700953,,,,,,,,,,,,,,,,, -130100,0.24452752,1.5061978,,,,,,,,,,,,,,,,, -130200,0.22463976,1.4833918,,,,,,,,,,,,,,,,, -130300,0.2407794,1.4534472,,,,,,,,,,,,,,,,, -130400,0.22640614,1.3911986,,,,,,,,,,,,,,,,, -130500,0.23335373,1.4557195,,,,,,,,,,,,,,,,, -130600,0.24072205,1.3988674,,,,,,,,,,,,,,,,, -130700,0.23806164,1.4710193,,,,,,,,,,,,,,,,, -130800,0.2432479,1.5066171,,,,,,,,,,,,,,,,, -130900,0.24282037,1.3695887,,,,,,,,,,,,,,,,, -131000,0.24300234,1.388023,,,,,,,,,,,,,,,,, -131100,0.22458771,1.3993438,,,,,,,,,,,,,,,,, -131200,0.23947643,1.4638684,,,,,,,,,,,,,,,,, -131300,0.22734816,1.471827,,,,,,,,,,,,,,,,, -131357,,,0.6980651617050171,1.3623944520950315,35.745533977811,0.6951928734779358,1.3623720407485962,30.81682333868408,3000.0,0.7132763862609863,1.2537087202072144,30.857126776061257,3003.0,47076.33667135239,79719.39379668236,47076.33667135239,32636.82600307465,2.012143135070801,0.0 -131400,0.23556186,1.4829019,,,,,,,,,,,,,,,,, -131500,0.23309986,1.4772606,,,,,,,,,,,,,,,,, -131600,0.24226202,1.4369755,,,,,,,,,,,,,,,,, -131700,0.23124965,1.4703295,,,,,,,,,,,,,,,,, -131800,0.23734821,1.5133003,,,,,,,,,,,,,,,,, -131900,0.23060772,1.4255548,,,,,,,,,,,,,,,,, -132000,0.22255813,1.4377238,,,,,,,,,,,,,,,,, -132100,0.24007018,1.4913591,,,,,,,,,,,,,,,,, -132200,0.24552457,1.4826179,,,,,,,,,,,,,,,,, -132300,0.23786439,1.471962,,,,,,,,,,,,,,,,, -132400,0.23707321,1.5127589,,,,,,,,,,,,,,,,, -132500,0.23154138,1.4090691,,,,,,,,,,,,,,,,, -132600,0.24261683,1.494104,,,,,,,,,,,,,,,,, -132700,0.22968712,1.5253218,,,,,,,,,,,,,,,,, -132800,0.22074147,1.4171988,,,,,,,,,,,,,,,,, -132900,0.22354947,1.4206939,,,,,,,,,,,,,,,,, -133000,0.23782404,1.4628651,,,,,,,,,,,,,,,,, -133100,0.23701704,1.4334042,,,,,,,,,,,,,,,,, -133200,0.23096658,1.3800336,,,,,,,,,,,,,,,,, -133300,0.23601046,1.5506622,,,,,,,,,,,,,,,,, -133333,,,0.6977786421775818,1.358901858329773,35.85408950071874,0.6951680779457092,1.3623775243759155,30.78454367159104,3000.0,0.7132763862609863,1.253868579864502,30.85216689299744,3003.0,47783.83299946785,81060.76957821846,47783.83299946785,33270.594187021255,2.060373544692993,0.0 -133333,,,,,,,,,,,,,,47783.83299946785,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 9cd4df208..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -770.6096827983856,0.0,20.892271518707275,1,0,20.892271518707275,1.5194805184210527,95000000,791.5020091533661,1.519693479597943,1.5212185261462023,83274637 -1402.949651002884,0.0265877246856689,1221.243737697601,1623,0,1221.243737697601,0.1287275309518914,95000000,2624.271951198578,0.1242260946042882,0.126126404030062,83274637 -1947.666404247284,0.0520765781402587,2421.242283344269,3236,0,2421.242283344269,0.1276052690583881,95000000,4369.06401515007,0.1254858712917604,0.1251327878424135,83274637 -2479.767473459244,0.0789272785186767,3621.582787752152,4855,0,3621.582787752152,0.1270578274876644,95000000,6101.583913326263,0.1232307549074011,0.1247294805830255,83274637 -2976.7446382045746,0.1071887016296386,4821.726407766342,6483,0,4821.726407766342,0.1265536055407072,95000000,7798.784013032913,0.1239696584126484,0.1241865348903442,83274637 -3418.107749462128,0.1339259147644043,6021.679481744766,8093,0,6021.679481744766,0.1263109411800987,95000000,9440.177884578705,0.1217597047769048,0.1239406133459038,83274637 -3735.948078393936,0.15744614601135254,7221.771959543228,9706,0,7221.771959543228,0.12604977984169408,95000000,10958.185050487518,0.12191009296561187,0.123722453762017,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 176ae62ba..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,113 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,9.0239935,1.5136962,,,,,,,,,,, -1,,,1.519693479597943,1.5212185261462023,83274637.0,1.5194805184210527,95000000.0,20.892271518707275,791.5020091533661,20.892271518707275,770.6096827983856,0.0,0.0 -100,0.59149855,0.1703465,,,,,,,,,,, -200,0.04087948,0.13404705,,,,,,,,,,, -300,0.021523552,0.14619005,,,,,,,,,,, -400,0.03592085,0.12268497,,,,,,,,,,, -500,0.009572293,0.11867895,,,,,,,,,,, -600,0.05511384,0.1342735,,,,,,,,,,, -700,0.03770047,0.13297136,,,,,,,,,,, -800,0.013373003,0.12577246,,,,,,,,,,, -900,0.009553156,0.12183991,,,,,,,,,,, -1000,0.02430809,0.11983143,,,,,,,,,,, -1100,0.011022753,0.13396275,,,,,,,,,,, -1200,0.03245219,0.122072935,,,,,,,,,,, -1300,0.017761579,0.12638302,,,,,,,,,,, -1400,0.012035799,0.127795,,,,,,,,,,, -1500,0.015720312,0.12385158,,,,,,,,,,, -1600,0.03990823,0.12158015,,,,,,,,,,, -1623,,,0.1242260946042882,0.126126404030062,83274637.0,0.1287275309518914,95000000.0,1221.243737697601,2624.271951198578,1221.243737697601,1402.949651002884,0.0265877246856689,0.0 -1700,0.03195682,0.13137847,,,,,,,,,,, -1800,0.025007427,0.13372931,,,,,,,,,,, -1900,0.02280598,0.12515974,,,,,,,,,,, -2000,0.026964666,0.12542129,,,,,,,,,,, -2100,0.026957298,0.13155656,,,,,,,,,,, -2200,0.018338647,0.13025206,,,,,,,,,,, -2300,0.031066718,0.121651344,,,,,,,,,,, -2400,0.00938783,0.121805,,,,,,,,,,, -2500,0.010528534,0.12850595,,,,,,,,,,, -2600,0.0068483697,0.1222548,,,,,,,,,,, -2700,0.0065244962,0.13450946,,,,,,,,,,, -2800,0.016117522,0.12625737,,,,,,,,,,, -2900,0.0334901,0.13021202,,,,,,,,,,, -3000,0.013702958,0.12835746,,,,,,,,,,, -3100,0.03856695,0.13027631,,,,,,,,,,, -3200,0.006470311,0.12004549,,,,,,,,,,, -3236,,,0.1254858712917604,0.1251327878424135,83274637.0,0.1276052690583881,95000000.0,2421.242283344269,4369.06401515007,2421.242283344269,1947.666404247284,0.0520765781402587,0.0 -3300,0.020194132,0.119127356,,,,,,,,,,, -3400,0.008855129,0.11943017,,,,,,,,,,, -3500,0.009954478,0.119130105,,,,,,,,,,, -3600,0.012218357,0.12673001,,,,,,,,,,, -3700,0.032219823,0.12077761,,,,,,,,,,, -3800,0.010829806,0.11760385,,,,,,,,,,, -3900,0.024559315,0.12997384,,,,,,,,,,, -4000,0.014479304,0.12122216,,,,,,,,,,, -4100,0.01986222,0.13290569,,,,,,,,,,, -4200,0.006717561,0.11834721,,,,,,,,,,, -4300,0.009315124,0.1294225,,,,,,,,,,, -4400,0.0068176277,0.12510152,,,,,,,,,,, -4500,0.02488536,0.12558384,,,,,,,,,,, -4600,0.013478547,0.11703863,,,,,,,,,,, -4700,0.010237155,0.12187064,,,,,,,,,,, -4800,0.0095921485,0.11672053,,,,,,,,,,, -4855,,,0.1232307549074011,0.1247294805830255,83274637.0,0.1270578274876644,95000000.0,3621.582787752152,6101.583913326263,3621.582787752152,2479.767473459244,0.0789272785186767,0.0 -4900,0.026922015,0.1269814,,,,,,,,,,, -5000,0.00815737,0.114370674,,,,,,,,,,, -5100,0.008869592,0.122818425,,,,,,,,,,, -5200,0.008722714,0.12702192,,,,,,,,,,, -5300,0.03836268,0.12062959,,,,,,,,,,, -5400,0.014714242,0.124275826,,,,,,,,,,, -5500,0.0069729434,0.11983302,,,,,,,,,,, -5600,0.013342663,0.12564726,,,,,,,,,,, -5700,0.008813415,0.12105765,,,,,,,,,,, -5800,0.008987973,0.12224325,,,,,,,,,,, -5900,0.008721336,0.12186219,,,,,,,,,,, -6000,0.008517333,0.1277755,,,,,,,,,,, -6100,0.012254136,0.12728721,,,,,,,,,,, -6200,0.007195817,0.116240196,,,,,,,,,,, -6300,0.008680577,0.127293,,,,,,,,,,, -6400,0.008285586,0.118738115,,,,,,,,,,, -6483,,,0.1239696584126484,0.1241865348903442,83274637.0,0.1265536055407072,95000000.0,4821.726407766342,7798.784013032913,4821.726407766342,2976.7446382045746,0.1071887016296386,0.0 -6500,0.011000238,0.116373375,,,,,,,,,,, -6600,0.012804175,0.1211956,,,,,,,,,,, -6700,0.009167755,0.122268915,,,,,,,,,,, -6800,0.020086858,0.12017112,,,,,,,,,,, -6900,0.0070550824,0.123872876,,,,,,,,,,, -7000,0.008102303,0.12027957,,,,,,,,,,, -7100,0.0067004887,0.12688595,,,,,,,,,,, -7200,0.0106264725,0.13232486,,,,,,,,,,, -7300,0.007750937,0.11613925,,,,,,,,,,, -7400,0.013071285,0.11827241,,,,,,,,,,, -7500,0.013380685,0.12251297,,,,,,,,,,, -7600,0.011878159,0.118871026,,,,,,,,,,, -7700,0.018131997,0.119680986,,,,,,,,,,, -7800,0.007953035,0.118180044,,,,,,,,,,, -7900,0.010707397,0.12762606,,,,,,,,,,, -8000,0.009229192,0.13282457,,,,,,,,,,, -8093,,,0.1217597047769048,0.1239406133459038,83274637.0,0.1263109411800987,95000000.0,6021.679481744766,9440.177884578705,6021.679481744766,3418.107749462128,0.1339259147644043,0.0 -8100,0.019505901,0.12207686,,,,,,,,,,, -8200,0.007951799,0.12521584,,,,,,,,,,, -8300,0.013263588,0.12134315,,,,,,,,,,, -8400,0.01783518,0.1193119,,,,,,,,,,, -8500,0.011419841,0.12251294,,,,,,,,,,, -8600,0.019996384,0.119056664,,,,,,,,,,, -8700,0.00727815,0.12429921,,,,,,,,,,, -8800,0.012584212,0.1149823,,,,,,,,,,, -8900,0.033738945,0.12326522,,,,,,,,,,, -9000,0.008871404,0.122448474,,,,,,,,,,, -9100,0.01345751,0.12050012,,,,,,,,,,, -9200,0.017661583,0.120287046,,,,,,,,,,, -9300,0.018736262,0.1252799,,,,,,,,,,, -9400,0.007191836,0.122741856,,,,,,,,,,, -9500,0.011903811,0.12667638,,,,,,,,,,, -9600,0.008293518,0.11927164,,,,,,,,,,, -9700,0.008751279,0.12274052,,,,,,,,,,, -9706,,,0.1219100929656118,0.123722453762017,83274637.0,0.126049779841694,95000000.0,7221.771959543228,10958.185050487518,7221.771959543228,3735.948078393936,0.1574461460113525,0.0 -9800,0.013005114,0.124885745,,,,,,,,,,, -9900,0.01358484,0.12052819,,,,,,,,,,, -10000,0.02631137,0.12445898,,,,,,,,,,, -10100,0.016628956,0.11980736,,,,,,,,,,, -10200,0.025504725,0.13037999,,,,,,,,,,, -10300,0.015183276,0.12677215,,,,,,,,,,, -10363,,,,,,,,7703.379245758057,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/eval_measurements.csv deleted file mode 100644 index cb7ad9c4e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -159.1849820613861,0.0,5.877012491226196,1,0,5.877012491226196,1.5194805184210527,95000000,165.06203317642212,1.5198267240944148,1.5212185261462023,83274637 -183.4349095821381,0.0187382698059082,1206.2299313545227,1441,0,1206.2299313545227,0.1293855753186677,95000000,1389.7307302951813,0.1235415176869188,0.1271633574300359,83274637 -207.0030398368836,0.0434603691101074,2406.5499982833862,2903,0,2406.5499982833862,0.1287444120682565,95000000,2613.691510438919,0.1266419379122602,0.1266030428628587,83274637 -229.8640608787537,0.0680716037750244,3606.9658749103546,4372,0,3606.9658749103546,0.127792316899671,95000000,3837.0405600070953,0.1236161842661083,0.1252928578699478,83274637 -253.03732204437256,0.0901024341583252,4807.48178434372,5847,0,4807.48178434372,0.1266098435032894,95000000,5060.800463438034,0.1234146692561653,0.1242218719364546,83274637 -276.6709907054901,0.1133763790130615,6008.018783092499,7325,0,6008.018783092499,0.1265671436060855,95000000,6285.041886091232,0.1219054411792155,0.1241356377437586,83274637 -300.1837706565857,0.13881635665893555,7208.39083981514,8800,0,7208.39083981514,0.12635513429276315,95000000,7508.999789476395,0.1216050774421332,0.12398396416156097,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/measurements.csv deleted file mode 100644 index ed324dada..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/measurements.csv +++ /dev/null @@ -1,103 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,9.141044,1.522038,,,,,,,,,,, -1,,,1.5198267240944148,1.5212185261462023,83274637.0,1.5194805184210527,95000000.0,5.877012491226196,165.06203317642212,5.877012491226196,159.1849820613861,0.0,0.0 -100,0.029509636,0.12898318,,,,,,,,,,, -200,0.16108988,0.1222863,,,,,,,,,,, -300,0.15741962,0.12679121,,,,,,,,,,, -400,0.05136266,0.1268577,,,,,,,,,,, -500,0.022103881,0.12500967,,,,,,,,,,, -600,0.06452728,0.12308533,,,,,,,,,,, -700,0.021651246,0.12363692,,,,,,,,,,, -800,0.16871984,0.12561443,,,,,,,,,,, -900,0.059794873,0.1193475,,,,,,,,,,, -1000,0.06656663,0.12685907,,,,,,,,,,, -1100,0.059472706,0.12826228,,,,,,,,,,, -1200,0.09216242,0.12615025,,,,,,,,,,, -1300,0.031314902,0.13133965,,,,,,,,,,, -1400,0.08692165,0.12122017,,,,,,,,,,, -1441,,,0.1235415176869188,0.1271633574300359,83274637.0,0.1293855753186677,95000000.0,1206.2299313545227,1389.7307302951813,1206.2299313545227,183.4349095821381,0.0187382698059082,0.0 -1500,0.025581548,0.12067516,,,,,,,,,,, -1600,0.17085254,0.1329253,,,,,,,,,,, -1700,0.04083709,0.12270526,,,,,,,,,,, -1800,0.008738771,0.12800668,,,,,,,,,,, -1900,0.046174258,0.13082995,,,,,,,,,,, -2000,0.016886167,0.122373454,,,,,,,,,,, -2100,0.018704599,0.12335165,,,,,,,,,,, -2200,0.005856851,0.119661905,,,,,,,,,,, -2300,0.08850847,0.13011678,,,,,,,,,,, -2400,0.024382558,0.13192332,,,,,,,,,,, -2500,0.014884037,0.12971918,,,,,,,,,,, -2600,0.020081418,0.121881224,,,,,,,,,,, -2700,0.05044572,0.13192363,,,,,,,,,,, -2800,0.032977186,0.12067364,,,,,,,,,,, -2900,0.14176668,0.13593026,,,,,,,,,,, -2903,,,0.1266419379122602,0.1266030428628587,83274637.0,0.1287444120682565,95000000.0,2406.5499982833862,2613.691510438919,2406.5499982833862,207.0030398368836,0.0434603691101074,0.0 -3000,0.023513595,0.122455366,,,,,,,,,,, -3100,0.010552226,0.12366609,,,,,,,,,,, -3200,0.015673682,0.11809164,,,,,,,,,,, -3300,0.06736789,0.12010399,,,,,,,,,,, -3400,0.019637454,0.11710611,,,,,,,,,,, -3500,0.026441215,0.12128248,,,,,,,,,,, -3600,0.036454596,0.120048754,,,,,,,,,,, -3700,0.02028305,0.12850502,,,,,,,,,,, -3800,0.00985035,0.12691496,,,,,,,,,,, -3900,0.06523962,0.121429734,,,,,,,,,,, -4000,0.017974233,0.11903073,,,,,,,,,,, -4100,0.008326593,0.11969763,,,,,,,,,,, -4200,0.028637735,0.13343042,,,,,,,,,,, -4300,0.042858537,0.11929269,,,,,,,,,,, -4372,,,0.1236161842661083,0.1252928578699478,83274637.0,0.127792316899671,95000000.0,3606.9658749103546,3837.0405600070953,3606.9658749103546,229.8640608787537,0.0680716037750244,0.0 -4400,0.01097906,0.12173028,,,,,,,,,,, -4500,0.007891063,0.117966995,,,,,,,,,,, -4600,0.005846583,0.11575361,,,,,,,,,,, -4700,0.03529255,0.11872265,,,,,,,,,,, -4800,0.019800492,0.12481357,,,,,,,,,,, -4900,0.03536206,0.12239019,,,,,,,,,,, -5000,0.008775658,0.124236345,,,,,,,,,,, -5100,0.017312368,0.11828873,,,,,,,,,,, -5200,0.030533997,0.11988267,,,,,,,,,,, -5300,0.026537478,0.1253702,,,,,,,,,,, -5400,0.009561846,0.12137674,,,,,,,,,,, -5500,0.014828515,0.124845885,,,,,,,,,,, -5600,0.010699857,0.118921176,,,,,,,,,,, -5700,0.017329592,0.124656126,,,,,,,,,,, -5800,0.015498795,0.119325444,,,,,,,,,,, -5847,,,0.1234146692561653,0.1242218719364546,83274637.0,0.1266098435032894,95000000.0,4807.48178434372,5060.800463438034,4807.48178434372,253.03732204437256,0.0901024341583252,0.0 -5900,0.018098323,0.12174201,,,,,,,,,,, -6000,0.047630414,0.12585321,,,,,,,,,,, -6100,0.005875857,0.12588681,,,,,,,,,,, -6200,0.016659914,0.12095709,,,,,,,,,,, -6300,0.006972202,0.12697639,,,,,,,,,,, -6400,0.008517646,0.118689746,,,,,,,,,,, -6500,0.009840013,0.1218954,,,,,,,,,,, -6600,0.012757383,0.12390421,,,,,,,,,,, -6700,0.007750559,0.11755807,,,,,,,,,,, -6800,0.018217368,0.11514546,,,,,,,,,,, -6900,0.019511223,0.12425683,,,,,,,,,,, -7000,0.023977311,0.12543635,,,,,,,,,,, -7100,0.007794371,0.12314555,,,,,,,,,,, -7200,0.006280556,0.119545005,,,,,,,,,,, -7300,0.0122814225,0.122455455,,,,,,,,,,, -7325,,,0.1219054411792155,0.1241356377437586,83274637.0,0.1265671436060855,95000000.0,6008.018783092499,6285.041886091232,6008.018783092499,276.6709907054901,0.1133763790130615,0.0 -7400,0.019129666,0.12135615,,,,,,,,,,, -7500,0.010784853,0.12741223,,,,,,,,,,, -7600,0.008081019,0.1184124,,,,,,,,,,, -7700,0.014115081,0.12963015,,,,,,,,,,, -7800,0.013089507,0.12927033,,,,,,,,,,, -7900,0.028797077,0.11962779,,,,,,,,,,, -8000,0.007995351,0.12213321,,,,,,,,,,, -8100,0.0062822173,0.116837524,,,,,,,,,,, -8200,0.0065623317,0.11460176,,,,,,,,,,, -8300,0.0060914773,0.12163823,,,,,,,,,,, -8400,0.008659464,0.12170176,,,,,,,,,,, -8500,0.008099075,0.118481606,,,,,,,,,,, -8600,0.008841545,0.119055465,,,,,,,,,,, -8700,0.006265677,0.12477807,,,,,,,,,,, -8800,,,0.1216050774421332,0.1239839641615609,83274637.0,0.1263551342927631,95000000.0,7208.39083981514,7508.999789476395,7208.39083981514,300.1837706565857,0.1388163566589355,0.0 -8800,0.0101979235,0.12692174,,,,,,,,,,, -8900,0.0069789784,0.11799255,,,,,,,,,,, -9000,0.010040322,0.122477934,,,,,,,,,,, -9100,0.007137316,0.12703854,,,,,,,,,,, -9200,0.007042263,0.12285716,,,,,,,,,,, -9300,0.009686956,0.123919636,,,,,,,,,,, -9353,,,,,,,,7703.8057949543,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/eval_measurements.csv deleted file mode 100644 index cfc3a5391..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -23.47407269477844,0.0,5.458779335021973,1,0,5.458779335021973,1.5194805184210527,95000000,28.932892322540283,1.5213117389558997,1.5212185261462023,83274637 -46.20129990577698,0.0184102058410644,1205.438355922699,1492,0,1205.438355922699,0.1282428620682565,95000000,1251.7059638500214,0.1246110814266234,0.1259296333866036,83274637 -68.91265916824341,0.0441446304321289,2405.723908662796,2968,0,2405.723908662796,0.1277512381373355,95000000,2474.776403188705,0.1217813355461606,0.1253189053416408,83274637 -91.86134266853333,0.0815405845642089,3606.100022315979,4452,0,3606.100022315979,0.1268565652652138,95000000,3698.186697721481,0.1252774153865358,0.1246079937786909,83274637 -114.52473449707033,0.1039731502532959,4806.658087968826,5929,0,4806.658087968826,0.1267495561369243,95000000,4921.478246450424,0.1227557818938351,0.1244513198932497,83274637 -137.27002453804016,0.1286251544952392,6006.600213289261,7390,0,6006.600213289261,0.1264927567125822,95000000,6144.236940383911,0.1221328321492896,0.1241153744086569,83274637 -163.1682367324829,0.1530437469482422,7206.640828609467,8856,0,7206.640828609467,0.12617349113898027,95000000,7370.247405529022,0.12047344727336236,0.12387181979700494,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/measurements.csv deleted file mode 100644 index 3ad36c178..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/measurements.csv +++ /dev/null @@ -1,103 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.940546,1.5185663,,,,,,,,,,, -1,,,1.5213117389558997,1.5212185261462023,83274637.0,1.5194805184210527,95000000.0,5.458779335021973,28.932892322540283,5.458779335021973,23.47407269477844,0.0,0.0 -100,0.42423666,0.16428328,,,,,,,,,,, -200,0.029830093,0.1272587,,,,,,,,,,, -300,0.038851753,0.12128369,,,,,,,,,,, -400,0.011392353,0.12422067,,,,,,,,,,, -500,0.018292004,0.12449877,,,,,,,,,,, -600,0.053326566,0.12571418,,,,,,,,,,, -700,0.09233309,0.1337107,,,,,,,,,,, -800,0.017438654,0.12779509,,,,,,,,,,, -900,0.03641839,0.13114932,,,,,,,,,,, -1000,0.008883786,0.1261205,,,,,,,,,,, -1100,0.021566646,0.118915394,,,,,,,,,,, -1200,0.015532701,0.1312403,,,,,,,,,,, -1300,0.049245767,0.12842114,,,,,,,,,,, -1400,0.030681847,0.12506728,,,,,,,,,,, -1492,,,0.1246110814266234,0.1259296333866036,83274637.0,0.1282428620682565,95000000.0,1205.438355922699,1251.7059638500214,1205.438355922699,46.20129990577698,0.0184102058410644,0.0 -1500,0.038815387,0.12918521,,,,,,,,,,, -1600,0.007928085,0.12504672,,,,,,,,,,, -1700,0.025933279,0.12186091,,,,,,,,,,, -1800,0.008465162,0.124234736,,,,,,,,,,, -1900,0.010194955,0.12743565,,,,,,,,,,, -2000,0.013820967,0.12577355,,,,,,,,,,, -2100,0.03707024,0.123233676,,,,,,,,,,, -2200,0.011483325,0.12168202,,,,,,,,,,, -2300,0.0077533955,0.11896754,,,,,,,,,,, -2400,0.010134891,0.124017484,,,,,,,,,,, -2500,0.028288875,0.12432,,,,,,,,,,, -2600,0.032892574,0.11870153,,,,,,,,,,, -2700,0.025055464,0.12450225,,,,,,,,,,, -2800,0.030068146,0.13703676,,,,,,,,,,, -2900,0.016364375,0.11324992,,,,,,,,,,, -2968,,,0.1217813355461606,0.1253189053416408,83274637.0,0.1277512381373355,95000000.0,2405.723908662796,2474.776403188705,2405.723908662796,68.91265916824341,0.0441446304321289,0.0 -3000,0.04027376,0.12175454,,,,,,,,,,, -3100,0.017414184,0.12067249,,,,,,,,,,, -3200,0.025671285,0.1186437,,,,,,,,,,, -3300,0.033329338,0.12136285,,,,,,,,,,, -3400,0.011007662,0.11594242,,,,,,,,,,, -3500,0.039957684,0.11889977,,,,,,,,,,, -3600,0.010171148,0.12036586,,,,,,,,,,, -3700,0.008768076,0.12586772,,,,,,,,,,, -3800,0.009480739,0.119772084,,,,,,,,,,, -3900,0.018976681,0.12818854,,,,,,,,,,, -4000,0.0073646144,0.13338082,,,,,,,,,,, -4100,0.021816414,0.12161431,,,,,,,,,,, -4200,0.024112197,0.11944003,,,,,,,,,,, -4300,0.0070140306,0.124862365,,,,,,,,,,, -4400,0.0067045754,0.11568033,,,,,,,,,,, -4452,,,0.1252774153865358,0.1246079937786909,83274637.0,0.1268565652652138,95000000.0,3606.100022315979,3698.186697721481,3606.100022315979,91.86134266853333,0.0815405845642089,0.0 -4500,0.007038383,0.124562584,,,,,,,,,,, -4600,0.020638265,0.12487332,,,,,,,,,,, -4700,0.020351037,0.12251453,,,,,,,,,,, -4800,0.023124374,0.12250568,,,,,,,,,,, -4900,0.014623508,0.1172962,,,,,,,,,,, -5000,0.0070752623,0.13285646,,,,,,,,,,, -5100,0.025643012,0.12028953,,,,,,,,,,, -5200,0.022638634,0.11903548,,,,,,,,,,, -5300,0.0063002217,0.1168968,,,,,,,,,,, -5400,0.024137056,0.12164642,,,,,,,,,,, -5500,0.013465724,0.12345033,,,,,,,,,,, -5600,0.026758421,0.123027466,,,,,,,,,,, -5700,0.013569683,0.11983955,,,,,,,,,,, -5800,0.022563944,0.115497336,,,,,,,,,,, -5900,0.011300475,0.12557617,,,,,,,,,,, -5929,,,0.1227557818938351,0.1244513198932497,83274637.0,0.1267495561369243,95000000.0,4806.658087968826,4921.478246450424,4806.658087968826,114.52473449707033,0.1039731502532959,0.0 -6000,0.019101653,0.122896075,,,,,,,,,,, -6100,0.018793043,0.13892554,,,,,,,,,,, -6200,0.016119407,0.1251545,,,,,,,,,,, -6300,0.0067649917,0.1251883,,,,,,,,,,, -6400,0.018077059,0.12586333,,,,,,,,,,, -6500,0.0072056763,0.12781656,,,,,,,,,,, -6600,0.021925278,0.12243565,,,,,,,,,,, -6700,0.008675363,0.12134902,,,,,,,,,,, -6800,0.011822452,0.1216983,,,,,,,,,,, -6900,0.010866141,0.12149359,,,,,,,,,,, -7000,0.009062803,0.11821966,,,,,,,,,,, -7100,0.008510246,0.11680034,,,,,,,,,,, -7200,0.009611678,0.1240183,,,,,,,,,,, -7300,0.010121631,0.13322747,,,,,,,,,,, -7390,,,0.1221328321492896,0.1241153744086569,83274637.0,0.1264927567125822,95000000.0,6006.600213289261,6144.236940383911,6006.600213289261,137.27002453804016,0.1286251544952392,0.0 -7400,0.0074158437,0.12179398,,,,,,,,,,, -7500,0.008030055,0.1280114,,,,,,,,,,, -7600,0.010911648,0.1346384,,,,,,,,,,, -7700,0.008433844,0.12688068,,,,,,,,,,, -7800,0.007538095,0.122660816,,,,,,,,,,, -7900,0.010470952,0.12239016,,,,,,,,,,, -8000,0.01690232,0.12721762,,,,,,,,,,, -8100,0.008596472,0.12176262,,,,,,,,,,, -8200,0.012390718,0.13535866,,,,,,,,,,, -8300,0.009677573,0.11919252,,,,,,,,,,, -8400,0.0079544885,0.11925681,,,,,,,,,,, -8500,0.0069633406,0.118404076,,,,,,,,,,, -8600,0.007858113,0.118908696,,,,,,,,,,, -8700,0.00744415,0.11500428,,,,,,,,,,, -8800,0.014732375,0.116726995,,,,,,,,,,, -8856,,,0.1204734472733623,0.1238718197970049,83274637.0,0.1261734911389802,95000000.0,7206.640828609467,7370.247405529022,7206.640828609467,163.1682367324829,0.1530437469482422,0.0 -8900,0.008766533,0.12900189,,,,,,,,,,, -9000,0.008402598,0.12663768,,,,,,,,,,, -9100,0.026436804,0.12884024,,,,,,,,,,, -9200,0.020576408,0.12263823,,,,,,,,,,, -9300,0.0061635803,0.11844325,,,,,,,,,,, -9395,,,,,,,,7703.400137901306,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/eval_measurements.csv deleted file mode 100644 index e02300999..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -23.48318362236023,0.0,10.086588859558104,1,0,10.086588859558104,1.5194805184210527,95000000,33.56980895996094,1.5201884942984432,1.5212185261462023,83274637 -46.37395644187927,0.0187075138092041,1210.5401918888092,1579,0,1210.5401918888092,0.1288543638980263,95000000,1256.9842765331268,0.1277041979274659,0.1264533591197161,83274637 -69.08802628517151,0.0517487525939941,2410.887097120285,3159,0,2410.887097120285,0.1284411678453947,95000000,2480.129752635956,0.1248658452206437,0.1261587597418767,83274637 -91.64909315109252,0.081200361251831,3611.227801322937,4743,0,3611.227801322937,0.1279277331003289,95000000,3703.112888813019,0.1244182016864512,0.1257831604518882,83274637 -114.43777704238892,0.1066071987152099,4811.301905870438,6324,0,4811.301905870438,0.1278399025801809,95000000,4926.051818847656,0.1229074633440131,0.1256336655136065,83274637 -137.12053418159485,0.1333255767822265,6011.683310031891,7911,0,6011.683310031891,0.12789141796875,95000000,6149.194463253021,0.1258539432812037,0.1255955362195094,83274637 -160.01790714263916,0.16205430030822754,7211.92481803894,9500,0,7211.92481803894,0.1276304653885691,95000000,7372.413366794586,0.12368268950742746,0.12541117594236706,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/measurements.csv deleted file mode 100644 index 547d5fe4a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/measurements.csv +++ /dev/null @@ -1,111 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.981047,1.5221735,,,,,,,,,,, -1,,,1.5201884942984432,1.5212185261462023,83274637.0,1.5194805184210527,95000000.0,10.086588859558104,33.56980895996094,10.086588859558104,23.48318362236023,0.0,0.0 -100,0.23051877,0.13221493,,,,,,,,,,, -200,0.14136243,0.13516688,,,,,,,,,,, -300,0.09770084,0.13236205,,,,,,,,,,, -400,0.04879078,0.13272418,,,,,,,,,,, -500,0.017638559,0.1269241,,,,,,,,,,, -600,0.012218685,0.12089262,,,,,,,,,,, -700,0.019378832,0.12812224,,,,,,,,,,, -800,0.0052524144,0.13203965,,,,,,,,,,, -900,0.03332424,0.12886988,,,,,,,,,,, -1000,0.034133337,0.1285809,,,,,,,,,,, -1100,0.034125924,0.14015102,,,,,,,,,,, -1200,0.054283805,0.1250805,,,,,,,,,,, -1300,0.040841777,0.12289792,,,,,,,,,,, -1400,0.0062227724,0.12123748,,,,,,,,,,, -1500,0.018354256,0.12275437,,,,,,,,,,, -1579,,,0.1277041979274659,0.1264533591197161,83274637.0,0.1288543638980263,95000000.0,1210.5401918888092,1256.9842765331268,1210.5401918888092,46.37395644187927,0.0187075138092041,0.0 -1600,0.026928073,0.12894799,,,,,,,,,,, -1700,0.03921423,0.12621975,,,,,,,,,,, -1800,0.031404268,0.13381925,,,,,,,,,,, -1900,0.010017296,0.12098771,,,,,,,,,,, -2000,0.031284254,0.13086535,,,,,,,,,,, -2100,0.04457005,0.12709957,,,,,,,,,,, -2200,0.020545214,0.13158232,,,,,,,,,,, -2300,0.05017474,0.124821916,,,,,,,,,,, -2400,0.034333702,0.13483042,,,,,,,,,,, -2500,0.027864918,0.11925297,,,,,,,,,,, -2600,0.050252583,0.12108474,,,,,,,,,,, -2700,0.01978998,0.120774165,,,,,,,,,,, -2800,0.038230084,0.120764226,,,,,,,,,,, -2900,0.014381608,0.12597527,,,,,,,,,,, -3000,0.057458606,0.1277229,,,,,,,,,,, -3100,0.037149645,0.11792576,,,,,,,,,,, -3159,,,0.1248658452206437,0.1261587597418767,83274637.0,0.1284411678453947,95000000.0,2410.887097120285,2480.129752635956,2410.887097120285,69.08802628517151,0.0517487525939941,0.0 -3200,0.058816142,0.13339113,,,,,,,,,,, -3300,0.07544793,0.13030338,,,,,,,,,,, -3400,0.038465943,0.11923201,,,,,,,,,,, -3500,0.011779609,0.1272596,,,,,,,,,,, -3600,0.03660985,0.13011757,,,,,,,,,,, -3700,0.07270223,0.12186109,,,,,,,,,,, -3800,0.039539993,0.11966619,,,,,,,,,,, -3900,0.02747423,0.11903575,,,,,,,,,,, -4000,0.048708845,0.13419212,,,,,,,,,,, -4100,0.07252225,0.12455134,,,,,,,,,,, -4200,0.0082644485,0.1241076,,,,,,,,,,, -4300,0.07299084,0.12073078,,,,,,,,,,, -4400,0.016562311,0.124637336,,,,,,,,,,, -4500,0.023966698,0.12517314,,,,,,,,,,, -4600,0.050815362,0.120225035,,,,,,,,,,, -4700,0.019153545,0.12619384,,,,,,,,,,, -4743,,,0.1244182016864512,0.1257831604518882,83274637.0,0.1279277331003289,95000000.0,3611.227801322937,3703.112888813019,3611.227801322937,91.64909315109252,0.081200361251831,0.0 -4800,0.051273674,0.12633452,,,,,,,,,,, -4900,0.00825519,0.12477049,,,,,,,,,,, -5000,0.042130265,0.12221112,,,,,,,,,,, -5100,0.024990344,0.12037953,,,,,,,,,,, -5200,0.035464615,0.124182,,,,,,,,,,, -5300,0.040430363,0.12436932,,,,,,,,,,, -5400,0.04231042,0.12363717,,,,,,,,,,, -5500,0.004310977,0.11371124,,,,,,,,,,, -5600,0.017832872,0.12272009,,,,,,,,,,, -5700,0.01163715,0.11888181,,,,,,,,,,, -5800,0.041959077,0.1211097,,,,,,,,,,, -5900,0.027332926,0.13830003,,,,,,,,,,, -6000,0.041133005,0.12074799,,,,,,,,,,, -6100,0.0139467325,0.1254148,,,,,,,,,,, -6200,0.012781473,0.12184929,,,,,,,,,,, -6300,0.0115625905,0.1264658,,,,,,,,,,, -6324,,,0.1229074633440131,0.1256336655136065,83274637.0,0.1278399025801809,95000000.0,4811.301905870438,4926.051818847656,4811.301905870438,114.43777704238892,0.1066071987152099,0.0 -6400,0.012414002,0.1251323,,,,,,,,,,, -6500,0.035561237,0.11872613,,,,,,,,,,, -6600,0.027760336,0.13606185,,,,,,,,,,, -6700,0.027895784,0.13324061,,,,,,,,,,, -6800,0.019138074,0.12952216,,,,,,,,,,, -6900,0.008442157,0.12168748,,,,,,,,,,, -7000,0.005535665,0.11928214,,,,,,,,,,, -7100,0.033871677,0.123562306,,,,,,,,,,, -7200,0.021305785,0.116445966,,,,,,,,,,, -7300,0.007715897,0.124291666,,,,,,,,,,, -7400,0.019473955,0.12847018,,,,,,,,,,, -7500,0.020462379,0.12874386,,,,,,,,,,, -7600,0.0065142005,0.12940535,,,,,,,,,,, -7700,0.0058914414,0.11429605,,,,,,,,,,, -7800,0.008974799,0.13390099,,,,,,,,,,, -7900,0.011898835,0.122231156,,,,,,,,,,, -7911,,,0.1258539432812037,0.1255955362195094,83274637.0,0.12789141796875,95000000.0,6011.683310031891,6149.194463253021,6011.683310031891,137.12053418159485,0.1333255767822265,0.0 -8000,0.017113924,0.12517564,,,,,,,,,,, -8100,0.0065169404,0.12163262,,,,,,,,,,, -8200,0.010368287,0.12055365,,,,,,,,,,, -8300,0.008870861,0.12713848,,,,,,,,,,, -8400,0.013552831,0.11944378,,,,,,,,,,, -8500,0.013031767,0.12044431,,,,,,,,,,, -8600,0.006035666,0.12172939,,,,,,,,,,, -8700,0.010368514,0.12223463,,,,,,,,,,, -8800,0.010484478,0.12534085,,,,,,,,,,, -8900,0.009311431,0.11927173,,,,,,,,,,, -9000,0.017366946,0.121520996,,,,,,,,,,, -9100,0.006620112,0.1201053,,,,,,,,,,, -9200,0.013876594,0.12912275,,,,,,,,,,, -9300,0.008475684,0.12109014,,,,,,,,,,, -9400,0.010075308,0.12827532,,,,,,,,,,, -9500,,,0.1236826895074274,0.125411175942367,83274637.0,0.1276304653885691,95000000.0,7211.92481803894,7372.413366794586,7211.92481803894,160.01790714263916,0.1620543003082275,0.0 -9500,0.007803174,0.1208032,,,,,,,,,,, -9600,0.011158598,0.120756656,,,,,,,,,,, -9700,0.0069876616,0.12284458,,,,,,,,,,, -9800,0.007335061,0.1252589,,,,,,,,,,, -9900,0.0115204025,0.119439736,,,,,,,,,,, -10000,0.011479246,0.1247024,,,,,,,,,,, -10100,0.006401952,0.11738011,,,,,,,,,,, -10150,,,,,,,,7703.422815322876,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 0b1fca425..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.293355464935303,0.0,6.213699102401733,1,0,6.213699102401733,1.5194805184210527,95000000,28.50710606575012,1.5197636916202568,1.5212185261462023,83274637 -44.59408211708069,0.0184142589569091,1206.1726896762848,1579,0,1206.1726896762848,0.1277996791118421,95000000,1250.8360691070557,0.1243806681449308,0.1254170771614261,83274637 -66.96817874908447,0.0433871746063232,2406.662718772888,3160,0,2406.662718772888,0.1272734294921875,95000000,2473.7763113975525,0.1239324527903922,0.1248393983167407,83274637 -89.11408042907715,0.0679883956909179,3606.842771530152,4744,0,3606.842771530152,0.1266547538342927,95000000,3696.1776309013367,0.1232597236550828,0.1243291537197724,83274637 -111.42054176330566,0.0913968086242675,4807.368099927902,6330,0,4807.368099927902,0.1265950968955592,95000000,4919.084401607513,0.1246934863943723,0.1241852502667619,83274637 -133.62959718704224,0.1174366474151611,6007.8526430130005,7909,0,6007.8526430130005,0.126114404738898,95000000,6141.8548221588135,0.1223530204228635,0.1237904075631845,83274637 -155.8615963459015,0.14365887641906738,7208.395542383194,9492,0,7208.395542383194,0.12595479833470394,95000000,7364.707081079483,0.12048418599667039,0.123623332617319,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/measurements.csv deleted file mode 100644 index c7ee69601..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/measurements.csv +++ /dev/null @@ -1,111 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.973966,1.5170931,,,,,,,,,,, -1,,,1.5197636916202568,1.5212185261462023,83274637.0,1.5194805184210527,95000000.0,6.213699102401733,28.50710606575012,6.213699102401733,22.293355464935303,0.0,0.0 -100,0.053498518,0.12956613,,,,,,,,,,, -200,0.06650594,0.13353443,,,,,,,,,,, -300,0.019169183,0.12707567,,,,,,,,,,, -400,0.053947933,0.12329276,,,,,,,,,,, -500,0.1224703,0.12496351,,,,,,,,,,, -600,0.06473773,0.12629348,,,,,,,,,,, -700,0.033261888,0.12603435,,,,,,,,,,, -800,0.093837544,0.12344693,,,,,,,,,,, -900,0.0225961,0.1194242,,,,,,,,,,, -1000,0.006245036,0.12135349,,,,,,,,,,, -1100,0.09219341,0.13264275,,,,,,,,,,, -1200,0.028173301,0.120468974,,,,,,,,,,, -1300,0.09431877,0.14273179,,,,,,,,,,, -1400,0.072574444,0.13219513,,,,,,,,,,, -1500,0.065208316,0.119486205,,,,,,,,,,, -1579,,,0.1243806681449308,0.1254170771614261,83274637.0,0.1277996791118421,95000000.0,1206.1726896762848,1250.8360691070557,1206.1726896762848,44.59408211708069,0.0184142589569091,0.0 -1600,0.029161142,0.11945993,,,,,,,,,,, -1700,0.03588621,0.13199012,,,,,,,,,,, -1800,0.010051076,0.12800746,,,,,,,,,,, -1900,0.028374264,0.12268944,,,,,,,,,,, -2000,0.05144389,0.12757638,,,,,,,,,,, -2100,0.012299032,0.13462214,,,,,,,,,,, -2200,0.055284135,0.11978168,,,,,,,,,,, -2300,0.007976957,0.12027761,,,,,,,,,,, -2400,0.0062181884,0.12051916,,,,,,,,,,, -2500,0.005858485,0.119816296,,,,,,,,,,, -2600,0.03091986,0.12211567,,,,,,,,,,, -2700,0.016867373,0.12703927,,,,,,,,,,, -2800,0.02144688,0.12155388,,,,,,,,,,, -2900,0.016603671,0.12532316,,,,,,,,,,, -3000,0.02874492,0.12294948,,,,,,,,,,, -3100,0.033358548,0.11605579,,,,,,,,,,, -3160,,,0.1239324527903922,0.1248393983167407,83274637.0,0.1272734294921875,95000000.0,2406.662718772888,2473.7763113975525,2406.662718772888,66.96817874908447,0.0433871746063232,0.0 -3200,0.018224873,0.12200165,,,,,,,,,,, -3300,0.0072889742,0.12241086,,,,,,,,,,, -3400,0.059859764,0.1167591,,,,,,,,,,, -3500,0.024699485,0.1293768,,,,,,,,,,, -3600,0.027884569,0.12179691,,,,,,,,,,, -3700,0.019902397,0.11872314,,,,,,,,,,, -3800,0.017471673,0.121368766,,,,,,,,,,, -3900,0.0077621927,0.13028814,,,,,,,,,,, -4000,0.015475147,0.123037264,,,,,,,,,,, -4100,0.011397916,0.12620795,,,,,,,,,,, -4200,0.031625107,0.13099056,,,,,,,,,,, -4300,0.013971738,0.12915419,,,,,,,,,,, -4400,0.021576554,0.122814104,,,,,,,,,,, -4500,0.016411101,0.120601244,,,,,,,,,,, -4600,0.014893751,0.11890037,,,,,,,,,,, -4700,0.029586751,0.13443221,,,,,,,,,,, -4744,,,0.1232597236550828,0.1243291537197724,83274637.0,0.1266547538342927,95000000.0,3606.842771530152,3696.1776309013367,3606.842771530152,89.11408042907715,0.0679883956909179,0.0 -4800,0.009216117,0.11462125,,,,,,,,,,, -4900,0.014781006,0.12713918,,,,,,,,,,, -5000,0.007128196,0.12615627,,,,,,,,,,, -5100,0.0109896865,0.12395311,,,,,,,,,,, -5200,0.0054136463,0.120842285,,,,,,,,,,, -5300,0.008895298,0.12780298,,,,,,,,,,, -5400,0.006539214,0.12462632,,,,,,,,,,, -5500,0.014138585,0.12445212,,,,,,,,,,, -5600,0.017117197,0.12851267,,,,,,,,,,, -5700,0.008935386,0.13291423,,,,,,,,,,, -5800,0.0065249107,0.11542285,,,,,,,,,,, -5900,0.0067860363,0.12331572,,,,,,,,,,, -6000,0.012329818,0.12274118,,,,,,,,,,, -6100,0.007768651,0.12444204,,,,,,,,,,, -6200,0.0067719,0.12286316,,,,,,,,,,, -6300,0.0076346784,0.11524536,,,,,,,,,,, -6330,,,0.1246934863943723,0.1241852502667619,83274637.0,0.1265950968955592,95000000.0,4807.368099927902,4919.084401607513,4807.368099927902,111.42054176330566,0.0913968086242675,0.0 -6400,0.015630847,0.11853572,,,,,,,,,,, -6500,0.009713167,0.12232152,,,,,,,,,,, -6600,0.0066234344,0.122205675,,,,,,,,,,, -6700,0.014266622,0.1271804,,,,,,,,,,, -6800,0.006459439,0.12587835,,,,,,,,,,, -6900,0.014334217,0.11792603,,,,,,,,,,, -7000,0.006169773,0.12065978,,,,,,,,,,, -7100,0.0055727153,0.12605067,,,,,,,,,,, -7200,0.0070342557,0.11980655,,,,,,,,,,, -7300,0.007821908,0.1179339,,,,,,,,,,, -7400,0.0071241413,0.12638158,,,,,,,,,,, -7500,0.011498554,0.1173363,,,,,,,,,,, -7600,0.007933419,0.13072982,,,,,,,,,,, -7700,0.008320398,0.11553016,,,,,,,,,,, -7800,0.008270065,0.119917646,,,,,,,,,,, -7900,0.0062224427,0.119127,,,,,,,,,,, -7909,,,0.1223530204228635,0.1237904075631845,83274637.0,0.126114404738898,95000000.0,6007.8526430130005,6141.8548221588135,6007.8526430130005,133.62959718704224,0.1174366474151611,0.0 -8000,0.010525943,0.11855522,,,,,,,,,,, -8100,0.0082243765,0.12648197,,,,,,,,,,, -8200,0.016276458,0.1372267,,,,,,,,,,, -8300,0.00745094,0.119880274,,,,,,,,,,, -8400,0.0054612337,0.11818514,,,,,,,,,,, -8500,0.007058272,0.13453929,,,,,,,,,,, -8600,0.008793821,0.11441825,,,,,,,,,,, -8700,0.0057599964,0.116799675,,,,,,,,,,, -8800,0.0066056536,0.12281239,,,,,,,,,,, -8900,0.006012405,0.11930261,,,,,,,,,,, -9000,0.013953436,0.13600689,,,,,,,,,,, -9100,0.0100597935,0.12681714,,,,,,,,,,, -9200,0.008851031,0.12485628,,,,,,,,,,, -9300,0.005798924,0.121154964,,,,,,,,,,, -9400,0.008212067,0.120074734,,,,,,,,,,, -9492,,,0.1204841859966703,0.123623332617319,83274637.0,0.1259547983347039,95000000.0,7208.395542383194,7364.707081079483,7208.395542383194,155.8615963459015,0.1436588764190673,0.0 -9500,0.00646975,0.13295576,,,,,,,,,,, -9600,0.009990905,0.12393096,,,,,,,,,,, -9700,0.010130011,0.11874766,,,,,,,,,,, -9800,0.008642093,0.11714746,,,,,,,,,,, -9900,0.008999875,0.11386283,,,,,,,,,,, -10000,0.008543794,0.121957585,,,,,,,,,,, -10100,0.006604854,0.13081987,,,,,,,,,,, -10145,,,,,,,,7703.6261920928955,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 7f319eab5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -202.42385411262512,0.0,55.91798734664917,1,0,55.91798734664917,0.8949576268413153,3581,0.2836721463431129,258.3423571586609,0.8928326879228864,0.2656774180276053,0.8956671152882316,3554,0.2613217295323051 -206.8602089881897,0.0311276912689209,136.21887683868408,339,0,136.21887683868408,0.3401057529103253,3581,0.6843897125235618,343.12347984313965,0.3207972730909075,0.6871940749032157,0.3388196851808701,3554,0.665303880201006 -210.8728530406952,0.0653471946716308,216.22331738471985,583,0,216.22331738471985,0.3183245708142802,3581,0.7065911014730523,427.18370366096497,0.2993877274649484,0.7098658425467355,0.3165920002110298,3554,0.6883714597548537 -214.89030241966248,0.1000032424926757,296.40825033187866,826,0,296.40825033187866,0.3082724674824595,3581,0.7172323875139626,511.4295752048493,0.2891365800585065,0.7211519650050572,0.3062531874296567,3554,0.6996939792751126 -218.9048209190369,0.1307234764099121,376.5098688602448,1112,0,376.5098688602448,0.3031021199123848,3581,0.7226963378769896,595.5867800712585,0.283478992325919,0.7276527541024345,0.3010855670986916,3554,0.7057840307575971 -222.9181063175201,0.1552250385284423,456.7718312740326,1458,0,456.7718312740326,0.300171648380777,3581,0.7262398198608978,679.8992731571198,0.2809938362666538,0.730339595249721,0.2982144417491734,3554,0.7092974850080894 -226.938090801239,0.1846945285797119,536.8662204742432,1806,0,536.8662204742432,0.2963384155852939,3581,0.7303801203399888,764.0558526515961,0.2769751208169119,0.734905515398298,0.2946085932553988,3554,0.7133297896076604 -230.95373272895813,0.2125940322875976,616.8835611343384,2152,0,616.8835611343384,0.3137618477402436,3581,0.720800208566043,848.1295416355133,0.2946501118796212,0.7232328142438617,0.3114003024211276,3554,0.7039935745814575 -234.9765722751617,0.2394816875457763,696.9776549339294,2498,0,696.9776549339294,0.2952626219544121,3581,0.7322592054855487,932.2862796783448,0.2758675473076956,0.7367127282278878,0.2935951417526906,3554,0.7151943672622397 -238.9914553165436,0.2640409469604492,777.0949683189392,2844,0,777.0949683189392,0.2935845556954237,3581,0.7344461764346552,1016.4557588100432,0.2741975443703787,0.7388166018894741,0.2919819863512063,3554,0.7174660978387029 -243.0097703933716,0.2926876544952392,857.13303399086,3188,0,857.13303399086,0.2936349723366378,3581,0.7352607511868193,1100.553381204605,0.2737694127219064,0.7401871000017438,0.2919790324832055,3554,0.7183199030801561 -247.0268371105194,0.3201866149902344,937.29634308815,3535,0,937.29634308815,0.2923930662698967,3581,0.7353010435937937,1184.7737319469452,0.2726729597364153,0.7400742939540318,0.2907690800638365,3554,0.7181579212023425 -251.04173159599304,0.3457553386688232,1017.441393136978,3881,0,1017.441393136978,0.2927929223898701,3581,0.7356736290491482,1268.9722275733948,0.2732805865151541,0.7403202056884766,0.2912120915693584,3554,0.7185890485412564 -255.06102967262268,0.3724069595336914,1097.5015771389008,4225,0,1097.5015771389008,0.2924918542524783,3581,0.7346395254468026,1353.090767621994,0.2730711357934134,0.7391737529209682,0.2909697026567776,3554,0.717479836759637 -259.0786759853363,0.3975775241851806,1177.474592924118,4572,0,1177.474592924118,0.2918655834395071,3581,0.7362364273902192,1437.1194834709167,0.2717891590935843,0.7415586880275181,0.2903687622551175,3554,0.7190310983223129 -263.0961480140686,0.4233999252319336,1257.6327981948853,4919,0,1257.6327981948853,0.2918626177547298,3581,0.7359213148605487,1521.3335707187653,0.2720399243491037,0.740943159375872,0.2903080705718908,3554,0.7187721883573087 -267.11389446258545,0.4483840465545654,1337.6137821674347,5264,0,1337.6137821674347,0.2923846464521956,3581,0.7353418132373988,1605.3700017929075,0.2729257345199585,0.7398541995457241,0.2908276422143184,3554,0.7182161742271033 -271.1278207302093,0.4733326435089111,1417.68976521492,5607,0,1417.68976521492,0.2903670263303721,3581,0.7379724097231919,1689.497330904007,0.2704073531287057,0.7430362701416016,0.2888049982743915,3554,0.72083666731148 -275.14325308799744,0.4997069835662842,1497.8008234500885,5954,0,1497.8008234500885,0.2902548416337964,3581,0.7374507900856954,1773.6630256175995,0.2703781127929687,0.7426825250898089,0.2887011320321292,3554,0.7202349712691686 -279.15823125839233,0.5257272720336914,1577.8476405143738,6296,0,1577.8476405143738,0.2908137539051592,3581,0.7372403969081611,1857.763218164444,0.2712857723236084,0.742091178894043,0.2893256346831915,3554,0.7200937351619654 -283.1754529476166,0.5520541667938232,1657.8847556114197,6638,0,1657.8847556114197,0.2905435016187866,3581,0.7374012938294122,1941.856261730194,0.2704795088086809,0.7428139277866909,0.2890228975604073,3554,0.720252831866383 -287.19444942474365,0.5787684917449951,1737.8613169193268,6984,0,1737.8613169193268,0.2904899147628106,3581,0.737995862494764,2025.891048192978,0.2706403732299804,0.7429917199271066,0.2890006061611916,3554,0.7208075407990996 -291.2129316329956,0.6057913303375244,1817.926639080048,7331,0,1817.926639080048,0.2893841915775098,3581,0.7387117856176696,2110.0144007205963,0.2696965081351144,0.7436522756304059,0.2879728146597584,3554,0.7215000511087859 -295.2284879684448,0.6318953037261963,1898.0349142551424,7674,0,1898.0349142551424,0.2907413162022654,3581,0.7367458434052988,2194.177098274231,0.2708191360746111,0.742138317653111,0.2893724157089722,3554,0.7195088006031936 -299.24297618865967,0.6634461879730225,1978.042454719544,8020,0,1978.042454719544,0.2896054248442299,3581,0.7389935597598436,2278.243096113205,0.269753132547651,0.744044576372419,0.2882133659916643,3554,0.7217948883520329 -303.2593698501587,0.6898679733276367,2058.131326675415,8365,0,2058.131326675415,0.2896580572269442,3581,0.7383492221315624,2362.3873484134674,0.2699533700942993,0.7434047971452985,0.2882835031830332,3554,0.7211110335625351 -307.2768743038177,0.7161507606506348,2138.2803223133087,8708,0,2138.2803223133087,0.2896519554157358,3581,0.7395078844945546,2446.592636823654,0.2695001704352243,0.7450571060180664,0.2882594944187007,3554,0.7223065944622257 -311.29213285446167,0.7432451248168945,2218.340068101883,9053,0,2218.340068101883,0.2896583640219212,3581,0.7392402910979824,2530.707357406616,0.2695177282605852,0.7446896008082798,0.288230230517111,3554,0.7220156041168402 -315.3143537044525,0.7714033126831055,2298.3692531585693,9398,0,2298.3692531585693,0.2885951148967641,3581,0.7406060741587546,2614.7996389865875,0.2685531207493373,0.7459303992135184,0.2872594898848568,3554,0.7233655904878307 -319.326447725296,0.7993285655975342,2378.3575394153595,9740,0,2378.3575394153595,0.2888772980989074,3581,0.7407723570362678,2698.840270757675,0.2690248829977853,0.7458776746477399,0.2874351591626512,3554,0.7237394265264491 -323.3411955833435,0.8272280693054199,2458.373385190964,10084,0,2458.373385190964,0.2895506108083461,3581,0.7392846741046495,2782.9115421772003,0.2694817951747349,0.7447167805262974,0.288159372032393,3554,0.7220355255521947 -327.36353492736816,0.8536674976348877,2538.4143187999725,10429,0,2538.4143187999725,0.2910769158187308,3581,0.7367384803258518,2867.0138463974,0.2712743282318115,0.7417423384530204,0.2896735385085467,3554,0.7195695953283272 -331.38250207901,0.8842437267303467,2618.481980085373,10772,0,2618.481980085373,0.2895300555448897,3581,0.7387032635349763,2951.143794298172,0.2697212696075439,0.7439629690987724,0.2880775739318813,3554,0.7215402374525183 -335.40252470970154,0.9122757911682128,2698.670265674591,11116,0,2698.670265674591,0.2891032014669435,3581,0.7391950899713767,3035.3925642967224,0.2687648705073765,0.7449778829302106,0.2877332593996201,3554,0.7219717082644556 -339.4211540222168,0.9393706321716307,2778.822010755539,11460,0,2778.822010755539,0.2883605190043807,3581,0.7399028318948967,3119.602737426758,0.2682352236339024,0.7453946386064801,0.2870142501461821,3554,0.7226209409731992 -343.4411287307739,0.9664688110351562,2858.9452958106995,11804,0,2858.9452958106995,0.2882328241173031,3581,0.7403026880148702,3203.785383462906,0.2681718553815569,0.7457089424133301,0.2868413629998769,3554,0.7230791339863534 -347.4586265087128,0.9985432624816896,2938.934836626053,12146,0,2938.934836626053,0.2882499364593514,3581,0.7404194746361002,3287.837063789368,0.2677620989935739,0.7462355749947684,0.2868784065654456,3554,0.7232231178777434 -351.48676466941833,1.028304100036621,3018.917544603348,12492,0,3018.917544603348,0.2879755594849553,3581,0.7405588277323024,3371.8902475833893,0.2678097145898001,0.7459902082170758,0.2866928452645786,3554,0.7231852671505697 -355.50565004348755,1.0569570064544678,3098.916193246841,12834,0,3098.916193246841,0.2885652194306758,3581,0.7409299814777646,3455.9489629268646,0.2683678013937814,0.7463996069771903,0.2872475370236441,3554,0.7237146964687676 -359.51924538612366,1.0859599113464355,3178.969719648361,13177,0,3178.969719648361,0.28811678743935,3581,0.7405494193530089,3540.0576510429382,0.2676967552730015,0.7463410241263253,0.2867473029124314,3554,0.7232975828292065 -363.53816270828247,1.11417555809021,3258.99784898758,13522,0,3258.99784898758,0.2877505083251885,3581,0.7410320419401005,3624.1457312107086,0.2675368956157139,0.7465722220284599,0.2864089819844277,3554,0.7237550202017093 -367.5555286407471,1.1418681144714355,3339.0577857494354,13868,0,3339.0577857494354,0.2880376684236246,3581,0.7413182475652751,3708.263298511505,0.2678717374801636,0.7468410900660923,0.2867557008278524,3554,0.7240890820642234 -371.569610118866,1.1705811023712158,3419.144530057907,14211,0,3419.144530057907,0.2886754610923974,3581,0.7405325115409452,3792.405335187912,0.2678675992148263,0.7465951783316476,0.2872121421285875,3554,0.723388877958814 -375.5900700092316,1.198145866394043,3499.270180702209,14559,0,3499.270180702209,0.2878229119397514,3581,0.7414607367879084,3876.59183716774,0.2675107717514038,0.7471063477652413,0.2865292147162528,3554,0.7242466674873382 -379.6057589054108,1.2259979248046875,3579.3831877708435,14905,0,3579.3831877708435,0.2879113711581088,3581,0.742121709521607,3960.7610342502594,0.2676566668919155,0.7477654048374721,0.2865739349038935,3554,0.7249977742948087 -383.6252100467682,1.2533009052276611,3659.4731526374817,15248,0,3659.4731526374817,0.2878147989170274,3581,0.7418195505576306,4044.9105067253113,0.2673081500189645,0.7476011684962681,0.2864856623368915,3554,0.7246028490125562 -387.6414725780487,1.2822730541229248,3739.480894804001,15592,0,3739.480894804001,0.2878657609715337,3581,0.741601521594003,4128.976114034653,0.2675517116274152,0.7472316878182548,0.2864906942166836,3554,0.7243855679779826 -391.657977104187,1.3116509914398191,3819.596107006073,15936,0,3819.596107006073,0.2876255064162419,3581,0.7426279894102555,4213.149703264236,0.2673044204711914,0.7483240536281041,0.2863633515932752,3554,0.7253718164172411 -395.6702156066895,1.340677261352539,3899.568253993988,16279,0,3899.568253993988,0.2881037997853253,3581,0.7415096876308992,4297.175599813461,0.2675484078271048,0.7473345484052386,0.2866905955162757,3554,0.7243296505697805 -399.6851809024811,1.3709535598754885,3979.7582840919495,16625,0,3979.7582840919495,0.288077960830599,3581,0.7408988929200991,4381.423604011536,0.2676364524023873,0.7466865948268345,0.2866642339617332,3554,0.7237747355532499 -403.7024691104889,1.3988149166107178,4059.7816610336304,16969,0,4059.7816610336304,0.2871888348990331,3581,0.7424319815083077,4465.504558324814,0.2668043034417288,0.7481282779148647,0.2858984608561656,3554,0.7251985686242614 -407.7211818695069,1.4279568195343018,4139.849115371704,17310,0,4139.849115371704,0.28726846523972,3581,0.7423032639713069,4549.632168054581,0.2669447490147182,0.7480014392307827,0.2859115471783554,3554,0.7251391477912211 -411.7394840717316,1.4562327861785889,4219.889711856842,17656,0,4219.889711856842,0.287273135341036,3581,0.7412486391938355,4633.732095003128,0.2665583406175886,0.7473434720720563,0.2859433184330156,3554,0.7239926348392656 -415.75352668762207,1.4845950603485107,4300.033041477203,18003,0,4300.033041477203,0.286959454521258,3581,0.7424689332588662,4717.930500507355,0.2665321145738874,0.7483083861214774,0.2856605886138418,3554,0.7252950158492192 -419.7735161781311,1.5131938457489014,4380.129502534866,18345,0,4380.129502534866,0.2871852897126326,3581,0.7414670090407708,4802.0883066654205,0.2668859277452741,0.7470991952078683,0.285921439201428,3554,0.7242033898863957 -423.7905547618866,1.542210578918457,4460.136431455612,18691,0,4460.136431455612,0.28718869854571,3581,0.743057638731325,4886.153877258301,0.2664386204310826,0.7492048399788993,0.2859228817881261,3554,0.7258765843723621 -427.807642698288,1.5725533962249756,4540.380712032318,19038,0,4540.380712032318,0.2870192454534347,3581,0.7429034912995671,4970.458175897598,0.266427789415632,0.7488509586879185,0.2856958117723867,3554,0.7257310891996693 -431.82620668411255,1.60274076461792,4620.512376785278,19384,0,4620.512376785278,0.2869319452383237,3581,0.7426361024329796,5054.651271343231,0.2664816209248134,0.7485593387058803,0.2856379709152539,3554,0.7254547994996835 -435.8466863632202,1.63158917427063,4700.570680141449,19728,0,4700.570680141449,0.2871433269774504,3581,0.742244086629084,5138.771405696869,0.2663026877811977,0.7484092712402344,0.2859123371663091,3554,0.7250737505275746 -439.86440348625183,1.6642556190490725,4780.684697151184,20072,0,4780.684697151184,0.2868610755986456,3581,0.7426688954071837,5222.948258399963,0.2662466253553118,0.7486628804888044,0.2855280252004783,3554,0.7255211584877954 -443.8861367702484,1.6940970420837402,4860.731207847595,20417,0,4860.731207847595,0.2871418952675579,3581,0.7418312769434167,5307.05890750885,0.2665714366095407,0.7478485788617816,0.2859325677273846,3554,0.7244868238252673 -447.90532636642456,1.7241308689117432,4940.794360637665,20761,0,4940.794360637665,0.2869748965350984,3581,0.7421095058991902,5391.183601856232,0.2659840413502284,0.7485369273594448,0.2856829830549645,3554,0.724897548866594 -451.92284989357,1.7548644542694092,5020.885152101517,21107,0,5020.885152101517,0.2870494818028309,3581,0.7430941132452528,5475.335332632065,0.2664053269795009,0.7491487775530133,0.285788927309018,3554,0.7259544153594542 -455.9429025650024,1.7836058139801023,5100.921221733093,21449,0,5100.921221733093,0.28676784401398,3581,0.7426107407148841,5559.43258523941,0.2660897970199585,0.7486867223467145,0.2854990189036561,3554,0.7254420223032146 -459.9567103385925,1.8128316402435305,5181.017259836197,21791,0,5181.017259836197,0.2870056101211254,3581,0.7427518664042865,5643.584321975708,0.2659718820026943,0.7492446218218122,0.2857284245359542,3554,0.7256659667144415 -463.9753234386444,1.841802835464477,5261.106295824051,22138,0,5261.106295824051,0.2873478910504223,3581,0.7421120284356674,5727.733413934708,0.2667830841881888,0.748035022190639,0.286072601679006,3554,0.7249676173633582 -467.9927349090576,1.872994899749756,5341.142727136612,22483,0,5341.142727136612,0.286731574030037,3581,0.7432496923869031,5811.8311512470245,0.2660031999860491,0.7493149893624442,0.2854326255682418,3554,0.7261346700021103 -472.01139187812805,1.9036543369293213,5421.144332408905,22826,0,5421.144332408905,0.2868267486495567,3581,0.7432502378001955,5895.894639015198,0.2659002372196742,0.7495972088405064,0.2854917201019098,3554,0.7261499202043472 -476.02953267097473,1.9357259273529053,5501.226588726044,23169,0,5501.226588726044,0.2866561024657044,3581,0.742398370414165,5980.0396893024445,0.2657872268131801,0.7486049788338798,0.2853719510586663,3554,0.7252260464661298 -480.0495846271515,1.9672400951385496,5581.201484680176,23514,0,5581.201484680176,0.2866178894469073,3581,0.7437929922027716,6064.078918218613,0.265853796686445,0.7499886240277972,0.2853215979134426,3554,0.7267512727736354 -484.0690758228302,1.9982903003692627,5661.339335203171,23856,0,5661.339335203171,0.2868497923611596,3581,0.7432202400691148,6148.279658317566,0.2660057204110281,0.7493819509233747,0.2854621299009479,3554,0.7261044443760551 -488.0896954536438,2.0329601764678955,5741.316986083984,24199,0,5741.316986083984,0.2867679121906415,3581,0.7429892575397934,6232.325395584106,0.2655493191310337,0.7496466636657715,0.2854341025022422,3554,0.7258055541511326 -492.1099181175232,2.064488172531128,5821.357635498047,24545,0,5821.357635498047,0.2864897173231988,3581,0.7436160055893954,6316.430392026901,0.26558062008449,0.7499784060886928,0.2852051433848744,3554,0.7264893402460256 -496.126501083374,2.0994303226470947,5901.505459070206,24889,0,5901.505459070206,0.2866690901197291,3581,0.743647980443661,6400.642231225967,0.2658586502075195,0.7498251370021275,0.2853666100531531,3554,0.7265470437139491 -500.1464011669159,2.129655838012696,5981.699506759644,25235,0,5981.699506759644,0.2863677151873603,3581,0.7431427232049358,6484.898949146271,0.2650754792349679,0.7498918942042759,0.2850697291454171,3554,0.7259708333699705 -504.1658718585968,2.16015625,6061.817725658417,25580,0,6061.817725658417,0.2863709535787838,3581,0.7433790916905194,6569.079738616943,0.2653120585850307,0.7498783384050641,0.2851230018113393,3554,0.726230086807998 -508.1867387294769,2.191812992095948,6141.91224861145,25925,0,6141.91224861145,0.2863817254913083,3581,0.7434775387897934,6653.2393543720245,0.2654145956039428,0.7497959818158831,0.2851072535732185,3554,0.72636321695185 -512.2066950798035,2.224546432495117,6221.960324764252,26268,0,6221.960324764252,0.2863529208518046,3581,0.7434729027768081,6737.352556943893,0.2649457454681396,0.7502094677516392,0.2850390226571292,3554,0.7263461806898917 -516.2275338172913,2.2554852962493896,6301.933490514755,26613,0,6301.933490514755,0.2863526822334892,3581,0.7435960298275621,6821.390177488327,0.2652312006269182,0.7500839233398438,0.2850700726184405,3554,0.7263961216674873 -520.2461042404175,2.287006378173828,6382.069779396057,26958,0,6382.069779396057,0.2862853236918807,3581,0.7434109301914619,6905.588939666748,0.2652568135942731,0.7498068809509277,0.285042989770549,3554,0.7262263086047411 -524.2069962024689,2.317927598953247,6462.058999300003,27301,0,6462.058999300003,0.2863510459936121,3581,0.7433469804829308,6989.5824983119965,0.2647902795246669,0.7503508159092495,0.2850366183459658,3554,0.7261952586434299 -528.22221159935,2.3489346504211426,6542.14812707901,27646,0,6542.14812707901,0.28623681599719,3581,0.743229239388439,7073.7304475307465,0.2649988617215837,0.7499651908874512,0.2849680439568532,3554,0.7260563581527856 -532.2352705001831,2.3820061683654785,6622.342363357544,27989,0,6622.342363357544,0.2863091855234222,3581,0.7432007415439124,7157.983325958252,0.2651307923453195,0.7497949600219727,0.2850151684556573,3554,0.726063158918648 -536.2536578178406,2.413137912750244,6702.453696250916,28333,0,6702.453696250916,0.2863208778208775,3581,0.7435149677769827,7242.156578779221,0.2647071395601545,0.75055878502982,0.2850119913301913,3554,0.7264404983821047 -540.2722911834717,2.4437241554260254,6782.451961994171,28677,0,6782.451961994171,0.28617555926679,3581,0.7437867199499092,7326.216543912888,0.2647926637104579,0.7505395071847099,0.2848629755590092,3554,0.7266726174512873 -544.2934327125549,2.476641893386841,6862.631316900253,29023,0,6862.631316900253,0.2862706316213174,3581,0.7433983175090757,7410.462491750717,0.2649626220975603,0.7501329013279506,0.2849998839061181,3554,0.7262491839080966 -548.3099186420441,2.50946044921875,6942.597845315933,29366,0,6942.597845315933,0.2862858691051731,3581,0.7433084606691567,7494.490667819977,0.2647256510598318,0.7502886227199009,0.285006427067213,3554,0.7261591939759777 -552.3254547119141,2.542933702468872,7022.648463249207,29711,0,7022.648463249207,0.2862731200694638,3581,0.7435891439847458,7578.602977514267,0.2646856989179338,0.7505989074707031,0.2849455808211258,3554,0.7264604198174592 -556.3397581577301,2.5753839015960693,7102.741122245789,30056,0,7102.741122245789,0.2861787635698827,3581,0.7438693500637042,7662.754905939102,0.2647981473377773,0.7506581715175084,0.2849077472676034,3554,0.7267338930386537 -560.3569049835205,2.6066181659698486,7182.78594827652,30397,0,7182.78594827652,0.2863478076021886,3581,0.7432389204743787,7746.860567092895,0.2648999180112566,0.7500651904514858,0.2850286325981728,3554,0.7260751804744654 -564.373776435852,2.638613700866699,7262.975407361984,30744,0,7262.975407361984,0.2862245441981115,3581,0.7435880531581611,7831.111491441727,0.2645703213555472,0.750626632145473,0.2849193566557927,3554,0.7264586337577378 -568.3998596668243,2.67030930519104,7342.940884590149,31088,0,7342.940884590149,0.2861438230308398,3581,0.7437144526886693,7915.147254228592,0.2646536486489432,0.7506256103515625,0.284842779345236,3554,0.7265625686946047 -572.4185557365417,2.7039809226989746,7423.086068630218,31433,0,7423.086068630218,0.2862593824721621,3581,0.743663320192509,7999.357291936874,0.2646878617150442,0.7506649153573173,0.2849661033342712,3554,0.7265348847689224 -576.4384963512421,2.736294269561768,7503.2064254283905,31779,0,7503.2064254283905,0.2861585832780647,3581,0.7436525482799846,8083.542644500732,0.264353837285723,0.7508582387651715,0.28483960221977,3554,0.7265188102314294 -580.4590227603912,2.7696642875671387,7583.260835409164,32124,0,7583.260835409164,0.2860684878198303,3581,0.743674023928372,8167.663266420364,0.2644423757280622,0.7506969315665108,0.28476238948412,3554,0.7265397620858539 -584.4764456748962,2.801800489425659,7663.235911130905,32467,0,7663.235911130905,0.2860721693595539,3581,0.7436811824778344,8251.700382471085,0.2644823108400617,0.7506919588361468,0.2847786872790781,3554,0.7265296639789673 -588.4945957660675,2.8336546421051025,7743.370946645737,32810,0,7743.370946645737,0.2860847479536093,3581,0.7438340345530229,8335.898382425308,0.2642045361655099,0.7511326926095145,0.2847665283340514,3554,0.7267132846572524 -592.5159690380096,2.867058038711548,7823.340132236481,33155,0,7823.340132236481,0.2860520572443975,3581,0.7435964388875315,8419.934628725052,0.2643196071897234,0.750725269317627,0.2847503507546514,3554,0.7264477800101998 -596.5357365608215,2.898998022079468,7903.390450239181,33500,0,7903.390450239181,0.2860510005061435,3581,0.7438440565222704,8504.049171447754,0.2643869774682181,0.7509395054408482,0.2847576495563977,3554,0.7267172002497186 -600.5509705543518,2.932687520980835,7983.478590726852,33844,0,7983.478590726852,0.2860495006195895,3581,0.7439712741727171,8588.198611021042,0.2641360419137137,0.7513058526175362,0.2847431893421145,3554,0.7268690840206458 -604.5714876651764,2.965465784072876,8063.584829330444,34190,0,8063.584829330444,0.2859617913445092,3581,0.7437399507600879,8672.370778083801,0.2641798939023699,0.7509686606270927,0.2846628166546497,3554,0.7265910082609384 -608.591876745224,2.998674154281616,8143.587386846542,34534,0,8143.587386846542,0.2859943457003979,3581,0.7436853412541887,8756.439602851868,0.2642404351915632,0.7508820125034877,0.2847020241002655,3554,0.726557279210045 -612.6066925525665,3.0308783054351807,8223.553173303604,34875,0,8223.553173303604,0.2859594392496858,3581,0.7437644943582449,8840.46463394165,0.264101539339338,0.7510365758623395,0.2846636409899057,3554,0.7266363467000211 -616.624080657959,3.063483715057373,8303.650809764862,35220,0,8303.650809764862,0.2859407588444219,3581,0.7438080592449735,8924.624583244324,0.2640958513532366,0.7510863712855748,0.2846420365367367,3554,0.7266660227692389 -620.6433687210083,3.099930047988892,8383.785435676575,35566,0,8383.785435676575,0.2859345547682211,3581,0.7438390796259774,9008.82751774788,0.264108487537929,0.7510568073817662,0.2846336729686181,3554,0.7267169254713 -624.6613454818726,3.1323611736297607,8463.990142822266,35909,0,8463.990142822266,0.2859224874991273,3581,0.7438034914086498,9093.09502840042,0.2640914406095232,0.7510451589311872,0.2846232485623593,3554,0.7266632062904473 -628.6836860179901,3.1654698848724365,8528.737991809845,36189,0,8528.737991809845,0.28593407753159034,3581,0.7439900909313041,9161.908516168594,0.2641005516052246,0.7512311254228864,0.2846340336152926,3554,0.7268648249551561 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 08c51b152..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.160697,0.8857512,,,,,,,,,,,,,, -1,,,0.2656774180276053,0.8928326879228864,0.2613217295323051,0.8956671152882316,3554.0,0.2836721463431129,0.8949576268413153,3581.0,55.91798734664917,258.3423571586609,55.91798734664917,202.42385411262512,0.0,0.0 -100,1.0042464,0.38269874,,,,,,,,,,,,,, -200,0.2571165,0.28436524,,,,,,,,,,,,,, -300,0.10312225,0.32610947,,,,,,,,,,,,,, -339,,,0.6871940749032157,0.3207972730909075,0.665303880201006,0.3388196851808701,3554.0,0.6843897125235618,0.3401057529103253,3581.0,136.21887683868408,343.12347984313965,136.21887683868408,206.8602089881897,0.0311276912689209,0.0 -400,0.1257668,0.2991722,,,,,,,,,,,,,, -500,0.16659963,0.30708474,,,,,,,,,,,,,, -583,,,0.7098658425467355,0.2993877274649484,0.6883714597548537,0.3165920002110298,3554.0,0.7065911014730523,0.3183245708142802,3581.0,216.22331738471985,427.18370366096497,216.22331738471985,210.8728530406952,0.0653471946716308,0.0 -600,0.09352814,0.2775395,,,,,,,,,,,,,, -700,0.18253644,0.24091326,,,,,,,,,,,,,, -800,0.14795639,0.3561833,,,,,,,,,,,,,, -826,,,0.7211519650050572,0.2891365800585065,0.6996939792751126,0.3062531874296567,3554.0,0.7172323875139626,0.3082724674824595,3581.0,296.40825033187866,511.4295752048493,296.40825033187866,214.89030241966248,0.1000032424926757,0.0 -900,0.107576,0.36933163,,,,,,,,,,,,,, -1000,0.11945334,0.28406268,,,,,,,,,,,,,, -1100,0.50715363,0.26979584,,,,,,,,,,,,,, -1112,,,0.7276527541024345,0.283478992325919,0.7057840307575971,0.3010855670986916,3554.0,0.7226963378769896,0.3031021199123848,3581.0,376.5098688602448,595.5867800712585,376.5098688602448,218.9048209190369,0.1307234764099121,0.0 -1200,0.13907877,0.26873428,,,,,,,,,,,,,, -1300,0.27435774,0.29432353,,,,,,,,,,,,,, -1400,0.2798384,0.29689264,,,,,,,,,,,,,, -1458,,,0.730339595249721,0.2809938362666538,0.7092974850080894,0.2982144417491734,3554.0,0.7262398198608978,0.300171648380777,3581.0,456.7718312740326,679.8992731571198,456.7718312740326,222.9181063175201,0.1552250385284423,0.0 -1500,0.16771577,0.27619728,,,,,,,,,,,,,, -1600,0.41387638,0.2531028,,,,,,,,,,,,,, -1700,0.17003193,0.4150604,,,,,,,,,,,,,, -1800,0.1116036,0.25742355,,,,,,,,,,,,,, -1806,,,0.734905515398298,0.2769751208169119,0.7133297896076604,0.2946085932553988,3554.0,0.7303801203399888,0.2963384155852939,3581.0,536.8662204742432,764.0558526515961,536.8662204742432,226.938090801239,0.1846945285797119,0.0 -1900,0.082416445,0.3131087,,,,,,,,,,,,,, -2000,0.09490412,0.27362847,,,,,,,,,,,,,, -2100,0.17937307,0.31237322,,,,,,,,,,,,,, -2152,,,0.7232328142438617,0.2946501118796212,0.7039935745814575,0.3114003024211276,3554.0,0.720800208566043,0.3137618477402436,3581.0,616.8835611343384,848.1295416355133,616.8835611343384,230.95373272895813,0.2125940322875976,0.0 -2200,0.60367304,0.2881979,,,,,,,,,,,,,, -2300,0.17214528,0.28092897,,,,,,,,,,,,,, -2400,0.23097578,0.33894622,,,,,,,,,,,,,, -2498,,,0.7367127282278878,0.2758675473076956,0.7151943672622397,0.2935951417526906,3554.0,0.7322592054855487,0.2952626219544121,3581.0,696.9776549339294,932.2862796783448,696.9776549339294,234.9765722751617,0.2394816875457763,0.0 -2500,0.29233947,0.3012068,,,,,,,,,,,,,, -2600,0.10704054,0.24765632,,,,,,,,,,,,,, -2700,0.14505628,0.29977313,,,,,,,,,,,,,, -2800,0.21134044,0.23725085,,,,,,,,,,,,,, -2844,,,0.7388166018894741,0.2741975443703787,0.7174660978387029,0.2919819863512063,3554.0,0.7344461764346552,0.2935845556954237,3581.0,777.0949683189392,1016.4557588100432,777.0949683189392,238.9914553165436,0.2640409469604492,0.0 -2900,0.21890311,0.25156978,,,,,,,,,,,,,, -3000,0.09432188,0.26394603,,,,,,,,,,,,,, -3100,0.18475823,0.2874429,,,,,,,,,,,,,, -3188,,,0.7401871000017438,0.2737694127219064,0.7183199030801561,0.2919790324832055,3554.0,0.7352607511868193,0.2936349723366378,3581.0,857.13303399086,1100.553381204605,857.13303399086,243.0097703933716,0.2926876544952392,0.0 -3200,0.19263089,0.26554966,,,,,,,,,,,,,, -3300,0.07562012,0.2633933,,,,,,,,,,,,,, -3400,0.14303961,0.28577325,,,,,,,,,,,,,, -3500,0.10661849,0.32846478,,,,,,,,,,,,,, -3535,,,0.7400742939540318,0.2726729597364153,0.7181579212023425,0.2907690800638365,3554.0,0.7353010435937937,0.2923930662698967,3581.0,937.29634308815,1184.7737319469452,937.29634308815,247.0268371105194,0.3201866149902344,0.0 -3600,0.096967414,0.24903819,,,,,,,,,,,,,, -3700,0.4245347,0.2559181,,,,,,,,,,,,,, -3800,0.07120775,0.2650801,,,,,,,,,,,,,, -3881,,,0.7403202056884766,0.2732805865151541,0.7185890485412564,0.2912120915693584,3554.0,0.7356736290491482,0.2927929223898701,3581.0,1017.441393136978,1268.9722275733948,1017.441393136978,251.04173159599304,0.3457553386688232,0.0 -3900,0.13351002,0.32221186,,,,,,,,,,,,,, -4000,0.29737294,0.23414296,,,,,,,,,,,,,, -4100,0.116450734,0.24462391,,,,,,,,,,,,,, -4200,0.14830562,0.294327,,,,,,,,,,,,,, -4225,,,0.7391737529209682,0.2730711357934134,0.717479836759637,0.2909697026567776,3554.0,0.7346395254468026,0.2924918542524783,3581.0,1097.5015771389008,1353.090767621994,1097.5015771389008,255.06102967262268,0.3724069595336914,0.0 -4300,0.2209918,0.29089397,,,,,,,,,,,,,, -4400,0.056159373,0.2936174,,,,,,,,,,,,,, -4500,0.11164372,0.25371975,,,,,,,,,,,,,, -4572,,,0.7415586880275181,0.2717891590935843,0.7190310983223129,0.2903687622551175,3554.0,0.7362364273902192,0.2918655834395071,3581.0,1177.474592924118,1437.1194834709167,1177.474592924118,259.0786759853363,0.3975775241851806,0.0 -4600,0.053554863,0.24674736,,,,,,,,,,,,,, -4700,0.13786821,0.2787198,,,,,,,,,,,,,, -4800,0.14424312,0.25193053,,,,,,,,,,,,,, -4900,0.044976797,0.28440303,,,,,,,,,,,,,, -4919,,,0.740943159375872,0.2720399243491037,0.7187721883573087,0.2903080705718908,3554.0,0.7359213148605487,0.2918626177547298,3581.0,1257.6327981948853,1521.3335707187653,1257.6327981948853,263.0961480140686,0.4233999252319336,0.0 -5000,0.22666843,0.2575603,,,,,,,,,,,,,, -5100,0.14599155,0.23257035,,,,,,,,,,,,,, -5200,0.1036556,0.2572719,,,,,,,,,,,,,, -5264,,,0.7398541995457241,0.2729257345199585,0.7182161742271033,0.2908276422143184,3554.0,0.7353418132373988,0.2923846464521956,3581.0,1337.6137821674347,1605.3700017929075,1337.6137821674347,267.11389446258545,0.4483840465545654,0.0 -5300,0.18958907,0.2586218,,,,,,,,,,,,,, -5400,0.09539702,0.3216996,,,,,,,,,,,,,, -5500,0.11499384,0.27451822,,,,,,,,,,,,,, -5600,0.0932003,0.21680972,,,,,,,,,,,,,, -5607,,,0.7430362701416016,0.2704073531287057,0.72083666731148,0.2888049982743915,3554.0,0.7379724097231919,0.2903670263303721,3581.0,1417.68976521492,1689.497330904007,1417.68976521492,271.1278207302093,0.4733326435089111,0.0 -5700,0.14337656,0.25422114,,,,,,,,,,,,,, -5800,0.1340008,0.35856125,,,,,,,,,,,,,, -5900,0.106755316,0.2521864,,,,,,,,,,,,,, -5954,,,0.7426825250898089,0.2703781127929687,0.7202349712691686,0.2887011320321292,3554.0,0.7374507900856954,0.2902548416337964,3581.0,1497.8008234500885,1773.6630256175995,1497.8008234500885,275.14325308799744,0.4997069835662842,0.0 -6000,0.2922615,0.22825944,,,,,,,,,,,,,, -6100,0.09331991,0.34682548,,,,,,,,,,,,,, -6200,0.0844489,0.23744544,,,,,,,,,,,,,, -6296,,,0.742091178894043,0.2712857723236084,0.7200937351619654,0.2893256346831915,3554.0,0.7372403969081611,0.2908137539051592,3581.0,1577.8476405143738,1857.763218164444,1577.8476405143738,279.15823125839233,0.5257272720336914,0.0 -6300,0.19263057,0.2634644,,,,,,,,,,,,,, -6400,0.2832348,0.21498433,,,,,,,,,,,,,, -6500,0.12802143,0.3013669,,,,,,,,,,,,,, -6600,0.26923656,0.33828068,,,,,,,,,,,,,, -6638,,,0.7428139277866909,0.2704795088086809,0.720252831866383,0.2890228975604073,3554.0,0.7374012938294122,0.2905435016187866,3581.0,1657.8847556114197,1941.856261730194,1657.8847556114197,283.1754529476166,0.5520541667938232,0.0 -6700,0.104618445,0.230501,,,,,,,,,,,,,, -6800,0.09704683,0.3652305,,,,,,,,,,,,,, -6900,0.040561378,0.327985,,,,,,,,,,,,,, -6984,,,0.7429917199271066,0.2706403732299804,0.7208075407990996,0.2890006061611916,3554.0,0.737995862494764,0.2904899147628106,3581.0,1737.8613169193268,2025.891048192978,1737.8613169193268,287.19444942474365,0.5787684917449951,0.0 -7000,0.1835466,0.2529319,,,,,,,,,,,,,, -7100,0.11880946,0.23121506,,,,,,,,,,,,,, -7200,0.0695238,0.2667539,,,,,,,,,,,,,, -7300,0.085124224,0.28837654,,,,,,,,,,,,,, -7331,,,0.7436522756304059,0.2696965081351144,0.7215000511087859,0.2879728146597584,3554.0,0.7387117856176696,0.2893841915775098,3581.0,1817.926639080048,2110.0144007205963,1817.926639080048,291.2129316329956,0.6057913303375244,0.0 -7400,0.24265218,0.22021408,,,,,,,,,,,,,, -7500,0.14301114,0.16140318,,,,,,,,,,,,,, -7600,0.15862502,0.2903723,,,,,,,,,,,,,, -7674,,,0.742138317653111,0.2708191360746111,0.7195088006031936,0.2893724157089722,3554.0,0.7367458434052988,0.2907413162022654,3581.0,1898.0349142551424,2194.177098274231,1898.0349142551424,295.2284879684448,0.6318953037261963,0.0 -7700,0.12172638,0.24463987,,,,,,,,,,,,,, -7800,0.13873191,0.30280572,,,,,,,,,,,,,, -7900,0.15177882,0.29474583,,,,,,,,,,,,,, -8000,0.15413769,0.29516256,,,,,,,,,,,,,, -8020,,,0.744044576372419,0.269753132547651,0.7217948883520329,0.2882133659916643,3554.0,0.7389935597598436,0.2896054248442299,3581.0,1978.042454719544,2278.243096113205,1978.042454719544,299.24297618865967,0.6634461879730225,0.0 -8100,0.15633269,0.2556499,,,,,,,,,,,,,, -8200,0.069297306,0.29509392,,,,,,,,,,,,,, -8300,0.18060154,0.21435258,,,,,,,,,,,,,, -8365,,,0.7434047971452985,0.2699533700942993,0.7211110335625351,0.2882835031830332,3554.0,0.7383492221315624,0.2896580572269442,3581.0,2058.131326675415,2362.3873484134674,2058.131326675415,303.2593698501587,0.6898679733276367,0.0 -8400,0.12075894,0.26852754,,,,,,,,,,,,,, -8500,0.19381467,0.26431412,,,,,,,,,,,,,, -8600,0.09719217,0.28867292,,,,,,,,,,,,,, -8700,0.20338535,0.21802297,,,,,,,,,,,,,, -8708,,,0.7450571060180664,0.2695001704352243,0.7223065944622257,0.2882594944187007,3554.0,0.7395078844945546,0.2896519554157358,3581.0,2138.2803223133087,2446.592636823654,2138.2803223133087,307.2768743038177,0.7161507606506348,0.0 -8800,0.12903616,0.34600314,,,,,,,,,,,,,, -8900,0.1442521,0.28313422,,,,,,,,,,,,,, -9000,0.7551182,0.24117884,,,,,,,,,,,,,, -9053,,,0.7446896008082798,0.2695177282605852,0.7220156041168402,0.288230230517111,3554.0,0.7392402910979824,0.2896583640219212,3581.0,2218.340068101883,2530.707357406616,2218.340068101883,311.29213285446167,0.7432451248168945,0.0 -9100,0.17959633,0.28321996,,,,,,,,,,,,,, -9200,0.1374806,0.23753478,,,,,,,,,,,,,, -9300,0.07916075,0.4162172,,,,,,,,,,,,,, -9398,,,0.7459303992135184,0.2685531207493373,0.7233655904878307,0.2872594898848568,3554.0,0.7406060741587546,0.2885951148967641,3581.0,2298.3692531585693,2614.7996389865875,2298.3692531585693,315.3143537044525,0.7714033126831055,0.0 -9400,0.10104478,0.30283606,,,,,,,,,,,,,, -9500,0.14098457,0.27868956,,,,,,,,,,,,,, -9600,0.1185856,0.33302817,,,,,,,,,,,,,, -9700,0.070823714,0.40033928,,,,,,,,,,,,,, -9740,,,0.7458776746477399,0.2690248829977853,0.7237394265264491,0.2874351591626512,3554.0,0.7407723570362678,0.2888772980989074,3581.0,2378.3575394153595,2698.840270757675,2378.3575394153595,319.326447725296,0.7993285655975342,0.0 -9800,0.13998096,0.24627438,,,,,,,,,,,,,, -9900,0.108107254,0.28074643,,,,,,,,,,,,,, -10000,0.072608754,0.29635066,,,,,,,,,,,,,, -10084,,,0.7447167805262974,0.2694817951747349,0.7220355255521947,0.288159372032393,3554.0,0.7392846741046495,0.2895506108083461,3581.0,2458.373385190964,2782.9115421772003,2458.373385190964,323.3411955833435,0.8272280693054199,0.0 -10100,0.22265385,0.24293818,,,,,,,,,,,,,, -10200,0.12078932,0.36512873,,,,,,,,,,,,,, -10300,0.09212808,0.29488572,,,,,,,,,,,,,, -10400,0.33129117,0.2861897,,,,,,,,,,,,,, -10429,,,0.7417423384530204,0.2712743282318115,0.7195695953283272,0.2896735385085467,3554.0,0.7367384803258518,0.2910769158187308,3581.0,2538.4143187999725,2867.0138463974,2538.4143187999725,327.36353492736816,0.8536674976348877,0.0 -10500,0.1290623,0.24649057,,,,,,,,,,,,,, -10600,0.1584398,0.24755692,,,,,,,,,,,,,, -10700,0.19048995,0.2544643,,,,,,,,,,,,,, -10772,,,0.7439629690987724,0.2697212696075439,0.7215402374525183,0.2880775739318813,3554.0,0.7387032635349763,0.2895300555448897,3581.0,2618.481980085373,2951.143794298172,2618.481980085373,331.38250207901,0.8842437267303467,0.0 -10800,0.095460035,0.23375031,,,,,,,,,,,,,, -10900,0.15333661,0.24753289,,,,,,,,,,,,,, -11000,0.12478778,0.26852933,,,,,,,,,,,,,, -11100,0.15594102,0.24528226,,,,,,,,,,,,,, -11116,,,0.7449778829302106,0.2687648705073765,0.7219717082644556,0.2877332593996201,3554.0,0.7391950899713767,0.2891032014669435,3581.0,2698.670265674591,3035.3925642967224,2698.670265674591,335.40252470970154,0.9122757911682128,0.0 -11200,0.14527102,0.3871883,,,,,,,,,,,,,, -11300,0.20473725,0.28170702,,,,,,,,,,,,,, -11400,0.1253241,0.26872957,,,,,,,,,,,,,, -11460,,,0.7453946386064801,0.2682352236339024,0.7226209409731992,0.2870142501461821,3554.0,0.7399028318948967,0.2883605190043807,3581.0,2778.822010755539,3119.602737426758,2778.822010755539,339.4211540222168,0.9393706321716307,0.0 -11500,0.1137302,0.36824518,,,,,,,,,,,,,, -11600,0.20599401,0.25748336,,,,,,,,,,,,,, -11700,0.17094482,0.2407554,,,,,,,,,,,,,, -11800,0.10847823,0.27308646,,,,,,,,,,,,,, -11804,,,0.7457089424133301,0.2681718553815569,0.7230791339863534,0.2868413629998769,3554.0,0.7403026880148702,0.2882328241173031,3581.0,2858.9452958106995,3203.785383462906,2858.9452958106995,343.4411287307739,0.9664688110351562,0.0 -11900,0.16378993,0.31446755,,,,,,,,,,,,,, -12000,0.1481909,0.33943087,,,,,,,,,,,,,, -12100,0.07427148,0.32722574,,,,,,,,,,,,,, -12146,,,0.7462355749947684,0.2677620989935739,0.7232231178777434,0.2868784065654456,3554.0,0.7404194746361002,0.2882499364593514,3581.0,2938.934836626053,3287.837063789368,2938.934836626053,347.4586265087128,0.9985432624816896,0.0 -12200,0.0880428,0.26853317,,,,,,,,,,,,,, -12300,0.08048114,0.22964446,,,,,,,,,,,,,, -12400,0.19051552,0.2648278,,,,,,,,,,,,,, -12492,,,0.7459902082170758,0.2678097145898001,0.7231852671505697,0.2866928452645786,3554.0,0.7405588277323024,0.2879755594849553,3581.0,3018.917544603348,3371.8902475833893,3018.917544603348,351.48676466941833,1.028304100036621,0.0 -12500,0.0658767,0.3471501,,,,,,,,,,,,,, -12600,0.13560927,0.27771008,,,,,,,,,,,,,, -12700,0.14001308,0.3060327,,,,,,,,,,,,,, -12800,0.28516525,0.26396823,,,,,,,,,,,,,, -12834,,,0.7463996069771903,0.2683678013937814,0.7237146964687676,0.2872475370236441,3554.0,0.7409299814777646,0.2885652194306758,3581.0,3098.916193246841,3455.9489629268646,3098.916193246841,355.50565004348755,1.0569570064544678,0.0 -12900,0.18505424,0.30532035,,,,,,,,,,,,,, -13000,0.11735014,0.3116727,,,,,,,,,,,,,, -13100,0.21320385,0.2561605,,,,,,,,,,,,,, -13177,,,0.7463410241263253,0.2676967552730015,0.7232975828292065,0.2867473029124314,3554.0,0.7405494193530089,0.28811678743935,3581.0,3178.969719648361,3540.0576510429382,3178.969719648361,359.51924538612366,1.0859599113464355,0.0 -13200,0.08446617,0.24011794,,,,,,,,,,,,,, -13300,0.12143488,0.24186306,,,,,,,,,,,,,, -13400,0.107449524,0.29406732,,,,,,,,,,,,,, -13500,0.12227946,0.24911337,,,,,,,,,,,,,, -13522,,,0.7465722220284599,0.2675368956157139,0.7237550202017093,0.2864089819844277,3554.0,0.7410320419401005,0.2877505083251885,3581.0,3258.99784898758,3624.1457312107086,3258.99784898758,363.53816270828247,1.11417555809021,0.0 -13600,0.16750707,0.20631424,,,,,,,,,,,,,, -13700,0.119657084,0.28157997,,,,,,,,,,,,,, -13800,0.15380892,0.3287326,,,,,,,,,,,,,, -13868,,,0.7468410900660923,0.2678717374801636,0.7240890820642234,0.2867557008278524,3554.0,0.7413182475652751,0.2880376684236246,3581.0,3339.0577857494354,3708.263298511505,3339.0577857494354,367.5555286407471,1.1418681144714355,0.0 -13900,0.21506718,0.2533751,,,,,,,,,,,,,, -14000,0.10466316,0.24618776,,,,,,,,,,,,,, -14100,0.2608076,0.3197033,,,,,,,,,,,,,, -14200,0.18469876,0.24670446,,,,,,,,,,,,,, -14211,,,0.7465951783316476,0.2678675992148263,0.723388877958814,0.2872121421285875,3554.0,0.7405325115409452,0.2886754610923974,3581.0,3419.144530057907,3792.405335187912,3419.144530057907,371.569610118866,1.1705811023712158,0.0 -14300,0.07051839,0.2756072,,,,,,,,,,,,,, -14400,0.14226225,0.2894485,,,,,,,,,,,,,, -14500,0.23808691,0.3328777,,,,,,,,,,,,,, -14559,,,0.7471063477652413,0.2675107717514038,0.7242466674873382,0.2865292147162528,3554.0,0.7414607367879084,0.2878229119397514,3581.0,3499.270180702209,3876.59183716774,3499.270180702209,375.5900700092316,1.198145866394043,0.0 -14600,0.21083272,0.22095884,,,,,,,,,,,,,, -14700,0.13315904,0.27102622,,,,,,,,,,,,,, -14800,0.11004412,0.28605503,,,,,,,,,,,,,, -14900,0.17465325,0.27972814,,,,,,,,,,,,,, -14905,,,0.7477654048374721,0.2676566668919155,0.7249977742948087,0.2865739349038935,3554.0,0.742121709521607,0.2879113711581088,3581.0,3579.3831877708435,3960.7610342502594,3579.3831877708435,379.6057589054108,1.2259979248046875,0.0 -15000,0.18644825,0.23815186,,,,,,,,,,,,,, -15100,0.18748122,0.19481215,,,,,,,,,,,,,, -15200,0.07273088,0.19491257,,,,,,,,,,,,,, -15248,,,0.7476011684962681,0.2673081500189645,0.7246028490125562,0.2864856623368915,3554.0,0.7418195505576306,0.2878147989170274,3581.0,3659.4731526374817,4044.9105067253113,3659.4731526374817,383.6252100467682,1.2533009052276611,0.0 -15300,0.14827794,0.2275475,,,,,,,,,,,,,, -15400,0.14549421,0.29778114,,,,,,,,,,,,,, -15500,0.14528474,0.26298183,,,,,,,,,,,,,, -15592,,,0.7472316878182548,0.2675517116274152,0.7243855679779826,0.2864906942166836,3554.0,0.741601521594003,0.2878657609715337,3581.0,3739.480894804001,4128.976114034653,3739.480894804001,387.6414725780487,1.2822730541229248,0.0 -15600,0.13145503,0.35948935,,,,,,,,,,,,,, -15700,0.19265896,0.25518307,,,,,,,,,,,,,, -15800,0.08288138,0.2668793,,,,,,,,,,,,,, -15900,0.05666358,0.19008718,,,,,,,,,,,,,, -15936,,,0.7483240536281041,0.2673044204711914,0.7253718164172411,0.2863633515932752,3554.0,0.7426279894102555,0.2876255064162419,3581.0,3819.596107006073,4213.149703264236,3819.596107006073,391.657977104187,1.3116509914398191,0.0 -16000,0.18708162,0.26199558,,,,,,,,,,,,,, -16100,0.12813985,0.24290365,,,,,,,,,,,,,, -16200,0.15554401,0.21397181,,,,,,,,,,,,,, -16279,,,0.7473345484052386,0.2675484078271048,0.7243296505697805,0.2866905955162757,3554.0,0.7415096876308992,0.2881037997853253,3581.0,3899.568253993988,4297.175599813461,3899.568253993988,395.6702156066895,1.340677261352539,0.0 -16300,0.2717237,0.24231821,,,,,,,,,,,,,, -16400,0.09133206,0.28724653,,,,,,,,,,,,,, -16500,0.16587076,0.3338519,,,,,,,,,,,,,, -16600,0.22693153,0.32609722,,,,,,,,,,,,,, -16625,,,0.7466865948268345,0.2676364524023873,0.7237747355532499,0.2866642339617332,3554.0,0.7408988929200991,0.288077960830599,3581.0,3979.7582840919495,4381.423604011536,3979.7582840919495,399.6851809024811,1.3709535598754885,0.0 -16700,0.13091421,0.23359978,,,,,,,,,,,,,, -16800,0.19608565,0.22665298,,,,,,,,,,,,,, -16900,0.096987784,0.4334288,,,,,,,,,,,,,, -16969,,,0.7481282779148647,0.2668043034417288,0.7251985686242614,0.2858984608561656,3554.0,0.7424319815083077,0.2871888348990331,3581.0,4059.7816610336304,4465.504558324814,4059.7816610336304,403.7024691104889,1.3988149166107178,0.0 -17000,0.3629632,0.22565082,,,,,,,,,,,,,, -17100,0.1174817,0.26978406,,,,,,,,,,,,,, -17200,0.1415644,0.27310252,,,,,,,,,,,,,, -17300,0.07684703,0.30829978,,,,,,,,,,,,,, -17310,,,0.7480014392307827,0.2669447490147182,0.7251391477912211,0.2859115471783554,3554.0,0.7423032639713069,0.28726846523972,3581.0,4139.849115371704,4549.632168054581,4139.849115371704,407.7211818695069,1.4279568195343018,0.0 -17400,0.15009664,0.29678568,,,,,,,,,,,,,, -17500,0.08222038,0.28413427,,,,,,,,,,,,,, -17600,0.14174695,0.2755594,,,,,,,,,,,,,, -17656,,,0.7473434720720563,0.2665583406175886,0.7239926348392656,0.2859433184330156,3554.0,0.7412486391938355,0.287273135341036,3581.0,4219.889711856842,4633.732095003128,4219.889711856842,411.7394840717316,1.4562327861785889,0.0 -17700,0.086660564,0.21659839,,,,,,,,,,,,,, -17800,0.26539177,0.30593824,,,,,,,,,,,,,, -17900,0.07536023,0.33997604,,,,,,,,,,,,,, -18000,0.110748574,0.27563834,,,,,,,,,,,,,, -18003,,,0.7483083861214774,0.2665321145738874,0.7252950158492192,0.2856605886138418,3554.0,0.7424689332588662,0.286959454521258,3581.0,4300.033041477203,4717.930500507355,4300.033041477203,415.75352668762207,1.4845950603485107,0.0 -18100,0.14971568,0.22907567,,,,,,,,,,,,,, -18200,0.05756489,0.23663718,,,,,,,,,,,,,, -18300,0.15884255,0.31028253,,,,,,,,,,,,,, -18345,,,0.7470991952078683,0.2668859277452741,0.7242033898863957,0.285921439201428,3554.0,0.7414670090407708,0.2871852897126326,3581.0,4380.129502534866,4802.0883066654205,4380.129502534866,419.7735161781311,1.5131938457489014,0.0 -18400,0.12821802,0.33342308,,,,,,,,,,,,,, -18500,0.2021974,0.25519103,,,,,,,,,,,,,, -18600,0.049530476,0.29364705,,,,,,,,,,,,,, -18691,,,0.7492048399788993,0.2664386204310826,0.7258765843723621,0.2859228817881261,3554.0,0.743057638731325,0.28718869854571,3581.0,4460.136431455612,4886.153877258301,4460.136431455612,423.7905547618866,1.542210578918457,0.0 -18700,0.10558745,0.2786112,,,,,,,,,,,,,, -18800,0.1516203,0.2847116,,,,,,,,,,,,,, -18900,0.14891882,0.2433894,,,,,,,,,,,,,, -19000,0.13251624,0.27284735,,,,,,,,,,,,,, -19038,,,0.7488509586879185,0.266427789415632,0.7257310891996693,0.2856958117723867,3554.0,0.7429034912995671,0.2870192454534347,3581.0,4540.380712032318,4970.458175897598,4540.380712032318,427.807642698288,1.5725533962249756,0.0 -19100,0.1548414,0.30341867,,,,,,,,,,,,,, -19200,0.3279035,0.22564757,,,,,,,,,,,,,, -19300,0.19211619,0.23761334,,,,,,,,,,,,,, -19384,,,0.7485593387058803,0.2664816209248134,0.7254547994996835,0.2856379709152539,3554.0,0.7426361024329796,0.2869319452383237,3581.0,4620.512376785278,5054.651271343231,4620.512376785278,431.82620668411255,1.60274076461792,0.0 -19400,0.08421687,0.22205774,,,,,,,,,,,,,, -19500,0.21605648,0.32932547,,,,,,,,,,,,,, -19600,0.07254297,0.2931977,,,,,,,,,,,,,, -19700,0.2209935,0.3338139,,,,,,,,,,,,,, -19728,,,0.7484092712402344,0.2663026877811977,0.7250737505275746,0.2859123371663091,3554.0,0.742244086629084,0.2871433269774504,3581.0,4700.570680141449,5138.771405696869,4700.570680141449,435.8466863632202,1.63158917427063,0.0 -19800,0.10429266,0.24873672,,,,,,,,,,,,,, -19900,0.06964718,0.23127571,,,,,,,,,,,,,, -20000,0.104429975,0.23414993,,,,,,,,,,,,,, -20072,,,0.7486628804888044,0.2662466253553118,0.7255211584877954,0.2855280252004783,3554.0,0.7426688954071837,0.2868610755986456,3581.0,4780.684697151184,5222.948258399963,4780.684697151184,439.86440348625183,1.6642556190490725,0.0 -20100,0.11172674,0.3526554,,,,,,,,,,,,,, -20200,0.119868174,0.2805819,,,,,,,,,,,,,, -20300,0.16746771,0.26346016,,,,,,,,,,,,,, -20400,0.19745134,0.28103468,,,,,,,,,,,,,, -20417,,,0.7478485788617816,0.2665714366095407,0.7244868238252673,0.2859325677273846,3554.0,0.7418312769434167,0.2871418952675579,3581.0,4860.731207847595,5307.05890750885,4860.731207847595,443.8861367702484,1.6940970420837402,0.0 -20500,0.17541979,0.23485444,,,,,,,,,,,,,, -20600,0.2659572,0.28978774,,,,,,,,,,,,,, -20700,0.16051592,0.26311862,,,,,,,,,,,,,, -20761,,,0.7485369273594448,0.2659840413502284,0.724897548866594,0.2856829830549645,3554.0,0.7421095058991902,0.2869748965350984,3581.0,4940.794360637665,5391.183601856232,4940.794360637665,447.90532636642456,1.7241308689117432,0.0 -20800,0.10290359,0.25952908,,,,,,,,,,,,,, -20900,0.14415513,0.33501387,,,,,,,,,,,,,, -21000,0.11405571,0.30869597,,,,,,,,,,,,,, -21100,0.14176954,0.23351496,,,,,,,,,,,,,, -21107,,,0.7491487775530133,0.2664053269795009,0.7259544153594542,0.285788927309018,3554.0,0.7430941132452528,0.2870494818028309,3581.0,5020.885152101517,5475.335332632065,5020.885152101517,451.92284989357,1.7548644542694092,0.0 -21200,0.16045356,0.2860969,,,,,,,,,,,,,, -21300,0.087782,0.31159922,,,,,,,,,,,,,, -21400,0.1291721,0.26223642,,,,,,,,,,,,,, -21449,,,0.7486867223467145,0.2660897970199585,0.7254420223032146,0.2854990189036561,3554.0,0.7426107407148841,0.28676784401398,3581.0,5100.921221733093,5559.43258523941,5100.921221733093,455.9429025650024,1.7836058139801023,0.0 -21500,0.19738503,0.21698746,,,,,,,,,,,,,, -21600,0.0917469,0.24899468,,,,,,,,,,,,,, -21700,0.1637088,0.32892925,,,,,,,,,,,,,, -21791,,,0.7492446218218122,0.2659718820026943,0.7256659667144415,0.2857284245359542,3554.0,0.7427518664042865,0.2870056101211254,3581.0,5181.017259836197,5643.584321975708,5181.017259836197,459.9567103385925,1.8128316402435305,0.0 -21800,0.44807234,0.2675401,,,,,,,,,,,,,, -21900,0.17930269,0.23010516,,,,,,,,,,,,,, -22000,0.1887371,0.23964265,,,,,,,,,,,,,, -22100,0.49168712,0.3058823,,,,,,,,,,,,,, -22138,,,0.748035022190639,0.2667830841881888,0.7249676173633582,0.286072601679006,3554.0,0.7421120284356674,0.2873478910504223,3581.0,5261.106295824051,5727.733413934708,5261.106295824051,463.9753234386444,1.841802835464477,0.0 -22200,0.16045177,0.2590931,,,,,,,,,,,,,, -22300,0.13558695,0.32886,,,,,,,,,,,,,, -22400,0.13241142,0.22656764,,,,,,,,,,,,,, -22483,,,0.7493149893624442,0.2660031999860491,0.7261346700021103,0.2854326255682418,3554.0,0.7432496923869031,0.286731574030037,3581.0,5341.142727136612,5811.8311512470245,5341.142727136612,467.9927349090576,1.872994899749756,0.0 -22500,0.14610401,0.22516042,,,,,,,,,,,,,, -22600,0.09150885,0.21682373,,,,,,,,,,,,,, -22700,0.2979061,0.2979715,,,,,,,,,,,,,, -22800,0.16230318,0.2646105,,,,,,,,,,,,,, -22826,,,0.7495972088405064,0.2659002372196742,0.7261499202043472,0.2854917201019098,3554.0,0.7432502378001955,0.2868267486495567,3581.0,5421.144332408905,5895.894639015198,5421.144332408905,472.01139187812805,1.9036543369293213,0.0 -22900,0.1609072,0.20586994,,,,,,,,,,,,,, -23000,0.19275357,0.257452,,,,,,,,,,,,,, -23100,0.18917947,0.3162246,,,,,,,,,,,,,, -23169,,,0.7486049788338798,0.2657872268131801,0.7252260464661298,0.2853719510586663,3554.0,0.742398370414165,0.2866561024657044,3581.0,5501.226588726044,5980.0396893024445,5501.226588726044,476.02953267097473,1.9357259273529053,0.0 -23200,0.08427577,0.27302065,,,,,,,,,,,,,, -23300,0.11392591,0.2630392,,,,,,,,,,,,,, -23400,0.0995108,0.3006045,,,,,,,,,,,,,, -23500,0.15753293,0.27163136,,,,,,,,,,,,,, -23514,,,0.7499886240277972,0.265853796686445,0.7267512727736354,0.2853215979134426,3554.0,0.7437929922027716,0.2866178894469073,3581.0,5581.201484680176,6064.078918218613,5581.201484680176,480.0495846271515,1.9672400951385496,0.0 -23600,0.30684772,0.25719574,,,,,,,,,,,,,, -23700,0.13193437,0.28518522,,,,,,,,,,,,,, -23800,0.09913395,0.29175037,,,,,,,,,,,,,, -23856,,,0.7493819509233747,0.2660057204110281,0.7261044443760551,0.2854621299009479,3554.0,0.7432202400691148,0.2868497923611596,3581.0,5661.339335203171,6148.279658317566,5661.339335203171,484.0690758228302,1.9982903003692627,0.0 -23900,0.07557938,0.2283459,,,,,,,,,,,,,, -24000,0.107201934,0.1893141,,,,,,,,,,,,,, -24100,0.06141336,0.34947333,,,,,,,,,,,,,, -24199,,,0.7496466636657715,0.2655493191310337,0.7258055541511326,0.2854341025022422,3554.0,0.7429892575397934,0.2867679121906415,3581.0,5741.316986083984,6232.325395584106,5741.316986083984,488.0896954536438,2.0329601764678955,0.0 -24200,0.14140151,0.27435657,,,,,,,,,,,,,, -24300,0.27093995,0.27077484,,,,,,,,,,,,,, -24400,0.1237081,0.33024997,,,,,,,,,,,,,, -24500,0.10158507,0.30683112,,,,,,,,,,,,,, -24545,,,0.7499784060886928,0.26558062008449,0.7264893402460256,0.2852051433848744,3554.0,0.7436160055893954,0.2864897173231988,3581.0,5821.357635498047,6316.430392026901,5821.357635498047,492.1099181175232,2.064488172531128,0.0 -24600,0.20265457,0.17385434,,,,,,,,,,,,,, -24700,0.10076458,0.231936,,,,,,,,,,,,,, -24800,0.15268764,0.32010144,,,,,,,,,,,,,, -24889,,,0.7498251370021275,0.2658586502075195,0.7265470437139491,0.2853666100531531,3554.0,0.743647980443661,0.2866690901197291,3581.0,5901.505459070206,6400.642231225967,5901.505459070206,496.126501083374,2.0994303226470947,0.0 -24900,0.15461694,0.31626135,,,,,,,,,,,,,, -25000,0.130428,0.25670213,,,,,,,,,,,,,, -25100,0.065960325,0.23670086,,,,,,,,,,,,,, -25200,0.0817384,0.23410535,,,,,,,,,,,,,, -25235,,,0.7498918942042759,0.2650754792349679,0.7259708333699705,0.2850697291454171,3554.0,0.7431427232049358,0.2863677151873603,3581.0,5981.699506759644,6484.898949146271,5981.699506759644,500.1464011669159,2.129655838012696,0.0 -25300,0.113125704,0.2933323,,,,,,,,,,,,,, -25400,0.13443819,0.23132561,,,,,,,,,,,,,, -25500,0.10953554,0.2554031,,,,,,,,,,,,,, -25580,,,0.7498783384050641,0.2653120585850307,0.726230086807998,0.2851230018113393,3554.0,0.7433790916905194,0.2863709535787838,3581.0,6061.817725658417,6569.079738616943,6061.817725658417,504.1658718585968,2.16015625,0.0 -25600,0.0682341,0.22257186,,,,,,,,,,,,,, -25700,0.17300048,0.2698814,,,,,,,,,,,,,, -25800,0.07840691,0.21699964,,,,,,,,,,,,,, -25900,0.21029805,0.28200334,,,,,,,,,,,,,, -25925,,,0.7497959818158831,0.2654145956039428,0.72636321695185,0.2851072535732185,3554.0,0.7434775387897934,0.2863817254913083,3581.0,6141.91224861145,6653.2393543720245,6141.91224861145,508.1867387294769,2.191812992095948,0.0 -26000,0.10669189,0.36076328,,,,,,,,,,,,,, -26100,0.11539027,0.29273877,,,,,,,,,,,,,, -26200,0.09504781,0.28967205,,,,,,,,,,,,,, -26268,,,0.7502094677516392,0.2649457454681396,0.7263461806898917,0.2850390226571292,3554.0,0.7434729027768081,0.2863529208518046,3581.0,6221.960324764252,6737.352556943893,6221.960324764252,512.2066950798035,2.224546432495117,0.0 -26300,0.20401187,0.24435245,,,,,,,,,,,,,, -26400,0.11295939,0.3293309,,,,,,,,,,,,,, -26500,0.13524254,0.28416654,,,,,,,,,,,,,, -26600,0.07682547,0.31447387,,,,,,,,,,,,,, -26613,,,0.7500839233398438,0.2652312006269182,0.7263961216674873,0.2850700726184405,3554.0,0.7435960298275621,0.2863526822334892,3581.0,6301.933490514755,6821.390177488327,6301.933490514755,516.2275338172913,2.2554852962493896,0.0 -26700,0.07687565,0.24760656,,,,,,,,,,,,,, -26800,0.109551884,0.22246183,,,,,,,,,,,,,, -26900,0.0870515,0.24968614,,,,,,,,,,,,,, -26958,,,0.7498068809509277,0.2652568135942731,0.7262263086047411,0.285042989770549,3554.0,0.7434109301914619,0.2862853236918807,3581.0,6382.069779396057,6905.588939666748,6382.069779396057,520.2461042404175,2.287006378173828,0.0 -27000,0.123749584,0.25419965,,,,,,,,,,,,,, -27100,0.3345115,0.28967088,,,,,,,,,,,,,, -27200,0.10864305,0.26621044,,,,,,,,,,,,,, -27300,0.1297907,0.24963671,,,,,,,,,,,,,, -27301,,,0.7503508159092495,0.2647902795246669,0.7261952586434299,0.2850366183459658,3554.0,0.7433469804829308,0.2863510459936121,3581.0,6462.058999300003,6989.5824983119965,6462.058999300003,524.2069962024689,2.317927598953247,0.0 -27400,0.08346716,0.24792624,,,,,,,,,,,,,, -27500,0.1939682,0.36477178,,,,,,,,,,,,,, -27600,0.16167349,0.28339309,,,,,,,,,,,,,, -27646,,,0.7499651908874512,0.2649988617215837,0.7260563581527856,0.2849680439568532,3554.0,0.743229239388439,0.28623681599719,3581.0,6542.14812707901,7073.7304475307465,6542.14812707901,528.22221159935,2.3489346504211426,0.0 -27700,0.07784699,0.30588406,,,,,,,,,,,,,, -27800,0.1021882,0.2670521,,,,,,,,,,,,,, -27900,0.08497104,0.25842777,,,,,,,,,,,,,, -27989,,,0.7497949600219727,0.2651307923453195,0.726063158918648,0.2850151684556573,3554.0,0.7432007415439124,0.2863091855234222,3581.0,6622.342363357544,7157.983325958252,6622.342363357544,532.2352705001831,2.3820061683654785,0.0 -28000,0.11958416,0.17673853,,,,,,,,,,,,,, -28100,0.13885175,0.23875095,,,,,,,,,,,,,, -28200,0.06586921,0.30990353,,,,,,,,,,,,,, -28300,0.14631122,0.30913988,,,,,,,,,,,,,, -28333,,,0.75055878502982,0.2647071395601545,0.7264404983821047,0.2850119913301913,3554.0,0.7435149677769827,0.2863208778208775,3581.0,6702.453696250916,7242.156578779221,6702.453696250916,536.2536578178406,2.413137912750244,0.0 -28400,0.16539213,0.3220885,,,,,,,,,,,,,, -28500,0.09580188,0.28740153,,,,,,,,,,,,,, -28600,0.11598412,0.31702423,,,,,,,,,,,,,, -28677,,,0.7505395071847099,0.2647926637104579,0.7266726174512873,0.2848629755590092,3554.0,0.7437867199499092,0.28617555926679,3581.0,6782.451961994171,7326.216543912888,6782.451961994171,540.2722911834717,2.4437241554260254,0.0 -28700,0.07849877,0.29163414,,,,,,,,,,,,,, -28800,0.12874438,0.23045895,,,,,,,,,,,,,, -28900,0.08862096,0.24936615,,,,,,,,,,,,,, -29000,0.10593884,0.29018503,,,,,,,,,,,,,, -29023,,,0.7501329013279506,0.2649626220975603,0.7262491839080966,0.2849998839061181,3554.0,0.7433983175090757,0.2862706316213174,3581.0,6862.631316900253,7410.462491750717,6862.631316900253,544.2934327125549,2.476641893386841,0.0 -29100,0.13622576,0.23781295,,,,,,,,,,,,,, -29200,0.085793756,0.24620008,,,,,,,,,,,,,, -29300,0.07717182,0.27637708,,,,,,,,,,,,,, -29366,,,0.7502886227199009,0.2647256510598318,0.7261591939759777,0.285006427067213,3554.0,0.7433084606691567,0.2862858691051731,3581.0,6942.597845315933,7494.490667819977,6942.597845315933,548.3099186420441,2.50946044921875,0.0 -29400,0.12271967,0.19695595,,,,,,,,,,,,,, -29500,0.10283769,0.22416006,,,,,,,,,,,,,, -29600,0.14249155,0.26146877,,,,,,,,,,,,,, -29700,0.07289438,0.3033096,,,,,,,,,,,,,, -29711,,,0.7505989074707031,0.2646856989179338,0.7264604198174592,0.2849455808211258,3554.0,0.7435891439847458,0.2862731200694638,3581.0,7022.648463249207,7578.602977514267,7022.648463249207,552.3254547119141,2.542933702468872,0.0 -29800,0.15414083,0.25701606,,,,,,,,,,,,,, -29900,0.098956175,0.23974787,,,,,,,,,,,,,, -30000,0.122166716,0.2948719,,,,,,,,,,,,,, -30056,,,0.7506581715175084,0.2647981473377773,0.7267338930386537,0.2849077472676034,3554.0,0.7438693500637042,0.2861787635698827,3581.0,7102.741122245789,7662.754905939102,7102.741122245789,556.3397581577301,2.5753839015960693,0.0 -30100,0.050707407,0.24938852,,,,,,,,,,,,,, -30200,0.12703222,0.25623626,,,,,,,,,,,,,, -30300,0.090138085,0.17996688,,,,,,,,,,,,,, -30397,,,0.7500651904514858,0.2648999180112566,0.7260751804744654,0.2850286325981728,3554.0,0.7432389204743787,0.2863478076021886,3581.0,7182.78594827652,7746.860567092895,7182.78594827652,560.3569049835205,2.6066181659698486,0.0 -30400,0.11574461,0.25362346,,,,,,,,,,,,,, -30500,0.16163264,0.21690042,,,,,,,,,,,,,, -30600,0.07522563,0.34406853,,,,,,,,,,,,,, -30700,0.045183484,0.28578293,,,,,,,,,,,,,, -30744,,,0.750626632145473,0.2645703213555472,0.7264586337577378,0.2849193566557927,3554.0,0.7435880531581611,0.2862245441981115,3581.0,7262.975407361984,7831.111491441727,7262.975407361984,564.373776435852,2.638613700866699,0.0 -30800,0.059634417,0.3233853,,,,,,,,,,,,,, -30900,0.06216059,0.24804944,,,,,,,,,,,,,, -31000,0.08467583,0.29085365,,,,,,,,,,,,,, -31088,,,0.7506256103515625,0.2646536486489432,0.7265625686946047,0.284842779345236,3554.0,0.7437144526886693,0.2861438230308398,3581.0,7342.940884590149,7915.147254228592,7342.940884590149,568.3998596668243,2.67030930519104,0.0 -31100,0.07067113,0.2032822,,,,,,,,,,,,,, -31200,0.059413575,0.19064875,,,,,,,,,,,,,, -31300,0.05264672,0.30600932,,,,,,,,,,,,,, -31400,0.072354935,0.19001085,,,,,,,,,,,,,, -31433,,,0.7506649153573173,0.2646878617150442,0.7265348847689224,0.2849661033342712,3554.0,0.743663320192509,0.2862593824721621,3581.0,7423.086068630218,7999.357291936874,7423.086068630218,572.4185557365417,2.7039809226989746,0.0 -31500,0.10154784,0.3266253,,,,,,,,,,,,,, -31600,0.08785094,0.22480585,,,,,,,,,,,,,, -31700,0.07197121,0.23043814,,,,,,,,,,,,,, -31779,,,0.7508582387651715,0.264353837285723,0.7265188102314294,0.28483960221977,3554.0,0.7436525482799846,0.2861585832780647,3581.0,7503.2064254283905,8083.542644500732,7503.2064254283905,576.4384963512421,2.736294269561768,0.0 -31800,0.061641425,0.24440353,,,,,,,,,,,,,, -31900,0.06849082,0.29842,,,,,,,,,,,,,, -32000,0.08783429,0.30423027,,,,,,,,,,,,,, -32100,0.053484663,0.3117317,,,,,,,,,,,,,, -32124,,,0.7506969315665108,0.2644423757280622,0.7265397620858539,0.28476238948412,3554.0,0.743674023928372,0.2860684878198303,3581.0,7583.260835409164,8167.663266420364,7583.260835409164,580.4590227603912,2.7696642875671387,0.0 -32200,0.057786517,0.2198379,,,,,,,,,,,,,, -32300,0.066553876,0.27684867,,,,,,,,,,,,,, -32400,0.08118868,0.26105884,,,,,,,,,,,,,, -32467,,,0.7506919588361468,0.2644823108400617,0.7265296639789673,0.2847786872790781,3554.0,0.7436811824778344,0.2860721693595539,3581.0,7663.235911130905,8251.700382471085,7663.235911130905,584.4764456748962,2.801800489425659,0.0 -32500,0.05148777,0.3058834,,,,,,,,,,,,,, -32600,0.06813754,0.2688253,,,,,,,,,,,,,, -32700,0.057371587,0.26086742,,,,,,,,,,,,,, -32800,0.08278305,0.23086272,,,,,,,,,,,,,, -32810,,,0.7511326926095145,0.2642045361655099,0.7267132846572524,0.2847665283340514,3554.0,0.7438340345530229,0.2860847479536093,3581.0,7743.370946645737,8335.898382425308,7743.370946645737,588.4945957660675,2.8336546421051025,0.0 -32900,0.12169177,0.2464498,,,,,,,,,,,,,, -33000,0.07763739,0.32256466,,,,,,,,,,,,,, -33100,0.06423656,0.25606686,,,,,,,,,,,,,, -33155,,,0.750725269317627,0.2643196071897234,0.7264477800101998,0.2847503507546514,3554.0,0.7435964388875315,0.2860520572443975,3581.0,7823.340132236481,8419.934628725052,7823.340132236481,592.5159690380096,2.867058038711548,0.0 -33200,0.043606192,0.32976887,,,,,,,,,,,,,, -33300,0.06586181,0.1971776,,,,,,,,,,,,,, -33400,0.06968609,0.26666552,,,,,,,,,,,,,, -33500,,,0.7509395054408482,0.2643869774682181,0.7267172002497186,0.2847576495563977,3554.0,0.7438440565222704,0.2860510005061435,3581.0,7903.390450239181,8504.049171447754,7903.390450239181,596.5357365608215,2.898998022079468,0.0 -33500,0.06168929,0.2531596,,,,,,,,,,,,,, -33600,0.07207161,0.2853372,,,,,,,,,,,,,, -33700,0.074908406,0.24556185,,,,,,,,,,,,,, -33800,0.100966014,0.3268312,,,,,,,,,,,,,, -33844,,,0.7513058526175362,0.2641360419137137,0.7268690840206458,0.2847431893421145,3554.0,0.7439712741727171,0.2860495006195895,3581.0,7983.478590726852,8588.198611021042,7983.478590726852,600.5509705543518,2.932687520980835,0.0 -33900,0.050031643,0.2135186,,,,,,,,,,,,,, -34000,0.1432645,0.26257655,,,,,,,,,,,,,, -34100,0.06624876,0.22285572,,,,,,,,,,,,,, -34190,,,0.7509686606270927,0.2641798939023699,0.7265910082609384,0.2846628166546497,3554.0,0.7437399507600879,0.2859617913445092,3581.0,8063.584829330444,8672.370778083801,8063.584829330444,604.5714876651764,2.965465784072876,0.0 -34200,0.079842374,0.23149669,,,,,,,,,,,,,, -34300,0.06618465,0.24997535,,,,,,,,,,,,,, -34400,0.054740008,0.18645462,,,,,,,,,,,,,, -34500,0.07637705,0.2484476,,,,,,,,,,,,,, -34534,,,0.7508820125034877,0.2642404351915632,0.726557279210045,0.2847020241002655,3554.0,0.7436853412541887,0.2859943457003979,3581.0,8143.587386846542,8756.439602851868,8143.587386846542,608.591876745224,2.998674154281616,0.0 -34600,0.059182595,0.22409673,,,,,,,,,,,,,, -34700,0.12942708,0.39675763,,,,,,,,,,,,,, -34800,0.10138907,0.34024477,,,,,,,,,,,,,, -34875,,,0.7510365758623395,0.264101539339338,0.7266363467000211,0.2846636409899057,3554.0,0.7437644943582449,0.2859594392496858,3581.0,8223.553173303604,8840.46463394165,8223.553173303604,612.6066925525665,3.0308783054351807,0.0 -34900,0.07138102,0.24535741,,,,,,,,,,,,,, -35000,0.06056174,0.32436273,,,,,,,,,,,,,, -35100,0.09493071,0.28037757,,,,,,,,,,,,,, -35200,0.05836576,0.25357732,,,,,,,,,,,,,, -35220,,,0.7510863712855748,0.2640958513532366,0.7266660227692389,0.2846420365367367,3554.0,0.7438080592449735,0.2859407588444219,3581.0,8303.650809764862,8924.624583244324,8303.650809764862,616.624080657959,3.063483715057373,0.0 -35300,0.06860242,0.3026769,,,,,,,,,,,,,, -35400,0.05657117,0.3092052,,,,,,,,,,,,,, -35500,0.066986345,0.2633338,,,,,,,,,,,,,, -35566,,,0.7510568073817662,0.264108487537929,0.7267169254713,0.2846336729686181,3554.0,0.7438390796259774,0.2859345547682211,3581.0,8383.785435676575,9008.82751774788,8383.785435676575,620.6433687210083,3.099930047988892,0.0 -35600,0.04369277,0.19503021,,,,,,,,,,,,,, -35700,0.06433414,0.30994266,,,,,,,,,,,,,, -35800,0.089018434,0.22954531,,,,,,,,,,,,,, -35900,0.07766331,0.28995532,,,,,,,,,,,,,, -35909,,,0.7510451589311872,0.2640914406095232,0.7266632062904473,0.2846232485623593,3554.0,0.7438034914086498,0.2859224874991273,3581.0,8463.990142822266,9093.09502840042,8463.990142822266,624.6613454818726,3.1323611736297607,0.0 -36000,0.065234624,0.22369924,,,,,,,,,,,,,, -36100,0.05683265,0.396511,,,,,,,,,,,,,, -36189,,,0.7512311254228864,0.2641005516052246,0.7268648249551561,0.2846340336152926,3554.0,0.7439900909313041,0.2859340775315903,3581.0,8528.737991809845,9161.908516168594,8528.737991809845,628.6836860179901,3.1654698848724365,0.0 -36189,,,,,,,,,,,8528.737991809845,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/eval_measurements.csv deleted file mode 100644 index aa5354711..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.9605116844177246,0.0,29.17657399177552,1,0,29.17657399177552,0.8949576268413153,3581,0.2836721463431129,33.137226819992065,0.8928326879228864,0.2656774180276053,0.8956671152882316,3554,0.2613217295323051 -7.972739458084106,0.0181217193603515,109.272686958313,345,0,109.272686958313,0.3249313665548206,3581,0.6965224312124756,117.27683687210084,0.3054479190281459,0.7004432678222656,0.3231287040130838,3554,0.6780527057980444 -11.985413312911987,0.0421442985534668,189.41457438468933,684,0,189.41457438468933,0.3102621692613795,3581,0.7156626198818417,201.46821546554563,0.2908952576773507,0.719583238874163,0.3081439033813133,3554,0.6983924912950197 -16.004467010498047,0.0665562152862548,269.56817269325256,1026,0,269.56817269325256,0.3069072980116413,3581,0.7192106015254119,285.6784076690674,0.2880307946886335,0.7229753902980259,0.305126321134637,3554,0.70181561214037 -20.019661903381348,0.0912399291992187,349.55478286743164,1372,0,349.55478286743164,0.3062297242608559,3581,0.7239998756457693,369.7181885242462,0.2867491585867746,0.7281092916216169,0.3044106607431767,3554,0.7068357451551069 -24.035717964172363,0.1169250011444091,429.7373430728912,1719,0,429.7373430728912,0.2991478394543249,3581,0.728800262561959,453.9560031890869,0.2795459202357701,0.7329938752310616,0.2974773486410558,3554,0.7116876450830051 -28.04653525352478,0.141535997390747,509.9147651195526,2067,0,509.9147651195526,0.3087029349234676,3581,0.7207346226176348,538.1823544502258,0.2891953672681536,0.7248802185058594,0.3067685000065946,3554,0.7033233213236846 -32.063100814819336,0.1654932498931884,590.0707066059113,2415,0,590.0707066059113,0.3006744853480173,3581,0.7260800819428931,622.3923435211182,0.2816903250558035,0.72991943359375,0.2988491112017269,3554,0.7089212446583075 -36.07767248153687,0.1900866031646728,670.0614335536957,2762,0,670.0614335536957,0.2960129061147375,3581,0.7312001492250768,706.4358239173889,0.2765935829707554,0.735365731375558,0.294382897131753,3554,0.713966107730726 -40.0933940410614,0.216646671295166,750.1785328388214,3107,0,750.1785328388214,0.2993197469064158,3581,0.7275827637400517,790.608763217926,0.2801121303013393,0.7322643143790108,0.2975265683253025,3554,0.7105216915579277 -44.107200622558594,0.2426018714904785,830.3153464794159,3455,0,830.3153464794159,0.2944108909217048,3581,0.7320430854684445,874.7988729476929,0.2747059549604143,0.7370850699288505,0.2927326122964441,3554,0.7150101970271173 -48.12011766433716,0.2682778835296631,910.3621096611024,3803,0,910.3621096611024,0.2935579327090896,3581,0.7343342985330564,958.8977530002594,0.2740594318934849,0.7389027050563267,0.2919730217052969,3554,0.7170717221132878 -52.13883900642395,0.293527603149414,990.4565176963806,4149,0,990.4565176963806,0.2925437026035848,3581,0.7352640236665736,1043.0497572422028,0.2729658229010446,0.7399940490722656,0.2910110224614871,3554,0.7180402473445414 -56.15374946594238,0.3187620639801025,1070.5930817127228,4496,0,1070.5930817127228,0.2942437217475914,3581,0.7336459187814158,1127.2398805618286,0.2744659355708531,0.7390618324279785,0.2924563912910629,3554,0.7164873371113534 -60.17170739173889,0.3441455364227295,1150.773977279663,4843,0,1150.773977279663,0.2972117927080424,3581,0.7299596748682281,1211.4776709079742,0.2777554307665144,0.7348019736153739,0.2955745424389772,3554,0.7128869154913478 -64.18590641021729,0.3689992427825928,1230.890167951584,5192,0,1230.890167951584,0.291365269008744,3581,0.7364080962239947,1295.6464822292328,0.2717779534203665,0.7410356657845634,0.2897226894981886,3554,0.7192005679120357 -68.20439505577087,0.3939557075500488,1310.9333474636078,5538,0,1310.9333474636078,0.2921690718483838,3581,0.7357066265533371,1379.7466549873352,0.2726975509098598,0.7400689806256976,0.2905688352912212,3554,0.7185819729969752 -72.22309613227844,0.4202187061309814,1390.929930686951,5887,0,1390.929930686951,0.2909360287476438,3581,0.7382192092379922,1463.8018777370453,0.2713005372456142,0.7430743489946637,0.2893864294083251,3554,0.7210018091411086 -76.23468589782715,0.4449434280395508,1471.0121450424194,6236,0,1471.0121450424194,0.2907333395328644,3581,0.7374493583758028,1547.9339079856873,0.2709931646074567,0.7420650890895298,0.2892046634843662,3554,0.7202885530608117 -80.25021290779114,0.4702737331390381,1551.1771442890167,6579,0,1551.1771442890167,0.2907080800797612,3581,0.7380951277139766,1632.15380692482,0.2706704991204398,0.7435024806431362,0.2891689422899374,3554,0.720853978351857 -84.26713466644287,0.4968934059143066,1631.2922222614288,6928,0,1631.2922222614288,0.2901382936308817,3581,0.7376407302647654,1716.3261096477509,0.2704187972205026,0.7426629747663226,0.2886667160351892,3554,0.720276600199599 -88.28663372993469,0.5245425701141357,1711.3607697486875,7279,0,1711.3607697486875,0.2898458498411757,3581,0.7389450179768221,1800.455798149109,0.2699455533708845,0.7441093581063407,0.2883482821952378,3554,0.7217107374613112 -92.30220317840576,0.5551633834838867,1791.4479806423187,7623,0,1791.4479806423187,0.2998467865885053,3581,0.7279741659539933,1884.60189819336,0.2796070405415126,0.7333408083234515,0.2982994513224535,3554,0.7108606307373734 -96.32317352294922,0.5827598571777344,1871.6037764549253,7970,0,1871.6037764549253,0.2914284687739982,3581,0.737044593536198,1968.8186626434328,0.2711213656834194,0.7422810282026019,0.2898536557619935,3554,0.7199205560635903 -100.34027457237244,0.6097853183746338,1951.662151813507,8317,0,1951.662151813507,0.2901830516091874,3581,0.7393386700205948,2052.933934688568,0.2702023983001709,0.7444945062909808,0.2886703911965391,3554,0.7221898136342854 -104.35645294189452,0.6381211280822754,2031.719271659851,8661,0,2031.719271659851,0.2982390104675719,3581,0.7288023078618053,2137.048096179962,0.2788037402289254,0.732374940599714,0.2963294274497046,3554,0.7117548971009777 -108.3740575313568,0.664588451385498,2111.8080835342407,9010,0,2111.8080835342407,0.2911604663174567,3581,0.7392850149879573,2221.1938774585724,0.2709218774523054,0.7443432807922363,0.289655952689751,3554,0.7221872032393079 -112.39150047302246,0.6922500133514404,2191.930754899978,9358,0,2191.930754899978,0.2904191814764556,3581,0.7372185121998045,2305.374571084976,0.2705019201551165,0.7424185616629464,0.289040895546831,3554,0.7198816749173467 -116.40704345703124,0.7198827266693115,2271.9885540008545,9704,0,2271.9885540008545,0.2890612387317613,3581,0.7396080360103672,2389.488274335861,0.2691502230507986,0.7448130335126605,0.2876758307101154,3554,0.7223689004686621 -120.42432022094728,0.7467184066772461,2351.9646701812744,10051,0,2351.9646701812744,0.2891233476704307,3581,0.7401323827143256,2473.5214030742645,0.2687918969563075,0.7456962721688407,0.28772884577127,3554,0.7229699095649268 -124.384624004364,0.773712158203125,2432.060645341873,10401,0,2432.060645341873,0.2891068830066671,3581,0.739315762662315,2557.617630958557,0.2693566083908081,0.7442936897277832,0.2877130803594981,3554,0.722090962098164 -128.40318179130554,0.8037357330322266,2512.202274799347,10748,0,2512.202274799347,0.2888093600556758,3581,0.7397670921617565,2641.820770740509,0.2689313207353864,0.7448418481009347,0.2874220900141126,3554,0.7225382326691756 -132.41613030433655,0.836331844329834,2592.4162380695343,11093,0,2592.4162380695343,0.289260280495148,3581,0.7395489950214674,2726.093177318573,0.2687351363045828,0.7452192306518555,0.2877938480409397,3554,0.7224126589318374 -136.42907905578613,0.8633337020874023,2672.483766078949,11443,0,2672.483766078949,0.2898771429288257,3581,0.7395929007915037,2810.2135293483734,0.2697348083768572,0.7452154840741839,0.2883542586258441,3554,0.7223850437007597 -140.4444704055786,0.8953886032104492,2752.4405829906464,11789,0,2752.4405829906464,0.288991903066968,3581,0.7399063089046356,2894.230272054672,0.2692332097462245,0.7447735241481236,0.2875390082312623,3554,0.7227943261553883 -144.46209335327148,0.9224934577941896,2832.651171684265,12137,0,2832.651171684265,0.288678972190467,3581,0.7404592216297822,2978.4996314048767,0.2685306753431047,0.7456636428833008,0.2873849434066369,3554,0.7233629800928532 -148.47949981689453,0.949455499649048,2912.848475217819,12485,0,2912.848475217819,0.2889450316121544,3581,0.7399579186374267,3062.7552905082703,0.2689953531537737,0.7451457296098981,0.287498426893553,3554,0.722774129941615 -152.49664402008057,0.9763846397399902,2992.999326467514,12834,0,2992.999326467514,0.2885619810392523,3581,0.7405707586480732,3146.96438741684,0.2684512138366699,0.7460318974086216,0.2871742742277627,3554,0.7233857180069991 -156.51189756393433,1.0034832954406738,3072.988017320633,13179,0,3072.988017320633,0.2888663216563983,3581,0.7396684405324979,3231.008982419968,0.2685987268175397,0.745182854788644,0.2874176076911579,3554,0.7225593219128095 -160.52774262428284,1.0317604541778564,3153.030436277389,13526,0,3153.030436277389,0.2885537316632051,3581,0.7403992943442823,3315.10955786705,0.2681436708995274,0.7462560108729771,0.2871634720011782,3554,0.7231661700504713 -164.54671454429626,1.0586469173431396,3233.1198103427887,13876,0,3233.1198103427887,0.2886610758168109,3581,0.7413314056609537,3399.258532524109,0.2683169841766357,0.7469626835414341,0.287253771059018,3554,0.7242131445202589 -168.5634524822235,1.0875027179718018,3313.156766176224,14220,0,3313.156766176224,0.2895751203181723,3581,0.7383425408187309,3483.354829549789,0.2695116315569196,0.7435752323695591,0.2880564846882474,3554,0.721224379660242 -172.5811402797699,1.114877223968506,3393.3027589321136,14567,0,3393.3027589321136,0.287807469925911,3581,0.7413343372574002,3567.560054063797,0.2674135140010288,0.7471098899841309,0.2864020266557048,3554,0.7242367754642656 -176.59819197654724,1.142707586288452,3473.428725004196,14914,0,3473.428725004196,0.2878112878189577,3581,0.7402175353645979,3651.744841337204,0.2676483733313424,0.7454605102539062,0.2864320290242948,3554,0.7230939720209623 -180.614520072937,1.1707472801208496,3553.576901912689,15260,0,3553.576901912689,0.2877523490950502,3581,0.7417507603061295,3735.9517509937286,0.2672039951596941,0.7476382255554199,0.2863704958321609,3554,0.7246152827360017 -184.6301848888397,1.1982817649841309,3633.779043912888,15610,0,3633.779043912888,0.287978763788048,3581,0.7408511010803547,3820.211359500885,0.2675726413726806,0.7465495382036481,0.2865305542610439,3554,0.7237017818830894 -188.6462926864624,1.226337432861328,3713.819554805756,15958,0,3713.819554805756,0.2877550420731813,3581,0.7408671907724798,3904.3096396923065,0.2675352437155587,0.7465058735438755,0.2864478459570203,3554,0.7236025868739449 -192.6677522659301,1.255561113357544,3793.9894206523895,16305,0,3793.9894206523895,0.2879381645860968,3581,0.7417191263351718,3988.544013738632,0.2676750591823033,0.7472934722900391,0.286674658367992,3554,0.7245056461469471 -196.6825180053711,1.2838151454925537,3874.103070020676,16653,0,3874.103070020676,0.287911780218078,3581,0.7405575323757331,4072.714470386505,0.267695529120309,0.7460013798304966,0.2865367024281619,3554,0.72336675829611 -200.70232272148127,1.3174071311950684,3954.122975349426,16999,0,3954.122975349426,0.2877251806954237,3581,0.7424530480967257,4156.801654577255,0.2673676184245518,0.7481280054364886,0.2864205398516636,3554,0.7252748883300506 -204.7229642868042,1.3461709022521973,4034.12496638298,17344,0,4034.12496638298,0.2879645830424462,3581,0.7410125434148981,4240.8676471710205,0.2677748032978603,0.746361528124128,0.28663830174847,3554,0.7239071787510551 -208.7395868301392,1.3738317489624023,4114.188055753708,17693,0,4114.188055753708,0.2882696395145385,3581,0.7423913482180257,4324.98748755455,0.2673057488032749,0.748692240033831,0.2868741990209095,3554,0.7252666449774902 -212.75606155395508,1.402545690536499,4194.349546909332,18042,0,4194.349546909332,0.2873323808599204,3581,0.7410724025237364,4409.206943273544,0.2669903721128191,0.7466691562107631,0.2859708306221862,3554,0.7238836165016531 -216.77316308021545,1.4304280281066897,4274.440649032593,18389,0,4274.440649032593,0.2874088750741762,3581,0.7420026730705459,4493.35611987114,0.2667764595576695,0.7479186739240374,0.2860190714083163,3554,0.7248502869785804 -220.79004406929016,4.727059602737427,4351.382478237152,18720,0,4351.382478237152,0.2871059661669226,3581,0.7416746069751815,4577.62344622612,0.2662788459232875,0.7479407446725028,0.285733319026537,3554,0.7245068826498312 -224.8030700683593,4.756349325180054,4431.4831800460815,19068,0,4431.4831800460815,0.2871270327553407,3581,0.7421484347729336,4661.779294967651,0.2664348908833095,0.7482282093593052,0.2857875362432734,3554,0.7249527106341446 -228.8146023750305,4.786910057067871,4511.659183979034,19413,0,4511.659183979034,0.2872160373869903,3581,0.7418394581428023,4746.009649276733,0.2666091578347342,0.747701713017055,0.2858498937706633,3554,0.7246733296769485 -232.8302583694458,4.815237998962402,4591.833205223084,19760,0,4591.833205223084,0.2871845397693556,3581,0.7423292392793563,4830.240214586258,0.2660056182316371,0.7489380155290876,0.2857684906641284,3554,0.7252018659652856 -236.8491904735565,4.845395565032959,4671.902544498444,20110,0,4671.902544498444,0.2872105491657358,3581,0.7421734556077213,4914.371288776398,0.2663228171212332,0.7484143120901925,0.2858430243101962,3554,0.724964320022334 -240.86488962173465,4.876446485519409,4752.0476224422455,20456,0,4752.0476224422455,0.2872538072574874,3581,0.7424498437936331,4998.575645685196,0.2665191377912249,0.7484644481113979,0.2858613657696434,3554,0.7253575279394696 -244.88236665725708,4.90512490272522,4832.189853668213,20805,0,4832.189853668213,0.2869579205463732,3581,0.7422169523177883,5082.776785135269,0.2659215075629098,0.7486716679164341,0.2855840456505873,3554,0.7250687358214336 -248.8937950134277,4.9343202114105225,4912.309539794922,21151,0,4912.309539794922,0.2869437057124406,3581,0.7421067788327282,5166.949881315231,0.2661462851933071,0.748410837990897,0.2856016486430343,3554,0.7249383534617684 -252.9103980064392,4.963292837142944,4992.323907136917,21499,0,4992.323907136917,0.2870643443150482,3581,0.7424884999607303,5251.02219581604,0.266301155090332,0.7486587933131627,0.2857149088724852,3554,0.7253595200830051 -256.9296944141388,4.993821382522583,5072.416844844818,21844,0,5072.416844844818,0.286989554517331,3581,0.74234614709142,5335.177726268768,0.2658378056117466,0.7488077027457101,0.2856382285200214,3554,0.7251955460616559 -260.9511814117432,5.023503065109253,5152.545783519745,22193,0,5152.545783519745,0.2869001408257121,3581,0.7422395869694219,5419.370491504669,0.2658830199922834,0.7486767087663923,0.2855381576546673,3554,0.7250403649497046 -264.9676489830017,5.052587985992432,5232.754349708557,22543,0,5232.754349708557,0.2869052881636589,3581,0.7420529192701061,5503.637329339981,0.2659910746983119,0.7485130855015346,0.2855696197836065,3554,0.7248971366989659 -268.98533368110657,5.083728313446045,5312.916896104813,22886,0,5312.916896104813,0.2871465653688739,3581,0.7422529495950851,5587.861164093018,0.265933837209429,0.748805318559919,0.2858411352085678,3554,0.7251649769625774 -273.00149416923523,5.113495111465454,5392.880267381668,23231,0,5392.880267381668,0.2868351002905962,3581,0.7422787885498116,5671.883190155029,0.2657624823706491,0.7487833840506417,0.2855141317166836,3554,0.7251254775648917 -277.0170331001282,5.143171548843384,5472.882793188095,23577,0,5472.882793188095,0.2869184462593375,3581,0.7426354206663641,5755.94358754158,0.2659006799970354,0.7490690095084054,0.2855618401196275,3554,0.725525623637099 -281.03462076187134,5.173579931259155,5552.945104598999,23922,0,5552.945104598999,0.2869320475033161,3581,0.7424687287288816,5840.066252231598,0.2657322883605957,0.749016353062221,0.285539771977877,3554,0.7253279205648565 -285.0586128234863,5.204019546508789,5632.908239126205,24267,0,5632.908239126205,0.2868933231595574,3581,0.742668554523876,5924.096672773361,0.2656536953789847,0.7493527276175362,0.2854886288446996,3554,0.7255556431793402 -289.0750644207001,5.235795736312866,5712.8686356544495,24614,0,5712.8686356544495,0.2870517316326619,3581,0.7432403521842712,6008.117910861969,0.2658327988215855,0.749849796295166,0.2856676298108205,3554,0.7261342578344823 -293.0931673049927,5.266326189041138,5792.95509147644,24958,0,5792.95509147644,0.2869829754694917,3581,0.7426132632513613,6092.265564203262,0.265755363873073,0.7492024557931083,0.2856205568329699,3554,0.725492238059229 -297.111275434494,5.298498392105103,5872.948161840439,25303,0,5872.948161840439,0.2868363615588348,3581,0.7425200657550265,6176.321474313736,0.265476039477757,0.7492395809718541,0.2854243306947278,3554,0.7254227878139069 -301.1283836364746,5.329797029495239,5953.030558347702,25649,0,5953.030558347702,0.2868049662061924,3581,0.7429180811051382,6260.4646916389465,0.2655596562794277,0.7495963232857841,0.2854543845842712,3554,0.7257776641416361 -305.14564299583435,5.360304832458496,6033.104256391525,25994,0,6033.104256391525,0.2868030572596691,3581,0.7422948782419366,6344.598982095718,0.2655539512634277,0.7489735058375767,0.2854096300493282,3554,0.7251500015387592 -309.16142416000366,5.391079664230347,6113.080691099167,26340,0,6113.080691099167,0.2867089734667341,3581,0.7425187703984572,6428.634559392929,0.2652556555611746,0.749401501246861,0.2853258226316298,3554,0.7253881657331528 -313.17753505706787,5.427402496337891,6193.06077671051,26688,0,6193.06077671051,0.2868135564655473,3581,0.7425824474003421,6512.6798322200775,0.2654205220086233,0.7494032042367118,0.2854462957945712,3554,0.7254571351162422 -317.19087767601013,5.4581146240234375,6273.084081888199,27033,0,6273.084081888199,0.2867420732359153,3581,0.7426219216873778,6596.759754896164,0.2654526914869036,0.7493421009608677,0.2853781850940401,3554,0.7255051526449071 -321.2112212181092,5.4910569190979,6353.071182966232,27380,0,6353.071182966232,0.2868342821706576,3581,0.7430699105304035,6680.812620401382,0.2650948081697736,0.7501442773001534,0.2854123091389103,3554,0.7259720011782499 -325.23206520080566,5.522183656692505,6433.0986251831055,27728,0,6433.0986251831055,0.2867781727782044,3581,0.7427992491840617,6764.905099153519,0.2653770787375314,0.7496339934212821,0.2854231972337507,3554,0.7256522277935074 -329.251145362854,5.553809642791748,6513.174396276474,28073,0,6513.174396276474,0.2867435049458077,3581,0.7428340874581123,6849.044062137604,0.2653424569538661,0.749655042375837,0.2853482342464037,3554,0.7257410499173467 -333.2688903808594,5.586345434188843,6593.317857980728,28419,0,6593.317857980728,0.2867419709709229,3581,0.7429455562997417,6933.250350475311,0.2649978228977748,0.7501136234828404,0.2853657857178971,3554,0.7258814616892938 -337.28736329078674,5.618357419967651,6673.433345794678,28766,0,6673.433345794678,0.2867050533086952,3581,0.7429291257243088,7017.428941488266,0.2651689052581787,0.74993896484375,0.285333516427353,3554,0.7257976542715954 -341.3042325973511,5.653614521026611,6753.462169647217,29111,0,6753.462169647217,0.2867233246539898,3581,0.7428636761292237,7101.522633552551,0.2652667249952044,0.7497790881565639,0.2853496424857994,3554,0.7257436603123242 -345.3211672306061,5.685898780822754,6833.543080806732,29455,0,6833.543080806732,0.2868116816073547,3581,0.7431262926295029,7185.665328025818,0.2649745089667184,0.7503256797790527,0.285384728255135,3554,0.7260548468714828 -349.28695940971375,5.716902494430542,6913.672247648239,29801,0,6913.672247648239,0.2866652722266825,3581,0.742863471599239,7269.804100990295,0.2650454044342041,0.7499323572431292,0.2852960950414586,3554,0.7257351421813449 -353.3090696334839,5.7517173290252686,6993.665600538254,30146,0,6993.665600538254,0.2867398915827457,3581,0.7429881667132086,7353.867259740829,0.2652068478720529,0.7499641690935407,0.285359843634593,3554,0.725912923818233 -357.32208609580994,5.783747434616089,7073.665458440781,30491,0,7073.665458440781,0.2868292030093723,3581,0.7430889318189752,7437.925004243851,0.264944110597883,0.7503576959882464,0.2854310284186832,3554,0.7260252394968697 -361.3407151699066,5.820956230163574,7153.841228246689,30838,0,7153.841228246689,0.2866669084665596,3581,0.7428631307159314,7522.169150829315,0.2649753093719482,0.7499535424368722,0.2853013158314135,3554,0.7257873500808948 -365.355765581131,5.85828971862793,7233.882091283798,31183,0,7233.882091283798,0.2867455502456541,3581,0.7429502604893884,7606.274645090103,0.2650706427437918,0.7500086511884417,0.285347959467985,3554,0.7258787139051069 -369.3752720355988,5.89311146736145,7313.855280160904,31525,0,7313.855280160904,0.2868032617896537,3581,0.7428442457806828,7690.314583301544,0.2649495601654053,0.7500947543552944,0.2854118454503288,3554,0.7257277918586452 -373.39921855926514,5.924740552902222,7393.850246191025,31870,0,7393.850246191025,0.2866636359868054,3581,0.7428506543868681,7774.377434492111,0.2648903812680925,0.7500623975481305,0.2852734773428707,3554,0.7257400194982766 -377.42104148864746,5.95652961730957,7473.869105577469,32217,0,7473.869105577469,0.2866684083531136,3581,0.7432891666739389,7858.462511301041,0.2650008542197091,0.7503595352172852,0.2853202583686515,3554,0.7261762302379361 -381.439279794693,5.988731384277344,7553.872319459915,32561,0,7553.872319459915,0.2869319452383237,3581,0.742864085189193,7942.528846740723,0.2650911467415945,0.7500896453857422,0.2854902431679094,3554,0.7257870753024761 -385.458370923996,6.020787715911865,7633.981284618378,32908,0,7633.981284618378,0.286711257384896,3581,0.7429668955948059,8026.701743841171,0.2648369073867798,0.750241756439209,0.2853054546813449,3554,0.7258719131392445 -389.4760494232178,6.053780794143677,7714.053889751434,33255,0,7714.053889751434,0.2866541935191811,3581,0.7429962797359327,8110.837794303894,0.2649034091404506,0.7501583780561175,0.2852611123140299,3554,0.7258987727296707 -393.4959604740143,6.090385913848877,7794.201321601868,33598,0,7794.201321601868,0.2866700445929908,3581,0.7428910831471656,8195.054210424423,0.2648886442184448,0.7500778606959752,0.2852774788035928,3554,0.7257746415790307 -397.5133173465729,6.123018741607666,7874.296406030655,33945,0,7874.296406030655,0.2866699423279985,3581,0.7429065592493368,8279.21177649498,0.2647657053811209,0.7502149173191616,0.2852540195960977,3554,0.7258221095508582 -401.5292658805847,6.1560492515563965,7954.259927272797,34293,0,7954.259927272797,0.2866239912581158,3581,0.7431016808546844,8363.237367153168,0.2648194347109113,0.7502946172441755,0.2852201703296462,3554,0.7260055241453293 -405.5457291603088,6.190333127975464,8034.41494178772,34639,0,8034.41494178772,0.286623650374808,3581,0.7430262092903519,8447.455779075623,0.2648192303521292,0.7502439362662179,0.2852440760520716,3554,0.7259190376380487 -409.5631048679352,6.223457336425781,8114.595928907394,34986,0,8114.595928907394,0.2866677606748289,3581,0.7431830837885717,8531.699686050415,0.264769230570112,0.7504809924534389,0.2852524911411437,3554,0.7261091843037775 -413.5802364349365,6.2608723640441895,8194.691151618958,35334,0,8194.691151618958,0.2866166963553302,3581,0.7429592598087127,8615.862161159515,0.2647643089294433,0.7502155985151019,0.2852156708330402,3554,0.7258480761114238 -417.6001315116882,6.293979644775391,8274.671577215195,35681,0,8274.671577215195,0.2866270932962161,3581,0.7431732663493088,8699.908791542053,0.264780044555664,0.7504230907985142,0.2852255113351593,3554,0.7260905680659117 -421.6159639358521,6.335568904876709,8354.620125293732,36027,0,8354.620125293732,0.2866263774412699,3581,0.7430758418999581,8783.927860736847,0.2647636617933001,0.7503397805350167,0.2852229009401818,3554,0.7259813436444851 -425.62926030158997,6.368803024291992,8390.96516919136,36189,0,8390.96516919136,0.286627570532847,3581,0.7430886591123289,8824.325541734695,0.2647658245904105,0.7503502709524972,0.2852237252754379,3554,0.7259957008168613 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/measurements.csv deleted file mode 100644 index 12288e868..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.160697,0.8857512,,,,,,,,,,,,,, -1,,,0.2656774180276053,0.8928326879228864,0.2613217295323051,0.8956671152882316,3554.0,0.2836721463431129,0.8949576268413153,3581.0,29.17657399177552,33.137226819992065,29.17657399177552,3.9605116844177246,0.0,0.0 -100,0.5582873,0.3209898,,,,,,,,,,,,,, -200,0.20868555,0.2707943,,,,,,,,,,,,,, -300,0.16465956,0.30556566,,,,,,,,,,,,,, -345,,,0.7004432678222656,0.3054479190281459,0.6780527057980444,0.3231287040130838,3554.0,0.6965224312124756,0.3249313665548206,3581.0,109.272686958313,117.27683687210084,109.272686958313,7.972739458084106,0.0181217193603515,0.0 -400,0.22950771,0.29260942,,,,,,,,,,,,,, -500,0.30240402,0.29898804,,,,,,,,,,,,,, -600,0.123142004,0.27235395,,,,,,,,,,,,,, -684,,,0.719583238874163,0.2908952576773507,0.6983924912950197,0.3081439033813133,3554.0,0.7156626198818417,0.3102621692613795,3581.0,189.41457438468933,201.46821546554563,189.41457438468933,11.985413312911987,0.0421442985534668,0.0 -700,0.97287077,0.24656655,,,,,,,,,,,,,, -800,0.18481122,0.35121107,,,,,,,,,,,,,, -900,0.15338825,0.36575383,,,,,,,,,,,,,, -1000,0.25266376,0.28435624,,,,,,,,,,,,,, -1026,,,0.7229753902980259,0.2880307946886335,0.70181561214037,0.305126321134637,3554.0,0.7192106015254119,0.3069072980116413,3581.0,269.56817269325256,285.6784076690674,269.56817269325256,16.004467010498047,0.0665562152862548,0.0 -1100,0.81229,0.2740228,,,,,,,,,,,,,, -1200,0.15972626,0.27921146,,,,,,,,,,,,,, -1300,0.100006856,0.29366085,,,,,,,,,,,,,, -1372,,,0.7281092916216169,0.2867491585867746,0.7068357451551069,0.3044106607431767,3554.0,0.7239998756457693,0.3062297242608559,3581.0,349.55478286743164,369.7181885242462,349.55478286743164,20.019661903381348,0.0912399291992187,0.0 -1400,0.68574035,0.30387726,,,,,,,,,,,,,, -1500,0.3763455,0.27916938,,,,,,,,,,,,,, -1600,0.34039953,0.25390378,,,,,,,,,,,,,, -1700,0.20141703,0.41627136,,,,,,,,,,,,,, -1719,,,0.7329938752310616,0.2795459202357701,0.7116876450830051,0.2974773486410558,3554.0,0.728800262561959,0.2991478394543249,3581.0,429.7373430728912,453.9560031890869,429.7373430728912,24.035717964172363,0.1169250011444091,0.0 -1800,0.23002268,0.25903493,,,,,,,,,,,,,, -1900,0.1416395,0.31441617,,,,,,,,,,,,,, -2000,0.67650574,0.3163867,,,,,,,,,,,,,, -2067,,,0.7248802185058594,0.2891953672681536,0.7033233213236846,0.3067685000065946,3554.0,0.7207346226176348,0.3087029349234676,3581.0,509.9147651195526,538.1823544502258,509.9147651195526,28.04653525352478,0.141535997390747,0.0 -2100,0.24606818,0.31945863,,,,,,,,,,,,,, -2200,0.08365246,0.28252164,,,,,,,,,,,,,, -2300,0.21744324,0.29448253,,,,,,,,,,,,,, -2400,0.25233567,0.34423274,,,,,,,,,,,,,, -2415,,,0.72991943359375,0.2816903250558035,0.7089212446583075,0.2988491112017269,3554.0,0.7260800819428931,0.3006744853480173,3581.0,590.0707066059113,622.3923435211182,590.0707066059113,32.063100814819336,0.1654932498931884,0.0 -2500,0.18792301,0.30590898,,,,,,,,,,,,,, -2600,0.10613573,0.24949019,,,,,,,,,,,,,, -2700,0.12596865,0.30309156,,,,,,,,,,,,,, -2762,,,0.735365731375558,0.2765935829707554,0.713966107730726,0.294382897131753,3554.0,0.7312001492250768,0.2960129061147375,3581.0,670.0614335536957,706.4358239173889,670.0614335536957,36.07767248153687,0.1900866031646728,0.0 -2800,0.15653884,0.23931164,,,,,,,,,,,,,, -2900,0.20514412,0.25376767,,,,,,,,,,,,,, -3000,0.09707613,0.26781246,,,,,,,,,,,,,, -3100,0.17868346,0.28996608,,,,,,,,,,,,,, -3107,,,0.7322643143790108,0.2801121303013393,0.7105216915579277,0.2975265683253025,3554.0,0.7275827637400517,0.2993197469064158,3581.0,750.1785328388214,790.608763217926,750.1785328388214,40.0933940410614,0.216646671295166,0.0 -3200,0.10311283,0.26662737,,,,,,,,,,,,,, -3300,0.15208869,0.2639253,,,,,,,,,,,,,, -3400,0.10625812,0.28514865,,,,,,,,,,,,,, -3455,,,0.7370850699288505,0.2747059549604143,0.7150101970271173,0.2927326122964441,3554.0,0.7320430854684445,0.2944108909217048,3581.0,830.3153464794159,874.7988729476929,830.3153464794159,44.107200622558594,0.2426018714904785,0.0 -3500,0.090178244,0.32994348,,,,,,,,,,,,,, -3600,0.11425918,0.24680275,,,,,,,,,,,,,, -3700,0.2594767,0.25584182,,,,,,,,,,,,,, -3800,0.10531654,0.2661136,,,,,,,,,,,,,, -3803,,,0.7389027050563267,0.2740594318934849,0.7170717221132878,0.2919730217052969,3554.0,0.7343342985330564,0.2935579327090896,3581.0,910.3621096611024,958.8977530002594,910.3621096611024,48.12011766433716,0.2682778835296631,0.0 -3900,0.143684,0.32357952,,,,,,,,,,,,,, -4000,0.17604707,0.23694788,,,,,,,,,,,,,, -4100,0.11539304,0.24300529,,,,,,,,,,,,,, -4149,,,0.7399940490722656,0.2729658229010446,0.7180402473445414,0.2910110224614871,3554.0,0.7352640236665736,0.2925437026035848,3581.0,990.4565176963806,1043.0497572422028,990.4565176963806,52.13883900642395,0.293527603149414,0.0 -4200,0.20872714,0.29471385,,,,,,,,,,,,,, -4300,0.16184255,0.29503763,,,,,,,,,,,,,, -4400,0.086713545,0.294288,,,,,,,,,,,,,, -4496,,,0.7390618324279785,0.2744659355708531,0.7164873371113534,0.2924563912910629,3554.0,0.7336459187814158,0.2942437217475914,3581.0,1070.5930817127228,1127.2398805618286,1070.5930817127228,56.15374946594238,0.3187620639801025,0.0 -4500,0.19826542,0.25580105,,,,,,,,,,,,,, -4600,0.13466683,0.24743575,,,,,,,,,,,,,, -4700,0.11392903,0.2791769,,,,,,,,,,,,,, -4800,0.18652774,0.25336882,,,,,,,,,,,,,, -4843,,,0.7348019736153739,0.2777554307665144,0.7128869154913478,0.2955745424389772,3554.0,0.7299596748682281,0.2972117927080424,3581.0,1150.773977279663,1211.4776709079742,1150.773977279663,60.17170739173889,0.3441455364227295,0.0 -4900,0.09393036,0.28442144,,,,,,,,,,,,,, -5000,0.0986135,0.2578901,,,,,,,,,,,,,, -5100,0.14251481,0.23330481,,,,,,,,,,,,,, -5192,,,0.7410356657845634,0.2717779534203665,0.7192005679120357,0.2897226894981886,3554.0,0.7364080962239947,0.291365269008744,3581.0,1230.890167951584,1295.6464822292328,1230.890167951584,64.18590641021729,0.3689992427825928,0.0 -5200,0.08044354,0.2574613,,,,,,,,,,,,,, -5300,0.15184332,0.25941518,,,,,,,,,,,,,, -5400,0.066272624,0.32233578,,,,,,,,,,,,,, -5500,0.13524516,0.2762893,,,,,,,,,,,,,, -5538,,,0.7400689806256976,0.2726975509098598,0.7185819729969752,0.2905688352912212,3554.0,0.7357066265533371,0.2921690718483838,3581.0,1310.9333474636078,1379.7466549873352,1310.9333474636078,68.20439505577087,0.3939557075500488,0.0 -5600,0.12724903,0.21774577,,,,,,,,,,,,,, -5700,0.08507594,0.2553143,,,,,,,,,,,,,, -5800,0.19948627,0.36212522,,,,,,,,,,,,,, -5887,,,0.7430743489946637,0.2713005372456142,0.7210018091411086,0.2893864294083251,3554.0,0.7382192092379922,0.2909360287476438,3581.0,1390.929930686951,1463.8018777370453,1390.929930686951,72.22309613227844,0.4202187061309814,0.0 -5900,0.067215405,0.25276196,,,,,,,,,,,,,, -6000,0.33733556,0.23273015,,,,,,,,,,,,,, -6100,0.025498752,0.3449694,,,,,,,,,,,,,, -6200,0.086092204,0.23651075,,,,,,,,,,,,,, -6236,,,0.7420650890895298,0.2709931646074567,0.7202885530608117,0.2892046634843662,3554.0,0.7374493583758028,0.2907333395328644,3581.0,1471.0121450424194,1547.9339079856873,1471.0121450424194,76.23468589782715,0.4449434280395508,0.0 -6300,0.14293903,0.26241475,,,,,,,,,,,,,, -6400,0.15729712,0.21481729,,,,,,,,,,,,,, -6500,0.13377877,0.30323067,,,,,,,,,,,,,, -6579,,,0.7435024806431362,0.2706704991204398,0.720853978351857,0.2891689422899374,3554.0,0.7380951277139766,0.2907080800797612,3581.0,1551.1771442890167,1632.15380692482,1551.1771442890167,80.25021290779114,0.4702737331390381,0.0 -6600,0.15116872,0.33911943,,,,,,,,,,,,,, -6700,0.10247202,0.23066571,,,,,,,,,,,,,, -6800,0.07290744,0.36576432,,,,,,,,,,,,,, -6900,0.08921278,0.32866,,,,,,,,,,,,,, -6928,,,0.7426629747663226,0.2704187972205026,0.720276600199599,0.2886667160351892,3554.0,0.7376407302647654,0.2901382936308817,3581.0,1631.2922222614288,1716.3261096477509,1631.2922222614288,84.26713466644287,0.4968934059143066,0.0 -7000,0.13990448,0.25348064,,,,,,,,,,,,,, -7100,0.16438682,0.23195544,,,,,,,,,,,,,, -7200,0.083012275,0.26694322,,,,,,,,,,,,,, -7279,,,0.7441093581063407,0.2699455533708845,0.7217107374613112,0.2883482821952378,3554.0,0.7389450179768221,0.2898458498411757,3581.0,1711.3607697486875,1800.455798149109,1711.3607697486875,88.28663372993469,0.5245425701141357,0.0 -7300,0.09185417,0.28914064,,,,,,,,,,,,,, -7400,0.18695882,0.22172153,,,,,,,,,,,,,, -7500,0.09180783,0.16200839,,,,,,,,,,,,,, -7600,0.1188083,0.2958864,,,,,,,,,,,,,, -7623,,,0.7333408083234515,0.2796070405415126,0.7108606307373734,0.2982994513224535,3554.0,0.7279741659539933,0.2998467865885053,3581.0,1791.4479806423187,1884.60189819336,1791.4479806423187,92.30220317840576,0.5551633834838867,0.0 -7700,0.114301704,0.24660094,,,,,,,,,,,,,, -7800,0.094839334,0.30361134,,,,,,,,,,,,,, -7900,0.06760329,0.29473323,,,,,,,,,,,,,, -7970,,,0.7422810282026019,0.2711213656834194,0.7199205560635903,0.2898536557619935,3554.0,0.737044593536198,0.2914284687739982,3581.0,1871.6037764549253,1968.8186626434328,1871.6037764549253,96.32317352294922,0.5827598571777344,0.0 -8000,0.08907792,0.29544404,,,,,,,,,,,,,, -8100,0.12482849,0.25778994,,,,,,,,,,,,,, -8200,0.040871292,0.29562825,,,,,,,,,,,,,, -8300,0.19148475,0.21405424,,,,,,,,,,,,,, -8317,,,0.7444945062909808,0.2702023983001709,0.7221898136342854,0.2886703911965391,3554.0,0.7393386700205948,0.2901830516091874,3581.0,1951.662151813507,2052.933934688568,1951.662151813507,100.34027457237244,0.6097853183746338,0.0 -8400,0.071025394,0.2683596,,,,,,,,,,,,,, -8500,0.14081797,0.26414213,,,,,,,,,,,,,, -8600,0.34933224,0.35275152,,,,,,,,,,,,,, -8661,,,0.732374940599714,0.2788037402289254,0.7117548971009777,0.2963294274497046,3554.0,0.7288023078618053,0.2982390104675719,3581.0,2031.719271659851,2137.048096179962,2031.719271659851,104.35645294189452,0.6381211280822754,0.0 -8700,0.07882274,0.22416249,,,,,,,,,,,,,, -8800,0.067977294,0.34867674,,,,,,,,,,,,,, -8900,0.114633955,0.28554842,,,,,,,,,,,,,, -9000,0.48505044,0.24208878,,,,,,,,,,,,,, -9010,,,0.7443432807922363,0.2709218774523054,0.7221872032393079,0.289655952689751,3554.0,0.7392850149879573,0.2911604663174567,3581.0,2111.8080835342407,2221.1938774585724,2111.8080835342407,108.3740575313568,0.664588451385498,0.0 -9100,0.18261562,0.2861331,,,,,,,,,,,,,, -9200,0.10485872,0.23910266,,,,,,,,,,,,,, -9300,0.028828153,0.41668868,,,,,,,,,,,,,, -9358,,,0.7424185616629464,0.2705019201551165,0.7198816749173467,0.289040895546831,3554.0,0.7372185121998045,0.2904191814764556,3581.0,2191.930754899978,2305.374571084976,2191.930754899978,112.39150047302246,0.6922500133514404,0.0 -9400,0.16030246,0.30424926,,,,,,,,,,,,,, -9500,0.058614887,0.2781617,,,,,,,,,,,,,, -9600,0.056689672,0.33321378,,,,,,,,,,,,,, -9700,0.06845296,0.40110785,,,,,,,,,,,,,, -9704,,,0.7448130335126605,0.2691502230507986,0.7223689004686621,0.2876758307101154,3554.0,0.7396080360103672,0.2890612387317613,3581.0,2271.9885540008545,2389.488274335861,2271.9885540008545,116.40704345703124,0.7198827266693115,0.0 -9800,0.11833548,0.24694306,,,,,,,,,,,,,, -9900,0.038163647,0.27915865,,,,,,,,,,,,,, -10000,0.04421735,0.29653323,,,,,,,,,,,,,, -10051,,,0.7456962721688407,0.2687918969563075,0.7229699095649268,0.28772884577127,3554.0,0.7401323827143256,0.2891233476704307,3581.0,2351.9646701812744,2473.5214030742645,2351.9646701812744,120.42432022094728,0.7467184066772461,0.0 -10100,0.20059857,0.24283648,,,,,,,,,,,,,, -10200,0.13706362,0.36833864,,,,,,,,,,,,,, -10300,0.06282287,0.2964717,,,,,,,,,,,,,, -10400,0.049326073,0.28236356,,,,,,,,,,,,,, -10401,,,0.7442936897277832,0.2693566083908081,0.722090962098164,0.2877130803594981,3554.0,0.739315762662315,0.2891068830066671,3581.0,2432.060645341873,2557.617630958557,2432.060645341873,124.384624004364,0.773712158203125,0.0 -10500,0.12012164,0.24631625,,,,,,,,,,,,,, -10600,0.102666296,0.24720371,,,,,,,,,,,,,, -10700,0.04892089,0.25151512,,,,,,,,,,,,,, -10748,,,0.7448418481009347,0.2689313207353864,0.7225382326691756,0.2874220900141126,3554.0,0.7397670921617565,0.2888093600556758,3581.0,2512.202274799347,2641.820770740509,2512.202274799347,128.40318179130554,0.8037357330322266,0.0 -10800,0.07049546,0.23354837,,,,,,,,,,,,,, -10900,0.08049668,0.24695385,,,,,,,,,,,,,, -11000,0.05687089,0.26840764,,,,,,,,,,,,,, -11093,,,0.7452192306518555,0.2687351363045828,0.7224126589318374,0.2877938480409397,3554.0,0.7395489950214674,0.289260280495148,3581.0,2592.4162380695343,2726.093177318573,2592.4162380695343,132.41613030433655,0.836331844329834,0.0 -11100,0.09675815,0.24520347,,,,,,,,,,,,,, -11200,0.08024623,0.38641232,,,,,,,,,,,,,, -11300,0.08972286,0.28140026,,,,,,,,,,,,,, -11400,0.08012347,0.26918215,,,,,,,,,,,,,, -11443,,,0.7452154840741839,0.2697348083768572,0.7223850437007597,0.2883542586258441,3554.0,0.7395929007915037,0.2898771429288257,3581.0,2672.483766078949,2810.2135293483734,2672.483766078949,136.42907905578613,0.8633337020874023,0.0 -11500,0.05215066,0.3682465,,,,,,,,,,,,,, -11600,0.15910597,0.25775182,,,,,,,,,,,,,, -11700,0.13672136,0.24162397,,,,,,,,,,,,,, -11789,,,0.7447735241481236,0.2692332097462245,0.7227943261553883,0.2875390082312623,3554.0,0.7399063089046356,0.288991903066968,3581.0,2752.4405829906464,2894.230272054672,2752.4405829906464,140.4444704055786,0.8953886032104492,0.0 -11800,0.09120775,0.2730173,,,,,,,,,,,,,, -11900,0.08077502,0.3151195,,,,,,,,,,,,,, -12000,0.066008545,0.33950293,,,,,,,,,,,,,, -12100,0.03921023,0.32745305,,,,,,,,,,,,,, -12137,,,0.7456636428833008,0.2685306753431047,0.7233629800928532,0.2873849434066369,3554.0,0.7404592216297822,0.288678972190467,3581.0,2832.651171684265,2978.4996314048767,2832.651171684265,144.46209335327148,0.9224934577941896,0.0 -12200,0.069043234,0.26862225,,,,,,,,,,,,,, -12300,0.1234051,0.23005246,,,,,,,,,,,,,, -12400,0.11860108,0.26653603,,,,,,,,,,,,,, -12485,,,0.7451457296098981,0.2689953531537737,0.722774129941615,0.287498426893553,3554.0,0.7399579186374267,0.2889450316121544,3581.0,2912.848475217819,3062.7552905082703,2912.848475217819,148.47949981689453,0.949455499649048,0.0 -12500,0.07222481,0.34863865,,,,,,,,,,,,,, -12600,0.05696205,0.2782365,,,,,,,,,,,,,, -12700,0.10961069,0.30662158,,,,,,,,,,,,,, -12800,0.13050821,0.26346326,,,,,,,,,,,,,, -12834,,,0.7460318974086216,0.2684512138366699,0.7233857180069991,0.2871742742277627,3554.0,0.7405707586480732,0.2885619810392523,3581.0,2992.999326467514,3146.96438741684,2992.999326467514,152.49664402008057,0.9763846397399902,0.0 -12900,0.12291937,0.30536693,,,,,,,,,,,,,, -13000,0.08919284,0.31177768,,,,,,,,,,,,,, -13100,0.06285211,0.25597173,,,,,,,,,,,,,, -13179,,,0.745182854788644,0.2685987268175397,0.7225593219128095,0.2874176076911579,3554.0,0.7396684405324979,0.2888663216563983,3581.0,3072.988017320633,3231.008982419968,3072.988017320633,156.51189756393433,1.0034832954406738,0.0 -13200,0.0960085,0.24050918,,,,,,,,,,,,,, -13300,0.066980265,0.2419861,,,,,,,,,,,,,, -13400,0.06300107,0.29468387,,,,,,,,,,,,,, -13500,0.038296167,0.24924403,,,,,,,,,,,,,, -13526,,,0.7462560108729771,0.2681436708995274,0.7231661700504713,0.2871634720011782,3554.0,0.7403992943442823,0.2885537316632051,3581.0,3153.030436277389,3315.10955786705,3153.030436277389,160.52774262428284,1.0317604541778564,0.0 -13600,0.08773251,0.20642652,,,,,,,,,,,,,, -13700,0.110467575,0.2820495,,,,,,,,,,,,,, -13800,0.10585982,0.32918447,,,,,,,,,,,,,, -13876,,,0.7469626835414341,0.2683169841766357,0.7242131445202589,0.287253771059018,3554.0,0.7413314056609537,0.2886610758168109,3581.0,3233.1198103427887,3399.258532524109,3233.1198103427887,164.54671454429626,1.0586469173431396,0.0 -13900,0.12759863,0.2534175,,,,,,,,,,,,,, -14000,0.07639967,0.24659403,,,,,,,,,,,,,, -14100,0.06624555,0.31972843,,,,,,,,,,,,,, -14200,0.081997305,0.24684554,,,,,,,,,,,,,, -14220,,,0.7435752323695591,0.2695116315569196,0.721224379660242,0.2880564846882474,3554.0,0.7383425408187309,0.2895751203181723,3581.0,3313.156766176224,3483.354829549789,3313.156766176224,168.5634524822235,1.0875027179718018,0.0 -14300,0.04677925,0.27563375,,,,,,,,,,,,,, -14400,0.03386487,0.2895433,,,,,,,,,,,,,, -14500,0.08743919,0.33282647,,,,,,,,,,,,,, -14567,,,0.7471098899841309,0.2674135140010288,0.7242367754642656,0.2864020266557048,3554.0,0.7413343372574002,0.287807469925911,3581.0,3393.3027589321136,3567.560054063797,3393.3027589321136,172.5811402797699,1.114877223968506,0.0 -14600,0.10360926,0.221224,,,,,,,,,,,,,, -14700,0.05309376,0.27107254,,,,,,,,,,,,,, -14800,0.05490071,0.28604284,,,,,,,,,,,,,, -14900,0.068734616,0.27929068,,,,,,,,,,,,,, -14914,,,0.7454605102539062,0.2676483733313424,0.7230939720209623,0.2864320290242948,3554.0,0.7402175353645979,0.2878112878189577,3581.0,3473.428725004196,3651.744841337204,3473.428725004196,176.59819197654724,1.142707586288452,0.0 -15000,0.15449676,0.23818754,,,,,,,,,,,,,, -15100,0.059398267,0.19426674,,,,,,,,,,,,,, -15200,0.07253029,0.19482625,,,,,,,,,,,,,, -15260,,,0.7476382255554199,0.2672039951596941,0.7246152827360017,0.2863704958321609,3554.0,0.7417507603061295,0.2877523490950502,3581.0,3553.576901912689,3735.9517509937286,3553.576901912689,180.614520072937,1.1707472801208496,0.0 -15300,0.18667535,0.22841902,,,,,,,,,,,,,, -15400,0.05958487,0.29766193,,,,,,,,,,,,,, -15500,0.057056982,0.26276067,,,,,,,,,,,,,, -15600,0.05676677,0.35967746,,,,,,,,,,,,,, -15610,,,0.7465495382036481,0.2675726413726806,0.7237017818830894,0.2865305542610439,3554.0,0.7408511010803547,0.287978763788048,3581.0,3633.779043912888,3820.211359500885,3633.779043912888,184.6301848888397,1.1982817649841309,0.0 -15700,0.24777664,0.25677297,,,,,,,,,,,,,, -15800,0.02909639,0.26686865,,,,,,,,,,,,,, -15900,0.09381111,0.19022462,,,,,,,,,,,,,, -15958,,,0.7465058735438755,0.2675352437155587,0.7236025868739449,0.2864478459570203,3554.0,0.7408671907724798,0.2877550420731813,3581.0,3713.819554805756,3904.3096396923065,3713.819554805756,188.6462926864624,1.226337432861328,0.0 -16000,0.13156725,0.26222238,,,,,,,,,,,,,, -16100,0.03562979,0.24279022,,,,,,,,,,,,,, -16200,0.06344441,0.21389458,,,,,,,,,,,,,, -16300,0.094599545,0.24193087,,,,,,,,,,,,,, -16305,,,0.7472934722900391,0.2676750591823033,0.7245056461469471,0.286674658367992,3554.0,0.7417191263351718,0.2879381645860968,3581.0,3793.9894206523895,3988.544013738632,3793.9894206523895,192.6677522659301,1.255561113357544,0.0 -16400,0.04331315,0.28683564,,,,,,,,,,,,,, -16500,0.07372119,0.33365488,,,,,,,,,,,,,, -16600,0.1256671,0.32669318,,,,,,,,,,,,,, -16653,,,0.7460013798304966,0.267695529120309,0.72336675829611,0.2865367024281619,3554.0,0.7405575323757331,0.287911780218078,3581.0,3874.103070020676,4072.714470386505,3874.103070020676,196.6825180053711,1.2838151454925537,0.0 -16700,0.054361187,0.23350088,,,,,,,,,,,,,, -16800,0.07544293,0.22663477,,,,,,,,,,,,,, -16900,0.032261886,0.43366534,,,,,,,,,,,,,, -16999,,,0.7481280054364886,0.2673676184245518,0.7252748883300506,0.2864205398516636,3554.0,0.7424530480967257,0.2877251806954237,3581.0,3954.122975349426,4156.801654577255,3954.122975349426,200.70232272148127,1.3174071311950684,0.0 -17000,0.06641296,0.22498041,,,,,,,,,,,,,, -17100,0.043611012,0.26987764,,,,,,,,,,,,,, -17200,0.10514994,0.27305168,,,,,,,,,,,,,, -17300,0.05099064,0.3080344,,,,,,,,,,,,,, -17344,,,0.746361528124128,0.2677748032978603,0.7239071787510551,0.28663830174847,3554.0,0.7410125434148981,0.2879645830424462,3581.0,4034.12496638298,4240.8676471710205,4034.12496638298,204.7229642868042,1.3461709022521973,0.0 -17400,0.1043758,0.29690564,,,,,,,,,,,,,, -17500,0.06761067,0.28368092,,,,,,,,,,,,,, -17600,0.0667724,0.27523375,,,,,,,,,,,,,, -17693,,,0.748692240033831,0.2673057488032749,0.7252666449774902,0.2868741990209095,3554.0,0.7423913482180257,0.2882696395145385,3581.0,4114.188055753708,4324.98748755455,4114.188055753708,208.7395868301392,1.3738317489624023,0.0 -17700,0.0367351,0.21688397,,,,,,,,,,,,,, -17800,0.112961404,0.30553782,,,,,,,,,,,,,, -17900,0.044012077,0.33988523,,,,,,,,,,,,,, -18000,0.046705365,0.27563012,,,,,,,,,,,,,, -18042,,,0.7466691562107631,0.2669903721128191,0.7238836165016531,0.2859708306221862,3554.0,0.7410724025237364,0.2873323808599204,3581.0,4194.349546909332,4409.206943273544,4194.349546909332,212.75606155395508,1.402545690536499,0.0 -18100,0.07098178,0.22909543,,,,,,,,,,,,,, -18200,0.035849005,0.23658854,,,,,,,,,,,,,, -18300,0.07798345,0.31009325,,,,,,,,,,,,,, -18389,,,0.7479186739240374,0.2667764595576695,0.7248502869785804,0.2860190714083163,3554.0,0.7420026730705459,0.2874088750741762,3581.0,4274.440649032593,4493.35611987114,4274.440649032593,216.77316308021545,1.4304280281066897,0.0 -18400,0.051236443,0.33299237,,,,,,,,,,,,,, -18500,0.0747218,0.2547711,,,,,,,,,,,,,, -18600,0.038245905,0.29343566,,,,,,,,,,,,,, -18700,0.09790104,0.27870125,,,,,,,,,,,,,, -18720,,,0.7479407446725028,0.2662788459232875,0.7245068826498312,0.285733319026537,3554.0,0.7416746069751815,0.2871059661669226,3581.0,4351.382478237152,4577.62344622612,4351.382478237152,220.79004406929016,4.727059602737427,0.0 -18800,0.101526774,0.28513813,,,,,,,,,,,,,, -18900,0.07170262,0.24325922,,,,,,,,,,,,,, -19000,0.05466786,0.27289835,,,,,,,,,,,,,, -19068,,,0.7482282093593052,0.2664348908833095,0.7249527106341446,0.2857875362432734,3554.0,0.7421484347729336,0.2871270327553407,3581.0,4431.4831800460815,4661.779294967651,4431.4831800460815,224.8030700683593,4.756349325180054,0.0 -19100,0.06460481,0.3035511,,,,,,,,,,,,,, -19200,0.11344452,0.22536674,,,,,,,,,,,,,, -19300,0.084861696,0.23728782,,,,,,,,,,,,,, -19400,0.08872194,0.22213447,,,,,,,,,,,,,, -19413,,,0.747701713017055,0.2666091578347342,0.7246733296769485,0.2858498937706633,3554.0,0.7418394581428023,0.2872160373869903,3581.0,4511.659183979034,4746.009649276733,4511.659183979034,228.8146023750305,4.786910057067871,0.0 -19500,0.11727526,0.32872063,,,,,,,,,,,,,, -19600,0.035288457,0.29309353,,,,,,,,,,,,,, -19700,0.10514927,0.3334424,,,,,,,,,,,,,, -19760,,,0.7489380155290876,0.2660056182316371,0.7252018659652856,0.2857684906641284,3554.0,0.7423292392793563,0.2871845397693556,3581.0,4591.833205223084,4830.240214586258,4591.833205223084,232.8302583694458,4.815237998962402,0.0 -19800,0.0657117,0.24873737,,,,,,,,,,,,,, -19900,0.05725422,0.23147765,,,,,,,,,,,,,, -20000,0.052328587,0.23396556,,,,,,,,,,,,,, -20100,0.036353987,0.3523728,,,,,,,,,,,,,, -20110,,,0.7484143120901925,0.2663228171212332,0.724964320022334,0.2858430243101962,3554.0,0.7421734556077213,0.2872105491657358,3581.0,4671.902544498444,4914.371288776398,4671.902544498444,236.8491904735565,4.845395565032959,0.0 -20200,0.107227035,0.28069612,,,,,,,,,,,,,, -20300,0.06466942,0.26345992,,,,,,,,,,,,,, -20400,0.11254204,0.28128454,,,,,,,,,,,,,, -20456,,,0.7484644481113979,0.2665191377912249,0.7253575279394696,0.2858613657696434,3554.0,0.7424498437936331,0.2872538072574874,3581.0,4752.0476224422455,4998.575645685196,4752.0476224422455,240.86488962173465,4.876446485519409,0.0 -20500,0.11890465,0.23478754,,,,,,,,,,,,,, -20600,0.101752885,0.28973198,,,,,,,,,,,,,, -20700,0.057787847,0.26322916,,,,,,,,,,,,,, -20800,0.059916377,0.25884816,,,,,,,,,,,,,, -20805,,,0.7486716679164341,0.2659215075629098,0.7250687358214336,0.2855840456505873,3554.0,0.7422169523177883,0.2869579205463732,3581.0,4832.189853668213,5082.776785135269,4832.189853668213,244.88236665725708,4.90512490272522,0.0 -20900,0.044329986,0.33448157,,,,,,,,,,,,,, -21000,0.07248469,0.3082742,,,,,,,,,,,,,, -21100,0.0488145,0.23327443,,,,,,,,,,,,,, -21151,,,0.748410837990897,0.2661462851933071,0.7249383534617684,0.2856016486430343,3554.0,0.7421067788327282,0.2869437057124406,3581.0,4912.309539794922,5166.949881315231,4912.309539794922,248.8937950134277,4.9343202114105225,0.0 -21200,0.07847499,0.2860881,,,,,,,,,,,,,, -21300,0.05561315,0.31147972,,,,,,,,,,,,,, -21400,0.0828616,0.262404,,,,,,,,,,,,,, -21499,,,0.7486587933131627,0.266301155090332,0.7253595200830051,0.2857149088724852,3554.0,0.7424884999607303,0.2870643443150482,3581.0,4992.323907136917,5251.02219581604,4992.323907136917,252.9103980064392,4.963292837142944,0.0 -21500,0.056648545,0.21670824,,,,,,,,,,,,,, -21600,0.06410942,0.24879758,,,,,,,,,,,,,, -21700,0.03736466,0.32870778,,,,,,,,,,,,,, -21800,0.12735705,0.2670353,,,,,,,,,,,,,, -21844,,,0.7488077027457101,0.2658378056117466,0.7251955460616559,0.2856382285200214,3554.0,0.74234614709142,0.286989554517331,3581.0,5072.416844844818,5335.177726268768,5072.416844844818,256.9296944141388,4.993821382522583,0.0 -21900,0.052334197,0.22983813,,,,,,,,,,,,,, -22000,0.0645689,0.23951626,,,,,,,,,,,,,, -22100,0.05769426,0.30429214,,,,,,,,,,,,,, -22193,,,0.7486767087663923,0.2658830199922834,0.7250403649497046,0.2855381576546673,3554.0,0.7422395869694219,0.2869001408257121,3581.0,5152.545783519745,5419.370491504669,5152.545783519745,260.9511814117432,5.023503065109253,0.0 -22200,0.074345924,0.2585563,,,,,,,,,,,,,, -22300,0.06762768,0.32847926,,,,,,,,,,,,,, -22400,0.040679667,0.22608107,,,,,,,,,,,,,, -22500,0.09638819,0.22498988,,,,,,,,,,,,,, -22543,,,0.7485130855015346,0.2659910746983119,0.7248971366989659,0.2855696197836065,3554.0,0.7420529192701061,0.2869052881636589,3581.0,5232.754349708557,5503.637329339981,5232.754349708557,264.9676489830017,5.052587985992432,0.0 -22600,0.079294875,0.21658443,,,,,,,,,,,,,, -22700,0.087147,0.2973917,,,,,,,,,,,,,, -22800,0.05737671,0.26417145,,,,,,,,,,,,,, -22886,,,0.748805318559919,0.265933837209429,0.7251649769625774,0.2858411352085678,3554.0,0.7422529495950851,0.2871465653688739,3581.0,5312.916896104813,5587.861164093018,5312.916896104813,268.98533368110657,5.083728313446045,0.0 -22900,0.068820655,0.20583326,,,,,,,,,,,,,, -23000,0.05668063,0.25719276,,,,,,,,,,,,,, -23100,0.04001922,0.31600627,,,,,,,,,,,,,, -23200,0.034308717,0.27276352,,,,,,,,,,,,,, -23231,,,0.7487833840506417,0.2657624823706491,0.7251254775648917,0.2855141317166836,3554.0,0.7422787885498116,0.2868351002905962,3581.0,5392.880267381668,5671.883190155029,5392.880267381668,273.00149416923523,5.113495111465454,0.0 -23300,0.05133564,0.2627224,,,,,,,,,,,,,, -23400,0.03099445,0.30057812,,,,,,,,,,,,,, -23500,0.057876863,0.27158305,,,,,,,,,,,,,, -23577,,,0.7490690095084054,0.2659006799970354,0.725525623637099,0.2855618401196275,3554.0,0.7426354206663641,0.2869184462593375,3581.0,5472.882793188095,5755.94358754158,5472.882793188095,277.0170331001282,5.143171548843384,0.0 -23600,0.13359484,0.25705323,,,,,,,,,,,,,, -23700,0.04158831,0.28512567,,,,,,,,,,,,,, -23800,0.05316679,0.29139084,,,,,,,,,,,,,, -23900,0.054748263,0.22818902,,,,,,,,,,,,,, -23922,,,0.749016353062221,0.2657322883605957,0.7253279205648565,0.285539771977877,3554.0,0.7424687287288816,0.2869320475033161,3581.0,5552.945104598999,5840.066252231598,5552.945104598999,281.03462076187134,5.173579931259155,0.0 -24000,0.05222875,0.18874975,,,,,,,,,,,,,, -24100,0.04044312,0.3495168,,,,,,,,,,,,,, -24200,0.113663174,0.2744722,,,,,,,,,,,,,, -24267,,,0.7493527276175362,0.2656536953789847,0.7255556431793402,0.2854886288446996,3554.0,0.742668554523876,0.2868933231595574,3581.0,5632.908239126205,5924.096672773361,5632.908239126205,285.0586128234863,5.204019546508789,0.0 -24300,0.07509545,0.27048597,,,,,,,,,,,,,, -24400,0.079359144,0.33004883,,,,,,,,,,,,,, -24500,0.031024039,0.30664682,,,,,,,,,,,,,, -24600,0.09117029,0.1737819,,,,,,,,,,,,,, -24614,,,0.749849796295166,0.2658327988215855,0.7261342578344823,0.2856676298108205,3554.0,0.7432403521842712,0.2870517316326619,3581.0,5712.8686356544495,6008.117910861969,5712.8686356544495,289.0750644207001,5.235795736312866,0.0 -24700,0.049916558,0.23188072,,,,,,,,,,,,,, -24800,0.034853097,0.31973502,,,,,,,,,,,,,, -24900,0.026076727,0.3157634,,,,,,,,,,,,,, -24958,,,0.7492024557931083,0.265755363873073,0.725492238059229,0.2856205568329699,3554.0,0.7426132632513613,0.2869829754694917,3581.0,5792.95509147644,6092.265564203262,5792.95509147644,293.0931673049927,5.266326189041138,0.0 -25000,0.07521729,0.2565869,,,,,,,,,,,,,, -25100,0.033979442,0.2363615,,,,,,,,,,,,,, -25200,0.03252711,0.2339932,,,,,,,,,,,,,, -25300,0.042388126,0.29327586,,,,,,,,,,,,,, -25303,,,0.7492395809718541,0.265476039477757,0.7254227878139069,0.2854243306947278,3554.0,0.7425200657550265,0.2868363615588348,3581.0,5872.948161840439,6176.321474313736,5872.948161840439,297.111275434494,5.298498392105103,0.0 -25400,0.05052151,0.23113124,,,,,,,,,,,,,, -25500,0.03541327,0.25521865,,,,,,,,,,,,,, -25600,0.040733766,0.22247484,,,,,,,,,,,,,, -25649,,,0.7495963232857841,0.2655596562794277,0.7257776641416361,0.2854543845842712,3554.0,0.7429180811051382,0.2868049662061924,3581.0,5953.030558347702,6260.4646916389465,5953.030558347702,301.1283836364746,5.329797029495239,0.0 -25700,0.04248168,0.26995486,,,,,,,,,,,,,, -25800,0.03224676,0.21695289,,,,,,,,,,,,,, -25900,0.0698158,0.2819911,,,,,,,,,,,,,, -25994,,,0.7489735058375767,0.2655539512634277,0.7251500015387592,0.2854096300493282,3554.0,0.7422948782419366,0.2868030572596691,3581.0,6033.104256391525,6344.598982095718,6033.104256391525,305.14564299583435,5.360304832458496,0.0 -26000,0.044156946,0.36045823,,,,,,,,,,,,,, -26100,0.050928555,0.29220504,,,,,,,,,,,,,, -26200,0.046241544,0.28929952,,,,,,,,,,,,,, -26300,0.04111493,0.24412157,,,,,,,,,,,,,, -26340,,,0.749401501246861,0.2652556555611746,0.7253881657331528,0.2853258226316298,3554.0,0.7425187703984572,0.2867089734667341,3581.0,6113.080691099167,6428.634559392929,6113.080691099167,309.16142416000366,5.391079664230347,0.0 -26400,0.0473458,0.3290697,,,,,,,,,,,,,, -26500,0.059789017,0.28437534,,,,,,,,,,,,,, -26600,0.028577883,0.3144424,,,,,,,,,,,,,, -26688,,,0.7494032042367118,0.2654205220086233,0.7254571351162422,0.2854462957945712,3554.0,0.7425824474003421,0.2868135564655473,3581.0,6193.06077671051,6512.6798322200775,6193.06077671051,313.17753505706787,5.427402496337891,0.0 -26700,0.05146777,0.24761447,,,,,,,,,,,,,, -26800,0.026291888,0.22227353,,,,,,,,,,,,,, -26900,0.04532299,0.24955317,,,,,,,,,,,,,, -27000,0.03723785,0.25419423,,,,,,,,,,,,,, -27033,,,0.7493421009608677,0.2654526914869036,0.7255051526449071,0.2853781850940401,3554.0,0.7426219216873778,0.2867420732359153,3581.0,6273.084081888199,6596.759754896164,6273.084081888199,317.19087767601013,5.4581146240234375,0.0 -27100,0.08392342,0.2890185,,,,,,,,,,,,,, -27200,0.040243473,0.26603505,,,,,,,,,,,,,, -27300,0.06396452,0.24950884,,,,,,,,,,,,,, -27380,,,0.7501442773001534,0.2650948081697736,0.7259720011782499,0.2854123091389103,3554.0,0.7430699105304035,0.2868342821706576,3581.0,6353.071182966232,6680.812620401382,6353.071182966232,321.2112212181092,5.4910569190979,0.0 -27400,0.036918316,0.24790756,,,,,,,,,,,,,, -27500,0.047760032,0.36462876,,,,,,,,,,,,,, -27600,0.060418725,0.28354517,,,,,,,,,,,,,, -27700,0.04228284,0.30571747,,,,,,,,,,,,,, -27728,,,0.7496339934212821,0.2653770787375314,0.7256522277935074,0.2854231972337507,3554.0,0.7427992491840617,0.2867781727782044,3581.0,6433.0986251831055,6764.905099153519,6433.0986251831055,325.23206520080566,5.522183656692505,0.0 -27800,0.030344648,0.2670954,,,,,,,,,,,,,, -27900,0.036728423,0.25836593,,,,,,,,,,,,,, -28000,0.05900149,0.17673334,,,,,,,,,,,,,, -28073,,,0.749655042375837,0.2653424569538661,0.7257410499173467,0.2853482342464037,3554.0,0.7428340874581123,0.2867435049458077,3581.0,6513.174396276474,6849.044062137604,6513.174396276474,329.251145362854,5.553809642791748,0.0 -28100,0.0802693,0.23882931,,,,,,,,,,,,,, -28200,0.026268225,0.30992934,,,,,,,,,,,,,, -28300,0.026989866,0.30916458,,,,,,,,,,,,,, -28400,0.032906044,0.3219085,,,,,,,,,,,,,, -28419,,,0.7501136234828404,0.2649978228977748,0.7258814616892938,0.2853657857178971,3554.0,0.7429455562997417,0.2867419709709229,3581.0,6593.317857980728,6933.250350475311,6593.317857980728,333.2688903808594,5.586345434188843,0.0 -28500,0.03761366,0.2873227,,,,,,,,,,,,,, -28600,0.025715461,0.31715548,,,,,,,,,,,,,, -28700,0.035015702,0.29183802,,,,,,,,,,,,,, -28766,,,0.74993896484375,0.2651689052581787,0.7257976542715954,0.285333516427353,3554.0,0.7429291257243088,0.2867050533086952,3581.0,6673.433345794678,7017.428941488266,6673.433345794678,337.28736329078674,5.618357419967651,0.0 -28800,0.046030868,0.23056422,,,,,,,,,,,,,, -28900,0.026689688,0.249582,,,,,,,,,,,,,, -29000,0.04902145,0.2903678,,,,,,,,,,,,,, -29100,0.03238973,0.23802742,,,,,,,,,,,,,, -29111,,,0.7497790881565639,0.2652667249952044,0.7257436603123242,0.2853496424857994,3554.0,0.7428636761292237,0.2867233246539898,3581.0,6753.462169647217,7101.522633552551,6753.462169647217,341.3042325973511,5.653614521026611,0.0 -29200,0.03269572,0.24650419,,,,,,,,,,,,,, -29300,0.030047147,0.27658185,,,,,,,,,,,,,, -29400,0.04356813,0.19718307,,,,,,,,,,,,,, -29455,,,0.7503256797790527,0.2649745089667184,0.7260548468714828,0.285384728255135,3554.0,0.7431262926295029,0.2868116816073547,3581.0,6833.543080806732,7185.665328025818,6833.543080806732,345.3211672306061,5.685898780822754,0.0 -29500,0.026262766,0.22420119,,,,,,,,,,,,,, -29600,0.06121344,0.26159987,,,,,,,,,,,,,, -29700,0.025533052,0.30334604,,,,,,,,,,,,,, -29800,0.05403989,0.2571887,,,,,,,,,,,,,, -29801,,,0.7499323572431292,0.2650454044342041,0.7257351421813449,0.2852960950414586,3554.0,0.742863471599239,0.2866652722266825,3581.0,6913.672247648239,7269.804100990295,6913.672247648239,349.28695940971375,5.716902494430542,0.0 -29900,0.032955002,0.23982532,,,,,,,,,,,,,, -30000,0.033742648,0.29520917,,,,,,,,,,,,,, -30100,0.019399734,0.24972717,,,,,,,,,,,,,, -30146,,,0.7499641690935407,0.2652068478720529,0.725912923818233,0.285359843634593,3554.0,0.7429881667132086,0.2867398915827457,3581.0,6993.665600538254,7353.867259740829,6993.665600538254,353.3090696334839,5.7517173290252686,0.0 -30200,0.043391943,0.25631624,,,,,,,,,,,,,, -30300,0.033444446,0.18018505,,,,,,,,,,,,,, -30400,0.039285675,0.25397965,,,,,,,,,,,,,, -30491,,,0.7503576959882464,0.264944110597883,0.7260252394968697,0.2854310284186832,3554.0,0.7430889318189752,0.2868292030093723,3581.0,7073.665458440781,7437.925004243851,7073.665458440781,357.32208609580994,5.783747434616089,0.0 -30500,0.038454432,0.21702822,,,,,,,,,,,,,, -30600,0.026828382,0.34390068,,,,,,,,,,,,,, -30700,0.031686533,0.28609914,,,,,,,,,,,,,, -30800,0.020848881,0.32365897,,,,,,,,,,,,,, -30838,,,0.7499535424368722,0.2649753093719482,0.7257873500808948,0.2853013158314135,3554.0,0.7428631307159314,0.2866669084665596,3581.0,7153.841228246689,7522.169150829315,7153.841228246689,361.3407151699066,5.820956230163574,0.0 -30900,0.016430844,0.24844006,,,,,,,,,,,,,, -31000,0.019224837,0.29109916,,,,,,,,,,,,,, -31100,0.024745222,0.20381993,,,,,,,,,,,,,, -31183,,,0.7500086511884417,0.2650706427437918,0.7258787139051069,0.285347959467985,3554.0,0.7429502604893884,0.2867455502456541,3581.0,7233.882091283798,7606.274645090103,7233.882091283798,365.355765581131,5.85828971862793,0.0 -31200,0.02536833,0.190898,,,,,,,,,,,,,, -31300,0.021382974,0.30646542,,,,,,,,,,,,,, -31400,0.027424734,0.19025648,,,,,,,,,,,,,, -31500,0.028853007,0.32677385,,,,,,,,,,,,,, -31525,,,0.7500947543552944,0.2649495601654053,0.7257277918586452,0.2854118454503288,3554.0,0.7428442457806828,0.2868032617896537,3581.0,7313.855280160904,7690.314583301544,7313.855280160904,369.3752720355988,5.89311146736145,0.0 -31600,0.040638246,0.22495028,,,,,,,,,,,,,, -31700,0.022333909,0.2306888,,,,,,,,,,,,,, -31800,0.020248193,0.24458003,,,,,,,,,,,,,, -31870,,,0.7500623975481305,0.2648903812680925,0.7257400194982766,0.2852734773428707,3554.0,0.7428506543868681,0.2866636359868054,3581.0,7393.850246191025,7774.377434492111,7393.850246191025,373.39921855926514,5.924740552902222,0.0 -31900,0.02131105,0.29869112,,,,,,,,,,,,,, -32000,0.0445908,0.30447036,,,,,,,,,,,,,, -32100,0.018732432,0.31205925,,,,,,,,,,,,,, -32200,0.021464167,0.22026762,,,,,,,,,,,,,, -32217,,,0.7503595352172852,0.2650008542197091,0.7261762302379361,0.2853202583686515,3554.0,0.7432891666739389,0.2866684083531136,3581.0,7473.869105577469,7858.462511301041,7473.869105577469,377.42104148864746,5.95652961730957,0.0 -32300,0.03176139,0.27733442,,,,,,,,,,,,,, -32400,0.026207613,0.26117545,,,,,,,,,,,,,, -32500,0.021511953,0.30630037,,,,,,,,,,,,,, -32561,,,0.7500896453857422,0.2650911467415945,0.7257870753024761,0.2854902431679094,3554.0,0.742864085189193,0.2869319452383237,3581.0,7553.872319459915,7942.528846740723,7553.872319459915,381.439279794693,5.988731384277344,0.0 -32600,0.024917336,0.26912355,,,,,,,,,,,,,, -32700,0.024454223,0.2611815,,,,,,,,,,,,,, -32800,0.032569025,0.23126172,,,,,,,,,,,,,, -32900,0.02720182,0.24691667,,,,,,,,,,,,,, -32908,,,0.750241756439209,0.2648369073867798,0.7258719131392445,0.2853054546813449,3554.0,0.7429668955948059,0.286711257384896,3581.0,7633.981284618378,8026.701743841171,7633.981284618378,385.458370923996,6.020787715911865,0.0 -33000,0.025310652,0.3230185,,,,,,,,,,,,,, -33100,0.020486975,0.25637493,,,,,,,,,,,,,, -33200,0.018359223,0.33012748,,,,,,,,,,,,,, -33255,,,0.7501583780561175,0.2649034091404506,0.7258987727296707,0.2852611123140299,3554.0,0.7429962797359327,0.2866541935191811,3581.0,7714.053889751434,8110.837794303894,7714.053889751434,389.4760494232178,6.053780794143677,0.0 -33300,0.026301015,0.19762151,,,,,,,,,,,,,, -33400,0.025932364,0.2671769,,,,,,,,,,,,,, -33500,0.025524981,0.25370026,,,,,,,,,,,,,, -33598,,,0.7500778606959752,0.2648886442184448,0.7257746415790307,0.2852774788035928,3554.0,0.7428910831471656,0.2866700445929908,3581.0,7794.201321601868,8195.054210424423,7794.201321601868,393.4959604740143,6.090385913848877,0.0 -33600,0.020953823,0.28578907,,,,,,,,,,,,,, -33700,0.02056421,0.24591817,,,,,,,,,,,,,, -33800,0.033712547,0.3273809,,,,,,,,,,,,,, -33900,0.01758528,0.21402091,,,,,,,,,,,,,, -33945,,,0.7502149173191616,0.2647657053811209,0.7258221095508582,0.2852540195960977,3554.0,0.7429065592493368,0.2866699423279985,3581.0,7874.296406030655,8279.21177649498,7874.296406030655,397.5133173465729,6.123018741607666,0.0 -34000,0.028174112,0.26293176,,,,,,,,,,,,,, -34100,0.024651112,0.2234975,,,,,,,,,,,,,, -34200,0.021357056,0.23199798,,,,,,,,,,,,,, -34293,,,0.7502946172441755,0.2648194347109113,0.7260055241453293,0.2852201703296462,3554.0,0.7431016808546844,0.2866239912581158,3581.0,7954.259927272797,8363.237367153168,7954.259927272797,401.5292658805847,6.1560492515563965,0.0 -34300,0.018528594,0.25041565,,,,,,,,,,,,,, -34400,0.016657572,0.18702026,,,,,,,,,,,,,, -34500,0.031901497,0.2491027,,,,,,,,,,,,,, -34600,0.018834248,0.22472051,,,,,,,,,,,,,, -34639,,,0.7502439362662179,0.2648192303521292,0.7259190376380487,0.2852440760520716,3554.0,0.7430262092903519,0.286623650374808,3581.0,8034.41494178772,8447.455779075623,8034.41494178772,405.5457291603088,6.190333127975464,0.0 -34700,0.04299468,0.39667568,,,,,,,,,,,,,, -34800,0.029705863,0.34048855,,,,,,,,,,,,,, -34900,0.029450396,0.24589275,,,,,,,,,,,,,, -34986,,,0.7504809924534389,0.264769230570112,0.7261091843037775,0.2852524911411437,3554.0,0.7431830837885717,0.2866677606748289,3581.0,8114.595928907394,8531.699686050415,8114.595928907394,409.5631048679352,6.223457336425781,0.0 -35000,0.016729508,0.32475162,,,,,,,,,,,,,, -35100,0.021755004,0.2808896,,,,,,,,,,,,,, -35200,0.016713146,0.2540621,,,,,,,,,,,,,, -35300,0.01910526,0.30319628,,,,,,,,,,,,,, -35334,,,0.7502155985151019,0.2647643089294433,0.7258480761114238,0.2852156708330402,3554.0,0.7429592598087127,0.2866166963553302,3581.0,8194.691151618958,8615.862161159515,8194.691151618958,413.5802364349365,6.2608723640441895,0.0 -35400,0.02191812,0.30955473,,,,,,,,,,,,,, -35500,0.02530799,0.26402748,,,,,,,,,,,,,, -35600,0.024070429,0.19554688,,,,,,,,,,,,,, -35681,,,0.7504230907985142,0.264780044555664,0.7260905680659117,0.2852255113351593,3554.0,0.7431732663493088,0.2866270932962161,3581.0,8274.671577215195,8699.908791542053,8274.671577215195,417.6001315116882,6.293979644775391,0.0 -35700,0.027744956,0.3105235,,,,,,,,,,,,,, -35800,0.026603458,0.22989835,,,,,,,,,,,,,, -35900,0.026746428,0.29060507,,,,,,,,,,,,,, -36000,0.029041165,0.22448094,,,,,,,,,,,,,, -36027,,,0.7503397805350167,0.2647636617933001,0.7259813436444851,0.2852229009401818,3554.0,0.7430758418999581,0.2866263774412699,3581.0,8354.620125293732,8783.927860736847,8354.620125293732,421.6159639358521,6.335568904876709,0.0 -36100,0.012491676,0.3968719,,,,,,,,,,,,,, -36189,,,0.7503502709524972,0.2647658245904105,0.7259957008168613,0.2852237252754379,3554.0,0.7430886591123289,0.286627570532847,3581.0,8390.96516919136,8824.325541734695,8390.96516919136,425.62926030159,6.368803024291992,0.0 -36189,,,,,,,,,,,8390.96516919136,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 4a7a5a9af..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.95801854133606,0.0,34.67232370376587,1,0,34.67232370376587,0.8949576268413153,3581,0.2836721463431129,38.63045763969421,0.8928326879228864,0.2656774180276053,0.8956671152882316,3554,0.2613217295323051 -7.97329568862915,0.0188593864440917,114.82151770591736,343,0,114.82151770591736,0.3396532303193068,3581,0.6846410116980243,122.8272988796234,0.3203447886875697,0.6874906676156181,0.3383400251037563,3554,0.665604968653278 -11.987972736358644,0.0427210330963134,194.85300540924072,683,0,194.85300540924072,0.3132889744157532,3581,0.7119667630593759,206.91001749038696,0.2942925861903599,0.7154259000505719,0.3114802629409644,3554,0.6941251137582654 -15.995911598205566,0.0665290355682373,274.98014783859253,1021,0,274.98014783859253,0.3037232433874092,3581,0.722000731399225,291.08116149902344,0.2844969374792916,0.7261792591639927,0.301730059879713,3554,0.7047837686189856 -20.01418519020081,0.0900325775146484,355.20496368408203,1367,0,355.20496368408203,0.2994538844880096,3581,0.7267269421076515,375.3608241081238,0.2801398549761091,0.7311660902840751,0.297597289420811,3554,0.7096444614562817 -24.033470630645752,0.1138274669647216,435.2983305454254,1712,0,435.2983305454254,0.2972389611076689,3581,0.7288269878132854,459.5098702907562,0.2782946484429495,0.7326619284493583,0.2955972116585185,3554,0.7117603926693514 -28.04665946960449,0.1377015113830566,515.3351848125458,2053,0,515.3351848125458,0.295980010875541,3581,0.731076544937692,543.5966114997864,0.2767984867095947,0.7352352823529925,0.2942941437025183,3554,0.714054586381542 -32.06492805480957,0.1615633964538574,595.3161046504974,2399,0,595.3161046504974,0.2943691327165073,3581,0.7328256853663432,627.6319966316223,0.2746273449489048,0.7378204890659877,0.2926793739778243,3554,0.715780881796919 -36.08144378662109,0.1853799819946289,675.4076058864594,2745,0,675.4076058864594,0.2939220642584124,3581,0.7329012932839989,711.7762885093689,0.2746836798531668,0.736846787588937,0.2923017597359489,3554,0.7157350624956036 -40.09206223487854,0.2090139389038086,755.4829206466675,3089,0,755.4829206466675,0.293090513517523,3581,0.7343056643352066,795.8983507156372,0.2738098076411656,0.738483156476702,0.2915166490991664,3554,0.7173665593565349 -44.11016654968262,0.2337839603424072,835.5611155033112,3435,0,835.5611155033112,0.2926345820934271,3581,0.7349787043379992,880.0320289134979,0.2729310819080898,0.7398324693952288,0.2910367485909362,3554,0.7178401399611354 -48.12693548202515,0.2577452659606933,915.7322247028352,3781,0,915.7322247028352,0.2921319155678407,3581,0.7351857568591176,964.2566239833832,0.272593651499067,0.7399343763078962,0.2906307978246342,3554,0.7179940845702026 -52.14394688606262,3.60978937149048,992.5416581630708,4113,0,992.5416581630708,0.291919783885437,3581,0.7355518655316252,1048.4469621181488,0.2726095233644758,0.7397107396806989,0.2903784825416784,3554,0.7185598533342712 -56.15869307518005,3.634568214416504,1072.5993592739103,4452,0,1072.5993592739103,0.2924081333120986,3581,0.7355363212527926,1132.556327342987,0.2726199116025652,0.7406341007777623,0.2907874902178883,3554,0.7184141520777645 -60.17582702636719,3.66079044342041,1152.6758415699005,4797,0,1152.6758415699005,0.2922863357062447,3581,0.7350659022881179,1216.6884505748749,0.2729235887527466,0.7394133976527623,0.2907678435609524,3554,0.717810807364941 -64.19480657577515,3.686720371246338,1232.8069953918457,5141,0,1232.8069953918457,0.2913546334495427,3581,0.7372924838775831,1300.876842737198,0.2714065313339233,0.7423439707074847,0.2898378903502215,3554,0.720104863687922 -68.21188974380493,3.715046644210816,1312.845635175705,5482,0,1312.845635175705,0.2909789459560877,3581,0.7371058161782672,1384.973824262619,0.2708125114440918,0.7426614080156598,0.2894304969972214,3554,0.7199558650903911 -72.22917580604553,3.7401299476623535,1392.8805315494535,5826,0,1392.8805315494535,0.2913603262007819,3581,0.7378205802979265,1469.063951730728,0.2719070741108486,0.7421904972621373,0.2898802405740011,3554,0.7206405442151449 -76.24688148498535,3.7643887996673584,1472.8445675373075,6170,0,1472.8445675373075,0.2897655377338732,3581,0.7387264435999022,1553.0830590724945,0.2700311115809849,0.743617125919887,0.2883799847552933,3554,0.7214869991338985 -80.26424646377563,3.7941737174987793,1552.9671349525452,6515,0,1552.9671349525452,0.2907156476891929,3581,0.736829496169017,1637.2661266326904,0.2711503676005772,0.7408600534711566,0.2892438537563309,3554,0.7199719396278841 -84.28073906898499,3.819782733917236,1633.011125087738,6858,0,1633.011125087738,0.2900368808468305,3581,0.7378659177778554,1721.365136384964,0.2702873434339251,0.7428475788661412,0.2885929380297728,3554,0.7205947249138295 -88.29626941680908,3.8452749252319336,1713.1123294830322,7202,0,1713.1123294830322,0.2896335136287873,3581,0.7387140354475007,1805.520592451096,0.2699863229479108,0.7438455990382603,0.2882636847895857,3554,0.7213912388549873 -92.31477165222168,3.871781349182129,1793.1978707313538,7544,0,1793.1978707313538,0.2890629090599693,3581,0.7392936734239738,1889.664298772812,0.2692584310259138,0.7445081983293805,0.2876641011063678,3554,0.7220048190639069 -96.33357906341551,3.897242546081543,1873.3463337421413,7890,0,1873.3463337421413,0.2891253588819464,3581,0.7390526689254049,1973.870161294937,0.269193274634225,0.7442601067679269,0.2877252049572225,3554,0.7217778520900746 -100.34822511672974,3.9240989685058594,1953.4747865200045,8235,0,1953.4747865200045,0.2885202910107163,3581,0.7402122175849972,2058.0535831451416,0.268533672605242,0.7455815587724958,0.2871241099927019,3554,0.72304437451639 -104.36516261100768,3.951184272766113,2033.6733074188232,8579,0,2033.6733074188232,0.2928813134315659,3581,0.7344478126745323,2142.3093214035034,0.272737979888916,0.7400757244655064,0.2915414135041502,3554,0.7171354707064224 -108.38202714920044,3.977944850921631,2113.6891655921936,8923,0,2113.6891655921936,0.2886678253063041,3581,0.7400754552019339,2226.3820304870605,0.2683231149400983,0.7457260404314313,0.2872266023428707,3554,0.7229087713667698 -112.39886569976808,4.004060506820679,2193.8962202072144,9269,0,2193.8962202072144,0.2883772904631213,3581,0.7400321630218515,2310.645069360733,0.2683753967285156,0.7454257011413574,0.2870310974979776,3554,0.7227186933956458 -116.41702151298524,4.029884099960327,2274.0123538970947,9611,0,2274.0123538970947,0.2883160337327213,3581,0.741056653714919,2394.818213224411,0.2683801651000976,0.7463662964957101,0.2869476163796514,3554,0.7238779148494654 -120.43198204040527,4.058996915817261,2354.080169200897,9957,0,2354.080169200897,0.2962296397217956,3581,0.7323646066043005,2478.943204164505,0.2757884774889265,0.7377515520368304,0.2945271214542593,3554,0.7153397250457231 -124.446848154068,4.085970640182495,2434.170556306839,10304,0,2434.170556306839,0.2885957625750489,3581,0.7394697737407497,2563.088749408722,0.2687504972730364,0.7443478448050362,0.2872471076823649,3554,0.722362031008195 -128.46352362632751,4.111758708953857,2514.290768623352,10648,0,2514.290768623352,0.2880172836018221,3581,0.7404704026022759,2647.2644975185394,0.2680527653012957,0.7457329886300224,0.2866149799301843,3554,0.723326915425401 -132.48388409614563,4.137834072113037,2594.4230313301086,10992,0,2594.4230313301086,0.2880957549392627,3581,0.7408430562342921,2731.456456422806,0.2677083015441894,0.746624265398298,0.2867493637505715,3554,0.7237021253561128 -136.5043747425079,4.165032148361206,2674.5201218128204,11339,0,2674.5201218128204,0.2881150148461498,3581,0.7407568127574351,2815.614520311356,0.2680037702832903,0.7461288315909249,0.2868198787622661,3554,0.7235591031891883 -140.52345037460327,4.193560123443604,2754.5981678962708,11685,0,2754.5981678962708,0.2877580077579587,3581,0.7407115434541678,2899.753466129303,0.2677056789398193,0.7460990633283343,0.2864074535294738,3554,0.7235476998848129 -144.54294848442078,4.220882415771484,2834.6626737117767,12028,0,2834.6626737117767,0.2886041142160884,3581,0.7407484952047263,2983.877734422684,0.2680112634386335,0.7467365946088519,0.2872677504110685,3554,0.7236964237039252 -148.56185173988342,4.256128311157227,2914.802984476089,12375,0,2914.802984476089,0.2893144127644164,3581,0.7391867042420064,3068.0855479240417,0.2689501898629324,0.7449325834001813,0.2879221523888136,3554,0.7220681554894134 -152.57717752456665,4.285537481307983,2994.950837135315,12722,0,2994.950837135315,0.2884791463954726,3581,0.7420470560772131,3152.291583776474,0.2685403653553554,0.7473699024745396,0.2871186659452817,3554,0.7248627207020258 -156.59312415122986,4.313032388687134,3075.084317445755,13065,0,3075.084317445755,0.2880311575524469,3581,0.7402109222284278,3236.481616020202,0.267665011542184,0.7461308070591518,0.2867349550572418,3554,0.7229474464291995 -160.61153936386108,4.341656446456909,3155.114428043365,13410,0,3155.114428043365,0.2915980241312657,3581,0.7346096640690449,3320.5721111297607,0.2715660844530378,0.7395822661263602,0.2904655186057963,3554,0.7169801522052617 -164.62888169288635,4.370457410812378,3235.2807846069336,13754,0,3235.2807846069336,0.2874327709940484,3581,0.7413727888945127,3404.797929763794,0.2671864543642316,0.7470260347638812,0.2861750425082213,3554,0.7241098965294387 -168.64652037620544,4.3972272872924805,3315.593653202057,14098,0,3315.593653202057,0.2879964215433887,3581,0.7416673802490575,3489.1684036254883,0.2675982883998326,0.7473585265023368,0.2866321879286543,3554,0.7245061957037845 -172.66461992263794,4.424889802932739,3395.570848941803,14442,0,3395.570848941803,0.2878864525883133,3581,0.7415121419907149,3573.20453453064,0.267539586339678,0.7471727643694196,0.2865762018258476,3554,0.7243391991198298 -176.68185997009277,4.453326225280762,3475.565026283264,14787,0,3475.565026283264,0.2875767601032358,3581,0.7423192854867705,3657.257560491562,0.2672019004821777,0.748114994594029,0.2862941761263717,3554,0.725114967290377 -180.69870519638064,4.48215651512146,3555.575869321823,15130,0,3555.575869321823,0.2874424179916574,3581,0.7418366628996789,3741.327105522156,0.2670585598264421,0.7476911544799805,0.2861452633970965,3554,0.7246122601733962 -184.71685791015625,4.515713930130005,3635.66794872284,15475,0,3635.66794872284,0.2873672873106325,3581,0.7421833412236456,3825.483948945999,0.2666790655681065,0.7482358387538365,0.2860073761518711,3554,0.7250108262696962 -188.73474383354187,4.548308849334717,3715.785654783249,15821,0,3715.785654783249,0.2871656207457763,3581,0.7422936510620287,3909.665248632431,0.2668221167155674,0.748002120426723,0.2859135736691932,3554,0.7250834364668332 -192.75176310539248,4.576169013977051,3795.99942445755,16165,0,3795.99942445755,0.2871942208552953,3581,0.7417378067404357,3993.93709754944,0.266790543283735,0.7476752144949776,0.2858654187513189,3554,0.7244753518262873 -196.76408624649048,4.604237794876099,3876.117094039917,16511,0,3876.117094039917,0.2872465123547019,3581,0.7417247168214186,4078.1082940101614,0.2664880411965506,0.7478722163609096,0.285872906463228,3554,0.7244973340997819 -200.78008460998527,4.63153338432312,3956.22412109375,16856,0,3956.22412109375,0.28726229525185,3581,0.7425804702771572,4162.271712303162,0.2666984285627092,0.7484268460954938,0.2859993045358223,3554,0.7252970079927546 -204.73551487922668,4.660937786102295,4036.373031377792,17201,0,4036.373031377792,0.2869837254127688,3581,0.7419232472598436,4246.418367147446,0.2666475432259695,0.7476682662963867,0.2857119550044844,3554,0.72458849184018 -208.7484924793244,4.688414573669434,4116.48851108551,17545,0,4116.48851108551,0.2871294189384948,3581,0.7414532373551382,4330.587821722031,0.2663342441831316,0.7476886340550014,0.2857955048174152,3554,0.7242569716780388 -212.76672649383545,4.716022253036499,4196.598222017288,17892,0,4196.598222017288,0.2873294151751431,3581,0.7417253985880341,4414.756689548492,0.2668954474585397,0.7475626809256417,0.2859883305727261,3554,0.7244492478765123 -216.7855279445648,4.743778467178345,4276.7827315330505,18239,0,4276.7827315330505,0.2871809945829552,3581,0.7431509044043214,4499.001486063004,0.2666161571230207,0.7490292276654925,0.2858516626567336,3554,0.7260106075460748 -220.7995901107788,4.771770715713501,4356.798830032349,18583,0,4356.798830032349,0.2873580834613236,3581,0.7413240425815065,4583.0732300281525,0.2666305303573608,0.7473668370928083,0.2860666939430044,3554,0.7240039007544317 -224.81593370437625,4.800743103027344,4436.8140461444855,18929,0,4436.8140461444855,0.2871458154255969,3581,0.742283833622766,4667.147452354431,0.2665919235774449,0.7481992585318429,0.2858412897714283,3554,0.7250225043524902 -228.83026695251465,4.829066753387451,4517.03516125679,19275,0,4517.03516125679,0.2869741125034906,3581,0.7412129827998464,4751.42479133606,0.2664503370012556,0.7471611840384347,0.2856836871746623,3554,0.7239721638470737 -232.84253406524653,4.8612940311431885,4597.184770107269,19618,0,4597.184770107269,0.2870630830468095,3581,0.7430842958059899,4835.632307052612,0.2661271776471819,0.7494119235447475,0.2857512483183561,3554,0.7260501756383653 -236.85779643058777,4.890937805175781,4677.34040427208,19964,0,4677.34040427208,0.2869520573534801,3581,0.7426314664199944,4919.846963167191,0.2662330695561,0.7486386980329242,0.2856385032984401,3554,0.7253936613015265 -240.87046551704407,4.920657634735107,4757.374995470047,20308,0,4757.374995470047,0.2868393954202736,3581,0.7432166948827144,5003.937238454819,0.2662207569394793,0.7491888999938965,0.2855256895839195,3554,0.7260768978395822 -244.8877301216125,4.949784517288208,4837.548028469086,20650,0,4837.548028469086,0.2873131209530333,3581,0.7420767811016475,5088.170040607452,0.2663572515760149,0.7485660825456891,0.2859770990048624,3554,0.7249621217949845 -248.90308475494385,4.978290319442749,4917.593822479248,20995,0,4917.593822479248,0.2868216013116099,3581,0.743107612224239,5172.273324012756,0.2659975971494402,0.7492354256766183,0.2854717814929041,3554,0.7259607352630838 -252.9222385883332,5.007358312606812,4997.784978628159,21341,0,4997.784978628159,0.2865837329394722,3581,0.7427538435274714,5256.526384115219,0.2659299033028738,0.748748915536063,0.285351445719172,3554,0.7255290583673326 -256.93958377838135,5.038073301315308,5077.8388476371765,21680,0,5077.8388476371765,0.2892507016742006,3581,0.7392589715032463,5340.641524076462,0.2684741360800607,0.7451433454241071,0.2879481876439839,3554,0.7219432686981219 -260.9596357345581,5.067148923873901,5157.941499471664,22026,0,5157.941499471664,0.2867683894272724,3581,0.7429517603759425,5424.806901931763,0.2656904969896589,0.7494405337742397,0.2854201746711451,3554,0.7258220408562536 -264.9744710922241,5.100812196731567,5238.024378061295,22371,0,5238.024378061295,0.2864998074691078,3581,0.7429030140629364,5508.9515812397,0.2657301425933838,0.74909394127982,0.2852048170855022,3554,0.7256988714300788 -268.993115901947,5.129384994506836,5318.00970864296,22713,0,5318.00970864296,0.2865557805082379,3581,0.7430842276293284,5592.997215270996,0.2658064535685948,0.7492767742701939,0.2852625892480304,3554,0.7258655245410102 -273.005410194397,5.15982985496521,5398.11927318573,23058,0,5398.11927318573,0.2866300248926626,3581,0.7427997945973541,5677.161759853363,0.2653361899512155,0.7493971415928432,0.285262760984542,3554,0.7256113545037282 -277.02224564552307,5.189005136489868,5478.105720996857,23402,0,5478.105720996857,0.2865171584294715,3581,0.7426350116063949,5761.20668721199,0.2655630792890276,0.7490113803318569,0.2852217503055536,3554,0.725430275525816 -281.0366368293762,5.219642877578735,5558.272310495377,23744,0,5558.272310495377,0.2865147040696558,3581,0.7431993780106814,5845.430538654327,0.2656279632023403,0.7494708469935826,0.2851841915304498,3554,0.7260391845016179 -285.0529205799103,5.249758005142212,5638.279019832611,24088,0,5638.279019832611,0.2865926640821348,3581,0.7436515256300614,5929.496938467026,0.2651689904076712,0.7504288128444127,0.2852026016845016,3554,0.7265264353325478 -289.0705211162567,5.280286550521851,5718.319184303284,24433,0,5718.319184303284,0.2864014285464954,3581,0.7432730088051522,6013.598551034927,0.2654334136417934,0.7496530669076102,0.2851277073917593,3554,0.7260332767656162 -293.0866525173187,5.3132641315460205,5798.420118570328,24780,0,5798.420118570328,0.2866211278383307,3581,0.7432156722327912,6097.761975288391,0.2656197888510568,0.7496503421238491,0.2852865808387116,3554,0.7260538851470174 -297.1022679805756,5.344352006912232,5878.403662443161,25123,0,5878.403662443161,0.2865872781258726,3581,0.7439480259311295,6181.804616928101,0.2650500876562936,0.7507190023149762,0.2852178862340409,3554,0.7268040302300225 -301.12024998664856,5.3743064403533936,5958.376376628876,25467,0,5958.376376628876,0.286416836472005,3581,0.7436492076235688,6265.838335990906,0.2652744735990252,0.7501741136823382,0.2851033208071011,3554,0.7264499782375492 -305.1376292705536,5.405126094818115,6038.365174293518,25810,0,6038.365174293518,0.2864697074730347,3581,0.7428756070449944,6349.887370347977,0.2654041562761579,0.7493463924952916,0.2852030138521296,3554,0.7256698823069077 -309.15387773513794,5.434753894805908,6118.489009618759,26151,0,6118.489009618759,0.2866675561448443,3581,0.7432468289671181,6434.069468021393,0.2650456428527832,0.7501864433288574,0.2852682737265669,3554,0.7261618730655599 -313.1702241897583,5.465415954589844,6198.512864589691,26495,0,6198.512864589691,0.2864623443935877,3581,0.7435073319908894,6518.152428388596,0.2651315076010568,0.7501932552882603,0.2851360881335291,3554,0.726330243541608 -317.18662691116333,5.497414350509644,6278.578718662262,26840,0,6278.578718662262,0.2863978151834334,3581,0.7426230806906241,6602.279229164124,0.2652441433497837,0.7490963935852051,0.2850785735757685,3554,0.7254626993792206 -321.2055685520172,5.527190208435059,6358.797013282776,27183,0,6358.797013282776,0.2865224762090722,3581,0.7430782962597738,6686.558983802795,0.2650856801441738,0.7498324257986886,0.2851935168230339,3554,0.7258750043964547 -325.2188792228699,5.559579610824585,6438.806653022766,27528,0,6438.806653022766,0.2863713285504223,3581,0.7431881970381876,6770.626757621765,0.26494300365448,0.7499326297215053,0.285010428527935,3554,0.7259978990442107 -329.2352783679962,5.590218782424927,6519.011382341385,27873,0,6519.011382341385,0.2863230253857163,3581,0.7436978175832519,6854.8912081718445,0.2650255646024431,0.7503581728254046,0.2850092435460045,3554,0.726517230255522 -333.2514762878418,5.620994567871094,6599.230717420578,28216,0,6599.230717420578,0.2867345056264835,3581,0.7434523816016825,6939.170467376709,0.2652479580470493,0.7501605578831264,0.2852860828028278,3554,0.7263499588931486 -337.26547598838806,5.651025295257568,6679.461097002029,28561,0,6679.461097002029,0.2863785893648771,3581,0.7433470486595923,7023.457095623016,0.2646962233952113,0.7503539494105748,0.2849880512604635,3554,0.7262106462348762 -341.2815098762512,5.682204484939575,6759.423446655273,28905,0,6759.423446655273,0.2863066288986142,3581,0.743408953068277,7107.479385375977,0.2648711545126779,0.7502357619149345,0.2849589934426878,3554,0.7262302241972074 -345.2969605922699,5.712467908859253,6839.43584895134,29248,0,6839.43584895134,0.2865055343086777,3581,0.7435973933607931,7191.550079584122,0.264979498726981,0.750530515398298,0.2850998860768676,3554,0.7264658466912282 -349.3138077259064,5.7450034618377686,6919.53741979599,29595,0,6919.53741979599,0.2862506899478148,3581,0.7437091349090686,7275.713355302811,0.2644538879394531,0.7509183202471051,0.2848655859539867,3554,0.7266162191808525 -353.32518696784973,5.7767322063446045,6999.530682325363,29941,0,6999.530682325363,0.2862475197330529,3581,0.7437569949254748,7359.7621948719025,0.2646848814828055,0.7506739071437291,0.2849207305478862,3554,0.7265805666810284 -357.34046387672424,5.810181140899658,7079.652544736862,30283,0,7079.652544736862,0.286300186204098,3581,0.7434573584979755,7443.945029020309,0.2647715636662074,0.7503796986171177,0.2849135519616981,3554,0.7263325104635622 -361.350955247879,5.84227180480957,7159.747346639633,30631,0,7159.747346639633,0.2862401225652751,3581,0.7437173161084544,7528.094771146774,0.2642929894583566,0.7510991777692523,0.2848544746016812,3554,0.726567720789955 -365.3675897121429,5.872438192367554,7239.900948047638,30978,0,7239.900948047638,0.2862286007094736,3581,0.7437042943660989,7612.307431459427,0.2645713601793562,0.7507302420479911,0.2848851810899691,3554,0.7265148259443585 -369.3829791545868,6.204082012176514,7319.654996156692,31320,0,7319.654996156692,0.2862563145223925,3581,0.7436085743332868,7696.420757055283,0.2646412168230329,0.7506299700055804,0.2848772640367807,3554,0.7264860429050014 -373.4013526439667,6.235177278518677,7399.781029224396,31664,0,7399.781029224396,0.2863126284448303,3581,0.7439497985243297,7780.608487844467,0.2642252445220947,0.7513968603951591,0.2848980613283448,3554,0.72686193978176 -377.4197075366974,6.269732475280762,7479.888315439224,32010,0,7479.888315439224,0.2862076704743787,3581,0.7437827657035395,7864.781018733978,0.2644398893628801,0.7509576252528599,0.2848405639442353,3554,0.7266498108425365 -381.4381649494171,6.301154613494873,7559.8783304691315,32355,0,7559.8783304691315,0.2862447585782602,3581,0.7437727437342921,7948.833317041397,0.2645163365772792,0.7508602823529925,0.2848704976182206,3554,0.7266644427933314 -385.4569594860077,6.333528757095337,7640.045894861221,32698,0,7640.045894861221,0.2862259418196732,3581,0.7437449276563809,8033.064491033554,0.2641677345548357,0.751176289149693,0.2848085866057611,3554,0.7266272790122046 -389.47060203552246,6.371289491653442,7720.015320777893,33041,0,7720.015320777893,0.286153435940118,3581,0.7437026581262217,8117.097451686859,0.2642357519694737,0.7510192053658622,0.284751741820396,3554,0.726567720789955 -393.484384059906,6.404409885406494,7799.988336086273,33384,0,7799.988336086273,0.2861949214386693,3581,0.7436151192927953,8201.129628896713,0.2643541949135916,0.7508604867117745,0.284828542388418,3554,0.7265020487478897 -397.50036454200745,6.436081647872925,7880.060120820999,33725,0,7880.060120820999,0.2861725594936819,3581,0.7437837883534627,8285.262045621872,0.264176675251552,0.7511130741664341,0.2847538198321873,3554,0.7266970040359454 -401.5164318084717,6.469133138656616,7960.137799978256,34069,0,7960.137799978256,0.2860788506723855,3581,0.7438623278675649,8369.402106523514,0.2640826020921979,0.7512414796011788,0.2846950344242402,3554,0.7267279166080473 -405.53353238105774,6.502725839614868,8040.238317012787,34415,0,8040.238317012787,0.2860408762719038,3581,0.7440084304532603,8453.566410779953,0.2641456127166748,0.7512760162353516,0.2846884912631454,3554,0.7268698396612971 -409.5499720573425,6.540175437927246,8120.395667791367,34758,0,8120.395667791367,0.2861053373053965,3581,0.7437508590259355,8537.790835618973,0.2641614505222865,0.7510401862008231,0.2847078803153137,3554,0.726607563660664 -413.56875348091125,6.579923152923584,8200.408621549606,35102,0,8200.408621549606,0.2860836912153554,3581,0.7439925452911198,8621.87518954277,0.2640484060559954,0.7514004707336426,0.2846869971554938,3554,0.726867778823157 -417.5888998508453,6.61214542388916,8280.389010429382,35446,0,8280.389010429382,0.2860260478480173,3581,0.7439716832326864,8705.920965194702,0.2640537875039236,0.7513203620910645,0.2846417445846669,3554,0.7268525973155248 -421.6066160202026,6.645025968551636,8360.353709697723,35790,0,8360.353709697723,0.2860469099064507,3581,0.7440648125523597,8789.949140787125,0.2640480313982282,0.7514446122305733,0.2846638985946733,3554,0.7269441672235509 -425.6265361309052,6.677982568740845,8440.450272083282,36134,0,8440.450272083282,0.2860179007369624,3581,0.7440346102912944,8874.111963510513,0.2640267610549927,0.7513912064688546,0.2846340679625949,3554,0.7269162772140546 -429.6439287662506,6.712190389633179,8451.264714956284,36189,0,8451.264714956284,0.2860175257653239,3581,0.7440320877548171,8888.98148393631,0.2640261820384434,0.7513892991202218,0.2846337244895716,3554,0.7269135294298678 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/measurements.csv deleted file mode 100644 index 7d6a91719..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.160697,0.8857512,,,,,,,,,,,,,, -1,,,0.2656774180276053,0.8928326879228864,0.2613217295323051,0.8956671152882316,3554.0,0.2836721463431129,0.8949576268413153,3581.0,34.67232370376587,38.63045763969421,34.67232370376587,3.95801854133606,0.0,0.0 -100,1.0045426,0.38270354,,,,,,,,,,,,,, -200,0.25704113,0.2843664,,,,,,,,,,,,,, -300,0.102702536,0.32611355,,,,,,,,,,,,,, -343,,,0.6874906676156181,0.3203447886875697,0.665604968653278,0.3383400251037563,3554.0,0.6846410116980243,0.3396532303193068,3581.0,114.82151770591736,122.8272988796234,114.82151770591736,7.97329568862915,0.0188593864440917,0.0 -400,0.13150357,0.2991799,,,,,,,,,,,,,, -500,0.16758057,0.30704203,,,,,,,,,,,,,, -600,0.09708159,0.2775314,,,,,,,,,,,,,, -683,,,0.7154259000505719,0.2942925861903599,0.6941251137582654,0.3114802629409644,3554.0,0.7119667630593759,0.3132889744157532,3581.0,194.85300540924072,206.91001749038696,194.85300540924072,11.987972736358644,0.0427210330963134,0.0 -700,0.14231499,0.24069493,,,,,,,,,,,,,, -800,0.12847377,0.35613957,,,,,,,,,,,,,, -900,0.12614387,0.36961144,,,,,,,,,,,,,, -1000,0.13489927,0.2839598,,,,,,,,,,,,,, -1021,,,0.7261792591639927,0.2844969374792916,0.7047837686189856,0.301730059879713,3554.0,0.722000731399225,0.3037232433874092,3581.0,274.98014783859253,291.08116149902344,274.98014783859253,15.995911598205566,0.0665290355682373,0.0 -1100,0.5156921,0.27037162,,,,,,,,,,,,,, -1200,0.15660271,0.26984057,,,,,,,,,,,,,, -1300,0.1941639,0.29411244,,,,,,,,,,,,,, -1367,,,0.7311660902840751,0.2801398549761091,0.7096444614562817,0.297597289420811,3554.0,0.7267269421076515,0.2994538844880096,3581.0,355.20496368408203,375.3608241081238,355.20496368408203,20.01418519020081,0.0900325775146484,0.0 -1400,0.31304848,0.2962491,,,,,,,,,,,,,, -1500,0.24462837,0.2761973,,,,,,,,,,,,,, -1600,0.66708535,0.25307095,,,,,,,,,,,,,, -1700,0.12604518,0.41422498,,,,,,,,,,,,,, -1712,,,0.7326619284493583,0.2782946484429495,0.7117603926693514,0.2955972116585185,3554.0,0.7288269878132854,0.2972389611076689,3581.0,435.2983305454254,459.5098702907562,435.2983305454254,24.033470630645752,0.1138274669647216,0.0 -1800,0.19629326,0.25790018,,,,,,,,,,,,,, -1900,0.35301307,0.31869423,,,,,,,,,,,,,, -2000,0.089472584,0.27461222,,,,,,,,,,,,,, -2053,,,0.7352352823529925,0.2767984867095947,0.714054586381542,0.2942941437025183,3554.0,0.731076544937692,0.295980010875541,3581.0,515.3351848125458,543.5966114997864,515.3351848125458,28.04665946960449,0.1377015113830566,0.0 -2100,0.16081944,0.31119332,,,,,,,,,,,,,, -2200,0.30054736,0.28056562,,,,,,,,,,,,,, -2300,0.19027752,0.27832016,,,,,,,,,,,,,, -2399,,,0.7378204890659877,0.2746273449489048,0.715780881796919,0.2926793739778243,3554.0,0.7328256853663432,0.2943691327165073,3581.0,595.3161046504974,627.6319966316223,595.3161046504974,32.06492805480957,0.1615633964538574,0.0 -2400,0.20696725,0.33755004,,,,,,,,,,,,,, -2500,0.3138887,0.3029544,,,,,,,,,,,,,, -2600,0.1978339,0.24734867,,,,,,,,,,,,,, -2700,0.1055276,0.29929495,,,,,,,,,,,,,, -2745,,,0.736846787588937,0.2746836798531668,0.7157350624956036,0.2923017597359489,3554.0,0.7329012932839989,0.2939220642584124,3581.0,675.4076058864594,711.7762885093689,675.4076058864594,36.08144378662109,0.1853799819946289,0.0 -2800,0.24613436,0.23744786,,,,,,,,,,,,,, -2900,0.24171,0.25138357,,,,,,,,,,,,,, -3000,0.07837304,0.26424122,,,,,,,,,,,,,, -3089,,,0.738483156476702,0.2738098076411656,0.7173665593565349,0.2915166490991664,3554.0,0.7343056643352066,0.293090513517523,3581.0,755.4829206466675,795.8983507156372,755.4829206466675,40.09206223487854,0.2090139389038086,0.0 -3100,0.14384335,0.28686082,,,,,,,,,,,,,, -3200,0.11545193,0.2656305,,,,,,,,,,,,,, -3300,0.10635659,0.26420635,,,,,,,,,,,,,, -3400,0.11530604,0.28499717,,,,,,,,,,,,,, -3435,,,0.7398324693952288,0.2729310819080898,0.7178401399611354,0.2910367485909362,3554.0,0.7349787043379992,0.2926345820934271,3581.0,835.5611155033112,880.0320289134979,835.5611155033112,44.11016654968262,0.2337839603424072,0.0 -3500,0.17004281,0.32802337,,,,,,,,,,,,,, -3600,0.1431602,0.24652785,,,,,,,,,,,,,, -3700,0.21402077,0.25237688,,,,,,,,,,,,,, -3781,,,0.7399343763078962,0.272593651499067,0.7179940845702026,0.2906307978246342,3554.0,0.7351857568591176,0.2921319155678407,3581.0,915.7322247028352,964.2566239833832,915.7322247028352,48.12693548202515,0.2577452659606933,0.0 -3800,0.101471,0.26435837,,,,,,,,,,,,,, -3900,0.16121256,0.32184738,,,,,,,,,,,,,, -4000,0.273132,0.2333734,,,,,,,,,,,,,, -4100,0.12034505,0.24135876,,,,,,,,,,,,,, -4113,,,0.7397107396806989,0.2726095233644758,0.7185598533342712,0.2903784825416784,3554.0,0.7355518655316252,0.291919783885437,3581.0,992.5416581630708,1048.4469621181488,992.5416581630708,52.14394688606262,3.60978937149048,0.0 -4200,0.15105222,0.29368916,,,,,,,,,,,,,, -4300,0.17530672,0.29503092,,,,,,,,,,,,,, -4400,0.10832762,0.29505447,,,,,,,,,,,,,, -4452,,,0.7406341007777623,0.2726199116025652,0.7184141520777645,0.2907874902178883,3554.0,0.7355363212527926,0.2924081333120986,3581.0,1072.5993592739103,1132.556327342987,1072.5993592739103,56.15869307518005,3.634568214416504,0.0 -4500,0.20169868,0.25392574,,,,,,,,,,,,,, -4600,0.13326861,0.25139746,,,,,,,,,,,,,, -4700,0.10194465,0.27995393,,,,,,,,,,,,,, -4797,,,0.7394133976527623,0.2729235887527466,0.717810807364941,0.2907678435609524,3554.0,0.7350659022881179,0.2922863357062447,3581.0,1152.6758415699005,1216.6884505748749,1152.6758415699005,60.17582702636719,3.66079044342041,0.0 -4800,0.16337706,0.25300923,,,,,,,,,,,,,, -4900,0.07305675,0.28432357,,,,,,,,,,,,,, -5000,0.17270337,0.25745535,,,,,,,,,,,,,, -5100,0.13278517,0.2326528,,,,,,,,,,,,,, -5141,,,0.7423439707074847,0.2714065313339233,0.720104863687922,0.2898378903502215,3554.0,0.7372924838775831,0.2913546334495427,3581.0,1232.8069953918457,1300.876842737198,1232.8069953918457,64.19480657577515,3.686720371246338,0.0 -5200,0.13594228,0.25733927,,,,,,,,,,,,,, -5300,0.27090824,0.2579026,,,,,,,,,,,,,, -5400,0.1011892,0.32120246,,,,,,,,,,,,,, -5482,,,0.7426614080156598,0.2708125114440918,0.7199558650903911,0.2894304969972214,3554.0,0.7371058161782672,0.2909789459560877,3581.0,1312.845635175705,1384.973824262619,1312.845635175705,68.21188974380493,3.715046644210816,0.0 -5500,0.10188991,0.2742706,,,,,,,,,,,,,, -5600,0.13558051,0.216883,,,,,,,,,,,,,, -5700,0.17463705,0.25457025,,,,,,,,,,,,,, -5800,0.18724534,0.36040926,,,,,,,,,,,,,, -5826,,,0.7421904972621373,0.2719070741108486,0.7206405442151449,0.2898802405740011,3554.0,0.7378205802979265,0.2913603262007819,3581.0,1392.8805315494535,1469.063951730728,1392.8805315494535,72.22917580604553,3.7401299476623535,0.0 -5900,0.2664735,0.25312826,,,,,,,,,,,,,, -6000,0.17962044,0.22758959,,,,,,,,,,,,,, -6100,0.03940572,0.34418982,,,,,,,,,,,,,, -6170,,,0.743617125919887,0.2700311115809849,0.7214869991338985,0.2883799847552933,3554.0,0.7387264435999022,0.2897655377338732,3581.0,1472.8445675373075,1553.0830590724945,1472.8445675373075,76.24688148498535,3.7643887996673584,0.0 -6200,0.11505488,0.23526779,,,,,,,,,,,,,, -6300,0.16105705,0.26223305,,,,,,,,,,,,,, -6400,0.20857821,0.21379781,,,,,,,,,,,,,, -6500,0.08332754,0.3001471,,,,,,,,,,,,,, -6515,,,0.7408600534711566,0.2711503676005772,0.7199719396278841,0.2892438537563309,3554.0,0.736829496169017,0.2907156476891929,3581.0,1552.9671349525452,1637.2661266326904,1552.9671349525452,80.26424646377563,3.7941737174987793,0.0 -6600,0.28011423,0.33763257,,,,,,,,,,,,,, -6700,0.09772515,0.23027714,,,,,,,,,,,,,, -6800,0.12646784,0.3648526,,,,,,,,,,,,,, -6858,,,0.7428475788661412,0.2702873434339251,0.7205947249138295,0.2885929380297728,3554.0,0.7378659177778554,0.2900368808468305,3581.0,1633.011125087738,1721.365136384964,1633.011125087738,84.28073906898499,3.819782733917236,0.0 -6900,0.097456,0.3278248,,,,,,,,,,,,,, -7000,0.16416258,0.25142488,,,,,,,,,,,,,, -7100,0.2152837,0.23112988,,,,,,,,,,,,,, -7200,0.09771778,0.2661825,,,,,,,,,,,,,, -7202,,,0.7438455990382603,0.2699863229479108,0.7213912388549873,0.2882636847895857,3554.0,0.7387140354475007,0.2896335136287873,3581.0,1713.1123294830322,1805.520592451096,1713.1123294830322,88.29626941680908,3.8452749252319336,0.0 -7300,0.16961224,0.28825298,,,,,,,,,,,,,, -7400,0.18188083,0.21952896,,,,,,,,,,,,,, -7500,0.13294216,0.1609398,,,,,,,,,,,,,, -7544,,,0.7445081983293805,0.2692584310259138,0.7220048190639069,0.2876641011063678,3554.0,0.7392936734239738,0.2890629090599693,3581.0,1793.1978707313538,1889.664298772812,1793.1978707313538,92.31477165222168,3.871781349182129,0.0 -7600,0.13804398,0.2880167,,,,,,,,,,,,,, -7700,0.33848923,0.24549073,,,,,,,,,,,,,, -7800,0.10840711,0.3021476,,,,,,,,,,,,,, -7890,,,0.7442601067679269,0.269193274634225,0.7217778520900746,0.2877252049572225,3554.0,0.7390526689254049,0.2891253588819464,3581.0,1873.3463337421413,1973.870161294937,1873.3463337421413,96.33357906341551,3.897242546081543,0.0 -7900,0.1514641,0.29383364,,,,,,,,,,,,,, -8000,0.14040701,0.29476961,,,,,,,,,,,,,, -8100,0.23164801,0.25592136,,,,,,,,,,,,,, -8200,0.0691385,0.2947613,,,,,,,,,,,,,, -8235,,,0.7455815587724958,0.268533672605242,0.72304437451639,0.2871241099927019,3554.0,0.7402122175849972,0.2885202910107163,3581.0,1953.4747865200045,2058.0535831451416,1953.4747865200045,100.34822511672974,3.9240989685058594,0.0 -8300,0.28536132,0.21304236,,,,,,,,,,,,,, -8400,0.14716211,0.2677426,,,,,,,,,,,,,, -8500,0.182616,0.26314166,,,,,,,,,,,,,, -8579,,,0.7400757244655064,0.272737979888916,0.7171354707064224,0.2915414135041502,3554.0,0.7344478126745323,0.2928813134315659,3581.0,2033.6733074188232,2142.3093214035034,2033.6733074188232,104.36516261100768,3.951184272766113,0.0 -8600,0.12161156,0.2883592,,,,,,,,,,,,,, -8700,0.14791957,0.21719186,,,,,,,,,,,,,, -8800,0.1112766,0.3469149,,,,,,,,,,,,,, -8900,0.14025089,0.283208,,,,,,,,,,,,,, -8923,,,0.7457260404314313,0.2683231149400983,0.7229087713667698,0.2872266023428707,3554.0,0.7400754552019339,0.2886678253063041,3581.0,2113.6891655921936,2226.3820304870605,2113.6891655921936,108.38202714920044,3.977944850921631,0.0 -9000,0.5251573,0.23853458,,,,,,,,,,,,,, -9100,0.18496059,0.28307658,,,,,,,,,,,,,, -9200,0.19853781,0.23747656,,,,,,,,,,,,,, -9269,,,0.7454257011413574,0.2683753967285156,0.7227186933956458,0.2870310974979776,3554.0,0.7400321630218515,0.2883772904631213,3581.0,2193.8962202072144,2310.645069360733,2193.8962202072144,112.39886569976808,4.004060506820679,0.0 -9300,0.084926724,0.41591018,,,,,,,,,,,,,, -9400,0.14880385,0.30259448,,,,,,,,,,,,,, -9500,0.09754071,0.2771289,,,,,,,,,,,,,, -9600,0.09783017,0.33210945,,,,,,,,,,,,,, -9611,,,0.7463662964957101,0.2683801651000976,0.7238779148494654,0.2869476163796514,3554.0,0.741056653714919,0.2883160337327213,3581.0,2274.0123538970947,2394.818213224411,2274.0123538970947,116.41702151298524,4.029884099960327,0.0 -9700,0.070929244,0.39976338,,,,,,,,,,,,,, -9800,0.15830635,0.2453332,,,,,,,,,,,,,, -9900,0.08744219,0.2783085,,,,,,,,,,,,,, -9957,,,0.7377515520368304,0.2757884774889265,0.7153397250457231,0.2945271214542593,3554.0,0.7323646066043005,0.2962296397217956,3581.0,2354.080169200897,2478.943204164505,2354.080169200897,120.43198204040527,4.058996915817261,0.0 -10000,0.14505349,0.2978232,,,,,,,,,,,,,, -10100,0.28129223,0.24388598,,,,,,,,,,,,,, -10200,0.15580554,0.36567682,,,,,,,,,,,,,, -10300,0.09325551,0.29484323,,,,,,,,,,,,,, -10304,,,0.7443478448050362,0.2687504972730364,0.722362031008195,0.2872471076823649,3554.0,0.7394697737407497,0.2885957625750489,3581.0,2434.170556306839,2563.088749408722,2434.170556306839,124.446848154068,4.085970640182495,0.0 -10400,0.123183645,0.28145078,,,,,,,,,,,,,, -10500,0.113378696,0.24492523,,,,,,,,,,,,,, -10600,0.09076503,0.24624272,,,,,,,,,,,,,, -10648,,,0.7457329886300224,0.2680527653012957,0.723326915425401,0.2866149799301843,3554.0,0.7404704026022759,0.2880172836018221,3581.0,2514.290768623352,2647.2644975185394,2514.290768623352,128.46352362632751,4.111758708953857,0.0 -10700,0.089762226,0.2505284,,,,,,,,,,,,,, -10800,0.08961556,0.2323552,,,,,,,,,,,,,, -10900,0.40988907,0.24811597,,,,,,,,,,,,,, -10992,,,0.746624265398298,0.2677083015441894,0.7237021253561128,0.2867493637505715,3554.0,0.7408430562342921,0.2880957549392627,3581.0,2594.4230313301086,2731.456456422806,2594.4230313301086,132.48388409614563,4.137834072113037,0.0 -11000,0.17555484,0.2679785,,,,,,,,,,,,,, -11100,0.15985571,0.24492985,,,,,,,,,,,,,, -11200,0.114770524,0.38689592,,,,,,,,,,,,,, -11300,0.13277893,0.28100342,,,,,,,,,,,,,, -11339,,,0.7461288315909249,0.2680037702832903,0.7235591031891883,0.2868198787622661,3554.0,0.7407568127574351,0.2881150148461498,3581.0,2674.5201218128204,2815.614520311356,2674.5201218128204,136.5043747425079,4.165032148361206,0.0 -11400,0.13072965,0.26810703,,,,,,,,,,,,,, -11500,0.10839143,0.36760747,,,,,,,,,,,,,, -11600,0.25572363,0.25693023,,,,,,,,,,,,,, -11685,,,0.7460990633283343,0.2677056789398193,0.7235476998848129,0.2864074535294738,3554.0,0.7407115434541678,0.2877580077579587,3581.0,2754.5981678962708,2899.753466129303,2754.5981678962708,140.52345037460327,4.193560123443604,0.0 -11700,0.21454018,0.24051708,,,,,,,,,,,,,, -11800,0.10206442,0.27249128,,,,,,,,,,,,,, -11900,0.1323308,0.31392595,,,,,,,,,,,,,, -12000,0.13392039,0.33880946,,,,,,,,,,,,,, -12028,,,0.7467365946088519,0.2680112634386335,0.7236964237039252,0.2872677504110685,3554.0,0.7407484952047263,0.2886041142160884,3581.0,2834.6626737117767,2983.877734422684,2834.6626737117767,144.54294848442078,4.220882415771484,0.0 -12100,0.07559484,0.32698077,,,,,,,,,,,,,, -12200,0.11704274,0.26831552,,,,,,,,,,,,,, -12300,0.109622605,0.22917405,,,,,,,,,,,,,, -12375,,,0.7449325834001813,0.2689501898629324,0.7220681554894134,0.2879221523888136,3554.0,0.7391867042420064,0.2893144127644164,3581.0,2914.802984476089,3068.0855479240417,2914.802984476089,148.56185173988342,4.256128311157227,0.0 -12400,0.17500924,0.26424533,,,,,,,,,,,,,, -12500,0.09087027,0.34695873,,,,,,,,,,,,,, -12600,0.14626545,0.27740186,,,,,,,,,,,,,, -12700,0.11834035,0.30554172,,,,,,,,,,,,,, -12722,,,0.7473699024745396,0.2685403653553554,0.7248627207020258,0.2871186659452817,3554.0,0.7420470560772131,0.2884791463954726,3581.0,2994.950837135315,3152.291583776474,2994.950837135315,152.57717752456665,4.285537481307983,0.0 -12800,0.13033839,0.2621025,,,,,,,,,,,,,, -12900,0.14879961,0.30432782,,,,,,,,,,,,,, -13000,0.10984103,0.31059697,,,,,,,,,,,,,, -13065,,,0.7461308070591518,0.267665011542184,0.7229474464291995,0.2867349550572418,3554.0,0.7402109222284278,0.2880311575524469,3581.0,3075.084317445755,3236.481616020202,3075.084317445755,156.59312415122986,4.313032388687134,0.0 -13100,0.19267423,0.255424,,,,,,,,,,,,,, -13200,0.15753277,0.23979971,,,,,,,,,,,,,, -13300,0.084277906,0.24138825,,,,,,,,,,,,,, -13400,0.09952215,0.29439342,,,,,,,,,,,,,, -13410,,,0.7395822661263602,0.2715660844530378,0.7169801522052617,0.2904655186057963,3554.0,0.7346096640690449,0.2915980241312657,3581.0,3155.114428043365,3320.5721111297607,3155.114428043365,160.61153936386108,4.341656446456909,0.0 -13500,0.1212753,0.24884355,,,,,,,,,,,,,, -13600,0.14960897,0.20554344,,,,,,,,,,,,,, -13700,0.12675664,0.28128356,,,,,,,,,,,,,, -13754,,,0.7470260347638812,0.2671864543642316,0.7241098965294387,0.2861750425082213,3554.0,0.7413727888945127,0.2874327709940484,3581.0,3235.2807846069336,3404.797929763794,3235.2807846069336,164.62888169288635,4.370457410812378,0.0 -13800,0.18439136,0.32870302,,,,,,,,,,,,,, -13900,0.26052323,0.25286514,,,,,,,,,,,,,, -14000,0.12604025,0.24591756,,,,,,,,,,,,,, -14098,,,0.7473585265023368,0.2675982883998326,0.7245061957037845,0.2866321879286543,3554.0,0.7416673802490575,0.2879964215433887,3581.0,3315.593653202057,3489.1684036254883,3315.593653202057,168.64652037620544,4.3972272872924805,0.0 -14100,0.19760303,0.31900117,,,,,,,,,,,,,, -14200,0.25387195,0.24653696,,,,,,,,,,,,,, -14300,0.07951966,0.2752148,,,,,,,,,,,,,, -14400,0.18295793,0.28915185,,,,,,,,,,,,,, -14442,,,0.7471727643694196,0.267539586339678,0.7243391991198298,0.2865762018258476,3554.0,0.7415121419907149,0.2878864525883133,3581.0,3395.570848941803,3573.20453453064,3395.570848941803,172.66461992263794,4.424889802932739,0.0 -14500,0.16085324,0.33241272,,,,,,,,,,,,,, -14600,0.19872177,0.22065315,,,,,,,,,,,,,, -14700,0.1489996,0.2706986,,,,,,,,,,,,,, -14787,,,0.748114994594029,0.2672019004821777,0.725114967290377,0.2862941761263717,3554.0,0.7423192854867705,0.2875767601032358,3581.0,3475.565026283264,3657.257560491562,3475.565026283264,176.68185997009277,4.453326225280762,0.0 -14800,0.093538836,0.28547606,,,,,,,,,,,,,, -14900,0.119417444,0.27919266,,,,,,,,,,,,,, -15000,0.37102175,0.238074,,,,,,,,,,,,,, -15100,0.09521537,0.1936788,,,,,,,,,,,,,, -15130,,,0.7476911544799805,0.2670585598264421,0.7246122601733962,0.2861452633970965,3554.0,0.7418366628996789,0.2874424179916574,3581.0,3555.575869321823,3741.327105522156,3555.575869321823,180.69870519638064,4.48215651512146,0.0 -15200,0.08298924,0.19442992,,,,,,,,,,,,,, -15300,0.18687364,0.2270965,,,,,,,,,,,,,, -15400,0.092241235,0.29725486,,,,,,,,,,,,,, -15475,,,0.7482358387538365,0.2666790655681065,0.7250108262696962,0.2860073761518711,3554.0,0.7421833412236456,0.2873672873106325,3581.0,3635.66794872284,3825.483948945999,3635.66794872284,184.71685791015625,4.515713930130005,0.0 -15500,0.11073464,0.26247346,,,,,,,,,,,,,, -15600,0.08752601,0.35879314,,,,,,,,,,,,,, -15700,0.29408836,0.2549118,,,,,,,,,,,,,, -15800,0.08533525,0.2663743,,,,,,,,,,,,,, -15821,,,0.748002120426723,0.2668221167155674,0.7250834364668332,0.2859135736691932,3554.0,0.7422936510620287,0.2871656207457763,3581.0,3715.785654783249,3909.665248632431,3715.785654783249,188.73474383354187,4.548308849334717,0.0 -15900,0.11665895,0.18971862,,,,,,,,,,,,,, -16000,0.24344957,0.26185948,,,,,,,,,,,,,, -16100,0.12438091,0.24246542,,,,,,,,,,,,,, -16165,,,0.7476752144949776,0.266790543283735,0.7244753518262873,0.2858654187513189,3554.0,0.7417378067404357,0.2871942208552953,3581.0,3795.99942445755,3993.93709754944,3795.99942445755,192.75176310539248,4.576169013977051,0.0 -16200,0.16730307,0.21328117,,,,,,,,,,,,,, -16300,0.18969402,0.24171275,,,,,,,,,,,,,, -16400,0.084531926,0.2876044,,,,,,,,,,,,,, -16500,0.11668561,0.33346942,,,,,,,,,,,,,, -16511,,,0.7478722163609096,0.2664880411965506,0.7244973340997819,0.285872906463228,3554.0,0.7417247168214186,0.2872465123547019,3581.0,3876.117094039917,4078.1082940101614,3876.117094039917,196.76408624649048,4.604237794876099,0.0 -16600,0.1401572,0.32467133,,,,,,,,,,,,,, -16700,0.12911168,0.23308522,,,,,,,,,,,,,, -16800,0.13499685,0.22607854,,,,,,,,,,,,,, -16856,,,0.7484268460954938,0.2666984285627092,0.7252970079927546,0.2859993045358223,3554.0,0.7425804702771572,0.28726229525185,3581.0,3956.22412109375,4162.271712303162,3956.22412109375,200.78008460998527,4.63153338432312,0.0 -16900,0.08846881,0.43335855,,,,,,,,,,,,,, -17000,0.24283494,0.2247242,,,,,,,,,,,,,, -17100,0.14875358,0.26953188,,,,,,,,,,,,,, -17200,0.14669095,0.2724393,,,,,,,,,,,,,, -17201,,,0.7476682662963867,0.2666475432259695,0.72458849184018,0.2857119550044844,3554.0,0.7419232472598436,0.2869837254127688,3581.0,4036.373031377792,4246.418367147446,4036.373031377792,204.73551487922668,4.660937786102295,0.0 -17300,0.123501636,0.30766714,,,,,,,,,,,,,, -17400,0.14643772,0.29647914,,,,,,,,,,,,,, -17500,0.201715,0.28382167,,,,,,,,,,,,,, -17545,,,0.7476886340550014,0.2663342441831316,0.7242569716780388,0.2857955048174152,3554.0,0.7414532373551382,0.2871294189384948,3581.0,4116.48851108551,4330.587821722031,4116.48851108551,208.7484924793244,4.688414573669434,0.0 -17600,0.10365252,0.27514032,,,,,,,,,,,,,, -17700,0.08348998,0.2164603,,,,,,,,,,,,,, -17800,0.19359067,0.30567977,,,,,,,,,,,,,, -17892,,,0.7475626809256417,0.2668954474585397,0.7244492478765123,0.2859883305727261,3554.0,0.7417253985880341,0.2873294151751431,3581.0,4196.598222017288,4414.756689548492,4196.598222017288,212.76672649383545,4.716022253036499,0.0 -17900,0.11259947,0.3395682,,,,,,,,,,,,,, -18000,0.10038499,0.27532262,,,,,,,,,,,,,, -18100,0.19289564,0.2289349,,,,,,,,,,,,,, -18200,0.07519864,0.23634844,,,,,,,,,,,,,, -18239,,,0.7490292276654925,0.2666161571230207,0.7260106075460748,0.2858516626567336,3554.0,0.7431509044043214,0.2871809945829552,3581.0,4276.7827315330505,4499.001486063004,4276.7827315330505,216.7855279445648,4.743778467178345,0.0 -18300,0.13920072,0.30987296,,,,,,,,,,,,,, -18400,0.06823586,0.3327068,,,,,,,,,,,,,, -18500,0.19748655,0.25466946,,,,,,,,,,,,,, -18583,,,0.7473668370928083,0.2666305303573608,0.7240039007544317,0.2860666939430044,3554.0,0.7413240425815065,0.2873580834613236,3581.0,4356.798830032349,4583.0732300281525,4356.798830032349,220.7995901107788,4.771770715713501,0.0 -18600,0.09347913,0.29343998,,,,,,,,,,,,,, -18700,0.17245083,0.27833822,,,,,,,,,,,,,, -18800,0.084064424,0.28426874,,,,,,,,,,,,,, -18900,0.11915558,0.2431444,,,,,,,,,,,,,, -18929,,,0.7481992585318429,0.2665919235774449,0.7250225043524902,0.2858412897714283,3554.0,0.742283833622766,0.2871458154255969,3581.0,4436.8140461444855,4667.147452354431,4436.8140461444855,224.81593370437625,4.800743103027344,0.0 -19000,0.10755725,0.2724963,,,,,,,,,,,,,, -19100,0.13140503,0.30315676,,,,,,,,,,,,,, -19200,0.09521807,0.22490534,,,,,,,,,,,,,, -19275,,,0.7471611840384347,0.2664503370012556,0.7239721638470737,0.2856836871746623,3554.0,0.7412129827998464,0.2869741125034906,3581.0,4517.03516125679,4751.42479133606,4517.03516125679,228.83026695251465,4.829066753387451,0.0 -19300,0.19026318,0.23728767,,,,,,,,,,,,,, -19400,0.2034137,0.22144759,,,,,,,,,,,,,, -19500,0.27137625,0.3249975,,,,,,,,,,,,,, -19600,0.10448537,0.29275656,,,,,,,,,,,,,, -19618,,,0.7494119235447475,0.2661271776471819,0.7260501756383653,0.2857512483183561,3554.0,0.7430842958059899,0.2870630830468095,3581.0,4597.184770107269,4835.632307052612,4597.184770107269,232.84253406524653,4.8612940311431885,0.0 -19700,0.1699592,0.33331177,,,,,,,,,,,,,, -19800,0.13788079,0.24832408,,,,,,,,,,,,,, -19900,0.09130275,0.23089491,,,,,,,,,,,,,, -19964,,,0.7486386980329242,0.2662330695561,0.7253936613015265,0.2856385032984401,3554.0,0.7426314664199944,0.2869520573534801,3581.0,4677.34040427208,4919.846963167191,4677.34040427208,236.85779643058777,4.890937805175781,0.0 -20000,0.15702832,0.2341626,,,,,,,,,,,,,, -20100,0.07084263,0.35234538,,,,,,,,,,,,,, -20200,0.11854637,0.2803697,,,,,,,,,,,,,, -20300,0.1057396,0.2629912,,,,,,,,,,,,,, -20308,,,0.7491888999938965,0.2662207569394793,0.7260768978395822,0.2855256895839195,3554.0,0.7432166948827144,0.2868393954202736,3581.0,4757.374995470047,5003.937238454819,4757.374995470047,240.87046551704407,4.920657634735107,0.0 -20400,0.18292537,0.28084826,,,,,,,,,,,,,, -20500,0.1865403,0.23431146,,,,,,,,,,,,,, -20600,0.18280943,0.28974634,,,,,,,,,,,,,, -20650,,,0.7485660825456891,0.2663572515760149,0.7249621217949845,0.2859770990048624,3554.0,0.7420767811016475,0.2873131209530333,3581.0,4837.548028469086,5088.170040607452,4837.548028469086,244.8877301216125,4.949784517288208,0.0 -20700,0.15731221,0.26291028,,,,,,,,,,,,,, -20800,0.14059184,0.25905818,,,,,,,,,,,,,, -20900,0.30088538,0.3348695,,,,,,,,,,,,,, -20995,,,0.7492354256766183,0.2659975971494402,0.7259607352630838,0.2854717814929041,3554.0,0.743107612224239,0.2868216013116099,3581.0,4917.593822479248,5172.273324012756,4917.593822479248,248.90308475494385,4.978290319442749,0.0 -21000,0.15312643,0.30845433,,,,,,,,,,,,,, -21100,0.15275267,0.23327887,,,,,,,,,,,,,, -21200,0.15215805,0.28583348,,,,,,,,,,,,,, -21300,0.083545804,0.31121078,,,,,,,,,,,,,, -21341,,,0.748748915536063,0.2659299033028738,0.7255290583673326,0.285351445719172,3554.0,0.7427538435274714,0.2865837329394722,3581.0,4997.784978628159,5256.526384115219,4997.784978628159,252.9222385883332,5.007358312606812,0.0 -21400,0.21381423,0.2622373,,,,,,,,,,,,,, -21500,0.18736045,0.21650672,,,,,,,,,,,,,, -21600,0.14358751,0.24855618,,,,,,,,,,,,,, -21680,,,0.7451433454241071,0.2684741360800607,0.7219432686981219,0.2879481876439839,3554.0,0.7392589715032463,0.2892507016742006,3581.0,5077.8388476371765,5340.641524076462,5077.8388476371765,256.93958377838135,5.038073301315308,0.0 -21700,0.16453664,0.3287571,,,,,,,,,,,,,, -21800,0.37207708,0.26650444,,,,,,,,,,,,,, -21900,0.13086493,0.22957718,,,,,,,,,,,,,, -22000,0.10846878,0.23935103,,,,,,,,,,,,,, -22026,,,0.7494405337742397,0.2656904969896589,0.7258220408562536,0.2854201746711451,3554.0,0.7429517603759425,0.2867683894272724,3581.0,5157.941499471664,5424.806901931763,5157.941499471664,260.9596357345581,5.067148923873901,0.0 -22100,0.3061044,0.30239996,,,,,,,,,,,,,, -22200,0.17235024,0.25820863,,,,,,,,,,,,,, -22300,0.15542224,0.32834768,,,,,,,,,,,,,, -22371,,,0.74909394127982,0.2657301425933838,0.7256988714300788,0.2852048170855022,3554.0,0.7429030140629364,0.2864998074691078,3581.0,5238.024378061295,5508.9515812397,5238.024378061295,264.9744710922241,5.100812196731567,0.0 -22400,0.11315298,0.2258792,,,,,,,,,,,,,, -22500,0.1330908,0.22456542,,,,,,,,,,,,,, -22600,0.14592549,0.21614656,,,,,,,,,,,,,, -22700,0.28282169,0.2970761,,,,,,,,,,,,,, -22713,,,0.7492767742701939,0.2658064535685948,0.7258655245410102,0.2852625892480304,3554.0,0.7430842276293284,0.2865557805082379,3581.0,5318.00970864296,5592.997215270996,5318.00970864296,268.993115901947,5.129384994506836,0.0 -22800,0.13730675,0.2642671,,,,,,,,,,,,,, -22900,0.16763069,0.2057944,,,,,,,,,,,,,, -23000,0.2058671,0.25728583,,,,,,,,,,,,,, -23058,,,0.7493971415928432,0.2653361899512155,0.7256113545037282,0.285262760984542,3554.0,0.7427997945973541,0.2866300248926626,3581.0,5398.11927318573,5677.161759853363,5398.11927318573,273.005410194397,5.15982985496521,0.0 -23100,0.15437987,0.31605148,,,,,,,,,,,,,, -23200,0.13232468,0.272712,,,,,,,,,,,,,, -23300,0.1611217,0.2626007,,,,,,,,,,,,,, -23400,0.10832453,0.30029392,,,,,,,,,,,,,, -23402,,,0.7490113803318569,0.2655630792890276,0.725430275525816,0.2852217503055536,3554.0,0.7426350116063949,0.2865171584294715,3581.0,5478.105720996857,5761.20668721199,5478.105720996857,277.02224564552307,5.189005136489868,0.0 -23500,0.1013463,0.27118585,,,,,,,,,,,,,, -23600,0.1990454,0.25663486,,,,,,,,,,,,,, -23700,0.09504689,0.2848487,,,,,,,,,,,,,, -23744,,,0.7494708469935826,0.2656279632023403,0.7260391845016179,0.2851841915304498,3554.0,0.7431993780106814,0.2865147040696558,3581.0,5558.272310495377,5845.430538654327,5558.272310495377,281.0366368293762,5.219642877578735,0.0 -23800,0.1376446,0.29152545,,,,,,,,,,,,,, -23900,0.12975901,0.22816493,,,,,,,,,,,,,, -24000,0.13004622,0.18882889,,,,,,,,,,,,,, -24088,,,0.7504288128444127,0.2651689904076712,0.7265264353325478,0.2852026016845016,3554.0,0.7436515256300614,0.2865926640821348,3581.0,5638.279019832611,5929.496938467026,5638.279019832611,285.0529205799103,5.249758005142212,0.0 -24100,0.07040644,0.3493752,,,,,,,,,,,,,, -24200,0.20849217,0.27421808,,,,,,,,,,,,,, -24300,0.2559368,0.27024958,,,,,,,,,,,,,, -24400,0.21072657,0.3298409,,,,,,,,,,,,,, -24433,,,0.7496530669076102,0.2654334136417934,0.7260332767656162,0.2851277073917593,3554.0,0.7432730088051522,0.2864014285464954,3581.0,5718.319184303284,6013.598551034927,5718.319184303284,289.0705211162567,5.280286550521851,0.0 -24500,0.11675973,0.3064406,,,,,,,,,,,,,, -24600,0.16066526,0.17353615,,,,,,,,,,,,,, -24700,0.15109769,0.23169139,,,,,,,,,,,,,, -24780,,,0.7496503421238491,0.2656197888510568,0.7260538851470174,0.2852865808387116,3554.0,0.7432156722327912,0.2866211278383307,3581.0,5798.420118570328,6097.761975288391,5798.420118570328,293.0866525173187,5.3132641315460205,0.0 -24800,0.14524817,0.31941295,,,,,,,,,,,,,, -24900,0.110433,0.31543,,,,,,,,,,,,,, -25000,0.103302926,0.25642765,,,,,,,,,,,,,, -25100,0.09995158,0.23636618,,,,,,,,,,,,,, -25123,,,0.7507190023149762,0.2650500876562936,0.7268040302300225,0.2852178862340409,3554.0,0.7439480259311295,0.2865872781258726,3581.0,5878.403662443161,6181.804616928101,5878.403662443161,297.1022679805756,5.344352006912232,0.0 -25200,0.08105088,0.23393995,,,,,,,,,,,,,, -25300,0.12408564,0.29326558,,,,,,,,,,,,,, -25400,0.15938503,0.23107599,,,,,,,,,,,,,, -25467,,,0.7501741136823382,0.2652744735990252,0.7264499782375492,0.2851033208071011,3554.0,0.7436492076235688,0.286416836472005,3581.0,5958.376376628876,6265.838335990906,5958.376376628876,301.12024998664856,5.3743064403533936,0.0 -25500,0.082841985,0.25498593,,,,,,,,,,,,,, -25600,0.093413405,0.22227412,,,,,,,,,,,,,, -25700,0.21580213,0.2698674,,,,,,,,,,,,,, -25800,0.080836035,0.21667778,,,,,,,,,,,,,, -25810,,,0.7493463924952916,0.2654041562761579,0.7256698823069077,0.2852030138521296,3554.0,0.7428756070449944,0.2864697074730347,3581.0,6038.365174293518,6349.887370347977,6038.365174293518,305.1376292705536,5.405126094818115,0.0 -25900,0.1613563,0.2815447,,,,,,,,,,,,,, -26000,0.09576715,0.36036345,,,,,,,,,,,,,, -26100,0.11411086,0.29192585,,,,,,,,,,,,,, -26151,,,0.7501864433288574,0.2650456428527832,0.7261618730655599,0.2852682737265669,3554.0,0.7432468289671181,0.2866675561448443,3581.0,6118.489009618759,6434.069468021393,6118.489009618759,309.15387773513794,5.434753894805908,0.0 -26200,0.11940155,0.28910014,,,,,,,,,,,,,, -26300,0.15473871,0.24399365,,,,,,,,,,,,,, -26400,0.15317093,0.3288021,,,,,,,,,,,,,, -26495,,,0.7501932552882603,0.2651315076010568,0.726330243541608,0.2851360881335291,3554.0,0.7435073319908894,0.2864623443935877,3581.0,6198.512864589691,6518.152428388596,6198.512864589691,313.1702241897583,5.465415954589844,0.0 -26500,0.11288524,0.28385913,,,,,,,,,,,,,, -26600,0.09267373,0.31418613,,,,,,,,,,,,,, -26700,0.104262985,0.24727166,,,,,,,,,,,,,, -26800,0.10593263,0.22202851,,,,,,,,,,,,,, -26840,,,0.7490963935852051,0.2652441433497837,0.7254626993792206,0.2850785735757685,3554.0,0.7426230806906241,0.2863978151834334,3581.0,6278.578718662262,6602.279229164124,6278.578718662262,317.18662691116333,5.497414350509644,0.0 -26900,0.109907664,0.24921086,,,,,,,,,,,,,, -27000,0.14766471,0.25381517,,,,,,,,,,,,,, -27100,0.404792,0.28900963,,,,,,,,,,,,,, -27183,,,0.7498324257986886,0.2650856801441738,0.7258750043964547,0.2851935168230339,3554.0,0.7430782962597738,0.2865224762090722,3581.0,6358.797013282776,6686.558983802795,6358.797013282776,321.2055685520172,5.527190208435059,0.0 -27200,0.08792508,0.26595396,,,,,,,,,,,,,, -27300,0.1610609,0.2492893,,,,,,,,,,,,,, -27400,0.067201756,0.2477759,,,,,,,,,,,,,, -27500,0.16949695,0.36445582,,,,,,,,,,,,,, -27528,,,0.7499326297215053,0.26494300365448,0.7259978990442107,0.285010428527935,3554.0,0.7431881970381876,0.2863713285504223,3581.0,6438.806653022766,6770.626757621765,6438.806653022766,325.2188792228699,5.559579610824585,0.0 -27600,0.18441667,0.2832448,,,,,,,,,,,,,, -27700,0.17373124,0.3052263,,,,,,,,,,,,,, -27800,0.10709118,0.26672044,,,,,,,,,,,,,, -27873,,,0.7503581728254046,0.2650255646024431,0.726517230255522,0.2850092435460045,3554.0,0.7436978175832519,0.2863230253857163,3581.0,6519.011382341385,6854.8912081718445,6519.011382341385,329.2352783679962,5.590218782424927,0.0 -27900,0.083687596,0.25790936,,,,,,,,,,,,,, -28000,0.14402924,0.17639908,,,,,,,,,,,,,, -28100,0.17883047,0.23839472,,,,,,,,,,,,,, -28200,0.13543268,0.30966938,,,,,,,,,,,,,, -28216,,,0.7501605578831264,0.2652479580470493,0.7263499588931486,0.2852860828028278,3554.0,0.7434523816016825,0.2867345056264835,3581.0,6599.230717420578,6939.170467376709,6599.230717420578,333.2514762878418,5.620994567871094,0.0 -28300,0.16291405,0.3088529,,,,,,,,,,,,,, -28400,0.14918453,0.32171324,,,,,,,,,,,,,, -28500,0.12826182,0.2871711,,,,,,,,,,,,,, -28561,,,0.7503539494105748,0.2646962233952113,0.7262106462348762,0.2849880512604635,3554.0,0.7433470486595923,0.2863785893648771,3581.0,6679.461097002029,7023.457095623016,6679.461097002029,337.26547598838806,5.651025295257568,0.0 -28600,0.10777911,0.31667405,,,,,,,,,,,,,, -28700,0.08460814,0.2914191,,,,,,,,,,,,,, -28800,0.13348019,0.23001474,,,,,,,,,,,,,, -28900,0.07510707,0.24915564,,,,,,,,,,,,,, -28905,,,0.7502357619149345,0.2648711545126779,0.7262302241972074,0.2849589934426878,3554.0,0.743408953068277,0.2863066288986142,3581.0,6759.423446655273,7107.479385375977,6759.423446655273,341.2815098762512,5.682204484939575,0.0 -29000,0.1620421,0.2899046,,,,,,,,,,,,,, -29100,0.08179053,0.23747237,,,,,,,,,,,,,, -29200,0.09272558,0.24595529,,,,,,,,,,,,,, -29248,,,0.750530515398298,0.264979498726981,0.7264658466912282,0.2850998860768676,3554.0,0.7435973933607931,0.2865055343086777,3581.0,6839.43584895134,7191.550079584122,6839.43584895134,345.2969605922699,5.712467908859253,0.0 -29300,0.079616904,0.27623957,,,,,,,,,,,,,, -29400,0.12664264,0.19680484,,,,,,,,,,,,,, -29500,0.081696086,0.2238282,,,,,,,,,,,,,, -29595,,,0.7509183202471051,0.2644538879394531,0.7266162191808525,0.2848655859539867,3554.0,0.7437091349090686,0.2862506899478148,3581.0,6919.53741979599,7275.713355302811,6919.53741979599,349.3138077259064,5.7450034618377686,0.0 -29600,0.16162883,0.261011,,,,,,,,,,,,,, -29700,0.10456897,0.30272752,,,,,,,,,,,,,, -29800,0.13022062,0.2563867,,,,,,,,,,,,,, -29900,0.12656917,0.23921159,,,,,,,,,,,,,, -29941,,,0.7506739071437291,0.2646848814828055,0.7265805666810284,0.2849207305478862,3554.0,0.7437569949254748,0.2862475197330529,3581.0,6999.530682325363,7359.7621948719025,6999.530682325363,353.32518696784973,5.7767322063446045,0.0 -30000,0.111619055,0.29443616,,,,,,,,,,,,,, -30100,0.044156197,0.24897046,,,,,,,,,,,,,, -30200,0.131106,0.25572473,,,,,,,,,,,,,, -30283,,,0.7503796986171177,0.2647715636662074,0.7263325104635622,0.2849135519616981,3554.0,0.7434573584979755,0.286300186204098,3581.0,7079.652544736862,7443.945029020309,7079.652544736862,357.34046387672424,5.810181140899658,0.0 -30300,0.1057769,0.17941053,,,,,,,,,,,,,, -30400,0.09397489,0.25353435,,,,,,,,,,,,,, -30500,0.12796953,0.21646914,,,,,,,,,,,,,, -30600,0.07158425,0.34375098,,,,,,,,,,,,,, -30631,,,0.7510991777692523,0.2642929894583566,0.726567720789955,0.2848544746016812,3554.0,0.7437173161084544,0.2862401225652751,3581.0,7159.747346639633,7528.094771146774,7159.747346639633,361.350955247879,5.84227180480957,0.0 -30700,0.05460628,0.28550377,,,,,,,,,,,,,, -30800,0.081544004,0.32303467,,,,,,,,,,,,,, -30900,0.075583234,0.2477528,,,,,,,,,,,,,, -30978,,,0.7507302420479911,0.2645713601793562,0.7265148259443585,0.2848851810899691,3554.0,0.7437042943660989,0.2862286007094736,3581.0,7239.900948047638,7612.307431459427,7239.900948047638,365.3675897121429,5.872438192367554,0.0 -31000,0.0845619,0.29020378,,,,,,,,,,,,,, -31100,0.08809458,0.20301482,,,,,,,,,,,,,, -31200,0.10524256,0.19028883,,,,,,,,,,,,,, -31300,0.06680224,0.30544177,,,,,,,,,,,,,, -31320,,,0.7506299700055804,0.2646412168230329,0.7264860429050014,0.2848772640367807,3554.0,0.7436085743332868,0.2862563145223925,3581.0,7319.654996156692,7696.420757055283,7319.654996156692,369.3829791545868,6.204082012176514,0.0 -31400,0.073042825,0.18936425,,,,,,,,,,,,,, -31500,0.14821662,0.3262445,,,,,,,,,,,,,, -31600,0.07054408,0.22440273,,,,,,,,,,,,,, -31664,,,0.7513968603951591,0.2642252445220947,0.72686193978176,0.2848980613283448,3554.0,0.7439497985243297,0.2863126284448303,3581.0,7399.781029224396,7780.608487844467,7399.781029224396,373.4013526439667,6.235177278518677,0.0 -31700,0.08837146,0.2300573,,,,,,,,,,,,,, -31800,0.06175245,0.24420348,,,,,,,,,,,,,, -31900,0.08571076,0.2980411,,,,,,,,,,,,,, -32000,0.077603966,0.3038233,,,,,,,,,,,,,, -32010,,,0.7509576252528599,0.2644398893628801,0.7266498108425365,0.2848405639442353,3554.0,0.7437827657035395,0.2862076704743787,3581.0,7479.888315439224,7864.781018733978,7479.888315439224,377.4197075366974,6.269732475280762,0.0 -32100,0.07837761,0.3111272,,,,,,,,,,,,,, -32200,0.053709254,0.21944451,,,,,,,,,,,,,, -32300,0.062049743,0.27650264,,,,,,,,,,,,,, -32355,,,0.7508602823529925,0.2645163365772792,0.7266644427933314,0.2848704976182206,3554.0,0.7437727437342921,0.2862447585782602,3581.0,7559.8783304691315,7948.833317041397,7559.8783304691315,381.4381649494171,6.301154613494873,0.0 -32400,0.09993525,0.26049697,,,,,,,,,,,,,, -32500,0.043521516,0.30552223,,,,,,,,,,,,,, -32600,0.089983486,0.2685666,,,,,,,,,,,,,, -32698,,,0.751176289149693,0.2641677345548357,0.7266272790122046,0.2848085866057611,3554.0,0.7437449276563809,0.2862259418196732,3581.0,7640.045894861221,8033.064491033554,7640.045894861221,385.4569594860077,6.333528757095337,0.0 -32700,0.09139706,0.2602842,,,,,,,,,,,,,, -32800,0.07749157,0.2307206,,,,,,,,,,,,,, -32900,0.123239994,0.24610426,,,,,,,,,,,,,, -33000,0.08081065,0.32204825,,,,,,,,,,,,,, -33041,,,0.7510192053658622,0.2642357519694737,0.726567720789955,0.284751741820396,3554.0,0.7437026581262217,0.286153435940118,3581.0,7720.015320777893,8117.097451686859,7720.015320777893,389.47060203552246,6.371289491653442,0.0 -33100,0.0952153,0.2556165,,,,,,,,,,,,,, -33200,0.046902806,0.32935905,,,,,,,,,,,,,, -33300,0.077561975,0.1970114,,,,,,,,,,,,,, -33384,,,0.7508604867117745,0.2643541949135916,0.7265020487478897,0.284828542388418,3554.0,0.7436151192927953,0.2861949214386693,3581.0,7799.988336086273,8201.129628896713,7799.988336086273,393.484384059906,6.404409885406494,0.0 -33400,0.09229081,0.26632744,,,,,,,,,,,,,, -33500,0.06015227,0.25271317,,,,,,,,,,,,,, -33600,0.08394826,0.28493252,,,,,,,,,,,,,, -33700,0.063864745,0.24514104,,,,,,,,,,,,,, -33725,,,0.7511130741664341,0.264176675251552,0.7266970040359454,0.2847538198321873,3554.0,0.7437837883534627,0.2861725594936819,3581.0,7880.060120820999,8285.262045621872,7880.060120820999,397.50036454200745,6.436081647872925,0.0 -33800,0.09230712,0.32656774,,,,,,,,,,,,,, -33900,0.04902444,0.21343969,,,,,,,,,,,,,, -34000,0.07322485,0.26201215,,,,,,,,,,,,,, -34069,,,0.7512414796011788,0.2640826020921979,0.7267279166080473,0.2846950344242402,3554.0,0.7438623278675649,0.2860788506723855,3581.0,7960.137799978256,8369.402106523514,7960.137799978256,401.5164318084717,6.469133138656616,0.0 -34100,0.06506333,0.22267954,,,,,,,,,,,,,, -34200,0.07454886,0.23097016,,,,,,,,,,,,,, -34300,0.051072884,0.24954137,,,,,,,,,,,,,, -34400,0.08150893,0.18609913,,,,,,,,,,,,,, -34415,,,0.7512760162353516,0.2641456127166748,0.7268698396612971,0.2846884912631454,3554.0,0.7440084304532603,0.2860408762719038,3581.0,8040.238317012787,8453.566410779953,8040.238317012787,405.53353238105774,6.502725839614868,0.0 -34500,0.080460414,0.2479426,,,,,,,,,,,,,, -34600,0.057493594,0.22366446,,,,,,,,,,,,,, -34700,0.11330064,0.39583412,,,,,,,,,,,,,, -34758,,,0.7510401862008231,0.2641614505222865,0.726607563660664,0.2847078803153137,3554.0,0.7437508590259355,0.2861053373053965,3581.0,8120.395667791367,8537.790835618973,8120.395667791367,409.5499720573425,6.540175437927246,0.0 -34800,0.086998865,0.33987468,,,,,,,,,,,,,, -34900,0.08550857,0.24511394,,,,,,,,,,,,,, -35000,0.044556472,0.32412052,,,,,,,,,,,,,, -35100,0.073136196,0.27984676,,,,,,,,,,,,,, -35102,,,0.7514004707336426,0.2640484060559954,0.726867778823157,0.2846869971554938,3554.0,0.7439925452911198,0.2860836912153554,3581.0,8200.408621549606,8621.87518954277,8200.408621549606,413.56875348091125,6.579923152923584,0.0 -35200,0.05799222,0.2532949,,,,,,,,,,,,,, -35300,0.06391906,0.30224726,,,,,,,,,,,,,, -35400,0.07287789,0.308821,,,,,,,,,,,,,, -35446,,,0.7513203620910645,0.2640537875039236,0.7268525973155248,0.2846417445846669,3554.0,0.7439716832326864,0.2860260478480173,3581.0,8280.389010429382,8705.920965194702,8280.389010429382,417.5888998508453,6.61214542388916,0.0 -35500,0.07365207,0.26291335,,,,,,,,,,,,,, -35600,0.052120764,0.19465722,,,,,,,,,,,,,, -35700,0.060303897,0.3094798,,,,,,,,,,,,,, -35790,,,0.7514446122305733,0.2640480313982282,0.7269441672235509,0.2846638985946733,3554.0,0.7440648125523597,0.2860469099064507,3581.0,8360.353709697723,8789.949140787125,8360.353709697723,421.6066160202026,6.645025968551636,0.0 -35800,0.0748643,0.22916208,,,,,,,,,,,,,, -35900,0.07645566,0.28973126,,,,,,,,,,,,,, -36000,0.085378155,0.22342202,,,,,,,,,,,,,, -36100,0.04352588,0.39614153,,,,,,,,,,,,,, -36134,,,0.7513912064688546,0.2640267610549927,0.7269162772140546,0.2846340679625949,3554.0,0.7440346102912944,0.2860179007369624,3581.0,8440.450272083282,8874.111963510513,8440.450272083282,425.6265361309052,6.677982568740845,0.0 -36189,,,0.7513892991202218,0.2640261820384434,0.7269135294298678,0.2846337244895716,3554.0,0.7440320877548171,0.2860175257653239,3581.0,8451.264714956284,8888.98148393631,8451.264714956284,429.6439287662506,6.712190389633179,0.0 -36189,,,,,,,,,,,8451.264714956284,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/eval_measurements.csv deleted file mode 100644 index d0bdd8d43..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.95932412147522,0.0,30.383830070495605,1,0,30.383830070495605,0.8949576268413153,3581,0.2836721463431129,34.343270778656006,0.8928326879228864,0.2656774180276053,0.8956671152882316,3554,0.2613217295323051 -7.974936962127685,0.01961350440979,110.5013530254364,344,0,110.5013530254364,0.3160043485801801,3581,0.7118314323862049,118.50828313827516,0.2980813298906599,0.714498588017055,0.3143779387551878,3554,0.6943832680826182 -11.9900643825531,0.0437884330749511,190.65867400169373,686,0,190.65867400169373,0.3079342089761938,3581,0.7214672490226194,202.7176303863525,0.2890401227133615,0.7243420055934361,0.3066701636800084,3554,0.7042416994935284 -16.00831699371338,0.0684285163879394,270.77397561073303,1028,0,270.77397561073303,0.3001406620881038,3581,0.7258220332789375,286.888560295105,0.2809634378978184,0.7297190938677106,0.2987067072862444,3554,0.7084653185671075 -20.02594399452209,0.0922152996063232,350.7723970413208,1375,0,350.7723970413208,0.2959859081567648,3581,0.728854599361212,370.9412474632263,0.2769715615681239,0.7327919006347656,0.2944409097253974,3554,0.7112662036833497 -24.04477381706237,0.1164469718933105,430.8206250667572,1721,0,430.8206250667572,0.2958660876740959,3581,0.7323106788650168,455.0456523895264,0.27680504322052,0.7358476775033134,0.2940751109555254,3554,0.7154141213025816 -28.061689853668213,0.1398513317108154,510.9170143604279,2066,0,510.9170143604279,0.2966501874585486,3581,0.7263692191645141,539.1955320835114,0.2779006106512887,0.7291966165815081,0.2951964816821715,3554,0.7088859356315067 -32.079628467559814,0.1638176441192627,591.0181603431702,2413,0,591.0181603431702,0.2949290335494624,3581,0.7297532359370636,623.3513972759247,0.275118316922869,0.7348235675266811,0.2932508787413829,3554,0.7122574668287492 -36.097635984420776,0.1916062831878662,671.2152290344238,2763,0,671.2152290344238,0.2950347414631911,3581,0.7313912484073932,707.6071050167084,0.2760927677154541,0.7348286764962333,0.2936104263022299,3554,0.714132005201006 -40.113075733184814,0.2189505100250244,751.3892691135406,3110,0,751.3892691135406,0.2933647200502653,3581,0.7293849456113864,791.8365585803986,0.2742685760770525,0.7333554540361676,0.2919274084877954,3554,0.7119192145953503 -44.12899589538574,0.2428901195526123,831.370644569397,3457,0,831.370644569397,0.2919519632696872,3581,0.7382981578120636,875.8706209659576,0.2725612606321062,0.7428291184561593,0.2903749791168402,3554,0.7211998556863745 -48.15200138092041,0.2721614837646484,911.4902095794678,3805,0,911.4902095794678,0.293038733343078,3581,0.7343574785979824,960.0550758838654,0.273552485874721,0.7386627197265625,0.2917152108539673,3554,0.7167989358381401 -52.17389678955078,0.2966139316558838,991.694286584854,4153,0,991.694286584854,0.293230650645333,3581,0.7317198599160499,1044.3180837631226,0.2742926052638462,0.735229355948312,0.2919689343763189,3554,0.7141856556872538 -56.19573616981506,0.3215329647064209,1071.6959915161133,4498,0,1071.6959915161133,0.2934164661363795,3581,0.737751721869764,1128.3792896270752,0.2737043585096086,0.7424945150102887,0.2917564276167698,3554,0.7206746167390616 -60.21451163291931,0.3464336395263672,1151.760767698288,4848,0,1151.760767698288,0.294008682706908,3581,0.7292635229771712,1212.5005338191986,0.2749983923775809,0.732649530683245,0.2926891629589899,3554,0.7118654267198931 -64.23522353172302,0.3716499805450439,1231.847367286682,5197,0,1231.847367286682,0.2919204315637217,3581,0.7350328366072675,1296.645854473114,0.2722815786089216,0.7398488180977958,0.2904013921923361,3554,0.7179851542715954 -68.25117325782776,0.3998067378997803,1311.875165224075,5542,0,1311.875165224075,0.2944357072265079,3581,0.7318219203783859,1380.730637550354,0.2749585935047695,0.7362847328186035,0.2929826263101435,3554,0.7142758517031865 -72.26720786094666,0.4259700775146484,1391.8485069274902,5892,0,1391.8485069274902,0.2912932062774888,3581,0.7376195954996858,1464.7589864730835,0.2715468747275216,0.7424530301775251,0.2899135574572664,3554,0.7202798975406233 -76.28772187232971,0.450514554977417,1471.865632534027,6239,0,1471.865632534027,0.2916656894678511,3581,0.7331954074019129,1548.8338203430176,0.2722529854093279,0.7378379276820591,0.2903406661618071,3554,0.715315956712507 -80.30454111099243,0.4805665016174316,1551.8551013469696,6584,0,1551.8551013469696,0.2911859643888753,3581,0.7338993996090477,1632.882734298706,0.2715438093457903,0.7387582915169852,0.289883915735351,3554,0.7162649039814294 -84.32542157173157,0.5068538188934326,1631.901466369629,6930,0,1631.901466369629,0.2917940661215442,3581,0.7372944610007679,1716.9891197681427,0.2723818676812308,0.7413247653416225,0.2905070788416221,3554,0.7200253840303179 -88.34319472312927,0.5372030735015869,1711.892200231552,7277,0,1711.892200231552,0.2905993042162629,3581,0.7388227772226682,1801.0402827262878,0.27117463520595,0.7432306153433663,0.2892516162466587,3554,0.721642317635059 -92.36465859413148,0.5624737739562988,1792.0071773529053,7621,0,1792.0071773529053,0.2919793361992984,3581,0.7362754926172856,1885.2145402431488,0.272381067276001,0.7414225169590541,0.2906780597126477,3554,0.7186626204628588 -96.38457012176514,0.5911064147949219,1872.1870954036715,7971,0,1872.1870954036715,0.2925304422429139,3581,0.736703369345155,1969.45621585846,0.2721464804240635,0.7419336863926479,0.2910100263897193,3554,0.7195013128912845 -100.40389037132265,0.6165444850921631,1952.1830134391785,8319,0,1952.1830134391785,0.2896039249576759,3581,0.7396412380445406,2053.50967502594,0.2699603012629917,0.7445854459490094,0.2881882237663548,3554,0.7223544059070766 -104.42580008506776,0.6430308818817139,2032.2230477333069,8665,0,2032.2230477333069,0.2918098149303616,3581,0.7338329955407009,2137.6106107234955,0.2723539727074759,0.7385866982596261,0.290498972878271,3554,0.7161103411209201 -108.44480562210084,0.6688551902770996,2112.292769908905,9013,0,2112.292769908905,0.2905307866714081,3581,0.737618981909732,2221.7378239631653,0.2706210272652762,0.7424891335623605,0.2891170778634109,3554,0.7202015169966939 -112.46442103385924,0.6951453685760498,2192.431623697281,9361,0,2192.431623697281,0.2898749271873255,3581,0.7415535934009355,2305.935366868973,0.2700404099055699,0.7465317589896066,0.2885494543450161,3554,0.7243574718846723 -116.48691987991332,0.7215027809143066,2272.523061037064,9706,0,2272.523061037064,0.2915895020485723,3581,0.736225314594387,2390.088215827942,0.2719215665544782,0.7407120295933315,0.290316073493335,3554,0.7184759085273635 -120.50583410263062,0.7481822967529297,2352.6956446170807,10054,0,2352.6956446170807,0.2894194729998603,3581,0.7403175505270874,2474.3193640708923,0.2696514810834612,0.7451527459280831,0.2880908835115363,3554,0.7231413026035804 -124.52903652191162,0.7753183841705322,2432.788671255112,10403,0,2432.788671255112,0.2902749537489528,3581,0.7356155425335102,2558.4751613140106,0.2704327957970755,0.7408154351370675,0.2889557829316439,3554,0.717867136940771 -128.54647755622864,0.805931568145752,2512.798567056656,10748,0,2512.798567056656,0.2895765861163956,3581,0.7389399729038676,2642.545251607895,0.270036118371146,0.7433472360883441,0.2883838660004572,3554,0.7214320434501618 -132.5679280757904,0.8328738212585449,2592.814291477204,11094,0,2592.814291477204,0.2899044135934445,3581,0.7398847650795867,2726.622271060944,0.2699331215449742,0.7452190944126674,0.2884922317393254,3554,0.7226893607994513 -136.58704543113708,0.8600804805755615,2672.999428987503,11442,0,2672.999428987503,0.289955068852974,3581,0.7421258001212999,2810.8664784431458,0.2698504243578229,0.7470412254333496,0.2885594150626934,3554,0.7249382847671637 -140.6037676334381,0.8904705047607422,2753.135874271393,11788,0,2753.135874271393,0.2894258134293842,3581,0.7389654709752862,2895.062990665436,0.2697876010622297,0.7438268661499023,0.2880471593956633,3554,0.7216566748074353 -144.61974668502808,0.9170408248901368,2833.2737278938293,12137,0,2833.2737278938293,0.2896093109139381,3581,0.7408732584953575,2979.2558150291443,0.2697081225258963,0.7457892554146903,0.2881403779742016,3554,0.7239719577632597 -148.6403510570526,0.9478676319122314,2913.414836406708,12485,0,2913.414836406708,0.2895721887217258,3581,0.7355890899888299,3063.4614906311035,0.2698268549782889,0.7406058992658343,0.2883866824792487,3554,0.7178359495902504 -152.65930891036987,0.9745488166809082,2993.3852968215942,12831,0,2993.3852968215942,0.2906425963963453,3581,0.7379370260358489,3147.490044116974,0.270629984991891,0.7433413096836635,0.2892357134456774,3554,0.720636972095702 -156.68073511123657,1.001333236694336,3073.576206445694,13177,0,3073.576206445694,0.2906785595853113,3581,0.7409575248490295,3231.741859436035,0.2701584611620222,0.7466381617954799,0.2892952029733223,3554,0.723888837291608 -160.6983847618103,1.0293962955474854,3153.5380742549896,13526,0,3153.5380742549896,0.2898893465512426,3581,0.7412572294531905,3315.7623331546783,0.2698252541678292,0.7466552598135812,0.2885258577483117,3554,0.7241531741303813 -164.72112107276917,1.0566694736480713,3233.532235622406,13872,0,3233.532235622406,0.2902640795714361,3581,0.7415183460669157,3399.819220304489,0.2703482934406825,0.7463925906590053,0.2889438644177335,3554,0.7245600522738463 -168.74269700050354,1.0836646556854248,3313.6611313819885,14220,0,3313.6611313819885,0.2906963536939752,3581,0.7331411387793214,3484.009249448776,0.2708131074905395,0.7378207615443638,0.2894396333796427,3554,0.715547801003271 -172.76443004608154,1.1147141456604004,3393.7974298000336,14567,0,3393.7974298000336,0.2894305857956925,3581,0.7405100814192963,3568.210847377777,0.269301210130964,0.7457613263811383,0.2880752211416713,3554,0.7233533628481992 -176.778422832489,1.1418116092681885,3473.7841737270355,14912,0,3473.7841737270355,0.2891287677150237,3581,0.7394592745348716,3652.251398086548,0.269031354359218,0.745074885232108,0.2877203276402909,3554,0.722156908918648 -180.79540133476257,1.169428825378418,3553.957760572433,15256,0,3553.957760572433,0.2893591025660604,3581,0.7385002334368891,3736.482356786728,0.2690439905439104,0.7442522048950195,0.2880876892124191,3554,0.7211128196222566 -184.8119802474976,1.1967723369598389,3634.062075376511,15593,0,3634.062075376511,0.2893062997416922,3581,0.7376525248272131,3820.642801046372,0.2695131301879883,0.7427872249058315,0.2880110260336065,3554,0.720087552647545 -188.8285722732544,1.2282922267913818,3714.14043045044,15940,0,3714.14043045044,0.2893332295230033,3581,0.7420830533545099,3904.782006263733,0.2692973273141043,0.7473726953778949,0.288025005385657,3554,0.7248972740881753 -192.8431706428528,1.2567307949066162,3794.149676799774,16283,0,3794.149676799774,0.2888464822478881,3581,0.7375420104588453,3988.846718549728,0.2691495077950613,0.742541858128139,0.2876137307874929,3554,0.7198662873259004 -196.8614206314087,1.2849531173706057,3874.2583525180817,16628,0,3874.2583525180817,0.2908273551491378,3581,0.7411906890315205,4073.014472723007,0.270333136831011,0.7466435432434082,0.2894760758674205,3554,0.7241068052722285 -200.88203167915344,1.3127756118774414,3954.3488595485687,16976,0,3954.3488595485687,0.2886493494310248,3581,0.7434004991622452,4157.166150331497,0.2686479602541242,0.7484711919512067,0.2872895952953538,3554,0.726306681292206 -204.9027829170227,1.3420863151550293,4034.3800785541534,17319,0,4034.3800785541534,0.2886111705005585,3581,0.7419237244964745,4241.2600877285,0.2685943841934204,0.7472654070172992,0.2873271884177599,3554,0.7247833097390265 -208.91718530654907,1.3709430694580078,4114.4050397872925,17665,0,4114.4050397872925,0.2885156209094003,3581,0.7382647512479056,4325.3410885334015,0.2685637984957014,0.7432973044259208,0.2872054959255856,3554,0.7210903564865293 -212.93413639068604,1.3991119861602783,4194.461406230927,18010,0,4194.461406230927,0.2881434104256841,3581,0.7421997036224169,4409.454977273941,0.2679294858660017,0.7477449008396694,0.2867555806122942,3554,0.7250144670837436 -216.95491862297047,1.4276528358459473,4274.639910936356,18354,0,4274.639910936356,0.2883986297581856,3581,0.7405550780159174,4493.69522857666,0.2684933287756784,0.7455470221383231,0.2871022135874631,3554,0.723333166634426 -220.97378945350647,1.456956148147583,4354.662460803986,18702,0,4354.662460803986,0.2895410660757295,3581,0.7397161641955808,4577.778399944305,0.2690425259726388,0.7455432074410575,0.2881686458040236,3554,0.7224505783536156 -224.9903819561005,1.4863619804382324,4434.6307945251465,19050,0,4434.6307945251465,0.2879860927791643,3581,0.7406949765254119,4661.805282831192,0.2680240358625139,0.7460323061261859,0.2866896681391126,3554,0.7234105167592854 -229.007297039032,1.5158488750457764,4514.716888904572,19396,0,4514.716888904572,0.2881001523339325,3581,0.7415834547786931,4745.949978351593,0.268166184425354,0.7468460627964565,0.2867450016431749,3554,0.7243975895338 -233.02187514305115,1.544053554534912,4594.704051733017,19741,0,4594.704051733017,0.288564435399068,3581,0.7374969456855627,4829.992509126663,0.2682247332164219,0.7431297983442035,0.2873446196736951,3554,0.7199703596519766 -237.04399013519287,1.5727825164794922,4674.802278280258,20089,0,4674.802278280258,0.2886755633573897,3581,0.7398124978183468,4914.1542048454285,0.2684863635471889,0.745450496673584,0.2874175046492508,3554,0.7224317560319359 -241.05979704856875,1.6012942790985107,4754.768115282059,20436,0,4754.768115282059,0.287764723159121,3581,0.7414250122172578,4998.176766395569,0.2675222669328962,0.7470994676862445,0.2864648135243739,3554,0.7241141555949282 -245.0799369812012,1.6300930976867676,4834.975478887558,20780,0,4834.975478887558,0.2879269154369415,3581,0.7421225958182072,5082.445472002029,0.2671455485480172,0.7483934674944196,0.2866048818232977,3554,0.7249244084570202 -249.0958018302917,1.663682460784912,4914.961860656738,21127,0,4914.961860656738,0.2884889297464046,3581,0.7398913782157568,5166.493827342987,0.2681503977094377,0.74561950138637,0.287237490437711,3554,0.7225248372212648 -253.1110918521881,1.6949870586395264,4994.953593254089,21473,0,4994.953593254089,0.2880482358061644,3581,0.7438104454281276,5250.544547796249,0.2679614509854998,0.7491350173950195,0.2867093663170019,3554,0.7267961303504854 -257.1274366378784,1.7242553234100342,5074.936768531799,21819,0,5074.936768531799,0.2886547353872871,3581,0.742086053127618,5334.585844755173,0.2674870831625802,0.7484683990478516,0.287353824750721,3554,0.7248557825469542 -261.1410791873932,1.7532122135162354,5155.074422121048,22167,0,5155.074422121048,0.2879750822483244,3581,0.7414715087004329,5418.778828620911,0.2676395348140171,0.7472255570547921,0.2866872638279491,3554,0.7242740079399972 -265.16000413894653,1.7835183143615725,5235.182781457901,22514,0,5235.182781457901,0.2872728626343898,3581,0.7404845151712162,5502.94899559021,0.2670364209583827,0.745988300868443,0.2860456390466727,3554,0.7230816069921215 -269.17903685569763,1.8133699893951416,5315.278612852097,22859,0,5315.278612852097,0.2873145867512566,3581,0.7409802276773247,5587.105926513672,0.2666735649108886,0.7471205166407994,0.2860689608649585,3554,0.7236672284969401 -273.1909189224243,1.843076229095459,5395.240823984146,23207,0,5395.240823984146,0.2874790629472389,3581,0.7427466849780089,5671.12223815918,0.2670401845659528,0.7486298424857003,0.2862076209244865,3554,0.7255774880636255 -277.20637464523315,1.872402667999268,5475.210396766663,23554,0,5475.210396766663,0.2872870092916608,3581,0.7415203231901005,5755.149123430252,0.2669352974210466,0.7473702430725098,0.2860737523136343,3554,0.7242064124490011 -281.2231953144073,1.902432203292847,5555.329708099365,23899,0,5555.329708099365,0.2875343542197535,3581,0.7424447987206786,5839.32786488533,0.266871520451137,0.7486155373709542,0.2863125862804234,3554,0.7250420823148214 -285.23908710479736,1.933682203292847,5635.302527427673,24245,0,5635.302527427673,0.287214128440467,3581,0.7427145055937587,5923.360556364059,0.2665752513068063,0.7487003462655204,0.2859320009968961,3554,0.7254798043357836 -289.2548222541809,1.96408486366272,5715.485833406448,24594,0,5715.485833406448,0.2868927436579342,3581,0.7419074302743647,6007.602685213089,0.2664803436824253,0.7478156770978656,0.2856176716595737,3554,0.7245771572304094 -293.2680079936981,1.994885921478272,5795.482012271881,24938,0,5795.482012271881,0.2872251048829761,3581,0.7421870909400308,6091.655306100845,0.2666981220245361,0.7481376784188407,0.2858955241618159,3554,0.7249912483073649 -297.2859468460083,2.02933931350708,5875.4409856796265,25284,0,5875.4409856796265,0.2873365737246056,3581,0.7414855530927116,6175.679006099701,0.2664295264652797,0.7477162906101772,0.2860070498524989,3554,0.7241976882342079 -301.29969024658203,2.0605125427246094,5955.660590648651,25631,0,5955.660590648651,0.2871733928851926,3581,0.7439213006798031,6259.956165790558,0.2663685594286237,0.7501141003199986,0.2858488633515932,3554,0.7267898104468556 -305.313759803772,2.093826532363892,6035.690726995468,25976,0,6035.690726995468,0.286973158030229,3581,0.7422536313617006,6344.0462028980255,0.2665225608008248,0.748063496180943,0.2857278578054656,3554,0.7249627400464266 -309.33096265792847,2.125617742538452,6115.706563234329,26325,0,6115.706563234329,0.2866083447142907,3581,0.742955441915666,6428.123630523682,0.2655692952019827,0.7494532721383231,0.285317854057488,3554,0.7256950932268219 -313.3459963798523,2.156813144683838,6195.682910203934,26674,0,6195.682910203934,0.2866072879760367,3581,0.7426701225870916,6512.158713579178,0.2659331730433872,0.7486578396388462,0.2852956485265282,3554,0.7254540438590321 -317.3675000667572,2.1871368885040283,6275.682108402252,27018,0,6275.682108402252,0.2864416186884774,3581,0.7428295877984502,6596.222285270691,0.2657155820301601,0.7490175792149135,0.2851643903106535,3554,0.725519784595702 -321.38672494888306,2.217285633087158,6355.869081735611,27364,0,6355.869081735611,0.2866985083491867,3581,0.744054381523143,6680.471506595612,0.2653354576655796,0.7507955006190709,0.2853771203276677,3554,0.7269669738323016 -325.39993500709534,2.257634401321411,6435.99257683754,27711,0,6435.99257683754,0.2865437473274748,3581,0.7427899771580914,6764.661067485809,0.2656243188040597,0.7491722106933594,0.2852912692454804,3554,0.7254776748030388 -329.4172194004059,2.288557529449463,6516.111511468887,28058,0,6516.111511468887,0.2864529360142941,3581,0.7426109452448687,6848.840639352799,0.265708327293396,0.7488178525652204,0.2851588260476752,3554,0.7253711294711944 -333.43426752090454,2.319133281707764,6596.127126693726,28403,0,6596.127126693726,0.2863629769093828,3581,0.7432273986185772,6932.916352748871,0.2649412836347307,0.750096184866769,0.2850630829424152,3554,0.7260181639525887 -337.45313835144043,2.350611448287964,6676.110071659088,28752,0,6676.110071659088,0.2864265857346063,3581,0.7432734178651215,7016.962255001068,0.2653285094669887,0.7498783384050641,0.2851461862404157,3554,0.7260880263655388 -341.47138237953186,2.385899305343628,6756.167396783829,29100,0,6756.167396783829,0.2863883045391476,3581,0.7432643503691357,7101.085487127304,0.2654174736567906,0.749572617667062,0.2851238433202465,3554,0.726090087203679 -345.49227356910706,2.417950391769409,6836.258863449097,29446,0,6836.258863449097,0.2862946979828434,3581,0.7437386554035186,7185.242262125015,0.2646945033754621,0.7507331030709403,0.2850275850054516,3554,0.7265406551157146 -349.5119769573212,2.4548652172088623,6916.343825817108,29793,0,6916.343825817108,0.2863362857463872,3581,0.7430140738445965,7269.396333932877,0.2650007350104196,0.7498508180890765,0.2850495672789462,3554,0.7257677034239589 -353.5303068161011,2.487074136734009,6996.4826691150665,30142,0,6996.4826691150665,0.2862814035338418,3581,0.7429910301329936,7353.598170042038,0.2650401251656668,0.7498049054827008,0.2849661376815736,3554,0.7257995777205262 -357.5478813648224,2.5196094512939453,7076.474865198135,30488,0,7076.474865198135,0.286356636479859,3581,0.7437524952658127,7437.652813434601,0.2646484034402029,0.7510360990251813,0.285057432811181,3554,0.7266652671285875 -361.5675451755524,2.557867288589477,7156.552713394165,30836,0,7156.552713394165,0.2862277144128735,3581,0.7437485410194429,7521.801300287247,0.2647977897099086,0.7507091249738421,0.2849529483174768,3554,0.7265978090268008 -365.5828056335449,2.589299201965332,7236.685940265655,31182,0,7236.685940265655,0.2861480499838558,3581,0.743651730160046,7605.993559360504,0.264832649912153,0.750427109854562,0.2848840132816896,3554,0.7264883098269556 -369.5958352088928,2.6212565898895264,7316.703558444977,31529,0,7316.703558444977,0.286220555863411,3581,0.7432802355312762,7690.068843841553,0.2645143951688494,0.7505290167672294,0.2849139813029773,3554,0.7261004600889842 -373.6141917705536,2.6523241996765137,7396.760830163956,31876,0,7396.760830163956,0.2861240858873219,3581,0.7438690091803966,7774.188344955444,0.2645551477159772,0.7509726115635463,0.2848580295474729,3554,0.7267222149558596 -377.6309578418732,2.6843929290771484,7476.919148206711,32223,0,7476.919148206711,0.2861133821514591,3581,0.7439353450720818,7858.408213853836,0.2646592685154506,0.7508950233459473,0.2848581497630311,3554,0.7267572492042417 -381.6490135192871,2.7174501419067383,7557.069484472275,32569,0,7557.069484472275,0.2861683666289968,3581,0.7437707666111072,7942.621923923492,0.26461204460689,0.750824110848563,0.284898696753438,3554,0.7265723920230726 -385.6676907539368,2.7551095485687256,7637.059919595718,32918,0,7637.059919595718,0.2862069887077632,3581,0.7436541845198618,8026.681388616562,0.2643741369247436,0.7510066713605609,0.2848729706239888,3554,0.7264729222355093 -389.682653427124,2.786450147628784,7717.092004537582,33264,0,7717.092004537582,0.286068385554838,3581,0.7436945451034976,8110.772299528122,0.2644939933504377,0.7508365086146763,0.2848100635397615,3554,0.7264935993115152 -393.70337677001953,2.818121194839477,7797.218616485596,33610,0,7797.218616485596,0.2862118974273945,3581,0.7443005674479893,8194.963827133179,0.264496956552778,0.7514575549534389,0.2849336107862619,3554,0.7271822627233399 -397.71788907051086,2.850908041000366,7877.383341789246,33958,0,7877.383341789246,0.286054988840844,3581,0.7438889167655682,8279.188350439072,0.2641548088618687,0.7513347353254046,0.2847540774369548,3554,0.7267312139490715 -401.73359990119934,2.884448528289795,7957.451696395874,34304,0,7957.451696395874,0.2860153782004852,3581,0.7438610325109956,8363.3182117939,0.2642734391348703,0.7511395045689174,0.2847594871370726,3554,0.7266988587902715 -405.75045943260193,2.918001174926758,8037.453133821487,34647,0,8037.453133821487,0.2859457016523841,3581,0.7439592750802848,8447.382189512253,0.2642475877489362,0.7512046950204032,0.2847093915966165,3554,0.7267843835730866 -409.7646481990814,2.950176954269409,8117.535898685455,34993,0,8117.535898685455,0.2859992203316985,3581,0.7439256639861421,8531.523807525635,0.2640795026506696,0.7513880729675293,0.2847361824924381,3554,0.72677325504713 -413.7841286659241,2.983177423477173,8197.736165523529,35342,0,8197.736165523529,0.2859374863646677,3581,0.7438301484833147,8615.789492368698,0.2640912362507411,0.751244136265346,0.284680162042329,3554,0.7266568863868177 -417.794353723526,3.0200717449188232,8277.73072552681,35686,0,8277.73072552681,0.2859263735688355,3581,0.7440562904696663,8699.84346485138,0.2641090495245797,0.7514269011361259,0.2846793720543754,3554,0.7268992409520962 -421.8108897209168,3.0529043674468994,8357.767586946487,36033,0,8357.767586946487,0.2859348615631981,3581,0.7439452306880061,8783.942422628403,0.2640905209950038,0.7513409342084613,0.2846803166051895,3554,0.7267714689874085 -425.8297622203827,3.0856668949127197,8392.034851789474,36189,0,8392.034851789474,0.28593445250322885,3581,0.7439701833461324,8822.267328739166,0.26408934593200684,0.751366410936628,0.28467950944358467,3554,0.7267988781346723 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/measurements.csv deleted file mode 100644 index acd02c47a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.160697,0.8857512,,,,,,,,,,,,,, -1,,,0.2656774180276053,0.8928326879228864,0.2613217295323051,0.8956671152882316,3554.0,0.2836721463431129,0.8949576268413153,3581.0,30.383830070495605,34.343270778656006,30.383830070495605,3.95932412147522,0.0,0.0 -100,0.5909863,0.24547428,,,,,,,,,,,,,, -200,0.7926095,0.2555074,,,,,,,,,,,,,, -300,0.3576562,0.296986,,,,,,,,,,,,,, -344,,,0.714498588017055,0.2980813298906599,0.6943832680826182,0.3143779387551878,3554.0,0.7118314323862049,0.3160043485801801,3581.0,110.5013530254364,118.50828313827516,110.5013530254364,7.974936962127685,0.01961350440979,0.0 -400,0.29695645,0.28386095,,,,,,,,,,,,,, -500,0.21928945,0.28025818,,,,,,,,,,,,,, -600,0.2697355,0.2707761,,,,,,,,,,,,,, -686,,,0.7243420055934361,0.2890401227133615,0.7042416994935284,0.3066701636800084,3554.0,0.7214672490226194,0.3079342089761938,3581.0,190.65867400169373,202.7176303863525,190.65867400169373,11.9900643825531,0.0437884330749511,0.0 -700,0.13697015,0.23610334,,,,,,,,,,,,,, -800,0.2119895,0.34579355,,,,,,,,,,,,,, -900,0.11316707,0.35369682,,,,,,,,,,,,,, -1000,0.20282471,0.2765333,,,,,,,,,,,,,, -1028,,,0.7297190938677106,0.2809634378978184,0.7084653185671075,0.2987067072862444,3554.0,0.7258220332789375,0.3001406620881038,3581.0,270.77397561073303,286.888560295105,270.77397561073303,16.00831699371338,0.0684285163879394,0.0 -1100,0.18590486,0.2670014,,,,,,,,,,,,,, -1200,0.17391141,0.2702939,,,,,,,,,,,,,, -1300,0.29069546,0.29500565,,,,,,,,,,,,,, -1375,,,0.7327919006347656,0.2769715615681239,0.7112662036833497,0.2944409097253974,3554.0,0.728854599361212,0.2959859081567648,3581.0,350.7723970413208,370.9412474632263,350.7723970413208,20.02594399452209,0.0922152996063232,0.0 -1400,0.13724594,0.29032692,,,,,,,,,,,,,, -1500,0.09734282,0.27498695,,,,,,,,,,,,,, -1600,0.15540956,0.2487393,,,,,,,,,,,,,, -1700,0.09210714,0.40882817,,,,,,,,,,,,,, -1721,,,0.7358476775033134,0.27680504322052,0.7154141213025816,0.2940751109555254,3554.0,0.7323106788650168,0.2958660876740959,3581.0,430.8206250667572,455.0456523895264,430.8206250667572,24.04477381706237,0.1164469718933105,0.0 -1800,0.20437716,0.2589693,,,,,,,,,,,,,, -1900,0.36033377,0.31628147,,,,,,,,,,,,,, -2000,0.3309579,0.27624565,,,,,,,,,,,,,, -2066,,,0.7291966165815081,0.2779006106512887,0.7088859356315067,0.2951964816821715,3554.0,0.7263692191645141,0.2966501874585486,3581.0,510.9170143604279,539.1955320835114,510.9170143604279,28.061689853668213,0.1398513317108154,0.0 -2100,0.19061784,0.3138634,,,,,,,,,,,,,, -2200,0.15518008,0.28093344,,,,,,,,,,,,,, -2300,0.13613355,0.2777457,,,,,,,,,,,,,, -2400,0.20346661,0.3373495,,,,,,,,,,,,,, -2413,,,0.7348235675266811,0.275118316922869,0.7122574668287492,0.2932508787413829,3554.0,0.7297532359370636,0.2949290335494624,3581.0,591.0181603431702,623.3513972759247,591.0181603431702,32.079628467559814,0.1638176441192627,0.0 -2500,0.1138327,0.30091602,,,,,,,,,,,,,, -2600,0.22648923,0.24881478,,,,,,,,,,,,,, -2700,0.0989374,0.30047354,,,,,,,,,,,,,, -2763,,,0.7348286764962333,0.2760927677154541,0.714132005201006,0.2936104263022299,3554.0,0.7313912484073932,0.2950347414631911,3581.0,671.2152290344238,707.6071050167084,671.2152290344238,36.097635984420776,0.1916062831878662,0.0 -2800,0.16045786,0.23834145,,,,,,,,,,,,,, -2900,0.14489493,0.25266412,,,,,,,,,,,,,, -3000,0.16859996,0.26450118,,,,,,,,,,,,,, -3100,0.059656464,0.2861216,,,,,,,,,,,,,, -3110,,,0.7333554540361676,0.2742685760770525,0.7119192145953503,0.2919274084877954,3554.0,0.7293849456113864,0.2933647200502653,3581.0,751.3892691135406,791.8365585803986,751.3892691135406,40.113075733184814,0.2189505100250244,0.0 -3200,0.23811424,0.26760882,,,,,,,,,,,,,, -3300,0.106779344,0.2627077,,,,,,,,,,,,,, -3400,0.13500194,0.2862821,,,,,,,,,,,,,, -3457,,,0.7428291184561593,0.2725612606321062,0.7211998556863745,0.2903749791168402,3554.0,0.7382981578120636,0.2919519632696872,3581.0,831.370644569397,875.8706209659576,831.370644569397,44.12899589538574,0.2428901195526123,0.0 -3500,0.1406748,0.32838255,,,,,,,,,,,,,, -3600,0.16952114,0.24667253,,,,,,,,,,,,,, -3700,0.14531507,0.25330615,,,,,,,,,,,,,, -3800,0.21037866,0.2665416,,,,,,,,,,,,,, -3805,,,0.7386627197265625,0.273552485874721,0.7167989358381401,0.2917152108539673,3554.0,0.7343574785979824,0.293038733343078,3581.0,911.4902095794678,960.0550758838654,911.4902095794678,48.15200138092041,0.2721614837646484,0.0 -3900,0.16723068,0.32252634,,,,,,,,,,,,,, -4000,0.27580673,0.2357241,,,,,,,,,,,,,, -4100,0.1665298,0.24278724,,,,,,,,,,,,,, -4153,,,0.735229355948312,0.2742926052638462,0.7141856556872538,0.2919689343763189,3554.0,0.7317198599160499,0.293230650645333,3581.0,991.694286584854,1044.3180837631226,991.694286584854,52.17389678955078,0.2966139316558838,0.0 -4200,0.098569825,0.29406723,,,,,,,,,,,,,, -4300,0.16659145,0.29203993,,,,,,,,,,,,,, -4400,0.16915388,0.29378647,,,,,,,,,,,,,, -4498,,,0.7424945150102887,0.2737043585096086,0.7206746167390616,0.2917564276167698,3554.0,0.737751721869764,0.2934164661363795,3581.0,1071.6959915161133,1128.3792896270752,1071.6959915161133,56.19573616981506,0.3215329647064209,0.0 -4500,0.15772729,0.25464997,,,,,,,,,,,,,, -4600,0.114564,0.24845347,,,,,,,,,,,,,, -4700,0.12701994,0.2805285,,,,,,,,,,,,,, -4800,0.17705071,0.25394827,,,,,,,,,,,,,, -4848,,,0.732649530683245,0.2749983923775809,0.7118654267198931,0.2926891629589899,3554.0,0.7292635229771712,0.294008682706908,3581.0,1151.760767698288,1212.5005338191986,1151.760767698288,60.21451163291931,0.3464336395263672,0.0 -4900,0.18852739,0.2850683,,,,,,,,,,,,,, -5000,0.06529057,0.257897,,,,,,,,,,,,,, -5100,0.14021812,0.23524097,,,,,,,,,,,,,, -5197,,,0.7398488180977958,0.2722815786089216,0.7179851542715954,0.2904013921923361,3554.0,0.7350328366072675,0.2919204315637217,3581.0,1231.847367286682,1296.645854473114,1231.847367286682,64.23522353172302,0.3716499805450439,0.0 -5200,0.22772755,0.2595017,,,,,,,,,,,,,, -5300,0.1485211,0.25931635,,,,,,,,,,,,,, -5400,0.13048795,0.32257083,,,,,,,,,,,,,, -5500,0.2635411,0.27763298,,,,,,,,,,,,,, -5542,,,0.7362847328186035,0.2749585935047695,0.7142758517031865,0.2929826263101435,3554.0,0.7318219203783859,0.2944357072265079,3581.0,1311.875165224075,1380.730637550354,1311.875165224075,68.25117325782776,0.3998067378997803,0.0 -5600,0.2108366,0.21913774,,,,,,,,,,,,,, -5700,0.10898561,0.2554291,,,,,,,,,,,,,, -5800,0.12268351,0.3612184,,,,,,,,,,,,,, -5892,,,0.7424530301775251,0.2715468747275216,0.7202798975406233,0.2899135574572664,3554.0,0.7376195954996858,0.2912932062774888,3581.0,1391.8485069274902,1464.7589864730835,1391.8485069274902,72.26720786094666,0.4259700775146484,0.0 -5900,0.17178816,0.25415677,,,,,,,,,,,,,, -6000,0.15754183,0.22972813,,,,,,,,,,,,,, -6100,0.07526067,0.34507814,,,,,,,,,,,,,, -6200,0.20961663,0.2375118,,,,,,,,,,,,,, -6239,,,0.7378379276820591,0.2722529854093279,0.715315956712507,0.2903406661618071,3554.0,0.7331954074019129,0.2916656894678511,3581.0,1471.865632534027,1548.8338203430176,1471.865632534027,76.28772187232971,0.450514554977417,0.0 -6300,0.07762352,0.26293197,,,,,,,,,,,,,, -6400,0.17759451,0.21591835,,,,,,,,,,,,,, -6500,0.100069866,0.30167776,,,,,,,,,,,,,, -6584,,,0.7387582915169852,0.2715438093457903,0.7162649039814294,0.289883915735351,3554.0,0.7338993996090477,0.2911859643888753,3581.0,1551.8551013469696,1632.882734298706,1551.8551013469696,80.30454111099243,0.4805665016174316,0.0 -6600,0.29165772,0.33805764,,,,,,,,,,,,,, -6700,0.24496691,0.23377399,,,,,,,,,,,,,, -6800,0.1510291,0.367134,,,,,,,,,,,,,, -6900,0.06638849,0.32883123,,,,,,,,,,,,,, -6930,,,0.7413247653416225,0.2723818676812308,0.7200253840303179,0.2905070788416221,3554.0,0.7372944610007679,0.2917940661215442,3581.0,1631.901466369629,1716.9891197681427,1631.901466369629,84.32542157173157,0.5068538188934326,0.0 -7000,0.17841078,0.25360888,,,,,,,,,,,,,, -7100,0.18225113,0.23319995,,,,,,,,,,,,,, -7200,0.087962665,0.26913416,,,,,,,,,,,,,, -7277,,,0.7432306153433663,0.27117463520595,0.721642317635059,0.2892516162466587,3554.0,0.7388227772226682,0.2905993042162629,3581.0,1711.892200231552,1801.0402827262878,1711.892200231552,88.34319472312927,0.5372030735015869,0.0 -7300,0.07629893,0.28989536,,,,,,,,,,,,,, -7400,0.10362339,0.22140583,,,,,,,,,,,,,, -7500,0.30972221,0.16578431,,,,,,,,,,,,,, -7600,0.073075734,0.2888834,,,,,,,,,,,,,, -7621,,,0.7414225169590541,0.272381067276001,0.7186626204628588,0.2906780597126477,3554.0,0.7362754926172856,0.2919793361992984,3581.0,1792.0071773529053,1885.2145402431488,1792.0071773529053,92.36465859413148,0.5624737739562988,0.0 -7700,0.10128859,0.2458495,,,,,,,,,,,,,, -7800,0.29470405,0.30822203,,,,,,,,,,,,,, -7900,0.1252431,0.29683352,,,,,,,,,,,,,, -7971,,,0.7419336863926479,0.2721464804240635,0.7195013128912845,0.2910100263897193,3554.0,0.736703369345155,0.2925304422429139,3581.0,1872.1870954036715,1969.45621585846,1872.1870954036715,96.38457012176514,0.5911064147949219,0.0 -8000,0.14265577,0.29624963,,,,,,,,,,,,,, -8100,0.096957915,0.25676045,,,,,,,,,,,,,, -8200,0.15784612,0.29700312,,,,,,,,,,,,,, -8300,0.18640459,0.21614087,,,,,,,,,,,,,, -8319,,,0.7445854459490094,0.2699603012629917,0.7223544059070766,0.2881882237663548,3554.0,0.7396412380445406,0.2896039249576759,3581.0,1952.1830134391785,2053.50967502594,1952.1830134391785,100.40389037132265,0.6165444850921631,0.0 -8400,0.11138567,0.2692394,,,,,,,,,,,,,, -8500,0.10764883,0.2646183,,,,,,,,,,,,,, -8600,0.08348013,0.289973,,,,,,,,,,,,,, -8665,,,0.7385866982596261,0.2723539727074759,0.7161103411209201,0.290498972878271,3554.0,0.7338329955407009,0.2918098149303616,3581.0,2032.2230477333069,2137.6106107234955,2032.2230477333069,104.42580008506776,0.6430308818817139,0.0 -8700,0.3254154,0.22123113,,,,,,,,,,,,,, -8800,0.112572625,0.3465643,,,,,,,,,,,,,, -8900,0.14939007,0.28518376,,,,,,,,,,,,,, -9000,0.35894018,0.24162747,,,,,,,,,,,,,, -9013,,,0.7424891335623605,0.2706210272652762,0.7202015169966939,0.2891170778634109,3554.0,0.737618981909732,0.2905307866714081,3581.0,2112.292769908905,2221.7378239631653,2112.292769908905,108.44480562210084,0.6688551902770996,0.0 -9100,0.22787625,0.28628847,,,,,,,,,,,,,, -9200,0.15747064,0.2391072,,,,,,,,,,,,,, -9300,0.1282431,0.41851935,,,,,,,,,,,,,, -9361,,,0.7465317589896066,0.2700404099055699,0.7243574718846723,0.2885494543450161,3554.0,0.7415535934009355,0.2898749271873255,3581.0,2192.431623697281,2305.935366868973,2192.431623697281,112.46442103385924,0.6951453685760498,0.0 -9400,0.18445303,0.30468336,,,,,,,,,,,,,, -9500,0.0629941,0.27872908,,,,,,,,,,,,,, -9600,0.07157171,0.33435804,,,,,,,,,,,,,, -9700,0.10400775,0.4022863,,,,,,,,,,,,,, -9706,,,0.7407120295933315,0.2719215665544782,0.7184759085273635,0.290316073493335,3554.0,0.736225314594387,0.2915895020485723,3581.0,2272.523061037064,2390.088215827942,2272.523061037064,116.48691987991332,0.7215027809143066,0.0 -9800,0.08748637,0.24650632,,,,,,,,,,,,,, -9900,0.1828828,0.28126377,,,,,,,,,,,,,, -10000,0.110518515,0.29798636,,,,,,,,,,,,,, -10054,,,0.7451527459280831,0.2696514810834612,0.7231413026035804,0.2880908835115363,3554.0,0.7403175505270874,0.2894194729998603,3581.0,2352.6956446170807,2474.3193640708923,2352.6956446170807,120.50583410263062,0.7481822967529297,0.0 -10100,0.06922175,0.24269089,,,,,,,,,,,,,, -10200,0.11532547,0.36681613,,,,,,,,,,,,,, -10300,0.14819868,0.29659608,,,,,,,,,,,,,, -10400,0.20600957,0.28477597,,,,,,,,,,,,,, -10403,,,0.7408154351370675,0.2704327957970755,0.717867136940771,0.2889557829316439,3554.0,0.7356155425335102,0.2902749537489528,3581.0,2432.788671255112,2558.4751613140106,2432.788671255112,124.52903652191162,0.7753183841705322,0.0 -10500,0.11256786,0.24719933,,,,,,,,,,,,,, -10600,0.22251661,0.2494203,,,,,,,,,,,,,, -10700,0.1576921,0.25378692,,,,,,,,,,,,,, -10748,,,0.7433472360883441,0.270036118371146,0.7214320434501618,0.2883838660004572,3554.0,0.7389399729038676,0.2895765861163956,3581.0,2512.798567056656,2642.545251607895,2512.798567056656,128.54647755622864,0.805931568145752,0.0 -10800,0.09081423,0.23415649,,,,,,,,,,,,,, -10900,0.11995569,0.24841833,,,,,,,,,,,,,, -11000,0.065541714,0.26845053,,,,,,,,,,,,,, -11094,,,0.7452190944126674,0.2699331215449742,0.7226893607994513,0.2884922317393254,3554.0,0.7398847650795867,0.2899044135934445,3581.0,2592.814291477204,2726.622271060944,2592.814291477204,132.5679280757904,0.8328738212585449,0.0 -11100,0.093335286,0.24591796,,,,,,,,,,,,,, -11200,0.10944376,0.3904328,,,,,,,,,,,,,, -11300,0.18979023,0.282685,,,,,,,,,,,,,, -11400,0.10605478,0.26976508,,,,,,,,,,,,,, -11442,,,0.7470412254333496,0.2698504243578229,0.7249382847671637,0.2885594150626934,3554.0,0.7421258001212999,0.289955068852974,3581.0,2672.999428987503,2810.8664784431458,2672.999428987503,136.58704543113708,0.8600804805755615,0.0 -11500,0.1860313,0.37173632,,,,,,,,,,,,,, -11600,0.18794933,0.25929376,,,,,,,,,,,,,, -11700,0.11137093,0.24233437,,,,,,,,,,,,,, -11788,,,0.7438268661499023,0.2697876010622297,0.7216566748074353,0.2880471593956633,3554.0,0.7389654709752862,0.2894258134293842,3581.0,2753.135874271393,2895.062990665436,2753.135874271393,140.6037676334381,0.8904705047607422,0.0 -11800,0.18500155,0.27472448,,,,,,,,,,,,,, -11900,0.16926585,0.3163547,,,,,,,,,,,,,, -12000,0.1276093,0.3409411,,,,,,,,,,,,,, -12100,0.11352658,0.32880482,,,,,,,,,,,,,, -12137,,,0.7457892554146903,0.2697081225258963,0.7239719577632597,0.2881403779742016,3554.0,0.7408732584953575,0.2896093109139381,3581.0,2833.2737278938293,2979.2558150291443,2833.2737278938293,144.61974668502808,0.9170408248901368,0.0 -12200,0.07785212,0.27061772,,,,,,,,,,,,,, -12300,0.16013384,0.2317941,,,,,,,,,,,,,, -12400,0.09178853,0.2661981,,,,,,,,,,,,,, -12485,,,0.7406058992658343,0.2698268549782889,0.7178359495902504,0.2883866824792487,3554.0,0.7355890899888299,0.2895721887217258,3581.0,2913.414836406708,3063.4614906311035,2913.414836406708,148.6403510570526,0.9478676319122314,0.0 -12500,0.13615432,0.34957296,,,,,,,,,,,,,, -12600,0.107563205,0.27909565,,,,,,,,,,,,,, -12700,0.1649401,0.30857316,,,,,,,,,,,,,, -12800,0.091289,0.2640559,,,,,,,,,,,,,, -12831,,,0.7433413096836635,0.270629984991891,0.720636972095702,0.2892357134456774,3554.0,0.7379370260358489,0.2906425963963453,3581.0,2993.3852968215942,3147.490044116974,2993.3852968215942,152.65930891036987,0.9745488166809082,0.0 -12900,0.2442099,0.308307,,,,,,,,,,,,,, -13000,0.12552014,0.31456774,,,,,,,,,,,,,, -13100,0.0930759,0.25693423,,,,,,,,,,,,,, -13177,,,0.7466381617954799,0.2701584611620222,0.723888837291608,0.2892952029733223,3554.0,0.7409575248490295,0.2906785595853113,3581.0,3073.576206445694,3231.741859436035,3073.576206445694,156.68073511123657,1.001333236694336,0.0 -13200,0.20138589,0.24302977,,,,,,,,,,,,,, -13300,0.09539237,0.24364996,,,,,,,,,,,,,, -13400,0.19600654,0.29753053,,,,,,,,,,,,,, -13500,0.09299792,0.25028664,,,,,,,,,,,,,, -13526,,,0.7466552598135812,0.2698252541678292,0.7241531741303813,0.2885258577483117,3554.0,0.7412572294531905,0.2898893465512426,3581.0,3153.5380742549896,3315.7623331546783,3153.5380742549896,160.6983847618103,1.0293962955474854,0.0 -13600,0.2388051,0.20890045,,,,,,,,,,,,,, -13700,0.2274067,0.28449845,,,,,,,,,,,,,, -13800,0.26618758,0.33330113,,,,,,,,,,,,,, -13872,,,0.7463925906590053,0.2703482934406825,0.7245600522738463,0.2889438644177335,3554.0,0.7415183460669157,0.2902640795714361,3581.0,3233.532235622406,3399.819220304489,3233.532235622406,164.72112107276917,1.0566694736480713,0.0 -13900,0.0934724,0.25518572,,,,,,,,,,,,,, -14000,0.19976275,0.24866818,,,,,,,,,,,,,, -14100,0.11328085,0.32125315,,,,,,,,,,,,,, -14200,0.07241538,0.24849023,,,,,,,,,,,,,, -14220,,,0.7378207615443638,0.2708131074905395,0.715547801003271,0.2894396333796427,3554.0,0.7331411387793214,0.2906963536939752,3581.0,3313.6611313819885,3484.009249448776,3313.6611313819885,168.74269700050354,1.0836646556854248,0.0 -14300,0.13949853,0.27771917,,,,,,,,,,,,,, -14400,0.1016464,0.29174566,,,,,,,,,,,,,, -14500,0.09591392,0.33400762,,,,,,,,,,,,,, -14567,,,0.7457613263811383,0.269301210130964,0.7233533628481992,0.2880752211416713,3554.0,0.7405100814192963,0.2894305857956925,3581.0,3393.7974298000336,3568.210847377777,3393.7974298000336,172.76443004608154,1.1147141456604004,0.0 -14600,0.1467455,0.22350878,,,,,,,,,,,,,, -14700,0.057204824,0.2719464,,,,,,,,,,,,,, -14800,0.11216006,0.2876605,,,,,,,,,,,,,, -14900,0.15768528,0.28197896,,,,,,,,,,,,,, -14912,,,0.745074885232108,0.269031354359218,0.722156908918648,0.2877203276402909,3554.0,0.7394592745348716,0.2891287677150237,3581.0,3473.7841737270355,3652.251398086548,3473.7841737270355,176.778422832489,1.1418116092681885,0.0 -15000,0.20302923,0.24096084,,,,,,,,,,,,,, -15100,0.13837133,0.19650051,,,,,,,,,,,,,, -15200,0.16657971,0.19664608,,,,,,,,,,,,,, -15256,,,0.7442522048950195,0.2690439905439104,0.7211128196222566,0.2880876892124191,3554.0,0.7385002334368891,0.2893591025660604,3581.0,3553.957760572433,3736.482356786728,3553.957760572433,180.79540133476257,1.169428825378418,0.0 -15300,0.11132977,0.22858225,,,,,,,,,,,,,, -15400,0.047531217,0.29909125,,,,,,,,,,,,,, -15500,0.09376097,0.26432908,,,,,,,,,,,,,, -15593,,,0.7427872249058315,0.2695131301879883,0.720087552647545,0.2880110260336065,3554.0,0.7376525248272131,0.2893062997416922,3581.0,3634.062075376511,3820.642801046372,3634.062075376511,184.8119802474976,1.1967723369598389,0.0 -15600,0.052374244,0.36038744,,,,,,,,,,,,,, -15700,0.09093254,0.25631642,,,,,,,,,,,,,, -15800,0.055292297,0.26803458,,,,,,,,,,,,,, -15900,0.12368331,0.1920985,,,,,,,,,,,,,, -15940,,,0.7473726953778949,0.2692973273141043,0.7248972740881753,0.288025005385657,3554.0,0.7420830533545099,0.2893332295230033,3581.0,3714.14043045044,3904.782006263733,3714.14043045044,188.8285722732544,1.2282922267913818,0.0 -16000,0.12608315,0.2632597,,,,,,,,,,,,,, -16100,0.14399303,0.24479997,,,,,,,,,,,,,, -16200,0.08308299,0.21520948,,,,,,,,,,,,,, -16283,,,0.742541858128139,0.2691495077950613,0.7198662873259004,0.2876137307874929,3554.0,0.7375420104588453,0.2888464822478881,3581.0,3794.149676799774,3988.846718549728,3794.149676799774,192.8431706428528,1.2567307949066162,0.0 -16300,0.14379542,0.2442888,,,,,,,,,,,,,, -16400,0.1124874,0.28959435,,,,,,,,,,,,,, -16500,0.09950548,0.33561623,,,,,,,,,,,,,, -16600,0.073496036,0.32655326,,,,,,,,,,,,,, -16628,,,0.7466435432434082,0.270333136831011,0.7241068052722285,0.2894760758674205,3554.0,0.7411906890315205,0.2908273551491378,3581.0,3874.2583525180817,4073.014472723007,3874.2583525180817,196.8614206314087,1.2849531173706057,0.0 -16700,0.1201835,0.23486948,,,,,,,,,,,,,, -16800,0.0729986,0.22808173,,,,,,,,,,,,,, -16900,0.12129011,0.43615144,,,,,,,,,,,,,, -16976,,,0.7484711919512067,0.2686479602541242,0.726306681292206,0.2872895952953538,3554.0,0.7434004991622452,0.2886493494310248,3581.0,3954.3488595485687,4157.166150331497,3954.3488595485687,200.88203167915344,1.3127756118774414,0.0 -17000,0.052952938,0.22637872,,,,,,,,,,,,,, -17100,0.09378613,0.27097353,,,,,,,,,,,,,, -17200,0.09383088,0.2741777,,,,,,,,,,,,,, -17300,0.07130343,0.30954856,,,,,,,,,,,,,, -17319,,,0.7472654070172992,0.2685943841934204,0.7247833097390265,0.2873271884177599,3554.0,0.7419237244964745,0.2886111705005585,3581.0,4034.3800785541534,4241.2600877285,4034.3800785541534,204.9027829170227,1.3420863151550293,0.0 -17400,0.078811996,0.2978058,,,,,,,,,,,,,, -17500,0.050753564,0.2855572,,,,,,,,,,,,,, -17600,0.04996008,0.2768879,,,,,,,,,,,,,, -17665,,,0.7432973044259208,0.2685637984957014,0.7210903564865293,0.2872054959255856,3554.0,0.7382647512479056,0.2885156209094003,3581.0,4114.4050397872925,4325.3410885334015,4114.4050397872925,208.91718530654907,1.3709430694580078,0.0 -17700,0.09256653,0.21843064,,,,,,,,,,,,,, -17800,0.09804112,0.30694285,,,,,,,,,,,,,, -17900,0.08095546,0.34118685,,,,,,,,,,,,,, -18000,0.06795904,0.27736333,,,,,,,,,,,,,, -18010,,,0.7477449008396694,0.2679294858660017,0.7250144670837436,0.2867555806122942,3554.0,0.7421997036224169,0.2881434104256841,3581.0,4194.461406230927,4409.454977273941,4194.461406230927,212.93413639068604,1.3991119861602783,0.0 -18100,0.058700956,0.23049259,,,,,,,,,,,,,, -18200,0.12923226,0.23840468,,,,,,,,,,,,,, -18300,0.11344205,0.31166467,,,,,,,,,,,,,, -18354,,,0.7455470221383231,0.2684933287756784,0.723333166634426,0.2871022135874631,3554.0,0.7405550780159174,0.2883986297581856,3581.0,4274.639910936356,4493.69522857666,4274.639910936356,216.95491862297047,1.4276528358459473,0.0 -18400,0.08076997,0.33498535,,,,,,,,,,,,,, -18500,0.075971805,0.2566242,,,,,,,,,,,,,, -18600,0.1278006,0.29570186,,,,,,,,,,,,,, -18700,0.049960915,0.2797917,,,,,,,,,,,,,, -18702,,,0.7455432074410575,0.2690425259726388,0.7224505783536156,0.2881686458040236,3554.0,0.7397161641955808,0.2895410660757295,3581.0,4354.662460803986,4577.778399944305,4354.662460803986,220.97378945350647,1.456956148147583,0.0 -18800,0.109700896,0.28650403,,,,,,,,,,,,,, -18900,0.19275098,0.24550715,,,,,,,,,,,,,, -19000,0.15332879,0.27506158,,,,,,,,,,,,,, -19050,,,0.7460323061261859,0.2680240358625139,0.7234105167592854,0.2866896681391126,3554.0,0.7406949765254119,0.2879860927791643,3581.0,4434.6307945251465,4661.805282831192,4434.6307945251465,224.9903819561005,1.4863619804382324,0.0 -19100,0.087737165,0.30560878,,,,,,,,,,,,,, -19200,0.13286792,0.22696973,,,,,,,,,,,,,, -19300,0.13891584,0.23976608,,,,,,,,,,,,,, -19396,,,0.7468460627964565,0.268166184425354,0.7243975895338,0.2867450016431749,3554.0,0.7415834547786931,0.2881001523339325,3581.0,4514.716888904572,4745.949978351593,4514.716888904572,229.007297039032,1.5158488750457764,0.0 -19400,0.11287056,0.2239263,,,,,,,,,,,,,, -19500,0.08881052,0.33246335,,,,,,,,,,,,,, -19600,0.1020412,0.29436505,,,,,,,,,,,,,, -19700,0.1388563,0.335151,,,,,,,,,,,,,, -19741,,,0.7431297983442035,0.2682247332164219,0.7199703596519766,0.2873446196736951,3554.0,0.7374969456855627,0.288564435399068,3581.0,4594.704051733017,4829.992509126663,4594.704051733017,233.02187514305115,1.544053554534912,0.0 -19800,0.07454825,0.25016868,,,,,,,,,,,,,, -19900,0.13937575,0.23358876,,,,,,,,,,,,,, -20000,0.06159504,0.23513478,,,,,,,,,,,,,, -20089,,,0.745450496673584,0.2684863635471889,0.7224317560319359,0.2874175046492508,3554.0,0.7398124978183468,0.2886755633573897,3581.0,4674.802278280258,4914.1542048454285,4674.802278280258,237.04399013519287,1.5727825164794922,0.0 -20100,0.117725246,0.3540648,,,,,,,,,,,,,, -20200,0.07872334,0.2818069,,,,,,,,,,,,,, -20300,0.08287546,0.2647349,,,,,,,,,,,,,, -20400,0.0436603,0.28215304,,,,,,,,,,,,,, -20436,,,0.7470994676862445,0.2675222669328962,0.7241141555949282,0.2864648135243739,3554.0,0.7414250122172578,0.287764723159121,3581.0,4754.768115282059,4998.176766395569,4754.768115282059,241.05979704856875,1.6012942790985107,0.0 -20500,0.14597204,0.23644929,,,,,,,,,,,,,, -20600,0.11813936,0.29112914,,,,,,,,,,,,,, -20700,0.09773453,0.26482365,,,,,,,,,,,,,, -20780,,,0.7483934674944196,0.2671455485480172,0.7249244084570202,0.2866048818232977,3554.0,0.7421225958182072,0.2879269154369415,3581.0,4834.975478887558,5082.445472002029,4834.975478887558,245.0799369812012,1.6300930976867676,0.0 -20800,0.073634714,0.26028234,,,,,,,,,,,,,, -20900,0.064673476,0.3365739,,,,,,,,,,,,,, -21000,0.06054919,0.30960593,,,,,,,,,,,,,, -21100,0.11146338,0.2349669,,,,,,,,,,,,,, -21127,,,0.74561950138637,0.2681503977094377,0.7225248372212648,0.287237490437711,3554.0,0.7398913782157568,0.2884889297464046,3581.0,4914.961860656738,5166.493827342987,4914.961860656738,249.0958018302917,1.663682460784912,0.0 -21200,0.10087769,0.28790697,,,,,,,,,,,,,, -21300,0.10510419,0.3130204,,,,,,,,,,,,,, -21400,0.13356616,0.26431143,,,,,,,,,,,,,, -21473,,,0.7491350173950195,0.2679614509854998,0.7267961303504854,0.2867093663170019,3554.0,0.7438104454281276,0.2880482358061644,3581.0,4994.953593254089,5250.544547796249,4994.953593254089,253.1110918521881,1.6949870586395264,0.0 -21500,0.065958716,0.21806328,,,,,,,,,,,,,, -21600,0.1194626,0.25060993,,,,,,,,,,,,,, -21700,0.10883997,0.3308245,,,,,,,,,,,,,, -21800,0.10681839,0.268548,,,,,,,,,,,,,, -21819,,,0.7484683990478516,0.2674870831625802,0.7248557825469542,0.287353824750721,3554.0,0.742086053127618,0.2886547353872871,3581.0,5074.936768531799,5334.585844755173,5074.936768531799,257.1274366378784,1.7242553234100342,0.0 -21900,0.11064779,0.23113091,,,,,,,,,,,,,, -22000,0.14985667,0.24223074,,,,,,,,,,,,,, -22100,0.06338973,0.30532864,,,,,,,,,,,,,, -22167,,,0.7472255570547921,0.2676395348140171,0.7242740079399972,0.2866872638279491,3554.0,0.7414715087004329,0.2879750822483244,3581.0,5155.074422121048,5418.778828620911,5155.074422121048,261.1410791873932,1.7532122135162354,0.0 -22200,0.07059488,0.2601071,,,,,,,,,,,,,, -22300,0.12689583,0.32997644,,,,,,,,,,,,,, -22400,0.087710194,0.22771864,,,,,,,,,,,,,, -22500,0.07628148,0.22620139,,,,,,,,,,,,,, -22514,,,0.745988300868443,0.2670364209583827,0.7230816069921215,0.2860456390466727,3554.0,0.7404845151712162,0.2872728626343898,3581.0,5235.182781457901,5502.94899559021,5235.182781457901,265.16000413894653,1.7835183143615725,0.0 -22600,0.04709,0.21839455,,,,,,,,,,,,,, -22700,0.15237498,0.29956025,,,,,,,,,,,,,, -22800,0.0844192,0.26540193,,,,,,,,,,,,,, -22859,,,0.7471205166407994,0.2666735649108886,0.7236672284969401,0.2860689608649585,3554.0,0.7409802276773247,0.2873145867512566,3581.0,5315.278612852097,5587.105926513672,5315.278612852097,269.17903685569763,1.8133699893951416,0.0 -22900,0.106901996,0.20703727,,,,,,,,,,,,,, -23000,0.10910653,0.25899467,,,,,,,,,,,,,, -23100,0.0562037,0.3172139,,,,,,,,,,,,,, -23200,0.10723567,0.2741959,,,,,,,,,,,,,, -23207,,,0.7486298424857003,0.2670401845659528,0.7255774880636255,0.2862076209244865,3554.0,0.7427466849780089,0.2874790629472389,3581.0,5395.240823984146,5671.12223815918,5395.240823984146,273.1909189224243,1.843076229095459,0.0 -23300,0.09333815,0.26387832,,,,,,,,,,,,,, -23400,0.07925485,0.30225673,,,,,,,,,,,,,, -23500,0.06396751,0.27383414,,,,,,,,,,,,,, -23554,,,0.7473702430725098,0.2669352974210466,0.7242064124490011,0.2860737523136343,3554.0,0.7415203231901005,0.2872870092916608,3581.0,5475.210396766663,5755.149123430252,5475.210396766663,277.20637464523315,1.872402667999268,0.0 -23600,0.062157787,0.2580322,,,,,,,,,,,,,, -23700,0.07401095,0.28645274,,,,,,,,,,,,,, -23800,0.056536555,0.29270023,,,,,,,,,,,,,, -23899,,,0.7486155373709542,0.266871520451137,0.7250420823148214,0.2863125862804234,3554.0,0.7424447987206786,0.2875343542197535,3581.0,5555.329708099365,5839.32786488533,5555.329708099365,281.2231953144073,1.902432203292847,0.0 -23900,0.08550762,0.22970301,,,,,,,,,,,,,, -24000,0.03764609,0.19006108,,,,,,,,,,,,,, -24100,0.03878699,0.35012442,,,,,,,,,,,,,, -24200,0.08108686,0.27570567,,,,,,,,,,,,,, -24245,,,0.7487003462655204,0.2665752513068063,0.7254798043357836,0.2859320009968961,3554.0,0.7427145055937587,0.287214128440467,3581.0,5635.302527427673,5923.360556364059,5635.302527427673,285.23908710479736,1.933682203292847,0.0 -24300,0.031438828,0.27117777,,,,,,,,,,,,,, -24400,0.05864675,0.33135957,,,,,,,,,,,,,, -24500,0.08228488,0.3078711,,,,,,,,,,,,,, -24594,,,0.7478156770978656,0.2664803436824253,0.7245771572304094,0.2856176716595737,3554.0,0.7419074302743647,0.2868927436579342,3581.0,5715.485833406448,6007.602685213089,5715.485833406448,289.2548222541809,1.96408486366272,0.0 -24600,0.09633943,0.17490777,,,,,,,,,,,,,, -24700,0.058868274,0.2329806,,,,,,,,,,,,,, -24800,0.044393945,0.32092717,,,,,,,,,,,,,, -24900,0.087238334,0.31741485,,,,,,,,,,,,,, -24938,,,0.7481376784188407,0.2666981220245361,0.7249912483073649,0.2858955241618159,3554.0,0.7421870909400308,0.2872251048829761,3581.0,5795.482012271881,6091.655306100845,5795.482012271881,293.2680079936981,1.994885921478272,0.0 -25000,0.08493415,0.25731462,,,,,,,,,,,,,, -25100,0.050887775,0.23780793,,,,,,,,,,,,,, -25200,0.057793967,0.23482178,,,,,,,,,,,,,, -25284,,,0.7477162906101772,0.2664295264652797,0.7241976882342079,0.2860070498524989,3554.0,0.7414855530927116,0.2873365737246056,3581.0,5875.4409856796265,6175.679006099701,5875.4409856796265,297.2859468460083,2.02933931350708,0.0 -25300,0.07912642,0.29445264,,,,,,,,,,,,,, -25400,0.058962956,0.23222712,,,,,,,,,,,,,, -25500,0.11207486,0.25636294,,,,,,,,,,,,,, -25600,0.07754209,0.22361505,,,,,,,,,,,,,, -25631,,,0.7501141003199986,0.2663685594286237,0.7267898104468556,0.2858488633515932,3554.0,0.7439213006798031,0.2871733928851926,3581.0,5955.660590648651,6259.956165790558,5955.660590648651,301.29969024658203,2.0605125427246094,0.0 -25700,0.040285986,0.2712295,,,,,,,,,,,,,, -25800,0.10266327,0.21808949,,,,,,,,,,,,,, -25900,0.06498948,0.28319514,,,,,,,,,,,,,, -25976,,,0.748063496180943,0.2665225608008248,0.7249627400464266,0.2857278578054656,3554.0,0.7422536313617006,0.286973158030229,3581.0,6035.690726995468,6344.0462028980255,6035.690726995468,305.313759803772,2.093826532363892,0.0 -26000,0.0450039,0.36188117,,,,,,,,,,,,,, -26100,0.057000887,0.2932588,,,,,,,,,,,,,, -26200,0.061011665,0.2901888,,,,,,,,,,,,,, -26300,0.06208024,0.24522327,,,,,,,,,,,,,, -26325,,,0.7494532721383231,0.2655692952019827,0.7256950932268219,0.285317854057488,3554.0,0.742955441915666,0.2866083447142907,3581.0,6115.706563234329,6428.123630523682,6115.706563234329,309.33096265792847,2.125617742538452,0.0 -26400,0.05635009,0.33036166,,,,,,,,,,,,,, -26500,0.028892746,0.2847387,,,,,,,,,,,,,, -26600,0.057189543,0.3150211,,,,,,,,,,,,,, -26674,,,0.7486578396388462,0.2659331730433872,0.7254540438590321,0.2852956485265282,3554.0,0.7426701225870916,0.2866072879760367,3581.0,6195.682910203934,6512.158713579178,6195.682910203934,313.3459963798523,2.156813144683838,0.0 -26700,0.098296955,0.24832153,,,,,,,,,,,,,, -26800,0.06241953,0.22293797,,,,,,,,,,,,,, -26900,0.078414716,0.2508246,,,,,,,,,,,,,, -27000,0.050284844,0.25490013,,,,,,,,,,,,,, -27018,,,0.7490175792149135,0.2657155820301601,0.725519784595702,0.2851643903106535,3554.0,0.7428295877984502,0.2864416186884774,3581.0,6275.682108402252,6596.222285270691,6275.682108402252,317.3675000667572,2.1871368885040283,0.0 -27100,0.07302994,0.2917225,,,,,,,,,,,,,, -27200,0.062877126,0.26696128,,,,,,,,,,,,,, -27300,0.08196724,0.2502623,,,,,,,,,,,,,, -27364,,,0.7507955006190709,0.2653354576655796,0.7269669738323016,0.2853771203276677,3554.0,0.744054381523143,0.2866985083491867,3581.0,6355.869081735611,6680.471506595612,6355.869081735611,321.38672494888306,2.217285633087158,0.0 -27400,0.050553292,0.24884851,,,,,,,,,,,,,, -27500,0.0856486,0.36643678,,,,,,,,,,,,,, -27600,0.039120167,0.28391486,,,,,,,,,,,,,, -27700,0.07839238,0.30656433,,,,,,,,,,,,,, -27711,,,0.7491722106933594,0.2656243188040597,0.7254776748030388,0.2852912692454804,3554.0,0.7427899771580914,0.2865437473274748,3581.0,6435.99257683754,6764.661067485809,6435.99257683754,325.39993500709534,2.257634401321411,0.0 -27800,0.0640149,0.26750726,,,,,,,,,,,,,, -27900,0.06256104,0.25887835,,,,,,,,,,,,,, -28000,0.055878505,0.17756213,,,,,,,,,,,,,, -28058,,,0.7488178525652204,0.265708327293396,0.7253711294711944,0.2851588260476752,3554.0,0.7426109452448687,0.2864529360142941,3581.0,6516.111511468887,6848.840639352799,6516.111511468887,329.4172194004059,2.288557529449463,0.0 -28100,0.05002763,0.23908295,,,,,,,,,,,,,, -28200,0.061897464,0.31047514,,,,,,,,,,,,,, -28300,0.050570965,0.30969596,,,,,,,,,,,,,, -28400,0.056189053,0.3225668,,,,,,,,,,,,,, -28403,,,0.750096184866769,0.2649412836347307,0.7260181639525887,0.2850630829424152,3554.0,0.7432273986185772,0.2863629769093828,3581.0,6596.127126693726,6932.916352748871,6596.127126693726,333.43426752090454,2.319133281707764,0.0 -28500,0.058662713,0.28796563,,,,,,,,,,,,,, -28600,0.08506476,0.31752923,,,,,,,,,,,,,, -28700,0.040571507,0.2919524,,,,,,,,,,,,,, -28752,,,0.7498783384050641,0.2653285094669887,0.7260880263655388,0.2851461862404157,3554.0,0.7432734178651215,0.2864265857346063,3581.0,6676.110071659088,7016.962255001068,6676.110071659088,337.45313835144043,2.350611448287964,0.0 -28800,0.046778157,0.23057117,,,,,,,,,,,,,, -28900,0.039739285,0.24967624,,,,,,,,,,,,,, -29000,0.030941341,0.29096088,,,,,,,,,,,,,, -29100,,,0.749572617667062,0.2654174736567906,0.726090087203679,0.2851238433202465,3554.0,0.7432643503691357,0.2863883045391476,3581.0,6756.167396783829,7101.085487127304,6756.167396783829,341.47138237953186,2.385899305343628,0.0 -29100,0.05823388,0.2383569,,,,,,,,,,,,,, -29200,0.057964284,0.24674347,,,,,,,,,,,,,, -29300,0.03982572,0.27691227,,,,,,,,,,,,,, -29400,0.057665907,0.19740365,,,,,,,,,,,,,, -29446,,,0.7507331030709403,0.2646945033754621,0.7265406551157146,0.2850275850054516,3554.0,0.7437386554035186,0.2862946979828434,3581.0,6836.258863449097,7185.242262125015,6836.258863449097,345.49227356910706,2.417950391769409,0.0 -29500,0.05920654,0.22455749,,,,,,,,,,,,,, -29600,0.05203885,0.26177534,,,,,,,,,,,,,, -29700,0.039252643,0.30365443,,,,,,,,,,,,,, -29793,,,0.7498508180890765,0.2650007350104196,0.7257677034239589,0.2850495672789462,3554.0,0.7430140738445965,0.2863362857463872,3581.0,6916.343825817108,7269.396333932877,6916.343825817108,349.5119769573212,2.4548652172088623,0.0 -29800,0.031529985,0.25728625,,,,,,,,,,,,,, -29900,0.042691763,0.23995598,,,,,,,,,,,,,, -30000,0.028626889,0.2951146,,,,,,,,,,,,,, -30100,0.028501851,0.24961142,,,,,,,,,,,,,, -30142,,,0.7498049054827008,0.2650401251656668,0.7257995777205262,0.2849661376815736,3554.0,0.7429910301329936,0.2862814035338418,3581.0,6996.4826691150665,7353.598170042038,6996.4826691150665,353.5303068161011,2.487074136734009,0.0 -30200,0.057225857,0.2565717,,,,,,,,,,,,,, -30300,0.07726614,0.18032922,,,,,,,,,,,,,, -30400,0.04719271,0.25408316,,,,,,,,,,,,,, -30488,,,0.7510360990251813,0.2646484034402029,0.7266652671285875,0.285057432811181,3554.0,0.7437524952658127,0.286356636479859,3581.0,7076.474865198135,7437.652813434601,7076.474865198135,357.5478813648224,2.5196094512939453,0.0 -30500,0.04229012,0.21700019,,,,,,,,,,,,,, -30600,0.03340116,0.34414953,,,,,,,,,,,,,, -30700,0.026871508,0.28583205,,,,,,,,,,,,,, -30800,0.042251,0.32346863,,,,,,,,,,,,,, -30836,,,0.7507091249738421,0.2647977897099086,0.7265978090268008,0.2849529483174768,3554.0,0.7437485410194429,0.2862277144128735,3581.0,7156.552713394165,7521.801300287247,7156.552713394165,361.5675451755524,2.557867288589477,0.0 -30900,0.030266728,0.24815905,,,,,,,,,,,,,, -31000,0.03444581,0.29074383,,,,,,,,,,,,,, -31100,0.0408844,0.20346802,,,,,,,,,,,,,, -31182,,,0.750427109854562,0.264832649912153,0.7264883098269556,0.2848840132816896,3554.0,0.743651730160046,0.2861480499838558,3581.0,7236.685940265655,7605.993559360504,7236.685940265655,365.5828056335449,2.589299201965332,0.0 -31200,0.027888931,0.19065088,,,,,,,,,,,,,, -31300,0.032615513,0.3061189,,,,,,,,,,,,,, -31400,0.04112756,0.1899925,,,,,,,,,,,,,, -31500,0.025253113,0.32651722,,,,,,,,,,,,,, -31529,,,0.7505290167672294,0.2645143951688494,0.7261004600889842,0.2849139813029773,3554.0,0.7432802355312762,0.286220555863411,3581.0,7316.703558444977,7690.068843841553,7316.703558444977,369.5958352088928,2.6212565898895264,0.0 -31600,0.0298616,0.22463378,,,,,,,,,,,,,, -31700,0.029527005,0.23015684,,,,,,,,,,,,,, -31800,0.024275927,0.2443547,,,,,,,,,,,,,, -31876,,,0.7509726115635463,0.2645551477159772,0.7267222149558596,0.2848580295474729,3554.0,0.7438690091803966,0.2861240858873219,3581.0,7396.760830163956,7774.188344955444,7396.760830163956,373.6141917705536,2.6523241996765137,0.0 -31900,0.024925921,0.2981467,,,,,,,,,,,,,, -32000,0.030342013,0.30398396,,,,,,,,,,,,,, -32100,0.036220234,0.31162372,,,,,,,,,,,,,, -32200,0.034338474,0.21966161,,,,,,,,,,,,,, -32223,,,0.7508950233459473,0.2646592685154506,0.7267572492042417,0.2848581497630311,3554.0,0.7439353450720818,0.2861133821514591,3581.0,7476.919148206711,7858.408213853836,7476.919148206711,377.6309578418732,2.6843929290771484,0.0 -32300,0.027587531,0.27683488,,,,,,,,,,,,,, -32400,0.027351081,0.26065078,,,,,,,,,,,,,, -32500,0.021213034,0.30587626,,,,,,,,,,,,,, -32569,,,0.750824110848563,0.26461204460689,0.7265723920230726,0.284898696753438,3554.0,0.7437707666111072,0.2861683666289968,3581.0,7557.069484472275,7942.621923923492,7557.069484472275,381.6490135192871,2.7174501419067383,0.0 -32600,0.024229337,0.2686624,,,,,,,,,,,,,, -32700,0.02000435,0.26060283,,,,,,,,,,,,,, -32800,0.02448691,0.23063202,,,,,,,,,,,,,, -32900,0.023210013,0.24613607,,,,,,,,,,,,,, -32918,,,0.7510066713605609,0.2643741369247436,0.7264729222355093,0.2848729706239888,3554.0,0.7436541845198618,0.2862069887077632,3581.0,7637.059919595718,8026.681388616562,7637.059919595718,385.6676907539368,2.7551095485687256,0.0 -33000,0.030608084,0.3221043,,,,,,,,,,,,,, -33100,0.023903633,0.25575352,,,,,,,,,,,,,, -33200,0.019006312,0.32951614,,,,,,,,,,,,,, -33264,,,0.7508365086146763,0.2644939933504377,0.7264935993115152,0.2848100635397615,3554.0,0.7436945451034976,0.286068385554838,3581.0,7717.092004537582,8110.772299528122,7717.092004537582,389.682653427124,2.786450147628784,0.0 -33300,0.030595185,0.19720757,,,,,,,,,,,,,, -33400,0.037068065,0.26650384,,,,,,,,,,,,,, -33500,0.025048178,0.25287876,,,,,,,,,,,,,, -33600,0.03276836,0.2851707,,,,,,,,,,,,,, -33610,,,0.7514575549534389,0.264496956552778,0.7271822627233399,0.2849336107862619,3554.0,0.7443005674479893,0.2862118974273945,3581.0,7797.218616485596,8194.963827133179,7797.218616485596,393.70337677001953,2.818121194839477,0.0 -33700,0.015985442,0.24527454,,,,,,,,,,,,,, -33800,0.02317194,0.32647616,,,,,,,,,,,,,, -33900,0.018504279,0.21318631,,,,,,,,,,,,,, -33958,,,0.7513347353254046,0.2641548088618687,0.7267312139490715,0.2847540774369548,3554.0,0.7438889167655682,0.286054988840844,3581.0,7877.383341789246,8279.188350439072,7877.383341789246,397.71788907051086,2.850908041000366,0.0 -34000,0.016744828,0.2619697,,,,,,,,,,,,,, -34100,0.020727392,0.2226749,,,,,,,,,,,,,, -34200,0.020913996,0.2309685,,,,,,,,,,,,,, -34300,0.012468062,0.24961865,,,,,,,,,,,,,, -34304,,,0.7511395045689174,0.2642734391348703,0.7266988587902715,0.2847594871370726,3554.0,0.7438610325109956,0.2860153782004852,3581.0,7957.451696395874,8363.3182117939,7957.451696395874,401.73359990119934,2.884448528289795,0.0 -34400,0.021567028,0.1862854,,,,,,,,,,,,,, -34500,0.03095952,0.24821544,,,,,,,,,,,,,, -34600,0.020359335,0.22381729,,,,,,,,,,,,,, -34647,,,0.7512046950204032,0.2642475877489362,0.7267843835730866,0.2847093915966165,3554.0,0.7439592750802848,0.2859457016523841,3581.0,8037.453133821487,8447.382189512253,8037.453133821487,405.75045943260193,2.918001174926758,0.0 -34700,0.038502753,0.39594924,,,,,,,,,,,,,, -34800,0.03363175,0.34000984,,,,,,,,,,,,,, -34900,0.05114074,0.24509946,,,,,,,,,,,,,, -34993,,,0.7513880729675293,0.2640795026506696,0.72677325504713,0.2847361824924381,3554.0,0.7439256639861421,0.2859992203316985,3581.0,8117.535898685455,8531.523807525635,8117.535898685455,409.7646481990814,2.950176954269409,0.0 -35000,0.018812587,0.32404125,,,,,,,,,,,,,, -35100,0.025758257,0.28016725,,,,,,,,,,,,,, -35200,0.015379055,0.2531384,,,,,,,,,,,,,, -35300,0.01486707,0.30214822,,,,,,,,,,,,,, -35342,,,0.751244136265346,0.2640912362507411,0.7266568863868177,0.284680162042329,3554.0,0.7438301484833147,0.2859374863646677,3581.0,8197.736165523529,8615.789492368698,8197.736165523529,413.7841286659241,2.983177423477173,0.0 -35400,0.015985243,0.30878732,,,,,,,,,,,,,, -35500,0.017524008,0.26310048,,,,,,,,,,,,,, -35600,0.014222952,0.19494593,,,,,,,,,,,,,, -35686,,,0.7514269011361259,0.2641090495245797,0.7268992409520962,0.2846793720543754,3554.0,0.7440562904696663,0.2859263735688355,3581.0,8277.73072552681,8699.84346485138,8277.73072552681,417.794353723526,3.0200717449188232,0.0 -35700,0.018684603,0.30958754,,,,,,,,,,,,,, -35800,0.018559337,0.22935358,,,,,,,,,,,,,, -35900,0.021012818,0.28990605,,,,,,,,,,,,,, -36000,0.021774707,0.22334592,,,,,,,,,,,,,, -36033,,,0.7513409342084613,0.2640905209950038,0.7267714689874085,0.2846803166051895,3554.0,0.7439452306880061,0.2859348615631981,3581.0,8357.767586946487,8783.942422628403,8357.767586946487,421.8108897209168,3.0529043674468994,0.0 -36100,0.015165559,0.3962549,,,,,,,,,,,,,, -36189,,,0.751366410936628,0.2640893459320068,0.7267988781346723,0.2846795094435846,3554.0,0.7439701833461324,0.2859344525032288,3581.0,8392.034851789474,8822.267328739166,8392.034851789474,425.8297622203827,3.08566689491272,0.0 -36189,,,,,,,,,,,8392.034851789474,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 6f5d45dca..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,70 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -202.0728747844696,0.0,56.70858883857727,1,0,56.70858883857727,0.9592445371404636,3581,0.2462517153248045,258.78186297416687,0.9522901262555804,0.2341189384460449,0.9615848092598128,3554,0.2230110232858223 -206.38493728637687,0.0288667678833007,136.92155241966248,335,0,136.92155241966248,0.3322951277685179,3581,0.6994074629991622,343.3468608856201,0.3096670082637242,0.704350199018206,0.3305422944437605,3554,0.6813741586284819 -210.41696405410767,0.0710964202880859,217.1231291294098,571,0,217.1231291294098,0.3205922629306583,3581,0.7031136144975216,427.6304380893707,0.2987579277583531,0.7054879324776786,0.3183833494170301,3554,0.6865909642963914 -214.43421697616577,0.1018157005310058,297.2307069301605,811,0,297.2307069301605,0.3115179833670762,3581,0.7188536285255516,511.7935454845429,0.2885885919843401,0.7250401633126395,0.3090903089498628,3554,0.7017843560952448 -218.45574116706848,0.1410191059112548,377.3906149864197,1087,0,377.3906149864197,0.3083762664496648,3581,0.7177934132618333,596.0232887268066,0.2861387729644775,0.7226156507219587,0.3059486642471511,3554,0.7011145836997046 -222.4755780696869,0.16563081741333,457.4902172088623,1437,0,457.4902172088623,0.3039782241015952,3581,0.7253636134066951,680.1792616844177,0.2812515326908656,0.7317293030875069,0.3017433866330191,3554,0.70807458365574 -226.4948856830597,0.1917774677276611,537.457314491272,1783,0,537.457314491272,0.3032157022305222,3581,0.7269661058363586,764.2037088871002,0.2804828711918422,0.7334301131112235,0.3009206657001794,3554,0.70985033918648 -230.51172065734863,0.2170164585113525,617.5902824401855,2127,0,617.5902824401855,0.301231454584526,3581,0.7303830519364354,848.39049243927,0.2784201417650495,0.7369766235351562,0.2990333844787563,3554,0.7133730672086029 -234.5359199047089,0.241142988204956,697.6991181373596,2473,0,697.6991181373596,0.2985777121221377,3581,0.7304407975687657,932.5592963695526,0.276072655405317,0.7371116365705218,0.2965130824752919,3554,0.7131970716314364 -238.551185131073,0.2645330429077148,777.7977502346039,2821,0,777.7977502346039,0.2976082740832693,3581,0.7322251853314368,1016.7080388069152,0.275220513343811,0.7386969157627651,0.2956504843244407,3554,0.7150522381251758 -242.5719120502472,0.288226842880249,857.8765261173248,3168,0,857.8765261173248,0.2965458430780508,3581,0.7309740072387252,1100.8429176807404,0.2744026524680001,0.7375610896519252,0.2948401284204417,3554,0.7134541268421145 -246.5917451381684,0.3134384155273437,938.0704567432404,3515,0,938.0704567432404,0.2958462482655857,3581,0.7319145724614283,1185.0933780670166,0.2737060444695608,0.737964425768171,0.2941061609168366,3554,0.7145502179542417 -250.6148431301117,0.3367536067962646,1018.14515542984,3861,0,1018.14515542984,0.2969245303446139,3581,0.7325709773588034,1269.226101398468,0.2744531631469726,0.7393647602626255,0.2951041217861916,3554,0.7152820902724043 -254.63929557800293,0.3610873222351074,1098.3866775035858,4209,0,1098.3866775035858,0.2944523082435946,3581,0.7348902792079727,1353.5279922485352,0.2721474340983799,0.7418202672685895,0.2926908116295019,3554,0.717668472144063 -258.6568067073822,0.3862354755401611,1178.4214470386505,4558,0,1178.4214470386505,0.2938810900848226,3581,0.7329621068660989,1437.6167466640472,0.2717377458299909,0.7396998405456543,0.2921931535659644,3554,0.7157129428328995 -262.6799116134644,0.4110124111175537,1258.5635569095612,4905,0,1258.5635569095612,0.293165916905194,3581,0.7345451689472214,1521.8184142112732,0.2708273615155901,0.741436277117048,0.2914666394269661,3554,0.7171715353738745 -266.7015905380249,0.4352340698242187,1338.5920433998108,5250,0,1338.5920433998108,0.2957844120335625,3581,0.732739919126117,1605.9042460918429,0.273174592426845,0.7398110117231097,0.2941134768922341,3554,0.7154668100643641 -270.7231616973877,0.4594244956970215,1418.782052278519,5597,0,1418.782052278519,0.293226287338994,3581,0.7340107320973541,1690.151935338974,0.2707802057266235,0.7411487443106515,0.2916773944740961,3554,0.7166218411472988 -274.7461721897125,0.4857480525970459,1498.9853432178495,5945,0,1498.9853432178495,0.2926771584185632,3581,0.734261349505201,1774.4164345264437,0.2705115420477731,0.7413513319832938,0.2911257767985896,3554,0.7167771596484594 -278.763644695282,0.5121357440948486,1579.0458467006683,6289,0,1579.0458467006683,0.2924217004677464,3581,0.7346909988262706,1858.532172679901,0.270162616457258,0.7416772842407227,0.2906290117649128,3554,0.7174957739079206 -282.7860562801361,0.5379180908203125,1659.1377947330477,6634,0,1659.1377947330477,0.2932752381819847,3581,0.7350663795247486,1942.6838665008545,0.2703203133174351,0.7426674025399345,0.2914958346339512,3554,0.717994221959412 -286.80538868904114,0.562300443649292,1739.3163893222809,6983,0,1739.3163893222809,0.2917695225233873,3581,0.7367043919950782,2026.9179275035856,0.2694126026970999,0.7435639926365444,0.2901051810569956,3554,0.7192983890290869 -290.8272798061371,0.5869457721710205,1819.362901210785,7328,0,1819.362901210785,0.2921486870265812,3581,0.7355595694943801,2111.022508382797,0.2700649499893188,0.742255619594029,0.2905095175000879,3554,0.7180247910584905 -294.8489227294922,0.6131947040557861,1899.522413253784,7672,0,1899.522413253784,0.2915839456506562,3581,0.7349674551888439,2195.2415804862976,0.2691801445824759,0.7419414520263672,0.2900247053276238,3554,0.7174110047657569 -298.8732056617737,0.6418027877807617,1979.6026208400729,8023,0,1979.6026208400729,0.2914978385271223,3581,0.7348966878141581,2279.386494874954,0.2693272318158831,0.7412242889404297,0.2899157213373136,3554,0.7173855877620287 -302.89744091033936,0.671701192855835,2059.7492246627808,8370,0,2059.7492246627808,0.2917135494842572,3581,0.7374802424034836,2363.59885597229,0.2692355258124215,0.7443105152675084,0.2901451269696117,3554,0.7201239607880205 -306.92110443115234,0.7052109241485596,2139.943512916565,8714,0,2139.943512916565,0.2926012777942613,3581,0.7357666220154985,2447.8618161678314,0.2703418731689453,0.7421479906354632,0.2912678372410488,3554,0.7183052711293613 -310.9387810230255,0.7300865650177002,2220.011216402054,9062,0,2220.011216402054,0.2916376006832937,3581,0.735274659225775,2531.983940124512,0.269741552216666,0.7411454745701381,0.2902401659551737,3554,0.7177406014789673 -314.9629945755005,0.7551014423370361,2300.221267700196,9411,0,2300.221267700196,0.2926036639774155,3581,0.7360342154120707,2616.25523352623,0.2701402732304164,0.7425169944763184,0.2910379850938203,3554,0.7188875952931556 -318.9835820198059,0.7812681198120117,2380.346992969513,9756,0,2380.346992969513,0.291624919824246,3581,0.7370602741683538,2700.43922996521,0.2689634391239711,0.7441386495317731,0.2901753182483645,3554,0.7195929514939153 -323.007447719574,0.8100852966308594,2460.3542511463165,10102,0,2460.3542511463165,0.2913862333321698,3581,0.7375806666259425,2784.5107316970825,0.2686057601656232,0.7448515210832868,0.2896271696503939,3554,0.7205442343793964 -327.03503346443176,0.834989070892334,2540.4041180610657,10450,0,2540.4041180610657,0.2915610382923764,3581,0.7359422450956437,2868.624881267548,0.2686093364443098,0.743347304207938,0.2899211138637802,3554,0.7185303146542628 -331.0597715377808,0.8598167896270752,2620.4691038131714,10795,0,2620.4691038131714,0.2916025237909278,3581,0.7371557214945197,2952.7511084079742,0.2687500544956752,0.7445851053510394,0.2898923308244231,3554,0.7199884263330051 -335.07978916168213,0.8851232528686523,2700.4993851184845,11141,0,2700.4993851184845,0.2906030880209788,3581,0.735852797315694,3036.8381135463715,0.2680032934461321,0.7431984628949847,0.2890530201445554,3554,0.7183742405124508 -339.10322880744934,0.9145634174346924,2780.4919424057007,11490,0,2780.4919424057007,0.2918942517256876,3581,0.7359646070406312,3120.895471572876,0.2691079378128052,0.7431731224060059,0.2902030708686515,3554,0.7189385666898214 -343.12741708755493,0.9401652812957764,2860.522198200226,11836,0,2860.522198200226,0.2931858244903658,3581,0.7362211558180327,3204.9869713783264,0.2704858439309256,0.7435904230390277,0.2916788027134918,3554,0.7188783902161298 -347.15452122688293,0.9660358428955078,2940.721024751663,12185,0,2940.721024751663,0.2918707989541155,3581,0.7370257086009494,3289.250340938568,0.2688614640917097,0.7444327899387905,0.2902460049965707,3554,0.7195729613639561 -351.1773955821991,0.9935295581817628,3020.9883308410645,12533,0,3020.9883308410645,0.2909331994161896,3581,0.7362679931845155,3373.5796024799347,0.2686043637139456,0.7431977135794503,0.2894434115828995,3554,0.7186979294896595 -355.19680309295654,1.0215466022491455,3101.179794549942,12879,0,3101.179794549942,0.2900932970342607,3581,0.7389217015585731,3457.8299860954285,0.2675271034240722,0.7460365976606097,0.2885580755179023,3554,0.7216288534925436 -359.2215950489044,1.0475833415985107,3181.1878747940063,13224,0,3181.1878747940063,0.2912266658558189,3581,0.7361949759799986,3541.900530576706,0.2689677136284964,0.7422705377851214,0.289795334042628,3554,0.7188009027020611 -363.24598574638367,1.077925205230713,3261.1537551879883,13570,0,3261.1537551879883,0.2914237645843514,3581,0.7378128081585102,3625.932768821716,0.2687478576387678,0.7451451846531459,0.2898253535848691,3554,0.7206711133142234 -367.2687175273895,1.1079049110412598,3341.286249399185,13916,0,3341.286249399185,0.2901630417590233,3581,0.7366030132993577,3710.1297755241394,0.2678613662719726,0.7433956691196987,0.2888582365930114,3554,0.718911088847953 -371.2918519973755,1.134298324584961,3421.35107922554,14260,0,3421.35107922554,0.2925438048685772,3581,0.7367140049043563,3794.255617141724,0.2692598274775913,0.7445827892848423,0.2909322984445343,3554,0.71959302018852 -375.3122110366821,1.1605091094970703,3501.3453781604767,14605,0,3501.3453781604767,0.2902641136597668,3581,0.7381265571549497,3878.308022737503,0.2677821431841169,0.7452515874590192,0.2887048758880838,3554,0.7207931836267234 -379.3365240097046,1.1882827281951904,3581.5688560009003,14952,0,3581.5688560009003,0.2898208290063879,3581,0.7366066948390813,3962.595205783844,0.2674763543265206,0.7432610648018974,0.2883870259522721,3554,0.7191397044922974 -383.3614344596863,1.214857578277588,3661.5406353473663,15296,0,3661.5406353473663,0.2909001337353393,3581,0.7383170427473122,4046.630147695541,0.2680352415357317,0.7455353736877441,0.2894263409736388,3554,0.7209673244495639 -387.38481283187866,1.240983963012695,3741.661069869995,15642,0,3741.661069869995,0.2907048416883377,3581,0.7376256632225635,4130.811549901962,0.26785751751491,0.7450480461120605,0.2891894132821292,3554,0.7202514579742896 -391.41095185279846,1.268481731414795,3821.717903852463,15989,0,3821.717903852463,0.2903537318813704,3581,0.7378837118865191,4214.933522701263,0.2676250594002859,0.7453368050711495,0.2887504204109806,3554,0.7206592291476154 -395.4308760166168,1.295414924621582,3901.7154943943024,16332,0,3901.7154943943024,0.289822158451288,3581,0.7379716597799149,4298.989427089691,0.2670436246054513,0.7455001558576312,0.2883993222865081,3554,0.7205200538785523 -399.4510545730591,1.3222932815551758,3981.8962936401367,16679,0,3981.8962936401367,0.2899093564014067,3581,0.7388266632923765,4383.229125022888,0.2674624579293387,0.7455932753426688,0.2885471530757597,3554,0.721336626644274 -403.477201461792,1.3522214889526367,4062.037636041641,17027,0,4062.037636041641,0.2900039515193032,3581,0.738394286904845,4467.438356161118,0.2673143659319196,0.7456743376595634,0.2884896556916502,3554,0.7211200325557471 -407.501473903656,1.3806486129760742,4142.102548122406,17371,0,4142.102548122406,0.2897704123651738,3581,0.7371324050762706,4551.567642211914,0.2674683843340192,0.7442355837140765,0.2883284638017902,3554,0.7196316952509496 -411.522497177124,1.4090490341186523,4222.079712629318,17718,0,4222.079712629318,0.2898624508582623,3581,0.7374111794453365,4635.606120109558,0.2668800013405936,0.7450887135096959,0.2883496560873311,3554,0.7200384360052055 -415.54224705696106,1.436967849731445,4302.183186531067,18063,0,4302.183186531067,0.2905329683245776,3581,0.7389545627094387,4719.768888950348,0.267510039465768,0.7464371408735003,0.2890953360210326,3554,0.7215637997019204 -419.5635986328125,1.4638168811798096,4382.376043796539,18409,0,4382.376043796539,0.2908939978358,3581,0.7378443739528064,4804.021505832672,0.2682968207768031,0.7450125558035714,0.2893638632306907,3554,0.7206625264886396 -423.5873386859894,1.4927854537963867,4462.424655199051,18755,0,4462.424655199051,0.2900614585333182,3581,0.7393603501989667,4888.134497642517,0.2671412229537964,0.7469017846243722,0.28858466032991,3554,0.7221098531144485 -427.6101665496826,1.521597385406494,4542.533133029938,19102,0,4542.533133029938,0.289944228763788,3581,0.737858145638439,4972.306104183197,0.2671962295259748,0.7452338082449776,0.2884972807927687,3554,0.7205666288205191 -431.6334185600281,1.5527558326721191,4622.55202126503,19447,0,4622.55202126503,0.2891450960254642,3581,0.7391323674427535,5056.391085863113,0.2665221009935651,0.7461850302559989,0.2877850207842395,3554,0.7216126415658413 -435.6535060405731,1.5806972980499268,4702.551931381226,19792,0,4702.551931381226,0.2899258210651703,3581,0.7397317084744136,5140.450572013855,0.2668184552873884,0.7475410188947406,0.2883770995818971,3554,0.7225712060794176 -439.6756844520569,1.6097443103790283,4782.750027894974,20138,0,4782.750027894974,0.2895012849937168,3581,0.7388371624982547,5224.711462259293,0.2669763224465506,0.7457947049822126,0.2879447013927969,3554,0.7215591284688028 -443.6998512744904,1.637216329574585,4862.860502004623,20483,0,4862.860502004623,0.2897698328635507,3581,0.7379801136859466,5308.885147809982,0.2667371375220163,0.7456110545567104,0.2882838466560565,3554,0.7206008387336452 -447.72234869003296,1.6652333736419678,4942.934024333954,20829,0,4942.934024333954,0.2891209273989458,3581,0.7407019305448896,5393.02094745636,0.2659840413502284,0.7483904702322823,0.2876581590230638,3554,0.7234748149092571 -451.7472302913666,1.6936416625976562,5022.919233322144,21176,0,5022.919233322144,0.2900835477716594,3581,0.738236389756702,5477.071233034134,0.2673719440187727,0.7456222942897252,0.2885616132900429,3554,0.7210958520549029 -455.7704327106476,1.7223103046417236,5103.048098325729,21522,0,5103.048098325729,0.2903546181779705,3581,0.7384746671888089,5561.263752222061,0.2674393142972673,0.7460637092590332,0.2888864013809264,3554,0.7211438695835678 -459.7928726673126,1.7547547817230225,5183.179874658585,21868,0,5183.179874658585,0.2892208062081122,3581,0.7384185577963558,5645.4620842933655,0.2662193434579031,0.7460754939488002,0.2877452122608329,3554,0.7210180210678109 -463.8135120868683,1.7877426147460938,5263.278465986252,22204,0,5263.278465986252,0.2908094587754817,3581,0.7386037937857791,5729.6257474422455,0.2678557123456682,0.7462743350437709,0.2893444913521736,3554,0.7213033784556134 -467.8380720615387,1.8198356628417969,5343.419144153595,22551,0,5343.419144153595,0.2890804986386484,3581,0.7401515403562203,5813.834905862808,0.2660458598818098,0.7475292342049735,0.2876492974190612,3554,0.7228142475907429 -471.8632171154022,1.8486895561218264,5423.412467479706,22892,0,5423.412467479706,0.2890605569651459,3581,0.7391687737800196,5897.893646717072,0.266058257647923,0.7466079848153251,0.2875659193426421,3554,0.7219377731297482 -475.89069175720215,1.8772878646850586,5503.533691883087,23240,0,5503.533691883087,0.2887749990182561,3581,0.7411558507574699,5982.082504987717,0.2657348428453718,0.7483805247715541,0.28723769652152503,3554,0.7238905546567248 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/measurements.csv deleted file mode 100644 index b2c2ae267..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/measurements.csv +++ /dev/null @@ -1,304 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.120433,0.9916711,,,,,,,,,,,,,, -1,,,0.2341189384460449,0.9522901262555804,0.2230110232858223,0.9615848092598128,3554.0,0.2462517153248045,0.9592445371404636,3581.0,56.70858883857727,258.78186297416687,56.70858883857727,202.0728747844696,0.0,0.0 -100,0.2260157,0.30417052,,,,,,,,,,,,,, -200,0.3057949,0.45114797,,,,,,,,,,,,,, -300,0.1304007,0.31970757,,,,,,,,,,,,,, -335,,,0.704350199018206,0.3096670082637242,0.6813741586284819,0.3305422944437605,3554.0,0.6994074629991622,0.3322951277685179,3581.0,136.92155241966248,343.3468608856201,136.92155241966248,206.38493728637687,0.0288667678833007,0.0 -400,0.093983255,0.3136414,,,,,,,,,,,,,, -500,0.121870786,0.259951,,,,,,,,,,,,,, -571,,,0.7054879324776786,0.2987579277583531,0.6865909642963914,0.3183833494170301,3554.0,0.7031136144975216,0.3205922629306583,3581.0,217.1231291294098,427.6304380893707,217.1231291294098,210.41696405410767,0.0710964202880859,0.0 -600,0.06425653,0.31614986,,,,,,,,,,,,,, -700,0.18043934,0.3070223,,,,,,,,,,,,,, -800,0.08235711,0.3033262,,,,,,,,,,,,,, -811,,,0.7250401633126395,0.2885885919843401,0.7017843560952448,0.3090903089498628,3554.0,0.7188536285255516,0.3115179833670762,3581.0,297.2307069301605,511.7935454845429,297.2307069301605,214.43421697616577,0.1018157005310058,0.0 -900,0.19042718,0.33958018,,,,,,,,,,,,,, -1000,0.34950885,0.31845,,,,,,,,,,,,,, -1087,,,0.7226156507219587,0.2861387729644775,0.7011145836997046,0.3059486642471511,3554.0,0.7177934132618333,0.3083762664496648,3581.0,377.3906149864197,596.0232887268066,377.3906149864197,218.45574116706848,0.1410191059112548,0.0 -1100,0.108899735,0.27233565,,,,,,,,,,,,,, -1200,0.082797974,0.24477963,,,,,,,,,,,,,, -1300,0.11520713,0.30520785,,,,,,,,,,,,,, -1400,0.13092223,0.2566104,,,,,,,,,,,,,, -1437,,,0.7317293030875069,0.2812515326908656,0.70807458365574,0.3017433866330191,3554.0,0.7253636134066951,0.3039782241015952,3581.0,457.4902172088623,680.1792616844177,457.4902172088623,222.4755780696869,0.16563081741333,0.0 -1500,0.26301327,0.2899595,,,,,,,,,,,,,, -1600,0.42776218,0.21218824,,,,,,,,,,,,,, -1700,0.051003125,0.32172066,,,,,,,,,,,,,, -1783,,,0.7334301131112235,0.2804828711918422,0.70985033918648,0.3009206657001794,3554.0,0.7269661058363586,0.3032157022305222,3581.0,537.457314491272,764.2037088871002,537.457314491272,226.4948856830597,0.1917774677276611,0.0 -1800,0.08990325,0.35175636,,,,,,,,,,,,,, -1900,0.17481413,0.44932994,,,,,,,,,,,,,, -2000,0.14514934,0.3013752,,,,,,,,,,,,,, -2100,0.19226052,0.45327857,,,,,,,,,,,,,, -2127,,,0.7369766235351562,0.2784201417650495,0.7133730672086029,0.2990333844787563,3554.0,0.7303830519364354,0.301231454584526,3581.0,617.5902824401855,848.39049243927,617.5902824401855,230.51172065734863,0.2170164585113525,0.0 -2200,0.113191664,0.27020907,,,,,,,,,,,,,, -2300,0.2526721,0.29965746,,,,,,,,,,,,,, -2400,0.12523839,0.23183814,,,,,,,,,,,,,, -2473,,,0.7371116365705218,0.276072655405317,0.7131970716314364,0.2965130824752919,3554.0,0.7304407975687657,0.2985777121221377,3581.0,697.6991181373596,932.5592963695526,697.6991181373596,234.5359199047089,0.241142988204956,0.0 -2500,0.23385353,0.2973778,,,,,,,,,,,,,, -2600,0.10319205,0.34898368,,,,,,,,,,,,,, -2700,0.3338454,0.24729088,,,,,,,,,,,,,, -2800,0.17422299,0.3177299,,,,,,,,,,,,,, -2821,,,0.7386969157627651,0.275220513343811,0.7150522381251758,0.2956504843244407,3554.0,0.7322251853314368,0.2976082740832693,3581.0,777.7977502346039,1016.7080388069152,777.7977502346039,238.551185131073,0.2645330429077148,0.0 -2900,0.103763014,0.35907876,,,,,,,,,,,,,, -3000,0.19075985,0.25439325,,,,,,,,,,,,,, -3100,0.11040003,0.25486052,,,,,,,,,,,,,, -3168,,,0.7375610896519252,0.2744026524680001,0.7134541268421145,0.2948401284204417,3554.0,0.7309740072387252,0.2965458430780508,3581.0,857.8765261173248,1100.8429176807404,857.8765261173248,242.5719120502472,0.288226842880249,0.0 -3200,0.32628986,0.23526095,,,,,,,,,,,,,, -3300,0.12668951,0.33054867,,,,,,,,,,,,,, -3400,0.099842176,0.26570937,,,,,,,,,,,,,, -3500,0.1785915,0.26609603,,,,,,,,,,,,,, -3515,,,0.737964425768171,0.2737060444695608,0.7145502179542417,0.2941061609168366,3554.0,0.7319145724614283,0.2958462482655857,3581.0,938.0704567432404,1185.0933780670166,938.0704567432404,246.5917451381684,0.3134384155273437,0.0 -3600,0.3345189,0.3746553,,,,,,,,,,,,,, -3700,0.30097705,0.24743022,,,,,,,,,,,,,, -3800,0.20755535,0.342618,,,,,,,,,,,,,, -3861,,,0.7393647602626255,0.2744531631469726,0.7152820902724043,0.2951041217861916,3554.0,0.7325709773588034,0.2969245303446139,3581.0,1018.14515542984,1269.226101398468,1018.14515542984,250.6148431301117,0.3367536067962646,0.0 -3900,0.0728016,0.37469873,,,,,,,,,,,,,, -4000,0.19362552,0.21937045,,,,,,,,,,,,,, -4100,0.08694752,0.29905236,,,,,,,,,,,,,, -4200,0.12341475,0.24397571,,,,,,,,,,,,,, -4209,,,0.7418202672685895,0.2721474340983799,0.717668472144063,0.2926908116295019,3554.0,0.7348902792079727,0.2944523082435946,3581.0,1098.3866775035858,1353.5279922485352,1098.3866775035858,254.63929557800293,0.3610873222351074,0.0 -4300,0.10701732,0.3504095,,,,,,,,,,,,,, -4400,0.34308878,0.31862468,,,,,,,,,,,,,, -4500,0.30828482,0.27637875,,,,,,,,,,,,,, -4558,,,0.7396998405456543,0.2717377458299909,0.7157129428328995,0.2921931535659644,3554.0,0.7329621068660989,0.2938810900848226,3581.0,1178.4214470386505,1437.6167466640472,1178.4214470386505,258.6568067073822,0.3862354755401611,0.0 -4600,0.07061479,0.3062765,,,,,,,,,,,,,, -4700,0.19769286,0.3233564,,,,,,,,,,,,,, -4800,0.09506498,0.29024437,,,,,,,,,,,,,, -4900,0.14452904,0.34957027,,,,,,,,,,,,,, -4905,,,0.741436277117048,0.2708273615155901,0.7171715353738745,0.2914666394269661,3554.0,0.7345451689472214,0.293165916905194,3581.0,1258.5635569095612,1521.8184142112732,1258.5635569095612,262.6799116134644,0.4110124111175537,0.0 -5000,0.09557955,0.36053362,,,,,,,,,,,,,, -5100,0.037322056,0.27001065,,,,,,,,,,,,,, -5200,0.18337926,0.23769791,,,,,,,,,,,,,, -5250,,,0.7398110117231097,0.273174592426845,0.7154668100643641,0.2941134768922341,3554.0,0.732739919126117,0.2957844120335625,3581.0,1338.5920433998108,1605.9042460918429,1338.5920433998108,266.7015905380249,0.4352340698242187,0.0 -5300,0.1368888,0.28743875,,,,,,,,,,,,,, -5400,0.17156288,0.26312557,,,,,,,,,,,,,, -5500,0.082838155,0.23161802,,,,,,,,,,,,,, -5597,,,0.7411487443106515,0.2707802057266235,0.7166218411472988,0.2916773944740961,3554.0,0.7340107320973541,0.293226287338994,3581.0,1418.782052278519,1690.151935338974,1418.782052278519,270.7231616973877,0.4594244956970215,0.0 -5600,0.16247027,0.25808063,,,,,,,,,,,,,, -5700,0.13106297,0.26932222,,,,,,,,,,,,,, -5800,0.09168298,0.30820832,,,,,,,,,,,,,, -5900,0.17389132,0.28343326,,,,,,,,,,,,,, -5945,,,0.7413513319832938,0.2705115420477731,0.7167771596484594,0.2911257767985896,3554.0,0.734261349505201,0.2926771584185632,3581.0,1498.9853432178495,1774.4164345264437,1498.9853432178495,274.7461721897125,0.4857480525970459,0.0 -6000,0.064733066,0.24732111,,,,,,,,,,,,,, -6100,0.18643826,0.30653962,,,,,,,,,,,,,, -6200,0.15941568,0.34455782,,,,,,,,,,,,,, -6289,,,0.7416772842407227,0.270162616457258,0.7174957739079206,0.2906290117649128,3554.0,0.7346909988262706,0.2924217004677464,3581.0,1579.0458467006683,1858.532172679901,1579.0458467006683,278.763644695282,0.5121357440948486,0.0 -6300,0.21127193,0.2324578,,,,,,,,,,,,,, -6400,0.14818242,0.26075616,,,,,,,,,,,,,, -6500,0.21365382,0.27827677,,,,,,,,,,,,,, -6600,0.18852745,0.27983817,,,,,,,,,,,,,, -6634,,,0.7426674025399345,0.2703203133174351,0.717994221959412,0.2914958346339512,3554.0,0.7350663795247486,0.2932752381819847,3581.0,1659.1377947330477,1942.6838665008545,1659.1377947330477,282.7860562801361,0.5379180908203125,0.0 -6700,0.112875864,0.19446781,,,,,,,,,,,,,, -6800,0.19028673,0.34029034,,,,,,,,,,,,,, -6900,0.1263666,0.3899785,,,,,,,,,,,,,, -6983,,,0.7435639926365444,0.2694126026970999,0.7192983890290869,0.2901051810569956,3554.0,0.7367043919950782,0.2917695225233873,3581.0,1739.3163893222809,2026.9179275035856,1739.3163893222809,286.80538868904114,0.562300443649292,0.0 -7000,0.210317,0.29829594,,,,,,,,,,,,,, -7100,0.15670754,0.29184383,,,,,,,,,,,,,, -7200,0.066206574,0.30112594,,,,,,,,,,,,,, -7300,0.13160045,0.25466892,,,,,,,,,,,,,, -7328,,,0.742255619594029,0.2700649499893188,0.7180247910584905,0.2905095175000879,3554.0,0.7355595694943801,0.2921486870265812,3581.0,1819.362901210785,2111.022508382797,1819.362901210785,290.8272798061371,0.5869457721710205,0.0 -7400,0.14619888,0.2634465,,,,,,,,,,,,,, -7500,0.061717514,0.2516152,,,,,,,,,,,,,, -7600,0.19562851,0.2752011,,,,,,,,,,,,,, -7672,,,0.7419414520263672,0.2691801445824759,0.7174110047657569,0.2900247053276238,3554.0,0.7349674551888439,0.2915839456506562,3581.0,1899.522413253784,2195.2415804862976,1899.522413253784,294.8489227294922,0.6131947040557861,0.0 -7700,0.09896945,0.3520224,,,,,,,,,,,,,, -7800,0.06352617,0.28958163,,,,,,,,,,,,,, -7900,0.21356754,0.24577206,,,,,,,,,,,,,, -8000,0.11266577,0.2638012,,,,,,,,,,,,,, -8023,,,0.7412242889404297,0.2693272318158831,0.7173855877620287,0.2899157213373136,3554.0,0.7348966878141581,0.2914978385271223,3581.0,1979.6026208400729,2279.386494874954,1979.6026208400729,298.8732056617737,0.6418027877807617,0.0 -8100,0.42376846,0.24813612,,,,,,,,,,,,,, -8200,0.10592177,0.3230064,,,,,,,,,,,,,, -8300,0.27258033,0.27567244,,,,,,,,,,,,,, -8370,,,0.7443105152675084,0.2692355258124215,0.7201239607880205,0.2901451269696117,3554.0,0.7374802424034836,0.2917135494842572,3581.0,2059.7492246627808,2363.59885597229,2059.7492246627808,302.89744091033936,0.671701192855835,0.0 -8400,0.41717407,0.26086825,,,,,,,,,,,,,, -8500,0.10710214,0.2657528,,,,,,,,,,,,,, -8600,0.38578957,0.24646011,,,,,,,,,,,,,, -8700,0.092426546,0.28139645,,,,,,,,,,,,,, -8714,,,0.7421479906354632,0.2703418731689453,0.7183052711293613,0.2912678372410488,3554.0,0.7357666220154985,0.2926012777942613,3581.0,2139.943512916565,2447.8618161678314,2139.943512916565,306.92110443115234,0.7052109241485596,0.0 -8800,0.13160671,0.23284289,,,,,,,,,,,,,, -8900,0.34556144,0.29902864,,,,,,,,,,,,,, -9000,0.16753481,0.2550075,,,,,,,,,,,,,, -9062,,,0.7411454745701381,0.269741552216666,0.7177406014789673,0.2902401659551737,3554.0,0.735274659225775,0.2916376006832937,3581.0,2220.011216402054,2531.983940124512,2220.011216402054,310.9387810230255,0.7300865650177002,0.0 -9100,0.14643894,0.27252582,,,,,,,,,,,,,, -9200,0.2597925,0.23183188,,,,,,,,,,,,,, -9300,0.107546605,0.33325914,,,,,,,,,,,,,, -9400,0.122480325,0.28157696,,,,,,,,,,,,,, -9411,,,0.7425169944763184,0.2701402732304164,0.7188875952931556,0.2910379850938203,3554.0,0.7360342154120707,0.2926036639774155,3581.0,2300.221267700196,2616.25523352623,2300.221267700196,314.9629945755005,0.7551014423370361,0.0 -9500,0.30337435,0.26865572,,,,,,,,,,,,,, -9600,0.14959958,0.27078837,,,,,,,,,,,,,, -9700,0.07803966,0.31523237,,,,,,,,,,,,,, -9756,,,0.7441386495317731,0.2689634391239711,0.7195929514939153,0.2901753182483645,3554.0,0.7370602741683538,0.291624919824246,3581.0,2380.346992969513,2700.43922996521,2380.346992969513,318.9835820198059,0.7812681198120117,0.0 -9800,0.26083776,0.27313185,,,,,,,,,,,,,, -9900,0.1994056,0.23449436,,,,,,,,,,,,,, -10000,0.18514988,0.39427418,,,,,,,,,,,,,, -10100,0.18171108,0.2952476,,,,,,,,,,,,,, -10102,,,0.7448515210832868,0.2686057601656232,0.7205442343793964,0.2896271696503939,3554.0,0.7375806666259425,0.2913862333321698,3581.0,2460.3542511463165,2784.5107316970825,2460.3542511463165,323.007447719574,0.8100852966308594,0.0 -10200,0.13247807,0.30590272,,,,,,,,,,,,,, -10300,0.16620065,0.26203868,,,,,,,,,,,,,, -10400,0.31388244,0.31767446,,,,,,,,,,,,,, -10450,,,0.743347304207938,0.2686093364443098,0.7185303146542628,0.2899211138637802,3554.0,0.7359422450956437,0.2915610382923764,3581.0,2540.4041180610657,2868.624881267548,2540.4041180610657,327.03503346443176,0.834989070892334,0.0 -10500,0.13108194,0.2747551,,,,,,,,,,,,,, -10600,0.1872198,0.3205393,,,,,,,,,,,,,, -10700,0.35616735,0.27766517,,,,,,,,,,,,,, -10795,,,0.7445851053510394,0.2687500544956752,0.7199884263330051,0.2898923308244231,3554.0,0.7371557214945197,0.2916025237909278,3581.0,2620.4691038131714,2952.7511084079742,2620.4691038131714,331.0597715377808,0.8598167896270752,0.0 -10800,0.100532606,0.23641221,,,,,,,,,,,,,, -10900,0.1308491,0.24804933,,,,,,,,,,,,,, -11000,0.22285734,0.31839994,,,,,,,,,,,,,, -11100,0.1373852,0.3732086,,,,,,,,,,,,,, -11141,,,0.7431984628949847,0.2680032934461321,0.7183742405124508,0.2890530201445554,3554.0,0.735852797315694,0.2906030880209788,3581.0,2700.4993851184845,3036.8381135463715,2700.4993851184845,335.07978916168213,0.8851232528686523,0.0 -11200,0.098097,0.2889493,,,,,,,,,,,,,, -11300,0.26477587,0.21628691,,,,,,,,,,,,,, -11400,0.1526048,0.42092037,,,,,,,,,,,,,, -11490,,,0.7431731224060059,0.2691079378128052,0.7189385666898214,0.2902030708686515,3554.0,0.7359646070406312,0.2918942517256876,3581.0,2780.4919424057007,3120.895471572876,2780.4919424057007,339.10322880744934,0.9145634174346924,0.0 -11500,0.2479141,0.32857203,,,,,,,,,,,,,, -11600,0.10011446,0.25681326,,,,,,,,,,,,,, -11700,0.13575974,0.29561377,,,,,,,,,,,,,, -11800,0.073259935,0.24357319,,,,,,,,,,,,,, -11836,,,0.7435904230390277,0.2704858439309256,0.7188783902161298,0.2916788027134918,3554.0,0.7362211558180327,0.2931858244903658,3581.0,2860.522198200226,3204.9869713783264,2860.522198200226,343.12741708755493,0.9401652812957764,0.0 -11900,0.13166259,0.29546988,,,,,,,,,,,,,, -12000,0.18728079,0.30171797,,,,,,,,,,,,,, -12100,0.24895579,0.24519722,,,,,,,,,,,,,, -12185,,,0.7444327899387905,0.2688614640917097,0.7195729613639561,0.2902460049965707,3554.0,0.7370257086009494,0.2918707989541155,3581.0,2940.721024751663,3289.250340938568,2940.721024751663,347.15452122688293,0.9660358428955078,0.0 -12200,0.12116897,0.2826181,,,,,,,,,,,,,, -12300,0.10659321,0.2798595,,,,,,,,,,,,,, -12400,0.22979918,0.27992868,,,,,,,,,,,,,, -12500,0.060829934,0.31896704,,,,,,,,,,,,,, -12533,,,0.7431977135794503,0.2686043637139456,0.7186979294896595,0.2894434115828995,3554.0,0.7362679931845155,0.2909331994161896,3581.0,3020.9883308410645,3373.5796024799347,3020.9883308410645,351.1773955821991,0.9935295581817628,0.0 -12600,0.2243465,0.40106457,,,,,,,,,,,,,, -12700,0.12531343,0.2184596,,,,,,,,,,,,,, -12800,0.13883206,0.26101387,,,,,,,,,,,,,, -12879,,,0.7460365976606097,0.2675271034240722,0.7216288534925436,0.2885580755179023,3554.0,0.7389217015585731,0.2900932970342607,3581.0,3101.179794549942,3457.8299860954285,3101.179794549942,355.19680309295654,1.0215466022491455,0.0 -12900,0.11340587,0.22057945,,,,,,,,,,,,,, -13000,0.10179449,0.38127446,,,,,,,,,,,,,, -13100,0.18471962,0.28610072,,,,,,,,,,,,,, -13200,0.14373325,0.22833318,,,,,,,,,,,,,, -13224,,,0.7422705377851214,0.2689677136284964,0.7188009027020611,0.289795334042628,3554.0,0.7361949759799986,0.2912266658558189,3581.0,3181.1878747940063,3541.900530576706,3181.1878747940063,359.2215950489044,1.0475833415985107,0.0 -13300,0.2252976,0.28259653,,,,,,,,,,,,,, -13400,0.21846609,0.36474153,,,,,,,,,,,,,, -13500,0.17302966,0.29365847,,,,,,,,,,,,,, -13570,,,0.7451451846531459,0.2687478576387678,0.7206711133142234,0.2898253535848691,3554.0,0.7378128081585102,0.2914237645843514,3581.0,3261.1537551879883,3625.932768821716,3261.1537551879883,363.24598574638367,1.077925205230713,0.0 -13600,0.08661338,0.26986217,,,,,,,,,,,,,, -13700,0.07717413,0.34934437,,,,,,,,,,,,,, -13800,0.17884152,0.26045075,,,,,,,,,,,,,, -13900,0.17608713,0.23927568,,,,,,,,,,,,,, -13916,,,0.7433956691196987,0.2678613662719726,0.718911088847953,0.2888582365930114,3554.0,0.7366030132993577,0.2901630417590233,3581.0,3341.286249399185,3710.1297755241394,3341.286249399185,367.2687175273895,1.1079049110412598,0.0 -14000,0.09319266,0.39698523,,,,,,,,,,,,,, -14100,0.10555782,0.27560195,,,,,,,,,,,,,, -14200,0.31779,0.2211548,,,,,,,,,,,,,, -14260,,,0.7445827892848423,0.2692598274775913,0.71959302018852,0.2909322984445343,3554.0,0.7367140049043563,0.2925438048685772,3581.0,3421.35107922554,3794.255617141724,3421.35107922554,371.2918519973755,1.134298324584961,0.0 -14300,0.18188426,0.35950613,,,,,,,,,,,,,, -14400,0.14079951,0.28853518,,,,,,,,,,,,,, -14500,0.17608008,0.32744354,,,,,,,,,,,,,, -14600,0.102910616,0.24126546,,,,,,,,,,,,,, -14605,,,0.7452515874590192,0.2677821431841169,0.7207931836267234,0.2887048758880838,3554.0,0.7381265571549497,0.2902641136597668,3581.0,3501.3453781604767,3878.308022737503,3501.3453781604767,375.3122110366821,1.1605091094970703,0.0 -14700,0.14730673,0.30192286,,,,,,,,,,,,,, -14800,0.15710348,0.28396037,,,,,,,,,,,,,, -14900,0.33914846,0.20393966,,,,,,,,,,,,,, -14952,,,0.7432610648018974,0.2674763543265206,0.7191397044922974,0.2883870259522721,3554.0,0.7366066948390813,0.2898208290063879,3581.0,3581.5688560009003,3962.595205783844,3581.5688560009003,379.3365240097046,1.1882827281951904,0.0 -15000,0.13484551,0.26247132,,,,,,,,,,,,,, -15100,0.12462779,0.30548882,,,,,,,,,,,,,, -15200,0.35461596,0.26615345,,,,,,,,,,,,,, -15296,,,0.7455353736877441,0.2680352415357317,0.7209673244495639,0.2894263409736388,3554.0,0.7383170427473122,0.2909001337353393,3581.0,3661.5406353473663,4046.630147695541,3661.5406353473663,383.3614344596863,1.214857578277588,0.0 -15300,0.2160832,0.22644086,,,,,,,,,,,,,, -15400,0.18712136,0.29574388,,,,,,,,,,,,,, -15500,0.08982608,0.29401517,,,,,,,,,,,,,, -15600,0.082862765,0.28404722,,,,,,,,,,,,,, -15642,,,0.7450480461120605,0.26785751751491,0.7202514579742896,0.2891894132821292,3554.0,0.7376256632225635,0.2907048416883377,3581.0,3741.661069869995,4130.811549901962,3741.661069869995,387.38481283187866,1.240983963012695,0.0 -15700,0.13080813,0.19353898,,,,,,,,,,,,,, -15800,0.15866478,0.33773905,,,,,,,,,,,,,, -15900,0.1929335,0.24779113,,,,,,,,,,,,,, -15989,,,0.7453368050711495,0.2676250594002859,0.7206592291476154,0.2887504204109806,3554.0,0.7378837118865191,0.2903537318813704,3581.0,3821.717903852463,4214.933522701263,3821.717903852463,391.41095185279846,1.268481731414795,0.0 -16000,0.15105166,0.27234232,,,,,,,,,,,,,, -16100,0.26810405,0.30146766,,,,,,,,,,,,,, -16200,0.2010364,0.22922055,,,,,,,,,,,,,, -16300,0.19048497,0.2626389,,,,,,,,,,,,,, -16332,,,0.7455001558576312,0.2670436246054513,0.7205200538785523,0.2883993222865081,3554.0,0.7379716597799149,0.289822158451288,3581.0,3901.7154943943024,4298.989427089691,3901.7154943943024,395.4308760166168,1.295414924621582,0.0 -16400,0.16262525,0.28185707,,,,,,,,,,,,,, -16500,0.0936368,0.25593668,,,,,,,,,,,,,, -16600,0.09802523,0.24941353,,,,,,,,,,,,,, -16679,,,0.7455932753426688,0.2674624579293387,0.721336626644274,0.2885471530757597,3554.0,0.7388266632923765,0.2899093564014067,3581.0,3981.8962936401367,4383.229125022888,3981.8962936401367,399.4510545730591,1.3222932815551758,0.0 -16700,0.10216367,0.38106564,,,,,,,,,,,,,, -16800,0.08941659,0.22828555,,,,,,,,,,,,,, -16900,0.09643417,0.3496001,,,,,,,,,,,,,, -17000,0.099550076,0.23279679,,,,,,,,,,,,,, -17027,,,0.7456743376595634,0.2673143659319196,0.7211200325557471,0.2884896556916502,3554.0,0.738394286904845,0.2900039515193032,3581.0,4062.037636041641,4467.438356161118,4062.037636041641,403.477201461792,1.3522214889526367,0.0 -17100,0.19409035,0.2846017,,,,,,,,,,,,,, -17200,0.2777493,0.30796918,,,,,,,,,,,,,, -17300,0.13097888,0.2489891,,,,,,,,,,,,,, -17371,,,0.7442355837140765,0.2674683843340192,0.7196316952509496,0.2883284638017902,3554.0,0.7371324050762706,0.2897704123651738,3581.0,4142.102548122406,4551.567642211914,4142.102548122406,407.501473903656,1.3806486129760742,0.0 -17400,0.39410973,0.2491446,,,,,,,,,,,,,, -17500,0.17428325,0.3501829,,,,,,,,,,,,,, -17600,0.38613486,0.35755485,,,,,,,,,,,,,, -17700,0.08067244,0.24323566,,,,,,,,,,,,,, -17718,,,0.7450887135096959,0.2668800013405936,0.7200384360052055,0.2883496560873311,3554.0,0.7374111794453365,0.2898624508582623,3581.0,4222.079712629318,4635.606120109558,4222.079712629318,411.522497177124,1.4090490341186523,0.0 -17800,0.22518027,0.2647225,,,,,,,,,,,,,, -17900,0.10826801,0.28678513,,,,,,,,,,,,,, -18000,0.15426232,0.27841687,,,,,,,,,,,,,, -18063,,,0.7464371408735003,0.267510039465768,0.7215637997019204,0.2890953360210326,3554.0,0.7389545627094387,0.2905329683245776,3581.0,4302.183186531067,4719.768888950348,4302.183186531067,415.54224705696106,1.436967849731445,0.0 -18100,0.10304708,0.23965348,,,,,,,,,,,,,, -18200,0.12345117,0.28186116,,,,,,,,,,,,,, -18300,0.2831268,0.23691307,,,,,,,,,,,,,, -18400,0.22063383,0.2743618,,,,,,,,,,,,,, -18409,,,0.7450125558035714,0.2682968207768031,0.7206625264886396,0.2893638632306907,3554.0,0.7378443739528064,0.2908939978358,3581.0,4382.376043796539,4804.021505832672,4382.376043796539,419.5635986328125,1.4638168811798096,0.0 -18500,0.09315914,0.27970162,,,,,,,,,,,,,, -18600,0.15203933,0.3473646,,,,,,,,,,,,,, -18700,0.41501436,0.21755648,,,,,,,,,,,,,, -18755,,,0.7469017846243722,0.2671412229537964,0.7221098531144485,0.28858466032991,3554.0,0.7393603501989667,0.2900614585333182,3581.0,4462.424655199051,4888.134497642517,4462.424655199051,423.5873386859894,1.4927854537963867,0.0 -18800,0.13915017,0.36103246,,,,,,,,,,,,,, -18900,0.119248055,0.31654274,,,,,,,,,,,,,, -19000,0.1298931,0.31742522,,,,,,,,,,,,,, -19100,0.11898545,0.27395847,,,,,,,,,,,,,, -19102,,,0.7452338082449776,0.2671962295259748,0.7205666288205191,0.2884972807927687,3554.0,0.737858145638439,0.289944228763788,3581.0,4542.533133029938,4972.306104183197,4542.533133029938,427.6101665496826,1.521597385406494,0.0 -19200,0.17284906,0.22876297,,,,,,,,,,,,,, -19300,0.09834536,0.25338218,,,,,,,,,,,,,, -19400,0.064293385,0.27315217,,,,,,,,,,,,,, -19447,,,0.7461850302559989,0.2665221009935651,0.7216126415658413,0.2877850207842395,3554.0,0.7391323674427535,0.2891450960254642,3581.0,4622.55202126503,5056.391085863113,4622.55202126503,431.6334185600281,1.5527558326721191,0.0 -19500,0.13972062,0.37583393,,,,,,,,,,,,,, -19600,0.18311471,0.27494594,,,,,,,,,,,,,, -19700,0.07149999,0.28365082,,,,,,,,,,,,,, -19792,,,0.7475410188947406,0.2668184552873884,0.7225712060794176,0.2883770995818971,3554.0,0.7397317084744136,0.2899258210651703,3581.0,4702.551931381226,5140.450572013855,4702.551931381226,435.6535060405731,1.5806972980499268,0.0 -19800,0.34110394,0.31404433,,,,,,,,,,,,,, -19900,0.17171285,0.2435101,,,,,,,,,,,,,, -20000,0.14301085,0.2398829,,,,,,,,,,,,,, -20100,0.18419065,0.3401602,,,,,,,,,,,,,, -20138,,,0.7457947049822126,0.2669763224465506,0.7215591284688028,0.2879447013927969,3554.0,0.7388371624982547,0.2895012849937168,3581.0,4782.750027894974,5224.711462259293,4782.750027894974,439.6756844520569,1.6097443103790283,0.0 -20200,0.30856544,0.38177392,,,,,,,,,,,,,, -20300,0.12516318,0.24522606,,,,,,,,,,,,,, -20400,0.15363818,0.24358235,,,,,,,,,,,,,, -20483,,,0.7456110545567104,0.2667371375220163,0.7206008387336452,0.2882838466560565,3554.0,0.7379801136859466,0.2897698328635507,3581.0,4862.860502004623,5308.885147809982,4862.860502004623,443.6998512744904,1.637216329574585,0.0 -20500,0.23550588,0.21777293,,,,,,,,,,,,,, -20600,0.23016089,0.30988875,,,,,,,,,,,,,, -20700,0.15020642,0.32781368,,,,,,,,,,,,,, -20800,0.1664838,0.2142416,,,,,,,,,,,,,, -20829,,,0.7483904702322823,0.2659840413502284,0.7234748149092571,0.2876581590230638,3554.0,0.7407019305448896,0.2891209273989458,3581.0,4942.934024333954,5393.02094745636,4942.934024333954,447.72234869003296,1.6652333736419678,0.0 -20900,0.14582297,0.33051473,,,,,,,,,,,,,, -21000,0.24132736,0.32247117,,,,,,,,,,,,,, -21100,0.26961085,0.31039113,,,,,,,,,,,,,, -21176,,,0.7456222942897252,0.2673719440187727,0.7210958520549029,0.2885616132900429,3554.0,0.738236389756702,0.2900835477716594,3581.0,5022.919233322144,5477.071233034134,5022.919233322144,451.7472302913666,1.6936416625976562,0.0 -21200,0.30652863,0.24785826,,,,,,,,,,,,,, -21300,0.10549263,0.32194045,,,,,,,,,,,,,, -21400,0.10069514,0.2594173,,,,,,,,,,,,,, -21500,0.14814502,0.24144258,,,,,,,,,,,,,, -21522,,,0.7460637092590332,0.2674393142972673,0.7211438695835678,0.2888864013809264,3554.0,0.7384746671888089,0.2903546181779705,3581.0,5103.048098325729,5561.263752222061,5103.048098325729,455.7704327106476,1.7223103046417236,0.0 -21600,0.18440545,0.27016655,,,,,,,,,,,,,, -21700,0.17169447,0.25072175,,,,,,,,,,,,,, -21800,0.2660393,0.31816268,,,,,,,,,,,,,, -21868,,,0.7460754939488002,0.2662193434579031,0.7210180210678109,0.2877452122608329,3554.0,0.7384185577963558,0.2892208062081122,3581.0,5183.179874658585,5645.4620842933655,5183.179874658585,459.7928726673126,1.7547547817230225,0.0 -21900,0.17871726,0.34244257,,,,,,,,,,,,,, -22000,0.13668092,0.22988823,,,,,,,,,,,,,, -22100,0.065150164,0.36632705,,,,,,,,,,,,,, -22200,0.1742855,0.26007348,,,,,,,,,,,,,, -22204,,,0.7462743350437709,0.2678557123456682,0.7213033784556134,0.2893444913521736,3554.0,0.7386037937857791,0.2908094587754817,3581.0,5263.278465986252,5729.6257474422455,5263.278465986252,463.8135120868683,1.7877426147460938,0.0 -22300,0.16534303,0.27419302,,,,,,,,,,,,,, -22400,0.28223303,0.3604869,,,,,,,,,,,,,, -22500,0.20367897,0.26146942,,,,,,,,,,,,,, -22551,,,0.7475292342049735,0.2660458598818098,0.7228142475907429,0.2876492974190612,3554.0,0.7401515403562203,0.2890804986386484,3581.0,5343.419144153595,5813.834905862808,5343.419144153595,467.8380720615387,1.8198356628417969,0.0 -22600,0.10114085,0.21668437,,,,,,,,,,,,,, -22700,0.1700131,0.22565019,,,,,,,,,,,,,, -22800,0.2422924,0.2698168,,,,,,,,,,,,,, -22892,,,0.7466079848153251,0.266058257647923,0.7219377731297482,0.2875659193426421,3554.0,0.7391687737800196,0.2890605569651459,3581.0,5423.412467479706,5897.893646717072,5423.412467479706,471.8632171154022,1.8486895561218264,0.0 -22900,0.2211454,0.22714038,,,,,,,,,,,,,, -23000,0.16713116,0.33099383,,,,,,,,,,,,,, -23100,0.10962717,0.28793952,,,,,,,,,,,,,, -23200,0.24277326,0.30780306,,,,,,,,,,,,,, -23240,,,0.7483805247715541,0.2657348428453718,0.7238905546567248,0.287237696521525,3554.0,0.7411558507574699,0.2887749990182561,3581.0,5503.533691883087,5982.082504987717,5503.533691883087,475.89069175720215,1.8772878646850584,0.0 -23240,,,,,,,,,,,5503.533691883087,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index eeba17d94..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.495994091033936,0.0,54.74183368682861,1,0,54.74183368682861,0.0009000000427477,6.912177562713623,10000,92.23791027069092,0.001335299690254,6.911616802215576,0.0011199999134987,6.912059783935547,50000 -55.4818127155304,0.0258848667144775,564.9765136241913,1508,0,564.9765136241913,0.0485000014305114,5.668511390686035,10000,620.535459280014,0.070990115404129,5.393021583557129,0.0666799992322921,5.462131500244141,50000 -73.87850284576416,0.0510666370391845,1074.9497528076172,3013,0,1074.9497528076172,0.1177000030875206,4.772252559661865,10000,1148.9814901351929,0.1768175959587097,4.263920783996582,0.1591600030660629,4.370645523071289,50000 -91.94419121742249,0.0785439014434814,1584.9473168849945,4518,0,1584.9473168849945,0.1901000142097473,4.155796527862549,10000,1677.1219086647034,0.2875677645206451,3.482984781265259,0.2633000016212463,3.6200780868530273,50000 -110.18438458442688,0.1052768230438232,2095.0244052410126,6024,0,2095.0244052410126,0.255700021982193,3.714011907577514,10000,2205.515574455261,0.3698580861091614,2.9841694831848145,0.3408399820327759,3.133474349975586,50000 -128.3986883163452,0.1336238384246826,2605.0299191474915,7529,0,2605.0299191474915,0.3109000027179718,3.379472017288208,10000,2733.8128702640533,0.4408880770206451,2.5516676902771,0.4115200042724609,2.7000417709350586,50000 -146.5477843284607,0.1685779094696045,3115.184098005295,9036,0,3115.184098005295,0.3530000150203705,3.093779325485229,10000,3262.2012915611267,0.4923070669174194,2.256446599960327,0.4586599767208099,2.4324121475219727,50000 -164.86400961875916,0.1949026584625244,3625.3059175014496,10544,0,3625.3059175014496,0.3918000161647796,2.887331008911133,10000,3790.7159888744354,0.5581353306770325,1.9381918907165527,0.5031399726867676,2.2289462089538574,50000 -183.4552013874054,0.2289702892303466,4135.503875255585,12053,0,4135.503875255585,0.4032000303268432,2.825645446777344,10000,4319.589529514313,0.5622209906578064,1.9551576375961304,0.5192399621009827,2.169461488723755,50000 -204.8366265296936,0.2584633827209472,4645.7563996315,13562,0,4645.7563996315,0.4192000329494476,2.732748746871948,10000,4851.302194833756,0.5889070630073547,1.835581421852112,0.5423399806022644,2.049694776535034,50000 -228.2326774597168,0.2825262546539306,5155.886093854904,15073,0,5155.886093854904,0.426000028848648,2.665627956390381,10000,5384.902832508087,0.5912587642669678,1.7979090213775637,0.551800012588501,1.9871643781661987,50000 -251.177282333374,0.3175804615020752,5666.106993675232,16583,0,5666.106993675232,0.4445000290870666,2.6144609451293945,10000,5918.15540766716,0.6018216013908386,1.7547258138656616,0.5679000020027161,1.932205080986023,50000 -276.3594694137573,0.3460025787353515,6176.272500514984,18094,0,6176.272500514984,0.4395000338554382,2.643144369125366,10000,6453.581561326981,0.6030970811843872,1.7894777059555054,0.5586400032043457,1.985974907875061,50000 -300.57800698280334,0.3844051361083984,6686.243889808655,19605,0,6686.243889808655,0.4588000178337097,2.5408897399902344,10000,6987.860915660858,0.6378746628761292,1.6028066873550415,0.5719999670982361,1.902109146118164,50000 -324.5791804790497,0.4113876819610595,7196.342430591583,21117,0,7196.342430591583,0.4389000236988067,2.6041297912597656,10000,7522.037932395935,0.6107900142669678,1.6893270015716553,0.5640599727630615,1.925768256187439,50000 -349.167008638382,0.4475350379943847,7706.454968452454,22629,0,7706.454968452454,0.4581000208854675,2.543734312057495,10000,8056.826657772064,0.6161710619926453,1.6690157651901243,0.5745599865913391,1.873211145401001,50000 -373.81429505348206,0.4732851982116699,8216.385123491287,24126,0,8216.385123491287,0.4648000299930572,2.4927031993865967,10000,8591.48088502884,0.634785532951355,1.5688326358795166,0.5852999687194824,1.8004077672958374,50000 -397.8678450584412,0.5002624988555908,8726.36254477501,25638,0,8726.36254477501,0.4674000144004822,2.484600067138672,10000,9125.590382575989,0.6300820708274841,1.6126160621643066,0.5874999761581421,1.8088067770004272,50000 -421.4299862384796,0.5421044826507568,9236.353283643724,27151,0,9236.353283643724,0.4571000337600708,2.5513482093811035,10000,9659.236914873123,0.6246811151504517,1.6491903066635132,0.5797399878501892,1.853614926338196,50000 -446.6768915653229,0.7980978488922119,9746.11990594864,28663,0,9746.11990594864,0.4710000157356262,2.4751579761505127,10000,10194.557493686676,0.6515266299247742,1.5246480703353882,0.590399980545044,1.8066246509552,50000 -471.7249677181244,0.8278708457946777,10256.066632509232,30176,0,10256.066632509232,0.4713000357151031,2.4669439792633057,10000,10729.63278746605,0.6483777165412903,1.5167217254638672,0.5963599681854248,1.7699456214904783,50000 -497.735536813736,0.8581020832061768,10766.203356981276,31689,0,10766.203356981276,0.4769000113010406,2.4278860092163086,10000,11265.86218070984,0.6477000713348389,1.5327821969985962,0.6036799550056458,1.7479439973831177,50000 -522.9339916706085,0.8919429779052734,11276.440303564072,33203,0,11276.440303564072,0.4764000177383423,2.419763803482056,10000,11801.382081270218,0.6440529227256775,1.5270545482635498,0.6020199656486511,1.7369011640548706,50000 -548.0605285167694,0.9209282398223876,11786.374200820925,34715,0,11786.374200820925,0.4750000238418579,2.4321818351745605,10000,12336.523176193235,0.6390106678009033,1.5380759239196775,0.5995799899101257,1.7401143312454224,50000 -573.6493542194366,0.9530339241027832,12296.36655974388,36228,0,12296.36655974388,0.4812000095844269,2.404179334640503,10000,12872.187220096588,0.6710578799247742,1.4383693933486938,0.6104999780654907,1.7170253992080688,50000 -597.3628311157227,0.9872357845306396,12806.337379455566,37741,0,12806.337379455566,0.4754000306129455,2.427032232284546,10000,13405.956619262695,0.6660754084587097,1.4482871294021606,0.6075599789619446,1.733280897140503,50000 -620.9647953510284,1.0185627937316897,13316.273628473282,39255,0,13316.273628473282,0.4769000113010406,2.424974203109741,10000,13939.578382253649,0.6478993892669678,1.5248180627822876,0.5951799750328064,1.7707302570343018,50000 -644.1830842494965,1.0521259307861328,13826.47026848793,40769,0,13826.47026848793,0.4877000153064728,2.3885064125061035,10000,14473.0782289505,0.6623883843421936,1.4788810014724731,0.6086199879646301,1.71470308303833,50000 -673.0554871559143,1.0817155838012695,14336.686334371569,42284,0,14336.686334371569,0.4802000224590301,2.396982431411743,10000,15012.248512983322,0.6502909660339355,1.5233001708984375,0.6042400002479553,1.73208749294281,50000 -696.379273891449,1.1133880615234375,14846.899607419968,43798,0,14846.899607419968,0.4892000257968902,2.3520050048828125,10000,15545.86763715744,0.6564692258834839,1.4994585514068604,0.6126999855041504,1.6944860219955444,50000 -719.1942150592804,1.1434450149536133,15356.965127944946,45312,0,15356.965127944946,0.4855000376701355,2.367635488510132,10000,16078.829252958298,0.6901307106018066,1.335003137588501,0.60971999168396,1.702130913734436,50000 -739.6124730110168,1.1734073162078855,15867.11354660988,46826,0,15867.11354660988,0.4852000176906585,2.384239435195923,10000,16609.47766971588,0.6707788705825806,1.4408752918243408,0.6131199598312378,1.7045730352401731,50000 -759.0497002601624,1.20459246635437,16377.163805484772,48340,0,16377.163805484772,0.4865000247955322,2.390916585922241,10000,17139.047789812088,0.6635841727256775,1.4390766620635986,0.6102799773216248,1.6896299123764038,50000 -777.0617418289185,1.239149808883667,16887.197038412094,49854,0,16887.197038412094,0.4975000321865082,2.350130081176758,10000,17667.17898607254,0.6621492505073547,1.4692184925079346,0.6146799921989441,1.694058895111084,50000 -794.8757519721985,1.2756733894348145,17397.333253145218,51369,0,17397.333253145218,0.4969000220298767,2.307199478149414,10000,18195.217745542526,0.6635642647743225,1.439455270767212,0.6191399693489075,1.646724820137024,50000 -812.5476040840149,1.3171625137329102,17907.54860687256,52884,0,17907.54860687256,0.4881000220775604,2.357887029647827,10000,18723.197404146194,0.6588408946990967,1.4924957752227783,0.6126799583435059,1.7022305727005005,50000 -830.1899147033691,1.362135410308838,18417.55141377449,54399,0,18417.55141377449,0.5001000165939331,2.3184971809387207,10000,19250.939685344696,0.7075693607330322,1.2901153564453125,0.625819981098175,1.6629422903060913,50000 -847.9586672782898,1.3999717235565186,18927.4927611351,55913,0,18927.4927611351,0.4935000240802765,2.397702932357788,10000,19778.73899126053,0.6686663031578064,1.4696898460388184,0.6124399900436401,1.7320475578308103,50000 -865.3042812347412,1.4369032382965088,19437.53251218796,57428,0,19437.53251218796,0.5006000399589539,2.341546535491944,10000,20306.21404647827,0.6722536683082581,1.435630440711975,0.6198399662971497,1.672557711601257,50000 -882.6476812362671,1.4736227989196775,19947.70263171196,58942,0,19947.70263171196,0.5063000321388245,2.299123764038086,10000,20833.81520462036,0.6779735088348389,1.4131288528442385,0.6226599812507629,1.664193868637085,50000 -900.1488373279572,1.509693622589111,20457.742354631424,60457,0,20457.742354631424,0.4940000176429748,2.3347575664520264,10000,21361.44446396828,0.6662946343421936,1.4448381662368774,0.6218199729919434,1.656185269355774,50000 -917.5497057437896,1.5463056564331057,20967.97685956955,61971,0,20967.97685956955,0.5004000067710876,2.3209445476531982,10000,21889.169400691982,0.6689453125,1.4233334064483645,0.6264599561691284,1.6362653970718384,50000 -935.2121860980988,1.5857601165771484,21477.900118112564,63485,0,21477.900118112564,0.5062000155448914,2.2677626609802246,10000,22416.84796738625,0.7110969424247742,1.2208093404769895,0.6271799802780151,1.6009138822555542,50000 -952.6368882656096,1.6321525573730469,21988.11581516266,65000,0,21988.11581516266,0.5092000365257263,2.258240222930908,10000,22944.5864379406,0.6969068646430969,1.2933167219161987,0.6342200040817261,1.5750954151153564,50000 -970.0448710918428,1.6718873977661133,22498.35578727722,66515,0,22498.35578727722,0.5099000334739685,2.246936559677124,10000,23472.325043201447,0.6873405575752258,1.344896912574768,0.632099986076355,1.5866305828094482,50000 -987.3612501621246,1.7124810218811035,23008.40454888344,68030,0,23008.40454888344,0.5042000412940979,2.301208734512329,10000,23999.7828745842,0.6845503449440002,1.392478346824646,0.6325799822807312,1.6152527332305908,50000 -1005.0094563961028,1.7519042491912842,23518.33584046364,69544,0,23518.33584046364,0.5063000321388245,2.339195966720581,10000,24527.45420455933,0.6779336333274841,1.4261982440948486,0.6293999552726746,1.6455281972885132,50000 -1022.608335018158,1.7905707359313965,24028.55358481407,71059,0,24028.55358481407,0.4982000291347503,2.287114143371582,10000,25055.36220574379,0.6732102632522583,1.407064437866211,0.6301599740982056,1.6089874505996704,50000 -1041.3291449546814,1.8296470642089844,24538.758437633514,72574,0,24538.758437633514,0.5004000067710876,2.2946672439575195,10000,25584.37909555435,0.7083266973495483,1.2449584007263184,0.6342399716377258,1.5873527526855469,50000 -1058.8174991607666,1.8682844638824463,25048.81888151169,74089,0,25048.81888151169,0.5052000284194946,2.28117036819458,10000,26112.01877140999,0.6872608065605164,1.3399715423583984,0.6256399750709534,1.6160084009170532,50000 -1076.917620420456,1.9076271057128904,25558.90424060821,75604,0,25558.90424060821,0.5152000188827515,2.232867956161499,10000,26640.29567885399,0.7000358700752258,1.2867945432662964,0.6447199583053589,1.5424926280975342,50000 -1094.5691964626312,1.9482781887054443,26069.08917927742,77119,0,26069.08917927742,0.5154000520706177,2.214272737503052,10000,27168.2249045372,0.6952527165412903,1.321738600730896,0.647599995136261,1.5462092161178589,50000 -1112.730740070343,2.8109002113342285,26578.31113815308,78632,0,26578.31113815308,0.5111000537872314,2.251178503036499,10000,27696.523701667786,0.6901904940605164,1.3460441827774048,0.6373400092124939,1.582018494606018,50000 -1129.9987313747406,2.845014333724976,27088.510328292847,80148,0,27088.510328292847,0.5189000368118286,2.245602369308472,10000,28224.07705426216,0.6938576102256775,1.3477110862731934,0.6477599740028381,1.5738341808319092,50000 -1147.4161262512207,2.885287284851074,27598.55618476868,81663,0,27598.55618476868,0.5138000249862671,2.23746657371521,10000,28751.632429361343,0.7150828838348389,1.2265141010284424,0.6447399854660034,1.5567433834075928,50000 -1164.917008638382,2.9261295795440674,28108.67711853981,83177,0,28108.67711853981,0.5205000042915344,2.2279982566833496,10000,29279.347688674927,0.7067522406578064,1.2656277418136597,0.6476799845695496,1.5454624891281128,50000 -1182.339945077896,2.973907709121704,28618.81132388115,84692,0,28618.81132388115,0.526900053024292,2.189680576324463,10000,29807.00520181656,0.70804762840271,1.2592806816101074,0.6536399722099304,1.508924961090088,50000 -1199.9780325889587,3.0153017044067383,29128.871512889866,86207,0,29128.871512889866,0.5340999960899353,2.161987543106079,10000,30334.79681873321,0.7067123651504517,1.260049819946289,0.653659999370575,1.5093188285827637,50000 -1217.268765926361,3.0644567012786865,29638.96631598473,87721,0,29638.96631598473,0.5281000137329102,2.171442985534668,10000,30862.28338265419,0.7043008208274841,1.2704253196716309,0.6558200120925903,1.510108232498169,50000 -1234.6226682662964,3.106571912765503,30149.158131837845,89236,0,30149.158131837845,0.5405000448226929,2.117152214050293,10000,31389.92377448082,0.7318239808082581,1.1713844537734983,0.6634599566459656,1.469886064529419,50000 -1251.9876792430878,3.152724266052246,30659.136559963223,90751,0,30659.136559963223,0.5282000303268433,2.163940191268921,10000,31917.36545681953,0.724609375,1.1772572994232178,0.6525599956512451,1.5016096830368042,50000 -1269.5212585926056,3.195277690887451,31169.08204507828,92266,0,31169.08204507828,0.541700005531311,2.0914676189422607,10000,32444.939204216003,0.7249082922935486,1.186362624168396,0.6613199710845947,1.461118459701538,50000 -1286.7298786640167,3.240107774734497,31679.052606105804,93780,0,31679.052606105804,0.5371000170707703,2.135526180267334,10000,32972.2160525322,0.7208425998687744,1.238020896911621,0.6592599749565125,1.504926085472107,50000 -1303.9475963115692,3.285914421081543,32189.051352262497,95294,0,32189.051352262497,0.5412000417709351,2.117387533187866,10000,33499.531002283096,0.7150231003761292,1.20890474319458,0.6631799936294556,1.4594324827194214,50000 -1321.4094231128693,3.3290951251983643,32699.01106214524,96808,0,32699.01106214524,0.5457000136375427,2.08373498916626,10000,34027.04830813408,0.721101701259613,1.1796294450759888,0.6640200018882751,1.435350775718689,50000 -1338.7655036449432,3.371743440628052,33209.20894575119,98323,0,33209.20894575119,0.5379000306129456,2.11805272102356,10000,34554.6969575882,0.7517338991165161,1.0737617015838623,0.663100004196167,1.4655168056488037,50000 -1355.984681367874,3.415469169616699,33719.22984600067,99838,0,33719.22984600067,0.5509999990463257,2.067492961883545,10000,35082.03232550621,0.7379822731018066,1.116407036781311,0.6669999957084656,1.440180420875549,50000 -1373.4264228343964,3.4589407444000244,34229.33446264267,101353,0,34229.33446264267,0.5529000163078308,2.0516908168792725,10000,35609.674355983734,0.7453961968421936,1.0961382389068604,0.676099956035614,1.4047099351882937,50000 -1390.8268103599548,3.504430055618286,34739.40772628784,102868,0,34739.40772628784,0.5488000512123108,2.043478488922119,10000,36137.24479365349,0.7364476919174194,1.114004135131836,0.672980010509491,1.399997353553772,50000 -1408.2055933475494,3.548323154449463,35249.39946103096,104382,0,35249.39946103096,0.5547000169754028,2.052401065826416,10000,36664.71185588837,0.7351921200752258,1.1262609958648682,0.6737599968910217,1.4070687294006348,50000 -1425.5483980178833,3.5949106216430664,35759.47870969772,105897,0,35759.47870969772,0.5578000545501709,2.038433790206909,10000,37192.232377290726,0.7310666441917419,1.1510498523712158,0.6747599840164185,1.4065779447555542,50000 -1442.9533438682556,3.6414597034454346,36269.60155582428,107412,0,36269.60155582428,0.5497000217437744,2.1009111404418945,10000,37719.8588643074,0.7695910334587097,1.021224856376648,0.6710999608039856,1.4482934474945068,50000 -1460.3775732517242,3.689704656600952,36779.734280347824,108927,0,36779.734280347824,0.5545000433921814,2.0375123023986816,10000,38247.51591229439,0.7567362785339355,1.0428543090820312,0.6768999695777893,1.38744056224823,50000 -1479.0125722885132,3.740776300430298,37289.76195025444,110442,0,37289.76195025444,0.5482000112533569,2.050555467605591,10000,38776.281270504,0.7382413744926453,1.1036720275878906,0.6723799705505371,1.4153038263320925,50000 -1496.186547756195,3.787414312362671,37799.80076622963,111957,0,37799.80076622963,0.5591000318527222,2.00235652923584,10000,39303.59298801422,0.7502391338348389,1.063479781150818,0.6810599565505981,1.373931050300598,50000 -1513.319462299347,3.83447790145874,38309.85378551483,113472,0,38309.85378551483,0.5565000176429749,2.0129952430725098,10000,39830.87760901451,0.744559109210968,1.0819778442382812,0.6834200024604797,1.3549482822418213,50000 -1530.7830486297607,3.8848836421966553,38819.99199438095,114987,0,38819.99199438095,0.5561000108718872,2.0544209480285645,10000,40358.582310676575,0.7444595098495483,1.1149373054504397,0.6838399767875671,1.3911999464035034,50000 -1548.1430974006653,3.932539224624634,39329.903483867645,116501,0,39329.903483867645,0.567300021648407,1.991621375083924,10000,40885.95380759239,0.7858139276504517,0.9308143854141236,0.6909199953079224,1.3370513916015625,50000 -1565.5552134513855,3.9830572605133057,39840.01126456261,118016,0,39840.01126456261,0.567300021648407,1.984636664390564,10000,41413.57640933991,0.7719228267669678,0.9736966490745544,0.692799985408783,1.3256369829177856,50000 -1582.9293761253357,4.029150485992432,40350.17130947113,119531,0,40350.17130947113,0.5732000470161438,1.962451219558716,10000,41941.20932650566,0.7692123651504517,0.9980911016464232,0.6963399648666382,1.3163288831710815,50000 -1600.571870803833,4.077746868133545,40860.15645599365,121046,0,40860.15645599365,0.5732000470161438,1.9472129344940183,10000,42468.937943696976,0.7643494606018066,1.0025675296783447,0.6955400109291077,1.3097865581512451,50000 -1618.0866899490356,4.125354290008545,41370.070234537125,122561,0,41370.070234537125,0.5808000564575195,1.9124568700790403,10000,42996.46684598923,0.7720224857330322,0.97927588224411,0.7010399699211121,1.279943585395813,50000 -1635.502154827118,4.173144578933716,41880.03378677368,124076,0,41880.03378677368,0.5755000114440918,1.924700140953064,10000,43523.94673323631,0.7723413705825806,0.971625566482544,0.7017599940299988,1.276233196258545,50000 -1653.6435549259186,4.225467681884766,42390.04211616516,125591,0,42390.04211616516,0.5813000202178955,1.917003273963928,10000,44052.20104932785,0.7950015664100647,0.878669023513794,0.6998400092124939,1.300983428955078,50000 -1671.1585881710052,4.279773235321045,42900.01652574539,127106,0,42900.01652574539,0.5827000141143799,1.9058300256729128,10000,44579.79776358605,0.78714919090271,0.9083539247512816,0.7061799764633179,1.2675763368606567,50000 -1688.4513931274414,5.066087961196899,43409.1985976696,128618,0,43409.1985976696,0.5888000130653381,1.886717438697815,10000,45107.111483335495,0.7895806431770325,0.9128103852272034,0.7077999711036682,1.2611401081085205,50000 -1706.087923288345,5.117692470550537,43919.28074002266,130133,0,43919.28074002266,0.5927000045776367,1.862455129623413,10000,45634.93291687965,0.7902981638908386,0.9130910634994508,0.711899995803833,1.2529252767562866,50000 -1723.184979915619,5.197498321533203,44429.30247211456,131648,0,44429.30247211456,0.5861000418663025,1.8845510482788088,10000,46162.18374609947,0.7849569320678711,0.9189361333847046,0.7110599875450134,1.247543454170227,50000 -1740.367802619934,5.248680830001831,44939.4055583477,133163,0,44939.4055583477,0.5948000550270081,1.8518083095550537,10000,46689.57337188721,0.7988081574440002,0.8635379076004028,0.7162399888038635,1.2158832550048828,50000 -1757.6859464645386,5.297849655151367,45449.53480505943,134678,0,45449.53480505943,0.5956000089645386,1.858915328979492,10000,47217.12211894989,0.8146922588348389,0.8167023658752441,0.7167999744415283,1.2303428649902344,50000 -1775.1927635669708,5.355523347854614,45959.67552042008,136193,0,45959.67552042008,0.597000002861023,1.836598873138428,10000,47744.87985253334,0.812519907951355,0.822140634059906,0.7186399698257446,1.2230992317199707,50000 -1792.49906373024,5.405432224273682,46469.635543346405,137707,0,46469.635543346405,0.5964000225067139,1.854988932609558,10000,48272.24763154984,0.8054049611091614,0.8238815069198608,0.7189599871635437,1.202906847000122,50000 -1809.889460325241,5.458134174346924,46979.75103497505,139222,0,46979.75103497505,0.6018000245094299,1.8125733137130733,10000,48799.85860180855,0.8106265664100647,0.8035197854042053,0.7279599905014038,1.174459457397461,50000 -1827.4179458618164,5.522055149078369,47489.82584095001,140737,0,47489.82584095001,0.5982000231742859,1.803998947143555,10000,49327.57805490494,0.8134167790412903,0.8105883598327637,0.7305999994277954,1.175118327140808,50000 -1844.4799404144287,5.573156356811523,48000.03847670555,142252,0,48000.03847670555,0.605400025844574,1.813591241836548,10000,49854.95561385155,0.8249959945678711,0.7559878826141357,0.726419985294342,1.179167866706848,50000 -1861.9220707416528,5.628809690475464,48509.954323768616,143766,0,48509.954323768616,0.6038000583648682,1.7907795906066897,10000,50382.42102813721,0.8384287357330322,0.7113813757896423,0.7309799790382385,1.161076545715332,50000 -1879.068108797073,5.681424856185913,49020.05911588669,145281,0,49020.05911588669,0.6077000498771667,1.781200885772705,10000,50909.77686786652,0.8381098508834839,0.7158500552177429,0.7337799668312073,1.148126482963562,50000 -1896.1900537014008,5.749187469482422,49530.228246212006,146796,0,49530.228246212006,0.6104000210762024,1.7563949823379517,10000,51437.18770766258,0.8376116156578064,0.7040087580680847,0.7360000014305115,1.129098415374756,50000 -1913.290997505188,5.799168109893799,50040.213027477264,148311,0,50040.213027477264,0.6163000464439392,1.7626051902770996,10000,51964.37534117699,0.8396643400192261,0.7063088417053223,0.738319993019104,1.1283575296401978,50000 -1930.7620613574984,5.853203773498535,50550.241391181946,149825,0,50550.241391181946,0.6243000030517578,1.719955325126648,10000,52491.98043131828,0.8420161008834839,0.6852911710739136,0.7417799830436707,1.1163774728775024,50000 -1948.308458328247,5.906131267547607,51060.40700316429,151340,0,51060.40700316429,0.6186000108718872,1.740875482559204,10000,53019.79779362679,0.86820387840271,0.5952945947647095,0.7403199672698975,1.1137624979019165,50000 -1966.237622976303,5.958402872085571,51570.638830661774,152855,0,51570.638830661774,0.62090003490448,1.735670804977417,10000,53548.062861442566,0.8589365482330322,0.613325297832489,0.7432799935340881,1.1003096103668213,50000 -1983.548333644867,6.011291027069092,52080.63415670395,154370,0,52080.63415670395,0.6241000294685364,1.7056432962417605,10000,54075.474705934525,0.8587571382522583,0.6176808476448059,0.7478799819946289,1.0903000831604004,50000 -2000.9424073696136,6.064687252044678,52590.81240940094,155885,0,52590.81240940094,0.6253000497817993,1.714283466339111,10000,54603.15359258652,0.8608896732330322,0.6201562881469727,0.75,1.0886688232421875,50000 -2018.272669315338,6.117977380752564,53100.97397065163,157400,0,53100.97397065163,0.6284000277519226,1.701801300048828,10000,55130.75111031532,0.8644570708274841,0.6090182065963745,0.7490599751472473,1.082980036735535,50000 -2035.686059474945,6.175013780593872,53611.190331459045,158915,0,53611.190331459045,0.6343000531196594,1.67049241065979,10000,55658.49032831192,0.871113657951355,0.5796046853065491,0.7538399696350098,1.0558656454086304,50000 -2052.965085029602,6.231770277023315,54121.38585519791,160430,0,54121.38585519791,0.6325000524520874,1.6727036237716677,10000,56186.07378411293,0.8908242583274841,0.5088411569595337,0.7539199590682983,1.0591936111450195,50000 -2070.361840724945,6.284480094909668,54631.54222178459,161945,0,54631.54222178459,0.6374000310897827,1.6483393907546997,10000,56713.73193693161,0.8883330225944519,0.507793128490448,0.7577599883079529,1.0418267250061035,50000 -2087.6100096702576,6.337510824203491,55141.44549107552,163459,0,55141.44549107552,0.6409000158309937,1.6533217430114746,10000,57240.98913478851,0.8888113498687744,0.508962869644165,0.759880006313324,1.0377557277679443,50000 -2104.812753200531,6.391413450241089,55651.61699032784,164974,0,55651.61699032784,0.6428000330924988,1.644087553024292,10000,57768.47083735466,0.8914421200752258,0.5064797401428223,0.7610999941825867,1.0407962799072266,50000 -2122.656609773636,6.452677011489868,56161.74114251137,166488,0,56161.74114251137,0.6368000507354736,1.654982328414917,10000,58296.5521376133,0.8913623690605164,0.4938401579856872,0.761900007724762,1.031830906867981,50000 -2139.942459821701,6.499396800994873,56671.93290805817,168003,0,56671.93290805817,0.6433000564575195,1.6341732740402222,10000,58824.12835144997,0.8980189561843872,0.4819498062133789,0.7639999985694885,1.017918586730957,50000 -2157.704973936081,6.561173677444458,57182.0711376667,169518,0,57182.0711376667,0.6438000202178955,1.6271370649337769,10000,59352.14291119576,0.9111128449440002,0.4368035793304443,0.7679599523544312,1.0132651329040527,50000 -2175.0261178016663,6.616154432296753,57692.27556872368,171033,0,57692.27556872368,0.6476000547409058,1.6199846267700195,10000,59879.77585315704,0.9100565910339355,0.4437120258808136,0.768839955329895,1.0136228799819946,50000 -2192.441407442093,6.679574251174927,58202.40605354309,172548,0,58202.40605354309,0.6457000374794006,1.6168798208236694,10000,60407.436655282974,0.907983899116516,0.44346883893013,0.7683799862861633,1.0122803449630735,50000 -2210.0351996421814,6.738846778869629,58712.38555598259,174063,0,58712.38555598259,0.6491000056266785,1.6106775999069214,10000,60935.12227869034,0.9148397445678712,0.4309082627296448,0.7701599597930908,1.0046439170837402,50000 -2227.369342327118,6.801840305328369,59222.59151554108,175578,0,59222.59151554108,0.650700032711029,1.6060999631881714,10000,61462.77833938599,0.913305163383484,0.4299351274967193,0.77183997631073,1.0019093751907349,50000 -2244.6973419189453,6.859493017196655,59732.50027251244,177092,0,59732.50027251244,0.65010005235672,1.6027166843414309,10000,61990.1252887249,0.914441168308258,0.4245000779628753,0.7725799679756165,0.9952738881111144,50000 -2262.119296312332,6.916916847229004,60242.55020594597,178606,0,60242.55020594597,0.6538000106811523,1.6024290323257446,10000,62517.70723223686,0.922512710094452,0.3977555632591247,0.7713800072669983,0.9977371096611024,50000 -2279.408571243286,6.978010654449463,60752.50250053406,180120,0,60752.50250053406,0.6515000462532043,1.5987427234649658,10000,63045.0622420311,0.9199019074440002,0.4014585316181183,0.7723199725151062,0.9948861598968506,50000 -2296.7791769504547,7.041255712509155,61262.67764997482,181635,0,61262.67764997482,0.6521000266075134,1.5991164445877075,10000,63572.72377586365,0.920918345451355,0.4040337800979614,0.7740799784660339,0.9936750531196594,50000 -2313.8672394752502,7.099210262298584,61772.59363722801,183149,0,61772.59363722801,0.6522000432014465,1.5940217971801758,10000,64099.83883333206,0.920141100883484,0.398114413022995,0.7745400071144104,0.9911282062530518,50000 -2331.1919524669647,7.156313896179199,62282.49484300613,184663,0,62282.49484300613,0.6538000106811523,1.596049427986145,10000,64627.17417168617,0.9205795526504515,0.4052787721157074,0.774399995803833,0.9936506748199464,50000 -2348.503286600113,7.2113025188446045,62792.42883324623,186177,0,62792.42883324623,0.6533000469207764,1.5955201387405396,10000,65154.52738904953,0.922253668308258,0.4010597765445709,0.7745199799537659,0.9941147565841676,50000 -2365.6900820732117,7.268607139587402,62957.09817099571,186666,0,62957.09817099571,0.6532000303268433,1.594335675239563,10000,65336.45809841156,0.9229711294174194,0.38955602049827576,0.774399995803833,0.9909643530845642,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index ec8a28c75..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6239581,6.925497,,,,,,,,,,,,,, -1,,,0.001335299690254,6.911616802215576,0.0011199999134987,6.912059783935547,50000.0,0.0009000000427477,6.912177562713623,10000.0,54.74183368682861,92.23791027069092,54.74183368682861,37.495994091033936,0.0,0.0 -100,0.5948771,6.9045835,,,,,,,,,,,,,, -200,0.6011343,6.8651357,,,,,,,,,,,,,, -300,0.6382472,6.788351,,,,,,,,,,,,,, -400,0.6636415,6.700946,,,,,,,,,,,,,, -500,0.73483163,6.6423125,,,,,,,,,,,,,, -600,0.7418363,6.5214944,,,,,,,,,,,,,, -700,0.7691768,6.4439383,,,,,,,,,,,,,, -800,0.86085683,6.348688,,,,,,,,,,,,,, -900,1.0663427,6.2706456,,,,,,,,,,,,,, -1000,1.5191617,6.22033,,,,,,,,,,,,,, -1100,2.0453186,6.061069,,,,,,,,,,,,,, -1200,3.6029768,6.034955,,,,,,,,,,,,,, -1300,2.0242965,5.9701796,,,,,,,,,,,,,, -1400,2.0123649,5.8854475,,,,,,,,,,,,,, -1500,3.2370663,5.814906,,,,,,,,,,,,,, -1508,,,0.070990115404129,5.393021583557129,0.0666799992322921,5.462131500244141,50000.0,0.0485000014305114,5.668511390686035,10000.0,564.9765136241913,620.535459280014,564.9765136241913,55.4818127155304,0.0258848667144775,0.0 -1600,2.0486515,5.747715,,,,,,,,,,,,,, -1700,3.2883177,5.736503,,,,,,,,,,,,,, -1800,2.5192208,5.636041,,,,,,,,,,,,,, -1900,2.402588,5.5427775,,,,,,,,,,,,,, -2000,2.979641,5.5172434,,,,,,,,,,,,,, -2100,3.5545719,5.4450006,,,,,,,,,,,,,, -2200,3.1048717,5.415953,,,,,,,,,,,,,, -2300,5.3297253,5.341847,,,,,,,,,,,,,, -2400,3.2601695,5.3113694,,,,,,,,,,,,,, -2500,3.3014967,5.34057,,,,,,,,,,,,,, -2600,3.1894073,5.205324,,,,,,,,,,,,,, -2700,4.1735373,5.116978,,,,,,,,,,,,,, -2800,4.5432057,5.1055913,,,,,,,,,,,,,, -2900,3.456403,4.9829698,,,,,,,,,,,,,, -3000,5.331863,5.0589952,,,,,,,,,,,,,, -3013,,,0.1768175959587097,4.263920783996582,0.1591600030660629,4.370645523071289,50000.0,0.1177000030875206,4.772252559661865,10000.0,1074.9497528076172,1148.9814901351929,1074.9497528076172,73.87850284576416,0.0510666370391845,0.0 -3100,3.7323716,5.026122,,,,,,,,,,,,,, -3200,6.8187447,4.8805265,,,,,,,,,,,,,, -3300,4.0009937,4.832096,,,,,,,,,,,,,, -3400,5.3468165,4.96758,,,,,,,,,,,,,, -3500,4.4629726,4.764493,,,,,,,,,,,,,, -3600,3.7994733,4.8266034,,,,,,,,,,,,,, -3700,4.281647,4.738118,,,,,,,,,,,,,, -3800,5.581806,4.6628175,,,,,,,,,,,,,, -3900,5.742629,4.661164,,,,,,,,,,,,,, -4000,5.0144477,4.6588545,,,,,,,,,,,,,, -4100,6.2088933,4.5860076,,,,,,,,,,,,,, -4200,6.9037995,4.470406,,,,,,,,,,,,,, -4300,3.5296068,4.5216994,,,,,,,,,,,,,, -4400,5.9416122,4.273775,,,,,,,,,,,,,, -4500,6.3297653,4.368678,,,,,,,,,,,,,, -4518,,,0.2875677645206451,3.482984781265259,0.2633000016212463,3.6200780868530273,50000.0,0.1901000142097473,4.155796527862549,10000.0,1584.9473168849945,1677.1219086647034,1584.9473168849945,91.94419121742249,0.0785439014434814,0.0 -4600,8.229104,4.4275365,,,,,,,,,,,,,, -4700,5.818326,4.4058223,,,,,,,,,,,,,, -4800,6.8743153,4.289041,,,,,,,,,,,,,, -4900,4.1725025,4.286368,,,,,,,,,,,,,, -5000,5.2488675,4.2969484,,,,,,,,,,,,,, -5100,5.3062806,4.200756,,,,,,,,,,,,,, -5200,7.036519,4.233947,,,,,,,,,,,,,, -5300,6.4292364,4.0969725,,,,,,,,,,,,,, -5400,5.5843225,4.1087255,,,,,,,,,,,,,, -5500,6.5556345,4.210333,,,,,,,,,,,,,, -5600,4.871556,4.0756526,,,,,,,,,,,,,, -5700,6.697876,4.131717,,,,,,,,,,,,,, -5800,4.3295116,4.000618,,,,,,,,,,,,,, -5900,5.465132,4.0481596,,,,,,,,,,,,,, -6000,4.5623217,3.990101,,,,,,,,,,,,,, -6024,,,0.3698580861091614,2.9841694831848145,0.3408399820327759,3.133474349975586,50000.0,0.255700021982193,3.714011907577514,10000.0,2095.0244052410126,2205.515574455261,2095.0244052410126,110.18438458442688,0.1052768230438232,0.0 -6100,4.7299356,3.9544194,,,,,,,,,,,,,, -6200,5.6171746,3.8651373,,,,,,,,,,,,,, -6300,6.0151396,3.987623,,,,,,,,,,,,,, -6400,7.3012576,4.010089,,,,,,,,,,,,,, -6500,7.8082194,3.8580074,,,,,,,,,,,,,, -6600,4.9772162,3.90113,,,,,,,,,,,,,, -6700,9.473548,3.801995,,,,,,,,,,,,,, -6800,5.313536,3.8093972,,,,,,,,,,,,,, -6900,8.786799,3.694005,,,,,,,,,,,,,, -7000,4.578554,3.615256,,,,,,,,,,,,,, -7100,7.989756,3.7124963,,,,,,,,,,,,,, -7200,6.489112,3.6911528,,,,,,,,,,,,,, -7300,7.1903563,3.6847537,,,,,,,,,,,,,, -7400,5.9322143,3.7838075,,,,,,,,,,,,,, -7500,4.4731593,3.6272824,,,,,,,,,,,,,, -7529,,,0.4408880770206451,2.5516676902771,0.4115200042724609,2.7000417709350586,50000.0,0.3109000027179718,3.379472017288208,10000.0,2605.0299191474915,2733.8128702640533,2605.0299191474915,128.3986883163452,0.1336238384246826,0.0 -7600,7.0704904,3.607764,,,,,,,,,,,,,, -7700,6.5831413,3.633194,,,,,,,,,,,,,, -7800,6.0458627,3.5923944,,,,,,,,,,,,,, -7900,3.9437022,3.6047597,,,,,,,,,,,,,, -8000,4.9557614,3.5063376,,,,,,,,,,,,,, -8100,8.155865,3.569632,,,,,,,,,,,,,, -8200,5.5346746,3.5127137,,,,,,,,,,,,,, -8300,7.0172343,3.4495037,,,,,,,,,,,,,, -8400,6.563126,3.531416,,,,,,,,,,,,,, -8500,4.37713,3.5564044,,,,,,,,,,,,,, -8600,8.233681,3.576018,,,,,,,,,,,,,, -8700,3.6642945,3.4061306,,,,,,,,,,,,,, -8800,3.4122217,3.4394965,,,,,,,,,,,,,, -8900,6.29159,3.3952134,,,,,,,,,,,,,, -9000,5.0573826,3.4777343,,,,,,,,,,,,,, -9036,,,0.4923070669174194,2.256446599960327,0.4586599767208099,2.4324121475219727,50000.0,0.3530000150203705,3.093779325485229,10000.0,3115.184098005295,3262.2012915611267,3115.184098005295,146.5477843284607,0.1685779094696045,0.0 -9100,6.8501315,3.3976934,,,,,,,,,,,,,, -9200,5.6928306,3.446867,,,,,,,,,,,,,, -9300,3.9998844,3.3893838,,,,,,,,,,,,,, -9400,5.011541,3.3574824,,,,,,,,,,,,,, -9500,5.0753975,3.301887,,,,,,,,,,,,,, -9600,3.9204361,3.4379582,,,,,,,,,,,,,, -9700,5.8329587,3.4662495,,,,,,,,,,,,,, -9800,5.4243712,3.3762667,,,,,,,,,,,,,, -9900,4.723493,3.4322343,,,,,,,,,,,,,, -10000,5.9905877,3.3453388,,,,,,,,,,,,,, -10100,5.6201324,3.4129133,,,,,,,,,,,,,, -10200,7.42347,3.370933,,,,,,,,,,,,,, -10300,6.28903,3.3990068,,,,,,,,,,,,,, -10400,4.550727,3.2812533,,,,,,,,,,,,,, -10500,5.2686415,3.2729774,,,,,,,,,,,,,, -10544,,,0.5581353306770325,1.9381918907165527,0.5031399726867676,2.2289462089538574,50000.0,0.3918000161647796,2.887331008911133,10000.0,3625.3059175014496,3790.7159888744354,3625.3059175014496,164.86400961875916,0.1949026584625244,0.0 -10600,4.5596647,3.2428372,,,,,,,,,,,,,, -10700,9.249387,3.2247562,,,,,,,,,,,,,, -10800,6.925232,3.2861228,,,,,,,,,,,,,, -10900,4.7069516,3.1892922,,,,,,,,,,,,,, -11000,4.732746,3.2717173,,,,,,,,,,,,,, -11100,4.4297905,3.2697663,,,,,,,,,,,,,, -11200,4.3635497,3.2583299,,,,,,,,,,,,,, -11300,5.559314,3.2587864,,,,,,,,,,,,,, -11400,6.1881433,3.3260543,,,,,,,,,,,,,, -11500,5.951379,3.205741,,,,,,,,,,,,,, -11600,3.29031,3.1824884,,,,,,,,,,,,,, -11700,3.4770334,3.230749,,,,,,,,,,,,,, -11800,5.051454,3.2078295,,,,,,,,,,,,,, -11900,3.82195,3.0730128,,,,,,,,,,,,,, -12000,5.411558,3.1373882,,,,,,,,,,,,,, -12053,,,0.5622209906578064,1.9551576375961304,0.5192399621009827,2.169461488723755,50000.0,0.4032000303268432,2.825645446777344,10000.0,4135.503875255585,4319.589529514313,4135.503875255585,183.4552013874054,0.2289702892303466,0.0 -12100,6.894994,3.187534,,,,,,,,,,,,,, -12200,4.227958,3.1714778,,,,,,,,,,,,,, -12300,7.870974,3.2573133,,,,,,,,,,,,,, -12400,5.1420064,3.1391604,,,,,,,,,,,,,, -12500,6.0568776,3.1136568,,,,,,,,,,,,,, -12600,4.0869403,3.0750422,,,,,,,,,,,,,, -12700,4.508413,3.2378945,,,,,,,,,,,,,, -12800,5.7794304,3.0559883,,,,,,,,,,,,,, -12900,6.2297215,3.1127725,,,,,,,,,,,,,, -13000,5.3724494,3.1860232,,,,,,,,,,,,,, -13100,5.101268,3.1232476,,,,,,,,,,,,,, -13200,4.266503,3.1750941,,,,,,,,,,,,,, -13300,6.8730197,3.1410234,,,,,,,,,,,,,, -13400,4.4216814,3.0740151,,,,,,,,,,,,,, -13500,5.5570784,3.1119127,,,,,,,,,,,,,, -13562,,,0.5889070630073547,1.835581421852112,0.5423399806022644,2.049694776535034,50000.0,0.4192000329494476,2.732748746871948,10000.0,4645.7563996315,4851.302194833756,4645.7563996315,204.8366265296936,0.2584633827209472,0.0 -13600,3.479042,3.0655036,,,,,,,,,,,,,, -13700,4.998807,3.0449743,,,,,,,,,,,,,, -13800,7.172605,3.114809,,,,,,,,,,,,,, -13900,4.442072,3.0552647,,,,,,,,,,,,,, -14000,7.2802906,3.106455,,,,,,,,,,,,,, -14100,5.7268314,3.0552864,,,,,,,,,,,,,, -14200,6.4864907,3.180555,,,,,,,,,,,,,, -14300,6.555811,3.0880375,,,,,,,,,,,,,, -14400,5.2501535,3.066195,,,,,,,,,,,,,, -14500,3.4084837,3.0339823,,,,,,,,,,,,,, -14600,8.654122,3.061963,,,,,,,,,,,,,, -14700,4.2895346,3.1002688,,,,,,,,,,,,,, -14800,4.9129043,2.9348326,,,,,,,,,,,,,, -14900,5.974022,3.1295009,,,,,,,,,,,,,, -15000,4.4528217,3.071105,,,,,,,,,,,,,, -15073,,,0.5912587642669678,1.7979090213775637,0.551800012588501,1.9871643781661987,50000.0,0.426000028848648,2.665627956390381,10000.0,5155.886093854904,5384.902832508087,5155.886093854904,228.2326774597168,0.2825262546539306,0.0 -15100,4.032474,3.050747,,,,,,,,,,,,,, -15200,6.6187143,2.9964957,,,,,,,,,,,,,, -15300,3.544849,2.945417,,,,,,,,,,,,,, -15400,6.2759085,3.172689,,,,,,,,,,,,,, -15500,4.273002,2.966677,,,,,,,,,,,,,, -15600,4.672142,2.9687858,,,,,,,,,,,,,, -15700,4.725392,3.0691342,,,,,,,,,,,,,, -15800,5.6621184,3.067103,,,,,,,,,,,,,, -15900,4.4949965,2.9855027,,,,,,,,,,,,,, -16000,3.7658796,2.951316,,,,,,,,,,,,,, -16100,5.7607136,2.9471698,,,,,,,,,,,,,, -16200,5.4803658,3.087514,,,,,,,,,,,,,, -16300,4.971734,3.06959,,,,,,,,,,,,,, -16400,5.5885434,2.988215,,,,,,,,,,,,,, -16500,3.1933374,2.9984684,,,,,,,,,,,,,, -16583,,,0.6018216013908386,1.7547258138656616,0.5679000020027161,1.932205080986023,50000.0,0.4445000290870666,2.6144609451293945,10000.0,5666.106993675232,5918.15540766716,5666.106993675232,251.177282333374,0.3175804615020752,0.0 -16600,4.440096,3.0405512,,,,,,,,,,,,,, -16700,4.0545006,3.0203938,,,,,,,,,,,,,, -16800,8.064937,3.019233,,,,,,,,,,,,,, -16900,4.4689627,2.9896662,,,,,,,,,,,,,, -17000,4.9232154,3.0928564,,,,,,,,,,,,,, -17100,7.3973703,3.0186694,,,,,,,,,,,,,, -17200,7.1789727,3.0279684,,,,,,,,,,,,,, -17300,4.3637953,2.9782476,,,,,,,,,,,,,, -17400,5.7150807,3.0877328,,,,,,,,,,,,,, -17500,3.8273785,2.9783652,,,,,,,,,,,,,, -17600,3.7317472,2.946093,,,,,,,,,,,,,, -17700,4.86578,2.8881524,,,,,,,,,,,,,, -17800,5.1394815,2.9947515,,,,,,,,,,,,,, -17900,4.04381,3.064238,,,,,,,,,,,,,, -18000,4.072637,3.0197575,,,,,,,,,,,,,, -18094,,,0.6030970811843872,1.7894777059555054,0.5586400032043457,1.985974907875061,50000.0,0.4395000338554382,2.643144369125366,10000.0,6176.272500514984,6453.581561326981,6176.272500514984,276.3594694137573,0.3460025787353515,0.0 -18100,3.9007401,3.077802,,,,,,,,,,,,,, -18200,4.12167,2.9698858,,,,,,,,,,,,,, -18300,3.1174788,3.0169044,,,,,,,,,,,,,, -18400,3.0446339,3.1153235,,,,,,,,,,,,,, -18500,4.068877,3.01501,,,,,,,,,,,,,, -18600,3.794226,2.9850485,,,,,,,,,,,,,, -18700,5.330354,2.9592073,,,,,,,,,,,,,, -18800,2.9649348,2.9506974,,,,,,,,,,,,,, -18900,3.2232463,2.96085,,,,,,,,,,,,,, -19000,2.9501123,2.9542112,,,,,,,,,,,,,, -19100,2.9131916,2.9688993,,,,,,,,,,,,,, -19200,3.9207602,2.9610808,,,,,,,,,,,,,, -19300,5.40889,3.018076,,,,,,,,,,,,,, -19400,2.8021047,3.038694,,,,,,,,,,,,,, -19500,3.1434114,2.928044,,,,,,,,,,,,,, -19600,5.359111,2.9592788,,,,,,,,,,,,,, -19605,,,0.6378746628761292,1.6028066873550415,0.5719999670982361,1.902109146118164,50000.0,0.4588000178337097,2.5408897399902344,10000.0,6686.243889808655,6987.860915660858,6686.243889808655,300.57800698280334,0.3844051361083984,0.0 -19700,3.346587,3.0850232,,,,,,,,,,,,,, -19800,3.352562,2.8499663,,,,,,,,,,,,,, -19900,2.7313867,2.8122022,,,,,,,,,,,,,, -20000,2.4403315,2.949303,,,,,,,,,,,,,, -20100,3.126098,2.9062877,,,,,,,,,,,,,, -20200,3.7157938,3.043522,,,,,,,,,,,,,, -20300,3.9955757,2.925566,,,,,,,,,,,,,, -20400,4.2422447,3.0255647,,,,,,,,,,,,,, -20500,3.8421376,2.909917,,,,,,,,,,,,,, -20600,2.3601823,2.8714128,,,,,,,,,,,,,, -20700,3.3458855,2.8447995,,,,,,,,,,,,,, -20800,3.6621416,2.9714437,,,,,,,,,,,,,, -20900,2.5957718,2.9860926,,,,,,,,,,,,,, -21000,4.094406,2.961309,,,,,,,,,,,,,, -21100,3.7316084,2.9508405,,,,,,,,,,,,,, -21117,,,0.6107900142669678,1.6893270015716553,0.5640599727630615,1.925768256187439,50000.0,0.4389000236988067,2.6041297912597656,10000.0,7196.342430591583,7522.037932395935,7196.342430591583,324.5791804790497,0.4113876819610595,0.0 -21200,3.8900657,2.9082506,,,,,,,,,,,,,, -21300,4.8495865,3.0498662,,,,,,,,,,,,,, -21400,4.3360415,2.841508,,,,,,,,,,,,,, -21500,3.2802567,2.9121203,,,,,,,,,,,,,, -21600,3.0691302,2.893933,,,,,,,,,,,,,, -21700,3.2831552,2.903327,,,,,,,,,,,,,, -21800,3.5898314,2.9568157,,,,,,,,,,,,,, -21900,3.0289204,2.967769,,,,,,,,,,,,,, -22000,3.3812928,2.9034283,,,,,,,,,,,,,, -22100,4.456885,2.9044018,,,,,,,,,,,,,, -22200,3.4588215,2.9976687,,,,,,,,,,,,,, -22300,3.6280844,2.8300047,,,,,,,,,,,,,, -22400,3.2274761,2.9723458,,,,,,,,,,,,,, -22500,3.4849823,2.867398,,,,,,,,,,,,,, -22600,2.9226522,2.8993056,,,,,,,,,,,,,, -22629,,,0.6161710619926453,1.6690157651901243,0.5745599865913391,1.873211145401001,50000.0,0.4581000208854675,2.543734312057495,10000.0,7706.454968452454,8056.826657772064,7706.454968452454,349.167008638382,0.4475350379943847,0.0 -22700,2.6565356,2.887433,,,,,,,,,,,,,, -22800,2.6048615,2.9572477,,,,,,,,,,,,,, -22900,3.0502825,2.8463812,,,,,,,,,,,,,, -23000,2.9641657,2.9586625,,,,,,,,,,,,,, -23100,2.5972886,2.94113,,,,,,,,,,,,,, -23200,4.2001047,2.8909678,,,,,,,,,,,,,, -23300,4.5950456,2.9338157,,,,,,,,,,,,,, -23400,2.9503229,2.8137817,,,,,,,,,,,,,, -23500,3.9350193,2.8826265,,,,,,,,,,,,,, -23600,2.5500166,2.9005613,,,,,,,,,,,,,, -23700,3.1841555,2.9634082,,,,,,,,,,,,,, -23800,3.674463,2.8746603,,,,,,,,,,,,,, -23900,2.6956677,2.8796842,,,,,,,,,,,,,, -24000,3.2989824,2.9617813,,,,,,,,,,,,,, -24100,3.3057096,2.9119487,,,,,,,,,,,,,, -24126,,,0.634785532951355,1.5688326358795166,0.5852999687194824,1.8004077672958374,50000.0,0.4648000299930572,2.4927031993865967,10000.0,8216.385123491287,8591.48088502884,8216.385123491287,373.81429505348206,0.4732851982116699,0.0 -24200,3.0248778,2.896919,,,,,,,,,,,,,, -24300,3.2341282,2.8234968,,,,,,,,,,,,,, -24400,2.8893106,2.911263,,,,,,,,,,,,,, -24500,3.4817357,2.9378135,,,,,,,,,,,,,, -24600,3.157708,2.8886254,,,,,,,,,,,,,, -24700,3.0473442,2.7515206,,,,,,,,,,,,,, -24800,4.1212897,2.9284604,,,,,,,,,,,,,, -24900,3.3943186,2.9856675,,,,,,,,,,,,,, -25000,3.7116127,2.8929107,,,,,,,,,,,,,, -25100,3.065967,2.8305821,,,,,,,,,,,,,, -25200,4.18405,2.859692,,,,,,,,,,,,,, -25300,3.8640823,2.8443542,,,,,,,,,,,,,, -25400,3.5600338,2.888383,,,,,,,,,,,,,, -25500,3.4196165,2.8932722,,,,,,,,,,,,,, -25600,2.9670875,2.8838944,,,,,,,,,,,,,, -25638,,,0.6300820708274841,1.6126160621643066,0.5874999761581421,1.8088067770004272,50000.0,0.4674000144004822,2.484600067138672,10000.0,8726.36254477501,9125.590382575989,8726.36254477501,397.8678450584412,0.5002624988555908,0.0 -25700,3.2728775,2.7778387,,,,,,,,,,,,,, -25800,3.1582355,2.8459466,,,,,,,,,,,,,, -25900,3.2749777,2.8389945,,,,,,,,,,,,,, -26000,3.0299964,2.9624474,,,,,,,,,,,,,, -26100,4.6306777,2.904117,,,,,,,,,,,,,, -26200,4.0957394,2.816719,,,,,,,,,,,,,, -26300,3.170938,2.9259737,,,,,,,,,,,,,, -26400,3.4882565,2.8583763,,,,,,,,,,,,,, -26500,2.9106426,2.9250684,,,,,,,,,,,,,, -26600,2.9606462,2.8382976,,,,,,,,,,,,,, -26700,3.1258564,2.7545857,,,,,,,,,,,,,, -26800,3.6477613,2.828968,,,,,,,,,,,,,, -26900,2.9465222,2.827422,,,,,,,,,,,,,, -27000,2.9570732,2.9149513,,,,,,,,,,,,,, -27100,4.3737473,2.990104,,,,,,,,,,,,,, -27151,,,0.6246811151504517,1.6491903066635132,0.5797399878501892,1.853614926338196,50000.0,0.4571000337600708,2.5513482093811035,10000.0,9236.353283643724,9659.236914873123,9236.353283643724,421.4299862384796,0.5421044826507568,0.0 -27200,2.8516643,2.7848256,,,,,,,,,,,,,, -27300,3.3796427,2.8516386,,,,,,,,,,,,,, -27400,3.7383034,2.8543794,,,,,,,,,,,,,, -27500,3.1233444,2.7505724,,,,,,,,,,,,,, -27600,3.157916,2.8193526,,,,,,,,,,,,,, -27700,2.901556,2.814974,,,,,,,,,,,,,, -27800,2.709626,2.9074738,,,,,,,,,,,,,, -27900,2.8629856,2.8489945,,,,,,,,,,,,,, -28000,3.2827673,2.8374398,,,,,,,,,,,,,, -28100,2.9229112,2.8517811,,,,,,,,,,,,,, -28200,3.7181327,2.8980799,,,,,,,,,,,,,, -28300,2.5199945,2.9019032,,,,,,,,,,,,,, -28400,3.8677096,2.8365254,,,,,,,,,,,,,, -28500,2.9270086,2.8414178,,,,,,,,,,,,,, -28600,3.1686893,2.9271069,,,,,,,,,,,,,, -28663,,,0.6515266299247742,1.5246480703353882,0.590399980545044,1.8066246509552,50000.0,0.4710000157356262,2.4751579761505127,10000.0,9746.11990594864,10194.557493686676,9746.11990594864,446.6768915653229,0.7980978488922119,0.0 -28700,4.1886053,2.8307407,,,,,,,,,,,,,, -28800,2.544713,2.7881918,,,,,,,,,,,,,, -28900,3.7224653,2.7777557,,,,,,,,,,,,,, -29000,3.1460056,2.8598423,,,,,,,,,,,,,, -29100,3.1050906,2.8787956,,,,,,,,,,,,,, -29200,2.9311671,2.8690915,,,,,,,,,,,,,, -29300,3.1972194,2.884904,,,,,,,,,,,,,, -29400,4.238393,2.8205757,,,,,,,,,,,,,, -29500,3.4268703,2.9098701,,,,,,,,,,,,,, -29600,3.644189,2.8681998,,,,,,,,,,,,,, -29700,2.6892915,2.7086043,,,,,,,,,,,,,, -29800,3.556107,2.876894,,,,,,,,,,,,,, -29900,3.2142153,2.7996492,,,,,,,,,,,,,, -30000,2.6440022,2.8347998,,,,,,,,,,,,,, -30100,2.8583035,2.8125103,,,,,,,,,,,,,, -30176,,,0.6483777165412903,1.5167217254638672,0.5963599681854248,1.7699456214904783,50000.0,0.4713000357151031,2.4669439792633057,10000.0,10256.066632509232,10729.63278746605,10256.066632509232,471.7249677181244,0.8278708457946777,0.0 -30200,3.3444836,2.8795764,,,,,,,,,,,,,, -30300,3.7186651,2.8940067,,,,,,,,,,,,,, -30400,3.596213,2.9539685,,,,,,,,,,,,,, -30500,3.1342773,2.861215,,,,,,,,,,,,,, -30600,3.1745896,2.7465355,,,,,,,,,,,,,, -30700,3.2693882,2.8031862,,,,,,,,,,,,,, -30800,2.6533144,2.8906581,,,,,,,,,,,,,, -30900,2.9677665,2.7347503,,,,,,,,,,,,,, -31000,2.9518762,2.8774843,,,,,,,,,,,,,, -31100,2.6076732,2.8323731,,,,,,,,,,,,,, -31200,3.6144843,2.767278,,,,,,,,,,,,,, -31300,3.1666083,2.8665466,,,,,,,,,,,,,, -31400,3.2862282,2.876595,,,,,,,,,,,,,, -31500,3.414644,2.8196583,,,,,,,,,,,,,, -31600,3.0686116,2.8050952,,,,,,,,,,,,,, -31689,,,0.6477000713348389,1.5327821969985962,0.6036799550056458,1.7479439973831177,50000.0,0.4769000113010406,2.4278860092163086,10000.0,10766.203356981276,11265.86218070984,10766.203356981276,497.735536813736,0.8581020832061768,0.0 -31700,3.3540263,2.8904312,,,,,,,,,,,,,, -31800,3.7875907,2.7794783,,,,,,,,,,,,,, -31900,3.749698,2.9040258,,,,,,,,,,,,,, -32000,3.1389465,2.8526833,,,,,,,,,,,,,, -32100,2.7459335,2.7262864,,,,,,,,,,,,,, -32200,3.0906372,2.8639464,,,,,,,,,,,,,, -32300,2.7809455,2.7724705,,,,,,,,,,,,,, -32400,2.778879,2.8180897,,,,,,,,,,,,,, -32500,2.7987547,2.7536101,,,,,,,,,,,,,, -32600,3.165824,2.857193,,,,,,,,,,,,,, -32700,2.921862,2.841837,,,,,,,,,,,,,, -32800,2.6111243,2.859934,,,,,,,,,,,,,, -32900,3.040867,2.7928505,,,,,,,,,,,,,, -33000,2.712906,2.896935,,,,,,,,,,,,,, -33100,3.034689,2.7639143,,,,,,,,,,,,,, -33200,2.6711488,2.8122764,,,,,,,,,,,,,, -33203,,,0.6440529227256775,1.5270545482635498,0.6020199656486511,1.7369011640548706,50000.0,0.4764000177383423,2.419763803482056,10000.0,11276.440303564072,11801.382081270218,11276.440303564072,522.9339916706085,0.8919429779052734,0.0 -33300,2.8590693,2.82271,,,,,,,,,,,,,, -33400,4.1979637,2.8251164,,,,,,,,,,,,,, -33500,3.6374733,2.7847865,,,,,,,,,,,,,, -33600,4.09263,2.8557007,,,,,,,,,,,,,, -33700,4.1433043,2.818813,,,,,,,,,,,,,, -33800,3.4771826,2.8584967,,,,,,,,,,,,,, -33900,3.286527,2.8631315,,,,,,,,,,,,,, -34000,3.6567042,2.9278407,,,,,,,,,,,,,, -34100,3.315931,2.7694266,,,,,,,,,,,,,, -34200,3.258215,2.791339,,,,,,,,,,,,,, -34300,2.9752295,2.7791166,,,,,,,,,,,,,, -34400,3.2359195,2.729763,,,,,,,,,,,,,, -34500,3.3311846,2.8946667,,,,,,,,,,,,,, -34600,4.7133117,2.8119075,,,,,,,,,,,,,, -34700,3.153625,2.778771,,,,,,,,,,,,,, -34715,,,0.6390106678009033,1.5380759239196775,0.5995799899101257,1.7401143312454224,50000.0,0.4750000238418579,2.4321818351745605,10000.0,11786.374200820925,12336.523176193235,11786.374200820925,548.0605285167694,0.9209282398223876,0.0 -34800,3.0651917,2.7868767,,,,,,,,,,,,,, -34900,3.3003545,2.780243,,,,,,,,,,,,,, -35000,3.0543656,2.8331223,,,,,,,,,,,,,, -35100,3.455123,2.7576241,,,,,,,,,,,,,, -35200,2.8699663,2.8538609,,,,,,,,,,,,,, -35300,2.8476071,2.6560001,,,,,,,,,,,,,, -35400,2.862473,2.6602387,,,,,,,,,,,,,, -35500,3.552619,2.8936708,,,,,,,,,,,,,, -35600,3.6301496,2.8894658,,,,,,,,,,,,,, -35700,3.0097864,2.855288,,,,,,,,,,,,,, -35800,3.1691542,2.8051903,,,,,,,,,,,,,, -35900,3.1665938,2.7935307,,,,,,,,,,,,,, -36000,3.063008,2.8244107,,,,,,,,,,,,,, -36100,3.0446649,2.824382,,,,,,,,,,,,,, -36200,2.7827258,2.8198404,,,,,,,,,,,,,, -36228,,,0.6710578799247742,1.4383693933486938,0.6104999780654907,1.7170253992080688,50000.0,0.4812000095844269,2.404179334640503,10000.0,12296.36655974388,12872.187220096588,12296.36655974388,573.6493542194366,0.9530339241027832,0.0 -36300,2.8972235,2.8488696,,,,,,,,,,,,,, -36400,3.7136776,2.8715916,,,,,,,,,,,,,, -36500,2.9434574,2.7024982,,,,,,,,,,,,,, -36600,2.9733949,2.7948036,,,,,,,,,,,,,, -36700,3.703236,2.7803612,,,,,,,,,,,,,, -36800,2.7140672,2.7537436,,,,,,,,,,,,,, -36900,3.8725913,2.955965,,,,,,,,,,,,,, -37000,2.9456544,2.781833,,,,,,,,,,,,,, -37100,3.1929276,2.7413929,,,,,,,,,,,,,, -37200,3.3385592,2.8580837,,,,,,,,,,,,,, -37300,3.2245312,2.7137384,,,,,,,,,,,,,, -37400,3.3356805,2.745235,,,,,,,,,,,,,, -37500,2.8371453,2.7389188,,,,,,,,,,,,,, -37600,3.067177,2.8796568,,,,,,,,,,,,,, -37700,3.4471436,2.760904,,,,,,,,,,,,,, -37741,,,0.6660754084587097,1.4482871294021606,0.6075599789619446,1.733280897140503,50000.0,0.4754000306129455,2.427032232284546,10000.0,12806.337379455566,13405.956619262695,12806.337379455566,597.3628311157227,0.9872357845306396,0.0 -37800,4.0279465,2.770814,,,,,,,,,,,,,, -37900,3.049967,2.857713,,,,,,,,,,,,,, -38000,3.087076,2.803548,,,,,,,,,,,,,, -38100,3.4969,2.7843826,,,,,,,,,,,,,, -38200,2.7290192,2.760288,,,,,,,,,,,,,, -38300,3.1108458,2.8133066,,,,,,,,,,,,,, -38400,2.5975528,2.8226836,,,,,,,,,,,,,, -38500,2.7984517,2.751972,,,,,,,,,,,,,, -38600,3.020341,2.8119617,,,,,,,,,,,,,, -38700,2.9927597,2.9109125,,,,,,,,,,,,,, -38800,2.9906428,2.7561433,,,,,,,,,,,,,, -38900,3.125754,2.8082867,,,,,,,,,,,,,, -39000,3.2371542,2.7668827,,,,,,,,,,,,,, -39100,2.841916,2.8301694,,,,,,,,,,,,,, -39200,3.0561395,2.7947598,,,,,,,,,,,,,, -39255,,,0.6478993892669678,1.5248180627822876,0.5951799750328064,1.7707302570343018,50000.0,0.4769000113010406,2.424974203109741,10000.0,13316.273628473282,13939.578382253649,13316.273628473282,620.9647953510284,1.0185627937316897,0.0 -39300,3.0790832,2.7253568,,,,,,,,,,,,,, -39400,2.790228,2.6804357,,,,,,,,,,,,,, -39500,3.515071,2.8735514,,,,,,,,,,,,,, -39600,3.6658,2.7871153,,,,,,,,,,,,,, -39700,3.7938046,2.825675,,,,,,,,,,,,,, -39800,2.9990551,2.8349826,,,,,,,,,,,,,, -39900,3.779057,2.716623,,,,,,,,,,,,,, -40000,3.152715,2.7810857,,,,,,,,,,,,,, -40100,3.1074164,2.762169,,,,,,,,,,,,,, -40200,3.9202068,2.7449865,,,,,,,,,,,,,, -40300,3.381642,2.8783882,,,,,,,,,,,,,, -40400,3.238005,2.7502713,,,,,,,,,,,,,, -40500,2.9090652,2.8368845,,,,,,,,,,,,,, -40600,3.2359047,2.722261,,,,,,,,,,,,,, -40700,2.8063893,2.7833018,,,,,,,,,,,,,, -40769,,,0.6623883843421936,1.4788810014724731,0.6086199879646301,1.71470308303833,50000.0,0.4877000153064728,2.3885064125061035,10000.0,13826.47026848793,14473.0782289505,13826.47026848793,644.1830842494965,1.0521259307861328,0.0 -40800,3.194148,2.7318864,,,,,,,,,,,,,, -40900,2.6311848,2.8975313,,,,,,,,,,,,,, -41000,3.2993534,2.7036257,,,,,,,,,,,,,, -41100,2.8824184,2.6998696,,,,,,,,,,,,,, -41200,3.178674,2.8368568,,,,,,,,,,,,,, -41300,3.0928805,2.8261242,,,,,,,,,,,,,, -41400,3.1267614,2.7446077,,,,,,,,,,,,,, -41500,2.9445205,2.7496126,,,,,,,,,,,,,, -41600,3.4656472,2.7778153,,,,,,,,,,,,,, -41700,3.1088734,2.805396,,,,,,,,,,,,,, -41800,3.263242,2.757567,,,,,,,,,,,,,, -41900,3.42966,2.703642,,,,,,,,,,,,,, -42000,3.1773157,2.8180163,,,,,,,,,,,,,, -42100,2.6754205,2.7951055,,,,,,,,,,,,,, -42200,3.2098594,2.8285844,,,,,,,,,,,,,, -42284,,,0.6502909660339355,1.5233001708984375,0.6042400002479553,1.73208749294281,50000.0,0.4802000224590301,2.396982431411743,10000.0,14336.686334371569,15012.248512983322,14336.686334371569,673.0554871559143,1.0817155838012695,0.0 -42300,3.0296924,2.6031418,,,,,,,,,,,,,, -42400,3.323272,2.6821814,,,,,,,,,,,,,, -42500,3.6942148,2.7852046,,,,,,,,,,,,,, -42600,3.0594234,2.751266,,,,,,,,,,,,,, -42700,3.2482154,2.6359386,,,,,,,,,,,,,, -42800,2.659241,2.788895,,,,,,,,,,,,,, -42900,3.586363,2.7769647,,,,,,,,,,,,,, -43000,3.156183,2.8212366,,,,,,,,,,,,,, -43100,3.4589348,2.7976227,,,,,,,,,,,,,, -43200,3.1176653,2.7413592,,,,,,,,,,,,,, -43300,2.8978012,2.7997985,,,,,,,,,,,,,, -43400,3.694133,2.7254896,,,,,,,,,,,,,, -43500,3.130869,2.7831695,,,,,,,,,,,,,, -43600,3.4345446,2.684178,,,,,,,,,,,,,, -43700,3.4454558,2.670086,,,,,,,,,,,,,, -43798,,,0.6564692258834839,1.4994585514068604,0.6126999855041504,1.6944860219955444,50000.0,0.4892000257968902,2.3520050048828125,10000.0,14846.899607419968,15545.86763715744,14846.899607419968,696.379273891449,1.1133880615234375,0.0 -43800,3.3399055,2.7915425,,,,,,,,,,,,,, -43900,4.1005516,2.747895,,,,,,,,,,,,,, -44000,3.4017353,2.8494906,,,,,,,,,,,,,, -44100,3.7700193,2.6921663,,,,,,,,,,,,,, -44200,3.2724898,2.7764912,,,,,,,,,,,,,, -44300,3.576636,2.767152,,,,,,,,,,,,,, -44400,3.4528277,2.742788,,,,,,,,,,,,,, -44500,3.0151484,2.8205307,,,,,,,,,,,,,, -44600,3.2722158,2.8435874,,,,,,,,,,,,,, -44700,3.0921082,2.7860346,,,,,,,,,,,,,, -44800,2.858264,2.6661215,,,,,,,,,,,,,, -44900,3.5114944,2.7306135,,,,,,,,,,,,,, -45000,3.1307497,2.8704975,,,,,,,,,,,,,, -45100,3.5397286,2.690322,,,,,,,,,,,,,, -45200,3.585983,2.8147907,,,,,,,,,,,,,, -45300,3.2738576,2.742382,,,,,,,,,,,,,, -45312,,,0.6901307106018066,1.335003137588501,0.60971999168396,1.702130913734436,50000.0,0.4855000376701355,2.367635488510132,10000.0,15356.965127944946,16078.829252958298,15356.965127944946,719.1942150592804,1.1434450149536133,0.0 -45400,3.3619587,2.764699,,,,,,,,,,,,,, -45500,2.554933,2.780468,,,,,,,,,,,,,, -45600,3.4400885,2.8268635,,,,,,,,,,,,,, -45700,3.27771,2.8440812,,,,,,,,,,,,,, -45800,3.439793,2.776506,,,,,,,,,,,,,, -45900,3.4762092,2.701297,,,,,,,,,,,,,, -46000,3.2881553,2.718849,,,,,,,,,,,,,, -46100,3.219333,2.81001,,,,,,,,,,,,,, -46200,3.5578842,2.7569668,,,,,,,,,,,,,, -46300,3.7611873,2.6721463,,,,,,,,,,,,,, -46400,3.05037,2.8446424,,,,,,,,,,,,,, -46500,2.924478,2.6973238,,,,,,,,,,,,,, -46600,2.811737,2.7825875,,,,,,,,,,,,,, -46700,2.8069293,2.6904001,,,,,,,,,,,,,, -46800,2.7392542,2.8031197,,,,,,,,,,,,,, -46826,,,0.6707788705825806,1.4408752918243408,0.6131199598312378,1.7045730352401731,50000.0,0.4852000176906585,2.384239435195923,10000.0,15867.11354660988,16609.47766971588,15867.11354660988,739.6124730110168,1.1734073162078855,0.0 -46900,3.4078875,2.8433244,,,,,,,,,,,,,, -47000,3.2546735,2.6786218,,,,,,,,,,,,,, -47100,2.9946866,2.7008345,,,,,,,,,,,,,, -47200,2.7802515,2.689988,,,,,,,,,,,,,, -47300,3.7065308,2.7406502,,,,,,,,,,,,,, -47400,3.2698042,2.6627755,,,,,,,,,,,,,, -47500,3.0505135,2.6858275,,,,,,,,,,,,,, -47600,3.3056371,2.860939,,,,,,,,,,,,,, -47700,3.8689506,2.685062,,,,,,,,,,,,,, -47800,3.4989748,2.739847,,,,,,,,,,,,,, -47900,3.3348587,2.710613,,,,,,,,,,,,,, -48000,3.412551,2.7192535,,,,,,,,,,,,,, -48100,3.1512961,2.6936142,,,,,,,,,,,,,, -48200,3.1771069,2.8566775,,,,,,,,,,,,,, -48300,2.8760104,2.7723026,,,,,,,,,,,,,, -48340,,,0.6635841727256775,1.4390766620635986,0.6102799773216248,1.6896299123764038,50000.0,0.4865000247955322,2.390916585922241,10000.0,16377.163805484772,17139.047789812088,16377.163805484772,759.0497002601624,1.20459246635437,0.0 -48400,3.6561651,2.7406979,,,,,,,,,,,,,, -48500,3.2079768,2.798457,,,,,,,,,,,,,, -48600,3.0372493,2.8069909,,,,,,,,,,,,,, -48700,3.0673406,2.737742,,,,,,,,,,,,,, -48800,3.1645458,2.8165746,,,,,,,,,,,,,, -48900,3.0869093,2.6443927,,,,,,,,,,,,,, -49000,2.7945929,2.7126143,,,,,,,,,,,,,, -49100,3.4597135,2.741825,,,,,,,,,,,,,, -49200,3.2514749,2.7364624,,,,,,,,,,,,,, -49300,3.141735,2.6929553,,,,,,,,,,,,,, -49400,3.6219375,2.7293577,,,,,,,,,,,,,, -49500,3.1073334,2.6857114,,,,,,,,,,,,,, -49600,2.8596697,2.6896398,,,,,,,,,,,,,, -49700,3.2080903,2.7308824,,,,,,,,,,,,,, -49800,3.595482,2.621256,,,,,,,,,,,,,, -49854,,,0.6621492505073547,1.4692184925079346,0.6146799921989441,1.694058895111084,50000.0,0.4975000321865082,2.350130081176758,10000.0,16887.197038412094,17667.17898607254,16887.197038412094,777.0617418289185,1.239149808883667,0.0 -49900,3.0309813,2.6990616,,,,,,,,,,,,,, -50000,3.2164717,2.7178097,,,,,,,,,,,,,, -50100,3.2360265,2.7142644,,,,,,,,,,,,,, -50200,3.1585016,2.789828,,,,,,,,,,,,,, -50300,3.0000029,2.6850672,,,,,,,,,,,,,, -50400,2.9905066,2.800795,,,,,,,,,,,,,, -50500,3.403159,2.6739998,,,,,,,,,,,,,, -50600,3.653812,2.7687569,,,,,,,,,,,,,, -50700,3.387186,2.6990204,,,,,,,,,,,,,, -50800,3.4408793,2.780238,,,,,,,,,,,,,, -50900,3.254808,2.8507242,,,,,,,,,,,,,, -51000,3.083522,2.766601,,,,,,,,,,,,,, -51100,2.8331606,2.757971,,,,,,,,,,,,,, -51200,3.473772,2.6757946,,,,,,,,,,,,,, -51300,3.53262,2.8025658,,,,,,,,,,,,,, -51369,,,0.6635642647743225,1.439455270767212,0.6191399693489075,1.646724820137024,50000.0,0.4969000220298767,2.307199478149414,10000.0,17397.333253145218,18195.217745542526,17397.333253145218,794.8757519721985,1.2756733894348145,0.0 -51400,3.6462076,2.6935773,,,,,,,,,,,,,, -51500,3.3512573,2.7853868,,,,,,,,,,,,,, -51600,3.1454508,2.6714754,,,,,,,,,,,,,, -51700,2.9968379,2.7409406,,,,,,,,,,,,,, -51800,2.792858,2.7467346,,,,,,,,,,,,,, -51900,3.5975337,2.7279754,,,,,,,,,,,,,, -52000,3.499188,2.7363522,,,,,,,,,,,,,, -52100,3.084086,2.7466974,,,,,,,,,,,,,, -52200,3.7199848,2.725644,,,,,,,,,,,,,, -52300,2.6734576,2.768328,,,,,,,,,,,,,, -52400,2.9764335,2.727923,,,,,,,,,,,,,, -52500,3.1757843,2.7667933,,,,,,,,,,,,,, -52600,3.3017128,2.798299,,,,,,,,,,,,,, -52700,3.259955,2.7847872,,,,,,,,,,,,,, -52800,3.2218869,2.7357993,,,,,,,,,,,,,, -52884,,,0.6588408946990967,1.4924957752227783,0.6126799583435059,1.7022305727005005,50000.0,0.4881000220775604,2.357887029647827,10000.0,17907.54860687256,18723.197404146194,17907.54860687256,812.5476040840149,1.3171625137329102,0.0 -52900,3.134647,2.7021718,,,,,,,,,,,,,, -53000,3.11233,2.7571542,,,,,,,,,,,,,, -53100,3.347177,2.719406,,,,,,,,,,,,,, -53200,3.103112,2.7249196,,,,,,,,,,,,,, -53300,3.003465,2.6606498,,,,,,,,,,,,,, -53400,3.3625314,2.6793256,,,,,,,,,,,,,, -53500,3.0712242,2.6391501,,,,,,,,,,,,,, -53600,3.5940547,2.5884657,,,,,,,,,,,,,, -53700,2.9733672,2.6641734,,,,,,,,,,,,,, -53800,3.7457018,2.6832485,,,,,,,,,,,,,, -53900,3.2133007,2.6461792,,,,,,,,,,,,,, -54000,3.2771556,2.730949,,,,,,,,,,,,,, -54100,3.643815,2.7341774,,,,,,,,,,,,,, -54200,4.19568,2.7111259,,,,,,,,,,,,,, -54300,3.005409,2.6692677,,,,,,,,,,,,,, -54399,,,0.7075693607330322,1.2901153564453125,0.625819981098175,1.6629422903060913,50000.0,0.5001000165939331,2.3184971809387207,10000.0,18417.55141377449,19250.939685344696,18417.55141377449,830.1899147033691,1.362135410308838,0.0 -54400,3.5492356,2.7314668,,,,,,,,,,,,,, -54500,3.3542655,2.660542,,,,,,,,,,,,,, -54600,3.3069785,2.727203,,,,,,,,,,,,,, -54700,3.2084358,2.689441,,,,,,,,,,,,,, -54800,3.0877383,2.7441037,,,,,,,,,,,,,, -54900,3.3799698,2.630784,,,,,,,,,,,,,, -55000,3.5887618,2.7269993,,,,,,,,,,,,,, -55100,3.1420488,2.8361244,,,,,,,,,,,,,, -55200,3.2836354,2.6975303,,,,,,,,,,,,,, -55300,3.3454506,2.7098217,,,,,,,,,,,,,, -55400,2.909742,2.726725,,,,,,,,,,,,,, -55500,3.484935,2.767774,,,,,,,,,,,,,, -55600,3.2123926,2.6507905,,,,,,,,,,,,,, -55700,3.428146,2.771651,,,,,,,,,,,,,, -55800,3.5374858,2.8308768,,,,,,,,,,,,,, -55900,3.3481603,2.6511803,,,,,,,,,,,,,, -55913,,,0.6686663031578064,1.4696898460388184,0.6124399900436401,1.7320475578308103,50000.0,0.4935000240802765,2.397702932357788,10000.0,18927.4927611351,19778.73899126053,18927.4927611351,847.9586672782898,1.3999717235565186,0.0 -56000,4.0065703,2.6559634,,,,,,,,,,,,,, -56100,2.9552975,2.6252131,,,,,,,,,,,,,, -56200,3.2788272,2.7596302,,,,,,,,,,,,,, -56300,3.0950656,2.752111,,,,,,,,,,,,,, -56400,2.8885832,2.6566615,,,,,,,,,,,,,, -56500,3.159505,2.6633432,,,,,,,,,,,,,, -56600,2.9907904,2.7480767,,,,,,,,,,,,,, -56700,3.0871718,2.6509151,,,,,,,,,,,,,, -56800,2.9025295,2.6823733,,,,,,,,,,,,,, -56900,3.892408,2.7750978,,,,,,,,,,,,,, -57000,2.967183,2.7116315,,,,,,,,,,,,,, -57100,3.4836621,2.7031674,,,,,,,,,,,,,, -57200,3.6762798,2.7732487,,,,,,,,,,,,,, -57300,3.283165,2.710113,,,,,,,,,,,,,, -57400,3.0740452,2.601092,,,,,,,,,,,,,, -57428,,,0.6722536683082581,1.435630440711975,0.6198399662971497,1.672557711601257,50000.0,0.5006000399589539,2.341546535491944,10000.0,19437.53251218796,20306.21404647827,19437.53251218796,865.3042812347412,1.4369032382965088,0.0 -57500,2.9696171,2.6834927,,,,,,,,,,,,,, -57600,3.472478,2.7074108,,,,,,,,,,,,,, -57700,3.6558697,2.7556305,,,,,,,,,,,,,, -57800,3.4347856,2.7742043,,,,,,,,,,,,,, -57900,3.2681763,2.6142366,,,,,,,,,,,,,, -58000,3.1566331,2.7577338,,,,,,,,,,,,,, -58100,3.3211226,2.7551763,,,,,,,,,,,,,, -58200,3.024868,2.6499312,,,,,,,,,,,,,, -58300,3.1419232,2.6679516,,,,,,,,,,,,,, -58400,3.1086054,2.7080822,,,,,,,,,,,,,, -58500,2.9658096,2.7536178,,,,,,,,,,,,,, -58600,3.2931535,2.7208304,,,,,,,,,,,,,, -58700,3.4915857,2.707922,,,,,,,,,,,,,, -58800,3.2031882,2.7576616,,,,,,,,,,,,,, -58900,2.8158674,2.723981,,,,,,,,,,,,,, -58942,,,0.6779735088348389,1.4131288528442385,0.6226599812507629,1.664193868637085,50000.0,0.5063000321388245,2.299123764038086,10000.0,19947.70263171196,20833.81520462036,19947.70263171196,882.6476812362671,1.4736227989196775,0.0 -59000,4.2940574,2.6995711,,,,,,,,,,,,,, -59100,3.3697593,2.7670016,,,,,,,,,,,,,, -59200,2.994976,2.7838674,,,,,,,,,,,,,, -59300,4.0910482,2.7193851,,,,,,,,,,,,,, -59400,3.0405192,2.7774215,,,,,,,,,,,,,, -59500,3.128449,2.6257331,,,,,,,,,,,,,, -59600,3.3327422,2.661035,,,,,,,,,,,,,, -59700,4.2351394,2.754327,,,,,,,,,,,,,, -59800,3.6785655,2.7599044,,,,,,,,,,,,,, -59900,3.1769638,2.64209,,,,,,,,,,,,,, -60000,3.545259,2.7003152,,,,,,,,,,,,,, -60100,3.0917149,2.7617037,,,,,,,,,,,,,, -60200,3.1495926,2.59311,,,,,,,,,,,,,, -60300,3.1824126,2.7732158,,,,,,,,,,,,,, -60400,3.1447353,2.6387281,,,,,,,,,,,,,, -60457,,,0.6662946343421936,1.4448381662368774,0.6218199729919434,1.656185269355774,50000.0,0.4940000176429748,2.3347575664520264,10000.0,20457.742354631424,21361.44446396828,20457.742354631424,900.1488373279572,1.509693622589111,0.0 -60500,4.016052,2.6299305,,,,,,,,,,,,,, -60600,3.0182474,2.6240196,,,,,,,,,,,,,, -60700,2.570963,2.4901829,,,,,,,,,,,,,, -60800,3.3386958,2.7553134,,,,,,,,,,,,,, -60900,2.8599164,2.6139297,,,,,,,,,,,,,, -61000,3.8757565,2.7626286,,,,,,,,,,,,,, -61100,3.0983956,2.7593675,,,,,,,,,,,,,, -61200,3.237271,2.603094,,,,,,,,,,,,,, -61300,3.584328,2.6371827,,,,,,,,,,,,,, -61400,3.27885,2.7179122,,,,,,,,,,,,,, -61500,3.299632,2.6439717,,,,,,,,,,,,,, -61600,3.141186,2.640944,,,,,,,,,,,,,, -61700,2.8334725,2.6298654,,,,,,,,,,,,,, -61800,3.2874799,2.6780133,,,,,,,,,,,,,, -61900,3.6592064,2.7095249,,,,,,,,,,,,,, -61971,,,0.6689453125,1.4233334064483645,0.6264599561691284,1.6362653970718384,50000.0,0.5004000067710876,2.3209445476531982,10000.0,20967.97685956955,21889.169400691982,20967.97685956955,917.5497057437896,1.5463056564331057,0.0 -62000,2.9864395,2.7581215,,,,,,,,,,,,,, -62100,3.1868095,2.6366804,,,,,,,,,,,,,, -62200,3.4954193,2.7228885,,,,,,,,,,,,,, -62300,3.0447037,2.583962,,,,,,,,,,,,,, -62400,3.6609118,2.6643593,,,,,,,,,,,,,, -62500,3.588388,2.6501718,,,,,,,,,,,,,, -62600,3.6336555,2.6664472,,,,,,,,,,,,,, -62700,4.266556,2.7665994,,,,,,,,,,,,,, -62800,2.9362257,2.62258,,,,,,,,,,,,,, -62900,3.485688,2.727025,,,,,,,,,,,,,, -63000,3.0651388,2.598122,,,,,,,,,,,,,, -63100,3.284726,2.6891577,,,,,,,,,,,,,, -63200,4.5820427,2.6597157,,,,,,,,,,,,,, -63300,3.0831392,2.7020502,,,,,,,,,,,,,, -63400,2.9538875,2.689981,,,,,,,,,,,,,, -63485,,,0.7110969424247742,1.2208093404769895,0.6271799802780151,1.6009138822555542,50000.0,0.5062000155448914,2.2677626609802246,10000.0,21477.900118112564,22416.84796738625,21477.900118112564,935.2121860980988,1.5857601165771484,0.0 -63500,3.4313788,2.589128,,,,,,,,,,,,,, -63600,3.3579075,2.7691479,,,,,,,,,,,,,, -63700,3.4703639,2.6677651,,,,,,,,,,,,,, -63800,3.491843,2.714209,,,,,,,,,,,,,, -63900,3.0676618,2.6370282,,,,,,,,,,,,,, -64000,4.201892,2.6423275,,,,,,,,,,,,,, -64100,3.885483,2.6562195,,,,,,,,,,,,,, -64200,3.4164052,2.5936975,,,,,,,,,,,,,, -64300,3.136847,2.73645,,,,,,,,,,,,,, -64400,3.1078026,2.732115,,,,,,,,,,,,,, -64500,3.529062,2.7441838,,,,,,,,,,,,,, -64600,3.5197928,2.637928,,,,,,,,,,,,,, -64700,3.1816852,2.6513493,,,,,,,,,,,,,, -64800,3.2935681,2.7345164,,,,,,,,,,,,,, -64900,3.786998,2.6280751,,,,,,,,,,,,,, -65000,,,0.6969068646430969,1.2933167219161987,0.6342200040817261,1.5750954151153564,50000.0,0.5092000365257263,2.258240222930908,10000.0,21988.11581516266,22944.5864379406,21988.11581516266,952.6368882656096,1.6321525573730469,0.0 -65000,3.1561487,2.5917385,,,,,,,,,,,,,, -65100,3.3614137,2.6315985,,,,,,,,,,,,,, -65200,3.7867386,2.6492667,,,,,,,,,,,,,, -65300,3.551054,2.6315908,,,,,,,,,,,,,, -65400,4.413917,2.5617626,,,,,,,,,,,,,, -65500,3.1631722,2.754449,,,,,,,,,,,,,, -65600,3.16768,2.5245345,,,,,,,,,,,,,, -65700,3.4988596,2.6495235,,,,,,,,,,,,,, -65800,3.5971859,2.6186044,,,,,,,,,,,,,, -65900,3.638631,2.645592,,,,,,,,,,,,,, -66000,3.2998986,2.7572682,,,,,,,,,,,,,, -66100,3.3378034,2.661611,,,,,,,,,,,,,, -66200,2.937489,2.6715672,,,,,,,,,,,,,, -66300,3.2623565,2.6135411,,,,,,,,,,,,,, -66400,3.537145,2.6461694,,,,,,,,,,,,,, -66500,3.423166,2.5833454,,,,,,,,,,,,,, -66515,,,0.6873405575752258,1.344896912574768,0.632099986076355,1.5866305828094482,50000.0,0.5099000334739685,2.246936559677124,10000.0,22498.35578727722,23472.325043201447,22498.35578727722,970.0448710918428,1.6718873977661133,0.0 -66600,3.1036868,2.7134597,,,,,,,,,,,,,, -66700,3.1759129,2.6759346,,,,,,,,,,,,,, -66800,3.6854599,2.5680182,,,,,,,,,,,,,, -66900,2.90355,2.6245344,,,,,,,,,,,,,, -67000,2.9077845,2.6497078,,,,,,,,,,,,,, -67100,3.1050088,2.6122131,,,,,,,,,,,,,, -67200,3.0840323,2.6006837,,,,,,,,,,,,,, -67300,3.0201538,2.7214732,,,,,,,,,,,,,, -67400,3.041352,2.6749313,,,,,,,,,,,,,, -67500,3.9507234,2.6234572,,,,,,,,,,,,,, -67600,3.2342856,2.6835713,,,,,,,,,,,,,, -67700,4.1243477,2.686873,,,,,,,,,,,,,, -67800,3.1112626,2.6253946,,,,,,,,,,,,,, -67900,3.007147,2.661968,,,,,,,,,,,,,, -68000,3.1798766,2.5400305,,,,,,,,,,,,,, -68030,,,0.6845503449440002,1.392478346824646,0.6325799822807312,1.6152527332305908,50000.0,0.5042000412940979,2.301208734512329,10000.0,23008.40454888344,23999.7828745842,23008.40454888344,987.3612501621246,1.7124810218811035,0.0 -68100,2.931924,2.7221715,,,,,,,,,,,,,, -68200,4.010598,2.6381273,,,,,,,,,,,,,, -68300,3.292219,2.7287886,,,,,,,,,,,,,, -68400,3.8490598,2.6426966,,,,,,,,,,,,,, -68500,3.2216156,2.6690223,,,,,,,,,,,,,, -68600,4.007391,2.5853062,,,,,,,,,,,,,, -68700,3.9169052,2.6335974,,,,,,,,,,,,,, -68800,3.3737307,2.6518054,,,,,,,,,,,,,, -68900,3.2512712,2.7770808,,,,,,,,,,,,,, -69000,3.1193366,2.6422787,,,,,,,,,,,,,, -69100,3.2397552,2.7411432,,,,,,,,,,,,,, -69200,3.726816,2.6863325,,,,,,,,,,,,,, -69300,3.4193096,2.7168953,,,,,,,,,,,,,, -69400,3.9781814,2.5982785,,,,,,,,,,,,,, -69500,3.6033287,2.675119,,,,,,,,,,,,,, -69544,,,0.6779336333274841,1.4261982440948486,0.6293999552726746,1.6455281972885132,50000.0,0.5063000321388245,2.339195966720581,10000.0,23518.33584046364,24527.45420455933,23518.33584046364,1005.0094563961028,1.7519042491912842,0.0 -69600,3.6362936,2.5778897,,,,,,,,,,,,,, -69700,3.4178853,2.5844452,,,,,,,,,,,,,, -69800,3.436986,2.6732247,,,,,,,,,,,,,, -69900,3.3142579,2.616139,,,,,,,,,,,,,, -70000,3.479165,2.5508828,,,,,,,,,,,,,, -70100,3.258038,2.6374717,,,,,,,,,,,,,, -70200,3.4591832,2.6640258,,,,,,,,,,,,,, -70300,3.3013291,2.7299047,,,,,,,,,,,,,, -70400,3.042864,2.655401,,,,,,,,,,,,,, -70500,3.4261415,2.698742,,,,,,,,,,,,,, -70600,3.3495824,2.6774094,,,,,,,,,,,,,, -70700,3.6231854,2.650593,,,,,,,,,,,,,, -70800,3.3299232,2.6799254,,,,,,,,,,,,,, -70900,3.4890022,2.599536,,,,,,,,,,,,,, -71000,3.05891,2.6593883,,,,,,,,,,,,,, -71059,,,0.6732102632522583,1.407064437866211,0.6301599740982056,1.6089874505996704,50000.0,0.4982000291347503,2.287114143371582,10000.0,24028.55358481407,25055.36220574379,24028.55358481407,1022.608335018158,1.7905707359313965,0.0 -71100,3.132191,2.5982697,,,,,,,,,,,,,, -71200,3.5503275,2.603496,,,,,,,,,,,,,, -71300,3.4446225,2.6288214,,,,,,,,,,,,,, -71400,4.378181,2.7311168,,,,,,,,,,,,,, -71500,3.1090045,2.6248183,,,,,,,,,,,,,, -71600,3.2790287,2.7028823,,,,,,,,,,,,,, -71700,3.3433268,2.6784172,,,,,,,,,,,,,, -71800,3.2987523,2.6718748,,,,,,,,,,,,,, -71900,3.9061027,2.5875244,,,,,,,,,,,,,, -72000,3.1530814,2.5841427,,,,,,,,,,,,,, -72100,3.865884,2.5619454,,,,,,,,,,,,,, -72200,3.7728493,2.5975542,,,,,,,,,,,,,, -72300,3.2886,2.6487222,,,,,,,,,,,,,, -72400,3.3767102,2.6644096,,,,,,,,,,,,,, -72500,3.309688,2.6757507,,,,,,,,,,,,,, -72574,,,0.7083266973495483,1.2449584007263184,0.6342399716377258,1.5873527526855469,50000.0,0.5004000067710876,2.2946672439575195,10000.0,24538.758437633514,25584.37909555435,24538.758437633514,1041.3291449546814,1.8296470642089844,0.0 -72600,3.6551652,2.7389119,,,,,,,,,,,,,, -72700,3.5440335,2.7689347,,,,,,,,,,,,,, -72800,4.305389,2.7308278,,,,,,,,,,,,,, -72900,3.556682,2.6762419,,,,,,,,,,,,,, -73000,3.640494,2.6196275,,,,,,,,,,,,,, -73100,3.4428585,2.6628854,,,,,,,,,,,,,, -73200,3.9866428,2.6434653,,,,,,,,,,,,,, -73300,4.134654,2.6901488,,,,,,,,,,,,,, -73400,4.001355,2.6273713,,,,,,,,,,,,,, -73500,3.9256468,2.659739,,,,,,,,,,,,,, -73600,3.600205,2.6096919,,,,,,,,,,,,,, -73700,3.1111746,2.7039413,,,,,,,,,,,,,, -73800,3.8476574,2.6824028,,,,,,,,,,,,,, -73900,3.402867,2.6282983,,,,,,,,,,,,,, -74000,3.2623055,2.5701013,,,,,,,,,,,,,, -74089,,,0.6872608065605164,1.3399715423583984,0.6256399750709534,1.6160084009170532,50000.0,0.5052000284194946,2.28117036819458,10000.0,25048.81888151169,26112.01877140999,25048.81888151169,1058.8174991607666,1.8682844638824463,0.0 -74100,3.4694583,2.6978939,,,,,,,,,,,,,, -74200,3.1269052,2.6002991,,,,,,,,,,,,,, -74300,3.4000773,2.6392298,,,,,,,,,,,,,, -74400,3.4607525,2.5341365,,,,,,,,,,,,,, -74500,3.5302339,2.5882456,,,,,,,,,,,,,, -74600,3.6048574,2.6790032,,,,,,,,,,,,,, -74700,3.2936792,2.6809206,,,,,,,,,,,,,, -74800,3.2005591,2.6824389,,,,,,,,,,,,,, -74900,3.8915732,2.6095192,,,,,,,,,,,,,, -75000,3.329798,2.7588472,,,,,,,,,,,,,, -75100,3.596984,2.7016046,,,,,,,,,,,,,, -75200,3.3812568,2.5490012,,,,,,,,,,,,,, -75300,3.6704407,2.5900097,,,,,,,,,,,,,, -75400,3.5436692,2.5859563,,,,,,,,,,,,,, -75500,4.4234447,2.7409158,,,,,,,,,,,,,, -75600,3.4130185,2.6782255,,,,,,,,,,,,,, -75604,,,0.7000358700752258,1.2867945432662964,0.6447199583053589,1.5424926280975342,50000.0,0.5152000188827515,2.232867956161499,10000.0,25558.90424060821,26640.29567885399,25558.90424060821,1076.917620420456,1.9076271057128904,0.0 -75700,3.3042512,2.6561732,,,,,,,,,,,,,, -75800,3.7116199,2.6386206,,,,,,,,,,,,,, -75900,4.4389224,2.588863,,,,,,,,,,,,,, -76000,4.7662206,2.6681657,,,,,,,,,,,,,, -76100,3.5481594,2.5976446,,,,,,,,,,,,,, -76200,3.266333,2.7675407,,,,,,,,,,,,,, -76300,3.9793553,2.6187735,,,,,,,,,,,,,, -76400,3.293599,2.5918438,,,,,,,,,,,,,, -76500,3.9243264,2.639302,,,,,,,,,,,,,, -76600,3.9263346,2.638483,,,,,,,,,,,,,, -76700,3.6574628,2.5872266,,,,,,,,,,,,,, -76800,3.4057493,2.6369421,,,,,,,,,,,,,, -76900,3.3976572,2.5825796,,,,,,,,,,,,,, -77000,3.7421093,2.6338077,,,,,,,,,,,,,, -77100,3.3821507,2.6541128,,,,,,,,,,,,,, -77119,,,0.6952527165412903,1.321738600730896,0.647599995136261,1.5462092161178589,50000.0,0.5154000520706177,2.214272737503052,10000.0,26069.08917927742,27168.2249045372,26069.08917927742,1094.5691964626312,1.9482781887054443,0.0 -77200,3.1420667,2.560467,,,,,,,,,,,,,, -77300,3.9228706,2.7261434,,,,,,,,,,,,,, -77400,4.3540316,2.6393008,,,,,,,,,,,,,, -77500,3.5712311,2.5972228,,,,,,,,,,,,,, -77600,3.4518547,2.6700003,,,,,,,,,,,,,, -77700,3.411755,2.5994458,,,,,,,,,,,,,, -77800,3.6770175,2.5450246,,,,,,,,,,,,,, -77900,3.7181795,2.6423512,,,,,,,,,,,,,, -78000,3.5826993,2.5379615,,,,,,,,,,,,,, -78100,3.1345713,2.6352558,,,,,,,,,,,,,, -78200,3.30184,2.6550856,,,,,,,,,,,,,, -78300,3.1422305,2.5709736,,,,,,,,,,,,,, -78400,3.4606152,2.6553707,,,,,,,,,,,,,, -78500,3.7325702,2.6370153,,,,,,,,,,,,,, -78600,3.1118672,2.5620701,,,,,,,,,,,,,, -78632,,,0.6901904940605164,1.3460441827774048,0.6373400092124939,1.582018494606018,50000.0,0.5111000537872314,2.251178503036499,10000.0,26578.31113815308,27696.523701667786,26578.31113815308,1112.730740070343,2.8109002113342285,0.0 -78700,4.337728,2.597161,,,,,,,,,,,,,, -78800,3.548101,2.6501377,,,,,,,,,,,,,, -78900,3.4819927,2.666302,,,,,,,,,,,,,, -79000,4.7225947,2.6083453,,,,,,,,,,,,,, -79100,3.3291416,2.5605848,,,,,,,,,,,,,, -79200,3.6392283,2.5083518,,,,,,,,,,,,,, -79300,3.2880006,2.766232,,,,,,,,,,,,,, -79400,3.4447129,2.6092296,,,,,,,,,,,,,, -79500,3.3423083,2.553775,,,,,,,,,,,,,, -79600,3.2636743,2.5318391,,,,,,,,,,,,,, -79700,3.655708,2.5760088,,,,,,,,,,,,,, -79800,3.7453144,2.6227531,,,,,,,,,,,,,, -79900,3.4449387,2.639142,,,,,,,,,,,,,, -80000,3.418806,2.6347806,,,,,,,,,,,,,, -80100,4.0167894,2.6382551,,,,,,,,,,,,,, -80148,,,0.6938576102256775,1.3477110862731934,0.6477599740028381,1.5738341808319092,50000.0,0.5189000368118286,2.245602369308472,10000.0,27088.510328292847,28224.07705426216,27088.510328292847,1129.9987313747406,2.845014333724976,0.0 -80200,3.0448565,2.5853255,,,,,,,,,,,,,, -80300,4.1134815,2.605764,,,,,,,,,,,,,, -80400,4.1333075,2.6333618,,,,,,,,,,,,,, -80500,4.1240883,2.6958034,,,,,,,,,,,,,, -80600,4.170234,2.5416937,,,,,,,,,,,,,, -80700,3.670374,2.5830214,,,,,,,,,,,,,, -80800,3.9681306,2.6509616,,,,,,,,,,,,,, -80900,3.92172,2.5727048,,,,,,,,,,,,,, -81000,3.5843463,2.6381464,,,,,,,,,,,,,, -81100,3.968974,2.5953598,,,,,,,,,,,,,, -81200,3.8873916,2.6705182,,,,,,,,,,,,,, -81300,3.704678,2.5978374,,,,,,,,,,,,,, -81400,3.6403139,2.6765301,,,,,,,,,,,,,, -81500,4.2771087,2.5841365,,,,,,,,,,,,,, -81600,3.8028872,2.5524623,,,,,,,,,,,,,, -81663,,,0.7150828838348389,1.2265141010284424,0.6447399854660034,1.5567433834075928,50000.0,0.5138000249862671,2.23746657371521,10000.0,27598.55618476868,28751.632429361343,27598.55618476868,1147.4161262512207,2.885287284851074,0.0 -81700,3.5552523,2.7104,,,,,,,,,,,,,, -81800,3.8380482,2.5853825,,,,,,,,,,,,,, -81900,3.4589608,2.538433,,,,,,,,,,,,,, -82000,3.6874752,2.5121973,,,,,,,,,,,,,, -82100,3.533249,2.6749492,,,,,,,,,,,,,, -82200,3.6861389,2.6070335,,,,,,,,,,,,,, -82300,3.1508248,2.6253903,,,,,,,,,,,,,, -82400,3.8683672,2.523171,,,,,,,,,,,,,, -82500,3.3876305,2.5713336,,,,,,,,,,,,,, -82600,3.7322543,2.6449575,,,,,,,,,,,,,, -82700,3.9305973,2.5835638,,,,,,,,,,,,,, -82800,3.9449093,2.6679094,,,,,,,,,,,,,, -82900,3.9510727,2.5395033,,,,,,,,,,,,,, -83000,3.3472373,2.6892657,,,,,,,,,,,,,, -83100,3.928046,2.6153407,,,,,,,,,,,,,, -83177,,,0.7067522406578064,1.2656277418136597,0.6476799845695496,1.5454624891281128,50000.0,0.5205000042915344,2.2279982566833496,10000.0,28108.67711853981,29279.347688674927,28108.67711853981,1164.917008638382,2.9261295795440674,0.0 -83200,3.241657,2.5292947,,,,,,,,,,,,,, -83300,3.8783858,2.5560453,,,,,,,,,,,,,, -83400,3.3997178,2.6044083,,,,,,,,,,,,,, -83500,4.0684395,2.575631,,,,,,,,,,,,,, -83600,4.206207,2.6330404,,,,,,,,,,,,,, -83700,3.54259,2.5892427,,,,,,,,,,,,,, -83800,3.7510488,2.6183834,,,,,,,,,,,,,, -83900,3.6198928,2.5754008,,,,,,,,,,,,,, -84000,3.4047391,2.6320243,,,,,,,,,,,,,, -84100,3.650529,2.6807308,,,,,,,,,,,,,, -84200,5.7529573,2.5412447,,,,,,,,,,,,,, -84300,3.4791334,2.589182,,,,,,,,,,,,,, -84400,3.9847412,2.5868785,,,,,,,,,,,,,, -84500,3.33972,2.567091,,,,,,,,,,,,,, -84600,3.4283981,2.550015,,,,,,,,,,,,,, -84692,,,0.70804762840271,1.2592806816101074,0.6536399722099304,1.508924961090088,50000.0,0.526900053024292,2.189680576324463,10000.0,28618.81132388115,29807.00520181656,28618.81132388115,1182.339945077896,2.973907709121704,0.0 -84700,4.141605,2.5834908,,,,,,,,,,,,,, -84800,4.1220293,2.5276158,,,,,,,,,,,,,, -84900,4.513106,2.5958648,,,,,,,,,,,,,, -85000,3.8528454,2.6391115,,,,,,,,,,,,,, -85100,3.8771098,2.603716,,,,,,,,,,,,,, -85200,3.960908,2.5917368,,,,,,,,,,,,,, -85300,3.640572,2.5917215,,,,,,,,,,,,,, -85400,3.3178186,2.549287,,,,,,,,,,,,,, -85500,3.5887237,2.6933074,,,,,,,,,,,,,, -85600,3.628155,2.6528172,,,,,,,,,,,,,, -85700,3.7022812,2.603843,,,,,,,,,,,,,, -85800,4.342007,2.581752,,,,,,,,,,,,,, -85900,3.2877836,2.537941,,,,,,,,,,,,,, -86000,3.3152606,2.567499,,,,,,,,,,,,,, -86100,3.4086773,2.6351657,,,,,,,,,,,,,, -86200,3.7299607,2.5628603,,,,,,,,,,,,,, -86207,,,0.7067123651504517,1.260049819946289,0.653659999370575,1.5093188285827637,50000.0,0.5340999960899353,2.161987543106079,10000.0,29128.871512889866,30334.79681873321,29128.871512889866,1199.9780325889587,3.0153017044067383,0.0 -86300,3.9714997,2.5752864,,,,,,,,,,,,,, -86400,3.6855214,2.59881,,,,,,,,,,,,,, -86500,4.2972665,2.5785396,,,,,,,,,,,,,, -86600,3.9206517,2.5311754,,,,,,,,,,,,,, -86700,3.817015,2.5857642,,,,,,,,,,,,,, -86800,4.3611836,2.6157365,,,,,,,,,,,,,, -86900,3.6137993,2.5654883,,,,,,,,,,,,,, -87000,3.658853,2.5562427,,,,,,,,,,,,,, -87100,3.754609,2.6279814,,,,,,,,,,,,,, -87200,3.8046806,2.5424378,,,,,,,,,,,,,, -87300,3.506967,2.5368876,,,,,,,,,,,,,, -87400,3.3546917,2.4799747,,,,,,,,,,,,,, -87500,3.5055106,2.5200279,,,,,,,,,,,,,, -87600,3.899728,2.5740972,,,,,,,,,,,,,, -87700,3.637482,2.5726361,,,,,,,,,,,,,, -87721,,,0.7043008208274841,1.2704253196716309,0.6558200120925903,1.510108232498169,50000.0,0.5281000137329102,2.171442985534668,10000.0,29638.96631598473,30862.28338265419,29638.96631598473,1217.268765926361,3.0644567012786865,0.0 -87800,3.5421798,2.6602874,,,,,,,,,,,,,, -87900,4.6448994,2.578662,,,,,,,,,,,,,, -88000,4.456308,2.5812376,,,,,,,,,,,,,, -88100,3.665317,2.501773,,,,,,,,,,,,,, -88200,4.121338,2.6594589,,,,,,,,,,,,,, -88300,3.437668,2.5550187,,,,,,,,,,,,,, -88400,3.4335957,2.5222268,,,,,,,,,,,,,, -88500,3.5361917,2.5438929,,,,,,,,,,,,,, -88600,3.5094938,2.5225482,,,,,,,,,,,,,, -88700,3.6785948,2.5892537,,,,,,,,,,,,,, -88800,4.0498796,2.5577292,,,,,,,,,,,,,, -88900,3.689616,2.6112428,,,,,,,,,,,,,, -89000,3.673923,2.6088848,,,,,,,,,,,,,, -89100,4.114251,2.4853978,,,,,,,,,,,,,, -89200,3.3874593,2.527661,,,,,,,,,,,,,, -89236,,,0.7318239808082581,1.1713844537734983,0.6634599566459656,1.469886064529419,50000.0,0.5405000448226929,2.117152214050293,10000.0,30149.158131837845,31389.92377448082,30149.158131837845,1234.6226682662964,3.106571912765503,0.0 -89300,3.8378837,2.5027406,,,,,,,,,,,,,, -89400,3.678806,2.5372071,,,,,,,,,,,,,, -89500,3.5228589,2.5375414,,,,,,,,,,,,,, -89600,3.630208,2.5669034,,,,,,,,,,,,,, -89700,4.5716558,2.6314974,,,,,,,,,,,,,, -89800,3.6202772,2.5535097,,,,,,,,,,,,,, -89900,3.7595491,2.529773,,,,,,,,,,,,,, -90000,3.5784755,2.5161593,,,,,,,,,,,,,, -90100,3.4538815,2.5660844,,,,,,,,,,,,,, -90200,3.8154833,2.579329,,,,,,,,,,,,,, -90300,3.4362857,2.6204553,,,,,,,,,,,,,, -90400,3.8776195,2.6286151,,,,,,,,,,,,,, -90500,3.4119132,2.62454,,,,,,,,,,,,,, -90600,4.1967187,2.4585962,,,,,,,,,,,,,, -90700,3.500329,2.507511,,,,,,,,,,,,,, -90751,,,0.724609375,1.1772572994232178,0.6525599956512451,1.5016096830368042,50000.0,0.5282000303268433,2.163940191268921,10000.0,30659.136559963223,31917.36545681953,30659.136559963223,1251.9876792430878,3.152724266052246,0.0 -90800,3.7555153,2.6363819,,,,,,,,,,,,,, -90900,3.9731524,2.5316043,,,,,,,,,,,,,, -91000,3.7674713,2.535871,,,,,,,,,,,,,, -91100,3.7056887,2.4659443,,,,,,,,,,,,,, -91200,3.7156525,2.477878,,,,,,,,,,,,,, -91300,3.7467282,2.5336475,,,,,,,,,,,,,, -91400,3.654282,2.3858984,,,,,,,,,,,,,, -91500,4.0267897,2.611567,,,,,,,,,,,,,, -91600,4.5557513,2.550785,,,,,,,,,,,,,, -91700,3.6734612,2.6955445,,,,,,,,,,,,,, -91800,4.025127,2.5491564,,,,,,,,,,,,,, -91900,3.6699102,2.4373817,,,,,,,,,,,,,, -92000,3.9653733,2.4834013,,,,,,,,,,,,,, -92100,3.640006,2.4753628,,,,,,,,,,,,,, -92200,3.4601195,2.483696,,,,,,,,,,,,,, -92266,,,0.7249082922935486,1.186362624168396,0.6613199710845947,1.461118459701538,50000.0,0.541700005531311,2.0914676189422607,10000.0,31169.08204507828,32444.939204216003,31169.08204507828,1269.5212585926056,3.195277690887451,0.0 -92300,3.6202116,2.4173005,,,,,,,,,,,,,, -92400,3.7416954,2.5139663,,,,,,,,,,,,,, -92500,3.4685774,2.4654136,,,,,,,,,,,,,, -92600,3.9878001,2.5150847,,,,,,,,,,,,,, -92700,3.7050283,2.50812,,,,,,,,,,,,,, -92800,4.149043,2.5811708,,,,,,,,,,,,,, -92900,3.9555955,2.5751727,,,,,,,,,,,,,, -93000,3.6383657,2.4934618,,,,,,,,,,,,,, -93100,3.7827477,2.614709,,,,,,,,,,,,,, -93200,4.067478,2.507654,,,,,,,,,,,,,, -93300,4.325792,2.5360553,,,,,,,,,,,,,, -93400,3.8070076,2.5516148,,,,,,,,,,,,,, -93500,3.8907092,2.5007815,,,,,,,,,,,,,, -93600,3.9811795,2.5128577,,,,,,,,,,,,,, -93700,4.15945,2.5988693,,,,,,,,,,,,,, -93780,,,0.7208425998687744,1.238020896911621,0.6592599749565125,1.504926085472107,50000.0,0.5371000170707703,2.135526180267334,10000.0,31679.052606105804,32972.2160525322,31679.052606105804,1286.7298786640167,3.240107774734497,0.0 -93800,4.0750732,2.510066,,,,,,,,,,,,,, -93900,4.1088204,2.632221,,,,,,,,,,,,,, -94000,3.817677,2.4989676,,,,,,,,,,,,,, -94100,4.0435715,2.5906978,,,,,,,,,,,,,, -94200,4.0384665,2.5303214,,,,,,,,,,,,,, -94300,4.73162,2.62061,,,,,,,,,,,,,, -94400,4.083166,2.5439239,,,,,,,,,,,,,, -94500,3.599483,2.4855652,,,,,,,,,,,,,, -94600,3.882592,2.4928708,,,,,,,,,,,,,, -94700,3.973979,2.5217814,,,,,,,,,,,,,, -94800,4.7545114,2.4833264,,,,,,,,,,,,,, -94900,4.0521693,2.4686973,,,,,,,,,,,,,, -95000,3.6801317,2.5550818,,,,,,,,,,,,,, -95100,3.833239,2.5198786,,,,,,,,,,,,,, -95200,4.1644745,2.5161037,,,,,,,,,,,,,, -95294,,,0.7150231003761292,1.20890474319458,0.6631799936294556,1.4594324827194214,50000.0,0.5412000417709351,2.117387533187866,10000.0,32189.051352262497,33499.531002283096,32189.051352262497,1303.9475963115692,3.285914421081543,0.0 -95300,3.6834738,2.5767198,,,,,,,,,,,,,, -95400,4.4910283,2.4880986,,,,,,,,,,,,,, -95500,3.7315676,2.4797406,,,,,,,,,,,,,, -95600,4.047668,2.427103,,,,,,,,,,,,,, -95700,4.3312697,2.51897,,,,,,,,,,,,,, -95800,4.5150323,2.5846376,,,,,,,,,,,,,, -95900,4.5620675,2.5389574,,,,,,,,,,,,,, -96000,3.295518,2.481341,,,,,,,,,,,,,, -96100,3.6985037,2.5653505,,,,,,,,,,,,,, -96200,3.9673398,2.4844487,,,,,,,,,,,,,, -96300,4.343033,2.5664945,,,,,,,,,,,,,, -96400,4.065496,2.6114135,,,,,,,,,,,,,, -96500,4.4638867,2.4897606,,,,,,,,,,,,,, -96600,3.8990176,2.5603306,,,,,,,,,,,,,, -96700,4.019423,2.474644,,,,,,,,,,,,,, -96800,3.7002468,2.5768566,,,,,,,,,,,,,, -96808,,,0.721101701259613,1.1796294450759888,0.6640200018882751,1.435350775718689,50000.0,0.5457000136375427,2.08373498916626,10000.0,32699.01106214524,34027.04830813408,32699.01106214524,1321.4094231128693,3.3290951251983643,0.0 -96900,3.538811,2.576786,,,,,,,,,,,,,, -97000,4.301903,2.5226052,,,,,,,,,,,,,, -97100,3.379186,2.5261495,,,,,,,,,,,,,, -97200,4.131658,2.4793167,,,,,,,,,,,,,, -97300,3.9589906,2.51932,,,,,,,,,,,,,, -97400,4.509411,2.4934092,,,,,,,,,,,,,, -97500,3.8083687,2.4539742,,,,,,,,,,,,,, -97600,4.1653266,2.5056884,,,,,,,,,,,,,, -97700,4.6158967,2.4983163,,,,,,,,,,,,,, -97800,4.0267015,2.6231217,,,,,,,,,,,,,, -97900,4.6309295,2.4884496,,,,,,,,,,,,,, -98000,4.6627536,2.5120745,,,,,,,,,,,,,, -98100,3.4634645,2.5036287,,,,,,,,,,,,,, -98200,3.8177762,2.4226062,,,,,,,,,,,,,, -98300,3.9664454,2.511054,,,,,,,,,,,,,, -98323,,,0.7517338991165161,1.0737617015838623,0.663100004196167,1.4655168056488037,50000.0,0.5379000306129456,2.11805272102356,10000.0,33209.20894575119,34554.6969575882,33209.20894575119,1338.7655036449432,3.371743440628052,0.0 -98400,4.1744304,2.4579644,,,,,,,,,,,,,, -98500,3.6467733,2.5418653,,,,,,,,,,,,,, -98600,3.5614398,2.4971352,,,,,,,,,,,,,, -98700,3.6325197,2.4113998,,,,,,,,,,,,,, -98800,3.6845336,2.3999417,,,,,,,,,,,,,, -98900,3.9686882,2.4578555,,,,,,,,,,,,,, -99000,4.1327796,2.5561364,,,,,,,,,,,,,, -99100,3.9708989,2.4818554,,,,,,,,,,,,,, -99200,4.131904,2.5318928,,,,,,,,,,,,,, -99300,3.770465,2.4734669,,,,,,,,,,,,,, -99400,4.416468,2.5126398,,,,,,,,,,,,,, -99500,4.3520217,2.4941757,,,,,,,,,,,,,, -99600,3.8429556,2.4761019,,,,,,,,,,,,,, -99700,3.5890226,2.3956687,,,,,,,,,,,,,, -99800,4.2402306,2.4931307,,,,,,,,,,,,,, -99838,,,0.7379822731018066,1.116407036781311,0.6669999957084656,1.440180420875549,50000.0,0.5509999990463257,2.067492961883545,10000.0,33719.22984600067,35082.03232550621,33719.22984600067,1355.984681367874,3.415469169616699,0.0 -99900,4.2233644,2.565966,,,,,,,,,,,,,, -100000,4.8560653,2.5538616,,,,,,,,,,,,,, -100100,3.7362356,2.5245078,,,,,,,,,,,,,, -100200,3.7844641,2.384875,,,,,,,,,,,,,, -100300,4.515004,2.5224845,,,,,,,,,,,,,, -100400,4.074322,2.4989781,,,,,,,,,,,,,, -100500,4.4842696,2.5361943,,,,,,,,,,,,,, -100600,3.825205,2.5320997,,,,,,,,,,,,,, -100700,4.726372,2.5283132,,,,,,,,,,,,,, -100800,4.4728394,2.4405823,,,,,,,,,,,,,, -100900,4.7881293,2.4509838,,,,,,,,,,,,,, -101000,4.0001893,2.4953258,,,,,,,,,,,,,, -101100,3.778007,2.4831488,,,,,,,,,,,,,, -101200,4.2188115,2.5512369,,,,,,,,,,,,,, -101300,3.772932,2.4923143,,,,,,,,,,,,,, -101353,,,0.7453961968421936,1.0961382389068604,0.676099956035614,1.4047099351882937,50000.0,0.5529000163078308,2.0516908168792725,10000.0,34229.33446264267,35609.674355983734,34229.33446264267,1373.4264228343964,3.4589407444000244,0.0 -101400,4.7320795,2.5092015,,,,,,,,,,,,,, -101500,4.0568576,2.4196386,,,,,,,,,,,,,, -101600,3.8797786,2.538221,,,,,,,,,,,,,, -101700,3.6722028,2.4323523,,,,,,,,,,,,,, -101800,3.956643,2.452449,,,,,,,,,,,,,, -101900,4.109291,2.3904905,,,,,,,,,,,,,, -102000,4.440391,2.494332,,,,,,,,,,,,,, -102100,4.095398,2.5302486,,,,,,,,,,,,,, -102200,3.811548,2.448889,,,,,,,,,,,,,, -102300,4.405526,2.5379095,,,,,,,,,,,,,, -102400,4.475648,2.53768,,,,,,,,,,,,,, -102500,4.4336524,2.4604115,,,,,,,,,,,,,, -102600,4.23767,2.4730222,,,,,,,,,,,,,, -102700,4.583588,2.4988256,,,,,,,,,,,,,, -102800,3.7460065,2.4566615,,,,,,,,,,,,,, -102868,,,0.7364476919174194,1.114004135131836,0.672980010509491,1.399997353553772,50000.0,0.5488000512123108,2.043478488922119,10000.0,34739.40772628784,36137.24479365349,34739.40772628784,1390.8268103599548,3.504430055618286,0.0 -102900,3.9935596,2.374616,,,,,,,,,,,,,, -103000,4.130552,2.4489682,,,,,,,,,,,,,, -103100,3.5764782,2.42035,,,,,,,,,,,,,, -103200,4.0527444,2.5220654,,,,,,,,,,,,,, -103300,4.104189,2.444233,,,,,,,,,,,,,, -103400,3.7660167,2.4338908,,,,,,,,,,,,,, -103500,3.695524,2.4195895,,,,,,,,,,,,,, -103600,3.8872426,2.438229,,,,,,,,,,,,,, -103700,4.233918,2.4401338,,,,,,,,,,,,,, -103800,4.0214186,2.5354185,,,,,,,,,,,,,, -103900,3.9951007,2.4312932,,,,,,,,,,,,,, -104000,4.3678,2.49495,,,,,,,,,,,,,, -104100,4.28421,2.4695568,,,,,,,,,,,,,, -104200,4.508502,2.4106648,,,,,,,,,,,,,, -104300,4.0934224,2.4209948,,,,,,,,,,,,,, -104382,,,0.7351921200752258,1.1262609958648682,0.6737599968910217,1.4070687294006348,50000.0,0.5547000169754028,2.052401065826416,10000.0,35249.39946103096,36664.71185588837,35249.39946103096,1408.2055933475494,3.548323154449463,0.0 -104400,4.1031694,2.3909173,,,,,,,,,,,,,, -104500,4.0846376,2.4587905,,,,,,,,,,,,,, -104600,4.4666038,2.493261,,,,,,,,,,,,,, -104700,4.384038,2.4363964,,,,,,,,,,,,,, -104800,4.478115,2.499945,,,,,,,,,,,,,, -104900,3.9973032,2.505605,,,,,,,,,,,,,, -105000,4.023939,2.5089145,,,,,,,,,,,,,, -105100,4.681536,2.464655,,,,,,,,,,,,,, -105200,4.1377935,2.534628,,,,,,,,,,,,,, -105300,4.539337,2.55731,,,,,,,,,,,,,, -105400,4.0539503,2.5201976,,,,,,,,,,,,,, -105500,4.220632,2.4581106,,,,,,,,,,,,,, -105600,4.2237062,2.4194627,,,,,,,,,,,,,, -105700,4.79419,2.4931376,,,,,,,,,,,,,, -105800,3.7091987,2.405254,,,,,,,,,,,,,, -105897,,,0.7310666441917419,1.1510498523712158,0.6747599840164185,1.4065779447555542,50000.0,0.5578000545501709,2.038433790206909,10000.0,35759.47870969772,37192.232377290726,35759.47870969772,1425.5483980178833,3.5949106216430664,0.0 -105900,4.186779,2.4103885,,,,,,,,,,,,,, -106000,4.2331166,2.5038173,,,,,,,,,,,,,, -106100,4.1950874,2.4209309,,,,,,,,,,,,,, -106200,3.908817,2.3704958,,,,,,,,,,,,,, -106300,3.7464025,2.404819,,,,,,,,,,,,,, -106400,4.2775097,2.4292939,,,,,,,,,,,,,, -106500,4.30567,2.4811473,,,,,,,,,,,,,, -106600,4.25162,2.4349537,,,,,,,,,,,,,, -106700,3.964702,2.4713466,,,,,,,,,,,,,, -106800,3.955267,2.468308,,,,,,,,,,,,,, -106900,4.195152,2.3921187,,,,,,,,,,,,,, -107000,3.8135087,2.3449152,,,,,,,,,,,,,, -107100,4.3240614,2.487377,,,,,,,,,,,,,, -107200,4.011002,2.4663086,,,,,,,,,,,,,, -107300,4.2232184,2.4869692,,,,,,,,,,,,,, -107400,4.4008436,2.466025,,,,,,,,,,,,,, -107412,,,0.7695910334587097,1.021224856376648,0.6710999608039856,1.4482934474945068,50000.0,0.5497000217437744,2.1009111404418945,10000.0,36269.60155582428,37719.8588643074,36269.60155582428,1442.9533438682556,3.6414597034454346,0.0 -107500,3.903289,2.4331615,,,,,,,,,,,,,, -107600,4.0766916,2.3822565,,,,,,,,,,,,,, -107700,4.0894823,2.508463,,,,,,,,,,,,,, -107800,4.899813,2.4680626,,,,,,,,,,,,,, -107900,4.344885,2.526462,,,,,,,,,,,,,, -108000,3.9194248,2.4842286,,,,,,,,,,,,,, -108100,4.2757654,2.4779632,,,,,,,,,,,,,, -108200,4.1201797,2.484004,,,,,,,,,,,,,, -108300,4.1369767,2.4811158,,,,,,,,,,,,,, -108400,4.355267,2.4505613,,,,,,,,,,,,,, -108500,4.2269278,2.3391414,,,,,,,,,,,,,, -108600,4.3770876,2.5520272,,,,,,,,,,,,,, -108700,4.161393,2.4534917,,,,,,,,,,,,,, -108800,4.2613463,2.4789684,,,,,,,,,,,,,, -108900,4.4248066,2.3998218,,,,,,,,,,,,,, -108927,,,0.7567362785339355,1.0428543090820312,0.6768999695777893,1.38744056224823,50000.0,0.5545000433921814,2.0375123023986816,10000.0,36779.734280347824,38247.51591229439,36779.734280347824,1460.3775732517242,3.689704656600952,0.0 -109000,4.3241706,2.4231057,,,,,,,,,,,,,, -109100,5.456649,2.4807906,,,,,,,,,,,,,, -109200,4.0771446,2.3499336,,,,,,,,,,,,,, -109300,4.115015,2.4736164,,,,,,,,,,,,,, -109400,4.071341,2.4208992,,,,,,,,,,,,,, -109500,5.185445,2.4783194,,,,,,,,,,,,,, -109600,4.5632715,2.4418375,,,,,,,,,,,,,, -109700,3.9657547,2.3913372,,,,,,,,,,,,,, -109800,4.566623,2.3954773,,,,,,,,,,,,,, -109900,4.9452076,2.433697,,,,,,,,,,,,,, -110000,4.1635413,2.3889978,,,,,,,,,,,,,, -110100,4.5854955,2.3918567,,,,,,,,,,,,,, -110200,4.1864347,2.4279976,,,,,,,,,,,,,, -110300,5.014107,2.3991213,,,,,,,,,,,,,, -110400,4.513448,2.5180147,,,,,,,,,,,,,, -110442,,,0.7382413744926453,1.1036720275878906,0.6723799705505371,1.4153038263320925,50000.0,0.5482000112533569,2.050555467605591,10000.0,37289.76195025444,38776.281270504,37289.76195025444,1479.0125722885132,3.740776300430298,0.0 -110500,4.0066056,2.4052017,,,,,,,,,,,,,, -110600,3.782304,2.3751996,,,,,,,,,,,,,, -110700,4.218128,2.5017858,,,,,,,,,,,,,, -110800,4.786861,2.4753184,,,,,,,,,,,,,, -110900,4.2456913,2.4550276,,,,,,,,,,,,,, -111000,4.5522256,2.4632635,,,,,,,,,,,,,, -111100,4.7653646,2.4577088,,,,,,,,,,,,,, -111200,4.2405643,2.3532565,,,,,,,,,,,,,, -111300,4.416696,2.3918135,,,,,,,,,,,,,, -111400,4.2902303,2.3762462,,,,,,,,,,,,,, -111500,4.648013,2.4494262,,,,,,,,,,,,,, -111600,4.613591,2.4019594,,,,,,,,,,,,,, -111700,4.536293,2.4653084,,,,,,,,,,,,,, -111800,4.5028834,2.4444373,,,,,,,,,,,,,, -111900,4.370769,2.3903077,,,,,,,,,,,,,, -111957,,,0.7502391338348389,1.063479781150818,0.6810599565505981,1.373931050300598,50000.0,0.5591000318527222,2.00235652923584,10000.0,37799.80076622963,39303.59298801422,37799.80076622963,1496.186547756195,3.787414312362671,0.0 -112000,4.460753,2.3547685,,,,,,,,,,,,,, -112100,4.9395757,2.4100711,,,,,,,,,,,,,, -112200,3.7810047,2.3948417,,,,,,,,,,,,,, -112300,4.3558483,2.4319324,,,,,,,,,,,,,, -112400,4.8265386,2.5155997,,,,,,,,,,,,,, -112500,4.379826,2.355503,,,,,,,,,,,,,, -112600,3.872249,2.4078789,,,,,,,,,,,,,, -112700,4.230675,2.3565943,,,,,,,,,,,,,, -112800,4.8954186,2.3812807,,,,,,,,,,,,,, -112900,4.5756865,2.4776874,,,,,,,,,,,,,, -113000,4.580173,2.4302983,,,,,,,,,,,,,, -113100,4.1048265,2.363779,,,,,,,,,,,,,, -113200,4.253018,2.4301736,,,,,,,,,,,,,, -113300,4.258305,2.3818264,,,,,,,,,,,,,, -113400,5.016219,2.3731432,,,,,,,,,,,,,, -113472,,,0.744559109210968,1.0819778442382812,0.6834200024604797,1.3549482822418213,50000.0,0.5565000176429749,2.0129952430725098,10000.0,38309.85378551483,39830.87760901451,38309.85378551483,1513.319462299347,3.83447790145874,0.0 -113500,4.4163437,2.4335353,,,,,,,,,,,,,, -113600,4.371992,2.38135,,,,,,,,,,,,,, -113700,4.134573,2.392328,,,,,,,,,,,,,, -113800,4.4033275,2.44933,,,,,,,,,,,,,, -113900,4.4280615,2.4409409,,,,,,,,,,,,,, -114000,4.428183,2.4126084,,,,,,,,,,,,,, -114100,4.597602,2.3636374,,,,,,,,,,,,,, -114200,4.798062,2.3524785,,,,,,,,,,,,,, -114300,4.3619466,2.3349164,,,,,,,,,,,,,, -114400,4.805878,2.3945615,,,,,,,,,,,,,, -114500,4.176891,2.3893187,,,,,,,,,,,,,, -114600,4.343708,2.419069,,,,,,,,,,,,,, -114700,4.2992697,2.35576,,,,,,,,,,,,,, -114800,4.807005,2.3766084,,,,,,,,,,,,,, -114900,4.0948777,2.4132211,,,,,,,,,,,,,, -114987,,,0.7444595098495483,1.1149373054504397,0.6838399767875671,1.3911999464035034,50000.0,0.5561000108718872,2.0544209480285645,10000.0,38819.99199438095,40358.582310676575,38819.99199438095,1530.7830486297607,3.8848836421966553,0.0 -115000,5.102284,2.3494682,,,,,,,,,,,,,, -115100,4.3683715,2.3876147,,,,,,,,,,,,,, -115200,4.2443247,2.3860786,,,,,,,,,,,,,, -115300,3.9891837,2.3740144,,,,,,,,,,,,,, -115400,4.095178,2.3747745,,,,,,,,,,,,,, -115500,4.873423,2.4441133,,,,,,,,,,,,,, -115600,4.867073,2.327329,,,,,,,,,,,,,, -115700,4.345029,2.4590402,,,,,,,,,,,,,, -115800,4.9163218,2.3370469,,,,,,,,,,,,,, -115900,4.2028413,2.2792647,,,,,,,,,,,,,, -116000,4.497385,2.3224773,,,,,,,,,,,,,, -116100,4.53915,2.3783092,,,,,,,,,,,,,, -116200,4.7263436,2.4399076,,,,,,,,,,,,,, -116300,4.680694,2.3939478,,,,,,,,,,,,,, -116400,4.631877,2.3422484,,,,,,,,,,,,,, -116500,4.763268,2.2772431,,,,,,,,,,,,,, -116501,,,0.7858139276504517,0.9308143854141236,0.6909199953079224,1.3370513916015625,50000.0,0.567300021648407,1.991621375083924,10000.0,39329.903483867645,40885.95380759239,39329.903483867645,1548.1430974006653,3.932539224624634,0.0 -116600,4.2921853,2.3391464,,,,,,,,,,,,,, -116700,4.393295,2.4135373,,,,,,,,,,,,,, -116800,4.474585,2.318664,,,,,,,,,,,,,, -116900,4.3635325,2.3360991,,,,,,,,,,,,,, -117000,4.529863,2.4042423,,,,,,,,,,,,,, -117100,3.9654942,2.3275135,,,,,,,,,,,,,, -117200,4.3707623,2.4152584,,,,,,,,,,,,,, -117300,4.9861636,2.3680763,,,,,,,,,,,,,, -117400,4.5457253,2.30969,,,,,,,,,,,,,, -117500,4.3026304,2.3107176,,,,,,,,,,,,,, -117600,4.720465,2.4566505,,,,,,,,,,,,,, -117700,4.845078,2.3716192,,,,,,,,,,,,,, -117800,4.493024,2.4035487,,,,,,,,,,,,,, -117900,4.8547525,2.320447,,,,,,,,,,,,,, -118000,4.738421,2.365036,,,,,,,,,,,,,, -118016,,,0.7719228267669678,0.9736966490745544,0.692799985408783,1.3256369829177856,50000.0,0.567300021648407,1.984636664390564,10000.0,39840.01126456261,41413.57640933991,39840.01126456261,1565.5552134513855,3.9830572605133057,0.0 -118100,4.516947,2.352794,,,,,,,,,,,,,, -118200,4.4106073,2.4212246,,,,,,,,,,,,,, -118300,4.5885606,2.408373,,,,,,,,,,,,,, -118400,4.8306675,2.2909868,,,,,,,,,,,,,, -118500,4.705962,2.3703477,,,,,,,,,,,,,, -118600,4.382288,2.306173,,,,,,,,,,,,,, -118700,4.5316477,2.4034762,,,,,,,,,,,,,, -118800,5.15468,2.3806553,,,,,,,,,,,,,, -118900,4.698777,2.3681214,,,,,,,,,,,,,, -119000,4.24371,2.3086412,,,,,,,,,,,,,, -119100,4.763708,2.3721428,,,,,,,,,,,,,, -119200,4.756604,2.4000454,,,,,,,,,,,,,, -119300,4.927645,2.301902,,,,,,,,,,,,,, -119400,4.823865,2.338906,,,,,,,,,,,,,, -119500,4.434376,2.3825798,,,,,,,,,,,,,, -119531,,,0.7692123651504517,0.9980911016464232,0.6963399648666382,1.3163288831710815,50000.0,0.5732000470161438,1.962451219558716,10000.0,40350.17130947113,41941.20932650566,40350.17130947113,1582.9293761253357,4.029150485992432,0.0 -119600,4.512509,2.488184,,,,,,,,,,,,,, -119700,5.047175,2.3539696,,,,,,,,,,,,,, -119800,4.8696036,2.3707945,,,,,,,,,,,,,, -119900,4.640444,2.3397145,,,,,,,,,,,,,, -120000,4.581476,2.370081,,,,,,,,,,,,,, -120100,4.374421,2.3055398,,,,,,,,,,,,,, -120200,4.5447345,2.4004974,,,,,,,,,,,,,, -120300,4.880274,2.3760972,,,,,,,,,,,,,, -120400,4.090468,2.3072238,,,,,,,,,,,,,, -120500,4.5716987,2.290718,,,,,,,,,,,,,, -120600,4.69355,2.4104948,,,,,,,,,,,,,, -120700,5.047427,2.4430223,,,,,,,,,,,,,, -120800,4.70536,2.3597622,,,,,,,,,,,,,, -120900,4.8159637,2.335711,,,,,,,,,,,,,, -121000,4.6890845,2.3951635,,,,,,,,,,,,,, -121046,,,0.7643494606018066,1.0025675296783447,0.6955400109291077,1.3097865581512451,50000.0,0.5732000470161438,1.9472129344940183,10000.0,40860.15645599365,42468.937943696976,40860.15645599365,1600.571870803833,4.077746868133545,0.0 -121100,5.451167,2.2397704,,,,,,,,,,,,,, -121200,4.8225055,2.3711984,,,,,,,,,,,,,, -121300,4.94079,2.242066,,,,,,,,,,,,,, -121400,4.901918,2.3305361,,,,,,,,,,,,,, -121500,4.665107,2.3532467,,,,,,,,,,,,,, -121600,4.2608333,2.3423753,,,,,,,,,,,,,, -121700,4.720089,2.3657622,,,,,,,,,,,,,, -121800,4.7901807,2.2818325,,,,,,,,,,,,,, -121900,4.677849,2.4328542,,,,,,,,,,,,,, -122000,4.3942547,2.2772903,,,,,,,,,,,,,, -122100,5.2926726,2.4057398,,,,,,,,,,,,,, -122200,5.086952,2.2628045,,,,,,,,,,,,,, -122300,4.8013263,2.3243554,,,,,,,,,,,,,, -122400,4.1908674,2.3081138,,,,,,,,,,,,,, -122500,4.87117,2.3450358,,,,,,,,,,,,,, -122561,,,0.7720224857330322,0.97927588224411,0.7010399699211121,1.279943585395813,50000.0,0.5808000564575195,1.9124568700790403,10000.0,41370.070234537125,42996.46684598923,41370.070234537125,1618.0866899490356,4.125354290008545,0.0 -122600,4.4892845,2.3620858,,,,,,,,,,,,,, -122700,5.854944,2.2599747,,,,,,,,,,,,,, -122800,4.4677343,2.3289423,,,,,,,,,,,,,, -122900,4.1781025,2.2245712,,,,,,,,,,,,,, -123000,4.634556,2.307963,,,,,,,,,,,,,, -123100,4.53369,2.2708616,,,,,,,,,,,,,, -123200,4.835756,2.3611188,,,,,,,,,,,,,, -123300,4.6426883,2.2706804,,,,,,,,,,,,,, -123400,4.733981,2.3572073,,,,,,,,,,,,,, -123500,5.1236258,2.37221,,,,,,,,,,,,,, -123600,4.755034,2.3212929,,,,,,,,,,,,,, -123700,4.5254936,2.253365,,,,,,,,,,,,,, -123800,5.2544136,2.3266206,,,,,,,,,,,,,, -123900,4.719846,2.2810113,,,,,,,,,,,,,, -124000,4.667584,2.2759182,,,,,,,,,,,,,, -124076,,,0.7723413705825806,0.971625566482544,0.7017599940299988,1.276233196258545,50000.0,0.5755000114440918,1.924700140953064,10000.0,41880.03378677368,43523.94673323631,41880.03378677368,1635.502154827118,4.173144578933716,0.0 -124100,4.398302,2.2606783,,,,,,,,,,,,,, -124200,4.74878,2.3533914,,,,,,,,,,,,,, -124300,4.8218374,2.3979409,,,,,,,,,,,,,, -124400,4.677384,2.2793322,,,,,,,,,,,,,, -124500,5.2964764,2.3431926,,,,,,,,,,,,,, -124600,4.511793,2.2624483,,,,,,,,,,,,,, -124700,4.656187,2.2871084,,,,,,,,,,,,,, -124800,5.1102014,2.3206313,,,,,,,,,,,,,, -124900,4.9725866,2.3090863,,,,,,,,,,,,,, -125000,4.523313,2.2485912,,,,,,,,,,,,,, -125100,5.028142,2.332537,,,,,,,,,,,,,, -125200,5.137025,2.3695762,,,,,,,,,,,,,, -125300,5.3847957,2.325267,,,,,,,,,,,,,, -125400,5.1587,2.344835,,,,,,,,,,,,,, -125500,5.417162,2.3881269,,,,,,,,,,,,,, -125591,,,0.7950015664100647,0.878669023513794,0.6998400092124939,1.300983428955078,50000.0,0.5813000202178955,1.917003273963928,10000.0,42390.04211616516,44052.20104932785,42390.04211616516,1653.6435549259186,4.225467681884766,0.0 -125600,5.267796,2.3402836,,,,,,,,,,,,,, -125700,4.541416,2.3043158,,,,,,,,,,,,,, -125800,5.138378,2.2248683,,,,,,,,,,,,,, -125900,4.701932,2.2852263,,,,,,,,,,,,,, -126000,5.9536705,2.4035616,,,,,,,,,,,,,, -126100,5.1655273,2.2494254,,,,,,,,,,,,,, -126200,4.562233,2.2860994,,,,,,,,,,,,,, -126300,4.591494,2.2393775,,,,,,,,,,,,,, -126400,4.8609633,2.2475276,,,,,,,,,,,,,, -126500,4.932956,2.2608857,,,,,,,,,,,,,, -126600,4.573465,2.1828496,,,,,,,,,,,,,, -126700,5.4121966,2.2539275,,,,,,,,,,,,,, -126800,4.552135,2.22679,,,,,,,,,,,,,, -126900,5.1428256,2.2571912,,,,,,,,,,,,,, -127000,4.960848,2.289866,,,,,,,,,,,,,, -127100,5.113382,2.2561553,,,,,,,,,,,,,, -127106,,,0.78714919090271,0.9083539247512816,0.7061799764633179,1.2675763368606567,50000.0,0.5827000141143799,1.9058300256729128,10000.0,42900.01652574539,44579.79776358605,42900.01652574539,1671.1585881710052,4.279773235321045,0.0 -127200,5.2852025,2.310096,,,,,,,,,,,,,, -127300,4.6902704,2.4040437,,,,,,,,,,,,,, -127400,5.3462963,2.3208787,,,,,,,,,,,,,, -127500,4.7786546,2.2855673,,,,,,,,,,,,,, -127600,5.112434,2.2622983,,,,,,,,,,,,,, -127700,5.3190613,2.2745411,,,,,,,,,,,,,, -127800,5.255514,2.3641782,,,,,,,,,,,,,, -127900,5.22065,2.347227,,,,,,,,,,,,,, -128000,5.1104126,2.3459606,,,,,,,,,,,,,, -128100,5.377942,2.3287983,,,,,,,,,,,,,, -128200,5.337921,2.320857,,,,,,,,,,,,,, -128300,4.372103,2.3549728,,,,,,,,,,,,,, -128400,5.1431737,2.2698278,,,,,,,,,,,,,, -128500,5.4331956,2.2847738,,,,,,,,,,,,,, -128600,4.809324,2.23097,,,,,,,,,,,,,, -128618,,,0.7895806431770325,0.9128103852272034,0.7077999711036682,1.2611401081085205,50000.0,0.5888000130653381,1.886717438697815,10000.0,43409.1985976696,45107.111483335495,43409.1985976696,1688.4513931274414,5.066087961196899,0.0 -128700,4.9643955,2.1598413,,,,,,,,,,,,,, -128800,5.3013444,2.2789392,,,,,,,,,,,,,, -128900,4.933651,2.2764406,,,,,,,,,,,,,, -129000,4.7533436,2.2205245,,,,,,,,,,,,,, -129100,5.148524,2.314249,,,,,,,,,,,,,, -129200,5.2160907,2.2910538,,,,,,,,,,,,,, -129300,5.259037,2.3278763,,,,,,,,,,,,,, -129400,4.7624555,2.1569989,,,,,,,,,,,,,, -129500,5.1323485,2.2721467,,,,,,,,,,,,,, -129600,5.698917,2.1852148,,,,,,,,,,,,,, -129700,5.478888,2.2404144,,,,,,,,,,,,,, -129800,4.7715907,2.2054985,,,,,,,,,,,,,, -129900,5.0531025,2.2891147,,,,,,,,,,,,,, -130000,5.091548,2.3494804,,,,,,,,,,,,,, -130100,5.868107,2.1874468,,,,,,,,,,,,,, -130133,,,0.7902981638908386,0.9130910634994508,0.711899995803833,1.2529252767562866,50000.0,0.5927000045776367,1.862455129623413,10000.0,43919.28074002266,45634.93291687965,43919.28074002266,1706.087923288345,5.117692470550537,0.0 -130200,5.6403904,2.2106128,,,,,,,,,,,,,, -130300,5.529542,2.1866326,,,,,,,,,,,,,, -130400,4.8139796,2.1954303,,,,,,,,,,,,,, -130500,4.657773,2.1598628,,,,,,,,,,,,,, -130600,5.550997,2.2045572,,,,,,,,,,,,,, -130700,4.8356676,2.2581315,,,,,,,,,,,,,, -130800,5.1453867,2.2088304,,,,,,,,,,,,,, -130900,5.439466,2.2070491,,,,,,,,,,,,,, -131000,4.96434,2.2296526,,,,,,,,,,,,,, -131100,5.1913905,2.2034001,,,,,,,,,,,,,, -131200,4.908767,2.2278476,,,,,,,,,,,,,, -131300,5.416502,2.2211766,,,,,,,,,,,,,, -131400,4.515653,2.251562,,,,,,,,,,,,,, -131500,5.2912807,2.3283713,,,,,,,,,,,,,, -131600,6.100405,2.2427602,,,,,,,,,,,,,, -131648,,,0.7849569320678711,0.9189361333847046,0.7110599875450134,1.247543454170227,50000.0,0.5861000418663025,1.8845510482788088,10000.0,44429.30247211456,46162.18374609947,44429.30247211456,1723.184979915619,5.197498321533203,0.0 -131700,5.7928524,2.242198,,,,,,,,,,,,,, -131800,5.4762,2.2452981,,,,,,,,,,,,,, -131900,5.274941,2.2388074,,,,,,,,,,,,,, -132000,4.770878,2.1840873,,,,,,,,,,,,,, -132100,5.120418,2.2437277,,,,,,,,,,,,,, -132200,5.1663694,2.2467396,,,,,,,,,,,,,, -132300,5.188446,2.2268617,,,,,,,,,,,,,, -132400,5.0950484,2.2287865,,,,,,,,,,,,,, -132500,5.201697,2.1660469,,,,,,,,,,,,,, -132600,5.3339972,2.2599342,,,,,,,,,,,,,, -132700,5.1044865,2.2481132,,,,,,,,,,,,,, -132800,5.189179,2.1631842,,,,,,,,,,,,,, -132900,5.558574,2.2980535,,,,,,,,,,,,,, -133000,5.3606133,2.3319926,,,,,,,,,,,,,, -133100,5.8168826,2.154521,,,,,,,,,,,,,, -133163,,,0.7988081574440002,0.8635379076004028,0.7162399888038635,1.2158832550048828,50000.0,0.5948000550270081,1.8518083095550537,10000.0,44939.4055583477,46689.57337188721,44939.4055583477,1740.367802619934,5.248680830001831,0.0 -133200,5.8513265,2.3118489,,,,,,,,,,,,,, -133300,5.159814,2.2134006,,,,,,,,,,,,,, -133400,5.6380754,2.3054771,,,,,,,,,,,,,, -133500,5.4559717,2.2140305,,,,,,,,,,,,,, -133600,5.731827,2.1837883,,,,,,,,,,,,,, -133700,5.340155,2.2233145,,,,,,,,,,,,,, -133800,6.135326,2.323216,,,,,,,,,,,,,, -133900,5.2386565,2.1811376,,,,,,,,,,,,,, -134000,5.626095,2.1956224,,,,,,,,,,,,,, -134100,5.241239,2.1606417,,,,,,,,,,,,,, -134200,4.9656153,2.1786242,,,,,,,,,,,,,, -134300,5.3104634,2.240808,,,,,,,,,,,,,, -134400,5.3516665,2.1535003,,,,,,,,,,,,,, -134500,5.499333,2.289157,,,,,,,,,,,,,, -134600,5.0875506,2.1497087,,,,,,,,,,,,,, -134678,,,0.8146922588348389,0.8167023658752441,0.7167999744415283,1.2303428649902344,50000.0,0.5956000089645386,1.858915328979492,10000.0,45449.53480505943,47217.12211894989,45449.53480505943,1757.6859464645386,5.297849655151367,0.0 -134700,5.300812,2.2722466,,,,,,,,,,,,,, -134800,4.94292,2.1887488,,,,,,,,,,,,,, -134900,5.272847,2.121555,,,,,,,,,,,,,, -135000,5.679523,2.156466,,,,,,,,,,,,,, -135100,5.556411,2.1672382,,,,,,,,,,,,,, -135200,5.749128,2.2645319,,,,,,,,,,,,,, -135300,5.6250954,2.2360692,,,,,,,,,,,,,, -135400,5.9192815,2.3071904,,,,,,,,,,,,,, -135500,5.0165324,2.278934,,,,,,,,,,,,,, -135600,5.390376,2.194304,,,,,,,,,,,,,, -135700,5.148678,2.146851,,,,,,,,,,,,,, -135800,5.5150127,2.1562874,,,,,,,,,,,,,, -135900,5.505015,2.1385992,,,,,,,,,,,,,, -136000,5.089251,2.211027,,,,,,,,,,,,,, -136100,5.2244396,2.2063675,,,,,,,,,,,,,, -136193,,,0.812519907951355,0.822140634059906,0.7186399698257446,1.2230992317199707,50000.0,0.597000002861023,1.836598873138428,10000.0,45959.67552042008,47744.87985253334,45959.67552042008,1775.1927635669708,5.355523347854614,0.0 -136200,5.548082,2.2270231,,,,,,,,,,,,,, -136300,5.6429687,2.1675525,,,,,,,,,,,,,, -136400,5.737675,2.1515594,,,,,,,,,,,,,, -136500,5.372763,2.1734498,,,,,,,,,,,,,, -136600,6.0920258,2.1805077,,,,,,,,,,,,,, -136700,5.1203146,2.1911032,,,,,,,,,,,,,, -136800,5.3961143,2.1920815,,,,,,,,,,,,,, -136900,5.5363564,2.1567729,,,,,,,,,,,,,, -137000,5.1517043,2.1424227,,,,,,,,,,,,,, -137100,5.6414804,2.2211952,,,,,,,,,,,,,, -137200,5.391644,2.1661086,,,,,,,,,,,,,, -137300,5.1556153,2.1836336,,,,,,,,,,,,,, -137400,5.5762568,2.1835558,,,,,,,,,,,,,, -137500,5.2971616,2.2332532,,,,,,,,,,,,,, -137600,4.96328,2.1267066,,,,,,,,,,,,,, -137700,5.866155,2.2618606,,,,,,,,,,,,,, -137707,,,0.8054049611091614,0.8238815069198608,0.7189599871635437,1.202906847000122,50000.0,0.5964000225067139,1.854988932609558,10000.0,46469.635543346405,48272.24763154984,46469.635543346405,1792.49906373024,5.405432224273682,0.0 -137800,4.9080486,2.1443334,,,,,,,,,,,,,, -137900,6.92353,2.19837,,,,,,,,,,,,,, -138000,6.3156347,2.1996202,,,,,,,,,,,,,, -138100,5.851632,2.1589155,,,,,,,,,,,,,, -138200,5.235259,2.1997483,,,,,,,,,,,,,, -138300,6.3766975,2.1972852,,,,,,,,,,,,,, -138400,6.069034,2.2399192,,,,,,,,,,,,,, -138500,5.2332397,2.1384454,,,,,,,,,,,,,, -138600,5.4472594,2.213129,,,,,,,,,,,,,, -138700,6.751887,2.2297182,,,,,,,,,,,,,, -138800,5.853838,2.1278613,,,,,,,,,,,,,, -138900,4.988494,2.1761546,,,,,,,,,,,,,, -139000,5.9460955,2.193904,,,,,,,,,,,,,, -139100,5.7336884,2.1439097,,,,,,,,,,,,,, -139200,5.6525164,2.1608279,,,,,,,,,,,,,, -139222,,,0.8106265664100647,0.8035197854042053,0.7279599905014038,1.174459457397461,50000.0,0.6018000245094299,1.8125733137130733,10000.0,46979.75103497505,48799.85860180855,46979.75103497505,1809.889460325241,5.458134174346924,0.0 -139300,5.9697156,2.2096834,,,,,,,,,,,,,, -139400,5.435877,2.091141,,,,,,,,,,,,,, -139500,5.9473615,2.2265837,,,,,,,,,,,,,, -139600,5.688157,2.1423666,,,,,,,,,,,,,, -139700,5.153337,2.0728953,,,,,,,,,,,,,, -139800,6.1260557,2.1526134,,,,,,,,,,,,,, -139900,5.5701714,2.1279142,,,,,,,,,,,,,, -140000,6.0741754,2.1358082,,,,,,,,,,,,,, -140100,5.4361186,2.2334044,,,,,,,,,,,,,, -140200,5.5433426,2.1095016,,,,,,,,,,,,,, -140300,5.7654448,2.1377044,,,,,,,,,,,,,, -140400,5.7325845,2.1165617,,,,,,,,,,,,,, -140500,6.4635653,2.1891084,,,,,,,,,,,,,, -140600,5.84864,2.0864315,,,,,,,,,,,,,, -140700,5.557838,2.1636724,,,,,,,,,,,,,, -140737,,,0.8134167790412903,0.8105883598327637,0.7305999994277954,1.175118327140808,50000.0,0.5982000231742859,1.803998947143555,10000.0,47489.82584095001,49327.57805490494,47489.82584095001,1827.4179458618164,5.522055149078369,0.0 -140800,5.506823,2.0973434,,,,,,,,,,,,,, -140900,5.7490363,2.1594052,,,,,,,,,,,,,, -141000,5.742148,2.1771321,,,,,,,,,,,,,, -141100,5.7627807,2.106043,,,,,,,,,,,,,, -141200,5.365157,2.1178865,,,,,,,,,,,,,, -141300,5.848535,2.1795762,,,,,,,,,,,,,, -141400,6.02589,2.107991,,,,,,,,,,,,,, -141500,5.4084444,2.1768966,,,,,,,,,,,,,, -141600,6.444319,2.1985295,,,,,,,,,,,,,, -141700,5.663201,2.1598992,,,,,,,,,,,,,, -141800,6.169014,2.1482525,,,,,,,,,,,,,, -141900,6.09303,2.081729,,,,,,,,,,,,,, -142000,5.6007175,2.1450243,,,,,,,,,,,,,, -142100,5.594043,2.1741807,,,,,,,,,,,,,, -142200,5.760767,2.1086624,,,,,,,,,,,,,, -142252,,,0.8249959945678711,0.7559878826141357,0.726419985294342,1.179167866706848,50000.0,0.605400025844574,1.813591241836548,10000.0,48000.03847670555,49854.95561385155,48000.03847670555,1844.4799404144287,5.573156356811523,0.0 -142300,5.773535,2.133787,,,,,,,,,,,,,, -142400,6.2032723,2.1816766,,,,,,,,,,,,,, -142500,5.788555,2.1819077,,,,,,,,,,,,,, -142600,6.264254,2.1165257,,,,,,,,,,,,,, -142700,5.119334,2.0798523,,,,,,,,,,,,,, -142800,5.9914875,2.1867049,,,,,,,,,,,,,, -142900,6.1387157,2.1503491,,,,,,,,,,,,,, -143000,5.8104587,2.1654313,,,,,,,,,,,,,, -143100,6.9594417,2.1682816,,,,,,,,,,,,,, -143200,5.952533,2.090704,,,,,,,,,,,,,, -143300,6.2234063,2.2149396,,,,,,,,,,,,,, -143400,5.8938136,2.1634343,,,,,,,,,,,,,, -143500,5.9106207,2.0773177,,,,,,,,,,,,,, -143600,6.139142,2.1099904,,,,,,,,,,,,,, -143700,5.9699073,2.0422714,,,,,,,,,,,,,, -143766,,,0.8384287357330322,0.7113813757896423,0.7309799790382385,1.161076545715332,50000.0,0.6038000583648682,1.7907795906066897,10000.0,48509.954323768616,50382.42102813721,48509.954323768616,1861.9220707416528,5.628809690475464,0.0 -143800,6.7402973,2.212944,,,,,,,,,,,,,, -143900,5.882301,2.104151,,,,,,,,,,,,,, -144000,5.726966,2.1376872,,,,,,,,,,,,,, -144100,5.732106,2.1323354,,,,,,,,,,,,,, -144200,5.7980933,2.1069386,,,,,,,,,,,,,, -144300,6.229484,2.0710943,,,,,,,,,,,,,, -144400,6.024167,2.1221828,,,,,,,,,,,,,, -144500,5.622561,2.0429723,,,,,,,,,,,,,, -144600,6.444573,2.0992439,,,,,,,,,,,,,, -144700,5.7133603,2.0947475,,,,,,,,,,,,,, -144800,6.1132565,2.0576217,,,,,,,,,,,,,, -144900,5.8761964,2.0241547,,,,,,,,,,,,,, -145000,6.1790543,2.101276,,,,,,,,,,,,,, -145100,5.87889,2.078546,,,,,,,,,,,,,, -145200,5.801024,2.156309,,,,,,,,,,,,,, -145281,,,0.8381098508834839,0.7158500552177429,0.7337799668312073,1.148126482963562,50000.0,0.6077000498771667,1.781200885772705,10000.0,49020.05911588669,50909.77686786652,49020.05911588669,1879.068108797073,5.681424856185913,0.0 -145300,6.3364706,2.1230588,,,,,,,,,,,,,, -145400,5.853158,2.028635,,,,,,,,,,,,,, -145500,5.7525363,2.1156492,,,,,,,,,,,,,, -145600,6.094925,2.1116316,,,,,,,,,,,,,, -145700,5.9491863,2.1238043,,,,,,,,,,,,,, -145800,6.0937686,2.0735867,,,,,,,,,,,,,, -145900,5.7175856,2.0808656,,,,,,,,,,,,,, -146000,5.865076,2.1828291,,,,,,,,,,,,,, -146100,6.2516885,2.1501646,,,,,,,,,,,,,, -146200,6.285933,2.149107,,,,,,,,,,,,,, -146300,5.8933444,2.0970612,,,,,,,,,,,,,, -146400,5.756747,2.0580444,,,,,,,,,,,,,, -146500,6.354463,2.1516373,,,,,,,,,,,,,, -146600,6.1622334,2.0309603,,,,,,,,,,,,,, -146700,6.263842,2.0809085,,,,,,,,,,,,,, -146796,,,0.8376116156578064,0.7040087580680847,0.7360000014305115,1.129098415374756,50000.0,0.6104000210762024,1.7563949823379517,10000.0,49530.228246212006,51437.18770766258,49530.228246212006,1896.1900537014008,5.749187469482422,0.0 -146800,6.1404223,2.1577802,,,,,,,,,,,,,, -146900,6.865083,2.163256,,,,,,,,,,,,,, -147000,5.826188,1.9791641,,,,,,,,,,,,,, -147100,6.191385,2.0801945,,,,,,,,,,,,,, -147200,6.470068,2.1379158,,,,,,,,,,,,,, -147300,5.947062,2.0693722,,,,,,,,,,,,,, -147400,6.7895427,2.1099298,,,,,,,,,,,,,, -147500,6.272813,2.1530924,,,,,,,,,,,,,, -147600,6.747607,2.1189232,,,,,,,,,,,,,, -147700,6.373327,2.086181,,,,,,,,,,,,,, -147800,6.3134356,2.0470965,,,,,,,,,,,,,, -147900,6.846935,2.0808775,,,,,,,,,,,,,, -148000,7.389175,2.1806934,,,,,,,,,,,,,, -148100,6.296167,2.095936,,,,,,,,,,,,,, -148200,6.728619,2.055448,,,,,,,,,,,,,, -148300,6.0954995,2.070157,,,,,,,,,,,,,, -148311,,,0.8396643400192261,0.7063088417053223,0.738319993019104,1.1283575296401978,50000.0,0.6163000464439392,1.7626051902770996,10000.0,50040.213027477264,51964.37534117699,50040.213027477264,1913.290997505188,5.799168109893799,0.0 -148400,6.6508,2.126136,,,,,,,,,,,,,, -148500,6.494281,2.0441427,,,,,,,,,,,,,, -148600,6.0016274,2.0778837,,,,,,,,,,,,,, -148700,6.2108393,2.01958,,,,,,,,,,,,,, -148800,6.5016193,2.0780241,,,,,,,,,,,,,, -148900,5.909759,2.067573,,,,,,,,,,,,,, -149000,6.606558,2.0631976,,,,,,,,,,,,,, -149100,5.8358984,2.07648,,,,,,,,,,,,,, -149200,6.1555657,2.0922744,,,,,,,,,,,,,, -149300,6.6572604,2.093165,,,,,,,,,,,,,, -149400,6.425255,2.0474086,,,,,,,,,,,,,, -149500,6.6759934,2.0529096,,,,,,,,,,,,,, -149600,5.9482036,2.0024192,,,,,,,,,,,,,, -149700,5.8824205,2.0291545,,,,,,,,,,,,,, -149800,6.644055,2.110536,,,,,,,,,,,,,, -149825,,,0.8420161008834839,0.6852911710739136,0.7417799830436707,1.1163774728775024,50000.0,0.6243000030517578,1.719955325126648,10000.0,50550.241391181946,52491.98043131828,50550.241391181946,1930.7620613574984,5.853203773498535,0.0 -149900,6.1613846,2.0654514,,,,,,,,,,,,,, -150000,6.1717386,2.024147,,,,,,,,,,,,,, -150100,6.377383,1.9753234,,,,,,,,,,,,,, -150200,6.438115,2.1202326,,,,,,,,,,,,,, -150300,7.3342047,2.0761225,,,,,,,,,,,,,, -150400,6.579112,2.022375,,,,,,,,,,,,,, -150500,6.1162047,2.0954883,,,,,,,,,,,,,, -150600,6.468466,2.0300846,,,,,,,,,,,,,, -150700,6.460127,2.1384175,,,,,,,,,,,,,, -150800,6.841117,2.041598,,,,,,,,,,,,,, -150900,5.938324,2.04921,,,,,,,,,,,,,, -151000,6.4721236,2.0083747,,,,,,,,,,,,,, -151100,6.0017905,2.029074,,,,,,,,,,,,,, -151200,6.2894955,2.0732236,,,,,,,,,,,,,, -151300,6.240554,2.0952291,,,,,,,,,,,,,, -151340,,,0.86820387840271,0.5952945947647095,0.7403199672698975,1.1137624979019165,50000.0,0.6186000108718872,1.740875482559204,10000.0,51060.40700316429,53019.79779362679,51060.40700316429,1948.308458328247,5.906131267547607,0.0 -151400,7.159317,1.9983486,,,,,,,,,,,,,, -151500,5.931353,2.075829,,,,,,,,,,,,,, -151600,6.688139,2.0983162,,,,,,,,,,,,,, -151700,6.4899993,2.0013812,,,,,,,,,,,,,, -151800,7.0378523,2.0794683,,,,,,,,,,,,,, -151900,5.8230505,1.9868493,,,,,,,,,,,,,, -152000,6.276391,2.0513813,,,,,,,,,,,,,, -152100,7.2538905,1.9912462,,,,,,,,,,,,,, -152200,6.0189323,2.0113583,,,,,,,,,,,,,, -152300,6.573118,1.960193,,,,,,,,,,,,,, -152400,6.308223,2.106822,,,,,,,,,,,,,, -152500,6.3497353,1.965127,,,,,,,,,,,,,, -152600,6.935632,2.0369616,,,,,,,,,,,,,, -152700,6.4133267,2.0037484,,,,,,,,,,,,,, -152800,6.6960645,2.0562208,,,,,,,,,,,,,, -152855,,,0.8589365482330322,0.613325297832489,0.7432799935340881,1.1003096103668213,50000.0,0.62090003490448,1.735670804977417,10000.0,51570.638830661774,53548.062861442566,51570.638830661774,1966.237622976303,5.958402872085571,0.0 -152900,6.2399707,2.0741131,,,,,,,,,,,,,, -153000,7.9306164,1.988547,,,,,,,,,,,,,, -153100,6.4629283,1.9801909,,,,,,,,,,,,,, -153200,6.379546,2.0664876,,,,,,,,,,,,,, -153300,7.233451,2.0000868,,,,,,,,,,,,,, -153400,6.6567035,2.051194,,,,,,,,,,,,,, -153500,6.6178374,2.015758,,,,,,,,,,,,,, -153600,6.8520656,2.0823312,,,,,,,,,,,,,, -153700,6.544041,2.0398033,,,,,,,,,,,,,, -153800,6.4948792,2.000503,,,,,,,,,,,,,, -153900,6.519161,2.006114,,,,,,,,,,,,,, -154000,6.9371724,2.0809493,,,,,,,,,,,,,, -154100,6.7668777,2.0452743,,,,,,,,,,,,,, -154200,7.095303,2.069933,,,,,,,,,,,,,, -154300,6.70021,1.9842389,,,,,,,,,,,,,, -154370,,,0.8587571382522583,0.6176808476448059,0.7478799819946289,1.0903000831604004,50000.0,0.6241000294685364,1.7056432962417605,10000.0,52080.63415670395,54075.474705934525,52080.63415670395,1983.548333644867,6.011291027069092,0.0 -154400,6.640777,2.01183,,,,,,,,,,,,,, -154500,7.1432824,1.9842802,,,,,,,,,,,,,, -154600,6.635598,1.9461482,,,,,,,,,,,,,, -154700,6.858837,2.011133,,,,,,,,,,,,,, -154800,6.688616,1.9831523,,,,,,,,,,,,,, -154900,6.210375,1.9149472,,,,,,,,,,,,,, -155000,6.9786115,1.9920497,,,,,,,,,,,,,, -155100,6.8588295,2.047783,,,,,,,,,,,,,, -155200,6.317278,1.8917649,,,,,,,,,,,,,, -155300,6.8504734,2.01688,,,,,,,,,,,,,, -155400,6.765533,1.9797091,,,,,,,,,,,,,, -155500,6.762876,2.0183117,,,,,,,,,,,,,, -155600,7.157049,2.039011,,,,,,,,,,,,,, -155700,6.482505,2.0019894,,,,,,,,,,,,,, -155800,6.532859,1.9775014,,,,,,,,,,,,,, -155885,,,0.8608896732330322,0.6201562881469727,0.75,1.0886688232421875,50000.0,0.6253000497817993,1.714283466339111,10000.0,52590.81240940094,54603.15359258652,52590.81240940094,2000.9424073696136,6.064687252044678,0.0 -155900,6.365147,1.9561615,,,,,,,,,,,,,, -156000,7.532432,2.0412483,,,,,,,,,,,,,, -156100,6.7046556,1.9677345,,,,,,,,,,,,,, -156200,6.4090123,1.9218422,,,,,,,,,,,,,, -156300,7.1459637,1.9686681,,,,,,,,,,,,,, -156400,6.1762447,1.9857062,,,,,,,,,,,,,, -156500,6.7763405,1.9113925,,,,,,,,,,,,,, -156600,7.8294005,2.1007233,,,,,,,,,,,,,, -156700,6.7997093,1.9863403,,,,,,,,,,,,,, -156800,6.3691382,1.9973335,,,,,,,,,,,,,, -156900,6.5977,2.025089,,,,,,,,,,,,,, -157000,6.6648445,2.038182,,,,,,,,,,,,,, -157100,7.085446,2.0062788,,,,,,,,,,,,,, -157200,6.43353,1.956006,,,,,,,,,,,,,, -157300,6.5563736,1.9540838,,,,,,,,,,,,,, -157400,,,0.8644570708274841,0.6090182065963745,0.7490599751472473,1.082980036735535,50000.0,0.6284000277519226,1.701801300048828,10000.0,53100.97397065163,55130.75111031532,53100.97397065163,2018.272669315338,6.117977380752564,0.0 -157400,6.442891,1.9928197,,,,,,,,,,,,,, -157500,6.822191,2.0019498,,,,,,,,,,,,,, -157600,6.3504906,1.9262562,,,,,,,,,,,,,, -157700,7.457332,1.9237429,,,,,,,,,,,,,, -157800,6.8409247,1.9988787,,,,,,,,,,,,,, -157900,6.6811585,2.0077815,,,,,,,,,,,,,, -158000,7.442835,2.0039682,,,,,,,,,,,,,, -158100,7.067614,2.067435,,,,,,,,,,,,,, -158200,6.5204616,1.9902679,,,,,,,,,,,,,, -158300,6.939004,1.9149662,,,,,,,,,,,,,, -158400,6.84378,1.9569534,,,,,,,,,,,,,, -158500,7.2408104,2.0789998,,,,,,,,,,,,,, -158600,7.1611533,1.9505284,,,,,,,,,,,,,, -158700,6.814787,1.9699078,,,,,,,,,,,,,, -158800,7.2929525,1.9890356,,,,,,,,,,,,,, -158900,7.199452,1.9571986,,,,,,,,,,,,,, -158915,,,0.871113657951355,0.5796046853065491,0.7538399696350098,1.0558656454086304,50000.0,0.6343000531196594,1.67049241065979,10000.0,53611.190331459045,55658.49032831192,53611.190331459045,2035.686059474945,6.175013780593872,0.0 -159000,7.5931907,1.9731863,,,,,,,,,,,,,, -159100,6.3982453,1.9346292,,,,,,,,,,,,,, -159200,7.000481,1.9864341,,,,,,,,,,,,,, -159300,7.4281836,1.8707088,,,,,,,,,,,,,, -159400,6.381107,1.9154122,,,,,,,,,,,,,, -159500,7.618313,2.0200772,,,,,,,,,,,,,, -159600,6.692199,1.9655966,,,,,,,,,,,,,, -159700,7.1116314,1.8994882,,,,,,,,,,,,,, -159800,6.9060426,1.8796549,,,,,,,,,,,,,, -159900,6.586864,1.923101,,,,,,,,,,,,,, -160000,7.154613,1.9374533,,,,,,,,,,,,,, -160100,7.178131,1.9574953,,,,,,,,,,,,,, -160200,7.2596874,1.965025,,,,,,,,,,,,,, -160300,6.6605167,1.9253418,,,,,,,,,,,,,, -160400,7.1920485,1.9895679,,,,,,,,,,,,,, -160430,,,0.8908242583274841,0.5088411569595337,0.7539199590682983,1.0591936111450195,50000.0,0.6325000524520874,1.6727036237716677,10000.0,54121.38585519791,56186.07378411293,54121.38585519791,2052.965085029602,6.231770277023315,0.0 -160500,7.549784,1.9440992,,,,,,,,,,,,,, -160600,7.1809597,1.9180323,,,,,,,,,,,,,, -160700,6.658686,1.9777067,,,,,,,,,,,,,, -160800,7.315005,1.998087,,,,,,,,,,,,,, -160900,6.9419646,1.922035,,,,,,,,,,,,,, -161000,6.5816026,1.9566467,,,,,,,,,,,,,, -161100,6.8444495,1.9955451,,,,,,,,,,,,,, -161200,7.2202697,1.8487811,,,,,,,,,,,,,, -161300,6.3964825,1.8402859,,,,,,,,,,,,,, -161400,6.9202113,1.8947139,,,,,,,,,,,,,, -161500,7.4274855,1.9325141,,,,,,,,,,,,,, -161600,6.8739376,1.9887166,,,,,,,,,,,,,, -161700,7.6688247,1.8980227,,,,,,,,,,,,,, -161800,7.2414813,1.8504508,,,,,,,,,,,,,, -161900,7.191819,1.9155881,,,,,,,,,,,,,, -161945,,,0.8883330225944519,0.507793128490448,0.7577599883079529,1.0418267250061035,50000.0,0.6374000310897827,1.6483393907546997,10000.0,54631.54222178459,56713.73193693161,54631.54222178459,2070.361840724945,6.284480094909668,0.0 -162000,6.1171923,1.8533723,,,,,,,,,,,,,, -162100,7.2534614,1.8958445,,,,,,,,,,,,,, -162200,7.1475425,1.9490682,,,,,,,,,,,,,, -162300,6.915839,1.9237306,,,,,,,,,,,,,, -162400,6.8593755,1.9102352,,,,,,,,,,,,,, -162500,7.3414416,1.932652,,,,,,,,,,,,,, -162600,6.9589014,1.8740573,,,,,,,,,,,,,, -162700,7.373176,1.9737127,,,,,,,,,,,,,, -162800,7.2728815,1.9692657,,,,,,,,,,,,,, -162900,8.195039,1.9735193,,,,,,,,,,,,,, -163000,7.75824,1.9224184,,,,,,,,,,,,,, -163100,6.983545,1.9561863,,,,,,,,,,,,,, -163200,7.055316,1.9396137,,,,,,,,,,,,,, -163300,6.881573,1.8614041,,,,,,,,,,,,,, -163400,7.074708,1.9221251,,,,,,,,,,,,,, -163459,,,0.8888113498687744,0.508962869644165,0.759880006313324,1.0377557277679443,50000.0,0.6409000158309937,1.6533217430114746,10000.0,55141.44549107552,57240.98913478851,55141.44549107552,2087.6100096702576,6.337510824203491,0.0 -163500,7.4534826,1.8452582,,,,,,,,,,,,,, -163600,7.679375,1.9346912,,,,,,,,,,,,,, -163700,7.430423,1.9259154,,,,,,,,,,,,,, -163800,7.045613,1.8772272,,,,,,,,,,,,,, -163900,7.163577,1.9555124,,,,,,,,,,,,,, -164000,7.374799,1.8639628,,,,,,,,,,,,,, -164100,7.394719,1.9357277,,,,,,,,,,,,,, -164200,7.3466115,1.9347961,,,,,,,,,,,,,, -164300,7.2882686,1.8680177,,,,,,,,,,,,,, -164400,7.3128033,1.9536227,,,,,,,,,,,,,, -164500,7.4492393,1.9564867,,,,,,,,,,,,,, -164600,7.7910347,1.8737347,,,,,,,,,,,,,, -164700,7.256158,1.9141359,,,,,,,,,,,,,, -164800,6.974534,1.7816961,,,,,,,,,,,,,, -164900,7.310505,1.8751165,,,,,,,,,,,,,, -164974,,,0.8914421200752258,0.5064797401428223,0.7610999941825867,1.0407962799072266,50000.0,0.6428000330924988,1.644087553024292,10000.0,55651.61699032784,57768.47083735466,55651.61699032784,2104.812753200531,6.391413450241089,0.0 -165000,7.1672273,1.8910512,,,,,,,,,,,,,, -165100,6.6107936,1.8375516,,,,,,,,,,,,,, -165200,7.385255,1.9601289,,,,,,,,,,,,,, -165300,7.441904,1.8804816,,,,,,,,,,,,,, -165400,6.7256474,1.8772509,,,,,,,,,,,,,, -165500,7.104865,1.8790817,,,,,,,,,,,,,, -165600,7.8933268,1.8195353,,,,,,,,,,,,,, -165700,7.875346,1.9090483,,,,,,,,,,,,,, -165800,7.274913,1.8994987,,,,,,,,,,,,,, -165900,6.8481445,1.884165,,,,,,,,,,,,,, -166000,7.6054454,1.9231799,,,,,,,,,,,,,, -166100,6.511314,1.7847927,,,,,,,,,,,,,, -166200,7.985102,1.9643513,,,,,,,,,,,,,, -166300,7.3839664,1.8951242,,,,,,,,,,,,,, -166400,7.450433,1.8848841,,,,,,,,,,,,,, -166488,,,0.8913623690605164,0.4938401579856872,0.761900007724762,1.031830906867981,50000.0,0.6368000507354736,1.654982328414917,10000.0,56161.74114251137,58296.5521376133,56161.74114251137,2122.656609773636,6.452677011489868,0.0 -166500,7.223378,1.8647387,,,,,,,,,,,,,, -166600,7.63588,1.913829,,,,,,,,,,,,,, -166700,7.939752,1.8148425,,,,,,,,,,,,,, -166800,7.510365,1.9027696,,,,,,,,,,,,,, -166900,7.413208,1.8395739,,,,,,,,,,,,,, -167000,7.626991,1.8713734,,,,,,,,,,,,,, -167100,8.061142,1.899983,,,,,,,,,,,,,, -167200,7.4141192,1.8567183,,,,,,,,,,,,,, -167300,7.267326,1.8738055,,,,,,,,,,,,,, -167400,7.1476517,1.8047267,,,,,,,,,,,,,, -167500,7.516079,1.8600993,,,,,,,,,,,,,, -167600,7.4274726,1.7941993,,,,,,,,,,,,,, -167700,7.4424915,1.8535347,,,,,,,,,,,,,, -167800,7.6489944,1.8588675,,,,,,,,,,,,,, -167900,7.7013507,1.9574978,,,,,,,,,,,,,, -168000,7.6242037,1.9012321,,,,,,,,,,,,,, -168003,,,0.8980189561843872,0.4819498062133789,0.7639999985694885,1.017918586730957,50000.0,0.6433000564575195,1.6341732740402222,10000.0,56671.93290805817,58824.12835144997,56671.93290805817,2139.942459821701,6.499396800994873,0.0 -168100,7.1921387,1.840549,,,,,,,,,,,,,, -168200,7.2692275,1.8247881,,,,,,,,,,,,,, -168300,7.635381,1.8792149,,,,,,,,,,,,,, -168400,7.194979,1.8285718,,,,,,,,,,,,,, -168500,7.2270384,1.8486357,,,,,,,,,,,,,, -168600,7.2298474,1.857231,,,,,,,,,,,,,, -168700,7.017291,1.8439795,,,,,,,,,,,,,, -168800,8.152631,1.9088308,,,,,,,,,,,,,, -168900,7.941756,1.8492427,,,,,,,,,,,,,, -169000,7.785146,1.8512156,,,,,,,,,,,,,, -169100,7.531374,1.8769505,,,,,,,,,,,,,, -169200,7.948015,1.8727669,,,,,,,,,,,,,, -169300,7.885187,1.7970923,,,,,,,,,,,,,, -169400,6.984024,1.7934064,,,,,,,,,,,,,, -169500,7.859589,1.8814754,,,,,,,,,,,,,, -169518,,,0.9111128449440002,0.4368035793304443,0.7679599523544312,1.0132651329040527,50000.0,0.6438000202178955,1.6271370649337769,10000.0,57182.0711376667,59352.14291119576,57182.0711376667,2157.704973936081,6.561173677444458,0.0 -169600,7.224221,1.8258004,,,,,,,,,,,,,, -169700,7.3620105,1.8419785,,,,,,,,,,,,,, -169800,7.2605042,1.7927092,,,,,,,,,,,,,, -169900,7.495914,1.813798,,,,,,,,,,,,,, -170000,7.719386,1.8507172,,,,,,,,,,,,,, -170100,8.813881,1.8332247,,,,,,,,,,,,,, -170200,7.439187,1.8776727,,,,,,,,,,,,,, -170300,7.622601,1.8150638,,,,,,,,,,,,,, -170400,7.048433,1.7336348,,,,,,,,,,,,,, -170500,7.6317844,1.8421745,,,,,,,,,,,,,, -170600,7.917677,1.8684256,,,,,,,,,,,,,, -170700,7.617991,1.8421001,,,,,,,,,,,,,, -170800,7.5149965,1.8524975,,,,,,,,,,,,,, -170900,7.4479775,1.7914084,,,,,,,,,,,,,, -171000,7.7815123,1.8158845,,,,,,,,,,,,,, -171033,,,0.9100565910339355,0.4437120258808136,0.768839955329895,1.0136228799819946,50000.0,0.6476000547409058,1.6199846267700195,10000.0,57692.27556872368,59879.77585315704,57692.27556872368,2175.0261178016663,6.616154432296753,0.0 -171100,7.4488463,1.8287764,,,,,,,,,,,,,, -171200,7.4683642,1.8021964,,,,,,,,,,,,,, -171300,7.293106,1.8170574,,,,,,,,,,,,,, -171400,7.3857193,1.8166308,,,,,,,,,,,,,, -171500,7.4593606,1.8741454,,,,,,,,,,,,,, -171600,7.926943,1.8004898,,,,,,,,,,,,,, -171700,7.679171,1.9111397,,,,,,,,,,,,,, -171800,7.7911506,1.8191273,,,,,,,,,,,,,, -171900,7.721437,1.8048167,,,,,,,,,,,,,, -172000,7.887253,1.8792622,,,,,,,,,,,,,, -172100,7.7209244,1.7999492,,,,,,,,,,,,,, -172200,7.5387783,1.8469667,,,,,,,,,,,,,, -172300,7.33617,1.8344135,,,,,,,,,,,,,, -172400,7.4372745,1.7047728,,,,,,,,,,,,,, -172500,7.4954495,1.793125,,,,,,,,,,,,,, -172548,,,0.907983899116516,0.44346883893013,0.7683799862861633,1.0122803449630735,50000.0,0.6457000374794006,1.6168798208236694,10000.0,58202.40605354309,60407.436655282974,58202.40605354309,2192.441407442093,6.679574251174927,0.0 -172600,8.259106,1.823172,,,,,,,,,,,,,, -172700,7.0271816,1.7936364,,,,,,,,,,,,,, -172800,7.5489545,1.788042,,,,,,,,,,,,,, -172900,7.227857,1.7753648,,,,,,,,,,,,,, -173000,8.710443,1.8042562,,,,,,,,,,,,,, -173100,8.142006,1.7988541,,,,,,,,,,,,,, -173200,8.320894,1.8675855,,,,,,,,,,,,,, -173300,7.4491262,1.8017349,,,,,,,,,,,,,, -173400,7.7948503,1.7843604,,,,,,,,,,,,,, -173500,8.088058,1.7628424,,,,,,,,,,,,,, -173600,7.319471,1.7304318,,,,,,,,,,,,,, -173700,8.370543,1.8236666,,,,,,,,,,,,,, -173800,8.156294,1.8112774,,,,,,,,,,,,,, -173900,7.93423,1.7883801,,,,,,,,,,,,,, -174000,8.149504,1.757461,,,,,,,,,,,,,, -174063,,,0.9148397445678712,0.4309082627296448,0.7701599597930908,1.0046439170837402,50000.0,0.6491000056266785,1.6106775999069214,10000.0,58712.38555598259,60935.12227869034,58712.38555598259,2210.0351996421814,6.738846778869629,0.0 -174100,7.6695366,1.8151027,,,,,,,,,,,,,, -174200,7.7571836,1.8114159,,,,,,,,,,,,,, -174300,7.903013,1.876574,,,,,,,,,,,,,, -174400,7.3359866,1.8135576,,,,,,,,,,,,,, -174500,7.6692495,1.7500405,,,,,,,,,,,,,, -174600,7.4725275,1.786437,,,,,,,,,,,,,, -174700,6.8945103,1.7592909,,,,,,,,,,,,,, -174800,7.883026,1.8075014,,,,,,,,,,,,,, -174900,7.686401,1.7777488,,,,,,,,,,,,,, -175000,8.774071,1.8171996,,,,,,,,,,,,,, -175100,7.4429717,1.7752675,,,,,,,,,,,,,, -175200,8.13203,1.8030316,,,,,,,,,,,,,, -175300,8.049388,1.8396788,,,,,,,,,,,,,, -175400,7.654931,1.8096331,,,,,,,,,,,,,, -175500,7.284858,1.8071795,,,,,,,,,,,,,, -175578,,,0.913305163383484,0.4299351274967193,0.77183997631073,1.0019093751907349,50000.0,0.650700032711029,1.6060999631881714,10000.0,59222.59151554108,61462.77833938599,59222.59151554108,2227.369342327118,6.801840305328369,0.0 -175600,6.9896765,1.7237198,,,,,,,,,,,,,, -175700,7.8032856,1.7455775,,,,,,,,,,,,,, -175800,7.956418,1.8361797,,,,,,,,,,,,,, -175900,8.33704,1.778852,,,,,,,,,,,,,, -176000,7.286194,1.7690673,,,,,,,,,,,,,, -176100,7.6314363,1.8313823,,,,,,,,,,,,,, -176200,8.174586,1.7711741,,,,,,,,,,,,,, -176300,7.4904933,1.8212204,,,,,,,,,,,,,, -176400,7.4599524,1.7719018,,,,,,,,,,,,,, -176500,9.0835,1.7947906,,,,,,,,,,,,,, -176600,7.1888704,1.7324759,,,,,,,,,,,,,, -176700,7.64184,1.779323,,,,,,,,,,,,,, -176800,7.1706347,1.7040856,,,,,,,,,,,,,, -176900,7.7135587,1.8162342,,,,,,,,,,,,,, -177000,8.005954,1.8059206,,,,,,,,,,,,,, -177092,,,0.914441168308258,0.4245000779628753,0.7725799679756165,0.9952738881111144,50000.0,0.65010005235672,1.6027166843414309,10000.0,59732.50027251244,61990.1252887249,59732.50027251244,2244.6973419189453,6.859493017196655,0.0 -177100,7.3471966,1.7771213,,,,,,,,,,,,,, -177200,8.283157,1.8555312,,,,,,,,,,,,,, -177300,7.800009,1.7876605,,,,,,,,,,,,,, -177400,8.220334,1.8119862,,,,,,,,,,,,,, -177500,7.3667703,1.6912441,,,,,,,,,,,,,, -177600,7.200552,1.7970284,,,,,,,,,,,,,, -177700,8.374347,1.8383595,,,,,,,,,,,,,, -177800,7.114279,1.7396386,,,,,,,,,,,,,, -177900,7.9470057,1.7620192,,,,,,,,,,,,,, -178000,8.052091,1.7652347,,,,,,,,,,,,,, -178100,7.1477523,1.71723,,,,,,,,,,,,,, -178200,8.145398,1.7507836,,,,,,,,,,,,,, -178300,8.165048,1.802913,,,,,,,,,,,,,, -178400,8.243778,1.812302,,,,,,,,,,,,,, -178500,7.7125916,1.803941,,,,,,,,,,,,,, -178600,8.051802,1.7652657,,,,,,,,,,,,,, -178606,,,0.922512710094452,0.3977555632591247,0.7713800072669983,0.9977371096611024,50000.0,0.6538000106811523,1.6024290323257446,10000.0,60242.55020594597,62517.70723223686,60242.55020594597,2262.119296312332,6.916916847229004,0.0 -178700,7.7162914,1.7653147,,,,,,,,,,,,,, -178800,8.431775,1.827358,,,,,,,,,,,,,, -178900,7.3943787,1.827354,,,,,,,,,,,,,, -179000,8.094884,1.7929313,,,,,,,,,,,,,, -179100,7.33561,1.6812204,,,,,,,,,,,,,, -179200,8.406057,1.7791708,,,,,,,,,,,,,, -179300,7.4598875,1.7424808,,,,,,,,,,,,,, -179400,7.2964168,1.8229532,,,,,,,,,,,,,, -179500,7.7406125,1.8241549,,,,,,,,,,,,,, -179600,8.106843,1.8109866,,,,,,,,,,,,,, -179700,7.698079,1.7885664,,,,,,,,,,,,,, -179800,7.816959,1.8447446,,,,,,,,,,,,,, -179900,7.7472234,1.777141,,,,,,,,,,,,,, -180000,7.326666,1.6961329,,,,,,,,,,,,,, -180100,7.861315,1.7479476,,,,,,,,,,,,,, -180120,,,0.9199019074440002,0.4014585316181183,0.7723199725151062,0.9948861598968506,50000.0,0.6515000462532043,1.5987427234649658,10000.0,60752.50250053406,63045.0622420311,60752.50250053406,2279.408571243286,6.978010654449463,0.0 -180200,7.4434075,1.7638204,,,,,,,,,,,,,, -180300,8.132752,1.7775011,,,,,,,,,,,,,, -180400,7.5899568,1.8108155,,,,,,,,,,,,,, -180500,7.663342,1.8145944,,,,,,,,,,,,,, -180600,7.921272,1.7362313,,,,,,,,,,,,,, -180700,8.008447,1.7683343,,,,,,,,,,,,,, -180800,7.9455047,1.8201473,,,,,,,,,,,,,, -180900,7.879255,1.8096892,,,,,,,,,,,,,, -181000,7.8421693,1.7671762,,,,,,,,,,,,,, -181100,7.5710397,1.8113904,,,,,,,,,,,,,, -181200,8.266313,1.8807191,,,,,,,,,,,,,, -181300,7.9201517,1.7366924,,,,,,,,,,,,,, -181400,8.532148,1.8044508,,,,,,,,,,,,,, -181500,8.057458,1.8161527,,,,,,,,,,,,,, -181600,8.230964,1.8033414,,,,,,,,,,,,,, -181635,,,0.920918345451355,0.4040337800979614,0.7740799784660339,0.9936750531196594,50000.0,0.6521000266075134,1.5991164445877075,10000.0,61262.67764997482,63572.72377586365,61262.67764997482,2296.7791769504547,7.041255712509155,0.0 -181700,7.626606,1.8073759,,,,,,,,,,,,,, -181800,6.9203115,1.755099,,,,,,,,,,,,,, -181900,8.050138,1.7880603,,,,,,,,,,,,,, -182000,7.124271,1.761568,,,,,,,,,,,,,, -182100,8.025608,1.753711,,,,,,,,,,,,,, -182200,7.459174,1.7807616,,,,,,,,,,,,,, -182300,7.5522294,1.7879515,,,,,,,,,,,,,, -182400,8.528654,1.8193152,,,,,,,,,,,,,, -182500,7.869606,1.7637525,,,,,,,,,,,,,, -182600,7.595463,1.7023563,,,,,,,,,,,,,, -182700,8.564877,1.8464686,,,,,,,,,,,,,, -182800,7.605384,1.7092588,,,,,,,,,,,,,, -182900,8.220379,1.7709097,,,,,,,,,,,,,, -183000,8.211809,1.8132412,,,,,,,,,,,,,, -183100,7.862526,1.8220582,,,,,,,,,,,,,, -183149,,,0.920141100883484,0.398114413022995,0.7745400071144104,0.9911282062530518,50000.0,0.6522000432014465,1.5940217971801758,10000.0,61772.59363722801,64099.83883333206,61772.59363722801,2313.8672394752502,7.099210262298584,0.0 -183200,7.794289,1.7752814,,,,,,,,,,,,,, -183300,7.381014,1.7558501,,,,,,,,,,,,,, -183400,8.184507,1.742649,,,,,,,,,,,,,, -183500,7.7731724,1.7192925,,,,,,,,,,,,,, -183600,7.6393876,1.7578156,,,,,,,,,,,,,, -183700,8.168797,1.7916183,,,,,,,,,,,,,, -183800,8.00075,1.8229196,,,,,,,,,,,,,, -183900,8.075762,1.7404685,,,,,,,,,,,,,, -184000,7.727779,1.7981788,,,,,,,,,,,,,, -184100,8.211248,1.783888,,,,,,,,,,,,,, -184200,7.5332685,1.7808559,,,,,,,,,,,,,, -184300,7.26865,1.787719,,,,,,,,,,,,,, -184400,7.622913,1.725931,,,,,,,,,,,,,, -184500,7.598126,1.753693,,,,,,,,,,,,,, -184600,7.5549603,1.7505078,,,,,,,,,,,,,, -184663,,,0.9205795526504515,0.4052787721157074,0.774399995803833,0.9936506748199464,50000.0,0.6538000106811523,1.596049427986145,10000.0,62282.49484300613,64627.17417168617,62282.49484300613,2331.1919524669647,7.156313896179199,0.0 -184700,6.9374704,1.7408254,,,,,,,,,,,,,, -184800,7.688114,1.7418451,,,,,,,,,,,,,, -184900,7.8739,1.8209697,,,,,,,,,,,,,, -185000,8.035901,1.7431533,,,,,,,,,,,,,, -185100,8.211433,1.7326638,,,,,,,,,,,,,, -185200,7.7083426,1.7393064,,,,,,,,,,,,,, -185300,7.7687163,1.6946431,,,,,,,,,,,,,, -185400,7.6575847,1.7785459,,,,,,,,,,,,,, -185500,7.218449,1.7506711,,,,,,,,,,,,,, -185600,8.64583,1.8101074,,,,,,,,,,,,,, -185700,8.083596,1.8170176,,,,,,,,,,,,,, -185800,7.2080727,1.7633725,,,,,,,,,,,,,, -185900,7.2949586,1.7637798,,,,,,,,,,,,,, -186000,7.756521,1.781855,,,,,,,,,,,,,, -186100,7.7563195,1.7507923,,,,,,,,,,,,,, -186177,,,0.922253668308258,0.4010597765445709,0.7745199799537659,0.9941147565841676,50000.0,0.6533000469207764,1.5955201387405396,10000.0,62792.42883324623,65154.52738904953,62792.42883324623,2348.503286600113,7.2113025188446045,0.0 -186200,7.914061,1.8221921,,,,,,,,,,,,,, -186300,7.8228383,1.743831,,,,,,,,,,,,,, -186400,8.285632,1.7448946,,,,,,,,,,,,,, -186500,7.6339407,1.7352808,,,,,,,,,,,,,, -186600,8.026839,1.8062842,,,,,,,,,,,,,, -186666,,,0.9229711294174194,0.3895560204982757,0.774399995803833,0.9909643530845642,50000.0,0.6532000303268433,1.594335675239563,10000.0,62957.09817099571,65336.45809841156,62957.09817099571,2365.690082073212,7.268607139587402,0.0 -186666,,,,,,,,,,,62957.09817099571,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/eval_measurements.csv deleted file mode 100644 index e020d43ff..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.564436674118042,0.0,31.10077476501465,1,0,31.10077476501465,0.0009000000427477,6.912177562713623,10000,48.6652979850769,0.0011957908282056,6.911755561828613,0.0011199999134987,6.912059783935547,50000 -35.312798261642456,0.018644094467163,541.0707614421844,1508,0,541.0707614421844,0.0577000044286251,5.554655075073242,10000,576.4564385414124,0.0834064111113548,5.276715278625488,0.0773599967360496,5.338540077209473,50000 -53.65842294692993,0.0583107471466064,1051.159835577011,3014,0,1051.159835577011,0.1319000124931335,4.727846622467041,10000,1104.983157634735,0.1994778364896774,4.18760871887207,0.1808199882507324,4.304879188537598,50000 -71.37783074378967,0.0846424102783203,1561.2604904174805,4520,0,1561.2604904174805,0.2049000114202499,4.1407623291015625,10000,1632.881965637207,0.2977120578289032,3.51352596282959,0.2717199921607971,3.6509952545166016,50000 -89.01321029663086,0.1151204109191894,2071.384763002396,6027,0,2071.384763002396,0.2648999989032745,3.793429136276245,10000,2160.7247705459595,0.3806600570678711,3.0440993309021,0.3531999886035919,3.1904051303863525,50000 -106.56827187538148,0.1405167579650879,2581.59383225441,7535,0,2581.59383225441,0.3281000256538391,3.381321668624878,10000,2688.5689568519592,0.4617745578289032,2.5876331329345703,0.4347999989986419,2.728386878967285,50000 -124.32044792175292,0.1719436645507812,3091.5421063899994,9042,0,3091.5421063899994,0.3534000217914581,3.23469614982605,10000,3216.3535330295563,0.5261878371238708,2.2832539081573486,0.4661199748516083,2.568787813186645,50000 -142.20368576049805,0.2005589008331298,3601.6693108081818,10551,0,3601.6693108081818,0.4077000319957733,2.899076700210572,10000,3744.4450080394745,0.5644331574440002,2.038395643234253,0.5180599689483643,2.2616748809814453,50000 -160.059020280838,0.2294116020202636,4111.758868694305,12060,0,4111.758868694305,0.4248000085353851,2.805687189102173,10000,4272.47144985199,0.5901626348495483,1.959301471710205,0.5461400151252747,2.174945831298828,50000 -177.62704491615295,0.2588982582092285,4621.998073577881,13570,0,4621.998073577881,0.4430000185966491,2.674421548843384,10000,4800.3601586818695,0.6247608065605164,1.7613424062728882,0.5734399557113647,2.0024077892303467,50000 -195.13138890266416,0.2899277210235595,5132.248927354813,15080,0,5132.248927354813,0.4603000283241272,2.65906310081482,10000,5328.198922872543,0.6309988498687744,1.777909517288208,0.5839799642562866,1.988774657249451,50000 -212.92911338806152,0.3187780380249023,5642.341354608536,16591,0,5642.341354608536,0.4653000235557556,2.612239122390747,10000,5856.172199487686,0.6404655575752258,1.7349895238876345,0.5945599675178528,1.9475181102752688,50000 -230.6470057964325,0.351142406463623,6152.393952131271,18101,0,6152.393952131271,0.4782000184059143,2.4677436351776123,10000,6384.028230428696,0.6934390664100647,1.4368542432785034,0.613099992275238,1.7975260019302368,50000 -248.4592657089233,0.3803071975708008,6662.580951213837,19612,0,6662.580951213837,0.4949000179767608,2.467326641082764,10000,6912.111881971359,0.6884366869926453,1.4951562881469729,0.6227799654006958,1.7908552885055542,50000 -266.00374031066895,0.4111535549163818,7172.652762174606,21123,0,7172.652762174606,0.493800014257431,2.425893545150757,10000,7439.812463760376,0.6886360049247742,1.4495007991790771,0.625819981098175,1.7343255281448364,50000 -283.67709016799927,0.441507339477539,7682.5801339149475,22633,0,7682.5801339149475,0.5089000463485718,2.38150954246521,10000,7967.496514797211,0.6915856003761292,1.4739915132522583,0.6281999945640564,1.7558631896972656,50000 -301.5229160785675,0.4715721607208252,8192.537194252014,24144,0,8192.537194252014,0.5017000436782837,2.382734537124634,10000,8495.38295841217,0.6928212642669678,1.4294037818908691,0.6322399973869324,1.7027853727340698,50000 -319.04385137557983,0.5028097629547119,8702.589481115341,25656,0,8702.589481115341,0.5093000531196594,2.390270709991455,10000,9023.039636611938,0.6940768361091614,1.4729230403900146,0.6380199790000916,1.7212127447128296,50000 -336.7666335105896,0.5334517955780029,9212.57828116417,27167,0,9212.57828116417,0.5049999952316284,2.398658514022827,10000,9550.835167884828,0.7210817933082581,1.3525702953338623,0.6334599852561951,1.7323843240737915,50000 -354.5671169757843,0.5640599727630615,9722.724038362505,28679,0,9722.724038362505,0.5118000507354736,2.377542734146118,10000,10078.865000724792,0.7137874364852905,1.3917793035507202,0.6363599896430969,1.7337682247161863,50000 -372.3236920833588,0.5949838161468506,10232.818863868712,30190,0,10232.818863868712,0.5134000182151794,2.357249021530152,10000,10606.80100107193,0.7102399468421936,1.383233904838562,0.6411399841308594,1.6908059120178225,50000 -390.22929978370667,0.6257216930389404,10743.079125404358,31701,0,10743.079125404358,0.515500009059906,2.3480465412139893,10000,11135.051835298538,0.7122528553009033,1.4000403881072998,0.6447399854660034,1.6911596059799194,50000 -407.86250853538513,0.65826416015625,11253.20901465416,33213,0,11253.20901465416,0.5081000328063965,2.3862760066986084,10000,11662.901600122452,0.6996970772743225,1.4557379484176636,0.6355999708175659,1.7373415231704712,50000 -425.66084122657776,0.6921558380126953,11763.446279764175,34726,0,11763.446279764175,0.5192000269889832,2.314382553100586,10000,12191.024013519287,0.7174744606018066,1.3685293197631836,0.6532599925994873,1.6530615091323853,50000 -443.2582674026489,0.7246830463409424,12273.675210475922,36238,0,12273.675210475922,0.5265000462532043,2.281526565551758,10000,12718.935508489609,0.7438416481018066,1.2298567295074463,0.6535599827766418,1.6235462427139282,50000 -460.8275353908539,0.7599701881408691,12783.86525940895,37751,0,12783.86525940895,0.5151000022888184,2.3460588455200195,10000,13246.783225536346,0.7174545526504517,1.3231170177459717,0.6446200013160706,1.666609764099121,50000 -478.63541746139526,0.7926428318023682,13293.79348897934,39263,0,13293.79348897934,0.5318000316619873,2.28325629234314,10000,13774.605125188828,0.7337173223495483,1.286712884902954,0.6584399938583374,1.6077271699905396,50000 -497.0984447002411,0.8266785144805908,13803.70843219757,40775,0,13803.70843219757,0.5288000106811523,2.251645803451538,10000,14303.069906711578,0.7258848547935486,1.2848888635635376,0.6577799916267395,1.597143530845642,50000 -514.9081664085388,0.86090087890625,14313.930534362791,42288,0,14313.930534362791,0.5174000263214111,2.3248226642608643,10000,14831.188447713852,0.7098811864852905,1.3791455030441284,0.6428200006484985,1.6814510822296145,50000 -532.3907444477081,0.8976097106933594,14824.281010389328,43801,0,14824.281010389328,0.5326000452041626,2.2799079418182373,10000,15359.110605239868,0.7302295565605164,1.3141027688980105,0.6615200042724609,1.6093543767929075,50000 -550.0225744247437,0.9313144683837892,15334.48612332344,45314,0,15334.48612332344,0.5347000360488892,2.256589412689209,10000,15887.034386634828,0.7517538070678711,1.200649619102478,0.6576399803161621,1.604137659072876,50000 -567.7500638961792,0.9677863121032716,15844.531326293943,46827,0,15844.531326293943,0.539900004863739,2.1983628273010254,10000,16414.89757823944,0.7456353306770325,1.2059592008590698,0.6634799838066101,1.5565850734710691,50000 -585.2028439044952,1.00309419631958,16354.600351333618,48339,0,16354.600351333618,0.5368000268936157,2.243327140808105,10000,16942.508352041245,0.735750138759613,1.2548892498016355,0.6612600088119507,1.583881974220276,50000 -603.0600016117096,1.037532091140747,16864.72789144516,49852,0,16864.72789144516,0.534000039100647,2.243760108947754,10000,17470.579726219177,0.7357302308082581,1.2622418403625488,0.6646999716758728,1.571779489517212,50000 -621.2467834949493,1.0746331214904783,17374.956661462784,51365,0,17374.956661462784,0.5473000407218933,2.198474884033203,10000,17999.086223602295,0.7381417155265808,1.2553119659423828,0.6695399880409241,1.566043734550476,50000 -639.0135197639465,1.109938621520996,17884.904060840607,52877,0,17884.904060840607,0.5195000171661377,2.337864875793457,10000,18526.888315439224,0.7241908311843872,1.344178318977356,0.6525599956512451,1.6665358543395996,50000 -656.6714797019958,1.147374153137207,18395.117757558823,54390,0,18395.117757558823,0.5393000245094299,2.2847721576690674,10000,19054.84995698929,0.750996470451355,1.2404024600982666,0.6648600101470947,1.6224167346954346,50000 -674.4238469600677,1.1857051849365234,18905.16711997986,55902,0,18905.16711997986,0.5325000286102295,2.248415231704712,10000,19582.74394917488,0.7460737824440002,1.2308549880981443,0.6655600070953369,1.5786885023117063,50000 -692.0639700889587,1.2219395637512207,19415.267368793488,57414,0,19415.267368793488,0.5427000522613525,2.203951358795166,10000,20110.57297754288,0.7460139989852905,1.2035185098648071,0.6686199903488159,1.54822039604187,50000 -709.7101130485535,1.257903814315796,19925.24661397934,58926,0,19925.24661397934,0.5497000217437744,2.2031824588775635,10000,20638.287900447845,0.7444595098495483,1.2356679439544678,0.6713399887084961,1.556386947631836,50000 -727.5052762031555,1.2989416122436523,20435.23851680756,60438,0,20435.23851680756,0.5458000302314758,2.216935634613037,10000,21166.16930627823,0.7448979616165161,1.266387701034546,0.6719599962234497,1.57454514503479,50000 -745.1842300891876,1.3371562957763672,20945.41634607315,61951,0,20945.41634607315,0.5335000157356262,2.219144105911255,10000,21694.11650276184,0.7596260905265808,1.151545763015747,0.6668800115585327,1.5589600801467896,50000 -763.0073924064636,1.374301195144653,21455.393117904663,63463,0,21455.393117904663,0.5394999980926514,2.243274211883545,10000,22222.00701785088,0.7645288705825806,1.1519296169281006,0.6728799939155579,1.5548346042633057,50000 -780.7278144359589,1.4133169651031494,21965.388853788376,64976,0,21965.388853788376,0.5503000020980835,2.204355478286743,10000,22749.81556200981,0.7543845772743225,1.1895458698272705,0.6688799858093262,1.562997817993164,50000 -798.4267275333405,1.4519784450531006,22475.54665660858,66488,0,22475.54665660858,0.5478000044822693,2.2059414386749268,10000,23277.76457595825,0.7489436864852905,1.2165971994400024,0.6738399863243103,1.5674960613250732,50000 -815.9583787918091,1.4906814098358154,22985.49106383324,68001,0,22985.49106383324,0.5573000311851501,2.136253595352173,10000,23805.33237314224,0.7564173936843872,1.1640042066574097,0.6771000027656555,1.5072546005249023,50000 -833.9057495594025,1.5299150943756104,23495.39884543419,69513,0,23495.39884543419,0.5519000291824341,2.177412986755371,10000,24333.28107357025,0.7506178021430969,1.205312728881836,0.6768999695777893,1.5448400974273682,50000 -851.3368241786957,1.5676684379577637,24005.601548433304,71026,0,24005.601548433304,0.5546000003814697,2.1463091373443604,10000,24861.00675535202,0.7942841053009033,1.016778826713562,0.6810799837112427,1.494775891304016,50000 -869.2670221328735,1.6069247722625732,24515.58721637726,72538,0,24515.58721637726,0.5504000186920166,2.2045271396636963,10000,25389.0153260231,0.7712053656578064,1.1332015991210938,0.6800199747085571,1.5349023342132568,50000 -886.9353656768799,1.6402819156646729,25025.69666481018,74051,0,25025.69666481018,0.5599000453948975,2.144322395324707,10000,25916.87974834442,0.7703882455825806,1.119797945022583,0.6869999766349792,1.496230125427246,50000 -904.3809192180634,1.6784143447875977,25535.903984308243,75564,0,25535.903984308243,0.5641000270843506,2.1208748817443848,10000,26444.623901844025,0.7689333558082581,1.1266475915908811,0.6871799826622009,1.490763545036316,50000 -921.9570591449738,1.717276096343994,26045.897315502167,77077,0,26045.897315502167,0.5550000071525574,2.229222297668457,10000,26972.284712553024,0.7555803656578064,1.252349853515625,0.6780799627304077,1.591875433921814,50000 -939.5918412208556,1.7566168308258057,26555.934617996216,78589,0,26555.934617996216,0.5601000189781189,2.1550445556640625,10000,27500.04880857468,0.7691724896430969,1.146484375,0.6866399645805359,1.4988430738449097,50000 -958.4196372032166,1.7997441291809082,27066.09313583374,80103,0,27066.09313583374,0.5586000084877014,2.127764940261841,10000,28029.13027572632,0.8092314600944519,0.9729456305503844,0.6866999864578247,1.4889295101165771,50000 -976.1020576953888,1.8335380554199217,27576.04886603356,81614,0,27576.04886603356,0.5557000041007996,2.1231517791748047,10000,28556.85472536087,0.7765266299247742,1.067360758781433,0.6843599677085876,1.4753485918045044,50000 -993.557685136795,1.876111507415772,28085.96573448181,83126,0,28085.96573448181,0.5627000331878662,2.139356851577759,10000,29084.32238149643,0.7771045565605164,1.1030161380767822,0.6881399750709534,1.486668586730957,50000 -1011.1986262798308,1.9192132949829104,28596.036551475525,84638,0,28596.036551475525,0.5644000172615051,2.1074471473693848,10000,29612.1298763752,0.7768853306770325,1.0869734287261963,0.693399965763092,1.4536720514297483,50000 -1028.7158319950104,1.964526891708374,29106.1050195694,86151,0,29106.1050195694,0.5592000484466553,2.1242754459381104,10000,30139.81373643875,0.7732979655265808,1.0890520811080933,0.6902599930763245,1.4564690589904783,50000 -1046.6535975933075,2.0054867267608643,29616.25096130371,87664,0,29616.25096130371,0.5514000058174133,2.184067964553833,10000,30667.99164557457,0.755281388759613,1.1695573329925537,0.6784399747848511,1.5213377475738523,50000 -1064.221899986267,2.043313980102539,30126.34645795822,89177,0,30126.34645795822,0.5473000407218933,2.2165260314941406,10000,31195.745640039444,0.7955994606018066,1.0491045713424685,0.6782199740409851,1.5385297536849976,50000 -1081.7890536785126,2.085658311843872,30636.46873354912,90690,0,30636.46873354912,0.572700023651123,2.065574884414673,10000,31723.53106689453,0.8026148080825806,0.9901580810546876,0.7046200037002563,1.4159032106399536,50000 -1099.251507282257,2.132451057434082,31146.66298151016,92203,0,31146.66298151016,0.5664000511169434,2.099205255508423,10000,32251.288065195084,0.7906369566917419,1.0359426736831665,0.6946600079536438,1.449451208114624,50000 -1116.7954907417295,2.174596071243286,31656.57655262947,93715,0,31656.57655262947,0.5732000470161438,2.0790419578552246,10000,32778.84194803238,0.7890027165412903,1.048642635345459,0.6960399746894836,1.4489803314208984,50000 -1134.5098896026611,2.2162539958953857,32166.51022219658,95228,0,32166.51022219658,0.5764999985694885,2.050769567489624,10000,33306.58448624611,0.7940050959587097,1.0202025175094604,0.706059992313385,1.4011818170547483,50000 -1152.02063703537,2.271630048751831,32676.71812939644,96741,0,32676.71812939644,0.5685000419616699,2.087905168533325,10000,33834.412241220474,0.7877470850944519,1.0436723232269287,0.6979999542236328,1.427424669265747,50000 -1169.504544019699,2.3140299320220947,33186.897922992706,98254,0,33186.897922992706,0.5869000554084778,2.0103774070739746,10000,34362.17217421532,0.8298588991165161,0.8838488459587097,0.7114599943161011,1.3823779821395874,50000 -1187.0604236125946,2.3605294227600098,33696.980467796326,99767,0,33696.980467796326,0.5852000117301941,2.0205318927764893,10000,34889.91017913818,0.8191565275192261,0.9219579696655272,0.7096199989318848,1.3757820129394531,50000 -1204.593267440796,2.404573440551758,34206.91540932655,101280,0,34206.91540932655,0.5839000344276428,2.0435094833374023,10000,35417.475497722626,0.8095503449440002,0.9634189605712892,0.7093799710273743,1.3969210386276243,50000 -1222.228625535965,2.44966459274292,34717.08581137657,102793,0,34717.08581137657,0.5825000405311584,2.034860849380493,10000,35945.378918647766,0.8084542155265808,0.9739400744438172,0.7127400040626526,1.390412449836731,50000 -1239.6494569778442,2.493047952651977,35227.30777025223,104307,0,35227.30777025223,0.5819000005722046,2.023756504058838,10000,36473.11824512482,0.7966358065605164,0.990210771560669,0.7057600021362305,1.3867374658584597,50000 -1257.0016777515411,2.5371103286743164,35737.31253170967,105819,0,35737.31253170967,0.5885000228881836,2.001072645187378,10000,37000.57236742973,0.809968888759613,0.9250097274780272,0.7107399702072144,1.3485785722732544,50000 -1274.6831283569336,2.58165979385376,36247.26858711243,107331,0,36247.26858711243,0.5887000560760498,1.985122561454773,10000,37528.30771255493,0.8362364172935486,0.838945746421814,0.7157599925994873,1.3426353931427002,50000 -1292.3768393993378,2.629001379013061,36757.356506347656,108844,0,36757.356506347656,0.5867000222206116,2.025434732437134,10000,38056.19006371498,0.8233019709587097,0.8994101285934448,0.7140600085258484,1.3612909317016602,50000 -1309.9977324008942,2.672827243804932,37267.36195087433,110357,0,37267.36195087433,0.5910000205039978,1.9762768745422363,10000,38583.91264152527,0.8215481042861938,0.8967534899711609,0.7145199775695801,1.3505741357803345,50000 -1327.280189990997,2.72864294052124,37777.49592471123,111870,0,37777.49592471123,0.5868000388145447,1.9892569780349727,10000,39111.43874955177,0.8194156289100647,0.8783233761787415,0.7137799859046936,1.3280168771743774,50000 -1344.7098760604858,2.7765440940856934,38287.695055007935,113384,0,38287.695055007935,0.5896000266075134,1.9782979488372805,10000,39639.17041897774,0.8175222873687744,0.9152930974960328,0.7148399949073792,1.350134015083313,50000 -1362.3821530342102,2.825171709060669,38797.88224673271,114897,0,38797.88224673271,0.5919000506401062,2.00195837020874,10000,40167.13142871857,0.8275271058082581,0.9153132438659668,0.7184000015258789,1.3692786693572998,50000 -1379.988827228546,2.8746609687805176,39307.83862376213,116409,0,39307.83862376213,0.588200032711029,2.00327205657959,10000,40694.7973818779,0.848074734210968,0.8165138363838196,0.7175999879837036,1.3547149896621704,50000 -1398.249297618866,2.9214396476745605,39818.07694029808,117922,0,39818.07694029808,0.5986000299453735,1.9471795558929443,10000,41223.39574432373,0.8465999364852905,0.8292368650436401,0.7242199778556824,1.3349623680114746,50000 -1415.6428937911987,2.97339940071106,40328.25413155556,119436,0,40328.25413155556,0.5958999991416931,1.961624264717102,10000,41751.07174015045,0.8353993892669678,0.8475539088249207,0.724399983882904,1.3222719430923462,50000 -1433.5952577590942,3.020388603210449,40838.44957041741,120949,0,40838.44957041741,0.6011000275611877,1.9599751234054563,10000,42279.31935048103,0.8393654227256775,0.8338555097579956,0.7299799919128418,1.3040684461593628,50000 -1451.1136507987976,3.0704591274261475,41348.56471085549,122463,0,41348.56471085549,0.6029000282287598,1.931108474731445,10000,42807.05680012703,0.8434908986091614,0.8159423470497131,0.730459988117218,1.2901736497879028,50000 -1468.6732609272003,3.1229352951049805,41858.57477927208,123976,0,41858.57477927208,0.604200005531311,1.907442688941956,10000,43334.73067235947,0.843191921710968,0.7751460671424866,0.7268799543380737,1.2697910070419312,50000 -1486.1433503627777,3.1712000370025635,42368.69026613236,125489,0,42368.69026613236,0.5929000377655029,2.000288963317871,10000,43862.41753697395,0.84574294090271,0.8140817880630493,0.7204599976539612,1.347158670425415,50000 -1503.5928556919098,3.2235565185546875,42878.68905615807,127001,0,42878.68905615807,0.6015000343322754,1.922147512435913,10000,44389.9714012146,0.851980984210968,0.7569774389266968,0.7263799905776978,1.2846840620040894,50000 -1521.1470046043396,3.2711191177368164,43388.70094943047,128514,0,43388.70094943047,0.6098000407218933,1.9216325283050537,10000,44917.63706827164,0.8525390625,0.7666031718254089,0.7315399646759033,1.2773808240890503,50000 -1538.5677382946014,3.3185129165649414,43898.8369243145,130027,0,43898.8369243145,0.6058000326156616,1.929798483848572,10000,45445.29455208778,0.8579002022743225,0.7536612153053284,0.7330600023269653,1.2795575857162476,50000 -1556.2172808647156,3.3665919303894043,44409.057758808136,131540,0,44409.057758808136,0.6103000044822693,1.926138162612915,10000,45973.265880823135,0.8606704473495483,0.7800890803337097,0.7356199622154236,1.2970260381698608,50000 -1573.6209378242493,3.42007064819336,44919.02611017227,133052,0,44919.02611017227,0.615600049495697,1.887601613998413,10000,46500.74430775642,0.8704559803009033,0.7145001888275146,0.737060010433197,1.258089542388916,50000 -1591.034093618393,3.4738011360168457,45429.102596998215,134565,0,45429.102596998215,0.6135000586509705,1.904842615127564,10000,47028.34170055389,0.8752790093421936,0.6944555640220642,0.7367199659347534,1.2754249572753906,50000 -1608.5354924201963,3.526309967041016,45939.30143976212,136078,0,45939.30143976212,0.6200000047683716,1.8570964336395264,10000,47556.14796924591,0.8797034025192261,0.667981743812561,0.7418999671936035,1.2350200414657593,50000 -1626.244454622269,3.5757358074188232,46449.3056447506,137591,0,46449.3056447506,0.6186000108718872,1.8752394914627075,10000,48083.96310710907,0.8743423223495483,0.6792070865631104,0.7400799989700317,1.234809637069702,50000 -1644.1155638694763,3.6292083263397217,46959.34210753441,139104,0,46959.34210753441,0.6166000366210938,1.8933311700820925,10000,48611.97661066055,0.880301296710968,0.6912388801574707,0.7443999648094177,1.2525049448013306,50000 -1661.5373244285583,3.6814053058624254,47469.49969315529,140618,0,47469.49969315529,0.6273000240325928,1.8636746406555176,10000,49139.661640405655,0.881257951259613,0.6726261973381042,0.7454599738121033,1.2311464548110962,50000 -1678.9434888362885,3.731943845748901,47979.43871974945,142130,0,47979.43871974945,0.6145000457763672,1.894722580909729,10000,49667.10993528366,0.8937141299247742,0.6258179545402527,0.7417399883270264,1.2493900060653689,50000 -1696.3737258911133,3.7850804328918457,48489.606682538986,143643,0,48489.606682538986,0.6167000532150269,1.8685358762741089,10000,50194.81318330765,0.8984175324440002,0.5999704599380493,0.7469199895858765,1.220354676246643,50000 -1713.693132162094,3.836029767990112,48999.76789522171,145156,0,48999.76789522171,0.6219000220298767,1.8563332557678225,10000,50722.39692115784,0.894551157951355,0.6113947629928589,0.745959997177124,1.216866374015808,50000 -1731.5386974811554,3.894920825958252,49509.97759890556,146670,0,49509.97759890556,0.6224000453948975,1.8555512428283687,10000,51250.5646352768,0.8981783986091614,0.6124516725540161,0.748479962348938,1.2270301580429075,50000 -1748.8881268501282,3.947747468948364,50020.18068480492,148183,0,50020.18068480492,0.6222000122070312,1.876502990722656,10000,51778.22360420227,0.8929169178009033,0.6327682137489319,0.746679961681366,1.2398338317871094,50000 -1766.2669341564178,4.001991271972656,50530.1194422245,149696,0,50530.1194422245,0.6230000257492065,1.8418248891830444,10000,52305.64930200577,0.89652419090271,0.6026631593704224,0.7511999607086182,1.203833818435669,50000 -1783.8296740055084,4.057976961135864,51040.28805708885,151209,0,51040.28805708885,0.6265000104904175,1.8258957862854004,10000,52833.48942470551,0.9214365482330322,0.5090224146842957,0.7524799704551697,1.1925246715545654,50000 -1801.2003610134125,4.109041690826416,51550.47252988815,152723,0,51550.47252988815,0.6246000528335571,1.8432016372680664,10000,53361.15016055107,0.9169324040412904,0.526613175868988,0.7518599629402161,1.19866681098938,50000 -1818.829811096192,4.163464784622192,52060.446773052216,154235,0,52060.446773052216,0.6273000240325928,1.8422486782073968,10000,53888.86126732826,0.9125677347183228,0.5550875067710876,0.7528600096702576,1.2101103067398071,50000 -1836.47781085968,4.209648609161377,52570.49973154068,155748,0,52570.49973154068,0.6277000308036804,1.836329817771912,10000,54416.66174650192,0.9142418503761292,0.544491171836853,0.7534799575805664,1.1975823640823364,50000 -1854.4784564971924,4.263177156448364,53080.50154972077,157261,0,53080.50154972077,0.6292000412940979,1.8402178287506104,10000,54944.76948451996,0.9153977632522584,0.5491384267807007,0.7535799741744995,1.2018764019012451,50000 -1871.9461405277248,4.316278457641602,53590.45152449608,158773,0,53590.45152449608,0.6296000480651855,1.8320457935333248,10000,55472.29345464706,0.9156568646430968,0.5355087518692017,0.7560999989509583,1.1917190551757812,50000 -1889.3943195343013,4.3780577182769775,54100.61996221542,160286,0,54100.61996221542,0.6318000555038452,1.8300745487213133,10000,56000.02527117729,0.9319993257522584,0.4888227581977844,0.7568399906158447,1.1966514587402344,50000 -1906.8222138881683,4.4321160316467285,54610.668021678925,161799,0,54610.668021678925,0.6349000334739685,1.818547010421753,10000,56527.60770535469,0.9300462007522584,0.4880435168743133,0.7581999897956848,1.1818283796310425,50000 -1924.1260414123533,4.511041164398193,55120.69439053536,163311,0,55120.69439053536,0.6351000070571899,1.817732334136963,10000,57055.07130908966,0.9286909699440002,0.4972872734069824,0.7575799822807312,1.1837843656539917,50000 -1941.725725650788,4.56522536277771,55630.80797767639,164824,0,55630.80797767639,0.6360000371932983,1.8208544254302976,10000,57582.89201283455,0.9278140664100648,0.496550053358078,0.7587599754333496,1.1828312873840332,50000 -1959.3804275989528,4.620617151260376,56140.73363828659,166336,0,56140.73363828659,0.6350000500679016,1.8172756433486936,10000,58110.581184625626,0.9301857352256776,0.4875613152980804,0.7597599625587463,1.1802102327346802,50000 -1977.0555260181427,4.680642366409302,56650.924723148346,167849,0,56650.924723148346,0.6374000310897827,1.8157609701156616,10000,58638.559196949005,0.9341716766357422,0.4808126091957092,0.7604199647903442,1.180977702140808,50000 -1994.679827213288,4.739821195602417,57161.05779337883,169362,0,57161.05779337883,0.6372000575065613,1.8115358352661133,10000,59166.4287109375,0.9384565949440002,0.4624026119709015,0.7603200078010559,1.1795930862426758,50000 -2012.0857291221616,4.797400236129761,57671.088973522186,170875,0,57671.088973522186,0.6381000280380249,1.8069534301757808,10000,59693.97801613808,0.9371013641357422,0.4609328508377075,0.7616399526596069,1.170788288116455,50000 -2029.605746269226,4.853551864624023,58181.277435302734,172387,0,58181.277435302734,0.6375000476837158,1.80931556224823,10000,60221.7962770462,0.9367625713348388,0.4671074450016022,0.7630599737167358,1.1761828660964966,50000 -2046.9627692699432,4.915642261505127,58691.458490133286,173900,0,58691.458490133286,0.6363000273704529,1.8088829517364504,10000,60749.449186086655,0.9363639950752258,0.4647765457630157,0.7618199586868286,1.1747890710830688,50000 -2064.736491918564,4.975682020187378,59201.67768287659,175413,0,59201.67768287659,0.6386000514030457,1.8084932565689087,10000,61277.555790662766,0.9368223547935486,0.4632803499698639,0.7626799941062927,1.174310326576233,50000 -2082.3390328884125,5.031782627105713,59711.65766215325,176926,0,59711.65766215325,0.6410000324249268,1.8080230951309204,10000,61805.24752783775,0.9379583597183228,0.4658437371253967,0.7636199593544006,1.175073504447937,50000 -2099.8054864406586,5.092525005340576,60221.87917423248,178439,0,60221.87917423248,0.6402000188827515,1.8037229776382449,10000,62333.05004048348,0.9407883882522584,0.4490717649459839,0.7630800008773804,1.169806957244873,50000 -2117.058675289154,5.151561737060547,60732.035746097565,179952,0,60732.035746097565,0.640500009059906,1.8030924797058103,10000,62860.571590185165,0.940808355808258,0.4544986188411712,0.7629599571228027,1.1712826490402222,50000 -2134.386967897415,5.21010160446167,61242.113387584686,181464,0,61242.113387584686,0.6407000422477722,1.8057492971420288,10000,63388.0883102417,0.939871609210968,0.4563002586364746,0.7633999586105347,1.173967719078064,50000 -2151.952211856842,5.269731998443604,61752.08497405052,182977,0,61752.08497405052,0.6399000287055969,1.800960659980774,10000,63915.73721885681,0.939851701259613,0.4527520835399627,0.7633599638938904,1.1705142259597778,50000 -2169.6572070121765,5.3304524421691895,62262.14775061607,184489,0,62262.14775061607,0.6402000188827515,1.80498480796814,10000,64443.61814188957,0.93949294090271,0.4551874995231628,0.7637400031089783,1.1731160879135132,50000 -2186.9935114383698,5.393200159072876,62772.11812114716,186002,0,62772.11812114716,0.6404000520706177,1.8014260530471802,10000,64971.04134345055,0.9404894709587096,0.4521209001541137,0.7636399865150452,1.167830228805542,50000 -2204.471424341202,5.451512098312378,62996.11162734032,186666,0,62996.11162734032,0.6404000520706177,1.8022942543029785,10000,65212.594963788986,0.9403100609779358,0.44991639256477356,0.763759970664978,1.1691426038742065,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/measurements.csv deleted file mode 100644 index 14a772499..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5562551,6.9256654,,,,,,,,,,,,,, -1,,,0.0011957908282056,6.911755561828613,0.0011199999134987,6.912059783935547,50000.0,0.0009000000427477,6.912177562713623,10000.0,31.10077476501465,48.6652979850769,31.10077476501465,17.564436674118042,0.0,0.0 -100,0.52879894,6.9011106,,,,,,,,,,,,,, -200,0.5411919,6.8568015,,,,,,,,,,,,,, -300,0.58589166,6.7643685,,,,,,,,,,,,,, -400,0.6113756,6.6849775,,,,,,,,,,,,,, -500,0.66772735,6.6414332,,,,,,,,,,,,,, -600,0.6808602,6.5412765,,,,,,,,,,,,,, -700,0.77916497,6.44799,,,,,,,,,,,,,, -800,1.4667825,6.3698874,,,,,,,,,,,,,, -900,1.503084,6.304823,,,,,,,,,,,,,, -1000,1.8208605,6.282628,,,,,,,,,,,,,, -1100,2.348894,6.126011,,,,,,,,,,,,,, -1200,1.8363744,6.0963187,,,,,,,,,,,,,, -1300,1.5095797,6.052471,,,,,,,,,,,,,, -1400,1.7882402,5.983612,,,,,,,,,,,,,, -1500,2.798366,5.9372087,,,,,,,,,,,,,, -1508,,,0.0834064111113548,5.276715278625488,0.0773599967360496,5.338540077209473,50000.0,0.0577000044286251,5.554655075073242,10000.0,541.0707614421844,576.4564385414124,541.0707614421844,35.312798261642456,0.018644094467163,0.0 -1600,1.769571,5.8558216,,,,,,,,,,,,,, -1700,2.1807952,5.8484764,,,,,,,,,,,,,, -1800,2.3275948,5.779243,,,,,,,,,,,,,, -1900,3.4302497,5.7116385,,,,,,,,,,,,,, -2000,2.58228,5.694634,,,,,,,,,,,,,, -2100,2.923163,5.624066,,,,,,,,,,,,,, -2200,2.5319939,5.5905027,,,,,,,,,,,,,, -2300,5.1756544,5.563944,,,,,,,,,,,,,, -2400,4.269682,5.5070477,,,,,,,,,,,,,, -2500,3.0944183,5.5281463,,,,,,,,,,,,,, -2600,2.6137793,5.392652,,,,,,,,,,,,,, -2700,4.873204,5.336563,,,,,,,,,,,,,, -2800,3.2162154,5.280334,,,,,,,,,,,,,, -2900,5.8924227,5.2110205,,,,,,,,,,,,,, -3000,4.9699054,5.271942,,,,,,,,,,,,,, -3014,,,0.1994778364896774,4.18760871887207,0.1808199882507324,4.304879188537598,50000.0,0.1319000124931335,4.727846622467041,10000.0,1051.159835577011,1104.983157634735,1051.159835577011,53.65842294692993,0.0583107471466064,0.0 -3100,3.5252228,5.2207513,,,,,,,,,,,,,, -3200,3.2950227,5.1199827,,,,,,,,,,,,,, -3300,2.59023,5.079239,,,,,,,,,,,,,, -3400,4.9810815,5.1789737,,,,,,,,,,,,,, -3500,3.524756,5.025016,,,,,,,,,,,,,, -3600,3.285896,5.072305,,,,,,,,,,,,,, -3700,2.7614362,4.973147,,,,,,,,,,,,,, -3800,2.8860824,4.918343,,,,,,,,,,,,,, -3900,4.865739,4.9198723,,,,,,,,,,,,,, -4000,4.0742645,4.9512806,,,,,,,,,,,,,, -4100,3.4135904,4.8443003,,,,,,,,,,,,,, -4200,4.570369,4.7867637,,,,,,,,,,,,,, -4300,3.4934049,4.7909036,,,,,,,,,,,,,, -4400,4.6502333,4.6041174,,,,,,,,,,,,,, -4500,9.161667,4.697425,,,,,,,,,,,,,, -4520,,,0.2977120578289032,3.51352596282959,0.2717199921607971,3.6509952545166016,50000.0,0.2049000114202499,4.1407623291015625,10000.0,1561.2604904174805,1632.881965637207,1561.2604904174805,71.37783074378967,0.0846424102783203,0.0 -4600,4.5680394,4.738797,,,,,,,,,,,,,, -4700,3.8749123,4.750304,,,,,,,,,,,,,, -4800,4.491113,4.6418724,,,,,,,,,,,,,, -4900,2.3008304,4.5920587,,,,,,,,,,,,,, -5000,3.3617277,4.6111584,,,,,,,,,,,,,, -5100,2.5283368,4.5244346,,,,,,,,,,,,,, -5200,2.9449258,4.5857544,,,,,,,,,,,,,, -5300,3.0145493,4.458437,,,,,,,,,,,,,, -5400,4.1979704,4.4290347,,,,,,,,,,,,,, -5500,3.6430657,4.5310698,,,,,,,,,,,,,, -5600,2.9670877,4.439545,,,,,,,,,,,,,, -5700,3.0947995,4.4822197,,,,,,,,,,,,,, -5800,3.134993,4.3379583,,,,,,,,,,,,,, -5900,3.0374243,4.4233694,,,,,,,,,,,,,, -6000,3.4959903,4.393601,,,,,,,,,,,,,, -6027,,,0.3806600570678711,3.0440993309021,0.3531999886035919,3.1904051303863525,50000.0,0.2648999989032745,3.793429136276245,10000.0,2071.384763002396,2160.7247705459595,2071.384763002396,89.01321029663086,0.1151204109191894,0.0 -6100,3.3188946,4.3391232,,,,,,,,,,,,,, -6200,2.4041085,4.286675,,,,,,,,,,,,,, -6300,3.2529323,4.3973117,,,,,,,,,,,,,, -6400,2.783098,4.3952537,,,,,,,,,,,,,, -6500,4.5313773,4.278348,,,,,,,,,,,,,, -6600,2.5969794,4.303044,,,,,,,,,,,,,, -6700,3.6577091,4.2531652,,,,,,,,,,,,,, -6800,2.5535393,4.2337914,,,,,,,,,,,,,, -6900,2.8784077,4.1781797,,,,,,,,,,,,,, -7000,2.8215034,4.093569,,,,,,,,,,,,,, -7100,2.89007,4.1785717,,,,,,,,,,,,,, -7200,2.5701227,4.144045,,,,,,,,,,,,,, -7300,2.7494106,4.152108,,,,,,,,,,,,,, -7400,2.5906866,4.195636,,,,,,,,,,,,,, -7500,2.9887602,4.091313,,,,,,,,,,,,,, -7535,,,0.4617745578289032,2.5876331329345703,0.4347999989986419,2.728386878967285,50000.0,0.3281000256538391,3.381321668624878,10000.0,2581.59383225441,2688.5689568519592,2581.59383225441,106.56827187538148,0.1405167579650879,0.0 -7600,2.4102669,4.0924172,,,,,,,,,,,,,, -7700,2.1378605,4.118382,,,,,,,,,,,,,, -7800,3.201398,4.1047964,,,,,,,,,,,,,, -7900,1.7879452,4.0402827,,,,,,,,,,,,,, -8000,2.2274997,4.036589,,,,,,,,,,,,,, -8100,2.9739187,4.02443,,,,,,,,,,,,,, -8200,3.348525,3.9884686,,,,,,,,,,,,,, -8300,3.367972,3.987361,,,,,,,,,,,,,, -8400,3.0288649,3.9962158,,,,,,,,,,,,,, -8500,2.300789,4.048567,,,,,,,,,,,,,, -8600,1.9451736,4.0370817,,,,,,,,,,,,,, -8700,2.514791,3.909171,,,,,,,,,,,,,, -8800,2.6097972,3.927833,,,,,,,,,,,,,, -8900,2.5759773,3.9243197,,,,,,,,,,,,,, -9000,2.3337176,4.024574,,,,,,,,,,,,,, -9042,,,0.5261878371238708,2.2832539081573486,0.4661199748516083,2.568787813186645,50000.0,0.3534000217914581,3.23469614982605,10000.0,3091.5421063899994,3216.3535330295563,3091.5421063899994,124.32044792175292,0.1719436645507812,0.0 -9100,1.5845464,3.9060633,,,,,,,,,,,,,, -9200,2.1251004,3.9335194,,,,,,,,,,,,,, -9300,1.9758509,3.8943841,,,,,,,,,,,,,, -9400,1.8538833,3.8782105,,,,,,,,,,,,,, -9500,2.167395,3.8253083,,,,,,,,,,,,,, -9600,2.217989,3.9410095,,,,,,,,,,,,,, -9700,2.3272226,3.9258862,,,,,,,,,,,,,, -9800,2.0117505,3.8723068,,,,,,,,,,,,,, -9900,2.460431,3.8901114,,,,,,,,,,,,,, -10000,1.8362519,3.8825254,,,,,,,,,,,,,, -10100,1.7170964,3.8966036,,,,,,,,,,,,,, -10200,2.2177827,3.8679326,,,,,,,,,,,,,, -10300,1.9294435,3.8564851,,,,,,,,,,,,,, -10400,1.7017616,3.7931347,,,,,,,,,,,,,, -10500,1.8509865,3.7358482,,,,,,,,,,,,,, -10551,,,0.5644331574440002,2.038395643234253,0.5180599689483643,2.2616748809814453,50000.0,0.4077000319957733,2.899076700210572,10000.0,3601.6693108081818,3744.4450080394745,3601.6693108081818,142.20368576049805,0.2005589008331298,0.0 -10600,1.9948605,3.733345,,,,,,,,,,,,,, -10700,2.3096538,3.7171838,,,,,,,,,,,,,, -10800,2.051727,3.7921448,,,,,,,,,,,,,, -10900,1.832192,3.707442,,,,,,,,,,,,,, -11000,1.9197631,3.7965333,,,,,,,,,,,,,, -11100,2.6168807,3.7299569,,,,,,,,,,,,,, -11200,1.5695492,3.78495,,,,,,,,,,,,,, -11300,1.9474478,3.7238557,,,,,,,,,,,,,, -11400,1.8990769,3.8023067,,,,,,,,,,,,,, -11500,1.5774724,3.7202928,,,,,,,,,,,,,, -11600,1.8088409,3.7018762,,,,,,,,,,,,,, -11700,1.8009822,3.7170904,,,,,,,,,,,,,, -11800,2.3016887,3.7121758,,,,,,,,,,,,,, -11900,1.3862728,3.6042876,,,,,,,,,,,,,, -12000,1.4355953,3.6388452,,,,,,,,,,,,,, -12060,,,0.5901626348495483,1.959301471710205,0.5461400151252747,2.174945831298828,50000.0,0.4248000085353851,2.805687189102173,10000.0,4111.758868694305,4272.47144985199,4111.758868694305,160.059020280838,0.2294116020202636,0.0 -12100,2.033804,3.6920476,,,,,,,,,,,,,, -12200,1.7749747,3.6977472,,,,,,,,,,,,,, -12300,1.7760097,3.7275555,,,,,,,,,,,,,, -12400,1.9330546,3.694358,,,,,,,,,,,,,, -12500,2.9264462,3.6424148,,,,,,,,,,,,,, -12600,2.4753447,3.5709457,,,,,,,,,,,,,, -12700,1.9129627,3.688491,,,,,,,,,,,,,, -12800,1.6914266,3.5743191,,,,,,,,,,,,,, -12900,1.4747545,3.639884,,,,,,,,,,,,,, -13000,1.904068,3.664855,,,,,,,,,,,,,, -13100,1.3906753,3.6157947,,,,,,,,,,,,,, -13200,1.4076977,3.6244643,,,,,,,,,,,,,, -13300,1.6600183,3.647,,,,,,,,,,,,,, -13400,1.7866441,3.575566,,,,,,,,,,,,,, -13500,1.3805004,3.5950954,,,,,,,,,,,,,, -13570,,,0.6247608065605164,1.7613424062728882,0.5734399557113647,2.0024077892303467,50000.0,0.4430000185966491,2.674421548843384,10000.0,4621.998073577881,4800.3601586818695,4621.998073577881,177.62704491615295,0.2588982582092285,0.0 -13600,1.4068216,3.5600076,,,,,,,,,,,,,, -13700,1.4309577,3.513751,,,,,,,,,,,,,, -13800,1.3936424,3.6334908,,,,,,,,,,,,,, -13900,1.5028021,3.5441108,,,,,,,,,,,,,, -14000,1.8630468,3.6026711,,,,,,,,,,,,,, -14100,1.4972643,3.5375788,,,,,,,,,,,,,, -14200,1.7416502,3.618748,,,,,,,,,,,,,, -14300,1.7341714,3.5330238,,,,,,,,,,,,,, -14400,1.6921818,3.5481243,,,,,,,,,,,,,, -14500,1.8325652,3.5124624,,,,,,,,,,,,,, -14600,1.7772923,3.5222251,,,,,,,,,,,,,, -14700,1.4499758,3.578837,,,,,,,,,,,,,, -14800,1.6380969,3.427617,,,,,,,,,,,,,, -14900,1.3396852,3.550614,,,,,,,,,,,,,, -15000,1.1506863,3.519022,,,,,,,,,,,,,, -15080,,,0.6309988498687744,1.777909517288208,0.5839799642562866,1.988774657249451,50000.0,0.4603000283241272,2.65906310081482,10000.0,5132.248927354813,5328.198922872543,5132.248927354813,195.13138890266416,0.2899277210235595,0.0 -15100,1.2525237,3.5284276,,,,,,,,,,,,,, -15200,2.07453,3.4798021,,,,,,,,,,,,,, -15300,1.4802501,3.4821744,,,,,,,,,,,,,, -15400,1.8909225,3.6169343,,,,,,,,,,,,,, -15500,1.2559623,3.4387236,,,,,,,,,,,,,, -15600,1.4513367,3.4726489,,,,,,,,,,,,,, -15700,1.3812599,3.5321832,,,,,,,,,,,,,, -15800,2.558289,3.550839,,,,,,,,,,,,,, -15900,2.1776826,3.4472654,,,,,,,,,,,,,, -16000,1.588503,3.4492211,,,,,,,,,,,,,, -16100,1.6224734,3.4345682,,,,,,,,,,,,,, -16200,1.8552547,3.5692396,,,,,,,,,,,,,, -16300,1.4451371,3.5115798,,,,,,,,,,,,,, -16400,1.4833679,3.4602256,,,,,,,,,,,,,, -16500,1.4341347,3.4605663,,,,,,,,,,,,,, -16591,,,0.6404655575752258,1.7349895238876345,0.5945599675178528,1.9475181102752688,50000.0,0.4653000235557556,2.612239122390747,10000.0,5642.341354608536,5856.172199487686,5642.341354608536,212.92911338806152,0.3187780380249023,0.0 -16600,1.2633024,3.4471617,,,,,,,,,,,,,, -16700,1.3095142,3.4666307,,,,,,,,,,,,,, -16800,1.3045912,3.4503078,,,,,,,,,,,,,, -16900,1.211861,3.436315,,,,,,,,,,,,,, -17000,1.9216399,3.5672445,,,,,,,,,,,,,, -17100,1.6662368,3.4555674,,,,,,,,,,,,,, -17200,1.4597892,3.481784,,,,,,,,,,,,,, -17300,1.0481541,3.4002252,,,,,,,,,,,,,, -17400,1.5838022,3.5306265,,,,,,,,,,,,,, -17500,1.2728164,3.4045715,,,,,,,,,,,,,, -17600,1.2645106,3.3829691,,,,,,,,,,,,,, -17700,1.6246228,3.3566036,,,,,,,,,,,,,, -17800,1.368414,3.40484,,,,,,,,,,,,,, -17900,1.3556507,3.471216,,,,,,,,,,,,,, -18000,1.1393992,3.4066634,,,,,,,,,,,,,, -18100,1.3865775,3.4695413,,,,,,,,,,,,,, -18101,,,0.6934390664100647,1.4368542432785034,0.613099992275238,1.7975260019302368,50000.0,0.4782000184059143,2.4677436351776123,10000.0,6152.393952131271,6384.028230428696,6152.393952131271,230.6470057964325,0.351142406463623,0.0 -18200,1.435304,3.4372804,,,,,,,,,,,,,, -18300,1.4280806,3.468391,,,,,,,,,,,,,, -18400,1.4652534,3.5544555,,,,,,,,,,,,,, -18500,1.3929437,3.4526026,,,,,,,,,,,,,, -18600,1.6178956,3.4225116,,,,,,,,,,,,,, -18700,1.3984509,3.3836558,,,,,,,,,,,,,, -18800,1.3434236,3.3809888,,,,,,,,,,,,,, -18900,1.384113,3.4041178,,,,,,,,,,,,,, -19000,1.360562,3.3959436,,,,,,,,,,,,,, -19100,1.5606599,3.4040663,,,,,,,,,,,,,, -19200,1.1331369,3.3984914,,,,,,,,,,,,,, -19300,1.3088409,3.4045572,,,,,,,,,,,,,, -19400,1.2599719,3.4388764,,,,,,,,,,,,,, -19500,1.2462021,3.3652847,,,,,,,,,,,,,, -19600,1.2850798,3.3742185,,,,,,,,,,,,,, -19612,,,0.6884366869926453,1.4951562881469729,0.6227799654006958,1.7908552885055542,50000.0,0.4949000179767608,2.467326641082764,10000.0,6662.580951213837,6912.111881971359,6662.580951213837,248.4592657089233,0.3803071975708008,0.0 -19700,1.2414665,3.4981613,,,,,,,,,,,,,, -19800,1.3898929,3.3002985,,,,,,,,,,,,,, -19900,1.0515103,3.2559047,,,,,,,,,,,,,, -20000,1.3321321,3.3793843,,,,,,,,,,,,,, -20100,1.4353828,3.3684762,,,,,,,,,,,,,, -20200,1.7503693,3.4604526,,,,,,,,,,,,,, -20300,1.6021558,3.3370767,,,,,,,,,,,,,, -20400,1.4560233,3.411264,,,,,,,,,,,,,, -20500,1.5406868,3.367979,,,,,,,,,,,,,, -20600,1.2890333,3.3122625,,,,,,,,,,,,,, -20700,1.3272843,3.2936676,,,,,,,,,,,,,, -20800,1.7772558,3.4145353,,,,,,,,,,,,,, -20900,1.3874385,3.434156,,,,,,,,,,,,,, -21000,1.2598264,3.363529,,,,,,,,,,,,,, -21100,1.2319244,3.3518882,,,,,,,,,,,,,, -21123,,,0.6886360049247742,1.4495007991790771,0.625819981098175,1.7343255281448364,50000.0,0.493800014257431,2.425893545150757,10000.0,7172.652762174606,7439.812463760376,7172.652762174606,266.00374031066895,0.4111535549163818,0.0 -21200,1.5433326,3.3491778,,,,,,,,,,,,,, -21300,1.6285083,3.444873,,,,,,,,,,,,,, -21400,1.2843595,3.288206,,,,,,,,,,,,,, -21500,1.5078464,3.3401418,,,,,,,,,,,,,, -21600,1.3264085,3.2941637,,,,,,,,,,,,,, -21700,1.6379311,3.3390384,,,,,,,,,,,,,, -21800,1.2688773,3.396504,,,,,,,,,,,,,, -21900,1.3026959,3.4283667,,,,,,,,,,,,,, -22000,1.2460567,3.3545477,,,,,,,,,,,,,, -22100,1.1873715,3.3388517,,,,,,,,,,,,,, -22200,1.8023037,3.4344401,,,,,,,,,,,,,, -22300,1.3426145,3.305298,,,,,,,,,,,,,, -22400,1.4451342,3.332354,,,,,,,,,,,,,, -22500,1.3952111,3.3055608,,,,,,,,,,,,,, -22600,1.5544038,3.3623412,,,,,,,,,,,,,, -22633,,,0.6915856003761292,1.4739915132522583,0.6281999945640564,1.7558631896972656,50000.0,0.5089000463485718,2.38150954246521,10000.0,7682.5801339149475,7967.496514797211,7682.5801339149475,283.67709016799927,0.441507339477539,0.0 -22700,1.2618623,3.3565102,,,,,,,,,,,,,, -22800,1.2682759,3.3872654,,,,,,,,,,,,,, -22900,1.3349097,3.3047879,,,,,,,,,,,,,, -23000,1.2632679,3.4207287,,,,,,,,,,,,,, -23100,1.2680391,3.3515995,,,,,,,,,,,,,, -23200,1.3064263,3.2958515,,,,,,,,,,,,,, -23300,1.29452,3.3701632,,,,,,,,,,,,,, -23400,1.623151,3.2731447,,,,,,,,,,,,,, -23500,1.0447605,3.29049,,,,,,,,,,,,,, -23600,1.6800519,3.3734336,,,,,,,,,,,,,, -23700,1.4176847,3.3908844,,,,,,,,,,,,,, -23800,1.5777872,3.270304,,,,,,,,,,,,,, -23900,1.4083301,3.3311186,,,,,,,,,,,,,, -24000,1.6718335,3.386884,,,,,,,,,,,,,, -24100,1.401443,3.3612087,,,,,,,,,,,,,, -24144,,,0.6928212642669678,1.4294037818908691,0.6322399973869324,1.7027853727340698,50000.0,0.5017000436782837,2.382734537124634,10000.0,8192.537194252014,8495.38295841217,8192.537194252014,301.5229160785675,0.4715721607208252,0.0 -24200,1.3174539,3.3668966,,,,,,,,,,,,,, -24300,1.4142678,3.2959828,,,,,,,,,,,,,, -24400,1.5037661,3.3798125,,,,,,,,,,,,,, -24500,1.364912,3.35211,,,,,,,,,,,,,, -24600,1.7688608,3.3464975,,,,,,,,,,,,,, -24700,1.3022878,3.2274704,,,,,,,,,,,,,, -24800,1.4157379,3.3398829,,,,,,,,,,,,,, -24900,1.2796831,3.376852,,,,,,,,,,,,,, -25000,1.4056848,3.3215604,,,,,,,,,,,,,, -25100,1.1428908,3.291615,,,,,,,,,,,,,, -25200,1.2283026,3.284516,,,,,,,,,,,,,, -25300,1.5112385,3.265901,,,,,,,,,,,,,, -25400,1.2540257,3.3342762,,,,,,,,,,,,,, -25500,1.5222144,3.3189807,,,,,,,,,,,,,, -25600,1.1225804,3.297897,,,,,,,,,,,,,, -25656,,,0.6940768361091614,1.4729230403900146,0.6380199790000916,1.7212127447128296,50000.0,0.5093000531196594,2.390270709991455,10000.0,8702.589481115341,9023.039636611938,8702.589481115341,319.04385137557983,0.5028097629547119,0.0 -25700,1.3726627,3.2612534,,,,,,,,,,,,,, -25800,1.1669288,3.2673934,,,,,,,,,,,,,, -25900,1.2734288,3.3070347,,,,,,,,,,,,,, -26000,1.5772849,3.390897,,,,,,,,,,,,,, -26100,1.3267659,3.3161018,,,,,,,,,,,,,, -26200,1.6121162,3.2499287,,,,,,,,,,,,,, -26300,1.6229447,3.365619,,,,,,,,,,,,,, -26400,1.4662488,3.2576256,,,,,,,,,,,,,, -26500,1.5493104,3.3303163,,,,,,,,,,,,,, -26600,1.1564596,3.2609305,,,,,,,,,,,,,, -26700,1.4838576,3.211064,,,,,,,,,,,,,, -26800,1.2231537,3.2852488,,,,,,,,,,,,,, -26900,1.2932692,3.2435045,,,,,,,,,,,,,, -27000,1.1541233,3.3387492,,,,,,,,,,,,,, -27100,1.4575492,3.358537,,,,,,,,,,,,,, -27167,,,0.7210817933082581,1.3525702953338623,0.6334599852561951,1.7323843240737915,50000.0,0.5049999952316284,2.398658514022827,10000.0,9212.57828116417,9550.835167884828,9212.57828116417,336.7666335105896,0.5334517955780029,0.0 -27200,1.3556697,3.2388492,,,,,,,,,,,,,, -27300,1.8609021,3.3228235,,,,,,,,,,,,,, -27400,1.5990696,3.2469287,,,,,,,,,,,,,, -27500,1.3308607,3.2152224,,,,,,,,,,,,,, -27600,1.3228898,3.254661,,,,,,,,,,,,,, -27700,1.5082738,3.240076,,,,,,,,,,,,,, -27800,1.6890059,3.3201275,,,,,,,,,,,,,, -27900,1.7488227,3.2944992,,,,,,,,,,,,,, -28000,1.431376,3.256734,,,,,,,,,,,,,, -28100,1.4903005,3.2856221,,,,,,,,,,,,,, -28200,1.2596457,3.294497,,,,,,,,,,,,,, -28300,1.9188071,3.339456,,,,,,,,,,,,,, -28400,1.4961504,3.2231004,,,,,,,,,,,,,, -28500,1.123745,3.250401,,,,,,,,,,,,,, -28600,1.2667092,3.3159738,,,,,,,,,,,,,, -28679,,,0.7137874364852905,1.3917793035507202,0.6363599896430969,1.7337682247161863,50000.0,0.5118000507354736,2.377542734146118,10000.0,9722.724038362505,10078.865000724792,9722.724038362505,354.5671169757843,0.5640599727630615,0.0 -28700,1.6485186,3.289362,,,,,,,,,,,,,, -28800,1.4047536,3.224467,,,,,,,,,,,,,, -28900,1.5701289,3.2466402,,,,,,,,,,,,,, -29000,1.4434091,3.2642012,,,,,,,,,,,,,, -29100,1.3847834,3.3092237,,,,,,,,,,,,,, -29200,1.5734158,3.307874,,,,,,,,,,,,,, -29300,1.4680614,3.355033,,,,,,,,,,,,,, -29400,1.4225087,3.2587776,,,,,,,,,,,,,, -29500,1.7156959,3.2782893,,,,,,,,,,,,,, -29600,1.3473654,3.32016,,,,,,,,,,,,,, -29700,1.2718685,3.1926446,,,,,,,,,,,,,, -29800,1.4330013,3.3282154,,,,,,,,,,,,,, -29900,1.4149629,3.2562206,,,,,,,,,,,,,, -30000,1.4271526,3.25947,,,,,,,,,,,,,, -30100,1.2398398,3.240123,,,,,,,,,,,,,, -30190,,,0.7102399468421936,1.383233904838562,0.6411399841308594,1.6908059120178225,50000.0,0.5134000182151794,2.357249021530152,10000.0,10232.818863868712,10606.80100107193,10232.818863868712,372.3236920833588,0.5949838161468506,0.0 -30200,1.2574512,3.3600914,,,,,,,,,,,,,, -30300,1.2904779,3.2689962,,,,,,,,,,,,,, -30400,1.4687665,3.367347,,,,,,,,,,,,,, -30500,1.4264427,3.2753584,,,,,,,,,,,,,, -30600,1.5418633,3.217319,,,,,,,,,,,,,, -30700,1.3838177,3.2239974,,,,,,,,,,,,,, -30800,1.4162241,3.3041875,,,,,,,,,,,,,, -30900,1.4528747,3.2086134,,,,,,,,,,,,,, -31000,1.3305278,3.2624397,,,,,,,,,,,,,, -31100,1.1933576,3.2582061,,,,,,,,,,,,,, -31200,1.2470763,3.2280507,,,,,,,,,,,,,, -31300,1.5181358,3.3131905,,,,,,,,,,,,,, -31400,1.7255249,3.2976103,,,,,,,,,,,,,, -31500,1.8998461,3.2465608,,,,,,,,,,,,,, -31600,2.0057013,3.2516327,,,,,,,,,,,,,, -31700,1.4243736,3.2607203,,,,,,,,,,,,,, -31701,,,0.7122528553009033,1.4000403881072998,0.6447399854660034,1.6911596059799194,50000.0,0.515500009059906,2.3480465412139893,10000.0,10743.079125404358,11135.051835298538,10743.079125404358,390.22929978370667,0.6257216930389404,0.0 -31800,1.5870687,3.2172296,,,,,,,,,,,,,, -31900,1.4068758,3.2495928,,,,,,,,,,,,,, -32000,1.3542237,3.2145834,,,,,,,,,,,,,, -32100,1.3772917,3.201791,,,,,,,,,,,,,, -32200,1.3580077,3.283184,,,,,,,,,,,,,, -32300,1.4480115,3.2353332,,,,,,,,,,,,,, -32400,1.6823004,3.2197824,,,,,,,,,,,,,, -32500,1.55888,3.1917593,,,,,,,,,,,,,, -32600,1.5588291,3.295642,,,,,,,,,,,,,, -32700,1.4363023,3.2563672,,,,,,,,,,,,,, -32800,1.3339502,3.273235,,,,,,,,,,,,,, -32900,1.3691702,3.2096908,,,,,,,,,,,,,, -33000,1.2459854,3.3101327,,,,,,,,,,,,,, -33100,1.3590716,3.2000072,,,,,,,,,,,,,, -33200,1.5797306,3.2425885,,,,,,,,,,,,,, -33213,,,0.6996970772743225,1.4557379484176636,0.6355999708175659,1.7373415231704712,50000.0,0.5081000328063965,2.3862760066986084,10000.0,11253.20901465416,11662.901600122452,11253.20901465416,407.86250853538513,0.65826416015625,0.0 -33300,1.5300035,3.2556105,,,,,,,,,,,,,, -33400,1.3401409,3.237211,,,,,,,,,,,,,, -33500,1.4721737,3.2037551,,,,,,,,,,,,,, -33600,1.8805193,3.242604,,,,,,,,,,,,,, -33700,1.4851822,3.2735324,,,,,,,,,,,,,, -33800,1.5336494,3.294144,,,,,,,,,,,,,, -33900,1.4543123,3.270094,,,,,,,,,,,,,, -34000,1.8959708,3.321154,,,,,,,,,,,,,, -34100,1.4116592,3.2050464,,,,,,,,,,,,,, -34200,1.3830051,3.2108572,,,,,,,,,,,,,, -34300,1.5699595,3.1947496,,,,,,,,,,,,,, -34400,1.3968498,3.1743293,,,,,,,,,,,,,, -34500,1.4337254,3.2779784,,,,,,,,,,,,,, -34600,1.4970132,3.2382338,,,,,,,,,,,,,, -34700,1.3401914,3.215198,,,,,,,,,,,,,, -34726,,,0.7174744606018066,1.3685293197631836,0.6532599925994873,1.6530615091323853,50000.0,0.5192000269889832,2.314382553100586,10000.0,11763.446279764175,12191.024013519287,11763.446279764175,425.66084122657776,0.6921558380126953,0.0 -34800,1.6298151,3.2181535,,,,,,,,,,,,,, -34900,1.8320324,3.2020645,,,,,,,,,,,,,, -35000,1.4900337,3.2504153,,,,,,,,,,,,,, -35100,1.4236048,3.1687233,,,,,,,,,,,,,, -35200,1.4732283,3.253055,,,,,,,,,,,,,, -35300,1.6128074,3.102157,,,,,,,,,,,,,, -35400,1.5764848,3.1146402,,,,,,,,,,,,,, -35500,1.5999748,3.2894654,,,,,,,,,,,,,, -35600,1.345081,3.2626479,,,,,,,,,,,,,, -35700,1.4416426,3.2898693,,,,,,,,,,,,,, -35800,1.5032218,3.2281687,,,,,,,,,,,,,, -35900,1.4274745,3.2256706,,,,,,,,,,,,,, -36000,1.5545188,3.215007,,,,,,,,,,,,,, -36100,1.4911671,3.2214723,,,,,,,,,,,,,, -36200,1.3450437,3.196527,,,,,,,,,,,,,, -36238,,,0.7438416481018066,1.2298567295074463,0.6535599827766418,1.6235462427139282,50000.0,0.5265000462532043,2.281526565551758,10000.0,12273.675210475922,12718.935508489609,12273.675210475922,443.2582674026489,0.7246830463409424,0.0 -36300,1.6114452,3.2763531,,,,,,,,,,,,,, -36400,1.7775611,3.274154,,,,,,,,,,,,,, -36500,1.614921,3.1599908,,,,,,,,,,,,,, -36600,1.4741471,3.2119496,,,,,,,,,,,,,, -36700,1.4550872,3.206945,,,,,,,,,,,,,, -36800,1.5176096,3.2096362,,,,,,,,,,,,,, -36900,1.4496058,3.2927139,,,,,,,,,,,,,, -37000,1.6250666,3.2344565,,,,,,,,,,,,,, -37100,1.5390974,3.1898098,,,,,,,,,,,,,, -37200,1.400416,3.2429445,,,,,,,,,,,,,, -37300,1.5682886,3.1794617,,,,,,,,,,,,,, -37400,1.9138823,3.2326992,,,,,,,,,,,,,, -37500,1.6391225,3.20662,,,,,,,,,,,,,, -37600,1.707297,3.2711124,,,,,,,,,,,,,, -37700,1.5455658,3.1863317,,,,,,,,,,,,,, -37751,,,0.7174545526504517,1.3231170177459717,0.6446200013160706,1.666609764099121,50000.0,0.5151000022888184,2.3460588455200195,10000.0,12783.86525940895,13246.783225536346,12783.86525940895,460.8275353908539,0.7599701881408691,0.0 -37800,1.5052937,3.190292,,,,,,,,,,,,,, -37900,1.5644212,3.2423983,,,,,,,,,,,,,, -38000,1.4287986,3.2217965,,,,,,,,,,,,,, -38100,1.5439321,3.2096019,,,,,,,,,,,,,, -38200,1.5749031,3.2082531,,,,,,,,,,,,,, -38300,1.5987822,3.1956398,,,,,,,,,,,,,, -38400,1.5848584,3.2417612,,,,,,,,,,,,,, -38500,1.4163064,3.157755,,,,,,,,,,,,,, -38600,1.597203,3.2335331,,,,,,,,,,,,,, -38700,1.5304961,3.2907877,,,,,,,,,,,,,, -38800,1.6605079,3.1993175,,,,,,,,,,,,,, -38900,2.0270298,3.20409,,,,,,,,,,,,,, -39000,1.5533262,3.2133722,,,,,,,,,,,,,, -39100,1.6970593,3.2130702,,,,,,,,,,,,,, -39200,1.5156353,3.2246654,,,,,,,,,,,,,, -39263,,,0.7337173223495483,1.286712884902954,0.6584399938583374,1.6077271699905396,50000.0,0.5318000316619873,2.28325629234314,10000.0,13293.79348897934,13774.605125188828,13293.79348897934,478.63541746139526,0.7926428318023682,0.0 -39300,1.5032603,3.1702936,,,,,,,,,,,,,, -39400,1.6413728,3.1159768,,,,,,,,,,,,,, -39500,1.5888734,3.252724,,,,,,,,,,,,,, -39600,1.6359645,3.1784508,,,,,,,,,,,,,, -39700,1.5025289,3.231886,,,,,,,,,,,,,, -39800,1.6350025,3.247418,,,,,,,,,,,,,, -39900,1.5014776,3.1617858,,,,,,,,,,,,,, -40000,1.5446845,3.1883562,,,,,,,,,,,,,, -40100,1.5392278,3.2066662,,,,,,,,,,,,,, -40200,1.5708342,3.2133088,,,,,,,,,,,,,, -40300,1.6148702,3.2766535,,,,,,,,,,,,,, -40400,1.4975592,3.185963,,,,,,,,,,,,,, -40500,1.508051,3.256228,,,,,,,,,,,,,, -40600,1.7985072,3.1360493,,,,,,,,,,,,,, -40700,1.3626171,3.1738,,,,,,,,,,,,,, -40775,,,0.7258848547935486,1.2848888635635376,0.6577799916267395,1.597143530845642,50000.0,0.5288000106811523,2.251645803451538,10000.0,13803.70843219757,14303.069906711578,13803.70843219757,497.0984447002411,0.8266785144805908,0.0 -40800,1.6227571,3.1832767,,,,,,,,,,,,,, -40900,1.7994238,3.3009195,,,,,,,,,,,,,, -41000,1.6248468,3.1763039,,,,,,,,,,,,,, -41100,1.6011375,3.1323965,,,,,,,,,,,,,, -41200,1.7068179,3.2149725,,,,,,,,,,,,,, -41300,1.6019318,3.2119336,,,,,,,,,,,,,, -41400,1.5435425,3.153193,,,,,,,,,,,,,, -41500,1.488089,3.1692395,,,,,,,,,,,,,, -41600,1.5348959,3.2043009,,,,,,,,,,,,,, -41700,1.7313263,3.21164,,,,,,,,,,,,,, -41800,1.7203443,3.144773,,,,,,,,,,,,,, -41900,1.6129984,3.1479754,,,,,,,,,,,,,, -42000,1.7133887,3.207788,,,,,,,,,,,,,, -42100,1.8121715,3.2105083,,,,,,,,,,,,,, -42200,1.6656693,3.2058682,,,,,,,,,,,,,, -42288,,,0.7098811864852905,1.3791455030441284,0.6428200006484985,1.6814510822296145,50000.0,0.5174000263214111,2.3248226642608643,10000.0,14313.930534362791,14831.188447713852,14313.930534362791,514.9081664085388,0.86090087890625,0.0 -42300,1.5679339,3.0573523,,,,,,,,,,,,,, -42400,1.6204349,3.1271017,,,,,,,,,,,,,, -42500,1.4410497,3.1942282,,,,,,,,,,,,,, -42600,1.6934781,3.160173,,,,,,,,,,,,,, -42700,1.7915542,3.1461577,,,,,,,,,,,,,, -42800,1.6880503,3.2219706,,,,,,,,,,,,,, -42900,1.6031244,3.1651497,,,,,,,,,,,,,, -43000,1.6583457,3.2208295,,,,,,,,,,,,,, -43100,1.6341388,3.2007465,,,,,,,,,,,,,, -43200,1.8116592,3.1708925,,,,,,,,,,,,,, -43300,1.6724588,3.1645768,,,,,,,,,,,,,, -43400,1.5677385,3.1500564,,,,,,,,,,,,,, -43500,1.7278368,3.1738594,,,,,,,,,,,,,, -43600,1.5539118,3.095532,,,,,,,,,,,,,, -43700,1.7221165,3.1274333,,,,,,,,,,,,,, -43800,1.5997105,3.2037292,,,,,,,,,,,,,, -43801,,,0.7302295565605164,1.3141027688980105,0.6615200042724609,1.6093543767929075,50000.0,0.5326000452041626,2.2799079418182373,10000.0,14824.281010389328,15359.110605239868,14824.281010389328,532.3907444477081,0.8976097106933594,0.0 -43900,1.911605,3.1678147,,,,,,,,,,,,,, -44000,1.6113565,3.229902,,,,,,,,,,,,,, -44100,1.84734,3.144886,,,,,,,,,,,,,, -44200,1.8479948,3.1755145,,,,,,,,,,,,,, -44300,1.7534873,3.1519756,,,,,,,,,,,,,, -44400,1.8509942,3.1643312,,,,,,,,,,,,,, -44500,1.7824196,3.2154899,,,,,,,,,,,,,, -44600,1.6969965,3.229556,,,,,,,,,,,,,, -44700,1.6796468,3.2320795,,,,,,,,,,,,,, -44800,1.610685,3.0812614,,,,,,,,,,,,,, -44900,1.7448612,3.1441581,,,,,,,,,,,,,, -45000,1.5593373,3.256832,,,,,,,,,,,,,, -45100,1.8548537,3.1334212,,,,,,,,,,,,,, -45200,1.9273779,3.1590815,,,,,,,,,,,,,, -45300,2.045462,3.189362,,,,,,,,,,,,,, -45314,,,0.7517538070678711,1.200649619102478,0.6576399803161621,1.604137659072876,50000.0,0.5347000360488892,2.256589412689209,10000.0,15334.48612332344,15887.034386634828,15334.48612332344,550.0225744247437,0.9313144683837892,0.0 -45400,1.578343,3.1627295,,,,,,,,,,,,,, -45500,1.7131554,3.1896496,,,,,,,,,,,,,, -45600,1.7015519,3.207278,,,,,,,,,,,,,, -45700,1.5882299,3.204866,,,,,,,,,,,,,, -45800,1.6614822,3.1702642,,,,,,,,,,,,,, -45900,1.784215,3.1396487,,,,,,,,,,,,,, -46000,1.6463965,3.1901765,,,,,,,,,,,,,, -46100,1.7563475,3.2360024,,,,,,,,,,,,,, -46200,1.912878,3.1519334,,,,,,,,,,,,,, -46300,1.7412963,3.0790265,,,,,,,,,,,,,, -46400,1.9292598,3.2774203,,,,,,,,,,,,,, -46500,1.9980536,3.145545,,,,,,,,,,,,,, -46600,1.735857,3.2157533,,,,,,,,,,,,,, -46700,1.7126305,3.1013367,,,,,,,,,,,,,, -46800,1.679388,3.2137709,,,,,,,,,,,,,, -46827,,,0.7456353306770325,1.2059592008590698,0.6634799838066101,1.5565850734710691,50000.0,0.539900004863739,2.1983628273010254,10000.0,15844.531326293943,16414.89757823944,15844.531326293943,567.7500638961792,0.9677863121032716,0.0 -46900,1.6081117,3.2066588,,,,,,,,,,,,,, -47000,1.597851,3.1020467,,,,,,,,,,,,,, -47100,1.6653657,3.1292899,,,,,,,,,,,,,, -47200,1.704847,3.1293066,,,,,,,,,,,,,, -47300,1.7409189,3.1720767,,,,,,,,,,,,,, -47400,1.6577462,3.1095543,,,,,,,,,,,,,, -47500,1.611618,3.1474435,,,,,,,,,,,,,, -47600,1.7255601,3.242181,,,,,,,,,,,,,, -47700,1.6251997,3.103837,,,,,,,,,,,,,, -47800,1.6686268,3.1690784,,,,,,,,,,,,,, -47900,1.8238547,3.1083717,,,,,,,,,,,,,, -48000,1.749605,3.1326385,,,,,,,,,,,,,, -48100,1.7854879,3.1002443,,,,,,,,,,,,,, -48200,1.7354916,3.2122827,,,,,,,,,,,,,, -48300,1.6644113,3.2119195,,,,,,,,,,,,,, -48339,,,0.735750138759613,1.2548892498016355,0.6612600088119507,1.583881974220276,50000.0,0.5368000268936157,2.243327140808105,10000.0,16354.600351333618,16942.508352041245,16354.600351333618,585.2028439044952,1.00309419631958,0.0 -48400,1.7026536,3.1913896,,,,,,,,,,,,,, -48500,1.7091036,3.2016015,,,,,,,,,,,,,, -48600,1.6982887,3.2133017,,,,,,,,,,,,,, -48700,1.6890461,3.0956469,,,,,,,,,,,,,, -48800,1.8381197,3.2026763,,,,,,,,,,,,,, -48900,1.6727452,3.1035948,,,,,,,,,,,,,, -49000,1.6745191,3.1274579,,,,,,,,,,,,,, -49100,1.8943903,3.15061,,,,,,,,,,,,,, -49200,1.6552037,3.1450403,,,,,,,,,,,,,, -49300,1.8604044,3.080456,,,,,,,,,,,,,, -49400,1.7179039,3.1506314,,,,,,,,,,,,,, -49500,1.8528578,3.1018972,,,,,,,,,,,,,, -49600,2.0740833,3.1252677,,,,,,,,,,,,,, -49700,1.8202688,3.1253211,,,,,,,,,,,,,, -49800,1.7132643,3.07324,,,,,,,,,,,,,, -49852,,,0.7357302308082581,1.2622418403625488,0.6646999716758728,1.571779489517212,50000.0,0.534000039100647,2.243760108947754,10000.0,16864.72789144516,17470.579726219177,16864.72789144516,603.0600016117096,1.037532091140747,0.0 -49900,1.6890143,3.0926323,,,,,,,,,,,,,, -50000,1.8142215,3.152445,,,,,,,,,,,,,, -50100,1.6558249,3.1296103,,,,,,,,,,,,,, -50200,1.7762343,3.1739843,,,,,,,,,,,,,, -50300,1.9357748,3.129105,,,,,,,,,,,,,, -50400,2.01863,3.1895587,,,,,,,,,,,,,, -50500,1.7179321,3.07404,,,,,,,,,,,,,, -50600,1.9394912,3.1647375,,,,,,,,,,,,,, -50700,1.7196779,3.1110046,,,,,,,,,,,,,, -50800,1.8343191,3.1939046,,,,,,,,,,,,,, -50900,1.9244988,3.246334,,,,,,,,,,,,,, -51000,1.8906415,3.2045205,,,,,,,,,,,,,, -51100,1.7768939,3.1911838,,,,,,,,,,,,,, -51200,1.7546313,3.1250255,,,,,,,,,,,,,, -51300,1.9874686,3.1775703,,,,,,,,,,,,,, -51365,,,0.7381417155265808,1.2553119659423828,0.6695399880409241,1.566043734550476,50000.0,0.5473000407218933,2.198474884033203,10000.0,17374.956661462784,17999.086223602295,17374.956661462784,621.2467834949493,1.0746331214904783,0.0 -51400,1.8437115,3.138657,,,,,,,,,,,,,, -51500,1.85279,3.209225,,,,,,,,,,,,,, -51600,1.8242671,3.1164012,,,,,,,,,,,,,, -51700,1.7946011,3.1888378,,,,,,,,,,,,,, -51800,1.6818383,3.1893716,,,,,,,,,,,,,, -51900,1.9369788,3.1677973,,,,,,,,,,,,,, -52000,1.7766997,3.1570485,,,,,,,,,,,,,, -52100,1.8102373,3.1672378,,,,,,,,,,,,,, -52200,2.007696,3.1511657,,,,,,,,,,,,,, -52300,1.6811167,3.1619155,,,,,,,,,,,,,, -52400,1.8540896,3.1466944,,,,,,,,,,,,,, -52500,1.8661947,3.1647434,,,,,,,,,,,,,, -52600,1.8023223,3.2341642,,,,,,,,,,,,,, -52700,1.9137689,3.170145,,,,,,,,,,,,,, -52800,1.7903634,3.1248865,,,,,,,,,,,,,, -52877,,,0.7241908311843872,1.344178318977356,0.6525599956512451,1.6665358543395996,50000.0,0.5195000171661377,2.337864875793457,10000.0,17884.904060840607,18526.888315439224,17884.904060840607,639.0135197639465,1.109938621520996,0.0 -52900,1.7089499,3.1207576,,,,,,,,,,,,,, -53000,1.9100863,3.169943,,,,,,,,,,,,,, -53100,1.8731161,3.1185284,,,,,,,,,,,,,, -53200,1.9478272,3.1427262,,,,,,,,,,,,,, -53300,2.1646066,3.0856776,,,,,,,,,,,,,, -53400,1.7539374,3.121762,,,,,,,,,,,,,, -53500,1.9667339,3.0905247,,,,,,,,,,,,,, -53600,1.7194291,3.0189316,,,,,,,,,,,,,, -53700,1.9489373,3.072361,,,,,,,,,,,,,, -53800,1.8296138,3.1355968,,,,,,,,,,,,,, -53900,1.8073727,3.1295724,,,,,,,,,,,,,, -54000,1.8761066,3.1432295,,,,,,,,,,,,,, -54100,1.7871681,3.1475725,,,,,,,,,,,,,, -54200,1.84836,3.1243136,,,,,,,,,,,,,, -54300,1.7580358,3.0853958,,,,,,,,,,,,,, -54390,,,0.750996470451355,1.2404024600982666,0.6648600101470947,1.6224167346954346,50000.0,0.5393000245094299,2.2847721576690674,10000.0,18395.117757558823,19054.84995698929,18395.117757558823,656.6714797019958,1.147374153137207,0.0 -54400,1.9178934,3.12896,,,,,,,,,,,,,, -54500,2.1749542,3.1186922,,,,,,,,,,,,,, -54600,1.7052972,3.1511374,,,,,,,,,,,,,, -54700,1.9627576,3.1303558,,,,,,,,,,,,,, -54800,1.9127678,3.1856222,,,,,,,,,,,,,, -54900,1.935091,3.0856454,,,,,,,,,,,,,, -55000,2.0523057,3.144666,,,,,,,,,,,,,, -55100,1.8015023,3.2198334,,,,,,,,,,,,,, -55200,1.9783292,3.1374075,,,,,,,,,,,,,, -55300,2.009131,3.146867,,,,,,,,,,,,,, -55400,1.8741977,3.1291373,,,,,,,,,,,,,, -55500,1.752504,3.1653898,,,,,,,,,,,,,, -55600,2.0704365,3.0716658,,,,,,,,,,,,,, -55700,1.9835402,3.1608422,,,,,,,,,,,,,, -55800,1.8400388,3.2321947,,,,,,,,,,,,,, -55900,2.015236,3.03469,,,,,,,,,,,,,, -55902,,,0.7460737824440002,1.2308549880981443,0.6655600070953369,1.5786885023117063,50000.0,0.5325000286102295,2.248415231704712,10000.0,18905.16711997986,19582.74394917488,18905.16711997986,674.4238469600677,1.1857051849365234,0.0 -56000,1.940415,3.0770514,,,,,,,,,,,,,, -56100,1.8453163,3.0197594,,,,,,,,,,,,,, -56200,2.0843167,3.1751895,,,,,,,,,,,,,, -56300,1.894004,3.1736796,,,,,,,,,,,,,, -56400,1.8633094,3.1173358,,,,,,,,,,,,,, -56500,1.8473893,3.0970001,,,,,,,,,,,,,, -56600,1.9352272,3.1397216,,,,,,,,,,,,,, -56700,1.9426361,3.1110501,,,,,,,,,,,,,, -56800,1.7310158,3.1400704,,,,,,,,,,,,,, -56900,1.8618768,3.1784003,,,,,,,,,,,,,, -57000,1.7935879,3.1182737,,,,,,,,,,,,,, -57100,2.0233579,3.1255646,,,,,,,,,,,,,, -57200,1.891254,3.185082,,,,,,,,,,,,,, -57300,2.0117266,3.1752548,,,,,,,,,,,,,, -57400,1.9251274,3.068028,,,,,,,,,,,,,, -57414,,,0.7460139989852905,1.2035185098648071,0.6686199903488159,1.54822039604187,50000.0,0.5427000522613525,2.203951358795166,10000.0,19415.267368793488,20110.57297754288,19415.267368793488,692.0639700889587,1.2219395637512207,0.0 -57500,1.8162341,3.0778153,,,,,,,,,,,,,, -57600,1.9721706,3.1406682,,,,,,,,,,,,,, -57700,1.9482396,3.136929,,,,,,,,,,,,,, -57800,1.9682611,3.1583428,,,,,,,,,,,,,, -57900,1.9118121,3.0434084,,,,,,,,,,,,,, -58000,2.0840485,3.160098,,,,,,,,,,,,,, -58100,1.9025955,3.1656876,,,,,,,,,,,,,, -58200,2.1356652,3.1011047,,,,,,,,,,,,,, -58300,1.953332,3.066313,,,,,,,,,,,,,, -58400,1.907642,3.1421607,,,,,,,,,,,,,, -58500,1.9067041,3.1442995,,,,,,,,,,,,,, -58600,1.9528668,3.1172175,,,,,,,,,,,,,, -58700,2.0477076,3.1214263,,,,,,,,,,,,,, -58800,1.7815655,3.1498368,,,,,,,,,,,,,, -58900,1.7771794,3.099308,,,,,,,,,,,,,, -58926,,,0.7444595098495483,1.2356679439544678,0.6713399887084961,1.556386947631836,50000.0,0.5497000217437744,2.2031824588775635,10000.0,19925.24661397934,20638.287900447845,19925.24661397934,709.7101130485535,1.257903814315796,0.0 -59000,2.1819448,3.1070776,,,,,,,,,,,,,, -59100,1.935985,3.1769605,,,,,,,,,,,,,, -59200,1.901098,3.1824355,,,,,,,,,,,,,, -59300,2.0906277,3.1419473,,,,,,,,,,,,,, -59400,1.9814355,3.2100208,,,,,,,,,,,,,, -59500,2.077851,3.064873,,,,,,,,,,,,,, -59600,2.040141,3.0810714,,,,,,,,,,,,,, -59700,2.0731003,3.1231952,,,,,,,,,,,,,, -59800,2.101652,3.1385717,,,,,,,,,,,,,, -59900,2.252949,3.0701885,,,,,,,,,,,,,, -60000,2.1510396,3.1090868,,,,,,,,,,,,,, -60100,2.60694,3.1689405,,,,,,,,,,,,,, -60200,2.0164716,3.083093,,,,,,,,,,,,,, -60300,1.9461951,3.2102416,,,,,,,,,,,,,, -60400,2.0155208,3.076853,,,,,,,,,,,,,, -60438,,,0.7448979616165161,1.266387701034546,0.6719599962234497,1.57454514503479,50000.0,0.5458000302314758,2.216935634613037,10000.0,20435.23851680756,21166.16930627823,20435.23851680756,727.5052762031555,1.2989416122436523,0.0 -60500,2.0584784,3.066648,,,,,,,,,,,,,, -60600,1.7847495,3.0655901,,,,,,,,,,,,,, -60700,1.8088508,2.9845874,,,,,,,,,,,,,, -60800,2.013204,3.16217,,,,,,,,,,,,,, -60900,1.9815394,3.045651,,,,,,,,,,,,,, -61000,1.9806707,3.1901178,,,,,,,,,,,,,, -61100,1.8928452,3.1473043,,,,,,,,,,,,,, -61200,1.9209148,3.0467732,,,,,,,,,,,,,, -61300,2.0066328,3.0637865,,,,,,,,,,,,,, -61400,2.0160515,3.119962,,,,,,,,,,,,,, -61500,2.077555,3.0604365,,,,,,,,,,,,,, -61600,2.0933936,3.0902069,,,,,,,,,,,,,, -61700,1.9065589,3.0708973,,,,,,,,,,,,,, -61800,1.9514111,3.0846713,,,,,,,,,,,,,, -61900,2.2582486,3.1171465,,,,,,,,,,,,,, -61951,,,0.7596260905265808,1.151545763015747,0.6668800115585327,1.5589600801467896,50000.0,0.5335000157356262,2.219144105911255,10000.0,20945.41634607315,21694.11650276184,20945.41634607315,745.1842300891876,1.3371562957763672,0.0 -62000,2.020858,3.1398764,,,,,,,,,,,,,, -62100,1.9666048,3.0931697,,,,,,,,,,,,,, -62200,2.195602,3.133737,,,,,,,,,,,,,, -62300,1.8556819,3.0089555,,,,,,,,,,,,,, -62400,2.1328485,3.087443,,,,,,,,,,,,,, -62500,1.9893098,3.106286,,,,,,,,,,,,,, -62600,2.1710453,3.0781448,,,,,,,,,,,,,, -62700,2.1288207,3.1287336,,,,,,,,,,,,,, -62800,1.9632863,3.0529485,,,,,,,,,,,,,, -62900,2.040165,3.1698856,,,,,,,,,,,,,, -63000,1.913427,3.047036,,,,,,,,,,,,,, -63100,1.9191899,3.110577,,,,,,,,,,,,,, -63200,1.8551013,3.050074,,,,,,,,,,,,,, -63300,2.1740181,3.1168418,,,,,,,,,,,,,, -63400,2.0868535,3.0812302,,,,,,,,,,,,,, -63463,,,0.7645288705825806,1.1519296169281006,0.6728799939155579,1.5548346042633057,50000.0,0.5394999980926514,2.243274211883545,10000.0,21455.393117904663,22222.00701785088,21455.393117904663,763.0073924064636,1.374301195144653,0.0 -63500,1.9035853,3.0136778,,,,,,,,,,,,,, -63600,1.8988261,3.095108,,,,,,,,,,,,,, -63700,2.1300383,3.1228452,,,,,,,,,,,,,, -63800,2.0471096,3.096313,,,,,,,,,,,,,, -63900,1.9909232,3.069047,,,,,,,,,,,,,, -64000,2.2317865,3.0592947,,,,,,,,,,,,,, -64100,2.0446813,3.035292,,,,,,,,,,,,,, -64200,2.11094,3.0635724,,,,,,,,,,,,,, -64300,2.0626948,3.1634893,,,,,,,,,,,,,, -64400,1.9870298,3.0935135,,,,,,,,,,,,,, -64500,2.2052147,3.1319149,,,,,,,,,,,,,, -64600,2.0547998,3.060825,,,,,,,,,,,,,, -64700,2.157248,3.0934758,,,,,,,,,,,,,, -64800,2.5836747,3.1547463,,,,,,,,,,,,,, -64900,2.0013452,3.0646417,,,,,,,,,,,,,, -64976,,,0.7543845772743225,1.1895458698272705,0.6688799858093262,1.562997817993164,50000.0,0.5503000020980835,2.204355478286743,10000.0,21965.388853788376,22749.81556200981,21965.388853788376,780.7278144359589,1.4133169651031494,0.0 -65000,1.9189273,3.0048933,,,,,,,,,,,,,, -65100,1.9463811,3.065245,,,,,,,,,,,,,, -65200,2.1711597,3.0836878,,,,,,,,,,,,,, -65300,2.2748609,3.04089,,,,,,,,,,,,,, -65400,1.950605,3.0274374,,,,,,,,,,,,,, -65500,2.0484858,3.1623065,,,,,,,,,,,,,, -65600,2.0779727,3.0095985,,,,,,,,,,,,,, -65700,2.063744,3.067105,,,,,,,,,,,,,, -65800,2.1599677,3.05338,,,,,,,,,,,,,, -65900,2.117161,3.109443,,,,,,,,,,,,,, -66000,2.2727127,3.1493862,,,,,,,,,,,,,, -66100,1.9780277,3.0859451,,,,,,,,,,,,,, -66200,2.1321888,3.096377,,,,,,,,,,,,,, -66300,2.0570934,3.0188391,,,,,,,,,,,,,, -66400,2.4455185,3.093484,,,,,,,,,,,,,, -66488,,,0.7489436864852905,1.2165971994400024,0.6738399863243103,1.5674960613250732,50000.0,0.5478000044822693,2.2059414386749268,10000.0,22475.54665660858,23277.76457595825,22475.54665660858,798.4267275333405,1.4519784450531006,0.0 -66500,2.0228088,3.0645974,,,,,,,,,,,,,, -66600,2.2756977,3.113335,,,,,,,,,,,,,, -66700,2.072388,3.0921452,,,,,,,,,,,,,, -66800,1.8910933,3.0034046,,,,,,,,,,,,,, -66900,2.1237142,3.0415158,,,,,,,,,,,,,, -67000,1.8834664,3.0633254,,,,,,,,,,,,,, -67100,2.0889785,3.0688107,,,,,,,,,,,,,, -67200,2.2695312,3.0480044,,,,,,,,,,,,,, -67300,1.9809725,3.1243677,,,,,,,,,,,,,, -67400,2.4857724,3.0850916,,,,,,,,,,,,,, -67500,2.039014,3.0332277,,,,,,,,,,,,,, -67600,2.0971968,3.0494084,,,,,,,,,,,,,, -67700,1.9796059,3.127309,,,,,,,,,,,,,, -67800,2.0305223,3.087576,,,,,,,,,,,,,, -67900,1.9054589,3.072397,,,,,,,,,,,,,, -68000,1.8980838,3.014615,,,,,,,,,,,,,, -68001,,,0.7564173936843872,1.1640042066574097,0.6771000027656555,1.5072546005249023,50000.0,0.5573000311851501,2.136253595352173,10000.0,22985.49106383324,23805.33237314224,22985.49106383324,815.9583787918091,1.4906814098358154,0.0 -68100,2.147472,3.130551,,,,,,,,,,,,,, -68200,2.110453,3.0498872,,,,,,,,,,,,,, -68300,2.3577049,3.1404035,,,,,,,,,,,,,, -68400,2.0568392,3.0727367,,,,,,,,,,,,,, -68500,2.1111453,3.0764158,,,,,,,,,,,,,, -68600,2.113659,2.9887838,,,,,,,,,,,,,, -68700,2.244969,3.0305963,,,,,,,,,,,,,, -68800,2.1424603,3.0760965,,,,,,,,,,,,,, -68900,2.1825447,3.1906378,,,,,,,,,,,,,, -69000,1.9537969,3.0734632,,,,,,,,,,,,,, -69100,2.2055218,3.1439557,,,,,,,,,,,,,, -69200,2.0850708,3.0589092,,,,,,,,,,,,,, -69300,2.1185114,3.0788062,,,,,,,,,,,,,, -69400,2.2383585,3.0478446,,,,,,,,,,,,,, -69500,1.9987363,3.0843315,,,,,,,,,,,,,, -69513,,,0.7506178021430969,1.205312728881836,0.6768999695777893,1.5448400974273682,50000.0,0.5519000291824341,2.177412986755371,10000.0,23495.39884543419,24333.28107357025,23495.39884543419,833.9057495594025,1.5299150943756104,0.0 -69600,2.1718771,3.030487,,,,,,,,,,,,,, -69700,2.0070047,3.0209308,,,,,,,,,,,,,, -69800,2.0798829,3.0758295,,,,,,,,,,,,,, -69900,2.075657,3.0564332,,,,,,,,,,,,,, -70000,2.026347,3.0072453,,,,,,,,,,,,,, -70100,2.0638587,3.0457299,,,,,,,,,,,,,, -70200,2.070709,3.0749063,,,,,,,,,,,,,, -70300,2.1597164,3.1543546,,,,,,,,,,,,,, -70400,2.1500175,3.0725706,,,,,,,,,,,,,, -70500,2.144366,3.0925255,,,,,,,,,,,,,, -70600,2.0337071,3.1198137,,,,,,,,,,,,,, -70700,2.1377082,3.102366,,,,,,,,,,,,,, -70800,1.9517137,3.078338,,,,,,,,,,,,,, -70900,2.0449402,3.0365353,,,,,,,,,,,,,, -71000,2.1329358,3.1058655,,,,,,,,,,,,,, -71026,,,0.7942841053009033,1.016778826713562,0.6810799837112427,1.494775891304016,50000.0,0.5546000003814697,2.1463091373443604,10000.0,24005.601548433304,24861.00675535202,24005.601548433304,851.3368241786957,1.5676684379577637,0.0 -71100,2.1055758,3.0118704,,,,,,,,,,,,,, -71200,2.10708,3.0434394,,,,,,,,,,,,,, -71300,2.0662236,3.0474687,,,,,,,,,,,,,, -71400,2.22244,3.093185,,,,,,,,,,,,,, -71500,2.037222,3.0262737,,,,,,,,,,,,,, -71600,2.249777,3.1245592,,,,,,,,,,,,,, -71700,2.1272504,3.0733986,,,,,,,,,,,,,, -71800,2.2072597,3.080582,,,,,,,,,,,,,, -71900,2.045068,3.0117123,,,,,,,,,,,,,, -72000,2.0637271,3.0416026,,,,,,,,,,,,,, -72100,2.0929897,3.0240202,,,,,,,,,,,,,, -72200,2.152158,3.0508842,,,,,,,,,,,,,, -72300,2.2502043,3.097394,,,,,,,,,,,,,, -72400,2.1798744,3.0592403,,,,,,,,,,,,,, -72500,2.0734332,3.0611181,,,,,,,,,,,,,, -72538,,,0.7712053656578064,1.1332015991210938,0.6800199747085571,1.5349023342132568,50000.0,0.5504000186920166,2.2045271396636963,10000.0,24515.58721637726,25389.0153260231,24515.58721637726,869.2670221328735,1.6069247722625732,0.0 -72600,2.2466307,3.1324525,,,,,,,,,,,,,, -72700,2.325101,3.1351957,,,,,,,,,,,,,, -72800,2.24387,3.1439724,,,,,,,,,,,,,, -72900,2.1091073,3.0700278,,,,,,,,,,,,,, -73000,2.1315846,3.0300124,,,,,,,,,,,,,, -73100,2.174537,3.0781403,,,,,,,,,,,,,, -73200,2.1107829,3.0640316,,,,,,,,,,,,,, -73300,2.175414,3.113498,,,,,,,,,,,,,, -73400,2.3746905,3.062942,,,,,,,,,,,,,, -73500,2.272867,3.079413,,,,,,,,,,,,,, -73600,2.2613642,3.0730255,,,,,,,,,,,,,, -73700,2.0848138,3.0970314,,,,,,,,,,,,,, -73800,2.1625383,3.075782,,,,,,,,,,,,,, -73900,2.0921528,3.0682743,,,,,,,,,,,,,, -74000,2.031427,3.0173764,,,,,,,,,,,,,, -74051,,,0.7703882455825806,1.119797945022583,0.6869999766349792,1.496230125427246,50000.0,0.5599000453948975,2.144322395324707,10000.0,25025.69666481018,25916.87974834442,25025.69666481018,886.9353656768799,1.6402819156646729,0.0 -74100,2.0589755,3.1170597,,,,,,,,,,,,,, -74200,1.956284,3.0176373,,,,,,,,,,,,,, -74300,2.23166,3.0600123,,,,,,,,,,,,,, -74400,2.0663502,2.9837437,,,,,,,,,,,,,, -74500,2.1635592,3.0108514,,,,,,,,,,,,,, -74600,2.3231971,3.100458,,,,,,,,,,,,,, -74700,2.2283568,3.0863912,,,,,,,,,,,,,, -74800,2.218578,3.0970263,,,,,,,,,,,,,, -74900,2.1789174,3.0198603,,,,,,,,,,,,,, -75000,2.1294491,3.1267803,,,,,,,,,,,,,, -75100,2.0944626,3.0574524,,,,,,,,,,,,,, -75200,2.2609065,3.0052538,,,,,,,,,,,,,, -75300,2.1869094,3.0043616,,,,,,,,,,,,,, -75400,2.197427,3.0536828,,,,,,,,,,,,,, -75500,2.5704806,3.1327863,,,,,,,,,,,,,, -75564,,,0.7689333558082581,1.1266475915908811,0.6871799826622009,1.490763545036316,50000.0,0.5641000270843506,2.1208748817443848,10000.0,25535.903984308243,26444.623901844025,25535.903984308243,904.3809192180634,1.6784143447875977,0.0 -75600,2.1815386,3.1106467,,,,,,,,,,,,,, -75700,2.2289417,3.063771,,,,,,,,,,,,,, -75800,2.2306464,3.036181,,,,,,,,,,,,,, -75900,2.1296575,3.0037277,,,,,,,,,,,,,, -76000,2.4546077,3.0823328,,,,,,,,,,,,,, -76100,2.241062,3.0088477,,,,,,,,,,,,,, -76200,2.2555985,3.155474,,,,,,,,,,,,,, -76300,2.0580032,3.029706,,,,,,,,,,,,,, -76400,2.5501282,3.0420778,,,,,,,,,,,,,, -76500,2.2918181,3.0779238,,,,,,,,,,,,,, -76600,2.4470315,3.055308,,,,,,,,,,,,,, -76700,2.3792853,3.0324593,,,,,,,,,,,,,, -76800,2.1271577,3.0853267,,,,,,,,,,,,,, -76900,2.30752,3.0297518,,,,,,,,,,,,,, -77000,2.12522,3.0446355,,,,,,,,,,,,,, -77077,,,0.7555803656578064,1.252349853515625,0.6780799627304077,1.591875433921814,50000.0,0.5550000071525574,2.229222297668457,10000.0,26045.897315502167,26972.284712553024,26045.897315502167,921.9570591449738,1.717276096343994,0.0 -77100,2.1581764,3.0802202,,,,,,,,,,,,,, -77200,2.1061497,2.9982789,,,,,,,,,,,,,, -77300,2.2590108,3.0906389,,,,,,,,,,,,,, -77400,2.0698369,3.0400653,,,,,,,,,,,,,, -77500,2.3876731,3.018826,,,,,,,,,,,,,, -77600,2.3684337,3.069437,,,,,,,,,,,,,, -77700,2.289516,3.0404363,,,,,,,,,,,,,, -77800,2.1328952,2.9876676,,,,,,,,,,,,,, -77900,2.2740655,3.0410683,,,,,,,,,,,,,, -78000,2.3173947,2.9697123,,,,,,,,,,,,,, -78100,2.2891045,3.0616186,,,,,,,,,,,,,, -78200,2.1857777,3.0588145,,,,,,,,,,,,,, -78300,2.3117733,3.0305262,,,,,,,,,,,,,, -78400,2.1635604,3.0811987,,,,,,,,,,,,,, -78500,2.4836512,3.0689766,,,,,,,,,,,,,, -78589,,,0.7691724896430969,1.146484375,0.6866399645805359,1.4988430738449097,50000.0,0.5601000189781189,2.1550445556640625,10000.0,26555.934617996216,27500.04880857468,26555.934617996216,939.5918412208556,1.7566168308258057,0.0 -78600,2.2413838,3.0128145,,,,,,,,,,,,,, -78700,2.1585824,3.0175123,,,,,,,,,,,,,, -78800,2.2350771,3.0390038,,,,,,,,,,,,,, -78900,2.1691637,3.0567627,,,,,,,,,,,,,, -79000,2.3552494,3.0493007,,,,,,,,,,,,,, -79100,2.1395805,3.0107713,,,,,,,,,,,,,, -79200,2.0706046,2.927735,,,,,,,,,,,,,, -79300,2.573482,3.1812344,,,,,,,,,,,,,, -79400,2.1383367,3.0478494,,,,,,,,,,,,,, -79500,2.2044873,2.965746,,,,,,,,,,,,,, -79600,2.234112,2.9625363,,,,,,,,,,,,,, -79700,2.2492588,2.9968553,,,,,,,,,,,,,, -79800,2.200992,3.058062,,,,,,,,,,,,,, -79900,2.317369,3.0609555,,,,,,,,,,,,,, -80000,2.4135199,3.0424566,,,,,,,,,,,,,, -80100,2.3285854,3.0756938,,,,,,,,,,,,,, -80103,,,0.8092314600944519,0.9729456305503844,0.6866999864578247,1.4889295101165771,50000.0,0.5586000084877014,2.127764940261841,10000.0,27066.09313583374,28029.13027572632,27066.09313583374,958.4196372032166,1.7997441291809082,0.0 -80200,2.2908614,3.010565,,,,,,,,,,,,,, -80300,2.0907028,3.0086126,,,,,,,,,,,,,, -80400,2.2159147,3.0919538,,,,,,,,,,,,,, -80500,2.3669543,3.094497,,,,,,,,,,,,,, -80600,2.490448,3.0098093,,,,,,,,,,,,,, -80700,2.446222,3.024467,,,,,,,,,,,,,, -80800,2.2965724,3.071173,,,,,,,,,,,,,, -80900,2.297657,3.0127034,,,,,,,,,,,,,, -81000,2.3590174,3.0856235,,,,,,,,,,,,,, -81100,2.1788855,3.0912771,,,,,,,,,,,,,, -81200,2.230334,3.039454,,,,,,,,,,,,,, -81300,2.2954237,3.0418942,,,,,,,,,,,,,, -81400,2.2530954,3.0813131,,,,,,,,,,,,,, -81500,2.2666698,3.0134501,,,,,,,,,,,,,, -81600,2.2312243,2.9672294,,,,,,,,,,,,,, -81614,,,0.7765266299247742,1.067360758781433,0.6843599677085876,1.4753485918045044,50000.0,0.5557000041007996,2.1231517791748047,10000.0,27576.04886603356,28556.85472536087,27576.04886603356,976.1020576953888,1.8335380554199217,0.0 -81700,2.4601088,3.103424,,,,,,,,,,,,,, -81800,2.2259214,2.9716854,,,,,,,,,,,,,, -81900,2.237922,2.9749885,,,,,,,,,,,,,, -82000,2.187703,2.9739306,,,,,,,,,,,,,, -82100,2.1972897,3.11575,,,,,,,,,,,,,, -82200,2.3228467,3.0118382,,,,,,,,,,,,,, -82300,2.3275461,3.0225897,,,,,,,,,,,,,, -82400,2.2463613,2.970265,,,,,,,,,,,,,, -82500,2.2746933,3.0182376,,,,,,,,,,,,,, -82600,2.276138,3.037312,,,,,,,,,,,,,, -82700,2.2192628,2.9860291,,,,,,,,,,,,,, -82800,2.363969,3.0895624,,,,,,,,,,,,,, -82900,2.4012685,2.9794302,,,,,,,,,,,,,, -83000,2.5967004,3.0721207,,,,,,,,,,,,,, -83100,2.3936021,3.0373237,,,,,,,,,,,,,, -83126,,,0.7771045565605164,1.1030161380767822,0.6881399750709534,1.486668586730957,50000.0,0.5627000331878662,2.139356851577759,10000.0,28085.96573448181,29084.32238149643,28085.96573448181,993.557685136795,1.876111507415772,0.0 -83200,2.2594316,2.9595783,,,,,,,,,,,,,, -83300,2.337217,3.0235243,,,,,,,,,,,,,, -83400,2.1787798,3.009937,,,,,,,,,,,,,, -83500,2.2763143,3.0358813,,,,,,,,,,,,,, -83600,2.3439095,3.0221956,,,,,,,,,,,,,, -83700,2.3150363,3.0432444,,,,,,,,,,,,,, -83800,2.2947524,3.0479112,,,,,,,,,,,,,, -83900,2.1171575,2.9838145,,,,,,,,,,,,,, -84000,2.294384,3.0426567,,,,,,,,,,,,,, -84100,2.3552794,3.0770633,,,,,,,,,,,,,, -84200,2.267122,2.9797215,,,,,,,,,,,,,, -84300,2.2825496,2.9695196,,,,,,,,,,,,,, -84400,2.579687,3.0171974,,,,,,,,,,,,,, -84500,2.2673676,2.9744215,,,,,,,,,,,,,, -84600,2.295545,2.999632,,,,,,,,,,,,,, -84638,,,0.7768853306770325,1.0869734287261963,0.693399965763092,1.4536720514297483,50000.0,0.5644000172615051,2.1074471473693848,10000.0,28596.036551475525,29612.1298763752,28596.036551475525,1011.1986262798308,1.9192132949829104,0.0 -84700,2.235235,3.039692,,,,,,,,,,,,,, -84800,2.4098408,2.9640403,,,,,,,,,,,,,, -84900,2.1804051,3.001413,,,,,,,,,,,,,, -85000,2.3870797,3.049481,,,,,,,,,,,,,, -85100,2.2533233,2.9855487,,,,,,,,,,,,,, -85200,2.266754,2.9889908,,,,,,,,,,,,,, -85300,2.248193,3.0210354,,,,,,,,,,,,,, -85400,2.2327511,2.962266,,,,,,,,,,,,,, -85500,2.5382144,3.0860856,,,,,,,,,,,,,, -85600,2.5082257,3.0679889,,,,,,,,,,,,,, -85700,2.2934,3.0254183,,,,,,,,,,,,,, -85800,2.3282387,3.0411167,,,,,,,,,,,,,, -85900,2.3426356,2.9701922,,,,,,,,,,,,,, -86000,2.21919,3.0196226,,,,,,,,,,,,,, -86100,2.4717767,3.052503,,,,,,,,,,,,,, -86151,,,0.7732979655265808,1.0890520811080933,0.6902599930763245,1.4564690589904783,50000.0,0.5592000484466553,2.1242754459381104,10000.0,29106.1050195694,30139.81373643875,29106.1050195694,1028.7158319950104,1.964526891708374,0.0 -86200,2.2112617,3.020465,,,,,,,,,,,,,, -86300,2.4932168,3.063751,,,,,,,,,,,,,, -86400,2.2077935,3.0461113,,,,,,,,,,,,,, -86500,2.4193094,3.0394466,,,,,,,,,,,,,, -86600,2.213718,2.9765568,,,,,,,,,,,,,, -86700,2.5031471,3.0107636,,,,,,,,,,,,,, -86800,2.4572215,3.0410044,,,,,,,,,,,,,, -86900,2.3681738,3.0097957,,,,,,,,,,,,,, -87000,2.2856188,2.9608216,,,,,,,,,,,,,, -87100,2.3898354,3.0455658,,,,,,,,,,,,,, -87200,2.2271707,2.974361,,,,,,,,,,,,,, -87300,2.2256963,2.938632,,,,,,,,,,,,,, -87400,2.1626809,2.9608312,,,,,,,,,,,,,, -87500,2.4697304,2.9928515,,,,,,,,,,,,,, -87600,2.5567398,3.008672,,,,,,,,,,,,,, -87664,,,0.755281388759613,1.1695573329925537,0.6784399747848511,1.5213377475738523,50000.0,0.5514000058174133,2.184067964553833,10000.0,29616.25096130371,30667.99164557457,29616.25096130371,1046.6535975933075,2.0054867267608643,0.0 -87700,2.3871765,2.993953,,,,,,,,,,,,,, -87800,2.447939,3.0782423,,,,,,,,,,,,,, -87900,2.396958,2.998371,,,,,,,,,,,,,, -88000,2.553091,3.025962,,,,,,,,,,,,,, -88100,2.3885937,2.9766963,,,,,,,,,,,,,, -88200,2.265626,3.0657709,,,,,,,,,,,,,, -88300,2.5035288,2.9939198,,,,,,,,,,,,,, -88400,2.3772058,2.9785318,,,,,,,,,,,,,, -88500,2.3060749,2.9914303,,,,,,,,,,,,,, -88600,2.4007926,2.9608524,,,,,,,,,,,,,, -88700,2.2703822,3.0324497,,,,,,,,,,,,,, -88800,2.276468,2.9957643,,,,,,,,,,,,,, -88900,2.282925,3.016981,,,,,,,,,,,,,, -89000,2.3576188,3.014258,,,,,,,,,,,,,, -89100,2.2919707,2.9207852,,,,,,,,,,,,,, -89177,,,0.7955994606018066,1.0491045713424685,0.6782199740409851,1.5385297536849976,50000.0,0.5473000407218933,2.2165260314941406,10000.0,30126.34645795822,31195.745640039444,30126.34645795822,1064.221899986267,2.043313980102539,0.0 -89200,2.4657393,2.9593287,,,,,,,,,,,,,, -89300,2.408618,2.9618866,,,,,,,,,,,,,, -89400,2.2891183,2.9596274,,,,,,,,,,,,,, -89500,2.3556695,2.9626296,,,,,,,,,,,,,, -89600,2.257324,2.977253,,,,,,,,,,,,,, -89700,2.6337538,3.0386682,,,,,,,,,,,,,, -89800,2.5014892,3.0036469,,,,,,,,,,,,,, -89900,2.3493752,2.9474564,,,,,,,,,,,,,, -90000,2.390914,2.9727933,,,,,,,,,,,,,, -90100,2.364261,3.04283,,,,,,,,,,,,,, -90200,2.4026978,2.9838927,,,,,,,,,,,,,, -90300,2.6850998,3.0390868,,,,,,,,,,,,,, -90400,2.4919183,3.0577126,,,,,,,,,,,,,, -90500,2.5001729,3.059477,,,,,,,,,,,,,, -90600,2.3633537,2.9210324,,,,,,,,,,,,,, -90690,,,0.8026148080825806,0.9901580810546876,0.7046200037002563,1.4159032106399536,50000.0,0.572700023651123,2.065574884414673,10000.0,30636.46873354912,31723.53106689453,30636.46873354912,1081.7890536785126,2.085658311843872,0.0 -90700,2.3256934,2.9498377,,,,,,,,,,,,,, -90800,2.395515,3.0392644,,,,,,,,,,,,,, -90900,2.6391056,2.9800472,,,,,,,,,,,,,, -91000,2.446682,2.9519725,,,,,,,,,,,,,, -91100,2.434639,2.9368994,,,,,,,,,,,,,, -91200,2.447587,2.9097981,,,,,,,,,,,,,, -91300,2.2865176,2.9677608,,,,,,,,,,,,,, -91400,2.3046124,2.880134,,,,,,,,,,,,,, -91500,2.3710036,3.0564575,,,,,,,,,,,,,, -91600,2.2957134,2.9808807,,,,,,,,,,,,,, -91700,2.418604,3.1010344,,,,,,,,,,,,,, -91800,2.9725006,2.9939725,,,,,,,,,,,,,, -91900,2.3273256,2.914222,,,,,,,,,,,,,, -92000,2.3476512,2.9200292,,,,,,,,,,,,,, -92100,2.43597,2.945142,,,,,,,,,,,,,, -92200,2.4771466,2.9067395,,,,,,,,,,,,,, -92203,,,0.7906369566917419,1.0359426736831665,0.6946600079536438,1.449451208114624,50000.0,0.5664000511169434,2.099205255508423,10000.0,31146.66298151016,32251.288065195084,31146.66298151016,1099.251507282257,2.132451057434082,0.0 -92300,2.4803092,2.8828783,,,,,,,,,,,,,, -92400,2.598011,2.9302075,,,,,,,,,,,,,, -92500,2.4560418,2.9501307,,,,,,,,,,,,,, -92600,2.3844733,2.93704,,,,,,,,,,,,,, -92700,2.6172304,2.9669662,,,,,,,,,,,,,, -92800,2.397493,2.9990225,,,,,,,,,,,,,, -92900,2.5563583,3.0186822,,,,,,,,,,,,,, -93000,2.4243817,2.9484153,,,,,,,,,,,,,, -93100,2.5628643,3.0167325,,,,,,,,,,,,,, -93200,2.331658,2.9196188,,,,,,,,,,,,,, -93300,2.5253131,2.967803,,,,,,,,,,,,,, -93400,2.416209,2.9597104,,,,,,,,,,,,,, -93500,2.4443066,2.9266748,,,,,,,,,,,,,, -93600,2.4771519,2.9782345,,,,,,,,,,,,,, -93700,2.7750962,3.025828,,,,,,,,,,,,,, -93715,,,0.7890027165412903,1.048642635345459,0.6960399746894836,1.4489803314208984,50000.0,0.5732000470161438,2.0790419578552246,10000.0,31656.57655262947,32778.84194803238,31656.57655262947,1116.7954907417295,2.174596071243286,0.0 -93800,2.4451995,2.919316,,,,,,,,,,,,,, -93900,2.5020604,3.0637865,,,,,,,,,,,,,, -94000,2.3640757,2.9433835,,,,,,,,,,,,,, -94100,2.562594,3.0125144,,,,,,,,,,,,,, -94200,2.5848765,2.9469464,,,,,,,,,,,,,, -94300,2.4795315,3.0104396,,,,,,,,,,,,,, -94400,2.4412,3.015262,,,,,,,,,,,,,, -94500,2.494452,2.9504752,,,,,,,,,,,,,, -94600,2.29759,2.914502,,,,,,,,,,,,,, -94700,2.4893904,2.9529219,,,,,,,,,,,,,, -94800,2.4868152,2.9175339,,,,,,,,,,,,,, -94900,2.6019897,2.9615693,,,,,,,,,,,,,, -95000,2.3147125,2.9767935,,,,,,,,,,,,,, -95100,2.628887,2.9772155,,,,,,,,,,,,,, -95200,2.4862978,2.9515328,,,,,,,,,,,,,, -95228,,,0.7940050959587097,1.0202025175094604,0.706059992313385,1.4011818170547483,50000.0,0.5764999985694885,2.050769567489624,10000.0,32166.51022219658,33306.58448624611,32166.51022219658,1134.5098896026611,2.2162539958953857,0.0 -95300,2.512719,3.0108135,,,,,,,,,,,,,, -95400,2.584843,2.876675,,,,,,,,,,,,,, -95500,2.4203308,2.9324324,,,,,,,,,,,,,, -95600,2.3451912,2.9120867,,,,,,,,,,,,,, -95700,2.4930067,2.9568558,,,,,,,,,,,,,, -95800,2.6306758,3.003621,,,,,,,,,,,,,, -95900,2.6805685,2.9594967,,,,,,,,,,,,,, -96000,2.5594828,2.941711,,,,,,,,,,,,,, -96100,2.482649,2.9832006,,,,,,,,,,,,,, -96200,2.6181781,2.9402533,,,,,,,,,,,,,, -96300,2.572227,2.9910889,,,,,,,,,,,,,, -96400,2.7825372,3.0568023,,,,,,,,,,,,,, -96500,2.2693708,2.931695,,,,,,,,,,,,,, -96600,2.5109603,2.9677014,,,,,,,,,,,,,, -96700,2.4563699,2.9158475,,,,,,,,,,,,,, -96741,,,0.7877470850944519,1.0436723232269287,0.6979999542236328,1.427424669265747,50000.0,0.5685000419616699,2.087905168533325,10000.0,32676.71812939644,33834.412241220474,32676.71812939644,1152.02063703537,2.271630048751831,0.0 -96800,2.514723,3.0135832,,,,,,,,,,,,,, -96900,2.4679515,3.0071151,,,,,,,,,,,,,, -97000,2.3896823,2.9753776,,,,,,,,,,,,,, -97100,2.4827487,2.9528036,,,,,,,,,,,,,, -97200,2.5177004,2.9168541,,,,,,,,,,,,,, -97300,2.590303,2.9762855,,,,,,,,,,,,,, -97400,2.3642535,2.9351125,,,,,,,,,,,,,, -97500,2.5344756,2.9076447,,,,,,,,,,,,,, -97600,2.7458823,2.9424179,,,,,,,,,,,,,, -97700,2.367546,2.936024,,,,,,,,,,,,,, -97800,2.6706169,3.0697758,,,,,,,,,,,,,, -97900,2.5725496,2.9612312,,,,,,,,,,,,,, -98000,2.35512,2.9432085,,,,,,,,,,,,,, -98100,2.5685654,2.9393165,,,,,,,,,,,,,, -98200,2.4395885,2.8906436,,,,,,,,,,,,,, -98254,,,0.8298588991165161,0.8838488459587097,0.7114599943161011,1.3823779821395874,50000.0,0.5869000554084778,2.0103774070739746,10000.0,33186.897922992706,34362.17217421532,33186.897922992706,1169.504544019699,2.3140299320220947,0.0 -98300,2.461271,2.9538546,,,,,,,,,,,,,, -98400,2.5159123,2.9107764,,,,,,,,,,,,,, -98500,2.5502374,2.9348202,,,,,,,,,,,,,, -98600,2.5604422,2.9166954,,,,,,,,,,,,,, -98700,2.491122,2.8845558,,,,,,,,,,,,,, -98800,2.4429932,2.885202,,,,,,,,,,,,,, -98900,2.55395,2.9363515,,,,,,,,,,,,,, -99000,2.5055974,2.9857745,,,,,,,,,,,,,, -99100,2.4204483,2.9206543,,,,,,,,,,,,,, -99200,2.6449053,2.9607503,,,,,,,,,,,,,, -99300,2.660181,2.9443088,,,,,,,,,,,,,, -99400,2.5221047,2.9194906,,,,,,,,,,,,,, -99500,2.8897662,2.9291852,,,,,,,,,,,,,, -99600,2.5041263,2.9242887,,,,,,,,,,,,,, -99700,2.3731072,2.8773322,,,,,,,,,,,,,, -99767,,,0.8191565275192261,0.9219579696655272,0.7096199989318848,1.3757820129394531,50000.0,0.5852000117301941,2.0205318927764893,10000.0,33696.980467796326,34889.91017913818,33696.980467796326,1187.0604236125946,2.3605294227600098,0.0 -99800,2.4735541,2.9406881,,,,,,,,,,,,,, -99900,2.394798,2.9533315,,,,,,,,,,,,,, -100000,2.679059,2.9936085,,,,,,,,,,,,,, -100100,2.5895066,2.9573238,,,,,,,,,,,,,, -100200,2.680236,2.8549407,,,,,,,,,,,,,, -100300,2.6288095,2.9881237,,,,,,,,,,,,,, -100400,2.5594504,2.9687793,,,,,,,,,,,,,, -100500,2.6714435,2.974659,,,,,,,,,,,,,, -100600,2.715471,2.9453092,,,,,,,,,,,,,, -100700,2.7179644,2.9671476,,,,,,,,,,,,,, -100800,2.7202544,2.8966475,,,,,,,,,,,,,, -100900,2.454694,2.8879585,,,,,,,,,,,,,, -101000,2.4358044,2.9458573,,,,,,,,,,,,,, -101100,2.711137,2.9530783,,,,,,,,,,,,,, -101200,2.8594759,2.9975286,,,,,,,,,,,,,, -101280,,,0.8095503449440002,0.9634189605712892,0.7093799710273743,1.3969210386276243,50000.0,0.5839000344276428,2.0435094833374023,10000.0,34206.91540932655,35417.475497722626,34206.91540932655,1204.593267440796,2.404573440551758,0.0 -101300,2.5285485,2.9617178,,,,,,,,,,,,,, -101400,2.6452868,2.9539447,,,,,,,,,,,,,, -101500,2.9229443,2.88143,,,,,,,,,,,,,, -101600,2.4132643,2.9296274,,,,,,,,,,,,,, -101700,2.6137567,2.9376612,,,,,,,,,,,,,, -101800,2.7647486,2.909402,,,,,,,,,,,,,, -101900,2.641303,2.853342,,,,,,,,,,,,,, -102000,2.6929286,2.946941,,,,,,,,,,,,,, -102100,2.6455941,2.9517932,,,,,,,,,,,,,, -102200,2.7258935,2.908512,,,,,,,,,,,,,, -102300,2.5728114,2.9412868,,,,,,,,,,,,,, -102400,2.719334,2.9635148,,,,,,,,,,,,,, -102500,2.5198154,2.9091225,,,,,,,,,,,,,, -102600,2.6002343,2.8889062,,,,,,,,,,,,,, -102700,2.577918,2.928433,,,,,,,,,,,,,, -102793,,,0.8084542155265808,0.9739400744438172,0.7127400040626526,1.390412449836731,50000.0,0.5825000405311584,2.034860849380493,10000.0,34717.08581137657,35945.378918647766,34717.08581137657,1222.228625535965,2.44966459274292,0.0 -102800,2.641222,2.9023206,,,,,,,,,,,,,, -102900,2.5559485,2.8468943,,,,,,,,,,,,,, -103000,2.6844206,2.898819,,,,,,,,,,,,,, -103100,2.5591831,2.8997536,,,,,,,,,,,,,, -103200,2.4438646,2.937222,,,,,,,,,,,,,, -103300,2.584804,2.910505,,,,,,,,,,,,,, -103400,2.5682845,2.8884027,,,,,,,,,,,,,, -103500,2.5554833,2.8978074,,,,,,,,,,,,,, -103600,2.603716,2.8926826,,,,,,,,,,,,,, -103700,2.8721774,2.9194133,,,,,,,,,,,,,, -103800,2.6990254,2.9722433,,,,,,,,,,,,,, -103900,2.4426072,2.8540304,,,,,,,,,,,,,, -104000,2.5686154,2.9123118,,,,,,,,,,,,,, -104100,2.69202,2.8945904,,,,,,,,,,,,,, -104200,2.8104205,2.8802865,,,,,,,,,,,,,, -104300,2.699495,2.903702,,,,,,,,,,,,,, -104307,,,0.7966358065605164,0.990210771560669,0.7057600021362305,1.3867374658584597,50000.0,0.5819000005722046,2.023756504058838,10000.0,35227.30777025223,36473.11824512482,35227.30777025223,1239.6494569778442,2.493047952651977,0.0 -104400,2.6946638,2.882976,,,,,,,,,,,,,, -104500,2.8272102,2.9167619,,,,,,,,,,,,,, -104600,2.5928707,2.908515,,,,,,,,,,,,,, -104700,2.7013772,2.8809037,,,,,,,,,,,,,, -104800,2.7018375,2.9305046,,,,,,,,,,,,,, -104900,2.8053367,2.95168,,,,,,,,,,,,,, -105000,2.895306,2.937759,,,,,,,,,,,,,, -105100,2.7904017,2.9276123,,,,,,,,,,,,,, -105200,2.648252,2.947569,,,,,,,,,,,,,, -105300,2.957799,3.0102775,,,,,,,,,,,,,, -105400,2.6035998,2.9455447,,,,,,,,,,,,,, -105500,2.6380117,2.9309866,,,,,,,,,,,,,, -105600,2.6676834,2.8883407,,,,,,,,,,,,,, -105700,2.6796148,2.9269743,,,,,,,,,,,,,, -105800,2.8685238,2.8960094,,,,,,,,,,,,,, -105819,,,0.809968888759613,0.9250097274780272,0.7107399702072144,1.3485785722732544,50000.0,0.5885000228881836,2.001072645187378,10000.0,35737.31253170967,37000.57236742973,35737.31253170967,1257.0016777515411,2.5371103286743164,0.0 -105900,2.828626,2.8577998,,,,,,,,,,,,,, -106000,2.686922,2.935065,,,,,,,,,,,,,, -106100,2.6770356,2.8574805,,,,,,,,,,,,,, -106200,2.6757898,2.8610973,,,,,,,,,,,,,, -106300,2.6771064,2.869503,,,,,,,,,,,,,, -106400,2.5419388,2.8959663,,,,,,,,,,,,,, -106500,2.8051028,2.9207115,,,,,,,,,,,,,, -106600,2.70486,2.8745353,,,,,,,,,,,,,, -106700,2.7732313,2.8787074,,,,,,,,,,,,,, -106800,2.7411356,2.905864,,,,,,,,,,,,,, -106900,2.5759776,2.8736272,,,,,,,,,,,,,, -107000,2.6735058,2.8229308,,,,,,,,,,,,,, -107100,2.8327737,2.9389029,,,,,,,,,,,,,, -107200,2.6662314,2.9754996,,,,,,,,,,,,,, -107300,2.863577,2.9204535,,,,,,,,,,,,,, -107331,,,0.8362364172935486,0.838945746421814,0.7157599925994873,1.3426353931427002,50000.0,0.5887000560760498,1.985122561454773,10000.0,36247.26858711243,37528.30771255493,36247.26858711243,1274.6831283569336,2.58165979385376,0.0 -107400,2.6245759,2.8845296,,,,,,,,,,,,,, -107500,2.5488164,2.8942707,,,,,,,,,,,,,, -107600,2.7418802,2.867895,,,,,,,,,,,,,, -107700,2.6177197,2.9237056,,,,,,,,,,,,,, -107800,2.8214333,2.8973782,,,,,,,,,,,,,, -107900,3.2719574,2.9296122,,,,,,,,,,,,,, -108000,2.6892097,2.9326832,,,,,,,,,,,,,, -108100,2.763088,2.9281833,,,,,,,,,,,,,, -108200,2.6135793,2.8905268,,,,,,,,,,,,,, -108300,2.757172,2.949157,,,,,,,,,,,,,, -108400,2.8264613,2.9457152,,,,,,,,,,,,,, -108500,2.5464377,2.8052986,,,,,,,,,,,,,, -108600,2.8394377,2.9706423,,,,,,,,,,,,,, -108700,2.7144225,2.9107256,,,,,,,,,,,,,, -108800,2.7569194,2.9548247,,,,,,,,,,,,,, -108844,,,0.8233019709587097,0.8994101285934448,0.7140600085258484,1.3612909317016602,50000.0,0.5867000222206116,2.025434732437134,10000.0,36757.356506347656,38056.19006371498,36757.356506347656,1292.3768393993378,2.629001379013061,0.0 -108900,2.5308268,2.8467553,,,,,,,,,,,,,, -109000,2.8443234,2.9273615,,,,,,,,,,,,,, -109100,2.925969,2.889956,,,,,,,,,,,,,, -109200,2.8616447,2.8258843,,,,,,,,,,,,,, -109300,2.9549284,2.9239225,,,,,,,,,,,,,, -109400,2.7914524,2.882346,,,,,,,,,,,,,, -109500,3.2315607,2.9205337,,,,,,,,,,,,,, -109600,2.7746623,2.8848813,,,,,,,,,,,,,, -109700,2.7561986,2.8532982,,,,,,,,,,,,,, -109800,2.8110125,2.8472,,,,,,,,,,,,,, -109900,2.9484107,2.8979044,,,,,,,,,,,,,, -110000,2.7638354,2.8348882,,,,,,,,,,,,,, -110100,2.9494596,2.8439584,,,,,,,,,,,,,, -110200,2.7439451,2.8962195,,,,,,,,,,,,,, -110300,3.0242763,2.8588905,,,,,,,,,,,,,, -110357,,,0.8215481042861938,0.8967534899711609,0.7145199775695801,1.3505741357803345,50000.0,0.5910000205039978,1.9762768745422363,10000.0,37267.36195087433,38583.91264152527,37267.36195087433,1309.9977324008942,2.672827243804932,0.0 -110400,2.9405546,2.9739063,,,,,,,,,,,,,, -110500,2.8982847,2.8804703,,,,,,,,,,,,,, -110600,2.6676013,2.874075,,,,,,,,,,,,,, -110700,2.6786506,2.8841102,,,,,,,,,,,,,, -110800,2.986883,2.9126825,,,,,,,,,,,,,, -110900,2.847072,2.9163089,,,,,,,,,,,,,, -111000,2.7426517,2.8875225,,,,,,,,,,,,,, -111100,2.8107412,2.9379869,,,,,,,,,,,,,, -111200,2.6425014,2.7811263,,,,,,,,,,,,,, -111300,2.7674744,2.8748305,,,,,,,,,,,,,, -111400,2.8682704,2.8415022,,,,,,,,,,,,,, -111500,2.8498783,2.895948,,,,,,,,,,,,,, -111600,2.7958176,2.8590255,,,,,,,,,,,,,, -111700,2.9068348,2.8846314,,,,,,,,,,,,,, -111800,2.8907144,2.8838344,,,,,,,,,,,,,, -111870,,,0.8194156289100647,0.8783233761787415,0.7137799859046936,1.3280168771743774,50000.0,0.5868000388145447,1.9892569780349727,10000.0,37777.49592471123,39111.43874955177,37777.49592471123,1327.280189990997,2.72864294052124,0.0 -111900,2.8730416,2.8711843,,,,,,,,,,,,,, -112000,2.8247528,2.8062234,,,,,,,,,,,,,, -112100,2.9165068,2.8633952,,,,,,,,,,,,,, -112200,2.8351948,2.8762584,,,,,,,,,,,,,, -112300,2.9071877,2.9341211,,,,,,,,,,,,,, -112400,2.9187543,2.9325056,,,,,,,,,,,,,, -112500,3.0230727,2.8633244,,,,,,,,,,,,,, -112600,2.548533,2.8467646,,,,,,,,,,,,,, -112700,2.9145083,2.8413227,,,,,,,,,,,,,, -112800,2.8816884,2.8348403,,,,,,,,,,,,,, -112900,2.9358056,2.9256072,,,,,,,,,,,,,, -113000,2.9036503,2.8448951,,,,,,,,,,,,,, -113100,2.8514261,2.8095632,,,,,,,,,,,,,, -113200,2.9320881,2.9002957,,,,,,,,,,,,,, -113300,2.7644105,2.8274896,,,,,,,,,,,,,, -113384,,,0.8175222873687744,0.9152930974960328,0.7148399949073792,1.350134015083313,50000.0,0.5896000266075134,1.9782979488372805,10000.0,38287.695055007935,39639.17041897774,38287.695055007935,1344.7098760604858,2.7765440940856934,0.0 -113400,2.7423518,2.8284128,,,,,,,,,,,,,, -113500,2.8969853,2.8990464,,,,,,,,,,,,,, -113600,2.9621031,2.8655024,,,,,,,,,,,,,, -113700,2.8073637,2.8534274,,,,,,,,,,,,,, -113800,2.9143987,2.8970866,,,,,,,,,,,,,, -113900,2.9239721,2.8868232,,,,,,,,,,,,,, -114000,2.8630147,2.8897564,,,,,,,,,,,,,, -114100,2.6548216,2.828353,,,,,,,,,,,,,, -114200,2.926552,2.829507,,,,,,,,,,,,,, -114300,2.981313,2.8117442,,,,,,,,,,,,,, -114400,2.8239784,2.8498683,,,,,,,,,,,,,, -114500,2.666999,2.8543622,,,,,,,,,,,,,, -114600,2.9031262,2.8783474,,,,,,,,,,,,,, -114700,2.670768,2.8496304,,,,,,,,,,,,,, -114800,2.7900577,2.851726,,,,,,,,,,,,,, -114897,,,0.8275271058082581,0.9153132438659668,0.7184000015258789,1.3692786693572998,50000.0,0.5919000506401062,2.00195837020874,10000.0,38797.88224673271,40167.13142871857,38797.88224673271,1362.3821530342102,2.825171709060669,0.0 -114900,2.9005618,2.8850574,,,,,,,,,,,,,, -115000,2.961927,2.8313448,,,,,,,,,,,,,, -115100,2.866574,2.8440118,,,,,,,,,,,,,, -115200,3.0368905,2.828255,,,,,,,,,,,,,, -115300,2.8228843,2.8677523,,,,,,,,,,,,,, -115400,2.834771,2.8342953,,,,,,,,,,,,,, -115500,2.9634569,2.901478,,,,,,,,,,,,,, -115600,2.9522588,2.8152092,,,,,,,,,,,,,, -115700,3.0601583,2.9120164,,,,,,,,,,,,,, -115800,3.118657,2.8075483,,,,,,,,,,,,,, -115900,2.8056335,2.7489321,,,,,,,,,,,,,, -116000,2.695562,2.807877,,,,,,,,,,,,,, -116100,2.839259,2.866005,,,,,,,,,,,,,, -116200,2.9668548,2.8920877,,,,,,,,,,,,,, -116300,2.9238815,2.8371303,,,,,,,,,,,,,, -116400,3.0422237,2.8329587,,,,,,,,,,,,,, -116409,,,0.848074734210968,0.8165138363838196,0.7175999879837036,1.3547149896621704,50000.0,0.588200032711029,2.00327205657959,10000.0,39307.83862376213,40694.7973818779,39307.83862376213,1379.988827228546,2.8746609687805176,0.0 -116500,2.7535074,2.7927418,,,,,,,,,,,,,, -116600,2.8389606,2.8092525,,,,,,,,,,,,,, -116700,2.914766,2.8529847,,,,,,,,,,,,,, -116800,3.110576,2.8129976,,,,,,,,,,,,,, -116900,3.023826,2.805797,,,,,,,,,,,,,, -117000,2.980211,2.8531678,,,,,,,,,,,,,, -117100,2.7783766,2.8145025,,,,,,,,,,,,,, -117200,2.9131525,2.8747988,,,,,,,,,,,,,, -117300,3.14356,2.8664987,,,,,,,,,,,,,, -117400,2.7467246,2.7907846,,,,,,,,,,,,,, -117500,2.975486,2.8326511,,,,,,,,,,,,,, -117600,2.895801,2.884047,,,,,,,,,,,,,, -117700,3.0133905,2.8525028,,,,,,,,,,,,,, -117800,3.1020691,2.8740635,,,,,,,,,,,,,, -117900,2.971513,2.8040066,,,,,,,,,,,,,, -117922,,,0.8465999364852905,0.8292368650436401,0.7242199778556824,1.3349623680114746,50000.0,0.5986000299453735,1.9471795558929443,10000.0,39818.07694029808,41223.39574432373,39818.07694029808,1398.249297618866,2.9214396476745605,0.0 -118000,2.9321985,2.8298917,,,,,,,,,,,,,, -118100,2.8033845,2.840984,,,,,,,,,,,,,, -118200,2.9327357,2.8504336,,,,,,,,,,,,,, -118300,3.1078038,2.8794837,,,,,,,,,,,,,, -118400,2.8803566,2.7661443,,,,,,,,,,,,,, -118500,2.9673645,2.8416429,,,,,,,,,,,,,, -118600,3.046098,2.823453,,,,,,,,,,,,,, -118700,2.9418938,2.8544931,,,,,,,,,,,,,, -118800,3.1766162,2.8271513,,,,,,,,,,,,,, -118900,2.9869971,2.8030436,,,,,,,,,,,,,, -119000,3.0148153,2.8271346,,,,,,,,,,,,,, -119100,2.8923397,2.8420095,,,,,,,,,,,,,, -119200,2.9558005,2.8593216,,,,,,,,,,,,,, -119300,2.9077082,2.8263378,,,,,,,,,,,,,, -119400,2.9261408,2.8009727,,,,,,,,,,,,,, -119436,,,0.8353993892669678,0.8475539088249207,0.724399983882904,1.3222719430923462,50000.0,0.5958999991416931,1.961624264717102,10000.0,40328.25413155556,41751.07174015045,40328.25413155556,1415.6428937911987,2.97339940071106,0.0 -119500,2.8439956,2.8442187,,,,,,,,,,,,,, -119600,3.2785802,2.9175415,,,,,,,,,,,,,, -119700,3.0225332,2.831421,,,,,,,,,,,,,, -119800,2.9976263,2.8368325,,,,,,,,,,,,,, -119900,2.9406419,2.7948337,,,,,,,,,,,,,, -120000,2.957569,2.8455372,,,,,,,,,,,,,, -120100,2.9324794,2.8217983,,,,,,,,,,,,,, -120200,2.953662,2.8453555,,,,,,,,,,,,,, -120300,3.040889,2.8732283,,,,,,,,,,,,,, -120400,2.738418,2.7596092,,,,,,,,,,,,,, -120500,2.964214,2.7568512,,,,,,,,,,,,,, -120600,2.9841216,2.909328,,,,,,,,,,,,,, -120700,3.2441406,2.927564,,,,,,,,,,,,,, -120800,2.9790182,2.8194866,,,,,,,,,,,,,, -120900,3.20001,2.8300097,,,,,,,,,,,,,, -120949,,,0.8393654227256775,0.8338555097579956,0.7299799919128418,1.3040684461593628,50000.0,0.6011000275611877,1.9599751234054563,10000.0,40838.44957041741,42279.31935048103,40838.44957041741,1433.5952577590942,3.020388603210449,0.0 -121000,3.0168364,2.8302975,,,,,,,,,,,,,, -121100,3.0229752,2.7499552,,,,,,,,,,,,,, -121200,3.0938723,2.8436673,,,,,,,,,,,,,, -121300,3.233745,2.779789,,,,,,,,,,,,,, -121400,3.1744444,2.8072662,,,,,,,,,,,,,, -121500,2.995171,2.8456798,,,,,,,,,,,,,, -121600,3.163978,2.850093,,,,,,,,,,,,,, -121700,3.1706712,2.8560557,,,,,,,,,,,,,, -121800,3.1102722,2.7845645,,,,,,,,,,,,,, -121900,3.3432546,2.9056175,,,,,,,,,,,,,, -122000,2.9335182,2.7792578,,,,,,,,,,,,,, -122100,3.1238885,2.877059,,,,,,,,,,,,,, -122200,2.9793096,2.7655392,,,,,,,,,,,,,, -122300,2.8930168,2.8138082,,,,,,,,,,,,,, -122400,2.9887614,2.8101525,,,,,,,,,,,,,, -122463,,,0.8434908986091614,0.8159423470497131,0.730459988117218,1.2901736497879028,50000.0,0.6029000282287598,1.931108474731445,10000.0,41348.56471085549,42807.05680012703,41348.56471085549,1451.1136507987976,3.0704591274261475,0.0 -122500,3.057563,2.8402057,,,,,,,,,,,,,, -122600,3.1072068,2.8428895,,,,,,,,,,,,,, -122700,2.98116,2.7234457,,,,,,,,,,,,,, -122800,3.2296772,2.813614,,,,,,,,,,,,,, -122900,3.0306957,2.7296987,,,,,,,,,,,,,, -123000,3.165777,2.8150992,,,,,,,,,,,,,, -123100,3.1275287,2.8069158,,,,,,,,,,,,,, -123200,3.1629715,2.859608,,,,,,,,,,,,,, -123300,3.2210715,2.7703798,,,,,,,,,,,,,, -123400,3.138346,2.8350315,,,,,,,,,,,,,, -123500,3.1600664,2.8499703,,,,,,,,,,,,,, -123600,3.0180326,2.7944171,,,,,,,,,,,,,, -123700,2.944943,2.764831,,,,,,,,,,,,,, -123800,3.1842008,2.7918503,,,,,,,,,,,,,, -123900,3.1116867,2.7953258,,,,,,,,,,,,,, -123976,,,0.843191921710968,0.7751460671424866,0.7268799543380737,1.2697910070419312,50000.0,0.604200005531311,1.907442688941956,10000.0,41858.57477927208,43334.73067235947,41858.57477927208,1468.6732609272003,3.1229352951049805,0.0 -124000,3.0750055,2.7631407,,,,,,,,,,,,,, -124100,3.0268586,2.7739084,,,,,,,,,,,,,, -124200,2.9401133,2.803552,,,,,,,,,,,,,, -124300,3.283021,2.8818207,,,,,,,,,,,,,, -124400,3.0655012,2.7680273,,,,,,,,,,,,,, -124500,3.2409034,2.8140278,,,,,,,,,,,,,, -124600,3.124017,2.7807136,,,,,,,,,,,,,, -124700,3.3461053,2.8217008,,,,,,,,,,,,,, -124800,3.0234127,2.81128,,,,,,,,,,,,,, -124900,3.0107195,2.7946627,,,,,,,,,,,,,, -125000,2.9625344,2.7658594,,,,,,,,,,,,,, -125100,3.2266564,2.8516514,,,,,,,,,,,,,, -125200,3.3091743,2.8521338,,,,,,,,,,,,,, -125300,3.3114119,2.798009,,,,,,,,,,,,,, -125400,3.1027286,2.7957337,,,,,,,,,,,,,, -125489,,,0.84574294090271,0.8140817880630493,0.7204599976539612,1.347158670425415,50000.0,0.5929000377655029,2.000288963317871,10000.0,42368.69026613236,43862.41753697395,42368.69026613236,1486.1433503627777,3.1712000370025635,0.0 -125500,3.2146578,2.8630695,,,,,,,,,,,,,, -125600,3.004526,2.7904994,,,,,,,,,,,,,, -125700,3.1375368,2.792938,,,,,,,,,,,,,, -125800,3.1611526,2.735196,,,,,,,,,,,,,, -125900,3.2765183,2.7928987,,,,,,,,,,,,,, -126000,3.2921827,2.8597999,,,,,,,,,,,,,, -126100,2.92724,2.7264593,,,,,,,,,,,,,, -126200,3.2142003,2.751443,,,,,,,,,,,,,, -126300,2.9742396,2.7516003,,,,,,,,,,,,,, -126400,3.1445687,2.7736657,,,,,,,,,,,,,, -126500,3.4885824,2.772807,,,,,,,,,,,,,, -126600,2.9334288,2.7021918,,,,,,,,,,,,,, -126700,3.15867,2.7369952,,,,,,,,,,,,,, -126800,3.1190224,2.7454197,,,,,,,,,,,,,, -126900,3.1865115,2.7559507,,,,,,,,,,,,,, -127000,3.258861,2.8004313,,,,,,,,,,,,,, -127001,,,0.851980984210968,0.7569774389266968,0.7263799905776978,1.2846840620040894,50000.0,0.6015000343322754,1.922147512435913,10000.0,42878.68905615807,44389.9714012146,42878.68905615807,1503.5928556919098,3.2235565185546875,0.0 -127100,3.0786061,2.7724519,,,,,,,,,,,,,, -127200,3.317691,2.767415,,,,,,,,,,,,,, -127300,3.3306441,2.8906763,,,,,,,,,,,,,, -127400,3.1425853,2.7854192,,,,,,,,,,,,,, -127500,3.0072608,2.7886093,,,,,,,,,,,,,, -127600,2.9913397,2.7697685,,,,,,,,,,,,,, -127700,3.1515832,2.7809837,,,,,,,,,,,,,, -127800,3.4223344,2.820181,,,,,,,,,,,,,, -127900,3.1250255,2.8219578,,,,,,,,,,,,,, -128000,3.1774006,2.843521,,,,,,,,,,,,,, -128100,3.2648764,2.8149004,,,,,,,,,,,,,, -128200,3.3478112,2.8125021,,,,,,,,,,,,,, -128300,3.3185258,2.8171403,,,,,,,,,,,,,, -128400,3.0156193,2.7752082,,,,,,,,,,,,,, -128500,3.3144424,2.7741652,,,,,,,,,,,,,, -128514,,,0.8525390625,0.7666031718254089,0.7315399646759033,1.2773808240890503,50000.0,0.6098000407218933,1.9216325283050537,10000.0,43388.70094943047,44917.63706827164,43388.70094943047,1521.1470046043396,3.2711191177368164,0.0 -128600,3.1963487,2.7295115,,,,,,,,,,,,,, -128700,3.141576,2.676773,,,,,,,,,,,,,, -128800,3.277015,2.79268,,,,,,,,,,,,,, -128900,3.2070265,2.7765563,,,,,,,,,,,,,, -129000,3.1583934,2.7388499,,,,,,,,,,,,,, -129100,3.37614,2.8231115,,,,,,,,,,,,,, -129200,3.6065726,2.780591,,,,,,,,,,,,,, -129300,3.31906,2.8065019,,,,,,,,,,,,,, -129400,2.932639,2.692148,,,,,,,,,,,,,, -129500,3.2383802,2.75399,,,,,,,,,,,,,, -129600,3.1918244,2.701943,,,,,,,,,,,,,, -129700,3.4772909,2.772026,,,,,,,,,,,,,, -129800,3.1904693,2.7337332,,,,,,,,,,,,,, -129900,3.1717918,2.7908537,,,,,,,,,,,,,, -130000,3.1424968,2.8291237,,,,,,,,,,,,,, -130027,,,0.8579002022743225,0.7536612153053284,0.7330600023269653,1.2795575857162476,50000.0,0.6058000326156616,1.929798483848572,10000.0,43898.8369243145,45445.29455208778,43898.8369243145,1538.5677382946014,3.3185129165649414,0.0 -130100,3.2115006,2.7277598,,,,,,,,,,,,,, -130200,3.085947,2.7045665,,,,,,,,,,,,,, -130300,3.1391127,2.716,,,,,,,,,,,,,, -130400,3.203135,2.7257555,,,,,,,,,,,,,, -130500,3.215599,2.6789973,,,,,,,,,,,,,, -130600,3.1908941,2.6893992,,,,,,,,,,,,,, -130700,3.294419,2.776194,,,,,,,,,,,,,, -130800,3.1057286,2.7205722,,,,,,,,,,,,,, -130900,3.2007544,2.7278817,,,,,,,,,,,,,, -131000,3.238165,2.722628,,,,,,,,,,,,,, -131100,3.2311742,2.7053015,,,,,,,,,,,,,, -131200,3.2982733,2.7305923,,,,,,,,,,,,,, -131300,3.471452,2.722415,,,,,,,,,,,,,, -131400,3.2359886,2.7317011,,,,,,,,,,,,,, -131500,3.1752799,2.7967606,,,,,,,,,,,,,, -131540,,,0.8606704473495483,0.7800890803337097,0.7356199622154236,1.2970260381698608,50000.0,0.6103000044822693,1.926138162612915,10000.0,44409.057758808136,45973.265880823135,44409.057758808136,1556.2172808647156,3.3665919303894043,0.0 -131600,3.1343327,2.7357368,,,,,,,,,,,,,, -131700,3.214714,2.76414,,,,,,,,,,,,,, -131800,3.3617895,2.723517,,,,,,,,,,,,,, -131900,3.335571,2.7278695,,,,,,,,,,,,,, -132000,3.0405343,2.6995335,,,,,,,,,,,,,, -132100,3.3623114,2.7629204,,,,,,,,,,,,,, -132200,3.3143113,2.7770076,,,,,,,,,,,,,, -132300,3.3016305,2.737622,,,,,,,,,,,,,, -132400,3.5343099,2.7520432,,,,,,,,,,,,,, -132500,3.4728253,2.7288017,,,,,,,,,,,,,, -132600,3.1142802,2.7349765,,,,,,,,,,,,,, -132700,3.344412,2.7424624,,,,,,,,,,,,,, -132800,3.0379682,2.6952531,,,,,,,,,,,,,, -132900,3.303815,2.7641456,,,,,,,,,,,,,, -133000,3.328776,2.8318157,,,,,,,,,,,,,, -133052,,,0.8704559803009033,0.7145001888275146,0.737060010433197,1.258089542388916,50000.0,0.615600049495697,1.887601613998413,10000.0,44919.02611017227,46500.74430775642,44919.02611017227,1573.6209378242493,3.42007064819336,0.0 -133100,3.2118886,2.6852264,,,,,,,,,,,,,, -133200,3.4411278,2.8105838,,,,,,,,,,,,,, -133300,3.2563548,2.6971185,,,,,,,,,,,,,, -133400,3.60235,2.798006,,,,,,,,,,,,,, -133500,3.3168352,2.7204177,,,,,,,,,,,,,, -133600,3.2780523,2.732108,,,,,,,,,,,,,, -133700,3.3817282,2.7451143,,,,,,,,,,,,,, -133800,3.3971992,2.8108125,,,,,,,,,,,,,, -133900,3.3571172,2.7234924,,,,,,,,,,,,,, -134000,3.3412113,2.7294636,,,,,,,,,,,,,, -134100,3.339586,2.7458875,,,,,,,,,,,,,, -134200,3.2085462,2.696508,,,,,,,,,,,,,, -134300,3.2727945,2.7669759,,,,,,,,,,,,,, -134400,3.285421,2.6876554,,,,,,,,,,,,,, -134500,3.4600182,2.7727358,,,,,,,,,,,,,, -134565,,,0.8752790093421936,0.6944555640220642,0.7367199659347534,1.2754249572753906,50000.0,0.6135000586509705,1.904842615127564,10000.0,45429.102596998215,47028.34170055389,45429.102596998215,1591.034093618393,3.4738011360168457,0.0 -134600,3.3766258,2.7057552,,,,,,,,,,,,,, -134700,3.2531824,2.7658956,,,,,,,,,,,,,, -134800,3.2894368,2.7133005,,,,,,,,,,,,,, -134900,3.2656784,2.694065,,,,,,,,,,,,,, -135000,3.812064,2.7387671,,,,,,,,,,,,,, -135100,3.2661815,2.699904,,,,,,,,,,,,,, -135200,3.2889407,2.7895904,,,,,,,,,,,,,, -135300,3.3344023,2.7612157,,,,,,,,,,,,,, -135400,3.6070046,2.7933545,,,,,,,,,,,,,, -135500,3.4459374,2.785431,,,,,,,,,,,,,, -135600,3.4796824,2.7213001,,,,,,,,,,,,,, -135700,3.2781937,2.6985297,,,,,,,,,,,,,, -135800,3.2845035,2.7222574,,,,,,,,,,,,,, -135900,3.3806877,2.7248163,,,,,,,,,,,,,, -136000,3.344403,2.753193,,,,,,,,,,,,,, -136078,,,0.8797034025192261,0.667981743812561,0.7418999671936035,1.2350200414657593,50000.0,0.6200000047683716,1.8570964336395264,10000.0,45939.30143976212,47556.14796924591,45939.30143976212,1608.5354924201963,3.526309967041016,0.0 -136100,3.5096245,2.7538953,,,,,,,,,,,,,, -136200,3.4165852,2.7091844,,,,,,,,,,,,,, -136300,3.5487382,2.685255,,,,,,,,,,,,,, -136400,3.3317254,2.7255669,,,,,,,,,,,,,, -136500,3.6488264,2.720881,,,,,,,,,,,,,, -136600,3.5228734,2.6973593,,,,,,,,,,,,,, -136700,3.4794176,2.714161,,,,,,,,,,,,,, -136800,3.635883,2.7563975,,,,,,,,,,,,,, -136900,3.4409506,2.7014797,,,,,,,,,,,,,, -137000,3.181462,2.6312816,,,,,,,,,,,,,, -137100,3.5618336,2.7555766,,,,,,,,,,,,,, -137200,3.4623451,2.6885962,,,,,,,,,,,,,, -137300,3.3504179,2.711983,,,,,,,,,,,,,, -137400,3.1408348,2.6980221,,,,,,,,,,,,,, -137500,3.4064746,2.7617178,,,,,,,,,,,,,, -137591,,,0.8743423223495483,0.6792070865631104,0.7400799989700317,1.234809637069702,50000.0,0.6186000108718872,1.8752394914627075,10000.0,46449.3056447506,48083.96310710907,46449.3056447506,1626.244454622269,3.5757358074188232,0.0 -137600,3.3843744,2.6989346,,,,,,,,,,,,,, -137700,3.3708825,2.765168,,,,,,,,,,,,,, -137800,3.2991624,2.6912503,,,,,,,,,,,,,, -137900,3.651223,2.7220278,,,,,,,,,,,,,, -138000,3.6989398,2.7367375,,,,,,,,,,,,,, -138100,3.446006,2.6776078,,,,,,,,,,,,,, -138200,3.26486,2.70249,,,,,,,,,,,,,, -138300,3.4332225,2.7075377,,,,,,,,,,,,,, -138400,3.7552607,2.7492676,,,,,,,,,,,,,, -138500,3.2750795,2.660677,,,,,,,,,,,,,, -138600,3.5668342,2.723834,,,,,,,,,,,,,, -138700,3.5380478,2.75364,,,,,,,,,,,,,, -138800,3.4903824,2.6771357,,,,,,,,,,,,,, -138900,3.2784061,2.6925633,,,,,,,,,,,,,, -139000,3.488468,2.732158,,,,,,,,,,,,,, -139100,3.2754116,2.6946626,,,,,,,,,,,,,, -139104,,,0.880301296710968,0.6912388801574707,0.7443999648094177,1.2525049448013306,50000.0,0.6166000366210938,1.8933311700820925,10000.0,46959.34210753441,48611.97661066055,46959.34210753441,1644.1155638694763,3.6292083263397217,0.0 -139200,3.3382916,2.714276,,,,,,,,,,,,,, -139300,3.5153067,2.758179,,,,,,,,,,,,,, -139400,3.2655222,2.6411,,,,,,,,,,,,,, -139500,3.646165,2.729939,,,,,,,,,,,,,, -139600,3.475407,2.671155,,,,,,,,,,,,,, -139700,3.4734402,2.6439118,,,,,,,,,,,,,, -139800,3.566124,2.6776488,,,,,,,,,,,,,, -139900,3.5064645,2.6747773,,,,,,,,,,,,,, -140000,3.964698,2.6792502,,,,,,,,,,,,,, -140100,3.7307596,2.7760983,,,,,,,,,,,,,, -140200,3.418227,2.6524975,,,,,,,,,,,,,, -140300,3.2486093,2.705918,,,,,,,,,,,,,, -140400,3.4751039,2.6867173,,,,,,,,,,,,,, -140500,3.5934153,2.7224815,,,,,,,,,,,,,, -140600,3.2212293,2.6486106,,,,,,,,,,,,,, -140618,,,0.881257951259613,0.6726261973381042,0.7454599738121033,1.2311464548110962,50000.0,0.6273000240325928,1.8636746406555176,10000.0,47469.49969315529,49139.661640405655,47469.49969315529,1661.5373244285583,3.6814053058624254,0.0 -140700,3.6365464,2.6977227,,,,,,,,,,,,,, -140800,3.5539083,2.6417725,,,,,,,,,,,,,, -140900,3.3590953,2.6909304,,,,,,,,,,,,,, -141000,3.286717,2.6933203,,,,,,,,,,,,,, -141100,3.581841,2.683738,,,,,,,,,,,,,, -141200,3.5199108,2.6728482,,,,,,,,,,,,,, -141300,3.866063,2.726026,,,,,,,,,,,,,, -141400,3.4384723,2.6500158,,,,,,,,,,,,,, -141500,3.731727,2.724709,,,,,,,,,,,,,, -141600,3.7574186,2.7365425,,,,,,,,,,,,,, -141700,3.7596166,2.6805792,,,,,,,,,,,,,, -141800,3.6794994,2.6909552,,,,,,,,,,,,,, -141900,3.454623,2.6654654,,,,,,,,,,,,,, -142000,3.5427592,2.6828153,,,,,,,,,,,,,, -142100,3.9193425,2.7238197,,,,,,,,,,,,,, -142130,,,0.8937141299247742,0.6258179545402527,0.7417399883270264,1.2493900060653689,50000.0,0.6145000457763672,1.894722580909729,10000.0,47979.43871974945,49667.10993528366,47979.43871974945,1678.9434888362885,3.731943845748901,0.0 -142200,3.4546485,2.6456378,,,,,,,,,,,,,, -142300,3.299218,2.6732526,,,,,,,,,,,,,, -142400,3.8619037,2.7046418,,,,,,,,,,,,,, -142500,3.4218848,2.6876767,,,,,,,,,,,,,, -142600,3.7298434,2.687519,,,,,,,,,,,,,, -142700,3.4362087,2.6553102,,,,,,,,,,,,,, -142800,3.5659113,2.739802,,,,,,,,,,,,,, -142900,3.8252752,2.7055643,,,,,,,,,,,,,, -143000,3.3338516,2.6911287,,,,,,,,,,,,,, -143100,3.5720599,2.6909528,,,,,,,,,,,,,, -143200,3.6520236,2.6628938,,,,,,,,,,,,,, -143300,3.5501957,2.755682,,,,,,,,,,,,,, -143400,3.6629503,2.6760895,,,,,,,,,,,,,, -143500,3.8274686,2.6495934,,,,,,,,,,,,,, -143600,3.554138,2.664667,,,,,,,,,,,,,, -143643,,,0.8984175324440002,0.5999704599380493,0.7469199895858765,1.220354676246643,50000.0,0.6167000532150269,1.8685358762741089,10000.0,48489.606682538986,50194.81318330765,48489.606682538986,1696.3737258911133,3.7850804328918457,0.0 -143700,3.5842707,2.6375792,,,,,,,,,,,,,, -143800,4.015861,2.7432265,,,,,,,,,,,,,, -143900,3.2850254,2.6692219,,,,,,,,,,,,,, -144000,3.9959128,2.6823277,,,,,,,,,,,,,, -144100,3.5770473,2.6660354,,,,,,,,,,,,,, -144200,3.5161836,2.660286,,,,,,,,,,,,,, -144300,3.4883342,2.6463497,,,,,,,,,,,,,, -144400,3.5990498,2.6872568,,,,,,,,,,,,,, -144500,3.5872605,2.6180234,,,,,,,,,,,,,, -144600,3.6928248,2.6782832,,,,,,,,,,,,,, -144700,3.4328928,2.6442723,,,,,,,,,,,,,, -144800,3.640556,2.649104,,,,,,,,,,,,,, -144900,3.4711006,2.639719,,,,,,,,,,,,,, -145000,3.6681328,2.6963665,,,,,,,,,,,,,, -145100,3.6001573,2.6472738,,,,,,,,,,,,,, -145156,,,0.894551157951355,0.6113947629928589,0.745959997177124,1.216866374015808,50000.0,0.6219000220298767,1.8563332557678225,10000.0,48999.76789522171,50722.39692115784,48999.76789522171,1713.693132162094,3.836029767990112,0.0 -145200,3.5641785,2.702559,,,,,,,,,,,,,, -145300,3.515673,2.668979,,,,,,,,,,,,,, -145400,3.9060674,2.6179676,,,,,,,,,,,,,, -145500,3.8034878,2.6844645,,,,,,,,,,,,,, -145600,3.847802,2.702632,,,,,,,,,,,,,, -145700,3.959485,2.7152085,,,,,,,,,,,,,, -145800,3.6258326,2.6640785,,,,,,,,,,,,,, -145900,3.8166933,2.641422,,,,,,,,,,,,,, -146000,3.7297778,2.7473218,,,,,,,,,,,,,, -146100,4.0900292,2.6921048,,,,,,,,,,,,,, -146200,3.5346434,2.628711,,,,,,,,,,,,,, -146300,3.5867906,2.6768382,,,,,,,,,,,,,, -146400,3.7502308,2.6008523,,,,,,,,,,,,,, -146500,3.6255217,2.6958342,,,,,,,,,,,,,, -146600,3.6200762,2.570482,,,,,,,,,,,,,, -146670,,,0.8981783986091614,0.6124516725540161,0.748479962348938,1.2270301580429075,50000.0,0.6224000453948975,1.8555512428283687,10000.0,49509.97759890556,51250.5646352768,49509.97759890556,1731.5386974811554,3.894920825958252,0.0 -146700,3.8403525,2.6323254,,,,,,,,,,,,,, -146800,3.9369013,2.704397,,,,,,,,,,,,,, -146900,3.917387,2.6853743,,,,,,,,,,,,,, -147000,3.7063985,2.5738478,,,,,,,,,,,,,, -147100,3.7119524,2.6470144,,,,,,,,,,,,,, -147200,3.6939287,2.657481,,,,,,,,,,,,,, -147300,3.6025934,2.6493087,,,,,,,,,,,,,, -147400,3.9448109,2.6617427,,,,,,,,,,,,,, -147500,4.043768,2.7029924,,,,,,,,,,,,,, -147600,3.812212,2.6543014,,,,,,,,,,,,,, -147700,3.9122953,2.6658187,,,,,,,,,,,,,, -147800,3.7536447,2.6185272,,,,,,,,,,,,,, -147900,3.7769167,2.6446238,,,,,,,,,,,,,, -148000,3.9337382,2.7084298,,,,,,,,,,,,,, -148100,3.857306,2.6948392,,,,,,,,,,,,,, -148183,,,0.8929169178009033,0.6327682137489319,0.746679961681366,1.2398338317871094,50000.0,0.6222000122070312,1.876502990722656,10000.0,50020.18068480492,51778.22360420227,50020.18068480492,1748.8881268501282,3.947747468948364,0.0 -148200,3.5855327,2.6279461,,,,,,,,,,,,,, -148300,3.9781377,2.636734,,,,,,,,,,,,,, -148400,3.8550215,2.6669803,,,,,,,,,,,,,, -148500,3.6459801,2.6156418,,,,,,,,,,,,,, -148600,3.745586,2.6277988,,,,,,,,,,,,,, -148700,3.7139945,2.6174757,,,,,,,,,,,,,, -148800,3.7884154,2.6379695,,,,,,,,,,,,,, -148900,3.7063875,2.6578877,,,,,,,,,,,,,, -149000,3.8557487,2.6551113,,,,,,,,,,,,,, -149100,3.9845781,2.6456203,,,,,,,,,,,,,, -149200,3.533439,2.651156,,,,,,,,,,,,,, -149300,3.7373605,2.662736,,,,,,,,,,,,,, -149400,3.9216893,2.652719,,,,,,,,,,,,,, -149500,3.852886,2.645286,,,,,,,,,,,,,, -149600,3.4743001,2.5855136,,,,,,,,,,,,,, -149696,,,0.89652419090271,0.6026631593704224,0.7511999607086182,1.203833818435669,50000.0,0.6230000257492065,1.8418248891830444,10000.0,50530.1194422245,52305.64930200577,50530.1194422245,1766.2669341564178,4.001991271972656,0.0 -149700,3.888343,2.62228,,,,,,,,,,,,,, -149800,4.0877476,2.659354,,,,,,,,,,,,,, -149900,3.6627412,2.6407506,,,,,,,,,,,,,, -150000,3.5443385,2.6072016,,,,,,,,,,,,,, -150100,3.5983984,2.5726748,,,,,,,,,,,,,, -150200,3.554989,2.6893697,,,,,,,,,,,,,, -150300,3.5873375,2.6248107,,,,,,,,,,,,,, -150400,3.5797424,2.6311007,,,,,,,,,,,,,, -150500,3.9426963,2.6957898,,,,,,,,,,,,,, -150600,3.9012225,2.628952,,,,,,,,,,,,,, -150700,3.9860106,2.702968,,,,,,,,,,,,,, -150800,3.7680657,2.6098762,,,,,,,,,,,,,, -150900,3.7310872,2.6306047,,,,,,,,,,,,,, -151000,3.805762,2.5929973,,,,,,,,,,,,,, -151100,3.7026963,2.6020489,,,,,,,,,,,,,, -151200,3.9128215,2.6678967,,,,,,,,,,,,,, -151209,,,0.9214365482330322,0.5090224146842957,0.7524799704551697,1.1925246715545654,50000.0,0.6265000104904175,1.8258957862854004,10000.0,51040.28805708885,52833.48942470551,51040.28805708885,1783.8296740055084,4.057976961135864,0.0 -151300,3.6643581,2.6572416,,,,,,,,,,,,,, -151400,3.6080897,2.6176395,,,,,,,,,,,,,, -151500,3.7365186,2.6577444,,,,,,,,,,,,,, -151600,3.729253,2.6505356,,,,,,,,,,,,,, -151700,3.670932,2.607027,,,,,,,,,,,,,, -151800,4.0213723,2.6531181,,,,,,,,,,,,,, -151900,3.7850106,2.611436,,,,,,,,,,,,,, -152000,3.7246974,2.6206589,,,,,,,,,,,,,, -152100,4.1618032,2.6142585,,,,,,,,,,,,,, -152200,4.0252905,2.608636,,,,,,,,,,,,,, -152300,3.769155,2.5883913,,,,,,,,,,,,,, -152400,3.8936222,2.7144806,,,,,,,,,,,,,, -152500,4.079853,2.5925086,,,,,,,,,,,,,, -152600,4.1554956,2.6500118,,,,,,,,,,,,,, -152700,3.7572145,2.6000946,,,,,,,,,,,,,, -152723,,,0.9169324040412904,0.526613175868988,0.7518599629402161,1.19866681098938,50000.0,0.6246000528335571,1.8432016372680664,10000.0,51550.47252988815,53361.15016055107,51550.47252988815,1801.2003610134125,4.109041690826416,0.0 -152800,3.7970657,2.6436753,,,,,,,,,,,,,, -152900,3.8312802,2.6454487,,,,,,,,,,,,,, -153000,3.9747298,2.6182466,,,,,,,,,,,,,, -153100,3.8466773,2.5831637,,,,,,,,,,,,,, -153200,4.106427,2.6535945,,,,,,,,,,,,,, -153300,4.115782,2.6112301,,,,,,,,,,,,,, -153400,4.222205,2.6655338,,,,,,,,,,,,,, -153500,3.8152282,2.620261,,,,,,,,,,,,,, -153600,3.9523375,2.6542912,,,,,,,,,,,,,, -153700,3.8583455,2.6301398,,,,,,,,,,,,,, -153800,3.8974175,2.6432192,,,,,,,,,,,,,, -153900,3.8194256,2.60335,,,,,,,,,,,,,, -154000,3.7995548,2.6836467,,,,,,,,,,,,,, -154100,4.067729,2.647892,,,,,,,,,,,,,, -154200,4.000972,2.6631718,,,,,,,,,,,,,, -154235,,,0.9125677347183228,0.5550875067710876,0.7528600096702576,1.2101103067398071,50000.0,0.6273000240325928,1.8422486782073968,10000.0,52060.446773052216,53888.86126732826,52060.446773052216,1818.829811096192,4.163464784622192,0.0 -154300,3.8730373,2.5754542,,,,,,,,,,,,,, -154400,3.6094,2.6298046,,,,,,,,,,,,,, -154500,3.9410293,2.5915685,,,,,,,,,,,,,, -154600,3.6207063,2.5350509,,,,,,,,,,,,,, -154700,3.7723997,2.5971727,,,,,,,,,,,,,, -154800,3.814602,2.5811865,,,,,,,,,,,,,, -154900,3.675901,2.5481138,,,,,,,,,,,,,, -155000,4.115887,2.6107843,,,,,,,,,,,,,, -155100,3.944535,2.6116364,,,,,,,,,,,,,, -155200,3.9950504,2.5473573,,,,,,,,,,,,,, -155300,3.7664335,2.6268718,,,,,,,,,,,,,, -155400,4.1723795,2.5881867,,,,,,,,,,,,,, -155500,4.193735,2.640396,,,,,,,,,,,,,, -155600,3.9104633,2.631461,,,,,,,,,,,,,, -155700,4.11352,2.6131983,,,,,,,,,,,,,, -155748,,,0.9142418503761292,0.544491171836853,0.7534799575805664,1.1975823640823364,50000.0,0.6277000308036804,1.836329817771912,10000.0,52570.49973154068,54416.66174650192,52570.49973154068,1836.47781085968,4.209648609161377,0.0 -155800,3.9199622,2.5809698,,,,,,,,,,,,,, -155900,3.877759,2.5884845,,,,,,,,,,,,,, -156000,3.8808827,2.6390383,,,,,,,,,,,,,, -156100,3.8300602,2.5793428,,,,,,,,,,,,,, -156200,3.9504743,2.5658336,,,,,,,,,,,,,, -156300,3.7577531,2.5965235,,,,,,,,,,,,,, -156400,4.0264683,2.628402,,,,,,,,,,,,,, -156500,3.7520955,2.5691571,,,,,,,,,,,,,, -156600,4.174228,2.6835632,,,,,,,,,,,,,, -156700,3.811793,2.6088622,,,,,,,,,,,,,, -156800,3.9113328,2.602247,,,,,,,,,,,,,, -156900,4.0685205,2.6442235,,,,,,,,,,,,,, -157000,4.0504537,2.6328847,,,,,,,,,,,,,, -157100,3.9598427,2.6088638,,,,,,,,,,,,,, -157200,3.7260156,2.5737383,,,,,,,,,,,,,, -157261,,,0.9153977632522584,0.5491384267807007,0.7535799741744995,1.2018764019012451,50000.0,0.6292000412940979,1.8402178287506104,10000.0,53080.50154972077,54944.76948451996,53080.50154972077,1854.4784564971924,4.263177156448364,0.0 -157300,4.0846896,2.5772758,,,,,,,,,,,,,, -157400,3.7221646,2.618596,,,,,,,,,,,,,, -157500,4.0218425,2.607132,,,,,,,,,,,,,, -157600,3.655475,2.5864518,,,,,,,,,,,,,, -157700,3.6884365,2.5485835,,,,,,,,,,,,,, -157800,4.2872076,2.6121442,,,,,,,,,,,,,, -157900,3.868039,2.6264284,,,,,,,,,,,,,, -158000,4.142761,2.6177392,,,,,,,,,,,,,, -158100,4.0887218,2.6490064,,,,,,,,,,,,,, -158200,3.875366,2.6281831,,,,,,,,,,,,,, -158300,3.9577417,2.576045,,,,,,,,,,,,,, -158400,3.9060447,2.5638468,,,,,,,,,,,,,, -158500,3.996956,2.6478436,,,,,,,,,,,,,, -158600,4.1077065,2.5977347,,,,,,,,,,,,,, -158700,3.8541389,2.57055,,,,,,,,,,,,,, -158773,,,0.9156568646430968,0.5355087518692017,0.7560999989509583,1.1917190551757812,50000.0,0.6296000480651855,1.8320457935333248,10000.0,53590.45152449608,55472.29345464706,53590.45152449608,1871.9461405277248,4.316278457641602,0.0 -158800,3.9036462,2.6245604,,,,,,,,,,,,,, -158900,4.016567,2.5934548,,,,,,,,,,,,,, -159000,4.268785,2.5980577,,,,,,,,,,,,,, -159100,4.07075,2.5590794,,,,,,,,,,,,,, -159200,3.8225105,2.6263826,,,,,,,,,,,,,, -159300,3.7341123,2.5437627,,,,,,,,,,,,,, -159400,3.7292314,2.5344985,,,,,,,,,,,,,, -159500,4.412644,2.635578,,,,,,,,,,,,,, -159600,3.8422809,2.591733,,,,,,,,,,,,,, -159700,3.8467286,2.5289738,,,,,,,,,,,,,, -159800,4.0465794,2.5250304,,,,,,,,,,,,,, -159900,4.1891456,2.5889027,,,,,,,,,,,,,, -160000,3.7769866,2.5960546,,,,,,,,,,,,,, -160100,4.047554,2.6253793,,,,,,,,,,,,,, -160200,4.07196,2.6170769,,,,,,,,,,,,,, -160286,,,0.9319993257522584,0.4888227581977844,0.7568399906158447,1.1966514587402344,50000.0,0.6318000555038452,1.8300745487213133,10000.0,54100.61996221542,56000.02527117729,54100.61996221542,1889.3943195343013,4.3780577182769775,0.0 -160300,3.9744897,2.5678205,,,,,,,,,,,,,, -160400,4.200192,2.6327941,,,,,,,,,,,,,, -160500,4.2735844,2.5951922,,,,,,,,,,,,,, -160600,4.1826415,2.5612624,,,,,,,,,,,,,, -160700,4.11766,2.6229804,,,,,,,,,,,,,, -160800,4.102447,2.6420317,,,,,,,,,,,,,, -160900,3.8846738,2.5841255,,,,,,,,,,,,,, -161000,3.9589512,2.5982075,,,,,,,,,,,,,, -161100,4.072777,2.6249619,,,,,,,,,,,,,, -161200,3.8282988,2.506553,,,,,,,,,,,,,, -161300,3.8142724,2.5304744,,,,,,,,,,,,,, -161400,3.9479852,2.5478814,,,,,,,,,,,,,, -161500,4.028605,2.555116,,,,,,,,,,,,,, -161600,3.8565152,2.595364,,,,,,,,,,,,,, -161700,4.1865244,2.5466278,,,,,,,,,,,,,, -161799,,,0.9300462007522584,0.4880435168743133,0.7581999897956848,1.1818283796310425,50000.0,0.6349000334739685,1.818547010421753,10000.0,54610.668021678925,56527.60770535469,54610.668021678925,1906.8222138881683,4.4321160316467285,0.0 -161800,3.9620526,2.5134933,,,,,,,,,,,,,, -161900,3.935204,2.5629082,,,,,,,,,,,,,, -162000,3.5554934,2.515167,,,,,,,,,,,,,, -162100,3.8611636,2.5509548,,,,,,,,,,,,,, -162200,4.1491714,2.5843368,,,,,,,,,,,,,, -162300,4.000608,2.5837638,,,,,,,,,,,,,, -162400,3.803219,2.5547848,,,,,,,,,,,,,, -162500,4.2643332,2.578299,,,,,,,,,,,,,, -162600,4.0958824,2.5663908,,,,,,,,,,,,,, -162700,4.010541,2.5953677,,,,,,,,,,,,,, -162800,3.9782372,2.6126354,,,,,,,,,,,,,, -162900,4.2583265,2.6270957,,,,,,,,,,,,,, -163000,4.085951,2.5932734,,,,,,,,,,,,,, -163100,4.092385,2.6133673,,,,,,,,,,,,,, -163200,4.508002,2.5912797,,,,,,,,,,,,,, -163300,3.6275117,2.5175319,,,,,,,,,,,,,, -163311,,,0.9286909699440002,0.4972872734069824,0.7575799822807312,1.1837843656539917,50000.0,0.6351000070571899,1.817732334136963,10000.0,55120.69439053536,57055.07130908966,55120.69439053536,1924.1260414123533,4.511041164398193,0.0 -163400,4.073987,2.5876794,,,,,,,,,,,,,, -163500,3.7965932,2.5181358,,,,,,,,,,,,,, -163600,4.0530047,2.5725775,,,,,,,,,,,,,, -163700,4.239175,2.5763612,,,,,,,,,,,,,, -163800,3.9506607,2.5425124,,,,,,,,,,,,,, -163900,4.302896,2.6197674,,,,,,,,,,,,,, -164000,3.9136915,2.5490818,,,,,,,,,,,,,, -164100,4.3072534,2.6227784,,,,,,,,,,,,,, -164200,4.0862694,2.593904,,,,,,,,,,,,,, -164300,4.108467,2.561079,,,,,,,,,,,,,, -164400,4.321875,2.6082234,,,,,,,,,,,,,, -164500,4.233028,2.6054637,,,,,,,,,,,,,, -164600,3.9546266,2.5298636,,,,,,,,,,,,,, -164700,3.9274323,2.5670836,,,,,,,,,,,,,, -164800,3.8052611,2.4550407,,,,,,,,,,,,,, -164824,,,0.9278140664100648,0.496550053358078,0.7587599754333496,1.1828312873840332,50000.0,0.6360000371932983,1.8208544254302976,10000.0,55630.80797767639,57582.89201283455,55630.80797767639,1941.725725650788,4.56522536277771,0.0 -164900,4.2238965,2.5729175,,,,,,,,,,,,,, -165000,4.0352983,2.5561826,,,,,,,,,,,,,, -165100,3.9544885,2.546712,,,,,,,,,,,,,, -165200,4.526474,2.6004734,,,,,,,,,,,,,, -165300,4.0859184,2.5741663,,,,,,,,,,,,,, -165400,3.8634024,2.5292645,,,,,,,,,,,,,, -165500,4.145787,2.559436,,,,,,,,,,,,,, -165600,4.10634,2.503909,,,,,,,,,,,,,, -165700,3.948491,2.5631816,,,,,,,,,,,,,, -165800,3.9333186,2.5823984,,,,,,,,,,,,,, -165900,3.9179897,2.5467527,,,,,,,,,,,,,, -166000,4.136917,2.5716784,,,,,,,,,,,,,, -166100,4.0593834,2.481324,,,,,,,,,,,,,, -166200,4.708854,2.6216578,,,,,,,,,,,,,, -166300,3.9888492,2.5647047,,,,,,,,,,,,,, -166336,,,0.9301857352256776,0.4875613152980804,0.7597599625587463,1.1802102327346802,50000.0,0.6350000500679016,1.8172756433486936,10000.0,56140.73363828659,58110.581184625626,56140.73363828659,1959.3804275989528,4.620617151260376,0.0 -166400,4.0404916,2.5587587,,,,,,,,,,,,,, -166500,3.9812815,2.5504687,,,,,,,,,,,,,, -166600,4.489358,2.5576127,,,,,,,,,,,,,, -166700,4.1759834,2.5162377,,,,,,,,,,,,,, -166800,4.1169524,2.5656462,,,,,,,,,,,,,, -166900,3.872324,2.532836,,,,,,,,,,,,,, -167000,3.9751353,2.5444105,,,,,,,,,,,,,, -167100,4.141607,2.55503,,,,,,,,,,,,,, -167200,4.002877,2.5344248,,,,,,,,,,,,,, -167300,3.9070857,2.5398657,,,,,,,,,,,,,, -167400,3.8496048,2.4945672,,,,,,,,,,,,,, -167500,3.9806807,2.5363662,,,,,,,,,,,,,, -167600,4.2787085,2.5118346,,,,,,,,,,,,,, -167700,3.9238605,2.5322766,,,,,,,,,,,,,, -167800,4.294153,2.5660996,,,,,,,,,,,,,, -167849,,,0.9341716766357422,0.4808126091957092,0.7604199647903442,1.180977702140808,50000.0,0.6374000310897827,1.8157609701156616,10000.0,56650.924723148346,58638.559196949005,56650.924723148346,1977.0555260181427,4.680642366409302,0.0 -167900,4.656649,2.6292818,,,,,,,,,,,,,, -168000,4.429195,2.567473,,,,,,,,,,,,,, -168100,4.007959,2.542251,,,,,,,,,,,,,, -168200,4.035111,2.514061,,,,,,,,,,,,,, -168300,4.4351315,2.5506637,,,,,,,,,,,,,, -168400,4.0516768,2.5114338,,,,,,,,,,,,,, -168500,4.4380836,2.5297894,,,,,,,,,,,,,, -168600,4.32162,2.5443258,,,,,,,,,,,,,, -168700,3.8756542,2.5466397,,,,,,,,,,,,,, -168800,4.2737455,2.592658,,,,,,,,,,,,,, -168900,4.189036,2.5366178,,,,,,,,,,,,,, -169000,4.1306453,2.562274,,,,,,,,,,,,,, -169100,4.087012,2.5557413,,,,,,,,,,,,,, -169200,3.9597597,2.563527,,,,,,,,,,,,,, -169300,4.252098,2.4944677,,,,,,,,,,,,,, -169362,,,0.9384565949440002,0.4624026119709015,0.7603200078010559,1.1795930862426758,50000.0,0.6372000575065613,1.8115358352661133,10000.0,57161.05779337883,59166.4287109375,57161.05779337883,1994.679827213288,4.739821195602417,0.0 -169400,4.1958575,2.4982076,,,,,,,,,,,,,, -169500,4.2511573,2.6034746,,,,,,,,,,,,,, -169600,3.8717961,2.5247884,,,,,,,,,,,,,, -169700,3.9955897,2.557658,,,,,,,,,,,,,, -169800,3.9646773,2.4909441,,,,,,,,,,,,,, -169900,3.6002169,2.4972153,,,,,,,,,,,,,, -170000,3.925064,2.55192,,,,,,,,,,,,,, -170100,4.3641148,2.5317025,,,,,,,,,,,,,, -170200,3.9392033,2.5425167,,,,,,,,,,,,,, -170300,4.156729,2.528441,,,,,,,,,,,,,, -170400,4.0887218,2.478324,,,,,,,,,,,,,, -170500,3.8659663,2.531828,,,,,,,,,,,,,, -170600,4.019816,2.5717685,,,,,,,,,,,,,, -170700,4.1141167,2.5311394,,,,,,,,,,,,,, -170800,4.1221075,2.5569613,,,,,,,,,,,,,, -170875,,,0.9371013641357422,0.4609328508377075,0.7616399526596069,1.170788288116455,50000.0,0.6381000280380249,1.8069534301757808,10000.0,57671.088973522186,59693.97801613808,57671.088973522186,2012.0857291221616,4.797400236129761,0.0 -170900,3.808795,2.531612,,,,,,,,,,,,,, -171000,3.8209677,2.5213695,,,,,,,,,,,,,, -171100,4.1458745,2.5392153,,,,,,,,,,,,,, -171200,4.0024776,2.5304382,,,,,,,,,,,,,, -171300,3.9656782,2.52874,,,,,,,,,,,,,, -171400,3.8793948,2.5217562,,,,,,,,,,,,,, -171500,4.286841,2.5871696,,,,,,,,,,,,,, -171600,3.8164787,2.4929094,,,,,,,,,,,,,, -171700,5.000035,2.612864,,,,,,,,,,,,,, -171800,3.9881127,2.5200176,,,,,,,,,,,,,, -171900,4.4537826,2.5313733,,,,,,,,,,,,,, -172000,4.127736,2.5769324,,,,,,,,,,,,,, -172100,4.050106,2.529285,,,,,,,,,,,,,, -172200,4.073254,2.556778,,,,,,,,,,,,,, -172300,4.1029363,2.5569708,,,,,,,,,,,,,, -172387,,,0.9367625713348388,0.4671074450016022,0.7630599737167358,1.1761828660964966,50000.0,0.6375000476837158,1.80931556224823,10000.0,58181.277435302734,60221.7962770462,58181.277435302734,2029.605746269226,4.853551864624023,0.0 -172400,3.9150221,2.455811,,,,,,,,,,,,,, -172500,4.1656876,2.5170856,,,,,,,,,,,,,, -172600,4.2237196,2.5437717,,,,,,,,,,,,,, -172700,3.908097,2.532535,,,,,,,,,,,,,, -172800,4.484127,2.526195,,,,,,,,,,,,,, -172900,4.3529315,2.5054314,,,,,,,,,,,,,, -173000,4.3481674,2.5189888,,,,,,,,,,,,,, -173100,4.0463,2.5323596,,,,,,,,,,,,,, -173200,4.0863156,2.5692275,,,,,,,,,,,,,, -173300,4.0484915,2.527146,,,,,,,,,,,,,, -173400,4.2068567,2.520156,,,,,,,,,,,,,, -173500,4.1228466,2.495933,,,,,,,,,,,,,, -173600,3.966235,2.487884,,,,,,,,,,,,,, -173700,4.0708704,2.5428941,,,,,,,,,,,,,, -173800,4.262625,2.557098,,,,,,,,,,,,,, -173900,,,0.9363639950752258,0.4647765457630157,0.7618199586868286,1.1747890710830688,50000.0,0.6363000273704529,1.8088829517364504,10000.0,58691.458490133286,60749.449186086655,58691.458490133286,2046.9627692699432,4.915642261505127,0.0 -173900,4.4580646,2.5392609,,,,,,,,,,,,,, -174000,4.154515,2.4832654,,,,,,,,,,,,,, -174100,4.498116,2.5317264,,,,,,,,,,,,,, -174200,4.252042,2.5333319,,,,,,,,,,,,,, -174300,4.423094,2.5647922,,,,,,,,,,,,,, -174400,4.155915,2.5419903,,,,,,,,,,,,,, -174500,4.211814,2.5109165,,,,,,,,,,,,,, -174600,4.3102565,2.5468512,,,,,,,,,,,,,, -174700,4.092504,2.4974327,,,,,,,,,,,,,, -174800,4.452943,2.5560052,,,,,,,,,,,,,, -174900,3.9802017,2.4809957,,,,,,,,,,,,,, -175000,4.2308292,2.5402946,,,,,,,,,,,,,, -175100,4.1252894,2.4924445,,,,,,,,,,,,,, -175200,4.00646,2.5372517,,,,,,,,,,,,,, -175300,4.2220845,2.524011,,,,,,,,,,,,,, -175400,3.9293485,2.544692,,,,,,,,,,,,,, -175413,,,0.9368223547935486,0.4632803499698639,0.7626799941062927,1.174310326576233,50000.0,0.6386000514030457,1.8084932565689087,10000.0,59201.67768287659,61277.555790662766,59201.67768287659,2064.736491918564,4.975682020187378,0.0 -175500,4.133634,2.529698,,,,,,,,,,,,,, -175600,4.384314,2.4760613,,,,,,,,,,,,,, -175700,3.8889198,2.4716644,,,,,,,,,,,,,, -175800,4.2467184,2.5761178,,,,,,,,,,,,,, -175900,4.22974,2.507502,,,,,,,,,,,,,, -176000,3.958027,2.5080726,,,,,,,,,,,,,, -176100,4.0577955,2.5900786,,,,,,,,,,,,,, -176200,4.2937617,2.5194016,,,,,,,,,,,,,, -176300,4.127481,2.5514994,,,,,,,,,,,,,, -176400,4.040116,2.5159013,,,,,,,,,,,,,, -176500,4.1167493,2.5019443,,,,,,,,,,,,,, -176600,4.0608783,2.4752965,,,,,,,,,,,,,, -176700,4.2708135,2.5211468,,,,,,,,,,,,,, -176800,4.044672,2.4719448,,,,,,,,,,,,,, -176900,4.332269,2.5495825,,,,,,,,,,,,,, -176926,,,0.9379583597183228,0.4658437371253967,0.7636199593544006,1.175073504447937,50000.0,0.6410000324249268,1.8080230951309204,10000.0,59711.65766215325,61805.24752783775,59711.65766215325,2082.3390328884125,5.031782627105713,0.0 -177000,4.469068,2.5366123,,,,,,,,,,,,,, -177100,4.073526,2.5191693,,,,,,,,,,,,,, -177200,4.066259,2.5603933,,,,,,,,,,,,,, -177300,4.284318,2.518127,,,,,,,,,,,,,, -177400,4.234683,2.5371046,,,,,,,,,,,,,, -177500,3.8096423,2.4586334,,,,,,,,,,,,,, -177600,4.1215925,2.558576,,,,,,,,,,,,,, -177700,4.277324,2.559114,,,,,,,,,,,,,, -177800,4.3178015,2.4917939,,,,,,,,,,,,,, -177900,4.5613804,2.5201144,,,,,,,,,,,,,, -178000,4.1733913,2.5042105,,,,,,,,,,,,,, -178100,4.22457,2.4564643,,,,,,,,,,,,,, -178200,3.9229772,2.4783216,,,,,,,,,,,,,, -178300,4.0634403,2.542612,,,,,,,,,,,,,, -178400,4.322175,2.5550644,,,,,,,,,,,,,, -178439,,,0.9407883882522584,0.4490717649459839,0.7630800008773804,1.169806957244873,50000.0,0.6402000188827515,1.8037229776382449,10000.0,60221.87917423248,62333.05004048348,60221.87917423248,2099.8054864406586,5.092525005340576,0.0 -178500,4.029507,2.552944,,,,,,,,,,,,,, -178600,4.230369,2.4977949,,,,,,,,,,,,,, -178700,4.1682034,2.524031,,,,,,,,,,,,,, -178800,4.431843,2.557308,,,,,,,,,,,,,, -178900,4.0534363,2.5418744,,,,,,,,,,,,,, -179000,4.3354483,2.5028229,,,,,,,,,,,,,, -179100,4.0805326,2.4372776,,,,,,,,,,,,,, -179200,4.1987033,2.5118139,,,,,,,,,,,,,, -179300,4.384496,2.4830184,,,,,,,,,,,,,, -179400,4.2782397,2.5525906,,,,,,,,,,,,,, -179500,4.237922,2.550637,,,,,,,,,,,,,, -179600,3.8948348,2.5507827,,,,,,,,,,,,,, -179700,4.401483,2.5334096,,,,,,,,,,,,,, -179800,4.24662,2.5532255,,,,,,,,,,,,,, -179900,3.9937334,2.5267327,,,,,,,,,,,,,, -179952,,,0.940808355808258,0.4544986188411712,0.7629599571228027,1.1712826490402222,50000.0,0.640500009059906,1.8030924797058103,10000.0,60732.035746097565,62860.571590185165,60732.035746097565,2117.058675289154,5.151561737060547,0.0 -180000,3.720044,2.4410682,,,,,,,,,,,,,, -180100,3.8781438,2.5161836,,,,,,,,,,,,,, -180200,4.1039696,2.4949307,,,,,,,,,,,,,, -180300,4.258729,2.527395,,,,,,,,,,,,,, -180400,4.148865,2.5484853,,,,,,,,,,,,,, -180500,4.0373588,2.544854,,,,,,,,,,,,,, -180600,4.254353,2.4656584,,,,,,,,,,,,,, -180700,4.489238,2.5210986,,,,,,,,,,,,,, -180800,4.6217227,2.5495377,,,,,,,,,,,,,, -180900,4.0473604,2.547113,,,,,,,,,,,,,, -181000,3.974167,2.530482,,,,,,,,,,,,,, -181100,3.9200766,2.5598125,,,,,,,,,,,,,, -181200,4.2896223,2.5905023,,,,,,,,,,,,,, -181300,3.9837794,2.4697576,,,,,,,,,,,,,, -181400,4.2241945,2.523612,,,,,,,,,,,,,, -181464,,,0.939871609210968,0.4563002586364746,0.7633999586105347,1.173967719078064,50000.0,0.6407000422477722,1.8057492971420288,10000.0,61242.113387584686,63388.0883102417,61242.113387584686,2134.386967897415,5.21010160446167,0.0 -181500,4.2251644,2.5351868,,,,,,,,,,,,,, -181600,4.1047626,2.5227551,,,,,,,,,,,,,, -181700,4.2194076,2.5505934,,,,,,,,,,,,,, -181800,3.9896479,2.5165794,,,,,,,,,,,,,, -181900,3.990296,2.5393007,,,,,,,,,,,,,, -182000,4.1575284,2.509142,,,,,,,,,,,,,, -182100,4.0898013,2.4933722,,,,,,,,,,,,,, -182200,4.1020384,2.531326,,,,,,,,,,,,,, -182300,4.1395445,2.5539532,,,,,,,,,,,,,, -182400,4.150146,2.5186965,,,,,,,,,,,,,, -182500,4.347947,2.5185454,,,,,,,,,,,,,, -182600,3.8596241,2.4426434,,,,,,,,,,,,,, -182700,4.5410876,2.589917,,,,,,,,,,,,,, -182800,4.084043,2.472989,,,,,,,,,,,,,, -182900,4.1672277,2.5121565,,,,,,,,,,,,,, -182977,,,0.939851701259613,0.4527520835399627,0.7633599638938904,1.1705142259597778,50000.0,0.6399000287055969,1.800960659980774,10000.0,61752.08497405052,63915.73721885681,61752.08497405052,2151.952211856842,5.269731998443604,0.0 -183000,4.1859636,2.5606906,,,,,,,,,,,,,, -183100,4.1580386,2.5560791,,,,,,,,,,,,,, -183200,4.2794905,2.5046086,,,,,,,,,,,,,, -183300,3.824577,2.505033,,,,,,,,,,,,,, -183400,4.0817122,2.484505,,,,,,,,,,,,,, -183500,4.0120525,2.498147,,,,,,,,,,,,,, -183600,4.144421,2.512945,,,,,,,,,,,,,, -183700,3.9309568,2.519341,,,,,,,,,,,,,, -183800,4.328236,2.57994,,,,,,,,,,,,,, -183900,3.9192188,2.465816,,,,,,,,,,,,,, -184000,3.9372318,2.5370934,,,,,,,,,,,,,, -184100,3.9589617,2.5104296,,,,,,,,,,,,,, -184200,4.3096204,2.5137377,,,,,,,,,,,,,, -184300,3.9641361,2.507762,,,,,,,,,,,,,, -184400,4.47974,2.4923406,,,,,,,,,,,,,, -184489,,,0.93949294090271,0.4551874995231628,0.7637400031089783,1.1731160879135132,50000.0,0.6402000188827515,1.80498480796814,10000.0,62262.14775061607,64443.61814188957,62262.14775061607,2169.6572070121765,5.3304524421691895,0.0 -184500,4.272539,2.5032167,,,,,,,,,,,,,, -184600,3.9864008,2.4942575,,,,,,,,,,,,,, -184700,3.892571,2.5073485,,,,,,,,,,,,,, -184800,4.5111094,2.5091407,,,,,,,,,,,,,, -184900,4.147476,2.5720131,,,,,,,,,,,,,, -185000,4.1115627,2.5051126,,,,,,,,,,,,,, -185100,4.0522685,2.4685502,,,,,,,,,,,,,, -185200,3.8361416,2.4833684,,,,,,,,,,,,,, -185300,4.2512984,2.4484305,,,,,,,,,,,,,, -185400,3.969208,2.5020266,,,,,,,,,,,,,, -185500,4.047456,2.5109324,,,,,,,,,,,,,, -185600,4.273656,2.5547833,,,,,,,,,,,,,, -185700,4.066029,2.5451057,,,,,,,,,,,,,, -185800,3.9836574,2.509951,,,,,,,,,,,,,, -185900,4.1940193,2.5133886,,,,,,,,,,,,,, -186000,4.12831,2.5258226,,,,,,,,,,,,,, -186002,,,0.9404894709587096,0.4521209001541137,0.7636399865150452,1.167830228805542,50000.0,0.6404000520706177,1.8014260530471802,10000.0,62772.11812114716,64971.04134345055,62772.11812114716,2186.9935114383698,5.393200159072876,0.0 -186100,3.9324415,2.5120234,,,,,,,,,,,,,, -186200,4.818522,2.5442595,,,,,,,,,,,,,, -186300,4.0133123,2.5129702,,,,,,,,,,,,,, -186400,3.985203,2.491315,,,,,,,,,,,,,, -186500,4.1450934,2.4750242,,,,,,,,,,,,,, -186600,3.9786983,2.5488472,,,,,,,,,,,,,, -186666,,,0.9403100609779358,0.4499163925647735,0.763759970664978,1.1691426038742063,50000.0,0.6404000520706177,1.8022942543029783,10000.0,62996.11162734032,65212.594963788986,62996.11162734032,2204.471424341202,5.451512098312378,0.0 -186666,,,,,,,,,,,62996.11162734032,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 90ed4f875..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.900901556015015,0.0,34.98110485076904,1,0,34.98110485076904,0.0009000000427477,6.912177562713623,10000,52.88213634490967,0.0011360011994838,6.912071704864502,0.0011199999134987,6.912059783935547,50000 -35.82946801185608,0.0212481021881103,545.0515990257263,1508,0,545.0515990257263,0.0460000038146972,5.649470329284668,10000,580.9545924663544,0.0686383917927742,5.3621506690979,0.0652799978852272,5.418907165527344,50000 -53.52310228347778,0.0560157299041748,1055.0695431232452,3015,0,1055.0695431232452,0.1108000054955482,4.82637882232666,10000,1108.7527313232422,0.1673309952020645,4.278256893157959,0.1536999940872192,4.376176357269287,50000 -71.17259407043457,0.0849447250366211,1565.034935951233,4521,0,1565.034935951233,0.1792000085115432,4.221713066101074,10000,1636.450201511383,0.2669403553009033,3.528772592544556,0.245739996433258,3.66829776763916,50000 -89.82699704170227,0.1136150360107421,2075.2632479667664,6028,0,2075.2632479667664,0.249300017952919,3.7188665866851807,10000,2165.413183450699,0.3638791441917419,2.9160544872283936,0.3346199989318847,3.083531141281128,50000 -107.9004201889038,0.1459236145019531,2585.3535408973694,7535,0,2585.3535408973694,0.3009000122547149,3.3908586502075195,10000,2693.6622858047485,0.4528658986091614,2.3958868980407715,0.4010199904441833,2.6911749839782715,50000 -126.01309251785278,0.1767423152923584,3095.388578891754,9043,0,3095.388578891754,0.3348000049591064,3.1936228275299072,10000,3221.8932802677155,0.4929049611091614,2.165973901748657,0.438539981842041,2.4691154956817627,50000 -143.53465509414673,0.2061150074005127,3605.421551465988,10551,0,3605.421551465988,0.3759000301361084,2.9420087337493896,10000,3749.529512166977,0.5285794138908386,1.9861119985580444,0.4908799827098846,2.2068440914154053,50000 -161.0859730243683,0.2334010601043701,4115.560308218002,12061,0,4115.560308218002,0.3946000039577484,2.823363780975342,10000,4277.298977375031,0.549226701259613,1.8957043886184688,0.5112400054931641,2.1058874130249023,50000 -178.88615822792053,0.2622287273406982,4625.673633098602,13571,0,4625.673633098602,0.4144000113010406,2.7417922019958496,10000,4805.292775630951,0.5690967440605164,1.7911511659622192,0.5332599878311157,1.990215301513672,50000 -196.78649830818176,0.2955081462860107,5135.629874706268,15082,0,5135.629874706268,0.4231000244617462,2.674654245376587,10000,5333.236327886581,0.5841637253761292,1.741123914718628,0.5440999865531921,1.943701148033142,50000 -214.31005549430847,0.3311014175415039,5645.554906845093,16593,0,5645.554906845093,0.4305000305175781,2.623843193054199,10000,5860.77290892601,0.6349848508834839,1.4669139385223389,0.557200014591217,1.8688576221466064,50000 -231.911938905716,0.3624053001403808,6155.727502822876,18105,0,6155.727502822876,0.4300000071525574,2.625067234039306,10000,6388.630298376083,0.6109095811843872,1.5742591619491575,0.5535399913787842,1.891344428062439,50000 -249.50973081588745,0.394974946975708,6665.842526435852,19617,0,6665.842526435852,0.4397000074386596,2.6097118854522705,10000,6916.42732667923,0.6103914380073547,1.596194624900818,0.5658800005912781,1.8382683992385864,50000 -267.0219874382019,0.424980878829956,7175.828543901444,21130,0,7175.828543901444,0.4401000142097473,2.5855188369750977,10000,7444.008017539978,0.6112683415412903,1.598503351211548,0.5659999847412109,1.84242594242096,50000 -284.88498640060425,0.4551308155059814,7685.893748044968,22643,0,7685.893748044968,0.4510000348091125,2.5334506034851074,10000,7972.019269227982,0.6199776530265808,1.546796441078186,0.5753399729728699,1.773450255393982,50000 -302.57521986961365,0.4855742454528808,8196.035193443298,24156,0,8196.035193443298,0.4533000290393829,2.5400712490081787,10000,8499.933848619461,0.6190210580825806,1.5578505992889404,0.5760399699211121,1.7675529718399048,50000 -320.2084016799927,0.51666259765625,8706.256494283676,25669,0,8706.256494283676,0.4646000266075134,2.441926956176758,10000,9027.87317752838,0.6695631146430969,1.318595290184021,0.5854399800300598,1.7338712215423584,50000 -337.96847558021545,0.549079179763794,9216.314534425735,27182,0,9216.314534425735,0.4656000137329101,2.442213535308838,10000,9555.776126384735,0.6529615521430969,1.4048078060150146,0.5888800024986267,1.7101259231567385,50000 -355.44871044158936,0.5797502994537354,9726.26763677597,28695,0,9726.26763677597,0.4537000358104706,2.503017902374268,10000,10083.292204618454,0.6298230290412903,1.5027976036071775,0.5792199969291687,1.772684931755066,50000 -373.15779161453247,0.6124565601348877,10236.290426254272,30209,0,10236.290426254272,0.4670000076293945,2.455329656600952,10000,10611.109208583832,0.6353435516357422,1.4675039052963257,0.5884400010108948,1.714426040649414,50000 -390.7127459049225,0.6449933052062988,10746.51861667633,31723,0,10746.51861667633,0.4664000272750854,2.412381172180176,10000,11138.977312088013,0.6358418464660645,1.4786940813064575,0.5914999842643738,1.7083582878112793,50000 -408.448246717453,0.680262565612793,11256.555840015411,33237,0,11256.555840015411,0.4757000207901001,2.400609254837036,10000,11666.838463544846,0.639668345451355,1.453246831893921,0.5940999984741211,1.6839145421981812,50000 -426.0376763343811,0.717193603515625,11766.602644443512,34751,0,11766.602644443512,0.4595000147819519,2.4588377475738525,10000,12194.56355690956,0.6583824753761292,1.3591587543487549,0.5783599615097046,1.7592555284500122,50000 -443.6337375640869,1.826355218887329,12275.55195569992,36262,0,12275.55195569992,0.468500018119812,2.4308671951293945,10000,12722.270855426788,0.6530413031578064,1.3820711374282837,0.5939399600028992,1.7025057077407837,50000 -461.3704402446747,1.859493970870972,12785.783144235613,37777,0,12785.783144235613,0.4780000150203705,2.3967301845550537,10000,13250.32419705391,0.6506098508834839,1.396021008491516,0.5997999906539917,1.6665048599243164,50000 -478.8435335159302,1.8939027786254885,13295.923071146011,39291,0,13295.923071146011,0.4695000350475311,2.4148316383361816,10000,13778.023537874222,0.6446109414100647,1.427964210510254,0.5927000045776367,1.6880282163619995,50000 -496.6829800605774,1.9265995025634768,13806.05341053009,40806,0,13806.05341053009,0.4875000119209289,2.322809934616089,10000,14306.078384160995,0.6479392647743225,1.4033961296081543,0.6071599721908569,1.6209561824798584,50000 -514.1380536556244,1.9594926834106443,14316.080854415894,42321,0,14316.080854415894,0.4732000231742859,2.43147611618042,10000,14833.64709019661,0.6422193646430969,1.4444411993026731,0.5982199907302856,1.6778781414031982,50000 -531.7674918174744,1.993568420410156,14826.026376008987,43835,0,14826.026376008987,0.4761000275611877,2.4039862155914307,10000,15361.308577775955,0.6764788031578064,1.2786322832107544,0.5990399718284607,1.6587549448013306,50000 -549.9132053852081,2.0285537242889404,15336.135235071182,45350,0,15336.135235071182,0.4793000221252441,2.385539293289185,10000,15889.651034116743,0.6533800959587097,1.3688466548919678,0.602180004119873,1.6588586568832395,50000 -567.3988988399506,2.063908100128174,15846.135519742966,46865,0,15846.135519742966,0.4809000194072723,2.3713834285736084,10000,16417.224779605865,0.6542569994926453,1.370741367340088,0.6079999804496765,1.6286641359329224,50000 -584.9136664867401,2.1026611328125,16356.188809633257,48379,0,16356.188809633257,0.4847000241279602,2.338550329208374,10000,16944.88347172737,0.6592593789100647,1.3680092096328735,0.6098399758338928,1.604357361793518,50000 -602.3362815380096,2.140815734863281,16866.31578350067,49894,0,16866.31578350067,0.4773000180721283,2.3852531909942627,10000,17472.523061990738,0.6473612785339355,1.4344884157180786,0.5996400117874146,1.6758544445037842,50000 -620.0699634552002,2.175709724426269,17376.387938022614,51408,0,17376.387938022614,0.4890000224113464,2.2980523109436035,10000,18000.41622543335,0.6523237824440002,1.3949483633041382,0.6103799939155579,1.6124364137649536,50000 -637.6424875259399,2.215566873550415,17886.385818719864,52923,0,17886.385818719864,0.4906000196933746,2.326728820800781,10000,18528.079265117645,0.6839325428009033,1.2379963397979736,0.6139999628067017,1.5940256118774414,50000 -655.2638325691223,2.253763198852539,18396.55052471161,54438,0,18396.55052471161,0.4843000173568725,2.368727922439575,10000,19055.956347703934,0.6634446382522583,1.333989143371582,0.608020007610321,1.6208534240722656,50000 -672.7741053104401,2.2918620109558105,18906.63244438172,55953,0,18906.63244438172,0.4941000342369079,2.28857421875,10000,19583.63987517357,0.6671316623687744,1.318600058555603,0.6162999868392944,1.5899615287780762,50000 -690.3503816127777,2.3279449939727783,19416.74183535576,57468,0,19416.74183535576,0.4899000227451324,2.360703468322754,10000,20111.414390802383,0.6590401530265808,1.3535881042480469,0.6110000014305115,1.6046415567398071,50000 -707.7968149185181,2.370288133621216,19926.887871027,58983,0,19926.887871027,0.4925000369548797,2.331923246383667,10000,20639.101494312286,0.6589404940605164,1.362576246261597,0.6136400103569031,1.60284686088562,50000 -725.9643821716309,2.4173500537872314,20436.94685316085,60498,0,20436.94685316085,0.4991000294685364,2.270559787750244,10000,21167.427884340286,0.6735491156578064,1.2887380123138428,0.620639979839325,1.5495282411575315,50000 -743.5703382492065,2.4481093883514404,20947.05204677581,62014,0,20947.05204677581,0.4821000099182129,2.346834421157837,10000,21695.2219684124,0.6741669178009033,1.2919275760650637,0.6080399751663208,1.623947024345398,50000 -760.9726111888885,2.4869983196258545,21457.020793676376,63529,0,21457.020793676376,0.4971000254154205,2.292905330657959,10000,22222.684602499008,0.6745057106018066,1.275795817375183,0.6182000041007996,1.5899293422698977,50000 -778.5667836666107,2.531611204147339,21966.95255088806,65044,0,21966.95255088806,0.5083000063896179,2.2321155071258545,10000,22750.30890488625,0.6870814561843872,1.2384870052337646,0.6372599601745605,1.5065011978149414,50000 -796.0130662918091,2.5694398880004883,22476.95581459999,66558,0,22476.95581459999,0.5033000111579895,2.2496683597564697,10000,23277.84796833992,0.6863042116165161,1.240123271942139,0.6317600011825562,1.520139217376709,50000 -813.9019508361816,2.613524913787842,22986.95196557045,68073,0,22986.95196557045,0.5074000358581543,2.2572550773620605,10000,23805.82973885536,0.675203263759613,1.27826189994812,0.6322000026702881,1.5258268117904663,50000 -831.6630780696869,2.6471753120422363,23496.924777507786,69588,0,23496.924777507786,0.4936000108718872,2.2800233364105225,10000,24333.650037050247,0.6990393400192261,1.1675734519958496,0.6238600015640259,1.5474979877471924,50000 -849.3363344669342,2.685119152069092,24006.85396337509,71102,0,24006.85396337509,0.5120000243186951,2.2258431911468506,10000,24861.34249138832,0.70023512840271,1.153869867324829,0.632420003414154,1.5144561529159546,50000 -866.8355195522308,2.729054927825928,24516.80539011956,72617,0,24516.80539011956,0.5139000415802002,2.235292196273804,10000,25388.890427351,0.6888153553009033,1.204163670539856,0.6315199732780457,1.518242120742798,50000 -884.3387162685394,2.771080732345581,25026.766946792603,74132,0,25026.766946792603,0.5053000450134277,2.2189886569976807,10000,25916.44958114624,0.6875796914100647,1.2316724061965942,0.6329799890518188,1.49374258518219,50000 -901.9840440750122,2.811651945114136,25536.77026438713,75647,0,25536.77026438713,0.5136000514030457,2.221004009246826,10000,26444.19121265412,0.6935586333274841,1.2051820755004885,0.6389399766921997,1.4845311641693115,50000 -919.7418501377106,2.8506574630737305,26046.95576357841,77163,0,26046.95576357841,0.5131000280380249,2.2043116092681885,10000,26972.225853919983,0.6889349222183228,1.2334457635879517,0.6387199759483337,1.485427737236023,50000 -937.3784308433532,2.897936582565308,26557.14439845085,78679,0,26557.14439845085,0.5126000046730042,2.209225654602051,10000,27500.150496721268,0.721121609210968,1.0787429809570312,0.6339199542999268,1.5018742084503174,50000 -954.9220995903016,2.937781572341919,27067.337899446487,80195,0,27067.337899446487,0.5139999985694885,2.207616090774536,10000,28027.980364322662,0.7099210619926453,1.1163970232009888,0.6413599848747253,1.4619232416152954,50000 -972.7897419929504,2.9804928302764893,27577.38576722145,81710,0,27577.38576722145,0.5199000239372253,2.1657211780548096,10000,28555.9904255867,0.7057955861091614,1.155988097190857,0.6452400088310242,1.4618548154830933,50000 -990.8612344264984,3.0235769748687744,28087.58653569221,83226,0,28087.58653569221,0.5162000060081482,2.1647984981536865,10000,29084.35788321495,0.6995774507522583,1.173744559288025,0.6433599591255188,1.4629638195037842,50000 -1008.616588830948,3.065901756286621,28597.68378353119,84741,0,28597.68378353119,0.5128000378608704,2.198549032211304,10000,29612.305801153183,0.6940369606018066,1.209545612335205,0.6385200023651123,1.4816385507583618,50000 -1026.5237970352173,4.082794904708862,29106.643271446228,86253,0,29106.643271446228,0.5186000466346741,2.212566614151001,10000,30140.24239969253,0.6893335580825806,1.2253234386444092,0.6370799541473389,1.4862534999847412,50000 -1044.2222499847412,4.124652862548828,29616.63855457306,87768,0,29616.63855457306,0.5189000368118286,2.1587750911712646,10000,30668.031265974045,0.7394770383834839,1.004119873046875,0.6474599838256836,1.4325826168060305,50000 -1061.6245312690735,4.165873289108276,30126.547045707703,89283,0,30126.547045707703,0.5223000049591064,2.1549034118652344,10000,31195.43574333191,0.7106584906578064,1.1320785284042358,0.644819974899292,1.4565682411193848,50000 -1079.1629874706268,4.2079079151153564,30636.56070494652,90798,0,30636.56070494652,0.5230000019073486,2.135982751846313,10000,31723.081677913666,0.7100805044174194,1.1253926753997805,0.6494799852371216,1.432708501815796,50000 -1096.6802098751068,4.251695394515991,31146.7821393013,92314,0,31146.7821393013,0.5293000340461731,2.1161322593688965,10000,32250.916907072067,0.7146045565605164,1.1162675619125366,0.6528800129890442,1.4155189990997314,50000 -1114.0649182796478,4.299462080001831,31656.7867538929,93829,0,31656.7867538929,0.5184000134468079,2.1597468852996826,10000,32778.40650200844,0.7053571343421936,1.1537963151931765,0.649679958820343,1.422189474105835,50000 -1131.4108610153198,4.345837831497192,32166.745491981503,95344,0,32166.745491981503,0.5294000506401062,2.091960906982422,10000,33305.81133413315,0.7147639989852905,1.101299524307251,0.6574400067329407,1.3877466917037964,50000 -1148.986226797104,4.391420841217041,32676.94704413414,96860,0,32676.94704413414,0.5304000377655029,2.108454704284668,10000,33833.68562602997,0.7434829473495483,0.9642866849899292,0.6609599590301514,1.386073112487793,50000 -1166.6077721118927,4.435008049011231,33186.945125579834,98375,0,33186.945125579834,0.5261000394821167,2.065514326095581,10000,34361.40209579468,0.7309271097183228,1.028484344482422,0.6595799922943115,1.3712352514266968,50000 -1184.107824802399,4.481401681900024,33696.84750413895,99890,0,33696.84750413895,0.5289000272750854,2.132866382598877,10000,34888.904074430466,0.7154615521430969,1.0986077785491943,0.6500999927520752,1.425310730934143,50000 -1201.5192544460297,4.524857044219971,34207.06012392044,101406,0,34207.06012392044,0.525700032711029,2.14704966545105,10000,35416.624915361404,0.7203643321990967,1.0835435390472412,0.6618399620056152,1.3780913352966309,50000 -1219.1642200946808,4.576931715011597,34717.284263134,102922,0,34717.284263134,0.5365000367164612,2.069563865661621,10000,35944.59972167015,0.7308474183082581,1.019730567932129,0.6702199578285217,1.3424674272537231,50000 -1236.998259305954,4.621797323226929,35227.21813130379,104437,0,35227.21813130379,0.5321000218391418,2.0829954147338867,10000,36472.46457672119,0.7296117544174194,1.0509129762649536,0.6679199934005737,1.347889065742493,50000 -1254.704176902771,4.679151773452759,35737.22608041763,105952,0,35737.22608041763,0.5433000326156616,2.0417113304138184,10000,37000.28816699982,0.76175856590271,0.9061012864112854,0.674560010433197,1.3246057033538818,50000 -1272.3301212787628,4.732055902481079,36247.43604612351,107468,0,36247.43604612351,0.546500027179718,2.02262544631958,10000,37528.2292740345,0.7491828799247742,0.9538630843162536,0.675819993019104,1.3117072582244873,50000 -1290.2720756530762,4.7767369747161865,36757.352653265,108983,0,36757.352653265,0.541700005531311,2.0563201904296875,10000,38056.18480205536,0.7424266338348389,0.9738861918449402,0.6760599613189697,1.3137764930725098,50000 -1307.7142674922943,4.826080322265625,37267.507370471954,110499,0,37267.507370471954,0.5463000535964966,2.017503261566162,10000,38583.88388371468,0.7394770383834839,0.996335506439209,0.6743599772453308,1.319318413734436,50000 -1325.155839920044,4.8738861083984375,37777.445001125336,112014,0,37777.445001125336,0.5491999983787537,2.0260019302368164,10000,39111.363669633865,0.7409917116165161,0.985767662525177,0.6783599853515625,1.3062725067138672,50000 -1342.566482782364,4.919671297073364,38287.514525175095,113529,0,38287.514525175095,0.5555000305175781,1.985240340232849,10000,39638.94228172302,0.7640305757522583,0.9003487825393677,0.6823399662971497,1.285157561302185,50000 -1360.0370292663574,4.967864036560059,38797.47217464447,115044,0,38797.47217464447,0.553600013256073,1.981478214263916,10000,40166.47235298157,0.772859513759613,0.8542397618293762,0.6820399761199951,1.2807635068893433,50000 -1377.9627270698547,5.016592741012573,39307.38990712166,116559,0,39307.38990712166,0.5580000281333923,1.988004207611084,10000,40694.417508125305,0.7597456574440002,0.911958634853363,0.681939959526062,1.299109697341919,50000 -1395.4464178085327,5.063026189804077,39817.51469898224,118074,0,39817.51469898224,0.5614000558853149,1.96667742729187,10000,41222.12512159348,0.7604631781578064,0.8970020413398743,0.6898199915885925,1.2545552253723145,50000 -1413.069516658783,5.112490177154541,40327.41791152954,119589,0,40327.41791152954,0.5593000054359436,1.9771336317062376,10000,41749.75271129608,0.7599050998687744,0.9128103256225586,0.6882799863815308,1.2635095119476318,50000 -1430.3549864292145,5.1657538414001465,40837.543076753616,121104,0,40837.543076753616,0.5651000142097473,1.946260452270508,10000,42277.27058959007,0.7638512253761292,0.8878701329231262,0.6916399598121643,1.2425378561019895,50000 -1448.7214317321775,5.217383623123169,41347.52177858353,122619,0,41347.52177858353,0.5672000050544739,1.9245952367782595,10000,42805.71956944466,0.8024553656578064,0.7302818298339844,0.693619966506958,1.2337989807128906,50000 -1466.0691316127777,5.268237113952637,41857.67465591431,124135,0,41857.67465591431,0.5654000043869019,1.947636365890503,10000,43333.32304549217,0.7833027839660645,0.8094884753227234,0.6947399973869324,1.2455179691314695,50000 -1483.8026728630066,5.32036828994751,42367.645431280136,125650,0,42367.645431280136,0.572700023651123,1.9287670850753784,10000,43861.1323056221,0.7840800285339355,0.8092227578163147,0.6946600079536438,1.2339342832565308,50000 -1501.4231088161469,5.373712062835693,42877.60603952408,127165,0,42877.60603952408,0.5708000063896179,1.908912181854248,10000,44388.82006406784,0.778340220451355,0.8295766711235046,0.7000399827957153,1.2225801944732666,50000 -1519.2647440433502,5.426731824874878,43387.53185915947,128680,0,43387.53185915947,0.5818000435829163,1.8730347156524656,10000,44916.69388747215,0.78324294090271,0.8075078129768372,0.7049199938774109,1.190179467201233,50000 -1536.644121170044,5.475257635116577,43897.6113114357,130196,0,43897.6113114357,0.5768000483512878,1.8925039768219,10000,45444.25380349159,0.7864915132522583,0.789760947227478,0.7076999545097351,1.1743167638778689,50000 -1554.1618838310242,5.526075601577759,44407.7055516243,131711,0,44407.7055516243,0.5764000415802002,1.8879553079605105,10000,45971.96899843216,0.8223453164100647,0.6572174429893494,0.7078799605369568,1.1762797832489014,50000 -1571.7473032474518,5.575676918029785,44917.60682630539,133226,0,44917.60682630539,0.581000030040741,1.8739051818847656,10000,46499.55842018128,0.8046476244926453,0.7072620391845703,0.7079600095748901,1.179681420326233,50000 -1589.98424077034,5.627666711807251,45427.66591835022,134741,0,45427.66591835022,0.5949000120162964,1.822237253189087,10000,47027.95869851112,0.8073381781578064,0.7013617157936096,0.7153399586677551,1.1508142948150637,50000 -1607.3338513374329,6.020332336425781,45937.36339020729,136255,0,45937.36339020729,0.5845000147819519,1.8409979343414309,10000,47555.45125055313,0.8015385866165161,0.7381904721260071,0.7129799723625183,1.1590455770492554,50000 -1624.9835669994354,6.077698230743408,46447.48622989655,137770,0,46447.48622989655,0.5848000049591064,1.854791045188904,10000,48083.33385229111,0.8052256107330322,0.7155845761299133,0.7154799699783325,1.1480318307876587,50000 -1642.6341168880465,6.127922534942627,46957.41595888138,139285,0,46957.41595888138,0.5924000144004822,1.8475563526153564,10000,48611.016573905945,0.8102080225944519,0.6994613409042358,0.7164799571037292,1.1391528844833374,50000 -1660.1281578540802,6.1773834228515625,47467.37585401535,140800,0,47467.37585401535,0.5952000021934509,1.7925559282302856,10000,49138.57271814346,0.8451849222183228,0.5710806846618652,0.7232999801635742,1.1173813343048096,50000 -1677.7803509235382,6.227893829345703,47977.319326639175,142315,0,47977.319326639175,0.5986000299453735,1.7970387935638428,10000,49666.27079510689,0.8337252736091614,0.6066722273826599,0.7225199937820435,1.1123135089874268,50000 -1695.3257067203522,6.27121639251709,48487.28637838364,143830,0,48487.28637838364,0.600100040435791,1.7974672317504885,10000,50193.8794400692,0.8341039419174194,0.6040483713150024,0.7270999550819397,1.1015565395355225,50000 -1712.9535655975342,6.321614980697632,48997.5089635849,145346,0,48997.5089635849,0.5995000004768372,1.8142592906951904,10000,50721.83307147026,0.8286631107330322,0.6169041991233826,0.727840006351471,1.1030001640319824,50000 -1730.040581703186,6.375980854034424,49507.73680782318,146862,0,49507.73680782318,0.6061000227928162,1.7760010957717896,10000,51249.25549435616,0.832051157951355,0.6025562286376953,0.7324999570846558,1.0817550420761108,50000 -1747.6648552417755,6.429377555847168,50017.81396985054,148377,0,50017.81396985054,0.6103000044822693,1.783352971076965,10000,51777.063213825226,0.840840220451355,0.5678520202636719,0.732759952545166,1.0800023078918457,50000 -1765.2281498908997,6.480208396911621,50527.99266386032,149893,0,50527.99266386032,0.6116000413894653,1.751431941986084,10000,52304.908478975296,0.8672671914100647,0.4734492003917694,0.7351399660110474,1.0676651000976562,50000 -1782.6540973186493,6.537206172943115,51038.04159331322,151408,0,51038.04159331322,0.6101000308990479,1.7489426136016846,10000,52832.49210214615,0.8615473508834839,0.4889602661132812,0.7380799651145935,1.0503398180007937,50000 -1799.930284500122,6.58967399597168,51548.12100839615,152923,0,51548.12100839615,0.6165000200271606,1.755265712738037,10000,53359.95241069794,0.8637993931770325,0.4847530722618103,0.7389400005340576,1.0541657209396362,50000 -1817.5693821907043,6.6446380615234375,52058.20289778709,154438,0,52058.20289778709,0.6126000285148621,1.7592875957489014,10000,53887.78027367592,0.8626036047935486,0.4872516095638275,0.7405999898910522,1.0530531406402588,50000 -1835.477769613266,6.697588682174683,52568.401344537735,155954,0,52568.401344537735,0.6145000457763672,1.7565077543258667,10000,54415.99195933342,0.8659518361091614,0.4726797938346863,0.7437599897384644,1.0325734615325928,50000 -1853.0811932086945,6.750476598739624,53078.52582502365,157469,0,53078.52582502365,0.6205000281333923,1.7346514463424685,10000,54943.82565164566,0.875996470451355,0.4355567395687103,0.7448999881744385,1.034436821937561,50000 -1870.427015542984,6.8072190284729,53588.57729744911,158984,0,53588.57729744911,0.6225000023841858,1.7522176504135132,10000,55471.33137130737,0.8933752775192261,0.3799366354942322,0.7441399693489075,1.0344130992889404,50000 -1888.285579442978,6.8666510581970215,54098.76694107056,160500,0,54098.76694107056,0.6256000399589539,1.7291738986968994,10000,55999.49178338051,0.8909438848495483,0.3873023092746734,0.750499963760376,1.014423131942749,50000 -1905.7687640190125,6.921643972396851,54608.69594120979,162015,0,54608.69594120979,0.6287000179290771,1.735564112663269,10000,56527.012323856354,0.8907844424247742,0.3792414665222168,0.7509599924087524,1.014463186264038,50000 -1923.371912240982,6.976501703262329,55118.67601776123,163530,0,55118.67601776123,0.6272000074386597,1.7141947746276855,10000,57054.70344829559,0.8932557106018066,0.3726956844329834,0.7510600090026855,1.0063519477844238,50000 -1941.1338379383087,7.033837080001831,55628.793186903,165045,0,55628.793186903,0.6301000118255615,1.7019473314285278,10000,57582.69201374054,0.8986766338348389,0.3579636812210083,0.7545199990272522,0.9984259009361268,50000 -1958.429122209549,7.088637590408325,56138.75828337669,166560,0,56138.75828337669,0.633400022983551,1.6993948221206665,10000,58110.06244254112,0.9122488498687744,0.3140309453010559,0.7552599906921387,0.9961941838264464,50000 -1976.0279388427728,7.146857023239136,56648.82356977463,168075,0,56648.82356977463,0.6340000033378601,1.6936455965042114,10000,58637.83798003197,0.916772961616516,0.2954529523849487,0.7545599937438965,0.992598831653595,50000 -1993.3908696174624,7.208467960357666,57158.8615398407,169590,0,57158.8615398407,0.6345000267028809,1.6890100240707395,10000,59165.35334587097,0.9133848547935486,0.3045446276664734,0.7583999633789062,0.9829720854759216,50000 -2010.915184259415,7.263584613800049,57668.79889130592,171105,0,57668.79889130592,0.6374000310897827,1.6809219121932983,10000,59692.92246937752,0.9176697731018066,0.2919151186943054,0.7594999670982361,0.9838279485702516,50000 -2028.5382385253904,7.319125175476074,58178.850531578064,172620,0,58178.850531578064,0.6355000138282776,1.6834412813186646,10000,60220.705787181854,0.9196228981018066,0.279366672039032,0.7603200078010559,0.9797834157943726,50000 -2046.0725243091583,7.37781810760498,58688.89499163628,174135,0,58688.89499163628,0.6358000040054321,1.688401460647583,10000,60748.39577841759,0.9229113459587096,0.2737523913383484,0.7621399760246277,0.9760602712631226,50000 -2063.7414784431458,7.4339611530303955,59198.99043011665,175650,0,59198.99043011665,0.6411000490188599,1.676186203956604,10000,61276.268881082535,0.9290696382522584,0.2520670890808105,0.7624199986457825,0.9706498980522156,50000 -2081.356790304184,7.493084192276001,59708.983047008514,177165,0,59708.983047008514,0.6412000060081482,1.673362374305725,10000,61803.98777985573,0.9302256107330322,0.2482085525989532,0.7620399594306946,0.9679412245750428,50000 -2098.725486040116,7.550928115844727,60218.98730373383,178680,0,60218.98730373383,0.6420000195503235,1.6618832349777222,10000,62331.470396757126,0.9304248690605164,0.2487609386444091,0.7635599970817566,0.96607106924057,50000 -2116.085070133209,7.609094381332397,60729.14320731163,180195,0,60729.14320731163,0.6408000588417053,1.6687637567520142,10000,62859.096420288086,0.9312220811843872,0.2476635724306106,0.7628799676895142,0.9658904671669006,50000 -2133.3704464435577,7.666547775268555,61239.31663656235,181711,0,61239.31663656235,0.6413000226020813,1.669214963912964,10000,63386.6651391983,0.932437777519226,0.2460429519414901,0.7635599970817566,0.963376522064209,50000 -2151.172516822815,7.724971771240234,61749.31013250351,183226,0,61749.31013250351,0.6414000391960144,1.668165683746338,10000,63914.57199478149,0.9330556392669678,0.2435648739337921,0.763700008392334,0.9627121686935424,50000 -2168.429575204849,7.789133071899414,62259.26232099533,184741,0,62259.26232099533,0.6416000127792358,1.6681299209594729,10000,64441.89793848992,0.9329758882522584,0.240103930234909,0.7639399766921997,0.962593674659729,50000 -2186.145072221756,7.848384857177734,62769.26744389534,186255,0,62769.26744389534,0.6425000429153442,1.6677354574203491,10000,64969.73010158539,0.9334343075752258,0.2402811050415039,0.7638399600982666,0.9630224704742432,50000 -2203.5093677043915,7.9071362018585205,62907.485946178436,186666,0,62907.485946178436,0.6413000226020813,1.6673704385757446,10000,65125.386219739914,0.9338129758834839,0.23659060895442963,0.7639399766921997,0.9623273015022278,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/measurements.csv deleted file mode 100644 index fbc699b69..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6918318,6.925329,,,,,,,,,,,,,, -1,,,0.0011360011994838,6.912071704864502,0.0011199999134987,6.912059783935547,50000.0,0.0009000000427477,6.912177562713623,10000.0,34.98110485076904,52.88213634490967,34.98110485076904,17.900901556015015,0.0,0.0 -100,0.6605497,6.9031887,,,,,,,,,,,,,, -200,0.66838104,6.859374,,,,,,,,,,,,,, -300,0.7106597,6.7713594,,,,,,,,,,,,,, -400,0.74135804,6.6677914,,,,,,,,,,,,,, -500,0.8236008,6.592644,,,,,,,,,,,,,, -600,0.83506924,6.44412,,,,,,,,,,,,,, -700,0.87176454,6.3434825,,,,,,,,,,,,,, -800,1.0219997,6.2212377,,,,,,,,,,,,,, -900,1.2797405,6.1169868,,,,,,,,,,,,,, -1000,1.8294537,6.051275,,,,,,,,,,,,,, -1100,2.4579573,5.85211,,,,,,,,,,,,,, -1200,4.7569213,5.8215847,,,,,,,,,,,,,, -1300,2.426664,5.7337084,,,,,,,,,,,,,, -1400,2.4091556,5.630808,,,,,,,,,,,,,, -1500,3.9885478,5.5422153,,,,,,,,,,,,,, -1508,,,0.0686383917927742,5.3621506690979,0.0652799978852272,5.418907165527344,50000.0,0.0460000038146972,5.649470329284668,10000.0,545.0515990257263,580.9545924663544,545.0515990257263,35.82946801185608,0.0212481021881103,0.0 -1600,2.53721,5.455859,,,,,,,,,,,,,, -1700,3.9927826,5.4505234,,,,,,,,,,,,,, -1800,3.0819147,5.318719,,,,,,,,,,,,,, -1900,3.2706342,5.2140975,,,,,,,,,,,,,, -2000,3.913796,5.171115,,,,,,,,,,,,,, -2100,4.6674557,5.074688,,,,,,,,,,,,,, -2200,4.2126503,5.0502615,,,,,,,,,,,,,, -2300,6.2003636,4.9508066,,,,,,,,,,,,,, -2400,4.526266,4.9331083,,,,,,,,,,,,,, -2500,4.2254987,4.9629583,,,,,,,,,,,,,, -2600,3.5820367,4.816556,,,,,,,,,,,,,, -2700,5.1409073,4.7009153,,,,,,,,,,,,,, -2800,5.9813867,4.684201,,,,,,,,,,,,,, -2900,4.9089665,4.524418,,,,,,,,,,,,,, -3000,6.237407,4.6330347,,,,,,,,,,,,,, -3015,,,0.1673309952020645,4.278256893157959,0.1536999940872192,4.376176357269287,50000.0,0.1108000054955482,4.82637882232666,10000.0,1055.0695431232452,1108.7527313232422,1055.0695431232452,53.52310228347778,0.0560157299041748,0.0 -3100,5.013001,4.59711,,,,,,,,,,,,,, -3200,8.40671,4.4264154,,,,,,,,,,,,,, -3300,4.627994,4.3527784,,,,,,,,,,,,,, -3400,6.604889,4.533375,,,,,,,,,,,,,, -3500,5.391113,4.2892876,,,,,,,,,,,,,, -3600,4.53213,4.345494,,,,,,,,,,,,,, -3700,4.5998297,4.2678337,,,,,,,,,,,,,, -3800,7.6352344,4.166814,,,,,,,,,,,,,, -3900,6.898203,4.174037,,,,,,,,,,,,,, -4000,7.079673,4.1589994,,,,,,,,,,,,,, -4100,7.8777766,4.068376,,,,,,,,,,,,,, -4200,7.3467884,3.9160156,,,,,,,,,,,,,, -4300,4.4655275,3.9996865,,,,,,,,,,,,,, -4400,7.3388705,3.6758366,,,,,,,,,,,,,, -4500,8.104232,3.790757,,,,,,,,,,,,,, -4521,,,0.2669403553009033,3.528772592544556,0.245739996433258,3.66829776763916,50000.0,0.1792000085115432,4.221713066101074,10000.0,1565.034935951233,1636.450201511383,1565.034935951233,71.17259407043457,0.0849447250366211,0.0 -4600,12.099805,3.8777313,,,,,,,,,,,,,, -4700,7.2592726,3.8515558,,,,,,,,,,,,,, -4800,7.803293,3.732893,,,,,,,,,,,,,, -4900,5.875465,3.7151234,,,,,,,,,,,,,, -5000,6.4340086,3.713931,,,,,,,,,,,,,, -5100,9.190918,3.627798,,,,,,,,,,,,,, -5200,8.083566,3.6541605,,,,,,,,,,,,,, -5300,8.566788,3.5153763,,,,,,,,,,,,,, -5400,6.340848,3.4977949,,,,,,,,,,,,,, -5500,8.5301075,3.643642,,,,,,,,,,,,,, -5600,6.198027,3.4304996,,,,,,,,,,,,,, -5700,8.770251,3.5450628,,,,,,,,,,,,,, -5800,6.5355525,3.3984222,,,,,,,,,,,,,, -5900,8.115816,3.4771125,,,,,,,,,,,,,, -6000,5.0551977,3.33284,,,,,,,,,,,,,, -6028,,,0.3638791441917419,2.9160544872283936,0.3346199989318847,3.083531141281128,50000.0,0.249300017952919,3.7188665866851807,10000.0,2075.2632479667664,2165.413183450699,2075.2632479667664,89.82699704170227,0.1136150360107421,0.0 -6100,5.5315905,3.328334,,,,,,,,,,,,,, -6200,6.5476875,3.2343206,,,,,,,,,,,,,, -6300,10.031163,3.4289308,,,,,,,,,,,,,, -6400,6.7959857,3.3897252,,,,,,,,,,,,,, -6500,9.299929,3.1957593,,,,,,,,,,,,,, -6600,6.222101,3.2700276,,,,,,,,,,,,,, -6700,6.613223,3.1213877,,,,,,,,,,,,,, -6800,4.905024,3.1571307,,,,,,,,,,,,,, -6900,9.618085,3.023005,,,,,,,,,,,,,, -7000,5.1729727,2.8959858,,,,,,,,,,,,,, -7100,10.577727,3.0257537,,,,,,,,,,,,,, -7200,8.897513,2.9858875,,,,,,,,,,,,,, -7300,8.26587,3.026681,,,,,,,,,,,,,, -7400,7.710528,3.114067,,,,,,,,,,,,,, -7500,5.821417,2.9316084,,,,,,,,,,,,,, -7535,,,0.4528658986091614,2.3958868980407715,0.4010199904441833,2.6911749839782715,50000.0,0.3009000122547149,3.3908586502075195,10000.0,2585.3535408973694,2693.6622858047485,2585.3535408973694,107.9004201889038,0.1459236145019531,0.0 -7600,7.2101984,2.8826833,,,,,,,,,,,,,, -7700,7.1638155,2.9599566,,,,,,,,,,,,,, -7800,8.931111,2.923713,,,,,,,,,,,,,, -7900,9.562929,2.904343,,,,,,,,,,,,,, -8000,7.8379383,2.7829208,,,,,,,,,,,,,, -8100,5.395839,2.8776,,,,,,,,,,,,,, -8200,7.9072013,2.7894106,,,,,,,,,,,,,, -8300,12.501923,2.7491832,,,,,,,,,,,,,, -8400,7.3359165,2.8245344,,,,,,,,,,,,,, -8500,6.561498,2.8731294,,,,,,,,,,,,,, -8600,7.3160796,2.894598,,,,,,,,,,,,,, -8700,5.7141385,2.6452725,,,,,,,,,,,,,, -8800,6.2034616,2.688836,,,,,,,,,,,,,, -8900,8.929108,2.6484401,,,,,,,,,,,,,, -9000,6.2237372,2.7607265,,,,,,,,,,,,,, -9043,,,0.4929049611091614,2.165973901748657,0.438539981842041,2.4691154956817627,50000.0,0.3348000049591064,3.1936228275299072,10000.0,3095.388578891754,3221.8932802677155,3095.388578891754,126.01309251785278,0.1767423152923584,0.0 -9100,5.910812,2.651039,,,,,,,,,,,,,, -9200,5.630504,2.6815853,,,,,,,,,,,,,, -9300,6.1672482,2.667395,,,,,,,,,,,,,, -9400,5.948908,2.5611534,,,,,,,,,,,,,, -9500,8.340653,2.550097,,,,,,,,,,,,,, -9600,6.6299834,2.687902,,,,,,,,,,,,,, -9700,8.407662,2.7546496,,,,,,,,,,,,,, -9800,6.604457,2.6349857,,,,,,,,,,,,,, -9900,4.4601803,2.6720328,,,,,,,,,,,,,, -10000,7.8588142,2.595665,,,,,,,,,,,,,, -10100,7.3966475,2.7037334,,,,,,,,,,,,,, -10200,7.062016,2.6334798,,,,,,,,,,,,,, -10300,7.28587,2.59799,,,,,,,,,,,,,, -10400,8.448896,2.488823,,,,,,,,,,,,,, -10500,5.7713923,2.5008802,,,,,,,,,,,,,, -10551,,,0.5285794138908386,1.9861119985580444,0.4908799827098846,2.2068440914154053,50000.0,0.3759000301361084,2.9420087337493896,10000.0,3605.421551465988,3749.529512166977,3605.421551465988,143.53465509414673,0.2061150074005127,0.0 -10600,5.6615615,2.464042,,,,,,,,,,,,,, -10700,7.260134,2.426726,,,,,,,,,,,,,, -10800,8.782193,2.556694,,,,,,,,,,,,,, -10900,6.8623886,2.39738,,,,,,,,,,,,,, -11000,8.206084,2.5723634,,,,,,,,,,,,,, -11100,4.7505507,2.4480588,,,,,,,,,,,,,, -11200,6.670445,2.5362537,,,,,,,,,,,,,, -11300,8.840716,2.4740548,,,,,,,,,,,,,, -11400,6.573105,2.5800877,,,,,,,,,,,,,, -11500,6.281918,2.4587314,,,,,,,,,,,,,, -11600,6.147467,2.4518993,,,,,,,,,,,,,, -11700,6.001476,2.4912791,,,,,,,,,,,,,, -11800,6.2510085,2.375293,,,,,,,,,,,,,, -11900,7.091833,2.2640743,,,,,,,,,,,,,, -12000,6.6647725,2.3540409,,,,,,,,,,,,,, -12061,,,0.549226701259613,1.8957043886184688,0.5112400054931641,2.1058874130249023,50000.0,0.3946000039577484,2.823363780975342,10000.0,4115.560308218002,4277.298977375031,4115.560308218002,161.0859730243683,0.2334010601043701,0.0 -12100,8.725427,2.3951964,,,,,,,,,,,,,, -12200,6.176702,2.420635,,,,,,,,,,,,,, -12300,6.677278,2.5014462,,,,,,,,,,,,,, -12400,5.3986645,2.3855236,,,,,,,,,,,,,, -12500,7.539214,2.3308806,,,,,,,,,,,,,, -12600,8.460099,2.2717388,,,,,,,,,,,,,, -12700,4.800795,2.4645557,,,,,,,,,,,,,, -12800,5.0112553,2.2793975,,,,,,,,,,,,,, -12900,6.506321,2.296379,,,,,,,,,,,,,, -13000,4.9706903,2.4162111,,,,,,,,,,,,,, -13100,7.1943545,2.330217,,,,,,,,,,,,,, -13200,6.3923244,2.3944025,,,,,,,,,,,,,, -13300,6.466414,2.3446603,,,,,,,,,,,,,, -13400,6.9634657,2.3066347,,,,,,,,,,,,,, -13500,4.7802367,2.3003194,,,,,,,,,,,,,, -13571,,,0.5690967440605164,1.7911511659622192,0.5332599878311157,1.990215301513672,50000.0,0.4144000113010406,2.7417922019958496,10000.0,4625.673633098602,4805.292775630951,4625.673633098602,178.88615822792053,0.2622287273406982,0.0 -13600,5.094012,2.263202,,,,,,,,,,,,,, -13700,4.992976,2.2400944,,,,,,,,,,,,,, -13800,6.7898545,2.378511,,,,,,,,,,,,,, -13900,5.279955,2.2474883,,,,,,,,,,,,,, -14000,5.934893,2.309133,,,,,,,,,,,,,, -14100,5.3385816,2.2247128,,,,,,,,,,,,,, -14200,6.8024197,2.3778205,,,,,,,,,,,,,, -14300,7.658504,2.2843351,,,,,,,,,,,,,, -14400,5.619893,2.2925358,,,,,,,,,,,,,, -14500,5.209236,2.1776214,,,,,,,,,,,,,, -14600,4.256783,2.2428365,,,,,,,,,,,,,, -14700,6.2928476,2.2963974,,,,,,,,,,,,,, -14800,4.3204474,2.1129677,,,,,,,,,,,,,, -14900,7.9693694,2.3304605,,,,,,,,,,,,,, -15000,5.5099454,2.282721,,,,,,,,,,,,,, -15082,,,0.5841637253761292,1.741123914718628,0.5440999865531921,1.943701148033142,50000.0,0.4231000244617462,2.674654245376587,10000.0,5135.629874706268,5333.236327886581,5135.629874706268,196.78649830818176,0.2955081462860107,0.0 -15100,7.4097996,2.2402627,,,,,,,,,,,,,, -15200,11.09363,2.1544204,,,,,,,,,,,,,, -15300,5.7063866,2.155396,,,,,,,,,,,,,, -15400,9.553391,2.4137197,,,,,,,,,,,,,, -15500,5.670063,2.0940006,,,,,,,,,,,,,, -15600,7.481932,2.2080672,,,,,,,,,,,,,, -15700,6.2287197,2.287107,,,,,,,,,,,,,, -15800,7.236441,2.308175,,,,,,,,,,,,,, -15900,10.304726,2.160612,,,,,,,,,,,,,, -16000,8.199587,2.098739,,,,,,,,,,,,,, -16100,5.0772424,2.1187012,,,,,,,,,,,,,, -16200,6.4286976,2.2850885,,,,,,,,,,,,,, -16300,5.265063,2.2440295,,,,,,,,,,,,,, -16400,5.7579546,2.1458423,,,,,,,,,,,,,, -16500,4.775985,2.2238526,,,,,,,,,,,,,, -16593,,,0.6349848508834839,1.4669139385223389,0.557200014591217,1.8688576221466064,50000.0,0.4305000305175781,2.623843193054199,10000.0,5645.554906845093,5860.77290892601,5645.554906845093,214.31005549430847,0.3311014175415039,0.0 -16600,7.8481264,2.2486806,,,,,,,,,,,,,, -16700,5.270316,2.1663601,,,,,,,,,,,,,, -16800,6.5218544,2.19277,,,,,,,,,,,,,, -16900,5.097384,2.122229,,,,,,,,,,,,,, -17000,4.4725003,2.346362,,,,,,,,,,,,,, -17100,7.7964916,2.1603045,,,,,,,,,,,,,, -17200,5.646873,2.2669036,,,,,,,,,,,,,, -17300,4.242427,2.1574788,,,,,,,,,,,,,, -17400,8.255229,2.3061566,,,,,,,,,,,,,, -17500,6.8696313,2.186872,,,,,,,,,,,,,, -17600,4.3447337,2.140676,,,,,,,,,,,,,, -17700,6.9211655,2.0525126,,,,,,,,,,,,,, -17800,4.459744,2.144934,,,,,,,,,,,,,, -17900,4.8675747,2.2166896,,,,,,,,,,,,,, -18000,4.9689007,2.1700568,,,,,,,,,,,,,, -18100,4.3642006,2.2792845,,,,,,,,,,,,,, -18105,,,0.6109095811843872,1.5742591619491575,0.5535399913787842,1.891344428062439,50000.0,0.4300000071525574,2.625067234039306,10000.0,6155.727502822876,6388.630298376083,6155.727502822876,231.911938905716,0.3624053001403808,0.0 -18200,6.5498314,2.1705031,,,,,,,,,,,,,, -18300,5.417239,2.2415075,,,,,,,,,,,,,, -18400,4.3799267,2.357487,,,,,,,,,,,,,, -18500,4.80797,2.2352753,,,,,,,,,,,,,, -18600,5.2964787,2.1467543,,,,,,,,,,,,,, -18700,4.712475,2.1039,,,,,,,,,,,,,, -18800,3.3838286,2.1387568,,,,,,,,,,,,,, -18900,5.2114096,2.1740322,,,,,,,,,,,,,, -19000,4.4705324,2.160033,,,,,,,,,,,,,, -19100,6.027235,2.147112,,,,,,,,,,,,,, -19200,4.5407767,2.1536345,,,,,,,,,,,,,, -19300,5.6931734,2.2166648,,,,,,,,,,,,,, -19400,4.1920843,2.224035,,,,,,,,,,,,,, -19500,4.95933,2.070912,,,,,,,,,,,,,, -19600,5.2566943,2.1638665,,,,,,,,,,,,,, -19617,,,0.6103914380073547,1.596194624900818,0.5658800005912781,1.8382683992385864,50000.0,0.4397000074386596,2.6097118854522705,10000.0,6665.842526435852,6916.42732667923,6665.842526435852,249.50973081588745,0.394974946975708,0.0 -19700,4.6621757,2.2414055,,,,,,,,,,,,,, -19800,3.407278,1.9687914,,,,,,,,,,,,,, -19900,6.2363954,1.9451904,,,,,,,,,,,,,, -20000,3.8515012,2.11922,,,,,,,,,,,,,, -20100,4.4829354,2.0286875,,,,,,,,,,,,,, -20200,3.9209678,2.2162428,,,,,,,,,,,,,, -20300,3.5085034,2.0405016,,,,,,,,,,,,,, -20400,5.412552,2.2063754,,,,,,,,,,,,,, -20500,4.321043,2.1020226,,,,,,,,,,,,,, -20600,4.1003985,2.0276675,,,,,,,,,,,,,, -20700,5.089968,2.0228689,,,,,,,,,,,,,, -20800,4.433331,2.1360347,,,,,,,,,,,,,, -20900,3.402144,2.2060595,,,,,,,,,,,,,, -21000,6.4777393,2.143498,,,,,,,,,,,,,, -21100,3.428987,2.087257,,,,,,,,,,,,,, -21130,,,0.6112683415412903,1.598503351211548,0.5659999847412109,1.84242594242096,50000.0,0.4401000142097473,2.5855188369750977,10000.0,7175.828543901444,7444.008017539978,7175.828543901444,267.0219874382019,0.424980878829956,0.0 -21200,5.0267024,2.0776303,,,,,,,,,,,,,, -21300,4.3811393,2.192462,,,,,,,,,,,,,, -21400,4.425209,2.0389767,,,,,,,,,,,,,, -21500,3.5318165,2.033403,,,,,,,,,,,,,, -21600,4.4498434,2.060489,,,,,,,,,,,,,, -21700,4.122686,2.080049,,,,,,,,,,,,,, -21800,3.7876194,2.1081405,,,,,,,,,,,,,, -21900,4.4680557,2.1715548,,,,,,,,,,,,,, -22000,4.4116516,2.033571,,,,,,,,,,,,,, -22100,3.716316,2.0175936,,,,,,,,,,,,,, -22200,3.8625665,2.1692092,,,,,,,,,,,,,, -22300,4.0066566,2.01678,,,,,,,,,,,,,, -22400,4.496334,2.122203,,,,,,,,,,,,,, -22500,4.250924,2.0307949,,,,,,,,,,,,,, -22600,4.6787915,2.0081677,,,,,,,,,,,,,, -22643,,,0.6199776530265808,1.546796441078186,0.5753399729728699,1.773450255393982,50000.0,0.4510000348091125,2.5334506034851074,10000.0,7685.893748044968,7972.019269227982,7685.893748044968,284.88498640060425,0.4551308155059814,0.0 -22700,3.5876064,2.0549634,,,,,,,,,,,,,, -22800,3.5026956,2.13879,,,,,,,,,,,,,, -22900,5.402143,2.0125172,,,,,,,,,,,,,, -23000,3.911221,2.0849793,,,,,,,,,,,,,, -23100,5.3246922,2.1697178,,,,,,,,,,,,,, -23200,4.430755,2.0617125,,,,,,,,,,,,,, -23300,4.2558103,2.108371,,,,,,,,,,,,,, -23400,3.677214,1.9520276,,,,,,,,,,,,,, -23500,4.0244045,2.0573566,,,,,,,,,,,,,, -23600,5.1884694,2.061713,,,,,,,,,,,,,, -23700,4.4608984,2.1066825,,,,,,,,,,,,,, -23800,4.376013,2.0124538,,,,,,,,,,,,,, -23900,3.6879659,2.0308847,,,,,,,,,,,,,, -24000,4.6255994,2.1311333,,,,,,,,,,,,,, -24100,4.343982,2.0606518,,,,,,,,,,,,,, -24156,,,0.6190210580825806,1.5578505992889404,0.5760399699211121,1.7675529718399048,50000.0,0.4533000290393829,2.5400712490081787,10000.0,8196.035193443298,8499.933848619461,8196.035193443298,302.57521986961365,0.4855742454528808,0.0 -24200,4.854485,2.0648873,,,,,,,,,,,,,, -24300,4.0427155,1.9933325,,,,,,,,,,,,,, -24400,5.194443,2.1014256,,,,,,,,,,,,,, -24500,5.056919,2.0787044,,,,,,,,,,,,,, -24600,5.9935427,2.0420427,,,,,,,,,,,,,, -24700,5.054775,1.88425,,,,,,,,,,,,,, -24800,4.5144763,2.1293592,,,,,,,,,,,,,, -24900,4.04037,2.1078207,,,,,,,,,,,,,, -25000,3.1948211,2.0540094,,,,,,,,,,,,,, -25100,3.2010014,1.9930985,,,,,,,,,,,,,, -25200,4.770114,1.9975641,,,,,,,,,,,,,, -25300,4.484963,1.9935255,,,,,,,,,,,,,, -25400,3.1309187,2.085467,,,,,,,,,,,,,, -25500,4.906482,2.0452375,,,,,,,,,,,,,, -25600,3.0697238,2.0703855,,,,,,,,,,,,,, -25669,,,0.6695631146430969,1.318595290184021,0.5854399800300598,1.7338712215423584,50000.0,0.4646000266075134,2.441926956176758,10000.0,8706.256494283676,9027.87317752838,8706.256494283676,320.2084016799927,0.51666259765625,0.0 -25700,3.8844368,1.9360489,,,,,,,,,,,,,, -25800,3.4634566,1.9690092,,,,,,,,,,,,,, -25900,3.605319,2.021104,,,,,,,,,,,,,, -26000,3.7015407,2.0991151,,,,,,,,,,,,,, -26100,4.317075,2.0067384,,,,,,,,,,,,,, -26200,5.017609,2.014726,,,,,,,,,,,,,, -26300,3.5866077,2.0870109,,,,,,,,,,,,,, -26400,4.167263,2.0338492,,,,,,,,,,,,,, -26500,4.32943,2.040482,,,,,,,,,,,,,, -26600,4.2222404,1.9501106,,,,,,,,,,,,,, -26700,2.9109151,1.8435001,,,,,,,,,,,,,, -26800,3.4955392,1.9812511,,,,,,,,,,,,,, -26900,4.440279,2.0231102,,,,,,,,,,,,,, -27000,3.6850998,2.0687952,,,,,,,,,,,,,, -27100,3.993688,2.1109858,,,,,,,,,,,,,, -27182,,,0.6529615521430969,1.4048078060150146,0.5888800024986267,1.7101259231567385,50000.0,0.4656000137329101,2.442213535308838,10000.0,9216.314534425735,9555.776126384735,9216.314534425735,337.96847558021545,0.549079179763794,0.0 -27200,3.907175,1.986553,,,,,,,,,,,,,, -27300,6.102223,2.0561426,,,,,,,,,,,,,, -27400,3.7747934,1.9677863,,,,,,,,,,,,,, -27500,3.715097,1.8840711,,,,,,,,,,,,,, -27600,3.8069344,1.90926,,,,,,,,,,,,,, -27700,3.3812985,2.019248,,,,,,,,,,,,,, -27800,4.246343,2.0722723,,,,,,,,,,,,,, -27900,4.056918,1.9990051,,,,,,,,,,,,,, -28000,4.155707,1.9858363,,,,,,,,,,,,,, -28100,4.1536093,2.023775,,,,,,,,,,,,,, -28200,4.6229935,2.0557773,,,,,,,,,,,,,, -28300,4.139329,2.0982435,,,,,,,,,,,,,, -28400,5.294293,1.9450476,,,,,,,,,,,,,, -28500,3.9842098,1.9918761,,,,,,,,,,,,,, -28600,3.9135394,2.057629,,,,,,,,,,,,,, -28695,,,0.6298230290412903,1.5027976036071775,0.5792199969291687,1.772684931755066,50000.0,0.4537000358104706,2.503017902374268,10000.0,9726.26763677597,10083.292204618454,9726.26763677597,355.44871044158936,0.5797502994537354,0.0 -28700,3.1936352,1.9689293,,,,,,,,,,,,,, -28800,4.22906,1.9627922,,,,,,,,,,,,,, -28900,4.454586,1.8928761,,,,,,,,,,,,,, -29000,3.2302659,2.024377,,,,,,,,,,,,,, -29100,4.0880795,2.0149179,,,,,,,,,,,,,, -29200,3.974929,2.0174952,,,,,,,,,,,,,, -29300,4.228001,2.1035619,,,,,,,,,,,,,, -29400,3.9544322,1.9742602,,,,,,,,,,,,,, -29500,3.7327015,2.0590088,,,,,,,,,,,,,, -29600,3.8599992,2.0552688,,,,,,,,,,,,,, -29700,3.7301517,1.9012164,,,,,,,,,,,,,, -29800,5.2502823,2.0598497,,,,,,,,,,,,,, -29900,3.7761075,1.9760897,,,,,,,,,,,,,, -30000,3.8111548,1.9538547,,,,,,,,,,,,,, -30100,3.516937,1.9569168,,,,,,,,,,,,,, -30200,3.5829983,2.0840669,,,,,,,,,,,,,, -30209,,,0.6353435516357422,1.4675039052963257,0.5884400010108948,1.714426040649414,50000.0,0.4670000076293945,2.455329656600952,10000.0,10236.290426254272,10611.109208583832,10236.290426254272,373.15779161453247,0.6124565601348877,0.0 -30300,4.5505543,1.9944434,,,,,,,,,,,,,, -30400,3.8737135,2.0855863,,,,,,,,,,,,,, -30500,4.1314564,1.9740711,,,,,,,,,,,,,, -30600,3.4799733,1.9313103,,,,,,,,,,,,,, -30700,4.5444307,1.8903173,,,,,,,,,,,,,, -30800,3.5860472,2.0747886,,,,,,,,,,,,,, -30900,4.1273355,1.8989208,,,,,,,,,,,,,, -31000,3.230602,1.9581822,,,,,,,,,,,,,, -31100,4.0059137,1.9869152,,,,,,,,,,,,,, -31200,4.086983,1.9044544,,,,,,,,,,,,,, -31300,3.997529,1.9984822,,,,,,,,,,,,,, -31400,4.180762,2.0288534,,,,,,,,,,,,,, -31500,3.4508562,1.8882699,,,,,,,,,,,,,, -31600,3.4263942,1.9489658,,,,,,,,,,,,,, -31700,4.0054755,2.018535,,,,,,,,,,,,,, -31723,,,0.6358418464660645,1.4786940813064575,0.5914999842643738,1.7083582878112793,50000.0,0.4664000272750854,2.412381172180176,10000.0,10746.51861667633,11138.977312088013,10746.51861667633,390.7127459049225,0.6449933052062988,0.0 -31800,3.5503693,1.8996878,,,,,,,,,,,,,, -31900,3.9775002,1.9928651,,,,,,,,,,,,,, -32000,4.401783,1.9885473,,,,,,,,,,,,,, -32100,4.4230623,1.845108,,,,,,,,,,,,,, -32200,3.993567,2.00921,,,,,,,,,,,,,, -32300,3.8040078,1.9401221,,,,,,,,,,,,,, -32400,3.4923577,1.8933154,,,,,,,,,,,,,, -32500,3.4942489,1.9073292,,,,,,,,,,,,,, -32600,3.6369274,2.0207856,,,,,,,,,,,,,, -32700,4.844222,1.9983268,,,,,,,,,,,,,, -32800,3.5800295,2.0468745,,,,,,,,,,,,,, -32900,3.9702203,1.9431963,,,,,,,,,,,,,, -33000,3.7151551,2.0609426,,,,,,,,,,,,,, -33100,4.004558,1.8399762,,,,,,,,,,,,,, -33200,3.1357832,1.9158918,,,,,,,,,,,,,, -33237,,,0.639668345451355,1.453246831893921,0.5940999984741211,1.6839145421981812,50000.0,0.4757000207901001,2.400609254837036,10000.0,11256.555840015411,11666.838463544846,11256.555840015411,408.448246717453,0.680262565612793,0.0 -33300,3.3218136,2.0192676,,,,,,,,,,,,,, -33400,3.9465349,1.9473425,,,,,,,,,,,,,, -33500,4.2857113,1.965692,,,,,,,,,,,,,, -33600,4.234598,2.0232697,,,,,,,,,,,,,, -33700,4.2908106,2.0020027,,,,,,,,,,,,,, -33800,3.774921,2.0075994,,,,,,,,,,,,,, -33900,3.936992,1.9738698,,,,,,,,,,,,,, -34000,4.1080685,2.0987225,,,,,,,,,,,,,, -34100,3.5754871,1.896229,,,,,,,,,,,,,, -34200,4.103382,1.9420005,,,,,,,,,,,,,, -34300,3.6541886,1.8868097,,,,,,,,,,,,,, -34400,5.1154156,1.8663621,,,,,,,,,,,,,, -34500,4.745734,2.0504663,,,,,,,,,,,,,, -34600,4.393839,1.9428071,,,,,,,,,,,,,, -34700,3.750946,1.922741,,,,,,,,,,,,,, -34751,,,0.6583824753761292,1.3591587543487549,0.5783599615097046,1.7592555284500122,50000.0,0.4595000147819519,2.4588377475738525,10000.0,11766.602644443512,12194.56355690956,11766.602644443512,426.0376763343811,0.717193603515625,0.0 -34800,3.895911,1.8971303,,,,,,,,,,,,,, -34900,3.8591332,1.9557449,,,,,,,,,,,,,, -35000,3.9929304,2.0019488,,,,,,,,,,,,,, -35100,4.7561107,1.9139332,,,,,,,,,,,,,, -35200,3.8979428,2.0609853,,,,,,,,,,,,,, -35300,3.7159615,1.7936894,,,,,,,,,,,,,, -35400,4.066668,1.7480643,,,,,,,,,,,,,, -35500,4.357567,2.0510216,,,,,,,,,,,,,, -35600,3.9665542,1.9942503,,,,,,,,,,,,,, -35700,3.4923456,2.0361342,,,,,,,,,,,,,, -35800,4.895355,1.9357605,,,,,,,,,,,,,, -35900,3.4539952,1.9329318,,,,,,,,,,,,,, -36000,3.6746104,1.9844818,,,,,,,,,,,,,, -36100,3.217462,2.0235438,,,,,,,,,,,,,, -36200,3.6912758,1.9641486,,,,,,,,,,,,,, -36262,,,0.6530413031578064,1.3820711374282837,0.5939399600028992,1.7025057077407837,50000.0,0.468500018119812,2.4308671951293945,10000.0,12275.55195569992,12722.270855426788,12275.55195569992,443.6337375640869,1.826355218887329,0.0 -36300,4.2517505,2.0072696,,,,,,,,,,,,,, -36400,3.9002597,2.0162961,,,,,,,,,,,,,, -36500,3.5038068,1.8300934,,,,,,,,,,,,,, -36600,3.7000825,1.964264,,,,,,,,,,,,,, -36700,4.3094735,1.8841639,,,,,,,,,,,,,, -36800,3.413028,1.9259231,,,,,,,,,,,,,, -36900,4.2352285,2.1193314,,,,,,,,,,,,,, -37000,3.914015,1.9231944,,,,,,,,,,,,,, -37100,3.8298028,1.9058464,,,,,,,,,,,,,, -37200,3.7516034,2.0218508,,,,,,,,,,,,,, -37300,3.5277874,1.855407,,,,,,,,,,,,,, -37400,4.483726,1.8718673,,,,,,,,,,,,,, -37500,4.9845033,1.9500352,,,,,,,,,,,,,, -37600,3.88403,2.0564184,,,,,,,,,,,,,, -37700,3.6275644,1.908319,,,,,,,,,,,,,, -37777,,,0.6506098508834839,1.396021008491516,0.5997999906539917,1.6665048599243164,50000.0,0.4780000150203705,2.3967301845550537,10000.0,12785.783144235613,13250.32419705391,12785.783144235613,461.3704402446747,1.859493970870972,0.0 -37800,3.98397,1.9589472,,,,,,,,,,,,,, -37900,3.6851468,2.029378,,,,,,,,,,,,,, -38000,4.4175763,1.955493,,,,,,,,,,,,,, -38100,3.8506348,1.9563261,,,,,,,,,,,,,, -38200,4.3436255,1.9230015,,,,,,,,,,,,,, -38300,4.00845,1.9169024,,,,,,,,,,,,,, -38400,3.7873144,1.9825363,,,,,,,,,,,,,, -38500,3.4867127,1.8703469,,,,,,,,,,,,,, -38600,3.2138615,1.9557985,,,,,,,,,,,,,, -38700,3.95643,2.0646846,,,,,,,,,,,,,, -38800,3.7801304,1.9499412,,,,,,,,,,,,,, -38900,3.284408,1.9204112,,,,,,,,,,,,,, -39000,4.20451,1.897518,,,,,,,,,,,,,, -39100,3.6433427,1.9694215,,,,,,,,,,,,,, -39200,4.133368,1.954294,,,,,,,,,,,,,, -39291,,,0.6446109414100647,1.427964210510254,0.5927000045776367,1.6880282163619995,50000.0,0.4695000350475311,2.4148316383361816,10000.0,13295.923071146011,13778.023537874222,13295.923071146011,478.8435335159302,1.8939027786254885,0.0 -39300,4.336007,1.8864362,,,,,,,,,,,,,, -39400,3.705601,1.8015025,,,,,,,,,,,,,, -39500,3.917998,2.036747,,,,,,,,,,,,,, -39600,4.262929,1.9535775,,,,,,,,,,,,,, -39700,3.8389213,1.9820064,,,,,,,,,,,,,, -39800,3.5667841,1.9916046,,,,,,,,,,,,,, -39900,4.1102834,1.8464319,,,,,,,,,,,,,, -40000,3.8715403,1.9115937,,,,,,,,,,,,,, -40100,3.4179556,1.9474136,,,,,,,,,,,,,, -40200,3.759169,1.8796784,,,,,,,,,,,,,, -40300,3.6357632,1.9968982,,,,,,,,,,,,,, -40400,4.3073053,1.879175,,,,,,,,,,,,,, -40500,3.956692,2.0390203,,,,,,,,,,,,,, -40600,4.106557,1.8303733,,,,,,,,,,,,,, -40700,3.2844188,1.9089264,,,,,,,,,,,,,, -40800,4.0452733,1.9033154,,,,,,,,,,,,,, -40806,,,0.6479392647743225,1.4033961296081543,0.6071599721908569,1.6209561824798584,50000.0,0.4875000119209289,2.322809934616089,10000.0,13806.05341053009,14306.078384160995,13806.05341053009,496.6829800605774,1.9265995025634768,0.0 -40900,3.823781,2.078939,,,,,,,,,,,,,, -41000,4.286013,1.8673735,,,,,,,,,,,,,, -41100,3.0443227,1.8451418,,,,,,,,,,,,,, -41200,3.9385586,1.9753251,,,,,,,,,,,,,, -41300,3.3543148,1.9435983,,,,,,,,,,,,,, -41400,3.350949,1.8321545,,,,,,,,,,,,,, -41500,3.6527734,1.8991244,,,,,,,,,,,,,, -41600,3.521822,1.9246564,,,,,,,,,,,,,, -41700,3.9910152,1.9736347,,,,,,,,,,,,,, -41800,4.5276484,1.9334903,,,,,,,,,,,,,, -41900,4.0717006,1.8197474,,,,,,,,,,,,,, -42000,3.8329647,1.9887118,,,,,,,,,,,,,, -42100,3.684869,2.0077186,,,,,,,,,,,,,, -42200,4.353548,1.9654229,,,,,,,,,,,,,, -42300,3.6073651,1.7513735,,,,,,,,,,,,,, -42321,,,0.6422193646430969,1.4444411993026731,0.5982199907302856,1.6778781414031982,50000.0,0.4732000231742859,2.43147611618042,10000.0,14316.080854415894,14833.64709019661,14316.080854415894,514.1380536556244,1.9594926834106443,0.0 -42400,4.1553407,1.854238,,,,,,,,,,,,,, -42500,3.69259,1.898687,,,,,,,,,,,,,, -42600,3.5271678,1.8428168,,,,,,,,,,,,,, -42700,3.8383465,1.7864109,,,,,,,,,,,,,, -42800,3.9696908,1.9434533,,,,,,,,,,,,,, -42900,4.3750954,1.9565712,,,,,,,,,,,,,, -43000,3.4555197,2.0056467,,,,,,,,,,,,,, -43100,3.7983222,1.9207561,,,,,,,,,,,,,, -43200,3.470741,1.8171105,,,,,,,,,,,,,, -43300,3.26927,1.9238062,,,,,,,,,,,,,, -43400,3.5968907,1.8669841,,,,,,,,,,,,,, -43500,3.8879154,1.9199464,,,,,,,,,,,,,, -43600,3.1645343,1.8099499,,,,,,,,,,,,,, -43700,3.343314,1.805572,,,,,,,,,,,,,, -43800,3.4916182,1.9153275,,,,,,,,,,,,,, -43835,,,0.6764788031578064,1.2786322832107544,0.5990399718284607,1.6587549448013306,50000.0,0.4761000275611877,2.4039862155914307,10000.0,14826.026376008987,15361.308577775955,14826.026376008987,531.7674918174744,1.993568420410156,0.0 -43900,3.7283528,1.9089698,,,,,,,,,,,,,, -44000,4.6264515,2.0024734,,,,,,,,,,,,,, -44100,3.5564187,1.8542551,,,,,,,,,,,,,, -44200,4.6637816,1.8834254,,,,,,,,,,,,,, -44300,3.7501335,1.9311506,,,,,,,,,,,,,, -44400,3.786208,1.890802,,,,,,,,,,,,,, -44500,3.9388173,1.9861517,,,,,,,,,,,,,, -44600,4.341872,2.0094686,,,,,,,,,,,,,, -44700,4.3426876,1.9881024,,,,,,,,,,,,,, -44800,3.6002808,1.7757297,,,,,,,,,,,,,, -44900,3.1832714,1.8369436,,,,,,,,,,,,,, -45000,4.2427063,2.011694,,,,,,,,,,,,,, -45100,5.411912,1.8410211,,,,,,,,,,,,,, -45200,3.9257994,1.8944697,,,,,,,,,,,,,, -45300,3.4685981,1.8948758,,,,,,,,,,,,,, -45350,,,0.6533800959587097,1.3688466548919678,0.602180004119873,1.6588586568832395,50000.0,0.4793000221252441,2.385539293289185,10000.0,15336.135235071182,15889.651034116743,15336.135235071182,549.9132053852081,2.0285537242889404,0.0 -45400,4.053127,1.9544961,,,,,,,,,,,,,, -45500,3.3432894,1.9039564,,,,,,,,,,,,,, -45600,3.5509703,1.9210268,,,,,,,,,,,,,, -45700,3.2086625,1.9432919,,,,,,,,,,,,,, -45800,4.0512595,1.9332505,,,,,,,,,,,,,, -45900,3.8056173,1.7972003,,,,,,,,,,,,,, -46000,3.881247,1.8902886,,,,,,,,,,,,,, -46100,3.9305005,1.9682307,,,,,,,,,,,,,, -46200,4.0949454,1.8751315,,,,,,,,,,,,,, -46300,4.07436,1.7643521,,,,,,,,,,,,,, -46400,3.7646575,2.0530581,,,,,,,,,,,,,, -46500,3.3986657,1.8248976,,,,,,,,,,,,,, -46600,4.9181805,1.9606416,,,,,,,,,,,,,, -46700,3.3584034,1.8424587,,,,,,,,,,,,,, -46800,3.585112,2.0031624,,,,,,,,,,,,,, -46865,,,0.6542569994926453,1.370741367340088,0.6079999804496765,1.6286641359329224,50000.0,0.4809000194072723,2.3713834285736084,10000.0,15846.135519742966,16417.224779605865,15846.135519742966,567.3988988399506,2.063908100128174,0.0 -46900,3.931867,1.9350547,,,,,,,,,,,,,, -47000,3.7197244,1.8205023,,,,,,,,,,,,,, -47100,3.9939456,1.8200073,,,,,,,,,,,,,, -47200,3.3116465,1.8300899,,,,,,,,,,,,,, -47300,3.5684516,1.9354444,,,,,,,,,,,,,, -47400,3.4521902,1.803926,,,,,,,,,,,,,, -47500,3.1440058,1.7968682,,,,,,,,,,,,,, -47600,3.6067593,1.9879395,,,,,,,,,,,,,, -47700,4.1765065,1.8029137,,,,,,,,,,,,,, -47800,4.150971,1.8857526,,,,,,,,,,,,,, -47900,4.661058,1.8251476,,,,,,,,,,,,,, -48000,3.9434712,1.8820095,,,,,,,,,,,,,, -48100,4.0288444,1.8453239,,,,,,,,,,,,,, -48200,4.38627,2.0253768,,,,,,,,,,,,,, -48300,3.7587562,1.9725864,,,,,,,,,,,,,, -48379,,,0.6592593789100647,1.3680092096328735,0.6098399758338928,1.604357361793518,50000.0,0.4847000241279602,2.338550329208374,10000.0,16356.188809633257,16944.88347172737,16356.188809633257,584.9136664867401,2.1026611328125,0.0 -48400,4.56974,1.9367385,,,,,,,,,,,,,, -48500,3.9984212,1.9236596,,,,,,,,,,,,,, -48600,3.6837177,1.9720188,,,,,,,,,,,,,, -48700,3.6302857,1.8495055,,,,,,,,,,,,,, -48800,3.9350123,1.9782687,,,,,,,,,,,,,, -48900,3.3123195,1.8086146,,,,,,,,,,,,,, -49000,3.6736174,1.807744,,,,,,,,,,,,,, -49100,3.8740628,1.8689075,,,,,,,,,,,,,, -49200,3.7786212,1.8409662,,,,,,,,,,,,,, -49300,3.774457,1.8511314,,,,,,,,,,,,,, -49400,3.4078844,1.9082216,,,,,,,,,,,,,, -49500,3.574778,1.8192408,,,,,,,,,,,,,, -49600,3.4173248,1.8164607,,,,,,,,,,,,,, -49700,3.6420314,1.8493109,,,,,,,,,,,,,, -49800,3.713165,1.7477164,,,,,,,,,,,,,, -49894,,,0.6473612785339355,1.4344884157180786,0.5996400117874146,1.6758544445037842,50000.0,0.4773000180721283,2.3852531909942627,10000.0,16866.31578350067,17472.523061990738,16866.31578350067,602.3362815380096,2.140815734863281,0.0 -49900,3.5130072,1.8481269,,,,,,,,,,,,,, -50000,3.3140392,1.827048,,,,,,,,,,,,,, -50100,3.6800795,1.848054,,,,,,,,,,,,,, -50200,3.2395198,1.8933748,,,,,,,,,,,,,, -50300,3.6032395,1.8332715,,,,,,,,,,,,,, -50400,4.072158,1.9537554,,,,,,,,,,,,,, -50500,3.487935,1.8172883,,,,,,,,,,,,,, -50600,4.348412,1.9119321,,,,,,,,,,,,,, -50700,3.9540503,1.8258432,,,,,,,,,,,,,, -50800,3.679827,1.9362118,,,,,,,,,,,,,, -50900,3.1599367,1.9777313,,,,,,,,,,,,,, -51000,3.89897,1.9557068,,,,,,,,,,,,,, -51100,3.7876704,1.9471333,,,,,,,,,,,,,, -51200,3.9561558,1.8107555,,,,,,,,,,,,,, -51300,3.7663743,1.9074768,,,,,,,,,,,,,, -51400,4.2272787,1.7977232,,,,,,,,,,,,,, -51408,,,0.6523237824440002,1.3949483633041382,0.6103799939155579,1.6124364137649536,50000.0,0.4890000224113464,2.2980523109436035,10000.0,17376.387938022614,18000.41622543335,17376.387938022614,620.0699634552002,2.175709724426269,0.0 -51500,3.7672064,1.9346339,,,,,,,,,,,,,, -51600,4.1393285,1.8658532,,,,,,,,,,,,,, -51700,3.8888676,1.9166937,,,,,,,,,,,,,, -51800,3.6840405,1.9350713,,,,,,,,,,,,,, -51900,3.8969247,1.8842947,,,,,,,,,,,,,, -52000,4.926097,1.8855605,,,,,,,,,,,,,, -52100,3.7504761,1.8514923,,,,,,,,,,,,,, -52200,3.8618333,1.838086,,,,,,,,,,,,,, -52300,4.0531063,1.941336,,,,,,,,,,,,,, -52400,3.803714,1.9266434,,,,,,,,,,,,,, -52500,3.8633285,1.9476583,,,,,,,,,,,,,, -52600,4.129245,1.9757969,,,,,,,,,,,,,, -52700,3.586032,1.8764219,,,,,,,,,,,,,, -52800,4.907467,1.8934824,,,,,,,,,,,,,, -52900,4.7456536,1.8497144,,,,,,,,,,,,,, -52923,,,0.6839325428009033,1.2379963397979736,0.6139999628067017,1.5940256118774414,50000.0,0.4906000196933746,2.326728820800781,10000.0,17886.385818719864,18528.079265117645,17886.385818719864,637.6424875259399,2.215566873550415,0.0 -53000,4.027455,1.916977,,,,,,,,,,,,,, -53100,4.2189226,1.8806162,,,,,,,,,,,,,, -53200,3.839038,1.8686234,,,,,,,,,,,,,, -53300,3.9411404,1.7670377,,,,,,,,,,,,,, -53400,3.920788,1.8482156,,,,,,,,,,,,,, -53500,3.9106302,1.7879376,,,,,,,,,,,,,, -53600,3.3874502,1.7293304,,,,,,,,,,,,,, -53700,5.037465,1.7787527,,,,,,,,,,,,,, -53800,3.758829,1.8510944,,,,,,,,,,,,,, -53900,3.9169357,1.8317341,,,,,,,,,,,,,, -54000,3.3729837,1.8674887,,,,,,,,,,,,,, -54100,3.6566353,1.8817954,,,,,,,,,,,,,, -54200,4.187257,1.8204418,,,,,,,,,,,,,, -54300,4.0033054,1.8202089,,,,,,,,,,,,,, -54400,3.4950757,1.893774,,,,,,,,,,,,,, -54438,,,0.6634446382522583,1.333989143371582,0.608020007610321,1.6208534240722656,50000.0,0.4843000173568725,2.368727922439575,10000.0,18396.55052471161,19055.956347703934,18396.55052471161,655.2638325691223,2.253763198852539,0.0 -54500,4.103141,1.7786889,,,,,,,,,,,,,, -54600,3.6021225,1.8745371,,,,,,,,,,,,,, -54700,4.2698607,1.8705676,,,,,,,,,,,,,, -54800,3.4724627,1.8844709,,,,,,,,,,,,,, -54900,3.4019802,1.770718,,,,,,,,,,,,,, -55000,3.995369,1.9098866,,,,,,,,,,,,,, -55100,3.5400398,1.9773972,,,,,,,,,,,,,, -55200,4.282717,1.8757458,,,,,,,,,,,,,, -55300,5.274156,1.8998731,,,,,,,,,,,,,, -55400,3.6817575,1.8069754,,,,,,,,,,,,,, -55500,3.95656,1.8960578,,,,,,,,,,,,,, -55600,4.6568527,1.8045553,,,,,,,,,,,,,, -55700,5.476296,1.9445996,,,,,,,,,,,,,, -55800,4.181935,1.9832008,,,,,,,,,,,,,, -55900,4.679861,1.7579602,,,,,,,,,,,,,, -55953,,,0.6671316623687744,1.318600058555603,0.6162999868392944,1.5899615287780762,50000.0,0.4941000342369079,2.28857421875,10000.0,18906.63244438172,19583.63987517357,18906.63244438172,672.7741053104401,2.2918620109558105,0.0 -56000,4.390674,1.7624955,,,,,,,,,,,,,, -56100,3.664662,1.7469412,,,,,,,,,,,,,, -56200,3.6292846,1.8876317,,,,,,,,,,,,,, -56300,3.926178,1.9303472,,,,,,,,,,,,,, -56400,3.705279,1.7500871,,,,,,,,,,,,,, -56500,3.960178,1.8160689,,,,,,,,,,,,,, -56600,4.476594,1.9084712,,,,,,,,,,,,,, -56700,4.107562,1.771157,,,,,,,,,,,,,, -56800,4.601777,1.8416564,,,,,,,,,,,,,, -56900,3.7083113,1.9265676,,,,,,,,,,,,,, -57000,4.78288,1.8213973,,,,,,,,,,,,,, -57100,3.904496,1.8403347,,,,,,,,,,,,,, -57200,4.366998,1.9241416,,,,,,,,,,,,,, -57300,3.5553815,1.8792837,,,,,,,,,,,,,, -57400,3.7030835,1.7602235,,,,,,,,,,,,,, -57468,,,0.6590401530265808,1.3535881042480469,0.6110000014305115,1.6046415567398071,50000.0,0.4899000227451324,2.360703468322754,10000.0,19416.74183535576,20111.414390802383,19416.74183535576,690.3503816127777,2.3279449939727783,0.0 -57500,4.133106,1.9021591,,,,,,,,,,,,,, -57600,4.4980664,1.8607532,,,,,,,,,,,,,, -57700,3.6939578,1.898387,,,,,,,,,,,,,, -57800,3.254748,1.939979,,,,,,,,,,,,,, -57900,4.004997,1.7112463,,,,,,,,,,,,,, -58000,4.478393,1.9146848,,,,,,,,,,,,,, -58100,4.2267065,1.941003,,,,,,,,,,,,,, -58200,3.208225,1.7657857,,,,,,,,,,,,,, -58300,3.2649171,1.7986268,,,,,,,,,,,,,, -58400,4.145215,1.8709303,,,,,,,,,,,,,, -58500,4.1045732,1.8857409,,,,,,,,,,,,,, -58600,3.7896068,1.8793626,,,,,,,,,,,,,, -58700,4.9168053,1.8062329,,,,,,,,,,,,,, -58800,3.7608728,1.894486,,,,,,,,,,,,,, -58900,4.2134027,1.8478001,,,,,,,,,,,,,, -58983,,,0.6589404940605164,1.362576246261597,0.6136400103569031,1.60284686088562,50000.0,0.4925000369548797,2.331923246383667,10000.0,19926.887871027,20639.101494312286,19926.887871027,707.7968149185181,2.370288133621216,0.0 -59000,4.6371536,1.8648058,,,,,,,,,,,,,, -59100,3.8295732,1.9409971,,,,,,,,,,,,,, -59200,4.663275,2.025329,,,,,,,,,,,,,, -59300,4.1795506,1.9146174,,,,,,,,,,,,,, -59400,3.8051398,1.9951303,,,,,,,,,,,,,, -59500,4.916174,1.7747741,,,,,,,,,,,,,, -59600,5.0946393,1.7922773,,,,,,,,,,,,,, -59700,4.0282393,1.9021745,,,,,,,,,,,,,, -59800,3.9687507,1.9427549,,,,,,,,,,,,,, -59900,4.124355,1.7792833,,,,,,,,,,,,,, -60000,5.2244,1.8177843,,,,,,,,,,,,,, -60100,3.4481483,1.9044063,,,,,,,,,,,,,, -60200,4.19261,1.7116612,,,,,,,,,,,,,, -60300,3.8889163,1.920566,,,,,,,,,,,,,, -60400,3.4934783,1.7927454,,,,,,,,,,,,,, -60498,,,0.6735491156578064,1.2887380123138428,0.620639979839325,1.5495282411575315,50000.0,0.4991000294685364,2.270559787750244,10000.0,20436.94685316085,21167.427884340286,20436.94685316085,725.9643821716309,2.4173500537872314,0.0 -60500,3.9638226,1.7387565,,,,,,,,,,,,,, -60600,3.8979187,1.8062327,,,,,,,,,,,,,, -60700,3.6008422,1.6525575,,,,,,,,,,,,,, -60800,4.197592,1.9040902,,,,,,,,,,,,,, -60900,3.8341935,1.7278676,,,,,,,,,,,,,, -61000,4.5849333,1.9090383,,,,,,,,,,,,,, -61100,4.2310567,1.9079801,,,,,,,,,,,,,, -61200,4.3929715,1.7049309,,,,,,,,,,,,,, -61300,3.9664726,1.7433504,,,,,,,,,,,,,, -61400,4.386988,1.807743,,,,,,,,,,,,,, -61500,4.3294435,1.8090047,,,,,,,,,,,,,, -61600,4.141832,1.7845268,,,,,,,,,,,,,, -61700,4.071753,1.738908,,,,,,,,,,,,,, -61800,3.6307924,1.838168,,,,,,,,,,,,,, -61900,4.097953,1.829127,,,,,,,,,,,,,, -62000,3.8401434,1.9186407,,,,,,,,,,,,,, -62014,,,0.6741669178009033,1.2919275760650637,0.6080399751663208,1.623947024345398,50000.0,0.4821000099182129,2.346834421157837,10000.0,20947.05204677581,21695.2219684124,20947.05204677581,743.5703382492065,2.4481093883514404,0.0 -62100,3.9308171,1.7488307,,,,,,,,,,,,,, -62200,4.976371,1.8162555,,,,,,,,,,,,,, -62300,4.219713,1.7301195,,,,,,,,,,,,,, -62400,5.10079,1.814042,,,,,,,,,,,,,, -62500,4.523494,1.8048222,,,,,,,,,,,,,, -62600,3.9651911,1.8071234,,,,,,,,,,,,,, -62700,4.997513,1.9036695,,,,,,,,,,,,,, -62800,3.3941433,1.7079608,,,,,,,,,,,,,, -62900,4.569104,1.8881693,,,,,,,,,,,,,, -63000,4.6084404,1.7461771,,,,,,,,,,,,,, -63100,4.1116414,1.8472364,,,,,,,,,,,,,, -63200,3.7439804,1.7460458,,,,,,,,,,,,,, -63300,3.7646377,1.8494351,,,,,,,,,,,,,, -63400,5.5215044,1.8357717,,,,,,,,,,,,,, -63500,3.803768,1.6994774,,,,,,,,,,,,,, -63529,,,0.6745057106018066,1.275795817375183,0.6182000041007996,1.5899293422698977,50000.0,0.4971000254154205,2.292905330657959,10000.0,21457.020793676376,22222.684602499008,21457.020793676376,760.9726111888885,2.4869983196258545,0.0 -63600,4.131134,1.8859401,,,,,,,,,,,,,, -63700,4.0297313,1.8362391,,,,,,,,,,,,,, -63800,3.7394493,1.8422554,,,,,,,,,,,,,, -63900,3.759087,1.8037266,,,,,,,,,,,,,, -64000,3.9279845,1.769879,,,,,,,,,,,,,, -64100,4.1976533,1.7663105,,,,,,,,,,,,,, -64200,4.0473695,1.6753767,,,,,,,,,,,,,, -64300,3.840741,1.8638039,,,,,,,,,,,,,, -64400,4.568819,1.835506,,,,,,,,,,,,,, -64500,4.007121,1.9376044,,,,,,,,,,,,,, -64600,3.9370675,1.7702589,,,,,,,,,,,,,, -64700,4.146255,1.7692102,,,,,,,,,,,,,, -64800,4.164078,1.8504063,,,,,,,,,,,,,, -64900,3.5907614,1.7685125,,,,,,,,,,,,,, -65000,4.181296,1.7085242,,,,,,,,,,,,,, -65044,,,0.6870814561843872,1.2384870052337646,0.6372599601745605,1.5065011978149414,50000.0,0.5083000063896179,2.2321155071258545,10000.0,21966.95255088806,22750.30890488625,21966.95255088806,778.5667836666107,2.531611204147339,0.0 -65100,3.9936845,1.7639877,,,,,,,,,,,,,, -65200,3.8344402,1.7625966,,,,,,,,,,,,,, -65300,3.6295705,1.7025485,,,,,,,,,,,,,, -65400,4.0880327,1.7400773,,,,,,,,,,,,,, -65500,3.7041187,1.8758776,,,,,,,,,,,,,, -65600,4.7105665,1.689724,,,,,,,,,,,,,, -65700,4.180933,1.7771343,,,,,,,,,,,,,, -65800,4.244617,1.8114583,,,,,,,,,,,,,, -65900,3.930593,1.7551085,,,,,,,,,,,,,, -66000,4.321261,1.9228816,,,,,,,,,,,,,, -66100,3.956782,1.7866559,,,,,,,,,,,,,, -66200,4.3483877,1.7940617,,,,,,,,,,,,,, -66300,4.285784,1.7037108,,,,,,,,,,,,,, -66400,3.897215,1.7805698,,,,,,,,,,,,,, -66500,3.7400563,1.7270721,,,,,,,,,,,,,, -66558,,,0.6863042116165161,1.240123271942139,0.6317600011825562,1.520139217376709,50000.0,0.5033000111579895,2.2496683597564697,10000.0,22476.95581459999,23277.84796833992,22476.95581459999,796.0130662918091,2.5694398880004883,0.0 -66600,4.2557015,1.8049226,,,,,,,,,,,,,, -66700,4.015452,1.8028785,,,,,,,,,,,,,, -66800,4.1540327,1.6778071,,,,,,,,,,,,,, -66900,3.7712429,1.7305064,,,,,,,,,,,,,, -67000,4.17748,1.8358662,,,,,,,,,,,,,, -67100,3.6857376,1.7517333,,,,,,,,,,,,,, -67200,4.516511,1.719603,,,,,,,,,,,,,, -67300,4.2353473,1.8401635,,,,,,,,,,,,,, -67400,4.036757,1.806247,,,,,,,,,,,,,, -67500,4.7294545,1.8026141,,,,,,,,,,,,,, -67600,4.0075736,1.7863039,,,,,,,,,,,,,, -67700,4.81793,1.9030731,,,,,,,,,,,,,, -67800,4.0985055,1.7720244,,,,,,,,,,,,,, -67900,4.000658,1.8327457,,,,,,,,,,,,,, -68000,4.3990116,1.7384031,,,,,,,,,,,,,, -68073,,,0.675203263759613,1.27826189994812,0.6322000026702881,1.5258268117904663,50000.0,0.5074000358581543,2.2572550773620605,10000.0,22986.95196557045,23805.82973885536,22986.95196557045,813.9019508361816,2.613524913787842,0.0 -68100,4.5861034,1.9252465,,,,,,,,,,,,,, -68200,4.101781,1.7434969,,,,,,,,,,,,,, -68300,4.0806494,1.8844006,,,,,,,,,,,,,, -68400,3.872168,1.7540979,,,,,,,,,,,,,, -68500,4.0620823,1.8208603,,,,,,,,,,,,,, -68600,3.6030247,1.6436169,,,,,,,,,,,,,, -68700,4.60612,1.6980487,,,,,,,,,,,,,, -68800,4.912476,1.7657961,,,,,,,,,,,,,, -68900,4.7533464,1.9336165,,,,,,,,,,,,,, -69000,3.6032245,1.8101566,,,,,,,,,,,,,, -69100,4.127538,1.8842943,,,,,,,,,,,,,, -69200,4.1397777,1.8233411,,,,,,,,,,,,,, -69300,4.439967,1.8149073,,,,,,,,,,,,,, -69400,4.050889,1.7486035,,,,,,,,,,,,,, -69500,4.633551,1.757205,,,,,,,,,,,,,, -69588,,,0.6990393400192261,1.1675734519958496,0.6238600015640259,1.5474979877471924,50000.0,0.4936000108718872,2.2800233364105225,10000.0,23496.924777507786,24333.650037050247,23496.924777507786,831.6630780696869,2.6471753120422363,0.0 -69600,4.6216965,1.7257001,,,,,,,,,,,,,, -69700,4.24884,1.7122296,,,,,,,,,,,,,, -69800,3.6340413,1.7759111,,,,,,,,,,,,,, -69900,4.5071073,1.7407091,,,,,,,,,,,,,, -70000,3.978756,1.7067959,,,,,,,,,,,,,, -70100,3.5856333,1.7329342,,,,,,,,,,,,,, -70200,4.0219016,1.82317,,,,,,,,,,,,,, -70300,4.1924367,1.871743,,,,,,,,,,,,,, -70400,4.19105,1.7959721,,,,,,,,,,,,,, -70500,4.0348735,1.8207364,,,,,,,,,,,,,, -70600,3.8114767,1.7833023,,,,,,,,,,,,,, -70700,4.6220417,1.8090261,,,,,,,,,,,,,, -70800,4.0358458,1.8001503,,,,,,,,,,,,,, -70900,4.2645698,1.765625,,,,,,,,,,,,,, -71000,3.7406723,1.7752984,,,,,,,,,,,,,, -71100,3.6479485,1.695277,,,,,,,,,,,,,, -71102,,,0.70023512840271,1.153869867324829,0.632420003414154,1.5144561529159546,50000.0,0.5120000243186951,2.2258431911468506,10000.0,24006.85396337509,24861.34249138832,24006.85396337509,849.3363344669342,2.685119152069092,0.0 -71200,3.9509673,1.7678226,,,,,,,,,,,,,, -71300,4.370953,1.7254121,,,,,,,,,,,,,, -71400,4.790479,1.8058044,,,,,,,,,,,,,, -71500,4.281839,1.7262571,,,,,,,,,,,,,, -71600,4.1051183,1.8503165,,,,,,,,,,,,,, -71700,3.8882034,1.7560195,,,,,,,,,,,,,, -71800,4.246412,1.8316412,,,,,,,,,,,,,, -71900,3.9829268,1.6761475,,,,,,,,,,,,,, -72000,4.1993175,1.7684333,,,,,,,,,,,,,, -72100,4.8905654,1.6987652,,,,,,,,,,,,,, -72200,4.2918406,1.7328396,,,,,,,,,,,,,, -72300,4.102774,1.8221744,,,,,,,,,,,,,, -72400,4.838421,1.7896619,,,,,,,,,,,,,, -72500,3.6503868,1.7951138,,,,,,,,,,,,,, -72600,4.3920197,1.8558178,,,,,,,,,,,,,, -72617,,,0.6888153553009033,1.204163670539856,0.6315199732780457,1.518242120742798,50000.0,0.5139000415802002,2.235292196273804,10000.0,24516.80539011956,25388.890427351,24516.80539011956,866.8355195522308,2.729054927825928,0.0 -72700,4.6321406,1.919986,,,,,,,,,,,,,, -72800,3.9802122,1.8481961,,,,,,,,,,,,,, -72900,3.773278,1.7742493,,,,,,,,,,,,,, -73000,3.826972,1.736831,,,,,,,,,,,,,, -73100,4.7379246,1.8491185,,,,,,,,,,,,,, -73200,3.862887,1.7733805,,,,,,,,,,,,,, -73300,4.134255,1.8693342,,,,,,,,,,,,,, -73400,5.0807486,1.7156621,,,,,,,,,,,,,, -73500,4.5181484,1.7970049,,,,,,,,,,,,,, -73600,4.279305,1.7437372,,,,,,,,,,,,,, -73700,4.134973,1.8442382,,,,,,,,,,,,,, -73800,4.155413,1.8540413,,,,,,,,,,,,,, -73900,4.489659,1.8104749,,,,,,,,,,,,,, -74000,3.597271,1.7328093,,,,,,,,,,,,,, -74100,3.7580416,1.8920479,,,,,,,,,,,,,, -74132,,,0.6875796914100647,1.2316724061965942,0.6329799890518188,1.49374258518219,50000.0,0.5053000450134277,2.2189886569976807,10000.0,25026.766946792603,25916.44958114624,25026.766946792603,884.3387162685394,2.771080732345581,0.0 -74200,3.9901304,1.7219138,,,,,,,,,,,,,, -74300,4.780741,1.7788525,,,,,,,,,,,,,, -74400,3.9080794,1.6175363,,,,,,,,,,,,,, -74500,5.2540293,1.7199864,,,,,,,,,,,,,, -74600,4.3577747,1.7853041,,,,,,,,,,,,,, -74700,4.553264,1.8215338,,,,,,,,,,,,,, -74800,4.1737995,1.8556765,,,,,,,,,,,,,, -74900,4.0941744,1.7305133,,,,,,,,,,,,,, -75000,4.295012,1.8875922,,,,,,,,,,,,,, -75100,4.134335,1.8209054,,,,,,,,,,,,,, -75200,4.3234053,1.6812078,,,,,,,,,,,,,, -75300,4.8236523,1.6992184,,,,,,,,,,,,,, -75400,4.017221,1.7082436,,,,,,,,,,,,,, -75500,4.2423797,1.8907698,,,,,,,,,,,,,, -75600,4.698178,1.8393183,,,,,,,,,,,,,, -75647,,,0.6935586333274841,1.2051820755004885,0.6389399766921997,1.4845311641693115,50000.0,0.5136000514030457,2.221004009246826,10000.0,25536.77026438713,26444.19121265412,25536.77026438713,901.9840440750122,2.811651945114136,0.0 -75700,3.953083,1.8141596,,,,,,,,,,,,,, -75800,6.1671586,1.7571808,,,,,,,,,,,,,, -75900,4.108558,1.672506,,,,,,,,,,,,,, -76000,4.782527,1.8082867,,,,,,,,,,,,,, -76100,4.007048,1.7288464,,,,,,,,,,,,,, -76200,4.1063766,1.9068534,,,,,,,,,,,,,, -76300,4.2790284,1.7203151,,,,,,,,,,,,,, -76400,5.554087,1.7671585,,,,,,,,,,,,,, -76500,4.232011,1.7932433,,,,,,,,,,,,,, -76600,4.689271,1.7468733,,,,,,,,,,,,,, -76700,3.5824125,1.7008967,,,,,,,,,,,,,, -76800,4.3213816,1.7776821,,,,,,,,,,,,,, -76900,3.9944932,1.7078578,,,,,,,,,,,,,, -77000,4.0585933,1.7431263,,,,,,,,,,,,,, -77100,4.0446253,1.8014872,,,,,,,,,,,,,, -77163,,,0.6889349222183228,1.2334457635879517,0.6387199759483337,1.485427737236023,50000.0,0.5131000280380249,2.2043116092681885,10000.0,26046.95576357841,26972.225853919983,26046.95576357841,919.7418501377106,2.8506574630737305,0.0 -77200,3.9870975,1.7115747,,,,,,,,,,,,,, -77300,4.5573244,1.8518823,,,,,,,,,,,,,, -77400,4.480951,1.743314,,,,,,,,,,,,,, -77500,4.159228,1.7235934,,,,,,,,,,,,,, -77600,4.090683,1.8333311,,,,,,,,,,,,,, -77700,4.372147,1.7149233,,,,,,,,,,,,,, -77800,3.6692326,1.6492438,,,,,,,,,,,,,, -77900,4.0589976,1.7313002,,,,,,,,,,,,,, -78000,4.0402064,1.6650933,,,,,,,,,,,,,, -78100,4.1390915,1.7890872,,,,,,,,,,,,,, -78200,4.135429,1.7926451,,,,,,,,,,,,,, -78300,3.7765644,1.6747794,,,,,,,,,,,,,, -78400,4.1202893,1.751561,,,,,,,,,,,,,, -78500,4.1796417,1.7425377,,,,,,,,,,,,,, -78600,4.608006,1.7052736,,,,,,,,,,,,,, -78679,,,0.721121609210968,1.0787429809570312,0.6339199542999268,1.5018742084503174,50000.0,0.5126000046730042,2.209225654602051,10000.0,26557.14439845085,27500.150496721268,26557.14439845085,937.3784308433532,2.897936582565308,0.0 -78700,4.3491435,1.721188,,,,,,,,,,,,,, -78800,4.727346,1.7840178,,,,,,,,,,,,,, -78900,4.1379642,1.7946054,,,,,,,,,,,,,, -79000,4.526201,1.809514,,,,,,,,,,,,,, -79100,3.6492653,1.6426611,,,,,,,,,,,,,, -79200,4.0807858,1.623838,,,,,,,,,,,,,, -79300,4.5242286,1.9240513,,,,,,,,,,,,,, -79400,4.082387,1.7181511,,,,,,,,,,,,,, -79500,4.133951,1.5926363,,,,,,,,,,,,,, -79600,3.4867756,1.5938176,,,,,,,,,,,,,, -79700,4.1991115,1.6876446,,,,,,,,,,,,,, -79800,4.2071743,1.7268579,,,,,,,,,,,,,, -79900,3.6719754,1.7358342,,,,,,,,,,,,,, -80000,3.8885276,1.6968691,,,,,,,,,,,,,, -80100,3.7046247,1.7421675,,,,,,,,,,,,,, -80195,,,0.7099210619926453,1.1163970232009888,0.6413599848747253,1.4619232416152954,50000.0,0.5139999985694885,2.207616090774536,10000.0,27067.337899446487,28027.980364322662,27067.337899446487,954.9220995903016,2.937781572341919,0.0 -80200,4.2243648,1.695462,,,,,,,,,,,,,, -80300,4.5144234,1.6929647,,,,,,,,,,,,,, -80400,5.367367,1.8099368,,,,,,,,,,,,,, -80500,4.2484717,1.774065,,,,,,,,,,,,,, -80600,4.4136662,1.715747,,,,,,,,,,,,,, -80700,4.3169394,1.6978967,,,,,,,,,,,,,, -80800,4.9916058,1.8078687,,,,,,,,,,,,,, -80900,4.4128194,1.7086539,,,,,,,,,,,,,, -81000,4.134921,1.7928455,,,,,,,,,,,,,, -81100,5.311638,1.7973602,,,,,,,,,,,,,, -81200,4.529269,1.756502,,,,,,,,,,,,,, -81300,6.5838366,1.7605592,,,,,,,,,,,,,, -81400,3.9619977,1.8155217,,,,,,,,,,,,,, -81500,4.6946464,1.6932974,,,,,,,,,,,,,, -81600,5.0495105,1.6521877,,,,,,,,,,,,,, -81700,4.1517396,1.8193139,,,,,,,,,,,,,, -81710,,,0.7057955861091614,1.155988097190857,0.6452400088310242,1.4618548154830933,50000.0,0.5199000239372253,2.1657211780548096,10000.0,27577.38576722145,28555.9904255867,27577.38576722145,972.7897419929504,2.9804928302764893,0.0 -81800,4.5865064,1.6723084,,,,,,,,,,,,,, -81900,4.179579,1.6542788,,,,,,,,,,,,,, -82000,4.6312394,1.6319457,,,,,,,,,,,,,, -82100,4.6958294,1.7859209,,,,,,,,,,,,,, -82200,4.6063714,1.7240536,,,,,,,,,,,,,, -82300,4.067366,1.765351,,,,,,,,,,,,,, -82400,4.8067565,1.6360333,,,,,,,,,,,,,, -82500,3.8183966,1.661609,,,,,,,,,,,,,, -82600,4.1495886,1.7881551,,,,,,,,,,,,,, -82700,3.9601855,1.7016653,,,,,,,,,,,,,, -82800,4.1001453,1.8110179,,,,,,,,,,,,,, -82900,4.7408013,1.6459951,,,,,,,,,,,,,, -83000,4.4149113,1.8098141,,,,,,,,,,,,,, -83100,4.302022,1.76687,,,,,,,,,,,,,, -83200,4.453429,1.6144058,,,,,,,,,,,,,, -83226,,,0.6995774507522583,1.173744559288025,0.6433599591255188,1.4629638195037842,50000.0,0.5162000060081482,2.1647984981536865,10000.0,28087.58653569221,29084.35788321495,28087.58653569221,990.8612344264984,3.0235769748687744,0.0 -83300,4.157793,1.6840304,,,,,,,,,,,,,, -83400,3.861226,1.6656466,,,,,,,,,,,,,, -83500,4.5002465,1.6803777,,,,,,,,,,,,,, -83600,3.891224,1.7457825,,,,,,,,,,,,,, -83700,4.6930804,1.7117342,,,,,,,,,,,,,, -83800,4.347626,1.7572398,,,,,,,,,,,,,, -83900,4.426829,1.6671665,,,,,,,,,,,,,, -84000,3.8715885,1.8095317,,,,,,,,,,,,,, -84100,4.3542447,1.8194407,,,,,,,,,,,,,, -84200,4.429485,1.6367867,,,,,,,,,,,,,, -84300,3.6421158,1.6503488,,,,,,,,,,,,,, -84400,4.166665,1.6944109,,,,,,,,,,,,,, -84500,3.9042513,1.6884695,,,,,,,,,,,,,, -84600,4.1445603,1.6588703,,,,,,,,,,,,,, -84700,4.5397,1.7079862,,,,,,,,,,,,,, -84741,,,0.6940369606018066,1.209545612335205,0.6385200023651123,1.4816385507583618,50000.0,0.5128000378608704,2.198549032211304,10000.0,28597.68378353119,29612.305801153183,28597.68378353119,1008.616588830948,3.065901756286621,0.0 -84800,4.502204,1.6673821,,,,,,,,,,,,,, -84900,4.6543264,1.7529213,,,,,,,,,,,,,, -85000,3.8866477,1.758851,,,,,,,,,,,,,, -85100,4.4495287,1.7199162,,,,,,,,,,,,,, -85200,5.819409,1.7146174,,,,,,,,,,,,,, -85300,4.0774126,1.7005222,,,,,,,,,,,,,, -85400,3.9485855,1.6545625,,,,,,,,,,,,,, -85500,5.711665,1.8006189,,,,,,,,,,,,,, -85600,5.440627,1.7364641,,,,,,,,,,,,,, -85700,4.402862,1.7378637,,,,,,,,,,,,,, -85800,4.411391,1.6831362,,,,,,,,,,,,,, -85900,4.6607265,1.6500937,,,,,,,,,,,,,, -86000,4.065704,1.7152267,,,,,,,,,,,,,, -86100,4.545862,1.7928475,,,,,,,,,,,,,, -86200,4.982448,1.6937978,,,,,,,,,,,,,, -86253,,,0.6893335580825806,1.2253234386444092,0.6370799541473389,1.4862534999847412,50000.0,0.5186000466346741,2.212566614151001,10000.0,29106.643271446228,30140.24239969253,29106.643271446228,1026.5237970352173,4.082794904708862,0.0 -86300,4.4790425,1.7528694,,,,,,,,,,,,,, -86400,4.8842716,1.7680203,,,,,,,,,,,,,, -86500,4.634054,1.7256687,,,,,,,,,,,,,, -86600,3.9590695,1.617015,,,,,,,,,,,,,, -86700,4.625004,1.6700335,,,,,,,,,,,,,, -86800,4.425586,1.7479794,,,,,,,,,,,,,, -86900,4.1791954,1.7034929,,,,,,,,,,,,,, -87000,4.30901,1.6219294,,,,,,,,,,,,,, -87100,4.3876553,1.8406695,,,,,,,,,,,,,, -87200,3.8911254,1.6542387,,,,,,,,,,,,,, -87300,3.841505,1.6632335,,,,,,,,,,,,,, -87400,4.0029664,1.6159382,,,,,,,,,,,,,, -87500,4.198643,1.6888306,,,,,,,,,,,,,, -87600,4.4802594,1.6906549,,,,,,,,,,,,,, -87700,4.2199526,1.7489655,,,,,,,,,,,,,, -87768,,,0.7394770383834839,1.004119873046875,0.6474599838256836,1.4325826168060305,50000.0,0.5189000368118286,2.1587750911712646,10000.0,29616.63855457306,30668.031265974045,29616.63855457306,1044.2222499847412,4.124652862548828,0.0 -87800,4.343962,1.7927539,,,,,,,,,,,,,, -87900,4.6265006,1.6767217,,,,,,,,,,,,,, -88000,4.5499873,1.6978974,,,,,,,,,,,,,, -88100,3.8862638,1.6259203,,,,,,,,,,,,,, -88200,4.4622245,1.7752863,,,,,,,,,,,,,, -88300,3.8671083,1.6769137,,,,,,,,,,,,,, -88400,4.6614895,1.6220957,,,,,,,,,,,,,, -88500,4.6269045,1.6481442,,,,,,,,,,,,,, -88600,4.38467,1.6403813,,,,,,,,,,,,,, -88700,4.5605483,1.7027454,,,,,,,,,,,,,, -88800,4.570484,1.6835933,,,,,,,,,,,,,, -88900,4.173081,1.7209035,,,,,,,,,,,,,, -89000,4.4368753,1.7337141,,,,,,,,,,,,,, -89100,4.8977246,1.601924,,,,,,,,,,,,,, -89200,5.0216713,1.6391029,,,,,,,,,,,,,, -89283,,,0.7106584906578064,1.1320785284042358,0.644819974899292,1.4565682411193848,50000.0,0.5223000049591064,2.1549034118652344,10000.0,30126.547045707703,31195.43574333191,30126.547045707703,1061.6245312690735,4.165873289108276,0.0 -89300,4.312597,1.5772429,,,,,,,,,,,,,, -89400,4.9986696,1.6555331,,,,,,,,,,,,,, -89500,4.3355217,1.6637709,,,,,,,,,,,,,, -89600,4.440582,1.6639855,,,,,,,,,,,,,, -89700,3.830895,1.7324742,,,,,,,,,,,,,, -89800,5.130396,1.7082468,,,,,,,,,,,,,, -89900,4.685493,1.6483678,,,,,,,,,,,,,, -90000,4.65596,1.628368,,,,,,,,,,,,,, -90100,4.151954,1.70947,,,,,,,,,,,,,, -90200,4.5647616,1.6708479,,,,,,,,,,,,,, -90300,4.538033,1.7658606,,,,,,,,,,,,,, -90400,4.3314176,1.761253,,,,,,,,,,,,,, -90500,5.86347,1.7506982,,,,,,,,,,,,,, -90600,4.2023854,1.5700082,,,,,,,,,,,,,, -90700,5.395527,1.6556063,,,,,,,,,,,,,, -90798,,,0.7100805044174194,1.1253926753997805,0.6494799852371216,1.432708501815796,50000.0,0.5230000019073486,2.135982751846313,10000.0,30636.56070494652,31723.081677913666,30636.56070494652,1079.1629874706268,4.2079079151153564,0.0 -90800,4.367517,1.7341723,,,,,,,,,,,,,, -90900,5.162245,1.6519158,,,,,,,,,,,,,, -91000,4.56275,1.604879,,,,,,,,,,,,,, -91100,4.237395,1.5767556,,,,,,,,,,,,,, -91200,4.8227253,1.6699502,,,,,,,,,,,,,, -91300,5.846114,1.6443926,,,,,,,,,,,,,, -91400,4.4917107,1.4812213,,,,,,,,,,,,,, -91500,4.684382,1.728674,,,,,,,,,,,,,, -91600,4.47327,1.6657887,,,,,,,,,,,,,, -91700,4.33524,1.8446158,,,,,,,,,,,,,, -91800,5.6307874,1.6521124,,,,,,,,,,,,,, -91900,4.624919,1.5638628,,,,,,,,,,,,,, -92000,4.6977477,1.5752048,,,,,,,,,,,,,, -92100,4.3482976,1.5694726,,,,,,,,,,,,,, -92200,4.5764475,1.55372,,,,,,,,,,,,,, -92300,5.334767,1.518022,,,,,,,,,,,,,, -92314,,,0.7146045565605164,1.1162675619125366,0.6528800129890442,1.4155189990997314,50000.0,0.5293000340461731,2.1161322593688965,10000.0,31146.7821393013,32250.916907072067,31146.7821393013,1096.6802098751068,4.251695394515991,0.0 -92400,4.8403897,1.6228809,,,,,,,,,,,,,, -92500,3.9356773,1.6037706,,,,,,,,,,,,,, -92600,4.697884,1.631914,,,,,,,,,,,,,, -92700,4.640671,1.6445222,,,,,,,,,,,,,, -92800,4.3799033,1.707154,,,,,,,,,,,,,, -92900,4.6967983,1.7089436,,,,,,,,,,,,,, -93000,4.250337,1.6086986,,,,,,,,,,,,,, -93100,4.4322963,1.7479272,,,,,,,,,,,,,, -93200,3.9728315,1.578401,,,,,,,,,,,,,, -93300,4.767332,1.6840034,,,,,,,,,,,,,, -93400,4.438678,1.7008978,,,,,,,,,,,,,, -93500,4.609376,1.6362318,,,,,,,,,,,,,, -93600,4.8099856,1.6247523,,,,,,,,,,,,,, -93700,5.3862405,1.6987021,,,,,,,,,,,,,, -93800,5.6660166,1.6007487,,,,,,,,,,,,,, -93829,,,0.7053571343421936,1.1537963151931765,0.649679958820343,1.422189474105835,50000.0,0.5184000134468079,2.1597468852996826,10000.0,31656.7867538929,32778.40650200844,31656.7867538929,1114.0649182796478,4.299462080001831,0.0 -93900,4.5273914,1.7760332,,,,,,,,,,,,,, -94000,3.8637018,1.5815148,,,,,,,,,,,,,, -94100,4.7334833,1.6734135,,,,,,,,,,,,,, -94200,4.623351,1.6568505,,,,,,,,,,,,,, -94300,4.513364,1.7401819,,,,,,,,,,,,,, -94400,5.021273,1.648349,,,,,,,,,,,,,, -94500,4.0645514,1.6044742,,,,,,,,,,,,,, -94600,4.7304344,1.5955125,,,,,,,,,,,,,, -94700,4.665809,1.6316531,,,,,,,,,,,,,, -94800,5.209061,1.6136031,,,,,,,,,,,,,, -94900,4.5840583,1.5940615,,,,,,,,,,,,,, -95000,4.87153,1.6977613,,,,,,,,,,,,,, -95100,4.742666,1.6067733,,,,,,,,,,,,,, -95200,4.580256,1.6053702,,,,,,,,,,,,,, -95300,5.0658793,1.755563,,,,,,,,,,,,,, -95344,,,0.7147639989852905,1.101299524307251,0.6574400067329407,1.3877466917037964,50000.0,0.5294000506401062,2.091960906982422,10000.0,32166.745491981503,33305.81133413315,32166.745491981503,1131.4108610153198,4.345837831497192,0.0 -95400,5.0657034,1.4977846,,,,,,,,,,,,,, -95500,4.1122546,1.5779728,,,,,,,,,,,,,, -95600,4.2076187,1.5420399,,,,,,,,,,,,,, -95700,4.354445,1.5969255,,,,,,,,,,,,,, -95800,4.773924,1.6904567,,,,,,,,,,,,,, -95900,4.1901274,1.640655,,,,,,,,,,,,,, -96000,5.360098,1.5513207,,,,,,,,,,,,,, -96100,4.0580206,1.6333858,,,,,,,,,,,,,, -96200,4.4244933,1.592581,,,,,,,,,,,,,, -96300,6.844723,1.6805004,,,,,,,,,,,,,, -96400,5.041364,1.7173679,,,,,,,,,,,,,, -96500,4.81256,1.6145642,,,,,,,,,,,,,, -96600,5.3089557,1.6858578,,,,,,,,,,,,,, -96700,4.556672,1.5473754,,,,,,,,,,,,,, -96800,4.511097,1.7305684,,,,,,,,,,,,,, -96860,,,0.7434829473495483,0.9642866849899292,0.6609599590301514,1.386073112487793,50000.0,0.5304000377655029,2.108454704284668,10000.0,32676.94704413414,33833.68562602997,32676.94704413414,1148.986226797104,4.391420841217041,0.0 -96900,5.0251565,1.7184784,,,,,,,,,,,,,, -97000,5.211226,1.6520448,,,,,,,,,,,,,, -97100,5.5431676,1.5574863,,,,,,,,,,,,,, -97200,4.4579606,1.6244903,,,,,,,,,,,,,, -97300,4.8078012,1.651056,,,,,,,,,,,,,, -97400,4.339105,1.5979894,,,,,,,,,,,,,, -97500,5.0906496,1.5354474,,,,,,,,,,,,,, -97600,4.559535,1.593037,,,,,,,,,,,,,, -97700,4.677559,1.6374767,,,,,,,,,,,,,, -97800,4.8973703,1.7393043,,,,,,,,,,,,,, -97900,5.516732,1.6083907,,,,,,,,,,,,,, -98000,5.5491357,1.5881025,,,,,,,,,,,,,, -98100,4.8300567,1.599838,,,,,,,,,,,,,, -98200,4.004778,1.5667714,,,,,,,,,,,,,, -98300,4.5047655,1.595222,,,,,,,,,,,,,, -98375,,,0.7309271097183228,1.028484344482422,0.6595799922943115,1.3712352514266968,50000.0,0.5261000394821167,2.065514326095581,10000.0,33186.945125579834,34361.40209579468,33186.945125579834,1166.6077721118927,4.435008049011231,0.0 -98400,4.632729,1.500219,,,,,,,,,,,,,, -98500,4.9440317,1.6093727,,,,,,,,,,,,,, -98600,5.373897,1.6306677,,,,,,,,,,,,,, -98700,4.8535123,1.5821232,,,,,,,,,,,,,, -98800,5.029027,1.5085076,,,,,,,,,,,,,, -98900,5.822056,1.6073204,,,,,,,,,,,,,, -99000,4.895613,1.6318394,,,,,,,,,,,,,, -99100,4.0447273,1.553688,,,,,,,,,,,,,, -99200,5.6563134,1.6312625,,,,,,,,,,,,,, -99300,5.64953,1.6027462,,,,,,,,,,,,,, -99400,5.214446,1.6841996,,,,,,,,,,,,,, -99500,4.593749,1.5774696,,,,,,,,,,,,,, -99600,5.1365104,1.5959718,,,,,,,,,,,,,, -99700,4.6581793,1.5272381,,,,,,,,,,,,,, -99800,4.731207,1.5988529,,,,,,,,,,,,,, -99890,,,0.7154615521430969,1.0986077785491943,0.6500999927520752,1.425310730934143,50000.0,0.5289000272750854,2.132866382598877,10000.0,33696.84750413895,34888.904074430466,33696.84750413895,1184.107824802399,4.481401681900024,0.0 -99900,4.570705,1.6844645,,,,,,,,,,,,,, -100000,4.813002,1.6742735,,,,,,,,,,,,,, -100100,4.811169,1.6494902,,,,,,,,,,,,,, -100200,4.7494836,1.4780123,,,,,,,,,,,,,, -100300,4.309312,1.6651795,,,,,,,,,,,,,, -100400,5.9492617,1.5952463,,,,,,,,,,,,,, -100500,5.0602984,1.6624984,,,,,,,,,,,,,, -100600,4.817809,1.6252316,,,,,,,,,,,,,, -100700,4.6709065,1.6439173,,,,,,,,,,,,,, -100800,5.2041554,1.5453638,,,,,,,,,,,,,, -100900,4.9327407,1.548666,,,,,,,,,,,,,, -101000,4.9280996,1.552825,,,,,,,,,,,,,, -101100,4.607864,1.5941579,,,,,,,,,,,,,, -101200,4.4344807,1.7016299,,,,,,,,,,,,,, -101300,5.1309166,1.5871435,,,,,,,,,,,,,, -101400,4.3219886,1.6650227,,,,,,,,,,,,,, -101406,,,0.7203643321990967,1.0835435390472412,0.6618399620056152,1.3780913352966309,50000.0,0.525700032711029,2.14704966545105,10000.0,34207.06012392044,35416.624915361404,34207.06012392044,1201.5192544460297,4.524857044219971,0.0 -101500,4.7607746,1.5396061,,,,,,,,,,,,,, -101600,4.8540626,1.615374,,,,,,,,,,,,,, -101700,5.490744,1.5834167,,,,,,,,,,,,,, -101800,5.0430946,1.5852071,,,,,,,,,,,,,, -101900,4.50483,1.4850892,,,,,,,,,,,,,, -102000,5.0556793,1.6547786,,,,,,,,,,,,,, -102100,4.3814387,1.6075398,,,,,,,,,,,,,, -102200,4.596679,1.5744547,,,,,,,,,,,,,, -102300,5.454852,1.649982,,,,,,,,,,,,,, -102400,5.4395657,1.6726518,,,,,,,,,,,,,, -102500,4.8452444,1.5905461,,,,,,,,,,,,,, -102600,6.3318563,1.5511978,,,,,,,,,,,,,, -102700,6.1420364,1.6048267,,,,,,,,,,,,,, -102800,4.3929896,1.5642446,,,,,,,,,,,,,, -102900,4.80949,1.4261746,,,,,,,,,,,,,, -102922,,,0.7308474183082581,1.019730567932129,0.6702199578285217,1.3424674272537231,50000.0,0.5365000367164612,2.069563865661621,10000.0,34717.284263134,35944.59972167015,34717.284263134,1219.1642200946808,4.576931715011597,0.0 -103000,5.4540143,1.6259415,,,,,,,,,,,,,, -103100,5.402918,1.5384287,,,,,,,,,,,,,, -103200,5.0824533,1.5841764,,,,,,,,,,,,,, -103300,5.5359807,1.5569296,,,,,,,,,,,,,, -103400,4.277847,1.5359501,,,,,,,,,,,,,, -103500,4.5877943,1.5225607,,,,,,,,,,,,,, -103600,4.7774434,1.5476345,,,,,,,,,,,,,, -103700,5.2351866,1.5289404,,,,,,,,,,,,,, -103800,4.604494,1.6439295,,,,,,,,,,,,,, -103900,4.9643426,1.5430236,,,,,,,,,,,,,, -104000,5.020053,1.5606753,,,,,,,,,,,,,, -104100,5.5051036,1.5562727,,,,,,,,,,,,,, -104200,4.66642,1.5333124,,,,,,,,,,,,,, -104300,4.726571,1.5349323,,,,,,,,,,,,,, -104400,4.4058433,1.4804683,,,,,,,,,,,,,, -104437,,,0.7296117544174194,1.0509129762649536,0.6679199934005737,1.347889065742493,50000.0,0.5321000218391418,2.0829954147338867,10000.0,35227.21813130379,36472.46457672119,35227.21813130379,1236.998259305954,4.621797323226929,0.0 -104500,5.084397,1.5956019,,,,,,,,,,,,,, -104600,4.76899,1.5672866,,,,,,,,,,,,,, -104700,4.6581497,1.525223,,,,,,,,,,,,,, -104800,4.5118403,1.5510428,,,,,,,,,,,,,, -104900,5.2184234,1.5872911,,,,,,,,,,,,,, -105000,5.675512,1.6043879,,,,,,,,,,,,,, -105100,5.2218995,1.5463583,,,,,,,,,,,,,, -105200,4.352925,1.6570843,,,,,,,,,,,,,, -105300,4.783992,1.6768934,,,,,,,,,,,,,, -105400,4.1960015,1.6183562,,,,,,,,,,,,,, -105500,5.1071115,1.5502155,,,,,,,,,,,,,, -105600,6.0194836,1.5417836,,,,,,,,,,,,,, -105700,5.3092065,1.6245025,,,,,,,,,,,,,, -105800,4.8667364,1.5444814,,,,,,,,,,,,,, -105900,5.6327133,1.4881223,,,,,,,,,,,,,, -105952,,,0.76175856590271,0.9061012864112854,0.674560010433197,1.3246057033538818,50000.0,0.5433000326156616,2.0417113304138184,10000.0,35737.22608041763,37000.28816699982,35737.22608041763,1254.704176902771,4.679151773452759,0.0 -106000,6.2209144,1.6192484,,,,,,,,,,,,,, -106100,4.586579,1.4593002,,,,,,,,,,,,,, -106200,4.873659,1.4705095,,,,,,,,,,,,,, -106300,4.9446716,1.501076,,,,,,,,,,,,,, -106400,4.621945,1.5417205,,,,,,,,,,,,,, -106500,5.1183887,1.5760562,,,,,,,,,,,,,, -106600,5.253082,1.5094168,,,,,,,,,,,,,, -106700,5.118197,1.5683794,,,,,,,,,,,,,, -106800,6.5267324,1.5939126,,,,,,,,,,,,,, -106900,4.6775274,1.4786859,,,,,,,,,,,,,, -107000,4.834519,1.4251136,,,,,,,,,,,,,, -107100,5.3027964,1.5889443,,,,,,,,,,,,,, -107200,4.8223243,1.641547,,,,,,,,,,,,,, -107300,5.0566545,1.6065016,,,,,,,,,,,,,, -107400,4.7291374,1.544948,,,,,,,,,,,,,, -107468,,,0.7491828799247742,0.9538630843162536,0.675819993019104,1.3117072582244873,50000.0,0.546500027179718,2.02262544631958,10000.0,36247.43604612351,37528.2292740345,36247.43604612351,1272.3301212787628,4.732055902481079,0.0 -107500,4.396631,1.5383725,,,,,,,,,,,,,, -107600,6.311528,1.4854362,,,,,,,,,,,,,, -107700,4.65199,1.5855728,,,,,,,,,,,,,, -107800,4.996709,1.5326215,,,,,,,,,,,,,, -107900,5.3343005,1.6160848,,,,,,,,,,,,,, -108000,5.768885,1.6484107,,,,,,,,,,,,,, -108100,5.16365,1.6463492,,,,,,,,,,,,,, -108200,5.0041986,1.55969,,,,,,,,,,,,,, -108300,5.631543,1.57776,,,,,,,,,,,,,, -108400,4.9394903,1.5811411,,,,,,,,,,,,,, -108500,5.1753716,1.4305704,,,,,,,,,,,,,, -108600,5.5439014,1.6825237,,,,,,,,,,,,,, -108700,4.7021255,1.5513588,,,,,,,,,,,,,, -108800,5.1963058,1.5903515,,,,,,,,,,,,,, -108900,5.248639,1.5300007,,,,,,,,,,,,,, -108983,,,0.7424266338348389,0.9738861918449402,0.6760599613189697,1.3137764930725098,50000.0,0.541700005531311,2.0563201904296875,10000.0,36757.352653265,38056.18480205536,36757.352653265,1290.2720756530762,4.7767369747161865,0.0 -109000,5.014316,1.5156546,,,,,,,,,,,,,, -109100,5.6840487,1.5674114,,,,,,,,,,,,,, -109200,5.5533924,1.4430383,,,,,,,,,,,,,, -109300,5.196017,1.5544772,,,,,,,,,,,,,, -109400,4.7842894,1.5336928,,,,,,,,,,,,,, -109500,4.8173566,1.5836087,,,,,,,,,,,,,, -109600,4.989482,1.5183846,,,,,,,,,,,,,, -109700,4.846234,1.4855338,,,,,,,,,,,,,, -109800,5.5921216,1.4629294,,,,,,,,,,,,,, -109900,5.8224874,1.5362728,,,,,,,,,,,,,, -110000,5.432146,1.5121733,,,,,,,,,,,,,, -110100,4.552069,1.4690334,,,,,,,,,,,,,, -110200,5.035556,1.537664,,,,,,,,,,,,,, -110300,4.588936,1.4568872,,,,,,,,,,,,,, -110400,5.164335,1.6364362,,,,,,,,,,,,,, -110499,,,0.7394770383834839,0.996335506439209,0.6743599772453308,1.319318413734436,50000.0,0.5463000535964966,2.017503261566162,10000.0,37267.507370471954,38583.88388371468,37267.507370471954,1307.7142674922943,4.826080322265625,0.0 -110500,5.180003,1.5223024,,,,,,,,,,,,,, -110600,4.5048075,1.4884477,,,,,,,,,,,,,, -110700,5.013239,1.5248942,,,,,,,,,,,,,, -110800,5.005399,1.5515593,,,,,,,,,,,,,, -110900,4.907119,1.5584487,,,,,,,,,,,,,, -111000,4.702987,1.562338,,,,,,,,,,,,,, -111100,5.865047,1.5664566,,,,,,,,,,,,,, -111200,5.9976673,1.4020052,,,,,,,,,,,,,, -111300,5.011124,1.4700937,,,,,,,,,,,,,, -111400,4.985495,1.4787581,,,,,,,,,,,,,, -111500,5.8882165,1.5686023,,,,,,,,,,,,,, -111600,4.89512,1.5054502,,,,,,,,,,,,,, -111700,5.477424,1.5247451,,,,,,,,,,,,,, -111800,5.252787,1.5206966,,,,,,,,,,,,,, -111900,5.4740205,1.4654391,,,,,,,,,,,,,, -112000,5.162662,1.434595,,,,,,,,,,,,,, -112014,,,0.7409917116165161,0.985767662525177,0.6783599853515625,1.3062725067138672,50000.0,0.5491999983787537,2.0260019302368164,10000.0,37777.445001125336,39111.363669633865,37777.445001125336,1325.155839920044,4.8738861083984375,0.0 -112100,5.3086743,1.4725575,,,,,,,,,,,,,, -112200,5.1508994,1.4934448,,,,,,,,,,,,,, -112300,5.1117435,1.5482364,,,,,,,,,,,,,, -112400,5.964021,1.5504496,,,,,,,,,,,,,, -112500,4.972198,1.4016097,,,,,,,,,,,,,, -112600,4.8441715,1.5352579,,,,,,,,,,,,,, -112700,5.5209303,1.4250983,,,,,,,,,,,,,, -112800,5.0144405,1.4274539,,,,,,,,,,,,,, -112900,4.831172,1.5279562,,,,,,,,,,,,,, -113000,4.7413235,1.4936162,,,,,,,,,,,,,, -113100,4.8378835,1.4680846,,,,,,,,,,,,,, -113200,4.593509,1.4662782,,,,,,,,,,,,,, -113300,5.410572,1.4299041,,,,,,,,,,,,,, -113400,6.2286277,1.4695559,,,,,,,,,,,,,, -113500,5.290262,1.5479001,,,,,,,,,,,,,, -113529,,,0.7640305757522583,0.9003487825393677,0.6823399662971497,1.285157561302185,50000.0,0.5555000305175781,1.985240340232849,10000.0,38287.514525175095,39638.94228172302,38287.514525175095,1342.566482782364,4.919671297073364,0.0 -113600,4.5290346,1.4657047,,,,,,,,,,,,,, -113700,5.2333603,1.4654008,,,,,,,,,,,,,, -113800,5.289908,1.5357778,,,,,,,,,,,,,, -113900,5.039063,1.5506928,,,,,,,,,,,,,, -114000,5.782368,1.53739,,,,,,,,,,,,,, -114100,5.2513275,1.4253476,,,,,,,,,,,,,, -114200,5.3718376,1.4646066,,,,,,,,,,,,,, -114300,5.709251,1.4373236,,,,,,,,,,,,,, -114400,5.367743,1.4797424,,,,,,,,,,,,,, -114500,5.487208,1.5125244,,,,,,,,,,,,,, -114600,5.546781,1.5195466,,,,,,,,,,,,,, -114700,4.7511444,1.4265944,,,,,,,,,,,,,, -114800,5.2927027,1.4672587,,,,,,,,,,,,,, -114900,4.935975,1.5143483,,,,,,,,,,,,,, -115000,5.577953,1.4508628,,,,,,,,,,,,,, -115044,,,0.772859513759613,0.8542397618293762,0.6820399761199951,1.2807635068893433,50000.0,0.553600013256073,1.981478214263916,10000.0,38797.47217464447,40166.47235298157,38797.47217464447,1360.0370292663574,4.967864036560059,0.0 -115100,4.9528394,1.4502339,,,,,,,,,,,,,, -115200,5.465086,1.5031765,,,,,,,,,,,,,, -115300,5.1831045,1.4567695,,,,,,,,,,,,,, -115400,5.0929766,1.4643211,,,,,,,,,,,,,, -115500,5.099505,1.526469,,,,,,,,,,,,,, -115600,5.9500995,1.3984386,,,,,,,,,,,,,, -115700,5.2292595,1.5441782,,,,,,,,,,,,,, -115800,5.1539354,1.4229606,,,,,,,,,,,,,, -115900,5.040158,1.3453218,,,,,,,,,,,,,, -116000,4.415904,1.4203266,,,,,,,,,,,,,, -116100,5.7106857,1.5114245,,,,,,,,,,,,,, -116200,6.1591897,1.5250523,,,,,,,,,,,,,, -116300,5.3080125,1.4479338,,,,,,,,,,,,,, -116400,5.100287,1.4597172,,,,,,,,,,,,,, -116500,5.7097945,1.3322358,,,,,,,,,,,,,, -116559,,,0.7597456574440002,0.911958634853363,0.681939959526062,1.299109697341919,50000.0,0.5580000281333923,1.988004207611084,10000.0,39307.38990712166,40694.417508125305,39307.38990712166,1377.9627270698547,5.016592741012573,0.0 -116600,5.8574944,1.4177701,,,,,,,,,,,,,, -116700,6.3156004,1.5054963,,,,,,,,,,,,,, -116800,5.1591296,1.4267721,,,,,,,,,,,,,, -116900,5.525194,1.439896,,,,,,,,,,,,,, -117000,4.876455,1.4153628,,,,,,,,,,,,,, -117100,5.189068,1.4315013,,,,,,,,,,,,,, -117200,5.370707,1.5273885,,,,,,,,,,,,,, -117300,5.59405,1.4395971,,,,,,,,,,,,,, -117400,4.725178,1.3823498,,,,,,,,,,,,,, -117500,6.117973,1.4132614,,,,,,,,,,,,,, -117600,5.4995704,1.5148762,,,,,,,,,,,,,, -117700,5.4478097,1.4356058,,,,,,,,,,,,,, -117800,5.7431536,1.4697506,,,,,,,,,,,,,, -117900,4.765572,1.3585311,,,,,,,,,,,,,, -118000,4.806803,1.4491748,,,,,,,,,,,,,, -118074,,,0.7604631781578064,0.8970020413398743,0.6898199915885925,1.2545552253723145,50000.0,0.5614000558853149,1.96667742729187,10000.0,39817.51469898224,41222.12512159348,39817.51469898224,1395.4464178085327,5.063026189804077,0.0 -118100,4.998506,1.4385271,,,,,,,,,,,,,, -118200,5.2371535,1.4507513,,,,,,,,,,,,,, -118300,5.0797396,1.5013958,,,,,,,,,,,,,, -118400,5.326624,1.3294458,,,,,,,,,,,,,, -118500,5.4666734,1.4695761,,,,,,,,,,,,,, -118600,5.015722,1.3936561,,,,,,,,,,,,,, -118700,5.5115047,1.4562961,,,,,,,,,,,,,, -118800,5.3304195,1.4244828,,,,,,,,,,,,,, -118900,5.117477,1.4055936,,,,,,,,,,,,,, -119000,5.160011,1.4104174,,,,,,,,,,,,,, -119100,5.6025696,1.5025342,,,,,,,,,,,,,, -119200,5.697259,1.4534638,,,,,,,,,,,,,, -119300,5.330858,1.3955852,,,,,,,,,,,,,, -119400,5.054939,1.387919,,,,,,,,,,,,,, -119500,4.9439273,1.4182186,,,,,,,,,,,,,, -119589,,,0.7599050998687744,0.9128103256225586,0.6882799863815308,1.2635095119476318,50000.0,0.5593000054359436,1.9771336317062376,10000.0,40327.41791152954,41749.75271129608,40327.41791152954,1413.069516658783,5.112490177154541,0.0 -119600,5.518734,1.5370595,,,,,,,,,,,,,, -119700,6.11807,1.4584632,,,,,,,,,,,,,, -119800,5.8194084,1.4202303,,,,,,,,,,,,,, -119900,4.966916,1.3793585,,,,,,,,,,,,,, -120000,5.5613446,1.4072036,,,,,,,,,,,,,, -120100,5.423315,1.4046422,,,,,,,,,,,,,, -120200,5.52415,1.473189,,,,,,,,,,,,,, -120300,5.7704344,1.495487,,,,,,,,,,,,,, -120400,5.501818,1.3637006,,,,,,,,,,,,,, -120500,5.383137,1.3467749,,,,,,,,,,,,,, -120600,5.7156196,1.5646055,,,,,,,,,,,,,, -120700,6.4274426,1.5490091,,,,,,,,,,,,,, -120800,6.011456,1.4142084,,,,,,,,,,,,,, -120900,4.9710965,1.4037262,,,,,,,,,,,,,, -121000,6.244154,1.4945354,,,,,,,,,,,,,, -121100,5.6974754,1.3320311,,,,,,,,,,,,,, -121104,,,0.7638512253761292,0.8878701329231262,0.6916399598121643,1.2425378561019895,50000.0,0.5651000142097473,1.946260452270508,10000.0,40837.543076753616,42277.27058959007,40837.543076753616,1430.3549864292145,5.1657538414001465,0.0 -121200,6.0205626,1.4199257,,,,,,,,,,,,,, -121300,5.5212927,1.3229511,,,,,,,,,,,,,, -121400,5.869204,1.4292058,,,,,,,,,,,,,, -121500,5.118605,1.4275131,,,,,,,,,,,,,, -121600,5.969365,1.4400244,,,,,,,,,,,,,, -121700,5.62735,1.4850882,,,,,,,,,,,,,, -121800,5.783166,1.4226481,,,,,,,,,,,,,, -121900,6.194717,1.5872282,,,,,,,,,,,,,, -122000,5.2171893,1.3446661,,,,,,,,,,,,,, -122100,6.4231653,1.4956702,,,,,,,,,,,,,, -122200,5.82857,1.3384693,,,,,,,,,,,,,, -122300,5.246294,1.4201058,,,,,,,,,,,,,, -122400,5.505138,1.3855674,,,,,,,,,,,,,, -122500,5.468547,1.4194344,,,,,,,,,,,,,, -122600,6.4065924,1.4059192,,,,,,,,,,,,,, -122619,,,0.8024553656578064,0.7302818298339844,0.693619966506958,1.2337989807128906,50000.0,0.5672000050544739,1.9245952367782595,10000.0,41347.52177858353,42805.71956944466,41347.52177858353,1448.7214317321775,5.217383623123169,0.0 -122700,4.8396654,1.32075,,,,,,,,,,,,,, -122800,6.3956723,1.4331084,,,,,,,,,,,,,, -122900,5.7042294,1.2491338,,,,,,,,,,,,,, -123000,6.078504,1.382777,,,,,,,,,,,,,, -123100,6.1784124,1.3986969,,,,,,,,,,,,,, -123200,5.6757503,1.4721822,,,,,,,,,,,,,, -123300,5.211739,1.342186,,,,,,,,,,,,,, -123400,5.918817,1.4415209,,,,,,,,,,,,,, -123500,6.096278,1.4722785,,,,,,,,,,,,,, -123600,6.084982,1.351144,,,,,,,,,,,,,, -123700,4.7571096,1.342133,,,,,,,,,,,,,, -123800,5.608222,1.3514137,,,,,,,,,,,,,, -123900,5.744373,1.3200079,,,,,,,,,,,,,, -124000,5.1639485,1.3097861,,,,,,,,,,,,,, -124100,5.3264103,1.3126156,,,,,,,,,,,,,, -124135,,,0.7833027839660645,0.8094884753227234,0.6947399973869324,1.2455179691314695,50000.0,0.5654000043869019,1.947636365890503,10000.0,41857.67465591431,43333.32304549217,41857.67465591431,1466.0691316127777,5.268237113952637,0.0 -124200,5.347891,1.3836961,,,,,,,,,,,,,, -124300,6.0148516,1.4988651,,,,,,,,,,,,,, -124400,5.1481824,1.3960996,,,,,,,,,,,,,, -124500,5.536541,1.4175106,,,,,,,,,,,,,, -124600,5.4956226,1.3021524,,,,,,,,,,,,,, -124700,5.8146214,1.3720051,,,,,,,,,,,,,, -124800,6.2638164,1.4084442,,,,,,,,,,,,,, -124900,5.7290344,1.3874396,,,,,,,,,,,,,, -125000,5.9330606,1.3685412,,,,,,,,,,,,,, -125100,5.8809805,1.4298294,,,,,,,,,,,,,, -125200,5.7239413,1.458184,,,,,,,,,,,,,, -125300,5.865568,1.3614559,,,,,,,,,,,,,, -125400,6.03118,1.3727031,,,,,,,,,,,,,, -125500,5.8907166,1.5115441,,,,,,,,,,,,,, -125600,6.0413957,1.4053923,,,,,,,,,,,,,, -125650,,,0.7840800285339355,0.8092227578163147,0.6946600079536438,1.2339342832565308,50000.0,0.572700023651123,1.9287670850753784,10000.0,42367.645431280136,43861.1323056221,42367.645431280136,1483.8026728630066,5.32036828994751,0.0 -125700,5.6543903,1.3785291,,,,,,,,,,,,,, -125800,5.78402,1.3029808,,,,,,,,,,,,,, -125900,5.4961567,1.3829765,,,,,,,,,,,,,, -126000,5.9028416,1.4335301,,,,,,,,,,,,,, -126100,5.5693326,1.2672503,,,,,,,,,,,,,, -126200,5.765673,1.3461219,,,,,,,,,,,,,, -126300,6.409089,1.380269,,,,,,,,,,,,,, -126400,5.385069,1.3112364,,,,,,,,,,,,,, -126500,7.0001454,1.3286836,,,,,,,,,,,,,, -126600,5.738199,1.2396882,,,,,,,,,,,,,, -126700,5.448824,1.2886027,,,,,,,,,,,,,, -126800,6.7705755,1.2893088,,,,,,,,,,,,,, -126900,5.701719,1.3040013,,,,,,,,,,,,,, -127000,5.7338953,1.3747149,,,,,,,,,,,,,, -127100,5.726292,1.3322984,,,,,,,,,,,,,, -127165,,,0.778340220451355,0.8295766711235046,0.7000399827957153,1.2225801944732666,50000.0,0.5708000063896179,1.908912181854248,10000.0,42877.60603952408,44388.82006406784,42877.60603952408,1501.4231088161469,5.373712062835693,0.0 -127200,6.2575464,1.3889358,,,,,,,,,,,,,, -127300,5.7248135,1.5244417,,,,,,,,,,,,,, -127400,7.360517,1.4032334,,,,,,,,,,,,,, -127500,5.9672704,1.3347038,,,,,,,,,,,,,, -127600,6.4480195,1.3155184,,,,,,,,,,,,,, -127700,6.2917414,1.3726954,,,,,,,,,,,,,, -127800,6.3289394,1.3957583,,,,,,,,,,,,,, -127900,6.125749,1.4489815,,,,,,,,,,,,,, -128000,6.1996155,1.391689,,,,,,,,,,,,,, -128100,7.1861844,1.4291729,,,,,,,,,,,,,, -128200,5.9667635,1.3176185,,,,,,,,,,,,,, -128300,5.9815946,1.4220726,,,,,,,,,,,,,, -128400,6.2031136,1.362615,,,,,,,,,,,,,, -128500,5.931864,1.3758214,,,,,,,,,,,,,, -128600,5.6118183,1.2922579,,,,,,,,,,,,,, -128680,,,0.78324294090271,0.8075078129768372,0.7049199938774109,1.190179467201233,50000.0,0.5818000435829163,1.8730347156524656,10000.0,43387.53185915947,44916.69388747215,43387.53185915947,1519.2647440433502,5.426731824874878,0.0 -128700,6.5253143,1.2077311,,,,,,,,,,,,,, -128800,5.9138217,1.3199248,,,,,,,,,,,,,, -128900,5.503734,1.2994931,,,,,,,,,,,,,, -129000,5.700749,1.2577643,,,,,,,,,,,,,, -129100,5.293547,1.3708922,,,,,,,,,,,,,, -129200,6.1776094,1.3793787,,,,,,,,,,,,,, -129300,6.626723,1.4180499,,,,,,,,,,,,,, -129400,5.8938847,1.2042798,,,,,,,,,,,,,, -129500,6.1506934,1.3041312,,,,,,,,,,,,,, -129600,5.7506876,1.2042418,,,,,,,,,,,,,, -129700,6.8726087,1.3238453,,,,,,,,,,,,,, -129800,5.867978,1.2450491,,,,,,,,,,,,,, -129900,6.068451,1.3485643,,,,,,,,,,,,,, -130000,6.5422077,1.4255035,,,,,,,,,,,,,, -130100,6.036627,1.2233227,,,,,,,,,,,,,, -130196,,,0.7864915132522583,0.789760947227478,0.7076999545097351,1.1743167638778689,50000.0,0.5768000483512878,1.8925039768219,10000.0,43897.6113114357,45444.25380349159,43897.6113114357,1536.644121170044,5.475257635116577,0.0 -130200,6.455589,1.2674106,,,,,,,,,,,,,, -130300,5.271184,1.2297064,,,,,,,,,,,,,, -130400,6.6090956,1.2880203,,,,,,,,,,,,,, -130500,6.3729153,1.2195474,,,,,,,,,,,,,, -130600,5.7692914,1.2614896,,,,,,,,,,,,,, -130700,6.423444,1.3364632,,,,,,,,,,,,,, -130800,5.723543,1.2782959,,,,,,,,,,,,,, -130900,5.913618,1.2667329,,,,,,,,,,,,,, -131000,5.9626923,1.2422042,,,,,,,,,,,,,, -131100,5.5048714,1.2184799,,,,,,,,,,,,,, -131200,6.2066665,1.3269012,,,,,,,,,,,,,, -131300,5.576992,1.2487631,,,,,,,,,,,,,, -131400,6.011139,1.3198392,,,,,,,,,,,,,, -131500,7.082122,1.4171443,,,,,,,,,,,,,, -131600,5.968187,1.3497245,,,,,,,,,,,,,, -131700,6.871319,1.3204452,,,,,,,,,,,,,, -131711,,,0.8223453164100647,0.6572174429893494,0.7078799605369568,1.1762797832489014,50000.0,0.5764000415802002,1.8879553079605105,10000.0,44407.7055516243,45971.96899843216,44407.7055516243,1554.1618838310242,5.526075601577759,0.0 -131800,6.552672,1.3057107,,,,,,,,,,,,,, -131900,5.838248,1.2933891,,,,,,,,,,,,,, -132000,5.4178433,1.2067521,,,,,,,,,,,,,, -132100,6.18295,1.3078996,,,,,,,,,,,,,, -132200,5.796489,1.3228658,,,,,,,,,,,,,, -132300,6.927788,1.2604096,,,,,,,,,,,,,, -132400,6.19545,1.2690191,,,,,,,,,,,,,, -132500,5.496327,1.1718937,,,,,,,,,,,,,, -132600,6.9621983,1.3205132,,,,,,,,,,,,,, -132700,6.3949633,1.3200488,,,,,,,,,,,,,, -132800,6.1069565,1.2262275,,,,,,,,,,,,,, -132900,5.696856,1.3063883,,,,,,,,,,,,,, -133000,6.633883,1.4183985,,,,,,,,,,,,,, -133100,6.209236,1.2458241,,,,,,,,,,,,,, -133200,6.207973,1.4043152,,,,,,,,,,,,,, -133226,,,0.8046476244926453,0.7072620391845703,0.7079600095748901,1.179681420326233,50000.0,0.581000030040741,1.8739051818847656,10000.0,44917.60682630539,46499.55842018128,44917.60682630539,1571.7473032474518,5.575676918029785,0.0 -133300,5.8950753,1.2352992,,,,,,,,,,,,,, -133400,6.6654725,1.3523138,,,,,,,,,,,,,, -133500,6.114542,1.2977796,,,,,,,,,,,,,, -133600,6.463686,1.2532581,,,,,,,,,,,,,, -133700,6.2702565,1.2502754,,,,,,,,,,,,,, -133800,6.515168,1.3818609,,,,,,,,,,,,,, -133900,5.7331114,1.2629745,,,,,,,,,,,,,, -134000,6.0829945,1.2509491,,,,,,,,,,,,,, -134100,6.0862155,1.2638764,,,,,,,,,,,,,, -134200,6.0825562,1.2417992,,,,,,,,,,,,,, -134300,7.161735,1.2754532,,,,,,,,,,,,,, -134400,5.842458,1.1917834,,,,,,,,,,,,,, -134500,6.57491,1.3359135,,,,,,,,,,,,,, -134600,6.4210224,1.1748437,,,,,,,,,,,,,, -134700,6.495916,1.3348296,,,,,,,,,,,,,, -134741,,,0.8073381781578064,0.7013617157936096,0.7153399586677551,1.1508142948150637,50000.0,0.5949000120162964,1.822237253189087,10000.0,45427.66591835022,47027.95869851112,45427.66591835022,1589.98424077034,5.627666711807251,0.0 -134800,5.885553,1.2204254,,,,,,,,,,,,,, -134900,6.1432805,1.1577436,,,,,,,,,,,,,, -135000,6.755051,1.239338,,,,,,,,,,,,,, -135100,6.455429,1.2229064,,,,,,,,,,,,,, -135200,6.4626956,1.3616793,,,,,,,,,,,,,, -135300,6.392474,1.3153285,,,,,,,,,,,,,, -135400,7.263964,1.3661253,,,,,,,,,,,,,, -135500,6.187065,1.3139141,,,,,,,,,,,,,, -135600,6.5879354,1.258733,,,,,,,,,,,,,, -135700,7.009325,1.2229828,,,,,,,,,,,,,, -135800,5.9215555,1.254554,,,,,,,,,,,,,, -135900,6.0273347,1.1813157,,,,,,,,,,,,,, -136000,6.2212863,1.2667527,,,,,,,,,,,,,, -136100,7.47929,1.3056506,,,,,,,,,,,,,, -136200,6.3552094,1.2835158,,,,,,,,,,,,,, -136255,,,0.8015385866165161,0.7381904721260071,0.7129799723625183,1.1590455770492554,50000.0,0.5845000147819519,1.8409979343414309,10000.0,45937.36339020729,47555.45125055313,45937.36339020729,1607.3338513374329,6.020332336425781,0.0 -136300,6.0701985,1.1759362,,,,,,,,,,,,,, -136400,6.3042684,1.2085494,,,,,,,,,,,,,, -136500,6.8946834,1.2319365,,,,,,,,,,,,,, -136600,5.5553355,1.1718582,,,,,,,,,,,,,, -136700,6.3970203,1.263578,,,,,,,,,,,,,, -136800,6.8794045,1.2423856,,,,,,,,,,,,,, -136900,5.636392,1.1987638,,,,,,,,,,,,,, -137000,6.640323,1.2096266,,,,,,,,,,,,,, -137100,6.882215,1.3257287,,,,,,,,,,,,,, -137200,6.62862,1.2243575,,,,,,,,,,,,,, -137300,6.1436396,1.2549604,,,,,,,,,,,,,, -137400,6.199885,1.2079245,,,,,,,,,,,,,, -137500,6.059951,1.2716088,,,,,,,,,,,,,, -137600,5.8367405,1.1550492,,,,,,,,,,,,,, -137700,6.3745966,1.3189503,,,,,,,,,,,,,, -137770,,,0.8052256107330322,0.7155845761299133,0.7154799699783325,1.1480318307876587,50000.0,0.5848000049591064,1.854791045188904,10000.0,46447.48622989655,48083.33385229111,46447.48622989655,1624.9835669994354,6.077698230743408,0.0 -137800,6.0075283,1.2361723,,,,,,,,,,,,,, -137900,6.4994893,1.2386417,,,,,,,,,,,,,, -138000,6.394932,1.2308195,,,,,,,,,,,,,, -138100,6.3872523,1.2281289,,,,,,,,,,,,,, -138200,6.4755554,1.2506529,,,,,,,,,,,,,, -138300,6.961326,1.220588,,,,,,,,,,,,,, -138400,6.496266,1.310917,,,,,,,,,,,,,, -138500,6.145487,1.157207,,,,,,,,,,,,,, -138600,7.17537,1.2508147,,,,,,,,,,,,,, -138700,6.48127,1.2861254,,,,,,,,,,,,,, -138800,6.352991,1.1432836,,,,,,,,,,,,,, -138900,6.9026175,1.2222356,,,,,,,,,,,,,, -139000,6.3377476,1.2920225,,,,,,,,,,,,,, -139100,6.7439394,1.1952575,,,,,,,,,,,,,, -139200,6.255166,1.211312,,,,,,,,,,,,,, -139285,,,0.8102080225944519,0.6994613409042358,0.7164799571037292,1.1391528844833374,50000.0,0.5924000144004822,1.8475563526153564,10000.0,46957.41595888138,48611.016573905945,46957.41595888138,1642.6341168880465,6.127922534942627,0.0 -139300,6.4502134,1.2952218,,,,,,,,,,,,,, -139400,6.1168294,1.1547582,,,,,,,,,,,,,, -139500,6.8299837,1.2417855,,,,,,,,,,,,,, -139600,6.1027627,1.1623598,,,,,,,,,,,,,, -139700,6.6820154,1.1170033,,,,,,,,,,,,,, -139800,7.2697797,1.2249355,,,,,,,,,,,,,, -139900,6.766762,1.1861057,,,,,,,,,,,,,, -140000,7.2871,1.1669935,,,,,,,,,,,,,, -140100,6.868517,1.2920505,,,,,,,,,,,,,, -140200,6.4369416,1.1167367,,,,,,,,,,,,,, -140300,6.254219,1.1912664,,,,,,,,,,,,,, -140400,7.067526,1.1616436,,,,,,,,,,,,,, -140500,7.061565,1.275825,,,,,,,,,,,,,, -140600,6.4760323,1.1580151,,,,,,,,,,,,,, -140700,6.5972443,1.1607634,,,,,,,,,,,,,, -140800,,,0.8451849222183228,0.5710806846618652,0.7232999801635742,1.1173813343048096,50000.0,0.5952000021934509,1.7925559282302856,10000.0,47467.37585401535,49138.57271814346,47467.37585401535,1660.1281578540802,6.1773834228515625,0.0 -140800,6.409684,1.1264422,,,,,,,,,,,,,, -140900,7.0950384,1.194683,,,,,,,,,,,,,, -141000,6.898777,1.1946642,,,,,,,,,,,,,, -141100,6.857469,1.1376945,,,,,,,,,,,,,, -141200,6.5785494,1.1411397,,,,,,,,,,,,,, -141300,7.6740737,1.2500479,,,,,,,,,,,,,, -141400,6.716075,1.167243,,,,,,,,,,,,,, -141500,7.327019,1.2060084,,,,,,,,,,,,,, -141600,7.17308,1.2789326,,,,,,,,,,,,,, -141700,7.003271,1.179034,,,,,,,,,,,,,, -141800,6.594555,1.2125183,,,,,,,,,,,,,, -141900,6.4142456,1.1380699,,,,,,,,,,,,,, -142000,7.156625,1.1713946,,,,,,,,,,,,,, -142100,7.8353543,1.2439089,,,,,,,,,,,,,, -142200,6.8180394,1.0956984,,,,,,,,,,,,,, -142300,6.4928102,1.1931211,,,,,,,,,,,,,, -142315,,,0.8337252736091614,0.6066722273826599,0.7225199937820435,1.1123135089874268,50000.0,0.5986000299453735,1.7970387935638428,10000.0,47977.319326639175,49666.27079510689,47977.319326639175,1677.7803509235382,6.227893829345703,0.0 -142400,6.7740917,1.2095225,,,,,,,,,,,,,, -142500,8.419297,1.2259687,,,,,,,,,,,,,, -142600,7.250378,1.1389306,,,,,,,,,,,,,, -142700,6.5612392,1.1225593,,,,,,,,,,,,,, -142800,7.0409317,1.2476264,,,,,,,,,,,,,, -142900,7.464379,1.2093127,,,,,,,,,,,,,, -143000,6.211812,1.1807292,,,,,,,,,,,,,, -143100,6.936364,1.2074494,,,,,,,,,,,,,, -143200,6.798118,1.1549714,,,,,,,,,,,,,, -143300,6.53018,1.2205272,,,,,,,,,,,,,, -143400,6.973166,1.1763546,,,,,,,,,,,,,, -143500,8.468186,1.077241,,,,,,,,,,,,,, -143600,7.29009,1.142129,,,,,,,,,,,,,, -143700,7.077285,1.0519745,,,,,,,,,,,,,, -143800,7.838365,1.2776085,,,,,,,,,,,,,, -143830,,,0.8341039419174194,0.6040483713150024,0.7270999550819397,1.1015565395355225,50000.0,0.600100040435791,1.7974672317504885,10000.0,48487.28637838364,50193.8794400692,48487.28637838364,1695.3257067203522,6.27121639251709,0.0 -143900,6.166938,1.1089438,,,,,,,,,,,,,, -144000,7.4371996,1.1633724,,,,,,,,,,,,,, -144100,7.1450033,1.1752961,,,,,,,,,,,,,, -144200,8.048636,1.1610075,,,,,,,,,,,,,, -144300,6.82726,1.1058528,,,,,,,,,,,,,, -144400,7.27055,1.1562136,,,,,,,,,,,,,, -144500,6.7503576,1.0796486,,,,,,,,,,,,,, -144600,7.2420425,1.1813734,,,,,,,,,,,,,, -144700,7.247584,1.1362727,,,,,,,,,,,,,, -144800,6.401622,1.0614768,,,,,,,,,,,,,, -144900,6.269927,1.0805542,,,,,,,,,,,,,, -145000,7.5600777,1.2016411,,,,,,,,,,,,,, -145100,6.963135,1.100014,,,,,,,,,,,,,, -145200,7.2046967,1.2054548,,,,,,,,,,,,,, -145300,6.271477,1.1349436,,,,,,,,,,,,,, -145346,,,0.8286631107330322,0.6169041991233826,0.727840006351471,1.1030001640319824,50000.0,0.5995000004768372,1.8142592906951904,10000.0,48997.5089635849,50721.83307147026,48997.5089635849,1712.9535655975342,6.321614980697632,0.0 -145400,8.096639,1.017007,,,,,,,,,,,,,, -145500,6.6360426,1.1403306,,,,,,,,,,,,,, -145600,6.4689326,1.111359,,,,,,,,,,,,,, -145700,7.752715,1.1958036,,,,,,,,,,,,,, -145800,6.451384,1.1218693,,,,,,,,,,,,,, -145900,7.851685,1.141647,,,,,,,,,,,,,, -146000,6.9531837,1.2437489,,,,,,,,,,,,,, -146100,7.049883,1.1659853,,,,,,,,,,,,,, -146200,6.728024,1.1589799,,,,,,,,,,,,,, -146300,6.5330963,1.1379948,,,,,,,,,,,,,, -146400,7.610319,1.0895767,,,,,,,,,,,,,, -146500,7.927475,1.1992797,,,,,,,,,,,,,, -146600,6.7362003,0.9902487,,,,,,,,,,,,,, -146700,8.371789,1.0979815,,,,,,,,,,,,,, -146800,7.637192,1.2285068,,,,,,,,,,,,,, -146862,,,0.832051157951355,0.6025562286376953,0.7324999570846558,1.0817550420761108,50000.0,0.6061000227928162,1.7760010957717896,10000.0,49507.73680782318,51249.25549435616,49507.73680782318,1730.040581703186,6.375980854034424,0.0 -146900,7.2094426,1.1989565,,,,,,,,,,,,,, -147000,7.075985,0.9542252,,,,,,,,,,,,,, -147100,7.2719264,1.055191,,,,,,,,,,,,,, -147200,7.1773596,1.138733,,,,,,,,,,,,,, -147300,8.62469,1.09582,,,,,,,,,,,,,, -147400,6.8674307,1.1326748,,,,,,,,,,,,,, -147500,7.3029575,1.1853853,,,,,,,,,,,,,, -147600,7.8349857,1.158776,,,,,,,,,,,,,, -147700,8.116922,1.1169932,,,,,,,,,,,,,, -147800,6.773388,1.0542599,,,,,,,,,,,,,, -147900,6.7899227,1.0994834,,,,,,,,,,,,,, -148000,7.1184797,1.175249,,,,,,,,,,,,,, -148100,7.70524,1.1549025,,,,,,,,,,,,,, -148200,7.437173,1.0938952,,,,,,,,,,,,,, -148300,7.5546412,1.0664128,,,,,,,,,,,,,, -148377,,,0.840840220451355,0.5678520202636719,0.732759952545166,1.0800023078918457,50000.0,0.6103000044822693,1.783352971076965,10000.0,50017.81396985054,51777.063213825226,50017.81396985054,1747.6648552417755,6.429377555847168,0.0 -148400,7.304529,1.157076,,,,,,,,,,,,,, -148500,7.548203,1.016341,,,,,,,,,,,,,, -148600,7.4571786,1.1144208,,,,,,,,,,,,,, -148700,7.238866,1.0216227,,,,,,,,,,,,,, -148800,7.4063263,1.1101872,,,,,,,,,,,,,, -148900,7.1605825,1.111571,,,,,,,,,,,,,, -149000,8.687685,1.1266793,,,,,,,,,,,,,, -149100,8.384318,1.0939264,,,,,,,,,,,,,, -149200,8.895827,1.1335142,,,,,,,,,,,,,, -149300,7.178083,1.1240935,,,,,,,,,,,,,, -149400,8.482304,1.121155,,,,,,,,,,,,,, -149500,7.886253,1.1175433,,,,,,,,,,,,,, -149600,6.731093,1.0144197,,,,,,,,,,,,,, -149700,7.531521,1.0554013,,,,,,,,,,,,,, -149800,7.9011755,1.1469389,,,,,,,,,,,,,, -149893,,,0.8672671914100647,0.4734492003917694,0.7351399660110474,1.0676651000976562,50000.0,0.6116000413894653,1.751431941986084,10000.0,50527.99266386032,52304.908478975296,50527.99266386032,1765.2281498908997,6.480208396911621,0.0 -149900,7.1514487,1.0835897,,,,,,,,,,,,,, -150000,7.404746,1.0304841,,,,,,,,,,,,,, -150100,7.9054503,0.9650074,,,,,,,,,,,,,, -150200,7.2580185,1.1491071,,,,,,,,,,,,,, -150300,7.041937,1.0570067,,,,,,,,,,,,,, -150400,7.0615597,1.0368773,,,,,,,,,,,,,, -150500,7.810748,1.1333683,,,,,,,,,,,,,, -150600,7.5843453,1.0417693,,,,,,,,,,,,,, -150700,7.389876,1.167049,,,,,,,,,,,,,, -150800,7.8917575,1.0552912,,,,,,,,,,,,,, -150900,7.396494,1.0492785,,,,,,,,,,,,,, -151000,7.5364537,1.0109755,,,,,,,,,,,,,, -151100,7.117861,1.0038841,,,,,,,,,,,,,, -151200,7.785939,1.1202853,,,,,,,,,,,,,, -151300,7.3343277,1.1026454,,,,,,,,,,,,,, -151400,6.7545295,1.0011747,,,,,,,,,,,,,, -151408,,,0.8615473508834839,0.4889602661132812,0.7380799651145935,1.0503398180007937,50000.0,0.6101000308990479,1.7489426136016846,10000.0,51038.04159331322,52832.49210214615,51038.04159331322,1782.6540973186493,6.537206172943115,0.0 -151500,8.879287,1.079075,,,,,,,,,,,,,, -151600,7.7439184,1.1070126,,,,,,,,,,,,,, -151700,7.5110593,1.0510828,,,,,,,,,,,,,, -151800,8.08822,1.1086078,,,,,,,,,,,,,, -151900,6.9825377,1.0005767,,,,,,,,,,,,,, -152000,7.553755,1.0469335,,,,,,,,,,,,,, -152100,8.74952,0.991084,,,,,,,,,,,,,, -152200,7.0405116,1.0358872,,,,,,,,,,,,,, -152300,7.054737,1.0243652,,,,,,,,,,,,,, -152400,7.5451093,1.1884894,,,,,,,,,,,,,, -152500,7.5466275,0.96930474,,,,,,,,,,,,,, -152600,7.63657,1.0741062,,,,,,,,,,,,,, -152700,7.4795866,1.0053372,,,,,,,,,,,,,, -152800,7.4930043,1.0553445,,,,,,,,,,,,,, -152900,8.307243,1.0955628,,,,,,,,,,,,,, -152923,,,0.8637993931770325,0.4847530722618103,0.7389400005340576,1.0541657209396362,50000.0,0.6165000200271606,1.755265712738037,10000.0,51548.12100839615,53359.95241069794,51548.12100839615,1799.930284500122,6.58967399597168,0.0 -153000,7.5146704,1.0474979,,,,,,,,,,,,,, -153100,7.772221,0.9865819,,,,,,,,,,,,,, -153200,7.771872,1.1116878,,,,,,,,,,,,,, -153300,8.962422,1.017495,,,,,,,,,,,,,, -153400,8.181093,1.0534008,,,,,,,,,,,,,, -153500,7.8256717,0.9946354,,,,,,,,,,,,,, -153600,8.496911,1.0615642,,,,,,,,,,,,,, -153700,7.583607,1.0851502,,,,,,,,,,,,,, -153800,7.464393,1.0179732,,,,,,,,,,,,,, -153900,7.31927,1.0056012,,,,,,,,,,,,,, -154000,7.439706,1.1020609,,,,,,,,,,,,,, -154100,8.186277,1.0610563,,,,,,,,,,,,,, -154200,7.6961913,1.0741805,,,,,,,,,,,,,, -154300,8.080909,0.9809106,,,,,,,,,,,,,, -154400,8.145695,1.0579325,,,,,,,,,,,,,, -154438,,,0.8626036047935486,0.4872516095638275,0.7405999898910522,1.0530531406402588,50000.0,0.6126000285148621,1.7592875957489014,10000.0,52058.20289778709,53887.78027367592,52058.20289778709,1817.5693821907043,6.6446380615234375,0.0 -154500,7.854036,0.94951886,,,,,,,,,,,,,, -154600,6.878678,0.90788007,,,,,,,,,,,,,, -154700,7.420093,0.9933705,,,,,,,,,,,,,, -154800,8.0617285,0.96191025,,,,,,,,,,,,,, -154900,7.3654103,0.8991147,,,,,,,,,,,,,, -155000,8.581038,1.000283,,,,,,,,,,,,,, -155100,10.128252,1.0641958,,,,,,,,,,,,,, -155200,6.8682075,0.8999282,,,,,,,,,,,,,, -155300,7.976512,1.0103495,,,,,,,,,,,,,, -155400,7.832035,0.9874734,,,,,,,,,,,,,, -155500,7.775321,1.0306234,,,,,,,,,,,,,, -155600,7.3882294,1.0245539,,,,,,,,,,,,,, -155700,7.620935,1.014716,,,,,,,,,,,,,, -155800,7.480995,0.9722959,,,,,,,,,,,,,, -155900,7.628886,0.97376233,,,,,,,,,,,,,, -155954,,,0.8659518361091614,0.4726797938346863,0.7437599897384644,1.0325734615325928,50000.0,0.6145000457763672,1.7565077543258667,10000.0,52568.401344537735,54415.99195933342,52568.401344537735,1835.477769613266,6.697588682174683,0.0 -156000,8.520894,1.0378221,,,,,,,,,,,,,, -156100,8.192043,0.93912363,,,,,,,,,,,,,, -156200,7.703024,0.91160583,,,,,,,,,,,,,, -156300,7.4858174,0.9555827,,,,,,,,,,,,,, -156400,8.434605,1.0289103,,,,,,,,,,,,,, -156500,8.090132,0.92249614,,,,,,,,,,,,,, -156600,8.920048,1.135712,,,,,,,,,,,,,, -156700,8.159202,1.013128,,,,,,,,,,,,,, -156800,7.569635,0.98298347,,,,,,,,,,,,,, -156900,7.552819,1.0444152,,,,,,,,,,,,,, -157000,8.394049,1.0762469,,,,,,,,,,,,,, -157100,7.869608,1.0009848,,,,,,,,,,,,,, -157200,7.7222047,0.95712453,,,,,,,,,,,,,, -157300,8.37547,0.97568643,,,,,,,,,,,,,, -157400,7.534641,0.9951681,,,,,,,,,,,,,, -157469,,,0.875996470451355,0.4355567395687103,0.7448999881744385,1.034436821937561,50000.0,0.6205000281333923,1.7346514463424685,10000.0,53078.52582502365,54943.82565164566,53078.52582502365,1853.0811932086945,6.750476598739624,0.0 -157500,8.134493,0.9779414,,,,,,,,,,,,,, -157600,7.6050587,0.9528655,,,,,,,,,,,,,, -157700,8.289618,0.9570556,,,,,,,,,,,,,, -157800,8.061912,1.0185221,,,,,,,,,,,,,, -157900,8.276008,1.032887,,,,,,,,,,,,,, -158000,8.037399,0.9992076,,,,,,,,,,,,,, -158100,8.4210205,1.0677367,,,,,,,,,,,,,, -158200,7.8356137,1.0272784,,,,,,,,,,,,,, -158300,7.609215,0.923721,,,,,,,,,,,,,, -158400,7.860541,0.94268787,,,,,,,,,,,,,, -158500,8.694022,1.0478939,,,,,,,,,,,,,, -158600,8.563919,0.97069824,,,,,,,,,,,,,, -158700,8.394092,0.9635924,,,,,,,,,,,,,, -158800,7.717932,0.93753535,,,,,,,,,,,,,, -158900,7.7689214,0.9773427,,,,,,,,,,,,,, -158984,,,0.8933752775192261,0.3799366354942322,0.7441399693489075,1.0344130992889404,50000.0,0.6225000023841858,1.7522176504135132,10000.0,53588.57729744911,55471.33137130737,53588.57729744911,1870.427015542984,6.8072190284729,0.0 -159000,8.5939455,0.9772965,,,,,,,,,,,,,, -159100,7.8252187,0.9010435,,,,,,,,,,,,,, -159200,8.134365,1.0226389,,,,,,,,,,,,,, -159300,7.8795238,0.8884396,,,,,,,,,,,,,, -159400,8.070165,0.91264755,,,,,,,,,,,,,, -159500,10.217535,1.044473,,,,,,,,,,,,,, -159600,8.333321,0.955227,,,,,,,,,,,,,, -159700,8.159805,0.8799023,,,,,,,,,,,,,, -159800,8.187227,0.88396543,,,,,,,,,,,,,, -159900,7.968828,0.9188023,,,,,,,,,,,,,, -160000,8.357263,0.89672834,,,,,,,,,,,,,, -160100,8.745509,1.0072482,,,,,,,,,,,,,, -160200,8.556872,0.97577816,,,,,,,,,,,,,, -160300,7.327989,0.8853307,,,,,,,,,,,,,, -160400,7.879198,0.9702431,,,,,,,,,,,,,, -160500,,,0.8909438848495483,0.3873023092746734,0.750499963760376,1.014423131942749,50000.0,0.6256000399589539,1.7291738986968994,10000.0,54098.76694107056,55999.49178338051,54098.76694107056,1888.285579442978,6.8666510581970215,0.0 -160500,9.336315,0.95882875,,,,,,,,,,,,,, -160600,8.221399,0.8588179,,,,,,,,,,,,,, -160700,7.9606414,0.933209,,,,,,,,,,,,,, -160800,8.489043,0.95574546,,,,,,,,,,,,,, -160900,8.577456,0.9200224,,,,,,,,,,,,,, -161000,7.934179,0.92557657,,,,,,,,,,,,,, -161100,8.584679,0.9788986,,,,,,,,,,,,,, -161200,9.170046,0.86944646,,,,,,,,,,,,,, -161300,8.536702,0.85732615,,,,,,,,,,,,,, -161400,8.138798,0.92289984,,,,,,,,,,,,,, -161500,9.855735,0.9212972,,,,,,,,,,,,,, -161600,10.213684,0.9357922,,,,,,,,,,,,,, -161700,8.318142,0.88468677,,,,,,,,,,,,,, -161800,8.772739,0.8251102,,,,,,,,,,,,,, -161900,7.98207,0.87759876,,,,,,,,,,,,,, -162000,7.6346745,0.8288888,,,,,,,,,,,,,, -162015,,,0.8907844424247742,0.3792414665222168,0.7509599924087524,1.014463186264038,50000.0,0.6287000179290771,1.735564112663269,10000.0,54608.69594120979,56527.012323856354,54608.69594120979,1905.7687640190125,6.921643972396851,0.0 -162100,7.7171283,0.88709366,,,,,,,,,,,,,, -162200,8.104619,0.928469,,,,,,,,,,,,,, -162300,8.038993,0.9067136,,,,,,,,,,,,,, -162400,8.163187,0.88713914,,,,,,,,,,,,,, -162500,8.249397,0.90396184,,,,,,,,,,,,,, -162600,7.5722814,0.8419679,,,,,,,,,,,,,, -162700,7.970053,0.91428965,,,,,,,,,,,,,, -162800,8.169525,0.95969015,,,,,,,,,,,,,, -162900,9.679586,1.0081422,,,,,,,,,,,,,, -163000,7.989174,0.8978707,,,,,,,,,,,,,, -163100,8.510612,0.9700848,,,,,,,,,,,,,, -163200,9.390859,0.91415304,,,,,,,,,,,,,, -163300,8.513431,0.8385129,,,,,,,,,,,,,, -163400,8.617529,0.9156602,,,,,,,,,,,,,, -163500,8.165944,0.8142182,,,,,,,,,,,,,, -163530,,,0.8932557106018066,0.3726956844329834,0.7510600090026855,1.0063519477844238,50000.0,0.6272000074386597,1.7141947746276855,10000.0,55118.67601776123,57054.70344829559,55118.67601776123,1923.371912240982,6.976501703262329,0.0 -163600,7.9259114,0.9081038,,,,,,,,,,,,,, -163700,8.412855,0.9358906,,,,,,,,,,,,,, -163800,8.454626,0.86760443,,,,,,,,,,,,,, -163900,8.347265,0.90556216,,,,,,,,,,,,,, -164000,8.180653,0.83657444,,,,,,,,,,,,,, -164100,9.133014,0.94305986,,,,,,,,,,,,,, -164200,8.255482,0.91408944,,,,,,,,,,,,,, -164300,8.232922,0.8449679,,,,,,,,,,,,,, -164400,8.525845,0.89424586,,,,,,,,,,,,,, -164500,8.665775,0.93126875,,,,,,,,,,,,,, -164600,8.529867,0.85843086,,,,,,,,,,,,,, -164700,8.159676,0.8615259,,,,,,,,,,,,,, -164800,7.7734056,0.7078307,,,,,,,,,,,,,, -164900,8.260199,0.8645796,,,,,,,,,,,,,, -165000,9.392859,0.8489895,,,,,,,,,,,,,, -165045,,,0.8986766338348389,0.3579636812210083,0.7545199990272522,0.9984259009361268,50000.0,0.6301000118255615,1.7019473314285278,10000.0,55628.793186903,57582.69201374054,55628.793186903,1941.1338379383087,7.033837080001831,0.0 -165100,7.894867,0.8568714,,,,,,,,,,,,,, -165200,9.307609,0.9539433,,,,,,,,,,,,,, -165300,8.551812,0.88444436,,,,,,,,,,,,,, -165400,8.187315,0.8512312,,,,,,,,,,,,,, -165500,8.055869,0.81065536,,,,,,,,,,,,,, -165600,7.4707737,0.77123374,,,,,,,,,,,,,, -165700,9.106224,0.9139999,,,,,,,,,,,,,, -165800,9.523957,0.8895912,,,,,,,,,,,,,, -165900,8.40787,0.8841685,,,,,,,,,,,,,, -166000,8.493634,0.88972485,,,,,,,,,,,,,, -166100,7.957158,0.7485266,,,,,,,,,,,,,, -166200,9.151053,0.9374316,,,,,,,,,,,,,, -166300,9.51285,0.9213078,,,,,,,,,,,,,, -166400,8.055467,0.87648034,,,,,,,,,,,,,, -166500,8.001605,0.82605267,,,,,,,,,,,,,, -166560,,,0.9122488498687744,0.3140309453010559,0.7552599906921387,0.9961941838264464,50000.0,0.633400022983551,1.6993948221206665,10000.0,56138.75828337669,58110.06244254112,56138.75828337669,1958.429122209549,7.088637590408325,0.0 -166600,9.267808,0.87783396,,,,,,,,,,,,,, -166700,9.379433,0.81885666,,,,,,,,,,,,,, -166800,8.531744,0.8536814,,,,,,,,,,,,,, -166900,8.491409,0.81715024,,,,,,,,,,,,,, -167000,8.052504,0.87658143,,,,,,,,,,,,,, -167100,9.335342,0.879135,,,,,,,,,,,,,, -167200,8.335119,0.82356703,,,,,,,,,,,,,, -167300,8.752915,0.86104983,,,,,,,,,,,,,, -167400,8.401613,0.7958304,,,,,,,,,,,,,, -167500,8.639758,0.8185315,,,,,,,,,,,,,, -167600,8.9204035,0.77501845,,,,,,,,,,,,,, -167700,8.414389,0.8452496,,,,,,,,,,,,,, -167800,8.38334,0.8407257,,,,,,,,,,,,,, -167900,10.017843,0.95014685,,,,,,,,,,,,,, -168000,9.6027775,0.85674304,,,,,,,,,,,,,, -168075,,,0.916772961616516,0.2954529523849487,0.7545599937438965,0.992598831653595,50000.0,0.6340000033378601,1.6936455965042114,10000.0,56648.82356977463,58637.83798003197,56648.82356977463,1976.0279388427728,7.146857023239136,0.0 -168100,9.557328,0.8203845,,,,,,,,,,,,,, -168200,8.841369,0.80681247,,,,,,,,,,,,,, -168300,9.716835,0.85968506,,,,,,,,,,,,,, -168400,8.843525,0.8268428,,,,,,,,,,,,,, -168500,9.439659,0.8090205,,,,,,,,,,,,,, -168600,8.727482,0.82740706,,,,,,,,,,,,,, -168700,9.864734,0.85255903,,,,,,,,,,,,,, -168800,9.078551,0.90224826,,,,,,,,,,,,,, -168900,8.993972,0.789556,,,,,,,,,,,,,, -169000,8.889262,0.8511989,,,,,,,,,,,,,, -169100,9.226008,0.81941456,,,,,,,,,,,,,, -169200,8.407913,0.86864805,,,,,,,,,,,,,, -169300,9.382455,0.8022952,,,,,,,,,,,,,, -169400,8.272962,0.76043457,,,,,,,,,,,,,, -169500,8.676259,0.8753248,,,,,,,,,,,,,, -169590,,,0.9133848547935486,0.3045446276664734,0.7583999633789062,0.9829720854759216,50000.0,0.6345000267028809,1.6890100240707395,10000.0,57158.8615398407,59165.35334587097,57158.8615398407,1993.3908696174624,7.208467960357666,0.0 -169600,9.26855,0.76812774,,,,,,,,,,,,,, -169700,8.881748,0.85093176,,,,,,,,,,,,,, -169800,8.934759,0.7265774,,,,,,,,,,,,,, -169900,8.2886505,0.7811078,,,,,,,,,,,,,, -170000,9.179733,0.8590857,,,,,,,,,,,,,, -170100,9.097104,0.7993182,,,,,,,,,,,,,, -170200,8.859816,0.8713825,,,,,,,,,,,,,, -170300,8.917518,0.775414,,,,,,,,,,,,,, -170400,9.039725,0.7216106,,,,,,,,,,,,,, -170500,9.11118,0.8142953,,,,,,,,,,,,,, -170600,8.900299,0.88465875,,,,,,,,,,,,,, -170700,8.560219,0.78795505,,,,,,,,,,,,,, -170800,8.762824,0.8132852,,,,,,,,,,,,,, -170900,8.68843,0.7906683,,,,,,,,,,,,,, -171000,8.163291,0.7641589,,,,,,,,,,,,,, -171100,9.2265215,0.77117157,,,,,,,,,,,,,, -171105,,,0.9176697731018066,0.2919151186943054,0.7594999670982361,0.9838279485702516,50000.0,0.6374000310897827,1.6809219121932983,10000.0,57668.79889130592,59692.92246937752,57668.79889130592,2010.915184259415,7.263584613800049,0.0 -171200,8.719589,0.7823224,,,,,,,,,,,,,, -171300,8.671649,0.7685701,,,,,,,,,,,,,, -171400,9.547194,0.79295677,,,,,,,,,,,,,, -171500,8.864623,0.86782104,,,,,,,,,,,,,, -171600,8.972546,0.77493715,,,,,,,,,,,,,, -171700,9.414689,0.8905814,,,,,,,,,,,,,, -171800,9.12113,0.7940284,,,,,,,,,,,,,, -171900,9.2245,0.781978,,,,,,,,,,,,,, -172000,9.401636,0.8704463,,,,,,,,,,,,,, -172100,8.310327,0.73731947,,,,,,,,,,,,,, -172200,9.20382,0.8029147,,,,,,,,,,,,,, -172300,9.054218,0.8278492,,,,,,,,,,,,,, -172400,8.205153,0.69074917,,,,,,,,,,,,,, -172500,9.755433,0.7774526,,,,,,,,,,,,,, -172600,8.578905,0.7967427,,,,,,,,,,,,,, -172620,,,0.9196228981018066,0.279366672039032,0.7603200078010559,0.9797834157943726,50000.0,0.6355000138282776,1.6834412813186646,10000.0,58178.850531578064,60220.705787181854,58178.850531578064,2028.5382385253904,7.319125175476074,0.0 -172700,8.900402,0.7630107,,,,,,,,,,,,,, -172800,8.859331,0.7430009,,,,,,,,,,,,,, -172900,9.98369,0.7433278,,,,,,,,,,,,,, -173000,9.501488,0.81100214,,,,,,,,,,,,,, -173100,8.900623,0.7951813,,,,,,,,,,,,,, -173200,8.887991,0.8221766,,,,,,,,,,,,,, -173300,9.192184,0.82265896,,,,,,,,,,,,,, -173400,8.179909,0.7604398,,,,,,,,,,,,,, -173500,8.793846,0.71330404,,,,,,,,,,,,,, -173600,8.954937,0.6763935,,,,,,,,,,,,,, -173700,9.102237,0.767401,,,,,,,,,,,,,, -173800,9.658056,0.7697113,,,,,,,,,,,,,, -173900,10.166398,0.7813914,,,,,,,,,,,,,, -174000,9.093481,0.72094697,,,,,,,,,,,,,, -174100,9.560716,0.77581865,,,,,,,,,,,,,, -174135,,,0.9229113459587096,0.2737523913383484,0.7621399760246277,0.9760602712631226,50000.0,0.6358000040054321,1.688401460647583,10000.0,58688.89499163628,60748.39577841759,58688.89499163628,2046.0725243091583,7.37781810760498,0.0 -174200,9.298103,0.80146146,,,,,,,,,,,,,, -174300,8.71712,0.82306975,,,,,,,,,,,,,, -174400,9.867437,0.782241,,,,,,,,,,,,,, -174500,9.732572,0.72809875,,,,,,,,,,,,,, -174600,8.677318,0.72323483,,,,,,,,,,,,,, -174700,8.767121,0.72614276,,,,,,,,,,,,,, -174800,8.970044,0.79658103,,,,,,,,,,,,,, -174900,8.079981,0.713543,,,,,,,,,,,,,, -175000,9.212592,0.78838754,,,,,,,,,,,,,, -175100,8.77674,0.7832142,,,,,,,,,,,,,, -175200,9.009466,0.77124286,,,,,,,,,,,,,, -175300,9.687729,0.76971674,,,,,,,,,,,,,, -175400,8.475989,0.7876309,,,,,,,,,,,,,, -175500,8.685446,0.8009461,,,,,,,,,,,,,, -175600,8.588201,0.7005917,,,,,,,,,,,,,, -175650,,,0.9290696382522584,0.2520670890808105,0.7624199986457825,0.9706498980522156,50000.0,0.6411000490188599,1.676186203956604,10000.0,59198.99043011665,61276.268881082535,59198.99043011665,2063.7414784431458,7.4339611530303955,0.0 -175700,9.56463,0.7269386,,,,,,,,,,,,,, -175800,9.161221,0.80684894,,,,,,,,,,,,,, -175900,8.954997,0.718169,,,,,,,,,,,,,, -176000,7.8721075,0.75678414,,,,,,,,,,,,,, -176100,9.212569,0.82069397,,,,,,,,,,,,,, -176200,9.359414,0.7609544,,,,,,,,,,,,,, -176300,8.749191,0.7877758,,,,,,,,,,,,,, -176400,9.924484,0.74695575,,,,,,,,,,,,,, -176500,9.258835,0.7823474,,,,,,,,,,,,,, -176600,8.325842,0.6907357,,,,,,,,,,,,,, -176700,9.12941,0.7771654,,,,,,,,,,,,,, -176800,7.8047504,0.65588975,,,,,,,,,,,,,, -176900,8.900273,0.7898053,,,,,,,,,,,,,, -177000,9.3656435,0.75859517,,,,,,,,,,,,,, -177100,8.826104,0.7261932,,,,,,,,,,,,,, -177165,,,0.9302256107330322,0.2482085525989532,0.7620399594306946,0.9679412245750428,50000.0,0.6412000060081482,1.673362374305725,10000.0,59708.983047008514,61803.98777985573,59708.983047008514,2081.356790304184,7.493084192276001,0.0 -177200,10.565102,0.8210322,,,,,,,,,,,,,, -177300,9.420683,0.7658438,,,,,,,,,,,,,, -177400,9.923072,0.7692702,,,,,,,,,,,,,, -177500,9.541068,0.67669284,,,,,,,,,,,,,, -177600,9.167643,0.7795053,,,,,,,,,,,,,, -177700,9.877147,0.8134113,,,,,,,,,,,,,, -177800,9.144312,0.7274144,,,,,,,,,,,,,, -177900,9.806835,0.73474395,,,,,,,,,,,,,, -178000,8.702977,0.71226805,,,,,,,,,,,,,, -178100,8.476491,0.6710186,,,,,,,,,,,,,, -178200,8.015768,0.69531465,,,,,,,,,,,,,, -178300,9.003298,0.7634978,,,,,,,,,,,,,, -178400,9.495012,0.7370252,,,,,,,,,,,,,, -178500,8.479387,0.7806254,,,,,,,,,,,,,, -178600,9.798377,0.7443733,,,,,,,,,,,,,, -178680,,,0.9304248690605164,0.2487609386444091,0.7635599970817566,0.96607106924057,50000.0,0.6420000195503235,1.6618832349777222,10000.0,60218.98730373383,62331.470396757126,60218.98730373383,2098.725486040116,7.550928115844727,0.0 -178700,8.662738,0.70926255,,,,,,,,,,,,,, -178800,9.105496,0.8360455,,,,,,,,,,,,,, -178900,9.538928,0.7792144,,,,,,,,,,,,,, -179000,9.21513,0.7459802,,,,,,,,,,,,,, -179100,8.57173,0.6303162,,,,,,,,,,,,,, -179200,9.996712,0.7467834,,,,,,,,,,,,,, -179300,9.653631,0.67774355,,,,,,,,,,,,,, -179400,10.120901,0.79457223,,,,,,,,,,,,,, -179500,9.15306,0.7579787,,,,,,,,,,,,,, -179600,9.673507,0.77223384,,,,,,,,,,,,,, -179700,9.951259,0.73607874,,,,,,,,,,,,,, -179800,10.12862,0.75914073,,,,,,,,,,,,,, -179900,9.090756,0.7244226,,,,,,,,,,,,,, -180000,8.995548,0.6682439,,,,,,,,,,,,,, -180100,8.855891,0.6980867,,,,,,,,,,,,,, -180195,,,0.9312220811843872,0.2476635724306106,0.7628799676895142,0.9658904671669006,50000.0,0.6408000588417053,1.6687637567520142,10000.0,60729.14320731163,62859.096420288086,60729.14320731163,2116.085070133209,7.609094381332397,0.0 -180200,8.863712,0.732891,,,,,,,,,,,,,, -180300,8.761485,0.7761571,,,,,,,,,,,,,, -180400,9.672247,0.7718275,,,,,,,,,,,,,, -180500,8.438275,0.770392,,,,,,,,,,,,,, -180600,9.522768,0.65470934,,,,,,,,,,,,,, -180700,9.761167,0.7430692,,,,,,,,,,,,,, -180800,9.303651,0.8025315,,,,,,,,,,,,,, -180900,10.000038,0.78940207,,,,,,,,,,,,,, -181000,8.792645,0.7857776,,,,,,,,,,,,,, -181100,8.6675625,0.77778816,,,,,,,,,,,,,, -181200,9.17824,0.8030969,,,,,,,,,,,,,, -181300,8.821431,0.67249334,,,,,,,,,,,,,, -181400,9.093203,0.7621326,,,,,,,,,,,,,, -181500,8.290939,0.748148,,,,,,,,,,,,,, -181600,8.922345,0.7710564,,,,,,,,,,,,,, -181700,8.9921465,0.7541316,,,,,,,,,,,,,, -181711,,,0.932437777519226,0.2460429519414901,0.7635599970817566,0.963376522064209,50000.0,0.6413000226020813,1.669214963912964,10000.0,61239.31663656235,63386.6651391983,61239.31663656235,2133.3704464435577,7.666547775268555,0.0 -181800,9.619761,0.7662722,,,,,,,,,,,,,, -181900,8.2339735,0.7406728,,,,,,,,,,,,,, -182000,8.538262,0.7521013,,,,,,,,,,,,,, -182100,9.258754,0.725011,,,,,,,,,,,,,, -182200,9.37817,0.76070786,,,,,,,,,,,,,, -182300,8.660069,0.7794449,,,,,,,,,,,,,, -182400,8.96824,0.7744499,,,,,,,,,,,,,, -182500,9.163713,0.7337281,,,,,,,,,,,,,, -182600,8.484168,0.6770491,,,,,,,,,,,,,, -182700,9.361972,0.7996266,,,,,,,,,,,,,, -182800,8.514699,0.68596196,,,,,,,,,,,,,, -182900,9.173877,0.73152256,,,,,,,,,,,,,, -183000,9.755933,0.76487,,,,,,,,,,,,,, -183100,10.997653,0.7706875,,,,,,,,,,,,,, -183200,8.880194,0.72796637,,,,,,,,,,,,,, -183226,,,0.9330556392669678,0.2435648739337921,0.763700008392334,0.9627121686935424,50000.0,0.6414000391960144,1.668165683746338,10000.0,61749.31013250351,63914.57199478149,61749.31013250351,2151.172516822815,7.724971771240234,0.0 -183300,9.09288,0.71460336,,,,,,,,,,,,,, -183400,9.245564,0.69692314,,,,,,,,,,,,,, -183500,8.632017,0.6987537,,,,,,,,,,,,,, -183600,9.3436775,0.71664757,,,,,,,,,,,,,, -183700,9.502648,0.720846,,,,,,,,,,,,,, -183800,9.392648,0.78990626,,,,,,,,,,,,,, -183900,8.630329,0.6950403,,,,,,,,,,,,,, -184000,8.540486,0.741616,,,,,,,,,,,,,, -184100,8.854294,0.7416185,,,,,,,,,,,,,, -184200,8.947096,0.7443515,,,,,,,,,,,,,, -184300,8.864855,0.7288749,,,,,,,,,,,,,, -184400,9.94452,0.67952955,,,,,,,,,,,,,, -184500,10.040203,0.737208,,,,,,,,,,,,,, -184600,8.120382,0.68716305,,,,,,,,,,,,,, -184700,8.89082,0.6656057,,,,,,,,,,,,,, -184741,,,0.9329758882522584,0.240103930234909,0.7639399766921997,0.962593674659729,50000.0,0.6416000127792358,1.6681299209594729,10000.0,62259.26232099533,64441.89793848992,62259.26232099533,2168.429575204849,7.789133071899414,0.0 -184800,8.680597,0.70804423,,,,,,,,,,,,,, -184900,8.569384,0.7932784,,,,,,,,,,,,,, -185000,9.80593,0.70061344,,,,,,,,,,,,,, -185100,9.066284,0.72347534,,,,,,,,,,,,,, -185200,8.133728,0.706967,,,,,,,,,,,,,, -185300,8.8381815,0.64194405,,,,,,,,,,,,,, -185400,9.522535,0.7085919,,,,,,,,,,,,,, -185500,8.418298,0.7191297,,,,,,,,,,,,,, -185600,9.076842,0.78759813,,,,,,,,,,,,,, -185700,9.114385,0.7962474,,,,,,,,,,,,,, -185800,8.818776,0.7194655,,,,,,,,,,,,,, -185900,9.510981,0.7741617,,,,,,,,,,,,,, -186000,9.2888775,0.72775,,,,,,,,,,,,,, -186100,10.013639,0.7467617,,,,,,,,,,,,,, -186200,9.19707,0.7750126,,,,,,,,,,,,,, -186255,,,0.9334343075752258,0.2402811050415039,0.7638399600982666,0.9630224704742432,50000.0,0.6425000429153442,1.6677354574203491,10000.0,62769.26744389534,64969.73010158539,62769.26744389534,2186.145072221756,7.848384857177734,0.0 -186300,8.82342,0.69405544,,,,,,,,,,,,,, -186400,9.020872,0.6992201,,,,,,,,,,,,,, -186500,9.118973,0.6683261,,,,,,,,,,,,,, -186600,8.805377,0.7692966,,,,,,,,,,,,,, -186666,,,0.933812975883484,0.2365906089544296,0.7639399766921997,0.9623273015022278,50000.0,0.6413000226020813,1.6673704385757446,10000.0,62907.48594617844,65125.38621973992,62907.48594617844,2203.5093677043915,7.90713620185852,0.0 -186666,,,,,,,,,,,62907.485946178436,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/eval_measurements.csv deleted file mode 100644 index d885f36c8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.638524532318115,0.0,32.56960868835449,1,0,32.56960868835449,0.0009000000427477,6.912177562713623,10000,50.20822787284851,0.0011160713620483,6.912219524383545,0.0011199999134987,6.912059783935547,50000 -35.02720069885254,0.0212280750274658,542.8288688659668,1509,0,542.8288688659668,0.1509000062942505,4.556992053985596,10000,577.9301149845123,0.2197464853525161,3.89326810836792,0.2019199877977371,4.026599407196045,50000 -52.6793520450592,0.0487828254699707,1052.830677509308,3019,0,1052.830677509308,0.2418000102043151,3.88297438621521,10000,1105.6641788482666,0.3487125337123871,3.0133767127990723,0.3220599889755249,3.1563539505004883,50000 -70.0097508430481,0.0756216049194336,1563.0581607818604,4531,0,1563.0581607818604,0.2779000103473663,3.6371009349823,10000,1633.3022310733795,0.3955078125,2.769458055496216,0.3713600039482116,2.910965919494629,50000 -87.59454846382141,0.1046609878540039,2073.061323404312,6044,0,2073.061323404312,0.2471000105142593,4.004795551300049,10000,2160.9724485874176,0.337890625,3.187882423400879,0.3098799884319305,3.3979156017303467,50000 -105.82629466056824,0.1317870616912841,2583.286673307419,7559,0,2583.286673307419,0.1762000024318695,4.848935604095459,10000,2689.5087604522705,0.2487842738628387,3.991169214248657,0.2201599925756454,4.308241844177246,50000 -123.39727687835692,0.1558332443237304,3093.225257396698,9072,0,3093.225257396698,0.282800018787384,3.630770444869995,10000,3217.094662427902,0.406927615404129,2.715965986251831,0.3807999789714813,2.8986012935638428,50000 -141.6427493095398,0.183197021484375,3603.417446374893,10587,0,3603.417446374893,0.2955000102519989,3.522867679595948,10000,3745.611699104309,0.4227718412876129,2.5902175903320312,0.3854199945926666,2.824572324752808,50000 -159.21364331245422,0.211188793182373,4113.432627916336,12101,0,4113.432627916336,0.2666000127792358,3.802535057067871,10000,4273.278554916382,0.3722098171710968,2.9450957775115967,0.3542799949645996,3.08566689491272,50000 -176.6627459526062,0.2437028884887695,4623.356384038925,13615,0,4623.356384038925,0.1856000125408172,4.675883769989014,10000,4800.736140012741,0.2633330523967743,3.873202085494995,0.2458799928426742,3.999222993850708,50000 -194.09912204742432,0.2718853950500488,5133.596252441406,15131,0,5133.596252441406,0.3039000034332275,3.497375249862671,10000,5328.493609189987,0.4353874325752258,2.5323400497436523,0.3913599848747253,2.8394031524658203,50000 -212.02861714363087,0.3007538318634033,5643.834300756455,16647,0,5643.834300756455,0.2205000072717666,4.4452009201049805,10000,5856.744036912918,0.3074577450752258,3.5628976821899414,0.2874000072479248,3.746313571929932,50000 -229.48013925552368,0.3294222354888916,6154.009658336639,18163,0,6154.009658336639,0.1614000052213668,4.841095924377441,10000,6384.4521453380585,0.2423668652772903,3.9114837646484375,0.2272399961948394,4.067636013031006,50000 -246.9081676006317,0.357844591140747,6664.045799255371,19679,0,6664.045799255371,0.2403000146150589,3.9924895763397217,10000,6911.997058391571,0.3469387590885162,3.0987894535064697,0.3242200016975403,3.262202739715576,50000 -264.5111756324768,0.3970851898193359,7174.211032867432,21195,0,7174.211032867432,0.1844000071287155,4.544374465942383,10000,7439.85765004158,0.283581793308258,3.587366819381714,0.262800008058548,3.7317724227905273,50000 -281.99954295158386,0.4310669898986816,7684.350539445877,22712,0,7684.350539445877,0.1419000029563903,5.220768451690674,10000,7967.571515083313,0.2009127885103225,4.426466464996338,0.1850199997425079,4.574308395385742,50000 -299.5668559074402,0.4717752933502197,8194.534126281738,24229,0,8194.534126281738,0.1383000016212463,5.4493255615234375,10000,8495.416982650757,0.2044005095958709,4.545769691467285,0.1773199886083603,4.824890613555908,50000 -316.89233231544495,0.50246262550354,8704.772486686707,25747,0,8704.772486686707,0.19200000166893,4.759922981262207,10000,9023.064319849014,0.2718032598495483,3.84205961227417,0.2549799978733063,4.00461483001709,50000 -334.2111656665802,0.5347170829772949,9214.88447213173,27265,0,9214.88447213173,0.2277000099420547,4.118068218231201,10000,9550.579601049423,0.3291215002536773,3.240143060684204,0.3055999875068664,3.3980844020843506,50000 -351.86159229278564,0.5654406547546387,9724.903431653976,28783,0,9724.903431653976,0.166700005531311,4.732311725616455,10000,10078.332313776016,0.2516741156578064,3.8752613067626953,0.2350799888372421,4.017209053039551,50000 -369.6040177345276,0.5961604118347168,10235.189134597778,30301,0,10235.189134597778,0.171300008893013,4.892067909240723,10000,10606.446083307266,0.2491828650236129,4.077645778656006,0.2342199981212616,4.213882923126221,50000 -387.2266595363617,0.6278092861175537,10745.33808541298,31820,0,10745.33808541298,0.1625000089406967,5.108082294464111,10000,11134.302097797394,0.2342952787876129,4.242311954498291,0.2231799960136413,4.331212520599365,50000 -404.47988963127136,0.6608669757843018,11255.258068799973,33338,0,11255.258068799973,0.263700008392334,3.885799407958984,10000,11661.563098192217,0.3989955186843872,2.7934722900390625,0.3581999838352203,3.0654101371765137,50000 -421.78941106796265,0.6944155693054199,11765.183889389038,34857,0,11765.183889389038,0.1992000043392181,4.656971454620361,10000,12188.885138511658,0.2908561825752258,3.669871330261232,0.271479994058609,3.857929229736328,50000 -439.2223870754242,0.727224588394165,12275.3075401783,36376,0,12275.3075401783,0.2078000158071518,4.676905632019043,10000,12716.527085065842,0.3069993555545807,3.613547325134277,0.2903600037097931,3.769576787948608,50000 -456.9472150802612,0.7594432830810547,12785.335270166395,37895,0,12785.335270166395,0.2243000119924545,4.230418205261231,10000,13244.364332437515,0.3186583220958709,3.3709821701049805,0.2976000010967254,3.534862756729126,50000 -474.2763805389404,0.7922139167785645,13295.302300453186,39413,0,13295.302300453186,0.176700010895729,4.55711841583252,10000,13771.746730804443,0.2593869566917419,3.792681932449341,0.2449399977922439,3.898503065109253,50000 -491.86146664619446,0.8255927562713623,13805.345292568209,40933,0,13805.345292568209,0.1660000085830688,4.765476703643799,10000,14299.461974859238,0.2283362448215484,3.997976064682007,0.2191199958324432,4.096891403198242,50000 -509.33509039878845,0.8656878471374512,14315.4254693985,42453,0,14315.4254693985,0.2339000105857849,4.010398387908936,10000,14827.108598709106,0.3529775142669678,3.090325355529785,0.3198799788951874,3.338887929916382,50000 -526.9106154441833,0.900765895843506,14825.367683410645,43973,0,14825.367683410645,0.0865000039339065,6.511098384857178,10000,15354.713600158691,0.1209542378783226,5.900317192077637,0.1139400005340576,6.015324115753174,50000 -544.4883089065552,0.9369504451751708,15335.417038440704,45493,0,15335.417038440704,0.1197000071406364,5.432078361511231,10000,15882.430266857147,0.1727319806814193,4.664542675018311,0.1613999903202057,4.75282621383667,50000 -561.8923692703247,0.9734163284301758,15845.601004362106,47014,0,15845.601004362106,0.1162000074982643,5.645530700683594,10000,16410.10750222206,0.1693638414144516,4.89009428024292,0.1624599993228912,4.977695465087891,50000 -579.9127895832062,1.0119578838348389,16355.56469798088,48534,0,16355.56469798088,0.1300000101327896,5.572903156280518,10000,16938.182448148727,0.186902105808258,4.8144307136535645,0.1775399893522262,4.926436901092529,50000 -597.3582053184509,1.0468201637268066,16865.721137285233,50054,0,16865.721137285233,0.1940000057220459,4.523690700531006,10000,17465.871644973755,0.2864915430545807,3.610297918319702,0.2510800063610077,3.918847322463989,50000 -615.1842617988586,1.081218957901001,17375.71496105194,51574,0,17375.71496105194,0.1734000146389007,4.786422729492188,10000,17993.777713537216,0.2438815385103225,4.0654730796813965,0.2254599928855896,4.252574920654297,50000 -632.7463145256042,1.1168007850646973,17885.76908183098,53094,0,17885.76908183098,0.175800010561943,4.857440948486328,10000,18521.48242330551,0.2487045526504516,4.032920837402344,0.2283199876546859,4.205615043640137,50000 -650.3460447788239,1.1529819965362549,18395.96377325058,54614,0,18395.96377325058,0.2151000052690506,4.410431385040283,10000,19049.366036891937,0.3034518361091614,3.569441318511963,0.2828399837017059,3.7012898921966553,50000 -667.7958898544312,1.19061279296875,18906.143671512604,56135,0,18906.143671512604,0.1838000118732452,4.662926197052002,10000,19577.08617591858,0.2638711631298065,3.8232076168060303,0.2496399879455566,3.983182907104492,50000 -685.1396398544312,1.230280876159668,19416.319585561752,57656,0,19416.319585561752,0.1345000118017196,5.564505100250244,10000,20104.69949555397,0.1876992881298065,4.809234619140625,0.1772599965333938,4.956613540649414,50000 -702.7179343700409,1.2668704986572266,19926.450991630554,59177,0,19926.450991630554,0.1234000027179718,5.476072311401367,10000,20632.49976873398,0.1821189373731613,4.640640258789063,0.1666799932718277,4.858196258544922,50000 -719.9272334575653,1.3086819648742676,20436.43869829178,60698,0,20436.43869829178,0.226500004529953,4.168323040008545,10000,21159.79110598564,0.3432318270206451,3.1367218494415283,0.3236799836158752,3.3062968254089355,50000 -737.3943648338318,1.3493919372558594,20946.65442109108,62219,0,20946.65442109108,0.1844000071287155,4.584669589996338,10000,21687.56720471382,0.2721022069454193,3.687821388244629,0.2554399967193603,3.826228618621826,50000 -754.9472448825836,1.3876104354858398,21456.615305900574,63739,0,21456.615305900574,0.1894000023603439,4.6342620849609375,10000,22215.171835184097,0.2697106003761291,3.7855687141418457,0.2567200064659118,3.871646881103516,50000 -772.145806312561,1.428035020828247,21966.762639045715,65260,0,21966.762639045715,0.1604000031948089,5.25397253036499,10000,22742.61117172241,0.2284757643938064,4.340119361877441,0.2192399948835373,4.4335126876831055,50000 -789.7588932514191,1.4698281288146973,22476.804877758022,66781,0,22476.804877758022,0.2578000128269195,3.982248544692993,10000,23270.36111807823,0.3475964665412903,3.1049602031707764,0.3361199796199798,3.2092509269714355,50000 -807.3938567638397,1.5102369785308838,22986.811591625214,68302,0,22986.811591625214,0.2359000146389007,4.184186458587647,10000,23798.096511363983,0.351283460855484,3.137704372406006,0.3225999772548675,3.393313407897949,50000 -824.8659207820892,1.5485587120056152,23496.993092536926,69823,0,23496.993092536926,0.1305000036954879,5.301344871520996,10000,24325.841685056686,0.1961694806814193,4.434426784515381,0.1842399984598159,4.5618109703063965,50000 -842.2808480262756,1.5876126289367676,24007.092417240143,71344,0,24007.092417240143,0.2132000029087066,4.583950042724609,10000,24853.448628664017,0.3081353604793548,3.636435031890869,0.2773399949073791,3.8866045475006095,50000 -859.5648393630981,1.626474142074585,24517.10163712501,72865,0,24517.10163712501,0.1842000037431717,4.706700325012207,10000,25380.833248138428,0.2721420526504516,3.783998489379883,0.2586399912834167,3.9203364849090576,50000 -878.2137405872345,1.6661872863769531,25027.26301074028,74387,0,25027.26301074028,0.2869000136852264,3.61729907989502,10000,25909.73645925522,0.407904177904129,2.691690921783448,0.3840200006961822,2.868121385574341,50000 -895.784318447113,1.700392484664917,25537.45592713356,75908,0,25537.45592713356,0.2714000046253204,3.899897336959839,10000,26437.587026834488,0.4138033986091614,2.737895965576172,0.3630799949169159,3.0857176780700684,50000 -913.3123028278352,1.7399368286132812,26047.5850622654,77429,0,26047.5850622654,0.2776000201702118,3.813933372497559,10000,26965.33577370644,0.3894292116165161,2.835314989089966,0.3652399778366089,3.049356460571289,50000 -930.7702312469482,1.7816412448883057,26557.61221885681,78950,0,26557.61221885681,0.2470000088214874,3.9893369674682617,10000,27492.91514992714,0.3610690236091614,3.041414499282837,0.3311599791049957,3.252563238143921,50000 -948.5069932937622,1.830293893814087,27067.55645370484,80471,0,27067.55645370484,0.2220000177621841,4.42165994644165,10000,28020.69840979576,0.3260921537876129,3.3770241737365723,0.3070800006389618,3.5291993618011475,50000 -966.3523087501526,1.8731324672698968,27577.53954029084,81991,0,27577.53954029084,0.2578000128269195,3.9556169509887695,10000,28548.6232419014,0.375019907951355,2.999455213546753,0.3500199913978576,3.140925168991089,50000 -983.7680652141572,1.908886432647705,28087.72601652145,83512,0,28087.72601652145,0.3202000260353088,3.373796939849853,10000,29076.31339788437,0.4471859037876129,2.482073307037353,0.4222399890422821,2.646937131881714,50000 -1001.0070824623108,1.950409889221192,28597.715670108795,85033,0,28597.715670108795,0.2476000189781189,4.028741359710693,10000,29603.63627338409,0.3713129758834839,2.982234001159668,0.339599996805191,3.212437391281128,50000 -1018.7795708179474,1.9979043006896973,29107.75165891648,86554,0,29107.75165891648,0.2648999989032745,4.093916893005371,10000,30131.54461431504,0.3778499662876129,3.040504217147827,0.3454599976539612,3.298660516738892,50000 -1036.8213591575625,2.039724111557007,29617.712538719177,88074,0,29617.712538719177,0.2532000243663788,3.9751439094543457,10000,30659.6422123909,0.3744818270206451,2.971901655197144,0.3493599891662597,3.141242742538452,50000 -1054.453256368637,2.091838836669922,30127.6463303566,89594,0,30127.6463303566,0.3113000094890594,3.605199098587036,10000,31187.31401515007,0.4371412396430969,2.587308406829834,0.4049800038337707,2.806240558624268,50000 -1071.72580742836,2.1351735591888428,30637.59294629097,91114,0,30637.59294629097,0.2911000251770019,3.8347907066345215,10000,31714.630335330963,0.408223032951355,2.8384816646575928,0.3856199979782104,3.0218026638031006,50000 -1089.054959774017,2.183558702468872,31147.667982816696,92635,0,31147.667982816696,0.3418000042438507,3.3646998405456543,10000,32242.1359269619,0.4818638265132904,2.327448844909668,0.4464399814605713,2.534179210662842,50000 -1106.6452696323397,2.233637571334839,31657.68229198456,94156,0,31657.68229198456,0.3456000089645386,3.1860923767089844,10000,32769.84343409538,0.5026904940605164,2.148667812347412,0.4616999924182892,2.3995540142059326,50000 -1124.3362169265747,2.278075695037842,32167.90555024147,95677,0,32167.90555024147,0.255700021982193,3.914307117462158,10000,33297.85503602028,0.3589166104793548,3.0547327995300293,0.3370199799537658,3.236257076263428,50000 -1142.114492893219,2.3221030235290527,32677.89052796364,97198,0,32677.89052796364,0.3154000043869018,3.509662628173828,10000,33825.715623378754,0.4383370578289032,2.5672011375427246,0.4063999950885772,2.7606093883514404,50000 -1159.3864908218384,2.3658483028411865,33188.105674266815,98720,0,33188.105674266815,0.2054000049829483,4.586277484893799,10000,34353.2997841835,0.3068598508834839,3.599402904510498,0.2856799960136413,3.769570827484131,50000 -1176.9801092147827,2.4154789447784424,33698.1875834465,100241,0,33698.1875834465,0.2949000000953674,3.5665860176086426,10000,34881.077904462814,0.4196627736091614,2.642141103744507,0.4020799994468689,2.7794859409332275,50000 -1194.4681041240692,2.459041118621826,34208.35723352432,101763,0,34208.35723352432,0.2842999994754791,3.691239595413208,10000,35408.832102775574,0.4217952787876129,2.6498255729675293,0.3804999887943268,2.93494200706482,50000 -1212.0220968723297,2.508146286010742,34718.39355421066,103284,0,34718.39355421066,0.3148000240325928,3.657977104187012,10000,35936.52565956116,0.4627311825752258,2.48760986328125,0.427979975938797,2.753029584884644,50000 -1229.5005240440369,2.5585193634033203,35228.39146900177,104805,0,35228.39146900177,0.3752000033855438,3.0397000312805176,10000,36464.10525536537,0.5387436151504517,1.9647153615951536,0.4964599907398224,2.206085205078125,50000 -1247.2531068325045,2.616370439529419,35738.42715740204,106326,0,35738.42715740204,0.3743000030517578,3.179978370666504,10000,36992.00506877899,0.5285993218421936,2.0710136890411377,0.4899199903011322,2.315893173217773,50000 -1264.6096456050873,2.6608407497406006,36248.5503616333,107847,0,36248.5503616333,0.3248000144958496,3.541719675064087,10000,37519.58160948753,0.4673548936843872,2.430663108825684,0.4399999976158142,2.63357925415039,50000 -1281.939534187317,2.70654034614563,36758.63085794449,109368,0,36758.63085794449,0.388700008392334,3.003093719482422,10000,38047.09022331238,0.5494060516357422,1.9401334524154663,0.510699987411499,2.1658456325531006,50000 -1299.4154126644137,2.756378173828125,37268.7436478138,110890,0,37268.7436478138,0.4267000257968902,2.716227054595948,10000,38574.78150773048,0.6084582209587097,1.6108510494232178,0.5411800146102905,1.968849301338196,50000 -1316.9662177562714,2.803908109664917,37778.82757425308,112411,0,37778.82757425308,0.416700005531311,2.758341073989868,10000,39102.51637029648,0.5785833597183228,1.7491344213485718,0.5329799652099609,2.002593278884888,50000 -1334.4177539348602,2.848686695098877,38288.882075071335,113933,0,38288.882075071335,0.4113000333309173,2.844345569610596,10000,39630.119643211365,0.5904615521430969,1.7069220542907717,0.5450599789619446,1.986126065254212,50000 -1351.958430767059,2.899632215499878,38799.00455093384,115454,0,38799.00455093384,0.426000028848648,2.689170360565185,10000,40157.88823246956,0.5925542116165161,1.6781911849975586,0.5508800148963928,1.918237566947937,50000 -1369.5922305583954,2.9562196731567383,39309.09029126167,116975,0,39309.09029126167,0.3038000166416168,3.7805306911468506,10000,40685.7168943882,0.4353674948215484,2.671430349349976,0.4124199748039245,2.853731870651245,50000 -1387.3949823379517,3.004948377609253,39819.04309177399,118496,0,39819.04309177399,0.3787000179290771,3.0569660663604736,10000,41213.57382154465,0.5356146097183228,2.0111894607543945,0.498659998178482,2.2324469089508057,50000 -1405.0201497077942,3.0573155879974365,40329.04471373558,120017,0,40329.04471373558,0.3614000082015991,3.282891273498535,10000,41741.30652666092,0.5190728306770325,2.124185562133789,0.463919997215271,2.485764741897583,50000 -1422.4556045532229,3.1066973209381104,40839.1926074028,121539,0,40839.1926074028,0.4230000078678131,2.734823703765869,10000,42268.991443395615,0.5934311151504517,1.6687780618667605,0.545740008354187,1.940511703491211,50000 -1440.025390625,3.157965421676636,41349.20190811157,123061,0,41349.20190811157,0.3936000168323517,2.924185514450073,10000,42796.67501139641,0.5442841053009033,1.9463707208633425,0.5045599937438965,2.186814069747925,50000 -1457.5347940921783,3.209736585617065,41859.287153720856,124582,0,41859.287153720856,0.458400011062622,2.5650177001953125,10000,43324.37479400635,0.6316764950752258,1.5131675004959106,0.5785199999809265,1.7937418222427368,50000 -1475.3515548706057,3.2582929134368896,42369.22559094429,126103,0,42369.22559094429,0.4564000070095062,2.564093589782715,10000,43852.23295521736,0.630301296710968,1.5103756189346311,0.5861600041389465,1.7711243629455566,50000 -1492.9049038887024,3.306204319000244,42879.3197491169,127624,0,42879.3197491169,0.4485000073909759,2.574965000152588,10000,44379.98077702522,0.6471021771430969,1.4206527471542358,0.5690199732780457,1.836398720741272,50000 -1510.3695363998413,3.35434365272522,43389.23870754242,129145,0,43389.23870754242,0.453900009393692,2.5460996627807617,10000,44907.46559214592,0.6431162357330322,1.4401280879974363,0.5803799629211426,1.769958734512329,50000 -1527.6565673351288,3.409334897994995,43899.253200531006,130666,0,43899.253200531006,0.4480000138282776,2.622263193130493,10000,45434.87541007996,0.6172273755073547,1.5718461275100708,0.5622000098228455,1.8866841793060305,50000 -1544.999391555786,3.458484411239624,44409.17866325378,132187,0,44409.17866325378,0.4357000291347503,2.7501347064971924,10000,45962.245206832886,0.5950454473495483,1.6977694034576416,0.5471599698066711,1.9765727519989007,50000 -1563.513568162918,3.509937763214112,44919.16495108605,133708,0,44919.16495108605,0.4850000143051147,2.405172348022461,10000,46490.850856781006,0.6533800959587097,1.3872883319854736,0.6028199791908264,1.666730284690857,50000 -1581.0774466991425,3.550867080688477,45429.14738154411,135229,0,45429.14738154411,0.5041000247001648,2.260892868041992,10000,47018.49057650566,0.6850087642669678,1.247851014137268,0.632099986076355,1.5180271863937378,50000 -1598.6504790782928,3.605985164642334,45939.10034203529,136750,0,45939.10034203529,0.4648000299930572,2.5056369304656982,10000,47546.12515926361,0.6704998016357422,1.3005253076553345,0.5945799946784973,1.7191460132598877,50000 -1616.1354587078094,3.6534364223480233,46449.014721632,138271,0,46449.014721632,0.5098000168800354,2.2224693298339844,10000,48073.62446284294,0.6985809803009033,1.16620934009552,0.6342599987983704,1.5143696069717407,50000 -1633.5621328353882,3.705156087875366,46959.01626968384,139791,0,46959.01626968384,0.4791000187397003,2.424319267272949,10000,48601.15681409836,0.6724529266357422,1.3127702474594116,0.6125199794769287,1.6271332502365112,50000 -1650.8986172676086,3.7565438747406006,47468.91835808754,141312,0,47468.91835808754,0.5,2.2908971309661865,10000,49128.50067186356,0.6929009556770325,1.2138513326644895,0.6319199800491333,1.516939997673035,50000 -1668.5178027153015,3.8354225158691406,47978.88781452179,142833,0,47978.88781452179,0.5177000164985657,2.194716215133667,10000,49656.22150874138,0.7071707248687744,1.1520593166351318,0.6445599794387817,1.4677226543426514,50000 -1686.1094892024994,3.890009880065918,48488.88462114334,144354,0,48488.88462114334,0.5253000259399414,2.1807503700256348,10000,50183.9176530838,0.7137077450752258,1.1160260438919067,0.65065997838974,1.4377646446228027,50000 -1703.7169313430786,3.942670345306397,48998.85201382637,145875,0,48998.85201382637,0.534000039100647,2.145510673522949,10000,50711.597870111465,0.7431241869926453,0.9743123650550842,0.6586199998855591,1.4034175872802734,50000 -1721.5213029384613,3.9980673789978014,49508.99960565567,147397,0,49508.99960565567,0.5109000205993652,2.299001932144165,10000,51239.65789651871,0.7065529227256775,1.1428754329681396,0.6351799964904785,1.5308077335357666,50000 -1738.840086698532,4.0529398918151855,50019.09476184845,148918,0,50019.09476184845,0.5452000498771667,2.0889601707458496,10000,51767.18022537232,0.7453164458274841,0.9728658199310304,0.6709199547767639,1.3506791591644287,50000 -1756.4312443733215,4.107654333114624,50529.31740260124,150440,0,50529.31740260124,0.5436000227928162,2.0569565296173096,10000,52295.1010248661,0.748465359210968,0.9664836525917052,0.6740999817848206,1.3404970169067385,50000 -1773.9691338539124,4.158005237579346,51039.45336413384,151961,0,51039.45336413384,0.5437000393867493,2.081461906433105,10000,52822.879269361496,0.7463727593421936,0.9735182523727416,0.6724399924278259,1.3431792259216309,50000 -1791.6252291202543,4.218658447265625,51549.48877668381,153482,0,51549.48877668381,0.5611000061035156,1.965351700782776,10000,53350.68533778191,0.7898397445678711,0.7781878113746643,0.6873999834060669,1.2663946151733398,50000 -1808.892503023148,4.2747087478637695,52059.40568423271,155003,0,52059.40568423271,0.5552999973297119,2.019956350326538,10000,53877.97802639008,0.7786391973495483,0.8249820470809937,0.6894599795341492,1.2767852544784546,50000 -1826.589562416077,4.334457635879517,52569.56496477127,156525,0,52569.56496477127,0.5636000037193298,1.9697444438934328,10000,54405.94682025909,0.7767059803009033,0.8280686140060425,0.6906200051307678,1.253747582435608,50000 -1844.0382986068728,4.386325359344482,53079.482422828674,158046,0,53079.482422828674,0.5710000395774841,1.9607090950012207,10000,54933.41873812676,0.7870694994926453,0.7952739596366882,0.6987599730491638,1.248716950416565,50000 -1861.841367483139,4.439189434051514,53589.56930446625,159567,0,53589.56930446625,0.5830000042915344,1.915332674980164,10000,55461.414071798325,0.7997050285339355,0.7518975138664246,0.7044199705123901,1.2122431993484497,50000 -1879.416244983673,4.49467396736145,54099.57530117035,161088,0,54099.57530117035,0.5833000540733337,1.886929988861084,10000,55989.10298204422,0.7997449040412903,0.7390770316123962,0.7134400010108948,1.1649560928344729,50000 -1897.226199388504,4.549734115600586,54609.529450416565,162609,0,54609.529450416565,0.5940999984741211,1.8632287979125977,10000,56516.9751701355,0.8374122977256775,0.5941317677497864,0.7181999683380127,1.1413167715072632,50000 -1914.762225151062,4.603350400924683,55119.610013246536,164130,0,55119.610013246536,0.5824000239372253,1.869858741760254,10000,57044.69862341881,0.8246572017669678,0.6434984803199768,0.7135599851608276,1.1545459032058716,50000 -1933.12388586998,4.66450572013855,55629.50521326065,165651,0,55629.50521326065,0.5944000482559204,1.8382610082626345,10000,57573.0696310997,0.8374919891357422,0.596867024898529,0.7252799868583679,1.1151868104934692,50000 -1950.4011313915253,4.738975524902344,56139.65770363808,167173,0,56139.65770363808,0.6052000522613525,1.811101675033569,10000,58100.62710976601,0.8379902839660645,0.5885335206985474,0.7274599671363831,1.0997930765151978,50000 -1967.9093770980835,4.797713041305542,56649.87207078934,168695,0,56649.87207078934,0.6029000282287598,1.819189548492432,10000,58628.46256566048,0.8406209945678711,0.5672109723091125,0.730459988117218,1.097076654434204,50000 -1985.569852113724,4.8556458950042725,57160.08248925209,170217,0,57160.08248925209,0.6044000387191772,1.8097580671310425,10000,59156.44516658783,0.8475565910339355,0.5495447516441345,0.7321199774742126,1.083914279937744,50000 -2003.057029247284,4.913210391998291,57670.17210030556,171739,0,57670.17210030556,0.6091000437736511,1.781596541404724,10000,59684.1325712204,0.8664699792861938,0.4829581677913666,0.7368800044059753,1.0651110410690308,50000 -2020.586843013764,4.973502397537232,58180.46226763725,173260,0,58180.46226763725,0.6144000291824341,1.779536247253418,10000,60212.06752181053,0.8635004758834839,0.4903871119022369,0.7384999990463257,1.065165638923645,50000 -2038.053610086441,5.037261247634888,58690.6109726429,174781,0,58690.6109726429,0.6187000274658203,1.7618048191070557,10000,60739.801298856735,0.8696388602256775,0.4675772190093994,0.7422800064086914,1.044396162033081,50000 -2055.5500314235687,5.10002589225769,59200.50735998154,176301,0,59200.50735998154,0.6177000403404236,1.7554930448532104,10000,61267.31038618088,0.8717913031578064,0.4568901360034942,0.7438399791717529,1.0421042442321775,50000 -2073.018792390824,5.156181573867798,59710.523655653,177822,0,59710.523655653,0.6203000545501709,1.750431776046753,10000,61794.90425157547,0.8772520422935486,0.435047298669815,0.7451599836349487,1.0354770421981812,50000 -2090.9237022399902,5.213481664657593,60220.68492269516,179343,0,60220.68492269516,0.6243000030517578,1.7491586208343506,10000,62323.07969260216,0.8833904266357422,0.4190746545791626,0.7470600008964539,1.032281517982483,50000 -2108.6718752384186,5.271944284439087,60730.57482671738,180863,0,60730.57482671738,0.6240000128746033,1.746339201927185,10000,62850.83002829552,0.8840082883834839,0.4111105501651764,0.7485599517822266,1.0272870063781738,50000 -2126.1875302791595,5.333211421966553,61240.5575094223,182383,0,61240.5575094223,0.6265000104904175,1.7416620254516602,10000,63378.4419465065,0.8837690949440002,0.4141952693462372,0.7486000061035156,1.0234150886535645,50000 -2143.5833439826965,5.390666723251343,61750.54298973084,183904,0,61750.54298973084,0.6260000467300415,1.7438762187957764,10000,63905.93356466293,0.8855827450752258,0.4065465927124023,0.7492199540138245,1.0245238542556765,50000 -2161.1756496429443,5.449017286300659,62260.43835067749,185424,0,62260.43835067749,0.6255000233650208,1.741317868232727,10000,64433.532873392105,0.8825932741165161,0.4104780256748199,0.7490999698638916,1.0238432884216309,50000 -2178.622392654419,5.510376214981079,62676.811291217804,186666,0,62676.811291217804,0.6260000467300415,1.7412824630737305,10000,64867.45725274086,0.8877750039100647,0.4056401550769806,0.7488799691200256,1.0230340957641602,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/measurements.csv deleted file mode 100644 index 829e4ae79..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1993 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6918318,6.925329,,,,,,,,,,,,,, -1,,,0.0011160713620483,6.912219524383545,0.0011199999134987,6.912059783935547,50000.0,0.0009000000427477,6.912177562713623,10000.0,32.56960868835449,50.20822787284851,32.56960868835449,17.638524532318115,0.0,0.0 -100,0.75867724,6.623493,,,,,,,,,,,,,, -200,0.9301058,6.30686,,,,,,,,,,,,,, -300,2.9603558,6.025938,,,,,,,,,,,,,, -400,2.3817933,5.737984,,,,,,,,,,,,,, -500,4.3303304,5.6007156,,,,,,,,,,,,,, -600,3.1323817,5.4906073,,,,,,,,,,,,,, -700,3.8981688,5.153554,,,,,,,,,,,,,, -800,3.0069458,5.0970273,,,,,,,,,,,,,, -900,2.6323955,4.8379703,,,,,,,,,,,,,, -1000,2.781022,4.8360505,,,,,,,,,,,,,, -1100,3.099763,4.4204946,,,,,,,,,,,,,, -1200,2.9360714,4.360841,,,,,,,,,,,,,, -1300,4.130004,4.2403893,,,,,,,,,,,,,, -1400,3.4329877,4.105845,,,,,,,,,,,,,, -1500,2.0853667,4.113501,,,,,,,,,,,,,, -1509,,,0.2197464853525161,3.89326810836792,0.2019199877977371,4.026599407196045,50000.0,0.1509000062942505,4.556992053985596,10000.0,542.8288688659668,577.9301149845123,542.8288688659668,35.02720069885254,0.0212280750274658,0.0 -1600,2.3858426,3.8714764,,,,,,,,,,,,,, -1700,2.358916,3.9363832,,,,,,,,,,,,,, -1800,2.4065418,3.6828501,,,,,,,,,,,,,, -1900,1.9262757,3.6648555,,,,,,,,,,,,,, -2000,2.8136885,3.5356805,,,,,,,,,,,,,, -2100,1.7582777,3.5298254,,,,,,,,,,,,,, -2200,1.1673323,3.4409008,,,,,,,,,,,,,, -2300,1.8350135,3.4004924,,,,,,,,,,,,,, -2400,1.3902681,3.4100606,,,,,,,,,,,,,, -2500,1.36213,3.290894,,,,,,,,,,,,,, -2600,2.0858245,3.2474535,,,,,,,,,,,,,, -2700,1.0480297,3.2610984,,,,,,,,,,,,,, -2800,1.1775793,3.104563,,,,,,,,,,,,,, -2900,1.3345753,3.1019416,,,,,,,,,,,,,, -3000,1.1653174,3.1591296,,,,,,,,,,,,,, -3019,,,0.3487125337123871,3.0133767127990723,0.3220599889755249,3.1563539505004883,50000.0,0.2418000102043151,3.88297438621521,10000.0,1052.830677509308,1105.6641788482666,1052.830677509308,52.6793520450592,0.0487828254699707,0.0 -3100,1.2123971,3.0827894,,,,,,,,,,,,,, -3200,1.279218,2.9156497,,,,,,,,,,,,,, -3300,1.0950172,3.0515866,,,,,,,,,,,,,, -3400,0.83063334,3.0650992,,,,,,,,,,,,,, -3500,0.89952147,2.9958024,,,,,,,,,,,,,, -3600,0.98236495,3.0642638,,,,,,,,,,,,,, -3700,1.0125346,2.9678643,,,,,,,,,,,,,, -3800,0.8205798,2.9506655,,,,,,,,,,,,,, -3900,0.8005396,2.905457,,,,,,,,,,,,,, -4000,0.97896636,3.0121603,,,,,,,,,,,,,, -4100,0.9985596,2.8313797,,,,,,,,,,,,,, -4200,0.9822568,2.769692,,,,,,,,,,,,,, -4300,0.72321516,2.7834692,,,,,,,,,,,,,, -4400,0.83831626,2.6256535,,,,,,,,,,,,,, -4500,0.9943323,2.6692452,,,,,,,,,,,,,, -4531,,,0.3955078125,2.769458055496216,0.3713600039482116,2.910965919494629,50000.0,0.2779000103473663,3.6371009349823,10000.0,1563.0581607818604,1633.3022310733795,1563.0581607818604,70.0097508430481,0.0756216049194336,0.0 -4600,1.1658111,2.7885015,,,,,,,,,,,,,, -4700,0.8371236,2.8539748,,,,,,,,,,,,,, -4800,1.0150601,2.757195,,,,,,,,,,,,,, -4900,0.8501879,2.700549,,,,,,,,,,,,,, -5000,0.8787015,2.7652144,,,,,,,,,,,,,, -5100,0.9660843,2.6625001,,,,,,,,,,,,,, -5200,0.83305895,2.8212426,,,,,,,,,,,,,, -5300,0.8074452,2.6512182,,,,,,,,,,,,,, -5400,0.98777694,2.6301508,,,,,,,,,,,,,, -5500,0.8262503,2.6927683,,,,,,,,,,,,,, -5600,0.825571,2.5901084,,,,,,,,,,,,,, -5700,1.0419825,2.7025695,,,,,,,,,,,,,, -5800,1.0477382,2.5708516,,,,,,,,,,,,,, -5900,0.9926294,2.6418614,,,,,,,,,,,,,, -6000,0.972275,2.629332,,,,,,,,,,,,,, -6044,,,0.337890625,3.187882423400879,0.3098799884319305,3.3979156017303467,50000.0,0.2471000105142593,4.004795551300049,10000.0,2073.061323404312,2160.9724485874176,2073.061323404312,87.59454846382141,0.1046609878540039,0.0 -6100,0.87360704,2.5422387,,,,,,,,,,,,,, -6200,0.9652313,2.5861616,,,,,,,,,,,,,, -6300,1.2345128,2.7309914,,,,,,,,,,,,,, -6400,0.922676,2.7056227,,,,,,,,,,,,,, -6500,0.89744335,2.572039,,,,,,,,,,,,,, -6600,1.2017502,2.7657533,,,,,,,,,,,,,, -6700,1.0326655,2.5858934,,,,,,,,,,,,,, -6800,1.1503733,2.5700247,,,,,,,,,,,,,, -6900,1.3572811,2.5086932,,,,,,,,,,,,,, -7000,0.9282428,2.350275,,,,,,,,,,,,,, -7100,0.9344136,2.53201,,,,,,,,,,,,,, -7200,0.869952,2.4932497,,,,,,,,,,,,,, -7300,1.0135813,2.4885626,,,,,,,,,,,,,, -7400,1.0522333,2.642329,,,,,,,,,,,,,, -7500,1.2061187,2.5065162,,,,,,,,,,,,,, -7559,,,0.2487842738628387,3.991169214248657,0.2201599925756454,4.308241844177246,50000.0,0.1762000024318695,4.848935604095459,10000.0,2583.286673307419,2689.5087604522705,2583.286673307419,105.82629466056824,0.1317870616912841,0.0 -7600,1.0380834,2.5072663,,,,,,,,,,,,,, -7700,1.0544192,2.544132,,,,,,,,,,,,,, -7800,1.196181,2.5292864,,,,,,,,,,,,,, -7900,0.8769747,2.5066142,,,,,,,,,,,,,, -8000,0.9878864,2.4541228,,,,,,,,,,,,,, -8100,0.93778163,2.5467548,,,,,,,,,,,,,, -8200,0.9338937,2.4542816,,,,,,,,,,,,,, -8300,0.9240605,2.4200766,,,,,,,,,,,,,, -8400,0.8860188,2.3977695,,,,,,,,,,,,,, -8500,1.0844163,2.5790982,,,,,,,,,,,,,, -8600,0.9689347,2.5368073,,,,,,,,,,,,,, -8700,0.9509433,2.401546,,,,,,,,,,,,,, -8800,0.9000289,2.4495318,,,,,,,,,,,,,, -8900,0.9540205,2.387352,,,,,,,,,,,,,, -9000,0.9406375,2.5517597,,,,,,,,,,,,,, -9072,,,0.406927615404129,2.715965986251831,0.3807999789714813,2.8986012935638428,50000.0,0.282800018787384,3.630770444869995,10000.0,3093.225257396698,3217.094662427902,3093.225257396698,123.39727687835692,0.1558332443237304,0.0 -9100,0.9619528,2.4064095,,,,,,,,,,,,,, -9200,1.0610994,2.530755,,,,,,,,,,,,,, -9300,0.9405793,2.4350014,,,,,,,,,,,,,, -9400,0.8827111,2.3540037,,,,,,,,,,,,,, -9500,0.8734927,2.3675008,,,,,,,,,,,,,, -9600,0.93498236,2.5134199,,,,,,,,,,,,,, -9700,0.98578584,2.5220747,,,,,,,,,,,,,, -9800,0.978146,2.4772801,,,,,,,,,,,,,, -9900,1.12708,2.5486946,,,,,,,,,,,,,, -10000,0.9571019,2.4346826,,,,,,,,,,,,,, -10100,0.97654825,2.5295594,,,,,,,,,,,,,, -10200,1.0274974,2.5239584,,,,,,,,,,,,,, -10300,0.9487051,2.5227742,,,,,,,,,,,,,, -10400,1.1297051,2.3991704,,,,,,,,,,,,,, -10500,0.93224823,2.3667264,,,,,,,,,,,,,, -10587,,,0.4227718412876129,2.5902175903320312,0.3854199945926666,2.824572324752808,50000.0,0.2955000102519989,3.522867679595948,10000.0,3603.417446374893,3745.611699104309,3603.417446374893,141.6427493095398,0.183197021484375,0.0 -10600,0.9278645,2.3667502,,,,,,,,,,,,,, -10700,1.121213,2.3825855,,,,,,,,,,,,,, -10800,0.96053547,2.4693172,,,,,,,,,,,,,, -10900,0.93230134,2.3395178,,,,,,,,,,,,,, -11000,0.9578837,2.5156012,,,,,,,,,,,,,, -11100,1.0025424,2.3783984,,,,,,,,,,,,,, -11200,0.9395115,2.5072165,,,,,,,,,,,,,, -11300,1.1584064,2.4105198,,,,,,,,,,,,,, -11400,0.97944427,2.5132647,,,,,,,,,,,,,, -11500,1.1738026,2.4480016,,,,,,,,,,,,,, -11600,0.9758062,2.3774266,,,,,,,,,,,,,, -11700,1.0833039,2.466622,,,,,,,,,,,,,, -11800,0.99014664,2.4001327,,,,,,,,,,,,,, -11900,1.0614855,2.2414186,,,,,,,,,,,,,, -12000,1.0841053,2.3199306,,,,,,,,,,,,,, -12100,0.97636473,2.4054902,,,,,,,,,,,,,, -12101,,,0.3722098171710968,2.9450957775115967,0.3542799949645996,3.08566689491272,50000.0,0.2666000127792358,3.802535057067871,10000.0,4113.432627916336,4273.278554916382,4113.432627916336,159.21364331245422,0.211188793182373,0.0 -12200,0.9699472,2.4615567,,,,,,,,,,,,,, -12300,0.9575459,2.4863608,,,,,,,,,,,,,, -12400,0.9315971,2.466129,,,,,,,,,,,,,, -12500,0.9574609,2.3252213,,,,,,,,,,,,,, -12600,1.0190475,2.2910144,,,,,,,,,,,,,, -12700,1.1068217,2.534083,,,,,,,,,,,,,, -12800,1.0146323,2.329993,,,,,,,,,,,,,, -12900,1.0153763,2.4527535,,,,,,,,,,,,,, -13000,1.297889,2.4487512,,,,,,,,,,,,,, -13100,1.01337,2.3547745,,,,,,,,,,,,,, -13200,1.079503,2.5312932,,,,,,,,,,,,,, -13300,1.1148245,2.4550247,,,,,,,,,,,,,, -13400,1.1224226,2.3857195,,,,,,,,,,,,,, -13500,1.0923051,2.3516965,,,,,,,,,,,,,, -13600,0.8947403,2.3186831,,,,,,,,,,,,,, -13615,,,0.2633330523967743,3.873202085494995,0.2458799928426742,3.999222993850708,50000.0,0.1856000125408172,4.675883769989014,10000.0,4623.356384038925,4800.736140012741,4623.356384038925,176.6627459526062,0.2437028884887695,0.0 -13700,1.0029896,2.3603425,,,,,,,,,,,,,, -13800,1.0235002,2.4540117,,,,,,,,,,,,,, -13900,1.1258758,2.325111,,,,,,,,,,,,,, -14000,1.022581,2.4096766,,,,,,,,,,,,,, -14100,0.95672154,2.2513056,,,,,,,,,,,,,, -14200,0.9377506,2.469487,,,,,,,,,,,,,, -14300,1.0770319,2.3734987,,,,,,,,,,,,,, -14400,1.1738803,2.4087224,,,,,,,,,,,,,, -14500,1.0947499,2.3654294,,,,,,,,,,,,,, -14600,1.1228278,2.314546,,,,,,,,,,,,,, -14700,0.9806299,2.4365325,,,,,,,,,,,,,, -14800,1.069935,2.2636876,,,,,,,,,,,,,, -14900,1.0048152,2.4438572,,,,,,,,,,,,,, -15000,1.0193031,2.4071543,,,,,,,,,,,,,, -15100,1.1611769,2.4266367,,,,,,,,,,,,,, -15131,,,0.4353874325752258,2.5323400497436523,0.3913599848747253,2.8394031524658203,50000.0,0.3039000034332275,3.497375249862671,10000.0,5133.596252441406,5328.493609189987,5133.596252441406,194.09912204742432,0.2718853950500488,0.0 -15200,1.0042772,2.3108113,,,,,,,,,,,,,, -15300,1.0570179,2.3335474,,,,,,,,,,,,,, -15400,1.182984,2.5137677,,,,,,,,,,,,,, -15500,0.98299813,2.1900418,,,,,,,,,,,,,, -15600,1.1022979,2.321773,,,,,,,,,,,,,, -15700,1.0550852,2.413674,,,,,,,,,,,,,, -15800,1.0413762,2.404875,,,,,,,,,,,,,, -15900,1.0457419,2.3241882,,,,,,,,,,,,,, -16000,1.053542,2.25916,,,,,,,,,,,,,, -16100,1.0456344,2.263392,,,,,,,,,,,,,, -16200,1.3739253,2.4287577,,,,,,,,,,,,,, -16300,1.0931123,2.3998353,,,,,,,,,,,,,, -16400,1.0283419,2.280624,,,,,,,,,,,,,, -16500,1.2486978,2.4085224,,,,,,,,,,,,,, -16600,1.0054451,2.388751,,,,,,,,,,,,,, -16647,,,0.3074577450752258,3.5628976821899414,0.2874000072479248,3.746313571929932,50000.0,0.2205000072717666,4.4452009201049805,10000.0,5643.834300756455,5856.744036912918,5643.834300756455,212.02861714363087,0.3007538318634033,0.0 -16700,1.0336918,2.310186,,,,,,,,,,,,,, -16800,1.1276366,2.3759525,,,,,,,,,,,,,, -16900,1.002904,2.30132,,,,,,,,,,,,,, -17000,1.1168402,2.508636,,,,,,,,,,,,,, -17100,1.104242,2.299437,,,,,,,,,,,,,, -17200,1.2715731,2.4366357,,,,,,,,,,,,,, -17300,0.99288577,2.3339622,,,,,,,,,,,,,, -17400,1.0300027,2.450941,,,,,,,,,,,,,, -17500,1.0951173,2.3074622,,,,,,,,,,,,,, -17600,1.1526618,2.291081,,,,,,,,,,,,,, -17700,0.9455926,2.2555246,,,,,,,,,,,,,, -17800,1.0443404,2.299295,,,,,,,,,,,,,, -17900,1.0693957,2.3740754,,,,,,,,,,,,,, -18000,1.0656945,2.314517,,,,,,,,,,,,,, -18100,1.1020108,2.40695,,,,,,,,,,,,,, -18163,,,0.2423668652772903,3.9114837646484375,0.2272399961948394,4.067636013031006,50000.0,0.1614000052213668,4.841095924377441,10000.0,6154.009658336639,6384.4521453380585,6154.009658336639,229.48013925552368,0.3294222354888916,0.0 -18200,1.0712755,2.337069,,,,,,,,,,,,,, -18300,1.097972,2.4382043,,,,,,,,,,,,,, -18400,1.159821,2.551009,,,,,,,,,,,,,, -18500,1.126175,2.3684227,,,,,,,,,,,,,, -18600,1.0900832,2.3437233,,,,,,,,,,,,,, -18700,0.9634922,2.2486546,,,,,,,,,,,,,, -18800,0.9697412,2.3070917,,,,,,,,,,,,,, -18900,1.1990248,2.356819,,,,,,,,,,,,,, -19000,1.2652467,2.3792882,,,,,,,,,,,,,, -19100,0.9588514,2.3091416,,,,,,,,,,,,,, -19200,1.0211673,2.321998,,,,,,,,,,,,,, -19300,1.0423495,2.3724737,,,,,,,,,,,,,, -19400,1.0922844,2.4244666,,,,,,,,,,,,,, -19500,1.0260929,2.3134751,,,,,,,,,,,,,, -19600,1.0858011,2.3833742,,,,,,,,,,,,,, -19679,,,0.3469387590885162,3.0987894535064697,0.3242200016975403,3.262202739715576,50000.0,0.2403000146150589,3.9924895763397217,10000.0,6664.045799255371,6911.997058391571,6664.045799255371,246.9081676006317,0.357844591140747,0.0 -19700,1.189821,2.531449,,,,,,,,,,,,,, -19800,1.1595683,2.222979,,,,,,,,,,,,,, -19900,1.0934895,2.1454494,,,,,,,,,,,,,, -20000,1.2484921,2.3624573,,,,,,,,,,,,,, -20100,1.1462479,2.3233895,,,,,,,,,,,,,, -20200,1.0236509,2.3868358,,,,,,,,,,,,,, -20300,1.0037125,2.284195,,,,,,,,,,,,,, -20400,1.0284976,2.3895836,,,,,,,,,,,,,, -20500,1.0500907,2.246243,,,,,,,,,,,,,, -20600,1.1611434,2.2661958,,,,,,,,,,,,,, -20700,1.278357,2.2216976,,,,,,,,,,,,,, -20800,1.1835324,2.381477,,,,,,,,,,,,,, -20900,1.029218,2.4351993,,,,,,,,,,,,,, -21000,1.0022743,2.325095,,,,,,,,,,,,,, -21100,1.0664277,2.3167706,,,,,,,,,,,,,, -21195,,,0.283581793308258,3.587366819381714,0.262800008058548,3.7317724227905273,50000.0,0.1844000071287155,4.544374465942383,10000.0,7174.211032867432,7439.85765004158,7174.211032867432,264.5111756324768,0.3970851898193359,0.0 -21200,1.1162653,2.2807446,,,,,,,,,,,,,, -21300,1.198351,2.4881103,,,,,,,,,,,,,, -21400,1.1312965,2.2395287,,,,,,,,,,,,,, -21500,1.0946193,2.3415394,,,,,,,,,,,,,, -21600,1.0819503,2.3103657,,,,,,,,,,,,,, -21700,1.2756708,2.3311827,,,,,,,,,,,,,, -21800,1.1622084,2.3912656,,,,,,,,,,,,,, -21900,1.2358739,2.4403849,,,,,,,,,,,,,, -22000,1.0425084,2.2759223,,,,,,,,,,,,,, -22100,1.0538462,2.3185682,,,,,,,,,,,,,, -22200,1.053034,2.4430647,,,,,,,,,,,,,, -22300,1.0664825,2.294352,,,,,,,,,,,,,, -22400,1.1652757,2.4100351,,,,,,,,,,,,,, -22500,1.0797244,2.2818408,,,,,,,,,,,,,, -22600,1.0551499,2.2782805,,,,,,,,,,,,,, -22700,0.9430982,2.3482583,,,,,,,,,,,,,, -22712,,,0.2009127885103225,4.426466464996338,0.1850199997425079,4.574308395385742,50000.0,0.1419000029563903,5.220768451690674,10000.0,7684.350539445877,7967.571515083313,7684.350539445877,281.99954295158386,0.4310669898986816,0.0 -22800,1.0267407,2.3604908,,,,,,,,,,,,,, -22900,1.0610523,2.2763436,,,,,,,,,,,,,, -23000,1.0752685,2.4068542,,,,,,,,,,,,,, -23100,1.1519467,2.4236236,,,,,,,,,,,,,, -23200,1.0167192,2.347848,,,,,,,,,,,,,, -23300,1.0526314,2.366541,,,,,,,,,,,,,, -23400,1.0131822,2.2126122,,,,,,,,,,,,,, -23500,1.0813845,2.3130198,,,,,,,,,,,,,, -23600,1.0606372,2.3920584,,,,,,,,,,,,,, -23700,1.3088449,2.4266303,,,,,,,,,,,,,, -23800,1.0860434,2.3176045,,,,,,,,,,,,,, -23900,1.1838784,2.2765913,,,,,,,,,,,,,, -24000,1.0022871,2.3904147,,,,,,,,,,,,,, -24100,1.0880053,2.4066975,,,,,,,,,,,,,, -24200,1.0297829,2.3835087,,,,,,,,,,,,,, -24229,,,0.2044005095958709,4.545769691467285,0.1773199886083603,4.824890613555908,50000.0,0.1383000016212463,5.4493255615234375,10000.0,8194.534126281738,8495.416982650757,8194.534126281738,299.5668559074402,0.4717752933502197,0.0 -24300,1.0169438,2.2907422,,,,,,,,,,,,,, -24400,1.1858217,2.39229,,,,,,,,,,,,,, -24500,1.1878655,2.4195025,,,,,,,,,,,,,, -24600,1.066637,2.3202634,,,,,,,,,,,,,, -24700,0.99174887,2.1268473,,,,,,,,,,,,,, -24800,1.0468626,2.3937337,,,,,,,,,,,,,, -24900,1.0917542,2.365864,,,,,,,,,,,,,, -25000,1.1243709,2.3516228,,,,,,,,,,,,,, -25100,1.08369,2.277793,,,,,,,,,,,,,, -25200,1.2392342,2.300098,,,,,,,,,,,,,, -25300,0.98637575,2.3268974,,,,,,,,,,,,,, -25400,1.0638889,2.3926916,,,,,,,,,,,,,, -25500,1.1376402,2.353579,,,,,,,,,,,,,, -25600,1.066682,2.3357446,,,,,,,,,,,,,, -25700,0.99738806,2.275955,,,,,,,,,,,,,, -25747,,,0.2718032598495483,3.84205961227417,0.2549799978733063,4.00461483001709,50000.0,0.19200000166893,4.759922981262207,10000.0,8704.772486686707,9023.064319849014,8704.772486686707,316.89233231544495,0.50246262550354,0.0 -25800,1.0730114,2.2615035,,,,,,,,,,,,,, -25900,1.102622,2.312048,,,,,,,,,,,,,, -26000,1.1645281,2.4641194,,,,,,,,,,,,,, -26100,1.1549559,2.3008,,,,,,,,,,,,,, -26200,1.2744064,2.29487,,,,,,,,,,,,,, -26300,1.0931231,2.3755774,,,,,,,,,,,,,, -26400,1.1086199,2.3545432,,,,,,,,,,,,,, -26500,1.0016993,2.35064,,,,,,,,,,,,,, -26600,1.0894963,2.2776792,,,,,,,,,,,,,, -26700,0.956097,2.1222994,,,,,,,,,,,,,, -26800,1.2583299,2.3065658,,,,,,,,,,,,,, -26900,1.1958921,2.2810764,,,,,,,,,,,,,, -27000,1.0205188,2.3944237,,,,,,,,,,,,,, -27100,1.0968462,2.3693624,,,,,,,,,,,,,, -27200,1.2443137,2.3087702,,,,,,,,,,,,,, -27265,,,0.3291215002536773,3.240143060684204,0.3055999875068664,3.3980844020843506,50000.0,0.2277000099420547,4.118068218231201,10000.0,9214.88447213173,9550.579601049423,9214.88447213173,334.2111656665802,0.5347170829772949,0.0 -27300,1.0829878,2.3627186,,,,,,,,,,,,,, -27400,1.2091234,2.248063,,,,,,,,,,,,,, -27500,1.0589267,2.1893187,,,,,,,,,,,,,, -27600,1.1575977,2.2346153,,,,,,,,,,,,,, -27700,1.1876527,2.363522,,,,,,,,,,,,,, -27800,1.2100819,2.391467,,,,,,,,,,,,,, -27900,1.216725,2.3316698,,,,,,,,,,,,,, -28000,1.1604964,2.3318994,,,,,,,,,,,,,, -28100,1.2479483,2.3217103,,,,,,,,,,,,,, -28200,1.1206505,2.3516095,,,,,,,,,,,,,, -28300,1.2434323,2.398546,,,,,,,,,,,,,, -28400,1.1651275,2.3018603,,,,,,,,,,,,,, -28500,1.092487,2.2720456,,,,,,,,,,,,,, -28600,1.0600537,2.3575969,,,,,,,,,,,,,, -28700,1.1095865,2.317822,,,,,,,,,,,,,, -28783,,,0.2516741156578064,3.8752613067626953,0.2350799888372421,4.017209053039551,50000.0,0.166700005531311,4.732311725616455,10000.0,9724.903431653976,10078.332313776016,9724.903431653976,351.86159229278564,0.5654406547546387,0.0 -28800,1.0544639,2.2864065,,,,,,,,,,,,,, -28900,1.0719007,2.257184,,,,,,,,,,,,,, -29000,1.1765914,2.3385038,,,,,,,,,,,,,, -29100,0.985112,2.309783,,,,,,,,,,,,,, -29200,1.0838358,2.3023434,,,,,,,,,,,,,, -29300,1.1086701,2.4281898,,,,,,,,,,,,,, -29400,1.2132907,2.3485382,,,,,,,,,,,,,, -29500,1.1570386,2.3635097,,,,,,,,,,,,,, -29600,1.210295,2.3803244,,,,,,,,,,,,,, -29700,1.1437904,2.1679268,,,,,,,,,,,,,, -29800,1.1750433,2.4130778,,,,,,,,,,,,,, -29900,1.1258264,2.2972865,,,,,,,,,,,,,, -30000,1.0140082,2.2519164,,,,,,,,,,,,,, -30100,1.0078909,2.274013,,,,,,,,,,,,,, -30200,1.0922797,2.44805,,,,,,,,,,,,,, -30300,1.1462933,2.312302,,,,,,,,,,,,,, -30301,,,0.2491828650236129,4.077645778656006,0.2342199981212616,4.213882923126221,50000.0,0.171300008893013,4.892067909240723,10000.0,10235.189134597778,10606.446083307266,10235.189134597778,369.6040177345276,0.5961604118347168,0.0 -30400,1.2461032,2.474751,,,,,,,,,,,,,, -30500,1.0421154,2.313753,,,,,,,,,,,,,, -30600,1.1646949,2.2581954,,,,,,,,,,,,,, -30700,1.1074148,2.237516,,,,,,,,,,,,,, -30800,1.161828,2.380077,,,,,,,,,,,,,, -30900,1.1985434,2.2221072,,,,,,,,,,,,,, -31000,1.0940778,2.2862453,,,,,,,,,,,,,, -31100,1.1040283,2.2570174,,,,,,,,,,,,,, -31200,1.0312693,2.223392,,,,,,,,,,,,,, -31300,1.0605093,2.2524915,,,,,,,,,,,,,, -31400,1.1606013,2.365895,,,,,,,,,,,,,, -31500,1.1447814,2.255437,,,,,,,,,,,,,, -31600,1.3123869,2.3390486,,,,,,,,,,,,,, -31700,1.0627414,2.2840934,,,,,,,,,,,,,, -31800,1.1822853,2.2639205,,,,,,,,,,,,,, -31820,,,0.2342952787876129,4.242311954498291,0.2231799960136413,4.331212520599365,50000.0,0.1625000089406967,5.108082294464111,10000.0,10745.33808541298,11134.302097797394,10745.33808541298,387.2266595363617,0.6278092861175537,0.0 -31900,1.2413781,2.3372905,,,,,,,,,,,,,, -32000,1.1195583,2.3215108,,,,,,,,,,,,,, -32100,1.0346966,2.1761708,,,,,,,,,,,,,, -32200,1.1369492,2.3616786,,,,,,,,,,,,,, -32300,1.0695581,2.2096996,,,,,,,,,,,,,, -32400,1.0652572,2.2370772,,,,,,,,,,,,,, -32500,1.1101447,2.225594,,,,,,,,,,,,,, -32600,1.1369789,2.3282084,,,,,,,,,,,,,, -32700,1.0240537,2.3419359,,,,,,,,,,,,,, -32800,1.0832782,2.2921422,,,,,,,,,,,,,, -32900,1.2610277,2.2525005,,,,,,,,,,,,,, -33000,1.0320969,2.3837147,,,,,,,,,,,,,, -33100,1.2755876,2.20352,,,,,,,,,,,,,, -33200,1.0389496,2.2772996,,,,,,,,,,,,,, -33300,1.1269038,2.3056922,,,,,,,,,,,,,, -33338,,,0.3989955186843872,2.7934722900390625,0.3581999838352203,3.0654101371765137,50000.0,0.263700008392334,3.885799407958984,10000.0,11255.258068799973,11661.563098192217,11255.258068799973,404.47988963127136,0.6608669757843018,0.0 -33400,1.17281,2.2737563,,,,,,,,,,,,,, -33500,1.3294134,2.2611198,,,,,,,,,,,,,, -33600,1.286439,2.3187585,,,,,,,,,,,,,, -33700,1.0873978,2.3333075,,,,,,,,,,,,,, -33800,1.1849767,2.4159303,,,,,,,,,,,,,, -33900,1.1696339,2.3745468,,,,,,,,,,,,,, -34000,1.0773085,2.3804128,,,,,,,,,,,,,, -34100,1.2386627,2.2562118,,,,,,,,,,,,,, -34200,1.0412123,2.271724,,,,,,,,,,,,,, -34300,1.1464467,2.2235548,,,,,,,,,,,,,, -34400,1.1338887,2.1771636,,,,,,,,,,,,,, -34500,1.188727,2.4620097,,,,,,,,,,,,,, -34600,1.1318041,2.2619774,,,,,,,,,,,,,, -34700,1.0454416,2.268232,,,,,,,,,,,,,, -34800,1.0487016,2.2136207,,,,,,,,,,,,,, -34857,,,0.2908561825752258,3.669871330261232,0.271479994058609,3.857929229736328,50000.0,0.1992000043392181,4.656971454620361,10000.0,11765.183889389038,12188.885138511658,11765.183889389038,421.78941106796265,0.6944155693054199,0.0 -34900,1.1122099,2.2644408,,,,,,,,,,,,,, -35000,1.0778135,2.3547556,,,,,,,,,,,,,, -35100,1.1978669,2.2559187,,,,,,,,,,,,,, -35200,1.0971518,2.39178,,,,,,,,,,,,,, -35300,1.1132356,2.1721332,,,,,,,,,,,,,, -35400,1.0606161,2.1519537,,,,,,,,,,,,,, -35500,1.1770881,2.37908,,,,,,,,,,,,,, -35600,1.0726303,2.4474604,,,,,,,,,,,,,, -35700,1.3182949,2.4026022,,,,,,,,,,,,,, -35800,1.1208632,2.2567568,,,,,,,,,,,,,, -35900,1.3591352,2.3166773,,,,,,,,,,,,,, -36000,1.1031556,2.243171,,,,,,,,,,,,,, -36100,1.1048735,2.2772808,,,,,,,,,,,,,, -36200,1.1984591,2.3567963,,,,,,,,,,,,,, -36300,1.0948412,2.3514125,,,,,,,,,,,,,, -36376,,,0.3069993555545807,3.613547325134277,0.2903600037097931,3.769576787948608,50000.0,0.2078000158071518,4.676905632019043,10000.0,12275.3075401783,12716.527085065842,12275.3075401783,439.2223870754242,0.727224588394165,0.0 -36400,1.1665876,2.3945823,,,,,,,,,,,,,, -36500,1.2274996,2.219122,,,,,,,,,,,,,, -36600,1.0375991,2.291918,,,,,,,,,,,,,, -36700,1.1519089,2.2090728,,,,,,,,,,,,,, -36800,1.1844577,2.248428,,,,,,,,,,,,,, -36900,1.1487441,2.4504974,,,,,,,,,,,,,, -37000,1.3352547,2.3149176,,,,,,,,,,,,,, -37100,1.2382882,2.2075713,,,,,,,,,,,,,, -37200,1.1553069,2.2916648,,,,,,,,,,,,,, -37300,1.1986923,2.2259932,,,,,,,,,,,,,, -37400,1.2089149,2.2474976,,,,,,,,,,,,,, -37500,1.1196073,2.265006,,,,,,,,,,,,,, -37600,1.2162246,2.4069118,,,,,,,,,,,,,, -37700,1.1410404,2.2848618,,,,,,,,,,,,,, -37800,1.1556107,2.3236852,,,,,,,,,,,,,, -37895,,,0.3186583220958709,3.3709821701049805,0.2976000010967254,3.534862756729126,50000.0,0.2243000119924545,4.230418205261231,10000.0,12785.335270166395,13244.364332437515,12785.335270166395,456.9472150802612,0.7594432830810547,0.0 -37900,1.1504841,2.3916073,,,,,,,,,,,,,, -38000,1.225532,2.2961154,,,,,,,,,,,,,, -38100,1.1121429,2.236185,,,,,,,,,,,,,, -38200,1.0896853,2.2926593,,,,,,,,,,,,,, -38300,1.1671944,2.2752578,,,,,,,,,,,,,, -38400,1.2263175,2.3515072,,,,,,,,,,,,,, -38500,1.1024308,2.2254899,,,,,,,,,,,,,, -38600,1.1224495,2.279457,,,,,,,,,,,,,, -38700,1.1831032,2.4178836,,,,,,,,,,,,,, -38800,1.2323036,2.309847,,,,,,,,,,,,,, -38900,1.2078801,2.2896538,,,,,,,,,,,,,, -39000,1.1721984,2.3091025,,,,,,,,,,,,,, -39100,1.1611724,2.3156118,,,,,,,,,,,,,, -39200,1.1984222,2.2807968,,,,,,,,,,,,,, -39300,1.1192847,2.245006,,,,,,,,,,,,,, -39400,1.1176031,2.098552,,,,,,,,,,,,,, -39413,,,0.2593869566917419,3.792681932449341,0.2449399977922439,3.898503065109253,50000.0,0.176700010895729,4.55711841583252,10000.0,13295.302300453186,13771.746730804443,13295.302300453186,474.2763805389404,0.7922139167785645,0.0 -39500,1.141682,2.3994439,,,,,,,,,,,,,, -39600,1.1598716,2.2940702,,,,,,,,,,,,,, -39700,1.1973305,2.2917979,,,,,,,,,,,,,, -39800,1.2906992,2.3932984,,,,,,,,,,,,,, -39900,1.1726987,2.2040808,,,,,,,,,,,,,, -40000,1.1923237,2.2427824,,,,,,,,,,,,,, -40100,1.2999395,2.2845151,,,,,,,,,,,,,, -40200,1.135861,2.2587094,,,,,,,,,,,,,, -40300,1.2140622,2.3584688,,,,,,,,,,,,,, -40400,1.1014762,2.2422104,,,,,,,,,,,,,, -40500,1.2757204,2.4420764,,,,,,,,,,,,,, -40600,1.2568312,2.1869113,,,,,,,,,,,,,, -40700,1.0730609,2.222753,,,,,,,,,,,,,, -40800,1.0714976,2.2382941,,,,,,,,,,,,,, -40900,1.1932117,2.379607,,,,,,,,,,,,,, -40933,,,0.2283362448215484,3.997976064682007,0.2191199958324432,4.096891403198242,50000.0,0.1660000085830688,4.765476703643799,10000.0,13805.345292568209,14299.461974859238,13805.345292568209,491.86146664619446,0.8255927562713623,0.0 -41000,1.0795816,2.217747,,,,,,,,,,,,,, -41100,1.091995,2.1695175,,,,,,,,,,,,,, -41200,1.1599079,2.3545556,,,,,,,,,,,,,, -41300,1.1831014,2.287053,,,,,,,,,,,,,, -41400,1.1519916,2.2158895,,,,,,,,,,,,,, -41500,1.2348638,2.2342327,,,,,,,,,,,,,, -41600,1.2700344,2.3621492,,,,,,,,,,,,,, -41700,1.1219885,2.3080945,,,,,,,,,,,,,, -41800,1.2021514,2.2659934,,,,,,,,,,,,,, -41900,1.1419921,2.116197,,,,,,,,,,,,,, -42000,1.1354159,2.32734,,,,,,,,,,,,,, -42100,1.1585116,2.2680926,,,,,,,,,,,,,, -42200,1.1217132,2.2827594,,,,,,,,,,,,,, -42300,1.1504611,2.1023567,,,,,,,,,,,,,, -42400,1.1769906,2.1709774,,,,,,,,,,,,,, -42453,,,0.3529775142669678,3.090325355529785,0.3198799788951874,3.338887929916382,50000.0,0.2339000105857849,4.010398387908936,10000.0,14315.4254693985,14827.108598709106,14315.4254693985,509.33509039878845,0.8656878471374512,0.0 -42500,1.3215932,2.2529182,,,,,,,,,,,,,, -42600,1.1152836,2.2044876,,,,,,,,,,,,,, -42700,1.2277851,2.223343,,,,,,,,,,,,,, -42800,1.147374,2.3378308,,,,,,,,,,,,,, -42900,1.11539,2.2670321,,,,,,,,,,,,,, -43000,1.1629881,2.3146162,,,,,,,,,,,,,, -43100,1.2507119,2.2873971,,,,,,,,,,,,,, -43200,1.1531382,2.1819792,,,,,,,,,,,,,, -43300,1.2377305,2.2541666,,,,,,,,,,,,,, -43400,1.2168634,2.2759473,,,,,,,,,,,,,, -43500,1.1892529,2.2977006,,,,,,,,,,,,,, -43600,1.1379642,2.1314504,,,,,,,,,,,,,, -43700,1.1871324,2.1173863,,,,,,,,,,,,,, -43800,1.3170491,2.3492262,,,,,,,,,,,,,, -43900,1.1707662,2.2060494,,,,,,,,,,,,,, -43973,,,0.1209542378783226,5.900317192077637,0.1139400005340576,6.015324115753174,50000.0,0.0865000039339065,6.511098384857178,10000.0,14825.367683410645,15354.713600158691,14825.367683410645,526.9106154441833,0.900765895843506,0.0 -44000,1.2746193,2.3360667,,,,,,,,,,,,,, -44100,1.138756,2.2043204,,,,,,,,,,,,,, -44200,1.309788,2.3071284,,,,,,,,,,,,,, -44300,1.1433198,2.199205,,,,,,,,,,,,,, -44400,1.1639446,2.2428255,,,,,,,,,,,,,, -44500,1.1187729,2.2794921,,,,,,,,,,,,,, -44600,1.3050573,2.4052029,,,,,,,,,,,,,, -44700,1.1379427,2.3473716,,,,,,,,,,,,,, -44800,1.2274169,2.0937226,,,,,,,,,,,,,, -44900,1.1329552,2.202341,,,,,,,,,,,,,, -45000,1.164899,2.3651545,,,,,,,,,,,,,, -45100,1.1314526,2.1743836,,,,,,,,,,,,,, -45200,1.1089981,2.2271855,,,,,,,,,,,,,, -45300,1.1965108,2.29084,,,,,,,,,,,,,, -45400,1.2047043,2.2070532,,,,,,,,,,,,,, -45493,,,0.1727319806814193,4.664542675018311,0.1613999903202057,4.75282621383667,50000.0,0.1197000071406364,5.432078361511231,10000.0,15335.417038440704,15882.430266857147,15335.417038440704,544.4883089065552,0.9369504451751708,0.0 -45500,1.2284088,2.2565389,,,,,,,,,,,,,, -45600,1.1089945,2.3113089,,,,,,,,,,,,,, -45700,1.1450243,2.291527,,,,,,,,,,,,,, -45800,1.1474417,2.2243257,,,,,,,,,,,,,, -45900,1.2021486,2.1527286,,,,,,,,,,,,,, -46000,1.1265191,2.2175128,,,,,,,,,,,,,, -46100,1.2151992,2.3347263,,,,,,,,,,,,,, -46200,1.4563937,2.2152073,,,,,,,,,,,,,, -46300,1.1201432,2.1294649,,,,,,,,,,,,,, -46400,1.1139355,2.3171482,,,,,,,,,,,,,, -46500,1.1960089,2.171061,,,,,,,,,,,,,, -46600,1.1531498,2.2911284,,,,,,,,,,,,,, -46700,1.2011095,2.1336133,,,,,,,,,,,,,, -46800,1.253277,2.3737998,,,,,,,,,,,,,, -46900,1.2589068,2.353378,,,,,,,,,,,,,, -47000,1.2413757,2.164325,,,,,,,,,,,,,, -47014,,,0.1693638414144516,4.89009428024292,0.1624599993228912,4.977695465087891,50000.0,0.1162000074982643,5.645530700683594,10000.0,15845.601004362106,16410.10750222206,15845.601004362106,561.8923692703247,0.9734163284301758,0.0 -47100,1.1505879,2.1734953,,,,,,,,,,,,,, -47200,1.1774551,2.1781046,,,,,,,,,,,,,, -47300,1.2178837,2.2953749,,,,,,,,,,,,,, -47400,1.2208352,2.1506224,,,,,,,,,,,,,, -47500,1.1005902,2.1996953,,,,,,,,,,,,,, -47600,1.1339356,2.339478,,,,,,,,,,,,,, -47700,1.235223,2.115238,,,,,,,,,,,,,, -47800,1.1935698,2.2697372,,,,,,,,,,,,,, -47900,1.3116937,2.2249804,,,,,,,,,,,,,, -48000,1.3003602,2.2663062,,,,,,,,,,,,,, -48100,1.2493649,2.26376,,,,,,,,,,,,,, -48200,1.2243618,2.3433394,,,,,,,,,,,,,, -48300,1.2481505,2.3147373,,,,,,,,,,,,,, -48400,1.2293264,2.2951908,,,,,,,,,,,,,, -48500,1.195166,2.2750046,,,,,,,,,,,,,, -48534,,,0.186902105808258,4.8144307136535645,0.1775399893522262,4.926436901092529,50000.0,0.1300000101327896,5.572903156280518,10000.0,16355.56469798088,16938.182448148727,16355.56469798088,579.9127895832062,1.0119578838348389,0.0 -48600,1.0925804,2.2685156,,,,,,,,,,,,,, -48700,1.2200773,2.1336951,,,,,,,,,,,,,, -48800,1.178705,2.3435268,,,,,,,,,,,,,, -48900,1.0688106,2.1435642,,,,,,,,,,,,,, -49000,1.1420201,2.1481729,,,,,,,,,,,,,, -49100,1.2038609,2.2073042,,,,,,,,,,,,,, -49200,1.1227615,2.2040238,,,,,,,,,,,,,, -49300,1.1884891,2.1361675,,,,,,,,,,,,,, -49400,1.1952146,2.2316027,,,,,,,,,,,,,, -49500,1.219407,2.1750727,,,,,,,,,,,,,, -49600,1.2376033,2.1870713,,,,,,,,,,,,,, -49700,1.2832037,2.2472672,,,,,,,,,,,,,, -49800,1.1559166,2.123278,,,,,,,,,,,,,, -49900,1.3625917,2.2177386,,,,,,,,,,,,,, -50000,1.1770804,2.2331307,,,,,,,,,,,,,, -50054,,,0.2864915430545807,3.610297918319702,0.2510800063610077,3.918847322463989,50000.0,0.1940000057220459,4.523690700531006,10000.0,16865.721137285233,17465.871644973755,16865.721137285233,597.3582053184509,1.0468201637268066,0.0 -50100,1.1613538,2.2271862,,,,,,,,,,,,,, -50200,1.1863817,2.3344908,,,,,,,,,,,,,, -50300,1.1252403,2.18002,,,,,,,,,,,,,, -50400,1.2554146,2.3422651,,,,,,,,,,,,,, -50500,1.108719,2.208664,,,,,,,,,,,,,, -50600,1.3001757,2.282577,,,,,,,,,,,,,, -50700,1.1989826,2.2182105,,,,,,,,,,,,,, -50800,1.2189732,2.297625,,,,,,,,,,,,,, -50900,1.2977101,2.35076,,,,,,,,,,,,,, -51000,1.3385745,2.2948458,,,,,,,,,,,,,, -51100,1.2492583,2.2652206,,,,,,,,,,,,,, -51200,1.1684144,2.1959934,,,,,,,,,,,,,, -51300,1.2610687,2.335366,,,,,,,,,,,,,, -51400,1.1133282,2.1453063,,,,,,,,,,,,,, -51500,1.1603036,2.2882135,,,,,,,,,,,,,, -51574,,,0.2438815385103225,4.0654730796813965,0.2254599928855896,4.252574920654297,50000.0,0.1734000146389007,4.786422729492188,10000.0,17375.71496105194,17993.777713537216,17375.71496105194,615.1842617988586,1.081218957901001,0.0 -51600,1.2600431,2.183092,,,,,,,,,,,,,, -51700,1.4588836,2.2754283,,,,,,,,,,,,,, -51800,1.258359,2.2732568,,,,,,,,,,,,,, -51900,1.3348774,2.2881289,,,,,,,,,,,,,, -52000,1.221551,2.177152,,,,,,,,,,,,,, -52100,1.3216555,2.256496,,,,,,,,,,,,,, -52200,1.3052217,2.225309,,,,,,,,,,,,,, -52300,1.1710504,2.2925954,,,,,,,,,,,,,, -52400,1.1638943,2.243079,,,,,,,,,,,,,, -52500,1.2398175,2.2348037,,,,,,,,,,,,,, -52600,1.2970774,2.3635416,,,,,,,,,,,,,, -52700,1.3153019,2.305048,,,,,,,,,,,,,, -52800,1.2637073,2.2476158,,,,,,,,,,,,,, -52900,1.1415884,2.1434093,,,,,,,,,,,,,, -53000,1.1414081,2.2333097,,,,,,,,,,,,,, -53094,,,0.2487045526504516,4.032920837402344,0.2283199876546859,4.205615043640137,50000.0,0.175800010561943,4.857440948486328,10000.0,17885.76908183098,18521.48242330551,17885.76908183098,632.7463145256042,1.1168007850646973,0.0 -53100,1.1775641,2.1930158,,,,,,,,,,,,,, -53200,1.2577778,2.2181475,,,,,,,,,,,,,, -53300,1.2755166,2.2253952,,,,,,,,,,,,,, -53400,1.175726,2.180716,,,,,,,,,,,,,, -53500,1.2941196,2.1546657,,,,,,,,,,,,,, -53600,1.3104514,2.0482504,,,,,,,,,,,,,, -53700,1.4086975,2.120734,,,,,,,,,,,,,, -53800,1.2824152,2.1900332,,,,,,,,,,,,,, -53900,1.145394,2.1066465,,,,,,,,,,,,,, -54000,1.1807146,2.2265842,,,,,,,,,,,,,, -54100,1.1649711,2.2251382,,,,,,,,,,,,,, -54200,1.1733809,2.1615577,,,,,,,,,,,,,, -54300,1.1359904,2.190796,,,,,,,,,,,,,, -54400,1.2389354,2.19826,,,,,,,,,,,,,, -54500,1.357779,2.2264318,,,,,,,,,,,,,, -54600,1.2092628,2.251001,,,,,,,,,,,,,, -54614,,,0.3034518361091614,3.569441318511963,0.2828399837017059,3.7012898921966553,50000.0,0.2151000052690506,4.410431385040283,10000.0,18395.96377325058,19049.366036891937,18395.96377325058,650.3460447788239,1.1529819965362549,0.0 -54700,1.2962762,2.1914032,,,,,,,,,,,,,, -54800,1.2010492,2.2728229,,,,,,,,,,,,,, -54900,1.2505826,2.1142273,,,,,,,,,,,,,, -55000,1.2025983,2.2563145,,,,,,,,,,,,,, -55100,1.290044,2.3589382,,,,,,,,,,,,,, -55200,1.2220218,2.278217,,,,,,,,,,,,,, -55300,1.2712172,2.174156,,,,,,,,,,,,,, -55400,1.4453984,2.252089,,,,,,,,,,,,,, -55500,1.1665319,2.208056,,,,,,,,,,,,,, -55600,1.2033546,2.2290714,,,,,,,,,,,,,, -55700,1.2209855,2.315952,,,,,,,,,,,,,, -55800,1.2376837,2.3660173,,,,,,,,,,,,,, -55900,1.3187368,2.1192176,,,,,,,,,,,,,, -56000,1.1735492,2.1120453,,,,,,,,,,,,,, -56100,1.2159767,2.0004022,,,,,,,,,,,,,, -56135,,,0.2638711631298065,3.8232076168060303,0.2496399879455566,3.983182907104492,50000.0,0.1838000118732452,4.662926197052002,10000.0,18906.143671512604,19577.08617591858,18906.143671512604,667.7958898544312,1.19061279296875,0.0 -56200,1.3582785,2.2503195,,,,,,,,,,,,,, -56300,1.1754497,2.2685628,,,,,,,,,,,,,, -56400,1.2814972,2.1166315,,,,,,,,,,,,,, -56500,1.3143522,2.142806,,,,,,,,,,,,,, -56600,1.3325287,2.316032,,,,,,,,,,,,,, -56700,1.4857954,2.093885,,,,,,,,,,,,,, -56800,1.2929871,2.1869507,,,,,,,,,,,,,, -56900,1.4527348,2.2967093,,,,,,,,,,,,,, -57000,1.1576769,2.2242734,,,,,,,,,,,,,, -57100,1.2310332,2.2284932,,,,,,,,,,,,,, -57200,1.3654449,2.3122826,,,,,,,,,,,,,, -57300,1.241922,2.2367964,,,,,,,,,,,,,, -57400,1.2299455,2.139795,,,,,,,,,,,,,, -57500,1.2479143,2.1871202,,,,,,,,,,,,,, -57600,1.2815744,2.1987886,,,,,,,,,,,,,, -57656,,,0.1876992881298065,4.809234619140625,0.1772599965333938,4.956613540649414,50000.0,0.1345000118017196,5.564505100250244,10000.0,19416.319585561752,20104.69949555397,19416.319585561752,685.1396398544312,1.230280876159668,0.0 -57700,1.2781274,2.231355,,,,,,,,,,,,,, -57800,1.3305715,2.279661,,,,,,,,,,,,,, -57900,1.2822868,2.071399,,,,,,,,,,,,,, -58000,1.1766698,2.2101336,,,,,,,,,,,,,, -58100,1.2698709,2.2916932,,,,,,,,,,,,,, -58200,1.22235,2.1726723,,,,,,,,,,,,,, -58300,1.2688819,2.161992,,,,,,,,,,,,,, -58400,1.2419268,2.2735283,,,,,,,,,,,,,, -58500,1.2377261,2.2960322,,,,,,,,,,,,,, -58600,1.2569908,2.229262,,,,,,,,,,,,,, -58700,1.3708053,2.1619244,,,,,,,,,,,,,, -58800,1.2595487,2.243386,,,,,,,,,,,,,, -58900,1.2168791,2.2306705,,,,,,,,,,,,,, -59000,1.5498207,2.274119,,,,,,,,,,,,,, -59100,1.3405313,2.2992125,,,,,,,,,,,,,, -59177,,,0.1821189373731613,4.640640258789063,0.1666799932718277,4.858196258544922,50000.0,0.1234000027179718,5.476072311401367,10000.0,19926.450991630554,20632.49976873398,19926.450991630554,702.7179343700409,1.2668704986572266,0.0 -59200,1.1869941,2.336993,,,,,,,,,,,,,, -59300,1.2048548,2.2251444,,,,,,,,,,,,,, -59400,1.2467902,2.360833,,,,,,,,,,,,,, -59500,1.3031679,2.1342773,,,,,,,,,,,,,, -59600,1.2553521,2.1146548,,,,,,,,,,,,,, -59700,1.4032654,2.2337644,,,,,,,,,,,,,, -59800,1.2951113,2.2821214,,,,,,,,,,,,,, -59900,1.2244983,2.084636,,,,,,,,,,,,,, -60000,1.3941427,2.2118645,,,,,,,,,,,,,, -60100,1.24718,2.2881594,,,,,,,,,,,,,, -60200,1.3003553,2.0585005,,,,,,,,,,,,,, -60300,1.2585893,2.301308,,,,,,,,,,,,,, -60400,1.3916578,2.16033,,,,,,,,,,,,,, -60500,1.2422297,2.1123846,,,,,,,,,,,,,, -60600,1.2131872,2.1327562,,,,,,,,,,,,,, -60698,,,0.3432318270206451,3.1367218494415283,0.3236799836158752,3.3062968254089355,50000.0,0.226500004529953,4.168323040008545,10000.0,20436.43869829178,21159.79110598564,20436.43869829178,719.9272334575653,1.3086819648742676,0.0 -60700,1.2223521,1.9583901,,,,,,,,,,,,,, -60800,1.3181264,2.2705734,,,,,,,,,,,,,, -60900,1.221318,2.109726,,,,,,,,,,,,,, -61000,1.2805917,2.3436952,,,,,,,,,,,,,, -61100,1.3703879,2.2880487,,,,,,,,,,,,,, -61200,1.3603965,2.020271,,,,,,,,,,,,,, -61300,1.1651881,2.1391842,,,,,,,,,,,,,, -61400,1.1840466,2.2318635,,,,,,,,,,,,,, -61500,1.414543,2.1515567,,,,,,,,,,,,,, -61600,1.1264204,2.1515532,,,,,,,,,,,,,, -61700,1.2985089,2.1358438,,,,,,,,,,,,,, -61800,1.1813109,2.1046019,,,,,,,,,,,,,, -61900,1.2076025,2.2072697,,,,,,,,,,,,,, -62000,1.1788443,2.2657986,,,,,,,,,,,,,, -62100,1.2061828,2.1022847,,,,,,,,,,,,,, -62200,1.4446312,2.2182755,,,,,,,,,,,,,, -62219,,,0.2721022069454193,3.687821388244629,0.2554399967193603,3.826228618621826,50000.0,0.1844000071287155,4.584669589996338,10000.0,20946.65442109108,21687.56720471382,20946.65442109108,737.3943648338318,1.3493919372558594,0.0 -62300,1.3599222,2.0623698,,,,,,,,,,,,,, -62400,1.4721979,2.2146509,,,,,,,,,,,,,, -62500,1.2991657,2.1552098,,,,,,,,,,,,,, -62600,1.295462,2.1635668,,,,,,,,,,,,,, -62700,1.3054291,2.2524395,,,,,,,,,,,,,, -62800,1.1640506,2.1035352,,,,,,,,,,,,,, -62900,1.4374323,2.2786021,,,,,,,,,,,,,, -63000,1.5023699,2.1546776,,,,,,,,,,,,,, -63100,1.1921194,2.1964822,,,,,,,,,,,,,, -63200,1.2920759,2.0979996,,,,,,,,,,,,,, -63300,1.4450004,2.2414155,,,,,,,,,,,,,, -63400,1.319046,2.132713,,,,,,,,,,,,,, -63500,1.2145574,2.0325062,,,,,,,,,,,,,, -63600,1.398999,2.27971,,,,,,,,,,,,,, -63700,1.3979291,2.2594438,,,,,,,,,,,,,, -63739,,,0.2697106003761291,3.7855687141418457,0.2567200064659118,3.871646881103516,50000.0,0.1894000023603439,4.6342620849609375,10000.0,21456.615305900574,22215.171835184097,21456.615305900574,754.9472448825836,1.3876104354858398,0.0 -63800,1.1688988,2.219031,,,,,,,,,,,,,, -63900,1.2958722,2.1541834,,,,,,,,,,,,,, -64000,1.3845078,2.0895414,,,,,,,,,,,,,, -64100,1.2993189,2.0960975,,,,,,,,,,,,,, -64200,1.1900225,2.1268609,,,,,,,,,,,,,, -64300,1.2540572,2.2265935,,,,,,,,,,,,,, -64400,1.3742112,2.284755,,,,,,,,,,,,,, -64500,1.300012,2.267433,,,,,,,,,,,,,, -64600,1.34847,2.1260054,,,,,,,,,,,,,, -64700,1.2084994,2.0959616,,,,,,,,,,,,,, -64800,1.3398702,2.2281787,,,,,,,,,,,,,, -64900,1.4172659,2.1706865,,,,,,,,,,,,,, -65000,1.2698103,2.070382,,,,,,,,,,,,,, -65100,1.4833398,2.1646786,,,,,,,,,,,,,, -65200,1.3191619,2.1414208,,,,,,,,,,,,,, -65260,,,0.2284757643938064,4.340119361877441,0.2192399948835373,4.4335126876831055,50000.0,0.1604000031948089,5.25397253036499,10000.0,21966.762639045715,22742.61117172241,21966.762639045715,772.145806312561,1.428035020828247,0.0 -65300,1.1584703,2.0797071,,,,,,,,,,,,,, -65400,1.2521023,2.1205032,,,,,,,,,,,,,, -65500,1.2641058,2.224318,,,,,,,,,,,,,, -65600,1.280125,2.0651865,,,,,,,,,,,,,, -65700,1.2905171,2.1914039,,,,,,,,,,,,,, -65800,1.2370166,2.1655495,,,,,,,,,,,,,, -65900,1.35732,2.1613991,,,,,,,,,,,,,, -66000,1.2750626,2.2889745,,,,,,,,,,,,,, -66100,1.2034771,2.1221774,,,,,,,,,,,,,, -66200,1.4512893,2.1506577,,,,,,,,,,,,,, -66300,1.3507015,2.066817,,,,,,,,,,,,,, -66400,1.379586,2.147145,,,,,,,,,,,,,, -66500,1.2963996,2.094911,,,,,,,,,,,,,, -66600,1.4208623,2.2102349,,,,,,,,,,,,,, -66700,1.3346509,2.200803,,,,,,,,,,,,,, -66781,,,0.3475964665412903,3.1049602031707764,0.3361199796199798,3.2092509269714355,50000.0,0.2578000128269195,3.982248544692993,10000.0,22476.804877758022,23270.36111807823,22476.804877758022,789.7588932514191,1.4698281288146973,0.0 -66800,1.3083788,2.0615687,,,,,,,,,,,,,, -66900,1.2909884,2.1421487,,,,,,,,,,,,,, -67000,1.4364381,2.1977704,,,,,,,,,,,,,, -67100,1.2423977,2.1128109,,,,,,,,,,,,,, -67200,1.342126,2.106711,,,,,,,,,,,,,, -67300,1.5425236,2.2110586,,,,,,,,,,,,,, -67400,1.3921753,2.1555748,,,,,,,,,,,,,, -67500,1.3271106,2.1539173,,,,,,,,,,,,,, -67600,1.2092285,2.0888648,,,,,,,,,,,,,, -67700,1.3392053,2.1939633,,,,,,,,,,,,,, -67800,1.5660906,2.1548996,,,,,,,,,,,,,, -67900,1.3881238,2.2195218,,,,,,,,,,,,,, -68000,1.2626013,2.0747514,,,,,,,,,,,,,, -68100,1.3451196,2.2508373,,,,,,,,,,,,,, -68200,1.2794794,2.1444063,,,,,,,,,,,,,, -68300,1.463734,2.2278237,,,,,,,,,,,,,, -68302,,,0.351283460855484,3.137704372406006,0.3225999772548675,3.393313407897949,50000.0,0.2359000146389007,4.184186458587647,10000.0,22986.811591625214,23798.096511363983,22986.811591625214,807.3938567638397,1.5102369785308838,0.0 -68400,1.2852749,2.0906684,,,,,,,,,,,,,, -68500,1.3951111,2.2109675,,,,,,,,,,,,,, -68600,1.3264312,2.0542307,,,,,,,,,,,,,, -68700,1.3812215,2.0802088,,,,,,,,,,,,,, -68800,1.297981,2.1360798,,,,,,,,,,,,,, -68900,1.3102999,2.2561803,,,,,,,,,,,,,, -69000,1.4532195,2.2518044,,,,,,,,,,,,,, -69100,1.3269395,2.2978966,,,,,,,,,,,,,, -69200,1.4098177,2.1564143,,,,,,,,,,,,,, -69300,1.389827,2.2218273,,,,,,,,,,,,,, -69400,1.4194782,2.1414423,,,,,,,,,,,,,, -69500,1.2634007,2.1165295,,,,,,,,,,,,,, -69600,1.3066287,2.021816,,,,,,,,,,,,,, -69700,1.3016928,2.0513752,,,,,,,,,,,,,, -69800,1.5344946,2.1307662,,,,,,,,,,,,,, -69823,,,0.1961694806814193,4.434426784515381,0.1842399984598159,4.5618109703063965,50000.0,0.1305000036954879,5.301344871520996,10000.0,23496.993092536926,24325.841685056686,23496.993092536926,824.8659207820892,1.5485587120056152,0.0 -69900,1.3709322,2.1514482,,,,,,,,,,,,,, -70000,1.211232,2.029714,,,,,,,,,,,,,, -70100,1.3268038,2.1195905,,,,,,,,,,,,,, -70200,1.3208529,2.1411004,,,,,,,,,,,,,, -70300,1.2685177,2.1914089,,,,,,,,,,,,,, -70400,1.2645019,2.1017058,,,,,,,,,,,,,, -70500,1.4890281,2.1805096,,,,,,,,,,,,,, -70600,1.335319,2.1832335,,,,,,,,,,,,,, -70700,1.3259491,2.128024,,,,,,,,,,,,,, -70800,1.2491518,2.1647363,,,,,,,,,,,,,, -70900,1.4098464,2.0715232,,,,,,,,,,,,,, -71000,1.3678387,2.19059,,,,,,,,,,,,,, -71100,1.3101732,2.0287106,,,,,,,,,,,,,, -71200,1.2587582,2.0912807,,,,,,,,,,,,,, -71300,1.4220569,2.0445313,,,,,,,,,,,,,, -71344,,,0.3081353604793548,3.636435031890869,0.2773399949073791,3.8866045475006095,50000.0,0.2132000029087066,4.583950042724609,10000.0,24007.092417240143,24853.448628664017,24007.092417240143,842.2808480262756,1.5876126289367676,0.0 -71400,1.3202904,2.214835,,,,,,,,,,,,,, -71500,1.348448,2.0698647,,,,,,,,,,,,,, -71600,1.4231713,2.2718818,,,,,,,,,,,,,, -71700,1.3852022,2.1348135,,,,,,,,,,,,,, -71800,1.4719656,2.1931012,,,,,,,,,,,,,, -71900,1.3386292,2.107933,,,,,,,,,,,,,, -72000,1.4188039,2.2081223,,,,,,,,,,,,,, -72100,1.3396639,2.0644226,,,,,,,,,,,,,, -72200,1.3547474,2.034634,,,,,,,,,,,,,, -72300,1.4273322,2.166106,,,,,,,,,,,,,, -72400,1.4070945,2.1512363,,,,,,,,,,,,,, -72500,1.3251113,2.1078303,,,,,,,,,,,,,, -72600,1.4054943,2.2201564,,,,,,,,,,,,,, -72700,1.3931928,2.3155046,,,,,,,,,,,,,, -72800,1.3144741,2.1902194,,,,,,,,,,,,,, -72865,,,0.2721420526504516,3.783998489379883,0.2586399912834167,3.9203364849090576,50000.0,0.1842000037431717,4.706700325012207,10000.0,24517.10163712501,25380.833248138428,24517.10163712501,859.5648393630981,1.626474142074585,0.0 -72900,1.5091932,2.1636333,,,,,,,,,,,,,, -73000,1.3596467,2.0669138,,,,,,,,,,,,,, -73100,1.4177258,2.1393833,,,,,,,,,,,,,, -73200,1.4088988,2.0767725,,,,,,,,,,,,,, -73300,1.3442119,2.1638622,,,,,,,,,,,,,, -73400,1.4413459,2.070244,,,,,,,,,,,,,, -73500,1.5427582,2.1452231,,,,,,,,,,,,,, -73600,1.3755965,2.107294,,,,,,,,,,,,,, -73700,1.320261,2.107066,,,,,,,,,,,,,, -73800,1.3287057,2.1957898,,,,,,,,,,,,,, -73900,1.3073112,2.1380665,,,,,,,,,,,,,, -74000,1.4837604,2.0449579,,,,,,,,,,,,,, -74100,1.4511213,2.2627459,,,,,,,,,,,,,, -74200,1.3396544,2.007286,,,,,,,,,,,,,, -74300,1.4406685,2.1000042,,,,,,,,,,,,,, -74387,,,0.407904177904129,2.691690921783448,0.3840200006961822,2.868121385574341,50000.0,0.2869000136852264,3.61729907989502,10000.0,25027.26301074028,25909.73645925522,25027.26301074028,878.2137405872345,1.6661872863769531,0.0 -74400,1.3436284,1.9887583,,,,,,,,,,,,,, -74500,1.3217114,2.1085918,,,,,,,,,,,,,, -74600,1.5579066,2.1900632,,,,,,,,,,,,,, -74700,1.3907925,2.2004275,,,,,,,,,,,,,, -74800,1.3712547,2.233715,,,,,,,,,,,,,, -74900,1.4220859,2.1255448,,,,,,,,,,,,,, -75000,1.3314859,2.2579625,,,,,,,,,,,,,, -75100,1.3769468,2.1818976,,,,,,,,,,,,,, -75200,1.6036681,2.0944347,,,,,,,,,,,,,, -75300,1.4777576,2.1312351,,,,,,,,,,,,,, -75400,1.4284749,2.1131234,,,,,,,,,,,,,, -75500,1.490012,2.2453742,,,,,,,,,,,,,, -75600,1.4103885,2.1567216,,,,,,,,,,,,,, -75700,1.353441,2.1325543,,,,,,,,,,,,,, -75800,1.3828275,2.1255085,,,,,,,,,,,,,, -75900,1.2740334,1.9988296,,,,,,,,,,,,,, -75908,,,0.4138033986091614,2.737895965576172,0.3630799949169159,3.0857176780700684,50000.0,0.2714000046253204,3.899897336959839,10000.0,25537.45592713356,26437.587026834488,25537.45592713356,895.784318447113,1.700392484664917,0.0 -76000,1.439059,2.176226,,,,,,,,,,,,,, -76100,1.3696989,2.0116606,,,,,,,,,,,,,, -76200,1.707632,2.3011897,,,,,,,,,,,,,, -76300,1.485272,2.0919378,,,,,,,,,,,,,, -76400,1.4005464,2.0705318,,,,,,,,,,,,,, -76500,1.4681664,2.14749,,,,,,,,,,,,,, -76600,1.5382358,2.1495333,,,,,,,,,,,,,, -76700,1.357747,2.0650787,,,,,,,,,,,,,, -76800,1.3919103,2.194432,,,,,,,,,,,,,, -76900,1.4363517,2.098213,,,,,,,,,,,,,, -77000,1.358306,2.0825493,,,,,,,,,,,,,, -77100,1.4040009,2.133988,,,,,,,,,,,,,, -77200,1.3625753,2.0124495,,,,,,,,,,,,,, -77300,1.397188,2.2034798,,,,,,,,,,,,,, -77400,1.3388377,2.0593376,,,,,,,,,,,,,, -77429,,,0.3894292116165161,2.835314989089966,0.3652399778366089,3.049356460571289,50000.0,0.2776000201702118,3.813933372497559,10000.0,26047.5850622654,26965.33577370644,26047.5850622654,913.3123028278352,1.7399368286132812,0.0 -77500,1.2936075,2.0436862,,,,,,,,,,,,,, -77600,1.4233009,2.2204251,,,,,,,,,,,,,, -77700,1.4229616,2.0851316,,,,,,,,,,,,,, -77800,1.3407058,1.9287436,,,,,,,,,,,,,, -77900,1.4412316,2.125263,,,,,,,,,,,,,, -78000,1.3785483,2.0746274,,,,,,,,,,,,,, -78100,1.3662372,2.0846264,,,,,,,,,,,,,, -78200,1.4426643,2.1127474,,,,,,,,,,,,,, -78300,1.440229,2.0522828,,,,,,,,,,,,,, -78400,1.4125369,2.1666234,,,,,,,,,,,,,, -78500,1.3632141,2.1011095,,,,,,,,,,,,,, -78600,1.3943751,2.0234873,,,,,,,,,,,,,, -78700,1.4038557,2.0473585,,,,,,,,,,,,,, -78800,1.3984895,2.1723013,,,,,,,,,,,,,, -78900,1.4866599,2.1922152,,,,,,,,,,,,,, -78950,,,0.3610690236091614,3.041414499282837,0.3311599791049957,3.252563238143921,50000.0,0.2470000088214874,3.9893369674682617,10000.0,26557.61221885681,27492.91514992714,26557.61221885681,930.7702312469482,1.7816412448883057,0.0 -79000,1.3286543,2.0948954,,,,,,,,,,,,,, -79100,1.2162874,1.9832377,,,,,,,,,,,,,, -79200,1.4294422,1.9310538,,,,,,,,,,,,,, -79300,1.5239933,2.3147407,,,,,,,,,,,,,, -79400,1.2973118,2.0425177,,,,,,,,,,,,,, -79500,1.3797829,2.0152993,,,,,,,,,,,,,, -79600,1.6445037,1.995731,,,,,,,,,,,,,, -79700,1.4715239,2.0965016,,,,,,,,,,,,,, -79800,1.4554739,2.119524,,,,,,,,,,,,,, -79900,1.4235847,2.1171072,,,,,,,,,,,,,, -80000,1.4306687,2.1045177,,,,,,,,,,,,,, -80100,1.3410301,2.11522,,,,,,,,,,,,,, -80200,1.388632,2.1180482,,,,,,,,,,,,,, -80300,1.408719,2.1338646,,,,,,,,,,,,,, -80400,1.3892192,2.1389399,,,,,,,,,,,,,, -80471,,,0.3260921537876129,3.3770241737365723,0.3070800006389618,3.5291993618011475,50000.0,0.2220000177621841,4.42165994644165,10000.0,27067.55645370484,28020.69840979576,27067.55645370484,948.5069932937622,1.830293893814087,0.0 -80500,1.3626457,2.1293888,,,,,,,,,,,,,, -80600,1.4327815,2.0079618,,,,,,,,,,,,,, -80700,1.4065788,2.0517967,,,,,,,,,,,,,, -80800,1.5884573,2.1947107,,,,,,,,,,,,,, -80900,1.3687601,2.0834038,,,,,,,,,,,,,, -81000,1.4126686,2.1461902,,,,,,,,,,,,,, -81100,1.8260866,2.1679852,,,,,,,,,,,,,, -81200,1.7200571,2.13594,,,,,,,,,,,,,, -81300,1.6232358,2.1350794,,,,,,,,,,,,,, -81400,1.4461182,2.1504138,,,,,,,,,,,,,, -81500,1.5240195,2.0533633,,,,,,,,,,,,,, -81600,1.3881793,1.9693031,,,,,,,,,,,,,, -81700,1.4183706,2.1341348,,,,,,,,,,,,,, -81800,1.4902754,2.0866885,,,,,,,,,,,,,, -81900,1.5028882,2.0256078,,,,,,,,,,,,,, -81991,,,0.375019907951355,2.999455213546753,0.3500199913978576,3.140925168991089,50000.0,0.2578000128269195,3.9556169509887695,10000.0,27577.53954029084,28548.6232419014,27577.53954029084,966.3523087501526,1.8731324672698968,0.0 -82000,1.3528581,1.9682186,,,,,,,,,,,,,, -82100,1.3914175,2.1776152,,,,,,,,,,,,,, -82200,1.4130585,2.0812056,,,,,,,,,,,,,, -82300,1.4123588,2.0812092,,,,,,,,,,,,,, -82400,1.6191692,2.0035791,,,,,,,,,,,,,, -82500,1.3500501,2.0371425,,,,,,,,,,,,,, -82600,1.6093676,2.1394,,,,,,,,,,,,,, -82700,1.3588376,2.0615547,,,,,,,,,,,,,, -82800,1.460726,2.1843336,,,,,,,,,,,,,, -82900,1.7716298,2.0797808,,,,,,,,,,,,,, -83000,1.5332603,2.1550066,,,,,,,,,,,,,, -83100,1.5481888,2.1543088,,,,,,,,,,,,,, -83200,1.3409475,1.982057,,,,,,,,,,,,,, -83300,1.4718374,2.0595412,,,,,,,,,,,,,, -83400,1.4260522,2.0355535,,,,,,,,,,,,,, -83500,1.3722715,2.055532,,,,,,,,,,,,,, -83512,,,0.4471859037876129,2.482073307037353,0.4222399890422821,2.646937131881714,50000.0,0.3202000260353088,3.373796939849853,10000.0,28087.72601652145,29076.31339788437,28087.72601652145,983.7680652141572,1.908886432647705,0.0 -83600,1.4377861,2.0938087,,,,,,,,,,,,,, -83700,1.3982223,2.1058242,,,,,,,,,,,,,, -83800,1.7688781,2.1035242,,,,,,,,,,,,,, -83900,1.4693955,2.0635848,,,,,,,,,,,,,, -84000,1.4126484,2.153388,,,,,,,,,,,,,, -84100,1.4003974,2.1397436,,,,,,,,,,,,,, -84200,1.5416011,2.0072906,,,,,,,,,,,,,, -84300,1.4993801,1.9931217,,,,,,,,,,,,,, -84400,1.5347421,2.116965,,,,,,,,,,,,,, -84500,1.5241368,2.0333602,,,,,,,,,,,,,, -84600,1.4675493,2.0709333,,,,,,,,,,,,,, -84700,1.3901109,2.0050964,,,,,,,,,,,,,, -84800,1.5402198,2.0346766,,,,,,,,,,,,,, -84900,1.5488176,2.0972376,,,,,,,,,,,,,, -85000,1.7991151,2.1045523,,,,,,,,,,,,,, -85033,,,0.3713129758834839,2.982234001159668,0.339599996805191,3.212437391281128,50000.0,0.2476000189781189,4.028741359710693,10000.0,28597.715670108795,29603.63627338409,28597.715670108795,1001.0070824623108,1.950409889221192,0.0 -85100,1.6355762,2.0961227,,,,,,,,,,,,,, -85200,1.505868,2.0606627,,,,,,,,,,,,,, -85300,1.456564,2.0793896,,,,,,,,,,,,,, -85400,1.6020513,2.0195367,,,,,,,,,,,,,, -85500,1.6445415,2.1960251,,,,,,,,,,,,,, -85600,1.5032259,2.089395,,,,,,,,,,,,,, -85700,1.4376978,2.0986586,,,,,,,,,,,,,, -85800,1.5394341,2.1032462,,,,,,,,,,,,,, -85900,1.5743749,2.0425594,,,,,,,,,,,,,, -86000,1.4351727,2.0430524,,,,,,,,,,,,,, -86100,1.537129,2.1139936,,,,,,,,,,,,,, -86200,1.4809577,2.0076904,,,,,,,,,,,,,, -86300,1.4837317,2.1475554,,,,,,,,,,,,,, -86400,1.4828396,2.117584,,,,,,,,,,,,,, -86500,1.715567,2.1221955,,,,,,,,,,,,,, -86554,,,0.3778499662876129,3.040504217147827,0.3454599976539612,3.298660516738892,50000.0,0.2648999989032745,4.093916893005371,10000.0,29107.75165891648,30131.54461431504,29107.75165891648,1018.7795708179474,1.9979043006896973,0.0 -86600,1.7261854,1.9949191,,,,,,,,,,,,,, -86700,1.54206,2.1301582,,,,,,,,,,,,,, -86800,1.5549868,2.1535385,,,,,,,,,,,,,, -86900,1.5601493,2.0704741,,,,,,,,,,,,,, -87000,1.8888922,2.0141387,,,,,,,,,,,,,, -87100,1.638992,2.1912332,,,,,,,,,,,,,, -87200,1.5315466,2.0578985,,,,,,,,,,,,,, -87300,1.5224327,1.9704318,,,,,,,,,,,,,, -87400,1.4745235,1.962904,,,,,,,,,,,,,, -87500,1.5595845,1.9810221,,,,,,,,,,,,,, -87600,1.5665456,2.0009894,,,,,,,,,,,,,, -87700,1.487532,2.077395,,,,,,,,,,,,,, -87800,1.4888241,2.1329546,,,,,,,,,,,,,, -87900,1.4964025,2.0712845,,,,,,,,,,,,,, -88000,1.6078756,2.0940611,,,,,,,,,,,,,, -88074,,,0.3744818270206451,2.971901655197144,0.3493599891662597,3.141242742538452,50000.0,0.2532000243663788,3.9751439094543457,10000.0,29617.712538719177,30659.6422123909,29617.712538719177,1036.8213591575625,2.039724111557007,0.0 -88100,1.4857316,1.9922786,,,,,,,,,,,,,, -88200,1.4392728,2.1331282,,,,,,,,,,,,,, -88300,1.4859332,2.0413334,,,,,,,,,,,,,, -88400,1.6621202,2.0337195,,,,,,,,,,,,,, -88500,1.7188313,2.0555477,,,,,,,,,,,,,, -88600,1.4494594,1.9632711,,,,,,,,,,,,,, -88700,1.5649178,2.117024,,,,,,,,,,,,,, -88800,1.4352005,2.0642896,,,,,,,,,,,,,, -88900,1.4612055,2.0569363,,,,,,,,,,,,,, -89000,1.520504,2.1336582,,,,,,,,,,,,,, -89100,1.4762766,1.9269165,,,,,,,,,,,,,, -89200,1.4656249,1.9920475,,,,,,,,,,,,,, -89300,1.6209854,1.9297286,,,,,,,,,,,,,, -89400,1.6128402,2.063676,,,,,,,,,,,,,, -89500,1.5629563,2.0323045,,,,,,,,,,,,,, -89594,,,0.4371412396430969,2.587308406829834,0.4049800038337707,2.806240558624268,50000.0,0.3113000094890594,3.605199098587036,10000.0,30127.6463303566,31187.31401515007,30127.6463303566,1054.453256368637,2.091838836669922,0.0 -89600,1.4962178,2.0443482,,,,,,,,,,,,,, -89700,1.6797853,2.1360333,,,,,,,,,,,,,, -89800,1.5924383,2.015007,,,,,,,,,,,,,, -89900,1.4520932,2.0307038,,,,,,,,,,,,,, -90000,1.6358463,2.004251,,,,,,,,,,,,,, -90100,1.6249949,2.1403594,,,,,,,,,,,,,, -90200,1.5984734,2.0367002,,,,,,,,,,,,,, -90300,1.5205271,2.1047804,,,,,,,,,,,,,, -90400,1.66594,2.112791,,,,,,,,,,,,,, -90500,1.4493064,2.1299503,,,,,,,,,,,,,, -90600,1.6407233,1.9292471,,,,,,,,,,,,,, -90700,1.5403665,1.9796957,,,,,,,,,,,,,, -90800,1.4665837,2.1234972,,,,,,,,,,,,,, -90900,1.7103868,1.9932065,,,,,,,,,,,,,, -91000,1.6717116,2.0117269,,,,,,,,,,,,,, -91100,1.514685,1.8802985,,,,,,,,,,,,,, -91114,,,0.408223032951355,2.8384816646575928,0.3856199979782104,3.0218026638031006,50000.0,0.2911000251770019,3.8347907066345215,10000.0,30637.59294629097,31714.630335330963,30637.59294629097,1071.72580742836,2.1351735591888428,0.0 -91200,1.5808483,1.9754167,,,,,,,,,,,,,, -91300,1.6796281,2.0100057,,,,,,,,,,,,,, -91400,1.5390016,1.8605155,,,,,,,,,,,,,, -91500,1.6023697,2.0787451,,,,,,,,,,,,,, -91600,1.6122218,2.04202,,,,,,,,,,,,,, -91700,1.6033276,2.2369344,,,,,,,,,,,,,, -91800,1.6861613,1.953479,,,,,,,,,,,,,, -91900,1.62279,1.9038715,,,,,,,,,,,,,, -92000,1.5470893,2.02594,,,,,,,,,,,,,, -92100,1.5714693,1.9242623,,,,,,,,,,,,,, -92200,1.4924892,1.9246917,,,,,,,,,,,,,, -92300,1.6217902,1.9005182,,,,,,,,,,,,,, -92400,1.537269,1.9579678,,,,,,,,,,,,,, -92500,1.5476643,2.02611,,,,,,,,,,,,,, -92600,2.0378566,2.0159404,,,,,,,,,,,,,, -92635,,,0.4818638265132904,2.327448844909668,0.4464399814605713,2.534179210662842,50000.0,0.3418000042438507,3.3646998405456543,10000.0,31147.667982816696,32242.1359269619,31147.667982816696,1089.054959774017,2.183558702468872,0.0 -92700,1.4796227,1.9887192,,,,,,,,,,,,,, -92800,1.5542386,2.0927184,,,,,,,,,,,,,, -92900,1.6067091,2.0537252,,,,,,,,,,,,,, -93000,1.4942818,2.0059464,,,,,,,,,,,,,, -93100,1.6365273,2.0588624,,,,,,,,,,,,,, -93200,1.4164792,1.9607003,,,,,,,,,,,,,, -93300,1.6727854,2.0687704,,,,,,,,,,,,,, -93400,1.6577302,2.0982742,,,,,,,,,,,,,, -93500,1.6095254,1.9565732,,,,,,,,,,,,,, -93600,1.5995576,2.0254178,,,,,,,,,,,,,, -93700,1.6251543,2.0612788,,,,,,,,,,,,,, -93800,1.5293354,1.9342428,,,,,,,,,,,,,, -93900,1.8331549,2.1736732,,,,,,,,,,,,,, -94000,1.6608659,1.9645125,,,,,,,,,,,,,, -94100,1.6943177,2.074143,,,,,,,,,,,,,, -94156,,,0.5026904940605164,2.148667812347412,0.4616999924182892,2.3995540142059326,50000.0,0.3456000089645386,3.1860923767089844,10000.0,31657.68229198456,32769.84343409538,31657.68229198456,1106.6452696323397,2.233637571334839,0.0 -94200,1.6746011,2.0137024,,,,,,,,,,,,,, -94300,1.6528404,2.0956526,,,,,,,,,,,,,, -94400,1.6858157,2.0318608,,,,,,,,,,,,,, -94500,1.4163599,1.9565947,,,,,,,,,,,,,, -94600,1.6077379,1.9438125,,,,,,,,,,,,,, -94700,1.5787877,2.001948,,,,,,,,,,,,,, -94800,1.6926706,1.9401734,,,,,,,,,,,,,, -94900,1.5586843,1.9114052,,,,,,,,,,,,,, -95000,1.5473183,2.0762274,,,,,,,,,,,,,, -95100,1.5054843,1.9928199,,,,,,,,,,,,,, -95200,2.01188,2.005017,,,,,,,,,,,,,, -95300,1.6795352,2.1103175,,,,,,,,,,,,,, -95400,1.6078383,1.8628362,,,,,,,,,,,,,, -95500,1.492943,1.9992679,,,,,,,,,,,,,, -95600,1.6118029,1.9166565,,,,,,,,,,,,,, -95677,,,0.3589166104793548,3.0547327995300293,0.3370199799537658,3.236257076263428,50000.0,0.255700021982193,3.914307117462158,10000.0,32167.90555024147,33297.85503602028,32167.90555024147,1124.3362169265747,2.278075695037842,0.0 -95700,1.6480881,1.9584742,,,,,,,,,,,,,, -95800,1.7671634,2.0894835,,,,,,,,,,,,,, -95900,1.8479093,1.9659736,,,,,,,,,,,,,, -96000,1.7161305,1.9397178,,,,,,,,,,,,,, -96100,1.6188085,1.9688514,,,,,,,,,,,,,, -96200,1.6205813,1.9614508,,,,,,,,,,,,,, -96300,1.7975706,2.0521166,,,,,,,,,,,,,, -96400,1.6961458,2.0872693,,,,,,,,,,,,,, -96500,1.7421504,2.0010433,,,,,,,,,,,,,, -96600,1.6558479,2.0616803,,,,,,,,,,,,,, -96700,1.5692402,1.9110811,,,,,,,,,,,,,, -96800,1.5333555,2.0916786,,,,,,,,,,,,,, -96900,1.7998589,2.0709512,,,,,,,,,,,,,, -97000,1.7224863,2.0106268,,,,,,,,,,,,,, -97100,1.6633617,1.9684919,,,,,,,,,,,,,, -97198,,,0.4383370578289032,2.5672011375427246,0.4063999950885772,2.7606093883514404,50000.0,0.3154000043869018,3.509662628173828,10000.0,32677.89052796364,33825.715623378754,32677.89052796364,1142.114492893219,2.3221030235290527,0.0 -97200,1.7166495,1.991318,,,,,,,,,,,,,, -97300,1.7261388,2.0224652,,,,,,,,,,,,,, -97400,1.6443733,1.9453372,,,,,,,,,,,,,, -97500,1.6279823,1.910368,,,,,,,,,,,,,, -97600,1.6698015,1.9370985,,,,,,,,,,,,,, -97700,1.7705721,1.9927332,,,,,,,,,,,,,, -97800,1.588717,2.119587,,,,,,,,,,,,,, -97900,1.6151094,1.9953485,,,,,,,,,,,,,, -98000,1.6529058,1.9562113,,,,,,,,,,,,,, -98100,1.7539355,1.943641,,,,,,,,,,,,,, -98200,1.5754308,1.9325421,,,,,,,,,,,,,, -98300,1.5893854,1.9814525,,,,,,,,,,,,,, -98400,1.6921792,1.8883797,,,,,,,,,,,,,, -98500,1.6852468,2.0485191,,,,,,,,,,,,,, -98600,1.5767965,1.9553448,,,,,,,,,,,,,, -98700,1.6454351,1.91697,,,,,,,,,,,,,, -98720,,,0.3068598508834839,3.599402904510498,0.2856799960136413,3.769570827484131,50000.0,0.2054000049829483,4.586277484893799,10000.0,33188.105674266815,34353.2997841835,33188.105674266815,1159.3864908218384,2.3658483028411865,0.0 -98800,1.7020215,1.874248,,,,,,,,,,,,,, -98900,1.7432011,1.9959818,,,,,,,,,,,,,, -99000,1.6913882,2.0403073,,,,,,,,,,,,,, -99100,1.6040261,1.945475,,,,,,,,,,,,,, -99200,1.7466412,1.9667348,,,,,,,,,,,,,, -99300,1.7188017,1.9535792,,,,,,,,,,,,,, -99400,1.6998683,1.9626788,,,,,,,,,,,,,, -99500,1.7972838,1.9970926,,,,,,,,,,,,,, -99600,1.6529233,1.8864403,,,,,,,,,,,,,, -99700,1.6358887,1.8244147,,,,,,,,,,,,,, -99800,1.7435936,1.8897189,,,,,,,,,,,,,, -99900,1.6844654,2.0345895,,,,,,,,,,,,,, -100000,1.7982444,2.0334241,,,,,,,,,,,,,, -100100,1.707622,1.9501011,,,,,,,,,,,,,, -100200,1.9055456,1.8294202,,,,,,,,,,,,,, -100241,,,0.4196627736091614,2.642141103744507,0.4020799994468689,2.7794859409332275,50000.0,0.2949000000953674,3.5665860176086426,10000.0,33698.1875834465,34881.077904462814,33698.1875834465,1176.9801092147827,2.4154789447784424,0.0 -100300,1.8477135,2.1228955,,,,,,,,,,,,,, -100400,1.7922753,1.9401301,,,,,,,,,,,,,, -100500,1.7891974,2.0293643,,,,,,,,,,,,,, -100600,1.6930306,1.9701805,,,,,,,,,,,,,, -100700,1.7954493,2.0229087,,,,,,,,,,,,,, -100800,1.844121,1.9157304,,,,,,,,,,,,,, -100900,1.7094646,1.8599417,,,,,,,,,,,,,, -101000,1.7302362,1.9110608,,,,,,,,,,,,,, -101100,1.6761218,1.9529462,,,,,,,,,,,,,, -101200,1.7070528,2.0140347,,,,,,,,,,,,,, -101300,1.7638928,1.9779397,,,,,,,,,,,,,, -101400,1.7438211,1.99842,,,,,,,,,,,,,, -101500,1.7956275,1.9038997,,,,,,,,,,,,,, -101600,1.6388336,2.0145068,,,,,,,,,,,,,, -101700,1.8758146,1.9198998,,,,,,,,,,,,,, -101763,,,0.4217952787876129,2.6498255729675293,0.3804999887943268,2.93494200706482,50000.0,0.2842999994754791,3.691239595413208,10000.0,34208.35723352432,35408.832102775574,34208.35723352432,1194.4681041240692,2.459041118621826,0.0 -101800,1.7353773,1.905782,,,,,,,,,,,,,, -101900,1.6472013,1.8572814,,,,,,,,,,,,,, -102000,1.7091817,1.9339277,,,,,,,,,,,,,, -102100,1.6574546,1.981653,,,,,,,,,,,,,, -102200,1.7639279,1.9161925,,,,,,,,,,,,,, -102300,1.6884954,2.0370307,,,,,,,,,,,,,, -102400,1.779881,2.009662,,,,,,,,,,,,,, -102500,1.6642532,1.9100318,,,,,,,,,,,,,, -102600,1.7789899,1.8667803,,,,,,,,,,,,,, -102700,1.7502749,2.0126119,,,,,,,,,,,,,, -102800,1.794413,1.9151568,,,,,,,,,,,,,, -102900,1.8512806,1.8749181,,,,,,,,,,,,,, -103000,1.7046651,1.9555901,,,,,,,,,,,,,, -103100,1.8591213,1.9069287,,,,,,,,,,,,,, -103200,1.7983687,1.9890308,,,,,,,,,,,,,, -103284,,,0.4627311825752258,2.48760986328125,0.427979975938797,2.753029584884644,50000.0,0.3148000240325928,3.657977104187012,10000.0,34718.39355421066,35936.52565956116,34718.39355421066,1212.0220968723297,2.508146286010742,0.0 -103300,1.7643076,1.9316599,,,,,,,,,,,,,, -103400,1.7424506,1.867798,,,,,,,,,,,,,, -103500,1.5909127,1.8559817,,,,,,,,,,,,,, -103600,1.7903335,1.8987558,,,,,,,,,,,,,, -103700,1.784398,1.8621981,,,,,,,,,,,,,, -103800,1.804851,1.9666194,,,,,,,,,,,,,, -103900,1.6827269,1.8579185,,,,,,,,,,,,,, -104000,1.7986604,1.9471993,,,,,,,,,,,,,, -104100,1.8307425,1.9752383,,,,,,,,,,,,,, -104200,1.8315058,1.8719761,,,,,,,,,,,,,, -104300,1.8178637,1.8978174,,,,,,,,,,,,,, -104400,1.9361647,1.8301172,,,,,,,,,,,,,, -104500,1.8959619,1.9513706,,,,,,,,,,,,,, -104600,1.8939883,1.9647774,,,,,,,,,,,,,, -104700,1.838261,1.917651,,,,,,,,,,,,,, -104800,1.748299,1.8967606,,,,,,,,,,,,,, -104805,,,0.5387436151504517,1.9647153615951536,0.4964599907398224,2.206085205078125,50000.0,0.3752000033855438,3.0397000312805176,10000.0,35228.39146900177,36464.10525536537,35228.39146900177,1229.5005240440369,2.5585193634033203,0.0 -104900,1.7781202,1.9691445,,,,,,,,,,,,,, -105000,1.7674983,1.944951,,,,,,,,,,,,,, -105100,1.7483792,1.8869334,,,,,,,,,,,,,, -105200,1.8192626,1.9476378,,,,,,,,,,,,,, -105300,1.7277551,2.0305643,,,,,,,,,,,,,, -105400,1.8190236,1.9659249,,,,,,,,,,,,,, -105500,1.7325708,2.026538,,,,,,,,,,,,,, -105600,1.813418,1.9376858,,,,,,,,,,,,,, -105700,1.872682,1.9686221,,,,,,,,,,,,,, -105800,1.713525,1.8687402,,,,,,,,,,,,,, -105900,1.7720547,1.8335346,,,,,,,,,,,,,, -106000,1.8184398,2.0217252,,,,,,,,,,,,,, -106100,1.7696381,1.8952947,,,,,,,,,,,,,, -106200,1.8565496,1.8526368,,,,,,,,,,,,,, -106300,1.9097387,1.8430519,,,,,,,,,,,,,, -106326,,,0.5285993218421936,2.0710136890411377,0.4899199903011322,2.315893173217773,50000.0,0.3743000030517578,3.179978370666504,10000.0,35738.42715740204,36992.00506877899,35738.42715740204,1247.2531068325045,2.616370439529419,0.0 -106400,1.7965596,1.9149039,,,,,,,,,,,,,, -106500,1.8896745,1.9495412,,,,,,,,,,,,,, -106600,1.8196603,1.895767,,,,,,,,,,,,,, -106700,1.8946418,1.9027767,,,,,,,,,,,,,, -106800,2.0584407,1.9396653,,,,,,,,,,,,,, -106900,1.9661918,1.9060645,,,,,,,,,,,,,, -107000,1.8097441,1.7566186,,,,,,,,,,,,,, -107100,2.0151346,2.013771,,,,,,,,,,,,,, -107200,1.9169607,2.031451,,,,,,,,,,,,,, -107300,1.7923306,1.9344145,,,,,,,,,,,,,, -107400,1.8810788,1.921,,,,,,,,,,,,,, -107500,1.7399664,1.917275,,,,,,,,,,,,,, -107600,1.8416237,1.7958878,,,,,,,,,,,,,, -107700,1.7609613,1.9693131,,,,,,,,,,,,,, -107800,1.7444619,1.8927202,,,,,,,,,,,,,, -107847,,,0.4673548936843872,2.430663108825684,0.4399999976158142,2.63357925415039,50000.0,0.3248000144958496,3.541719675064087,10000.0,36248.5503616333,37519.58160948753,36248.5503616333,1264.6096456050873,2.6608407497406006,0.0 -107900,1.9224435,1.9828271,,,,,,,,,,,,,, -108000,1.7618438,1.9546752,,,,,,,,,,,,,, -108100,1.8624594,1.9759872,,,,,,,,,,,,,, -108200,1.8290452,1.9207423,,,,,,,,,,,,,, -108300,1.8132795,1.9408704,,,,,,,,,,,,,, -108400,1.7997085,1.9471143,,,,,,,,,,,,,, -108500,1.8450432,1.8347898,,,,,,,,,,,,,, -108600,2.167609,2.0234761,,,,,,,,,,,,,, -108700,2.0059588,1.8827502,,,,,,,,,,,,,, -108800,1.9849035,2.001867,,,,,,,,,,,,,, -108900,1.812252,1.8764213,,,,,,,,,,,,,, -109000,1.7561096,1.8712726,,,,,,,,,,,,,, -109100,2.0536153,1.9020095,,,,,,,,,,,,,, -109200,1.6894599,1.7849193,,,,,,,,,,,,,, -109300,1.7468686,1.8744104,,,,,,,,,,,,,, -109368,,,0.5494060516357422,1.9401334524154663,0.510699987411499,2.1658456325531006,50000.0,0.388700008392334,3.003093719482422,10000.0,36758.63085794449,38047.09022331238,36758.63085794449,1281.939534187317,2.70654034614563,0.0 -109400,1.9936603,1.9192317,,,,,,,,,,,,,, -109500,1.8907522,1.906499,,,,,,,,,,,,,, -109600,2.083277,1.9050684,,,,,,,,,,,,,, -109700,1.829795,1.828037,,,,,,,,,,,,,, -109800,1.8761501,1.856429,,,,,,,,,,,,,, -109900,2.1794548,1.9264563,,,,,,,,,,,,,, -110000,1.8093122,1.869411,,,,,,,,,,,,,, -110100,1.9435265,1.7995942,,,,,,,,,,,,,, -110200,1.8410151,1.9052458,,,,,,,,,,,,,, -110300,1.9282197,1.8683872,,,,,,,,,,,,,, -110400,1.8760669,2.016523,,,,,,,,,,,,,, -110500,1.9063617,1.9448068,,,,,,,,,,,,,, -110600,1.750198,1.8460035,,,,,,,,,,,,,, -110700,1.8796825,1.9428151,,,,,,,,,,,,,, -110800,1.9578675,1.9753639,,,,,,,,,,,,,, -110890,,,0.6084582209587097,1.6108510494232178,0.5411800146102905,1.968849301338196,50000.0,0.4267000257968902,2.716227054595948,10000.0,37268.7436478138,38574.78150773048,37268.7436478138,1299.4154126644137,2.756378173828125,0.0 -110900,2.168438,1.9064767,,,,,,,,,,,,,, -111000,1.9191563,2.0040393,,,,,,,,,,,,,, -111100,1.7804253,1.916817,,,,,,,,,,,,,, -111200,1.937193,1.7640107,,,,,,,,,,,,,, -111300,1.8734826,1.860448,,,,,,,,,,,,,, -111400,1.8593599,1.8253359,,,,,,,,,,,,,, -111500,1.9973763,1.942124,,,,,,,,,,,,,, -111600,1.8666642,1.8524021,,,,,,,,,,,,,, -111700,2.086469,1.8728139,,,,,,,,,,,,,, -111800,1.8901616,1.8885813,,,,,,,,,,,,,, -111900,1.917594,1.7981452,,,,,,,,,,,,,, -112000,2.034551,1.8268247,,,,,,,,,,,,,, -112100,1.931143,1.8567224,,,,,,,,,,,,,, -112200,1.8926642,1.8040669,,,,,,,,,,,,,, -112300,1.8866712,1.8719409,,,,,,,,,,,,,, -112400,2.0544567,1.9671197,,,,,,,,,,,,,, -112411,,,0.5785833597183228,1.7491344213485718,0.5329799652099609,2.002593278884888,50000.0,0.416700005531311,2.758341073989868,10000.0,37778.82757425308,39102.51637029648,37778.82757425308,1316.9662177562714,2.803908109664917,0.0 -112500,1.9803106,1.8336163,,,,,,,,,,,,,, -112600,1.8300811,1.9059724,,,,,,,,,,,,,, -112700,1.9568393,1.8005832,,,,,,,,,,,,,, -112800,1.9402095,1.7782421,,,,,,,,,,,,,, -112900,1.969545,1.8838013,,,,,,,,,,,,,, -113000,1.8742,1.881738,,,,,,,,,,,,,, -113100,2.0276175,1.8286183,,,,,,,,,,,,,, -113200,1.9565319,1.9000709,,,,,,,,,,,,,, -113300,2.0433853,1.8073808,,,,,,,,,,,,,, -113400,2.0664418,1.9028721,,,,,,,,,,,,,, -113500,1.8634746,1.8714395,,,,,,,,,,,,,, -113600,1.94538,1.8268648,,,,,,,,,,,,,, -113700,1.9932767,1.8622255,,,,,,,,,,,,,, -113800,1.8047707,1.8379139,,,,,,,,,,,,,, -113900,2.1812,1.9687767,,,,,,,,,,,,,, -113933,,,0.5904615521430969,1.7069220542907717,0.5450599789619446,1.986126065254212,50000.0,0.4113000333309173,2.844345569610596,10000.0,38288.882075071335,39630.119643211365,38288.882075071335,1334.4177539348602,2.848686695098877,0.0 -114000,1.9965577,1.856176,,,,,,,,,,,,,, -114100,1.9279327,1.8301651,,,,,,,,,,,,,, -114200,2.0212018,1.8305398,,,,,,,,,,,,,, -114300,2.1035023,1.7323081,,,,,,,,,,,,,, -114400,1.9441183,1.8359749,,,,,,,,,,,,,, -114500,1.9797752,1.8476291,,,,,,,,,,,,,, -114600,1.849433,1.7733041,,,,,,,,,,,,,, -114700,1.9783307,1.8023021,,,,,,,,,,,,,, -114800,1.7899351,1.7699125,,,,,,,,,,,,,, -114900,2.0259988,1.8629184,,,,,,,,,,,,,, -115000,2.000711,1.7630625,,,,,,,,,,,,,, -115100,2.018232,1.8268787,,,,,,,,,,,,,, -115200,2.1029046,1.9335626,,,,,,,,,,,,,, -115300,2.1628242,1.8668956,,,,,,,,,,,,,, -115400,2.0963998,1.8795308,,,,,,,,,,,,,, -115454,,,0.5925542116165161,1.6781911849975586,0.5508800148963928,1.918237566947937,50000.0,0.426000028848648,2.689170360565185,10000.0,38799.00455093384,40157.88823246956,38799.00455093384,1351.958430767059,2.899632215499878,0.0 -115500,2.0358808,1.9084294,,,,,,,,,,,,,, -115600,2.2821712,1.7499292,,,,,,,,,,,,,, -115700,1.898314,1.9280227,,,,,,,,,,,,,, -115800,2.1014066,1.8075942,,,,,,,,,,,,,, -115900,2.0994942,1.676471,,,,,,,,,,,,,, -116000,1.9494709,1.7491126,,,,,,,,,,,,,, -116100,2.1195,1.7892606,,,,,,,,,,,,,, -116200,1.9857216,1.9156897,,,,,,,,,,,,,, -116300,1.9031843,1.7933594,,,,,,,,,,,,,, -116400,2.1142182,1.8477252,,,,,,,,,,,,,, -116500,2.0070999,1.6909018,,,,,,,,,,,,,, -116600,2.1995106,1.8593621,,,,,,,,,,,,,, -116700,2.2343986,1.8700639,,,,,,,,,,,,,, -116800,1.972341,1.7073495,,,,,,,,,,,,,, -116900,1.9320018,1.7374326,,,,,,,,,,,,,, -116975,,,0.4353674948215484,2.671430349349976,0.4124199748039245,2.853731870651245,50000.0,0.3038000166416168,3.7805306911468506,10000.0,39309.09029126167,40685.7168943882,39309.09029126167,1369.5922305583954,2.9562196731567383,0.0 -117000,2.0424654,1.8365374,,,,,,,,,,,,,, -117100,2.0866847,1.756235,,,,,,,,,,,,,, -117200,2.140375,1.9329047,,,,,,,,,,,,,, -117300,2.0169694,1.796978,,,,,,,,,,,,,, -117400,2.1447566,1.7555218,,,,,,,,,,,,,, -117500,1.9945863,1.7918917,,,,,,,,,,,,,, -117600,2.2313085,1.8824253,,,,,,,,,,,,,, -117700,2.1588638,1.8112065,,,,,,,,,,,,,, -117800,2.163126,1.8961703,,,,,,,,,,,,,, -117900,2.1204221,1.8422922,,,,,,,,,,,,,, -118000,2.0190585,1.8170018,,,,,,,,,,,,,, -118100,2.0665395,1.8403511,,,,,,,,,,,,,, -118200,2.3230941,1.8303618,,,,,,,,,,,,,, -118300,2.1287143,1.8844292,,,,,,,,,,,,,, -118400,2.085555,1.7449316,,,,,,,,,,,,,, -118496,,,0.5356146097183228,2.0111894607543945,0.498659998178482,2.2324469089508057,50000.0,0.3787000179290771,3.0569660663604736,10000.0,39819.04309177399,41213.57382154465,39819.04309177399,1387.3949823379517,3.004948377609253,0.0 -118500,2.0426505,1.785384,,,,,,,,,,,,,, -118600,1.9554299,1.7418652,,,,,,,,,,,,,, -118700,2.081224,1.8309343,,,,,,,,,,,,,, -118800,2.16848,1.8154439,,,,,,,,,,,,,, -118900,1.9434115,1.7506166,,,,,,,,,,,,,, -119000,2.087489,1.7546008,,,,,,,,,,,,,, -119100,2.1382508,1.8217477,,,,,,,,,,,,,, -119200,2.0644069,1.8528463,,,,,,,,,,,,,, -119300,1.9957687,1.7415887,,,,,,,,,,,,,, -119400,2.0647182,1.7799875,,,,,,,,,,,,,, -119500,1.9725424,1.7904325,,,,,,,,,,,,,, -119600,2.1609707,1.9394372,,,,,,,,,,,,,, -119700,2.1041653,1.8088318,,,,,,,,,,,,,, -119800,2.1551712,1.8058628,,,,,,,,,,,,,, -119900,2.192793,1.7831875,,,,,,,,,,,,,, -120000,2.1082401,1.8221769,,,,,,,,,,,,,, -120017,,,0.5190728306770325,2.124185562133789,0.463919997215271,2.485764741897583,50000.0,0.3614000082015991,3.282891273498535,10000.0,40329.04471373558,41741.30652666092,40329.04471373558,1405.0201497077942,3.0573155879974365,0.0 -120100,2.094987,1.7447163,,,,,,,,,,,,,, -120200,1.8920689,1.7995988,,,,,,,,,,,,,, -120300,2.0665817,1.8782408,,,,,,,,,,,,,, -120400,2.0354903,1.7750447,,,,,,,,,,,,,, -120500,2.1717923,1.7184162,,,,,,,,,,,,,, -120600,2.1470268,1.957901,,,,,,,,,,,,,, -120700,2.1535578,1.9500179,,,,,,,,,,,,,, -120800,2.2109146,1.8085871,,,,,,,,,,,,,, -120900,2.3792343,1.8345045,,,,,,,,,,,,,, -121000,2.08983,1.8632665,,,,,,,,,,,,,, -121100,2.1222076,1.6312279,,,,,,,,,,,,,, -121200,2.3164508,1.8163393,,,,,,,,,,,,,, -121300,2.0297158,1.699098,,,,,,,,,,,,,, -121400,2.2207634,1.7794435,,,,,,,,,,,,,, -121500,2.056434,1.7797467,,,,,,,,,,,,,, -121539,,,0.5934311151504517,1.6687780618667605,0.545740008354187,1.940511703491211,50000.0,0.4230000078678131,2.734823703765869,10000.0,40839.1926074028,42268.991443395615,40839.1926074028,1422.4556045532229,3.1066973209381104,0.0 -121600,2.2517102,1.8509533,,,,,,,,,,,,,, -121700,2.3493817,1.8847864,,,,,,,,,,,,,, -121800,2.178819,1.7708255,,,,,,,,,,,,,, -121900,2.2840126,1.9242333,,,,,,,,,,,,,, -122000,2.2049072,1.6983203,,,,,,,,,,,,,, -122100,2.1484146,1.8357842,,,,,,,,,,,,,, -122200,2.2729082,1.7232484,,,,,,,,,,,,,, -122300,2.1523404,1.7978926,,,,,,,,,,,,,, -122400,2.2808886,1.7679229,,,,,,,,,,,,,, -122500,2.2438865,1.7758853,,,,,,,,,,,,,, -122600,2.387248,1.8536907,,,,,,,,,,,,,, -122700,2.2390785,1.7143861,,,,,,,,,,,,,, -122800,2.188453,1.8225355,,,,,,,,,,,,,, -122900,2.2860262,1.675684,,,,,,,,,,,,,, -123000,2.5553126,1.8179513,,,,,,,,,,,,,, -123061,,,0.5442841053009033,1.9463707208633425,0.5045599937438965,2.186814069747925,50000.0,0.3936000168323517,2.924185514450073,10000.0,41349.20190811157,42796.67501139641,41349.20190811157,1440.025390625,3.157965421676636,0.0 -123100,2.1723266,1.7720829,,,,,,,,,,,,,, -123200,2.2045658,1.7985044,,,,,,,,,,,,,, -123300,2.0616841,1.7042234,,,,,,,,,,,,,, -123400,2.3515499,1.7502178,,,,,,,,,,,,,, -123500,2.5675285,1.9072075,,,,,,,,,,,,,, -123600,2.2345526,1.7123339,,,,,,,,,,,,,, -123700,2.147772,1.7292246,,,,,,,,,,,,,, -123800,2.5923138,1.8319982,,,,,,,,,,,,,, -123900,2.3057656,1.7473537,,,,,,,,,,,,,, -124000,2.173288,1.6328237,,,,,,,,,,,,,, -124100,2.132289,1.6846931,,,,,,,,,,,,,, -124200,2.175291,1.7915623,,,,,,,,,,,,,, -124300,2.201304,1.8230817,,,,,,,,,,,,,, -124400,2.1814651,1.7565291,,,,,,,,,,,,,, -124500,2.246825,1.7552695,,,,,,,,,,,,,, -124582,,,0.6316764950752258,1.5131675004959106,0.5785199999809265,1.7937418222427368,50000.0,0.458400011062622,2.5650177001953125,10000.0,41859.287153720856,43324.37479400635,41859.287153720856,1457.5347940921783,3.209736585617065,0.0 -124600,2.206771,1.6858017,,,,,,,,,,,,,, -124700,2.3761091,1.7465959,,,,,,,,,,,,,, -124800,2.2992544,1.7712437,,,,,,,,,,,,,, -124900,2.5211372,1.8166772,,,,,,,,,,,,,, -125000,2.4445865,1.7387104,,,,,,,,,,,,,, -125100,2.361373,1.8236164,,,,,,,,,,,,,, -125200,2.2708316,1.7875569,,,,,,,,,,,,,, -125300,2.3518522,1.7586896,,,,,,,,,,,,,, -125400,2.395358,1.7915758,,,,,,,,,,,,,, -125500,2.3284326,1.8355196,,,,,,,,,,,,,, -125600,2.318721,1.7792777,,,,,,,,,,,,,, -125700,2.3240829,1.7551073,,,,,,,,,,,,,, -125800,2.1543972,1.5641922,,,,,,,,,,,,,, -125900,2.2795644,1.7717066,,,,,,,,,,,,,, -126000,2.4852836,1.9119748,,,,,,,,,,,,,, -126100,2.350322,1.6019167,,,,,,,,,,,,,, -126103,,,0.630301296710968,1.5103756189346311,0.5861600041389465,1.7711243629455566,50000.0,0.4564000070095062,2.564093589782715,10000.0,42369.22559094429,43852.23295521736,42369.22559094429,1475.3515548706057,3.2582929134368896,0.0 -126200,2.4270625,1.7519194,,,,,,,,,,,,,, -126300,2.223351,1.7123629,,,,,,,,,,,,,, -126400,2.2332215,1.6751947,,,,,,,,,,,,,, -126500,2.2591321,1.7018448,,,,,,,,,,,,,, -126600,2.1252644,1.5730995,,,,,,,,,,,,,, -126700,2.2668374,1.6269879,,,,,,,,,,,,,, -126800,2.4031153,1.6755202,,,,,,,,,,,,,, -126900,2.2977622,1.6438749,,,,,,,,,,,,,, -127000,2.3708787,1.7952571,,,,,,,,,,,,,, -127100,2.4725173,1.7098172,,,,,,,,,,,,,, -127200,2.2421134,1.755318,,,,,,,,,,,,,, -127300,2.4312978,1.9125621,,,,,,,,,,,,,, -127400,2.5842922,1.749677,,,,,,,,,,,,,, -127500,2.3985488,1.7521113,,,,,,,,,,,,,, -127600,2.4409728,1.7016643,,,,,,,,,,,,,, -127624,,,0.6471021771430969,1.4206527471542358,0.5690199732780457,1.836398720741272,50000.0,0.4485000073909759,2.574965000152588,10000.0,42879.3197491169,44379.98077702522,42879.3197491169,1492.9049038887024,3.306204319000244,0.0 -127700,2.4546614,1.7866998,,,,,,,,,,,,,, -127800,2.4574902,1.7828218,,,,,,,,,,,,,, -127900,2.591851,1.8504112,,,,,,,,,,,,,, -128000,2.3153508,1.789284,,,,,,,,,,,,,, -128100,2.4207587,1.777575,,,,,,,,,,,,,, -128200,2.4057178,1.8178886,,,,,,,,,,,,,, -128300,2.3626957,1.8161747,,,,,,,,,,,,,, -128400,2.346955,1.6793756,,,,,,,,,,,,,, -128500,2.2621677,1.7490733,,,,,,,,,,,,,, -128600,2.3024576,1.7056723,,,,,,,,,,,,,, -128700,2.3084836,1.5970408,,,,,,,,,,,,,, -128800,2.5394955,1.6750767,,,,,,,,,,,,,, -128900,2.4788442,1.6777653,,,,,,,,,,,,,, -129000,2.4406977,1.6844268,,,,,,,,,,,,,, -129100,2.4398222,1.7425625,,,,,,,,,,,,,, -129145,,,0.6431162357330322,1.4401280879974363,0.5803799629211426,1.769958734512329,50000.0,0.453900009393692,2.5460996627807617,10000.0,43389.23870754242,44907.46559214592,43389.23870754242,1510.3695363998413,3.35434365272522,0.0 -129200,2.4728854,1.76705,,,,,,,,,,,,,, -129300,2.570384,1.7768123,,,,,,,,,,,,,, -129400,2.3493814,1.5648229,,,,,,,,,,,,,, -129500,2.4026241,1.6349107,,,,,,,,,,,,,, -129600,2.560721,1.5775392,,,,,,,,,,,,,, -129700,2.4924629,1.6472442,,,,,,,,,,,,,, -129800,2.3505728,1.5764174,,,,,,,,,,,,,, -129900,2.447745,1.7418762,,,,,,,,,,,,,, -130000,2.4474561,1.815275,,,,,,,,,,,,,, -130100,2.5956883,1.6306899,,,,,,,,,,,,,, -130200,2.402107,1.6688223,,,,,,,,,,,,,, -130300,2.4279404,1.6070219,,,,,,,,,,,,,, -130400,2.4324975,1.580741,,,,,,,,,,,,,, -130500,2.4386241,1.5810137,,,,,,,,,,,,,, -130600,2.4284773,1.6003739,,,,,,,,,,,,,, -130666,,,0.6172273755073547,1.5718461275100708,0.5622000098228455,1.8866841793060305,50000.0,0.4480000138282776,2.622263193130493,10000.0,43899.253200531006,45434.87541007996,43899.253200531006,1527.6565673351288,3.409334897994995,0.0 -130700,2.421446,1.6982303,,,,,,,,,,,,,, -130800,2.6333761,1.639053,,,,,,,,,,,,,, -130900,2.4872546,1.6529195,,,,,,,,,,,,,, -131000,2.512254,1.6276706,,,,,,,,,,,,,, -131100,2.433469,1.6160551,,,,,,,,,,,,,, -131200,2.3423178,1.6476561,,,,,,,,,,,,,, -131300,2.5181172,1.6182346,,,,,,,,,,,,,, -131400,2.5033169,1.6677456,,,,,,,,,,,,,, -131500,2.515399,1.7627267,,,,,,,,,,,,,, -131600,2.5082324,1.6705432,,,,,,,,,,,,,, -131700,2.4074857,1.64852,,,,,,,,,,,,,, -131800,2.4966052,1.6937358,,,,,,,,,,,,,, -131900,2.4444394,1.6242626,,,,,,,,,,,,,, -132000,2.4057615,1.5990663,,,,,,,,,,,,,, -132100,2.5080137,1.6609917,,,,,,,,,,,,,, -132187,,,0.5950454473495483,1.6977694034576416,0.5471599698066711,1.9765727519989007,50000.0,0.4357000291347503,2.7501347064971924,10000.0,44409.17866325378,45962.245206832886,44409.17866325378,1544.999391555786,3.458484411239624,0.0 -132200,2.6047435,1.6540382,,,,,,,,,,,,,, -132300,2.6247303,1.6256535,,,,,,,,,,,,,, -132400,2.6127756,1.581582,,,,,,,,,,,,,, -132500,2.6164718,1.5694638,,,,,,,,,,,,,, -132600,2.4754574,1.7038991,,,,,,,,,,,,,, -132700,2.7588656,1.658355,,,,,,,,,,,,,, -132800,2.6081786,1.5826019,,,,,,,,,,,,,, -132900,2.68525,1.7524147,,,,,,,,,,,,,, -133000,2.4539406,1.7257851,,,,,,,,,,,,,, -133100,2.5045352,1.5490987,,,,,,,,,,,,,, -133200,2.5674071,1.7593467,,,,,,,,,,,,,, -133300,2.658739,1.6782527,,,,,,,,,,,,,, -133400,2.6222372,1.7243874,,,,,,,,,,,,,, -133500,2.4742622,1.604448,,,,,,,,,,,,,, -133600,2.506141,1.6343641,,,,,,,,,,,,,, -133700,2.7036445,1.6829903,,,,,,,,,,,,,, -133708,,,0.6533800959587097,1.3872883319854736,0.6028199791908264,1.666730284690857,50000.0,0.4850000143051147,2.405172348022461,10000.0,44919.16495108605,46490.850856781006,44919.16495108605,1563.513568162918,3.509937763214112,0.0 -133800,2.762876,1.7511916,,,,,,,,,,,,,, -133900,2.6983154,1.625614,,,,,,,,,,,,,, -134000,2.829636,1.6181298,,,,,,,,,,,,,, -134100,2.6704514,1.6361063,,,,,,,,,,,,,, -134200,2.5242298,1.6097941,,,,,,,,,,,,,, -134300,2.5642593,1.6712477,,,,,,,,,,,,,, -134400,2.4137619,1.5019687,,,,,,,,,,,,,, -134500,2.8494291,1.7487447,,,,,,,,,,,,,, -134600,2.7050147,1.5974019,,,,,,,,,,,,,, -134700,2.64984,1.65593,,,,,,,,,,,,,, -134800,2.639199,1.596308,,,,,,,,,,,,,, -134900,2.5712962,1.5530776,,,,,,,,,,,,,, -135000,2.5982049,1.5956104,,,,,,,,,,,,,, -135100,2.8154008,1.5395045,,,,,,,,,,,,,, -135200,2.5824797,1.70408,,,,,,,,,,,,,, -135229,,,0.6850087642669678,1.247851014137268,0.632099986076355,1.5180271863937378,50000.0,0.5041000247001648,2.260892868041992,10000.0,45429.14738154411,47018.49057650566,45429.14738154411,1581.0774466991425,3.550867080688477,0.0 -135300,2.6716228,1.6922129,,,,,,,,,,,,,, -135400,2.8748085,1.7631769,,,,,,,,,,,,,, -135500,2.56526,1.6850075,,,,,,,,,,,,,, -135600,2.8393025,1.611475,,,,,,,,,,,,,, -135700,2.5019884,1.5837408,,,,,,,,,,,,,, -135800,2.5358799,1.5990511,,,,,,,,,,,,,, -135900,2.6947882,1.5625218,,,,,,,,,,,,,, -136000,2.5356147,1.6749816,,,,,,,,,,,,,, -136100,3.0696428,1.6983672,,,,,,,,,,,,,, -136200,2.6813843,1.6345212,,,,,,,,,,,,,, -136300,2.8344479,1.5967412,,,,,,,,,,,,,, -136400,2.959944,1.5892508,,,,,,,,,,,,,, -136500,2.9314659,1.6422298,,,,,,,,,,,,,, -136600,2.762838,1.5693609,,,,,,,,,,,,,, -136700,2.8098319,1.6017965,,,,,,,,,,,,,, -136750,,,0.6704998016357422,1.3005253076553345,0.5945799946784973,1.7191460132598877,50000.0,0.4648000299930572,2.5056369304656982,10000.0,45939.10034203529,47546.12515926361,45939.10034203529,1598.6504790782928,3.605985164642334,0.0 -136800,2.8259263,1.6085546,,,,,,,,,,,,,, -136900,2.8016787,1.5609477,,,,,,,,,,,,,, -137000,2.8374252,1.5640503,,,,,,,,,,,,,, -137100,3.0718887,1.6668591,,,,,,,,,,,,,, -137200,2.727299,1.5496012,,,,,,,,,,,,,, -137300,2.757359,1.6053884,,,,,,,,,,,,,, -137400,3.1806853,1.63001,,,,,,,,,,,,,, -137500,2.813826,1.7100276,,,,,,,,,,,,,, -137600,2.8971245,1.5823082,,,,,,,,,,,,,, -137700,2.8788755,1.6919024,,,,,,,,,,,,,, -137800,2.843023,1.6335689,,,,,,,,,,,,,, -137900,2.6964574,1.5307858,,,,,,,,,,,,,, -138000,2.5799463,1.4985323,,,,,,,,,,,,,, -138100,2.7916083,1.5570673,,,,,,,,,,,,,, -138200,2.9526112,1.6396779,,,,,,,,,,,,,, -138271,,,0.6985809803009033,1.16620934009552,0.6342599987983704,1.5143696069717407,50000.0,0.5098000168800354,2.2224693298339844,10000.0,46449.014721632,48073.62446284294,46449.014721632,1616.1354587078094,3.6534364223480233,0.0 -138300,2.7299426,1.559114,,,,,,,,,,,,,, -138400,2.9284146,1.681845,,,,,,,,,,,,,, -138500,2.843637,1.5519214,,,,,,,,,,,,,, -138600,2.8917341,1.6563036,,,,,,,,,,,,,, -138700,2.9122076,1.6248544,,,,,,,,,,,,,, -138800,2.8217509,1.5347359,,,,,,,,,,,,,, -138900,2.8149972,1.5499536,,,,,,,,,,,,,, -139000,2.770687,1.6406522,,,,,,,,,,,,,, -139100,3.063969,1.6004641,,,,,,,,,,,,,, -139200,2.8349998,1.5794864,,,,,,,,,,,,,, -139300,3.020634,1.7103541,,,,,,,,,,,,,, -139400,3.0118892,1.5444735,,,,,,,,,,,,,, -139500,2.9192722,1.631035,,,,,,,,,,,,,, -139600,2.7181797,1.5582237,,,,,,,,,,,,,, -139700,2.6778603,1.4239213,,,,,,,,,,,,,, -139791,,,0.6724529266357422,1.3127702474594116,0.6125199794769287,1.6271332502365112,50000.0,0.4791000187397003,2.424319267272949,10000.0,46959.01626968384,48601.15681409836,46959.01626968384,1633.5621328353882,3.705156087875366,0.0 -139800,2.9712956,1.5754092,,,,,,,,,,,,,, -139900,2.9839957,1.56112,,,,,,,,,,,,,, -140000,3.1559422,1.5908904,,,,,,,,,,,,,, -140100,2.9086664,1.696619,,,,,,,,,,,,,, -140200,2.77822,1.4365869,,,,,,,,,,,,,, -140300,3.0264983,1.5999844,,,,,,,,,,,,,, -140400,2.9312112,1.5342956,,,,,,,,,,,,,, -140500,2.8944955,1.5597594,,,,,,,,,,,,,, -140600,2.895657,1.5680249,,,,,,,,,,,,,, -140700,2.9319077,1.559945,,,,,,,,,,,,,, -140800,2.819784,1.4826441,,,,,,,,,,,,,, -140900,3.2445834,1.5531976,,,,,,,,,,,,,, -141000,2.9839864,1.5778859,,,,,,,,,,,,,, -141100,3.0406792,1.5228887,,,,,,,,,,,,,, -141200,3.1388195,1.5146966,,,,,,,,,,,,,, -141300,3.0621166,1.5927067,,,,,,,,,,,,,, -141312,,,0.6929009556770325,1.2138513326644895,0.6319199800491333,1.516939997673035,50000.0,0.5,2.2908971309661865,10000.0,47468.91835808754,49128.50067186356,47468.91835808754,1650.8986172676086,3.7565438747406006,0.0 -141400,2.8309758,1.453078,,,,,,,,,,,,,, -141500,2.9921684,1.5562627,,,,,,,,,,,,,, -141600,3.2020917,1.647548,,,,,,,,,,,,,, -141700,2.996804,1.5893865,,,,,,,,,,,,,, -141800,3.131544,1.6152927,,,,,,,,,,,,,, -141900,2.860632,1.4722762,,,,,,,,,,,,,, -142000,2.8341296,1.4875332,,,,,,,,,,,,,, -142100,3.1264207,1.6101373,,,,,,,,,,,,,, -142200,2.8611102,1.4624704,,,,,,,,,,,,,, -142300,3.0142264,1.5215218,,,,,,,,,,,,,, -142400,2.9737165,1.5706543,,,,,,,,,,,,,, -142500,3.044984,1.5669516,,,,,,,,,,,,,, -142600,3.1505015,1.5139561,,,,,,,,,,,,,, -142700,3.048422,1.4916384,,,,,,,,,,,,,, -142800,3.130653,1.5825131,,,,,,,,,,,,,, -142833,,,0.7071707248687744,1.1520593166351318,0.6445599794387817,1.4677226543426514,50000.0,0.5177000164985657,2.194716215133667,10000.0,47978.88781452179,49656.22150874138,47978.88781452179,1668.5178027153015,3.8354225158691406,0.0 -142900,3.1046214,1.5270628,,,,,,,,,,,,,, -143000,2.9547079,1.5461572,,,,,,,,,,,,,, -143100,3.0137532,1.5835633,,,,,,,,,,,,,, -143200,3.068452,1.4929887,,,,,,,,,,,,,, -143300,3.0614386,1.6408615,,,,,,,,,,,,,, -143400,3.1136212,1.6001191,,,,,,,,,,,,,, -143500,3.1374838,1.4528468,,,,,,,,,,,,,, -143600,3.0804698,1.5499504,,,,,,,,,,,,,, -143700,3.1130896,1.4490821,,,,,,,,,,,,,, -143800,3.3830514,1.6649252,,,,,,,,,,,,,, -143900,2.9031792,1.4817069,,,,,,,,,,,,,, -144000,3.059451,1.5045378,,,,,,,,,,,,,, -144100,3.3115346,1.5297462,,,,,,,,,,,,,, -144200,2.9342911,1.4811537,,,,,,,,,,,,,, -144300,3.0257046,1.4547739,,,,,,,,,,,,,, -144354,,,0.7137077450752258,1.1160260438919067,0.65065997838974,1.4377646446228027,50000.0,0.5253000259399414,2.1807503700256348,10000.0,48488.88462114334,50183.9176530838,48488.88462114334,1686.1094892024994,3.890009880065918,0.0 -144400,3.1645753,1.5172969,,,,,,,,,,,,,, -144500,3.2411695,1.4448302,,,,,,,,,,,,,, -144600,3.330502,1.5327283,,,,,,,,,,,,,, -144700,3.0176306,1.4934962,,,,,,,,,,,,,, -144800,3.1485455,1.43192,,,,,,,,,,,,,, -144900,2.9285216,1.3862666,,,,,,,,,,,,,, -145000,3.2188761,1.5544578,,,,,,,,,,,,,, -145100,3.1248407,1.4628131,,,,,,,,,,,,,, -145200,3.138484,1.5395328,,,,,,,,,,,,,, -145300,2.978479,1.4741263,,,,,,,,,,,,,, -145400,3.1519883,1.4284071,,,,,,,,,,,,,, -145500,3.0166147,1.5066986,,,,,,,,,,,,,, -145600,3.057454,1.487982,,,,,,,,,,,,,, -145700,3.573282,1.566113,,,,,,,,,,,,,, -145800,2.9805808,1.4504476,,,,,,,,,,,,,, -145875,,,0.7431241869926453,0.9743123650550842,0.6586199998855591,1.4034175872802734,50000.0,0.534000039100647,2.145510673522949,10000.0,48998.85201382637,50711.597870111465,48998.85201382637,1703.7169313430786,3.942670345306397,0.0 -145900,2.9186344,1.4071237,,,,,,,,,,,,,, -146000,3.0829957,1.5885088,,,,,,,,,,,,,, -146100,3.2116349,1.5374641,,,,,,,,,,,,,, -146200,3.233462,1.5443363,,,,,,,,,,,,,, -146300,3.1236525,1.4885772,,,,,,,,,,,,,, -146400,3.0253887,1.4253265,,,,,,,,,,,,,, -146500,3.2878828,1.534903,,,,,,,,,,,,,, -146600,3.2684557,1.4044597,,,,,,,,,,,,,, -146700,3.2739613,1.4492334,,,,,,,,,,,,,, -146800,3.3128333,1.6209095,,,,,,,,,,,,,, -146900,3.1935546,1.5125467,,,,,,,,,,,,,, -147000,3.114302,1.319708,,,,,,,,,,,,,, -147100,3.0687194,1.4258243,,,,,,,,,,,,,, -147200,3.5396516,1.5072651,,,,,,,,,,,,,, -147300,3.3546314,1.4404454,,,,,,,,,,,,,, -147397,,,0.7065529227256775,1.1428754329681396,0.6351799964904785,1.5308077335357666,50000.0,0.5109000205993652,2.299001932144165,10000.0,49508.99960565567,51239.65789651871,49508.99960565567,1721.5213029384613,3.9980673789978014,0.0 -147400,3.422053,1.467501,,,,,,,,,,,,,, -147500,3.3955724,1.5407321,,,,,,,,,,,,,, -147600,3.3230267,1.4527707,,,,,,,,,,,,,, -147700,3.6511877,1.4273013,,,,,,,,,,,,,, -147800,3.2996826,1.37955,,,,,,,,,,,,,, -147900,3.1784098,1.4403049,,,,,,,,,,,,,, -148000,3.5115302,1.5643215,,,,,,,,,,,,,, -148100,3.6122127,1.5759377,,,,,,,,,,,,,, -148200,3.4128933,1.4046351,,,,,,,,,,,,,, -148300,3.5707126,1.4507167,,,,,,,,,,,,,, -148400,3.3808177,1.4628843,,,,,,,,,,,,,, -148500,3.149585,1.3570809,,,,,,,,,,,,,, -148600,3.6459877,1.4801657,,,,,,,,,,,,,, -148700,3.2519512,1.3553711,,,,,,,,,,,,,, -148800,3.3794289,1.4940858,,,,,,,,,,,,,, -148900,3.4033248,1.4809175,,,,,,,,,,,,,, -148918,,,0.7453164458274841,0.9728658199310304,0.6709199547767639,1.3506791591644287,50000.0,0.5452000498771667,2.0889601707458496,10000.0,50019.09476184845,51767.18022537232,50019.09476184845,1738.840086698532,4.0529398918151855,0.0 -149000,3.4459884,1.4939148,,,,,,,,,,,,,, -149100,3.4377825,1.4848464,,,,,,,,,,,,,, -149200,3.5178077,1.4728061,,,,,,,,,,,,,, -149300,3.461215,1.4519776,,,,,,,,,,,,,, -149400,3.2979507,1.3398771,,,,,,,,,,,,,, -149500,4.09534,1.4574687,,,,,,,,,,,,,, -149600,3.5433285,1.381262,,,,,,,,,,,,,, -149700,3.4852662,1.4267964,,,,,,,,,,,,,, -149800,3.4354553,1.4923164,,,,,,,,,,,,,, -149900,3.303148,1.4230856,,,,,,,,,,,,,, -150000,3.5448592,1.3828883,,,,,,,,,,,,,, -150100,3.4172037,1.3052934,,,,,,,,,,,,,, -150200,3.4512439,1.4589739,,,,,,,,,,,,,, -150300,3.5705369,1.4706213,,,,,,,,,,,,,, -150400,3.3572462,1.3758683,,,,,,,,,,,,,, -150440,,,0.748465359210968,0.9664836525917052,0.6740999817848206,1.3404970169067385,50000.0,0.5436000227928162,2.0569565296173096,10000.0,50529.31740260124,52295.1010248661,50529.31740260124,1756.4312443733215,4.107654333114624,0.0 -150500,3.5438235,1.4774326,,,,,,,,,,,,,, -150600,3.629435,1.4303352,,,,,,,,,,,,,, -150700,3.6176229,1.5518389,,,,,,,,,,,,,, -150800,3.3516173,1.3527482,,,,,,,,,,,,,, -150900,3.686572,1.4504043,,,,,,,,,,,,,, -151000,3.534153,1.3571599,,,,,,,,,,,,,, -151100,3.5427287,1.3553574,,,,,,,,,,,,,, -151200,3.627933,1.456691,,,,,,,,,,,,,, -151300,3.5085423,1.4290744,,,,,,,,,,,,,, -151400,3.4036007,1.3425344,,,,,,,,,,,,,, -151500,3.6283643,1.412674,,,,,,,,,,,,,, -151600,3.6647878,1.4412686,,,,,,,,,,,,,, -151700,3.4453647,1.4131483,,,,,,,,,,,,,, -151800,3.6594532,1.442158,,,,,,,,,,,,,, -151900,3.461942,1.3107542,,,,,,,,,,,,,, -151961,,,0.7463727593421936,0.9735182523727416,0.6724399924278259,1.3431792259216309,50000.0,0.5437000393867493,2.081461906433105,10000.0,51039.45336413384,52822.879269361496,51039.45336413384,1773.9691338539124,4.158005237579346,0.0 -152000,3.615896,1.4142351,,,,,,,,,,,,,, -152100,3.5930443,1.3729104,,,,,,,,,,,,,, -152200,3.8449905,1.3775904,,,,,,,,,,,,,, -152300,3.4400697,1.3351953,,,,,,,,,,,,,, -152400,3.7933276,1.5296338,,,,,,,,,,,,,, -152500,4.0664682,1.3452897,,,,,,,,,,,,,, -152600,3.7953973,1.4229285,,,,,,,,,,,,,, -152700,3.812328,1.3864577,,,,,,,,,,,,,, -152800,3.5436075,1.4289774,,,,,,,,,,,,,, -152900,3.868289,1.4687554,,,,,,,,,,,,,, -153000,3.7115705,1.340339,,,,,,,,,,,,,, -153100,3.8107266,1.3481759,,,,,,,,,,,,,, -153200,3.6473117,1.4449404,,,,,,,,,,,,,, -153300,3.7345672,1.3716347,,,,,,,,,,,,,, -153400,3.9602907,1.4037071,,,,,,,,,,,,,, -153482,,,0.7898397445678711,0.7781878113746643,0.6873999834060669,1.2663946151733398,50000.0,0.5611000061035156,1.965351700782776,10000.0,51549.48877668381,53350.68533778191,51549.48877668381,1791.6252291202543,4.218658447265625,0.0 -153500,3.633235,1.36227,,,,,,,,,,,,,, -153600,3.8004131,1.3958564,,,,,,,,,,,,,, -153700,3.7569456,1.3766395,,,,,,,,,,,,,, -153800,4.0244236,1.3423598,,,,,,,,,,,,,, -153900,3.817052,1.3499463,,,,,,,,,,,,,, -154000,3.84082,1.4507053,,,,,,,,,,,,,, -154100,3.772309,1.4105375,,,,,,,,,,,,,, -154200,3.7460308,1.428613,,,,,,,,,,,,,, -154300,3.6821983,1.3022327,,,,,,,,,,,,,, -154400,3.867001,1.4377657,,,,,,,,,,,,,, -154500,3.7302728,1.2404557,,,,,,,,,,,,,, -154600,3.5550792,1.2667913,,,,,,,,,,,,,, -154700,3.745861,1.3386896,,,,,,,,,,,,,, -154800,3.9147816,1.3028016,,,,,,,,,,,,,, -154900,3.709675,1.1724058,,,,,,,,,,,,,, -155000,3.7860353,1.2938712,,,,,,,,,,,,,, -155003,,,0.7786391973495483,0.8249820470809937,0.6894599795341492,1.2767852544784546,50000.0,0.5552999973297119,2.019956350326538,10000.0,52059.40568423271,53877.97802639008,52059.40568423271,1808.892503023148,4.2747087478637695,0.0 -155100,3.9218364,1.3739696,,,,,,,,,,,,,, -155200,3.8114924,1.1908361,,,,,,,,,,,,,, -155300,3.651422,1.3242759,,,,,,,,,,,,,, -155400,4.1455874,1.3160429,,,,,,,,,,,,,, -155500,3.9259012,1.3408437,,,,,,,,,,,,,, -155600,3.787416,1.3803978,,,,,,,,,,,,,, -155700,4.048334,1.3281543,,,,,,,,,,,,,, -155800,3.5562274,1.2939243,,,,,,,,,,,,,, -155900,3.87619,1.2541305,,,,,,,,,,,,,, -156000,4.1074767,1.4051393,,,,,,,,,,,,,, -156100,3.8346763,1.2575154,,,,,,,,,,,,,, -156200,3.9979281,1.2728301,,,,,,,,,,,,,, -156300,4.0799327,1.3127477,,,,,,,,,,,,,, -156400,4.292607,1.3860202,,,,,,,,,,,,,, -156500,3.7433312,1.1970239,,,,,,,,,,,,,, -156525,,,0.7767059803009033,0.8280686140060425,0.6906200051307678,1.253747582435608,50000.0,0.5636000037193298,1.9697444438934328,10000.0,52569.56496477127,54405.94682025909,52569.56496477127,1826.589562416077,4.334457635879517,0.0 -156600,4.103594,1.4581861,,,,,,,,,,,,,, -156700,3.7220654,1.3046277,,,,,,,,,,,,,, -156800,3.8714762,1.2761164,,,,,,,,,,,,,, -156900,4.0617795,1.369448,,,,,,,,,,,,,, -157000,4.2026467,1.410418,,,,,,,,,,,,,, -157100,4.047621,1.3411624,,,,,,,,,,,,,, -157200,3.8014586,1.3146566,,,,,,,,,,,,,, -157300,3.957194,1.2835367,,,,,,,,,,,,,, -157400,3.7851589,1.3084736,,,,,,,,,,,,,, -157500,3.9422572,1.3245385,,,,,,,,,,,,,, -157600,3.902642,1.2568121,,,,,,,,,,,,,, -157700,3.930477,1.201621,,,,,,,,,,,,,, -157800,4.3181725,1.3902056,,,,,,,,,,,,,, -157900,4.3799915,1.3369758,,,,,,,,,,,,,, -158000,4.2386804,1.3383807,,,,,,,,,,,,,, -158046,,,0.7870694994926453,0.7952739596366882,0.6987599730491638,1.248716950416565,50000.0,0.5710000395774841,1.9607090950012207,10000.0,53079.482422828674,54933.41873812676,53079.482422828674,1844.0382986068728,4.386325359344482,0.0 -158100,4.0586286,1.3910251,,,,,,,,,,,,,, -158200,3.933639,1.3549006,,,,,,,,,,,,,, -158300,3.6288548,1.1864047,,,,,,,,,,,,,, -158400,4.130891,1.2869307,,,,,,,,,,,,,, -158500,4.2498813,1.4462527,,,,,,,,,,,,,, -158600,4.0694323,1.259901,,,,,,,,,,,,,, -158700,4.0584083,1.2504323,,,,,,,,,,,,,, -158800,4.194538,1.2674803,,,,,,,,,,,,,, -158900,4.2066183,1.3761804,,,,,,,,,,,,,, -159000,4.4071536,1.3205926,,,,,,,,,,,,,, -159100,3.9412892,1.2615514,,,,,,,,,,,,,, -159200,4.318915,1.3554655,,,,,,,,,,,,,, -159300,3.9702518,1.2119205,,,,,,,,,,,,,, -159400,4.060966,1.2515323,,,,,,,,,,,,,, -159500,4.4463763,1.390747,,,,,,,,,,,,,, -159567,,,0.7997050285339355,0.7518975138664246,0.7044199705123901,1.2122431993484497,50000.0,0.5830000042915344,1.915332674980164,10000.0,53589.56930446625,55461.414071798325,53589.56930446625,1861.841367483139,4.439189434051514,0.0 -159600,4.1290965,1.2601025,,,,,,,,,,,,,, -159700,4.314009,1.2304981,,,,,,,,,,,,,, -159800,4.1221404,1.2174511,,,,,,,,,,,,,, -159900,3.9975967,1.2584997,,,,,,,,,,,,,, -160000,4.24553,1.2232914,,,,,,,,,,,,,, -160100,4.451996,1.2663838,,,,,,,,,,,,,, -160200,4.219806,1.2521715,,,,,,,,,,,,,, -160300,4.0718126,1.2282633,,,,,,,,,,,,,, -160400,4.3444705,1.2478979,,,,,,,,,,,,,, -160500,4.1791115,1.2653332,,,,,,,,,,,,,, -160600,4.445244,1.2518415,,,,,,,,,,,,,, -160700,4.226293,1.3044616,,,,,,,,,,,,,, -160800,4.2046175,1.2856407,,,,,,,,,,,,,, -160900,4.037133,1.2563912,,,,,,,,,,,,,, -161000,4.1835346,1.2600693,,,,,,,,,,,,,, -161088,,,0.7997449040412903,0.7390770316123962,0.7134400010108948,1.1649560928344729,50000.0,0.5833000540733337,1.886929988861084,10000.0,54099.57530117035,55989.10298204422,54099.57530117035,1879.416244983673,4.49467396736145,0.0 -161100,4.0301685,1.291546,,,,,,,,,,,,,, -161200,4.1925373,1.12256,,,,,,,,,,,,,, -161300,4.0557194,1.1541461,,,,,,,,,,,,,, -161400,4.1408534,1.1763867,,,,,,,,,,,,,, -161500,4.420089,1.2452961,,,,,,,,,,,,,, -161600,4.2458057,1.2151572,,,,,,,,,,,,,, -161700,4.169203,1.1575887,,,,,,,,,,,,,, -161800,4.572299,1.1328816,,,,,,,,,,,,,, -161900,4.1327662,1.1685853,,,,,,,,,,,,,, -162000,3.9444995,1.1650486,,,,,,,,,,,,,, -162100,4.3367057,1.1655996,,,,,,,,,,,,,, -162200,4.4412265,1.2524285,,,,,,,,,,,,,, -162300,4.3494883,1.2376546,,,,,,,,,,,,,, -162400,4.316312,1.2279233,,,,,,,,,,,,,, -162500,4.3279843,1.2020777,,,,,,,,,,,,,, -162600,4.055311,1.1158749,,,,,,,,,,,,,, -162609,,,0.8374122977256775,0.5941317677497864,0.7181999683380127,1.1413167715072632,50000.0,0.5940999984741211,1.8632287979125977,10000.0,54609.529450416565,56516.9751701355,54609.529450416565,1897.226199388504,4.549734115600586,0.0 -162700,4.6256127,1.2547303,,,,,,,,,,,,,, -162800,4.594665,1.293786,,,,,,,,,,,,,, -162900,4.5233693,1.3781312,,,,,,,,,,,,,, -163000,4.524726,1.2438514,,,,,,,,,,,,,, -163100,4.4043813,1.2728755,,,,,,,,,,,,,, -163200,4.2228074,1.2273294,,,,,,,,,,,,,, -163300,4.3176928,1.1441491,,,,,,,,,,,,,, -163400,4.5335217,1.2090709,,,,,,,,,,,,,, -163500,4.099526,1.1076992,,,,,,,,,,,,,, -163600,4.2552385,1.1933296,,,,,,,,,,,,,, -163700,4.2378774,1.2100618,,,,,,,,,,,,,, -163800,4.485365,1.1833072,,,,,,,,,,,,,, -163900,4.427382,1.2201601,,,,,,,,,,,,,, -164000,4.5661507,1.1173401,,,,,,,,,,,,,, -164100,4.7677712,1.2811998,,,,,,,,,,,,,, -164130,,,0.8246572017669678,0.6434984803199768,0.7135599851608276,1.1545459032058716,50000.0,0.5824000239372253,1.869858741760254,10000.0,55119.610013246536,57044.69862341881,55119.610013246536,1914.762225151062,4.603350400924683,0.0 -164200,4.6764774,1.1973354,,,,,,,,,,,,,, -164300,4.3299255,1.1646833,,,,,,,,,,,,,, -164400,4.7508135,1.2961133,,,,,,,,,,,,,, -164500,4.543045,1.2288747,,,,,,,,,,,,,, -164600,4.4673905,1.1540325,,,,,,,,,,,,,, -164700,4.601302,1.2180791,,,,,,,,,,,,,, -164800,4.196551,0.987375,,,,,,,,,,,,,, -164900,4.8029137,1.1912129,,,,,,,,,,,,,, -165000,4.5748696,1.1856048,,,,,,,,,,,,,, -165100,4.263168,1.1176871,,,,,,,,,,,,,, -165200,4.78899,1.2698425,,,,,,,,,,,,,, -165300,4.571615,1.1431277,,,,,,,,,,,,,, -165400,4.399917,1.1190776,,,,,,,,,,,,,, -165500,4.5947385,1.1353929,,,,,,,,,,,,,, -165600,4.2740827,1.0566354,,,,,,,,,,,,,, -165651,,,0.8374919891357422,0.596867024898529,0.7252799868583679,1.1151868104934692,50000.0,0.5944000482559204,1.8382610082626345,10000.0,55629.50521326065,57573.0696310997,55629.50521326065,1933.12388586998,4.66450572013855,0.0 -165700,4.5604215,1.1932684,,,,,,,,,,,,,, -165800,4.4766164,1.1464356,,,,,,,,,,,,,, -165900,4.5897408,1.1639779,,,,,,,,,,,,,, -166000,4.333164,1.1420811,,,,,,,,,,,,,, -166100,4.605451,1.027483,,,,,,,,,,,,,, -166200,5.221563,1.2503529,,,,,,,,,,,,,, -166300,4.969811,1.1778206,,,,,,,,,,,,,, -166400,4.789336,1.1550452,,,,,,,,,,,,,, -166500,4.2816396,1.0708386,,,,,,,,,,,,,, -166600,4.734194,1.1323421,,,,,,,,,,,,,, -166700,4.9743376,1.1193953,,,,,,,,,,,,,, -166800,4.585442,1.1933076,,,,,,,,,,,,,, -166900,4.8220425,1.1353614,,,,,,,,,,,,,, -167000,4.275556,1.11692,,,,,,,,,,,,,, -167100,4.784027,1.1704834,,,,,,,,,,,,,, -167173,,,0.8379902839660645,0.5885335206985474,0.7274599671363831,1.0997930765151978,50000.0,0.6052000522613525,1.811101675033569,10000.0,56139.65770363808,58100.62710976601,56139.65770363808,1950.4011313915253,4.738975524902344,0.0 -167200,4.3810644,1.0780131,,,,,,,,,,,,,, -167300,4.860207,1.1489625,,,,,,,,,,,,,, -167400,4.3135033,1.0396805,,,,,,,,,,,,,, -167500,4.726345,1.101629,,,,,,,,,,,,,, -167600,4.556424,1.0512447,,,,,,,,,,,,,, -167700,4.7889915,1.1093355,,,,,,,,,,,,,, -167800,4.695213,1.1313183,,,,,,,,,,,,,, -167900,4.8703136,1.2490782,,,,,,,,,,,,,, -168000,4.901394,1.1858615,,,,,,,,,,,,,, -168100,4.4716935,1.1252601,,,,,,,,,,,,,, -168200,4.999421,1.1099231,,,,,,,,,,,,,, -168300,5.0894814,1.150178,,,,,,,,,,,,,, -168400,5.0582666,1.1314708,,,,,,,,,,,,,, -168500,4.9536524,1.1089652,,,,,,,,,,,,,, -168600,4.794086,1.0803317,,,,,,,,,,,,,, -168695,,,0.8406209945678711,0.5672109723091125,0.730459988117218,1.097076654434204,50000.0,0.6029000282287598,1.819189548492432,10000.0,56649.87207078934,58628.46256566048,56649.87207078934,1967.9093770980835,4.797713041305542,0.0 -168700,4.8756256,1.1865947,,,,,,,,,,,,,, -168800,4.9665394,1.1441294,,,,,,,,,,,,,, -168900,4.84734,1.0848522,,,,,,,,,,,,,, -169000,4.597927,1.0649768,,,,,,,,,,,,,, -169100,4.9592924,1.0968025,,,,,,,,,,,,,, -169200,4.816309,1.1590718,,,,,,,,,,,,,, -169300,4.6755047,1.0373629,,,,,,,,,,,,,, -169400,4.7724752,1.0541532,,,,,,,,,,,,,, -169500,4.922094,1.1867703,,,,,,,,,,,,,, -169600,4.9658365,1.0772574,,,,,,,,,,,,,, -169700,5.1318526,1.0990391,,,,,,,,,,,,,, -169800,4.646808,0.98556674,,,,,,,,,,,,,, -169900,5.174971,1.0237256,,,,,,,,,,,,,, -170000,4.988613,1.1148729,,,,,,,,,,,,,, -170100,4.758819,1.0751258,,,,,,,,,,,,,, -170200,5.3071856,1.1521109,,,,,,,,,,,,,, -170217,,,0.8475565910339355,0.5495447516441345,0.7321199774742126,1.083914279937744,50000.0,0.6044000387191772,1.8097580671310425,10000.0,57160.08248925209,59156.44516658783,57160.08248925209,1985.569852113724,4.8556458950042725,0.0 -170300,4.905506,1.065779,,,,,,,,,,,,,, -170400,4.705943,0.9685149,,,,,,,,,,,,,, -170500,4.742853,1.0227947,,,,,,,,,,,,,, -170600,4.8094463,1.1220217,,,,,,,,,,,,,, -170700,4.7276163,1.0589998,,,,,,,,,,,,,, -170800,5.085618,1.0820019,,,,,,,,,,,,,, -170900,4.8702083,1.0223097,,,,,,,,,,,,,, -171000,4.6962485,1.0023649,,,,,,,,,,,,,, -171100,5.035205,1.0644401,,,,,,,,,,,,,, -171200,4.9324884,1.0396187,,,,,,,,,,,,,, -171300,4.7059655,1.0627408,,,,,,,,,,,,,, -171400,4.690896,1.0407716,,,,,,,,,,,,,, -171500,4.8380966,1.1365316,,,,,,,,,,,,,, -171600,5.0288424,1.0388348,,,,,,,,,,,,,, -171700,5.3974037,1.1856304,,,,,,,,,,,,,, -171739,,,0.8664699792861938,0.4829581677913666,0.7368800044059753,1.0651110410690308,50000.0,0.6091000437736511,1.781596541404724,10000.0,57670.17210030556,59684.1325712204,57670.17210030556,2003.057029247284,4.913210391998291,0.0 -171800,4.884749,1.0545244,,,,,,,,,,,,,, -171900,5.6674657,1.0222589,,,,,,,,,,,,,, -172000,4.8731937,1.126548,,,,,,,,,,,,,, -172100,4.9459558,1.0242584,,,,,,,,,,,,,, -172200,5.073141,1.0600152,,,,,,,,,,,,,, -172300,4.8620443,1.0699862,,,,,,,,,,,,,, -172400,4.7856755,0.94466054,,,,,,,,,,,,,, -172500,5.1134477,1.0184574,,,,,,,,,,,,,, -172600,5.010029,1.0181454,,,,,,,,,,,,,, -172700,4.760458,0.99218196,,,,,,,,,,,,,, -172800,5.15323,0.9891846,,,,,,,,,,,,,, -172900,4.9043193,0.9987968,,,,,,,,,,,,,, -173000,5.3856807,1.030198,,,,,,,,,,,,,, -173100,4.668436,1.0211027,,,,,,,,,,,,,, -173200,5.3311257,1.0654793,,,,,,,,,,,,,, -173260,,,0.8635004758834839,0.4903871119022369,0.7384999990463257,1.065165638923645,50000.0,0.6144000291824341,1.779536247253418,10000.0,58180.46226763725,60212.06752181053,58180.46226763725,2020.586843013764,4.973502397537232,0.0 -173300,5.206032,0.99108285,,,,,,,,,,,,,, -173400,5.381643,0.98004174,,,,,,,,,,,,,, -173500,4.8565,0.9575061,,,,,,,,,,,,,, -173600,4.9201584,0.9633168,,,,,,,,,,,,,, -173700,5.093961,1.0289088,,,,,,,,,,,,,, -173800,5.0605326,1.0285931,,,,,,,,,,,,,, -173900,5.1612816,1.0027921,,,,,,,,,,,,,, -174000,4.923023,0.978785,,,,,,,,,,,,,, -174100,5.3408823,1.0274214,,,,,,,,,,,,,, -174200,5.108451,1.0084002,,,,,,,,,,,,,, -174300,5.103018,1.07534,,,,,,,,,,,,,, -174400,5.1565943,1.0082798,,,,,,,,,,,,,, -174500,5.0268874,0.9398066,,,,,,,,,,,,,, -174600,5.316613,0.99613667,,,,,,,,,,,,,, -174700,4.776165,0.91917455,,,,,,,,,,,,,, -174781,,,0.8696388602256775,0.4675772190093994,0.7422800064086914,1.044396162033081,50000.0,0.6187000274658203,1.7618048191070557,10000.0,58690.6109726429,60739.801298856735,58690.6109726429,2038.053610086441,5.037261247634888,0.0 -174800,5.131151,1.0424533,,,,,,,,,,,,,, -174900,4.961529,0.9386387,,,,,,,,,,,,,, -175000,5.1319757,0.98188317,,,,,,,,,,,,,, -175100,4.7410946,1.0011008,,,,,,,,,,,,,, -175200,5.126221,1.0003531,,,,,,,,,,,,,, -175300,4.8736553,1.0234563,,,,,,,,,,,,,, -175400,5.0683885,1.0299888,,,,,,,,,,,,,, -175500,5.010246,0.9840752,,,,,,,,,,,,,, -175600,4.859823,0.9272581,,,,,,,,,,,,,, -175700,4.9561796,0.94801956,,,,,,,,,,,,,, -175800,5.2272906,1.0514135,,,,,,,,,,,,,, -175900,5.2649927,0.9511851,,,,,,,,,,,,,, -176000,4.6293917,0.9415107,,,,,,,,,,,,,, -176100,4.929928,1.0371939,,,,,,,,,,,,,, -176200,5.107661,0.99635136,,,,,,,,,,,,,, -176300,4.921626,1.010559,,,,,,,,,,,,,, -176301,,,0.8717913031578064,0.4568901360034942,0.7438399791717529,1.0421042442321775,50000.0,0.6177000403404236,1.7554930448532104,10000.0,59200.50735998154,61267.31038618088,59200.50735998154,2055.5500314235687,5.10002589225769,0.0 -176400,4.9999714,0.95966053,,,,,,,,,,,,,, -176500,4.8021083,0.9771631,,,,,,,,,,,,,, -176600,4.969478,0.91254646,,,,,,,,,,,,,, -176700,5.253708,0.9804436,,,,,,,,,,,,,, -176800,5.0983543,0.87999487,,,,,,,,,,,,,, -176900,5.2454324,1.0408342,,,,,,,,,,,,,, -177000,5.261659,1.022862,,,,,,,,,,,,,, -177100,5.2084255,0.96065474,,,,,,,,,,,,,, -177200,5.4388027,1.0390645,,,,,,,,,,,,,, -177300,5.624382,0.9977849,,,,,,,,,,,,,, -177400,5.239889,1.0134338,,,,,,,,,,,,,, -177500,4.876822,0.8577122,,,,,,,,,,,,,, -177600,5.062587,1.0234976,,,,,,,,,,,,,, -177700,5.540904,1.094193,,,,,,,,,,,,,, -177800,5.157197,0.8987943,,,,,,,,,,,,,, -177822,,,0.8772520422935486,0.435047298669815,0.7451599836349487,1.0354770421981812,50000.0,0.6203000545501709,1.750431776046753,10000.0,59710.523655653,61794.90425157547,59710.523655653,2073.018792390824,5.156181573867798,0.0 -177900,5.1027813,0.9592284,,,,,,,,,,,,,, -178000,5.276114,0.90780425,,,,,,,,,,,,,, -178100,4.991815,0.8990126,,,,,,,,,,,,,, -178200,5.21427,0.91048104,,,,,,,,,,,,,, -178300,5.238967,0.96350074,,,,,,,,,,,,,, -178400,5.1196213,0.9776604,,,,,,,,,,,,,, -178500,5.3903055,0.99406433,,,,,,,,,,,,,, -178600,4.931907,0.976404,,,,,,,,,,,,,, -178700,5.5249023,0.93361473,,,,,,,,,,,,,, -178800,5.875884,1.0545704,,,,,,,,,,,,,, -178900,5.4869285,1.0530356,,,,,,,,,,,,,, -179000,5.5180473,0.97409266,,,,,,,,,,,,,, -179100,5.047878,0.8405904,,,,,,,,,,,,,, -179200,5.305094,0.9493791,,,,,,,,,,,,,, -179300,5.214817,0.93126214,,,,,,,,,,,,,, -179343,,,0.8833904266357422,0.4190746545791626,0.7470600008964539,1.032281517982483,50000.0,0.6243000030517578,1.7491586208343506,10000.0,60220.68492269516,62323.07969260216,60220.68492269516,2090.9237022399902,5.213481664657593,0.0 -179400,5.395344,0.9910859,,,,,,,,,,,,,, -179500,5.3869705,1.0218165,,,,,,,,,,,,,, -179600,5.2119956,1.0148314,,,,,,,,,,,,,, -179700,5.3229847,0.9316722,,,,,,,,,,,,,, -179800,6.02726,1.0007045,,,,,,,,,,,,,, -179900,5.4165354,0.9494337,,,,,,,,,,,,,, -180000,5.0830054,0.8368078,,,,,,,,,,,,,, -180100,5.1250544,0.9132382,,,,,,,,,,,,,, -180200,5.052483,0.94607556,,,,,,,,,,,,,, -180300,5.0963078,0.963951,,,,,,,,,,,,,, -180400,5.3465195,0.99825233,,,,,,,,,,,,,, -180500,5.571106,1.0144271,,,,,,,,,,,,,, -180600,5.1658554,0.8986087,,,,,,,,,,,,,, -180700,5.440499,0.9873177,,,,,,,,,,,,,, -180800,5.4966507,1.0181792,,,,,,,,,,,,,, -180863,,,0.8840082883834839,0.4111105501651764,0.7485599517822266,1.0272870063781738,50000.0,0.6240000128746033,1.746339201927185,10000.0,60730.57482671738,62850.83002829552,60730.57482671738,2108.6718752384186,5.271944284439087,0.0 -180900,5.3728075,1.0123749,,,,,,,,,,,,,, -181000,5.3783135,0.96980816,,,,,,,,,,,,,, -181100,5.694904,1.0228043,,,,,,,,,,,,,, -181200,5.511068,1.0753098,,,,,,,,,,,,,, -181300,5.102334,0.8929576,,,,,,,,,,,,,, -181400,5.5999618,0.9364625,,,,,,,,,,,,,, -181500,5.1857314,0.95542395,,,,,,,,,,,,,, -181600,5.165569,0.9617675,,,,,,,,,,,,,, -181700,5.0089717,0.9257985,,,,,,,,,,,,,, -181800,5.2614527,0.9404515,,,,,,,,,,,,,, -181900,5.031305,0.94264627,,,,,,,,,,,,,, -182000,5.1196203,0.97296983,,,,,,,,,,,,,, -182100,4.9297743,0.911828,,,,,,,,,,,,,, -182200,4.921997,0.9444384,,,,,,,,,,,,,, -182300,5.1909885,0.9923643,,,,,,,,,,,,,, -182383,,,0.8837690949440002,0.4141952693462372,0.7486000061035156,1.0234150886535645,50000.0,0.6265000104904175,1.7416620254516602,10000.0,61240.5575094223,63378.4419465065,61240.5575094223,2126.1875302791595,5.333211421966553,0.0 -182400,5.4358115,0.965978,,,,,,,,,,,,,, -182500,5.259094,0.9475452,,,,,,,,,,,,,, -182600,4.7191634,0.8752855,,,,,,,,,,,,,, -182700,5.678208,1.039314,,,,,,,,,,,,,, -182800,5.0967693,0.8375419,,,,,,,,,,,,,, -182900,5.378961,0.93767834,,,,,,,,,,,,,, -183000,5.611936,0.9736835,,,,,,,,,,,,,, -183100,5.615901,0.9998292,,,,,,,,,,,,,, -183200,5.4599347,0.9546598,,,,,,,,,,,,,, -183300,5.0283403,0.9560725,,,,,,,,,,,,,, -183400,5.385259,0.8762734,,,,,,,,,,,,,, -183500,5.07047,0.9141893,,,,,,,,,,,,,, -183600,5.1039615,0.9040534,,,,,,,,,,,,,, -183700,5.001322,0.87071025,,,,,,,,,,,,,, -183800,5.0847073,0.9902391,,,,,,,,,,,,,, -183900,5.124872,0.9027246,,,,,,,,,,,,,, -183904,,,0.8855827450752258,0.4065465927124023,0.7492199540138245,1.0245238542556765,50000.0,0.6260000467300415,1.7438762187957764,10000.0,61750.54298973084,63905.93356466293,61750.54298973084,2143.5833439826965,5.390666723251343,0.0 -184000,5.045528,0.9532101,,,,,,,,,,,,,, -184100,5.5241103,0.95917624,,,,,,,,,,,,,, -184200,5.5957866,0.9775341,,,,,,,,,,,,,, -184300,5.1790223,0.937968,,,,,,,,,,,,,, -184400,5.033718,0.84718543,,,,,,,,,,,,,, -184500,5.3298635,0.92482096,,,,,,,,,,,,,, -184600,5.471038,0.9553295,,,,,,,,,,,,,, -184700,5.0726557,0.8940711,,,,,,,,,,,,,, -184800,5.3707757,0.9126243,,,,,,,,,,,,,, -184900,5.1975384,0.9873948,,,,,,,,,,,,,, -185000,4.9392495,0.88581043,,,,,,,,,,,,,, -185100,5.3161893,0.9053847,,,,,,,,,,,,,, -185200,5.0999165,0.91382873,,,,,,,,,,,,,, -185300,4.9259095,0.85240763,,,,,,,,,,,,,, -185400,5.24523,0.92593074,,,,,,,,,,,,,, -185424,,,0.8825932741165161,0.4104780256748199,0.7490999698638916,1.0238432884216309,50000.0,0.6255000233650208,1.741317868232727,10000.0,62260.43835067749,64433.532873392105,62260.43835067749,2161.1756496429443,5.449017286300659,0.0 -185500,5.2802215,0.94168633,,,,,,,,,,,,,, -185600,5.690959,1.0071659,,,,,,,,,,,,,, -185700,5.1515403,0.9707155,,,,,,,,,,,,,, -185800,6.031607,0.94510406,,,,,,,,,,,,,, -185900,5.032657,0.95897436,,,,,,,,,,,,,, -186000,5.07511,0.9147426,,,,,,,,,,,,,, -186100,5.248062,0.9288209,,,,,,,,,,,,,, -186200,5.4673805,1.0000451,,,,,,,,,,,,,, -186300,5.385488,0.9206418,,,,,,,,,,,,,, -186400,5.353974,0.8908167,,,,,,,,,,,,,, -186500,4.9623713,0.8922589,,,,,,,,,,,,,, -186600,5.5273438,0.9897064,,,,,,,,,,,,,, -186666,,,0.8877750039100647,0.4056401550769806,0.7488799691200256,1.0230340957641602,50000.0,0.6260000467300415,1.7412824630737305,10000.0,62676.81129121781,64867.45725274086,62676.81129121781,2178.622392654419,5.510376214981079,0.0 -186666,,,,,,,,,,,62676.811291217804,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/eval_measurements.csv deleted file mode 100644 index dc54f10ae..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.791696310043335,0.0,32.42240762710571,1,0,32.42240762710571,0.0009000000427477,6.912177562713623,10000,50.21418762207031,0.0010961415246129,6.911187648773193,0.0011199999134987,6.912059783935547,50000 -35.341097831726074,0.0224995613098144,542.6018404960632,1508,0,542.6018404960632,0.100100003182888,5.178818702697754,10000,578.0186469554901,0.1525031924247741,4.490919589996338,0.1388799995183944,4.628087520599365,50000 -52.99773406982422,0.0534839630126953,1052.5842320919037,3015,0,1052.5842320919037,0.2410000115633011,3.804908037185669,10000,1105.7418451309204,0.3478156924247741,2.9913976192474365,0.3228600025177002,3.1558074951171875,50000 -70.55705571174622,0.086451768875122,1562.7190339565277,4524,0,1562.7190339565277,0.322700023651123,3.2601864337921143,10000,1633.5204920768738,0.4571906924247741,2.4114012718200684,0.4296599924564361,2.576097249984741,50000 -88.38283014297485,0.1140086650848388,2072.789868593216,6033,0,2072.789868593216,0.384300023317337,2.940765380859375,10000,2161.4965052604675,0.5725047588348389,1.785803198814392,0.5069999694824219,2.150753736495972,50000 -105.65375065803528,0.1418612003326416,2582.7939958572388,7544,0,2582.7939958572388,0.3979000151157379,2.8477373123168945,10000,2688.8509685993195,0.5734614133834839,1.7597079277038574,0.5224800109863281,2.043307065963745,50000 -123.14149379730225,0.1691446304321289,3092.727585554123,9055,0,3092.727585554123,0.4310000240802765,2.634711265563965,10000,3216.351620197296,0.6027184128761292,1.6339234113693235,0.5541399717330933,1.899752259254456,50000 -140.74559259414673,0.1966307163238525,3602.7257010936737,10567,0,3602.7257010936737,0.4435000121593475,2.5696029663085938,10000,3744.0334181785583,0.6264349222183228,1.537442684173584,0.5761199593544006,1.796309471130371,50000 -158.65228414535522,0.2255623340606689,4112.730927944183,12079,0,4112.730927944183,0.4504000246524811,2.5434465408325195,10000,4272.026992321014,0.6282485723495483,1.5030643939971924,0.5816199779510498,1.767730951309204,50000 -176.235848903656,0.2562377452850342,4622.837368488312,13591,0,4622.837368488312,0.4679000079631805,2.446086168289185,10000,4799.799918413162,0.644949734210968,1.434355854988098,0.5983200073242188,1.6884334087371826,50000 -193.6725881099701,0.2862663269042969,5132.94886469841,15103,0,5132.94886469841,0.4739000201225281,2.416137933731079,10000,5327.430913209915,0.6813815236091614,1.2665554285049438,0.6004999876022339,1.6677234172821045,50000 -212.0672664642334,0.3150801658630371,5642.935264825821,16615,0,5642.935264825821,0.4627000093460083,2.459235191345215,10000,5855.89298915863,0.666015625,1.3178972005844116,0.6035000085830688,1.6508041620254517,50000 -229.6052343845368,0.3456621170043945,6153.063757181168,18127,0,6153.063757181168,0.4727000296115875,2.390066385269165,10000,6383.643333911896,0.6575454473495483,1.3667640686035156,0.6034199595451355,1.652487874031067,50000 -247.2090749740601,0.3753409385681152,6663.034357786179,19639,0,6663.034357786179,0.4854000210762024,2.3747594356536865,10000,6911.30079293251,0.6714365482330322,1.3168739080429075,0.6110599637031555,1.6108882427215576,50000 -264.745756149292,0.4052441120147705,7173.039954662323,21151,0,7173.039954662323,0.492000013589859,2.333085536956787,10000,7438.92605638504,0.6639229655265808,1.3415753841400146,0.6109600067138672,1.6230710744857788,50000 -282.2081139087677,0.4366991519927978,7683.255749940872,22664,0,7683.255749940872,0.4735000133514404,2.4242477416992188,10000,7966.688236236572,0.6608737111091614,1.3586628437042236,0.6041399836540222,1.640969157218933,50000 -299.73340010643005,0.4670734405517578,8193.423243284225,24177,0,8193.423243284225,0.4853000342845917,2.374408483505249,10000,8494.462911128998,0.6926219463348389,1.2101314067840576,0.6140999794006348,1.6118823289871216,50000 -317.66002774238586,0.5013530254364014,8703.385436296463,25690,0,8703.385436296463,0.4904000163078308,2.407479763031006,10000,9022.43826174736,0.6783721446990967,1.2624869346618652,0.6137799620628357,1.6312508583068848,50000 -336.1436400413513,0.5356407165527344,9213.344497919084,27203,0,9213.344497919084,0.4914000332355499,2.33853816986084,10000,9550.966819286346,0.6784717440605164,1.2687817811965942,0.6159200072288513,1.5993962287902832,50000 -353.745454788208,0.562096118927002,9723.4910697937,28716,0,9723.4910697937,0.4785000085830688,2.4138495922088623,10000,10078.793704748154,0.6600964665412903,1.350205421447754,0.6032999753952026,1.6606773138046265,50000 -371.32841968536377,0.6031649112701416,10233.69608449936,30229,0,10233.69608449936,0.465800017118454,2.448777437210083,10000,10606.674701690674,0.6486168503761292,1.4096747636795044,0.5981199741363525,1.6882946491241455,50000 -388.6515634059906,0.6397063732147217,10743.68086385727,31742,0,10743.68086385727,0.4930000305175781,2.3522725105285645,10000,11134.072080612184,0.6831752061843872,1.2528233528137207,0.6192399859428406,1.6038295030593872,50000 -406.1087462902069,0.6711766719818115,11253.931037902832,33255,0,11253.931037902832,0.5005000233650208,2.294847011566162,10000,11661.8629693985,0.7012914419174194,1.1544904708862305,0.6235600113868713,1.5596383810043335,50000 -423.68754959106445,0.7041482925415039,11763.943603754044,34769,0,11763.943603754044,0.5094000101089478,2.247567176818848,10000,12189.540555000303,0.7081871628761292,1.1398144960403442,0.6377599835395813,1.4949623346328735,50000 -441.0496399402618,0.7320287227630615,12273.98748230934,36281,0,12273.98748230934,0.4908000230789184,2.300273895263672,10000,12717.026733636856,0.6898915767669678,1.230826497077942,0.6264599561691284,1.5518196821212769,50000 -458.7441716194153,0.7648324966430664,12784.045414686205,37795,0,12784.045414686205,0.4897000193595886,2.340038299560547,10000,13244.865463733671,0.6783322691917419,1.2680907249450684,0.6167399883270264,1.5898174047470093,50000 -476.41528153419495,0.797590970993042,13294.002183437347,39308,0,13294.002183437347,0.499500036239624,2.292027950286865,10000,13772.57915186882,0.6907684803009033,1.2177668809890747,0.6317999958992004,1.5191069841384888,50000 -493.7730107307434,0.8306655883789062,13803.981996297836,40822,0,13803.981996297836,0.4988000094890594,2.3107590675354004,10000,14300.002563476562,0.7167769074440002,1.1003592014312744,0.6209999918937683,1.567124843597412,50000 -511.47321367263794,0.8653779029846191,14313.949091911316,42335,0,14313.949091911316,0.5099000334739685,2.241546869277954,10000,14827.75766301155,0.7200653553009033,1.0819741487503052,0.6363199949264526,1.509085774421692,50000 -528.9610199928284,0.9021854400634766,14824.149604320526,43848,0,14824.149604320526,0.5049000382423401,2.2731704711914062,10000,15355.536323785782,0.6968072056770325,1.1697343587875366,0.6294599771499634,1.5310777425765991,50000 -546.4602868556976,0.9373815059661864,15334.175417423248,45362,0,15334.175417423248,0.5038000345230103,2.2749900817871094,10000,15883.150108098984,0.7052175998687744,1.146565079689026,0.6387799978256226,1.4946125745773315,50000 -564.2131533622742,0.9761536121368408,15844.352715015411,46877,0,15844.352715015411,0.515500009059906,2.1804099082946777,10000,16411.172538280487,0.700613796710968,1.157369613647461,0.6425999999046326,1.466826319694519,50000 -581.6602036952972,1.0123159885406494,16354.314910888672,48390,0,16354.314910888672,0.5108000040054321,2.2163541316986084,10000,16938.669855833054,0.7057557106018066,1.150970220565796,0.6429199576377869,1.46943998336792,50000 -599.1710863113403,1.04778790473938,16864.458737134933,49904,0,16864.458737134933,0.513700008392334,2.200896739959717,10000,17466.411892175674,0.758230984210968,0.9265372157096864,0.6476399898529053,1.4463788270950315,50000 -616.627453327179,1.0893511772155762,17374.71946334839,51418,0,17374.71946334839,0.515500009059906,2.1797661781311035,10000,17994.22289633751,0.7257453799247742,1.053958535194397,0.6452999711036682,1.461386799812317,50000 -634.0705862045288,1.129570722579956,17884.717614650726,52932,0,17884.717614650726,0.5178000330924988,2.1698689460754395,10000,18521.756559848785,0.7232341766357422,1.0687999725341797,0.646399974822998,1.4467506408691406,50000 -652.1868450641632,1.1715331077575684,18394.86057209969,54446,0,18394.86057209969,0.5199000239372253,2.165266513824463,10000,19050.109229803085,0.7245296239852905,1.073808670043945,0.6561799645423889,1.4147099256515503,50000 -669.8906226158142,1.211322784423828,18904.974903345108,55961,0,18904.974903345108,0.5128000378608704,2.175576448440552,10000,19578.019869804382,0.7122927308082581,1.1093248128890991,0.6516199707984924,1.4258030652999878,50000 -687.4907431602478,1.2484371662139893,19415.11595249176,57475,0,19415.11595249176,0.5162000060081482,2.202220916748047,10000,20105.84994530677,0.70609450340271,1.13842511177063,0.6474599838256836,1.457894802093506,50000 -704.9665122032166,1.285801649093628,19925.138954639435,58989,0,19925.138954639435,0.5228000283241272,2.17522406578064,10000,20633.438943624496,0.7527901530265808,0.939852774143219,0.6464999914169312,1.4735795259475708,50000 -722.6707804203033,1.3243467807769775,20435.203814983368,60503,0,20435.203814983368,0.5309000015258789,2.1159284114837646,10000,21161.29817390442,0.7373046875,1.001809000968933,0.6541599631309509,1.409217119216919,50000 -739.9851453304291,1.360926389694214,20945.369954109192,62017,0,20945.369954109192,0.5299000144004822,2.1289846897125244,10000,21688.86900305748,0.7272400856018066,1.0499770641326904,0.6545799970626831,1.4146177768707275,50000 -757.318776845932,1.4008488655090332,21455.5750977993,63531,0,21455.5750977993,0.5286000370979309,2.131467580795288,10000,22216.50008130073,0.7257254123687744,1.053279995918274,0.6569799780845642,1.4097890853881836,50000 -775.1528396606445,1.4429562091827393,21965.726397037502,65046,0,21965.726397037502,0.5245000123977661,2.156514883041382,10000,22744.58133101464,0.7092434763908386,1.1382043361663818,0.6457200050354004,1.4658386707305908,50000 -792.794264793396,1.482445240020752,22475.8989508152,66561,0,22475.8989508152,0.5267000198364258,2.145815134048462,10000,23272.48766350746,0.7173748016357422,1.0853809118270874,0.6520000100135803,1.4343310594558716,50000 -810.4946658611298,1.5225434303283691,22986.08687877655,68075,0,22986.08687877655,0.5333000421524048,2.1278412342071533,10000,23800.46881747245,0.7562978267669678,0.9239857792854308,0.6578199863433838,1.4056472778320312,50000 -827.9055438041687,1.5612945556640625,23496.2087392807,69589,0,23496.2087392807,0.5343000292778015,2.123914957046509,10000,24328.093178987503,0.7335578799247742,1.0107797384262085,0.6535399556159973,1.4225095510482788,50000 -845.568799495697,1.6021020412445068,24006.12292265892,71103,0,24006.12292265892,0.5351000428199768,2.0900559425354004,10000,24855.764329195023,0.7446388602256775,0.9683708548545836,0.668940007686615,1.3582696914672852,50000 -862.9524285793304,1.64097261428833,24516.380070209503,72618,0,24516.380070209503,0.5350000262260437,2.121037483215332,10000,25383.496671438217,0.7384406924247742,0.9917774200439452,0.6642999649047852,1.362199068069458,50000 -880.3766114711761,1.679915428161621,25026.531310796738,74133,0,25026.531310796738,0.5330000519752502,2.1223597526550293,10000,25911.16432285309,0.7292729616165161,1.0414572954177856,0.6623799800872803,1.3939917087554932,50000 -898.0017364025116,1.7220737934112549,25536.60414552689,75647,0,25536.60414552689,0.5415000319480896,2.0805461406707764,10000,26438.958388328552,0.7429248690605164,0.9894379377365112,0.6701799631118774,1.3528382778167725,50000 -915.695422887802,1.7612836360931396,26046.688113451004,77161,0,26046.688113451004,0.5258000493049622,2.1486873626708984,10000,26966.82772350312,0.7538663744926453,0.933739185333252,0.6592999696731567,1.3939270973205566,50000 -933.3991062641144,1.8085589408874512,26556.600678920742,78675,0,26556.600678920742,0.5458000302314758,2.0804595947265625,10000,27494.543236494064,0.7557198405265808,0.925915777683258,0.6734399795532227,1.3371450901031494,50000 -951.3102207183838,1.8615412712097168,27066.66562986374,80190,0,27066.66562986374,0.5390000343322754,2.0906593799591064,10000,28022.62483239174,0.7496811151504517,0.947109878063202,0.6690799593925476,1.3521523475646973,50000 -968.7020602226256,1.9029486179351809,27576.59165716172,81704,0,27576.59165716172,0.5506000518798828,2.008496284484864,10000,28550.037356376648,0.7504384517669678,0.9462156295776368,0.6738199591636658,1.321738362312317,50000 -986.0461583137512,1.945774793624878,28086.49978995323,83218,0,28086.49978995323,0.5326000452041626,2.109022378921509,10000,29077.384110450745,0.7326610088348389,1.0291450023651123,0.6612399816513062,1.3827193975448608,50000 -1003.904506444931,1.9876015186309808,28596.737845897675,84732,0,28596.737845897675,0.5491000413894653,2.0207293033599854,10000,29605.5751888752,0.7528101205825806,0.936428725719452,0.6754199862480164,1.3129218816757202,50000 -1021.369663476944,2.03339958190918,29106.95946264267,86247,0,29106.95946264267,0.5403000116348267,2.084959983825684,10000,30133.36285853386,0.7684949040412903,0.8733639717102051,0.6672799587249756,1.357527494430542,50000 -1039.2736871242523,2.084794282913208,29616.86198425293,87761,0,29616.86198425293,0.534600019454956,2.101754903793335,10000,30661.273845672607,0.7531489133834839,0.9372791051864624,0.6677599549293518,1.3522411584854126,50000 -1056.668353557587,2.1289286613464355,30126.8455953598,89275,0,30126.8455953598,0.5367000102996826,2.0899970531463623,10000,31188.748248815536,0.748465359210968,0.9442353248596193,0.6695799827575684,1.3373024463653564,50000 -1074.0231430530548,2.169685840606689,30636.927599668503,90789,0,30636.927599668503,0.5403000116348267,2.083833694458008,10000,31716.2787964344,0.7472097873687744,0.958838939666748,0.6690399646759033,1.3475358486175537,50000 -1091.316163063049,2.213765859603882,31146.947617292404,92303,0,31146.947617292404,0.5448000431060791,2.058915138244629,10000,32243.688675642014,0.7487842440605164,0.9438230991363524,0.6720799803733826,1.337753415107727,50000 -1109.5773482322693,2.25687313079834,31657.00089836121,93815,0,31657.00089836121,0.5523000359535217,2.045220375061035,10000,32772.09962654114,0.7768255472183228,0.8307573795318604,0.6793599724769592,1.3027634620666504,50000 -1126.968933343887,2.3015289306640625,32167.20648097992,95330,0,32167.20648097992,0.5487000346183777,2.0057966709136963,10000,33299.79450273514,0.7841796875,0.8101245760917664,0.6791799664497375,1.3027790784835815,50000 -1144.3458700180054,2.350280284881592,32677.118038654327,96844,0,32677.118038654327,0.5568000078201294,2.0034146308898926,10000,33827.184403419495,0.7759685516357422,0.8316084146499634,0.6843999624252319,1.2773557901382446,50000 -1161.9749476909635,2.399492025375366,33187.21611762047,98358,0,33187.21611762047,0.558899998664856,1.9838387966156008,10000,34355.01317358017,0.7757692933082581,0.8344160318374634,0.6873799562454224,1.2644789218902588,50000 -1179.2202832698822,2.444276094436645,33697.1601536274,99872,0,33697.1601536274,0.5624000430107117,1.971338391304016,10000,34882.30074048042,0.7700095772743225,0.8556905388832092,0.6885600090026855,1.267644286155701,50000 -1196.8182473182678,2.4896559715271,34207.31728053093,101387,0,34207.31728053093,0.5706000328063965,1.9568129777908323,10000,35410.15484523773,0.7763273119926453,0.8312370777130127,0.6921399831771851,1.2392897605895996,50000 -1214.52796459198,2.526557683944702,34717.62727546692,102901,0,34717.62727546692,0.5493000149726868,2.0565006732940674,10000,35938.262838840485,0.8036909699440002,0.7377466559410095,0.6786800026893616,1.3162022829055786,50000 -1231.9329543113708,2.574941396713257,35227.735830545425,104416,0,35227.735830545425,0.5697000026702881,1.912181735038757,10000,36465.87762641907,0.8050262928009033,0.7113240957260132,0.6964199542999268,1.2255223989486694,50000 -1249.209624528885,2.623447895050049,35737.8859539032,105931,0,35737.8859539032,0.5741000175476074,1.9478965997695925,10000,36993.40591478348,0.7947624325752258,0.7517896294593811,0.69896000623703,1.214800238609314,50000 -1266.8068754673004,2.673878192901612,36247.9498064518,107445,0,36247.9498064518,0.5676000118255615,1.9585601091384888,10000,37521.17093753815,0.7896404266357422,0.7744253277778625,0.6942200064659119,1.2482731342315674,50000 -1284.7805352211,2.722330331802368,36757.94217133522,108958,0,36757.94217133522,0.5630000233650208,1.98625648021698,10000,38049.23787140846,0.7810905575752258,0.8077903389930725,0.6916999816894531,1.2639470100402832,50000 -1302.261702299118,2.787564992904663,37267.98639202118,110472,0,37267.98639202118,0.5737000107765198,1.9548555612564087,10000,38576.88211965561,0.7847775816917419,0.7921836972236633,0.6990000009536743,1.2325977087020874,50000 -1319.8213753700256,2.836085319519043,37778.13521909714,111986,0,37778.13521909714,0.5711000561714172,1.9375004768371584,10000,39104.69184041023,0.8351203799247742,0.5935096740722656,0.6996999979019165,1.23037850856781,50000 -1337.3728301525116,2.8874781131744385,38288.519334316254,113501,0,38288.519334316254,0.5791000127792358,1.906409859657288,10000,39632.731977939606,0.8151108026504517,0.672951340675354,0.6999399662017822,1.2182530164718628,50000 -1354.794580221176,2.939371109008789,38798.61583328247,115015,0,38798.61583328247,0.5841000080108643,1.930444598197937,10000,40160.355558395386,0.80961012840271,0.6891085505485535,0.705299973487854,1.2037498950958252,50000 -1372.6245946884155,2.987029552459717,39308.739602565765,116530,0,39308.739602565765,0.5679000020027161,1.9563889503479004,10000,40688.40866851807,0.8014189600944519,0.7300856113433838,0.7023199796676636,1.212501883506775,50000 -1390.393991947174,3.032386541366577,39818.97498655319,118045,0,39818.97498655319,0.5809000134468079,1.901960849761963,10000,41216.51202297211,0.8037906289100647,0.7134524583816528,0.7069399952888489,1.1913037300109863,50000 -1407.8446052074432,3.0808701515197754,40328.96877479553,119559,0,40328.96877479553,0.567300021648407,1.958403825759888,10000,41744.0572385788,0.7993462681770325,0.7303955554962158,0.7000199556350708,1.2318073511123655,50000 -1425.3332903385162,3.131650686264038,40839.1434905529,121073,0,40839.1434905529,0.5856000185012817,1.882620930671692,10000,42271.82386422157,0.8441884517669678,0.5559610724449158,0.7035399675369263,1.1964123249053955,50000 -1442.7801163196564,3.1828622817993164,41349.36904430389,122588,0,41349.36904430389,0.5904000401496887,1.8778742551803589,10000,42799.59975862503,0.83203125,0.6034864783287048,0.7110399603843689,1.172864317893982,50000 -1460.7830305099487,3.2322473526000977,41859.42814803124,124103,0,41859.42814803124,0.5773000121116638,1.9539101123809808,10000,43327.76603150368,0.8218072056770325,0.6350313425064087,0.7073799967765808,1.1969265937805176,50000 -1478.3529794216156,3.285027503967285,42369.41453456879,125617,0,42369.41453456879,0.5925000309944153,1.8581621646881104,10000,43855.42738699913,0.8284637928009033,0.6209262609481812,0.7165399789810181,1.1599438190460205,50000 -1495.813749074936,3.3334877490997314,42879.60580301285,127132,0,42879.60580301285,0.5856000185012817,1.9250835180282595,10000,44383.18127179146,0.8194355964660645,0.6457960605621338,0.7070199847221375,1.2006008625030518,50000 -1513.3805103302002,3.381448268890381,43389.64476656914,128647,0,43389.64476656914,0.5854000449180603,1.894575119018555,10000,44910.88752961159,0.8309949040412903,0.6023525595664978,0.7137399911880493,1.1708685159683228,50000 -1530.9731702804563,3.430178165435791,43899.57999563217,130160,0,43899.57999563217,0.5945000052452087,1.87068510055542,10000,45438.51849889755,0.8641980290412903,0.4822084009647369,0.7169199585914612,1.153667449951172,50000 -1549.4466173648834,3.483417510986328,44409.718037605286,131675,0,44409.718037605286,0.5956000089645386,1.894251823425293,10000,45967.23632621765,0.853535532951355,0.5158511400222778,0.7181800007820129,1.1625126600265503,50000 -1567.1695573329926,3.533687829971313,44919.616294384,133189,0,44919.616294384,0.5996000170707703,1.8523231744766235,10000,46494.96126246452,0.8512635231018066,0.518912672996521,0.7234999537467957,1.1488239765167236,50000 -1584.762847185135,3.58777403831482,45429.56394815445,134703,0,45429.56394815445,0.6012000441551208,1.8476316928863523,10000,47022.609437942505,0.8521404266357422,0.5233726501464844,0.7243199944496155,1.1318014860153198,50000 -1602.4290103912354,3.637843132019043,45939.55491280556,136217,0,45939.55491280556,0.6022000312805176,1.8703584671020508,10000,47550.36924123764,0.8546316623687744,0.5151609182357788,0.7261799573898315,1.1283327341079712,50000 -1620.2749452590942,3.691930770874024,46449.772149086,137732,0,46449.772149086,0.5824000239372253,1.9551299810409544,10000,48078.53934550285,0.8350605964660645,0.5876715183258057,0.7061399817466736,1.225954532623291,50000 -1637.8468651771543,3.743335962295532,46959.832931280136,139246,0,46959.832931280136,0.6080000400543213,1.8433808088302608,10000,48606.27636384964,0.8834900856018066,0.4090985059738159,0.7286199927330017,1.1191855669021606,50000 -1655.6108510494232,3.794053316116333,47470.02541399002,140761,0,47470.02541399002,0.6022000312805176,1.865876317024231,10000,49134.33640837669,0.8775510191917419,0.4234741926193237,0.7308799624443054,1.1201832294464111,50000 -1673.117571592331,3.8474504947662354,47980.01170706749,142275,0,47980.01170706749,0.6037000417709351,1.860552668571472,10000,49661.93563914299,0.869559109210968,0.4549353718757629,0.729699969291687,1.1267502307891846,50000 -1690.5912556648254,3.902509450912476,48489.98080587387,143788,0,48489.98080587387,0.6041000485420227,1.8592493534088133,10000,50189.48695039749,0.8769331574440002,0.4277708530426025,0.7282199859619141,1.118950605392456,50000 -1708.0588409900663,3.9540607929229736,48999.96798682213,145301,0,48999.96798682213,0.6092000007629395,1.8546922206878664,10000,50717.04645681381,0.8763153553009033,0.4283612966537475,0.7317599654197693,1.1129244565963743,50000 -1725.5483317375183,4.008327007293701,49510.08322453499,146815,0,49510.08322453499,0.6107000112533569,1.8342806100845337,10000,51244.758184194565,0.8947902917861938,0.3664727807044983,0.7371000051498413,1.0927094221115112,50000 -1743.310753107071,4.059577941894531,50020.06638741493,148329,0,50020.06638741493,0.6067000031471252,1.849771738052368,10000,51772.60830807686,0.904934585094452,0.327584832906723,0.7351999878883362,1.1068998575210571,50000 -1760.661033153534,4.114696264266968,50530.12015199661,149843,0,50530.12015199661,0.6149000525474548,1.8277121782302856,10000,52300.12011003494,0.901566445827484,0.3411449491977691,0.7376599907875061,1.0968347787857056,50000 -1778.430143117905,4.172338247299194,51040.12117242813,151357,0,51040.12117242813,0.6163000464439392,1.8217467069625848,10000,52828.000644207,0.9032804369926452,0.335657387971878,0.7389000058174133,1.0878074169158936,50000 -1796.0045523643494,4.2290143966674805,51550.18557286263,152872,0,51550.18557286263,0.6160000562667847,1.8262497186660769,10000,53355.74990034104,0.9049944281578064,0.3283750116825104,0.7387199997901917,1.1011860370635986,50000 -1813.832791805268,4.287072420120239,52060.21515059471,154386,0,52060.21515059471,0.616100013256073,1.8450510501861568,10000,53883.71820926666,0.9026825428009032,0.3313542008399963,0.7376799583435059,1.1047000885009766,50000 -1831.2789142131803,4.342725038528442,52570.35527801514,155900,0,52570.35527801514,0.6147000193595886,1.8321329355239868,10000,54411.412940740585,0.9301259517669678,0.2502636015415191,0.7440800070762634,1.0782766342163086,50000 -1848.89755487442,4.397834777832031,53080.38762998581,157414,0,53080.38762998581,0.6175000071525574,1.845854997634888,10000,54939.17242026329,0.9299465417861938,0.2439799904823303,0.742579996585846,1.0852277278900146,50000 -1866.427173614502,4.452563047409058,53590.50220036507,158929,0,53590.50220036507,0.6193000078201294,1.8259567022323608,10000,55466.92384791374,0.9267578125,0.2488669753074646,0.7440999746322632,1.0787391662597656,50000 -1884.1840229034424,4.509575366973877,54100.62163352966,160443,0,54100.62163352966,0.6230000257492065,1.8304731845855715,10000,55994.9102306366,0.927355706691742,0.2522351443767547,0.7450199723243713,1.0835787057876587,50000 -1901.915011882782,4.5679240226745605,54610.548907756805,161957,0,54610.548907756805,0.6254000067710876,1.832878589630127,10000,56522.68081855774,0.928730845451355,0.2452500760555267,0.7450199723243713,1.079473853111267,50000 -1919.380197048188,4.62821364402771,55120.64704847336,163471,0,55120.64704847336,0.6281000375747681,1.822871208190918,10000,57050.35792398453,0.9323580861091614,0.2346342206001281,0.7455399632453918,1.0828678607940674,50000 -1936.8064422607424,4.680109977722168,55630.59670042992,164984,0,55630.59670042992,0.6242000460624695,1.831413507461548,10000,57577.83935189247,0.9518494606018066,0.1781339794397354,0.7479000091552734,1.078954577445984,50000 -1954.128136396408,4.73781156539917,56140.51879167557,166498,0,56140.51879167557,0.6274000406265259,1.8304781913757324,10000,58105.194816827774,0.9480029940605164,0.1860027015209198,0.746999979019165,1.078598976135254,50000 -1971.827444314957,4.794127941131592,56650.4464943409,168012,0,56650.4464943409,0.626800000667572,1.822851657867432,10000,58632.931473732,0.9465281963348388,0.1897157132625579,0.7482799887657166,1.0680350065231323,50000 -1989.6279754638672,4.857915878295898,57160.48657393456,169526,0,57160.48657393456,0.627500057220459,1.824575662612915,10000,59160.888498306274,0.9483816623687744,0.1834833025932312,0.7514199614524841,1.0675344467163086,50000 -2007.9708423614504,4.914608955383301,57670.477942705154,171040,0,57670.477942705154,0.6289000511169434,1.8254389762878416,10000,59689.33236479759,0.950215220451355,0.1781926602125167,0.7507599592208862,1.0668600797653198,50000 -2025.5694625377653,4.971103191375732,58180.52669739723,172554,0,58180.52669739723,0.6328000426292419,1.818946361541748,10000,60217.08939146996,0.9512715339660645,0.1751319766044616,0.7514399886131287,1.0617859363555908,50000 -2042.9718658924105,5.033244848251343,58690.58052444458,174068,0,58690.58052444458,0.6322000026702881,1.8160579204559328,10000,60744.659641981125,0.9608178734779358,0.1495468467473983,0.7513999938964844,1.0650503635406494,50000 -2060.5704913139343,5.091560125350952,59200.59376168251,175582,0,59200.59376168251,0.6325000524520874,1.820577621459961,10000,61272.38275671005,0.958227038383484,0.1554597616195678,0.7520999908447266,1.0633561611175537,50000 -2078.398533344269,5.148442506790161,59710.533801317215,177096,0,59710.533801317215,0.6310000419616699,1.812854528427124,10000,61800.26037359238,0.9596021771430968,0.1510808914899826,0.7536599636077881,1.0556384325027466,50000 -2096.097542285919,5.214344024658203,60220.49709105492,178610,0,60220.49709105492,0.6331000328063965,1.810882329940796,10000,62328.04130482674,0.95902419090271,0.1536361575126648,0.7538599967956543,1.0557180643081665,50000 -2113.768192052841,5.269103527069092,60730.41804790497,180123,0,60730.41804790497,0.6345000267028809,1.8141127824783323,10000,62855.74048471451,0.9585060477256776,0.1502209901809692,0.7537199854850769,1.0573610067367554,50000 -2131.186047077179,5.325862407684326,61240.57641124725,181637,0,61240.57641124725,0.633400022983551,1.8171926736831665,10000,63383.426446676254,0.9604192972183228,0.1468931436538696,0.7546799778938293,1.0567476749420166,50000 -2149.111617088318,5.389604806900024,61750.644112825394,183151,0,61750.644112825394,0.6343000531196594,1.8140968084335327,10000,63911.53665685654,0.9615154266357422,0.1437261253595352,0.7544800043106079,1.0552575588226318,50000 -2166.556547164917,5.439829349517822,62260.7015068531,184665,0,62260.7015068531,0.6353000402450562,1.81319522857666,10000,64439.141348838806,0.961316168308258,0.1466423720121383,0.7540599703788757,1.0541568994522097,50000 -2183.882416248321,5.515969038009644,62770.74982833862,186179,0,62770.74982833862,0.634600043296814,1.8140379190444944,10000,64966.6444914341,0.9597018361091614,0.1491877883672714,0.7537399530410767,1.0555750131607056,50000 -2201.129693508148,5.575985908508301,62934.7845287323,186666,0,62934.7845287323,0.6342000365257263,1.8139417171478271,10000,65148.003600120544,0.9602798223495483,0.14576058089733124,0.75409996509552,1.054492473602295,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/measurements.csv deleted file mode 100644 index 17ea35658..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6918318,6.925329,,,,,,,,,,,,,, -1,,,0.0010961415246129,6.911187648773193,0.0011199999134987,6.912059783935547,50000.0,0.0009000000427477,6.912177562713623,10000.0,32.42240762710571,50.21418762207031,32.42240762710571,17.791696310043335,0.0,0.0 -100,0.68010306,6.8061194,,,,,,,,,,,,,, -200,0.7912972,6.5498357,,,,,,,,,,,,,, -300,0.9233515,6.2825,,,,,,,,,,,,,, -400,1.9480759,6.025085,,,,,,,,,,,,,, -500,2.7521656,5.8995733,,,,,,,,,,,,,, -600,3.2107162,5.768036,,,,,,,,,,,,,, -700,2.7495427,5.444252,,,,,,,,,,,,,, -800,4.07331,5.367702,,,,,,,,,,,,,, -900,6.307762,5.1746826,,,,,,,,,,,,,, -1000,4.8720717,5.1621137,,,,,,,,,,,,,, -1100,6.7799687,4.8602285,,,,,,,,,,,,,, -1200,5.3209705,4.8530145,,,,,,,,,,,,,, -1300,4.2748423,4.6772776,,,,,,,,,,,,,, -1400,4.7735868,4.536643,,,,,,,,,,,,,, -1500,4.8370247,4.549308,,,,,,,,,,,,,, -1508,,,0.1525031924247741,4.490919589996338,0.1388799995183944,4.628087520599365,50000.0,0.100100003182888,5.178818702697754,10000.0,542.6018404960632,578.0186469554901,542.6018404960632,35.341097831726074,0.0224995613098144,0.0 -1600,6.397731,4.2896647,,,,,,,,,,,,,, -1700,4.7973413,4.3483677,,,,,,,,,,,,,, -1800,5.1549053,4.1644945,,,,,,,,,,,,,, -1900,8.38488,4.1065536,,,,,,,,,,,,,, -2000,5.0073113,4.001276,,,,,,,,,,,,,, -2100,4.3073244,3.9707727,,,,,,,,,,,,,, -2200,5.09961,3.8643932,,,,,,,,,,,,,, -2300,6.4988174,3.8363607,,,,,,,,,,,,,, -2400,4.491079,3.7946305,,,,,,,,,,,,,, -2500,3.1406317,3.7284772,,,,,,,,,,,,,, -2600,4.0467978,3.582374,,,,,,,,,,,,,, -2700,3.5973494,3.5674882,,,,,,,,,,,,,, -2800,3.170961,3.4151442,,,,,,,,,,,,,, -2900,3.3388345,3.3644943,,,,,,,,,,,,,, -3000,3.4562995,3.4026618,,,,,,,,,,,,,, -3015,,,0.3478156924247741,2.9913976192474365,0.3228600025177002,3.1558074951171875,50000.0,0.2410000115633011,3.804908037185669,10000.0,1052.5842320919037,1105.7418451309204,1052.5842320919037,52.99773406982422,0.0534839630126953,0.0 -3100,4.32333,3.2892976,,,,,,,,,,,,,, -3200,3.998074,3.196799,,,,,,,,,,,,,, -3300,5.4813094,3.2182977,,,,,,,,,,,,,, -3400,2.7264595,3.2498507,,,,,,,,,,,,,, -3500,3.8051019,3.190624,,,,,,,,,,,,,, -3600,2.7730155,3.2352388,,,,,,,,,,,,,, -3700,3.7109964,3.1043622,,,,,,,,,,,,,, -3800,3.7934277,3.0882845,,,,,,,,,,,,,, -3900,2.775068,2.9941616,,,,,,,,,,,,,, -4000,3.4051487,3.0312803,,,,,,,,,,,,,, -4100,3.0215108,2.9076812,,,,,,,,,,,,,, -4200,3.4173312,2.8691645,,,,,,,,,,,,,, -4300,3.5399742,2.8554292,,,,,,,,,,,,,, -4400,2.2822447,2.6607292,,,,,,,,,,,,,, -4500,4.551394,2.7052364,,,,,,,,,,,,,, -4524,,,0.4571906924247741,2.4114012718200684,0.4296599924564361,2.576097249984741,50000.0,0.322700023651123,3.2601864337921143,10000.0,1562.7190339565277,1633.5204920768738,1562.7190339565277,70.55705571174622,0.086451768875122,0.0 -4600,2.430105,2.7949471,,,,,,,,,,,,,, -4700,2.0382528,2.7927787,,,,,,,,,,,,,, -4800,2.5707622,2.6729836,,,,,,,,,,,,,, -4900,2.4575942,2.6161785,,,,,,,,,,,,,, -5000,2.0066833,2.7520397,,,,,,,,,,,,,, -5100,2.2189484,2.5912764,,,,,,,,,,,,,, -5200,2.6864684,2.6896117,,,,,,,,,,,,,, -5300,2.3768747,2.5453389,,,,,,,,,,,,,, -5400,2.1050029,2.5079775,,,,,,,,,,,,,, -5500,2.2642784,2.5391197,,,,,,,,,,,,,, -5600,2.4235435,2.4937992,,,,,,,,,,,,,, -5700,2.3132713,2.5551457,,,,,,,,,,,,,, -5800,2.696835,2.4565337,,,,,,,,,,,,,, -5900,1.8293024,2.4533577,,,,,,,,,,,,,, -6000,2.4622889,2.4788277,,,,,,,,,,,,,, -6033,,,0.5725047588348389,1.785803198814392,0.5069999694824219,2.150753736495972,50000.0,0.384300023317337,2.940765380859375,10000.0,2072.789868593216,2161.4965052604675,2072.789868593216,88.38283014297485,0.1140086650848388,0.0 -6100,1.9696679,2.4148715,,,,,,,,,,,,,, -6200,2.5494688,2.3807964,,,,,,,,,,,,,, -6300,2.4138718,2.5286655,,,,,,,,,,,,,, -6400,1.5093867,2.4926085,,,,,,,,,,,,,, -6500,3.0047112,2.390893,,,,,,,,,,,,,, -6600,1.8483016,2.5354104,,,,,,,,,,,,,, -6700,2.3365963,2.3767452,,,,,,,,,,,,,, -6800,2.32264,2.4171634,,,,,,,,,,,,,, -6900,1.6640183,2.228588,,,,,,,,,,,,,, -7000,2.306021,2.1501284,,,,,,,,,,,,,, -7100,2.0110009,2.283247,,,,,,,,,,,,,, -7200,1.8829626,2.2637997,,,,,,,,,,,,,, -7300,2.007664,2.2733853,,,,,,,,,,,,,, -7400,1.889896,2.3738992,,,,,,,,,,,,,, -7500,1.6987846,2.2726116,,,,,,,,,,,,,, -7544,,,0.5734614133834839,1.7597079277038574,0.5224800109863281,2.043307065963745,50000.0,0.3979000151157379,2.8477373123168945,10000.0,2582.7939958572388,2688.8509685993195,2582.7939958572388,105.65375065803528,0.1418612003326416,0.0 -7600,1.470468,2.2241445,,,,,,,,,,,,,, -7700,2.269506,2.2589273,,,,,,,,,,,,,, -7800,2.0052466,2.2243485,,,,,,,,,,,,,, -7900,1.9882227,2.2422547,,,,,,,,,,,,,, -8000,2.0189128,2.1193361,,,,,,,,,,,,,, -8100,1.8131036,2.1689363,,,,,,,,,,,,,, -8200,1.3964225,2.140108,,,,,,,,,,,,,, -8300,1.8088672,2.1076953,,,,,,,,,,,,,, -8400,1.8877631,2.1346896,,,,,,,,,,,,,, -8500,1.6758372,2.2033036,,,,,,,,,,,,,, -8600,1.6381117,2.2488747,,,,,,,,,,,,,, -8700,2.0961015,2.0874734,,,,,,,,,,,,,, -8800,1.906143,2.1183493,,,,,,,,,,,,,, -8900,1.6031384,2.0836017,,,,,,,,,,,,,, -9000,1.6475294,2.18958,,,,,,,,,,,,,, -9055,,,0.6027184128761292,1.6339234113693235,0.5541399717330933,1.899752259254456,50000.0,0.4310000240802765,2.634711265563965,10000.0,3092.727585554123,3216.351620197296,3092.727585554123,123.14149379730225,0.1691446304321289,0.0 -9100,1.5044968,2.0451646,,,,,,,,,,,,,, -9200,1.4881612,2.1175156,,,,,,,,,,,,,, -9300,1.598202,2.08878,,,,,,,,,,,,,, -9400,1.9449213,2.033057,,,,,,,,,,,,,, -9500,2.1494784,2.069245,,,,,,,,,,,,,, -9600,1.3962138,2.1647656,,,,,,,,,,,,,, -9700,2.0289497,2.2453642,,,,,,,,,,,,,, -9800,1.8858681,2.1100223,,,,,,,,,,,,,, -9900,1.3945575,2.1081655,,,,,,,,,,,,,, -10000,1.5054178,2.1031919,,,,,,,,,,,,,, -10100,2.0972483,2.2060223,,,,,,,,,,,,,, -10200,1.7289541,2.0909553,,,,,,,,,,,,,, -10300,1.8152189,2.1179438,,,,,,,,,,,,,, -10400,1.59412,2.0064645,,,,,,,,,,,,,, -10500,1.6396528,2.0212507,,,,,,,,,,,,,, -10567,,,0.6264349222183228,1.537442684173584,0.5761199593544006,1.796309471130371,50000.0,0.4435000121593475,2.5696029663085938,10000.0,3602.7257010936737,3744.0334181785583,3602.7257010936737,140.74559259414673,0.1966307163238525,0.0 -10600,1.4706999,1.972539,,,,,,,,,,,,,, -10700,1.864385,1.942506,,,,,,,,,,,,,, -10800,1.6934477,2.0432348,,,,,,,,,,,,,, -10900,1.5605414,1.8816153,,,,,,,,,,,,,, -11000,1.6484715,2.098398,,,,,,,,,,,,,, -11100,1.8524135,2.0349784,,,,,,,,,,,,,, -11200,1.3696475,2.07374,,,,,,,,,,,,,, -11300,1.7060682,2.042878,,,,,,,,,,,,,, -11400,1.7215416,2.155798,,,,,,,,,,,,,, -11500,1.8861332,2.0328298,,,,,,,,,,,,,, -11600,1.6131768,1.9789412,,,,,,,,,,,,,, -11700,2.055166,2.022081,,,,,,,,,,,,,, -11800,1.5446231,2.0308127,,,,,,,,,,,,,, -11900,1.2436901,1.8267956,,,,,,,,,,,,,, -12000,1.4088178,1.9275553,,,,,,,,,,,,,, -12079,,,0.6282485723495483,1.5030643939971924,0.5816199779510498,1.767730951309204,50000.0,0.4504000246524811,2.5434465408325195,10000.0,4112.730927944183,4272.026992321014,4112.730927944183,158.65228414535522,0.2255623340606689,0.0 -12100,1.6998246,1.9902539,,,,,,,,,,,,,, -12200,1.6075056,1.9963726,,,,,,,,,,,,,, -12300,1.807678,2.1164303,,,,,,,,,,,,,, -12400,1.8335841,2.009938,,,,,,,,,,,,,, -12500,1.5906398,1.9419563,,,,,,,,,,,,,, -12600,1.5497775,1.8802594,,,,,,,,,,,,,, -12700,2.4893122,2.0678513,,,,,,,,,,,,,, -12800,1.7173376,1.8775357,,,,,,,,,,,,,, -12900,1.8244414,1.9454575,,,,,,,,,,,,,, -13000,1.9324846,1.9552588,,,,,,,,,,,,,, -13100,1.8401495,1.9758862,,,,,,,,,,,,,, -13200,1.564938,2.0323317,,,,,,,,,,,,,, -13300,1.4387997,1.975103,,,,,,,,,,,,,, -13400,1.435291,1.9313122,,,,,,,,,,,,,, -13500,1.8094578,1.9211301,,,,,,,,,,,,,, -13591,,,0.644949734210968,1.434355854988098,0.5983200073242188,1.6884334087371826,50000.0,0.4679000079631805,2.446086168289185,10000.0,4622.837368488312,4799.799918413162,4622.837368488312,176.235848903656,0.2562377452850342,0.0 -13600,1.637474,1.8648759,,,,,,,,,,,,,, -13700,1.3221849,1.8669605,,,,,,,,,,,,,, -13800,1.5108106,1.9878379,,,,,,,,,,,,,, -13900,1.368279,1.8456895,,,,,,,,,,,,,, -14000,1.376009,1.9314295,,,,,,,,,,,,,, -14100,1.5885189,1.8659523,,,,,,,,,,,,,, -14200,1.718177,2.0212445,,,,,,,,,,,,,, -14300,1.6095234,1.9314953,,,,,,,,,,,,,, -14400,1.5958651,1.8840423,,,,,,,,,,,,,, -14500,1.7913058,1.8523272,,,,,,,,,,,,,, -14600,1.6784539,1.9111811,,,,,,,,,,,,,, -14700,1.802837,1.9696457,,,,,,,,,,,,,, -14800,1.7304962,1.7736915,,,,,,,,,,,,,, -14900,1.5803102,1.9655573,,,,,,,,,,,,,, -15000,1.3360739,1.93851,,,,,,,,,,,,,, -15100,1.389606,1.9015368,,,,,,,,,,,,,, -15103,,,0.6813815236091614,1.2665554285049438,0.6004999876022339,1.6677234172821045,50000.0,0.4739000201225281,2.416137933731079,10000.0,5132.94886469841,5327.430913209915,5132.94886469841,193.6725881099701,0.2862663269042969,0.0 -15200,2.2626405,1.8162167,,,,,,,,,,,,,, -15300,1.5298078,1.8435428,,,,,,,,,,,,,, -15400,1.877291,2.0656023,,,,,,,,,,,,,, -15500,1.7219782,1.7285608,,,,,,,,,,,,,, -15600,1.5315425,1.790456,,,,,,,,,,,,,, -15700,1.6377741,1.8979326,,,,,,,,,,,,,, -15800,1.4426632,1.9200352,,,,,,,,,,,,,, -15900,1.8339746,1.8087122,,,,,,,,,,,,,, -16000,1.7605549,1.7795165,,,,,,,,,,,,,, -16100,1.5819668,1.8114694,,,,,,,,,,,,,, -16200,1.6702687,1.9306401,,,,,,,,,,,,,, -16300,1.8296965,1.8597425,,,,,,,,,,,,,, -16400,2.2513583,1.832304,,,,,,,,,,,,,, -16500,1.5201334,1.8907523,,,,,,,,,,,,,, -16600,1.821788,1.8413198,,,,,,,,,,,,,, -16615,,,0.666015625,1.3178972005844116,0.6035000085830688,1.6508041620254517,50000.0,0.4627000093460083,2.459235191345215,10000.0,5642.935264825821,5855.89298915863,5642.935264825821,212.0672664642334,0.3150801658630371,0.0 -16700,1.6748333,1.8099241,,,,,,,,,,,,,, -16800,1.4678357,1.8463603,,,,,,,,,,,,,, -16900,1.8432106,1.8171227,,,,,,,,,,,,,, -17000,1.3785381,1.9679588,,,,,,,,,,,,,, -17100,2.013431,1.811119,,,,,,,,,,,,,, -17200,1.5678395,1.8954405,,,,,,,,,,,,,, -17300,1.5688723,1.86275,,,,,,,,,,,,,, -17400,1.9554999,1.9579445,,,,,,,,,,,,,, -17500,1.4061334,1.7794293,,,,,,,,,,,,,, -17600,1.5506356,1.7702334,,,,,,,,,,,,,, -17700,2.0012977,1.7731712,,,,,,,,,,,,,, -17800,1.9092442,1.7848864,,,,,,,,,,,,,, -17900,1.4718832,1.8464859,,,,,,,,,,,,,, -18000,1.4402694,1.8325709,,,,,,,,,,,,,, -18100,1.7645869,1.9220017,,,,,,,,,,,,,, -18127,,,0.6575454473495483,1.3667640686035156,0.6034199595451355,1.652487874031067,50000.0,0.4727000296115875,2.390066385269165,10000.0,6153.063757181168,6383.643333911896,6153.063757181168,229.6052343845368,0.3456621170043945,0.0 -18200,1.5590018,1.799426,,,,,,,,,,,,,, -18300,1.4696466,1.8902166,,,,,,,,,,,,,, -18400,1.7430738,1.9913853,,,,,,,,,,,,,, -18500,1.664995,1.8282772,,,,,,,,,,,,,, -18600,1.7481974,1.8163351,,,,,,,,,,,,,, -18700,1.8018695,1.7126971,,,,,,,,,,,,,, -18800,1.6773592,1.8033841,,,,,,,,,,,,,, -18900,1.5680525,1.8325434,,,,,,,,,,,,,, -19000,1.5418812,1.8336986,,,,,,,,,,,,,, -19100,1.5650817,1.8322479,,,,,,,,,,,,,, -19200,1.6164457,1.8460406,,,,,,,,,,,,,, -19300,1.6207947,1.858954,,,,,,,,,,,,,, -19400,1.858372,1.8923469,,,,,,,,,,,,,, -19500,1.5581669,1.7687801,,,,,,,,,,,,,, -19600,1.8135948,1.8141834,,,,,,,,,,,,,, -19639,,,0.6714365482330322,1.3168739080429075,0.6110599637031555,1.6108882427215576,50000.0,0.4854000210762024,2.3747594356536865,10000.0,6663.034357786179,6911.30079293251,6663.034357786179,247.2090749740601,0.3753409385681152,0.0 -19700,1.6879021,1.9533256,,,,,,,,,,,,,, -19800,1.6708437,1.7003925,,,,,,,,,,,,,, -19900,1.4615829,1.6349258,,,,,,,,,,,,,, -20000,1.693494,1.8036333,,,,,,,,,,,,,, -20100,2.103185,1.757796,,,,,,,,,,,,,, -20200,1.6169859,1.8734516,,,,,,,,,,,,,, -20300,1.6367545,1.6806017,,,,,,,,,,,,,, -20400,1.7020552,1.8736491,,,,,,,,,,,,,, -20500,1.8382634,1.7862177,,,,,,,,,,,,,, -20600,1.6030805,1.6800663,,,,,,,,,,,,,, -20700,1.53309,1.6670713,,,,,,,,,,,,,, -20800,1.7662575,1.8334584,,,,,,,,,,,,,, -20900,1.8755144,1.903579,,,,,,,,,,,,,, -21000,1.6189504,1.8088096,,,,,,,,,,,,,, -21100,1.6889956,1.7815889,,,,,,,,,,,,,, -21151,,,0.6639229655265808,1.3415753841400146,0.6109600067138672,1.6230710744857788,50000.0,0.492000013589859,2.333085536956787,10000.0,7173.039954662323,7438.92605638504,7173.039954662323,264.745756149292,0.4052441120147705,0.0 -21200,2.0125077,1.7835052,,,,,,,,,,,,,, -21300,2.0354562,1.9173809,,,,,,,,,,,,,, -21400,1.6633719,1.7004468,,,,,,,,,,,,,, -21500,1.5030895,1.7758889,,,,,,,,,,,,,, -21600,1.635489,1.7176867,,,,,,,,,,,,,, -21700,1.737065,1.7569051,,,,,,,,,,,,,, -21800,1.599628,1.8564837,,,,,,,,,,,,,, -21900,1.6657162,1.8711679,,,,,,,,,,,,,, -22000,1.6814147,1.76156,,,,,,,,,,,,,, -22100,1.4767891,1.774959,,,,,,,,,,,,,, -22200,1.9750636,1.8623906,,,,,,,,,,,,,, -22300,1.9433092,1.7084582,,,,,,,,,,,,,, -22400,1.7976029,1.7747043,,,,,,,,,,,,,, -22500,1.6392617,1.6808777,,,,,,,,,,,,,, -22600,2.0172777,1.7709135,,,,,,,,,,,,,, -22664,,,0.6608737111091614,1.3586628437042236,0.6041399836540222,1.640969157218933,50000.0,0.4735000133514404,2.4242477416992188,10000.0,7683.255749940872,7966.688236236572,7683.255749940872,282.2081139087677,0.4366991519927978,0.0 -22700,1.9722885,1.8103889,,,,,,,,,,,,,, -22800,1.5365053,1.7966806,,,,,,,,,,,,,, -22900,1.713146,1.7074791,,,,,,,,,,,,,, -23000,1.8065917,1.852678,,,,,,,,,,,,,, -23100,1.9225891,1.8607737,,,,,,,,,,,,,, -23200,1.9568334,1.7306279,,,,,,,,,,,,,, -23300,1.7922955,1.8606067,,,,,,,,,,,,,, -23400,1.5755,1.6763073,,,,,,,,,,,,,, -23500,1.7085327,1.7410454,,,,,,,,,,,,,, -23600,1.7838854,1.8240051,,,,,,,,,,,,,, -23700,1.7920845,1.8643538,,,,,,,,,,,,,, -23800,2.0529423,1.7035624,,,,,,,,,,,,,, -23900,1.5165539,1.7776937,,,,,,,,,,,,,, -24000,1.6111386,1.8397944,,,,,,,,,,,,,, -24100,1.8123764,1.8353595,,,,,,,,,,,,,, -24177,,,0.6926219463348389,1.2101314067840576,0.6140999794006348,1.6118823289871216,50000.0,0.4853000342845917,2.374408483505249,10000.0,8193.423243284225,8494.462911128998,8193.423243284225,299.73340010643005,0.4670734405517578,0.0 -24200,2.0666187,1.8245122,,,,,,,,,,,,,, -24300,1.6487038,1.6820785,,,,,,,,,,,,,, -24400,1.9203795,1.811274,,,,,,,,,,,,,, -24500,2.240077,1.8402864,,,,,,,,,,,,,, -24600,1.9322623,1.7477022,,,,,,,,,,,,,, -24700,1.991659,1.6242386,,,,,,,,,,,,,, -24800,1.702415,1.8190465,,,,,,,,,,,,,, -24900,1.7892954,1.8281198,,,,,,,,,,,,,, -25000,1.6810025,1.8319879,,,,,,,,,,,,,, -25100,1.8118604,1.7361151,,,,,,,,,,,,,, -25200,1.8946855,1.712991,,,,,,,,,,,,,, -25300,1.6517427,1.7359306,,,,,,,,,,,,,, -25400,1.6751683,1.7566174,,,,,,,,,,,,,, -25500,1.8159363,1.7761033,,,,,,,,,,,,,, -25600,1.893867,1.7302394,,,,,,,,,,,,,, -25690,,,0.6783721446990967,1.2624869346618652,0.6137799620628357,1.6312508583068848,50000.0,0.4904000163078308,2.407479763031006,10000.0,8703.385436296463,9022.43826174736,8703.385436296463,317.66002774238586,0.5013530254364014,0.0 -25700,1.80658,1.6895535,,,,,,,,,,,,,, -25800,1.6606104,1.7611265,,,,,,,,,,,,,, -25900,1.6681194,1.7444813,,,,,,,,,,,,,, -26000,1.8155802,1.851486,,,,,,,,,,,,,, -26100,1.7453831,1.745832,,,,,,,,,,,,,, -26200,2.1130834,1.709817,,,,,,,,,,,,,, -26300,1.7248522,1.8632696,,,,,,,,,,,,,, -26400,1.7125199,1.7263458,,,,,,,,,,,,,, -26500,2.20043,1.7771573,,,,,,,,,,,,,, -26600,1.8209286,1.6839027,,,,,,,,,,,,,, -26700,1.727047,1.6058947,,,,,,,,,,,,,, -26800,1.6653955,1.7490405,,,,,,,,,,,,,, -26900,1.8219677,1.7234099,,,,,,,,,,,,,, -27000,1.5471166,1.8199667,,,,,,,,,,,,,, -27100,1.6975688,1.8152094,,,,,,,,,,,,,, -27200,1.9084395,1.7026807,,,,,,,,,,,,,, -27203,,,0.6784717440605164,1.2687817811965942,0.6159200072288513,1.5993962287902832,50000.0,0.4914000332355499,2.33853816986084,10000.0,9213.344497919084,9550.966819286346,9213.344497919084,336.1436400413513,0.5356407165527344,0.0 -27300,2.1285906,1.8141477,,,,,,,,,,,,,, -27400,2.007283,1.6812288,,,,,,,,,,,,,, -27500,1.6142173,1.592174,,,,,,,,,,,,,, -27600,1.9408045,1.6974387,,,,,,,,,,,,,, -27700,1.7397947,1.6662008,,,,,,,,,,,,,, -27800,1.8013816,1.7749757,,,,,,,,,,,,,, -27900,1.9207559,1.7722939,,,,,,,,,,,,,, -28000,1.9608009,1.6861289,,,,,,,,,,,,,, -28100,1.915234,1.6957483,,,,,,,,,,,,,, -28200,1.5895617,1.7194208,,,,,,,,,,,,,, -28300,1.8061925,1.847044,,,,,,,,,,,,,, -28400,1.8829377,1.6657696,,,,,,,,,,,,,, -28500,1.6838363,1.7262006,,,,,,,,,,,,,, -28600,2.0910673,1.7835722,,,,,,,,,,,,,, -28700,1.6917158,1.7423068,,,,,,,,,,,,,, -28716,,,0.6600964665412903,1.350205421447754,0.6032999753952026,1.6606773138046265,50000.0,0.4785000085830688,2.4138495922088623,10000.0,9723.4910697937,10078.793704748154,9723.4910697937,353.745454788208,0.562096118927002,0.0 -28800,1.8803438,1.6678684,,,,,,,,,,,,,, -28900,1.950854,1.6056774,,,,,,,,,,,,,, -29000,1.9227993,1.7623389,,,,,,,,,,,,,, -29100,1.8163923,1.7640983,,,,,,,,,,,,,, -29200,1.7019974,1.7514856,,,,,,,,,,,,,, -29300,1.9471794,1.831437,,,,,,,,,,,,,, -29400,1.7117727,1.6951085,,,,,,,,,,,,,, -29500,1.7114972,1.768162,,,,,,,,,,,,,, -29600,1.5945356,1.7175773,,,,,,,,,,,,,, -29700,1.6615298,1.6062611,,,,,,,,,,,,,, -29800,1.8279408,1.810255,,,,,,,,,,,,,, -29900,1.7877475,1.6939783,,,,,,,,,,,,,, -30000,1.850959,1.717098,,,,,,,,,,,,,, -30100,1.8182942,1.747205,,,,,,,,,,,,,, -30200,1.5580245,1.803586,,,,,,,,,,,,,, -30229,,,0.6486168503761292,1.4096747636795044,0.5981199741363525,1.6882946491241455,50000.0,0.465800017118454,2.448777437210083,10000.0,10233.69608449936,10606.674701690674,10233.69608449936,371.32841968536377,0.6031649112701416,0.0 -30300,1.981609,1.7419771,,,,,,,,,,,,,, -30400,2.020477,1.8951069,,,,,,,,,,,,,, -30500,1.5382932,1.7359576,,,,,,,,,,,,,, -30600,1.7868392,1.6587911,,,,,,,,,,,,,, -30700,2.2193398,1.6775439,,,,,,,,,,,,,, -30800,1.7192419,1.8114324,,,,,,,,,,,,,, -30900,1.6656369,1.5859271,,,,,,,,,,,,,, -31000,1.6452343,1.6809489,,,,,,,,,,,,,, -31100,1.7956785,1.7200344,,,,,,,,,,,,,, -31200,1.684076,1.6297034,,,,,,,,,,,,,, -31300,1.7650244,1.7174639,,,,,,,,,,,,,, -31400,1.7185812,1.7554643,,,,,,,,,,,,,, -31500,1.9424765,1.7085605,,,,,,,,,,,,,, -31600,2.056947,1.6875244,,,,,,,,,,,,,, -31700,1.6489899,1.7027507,,,,,,,,,,,,,, -31742,,,0.6831752061843872,1.2528233528137207,0.6192399859428406,1.6038295030593872,50000.0,0.4930000305175781,2.3522725105285645,10000.0,10743.68086385727,11134.072080612184,10743.68086385727,388.6515634059906,0.6397063732147217,0.0 -31800,1.7756933,1.6388386,,,,,,,,,,,,,, -31900,1.9976078,1.7317082,,,,,,,,,,,,,, -32000,1.8200017,1.678474,,,,,,,,,,,,,, -32100,1.6912327,1.59348,,,,,,,,,,,,,, -32200,1.711298,1.7469358,,,,,,,,,,,,,, -32300,1.9974034,1.6578871,,,,,,,,,,,,,, -32400,1.793199,1.6967167,,,,,,,,,,,,,, -32500,1.7345456,1.6142858,,,,,,,,,,,,,, -32600,1.6582507,1.7215734,,,,,,,,,,,,,, -32700,1.7524866,1.7649428,,,,,,,,,,,,,, -32800,1.6962298,1.6865079,,,,,,,,,,,,,, -32900,1.9154484,1.6688399,,,,,,,,,,,,,, -33000,1.722983,1.7931716,,,,,,,,,,,,,, -33100,1.780714,1.5824146,,,,,,,,,,,,,, -33200,1.8894874,1.7293595,,,,,,,,,,,,,, -33255,,,0.7012914419174194,1.1544904708862305,0.6235600113868713,1.5596383810043335,50000.0,0.5005000233650208,2.294847011566162,10000.0,11253.931037902832,11661.8629693985,11253.931037902832,406.1087462902069,0.6711766719818115,0.0 -33300,1.5829427,1.7159387,,,,,,,,,,,,,, -33400,1.9813998,1.6940945,,,,,,,,,,,,,, -33500,1.9721727,1.6423479,,,,,,,,,,,,,, -33600,1.9980699,1.7149491,,,,,,,,,,,,,, -33700,2.3682306,1.7095728,,,,,,,,,,,,,, -33800,2.184164,1.774567,,,,,,,,,,,,,, -33900,1.7026016,1.6905178,,,,,,,,,,,,,, -34000,1.9835125,1.795645,,,,,,,,,,,,,, -34100,1.9466674,1.633489,,,,,,,,,,,,,, -34200,1.7426947,1.6473472,,,,,,,,,,,,,, -34300,1.8379754,1.6678575,,,,,,,,,,,,,, -34400,1.7813627,1.5745864,,,,,,,,,,,,,, -34500,2.33811,1.7938601,,,,,,,,,,,,,, -34600,1.6911749,1.702092,,,,,,,,,,,,,, -34700,1.7331706,1.7099133,,,,,,,,,,,,,, -34769,,,0.7081871628761292,1.1398144960403442,0.6377599835395813,1.4949623346328735,50000.0,0.5094000101089478,2.247567176818848,10000.0,11763.943603754044,12189.540555000303,11763.943603754044,423.68754959106445,0.7041482925415039,0.0 -34800,1.6709614,1.6761701,,,,,,,,,,,,,, -34900,1.9977589,1.6897318,,,,,,,,,,,,,, -35000,1.8156927,1.7685635,,,,,,,,,,,,,, -35100,1.8176653,1.6460495,,,,,,,,,,,,,, -35200,2.0641756,1.7424155,,,,,,,,,,,,,, -35300,1.7142074,1.5077204,,,,,,,,,,,,,, -35400,1.6866655,1.5287894,,,,,,,,,,,,,, -35500,1.7623241,1.7431394,,,,,,,,,,,,,, -35600,1.808088,1.7748562,,,,,,,,,,,,,, -35700,1.6748058,1.7598349,,,,,,,,,,,,,, -35800,1.7526428,1.6637709,,,,,,,,,,,,,, -35900,1.8191088,1.6648128,,,,,,,,,,,,,, -36000,1.6292272,1.6698351,,,,,,,,,,,,,, -36100,1.9347097,1.7244519,,,,,,,,,,,,,, -36200,1.9181819,1.6703494,,,,,,,,,,,,,, -36281,,,0.6898915767669678,1.230826497077942,0.6264599561691284,1.5518196821212769,50000.0,0.4908000230789184,2.300273895263672,10000.0,12273.98748230934,12717.026733636856,12273.98748230934,441.0496399402618,0.7320287227630615,0.0 -36300,1.9045041,1.7409987,,,,,,,,,,,,,, -36400,1.9322592,1.695755,,,,,,,,,,,,,, -36500,1.7337283,1.5996894,,,,,,,,,,,,,, -36600,1.6322818,1.6671169,,,,,,,,,,,,,, -36700,1.7974031,1.6428924,,,,,,,,,,,,,, -36800,1.6143228,1.6237726,,,,,,,,,,,,,, -36900,1.8395412,1.8233202,,,,,,,,,,,,,, -37000,1.9570459,1.6566772,,,,,,,,,,,,,, -37100,1.8621458,1.6542077,,,,,,,,,,,,,, -37200,1.8653126,1.7495359,,,,,,,,,,,,,, -37300,1.8692592,1.604075,,,,,,,,,,,,,, -37400,1.9348372,1.6265707,,,,,,,,,,,,,, -37500,2.1100492,1.6673625,,,,,,,,,,,,,, -37600,1.8799893,1.8321668,,,,,,,,,,,,,, -37700,1.6815586,1.64485,,,,,,,,,,,,,, -37795,,,0.6783322691917419,1.2680907249450684,0.6167399883270264,1.5898174047470093,50000.0,0.4897000193595886,2.340038299560547,10000.0,12784.045414686205,13244.865463733671,12784.045414686205,458.7441716194153,0.7648324966430664,0.0 -37800,1.726932,1.6466242,,,,,,,,,,,,,, -37900,2.0403137,1.719548,,,,,,,,,,,,,, -38000,1.771754,1.684169,,,,,,,,,,,,,, -38100,1.8063134,1.6483293,,,,,,,,,,,,,, -38200,1.8871691,1.6707194,,,,,,,,,,,,,, -38300,1.8179086,1.6577342,,,,,,,,,,,,,, -38400,1.8214645,1.6979631,,,,,,,,,,,,,, -38500,1.8492349,1.6570196,,,,,,,,,,,,,, -38600,1.8531963,1.7192945,,,,,,,,,,,,,, -38700,1.918615,1.8154669,,,,,,,,,,,,,, -38800,2.0061922,1.663893,,,,,,,,,,,,,, -38900,2.0353086,1.6968284,,,,,,,,,,,,,, -39000,1.7574238,1.6138257,,,,,,,,,,,,,, -39100,2.1276112,1.6888719,,,,,,,,,,,,,, -39200,1.7719153,1.7065718,,,,,,,,,,,,,, -39300,1.6760817,1.5804706,,,,,,,,,,,,,, -39308,,,0.6907684803009033,1.2177668809890747,0.6317999958992004,1.5191069841384888,50000.0,0.499500036239624,2.292027950286865,10000.0,13294.002183437347,13772.57915186882,13294.002183437347,476.41528153419495,0.797590970993042,0.0 -39400,1.7650856,1.5199351,,,,,,,,,,,,,, -39500,1.9275637,1.7283137,,,,,,,,,,,,,, -39600,2.0409818,1.6311479,,,,,,,,,,,,,, -39700,2.0428414,1.7059526,,,,,,,,,,,,,, -39800,1.699198,1.7656035,,,,,,,,,,,,,, -39900,2.1441534,1.6218609,,,,,,,,,,,,,, -40000,1.8436968,1.6815684,,,,,,,,,,,,,, -40100,1.8472297,1.730342,,,,,,,,,,,,,, -40200,1.7615032,1.6771464,,,,,,,,,,,,,, -40300,1.8897749,1.762935,,,,,,,,,,,,,, -40400,1.9125379,1.610006,,,,,,,,,,,,,, -40500,1.7928809,1.7174437,,,,,,,,,,,,,, -40600,2.0915182,1.6396778,,,,,,,,,,,,,, -40700,1.564624,1.6510112,,,,,,,,,,,,,, -40800,1.8656586,1.6700593,,,,,,,,,,,,,, -40822,,,0.7167769074440002,1.1003592014312744,0.6209999918937683,1.567124843597412,50000.0,0.4988000094890594,2.3107590675354004,10000.0,13803.981996297836,14300.002563476562,13803.981996297836,493.7730107307434,0.8306655883789062,0.0 -40900,1.821055,1.7618834,,,,,,,,,,,,,, -41000,1.8080258,1.5593319,,,,,,,,,,,,,, -41100,1.7837297,1.5389425,,,,,,,,,,,,,, -41200,1.7864498,1.7311169,,,,,,,,,,,,,, -41300,1.8275118,1.712439,,,,,,,,,,,,,, -41400,2.0320559,1.623207,,,,,,,,,,,,,, -41500,1.8611118,1.6030947,,,,,,,,,,,,,, -41600,2.1992784,1.7063351,,,,,,,,,,,,,, -41700,2.0475743,1.7468235,,,,,,,,,,,,,, -41800,2.0887709,1.5929569,,,,,,,,,,,,,, -41900,1.7287916,1.5568507,,,,,,,,,,,,,, -42000,1.8584827,1.7335172,,,,,,,,,,,,,, -42100,1.8064501,1.7175195,,,,,,,,,,,,,, -42200,1.9432638,1.7033529,,,,,,,,,,,,,, -42300,1.822776,1.5013899,,,,,,,,,,,,,, -42335,,,0.7200653553009033,1.0819741487503052,0.6363199949264526,1.509085774421692,50000.0,0.5099000334739685,2.241546869277954,10000.0,14313.949091911316,14827.75766301155,14313.949091911316,511.47321367263794,0.8653779029846191,0.0 -42400,1.7521877,1.61637,,,,,,,,,,,,,, -42500,1.6744213,1.6460397,,,,,,,,,,,,,, -42600,1.6870825,1.593488,,,,,,,,,,,,,, -42700,1.8319368,1.5426877,,,,,,,,,,,,,, -42800,1.7335335,1.6920006,,,,,,,,,,,,,, -42900,1.9702697,1.6321751,,,,,,,,,,,,,, -43000,1.8487074,1.6921088,,,,,,,,,,,,,, -43100,1.8197026,1.6256889,,,,,,,,,,,,,, -43200,1.743069,1.5986041,,,,,,,,,,,,,, -43300,1.8895556,1.6365671,,,,,,,,,,,,,, -43400,1.7701228,1.6339812,,,,,,,,,,,,,, -43500,1.9159434,1.5933043,,,,,,,,,,,,,, -43600,1.8792132,1.5845554,,,,,,,,,,,,,, -43700,1.7285422,1.5065769,,,,,,,,,,,,,, -43800,1.6731734,1.6155592,,,,,,,,,,,,,, -43848,,,0.6968072056770325,1.1697343587875366,0.6294599771499634,1.5310777425765991,50000.0,0.5049000382423401,2.2731704711914062,10000.0,14824.149604320526,15355.536323785782,14824.149604320526,528.9610199928284,0.9021854400634766,0.0 -43900,1.7858334,1.60041,,,,,,,,,,,,,, -44000,2.2092786,1.7156438,,,,,,,,,,,,,, -44100,1.9468296,1.614955,,,,,,,,,,,,,, -44200,2.258492,1.6611431,,,,,,,,,,,,,, -44300,1.885175,1.6013803,,,,,,,,,,,,,, -44400,1.7779323,1.6293525,,,,,,,,,,,,,, -44500,1.8965197,1.7115899,,,,,,,,,,,,,, -44600,2.0177662,1.7122822,,,,,,,,,,,,,, -44700,1.8992139,1.7170664,,,,,,,,,,,,,, -44800,1.7080128,1.5393156,,,,,,,,,,,,,, -44900,1.9169221,1.602514,,,,,,,,,,,,,, -45000,1.9171838,1.7556539,,,,,,,,,,,,,, -45100,1.9681659,1.5587229,,,,,,,,,,,,,, -45200,1.8274002,1.5922691,,,,,,,,,,,,,, -45300,1.8213371,1.6753008,,,,,,,,,,,,,, -45362,,,0.7052175998687744,1.146565079689026,0.6387799978256226,1.4946125745773315,50000.0,0.5038000345230103,2.2749900817871094,10000.0,15334.175417423248,15883.150108098984,15334.175417423248,546.4602868556976,0.9373815059661864,0.0 -45400,1.9377246,1.6632007,,,,,,,,,,,,,, -45500,1.818531,1.6614596,,,,,,,,,,,,,, -45600,2.3335354,1.681934,,,,,,,,,,,,,, -45700,1.8816326,1.7115014,,,,,,,,,,,,,, -45800,1.9770312,1.6423945,,,,,,,,,,,,,, -45900,1.8405834,1.5851688,,,,,,,,,,,,,, -46000,1.8904389,1.6542346,,,,,,,,,,,,,, -46100,1.9229873,1.7015499,,,,,,,,,,,,,, -46200,2.0080826,1.648673,,,,,,,,,,,,,, -46300,1.9870726,1.4989125,,,,,,,,,,,,,, -46400,1.813156,1.7286496,,,,,,,,,,,,,, -46500,1.746648,1.5692888,,,,,,,,,,,,,, -46600,2.0227816,1.6657239,,,,,,,,,,,,,, -46700,1.8989915,1.5760043,,,,,,,,,,,,,, -46800,1.9493197,1.7065303,,,,,,,,,,,,,, -46877,,,0.700613796710968,1.157369613647461,0.6425999999046326,1.466826319694519,50000.0,0.515500009059906,2.1804099082946777,10000.0,15844.352715015411,16411.172538280487,15844.352715015411,564.2131533622742,0.9761536121368408,0.0 -46900,1.9878696,1.6617429,,,,,,,,,,,,,, -47000,1.7584691,1.5527682,,,,,,,,,,,,,, -47100,1.8709965,1.5514573,,,,,,,,,,,,,, -47200,1.684134,1.567007,,,,,,,,,,,,,, -47300,1.8746248,1.6293702,,,,,,,,,,,,,, -47400,1.8208117,1.5312161,,,,,,,,,,,,,, -47500,1.9656538,1.5790339,,,,,,,,,,,,,, -47600,1.9948659,1.7512308,,,,,,,,,,,,,, -47700,1.8732659,1.5595016,,,,,,,,,,,,,, -47800,1.8329058,1.6258484,,,,,,,,,,,,,, -47900,1.9697142,1.5594449,,,,,,,,,,,,,, -48000,1.7172524,1.5801163,,,,,,,,,,,,,, -48100,1.7696822,1.5590646,,,,,,,,,,,,,, -48200,1.7425508,1.7049078,,,,,,,,,,,,,, -48300,1.8215415,1.6761549,,,,,,,,,,,,,, -48390,,,0.7057557106018066,1.150970220565796,0.6429199576377869,1.46943998336792,50000.0,0.5108000040054321,2.2163541316986084,10000.0,16354.314910888672,16938.669855833054,16354.314910888672,581.6602036952972,1.0123159885406494,0.0 -48400,2.0912294,1.6462057,,,,,,,,,,,,,, -48500,2.3014774,1.6602666,,,,,,,,,,,,,, -48600,2.0640888,1.6966643,,,,,,,,,,,,,, -48700,1.9747994,1.5606129,,,,,,,,,,,,,, -48800,2.0223246,1.7577698,,,,,,,,,,,,,, -48900,1.7423091,1.5210682,,,,,,,,,,,,,, -49000,1.7215716,1.5965697,,,,,,,,,,,,,, -49100,1.9202958,1.6170985,,,,,,,,,,,,,, -49200,1.884659,1.5968974,,,,,,,,,,,,,, -49300,1.8654726,1.5248083,,,,,,,,,,,,,, -49400,1.9628906,1.6790717,,,,,,,,,,,,,, -49500,1.709228,1.5354528,,,,,,,,,,,,,, -49600,1.7408198,1.5584937,,,,,,,,,,,,,, -49700,1.975373,1.5810964,,,,,,,,,,,,,, -49800,1.928409,1.4738574,,,,,,,,,,,,,, -49900,1.614438,1.5384825,,,,,,,,,,,,,, -49904,,,0.758230984210968,0.9265372157096864,0.6476399898529053,1.4463788270950315,50000.0,0.513700008392334,2.200896739959717,10000.0,16864.458737134933,17466.411892175674,16864.458737134933,599.1710863113403,1.04778790473938,0.0 -50000,1.9791886,1.6000028,,,,,,,,,,,,,, -50100,1.8252877,1.5806816,,,,,,,,,,,,,, -50200,1.9778864,1.6292228,,,,,,,,,,,,,, -50300,1.9548806,1.5437541,,,,,,,,,,,,,, -50400,1.9951953,1.6914895,,,,,,,,,,,,,, -50500,1.6892068,1.5235087,,,,,,,,,,,,,, -50600,2.4966252,1.6729962,,,,,,,,,,,,,, -50700,1.9269633,1.5563302,,,,,,,,,,,,,, -50800,1.863496,1.6708636,,,,,,,,,,,,,, -50900,1.9613614,1.6925385,,,,,,,,,,,,,, -51000,1.9426692,1.6951033,,,,,,,,,,,,,, -51100,2.0216875,1.6375012,,,,,,,,,,,,,, -51200,1.8665429,1.5484854,,,,,,,,,,,,,, -51300,1.9037372,1.6562812,,,,,,,,,,,,,, -51400,2.2921746,1.5872384,,,,,,,,,,,,,, -51418,,,0.7257453799247742,1.053958535194397,0.6452999711036682,1.461386799812317,50000.0,0.515500009059906,2.1797661781311035,10000.0,17374.71946334839,17994.22289633751,17374.71946334839,616.627453327179,1.0893511772155762,0.0 -51500,2.0195925,1.6755397,,,,,,,,,,,,,, -51600,2.015084,1.5415511,,,,,,,,,,,,,, -51700,1.8903004,1.6834267,,,,,,,,,,,,,, -51800,1.7880807,1.6276133,,,,,,,,,,,,,, -51900,2.1119413,1.5882643,,,,,,,,,,,,,, -52000,1.9334301,1.6051865,,,,,,,,,,,,,, -52100,2.0114791,1.6092062,,,,,,,,,,,,,, -52200,2.002649,1.5501013,,,,,,,,,,,,,, -52300,1.9102746,1.7035495,,,,,,,,,,,,,, -52400,1.7358316,1.6125252,,,,,,,,,,,,,, -52500,1.8965479,1.6394722,,,,,,,,,,,,,, -52600,1.9538476,1.6963496,,,,,,,,,,,,,, -52700,1.9185776,1.6445941,,,,,,,,,,,,,, -52800,2.121451,1.6032522,,,,,,,,,,,,,, -52900,2.008778,1.5789262,,,,,,,,,,,,,, -52932,,,0.7232341766357422,1.0687999725341797,0.646399974822998,1.4467506408691406,50000.0,0.5178000330924988,2.1698689460754395,10000.0,17884.717614650726,18521.756559848785,17884.717614650726,634.0705862045288,1.129570722579956,0.0 -53000,1.9523344,1.6317208,,,,,,,,,,,,,, -53100,1.9808311,1.6060526,,,,,,,,,,,,,, -53200,1.8527536,1.6068008,,,,,,,,,,,,,, -53300,2.0283744,1.5485857,,,,,,,,,,,,,, -53400,1.6235214,1.5411043,,,,,,,,,,,,,, -53500,2.0558844,1.4824164,,,,,,,,,,,,,, -53600,1.8927653,1.4032121,,,,,,,,,,,,,, -53700,1.7936941,1.5164909,,,,,,,,,,,,,, -53800,1.807682,1.6167687,,,,,,,,,,,,,, -53900,1.7681285,1.5626043,,,,,,,,,,,,,, -54000,1.9952267,1.6046696,,,,,,,,,,,,,, -54100,1.8529426,1.6098566,,,,,,,,,,,,,, -54200,2.1082823,1.5523164,,,,,,,,,,,,,, -54300,1.7800007,1.5627078,,,,,,,,,,,,,, -54400,2.3002114,1.6168222,,,,,,,,,,,,,, -54446,,,0.7245296239852905,1.073808670043945,0.6561799645423889,1.4147099256515503,50000.0,0.5199000239372253,2.165266513824463,10000.0,18394.86057209969,19050.109229803085,18394.86057209969,652.1868450641632,1.1715331077575684,0.0 -54500,2.1757364,1.5437207,,,,,,,,,,,,,, -54600,2.100107,1.6089482,,,,,,,,,,,,,, -54700,1.8093134,1.5875176,,,,,,,,,,,,,, -54800,1.8265486,1.5908344,,,,,,,,,,,,,, -54900,2.0715551,1.5094664,,,,,,,,,,,,,, -55000,2.1050463,1.6402031,,,,,,,,,,,,,, -55100,2.0766416,1.6973102,,,,,,,,,,,,,, -55200,2.0798168,1.5934111,,,,,,,,,,,,,, -55300,1.8736645,1.5722978,,,,,,,,,,,,,, -55400,1.8659818,1.5361941,,,,,,,,,,,,,, -55500,1.9742781,1.6320746,,,,,,,,,,,,,, -55600,1.941592,1.4999809,,,,,,,,,,,,,, -55700,1.9856356,1.6631899,,,,,,,,,,,,,, -55800,2.213814,1.7493906,,,,,,,,,,,,,, -55900,1.9340148,1.48007,,,,,,,,,,,,,, -55961,,,0.7122927308082581,1.1093248128890991,0.6516199707984924,1.4258030652999878,50000.0,0.5128000378608704,2.175576448440552,10000.0,18904.974903345108,19578.019869804382,18904.974903345108,669.8906226158142,1.211322784423828,0.0 -56000,1.8354479,1.5191965,,,,,,,,,,,,,, -56100,1.7939901,1.4373468,,,,,,,,,,,,,, -56200,2.109356,1.6620314,,,,,,,,,,,,,, -56300,1.7577661,1.6243674,,,,,,,,,,,,,, -56400,2.0212545,1.530328,,,,,,,,,,,,,, -56500,2.0193229,1.507939,,,,,,,,,,,,,, -56600,1.9052308,1.6115813,,,,,,,,,,,,,, -56700,1.9526517,1.5340189,,,,,,,,,,,,,, -56800,1.9870526,1.6024762,,,,,,,,,,,,,, -56900,1.7496619,1.6161875,,,,,,,,,,,,,, -57000,1.9708107,1.5363424,,,,,,,,,,,,,, -57100,2.0551121,1.5239737,,,,,,,,,,,,,, -57200,1.8936114,1.661614,,,,,,,,,,,,,, -57300,2.0024364,1.6315659,,,,,,,,,,,,,, -57400,1.9560738,1.5113018,,,,,,,,,,,,,, -57475,,,0.70609450340271,1.13842511177063,0.6474599838256836,1.457894802093506,50000.0,0.5162000060081482,2.202220916748047,10000.0,19415.11595249176,20105.84994530677,19415.11595249176,687.4907431602478,1.2484371662139893,0.0 -57500,2.1091588,1.5877042,,,,,,,,,,,,,, -57600,2.1780994,1.5843915,,,,,,,,,,,,,, -57700,2.2011466,1.6235951,,,,,,,,,,,,,, -57800,1.8597991,1.6675421,,,,,,,,,,,,,, -57900,1.8342079,1.4652758,,,,,,,,,,,,,, -58000,1.9484699,1.6481535,,,,,,,,,,,,,, -58100,2.1180694,1.6651882,,,,,,,,,,,,,, -58200,1.9102465,1.5136769,,,,,,,,,,,,,, -58300,1.9494644,1.5013595,,,,,,,,,,,,,, -58400,2.1302168,1.5634898,,,,,,,,,,,,,, -58500,1.8955705,1.5996705,,,,,,,,,,,,,, -58600,1.7593927,1.6044807,,,,,,,,,,,,,, -58700,1.9557812,1.5312027,,,,,,,,,,,,,, -58800,1.8262569,1.6204008,,,,,,,,,,,,,, -58900,1.8244469,1.5293902,,,,,,,,,,,,,, -58989,,,0.7527901530265808,0.939852774143219,0.6464999914169312,1.4735795259475708,50000.0,0.5228000283241272,2.17522406578064,10000.0,19925.138954639435,20633.438943624496,19925.138954639435,704.9665122032166,1.285801649093628,0.0 -59000,1.9459635,1.5341173,,,,,,,,,,,,,, -59100,2.0523322,1.6482452,,,,,,,,,,,,,, -59200,1.833652,1.7145443,,,,,,,,,,,,,, -59300,2.1631055,1.6243329,,,,,,,,,,,,,, -59400,1.8590391,1.6728028,,,,,,,,,,,,,, -59500,1.839298,1.5196445,,,,,,,,,,,,,, -59600,1.9909335,1.5176543,,,,,,,,,,,,,, -59700,1.9869363,1.5745021,,,,,,,,,,,,,, -59800,1.96896,1.6342641,,,,,,,,,,,,,, -59900,2.0776746,1.5274825,,,,,,,,,,,,,, -60000,2.100833,1.5914466,,,,,,,,,,,,,, -60100,1.9201103,1.618767,,,,,,,,,,,,,, -60200,2.066461,1.4886906,,,,,,,,,,,,,, -60300,2.0718157,1.6661979,,,,,,,,,,,,,, -60400,2.1840396,1.5107094,,,,,,,,,,,,,, -60500,1.8505807,1.5030406,,,,,,,,,,,,,, -60503,,,0.7373046875,1.001809000968933,0.6541599631309509,1.409217119216919,50000.0,0.5309000015258789,2.1159284114837646,10000.0,20435.203814983368,21161.29817390442,20435.203814983368,722.6707804203033,1.3243467807769775,0.0 -60600,1.9769105,1.449789,,,,,,,,,,,,,, -60700,1.8259884,1.3788431,,,,,,,,,,,,,, -60800,2.0625703,1.6086588,,,,,,,,,,,,,, -60900,1.99664,1.4713273,,,,,,,,,,,,,, -61000,2.2683887,1.6275022,,,,,,,,,,,,,, -61100,1.995423,1.5745642,,,,,,,,,,,,,, -61200,1.7551329,1.465938,,,,,,,,,,,,,, -61300,1.8716129,1.4667771,,,,,,,,,,,,,, -61400,1.864471,1.5520477,,,,,,,,,,,,,, -61500,1.9608656,1.4861189,,,,,,,,,,,,,, -61600,2.2351398,1.5361199,,,,,,,,,,,,,, -61700,1.9020525,1.5143179,,,,,,,,,,,,,, -61800,1.880823,1.5229142,,,,,,,,,,,,,, -61900,1.9159698,1.573693,,,,,,,,,,,,,, -62000,2.029663,1.6317269,,,,,,,,,,,,,, -62017,,,0.7272400856018066,1.0499770641326904,0.6545799970626831,1.4146177768707275,50000.0,0.5299000144004822,2.1289846897125244,10000.0,20945.369954109192,21688.86900305748,20945.369954109192,739.9851453304291,1.360926389694214,0.0 -62100,1.887423,1.4890116,,,,,,,,,,,,,, -62200,2.0684721,1.5957615,,,,,,,,,,,,,, -62300,2.0808492,1.4180708,,,,,,,,,,,,,, -62400,2.3003614,1.5310235,,,,,,,,,,,,,, -62500,2.1455407,1.5544016,,,,,,,,,,,,,, -62600,2.1692603,1.5269115,,,,,,,,,,,,,, -62700,1.9393594,1.5958357,,,,,,,,,,,,,, -62800,1.95681,1.5015936,,,,,,,,,,,,,, -62900,1.8381976,1.6037958,,,,,,,,,,,,,, -63000,1.9682068,1.4551134,,,,,,,,,,,,,, -63100,1.9755994,1.6076021,,,,,,,,,,,,,, -63200,1.8338668,1.4759443,,,,,,,,,,,,,, -63300,2.3832407,1.582102,,,,,,,,,,,,,, -63400,2.1118467,1.538026,,,,,,,,,,,,,, -63500,2.0408492,1.4308904,,,,,,,,,,,,,, -63531,,,0.7257254123687744,1.053279995918274,0.6569799780845642,1.4097890853881836,50000.0,0.5286000370979309,2.131467580795288,10000.0,21455.5750977993,22216.50008130073,21455.5750977993,757.318776845932,1.4008488655090332,0.0 -63600,1.994011,1.5744272,,,,,,,,,,,,,, -63700,2.2254019,1.5592127,,,,,,,,,,,,,, -63800,2.1637433,1.5970604,,,,,,,,,,,,,, -63900,1.9899075,1.469438,,,,,,,,,,,,,, -64000,1.9278954,1.4907436,,,,,,,,,,,,,, -64100,1.9424775,1.5019685,,,,,,,,,,,,,, -64200,2.0167096,1.4891343,,,,,,,,,,,,,, -64300,2.0096996,1.5951489,,,,,,,,,,,,,, -64400,2.0255399,1.5360489,,,,,,,,,,,,,, -64500,1.956276,1.6413314,,,,,,,,,,,,,, -64600,2.1839452,1.5110576,,,,,,,,,,,,,, -64700,1.8820876,1.5020823,,,,,,,,,,,,,, -64800,2.0555668,1.6102517,,,,,,,,,,,,,, -64900,2.1501331,1.5184562,,,,,,,,,,,,,, -65000,1.9893018,1.4480072,,,,,,,,,,,,,, -65046,,,0.7092434763908386,1.1382043361663818,0.6457200050354004,1.4658386707305908,50000.0,0.5245000123977661,2.156514883041382,10000.0,21965.726397037502,22744.58133101464,21965.726397037502,775.1528396606445,1.4429562091827393,0.0 -65100,1.9946613,1.4845611,,,,,,,,,,,,,, -65200,2.118427,1.4743586,,,,,,,,,,,,,, -65300,1.984002,1.4565037,,,,,,,,,,,,,, -65400,2.1953619,1.4685292,,,,,,,,,,,,,, -65500,1.979967,1.5807633,,,,,,,,,,,,,, -65600,1.8856426,1.4164115,,,,,,,,,,,,,, -65700,1.9396899,1.5297863,,,,,,,,,,,,,, -65800,2.0785952,1.521381,,,,,,,,,,,,,, -65900,2.077005,1.4948251,,,,,,,,,,,,,, -66000,1.9163742,1.6242268,,,,,,,,,,,,,, -66100,2.0360615,1.5585641,,,,,,,,,,,,,, -66200,2.198406,1.5057969,,,,,,,,,,,,,, -66300,1.9289808,1.4386358,,,,,,,,,,,,,, -66400,2.1224155,1.5139911,,,,,,,,,,,,,, -66500,2.10056,1.5298262,,,,,,,,,,,,,, -66561,,,0.7173748016357422,1.0853809118270874,0.6520000100135803,1.4343310594558716,50000.0,0.5267000198364258,2.145815134048462,10000.0,22475.8989508152,23272.48766350746,22475.8989508152,792.794264793396,1.482445240020752,0.0 -66600,2.2581134,1.5415851,,,,,,,,,,,,,, -66700,1.9234529,1.5444707,,,,,,,,,,,,,, -66800,1.8620586,1.3908677,,,,,,,,,,,,,, -66900,1.8930688,1.513787,,,,,,,,,,,,,, -67000,2.015848,1.5647324,,,,,,,,,,,,,, -67100,2.1858835,1.5290234,,,,,,,,,,,,,, -67200,1.9258004,1.471673,,,,,,,,,,,,,, -67300,2.0108242,1.5401343,,,,,,,,,,,,,, -67400,1.9504281,1.5201263,,,,,,,,,,,,,, -67500,2.1128008,1.454227,,,,,,,,,,,,,, -67600,1.9930077,1.4894952,,,,,,,,,,,,,, -67700,1.932936,1.576349,,,,,,,,,,,,,, -67800,2.1219232,1.5062702,,,,,,,,,,,,,, -67900,2.0199304,1.5600576,,,,,,,,,,,,,, -68000,2.3256292,1.4065702,,,,,,,,,,,,,, -68075,,,0.7562978267669678,0.9239857792854308,0.6578199863433838,1.4056472778320312,50000.0,0.5333000421524048,2.1278412342071533,10000.0,22986.08687877655,23800.46881747245,22986.08687877655,810.4946658611298,1.5225434303283691,0.0 -68100,1.9851251,1.6457112,,,,,,,,,,,,,, -68200,2.0153573,1.4992924,,,,,,,,,,,,,, -68300,2.1266215,1.5578243,,,,,,,,,,,,,, -68400,1.9750072,1.4573697,,,,,,,,,,,,,, -68500,1.9720577,1.5031128,,,,,,,,,,,,,, -68600,2.154588,1.426151,,,,,,,,,,,,,, -68700,2.2556436,1.495355,,,,,,,,,,,,,, -68800,2.104627,1.5041641,,,,,,,,,,,,,, -68900,2.0381215,1.664952,,,,,,,,,,,,,, -69000,1.8574816,1.5041878,,,,,,,,,,,,,, -69100,2.0065057,1.5868862,,,,,,,,,,,,,, -69200,2.114901,1.5387135,,,,,,,,,,,,,, -69300,1.9950435,1.5629727,,,,,,,,,,,,,, -69400,2.1352804,1.4834665,,,,,,,,,,,,,, -69500,2.05216,1.5308728,,,,,,,,,,,,,, -69589,,,0.7335578799247742,1.0107797384262085,0.6535399556159973,1.4225095510482788,50000.0,0.5343000292778015,2.123914957046509,10000.0,23496.2087392807,24328.093178987503,23496.2087392807,827.9055438041687,1.5612945556640625,0.0 -69600,2.030166,1.4060974,,,,,,,,,,,,,, -69700,1.8460945,1.4286364,,,,,,,,,,,,,, -69800,2.0715582,1.4963669,,,,,,,,,,,,,, -69900,2.0008426,1.4865592,,,,,,,,,,,,,, -70000,1.917816,1.4563863,,,,,,,,,,,,,, -70100,2.2235165,1.4927568,,,,,,,,,,,,,, -70200,1.9713047,1.5421121,,,,,,,,,,,,,, -70300,2.1303947,1.6025116,,,,,,,,,,,,,, -70400,2.1353326,1.554642,,,,,,,,,,,,,, -70500,1.8722249,1.581212,,,,,,,,,,,,,, -70600,1.9662569,1.5293888,,,,,,,,,,,,,, -70700,2.0085044,1.538,,,,,,,,,,,,,, -70800,2.342683,1.5199784,,,,,,,,,,,,,, -70900,1.9396641,1.4589275,,,,,,,,,,,,,, -71000,2.0562112,1.5121589,,,,,,,,,,,,,, -71100,1.9661077,1.4322299,,,,,,,,,,,,,, -71103,,,0.7446388602256775,0.9683708548545836,0.668940007686615,1.3582696914672852,50000.0,0.5351000428199768,2.0900559425354004,10000.0,24006.12292265892,24855.764329195023,24006.12292265892,845.568799495697,1.6021020412445068,0.0 -71200,1.9340451,1.4994948,,,,,,,,,,,,,, -71300,2.2625759,1.4560454,,,,,,,,,,,,,, -71400,2.3121755,1.5849469,,,,,,,,,,,,,, -71500,2.433783,1.4394444,,,,,,,,,,,,,, -71600,2.201859,1.6071179,,,,,,,,,,,,,, -71700,2.0438426,1.522041,,,,,,,,,,,,,, -71800,2.158431,1.549075,,,,,,,,,,,,,, -71900,2.1012466,1.4256564,,,,,,,,,,,,,, -72000,1.9268746,1.4470465,,,,,,,,,,,,,, -72100,2.1266782,1.4371656,,,,,,,,,,,,,, -72200,2.1839862,1.4704587,,,,,,,,,,,,,, -72300,2.1256003,1.5923673,,,,,,,,,,,,,, -72400,2.2396235,1.4966917,,,,,,,,,,,,,, -72500,2.144904,1.5340635,,,,,,,,,,,,,, -72600,1.9110968,1.5660748,,,,,,,,,,,,,, -72618,,,0.7384406924247742,0.9917774200439452,0.6642999649047852,1.362199068069458,50000.0,0.5350000262260437,2.121037483215332,10000.0,24516.380070209503,25383.496671438217,24516.380070209503,862.9524285793304,1.64097261428833,0.0 -72700,1.9856809,1.6411607,,,,,,,,,,,,,, -72800,1.9779835,1.5695591,,,,,,,,,,,,,, -72900,1.9301689,1.5216495,,,,,,,,,,,,,, -73000,1.8810428,1.4804627,,,,,,,,,,,,,, -73100,1.9351035,1.4666575,,,,,,,,,,,,,, -73200,2.0015786,1.4572651,,,,,,,,,,,,,, -73300,2.0532792,1.5239807,,,,,,,,,,,,,, -73400,2.1522481,1.4854453,,,,,,,,,,,,,, -73500,2.1370966,1.5721887,,,,,,,,,,,,,, -73600,2.3230147,1.4607887,,,,,,,,,,,,,, -73700,2.081746,1.5568715,,,,,,,,,,,,,, -73800,2.2226963,1.5882163,,,,,,,,,,,,,, -73900,1.9828303,1.5483248,,,,,,,,,,,,,, -74000,2.081284,1.4779857,,,,,,,,,,,,,, -74100,2.3459384,1.6017071,,,,,,,,,,,,,, -74133,,,0.7292729616165161,1.0414572954177856,0.6623799800872803,1.3939917087554932,50000.0,0.5330000519752502,2.1223597526550293,10000.0,25026.531310796738,25911.16432285309,25026.531310796738,880.3766114711761,1.679915428161621,0.0 -74200,2.178315,1.4742194,,,,,,,,,,,,,, -74300,2.3241427,1.5002793,,,,,,,,,,,,,, -74400,1.9140397,1.3983771,,,,,,,,,,,,,, -74500,1.976415,1.4759787,,,,,,,,,,,,,, -74600,2.0819046,1.5312934,,,,,,,,,,,,,, -74700,2.3438194,1.4952691,,,,,,,,,,,,,, -74800,2.314601,1.6006651,,,,,,,,,,,,,, -74900,2.487355,1.4862858,,,,,,,,,,,,,, -75000,2.1969037,1.6428545,,,,,,,,,,,,,, -75100,2.0378954,1.5188332,,,,,,,,,,,,,, -75200,2.1383333,1.4093454,,,,,,,,,,,,,, -75300,2.055073,1.422759,,,,,,,,,,,,,, -75400,2.0851936,1.4643068,,,,,,,,,,,,,, -75500,2.3001528,1.5689733,,,,,,,,,,,,,, -75600,2.2299328,1.5596788,,,,,,,,,,,,,, -75647,,,0.7429248690605164,0.9894379377365112,0.6701799631118774,1.3528382778167725,50000.0,0.5415000319480896,2.0805461406707764,10000.0,25536.60414552689,26438.958388328552,25536.60414552689,898.0017364025116,1.7220737934112549,0.0 -75700,2.195187,1.5436586,,,,,,,,,,,,,, -75800,2.015146,1.509602,,,,,,,,,,,,,, -75900,2.1776247,1.4121886,,,,,,,,,,,,,, -76000,2.118891,1.5340914,,,,,,,,,,,,,, -76100,2.1845565,1.4357533,,,,,,,,,,,,,, -76200,2.5199547,1.6706222,,,,,,,,,,,,,, -76300,2.047253,1.4650428,,,,,,,,,,,,,, -76400,2.1742413,1.4445305,,,,,,,,,,,,,, -76500,2.2536938,1.535778,,,,,,,,,,,,,, -76600,2.0999625,1.477427,,,,,,,,,,,,,, -76700,1.9183666,1.4088777,,,,,,,,,,,,,, -76800,1.9657406,1.4825208,,,,,,,,,,,,,, -76900,2.1758657,1.44987,,,,,,,,,,,,,, -77000,2.0609186,1.4736776,,,,,,,,,,,,,, -77100,2.2239985,1.4981501,,,,,,,,,,,,,, -77161,,,0.7538663744926453,0.933739185333252,0.6592999696731567,1.3939270973205566,50000.0,0.5258000493049622,2.1486873626708984,10000.0,26046.688113451004,26966.82772350312,26046.688113451004,915.695422887802,1.7612836360931396,0.0 -77200,2.0174432,1.4350497,,,,,,,,,,,,,, -77300,2.2004843,1.5109181,,,,,,,,,,,,,, -77400,2.6696773,1.4703976,,,,,,,,,,,,,, -77500,2.0417929,1.4462142,,,,,,,,,,,,,, -77600,2.1596801,1.5124518,,,,,,,,,,,,,, -77700,2.2429848,1.4615092,,,,,,,,,,,,,, -77800,2.174301,1.3965261,,,,,,,,,,,,,, -77900,2.260156,1.5151534,,,,,,,,,,,,,, -78000,2.1022496,1.3722415,,,,,,,,,,,,,, -78100,2.06927,1.4841491,,,,,,,,,,,,,, -78200,2.2046933,1.4706522,,,,,,,,,,,,,, -78300,2.3002338,1.4517176,,,,,,,,,,,,,, -78400,2.1859705,1.487953,,,,,,,,,,,,,, -78500,2.1671426,1.5041133,,,,,,,,,,,,,, -78600,2.1623638,1.4547691,,,,,,,,,,,,,, -78675,,,0.7557198405265808,0.925915777683258,0.6734399795532227,1.3371450901031494,50000.0,0.5458000302314758,2.0804595947265625,10000.0,26556.600678920742,27494.543236494064,26556.600678920742,933.3991062641144,1.8085589408874512,0.0 -78700,2.1840425,1.4560268,,,,,,,,,,,,,, -78800,2.0539298,1.4981446,,,,,,,,,,,,,, -78900,2.2355757,1.4690402,,,,,,,,,,,,,, -79000,2.135984,1.466302,,,,,,,,,,,,,, -79100,1.8297586,1.3695488,,,,,,,,,,,,,, -79200,2.2175357,1.3634592,,,,,,,,,,,,,, -79300,2.1303825,1.6080639,,,,,,,,,,,,,, -79400,2.1143909,1.455198,,,,,,,,,,,,,, -79500,2.2022612,1.3848752,,,,,,,,,,,,,, -79600,2.1402352,1.3756258,,,,,,,,,,,,,, -79700,2.270239,1.3474659,,,,,,,,,,,,,, -79800,2.2880585,1.4794424,,,,,,,,,,,,,, -79900,2.1791785,1.5040958,,,,,,,,,,,,,, -80000,2.1887033,1.4520868,,,,,,,,,,,,,, -80100,2.1785867,1.521901,,,,,,,,,,,,,, -80190,,,0.7496811151504517,0.947109878063202,0.6690799593925476,1.3521523475646973,50000.0,0.5390000343322754,2.0906593799591064,10000.0,27066.66562986374,28022.62483239174,27066.66562986374,951.3102207183838,1.8615412712097168,0.0 -80200,2.1176975,1.4320384,,,,,,,,,,,,,, -80300,2.194,1.4564707,,,,,,,,,,,,,, -80400,2.3959866,1.5593512,,,,,,,,,,,,,, -80500,2.1176114,1.4918704,,,,,,,,,,,,,, -80600,2.3275225,1.4043677,,,,,,,,,,,,,, -80700,2.0999432,1.4298775,,,,,,,,,,,,,, -80800,2.031357,1.5420057,,,,,,,,,,,,,, -80900,2.0636885,1.4734988,,,,,,,,,,,,,, -81000,2.22813,1.5482509,,,,,,,,,,,,,, -81100,2.1733625,1.5094337,,,,,,,,,,,,,, -81200,2.210585,1.4281644,,,,,,,,,,,,,, -81300,2.062317,1.469151,,,,,,,,,,,,,, -81400,2.2191157,1.5418955,,,,,,,,,,,,,, -81500,2.481627,1.4339797,,,,,,,,,,,,,, -81600,2.0327888,1.4037313,,,,,,,,,,,,,, -81700,2.3442266,1.6055956,,,,,,,,,,,,,, -81704,,,0.7504384517669678,0.9462156295776368,0.6738199591636658,1.321738362312317,50000.0,0.5506000518798828,2.008496284484864,10000.0,27576.59165716172,28550.037356376648,27576.59165716172,968.7020602226256,1.9029486179351809,0.0 -81800,2.175317,1.4113942,,,,,,,,,,,,,, -81900,1.9950916,1.4226698,,,,,,,,,,,,,, -82000,2.061504,1.3504777,,,,,,,,,,,,,, -82100,2.1883345,1.5222614,,,,,,,,,,,,,, -82200,2.0895736,1.4240429,,,,,,,,,,,,,, -82300,2.1492999,1.4947402,,,,,,,,,,,,,, -82400,2.091065,1.3562415,,,,,,,,,,,,,, -82500,2.0100944,1.4180229,,,,,,,,,,,,,, -82600,2.1655965,1.4050854,,,,,,,,,,,,,, -82700,2.5928724,1.469188,,,,,,,,,,,,,, -82800,2.3146713,1.5931574,,,,,,,,,,,,,, -82900,2.3039925,1.3743129,,,,,,,,,,,,,, -83000,2.196723,1.5394926,,,,,,,,,,,,,, -83100,2.1826138,1.4867051,,,,,,,,,,,,,, -83200,2.063006,1.3535503,,,,,,,,,,,,,, -83218,,,0.7326610088348389,1.0291450023651123,0.6612399816513062,1.3827193975448608,50000.0,0.5326000452041626,2.109022378921509,10000.0,28086.49978995323,29077.384110450745,28086.49978995323,986.0461583137512,1.945774793624878,0.0 -83300,2.2065673,1.4352329,,,,,,,,,,,,,, -83400,2.156648,1.4339856,,,,,,,,,,,,,, -83500,2.2950904,1.4402469,,,,,,,,,,,,,, -83600,2.1679332,1.4764037,,,,,,,,,,,,,, -83700,2.1807404,1.4509354,,,,,,,,,,,,,, -83800,2.1712298,1.4698174,,,,,,,,,,,,,, -83900,2.1351156,1.4307468,,,,,,,,,,,,,, -84000,2.3351028,1.4725868,,,,,,,,,,,,,, -84100,2.3053298,1.5177088,,,,,,,,,,,,,, -84200,2.1240792,1.3864064,,,,,,,,,,,,,, -84300,2.203435,1.3704684,,,,,,,,,,,,,, -84400,2.3554468,1.4497283,,,,,,,,,,,,,, -84500,2.2252595,1.4498657,,,,,,,,,,,,,, -84600,2.111276,1.4097407,,,,,,,,,,,,,, -84700,2.0362089,1.4241436,,,,,,,,,,,,,, -84732,,,0.7528101205825806,0.936428725719452,0.6754199862480164,1.3129218816757202,50000.0,0.5491000413894653,2.0207293033599854,10000.0,28596.737845897675,29605.5751888752,28596.737845897675,1003.904506444931,1.9876015186309808,0.0 -84800,2.3310366,1.3561378,,,,,,,,,,,,,, -84900,2.1942108,1.4340445,,,,,,,,,,,,,, -85000,2.1667173,1.4913967,,,,,,,,,,,,,, -85100,2.163547,1.4583372,,,,,,,,,,,,,, -85200,2.49413,1.419431,,,,,,,,,,,,,, -85300,2.142545,1.4666185,,,,,,,,,,,,,, -85400,2.0987673,1.4003631,,,,,,,,,,,,,, -85500,2.3516552,1.5112443,,,,,,,,,,,,,, -85600,2.184249,1.4389608,,,,,,,,,,,,,, -85700,2.2811668,1.4985871,,,,,,,,,,,,,, -85800,2.5801628,1.4550809,,,,,,,,,,,,,, -85900,2.1539595,1.4138637,,,,,,,,,,,,,, -86000,2.139123,1.4363734,,,,,,,,,,,,,, -86100,2.475023,1.5231348,,,,,,,,,,,,,, -86200,2.3461983,1.4189306,,,,,,,,,,,,,, -86247,,,0.7684949040412903,0.8733639717102051,0.6672799587249756,1.357527494430542,50000.0,0.5403000116348267,2.084959983825684,10000.0,29106.95946264267,30133.36285853386,29106.95946264267,1021.369663476944,2.03339958190918,0.0 -86300,2.27601,1.5094814,,,,,,,,,,,,,, -86400,2.3215277,1.5217984,,,,,,,,,,,,,, -86500,2.4479778,1.4757233,,,,,,,,,,,,,, -86600,2.2891128,1.4241424,,,,,,,,,,,,,, -86700,2.3137603,1.430922,,,,,,,,,,,,,, -86800,2.463107,1.5135654,,,,,,,,,,,,,, -86900,2.0600681,1.4305812,,,,,,,,,,,,,, -87000,2.2695735,1.3622637,,,,,,,,,,,,,, -87100,2.58036,1.5264693,,,,,,,,,,,,,, -87200,2.507958,1.3998629,,,,,,,,,,,,,, -87300,2.1178455,1.3552349,,,,,,,,,,,,,, -87400,2.0781767,1.3674537,,,,,,,,,,,,,, -87500,2.3213305,1.4257995,,,,,,,,,,,,,, -87600,2.3293016,1.4572084,,,,,,,,,,,,,, -87700,2.264924,1.437298,,,,,,,,,,,,,, -87761,,,0.7531489133834839,0.9372791051864624,0.6677599549293518,1.3522411584854126,50000.0,0.534600019454956,2.101754903793335,10000.0,29616.86198425293,30661.273845672607,29616.86198425293,1039.2736871242523,2.084794282913208,0.0 -87800,2.2443357,1.5296216,,,,,,,,,,,,,, -87900,2.2277672,1.3871186,,,,,,,,,,,,,, -88000,2.4491963,1.4666839,,,,,,,,,,,,,, -88100,2.1721919,1.3939587,,,,,,,,,,,,,, -88200,2.020826,1.486788,,,,,,,,,,,,,, -88300,2.366585,1.4491168,,,,,,,,,,,,,, -88400,2.2143636,1.3906575,,,,,,,,,,,,,, -88500,2.1959672,1.4195592,,,,,,,,,,,,,, -88600,2.2617085,1.4016964,,,,,,,,,,,,,, -88700,2.1665106,1.4471209,,,,,,,,,,,,,, -88800,2.210578,1.4231871,,,,,,,,,,,,,, -88900,2.0347075,1.4170003,,,,,,,,,,,,,, -89000,2.2347453,1.4068202,,,,,,,,,,,,,, -89100,2.3353753,1.3186649,,,,,,,,,,,,,, -89200,2.1549628,1.3503826,,,,,,,,,,,,,, -89275,,,0.748465359210968,0.9442353248596193,0.6695799827575684,1.3373024463653564,50000.0,0.5367000102996826,2.0899970531463623,10000.0,30126.8455953598,31188.748248815536,30126.8455953598,1056.668353557587,2.1289286613464355,0.0 -89300,2.2313094,1.3489667,,,,,,,,,,,,,, -89400,2.3288083,1.3644097,,,,,,,,,,,,,, -89500,2.2869513,1.3716731,,,,,,,,,,,,,, -89600,2.4245238,1.4340397,,,,,,,,,,,,,, -89700,2.3750567,1.4852374,,,,,,,,,,,,,, -89800,2.315479,1.4326912,,,,,,,,,,,,,, -89900,2.178609,1.3707175,,,,,,,,,,,,,, -90000,2.2532563,1.3969774,,,,,,,,,,,,,, -90100,2.4865885,1.4618726,,,,,,,,,,,,,, -90200,2.304427,1.4116172,,,,,,,,,,,,,, -90300,2.3158362,1.511411,,,,,,,,,,,,,, -90400,2.1296299,1.4717672,,,,,,,,,,,,,, -90500,2.2313085,1.4916117,,,,,,,,,,,,,, -90600,2.3524442,1.3552723,,,,,,,,,,,,,, -90700,2.1755486,1.3619813,,,,,,,,,,,,,, -90789,,,0.7472097873687744,0.958838939666748,0.6690399646759033,1.3475358486175537,50000.0,0.5403000116348267,2.083833694458008,10000.0,30636.927599668503,31716.2787964344,30636.927599668503,1074.0231430530548,2.169685840606689,0.0 -90800,2.2613673,1.4599304,,,,,,,,,,,,,, -90900,2.2153018,1.3705014,,,,,,,,,,,,,, -91000,2.423882,1.3635765,,,,,,,,,,,,,, -91100,2.3828053,1.2966646,,,,,,,,,,,,,, -91200,2.259917,1.3168495,,,,,,,,,,,,,, -91300,2.432043,1.4062642,,,,,,,,,,,,,, -91400,2.1148849,1.20149,,,,,,,,,,,,,, -91500,2.242895,1.4979074,,,,,,,,,,,,,, -91600,2.5241632,1.4088593,,,,,,,,,,,,,, -91700,2.4818504,1.5987184,,,,,,,,,,,,,, -91800,2.0657506,1.3760948,,,,,,,,,,,,,, -91900,2.3309007,1.2986786,,,,,,,,,,,,,, -92000,2.512252,1.3853673,,,,,,,,,,,,,, -92100,2.262364,1.338885,,,,,,,,,,,,,, -92200,2.3726444,1.3386286,,,,,,,,,,,,,, -92300,2.2670634,1.2626958,,,,,,,,,,,,,, -92303,,,0.7487842440605164,0.9438230991363524,0.6720799803733826,1.337753415107727,50000.0,0.5448000431060791,2.058915138244629,10000.0,31146.947617292404,32243.688675642014,31146.947617292404,1091.316163063049,2.213765859603882,0.0 -92400,2.30241,1.3312774,,,,,,,,,,,,,, -92500,2.311271,1.3662661,,,,,,,,,,,,,, -92600,2.340484,1.3702177,,,,,,,,,,,,,, -92700,2.473721,1.3877279,,,,,,,,,,,,,, -92800,2.3821247,1.4726406,,,,,,,,,,,,,, -92900,2.3359516,1.4599707,,,,,,,,,,,,,, -93000,2.1839457,1.3395365,,,,,,,,,,,,,, -93100,2.3868465,1.4525207,,,,,,,,,,,,,, -93200,2.20338,1.330251,,,,,,,,,,,,,, -93300,2.688015,1.4221014,,,,,,,,,,,,,, -93400,2.2958658,1.3902171,,,,,,,,,,,,,, -93500,2.2618926,1.3772188,,,,,,,,,,,,,, -93600,2.52629,1.3842447,,,,,,,,,,,,,, -93700,2.463044,1.4837345,,,,,,,,,,,,,, -93800,2.53902,1.3196635,,,,,,,,,,,,,, -93815,,,0.7768255472183228,0.8307573795318604,0.6793599724769592,1.3027634620666504,50000.0,0.5523000359535217,2.045220375061035,10000.0,31657.00089836121,32772.09962654114,31657.00089836121,1109.5773482322693,2.25687313079834,0.0 -93900,2.3941836,1.5061275,,,,,,,,,,,,,, -94000,2.3247712,1.3548614,,,,,,,,,,,,,, -94100,2.674475,1.4352493,,,,,,,,,,,,,, -94200,2.3566556,1.3236004,,,,,,,,,,,,,, -94300,2.683465,1.4377909,,,,,,,,,,,,,, -94400,2.3419127,1.3936622,,,,,,,,,,,,,, -94500,2.1860313,1.3525462,,,,,,,,,,,,,, -94600,2.3177338,1.3111843,,,,,,,,,,,,,, -94700,2.3777642,1.3937405,,,,,,,,,,,,,, -94800,2.335217,1.3048496,,,,,,,,,,,,,, -94900,2.4325764,1.3567257,,,,,,,,,,,,,, -95000,2.4112918,1.4727464,,,,,,,,,,,,,, -95100,2.2759526,1.3539654,,,,,,,,,,,,,, -95200,2.365316,1.3826454,,,,,,,,,,,,,, -95300,2.4855118,1.4789796,,,,,,,,,,,,,, -95330,,,0.7841796875,0.8101245760917664,0.6791799664497375,1.3027790784835815,50000.0,0.5487000346183777,2.0057966709136963,10000.0,32167.20648097992,33299.79450273514,32167.20648097992,1126.968933343887,2.3015289306640625,0.0 -95400,2.3643053,1.2635503,,,,,,,,,,,,,, -95500,2.45254,1.4042419,,,,,,,,,,,,,, -95600,2.5750244,1.3177453,,,,,,,,,,,,,, -95700,2.480141,1.3825648,,,,,,,,,,,,,, -95800,2.4346693,1.4625838,,,,,,,,,,,,,, -95900,2.2379794,1.3858402,,,,,,,,,,,,,, -96000,2.3759923,1.2974341,,,,,,,,,,,,,, -96100,2.372448,1.3744959,,,,,,,,,,,,,, -96200,2.432613,1.3614149,,,,,,,,,,,,,, -96300,2.4263427,1.4221469,,,,,,,,,,,,,, -96400,2.725351,1.4834569,,,,,,,,,,,,,, -96500,2.262831,1.3473818,,,,,,,,,,,,,, -96600,2.5840714,1.4358996,,,,,,,,,,,,,, -96700,2.307647,1.311209,,,,,,,,,,,,,, -96800,2.3817623,1.4483064,,,,,,,,,,,,,, -96844,,,0.7759685516357422,0.8316084146499634,0.6843999624252319,1.2773557901382446,50000.0,0.5568000078201294,2.0034146308898926,10000.0,32677.118038654327,33827.184403419495,32677.118038654327,1144.3458700180054,2.350280284881592,0.0 -96900,2.6015384,1.4436663,,,,,,,,,,,,,, -97000,2.6430275,1.3629012,,,,,,,,,,,,,, -97100,2.345228,1.3314718,,,,,,,,,,,,,, -97200,2.433009,1.3141354,,,,,,,,,,,,,, -97300,2.5779839,1.3711374,,,,,,,,,,,,,, -97400,2.2646306,1.3417268,,,,,,,,,,,,,, -97500,2.3250122,1.2578812,,,,,,,,,,,,,, -97600,2.6617262,1.3386286,,,,,,,,,,,,,, -97700,2.3536792,1.3298132,,,,,,,,,,,,,, -97800,2.6337967,1.5259291,,,,,,,,,,,,,, -97900,2.473085,1.3600736,,,,,,,,,,,,,, -98000,2.3717601,1.3649906,,,,,,,,,,,,,, -98100,2.1210272,1.2912924,,,,,,,,,,,,,, -98200,2.295988,1.2740979,,,,,,,,,,,,,, -98300,2.3578577,1.3444204,,,,,,,,,,,,,, -98358,,,0.7757692933082581,0.8344160318374634,0.6873799562454224,1.2644789218902588,50000.0,0.558899998664856,1.9838387966156008,10000.0,33187.21611762047,34355.01317358017,33187.21611762047,1161.9749476909635,2.399492025375366,0.0 -98400,2.490398,1.3013382,,,,,,,,,,,,,, -98500,2.5601215,1.4271212,,,,,,,,,,,,,, -98600,2.3997447,1.3189342,,,,,,,,,,,,,, -98700,2.5220218,1.3108658,,,,,,,,,,,,,, -98800,2.3730378,1.2406151,,,,,,,,,,,,,, -98900,2.7570355,1.3137871,,,,,,,,,,,,,, -99000,2.826072,1.4230856,,,,,,,,,,,,,, -99100,2.326129,1.3102512,,,,,,,,,,,,,, -99200,2.4871328,1.3728473,,,,,,,,,,,,,, -99300,2.5171142,1.3305308,,,,,,,,,,,,,, -99400,2.535408,1.4102684,,,,,,,,,,,,,, -99500,2.3046186,1.3525097,,,,,,,,,,,,,, -99600,2.5578363,1.3341689,,,,,,,,,,,,,, -99700,2.1121619,1.2457961,,,,,,,,,,,,,, -99800,2.4189487,1.3106927,,,,,,,,,,,,,, -99872,,,0.7700095772743225,0.8556905388832092,0.6885600090026855,1.267644286155701,50000.0,0.5624000430107117,1.971338391304016,10000.0,33697.1601536274,34882.30074048042,33697.1601536274,1179.2202832698822,2.444276094436645,0.0 -99900,2.285468,1.4086928,,,,,,,,,,,,,, -100000,2.7605965,1.4031174,,,,,,,,,,,,,, -100100,2.5852063,1.3690557,,,,,,,,,,,,,, -100200,2.5396183,1.2287018,,,,,,,,,,,,,, -100300,2.5517502,1.4160863,,,,,,,,,,,,,, -100400,2.3929608,1.3521193,,,,,,,,,,,,,, -100500,2.4017463,1.4422957,,,,,,,,,,,,,, -100600,2.5207698,1.3297882,,,,,,,,,,,,,, -100700,3.0739374,1.3840045,,,,,,,,,,,,,, -100800,2.4649806,1.2830448,,,,,,,,,,,,,, -100900,2.4897752,1.2881298,,,,,,,,,,,,,, -101000,2.4849086,1.3098627,,,,,,,,,,,,,, -101100,2.2550232,1.372652,,,,,,,,,,,,,, -101200,2.4683921,1.39636,,,,,,,,,,,,,, -101300,2.9666703,1.3807795,,,,,,,,,,,,,, -101387,,,0.7763273119926453,0.8312370777130127,0.6921399831771851,1.2392897605895996,50000.0,0.5706000328063965,1.9568129777908323,10000.0,34207.31728053093,35410.15484523773,34207.31728053093,1196.8182473182678,2.4896559715271,0.0 -101400,2.4407823,1.3949798,,,,,,,,,,,,,, -101500,2.561618,1.2737988,,,,,,,,,,,,,, -101600,2.3537323,1.3532089,,,,,,,,,,,,,, -101700,2.4806325,1.2925606,,,,,,,,,,,,,, -101800,2.612262,1.3531266,,,,,,,,,,,,,, -101900,2.196901,1.2358254,,,,,,,,,,,,,, -102000,2.6580737,1.3458319,,,,,,,,,,,,,, -102100,2.5383537,1.3726065,,,,,,,,,,,,,, -102200,2.5681157,1.300782,,,,,,,,,,,,,, -102300,2.6526577,1.3792754,,,,,,,,,,,,,, -102400,2.5566592,1.3892962,,,,,,,,,,,,,, -102500,2.3184621,1.3291087,,,,,,,,,,,,,, -102600,2.5489557,1.2643024,,,,,,,,,,,,,, -102700,2.4480202,1.3857481,,,,,,,,,,,,,, -102800,2.364033,1.3038768,,,,,,,,,,,,,, -102900,2.5209653,1.205826,,,,,,,,,,,,,, -102901,,,0.8036909699440002,0.7377466559410095,0.6786800026893616,1.3162022829055786,50000.0,0.5493000149726868,2.0565006732940674,10000.0,34717.62727546692,35938.262838840485,34717.62727546692,1214.52796459198,2.526557683944702,0.0 -103000,2.9162285,1.3817472,,,,,,,,,,,,,, -103100,2.600346,1.2844241,,,,,,,,,,,,,, -103200,2.6055555,1.3410984,,,,,,,,,,,,,, -103300,2.658646,1.3013971,,,,,,,,,,,,,, -103400,2.583219,1.3054949,,,,,,,,,,,,,, -103500,2.5706,1.3127269,,,,,,,,,,,,,, -103600,2.596758,1.31461,,,,,,,,,,,,,, -103700,2.5118594,1.3249273,,,,,,,,,,,,,, -103800,2.6250436,1.3662953,,,,,,,,,,,,,, -103900,2.4613333,1.3053409,,,,,,,,,,,,,, -104000,2.4171576,1.2843904,,,,,,,,,,,,,, -104100,2.4086514,1.2802366,,,,,,,,,,,,,, -104200,2.8510554,1.3273478,,,,,,,,,,,,,, -104300,2.6858494,1.27304,,,,,,,,,,,,,, -104400,2.4163327,1.2267663,,,,,,,,,,,,,, -104416,,,0.8050262928009033,0.7113240957260132,0.6964199542999268,1.2255223989486694,50000.0,0.5697000026702881,1.912181735038757,10000.0,35227.735830545425,36465.87762641907,35227.735830545425,1231.9329543113708,2.574941396713257,0.0 -104500,2.8208,1.3843657,,,,,,,,,,,,,, -104600,2.8897264,1.3212023,,,,,,,,,,,,,, -104700,2.6014442,1.276503,,,,,,,,,,,,,, -104800,2.6258512,1.3047146,,,,,,,,,,,,,, -104900,3.0497584,1.3229717,,,,,,,,,,,,,, -105000,2.6276534,1.354176,,,,,,,,,,,,,, -105100,2.514531,1.3291135,,,,,,,,,,,,,, -105200,2.4358897,1.3102205,,,,,,,,,,,,,, -105300,2.8250983,1.4627619,,,,,,,,,,,,,, -105400,2.453197,1.3778344,,,,,,,,,,,,,, -105500,2.5964546,1.3350594,,,,,,,,,,,,,, -105600,2.598715,1.2983365,,,,,,,,,,,,,, -105700,2.8275359,1.3305482,,,,,,,,,,,,,, -105800,2.732204,1.2989686,,,,,,,,,,,,,, -105900,2.9130597,1.2214925,,,,,,,,,,,,,, -105931,,,0.7947624325752258,0.7517896294593811,0.69896000623703,1.214800238609314,50000.0,0.5741000175476074,1.9478965997695925,10000.0,35737.8859539032,36993.40591478348,35737.8859539032,1249.209624528885,2.623447895050049,0.0 -106000,2.6881046,1.3794637,,,,,,,,,,,,,, -106100,2.9162874,1.266001,,,,,,,,,,,,,, -106200,2.5816882,1.2384514,,,,,,,,,,,,,, -106300,2.9266036,1.2572314,,,,,,,,,,,,,, -106400,2.790506,1.2531322,,,,,,,,,,,,,, -106500,2.4825895,1.3060033,,,,,,,,,,,,,, -106600,2.8172145,1.2727909,,,,,,,,,,,,,, -106700,2.509417,1.2900782,,,,,,,,,,,,,, -106800,2.716583,1.30194,,,,,,,,,,,,,, -106900,2.7107575,1.2649184,,,,,,,,,,,,,, -107000,2.4887729,1.202645,,,,,,,,,,,,,, -107100,2.6373146,1.319951,,,,,,,,,,,,,, -107200,2.556858,1.3924209,,,,,,,,,,,,,, -107300,2.7385037,1.3682759,,,,,,,,,,,,,, -107400,2.6124785,1.310382,,,,,,,,,,,,,, -107445,,,0.7896404266357422,0.7744253277778625,0.6942200064659119,1.2482731342315674,50000.0,0.5676000118255615,1.9585601091384888,10000.0,36247.9498064518,37521.17093753815,36247.9498064518,1266.8068754673004,2.673878192901612,0.0 -107500,2.880659,1.3178568,,,,,,,,,,,,,, -107600,2.5388446,1.2348089,,,,,,,,,,,,,, -107700,2.8022788,1.3481374,,,,,,,,,,,,,, -107800,2.5706627,1.2549393,,,,,,,,,,,,,, -107900,2.677456,1.3201817,,,,,,,,,,,,,, -108000,2.8983412,1.3600554,,,,,,,,,,,,,, -108100,2.6099076,1.3186327,,,,,,,,,,,,,, -108200,2.6987338,1.3118758,,,,,,,,,,,,,, -108300,2.5817275,1.3125404,,,,,,,,,,,,,, -108400,2.4607003,1.3241978,,,,,,,,,,,,,, -108500,2.5993068,1.1663371,,,,,,,,,,,,,, -108600,2.988585,1.436844,,,,,,,,,,,,,, -108700,2.5076976,1.2936834,,,,,,,,,,,,,, -108800,2.9439094,1.3249171,,,,,,,,,,,,,, -108900,2.3718412,1.2512442,,,,,,,,,,,,,, -108958,,,0.7810905575752258,0.8077903389930725,0.6916999816894531,1.2639470100402832,50000.0,0.5630000233650208,1.98625648021698,10000.0,36757.94217133522,38049.23787140846,36757.94217133522,1284.7805352211,2.722330331802368,0.0 -109000,2.6427562,1.3126428,,,,,,,,,,,,,, -109100,2.8889923,1.2655045,,,,,,,,,,,,,, -109200,2.8069265,1.1999189,,,,,,,,,,,,,, -109300,2.6349065,1.343127,,,,,,,,,,,,,, -109400,2.965644,1.3185116,,,,,,,,,,,,,, -109500,2.796975,1.2664943,,,,,,,,,,,,,, -109600,2.6097977,1.2802528,,,,,,,,,,,,,, -109700,2.6884944,1.2282224,,,,,,,,,,,,,, -109800,2.8219416,1.2269734,,,,,,,,,,,,,, -109900,2.7143307,1.3012712,,,,,,,,,,,,,, -110000,2.633733,1.2495396,,,,,,,,,,,,,, -110100,2.7279818,1.2443007,,,,,,,,,,,,,, -110200,2.6880646,1.2908634,,,,,,,,,,,,,, -110300,2.8536122,1.2222654,,,,,,,,,,,,,, -110400,2.80049,1.4009057,,,,,,,,,,,,,, -110472,,,0.7847775816917419,0.7921836972236633,0.6990000009536743,1.2325977087020874,50000.0,0.5737000107765198,1.9548555612564087,10000.0,37267.98639202118,38576.88211965561,37267.98639202118,1302.261702299118,2.787564992904663,0.0 -110500,3.2270238,1.3134553,,,,,,,,,,,,,, -110600,2.636273,1.2517327,,,,,,,,,,,,,, -110700,2.7283807,1.3465089,,,,,,,,,,,,,, -110800,2.780411,1.2966344,,,,,,,,,,,,,, -110900,2.7769845,1.3326468,,,,,,,,,,,,,, -111000,2.8191185,1.2606769,,,,,,,,,,,,,, -111100,2.7376173,1.2950064,,,,,,,,,,,,,, -111200,2.619202,1.1749442,,,,,,,,,,,,,, -111300,2.6769187,1.2222263,,,,,,,,,,,,,, -111400,2.858385,1.2289239,,,,,,,,,,,,,, -111500,2.8675401,1.3065902,,,,,,,,,,,,,, -111600,3.082428,1.2684798,,,,,,,,,,,,,, -111700,2.874621,1.288553,,,,,,,,,,,,,, -111800,2.9339178,1.2836964,,,,,,,,,,,,,, -111900,2.5507839,1.2209152,,,,,,,,,,,,,, -111986,,,0.8351203799247742,0.5935096740722656,0.6996999979019165,1.23037850856781,50000.0,0.5711000561714172,1.9375004768371584,10000.0,37778.13521909714,39104.69184041023,37778.13521909714,1319.8213753700256,2.836085319519043,0.0 -112000,3.2244985,1.2123522,,,,,,,,,,,,,, -112100,2.9460213,1.2088807,,,,,,,,,,,,,, -112200,2.7644444,1.2315336,,,,,,,,,,,,,, -112300,2.7652075,1.3240445,,,,,,,,,,,,,, -112400,2.70419,1.3686887,,,,,,,,,,,,,, -112500,3.012937,1.2361426,,,,,,,,,,,,,, -112600,2.4076207,1.2400982,,,,,,,,,,,,,, -112700,2.7433238,1.1706314,,,,,,,,,,,,,, -112800,2.631594,1.1649964,,,,,,,,,,,,,, -112900,2.9023116,1.3231065,,,,,,,,,,,,,, -113000,2.714495,1.2112545,,,,,,,,,,,,,, -113100,2.7874086,1.1911391,,,,,,,,,,,,,, -113200,2.7419205,1.2607226,,,,,,,,,,,,,, -113300,2.5766754,1.1838074,,,,,,,,,,,,,, -113400,2.9756181,1.1905482,,,,,,,,,,,,,, -113500,2.6626825,1.2675478,,,,,,,,,,,,,, -113501,,,0.8151108026504517,0.672951340675354,0.6999399662017822,1.2182530164718628,50000.0,0.5791000127792358,1.906409859657288,10000.0,38288.519334316254,39632.731977939606,38288.519334316254,1337.3728301525116,2.8874781131744385,0.0 -113600,2.8126972,1.2673981,,,,,,,,,,,,,, -113700,3.0202322,1.2323002,,,,,,,,,,,,,, -113800,2.961507,1.2900379,,,,,,,,,,,,,, -113900,2.6667712,1.2707292,,,,,,,,,,,,,, -114000,2.7527385,1.2659478,,,,,,,,,,,,,, -114100,2.7066534,1.1958742,,,,,,,,,,,,,, -114200,2.8779547,1.1997893,,,,,,,,,,,,,, -114300,2.6587293,1.1203291,,,,,,,,,,,,,, -114400,2.7707167,1.1808536,,,,,,,,,,,,,, -114500,2.6763902,1.2999278,,,,,,,,,,,,,, -114600,2.9065711,1.2626159,,,,,,,,,,,,,, -114700,2.8192637,1.1937864,,,,,,,,,,,,,, -114800,2.8153229,1.2295822,,,,,,,,,,,,,, -114900,2.5623434,1.2678251,,,,,,,,,,,,,, -115000,3.043035,1.1915199,,,,,,,,,,,,,, -115015,,,0.80961012840271,0.6891085505485535,0.705299973487854,1.2037498950958252,50000.0,0.5841000080108643,1.930444598197937,10000.0,38798.61583328247,40160.355558395386,38798.61583328247,1354.794580221176,2.939371109008789,0.0 -115100,2.7373178,1.1948488,,,,,,,,,,,,,, -115200,3.1009595,1.2548028,,,,,,,,,,,,,, -115300,2.6924489,1.2311621,,,,,,,,,,,,,, -115400,2.845828,1.238488,,,,,,,,,,,,,, -115500,2.838928,1.2289511,,,,,,,,,,,,,, -115600,2.7734163,1.1620277,,,,,,,,,,,,,, -115700,3.118117,1.3480765,,,,,,,,,,,,,, -115800,3.0583994,1.1846999,,,,,,,,,,,,,, -115900,2.7543538,1.0899599,,,,,,,,,,,,,, -116000,2.8429456,1.2305294,,,,,,,,,,,,,, -116100,2.8654103,1.2567375,,,,,,,,,,,,,, -116200,2.8533475,1.2933216,,,,,,,,,,,,,, -116300,2.7712412,1.2165766,,,,,,,,,,,,,, -116400,2.7524755,1.204235,,,,,,,,,,,,,, -116500,2.9911852,1.1070559,,,,,,,,,,,,,, -116530,,,0.8014189600944519,0.7300856113433838,0.7023199796676636,1.212501883506775,50000.0,0.5679000020027161,1.9563889503479004,10000.0,39308.739602565765,40688.40866851807,39308.739602565765,1372.6245946884155,2.987029552459717,0.0 -116600,2.9881349,1.2131536,,,,,,,,,,,,,, -116700,3.038277,1.2430955,,,,,,,,,,,,,, -116800,2.8244252,1.1619924,,,,,,,,,,,,,, -116900,3.0050704,1.1825969,,,,,,,,,,,,,, -117000,2.81661,1.23886,,,,,,,,,,,,,, -117100,2.7509642,1.1464179,,,,,,,,,,,,,, -117200,2.8670602,1.2592127,,,,,,,,,,,,,, -117300,3.1406183,1.1896312,,,,,,,,,,,,,, -117400,2.6462665,1.1159447,,,,,,,,,,,,,, -117500,2.7795956,1.1846408,,,,,,,,,,,,,, -117600,2.86575,1.2576666,,,,,,,,,,,,,, -117700,2.7819345,1.2160504,,,,,,,,,,,,,, -117800,2.8901284,1.2126119,,,,,,,,,,,,,, -117900,2.779957,1.1458595,,,,,,,,,,,,,, -118000,3.0436494,1.2531599,,,,,,,,,,,,,, -118045,,,0.8037906289100647,0.7134524583816528,0.7069399952888489,1.1913037300109863,50000.0,0.5809000134468079,1.901960849761963,10000.0,39818.97498655319,41216.51202297211,39818.97498655319,1390.393991947174,3.032386541366577,0.0 -118100,2.9567533,1.2203839,,,,,,,,,,,,,, -118200,2.9493253,1.2313659,,,,,,,,,,,,,, -118300,3.020503,1.3028284,,,,,,,,,,,,,, -118400,2.990449,1.1125358,,,,,,,,,,,,,, -118500,2.9810493,1.1966236,,,,,,,,,,,,,, -118600,2.9209788,1.1611645,,,,,,,,,,,,,, -118700,2.9912734,1.2266545,,,,,,,,,,,,,, -118800,3.0620296,1.1945863,,,,,,,,,,,,,, -118900,2.923786,1.2074822,,,,,,,,,,,,,, -119000,3.1331494,1.1372519,,,,,,,,,,,,,, -119100,2.8487327,1.2435124,,,,,,,,,,,,,, -119200,3.001728,1.1969606,,,,,,,,,,,,,, -119300,2.81413,1.1169553,,,,,,,,,,,,,, -119400,2.7246726,1.1255302,,,,,,,,,,,,,, -119500,3.2952642,1.253866,,,,,,,,,,,,,, -119559,,,0.7993462681770325,0.7303955554962158,0.7000199556350708,1.2318073511123655,50000.0,0.567300021648407,1.958403825759888,10000.0,40328.96877479553,41744.0572385788,40328.96877479553,1407.8446052074432,3.0808701515197754,0.0 -119600,3.2500534,1.3445942,,,,,,,,,,,,,, -119700,3.1433883,1.2142141,,,,,,,,,,,,,, -119800,2.9636774,1.1567098,,,,,,,,,,,,,, -119900,2.7841375,1.1887636,,,,,,,,,,,,,, -120000,3.1053877,1.2195524,,,,,,,,,,,,,, -120100,2.9571922,1.2004828,,,,,,,,,,,,,, -120200,3.2962306,1.2326291,,,,,,,,,,,,,, -120300,3.343319,1.2634248,,,,,,,,,,,,,, -120400,2.856701,1.148643,,,,,,,,,,,,,, -120500,2.8095343,1.0887415,,,,,,,,,,,,,, -120600,2.9701304,1.2897762,,,,,,,,,,,,,, -120700,3.2941015,1.3012546,,,,,,,,,,,,,, -120800,3.005663,1.2330179,,,,,,,,,,,,,, -120900,2.8596113,1.2038146,,,,,,,,,,,,,, -121000,3.0406005,1.230386,,,,,,,,,,,,,, -121073,,,0.8441884517669678,0.5559610724449158,0.7035399675369263,1.1964123249053955,50000.0,0.5856000185012817,1.882620930671692,10000.0,40839.1434905529,42271.82386422157,40839.1434905529,1425.3332903385162,3.131650686264038,0.0 -121100,3.1392655,1.093747,,,,,,,,,,,,,, -121200,2.8217785,1.169588,,,,,,,,,,,,,, -121300,3.0740993,1.1148157,,,,,,,,,,,,,, -121400,2.9785104,1.1424813,,,,,,,,,,,,,, -121500,2.9323084,1.1990868,,,,,,,,,,,,,, -121600,3.3209324,1.2406309,,,,,,,,,,,,,, -121700,3.0375,1.2305262,,,,,,,,,,,,,, -121800,2.8993828,1.0825778,,,,,,,,,,,,,, -121900,3.13241,1.3034672,,,,,,,,,,,,,, -122000,2.9230254,1.1252344,,,,,,,,,,,,,, -122100,2.9963105,1.2504545,,,,,,,,,,,,,, -122200,3.005407,1.126605,,,,,,,,,,,,,, -122300,2.9584177,1.1742694,,,,,,,,,,,,,, -122400,3.0326254,1.2008944,,,,,,,,,,,,,, -122500,3.4134202,1.2048929,,,,,,,,,,,,,, -122588,,,0.83203125,0.6034864783287048,0.7110399603843689,1.172864317893982,50000.0,0.5904000401496887,1.8778742551803589,10000.0,41349.36904430389,42799.59975862503,41349.36904430389,1442.7801163196564,3.1828622817993164,0.0 -122600,3.2362697,1.1891732,,,,,,,,,,,,,, -122700,2.987382,1.0603409,,,,,,,,,,,,,, -122800,2.9566863,1.1337547,,,,,,,,,,,,,, -122900,2.7850134,1.0289495,,,,,,,,,,,,,, -123000,3.042356,1.2002318,,,,,,,,,,,,,, -123100,2.8860252,1.161711,,,,,,,,,,,,,, -123200,3.1562936,1.2214017,,,,,,,,,,,,,, -123300,3.1292057,1.1119604,,,,,,,,,,,,,, -123400,3.490894,1.1798697,,,,,,,,,,,,,, -123500,3.2129261,1.2229013,,,,,,,,,,,,,, -123600,2.9711046,1.1355084,,,,,,,,,,,,,, -123700,2.9706995,1.0885285,,,,,,,,,,,,,, -123800,3.2912054,1.1680859,,,,,,,,,,,,,, -123900,2.8508224,1.1463445,,,,,,,,,,,,,, -124000,2.8967931,1.0791117,,,,,,,,,,,,,, -124100,2.8715503,1.118571,,,,,,,,,,,,,, -124103,,,0.8218072056770325,0.6350313425064087,0.7073799967765808,1.1969265937805176,50000.0,0.5773000121116638,1.9539101123809808,10000.0,41859.42814803124,43327.76603150368,41859.42814803124,1460.7830305099487,3.2322473526000977,0.0 -124200,2.8894107,1.1741147,,,,,,,,,,,,,, -124300,3.30095,1.2329587,,,,,,,,,,,,,, -124400,2.904628,1.1259602,,,,,,,,,,,,,, -124500,2.908671,1.1689308,,,,,,,,,,,,,, -124600,2.811263,1.0831487,,,,,,,,,,,,,, -124700,3.217156,1.148753,,,,,,,,,,,,,, -124800,3.022651,1.1526634,,,,,,,,,,,,,, -124900,3.302878,1.1543732,,,,,,,,,,,,,, -125000,3.1474133,1.1289806,,,,,,,,,,,,,, -125100,3.3216012,1.2320727,,,,,,,,,,,,,, -125200,3.2267127,1.2316785,,,,,,,,,,,,,, -125300,3.0832531,1.1379882,,,,,,,,,,,,,, -125400,3.2371404,1.1278353,,,,,,,,,,,,,, -125500,3.3422472,1.2212284,,,,,,,,,,,,,, -125600,3.0092545,1.1442235,,,,,,,,,,,,,, -125617,,,0.8284637928009033,0.6209262609481812,0.7165399789810181,1.1599438190460205,50000.0,0.5925000309944153,1.8581621646881104,10000.0,42369.41453456879,43855.42738699913,42369.41453456879,1478.3529794216156,3.285027503967285,0.0 -125700,3.4114501,1.1558707,,,,,,,,,,,,,, -125800,2.882311,1.0643698,,,,,,,,,,,,,, -125900,3.2744694,1.116328,,,,,,,,,,,,,, -126000,3.303984,1.2379038,,,,,,,,,,,,,, -126100,3.1490686,1.0503331,,,,,,,,,,,,,, -126200,3.2093043,1.1190611,,,,,,,,,,,,,, -126300,3.150608,1.1248542,,,,,,,,,,,,,, -126400,3.1588743,1.080699,,,,,,,,,,,,,, -126500,3.1863248,1.1385123,,,,,,,,,,,,,, -126600,3.1835566,0.9987216,,,,,,,,,,,,,, -126700,3.0920875,1.0648565,,,,,,,,,,,,,, -126800,3.071502,1.070481,,,,,,,,,,,,,, -126900,3.2236028,1.0464946,,,,,,,,,,,,,, -127000,3.1606483,1.2035977,,,,,,,,,,,,,, -127100,3.1839507,1.0901539,,,,,,,,,,,,,, -127132,,,0.8194355964660645,0.6457960605621338,0.7070199847221375,1.2006008625030518,50000.0,0.5856000185012817,1.9250835180282595,10000.0,42879.60580301285,44383.18127179146,42879.60580301285,1495.813749074936,3.3334877490997314,0.0 -127200,3.314699,1.1363999,,,,,,,,,,,,,, -127300,3.3005157,1.2857623,,,,,,,,,,,,,, -127400,3.3961391,1.1427364,,,,,,,,,,,,,, -127500,3.0887802,1.1312339,,,,,,,,,,,,,, -127600,2.9565296,1.0749428,,,,,,,,,,,,,, -127700,3.2410638,1.1300778,,,,,,,,,,,,,, -127800,3.1582243,1.1553388,,,,,,,,,,,,,, -127900,3.5585186,1.2051334,,,,,,,,,,,,,, -128000,3.0751388,1.149037,,,,,,,,,,,,,, -128100,3.2334301,1.1598786,,,,,,,,,,,,,, -128200,3.4239535,1.1476448,,,,,,,,,,,,,, -128300,3.2452824,1.1653204,,,,,,,,,,,,,, -128400,3.223028,1.1402906,,,,,,,,,,,,,, -128500,3.4730868,1.118315,,,,,,,,,,,,,, -128600,2.9961736,1.0907148,,,,,,,,,,,,,, -128647,,,0.8309949040412903,0.6023525595664978,0.7137399911880493,1.1708685159683228,50000.0,0.5854000449180603,1.894575119018555,10000.0,43389.64476656914,44910.88752961159,43389.64476656914,1513.3805103302002,3.381448268890381,0.0 -128700,3.1932185,0.9541168,,,,,,,,,,,,,, -128800,3.138723,1.1524982,,,,,,,,,,,,,, -128900,3.2446554,1.1030842,,,,,,,,,,,,,, -129000,3.071177,1.0133743,,,,,,,,,,,,,, -129100,3.2903862,1.1479335,,,,,,,,,,,,,, -129200,3.4220293,1.1278839,,,,,,,,,,,,,, -129300,3.5138106,1.1446593,,,,,,,,,,,,,, -129400,2.9434178,0.9905727,,,,,,,,,,,,,, -129500,3.4085765,1.1049719,,,,,,,,,,,,,, -129600,3.1909525,0.98639774,,,,,,,,,,,,,, -129700,3.4325316,1.054983,,,,,,,,,,,,,, -129800,3.0398922,1.0323211,,,,,,,,,,,,,, -129900,3.1004813,1.1399753,,,,,,,,,,,,,, -130000,3.0741034,1.1737103,,,,,,,,,,,,,, -130100,3.1572528,1.0496504,,,,,,,,,,,,,, -130160,,,0.8641980290412903,0.4822084009647369,0.7169199585914612,1.153667449951172,50000.0,0.5945000052452087,1.87068510055542,10000.0,43899.57999563217,45438.51849889755,43899.57999563217,1530.9731702804563,3.430178165435791,0.0 -130200,3.4583244,1.0379405,,,,,,,,,,,,,, -130300,3.2716265,1.0449729,,,,,,,,,,,,,, -130400,3.1078978,0.9728106,,,,,,,,,,,,,, -130500,3.4099061,0.966226,,,,,,,,,,,,,, -130600,3.2789063,1.0118208,,,,,,,,,,,,,, -130700,3.6650105,1.0874944,,,,,,,,,,,,,, -130800,3.2500718,1.0201588,,,,,,,,,,,,,, -130900,3.1704452,1.0460284,,,,,,,,,,,,,, -131000,3.0871785,1.0585972,,,,,,,,,,,,,, -131100,3.283675,0.9953464,,,,,,,,,,,,,, -131200,3.2152474,1.0316074,,,,,,,,,,,,,, -131300,3.171407,1.0240663,,,,,,,,,,,,,, -131400,3.093949,1.0514017,,,,,,,,,,,,,, -131500,3.363664,1.1845624,,,,,,,,,,,,,, -131600,3.144058,1.0770152,,,,,,,,,,,,,, -131675,,,0.853535532951355,0.5158511400222778,0.7181800007820129,1.1625126600265503,50000.0,0.5956000089645386,1.894251823425293,10000.0,44409.718037605286,45967.23632621765,44409.718037605286,1549.4466173648834,3.483417510986328,0.0 -131700,3.2718794,1.1127633,,,,,,,,,,,,,, -131800,3.3957183,1.0679784,,,,,,,,,,,,,, -131900,3.3600564,1.07014,,,,,,,,,,,,,, -132000,3.0594265,1.0282134,,,,,,,,,,,,,, -132100,3.4864337,1.1150554,,,,,,,,,,,,,, -132200,3.1600046,1.0857904,,,,,,,,,,,,,, -132300,3.430717,1.0559587,,,,,,,,,,,,,, -132400,3.4711213,1.086449,,,,,,,,,,,,,, -132500,3.373563,0.95925224,,,,,,,,,,,,,, -132600,3.4784782,1.0800657,,,,,,,,,,,,,, -132700,3.6722505,1.0894365,,,,,,,,,,,,,, -132800,3.2188363,0.99381125,,,,,,,,,,,,,, -132900,3.373169,1.1097466,,,,,,,,,,,,,, -133000,3.5752363,1.1934328,,,,,,,,,,,,,, -133100,3.3237984,1.007084,,,,,,,,,,,,,, -133189,,,0.8512635231018066,0.518912672996521,0.7234999537467957,1.1488239765167236,50000.0,0.5996000170707703,1.8523231744766235,10000.0,44919.616294384,46494.96126246452,44919.616294384,1567.1695573329926,3.533687829971313,0.0 -133200,3.5917158,1.1671269,,,,,,,,,,,,,, -133300,3.081072,1.023593,,,,,,,,,,,,,, -133400,3.347167,1.1312506,,,,,,,,,,,,,, -133500,3.2605722,1.0758431,,,,,,,,,,,,,, -133600,3.6055002,0.9946462,,,,,,,,,,,,,, -133700,3.709313,1.0224017,,,,,,,,,,,,,, -133800,3.6293962,1.1646272,,,,,,,,,,,,,, -133900,3.1150177,1.0082382,,,,,,,,,,,,,, -134000,3.3033726,0.9980055,,,,,,,,,,,,,, -134100,3.545272,1.0539807,,,,,,,,,,,,,, -134200,3.351135,1.0180953,,,,,,,,,,,,,, -134300,3.3526251,1.1098038,,,,,,,,,,,,,, -134400,3.064416,0.99129015,,,,,,,,,,,,,, -134500,3.467842,1.1254362,,,,,,,,,,,,,, -134600,3.387135,1.0090798,,,,,,,,,,,,,, -134700,3.5486722,1.0853055,,,,,,,,,,,,,, -134703,,,0.8521404266357422,0.5233726501464844,0.7243199944496155,1.1318014860153198,50000.0,0.6012000441551208,1.8476316928863523,10000.0,45429.56394815445,47022.609437942505,45429.56394815445,1584.762847185135,3.58777403831482,0.0 -134800,3.2019799,1.0168715,,,,,,,,,,,,,, -134900,3.2634776,0.9733069,,,,,,,,,,,,,, -135000,3.5169203,1.0374892,,,,,,,,,,,,,, -135100,3.309797,0.9872132,,,,,,,,,,,,,, -135200,3.3378623,1.1014466,,,,,,,,,,,,,, -135300,3.4987752,1.0727926,,,,,,,,,,,,,, -135400,3.7430882,1.1396852,,,,,,,,,,,,,, -135500,3.4150648,1.1068367,,,,,,,,,,,,,, -135600,3.4789565,1.0167336,,,,,,,,,,,,,, -135700,3.3847902,0.99101996,,,,,,,,,,,,,, -135800,3.3222868,1.0239196,,,,,,,,,,,,,, -135900,3.2544646,0.9536018,,,,,,,,,,,,,, -136000,3.3996444,1.062899,,,,,,,,,,,,,, -136100,3.534564,1.0598195,,,,,,,,,,,,,, -136200,3.137151,1.0186986,,,,,,,,,,,,,, -136217,,,0.8546316623687744,0.5151609182357788,0.7261799573898315,1.1283327341079712,50000.0,0.6022000312805176,1.8703584671020508,10000.0,45939.55491280556,47550.36924123764,45939.55491280556,1602.4290103912354,3.637843132019043,0.0 -136300,3.332709,0.9713217,,,,,,,,,,,,,, -136400,3.6530573,1.0071622,,,,,,,,,,,,,, -136500,3.6460986,1.0188974,,,,,,,,,,,,,, -136600,3.4894056,1.017047,,,,,,,,,,,,,, -136700,3.3477764,1.0106483,,,,,,,,,,,,,, -136800,3.401579,1.0626855,,,,,,,,,,,,,, -136900,3.112364,0.99216086,,,,,,,,,,,,,, -137000,3.440428,0.93378127,,,,,,,,,,,,,, -137100,3.5916429,1.0711595,,,,,,,,,,,,,, -137200,3.4444077,1.0114002,,,,,,,,,,,,,, -137300,3.793735,0.9805641,,,,,,,,,,,,,, -137400,3.7741086,0.98184675,,,,,,,,,,,,,, -137500,3.7574034,1.0732017,,,,,,,,,,,,,, -137600,3.6041167,0.9914516,,,,,,,,,,,,,, -137700,3.7038743,1.0961982,,,,,,,,,,,,,, -137732,,,0.8350605964660645,0.5876715183258057,0.7061399817466736,1.225954532623291,50000.0,0.5824000239372253,1.9551299810409544,10000.0,46449.772149086,48078.53934550285,46449.772149086,1620.2749452590942,3.691930770874024,0.0 -137800,3.5975566,0.9808386,,,,,,,,,,,,,, -137900,3.455495,1.008637,,,,,,,,,,,,,, -138000,3.739561,1.0204527,,,,,,,,,,,,,, -138100,3.3096588,0.98954207,,,,,,,,,,,,,, -138200,3.7700987,1.0204523,,,,,,,,,,,,,, -138300,3.3768263,0.9827486,,,,,,,,,,,,,, -138400,3.8612833,1.0790434,,,,,,,,,,,,,, -138500,3.417733,0.92454195,,,,,,,,,,,,,, -138600,3.5717037,1.0326828,,,,,,,,,,,,,, -138700,3.8243303,1.0323988,,,,,,,,,,,,,, -138800,3.4844897,0.96183217,,,,,,,,,,,,,, -138900,3.674161,0.9624563,,,,,,,,,,,,,, -139000,3.5671463,1.0360823,,,,,,,,,,,,,, -139100,3.4419906,0.9683551,,,,,,,,,,,,,, -139200,3.4400132,0.9888097,,,,,,,,,,,,,, -139246,,,0.8834900856018066,0.4090985059738159,0.7286199927330017,1.1191855669021606,50000.0,0.6080000400543213,1.8433808088302608,10000.0,46959.832931280136,48606.27636384964,46959.832931280136,1637.8468651771543,3.743335962295532,0.0 -139300,3.7498534,1.075738,,,,,,,,,,,,,, -139400,3.38602,0.91313046,,,,,,,,,,,,,, -139500,3.4595485,1.0322857,,,,,,,,,,,,,, -139600,3.93243,0.94756067,,,,,,,,,,,,,, -139700,3.462992,0.8913594,,,,,,,,,,,,,, -139800,3.6816857,0.9361329,,,,,,,,,,,,,, -139900,3.8074505,0.9933727,,,,,,,,,,,,,, -140000,3.688745,0.9466858,,,,,,,,,,,,,, -140100,3.7197826,1.0929613,,,,,,,,,,,,,, -140200,3.3304532,0.91424197,,,,,,,,,,,,,, -140300,3.647119,0.9817412,,,,,,,,,,,,,, -140400,3.5062895,0.9433067,,,,,,,,,,,,,, -140500,3.8137023,1.035919,,,,,,,,,,,,,, -140600,3.4264789,0.90612924,,,,,,,,,,,,,, -140700,3.8710024,0.95658016,,,,,,,,,,,,,, -140761,,,0.8775510191917419,0.4234741926193237,0.7308799624443054,1.1201832294464111,50000.0,0.6022000312805176,1.865876317024231,10000.0,47470.02541399002,49134.33640837669,47470.02541399002,1655.6108510494232,3.794053316116333,0.0 -140800,3.677751,0.88930136,,,,,,,,,,,,,, -140900,3.5229714,0.9417216,,,,,,,,,,,,,, -141000,3.7495215,0.9801192,,,,,,,,,,,,,, -141100,3.5297525,0.9588081,,,,,,,,,,,,,, -141200,3.5184433,0.9550386,,,,,,,,,,,,,, -141300,4.103953,1.0214005,,,,,,,,,,,,,, -141400,3.4685102,0.921293,,,,,,,,,,,,,, -141500,3.6222744,0.9937414,,,,,,,,,,,,,, -141600,4.0966587,1.0664668,,,,,,,,,,,,,, -141700,4.1208267,0.9537108,,,,,,,,,,,,,, -141800,3.3685853,0.95375854,,,,,,,,,,,,,, -141900,3.5608914,0.86938834,,,,,,,,,,,,,, -142000,3.644447,0.9630288,,,,,,,,,,,,,, -142100,3.683171,1.0362253,,,,,,,,,,,,,, -142200,3.7120938,0.934593,,,,,,,,,,,,,, -142275,,,0.869559109210968,0.4549353718757629,0.729699969291687,1.1267502307891846,50000.0,0.6037000417709351,1.860552668571472,10000.0,47980.01170706749,49661.93563914299,47980.01170706749,1673.117571592331,3.8474504947662354,0.0 -142300,3.5628278,0.95494664,,,,,,,,,,,,,, -142400,4.0025487,0.9690418,,,,,,,,,,,,,, -142500,3.8817816,0.9825347,,,,,,,,,,,,,, -142600,3.6902177,0.9640564,,,,,,,,,,,,,, -142700,3.2710872,0.90742576,,,,,,,,,,,,,, -142800,3.904098,1.0392843,,,,,,,,,,,,,, -142900,4.0149975,0.95819104,,,,,,,,,,,,,, -143000,3.531726,0.9578894,,,,,,,,,,,,,, -143100,4.275051,1.0056143,,,,,,,,,,,,,, -143200,4.350479,0.9233798,,,,,,,,,,,,,, -143300,3.7354987,1.0397744,,,,,,,,,,,,,, -143400,3.8281226,0.9778752,,,,,,,,,,,,,, -143500,3.5873094,0.9035405,,,,,,,,,,,,,, -143600,4.280659,0.89864916,,,,,,,,,,,,,, -143700,3.6732488,0.8263642,,,,,,,,,,,,,, -143788,,,0.8769331574440002,0.4277708530426025,0.7282199859619141,1.118950605392456,50000.0,0.6041000485420227,1.8592493534088133,10000.0,48489.98080587387,50189.48695039749,48489.98080587387,1690.5912556648254,3.902509450912476,0.0 -143800,4.5782504,1.0646293,,,,,,,,,,,,,, -143900,3.5196173,0.9310656,,,,,,,,,,,,,, -144000,3.660019,0.9253609,,,,,,,,,,,,,, -144100,3.862197,0.9503002,,,,,,,,,,,,,, -144200,3.5974689,0.91417295,,,,,,,,,,,,,, -144300,3.8431897,0.9057482,,,,,,,,,,,,,, -144400,3.7463787,0.9857836,,,,,,,,,,,,,, -144500,3.8163435,0.8749681,,,,,,,,,,,,,, -144600,3.849676,0.9585552,,,,,,,,,,,,,, -144700,3.5000103,0.9043116,,,,,,,,,,,,,, -144800,3.894853,0.8631749,,,,,,,,,,,,,, -144900,3.5306525,0.859725,,,,,,,,,,,,,, -145000,3.7095685,0.8988885,,,,,,,,,,,,,, -145100,3.8026547,0.90023637,,,,,,,,,,,,,, -145200,3.8254385,0.97068393,,,,,,,,,,,,,, -145300,3.7579834,0.94619435,,,,,,,,,,,,,, -145301,,,0.8763153553009033,0.4283612966537475,0.7317599654197693,1.1129244565963743,50000.0,0.6092000007629395,1.8546922206878664,10000.0,48999.96798682213,50717.04645681381,48999.96798682213,1708.0588409900663,3.9540607929229736,0.0 -145400,3.9220755,0.83030033,,,,,,,,,,,,,, -145500,3.6787043,0.9447482,,,,,,,,,,,,,, -145600,3.9543238,0.9472068,,,,,,,,,,,,,, -145700,3.8902519,0.94963485,,,,,,,,,,,,,, -145800,3.6042135,0.92490107,,,,,,,,,,,,,, -145900,3.8145595,0.8714326,,,,,,,,,,,,,, -146000,4.287434,1.0639262,,,,,,,,,,,,,, -146100,4.3053546,0.99517155,,,,,,,,,,,,,, -146200,3.897915,0.9094047,,,,,,,,,,,,,, -146300,4.1377788,0.9496101,,,,,,,,,,,,,, -146400,4.155675,0.9095925,,,,,,,,,,,,,, -146500,4.1094694,0.99868083,,,,,,,,,,,,,, -146600,3.728858,0.82284236,,,,,,,,,,,,,, -146700,3.9295406,0.8863263,,,,,,,,,,,,,, -146800,4.148761,1.0017235,,,,,,,,,,,,,, -146815,,,0.8947902917861938,0.3664727807044983,0.7371000051498413,1.0927094221115112,50000.0,0.6107000112533569,1.8342806100845337,10000.0,49510.08322453499,51244.758184194565,49510.08322453499,1725.5483317375183,4.008327007293701,0.0 -146900,3.936058,0.96722466,,,,,,,,,,,,,, -147000,3.6004548,0.8119663,,,,,,,,,,,,,, -147100,3.7332819,0.88130724,,,,,,,,,,,,,, -147200,3.7157648,0.90386444,,,,,,,,,,,,,, -147300,4.027201,0.9183649,,,,,,,,,,,,,, -147400,3.9605415,0.91455173,,,,,,,,,,,,,, -147500,4.029736,0.95117545,,,,,,,,,,,,,, -147600,3.9202611,0.9363357,,,,,,,,,,,,,, -147700,4.1421213,0.9111255,,,,,,,,,,,,,, -147800,3.9055266,0.85453546,,,,,,,,,,,,,, -147900,3.787865,0.8754556,,,,,,,,,,,,,, -148000,3.9813309,0.94209296,,,,,,,,,,,,,, -148100,4.1899505,0.9474251,,,,,,,,,,,,,, -148200,3.9524722,0.8545438,,,,,,,,,,,,,, -148300,4.323169,0.88735664,,,,,,,,,,,,,, -148329,,,0.904934585094452,0.327584832906723,0.7351999878883362,1.1068998575210571,50000.0,0.6067000031471252,1.849771738052368,10000.0,50020.06638741493,51772.60830807686,50020.06638741493,1743.310753107071,4.059577941894531,0.0 -148400,4.015534,0.94680315,,,,,,,,,,,,,, -148500,3.727202,0.8299132,,,,,,,,,,,,,, -148600,4.078368,0.88334644,,,,,,,,,,,,,, -148700,3.844708,0.8400441,,,,,,,,,,,,,, -148800,4.193438,0.8803427,,,,,,,,,,,,,, -148900,3.93681,0.94091433,,,,,,,,,,,,,, -149000,4.1668744,0.89077234,,,,,,,,,,,,,, -149100,4.1136823,0.8931768,,,,,,,,,,,,,, -149200,3.9339707,0.9086999,,,,,,,,,,,,,, -149300,3.763929,0.9376372,,,,,,,,,,,,,, -149400,3.9096804,0.876829,,,,,,,,,,,,,, -149500,4.144906,0.92668104,,,,,,,,,,,,,, -149600,3.7317452,0.81790626,,,,,,,,,,,,,, -149700,4.244897,0.88377774,,,,,,,,,,,,,, -149800,4.2693853,0.9148557,,,,,,,,,,,,,, -149843,,,0.901566445827484,0.3411449491977691,0.7376599907875061,1.0968347787857056,50000.0,0.6149000525474548,1.8277121782302856,10000.0,50530.12015199661,52300.12011003494,50530.12015199661,1760.661033153534,4.114696264266968,0.0 -149900,3.9261851,0.9040629,,,,,,,,,,,,,, -150000,3.7189152,0.86577916,,,,,,,,,,,,,, -150100,4.2372622,0.7999842,,,,,,,,,,,,,, -150200,4.046445,0.95560443,,,,,,,,,,,,,, -150300,3.9055877,0.86385894,,,,,,,,,,,,,, -150400,4.2988143,0.8816323,,,,,,,,,,,,,, -150500,4.328543,0.9470316,,,,,,,,,,,,,, -150600,4.580611,0.8502341,,,,,,,,,,,,,, -150700,4.4743385,0.9604027,,,,,,,,,,,,,, -150800,3.9356916,0.82028925,,,,,,,,,,,,,, -150900,4.03862,0.89511764,,,,,,,,,,,,,, -151000,4.179754,0.7965307,,,,,,,,,,,,,, -151100,3.838301,0.8506212,,,,,,,,,,,,,, -151200,4.352758,0.9298353,,,,,,,,,,,,,, -151300,4.277946,0.92374015,,,,,,,,,,,,,, -151357,,,0.9032804369926452,0.335657387971878,0.7389000058174133,1.0878074169158936,50000.0,0.6163000464439392,1.8217467069625848,10000.0,51040.12117242813,52828.000644207,51040.12117242813,1778.430143117905,4.172338247299194,0.0 -151400,4.0939336,0.86258394,,,,,,,,,,,,,, -151500,3.8610084,0.90787685,,,,,,,,,,,,,, -151600,4.231354,0.85104537,,,,,,,,,,,,,, -151700,4.1613855,0.8717164,,,,,,,,,,,,,, -151800,4.155852,0.87037116,,,,,,,,,,,,,, -151900,3.824908,0.77889585,,,,,,,,,,,,,, -152000,3.767729,0.8434961,,,,,,,,,,,,,, -152100,4.0460477,0.8239369,,,,,,,,,,,,,, -152200,3.8839867,0.8218497,,,,,,,,,,,,,, -152300,3.782464,0.754029,,,,,,,,,,,,,, -152400,4.2739606,0.97725636,,,,,,,,,,,,,, -152500,4.0020876,0.79552704,,,,,,,,,,,,,, -152600,4.123478,0.8566287,,,,,,,,,,,,,, -152700,3.8224375,0.786465,,,,,,,,,,,,,, -152800,3.917351,0.8588026,,,,,,,,,,,,,, -152872,,,0.9049944281578064,0.3283750116825104,0.7387199997901917,1.1011860370635986,50000.0,0.6160000562667847,1.8262497186660769,10000.0,51550.18557286263,53355.74990034104,51550.18557286263,1796.0045523643494,4.2290143966674805,0.0 -152900,4.2068257,0.87842196,,,,,,,,,,,,,, -153000,4.0345216,0.81131256,,,,,,,,,,,,,, -153100,4.553978,0.8085229,,,,,,,,,,,,,, -153200,4.1327763,0.92271644,,,,,,,,,,,,,, -153300,4.120935,0.82883054,,,,,,,,,,,,,, -153400,4.1280055,0.84147525,,,,,,,,,,,,,, -153500,4.3628016,0.8306212,,,,,,,,,,,,,, -153600,4.3915715,0.85943234,,,,,,,,,,,,,, -153700,4.1322994,0.8591367,,,,,,,,,,,,,, -153800,4.3152266,0.8154378,,,,,,,,,,,,,, -153900,4.2948003,0.8373037,,,,,,,,,,,,,, -154000,4.349469,0.90688074,,,,,,,,,,,,,, -154100,4.1144314,0.8663986,,,,,,,,,,,,,, -154200,4.3850703,0.8950397,,,,,,,,,,,,,, -154300,3.9112675,0.7638825,,,,,,,,,,,,,, -154386,,,0.9026825428009032,0.3313542008399963,0.7376799583435059,1.1047000885009766,50000.0,0.616100013256073,1.8450510501861568,10000.0,52060.21515059471,53883.71820926666,52060.21515059471,1813.832791805268,4.287072420120239,0.0 -154400,4.28677,0.8964618,,,,,,,,,,,,,, -154500,4.003208,0.829636,,,,,,,,,,,,,, -154600,3.9098573,0.7644913,,,,,,,,,,,,,, -154700,3.805604,0.8044286,,,,,,,,,,,,,, -154800,4.4528346,0.768995,,,,,,,,,,,,,, -154900,3.7235749,0.72676736,,,,,,,,,,,,,, -155000,4.2479773,0.8069887,,,,,,,,,,,,,, -155100,4.2704835,0.843094,,,,,,,,,,,,,, -155200,4.135775,0.7245764,,,,,,,,,,,,,, -155300,4.47674,0.82279724,,,,,,,,,,,,,, -155400,3.9355955,0.79375154,,,,,,,,,,,,,, -155500,4.3422623,0.7839641,,,,,,,,,,,,,, -155600,4.1444516,0.82740414,,,,,,,,,,,,,, -155700,4.28424,0.80889344,,,,,,,,,,,,,, -155800,3.8808935,0.7717025,,,,,,,,,,,,,, -155900,,,0.9301259517669678,0.2502636015415191,0.7440800070762634,1.0782766342163086,50000.0,0.6147000193595886,1.8321329355239868,10000.0,52570.35527801514,54411.412940740585,52570.35527801514,1831.2789142131803,4.342725038528442,0.0 -155900,4.0533957,0.7797965,,,,,,,,,,,,,, -156000,4.5264564,0.841571,,,,,,,,,,,,,, -156100,4.063702,0.7513301,,,,,,,,,,,,,, -156200,3.783733,0.7210554,,,,,,,,,,,,,, -156300,3.989038,0.7742172,,,,,,,,,,,,,, -156400,4.0980673,0.8781815,,,,,,,,,,,,,, -156500,4.3334165,0.6814492,,,,,,,,,,,,,, -156600,4.449696,0.90273976,,,,,,,,,,,,,, -156700,4.1733365,0.82237065,,,,,,,,,,,,,, -156800,4.1496153,0.7804185,,,,,,,,,,,,,, -156900,4.243556,0.8468458,,,,,,,,,,,,,, -157000,4.544961,0.84525704,,,,,,,,,,,,,, -157100,4.37832,0.80427575,,,,,,,,,,,,,, -157200,4.28887,0.79192626,,,,,,,,,,,,,, -157300,4.2087717,0.7673945,,,,,,,,,,,,,, -157400,4.0495243,0.8226815,,,,,,,,,,,,,, -157414,,,0.9299465417861938,0.2439799904823303,0.742579996585846,1.0852277278900146,50000.0,0.6175000071525574,1.845854997634888,10000.0,53080.38762998581,54939.17242026329,53080.38762998581,1848.89755487442,4.397834777832031,0.0 -157500,4.6026382,0.8070273,,,,,,,,,,,,,, -157600,3.824225,0.7590424,,,,,,,,,,,,,, -157700,4.088798,0.767505,,,,,,,,,,,,,, -157800,4.157088,0.81334394,,,,,,,,,,,,,, -157900,4.1857963,0.87154794,,,,,,,,,,,,,, -158000,4.7706213,0.8146809,,,,,,,,,,,,,, -158100,4.6038775,0.8974297,,,,,,,,,,,,,, -158200,4.223793,0.84773946,,,,,,,,,,,,,, -158300,3.9769971,0.7585554,,,,,,,,,,,,,, -158400,4.063481,0.7788024,,,,,,,,,,,,,, -158500,4.388408,0.8793893,,,,,,,,,,,,,, -158600,4.075224,0.7706595,,,,,,,,,,,,,, -158700,4.271267,0.7630426,,,,,,,,,,,,,, -158800,4.4670086,0.8074529,,,,,,,,,,,,,, -158900,4.4986525,0.76617897,,,,,,,,,,,,,, -158929,,,0.9267578125,0.2488669753074646,0.7440999746322632,1.0787391662597656,50000.0,0.6193000078201294,1.8259567022323608,10000.0,53590.50220036507,55466.92384791374,53590.50220036507,1866.427173614502,4.452563047409058,0.0 -159000,4.5050154,0.7840423,,,,,,,,,,,,,, -159100,4.19439,0.7309401,,,,,,,,,,,,,, -159200,4.352678,0.79339844,,,,,,,,,,,,,, -159300,4.0939193,0.7291873,,,,,,,,,,,,,, -159400,4.477222,0.7197379,,,,,,,,,,,,,, -159500,4.4790235,0.865562,,,,,,,,,,,,,, -159600,4.0970163,0.7662841,,,,,,,,,,,,,, -159700,4.37864,0.69032604,,,,,,,,,,,,,, -159800,3.9562318,0.71337026,,,,,,,,,,,,,, -159900,4.422882,0.75475776,,,,,,,,,,,,,, -160000,4.3492136,0.78112996,,,,,,,,,,,,,, -160100,4.653423,0.79969305,,,,,,,,,,,,,, -160200,4.0148096,0.7368742,,,,,,,,,,,,,, -160300,4.3745637,0.76250786,,,,,,,,,,,,,, -160400,4.341325,0.76002,,,,,,,,,,,,,, -160443,,,0.927355706691742,0.2522351443767547,0.7450199723243713,1.0835787057876587,50000.0,0.6230000257492065,1.8304731845855715,10000.0,54100.62163352966,55994.9102306366,54100.62163352966,1884.1840229034424,4.509575366973877,0.0 -160500,4.7067037,0.7724918,,,,,,,,,,,,,, -160600,3.919291,0.70325905,,,,,,,,,,,,,, -160700,4.3982687,0.76811516,,,,,,,,,,,,,, -160800,4.196502,0.7976754,,,,,,,,,,,,,, -160900,4.2546477,0.75054353,,,,,,,,,,,,,, -161000,4.2439003,0.75027347,,,,,,,,,,,,,, -161100,4.236537,0.8064651,,,,,,,,,,,,,, -161200,4.084096,0.66733426,,,,,,,,,,,,,, -161300,4.271732,0.66300696,,,,,,,,,,,,,, -161400,4.419788,0.7268276,,,,,,,,,,,,,, -161500,4.1480026,0.7114157,,,,,,,,,,,,,, -161600,4.411297,0.7877692,,,,,,,,,,,,,, -161700,4.1656456,0.6827632,,,,,,,,,,,,,, -161800,4.2612486,0.6369712,,,,,,,,,,,,,, -161900,4.2178545,0.6972984,,,,,,,,,,,,,, -161957,,,0.928730845451355,0.2452500760555267,0.7450199723243713,1.079473853111267,50000.0,0.6254000067710876,1.832878589630127,10000.0,54610.548907756805,56522.68081855774,54610.548907756805,1901.915011882782,4.5679240226745605,0.0 -162000,4.1229706,0.73795736,,,,,,,,,,,,,, -162100,4.375163,0.7182588,,,,,,,,,,,,,, -162200,4.67381,0.78121495,,,,,,,,,,,,,, -162300,4.8381233,0.74782205,,,,,,,,,,,,,, -162400,4.4379873,0.71238923,,,,,,,,,,,,,, -162500,4.381656,0.7710456,,,,,,,,,,,,,, -162600,4.429055,0.71566945,,,,,,,,,,,,,, -162700,4.4373236,0.7391841,,,,,,,,,,,,,, -162800,4.189902,0.75540465,,,,,,,,,,,,,, -162900,4.503177,0.83499664,,,,,,,,,,,,,, -163000,4.2889037,0.7420008,,,,,,,,,,,,,, -163100,4.3497515,0.82784355,,,,,,,,,,,,,, -163200,4.2908077,0.75118846,,,,,,,,,,,,,, -163300,4.3275523,0.68325084,,,,,,,,,,,,,, -163400,4.42663,0.7300899,,,,,,,,,,,,,, -163471,,,0.9323580861091614,0.2346342206001281,0.7455399632453918,1.0828678607940674,50000.0,0.6281000375747681,1.822871208190918,10000.0,55120.64704847336,57050.35792398453,55120.64704847336,1919.380197048188,4.62821364402771,0.0 -163500,4.1213393,0.6649252,,,,,,,,,,,,,, -163600,4.782632,0.76304895,,,,,,,,,,,,,, -163700,4.2659984,0.7362741,,,,,,,,,,,,,, -163800,4.2014737,0.709309,,,,,,,,,,,,,, -163900,4.605631,0.77052474,,,,,,,,,,,,,, -164000,4.1283746,0.6657724,,,,,,,,,,,,,, -164100,4.786109,0.80965626,,,,,,,,,,,,,, -164200,4.5402145,0.7623609,,,,,,,,,,,,,, -164300,4.5477858,0.7219759,,,,,,,,,,,,,, -164400,4.7972302,0.77011186,,,,,,,,,,,,,, -164500,4.616399,0.7903224,,,,,,,,,,,,,, -164600,4.538315,0.7236311,,,,,,,,,,,,,, -164700,4.6578913,0.7481695,,,,,,,,,,,,,, -164800,4.0118365,0.55350477,,,,,,,,,,,,,, -164900,4.5292606,0.7484837,,,,,,,,,,,,,, -164984,,,0.9518494606018066,0.1781339794397354,0.7479000091552734,1.078954577445984,50000.0,0.6242000460624695,1.831413507461548,10000.0,55630.59670042992,57577.83935189247,55630.59670042992,1936.8064422607424,4.680109977722168,0.0 -165000,4.3115087,0.6895703,,,,,,,,,,,,,, -165100,4.239437,0.70446193,,,,,,,,,,,,,, -165200,4.8727345,0.78197604,,,,,,,,,,,,,, -165300,4.4036784,0.7269149,,,,,,,,,,,,,, -165400,4.1581097,0.67857605,,,,,,,,,,,,,, -165500,4.6072097,0.6961981,,,,,,,,,,,,,, -165600,4.432182,0.6361019,,,,,,,,,,,,,, -165700,4.4602838,0.7059632,,,,,,,,,,,,,, -165800,4.1883097,0.6939422,,,,,,,,,,,,,, -165900,4.0302653,0.685204,,,,,,,,,,,,,, -166000,4.774885,0.76522046,,,,,,,,,,,,,, -166100,3.772886,0.5684427,,,,,,,,,,,,,, -166200,4.447456,0.7796866,,,,,,,,,,,,,, -166300,4.466606,0.73499894,,,,,,,,,,,,,, -166400,4.25174,0.71157044,,,,,,,,,,,,,, -166498,,,0.9480029940605164,0.1860027015209198,0.746999979019165,1.078598976135254,50000.0,0.6274000406265259,1.8304781913757324,10000.0,56140.51879167557,58105.194816827774,56140.51879167557,1954.128136396408,4.73781156539917,0.0 -166500,3.9856865,0.6355285,,,,,,,,,,,,,, -166600,4.6426444,0.7251235,,,,,,,,,,,,,, -166700,4.7734027,0.6849552,,,,,,,,,,,,,, -166800,4.9616985,0.733583,,,,,,,,,,,,,, -166900,4.456893,0.69997525,,,,,,,,,,,,,, -167000,4.5524654,0.72421014,,,,,,,,,,,,,, -167100,4.6563764,0.7533301,,,,,,,,,,,,,, -167200,4.471044,0.68006617,,,,,,,,,,,,,, -167300,4.253878,0.70932215,,,,,,,,,,,,,, -167400,4.252145,0.65617174,,,,,,,,,,,,,, -167500,4.637588,0.70558727,,,,,,,,,,,,,, -167600,4.1092763,0.6199726,,,,,,,,,,,,,, -167700,4.048186,0.6675905,,,,,,,,,,,,,, -167800,4.207988,0.7070611,,,,,,,,,,,,,, -167900,4.8942947,0.7661686,,,,,,,,,,,,,, -168000,4.7390914,0.71975166,,,,,,,,,,,,,, -168012,,,0.9465281963348388,0.1897157132625579,0.7482799887657166,1.0680350065231323,50000.0,0.626800000667572,1.822851657867432,10000.0,56650.4464943409,58632.931473732,56650.4464943409,1971.827444314957,4.794127941131592,0.0 -168100,4.1684036,0.683794,,,,,,,,,,,,,, -168200,4.649034,0.67002356,,,,,,,,,,,,,, -168300,4.4487667,0.7104896,,,,,,,,,,,,,, -168400,4.028717,0.65829575,,,,,,,,,,,,,, -168500,4.5973277,0.69702387,,,,,,,,,,,,,, -168600,4.73644,0.7186611,,,,,,,,,,,,,, -168700,4.509205,0.69263285,,,,,,,,,,,,,, -168800,4.551193,0.7438587,,,,,,,,,,,,,, -168900,4.352741,0.67845905,,,,,,,,,,,,,, -169000,4.1602035,0.69352067,,,,,,,,,,,,,, -169100,4.448203,0.67375463,,,,,,,,,,,,,, -169200,4.3739266,0.71824753,,,,,,,,,,,,,, -169300,4.483713,0.6409933,,,,,,,,,,,,,, -169400,4.3631086,0.6155369,,,,,,,,,,,,,, -169500,4.6234865,0.7304706,,,,,,,,,,,,,, -169526,,,0.9483816623687744,0.1834833025932312,0.7514199614524841,1.0675344467163086,50000.0,0.627500057220459,1.824575662612915,10000.0,57160.48657393456,59160.888498306274,57160.48657393456,1989.6279754638672,4.857915878295898,0.0 -169600,4.4243965,0.6887211,,,,,,,,,,,,,, -169700,4.7059083,0.69088113,,,,,,,,,,,,,, -169800,4.1321273,0.6120745,,,,,,,,,,,,,, -169900,4.4759774,0.6723779,,,,,,,,,,,,,, -170000,4.7124233,0.6994495,,,,,,,,,,,,,, -170100,4.505185,0.66664374,,,,,,,,,,,,,, -170200,4.4953527,0.6969395,,,,,,,,,,,,,, -170300,4.4395585,0.648322,,,,,,,,,,,,,, -170400,4.2206964,0.5707834,,,,,,,,,,,,,, -170500,4.5163493,0.66873956,,,,,,,,,,,,,, -170600,4.277311,0.7059509,,,,,,,,,,,,,, -170700,4.1532664,0.6516354,,,,,,,,,,,,,, -170800,4.5178146,0.69160366,,,,,,,,,,,,,, -170900,4.3111863,0.6543589,,,,,,,,,,,,,, -171000,4.461961,0.6477135,,,,,,,,,,,,,, -171040,,,0.950215220451355,0.1781926602125167,0.7507599592208862,1.0668600797653198,50000.0,0.6289000511169434,1.8254389762878416,10000.0,57670.477942705154,59689.33236479759,57670.477942705154,2007.9708423614504,4.914608955383301,0.0 -171100,4.4222937,0.65643114,,,,,,,,,,,,,, -171200,4.2102175,0.6344004,,,,,,,,,,,,,, -171300,4.286723,0.6695389,,,,,,,,,,,,,, -171400,4.3285966,0.66247666,,,,,,,,,,,,,, -171500,5.280111,0.73875165,,,,,,,,,,,,,, -171600,4.1846943,0.5892317,,,,,,,,,,,,,, -171700,4.559212,0.745322,,,,,,,,,,,,,, -171800,4.530733,0.65119404,,,,,,,,,,,,,, -171900,4.7892365,0.67593443,,,,,,,,,,,,,, -172000,4.7708883,0.75126034,,,,,,,,,,,,,, -172100,4.2831063,0.6482376,,,,,,,,,,,,,, -172200,4.5738273,0.71566164,,,,,,,,,,,,,, -172300,4.45585,0.6960131,,,,,,,,,,,,,, -172400,4.159507,0.58389324,,,,,,,,,,,,,, -172500,4.893854,0.64039975,,,,,,,,,,,,,, -172554,,,0.9512715339660645,0.1751319766044616,0.7514399886131287,1.0617859363555908,50000.0,0.6328000426292419,1.818946361541748,10000.0,58180.52669739723,60217.08939146996,58180.52669739723,2025.5694625377653,4.971103191375732,0.0 -172600,4.6003613,0.66674423,,,,,,,,,,,,,, -172700,4.4868584,0.649583,,,,,,,,,,,,,, -172800,4.487968,0.6286116,,,,,,,,,,,,,, -172900,4.6227617,0.63613963,,,,,,,,,,,,,, -173000,4.7394,0.63397634,,,,,,,,,,,,,, -173100,4.295481,0.6744726,,,,,,,,,,,,,, -173200,4.743475,0.70437473,,,,,,,,,,,,,, -173300,4.756542,0.6809252,,,,,,,,,,,,,, -173400,4.488353,0.6257187,,,,,,,,,,,,,, -173500,4.458154,0.60468006,,,,,,,,,,,,,, -173600,4.5540605,0.5923759,,,,,,,,,,,,,, -173700,4.839638,0.6474334,,,,,,,,,,,,,, -173800,4.412545,0.6705384,,,,,,,,,,,,,, -173900,4.4897485,0.66562384,,,,,,,,,,,,,, -174000,4.4338326,0.6102484,,,,,,,,,,,,,, -174068,,,0.9608178734779358,0.1495468467473983,0.7513999938964844,1.0650503635406494,50000.0,0.6322000026702881,1.8160579204559328,10000.0,58690.58052444458,60744.659641981125,58690.58052444458,2042.9718658924105,5.033244848251343,0.0 -174100,4.848406,0.6371241,,,,,,,,,,,,,, -174200,4.7144866,0.6568332,,,,,,,,,,,,,, -174300,4.4253716,0.7045787,,,,,,,,,,,,,, -174400,4.6206055,0.65481234,,,,,,,,,,,,,, -174500,4.468303,0.5815454,,,,,,,,,,,,,, -174600,4.478845,0.62389517,,,,,,,,,,,,,, -174700,4.3813696,0.6162063,,,,,,,,,,,,,, -174800,4.581236,0.6738718,,,,,,,,,,,,,, -174900,4.3299294,0.64718384,,,,,,,,,,,,,, -175000,4.4387107,0.6628642,,,,,,,,,,,,,, -175100,4.2458525,0.6271281,,,,,,,,,,,,,, -175200,4.466261,0.66235524,,,,,,,,,,,,,, -175300,4.4842777,0.67364,,,,,,,,,,,,,, -175400,4.3798904,0.68337804,,,,,,,,,,,,,, -175500,4.616176,0.6479951,,,,,,,,,,,,,, -175582,,,0.958227038383484,0.1554597616195678,0.7520999908447266,1.0633561611175537,50000.0,0.6325000524520874,1.820577621459961,10000.0,59200.59376168251,61272.38275671005,59200.59376168251,2060.5704913139343,5.091560125350952,0.0 -175600,4.142717,0.58227044,,,,,,,,,,,,,, -175700,4.434039,0.5724314,,,,,,,,,,,,,, -175800,4.6249213,0.6932404,,,,,,,,,,,,,, -175900,4.6400766,0.5892945,,,,,,,,,,,,,, -176000,4.5378914,0.6670633,,,,,,,,,,,,,, -176100,4.534222,0.6800349,,,,,,,,,,,,,, -176200,4.5573506,0.6206094,,,,,,,,,,,,,, -176300,4.664262,0.69045293,,,,,,,,,,,,,, -176400,4.504374,0.6167592,,,,,,,,,,,,,, -176500,4.499817,0.6174116,,,,,,,,,,,,,, -176600,4.229496,0.5685679,,,,,,,,,,,,,, -176700,4.2813573,0.62423253,,,,,,,,,,,,,, -176800,4.374528,0.56060356,,,,,,,,,,,,,, -176900,4.647959,0.6519566,,,,,,,,,,,,,, -177000,4.4909763,0.6485603,,,,,,,,,,,,,, -177096,,,0.9596021771430968,0.1510808914899826,0.7536599636077881,1.0556384325027466,50000.0,0.6310000419616699,1.812854528427124,10000.0,59710.533801317215,61800.26037359238,59710.533801317215,2078.398533344269,5.148442506790161,0.0 -177100,4.0195,0.5925612,,,,,,,,,,,,,, -177200,4.9665475,0.6991036,,,,,,,,,,,,,, -177300,4.556619,0.6033248,,,,,,,,,,,,,, -177400,4.6309075,0.66046286,,,,,,,,,,,,,, -177500,4.2321033,0.56853527,,,,,,,,,,,,,, -177600,4.586182,0.6671848,,,,,,,,,,,,,, -177700,4.9628286,0.7064694,,,,,,,,,,,,,, -177800,4.530781,0.63399243,,,,,,,,,,,,,, -177900,4.835795,0.6144823,,,,,,,,,,,,,, -178000,4.268508,0.61707747,,,,,,,,,,,,,, -178100,4.4828663,0.5619826,,,,,,,,,,,,,, -178200,4.232826,0.5812722,,,,,,,,,,,,,, -178300,4.39271,0.629923,,,,,,,,,,,,,, -178400,4.4313498,0.64410645,,,,,,,,,,,,,, -178500,4.622337,0.65773004,,,,,,,,,,,,,, -178600,4.388506,0.6385698,,,,,,,,,,,,,, -178610,,,0.95902419090271,0.1536361575126648,0.7538599967956543,1.0557180643081665,50000.0,0.6331000328063965,1.810882329940796,10000.0,60220.49709105492,62328.04130482674,60220.49709105492,2096.097542285919,5.214344024658203,0.0 -178700,4.2220635,0.58903056,,,,,,,,,,,,,, -178800,4.6214623,0.66715246,,,,,,,,,,,,,, -178900,4.654928,0.6321581,,,,,,,,,,,,,, -179000,4.730133,0.64835083,,,,,,,,,,,,,, -179100,4.3729053,0.52176255,,,,,,,,,,,,,, -179200,4.5457773,0.6130741,,,,,,,,,,,,,, -179300,4.383505,0.5596387,,,,,,,,,,,,,, -179400,4.8306828,0.6560162,,,,,,,,,,,,,, -179500,4.8779135,0.6552594,,,,,,,,,,,,,, -179600,4.5215893,0.6772495,,,,,,,,,,,,,, -179700,4.4801555,0.6058127,,,,,,,,,,,,,, -179800,4.255951,0.6319612,,,,,,,,,,,,,, -179900,4.3654523,0.67295814,,,,,,,,,,,,,, -180000,4.113813,0.5290624,,,,,,,,,,,,,, -180100,4.405206,0.5876495,,,,,,,,,,,,,, -180123,,,0.9585060477256776,0.1502209901809692,0.7537199854850769,1.0573610067367554,50000.0,0.6345000267028809,1.8141127824783323,10000.0,60730.41804790497,62855.74048471451,60730.41804790497,2113.768192052841,5.269103527069092,0.0 -180200,4.6120286,0.6134532,,,,,,,,,,,,,, -180300,4.670912,0.67003894,,,,,,,,,,,,,, -180400,4.845346,0.6856934,,,,,,,,,,,,,, -180500,4.423167,0.6420341,,,,,,,,,,,,,, -180600,4.124637,0.5788629,,,,,,,,,,,,,, -180700,4.745647,0.6299854,,,,,,,,,,,,,, -180800,4.9330416,0.65425164,,,,,,,,,,,,,, -180900,4.5062356,0.67408603,,,,,,,,,,,,,, -181000,4.454658,0.63789445,,,,,,,,,,,,,, -181100,4.5778646,0.6863461,,,,,,,,,,,,,, -181200,5.0740023,0.69497776,,,,,,,,,,,,,, -181300,4.6930704,0.56260085,,,,,,,,,,,,,, -181400,4.649234,0.6623568,,,,,,,,,,,,,, -181500,4.6790304,0.6411753,,,,,,,,,,,,,, -181600,4.3910227,0.647556,,,,,,,,,,,,,, -181637,,,0.9604192972183228,0.1468931436538696,0.7546799778938293,1.0567476749420166,50000.0,0.633400022983551,1.8171926736831665,10000.0,61240.57641124725,63383.426446676254,61240.57641124725,2131.186047077179,5.325862407684326,0.0 -181700,4.5562954,0.66696906,,,,,,,,,,,,,, -181800,4.5948277,0.6531884,,,,,,,,,,,,,, -181900,4.4390707,0.61765236,,,,,,,,,,,,,, -182000,4.368404,0.6430066,,,,,,,,,,,,,, -182100,4.6138716,0.5993149,,,,,,,,,,,,,, -182200,4.7442837,0.6027459,,,,,,,,,,,,,, -182300,4.582229,0.63273203,,,,,,,,,,,,,, -182400,4.5812736,0.62602174,,,,,,,,,,,,,, -182500,4.3742104,0.6101172,,,,,,,,,,,,,, -182600,3.9878368,0.5549427,,,,,,,,,,,,,, -182700,4.7219496,0.6883252,,,,,,,,,,,,,, -182800,4.4599786,0.5836948,,,,,,,,,,,,,, -182900,4.4546766,0.6224523,,,,,,,,,,,,,, -183000,4.561045,0.6422065,,,,,,,,,,,,,, -183100,4.9048185,0.6488973,,,,,,,,,,,,,, -183151,,,0.9615154266357422,0.1437261253595352,0.7544800043106079,1.0552575588226318,50000.0,0.6343000531196594,1.8140968084335327,10000.0,61750.644112825394,63911.53665685654,61750.644112825394,2149.111617088318,5.389604806900024,0.0 -183200,4.696652,0.61576533,,,,,,,,,,,,,, -183300,4.3851466,0.5970209,,,,,,,,,,,,,, -183400,4.325351,0.5734288,,,,,,,,,,,,,, -183500,4.1601644,0.59085757,,,,,,,,,,,,,, -183600,4.7292023,0.62826455,,,,,,,,,,,,,, -183700,4.1149764,0.59510505,,,,,,,,,,,,,, -183800,4.753674,0.67224544,,,,,,,,,,,,,, -183900,4.343579,0.58667874,,,,,,,,,,,,,, -184000,4.603429,0.6492386,,,,,,,,,,,,,, -184100,4.5472345,0.6365735,,,,,,,,,,,,,, -184200,4.8477855,0.6271953,,,,,,,,,,,,,, -184300,4.3707647,0.62385637,,,,,,,,,,,,,, -184400,4.6113787,0.5829914,,,,,,,,,,,,,, -184500,4.415352,0.6034785,,,,,,,,,,,,,, -184600,4.4098563,0.59708136,,,,,,,,,,,,,, -184665,,,0.961316168308258,0.1466423720121383,0.7540599703788757,1.0541568994522097,50000.0,0.6353000402450562,1.81319522857666,10000.0,62260.7015068531,64439.141348838806,62260.7015068531,2166.556547164917,5.439829349517822,0.0 -184700,4.471392,0.59618735,,,,,,,,,,,,,, -184800,4.851125,0.6158637,,,,,,,,,,,,,, -184900,4.543128,0.6902498,,,,,,,,,,,,,, -185000,4.8043923,0.61110264,,,,,,,,,,,,,, -185100,4.362865,0.571815,,,,,,,,,,,,,, -185200,4.157391,0.59813595,,,,,,,,,,,,,, -185300,4.3870463,0.5737334,,,,,,,,,,,,,, -185400,4.5473866,0.60264426,,,,,,,,,,,,,, -185500,4.3034472,0.5829023,,,,,,,,,,,,,, -185600,4.3157783,0.7009779,,,,,,,,,,,,,, -185700,4.581582,0.6362631,,,,,,,,,,,,,, -185800,4.133627,0.6122631,,,,,,,,,,,,,, -185900,4.3318167,0.6064492,,,,,,,,,,,,,, -186000,4.0833364,0.623132,,,,,,,,,,,,,, -186100,4.2938657,0.5943944,,,,,,,,,,,,,, -186179,,,0.9597018361091614,0.1491877883672714,0.7537399530410767,1.0555750131607056,50000.0,0.634600043296814,1.8140379190444944,10000.0,62770.74982833862,64966.6444914341,62770.74982833862,2183.882416248321,5.515969038009644,0.0 -186200,4.40441,0.6628542,,,,,,,,,,,,,, -186300,4.454769,0.6091709,,,,,,,,,,,,,, -186400,4.615877,0.6047577,,,,,,,,,,,,,, -186500,4.37877,0.5991163,,,,,,,,,,,,,, -186600,4.407118,0.6475784,,,,,,,,,,,,,, -186666,,,0.9602798223495485,0.1457605808973312,0.75409996509552,1.054492473602295,50000.0,0.6342000365257263,1.8139417171478271,10000.0,62934.7845287323,65148.003600120544,62934.7845287323,2201.129693508148,5.575985908508301,0.0 -186666,,,,,,,,,,,62934.7845287323,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d2c178c81..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.396591901779175,0.0,42.824132204055786,1,0,42.824132204055786,0.0010000000474974,6.907756805419922,10000,83.2208137512207,0.0007617187220603,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -61.98973369598389,0.0296323299407959,462.9493408203125,890,0,462.9493408203125,0.0100000007078051,6.503350734710693,10000,525.0217719078064,0.0124023435637354,6.457183837890625,0.0122400000691413,6.465394973754883,50000 -83.87272596359253,0.0612773895263671,883.1624338626862,1829,0,883.1624338626862,0.0310000013560056,5.997738838195801,10000,967.2049548625946,0.0393554680049419,5.864162445068359,0.0392399989068508,5.889999866485596,50000 -105.64576721191406,0.0901827812194824,1303.3358738422394,2769,0,1303.3358738422394,0.0510000027716159,5.634537220001221,10000,1409.2352724075315,0.0716210901737213,5.389620780944824,0.0655199959874153,5.453647136688232,50000 -127.48762011528017,0.1172609329223632,1723.3727452754974,3710,0,1723.3727452754974,0.074800007045269,5.322034358978272,10000,1851.19499373436,0.1043164059519767,5.0515570640563965,0.0981599986553192,5.091566562652588,50000 -149.46379971504211,0.1450655460357666,2143.325270175934,4645,0,2143.325270175934,0.0993000045418739,5.037469387054443,10000,2293.20528960228,0.1397460848093032,4.694349765777588,0.1315799951553344,4.74459171295166,50000 -171.33267450332642,0.1757004261016845,2563.570514202118,5578,0,2563.570514202118,0.1262000054121017,4.740323066711426,10000,2735.4029870033264,0.1884570270776748,4.301875114440918,0.1712799966335296,4.40316104888916,50000 -196.30260586738584,0.2066588401794433,2983.755667209625,6510,0,2983.755667209625,0.163000002503395,4.438085556030273,10000,3180.642364740372,0.2281054705381393,3.9725728034973145,0.2132799923419952,4.052278518676758,50000 -220.32091236114505,0.2356452941894531,3403.835516691208,7447,0,3403.835516691208,0.1923000067472458,4.209742546081543,10000,3624.823510885239,0.2765820324420929,3.669283628463745,0.2522999942302704,3.7821102142333984,50000 -247.93834471702576,0.2755928039550781,3823.778416156769,8384,0,3823.778416156769,0.2238000035285949,3.993281841278076,10000,4072.47727060318,0.3094335794448852,3.4023940563201904,0.2844800055027008,3.535698652267456,50000 -276.0673451423645,0.3065640926361084,4243.810468912125,9320,0,4243.810468912125,0.2459000051021576,3.8278791904449454,10000,4520.722212314606,0.3596288859844208,3.166682004928589,0.3202399909496307,3.347395896911621,50000 -307.3923919200897,0.3343639373779297,4664.097786426544,10253,0,4664.097786426544,0.2634000182151794,3.720574378967285,10000,4972.415317296982,0.3677538931369781,3.1120121479034424,0.3455999791622162,3.226508617401123,50000 -335.6350944042206,0.3666074275970459,5084.326451063156,11185,0,5084.326451063156,0.2808000147342682,3.5686235427856445,10000,5420.9716629982,0.3995117247104645,2.8868730068206787,0.3680599927902221,3.0476648807525635,50000 -366.0485026836395,0.3962712287902832,5504.530478954315,12117,0,5504.530478954315,0.2963000237941742,3.4384853839874268,10000,5871.672466278076,0.4317578077316284,2.685864210128784,0.3901999890804291,2.880558729171753,50000 -396.19059109687805,0.429196834564209,5924.697474718094,13044,0,5924.697474718094,0.3115000128746032,3.408851385116577,10000,6322.066045045853,0.4320703148841858,2.7094104290008545,0.4023999869823456,2.8486061096191406,50000 -431.9696838855744,0.46321702003479,6344.838493108749,13967,0,6344.838493108749,0.3273000121116638,3.2973392009735107,10000,6778.072106599808,0.4513085782527923,2.5765609741210938,0.4201000034809112,2.7298953533172607,50000 -466.9709808826447,0.489241361618042,6764.83682847023,14896,0,6764.83682847023,0.3342000246047973,3.250800371170044,10000,7233.150120258331,0.4676562249660492,2.475290536880493,0.4269999861717224,2.6804628372192383,50000 -505.29772305488586,0.520289421081543,7185.151596069336,15826,0,7185.151596069336,0.3340000212192535,3.236973285675049,10000,7691.874465227127,0.4655663967132568,2.52592134475708,0.435619980096817,2.66223406791687,50000 -545.1009593009949,0.5469293594360352,7605.435419559479,16754,0,7605.435419559479,0.3508000075817108,3.1780033111572266,10000,8152.040428161621,0.4782226383686065,2.447401762008667,0.4474399983882904,2.5987234115600586,50000 -586.1966207027435,0.5725915431976318,8025.745712041855,17682,0,8025.745712041855,0.3485000133514404,3.146413326263428,10000,8613.524369716644,0.4890234172344208,2.3615238666534424,0.4515799880027771,2.543088912963867,50000 -625.3753478527069,0.6053094863891602,8445.693154335022,18605,0,8445.693154335022,0.3604000210762024,3.0848608016967773,10000,9072.734701156616,0.5197656154632568,2.212848901748657,0.4608999788761139,2.4904685020446777,50000 -666.6394157409668,0.6336965560913086,8865.636800050735,19531,0,8865.636800050735,0.3640000224113464,3.088551998138428,10000,9534.022550106049,0.5014843344688416,2.333058357238769,0.4675999879837036,2.486729145050049,50000 -705.9051666259766,2.38281798362732,9284.03717637062,20451,0,9284.03717637062,0.3753000199794769,3.02260684967041,10000,9993.490169763563,0.5143749713897705,2.221330165863037,0.4759399890899658,2.4047062397003174,50000 -743.9496030807495,2.413090705871582,9703.97527360916,21375,0,9703.97527360916,0.3745000064373016,3.008895874023437,10000,10451.55467915535,0.5318750143051147,2.146445274353028,0.4820199906826019,2.3802831172943115,50000 -778.7204098701477,2.43961763381958,10124.1838889122,22298,0,10124.1838889122,0.3839000165462494,2.934317111968994,10000,10906.612278938292,0.525390625,2.1639983654022217,0.4918799996376037,2.320196866989136,50000 -815.3714473247528,2.4672508239746094,10544.424010038376,23222,0,10544.424010038376,0.3893000185489654,2.94062876701355,10000,11363.582971572876,0.5257226228713989,2.1706438064575195,0.4936599731445312,2.339932680130005,50000 -851.4152765274048,2.494343042373657,10964.40007162094,24147,0,10964.40007162094,0.3957000076770782,2.886673212051392,10000,11819.682758808136,0.5456640720367432,2.072561502456665,0.5017399787902832,2.2792351245880127,50000 -887.2221512794495,2.52208948135376,11384.33907365799,25071,0,11384.33907365799,0.4038000106811523,2.85197377204895,10000,12275.507807016373,0.5464843511581421,2.0741374492645264,0.5124599933624268,2.23947811126709,50000 -924.0465886592864,2.554238557815552,11804.526224374771,25985,0,11804.526224374771,0.4057000279426574,2.873729705810547,10000,12732.602709293364,0.5458202958106995,2.107918739318848,0.5095399618148804,2.2755002975463867,50000 -959.8301196098328,2.5867295265197754,12224.458804368973,26908,0,12224.458804368973,0.4059000313282013,2.844035387039185,10000,13188.403272867205,0.554394543170929,2.04179310798645,0.5141599774360657,2.2237548828125,50000 -995.5060038566588,2.615927696228028,12644.591686487198,27831,0,12644.591686487198,0.4076000154018402,2.831859588623047,10000,13644.293235778809,0.5831835865974426,1.937749147415161,0.5192999839782715,2.2157299518585205,50000 -1032.0707716941831,2.6466317176818848,13064.928541898727,28754,0,13064.928541898727,0.415800005197525,2.801346063613892,10000,14101.277722358704,0.56103515625,2.0178725719451904,0.5253599882125854,2.184187412261963,50000 -1067.7865002155304,2.6822452545166016,13485.01470541954,29679,0,13485.01470541954,0.4166000187397003,2.762173652648926,10000,14557.166967391968,0.5685741901397705,1.968665719032288,0.5270999670028687,2.1590123176574707,50000 -1103.7031581401825,2.720014333724976,13905.378865480425,30601,0,13905.378865480425,0.4199000298976898,2.761985540390014,10000,15013.537324428558,0.5899999737739563,1.8909555673599243,0.5362399816513062,2.1361539363861084,50000 -1139.8854219913485,2.751550912857056,14325.58296585083,31525,0,14325.58296585083,0.4225000143051147,2.7142233848571777,10000,15470.007066965103,0.5753320455551147,1.9117754697799685,0.5366199612617493,2.086362838745117,50000 -1175.7121279239657,2.7820372581481934,14745.557694911957,32447,0,14745.557694911957,0.4178000092506408,2.760493278503418,10000,15925.88976097107,0.5731250047683716,1.933479070663452,0.5350800156593323,2.114149808883667,50000 -1211.0903453826904,2.8150620460510254,15165.785947084429,33367,0,15165.785947084429,0.4253000319004059,2.7308449745178223,10000,16381.58100271225,0.5793749690055847,1.8931059837341309,0.5353599786758423,2.102208614349365,50000 -1247.915519475937,2.8516042232513428,15586.13660979271,34289,0,15586.13660979271,0.4348000288009643,2.68808913230896,10000,16838.84478354454,0.5872460603713989,1.8730449676513672,0.5454199910163879,2.0598084926605225,50000 -1283.7321796417236,2.881927490234375,16006.159181833267,35212,0,16006.159181833267,0.4269000291824341,2.722228527069092,10000,17294.766576051712,0.583691418170929,1.88885509967804,0.5408999919891357,2.083228588104248,50000 -1319.5742392539978,2.9135375022888184,16426.25867486,36132,0,16426.25867486,0.4353000223636627,2.674314498901367,10000,17750.791348934174,0.5895312428474426,1.8599177598953247,0.5429800152778625,2.068998336791992,50000 -1355.6950623989103,2.946035861968994,16846.513379335403,37053,0,16846.513379335403,0.4245000183582306,2.73975157737732,10000,18207.25081396103,0.5921093821525574,1.889419674873352,0.5410400032997131,2.1071622371673584,50000 -1391.6465280056,2.97900915145874,17266.63311100006,37974,0,17266.63311100006,0.4247000217437744,2.710883617401123,10000,18663.40659666061,0.5863280892372131,1.8857553005218504,0.5454800128936768,2.069470167160034,50000 -1428.1919829845428,3.0127615928649902,17686.921609401703,38897,0,17686.921609401703,0.4341000318527221,2.703827381134033,10000,19120.326916456223,0.5941210985183716,1.883009552955628,0.5507400035858154,2.086865186691284,50000 -1464.3569984436035,3.0470962524414062,18107.20042920113,39821,0,18107.20042920113,0.4416000247001648,2.6214611530303955,10000,19576.857031822205,0.6173437237739563,1.7155250310897827,0.555620014667511,2.007648706436157,50000 -1500.6704943180084,3.0773818492889404,18527.43185925484,40743,0,18527.43185925484,0.445000022649765,2.613339900970459,10000,20033.48488211632,0.5999609231948853,1.8218204975128167,0.5581799745559692,1.9893065690994265,50000 -1536.0949032306671,3.118054151535034,18947.66690206528,41667,0,18947.66690206528,0.4398000240325928,2.6552605628967285,10000,20489.236602544785,0.5997461080551147,1.8272331953048704,0.5560599565505981,2.0224764347076416,50000 -1573.170075416565,3.149902820587158,19367.82270050049,42591,0,19367.82270050049,0.4319000244140625,2.671928882598877,10000,20946.551624774933,0.6067578196525574,1.7815989255905151,0.5556600093841553,2.026575326919556,50000 -1609.2521879673004,3.1867854595184326,19787.93092918396,43515,0,19787.93092918396,0.45210000872612,2.557616233825684,10000,21402.830537080765,0.6055273413658142,1.7664735317230225,0.5644599795341492,1.9504902362823489,50000 -1644.6554489135742,3.2182297706604004,20207.914827108383,44436,0,20207.914827108383,0.4448000192642212,2.628516912460327,10000,21858.304526090626,0.6005859375,1.8105329275131223,0.5595600008964539,1.997233271598816,50000 -1681.4962322711945,3.2489402294158936,20627.91112852097,45355,0,20627.91112852097,0.4436000287532806,2.6396775245666504,10000,22315.22365355492,0.6087499856948853,1.792636513710022,0.5613799691200256,2.0202178955078125,50000 -1717.3759191036224,3.283371686935425,21047.84006094933,46275,0,21047.84006094933,0.4536000192165375,2.589318037033081,10000,22771.119074821472,0.6075195074081421,1.7662358283996582,0.5708999633789062,1.937433123588562,50000 -1753.7322096824646,3.31660795211792,21468.07667350769,47195,0,21468.07667350769,0.4508000314235687,2.603720664978028,10000,23227.797281980515,0.6070312261581421,1.791407585144043,0.5654999613761902,1.988101363182068,50000 -1789.8139972686768,3.355715274810791,21888.20467019081,48117,0,21888.20467019081,0.4601000249385834,2.536510705947876,10000,23684.09828400612,0.6176171898841858,1.6974457502365112,0.572219967842102,1.9044924974441528,50000 -1825.6325912475584,3.390817880630493,22308.17884039879,49039,0,22308.17884039879,0.4565000236034393,2.5707848072052,10000,24139.977380990986,0.6376562118530273,1.6303468942642212,0.5713199973106384,1.924487590789795,50000 -1862.3381762504573,3.424906015396118,22728.231050014496,49960,0,22728.231050014496,0.4570000171661377,2.5590219497680664,10000,24596.820907592773,0.6035937070846558,1.750400185585022,0.5683599710464478,1.9232509136199951,50000 -1895.92019033432,3.456855535507202,23148.14708518982,50881,0,23148.14708518982,0.4577000141143799,2.5869696140289307,10000,25050.40443754196,0.6149609088897705,1.7614407539367676,0.5702599883079529,1.953475832939148,50000 -1932.593270778656,3.4931063652038574,23568.13712787628,51802,0,23568.13712787628,0.45660001039505,2.557199239730835,10000,25507.155699014664,0.6318554282188416,1.6667355298995972,0.5773599743843079,1.915133237838745,50000 -1969.695535421372,3.5249931812286377,23988.11474442482,52724,0,23988.11474442482,0.4574000239372253,2.535064697265625,10000,25964.319465875626,0.6173242330551147,1.7071141004562378,0.578220009803772,1.8849382400512693,50000 -2005.4340209960933,3.5617029666900635,24408.527057886124,53645,0,24408.527057886124,0.4635000228881836,2.5295896530151367,10000,26420.55820798874,0.6204687356948853,1.695976972579956,0.5777400135993958,1.8915741443634035,50000 -2043.0168118476868,3.5952277183532715,24828.74754881859,54568,0,24828.74754881859,0.4611000120639801,2.53002667427063,10000,26878.447025060654,0.6335741877555847,1.6409857273101809,0.5818600058555603,1.8811854124069207,50000 -2079.7810554504395,3.627671003341675,25248.86332678795,55491,0,25248.86332678795,0.4648000299930572,2.4718806743621826,10000,27335.411857128143,0.6231836080551147,1.6743615865707395,0.5826799869537354,1.8473048210144043,50000 -2117.397604942322,3.666386127471924,25669.019901752472,56413,0,25669.019901752472,0.465800017118454,2.5380899906158447,10000,27793.2751955986,0.6209765672683716,1.711039662361145,0.5796599984169006,1.9055297374725344,50000 -2154.234006166458,3.701314687728882,26089.111981153488,57336,0,26089.111981153488,0.4615000188350677,2.495638608932495,10000,28250.289984464645,0.6293359398841858,1.6536279916763306,0.5833799839019775,1.8656693696975708,50000 -2191.714658260345,3.733286857604981,26509.280297517776,58258,0,26509.280297517776,0.4659000337123871,2.48901891708374,10000,28708.02276062965,0.6524804830551147,1.5615637302398682,0.5855000019073486,1.8619376420974727,50000 -2228.310579776764,3.7675628662109375,26929.2906024456,59178,0,26929.2906024456,0.4636000096797943,2.526654005050659,10000,29164.714790582657,0.6221289038658142,1.7175687551498413,0.5812399983406067,1.9062336683273315,50000 -2265.5188434124,3.804023742675781,27349.29820728302,60097,0,27349.29820728302,0.4661000072956085,2.5786662101745605,10000,29622.019852399822,0.6240624785423279,1.7262558937072754,0.5816799998283386,1.931992769241333,50000 -2301.384221792221,3.842408418655396,27769.632704496384,61019,0,27769.632704496384,0.4717000126838684,2.467651128768921,10000,30078.310692310333,0.6532421708106995,1.5591392517089844,0.5931999683380127,1.8360151052474976,50000 -2339.4837741851807,3.8803889751434326,28189.844376802444,61941,0,28189.844376802444,0.4720000326633453,2.4553451538085938,10000,30536.71168017388,0.6307030916213989,1.6259307861328125,0.5897600054740906,1.8175615072250368,50000 -2375.022164583206,3.916743040084839,28609.88588285446,62861,0,28609.88588285446,0.4687000215053558,2.47857403755188,10000,30992.379618406296,0.6349999904632568,1.6504743099212646,0.5909799933433533,1.858465313911438,50000 -2412.3593595027924,3.953082323074341,29029.89266204834,63784,0,29029.89266204834,0.4695000350475311,2.484283208847046,10000,31449.811811208725,0.6424999833106995,1.608486294746399,0.5922999978065491,1.8405706882476809,50000 -2449.3163084983826,3.987928152084351,29449.915155172348,64704,0,29449.915155172348,0.4734000265598297,2.459376335144043,10000,31906.87753367424,0.6398632526397705,1.6252983808517456,0.5931400060653687,1.8280532360076904,50000 -2488.313669919968,4.030225992202759,29870.18850684166,65625,0,29870.18850684166,0.4728000164031982,2.4222564697265625,10000,32366.242438316345,0.6363476514816284,1.6080312728881836,0.596019983291626,1.798905611038208,50000 -2526.093469142914,4.068289279937744,30290.33807373047,66548,0,30290.33807373047,0.4757000207901001,2.425429582595825,10000,32824.26109623909,0.6463476419448853,1.5539718866348269,0.593779981136322,1.7889726161956787,50000 -2563.291362285614,4.101492404937744,30710.64385533333,67470,0,30710.64385533333,0.4787000119686126,2.4377787113189697,10000,33281.84917807579,0.6481249928474426,1.55096435546875,0.5956199765205383,1.7877963781356812,50000 -2600.11688375473,4.13883113861084,31130.57970190048,68391,0,31130.57970190048,0.4754000306129455,2.4577841758728027,10000,33738.69984698296,0.6390038728713989,1.6182537078857422,0.5971599817276001,1.8191072940826416,50000 -2637.507829427719,4.173320293426514,31550.54110193253,69310,0,31550.54110193253,0.4824000298976898,2.421178340911865,10000,34196.13939833641,0.6440038681030273,1.572296142578125,0.5978800058364868,1.780421018600464,50000 -2672.718720436096,4.21485447883606,31970.663256645203,70234,0,31970.663256645203,0.4819000363349914,2.420220375061035,10000,34651.5661213398,0.6702734231948853,1.4641485214233398,0.6001200079917908,1.762584209442139,50000 -2710.3715307712555,4.249778985977173,32390.83561515808,71154,0,32390.83561515808,0.4907000362873077,2.3740594387054443,10000,35109.47786331177,0.6441210508346558,1.5635147094726562,0.6025399565696716,1.7569859027862549,50000 -2748.524088859558,4.293323755264282,32810.90993022919,72077,0,32810.90993022919,0.4802000224590301,2.438359498977661,10000,35567.80102777481,0.6434960961341858,1.6113227605819702,0.6014800071716309,1.8047555685043333,50000 -2787.061530351639,4.332884788513184,33231.24440646172,72996,0,33231.24440646172,0.4825000166893005,2.439271926879883,10000,36026.76460838318,0.6573632955551147,1.5504130125045776,0.5985599756240845,1.8023332357406616,50000 -2824.5923268795013,4.368396282196045,33651.5921421051,73920,0,33651.5921421051,0.486700028181076,2.417206048965454,10000,36484.73053359985,0.6471093893051147,1.6129900217056274,0.6095799803733826,1.7827093601226809,50000 -2861.388499736786,4.406642913818359,34071.80634212494,74841,0,34071.80634212494,0.4857000112533569,2.386080503463745,10000,36941.8312625885,0.6474413871765137,1.5477566719055176,0.606440007686615,1.7396602630615234,50000 -2898.396971464157,4.448941230773926,34492.00156569481,75764,0,34492.00156569481,0.4821000099182129,2.3967623710632324,10000,37399.13010430336,0.6582812070846558,1.5243412256240845,0.6102199554443359,1.744834065437317,50000 -2935.307032585144,4.486681699752808,34912.220725774765,76687,0,34912.220725774765,0.4863000214099884,2.3870294094085693,10000,37856.34967160225,0.6499804258346558,1.5522480010986328,0.6068199872970581,1.75525164604187,50000 -2971.7798516750336,4.523826599121094,35332.43723273277,77611,0,35332.43723273277,0.4870000183582306,2.381280899047852,10000,38313.127844810486,0.655078113079071,1.556774377822876,0.6091799736022949,1.7530089616775513,50000 -3010.292057991028,4.561509370803833,35752.73362803459,78533,0,35752.73362803459,0.4961000382900238,2.359246730804444,10000,38772.0255856514,0.6649804711341858,1.5041078329086304,0.6135199666023254,1.727328181266785,50000 -3047.432544708252,4.599277496337891,36172.88255786896,79454,0,36172.88255786896,0.4874000251293182,2.350943326950073,10000,39229.4043803215,0.6874414086341858,1.3758230209350586,0.6157599687576294,1.6964789628982544,50000 -3085.007307291031,4.6409912109375,36593.105981349945,80378,0,36593.105981349945,0.4880000352859497,2.4137213230133057,10000,39687.296456336975,0.6550585627555847,1.5762782096862793,0.6106799840927124,1.7810405492782593,50000 -3123.217358827591,4.677295207977295,37013.32633161545,81300,0,37013.32633161545,0.4939000308513641,2.3620481491088867,10000,40145.815786361694,0.6625585556030273,1.5070688724517822,0.6144799590110779,1.7260377407073977,50000 -3160.800005197525,4.718563795089722,37433.64666318893,82221,0,37433.64666318893,0.492900013923645,2.362778425216675,10000,40603.81119298935,0.6760546565055847,1.452250361442566,0.6148200035095215,1.7269800901412964,50000 -3199.3191499710083,4.758185863494873,37853.990706920624,83141,0,37853.990706920624,0.5021000504493713,2.319068193435669,10000,41062.76618885994,0.6602343320846558,1.497167468070984,0.6192399859428406,1.6920735836029053,50000 -3237.663080692292,4.796828985214233,38274.122478723526,84061,0,38274.122478723526,0.4988000094890594,2.355710029602051,10000,41521.33239722252,0.6661523580551147,1.5092904567718506,0.6180399656295776,1.7301098108291626,50000 -3275.256284952164,4.83923602104187,38694.25512552261,84983,0,38694.25512552261,0.5003000497817993,2.320410251617432,10000,41979.15235233307,0.6742187142372131,1.436706781387329,0.6194199919700623,1.680558204650879,50000 -3313.068485021591,4.880053997039795,39114.24275946617,85905,0,39114.24275946617,0.4909000098705292,2.3830223083496094,10000,42437.045094013214,0.657031238079071,1.5600054264068604,0.6131199598312378,1.742652177810669,50000 -3351.4402787685394,4.920087099075317,39534.21739983559,86826,0,39534.21739983559,0.4971000254154205,2.3761723041534424,10000,42895.48432254791,0.6669530868530273,1.5276052951812744,0.6200399994850159,1.735192894935608,50000 -3387.36865234375,4.959564685821533,39954.22675919533,87749,0,39954.22675919533,0.5020000338554382,2.299318552017212,10000,43351.513491392136,0.6812499761581421,1.3998364210128784,0.6264599561691284,1.6430705785751345,50000 -3426.085087776184,4.998172760009766,40374.38328671456,88670,0,40374.38328671456,0.5082000494003296,2.318076610565185,10000,43810.47713375092,0.6852734088897705,1.405772089958191,0.6255399584770203,1.6802127361297607,50000 -3463.438698530197,5.036881446838379,40794.58645391464,89592,0,40794.58645391464,0.5103000402450562,2.291046142578125,10000,44268.12525868416,0.6747655868530273,1.4603726863861084,0.6269800066947937,1.6675292253494265,50000 -3501.4375364780426,5.075735569000244,41214.94577026367,90515,0,41214.94577026367,0.5087000131607056,2.2928144931793213,10000,44726.574046611786,0.6829491853713989,1.4198222160339355,0.6295199990272522,1.649441838264465,50000 -3541.404905796051,5.113611698150635,41635.21865081787,91434,0,41635.21865081787,0.5065000057220459,2.301645517349243,10000,45186.90314650536,0.6955859065055847,1.3533570766448977,0.6298800110816956,1.656938910484314,50000 -3578.7778816223145,5.15111517906189,42055.23688745499,92353,0,42055.23688745499,0.5028000473976135,2.322950601577759,10000,45644.38531947136,0.6741992235183716,1.4724483489990234,0.6269999742507935,1.6782883405685425,50000 -3617.549820184708,5.1927220821380615,42475.59681820869,93277,0,42475.59681820869,0.5100000500679016,2.268878698348999,10000,46103.61071944237,0.681933581829071,1.4165855646133425,0.6343399882316589,1.6316964626312256,50000 -3655.330612421036,5.234953880310059,42895.638721227646,94198,0,42895.638721227646,0.5108000040054321,2.293397188186645,10000,46561.52732515335,0.6891406178474426,1.4098796844482422,0.633080005645752,1.6661624908447266,50000 -3692.772596597672,5.277697801589966,43315.5960958004,95115,0,43315.5960958004,0.5033000111579895,2.2706520557403564,10000,47019.02144479752,0.6744726300239563,1.4238649606704712,0.6321200132369995,1.6284260749816897,50000 -3732.062595129013,5.319170951843262,43735.618525505066,96038,0,43735.618525505066,0.501800000667572,2.325078248977661,10000,47478.4274597168,0.6746679544448853,1.4650830030441284,0.6334800124168396,1.658732295036316,50000 -3769.371959209442,5.359732389450073,44155.63729739189,96961,0,44155.63729739189,0.5160000324249268,2.247906446456909,10000,47935.84779524803,0.6908984184265137,1.370441436767578,0.6380000114440918,1.611522197723389,50000 -3804.532338619232,5.398097276687622,44575.90721774101,97883,0,44575.90721774101,0.5092000365257263,2.252286434173584,10000,48391.369034051895,0.6875585913658142,1.3868331909179688,0.6341399550437927,1.6226903200149536,50000 -3842.685833454132,5.437853097915649,44995.98552107811,98806,0,44995.98552107811,0.5145000219345093,2.221693754196167,10000,48849.69229388237,0.6914257407188416,1.358039379119873,0.6440399885177612,1.5768593549728394,50000 -3881.4554891586304,5.483229875564575,45416.36224722862,99725,0,45416.36224722862,0.5164000391960144,2.2455978393554688,10000,49308.9356777668,0.6920703053474426,1.373831272125244,0.6429199576377869,1.6061257123947144,50000 -3919.975238323212,5.5283708572387695,45836.51930832863,100644,0,45836.51930832863,0.522599995136261,2.2153022289276123,10000,49767.70958662033,0.7134960889816284,1.2532609701156616,0.6430999636650085,1.5673333406448364,50000 -3958.030996799469,5.570519685745239,46256.86202788353,101569,0,46256.86202788353,0.5160000324249268,2.235506534576416,10000,50226.20248699188,0.6882616877555847,1.375690460205078,0.6430999636650085,1.592112421989441,50000 -3994.964736223221,5.617609024047852,46676.82009387016,102491,0,46676.82009387016,0.5182999968528748,2.2385599613189697,10000,50683.19379091263,0.695605456829071,1.3783432245254517,0.6459800004959106,1.5988671779632568,50000 -4035.222962141037,5.66266655921936,47097.179240942,103413,0,47097.179240942,0.5222000479698181,2.2103865146636963,10000,51143.90802812576,0.7080858945846558,1.311424970626831,0.6429799795150757,1.5880955457687378,50000 -4075.8489694595337,5.703973770141602,47517.37286019325,104333,0,47517.37286019325,0.5208000540733337,2.272473096847534,10000,51604.82066082954,0.6900585889816284,1.4127657413482666,0.6438199877738953,1.6250956058502195,50000 -4114.217289924622,5.74323320388794,47937.8508181572,105255,0,47937.8508181572,0.5284000039100647,2.170280933380127,10000,52063.75856423378,0.7038280963897705,1.3086973428726196,0.6522600054740906,1.5347038507461548,50000 -4153.619336605072,5.791689872741699,48358.08609056473,106178,0,48358.08609056473,0.5179000496864319,2.258666753768921,10000,52523.49588823319,0.695117175579071,1.3671385049819946,0.638480007648468,1.6232553720474243,50000 -4193.563049793243,5.832031965255737,48778.02906394005,107099,0,48778.02906394005,0.5246000289916992,2.1825060844421387,10000,52983.47541809082,0.7011132836341858,1.3185851573944092,0.6495000123977661,1.5532654523849487,50000 -4231.1956782341,5.872812747955322,49198.16246008873,108022,0,49198.16246008873,0.5370000004768372,2.137031316757202,10000,53441.33406162262,0.7089062333106995,1.286357879638672,0.6575999855995178,1.5152268409729004,50000 -4270.626555204392,5.915415287017822,49618.37132978439,108948,0,49618.37132978439,0.5306000113487244,2.1434850692749023,10000,53901.06821775437,0.7115429639816284,1.2524975538253784,0.6578199863433838,1.5066378116607666,50000 -4310.15029501915,5.962715864181519,50038.27680063248,109871,0,50038.27680063248,0.5297999978065491,2.186164140701294,10000,54360.60054755211,0.7302343845367432,1.2425791025161743,0.6523799896240234,1.5649343729019165,50000 -4349.666513442993,6.006732702255249,50458.470808029175,110790,0,50458.470808029175,0.5329000353813171,2.155395746231079,10000,54820.40697169304,0.7073437571525574,1.2899636030197144,0.6584799885749817,1.5123716592788696,50000 -4386.731926679611,6.05235767364502,50878.71767401695,111713,0,50878.71767401695,0.5364000201225281,2.1200437545776367,10000,55277.81778669357,0.7128515243530273,1.2536503076553345,0.6627599596977234,1.4864120483398438,50000 -4424.739602088928,6.320374011993408,51298.70504426956,112635,0,51298.70504426956,0.5369000434875488,2.1409804821014404,10000,55736.1331076622,0.7286913990974426,1.2007211446762085,0.6609399914741516,1.50294291973114,50000 -4459.790710687637,6.36035680770874,51718.90090465546,113557,0,51718.90090465546,0.532800018787384,2.165210485458374,10000,56191.47248148918,0.7140820026397705,1.3028429746627808,0.6602199673652649,1.5292065143585205,50000 -4498.1347053051,6.402165174484253,52139.10345339775,114479,0,52139.10345339775,0.5388000011444092,2.1023590564727783,10000,56650.11271905899,0.7202343344688416,1.2131118774414062,0.6656999588012695,1.4627410173416138,50000 -4536.924311637878,6.447792291641235,52559.15979242325,115403,0,52559.15979242325,0.539400041103363,2.100294589996338,10000,57109.05661630631,0.7263085842132568,1.1966229677200315,0.6647199988365173,1.4690905809402466,50000 -4576.432106971741,6.493849992752075,52979.50097155571,116327,0,52979.50097155571,0.5482000112533569,2.084333658218384,10000,57569.00668978691,0.7205468416213989,1.2346270084381104,0.6696799993515015,1.4559895992279053,50000 -4613.82989192009,6.539278507232666,53399.68098139763,117250,0,53399.68098139763,0.539900004863739,2.131301164627075,10000,58026.68306350708,0.7180468440055847,1.26100754737854,0.6637200117111206,1.4970401525497437,50000 -4652.822212696075,6.585542917251587,53819.9982714653,118173,0,53819.9982714653,0.5494000315666199,2.0697150230407715,10000,58486.09160208702,0.7331640720367432,1.1802427768707275,0.6723799705505371,1.4435547590255735,50000 -4692.190171718597,6.631305694580078,54240.08188533783,119096,0,54240.08188533783,0.5421000123023987,2.1308164596557617,10000,58945.64118242264,0.7335156202316284,1.199196219444275,0.6699999570846558,1.4819027185440063,50000 -4730.476469278336,6.678221940994263,54660.09229564667,120018,0,54660.09229564667,0.5497000217437744,2.077719211578369,10000,59404.036954164505,0.7240429520606995,1.2205166816711426,0.6732999682426453,1.439631104469299,50000 -4770.4309067726135,6.724869251251221,55080.20759105682,120944,0,55080.20759105682,0.5525000095367432,2.0614137649536133,10000,59864.2053706646,0.73388671875,1.1822478771209717,0.676099956035614,1.439404010772705,50000 -4810.173624038696,6.771934986114502,55500.20226883888,121867,0,55500.20226883888,0.5576000213623047,2.0427939891815186,10000,60324.041732788086,0.7496874928474426,1.114646315574646,0.6793599724769592,1.4229241609573364,50000 -4849.972870588303,6.823628664016724,55920.51446223259,122791,0,55920.51446223259,0.5555000305175781,2.0610427856445312,10000,60784.256432294846,0.7269140481948853,1.2267969846725464,0.6785999536514282,1.4419078826904297,50000 -4887.135262012482,6.871109485626221,56340.72695922852,123715,0,56340.72695922852,0.560699999332428,2.007731914520264,10000,61241.7313709259,0.7386718392372131,1.1407707929611206,0.6824399828910828,1.3840407133102417,50000 -4926.5925216674805,6.9160919189453125,56760.75370979309,124640,0,56760.75370979309,0.5530000329017639,2.0573394298553467,10000,61701.31235575676,0.7439648509025574,1.139055848121643,0.6798799633979797,1.4270832538604736,50000 -4966.275221347809,6.962021350860596,57180.97926735878,125563,0,57180.97926735878,0.5559000372886658,2.040470600128174,10000,62161.31812500954,0.7366601228713989,1.1655316352844238,0.6827999949455261,1.4155220985412598,50000 -5003.563591241837,7.007826805114746,57601.43358683586,126485,0,57601.43358683586,0.566100001335144,1.9890079498291016,10000,62619.15861034393,0.7461132407188416,1.1315586566925049,0.6882199645042419,1.3758423328399658,50000 -5041.106873750687,7.051853656768799,58021.75379371643,127409,0,58021.75379371643,0.5671000480651855,2.030573844909668,10000,63077.11794781685,0.7524218559265137,1.1364513635635376,0.6879400014877319,1.4114627838134766,50000 -5079.05620598793,7.098832368850708,58441.99978637695,128334,0,58441.99978637695,0.5678000450134277,1.993343472480774,10000,63535.41212558746,0.74964839220047,1.1114648580551147,0.689799964427948,1.371135950088501,50000 -5116.010575294495,7.14700722694397,58861.942873716354,129258,0,58861.942873716354,0.5713000297546387,1.966183304786682,10000,63992.40976333618,0.7498828172683716,1.1026932001113892,0.6937400102615356,1.3566606044769287,50000 -5155.474381446838,7.194597005844116,59282.19617843628,130184,0,59282.19617843628,0.5699000358581543,1.9783955812454224,10000,64452.22687602043,0.75355464220047,1.101863980293274,0.6942200064659119,1.3657922744750977,50000 -5197.139461517334,7.241278171539307,59702.45468664169,131104,0,59702.45468664169,0.5657000541687012,1.97388756275177,10000,64914.25153064728,0.7650976181030273,1.0323588848114014,0.6943999528884888,1.344438910484314,50000 -5236.505722999573,7.285628318786621,60122.40482163429,132025,0,60122.40482163429,0.5679000020027161,1.9602190256118768,10000,65373.663570165634,0.7515429258346558,1.0836161375045776,0.6946199536323547,1.3420709371566772,50000 -5274.264638900757,7.332998752593994,60542.771542072296,132951,0,60542.771542072296,0.5760000348091125,1.9236680269241333,10000,65831.88937497139,0.758593738079071,1.0445502996444702,0.6947999596595764,1.3174066543579102,50000 -5314.167495965958,7.379750490188599,60962.8595468998,133876,0,60962.8595468998,0.5711000561714172,1.9531786441802976,10000,66291.97956442833,0.7695898413658142,1.018075704574585,0.6988999843597412,1.3271454572677612,50000 -5353.349631071091,7.431820392608643,61382.80419540405,134798,0,61382.80419540405,0.5707000494003296,1.9330246448516848,10000,66751.209690094,0.7604882717132568,1.0434051752090454,0.6985399723052979,1.303519606590271,50000 -5391.70734000206,7.477010011672974,61803.14104223251,135721,0,61803.14104223251,0.5758000016212463,1.9241626262664795,10000,67210.00192761421,0.7647656202316284,1.0401690006256104,0.6980400085449219,1.3200068473815918,50000 -5431.663053035736,7.522252082824707,62223.05319976807,136644,0,62223.05319976807,0.5790000557899475,1.9301297664642327,10000,67669.96811890602,0.7710546851158142,1.0173821449279783,0.7045800089836121,1.311528205871582,50000 -5470.751053571701,7.575902938842773,62643.12946271896,137568,0,62643.12946271896,0.5751000046730042,1.937660217285156,10000,68129.23905944824,0.7683789134025574,1.041908621788025,0.7035399675369263,1.3146579265594482,50000 -5510.862103939056,7.621830224990845,63063.101344347,138489,0,63063.101344347,0.5884000062942505,1.885142922401428,10000,68589.41958975792,0.7689452767372131,1.025770902633667,0.7057200074195862,1.2924665212631226,50000 -5549.8006637096405,7.670376300811768,63483.11455059052,139413,0,63483.11455059052,0.5861999988555908,1.8938807249069207,10000,69048.47295331955,0.7711523175239563,1.000161051750183,0.7072599530220032,1.2786537408828735,50000 -5589.37908911705,7.717724561691284,63903.165801763535,140336,0,63903.165801763535,0.5909000039100647,1.8893526792526243,10000,69508.20226073265,0.7856835722923279,0.9613086581230164,0.7094999551773071,1.2880979776382446,50000 -5630.997187137604,7.762471675872803,64323.22742891312,141260,0,64323.22742891312,0.5878000259399414,1.8784040212631223,10000,69969.97873592377,0.7727343440055847,1.010166049003601,0.7104399800300598,1.2714622020721436,50000 -5669.488003015518,7.80783200263977,64743.2424428463,142185,0,64743.2424428463,0.5889000296592712,1.8770174980163568,10000,70428.5822839737,0.7766211032867432,0.98188453912735,0.7114999890327454,1.2591363191604614,50000 -5707.855123996735,7.854290246963501,65163.4554746151,143108,0,65163.4554746151,0.5932000279426575,1.8319133520126345,10000,70887.26061153412,0.7864062190055847,0.9273659586906432,0.7152000069618225,1.233481764793396,50000 -5746.891888856888,7.905426979064941,65583.5791592598,144030,0,65583.5791592598,0.593500018119812,1.8458307981491089,10000,71346.5245103836,0.7824804782867432,0.9607452154159546,0.7184999585151672,1.2336714267730713,50000 -5786.692119598389,7.957865715026855,66003.87094473839,144951,0,66003.87094473839,0.5945000052452087,1.8343600034713743,10000,71806.72054195404,0.7843359112739563,0.9414471983909608,0.7197799682617188,1.2216196060180664,50000 -5825.466367721558,8.004544019699097,66423.95629882812,145873,0,66423.95629882812,0.5939000248908997,1.827216863632202,10000,72265.67800307274,0.7897070050239563,0.930291473865509,0.7209399938583374,1.2243276834487915,50000 -5866.13242316246,8.054444074630737,66843.89471817017,146796,0,66843.89471817017,0.5915000438690186,1.837318658828736,10000,72726.38371825218,0.7828710675239563,0.94913649559021,0.7192999720573425,1.2191599607467651,50000 -5907.169671535492,8.107472896575928,67263.90523648262,147716,0,67263.90523648262,0.5963000059127808,1.8374955654144287,10000,73187.53686141968,0.7864453196525574,0.94319087266922,0.7235599756240845,1.2195154428482056,50000 -5948.59770822525,8.155084133148193,67684.12996077538,148639,0,67684.12996077538,0.6018000245094299,1.8382083177566528,10000,73649.2893576622,0.7929491996765137,0.936110258102417,0.7234199643135071,1.2337615489959717,50000 -5990.390378952026,8.203400373458862,68104.28977417946,149560,0,68104.28977417946,0.601900041103363,1.7990987300872805,10000,74111.3414068222,0.8044726252555847,0.8590974807739258,0.725600004196167,1.193511128425598,50000 -6028.878688812256,8.260981321334839,68524.2951323986,150481,0,68524.2951323986,0.6038000583648682,1.7990301847457886,10000,74569.94378137589,0.7959179282188416,0.9110642075538636,0.7285799980163574,1.19636070728302,50000 -6068.592380285263,8.307986736297607,68944.20941090584,151402,0,68944.20941090584,0.6051000356674194,1.761571168899536,10000,75029.67028093338,0.7992382645606995,0.8694639205932617,0.7313199639320374,1.1657003164291382,50000 -6109.569490194321,8.357463836669922,69364.16141724586,152323,0,69364.16141724586,0.6097000241279602,1.7557659149169922,10000,75490.700922966,0.8080468773841858,0.8454297184944153,0.7318399548530579,1.1674693822860718,50000 -6148.683254241943,8.404613018035889,69784.3099322319,153246,0,69784.3099322319,0.614300012588501,1.7566510438919067,10000,75950.06301856041,0.802539050579071,0.863842248916626,0.7321000099182129,1.1585302352905271,50000 -6186.854301929474,8.456630945205688,70204.3026239872,154167,0,70204.3026239872,0.614300012588501,1.755558729171753,10000,76408.33057045937,0.80726557970047,0.8443677425384521,0.7341799736022949,1.159018874168396,50000 -6224.655606746674,8.505312204360962,70624.36257171631,155089,0,70624.36257171631,0.6105000376701355,1.779624342918396,10000,76866.29339528084,0.8079296946525574,0.8516631722450256,0.7345199584960938,1.1670846939086914,50000 -6264.754841089249,8.815968751907349,71044.26182794571,156010,0,71044.26182794571,0.6135000586509705,1.7380512952804563,10000,77326.6550860405,0.8057616949081421,0.8517761826515198,0.7357400059700012,1.147667646408081,50000 -6304.327245473862,8.869405031204224,71464.36837172508,156933,0,71464.36837172508,0.6139000058174133,1.7471004724502563,10000,77786.44018650055,0.8101562261581421,0.8456878662109375,0.7382000088691711,1.148350715637207,50000 -6345.1846034526825,8.920923233032227,71884.33867025375,157854,0,71884.33867025375,0.6210000514984131,1.7272372245788574,10000,78247.37133145332,0.8182226419448853,0.813230574131012,0.7407799959182739,1.141973853111267,50000 -6385.997955322266,8.971666812896729,72304.50023531914,158776,0,72304.50023531914,0.6224000453948975,1.7298548221588137,10000,78708.44903898239,0.8186913728713989,0.811683177947998,0.7415399551391602,1.1460314989089966,50000 -6427.509117841721,9.025804042816162,72724.4717707634,159700,0,72724.4717707634,0.6160000562667847,1.725846529006958,10000,79170.03859901428,0.8161132335662842,0.8156614303588867,0.7422999739646912,1.1298848390579224,50000 -6468.714903116226,9.074469327926636,73144.85194015503,160621,0,73144.85194015503,0.624500036239624,1.7089358568191528,10000,79631.72527456284,0.8175585865974426,0.804296612739563,0.7449600100517273,1.1215741634368896,50000 -6509.093505382538,9.12383794784546,73564.85105657578,161542,0,73564.85105657578,0.6322000026702881,1.680891036987305,10000,80092.20484352112,0.8251757621765137,0.7653958201408386,0.7464799880981445,1.1053043603897097,50000 -6549.888410568237,9.1816246509552,73984.82580113411,162464,0,73984.82580113411,0.6265000104904175,1.694629788398743,10000,80553.0836699009,0.8218163847923279,0.7823770642280579,0.7467399835586548,1.1013247966766355,50000 -6588.652596235275,9.234551668167114,74404.8524620533,163386,0,74404.8524620533,0.6223000288009644,1.7001694440841677,10000,81011.978900671,0.8213866949081421,0.7875846028327942,0.7483199834823608,1.1053085327148438,50000 -6629.552686929703,9.28411316871643,74824.91179394722,164307,0,74824.91179394722,0.6322000026702881,1.6703184843063354,10000,81473.03928542137,0.8302733898162842,0.7504153251647949,0.7503199577331543,1.0849838256835938,50000 -6669.045813798904,9.33809781074524,75245.17643213272,165227,0,75245.17643213272,0.6283000111579895,1.6670862436294556,10000,81932.90296435356,0.8239843845367432,0.7693715691566467,0.7511799931526184,1.0776283740997314,50000 -6708.982246637344,9.390437126159668,75665.09493494034,166150,0,75665.09493494034,0.6307000517845154,1.6644469499588013,10000,82392.86175775528,0.8310937285423279,0.7499398589134216,0.7531399726867676,1.070204496383667,50000 -6749.0644516944885,9.445293426513672,76085.28699398041,167071,0,76085.28699398041,0.6333000063896179,1.6693150997161863,10000,82853.24296784401,0.83056640625,0.7435762882232666,0.7534199953079224,1.0754247903823853,50000 -6786.874273777008,9.495783567428589,76505.26083564758,167992,0,76505.26083564758,0.634600043296814,1.6695160865783691,10000,83311.12951374054,0.8321484327316284,0.7515702843666077,0.7541399598121643,1.0826858282089231,50000 -6827.716981649399,9.548727989196776,76925.41933321953,168913,0,76925.41933321953,0.638700008392334,1.6415356397628784,10000,83772.23526740074,0.8340038657188416,0.7366316318511963,0.7570599913597107,1.0599473714828491,50000 -6867.189259767532,9.600689172744751,77345.37213397026,169834,0,77345.37213397026,0.6380000114440918,1.6476746797561646,10000,84231.7641518116,0.8338086009025574,0.7275217771530151,0.7583400011062622,1.0552046298980713,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index f6cc36ad5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1890 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.33517402,6.907757,,,,,,,,,,,,,, -1,,,0.0007617187220603,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.824132204055786,83.2208137512207,42.824132204055786,40.396591901779175,0.0,0.0 -100,0.3944687,6.905936,,,,,,,,,,,,,, -200,0.4567522,6.894867,,,,,,,,,,,,,, -300,0.5040463,6.868991,,,,,,,,,,,,,, -400,0.6856747,6.8486357,,,,,,,,,,,,,, -500,0.7926198,6.7884073,,,,,,,,,,,,,, -600,0.65703934,6.740931,,,,,,,,,,,,,, -700,0.7858898,6.7175026,,,,,,,,,,,,,, -800,1.0388302,6.647991,,,,,,,,,,,,,, -890,,,0.0124023435637354,6.457183837890625,0.0122400000691413,6.465394973754883,50000.0,0.0100000007078051,6.503350734710693,10000.0,462.9493408203125,525.0217719078064,462.9493408203125,61.98973369598389,0.0296323299407959,0.0 -900,1.6568426,6.63247,,,,,,,,,,,,,, -1000,1.6070427,6.5840735,,,,,,,,,,,,,, -1100,1.4284866,6.5053806,,,,,,,,,,,,,, -1200,1.1995851,6.475441,,,,,,,,,,,,,, -1300,1.1551089,6.4062586,,,,,,,,,,,,,, -1400,2.7579477,6.3889184,,,,,,,,,,,,,, -1500,1.837784,6.376267,,,,,,,,,,,,,, -1600,1.4822022,6.2762003,,,,,,,,,,,,,, -1700,1.6277338,6.277441,,,,,,,,,,,,,, -1800,1.6519711,6.3694015,,,,,,,,,,,,,, -1829,,,0.0393554680049419,5.864162445068359,0.0392399989068508,5.889999866485596,50000.0,0.0310000013560056,5.997738838195801,10000.0,883.1624338626862,967.2049548625946,883.1624338626862,83.87272596359253,0.0612773895263671,0.0 -1900,1.4732711,6.5706534,,,,,,,,,,,,,, -2000,2.259101,6.2142124,,,,,,,,,,,,,, -2100,1.2659309,6.46337,,,,,,,,,,,,,, -2200,1.6723264,6.1429386,,,,,,,,,,,,,, -2300,1.7681768,6.085765,,,,,,,,,,,,,, -2400,1.2223927,6.567953,,,,,,,,,,,,,, -2500,1.3807824,6.6343265,,,,,,,,,,,,,, -2600,1.6499238,6.015642,,,,,,,,,,,,,, -2700,1.6207423,6.0157337,,,,,,,,,,,,,, -2769,,,0.0716210901737213,5.389620780944824,0.0655199959874153,5.453647136688232,50000.0,0.0510000027716159,5.634537220001221,10000.0,1303.3358738422394,1409.2352724075315,1303.3358738422394,105.64576721191406,0.0901827812194824,0.0 -2800,1.4433343,6.592649,,,,,,,,,,,,,, -2900,1.4951707,6.048842,,,,,,,,,,,,,, -3000,2.1412873,5.878791,,,,,,,,,,,,,, -3100,1.5835642,5.9385095,,,,,,,,,,,,,, -3200,1.8471047,6.531931,,,,,,,,,,,,,, -3300,1.8957144,5.877353,,,,,,,,,,,,,, -3400,1.5057442,5.8509746,,,,,,,,,,,,,, -3500,1.6337367,5.9111,,,,,,,,,,,,,, -3600,1.3735598,6.3580375,,,,,,,,,,,,,, -3700,1.742636,5.825474,,,,,,,,,,,,,, -3710,,,0.1043164059519767,5.0515570640563965,0.0981599986553192,5.091566562652588,50000.0,0.074800007045269,5.322034358978272,10000.0,1723.3727452754974,1851.19499373436,1723.3727452754974,127.48762011528017,0.1172609329223632,0.0 -3800,1.7297789,5.787086,,,,,,,,,,,,,, -3900,1.5352964,5.7607994,,,,,,,,,,,,,, -4000,1.6476717,5.6955876,,,,,,,,,,,,,, -4100,1.409554,5.99283,,,,,,,,,,,,,, -4200,1.5428634,5.6699076,,,,,,,,,,,,,, -4300,1.6690813,5.571576,,,,,,,,,,,,,, -4400,1.4018689,6.206125,,,,,,,,,,,,,, -4500,1.8544066,5.5358367,,,,,,,,,,,,,, -4600,1.8107886,5.492282,,,,,,,,,,,,,, -4645,,,0.1397460848093032,4.694349765777588,0.1315799951553344,4.74459171295166,50000.0,0.0993000045418739,5.037469387054443,10000.0,2143.325270175934,2293.20528960228,2143.325270175934,149.46379971504211,0.1450655460357666,0.0 -4700,1.5961668,5.5222826,,,,,,,,,,,,,, -4800,1.6056712,5.47886,,,,,,,,,,,,,, -4900,1.6883291,5.460701,,,,,,,,,,,,,, -5000,1.1609838,6.2917385,,,,,,,,,,,,,, -5100,1.5313162,5.713639,,,,,,,,,,,,,, -5200,1.407798,5.4826083,,,,,,,,,,,,,, -5300,1.7079568,5.336762,,,,,,,,,,,,,, -5400,1.5457431,5.2867374,,,,,,,,,,,,,, -5500,1.6251787,5.3433027,,,,,,,,,,,,,, -5578,,,0.1884570270776748,4.301875114440918,0.1712799966335296,4.40316104888916,50000.0,0.1262000054121017,4.740323066711426,10000.0,2563.570514202118,2735.4029870033264,2563.570514202118,171.33267450332642,0.1757004261016845,0.0 -5600,1.5831336,5.978959,,,,,,,,,,,,,, -5700,1.5547022,5.3265686,,,,,,,,,,,,,, -5800,1.458939,5.134392,,,,,,,,,,,,,, -5900,1.6681092,5.175317,,,,,,,,,,,,,, -6000,1.6376241,5.5128436,,,,,,,,,,,,,, -6100,1.2395114,6.048193,,,,,,,,,,,,,, -6200,1.6833968,5.3674183,,,,,,,,,,,,,, -6300,1.6589489,5.065326,,,,,,,,,,,,,, -6400,1.3850477,6.0769134,,,,,,,,,,,,,, -6500,1.9404573,4.9798007,,,,,,,,,,,,,, -6510,,,0.2281054705381393,3.9725728034973145,0.2132799923419952,4.052278518676758,50000.0,0.163000002503395,4.438085556030273,10000.0,2983.755667209625,3180.642364740372,2983.755667209625,196.30260586738584,0.2066588401794433,0.0 -6600,1.6280379,5.0951633,,,,,,,,,,,,,, -6700,1.3099027,6.214257,,,,,,,,,,,,,, -6800,1.314678,5.9952397,,,,,,,,,,,,,, -6900,1.390951,6.1743565,,,,,,,,,,,,,, -7000,1.5318607,6.1209803,,,,,,,,,,,,,, -7100,2.0307562,5.0016823,,,,,,,,,,,,,, -7200,1.8804468,5.0745344,,,,,,,,,,,,,, -7300,1.6320628,4.869517,,,,,,,,,,,,,, -7400,1.6361185,4.9214015,,,,,,,,,,,,,, -7447,,,0.2765820324420929,3.669283628463745,0.2522999942302704,3.7821102142333984,50000.0,0.1923000067472458,4.209742546081543,10000.0,3403.835516691208,3624.823510885239,3403.835516691208,220.32091236114505,0.2356452941894531,0.0 -7500,1.5397186,5.1101265,,,,,,,,,,,,,, -7600,2.2702484,4.83105,,,,,,,,,,,,,, -7700,1.6236684,4.7844644,,,,,,,,,,,,,, -7800,1.7342823,4.811121,,,,,,,,,,,,,, -7900,1.6204873,4.7615366,,,,,,,,,,,,,, -8000,2.1175847,4.744965,,,,,,,,,,,,,, -8100,1.4225781,4.9848185,,,,,,,,,,,,,, -8200,1.2266932,6.0975266,,,,,,,,,,,,,, -8300,1.627064,4.63911,,,,,,,,,,,,,, -8384,,,0.3094335794448852,3.4023940563201904,0.2844800055027008,3.535698652267456,50000.0,0.2238000035285949,3.993281841278076,10000.0,3823.778416156769,4072.47727060318,3823.778416156769,247.93834471702576,0.2755928039550781,0.0 -8400,1.3369449,5.721242,,,,,,,,,,,,,, -8500,1.4889462,4.935299,,,,,,,,,,,,,, -8600,1.6684306,4.9022665,,,,,,,,,,,,,, -8700,1.1214494,5.991039,,,,,,,,,,,,,, -8800,1.6637275,4.6082745,,,,,,,,,,,,,, -8900,1.3182817,6.1870666,,,,,,,,,,,,,, -9000,1.299927,5.5648336,,,,,,,,,,,,,, -9100,1.5803969,4.5832796,,,,,,,,,,,,,, -9200,1.7751827,4.5646524,,,,,,,,,,,,,, -9300,1.4913995,4.5023518,,,,,,,,,,,,,, -9320,,,0.3596288859844208,3.166682004928589,0.3202399909496307,3.347395896911621,50000.0,0.2459000051021576,3.8278791904449454,10000.0,4243.810468912125,4520.722212314606,4243.810468912125,276.0673451423645,0.3065640926361084,0.0 -9400,1.0412283,5.976206,,,,,,,,,,,,,, -9500,1.8172841,4.421599,,,,,,,,,,,,,, -9600,1.5608932,4.4627867,,,,,,,,,,,,,, -9700,1.4171172,5.345113,,,,,,,,,,,,,, -9800,1.2561456,5.8431373,,,,,,,,,,,,,, -9900,1.3113762,6.223316,,,,,,,,,,,,,, -10000,1.4268923,4.5805836,,,,,,,,,,,,,, -10100,1.6367049,4.4778643,,,,,,,,,,,,,, -10200,1.5578266,4.5634694,,,,,,,,,,,,,, -10253,,,0.3677538931369781,3.1120121479034424,0.3455999791622162,3.226508617401123,50000.0,0.2634000182151794,3.720574378967285,10000.0,4664.097786426544,4972.415317296982,4664.097786426544,307.3923919200897,0.3343639373779297,0.0 -10300,1.6402473,4.4213357,,,,,,,,,,,,,, -10400,1.5249307,4.343973,,,,,,,,,,,,,, -10500,1.6494812,4.481303,,,,,,,,,,,,,, -10600,1.7422875,4.4378123,,,,,,,,,,,,,, -10700,1.3987882,5.2345743,,,,,,,,,,,,,, -10800,1.2554674,5.5710125,,,,,,,,,,,,,, -10900,1.6371896,4.3128686,,,,,,,,,,,,,, -11000,1.3421242,5.3160677,,,,,,,,,,,,,, -11100,1.6047524,4.3145146,,,,,,,,,,,,,, -11185,,,0.3995117247104645,2.8868730068206787,0.3680599927902221,3.0476648807525635,50000.0,0.2808000147342682,3.5686235427856445,10000.0,5084.326451063156,5420.9716629982,5084.326451063156,335.6350944042206,0.3666074275970459,0.0 -11200,1.4320313,4.1929865,,,,,,,,,,,,,, -11300,1.5044566,4.216322,,,,,,,,,,,,,, -11400,1.6259528,4.824007,,,,,,,,,,,,,, -11500,1.0848694,6.038086,,,,,,,,,,,,,, -11600,1.1300865,5.9306707,,,,,,,,,,,,,, -11700,1.587306,4.375807,,,,,,,,,,,,,, -11800,1.1610994,5.0070376,,,,,,,,,,,,,, -11900,1.3607811,4.426067,,,,,,,,,,,,,, -12000,1.4311336,4.1263294,,,,,,,,,,,,,, -12100,1.2267541,5.259185,,,,,,,,,,,,,, -12117,,,0.4317578077316284,2.685864210128784,0.3901999890804291,2.880558729171753,50000.0,0.2963000237941742,3.4384853839874268,10000.0,5504.530478954315,5871.672466278076,5504.530478954315,366.0485026836395,0.3962712287902832,0.0 -12200,1.7315024,4.1544,,,,,,,,,,,,,, -12300,1.445748,4.3670173,,,,,,,,,,,,,, -12400,1.5935549,4.192972,,,,,,,,,,,,,, -12500,1.4608994,4.3627844,,,,,,,,,,,,,, -12600,1.0504589,5.827262,,,,,,,,,,,,,, -12700,1.6683779,4.1427703,,,,,,,,,,,,,, -12800,1.5389209,4.1655574,,,,,,,,,,,,,, -12900,1.6384673,4.1984696,,,,,,,,,,,,,, -13000,1.2196078,4.8191876,,,,,,,,,,,,,, -13044,,,0.4320703148841858,2.7094104290008545,0.4023999869823456,2.8486061096191406,50000.0,0.3115000128746032,3.408851385116577,10000.0,5924.697474718094,6322.066045045853,5924.697474718094,396.19059109687805,0.429196834564209,0.0 -13100,1.4046534,4.1963916,,,,,,,,,,,,,, -13200,1.8616201,4.083909,,,,,,,,,,,,,, -13300,1.5325838,4.051358,,,,,,,,,,,,,, -13400,2.390177,4.1318884,,,,,,,,,,,,,, -13500,1.4874743,4.957622,,,,,,,,,,,,,, -13600,1.4397324,4.148898,,,,,,,,,,,,,, -13700,1.3671142,4.5437946,,,,,,,,,,,,,, -13800,1.0689243,5.6754584,,,,,,,,,,,,,, -13900,1.358344,4.3096023,,,,,,,,,,,,,, -13967,,,0.4513085782527923,2.5765609741210938,0.4201000034809112,2.7298953533172607,50000.0,0.3273000121116638,3.2973392009735107,10000.0,6344.838493108749,6778.072106599808,6344.838493108749,431.9696838855744,0.46321702003479,0.0 -14000,1.1182157,5.1848707,,,,,,,,,,,,,, -14100,1.5316029,4.559302,,,,,,,,,,,,,, -14200,0.9431723,5.724108,,,,,,,,,,,,,, -14300,1.0030106,5.667451,,,,,,,,,,,,,, -14400,1.02763,6.0243206,,,,,,,,,,,,,, -14500,1.6068822,4.089573,,,,,,,,,,,,,, -14600,1.5199301,4.0381165,,,,,,,,,,,,,, -14700,1.4418695,4.0598354,,,,,,,,,,,,,, -14800,1.3633994,4.1674886,,,,,,,,,,,,,, -14896,,,0.4676562249660492,2.475290536880493,0.4269999861717224,2.6804628372192383,50000.0,0.3342000246047973,3.250800371170044,10000.0,6764.83682847023,7233.150120258331,6764.83682847023,466.9709808826447,0.489241361618042,0.0 -14900,1.5945747,3.852796,,,,,,,,,,,,,, -15000,1.3827907,3.9936388,,,,,,,,,,,,,, -15100,1.4525716,4.161554,,,,,,,,,,,,,, -15200,1.4989096,4.0157027,,,,,,,,,,,,,, -15300,1.4548367,4.101805,,,,,,,,,,,,,, -15400,1.1710943,5.164604,,,,,,,,,,,,,, -15500,1.5339382,4.0550194,,,,,,,,,,,,,, -15600,1.3281765,4.1522846,,,,,,,,,,,,,, -15700,1.5061122,4.04529,,,,,,,,,,,,,, -15800,1.4073306,4.1666236,,,,,,,,,,,,,, -15826,,,0.4655663967132568,2.52592134475708,0.435619980096817,2.66223406791687,50000.0,0.3340000212192535,3.236973285675049,10000.0,7185.151596069336,7691.874465227127,7185.151596069336,505.29772305488586,0.520289421081543,0.0 -15900,1.6428598,4.5499706,,,,,,,,,,,,,, -16000,1.3101712,3.9303632,,,,,,,,,,,,,, -16100,1.4226633,3.931915,,,,,,,,,,,,,, -16200,1.440107,4.253079,,,,,,,,,,,,,, -16300,1.5922475,4.604142,,,,,,,,,,,,,, -16400,1.6853753,4.019253,,,,,,,,,,,,,, -16500,1.0682857,5.6821375,,,,,,,,,,,,,, -16600,1.3749901,3.8820705,,,,,,,,,,,,,, -16700,1.3083951,3.9709702,,,,,,,,,,,,,, -16754,,,0.4782226383686065,2.447401762008667,0.4474399983882904,2.5987234115600586,50000.0,0.3508000075817108,3.1780033111572266,10000.0,7605.435419559479,8152.040428161621,7605.435419559479,545.1009593009949,0.5469293594360352,0.0 -16800,1.1161864,4.8156633,,,,,,,,,,,,,, -16900,1.31452,4.0940766,,,,,,,,,,,,,, -17000,1.1518232,4.5387287,,,,,,,,,,,,,, -17100,1.3088934,4.3568416,,,,,,,,,,,,,, -17200,1.3566141,3.9078608,,,,,,,,,,,,,, -17300,1.2991809,4.5883937,,,,,,,,,,,,,, -17400,1.3695495,4.7508125,,,,,,,,,,,,,, -17500,1.3287022,3.9039767,,,,,,,,,,,,,, -17600,1.3669735,3.9553125,,,,,,,,,,,,,, -17682,,,0.4890234172344208,2.3615238666534424,0.4515799880027771,2.543088912963867,50000.0,0.3485000133514404,3.146413326263428,10000.0,8025.745712041855,8613.524369716644,8025.745712041855,586.1966207027435,0.5725915431976318,0.0 -17700,1.5086837,3.8369765,,,,,,,,,,,,,, -17800,1.3283826,3.8832111,,,,,,,,,,,,,, -17900,1.2936093,4.084863,,,,,,,,,,,,,, -18000,1.2742945,3.937617,,,,,,,,,,,,,, -18100,1.2767797,3.7744076,,,,,,,,,,,,,, -18200,1.2504127,4.5887365,,,,,,,,,,,,,, -18300,1.2174411,4.424422,,,,,,,,,,,,,, -18400,1.2805078,3.841825,,,,,,,,,,,,,, -18500,1.3169707,3.8795035,,,,,,,,,,,,,, -18600,1.25074,3.8224895,,,,,,,,,,,,,, -18605,,,0.5197656154632568,2.212848901748657,0.4608999788761139,2.4904685020446777,50000.0,0.3604000210762024,3.0848608016967773,10000.0,8445.693154335022,9072.734701156616,8445.693154335022,625.3753478527069,0.6053094863891602,0.0 -18700,1.3960762,3.8227105,,,,,,,,,,,,,, -18800,1.055377,5.0069156,,,,,,,,,,,,,, -18900,1.2639996,4.1738257,,,,,,,,,,,,,, -19000,1.4130228,3.773006,,,,,,,,,,,,,, -19100,0.9263901,5.7256374,,,,,,,,,,,,,, -19200,1.1895419,3.8688855,,,,,,,,,,,,,, -19300,1.311854,3.8969343,,,,,,,,,,,,,, -19400,1.6098737,3.9181967,,,,,,,,,,,,,, -19500,0.9570789,5.8581934,,,,,,,,,,,,,, -19531,,,0.5014843344688416,2.333058357238769,0.4675999879837036,2.486729145050049,50000.0,0.3640000224113464,3.088551998138428,10000.0,8865.636800050735,9534.022550106049,8865.636800050735,666.6394157409668,0.6336965560913086,0.0 -19600,1.2897859,4.291822,,,,,,,,,,,,,, -19700,1.2026556,3.884405,,,,,,,,,,,,,, -19800,1.1585981,4.1222076,,,,,,,,,,,,,, -19900,1.3248845,3.8632812,,,,,,,,,,,,,, -20000,1.332415,3.7890148,,,,,,,,,,,,,, -20100,1.0575377,5.5225625,,,,,,,,,,,,,, -20200,1.1268694,4.580862,,,,,,,,,,,,,, -20300,1.0240886,5.107607,,,,,,,,,,,,,, -20400,0.827665,5.8913064,,,,,,,,,,,,,, -20451,,,0.5143749713897705,2.221330165863037,0.4759399890899658,2.4047062397003174,50000.0,0.3753000199794769,3.02260684967041,10000.0,9284.03717637062,9993.490169763563,9284.03717637062,705.9051666259766,2.38281798362732,0.0 -20500,1.5463189,3.8534265,,,,,,,,,,,,,, -20600,0.8530426,5.886279,,,,,,,,,,,,,, -20700,1.4110459,3.7973356,,,,,,,,,,,,,, -20800,1.3141598,3.8187249,,,,,,,,,,,,,, -20900,1.0656337,5.8445964,,,,,,,,,,,,,, -21000,1.0809128,5.2059126,,,,,,,,,,,,,, -21100,1.1913829,3.9574375,,,,,,,,,,,,,, -21200,1.9854105,3.82403,,,,,,,,,,,,,, -21300,1.2869126,3.766968,,,,,,,,,,,,,, -21375,,,0.5318750143051147,2.146445274353028,0.4820199906826019,2.3802831172943115,50000.0,0.3745000064373016,3.008895874023437,10000.0,9703.97527360916,10451.55467915535,9703.97527360916,743.9496030807495,2.413090705871582,0.0 -21400,0.9700244,5.5409536,,,,,,,,,,,,,, -21500,1.0251663,4.462543,,,,,,,,,,,,,, -21600,1.3766007,3.815008,,,,,,,,,,,,,, -21700,1.2528785,3.8870258,,,,,,,,,,,,,, -21800,1.3118476,3.9024084,,,,,,,,,,,,,, -21900,0.99181575,4.889383,,,,,,,,,,,,,, -22000,1.4093006,3.770662,,,,,,,,,,,,,, -22100,1.1913848,3.8869867,,,,,,,,,,,,,, -22200,1.2721884,3.7209456,,,,,,,,,,,,,, -22298,,,0.525390625,2.1639983654022217,0.4918799996376037,2.320196866989136,50000.0,0.3839000165462494,2.934317111968994,10000.0,10124.1838889122,10906.612278938292,10124.1838889122,778.7204098701477,2.43961763381958,0.0 -22300,1.2730278,5.7167745,,,,,,,,,,,,,, -22400,1.3438139,3.712302,,,,,,,,,,,,,, -22500,1.3016139,3.6936507,,,,,,,,,,,,,, -22600,1.2812002,3.7106264,,,,,,,,,,,,,, -22700,1.3373916,3.8861783,,,,,,,,,,,,,, -22800,0.9637715,5.7446856,,,,,,,,,,,,,, -22900,1.3626289,3.8401077,,,,,,,,,,,,,, -23000,1.4668549,3.7131135,,,,,,,,,,,,,, -23100,1.095709,4.0304756,,,,,,,,,,,,,, -23200,0.9612156,5.737274,,,,,,,,,,,,,, -23222,,,0.5257226228713989,2.1706438064575195,0.4936599731445312,2.339932680130005,50000.0,0.3893000185489654,2.94062876701355,10000.0,10544.424010038376,11363.582971572876,10544.424010038376,815.3714473247528,2.4672508239746094,0.0 -23300,1.3004149,4.16434,,,,,,,,,,,,,, -23400,1.4149588,3.8839178,,,,,,,,,,,,,, -23500,1.4951947,3.7208483,,,,,,,,,,,,,, -23600,1.1138532,5.7621493,,,,,,,,,,,,,, -23700,1.3257753,3.9050984,,,,,,,,,,,,,, -23800,1.186448,4.171544,,,,,,,,,,,,,, -23900,1.1874988,4.0204325,,,,,,,,,,,,,, -24000,1.3801149,3.7063344,,,,,,,,,,,,,, -24100,0.88679165,5.6830025,,,,,,,,,,,,,, -24147,,,0.5456640720367432,2.072561502456665,0.5017399787902832,2.2792351245880127,50000.0,0.3957000076770782,2.886673212051392,10000.0,10964.40007162094,11819.682758808136,10964.40007162094,851.4152765274048,2.494343042373657,0.0 -24200,1.401906,3.7225013,,,,,,,,,,,,,, -24300,1.4067308,3.6021786,,,,,,,,,,,,,, -24400,1.2686813,3.7378583,,,,,,,,,,,,,, -24500,1.4588624,3.6432219,,,,,,,,,,,,,, -24600,1.3644376,3.547066,,,,,,,,,,,,,, -24700,1.3674982,3.7187443,,,,,,,,,,,,,, -24800,1.0913763,5.0351367,,,,,,,,,,,,,, -24900,1.316244,3.7169313,,,,,,,,,,,,,, -25000,1.3553078,3.9433708,,,,,,,,,,,,,, -25071,,,0.5464843511581421,2.0741374492645264,0.5124599933624268,2.23947811126709,50000.0,0.4038000106811523,2.85197377204895,10000.0,11384.33907365799,12275.507807016373,11384.33907365799,887.2221512794495,2.52208948135376,0.0 -25100,0.88909256,5.489066,,,,,,,,,,,,,, -25200,1.2224084,3.6789467,,,,,,,,,,,,,, -25300,1.5263709,4.100057,,,,,,,,,,,,,, -25400,1.5517235,3.7532825,,,,,,,,,,,,,, -25500,1.2392205,4.0451474,,,,,,,,,,,,,, -25600,1.1050329,4.4326315,,,,,,,,,,,,,, -25700,1.4098892,3.6923177,,,,,,,,,,,,,, -25800,1.2145044,4.707416,,,,,,,,,,,,,, -25900,1.1927906,4.5259175,,,,,,,,,,,,,, -25985,,,0.5458202958106995,2.107918739318848,0.5095399618148804,2.2755002975463867,50000.0,0.4057000279426574,2.873729705810547,10000.0,11804.526224374771,12732.602709293364,11804.526224374771,924.0465886592864,2.554238557815552,0.0 -26000,1.5670245,3.6207564,,,,,,,,,,,,,, -26100,1.2758925,3.6553063,,,,,,,,,,,,,, -26200,0.9581122,5.396819,,,,,,,,,,,,,, -26300,1.3333784,3.7591648,,,,,,,,,,,,,, -26400,1.366662,3.7763405,,,,,,,,,,,,,, -26500,1.2562938,3.7004278,,,,,,,,,,,,,, -26600,1.1764121,4.0197744,,,,,,,,,,,,,, -26700,1.4847194,3.5008245,,,,,,,,,,,,,, -26800,1.2965316,3.7677784,,,,,,,,,,,,,, -26900,1.446977,4.0964327,,,,,,,,,,,,,, -26908,,,0.554394543170929,2.04179310798645,0.5141599774360657,2.2237548828125,50000.0,0.4059000313282013,2.844035387039185,10000.0,12224.458804368973,13188.403272867205,12224.458804368973,959.8301196098328,2.5867295265197754,0.0 -27000,1.4398328,3.5893123,,,,,,,,,,,,,, -27100,1.3160322,3.552747,,,,,,,,,,,,,, -27200,1.1201563,4.8979607,,,,,,,,,,,,,, -27300,1.2629448,3.9194484,,,,,,,,,,,,,, -27400,1.6254361,3.675377,,,,,,,,,,,,,, -27500,1.2849855,4.085932,,,,,,,,,,,,,, -27600,1.404914,4.0388393,,,,,,,,,,,,,, -27700,1.1753873,4.8529806,,,,,,,,,,,,,, -27800,1.3155726,3.7421815,,,,,,,,,,,,,, -27831,,,0.5831835865974426,1.937749147415161,0.5192999839782715,2.2157299518585205,50000.0,0.4076000154018402,2.831859588623047,10000.0,12644.591686487198,13644.293235778809,12644.591686487198,995.5060038566588,2.615927696228028,0.0 -27900,1.3756055,3.6777565,,,,,,,,,,,,,, -28000,1.2786368,3.544041,,,,,,,,,,,,,, -28100,1.4039577,3.6485634,,,,,,,,,,,,,, -28200,1.4870384,3.556885,,,,,,,,,,,,,, -28300,1.337833,3.5229592,,,,,,,,,,,,,, -28400,1.1083313,5.4800262,,,,,,,,,,,,,, -28500,1.4861562,3.5563903,,,,,,,,,,,,,, -28600,1.0250443,5.435132,,,,,,,,,,,,,, -28700,1.0298393,5.7190933,,,,,,,,,,,,,, -28754,,,0.56103515625,2.0178725719451904,0.5253599882125854,2.184187412261963,50000.0,0.415800005197525,2.801346063613892,10000.0,13064.928541898727,14101.277722358704,13064.928541898727,1032.0707716941831,2.6466317176818848,0.0 -28800,1.1668799,4.911312,,,,,,,,,,,,,, -28900,1.3465817,3.513232,,,,,,,,,,,,,, -29000,1.2150611,4.034042,,,,,,,,,,,,,, -29100,1.1361371,4.2518897,,,,,,,,,,,,,, -29200,0.97973186,5.3952,,,,,,,,,,,,,, -29300,1.560395,3.6275976,,,,,,,,,,,,,, -29400,1.3960607,3.782717,,,,,,,,,,,,,, -29500,1.2119089,4.905441,,,,,,,,,,,,,, -29600,1.4013778,3.5775523,,,,,,,,,,,,,, -29679,,,0.5685741901397705,1.968665719032288,0.5270999670028687,2.1590123176574707,50000.0,0.4166000187397003,2.762173652648926,10000.0,13485.01470541954,14557.166967391968,13485.01470541954,1067.7865002155304,2.6822452545166016,0.0 -29700,1.327778,3.556865,,,,,,,,,,,,,, -29800,1.5184784,3.5657806,,,,,,,,,,,,,, -29900,1.0568287,4.8717723,,,,,,,,,,,,,, -30000,1.3712444,3.936689,,,,,,,,,,,,,, -30100,1.2277291,3.6194854,,,,,,,,,,,,,, -30200,1.1565527,4.112338,,,,,,,,,,,,,, -30300,1.0209578,5.534213,,,,,,,,,,,,,, -30400,1.5733399,3.5154548,,,,,,,,,,,,,, -30500,1.0926967,5.632367,,,,,,,,,,,,,, -30600,1.3784844,3.5687296,,,,,,,,,,,,,, -30601,,,0.5899999737739563,1.8909555673599243,0.5362399816513062,2.1361539363861084,50000.0,0.4199000298976898,2.761985540390014,10000.0,13905.378865480425,15013.537324428558,13905.378865480425,1103.7031581401825,2.720014333724976,0.0 -30700,1.0695727,4.931524,,,,,,,,,,,,,, -30800,1.4049388,3.6087098,,,,,,,,,,,,,, -30900,1.2057663,4.1127133,,,,,,,,,,,,,, -31000,1.1170183,5.337214,,,,,,,,,,,,,, -31100,1.3824284,3.893921,,,,,,,,,,,,,, -31200,1.4714578,3.4210577,,,,,,,,,,,,,, -31300,1.3671867,5.737711,,,,,,,,,,,,,, -31400,1.2643563,3.9600174,,,,,,,,,,,,,, -31500,1.4424386,3.7512045,,,,,,,,,,,,,, -31525,,,0.5753320455551147,1.9117754697799685,0.5366199612617493,2.086362838745117,50000.0,0.4225000143051147,2.7142233848571777,10000.0,14325.58296585083,15470.007066965103,14325.58296585083,1139.8854219913485,2.751550912857056,0.0 -31600,1.3454238,3.4853742,,,,,,,,,,,,,, -31700,1.1051531,4.887311,,,,,,,,,,,,,, -31800,1.4165289,3.8478591,,,,,,,,,,,,,, -31900,1.1910776,4.018475,,,,,,,,,,,,,, -32000,1.3158436,3.6079004,,,,,,,,,,,,,, -32100,1.0794152,4.2830586,,,,,,,,,,,,,, -32200,1.3727474,3.477355,,,,,,,,,,,,,, -32300,1.3982443,3.477656,,,,,,,,,,,,,, -32400,1.3964664,4.1438584,,,,,,,,,,,,,, -32447,,,0.5731250047683716,1.933479070663452,0.5350800156593323,2.114149808883667,50000.0,0.4178000092506408,2.760493278503418,10000.0,14745.557694911957,15925.88976097107,14745.557694911957,1175.7121279239657,2.7820372581481934,0.0 -32500,1.3601714,3.4207392,,,,,,,,,,,,,, -32600,1.367614,3.9328794,,,,,,,,,,,,,, -32700,1.09067,4.4552116,,,,,,,,,,,,,, -32800,1.0329124,5.4948497,,,,,,,,,,,,,, -32900,1.118403,5.0750117,,,,,,,,,,,,,, -33000,1.4198844,3.5109932,,,,,,,,,,,,,, -33100,1.4455903,3.6862192,,,,,,,,,,,,,, -33200,1.2920195,3.931557,,,,,,,,,,,,,, -33300,1.3982977,5.7522855,,,,,,,,,,,,,, -33367,,,0.5793749690055847,1.8931059837341309,0.5353599786758423,2.102208614349365,50000.0,0.4253000319004059,2.7308449745178223,10000.0,15165.785947084429,16381.58100271225,15165.785947084429,1211.0903453826904,2.8150620460510254,0.0 -33400,0.9309152,5.1300535,,,,,,,,,,,,,, -33500,1.619412,3.6092694,,,,,,,,,,,,,, -33600,1.0752639,5.4312873,,,,,,,,,,,,,, -33700,1.4304398,3.5167518,,,,,,,,,,,,,, -33800,1.2063762,4.29366,,,,,,,,,,,,,, -33900,1.1076742,4.8731966,,,,,,,,,,,,,, -34000,1.0304502,5.028012,,,,,,,,,,,,,, -34100,1.4619782,3.521889,,,,,,,,,,,,,, -34200,1.2156123,3.9186442,,,,,,,,,,,,,, -34289,,,0.5872460603713989,1.8730449676513672,0.5454199910163879,2.0598084926605225,50000.0,0.4348000288009643,2.68808913230896,10000.0,15586.13660979271,16838.84478354454,15586.13660979271,1247.915519475937,2.8516042232513428,0.0 -34300,1.3369195,3.530988,,,,,,,,,,,,,, -34400,1.5489534,3.4855506,,,,,,,,,,,,,, -34500,1.3804796,3.4794333,,,,,,,,,,,,,, -34600,1.1618634,4.767331,,,,,,,,,,,,,, -34700,1.0942336,5.691152,,,,,,,,,,,,,, -34800,1.5097761,3.6077175,,,,,,,,,,,,,, -34900,1.0854024,4.321489,,,,,,,,,,,,,, -35000,1.1843855,4.817035,,,,,,,,,,,,,, -35100,1.3299055,3.5306568,,,,,,,,,,,,,, -35200,1.3294377,3.5573363,,,,,,,,,,,,,, -35212,,,0.583691418170929,1.88885509967804,0.5408999919891357,2.083228588104248,50000.0,0.4269000291824341,2.722228527069092,10000.0,16006.159181833267,17294.766576051712,16006.159181833267,1283.7321796417236,2.881927490234375,0.0 -35300,1.2240127,3.9169528,,,,,,,,,,,,,, -35400,1.2570306,3.8905497,,,,,,,,,,,,,, -35500,1.3129815,4.3917923,,,,,,,,,,,,,, -35600,1.1593761,4.788415,,,,,,,,,,,,,, -35700,1.4127804,3.4944887,,,,,,,,,,,,,, -35800,1.4675512,3.6268215,,,,,,,,,,,,,, -35900,1.2037278,4.5174794,,,,,,,,,,,,,, -36000,1.4826876,3.5307171,,,,,,,,,,,,,, -36100,1.1263953,5.6330705,,,,,,,,,,,,,, -36132,,,0.5895312428474426,1.8599177598953247,0.5429800152778625,2.068998336791992,50000.0,0.4353000223636627,2.674314498901367,10000.0,16426.25867486,17750.791348934174,16426.25867486,1319.5742392539978,2.9135375022888184,0.0 -36200,1.457898,3.635407,,,,,,,,,,,,,, -36300,1.5929871,3.4321032,,,,,,,,,,,,,, -36400,1.4246892,3.5705886,,,,,,,,,,,,,, -36500,1.1468499,4.245644,,,,,,,,,,,,,, -36600,1.423402,3.5063672,,,,,,,,,,,,,, -36700,1.3003011,3.8699338,,,,,,,,,,,,,, -36800,1.3448124,3.5667784,,,,,,,,,,,,,, -36900,1.3665301,3.444279,,,,,,,,,,,,,, -37000,1.1242392,5.2653384,,,,,,,,,,,,,, -37053,,,0.5921093821525574,1.889419674873352,0.5410400032997131,2.1071622371673584,50000.0,0.4245000183582306,2.73975157737732,10000.0,16846.513379335403,18207.25081396103,16846.513379335403,1355.6950623989103,2.946035861968994,0.0 -37100,1.082592,5.0353374,,,,,,,,,,,,,, -37200,1.3837693,3.7161853,,,,,,,,,,,,,, -37300,1.4477278,3.402524,,,,,,,,,,,,,, -37400,1.4045105,3.4523911,,,,,,,,,,,,,, -37500,1.1767241,4.0397363,,,,,,,,,,,,,, -37600,1.4421339,3.4550643,,,,,,,,,,,,,, -37700,1.4899968,3.5101342,,,,,,,,,,,,,, -37800,1.1311299,5.2804313,,,,,,,,,,,,,, -37900,1.3806335,3.4456673,,,,,,,,,,,,,, -37974,,,0.5863280892372131,1.8857553005218504,0.5454800128936768,2.069470167160034,50000.0,0.4247000217437744,2.710883617401123,10000.0,17266.63311100006,18663.40659666061,17266.63311100006,1391.6465280056,2.97900915145874,0.0 -38000,1.5205817,3.4553678,,,,,,,,,,,,,, -38100,1.4879117,3.3831596,,,,,,,,,,,,,, -38200,1.3774967,3.4405818,,,,,,,,,,,,,, -38300,1.0613751,5.144145,,,,,,,,,,,,,, -38400,1.3409827,3.6086059,,,,,,,,,,,,,, -38500,1.4765854,3.458194,,,,,,,,,,,,,, -38600,1.638136,3.7129304,,,,,,,,,,,,,, -38700,1.3163316,3.8283277,,,,,,,,,,,,,, -38800,1.3981721,3.5430682,,,,,,,,,,,,,, -38897,,,0.5941210985183716,1.883009552955628,0.5507400035858154,2.086865186691284,50000.0,0.4341000318527221,2.703827381134033,10000.0,17686.921609401703,19120.326916456223,17686.921609401703,1428.1919829845428,3.0127615928649902,0.0 -38900,1.1009296,5.50356,,,,,,,,,,,,,, -39000,1.1009089,5.03108,,,,,,,,,,,,,, -39100,1.4421314,3.6070092,,,,,,,,,,,,,, -39200,1.4308667,3.870319,,,,,,,,,,,,,, -39300,1.5161839,3.4412456,,,,,,,,,,,,,, -39400,1.1856271,4.157056,,,,,,,,,,,,,, -39500,1.5185821,3.544436,,,,,,,,,,,,,, -39600,1.6314793,3.3545215,,,,,,,,,,,,,, -39700,1.4988676,3.3439245,,,,,,,,,,,,,, -39800,1.3686135,4.03205,,,,,,,,,,,,,, -39821,,,0.6173437237739563,1.7155250310897827,0.555620014667511,2.007648706436157,50000.0,0.4416000247001648,2.6214611530303955,10000.0,18107.20042920113,19576.857031822205,18107.20042920113,1464.3569984436035,3.0470962524414062,0.0 -39900,1.4990593,3.4236696,,,,,,,,,,,,,, -40000,1.4752612,3.6213708,,,,,,,,,,,,,, -40100,1.4571722,3.4694796,,,,,,,,,,,,,, -40200,1.377916,3.4431622,,,,,,,,,,,,,, -40300,1.4096674,3.6040158,,,,,,,,,,,,,, -40400,1.5702767,3.3755016,,,,,,,,,,,,,, -40500,1.3360392,3.9318316,,,,,,,,,,,,,, -40600,0.98389465,5.5356607,,,,,,,,,,,,,, -40700,1.5535392,3.4801402,,,,,,,,,,,,,, -40743,,,0.5999609231948853,1.8218204975128167,0.5581799745559692,1.9893065690994265,50000.0,0.445000022649765,2.613339900970459,10000.0,18527.43185925484,20033.48488211632,18527.43185925484,1500.6704943180084,3.0773818492889404,0.0 -40800,1.1257608,5.3239775,,,,,,,,,,,,,, -40900,1.2696817,4.038423,,,,,,,,,,,,,, -41000,1.2842511,4.4221506,,,,,,,,,,,,,, -41100,1.142294,5.343773,,,,,,,,,,,,,, -41200,1.4294802,3.4437847,,,,,,,,,,,,,, -41300,1.429682,3.455772,,,,,,,,,,,,,, -41400,1.1734929,5.555881,,,,,,,,,,,,,, -41500,1.3027158,4.1507545,,,,,,,,,,,,,, -41600,1.6415635,3.4785366,,,,,,,,,,,,,, -41667,,,0.5997461080551147,1.8272331953048704,0.5560599565505981,2.0224764347076416,50000.0,0.4398000240325928,2.6552605628967285,10000.0,18947.66690206528,20489.236602544785,18947.66690206528,1536.0949032306671,3.118054151535034,0.0 -41700,1.0863891,4.4313774,,,,,,,,,,,,,, -41800,1.9132621,3.487046,,,,,,,,,,,,,, -41900,1.7258173,3.6542375,,,,,,,,,,,,,, -42000,1.4724482,3.4236891,,,,,,,,,,,,,, -42100,1.4465281,4.194357,,,,,,,,,,,,,, -42200,1.5923483,3.5103514,,,,,,,,,,,,,, -42300,1.1484933,4.6817603,,,,,,,,,,,,,, -42400,1.1585927,5.307663,,,,,,,,,,,,,, -42500,1.4344822,4.7792587,,,,,,,,,,,,,, -42591,,,0.6067578196525574,1.7815989255905151,0.5556600093841553,2.026575326919556,50000.0,0.4319000244140625,2.671928882598877,10000.0,19367.82270050049,20946.551624774933,19367.82270050049,1573.170075416565,3.149902820587158,0.0 -42600,1.3421242,3.4792528,,,,,,,,,,,,,, -42700,1.4214664,3.7475848,,,,,,,,,,,,,, -42800,1.3839965,3.3921952,,,,,,,,,,,,,, -42900,1.1638018,5.3724794,,,,,,,,,,,,,, -43000,1.327116,5.55859,,,,,,,,,,,,,, -43100,1.3831019,4.0560126,,,,,,,,,,,,,, -43200,1.4696965,4.076774,,,,,,,,,,,,,, -43300,1.3762169,4.939571,,,,,,,,,,,,,, -43400,1.3714826,3.407767,,,,,,,,,,,,,, -43500,1.7108008,3.3083658,,,,,,,,,,,,,, -43515,,,0.6055273413658142,1.7664735317230225,0.5644599795341492,1.9504902362823489,50000.0,0.45210000872612,2.557616233825684,10000.0,19787.93092918396,21402.830537080765,19787.93092918396,1609.2521879673004,3.1867854595184326,0.0 -43600,1.4323927,3.3313937,,,,,,,,,,,,,, -43700,1.6076933,3.4123676,,,,,,,,,,,,,, -43800,1.3794093,3.3416545,,,,,,,,,,,,,, -43900,1.0707762,5.403753,,,,,,,,,,,,,, -44000,1.4398348,3.3974152,,,,,,,,,,,,,, -44100,1.4235396,3.698818,,,,,,,,,,,,,, -44200,1.3981664,3.363635,,,,,,,,,,,,,, -44300,1.5296786,3.5170355,,,,,,,,,,,,,, -44400,1.4135915,3.4470313,,,,,,,,,,,,,, -44436,,,0.6005859375,1.8105329275131223,0.5595600008964539,1.997233271598816,50000.0,0.4448000192642212,2.628516912460327,10000.0,20207.914827108383,21858.304526090626,20207.914827108383,1644.6554489135742,3.2182297706604004,0.0 -44500,1.397162,4.9272885,,,,,,,,,,,,,, -44600,1.263859,4.6302915,,,,,,,,,,,,,, -44700,1.4979767,3.441124,,,,,,,,,,,,,, -44800,1.4237665,3.5880861,,,,,,,,,,,,,, -44900,1.5266664,3.4234471,,,,,,,,,,,,,, -45000,1.6804179,3.5295937,,,,,,,,,,,,,, -45100,1.0992559,4.8325434,,,,,,,,,,,,,, -45200,1.5489312,3.4421375,,,,,,,,,,,,,, -45300,1.5284737,3.3868446,,,,,,,,,,,,,, -45355,,,0.6087499856948853,1.792636513710022,0.5613799691200256,2.0202178955078125,50000.0,0.4436000287532806,2.6396775245666504,10000.0,20627.91112852097,22315.22365355492,20627.91112852097,1681.4962322711945,3.2489402294158936,0.0 -45400,1.3217665,4.168898,,,,,,,,,,,,,, -45500,1.1615018,4.4023275,,,,,,,,,,,,,, -45600,1.4985347,3.2635593,,,,,,,,,,,,,, -45700,1.4633634,3.4811664,,,,,,,,,,,,,, -45800,1.5245286,3.4126997,,,,,,,,,,,,,, -45900,1.3134512,5.3684964,,,,,,,,,,,,,, -46000,1.3844496,3.6181707,,,,,,,,,,,,,, -46100,1.4401743,3.4577239,,,,,,,,,,,,,, -46200,1.2994485,4.0816193,,,,,,,,,,,,,, -46275,,,0.6075195074081421,1.7662358283996582,0.5708999633789062,1.937433123588562,50000.0,0.4536000192165375,2.589318037033081,10000.0,21047.84006094933,22771.119074821472,21047.84006094933,1717.3759191036224,3.283371686935425,0.0 -46300,1.3798045,3.5855217,,,,,,,,,,,,,, -46400,1.6203064,3.6066277,,,,,,,,,,,,,, -46500,1.451541,3.3002334,,,,,,,,,,,,,, -46600,1.7336164,3.4215248,,,,,,,,,,,,,, -46700,1.3704627,4.088948,,,,,,,,,,,,,, -46800,1.3706781,3.4263217,,,,,,,,,,,,,, -46900,1.2802656,4.010287,,,,,,,,,,,,,, -47000,1.137405,4.839507,,,,,,,,,,,,,, -47100,1.3775102,4.269268,,,,,,,,,,,,,, -47195,,,0.6070312261581421,1.791407585144043,0.5654999613761902,1.988101363182068,50000.0,0.4508000314235687,2.603720664978028,10000.0,21468.07667350769,23227.797281980515,21468.07667350769,1753.7322096824646,3.31660795211792,0.0 -47200,1.5190488,3.345691,,,,,,,,,,,,,, -47300,1.464968,3.4303803,,,,,,,,,,,,,, -47400,1.3744543,3.3253112,,,,,,,,,,,,,, -47500,1.5693138,3.3919182,,,,,,,,,,,,,, -47600,1.4422331,3.9078128,,,,,,,,,,,,,, -47700,1.290684,4.219934,,,,,,,,,,,,,, -47800,1.4844421,3.470271,,,,,,,,,,,,,, -47900,1.5625672,3.4222677,,,,,,,,,,,,,, -48000,1.407191,3.6751733,,,,,,,,,,,,,, -48100,1.4987297,3.3751464,,,,,,,,,,,,,, -48117,,,0.6176171898841858,1.6974457502365112,0.572219967842102,1.9044924974441528,50000.0,0.4601000249385834,2.536510705947876,10000.0,21888.20467019081,23684.09828400612,21888.20467019081,1789.8139972686768,3.355715274810791,0.0 -48200,1.5656911,3.4791608,,,,,,,,,,,,,, -48300,1.515554,3.374039,,,,,,,,,,,,,, -48400,1.419277,3.4021654,,,,,,,,,,,,,, -48500,1.4002689,5.2166753,,,,,,,,,,,,,, -48600,1.2725676,4.011326,,,,,,,,,,,,,, -48700,1.266788,4.959113,,,,,,,,,,,,,, -48800,1.606404,3.3966565,,,,,,,,,,,,,, -48900,1.3832427,3.3875945,,,,,,,,,,,,,, -49000,1.068361,5.550053,,,,,,,,,,,,,, -49039,,,0.6376562118530273,1.6303468942642212,0.5713199973106384,1.924487590789795,50000.0,0.4565000236034393,2.5707848072052,10000.0,22308.17884039879,24139.977380990986,22308.17884039879,1825.6325912475584,3.390817880630493,0.0 -49100,1.4861565,3.3785605,,,,,,,,,,,,,, -49200,1.3923732,3.3943748,,,,,,,,,,,,,, -49300,1.250942,4.3853426,,,,,,,,,,,,,, -49400,1.4366899,3.449019,,,,,,,,,,,,,, -49500,1.7158535,3.4446974,,,,,,,,,,,,,, -49600,1.5248976,3.5026894,,,,,,,,,,,,,, -49700,1.4159441,3.3118153,,,,,,,,,,,,,, -49800,1.60045,3.50733,,,,,,,,,,,,,, -49900,1.4248242,5.494286,,,,,,,,,,,,,, -49960,,,0.6035937070846558,1.750400185585022,0.5683599710464478,1.9232509136199951,50000.0,0.4570000171661377,2.5590219497680664,10000.0,22728.231050014496,24596.820907592773,22728.231050014496,1862.3381762504573,3.424906015396118,0.0 -50000,1.1655424,5.431139,,,,,,,,,,,,,, -50100,1.4562838,3.3845515,,,,,,,,,,,,,, -50200,1.5363826,3.4090211,,,,,,,,,,,,,, -50300,1.4909337,3.35223,,,,,,,,,,,,,, -50400,1.1780419,5.1433244,,,,,,,,,,,,,, -50500,1.5057317,3.3171768,,,,,,,,,,,,,, -50600,1.5378122,3.354642,,,,,,,,,,,,,, -50700,1.312893,4.0049977,,,,,,,,,,,,,, -50800,1.2160918,4.5542383,,,,,,,,,,,,,, -50881,,,0.6149609088897705,1.7614407539367676,0.5702599883079529,1.953475832939148,50000.0,0.4577000141143799,2.5869696140289307,10000.0,23148.14708518982,25050.40443754196,23148.14708518982,1895.92019033432,3.456855535507202,0.0 -50900,1.5308548,3.4661322,,,,,,,,,,,,,, -51000,1.6373341,3.3593698,,,,,,,,,,,,,, -51100,1.2417121,5.6013556,,,,,,,,,,,,,, -51200,1.4471431,3.3670309,,,,,,,,,,,,,, -51300,1.5334184,3.506979,,,,,,,,,,,,,, -51400,1.6283389,3.3848584,,,,,,,,,,,,,, -51500,1.3894817,3.4715557,,,,,,,,,,,,,, -51600,1.4771667,3.743594,,,,,,,,,,,,,, -51700,1.5953499,3.5124197,,,,,,,,,,,,,, -51800,1.2290246,5.6003275,,,,,,,,,,,,,, -51802,,,0.6318554282188416,1.6667355298995972,0.5773599743843079,1.915133237838745,50000.0,0.45660001039505,2.557199239730835,10000.0,23568.13712787628,25507.155699014664,23568.13712787628,1932.593270778656,3.4931063652038574,0.0 -51900,1.485987,3.6918159,,,,,,,,,,,,,, -52000,1.3811347,3.937836,,,,,,,,,,,,,, -52100,1.497204,3.279941,,,,,,,,,,,,,, -52200,1.7476113,3.3996468,,,,,,,,,,,,,, -52300,1.4294133,4.2771792,,,,,,,,,,,,,, -52400,1.4081739,3.3522706,,,,,,,,,,,,,, -52500,1.4651228,3.23125,,,,,,,,,,,,,, -52600,1.186015,5.4037094,,,,,,,,,,,,,, -52700,1.3863865,5.483536,,,,,,,,,,,,,, -52724,,,0.6173242330551147,1.7071141004562378,0.578220009803772,1.8849382400512693,50000.0,0.4574000239372253,2.535064697265625,10000.0,23988.11474442482,25964.319465875626,23988.11474442482,1969.695535421372,3.5249931812286377,0.0 -52800,1.4261775,3.4312687,,,,,,,,,,,,,, -52900,1.3989031,3.8810272,,,,,,,,,,,,,, -53000,1.5781121,3.3541458,,,,,,,,,,,,,, -53100,1.5389555,3.4449327,,,,,,,,,,,,,, -53200,1.6573749,3.389726,,,,,,,,,,,,,, -53300,1.5893599,3.301926,,,,,,,,,,,,,, -53400,1.5769271,3.3366961,,,,,,,,,,,,,, -53500,1.4939743,3.2405713,,,,,,,,,,,,,, -53600,1.615295,3.4182224,,,,,,,,,,,,,, -53645,,,0.6204687356948853,1.695976972579956,0.5777400135993958,1.8915741443634035,50000.0,0.4635000228881836,2.5295896530151367,10000.0,24408.527057886124,26420.55820798874,24408.527057886124,2005.4340209960933,3.5617029666900635,0.0 -53700,1.4689612,3.3754015,,,,,,,,,,,,,, -53800,1.448836,3.379928,,,,,,,,,,,,,, -53900,1.6800512,3.4260874,,,,,,,,,,,,,, -54000,1.3986452,4.1057715,,,,,,,,,,,,,, -54100,1.3255795,4.1950164,,,,,,,,,,,,,, -54200,1.6741647,3.414941,,,,,,,,,,,,,, -54300,1.3978542,3.6664708,,,,,,,,,,,,,, -54400,1.6604851,3.4877448,,,,,,,,,,,,,, -54500,1.5444139,3.2403135,,,,,,,,,,,,,, -54568,,,0.6335741877555847,1.6409857273101809,0.5818600058555603,1.8811854124069207,50000.0,0.4611000120639801,2.53002667427063,10000.0,24828.74754881859,26878.447025060654,24828.74754881859,2043.0168118476868,3.5952277183532715,0.0 -54600,1.4251069,4.0355353,,,,,,,,,,,,,, -54700,1.5852594,3.3035884,,,,,,,,,,,,,, -54800,1.543244,3.4180486,,,,,,,,,,,,,, -54900,1.2214215,4.29932,,,,,,,,,,,,,, -55000,1.6385382,3.352611,,,,,,,,,,,,,, -55100,1.3060044,5.3332624,,,,,,,,,,,,,, -55200,1.4961666,3.9246197,,,,,,,,,,,,,, -55300,1.1575947,5.362774,,,,,,,,,,,,,, -55400,1.5155296,3.3664246,,,,,,,,,,,,,, -55491,,,0.6231836080551147,1.6743615865707395,0.5826799869537354,1.8473048210144043,50000.0,0.4648000299930572,2.4718806743621826,10000.0,25248.86332678795,27335.411857128143,25248.86332678795,2079.7810554504395,3.627671003341675,0.0 -55500,1.4372631,4.232883,,,,,,,,,,,,,, -55600,1.3616213,3.7333841,,,,,,,,,,,,,, -55700,1.811741,3.451851,,,,,,,,,,,,,, -55800,1.604629,3.2704425,,,,,,,,,,,,,, -55900,1.8117096,3.3498595,,,,,,,,,,,,,, -56000,1.4071642,3.9336722,,,,,,,,,,,,,, -56100,1.4793072,3.542179,,,,,,,,,,,,,, -56200,1.3899616,3.6053958,,,,,,,,,,,,,, -56300,1.5506184,3.4536462,,,,,,,,,,,,,, -56400,1.6235776,3.2926507,,,,,,,,,,,,,, -56413,,,0.6209765672683716,1.711039662361145,0.5796599984169006,1.9055297374725344,50000.0,0.465800017118454,2.5380899906158447,10000.0,25669.019901752472,27793.2751955986,25669.019901752472,2117.397604942322,3.666386127471924,0.0 -56500,1.4471815,3.2782826,,,,,,,,,,,,,, -56600,1.2596277,4.942795,,,,,,,,,,,,,, -56700,1.5084257,3.693614,,,,,,,,,,,,,, -56800,1.3892207,3.7626302,,,,,,,,,,,,,, -56900,1.4003994,5.35501,,,,,,,,,,,,,, -57000,1.5789138,3.3134737,,,,,,,,,,,,,, -57100,1.316026,4.476963,,,,,,,,,,,,,, -57200,1.6348324,3.2706666,,,,,,,,,,,,,, -57300,1.6623006,3.4914289,,,,,,,,,,,,,, -57336,,,0.6293359398841858,1.6536279916763306,0.5833799839019775,1.8656693696975708,50000.0,0.4615000188350677,2.495638608932495,10000.0,26089.111981153488,28250.289984464645,26089.111981153488,2154.234006166458,3.701314687728882,0.0 -57400,1.6021324,3.2666123,,,,,,,,,,,,,, -57500,1.50483,3.4038143,,,,,,,,,,,,,, -57600,1.3932134,5.239692,,,,,,,,,,,,,, -57700,1.7543751,3.3644743,,,,,,,,,,,,,, -57800,1.6091881,3.3446288,,,,,,,,,,,,,, -57900,1.2218244,4.7250547,,,,,,,,,,,,,, -58000,1.2981539,4.5698943,,,,,,,,,,,,,, -58100,1.6974286,3.4165611,,,,,,,,,,,,,, -58200,1.6269834,3.3423831,,,,,,,,,,,,,, -58258,,,0.6524804830551147,1.5615637302398682,0.5855000019073486,1.8619376420974727,50000.0,0.4659000337123871,2.48901891708374,10000.0,26509.280297517776,28708.02276062965,26509.280297517776,2191.714658260345,3.733286857604981,0.0 -58300,1.6324835,3.262469,,,,,,,,,,,,,, -58400,1.3075081,4.446304,,,,,,,,,,,,,, -58500,1.4995842,3.302044,,,,,,,,,,,,,, -58600,1.7786801,3.372744,,,,,,,,,,,,,, -58700,1.3212782,4.429657,,,,,,,,,,,,,, -58800,1.3605888,3.636927,,,,,,,,,,,,,, -58900,1.7527564,3.278634,,,,,,,,,,,,,, -59000,1.3377528,4.4308004,,,,,,,,,,,,,, -59100,1.3988198,4.9245176,,,,,,,,,,,,,, -59178,,,0.6221289038658142,1.7175687551498413,0.5812399983406067,1.9062336683273315,50000.0,0.4636000096797943,2.526654005050659,10000.0,26929.2906024456,29164.714790582657,26929.2906024456,2228.310579776764,3.7675628662109375,0.0 -59200,1.7121661,3.2046046,,,,,,,,,,,,,, -59300,1.4674265,5.5658884,,,,,,,,,,,,,, -59400,1.62171,3.1476245,,,,,,,,,,,,,, -59500,1.5859112,3.3109477,,,,,,,,,,,,,, -59600,1.2029542,5.3628683,,,,,,,,,,,,,, -59700,1.5539933,3.2407641,,,,,,,,,,,,,, -59800,1.4218955,3.6968014,,,,,,,,,,,,,, -59900,1.2764623,4.476866,,,,,,,,,,,,,, -60000,1.6312524,3.4100976,,,,,,,,,,,,,, -60097,,,0.6240624785423279,1.7262558937072754,0.5816799998283386,1.931992769241333,50000.0,0.4661000072956085,2.5786662101745605,10000.0,27349.29820728302,29622.019852399822,27349.29820728302,2265.5188434124,3.804023742675781,0.0 -60100,1.4978838,3.2586255,,,,,,,,,,,,,, -60200,1.6176416,3.3231585,,,,,,,,,,,,,, -60300,1.5389239,3.1972392,,,,,,,,,,,,,, -60400,1.2147913,4.717794,,,,,,,,,,,,,, -60500,1.56694,3.1436658,,,,,,,,,,,,,, -60600,1.2690148,5.4425516,,,,,,,,,,,,,, -60700,1.225002,5.187043,,,,,,,,,,,,,, -60800,1.3599946,4.4545403,,,,,,,,,,,,,, -60900,1.5548036,3.3840854,,,,,,,,,,,,,, -61000,1.5656849,3.2291715,,,,,,,,,,,,,, -61019,,,0.6532421708106995,1.5591392517089844,0.5931999683380127,1.8360151052474976,50000.0,0.4717000126838684,2.467651128768921,10000.0,27769.632704496384,30078.310692310333,27769.632704496384,2301.384221792221,3.842408418655396,0.0 -61100,1.5800846,3.2322335,,,,,,,,,,,,,, -61200,1.503913,3.2816339,,,,,,,,,,,,,, -61300,1.5610148,3.2513983,,,,,,,,,,,,,, -61400,1.4864967,5.410632,,,,,,,,,,,,,, -61500,1.1429319,5.183804,,,,,,,,,,,,,, -61600,1.6711068,3.2888913,,,,,,,,,,,,,, -61700,1.3929614,5.336631,,,,,,,,,,,,,, -61800,1.4076631,3.583211,,,,,,,,,,,,,, -61900,1.4994075,4.4523525,,,,,,,,,,,,,, -61941,,,0.6307030916213989,1.6259307861328125,0.5897600054740906,1.8175615072250368,50000.0,0.4720000326633453,2.4553451538085938,10000.0,28189.844376802444,30536.71168017388,28189.844376802444,2339.4837741851807,3.8803889751434326,0.0 -62000,1.5948399,3.2382045,,,,,,,,,,,,,, -62100,1.622854,3.2424946,,,,,,,,,,,,,, -62200,1.3358616,3.663153,,,,,,,,,,,,,, -62300,1.3159521,4.561371,,,,,,,,,,,,,, -62400,1.376097,3.8038247,,,,,,,,,,,,,, -62500,1.5195333,4.163169,,,,,,,,,,,,,, -62600,1.5256834,3.2343018,,,,,,,,,,,,,, -62700,1.2578523,5.413676,,,,,,,,,,,,,, -62800,1.4124305,3.7205281,,,,,,,,,,,,,, -62861,,,0.6349999904632568,1.6504743099212646,0.5909799933433533,1.858465313911438,50000.0,0.4687000215053558,2.47857403755188,10000.0,28609.88588285446,30992.379618406296,28609.88588285446,2375.022164583206,3.916743040084839,0.0 -62900,1.6510768,3.2137887,,,,,,,,,,,,,, -63000,1.9804137,3.3771534,,,,,,,,,,,,,, -63100,1.5563765,3.2819371,,,,,,,,,,,,,, -63200,1.5690482,3.251042,,,,,,,,,,,,,, -63300,1.5365808,3.9141068,,,,,,,,,,,,,, -63400,1.2799077,5.3364325,,,,,,,,,,,,,, -63500,1.2889283,4.492838,,,,,,,,,,,,,, -63600,1.5505028,3.2178888,,,,,,,,,,,,,, -63700,1.3801829,4.181286,,,,,,,,,,,,,, -63784,,,0.6424999833106995,1.608486294746399,0.5922999978065491,1.8405706882476809,50000.0,0.4695000350475311,2.484283208847046,10000.0,29029.89266204834,31449.811811208725,29029.89266204834,2412.3593595027924,3.953082323074341,0.0 -63800,1.7297133,3.2111263,,,,,,,,,,,,,, -63900,1.171914,4.7595453,,,,,,,,,,,,,, -64000,1.5805978,3.2025704,,,,,,,,,,,,,, -64100,1.5995089,3.4253752,,,,,,,,,,,,,, -64200,1.6356102,3.1913362,,,,,,,,,,,,,, -64300,1.8369453,3.3171675,,,,,,,,,,,,,, -64400,1.7739656,3.1979156,,,,,,,,,,,,,, -64500,1.7061214,3.3702013,,,,,,,,,,,,,, -64600,1.6139042,3.2535114,,,,,,,,,,,,,, -64700,1.5922302,3.2630992,,,,,,,,,,,,,, -64704,,,0.6398632526397705,1.6252983808517456,0.5931400060653687,1.8280532360076904,50000.0,0.4734000265598297,2.459376335144043,10000.0,29449.915155172348,31906.87753367424,29449.915155172348,2449.3163084983826,3.987928152084351,0.0 -64800,1.5977147,3.167488,,,,,,,,,,,,,, -64900,1.3142676,3.9814372,,,,,,,,,,,,,, -65000,1.2185801,5.4903135,,,,,,,,,,,,,, -65100,1.1963903,4.6594133,,,,,,,,,,,,,, -65200,1.3936716,4.838554,,,,,,,,,,,,,, -65300,1.4589763,3.5924702,,,,,,,,,,,,,, -65400,1.7031236,3.316302,,,,,,,,,,,,,, -65500,1.506186,3.3207805,,,,,,,,,,,,,, -65600,1.4650856,3.5962806,,,,,,,,,,,,,, -65625,,,0.6363476514816284,1.6080312728881836,0.596019983291626,1.798905611038208,50000.0,0.4728000164031982,2.4222564697265625,10000.0,29870.18850684166,32366.242438316345,29870.18850684166,2488.313669919968,4.030225992202759,0.0 -65700,1.5155494,3.4399626,,,,,,,,,,,,,, -65800,1.5722268,3.4233775,,,,,,,,,,,,,, -65900,1.6850668,3.2334874,,,,,,,,,,,,,, -66000,1.6155951,3.3336878,,,,,,,,,,,,,, -66100,1.578576,3.3401952,,,,,,,,,,,,,, -66200,1.3171977,4.233203,,,,,,,,,,,,,, -66300,1.594095,3.367209,,,,,,,,,,,,,, -66400,1.4255981,4.7897005,,,,,,,,,,,,,, -66500,1.2926247,5.364485,,,,,,,,,,,,,, -66548,,,0.6463476419448853,1.5539718866348269,0.593779981136322,1.7889726161956787,50000.0,0.4757000207901001,2.425429582595825,10000.0,30290.33807373047,32824.26109623909,30290.33807373047,2526.093469142914,4.068289279937744,0.0 -66600,1.9462143,3.290664,,,,,,,,,,,,,, -66700,1.5336941,5.3792543,,,,,,,,,,,,,, -66800,1.5928652,3.1918604,,,,,,,,,,,,,, -66900,1.4819412,3.5069866,,,,,,,,,,,,,, -67000,1.4657782,3.5534701,,,,,,,,,,,,,, -67100,1.4406712,4.646687,,,,,,,,,,,,,, -67200,1.2394817,5.455593,,,,,,,,,,,,,, -67300,1.8076074,3.197499,,,,,,,,,,,,,, -67400,1.7475151,3.213417,,,,,,,,,,,,,, -67470,,,0.6481249928474426,1.55096435546875,0.5956199765205383,1.7877963781356812,50000.0,0.4787000119686126,2.4377787113189697,10000.0,30710.64385533333,33281.84917807579,30710.64385533333,2563.291362285614,4.101492404937744,0.0 -67500,1.7562927,3.13412,,,,,,,,,,,,,, -67600,1.3404431,4.312476,,,,,,,,,,,,,, -67700,1.2520077,4.0702295,,,,,,,,,,,,,, -67800,1.2270267,4.7765493,,,,,,,,,,,,,, -67900,1.7168615,3.2969332,,,,,,,,,,,,,, -68000,1.5113168,3.8235672,,,,,,,,,,,,,, -68100,1.7862117,3.1676054,,,,,,,,,,,,,, -68200,1.6838094,4.0656166,,,,,,,,,,,,,, -68300,1.5850862,3.2227159,,,,,,,,,,,,,, -68391,,,0.6390038728713989,1.6182537078857422,0.5971599817276001,1.8191072940826416,50000.0,0.4754000306129455,2.4577841758728027,10000.0,31130.57970190048,33738.69984698296,31130.57970190048,2600.11688375473,4.13883113861084,0.0 -68400,1.8358788,3.3090825,,,,,,,,,,,,,, -68500,1.7613665,3.522103,,,,,,,,,,,,,, -68600,1.614366,3.7076035,,,,,,,,,,,,,, -68700,1.2640111,4.8696356,,,,,,,,,,,,,, -68800,1.5744698,3.474563,,,,,,,,,,,,,, -68900,1.2933742,5.3463087,,,,,,,,,,,,,, -69000,1.685293,3.3095572,,,,,,,,,,,,,, -69100,1.476324,4.109457,,,,,,,,,,,,,, -69200,1.7972479,3.2035105,,,,,,,,,,,,,, -69300,1.7276423,3.2448545,,,,,,,,,,,,,, -69310,,,0.6440038681030273,1.572296142578125,0.5978800058364868,1.780421018600464,50000.0,0.4824000298976898,2.421178340911865,10000.0,31550.54110193253,34196.13939833641,31550.54110193253,2637.507829427719,4.173320293426514,0.0 -69400,1.6606685,3.341425,,,,,,,,,,,,,, -69500,1.8084843,3.2043266,,,,,,,,,,,,,, -69600,1.3040215,5.376007,,,,,,,,,,,,,, -69700,1.5838172,3.3063455,,,,,,,,,,,,,, -69800,1.4998708,3.0509706,,,,,,,,,,,,,, -69900,1.4692097,3.3613036,,,,,,,,,,,,,, -70000,1.52145,3.1508088,,,,,,,,,,,,,, -70100,1.4512255,3.436523,,,,,,,,,,,,,, -70200,1.4803934,3.7335808,,,,,,,,,,,,,, -70234,,,0.6702734231948853,1.4641485214233398,0.6001200079917908,1.762584209442139,50000.0,0.4819000363349914,2.420220375061035,10000.0,31970.663256645203,34651.5661213398,31970.663256645203,2672.718720436096,4.21485447883606,0.0 -70300,1.3348322,5.006994,,,,,,,,,,,,,, -70400,1.5903962,3.2512224,,,,,,,,,,,,,, -70500,1.2924613,5.3666706,,,,,,,,,,,,,, -70600,1.6122137,3.5453153,,,,,,,,,,,,,, -70700,1.5429479,3.8474462,,,,,,,,,,,,,, -70800,1.6298426,3.2444944,,,,,,,,,,,,,, -70900,1.5457895,3.3931673,,,,,,,,,,,,,, -71000,1.3646227,3.8388495,,,,,,,,,,,,,, -71100,1.2894918,4.8257904,,,,,,,,,,,,,, -71154,,,0.6441210508346558,1.5635147094726562,0.6025399565696716,1.7569859027862549,50000.0,0.4907000362873077,2.3740594387054443,10000.0,32390.83561515808,35109.47786331177,32390.83561515808,2710.3715307712555,4.249778985977173,0.0 -71200,1.6642448,3.220045,,,,,,,,,,,,,, -71300,1.6826278,3.341535,,,,,,,,,,,,,, -71400,1.6646572,3.3832164,,,,,,,,,,,,,, -71500,1.6091348,3.332512,,,,,,,,,,,,,, -71600,1.4551938,5.4821806,,,,,,,,,,,,,, -71700,1.5858777,3.4212306,,,,,,,,,,,,,, -71800,1.5229528,3.14882,,,,,,,,,,,,,, -71900,1.6046569,3.3900805,,,,,,,,,,,,,, -72000,1.6858753,3.2542133,,,,,,,,,,,,,, -72077,,,0.6434960961341858,1.6113227605819702,0.6014800071716309,1.8047555685043333,50000.0,0.4802000224590301,2.438359498977661,10000.0,32810.90993022919,35567.80102777481,32810.90993022919,2748.524088859558,4.293323755264282,0.0 -72100,1.6013334,3.249718,,,,,,,,,,,,,, -72200,1.7116075,3.1795678,,,,,,,,,,,,,, -72300,1.4899918,3.6865554,,,,,,,,,,,,,, -72400,1.2743652,4.954829,,,,,,,,,,,,,, -72500,1.8231944,3.1519232,,,,,,,,,,,,,, -72600,1.7606003,3.3172016,,,,,,,,,,,,,, -72700,1.7853976,3.1405067,,,,,,,,,,,,,, -72800,1.6541182,3.6034884,,,,,,,,,,,,,, -72900,1.6471936,3.076346,,,,,,,,,,,,,, -72996,,,0.6573632955551147,1.5504130125045776,0.5985599756240845,1.8023332357406616,50000.0,0.4825000166893005,2.439271926879883,10000.0,33231.24440646172,36026.76460838318,33231.24440646172,2787.061530351639,4.332884788513184,0.0 -73000,1.5373864,3.3479288,,,,,,,,,,,,,, -73100,1.2890491,3.993775,,,,,,,,,,,,,, -73200,1.3362029,4.7853775,,,,,,,,,,,,,, -73300,1.3661362,5.2907624,,,,,,,,,,,,,, -73400,1.6594999,3.1541314,,,,,,,,,,,,,, -73500,1.9256092,3.1239314,,,,,,,,,,,,,, -73600,1.7528524,3.3490896,,,,,,,,,,,,,, -73700,1.528428,3.4194427,,,,,,,,,,,,,, -73800,1.7089852,3.0871592,,,,,,,,,,,,,, -73900,1.6489906,3.1473014,,,,,,,,,,,,,, -73920,,,0.6471093893051147,1.6129900217056274,0.6095799803733826,1.7827093601226809,50000.0,0.486700028181076,2.417206048965454,10000.0,33651.5921421051,36484.73053359985,33651.5921421051,2824.5923268795013,4.368396282196045,0.0 -74000,1.5825915,3.207239,,,,,,,,,,,,,, -74100,1.597182,3.3221617,,,,,,,,,,,,,, -74200,1.4390664,4.1886654,,,,,,,,,,,,,, -74300,1.6223706,3.6086655,,,,,,,,,,,,,, -74400,1.5480229,3.474413,,,,,,,,,,,,,, -74500,1.6997718,3.3063486,,,,,,,,,,,,,, -74600,1.9226855,3.2664104,,,,,,,,,,,,,, -74700,1.8722805,3.199106,,,,,,,,,,,,,, -74800,1.5990548,3.91718,,,,,,,,,,,,,, -74841,,,0.6474413871765137,1.5477566719055176,0.606440007686615,1.7396602630615234,50000.0,0.4857000112533569,2.386080503463745,10000.0,34071.80634212494,36941.8312625885,34071.80634212494,2861.388499736786,4.406642913818359,0.0 -74900,1.7780421,3.168089,,,,,,,,,,,,,, -75000,1.9102966,3.1766808,,,,,,,,,,,,,, -75100,1.6869593,3.5617375,,,,,,,,,,,,,, -75200,1.7009386,3.3699222,,,,,,,,,,,,,, -75300,1.4541056,3.9383898,,,,,,,,,,,,,, -75400,1.4366807,5.2082515,,,,,,,,,,,,,, -75500,1.4773492,3.4651585,,,,,,,,,,,,,, -75600,1.4351035,4.6737394,,,,,,,,,,,,,, -75700,1.7442764,3.170009,,,,,,,,,,,,,, -75764,,,0.6582812070846558,1.5243412256240845,0.6102199554443359,1.744834065437317,50000.0,0.4821000099182129,2.3967623710632324,10000.0,34492.00156569481,37399.13010430336,34492.00156569481,2898.396971464157,4.448941230773926,0.0 -75800,1.6919763,3.0967584,,,,,,,,,,,,,, -75900,1.5624325,4.092467,,,,,,,,,,,,,, -76000,1.5275946,3.2550797,,,,,,,,,,,,,, -76100,1.7596986,3.137707,,,,,,,,,,,,,, -76200,1.7171062,3.1870246,,,,,,,,,,,,,, -76300,1.522854,4.7363715,,,,,,,,,,,,,, -76400,1.8130655,3.1859107,,,,,,,,,,,,,, -76500,1.753638,3.1471148,,,,,,,,,,,,,, -76600,1.4775734,3.88634,,,,,,,,,,,,,, -76687,,,0.6499804258346558,1.5522480010986328,0.6068199872970581,1.75525164604187,50000.0,0.4863000214099884,2.3870294094085693,10000.0,34912.220725774765,37856.34967160225,34912.220725774765,2935.307032585144,4.486681699752808,0.0 -76700,1.962345,3.1998544,,,,,,,,,,,,,, -76800,1.5452136,3.5527363,,,,,,,,,,,,,, -76900,1.5867717,4.5110936,,,,,,,,,,,,,, -77000,1.6322571,3.2803926,,,,,,,,,,,,,, -77100,1.6099862,3.5182512,,,,,,,,,,,,,, -77200,1.8024577,3.203031,,,,,,,,,,,,,, -77300,1.8108153,3.423757,,,,,,,,,,,,,, -77400,1.4599266,5.3102455,,,,,,,,,,,,,, -77500,1.7583342,3.257702,,,,,,,,,,,,,, -77600,1.6334738,3.4057553,,,,,,,,,,,,,, -77611,,,0.655078113079071,1.556774377822876,0.6091799736022949,1.7530089616775513,50000.0,0.4870000183582306,2.381280899047852,10000.0,35332.43723273277,38313.127844810486,35332.43723273277,2971.7798516750336,4.523826599121094,0.0 -77700,1.558783,3.3329742,,,,,,,,,,,,,, -77800,1.8773528,4.6090064,,,,,,,,,,,,,, -77900,1.5566077,5.234619,,,,,,,,,,,,,, -78000,1.8803633,3.0657508,,,,,,,,,,,,,, -78100,1.388272,4.111945,,,,,,,,,,,,,, -78200,1.6841849,3.2166362,,,,,,,,,,,,,, -78300,1.6297188,3.0630362,,,,,,,,,,,,,, -78400,1.7861911,3.160275,,,,,,,,,,,,,, -78500,1.8060597,3.1220512,,,,,,,,,,,,,, -78533,,,0.6649804711341858,1.5041078329086304,0.6135199666023254,1.727328181266785,50000.0,0.4961000382900238,2.359246730804444,10000.0,35752.73362803459,38772.0255856514,35752.73362803459,3010.292057991028,4.561509370803833,0.0 -78600,1.6915784,3.0854998,,,,,,,,,,,,,, -78700,1.6928284,3.1915975,,,,,,,,,,,,,, -78800,1.7903578,3.1212735,,,,,,,,,,,,,, -78900,1.5602777,5.292224,,,,,,,,,,,,,, -79000,1.4056386,5.419692,,,,,,,,,,,,,, -79100,1.7285833,3.206869,,,,,,,,,,,,,, -79200,1.7188386,3.2802958,,,,,,,,,,,,,, -79300,1.4117826,4.8133574,,,,,,,,,,,,,, -79400,1.8341466,3.11311,,,,,,,,,,,,,, -79454,,,0.6874414086341858,1.3758230209350586,0.6157599687576294,1.6964789628982544,50000.0,0.4874000251293182,2.350943326950073,10000.0,36172.88255786896,39229.4043803215,36172.88255786896,3047.432544708252,4.599277496337891,0.0 -79500,1.5021763,5.386443,,,,,,,,,,,,,, -79600,1.2846589,5.0380344,,,,,,,,,,,,,, -79700,1.6235417,5.040043,,,,,,,,,,,,,, -79800,1.7215395,3.341875,,,,,,,,,,,,,, -79900,1.6823329,3.2747586,,,,,,,,,,,,,, -80000,1.6590343,3.2291138,,,,,,,,,,,,,, -80100,1.3494906,4.1002636,,,,,,,,,,,,,, -80200,1.7965127,3.1568499,,,,,,,,,,,,,, -80300,1.7235054,3.156136,,,,,,,,,,,,,, -80378,,,0.6550585627555847,1.5762782096862793,0.6106799840927124,1.7810405492782593,50000.0,0.4880000352859497,2.4137213230133057,10000.0,36593.105981349945,39687.296456336975,36593.105981349945,3085.007307291031,4.6409912109375,0.0 -80400,2.0204577,3.5234554,,,,,,,,,,,,,, -80500,1.5041461,4.404248,,,,,,,,,,,,,, -80600,1.8146864,3.1523407,,,,,,,,,,,,,, -80700,1.6037898,3.102739,,,,,,,,,,,,,, -80800,1.5148122,4.8391066,,,,,,,,,,,,,, -80900,1.8406142,3.2320526,,,,,,,,,,,,,, -81000,1.6042501,3.4492545,,,,,,,,,,,,,, -81100,1.7480704,3.5544786,,,,,,,,,,,,,, -81200,1.3795033,4.7287664,,,,,,,,,,,,,, -81300,,,0.6625585556030273,1.5070688724517822,0.6144799590110779,1.7260377407073977,50000.0,0.4939000308513641,2.3620481491088867,10000.0,37013.32633161545,40145.815786361694,37013.32633161545,3123.217358827591,4.677295207977295,0.0 -81300,2.0091903,3.16811,,,,,,,,,,,,,, -81400,1.8577827,3.108312,,,,,,,,,,,,,, -81500,1.6820291,3.6753578,,,,,,,,,,,,,, -81600,1.853868,3.1755333,,,,,,,,,,,,,, -81700,1.4223965,3.8941853,,,,,,,,,,,,,, -81800,1.8337305,3.0758932,,,,,,,,,,,,,, -81900,1.4417069,4.9725666,,,,,,,,,,,,,, -82000,1.8188348,3.156477,,,,,,,,,,,,,, -82100,1.6532142,3.1813276,,,,,,,,,,,,,, -82200,1.6641619,3.1462293,,,,,,,,,,,,,, -82221,,,0.6760546565055847,1.452250361442566,0.6148200035095215,1.7269800901412964,50000.0,0.492900013923645,2.362778425216675,10000.0,37433.64666318893,40603.81119298935,37433.64666318893,3160.800005197525,4.718563795089722,0.0 -82300,1.5239933,4.1029396,,,,,,,,,,,,,, -82400,1.9726845,3.2454631,,,,,,,,,,,,,, -82500,1.424755,4.510478,,,,,,,,,,,,,, -82600,1.8107787,3.134643,,,,,,,,,,,,,, -82700,1.7726535,3.6807241,,,,,,,,,,,,,, -82800,1.4542747,5.2670217,,,,,,,,,,,,,, -82900,1.6426429,3.4338818,,,,,,,,,,,,,, -83000,1.3811767,4.43595,,,,,,,,,,,,,, -83100,1.7641001,3.200211,,,,,,,,,,,,,, -83141,,,0.6602343320846558,1.497167468070984,0.6192399859428406,1.6920735836029053,50000.0,0.5021000504493713,2.319068193435669,10000.0,37853.990706920624,41062.76618885994,37853.990706920624,3199.3191499710083,4.758185863494873,0.0 -83200,1.6782751,3.4226718,,,,,,,,,,,,,, -83300,1.7781134,3.02492,,,,,,,,,,,,,, -83400,1.7328241,3.2461376,,,,,,,,,,,,,, -83500,1.4936842,4.2411437,,,,,,,,,,,,,, -83600,1.989032,3.151361,,,,,,,,,,,,,, -83700,1.8792093,3.2173615,,,,,,,,,,,,,, -83800,1.4883492,4.868915,,,,,,,,,,,,,, -83900,1.6541585,4.003683,,,,,,,,,,,,,, -84000,1.7158666,3.3486066,,,,,,,,,,,,,, -84061,,,0.6661523580551147,1.5092904567718506,0.6180399656295776,1.7301098108291626,50000.0,0.4988000094890594,2.355710029602051,10000.0,38274.122478723526,41521.33239722252,38274.122478723526,3237.663080692292,4.796828985214233,0.0 -84100,1.7054919,5.2430816,,,,,,,,,,,,,, -84200,1.8739638,3.1129422,,,,,,,,,,,,,, -84300,1.7286375,5.409863,,,,,,,,,,,,,, -84400,1.8929433,3.749805,,,,,,,,,,,,,, -84500,1.4115541,4.715339,,,,,,,,,,,,,, -84600,1.7462769,3.331722,,,,,,,,,,,,,, -84700,1.676091,4.992337,,,,,,,,,,,,,, -84800,2.0372474,3.2895937,,,,,,,,,,,,,, -84900,1.7996504,3.2918909,,,,,,,,,,,,,, -84983,,,0.6742187142372131,1.436706781387329,0.6194199919700623,1.680558204650879,50000.0,0.5003000497817993,2.320410251617432,10000.0,38694.25512552261,41979.15235233307,38694.25512552261,3275.256284952164,4.83923602104187,0.0 -85000,1.7037787,3.1371818,,,,,,,,,,,,,, -85100,1.6823982,3.1428013,,,,,,,,,,,,,, -85200,1.6072094,3.6800315,,,,,,,,,,,,,, -85300,1.4241819,4.775307,,,,,,,,,,,,,, -85400,1.7369134,3.6811523,,,,,,,,,,,,,, -85500,1.7537297,3.1213124,,,,,,,,,,,,,, -85600,1.7127652,2.9742217,,,,,,,,,,,,,, -85700,2.0058103,3.137932,,,,,,,,,,,,,, -85800,1.6682295,3.7058983,,,,,,,,,,,,,, -85900,1.4733143,4.320019,,,,,,,,,,,,,, -85905,,,0.657031238079071,1.5600054264068604,0.6131199598312378,1.742652177810669,50000.0,0.4909000098705292,2.3830223083496094,10000.0,39114.24275946617,42437.045094013214,39114.24275946617,3313.068485021591,4.880053997039795,0.0 -86000,1.7537475,3.2064188,,,,,,,,,,,,,, -86100,1.8662417,3.0472786,,,,,,,,,,,,,, -86200,1.5499403,3.867987,,,,,,,,,,,,,, -86300,1.559863,4.8905625,,,,,,,,,,,,,, -86400,1.7781117,3.1290739,,,,,,,,,,,,,, -86500,1.6259406,3.7782063,,,,,,,,,,,,,, -86600,1.9316941,3.093466,,,,,,,,,,,,,, -86700,1.5456314,3.8457794,,,,,,,,,,,,,, -86800,1.9941553,3.1387422,,,,,,,,,,,,,, -86826,,,0.6669530868530273,1.5276052951812744,0.6200399994850159,1.735192894935608,50000.0,0.4971000254154205,2.3761723041534424,10000.0,39534.21739983559,42895.48432254791,39534.21739983559,3351.4402787685394,4.920087099075317,0.0 -86900,1.8177588,3.2171354,,,,,,,,,,,,,, -87000,1.8086127,3.071698,,,,,,,,,,,,,, -87100,1.8690255,3.0348809,,,,,,,,,,,,,, -87200,1.7335254,3.2166307,,,,,,,,,,,,,, -87300,1.8067601,3.1552937,,,,,,,,,,,,,, -87400,1.7121352,3.2705417,,,,,,,,,,,,,, -87500,2.0799599,3.0865674,,,,,,,,,,,,,, -87600,1.8983449,3.0994196,,,,,,,,,,,,,, -87700,1.9562677,3.2591376,,,,,,,,,,,,,, -87749,,,0.6812499761581421,1.3998364210128784,0.6264599561691284,1.6430705785751345,50000.0,0.5020000338554382,2.299318552017212,10000.0,39954.22675919533,43351.513491392136,39954.22675919533,3387.36865234375,4.959564685821533,0.0 -87800,1.8647561,3.1268663,,,,,,,,,,,,,, -87900,1.9205598,3.8481212,,,,,,,,,,,,,, -88000,1.6703644,3.2933836,,,,,,,,,,,,,, -88100,1.7790802,3.0663552,,,,,,,,,,,,,, -88200,1.5288244,4.2721834,,,,,,,,,,,,,, -88300,1.683934,3.948412,,,,,,,,,,,,,, -88400,1.7649392,3.1921117,,,,,,,,,,,,,, -88500,1.7531142,5.174348,,,,,,,,,,,,,, -88600,1.9022378,3.0912542,,,,,,,,,,,,,, -88670,,,0.6852734088897705,1.405772089958191,0.6255399584770203,1.6802127361297607,50000.0,0.5082000494003296,2.318076610565185,10000.0,40374.38328671456,43810.47713375092,40374.38328671456,3426.085087776184,4.998172760009766,0.0 -88700,2.0944254,3.1681094,,,,,,,,,,,,,, -88800,1.7614911,2.9152713,,,,,,,,,,,,,, -88900,1.7192489,3.2545145,,,,,,,,,,,,,, -89000,1.838767,3.416202,,,,,,,,,,,,,, -89100,1.5513588,3.5087984,,,,,,,,,,,,,, -89200,2.138772,3.1667233,,,,,,,,,,,,,, -89300,1.9217759,3.013359,,,,,,,,,,,,,, -89400,1.7269979,3.4359045,,,,,,,,,,,,,, -89500,1.9177332,3.09446,,,,,,,,,,,,,, -89592,,,0.6747655868530273,1.4603726863861084,0.6269800066947937,1.6675292253494265,50000.0,0.5103000402450562,2.291046142578125,10000.0,40794.58645391464,44268.12525868416,40794.58645391464,3463.438698530197,5.036881446838379,0.0 -89600,1.6180385,5.343636,,,,,,,,,,,,,, -89700,1.8935384,3.1525757,,,,,,,,,,,,,, -89800,1.9288595,3.5369136,,,,,,,,,,,,,, -89900,1.68715,3.2666943,,,,,,,,,,,,,, -90000,1.8580942,3.2822442,,,,,,,,,,,,,, -90100,1.7689402,3.0402012,,,,,,,,,,,,,, -90200,1.8677783,3.0096242,,,,,,,,,,,,,, -90300,1.7266376,3.0696466,,,,,,,,,,,,,, -90400,1.8289665,3.138765,,,,,,,,,,,,,, -90500,1.863435,3.0796387,,,,,,,,,,,,,, -90515,,,0.6829491853713989,1.4198222160339355,0.6295199990272522,1.649441838264465,50000.0,0.5087000131607056,2.2928144931793213,10000.0,41214.94577026367,44726.574046611786,41214.94577026367,3501.4375364780426,5.075735569000244,0.0 -90600,1.9374171,3.0818262,,,,,,,,,,,,,, -90700,1.7462125,3.0241551,,,,,,,,,,,,,, -90800,1.5914241,4.50423,,,,,,,,,,,,,, -90900,1.956528,3.0731862,,,,,,,,,,,,,, -91000,1.609378,3.7026205,,,,,,,,,,,,,, -91100,1.4907255,4.6805897,,,,,,,,,,,,,, -91200,1.6202472,4.5987725,,,,,,,,,,,,,, -91300,1.8444774,3.009286,,,,,,,,,,,,,, -91400,1.6379273,3.6521819,,,,,,,,,,,,,, -91434,,,0.6955859065055847,1.3533570766448977,0.6298800110816956,1.656938910484314,50000.0,0.5065000057220459,2.301645517349243,10000.0,41635.21865081787,45186.90314650536,41635.21865081787,3541.404905796051,5.113611698150635,0.0 -91500,1.9600846,3.242194,,,,,,,,,,,,,, -91600,2.000274,3.0411592,,,,,,,,,,,,,, -91700,1.7796093,3.055749,,,,,,,,,,,,,, -91800,1.6977174,4.748424,,,,,,,,,,,,,, -91900,1.7732084,3.8202596,,,,,,,,,,,,,, -92000,1.9865324,3.0713074,,,,,,,,,,,,,, -92100,1.6703213,3.7490687,,,,,,,,,,,,,, -92200,1.6915402,4.056881,,,,,,,,,,,,,, -92300,1.6523278,3.2591343,,,,,,,,,,,,,, -92353,,,0.6741992235183716,1.4724483489990234,0.6269999742507935,1.6782883405685425,50000.0,0.5028000473976135,2.322950601577759,10000.0,42055.23688745499,45644.38531947136,42055.23688745499,3578.7778816223145,5.15111517906189,0.0 -92400,1.8957558,3.243128,,,,,,,,,,,,,, -92500,1.9495559,3.0397997,,,,,,,,,,,,,, -92600,1.8353542,3.0509136,,,,,,,,,,,,,, -92700,1.8815159,3.0906765,,,,,,,,,,,,,, -92800,1.6756816,4.8265085,,,,,,,,,,,,,, -92900,1.6536298,3.9357934,,,,,,,,,,,,,, -93000,2.0601048,3.099315,,,,,,,,,,,,,, -93100,1.712342,3.0932226,,,,,,,,,,,,,, -93200,1.6471807,5.192338,,,,,,,,,,,,,, -93277,,,0.681933581829071,1.4165855646133425,0.6343399882316589,1.6316964626312256,50000.0,0.5100000500679016,2.268878698348999,10000.0,42475.59681820869,46103.61071944237,42475.59681820869,3617.549820184708,5.1927220821380615,0.0 -93300,1.7987564,4.6894703,,,,,,,,,,,,,, -93400,1.7385043,3.6120608,,,,,,,,,,,,,, -93500,1.7639605,4.1763315,,,,,,,,,,,,,, -93600,1.4988735,5.105922,,,,,,,,,,,,,, -93700,1.7899647,2.9622939,,,,,,,,,,,,,, -93800,1.8240111,3.0889826,,,,,,,,,,,,,, -93900,1.7850184,3.1790724,,,,,,,,,,,,,, -94000,1.7643225,3.0665653,,,,,,,,,,,,,, -94100,2.2212589,4.2548494,,,,,,,,,,,,,, -94198,,,0.6891406178474426,1.4098796844482422,0.633080005645752,1.6661624908447266,50000.0,0.5108000040054321,2.293397188186645,10000.0,42895.638721227646,46561.52732515335,42895.638721227646,3655.330612421036,5.234953880310059,0.0 -94200,1.6371547,5.2702303,,,,,,,,,,,,,, -94300,1.7446576,3.082392,,,,,,,,,,,,,, -94400,1.6850626,4.5845866,,,,,,,,,,,,,, -94500,1.7436407,3.034849,,,,,,,,,,,,,, -94600,1.819558,3.2952878,,,,,,,,,,,,,, -94700,1.7382133,5.0195127,,,,,,,,,,,,,, -94800,1.8066851,5.093833,,,,,,,,,,,,,, -94900,2.1033108,2.987437,,,,,,,,,,,,,, -95000,2.0233054,5.0588903,,,,,,,,,,,,,, -95100,2.0622716,3.1224613,,,,,,,,,,,,,, -95115,,,0.6744726300239563,1.4238649606704712,0.6321200132369995,1.6284260749816897,50000.0,0.5033000111579895,2.2706520557403564,10000.0,43315.5960958004,47019.02144479752,43315.5960958004,3692.772596597672,5.277697801589966,0.0 -95200,1.4780461,3.807492,,,,,,,,,,,,,, -95300,1.7123097,3.276227,,,,,,,,,,,,,, -95400,1.9964572,3.0105872,,,,,,,,,,,,,, -95500,1.9171598,3.003494,,,,,,,,,,,,,, -95600,1.9571481,4.9659534,,,,,,,,,,,,,, -95700,1.8619945,3.0541213,,,,,,,,,,,,,, -95800,1.8788486,3.0345616,,,,,,,,,,,,,, -95900,1.8042506,3.8261716,,,,,,,,,,,,,, -96000,2.1606584,3.1147943,,,,,,,,,,,,,, -96038,,,0.6746679544448853,1.4650830030441284,0.6334800124168396,1.658732295036316,50000.0,0.501800000667572,2.325078248977661,10000.0,43735.618525505066,47478.4274597168,43735.618525505066,3732.062595129013,5.319170951843262,0.0 -96100,1.64896,4.69827,,,,,,,,,,,,,, -96200,2.2175262,3.076138,,,,,,,,,,,,,, -96300,1.8282633,3.058384,,,,,,,,,,,,,, -96400,1.7260278,3.0100436,,,,,,,,,,,,,, -96500,1.6812482,4.96534,,,,,,,,,,,,,, -96600,1.763497,3.4150286,,,,,,,,,,,,,, -96700,1.6709731,3.7907834,,,,,,,,,,,,,, -96800,1.8116915,3.8213444,,,,,,,,,,,,,, -96900,2.1341515,3.1053128,,,,,,,,,,,,,, -96961,,,0.6908984184265137,1.370441436767578,0.6380000114440918,1.611522197723389,50000.0,0.5160000324249268,2.247906446456909,10000.0,44155.63729739189,47935.84779524803,44155.63729739189,3769.371959209442,5.359732389450073,0.0 -97000,1.6248865,4.9811306,,,,,,,,,,,,,, -97100,1.9884621,3.3350968,,,,,,,,,,,,,, -97200,1.7097124,3.2722414,,,,,,,,,,,,,, -97300,1.7344339,3.2812734,,,,,,,,,,,,,, -97400,2.0461576,3.1978683,,,,,,,,,,,,,, -97500,2.3335674,4.992631,,,,,,,,,,,,,, -97600,1.7754092,5.217704,,,,,,,,,,,,,, -97700,1.8070545,4.020188,,,,,,,,,,,,,, -97800,1.9420252,3.5635767,,,,,,,,,,,,,, -97883,,,0.6875585913658142,1.3868331909179688,0.6341399550437927,1.6226903200149536,50000.0,0.5092000365257263,2.252286434173584,10000.0,44575.90721774101,48391.369034051895,44575.90721774101,3804.532338619232,5.398097276687622,0.0 -97900,2.0289605,3.0677686,,,,,,,,,,,,,, -98000,1.8962206,3.2368584,,,,,,,,,,,,,, -98100,1.7198251,4.9584703,,,,,,,,,,,,,, -98200,1.9733725,3.0378985,,,,,,,,,,,,,, -98300,2.306664,3.0142856,,,,,,,,,,,,,, -98400,2.262394,3.0471027,,,,,,,,,,,,,, -98500,1.856164,3.3007085,,,,,,,,,,,,,, -98600,1.8600295,4.601558,,,,,,,,,,,,,, -98700,1.9286201,3.0223358,,,,,,,,,,,,,, -98800,1.7571026,3.399229,,,,,,,,,,,,,, -98806,,,0.6914257407188416,1.358039379119873,0.6440399885177612,1.5768593549728394,50000.0,0.5145000219345093,2.221693754196167,10000.0,44995.98552107811,48849.69229388237,44995.98552107811,3842.685833454132,5.437853097915649,0.0 -98900,2.019656,3.0805273,,,,,,,,,,,,,, -99000,1.9674275,3.03266,,,,,,,,,,,,,, -99100,1.6002582,5.1829886,,,,,,,,,,,,,, -99200,2.0822139,3.0360594,,,,,,,,,,,,,, -99300,2.009323,3.025145,,,,,,,,,,,,,, -99400,1.804681,3.3802319,,,,,,,,,,,,,, -99500,1.979041,2.9913473,,,,,,,,,,,,,, -99600,1.8830037,3.0143137,,,,,,,,,,,,,, -99700,1.8979392,3.0625741,,,,,,,,,,,,,, -99725,,,0.6920703053474426,1.373831272125244,0.6429199576377869,1.6061257123947144,50000.0,0.5164000391960144,2.2455978393554688,10000.0,45416.36224722862,49308.9356777668,45416.36224722862,3881.4554891586304,5.483229875564575,0.0 -99800,1.7787615,2.962067,,,,,,,,,,,,,, -99900,2.1466177,2.9425547,,,,,,,,,,,,,, -100000,1.7124159,4.4980507,,,,,,,,,,,,,, -100100,2.2617702,3.0295215,,,,,,,,,,,,,, -100200,1.7876517,2.9232175,,,,,,,,,,,,,, -100300,2.0508814,3.1064785,,,,,,,,,,,,,, -100400,2.2492876,3.0946097,,,,,,,,,,,,,, -100500,1.8314545,3.7227244,,,,,,,,,,,,,, -100600,2.1402342,2.9912503,,,,,,,,,,,,,, -100644,,,0.7134960889816284,1.2532609701156616,0.6430999636650085,1.5673333406448364,50000.0,0.522599995136261,2.2153022289276123,10000.0,45836.51930832863,49767.70958662033,45836.51930832863,3919.975238323212,5.5283708572387695,0.0 -100700,1.7660155,3.117802,,,,,,,,,,,,,, -100800,1.7031138,5.112919,,,,,,,,,,,,,, -100900,1.7530928,3.094232,,,,,,,,,,,,,, -101000,1.9502339,2.9269433,,,,,,,,,,,,,, -101100,1.6623464,3.5202003,,,,,,,,,,,,,, -101200,1.7793756,3.2254202,,,,,,,,,,,,,, -101300,1.5427822,4.3998055,,,,,,,,,,,,,, -101400,2.034882,3.0018644,,,,,,,,,,,,,, -101500,2.0639448,3.0692325,,,,,,,,,,,,,, -101569,,,0.6882616877555847,1.375690460205078,0.6430999636650085,1.592112421989441,50000.0,0.5160000324249268,2.235506534576416,10000.0,46256.86202788353,50226.20248699188,46256.86202788353,3958.030996799469,5.570519685745239,0.0 -101600,2.1511667,3.0589705,,,,,,,,,,,,,, -101700,2.1121347,2.964335,,,,,,,,,,,,,, -101800,1.8712296,3.158878,,,,,,,,,,,,,, -101900,1.8445382,3.09383,,,,,,,,,,,,,, -102000,2.1476216,2.969659,,,,,,,,,,,,,, -102100,1.7422642,4.4056263,,,,,,,,,,,,,, -102200,2.0396512,3.6050417,,,,,,,,,,,,,, -102300,2.1320791,2.9972334,,,,,,,,,,,,,, -102400,1.6307379,4.3553553,,,,,,,,,,,,,, -102491,,,0.695605456829071,1.3783432245254517,0.6459800004959106,1.5988671779632568,50000.0,0.5182999968528748,2.2385599613189697,10000.0,46676.82009387016,50683.19379091263,46676.82009387016,3994.964736223221,5.617609024047852,0.0 -102500,1.9766165,3.1281018,,,,,,,,,,,,,, -102600,2.0724082,3.022693,,,,,,,,,,,,,, -102700,2.3041718,3.725985,,,,,,,,,,,,,, -102800,2.3055413,2.9913545,,,,,,,,,,,,,, -102900,1.9678825,3.0653572,,,,,,,,,,,,,, -103000,1.9422923,2.9416301,,,,,,,,,,,,,, -103100,2.1178944,3.015641,,,,,,,,,,,,,, -103200,1.6104112,4.207927,,,,,,,,,,,,,, -103300,1.7199255,4.5265913,,,,,,,,,,,,,, -103400,2.3711317,4.299199,,,,,,,,,,,,,, -103413,,,0.7080858945846558,1.311424970626831,0.6429799795150757,1.5880955457687378,50000.0,0.5222000479698181,2.2103865146636963,10000.0,47097.179240942,51143.90802812576,47097.179240942,4035.222962141037,5.66266655921936,0.0 -103500,2.0864754,4.415757,,,,,,,,,,,,,, -103600,2.0201485,3.0483265,,,,,,,,,,,,,, -103700,2.0384881,3.336373,,,,,,,,,,,,,, -103800,1.8877491,3.0056233,,,,,,,,,,,,,, -103900,1.8538319,3.6296463,,,,,,,,,,,,,, -104000,1.9480454,2.98033,,,,,,,,,,,,,, -104100,1.938084,3.2510805,,,,,,,,,,,,,, -104200,2.0383592,2.8883963,,,,,,,,,,,,,, -104300,1.8047029,5.0711775,,,,,,,,,,,,,, -104333,,,0.6900585889816284,1.4127657413482666,0.6438199877738953,1.6250956058502195,50000.0,0.5208000540733337,2.272473096847534,10000.0,47517.37286019325,51604.82066082954,47517.37286019325,4075.8489694595337,5.703973770141602,0.0 -104400,2.0545242,3.0676715,,,,,,,,,,,,,, -104500,1.9679153,3.014545,,,,,,,,,,,,,, -104600,2.0640671,2.9736607,,,,,,,,,,,,,, -104700,2.0379572,2.9640083,,,,,,,,,,,,,, -104800,1.9598882,3.1277359,,,,,,,,,,,,,, -104900,2.159785,2.9270587,,,,,,,,,,,,,, -105000,1.564951,4.114322,,,,,,,,,,,,,, -105100,1.9877595,3.0474062,,,,,,,,,,,,,, -105200,2.15793,3.0373864,,,,,,,,,,,,,, -105255,,,0.7038280963897705,1.3086973428726196,0.6522600054740906,1.5347038507461548,50000.0,0.5284000039100647,2.170280933380127,10000.0,47937.8508181572,52063.75856423378,47937.8508181572,4114.217289924622,5.74323320388794,0.0 -105300,2.3624296,2.9920835,,,,,,,,,,,,,, -105400,2.063314,2.957119,,,,,,,,,,,,,, -105500,1.9356266,3.1634324,,,,,,,,,,,,,, -105600,1.9277245,2.9166582,,,,,,,,,,,,,, -105700,2.1117423,2.9599414,,,,,,,,,,,,,, -105800,2.1478815,2.8952446,,,,,,,,,,,,,, -105900,2.0798683,2.967653,,,,,,,,,,,,,, -106000,1.8579465,3.700151,,,,,,,,,,,,,, -106100,2.0940778,3.0787892,,,,,,,,,,,,,, -106178,,,0.695117175579071,1.3671385049819946,0.638480007648468,1.6232553720474243,50000.0,0.5179000496864319,2.258666753768921,10000.0,48358.08609056473,52523.49588823319,48358.08609056473,4153.619336605072,5.791689872741699,0.0 -106200,2.2792997,3.0332878,,,,,,,,,,,,,, -106300,1.7492875,4.9577136,,,,,,,,,,,,,, -106400,2.0610428,2.9166389,,,,,,,,,,,,,, -106500,2.4286945,2.9497566,,,,,,,,,,,,,, -106600,2.7226171,3.0285442,,,,,,,,,,,,,, -106700,2.0263584,2.8418858,,,,,,,,,,,,,, -106800,2.5814168,4.862165,,,,,,,,,,,,,, -106900,2.3603287,3.152338,,,,,,,,,,,,,, -107000,1.9507474,3.2242613,,,,,,,,,,,,,, -107099,,,0.7011132836341858,1.3185851573944092,0.6495000123977661,1.5532654523849487,50000.0,0.5246000289916992,2.1825060844421387,10000.0,48778.02906394005,52983.47541809082,48778.02906394005,4193.563049793243,5.832031965255737,0.0 -107100,1.8236767,4.6609845,,,,,,,,,,,,,, -107200,1.790159,3.3863907,,,,,,,,,,,,,, -107300,2.2222054,3.2569056,,,,,,,,,,,,,, -107400,2.101677,2.985216,,,,,,,,,,,,,, -107500,1.850246,4.515249,,,,,,,,,,,,,, -107600,2.0675845,2.8844445,,,,,,,,,,,,,, -107700,1.9865571,2.8046908,,,,,,,,,,,,,, -107800,1.9554304,3.152082,,,,,,,,,,,,,, -107900,1.9176903,2.9217885,,,,,,,,,,,,,, -108000,1.931172,5.0877047,,,,,,,,,,,,,, -108022,,,0.7089062333106995,1.286357879638672,0.6575999855995178,1.5152268409729004,50000.0,0.5370000004768372,2.137031316757202,10000.0,49198.16246008873,53441.33406162262,49198.16246008873,4231.1956782341,5.872812747955322,0.0 -108100,2.4650016,3.002092,,,,,,,,,,,,,, -108200,1.9840212,4.3073225,,,,,,,,,,,,,, -108300,1.9682561,2.984622,,,,,,,,,,,,,, -108400,1.9241241,4.8702984,,,,,,,,,,,,,, -108500,1.8353368,3.6181848,,,,,,,,,,,,,, -108600,2.0493116,3.0519185,,,,,,,,,,,,,, -108700,2.447451,2.9569006,,,,,,,,,,,,,, -108800,1.8207389,4.0728736,,,,,,,,,,,,,, -108900,2.1465354,2.9565716,,,,,,,,,,,,,, -108948,,,0.7115429639816284,1.2524975538253784,0.6578199863433838,1.5066378116607666,50000.0,0.5306000113487244,2.1434850692749023,10000.0,49618.37132978439,53901.06821775437,49618.37132978439,4270.626555204392,5.915415287017822,0.0 -109000,2.0757651,2.9941952,,,,,,,,,,,,,, -109100,2.0238051,5.1098313,,,,,,,,,,,,,, -109200,1.78702,4.219658,,,,,,,,,,,,,, -109300,2.2280643,3.0063744,,,,,,,,,,,,,, -109400,2.1279223,2.86655,,,,,,,,,,,,,, -109500,2.1796155,2.9070702,,,,,,,,,,,,,, -109600,1.835983,3.9839458,,,,,,,,,,,,,, -109700,2.2304614,3.1705213,,,,,,,,,,,,,, -109800,1.7533962,4.6102266,,,,,,,,,,,,,, -109871,,,0.7302343845367432,1.2425791025161743,0.6523799896240234,1.5649343729019165,50000.0,0.5297999978065491,2.186164140701294,10000.0,50038.27680063248,54360.60054755211,50038.27680063248,4310.15029501915,5.962715864181519,0.0 -109900,2.153904,2.8869092,,,,,,,,,,,,,, -110000,2.400286,2.9851232,,,,,,,,,,,,,, -110100,2.0788865,2.994146,,,,,,,,,,,,,, -110200,1.9490923,4.760349,,,,,,,,,,,,,, -110300,1.8517135,3.7172973,,,,,,,,,,,,,, -110400,1.982064,4.586409,,,,,,,,,,,,,, -110500,2.2430258,2.9689658,,,,,,,,,,,,,, -110600,1.8221923,4.5570383,,,,,,,,,,,,,, -110700,1.7604499,3.456143,,,,,,,,,,,,,, -110790,,,0.7073437571525574,1.2899636030197144,0.6584799885749817,1.5123716592788696,50000.0,0.5329000353813171,2.155395746231079,10000.0,50458.470808029175,54820.40697169304,50458.470808029175,4349.666513442993,6.006732702255249,0.0 -110800,2.0363529,4.627378,,,,,,,,,,,,,, -110900,1.8533285,4.371954,,,,,,,,,,,,,, -111000,2.0853095,2.9580271,,,,,,,,,,,,,, -111100,2.055991,3.5939121,,,,,,,,,,,,,, -111200,2.0655887,3.561057,,,,,,,,,,,,,, -111300,1.8497244,3.6864114,,,,,,,,,,,,,, -111400,3.6872492,5.0200653,,,,,,,,,,,,,, -111500,2.1893091,2.9504685,,,,,,,,,,,,,, -111600,1.9704409,3.5930307,,,,,,,,,,,,,, -111700,2.0895534,2.980297,,,,,,,,,,,,,, -111713,,,0.7128515243530273,1.2536503076553345,0.6627599596977234,1.4864120483398438,50000.0,0.5364000201225281,2.1200437545776367,10000.0,50878.71767401695,55277.81778669357,50878.71767401695,4386.731926679611,6.05235767364502,0.0 -111800,1.9904765,4.3093452,,,,,,,,,,,,,, -111900,2.0429008,2.9156947,,,,,,,,,,,,,, -112000,2.306969,2.8489032,,,,,,,,,,,,,, -112100,2.2360122,2.9157364,,,,,,,,,,,,,, -112200,2.2654312,5.053862,,,,,,,,,,,,,, -112300,2.035413,3.7921693,,,,,,,,,,,,,, -112400,2.0936615,3.0308235,,,,,,,,,,,,,, -112500,2.3596177,2.9021788,,,,,,,,,,,,,, -112600,1.8965182,3.3997104,,,,,,,,,,,,,, -112635,,,0.7286913990974426,1.2007211446762085,0.6609399914741516,1.50294291973114,50000.0,0.5369000434875488,2.1409804821014404,10000.0,51298.70504426956,55736.1331076622,51298.70504426956,4424.739602088928,6.320374011993408,0.0 -112700,2.2814515,2.9026175,,,,,,,,,,,,,, -112800,2.2034626,4.7611966,,,,,,,,,,,,,, -112900,2.1343298,2.8277793,,,,,,,,,,,,,, -113000,2.1980827,3.009848,,,,,,,,,,,,,, -113100,2.6774454,2.8655932,,,,,,,,,,,,,, -113200,2.188924,2.9902282,,,,,,,,,,,,,, -113300,2.0704648,2.9116914,,,,,,,,,,,,,, -113400,2.2793963,2.7668815,,,,,,,,,,,,,, -113500,2.0778918,2.8701036,,,,,,,,,,,,,, -113557,,,0.7140820026397705,1.3028429746627808,0.6602199673652649,1.5292065143585205,50000.0,0.532800018787384,2.165210485458374,10000.0,51718.90090465546,56191.47248148918,51718.90090465546,4459.790710687637,6.36035680770874,0.0 -113600,2.060001,2.9374714,,,,,,,,,,,,,, -113700,2.3105876,2.9052436,,,,,,,,,,,,,, -113800,2.0479786,3.4246173,,,,,,,,,,,,,, -113900,2.2147853,2.909605,,,,,,,,,,,,,, -114000,2.2283745,4.615859,,,,,,,,,,,,,, -114100,2.0602562,4.9442854,,,,,,,,,,,,,, -114200,1.9604207,4.164029,,,,,,,,,,,,,, -114300,2.0726364,2.9549253,,,,,,,,,,,,,, -114400,2.1982558,4.898988,,,,,,,,,,,,,, -114479,,,0.7202343344688416,1.2131118774414062,0.6656999588012695,1.4627410173416138,50000.0,0.5388000011444092,2.1023590564727783,10000.0,52139.10345339775,56650.11271905899,52139.10345339775,4498.1347053051,6.402165174484253,0.0 -114500,2.1241634,2.8965347,,,,,,,,,,,,,, -114600,2.4054537,2.9312348,,,,,,,,,,,,,, -114700,2.3152788,2.960701,,,,,,,,,,,,,, -114800,2.3587594,2.9628642,,,,,,,,,,,,,, -114900,2.2085218,2.875593,,,,,,,,,,,,,, -115000,2.018119,4.3772326,,,,,,,,,,,,,, -115100,2.2114513,2.8651464,,,,,,,,,,,,,, -115200,2.0983906,4.5274763,,,,,,,,,,,,,, -115300,2.120391,3.0092087,,,,,,,,,,,,,, -115400,2.1783433,3.5696478,,,,,,,,,,,,,, -115403,,,0.7263085842132568,1.1966229677200315,0.6647199988365173,1.4690905809402466,50000.0,0.539400041103363,2.100294589996338,10000.0,52559.15979242325,57109.05661630631,52559.15979242325,4536.924311637878,6.447792291641235,0.0 -115500,2.2148342,3.2665029,,,,,,,,,,,,,, -115600,1.9713222,3.972769,,,,,,,,,,,,,, -115700,2.1727276,2.7687464,,,,,,,,,,,,,, -115800,1.7822032,3.831616,,,,,,,,,,,,,, -115900,2.7473083,3.9785771,,,,,,,,,,,,,, -116000,2.185319,2.888587,,,,,,,,,,,,,, -116100,1.8310738,3.9472795,,,,,,,,,,,,,, -116200,2.2151432,2.8347328,,,,,,,,,,,,,, -116300,2.4382796,2.8713715,,,,,,,,,,,,,, -116327,,,0.7205468416213989,1.2346270084381104,0.6696799993515015,1.4559895992279053,50000.0,0.5482000112533569,2.084333658218384,10000.0,52979.50097155571,57569.00668978691,52979.50097155571,4576.432106971741,6.493849992752075,0.0 -116400,2.4385796,2.8891714,,,,,,,,,,,,,, -116500,2.3791473,2.9384217,,,,,,,,,,,,,, -116600,2.2264347,3.7294111,,,,,,,,,,,,,, -116700,2.231746,2.8779721,,,,,,,,,,,,,, -116800,2.2546647,2.8034241,,,,,,,,,,,,,, -116900,2.146767,2.8465147,,,,,,,,,,,,,, -117000,1.9561547,3.8072317,,,,,,,,,,,,,, -117100,2.3409042,4.6834927,,,,,,,,,,,,,, -117200,2.1135178,4.708544,,,,,,,,,,,,,, -117250,,,0.7180468440055847,1.26100754737854,0.6637200117111206,1.4970401525497437,50000.0,0.539900004863739,2.131301164627075,10000.0,53399.68098139763,58026.68306350708,53399.68098139763,4613.82989192009,6.539278507232666,0.0 -117300,2.310292,2.981729,,,,,,,,,,,,,, -117400,2.2237968,2.9225688,,,,,,,,,,,,,, -117500,2.5007997,2.8646119,,,,,,,,,,,,,, -117600,2.4729385,4.9556828,,,,,,,,,,,,,, -117700,2.18414,2.8421488,,,,,,,,,,,,,, -117800,2.2099473,2.7955396,,,,,,,,,,,,,, -117900,2.2615547,4.8516665,,,,,,,,,,,,,, -118000,2.3022728,2.8579893,,,,,,,,,,,,,, -118100,2.2153678,2.941456,,,,,,,,,,,,,, -118173,,,0.7331640720367432,1.1802427768707275,0.6723799705505371,1.4435547590255735,50000.0,0.5494000315666199,2.0697150230407715,10000.0,53819.9982714653,58486.09160208702,53819.9982714653,4652.822212696075,6.585542917251587,0.0 -118200,2.671162,5.1631413,,,,,,,,,,,,,, -118300,1.9792804,4.6027827,,,,,,,,,,,,,, -118400,2.3283844,2.8169117,,,,,,,,,,,,,, -118500,2.1119905,2.9794905,,,,,,,,,,,,,, -118600,2.378061,2.7418435,,,,,,,,,,,,,, -118700,2.1165824,3.007834,,,,,,,,,,,,,, -118800,2.2976081,4.8022294,,,,,,,,,,,,,, -118900,2.3631337,3.2568593,,,,,,,,,,,,,, -119000,2.008499,3.894805,,,,,,,,,,,,,, -119096,,,0.7335156202316284,1.199196219444275,0.6699999570846558,1.4819027185440063,50000.0,0.5421000123023987,2.1308164596557617,10000.0,54240.08188533783,58945.64118242264,54240.08188533783,4692.190171718597,6.631305694580078,0.0 -119100,2.355402,2.7876668,,,,,,,,,,,,,, -119200,2.289116,2.9621856,,,,,,,,,,,,,, -119300,2.639063,3.1162508,,,,,,,,,,,,,, -119400,2.114803,3.301422,,,,,,,,,,,,,, -119500,2.2807918,2.8063378,,,,,,,,,,,,,, -119600,2.0938718,3.446167,,,,,,,,,,,,,, -119700,2.1694486,4.4908876,,,,,,,,,,,,,, -119800,2.240546,2.7025704,,,,,,,,,,,,,, -119900,2.1599762,2.8858173,,,,,,,,,,,,,, -120000,2.4178615,2.7091787,,,,,,,,,,,,,, -120018,,,0.7240429520606995,1.2205166816711426,0.6732999682426453,1.439631104469299,50000.0,0.5497000217437744,2.077719211578369,10000.0,54660.09229564667,59404.036954164505,54660.09229564667,4730.476469278336,6.678221940994263,0.0 -120100,2.0143747,3.7338722,,,,,,,,,,,,,, -120200,2.4425116,2.904192,,,,,,,,,,,,,, -120300,2.121667,4.3231115,,,,,,,,,,,,,, -120400,1.938065,3.533318,,,,,,,,,,,,,, -120500,2.2753491,2.798964,,,,,,,,,,,,,, -120600,1.9605021,3.5913846,,,,,,,,,,,,,, -120700,2.2462702,2.8173656,,,,,,,,,,,,,, -120800,2.1583214,3.7269368,,,,,,,,,,,,,, -120900,2.8456514,4.9750557,,,,,,,,,,,,,, -120944,,,0.73388671875,1.1822478771209717,0.676099956035614,1.439404010772705,50000.0,0.5525000095367432,2.0614137649536133,10000.0,55080.20759105682,59864.2053706646,55080.20759105682,4770.4309067726135,6.724869251251221,0.0 -121000,2.354821,2.8270855,,,,,,,,,,,,,, -121100,2.4974947,2.8897943,,,,,,,,,,,,,, -121200,2.405977,2.8648307,,,,,,,,,,,,,, -121300,2.0570734,3.544696,,,,,,,,,,,,,, -121400,2.4456477,2.89156,,,,,,,,,,,,,, -121500,2.4376152,2.916679,,,,,,,,,,,,,, -121600,2.309041,3.9708867,,,,,,,,,,,,,, -121700,2.233138,2.8025386,,,,,,,,,,,,,, -121800,2.547131,2.8541203,,,,,,,,,,,,,, -121867,,,0.7496874928474426,1.114646315574646,0.6793599724769592,1.4229241609573364,50000.0,0.5576000213623047,2.0427939891815186,10000.0,55500.20226883888,60324.041732788086,55500.20226883888,4810.173624038696,6.771934986114502,0.0 -121900,2.2686014,4.3594756,,,,,,,,,,,,,, -122000,2.2713814,4.922792,,,,,,,,,,,,,, -122100,2.299239,2.7955003,,,,,,,,,,,,,, -122200,2.2089298,3.2191305,,,,,,,,,,,,,, -122300,2.4121408,2.8276107,,,,,,,,,,,,,, -122400,2.3323765,3.1687398,,,,,,,,,,,,,, -122500,2.3941853,2.8360548,,,,,,,,,,,,,, -122600,2.3127348,2.8684945,,,,,,,,,,,,,, -122700,2.6585984,2.7527125,,,,,,,,,,,,,, -122791,,,0.7269140481948853,1.2267969846725464,0.6785999536514282,1.4419078826904297,50000.0,0.5555000305175781,2.0610427856445312,10000.0,55920.51446223259,60784.256432294846,55920.51446223259,4849.972870588303,6.823628664016724,0.0 -122800,2.423272,2.9183602,,,,,,,,,,,,,, -122900,2.3431053,3.2741926,,,,,,,,,,,,,, -123000,2.6138384,3.2357385,,,,,,,,,,,,,, -123100,2.6098106,3.165083,,,,,,,,,,,,,, -123200,2.3722353,2.75245,,,,,,,,,,,,,, -123300,2.4506586,2.7984214,,,,,,,,,,,,,, -123400,2.3550224,2.7849708,,,,,,,,,,,,,, -123500,2.5103416,2.8142424,,,,,,,,,,,,,, -123600,2.1098757,2.9686544,,,,,,,,,,,,,, -123700,2.241016,2.6590657,,,,,,,,,,,,,, -123715,,,0.7386718392372131,1.1407707929611206,0.6824399828910828,1.3840407133102417,50000.0,0.560699999332428,2.007731914520264,10000.0,56340.72695922852,61241.7313709259,56340.72695922852,4887.135262012482,6.871109485626221,0.0 -123800,2.5136578,2.8383794,,,,,,,,,,,,,, -123900,2.3707383,2.999714,,,,,,,,,,,,,, -124000,2.436695,2.836978,,,,,,,,,,,,,, -124100,2.512061,2.8718352,,,,,,,,,,,,,, -124200,2.5513296,2.8818865,,,,,,,,,,,,,, -124300,2.2208672,3.1110275,,,,,,,,,,,,,, -124400,2.529256,2.8685613,,,,,,,,,,,,,, -124500,2.2607956,3.100016,,,,,,,,,,,,,, -124600,2.3395839,4.9228973,,,,,,,,,,,,,, -124640,,,0.7439648509025574,1.139055848121643,0.6798799633979797,1.4270832538604736,50000.0,0.5530000329017639,2.0573394298553467,10000.0,56760.75370979309,61701.31235575676,56760.75370979309,4926.5925216674805,6.9160919189453125,0.0 -124700,2.5302908,3.1080668,,,,,,,,,,,,,, -124800,2.6228962,2.8000627,,,,,,,,,,,,,, -124900,2.1705725,3.5744898,,,,,,,,,,,,,, -125000,2.4311998,2.7403064,,,,,,,,,,,,,, -125100,2.5310874,2.6481845,,,,,,,,,,,,,, -125200,2.8938363,2.9113994,,,,,,,,,,,,,, -125300,2.5969634,2.8136718,,,,,,,,,,,,,, -125400,2.4548109,4.750865,,,,,,,,,,,,,, -125500,2.928298,2.809116,,,,,,,,,,,,,, -125563,,,0.7366601228713989,1.1655316352844238,0.6827999949455261,1.4155220985412598,50000.0,0.5559000372886658,2.040470600128174,10000.0,57180.97926735878,62161.31812500954,57180.97926735878,4966.275221347809,6.962021350860596,0.0 -125600,2.354147,3.5083432,,,,,,,,,,,,,, -125700,2.4209788,2.7758276,,,,,,,,,,,,,, -125800,2.5404603,2.8625515,,,,,,,,,,,,,, -125900,2.6953456,3.507349,,,,,,,,,,,,,, -126000,2.822657,4.015558,,,,,,,,,,,,,, -126100,2.6073058,2.8023415,,,,,,,,,,,,,, -126200,2.7197607,2.7215137,,,,,,,,,,,,,, -126300,2.795042,4.6632943,,,,,,,,,,,,,, -126400,2.3742468,3.2477894,,,,,,,,,,,,,, -126485,,,0.7461132407188416,1.1315586566925049,0.6882199645042419,1.3758423328399658,50000.0,0.566100001335144,1.9890079498291016,10000.0,57601.43358683586,62619.15861034393,57601.43358683586,5003.563591241837,7.007826805114746,0.0 -126500,3.091918,4.9646883,,,,,,,,,,,,,, -126600,2.1315289,4.038924,,,,,,,,,,,,,, -126700,2.4097028,2.8557506,,,,,,,,,,,,,, -126800,2.6982884,2.926898,,,,,,,,,,,,,, -126900,2.4320033,2.747024,,,,,,,,,,,,,, -127000,2.3446496,3.0303655,,,,,,,,,,,,,, -127100,2.2579668,3.8473086,,,,,,,,,,,,,, -127200,2.4650912,4.721118,,,,,,,,,,,,,, -127300,2.56783,2.685131,,,,,,,,,,,,,, -127400,2.3767686,3.000869,,,,,,,,,,,,,, -127409,,,0.7524218559265137,1.1364513635635376,0.6879400014877319,1.4114627838134766,50000.0,0.5671000480651855,2.030573844909668,10000.0,58021.75379371643,63077.11794781685,58021.75379371643,5041.106873750687,7.051853656768799,0.0 -127500,2.4332669,4.3561645,,,,,,,,,,,,,, -127600,2.2722745,2.7117026,,,,,,,,,,,,,, -127700,2.3268144,3.2227182,,,,,,,,,,,,,, -127800,2.290448,3.0092123,,,,,,,,,,,,,, -127900,2.641818,2.7663283,,,,,,,,,,,,,, -128000,2.4205668,4.13624,,,,,,,,,,,,,, -128100,2.41805,3.6867056,,,,,,,,,,,,,, -128200,2.7045865,4.728705,,,,,,,,,,,,,, -128300,2.5477016,4.9593854,,,,,,,,,,,,,, -128334,,,0.74964839220047,1.1114648580551147,0.689799964427948,1.371135950088501,50000.0,0.5678000450134277,1.993343472480774,10000.0,58441.99978637695,63535.41212558746,58441.99978637695,5079.05620598793,7.098832368850708,0.0 -128400,2.5085082,2.9268973,,,,,,,,,,,,,, -128500,2.5041263,2.686706,,,,,,,,,,,,,, -128600,2.4737127,3.1586807,,,,,,,,,,,,,, -128700,2.5512056,3.1155424,,,,,,,,,,,,,, -128800,2.7900116,2.8167822,,,,,,,,,,,,,, -128900,2.6261845,2.717118,,,,,,,,,,,,,, -129000,2.5085387,2.6147935,,,,,,,,,,,,,, -129100,2.566027,4.8058543,,,,,,,,,,,,,, -129200,2.7244427,2.99827,,,,,,,,,,,,,, -129258,,,0.7498828172683716,1.1026932001113892,0.6937400102615356,1.3566606044769287,50000.0,0.5713000297546387,1.966183304786682,10000.0,58861.942873716354,63992.40976333618,58861.942873716354,5116.010575294495,7.14700722694397,0.0 -129300,2.3229492,3.8440185,,,,,,,,,,,,,, -129400,2.732241,2.760294,,,,,,,,,,,,,, -129500,3.4235322,4.767371,,,,,,,,,,,,,, -129600,2.7992692,2.796432,,,,,,,,,,,,,, -129700,2.4777443,2.759315,,,,,,,,,,,,,, -129800,2.811134,2.8685405,,,,,,,,,,,,,, -129900,3.0010402,2.7434597,,,,,,,,,,,,,, -130000,2.5563767,2.8559425,,,,,,,,,,,,,, -130100,2.4902737,3.320577,,,,,,,,,,,,,, -130184,,,0.75355464220047,1.101863980293274,0.6942200064659119,1.3657922744750977,50000.0,0.5699000358581543,1.9783955812454224,10000.0,59282.19617843628,64452.22687602043,59282.19617843628,5155.474381446838,7.194597005844116,0.0 -130200,2.476028,3.0738668,,,,,,,,,,,,,, -130300,2.7039466,4.9354815,,,,,,,,,,,,,, -130400,2.7390628,2.8567061,,,,,,,,,,,,,, -130500,2.7420807,4.8154554,,,,,,,,,,,,,, -130600,2.3904903,4.2193193,,,,,,,,,,,,,, -130700,2.91387,2.7544432,,,,,,,,,,,,,, -130800,2.47598,4.017023,,,,,,,,,,,,,, -130900,2.6901288,2.6491086,,,,,,,,,,,,,, -131000,2.4681728,3.4172645,,,,,,,,,,,,,, -131100,2.6466374,2.7271082,,,,,,,,,,,,,, -131104,,,0.7650976181030273,1.0323588848114014,0.6943999528884888,1.344438910484314,50000.0,0.5657000541687012,1.97388756275177,10000.0,59702.45468664169,64914.25153064728,59702.45468664169,5197.139461517334,7.241278171539307,0.0 -131200,2.772134,2.7843318,,,,,,,,,,,,,, -131300,2.8626766,2.6930223,,,,,,,,,,,,,, -131400,2.6478744,2.8573906,,,,,,,,,,,,,, -131500,2.6767495,2.748658,,,,,,,,,,,,,, -131600,2.5675027,4.6367154,,,,,,,,,,,,,, -131700,2.488112,4.714795,,,,,,,,,,,,,, -131800,2.442855,3.745042,,,,,,,,,,,,,, -131900,2.7438815,3.3292522,,,,,,,,,,,,,, -132000,2.5225487,4.2265368,,,,,,,,,,,,,, -132025,,,0.7515429258346558,1.0836161375045776,0.6946199536323547,1.3420709371566772,50000.0,0.5679000020027161,1.9602190256118768,10000.0,60122.40482163429,65373.663570165634,60122.40482163429,5236.505722999573,7.285628318786621,0.0 -132100,2.6721766,2.6968908,,,,,,,,,,,,,, -132200,2.6213772,3.038326,,,,,,,,,,,,,, -132300,2.496189,4.211146,,,,,,,,,,,,,, -132400,3.0006669,2.7952037,,,,,,,,,,,,,, -132500,2.5054855,2.6977968,,,,,,,,,,,,,, -132600,2.551792,4.017948,,,,,,,,,,,,,, -132700,2.8025534,2.873208,,,,,,,,,,,,,, -132800,2.7908735,2.8098154,,,,,,,,,,,,,, -132900,3.0636725,4.6353345,,,,,,,,,,,,,, -132951,,,0.758593738079071,1.0445502996444702,0.6947999596595764,1.3174066543579102,50000.0,0.5760000348091125,1.9236680269241333,10000.0,60542.771542072296,65831.88937497139,60542.771542072296,5274.264638900757,7.332998752593994,0.0 -133000,2.747426,2.749865,,,,,,,,,,,,,, -133100,2.8001316,2.7145348,,,,,,,,,,,,,, -133200,2.5825958,2.6433182,,,,,,,,,,,,,, -133300,2.8685915,2.6930418,,,,,,,,,,,,,, -133400,2.8041995,2.804288,,,,,,,,,,,,,, -133500,3.038415,2.6894283,,,,,,,,,,,,,, -133600,2.6911082,2.970066,,,,,,,,,,,,,, -133700,2.8081172,2.67823,,,,,,,,,,,,,, -133800,2.9538884,3.8527198,,,,,,,,,,,,,, -133876,,,0.7695898413658142,1.018075704574585,0.6988999843597412,1.3271454572677612,50000.0,0.5711000561714172,1.9531786441802976,10000.0,60962.8595468998,66291.97956442833,60962.8595468998,5314.167495965958,7.379750490188599,0.0 -133900,2.5969372,2.9386752,,,,,,,,,,,,,, -134000,2.783064,2.6873627,,,,,,,,,,,,,, -134100,2.8120394,2.7533634,,,,,,,,,,,,,, -134200,2.7393045,3.04967,,,,,,,,,,,,,, -134300,2.6334536,2.8652866,,,,,,,,,,,,,, -134400,2.9338336,2.5953622,,,,,,,,,,,,,, -134500,2.6458426,2.5859373,,,,,,,,,,,,,, -134600,2.8671029,2.713525,,,,,,,,,,,,,, -134700,2.6228018,4.043987,,,,,,,,,,,,,, -134798,,,0.7604882717132568,1.0434051752090454,0.6985399723052979,1.303519606590271,50000.0,0.5707000494003296,1.9330246448516848,10000.0,61382.80419540405,66751.209690094,61382.80419540405,5353.349631071091,7.431820392608643,0.0 -134800,2.9160438,2.6671376,,,,,,,,,,,,,, -134900,2.9140403,2.9326777,,,,,,,,,,,,,, -135000,2.6674225,3.3105996,,,,,,,,,,,,,, -135100,2.713539,2.9558558,,,,,,,,,,,,,, -135200,2.837084,4.2431254,,,,,,,,,,,,,, -135300,2.8442783,2.6569967,,,,,,,,,,,,,, -135400,2.917402,2.6353073,,,,,,,,,,,,,, -135500,3.1246023,4.693748,,,,,,,,,,,,,, -135600,2.8195605,2.6609457,,,,,,,,,,,,,, -135700,3.138518,4.556306,,,,,,,,,,,,,, -135721,,,0.7647656202316284,1.0401690006256104,0.6980400085449219,1.3200068473815918,50000.0,0.5758000016212463,1.9241626262664795,10000.0,61803.14104223251,67210.00192761421,61803.14104223251,5391.70734000206,7.477010011672974,0.0 -135800,2.8076897,2.9587543,,,,,,,,,,,,,, -135900,3.0607684,2.6340158,,,,,,,,,,,,,, -136000,2.6471782,3.1619568,,,,,,,,,,,,,, -136100,2.6333735,2.759099,,,,,,,,,,,,,, -136200,2.7649417,3.9968195,,,,,,,,,,,,,, -136300,2.8850675,4.646353,,,,,,,,,,,,,, -136400,3.0985246,4.6982408,,,,,,,,,,,,,, -136500,2.6626797,2.6431167,,,,,,,,,,,,,, -136600,3.6397002,3.1986604,,,,,,,,,,,,,, -136644,,,0.7710546851158142,1.0173821449279783,0.7045800089836121,1.311528205871582,50000.0,0.5790000557899475,1.9301297664642327,10000.0,62223.05319976807,67669.96811890602,62223.05319976807,5431.663053035736,7.522252082824707,0.0 -136700,2.810893,2.8727074,,,,,,,,,,,,,, -136800,2.7845426,3.8096159,,,,,,,,,,,,,, -136900,2.8976755,2.5517068,,,,,,,,,,,,,, -137000,2.7902842,4.3078303,,,,,,,,,,,,,, -137100,2.8462048,2.5481524,,,,,,,,,,,,,, -137200,2.7298946,2.7008708,,,,,,,,,,,,,, -137300,3.3149865,4.756173,,,,,,,,,,,,,, -137400,3.0054178,2.6939816,,,,,,,,,,,,,, -137500,3.186142,3.812648,,,,,,,,,,,,,, -137568,,,0.7683789134025574,1.041908621788025,0.7035399675369263,1.3146579265594482,50000.0,0.5751000046730042,1.937660217285156,10000.0,62643.12946271896,68129.23905944824,62643.12946271896,5470.751053571701,7.575902938842773,0.0 -137600,2.489752,2.9541137,,,,,,,,,,,,,, -137700,2.8842463,2.5487103,,,,,,,,,,,,,, -137800,2.884293,2.7326972,,,,,,,,,,,,,, -137900,3.283981,3.1340537,,,,,,,,,,,,,, -138000,2.5945272,3.571961,,,,,,,,,,,,,, -138100,2.907366,4.0863333,,,,,,,,,,,,,, -138200,2.830152,2.739666,,,,,,,,,,,,,, -138300,3.2137876,2.7918897,,,,,,,,,,,,,, -138400,3.0249858,2.6089182,,,,,,,,,,,,,, -138489,,,0.7689452767372131,1.025770902633667,0.7057200074195862,1.2924665212631226,50000.0,0.5884000062942505,1.885142922401428,10000.0,63063.101344347,68589.41958975792,63063.101344347,5510.862103939056,7.621830224990845,0.0 -138500,2.704395,4.053815,,,,,,,,,,,,,, -138600,3.2570062,4.33329,,,,,,,,,,,,,, -138700,2.9406366,2.7066634,,,,,,,,,,,,,, -138800,2.9740872,3.2210698,,,,,,,,,,,,,, -138900,2.904739,4.135206,,,,,,,,,,,,,, -139000,2.8625183,3.1474037,,,,,,,,,,,,,, -139100,3.0048704,2.612064,,,,,,,,,,,,,, -139200,3.0137455,2.6603615,,,,,,,,,,,,,, -139300,3.4655757,3.0395956,,,,,,,,,,,,,, -139400,2.9815824,2.7091281,,,,,,,,,,,,,, -139413,,,0.7711523175239563,1.000161051750183,0.7072599530220032,1.2786537408828735,50000.0,0.5861999988555908,1.8938807249069207,10000.0,63483.11455059052,69048.47295331955,63483.11455059052,5549.8006637096405,7.670376300811768,0.0 -139500,2.8771744,2.6178946,,,,,,,,,,,,,, -139600,3.2349093,3.7563765,,,,,,,,,,,,,, -139700,3.0857608,2.6440744,,,,,,,,,,,,,, -139800,2.8573623,2.638432,,,,,,,,,,,,,, -139900,3.0146854,3.3265579,,,,,,,,,,,,,, -140000,3.546686,2.5294156,,,,,,,,,,,,,, -140100,3.3966818,2.554689,,,,,,,,,,,,,, -140200,3.5512547,4.6413655,,,,,,,,,,,,,, -140300,2.77737,2.5752218,,,,,,,,,,,,,, -140336,,,0.7856835722923279,0.9613086581230164,0.7094999551773071,1.2880979776382446,50000.0,0.5909000039100647,1.8893526792526243,10000.0,63903.165801763535,69508.20226073265,63903.165801763535,5589.37908911705,7.717724561691284,0.0 -140400,3.2745602,3.4906573,,,,,,,,,,,,,, -140500,3.073835,2.6871476,,,,,,,,,,,,,, -140600,3.0289574,3.3115458,,,,,,,,,,,,,, -140700,2.8867614,3.4590802,,,,,,,,,,,,,, -140800,3.1271644,2.6570668,,,,,,,,,,,,,, -140900,2.831654,3.1809623,,,,,,,,,,,,,, -141000,3.0346603,2.6770687,,,,,,,,,,,,,, -141100,3.180394,3.0215921,,,,,,,,,,,,,, -141200,3.0127187,3.7633033,,,,,,,,,,,,,, -141260,,,0.7727343440055847,1.010166049003601,0.7104399800300598,1.2714622020721436,50000.0,0.5878000259399414,1.8784040212631223,10000.0,64323.22742891312,69969.97873592377,64323.22742891312,5630.997187137604,7.762471675872803,0.0 -141300,3.2843244,2.7138345,,,,,,,,,,,,,, -141400,3.545764,4.1989536,,,,,,,,,,,,,, -141500,3.1773741,2.6914058,,,,,,,,,,,,,, -141600,3.4603336,4.519435,,,,,,,,,,,,,, -141700,3.1146948,4.3946977,,,,,,,,,,,,,, -141800,3.128177,2.585063,,,,,,,,,,,,,, -141900,3.7251954,2.6437466,,,,,,,,,,,,,, -142000,3.6418839,2.6209831,,,,,,,,,,,,,, -142100,3.6511586,4.3185024,,,,,,,,,,,,,, -142185,,,0.7766211032867432,0.98188453912735,0.7114999890327454,1.2591363191604614,50000.0,0.5889000296592712,1.8770174980163568,10000.0,64743.2424428463,70428.5822839737,64743.2424428463,5669.488003015518,7.80783200263977,0.0 -142200,3.35074,2.5713186,,,,,,,,,,,,,, -142300,3.5100448,2.6055562,,,,,,,,,,,,,, -142400,3.1314175,3.011155,,,,,,,,,,,,,, -142500,3.1485615,2.6328838,,,,,,,,,,,,,, -142600,3.5045807,2.9459047,,,,,,,,,,,,,, -142700,3.204449,3.055512,,,,,,,,,,,,,, -142800,2.9531183,3.55896,,,,,,,,,,,,,, -142900,3.103227,2.51082,,,,,,,,,,,,,, -143000,3.084276,2.6040711,,,,,,,,,,,,,, -143100,3.4688861,2.5895123,,,,,,,,,,,,,, -143108,,,0.7864062190055847,0.9273659586906432,0.7152000069618225,1.233481764793396,50000.0,0.5932000279426575,1.8319133520126345,10000.0,65163.4554746151,70887.26061153412,65163.4554746151,5707.855123996735,7.854290246963501,0.0 -143200,3.5616348,2.6359746,,,,,,,,,,,,,, -143300,3.2612264,3.062431,,,,,,,,,,,,,, -143400,2.8205793,3.0656374,,,,,,,,,,,,,, -143500,3.5123496,2.604966,,,,,,,,,,,,,, -143600,3.274515,2.6298168,,,,,,,,,,,,,, -143700,3.2092538,2.6372118,,,,,,,,,,,,,, -143800,3.002876,2.9798436,,,,,,,,,,,,,, -143900,3.1610246,3.9316094,,,,,,,,,,,,,, -144000,3.0271065,4.107193,,,,,,,,,,,,,, -144030,,,0.7824804782867432,0.9607452154159546,0.7184999585151672,1.2336714267730713,50000.0,0.593500018119812,1.8458307981491089,10000.0,65583.5791592598,71346.5245103836,65583.5791592598,5746.891888856888,7.905426979064941,0.0 -144100,3.2681375,3.2399662,,,,,,,,,,,,,, -144200,3.4239097,2.6230116,,,,,,,,,,,,,, -144300,3.0573823,2.542194,,,,,,,,,,,,,, -144400,3.9084353,4.1891193,,,,,,,,,,,,,, -144500,3.558512,2.572164,,,,,,,,,,,,,, -144600,3.7162557,3.754285,,,,,,,,,,,,,, -144700,3.1624148,3.398246,,,,,,,,,,,,,, -144800,3.6782112,4.664362,,,,,,,,,,,,,, -144900,3.4421191,2.5530877,,,,,,,,,,,,,, -144951,,,0.7843359112739563,0.9414471983909608,0.7197799682617188,1.2216196060180664,50000.0,0.5945000052452087,1.8343600034713743,10000.0,66003.87094473839,71806.72054195404,66003.87094473839,5786.692119598389,7.957865715026855,0.0 -145000,3.4394069,4.6661634,,,,,,,,,,,,,, -145100,3.3275375,2.5697138,,,,,,,,,,,,,, -145200,3.4982953,2.5587475,,,,,,,,,,,,,, -145300,3.2347188,2.7790418,,,,,,,,,,,,,, -145400,3.2468934,2.5433836,,,,,,,,,,,,,, -145500,16.976156,3.3755133,,,,,,,,,,,,,, -145600,3.5675986,3.8357832,,,,,,,,,,,,,, -145700,4.13948,4.6791077,,,,,,,,,,,,,, -145800,3.5115256,2.5984578,,,,,,,,,,,,,, -145873,,,0.7897070050239563,0.930291473865509,0.7209399938583374,1.2243276834487915,50000.0,0.5939000248908997,1.827216863632202,10000.0,66423.95629882812,72265.67800307274,66423.95629882812,5825.466367721558,8.004544019699097,0.0 -145900,3.2733412,2.6670332,,,,,,,,,,,,,, -146000,3.4623783,2.7233815,,,,,,,,,,,,,, -146100,3.556769,2.930801,,,,,,,,,,,,,, -146200,2.980689,3.533544,,,,,,,,,,,,,, -146300,3.1344895,2.8582335,,,,,,,,,,,,,, -146400,3.432796,3.8138494,,,,,,,,,,,,,, -146500,3.5741751,4.399704,,,,,,,,,,,,,, -146600,3.2750978,2.449175,,,,,,,,,,,,,, -146700,3.329272,2.522695,,,,,,,,,,,,,, -146796,,,0.7828710675239563,0.94913649559021,0.7192999720573425,1.2191599607467651,50000.0,0.5915000438690186,1.837318658828736,10000.0,66843.89471817017,72726.38371825218,66843.89471817017,5866.13242316246,8.054444074630737,0.0 -146800,3.2215917,2.5759132,,,,,,,,,,,,,, -146900,3.1960018,2.5126536,,,,,,,,,,,,,, -147000,3.923788,3.4774106,,,,,,,,,,,,,, -147100,3.6987417,4.629935,,,,,,,,,,,,,, -147200,3.2257793,3.4090836,,,,,,,,,,,,,, -147300,3.2240753,2.6343431,,,,,,,,,,,,,, -147400,3.04847,2.6797323,,,,,,,,,,,,,, -147500,3.5052242,2.5633025,,,,,,,,,,,,,, -147600,3.4986024,2.6391459,,,,,,,,,,,,,, -147700,3.7262757,2.6034007,,,,,,,,,,,,,, -147716,,,0.7864453196525574,0.94319087266922,0.7235599756240845,1.2195154428482056,50000.0,0.5963000059127808,1.8374955654144287,10000.0,67263.90523648262,73187.53686141968,67263.90523648262,5907.169671535492,8.107472896575928,0.0 -147800,3.4168816,2.502191,,,,,,,,,,,,,, -147900,3.2679625,2.7066565,,,,,,,,,,,,,, -148000,4.08653,4.5139074,,,,,,,,,,,,,, -148100,3.3489447,2.510507,,,,,,,,,,,,,, -148200,3.3615727,2.4165282,,,,,,,,,,,,,, -148300,3.9406047,3.8300333,,,,,,,,,,,,,, -148400,3.2422466,2.7178545,,,,,,,,,,,,,, -148500,3.525959,4.5615435,,,,,,,,,,,,,, -148600,3.795743,4.5897136,,,,,,,,,,,,,, -148639,,,0.7929491996765137,0.936110258102417,0.7234199643135071,1.2337615489959717,50000.0,0.6018000245094299,1.8382083177566528,10000.0,67684.12996077538,73649.2893576622,67684.12996077538,5948.59770822525,8.155084133148193,0.0 -148700,3.3377423,3.1910098,,,,,,,,,,,,,, -148800,3.6938589,2.6253016,,,,,,,,,,,,,, -148900,3.4567811,2.9949455,,,,,,,,,,,,,, -149000,3.5096397,2.6848645,,,,,,,,,,,,,, -149100,3.1196775,3.5566695,,,,,,,,,,,,,, -149200,3.7455301,2.525323,,,,,,,,,,,,,, -149300,3.4963467,2.479682,,,,,,,,,,,,,, -149400,3.4701061,3.9106572,,,,,,,,,,,,,, -149500,3.3715563,2.466226,,,,,,,,,,,,,, -149560,,,0.8044726252555847,0.8590974807739258,0.725600004196167,1.193511128425598,50000.0,0.601900041103363,1.7990987300872805,10000.0,68104.28977417946,74111.3414068222,68104.28977417946,5990.390378952026,8.203400373458862,0.0 -149600,3.576118,3.4809437,,,,,,,,,,,,,, -149700,3.173465,3.6741743,,,,,,,,,,,,,, -149800,3.3073018,3.8886676,,,,,,,,,,,,,, -149900,3.6877925,4.3970165,,,,,,,,,,,,,, -150000,3.647722,3.0575867,,,,,,,,,,,,,, -150100,3.510007,2.5390224,,,,,,,,,,,,,, -150200,3.5087967,2.571619,,,,,,,,,,,,,, -150300,4.086506,4.3042397,,,,,,,,,,,,,, -150400,3.8449926,2.5216324,,,,,,,,,,,,,, -150481,,,0.7959179282188416,0.9110642075538636,0.7285799980163574,1.19636070728302,50000.0,0.6038000583648682,1.7990301847457886,10000.0,68524.2951323986,74569.94378137589,68524.2951323986,6028.878688812256,8.260981321334839,0.0 -150500,3.7817006,2.7386603,,,,,,,,,,,,,, -150600,3.8830261,2.527893,,,,,,,,,,,,,, -150700,4.3076954,2.6297026,,,,,,,,,,,,,, -150800,3.6450212,4.027121,,,,,,,,,,,,,, -150900,3.7346287,2.5864115,,,,,,,,,,,,,, -151000,3.5716429,3.6606526,,,,,,,,,,,,,, -151100,3.5580616,2.48429,,,,,,,,,,,,,, -151200,3.675211,2.5629573,,,,,,,,,,,,,, -151300,4.1948442,4.485324,,,,,,,,,,,,,, -151400,3.4977207,2.4681463,,,,,,,,,,,,,, -151402,,,0.7992382645606995,0.8694639205932617,0.7313199639320374,1.1657003164291382,50000.0,0.6051000356674194,1.761571168899536,10000.0,68944.20941090584,75029.67028093338,68944.20941090584,6068.592380285263,8.307986736297607,0.0 -151500,3.5434227,2.5913534,,,,,,,,,,,,,, -151600,3.4061725,2.4353907,,,,,,,,,,,,,, -151700,4.165634,4.481307,,,,,,,,,,,,,, -151800,3.4318619,3.6668696,,,,,,,,,,,,,, -151900,4.3615384,4.4092607,,,,,,,,,,,,,, -152000,3.7067301,3.7601888,,,,,,,,,,,,,, -152100,3.5621128,3.319479,,,,,,,,,,,,,, -152200,3.8575525,2.4587126,,,,,,,,,,,,,, -152300,3.4191372,3.4080765,,,,,,,,,,,,,, -152323,,,0.8080468773841858,0.8454297184944153,0.7318399548530579,1.1674693822860718,50000.0,0.6097000241279602,1.7557659149169922,10000.0,69364.16141724586,75490.700922966,69364.16141724586,6109.569490194321,8.357463836669922,0.0 -152400,3.651123,2.461794,,,,,,,,,,,,,, -152500,3.6844893,2.5967302,,,,,,,,,,,,,, -152600,3.6874385,2.5249991,,,,,,,,,,,,,, -152700,3.8284302,3.9863052,,,,,,,,,,,,,, -152800,3.7282426,3.166991,,,,,,,,,,,,,, -152900,3.4867635,3.0323055,,,,,,,,,,,,,, -153000,4.0978785,2.5153334,,,,,,,,,,,,,, -153100,4.350247,2.5232763,,,,,,,,,,,,,, -153200,3.8242512,2.5797982,,,,,,,,,,,,,, -153246,,,0.802539050579071,0.863842248916626,0.7321000099182129,1.1585302352905271,50000.0,0.614300012588501,1.7566510438919067,10000.0,69784.3099322319,75950.06301856041,69784.3099322319,6148.683254241943,8.404613018035889,0.0 -153300,4.319853,4.164669,,,,,,,,,,,,,, -153400,3.556941,3.2073405,,,,,,,,,,,,,, -153500,3.6231692,3.273596,,,,,,,,,,,,,, -153600,3.8865519,2.8093245,,,,,,,,,,,,,, -153700,4.0937624,2.6289515,,,,,,,,,,,,,, -153800,4.2074466,3.9602156,,,,,,,,,,,,,, -153900,3.7565992,3.300007,,,,,,,,,,,,,, -154000,3.544088,3.0267808,,,,,,,,,,,,,, -154100,3.7781627,2.9884002,,,,,,,,,,,,,, -154167,,,0.80726557970047,0.8443677425384521,0.7341799736022949,1.159018874168396,50000.0,0.614300012588501,1.755558729171753,10000.0,70204.3026239872,76408.33057045937,70204.3026239872,6186.854301929474,8.456630945205688,0.0 -154200,3.9670413,2.437726,,,,,,,,,,,,,, -154300,4.1806583,4.517824,,,,,,,,,,,,,, -154400,4.0311184,2.430669,,,,,,,,,,,,,, -154500,3.8914773,2.4157536,,,,,,,,,,,,,, -154600,4.2639265,2.5490034,,,,,,,,,,,,,, -154700,4.101656,2.4557981,,,,,,,,,,,,,, -154800,3.8869383,2.523027,,,,,,,,,,,,,, -154900,4.7661476,4.5186796,,,,,,,,,,,,,, -155000,3.8511868,2.615621,,,,,,,,,,,,,, -155089,,,0.8079296946525574,0.8516631722450256,0.7345199584960938,1.1670846939086914,50000.0,0.6105000376701355,1.779624342918396,10000.0,70624.36257171631,76866.29339528084,70624.36257171631,6224.655606746674,8.505312204360962,0.0 -155100,4.206555,2.3642495,,,,,,,,,,,,,, -155200,3.9144688,3.6795983,,,,,,,,,,,,,, -155300,3.8716474,2.5277243,,,,,,,,,,,,,, -155400,4.1329827,2.4674807,,,,,,,,,,,,,, -155500,3.9300225,2.4042163,,,,,,,,,,,,,, -155600,4.247539,2.43476,,,,,,,,,,,,,, -155700,3.980225,2.4443107,,,,,,,,,,,,,, -155800,3.9834204,2.355337,,,,,,,,,,,,,, -155900,3.829629,3.1800587,,,,,,,,,,,,,, -156000,3.8702073,3.898167,,,,,,,,,,,,,, -156010,,,0.8057616949081421,0.8517761826515198,0.7357400059700012,1.147667646408081,50000.0,0.6135000586509705,1.7380512952804563,10000.0,71044.26182794571,77326.6550860405,71044.26182794571,6264.754841089249,8.815968751907349,0.0 -156100,3.8783433,2.6918304,,,,,,,,,,,,,, -156200,4.0676384,3.0522695,,,,,,,,,,,,,, -156300,3.838291,2.3184884,,,,,,,,,,,,,, -156400,3.9445918,2.4419858,,,,,,,,,,,,,, -156500,3.7146842,2.5194864,,,,,,,,,,,,,, -156600,4.2384396,2.3939033,,,,,,,,,,,,,, -156700,4.8028274,4.4981117,,,,,,,,,,,,,, -156800,4.2042694,3.5705173,,,,,,,,,,,,,, -156900,3.6863678,2.337194,,,,,,,,,,,,,, -156933,,,0.8101562261581421,0.8456878662109375,0.7382000088691711,1.148350715637207,50000.0,0.6139000058174133,1.7471004724502563,10000.0,71464.36837172508,77786.44018650055,71464.36837172508,6304.327245473862,8.869405031204224,0.0 -157000,4.094797,2.409365,,,,,,,,,,,,,, -157100,3.8405402,3.097225,,,,,,,,,,,,,, -157200,3.949942,3.3404877,,,,,,,,,,,,,, -157300,4.1284294,2.4446013,,,,,,,,,,,,,, -157400,4.0297027,2.3263624,,,,,,,,,,,,,, -157500,3.8629942,2.3469622,,,,,,,,,,,,,, -157600,4.4579124,2.812073,,,,,,,,,,,,,, -157700,4.148536,2.4019485,,,,,,,,,,,,,, -157800,4.3104873,2.4437895,,,,,,,,,,,,,, -157854,,,0.8182226419448853,0.813230574131012,0.7407799959182739,1.141973853111267,50000.0,0.6210000514984131,1.7272372245788574,10000.0,71884.33867025375,78247.37133145332,71884.33867025375,6345.1846034526825,8.920923233032227,0.0 -157900,4.5731287,4.4329348,,,,,,,,,,,,,, -158000,4.5939593,3.093471,,,,,,,,,,,,,, -158100,3.908171,2.5719893,,,,,,,,,,,,,, -158200,4.3101754,3.493958,,,,,,,,,,,,,, -158300,3.9975176,3.4468424,,,,,,,,,,,,,, -158400,3.972369,2.2706926,,,,,,,,,,,,,, -158500,5.1090655,4.431995,,,,,,,,,,,,,, -158600,4.2770557,2.4514897,,,,,,,,,,,,,, -158700,4.438183,2.3951623,,,,,,,,,,,,,, -158776,,,0.8186913728713989,0.811683177947998,0.7415399551391602,1.1460314989089966,50000.0,0.6224000453948975,1.7298548221588137,10000.0,72304.50023531914,78708.44903898239,72304.50023531914,6385.997955322266,8.971666812896729,0.0 -158800,5.2376585,2.3782852,,,,,,,,,,,,,, -158900,3.9993675,3.174527,,,,,,,,,,,,,, -159000,4.2827797,2.4331014,,,,,,,,,,,,,, -159100,4.149342,2.3890595,,,,,,,,,,,,,, -159200,4.0008035,2.9142718,,,,,,,,,,,,,, -159300,4.3514676,2.939985,,,,,,,,,,,,,, -159400,4.186159,2.4195967,,,,,,,,,,,,,, -159500,4.2053413,2.2857487,,,,,,,,,,,,,, -159600,4.296654,2.4039006,,,,,,,,,,,,,, -159700,,,0.8161132335662842,0.8156614303588867,0.7422999739646912,1.1298848390579224,50000.0,0.6160000562667847,1.725846529006958,10000.0,72724.4717707634,79170.03859901428,72724.4717707634,6427.509117841721,9.025804042816162,0.0 -159700,4.175771,2.5521855,,,,,,,,,,,,,, -159800,5.4273553,4.4346547,,,,,,,,,,,,,, -159900,4.001304,2.3718522,,,,,,,,,,,,,, -160000,4.379473,3.742766,,,,,,,,,,,,,, -160100,4.7786703,4.262555,,,,,,,,,,,,,, -160200,4.45465,2.4368043,,,,,,,,,,,,,, -160300,4.483553,2.4427428,,,,,,,,,,,,,, -160400,3.8049035,2.8227239,,,,,,,,,,,,,, -160500,3.9541385,2.2336316,,,,,,,,,,,,,, -160600,4.119507,2.3987226,,,,,,,,,,,,,, -160621,,,0.8175585865974426,0.804296612739563,0.7449600100517273,1.1215741634368896,50000.0,0.624500036239624,1.7089358568191528,10000.0,73144.85194015503,79631.72527456284,73144.85194015503,6468.714903116226,9.074469327926636,0.0 -160700,4.2046304,2.337423,,,,,,,,,,,,,, -160800,4.718176,2.83411,,,,,,,,,,,,,, -160900,4.3222265,2.8665552,,,,,,,,,,,,,, -161000,4.3014665,2.6220374,,,,,,,,,,,,,, -161100,6.7523146,2.5141912,,,,,,,,,,,,,, -161200,4.336521,2.382824,,,,,,,,,,,,,, -161300,4.520977,2.363259,,,,,,,,,,,,,, -161400,5.349088,4.2518144,,,,,,,,,,,,,, -161500,4.1949654,3.0805855,,,,,,,,,,,,,, -161542,,,0.8251757621765137,0.7653958201408386,0.7464799880981445,1.1053043603897097,50000.0,0.6322000026702881,1.680891036987305,10000.0,73564.85105657578,80092.20484352112,73564.85105657578,6509.093505382538,9.12383794784546,0.0 -161600,4.2762694,2.9710705,,,,,,,,,,,,,, -161700,4.3875003,2.2948914,,,,,,,,,,,,,, -161800,5.4687886,4.3315454,,,,,,,,,,,,,, -161900,4.536882,3.2450414,,,,,,,,,,,,,, -162000,4.445142,2.6214368,,,,,,,,,,,,,, -162100,4.702439,2.3691473,,,,,,,,,,,,,, -162200,4.3783116,3.0739143,,,,,,,,,,,,,, -162300,4.4372306,2.4397163,,,,,,,,,,,,,, -162400,4.164853,3.2810078,,,,,,,,,,,,,, -162464,,,0.8218163847923279,0.7823770642280579,0.7467399835586548,1.1013247966766355,50000.0,0.6265000104904175,1.694629788398743,10000.0,73984.82580113411,80553.0836699009,73984.82580113411,6549.888410568237,9.1816246509552,0.0 -162500,4.424398,2.4167073,,,,,,,,,,,,,, -162600,4.4202914,2.1824331,,,,,,,,,,,,,, -162700,5.0285115,3.7585683,,,,,,,,,,,,,, -162800,4.326557,2.7206614,,,,,,,,,,,,,, -162900,4.5644016,2.3405566,,,,,,,,,,,,,, -163000,4.5993896,2.6012073,,,,,,,,,,,,,, -163100,4.165218,2.3281775,,,,,,,,,,,,,, -163200,4.8362436,2.3654263,,,,,,,,,,,,,, -163300,4.6672907,2.352652,,,,,,,,,,,,,, -163386,,,0.8213866949081421,0.7875846028327942,0.7483199834823608,1.1053085327148438,50000.0,0.6223000288009644,1.7001694440841677,10000.0,74404.8524620533,81011.978900671,74404.8524620533,6588.652596235275,9.234551668167114,0.0 -163400,4.6632414,2.343377,,,,,,,,,,,,,, -163500,4.8306365,2.4410825,,,,,,,,,,,,,, -163600,4.4699435,2.5921578,,,,,,,,,,,,,, -163700,4.651499,2.425528,,,,,,,,,,,,,, -163800,4.92097,3.6976995,,,,,,,,,,,,,, -163900,5.1622586,2.4133322,,,,,,,,,,,,,, -164000,4.625358,2.345813,,,,,,,,,,,,,, -164100,4.631308,2.3778129,,,,,,,,,,,,,, -164200,4.80759,2.3164184,,,,,,,,,,,,,, -164300,4.6316104,2.3052173,,,,,,,,,,,,,, -164307,,,0.8302733898162842,0.7504153251647949,0.7503199577331543,1.0849838256835938,50000.0,0.6322000026702881,1.6703184843063354,10000.0,74824.91179394722,81473.03928542137,74824.91179394722,6629.552686929703,9.28411316871643,0.0 -164400,4.593603,2.3524323,,,,,,,,,,,,,, -164500,4.9125657,2.3356426,,,,,,,,,,,,,, -164600,4.648672,2.2755291,,,,,,,,,,,,,, -164700,4.737032,2.4984927,,,,,,,,,,,,,, -164800,4.546199,2.3349445,,,,,,,,,,,,,, -164900,4.667916,3.345018,,,,,,,,,,,,,, -165000,4.8203263,2.3230674,,,,,,,,,,,,,, -165100,4.7780566,2.2860901,,,,,,,,,,,,,, -165200,4.7289863,3.1816015,,,,,,,,,,,,,, -165227,,,0.8239843845367432,0.7693715691566467,0.7511799931526184,1.0776283740997314,50000.0,0.6283000111579895,1.6670862436294556,10000.0,75245.17643213272,81932.90296435356,75245.17643213272,6669.045813798904,9.33809781074524,0.0 -165300,5.450347,4.109107,,,,,,,,,,,,,, -165400,5.025798,2.3077145,,,,,,,,,,,,,, -165500,4.7520566,2.3366039,,,,,,,,,,,,,, -165600,4.9405255,2.3829882,,,,,,,,,,,,,, -165700,4.846949,2.479698,,,,,,,,,,,,,, -165800,5.024971,2.259168,,,,,,,,,,,,,, -165900,5.2887053,2.36838,,,,,,,,,,,,,, -166000,5.085882,2.6095364,,,,,,,,,,,,,, -166100,4.6767144,2.7299058,,,,,,,,,,,,,, -166150,,,0.8310937285423279,0.7499398589134216,0.7531399726867676,1.070204496383667,50000.0,0.6307000517845154,1.6644469499588013,10000.0,75665.09493494034,82392.86175775528,75665.09493494034,6708.982246637344,9.390437126159668,0.0 -166200,4.9627833,2.2377727,,,,,,,,,,,,,, -166300,5.013967,2.3066332,,,,,,,,,,,,,, -166400,4.5811677,2.268452,,,,,,,,,,,,,, -166500,4.762453,3.8135848,,,,,,,,,,,,,, -166600,4.738137,2.2741787,,,,,,,,,,,,,, -166700,4.6576858,2.883676,,,,,,,,,,,,,, -166800,4.6366315,2.3178012,,,,,,,,,,,,,, -166900,4.58852,2.2801347,,,,,,,,,,,,,, -167000,4.973586,3.2312574,,,,,,,,,,,,,, -167071,,,0.83056640625,0.7435762882232666,0.7534199953079224,1.0754247903823853,50000.0,0.6333000063896179,1.6693150997161863,10000.0,76085.28699398041,82853.24296784401,76085.28699398041,6749.0644516944885,9.445293426513672,0.0 -167100,5.03506,3.7692657,,,,,,,,,,,,,, -167200,4.850314,2.6804159,,,,,,,,,,,,,, -167300,5.164863,3.8261943,,,,,,,,,,,,,, -167400,5.2171063,2.1963296,,,,,,,,,,,,,, -167500,5.0405364,2.2156627,,,,,,,,,,,,,, -167600,5.7395554,4.206012,,,,,,,,,,,,,, -167700,4.8983707,2.3979974,,,,,,,,,,,,,, -167800,5.1929092,3.3113894,,,,,,,,,,,,,, -167900,5.191899,2.270616,,,,,,,,,,,,,, -167992,,,0.8321484327316284,0.7515702843666077,0.7541399598121643,1.0826858282089231,50000.0,0.634600043296814,1.6695160865783691,10000.0,76505.26083564758,83311.12951374054,76505.26083564758,6786.874273777008,9.495783567428589,0.0 -168000,5.288217,2.331593,,,,,,,,,,,,,, -168100,5.5301256,2.3319674,,,,,,,,,,,,,, -168200,4.8300242,2.4107993,,,,,,,,,,,,,, -168300,5.487422,2.3055086,,,,,,,,,,,,,, -168400,4.730133,2.5635273,,,,,,,,,,,,,, -168500,4.9577794,2.2446554,,,,,,,,,,,,,, -168600,5.1086307,3.0857038,,,,,,,,,,,,,, -168700,5.0765085,2.2513473,,,,,,,,,,,,,, -168800,5.466431,2.452456,,,,,,,,,,,,,, -168900,4.801932,2.2304325,,,,,,,,,,,,,, -168913,,,0.8340038657188416,0.7366316318511963,0.7570599913597107,1.0599473714828491,50000.0,0.638700008392334,1.6415356397628784,10000.0,76925.41933321953,83772.23526740074,76925.41933321953,6827.716981649399,9.548727989196776,0.0 -169000,5.6015773,2.2924774,,,,,,,,,,,,,, -169100,4.9937806,3.2917225,,,,,,,,,,,,,, -169200,5.4334264,2.2656918,,,,,,,,,,,,,, -169300,5.96158,2.5360384,,,,,,,,,,,,,, -169400,5.3790627,2.275176,,,,,,,,,,,,,, -169500,5.7040935,3.4755073,,,,,,,,,,,,,, -169600,5.773277,3.5418968,,,,,,,,,,,,,, -169700,4.949899,2.2747068,,,,,,,,,,,,,, -169800,5.203391,2.2491698,,,,,,,,,,,,,, -169834,,,0.8338086009025574,0.7275217771530151,0.7583400011062622,1.0552046298980713,50000.0,0.6380000114440918,1.6476746797561646,10000.0,77345.37213397026,84231.7641518116,77345.37213397026,6867.189259767532,9.600689172744753,0.0 -169900,5.5321746,3.4971378,,,,,,,,,,,,,, -170000,5.2074113,2.4494426,,,,,,,,,,,,,, -170100,5.751034,2.210133,,,,,,,,,,,,,, -170200,5.543153,2.4672225,,,,,,,,,,,,,, -170224,,,,,,,,,,,77520.19496154785,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 0a9816016..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -25.40871500968933,0.0,33.48979616165161,1,0,33.48979616165161,0.0010000000474974,6.907756805419922,10000,58.89862966537476,0.0008203124743886,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -64.14088726043701,0.0175280570983886,453.5089168548584,854,0,453.5089168548584,0.0101000005379319,6.498650550842285,10000,517.7165124416351,0.0140429688617587,6.451799392700195,0.0134399998933076,6.465761661529541,50000 -103.90360403060912,0.0479042530059814,873.7088196277618,1768,0,873.7088196277618,0.0351999998092651,5.988783836364746,10000,977.7617542743684,0.0452539063990116,5.853660583496094,0.044659998267889,5.874964714050293,50000 -145.50300693511963,0.0734522342681884,1293.9095661640167,2686,0,1293.9095661640167,0.0538000017404556,5.632424831390381,10000,1439.6401374340055,0.0721289068460464,5.425541400909424,0.0664799958467483,5.466614246368408,50000 -183.9953689575196,0.1029400825500488,1713.922945976257,3601,0,1713.922945976257,0.0749000012874603,5.360351085662842,10000,1898.2273745536804,0.1110156252980232,5.073092937469482,0.0995799973607063,5.132517814636231,50000 -223.9190099239349,0.1261570453643798,2133.9952919483185,4520,0,2133.9952919483185,0.1058000028133392,5.029491424560547,10000,2358.2990391254425,0.1525195240974426,4.659289360046387,0.1367399990558624,4.747751712799072,50000 -264.90250039100647,0.1519322395324707,2554.283395767212,5437,0,2554.283395767212,0.1422000080347061,4.704031467437744,10000,2819.649272918701,0.2096288949251175,4.283182621002197,0.1894799917936325,4.372105121612549,50000 -306.4709143638611,0.176743745803833,2974.62889289856,6357,0,2974.62889289856,0.1807000041007995,4.413839817047119,10000,3281.6409327983856,0.2499413937330246,3.9393117427825928,0.2325399965047836,4.036503314971924,50000 -347.54998540878296,0.2015523910522461,3394.736057758332,7275,0,3394.736057758332,0.2017000168561935,4.236976146697998,10000,3742.9045355319977,0.2956054508686065,3.652733087539673,0.2621199786663055,3.811382293701172,50000 -387.60815477371216,0.231529951095581,3814.776090860367,8192,0,3814.776090860367,0.2297000139951706,4.045166015625,10000,4203.085846185684,0.3202148377895355,3.4828038215637207,0.2987799942493438,3.5942399501800537,50000 -426.7889401912689,0.2578392028808594,4234.937925100327,9112,0,4234.937925100327,0.2621000111103058,3.810883045196533,10000,4662.507352590561,0.3625390529632568,3.197558879852295,0.3323200047016144,3.329472780227661,50000 -467.26917695999146,0.2834861278533935,4655.196612596512,10031,0,4655.196612596512,0.2775000035762787,3.6873080730438232,10000,5123.324376821518,0.3985546827316284,2.9937057495117188,0.3631999790668487,3.172254800796509,50000 -505.16552233695984,0.3097312450408935,5075.5094628334045,10952,0,5075.5094628334045,0.3062000274658203,3.53668212890625,10000,5581.614113092423,0.4247460961341858,2.851715087890625,0.3962799906730652,2.9970977306365967,50000 -543.6498990058899,0.3350875377655029,5495.436726093292,11872,0,5495.436726093292,0.3208000063896179,3.4061996936798096,10000,6040.103754281998,0.4528906047344208,2.6884093284606934,0.4200599789619446,2.8471667766571045,50000 -584.9318788051605,0.3650226593017578,5915.742879390717,12791,0,5915.742879390717,0.3406000137329101,3.2832746505737305,10000,6501.774571418762,0.4826757609844208,2.515464305877685,0.4415600001811981,2.720947742462158,50000 -621.6272025108337,0.394378662109375,6335.661201000214,13709,0,6335.661201000214,0.3562000095844269,3.150334119796753,10000,6958.471544742584,0.5004491806030273,2.3911969661712646,0.4621599912643432,2.5576767921447754,50000 -660.215106010437,0.4230444431304931,6755.971422195435,14630,0,6755.971422195435,0.3742000162601471,3.108792543411255,10000,7417.451339006424,0.5197460651397705,2.346434593200684,0.4825599789619446,2.5191760063171387,50000 -698.8068988323212,0.4497511386871338,7176.325411558151,15550,0,7176.325411558151,0.3827000260353088,3.0271334648132324,10000,7876.476963996887,0.53466796875,2.234204769134521,0.4933599829673767,2.4432485103607178,50000 -738.5819492340088,0.4763305187225342,7596.317331790924,16467,0,7596.317331790924,0.4006000161170959,2.977833747863769,10000,8336.32350063324,0.5671288967132568,2.112599372863769,0.505299985408783,2.3933229446411133,50000 -778.5356502532959,0.5067436695098877,8016.236355781555,17387,0,8016.236355781555,0.4088000059127807,2.8916168212890625,10000,8796.280049562454,0.5574023127555847,2.110726833343506,0.5231800079345703,2.286231279373169,50000 -819.9306211471558,0.5377237796783447,8436.425481557846,18306,0,8436.425481557846,0.4140000343322754,2.887305736541748,10000,9257.947703361511,0.5698437094688416,2.086137056350708,0.5271199941635132,2.275842905044556,50000 -859.3592464923859,0.5637521743774414,8856.704848766327,19218,0,8856.704848766327,0.417900025844574,2.88185453414917,10000,9717.73423075676,0.5857617259025574,2.032967090606689,0.5317599773406982,2.2811155319213867,50000 -900.5303463935852,0.5916616916656494,9277.106140851974,20135,0,9277.106140851974,0.4216000139713287,2.8544681072235107,10000,10179.38702774048,0.5826562643051147,2.0578503608703613,0.5440599918365479,2.237067222595215,50000 -936.777051448822,0.619286298751831,9697.392251729963,21053,0,9697.392251729963,0.4394000172615051,2.793226718902588,10000,10636.000517368317,0.5937890410423279,1.9904780387878416,0.5496399998664856,2.194326162338257,50000 -975.779391527176,0.6520423889160156,10117.686156272888,21972,0,10117.686156272888,0.4426000118255615,2.743307590484619,10000,11095.38220667839,0.613964855670929,1.871402144432068,0.5616599917411804,2.118140935897827,50000 -1016.273241519928,0.684720516204834,10537.782514810562,22891,0,10537.782514810562,0.4502000212669372,2.677503824234009,10000,11556.057990550997,0.6078515648841858,1.876462459564209,0.5680199861526489,2.0680747032165527,50000 -1057.8305151462555,0.7141125202178955,10958.110257148744,23807,0,10958.110257148744,0.4628000259399414,2.6657001972198486,10000,12018.024418115616,0.6229491829872131,1.856359481811524,0.5753600001335144,2.064666986465454,50000 -1099.2352216243744,0.7455592155456543,11378.051944494247,24724,0,11378.051944494247,0.4601000249385834,2.632568836212158,10000,12479.46283340454,0.6285351514816284,1.7818245887756348,0.5766000151634216,2.016919612884521,50000 -1140.624148607254,0.7733578681945801,11798.264919519424,25642,0,11798.264919519424,0.4608000218868255,2.615546226501465,10000,12941.145199537275,0.6348242163658142,1.7657935619354248,0.5819199681282043,1.9933316707611084,50000 -1175.7698872089386,0.8034751415252686,12218.406804323196,26562,0,12218.406804323196,0.4737000167369842,2.547514677047729,10000,13396.518834590912,0.6374804377555847,1.7287397384643557,0.5893999934196472,1.9390709400177,50000 -1211.7467403411863,0.8337299823760986,12638.708625555038,27477,0,12638.708625555038,0.4715000092983246,2.6012699604034424,10000,13852.880910873411,0.6408398151397705,1.7599726915359497,0.5889399647712708,1.9790270328521729,50000 -1250.0245730876925,0.862741231918335,13059.00931596756,28395,0,13059.00931596756,0.4756000339984894,2.566092729568481,10000,14311.540762901306,0.6732617020606995,1.6018322706222534,0.5973399877548218,1.9325991868972776,50000 -1290.234845161438,0.8960211277008057,13479.215354442596,29312,0,13479.215354442596,0.4768000245094299,2.5661449432373047,10000,14772.042618513107,0.6476367115974426,1.7281827926635742,0.5992599725723267,1.946011662483216,50000 -1331.7159614562988,0.9272348880767822,13899.440378904344,30230,0,13899.440378904344,0.4821000099182129,2.5124006271362305,10000,15233.832745790482,0.6563280820846558,1.6538163423538208,0.6067399978637695,1.877922296524048,50000 -1366.7929692268372,0.9575145244598388,14319.611972808838,31148,0,14319.611972808838,0.4859000146389007,2.506326675415039,10000,15689.16557955742,0.6665819883346558,1.621629238128662,0.6076200008392334,1.8831592798233032,50000 -1406.235392332077,0.989626169204712,14739.628532409668,32064,0,14739.628532409668,0.4937000274658203,2.442761182785034,10000,16148.70953464508,0.6660351157188416,1.5995913743972778,0.616919994354248,1.8113443851470947,50000 -1448.665411233902,1.024622678756714,15159.826711416245,32979,0,15159.826711416245,0.489300012588501,2.53031587600708,10000,16611.425479650497,0.66064453125,1.6689999103546145,0.6136199831962585,1.892295241355896,50000 -1490.9241652488708,1.0565321445465088,15579.860214233398,33898,0,15579.860214233398,0.4989000260829925,2.4557673931121826,10000,17073.8019824028,0.6800194978713989,1.5640877485275269,0.6166399717330933,1.8367356061935425,50000 -1530.2285268306732,1.0866823196411133,16000.001391410828,34815,0,16000.001391410828,0.498600035905838,2.475801467895508,10000,17533.330276489258,0.6681445240974426,1.6435627937316897,0.6217399835586548,1.854430675506592,50000 -1572.68993806839,1.116673469543457,16420.151757001877,35733,0,16420.151757001877,0.5069000124931335,2.359993934631348,10000,17996.025110006332,0.6791015267372131,1.499345064163208,0.6260600090026855,1.74126398563385,50000 -1613.168706893921,1.1520779132843018,16840.42435503006,36651,0,16840.42435503006,0.5042999982833862,2.42209792137146,10000,18456.86468553543,0.6833788752555847,1.5450021028518677,0.6283800005912781,1.795507788658142,50000 -1655.7754156589508,1.1850988864898682,17260.590923786163,37570,0,17260.590923786163,0.5034000277519226,2.409806251525879,10000,18919.72491168976,0.6815234422683716,1.5546667575836182,0.6279199719429016,1.7888097763061523,50000 -1697.8138728141785,1.2178847789764404,17680.91156888008,38491,0,17680.91156888008,0.5108000040054321,2.3639705181121826,10000,19382.169947624207,0.6863867044448853,1.5237195491790771,0.634660005569458,1.7491233348846436,50000 -1737.9596025943756,1.2555735111236572,18101.19814991951,39410,0,18101.19814991951,0.5131000280380249,2.3418853282928467,10000,19842.69286584854,0.6941015720367432,1.4616892337799072,0.6366599798202515,1.711472749710083,50000 -1780.9176177978516,1.2859671115875244,18521.3622546196,40329,0,18521.3622546196,0.5110000371932983,2.36631178855896,10000,20305.8983001709,0.7090820074081421,1.4305404424667358,0.6382399797439575,1.7424010038375854,50000 -1815.185781955719,1.3155813217163086,18941.50358247757,41246,0,18941.50358247757,0.511400043964386,2.348185062408448,10000,20760.38983654976,0.6861523389816284,1.5130900144577026,0.6380400061607361,1.7185349464416504,50000 -1857.3007380962367,1.3462746143341064,19361.500798225403,42165,0,19361.500798225403,0.5162000060081482,2.3191921710968018,10000,21222.585634231567,0.6976562142372131,1.452872633934021,0.6402599811553955,1.701892971992493,50000 -1895.374864578247,1.380638599395752,19781.848207473755,43084,0,19781.848207473755,0.5215000510215759,2.2900736331939697,10000,21681.09438085556,0.7085351347923279,1.3984525203704834,0.6442199945449829,1.6737351417541504,50000 -1935.0757467746728,1.41862154006958,20201.83907580376,44000,0,20201.83907580376,0.5199000239372253,2.348926305770874,10000,22140.8764231205,0.6931054592132568,1.5006791353225708,0.6433599591255188,1.717563271522522,50000 -1977.193481206894,1.4574127197265625,20621.8154001236,44915,0,20621.8154001236,0.5258000493049622,2.2985410690307617,10000,22603.06175518036,0.7015234231948853,1.447856068611145,0.6456999778747559,1.6839470863342283,50000 -2016.4534318447115,1.4908981323242188,21042.14222931862,45835,0,21042.14222931862,0.5192000269889832,2.3293352127075195,10000,23062.734843730927,0.7068163752555847,1.426822304725647,0.6466799974441528,1.6960935592651367,50000 -2053.571048259735,1.5256400108337402,21462.42486310005,46755,0,21462.42486310005,0.5317000150680542,2.2454986572265625,10000,23520.222547769547,0.706250011920929,1.3982168436050415,0.6541199684143066,1.629255294799805,50000 -2095.10959148407,1.5598258972167969,21882.692858219147,47674,0,21882.692858219147,0.5311000347137451,2.2582650184631348,10000,23982.11646294593,0.7119921445846558,1.3972750902175903,0.6566799879074097,1.6393754482269287,50000 -2136.696215391159,1.5978095531463623,22303.524026870728,48595,0,22303.524026870728,0.5323000550270081,2.2515869140625,10000,24444.62548708916,0.7161718606948853,1.3505080938339231,0.6545000076293945,1.6139466762542725,50000 -2173.0738096237183,1.6330816745758057,22723.87028694153,49513,0,22723.87028694153,0.5327000021934509,2.3007359504699707,10000,24901.43813586235,0.7284765243530273,1.366170048713684,0.6598399877548218,1.6686944961547852,50000 -2215.0406877994537,1.67065167427063,23143.849014759064,50430,0,23143.849014759064,0.5389000177383423,2.222001552581787,10000,25363.47724819184,0.7109179496765137,1.366124987602234,0.6569799780845642,1.603185772895813,50000 -2257.400363683701,1.7045204639434814,23563.82791399956,51346,0,23563.82791399956,0.5329000353813171,2.27443790435791,10000,25825.90231370926,0.7173827886581421,1.383202314376831,0.6582199931144714,1.6434626579284668,50000 -2296.1784982681274,1.7483222484588623,23984.125440120697,52265,0,23984.125440120697,0.5343000292778015,2.2770166397094727,10000,26285.075368642807,0.7373827695846558,1.3297600746154783,0.6584999561309814,1.656172513961792,50000 -2338.523509502411,1.7898108959197998,24404.181349277496,53181,0,24404.181349277496,0.5390000343322754,2.2040934562683105,10000,26747.570974826813,0.7142968773841858,1.3528928756713867,0.663599967956543,1.5829827785491943,50000 -2376.301381111145,1.831099271774292,24824.275985479355,54095,0,24824.275985479355,0.5357000231742859,2.2335379123687744,10000,27205.53682422638,0.7222851514816284,1.373044729232788,0.6613799929618835,1.6264369487762451,50000 -2414.699910640717,1.867379903793335,25244.52186369896,55013,0,25244.52186369896,0.5479000210762024,2.186769962310791,10000,27664.27063536644,0.7419335842132568,1.2533196210861206,0.6665599942207336,1.5780009031295776,50000 -2457.052332401276,1.9043676853179927,25664.90718483925,55931,0,25664.90718483925,0.5414000153541565,2.191251754760742,10000,28127.09863138199,0.72083979845047,1.3554744720458984,0.6647799611091614,1.584027886390686,50000 -2499.262162208557,1.9429833889007568,26085.24145579338,56846,0,26085.24145579338,0.5484000444412231,2.171085357666016,10000,28589.73458695412,0.7316796779632568,1.2941749095916748,0.6717199683189392,1.55128014087677,50000 -2538.9583842754364,1.9764173030853271,26505.591643333435,57762,0,26505.591643333435,0.5490000247955322,2.189195394515991,10000,29049.86763215065,0.7390429377555847,1.2824914455413818,0.6715999841690063,1.5749740600585938,50000 -2583.152119636536,2.011991500854492,26925.87080693245,58678,0,26925.87080693245,0.5476000308990479,2.1589043140411377,10000,29514.428884983063,0.7281835675239563,1.3011599779129028,0.6716399788856506,1.546015381813049,50000 -2620.7382864952087,2.05407190322876,27346.01905035973,59597,0,27346.01905035973,0.5433000326156616,2.194332361221313,10000,29972.257979154587,0.7227734327316284,1.3543765544891355,0.6686399579048157,1.5894092321395874,50000 -2663.5940973758698,2.096351623535156,27766.20599770546,60512,0,27766.20599770546,0.5527999997138977,2.1810519695281982,10000,30435.395832777023,0.7395703196525574,1.2908554077148438,0.6735599637031555,1.5688538551330566,50000 -2705.494818210602,2.1315011978149414,28186.2903354168,61429,0,28186.2903354168,0.5517000555992126,2.1733086109161377,10000,30897.46825361252,0.7494531273841858,1.2534152269363403,0.6738799810409546,1.5649051666259766,50000 -2746.435801267624,2.166461229324341,28606.247513771057,62343,0,28606.247513771057,0.5503000020980835,2.170751571655273,10000,31358.454117536545,0.7285351157188416,1.318561315536499,0.6755599975585938,1.5545648336410522,50000 -2784.688705921173,2.2054715156555176,29026.210033893585,63259,0,29026.210033893585,0.5529000163078308,2.176795959472656,10000,31816.76087665558,0.7363085746765137,1.2871750593185425,0.6755399703979492,1.5438107252120972,50000 -2824.250700235367,2.2430827617645264,29446.1415143013,64175,0,29446.1415143013,0.5590000152587891,2.1525299549102783,10000,32276.34983420372,0.752734363079071,1.2191165685653689,0.6774399876594543,1.541650414466858,50000 -2861.604071855545,2.2804622650146484,29866.106069087986,65091,0,29866.106069087986,0.5505000352859497,2.155724048614502,10000,32733.75764989853,0.7371875047683716,1.289478063583374,0.6808599829673767,1.5334938764572144,50000 -2902.1909971237183,2.317537307739258,30286.265585184097,66008,0,30286.265585184097,0.5539000034332275,2.1502127647399902,10000,33194.593849658966,0.7390820384025574,1.268944501876831,0.6768999695777893,1.5356301069259644,50000 -2943.6744248867035,2.357588052749634,30706.27566933632,66927,0,30706.27566933632,0.5534999966621399,2.161947011947632,10000,33656.18039536476,0.7473242282867432,1.2253973484039309,0.6799399852752686,1.5155017375946045,50000 -2985.6455442905426,2.400687456130981,31126.626373529434,67844,0,31126.626373529434,0.5626000165939331,2.114659786224365,10000,34118.59836268425,0.738964855670929,1.255308747291565,0.6803999543190002,1.512239933013916,50000 -3025.6633553504944,2.43937611579895,31546.67461037636,68761,0,31546.67461037636,0.5574000477790833,2.1279420852661133,10000,34578.75641846657,0.7493945360183716,1.2292940616607666,0.6869199872016907,1.5031391382217407,50000 -3067.120540380478,2.4745750427246094,31966.91418480873,69678,0,31966.91418480873,0.5576000213623047,2.2235536575317383,10000,35040.54090499878,0.7514257431030273,1.3177834749221802,0.6843000054359436,1.611196517944336,50000 -3107.352053165436,2.5219199657440186,32387.108137846,70595,0,32387.108137846,0.5586000084877014,2.119065761566162,10000,35501.06599497795,0.7463085651397705,1.2566375732421875,0.6869800090789795,1.5039660930633545,50000 -3149.2554540634155,2.5634772777557373,32807.47200012207,71512,0,32807.47200012207,0.5613000392913818,2.0869626998901367,10000,35963.4273827076,0.7489648461341858,1.1991535425186155,0.6866599917411804,1.4654242992401123,50000 -3189.1428916454315,2.59970760345459,33227.58208632469,72427,0,33227.58208632469,0.5659000277519226,2.090136766433716,10000,36423.51284694672,0.7582226395606995,1.187957525253296,0.6908599734306335,1.4750837087631226,50000 -3233.202719926834,2.637953042984009,33647.756412267685,73345,0,33647.756412267685,0.5666000247001648,2.1033482551574707,10000,36887.83846831322,0.7671679258346558,1.1537946462631226,0.6909599900245667,1.4814988374710083,50000 -3275.276128768921,2.678273916244507,34068.01615142822,74262,0,34068.01615142822,0.5663000345230103,2.115422248840332,10000,37350.26576018333,0.750195324420929,1.2365492582321167,0.6885799765586853,1.5010799169540403,50000 -3319.566258430481,2.720487594604492,34488.080597639084,75178,0,34488.080597639084,0.5640000104904175,2.0848517417907715,10000,37814.71540975571,0.7569140195846558,1.1788876056671145,0.6946199536323547,1.4552501440048218,50000 -3360.3150522708893,2.759145021438598,34908.33102440834,76094,0,34908.33102440834,0.5666000247001648,2.119791030883789,10000,38275.80552268028,0.7713086009025574,1.153877019882202,0.6943599581718445,1.489396095275879,50000 -3405.4010264873505,2.8056118488311768,35328.639944553375,77009,0,35328.639944553375,0.5698000192642212,2.0639443397521973,10000,38741.29879665375,0.7552343606948853,1.1852645874023438,0.6930199861526489,1.4465210437774658,50000 -3448.589109897613,2.8494107723236084,35748.82993221283,77923,0,35748.82993221283,0.5684000253677368,2.111642599105835,10000,39204.77293586731,0.7587695121765137,1.2047119140625,0.6947199702262878,1.4846562147140503,50000 -3489.447700977325,2.887045860290528,36168.960582733154,78839,0,36168.960582733154,0.5698000192642212,2.074385166168213,10000,39665.85251927376,0.7691015601158142,1.1444746255874634,0.6953999996185303,1.456865310668945,50000 -3533.049519300461,2.9246163368225098,36588.97576904297,79757,0,36588.97576904297,0.5781000256538391,2.030494213104248,10000,40129.56078243256,0.7608398199081421,1.156747817993164,0.6993399858474731,1.4243495464324951,50000 -3576.878761768341,2.968376398086548,37009.30336403847,80674,0,37009.30336403847,0.5706000328063965,2.0904242992401123,10000,40593.813328027725,0.7592187523841858,1.2048238515853882,0.6990999579429626,1.4711554050445557,50000 -3621.7705612182617,3.010806083679199,37429.31781625748,81588,0,37429.31781625748,0.574400007724762,2.0534827709198,10000,41058.81526851654,0.7721484303474426,1.1226656436920166,0.7008000016212463,1.4333691596984863,50000 -3665.23730635643,3.0572826862335205,37849.59191274643,82505,0,37849.59191274643,0.5731000304222107,2.033210515975952,10000,41522.65545606613,0.7578710913658142,1.1451164484024048,0.6981199979782104,1.403847336769104,50000 -3709.483957052231,3.096383810043335,38269.6051607132,83422,0,38269.6051607132,0.5782000422477722,2.01192593574524,10000,41987.00690460205,0.7671093344688416,1.126673460006714,0.702459990978241,1.4112627506256104,50000 -3747.7794547080994,3.1379082202911377,38689.86214780808,84337,0,38689.86214780808,0.5787000060081482,2.0397210121154785,10000,42445.65338349342,0.7741601467132568,1.1334631443023682,0.7036600112915039,1.4320693016052246,50000 -3787.216263532639,3.178086042404175,39110.120413541794,85253,0,39110.120413541794,0.5779000520706177,2.0556511878967285,10000,42905.44090247154,0.7835546731948853,1.083608627319336,0.7017199993133545,1.4271472692489624,50000 -3828.928506135941,3.2188732624053955,39530.094547748566,86168,0,39530.094547748566,0.5843999981880188,2.0091748237609863,10000,43367.21985912323,0.7661913633346558,1.129399657249451,0.7059599757194519,1.391916036605835,50000 -3872.0360209941855,3.2629506587982178,39950.2617855072,87086,0,39950.2617855072,0.5800000429153442,2.038839817047119,10000,43830.59176373482,0.7721288800239563,1.1214532852172852,0.7034800052642822,1.4084711074829102,50000 -3915.6184599399567,3.306154489517212,40370.22263884544,88002,0,40370.22263884544,0.5897000432014465,2.0038678646087646,10000,44294.23095417023,0.7907617092132568,1.0512465238571167,0.708620011806488,1.3993642330169678,50000 -3955.6573054790497,3.3536322116851807,40790.28323984146,88915,0,40790.28323984146,0.5835000276565552,2.015311479568481,10000,44754.43089079857,0.7695898413658142,1.1286250352859497,0.7078799605369568,1.400683045387268,50000 -3999.028008460999,3.3931477069854736,41210.26914906502,89831,0,41210.26914906502,0.5851000547409058,1.989667654037476,10000,45217.8800368309,0.7773827910423279,1.0906800031661987,0.7084000110626221,1.3841772079467771,50000 -4042.940614700317,3.434882879257202,41630.42470932007,90745,0,41630.42470932007,0.589400053024292,1.9841927289962769,10000,45682.04362034798,0.7867773175239563,1.062552809715271,0.7113800048828125,1.3798199892044067,50000 -4082.0101313591,3.475362062454224,42050.68173789978,91661,0,42050.68173789978,0.5848000049591064,1.9762824773788448,10000,46141.46359419823,0.7730078101158142,1.0981395244598389,0.7106599807739258,1.366332769393921,50000 -4126.3690321445465,3.515968084335327,42470.79413843155,92575,0,42470.79413843155,0.5852000117301941,1.9942631721496584,10000,46606.02896428108,0.779980480670929,1.0799758434295654,0.708620011806488,1.3764957189559937,50000 -4168.915802717209,3.5591742992401123,42890.790291547775,93494,0,42890.790291547775,0.5883000493049622,1.989989161491394,10000,47068.66774916649,0.7886718511581421,1.0608985424041748,0.7127999663352966,1.3815165758132937,50000 -4213.372656822205,3.6043410301208496,43310.81183815002,94409,0,43310.81183815002,0.5932000279426575,1.9795022010803225,10000,47533.24429869652,0.7827734351158142,1.0911414623260498,0.7120400071144104,1.3809247016906738,50000 -4256.298967838287,3.647977590560913,43731.426703214645,95328,0,43731.426703214645,0.5919000506401062,1.9971227645874023,10000,47996.88214635849,0.7830663919448853,1.0953134298324585,0.7137599587440491,1.3859901428222656,50000 -4301.111127138138,3.689197540283203,44151.71507334709,96244,0,44151.71507334709,0.5871000289916992,1.94822096824646,10000,48462.07651758194,0.7925195097923279,1.027241826057434,0.7157999873161316,1.3453587293624878,50000 -4341.087158918381,3.729381799697876,44571.969373226166,97160,0,44571.969373226166,0.5987000465393066,1.9719301462173464,10000,48922.39983391762,0.8024218678474426,1.013872146606445,0.7195599675178528,1.3655524253845217,50000 -4383.347653388977,3.7755613327026367,44992.07482409477,98072,0,44992.07482409477,0.5934000015258789,1.9582252502441408,10000,49384.86444759369,0.7826171517372131,1.069865107536316,0.7172399759292603,1.3597993850708008,50000 -4428.588559150696,3.823927640914917,45412.23049378395,98987,0,45412.23049378395,0.5945000052452087,1.974447965621948,10000,49850.361988306046,0.7899804711341858,1.0631240606307983,0.7159000039100647,1.3712981939315796,50000 -4472.3286554813385,3.867229223251343,45832.44064116478,99904,0,45832.44064116478,0.5960000157356262,1.927379250526428,10000,50314.4089922905,0.8057421445846558,0.96631920337677,0.7188799977302551,1.3333569765090942,50000 -4510.744811296463,3.916255474090576,46252.78383398056,100823,0,46252.78383398056,0.5957000255584717,1.943742036819458,10000,50773.26985836029,0.7916210889816284,1.0372837781906128,0.723039984703064,1.3263189792633057,50000 -4552.680232286453,3.956141948699951,46672.85423207283,101740,0,46672.85423207283,0.5976999998092651,1.9327346086502075,10000,51235.3676404953,0.7941210865974426,1.010243535041809,0.7223399877548218,1.32013201713562,50000 -4592.725254058838,4.001856803894043,47093.024424791336,102657,0,47093.024424791336,0.6000000238418579,1.952873945236206,10000,51695.68058466911,0.8057616949081421,0.9902685880661012,0.7222999930381775,1.3363498449325562,50000 -4636.055516242981,4.0445640087127686,47512.99956417084,103574,0,47512.99956417084,0.603600025177002,1.899206280708313,10000,52159.08165287972,0.79505854845047,1.0035017728805542,0.7264399528503418,1.2918392419815063,50000 -4675.482330560684,4.093322277069092,47933.23684620857,104491,0,47933.23684620857,0.5982000231742859,1.937688231468201,10000,52618.84631681442,0.7984960675239563,1.0110760927200315,0.7244200110435486,1.3226091861724854,50000 -4718.283473730087,4.140423059463501,48353.525584459305,105405,0,48353.525584459305,0.607200026512146,1.8782799243927,10000,53082.0359852314,0.8055859208106995,0.946385383605957,0.7270799875259399,1.2798376083374023,50000 -4758.99973154068,4.183903694152832,48773.57774686813,106321,0,48773.57774686813,0.6035000085830688,1.9131704568862915,10000,53542.90063166618,0.7962890267372131,1.0095607042312622,0.726639986038208,1.308942794799805,50000 -4798.267370700836,4.23191499710083,49193.492560863495,107240,0,49193.492560863495,0.5995000004768372,1.9311659336090088,10000,54002.184608221054,0.79638671875,1.0163118839263916,0.7265799641609192,1.3119384050369265,50000 -4840.51326918602,4.2742462158203125,49613.74026465416,108158,0,49613.74026465416,0.6043000221252441,1.9325958490371704,10000,54464.774629831314,0.8068554401397705,0.9952368140220642,0.7298399806022644,1.3251501321792605,50000 -4881.228107690811,4.3293633460998535,50033.76714491844,109073,0,50033.76714491844,0.6089000105857849,1.8995212316513064,10000,54925.62378549576,0.8169335722923279,0.9389804601669312,0.7301599979400635,1.2972155809402466,50000 -4918.705953121185,4.377705097198486,50453.88138437271,109990,0,50453.88138437271,0.6098000407218933,1.899088382720948,10000,55383.31662654877,0.8023046851158142,0.9991682767868042,0.730239987373352,1.301327347755432,50000 -4960.672949314117,4.423073053359985,50873.97991299629,110907,0,50873.97991299629,0.6110000014305115,1.8786910772323608,10000,55845.48027086258,0.8082812428474426,0.94392591714859,0.7319799661636353,1.2687028646469116,50000 -4999.882031202316,4.470271587371826,51294.16231417656,111820,0,51294.16231417656,0.613800048828125,1.898208737373352,10000,56304.97105741501,0.8192773461341858,0.9103418588638306,0.7331399917602539,1.2813720703125,50000 -5040.045891523361,4.520854949951172,51714.31570267677,112735,0,51714.31570267677,0.6167000532150269,1.8878358602523804,10000,56765.39139795303,0.8069726228713989,0.985032081604004,0.7335000038146973,1.2869131565093994,50000 -5079.130994319916,4.565948486328125,52134.493783950806,113652,0,52134.493783950806,0.6139000058174133,1.887000322341919,10000,57224.7523059845,0.810351550579071,0.9742421507835388,0.7343800067901611,1.293757438659668,50000 -5123.890210866928,4.613940000534058,52554.77208948136,114570,0,52554.77208948136,0.6096000075340271,1.864834427833557,10000,57689.89100813866,0.8231640458106995,0.8864479660987854,0.7347399592399597,1.2553902864456177,50000 -5165.858438968658,4.662526845932007,52974.86817359924,115486,0,52974.86817359924,0.6133000254631042,1.8954529762268064,10000,58152.056190013885,0.8078905940055847,0.9823316931724548,0.7350599765777588,1.3003755807876587,50000 -5209.096930503845,4.717260360717773,53395.18107008934,116403,0,53395.18107008934,0.612500011920929,1.879354238510132,10000,58615.71476054192,0.8154101371765137,0.9408923983573914,0.738099992275238,1.263514518737793,50000 -5249.836025476456,4.766160011291504,53815.10822582245,117321,0,53815.10822582245,0.6232000589370728,1.8169059753417969,10000,59076.4835767746,0.8239257335662842,0.8793371915817261,0.7392399907112122,1.2323999404907229,50000 -5292.465433597565,4.820432662963867,54235.39735746384,118195,0,54235.39735746384,0.6173000335693359,1.852914571762085,10000,59539.506457567215,0.8167187571525574,0.927771270275116,0.7392599582672119,1.2517048120498655,50000 -5332.194699764252,4.867582082748413,54655.33577847481,119111,0,54655.33577847481,0.6168000102043152,1.890485405921936,10000,59999.27346301079,0.8149804472923279,0.9616518020629884,0.7415800094604492,1.2812042236328125,50000 -5375.512019634247,4.913069009780884,55075.26651906967,120024,0,55075.26651906967,0.6173000335693359,1.8545596599578853,10000,60462.61881303787,0.8238866925239563,0.8949841856956482,0.7396199703216553,1.250055909156799,50000 -5414.950091123581,4.961091995239258,55495.54484438896,120940,0,55495.54484438896,0.6230000257492065,1.833022952079773,10000,60922.43671345711,0.8238281011581421,0.8898305296897888,0.7408599853515625,1.232032060623169,50000 -5457.037507534027,5.007095813751221,55915.70909833908,121856,0,55915.70909833908,0.6186000108718872,1.8269776105880733,10000,61384.78747153282,0.8225390315055847,0.8875148296356201,0.744879961013794,1.2177189588546753,50000 -5501.842509508133,5.055784463882446,56336.03149437904,122773,0,56336.03149437904,0.6287000179290771,1.805085062980652,10000,61850.01596212387,0.82582026720047,0.8763144016265869,0.7443599700927734,1.2213648557662964,50000 -5541.017600536346,5.112732887268066,56756.35433626175,123690,0,56756.35433626175,0.6236000061035156,1.8550283908844,10000,62309.62432742119,0.8439843654632568,0.8367078900337219,0.744159996509552,1.248141527175903,50000 -5582.155804157257,5.160711050033569,57176.65181708336,124606,0,57176.65181708336,0.6204000115394592,1.806689739227295,10000,62771.15999698639,0.8265038728713989,0.8778988718986511,0.7480599880218506,1.2074445486068726,50000 -5626.422222137451,5.204384326934815,57596.83524441719,125525,0,57596.83524441719,0.6232000589370728,1.830721974372864,10000,63235.70528793335,0.8316210508346558,0.8772971630096436,0.7468799948692322,1.236533522605896,50000 -5670.740884780884,5.2526023387908936,58016.88030552864,126442,0,58016.88030552864,0.6271000504493713,1.8440407514572144,10000,63700.16956210136,0.84095698595047,0.8439698815345764,0.7470600008964539,1.2318177223205566,50000 -5709.3451907634735,5.299185037612915,58437.10874581337,127359,0,58437.10874581337,0.6279000043869019,1.790854811668396,10000,64159.10203433037,0.8307226300239563,0.8531256914138794,0.7465199828147888,1.2010209560394287,50000 -5751.883533239365,5.343764781951904,58857.14763689041,128275,0,58857.14763689041,0.6304000020027161,1.7960751056671145,10000,64621.77625489235,0.8322070240974426,0.8553465604782104,0.7501800060272217,1.1990872621536257,50000 -5796.012695074081,5.395809412002564,59277.30268001557,129189,0,59277.30268001557,0.6279000043869019,1.829653024673462,10000,65086.16503977776,0.8402343392372131,0.8571125864982605,0.7488999962806702,1.232455849647522,50000 -5834.219810962677,5.449113845825195,59697.6349568367,130106,0,59697.6349568367,0.6306000351905823,1.786401629447937,10000,65544.8101940155,0.83314448595047,0.8533020615577698,0.7516599893569946,1.195619821548462,50000 -5875.437610626221,5.501766443252564,60117.64026284218,131022,0,60117.64026284218,0.6299000382423401,1.7943549156188965,10000,66006.13822078705,0.8349804282188416,0.847000241279602,0.7511000037193298,1.1966373920440674,50000 -5919.410274267197,5.547304630279541,60537.56083083153,131938,0,60537.56083083153,0.6317000389099121,1.801892876625061,10000,66470.13001537323,0.841113269329071,0.8369109630584717,0.7544800043106079,1.2058559656143188,50000 -5962.992874383926,5.59683632850647,60957.67953944206,132855,0,60957.67953944206,0.6373000144958496,1.7519946098327637,10000,66933.93345880508,0.8413476347923279,0.8087218403816223,0.756060004234314,1.167621612548828,50000 -6007.655216932297,5.646831274032593,61377.61789727211,133769,0,61377.61789727211,0.6380000114440918,1.7589350938796997,10000,67398.6365814209,0.8415429592132568,0.8192632794380188,0.7564199566841125,1.1783759593963623,50000 -6047.2556438446045,5.699421167373657,61797.7135746479,134688,0,61797.7135746479,0.6371000409126282,1.79377543926239,10000,67858.43765830994,0.8436328172683716,0.8358486890792847,0.7567600011825562,1.1952334642410278,50000 -6091.693810224533,5.74712872505188,62218.0052587986,135607,0,62218.0052587986,0.6403000354766846,1.7427631616592407,10000,68323.26888012886,0.857714831829071,0.7488970160484314,0.75764000415802,1.1549402475357056,50000 -6132.479343414307,5.793379783630371,62638.34627819061,136525,0,62638.34627819061,0.6402000188827515,1.761014103889465,10000,68784.49385523796,0.8420116901397705,0.8260501623153687,0.7569199800491333,1.185738205909729,50000 -6174.12397646904,5.840576648712158,63058.35525393486,137441,0,63058.35525393486,0.6407000422477722,1.739344596862793,10000,69246.24775362015,0.8498241901397705,0.7776594161987305,0.7576599717140198,1.155328392982483,50000 -6219.046733617783,5.887576580047607,63478.462926864624,138357,0,63478.462926864624,0.6389000415802002,1.75374174118042,10000,69711.37830281258,0.8564062118530273,0.7429553270339966,0.7604999542236328,1.145262360572815,50000 -6262.304250955582,5.935348272323608,63898.55232334137,139275,0,63898.55232334137,0.64000004529953,1.7675161361694336,10000,70174.82610559464,0.848437488079071,0.8022940158843994,0.7615199685096741,1.1675846576690674,50000 -6306.780901193619,5.983578443527222,64318.52702474594,140189,0,64318.52702474594,0.6385000348091125,1.748010277748108,10000,70639.37820744514,0.850390613079071,0.7843379974365234,0.7585200071334839,1.1634849309921265,50000 -6351.042797088623,6.035136699676514,64738.57985305786,141104,0,64738.57985305786,0.645300030708313,1.7508445978164673,10000,71103.796667099,0.8593944907188416,0.7617772221565247,0.7626199722290039,1.1537964344024658,50000 -6395.823813199997,6.086883783340454,65158.57492208481,142015,0,65158.57492208481,0.6426000595092773,1.7553709745407104,10000,71568.67630791664,0.8515819907188416,0.7848900556564331,0.7644199728965759,1.1486810445785522,50000 -6435.086161613464,6.138216018676758,65578.79041981697,142930,0,65578.79041981697,0.6457000374794006,1.7412879467010498,10000,72028.25753879547,0.8580859303474426,0.7559224367141724,0.7646999955177307,1.1461162567138672,50000 -6473.854023933411,6.194704294204712,65998.90920972824,143845,0,65998.90920972824,0.645300030708313,1.753394603729248,10000,72487.25309491158,0.8574609160423279,0.7697817087173462,0.7655199766159058,1.1547305583953855,50000 -6515.839464187622,6.242878437042236,66419.2152094841,144762,0,66419.2152094841,0.6467000246047974,1.7295129299163818,10000,72949.6464586258,0.8602929711341858,0.742518961429596,0.7663799524307251,1.1345235109329224,50000 -6559.876788377762,6.297804832458496,66839.39895510674,145674,0,66839.39895510674,0.6487000584602356,1.7240246534347534,10000,73413.97380805016,0.8600585460662842,0.7430623173713684,0.7668799757957458,1.1308563947677612,50000 -6602.229855775833,6.350820302963257,67259.50889992714,146588,0,67259.50889992714,0.64410001039505,1.748490333557129,10000,73876.54255223274,0.864062488079071,0.7536942958831787,0.7701799869537354,1.145656943321228,50000 -6646.628881692886,6.404725551605225,67679.66255092621,147504,0,67679.66255092621,0.6520000100135803,1.7187994718551636,10000,74341.20226407051,0.8710546493530273,0.708790123462677,0.7684599757194519,1.1278865337371826,50000 -6690.891248703003,6.458734512329102,68099.97631287575,148421,0,68099.97631287575,0.6528000235557556,1.7120332717895508,10000,74805.88531470299,0.8649804592132568,0.7262025475502014,0.7682799696922302,1.1214993000030518,50000 -6733.445414304733,6.892820119857788,68519.64030075073,149335,0,68519.64030075073,0.651900053024292,1.711452603340149,10000,75268.590113163,0.8691601157188416,0.7210755944252014,0.7702999711036682,1.124349594116211,50000 -6779.409840583801,6.94247579574585,68939.89294409752,150250,0,68939.89294409752,0.6508000493049622,1.714530348777771,10000,75734.91024708748,0.8727734088897705,0.6951795816421509,0.770859956741333,1.1124303340911863,50000 -6819.731669664383,6.9958555698394775,69359.92496109009,151165,0,69359.92496109009,0.6515000462532043,1.7132558822631836,10000,76195.36973547935,0.8675585985183716,0.7195205688476562,0.7733199596405029,1.1224658489227295,50000 -6863.285403966904,7.0559470653533936,69779.90133500099,152079,0,69779.90133500099,0.6582000255584717,1.690200686454773,10000,76659.01178598404,0.8717772960662842,0.6925562620162964,0.7725399732589722,1.0992447137832642,50000 -6901.371284723282,7.113300085067749,70199.87671470642,152996,0,70199.87671470642,0.6490000486373901,1.7110410928726196,10000,77117.18258023262,0.8743945360183716,0.6965427398681641,0.7730000019073486,1.1139885187149048,50000 -6943.878223657608,7.172998905181885,70619.90300965309,153909,0,70619.90300965309,0.6552000045776367,1.6988980770111084,10000,77579.82703661919,0.8716406226158142,0.702102541923523,0.7744199633598328,1.1024911403656006,50000 -6984.388374567032,7.224968194961548,71039.93932843208,154827,0,71039.93932843208,0.6602000594139099,1.70155668258667,10000,78040.47797226906,0.8707226514816284,0.7155790328979492,0.7752999663352966,1.1152511835098269,50000 -7025.718888044357,7.279221773147583,71460.28168869019,155742,0,71460.28168869019,0.6535000205039978,1.7088316679000854,10000,78502.2566742897,0.8770117163658142,0.6983265280723572,0.7755199670791626,1.116679549217224,50000 -7066.457926273346,7.332519292831421,71880.30394387245,156657,0,71880.30394387245,0.657200038433075,1.6836694478988647,10000,78963.12360739708,0.8760351538658142,0.6916414499282837,0.7758199572563171,1.1019681692123413,50000 -7110.197211503983,7.3921473026275635,72300.49660873413,157574,0,72300.49660873413,0.6581000089645386,1.694753646850586,10000,79427.16787004471,0.8784960508346558,0.6948211789131165,0.7772799730300903,1.1076263189315796,50000 -7154.059904813767,7.443153619766235,72720.79914164543,158489,0,72720.79914164543,0.6564000248908997,1.6833994388580322,10000,79891.43614602089,0.8775194883346558,0.6820436716079712,0.776419997215271,1.095947504043579,50000 -7197.746908187866,7.494960784912109,73141.05099487305,159403,0,73141.05099487305,0.6584000587463379,1.6912035942077637,10000,80355.4780535698,0.8854882717132568,0.6608580350875854,0.7779200077056885,1.1026121377944946,50000 -7243.284869670868,7.547072887420654,73561.01437163353,160318,0,73561.01437163353,0.6579000353813171,1.6871545314788818,10000,80821.08317613602,0.8786523342132568,0.6797360181808472,0.778939962387085,1.0944668054580688,50000 -7280.42783331871,7.600719690322876,73980.93570017815,161235,0,73980.93570017815,0.6635000109672546,1.6645426750183103,10000,81278.25377750397,0.8811132907867432,0.6593165993690491,0.7783799767494202,1.0817325115203855,50000 -7323.926148414612,7.653084754943848,74401.1434378624,162150,0,74401.1434378624,0.6640000343322754,1.6774790287017822,10000,81742.06526184082,0.8886327743530273,0.6440286040306091,0.7780199646949768,1.0898090600967407,50000 -7367.785813331604,7.709946155548096,74821.25588536263,163066,0,74821.25588536263,0.6609000563621521,1.669062614440918,10000,82206.14687275887,0.8836523294448853,0.6655579805374146,0.7795799970626831,1.0824384689331057,50000 -7412.314453601837,7.765376091003418,75241.35402417183,163980,0,75241.35402417183,0.659500002861023,1.6772193908691406,10000,82670.88124489784,0.8823828101158142,0.6626566648483276,0.7808399796485901,1.0874329805374146,50000 -7455.80052113533,7.824825286865234,75661.63371706009,164896,0,75661.63371706009,0.6633000373840332,1.6752506494522097,10000,83134.75895619392,0.8887304663658142,0.6466950178146362,0.7809199690818787,1.0870565176010132,50000 -7501.036018610001,7.883217334747314,76081.98018074036,165811,0,76081.98018074036,0.6655000448226929,1.6651965379714966,10000,83600.45110487938,0.8844726085662842,0.6560773849487305,0.7813000082969666,1.0782166719436646,50000 -7539.935069799423,7.935632944107056,76502.31396532059,166727,0,76502.31396532059,0.6671000123023987,1.654746413230896,10000,84059.78851270676,0.8896484375,0.6387453079223633,0.7838599681854248,1.0713698863983154,50000 -7584.034651994705,7.9859490394592285,76922.55463337898,167644,0,76922.55463337898,0.6630000472068787,1.6625733375549316,10000,84524.23124790192,0.8884179592132568,0.649350643157959,0.7827999591827393,1.0806797742843628,50000 -7626.804463386536,8.039025783538818,77342.60869932175,168561,0,77342.60869932175,0.666700005531311,1.6670899391174316,10000,84987.16055512428,0.8898632526397705,0.6395523548126221,0.7828800082206726,1.0810366868972778,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/measurements.csv deleted file mode 100644 index 395d1f341..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1877 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.29810053,6.9077535,,,,,,,,,,,,,, -1,,,0.0008203124743886,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,33.48979616165161,58.89862966537476,33.48979616165161,25.40871500968933,0.0,0.0 -100,0.3882684,6.9038615,,,,,,,,,,,,,, -200,0.43141705,6.8825707,,,,,,,,,,,,,, -300,0.50145763,6.8551836,,,,,,,,,,,,,, -400,0.5497487,6.846946,,,,,,,,,,,,,, -500,0.5792926,6.7863083,,,,,,,,,,,,,, -600,0.49089873,6.741988,,,,,,,,,,,,,, -700,1.0341983,6.7091446,,,,,,,,,,,,,, -800,0.82356286,6.641955,,,,,,,,,,,,,, -854,,,0.0140429688617587,6.451799392700195,0.0134399998933076,6.465761661529541,50000.0,0.0101000005379319,6.498650550842285,10000.0,453.5089168548584,517.7165124416351,453.5089168548584,64.14088726043701,0.0175280570983886,0.0 -900,1.3741965,6.6504807,,,,,,,,,,,,,, -1000,0.9647621,6.600851,,,,,,,,,,,,,, -1100,1.0621064,6.5312424,,,,,,,,,,,,,, -1200,0.99864066,6.517338,,,,,,,,,,,,,, -1300,0.8765631,6.445308,,,,,,,,,,,,,, -1400,2.397496,6.4962764,,,,,,,,,,,,,, -1500,1.6469561,6.440393,,,,,,,,,,,,,, -1600,1.2138859,6.352611,,,,,,,,,,,,,, -1700,1.2445138,6.3434696,,,,,,,,,,,,,, -1768,,,0.0452539063990116,5.853660583496094,0.044659998267889,5.874964714050293,50000.0,0.0351999998092651,5.988783836364746,10000.0,873.7088196277618,977.7617542743684,873.7088196277618,103.90360403060912,0.0479042530059814,0.0 -1800,1.3798757,6.445388,,,,,,,,,,,,,, -1900,0.94226664,6.602548,,,,,,,,,,,,,, -2000,1.0879852,6.306015,,,,,,,,,,,,,, -2100,1.0207015,6.5419097,,,,,,,,,,,,,, -2200,1.239549,6.2600613,,,,,,,,,,,,,, -2300,1.1418899,6.2147884,,,,,,,,,,,,,, -2400,0.8032821,6.6151586,,,,,,,,,,,,,, -2500,0.90412086,6.7003827,,,,,,,,,,,,,, -2600,1.035976,6.144948,,,,,,,,,,,,,, -2686,,,0.0721289068460464,5.425541400909424,0.0664799958467483,5.466614246368408,50000.0,0.0538000017404556,5.632424831390381,10000.0,1293.9095661640167,1439.6401374340055,1293.9095661640167,145.50300693511963,0.0734522342681884,0.0 -2700,1.2723927,6.168304,,,,,,,,,,,,,, -2800,1.4190099,6.661294,,,,,,,,,,,,,, -2900,1.1092544,6.1798944,,,,,,,,,,,,,, -3000,1.2474487,6.037815,,,,,,,,,,,,,, -3100,0.9549809,6.0666766,,,,,,,,,,,,,, -3200,0.9823571,6.616784,,,,,,,,,,,,,, -3300,1.0545716,5.992593,,,,,,,,,,,,,, -3400,1.0906589,5.9971538,,,,,,,,,,,,,, -3500,1.5065795,6.080501,,,,,,,,,,,,,, -3600,0.91560936,6.434338,,,,,,,,,,,,,, -3601,,,0.1110156252980232,5.073092937469482,0.0995799973607063,5.132517814636231,50000.0,0.0749000012874603,5.360351085662842,10000.0,1713.922945976257,1898.2273745536804,1713.922945976257,183.9953689575196,0.1029400825500488,0.0 -3700,1.1028262,5.9439917,,,,,,,,,,,,,, -3800,0.938945,5.9255795,,,,,,,,,,,,,, -3900,1.0749116,5.9374957,,,,,,,,,,,,,, -4000,1.1781092,5.850253,,,,,,,,,,,,,, -4100,1.3215716,6.157834,,,,,,,,,,,,,, -4200,1.1683115,5.817495,,,,,,,,,,,,,, -4300,1.017002,5.7073326,,,,,,,,,,,,,, -4400,0.9394937,6.289195,,,,,,,,,,,,,, -4500,1.4003842,5.76271,,,,,,,,,,,,,, -4520,,,0.1525195240974426,4.659289360046387,0.1367399990558624,4.747751712799072,50000.0,0.1058000028133392,5.029491424560547,10000.0,2133.9952919483185,2358.2990391254425,2133.9952919483185,223.9190099239349,0.1261570453643798,0.0 -4600,1.1270797,5.7108393,,,,,,,,,,,,,, -4700,1.271445,5.7137136,,,,,,,,,,,,,, -4800,1.2455287,5.650526,,,,,,,,,,,,,, -4900,1.526599,5.6673403,,,,,,,,,,,,,, -5000,0.8643478,6.385779,,,,,,,,,,,,,, -5100,1.0899843,5.860268,,,,,,,,,,,,,, -5200,0.99829507,5.638861,,,,,,,,,,,,,, -5300,1.2565957,5.5697217,,,,,,,,,,,,,, -5400,1.0377719,5.496123,,,,,,,,,,,,,, -5437,,,0.2096288949251175,4.283182621002197,0.1894799917936325,4.372105121612549,50000.0,0.1422000080347061,4.704031467437744,10000.0,2554.283395767212,2819.649272918701,2554.283395767212,264.90250039100647,0.1519322395324707,0.0 -5500,1.0925385,5.5527344,,,,,,,,,,,,,, -5600,0.9653474,6.074388,,,,,,,,,,,,,, -5700,1.225729,5.5530396,,,,,,,,,,,,,, -5800,1.2751431,5.430729,,,,,,,,,,,,,, -5900,1.1552624,5.40311,,,,,,,,,,,,,, -6000,1.7441664,5.821639,,,,,,,,,,,,,, -6100,0.96186805,6.174363,,,,,,,,,,,,,, -6200,1.0943319,5.5261564,,,,,,,,,,,,,, -6300,1.2766863,5.318653,,,,,,,,,,,,,, -6357,,,0.2499413937330246,3.9393117427825928,0.2325399965047836,4.036503314971924,50000.0,0.1807000041007995,4.413839817047119,10000.0,2974.62889289856,3281.6409327983856,2974.62889289856,306.4709143638611,0.176743745803833,0.0 -6400,0.9807743,6.164378,,,,,,,,,,,,,, -6500,1.3454124,5.3035994,,,,,,,,,,,,,, -6600,1.1563755,5.3261647,,,,,,,,,,,,,, -6700,0.83554345,6.333549,,,,,,,,,,,,,, -6800,0.92086834,6.1026797,,,,,,,,,,,,,, -6900,0.9928729,6.308443,,,,,,,,,,,,,, -7000,0.8877133,6.2178197,,,,,,,,,,,,,, -7100,1.0296389,5.229633,,,,,,,,,,,,,, -7200,1.2890886,5.348953,,,,,,,,,,,,,, -7275,,,0.2956054508686065,3.652733087539673,0.2621199786663055,3.811382293701172,50000.0,0.2017000168561935,4.236976146697998,10000.0,3394.736057758332,3742.9045355319977,3394.736057758332,347.54998540878296,0.2015523910522461,0.0 -7300,1.1682239,5.1901855,,,,,,,,,,,,,, -7400,1.0083479,5.1901426,,,,,,,,,,,,,, -7500,1.1640896,5.378966,,,,,,,,,,,,,, -7600,1.0328908,5.1031523,,,,,,,,,,,,,, -7700,1.2695708,5.1631885,,,,,,,,,,,,,, -7800,1.1454405,5.084059,,,,,,,,,,,,,, -7900,1.1248627,5.0922217,,,,,,,,,,,,,, -8000,1.2466713,5.079219,,,,,,,,,,,,,, -8100,1.029118,5.2390532,,,,,,,,,,,,,, -8192,,,0.3202148377895355,3.4828038215637207,0.2987799942493438,3.5942399501800537,50000.0,0.2297000139951706,4.045166015625,10000.0,3814.776090860367,4203.085846185684,3814.776090860367,387.60815477371216,0.231529951095581,0.0 -8200,0.7676449,6.2310495,,,,,,,,,,,,,, -8300,1.1623998,4.9899096,,,,,,,,,,,,,, -8400,0.93855876,5.878004,,,,,,,,,,,,,, -8500,0.95912087,5.2001824,,,,,,,,,,,,,, -8600,1.0274538,5.1688604,,,,,,,,,,,,,, -8700,0.74439716,6.134138,,,,,,,,,,,,,, -8800,1.2472966,4.997823,,,,,,,,,,,,,, -8900,0.7524981,6.2832174,,,,,,,,,,,,,, -9000,0.7392098,5.6930475,,,,,,,,,,,,,, -9100,0.97822464,4.9119935,,,,,,,,,,,,,, -9112,,,0.3625390529632568,3.197558879852295,0.3323200047016144,3.329472780227661,50000.0,0.2621000111103058,3.810883045196533,10000.0,4234.937925100327,4662.507352590561,4234.937925100327,426.7889401912689,0.2578392028808594,0.0 -9200,0.9899534,4.8639054,,,,,,,,,,,,,, -9300,0.94622725,4.845247,,,,,,,,,,,,,, -9400,0.6613163,6.104943,,,,,,,,,,,,,, -9500,1.4357435,4.819489,,,,,,,,,,,,,, -9600,1.2027082,4.8190126,,,,,,,,,,,,,, -9700,0.8698272,5.587406,,,,,,,,,,,,,, -9800,0.7499259,5.9579687,,,,,,,,,,,,,, -9900,0.7611952,6.3135443,,,,,,,,,,,,,, -10000,1.0621507,4.924604,,,,,,,,,,,,,, -10031,,,0.3985546827316284,2.9937057495117188,0.3631999790668487,3.172254800796509,50000.0,0.2775000035762787,3.6873080730438232,10000.0,4655.196612596512,5123.324376821518,4655.196612596512,467.26917695999146,0.2834861278533935,0.0 -10100,1.0335361,4.8103495,,,,,,,,,,,,,, -10200,0.9659044,4.8892274,,,,,,,,,,,,,, -10300,1.027035,4.76825,,,,,,,,,,,,,, -10400,0.9233379,4.623209,,,,,,,,,,,,,, -10500,1.1734142,4.810197,,,,,,,,,,,,,, -10600,0.9696659,4.764827,,,,,,,,,,,,,, -10700,0.7474859,5.395315,,,,,,,,,,,,,, -10800,0.70239276,5.7202806,,,,,,,,,,,,,, -10900,1.0015862,4.6118283,,,,,,,,,,,,,, -10952,,,0.4247460961341858,2.851715087890625,0.3962799906730652,2.9970977306365967,50000.0,0.3062000274658203,3.53668212890625,10000.0,5075.5094628334045,5581.614113092423,5075.5094628334045,505.16552233695984,0.3097312450408935,0.0 -11000,0.7821112,5.45662,,,,,,,,,,,,,, -11100,0.93973905,4.6691318,,,,,,,,,,,,,, -11200,0.9446306,4.5595055,,,,,,,,,,,,,, -11300,1.0361946,4.5206804,,,,,,,,,,,,,, -11400,0.79925257,5.0283895,,,,,,,,,,,,,, -11500,0.73039687,6.129276,,,,,,,,,,,,,, -11600,0.66009134,6.030238,,,,,,,,,,,,,, -11700,0.9406419,4.7089396,,,,,,,,,,,,,, -11800,0.72309166,5.187085,,,,,,,,,,,,,, -11872,,,0.4528906047344208,2.6884093284606934,0.4200599789619446,2.8471667766571045,50000.0,0.3208000063896179,3.4061996936798096,10000.0,5495.436726093292,6040.103754281998,5495.436726093292,543.6498990058899,0.3350875377655029,0.0 -11900,0.90224487,4.739918,,,,,,,,,,,,,, -12000,0.95763457,4.424325,,,,,,,,,,,,,, -12100,0.764483,5.386701,,,,,,,,,,,,,, -12200,1.0814296,4.425545,,,,,,,,,,,,,, -12300,0.9779671,4.629334,,,,,,,,,,,,,, -12400,1.0897235,4.54786,,,,,,,,,,,,,, -12500,0.88138205,4.678476,,,,,,,,,,,,,, -12600,0.6386,5.8970027,,,,,,,,,,,,,, -12700,1.0398079,4.4328146,,,,,,,,,,,,,, -12791,,,0.4826757609844208,2.515464305877685,0.4415600001811981,2.720947742462158,50000.0,0.3406000137329101,3.2832746505737305,10000.0,5915.742879390717,6501.774571418762,5915.742879390717,584.9318788051605,0.3650226593017578,0.0 -12800,1.0822655,4.4960136,,,,,,,,,,,,,, -12900,0.9168245,4.494055,,,,,,,,,,,,,, -13000,0.76949674,5.027962,,,,,,,,,,,,,, -13100,0.97616297,4.4402676,,,,,,,,,,,,,, -13200,0.91464907,4.4174256,,,,,,,,,,,,,, -13300,0.8822136,4.35081,,,,,,,,,,,,,, -13400,1.0463009,4.432855,,,,,,,,,,,,,, -13500,0.77086025,5.097184,,,,,,,,,,,,,, -13600,0.86040044,4.410296,,,,,,,,,,,,,, -13700,0.9187656,4.7521095,,,,,,,,,,,,,, -13709,,,0.5004491806030273,2.3911969661712646,0.4621599912643432,2.5576767921447754,50000.0,0.3562000095844269,3.150334119796753,10000.0,6335.661201000214,6958.471544742584,6335.661201000214,621.6272025108337,0.394378662109375,0.0 -13800,0.6866792,5.7455044,,,,,,,,,,,,,, -13900,0.9184479,4.5036125,,,,,,,,,,,,,, -14000,0.68926203,5.3126917,,,,,,,,,,,,,, -14100,0.91336745,4.802228,,,,,,,,,,,,,, -14200,0.630083,5.827702,,,,,,,,,,,,,, -14300,0.8000843,5.735481,,,,,,,,,,,,,, -14400,0.7162349,6.064936,,,,,,,,,,,,,, -14500,1.0292084,4.3553925,,,,,,,,,,,,,, -14600,0.9078199,4.3195615,,,,,,,,,,,,,, -14630,,,0.5197460651397705,2.346434593200684,0.4825599789619446,2.5191760063171387,50000.0,0.3742000162601471,3.108792543411255,10000.0,6755.971422195435,7417.451339006424,6755.971422195435,660.215106010437,0.4230444431304931,0.0 -14700,0.94546247,4.334895,,,,,,,,,,,,,, -14800,0.91294235,4.393267,,,,,,,,,,,,,, -14900,1.041223,4.1846204,,,,,,,,,,,,,, -15000,0.9674033,4.281478,,,,,,,,,,,,,, -15100,0.9891731,4.428173,,,,,,,,,,,,,, -15200,0.9379989,4.287625,,,,,,,,,,,,,, -15300,0.96537954,4.355807,,,,,,,,,,,,,, -15400,0.75388074,5.2782354,,,,,,,,,,,,,, -15500,0.95146894,4.298214,,,,,,,,,,,,,, -15550,,,0.53466796875,2.234204769134521,0.4933599829673767,2.4432485103607178,50000.0,0.3827000260353088,3.0271334648132324,10000.0,7176.325411558151,7876.476963996887,7176.325411558151,698.8068988323212,0.4497511386871338,0.0 -15600,0.9138331,4.429789,,,,,,,,,,,,,, -15700,0.8591358,4.290969,,,,,,,,,,,,,, -15800,0.9256324,4.3800173,,,,,,,,,,,,,, -15900,0.91822284,4.7264276,,,,,,,,,,,,,, -16000,0.9415963,4.262964,,,,,,,,,,,,,, -16100,0.89705694,4.206817,,,,,,,,,,,,,, -16200,0.8742858,4.42272,,,,,,,,,,,,,, -16300,0.83779514,4.752382,,,,,,,,,,,,,, -16400,0.973257,4.2279406,,,,,,,,,,,,,, -16467,,,0.5671288967132568,2.112599372863769,0.505299985408783,2.3933229446411133,50000.0,0.4006000161170959,2.977833747863769,10000.0,7596.317331790924,8336.32350063324,7596.317331790924,738.5819492340088,0.4763305187225342,0.0 -16500,0.7279265,5.693586,,,,,,,,,,,,,, -16600,0.86317486,4.1509757,,,,,,,,,,,,,, -16700,1.14954,4.1771617,,,,,,,,,,,,,, -16800,0.77839524,4.9482346,,,,,,,,,,,,,, -16900,0.8811724,4.3043766,,,,,,,,,,,,,, -17000,0.82797915,4.7098513,,,,,,,,,,,,,, -17100,0.8735716,4.545065,,,,,,,,,,,,,, -17200,0.91017145,4.1774664,,,,,,,,,,,,,, -17300,0.8171707,4.685866,,,,,,,,,,,,,, -17387,,,0.5574023127555847,2.110726833343506,0.5231800079345703,2.286231279373169,50000.0,0.4088000059127807,2.8916168212890625,10000.0,8016.236355781555,8796.280049562454,8016.236355781555,778.5356502532959,0.5067436695098877,0.0 -17400,0.7344579,4.842475,,,,,,,,,,,,,, -17500,0.8120201,4.095276,,,,,,,,,,,,,, -17600,0.8695242,4.1518984,,,,,,,,,,,,,, -17700,0.95140046,4.131654,,,,,,,,,,,,,, -17800,0.953137,4.1013136,,,,,,,,,,,,,, -17900,0.90563226,4.2333975,,,,,,,,,,,,,, -18000,0.9524358,4.1902995,,,,,,,,,,,,,, -18100,0.95598465,4.013528,,,,,,,,,,,,,, -18200,0.8249606,4.721331,,,,,,,,,,,,,, -18300,0.8857458,4.612793,,,,,,,,,,,,,, -18306,,,0.5698437094688416,2.086137056350708,0.5271199941635132,2.275842905044556,50000.0,0.4140000343322754,2.887305736541748,10000.0,8436.425481557846,9257.947703361511,8436.425481557846,819.9306211471558,0.5377237796783447,0.0 -18400,0.8988287,4.068427,,,,,,,,,,,,,, -18500,0.9265351,4.0612125,,,,,,,,,,,,,, -18600,0.9172109,4.0727105,,,,,,,,,,,,,, -18700,0.9169552,4.086375,,,,,,,,,,,,,, -18800,0.7270971,5.1178217,,,,,,,,,,,,,, -18900,0.8481262,4.327803,,,,,,,,,,,,,, -19000,0.8834139,4.009995,,,,,,,,,,,,,, -19100,0.6735661,5.7303915,,,,,,,,,,,,,, -19200,0.9279008,4.127366,,,,,,,,,,,,,, -19218,,,0.5857617259025574,2.032967090606689,0.5317599773406982,2.2811155319213867,50000.0,0.417900025844574,2.88185453414917,10000.0,8856.704848766327,9717.73423075676,8856.704848766327,859.3592464923859,0.5637521743774414,0.0 -19300,1.083869,4.1464167,,,,,,,,,,,,,, -19400,0.97616893,4.151034,,,,,,,,,,,,,, -19500,0.7101121,5.8169217,,,,,,,,,,,,,, -19600,0.7962189,4.405978,,,,,,,,,,,,,, -19700,0.91118616,4.069024,,,,,,,,,,,,,, -19800,0.8560085,4.316399,,,,,,,,,,,,,, -19900,1.0255411,4.085351,,,,,,,,,,,,,, -20000,0.91312575,4.024132,,,,,,,,,,,,,, -20100,0.7107823,5.4873857,,,,,,,,,,,,,, -20135,,,0.5826562643051147,2.0578503608703613,0.5440599918365479,2.237067222595215,50000.0,0.4216000139713287,2.8544681072235107,10000.0,9277.106140851974,10179.38702774048,9277.106140851974,900.5303463935852,0.5916616916656494,0.0 -20200,0.859821,4.6788435,,,,,,,,,,,,,, -20300,0.72182584,5.1506214,,,,,,,,,,,,,, -20400,0.64944917,5.8761835,,,,,,,,,,,,,, -20500,0.98228633,3.9772108,,,,,,,,,,,,,, -20600,0.66739124,5.846271,,,,,,,,,,,,,, -20700,0.98899645,4.031011,,,,,,,,,,,,,, -20800,0.90142137,4.074954,,,,,,,,,,,,,, -20900,0.7204198,5.8278675,,,,,,,,,,,,,, -21000,0.8697034,5.2502565,,,,,,,,,,,,,, -21053,,,0.5937890410423279,1.9904780387878416,0.5496399998664856,2.194326162338257,50000.0,0.4394000172615051,2.793226718902588,10000.0,9697.392251729963,10636.000517368317,9697.392251729963,936.777051448822,0.619286298751831,0.0 -21100,0.8911721,4.176396,,,,,,,,,,,,,, -21200,0.93021417,4.0331984,,,,,,,,,,,,,, -21300,1.0501657,4.043306,,,,,,,,,,,,,, -21400,0.71731514,5.490774,,,,,,,,,,,,,, -21500,0.7901692,4.635183,,,,,,,,,,,,,, -21600,1.1022576,4.04663,,,,,,,,,,,,,, -21700,0.96136093,4.0746813,,,,,,,,,,,,,, -21800,0.9241045,4.1308603,,,,,,,,,,,,,, -21900,0.7858907,4.962491,,,,,,,,,,,,,, -21972,,,0.613964855670929,1.871402144432068,0.5616599917411804,2.118140935897827,50000.0,0.4426000118255615,2.743307590484619,10000.0,10117.686156272888,11095.38220667839,10117.686156272888,975.779391527176,0.6520423889160156,0.0 -22000,0.94845307,3.9400635,,,,,,,,,,,,,, -22100,0.88811266,4.118296,,,,,,,,,,,,,, -22200,0.99654883,3.968538,,,,,,,,,,,,,, -22300,0.8497168,5.726692,,,,,,,,,,,,,, -22400,0.9820904,3.9405403,,,,,,,,,,,,,, -22500,0.97009826,3.904659,,,,,,,,,,,,,, -22600,0.9413426,4.0054054,,,,,,,,,,,,,, -22700,0.94624084,4.129818,,,,,,,,,,,,,, -22800,0.7699689,5.734841,,,,,,,,,,,,,, -22891,,,0.6078515648841858,1.876462459564209,0.5680199861526489,2.0680747032165527,50000.0,0.4502000212669372,2.677503824234009,10000.0,10537.782514810562,11556.057990550997,10537.782514810562,1016.273241519928,0.684720516204834,0.0 -22900,0.93482476,4.017759,,,,,,,,,,,,,, -23000,1.0273107,3.9635954,,,,,,,,,,,,,, -23100,0.92145103,4.230979,,,,,,,,,,,,,, -23200,0.761017,5.705742,,,,,,,,,,,,,, -23300,0.87022156,4.258831,,,,,,,,,,,,,, -23400,0.9487514,4.0289464,,,,,,,,,,,,,, -23500,1.0095416,3.9460976,,,,,,,,,,,,,, -23600,0.8454769,5.713336,,,,,,,,,,,,,, -23700,0.9605956,4.023005,,,,,,,,,,,,,, -23800,0.8833942,4.306746,,,,,,,,,,,,,, -23807,,,0.6229491829872131,1.856359481811524,0.5753600001335144,2.064666986465454,50000.0,0.4628000259399414,2.6657001972198486,10000.0,10958.110257148744,12018.024418115616,10958.110257148744,1057.8305151462555,0.7141125202178955,0.0 -23900,0.91944295,4.1916246,,,,,,,,,,,,,, -24000,0.9207389,3.9084764,,,,,,,,,,,,,, -24100,0.761517,5.654176,,,,,,,,,,,,,, -24200,0.9210358,3.9787736,,,,,,,,,,,,,, -24300,0.9345751,3.8316982,,,,,,,,,,,,,, -24400,0.9057826,3.9318027,,,,,,,,,,,,,, -24500,1.0434071,3.89452,,,,,,,,,,,,,, -24600,1.0279374,3.788765,,,,,,,,,,,,,, -24700,0.97332776,3.9514108,,,,,,,,,,,,,, -24724,,,0.6285351514816284,1.7818245887756348,0.5766000151634216,2.016919612884521,50000.0,0.4601000249385834,2.632568836212158,10000.0,11378.051944494247,12479.46283340454,11378.051944494247,1099.2352216243744,0.7455592155456543,0.0 -24800,0.8546165,5.04655,,,,,,,,,,,,,, -24900,0.9959923,3.9275064,,,,,,,,,,,,,, -25000,0.8935413,4.1173353,,,,,,,,,,,,,, -25100,0.7023649,5.46241,,,,,,,,,,,,,, -25200,0.9724478,3.954536,,,,,,,,,,,,,, -25300,0.8984044,4.2710223,,,,,,,,,,,,,, -25400,1.0032248,3.9139733,,,,,,,,,,,,,, -25500,0.8517816,4.206404,,,,,,,,,,,,,, -25600,0.8082479,4.5686584,,,,,,,,,,,,,, -25642,,,0.6348242163658142,1.7657935619354248,0.5819199681282043,1.9933316707611084,50000.0,0.4608000218868255,2.615546226501465,10000.0,11798.264919519424,12941.145199537275,11798.264919519424,1140.624148607254,0.7733578681945801,0.0 -25700,0.91626114,3.8480422,,,,,,,,,,,,,, -25800,0.8407334,4.8154974,,,,,,,,,,,,,, -25900,0.8632256,4.662093,,,,,,,,,,,,,, -26000,0.91487634,3.8416846,,,,,,,,,,,,,, -26100,1.0100857,3.9149134,,,,,,,,,,,,,, -26200,0.8061771,5.388971,,,,,,,,,,,,,, -26300,0.9026892,4.011157,,,,,,,,,,,,,, -26400,0.9246925,3.9511113,,,,,,,,,,,,,, -26500,0.84472317,3.946595,,,,,,,,,,,,,, -26562,,,0.6374804377555847,1.7287397384643557,0.5893999934196472,1.9390709400177,50000.0,0.4737000167369842,2.547514677047729,10000.0,12218.406804323196,13396.518834590912,12218.406804323196,1175.7698872089386,0.8034751415252686,0.0 -26600,0.8791989,4.223457,,,,,,,,,,,,,, -26700,0.98813576,3.7349193,,,,,,,,,,,,,, -26800,0.95289314,3.9634233,,,,,,,,,,,,,, -26900,0.9514996,4.284561,,,,,,,,,,,,,, -27000,1.0732597,3.8391519,,,,,,,,,,,,,, -27100,0.91130376,3.7680626,,,,,,,,,,,,,, -27200,0.8062008,4.951329,,,,,,,,,,,,,, -27300,0.9342178,4.0968657,,,,,,,,,,,,,, -27400,1.0503197,3.8772364,,,,,,,,,,,,,, -27477,,,0.6408398151397705,1.7599726915359497,0.5889399647712708,1.9790270328521729,50000.0,0.4715000092983246,2.6012699604034424,10000.0,12638.708625555038,13852.880910873411,12638.708625555038,1211.7467403411863,0.8337299823760986,0.0 -27500,1.0027446,4.2805843,,,,,,,,,,,,,, -27600,0.9033756,4.1739197,,,,,,,,,,,,,, -27700,0.87270874,4.917639,,,,,,,,,,,,,, -27800,0.9551657,3.9700885,,,,,,,,,,,,,, -27900,1.0027555,3.8764038,,,,,,,,,,,,,, -28000,1.0098298,3.7779279,,,,,,,,,,,,,, -28100,1.0469549,3.9187288,,,,,,,,,,,,,, -28200,1.0222335,3.7528024,,,,,,,,,,,,,, -28300,1.0529442,3.7649145,,,,,,,,,,,,,, -28395,,,0.6732617020606995,1.6018322706222534,0.5973399877548218,1.9325991868972776,50000.0,0.4756000339984894,2.566092729568481,10000.0,13059.00931596756,14311.540762901306,13059.00931596756,1250.0245730876925,0.862741231918335,0.0 -28400,0.7702968,5.362602,,,,,,,,,,,,,, -28500,0.9933793,3.7481163,,,,,,,,,,,,,, -28600,0.79713047,5.3884478,,,,,,,,,,,,,, -28700,0.8120528,5.617921,,,,,,,,,,,,,, -28800,0.8105028,4.9263787,,,,,,,,,,,,,, -28900,0.92957187,3.7201717,,,,,,,,,,,,,, -29000,0.86523193,4.178105,,,,,,,,,,,,,, -29100,0.8353898,4.425715,,,,,,,,,,,,,, -29200,0.74752945,5.318571,,,,,,,,,,,,,, -29300,1.0456656,3.8288672,,,,,,,,,,,,,, -29312,,,0.6476367115974426,1.7281827926635742,0.5992599725723267,1.946011662483216,50000.0,0.4768000245094299,2.5661449432373047,10000.0,13479.215354442596,14772.042618513107,13479.215354442596,1290.234845161438,0.8960211277008057,0.0 -29400,0.9046203,4.0266767,,,,,,,,,,,,,, -29500,0.8795575,4.950678,,,,,,,,,,,,,, -29600,1.0321153,3.7588763,,,,,,,,,,,,,, -29700,0.9181044,3.8270369,,,,,,,,,,,,,, -29800,1.054212,3.8061018,,,,,,,,,,,,,, -29900,0.8610804,4.881363,,,,,,,,,,,,,, -30000,1.0498875,4.0980577,,,,,,,,,,,,,, -30100,0.94358075,3.8717623,,,,,,,,,,,,,, -30200,0.8470659,4.2767997,,,,,,,,,,,,,, -30230,,,0.6563280820846558,1.6538163423538208,0.6067399978637695,1.877922296524048,50000.0,0.4821000099182129,2.5124006271362305,10000.0,13899.440378904344,15233.832745790482,13899.440378904344,1331.7159614562988,0.9272348880767822,0.0 -30300,0.8096651,5.48256,,,,,,,,,,,,,, -30400,1.0943118,3.7175825,,,,,,,,,,,,,, -30500,0.8213421,5.5419717,,,,,,,,,,,,,, -30600,0.98282206,3.7915215,,,,,,,,,,,,,, -30700,0.82679725,4.966925,,,,,,,,,,,,,, -30800,1.0262125,3.7943547,,,,,,,,,,,,,, -30900,0.88606983,4.2331758,,,,,,,,,,,,,, -31000,0.8021998,5.2165766,,,,,,,,,,,,,, -31100,0.86614347,4.094508,,,,,,,,,,,,,, -31148,,,0.6665819883346558,1.621629238128662,0.6076200008392334,1.8831592798233032,50000.0,0.4859000146389007,2.506326675415039,10000.0,14319.611972808838,15689.16557955742,14319.611972808838,1366.7929692268372,0.9575145244598388,0.0 -31200,0.9532481,3.663616,,,,,,,,,,,,,, -31300,0.9100233,5.666293,,,,,,,,,,,,,, -31400,0.96718735,4.180649,,,,,,,,,,,,,, -31500,0.8473484,3.928837,,,,,,,,,,,,,, -31600,0.9211078,3.73946,,,,,,,,,,,,,, -31700,0.76606816,4.883617,,,,,,,,,,,,,, -31800,1.010112,4.0553956,,,,,,,,,,,,,, -31900,0.9020924,4.1462307,,,,,,,,,,,,,, -32000,1.1217474,3.8111167,,,,,,,,,,,,,, -32064,,,0.6660351157188416,1.5995913743972778,0.616919994354248,1.8113443851470947,50000.0,0.4937000274658203,2.442761182785034,10000.0,14739.628532409668,16148.70953464508,14739.628532409668,1406.235392332077,0.989626169204712,0.0 -32100,0.8566395,4.373824,,,,,,,,,,,,,, -32200,1.1754119,3.7744334,,,,,,,,,,,,,, -32300,0.99658835,3.687417,,,,,,,,,,,,,, -32400,0.8976452,4.286884,,,,,,,,,,,,,, -32500,0.9847798,3.658011,,,,,,,,,,,,,, -32600,0.86936516,4.0876436,,,,,,,,,,,,,, -32700,0.85920286,4.5463486,,,,,,,,,,,,,, -32800,0.810028,5.3808384,,,,,,,,,,,,,, -32900,0.8347292,5.0351067,,,,,,,,,,,,,, -32979,,,0.66064453125,1.6689999103546145,0.6136199831962585,1.892295241355896,50000.0,0.489300012588501,2.53031587600708,10000.0,15159.826711416245,16611.425479650497,15159.826711416245,1448.665411233902,1.024622678756714,0.0 -33000,1.0607302,3.7124896,,,,,,,,,,,,,, -33100,0.9783141,3.8479738,,,,,,,,,,,,,, -33200,0.95251596,4.0858755,,,,,,,,,,,,,, -33300,0.89460355,5.6157866,,,,,,,,,,,,,, -33400,0.81375664,5.0732036,,,,,,,,,,,,,, -33500,1.0356131,3.7242095,,,,,,,,,,,,,, -33600,0.84338915,5.38762,,,,,,,,,,,,,, -33700,0.9971429,3.744428,,,,,,,,,,,,,, -33800,0.8660893,4.4191957,,,,,,,,,,,,,, -33898,,,0.6800194978713989,1.5640877485275269,0.6166399717330933,1.8367356061935425,50000.0,0.4989000260829925,2.4557673931121826,10000.0,15579.860214233398,17073.8019824028,15579.860214233398,1490.9241652488708,1.0565321445465088,0.0 -33900,0.86313385,4.8465514,,,,,,,,,,,,,, -34000,0.8233644,4.9562054,,,,,,,,,,,,,, -34100,0.99609375,3.7009192,,,,,,,,,,,,,, -34200,0.92040396,4.083099,,,,,,,,,,,,,, -34300,0.99526703,3.7165358,,,,,,,,,,,,,, -34400,1.0496981,3.7476814,,,,,,,,,,,,,, -34500,0.94918257,3.7301662,,,,,,,,,,,,,, -34600,0.85787916,4.7985415,,,,,,,,,,,,,, -34700,0.826766,5.5641947,,,,,,,,,,,,,, -34800,0.9710895,3.7842438,,,,,,,,,,,,,, -34815,,,0.6681445240974426,1.6435627937316897,0.6217399835586548,1.854430675506592,50000.0,0.498600035905838,2.475801467895508,10000.0,16000.001391410828,17533.330276489258,16000.001391410828,1530.2285268306732,1.0866823196411133,0.0 -34900,0.92203104,4.425086,,,,,,,,,,,,,, -35000,0.8131949,4.7953734,,,,,,,,,,,,,, -35100,0.92560935,3.7247994,,,,,,,,,,,,,, -35200,0.9759463,3.7870183,,,,,,,,,,,,,, -35300,0.91347414,4.0505047,,,,,,,,,,,,,, -35400,0.9715446,4.00423,,,,,,,,,,,,,, -35500,0.9875375,4.5087485,,,,,,,,,,,,,, -35600,0.8803153,4.803266,,,,,,,,,,,,,, -35700,1.0371999,3.6604786,,,,,,,,,,,,,, -35733,,,0.6791015267372131,1.499345064163208,0.6260600090026855,1.74126398563385,50000.0,0.5069000124931335,2.359993934631348,10000.0,16420.151757001877,17996.025110006332,16420.151757001877,1572.68993806839,1.116673469543457,0.0 -35800,0.99329257,3.8213944,,,,,,,,,,,,,, -35900,0.95982516,4.5691833,,,,,,,,,,,,,, -36000,1.0135249,3.7775414,,,,,,,,,,,,,, -36100,0.88239527,5.4769287,,,,,,,,,,,,,, -36200,0.98076963,3.7943978,,,,,,,,,,,,,, -36300,1.0017551,3.5890148,,,,,,,,,,,,,, -36400,1.0557253,3.7255921,,,,,,,,,,,,,, -36500,0.88315165,4.3300424,,,,,,,,,,,,,, -36600,0.9588125,3.6345334,,,,,,,,,,,,,, -36651,,,0.6833788752555847,1.5450021028518677,0.6283800005912781,1.795507788658142,50000.0,0.5042999982833862,2.42209792137146,10000.0,16840.42435503006,18456.86468553543,16840.42435503006,1613.168706893921,1.1520779132843018,0.0 -36700,0.93677247,3.9911757,,,,,,,,,,,,,, -36800,1.0379497,3.7913547,,,,,,,,,,,,,, -36900,1.0251763,3.6964135,,,,,,,,,,,,,, -37000,0.9225961,5.150618,,,,,,,,,,,,,, -37100,0.8900138,4.9292727,,,,,,,,,,,,,, -37200,1.0372908,3.948456,,,,,,,,,,,,,, -37300,0.9678074,3.5924678,,,,,,,,,,,,,, -37400,1.0160211,3.686813,,,,,,,,,,,,,, -37500,0.98755085,4.1659455,,,,,,,,,,,,,, -37570,,,0.6815234422683716,1.5546667575836182,0.6279199719429016,1.7888097763061523,50000.0,0.5034000277519226,2.409806251525879,10000.0,17260.590923786163,18919.72491168976,17260.590923786163,1655.7754156589508,1.1850988864898682,0.0 -37600,0.9950062,3.6772254,,,,,,,,,,,,,, -37700,1.0190758,3.663836,,,,,,,,,,,,,, -37800,1.0188732,5.234878,,,,,,,,,,,,,, -37900,0.96118075,3.676421,,,,,,,,,,,,,, -38000,1.0300267,3.6813931,,,,,,,,,,,,,, -38100,1.0948431,3.5750244,,,,,,,,,,,,,, -38200,1.0726705,3.628455,,,,,,,,,,,,,, -38300,0.8335265,5.0328617,,,,,,,,,,,,,, -38400,0.9313028,3.7866812,,,,,,,,,,,,,, -38491,,,0.6863867044448853,1.5237195491790771,0.634660005569458,1.7491233348846436,50000.0,0.5108000040054321,2.3639705181121826,10000.0,17680.91156888008,19382.169947624207,17680.91156888008,1697.8138728141785,1.2178847789764404,0.0 -38500,1.021221,3.6147792,,,,,,,,,,,,,, -38600,1.0080527,3.8614645,,,,,,,,,,,,,, -38700,0.9117763,4.012222,,,,,,,,,,,,,, -38800,1.0127932,3.7739577,,,,,,,,,,,,,, -38900,0.90009433,5.4023886,,,,,,,,,,,,,, -39000,0.93414557,4.9551497,,,,,,,,,,,,,, -39100,1.0640831,3.7700315,,,,,,,,,,,,,, -39200,1.0234076,3.9796066,,,,,,,,,,,,,, -39300,1.0378078,3.6960866,,,,,,,,,,,,,, -39400,0.914532,4.267885,,,,,,,,,,,,,, -39410,,,0.6941015720367432,1.4616892337799072,0.6366599798202515,1.711472749710083,50000.0,0.5131000280380249,2.3418853282928467,10000.0,18101.19814991951,19842.69286584854,18101.19814991951,1737.9596025943756,1.2555735111236572,0.0 -39500,1.0919006,3.7642095,,,,,,,,,,,,,, -39600,1.0135524,3.5764241,,,,,,,,,,,,,, -39700,1.0213331,3.5653553,,,,,,,,,,,,,, -39800,0.9288753,4.1311274,,,,,,,,,,,,,, -39900,1.0614614,3.6366305,,,,,,,,,,,,,, -40000,0.9632009,3.8432598,,,,,,,,,,,,,, -40100,1.0322155,3.626994,,,,,,,,,,,,,, -40200,1.0631388,3.6409109,,,,,,,,,,,,,, -40300,0.9869614,3.7906055,,,,,,,,,,,,,, -40329,,,0.7090820074081421,1.4305404424667358,0.6382399797439575,1.7424010038375854,50000.0,0.5110000371932983,2.36631178855896,10000.0,18521.3622546196,20305.8983001709,18521.3622546196,1780.9176177978516,1.2859671115875244,0.0 -40400,1.0293456,3.6102068,,,,,,,,,,,,,, -40500,0.89331657,4.082416,,,,,,,,,,,,,, -40600,0.85342705,5.409062,,,,,,,,,,,,,, -40700,1.0320836,3.677057,,,,,,,,,,,,,, -40800,0.89723706,5.182522,,,,,,,,,,,,,, -40900,0.94056225,4.131309,,,,,,,,,,,,,, -41000,0.8834146,4.457677,,,,,,,,,,,,,, -41100,1.0106298,5.257336,,,,,,,,,,,,,, -41200,1.0398743,3.6349878,,,,,,,,,,,,,, -41246,,,0.6861523389816284,1.5130900144577026,0.6380400061607361,1.7185349464416504,50000.0,0.511400043964386,2.348185062408448,10000.0,18941.50358247757,20760.38983654976,18941.50358247757,1815.185781955719,1.3155813217163086,0.0 -41300,1.0281575,3.6287313,,,,,,,,,,,,,, -41400,0.9226202,5.363484,,,,,,,,,,,,,, -41500,0.9618579,4.2867713,,,,,,,,,,,,,, -41600,1.103941,3.6661224,,,,,,,,,,,,,, -41700,0.85575837,4.4844623,,,,,,,,,,,,,, -41800,1.0778097,3.6958804,,,,,,,,,,,,,, -41900,1.0905876,3.8229337,,,,,,,,,,,,,, -42000,1.0438377,3.6252155,,,,,,,,,,,,,, -42100,0.9919505,4.257034,,,,,,,,,,,,,, -42165,,,0.6976562142372131,1.452872633934021,0.6402599811553955,1.701892971992493,50000.0,0.5162000060081482,2.3191921710968018,10000.0,19361.500798225403,21222.585634231567,19361.500798225403,1857.3007380962367,1.3462746143341064,0.0 -42200,1.0861139,3.6813087,,,,,,,,,,,,,, -42300,0.9712358,4.6740055,,,,,,,,,,,,,, -42400,0.9702077,5.224998,,,,,,,,,,,,,, -42500,1.0105026,4.7291236,,,,,,,,,,,,,, -42600,1.0179522,3.618339,,,,,,,,,,,,,, -42700,1.1121786,3.9186122,,,,,,,,,,,,,, -42800,1.036385,3.6029403,,,,,,,,,,,,,, -42900,0.91182524,5.1930175,,,,,,,,,,,,,, -43000,0.87253773,5.4159937,,,,,,,,,,,,,, -43084,,,0.7085351347923279,1.3984525203704834,0.6442199945449829,1.6737351417541504,50000.0,0.5215000510215759,2.2900736331939697,10000.0,19781.848207473755,21681.09438085556,19781.848207473755,1895.374864578247,1.380638599395752,0.0 -43100,0.90824455,4.1870465,,,,,,,,,,,,,, -43200,1.0262983,4.224984,,,,,,,,,,,,,, -43300,1.0179745,4.896377,,,,,,,,,,,,,, -43400,1.0156926,3.5820694,,,,,,,,,,,,,, -43500,1.105293,3.587769,,,,,,,,,,,,,, -43600,1.0179511,3.5447702,,,,,,,,,,,,,, -43700,1.0386664,3.5760074,,,,,,,,,,,,,, -43800,0.99329925,3.577623,,,,,,,,,,,,,, -43900,0.946611,5.224964,,,,,,,,,,,,,, -44000,,,0.6931054592132568,1.5006791353225708,0.6433599591255188,1.717563271522522,50000.0,0.5199000239372253,2.348926305770874,10000.0,20201.83907580376,22140.8764231205,20201.83907580376,1935.0757467746728,1.41862154006958,0.0 -44000,1.1352296,3.609577,,,,,,,,,,,,,, -44100,0.9895463,3.764942,,,,,,,,,,,,,, -44200,1.0326979,3.5709324,,,,,,,,,,,,,, -44300,1.2410064,3.6798236,,,,,,,,,,,,,, -44400,0.97380364,3.6660337,,,,,,,,,,,,,, -44500,1.0743496,4.9057083,,,,,,,,,,,,,, -44600,1.0283588,4.6615014,,,,,,,,,,,,,, -44700,1.071624,3.6517165,,,,,,,,,,,,,, -44800,0.93606645,3.8124645,,,,,,,,,,,,,, -44900,1.0945266,3.6568532,,,,,,,,,,,,,, -44915,,,0.7015234231948853,1.447856068611145,0.6456999778747559,1.6839470863342283,50000.0,0.5258000493049622,2.2985410690307617,10000.0,20621.8154001236,22603.06175518036,20621.8154001236,1977.193481206894,1.4574127197265625,0.0 -45000,1.1315455,3.6884084,,,,,,,,,,,,,, -45100,0.90678746,4.7870684,,,,,,,,,,,,,, -45200,1.0793451,3.64847,,,,,,,,,,,,,, -45300,1.0081359,3.6248431,,,,,,,,,,,,,, -45400,0.8918159,4.290958,,,,,,,,,,,,,, -45500,0.9667163,4.4311953,,,,,,,,,,,,,, -45600,1.2344873,3.5072057,,,,,,,,,,,,,, -45700,1.0009282,3.618683,,,,,,,,,,,,,, -45800,1.0376127,3.590316,,,,,,,,,,,,,, -45835,,,0.7068163752555847,1.426822304725647,0.6466799974441528,1.6960935592651367,50000.0,0.5192000269889832,2.3293352127075195,10000.0,21042.14222931862,23062.734843730927,21042.14222931862,2016.4534318447115,1.4908981323242188,0.0 -45900,0.9012858,5.180382,,,,,,,,,,,,,, -46000,0.9972616,3.7290711,,,,,,,,,,,,,, -46100,1.0306327,3.6840441,,,,,,,,,,,,,, -46200,1.0228683,4.1605244,,,,,,,,,,,,,, -46300,1.0094097,3.7442122,,,,,,,,,,,,,, -46400,1.0632563,3.7606277,,,,,,,,,,,,,, -46500,1.0477718,3.526128,,,,,,,,,,,,,, -46600,1.2162464,3.5675898,,,,,,,,,,,,,, -46700,0.98876786,4.164054,,,,,,,,,,,,,, -46755,,,0.706250011920929,1.3982168436050415,0.6541199684143066,1.629255294799805,50000.0,0.5317000150680542,2.2454986572265625,10000.0,21462.42486310005,23520.222547769547,21462.42486310005,2053.571048259735,1.5256400108337402,0.0 -46800,1.0089024,3.5781364,,,,,,,,,,,,,, -46900,1.006702,4.103663,,,,,,,,,,,,,, -47000,0.9129489,4.829139,,,,,,,,,,,,,, -47100,0.9370305,4.30798,,,,,,,,,,,,,, -47200,1.0942938,3.5300586,,,,,,,,,,,,,, -47300,0.9920619,3.6383173,,,,,,,,,,,,,, -47400,1.0217601,3.5245614,,,,,,,,,,,,,, -47500,1.0554334,3.530191,,,,,,,,,,,,,, -47600,0.9823994,3.9879885,,,,,,,,,,,,,, -47674,,,0.7119921445846558,1.3972750902175903,0.6566799879074097,1.6393754482269287,50000.0,0.5311000347137451,2.2582650184631348,10000.0,21882.692858219147,23982.11646294593,21882.692858219147,2095.10959148407,1.5598258972167969,0.0 -47700,0.9520664,4.269885,,,,,,,,,,,,,, -47800,1.029373,3.6365647,,,,,,,,,,,,,, -47900,1.0397962,3.5847714,,,,,,,,,,,,,, -48000,0.9489725,3.8265052,,,,,,,,,,,,,, -48100,1.1080618,3.6179652,,,,,,,,,,,,,, -48200,1.0059544,3.6698205,,,,,,,,,,,,,, -48300,1.1768794,3.572732,,,,,,,,,,,,,, -48400,1.036825,3.602577,,,,,,,,,,,,,, -48500,1.0237658,5.0804076,,,,,,,,,,,,,, -48595,,,0.7161718606948853,1.3505080938339231,0.6545000076293945,1.6139466762542725,50000.0,0.5323000550270081,2.2515869140625,10000.0,22303.524026870728,24444.62548708916,22303.524026870728,2136.696215391159,1.5978095531463623,0.0 -48600,1.0311798,4.1108103,,,,,,,,,,,,,, -48700,0.9637624,4.854989,,,,,,,,,,,,,, -48800,1.1940761,3.5790417,,,,,,,,,,,,,, -48900,1.0901203,3.614222,,,,,,,,,,,,,, -49000,0.96465576,5.313847,,,,,,,,,,,,,, -49100,1.057635,3.5946822,,,,,,,,,,,,,, -49200,1.1532234,3.5825498,,,,,,,,,,,,,, -49300,0.90998924,4.4304824,,,,,,,,,,,,,, -49400,0.9980964,3.6128383,,,,,,,,,,,,,, -49500,1.0537913,3.5744977,,,,,,,,,,,,,, -49513,,,0.7284765243530273,1.366170048713684,0.6598399877548218,1.6686944961547852,50000.0,0.5327000021934509,2.3007359504699707,10000.0,22723.87028694153,24901.43813586235,22723.87028694153,2173.0738096237183,1.6330816745758057,0.0 -49600,0.974492,3.7248065,,,,,,,,,,,,,, -49700,1.189547,3.515381,,,,,,,,,,,,,, -49800,1.1027411,3.6792362,,,,,,,,,,,,,, -49900,1.0662596,5.31802,,,,,,,,,,,,,, -50000,1.0695534,5.2467017,,,,,,,,,,,,,, -50100,1.0844305,3.5950341,,,,,,,,,,,,,, -50200,1.0197116,3.615178,,,,,,,,,,,,,, -50300,1.0676913,3.5635211,,,,,,,,,,,,,, -50400,0.9170226,4.9847345,,,,,,,,,,,,,, -50430,,,0.7109179496765137,1.366124987602234,0.6569799780845642,1.603185772895813,50000.0,0.5389000177383423,2.222001552581787,10000.0,23143.849014759064,25363.47724819184,23143.849014759064,2215.0406877994537,1.67065167427063,0.0 -50500,1.0310621,3.5431352,,,,,,,,,,,,,, -50600,1.1520135,3.5239978,,,,,,,,,,,,,, -50700,0.95333225,4.1046457,,,,,,,,,,,,,, -50800,0.88836855,4.549981,,,,,,,,,,,,,, -50900,1.1046991,3.6378877,,,,,,,,,,,,,, -51000,1.1229587,3.6001554,,,,,,,,,,,,,, -51100,0.9993393,5.381174,,,,,,,,,,,,,, -51200,1.1057837,3.5216153,,,,,,,,,,,,,, -51300,1.0908955,3.641657,,,,,,,,,,,,,, -51346,,,0.7173827886581421,1.383202314376831,0.6582199931144714,1.6434626579284668,50000.0,0.5329000353813171,2.27443790435791,10000.0,23563.82791399956,25825.90231370926,23563.82791399956,2257.400363683701,1.7045204639434814,0.0 -51400,1.0764471,3.5122159,,,,,,,,,,,,,, -51500,1.0266836,3.6406577,,,,,,,,,,,,,, -51600,1.0276229,3.8338263,,,,,,,,,,,,,, -51700,1.0921855,3.635588,,,,,,,,,,,,,, -51800,1.0114012,5.42583,,,,,,,,,,,,,, -51900,1.0896572,3.8666644,,,,,,,,,,,,,, -52000,1.0248586,4.084232,,,,,,,,,,,,,, -52100,1.0758817,3.4582396,,,,,,,,,,,,,, -52200,1.097401,3.5247347,,,,,,,,,,,,,, -52265,,,0.7373827695846558,1.3297600746154783,0.6584999561309814,1.656172513961792,50000.0,0.5343000292778015,2.2770166397094727,10000.0,23984.125440120697,26285.075368642807,23984.125440120697,2296.1784982681274,1.7483222484588623,0.0 -52300,0.95229405,4.277356,,,,,,,,,,,,,, -52400,1.0748568,3.5829954,,,,,,,,,,,,,, -52500,1.0340931,3.500022,,,,,,,,,,,,,, -52600,1.1048626,5.2315307,,,,,,,,,,,,,, -52700,1.06175,5.286289,,,,,,,,,,,,,, -52800,1.0638549,3.615297,,,,,,,,,,,,,, -52900,0.91378385,3.9962575,,,,,,,,,,,,,, -53000,1.1688063,3.5461993,,,,,,,,,,,,,, -53100,1.1048464,3.57942,,,,,,,,,,,,,, -53181,,,0.7142968773841858,1.3528928756713867,0.663599967956543,1.5829827785491943,50000.0,0.5390000343322754,2.2040934562683105,10000.0,24404.181349277496,26747.570974826813,24404.181349277496,2338.523509502411,1.7898108959197998,0.0 -53200,1.2223247,3.5823224,,,,,,,,,,,,,, -53300,1.104796,3.522435,,,,,,,,,,,,,, -53400,1.0661662,3.492004,,,,,,,,,,,,,, -53500,1.0251052,3.5385706,,,,,,,,,,,,,, -53600,1.0315448,3.6006064,,,,,,,,,,,,,, -53700,1.1082039,3.4741902,,,,,,,,,,,,,, -53800,1.1228083,3.5534935,,,,,,,,,,,,,, -53900,1.0723687,3.5985024,,,,,,,,,,,,,, -54000,1.01939,4.165265,,,,,,,,,,,,,, -54095,,,0.7222851514816284,1.373044729232788,0.6613799929618835,1.6264369487762451,50000.0,0.5357000231742859,2.2335379123687744,10000.0,24824.275985479355,27205.53682422638,24824.275985479355,2376.301381111145,1.831099271774292,0.0 -54100,0.9068815,4.276989,,,,,,,,,,,,,, -54200,1.1212169,3.5931141,,,,,,,,,,,,,, -54300,1.0523872,3.8142128,,,,,,,,,,,,,, -54400,1.1796069,3.6616244,,,,,,,,,,,,,, -54500,1.0606539,3.43699,,,,,,,,,,,,,, -54600,1.0013815,4.138448,,,,,,,,,,,,,, -54700,1.1463568,3.5084906,,,,,,,,,,,,,, -54800,1.0908037,3.583492,,,,,,,,,,,,,, -54900,1.0303074,4.2984133,,,,,,,,,,,,,, -55000,1.0939664,3.5162363,,,,,,,,,,,,,, -55013,,,0.7419335842132568,1.2533196210861206,0.6665599942207336,1.5780009031295776,50000.0,0.5479000210762024,2.186769962310791,10000.0,25244.52186369896,27664.27063536644,25244.52186369896,2414.699910640717,1.867379903793335,0.0 -55100,1.0862305,5.14213,,,,,,,,,,,,,, -55200,1.0387899,3.9752922,,,,,,,,,,,,,, -55300,1.1049589,5.1626754,,,,,,,,,,,,,, -55400,1.0739747,3.5252957,,,,,,,,,,,,,, -55500,1.0176967,4.2586236,,,,,,,,,,,,,, -55600,0.98875546,3.8489935,,,,,,,,,,,,,, -55700,1.252943,3.6254628,,,,,,,,,,,,,, -55800,1.1268796,3.47919,,,,,,,,,,,,,, -55900,1.2076116,3.5192502,,,,,,,,,,,,,, -55931,,,0.72083979845047,1.3554744720458984,0.6647799611091614,1.584027886390686,50000.0,0.5414000153541565,2.191251754760742,10000.0,25664.90718483925,28127.09863138199,25664.90718483925,2457.052332401276,1.9043676853179927,0.0 -56000,0.9718603,4.02561,,,,,,,,,,,,,, -56100,1.0425831,3.6850715,,,,,,,,,,,,,, -56200,1.0546345,3.7557309,,,,,,,,,,,,,, -56300,1.1026598,3.6131098,,,,,,,,,,,,,, -56400,1.2182666,3.5026956,,,,,,,,,,,,,, -56500,1.1297877,3.513445,,,,,,,,,,,,,, -56600,0.94661385,4.7570734,,,,,,,,,,,,,, -56700,1.0477127,3.8165157,,,,,,,,,,,,,, -56800,1.1039366,3.966373,,,,,,,,,,,,,, -56846,,,0.7316796779632568,1.2941749095916748,0.6717199683189392,1.55128014087677,50000.0,0.5484000444412231,2.171085357666016,10000.0,26085.24145579338,28589.73458695412,26085.24145579338,2499.262162208557,1.9429833889007568,0.0 -56900,1.1196674,5.1313453,,,,,,,,,,,,,, -57000,1.045809,3.5040736,,,,,,,,,,,,,, -57100,1.0098947,4.44922,,,,,,,,,,,,,, -57200,1.0717163,3.5196571,,,,,,,,,,,,,, -57300,1.1562344,3.604778,,,,,,,,,,,,,, -57400,1.0702277,3.448562,,,,,,,,,,,,,, -57500,1.0140045,3.5538135,,,,,,,,,,,,,, -57600,1.0846146,5.086258,,,,,,,,,,,,,, -57700,1.226309,3.506034,,,,,,,,,,,,,, -57762,,,0.7390429377555847,1.2824914455413818,0.6715999841690063,1.5749740600585938,50000.0,0.5490000247955322,2.189195394515991,10000.0,26505.591643333435,29049.86763215065,26505.591643333435,2538.9583842754364,1.9764173030853271,0.0 -57800,1.101281,3.5246177,,,,,,,,,,,,,, -57900,1.013892,4.642469,,,,,,,,,,,,,, -58000,1.1071851,4.5488353,,,,,,,,,,,,,, -58100,1.096918,3.559443,,,,,,,,,,,,,, -58200,1.2724698,3.587545,,,,,,,,,,,,,, -58300,1.159674,3.4703715,,,,,,,,,,,,,, -58400,0.98623043,4.415076,,,,,,,,,,,,,, -58500,1.05953,3.4924018,,,,,,,,,,,,,, -58600,1.1822499,3.5600102,,,,,,,,,,,,,, -58678,,,0.7281835675239563,1.3011599779129028,0.6716399788856506,1.546015381813049,50000.0,0.5476000308990479,2.1589043140411377,10000.0,26925.87080693245,29514.428884983063,26925.87080693245,2583.152119636536,2.011991500854492,0.0 -58700,0.9876223,4.4117217,,,,,,,,,,,,,, -58800,0.98741573,3.7388008,,,,,,,,,,,,,, -58900,1.1758238,3.4506001,,,,,,,,,,,,,, -59000,1.0425711,4.432234,,,,,,,,,,,,,, -59100,0.9923078,4.8375754,,,,,,,,,,,,,, -59200,1.039816,3.3794293,,,,,,,,,,,,,, -59300,1.0668057,5.3343434,,,,,,,,,,,,,, -59400,1.2594563,3.3401282,,,,,,,,,,,,,, -59500,1.0772343,3.4684222,,,,,,,,,,,,,, -59597,,,0.7227734327316284,1.3543765544891355,0.6686399579048157,1.5894092321395874,50000.0,0.5433000326156616,2.194332361221313,10000.0,27346.01905035973,29972.257979154587,27346.01905035973,2620.7382864952087,2.05407190322876,0.0 -59600,1.0175332,5.1101255,,,,,,,,,,,,,, -59700,1.0731531,3.4154425,,,,,,,,,,,,,, -59800,1.0548131,3.8634021,,,,,,,,,,,,,, -59900,1.0164003,4.484619,,,,,,,,,,,,,, -60000,1.0968148,3.4994907,,,,,,,,,,,,,, -60100,1.1461023,3.4976234,,,,,,,,,,,,,, -60200,1.1886531,3.4851546,,,,,,,,,,,,,, -60300,1.1065637,3.4602785,,,,,,,,,,,,,, -60400,1.02093,4.667905,,,,,,,,,,,,,, -60500,1.0759732,3.3904328,,,,,,,,,,,,,, -60512,,,0.7395703196525574,1.2908554077148438,0.6735599637031555,1.5688538551330566,50000.0,0.5527999997138977,2.1810519695281982,10000.0,27766.20599770546,30435.395832777023,27766.20599770546,2663.5940973758698,2.096351623535156,0.0 -60600,1.1103466,5.1831074,,,,,,,,,,,,,, -60700,1.0535067,4.9894404,,,,,,,,,,,,,, -60800,0.96450275,4.4202437,,,,,,,,,,,,,, -60900,1.0680753,3.5149052,,,,,,,,,,,,,, -61000,1.361683,3.4668295,,,,,,,,,,,,,, -61100,1.1222694,3.3468294,,,,,,,,,,,,,, -61200,1.0532309,3.3861284,,,,,,,,,,,,,, -61300,1.1126504,3.4132442,,,,,,,,,,,,,, -61400,1.1679969,5.1823263,,,,,,,,,,,,,, -61429,,,0.7494531273841858,1.2534152269363403,0.6738799810409546,1.5649051666259766,50000.0,0.5517000555992126,2.1733086109161377,10000.0,28186.2903354168,30897.46825361252,28186.2903354168,2705.494818210602,2.1315011978149414,0.0 -61500,1.0978558,5.0084767,,,,,,,,,,,,,, -61600,1.2323608,3.4001746,,,,,,,,,,,,,, -61700,1.0864115,5.1056833,,,,,,,,,,,,,, -61800,1.0967984,3.7393465,,,,,,,,,,,,,, -61900,0.99643,4.4462943,,,,,,,,,,,,,, -62000,1.1539351,3.4269202,,,,,,,,,,,,,, -62100,1.1314309,3.4150186,,,,,,,,,,,,,, -62200,1.0035686,3.7736118,,,,,,,,,,,,,, -62300,0.9996284,4.4833703,,,,,,,,,,,,,, -62343,,,0.7285351157188416,1.318561315536499,0.6755599975585938,1.5545648336410522,50000.0,0.5503000020980835,2.170751571655273,10000.0,28606.247513771057,31358.454117536545,28606.247513771057,2746.435801267624,2.166461229324341,0.0 -62400,1.0881945,3.8718653,,,,,,,,,,,,,, -62500,1.0433434,4.192466,,,,,,,,,,,,,, -62600,1.1611875,3.3824427,,,,,,,,,,,,,, -62700,1.0922489,5.1197443,,,,,,,,,,,,,, -62800,1.1243067,3.848599,,,,,,,,,,,,,, -62900,1.1772761,3.4630673,,,,,,,,,,,,,, -63000,1.1190431,3.5166461,,,,,,,,,,,,,, -63100,1.0554669,3.4882338,,,,,,,,,,,,,, -63200,1.2121084,3.4448156,,,,,,,,,,,,,, -63259,,,0.7363085746765137,1.2871750593185425,0.6755399703979492,1.5438107252120972,50000.0,0.5529000163078308,2.176795959472656,10000.0,29026.210033893585,31816.76087665558,29026.210033893585,2784.688705921173,2.2054715156555176,0.0 -63300,1.0763956,4.013747,,,,,,,,,,,,,, -63400,1.138614,5.1312222,,,,,,,,,,,,,, -63500,1.0117557,4.44853,,,,,,,,,,,,,, -63600,1.1284655,3.4244065,,,,,,,,,,,,,, -63700,1.0123982,4.209993,,,,,,,,,,,,,, -63800,1.1778392,3.3879154,,,,,,,,,,,,,, -63900,1.1368577,4.6585298,,,,,,,,,,,,,, -64000,1.114587,3.358918,,,,,,,,,,,,,, -64100,1.1153549,3.6072407,,,,,,,,,,,,,, -64175,,,0.752734363079071,1.2191165685653689,0.6774399876594543,1.541650414466858,50000.0,0.5590000152587891,2.1525299549102783,10000.0,29446.1415143013,32276.34983420372,29446.1415143013,2824.250700235367,2.2430827617645264,0.0 -64200,1.1731032,3.3960602,,,,,,,,,,,,,, -64300,1.2036324,3.4573894,,,,,,,,,,,,,, -64400,1.163325,3.3694313,,,,,,,,,,,,,, -64500,1.1346142,3.5012436,,,,,,,,,,,,,, -64600,1.1860877,3.5290034,,,,,,,,,,,,,, -64700,1.1913795,3.45431,,,,,,,,,,,,,, -64800,1.0715369,3.4076488,,,,,,,,,,,,,, -64900,1.0516697,4.0521965,,,,,,,,,,,,,, -65000,1.1623387,5.258745,,,,,,,,,,,,,, -65091,,,0.7371875047683716,1.289478063583374,0.6808599829673767,1.5334938764572144,50000.0,0.5505000352859497,2.155724048614502,10000.0,29866.106069087986,32733.75764989853,29866.106069087986,2861.604071855545,2.2804622650146484,0.0 -65100,1.0086759,4.5918455,,,,,,,,,,,,,, -65200,1.0369649,4.6969156,,,,,,,,,,,,,, -65300,1.0245159,3.7140415,,,,,,,,,,,,,, -65400,1.1787409,3.5144155,,,,,,,,,,,,,, -65500,1.169573,3.498598,,,,,,,,,,,,,, -65600,1.0688194,3.7343278,,,,,,,,,,,,,, -65700,1.1560637,3.6028671,,,,,,,,,,,,,, -65800,1.106808,3.5793447,,,,,,,,,,,,,, -65900,1.1407545,3.4077086,,,,,,,,,,,,,, -66000,1.1149573,3.5686235,,,,,,,,,,,,,, -66008,,,0.7390820384025574,1.268944501876831,0.6768999695777893,1.5356301069259644,50000.0,0.5539000034332275,2.1502127647399902,10000.0,30286.265585184097,33194.593849658966,30286.265585184097,2902.1909971237183,2.317537307739258,0.0 -66100,1.1290057,3.4908242,,,,,,,,,,,,,, -66200,1.0527861,4.240366,,,,,,,,,,,,,, -66300,1.1351978,3.5491374,,,,,,,,,,,,,, -66400,1.0309175,4.639072,,,,,,,,,,,,,, -66500,1.1356126,5.134892,,,,,,,,,,,,,, -66600,1.2323364,3.4578643,,,,,,,,,,,,,, -66700,1.371976,5.146758,,,,,,,,,,,,,, -66800,1.1416566,3.426393,,,,,,,,,,,,,, -66900,1.0699323,3.7238724,,,,,,,,,,,,,, -66927,,,0.7473242282867432,1.2253973484039309,0.6799399852752686,1.5155017375946045,50000.0,0.5534999966621399,2.161947011947632,10000.0,30706.27566933632,33656.18039536476,30706.27566933632,2943.6744248867035,2.357588052749634,0.0 -67000,1.030439,3.7147129,,,,,,,,,,,,,, -67100,1.0752463,4.558998,,,,,,,,,,,,,, -67200,1.0907646,5.163343,,,,,,,,,,,,,, -67300,1.1752512,3.4420896,,,,,,,,,,,,,, -67400,1.2319225,3.352807,,,,,,,,,,,,,, -67500,1.1490186,3.4015226,,,,,,,,,,,,,, -67600,1.0775808,4.2806997,,,,,,,,,,,,,, -67700,1.0774921,4.1251254,,,,,,,,,,,,,, -67800,1.0677608,4.6631155,,,,,,,,,,,,,, -67844,,,0.738964855670929,1.255308747291565,0.6803999543190002,1.512239933013916,50000.0,0.5626000165939331,2.114659786224365,10000.0,31126.626373529434,34118.59836268425,31126.626373529434,2985.6455442905426,2.400687456130981,0.0 -67900,1.2216349,3.4654248,,,,,,,,,,,,,, -68000,1.1742092,3.907474,,,,,,,,,,,,,, -68100,1.086264,3.3577545,,,,,,,,,,,,,, -68200,0.9750106,4.0496116,,,,,,,,,,,,,, -68300,1.2285591,3.417605,,,,,,,,,,,,,, -68400,1.1936663,3.4930599,,,,,,,,,,,,,, -68500,1.1607918,3.6341126,,,,,,,,,,,,,, -68600,1.0600858,3.7807925,,,,,,,,,,,,,, -68700,1.1141376,4.7437506,,,,,,,,,,,,,, -68761,,,0.7493945360183716,1.2292940616607666,0.6869199872016907,1.5031391382217407,50000.0,0.5574000477790833,2.1279420852661133,10000.0,31546.67461037636,34578.75641846657,31546.67461037636,3025.6633553504944,2.43937611579895,0.0 -68800,1.1135409,3.674754,,,,,,,,,,,,,, -68900,1.0659819,5.082845,,,,,,,,,,,,,, -69000,1.1851033,3.4608426,,,,,,,,,,,,,, -69100,1.0733768,4.14584,,,,,,,,,,,,,, -69200,1.2532756,3.389604,,,,,,,,,,,,,, -69300,1.0603427,3.4342165,,,,,,,,,,,,,, -69400,1.1531492,3.4700942,,,,,,,,,,,,,, -69500,1.1522101,3.457474,,,,,,,,,,,,,, -69600,1.1586809,5.087449,,,,,,,,,,,,,, -69678,,,0.7514257431030273,1.3177834749221802,0.6843000054359436,1.611196517944336,50000.0,0.5576000213623047,2.2235536575317383,10000.0,31966.91418480873,35040.54090499878,31966.91418480873,3067.120540380478,2.4745750427246094,0.0 -69700,1.114367,3.4165437,,,,,,,,,,,,,, -69800,1.1505182,3.3285203,,,,,,,,,,,,,, -69900,1.1296235,3.4851513,,,,,,,,,,,,,, -70000,1.1271955,3.3696127,,,,,,,,,,,,,, -70100,1.1109692,3.6086159,,,,,,,,,,,,,, -70200,1.0665025,3.824049,,,,,,,,,,,,,, -70300,1.1059089,4.821411,,,,,,,,,,,,,, -70400,1.0992482,3.4551933,,,,,,,,,,,,,, -70500,1.1300058,5.125837,,,,,,,,,,,,,, -70595,,,0.7463085651397705,1.2566375732421875,0.6869800090789795,1.5039660930633545,50000.0,0.5586000084877014,2.119065761566162,10000.0,32387.108137846,35501.06599497795,32387.108137846,3107.352053165436,2.5219199657440186,0.0 -70600,1.0327094,3.668246,,,,,,,,,,,,,, -70700,1.0628115,3.9353256,,,,,,,,,,,,,, -70800,1.1917466,3.4307907,,,,,,,,,,,,,, -70900,1.1546763,3.4967928,,,,,,,,,,,,,, -71000,1.0698007,3.9544916,,,,,,,,,,,,,, -71100,1.1988363,4.747985,,,,,,,,,,,,,, -71200,1.1604484,3.3794637,,,,,,,,,,,,,, -71300,1.1739366,3.462095,,,,,,,,,,,,,, -71400,1.2462586,3.5311134,,,,,,,,,,,,,, -71500,1.246388,3.4704971,,,,,,,,,,,,,, -71512,,,0.7489648461341858,1.1991535425186155,0.6866599917411804,1.4654242992401123,50000.0,0.5613000392913818,2.0869626998901367,10000.0,32807.47200012207,35963.4273827076,32807.47200012207,3149.2554540634155,2.5634772777557373,0.0 -71600,1.26531,5.160118,,,,,,,,,,,,,, -71700,1.0950063,3.5813134,,,,,,,,,,,,,, -71800,1.2582903,3.356848,,,,,,,,,,,,,, -71900,1.1781503,3.5484502,,,,,,,,,,,,,, -72000,1.1850755,3.3936722,,,,,,,,,,,,,, -72100,1.2567871,3.4469607,,,,,,,,,,,,,, -72200,1.2907927,3.3883116,,,,,,,,,,,,,, -72300,1.1433779,3.846797,,,,,,,,,,,,,, -72400,1.1352084,4.782506,,,,,,,,,,,,,, -72427,,,0.7582226395606995,1.187957525253296,0.6908599734306335,1.4750837087631226,50000.0,0.5659000277519226,2.090136766433716,10000.0,33227.58208632469,36423.51284694672,33227.58208632469,3189.1428916454315,2.59970760345459,0.0 -72500,1.1645893,3.366031,,,,,,,,,,,,,, -72600,1.1599181,3.4976022,,,,,,,,,,,,,, -72700,1.3161086,3.3135414,,,,,,,,,,,,,, -72800,1.2046776,3.7554126,,,,,,,,,,,,,, -72900,1.2011554,3.2362108,,,,,,,,,,,,,, -73000,1.1143812,3.5405629,,,,,,,,,,,,,, -73100,1.0423243,4.062601,,,,,,,,,,,,,, -73200,1.1307963,4.6538715,,,,,,,,,,,,,, -73300,1.2012715,5.0564375,,,,,,,,,,,,,, -73345,,,0.7671679258346558,1.1537946462631226,0.6909599900245667,1.4814988374710083,50000.0,0.5666000247001648,2.1033482551574707,10000.0,33647.756412267685,36887.83846831322,33647.756412267685,3233.202719926834,2.637953042984009,0.0 -73400,1.1461607,3.35173,,,,,,,,,,,,,, -73500,1.3018239,3.3265498,,,,,,,,,,,,,, -73600,1.2919008,3.5870578,,,,,,,,,,,,,, -73700,1.094167,3.5823073,,,,,,,,,,,,,, -73800,1.1901599,3.2571785,,,,,,,,,,,,,, -73900,1.3051089,3.3080647,,,,,,,,,,,,,, -74000,1.1342024,3.39547,,,,,,,,,,,,,, -74100,1.1879262,3.5442746,,,,,,,,,,,,,, -74200,1.1321571,4.192454,,,,,,,,,,,,,, -74262,,,0.750195324420929,1.2365492582321167,0.6885799765586853,1.5010799169540403,50000.0,0.5663000345230103,2.115422248840332,10000.0,34068.01615142822,37350.26576018333,34068.01615142822,3275.276128768921,2.678273916244507,0.0 -74300,1.1833954,3.7225266,,,,,,,,,,,,,, -74400,1.0849637,3.610068,,,,,,,,,,,,,, -74500,1.1648284,3.4616928,,,,,,,,,,,,,, -74600,1.1301126,3.4835763,,,,,,,,,,,,,, -74700,1.3237872,3.391251,,,,,,,,,,,,,, -74800,1.133955,3.9823408,,,,,,,,,,,,,, -74900,1.2947242,3.4015927,,,,,,,,,,,,,, -75000,1.2894248,3.3644485,,,,,,,,,,,,,, -75100,1.3541493,3.718985,,,,,,,,,,,,,, -75178,,,0.7569140195846558,1.1788876056671145,0.6946199536323547,1.4552501440048218,50000.0,0.5640000104904175,2.0848517417907715,10000.0,34488.080597639084,37814.71540975571,34488.080597639084,3319.566258430481,2.720487594604492,0.0 -75200,1.2463254,3.5278547,,,,,,,,,,,,,, -75300,1.0661663,4.01947,,,,,,,,,,,,,, -75400,1.2449203,4.9589114,,,,,,,,,,,,,, -75500,1.0912068,3.6357238,,,,,,,,,,,,,, -75600,1.2218858,4.5743766,,,,,,,,,,,,,, -75700,1.1701938,3.353264,,,,,,,,,,,,,, -75800,1.2509112,3.2990074,,,,,,,,,,,,,, -75900,1.1206118,4.1432066,,,,,,,,,,,,,, -76000,1.1708895,3.392409,,,,,,,,,,,,,, -76094,,,0.7713086009025574,1.153877019882202,0.6943599581718445,1.489396095275879,50000.0,0.5666000247001648,2.119791030883789,10000.0,34908.33102440834,38275.80552268028,34908.33102440834,3360.3150522708893,2.759145021438598,0.0 -76100,1.2450536,3.3137383,,,,,,,,,,,,,, -76200,1.2913283,3.346038,,,,,,,,,,,,,, -76300,1.1193719,4.62414,,,,,,,,,,,,,, -76400,1.2203286,3.386417,,,,,,,,,,,,,, -76500,1.2483331,3.3513556,,,,,,,,,,,,,, -76600,1.0663811,3.9641933,,,,,,,,,,,,,, -76700,1.2161362,3.3280563,,,,,,,,,,,,,, -76800,1.1576426,3.738967,,,,,,,,,,,,,, -76900,1.1687452,4.480789,,,,,,,,,,,,,, -77000,1.1960498,3.41554,,,,,,,,,,,,,, -77009,,,0.7552343606948853,1.1852645874023438,0.6930199861526489,1.4465210437774658,50000.0,0.5698000192642212,2.0639443397521973,10000.0,35328.639944553375,38741.29879665375,35328.639944553375,3405.4010264873505,2.8056118488311768,0.0 -77100,1.1256455,3.6917443,,,,,,,,,,,,,, -77200,1.2607328,3.3600123,,,,,,,,,,,,,, -77300,1.2829384,3.5825217,,,,,,,,,,,,,, -77400,1.3012776,5.062294,,,,,,,,,,,,,, -77500,1.2226642,3.4197614,,,,,,,,,,,,,, -77600,1.1904647,3.575517,,,,,,,,,,,,,, -77700,1.3964703,3.513263,,,,,,,,,,,,,, -77800,1.1858147,4.5036297,,,,,,,,,,,,,, -77900,1.3228747,4.9895334,,,,,,,,,,,,,, -77923,,,0.7587695121765137,1.2047119140625,0.6947199702262878,1.4846562147140503,50000.0,0.5684000253677368,2.111642599105835,10000.0,35748.82993221283,39204.77293586731,35748.82993221283,3448.589109897613,2.8494107723236084,0.0 -78000,1.3440009,3.225905,,,,,,,,,,,,,, -78100,1.2119938,4.1399684,,,,,,,,,,,,,, -78200,1.2241932,3.334238,,,,,,,,,,,,,, -78300,1.1507571,3.2771814,,,,,,,,,,,,,, -78400,1.1483344,3.3279335,,,,,,,,,,,,,, -78500,1.1764181,3.3284194,,,,,,,,,,,,,, -78600,1.2824435,3.3309262,,,,,,,,,,,,,, -78700,1.2507701,3.4053044,,,,,,,,,,,,,, -78800,1.2338703,3.3155704,,,,,,,,,,,,,, -78839,,,0.7691015601158142,1.1444746255874634,0.6953999996185303,1.456865310668945,50000.0,0.5698000192642212,2.074385166168213,10000.0,36168.960582733154,39665.85251927376,36168.960582733154,3489.447700977325,2.887045860290528,0.0 -78900,1.3872665,5.0469275,,,,,,,,,,,,,, -79000,1.3144954,5.1042695,,,,,,,,,,,,,, -79100,1.2598372,3.3887322,,,,,,,,,,,,,, -79200,1.2667255,3.515796,,,,,,,,,,,,,, -79300,1.1053174,4.6640415,,,,,,,,,,,,,, -79400,1.1746647,3.3450804,,,,,,,,,,,,,, -79500,1.3083667,5.0985484,,,,,,,,,,,,,, -79600,1.232433,4.8512306,,,,,,,,,,,,,, -79700,1.3109863,4.808946,,,,,,,,,,,,,, -79757,,,0.7608398199081421,1.156747817993164,0.6993399858474731,1.4243495464324951,50000.0,0.5781000256538391,2.030494213104248,10000.0,36588.97576904297,40129.56078243256,36588.97576904297,3533.049519300461,2.9246163368225098,0.0 -79800,1.2077792,3.527805,,,,,,,,,,,,,, -79900,1.2050471,3.4183657,,,,,,,,,,,,,, -80000,1.1538914,3.3893683,,,,,,,,,,,,,, -80100,1.1015388,4.1156306,,,,,,,,,,,,,, -80200,1.3518159,3.3445313,,,,,,,,,,,,,, -80300,1.3248075,3.331388,,,,,,,,,,,,,, -80400,1.1710051,3.651742,,,,,,,,,,,,,, -80500,1.2210034,4.380685,,,,,,,,,,,,,, -80600,1.2150962,3.3295312,,,,,,,,,,,,,, -80674,,,0.7592187523841858,1.2048238515853882,0.6990999579429626,1.4711554050445557,50000.0,0.5706000328063965,2.0904242992401123,10000.0,37009.30336403847,40593.813328027725,37009.30336403847,3576.878761768341,2.968376398086548,0.0 -80700,1.2649108,3.3426387,,,,,,,,,,,,,, -80800,1.1935223,4.6788015,,,,,,,,,,,,,, -80900,1.204555,3.3943186,,,,,,,,,,,,,, -81000,1.1691597,3.6016672,,,,,,,,,,,,,, -81100,1.2111704,3.672531,,,,,,,,,,,,,, -81200,1.20878,4.5600877,,,,,,,,,,,,,, -81300,1.4174204,3.3479674,,,,,,,,,,,,,, -81400,1.3319277,3.3193738,,,,,,,,,,,,,, -81500,1.1804361,3.7807717,,,,,,,,,,,,,, -81588,,,0.7721484303474426,1.1226656436920166,0.7008000016212463,1.4333691596984863,50000.0,0.574400007724762,2.0534827709198,10000.0,37429.31781625748,41058.81526851654,37429.31781625748,3621.7705612182617,3.010806083679199,0.0 -81600,1.2890873,3.3573394,,,,,,,,,,,,,, -81700,1.2012155,3.9549494,,,,,,,,,,,,,, -81800,1.287152,3.270034,,,,,,,,,,,,,, -81900,1.2288572,4.75479,,,,,,,,,,,,,, -82000,1.2433978,3.259497,,,,,,,,,,,,,, -82100,1.2462176,3.3986473,,,,,,,,,,,,,, -82200,1.3191162,3.4227123,,,,,,,,,,,,,, -82300,1.1214052,4.0854893,,,,,,,,,,,,,, -82400,1.3047147,3.3849595,,,,,,,,,,,,,, -82500,1.1644447,4.3908577,,,,,,,,,,,,,, -82505,,,0.7578710913658142,1.1451164484024048,0.6981199979782104,1.403847336769104,50000.0,0.5731000304222107,2.033210515975952,10000.0,37849.59191274643,41522.65545606613,37849.59191274643,3665.23730635643,3.0572826862335205,0.0 -82600,1.2644064,3.3263218,,,,,,,,,,,,,, -82700,1.2080218,3.740936,,,,,,,,,,,,,, -82800,1.3575662,5.0032325,,,,,,,,,,,,,, -82900,1.2307525,3.6399522,,,,,,,,,,,,,, -83000,1.1764554,4.3642936,,,,,,,,,,,,,, -83100,1.2107213,3.3903,,,,,,,,,,,,,, -83200,1.2612242,3.577917,,,,,,,,,,,,,, -83300,1.2083176,3.235808,,,,,,,,,,,,,, -83400,1.1936918,3.419486,,,,,,,,,,,,,, -83422,,,0.7671093344688416,1.126673460006714,0.702459990978241,1.4112627506256104,50000.0,0.5782000422477722,2.01192593574524,10000.0,38269.6051607132,41987.00690460205,38269.6051607132,3709.483957052231,3.096383810043335,0.0 -83500,1.1533618,4.2195716,,,,,,,,,,,,,, -83600,1.27036,3.3047242,,,,,,,,,,,,,, -83700,1.3091514,3.3824193,,,,,,,,,,,,,, -83800,1.2728591,4.673803,,,,,,,,,,,,,, -83900,1.1385067,4.0653195,,,,,,,,,,,,,, -84000,1.1416292,3.5093875,,,,,,,,,,,,,, -84100,1.417253,5.001894,,,,,,,,,,,,,, -84200,1.274541,3.3173926,,,,,,,,,,,,,, -84300,1.4112629,5.09966,,,,,,,,,,,,,, -84337,,,0.7741601467132568,1.1334631443023682,0.7036600112915039,1.4320693016052246,50000.0,0.5787000060081482,2.0397210121154785,10000.0,38689.86214780808,42445.65338349342,38689.86214780808,3747.7794547080994,3.1379082202911377,0.0 -84400,1.1873466,3.8766706,,,,,,,,,,,,,, -84500,1.1988636,4.599886,,,,,,,,,,,,,, -84600,1.2277372,3.4676623,,,,,,,,,,,,,, -84700,1.2371639,4.742018,,,,,,,,,,,,,, -84800,1.3003732,3.3481266,,,,,,,,,,,,,, -84900,1.2227726,3.473573,,,,,,,,,,,,,, -85000,1.2840202,3.364851,,,,,,,,,,,,,, -85100,1.2504528,3.331556,,,,,,,,,,,,,, -85200,1.170516,3.832867,,,,,,,,,,,,,, -85253,,,0.7835546731948853,1.083608627319336,0.7017199993133545,1.4271472692489624,50000.0,0.5779000520706177,2.0556511878967285,10000.0,39110.120413541794,42905.44090247154,39110.120413541794,3787.216263532639,3.178086042404175,0.0 -85300,1.2743386,4.5977397,,,,,,,,,,,,,, -85400,1.1752927,3.7456656,,,,,,,,,,,,,, -85500,1.5501102,3.3103945,,,,,,,,,,,,,, -85600,1.2616788,3.1835084,,,,,,,,,,,,,, -85700,1.1984917,3.3199403,,,,,,,,,,,,,, -85800,1.197173,3.8558564,,,,,,,,,,,,,, -85900,1.3209388,4.311929,,,,,,,,,,,,,, -86000,1.2808343,3.4087853,,,,,,,,,,,,,, -86100,1.2265385,3.285081,,,,,,,,,,,,,, -86168,,,0.7661913633346558,1.129399657249451,0.7059599757194519,1.391916036605835,50000.0,0.5843999981880188,2.0091748237609863,10000.0,39530.094547748566,43367.21985912323,39530.094547748566,3828.928506135941,3.2188732624053955,0.0 -86200,1.2297379,3.8615372,,,,,,,,,,,,,, -86300,1.2251189,4.6884413,,,,,,,,,,,,,, -86400,1.2911085,3.3491857,,,,,,,,,,,,,, -86500,1.1694578,3.8950353,,,,,,,,,,,,,, -86600,1.2387564,3.260838,,,,,,,,,,,,,, -86700,1.1166303,3.8839607,,,,,,,,,,,,,, -86800,1.2808539,3.3158178,,,,,,,,,,,,,, -86900,1.5021504,3.4187315,,,,,,,,,,,,,, -87000,1.3114047,3.2479904,,,,,,,,,,,,,, -87086,,,0.7721288800239563,1.1214532852172852,0.7034800052642822,1.4084711074829102,50000.0,0.5800000429153442,2.038839817047119,10000.0,39950.2617855072,43830.59176373482,39950.2617855072,3872.0360209941855,3.2629506587982178,0.0 -87100,1.2236135,3.2707047,,,,,,,,,,,,,, -87200,1.1904652,3.4063354,,,,,,,,,,,,,, -87300,1.3379,3.303821,,,,,,,,,,,,,, -87400,1.187906,3.421679,,,,,,,,,,,,,, -87500,1.3330239,3.2867992,,,,,,,,,,,,,, -87600,1.2677368,3.272184,,,,,,,,,,,,,, -87700,1.3054434,3.4413059,,,,,,,,,,,,,, -87800,1.2555473,3.2905126,,,,,,,,,,,,,, -87900,1.2462953,3.9145486,,,,,,,,,,,,,, -88000,1.1971396,3.4933283,,,,,,,,,,,,,, -88002,,,0.7907617092132568,1.0512465238571167,0.708620011806488,1.3993642330169678,50000.0,0.5897000432014465,2.0038678646087646,10000.0,40370.22263884544,44294.23095417023,40370.22263884544,3915.6184599399567,3.306154489517212,0.0 -88100,1.3461422,3.3404346,,,,,,,,,,,,,, -88200,1.1560808,4.259831,,,,,,,,,,,,,, -88300,1.1978209,4.02277,,,,,,,,,,,,,, -88400,1.234228,3.3439999,,,,,,,,,,,,,, -88500,1.3206488,4.854894,,,,,,,,,,,,,, -88600,1.2887925,3.2863145,,,,,,,,,,,,,, -88700,1.4429693,3.383651,,,,,,,,,,,,,, -88800,1.3656306,3.1355102,,,,,,,,,,,,,, -88900,1.2761927,3.4775567,,,,,,,,,,,,,, -88915,,,0.7695898413658142,1.1286250352859497,0.7078799605369568,1.400683045387268,50000.0,0.5835000276565552,2.015311479568481,10000.0,40790.28323984146,44754.43089079857,40790.28323984146,3955.6573054790497,3.3536322116851807,0.0 -89000,1.3166597,3.531081,,,,,,,,,,,,,, -89100,1.2371042,3.618568,,,,,,,,,,,,,, -89200,1.2203343,3.3162816,,,,,,,,,,,,,, -89300,1.4667739,3.2864203,,,,,,,,,,,,,, -89400,1.234027,3.5469167,,,,,,,,,,,,,, -89500,1.4705737,3.2754898,,,,,,,,,,,,,, -89600,1.4401563,5.0462646,,,,,,,,,,,,,, -89700,1.34017,3.3512723,,,,,,,,,,,,,, -89800,1.2705548,3.6785161,,,,,,,,,,,,,, -89831,,,0.7773827910423279,1.0906800031661987,0.7084000110626221,1.3841772079467771,50000.0,0.5851000547409058,1.989667654037476,10000.0,41210.26914906502,45217.8800368309,41210.26914906502,3999.028008460999,3.3931477069854736,0.0 -89900,1.208206,3.4607956,,,,,,,,,,,,,, -90000,1.2374371,3.438193,,,,,,,,,,,,,, -90100,1.2889689,3.2929583,,,,,,,,,,,,,, -90200,1.3416963,3.2679236,,,,,,,,,,,,,, -90300,1.2949177,3.2542193,,,,,,,,,,,,,, -90400,1.2099428,3.338564,,,,,,,,,,,,,, -90500,1.5285914,3.312822,,,,,,,,,,,,,, -90600,1.3948643,3.3109758,,,,,,,,,,,,,, -90700,1.2544639,3.2375317,,,,,,,,,,,,,, -90745,,,0.7867773175239563,1.062552809715271,0.7113800048828125,1.3798199892044067,50000.0,0.589400053024292,1.9841927289962769,10000.0,41630.42470932007,45682.04362034798,41630.42470932007,4042.940614700317,3.434882879257202,0.0 -90800,1.6230167,4.4257197,,,,,,,,,,,,,, -90900,1.354586,3.3155203,,,,,,,,,,,,,, -91000,1.246894,3.83111,,,,,,,,,,,,,, -91100,1.308461,4.555081,,,,,,,,,,,,,, -91200,1.2856786,4.471681,,,,,,,,,,,,,, -91300,1.2244759,3.2396717,,,,,,,,,,,,,, -91400,1.3120313,3.7926266,,,,,,,,,,,,,, -91500,1.3027668,3.3710992,,,,,,,,,,,,,, -91600,1.3673006,3.241398,,,,,,,,,,,,,, -91661,,,0.7730078101158142,1.0981395244598389,0.7106599807739258,1.366332769393921,50000.0,0.5848000049591064,1.9762824773788448,10000.0,42050.68173789978,46141.46359419823,42050.68173789978,4082.0101313591,3.475362062454224,0.0 -91700,1.3122196,3.279452,,,,,,,,,,,,,, -91800,1.3412458,4.6080627,,,,,,,,,,,,,, -91900,1.3244767,3.9170926,,,,,,,,,,,,,, -92000,1.3388566,3.253642,,,,,,,,,,,,,, -92100,1.219978,3.8319669,,,,,,,,,,,,,, -92200,1.2118213,4.07624,,,,,,,,,,,,,, -92300,1.2920879,3.411622,,,,,,,,,,,,,, -92400,1.3299996,3.5102444,,,,,,,,,,,,,, -92500,1.2974316,3.2462454,,,,,,,,,,,,,, -92575,,,0.779980480670929,1.0799758434295654,0.708620011806488,1.3764957189559937,50000.0,0.5852000117301941,1.9942631721496584,10000.0,42470.79413843155,46606.02896428108,42470.79413843155,4126.3690321445465,3.515968084335327,0.0 -92600,1.4332161,3.2833312,,,,,,,,,,,,,, -92700,1.386645,3.3232439,,,,,,,,,,,,,, -92800,1.3622252,4.6540318,,,,,,,,,,,,,, -92900,1.2430881,3.9921856,,,,,,,,,,,,,, -93000,1.3214682,3.2366862,,,,,,,,,,,,,, -93100,1.303589,3.2856977,,,,,,,,,,,,,, -93200,1.4286684,4.883076,,,,,,,,,,,,,, -93300,1.3652221,4.5122633,,,,,,,,,,,,,, -93400,1.2186674,3.7026682,,,,,,,,,,,,,, -93494,,,0.7886718511581421,1.0608985424041748,0.7127999663352966,1.3815165758132937,50000.0,0.5883000493049622,1.989989161491394,10000.0,42890.790291547775,47068.66774916649,42890.790291547775,4168.915802717209,3.5591742992401123,0.0 -93500,1.2948341,4.1729198,,,,,,,,,,,,,, -93600,1.3759327,4.816289,,,,,,,,,,,,,, -93700,1.2855233,3.250204,,,,,,,,,,,,,, -93800,1.3070327,3.2915366,,,,,,,,,,,,,, -93900,1.292364,3.366495,,,,,,,,,,,,,, -94000,1.3438935,3.279262,,,,,,,,,,,,,, -94100,1.3002057,4.243024,,,,,,,,,,,,,, -94200,1.4256915,4.9165287,,,,,,,,,,,,,, -94300,1.3574997,3.3142729,,,,,,,,,,,,,, -94400,1.268874,4.5039673,,,,,,,,,,,,,, -94409,,,0.7827734351158142,1.0911414623260498,0.7120400071144104,1.3809247016906738,50000.0,0.5932000279426575,1.9795022010803225,10000.0,43310.81183815002,47533.24429869652,43310.81183815002,4213.372656822205,3.6043410301208496,0.0 -94500,1.3152559,3.2066514,,,,,,,,,,,,,, -94600,1.2938596,3.5043426,,,,,,,,,,,,,, -94700,1.5342623,4.7725835,,,,,,,,,,,,,, -94800,1.5470313,4.76441,,,,,,,,,,,,,, -94900,1.2981836,3.1554642,,,,,,,,,,,,,, -95000,1.4164244,4.659758,,,,,,,,,,,,,, -95100,1.4026062,3.3026037,,,,,,,,,,,,,, -95200,1.1364224,3.8886027,,,,,,,,,,,,,, -95300,1.3559309,3.4469128,,,,,,,,,,,,,, -95328,,,0.7830663919448853,1.0953134298324585,0.7137599587440491,1.3859901428222656,50000.0,0.5919000506401062,1.9971227645874023,10000.0,43731.426703214645,47996.88214635849,43731.426703214645,4256.298967838287,3.647977590560913,0.0 -95400,1.4535813,3.1683774,,,,,,,,,,,,,, -95500,1.412188,3.226233,,,,,,,,,,,,,, -95600,1.5061306,4.7586355,,,,,,,,,,,,,, -95700,1.3092232,3.2355876,,,,,,,,,,,,,, -95800,1.397621,3.223636,,,,,,,,,,,,,, -95900,1.2972053,3.8763313,,,,,,,,,,,,,, -96000,1.3665868,3.2958627,,,,,,,,,,,,,, -96100,1.3344393,4.4914107,,,,,,,,,,,,,, -96200,1.4960127,3.2486696,,,,,,,,,,,,,, -96244,,,0.7925195097923279,1.027241826057434,0.7157999873161316,1.3453587293624878,50000.0,0.5871000289916992,1.94822096824646,10000.0,44151.71507334709,48462.07651758194,44151.71507334709,4301.111127138138,3.689197540283203,0.0 -96300,1.3435653,3.2332482,,,,,,,,,,,,,, -96400,1.3031467,3.2391517,,,,,,,,,,,,,, -96500,1.5694573,4.7451844,,,,,,,,,,,,,, -96600,1.3584546,3.5532975,,,,,,,,,,,,,, -96700,1.3667868,3.8575094,,,,,,,,,,,,,, -96800,1.386326,3.921976,,,,,,,,,,,,,, -96900,1.4637859,3.2668195,,,,,,,,,,,,,, -97000,1.5348039,4.743079,,,,,,,,,,,,,, -97100,1.2990472,3.507841,,,,,,,,,,,,,, -97160,,,0.8024218678474426,1.013872146606445,0.7195599675178528,1.3655524253845217,50000.0,0.5987000465393066,1.9719301462173464,10000.0,44571.969373226166,48922.39983391762,44571.969373226166,4341.087158918381,3.729381799697876,0.0 -97200,1.4596286,3.478029,,,,,,,,,,,,,, -97300,1.5101498,3.449353,,,,,,,,,,,,,, -97400,1.2729145,3.2743175,,,,,,,,,,,,,, -97500,1.539865,4.7325916,,,,,,,,,,,,,, -97600,1.5073873,4.836944,,,,,,,,,,,,,, -97700,1.238915,4.0065145,,,,,,,,,,,,,, -97800,1.3406063,3.6216397,,,,,,,,,,,,,, -97900,1.41029,3.2549353,,,,,,,,,,,,,, -98000,1.3442414,3.4031124,,,,,,,,,,,,,, -98072,,,0.7826171517372131,1.069865107536316,0.7172399759292603,1.3597993850708008,50000.0,0.5934000015258789,1.9582252502441408,10000.0,44992.07482409477,49384.86444759369,44992.07482409477,4383.347653388977,3.7755613327026367,0.0 -98100,1.5091652,4.697132,,,,,,,,,,,,,, -98200,1.3364311,3.2288234,,,,,,,,,,,,,, -98300,1.4619946,3.2438421,,,,,,,,,,,,,, -98400,1.4183503,3.3030145,,,,,,,,,,,,,, -98500,1.3017387,3.4499357,,,,,,,,,,,,,, -98600,1.440863,4.52137,,,,,,,,,,,,,, -98700,1.4122746,3.258435,,,,,,,,,,,,,, -98800,1.3788065,3.5502596,,,,,,,,,,,,,, -98900,1.5430146,3.28692,,,,,,,,,,,,,, -98987,,,0.7899804711341858,1.0631240606307983,0.7159000039100647,1.3712981939315796,50000.0,0.5945000052452087,1.974447965621948,10000.0,45412.23049378395,49850.361988306046,45412.23049378395,4428.588559150696,3.823927640914917,0.0 -99000,1.3658068,3.2289226,,,,,,,,,,,,,, -99100,1.3822608,4.868128,,,,,,,,,,,,,, -99200,1.3455116,3.2162702,,,,,,,,,,,,,, -99300,1.3788624,3.239013,,,,,,,,,,,,,, -99400,1.3530324,3.549671,,,,,,,,,,,,,, -99500,1.3813916,3.2053025,,,,,,,,,,,,,, -99600,1.3436705,3.2704453,,,,,,,,,,,,,, -99700,1.2358315,3.2517111,,,,,,,,,,,,,, -99800,1.3319378,3.2183223,,,,,,,,,,,,,, -99900,1.2856485,3.1547544,,,,,,,,,,,,,, -99904,,,0.8057421445846558,0.96631920337677,0.7188799977302551,1.3333569765090942,50000.0,0.5960000157356262,1.927379250526428,10000.0,45832.44064116478,50314.4089922905,45832.44064116478,4472.3286554813385,3.867229223251343,0.0 -100000,1.4204432,4.4018545,,,,,,,,,,,,,, -100100,1.4214988,3.2751436,,,,,,,,,,,,,, -100200,1.2953588,3.1391351,,,,,,,,,,,,,, -100300,1.3454745,3.3025696,,,,,,,,,,,,,, -100400,1.3868629,3.2537746,,,,,,,,,,,,,, -100500,1.2907754,3.7675173,,,,,,,,,,,,,, -100600,1.6375386,3.1862502,,,,,,,,,,,,,, -100700,1.3729771,3.2862344,,,,,,,,,,,,,, -100800,1.5056427,4.789197,,,,,,,,,,,,,, -100823,,,0.7916210889816284,1.0372837781906128,0.723039984703064,1.3263189792633057,50000.0,0.5957000255584717,1.943742036819458,10000.0,46252.78383398056,50773.26985836029,46252.78383398056,4510.744811296463,3.916255474090576,0.0 -100900,1.3714707,3.2577057,,,,,,,,,,,,,, -101000,1.4497603,3.167036,,,,,,,,,,,,,, -101100,1.2882138,3.6469731,,,,,,,,,,,,,, -101200,1.2860276,3.5050573,,,,,,,,,,,,,, -101300,1.3547896,4.354765,,,,,,,,,,,,,, -101400,1.3831106,3.248546,,,,,,,,,,,,,, -101500,1.4796977,3.262371,,,,,,,,,,,,,, -101600,1.5078729,3.2732718,,,,,,,,,,,,,, -101700,1.4950192,3.2148657,,,,,,,,,,,,,, -101740,,,0.7941210865974426,1.010243535041809,0.7223399877548218,1.32013201713562,50000.0,0.5976999998092651,1.9327346086502075,10000.0,46672.85423207283,51235.3676404953,46672.85423207283,4552.680232286453,3.956141948699951,0.0 -101800,1.3081305,3.376554,,,,,,,,,,,,,, -101900,1.459664,3.3255043,,,,,,,,,,,,,, -102000,1.6134768,3.244859,,,,,,,,,,,,,, -102100,1.3007662,4.3747954,,,,,,,,,,,,,, -102200,1.3647714,3.7207835,,,,,,,,,,,,,, -102300,1.4831158,3.216576,,,,,,,,,,,,,, -102400,1.3248401,4.2418795,,,,,,,,,,,,,, -102500,1.5128702,3.2999582,,,,,,,,,,,,,, -102600,1.41936,3.2235918,,,,,,,,,,,,,, -102657,,,0.8057616949081421,0.9902685880661012,0.7222999930381775,1.3363498449325562,50000.0,0.6000000238418579,1.952873945236206,10000.0,47093.024424791336,51695.68058466911,47093.024424791336,4592.725254058838,4.001856803894043,0.0 -102700,1.3508605,3.719969,,,,,,,,,,,,,, -102800,1.4659889,3.2165737,,,,,,,,,,,,,, -102900,1.4053795,3.2588072,,,,,,,,,,,,,, -103000,1.3848563,3.193396,,,,,,,,,,,,,, -103100,1.3583035,3.2057474,,,,,,,,,,,,,, -103200,1.394325,4.1554003,,,,,,,,,,,,,, -103300,1.4778886,4.369177,,,,,,,,,,,,,, -103400,1.4467798,4.2171054,,,,,,,,,,,,,, -103500,1.530312,4.296254,,,,,,,,,,,,,, -103574,,,0.79505854845047,1.0035017728805542,0.7264399528503418,1.2918392419815063,50000.0,0.603600025177002,1.899206280708313,10000.0,47512.99956417084,52159.08165287972,47512.99956417084,4636.055516242981,4.0445640087127686,0.0 -103600,1.341462,3.194461,,,,,,,,,,,,,, -103700,1.3910389,3.499979,,,,,,,,,,,,,, -103800,1.4481293,3.2170553,,,,,,,,,,,,,, -103900,1.4084216,3.724569,,,,,,,,,,,,,, -104000,1.4425123,3.1557498,,,,,,,,,,,,,, -104100,1.3980721,3.4356523,,,,,,,,,,,,,, -104200,1.3477228,3.1687427,,,,,,,,,,,,,, -104300,1.520737,4.824134,,,,,,,,,,,,,, -104400,1.3868086,3.261683,,,,,,,,,,,,,, -104491,,,0.7984960675239563,1.0110760927200315,0.7244200110435486,1.3226091861724854,50000.0,0.5982000231742859,1.937688231468201,10000.0,47933.23684620857,52618.84631681442,47933.23684620857,4675.482330560684,4.093322277069092,0.0 -104500,1.3817296,3.2217674,,,,,,,,,,,,,, -104600,1.4898137,3.2346373,,,,,,,,,,,,,, -104700,1.4301428,3.2132273,,,,,,,,,,,,,, -104800,1.449038,3.3521013,,,,,,,,,,,,,, -104900,1.5731084,3.1629684,,,,,,,,,,,,,, -105000,1.403471,4.129335,,,,,,,,,,,,,, -105100,1.5091949,3.2541156,,,,,,,,,,,,,, -105200,1.5243782,3.2121463,,,,,,,,,,,,,, -105300,1.516479,3.1322017,,,,,,,,,,,,,, -105400,1.4656506,3.1378741,,,,,,,,,,,,,, -105405,,,0.8055859208106995,0.946385383605957,0.7270799875259399,1.2798376083374023,50000.0,0.607200026512146,1.8782799243927,10000.0,48353.525584459305,53082.0359852314,48353.525584459305,4718.283473730087,4.140423059463501,0.0 -105500,1.3761814,3.3607864,,,,,,,,,,,,,, -105600,1.4936457,3.1821783,,,,,,,,,,,,,, -105700,1.376512,3.1233447,,,,,,,,,,,,,, -105800,1.3971045,3.1321833,,,,,,,,,,,,,, -105900,1.5204202,3.1955047,,,,,,,,,,,,,, -106000,1.340064,3.8292418,,,,,,,,,,,,,, -106100,1.6080216,3.28962,,,,,,,,,,,,,, -106200,1.3689544,3.1927981,,,,,,,,,,,,,, -106300,1.6515046,4.6883526,,,,,,,,,,,,,, -106321,,,0.7962890267372131,1.0095607042312622,0.726639986038208,1.308942794799805,50000.0,0.6035000085830688,1.9131704568862915,10000.0,48773.57774686813,53542.90063166618,48773.57774686813,4758.99973154068,4.183903694152832,0.0 -106400,1.4462788,3.1427739,,,,,,,,,,,,,, -106500,1.5319159,3.158249,,,,,,,,,,,,,, -106600,1.3978512,3.2554436,,,,,,,,,,,,,, -106700,1.3949558,3.1667295,,,,,,,,,,,,,, -106800,1.7578896,4.6053452,,,,,,,,,,,,,, -106900,1.3446181,3.327949,,,,,,,,,,,,,, -107000,1.409179,3.4153664,,,,,,,,,,,,,, -107100,1.6717349,4.4995575,,,,,,,,,,,,,, -107200,1.3811114,3.5665638,,,,,,,,,,,,,, -107240,,,0.79638671875,1.0163118839263916,0.7265799641609192,1.3119384050369265,50000.0,0.5995000004768372,1.9311659336090088,10000.0,49193.492560863495,54002.184608221054,49193.492560863495,4798.267370700836,4.23191499710083,0.0 -107300,1.5000021,3.4365854,,,,,,,,,,,,,, -107400,1.4815117,3.1646662,,,,,,,,,,,,,, -107500,1.4756761,4.371393,,,,,,,,,,,,,, -107600,1.5276955,3.15736,,,,,,,,,,,,,, -107700,1.4561785,3.098845,,,,,,,,,,,,,, -107800,1.4396985,3.3455727,,,,,,,,,,,,,, -107900,1.5443501,3.21421,,,,,,,,,,,,,, -108000,1.6421707,4.834964,,,,,,,,,,,,,, -108100,1.6445069,3.1636858,,,,,,,,,,,,,, -108158,,,0.8068554401397705,0.9952368140220642,0.7298399806022644,1.3251501321792605,50000.0,0.6043000221252441,1.9325958490371704,10000.0,49613.74026465416,54464.774629831314,49613.74026465416,4840.51326918602,4.2742462158203125,0.0 -108200,1.386386,4.153435,,,,,,,,,,,,,, -108300,1.435766,3.1822188,,,,,,,,,,,,,, -108400,1.6398854,4.6424956,,,,,,,,,,,,,, -108500,1.3634027,3.7182271,,,,,,,,,,,,,, -108600,1.4607197,3.2318628,,,,,,,,,,,,,, -108700,1.4965882,3.1753337,,,,,,,,,,,,,, -108800,1.4356878,4.1116033,,,,,,,,,,,,,, -108900,1.625959,3.2176452,,,,,,,,,,,,,, -109000,1.5976506,3.1914184,,,,,,,,,,,,,, -109073,,,0.8169335722923279,0.9389804601669312,0.7301599979400635,1.2972155809402466,50000.0,0.6089000105857849,1.8995212316513064,10000.0,50033.76714491844,54925.62378549576,50033.76714491844,4881.228107690811,4.3293633460998535,0.0 -109100,1.6269183,4.8276362,,,,,,,,,,,,,, -109200,1.5255935,4.2076015,,,,,,,,,,,,,, -109300,1.661705,3.1757517,,,,,,,,,,,,,, -109400,1.3957366,3.132503,,,,,,,,,,,,,, -109500,1.4310582,3.18284,,,,,,,,,,,,,, -109600,1.4098123,4.035686,,,,,,,,,,,,,, -109700,1.4831533,3.35719,,,,,,,,,,,,,, -109800,1.7662119,4.437015,,,,,,,,,,,,,, -109900,1.4471875,3.1370656,,,,,,,,,,,,,, -109990,,,0.8023046851158142,0.9991682767868042,0.730239987373352,1.301327347755432,50000.0,0.6098000407218933,1.899088382720948,10000.0,50453.88138437271,55383.31662654877,50453.88138437271,4918.705953121185,4.377705097198486,0.0 -110000,1.526612,3.1804616,,,,,,,,,,,,,, -110100,1.6043634,3.2127075,,,,,,,,,,,,,, -110200,1.4526815,4.5336113,,,,,,,,,,,,,, -110300,1.3659844,3.790617,,,,,,,,,,,,,, -110400,1.5688043,4.4527864,,,,,,,,,,,,,, -110500,1.5394144,3.209929,,,,,,,,,,,,,, -110600,1.4994527,4.3891993,,,,,,,,,,,,,, -110700,1.4382639,3.57817,,,,,,,,,,,,,, -110800,1.6193761,4.4466524,,,,,,,,,,,,,, -110900,1.4876404,4.2841015,,,,,,,,,,,,,, -110907,,,0.8082812428474426,0.94392591714859,0.7319799661636353,1.2687028646469116,50000.0,0.6110000014305115,1.8786910772323608,10000.0,50873.97991299629,55845.48027086258,50873.97991299629,4960.672949314117,4.423073053359985,0.0 -111000,1.4363441,3.203922,,,,,,,,,,,,,, -111100,1.5447322,3.7377765,,,,,,,,,,,,,, -111200,1.382028,3.6997664,,,,,,,,,,,,,, -111300,1.4247019,3.747311,,,,,,,,,,,,,, -111400,1.6792208,4.69641,,,,,,,,,,,,,, -111500,1.4749776,3.1742778,,,,,,,,,,,,,, -111600,1.4513413,3.6930227,,,,,,,,,,,,,, -111700,1.3917961,3.2506082,,,,,,,,,,,,,, -111800,1.5297419,4.286577,,,,,,,,,,,,,, -111820,,,0.8192773461341858,0.9103418588638306,0.7331399917602539,1.2813720703125,50000.0,0.613800048828125,1.898208737373352,10000.0,51294.16231417656,56304.97105741501,51294.16231417656,4999.882031202316,4.470271587371826,0.0 -111900,1.4365364,3.1379087,,,,,,,,,,,,,, -112000,1.5365008,3.1131792,,,,,,,,,,,,,, -112100,1.5706866,3.1710045,,,,,,,,,,,,,, -112200,1.7362405,4.802866,,,,,,,,,,,,,, -112300,1.4892433,3.8557577,,,,,,,,,,,,,, -112400,1.4757997,3.251583,,,,,,,,,,,,,, -112500,1.5383072,3.1729436,,,,,,,,,,,,,, -112600,1.4825732,3.5454054,,,,,,,,,,,,,, -112700,1.4850503,3.1457686,,,,,,,,,,,,,, -112735,,,0.8069726228713989,0.985032081604004,0.7335000038146973,1.2869131565093994,50000.0,0.6167000532150269,1.8878358602523804,10000.0,51714.31570267677,56765.39139795303,51714.31570267677,5040.045891523361,4.520854949951172,0.0 -112800,1.5643281,4.5738297,,,,,,,,,,,,,, -112900,1.50251,3.11204,,,,,,,,,,,,,, -113000,1.5619872,3.2255528,,,,,,,,,,,,,, -113100,1.4171525,3.0894918,,,,,,,,,,,,,, -113200,1.5032235,3.228628,,,,,,,,,,,,,, -113300,1.6195432,3.1472075,,,,,,,,,,,,,, -113400,1.5431231,3.0715227,,,,,,,,,,,,,, -113500,1.6283348,3.1099174,,,,,,,,,,,,,, -113600,1.8396405,3.1704104,,,,,,,,,,,,,, -113652,,,0.810351550579071,0.9742421507835388,0.7343800067901611,1.293757438659668,50000.0,0.6139000058174133,1.887000322341919,10000.0,52134.493783950806,57224.7523059845,52134.493783950806,5079.130994319916,4.565948486328125,0.0 -113700,1.5177702,3.1611547,,,,,,,,,,,,,, -113800,1.6531183,3.5132523,,,,,,,,,,,,,, -113900,1.5167097,3.1649814,,,,,,,,,,,,,, -114000,1.7029166,4.5104527,,,,,,,,,,,,,, -114100,1.8132586,4.650818,,,,,,,,,,,,,, -114200,1.519562,4.1241674,,,,,,,,,,,,,, -114300,1.6347802,3.2141945,,,,,,,,,,,,,, -114400,1.7720613,4.64952,,,,,,,,,,,,,, -114500,1.5937343,3.1133478,,,,,,,,,,,,,, -114570,,,0.8231640458106995,0.8864479660987854,0.7347399592399597,1.2553902864456177,50000.0,0.6096000075340271,1.864834427833557,10000.0,52554.77208948136,57689.89100813866,52554.77208948136,5123.890210866928,4.613940000534058,0.0 -114600,1.5095031,3.058326,,,,,,,,,,,,,, -114700,1.5673968,3.183166,,,,,,,,,,,,,, -114800,1.5472649,3.129364,,,,,,,,,,,,,, -114900,1.5243974,3.1273508,,,,,,,,,,,,,, -115000,1.5726961,4.2472086,,,,,,,,,,,,,, -115100,1.6477109,3.152302,,,,,,,,,,,,,, -115200,1.6236329,4.417577,,,,,,,,,,,,,, -115300,1.5510205,3.2170596,,,,,,,,,,,,,, -115400,1.5042679,3.682907,,,,,,,,,,,,,, -115486,,,0.8078905940055847,0.9823316931724548,0.7350599765777588,1.3003755807876587,50000.0,0.6133000254631042,1.8954529762268064,10000.0,52974.86817359924,58152.056190013885,52974.86817359924,5165.858438968658,4.662526845932007,0.0 -115500,1.5395766,3.4531417,,,,,,,,,,,,,, -115600,1.519423,4.0099106,,,,,,,,,,,,,, -115700,1.62891,3.0702755,,,,,,,,,,,,,, -115800,1.4582844,3.908502,,,,,,,,,,,,,, -115900,1.5210054,4.003437,,,,,,,,,,,,,, -116000,1.6884654,3.1136124,,,,,,,,,,,,,, -116100,1.5175234,4.0103993,,,,,,,,,,,,,, -116200,1.674885,3.1299787,,,,,,,,,,,,,, -116300,1.5626775,3.1161387,,,,,,,,,,,,,, -116400,1.5855092,3.1501348,,,,,,,,,,,,,, -116403,,,0.8154101371765137,0.9408923983573914,0.738099992275238,1.263514518737793,50000.0,0.612500011920929,1.879354238510132,10000.0,53395.18107008934,58615.71476054192,53395.18107008934,5209.096930503845,4.717260360717773,0.0 -116500,1.5811388,3.1459746,,,,,,,,,,,,,, -116600,1.4552785,3.788582,,,,,,,,,,,,,, -116700,1.619575,3.1488805,,,,,,,,,,,,,, -116800,1.5403132,3.0650792,,,,,,,,,,,,,, -116900,1.625871,3.1286654,,,,,,,,,,,,,, -117000,1.5243295,3.8596575,,,,,,,,,,,,,, -117100,1.7528563,4.4911456,,,,,,,,,,,,,, -117200,1.7691946,4.479727,,,,,,,,,,,,,, -117300,1.6326417,3.2189274,,,,,,,,,,,,,, -117321,,,0.8239257335662842,0.8793371915817261,0.7392399907112122,1.2323999404907229,50000.0,0.6232000589370728,1.8169059753417969,10000.0,53815.10822582245,59076.4835767746,53815.10822582245,5249.836025476456,4.766160011291504,0.0 -117400,1.7369382,3.173782,,,,,,,,,,,,,, -117500,1.7276754,3.131435,,,,,,,,,,,,,, -117600,1.8539094,4.6898046,,,,,,,,,,,,,, -117700,1.5600842,3.0724735,,,,,,,,,,,,,, -117800,1.5768605,3.1558542,,,,,,,,,,,,,, -117900,1.7111317,4.599475,,,,,,,,,,,,,, -118000,1.5766176,3.1455245,,,,,,,,,,,,,, -118100,1.6226087,3.2345154,,,,,,,,,,,,,, -118195,,,0.8167187571525574,0.927771270275116,0.7392599582672119,1.2517048120498655,50000.0,0.6173000335693359,1.852914571762085,10000.0,54235.39735746384,59539.506457567215,54235.39735746384,5292.465433597565,4.820432662963867,0.0 -118200,2.0427172,4.8619742,,,,,,,,,,,,,, -118300,1.6542132,4.4866877,,,,,,,,,,,,,, -118400,1.5804082,3.0563908,,,,,,,,,,,,,, -118500,1.4959962,3.2168708,,,,,,,,,,,,,, -118600,1.6472527,3.0691833,,,,,,,,,,,,,, -118700,1.6099545,3.297126,,,,,,,,,,,,,, -118800,1.8171397,4.581282,,,,,,,,,,,,,, -118900,1.5782729,3.4791956,,,,,,,,,,,,,, -119000,1.6397864,3.9525707,,,,,,,,,,,,,, -119100,1.7333143,3.0692592,,,,,,,,,,,,,, -119111,,,0.8149804472923279,0.9616518020629884,0.7415800094604492,1.2812042236328125,50000.0,0.6168000102043152,1.890485405921936,10000.0,54655.33577847481,59999.27346301079,54655.33577847481,5332.194699764252,4.867582082748413,0.0 -119200,1.6614412,3.2077248,,,,,,,,,,,,,, -119300,1.5936362,3.2888384,,,,,,,,,,,,,, -119400,1.5516549,3.4195404,,,,,,,,,,,,,, -119500,1.774572,3.0810723,,,,,,,,,,,,,, -119600,1.5051167,3.5893486,,,,,,,,,,,,,, -119700,1.6270226,4.3722506,,,,,,,,,,,,,, -119800,1.6259204,3.0497212,,,,,,,,,,,,,, -119900,1.7694616,3.17216,,,,,,,,,,,,,, -120000,1.5650285,3.0197296,,,,,,,,,,,,,, -120024,,,0.8238866925239563,0.8949841856956482,0.7396199703216553,1.250055909156799,50000.0,0.6173000335693359,1.8545596599578853,10000.0,55075.26651906967,60462.61881303787,55075.26651906967,5375.512019634247,4.913069009780884,0.0 -120100,1.5530101,3.83176,,,,,,,,,,,,,, -120200,1.5820792,3.0742133,,,,,,,,,,,,,, -120300,1.5895925,4.187153,,,,,,,,,,,,,, -120400,1.5595887,3.6936493,,,,,,,,,,,,,, -120500,1.6778508,3.0710459,,,,,,,,,,,,,, -120600,1.6394557,3.7408724,,,,,,,,,,,,,, -120700,1.5755762,3.0931838,,,,,,,,,,,,,, -120800,1.6857947,3.7900105,,,,,,,,,,,,,, -120900,1.8661114,4.7120457,,,,,,,,,,,,,, -120940,,,0.8238281011581421,0.8898305296897888,0.7408599853515625,1.232032060623169,50000.0,0.6230000257492065,1.833022952079773,10000.0,55495.54484438896,60922.43671345711,55495.54484438896,5414.950091123581,4.961091995239258,0.0 -121000,1.7637126,3.1227179,,,,,,,,,,,,,, -121100,1.7017659,3.1790347,,,,,,,,,,,,,, -121200,1.7193882,3.1502876,,,,,,,,,,,,,, -121300,1.6423162,3.7189584,,,,,,,,,,,,,, -121400,1.7259696,3.1191435,,,,,,,,,,,,,, -121500,1.7244092,3.1502798,,,,,,,,,,,,,, -121600,1.6236783,3.9722023,,,,,,,,,,,,,, -121700,1.5636349,3.0523927,,,,,,,,,,,,,, -121800,1.6799659,3.1167746,,,,,,,,,,,,,, -121856,,,0.8225390315055847,0.8875148296356201,0.744879961013794,1.2177189588546753,50000.0,0.6186000108718872,1.8269776105880733,10000.0,55915.70909833908,61384.78747153282,55915.70909833908,5457.037507534027,5.007095813751221,0.0 -121900,1.7461616,4.2960305,,,,,,,,,,,,,, -122000,1.9908909,4.6842976,,,,,,,,,,,,,, -122100,1.6586614,3.0720108,,,,,,,,,,,,,, -122200,1.5713865,3.4516327,,,,,,,,,,,,,, -122300,1.6463239,3.0964398,,,,,,,,,,,,,, -122400,1.4680969,3.3293183,,,,,,,,,,,,,, -122500,1.7128246,3.1495447,,,,,,,,,,,,,, -122600,1.6851735,3.129929,,,,,,,,,,,,,, -122700,1.6541173,3.0424106,,,,,,,,,,,,,, -122773,,,0.82582026720047,0.8763144016265869,0.7443599700927734,1.2213648557662964,50000.0,0.6287000179290771,1.805085062980652,10000.0,56336.03149437904,61850.01596212387,56336.03149437904,5501.842509508133,5.055784463882446,0.0 -122800,1.7082525,3.1914468,,,,,,,,,,,,,, -122900,1.5623088,3.4938893,,,,,,,,,,,,,, -123000,1.7314763,3.4308913,,,,,,,,,,,,,, -123100,1.4865062,3.3806005,,,,,,,,,,,,,, -123200,1.6758907,3.0715284,,,,,,,,,,,,,, -123300,1.5914084,3.0917025,,,,,,,,,,,,,, -123400,1.7875097,3.0474534,,,,,,,,,,,,,, -123500,1.6717936,3.0972793,,,,,,,,,,,,,, -123600,1.60912,3.2328515,,,,,,,,,,,,,, -123690,,,0.8439843654632568,0.8367078900337219,0.744159996509552,1.248141527175903,50000.0,0.6236000061035156,1.8550283908844,10000.0,56756.35433626175,62309.62432742119,56756.35433626175,5541.017600536346,5.112732887268066,0.0 -123700,1.518324,2.9749343,,,,,,,,,,,,,, -123800,1.8201551,3.0818458,,,,,,,,,,,,,, -123900,1.7958295,3.2563353,,,,,,,,,,,,,, -124000,1.5826887,3.0935633,,,,,,,,,,,,,, -124100,1.786755,3.102054,,,,,,,,,,,,,, -124200,1.7314887,3.0981102,,,,,,,,,,,,,, -124300,1.5379277,3.3371742,,,,,,,,,,,,,, -124400,1.7152196,3.1137238,,,,,,,,,,,,,, -124500,1.6346223,3.3330805,,,,,,,,,,,,,, -124600,1.9604062,4.647733,,,,,,,,,,,,,, -124606,,,0.8265038728713989,0.8778988718986511,0.7480599880218506,1.2074445486068726,50000.0,0.6204000115394592,1.806689739227295,10000.0,57176.65181708336,62771.15999698639,57176.65181708336,5582.155804157257,5.160711050033569,0.0 -124700,1.7582036,3.2861059,,,,,,,,,,,,,, -124800,1.6469646,3.017052,,,,,,,,,,,,,, -124900,1.6593764,3.7310584,,,,,,,,,,,,,, -125000,1.7102928,3.1073432,,,,,,,,,,,,,, -125100,1.5750746,2.9710383,,,,,,,,,,,,,, -125200,1.6302042,3.1631362,,,,,,,,,,,,,, -125300,1.5942578,3.0391178,,,,,,,,,,,,,, -125400,1.7883879,4.522936,,,,,,,,,,,,,, -125500,1.7159727,3.0814347,,,,,,,,,,,,,, -125525,,,0.8316210508346558,0.8772971630096436,0.7468799948692322,1.236533522605896,50000.0,0.6232000589370728,1.830721974372864,10000.0,57596.83524441719,63235.70528793335,57596.83524441719,5626.422222137451,5.204384326934815,0.0 -125600,1.620462,3.6565025,,,,,,,,,,,,,, -125700,1.7113856,3.084482,,,,,,,,,,,,,, -125800,1.6821041,3.143139,,,,,,,,,,,,,, -125900,1.7198814,3.6226525,,,,,,,,,,,,,, -126000,1.7356648,4.003878,,,,,,,,,,,,,, -126100,1.9067245,3.04715,,,,,,,,,,,,,, -126200,1.6770654,2.995018,,,,,,,,,,,,,, -126300,1.7511302,4.4619336,,,,,,,,,,,,,, -126400,1.5808305,3.4121819,,,,,,,,,,,,,, -126442,,,0.84095698595047,0.8439698815345764,0.7470600008964539,1.2318177223205566,50000.0,0.6271000504493713,1.8440407514572144,10000.0,58016.88030552864,63700.16956210136,58016.88030552864,5670.740884780884,5.2526023387908936,0.0 -126500,1.968146,4.677769,,,,,,,,,,,,,, -126600,1.7776711,4.1056347,,,,,,,,,,,,,, -126700,1.6068746,3.1552525,,,,,,,,,,,,,, -126800,1.6331418,3.1595984,,,,,,,,,,,,,, -126900,1.6921304,3.0139189,,,,,,,,,,,,,, -127000,1.6100899,3.2528543,,,,,,,,,,,,,, -127100,1.6236172,3.9432793,,,,,,,,,,,,,, -127200,1.938884,4.5355635,,,,,,,,,,,,,, -127300,1.750185,2.9920962,,,,,,,,,,,,,, -127359,,,0.8307226300239563,0.8531256914138794,0.7465199828147888,1.2010209560394287,50000.0,0.6279000043869019,1.790854811668396,10000.0,58437.10874581337,64159.10203433037,58437.10874581337,5709.3451907634735,5.299185037612915,0.0 -127400,1.7134132,3.3256078,,,,,,,,,,,,,, -127500,1.9507309,4.2829313,,,,,,,,,,,,,, -127600,1.6418992,3.0492942,,,,,,,,,,,,,, -127700,1.7336411,3.4104233,,,,,,,,,,,,,, -127800,1.6860293,3.2119794,,,,,,,,,,,,,, -127900,1.7523526,3.032963,,,,,,,,,,,,,, -128000,1.9892412,4.176102,,,,,,,,,,,,,, -128100,1.7010137,3.7771082,,,,,,,,,,,,,, -128200,1.9973701,4.4776297,,,,,,,,,,,,,, -128275,,,0.8322070240974426,0.8553465604782104,0.7501800060272217,1.1990872621536257,50000.0,0.6304000020027161,1.7960751056671145,10000.0,58857.14763689041,64621.77625489235,58857.14763689041,5751.883533239365,5.343764781951904,0.0 -128300,2.066591,4.653329,,,,,,,,,,,,,, -128400,1.7349256,3.2084386,,,,,,,,,,,,,, -128500,1.7428478,2.984902,,,,,,,,,,,,,, -128600,1.6714356,3.395677,,,,,,,,,,,,,, -128700,1.7514817,3.3724349,,,,,,,,,,,,,, -128800,1.7594488,3.0781612,,,,,,,,,,,,,, -128900,1.7091208,3.0021167,,,,,,,,,,,,,, -129000,1.8008814,2.9901075,,,,,,,,,,,,,, -129100,2.0109189,4.5883512,,,,,,,,,,,,,, -129189,,,0.8402343392372131,0.8571125864982605,0.7488999962806702,1.232455849647522,50000.0,0.6279000043869019,1.829653024673462,10000.0,59277.30268001557,65086.16503977776,59277.30268001557,5796.012695074081,5.395809412002564,0.0 -129200,1.8042287,3.2321157,,,,,,,,,,,,,, -129300,1.932361,3.939158,,,,,,,,,,,,,, -129400,1.7467388,3.001813,,,,,,,,,,,,,, -129500,2.16425,4.560457,,,,,,,,,,,,,, -129600,1.8824422,3.0731313,,,,,,,,,,,,,, -129700,1.829762,3.0857506,,,,,,,,,,,,,, -129800,1.8193458,3.0926352,,,,,,,,,,,,,, -129900,1.717106,2.9862761,,,,,,,,,,,,,, -130000,1.702929,3.1085682,,,,,,,,,,,,,, -130100,1.6741974,3.5279856,,,,,,,,,,,,,, -130106,,,0.83314448595047,0.8533020615577698,0.7516599893569946,1.195619821548462,50000.0,0.6306000351905823,1.786401629447937,10000.0,59697.6349568367,65544.8101940155,59697.6349568367,5834.219810962677,5.449113845825195,0.0 -130200,1.7371023,3.3101864,,,,,,,,,,,,,, -130300,2.1610157,4.6592865,,,,,,,,,,,,,, -130400,1.8319383,3.1138198,,,,,,,,,,,,,, -130500,2.1063633,4.631998,,,,,,,,,,,,,, -130600,1.8953478,4.1752725,,,,,,,,,,,,,, -130700,1.7708485,2.9989526,,,,,,,,,,,,,, -130800,1.9125626,4.037025,,,,,,,,,,,,,, -130900,1.7236454,2.9551091,,,,,,,,,,,,,, -131000,1.6524403,3.5665321,,,,,,,,,,,,,, -131022,,,0.8349804282188416,0.847000241279602,0.7511000037193298,1.1966373920440674,50000.0,0.6299000382423401,1.7943549156188965,10000.0,60117.64026284218,66006.13822078705,60117.64026284218,5875.437610626221,5.501766443252564,0.0 -131100,1.9147657,2.9906085,,,,,,,,,,,,,, -131200,1.8769661,3.0832264,,,,,,,,,,,,,, -131300,2.126901,3.0401437,,,,,,,,,,,,,, -131400,1.7650791,3.1623244,,,,,,,,,,,,,, -131500,1.9358534,3.0711558,,,,,,,,,,,,,, -131600,2.109169,4.4771857,,,,,,,,,,,,,, -131700,1.8983252,4.5030937,,,,,,,,,,,,,, -131800,1.7427355,3.8545158,,,,,,,,,,,,,, -131900,1.7432348,3.4631345,,,,,,,,,,,,,, -131938,,,0.841113269329071,0.8369109630584717,0.7544800043106079,1.2058559656143188,50000.0,0.6317000389099121,1.801892876625061,10000.0,60537.56083083153,66470.13001537323,60537.56083083153,5919.410274267197,5.547304630279541,0.0 -132000,1.8116752,4.173666,,,,,,,,,,,,,, -132100,1.7664174,2.9939408,,,,,,,,,,,,,, -132200,1.7734314,3.3101869,,,,,,,,,,,,,, -132300,1.9532998,4.0993195,,,,,,,,,,,,,, -132400,1.9206259,3.0742326,,,,,,,,,,,,,, -132500,1.9452182,3.01048,,,,,,,,,,,,,, -132600,1.7399948,4.002247,,,,,,,,,,,,,, -132700,1.752296,3.1530123,,,,,,,,,,,,,, -132800,1.8839602,3.1221528,,,,,,,,,,,,,, -132855,,,0.8413476347923279,0.8087218403816223,0.756060004234314,1.167621612548828,50000.0,0.6373000144958496,1.7519946098327637,10000.0,60957.67953944206,66933.93345880508,60957.67953944206,5962.992874383926,5.59683632850647,0.0 -132900,2.1633654,4.463101,,,,,,,,,,,,,, -133000,1.6614205,3.039019,,,,,,,,,,,,,, -133100,1.7636297,2.9835916,,,,,,,,,,,,,, -133200,1.787527,3.0136456,,,,,,,,,,,,,, -133300,1.8731899,3.0271816,,,,,,,,,,,,,, -133400,1.8075879,3.1263666,,,,,,,,,,,,,, -133500,1.7917259,2.927026,,,,,,,,,,,,,, -133600,1.8171271,3.2119553,,,,,,,,,,,,,, -133700,1.9168179,3.0090132,,,,,,,,,,,,,, -133769,,,0.8415429592132568,0.8192632794380188,0.7564199566841125,1.1783759593963623,50000.0,0.6380000114440918,1.7589350938796997,10000.0,61377.61789727211,67398.6365814209,61377.61789727211,6007.655216932297,5.646831274032593,0.0 -133800,1.87712,3.909968,,,,,,,,,,,,,, -133900,1.8328714,3.2265716,,,,,,,,,,,,,, -134000,1.8887273,2.9967368,,,,,,,,,,,,,, -134100,1.774865,3.027175,,,,,,,,,,,,,, -134200,1.7671587,3.3017166,,,,,,,,,,,,,, -134300,1.9005462,3.126543,,,,,,,,,,,,,, -134400,1.8310835,2.911309,,,,,,,,,,,,,, -134500,1.8972784,2.9640503,,,,,,,,,,,,,, -134600,1.756797,3.035567,,,,,,,,,,,,,, -134688,,,0.8436328172683716,0.8358486890792847,0.7567600011825562,1.1952334642410278,50000.0,0.6371000409126282,1.79377543926239,10000.0,61797.7135746479,67858.43765830994,61797.7135746479,6047.2556438446045,5.699421167373657,0.0 -134700,1.8771492,4.0442286,,,,,,,,,,,,,, -134800,1.8990396,3.0171623,,,,,,,,,,,,,, -134900,1.8140388,3.185067,,,,,,,,,,,,,, -135000,1.8173629,3.539903,,,,,,,,,,,,,, -135100,1.8734345,3.1762133,,,,,,,,,,,,,, -135200,1.9402729,4.2116623,,,,,,,,,,,,,, -135300,1.7840749,3.001697,,,,,,,,,,,,,, -135400,1.8698099,2.9930122,,,,,,,,,,,,,, -135500,2.4361124,4.5382776,,,,,,,,,,,,,, -135600,1.8609686,2.9686444,,,,,,,,,,,,,, -135607,,,0.857714831829071,0.7488970160484314,0.75764000415802,1.1549402475357056,50000.0,0.6403000354766846,1.7427631616592407,10000.0,62218.0052587986,68323.26888012886,62218.0052587986,6091.693810224533,5.74712872505188,0.0 -135700,2.0475736,4.376807,,,,,,,,,,,,,, -135800,1.9391268,3.2710078,,,,,,,,,,,,,, -135900,1.7571313,2.9638143,,,,,,,,,,,,,, -136000,1.8675511,3.3705766,,,,,,,,,,,,,, -136100,1.8576022,3.1163514,,,,,,,,,,,,,, -136200,2.0294151,4.0169578,,,,,,,,,,,,,, -136300,2.1563148,4.421372,,,,,,,,,,,,,, -136400,2.3132558,4.494819,,,,,,,,,,,,,, -136500,2.0767558,2.9693422,,,,,,,,,,,,,, -136525,,,0.8420116901397705,0.8260501623153687,0.7569199800491333,1.185738205909729,50000.0,0.6402000188827515,1.761014103889465,10000.0,62638.34627819061,68784.49385523796,62638.34627819061,6132.479343414307,5.793379783630371,0.0 -136600,1.8465891,3.3937078,,,,,,,,,,,,,, -136700,1.7170408,3.1567018,,,,,,,,,,,,,, -136800,1.8502581,3.8922052,,,,,,,,,,,,,, -136900,1.839168,2.9179158,,,,,,,,,,,,,, -137000,2.1173317,4.2529144,,,,,,,,,,,,,, -137100,1.9747818,2.906254,,,,,,,,,,,,,, -137200,1.8230301,3.035715,,,,,,,,,,,,,, -137300,2.4714532,4.5235295,,,,,,,,,,,,,, -137400,1.9068782,2.9953659,,,,,,,,,,,,,, -137441,,,0.8498241901397705,0.7776594161987305,0.7576599717140198,1.155328392982483,50000.0,0.6407000422477722,1.739344596862793,10000.0,63058.35525393486,69246.24775362015,63058.35525393486,6174.12397646904,5.840576648712158,0.0 -137500,1.8945524,3.868934,,,,,,,,,,,,,, -137600,1.9460834,3.2300668,,,,,,,,,,,,,, -137700,1.9218189,2.907928,,,,,,,,,,,,,, -137800,1.9414749,3.0370526,,,,,,,,,,,,,, -137900,1.8197945,3.3122368,,,,,,,,,,,,,, -138000,1.7074683,3.7146533,,,,,,,,,,,,,, -138100,2.0162342,4.0841436,,,,,,,,,,,,,, -138200,1.8514891,3.0920372,,,,,,,,,,,,,, -138300,2.0240405,3.1067352,,,,,,,,,,,,,, -138357,,,0.8564062118530273,0.7429553270339966,0.7604999542236328,1.145262360572815,50000.0,0.6389000415802002,1.75374174118042,10000.0,63478.462926864624,69711.37830281258,63478.462926864624,6219.046733617783,5.887576580047607,0.0 -138400,1.8583528,2.8819854,,,,,,,,,,,,,, -138500,2.0585878,4.0494127,,,,,,,,,,,,,, -138600,2.246018,4.277855,,,,,,,,,,,,,, -138700,1.8886824,2.980607,,,,,,,,,,,,,, -138800,1.9949208,3.4449837,,,,,,,,,,,,,, -138900,2.0125108,4.1297956,,,,,,,,,,,,,, -139000,2.1184964,3.4486043,,,,,,,,,,,,,, -139100,1.9548165,2.9061346,,,,,,,,,,,,,, -139200,1.9421582,2.991651,,,,,,,,,,,,,, -139275,,,0.848437488079071,0.8022940158843994,0.7615199685096741,1.1675846576690674,50000.0,0.64000004529953,1.7675161361694336,10000.0,63898.55232334137,70174.82610559464,63898.55232334137,6262.304250955582,5.935348272323608,0.0 -139300,1.9714139,3.3095667,,,,,,,,,,,,,, -139400,2.0134602,3.027193,,,,,,,,,,,,,, -139500,2.0262148,2.9742796,,,,,,,,,,,,,, -139600,2.0310366,3.8200235,,,,,,,,,,,,,, -139700,2.0262153,2.9842737,,,,,,,,,,,,,, -139800,1.8939538,2.9417279,,,,,,,,,,,,,, -139900,1.8517087,3.5293539,,,,,,,,,,,,,, -140000,2.1126516,2.892724,,,,,,,,,,,,,, -140100,1.9576705,2.890048,,,,,,,,,,,,,, -140189,,,0.850390613079071,0.7843379974365234,0.7585200071334839,1.1634849309921265,50000.0,0.6385000348091125,1.748010277748108,10000.0,64318.52702474594,70639.37820744514,64318.52702474594,6306.780901193619,5.983578443527222,0.0 -140200,2.7352448,4.453924,,,,,,,,,,,,,, -140300,1.734445,2.9318829,,,,,,,,,,,,,, -140400,1.9211074,3.6335754,,,,,,,,,,,,,, -140500,2.0298657,2.9888272,,,,,,,,,,,,,, -140600,1.8360091,3.528152,,,,,,,,,,,,,, -140700,1.9726768,3.5967398,,,,,,,,,,,,,, -140800,1.9586577,2.9481645,,,,,,,,,,,,,, -140900,2.1410441,3.3779278,,,,,,,,,,,,,, -141000,1.9796535,2.9774609,,,,,,,,,,,,,, -141100,1.8527927,3.2720923,,,,,,,,,,,,,, -141104,,,0.8593944907188416,0.7617772221565247,0.7626199722290039,1.1537964344024658,50000.0,0.645300030708313,1.7508445978164673,10000.0,64738.57985305786,71103.796667099,64738.57985305786,6351.042797088623,6.035136699676514,0.0 -141200,2.144116,3.8956664,,,,,,,,,,,,,, -141300,2.042537,3.1008792,,,,,,,,,,,,,, -141400,2.0881732,4.1232615,,,,,,,,,,,,,, -141500,1.9394418,3.0013652,,,,,,,,,,,,,, -141600,2.2750652,4.40083,,,,,,,,,,,,,, -141700,2.2873662,4.2591014,,,,,,,,,,,,,, -141800,1.9400669,2.9234757,,,,,,,,,,,,,, -141900,1.9818201,2.9997625,,,,,,,,,,,,,, -142000,2.0631037,2.973062,,,,,,,,,,,,,, -142015,,,0.8515819907188416,0.7848900556564331,0.7644199728965759,1.1486810445785522,50000.0,0.6426000595092773,1.7553709745407104,10000.0,65158.57492208481,71568.67630791664,65158.57492208481,6395.823813199997,6.086883783340454,0.0 -142100,2.3477888,4.194345,,,,,,,,,,,,,, -142200,2.1642144,2.8625817,,,,,,,,,,,,,, -142300,2.0930588,2.9240289,,,,,,,,,,,,,, -142400,1.9096547,3.3144412,,,,,,,,,,,,,, -142500,2.0369928,2.9944246,,,,,,,,,,,,,, -142600,2.1120589,3.2478514,,,,,,,,,,,,,, -142700,1.8446816,3.3128552,,,,,,,,,,,,,, -142800,2.0290625,3.7330663,,,,,,,,,,,,,, -142900,1.8617976,2.8613634,,,,,,,,,,,,,, -142930,,,0.8580859303474426,0.7559224367141724,0.7646999955177307,1.1461162567138672,50000.0,0.6457000374794006,1.7412879467010498,10000.0,65578.79041981697,72028.25753879547,65578.79041981697,6435.086161613464,6.138216018676758,0.0 -143000,1.982081,2.9141438,,,,,,,,,,,,,, -143100,1.9986889,2.932541,,,,,,,,,,,,,, -143200,2.0473392,2.9873757,,,,,,,,,,,,,, -143300,1.9225736,3.3048172,,,,,,,,,,,,,, -143400,1.9789878,3.2742374,,,,,,,,,,,,,, -143500,2.0231884,2.9721165,,,,,,,,,,,,,, -143600,2.079194,2.976507,,,,,,,,,,,,,, -143700,2.100215,2.9592078,,,,,,,,,,,,,, -143800,1.8323743,3.2475593,,,,,,,,,,,,,, -143845,,,0.8574609160423279,0.7697817087173462,0.7655199766159058,1.1547305583953855,50000.0,0.645300030708313,1.753394603729248,10000.0,65998.90920972824,72487.25309491158,65998.90920972824,6473.854023933411,6.194704294204712,0.0 -143900,2.1601489,3.9459455,,,,,,,,,,,,,, -144000,2.1169837,4.113109,,,,,,,,,,,,,, -144100,2.047695,3.5040083,,,,,,,,,,,,,, -144200,1.8728583,2.9819574,,,,,,,,,,,,,, -144300,1.9448522,2.9253387,,,,,,,,,,,,,, -144400,2.1580274,4.1469374,,,,,,,,,,,,,, -144500,2.2318678,2.8925605,,,,,,,,,,,,,, -144600,2.317371,3.823367,,,,,,,,,,,,,, -144700,2.0459244,3.5823345,,,,,,,,,,,,,, -144762,,,0.8602929711341858,0.742518961429596,0.7663799524307251,1.1345235109329224,50000.0,0.6467000246047974,1.7295129299163818,10000.0,66419.2152094841,72949.6464586258,66419.2152094841,6515.839464187622,6.242878437042236,0.0 -144800,2.7720845,4.4059834,,,,,,,,,,,,,, -144900,2.0852616,2.9121149,,,,,,,,,,,,,, -145000,2.447408,4.4433355,,,,,,,,,,,,,, -145100,1.9155787,2.8483236,,,,,,,,,,,,,, -145200,1.8940167,2.9184914,,,,,,,,,,,,,, -145300,2.2196186,3.1085682,,,,,,,,,,,,,, -145400,2.2926476,2.9517655,,,,,,,,,,,,,, -145500,1.829976,3.5578997,,,,,,,,,,,,,, -145600,2.148791,3.924052,,,,,,,,,,,,,, -145674,,,0.8600585460662842,0.7430623173713684,0.7668799757957458,1.1308563947677612,50000.0,0.6487000584602356,1.7240246534347534,10000.0,66839.39895510674,73413.97380805016,66839.39895510674,6559.876788377762,6.297804832458496,0.0 -145700,2.9560103,4.4768577,,,,,,,,,,,,,, -145800,2.1361866,2.9756312,,,,,,,,,,,,,, -145900,2.0607088,3.000041,,,,,,,,,,,,,, -146000,2.3253877,3.0644803,,,,,,,,,,,,,, -146100,1.9315677,3.1858697,,,,,,,,,,,,,, -146200,2.2190354,3.665823,,,,,,,,,,,,,, -146300,1.9931626,3.164115,,,,,,,,,,,,,, -146400,2.1602342,3.9011257,,,,,,,,,,,,,, -146500,2.2650158,4.2584925,,,,,,,,,,,,,, -146588,,,0.864062488079071,0.7536942958831787,0.7701799869537354,1.145656943321228,50000.0,0.64410001039505,1.748490333557129,10000.0,67259.50889992714,73876.54255223274,67259.50889992714,6602.229855775833,6.350820302963257,0.0 -146600,2.170834,2.8543377,,,,,,,,,,,,,, -146700,2.0634646,2.8890998,,,,,,,,,,,,,, -146800,2.1042304,2.9544082,,,,,,,,,,,,,, -146900,2.0178468,2.9275649,,,,,,,,,,,,,, -147000,2.0462391,3.7027507,,,,,,,,,,,,,, -147100,2.5047307,4.4561534,,,,,,,,,,,,,, -147200,2.0931063,3.6329303,,,,,,,,,,,,,, -147300,1.9976248,2.9939377,,,,,,,,,,,,,, -147400,2.1813788,3.02159,,,,,,,,,,,,,, -147500,2.1599822,2.936526,,,,,,,,,,,,,, -147504,,,0.8710546493530273,0.708790123462677,0.7684599757194519,1.1278865337371826,50000.0,0.6520000100135803,1.7187994718551636,10000.0,67679.66255092621,74341.20226407051,67679.66255092621,6646.628881692886,6.404725551605225,0.0 -147600,2.2591941,2.95615,,,,,,,,,,,,,, -147700,2.1987453,2.8491738,,,,,,,,,,,,,, -147800,2.082057,2.928907,,,,,,,,,,,,,, -147900,1.976083,3.06189,,,,,,,,,,,,,, -148000,2.896851,4.3732967,,,,,,,,,,,,,, -148100,2.1191914,2.8906868,,,,,,,,,,,,,, -148200,2.1703544,2.858984,,,,,,,,,,,,,, -148300,2.188386,3.8689384,,,,,,,,,,,,,, -148400,2.059691,3.056765,,,,,,,,,,,,,, -148421,,,0.8649804592132568,0.7262025475502014,0.7682799696922302,1.1214993000030518,50000.0,0.6528000235557556,1.7120332717895508,10000.0,68099.97631287575,74805.88531470299,68099.97631287575,6690.891248703003,6.458734512329102,0.0 -148500,2.627271,4.4148836,,,,,,,,,,,,,, -148600,2.668138,4.467204,,,,,,,,,,,,,, -148700,2.1816196,3.421948,,,,,,,,,,,,,, -148800,2.135955,3.0075636,,,,,,,,,,,,,, -148900,2.5186422,3.288029,,,,,,,,,,,,,, -149000,2.250828,3.0444672,,,,,,,,,,,,,, -149100,2.073131,3.7301452,,,,,,,,,,,,,, -149200,2.0718102,2.8966882,,,,,,,,,,,,,, -149300,2.091555,2.8333597,,,,,,,,,,,,,, -149335,,,0.8691601157188416,0.7210755944252014,0.7702999711036682,1.124349594116211,50000.0,0.651900053024292,1.711452603340149,10000.0,68519.64030075073,75268.590113163,68519.64030075073,6733.445414304733,6.892820119857788,0.0 -149400,2.282742,3.9422178,,,,,,,,,,,,,, -149500,2.1747,2.8530893,,,,,,,,,,,,,, -149600,2.1269855,3.6270761,,,,,,,,,,,,,, -149700,2.0859587,3.770271,,,,,,,,,,,,,, -149800,2.364579,3.9437342,,,,,,,,,,,,,, -149900,2.451696,4.287303,,,,,,,,,,,,,, -150000,2.2632391,3.276514,,,,,,,,,,,,,, -150100,1.9909984,2.8835251,,,,,,,,,,,,,, -150200,2.164974,2.9751987,,,,,,,,,,,,,, -150250,,,0.8727734088897705,0.6951795816421509,0.770859956741333,1.1124303340911863,50000.0,0.6508000493049622,1.714530348777771,10000.0,68939.89294409752,75734.91024708748,68939.89294409752,6779.409840583801,6.94247579574585,0.0 -150300,2.7889018,4.211799,,,,,,,,,,,,,, -150400,2.3009522,2.9145288,,,,,,,,,,,,,, -150500,1.9930227,3.0521684,,,,,,,,,,,,,, -150600,2.1966608,2.9015791,,,,,,,,,,,,,, -150700,2.3038144,2.9715412,,,,,,,,,,,,,, -150800,2.308714,4.0775557,,,,,,,,,,,,,, -150900,2.2102895,2.9778051,,,,,,,,,,,,,, -151000,2.300515,3.766277,,,,,,,,,,,,,, -151100,2.2721307,2.860087,,,,,,,,,,,,,, -151165,,,0.8675585985183716,0.7195205688476562,0.7733199596405029,1.1224658489227295,50000.0,0.6515000462532043,1.7132558822631836,10000.0,69359.92496109009,76195.36973547935,69359.92496109009,6819.731669664383,6.9958555698394775,0.0 -151200,2.3163548,2.9331486,,,,,,,,,,,,,, -151300,2.878591,4.365211,,,,,,,,,,,,,, -151400,2.1214907,2.8738313,,,,,,,,,,,,,, -151500,1.9979674,3.0046244,,,,,,,,,,,,,, -151600,2.2829864,2.8259554,,,,,,,,,,,,,, -151700,2.8265743,4.326472,,,,,,,,,,,,,, -151800,2.2130141,3.745979,,,,,,,,,,,,,, -151900,2.5257828,4.267113,,,,,,,,,,,,,, -152000,2.2389162,3.8578784,,,,,,,,,,,,,, -152079,,,0.8717772960662842,0.6925562620162964,0.7725399732589722,1.0992447137832642,50000.0,0.6582000255584717,1.690200686454773,10000.0,69779.90133500099,76659.01178598404,69779.90133500099,6863.285403966904,7.0559470653533936,0.0 -152100,2.1850555,3.5311477,,,,,,,,,,,,,, -152200,2.1885188,2.8749409,,,,,,,,,,,,,, -152300,2.3549361,3.65701,,,,,,,,,,,,,, -152400,2.264505,2.8677435,,,,,,,,,,,,,, -152500,2.2101603,2.9596844,,,,,,,,,,,,,, -152600,2.1181707,2.914987,,,,,,,,,,,,,, -152700,2.435812,4.0722165,,,,,,,,,,,,,, -152800,2.2221923,3.4312272,,,,,,,,,,,,,, -152900,2.2208705,3.3308551,,,,,,,,,,,,,, -152996,,,0.8743945360183716,0.6965427398681641,0.7730000019073486,1.1139885187149048,50000.0,0.6490000486373901,1.7110410928726196,10000.0,70199.87671470642,77117.18258023262,70199.87671470642,6901.371284723282,7.113300085067749,0.0 -153000,2.2081058,2.854379,,,,,,,,,,,,,, -153100,2.356225,2.901253,,,,,,,,,,,,,, -153200,2.1896868,2.9724646,,,,,,,,,,,,,, -153300,2.550318,4.1137743,,,,,,,,,,,,,, -153400,2.1194232,3.479763,,,,,,,,,,,,,, -153500,2.1564538,3.4816544,,,,,,,,,,,,,, -153600,2.0993764,3.151747,,,,,,,,,,,,,, -153700,2.250239,2.9221308,,,,,,,,,,,,,, -153800,2.5008223,3.9889495,,,,,,,,,,,,,, -153900,2.222052,3.5326676,,,,,,,,,,,,,, -153909,,,0.8716406226158142,0.702102541923523,0.7744199633598328,1.1024911403656006,50000.0,0.6552000045776367,1.6988980770111084,10000.0,70619.90300965309,77579.82703661919,70619.90300965309,6943.878223657608,7.172998905181885,0.0 -154000,2.221365,3.3483522,,,,,,,,,,,,,, -154100,2.1319382,3.2922156,,,,,,,,,,,,,, -154200,2.2446158,2.8151882,,,,,,,,,,,,,, -154300,2.6645372,4.3806605,,,,,,,,,,,,,, -154400,2.2267616,2.861504,,,,,,,,,,,,,, -154500,2.3603816,2.8560386,,,,,,,,,,,,,, -154600,2.2475235,2.9229288,,,,,,,,,,,,,, -154700,2.2411737,2.7916386,,,,,,,,,,,,,, -154800,2.1927931,2.9140313,,,,,,,,,,,,,, -154827,,,0.8707226514816284,0.7155790328979492,0.7752999663352966,1.1152511835098269,50000.0,0.6602000594139099,1.70155668258667,10000.0,71039.93932843208,78040.47797226906,71039.93932843208,6984.388374567032,7.224968194961548,0.0 -154900,3.2237666,4.3402567,,,,,,,,,,,,,, -155000,2.2137232,2.999068,,,,,,,,,,,,,, -155100,2.1796606,2.8530157,,,,,,,,,,,,,, -155200,2.30226,3.8199313,,,,,,,,,,,,,, -155300,2.2762265,2.9076214,,,,,,,,,,,,,, -155400,2.239269,2.8795028,,,,,,,,,,,,,, -155500,2.2785316,2.8725471,,,,,,,,,,,,,, -155600,2.3379996,2.8560874,,,,,,,,,,,,,, -155700,2.5233297,2.863867,,,,,,,,,,,,,, -155742,,,0.8770117163658142,0.6983265280723572,0.7755199670791626,1.116679549217224,50000.0,0.6535000205039978,1.7088316679000854,10000.0,71460.28168869019,78502.2566742897,71460.28168869019,7025.718888044357,7.279221773147583,0.0 -155800,2.3135788,2.78213,,,,,,,,,,,,,, -155900,2.3583837,3.4429622,,,,,,,,,,,,,, -156000,2.4556527,3.9520595,,,,,,,,,,,,,, -156100,2.2140253,3.0433629,,,,,,,,,,,,,, -156200,2.2947962,3.3549747,,,,,,,,,,,,,, -156300,2.2270534,2.818131,,,,,,,,,,,,,, -156400,2.3810475,2.8817863,,,,,,,,,,,,,, -156500,2.1517155,2.9680362,,,,,,,,,,,,,, -156600,2.433318,2.834969,,,,,,,,,,,,,, -156657,,,0.8760351538658142,0.6916414499282837,0.7758199572563171,1.1019681692123413,50000.0,0.657200038433075,1.6836694478988647,10000.0,71880.30394387245,78963.12360739708,71880.30394387245,7066.457926273346,7.332519292831421,0.0 -156700,3.7541177,4.3333097,,,,,,,,,,,,,, -156800,2.3341222,3.6864736,,,,,,,,,,,,,, -156900,2.2691832,2.775249,,,,,,,,,,,,,, -157000,2.3272452,2.8470345,,,,,,,,,,,,,, -157100,2.215056,3.3512974,,,,,,,,,,,,,, -157200,2.2091045,3.5671053,,,,,,,,,,,,,, -157300,2.4263623,2.842733,,,,,,,,,,,,,, -157400,2.4488153,2.833075,,,,,,,,,,,,,, -157500,2.2829602,2.7996197,,,,,,,,,,,,,, -157574,,,0.8784960508346558,0.6948211789131165,0.7772799730300903,1.1076263189315796,50000.0,0.6581000089645386,1.694753646850586,10000.0,72300.49660873413,79427.16787004471,72300.49660873413,7110.197211503983,7.3921473026275635,0.0 -157600,2.3149645,3.1578772,,,,,,,,,,,,,, -157700,2.316456,2.7916465,,,,,,,,,,,,,, -157800,2.358574,2.817347,,,,,,,,,,,,,, -157900,3.2079313,4.3416767,,,,,,,,,,,,,, -158000,2.287876,3.3763902,,,,,,,,,,,,,, -158100,2.4944425,2.9647167,,,,,,,,,,,,,, -158200,2.5529914,3.6909113,,,,,,,,,,,,,, -158300,3.2525744,3.5960162,,,,,,,,,,,,,, -158400,2.3042824,2.7647302,,,,,,,,,,,,,, -158489,,,0.8775194883346558,0.6820436716079712,0.776419997215271,1.095947504043579,50000.0,0.6564000248908997,1.6833994388580322,10000.0,72720.79914164543,79891.43614602089,72720.79914164543,7154.059904813767,7.443153619766235,0.0 -158500,3.122917,4.313107,,,,,,,,,,,,,, -158600,2.3976953,2.8284624,,,,,,,,,,,,,, -158700,2.4547215,2.7991476,,,,,,,,,,,,,, -158800,2.50197,2.8304336,,,,,,,,,,,,,, -158900,2.4618475,3.4338164,,,,,,,,,,,,,, -159000,2.3212092,2.849926,,,,,,,,,,,,,, -159100,2.288764,2.8171177,,,,,,,,,,,,,, -159200,2.5846272,3.2206767,,,,,,,,,,,,,, -159300,2.357779,3.2387686,,,,,,,,,,,,,, -159400,2.329463,2.8588552,,,,,,,,,,,,,, -159403,,,0.8854882717132568,0.6608580350875854,0.7779200077056885,1.1026121377944946,50000.0,0.6584000587463379,1.6912035942077637,10000.0,73141.05099487305,80355.4780535698,73141.05099487305,7197.746908187866,7.494960784912109,0.0 -159500,2.318888,2.7482495,,,,,,,,,,,,,, -159600,2.2580462,2.813951,,,,,,,,,,,,,, -159700,2.2987218,2.9814293,,,,,,,,,,,,,, -159800,3.3017378,4.288636,,,,,,,,,,,,,, -159900,2.3651602,2.7704191,,,,,,,,,,,,,, -160000,2.6619709,3.8649583,,,,,,,,,,,,,, -160100,3.1230426,4.2240415,,,,,,,,,,,,,, -160200,2.565746,2.8548563,,,,,,,,,,,,,, -160300,2.3457036,2.8389738,,,,,,,,,,,,,, -160318,,,0.8786523342132568,0.6797360181808472,0.778939962387085,1.0944668054580688,50000.0,0.6579000353813171,1.6871545314788818,10000.0,73561.01437163353,80821.08317613602,73561.01437163353,7243.284869670868,7.547072887420654,0.0 -160400,2.3724113,3.1922264,,,,,,,,,,,,,, -160500,2.2070181,2.7414367,,,,,,,,,,,,,, -160600,2.3557527,2.8568673,,,,,,,,,,,,,, -160700,2.3552518,2.817802,,,,,,,,,,,,,, -160800,2.4685745,3.1877275,,,,,,,,,,,,,, -160900,2.412676,3.2027183,,,,,,,,,,,,,, -161000,2.371883,3.022243,,,,,,,,,,,,,, -161100,2.5698118,2.9003015,,,,,,,,,,,,,, -161200,2.3542101,2.85593,,,,,,,,,,,,,, -161235,,,0.8811132907867432,0.6593165993690491,0.7783799767494202,1.0817325115203855,50000.0,0.6635000109672546,1.6645426750183103,10000.0,73980.93570017815,81278.25377750397,73980.93570017815,7280.42783331871,7.600719690322876,0.0 -161300,2.3090103,2.7949562,,,,,,,,,,,,,, -161400,2.9936035,4.239402,,,,,,,,,,,,,, -161500,2.419917,3.4092913,,,,,,,,,,,,,, -161600,2.4641168,3.264272,,,,,,,,,,,,,, -161700,2.4743512,2.7617092,,,,,,,,,,,,,, -161800,3.7215695,4.3125777,,,,,,,,,,,,,, -161900,2.3864324,3.4391415,,,,,,,,,,,,,, -162000,2.2595162,3.0153556,,,,,,,,,,,,,, -162100,2.3736343,2.7786872,,,,,,,,,,,,,, -162150,,,0.8886327743530273,0.6440286040306091,0.7780199646949768,1.0898090600967407,50000.0,0.6640000343322754,1.6774790287017822,10000.0,74401.1434378624,81742.06526184082,74401.1434378624,7323.926148414612,7.653084754943848,0.0 -162200,2.4203384,3.4146998,,,,,,,,,,,,,, -162300,2.450458,2.847176,,,,,,,,,,,,,, -162400,2.550951,3.5447588,,,,,,,,,,,,,, -162500,2.4657376,2.8187327,,,,,,,,,,,,,, -162600,2.3094208,2.671388,,,,,,,,,,,,,, -162700,2.632611,3.9043155,,,,,,,,,,,,,, -162800,2.2455819,3.1188555,,,,,,,,,,,,,, -162900,2.345354,2.8591022,,,,,,,,,,,,,, -163000,2.219523,3.0179353,,,,,,,,,,,,,, -163066,,,0.8836523294448853,0.6655579805374146,0.7795799970626831,1.0824384689331057,50000.0,0.6609000563621521,1.669062614440918,10000.0,74821.25588536263,82206.14687275887,74821.25588536263,7367.785813331604,7.709946155548096,0.0 -163100,2.398396,2.8036005,,,,,,,,,,,,,, -163200,2.570605,2.8124144,,,,,,,,,,,,,, -163300,2.2627873,2.8307123,,,,,,,,,,,,,, -163400,2.4406552,2.8135161,,,,,,,,,,,,,, -163500,2.4039598,2.8392956,,,,,,,,,,,,,, -163600,2.6366107,3.0016253,,,,,,,,,,,,,, -163700,2.6834185,2.8593338,,,,,,,,,,,,,, -163800,2.732771,3.8037095,,,,,,,,,,,,,, -163900,2.4909358,2.8416438,,,,,,,,,,,,,, -163980,,,0.8823828101158142,0.6626566648483276,0.7808399796485901,1.0874329805374146,50000.0,0.659500002861023,1.6772193908691406,10000.0,75241.35402417183,82670.88124489784,75241.35402417183,7412.314453601837,7.765376091003418,0.0 -164000,2.5725732,2.8094115,,,,,,,,,,,,,, -164100,2.3471506,2.8304567,,,,,,,,,,,,,, -164200,2.525984,2.7937949,,,,,,,,,,,,,, -164300,2.4654846,2.7809205,,,,,,,,,,,,,, -164400,2.426027,2.817333,,,,,,,,,,,,,, -164500,2.5522346,2.7681065,,,,,,,,,,,,,, -164600,2.6411464,2.7670689,,,,,,,,,,,,,, -164700,2.313427,2.927069,,,,,,,,,,,,,, -164800,2.4825273,2.8412974,,,,,,,,,,,,,, -164896,,,0.8887304663658142,0.6466950178146362,0.7809199690818787,1.0870565176010132,50000.0,0.6633000373840332,1.6752506494522097,10000.0,75661.63371706009,83134.75895619392,75661.63371706009,7455.80052113533,7.824825286865234,0.0 -164900,2.575584,3.6036472,,,,,,,,,,,,,, -165000,2.555883,2.7842689,,,,,,,,,,,,,, -165100,2.6120992,2.78377,,,,,,,,,,,,,, -165200,2.225407,3.4377732,,,,,,,,,,,,,, -165300,2.94507,4.0838833,,,,,,,,,,,,,, -165400,2.565496,2.789931,,,,,,,,,,,,,, -165500,2.5536385,2.7880137,,,,,,,,,,,,,, -165600,2.4865448,2.8195796,,,,,,,,,,,,,, -165700,2.5182345,2.8983061,,,,,,,,,,,,,, -165800,2.3461525,2.7555025,,,,,,,,,,,,,, -165811,,,0.8844726085662842,0.6560773849487305,0.7813000082969666,1.0782166719436646,50000.0,0.6655000448226929,1.6651965379714966,10000.0,76081.98018074036,83600.45110487938,76081.98018074036,7501.036018610001,7.883217334747314,0.0 -165900,2.7043452,2.8623328,,,,,,,,,,,,,, -166000,2.4696312,2.9927692,,,,,,,,,,,,,, -166100,2.5255797,3.109977,,,,,,,,,,,,,, -166200,2.5236187,2.727464,,,,,,,,,,,,,, -166300,2.3532927,2.7694464,,,,,,,,,,,,,, -166400,2.4208457,2.7851193,,,,,,,,,,,,,, -166500,2.8258464,3.9440577,,,,,,,,,,,,,, -166600,2.5133364,2.744901,,,,,,,,,,,,,, -166700,2.4149718,3.2479815,,,,,,,,,,,,,, -166727,,,0.8896484375,0.6387453079223633,0.7838599681854248,1.0713698863983154,50000.0,0.6671000123023987,1.654746413230896,10000.0,76502.31396532059,84059.78851270676,76502.31396532059,7539.935069799423,7.935632944107056,0.0 -166800,2.530588,2.7924335,,,,,,,,,,,,,, -166900,2.651802,2.7793548,,,,,,,,,,,,,, -167000,2.452021,3.5313938,,,,,,,,,,,,,, -167100,2.7046976,3.8706274,,,,,,,,,,,,,, -167200,2.464628,3.0601745,,,,,,,,,,,,,, -167300,3.1990306,3.957111,,,,,,,,,,,,,, -167400,2.4504564,2.756278,,,,,,,,,,,,,, -167500,2.549614,2.7048454,,,,,,,,,,,,,, -167600,3.2699456,4.1710505,,,,,,,,,,,,,, -167644,,,0.8884179592132568,0.649350643157959,0.7827999591827393,1.0806797742843628,50000.0,0.6630000472068787,1.6625733375549316,10000.0,76922.55463337898,84524.23124790192,76922.55463337898,7584.034651994705,7.9859490394592285,0.0 -167700,2.53602,2.941904,,,,,,,,,,,,,, -167800,2.4966893,3.5609186,,,,,,,,,,,,,, -167900,2.479311,2.7588074,,,,,,,,,,,,,, -168000,2.700583,2.8013744,,,,,,,,,,,,,, -168100,2.5922573,2.776106,,,,,,,,,,,,,, -168200,2.4177263,2.8812778,,,,,,,,,,,,,, -168300,2.4537501,2.7938735,,,,,,,,,,,,,, -168400,2.5043561,3.005255,,,,,,,,,,,,,, -168500,3.005552,2.713958,,,,,,,,,,,,,, -168561,,,0.8898632526397705,0.6395523548126221,0.7828800082206726,1.0810366868972778,50000.0,0.666700005531311,1.6670899391174316,10000.0,77342.60869932175,84987.16055512428,77342.60869932175,7626.804463386536,8.039025783538818,0.0 -168600,2.5570562,3.4281545,,,,,,,,,,,,,, -168700,2.5084963,2.740342,,,,,,,,,,,,,, -168800,2.490861,2.8732605,,,,,,,,,,,,,, -168900,2.3563287,2.7455316,,,,,,,,,,,,,, -168954,,,,,,,,,,,77520.17618727684,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 857e95a00..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -30.9212863445282,0.0,34.48550629615784,1,0,34.48550629615784,0.0010000000474974,6.907756805419922,10000,65.40688729286194,0.0011328124674037,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -76.71136093139648,0.017364501953125,454.543984413147,853,0,454.543984413147,0.0105000007897615,6.523868083953857,10000,531.3210079669952,0.0112499995157122,6.477405071258545,0.0108599998056888,6.486855030059815,50000 -122.98103356361388,0.0453393459320068,874.770247220993,1759,0,874.770247220993,0.0279000010341405,6.0221991539001465,10000,997.8961570262908,0.0346093736588954,5.880190372467041,0.0334799997508525,5.913444995880127,50000 -169.67155838012695,0.0765357017517089,1294.968270778656,2673,0,1294.968270778656,0.0459000021219253,5.680256366729736,10000,1464.868360042572,0.0605859346687793,5.4684343338012695,0.0572999976575374,5.506952285766602,50000 -218.1023845672608,0.1063158512115478,1715.2244164943695,3588,0,1715.2244164943695,0.0681000053882598,5.376662731170654,10000,1933.6374650001528,0.0949414074420929,5.101712703704834,0.0859000012278556,5.135260105133057,50000 -264.62031650543213,0.1340563297271728,2135.2453899383545,4504,0,2135.2453899383545,0.0938000008463859,5.061717510223389,10000,2400.255875349045,0.1316992193460464,4.694127559661865,0.1209999993443489,4.762543678283691,50000 -311.2709410190582,0.161536693572998,2555.4077939987183,5416,0,2555.4077939987183,0.1170000061392784,4.778557300567627,10000,2867.1481182575226,0.1699804663658142,4.376136302947998,0.1554600000381469,4.456610679626465,50000 -360.0176274776459,0.192075490951538,2975.353636741638,6332,0,2975.353636741638,0.1517000049352646,4.465686321258545,10000,3335.9222333431244,0.2166210860013961,4.001551628112793,0.2006799876689911,4.088918209075928,50000 -404.859354019165,0.2183105945587158,3395.407825946808,7248,0,3395.407825946808,0.1785000115633011,4.246798515319824,10000,3800.8962712287903,0.2599609196186065,3.707098245620728,0.2382399886846542,3.8109381198883057,50000 -453.3918721675873,0.245370864868164,3815.699334621429,8165,0,3815.699334621429,0.2092000097036361,4.015152454376221,10000,4269.799302577972,0.2924023270606994,3.420924186706543,0.2739599943161011,3.5337398052215576,50000 -500.3881521224976,0.2754819393157959,4235.792031049728,9082,0,4235.792031049728,0.2371000051498413,3.8389713764190674,10000,4736.970528125763,0.3405078053474426,3.166249990463257,0.3056800067424774,3.3514530658721924,50000 -549.8398551940918,0.3025200366973877,4656.01069688797,9997,0,4656.01069688797,0.2614000141620636,3.675516605377197,10000,5206.719424247742,0.3578320145606994,3.022150754928589,0.3312000036239624,3.1509346961975098,50000 -599.9115297794342,0.3341560363769531,5076.158985376358,10912,0,5076.158985376358,0.2733000218868255,3.5355544090271,10000,5677.023521661758,0.3854492008686065,2.840325355529785,0.3569599986076355,2.995655536651612,50000 -647.2774829864502,0.3645823001861572,5496.442660808563,11830,0,5496.442660808563,0.2883000075817108,3.481478691101074,10000,6144.755959510803,0.4166601598262787,2.7132813930511475,0.3746599853038788,2.916720151901245,50000 -696.8378114700317,0.3950300216674804,5916.627821445465,12746,0,5916.627821445465,0.3023000061511993,3.4036993980407715,10000,6614.584210395813,0.4178515672683716,2.683227777481079,0.3882399797439575,2.8298232555389404,50000 -740.1960797309875,0.4252433776855469,6336.897862434387,13657,0,6336.897862434387,0.3061000108718872,3.3392672538757324,10000,7078.295861721039,0.4305664002895355,2.5879738330841064,0.4039999842643738,2.7390167713165283,50000 -790.636931180954,0.457782506942749,6757.152330160141,14574,0,6757.152330160141,0.317900002002716,3.2573442459106445,10000,7549.075749397278,0.4543749988079071,2.460162401199341,0.4130399823188782,2.6595137119293213,50000 -838.7734625339508,0.4854211807250976,7177.283156871796,15489,0,7177.283156871796,0.3285000026226043,3.233921527862549,10000,8017.423315048218,0.4491015672683716,2.497513055801392,0.4224399924278259,2.632542133331299,50000 -887.2250754833221,0.5128743648529053,7597.375913381576,16404,0,7597.375913381576,0.3423000276088714,3.1236488819122314,10000,8486.046859264374,0.4757421910762787,2.3351094722747803,0.44132000207901,2.5121448040008545,50000 -934.1244015693665,0.5392537117004395,8017.678372621536,17319,0,8017.678372621536,0.3389000296592712,3.1433703899383545,10000,8953.327250957489,0.4820898473262787,2.317319869995117,0.4428799748420715,2.516441583633423,50000 -978.050395488739,0.5683753490447998,8437.92775630951,18234,0,8437.92775630951,0.343500018119812,3.1278634071350098,10000,9417.584113836288,0.4768359363079071,2.342108249664306,0.4420199990272522,2.51889705657959,50000 -1028.396469831467,0.5990426540374756,8857.909608125687,19148,0,8857.909608125687,0.3586000204086303,3.0742838382720947,10000,9887.99471116066,0.4917773306369781,2.2599568367004395,0.4593999981880188,2.4413387775421143,50000 -1078.1956298351288,0.6277709007263184,9277.858618497849,20064,0,9277.858618497849,0.3616000115871429,3.0053672790527344,10000,10357.822923898697,0.509082019329071,2.185120820999145,0.4652799963951111,2.3844094276428223,50000 -1125.6968231201172,0.6573843955993652,9698.17512512207,20980,0,9698.17512512207,0.3641000092029571,2.975359439849853,10000,10825.721643447876,0.53369140625,2.057878255844116,0.4750799834728241,2.334637403488159,50000 -1172.9720392227173,0.6915838718414307,10118.501426696776,21897,0,10118.501426696776,0.3733000159263611,2.9646759033203125,10000,11293.408746242523,0.507519543170929,2.1788175106048584,0.4765599966049194,2.338596105575561,50000 -1221.914003610611,0.729525089263916,10538.751983642578,22814,0,10538.751983642578,0.3775000274181366,2.9053289890289307,10000,11762.698380231855,0.5238476395606995,2.071101427078247,0.4860999882221222,2.2656033039093018,50000 -1267.7731275558472,0.766730546951294,10958.97521162033,23730,0,10958.97521162033,0.3871000111103058,2.8293137550354004,10000,12228.869998455048,0.5507421493530273,1.938953518867493,0.4956399798393249,2.2019717693328857,50000 -1317.8803231716156,0.8011837005615234,11379.114792823792,24647,0,11379.114792823792,0.3889000117778778,2.873844623565674,10000,12699.20350074768,0.5287109017372131,2.059361219406128,0.4931999742984772,2.238321542739868,50000 -1366.6767621040344,0.8305325508117676,11799.398998975754,25563,0,11799.398998975754,0.3960000276565552,2.799640893936157,10000,13168.366040945051,0.5419335961341858,1.988563776016236,0.5034799575805664,2.1718690395355225,50000 -1415.4040472507477,0.8615057468414307,12219.34901380539,26477,0,12219.34901380539,0.3974000215530395,2.7761967182159424,10000,13637.125775814056,0.5632421970367432,1.8941307067871087,0.5120199918746948,2.144299268722534,50000 -1464.6509912014008,0.891242265701294,12639.654458522797,27392,0,12639.654458522797,0.4012000262737274,2.763537645339966,10000,14106.759474277496,0.5482421517372131,1.9673892259597776,0.5123999714851379,2.13213849067688,50000 -1512.441159248352,0.9247598648071288,13059.61026597023,28308,0,13059.61026597023,0.4006000161170959,2.744941473007202,10000,14574.590921640396,0.5547069907188416,1.929384708404541,0.5189200043678284,2.1071009635925293,50000 -1563.1740138530731,0.9577534198760986,13479.542765378952,29224,0,13479.542765378952,0.4052000045776367,2.7502341270446777,10000,15045.342432498932,0.5635741949081421,1.887037754058838,0.5194000005722046,2.098005533218384,50000 -1611.8296627998352,0.990304946899414,13899.718740701675,30141,0,13899.718740701675,0.4148000180721283,2.698044538497925,10000,15514.25835442543,0.5602734088897705,1.8773443698883057,0.5280199646949768,2.0621399879455566,50000 -1660.1082208156586,1.0202922821044922,14319.735169649124,31057,0,14319.735169649124,0.415800005197525,2.691118001937866,10000,15982.635932683945,0.5613867044448853,1.895029664039612,0.526199996471405,2.0695595741271973,50000 -1708.6653501987455,1.0565845966339111,14739.937356710434,31972,0,14739.937356710434,0.4119000136852264,2.7069880962371826,10000,16451.484882354736,0.5696093440055847,1.8706759214401243,0.5287600159645081,2.060586452484131,50000 -1757.6149718761444,1.0907628536224363,15160.13657617569,32890,0,15160.13657617569,0.4160000085830688,2.677694797515869,10000,16920.72052717209,0.5985351204872131,1.70049250125885,0.5352999567985535,2.021901130676269,50000 -1806.4245445728304,1.1232614517211914,15580.376125097277,33807,0,15580.376125097277,0.4164000153541565,2.689453363418579,10000,17389.854667186737,0.567578136920929,1.879085898399353,0.5329399704933167,2.052635431289673,50000 -1852.8152103424072,1.1534223556518557,16000.444960832596,34724,0,16000.444960832596,0.4126000106334686,2.657298803329468,10000,17856.395292282104,0.5770702958106995,1.802340149879456,0.5371999740600586,1.9966063499450684,50000 -1901.7003903388977,1.184575080871582,16420.363413095474,35638,0,16420.363413095474,0.4256000220775604,2.619293212890625,10000,18325.28321290016,0.5990234017372131,1.7254287004470823,0.5380600094795227,2.003264904022217,50000 -1950.9113097190857,1.2209351062774658,16840.63138628006,36554,0,16840.63138628006,0.4204000234603882,2.6431541442871094,10000,18794.849648475647,0.575488269329071,1.8254637718200684,0.5359799861907959,2.010288715362549,50000 -1993.4251432418823,1.2563939094543457,17260.585029363632,37469,0,17260.585029363632,0.4323000311851501,2.6027326583862305,10000,19257.40512537956,0.5883203148841858,1.7518789768218994,0.5458799600601196,1.955213069915772,50000 -2041.5583720207208,1.2973315715789795,17680.8397500515,38384,0,17680.8397500515,0.4346000254154205,2.583150625228882,10000,19725.88530659676,0.6019335985183716,1.7069507837295532,0.5504199862480164,1.9425902366638184,50000 -2087.815386772156,1.333068609237671,18101.13681983948,39301,0,18101.13681983948,0.4333000183105469,2.616333484649658,10000,20192.52774953842,0.5848632454872131,1.7843482494354248,0.545199990272522,1.975147485733032,50000 -2136.272389173508,1.364790678024292,18521.130932092667,40218,0,18521.130932092667,0.4305000305175781,2.606301069259644,10000,20661.06382799149,0.5871874690055847,1.7634190320968628,0.5512199997901917,1.9427509307861328,50000 -2185.132310628891,1.3997371196746826,18941.098199129105,41135,0,18941.098199129105,0.4401000142097473,2.572782754898072,10000,21129.979234695435,0.597851574420929,1.7207162380218506,0.5503999590873718,1.93899405002594,50000 -2230.481188297272,1.4327614307403564,19361.50342464447,42051,0,19361.50342464447,0.4354000091552734,2.5808470249176025,10000,21595.81792283058,0.5936523079872131,1.7368556261062622,0.5519399642944336,1.9425498247146609,50000 -2278.644249200821,1.4757418632507324,19781.703052520752,42966,0,19781.703052520752,0.4368000328540802,2.5653188228607178,10000,22064.27545189857,0.5927538871765137,1.7270139455795288,0.5557399988174438,1.9057121276855469,50000 -2326.826430082321,1.5156099796295166,20201.69106078148,43882,0,20201.69106078148,0.4455000162124634,2.518744707107544,10000,22532.53761100769,0.6057812571525574,1.6457350254058838,0.5611599683761597,1.869145750999451,50000 -2373.602608203888,1.5518851280212402,20622.27797460556,44796,0,20622.27797460556,0.4451000094413757,2.5309979915618896,10000,22999.989002227783,0.6296679377555847,1.5705506801605225,0.5589399933815002,1.888720750808716,50000 -2420.3978378772736,1.5879981517791748,21042.34734320641,45710,0,21042.34734320641,0.4449000358581543,2.515469789505005,10000,23466.94051671028,0.6051172018051147,1.6690466403961182,0.5628399848937988,1.861746311187744,50000 -2468.4884293079376,1.6240994930267334,21462.64211010933,46623,0,21462.64211010933,0.4453000128269195,2.5359339714050293,10000,23935.413821458817,0.6066796779632568,1.6673763990402222,0.5652799606323242,1.8668668270111084,50000 -2515.0421693325043,1.66135573387146,21882.84224438668,47539,0,21882.84224438668,0.4479000270366668,2.509688377380371,10000,24402.25755739212,0.626269519329071,1.5778939723968506,0.5616999864578247,1.8711748123168943,50000 -2561.5461843013763,1.692392110824585,22303.119409799576,48455,0,22303.119409799576,0.4543000161647796,2.494823455810547,10000,24869.121667146683,0.6093554496765137,1.643090009689331,0.5676599740982056,1.8277660608291624,50000 -2609.5553126335144,1.7268388271331787,22723.300478219982,49371,0,22723.300478219982,0.4473000168800354,2.480165243148804,10000,25337.39897251129,0.611132800579071,1.6391677856445312,0.5672399997711182,1.839793682098389,50000 -2656.3447353839874,1.7708237171173096,23143.49778413773,50287,0,23143.49778413773,0.4564000070095062,2.477119445800781,10000,25804.48171377182,0.6187109351158142,1.5916036367416382,0.5705599784851074,1.8403029441833496,50000 -2703.257434368133,1.8078598976135247,23563.565361738205,51201,0,23563.565361738205,0.457800030708313,2.448241949081421,10000,26271.550952911377,0.6166015267372131,1.6008018255233765,0.5757399797439575,1.7892749309539795,50000 -2752.9996156692505,1.848759651184082,23983.790986299515,52116,0,23983.790986299515,0.4490000307559967,2.522512912750244,10000,26741.61191439629,0.6089453101158142,1.667687177658081,0.5686599612236023,1.8564475774765008,50000 -2797.968763113022,1.881891727447509,24403.7320561409,53032,0,24403.7320561409,0.4565000236034393,2.4617958068847656,10000,27206.607456207275,0.6204296946525574,1.5840575695037842,0.5726199746131897,1.805444836616516,50000 -2847.4651761055,1.9169471263885496,24823.81976938248,53949,0,24823.81976938248,0.4558000266551971,2.460911512374878,10000,27676.279418468475,0.6197265386581421,1.6068668365478516,0.5780799984931946,1.7958285808563232,50000 -2890.782431602478,1.950938701629639,25243.79498982429,54863,0,25243.79498982429,0.4589000344276428,2.458005428314209,10000,28139.657998085026,0.6193749904632568,1.6178147792816162,0.5789799690246582,1.8013064861297607,50000 -2937.7285718917847,1.9914829730987549,25663.81897115708,55774,0,25663.81897115708,0.4631000161170959,2.472367525100708,10000,28606.72021007538,0.6284765601158142,1.5806477069854736,0.5812999606132507,1.804103970527649,50000 -2985.750026702881,2.033883810043335,26084.16915845871,56688,0,26084.16915845871,0.4606000185012817,2.447606086730957,10000,29075.186302661896,0.6417773365974426,1.5043836832046509,0.578719973564148,1.7915959358215332,50000 -3033.001730442047,2.4678242206573486,26504.19334626197,57601,0,26504.19334626197,0.4669000208377838,2.421680450439453,10000,29542.947952270508,0.6244726181030273,1.5794726610183716,0.5821599960327148,1.773970603942871,50000 -3080.904748916626,2.5036418437957764,26924.44057393074,58519,0,26924.44057393074,0.4695000350475311,2.407954692840576,10000,30011.18598389625,0.6330859065055847,1.5414679050445557,0.5845400094985962,1.7692793607711792,50000 -3128.893353700638,2.5405941009521484,27344.42973256111,59431,0,27344.42973256111,0.4673000276088714,2.3970510959625244,10000,30479.25344944,0.6523241996765137,1.4520971775054932,0.5870400071144104,1.7456282377243042,50000 -3177.135461807251,2.592692852020264,27764.40898680687,60346,0,27764.40898680687,0.4643000364303589,2.428161144256592,10000,30947.57905244828,0.6209570169448853,1.5848476886749268,0.58051997423172,1.7770845890045166,50000 -3224.825823068619,2.629995107650757,28184.32632660865,61260,0,28184.32632660865,0.4652000367641449,2.4296584129333496,10000,31415.275759458546,0.6322460770606995,1.57372784614563,0.5851399898529053,1.7762945890426636,50000 -3274.260570049286,2.6725854873657227,28604.31938052177,62175,0,28604.31938052177,0.4636000096797943,2.413270711898804,10000,31884.798108816147,0.6433984041213989,1.473775863647461,0.5825200080871582,1.7541874647140503,50000 -3323.5141406059265,2.708783864974976,29024.578023433685,63090,0,29024.578023433685,0.4662000238895416,2.3992021083831787,10000,32354.39852118492,0.6314452886581421,1.5430465936660769,0.5873000025749207,1.7412338256835938,50000 -3366.8562116622925,2.7485859394073486,29444.48857831955,64006,0,29444.48857831955,0.4687000215053558,2.39388108253479,10000,32817.74343061447,0.6298632621765137,1.5495365858078003,0.5903800129890442,1.736882567405701,50000 -3414.1854150295258,2.78342866897583,29864.842218399048,64921,0,29864.842218399048,0.4788000285625458,2.3833441734313965,10000,33285.5135974884,0.6470116972923279,1.4817765951156616,0.5940399765968323,1.731095790863037,50000 -3460.025162935257,2.822338342666626,30285.16464781761,65839,0,30285.16464781761,0.4762000143527984,2.3636457920074463,10000,33751.76726102829,0.64111328125,1.4903844594955444,0.5980600118637085,1.6916265487670898,50000 -3510.1349868774414,2.8638916015625,30705.16337299347,66752,0,30705.16337299347,0.4765000343322754,2.350003242492676,10000,34221.96933102608,0.6405078172683716,1.486151099205017,0.6005600094795227,1.6860800981521606,50000 -3558.282987117768,2.9099161624908447,31125.316086292267,67668,0,31125.316086292267,0.4780000150203705,2.372490167617798,10000,34690.3684053421,0.6430078148841858,1.4980931282043457,0.5955199599266052,1.7154990434646606,50000 -3605.6924924850464,2.954616069793701,31545.59892988205,68581,0,31545.59892988205,0.4852000176906585,2.3231451511383057,10000,35158.15919518471,0.662304699420929,1.3984365463256836,0.599399983882904,1.6794729232788086,50000 -3654.193194627762,2.99135160446167,31965.907709360123,69499,0,31965.907709360123,0.4850000143051147,2.3373358249664307,10000,35627.05697226524,0.6392773389816284,1.4914463758468628,0.6018999814987183,1.6752866506576538,50000 -3703.6968002319336,3.0350594520568848,32385.870503902435,70415,0,32385.870503902435,0.4850000143051147,2.3356761932373047,10000,36096.61861395836,0.647753894329071,1.4472213983535769,0.5977199673652649,1.679577112197876,50000 -3752.7585911750793,3.072389841079712,32805.93040370941,71331,0,32805.93040370941,0.4736000299453735,2.3584470748901367,10000,36565.83061218262,0.6681054830551147,1.376840353012085,0.5995399951934814,1.685808539390564,50000 -3800.581439733505,3.1176953315734863,33226.05816960335,72246,0,33226.05816960335,0.4778000116348266,2.3342955112457275,10000,37033.87846469879,0.6423632502555847,1.4910801649093628,0.5996599793434143,1.6887861490249634,50000 -3846.234811067581,3.1611366271972656,33646.26445245743,73161,0,33646.26445245743,0.4845000207424164,2.3093228340148926,10000,37499.833355903625,0.6495702862739563,1.450307846069336,0.604200005531311,1.6661208868026731,50000 -3897.773527622223,3.20788049697876,34066.458422899246,74075,0,34066.458422899246,0.4880000352859497,2.348051309585572,10000,37971.66522574425,0.662890613079071,1.4327744245529177,0.6055799722671509,1.702540636062622,50000 -3942.385674238205,3.2509477138519287,34486.554055690765,74988,0,34486.554055690765,0.4859000146389007,2.3130125999450684,10000,38436.4689707756,0.6476953029632568,1.45093834400177,0.6082199811935425,1.6413438320159912,50000 -3990.819508075714,3.2981767654418945,34906.55943346024,75899,0,34906.55943346024,0.4792000353336334,2.326888084411621,10000,38905.00736904144,0.6452929377555847,1.4677882194519043,0.6007599830627441,1.683293342590332,50000 -4038.497713804245,3.342529058456421,35326.47755908966,76813,0,35326.47755908966,0.48580002784729,2.317957162857056,10000,39372.69943475723,0.6568945050239563,1.4267654418945312,0.6042400002479553,1.6688212156295776,50000 -4083.765850067138,3.3802947998046875,35746.67924427986,77728,0,35746.67924427986,0.4881000220775604,2.274594783782959,10000,39838.25917339325,0.6592968702316284,1.4043962955474854,0.6110000014305115,1.6132807731628418,50000 -4132.717834472656,3.4168825149536133,36166.756766080856,78644,0,36166.756766080856,0.4855000376701355,2.3392891883850098,10000,40307.37702679634,0.6494140625,1.4841097593307495,0.6038999557495117,1.6853262186050415,50000 -4179.737259864807,3.457533836364746,36586.85572004318,79559,0,36586.85572004318,0.4878000319004059,2.2785677909851074,10000,40774.5886554718,0.6655663847923279,1.3757052421569824,0.6121199727058411,1.610224366188049,50000 -4225.793916940689,3.5003437995910645,37007.06138277054,80474,0,37007.06138277054,0.4944000244140625,2.2579493522644043,10000,41240.94586133957,0.6681054830551147,1.3617504835128784,0.6156399846076965,1.5994772911071775,50000 -4274.636933803558,3.5497610569000244,37427.10963559151,81388,0,37427.10963559151,0.4949000179767608,2.2725863456726074,10000,41709.9388833046,0.6598241925239563,1.4070793390274048,0.6183599829673767,1.610404133796692,50000 -4320.90106010437,3.595423936843872,37847.287464141846,82301,0,37847.287464141846,0.4942000210285187,2.2834160327911377,10000,42176.4779920578,0.6676562428474426,1.3960615396499634,0.6139400005340576,1.63629150390625,50000 -4367.9880521297455,3.6356606483459473,38267.64587044716,83217,0,38267.64587044716,0.4937000274658203,2.265681266784668,10000,42644.01581954956,0.685742199420929,1.2981655597686768,0.6175999641418457,1.6042635440826416,50000 -4417.275773525238,3.674084186553955,38688.0792453289,84133,0,38688.0792453289,0.4999000132083893,2.265106439590454,10000,43113.82888507843,0.6615625023841858,1.408983588218689,0.617680013179779,1.613947510719299,50000 -4466.535435676575,3.7140793800354,39108.14068603516,85049,0,39108.14068603516,0.5058000087738037,2.244614601135254,10000,43583.24208950997,0.6717578172683716,1.3764458894729614,0.6211400032043457,1.6145461797714231,50000 -4512.167064666748,3.7566189765930176,39528.49926805496,85964,0,39528.49926805496,0.5038000345230103,2.207498550415039,10000,44049.32660079002,0.6827148199081421,1.2997052669525146,0.6204400062561035,1.5726711750030518,50000 -4559.381934642792,3.8039591312408447,39948.76182985306,86878,0,39948.76182985306,0.5035000443458557,2.202033042907715,10000,44516.90362381935,0.6692578196525574,1.3671681880950928,0.6223799586296082,1.5760130882263184,50000 -4606.599714756012,3.843261957168579,40369.03018307686,87793,0,40369.03018307686,0.5076000094413757,2.2284297943115234,10000,44984.48181915283,0.6750195026397705,1.3554376363754272,0.6269999742507935,1.583235740661621,50000 -4652.663172245026,3.891881942749024,40789.127017498016,88707,0,40789.127017498016,0.501800000667572,2.2088680267333984,10000,45450.74204015732,0.6813281178474426,1.3070732355117798,0.6232199668884277,1.5689715147018433,50000 -4702.275573730469,3.930466890335083,41209.09097337723,89620,0,41209.09097337723,0.5082000494003296,2.181424617767334,10000,45920.40969824791,0.6719921827316284,1.342168211936951,0.6292600035667419,1.5393462181091309,50000 -4748.239414215088,3.968465805053711,41629.03953385353,90534,0,41629.03953385353,0.504800021648407,2.219699621200561,10000,46386.41199302673,0.6768749952316284,1.3328397274017334,0.6259399652481079,1.5650203227996826,50000 -4797.702326059341,4.011040687561035,42049.08965468407,91446,0,42049.08965468407,0.511400043964386,2.195029735565185,10000,46856.01924753189,0.6907030940055847,1.2894749641418457,0.6323599815368652,1.547354221343994,50000 -4845.945533275604,4.053576469421387,42469.318239450455,92359,0,42469.318239450455,0.5012000203132629,2.218226194381714,10000,47324.5852496624,0.6765429377555847,1.3410677909851074,0.6250799894332886,1.5731840133666992,50000 -4894.030555963516,4.096000909805298,42889.63666677475,93276,0,42889.63666677475,0.5135000348091125,2.1676101684570312,10000,47793.083832740784,0.68115234375,1.303971529006958,0.6333000063896179,1.5291578769683838,50000 -4941.8448095321655,4.140573740005493,43309.85211658478,94192,0,43309.85211658478,0.5126000046730042,2.1976771354675293,10000,48261.21060657501,0.6882030963897705,1.2961716651916504,0.6321399807929993,1.5467734336853027,50000 -4991.971055984497,4.181777000427246,43730.018189907074,95110,0,43730.018189907074,0.5103999972343445,2.1727511882781982,10000,48731.59634041786,0.7078515291213989,1.1914904117584229,0.6358000040054321,1.5101193189620972,50000 -5037.990777015686,4.222255706787109,44150.107147455215,96025,0,44150.107147455215,0.5094000101089478,2.1757290363311768,10000,49197.796854019165,0.6827734112739563,1.2824567556381226,0.6331599950790405,1.5004512071609497,50000 -5085.7554042339325,4.266266822814941,44570.26457285881,96938,0,44570.26457285881,0.5078999996185303,2.185168504714966,10000,49665.81526470184,0.6849414110183716,1.297351360321045,0.6337800025939941,1.5296707153320312,50000 -5135.472050905228,4.316110610961914,44990.63896560669,97853,0,44990.63896560669,0.5216000080108643,2.120568037033081,10000,50136.00769329071,0.706347644329071,1.1973272562026978,0.6412000060081482,1.4900233745574951,50000 -5182.138915061951,4.366321802139282,45410.94944810867,98770,0,45410.94944810867,0.5159000158309937,2.161956548690796,10000,50603.08851194382,0.6881640553474426,1.2870441675186155,0.6415799856185913,1.4922298192977903,50000 -5231.758380651474,4.412039041519165,45831.12450551987,99686,0,45831.12450551987,0.5254999995231628,2.105508327484131,10000,51072.9808216095,0.6941796541213989,1.2387899160385132,0.6456999778747559,1.462598204612732,50000 -5277.18566441536,4.834028720855713,46250.72084593773,100594,0,46250.72084593773,0.5175999999046326,2.107346296310425,10000,51538.47866082192,0.7018945217132568,1.217579960823059,0.642799973487854,1.4809935092926023,50000 -5326.998807668686,4.8833067417144775,46670.82105779648,101508,0,46670.82105779648,0.5198000073432922,2.138503074645996,10000,52008.493248701096,0.6890038847923279,1.2767558097839355,0.643839955329895,1.4891250133514404,50000 -5376.259582996368,4.928251028060913,47091.16972374916,102425,0,47091.16972374916,0.5297000408172607,2.082939863204956,10000,52478.20104813576,0.7002929449081421,1.2077990770339966,0.6493200063705444,1.445142149925232,50000 -5424.322031259537,4.9686126708984375,47511.17261624336,103341,0,47511.17261624336,0.5266000032424927,2.116064548492432,10000,52946.35978603363,0.7048437595367432,1.213708758354187,0.647599995136261,1.473738670349121,50000 -5473.918003559113,5.021668195724487,47931.445341825485,104258,0,47931.445341825485,0.5238000154495239,2.127571582794189,10000,53416.33458185196,0.6996484398841858,1.24612295627594,0.6481599807739258,1.4764598608016968,50000 -5520.325870513916,5.064110040664673,48351.39259338379,105170,0,48351.39259338379,0.5285000205039978,2.1071197986602783,10000,53882.78431844711,0.701464831829071,1.2353618144989014,0.6521399617195129,1.4591549634933472,50000 -5566.9127151966095,5.108115911483765,48771.49581623077,106082,0,48771.49581623077,0.5339000225067139,2.0650956630706787,10000,54349.56981277466,0.708984375,1.166994333267212,0.6547200083732605,1.4128574132919312,50000 -5616.92050743103,5.152854442596436,49191.43444299698,106995,0,49191.43444299698,0.5242000222206116,2.100848436355591,10000,54819.61263132095,0.7225390672683716,1.131512999534607,0.6502199769020081,1.4502038955688477,50000 -5668.385055780411,5.196820497512817,49611.7597925663,107912,0,49611.7597925663,0.5339000225067139,2.047857999801636,10000,55291.49946713448,0.7108789086341858,1.1651406288146973,0.6574999690055847,1.407038688659668,50000 -5717.807471752167,5.240591287612915,50032.11195850372,108827,0,50032.11195850372,0.5293000340461731,2.086609125137329,10000,55761.36970353127,0.7103906273841858,1.18244469165802,0.6546799540519714,1.4344879388809204,50000 -5765.595708608627,5.286548137664795,50452.2932767868,109742,0,50452.2932767868,0.5376999974250793,2.0645999908447266,10000,56229.43710780144,0.7229882478713989,1.1345579624176023,0.6586799621582031,1.426134467124939,50000 -5811.417835235596,5.338428258895874,50872.51973128319,110658,0,50872.51973128319,0.5340999960899353,2.095458984375,10000,56695.58960843086,0.7133398056030273,1.1982223987579346,0.6570599675178528,1.440184235572815,50000 -5859.738709926605,5.385877132415772,51292.6597931385,111575,0,51292.6597931385,0.5422000288963318,2.019754409790039,10000,57164.1505010128,0.7201171517372131,1.123279690742493,0.6644399762153625,1.3742200136184692,50000 -5904.406593084335,5.429534912109375,51712.80060315132,112491,0,51712.80060315132,0.5380000472068787,2.0178375244140625,10000,57629.05432772637,0.7299218773841858,1.0928118228912354,0.6615599989891052,1.383098602294922,50000 -5952.510481357575,5.479061126708984,52132.85252594948,113407,0,52132.85252594948,0.5384000539779663,2.007035255432129,10000,58097.311703681946,0.7118163704872131,1.1474217176437378,0.6642599701881409,1.3698629140853882,50000 -5997.965879917145,5.521757125854492,52552.87658810616,114319,0,52552.87658810616,0.5433000326156616,2.0026955604553223,10000,58562.885924339294,0.7221288681030273,1.117608666419983,0.6670199632644653,1.3616474866867063,50000 -6047.270622730255,5.568381786346436,52973.100821495056,115236,0,52973.100821495056,0.5494000315666199,2.007647752761841,10000,59032.51432132721,0.7292773127555847,1.092387318611145,0.6671000123023987,1.3630927801132202,50000 -6095.24760055542,5.615738153457642,53393.11542224884,116150,0,53393.11542224884,0.5452000498771667,1.99252724647522,10000,59500.606415987015,0.7247265577316284,1.1091890335083008,0.6702199578285217,1.354252815246582,50000 -6142.140088558197,5.661764621734619,53813.10630583763,117066,0,53813.10630583763,0.5508000254631042,1.99046790599823,10000,59967.58908033371,0.7264843583106995,1.096244215965271,0.672760009765625,1.340212106704712,50000 -6192.417934656143,5.713583707809448,54233.0949075222,117982,0,54233.0949075222,0.5498000383377075,1.969029784202576,10000,60437.959518909454,0.7369335889816284,1.0518841743469238,0.6758599877357483,1.3248143196105957,50000 -6240.577670812607,5.76579213142395,54653.28877854347,118897,0,54653.28877854347,0.5490000247955322,2.009008407592773,10000,60906.417788267136,0.7423437237739563,1.0445845127105713,0.6692599654197693,1.361345291137695,50000 -6289.227605819702,5.812516689300537,55073.62894105911,119812,0,55073.62894105911,0.5487000346183777,2.030686378479004,10000,61375.50624871254,0.7238671779632568,1.152235507965088,0.6700599789619446,1.3908920288085938,50000 -6338.752652645111,5.857455015182495,55493.84318685532,120727,0,55493.84318685532,0.5505000352859497,1.9711410999298096,10000,61845.34255743027,0.7369335889816284,1.071627855300903,0.6764400005340576,1.331760287284851,50000 -6388.082646846771,5.904085636138916,55914.0778169632,121644,0,55914.0778169632,0.5493000149726868,1.980876445770264,10000,62315.006165504456,0.746386706829071,1.0214617252349854,0.6764199733734131,1.3321937322616575,50000 -6439.080575942993,5.9502387046813965,56334.39921450615,122559,0,56334.39921450615,0.5550000071525574,1.9604276418685915,10000,62786.4237511158,0.7341406345367432,1.0811148881912231,0.6796999573707581,1.328904390335083,50000 -6488.3691465854645,6.0046210289001465,56754.34090018272,123473,0,56754.34090018272,0.5523000359535217,1.98598861694336,10000,63255.76044034958,0.7335742115974426,1.078386902809143,0.6755599975585938,1.3359389305114746,50000 -6533.5595326423645,6.052314043045044,57174.481586933136,124389,0,57174.481586933136,0.5511000156402588,1.9417612552642824,10000,63721.19270992279,0.75,0.994757890701294,0.6810799837112427,1.293065071105957,50000 -6583.968408584595,6.108830213546753,57594.81980538368,125305,0,57594.81980538368,0.560200035572052,1.9421415328979488,10000,64192.04764842987,0.74378901720047,1.0434715747833252,0.6869999766349792,1.2920804023742676,50000 -6632.670027494431,6.153582334518433,58015.1064915657,126221,0,58015.1064915657,0.5663000345230103,1.911031007766724,10000,64661.13260555267,0.7441992163658142,1.0305161476135254,0.6879000067710876,1.277883529663086,50000 -6682.635187864304,6.201958656311035,58435.141280412674,127138,0,58435.141280412674,0.5621000528335571,1.8933143615722656,10000,65131.23244309425,0.7535937428474426,0.9919482469558716,0.6867799758911133,1.2697197198867798,50000 -6731.701937913895,6.2473344802856445,58855.1561756134,128053,0,58855.1561756134,0.5675000548362732,1.9025838375091555,10000,65600.41216874123,0.7463476657867432,1.0128285884857178,0.6887399554252625,1.2626125812530518,50000 -6782.662457227707,6.294980049133301,59275.09725427628,128968,0,59275.09725427628,0.5649000406265259,1.8984917402267456,10000,66071.41319656372,0.7517382502555847,0.9961110353469848,0.6873199939727783,1.2667717933654783,50000 -6830.492385864258,6.34734582901001,59695.053881406784,129883,0,59695.053881406784,0.5732000470161438,1.85313093662262,10000,66539.30385136604,0.7596484422683716,0.9546266794204712,0.693839967250824,1.2397924661636353,50000 -6879.277712106705,6.404633522033691,60115.20013332367,130796,0,60115.20013332367,0.5675000548362732,1.8973604440689087,10000,67008.34401488304,0.764453113079071,0.9491795897483826,0.691540002822876,1.2591371536254885,50000 -6928.066775798798,6.457538366317749,60535.40959310532,131711,0,60535.40959310532,0.5690000057220459,1.871274471282959,10000,67477.44813871384,0.7549218535423279,0.9674057364463806,0.6934999823570251,1.2365845441818235,50000 -6978.662467956543,6.508562088012695,60955.45423436165,132625,0,60955.45423436165,0.5770000219345093,1.8836565017700195,10000,67948.19107365608,0.7603515386581421,0.9755155444145204,0.6967200040817261,1.2530672550201416,50000 -7026.758812665939,6.553529500961304,61375.42639231682,133540,0,61375.42639231682,0.5764999985694885,1.851581335067749,10000,68416.36034202576,0.7734179496765137,0.9080670475959778,0.6987400054931641,1.2287685871124268,50000 -7073.3166263103485,6.606256008148193,61795.71336269379,134453,0,61795.71336269379,0.5779000520706177,1.8428810834884644,10000,68883.31156492233,0.7599413990974426,0.9587976336479188,0.7029199600219727,1.2122979164123535,50000 -7123.942526578903,6.653573513031006,62216.026047468185,135369,0,62216.026047468185,0.5750000476837158,1.8443801403045648,10000,69354.34955620766,0.7627929449081421,0.9346864819526672,0.7001399993896484,1.2140921354293823,50000 -7173.187728404999,6.702480316162109,62636.19765305519,136284,0,62636.19765305519,0.5800999999046326,1.845173954963684,10000,69823.86749219894,0.7731054425239563,0.9033991694450378,0.7005999684333801,1.2164199352264404,50000 -7222.835695266724,6.760934591293335,63056.44626426697,137195,0,63056.44626426697,0.5833000540733337,1.8194403648376465,10000,70293.874396801,0.7671093344688416,0.9361597299575806,0.7026599645614624,1.2115135192871094,50000 -7271.722208023071,6.809126853942871,63476.769704818726,138106,0,63476.769704818726,0.5826000571250916,1.85190498828888,10000,70763.18441557884,0.7684765458106995,0.9252225756645204,0.7012400031089783,1.2247495651245115,50000 -7323.079668521881,6.859033823013306,63896.763201236725,139019,0,63896.763201236725,0.5879000425338745,1.8092293739318848,10000,71234.63944602013,0.7814257740974426,0.8607298135757446,0.7080000042915344,1.1757205724716189,50000 -7370.497317314148,6.90878701210022,64317.06393766403,139933,0,64317.06393766403,0.5852000117301941,1.819550514221192,10000,71702.45918250084,0.7718554735183716,0.9126390814781188,0.7079600095748901,1.185042142868042,50000 -7421.191428661346,6.957134008407593,64737.02635455132,140847,0,64737.02635455132,0.5847000479698181,1.7930099964141846,10000,72173.21574926376,0.7747851610183716,0.8819167017936707,0.7116400003433228,1.1624095439910889,50000 -7473.606050014496,7.00886607170105,65157.04562306404,141762,0,65157.04562306404,0.5872000455856323,1.7895723581314087,10000,72645.75348472595,0.77978515625,0.8617091178894043,0.7116599678993225,1.165801763534546,50000 -7526.591713428497,7.064823865890503,65577.20102715492,142676,0,65577.20102715492,0.5898000001907349,1.76817786693573,10000,73119.00194859505,0.7821679711341858,0.8531965613365173,0.7168200016021729,1.1503877639770508,50000 -7577.742676258087,7.511917352676392,65996.80779743195,143588,0,65996.80779743195,0.5919000506401062,1.7643986940383911,10000,73590.25830888748,0.7820116877555847,0.8580734729766846,0.7165200114250183,1.1449278593063354,50000 -7624.79746389389,7.559880018234253,66416.94591355324,144501,0,66416.94591355324,0.5920000076293945,1.7644582986831665,10000,74057.55110120773,0.7851171493530273,0.8476221561431885,0.7163400053977966,1.144951581954956,50000 -7668.535885095596,7.608190298080444,66837.3013036251,145415,0,66837.3013036251,0.6025000214576721,1.743360996246338,10000,74521.7456138134,0.7938281297683716,0.8047267198562622,0.7177599668502808,1.1373757123947144,50000 -7717.922769069672,7.655825853347778,67257.48422813416,146328,0,67257.48422813416,0.6013000011444092,1.727980136871338,10000,74991.41473031044,0.7868554592132568,0.8384296894073486,0.7208999991416931,1.1160486936569214,50000 -7767.141966342926,7.707298278808594,67677.5356965065,147243,0,67677.5356965065,0.6025000214576721,1.7282333374023438,10000,75460.78970003128,0.7939453125,0.8069570064544678,0.7222599983215332,1.10954487323761,50000 -7818.040345191956,7.759216070175171,68097.77912926674,148158,0,68097.77912926674,0.6007000207901001,1.723463773727417,10000,75932.03611707687,0.7961523532867432,0.7881932854652405,0.7221599817276001,1.1152422428131104,50000 -7867.6073389053345,7.813696146011352,68517.94238138199,149073,0,68517.94238138199,0.6014000177383423,1.7236626148223877,10000,76401.87269186974,0.7910937070846558,0.8174343109130859,0.7259599566459656,1.0968220233917236,50000 -7916.614735364914,7.865723133087158,68937.99177789688,149988,0,68937.99177789688,0.602400004863739,1.7529717683792114,10000,76871.03368258476,0.7939257621765137,0.8128007054328918,0.7231400012969971,1.1107922792434692,50000 -7964.20180106163,7.918400287628174,69358.29975485802,150900,0,69358.29975485802,0.6033000349998474,1.7106510400772097,10000,77339.0362174511,0.80189448595047,0.7681282758712769,0.7280600070953369,1.0890097618103027,50000 -8014.445042133331,7.975107192993164,69778.54154014587,151815,0,69778.54154014587,0.614300012588501,1.681167483329773,10000,77809.62989974022,0.7974413633346558,0.7807225584983826,0.7289199829101562,1.0758874416351318,50000 -8065.495423555374,8.04086971282959,70198.79391741753,152729,0,70198.79391741753,0.6094000339508057,1.697064757347107,10000,78281.05008983612,0.8052343726158142,0.7729941606521606,0.7317799925804138,1.0814770460128784,50000 -8112.055419683456,8.100271940231323,70618.86042189598,153644,0,70618.86042189598,0.6111000180244446,1.6869773864746094,10000,78747.78933882713,0.805468738079071,0.7481812834739685,0.7303999662399292,1.070380926132202,50000 -8159.562078952789,8.1575448513031,71038.89610767365,154558,0,71038.89610767365,0.6081000566482544,1.683420181274414,10000,79215.44156551361,0.806640625,0.7597796320915222,0.7336199879646301,1.068161129951477,50000 -8206.094826936722,8.21754264831543,71458.79195690155,155470,0,71458.79195690155,0.6116000413894653,1.6947035789489746,10000,79681.98218774796,0.8057226538658142,0.7690586447715759,0.7343399524688721,1.075039625167847,50000 -8255.587515115738,8.26708173751831,71878.8363673687,156385,0,71878.8363673687,0.6154000163078308,1.6492758989334106,10000,80151.62154054642,0.8109179735183716,0.7232632040977478,0.7350800037384033,1.0466346740722656,50000 -8301.570994377136,8.324889659881592,72298.85396432877,157302,0,72298.85396432877,0.6164000034332275,1.6407500505447388,10000,80617.73319363594,0.8225781321525574,0.6812586188316345,0.739139974117279,1.0413532257080078,50000 -8350.355193376541,8.37764310836792,72719.15087842941,158215,0,72719.15087842941,0.61080002784729,1.6860063076019287,10000,81086.91956949234,0.8095898032188416,0.7571691870689392,0.7379199862480164,1.0595265626907349,50000 -8396.208889245987,8.428681373596191,73139.33359980583,159131,0,73139.33359980583,0.6171000003814697,1.6728469133377075,10000,81553.05968117714,0.8150585889816284,0.7215979695320129,0.7387599945068359,1.0488839149475098,50000 -8446.458607912064,8.487337112426758,73559.36250901222,160046,0,73559.36250901222,0.6206000447273254,1.6554361581802368,10000,82023.44992232323,0.8247265219688416,0.6886405944824219,0.7408199906349182,1.0412501096725464,50000 -8494.669929981232,8.537365913391113,73979.74830436707,160960,0,73979.74830436707,0.6252000331878662,1.632118582725525,10000,82492.14982962608,0.8200390338897705,0.7016811370849609,0.744439959526062,1.0223753452301023,50000 -8541.95247745514,8.591475486755371,74399.8327550888,161873,0,74399.8327550888,0.6238000392913818,1.6342037916183472,10000,82959.62244343758,0.8234765529632568,0.682098388671875,0.7447400093078613,1.019462823867798,50000 -8591.316792964935,8.644585847854614,74819.76510477066,162788,0,74819.76510477066,0.6231000423431396,1.6292407512664795,10000,83429.02394080162,0.8300195336341858,0.6636582612991333,0.7465400099754333,1.0095263719558716,50000 -8636.684086561203,8.705162763595581,75239.78667020798,163702,0,75239.78667020798,0.6289000511169434,1.6085121631622314,10000,83894.52511429787,0.8236327767372131,0.6771386861801147,0.7473399639129639,1.008391499519348,50000 -8684.300383806229,8.755192041397095,75659.76452445984,164617,0,75659.76452445984,0.6265000104904175,1.6200848817825315,10000,84362.22888803482,0.8244921565055847,0.6821433305740356,0.7472999691963196,1.0059750080108645,50000 -8733.994886159897,8.806623458862305,76080.02516317368,165531,0,76080.02516317368,0.6282000541687012,1.6096539497375488,10000,84832.28757619858,0.8278319835662842,0.6549848914146423,0.7478399872779846,0.9953515529632568,50000 -8779.374361515045,8.862721681594849,76500.17499065399,166447,0,76500.17499065399,0.6318000555038452,1.6084636449813845,10000,85297.9245531559,0.8307226300239563,0.6519699096679688,0.7495200037956238,0.9923651218414308,50000 -8828.95196390152,8.925573348999023,76920.16060471535,167359,0,76920.16060471535,0.6297000050544739,1.5886270999908447,10000,85767.60281062126,0.8291015625,0.6523779034614563,0.7506399750709534,0.9864511489868164,50000 -8877.900985002518,8.980156898498535,77340.48335027695,168272,0,77340.48335027695,0.6333000063896179,1.587046504020691,10000,86236.98085284233,0.8313866853713989,0.637546181678772,0.753600001335144,0.9787457585334778,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/measurements.csv deleted file mode 100644 index 637aa29ab..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1874 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.3736868,6.9077563,,,,,,,,,,,,,, -1,,,0.0011328124674037,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,34.48550629615784,65.40688729286194,34.48550629615784,30.9212863445282,0.0,0.0 -100,0.4377955,6.905726,,,,,,,,,,,,,, -200,0.50483423,6.892523,,,,,,,,,,,,,, -300,0.55785394,6.864784,,,,,,,,,,,,,, -400,0.684679,6.8360386,,,,,,,,,,,,,, -500,0.88431513,6.766288,,,,,,,,,,,,,, -600,0.79231024,6.717329,,,,,,,,,,,,,, -700,0.82443124,6.678963,,,,,,,,,,,,,, -800,1.3026481,6.6058955,,,,,,,,,,,,,, -853,,,0.0112499995157122,6.477405071258545,0.0108599998056888,6.486855030059815,50000.0,0.0105000007897615,6.523868083953857,10000.0,454.543984413147,531.3210079669952,454.543984413147,76.71136093139648,0.017364501953125,0.0 -900,1.5872691,6.5772257,,,,,,,,,,,,,, -1000,1.8541744,6.521769,,,,,,,,,,,,,, -1100,2.0020957,6.425653,,,,,,,,,,,,,, -1200,1.6735157,6.3905907,,,,,,,,,,,,,, -1300,1.4376096,6.298111,,,,,,,,,,,,,, -1400,3.020654,6.2893915,,,,,,,,,,,,,, -1500,2.2515695,6.239998,,,,,,,,,,,,,, -1600,1.735535,6.1636357,,,,,,,,,,,,,, -1700,2.116194,6.1481233,,,,,,,,,,,,,, -1759,,,0.0346093736588954,5.880190372467041,0.0334799997508525,5.913444995880127,50000.0,0.0279000010341405,6.0221991539001465,10000.0,874.770247220993,997.8961570262908,874.770247220993,122.98103356361388,0.0453393459320068,0.0 -1800,2.0315106,6.2700825,,,,,,,,,,,,,, -1900,1.8286895,6.48117,,,,,,,,,,,,,, -2000,2.4351346,6.0516386,,,,,,,,,,,,,, -2100,1.6264896,6.3710423,,,,,,,,,,,,,, -2200,2.0051186,5.9911733,,,,,,,,,,,,,, -2300,2.123548,5.8998027,,,,,,,,,,,,,, -2400,1.3887469,6.4786253,,,,,,,,,,,,,, -2500,1.5514843,6.585514,,,,,,,,,,,,,, -2600,2.0684388,5.8261223,,,,,,,,,,,,,, -2673,,,0.0605859346687793,5.4684343338012695,0.0572999976575374,5.506952285766602,50000.0,0.0459000021219253,5.680256366729736,10000.0,1294.968270778656,1464.868360042572,1294.968270778656,169.67155838012695,0.0765357017517089,0.0 -2700,1.8989232,5.810168,,,,,,,,,,,,,, -2800,1.6083347,6.530424,,,,,,,,,,,,,, -2900,1.9745597,5.863526,,,,,,,,,,,,,, -3000,2.3980582,5.631616,,,,,,,,,,,,,, -3100,1.674358,5.7172527,,,,,,,,,,,,,, -3200,1.8205616,6.4764423,,,,,,,,,,,,,, -3300,2.29223,5.643919,,,,,,,,,,,,,, -3400,1.82263,5.591483,,,,,,,,,,,,,, -3500,2.0950053,5.702161,,,,,,,,,,,,,, -3588,,,0.0949414074420929,5.101712703704834,0.0859000012278556,5.135260105133057,50000.0,0.0681000053882598,5.376662731170654,10000.0,1715.2244164943695,1933.6374650001528,1715.2244164943695,218.1023845672608,0.1063158512115478,0.0 -3600,1.8267549,6.279958,,,,,,,,,,,,,, -3700,2.0360165,5.561869,,,,,,,,,,,,,, -3800,1.8618836,5.551045,,,,,,,,,,,,,, -3900,1.718084,5.536628,,,,,,,,,,,,,, -4000,2.019587,5.4198103,,,,,,,,,,,,,, -4100,2.0311217,5.7733693,,,,,,,,,,,,,, -4200,1.8290058,5.389282,,,,,,,,,,,,,, -4300,1.7552866,5.2655363,,,,,,,,,,,,,, -4400,1.900527,6.115322,,,,,,,,,,,,,, -4500,2.4339058,5.2850175,,,,,,,,,,,,,, -4504,,,0.1316992193460464,4.694127559661865,0.1209999993443489,4.762543678283691,50000.0,0.0938000008463859,5.061717510223389,10000.0,2135.2453899383545,2400.255875349045,2135.2453899383545,264.62031650543213,0.1340563297271728,0.0 -4600,2.0598378,5.207844,,,,,,,,,,,,,, -4700,2.211212,5.2485776,,,,,,,,,,,,,, -4800,1.7525636,5.1523514,,,,,,,,,,,,,, -4900,1.856591,5.1763906,,,,,,,,,,,,,, -5000,1.507605,6.185042,,,,,,,,,,,,,, -5100,1.8977662,5.4671526,,,,,,,,,,,,,, -5200,2.002251,5.2106314,,,,,,,,,,,,,, -5300,2.069681,5.038599,,,,,,,,,,,,,, -5400,1.9216468,4.9530907,,,,,,,,,,,,,, -5416,,,0.1699804663658142,4.376136302947998,0.1554600000381469,4.456610679626465,50000.0,0.1170000061392784,4.778557300567627,10000.0,2555.4077939987183,2867.1481182575226,2555.4077939987183,311.2709410190582,0.161536693572998,0.0 -5500,1.835633,4.9762597,,,,,,,,,,,,,, -5600,1.6733167,5.7800136,,,,,,,,,,,,,, -5700,1.7408509,5.061088,,,,,,,,,,,,,, -5800,1.9261721,4.7915373,,,,,,,,,,,,,, -5900,1.8967346,4.843038,,,,,,,,,,,,,, -6000,2.1753726,5.2702155,,,,,,,,,,,,,, -6100,1.3377305,5.919361,,,,,,,,,,,,,, -6200,1.9766296,5.043452,,,,,,,,,,,,,, -6300,2.036704,4.651523,,,,,,,,,,,,,, -6332,,,0.2166210860013961,4.001551628112793,0.2006799876689911,4.088918209075928,50000.0,0.1517000049352646,4.465686321258545,10000.0,2975.353636741638,3335.9222333431244,2975.353636741638,360.0176274776459,0.192075490951538,0.0 -6400,1.5744758,5.897503,,,,,,,,,,,,,, -6500,2.2782595,4.623733,,,,,,,,,,,,,, -6600,1.9259278,4.7214584,,,,,,,,,,,,,, -6700,1.4395233,6.095008,,,,,,,,,,,,,, -6800,1.8101695,5.8679166,,,,,,,,,,,,,, -6900,1.4555731,6.0585957,,,,,,,,,,,,,, -7000,1.6763811,6.003878,,,,,,,,,,,,,, -7100,2.5499024,4.5592217,,,,,,,,,,,,,, -7200,2.2225935,4.7087135,,,,,,,,,,,,,, -7248,,,0.2599609196186065,3.707098245620728,0.2382399886846542,3.8109381198883057,50000.0,0.1785000115633011,4.246798515319824,10000.0,3395.407825946808,3800.8962712287903,3395.407825946808,404.859354019165,0.2183105945587158,0.0 -7300,2.1850822,4.4729424,,,,,,,,,,,,,, -7400,1.7728624,4.5760765,,,,,,,,,,,,,, -7500,1.884736,4.744378,,,,,,,,,,,,,, -7600,2.4943125,4.427091,,,,,,,,,,,,,, -7700,1.9024467,4.354324,,,,,,,,,,,,,, -7800,1.7815336,4.342934,,,,,,,,,,,,,, -7900,1.8960832,4.332199,,,,,,,,,,,,,, -8000,2.4804242,4.341458,,,,,,,,,,,,,, -8100,1.7053375,4.633713,,,,,,,,,,,,,, -8165,,,0.2924023270606994,3.420924186706543,0.2739599943161011,3.5337398052215576,50000.0,0.2092000097036361,4.015152454376221,10000.0,3815.699334621429,4269.799302577972,3815.699334621429,453.3918721675873,0.245370864868164,0.0 -8200,1.1568561,5.9260836,,,,,,,,,,,,,, -8300,1.8314735,4.2068806,,,,,,,,,,,,,, -8400,1.3871912,5.538182,,,,,,,,,,,,,, -8500,1.7562612,4.5184073,,,,,,,,,,,,,, -8600,1.6799903,4.508879,,,,,,,,,,,,,, -8700,1.3911681,5.8609953,,,,,,,,,,,,,, -8800,1.856212,4.1248403,,,,,,,,,,,,,, -8900,1.854578,6.067874,,,,,,,,,,,,,, -9000,1.701393,5.3478656,,,,,,,,,,,,,, -9082,,,0.3405078053474426,3.166249990463257,0.3056800067424774,3.3514530658721924,50000.0,0.2371000051498413,3.8389713764190674,10000.0,4235.792031049728,4736.970528125763,4235.792031049728,500.3881521224976,0.2754819393157959,0.0 -9100,1.8023497,4.148771,,,,,,,,,,,,,, -9200,1.808446,4.075646,,,,,,,,,,,,,, -9300,1.7533838,4.0116243,,,,,,,,,,,,,, -9400,1.3105919,5.8313394,,,,,,,,,,,,,, -9500,2.0123172,3.9271157,,,,,,,,,,,,,, -9600,1.8772125,3.97621,,,,,,,,,,,,,, -9700,1.5713362,5.1545963,,,,,,,,,,,,,, -9800,1.5640185,5.6757364,,,,,,,,,,,,,, -9900,1.5587897,6.0992813,,,,,,,,,,,,,, -9997,,,0.3578320145606994,3.022150754928589,0.3312000036239624,3.1509346961975098,50000.0,0.2614000141620636,3.675516605377197,10000.0,4656.01069688797,5206.719424247742,4656.01069688797,549.8398551940918,0.3025200366973877,0.0 -10000,1.8052336,4.101894,,,,,,,,,,,,,, -10100,1.8945479,4.0468774,,,,,,,,,,,,,, -10200,2.1821766,4.1395726,,,,,,,,,,,,,, -10300,2.0622797,3.9592361,,,,,,,,,,,,,, -10400,1.7892634,3.7881002,,,,,,,,,,,,,, -10500,3.0018327,3.9904687,,,,,,,,,,,,,, -10600,1.9032182,3.9374595,,,,,,,,,,,,,, -10700,1.6337242,4.9266653,,,,,,,,,,,,,, -10800,1.4353564,5.3570676,,,,,,,,,,,,,, -10900,1.826975,3.8759108,,,,,,,,,,,,,, -10912,,,0.3854492008686065,2.840325355529785,0.3569599986076355,2.995655536651612,50000.0,0.2733000218868255,3.5355544090271,10000.0,5076.158985376358,5677.023521661758,5076.158985376358,599.9115297794342,0.3341560363769531,0.0 -11000,1.4654561,5.0527225,,,,,,,,,,,,,, -11100,1.8668442,3.7249205,,,,,,,,,,,,,, -11200,1.8467655,3.720509,,,,,,,,,,,,,, -11300,2.2185006,3.7295918,,,,,,,,,,,,,, -11400,1.5469611,4.4372582,,,,,,,,,,,,,, -11500,1.3914087,5.8998117,,,,,,,,,,,,,, -11600,1.3163862,5.733597,,,,,,,,,,,,,, -11700,1.6951663,3.9221997,,,,,,,,,,,,,, -11800,1.544186,4.730101,,,,,,,,,,,,,, -11830,,,0.4166601598262787,2.7132813930511475,0.3746599853038788,2.916720151901245,50000.0,0.2883000075817108,3.481478691101074,10000.0,5496.442660808563,6144.755959510803,5496.442660808563,647.2774829864502,0.3645823001861572,0.0 -11900,1.637465,3.9564342,,,,,,,,,,,,,, -12000,1.8301762,3.5986192,,,,,,,,,,,,,, -12100,1.6208825,5.010679,,,,,,,,,,,,,, -12200,2.0013235,3.5591836,,,,,,,,,,,,,, -12300,1.6134789,3.9099653,,,,,,,,,,,,,, -12400,1.8698814,3.713546,,,,,,,,,,,,,, -12500,1.7635411,3.96712,,,,,,,,,,,,,, -12600,1.2413961,5.6525717,,,,,,,,,,,,,, -12700,1.9053079,3.513307,,,,,,,,,,,,,, -12746,,,0.4178515672683716,2.683227777481079,0.3882399797439575,2.8298232555389404,50000.0,0.3023000061511993,3.4036993980407715,10000.0,5916.627821445465,6614.584210395813,5916.627821445465,696.8378114700317,0.3950300216674804,0.0 -12800,1.8669972,3.64745,,,,,,,,,,,,,, -12900,1.945956,3.7037477,,,,,,,,,,,,,, -13000,1.4521304,4.4990096,,,,,,,,,,,,,, -13100,1.859278,3.677915,,,,,,,,,,,,,, -13200,1.8072219,3.58079,,,,,,,,,,,,,, -13300,1.6466194,3.4889185,,,,,,,,,,,,,, -13400,1.8429142,3.5019233,,,,,,,,,,,,,, -13500,1.7185614,4.6275854,,,,,,,,,,,,,, -13600,1.7147511,3.6790013,,,,,,,,,,,,,, -13657,,,0.4305664002895355,2.5879738330841064,0.4039999842643738,2.7390167713165283,50000.0,0.3061000108718872,3.3392672538757324,10000.0,6336.897862434387,7078.295861721039,6336.897862434387,740.1960797309875,0.4252433776855469,0.0 -13700,1.6443789,4.109016,,,,,,,,,,,,,, -13800,1.1409706,5.5005865,,,,,,,,,,,,,, -13900,1.6988306,3.8145566,,,,,,,,,,,,,, -14000,1.2913595,4.885404,,,,,,,,,,,,,, -14100,1.3895203,4.1283937,,,,,,,,,,,,,, -14200,1.0083363,5.5433817,,,,,,,,,,,,,, -14300,1.1514064,5.4714847,,,,,,,,,,,,,, -14400,1.2146788,5.8336797,,,,,,,,,,,,,, -14500,1.7771909,3.4366696,,,,,,,,,,,,,, -14574,,,0.4543749988079071,2.460162401199341,0.4130399823188782,2.6595137119293213,50000.0,0.317900002002716,3.2573442459106445,10000.0,6757.152330160141,7549.075749397278,6757.152330160141,790.636931180954,0.457782506942749,0.0 -14600,1.855412,3.4961724,,,,,,,,,,,,,, -14700,1.599307,3.5221915,,,,,,,,,,,,,, -14800,1.6713252,3.5697439,,,,,,,,,,,,,, -14900,1.8094374,3.215625,,,,,,,,,,,,,, -15000,1.6802789,3.4612195,,,,,,,,,,,,,, -15100,1.6097391,3.6708207,,,,,,,,,,,,,, -15200,1.5935725,3.4791996,,,,,,,,,,,,,, -15300,1.8480313,3.6127234,,,,,,,,,,,,,, -15400,1.3594528,4.851919,,,,,,,,,,,,,, -15489,,,0.4491015672683716,2.497513055801392,0.4224399924278259,2.632542133331299,50000.0,0.3285000026226043,3.233921527862549,10000.0,7177.283156871796,8017.423315048218,7177.283156871796,838.7734625339508,0.4854211807250976,0.0 -15500,1.734178,3.499287,,,,,,,,,,,,,, -15600,1.4923706,3.6576939,,,,,,,,,,,,,, -15700,1.7327241,3.5447779,,,,,,,,,,,,,, -15800,1.7034088,3.6269963,,,,,,,,,,,,,, -15900,1.4832835,4.1153564,,,,,,,,,,,,,, -16000,1.5401214,3.3563507,,,,,,,,,,,,,, -16100,1.5762857,3.3200445,,,,,,,,,,,,,, -16200,1.3945651,3.6967764,,,,,,,,,,,,,, -16300,1.4718076,4.179635,,,,,,,,,,,,,, -16400,1.6369286,3.4235315,,,,,,,,,,,,,, -16404,,,0.4757421910762787,2.3351094722747803,0.44132000207901,2.5121448040008545,50000.0,0.3423000276088714,3.1236488819122314,10000.0,7597.375913381576,8486.046859264374,7597.375913381576,887.2250754833221,0.5128743648529053,0.0 -16500,1.1047027,5.482151,,,,,,,,,,,,,, -16600,1.5875683,3.2978296,,,,,,,,,,,,,, -16700,1.5776298,3.3733597,,,,,,,,,,,,,, -16800,1.2723043,4.486142,,,,,,,,,,,,,, -16900,1.5758355,3.5496173,,,,,,,,,,,,,, -17000,1.327184,4.1159444,,,,,,,,,,,,,, -17100,1.4938816,3.8639216,,,,,,,,,,,,,, -17200,1.5084589,3.3696659,,,,,,,,,,,,,, -17300,1.4950513,4.2005205,,,,,,,,,,,,,, -17319,,,0.4820898473262787,2.317319869995117,0.4428799748420715,2.516441583633423,50000.0,0.3389000296592712,3.1433703899383545,10000.0,8017.678372621536,8953.327250957489,8017.678372621536,934.1244015693665,0.5392537117004395,0.0 -17400,1.2663866,4.427748,,,,,,,,,,,,,, -17500,1.5834571,3.3893328,,,,,,,,,,,,,, -17600,1.6858375,3.3563302,,,,,,,,,,,,,, -17700,1.6338154,3.3278785,,,,,,,,,,,,,, -17800,1.6353914,3.233638,,,,,,,,,,,,,, -17900,1.7685764,3.4832344,,,,,,,,,,,,,, -18000,1.6703384,3.3384063,,,,,,,,,,,,,, -18100,1.7271248,3.2226763,,,,,,,,,,,,,, -18200,1.3101696,4.201858,,,,,,,,,,,,,, -18234,,,0.4768359363079071,2.342108249664306,0.4420199990272522,2.51889705657959,50000.0,0.343500018119812,3.1278634071350098,10000.0,8437.92775630951,9417.584113836288,8437.92775630951,978.050395488739,0.5683753490447998,0.0 -18300,1.2409943,4.0625267,,,,,,,,,,,,,, -18400,1.6046027,3.2500129,,,,,,,,,,,,,, -18500,1.5943698,3.2516582,,,,,,,,,,,,,, -18600,1.4474881,3.2228205,,,,,,,,,,,,,, -18700,1.6003674,3.266895,,,,,,,,,,,,,, -18800,1.1135192,4.717021,,,,,,,,,,,,,, -18900,1.4225715,3.6077302,,,,,,,,,,,,,, -19000,1.5800387,3.1552916,,,,,,,,,,,,,, -19100,1.0434957,5.520728,,,,,,,,,,,,,, -19148,,,0.4917773306369781,2.2599568367004395,0.4593999981880188,2.4413387775421143,50000.0,0.3586000204086303,3.0742838382720947,10000.0,8857.909608125687,9887.99471116066,8857.909608125687,1028.396469831467,0.5990426540374756,0.0 -19200,1.5227993,3.2852886,,,,,,,,,,,,,, -19300,1.6440401,3.4049306,,,,,,,,,,,,,, -19400,1.5471094,3.326654,,,,,,,,,,,,,, -19500,1.289362,5.6752987,,,,,,,,,,,,,, -19600,1.3608457,3.7952113,,,,,,,,,,,,,, -19700,1.6029277,3.2574804,,,,,,,,,,,,,, -19800,1.4093652,3.630465,,,,,,,,,,,,,, -19900,1.6310943,3.2486289,,,,,,,,,,,,,, -20000,1.5223604,3.1834404,,,,,,,,,,,,,, -20064,,,0.509082019329071,2.185120820999145,0.4652799963951111,2.3844094276428223,50000.0,0.3616000115871429,3.0053672790527344,10000.0,9277.858618497849,10357.822923898697,9277.858618497849,1078.1956298351288,0.6277709007263184,0.0 -20100,1.0659019,5.292198,,,,,,,,,,,,,, -20200,1.3139043,4.172209,,,,,,,,,,,,,, -20300,1.3394669,4.8302755,,,,,,,,,,,,,, -20400,1.1122952,5.739585,,,,,,,,,,,,,, -20500,1.5983912,3.1908188,,,,,,,,,,,,,, -20600,0.9149168,5.679461,,,,,,,,,,,,,, -20700,1.5978073,3.2106712,,,,,,,,,,,,,, -20800,1.3857377,3.1779342,,,,,,,,,,,,,, -20900,1.1532793,5.661017,,,,,,,,,,,,,, -20980,,,0.53369140625,2.057878255844116,0.4750799834728241,2.334637403488159,50000.0,0.3641000092029571,2.975359439849853,10000.0,9698.17512512207,10825.721643447876,9698.17512512207,1125.6968231201172,0.6573843955993652,0.0 -21000,1.0975606,4.900197,,,,,,,,,,,,,, -21100,1.4131819,3.4158185,,,,,,,,,,,,,, -21200,1.5841873,3.2389605,,,,,,,,,,,,,, -21300,1.7567426,3.1235197,,,,,,,,,,,,,, -21400,1.1335948,5.3392277,,,,,,,,,,,,,, -21500,1.3337559,4.134668,,,,,,,,,,,,,, -21600,1.6117951,3.2760215,,,,,,,,,,,,,, -21700,1.361108,3.2856293,,,,,,,,,,,,,, -21800,1.6178398,3.3778605,,,,,,,,,,,,,, -21897,,,0.507519543170929,2.1788175106048584,0.4765599966049194,2.338596105575561,50000.0,0.3733000159263611,2.9646759033203125,10000.0,10118.501426696776,11293.408746242523,10118.501426696776,1172.9720392227173,0.6915838718414307,0.0 -21900,1.0294839,4.584022,,,,,,,,,,,,,, -22000,1.672421,3.1409485,,,,,,,,,,,,,, -22100,1.3479236,3.3102374,,,,,,,,,,,,,, -22200,1.7587167,3.0755427,,,,,,,,,,,,,, -22300,1.4971763,5.59326,,,,,,,,,,,,,, -22400,1.5876594,3.0520172,,,,,,,,,,,,,, -22500,1.6175338,3.021848,,,,,,,,,,,,,, -22600,1.8765379,3.1879036,,,,,,,,,,,,,, -22700,1.5423181,3.280478,,,,,,,,,,,,,, -22800,1.1007072,5.5429287,,,,,,,,,,,,,, -22814,,,0.5238476395606995,2.071101427078247,0.4860999882221222,2.2656033039093018,50000.0,0.3775000274181366,2.9053289890289307,10000.0,10538.751983642578,11762.698380231855,10538.751983642578,1221.914003610611,0.729525089263916,0.0 -22900,1.56953,3.2208395,,,,,,,,,,,,,, -23000,1.5861394,3.062591,,,,,,,,,,,,,, -23100,1.3737812,3.5370598,,,,,,,,,,,,,, -23200,1.0467223,5.51936,,,,,,,,,,,,,, -23300,1.5795767,3.599698,,,,,,,,,,,,,, -23400,1.5801829,3.306325,,,,,,,,,,,,,, -23500,1.4701576,3.071367,,,,,,,,,,,,,, -23600,1.4244565,5.6286297,,,,,,,,,,,,,, -23700,1.5964833,3.2625194,,,,,,,,,,,,,, -23730,,,0.5507421493530273,1.938953518867493,0.4956399798393249,2.2019717693328857,50000.0,0.3871000111103058,2.8293137550354004,10000.0,10958.97521162033,12228.869998455048,10958.97521162033,1267.7731275558472,0.766730546951294,0.0 -23800,1.322858,3.7267659,,,,,,,,,,,,,, -23900,1.4451349,3.5194328,,,,,,,,,,,,,, -24000,1.6332237,3.1132038,,,,,,,,,,,,,, -24100,1.0073259,5.486715,,,,,,,,,,,,,, -24200,1.4710624,3.108313,,,,,,,,,,,,,, -24300,1.5711573,2.9472575,,,,,,,,,,,,,, -24400,1.659474,3.029727,,,,,,,,,,,,,, -24500,1.8491473,2.9909244,,,,,,,,,,,,,, -24600,1.6470684,2.892229,,,,,,,,,,,,,, -24647,,,0.5287109017372131,2.059361219406128,0.4931999742984772,2.238321542739868,50000.0,0.3889000117778778,2.873844623565674,10000.0,11379.114792823792,12699.20350074768,11379.114792823792,1317.8803231716156,0.8011837005615234,0.0 -24700,1.590078,3.0912423,,,,,,,,,,,,,, -24800,1.2942824,4.7236147,,,,,,,,,,,,,, -24900,1.5614556,3.0784562,,,,,,,,,,,,,, -25000,1.5499252,3.3880196,,,,,,,,,,,,,, -25100,1.0266683,5.3002944,,,,,,,,,,,,,, -25200,1.6075132,3.0736537,,,,,,,,,,,,,, -25300,1.6243517,3.582599,,,,,,,,,,,,,, -25400,1.6186725,3.1462135,,,,,,,,,,,,,, -25500,1.3748192,3.5509121,,,,,,,,,,,,,, -25563,,,0.5419335961341858,1.988563776016236,0.5034799575805664,2.1718690395355225,50000.0,0.3960000276565552,2.799640893936157,10000.0,11799.398998975754,13168.366040945051,11799.398998975754,1366.6767621040344,0.8305325508117676,0.0 -25600,1.2564732,4.039372,,,,,,,,,,,,,, -25700,1.5696445,3.0497415,,,,,,,,,,,,,, -25800,1.2610751,4.3614254,,,,,,,,,,,,,, -25900,1.2766919,4.173349,,,,,,,,,,,,,, -26000,1.6513337,3.0189247,,,,,,,,,,,,,, -26100,1.6131005,2.953814,,,,,,,,,,,,,, -26200,1.092642,5.19406,,,,,,,,,,,,,, -26300,1.5513515,3.1507225,,,,,,,,,,,,,, -26400,1.8607905,3.2271976,,,,,,,,,,,,,, -26477,,,0.5632421970367432,1.8941307067871087,0.5120199918746948,2.144299268722534,50000.0,0.3974000215530395,2.7761967182159424,10000.0,12219.34901380539,13637.125775814056,12219.34901380539,1415.4040472507477,0.8615057468414307,0.0 -26500,1.5005546,3.106028,,,,,,,,,,,,,, -26600,1.3372822,3.5712452,,,,,,,,,,,,,, -26700,1.5956073,2.8107264,,,,,,,,,,,,,, -26800,1.5509624,3.172875,,,,,,,,,,,,,, -26900,1.3498126,3.675086,,,,,,,,,,,,,, -27000,1.6602644,2.9454944,,,,,,,,,,,,,, -27100,1.4664917,2.844412,,,,,,,,,,,,,, -27200,1.2282417,4.568387,,,,,,,,,,,,,, -27300,1.4960653,3.3924563,,,,,,,,,,,,,, -27392,,,0.5482421517372131,1.9673892259597776,0.5123999714851379,2.13213849067688,50000.0,0.4012000262737274,2.763537645339966,10000.0,12639.654458522797,14106.759474277496,12639.654458522797,1464.6509912014008,0.891242265701294,0.0 -27400,1.5712922,3.0384672,,,,,,,,,,,,,, -27500,1.4257288,3.5731318,,,,,,,,,,,,,, -27600,1.4659684,3.5359538,,,,,,,,,,,,,, -27700,1.2733064,4.5019555,,,,,,,,,,,,,, -27800,1.6553385,3.1401668,,,,,,,,,,,,,, -27900,1.562929,3.0049584,,,,,,,,,,,,,, -28000,1.5923393,2.9148269,,,,,,,,,,,,,, -28100,1.5448686,3.0276687,,,,,,,,,,,,,, -28200,1.6075444,2.8902082,,,,,,,,,,,,,, -28300,1.6664962,2.850125,,,,,,,,,,,,,, -28308,,,0.5547069907188416,1.929384708404541,0.5189200043678284,2.1071009635925293,50000.0,0.4006000161170959,2.744941473007202,10000.0,13059.61026597023,14574.590921640396,13059.61026597023,1512.441159248352,0.9247598648071288,0.0 -28400,1.0916032,5.216111,,,,,,,,,,,,,, -28500,1.5651299,2.9136088,,,,,,,,,,,,,, -28600,1.1941617,5.188648,,,,,,,,,,,,,, -28700,1.1393652,5.5508037,,,,,,,,,,,,,, -28800,1.3204491,4.57262,,,,,,,,,,,,,, -28900,1.6288908,2.8899095,,,,,,,,,,,,,, -29000,1.4205594,3.4732227,,,,,,,,,,,,,, -29100,1.3594903,3.821765,,,,,,,,,,,,,, -29200,1.2007538,5.1123366,,,,,,,,,,,,,, -29224,,,0.5635741949081421,1.887037754058838,0.5194000005722046,2.098005533218384,50000.0,0.4052000045776367,2.7502341270446777,10000.0,13479.542765378952,15045.342432498932,13479.542765378952,1563.1740138530731,0.9577534198760986,0.0 -29300,2.2522192,2.9829686,,,,,,,,,,,,,, -29400,1.5756507,3.1916323,,,,,,,,,,,,,, -29500,1.3087102,4.58808,,,,,,,,,,,,,, -29600,1.6204689,2.916179,,,,,,,,,,,,,, -29700,1.5053424,2.9656057,,,,,,,,,,,,,, -29800,1.6197731,2.9168613,,,,,,,,,,,,,, -29900,1.3272743,4.52092,,,,,,,,,,,,,, -30000,1.5586256,3.4066164,,,,,,,,,,,,,, -30100,1.5506004,3.0392618,,,,,,,,,,,,,, -30141,,,0.5602734088897705,1.8773443698883057,0.5280199646949768,2.0621399879455566,50000.0,0.4148000180721283,2.698044538497925,10000.0,13899.718740701675,15514.25835442543,13899.718740701675,1611.8296627998352,0.990304946899414,0.0 -30200,1.3836906,3.676115,,,,,,,,,,,,,, -30300,1.1342664,5.305671,,,,,,,,,,,,,, -30400,1.8883466,2.9318726,,,,,,,,,,,,,, -30500,1.0799525,5.4028606,,,,,,,,,,,,,, -30600,1.7705472,2.8856263,,,,,,,,,,,,,, -30700,1.2715056,4.6711445,,,,,,,,,,,,,, -30800,1.5422884,2.986819,,,,,,,,,,,,,, -30900,1.5788093,3.6169963,,,,,,,,,,,,,, -31000,1.3407117,5.065897,,,,,,,,,,,,,, -31057,,,0.5613867044448853,1.895029664039612,0.526199996471405,2.0695595741271973,50000.0,0.415800005197525,2.691118001937866,10000.0,14319.735169649124,15982.635932683945,14319.735169649124,1660.1082208156586,1.0202922821044922,0.0 -31100,1.6644274,3.3522296,,,,,,,,,,,,,, -31200,1.4850557,2.6503036,,,,,,,,,,,,,, -31300,1.6107739,5.4946685,,,,,,,,,,,,,, -31400,1.529923,3.4471197,,,,,,,,,,,,,, -31500,1.4109957,3.1602995,,,,,,,,,,,,,, -31600,1.6041973,2.8514092,,,,,,,,,,,,,, -31700,1.335729,4.6057153,,,,,,,,,,,,,, -31800,1.8189055,3.3157325,,,,,,,,,,,,,, -31900,1.3240031,3.441408,,,,,,,,,,,,,, -31972,,,0.5696093440055847,1.8706759214401243,0.5287600159645081,2.060586452484131,50000.0,0.4119000136852264,2.7069880962371826,10000.0,14739.937356710434,16451.484882354736,14739.937356710434,1708.6653501987455,1.0565845966339111,0.0 -32000,1.8313515,2.9554539,,,,,,,,,,,,,, -32100,1.1923763,3.8135514,,,,,,,,,,,,,, -32200,1.7960883,2.8006637,,,,,,,,,,,,,, -32300,1.5834472,2.8068426,,,,,,,,,,,,,, -32400,1.4543507,3.6946394,,,,,,,,,,,,,, -32500,1.6341193,2.7930617,,,,,,,,,,,,,, -32600,1.5765648,3.4384713,,,,,,,,,,,,,, -32700,1.3182158,4.075457,,,,,,,,,,,,,, -32800,1.3040465,5.266128,,,,,,,,,,,,,, -32890,,,0.5985351204872131,1.70049250125885,0.5352999567985535,2.021901130676269,50000.0,0.4160000085830688,2.677694797515869,10000.0,15160.13657617569,16920.72052717209,15160.13657617569,1757.6149718761444,1.0907628536224363,0.0 -32900,1.4245427,4.84182,,,,,,,,,,,,,, -33000,1.7908514,2.8869336,,,,,,,,,,,,,, -33100,1.9156309,3.1433644,,,,,,,,,,,,,, -33200,1.695391,3.4281976,,,,,,,,,,,,,, -33300,1.4384097,5.5604925,,,,,,,,,,,,,, -33400,1.1983541,4.863857,,,,,,,,,,,,,, -33500,1.6048627,2.885574,,,,,,,,,,,,,, -33600,1.3595145,5.2117543,,,,,,,,,,,,,, -33700,1.7577876,2.878028,,,,,,,,,,,,,, -33800,1.342765,3.852155,,,,,,,,,,,,,, -33807,,,0.567578136920929,1.879085898399353,0.5329399704933167,2.052635431289673,50000.0,0.4164000153541565,2.689453363418579,10000.0,15580.376125097277,17389.854667186737,15580.376125097277,1806.4245445728304,1.1232614517211914,0.0 -33900,1.2614167,4.5647264,,,,,,,,,,,,,, -34000,1.1631111,4.7037272,,,,,,,,,,,,,, -34100,1.6711853,2.8644013,,,,,,,,,,,,,, -34200,1.4831492,3.366508,,,,,,,,,,,,,, -34300,1.5433178,2.8857353,,,,,,,,,,,,,, -34400,1.7200687,2.7649722,,,,,,,,,,,,,, -34500,1.6709725,2.847022,,,,,,,,,,,,,, -34600,1.3957883,4.472817,,,,,,,,,,,,,, -34700,1.253073,5.4952483,,,,,,,,,,,,,, -34724,,,0.5770702958106995,1.802340149879456,0.5371999740600586,1.9966063499450684,50000.0,0.4126000106334686,2.657298803329468,10000.0,16000.444960832596,17856.395292282104,16000.444960832596,1852.8152103424072,1.1534223556518557,0.0 -34800,1.6067551,2.9708416,,,,,,,,,,,,,, -34900,1.3781292,3.9275465,,,,,,,,,,,,,, -35000,1.4250151,4.492855,,,,,,,,,,,,,, -35100,1.5720917,2.9229739,,,,,,,,,,,,,, -35200,1.6211871,2.8998082,,,,,,,,,,,,,, -35300,1.4879363,3.400229,,,,,,,,,,,,,, -35400,1.5287371,3.2831242,,,,,,,,,,,,,, -35500,1.3159415,3.9919782,,,,,,,,,,,,,, -35600,1.6138428,4.457925,,,,,,,,,,,,,, -35638,,,0.5990234017372131,1.7254287004470823,0.5380600094795227,2.003264904022217,50000.0,0.4256000220775604,2.619293212890625,10000.0,16420.363413095474,18325.28321290016,16420.363413095474,1901.7003903388977,1.184575080871582,0.0 -35700,1.8457054,2.838571,,,,,,,,,,,,,, -35800,1.6471881,3.0021534,,,,,,,,,,,,,, -35900,1.3580226,4.1071653,,,,,,,,,,,,,, -36000,1.823519,2.8936472,,,,,,,,,,,,,, -36100,1.3897824,5.404323,,,,,,,,,,,,,, -36200,1.5714923,2.979948,,,,,,,,,,,,,, -36300,1.9214449,2.7061453,,,,,,,,,,,,,, -36400,1.564999,2.900934,,,,,,,,,,,,,, -36500,1.3701633,3.729506,,,,,,,,,,,,,, -36554,,,0.575488269329071,1.8254637718200684,0.5359799861907959,2.010288715362549,50000.0,0.4204000234603882,2.6431541442871094,10000.0,16840.63138628006,18794.849648475647,16840.63138628006,1950.9113097190857,1.2209351062774658,0.0 -36600,1.6312908,2.8130457,,,,,,,,,,,,,, -36700,1.4951344,3.2611673,,,,,,,,,,,,,, -36800,1.731194,2.9665258,,,,,,,,,,,,,, -36900,1.7462965,2.8445644,,,,,,,,,,,,,, -37000,1.4703232,5.0283628,,,,,,,,,,,,,, -37100,1.2438184,4.6871853,,,,,,,,,,,,,, -37200,1.669354,3.1900241,,,,,,,,,,,,,, -37300,1.6617209,2.743933,,,,,,,,,,,,,, -37400,1.7838069,2.850782,,,,,,,,,,,,,, -37469,,,0.5883203148841858,1.7518789768218994,0.5458799600601196,1.955213069915772,50000.0,0.4323000311851501,2.6027326583862305,10000.0,17260.585029363632,19257.40512537956,17260.585029363632,1993.4251432418823,1.2563939094543457,0.0 -37500,1.5435313,3.5355704,,,,,,,,,,,,,, -37600,1.9169894,2.733632,,,,,,,,,,,,,, -37700,1.61523,2.8612418,,,,,,,,,,,,,, -37800,1.5201288,5.1090198,,,,,,,,,,,,,, -37900,1.5485561,2.7920644,,,,,,,,,,,,,, -38000,1.7501984,2.7651832,,,,,,,,,,,,,, -38100,1.8835084,2.7261748,,,,,,,,,,,,,, -38200,1.6957384,2.7581556,,,,,,,,,,,,,, -38300,1.2528002,4.871035,,,,,,,,,,,,,, -38384,,,0.6019335985183716,1.7069507837295532,0.5504199862480164,1.9425902366638184,50000.0,0.4346000254154205,2.583150625228882,10000.0,17680.8397500515,19725.88530659676,17680.8397500515,2041.5583720207208,1.2973315715789795,0.0 -38400,1.6331613,2.9649658,,,,,,,,,,,,,, -38500,1.5900805,2.7863133,,,,,,,,,,,,,, -38600,1.7675855,3.1119354,,,,,,,,,,,,,, -38700,1.6363186,3.3187704,,,,,,,,,,,,,, -38800,1.7377714,2.8334866,,,,,,,,,,,,,, -38900,1.3542873,5.316988,,,,,,,,,,,,,, -39000,1.2366636,4.7497096,,,,,,,,,,,,,, -39100,1.6161119,2.989025,,,,,,,,,,,,,, -39200,1.5891055,3.3459258,,,,,,,,,,,,,, -39300,1.7149158,2.7516477,,,,,,,,,,,,,, -39301,,,0.5848632454872131,1.7843482494354248,0.545199990272522,1.975147485733032,50000.0,0.4333000183105469,2.616333484649658,10000.0,18101.13681983948,20192.52774953842,18101.13681983948,2087.815386772156,1.333068609237671,0.0 -39400,1.312894,3.729062,,,,,,,,,,,,,, -39500,1.7238477,2.8759272,,,,,,,,,,,,,, -39600,1.8951339,2.6462634,,,,,,,,,,,,,, -39700,1.7848471,2.692247,,,,,,,,,,,,,, -39800,1.8389354,3.5218318,,,,,,,,,,,,,, -39900,1.721772,2.7082253,,,,,,,,,,,,,, -40000,1.5965768,3.079297,,,,,,,,,,,,,, -40100,1.7173321,2.7622588,,,,,,,,,,,,,, -40200,1.6085435,2.7480676,,,,,,,,,,,,,, -40218,,,0.5871874690055847,1.7634190320968628,0.5512199997901917,1.9427509307861328,50000.0,0.4305000305175781,2.606301069259644,10000.0,18521.130932092667,20661.06382799149,18521.130932092667,2136.272389173508,1.364790678024292,0.0 -40300,1.679283,2.9245632,,,,,,,,,,,,,, -40400,1.845753,2.671192,,,,,,,,,,,,,, -40500,1.4908701,3.4361486,,,,,,,,,,,,,, -40600,1.1320084,5.322177,,,,,,,,,,,,,, -40700,1.8405724,2.8143287,,,,,,,,,,,,,, -40800,1.3634917,5.088119,,,,,,,,,,,,,, -40900,1.6526278,3.5004056,,,,,,,,,,,,,, -41000,1.4696722,3.9850168,,,,,,,,,,,,,, -41100,1.5721121,5.146906,,,,,,,,,,,,,, -41135,,,0.597851574420929,1.7207162380218506,0.5503999590873718,1.93899405002594,50000.0,0.4401000142097473,2.572782754898072,10000.0,18941.098199129105,21129.979234695435,18941.098199129105,2185.132310628891,1.3997371196746826,0.0 -41200,1.8140532,2.8382201,,,,,,,,,,,,,, -41300,1.6291633,2.7613025,,,,,,,,,,,,,, -41400,1.4451655,5.389513,,,,,,,,,,,,,, -41500,1.4588089,3.7289324,,,,,,,,,,,,,, -41600,1.7741147,2.7127013,,,,,,,,,,,,,, -41700,1.3829385,4.008975,,,,,,,,,,,,,, -41800,1.9289641,2.8502364,,,,,,,,,,,,,, -41900,1.7000608,3.0039105,,,,,,,,,,,,,, -42000,1.6062003,2.7884479,,,,,,,,,,,,,, -42051,,,0.5936523079872131,1.7368556261062622,0.5519399642944336,1.9425498247146609,50000.0,0.4354000091552734,2.5808470249176025,10000.0,19361.50342464447,21595.81792283058,19361.50342464447,2230.481188297272,1.4327614307403564,0.0 -42100,1.5317565,3.6865726,,,,,,,,,,,,,, -42200,1.761892,2.882737,,,,,,,,,,,,,, -42300,1.4368055,4.365969,,,,,,,,,,,,,, -42400,1.5212597,5.098485,,,,,,,,,,,,,, -42500,1.4218254,4.3810997,,,,,,,,,,,,,, -42600,1.604941,2.7616723,,,,,,,,,,,,,, -42700,1.4986166,3.186542,,,,,,,,,,,,,, -42800,1.6214323,2.7367835,,,,,,,,,,,,,, -42900,1.2675924,5.089408,,,,,,,,,,,,,, -42966,,,0.5927538871765137,1.7270139455795288,0.5557399988174438,1.9057121276855469,50000.0,0.4368000328540802,2.5653188228607178,10000.0,19781.703052520752,22064.27545189857,19781.703052520752,2278.644249200821,1.4757418632507324,0.0 -43000,1.7755773,5.3826857,,,,,,,,,,,,,, -43100,1.5133544,3.5300643,,,,,,,,,,,,,, -43200,1.7856954,3.596795,,,,,,,,,,,,,, -43300,1.4257301,4.633341,,,,,,,,,,,,,, -43400,1.552558,2.6957343,,,,,,,,,,,,,, -43500,1.6819565,2.6416101,,,,,,,,,,,,,, -43600,1.613976,2.6582613,,,,,,,,,,,,,, -43700,1.7145491,2.72533,,,,,,,,,,,,,, -43800,1.7247665,2.620261,,,,,,,,,,,,,, -43882,,,0.6057812571525574,1.6457350254058838,0.5611599683761597,1.869145750999451,50000.0,0.4455000162124634,2.518744707107544,10000.0,20201.69106078148,22532.53761100769,20201.69106078148,2326.826430082321,1.5156099796295166,0.0 -43900,1.2710804,5.1871486,,,,,,,,,,,,,, -44000,1.7523582,2.777659,,,,,,,,,,,,,, -44100,1.5967058,3.028194,,,,,,,,,,,,,, -44200,1.8351407,2.6651125,,,,,,,,,,,,,, -44300,1.7852714,2.818116,,,,,,,,,,,,,, -44400,1.8124251,2.85246,,,,,,,,,,,,,, -44500,1.4888159,4.577803,,,,,,,,,,,,,, -44600,1.4660349,4.281443,,,,,,,,,,,,,, -44700,1.9322513,2.7769186,,,,,,,,,,,,,, -44796,,,0.6296679377555847,1.5705506801605225,0.5589399933815002,1.888720750808716,50000.0,0.4451000094413757,2.5309979915618896,10000.0,20622.27797460556,22999.989002227783,20622.27797460556,2373.602608203888,1.5518851280212402,0.0 -44800,1.6009846,3.0615804,,,,,,,,,,,,,, -44900,1.8234344,2.8128142,,,,,,,,,,,,,, -45000,1.9586695,2.8638434,,,,,,,,,,,,,, -45100,1.5278687,4.5394907,,,,,,,,,,,,,, -45200,1.6397579,2.8090267,,,,,,,,,,,,,, -45300,1.7807996,2.7491193,,,,,,,,,,,,,, -45400,1.4605358,3.743086,,,,,,,,,,,,,, -45500,1.3716936,3.9703317,,,,,,,,,,,,,, -45600,2.02385,2.5668955,,,,,,,,,,,,,, -45700,1.716233,2.7778587,,,,,,,,,,,,,, -45710,,,0.6051172018051147,1.6690466403961182,0.5628399848937988,1.861746311187744,50000.0,0.4449000358581543,2.515469789505005,10000.0,21042.34734320641,23466.94051671028,21042.34734320641,2420.3978378772736,1.5879981517791748,0.0 -45800,1.7167268,2.7094111,,,,,,,,,,,,,, -45900,1.6691085,5.159275,,,,,,,,,,,,,, -46000,1.4738046,2.9543748,,,,,,,,,,,,,, -46100,1.5957564,2.7752678,,,,,,,,,,,,,, -46200,1.41036,3.6252375,,,,,,,,,,,,,, -46300,1.6577885,2.9837577,,,,,,,,,,,,,, -46400,1.738591,2.970627,,,,,,,,,,,,,, -46500,1.6670846,2.709404,,,,,,,,,,,,,, -46600,1.9520004,2.7218256,,,,,,,,,,,,,, -46623,,,0.6066796779632568,1.6673763990402222,0.5652799606323242,1.8668668270111084,50000.0,0.4453000128269195,2.5359339714050293,10000.0,21462.64211010933,23935.413821458817,21462.64211010933,2468.4884293079376,1.6240994930267334,0.0 -46700,1.5902202,3.6034496,,,,,,,,,,,,,, -46800,1.7186284,2.7845292,,,,,,,,,,,,,, -46900,1.5480422,3.498211,,,,,,,,,,,,,, -47000,1.3981053,4.571333,,,,,,,,,,,,,, -47100,1.5559173,3.8359272,,,,,,,,,,,,,, -47200,1.6474582,2.617504,,,,,,,,,,,,,, -47300,1.5774391,2.7954953,,,,,,,,,,,,,, -47400,1.6637062,2.58718,,,,,,,,,,,,,, -47500,1.8465542,2.771733,,,,,,,,,,,,,, -47539,,,0.626269519329071,1.5778939723968506,0.5616999864578247,1.8711748123168943,50000.0,0.4479000270366668,2.509688377380371,10000.0,21882.84224438668,24402.25755739212,21882.84224438668,2515.0421693325043,1.66135573387146,0.0 -47600,1.5222234,3.323785,,,,,,,,,,,,,, -47700,1.5571412,3.79299,,,,,,,,,,,,,, -47800,1.7473036,2.8360238,,,,,,,,,,,,,, -47900,1.7604463,2.7429738,,,,,,,,,,,,,, -48000,1.5496626,3.1584516,,,,,,,,,,,,,, -48100,1.8553189,2.7195597,,,,,,,,,,,,,, -48200,1.8029084,2.8886988,,,,,,,,,,,,,, -48300,1.8939222,2.6651366,,,,,,,,,,,,,, -48400,1.7469307,2.6888928,,,,,,,,,,,,,, -48455,,,0.6093554496765137,1.643090009689331,0.5676599740982056,1.8277660608291624,50000.0,0.4543000161647796,2.494823455810547,10000.0,22303.119409799576,24869.121667146683,22303.119409799576,2561.5461843013763,1.692392110824585,0.0 -48500,1.3513952,4.856321,,,,,,,,,,,,,, -48600,1.5080053,3.514442,,,,,,,,,,,,,, -48700,1.4875392,4.665172,,,,,,,,,,,,,, -48800,1.9661065,2.6943798,,,,,,,,,,,,,, -48900,1.7378637,2.7223842,,,,,,,,,,,,,, -49000,1.2499193,5.3484917,,,,,,,,,,,,,, -49100,1.8272879,2.7035222,,,,,,,,,,,,,, -49200,1.7677822,2.7369409,,,,,,,,,,,,,, -49300,1.4241374,4.005354,,,,,,,,,,,,,, -49371,,,0.611132800579071,1.6391677856445312,0.5672399997711182,1.839793682098389,50000.0,0.4473000168800354,2.480165243148804,10000.0,22723.300478219982,25337.39897251129,22723.300478219982,2609.5553126335144,1.7268388271331787,0.0 -49400,1.6882008,2.809917,,,,,,,,,,,,,, -49500,2.0956798,2.7290983,,,,,,,,,,,,,, -49600,1.7158875,2.8887377,,,,,,,,,,,,,, -49700,1.8574986,2.6258528,,,,,,,,,,,,,, -49800,1.7683213,2.8094122,,,,,,,,,,,,,, -49900,1.5258269,5.263038,,,,,,,,,,,,,, -50000,1.3090633,5.1521635,,,,,,,,,,,,,, -50100,1.8391476,2.650238,,,,,,,,,,,,,, -50200,1.7136165,2.7651894,,,,,,,,,,,,,, -50287,,,0.6187109351158142,1.5916036367416382,0.5705599784851074,1.8403029441833496,50000.0,0.4564000070095062,2.477119445800781,10000.0,23143.49778413773,25804.48171377182,23143.49778413773,2656.3447353839874,1.7708237171173096,0.0 -50300,1.8474184,2.6121972,,,,,,,,,,,,,, -50400,1.2966592,4.8670325,,,,,,,,,,,,,, -50500,1.7895393,2.6505141,,,,,,,,,,,,,, -50600,1.8098396,2.6613576,,,,,,,,,,,,,, -50700,1.4941292,3.5234568,,,,,,,,,,,,,, -50800,1.3836696,4.1783442,,,,,,,,,,,,,, -50900,1.7360308,2.8319712,,,,,,,,,,,,,, -51000,1.9616071,2.7507083,,,,,,,,,,,,,, -51100,1.435885,5.396437,,,,,,,,,,,,,, -51200,1.73933,2.6365635,,,,,,,,,,,,,, -51201,,,0.6166015267372131,1.6008018255233765,0.5757399797439575,1.7892749309539795,50000.0,0.457800030708313,2.448241949081421,10000.0,23563.565361738205,26271.550952911377,23563.565361738205,2703.257434368133,1.8078598976135247,0.0 -51300,1.8535002,2.893326,,,,,,,,,,,,,, -51400,1.847838,2.6237607,,,,,,,,,,,,,, -51500,1.7699744,2.810106,,,,,,,,,,,,,, -51600,1.5786482,3.164136,,,,,,,,,,,,,, -51700,1.8185743,2.8478572,,,,,,,,,,,,,, -51800,1.4505558,5.372991,,,,,,,,,,,,,, -51900,1.5324496,3.1403766,,,,,,,,,,,,,, -52000,1.7733212,3.4709287,,,,,,,,,,,,,, -52100,1.9738714,2.5616634,,,,,,,,,,,,,, -52116,,,0.6089453101158142,1.667687177658081,0.5686599612236023,1.8564475774765008,50000.0,0.4490000307559967,2.522512912750244,10000.0,23983.790986299515,26741.61191439629,23983.790986299515,2752.9996156692505,1.848759651184082,0.0 -52200,1.7998743,2.669475,,,,,,,,,,,,,, -52300,1.5195285,3.8120766,,,,,,,,,,,,,, -52400,1.8609785,2.6726406,,,,,,,,,,,,,, -52500,1.7448568,2.600084,,,,,,,,,,,,,, -52600,1.319663,5.182706,,,,,,,,,,,,,, -52700,1.422924,5.150104,,,,,,,,,,,,,, -52800,1.9885904,2.7577696,,,,,,,,,,,,,, -52900,1.4785874,3.2932181,,,,,,,,,,,,,, -53000,1.8988616,2.681577,,,,,,,,,,,,,, -53032,,,0.6204296946525574,1.5840575695037842,0.5726199746131897,1.805444836616516,50000.0,0.4565000236034393,2.4617958068847656,10000.0,24403.7320561409,27206.607456207275,24403.7320561409,2797.968763113022,1.881891727447509,0.0 -53100,1.7952833,2.7197905,,,,,,,,,,,,,, -53200,1.8070267,2.7208896,,,,,,,,,,,,,, -53300,1.7959849,2.6244211,,,,,,,,,,,,,, -53400,1.8828515,2.6228952,,,,,,,,,,,,,, -53500,1.7590938,2.5712483,,,,,,,,,,,,,, -53600,1.7256347,2.729116,,,,,,,,,,,,,, -53700,1.7486343,2.6147714,,,,,,,,,,,,,, -53800,1.8394886,2.6777701,,,,,,,,,,,,,, -53900,1.7298151,2.7171907,,,,,,,,,,,,,, -53949,,,0.6197265386581421,1.6068668365478516,0.5780799984931946,1.7958285808563232,50000.0,0.4558000266551971,2.460911512374878,10000.0,24823.81976938248,27676.279418468475,24823.81976938248,2847.4651761055,1.9169471263885496,0.0 -54000,1.7325324,3.6239116,,,,,,,,,,,,,, -54100,1.3545388,3.734933,,,,,,,,,,,,,, -54200,1.8246263,2.7398305,,,,,,,,,,,,,, -54300,1.5918728,3.076853,,,,,,,,,,,,,, -54400,2.0267563,2.8666763,,,,,,,,,,,,,, -54500,1.9078002,2.4751198,,,,,,,,,,,,,, -54600,1.6224527,3.5490894,,,,,,,,,,,,,, -54700,1.8342373,2.646343,,,,,,,,,,,,,, -54800,1.9134485,2.680411,,,,,,,,,,,,,, -54863,,,0.6193749904632568,1.6178147792816162,0.5789799690246582,1.8013064861297607,50000.0,0.4589000344276428,2.458005428314209,10000.0,25243.79498982429,28139.657998085026,25243.79498982429,2890.782431602478,1.950938701629639,0.0 -54900,1.4604182,3.8123608,,,,,,,,,,,,,, -55000,2.0142126,2.6566343,,,,,,,,,,,,,, -55100,1.6239904,5.0794196,,,,,,,,,,,,,, -55200,1.5894625,3.374434,,,,,,,,,,,,,, -55300,1.3124468,5.0834618,,,,,,,,,,,,,, -55400,1.8904003,2.6332688,,,,,,,,,,,,,, -55500,1.5710006,3.7896948,,,,,,,,,,,,,, -55600,1.6092707,3.1273453,,,,,,,,,,,,,, -55700,1.6740564,2.7107875,,,,,,,,,,,,,, -55774,,,0.6284765601158142,1.5806477069854736,0.5812999606132507,1.804103970527649,50000.0,0.4631000161170959,2.472367525100708,10000.0,25663.81897115708,28606.72021007538,25663.81897115708,2937.7285718917847,1.9914829730987549,0.0 -55800,1.6832566,2.553311,,,,,,,,,,,,,, -55900,2.0187871,2.6691837,,,,,,,,,,,,,, -56000,1.7251128,3.4731631,,,,,,,,,,,,,, -56100,1.7133089,2.9492984,,,,,,,,,,,,,, -56200,1.5325408,3.0587325,,,,,,,,,,,,,, -56300,1.9864212,2.748586,,,,,,,,,,,,,, -56400,1.7811359,2.5461278,,,,,,,,,,,,,, -56500,1.9715799,2.5988996,,,,,,,,,,,,,, -56600,1.4750506,4.602859,,,,,,,,,,,,,, -56688,,,0.6417773365974426,1.5043836832046509,0.578719973564148,1.7915959358215332,50000.0,0.4606000185012817,2.447606086730957,10000.0,26084.16915845871,29075.186302661896,26084.16915845871,2985.750026702881,2.033883810043335,0.0 -56700,1.7440344,3.1043487,,,,,,,,,,,,,, -56800,1.725592,3.2513852,,,,,,,,,,,,,, -56900,1.5674942,5.076536,,,,,,,,,,,,,, -57000,1.9355624,2.6760964,,,,,,,,,,,,,, -57100,1.409986,4.06371,,,,,,,,,,,,,, -57200,1.8868873,2.5655093,,,,,,,,,,,,,, -57300,1.7563703,2.7852051,,,,,,,,,,,,,, -57400,1.7860631,2.5397758,,,,,,,,,,,,,, -57500,1.8929665,2.696832,,,,,,,,,,,,,, -57600,1.72452,5.0181985,,,,,,,,,,,,,, -57601,,,0.6244726181030273,1.5794726610183716,0.5821599960327148,1.773970603942871,50000.0,0.4669000208377838,2.421680450439453,10000.0,26504.19334626197,29542.947952270508,26504.19334626197,3033.001730442047,2.4678242206573486,0.0 -57700,1.7704036,2.6509414,,,,,,,,,,,,,, -57800,1.9451648,2.657483,,,,,,,,,,,,,, -57900,1.4149628,4.3563986,,,,,,,,,,,,,, -58000,1.4561964,4.185471,,,,,,,,,,,,,, -58100,1.9307371,2.7555337,,,,,,,,,,,,,, -58200,1.9135091,2.6591046,,,,,,,,,,,,,, -58300,1.8893397,2.5279808,,,,,,,,,,,,,, -58400,1.505108,4.0454016,,,,,,,,,,,,,, -58500,1.8145372,2.6105337,,,,,,,,,,,,,, -58519,,,0.6330859065055847,1.5414679050445557,0.5845400094985962,1.7692793607711792,50000.0,0.4695000350475311,2.407954692840576,10000.0,26924.44057393074,30011.18598389625,26924.44057393074,3080.904748916626,2.5036418437957764,0.0 -58600,2.1934016,2.7135146,,,,,,,,,,,,,, -58700,1.4970839,4.0130925,,,,,,,,,,,,,, -58800,1.5210242,3.0265462,,,,,,,,,,,,,, -58900,1.8104709,2.5581536,,,,,,,,,,,,,, -59000,1.374464,4.041095,,,,,,,,,,,,,, -59100,1.4832327,4.630603,,,,,,,,,,,,,, -59200,1.959421,2.5082083,,,,,,,,,,,,,, -59300,2.0328534,5.3940034,,,,,,,,,,,,,, -59400,2.0743597,2.4503012,,,,,,,,,,,,,, -59431,,,0.6523241996765137,1.4520971775054932,0.5870400071144104,1.7456282377243042,50000.0,0.4673000276088714,2.3970510959625244,10000.0,27344.42973256111,30479.25344944,27344.42973256111,3128.893353700638,2.5405941009521484,0.0 -59500,1.6534942,2.6109898,,,,,,,,,,,,,, -59600,1.5999136,5.0263276,,,,,,,,,,,,,, -59700,2.0195463,2.5466857,,,,,,,,,,,,,, -59800,1.5841731,3.135783,,,,,,,,,,,,,, -59900,1.5223327,4.1090713,,,,,,,,,,,,,, -60000,2.2104504,2.6899438,,,,,,,,,,,,,, -60100,1.8117158,2.5702212,,,,,,,,,,,,,, -60200,2.0300403,2.598056,,,,,,,,,,,,,, -60300,1.9214467,2.5330493,,,,,,,,,,,,,, -60346,,,0.6209570169448853,1.5848476886749268,0.58051997423172,1.7770845890045166,50000.0,0.4643000364303589,2.428161144256592,10000.0,27764.40898680687,30947.57905244828,27764.40898680687,3177.135461807251,2.592692852020264,0.0 -60400,1.4077942,4.3932633,,,,,,,,,,,,,, -60500,1.8189149,2.450865,,,,,,,,,,,,,, -60600,1.4570166,5.2332954,,,,,,,,,,,,,, -60700,1.3835428,4.9314218,,,,,,,,,,,,,, -60800,1.5876968,4.0949745,,,,,,,,,,,,,, -60900,1.9211113,2.7136548,,,,,,,,,,,,,, -61000,1.8461053,2.6391227,,,,,,,,,,,,,, -61100,1.8131003,2.4797065,,,,,,,,,,,,,, -61200,1.8914801,2.5353498,,,,,,,,,,,,,, -61260,,,0.6322460770606995,1.57372784614563,0.5851399898529053,1.7762945890426636,50000.0,0.4652000367641449,2.4296584129333496,10000.0,28184.32632660865,31415.275759458546,28184.32632660865,3224.825823068619,2.629995107650757,0.0 -61300,1.7386353,2.4732084,,,,,,,,,,,,,, -61400,1.8590511,5.1813536,,,,,,,,,,,,,, -61500,1.5108961,4.9436207,,,,,,,,,,,,,, -61600,2.1029966,2.558443,,,,,,,,,,,,,, -61700,1.5263817,5.06188,,,,,,,,,,,,,, -61800,1.9141599,2.9780707,,,,,,,,,,,,,, -61900,1.5886325,4.018123,,,,,,,,,,,,,, -62000,1.7340211,2.445835,,,,,,,,,,,,,, -62100,1.8017559,2.5302434,,,,,,,,,,,,,, -62175,,,0.6433984041213989,1.473775863647461,0.5825200080871582,1.7541874647140503,50000.0,0.4636000096797943,2.413270711898804,10000.0,28604.31938052177,31884.798108816147,28604.31938052177,3274.260570049286,2.6725854873657227,0.0 -62200,1.5964476,3.1034255,,,,,,,,,,,,,, -62300,1.4619983,4.185448,,,,,,,,,,,,,, -62400,1.5009309,3.2207336,,,,,,,,,,,,,, -62500,1.6447165,3.6838944,,,,,,,,,,,,,, -62600,1.9432905,2.5022337,,,,,,,,,,,,,, -62700,1.3558699,5.0429807,,,,,,,,,,,,,, -62800,1.6624703,3.103239,,,,,,,,,,,,,, -62900,1.9158643,2.5724967,,,,,,,,,,,,,, -63000,2.2767844,2.6788042,,,,,,,,,,,,,, -63090,,,0.6314452886581421,1.5430465936660769,0.5873000025749207,1.7412338256835938,50000.0,0.4662000238895416,2.3992021083831787,10000.0,29024.578023433685,32354.39852118492,29024.578023433685,3323.5141406059265,2.708783864974976,0.0 -63100,1.7125525,2.561471,,,,,,,,,,,,,, -63200,1.7799512,2.6070213,,,,,,,,,,,,,, -63300,1.7995543,3.4235053,,,,,,,,,,,,,, -63400,1.3180624,5.04229,,,,,,,,,,,,,, -63500,1.501228,4.1028147,,,,,,,,,,,,,, -63600,1.807752,2.4609046,,,,,,,,,,,,,, -63700,1.5735645,3.7792816,,,,,,,,,,,,,, -63800,1.9247159,2.4872499,,,,,,,,,,,,,, -63900,1.3940058,4.4067335,,,,,,,,,,,,,, -64000,1.7379017,2.4305172,,,,,,,,,,,,,, -64006,,,0.6298632621765137,1.5495365858078003,0.5903800129890442,1.736882567405701,50000.0,0.4687000215053558,2.39388108253479,10000.0,29444.48857831955,32817.74343061447,29444.48857831955,3366.8562116622925,2.7485859394073486,0.0 -64100,1.8032179,2.8375587,,,,,,,,,,,,,, -64200,1.841436,2.4291987,,,,,,,,,,,,,, -64300,1.9669232,2.5874133,,,,,,,,,,,,,, -64400,1.8967694,2.489687,,,,,,,,,,,,,, -64500,1.7681217,2.6809688,,,,,,,,,,,,,, -64600,1.8217148,2.5313702,,,,,,,,,,,,,, -64700,1.8792374,2.5440953,,,,,,,,,,,,,, -64800,1.899872,2.4804797,,,,,,,,,,,,,, -64900,1.5558254,3.4705818,,,,,,,,,,,,,, -64921,,,0.6470116972923279,1.4817765951156616,0.5940399765968323,1.731095790863037,50000.0,0.4788000285625458,2.3833441734313965,10000.0,29864.842218399048,33285.5135974884,29864.842218399048,3414.1854150295258,2.78342866897583,0.0 -65000,1.4855173,5.2788506,,,,,,,,,,,,,, -65100,1.5233401,4.3366756,,,,,,,,,,,,,, -65200,1.6072253,4.499369,,,,,,,,,,,,,, -65300,1.8769931,2.9961123,,,,,,,,,,,,,, -65400,2.163098,2.6129694,,,,,,,,,,,,,, -65500,2.0157366,2.6562705,,,,,,,,,,,,,, -65600,1.7910011,3.008278,,,,,,,,,,,,,, -65700,1.8326135,2.8202052,,,,,,,,,,,,,, -65800,1.6456488,2.7471895,,,,,,,,,,,,,, -65839,,,0.64111328125,1.4903844594955444,0.5980600118637085,1.6916265487670898,50000.0,0.4762000143527984,2.3636457920074463,10000.0,30285.16464781761,33751.76726102829,30285.16464781761,3460.025162935257,2.822338342666626,0.0 -65900,1.8947463,2.5474653,,,,,,,,,,,,,, -66000,1.8864648,2.6938102,,,,,,,,,,,,,, -66100,1.8000782,2.6301932,,,,,,,,,,,,,, -66200,1.6938416,3.7895153,,,,,,,,,,,,,, -66300,1.8421671,2.7409313,,,,,,,,,,,,,, -66400,1.5175085,4.4066143,,,,,,,,,,,,,, -66500,1.5593492,5.1415634,,,,,,,,,,,,,, -66600,2.0150726,2.6241822,,,,,,,,,,,,,, -66700,1.7717652,5.0973563,,,,,,,,,,,,,, -66752,,,0.6405078172683716,1.486151099205017,0.6005600094795227,1.6860800981521606,50000.0,0.4765000343322754,2.350003242492676,10000.0,30705.16337299347,34221.96933102608,30705.16337299347,3510.1349868774414,2.8638916015625,0.0 -66800,1.7844131,2.4692307,,,,,,,,,,,,,, -66900,1.6198137,2.945353,,,,,,,,,,,,,, -67000,1.6820402,2.9924812,,,,,,,,,,,,,, -67100,1.6681126,4.2709827,,,,,,,,,,,,,, -67200,1.4195561,5.250383,,,,,,,,,,,,,, -67300,2.0218232,2.5355818,,,,,,,,,,,,,, -67400,1.9823362,2.5182822,,,,,,,,,,,,,, -67500,1.7644533,2.4654045,,,,,,,,,,,,,, -67600,1.3908963,3.8598936,,,,,,,,,,,,,, -67668,,,0.6430078148841858,1.4980931282043457,0.5955199599266052,1.7154990434646606,50000.0,0.4780000150203705,2.372490167617798,10000.0,31125.316086292267,34690.3684053421,31125.316086292267,3558.282987117768,2.9099161624908447,0.0 -67700,1.6225966,3.627049,,,,,,,,,,,,,, -67800,1.4490906,4.4651794,,,,,,,,,,,,,, -67900,2.0713148,2.599945,,,,,,,,,,,,,, -68000,1.6840818,3.2669442,,,,,,,,,,,,,, -68100,1.8681118,2.4427533,,,,,,,,,,,,,, -68200,1.4803673,3.52598,,,,,,,,,,,,,, -68300,1.8532904,2.4905338,,,,,,,,,,,,,, -68400,2.010008,2.615158,,,,,,,,,,,,,, -68500,1.7561015,2.8633745,,,,,,,,,,,,,, -68581,,,0.662304699420929,1.3984365463256836,0.599399983882904,1.6794729232788086,50000.0,0.4852000176906585,2.3231451511383057,10000.0,31545.59892988205,35158.15919518471,31545.59892988205,3605.6924924850464,2.954616069793701,0.0 -68600,1.739283,3.1314583,,,,,,,,,,,,,, -68700,1.6839081,4.548764,,,,,,,,,,,,,, -68800,1.9309419,2.860457,,,,,,,,,,,,,, -68900,1.646444,5.0512447,,,,,,,,,,,,,, -69000,2.1220403,2.6343858,,,,,,,,,,,,,, -69100,1.7397791,3.7005942,,,,,,,,,,,,,, -69200,1.9307755,2.510518,,,,,,,,,,,,,, -69300,2.0476108,2.5209656,,,,,,,,,,,,,, -69400,1.8502684,2.6263442,,,,,,,,,,,,,, -69499,,,0.6392773389816284,1.4914463758468628,0.6018999814987183,1.6752866506576538,50000.0,0.4850000143051147,2.3373358249664307,10000.0,31965.907709360123,35627.05697226524,31965.907709360123,3654.193194627762,2.99135160446167,0.0 -69500,1.9626244,2.4715595,,,,,,,,,,,,,, -69600,1.734223,5.0715036,,,,,,,,,,,,,, -69700,1.7505639,2.5528383,,,,,,,,,,,,,, -69800,1.9524934,2.3668528,,,,,,,,,,,,,, -69900,1.6663344,2.6524374,,,,,,,,,,,,,, -70000,1.9528207,2.3600016,,,,,,,,,,,,,, -70100,1.7107627,2.8129573,,,,,,,,,,,,,, -70200,1.5777476,3.1808639,,,,,,,,,,,,,, -70300,1.5432452,4.705456,,,,,,,,,,,,,, -70400,2.0421026,2.474667,,,,,,,,,,,,,, -70415,,,0.647753894329071,1.4472213983535769,0.5977199673652649,1.679577112197876,50000.0,0.4850000143051147,2.3356761932373047,10000.0,32385.870503902435,36096.61861395836,32385.870503902435,3703.6968002319336,3.0350594520568848,0.0 -70500,1.487042,5.0846663,,,,,,,,,,,,,, -70600,1.7785867,2.8925881,,,,,,,,,,,,,, -70700,1.747173,3.2607946,,,,,,,,,,,,,, -70800,2.071811,2.472801,,,,,,,,,,,,,, -70900,1.8577663,2.7357888,,,,,,,,,,,,,, -71000,1.7163203,3.3411572,,,,,,,,,,,,,, -71100,1.5878158,4.552156,,,,,,,,,,,,,, -71200,2.0731082,2.4820127,,,,,,,,,,,,,, -71300,1.876901,2.6404157,,,,,,,,,,,,,, -71331,,,0.6681054830551147,1.376840353012085,0.5995399951934814,1.685808539390564,50000.0,0.4736000299453735,2.3584470748901367,10000.0,32805.93040370941,36565.83061218262,32805.93040370941,3752.7585911750793,3.072389841079712,0.0 -71400,1.8931375,2.6657643,,,,,,,,,,,,,, -71500,1.782124,2.6005652,,,,,,,,,,,,,, -71600,1.6973547,5.183948,,,,,,,,,,,,,, -71700,1.8265498,2.800996,,,,,,,,,,,,,, -71800,1.9022764,2.4188974,,,,,,,,,,,,,, -71900,2.1739419,2.7210796,,,,,,,,,,,,,, -72000,1.8142829,2.5300946,,,,,,,,,,,,,, -72100,1.9269639,2.4842072,,,,,,,,,,,,,, -72200,2.0014975,2.4588923,,,,,,,,,,,,,, -72246,,,0.6423632502555847,1.4910801649093628,0.5996599793434143,1.6887861490249634,50000.0,0.4778000116348266,2.3342955112457275,10000.0,33226.05816960335,37033.87846469879,33226.05816960335,3800.581439733505,3.1176953315734863,0.0 -72300,1.9489671,3.1278877,,,,,,,,,,,,,, -72400,1.4918375,4.6159,,,,,,,,,,,,,, -72500,2.0071466,2.3532774,,,,,,,,,,,,,, -72600,2.1572373,2.6450813,,,,,,,,,,,,,, -72700,1.9200336,2.3896809,,,,,,,,,,,,,, -72800,1.99072,3.0355086,,,,,,,,,,,,,, -72900,1.8201872,2.2888234,,,,,,,,,,,,,, -73000,1.697296,2.7278447,,,,,,,,,,,,,, -73100,1.5695465,3.5119765,,,,,,,,,,,,,, -73161,,,0.6495702862739563,1.450307846069336,0.604200005531311,1.6661208868026731,50000.0,0.4845000207424164,2.3093228340148926,10000.0,33646.26445245743,37499.833355903625,33646.26445245743,3846.234811067581,3.1611366271972656,0.0 -73200,1.6028755,4.372715,,,,,,,,,,,,,, -73300,1.5084082,5.04965,,,,,,,,,,,,,, -73400,1.8858235,2.3458223,,,,,,,,,,,,,, -73500,2.017878,2.399798,,,,,,,,,,,,,, -73600,2.234191,2.763331,,,,,,,,,,,,,, -73700,1.7788532,2.7472591,,,,,,,,,,,,,, -73800,1.7732685,2.2543645,,,,,,,,,,,,,, -73900,2.0005393,2.405782,,,,,,,,,,,,,, -74000,1.9358248,2.502501,,,,,,,,,,,,,, -74075,,,0.662890613079071,1.4327744245529177,0.6055799722671509,1.702540636062622,50000.0,0.4880000352859497,2.348051309585572,10000.0,34066.458422899246,37971.66522574425,34066.458422899246,3897.773527622223,3.20788049697876,0.0 -74100,1.8558598,2.698442,,,,,,,,,,,,,, -74200,1.6065369,3.740943,,,,,,,,,,,,,, -74300,1.9324588,3.0733516,,,,,,,,,,,,,, -74400,1.6603193,2.8991127,,,,,,,,,,,,,, -74500,1.9881971,2.5528078,,,,,,,,,,,,,, -74600,1.8030334,2.5739336,,,,,,,,,,,,,, -74700,2.0306342,2.55984,,,,,,,,,,,,,, -74800,1.7761183,3.3868237,,,,,,,,,,,,,, -74900,1.8209369,2.3854108,,,,,,,,,,,,,, -74988,,,0.6476953029632568,1.45093834400177,0.6082199811935425,1.6413438320159912,50000.0,0.4859000146389007,2.3130125999450684,10000.0,34486.554055690765,38436.4689707756,34486.554055690765,3942.385674238205,3.2509477138519287,0.0 -75000,2.113486,2.3877277,,,,,,,,,,,,,, -75100,1.7636527,3.0140734,,,,,,,,,,,,,, -75200,1.9779731,2.6482594,,,,,,,,,,,,,, -75300,1.6365995,3.4292555,,,,,,,,,,,,,, -75400,1.5398775,4.880454,,,,,,,,,,,,,, -75500,1.5752798,2.866067,,,,,,,,,,,,,, -75600,1.67554,4.329591,,,,,,,,,,,,,, -75700,1.8991501,2.3616872,,,,,,,,,,,,,, -75800,1.8263909,2.3715796,,,,,,,,,,,,,, -75899,,,0.6452929377555847,1.4677882194519043,0.6007599830627441,1.683293342590332,50000.0,0.4792000353336334,2.326888084411621,10000.0,34906.55943346024,38905.00736904144,34906.55943346024,3990.819508075714,3.2981767654418945,0.0 -75900,1.844781,3.6018639,,,,,,,,,,,,,, -76000,2.0874968,2.5514371,,,,,,,,,,,,,, -76100,1.9782195,2.4011028,,,,,,,,,,,,,, -76200,1.9345567,2.4716344,,,,,,,,,,,,,, -76300,1.5867624,4.3823466,,,,,,,,,,,,,, -76400,1.9026048,2.4833932,,,,,,,,,,,,,, -76500,2.0322182,2.4511263,,,,,,,,,,,,,, -76600,1.6246458,3.3715343,,,,,,,,,,,,,, -76700,2.1590505,2.4456077,,,,,,,,,,,,,, -76800,1.9908731,2.9712338,,,,,,,,,,,,,, -76813,,,0.6568945050239563,1.4267654418945312,0.6042400002479553,1.6688212156295776,50000.0,0.48580002784729,2.317957162857056,10000.0,35326.47755908966,39372.69943475723,35326.47755908966,4038.497713804245,3.342529058456421,0.0 -76900,1.759884,4.204024,,,,,,,,,,,,,, -77000,1.9325027,2.660402,,,,,,,,,,,,,, -77100,1.7700071,2.9567811,,,,,,,,,,,,,, -77200,2.080727,2.5177155,,,,,,,,,,,,,, -77300,1.97065,2.745726,,,,,,,,,,,,,, -77400,1.8624659,5.0502996,,,,,,,,,,,,,, -77500,1.9852594,2.5043757,,,,,,,,,,,,,, -77600,1.8760185,2.7842548,,,,,,,,,,,,,, -77700,1.8467786,2.6059134,,,,,,,,,,,,,, -77728,,,0.6592968702316284,1.4043962955474854,0.6110000014305115,1.6132807731628418,50000.0,0.4881000220775604,2.274594783782959,10000.0,35746.67924427986,39838.25917339325,35746.67924427986,4083.765850067138,3.3802947998046875,0.0 -77800,1.8000847,4.1996336,,,,,,,,,,,,,, -77900,1.8542324,4.9361753,,,,,,,,,,,,,, -78000,2.024901,2.338907,,,,,,,,,,,,,, -78100,1.7025132,3.612802,,,,,,,,,,,,,, -78200,1.9079204,2.3984125,,,,,,,,,,,,,, -78300,1.9802643,2.3154023,,,,,,,,,,,,,, -78400,1.9588997,2.3833108,,,,,,,,,,,,,, -78500,2.125352,2.367429,,,,,,,,,,,,,, -78600,1.9938518,2.3986828,,,,,,,,,,,,,, -78644,,,0.6494140625,1.4841097593307495,0.6038999557495117,1.6853262186050415,50000.0,0.4855000376701355,2.3392891883850098,10000.0,36166.756766080856,40307.37702679634,36166.756766080856,4132.717834472656,3.4168825149536133,0.0 -78700,2.0756652,2.4730906,,,,,,,,,,,,,, -78800,2.1446626,2.3759632,,,,,,,,,,,,,, -78900,1.6357893,4.950428,,,,,,,,,,,,,, -79000,1.685321,5.0694118,,,,,,,,,,,,,, -79100,2.0752933,2.4765854,,,,,,,,,,,,,, -79200,1.998555,2.6094966,,,,,,,,,,,,,, -79300,1.5236284,4.457205,,,,,,,,,,,,,, -79400,1.865356,2.351242,,,,,,,,,,,,,, -79500,1.9233758,5.099604,,,,,,,,,,,,,, -79559,,,0.6655663847923279,1.3757052421569824,0.6121199727058411,1.610224366188049,50000.0,0.4878000319004059,2.2785677909851074,10000.0,36586.85572004318,40774.5886554718,36586.85572004318,4179.737259864807,3.457533836364746,0.0 -79600,1.6154785,4.698633,,,,,,,,,,,,,, -79700,1.7843761,4.6912465,,,,,,,,,,,,,, -79800,1.977193,2.6837635,,,,,,,,,,,,,, -79900,2.0593128,2.5549543,,,,,,,,,,,,,, -80000,2.3246517,2.5062246,,,,,,,,,,,,,, -80100,1.7717812,3.6416862,,,,,,,,,,,,,, -80200,1.9040253,2.4385011,,,,,,,,,,,,,, -80300,2.1239393,2.482826,,,,,,,,,,,,,, -80400,1.8196497,2.9507308,,,,,,,,,,,,,, -80474,,,0.6681054830551147,1.3617504835128784,0.6156399846076965,1.5994772911071775,50000.0,0.4944000244140625,2.2579493522644043,10000.0,37007.06138277054,41240.94586133957,37007.06138277054,4225.793916940689,3.5003437995910645,0.0 -80500,1.7644614,4.0822763,,,,,,,,,,,,,, -80600,2.117618,2.447793,,,,,,,,,,,,,, -80700,2.044296,2.3945134,,,,,,,,,,,,,, -80800,1.5707982,4.439652,,,,,,,,,,,,,, -80900,2.035795,2.516211,,,,,,,,,,,,,, -81000,1.7949862,2.7866464,,,,,,,,,,,,,, -81100,1.7452679,2.9086118,,,,,,,,,,,,,, -81200,1.5922253,4.331115,,,,,,,,,,,,,, -81300,2.3368816,2.4245348,,,,,,,,,,,,,, -81388,,,0.6598241925239563,1.4070793390274048,0.6183599829673767,1.610404133796692,50000.0,0.4949000179767608,2.2725863456726074,10000.0,37427.10963559151,41709.9388833046,37427.10963559151,4274.636933803558,3.5497610569000244,0.0 -81400,2.0627637,2.3079863,,,,,,,,,,,,,, -81500,1.7614518,3.1535978,,,,,,,,,,,,,, -81600,2.0579457,2.3997312,,,,,,,,,,,,,, -81700,1.785355,3.3520248,,,,,,,,,,,,,, -81800,2.0152419,2.362036,,,,,,,,,,,,,, -81900,1.7688692,4.6878896,,,,,,,,,,,,,, -82000,1.9418459,2.3644245,,,,,,,,,,,,,, -82100,1.9939762,2.4283023,,,,,,,,,,,,,, -82200,1.8987561,2.4740653,,,,,,,,,,,,,, -82300,1.5841117,3.611967,,,,,,,,,,,,,, -82301,,,0.6676562428474426,1.3960615396499634,0.6139400005340576,1.63629150390625,50000.0,0.4942000210285187,2.2834160327911377,10000.0,37847.287464141846,42176.4779920578,37847.287464141846,4320.90106010437,3.595423936843872,0.0 -82400,2.2173564,2.5104287,,,,,,,,,,,,,, -82500,1.6618258,4.092658,,,,,,,,,,,,,, -82600,2.0432148,2.3391466,,,,,,,,,,,,,, -82700,1.9742075,3.058815,,,,,,,,,,,,,, -82800,1.7182589,4.9989743,,,,,,,,,,,,,, -82900,1.8587948,2.8447793,,,,,,,,,,,,,, -83000,1.6532023,4.048667,,,,,,,,,,,,,, -83100,2.1122768,2.4504843,,,,,,,,,,,,,, -83200,2.0497785,2.8076687,,,,,,,,,,,,,, -83217,,,0.685742199420929,1.2981655597686768,0.6175999641418457,1.6042635440826416,50000.0,0.4937000274658203,2.265681266784668,10000.0,38267.64587044716,42644.01581954956,38267.64587044716,4367.9880521297455,3.6356606483459473,0.0 -83300,1.9716294,2.3022528,,,,,,,,,,,,,, -83400,2.0417268,2.4575493,,,,,,,,,,,,,, -83500,1.6867652,3.7213526,,,,,,,,,,,,,, -83600,2.3674133,2.3816917,,,,,,,,,,,,,, -83700,2.0568314,2.451077,,,,,,,,,,,,,, -83800,1.6809078,4.53959,,,,,,,,,,,,,, -83900,1.8917847,3.542546,,,,,,,,,,,,,, -84000,2.0227683,2.6595812,,,,,,,,,,,,,, -84100,1.7416477,4.9614305,,,,,,,,,,,,,, -84133,,,0.6615625023841858,1.408983588218689,0.617680013179779,1.613947510719299,50000.0,0.4999000132083893,2.265106439590454,10000.0,38688.0792453289,43113.82888507843,38688.0792453289,4417.275773525238,3.674084186553955,0.0 -84200,2.028829,2.2889318,,,,,,,,,,,,,, -84300,1.7966323,5.0890145,,,,,,,,,,,,,, -84400,1.7810739,3.1963363,,,,,,,,,,,,,, -84500,1.6783112,4.3737683,,,,,,,,,,,,,, -84600,2.0075214,2.6036966,,,,,,,,,,,,,, -84700,1.8590415,4.6458683,,,,,,,,,,,,,, -84800,2.0621712,2.4678226,,,,,,,,,,,,,, -84900,2.3639393,2.6608572,,,,,,,,,,,,,, -85000,1.7990364,2.3947134,,,,,,,,,,,,,, -85049,,,0.6717578172683716,1.3764458894729614,0.6211400032043457,1.6145461797714231,50000.0,0.5058000087738037,2.244614601135254,10000.0,39108.14068603516,43583.24208950997,39108.14068603516,4466.535435676575,3.7140793800354,0.0 -85100,1.9307626,2.3903418,,,,,,,,,,,,,, -85200,1.9531883,3.1545703,,,,,,,,,,,,,, -85300,1.5734473,4.327223,,,,,,,,,,,,,, -85400,1.9117738,3.0763197,,,,,,,,,,,,,, -85500,2.1165466,2.359621,,,,,,,,,,,,,, -85600,1.9714782,2.1505253,,,,,,,,,,,,,, -85700,2.0521793,2.3594048,,,,,,,,,,,,,, -85800,1.793286,3.1499217,,,,,,,,,,,,,, -85900,1.7182199,3.8641462,,,,,,,,,,,,,, -85964,,,0.6827148199081421,1.2997052669525146,0.6204400062561035,1.5726711750030518,50000.0,0.5038000345230103,2.207498550415039,10000.0,39528.49926805496,44049.32660079002,39528.49926805496,4512.167064666748,3.7566189765930176,0.0 -86000,2.5473537,2.4790945,,,,,,,,,,,,,, -86100,2.127073,2.3327053,,,,,,,,,,,,,, -86200,1.6047914,3.3202887,,,,,,,,,,,,,, -86300,1.6436429,4.5847406,,,,,,,,,,,,,, -86400,2.0324519,2.497827,,,,,,,,,,,,,, -86500,1.6701953,3.1917272,,,,,,,,,,,,,, -86600,2.3733225,2.2658532,,,,,,,,,,,,,, -86700,1.6939616,3.302197,,,,,,,,,,,,,, -86800,2.0932312,2.436041,,,,,,,,,,,,,, -86878,,,0.6692578196525574,1.3671681880950928,0.6223799586296082,1.5760130882263184,50000.0,0.5035000443458557,2.202033042907715,10000.0,39948.76182985306,44516.90362381935,39948.76182985306,4559.381934642792,3.8039591312408447,0.0 -86900,2.0380232,2.464378,,,,,,,,,,,,,, -87000,1.9805782,2.250037,,,,,,,,,,,,,, -87100,2.0802994,2.3432927,,,,,,,,,,,,,, -87200,1.9047613,2.577971,,,,,,,,,,,,,, -87300,2.1570485,2.4224932,,,,,,,,,,,,,, -87400,1.8444614,2.5620496,,,,,,,,,,,,,, -87500,1.9850724,2.344049,,,,,,,,,,,,,, -87600,2.2514827,2.3661087,,,,,,,,,,,,,, -87700,2.1700168,2.592706,,,,,,,,,,,,,, -87793,,,0.6750195026397705,1.3554376363754272,0.6269999742507935,1.583235740661621,50000.0,0.5076000094413757,2.2284297943115234,10000.0,40369.03018307686,44984.48181915283,40369.03018307686,4606.599714756012,3.843261957168579,0.0 -87800,2.2673259,2.380691,,,,,,,,,,,,,, -87900,2.0228624,3.3491447,,,,,,,,,,,,,, -88000,2.1034517,2.643017,,,,,,,,,,,,,, -88100,2.3065715,2.3651197,,,,,,,,,,,,,, -88200,1.5954939,3.8231168,,,,,,,,,,,,,, -88300,1.6555872,3.4118254,,,,,,,,,,,,,, -88400,2.1080885,2.397718,,,,,,,,,,,,,, -88500,1.8121291,4.7540474,,,,,,,,,,,,,, -88600,1.9790659,2.338015,,,,,,,,,,,,,, -88700,2.225066,2.4618359,,,,,,,,,,,,,, -88707,,,0.6813281178474426,1.3070732355117798,0.6232199668884277,1.5689715147018433,50000.0,0.501800000667572,2.2088680267333984,10000.0,40789.127017498016,45450.74204015732,40789.127017498016,4652.663172245026,3.891881942749024,0.0 -88800,2.1130266,2.1388595,,,,,,,,,,,,,, -88900,2.2145808,2.6243503,,,,,,,,,,,,,, -89000,2.0683143,2.6944995,,,,,,,,,,,,,, -89100,1.738472,2.8981786,,,,,,,,,,,,,, -89200,2.1051693,2.4438891,,,,,,,,,,,,,, -89300,2.1609607,2.2706335,,,,,,,,,,,,,, -89400,1.8522005,2.7073824,,,,,,,,,,,,,, -89500,1.9766736,2.2872384,,,,,,,,,,,,,, -89600,1.9390996,5.0691547,,,,,,,,,,,,,, -89620,,,0.6719921827316284,1.342168211936951,0.6292600035667419,1.5393462181091309,50000.0,0.5082000494003296,2.181424617767334,10000.0,41209.09097337723,45920.40969824791,41209.09097337723,4702.275573730469,3.930466890335083,0.0 -89700,2.0597334,2.4548478,,,,,,,,,,,,,, -89800,2.2472296,2.9936824,,,,,,,,,,,,,, -89900,2.1630788,2.5606265,,,,,,,,,,,,,, -90000,1.9611008,2.6477869,,,,,,,,,,,,,, -90100,1.9899938,2.2533708,,,,,,,,,,,,,, -90200,2.0874574,2.3102736,,,,,,,,,,,,,, -90300,2.0968308,2.3144808,,,,,,,,,,,,,, -90400,2.0721002,2.4968185,,,,,,,,,,,,,, -90500,2.1174042,2.3238766,,,,,,,,,,,,,, -90534,,,0.6768749952316284,1.3328397274017334,0.6259399652481079,1.5650203227996826,50000.0,0.504800021648407,2.219699621200561,10000.0,41629.03953385353,46386.41199302673,41629.03953385353,4748.239414215088,3.968465805053711,0.0 -90600,2.2980008,2.345715,,,,,,,,,,,,,, -90700,2.0472364,2.2693644,,,,,,,,,,,,,, -90800,1.7130952,4.093922,,,,,,,,,,,,,, -90900,2.0690103,2.3197541,,,,,,,,,,,,,, -91000,1.7674552,3.1195366,,,,,,,,,,,,,, -91100,1.8670607,4.3055778,,,,,,,,,,,,,, -91200,1.975659,4.2098603,,,,,,,,,,,,,, -91300,1.9395638,2.2662148,,,,,,,,,,,,,, -91400,1.8921326,3.031465,,,,,,,,,,,,,, -91446,,,0.6907030940055847,1.2894749641418457,0.6323599815368652,1.547354221343994,50000.0,0.511400043964386,2.195029735565185,10000.0,42049.08965468407,46856.01924753189,42049.08965468407,4797.702326059341,4.011040687561035,0.0 -91500,2.2030127,2.504581,,,,,,,,,,,,,, -91600,2.0992856,2.286253,,,,,,,,,,,,,, -91700,2.018968,2.4409347,,,,,,,,,,,,,, -91800,1.9311366,4.3804526,,,,,,,,,,,,,, -91900,1.9333998,3.2272174,,,,,,,,,,,,,, -92000,2.1298954,2.2426696,,,,,,,,,,,,,, -92100,2.0734544,3.1570196,,,,,,,,,,,,,, -92200,1.7393261,3.4969053,,,,,,,,,,,,,, -92300,1.9162624,2.579738,,,,,,,,,,,,,, -92359,,,0.6765429377555847,1.3410677909851074,0.6250799894332886,1.5731840133666992,50000.0,0.5012000203132629,2.218226194381714,10000.0,42469.318239450455,47324.5852496624,42469.318239450455,4845.945533275604,4.053576469421387,0.0 -92400,2.1971905,2.593549,,,,,,,,,,,,,, -92500,2.3058496,2.2466528,,,,,,,,,,,,,, -92600,2.013508,2.2927897,,,,,,,,,,,,,, -92700,2.1334622,2.360736,,,,,,,,,,,,,, -92800,1.7706325,4.415358,,,,,,,,,,,,,, -92900,1.8438036,3.442202,,,,,,,,,,,,,, -93000,2.2521546,2.2418728,,,,,,,,,,,,,, -93100,2.0494373,2.4255154,,,,,,,,,,,,,, -93200,1.7591732,4.8296814,,,,,,,,,,,,,, -93276,,,0.68115234375,1.303971529006958,0.6333000063896179,1.5291578769683838,50000.0,0.5135000348091125,2.1676101684570312,10000.0,42889.63666677475,47793.083832740784,42889.63666677475,4894.030555963516,4.096000909805298,0.0 -93300,2.0632274,4.2592583,,,,,,,,,,,,,, -93400,1.8725541,2.9626904,,,,,,,,,,,,,, -93500,1.8295782,3.7018993,,,,,,,,,,,,,, -93600,1.5990838,4.7356195,,,,,,,,,,,,,, -93700,2.1225562,2.2557082,,,,,,,,,,,,,, -93800,2.059013,2.303243,,,,,,,,,,,,,, -93900,1.9386488,2.4608707,,,,,,,,,,,,,, -94000,2.0162482,2.2905996,,,,,,,,,,,,,, -94100,1.6938837,3.736494,,,,,,,,,,,,,, -94192,,,0.6882030963897705,1.2961716651916504,0.6321399807929993,1.5467734336853027,50000.0,0.5126000046730042,2.1976771354675293,10000.0,43309.85211658478,48261.21060657501,43309.85211658478,4941.8448095321655,4.140573740005493,0.0 -94200,1.7317495,4.9420714,,,,,,,,,,,,,, -94300,2.2673392,2.3519783,,,,,,,,,,,,,, -94400,1.7713841,4.2462845,,,,,,,,,,,,,, -94500,2.0561543,2.2738402,,,,,,,,,,,,,, -94600,1.8645973,2.7044313,,,,,,,,,,,,,, -94700,1.8321568,4.6166077,,,,,,,,,,,,,, -94800,1.9939194,4.7020173,,,,,,,,,,,,,, -94900,2.10599,2.1827745,,,,,,,,,,,,,, -95000,1.9027243,4.5213027,,,,,,,,,,,,,, -95100,2.3862462,2.335788,,,,,,,,,,,,,, -95110,,,0.7078515291213989,1.1914904117584229,0.6358000040054321,1.5101193189620972,50000.0,0.5103999972343445,2.1727511882781982,10000.0,43730.018189907074,48731.59634041786,43730.018189907074,4991.971055984497,4.181777000427246,0.0 -95200,1.7887015,3.2579498,,,,,,,,,,,,,, -95300,2.0796163,2.597553,,,,,,,,,,,,,, -95400,2.2915232,2.221243,,,,,,,,,,,,,, -95500,2.3539448,2.2226305,,,,,,,,,,,,,, -95600,1.7821,4.5900993,,,,,,,,,,,,,, -95700,2.0225897,2.249174,,,,,,,,,,,,,, -95800,2.1052732,2.2254796,,,,,,,,,,,,,, -95900,1.9326829,3.2546473,,,,,,,,,,,,,, -96000,2.1816,2.3139725,,,,,,,,,,,,,, -96025,,,0.6827734112739563,1.2824567556381226,0.6331599950790405,1.5004512071609497,50000.0,0.5094000101089478,2.1757290363311768,10000.0,44150.107147455215,49197.796854019165,44150.107147455215,5037.990777015686,4.222255706787109,0.0 -96100,1.791497,4.254918,,,,,,,,,,,,,, -96200,2.33543,2.374572,,,,,,,,,,,,,, -96300,2.1427817,2.2894185,,,,,,,,,,,,,, -96400,2.207511,2.2863407,,,,,,,,,,,,,, -96500,1.88514,4.60214,,,,,,,,,,,,,, -96600,1.9663559,2.7369404,,,,,,,,,,,,,, -96700,2.031934,3.2414253,,,,,,,,,,,,,, -96800,2.29253,3.3226416,,,,,,,,,,,,,, -96900,2.5183103,2.3814619,,,,,,,,,,,,,, -96938,,,0.6849414110183716,1.297351360321045,0.6337800025939941,1.5296707153320312,50000.0,0.5078999996185303,2.185168504714966,10000.0,44570.26457285881,49665.81526470184,44570.26457285881,5085.7554042339325,4.266266822814941,0.0 -97000,1.8361039,4.5838175,,,,,,,,,,,,,, -97100,2.1531074,2.7281485,,,,,,,,,,,,,, -97200,2.3004794,2.6325393,,,,,,,,,,,,,, -97300,2.0120044,2.516271,,,,,,,,,,,,,, -97400,2.548559,2.4023378,,,,,,,,,,,,,, -97500,2.3568034,4.5925045,,,,,,,,,,,,,, -97600,1.8819029,4.788124,,,,,,,,,,,,,, -97700,1.7972149,3.5029979,,,,,,,,,,,,,, -97800,2.1619158,2.8546855,,,,,,,,,,,,,, -97853,,,0.706347644329071,1.1973272562026978,0.6412000060081482,1.4900233745574951,50000.0,0.5216000080108643,2.120568037033081,10000.0,44990.63896560669,50136.00769329071,44990.63896560669,5135.472050905228,4.316110610961914,0.0 -97900,2.3058734,2.3221931,,,,,,,,,,,,,, -98000,2.0083425,2.5418284,,,,,,,,,,,,,, -98100,1.9620577,4.6044564,,,,,,,,,,,,,, -98200,2.1751492,2.285254,,,,,,,,,,,,,, -98300,2.195164,2.2398906,,,,,,,,,,,,,, -98400,2.3436937,2.3605292,,,,,,,,,,,,,, -98500,1.9984058,2.6669874,,,,,,,,,,,,,, -98600,1.8544182,4.2716193,,,,,,,,,,,,,, -98700,2.2853105,2.2362838,,,,,,,,,,,,,, -98770,,,0.6881640553474426,1.2870441675186155,0.6415799856185913,1.4922298192977903,50000.0,0.5159000158309937,2.161956548690796,10000.0,45410.94944810867,50603.08851194382,45410.94944810867,5182.138915061951,4.366321802139282,0.0 -98800,2.078293,2.7475762,,,,,,,,,,,,,, -98900,2.2868526,2.3214824,,,,,,,,,,,,,, -99000,2.29885,2.245531,,,,,,,,,,,,,, -99100,1.8613155,4.8408723,,,,,,,,,,,,,, -99200,2.2800884,2.2340298,,,,,,,,,,,,,, -99300,2.341212,2.2555304,,,,,,,,,,,,,, -99400,2.2391875,2.7278829,,,,,,,,,,,,,, -99500,2.2458973,2.1892328,,,,,,,,,,,,,, -99600,2.0867856,2.3717875,,,,,,,,,,,,,, -99686,,,0.6941796541213989,1.2387899160385132,0.6456999778747559,1.462598204612732,50000.0,0.5254999995231628,2.105508327484131,10000.0,45831.12450551987,51072.9808216095,45831.12450551987,5231.758380651474,4.412039041519165,0.0 -99700,2.0905802,2.248328,,,,,,,,,,,,,, -99800,2.0728927,2.2050831,,,,,,,,,,,,,, -99900,1.9486052,2.1264362,,,,,,,,,,,,,, -100000,1.7855062,4.060932,,,,,,,,,,,,,, -100100,2.4904993,2.3023303,,,,,,,,,,,,,, -100200,2.1561363,2.1306298,,,,,,,,,,,,,, -100300,2.2591636,2.3376725,,,,,,,,,,,,,, -100400,2.3224049,2.335036,,,,,,,,,,,,,, -100500,1.9371626,3.0744958,,,,,,,,,,,,,, -100594,,,0.7018945217132568,1.217579960823059,0.642799973487854,1.4809935092926023,50000.0,0.5175999999046326,2.107346296310425,10000.0,46250.72084593773,51538.47866082192,46250.72084593773,5277.18566441536,4.834028720855713,0.0 -100600,2.3867714,2.1748967,,,,,,,,,,,,,, -100700,2.3345282,2.3708763,,,,,,,,,,,,,, -100800,1.8527836,4.652047,,,,,,,,,,,,,, -100900,2.1560228,2.3179743,,,,,,,,,,,,,, -101000,2.2245672,2.1039917,,,,,,,,,,,,,, -101100,2.0103521,2.946003,,,,,,,,,,,,,, -101200,2.0728657,2.655314,,,,,,,,,,,,,, -101300,1.9579401,3.9761035,,,,,,,,,,,,,, -101400,2.3010879,2.3374116,,,,,,,,,,,,,, -101500,2.2063394,2.340593,,,,,,,,,,,,,, -101508,,,0.6890038847923279,1.2767558097839355,0.643839955329895,1.4891250133514404,50000.0,0.5198000073432922,2.138503074645996,10000.0,46670.82105779648,52008.493248701096,46670.82105779648,5326.998807668686,4.8833067417144775,0.0 -101600,2.2348287,2.3965492,,,,,,,,,,,,,, -101700,2.178886,2.2032661,,,,,,,,,,,,,, -101800,2.183266,2.5195298,,,,,,,,,,,,,, -101900,2.1101644,2.406248,,,,,,,,,,,,,, -102000,2.5744424,2.2514448,,,,,,,,,,,,,, -102100,1.9838662,3.957716,,,,,,,,,,,,,, -102200,2.093229,3.0679204,,,,,,,,,,,,,, -102300,2.41417,2.2204566,,,,,,,,,,,,,, -102400,1.928354,3.898003,,,,,,,,,,,,,, -102425,,,0.7002929449081421,1.2077990770339966,0.6493200063705444,1.445142149925232,50000.0,0.5297000408172607,2.082939863204956,10000.0,47091.16972374916,52478.20104813576,47091.16972374916,5376.259582996368,4.928251028060913,0.0 -102500,2.3079777,2.422987,,,,,,,,,,,,,, -102600,2.238245,2.155653,,,,,,,,,,,,,, -102700,2.0153518,3.0960202,,,,,,,,,,,,,, -102800,2.2193518,2.2093031,,,,,,,,,,,,,, -102900,2.1769245,2.275445,,,,,,,,,,,,,, -103000,2.1950364,2.1177366,,,,,,,,,,,,,, -103100,2.1199818,2.21703,,,,,,,,,,,,,, -103200,1.9479005,3.7314506,,,,,,,,,,,,,, -103300,1.9487795,4.086759,,,,,,,,,,,,,, -103341,,,0.7048437595367432,1.213708758354187,0.647599995136261,1.473738670349121,50000.0,0.5266000032424927,2.116064548492432,10000.0,47511.17261624336,52946.35978603363,47511.17261624336,5424.322031259537,4.9686126708984375,0.0 -103400,1.9861976,3.822448,,,,,,,,,,,,,, -103500,2.1831028,4.0050516,,,,,,,,,,,,,, -103600,2.251987,2.2292948,,,,,,,,,,,,,, -103700,2.228478,2.6776383,,,,,,,,,,,,,, -103800,2.2041025,2.19978,,,,,,,,,,,,,, -103900,2.0948043,3.0233626,,,,,,,,,,,,,, -104000,2.4310553,2.2144618,,,,,,,,,,,,,, -104100,2.3155582,2.5705595,,,,,,,,,,,,,, -104200,2.1932807,2.1851542,,,,,,,,,,,,,, -104258,,,0.6996484398841858,1.24612295627594,0.6481599807739258,1.4764598608016968,50000.0,0.5238000154495239,2.127571582794189,10000.0,47931.445341825485,53416.33458185196,47931.445341825485,5473.918003559113,5.021668195724487,0.0 -104300,2.4339242,4.7576203,,,,,,,,,,,,,, -104400,2.372868,2.308864,,,,,,,,,,,,,, -104500,2.1845233,2.228965,,,,,,,,,,,,,, -104600,2.4586258,2.1344905,,,,,,,,,,,,,, -104700,2.2925155,2.1765013,,,,,,,,,,,,,, -104800,2.300517,2.4431236,,,,,,,,,,,,,, -104900,2.6698017,2.1712,,,,,,,,,,,,,, -105000,1.9001465,3.6621056,,,,,,,,,,,,,, -105100,2.4797218,2.256795,,,,,,,,,,,,,, -105170,,,0.701464831829071,1.2353618144989014,0.6521399617195129,1.4591549634933472,50000.0,0.5285000205039978,2.1071197986602783,10000.0,48351.39259338379,53882.78431844711,48351.39259338379,5520.325870513916,5.064110040664673,0.0 -105200,2.155663,2.212997,,,,,,,,,,,,,, -105300,2.4083393,2.1691706,,,,,,,,,,,,,, -105400,2.2510958,2.1491,,,,,,,,,,,,,, -105500,2.5104342,2.4071615,,,,,,,,,,,,,, -105600,2.402516,2.096377,,,,,,,,,,,,,, -105700,2.2944005,2.1199176,,,,,,,,,,,,,, -105800,2.4161587,2.1430595,,,,,,,,,,,,,, -105900,2.4529574,2.2218385,,,,,,,,,,,,,, -106000,1.8684483,3.1755376,,,,,,,,,,,,,, -106082,,,0.708984375,1.166994333267212,0.6547200083732605,1.4128574132919312,50000.0,0.5339000225067139,2.0650956630706787,10000.0,48771.49581623077,54349.56981277466,48771.49581623077,5566.9127151966095,5.108115911483765,0.0 -106100,2.70566,2.3337438,,,,,,,,,,,,,, -106200,2.317549,2.2348328,,,,,,,,,,,,,, -106300,1.9492543,4.537573,,,,,,,,,,,,,, -106400,2.2685578,2.0909214,,,,,,,,,,,,,, -106500,2.3034286,2.1093647,,,,,,,,,,,,,, -106600,2.2996643,2.201841,,,,,,,,,,,,,, -106700,2.3668697,2.0562148,,,,,,,,,,,,,, -106800,2.137172,4.4709983,,,,,,,,,,,,,, -106900,2.5364838,2.421651,,,,,,,,,,,,,, -106995,,,0.7225390672683716,1.131512999534607,0.6502199769020081,1.4502038955688477,50000.0,0.5242000222206116,2.100848436355591,10000.0,49191.43444299698,54819.61263132095,49191.43444299698,5616.92050743103,5.152854442596436,0.0 -107000,2.1891484,2.5295305,,,,,,,,,,,,,, -107100,1.9815555,4.1744823,,,,,,,,,,,,,, -107200,2.064255,2.7227173,,,,,,,,,,,,,, -107300,2.235533,2.592071,,,,,,,,,,,,,, -107400,2.5801105,2.239923,,,,,,,,,,,,,, -107500,2.23581,4.073599,,,,,,,,,,,,,, -107600,2.511917,2.1818848,,,,,,,,,,,,,, -107700,2.4105797,2.0247126,,,,,,,,,,,,,, -107800,2.3150373,2.4820387,,,,,,,,,,,,,, -107900,2.41748,2.2049148,,,,,,,,,,,,,, -107912,,,0.7108789086341858,1.1651406288146973,0.6574999690055847,1.407038688659668,50000.0,0.5339000225067139,2.047857999801636,10000.0,49611.7597925663,55291.49946713448,49611.7597925663,5668.385055780411,5.196820497512817,0.0 -108000,2.266267,4.7363873,,,,,,,,,,,,,, -108100,2.758834,2.223428,,,,,,,,,,,,,, -108200,2.3290718,3.7390547,,,,,,,,,,,,,, -108300,2.2522993,2.2102,,,,,,,,,,,,,, -108400,2.1287887,4.4446764,,,,,,,,,,,,,, -108500,2.194318,3.0862415,,,,,,,,,,,,,, -108600,2.6249714,2.353696,,,,,,,,,,,,,, -108700,2.4150462,2.191533,,,,,,,,,,,,,, -108800,2.2179434,3.5934527,,,,,,,,,,,,,, -108827,,,0.7103906273841858,1.18244469165802,0.6546799540519714,1.4344879388809204,50000.0,0.5293000340461731,2.086609125137329,10000.0,50032.11195850372,55761.36970353127,50032.11195850372,5717.807471752167,5.240591287612915,0.0 -108900,2.3954732,2.2570648,,,,,,,,,,,,,, -109000,2.6577742,2.2778196,,,,,,,,,,,,,, -109100,2.1891427,4.708183,,,,,,,,,,,,,, -109200,1.9370842,3.7676415,,,,,,,,,,,,,, -109300,2.470119,2.2672944,,,,,,,,,,,,,, -109400,2.3898952,2.0701067,,,,,,,,,,,,,, -109500,2.3191743,2.1958735,,,,,,,,,,,,,, -109600,1.915822,3.4451995,,,,,,,,,,,,,, -109700,2.3854456,2.5088778,,,,,,,,,,,,,, -109742,,,0.7229882478713989,1.1345579624176023,0.6586799621582031,1.426134467124939,50000.0,0.5376999974250793,2.0645999908447266,10000.0,50452.2932767868,56229.43710780144,50452.2932767868,5765.595708608627,5.286548137664795,0.0 -109800,2.0654666,4.2041206,,,,,,,,,,,,,, -109900,2.5248935,2.1750247,,,,,,,,,,,,,, -110000,2.4738798,2.1118722,,,,,,,,,,,,,, -110100,2.4416232,2.2443283,,,,,,,,,,,,,, -110200,2.3402233,4.3599553,,,,,,,,,,,,,, -110300,2.1968606,3.1465712,,,,,,,,,,,,,, -110400,2.033809,4.133288,,,,,,,,,,,,,, -110500,2.4053721,2.145793,,,,,,,,,,,,,, -110600,1.9475944,4.1155457,,,,,,,,,,,,,, -110658,,,0.7133398056030273,1.1982223987579346,0.6570599675178528,1.440184235572815,50000.0,0.5340999960899353,2.095458984375,10000.0,50872.51973128319,56695.58960843086,50872.51973128319,5811.417835235596,5.338428258895874,0.0 -110700,2.2654448,2.853128,,,,,,,,,,,,,, -110800,2.2374823,4.1752434,,,,,,,,,,,,,, -110900,2.1459692,3.8647175,,,,,,,,,,,,,, -111000,2.3012571,2.238623,,,,,,,,,,,,,, -111100,2.1775246,3.0003574,,,,,,,,,,,,,, -111200,2.1845942,2.9830616,,,,,,,,,,,,,, -111300,2.1574304,3.0904958,,,,,,,,,,,,,, -111400,2.0431678,4.6118937,,,,,,,,,,,,,, -111500,2.6396823,2.1304305,,,,,,,,,,,,,, -111575,,,0.7201171517372131,1.123279690742493,0.6644399762153625,1.3742200136184692,50000.0,0.5422000288963318,2.019754409790039,10000.0,51292.6597931385,57164.1505010128,51292.6597931385,5859.738709926605,5.385877132415772,0.0 -111600,2.2037365,3.0251193,,,,,,,,,,,,,, -111700,2.3732746,2.2496316,,,,,,,,,,,,,, -111800,2.3219962,3.829143,,,,,,,,,,,,,, -111900,2.7055814,2.1259785,,,,,,,,,,,,,, -112000,2.4732866,2.068357,,,,,,,,,,,,,, -112100,2.5355322,2.0797348,,,,,,,,,,,,,, -112200,2.4849246,4.7856116,,,,,,,,,,,,,, -112300,2.163202,3.1904402,,,,,,,,,,,,,, -112400,2.458723,2.3326764,,,,,,,,,,,,,, -112491,,,0.7299218773841858,1.0928118228912354,0.6615599989891052,1.383098602294922,50000.0,0.5380000472068787,2.0178375244140625,10000.0,51712.80060315132,57629.05432772637,51712.80060315132,5904.406593084335,5.429534912109375,0.0 -112500,2.3691611,2.164809,,,,,,,,,,,,,, -112600,2.275371,2.7705562,,,,,,,,,,,,,, -112700,2.6631923,2.1050034,,,,,,,,,,,,,, -112800,2.389391,4.383293,,,,,,,,,,,,,, -112900,2.963853,2.0941746,,,,,,,,,,,,,, -113000,2.6550672,2.3200054,,,,,,,,,,,,,, -113100,2.4755666,2.0714073,,,,,,,,,,,,,, -113200,2.319669,2.2673934,,,,,,,,,,,,,, -113300,2.5594525,2.086801,,,,,,,,,,,,,, -113400,2.8272343,1.9392238,,,,,,,,,,,,,, -113407,,,0.7118163704872131,1.1474217176437378,0.6642599701881409,1.3698629140853882,50000.0,0.5384000539779663,2.007035255432129,10000.0,52132.85252594948,58097.311703681946,52132.85252594948,5952.510481357575,5.479061126708984,0.0 -113500,2.5831707,2.0436556,,,,,,,,,,,,,, -113600,2.796908,2.2506166,,,,,,,,,,,,,, -113700,2.4363937,2.185411,,,,,,,,,,,,,, -113800,2.2022548,2.790924,,,,,,,,,,,,,, -113900,2.5537856,2.1223073,,,,,,,,,,,,,, -114000,2.2537699,4.246186,,,,,,,,,,,,,, -114100,2.5585175,4.5605764,,,,,,,,,,,,,, -114200,2.1170156,3.7120974,,,,,,,,,,,,,, -114300,2.6250913,2.1598258,,,,,,,,,,,,,, -114319,,,0.7221288681030273,1.117608666419983,0.6670199632644653,1.3616474866867063,50000.0,0.5433000326156616,2.0026955604553223,10000.0,52552.87658810616,58562.885924339294,52552.87658810616,5997.965879917145,5.521757125854492,0.0 -114400,2.3534098,4.530716,,,,,,,,,,,,,, -114500,2.9980402,2.1435204,,,,,,,,,,,,,, -114600,2.6812527,2.0815716,,,,,,,,,,,,,, -114700,2.768874,2.124269,,,,,,,,,,,,,, -114800,2.649207,2.1668181,,,,,,,,,,,,,, -114900,2.5134633,2.1051497,,,,,,,,,,,,,, -115000,2.4478605,3.9484134,,,,,,,,,,,,,, -115100,2.4976296,2.1300602,,,,,,,,,,,,,, -115200,2.4178128,4.1414175,,,,,,,,,,,,,, -115236,,,0.7292773127555847,1.092387318611145,0.6671000123023987,1.3630927801132202,50000.0,0.5494000315666199,2.007647752761841,10000.0,52973.100821495056,59032.51432132721,52973.100821495056,6047.270622730255,5.568381786346436,0.0 -115300,2.6648633,2.2661133,,,,,,,,,,,,,, -115400,2.3070705,2.9631097,,,,,,,,,,,,,, -115500,2.3170183,2.5985181,,,,,,,,,,,,,, -115600,2.2403164,3.4755359,,,,,,,,,,,,,, -115700,2.5041533,2.0034044,,,,,,,,,,,,,, -115800,2.1885948,3.271861,,,,,,,,,,,,,, -115900,2.2004888,3.5343447,,,,,,,,,,,,,, -116000,2.5405216,2.0956137,,,,,,,,,,,,,, -116100,2.4106755,3.4741244,,,,,,,,,,,,,, -116150,,,0.7247265577316284,1.1091890335083008,0.6702199578285217,1.354252815246582,50000.0,0.5452000498771667,1.99252724647522,10000.0,53393.11542224884,59500.606415987015,53393.11542224884,6095.24760055542,5.615738153457642,0.0 -116200,2.523853,2.0722268,,,,,,,,,,,,,, -116300,2.396117,2.0493085,,,,,,,,,,,,,, -116400,2.4503963,2.0936234,,,,,,,,,,,,,, -116500,2.6631286,2.1203861,,,,,,,,,,,,,, -116600,2.2391622,3.135484,,,,,,,,,,,,,, -116700,2.7175004,2.1309404,,,,,,,,,,,,,, -116800,2.664604,1.9830086,,,,,,,,,,,,,, -116900,2.4451194,2.014882,,,,,,,,,,,,,, -117000,2.2182903,3.2353606,,,,,,,,,,,,,, -117066,,,0.7264843583106995,1.096244215965271,0.672760009765625,1.340212106704712,50000.0,0.5508000254631042,1.99046790599823,10000.0,53813.10630583763,59967.58908033371,53813.10630583763,6142.140088558197,5.661764621734619,0.0 -117100,2.748169,4.286808,,,,,,,,,,,,,, -117200,2.5013514,4.2557745,,,,,,,,,,,,,, -117300,2.5341654,2.2150166,,,,,,,,,,,,,, -117400,2.6068048,2.160309,,,,,,,,,,,,,, -117500,2.9607704,2.0401073,,,,,,,,,,,,,, -117600,2.5402765,4.5721793,,,,,,,,,,,,,, -117700,2.4925816,1.9998001,,,,,,,,,,,,,, -117800,2.6826916,2.0122006,,,,,,,,,,,,,, -117900,2.278437,4.470094,,,,,,,,,,,,,, -117982,,,0.7369335889816284,1.0518841743469238,0.6758599877357483,1.3248143196105957,50000.0,0.5498000383377075,1.969029784202576,10000.0,54233.0949075222,60437.959518909454,54233.0949075222,6192.417934656143,5.713583707809448,0.0 -118000,3.1845117,2.0576823,,,,,,,,,,,,,, -118100,2.637486,2.1608713,,,,,,,,,,,,,, -118200,2.6125622,4.773152,,,,,,,,,,,,,, -118300,2.2862864,4.229745,,,,,,,,,,,,,, -118400,2.5502908,2.050189,,,,,,,,,,,,,, -118500,2.8127534,2.2862868,,,,,,,,,,,,,, -118600,2.7391262,1.9595469,,,,,,,,,,,,,, -118700,2.7552128,2.301017,,,,,,,,,,,,,, -118800,2.6371858,4.495287,,,,,,,,,,,,,, -118897,,,0.7423437237739563,1.0445845127105713,0.6692599654197693,1.361345291137695,50000.0,0.5490000247955322,2.009008407592773,10000.0,54653.28877854347,60906.417788267136,54653.28877854347,6240.577670812607,5.76579213142395,0.0 -118900,2.400632,2.5913436,,,,,,,,,,,,,, -119000,2.249098,3.35608,,,,,,,,,,,,,, -119100,2.6749837,2.005149,,,,,,,,,,,,,, -119200,2.6442664,2.2337937,,,,,,,,,,,,,, -119300,2.799783,2.3970156,,,,,,,,,,,,,, -119400,2.5126705,2.6377547,,,,,,,,,,,,,, -119500,2.6091807,1.9880469,,,,,,,,,,,,,, -119600,2.4797535,2.8299112,,,,,,,,,,,,,, -119700,2.4988666,4.0533834,,,,,,,,,,,,,, -119800,2.7374253,1.9893293,,,,,,,,,,,,,, -119812,,,0.7238671779632568,1.152235507965088,0.6700599789619446,1.3908920288085938,50000.0,0.5487000346183777,2.030686378479004,10000.0,55073.62894105911,61375.50624871254,55073.62894105911,6289.227605819702,5.812516689300537,0.0 -119900,2.6561701,2.0953922,,,,,,,,,,,,,, -120000,2.450634,1.9101732,,,,,,,,,,,,,, -120100,2.3416903,3.2318559,,,,,,,,,,,,,, -120200,2.6904702,2.1391807,,,,,,,,,,,,,, -120300,2.3821747,3.859333,,,,,,,,,,,,,, -120400,2.536713,2.9786773,,,,,,,,,,,,,, -120500,3.0854118,2.0275931,,,,,,,,,,,,,, -120600,2.473074,3.0335443,,,,,,,,,,,,,, -120700,2.8325663,2.0615878,,,,,,,,,,,,,, -120727,,,0.7369335889816284,1.071627855300903,0.6764400005340576,1.331760287284851,50000.0,0.5505000352859497,1.9711410999298096,10000.0,55493.84318685532,61845.34255743027,55493.84318685532,6338.752652645111,5.857455015182495,0.0 -120800,2.3785372,3.1769812,,,,,,,,,,,,,, -120900,2.5405061,4.601305,,,,,,,,,,,,,, -121000,2.954017,2.0535712,,,,,,,,,,,,,, -121100,2.6052654,2.1811578,,,,,,,,,,,,,, -121200,3.1257873,2.0931363,,,,,,,,,,,,,, -121300,2.7581599,3.055455,,,,,,,,,,,,,, -121400,2.7775347,2.0700846,,,,,,,,,,,,,, -121500,2.4736195,2.2048845,,,,,,,,,,,,,, -121600,2.42313,3.4850883,,,,,,,,,,,,,, -121644,,,0.746386706829071,1.0214617252349854,0.6764199733734131,1.3321937322616575,50000.0,0.5493000149726868,1.980876445770264,10000.0,55914.0778169632,62315.006165504456,55914.0778169632,6388.082646846771,5.904085636138916,0.0 -121700,2.858127,1.9951252,,,,,,,,,,,,,, -121800,2.812856,2.0878494,,,,,,,,,,,,,, -121900,2.7155426,3.9466834,,,,,,,,,,,,,, -122000,2.5717347,4.5133066,,,,,,,,,,,,,, -122100,2.6980815,1.9864177,,,,,,,,,,,,,, -122200,2.4556615,2.5142076,,,,,,,,,,,,,, -122300,2.758747,2.0744474,,,,,,,,,,,,,, -122400,2.5090008,2.4545267,,,,,,,,,,,,,, -122500,3.2128115,2.05971,,,,,,,,,,,,,, -122559,,,0.7341406345367432,1.0811148881912231,0.6796999573707581,1.328904390335083,50000.0,0.5550000071525574,1.9604276418685915,10000.0,56334.39921450615,62786.4237511158,56334.39921450615,6439.080575942993,5.9502387046813965,0.0 -122600,2.758382,2.0947063,,,,,,,,,,,,,, -122700,3.0275874,2.0025032,,,,,,,,,,,,,, -122800,2.7550936,2.0882943,,,,,,,,,,,,,, -122900,2.6329398,2.549429,,,,,,,,,,,,,, -123000,2.7894416,2.5613916,,,,,,,,,,,,,, -123100,2.6902645,2.4832506,,,,,,,,,,,,,, -123200,3.242601,1.9906938,,,,,,,,,,,,,, -123300,2.535782,2.0404482,,,,,,,,,,,,,, -123400,2.8015532,1.9269154,,,,,,,,,,,,,, -123473,,,0.7335742115974426,1.078386902809143,0.6755599975585938,1.3359389305114746,50000.0,0.5523000359535217,1.98598861694336,10000.0,56754.34090018272,63255.76044034958,56754.34090018272,6488.3691465854645,6.0046210289001465,0.0 -123500,2.6763697,2.018134,,,,,,,,,,,,,, -123600,2.7411733,2.2271435,,,,,,,,,,,,,, -123700,2.6017928,1.9013654,,,,,,,,,,,,,, -123800,2.6602807,2.1219459,,,,,,,,,,,,,, -123900,2.9325476,2.2107723,,,,,,,,,,,,,, -124000,2.9509757,2.070937,,,,,,,,,,,,,, -124100,2.849383,2.0978446,,,,,,,,,,,,,, -124200,3.1011038,2.083594,,,,,,,,,,,,,, -124300,2.443139,2.4104638,,,,,,,,,,,,,, -124389,,,0.75,0.994757890701294,0.6810799837112427,1.293065071105957,50000.0,0.5511000156402588,1.9417612552642824,10000.0,57174.481586933136,63721.19270992279,57174.481586933136,6533.5595326423645,6.052314043045044,0.0 -124400,2.9252543,2.0890257,,,,,,,,,,,,,, -124500,2.5683253,2.3863728,,,,,,,,,,,,,, -124600,2.7314556,4.4992423,,,,,,,,,,,,,, -124700,2.8503754,2.3704386,,,,,,,,,,,,,, -124800,2.9857326,1.9822305,,,,,,,,,,,,,, -124900,2.5135603,3.038348,,,,,,,,,,,,,, -125000,2.9004726,1.9594097,,,,,,,,,,,,,, -125100,2.7304196,1.9227082,,,,,,,,,,,,,, -125200,3.0699923,2.103435,,,,,,,,,,,,,, -125300,2.7064786,1.9878476,,,,,,,,,,,,,, -125305,,,0.74378901720047,1.0434715747833252,0.6869999766349792,1.2920804023742676,50000.0,0.560200035572052,1.9421415328979488,10000.0,57594.81980538368,64192.04764842987,57594.81980538368,6583.968408584595,6.108830213546753,0.0 -125400,2.6549294,4.299717,,,,,,,,,,,,,, -125500,2.8638391,2.0236247,,,,,,,,,,,,,, -125600,2.6636102,2.9279761,,,,,,,,,,,,,, -125700,3.009747,1.9491041,,,,,,,,,,,,,, -125800,3.0847836,2.1733258,,,,,,,,,,,,,, -125900,2.8836753,2.869381,,,,,,,,,,,,,, -126000,3.3193784,3.513115,,,,,,,,,,,,,, -126100,3.210104,1.9804004,,,,,,,,,,,,,, -126200,2.9368894,1.8739247,,,,,,,,,,,,,, -126221,,,0.7441992163658142,1.0305161476135254,0.6879000067710876,1.277883529663086,50000.0,0.5663000345230103,1.911031007766724,10000.0,58015.1064915657,64661.13260555267,58015.1064915657,6632.670027494431,6.153582334518433,0.0 -126300,2.7774508,4.209408,,,,,,,,,,,,,, -126400,2.511993,2.5540376,,,,,,,,,,,,,, -126500,3.1279376,4.60186,,,,,,,,,,,,,, -126600,2.7277882,3.6535597,,,,,,,,,,,,,, -126700,2.8423684,2.0725985,,,,,,,,,,,,,, -126800,2.8807204,2.1835465,,,,,,,,,,,,,, -126900,2.9750693,1.9273264,,,,,,,,,,,,,, -127000,2.6280584,2.3178606,,,,,,,,,,,,,, -127100,2.5271943,3.3829293,,,,,,,,,,,,,, -127138,,,0.7535937428474426,0.9919482469558716,0.6867799758911133,1.2697197198867798,50000.0,0.5621000528335571,1.8933143615722656,10000.0,58435.141280412674,65131.23244309425,58435.141280412674,6682.635187864304,6.201958656311035,0.0 -127200,3.1058717,4.299801,,,,,,,,,,,,,, -127300,3.1107152,1.903212,,,,,,,,,,,,,, -127400,2.5532138,2.3648257,,,,,,,,,,,,,, -127500,3.0063207,3.9401317,,,,,,,,,,,,,, -127600,2.6533704,1.930968,,,,,,,,,,,,,, -127700,2.867516,2.5753908,,,,,,,,,,,,,, -127800,2.5323567,2.3278878,,,,,,,,,,,,,, -127900,3.0632987,1.9604646,,,,,,,,,,,,,, -128000,2.8456051,3.643981,,,,,,,,,,,,,, -128053,,,0.7463476657867432,1.0128285884857178,0.6887399554252625,1.2626125812530518,50000.0,0.5675000548362732,1.9025838375091555,10000.0,58855.1561756134,65600.41216874123,58855.1561756134,6731.701937913895,6.2473344802856445,0.0 -128100,2.6775227,3.1580753,,,,,,,,,,,,,, -128200,2.8810868,4.2743196,,,,,,,,,,,,,, -128300,3.011694,4.6132407,,,,,,,,,,,,,, -128400,3.0650945,2.1706073,,,,,,,,,,,,,, -128500,3.145492,1.8636537,,,,,,,,,,,,,, -128600,2.738062,2.5022054,,,,,,,,,,,,,, -128700,2.7895477,2.4414124,,,,,,,,,,,,,, -128800,3.0730512,2.020978,,,,,,,,,,,,,, -128900,3.1754832,1.8700913,,,,,,,,,,,,,, -128968,,,0.7517382502555847,0.9961110353469848,0.6873199939727783,1.2667717933654783,50000.0,0.5649000406265259,1.8984917402267456,10000.0,59275.09725427628,66071.41319656372,59275.09725427628,6782.662457227707,6.294980049133301,0.0 -129000,3.036555,1.8236973,,,,,,,,,,,,,, -129100,2.9166927,4.3807845,,,,,,,,,,,,,, -129200,2.8718789,2.26024,,,,,,,,,,,,,, -129300,2.375508,3.3433995,,,,,,,,,,,,,, -129400,3.2056878,1.9567089,,,,,,,,,,,,,, -129500,2.894284,4.3767715,,,,,,,,,,,,,, -129600,2.8862646,1.9656526,,,,,,,,,,,,,, -129700,3.0733397,1.9912777,,,,,,,,,,,,,, -129800,3.0800173,2.0679185,,,,,,,,,,,,,, -129883,,,0.7596484422683716,0.9546266794204712,0.693839967250824,1.2397924661636353,50000.0,0.5732000470161438,1.85313093662262,10000.0,59695.053881406784,66539.30385136604,59695.053881406784,6830.492385864258,6.34734582901001,0.0 -129900,3.2449636,1.9195619,,,,,,,,,,,,,, -130000,3.2060986,2.059269,,,,,,,,,,,,,, -130100,2.9320505,2.6754222,,,,,,,,,,,,,, -130200,2.888696,2.4245458,,,,,,,,,,,,,, -130300,3.1235614,4.569894,,,,,,,,,,,,,, -130400,3.5894933,2.04123,,,,,,,,,,,,,, -130500,2.909574,4.4617634,,,,,,,,,,,,,, -130600,2.8295197,3.7473936,,,,,,,,,,,,,, -130700,3.100945,1.9445087,,,,,,,,,,,,,, -130796,,,0.764453113079071,0.9491795897483826,0.691540002822876,1.2591371536254885,50000.0,0.5675000548362732,1.8973604440689087,10000.0,60115.20013332367,67008.34401488304,60115.20013332367,6879.277712106705,6.404633522033691,0.0 -130800,2.8283503,3.490728,,,,,,,,,,,,,, -130900,3.0534198,1.8421447,,,,,,,,,,,,,, -131000,2.855181,2.7984924,,,,,,,,,,,,,, -131100,3.100281,1.8823308,,,,,,,,,,,,,, -131200,3.5230098,2.035772,,,,,,,,,,,,,, -131300,3.126903,1.8663864,,,,,,,,,,,,,, -131400,3.0919144,2.1477723,,,,,,,,,,,,,, -131500,3.1955163,1.9844632,,,,,,,,,,,,,, -131600,2.914349,4.255667,,,,,,,,,,,,,, -131700,3.098037,4.3049226,,,,,,,,,,,,,, -131711,,,0.7549218535423279,0.9674057364463806,0.6934999823570251,1.2365845441818235,50000.0,0.5690000057220459,1.871274471282959,10000.0,60535.40959310532,67477.44813871384,60535.40959310532,6928.066775798798,6.457538366317749,0.0 -131800,2.8493881,3.185123,,,,,,,,,,,,,, -131900,3.118101,2.715005,,,,,,,,,,,,,, -132000,2.886872,3.7669399,,,,,,,,,,,,,, -132100,2.9486282,1.8922143,,,,,,,,,,,,,, -132200,2.8533413,2.3668401,,,,,,,,,,,,,, -132300,3.198983,3.730215,,,,,,,,,,,,,, -132400,3.2753239,1.9849,,,,,,,,,,,,,, -132500,3.019055,1.8823123,,,,,,,,,,,,,, -132600,2.7236137,3.4313502,,,,,,,,,,,,,, -132625,,,0.7603515386581421,0.9755155444145204,0.6967200040817261,1.2530672550201416,50000.0,0.5770000219345093,1.8836565017700195,10000.0,60955.45423436165,67948.19107365608,60955.45423436165,6978.662467956543,6.508562088012695,0.0 -132700,3.1589334,2.1515598,,,,,,,,,,,,,, -132800,3.2263806,2.0191674,,,,,,,,,,,,,, -132900,3.2520466,4.214372,,,,,,,,,,,,,, -133000,3.3805141,1.959938,,,,,,,,,,,,,, -133100,3.1926632,1.8197328,,,,,,,,,,,,,, -133200,2.940481,1.822955,,,,,,,,,,,,,, -133300,3.2232842,1.9210696,,,,,,,,,,,,,, -133400,3.0556047,2.0305724,,,,,,,,,,,,,, -133500,3.3841531,1.8024043,,,,,,,,,,,,,, -133540,,,0.7734179496765137,0.9080670475959778,0.6987400054931641,1.2287685871124268,50000.0,0.5764999985694885,1.851581335067749,10000.0,61375.42639231682,68416.36034202576,61375.42639231682,7026.758812665939,6.553529500961304,0.0 -133600,2.9189215,2.3123956,,,,,,,,,,,,,, -133700,3.5554574,1.9598597,,,,,,,,,,,,,, -133800,3.1898365,3.3672988,,,,,,,,,,,,,, -133900,3.193122,2.286677,,,,,,,,,,,,,, -134000,3.4620929,1.8543646,,,,,,,,,,,,,, -134100,3.0990863,1.8832573,,,,,,,,,,,,,, -134200,2.9032693,2.3445601,,,,,,,,,,,,,, -134300,3.0551124,2.0936615,,,,,,,,,,,,,, -134400,3.084262,1.7204113,,,,,,,,,,,,,, -134453,,,0.7599413990974426,0.9587976336479188,0.7029199600219727,1.2122979164123535,50000.0,0.5779000520706177,1.8428810834884644,10000.0,61795.71336269379,68883.31156492233,61795.71336269379,7073.3166263103485,6.606256008148193,0.0 -134500,3.100837,1.7745583,,,,,,,,,,,,,, -134600,3.1533446,1.9916505,,,,,,,,,,,,,, -134700,3.1434755,3.5702605,,,,,,,,,,,,,, -134800,2.9981232,1.8958743,,,,,,,,,,,,,, -134900,3.2363024,2.150739,,,,,,,,,,,,,, -135000,3.105155,2.747707,,,,,,,,,,,,,, -135100,2.9637377,2.193498,,,,,,,,,,,,,, -135200,3.0886793,3.7698054,,,,,,,,,,,,,, -135300,3.193876,1.8527973,,,,,,,,,,,,,, -135369,,,0.7627929449081421,0.9346864819526672,0.7001399993896484,1.2140921354293823,50000.0,0.5750000476837158,1.8443801403045648,10000.0,62216.026047468185,69354.34955620766,62216.026047468185,7123.942526578903,6.653573513031006,0.0 -135400,3.5339797,1.9349189,,,,,,,,,,,,,, -135500,3.1734662,4.2792873,,,,,,,,,,,,,, -135600,3.2347553,1.8312012,,,,,,,,,,,,,, -135700,3.7740378,4.0858607,,,,,,,,,,,,,, -135800,3.2298503,2.277984,,,,,,,,,,,,,, -135900,3.6200452,1.8024187,,,,,,,,,,,,,, -136000,3.217181,2.446996,,,,,,,,,,,,,, -136100,3.1472762,1.9972905,,,,,,,,,,,,,, -136200,3.166231,3.4771407,,,,,,,,,,,,,, -136284,,,0.7731054425239563,0.9033991694450378,0.7005999684333801,1.2164199352264404,50000.0,0.5800999999046326,1.845173954963684,10000.0,62636.19765305519,69823.86749219894,62636.19765305519,7173.187728404999,6.702480316162109,0.0 -136300,3.493942,4.2087746,,,,,,,,,,,,,, -136400,3.4631684,4.321224,,,,,,,,,,,,,, -136500,3.64512,1.8540672,,,,,,,,,,,,,, -136600,3.1455193,2.50424,,,,,,,,,,,,,, -136700,3.0142548,2.144969,,,,,,,,,,,,,, -136800,3.6158316,3.3181624,,,,,,,,,,,,,, -136900,3.320009,1.7302662,,,,,,,,,,,,,, -137000,3.3637605,3.9089398,,,,,,,,,,,,,, -137100,3.6428838,1.7427475,,,,,,,,,,,,,, -137195,,,0.7671093344688416,0.9361597299575806,0.7026599645614624,1.2115135192871094,50000.0,0.5833000540733337,1.8194403648376465,10000.0,63056.44626426697,70293.874396801,63056.44626426697,7222.835695266724,6.760934591293335,0.0 -137200,3.1441033,1.910552,,,,,,,,,,,,,, -137300,3.6767347,4.370565,,,,,,,,,,,,,, -137400,3.6749911,1.8369408,,,,,,,,,,,,,, -137500,3.2590513,3.3323727,,,,,,,,,,,,,, -137600,3.2784276,2.2734616,,,,,,,,,,,,,, -137700,3.2068982,1.7260203,,,,,,,,,,,,,, -137800,3.4136891,1.9707996,,,,,,,,,,,,,, -137900,3.6982281,2.5079699,,,,,,,,,,,,,, -138000,2.9174592,3.052024,,,,,,,,,,,,,, -138100,3.132107,3.552683,,,,,,,,,,,,,, -138106,,,0.7684765458106995,0.9252225756645204,0.7012400031089783,1.2247495651245115,50000.0,0.5826000571250916,1.85190498828888,10000.0,63476.769704818726,70763.18441557884,63476.769704818726,7271.722208023071,6.809126853942871,0.0 -138200,3.5913486,1.9800221,,,,,,,,,,,,,, -138300,3.4011927,2.0550957,,,,,,,,,,,,,, -138400,3.3896275,1.7711775,,,,,,,,,,,,,, -138500,3.2830176,3.5170102,,,,,,,,,,,,,, -138600,3.4116206,3.8936877,,,,,,,,,,,,,, -138700,3.3492973,1.8953872,,,,,,,,,,,,,, -138800,3.4104328,2.563984,,,,,,,,,,,,,, -138900,3.2092457,3.76016,,,,,,,,,,,,,, -139000,3.4049146,2.5885046,,,,,,,,,,,,,, -139019,,,0.7814257740974426,0.8607298135757446,0.7080000042915344,1.1757205724716189,50000.0,0.5879000425338745,1.8092293739318848,10000.0,63896.763201236725,71234.63944602013,63896.763201236725,7323.079668521881,6.859033823013306,0.0 -139100,3.5139158,1.8216379,,,,,,,,,,,,,, -139200,3.440559,1.8983598,,,,,,,,,,,,,, -139300,3.8956437,2.4057798,,,,,,,,,,,,,, -139400,3.3831425,1.8547585,,,,,,,,,,,,,, -139500,3.367914,1.7733238,,,,,,,,,,,,,, -139600,3.2896502,3.255846,,,,,,,,,,,,,, -139700,3.3263412,1.8712078,,,,,,,,,,,,,, -139800,3.7847855,1.7813882,,,,,,,,,,,,,, -139900,3.5063565,2.7486513,,,,,,,,,,,,,, -139933,,,0.7718554735183716,0.9126390814781188,0.7079600095748901,1.185042142868042,50000.0,0.5852000117301941,1.819550514221192,10000.0,64317.06393766403,71702.45918250084,64317.06393766403,7370.497317314148,6.90878701210022,0.0 -140000,3.5460143,1.6842688,,,,,,,,,,,,,, -140100,3.8130603,1.6855847,,,,,,,,,,,,,, -140200,3.6746001,4.16518,,,,,,,,,,,,,, -140300,3.2907307,1.7959471,,,,,,,,,,,,,, -140400,3.1340365,2.8775735,,,,,,,,,,,,,, -140500,4.152291,1.8430302,,,,,,,,,,,,,, -140600,3.4373064,2.6884499,,,,,,,,,,,,,, -140700,3.0652442,2.8834863,,,,,,,,,,,,,, -140800,3.381941,1.7616295,,,,,,,,,,,,,, -140847,,,0.7747851610183716,0.8819167017936707,0.7116400003433228,1.1624095439910889,50000.0,0.5847000479698181,1.7930099964141846,10000.0,64737.02635455132,72173.21574926376,64737.02635455132,7421.191428661346,6.957134008407593,0.0 -140900,3.3061795,2.5468898,,,,,,,,,,,,,, -141000,3.6075153,1.8879986,,,,,,,,,,,,,, -141100,3.330357,2.3405795,,,,,,,,,,,,,, -141200,3.9664364,3.2880404,,,,,,,,,,,,,, -141300,3.462556,1.9875855,,,,,,,,,,,,,, -141400,3.750132,3.707862,,,,,,,,,,,,,, -141500,3.6291018,1.820084,,,,,,,,,,,,,, -141600,3.619926,4.1288037,,,,,,,,,,,,,, -141700,4.0012946,3.890893,,,,,,,,,,,,,, -141762,,,0.77978515625,0.8617091178894043,0.7116599678993225,1.165801763534546,50000.0,0.5872000455856323,1.7895723581314087,10000.0,65157.04562306404,72645.75348472595,65157.04562306404,7473.606050014496,7.00886607170105,0.0 -141800,3.6485617,1.7391882,,,,,,,,,,,,,, -141900,4.087006,1.84234,,,,,,,,,,,,,, -142000,3.4554079,1.8000004,,,,,,,,,,,,,, -142100,3.9271154,3.84788,,,,,,,,,,,,,, -142200,3.8561287,1.6893926,,,,,,,,,,,,,, -142300,3.4139266,1.7273155,,,,,,,,,,,,,, -142400,3.2391446,2.3619823,,,,,,,,,,,,,, -142500,3.753185,1.8405193,,,,,,,,,,,,,, -142600,3.8022141,2.3028371,,,,,,,,,,,,,, -142676,,,0.7821679711341858,0.8531965613365173,0.7168200016021729,1.1503877639770508,50000.0,0.5898000001907349,1.76817786693573,10000.0,65577.20102715492,73119.00194859505,65577.20102715492,7526.591713428497,7.064823865890503,0.0 -142700,3.9971237,2.4185865,,,,,,,,,,,,,, -142800,3.3276148,3.008454,,,,,,,,,,,,,, -142900,3.5771608,1.6892219,,,,,,,,,,,,,, -143000,3.5953398,1.7199987,,,,,,,,,,,,,, -143100,3.871368,1.7182499,,,,,,,,,,,,,, -143200,4.2754,1.8777783,,,,,,,,,,,,,, -143300,3.6557999,2.3643892,,,,,,,,,,,,,, -143400,3.2699144,2.443175,,,,,,,,,,,,,, -143500,3.7971997,1.8158791,,,,,,,,,,,,,, -143588,,,0.7820116877555847,0.8580734729766846,0.7165200114250183,1.1449278593063354,50000.0,0.5919000506401062,1.7643986940383911,10000.0,65996.80779743195,73590.25830888748,65996.80779743195,7577.742676258087,7.511917352676392,0.0 -143600,3.5827334,1.8198186,,,,,,,,,,,,,, -143700,3.863101,1.7486408,,,,,,,,,,,,,, -143800,3.3531868,2.2448993,,,,,,,,,,,,,, -143900,3.4482234,3.4272287,,,,,,,,,,,,,, -144000,3.6757734,3.6746528,,,,,,,,,,,,,, -144100,3.8609333,2.6774192,,,,,,,,,,,,,, -144200,3.9942894,1.8682406,,,,,,,,,,,,,, -144300,3.8589184,1.74922,,,,,,,,,,,,,, -144400,3.4805474,3.7012982,,,,,,,,,,,,,, -144500,3.680228,1.6669598,,,,,,,,,,,,,, -144501,,,0.7851171493530273,0.8476221561431885,0.7163400053977966,1.144951581954956,50000.0,0.5920000076293945,1.7644582986831665,10000.0,66416.94591355324,74057.55110120773,66416.94591355324,7624.79746389389,7.559880018234253,0.0 -144600,3.8217938,3.16039,,,,,,,,,,,,,, -144700,3.9154804,2.85181,,,,,,,,,,,,,, -144800,3.8550606,4.2947235,,,,,,,,,,,,,, -144900,4.2107344,1.7428786,,,,,,,,,,,,,, -145000,4.094705,4.2002425,,,,,,,,,,,,,, -145100,3.642429,1.6300057,,,,,,,,,,,,,, -145200,3.7619443,1.7156668,,,,,,,,,,,,,, -145300,3.9062185,2.0598526,,,,,,,,,,,,,, -145400,4.047144,1.8071265,,,,,,,,,,,,,, -145415,,,0.7938281297683716,0.8047267198562622,0.7177599668502808,1.1373757123947144,50000.0,0.6025000214576721,1.743360996246338,10000.0,66837.3013036251,74521.7456138134,66837.3013036251,7668.535885095596,7.608190298080444,0.0 -145500,3.59444,2.7810235,,,,,,,,,,,,,, -145600,3.8357003,3.3626049,,,,,,,,,,,,,, -145700,5.5344005,4.320508,,,,,,,,,,,,,, -145800,3.8732986,1.768585,,,,,,,,,,,,,, -145900,3.833375,1.9164777,,,,,,,,,,,,,, -146000,3.815456,1.9062723,,,,,,,,,,,,,, -146100,3.785583,2.248338,,,,,,,,,,,,,, -146200,3.528829,2.9451735,,,,,,,,,,,,,, -146300,3.9354267,2.1724603,,,,,,,,,,,,,, -146328,,,0.7868554592132568,0.8384296894073486,0.7208999991416931,1.1160486936569214,50000.0,0.6013000011444092,1.727980136871338,10000.0,67257.48422813416,74991.41473031044,67257.48422813416,7717.922769069672,7.655825853347778,0.0 -146400,3.9056153,3.2653012,,,,,,,,,,,,,, -146500,4.3399568,3.9659078,,,,,,,,,,,,,, -146600,4.0390043,1.6415417,,,,,,,,,,,,,, -146700,3.7853389,1.7421765,,,,,,,,,,,,,, -146800,4.057602,1.7739983,,,,,,,,,,,,,, -146900,3.7573225,1.7243384,,,,,,,,,,,,,, -147000,3.7644343,2.9393525,,,,,,,,,,,,,, -147100,4.1320953,4.183683,,,,,,,,,,,,,, -147200,3.8502567,2.860025,,,,,,,,,,,,,, -147243,,,0.7939453125,0.8069570064544678,0.7222599983215332,1.10954487323761,50000.0,0.6025000214576721,1.7282333374023438,10000.0,67677.5356965065,75460.78970003128,67677.5356965065,7767.141966342926,7.707298278808594,0.0 -147300,3.780783,1.8665249,,,,,,,,,,,,,, -147400,4.0177646,1.9248891,,,,,,,,,,,,,, -147500,3.916482,1.7753438,,,,,,,,,,,,,, -147600,4.1590614,1.8079487,,,,,,,,,,,,,, -147700,4.09666,1.6956574,,,,,,,,,,,,,, -147800,4.122468,1.6922263,,,,,,,,,,,,,, -147900,4.298794,1.9751029,,,,,,,,,,,,,, -148000,4.2639184,4.061504,,,,,,,,,,,,,, -148100,4.2026844,1.6984634,,,,,,,,,,,,,, -148158,,,0.7961523532867432,0.7881932854652405,0.7221599817276001,1.1152422428131104,50000.0,0.6007000207901001,1.723463773727417,10000.0,68097.77912926674,75932.03611707687,68097.77912926674,7818.040345191956,7.759216070175171,0.0 -148200,3.9654827,1.6569769,,,,,,,,,,,,,, -148300,3.8550277,3.3034902,,,,,,,,,,,,,, -148400,3.8860414,1.9271569,,,,,,,,,,,,,, -148500,4.4800134,4.1248255,,,,,,,,,,,,,, -148600,4.13009,4.168578,,,,,,,,,,,,,, -148700,3.8512838,2.5934598,,,,,,,,,,,,,, -148800,4.526534,1.8220041,,,,,,,,,,,,,, -148900,4.0147,2.2286136,,,,,,,,,,,,,, -149000,4.0366216,1.8429791,,,,,,,,,,,,,, -149073,,,0.7910937070846558,0.8174343109130859,0.7259599566459656,1.0968220233917236,50000.0,0.6014000177383423,1.7236626148223877,10000.0,68517.94238138199,76401.87269186974,68517.94238138199,7867.6073389053345,7.813696146011352,0.0 -149100,3.850809,3.010681,,,,,,,,,,,,,, -149200,4.4954467,1.6996562,,,,,,,,,,,,,, -149300,4.060715,1.6482904,,,,,,,,,,,,,, -149400,3.80276,3.3463094,,,,,,,,,,,,,, -149500,4.019151,1.5754795,,,,,,,,,,,,,, -149600,3.9987729,2.8530073,,,,,,,,,,,,,, -149700,4.070709,3.1033106,,,,,,,,,,,,,, -149800,4.194424,3.382923,,,,,,,,,,,,,, -149900,4.54585,3.944255,,,,,,,,,,,,,, -149988,,,0.7939257621765137,0.8128007054328918,0.7231400012969971,1.1107922792434692,50000.0,0.602400004863739,1.7529717683792114,10000.0,68937.99177789688,76871.03368258476,68937.99177789688,7916.614735364914,7.865723133087158,0.0 -150000,4.370496,2.3795981,,,,,,,,,,,,,, -150100,4.4209766,1.7232043,,,,,,,,,,,,,, -150200,4.1190963,1.7962693,,,,,,,,,,,,,, -150300,5.0636826,3.8523927,,,,,,,,,,,,,, -150400,4.3708014,1.7483662,,,,,,,,,,,,,, -150500,4.226767,1.9939339,,,,,,,,,,,,,, -150600,4.1180396,1.7055216,,,,,,,,,,,,,, -150700,4.3412,1.762701,,,,,,,,,,,,,, -150800,4.269175,3.559122,,,,,,,,,,,,,, -150900,,,0.80189448595047,0.7681282758712769,0.7280600070953369,1.0890097618103027,50000.0,0.6033000349998474,1.7106510400772097,10000.0,69358.29975485802,77339.0362174511,69358.29975485802,7964.20180106163,7.918400287628174,0.0 -150900,3.972483,1.797988,,,,,,,,,,,,,, -151000,3.9961646,3.1558294,,,,,,,,,,,,,, -151100,4.4563417,1.7052554,,,,,,,,,,,,,, -151200,4.3641496,1.7596678,,,,,,,,,,,,,, -151300,4.9740033,4.042652,,,,,,,,,,,,,, -151400,4.21129,1.6271993,,,,,,,,,,,,,, -151500,3.8290277,1.8726994,,,,,,,,,,,,,, -151600,4.1142073,1.6147739,,,,,,,,,,,,,, -151700,4.739556,4.026566,,,,,,,,,,,,,, -151800,4.071616,3.0996184,,,,,,,,,,,,,, -151815,,,0.7974413633346558,0.7807225584983826,0.7289199829101562,1.0758874416351318,50000.0,0.614300012588501,1.681167483329773,10000.0,69778.54154014587,77809.62989974022,69778.54154014587,8014.445042133331,7.975107192993164,0.0 -151900,4.718871,3.8980865,,,,,,,,,,,,,, -152000,4.6753473,3.2287507,,,,,,,,,,,,,, -152100,3.8252413,2.7400153,,,,,,,,,,,,,, -152200,4.2415757,1.5972545,,,,,,,,,,,,,, -152300,3.822134,2.8084872,,,,,,,,,,,,,, -152400,4.4199843,1.5801544,,,,,,,,,,,,,, -152500,4.3311243,1.8082271,,,,,,,,,,,,,, -152600,4.324119,1.7456914,,,,,,,,,,,,,, -152700,4.588441,3.522379,,,,,,,,,,,,,, -152729,,,0.8052343726158142,0.7729941606521606,0.7317799925804138,1.0814770460128784,50000.0,0.6094000339508057,1.697064757347107,10000.0,70198.79391741753,78281.05008983612,70198.79391741753,8065.495423555374,8.04086971282959,0.0 -152800,4.1947203,2.5519538,,,,,,,,,,,,,, -152900,4.341351,2.418126,,,,,,,,,,,,,, -153000,4.7563567,1.6787288,,,,,,,,,,,,,, -153100,5.4185524,1.6737146,,,,,,,,,,,,,, -153200,4.6010013,1.802964,,,,,,,,,,,,,, -153300,4.611399,3.666156,,,,,,,,,,,,,, -153400,4.16263,2.6231217,,,,,,,,,,,,,, -153500,4.151379,2.69286,,,,,,,,,,,,,, -153600,4.0434914,2.095418,,,,,,,,,,,,,, -153644,,,0.805468738079071,0.7481812834739685,0.7303999662399292,1.070380926132202,50000.0,0.6111000180244446,1.6869773864746094,10000.0,70618.86042189598,78747.78933882713,70618.86042189598,8112.055419683456,8.100271940231323,0.0 -153700,4.7403736,1.7733337,,,,,,,,,,,,,, -153800,4.5384655,3.4969068,,,,,,,,,,,,,, -153900,4.569149,2.7078474,,,,,,,,,,,,,, -154000,4.0692244,2.3929923,,,,,,,,,,,,,, -154100,4.1451373,2.294909,,,,,,,,,,,,,, -154200,4.644221,1.6057123,,,,,,,,,,,,,, -154300,5.3683214,4.0590415,,,,,,,,,,,,,, -154400,5.4321837,1.6316518,,,,,,,,,,,,,, -154500,4.518785,1.5413197,,,,,,,,,,,,,, -154558,,,0.806640625,0.7597796320915222,0.7336199879646301,1.068161129951477,50000.0,0.6081000566482544,1.683420181274414,10000.0,71038.89610767365,79215.44156551361,71038.89610767365,8159.562078952789,8.1575448513031,0.0 -154600,4.956747,1.76636,,,,,,,,,,,,,, -154700,4.6778765,1.546231,,,,,,,,,,,,,, -154800,4.9349008,1.6835883,,,,,,,,,,,,,, -154900,5.249209,4.0348716,,,,,,,,,,,,,, -155000,4.336775,1.8868482,,,,,,,,,,,,,, -155100,4.6156383,1.528272,,,,,,,,,,,,,, -155200,4.594363,3.1976533,,,,,,,,,,,,,, -155300,4.8500752,1.707347,,,,,,,,,,,,,, -155400,4.8739686,1.624213,,,,,,,,,,,,,, -155470,,,0.8057226538658142,0.7690586447715759,0.7343399524688721,1.075039625167847,50000.0,0.6116000413894653,1.6947035789489746,10000.0,71458.79195690155,79681.98218774796,71458.79195690155,8206.094826936722,8.21754264831543,0.0 -155500,4.5051584,1.5836048,,,,,,,,,,,,,, -155600,4.9854164,1.6128875,,,,,,,,,,,,,, -155700,4.680887,1.6767284,,,,,,,,,,,,,, -155800,4.474355,1.5418469,,,,,,,,,,,,,, -155900,4.851716,2.5337,,,,,,,,,,,,,, -156000,4.719206,3.3699102,,,,,,,,,,,,,, -156100,4.2924495,2.0038404,,,,,,,,,,,,,, -156200,4.6329856,2.3950174,,,,,,,,,,,,,, -156300,4.428442,1.5385964,,,,,,,,,,,,,, -156385,,,0.8109179735183716,0.7232632040977478,0.7350800037384033,1.0466346740722656,50000.0,0.6154000163078308,1.6492758989334106,10000.0,71878.8363673687,80151.62154054642,71878.8363673687,8255.587515115738,8.26708173751831,0.0 -156400,5.0516534,1.6527714,,,,,,,,,,,,,, -156500,4.3403764,1.767025,,,,,,,,,,,,,, -156600,5.1270742,1.533444,,,,,,,,,,,,,, -156700,5.4769626,4.0450573,,,,,,,,,,,,,, -156800,4.7725687,3.0062685,,,,,,,,,,,,,, -156900,4.627067,1.4812138,,,,,,,,,,,,,, -157000,5.080358,1.6115078,,,,,,,,,,,,,, -157100,4.7973824,2.4213543,,,,,,,,,,,,,, -157200,4.4739647,2.740752,,,,,,,,,,,,,, -157300,4.949353,1.612807,,,,,,,,,,,,,, -157302,,,0.8225781321525574,0.6812586188316345,0.739139974117279,1.0413532257080078,50000.0,0.6164000034332275,1.6407500505447388,10000.0,72298.85396432877,80617.73319363594,72298.85396432877,8301.570994377136,8.324889659881592,0.0 -157400,4.9489813,1.5197766,,,,,,,,,,,,,, -157500,4.8071656,1.5103956,,,,,,,,,,,,,, -157600,4.4776483,2.085876,,,,,,,,,,,,,, -157700,4.905572,1.4923044,,,,,,,,,,,,,, -157800,4.8520017,1.5850139,,,,,,,,,,,,,, -157900,5.3669853,3.948101,,,,,,,,,,,,,, -158000,4.7996078,2.463312,,,,,,,,,,,,,, -158100,4.7847123,1.7865248,,,,,,,,,,,,,, -158200,4.638245,2.908613,,,,,,,,,,,,,, -158215,,,0.8095898032188416,0.7571691870689392,0.7379199862480164,1.0595265626907349,50000.0,0.61080002784729,1.6860063076019287,10000.0,72719.15087842941,81086.91956949234,72719.15087842941,8350.355193376541,8.37764310836792,0.0 -158300,4.839529,2.8138294,,,,,,,,,,,,,, -158400,4.695429,1.4709278,,,,,,,,,,,,,, -158500,6.032057,4.045566,,,,,,,,,,,,,, -158600,4.78814,1.6057942,,,,,,,,,,,,,, -158700,5.2786508,1.4937699,,,,,,,,,,,,,, -158800,5.253671,1.475984,,,,,,,,,,,,,, -158900,4.9238343,2.5539715,,,,,,,,,,,,,, -159000,5.1801753,1.5694728,,,,,,,,,,,,,, -159100,4.935128,1.526159,,,,,,,,,,,,,, -159131,,,0.8150585889816284,0.7215979695320129,0.7387599945068359,1.0488839149475098,50000.0,0.6171000003814697,1.6728469133377075,10000.0,73139.33359980583,81553.05968117714,73139.33359980583,8396.208889245987,8.428681373596191,0.0 -159200,4.5852156,2.2103996,,,,,,,,,,,,,, -159300,5.0276814,2.2691793,,,,,,,,,,,,,, -159400,5.1520205,1.5330384,,,,,,,,,,,,,, -159500,4.992797,1.4631875,,,,,,,,,,,,,, -159600,4.7555966,1.4946533,,,,,,,,,,,,,, -159700,5.6088142,1.8479648,,,,,,,,,,,,,, -159800,6.3013244,3.927193,,,,,,,,,,,,,, -159900,4.695036,1.489479,,,,,,,,,,,,,, -160000,5.4033074,3.223848,,,,,,,,,,,,,, -160046,,,0.8247265219688416,0.6886405944824219,0.7408199906349182,1.0412501096725464,50000.0,0.6206000447273254,1.6554361581802368,10000.0,73559.36250901222,82023.44992232323,73559.36250901222,8446.458607912064,8.487337112426758,0.0 -160100,5.784506,3.8391783,,,,,,,,,,,,,, -160200,5.259662,1.5340546,,,,,,,,,,,,,, -160300,5.9385333,1.596472,,,,,,,,,,,,,, -160400,4.7279344,2.1724868,,,,,,,,,,,,,, -160500,4.8397574,1.3917757,,,,,,,,,,,,,, -160600,4.720094,1.5185914,,,,,,,,,,,,,, -160700,5.1693826,1.5057181,,,,,,,,,,,,,, -160800,5.620216,2.170739,,,,,,,,,,,,,, -160900,4.6659446,2.1563952,,,,,,,,,,,,,, -160960,,,0.8200390338897705,0.7016811370849609,0.744439959526062,1.0223753452301023,50000.0,0.6252000331878662,1.632118582725525,10000.0,73979.74830436707,82492.14982962608,73979.74830436707,8494.669929981232,8.537365913391113,0.0 -161000,4.916928,1.8348354,,,,,,,,,,,,,, -161100,5.532024,1.6545986,,,,,,,,,,,,,, -161200,4.8914747,1.5686581,,,,,,,,,,,,,, -161300,5.178446,1.5275539,,,,,,,,,,,,,, -161400,5.6864166,3.7384787,,,,,,,,,,,,,, -161500,4.997368,2.4598768,,,,,,,,,,,,,, -161600,4.596337,2.2283578,,,,,,,,,,,,,, -161700,5.443029,1.4407523,,,,,,,,,,,,,, -161800,6.2392774,3.886462,,,,,,,,,,,,,, -161873,,,0.8234765529632568,0.682098388671875,0.7447400093078613,1.019462823867798,50000.0,0.6238000392913818,1.6342037916183472,10000.0,74399.8327550888,82959.62244343758,74399.8327550888,8541.95247745514,8.591475486755371,0.0 -161900,5.2482677,2.5571907,,,,,,,,,,,,,, -162000,4.8373475,1.8495018,,,,,,,,,,,,,, -162100,4.9674125,1.4475698,,,,,,,,,,,,,, -162200,4.9403224,2.462951,,,,,,,,,,,,,, -162300,5.350606,1.6478788,,,,,,,,,,,,,, -162400,5.0004487,2.6936123,,,,,,,,,,,,,, -162500,5.460961,1.591559,,,,,,,,,,,,,, -162600,5.197374,1.2375592,,,,,,,,,,,,,, -162700,5.57489,3.291677,,,,,,,,,,,,,, -162788,,,0.8300195336341858,0.6636582612991333,0.7465400099754333,1.0095263719558716,50000.0,0.6231000423431396,1.6292407512664795,10000.0,74819.76510477066,83429.02394080162,74819.76510477066,8591.316792964935,8.644585847854614,0.0 -162800,4.7725205,2.0249648,,,,,,,,,,,,,, -162900,4.911681,1.5236174,,,,,,,,,,,,,, -163000,5.084397,1.8118472,,,,,,,,,,,,,, -163100,5.1547155,1.4252663,,,,,,,,,,,,,, -163200,5.388488,1.4762042,,,,,,,,,,,,,, -163300,5.3907185,1.4835547,,,,,,,,,,,,,, -163400,5.4237676,1.4804566,,,,,,,,,,,,,, -163500,5.608237,1.5366938,,,,,,,,,,,,,, -163600,5.237856,1.8200417,,,,,,,,,,,,,, -163700,5.599107,1.6035872,,,,,,,,,,,,,, -163702,,,0.8236327767372131,0.6771386861801147,0.7473399639129639,1.008391499519348,50000.0,0.6289000511169434,1.6085121631622314,10000.0,75239.78667020798,83894.52511429787,75239.78667020798,8636.684086561203,8.705162763595581,0.0 -163800,5.8250213,3.177343,,,,,,,,,,,,,, -163900,6.151038,1.6202784,,,,,,,,,,,,,, -164000,5.834813,1.532432,,,,,,,,,,,,,, -164100,5.6325817,1.5384928,,,,,,,,,,,,,, -164200,5.65261,1.4967425,,,,,,,,,,,,,, -164300,5.6738753,1.4575517,,,,,,,,,,,,,, -164400,5.241318,1.5079306,,,,,,,,,,,,,, -164500,5.229831,1.4477042,,,,,,,,,,,,,, -164600,5.6354804,1.3750899,,,,,,,,,,,,,, -164617,,,0.8244921565055847,0.6821433305740356,0.7472999691963196,1.0059750080108645,50000.0,0.6265000104904175,1.6200848817825315,10000.0,75659.76452445984,84362.22888803482,75659.76452445984,8684.300383806229,8.755192041397095,0.0 -164700,5.7657633,1.6925919,,,,,,,,,,,,,, -164800,6.19119,1.4630302,,,,,,,,,,,,,, -164900,4.918087,2.751024,,,,,,,,,,,,,, -165000,6.1042304,1.4324303,,,,,,,,,,,,,, -165100,5.4369755,1.5255138,,,,,,,,,,,,,, -165200,5.234058,2.5580325,,,,,,,,,,,,,, -165300,7.2508574,3.629583,,,,,,,,,,,,,, -165400,5.9611,1.4640019,,,,,,,,,,,,,, -165500,5.9237823,1.4704676,,,,,,,,,,,,,, -165531,,,0.8278319835662842,0.6549848914146423,0.7478399872779846,0.9953515529632568,50000.0,0.6282000541687012,1.6096539497375488,10000.0,76080.02516317368,84832.28757619858,76080.02516317368,8733.994886159897,8.806623458862305,0.0 -165600,5.4132547,1.4846781,,,,,,,,,,,,,, -165700,5.696411,1.6296363,,,,,,,,,,,,,, -165800,5.549487,1.4035095,,,,,,,,,,,,,, -165900,6.461188,1.5627928,,,,,,,,,,,,,, -166000,6.6323757,1.8383415,,,,,,,,,,,,,, -166100,5.473381,2.0626278,,,,,,,,,,,,,, -166200,5.8724346,1.4017248,,,,,,,,,,,,,, -166300,5.7911377,1.4418818,,,,,,,,,,,,,, -166400,5.5878706,1.4721651,,,,,,,,,,,,,, -166447,,,0.8307226300239563,0.6519699096679688,0.7495200037956238,0.9923651218414308,50000.0,0.6318000555038452,1.6084636449813845,10000.0,76500.17499065399,85297.9245531559,76500.17499065399,8779.374361515045,8.862721681594849,0.0 -166500,5.586865,3.2793047,,,,,,,,,,,,,, -166600,6.006634,1.4552218,,,,,,,,,,,,,, -166700,5.981642,2.154442,,,,,,,,,,,,,, -166800,6.3353286,1.4501317,,,,,,,,,,,,,, -166900,5.5301375,1.4219623,,,,,,,,,,,,,, -167000,5.4218545,2.659203,,,,,,,,,,,,,, -167100,5.9464326,3.1892674,,,,,,,,,,,,,, -167200,5.9405017,1.9284275,,,,,,,,,,,,,, -167300,5.752539,3.3494608,,,,,,,,,,,,,, -167359,,,0.8291015625,0.6523779034614563,0.7506399750709534,0.9864511489868164,50000.0,0.6297000050544739,1.5886270999908447,10000.0,76920.16060471535,85767.60281062126,76920.16060471535,8828.95196390152,8.925573348999023,0.0 -167400,5.5941234,1.3629596,,,,,,,,,,,,,, -167500,5.5272765,1.3642813,,,,,,,,,,,,,, -167600,6.566525,3.792185,,,,,,,,,,,,,, -167700,5.8834662,1.5930979,,,,,,,,,,,,,, -167800,5.663102,2.737343,,,,,,,,,,,,,, -167900,5.7982707,1.353971,,,,,,,,,,,,,, -168000,5.722736,1.484561,,,,,,,,,,,,,, -168100,5.792393,1.4376589,,,,,,,,,,,,,, -168200,5.5024886,1.6083238,,,,,,,,,,,,,, -168272,,,0.8313866853713989,0.637546181678772,0.753600001335144,0.9787457585334778,50000.0,0.6333000063896179,1.587046504020691,10000.0,77340.48335027695,86236.98085284233,77340.48335027695,8877.900985002518,8.980156898498535,0.0 -168300,5.7141185,1.4701816,,,,,,,,,,,,,, -168400,5.223223,1.8468792,,,,,,,,,,,,,, -168500,5.8917236,1.3719984,,,,,,,,,,,,,, -168600,5.5342402,2.4578388,,,,,,,,,,,,,, -168670,,,,,,,,,,,77520.35540795326,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 18773aff5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -29.252927780151367,0.0,35.27015542984009,1,0,35.27015542984009,0.0010000000474974,6.907756805419922,10000,64.52321434020996,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -77.09238171577454,0.0176796913146972,455.6410744190216,850,0,455.6410744190216,0.026800001040101,6.039024353027344,10000,532.7995707988739,0.03515625,5.885966777801514,0.0339199975132942,5.9119391441345215,50000 -127.39190244674684,0.0458791255950927,875.7305772304535,1756,0,875.7305772304535,0.0533000007271766,5.654690742492676,10000,1003.27015709877,0.0728906244039535,5.357734680175781,0.0655599981546402,5.43528413772583,50000 -178.35270309448242,0.0743927955627441,1295.863152742386,2665,0,1295.863152742386,0.0734999999403953,5.327393054962158,10000,1474.4437997341156,0.1038476526737213,4.981348991394043,0.0971199944615364,5.043646812438965,50000 -228.4377381801605,0.1031725406646728,1715.986170053482,3575,0,1715.986170053482,0.0984000042080879,5.043509006500244,10000,1944.7330491542816,0.1374218761920929,4.640011787414551,0.1282400041818618,4.715636253356934,50000 -278.38977098464966,0.1342794895172119,2136.13942360878,4482,0,2136.13942360878,0.1194000020623207,4.872969627380371,10000,2414.920921087265,0.1684765517711639,4.419726848602295,0.1536400020122528,4.517857551574707,50000 -330.4113621711731,0.1604244709014892,2556.4348685741425,5394,0,2556.4348685741425,0.1340000033378601,4.738316059112549,10000,2887.315844535828,0.1861328035593032,4.220674514770508,0.1708399951457977,4.3279571533203125,50000 -381.2634997367859,0.18742036819458,2976.544565677643,6306,0,2976.544565677643,0.1624000072479248,4.428020477294922,10000,3358.3559985160828,0.2262695282697677,3.890352487564087,0.2067800015211105,3.99831748008728,50000 -431.6334598064423,0.2173418998718261,3396.534752845764,7218,0,3396.534752845764,0.165800005197525,4.394979000091553,10000,3828.797343492508,0.232421875,3.807736158370972,0.2166399955749511,3.93320369720459,50000 -483.87773609161377,0.2426652908325195,3816.9069616794586,8131,0,3816.9069616794586,0.1724000126123428,4.411588668823242,10000,4301.491326808929,0.243203118443489,3.820199728012085,0.2207199931144714,3.940143585205078,50000 -535.8006870746613,0.69814133644104,4236.489646196365,9041,0,4236.489646196365,0.1932000070810318,4.174107074737549,10000,4773.503950834274,0.2707031071186065,3.5831527709960938,0.2477799952030182,3.710452079772949,50000 -586.9349710941315,0.7309134006500244,4656.436341047287,9953,0,4656.436341047287,0.1997000128030777,4.167118549346924,10000,5244.670069217682,0.2805859446525574,3.5102574825286865,0.258459985256195,3.65179705619812,50000 -637.9864263534546,0.7608444690704346,5076.775891304016,10864,0,5076.775891304016,0.2043000161647796,4.103435516357422,10000,5716.142570257187,0.2991601526737213,3.379409074783325,0.2633799910545349,3.607219934463501,50000 -690.076201915741,0.7959158420562744,5497.108896970749,11773,0,5497.108896970749,0.2072000056505203,4.052846908569336,10000,6188.652234077454,0.2922460734844208,3.4271626472473145,0.2735399901866913,3.5501279830932617,50000 -743.02845287323,0.8236494064331055,5917.175496578217,12683,0,5917.175496578217,0.21390001475811,4.058302879333496,10000,6661.750261306763,0.2971484363079071,3.4095776081085205,0.277099996805191,3.530275821685791,50000 -793.3429248332977,0.8524086475372314,6337.318806171417,13592,0,6337.318806171417,0.2124000042676925,4.011412143707275,10000,7132.289152383804,0.3229882717132568,3.2224056720733643,0.28015998005867,3.494879007339477,50000 -844.848123550415,0.8794562816619873,6757.358831167221,14498,0,6757.358831167221,0.2227000147104263,3.9273860454559326,10000,7603.913234949112,0.3158788979053497,3.273736000061035,0.2939800024032593,3.3951432704925537,50000 -896.8590533733368,0.9088699817657472,7177.752962350845,15405,0,7177.752962350845,0.2321000099182129,3.9265310764312735,10000,8076.399621963501,0.3174218535423279,3.3037431240081787,0.293639987707138,3.4267396926879883,50000 -947.5592894554138,0.9369139671325684,7597.680613279343,16315,0,7597.680613279343,0.2264000177383422,3.9574127197265625,10000,8547.108525753021,0.3279101550579071,3.2197184562683105,0.2950599789619446,3.416271448135376,50000 -1000.5289397239684,0.9675893783569336,8017.9552347660065,17227,0,8017.9552347660065,0.2181000113487243,4.019433975219727,10000,9020.436147928238,0.3093163967132568,3.3662712574005127,0.2869600057601928,3.501178741455078,50000 -1051.960121870041,0.9948971271514891,8437.884447813034,18135,0,8437.884447813034,0.2333000153303146,3.880850553512573,10000,9491.875847578049,0.3327734172344208,3.162445545196533,0.3094799816608429,3.302088022232056,50000 -1105.318029642105,1.0265140533447266,8858.166803836823,19043,0,8858.166803836823,0.2296000123023986,3.9306869506835938,10000,9965.598746538162,0.3293164074420929,3.2388904094696045,0.3053599894046783,3.392557382583618,50000 -1160.165560245514,1.059199333190918,9278.551815271378,19952,0,9278.551815271378,0.2477000057697296,3.826837778091431,10000,10440.915224790571,0.3375976383686065,3.134190082550049,0.3170999884605407,3.262382984161377,50000 -1212.4267747402191,1.0930509567260742,9698.90292072296,20863,0,9698.90292072296,0.2464000135660171,3.7816295623779297,10000,10913.613707065582,0.3479296863079071,3.0743730068206787,0.3225999772548675,3.2060763835906982,50000 -1263.3394284248352,1.1243548393249512,10119.1109790802,21769,0,10119.1109790802,0.2497000098228454,3.733030319213867,10000,11384.816810846329,0.3544921875,3.012260913848877,0.3244200050830841,3.1736767292022705,50000 -1314.6286573410034,1.1545770168304443,10539.403557777405,22677,0,10539.403557777405,0.246500015258789,3.79990553855896,10000,11856.48029446602,0.3376367092132568,3.103149652481079,0.3175399899482727,3.2267537117004395,50000 -1366.626068353653,1.188666820526123,10959.684936523438,23585,0,10959.684936523438,0.2532000243663788,3.731189250946045,10000,12328.844103336334,0.3579882681369781,3.0236752033233643,0.3356199860572815,3.154872179031372,50000 -1419.0612137317655,1.2243428230285645,11380.087949991226,24492,0,11380.087949991226,0.2571000158786773,3.7189767360687256,10000,12801.7687625885,0.3534179627895355,2.9814937114715576,0.3287400007247925,3.1595587730407715,50000 -1470.207144498825,1.257941961288452,11800.085270881653,25398,0,11800.085270881653,0.2442000061273574,3.829700231552124,10000,13272.997171640396,0.3600195348262787,3.08575177192688,0.3273800015449524,3.256126880645752,50000 -1522.5056171417236,1.2966363430023191,12220.36465883255,26305,0,12220.36465883255,0.2581000030040741,3.705676317214966,10000,13745.666129112244,0.3636718690395355,2.97861909866333,0.3398799896240234,3.1036124229431152,50000 -1573.5893032550812,1.327235221862793,12640.606446743011,27209,0,12640.606446743011,0.2712000012397766,3.6168735027313232,10000,14217.07286643982,0.3745898306369781,2.8777029514312744,0.3472599983215332,3.050992727279663,50000 -1625.0876586437223,1.3568122386932373,13060.701996088028,28110,0,13060.701996088028,0.2741000056266784,3.6207833290100098,10000,14688.747186660768,0.4131249785423279,2.7041707038879395,0.3500999808311462,3.0397725105285645,50000 -1677.4866523742676,1.3948171138763428,13480.76611328125,29014,0,13480.76611328125,0.2489000111818313,3.72862720489502,10000,15161.299011945724,0.3537499904632568,3.0552871227264404,0.3319399952888489,3.185477495193481,50000 -1730.9614639282229,1.4306964874267578,13901.006760120392,29917,0,13901.006760120392,0.256600022315979,3.744589328765869,10000,15635.10154223442,0.3595312535762787,3.01259183883667,0.3307799994945526,3.17191219329834,50000 -1784.915414810181,1.4623537063598633,14321.184031248093,30820,0,14321.184031248093,0.2592000067234039,3.72937273979187,10000,16109.316176652908,0.3834374845027923,2.864189624786377,0.3374399840831756,3.1222946643829346,50000 -1840.1619803905487,1.498626470565796,14741.241425275804,31728,0,14741.241425275804,0.2592000067234039,3.729475259780884,10000,16584.70828318596,0.3595312535762787,3.016347885131836,0.3363399803638458,3.149092197418213,50000 -1891.49284529686,1.528019905090332,15161.485805511476,32632,0,15161.485805511476,0.2743000090122223,3.608689308166504,10000,17056.36390018463,0.3854101598262787,2.822803497314453,0.3571199774742126,2.989459991455078,50000 -1943.4836995601647,1.5617244243621826,15581.891350269318,33538,0,15581.891350269318,0.2797000110149383,3.5605785846710205,10000,17528.8449883461,0.3950781226158142,2.7515311241149902,0.360179990530014,2.970732450485229,50000 -1994.9713623523712,1.596980094909668,16001.988507032394,34448,0,16001.988507032394,0.268200010061264,3.622101306915283,10000,18000.516926765442,0.3739648461341858,2.8991634845733643,0.3464599847793579,3.0483076572418213,50000 -2045.7526659965515,1.6344339847564695,16422.013520240784,35352,0,16422.013520240784,0.2821000218391418,3.5719974040985107,10000,18471.41180539131,0.3836914002895355,2.843531370162964,0.3626999855041504,2.977720260620117,50000 -2097.913378477097,1.6660008430480957,16842.142607688904,36258,0,16842.142607688904,0.2752000093460083,3.615450620651245,10000,18943.7842271328,0.3859765529632568,2.824381351470948,0.3565199971199035,3.010390520095825,50000 -2148.6515715122223,1.697392463684082,17262.063611745834,37164,0,17262.063611745834,0.2762000262737274,3.5880584716796875,10000,19414.52570724488,0.3831445276737213,2.815798282623291,0.3597999811172485,2.957936286926269,50000 -2202.1193537712097,1.732421636581421,17682.347589969635,38070,0,17682.347589969635,0.2651000022888183,3.700192451477051,10000,19888.36406493187,0.3713281154632568,2.9289300441741943,0.3464599847793579,3.083930969238281,50000 -2254.9687502384186,1.7733440399169922,18102.575630664825,38976,0,18102.575630664825,0.2685000002384186,3.625632286071777,10000,20361.533406972885,0.3922656178474426,2.8164315223693848,0.3563799858093261,3.0107686519622803,50000 -2305.820141553879,1.8208367824554443,18522.856332540512,39884,0,18522.856332540512,0.2786000072956085,3.6124985218048096,10000,20832.76473426819,0.3822265565395355,2.884771585464477,0.3570599853992462,3.043698787689209,50000 -2355.927830219269,1.8549113273620603,18942.836881637573,40792,0,18942.836881637573,0.2778000235557556,3.527748107910156,10000,21302.93864560128,0.3929687440395355,2.77851676940918,0.369079977273941,2.914356231689453,50000 -2406.526533842087,1.887979984283448,19362.784957170486,41694,0,19362.784957170486,0.268200010061264,3.675233840942383,10000,21773.56910061836,0.3770507872104645,2.974250078201294,0.3516799807548523,3.125171184539795,50000 -2456.8914000988007,1.921765804290772,19783.164499998093,42601,0,19783.164499998093,0.2657000124454498,3.694946765899658,10000,22244.399122715,0.3780468702316284,2.9276158809661865,0.3479799926280975,3.1044631004333496,50000 -2508.969986438751,1.9536054134368896,20203.380412817,43515,0,20203.380412817,0.2962000072002411,3.4557039737701416,10000,22716.77877688408,0.4089257717132568,2.723781108856201,0.3783199787139892,2.8741984367370605,50000 -2558.9942378997803,1.9841821193695068,20623.665967941284,44422,0,20623.665967941284,0.2747000157833099,3.624547004699707,10000,23187.171246290207,0.3873632848262787,2.8844244480133057,0.3587599992752075,3.038835287094116,50000 -2609.552106142044,2.016412734985352,21043.75366783142,45332,0,21043.75366783142,0.2851999998092651,3.533428192138672,10000,23657.901322603226,0.4272656142711639,2.607675313949585,0.367059975862503,2.9391558170318604,50000 -2663.1634533405304,2.049370527267456,21463.72553873062,46241,0,21463.72553873062,0.2811000049114227,3.5713536739349365,10000,24131.569739818573,0.3944140672683716,2.8179240226745605,0.3656199872493744,2.9761483669281006,50000 -2715.443905115128,2.0853078365325928,21884.07934308052,47149,0,21884.07934308052,0.2869000136852264,3.5421478748321533,10000,24604.291342496872,0.4058789014816284,2.753173828125,0.374239981174469,2.925992965698242,50000 -2767.4498670101166,2.1263139247894287,22304.013244628903,48055,0,22304.013244628903,0.2960000038146972,3.4566426277160645,10000,25076.32378435135,0.4291796684265136,2.6209170818328857,0.3867599964141845,2.860398054122925,50000 -2817.928693294525,2.163120985031128,22724.27734947205,48963,0,22724.27734947205,0.3047000169754028,3.414707899093628,10000,25547.15549659729,0.4155077934265136,2.642244815826416,0.3876399993896484,2.803826093673706,50000 -2869.654707431793,2.1987524032592773,23144.547844409943,49872,0,23144.547844409943,0.3007000088691711,3.3947463035583496,10000,26019.239842414856,0.4216601550579071,2.622653722763061,0.3941999971866607,2.784731388092041,50000 -2922.45498919487,2.235773801803589,23564.829872846603,50782,0,23564.829872846603,0.3039000034332275,3.3923990726470947,10000,26492.41095542908,0.431640625,2.567415475845337,0.3929999768733978,2.786828756332397,50000 -2975.215850353241,2.2752928733825684,23985.074984312057,51689,0,23985.074984312057,0.3032000064849853,3.374166250228882,10000,26965.50780677796,0.4244335889816284,2.612565279006958,0.3941600024700165,2.762552499771118,50000 -3025.862987279892,2.319564819335937,24405.261734247208,52597,0,24405.261734247208,0.3110000193119049,3.386230230331421,10000,27436.437440156937,0.4292578101158142,2.5953733921051025,0.3981399834156036,2.7684292793273926,50000 -3078.066597223282,2.363283157348633,24825.304622650143,53503,0,24825.304622650143,0.2998000085353851,3.433039903640747,10000,27908.778692007065,0.4274999797344208,2.5975615978240967,0.3849999904632568,2.821701765060425,50000 -3129.888496398926,2.3979485034942627,25245.264453172684,54413,0,25245.264453172684,0.3037000000476837,3.377845764160156,10000,28380.64738035202,0.4327929615974426,2.579705238342285,0.3990999758243561,2.741082191467285,50000 -3181.6526210308075,2.4328207969665527,25665.59147167205,55323,0,25665.59147167205,0.3120000064373016,3.340991973876953,10000,28852.82416653633,0.4314843714237213,2.5818235874176025,0.4027799963951111,2.7566277980804443,50000 -3233.405528306961,2.467010736465454,26085.61327648163,56233,0,26085.61327648163,0.3122000098228454,3.344634532928467,10000,29324.68583965301,0.4399023354053497,2.5200986862182617,0.407260000705719,2.699357986450196,50000 -3284.786096572876,2.5048763751983643,26505.705159902573,57145,0,26505.705159902573,0.3115000128746032,3.323418140411377,10000,29796.24742078781,0.4365624785423279,2.546628475189209,0.4060999751091003,2.6989681720733643,50000 -3335.712982416153,2.543241739273072,26925.716034412384,58056,0,26925.716034412384,0.3158000111579895,3.3459408283233643,10000,30267.275554418564,0.4286132752895355,2.626063108444214,0.4068599939346313,2.7548227310180664,50000 -3388.341548681259,2.5799450874328613,27345.927248716354,58965,0,27345.927248716354,0.3166000247001648,3.356484651565552,10000,30740.20338773728,0.4401562511920929,2.5463802814483643,0.4059799909591675,2.7379801273345947,50000 -3442.583624601364,2.620101928710937,27766.268231868744,59875,0,27766.268231868744,0.3135000169277191,3.311406373977661,10000,31214.8779771328,0.4756445288658142,2.349417924880981,0.4097599983215332,2.690446376800537,50000 -3494.8005118370056,2.661822557449341,28186.384392499924,60784,0,28186.384392499924,0.3251000046730041,3.2931430339813232,10000,31687.30434346199,0.440253883600235,2.530428647994995,0.4153399765491485,2.6711795330047607,50000 -3547.8963441848755,2.700721263885498,28606.57464933396,61693,0,28606.57464933396,0.3247000277042389,3.2571167945861816,10000,32160.68154025078,0.4561132788658142,2.4191908836364746,0.4205799996852875,2.6124274730682373,50000 -3601.364328861237,2.742002010345459,29026.54181861877,62604,0,29026.54181861877,0.3195000290870666,3.280048131942749,10000,32634.209529399872,0.4695117175579071,2.3920223712921143,0.4214800000190735,2.663222074508667,50000 -3653.226641178131,2.778369903564453,29446.83342576027,63516,0,29446.83342576027,0.3118000030517578,3.331474781036377,10000,33106.452523231506,0.4356835782527923,2.5817770957946777,0.4103799760341644,2.7430014610290527,50000 -3707.521521806717,2.814645767211914,29867.208926677704,64423,0,29867.208926677704,0.3339000046253204,3.209775447845459,10000,33581.21010494232,0.4575585722923279,2.4170541763305664,0.4262599945068359,2.5896904468536377,50000 -3760.2291600704193,2.8578405380249023,30287.316556453705,65330,0,30287.316556453705,0.3272000253200531,3.247157573699951,10000,34054.11939907074,0.4688476324081421,2.36328673362732,0.4254999756813049,2.597842216491699,50000 -3812.165778875351,2.89416766166687,30707.30086231232,66240,0,30707.30086231232,0.3324000239372253,3.1866111755371094,10000,34526.12924003601,0.4552929699420929,2.4185917377471924,0.4275999963283539,2.574275016784668,50000 -3866.608591794968,2.933951139450073,31127.646122694016,67149,0,31127.646122694016,0.332800030708313,3.1850574016571045,10000,35001.00862288475,0.4733007848262787,2.328538417816162,0.4402399957180023,2.4984281063079834,50000 -3919.500876426697,2.9747848510742188,31547.65611815453,68058,0,31547.65611815453,0.333400011062622,3.2271764278411865,10000,35474.00371336937,0.4591992199420929,2.4289512634277344,0.4236399829387665,2.623819351196289,50000 -3971.823595046997,3.0140380859375,31967.69016432762,68967,0,31967.69016432762,0.3456000089645386,3.1654460430145264,10000,35946.450082063675,0.4648632705211639,2.3892576694488525,0.4384599924087524,2.5289251804351807,50000 -4026.2420842647552,3.0550498962402344,32388.05753827095,69875,0,32388.05753827095,0.3344000279903412,3.193197011947632,10000,36421.328429460526,0.4686718583106994,2.3740696907043457,0.4352200031280517,2.5472748279571533,50000 -4078.6816279888153,3.0984416007995605,32808.10027337074,70782,0,32808.10027337074,0.3368000090122223,3.184787034988404,10000,36893.90516161919,0.4740038812160492,2.324805974960327,0.4377000033855438,2.5263214111328125,50000 -4131.9138832092285,3.138212442398072,33228.35967874527,71692,0,33228.35967874527,0.3339000046253204,3.219151020050049,10000,37367.487449646,0.4591406285762787,2.428839921951294,0.4280000030994415,2.5926687717437744,50000 -4186.46707201004,3.180236339569092,33648.43023991585,72601,0,33648.43023991585,0.3451000154018402,3.1486642360687256,10000,37842.20454144478,0.4733984172344208,2.335913896560669,0.442220002412796,2.501166582107544,50000 -4240.780715942383,3.2186429500579834,34068.49537944794,73512,0,34068.49537944794,0.34170001745224,3.1700222492218018,10000,38316.67360329628,0.4789257645606994,2.335693359375,0.4441199898719787,2.5147647857666016,50000 -4291.616386175156,3.258612632751465,34488.55130815506,74422,0,34488.55130815506,0.3525000214576721,3.0565521717071533,10000,38787.65641450882,0.514941394329071,2.1037821769714355,0.456279993057251,2.4261739253997803,50000 -4342.585709571838,3.3010001182556152,34908.481731414795,75331,0,34908.481731414795,0.3541000187397003,3.067671537399292,10000,39258.649804115295,0.4822070300579071,2.265528440475464,0.4532800018787384,2.43257212638855,50000 -4392.15483045578,3.3405954837799072,35328.56245660782,76238,0,35328.56245660782,0.3431000113487243,3.140433549880981,10000,39728.39041757584,0.4829687476158142,2.3049960136413574,0.4452399909496307,2.4958457946777344,50000 -4444.11982011795,3.381718873977661,35748.48908209801,77145,0,35748.48908209801,0.3474000096321106,3.1019248962402344,10000,40200.3743493557,0.5133984088897705,2.1134865283966064,0.4520799815654754,2.433824062347412,50000 -4495.4370748996735,3.4243533611297607,36168.59254050255,78051,0,36168.59254050255,0.3499000072479248,3.089308261871338,10000,40671.88909912109,0.4828515648841858,2.2945902347564697,0.4523399770259857,2.466150522232056,50000 -4546.794453859329,3.463703393936157,36588.80544137955,78960,0,36588.80544137955,0.3455000221729278,3.0912413597106934,10000,41143.55085515976,0.4942382574081421,2.219955921173096,0.4580599963665008,2.418983459472656,50000 -4600.427646636963,3.510913848876953,37008.87369513512,79864,0,37008.87369513512,0.3482000231742859,3.0952043533325195,10000,41617.35136270523,0.5013476610183716,2.2078988552093506,0.4576599895954132,2.443406581878662,50000 -4653.40140414238,3.551713705062866,37429.11387252808,80773,0,37429.11387252808,0.3550000190734863,3.052574396133423,10000,42090.65738582611,0.4915429651737213,2.2458813190460205,0.4599799811840057,2.409734010696411,50000 -4705.331708192825,3.597206354141236,37849.28627562523,81681,0,37849.28627562523,0.3586000204086303,3.044772148132324,10000,42562.85619473457,0.4970312416553497,2.2210443019866943,0.4611999988555908,2.407777547836304,50000 -4756.069163322449,3.641781806945801,38269.44111919403,82586,0,38269.44111919403,0.3614000082015991,3.003753185272217,10000,43033.84465241432,0.5177343487739563,2.106281042098999,0.4700599908828735,2.348548889160156,50000 -4810.5310571193695,3.686613321304321,38689.66334319115,83497,0,38689.66334319115,0.3677000105381012,2.9964702129364014,10000,43508.62535381317,0.4992382824420929,2.189939260482788,0.4672999978065491,2.3583219051361084,50000 -4862.211612701416,3.726073265075684,39109.79868769646,84399,0,39109.79868769646,0.3638000190258026,3.0029189586639404,10000,43980.53152704239,0.5051953196525574,2.1985788345336914,0.4702000021934509,2.3659684658050537,50000 -4915.274274110794,3.7699716091156006,39529.994701862335,85308,0,39529.994701862335,0.35630002617836,3.0596253871917725,10000,44453.88608646393,0.505664050579071,2.189625263214112,0.4622199833393097,2.41805386543274,50000 -4965.784074783325,3.8100292682647705,39950.410684108734,86215,0,39950.410684108734,0.3636000156402588,3.021517038345337,10000,44924.90372800827,0.4994335770606994,2.226904392242432,0.4710399806499481,2.3905842304229736,50000 -5017.331121683121,3.862768650054932,40370.71117019653,87118,0,40370.71117019653,0.3753000199794769,2.977247714996338,10000,45396.8552069664,0.5128515362739563,2.143615484237671,0.4809999763965606,2.3133881092071533,50000 -5068.761607885361,3.913156747817993,40790.89816689491,88021,0,40790.89816689491,0.3805000185966491,2.938375234603882,10000,45868.57350111008,0.5203710794448853,2.086500883102417,0.4777399897575378,2.3107686042785645,50000 -5119.981722831726,3.954137563705444,41210.95104074478,88927,0,41210.95104074478,0.3777000308036804,2.9439592361450195,10000,46339.93909049034,0.5184179544448853,2.0980348587036133,0.4792400002479553,2.289841413497925,50000 -5171.822760105133,3.997264623641968,41631.14861416817,89836,0,41631.14861416817,0.3736000061035156,2.944952964782715,10000,46812.072179317474,0.5190234184265137,2.101839303970337,0.4875999987125397,2.2763702869415283,50000 -5225.636157512665,4.036552667617798,42051.27122282982,90743,0,42051.27122282982,0.3727000057697296,2.997620820999145,10000,47286.09899115562,0.5146093368530273,2.1558313369750977,0.475519984960556,2.3405981063842773,50000 -5277.445053577423,4.078348875045776,42471.242216825485,91651,0,42471.242216825485,0.3713000118732452,2.993382453918457,10000,47757.972594976425,0.5436718463897705,2.0345213413238525,0.4791199862957001,2.349999189376831,50000 -5330.52893781662,4.1219470500946045,42891.40844297409,92556,0,42891.40844297409,0.3812000155448913,2.9217004776000977,10000,48231.31739473343,0.5254101157188416,2.0577504634857178,0.4925599992275238,2.2484467029571533,50000 -5383.027472257614,4.1651999950408936,43311.48286437988,93464,0,43311.48286437988,0.3795000314712524,2.9080698490142822,10000,48703.985027074814,0.5331054329872131,2.015186786651612,0.4948799908161163,2.220568418502808,50000 -5435.459993362427,4.21837306022644,43731.60246658325,94368,0,43731.60246658325,0.3873000144958496,2.886227607727051,10000,49176.6410984993,0.5615038871765137,1.8956527709960933,0.5003199577331543,2.2095954418182373,50000 -5486.836786031723,4.264795303344727,44151.640016794205,95273,0,44151.640016794205,0.3952000141143799,2.8237500190734863,10000,49648.15318131447,0.5392187237739563,1.992055058479309,0.5055800080299377,2.1670339107513428,50000 -5540.123233795166,4.307686805725098,44571.82088375092,96179,0,44571.82088375092,0.3921000063419342,2.8492650985717773,10000,50121.71514558792,0.5384374856948853,2.011241912841797,0.5003399848937988,2.213174343109131,50000 -5589.800358533859,4.35940146446228,44992.06829333305,97086,0,44992.06829333305,0.398000031709671,2.8221137523651123,10000,50591.74269366264,0.5535546541213989,1.9242347478866573,0.5067799687385559,2.171412706375122,50000 -5644.221765995026,4.40524959564209,45412.143065452576,97993,0,45412.143065452576,0.3891000151634216,2.8874340057373047,10000,51066.33588290215,0.532910168170929,2.055387020111084,0.4993399977684021,2.2423460483551025,50000 -5695.656700849533,4.450121402740479,45832.18708944321,98898,0,45832.18708944321,0.3814000189304352,2.900303363800049,10000,51537.9104912281,0.5400781035423279,2.028705358505249,0.5006399750709534,2.236777067184448,50000 -5748.635198116303,4.504448413848877,46252.08383107185,99804,0,46252.08383107185,0.3951000273227691,2.8197717666625977,10000,52010.89158630371,0.5590429306030273,1.925033688545227,0.5095599889755249,2.1697280406951904,50000 -5800.427268981934,4.545479774475098,46672.03926706314,100711,0,46672.03926706314,0.3897000253200531,2.8789451122283936,10000,52482.73132753372,0.5338476300239563,2.064886569976806,0.4959999918937683,2.255467176437378,50000 -5851.827321767807,4.592525005340576,47091.95157575607,101614,0,47091.95157575607,0.4047000110149383,2.784193754196167,10000,52954.14155244827,0.5544726252555847,1.9149682521820068,0.5176399946212769,2.104093313217163,50000 -5902.980116844177,4.639246225357056,47512.10585308075,102519,0,47512.10585308075,0.4044000208377838,2.776407241821289,10000,53425.54627323151,0.558300793170929,1.900124430656433,0.5158199667930603,2.11898159980774,50000 -5955.425267219544,4.68032431602478,47932.22675728798,103426,0,47932.22675728798,0.4073000252246856,2.7569947242736816,10000,53898.2055375576,0.5542968511581421,1.920783758163452,0.5221399664878845,2.0911567211151123,50000 -6007.061721801758,4.724331378936768,48352.40204453468,104333,0,48352.40204453468,0.4038000106811523,2.7949371337890625,10000,54370.11286449432,0.5587499737739563,1.940987467765808,0.5200799703598022,2.1409268379211426,50000 -6060.579028129578,4.768460273742676,48772.720823049545,105241,0,48772.720823049545,0.403300017118454,2.796470880508423,10000,54844.04375267029,0.560839831829071,1.9002279043197632,0.514959990978241,2.125225067138672,50000 -6112.325652837753,4.819467544555664,49193.001353263855,106150,0,49193.001353263855,0.4099000096321106,2.748702049255371,10000,55316.173254966736,0.5675585865974426,1.8988195657730105,0.5279399752616882,2.0975606441497803,50000 -6164.696760416031,4.867316961288452,49612.96082854271,107055,0,49612.96082854271,0.4122000336647033,2.7401907444000244,10000,55788.6028881073,0.5690038800239563,1.8810782432556152,0.5263199806213379,2.085660457611084,50000 -6218.849471569061,4.915220022201538,50033.20030045509,107961,0,50033.20030045509,0.417900025844574,2.732545852661133,10000,56263.09414482117,0.5733398199081421,1.872202754020691,0.5306199789047241,2.0735056400299072,50000 -6271.163687944412,4.974422931671143,50453.37967920303,108871,0,50453.37967920303,0.41880002617836,2.701533317565918,10000,56735.69892835617,0.6080859303474426,1.6695939302444458,0.5309799909591675,2.035268545150757,50000 -6325.108122825623,5.025132894515991,50873.40883421898,109779,0,50873.40883421898,0.4267000257968902,2.641922950744629,10000,57209.77555155754,0.5797070264816284,1.7871202230453491,0.5417999625205994,1.9782638549804688,50000 -6375.277026414871,5.072314500808716,51293.60374307632,110687,0,51293.60374307632,0.419400006532669,2.7238848209381104,10000,57680.238035440445,0.5723242163658142,1.8696619272232056,0.5303800106048584,2.078565120697021,50000 -6427.413257360458,5.120913028717041,51713.767781972885,111597,0,51713.767781972885,0.4234000146389007,2.6781985759735107,10000,58152.63842535019,0.6044335961341858,1.6973642110824585,0.542639970779419,2.007513999938965,50000 -6480.615281820297,5.165015459060669,52134.00865268707,112503,0,52134.00865268707,0.431300014257431,2.65842080116272,10000,58626.1765999794,0.5824804306030273,1.7816941738128662,0.5447999835014343,1.961826205253601,50000 -6532.158484697342,5.209194660186768,52554.14491820336,113409,0,52554.14491820336,0.4369000196456909,2.59963321685791,10000,59097.95101642609,0.5881054401397705,1.7460676431655884,0.5488399863243103,1.9370585680007928,50000 -6583.523062705994,5.259409666061401,52974.33516526222,114316,0,52974.33516526222,0.4376000165939331,2.588284969329834,10000,59569.60813641548,0.6094335913658142,1.6528574228286743,0.5530999898910522,1.9297101497650144,50000 -6635.696802854538,5.305138349533081,53394.59422135353,115222,0,53394.59422135353,0.4300000071525574,2.6224279403686523,10000,60042.13766551018,0.5908203125,1.7480441331863403,0.5513399839401245,1.947059988975525,50000 -6688.041565418243,5.350852966308594,53814.87706851959,116126,0,53814.87706851959,0.4336000084877014,2.6125783920288086,10000,60514.86140489578,0.5944726467132568,1.7341322898864746,0.5532999634742737,1.9365402460098269,50000 -6742.556987047195,5.396175384521484,54235.36640357971,117031,0,54235.36640357971,0.4370000064373016,2.60026216506958,10000,60989.96281218529,0.607128918170929,1.6637282371520996,0.5559200048446655,1.909610152244568,50000 -6794.329073667526,5.449033498764038,54655.59771943092,117937,0,54655.59771943092,0.4412000179290771,2.5633251667022705,10000,61462.07122516632,0.5986914038658142,1.7124024629592896,0.5584200024604797,1.9031184911727903,50000 -6847.463381290436,5.4985644817352295,55075.97532296181,118844,0,55075.97532296181,0.4528000354766845,2.5325958728790283,10000,61935.68432497978,0.6104296445846558,1.6610846519470217,0.5698599815368652,1.871710538864136,50000 -6898.902683258057,5.5450968742370605,55496.04265832901,119751,0,55496.04265832901,0.4409000277519226,2.555757999420166,10000,62407.2891471386,0.6143554449081421,1.6402643918991089,0.5634399652481079,1.8867182731628416,50000 -6949.624935626984,5.5959556102752686,55916.21435809136,120659,0,55916.21435809136,0.4469000101089477,2.5595107078552246,10000,62878.28482103348,0.6095117330551147,1.7067832946777344,0.5638799667358398,1.9024053812026973,50000 -7001.017154455185,5.640802383422852,56336.36602210999,121566,0,56336.36602210999,0.4475000202655792,2.534411907196045,10000,63349.9245557785,0.6133007407188416,1.6413342952728271,0.5678399801254272,1.860566258430481,50000 -7051.67671251297,5.691583156585693,56756.33172440529,122472,0,56756.33172440529,0.4550000131130218,2.503061532974243,10000,63820.65265607834,0.62255859375,1.6071964502334597,0.5728999972343445,1.8406982421875,50000 -7104.192118406296,5.741584062576294,57176.62945842743,123377,0,57176.62945842743,0.4579000174999237,2.4760637283325195,10000,64293.56670308113,0.6289257407188416,1.570678949356079,0.579479992389679,1.805068850517273,50000 -7156.915832519531,5.790433645248413,57596.96713638306,124287,0,57596.96713638306,0.4643000364303589,2.429023504257202,10000,64766.72912335396,0.6289257407188416,1.5635874271392822,0.5827199816703796,1.768791675567627,50000 -7208.776687383652,5.846323251724243,58017.08203911781,125193,0,58017.08203911781,0.4616000354290008,2.4752357006073,10000,65238.811891555786,0.6286327838897705,1.5781629085540771,0.5806199908256531,1.8158116340637207,50000 -7259.260968923569,5.892705202102661,58437.05323433876,126099,0,58437.05323433876,0.4726000130176544,2.447232484817505,10000,65709.36529541016,0.6649023294448853,1.4432421922683716,0.5847399830818176,1.7886931896209717,50000 -7309.869294166565,5.941359758377075,58857.087052583694,127005,0,58857.087052583694,0.4724000096321106,2.40802001953125,10000,66180.10751318932,0.6364062428474426,1.526646852493286,0.5922799706459045,1.7352849245071411,50000 -7363.129547357559,5.994652986526489,59277.23089051247,127910,0,59277.23089051247,0.4738000333309173,2.39067816734314,10000,66653.61686849594,0.6415820121765137,1.496817708015442,0.5906599760055542,1.746111512184143,50000 -7415.558722019196,6.044222116470337,59697.32144474983,128813,0,59697.32144474983,0.4759000241756439,2.3901915550231934,10000,67126.2372546196,0.6640819907188416,1.4066020250320437,0.5945999622344971,1.7294906377792358,50000 -7466.64894080162,6.103210926055908,60117.41744160652,129719,0,60117.41744160652,0.4695000350475311,2.420825242996216,10000,67597.53349399567,0.6342382431030273,1.546700358390808,0.5904600024223328,1.7574831247329712,50000 -7517.243643760681,6.158584356307983,60537.58105373383,130626,0,60537.58105373383,0.4785000085830688,2.387399196624756,10000,68068.39860129356,0.6511132717132568,1.4745047092437744,0.5997399687767029,1.7100039720535278,50000 -7571.182153224945,6.214536190032959,60957.56353855133,131535,0,60957.56353855133,0.4800000190734863,2.348107099533081,10000,68542.42744445801,0.6660742163658142,1.3897231817245483,0.6065599918365479,1.6731947660446167,50000 -7621.960869312286,6.263074159622192,61377.6965944767,132442,0,61377.6965944767,0.48130002617836,2.3436646461486816,10000,69013.43845295906,0.6496874690055847,1.4811943769454956,0.6088399887084961,1.6800014972686768,50000 -7675.555843830109,6.311018705368042,61797.79363250733,133349,0,61797.79363250733,0.4870000183582306,2.3509786128997803,10000,69487.23033475876,0.6616796851158142,1.4510056972503662,0.6108399629592896,1.6753292083740234,50000 -7729.499813079834,6.358997106552124,62217.71622133255,134257,0,62217.71622133255,0.4892000257968902,2.3225879669189453,10000,69961.19663286209,0.6688085794448853,1.3909156322479248,0.6072199940681458,1.6612194776535034,50000 -7783.615670204163,6.407909154891968,62637.634321689606,135162,0,62637.634321689606,0.4907000362873077,2.3041481971740723,10000,70435.33058691025,0.6540820002555847,1.4487427473068235,0.6113799810409546,1.649997591972351,50000 -7834.759024858475,6.456115007400513,63057.867579460144,136068,0,63057.867579460144,0.4901000261306762,2.3076579570770264,10000,70906.8068845272,0.6649999618530273,1.3921610116958618,0.6185199618339539,1.6188608407974243,50000 -7885.648756742477,6.504778623580933,63478.281381607056,136976,0,63478.281381607056,0.496800035238266,2.2707793712615967,10000,71378.2106127739,0.6775000095367432,1.342071533203125,0.6184399724006653,1.606795072555542,50000 -7936.875230550766,6.928737640380859,63898.23468732834,137884,0,63898.23468732834,0.4948000311851501,2.288226366043091,10000,71849.86646604538,0.6682812571525574,1.403730034828186,0.6224200129508972,1.6151797771453855,50000 -7987.347238540649,6.983047246932983,64318.29056835175,138791,0,64318.29056835175,0.4999000132083893,2.231370687484741,10000,72320.50111722946,0.6753710508346558,1.351009488105774,0.6233800053596497,1.5772992372512815,50000 -8039.49192738533,7.045411348342896,64738.50710082054,139699,0,64738.50710082054,0.5088000297546387,2.2103207111358643,10000,72792.97579598427,0.6874608993530273,1.287361741065979,0.6303799748420715,1.5501196384429932,50000 -8092.068526983261,7.1019439697265625,65158.44564986229,140606,0,65158.44564986229,0.5029000043869019,2.2430734634399414,10000,73265.60011100769,0.684277355670929,1.3224250078201294,0.6326999664306641,1.5607450008392334,50000 -8143.69774889946,7.15775990486145,65578.51907277107,141514,0,65578.51907277107,0.5126000046730042,2.173495054244995,10000,73737.41035723686,0.6889257431030273,1.2937071323394775,0.6418200135231018,1.511349320411682,50000 -8195.071061849594,7.205103397369385,65998.46021318436,142421,0,65998.46021318436,0.513700008392334,2.185518980026245,10000,74208.82389330864,0.6948828101158142,1.2586766481399536,0.6380000114440918,1.5189902782440186,50000 -8247.3456325531,7.25283932685852,66418.74589681625,143326,0,66418.74589681625,0.5218000411987305,2.157728672027588,10000,74681.48336172104,0.7229296565055847,1.147066593170166,0.6485999822616577,1.4839391708374023,50000 -8300.968694925308,7.303514003753662,66839.10083460808,144235,0,66839.10083460808,0.5186000466346741,2.162790298461914,10000,75155.56358337402,0.6969531178474426,1.259919285774231,0.6465399861335754,1.4953464269638062,50000 -8352.526899814606,7.353893756866455,67259.42416167259,145143,0,67259.42416167259,0.5236000418663025,2.1719541549682617,10000,75627.54675364494,0.6981835961341858,1.2710505723953247,0.6437000036239624,1.523866057395935,50000 -8404.179302930832,7.413357973098755,67679.52076888084,146049,0,67679.52076888084,0.523300051689148,2.157883644104004,10000,76099.40624260902,0.7145116925239563,1.1924309730529783,0.6484400033950806,1.4934139251708984,50000 -8457.554328680038,7.466028690338135,68099.78795552254,146958,0,68099.78795552254,0.530500054359436,2.0883536338806152,10000,76573.15300369263,0.7075585722923279,1.199978590011597,0.655239999294281,1.4392666816711426,50000 -8508.766267299652,7.52376914024353,68519.88999271393,147867,0,68519.88999271393,0.5245000123977661,2.122317314147949,10000,77044.57645773888,0.7089257836341858,1.191853165626526,0.6549599766731262,1.450493097305298,50000 -8561.076631069183,7.576759338378906,68940.13805937767,148775,0,68940.13805937767,0.5313000082969666,2.1140568256378174,10000,77517.2398583889,0.72328120470047,1.1626847982406616,0.656279981136322,1.4608744382858276,50000 -8612.59052824974,7.63094162940979,69360.1499080658,149682,0,69360.1499080658,0.5441000461578369,2.061110496520996,10000,77988.87147164345,0.71742182970047,1.170621037483215,0.6571800112724304,1.4194693565368652,50000 -8664.59189748764,7.682143688201904,69780.04849529266,150589,0,69780.04849529266,0.5406000018119812,2.0501999855041504,10000,78460.87410640717,0.7193945050239563,1.145675778388977,0.6657599806785583,1.3892215490341189,50000 -8717.290427207947,7.740187168121338,70200.33002853394,151496,0,70200.33002853394,0.5424000024795532,2.038510799407959,10000,78933.96350884438,0.7369726300239563,1.0826455354690552,0.6676999926567078,1.3798246383666992,50000 -8770.050068855286,7.791531801223755,70620.6611096859,152404,0,70620.6611096859,0.5423000454902649,2.045349597930908,10000,79407.156021595,0.7234960794448853,1.141112208366394,0.6663999557495117,1.381853461265564,50000 -8823.263661623001,7.842597007751465,71040.6793308258,153312,0,71040.6793308258,0.5484000444412231,2.018724203109741,10000,79880.49059653282,0.7329882383346558,1.0937864780426023,0.6733799576759338,1.3574568033218384,50000 -8873.951783180237,7.901000738143921,71460.68648219109,154220,0,71460.68648219109,0.5509000420570374,1.9991964101791384,10000,80351.2955019474,0.7409570217132568,1.0447590351104736,0.6767799854278564,1.337766408920288,50000 -8928.184937000275,7.954378366470337,71880.62465643883,155127,0,71880.62465643883,0.5489000082015991,2.018172025680542,10000,80825.57174110413,0.7374804615974426,1.092585206031799,0.6778199672698975,1.3446279764175415,50000 -8979.215026378632,8.012673616409302,72300.71397519112,156033,0,72300.71397519112,0.5550000071525574,1.985244870185852,10000,81296.80072975159,0.7378515601158142,1.0726988315582275,0.6809799671173096,1.3264795541763306,50000 -9029.969490528109,8.063661575317383,72720.9410059452,156942,0,72720.9410059452,0.5562000274658203,1.9805034399032595,10000,81767.88509273529,0.7446679472923279,1.0540313720703125,0.6827600002288818,1.3250590562820437,50000 -9083.979422330856,8.117793560028076,73141.18209695816,157851,0,73141.18209695816,0.5621000528335571,1.955943703651428,10000,82242.24174976349,0.7596679329872131,0.9815232753753662,0.6845600008964539,1.304532527923584,50000 -9134.91816353798,8.174100875854492,73561.48490953445,158758,0,73561.48490953445,0.5637000203132629,1.9300532341003416,10000,82713.59108018875,0.7476562261581421,1.031006932258606,0.6883599758148193,1.2879359722137451,50000 -9185.332689285278,8.230481386184692,73981.68310594559,159669,0,73981.68310594559,0.5711000561714172,1.9215903282165527,10000,83184.31155920029,0.7578319907188416,0.9961916208267212,0.69159996509552,1.2772376537322998,50000 -9240.204246282578,8.28437876701355,74401.90697979927,160576,0,74401.90697979927,0.5687000155448914,1.9138787984848025,10000,83659.51302075386,0.7697070240974426,0.938443958759308,0.6931599974632263,1.2678637504577637,50000 -9291.293983697891,8.341437816619873,74822.10788369179,161485,0,74822.10788369179,0.5685999989509583,1.8960278034210205,10000,84130.91235041618,0.7563671469688416,0.976298451423645,0.6957199573516846,1.2451553344726562,50000 -9343.114990472794,8.395569086074829,75242.05846524239,162394,0,75242.05846524239,0.5740000009536743,1.8750234842300413,10000,84602.79054522514,0.76527339220047,0.9461551308631896,0.7001199722290039,1.229008674621582,50000 -9394.80528140068,8.456037521362305,75661.95459794998,163300,0,75661.95459794998,0.5771000385284424,1.875489592552185,10000,85074.48866915703,0.7712695002555847,0.9198420643806458,0.701259970664978,1.2253583669662476,50000 -9447.132781267166,8.510085105895996,76082.06962704659,164208,0,76082.06962704659,0.5770000219345093,1.8578152656555176,10000,85547.03738641739,0.7705273032188416,0.9230221509933472,0.7032399773597717,1.2151079177856443,50000 -9500.79858636856,8.563549280166626,76502.12402510643,165117,0,76502.12402510643,0.5834000110626221,1.851784825325012,10000,86020.86357402802,0.7754296660423279,0.913085401058197,0.7068799734115601,1.209192991256714,50000 -9551.27962064743,8.616734027862549,76922.23614120483,166022,0,76922.23614120483,0.5837000012397766,1.840895652770996,10000,86491.56161642075,0.7830663919448853,0.8687943816184998,0.7093600034713745,1.190772533416748,50000 -9603.473745822906,8.681183338165283,77342.32014989853,166928,0,77342.32014989853,0.5940000414848328,1.7914539575576782,10000,86963.95547676086,0.7832421660423279,0.871487557888031,0.7151399850845337,1.162901520729065,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/measurements.csv deleted file mode 100644 index b29a444d1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1861 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.37231192,6.9077563,,,,,,,,,,,,,, -1,,,0.0009765625,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,35.27015542984009,64.52321434020996,35.27015542984009,29.252927780151367,0.0,0.0 -100,0.55253536,6.823212,,,,,,,,,,,,,, -200,0.8172916,6.7009325,,,,,,,,,,,,,, -300,0.75907725,6.5808706,,,,,,,,,,,,,, -400,0.99707186,6.6215477,,,,,,,,,,,,,, -500,1.0452117,6.413816,,,,,,,,,,,,,, -600,0.76168597,6.3866057,,,,,,,,,,,,,, -700,0.79361385,6.2539005,,,,,,,,,,,,,, -800,0.723323,6.1751227,,,,,,,,,,,,,, -850,,,0.03515625,5.885966777801514,0.0339199975132942,5.9119391441345215,50000.0,0.026800001040101,6.039024353027344,10000.0,455.6410744190216,532.7995707988739,455.6410744190216,77.09238171577454,0.0176796913146972,0.0 -900,0.9938035,6.347988,,,,,,,,,,,,,, -1000,0.7472816,6.176773,,,,,,,,,,,,,, -1100,0.76724344,6.038356,,,,,,,,,,,,,, -1200,0.75480497,6.0896034,,,,,,,,,,,,,, -1300,0.74082196,5.9728456,,,,,,,,,,,,,, -1400,0.6293219,5.9969096,,,,,,,,,,,,,, -1500,0.8321904,6.0553756,,,,,,,,,,,,,, -1600,0.6700063,5.9065332,,,,,,,,,,,,,, -1700,0.70200026,5.956197,,,,,,,,,,,,,, -1756,,,0.0728906244039535,5.357734680175781,0.0655599981546402,5.43528413772583,50000.0,0.0533000007271766,5.654690742492676,10000.0,875.7305772304535,1003.27015709877,875.7305772304535,127.39190244674684,0.0458791255950927,0.0 -1800,0.62562627,6.130464,,,,,,,,,,,,,, -1900,0.45741743,6.4131784,,,,,,,,,,,,,, -2000,0.4981699,5.8144774,,,,,,,,,,,,,, -2100,0.41845906,6.299509,,,,,,,,,,,,,, -2200,0.50789547,5.741996,,,,,,,,,,,,,, -2300,0.5484276,5.616797,,,,,,,,,,,,,, -2400,0.41595748,6.4199295,,,,,,,,,,,,,, -2500,0.4575557,6.563049,,,,,,,,,,,,,, -2600,0.45111746,5.5120444,,,,,,,,,,,,,, -2665,,,0.1038476526737213,4.981348991394043,0.0971199944615364,5.043646812438965,50000.0,0.0734999999403953,5.327393054962158,10000.0,1295.863152742386,1474.4437997341156,1295.863152742386,178.35270309448242,0.0743927955627441,0.0 -2700,0.4265267,5.4647565,,,,,,,,,,,,,, -2800,0.4515603,6.440812,,,,,,,,,,,,,, -2900,0.5267948,5.6107326,,,,,,,,,,,,,, -3000,0.5722423,5.3895135,,,,,,,,,,,,,, -3100,0.50797796,5.4752607,,,,,,,,,,,,,, -3200,0.505696,6.397774,,,,,,,,,,,,,, -3300,0.6071694,5.383883,,,,,,,,,,,,,, -3400,0.46557218,5.367531,,,,,,,,,,,,,, -3500,0.56486744,5.50188,,,,,,,,,,,,,, -3575,,,0.1374218761920929,4.640011787414551,0.1282400041818618,4.715636253356934,50000.0,0.0984000042080879,5.043509006500244,10000.0,1715.986170053482,1944.7330491542816,1715.986170053482,228.4377381801605,0.1031725406646728,0.0 -3600,0.52467686,6.133027,,,,,,,,,,,,,, -3700,0.61561656,5.290013,,,,,,,,,,,,,, -3800,0.51549244,5.296802,,,,,,,,,,,,,, -3900,0.69056696,5.311901,,,,,,,,,,,,,, -4000,0.7612141,5.217221,,,,,,,,,,,,,, -4100,0.7050866,5.6810875,,,,,,,,,,,,,, -4200,0.58513385,5.14225,,,,,,,,,,,,,, -4300,0.66358364,4.993507,,,,,,,,,,,,,, -4400,0.45342398,5.9181485,,,,,,,,,,,,,, -4482,,,0.1684765517711639,4.419726848602295,0.1536400020122528,4.517857551574707,50000.0,0.1194000020623207,4.872969627380371,10000.0,2136.13942360878,2414.920921087265,2136.13942360878,278.38977098464966,0.1342794895172119,0.0 -4500,1.0700034,5.107685,,,,,,,,,,,,,, -4600,0.87601477,5.041872,,,,,,,,,,,,,, -4700,0.97548306,5.029412,,,,,,,,,,,,,, -4800,0.70146847,4.931555,,,,,,,,,,,,,, -4900,0.7602446,4.9513745,,,,,,,,,,,,,, -5000,0.7757437,6.153802,,,,,,,,,,,,,, -5100,0.76303947,5.3966875,,,,,,,,,,,,,, -5200,0.5982066,5.003382,,,,,,,,,,,,,, -5300,0.7754956,4.918873,,,,,,,,,,,,,, -5394,,,0.1861328035593032,4.220674514770508,0.1708399951457977,4.3279571533203125,50000.0,0.1340000033378601,4.738316059112549,10000.0,2556.4348685741425,2887.315844535828,2556.4348685741425,330.4113621711731,0.1604244709014892,0.0 -5400,0.8310753,4.895216,,,,,,,,,,,,,, -5500,0.9779934,5.021838,,,,,,,,,,,,,, -5600,0.7216487,5.7556496,,,,,,,,,,,,,, -5700,0.66667694,5.021973,,,,,,,,,,,,,, -5800,0.9091356,4.829042,,,,,,,,,,,,,, -5900,0.7314352,4.7906704,,,,,,,,,,,,,, -6000,0.63500905,5.1865616,,,,,,,,,,,,,, -6100,0.8145373,5.9281597,,,,,,,,,,,,,, -6200,0.79108167,4.9629393,,,,,,,,,,,,,, -6300,0.82676667,4.7498245,,,,,,,,,,,,,, -6306,,,0.2262695282697677,3.890352487564087,0.2067800015211105,3.99831748008728,50000.0,0.1624000072479248,4.428020477294922,10000.0,2976.544565677643,3358.3559985160828,2976.544565677643,381.2634997367859,0.18742036819458,0.0 -6400,0.75352204,5.9278083,,,,,,,,,,,,,, -6500,0.8504721,4.642518,,,,,,,,,,,,,, -6600,0.97370267,4.7683992,,,,,,,,,,,,,, -6700,0.75086373,6.1535273,,,,,,,,,,,,,, -6800,0.92901397,5.934276,,,,,,,,,,,,,, -6900,0.93426394,6.1418285,,,,,,,,,,,,,, -7000,0.6737804,6.025516,,,,,,,,,,,,,, -7100,0.97897243,4.8170176,,,,,,,,,,,,,, -7200,0.7591975,4.7731614,,,,,,,,,,,,,, -7218,,,0.232421875,3.807736158370972,0.2166399955749511,3.93320369720459,50000.0,0.165800005197525,4.394979000091553,10000.0,3396.534752845764,3828.797343492508,3396.534752845764,431.6334598064423,0.2173418998718261,0.0 -7300,0.8692689,4.600009,,,,,,,,,,,,,, -7400,0.81564695,4.7430387,,,,,,,,,,,,,, -7500,0.9324722,4.938241,,,,,,,,,,,,,, -7600,0.9841206,4.712836,,,,,,,,,,,,,, -7700,1.002941,4.7123013,,,,,,,,,,,,,, -7800,0.80634266,4.5578084,,,,,,,,,,,,,, -7900,1.0682045,4.6597877,,,,,,,,,,,,,, -8000,1.0448047,4.5888777,,,,,,,,,,,,,, -8100,0.8523327,4.9197526,,,,,,,,,,,,,, -8131,,,0.243203118443489,3.820199728012085,0.2207199931144714,3.940143585205078,50000.0,0.1724000126123428,4.411588668823242,10000.0,3816.9069616794586,4301.491326808929,3816.9069616794586,483.87773609161377,0.2426652908325195,0.0 -8200,0.6335273,6.0491767,,,,,,,,,,,,,, -8300,0.77811885,4.4260683,,,,,,,,,,,,,, -8400,0.7261917,5.687608,,,,,,,,,,,,,, -8500,0.7849367,4.8448653,,,,,,,,,,,,,, -8600,0.82680666,4.787758,,,,,,,,,,,,,, -8700,0.73589396,6.0265784,,,,,,,,,,,,,, -8800,0.79093474,4.479437,,,,,,,,,,,,,, -8900,0.7368625,6.1646776,,,,,,,,,,,,,, -9000,0.71287805,5.504217,,,,,,,,,,,,,, -9041,,,0.2707031071186065,3.5831527709960938,0.2477799952030182,3.710452079772949,50000.0,0.1932000070810318,4.174107074737549,10000.0,4236.489646196365,4773.503950834274,4236.489646196365,535.8006870746613,0.69814133644104,0.0 -9100,0.72651124,4.5147862,,,,,,,,,,,,,, -9200,0.8734748,4.481327,,,,,,,,,,,,,, -9300,0.8897311,4.5332074,,,,,,,,,,,,,, -9400,0.56843483,5.9614673,,,,,,,,,,,,,, -9500,0.9329115,4.3623734,,,,,,,,,,,,,, -9600,0.7424295,4.357006,,,,,,,,,,,,,, -9700,0.80257565,5.390801,,,,,,,,,,,,,, -9800,0.75719404,5.8606987,,,,,,,,,,,,,, -9900,0.81608737,6.3566427,,,,,,,,,,,,,, -9953,,,0.2805859446525574,3.5102574825286865,0.258459985256195,3.65179705619812,50000.0,0.1997000128030777,4.167118549346924,10000.0,4656.436341047287,5244.670069217682,4656.436341047287,586.9349710941315,0.7309134006500244,0.0 -10000,0.9409434,4.5708413,,,,,,,,,,,,,, -10100,0.8423679,4.446596,,,,,,,,,,,,,, -10200,0.76401085,4.6147366,,,,,,,,,,,,,, -10300,0.9308808,4.5679317,,,,,,,,,,,,,, -10400,1.0500841,4.396538,,,,,,,,,,,,,, -10500,0.937792,4.5371003,,,,,,,,,,,,,, -10600,0.9160237,4.428862,,,,,,,,,,,,,, -10700,0.60662675,5.238826,,,,,,,,,,,,,, -10800,0.6223691,5.632261,,,,,,,,,,,,,, -10864,,,0.2991601526737213,3.379409074783325,0.2633799910545349,3.607219934463501,50000.0,0.2043000161647796,4.103435516357422,10000.0,5076.775891304016,5716.142570257187,5076.775891304016,637.9864263534546,0.7608444690704346,0.0 -10900,0.911683,4.5034018,,,,,,,,,,,,,, -11000,0.7750825,5.344603,,,,,,,,,,,,,, -11100,0.8529629,4.360544,,,,,,,,,,,,,, -11200,0.8378531,4.33712,,,,,,,,,,,,,, -11300,0.90481776,4.30059,,,,,,,,,,,,,, -11400,0.9120158,4.9071584,,,,,,,,,,,,,, -11500,0.9483913,6.20292,,,,,,,,,,,,,, -11600,0.7218686,6.0247436,,,,,,,,,,,,,, -11700,0.883753,4.465518,,,,,,,,,,,,,, -11773,,,0.2922460734844208,3.4271626472473145,0.2735399901866913,3.5501279830932617,50000.0,0.2072000056505203,4.052846908569336,10000.0,5497.108896970749,6188.652234077454,5497.108896970749,690.076201915741,0.7959158420562744,0.0 -11800,0.62526184,5.0568805,,,,,,,,,,,,,, -11900,1.1445898,4.5944386,,,,,,,,,,,,,, -12000,0.98662543,4.195829,,,,,,,,,,,,,, -12100,0.71581084,5.3242693,,,,,,,,,,,,,, -12200,1.3324857,4.317966,,,,,,,,,,,,,, -12300,0.96337324,4.388945,,,,,,,,,,,,,, -12400,0.8434965,4.40329,,,,,,,,,,,,,, -12500,0.93150604,4.5178986,,,,,,,,,,,,,, -12600,0.67297494,5.9098063,,,,,,,,,,,,,, -12683,,,0.2971484363079071,3.4095776081085205,0.277099996805191,3.530275821685791,50000.0,0.21390001475811,4.058302879333496,10000.0,5917.175496578217,6661.750261306763,5917.175496578217,743.02845287323,0.8236494064331055,0.0 -12700,1.033222,4.322517,,,,,,,,,,,,,, -12800,0.9638101,4.1972647,,,,,,,,,,,,,, -12900,0.928114,4.401471,,,,,,,,,,,,,, -13000,0.8131576,4.969634,,,,,,,,,,,,,, -13100,0.8697676,4.4556947,,,,,,,,,,,,,, -13200,1.1326288,4.3293324,,,,,,,,,,,,,, -13300,0.93465763,4.1964054,,,,,,,,,,,,,, -13400,1.1402743,4.268597,,,,,,,,,,,,,, -13500,0.84403735,5.094259,,,,,,,,,,,,,, -13592,,,0.3229882717132568,3.2224056720733643,0.28015998005867,3.494879007339477,50000.0,0.2124000042676925,4.011412143707275,10000.0,6337.318806171417,7132.289152383804,6337.318806171417,793.3429248332977,0.8524086475372314,0.0 -13600,0.80134314,4.360967,,,,,,,,,,,,,, -13700,0.88529336,4.720955,,,,,,,,,,,,,, -13800,0.7744684,5.8310585,,,,,,,,,,,,,, -13900,0.86065316,4.450121,,,,,,,,,,,,,, -14000,0.86951965,5.407818,,,,,,,,,,,,,, -14100,0.9889579,4.7810173,,,,,,,,,,,,,, -14200,0.58330977,5.8232493,,,,,,,,,,,,,, -14300,0.82401747,5.859179,,,,,,,,,,,,,, -14400,0.7284654,6.131599,,,,,,,,,,,,,, -14498,,,0.3158788979053497,3.273736000061035,0.2939800024032593,3.3951432704925537,50000.0,0.2227000147104263,3.9273860454559326,10000.0,6757.358831167221,7603.913234949112,6757.358831167221,844.848123550415,0.8794562816619873,0.0 -14500,0.9180745,4.2483525,,,,,,,,,,,,,, -14600,0.96898615,4.2077065,,,,,,,,,,,,,, -14700,0.914578,4.269009,,,,,,,,,,,,,, -14800,0.91478246,4.281115,,,,,,,,,,,,,, -14900,1.2010902,4.100192,,,,,,,,,,,,,, -15000,0.94061357,4.146902,,,,,,,,,,,,,, -15100,1.0268724,4.38999,,,,,,,,,,,,,, -15200,0.92444205,4.293464,,,,,,,,,,,,,, -15300,0.8813629,4.2634754,,,,,,,,,,,,,, -15400,0.7635151,5.3652844,,,,,,,,,,,,,, -15405,,,0.3174218535423279,3.3037431240081787,0.293639987707138,3.4267396926879883,50000.0,0.2321000099182129,3.9265310764312735,10000.0,7177.752962350845,8076.399621963501,7177.752962350845,896.8590533733368,0.9088699817657472,0.0 -15500,0.94413143,4.305468,,,,,,,,,,,,,, -15600,0.89320475,4.358602,,,,,,,,,,,,,, -15700,1.072222,4.3990345,,,,,,,,,,,,,, -15800,0.8416938,4.357018,,,,,,,,,,,,,, -15900,0.8513191,4.6738048,,,,,,,,,,,,,, -16000,0.99772,4.155658,,,,,,,,,,,,,, -16100,0.91168827,4.141082,,,,,,,,,,,,,, -16200,0.7718663,4.452922,,,,,,,,,,,,,, -16300,0.8795629,4.766924,,,,,,,,,,,,,, -16315,,,0.3279101550579071,3.2197184562683105,0.2950599789619446,3.416271448135376,50000.0,0.2264000177383422,3.9574127197265625,10000.0,7597.680613279343,8547.108525753021,7597.680613279343,947.5592894554138,0.9369139671325684,0.0 -16400,0.8958057,4.2322354,,,,,,,,,,,,,, -16500,0.69986916,5.8475723,,,,,,,,,,,,,, -16600,0.9837789,4.111803,,,,,,,,,,,,,, -16700,1.4039582,4.343307,,,,,,,,,,,,,, -16800,0.756189,5.0203896,,,,,,,,,,,,,, -16900,0.8891104,4.2742233,,,,,,,,,,,,,, -17000,0.7691994,4.7339673,,,,,,,,,,,,,, -17100,0.95303714,4.484395,,,,,,,,,,,,,, -17200,1.0292584,4.065135,,,,,,,,,,,,,, -17227,,,0.3093163967132568,3.3662712574005127,0.2869600057601928,3.501178741455078,50000.0,0.2181000113487243,4.019433975219727,10000.0,8017.9552347660065,9020.436147928238,8017.9552347660065,1000.5289397239684,0.9675893783569336,0.0 -17300,0.75356317,4.718058,,,,,,,,,,,,,, -17400,0.9963294,5.044593,,,,,,,,,,,,,, -17500,1.0105183,4.13367,,,,,,,,,,,,,, -17600,0.8584475,4.1773376,,,,,,,,,,,,,, -17700,0.92621154,4.190658,,,,,,,,,,,,,, -17800,1.117711,4.109793,,,,,,,,,,,,,, -17900,0.87498075,4.2648783,,,,,,,,,,,,,, -18000,1.2079636,4.2277107,,,,,,,,,,,,,, -18100,1.0342282,4.0280676,,,,,,,,,,,,,, -18135,,,0.3327734172344208,3.162445545196533,0.3094799816608429,3.302088022232056,50000.0,0.2333000153303146,3.880850553512573,10000.0,8437.884447813034,9491.875847578049,8437.884447813034,1051.960121870041,0.9948971271514891,0.0 -18200,0.7019395,4.7805557,,,,,,,,,,,,,, -18300,1.0155325,4.7458215,,,,,,,,,,,,,, -18400,0.9082048,4.0718374,,,,,,,,,,,,,, -18500,1.3138793,4.235968,,,,,,,,,,,,,, -18600,1.0126655,4.07006,,,,,,,,,,,,,, -18700,1.0938338,4.1308336,,,,,,,,,,,,,, -18800,0.9040427,5.3039565,,,,,,,,,,,,,, -18900,0.92137873,4.3523636,,,,,,,,,,,,,, -19000,0.84312844,4.035587,,,,,,,,,,,,,, -19043,,,0.3293164074420929,3.2388904094696045,0.3053599894046783,3.392557382583618,50000.0,0.2296000123023986,3.9306869506835938,10000.0,8858.166803836823,9965.598746538162,8858.166803836823,1105.318029642105,1.0265140533447266,0.0 -19100,0.7555198,5.979587,,,,,,,,,,,,,, -19200,1.041989,4.131752,,,,,,,,,,,,,, -19300,0.77782255,4.1139536,,,,,,,,,,,,,, -19400,0.9093911,4.121923,,,,,,,,,,,,,, -19500,0.7733843,6.0587053,,,,,,,,,,,,,, -19600,1.0874505,4.5775814,,,,,,,,,,,,,, -19700,1.1648082,4.2979875,,,,,,,,,,,,,, -19800,0.8013429,4.456337,,,,,,,,,,,,,, -19900,1.402299,4.142044,,,,,,,,,,,,,, -19952,,,0.3375976383686065,3.134190082550049,0.3170999884605407,3.262382984161377,50000.0,0.2477000057697296,3.826837778091431,10000.0,9278.551815271378,10440.915224790571,9278.551815271378,1160.165560245514,1.059199333190918,0.0 -20000,1.0065008,4.0887737,,,,,,,,,,,,,, -20100,0.7840756,5.7399583,,,,,,,,,,,,,, -20200,1.1594365,4.9299965,,,,,,,,,,,,,, -20300,0.74733996,5.3797717,,,,,,,,,,,,,, -20400,0.6687324,6.1065516,,,,,,,,,,,,,, -20500,0.91215146,4.0138354,,,,,,,,,,,,,, -20600,0.6361659,6.054471,,,,,,,,,,,,,, -20700,1.1544256,4.1768417,,,,,,,,,,,,,, -20800,0.90959287,4.1257424,,,,,,,,,,,,,, -20863,,,0.3479296863079071,3.0743730068206787,0.3225999772548675,3.2060763835906982,50000.0,0.2464000135660171,3.7816295623779297,10000.0,9698.90292072296,10913.613707065582,9698.90292072296,1212.4267747402191,1.0930509567260742,0.0 -20900,0.78571224,6.0854354,,,,,,,,,,,,,, -21000,0.83506393,5.503471,,,,,,,,,,,,,, -21100,1.0338635,4.3464394,,,,,,,,,,,,,, -21200,1.061357,4.21029,,,,,,,,,,,,,, -21300,0.86781436,4.033852,,,,,,,,,,,,,, -21400,0.7702733,5.7748475,,,,,,,,,,,,,, -21500,0.78307766,4.7949853,,,,,,,,,,,,,, -21600,1.0270188,4.2332335,,,,,,,,,,,,,, -21700,0.8353691,4.0941424,,,,,,,,,,,,,, -21769,,,0.3544921875,3.012260913848877,0.3244200050830841,3.1736767292022705,50000.0,0.2497000098228454,3.733030319213867,10000.0,10119.1109790802,11384.816810846329,10119.1109790802,1263.3394284248352,1.1243548393249512,0.0 -21800,0.8803375,4.2641544,,,,,,,,,,,,,, -21900,0.81573266,5.240739,,,,,,,,,,,,,, -22000,1.2643023,4.1277347,,,,,,,,,,,,,, -22100,0.9976116,4.197023,,,,,,,,,,,,,, -22200,0.8407199,3.9880805,,,,,,,,,,,,,, -22300,1.0302548,6.072651,,,,,,,,,,,,,, -22400,0.9440802,3.9697225,,,,,,,,,,,,,, -22500,0.8575352,3.9297228,,,,,,,,,,,,,, -22600,1.0527908,4.0988073,,,,,,,,,,,,,, -22677,,,0.3376367092132568,3.103149652481079,0.3175399899482727,3.2267537117004395,50000.0,0.246500015258789,3.79990553855896,10000.0,10539.403557777405,11856.48029446602,10539.403557777405,1314.6286573410034,1.1545770168304443,0.0 -22700,0.9149422,4.119547,,,,,,,,,,,,,, -22800,0.8510742,6.004709,,,,,,,,,,,,,, -22900,0.9134335,4.0918155,,,,,,,,,,,,,, -23000,0.9341075,3.8729217,,,,,,,,,,,,,, -23100,1.1599568,4.376851,,,,,,,,,,,,,, -23200,0.82890964,5.999166,,,,,,,,,,,,,, -23300,0.8928606,4.3893824,,,,,,,,,,,,,, -23400,0.94820017,4.1605635,,,,,,,,,,,,,, -23500,0.97704697,3.9536605,,,,,,,,,,,,,, -23585,,,0.3579882681369781,3.0236752033233643,0.3356199860572815,3.154872179031372,50000.0,0.2532000243663788,3.731189250946045,10000.0,10959.684936523438,12328.844103336334,10959.684936523438,1366.626068353653,1.188666820526123,0.0 -23600,0.9768486,6.055979,,,,,,,,,,,,,, -23700,1.042942,4.2522526,,,,,,,,,,,,,, -23800,0.75835866,4.4440217,,,,,,,,,,,,,, -23900,0.82203865,4.3030443,,,,,,,,,,,,,, -24000,0.845139,3.9452875,,,,,,,,,,,,,, -24100,0.5874053,5.9119234,,,,,,,,,,,,,, -24200,1.0749842,4.0110617,,,,,,,,,,,,,, -24300,1.0506446,3.9740252,,,,,,,,,,,,,, -24400,1.1840125,4.0345044,,,,,,,,,,,,,, -24492,,,0.3534179627895355,2.9814937114715576,0.3287400007247925,3.1595587730407715,50000.0,0.2571000158786773,3.7189767360687256,10000.0,11380.087949991226,12801.7687625885,11380.087949991226,1419.0612137317655,1.2243428230285645,0.0 -24500,0.8749317,3.9437494,,,,,,,,,,,,,, -24600,1.3046436,3.9632936,,,,,,,,,,,,,, -24700,0.9659533,4.0295978,,,,,,,,,,,,,, -24800,0.77454746,5.3629503,,,,,,,,,,,,,, -24900,1.1065806,4.129645,,,,,,,,,,,,,, -25000,0.93404394,4.323594,,,,,,,,,,,,,, -25100,0.8251594,5.7647376,,,,,,,,,,,,,, -25200,1.345454,4.0640707,,,,,,,,,,,,,, -25300,0.9440051,4.496231,,,,,,,,,,,,,, -25398,,,0.3600195348262787,3.08575177192688,0.3273800015449524,3.256126880645752,50000.0,0.2442000061273574,3.829700231552124,10000.0,11800.085270881653,13272.997171640396,11800.085270881653,1470.207144498825,1.257941961288452,0.0 -25400,1.0171268,4.060245,,,,,,,,,,,,,, -25500,0.9745265,4.490781,,,,,,,,,,,,,, -25600,1.0822484,4.8274493,,,,,,,,,,,,,, -25700,1.2205309,4.0522895,,,,,,,,,,,,,, -25800,0.87787694,5.066488,,,,,,,,,,,,,, -25900,0.73664737,4.8527637,,,,,,,,,,,,,, -26000,0.9471476,3.935249,,,,,,,,,,,,,, -26100,0.9743689,3.98305,,,,,,,,,,,,,, -26200,0.91552377,5.773072,,,,,,,,,,,,,, -26300,0.92526096,4.0662174,,,,,,,,,,,,,, -26305,,,0.3636718690395355,2.97861909866333,0.3398799896240234,3.1036124229431152,50000.0,0.2581000030040741,3.705676317214966,10000.0,12220.36465883255,13745.666129112244,12220.36465883255,1522.5056171417236,1.2966363430023191,0.0 -26400,0.8380951,4.0334816,,,,,,,,,,,,,, -26500,0.9823744,4.0297017,,,,,,,,,,,,,, -26600,0.8408607,4.3325815,,,,,,,,,,,,,, -26700,1.0955615,3.8243263,,,,,,,,,,,,,, -26800,0.9659243,4.124419,,,,,,,,,,,,,, -26900,0.80685174,4.487803,,,,,,,,,,,,,, -27000,0.94480234,3.9661717,,,,,,,,,,,,,, -27100,0.99976623,3.8982866,,,,,,,,,,,,,, -27200,0.9439564,5.2676992,,,,,,,,,,,,,, -27209,,,0.3745898306369781,2.8777029514312744,0.3472599983215332,3.050992727279663,50000.0,0.2712000012397766,3.6168735027313232,10000.0,12640.606446743011,14217.07286643982,12640.606446743011,1573.5893032550812,1.327235221862793,0.0 -27300,0.9523766,4.227093,,,,,,,,,,,,,, -27400,1.1512637,4.0311975,,,,,,,,,,,,,, -27500,1.0032636,4.4756546,,,,,,,,,,,,,, -27600,0.94629425,4.336035,,,,,,,,,,,,,, -27700,0.8009624,5.1572943,,,,,,,,,,,,,, -27800,0.97035944,4.145494,,,,,,,,,,,,,, -27900,1.1013571,4.1436825,,,,,,,,,,,,,, -28000,1.0593712,3.8005626,,,,,,,,,,,,,, -28100,1.0996554,4.03054,,,,,,,,,,,,,, -28110,,,0.4131249785423279,2.7041707038879395,0.3500999808311462,3.0397725105285645,50000.0,0.2741000056266784,3.6207833290100098,10000.0,13060.701996088028,14688.747186660768,13060.701996088028,1625.0876586437223,1.3568122386932373,0.0 -28200,0.8767696,3.8368068,,,,,,,,,,,,,, -28300,0.98394924,3.8841596,,,,,,,,,,,,,, -28400,0.7722254,5.783076,,,,,,,,,,,,,, -28500,1.2828664,4.0221796,,,,,,,,,,,,,, -28600,0.7768323,5.751974,,,,,,,,,,,,,, -28700,0.6881602,6.0304623,,,,,,,,,,,,,, -28800,0.79224557,5.210574,,,,,,,,,,,,,, -28900,0.94843084,3.90837,,,,,,,,,,,,,, -29000,0.9569326,4.364667,,,,,,,,,,,,,, -29014,,,0.3537499904632568,3.0552871227264404,0.3319399952888489,3.185477495193481,50000.0,0.2489000111818313,3.72862720489502,10000.0,13480.76611328125,15161.299011945724,13480.76611328125,1677.4866523742676,1.3948171138763428,0.0 -29100,1.0634242,4.666823,,,,,,,,,,,,,, -29200,0.6889065,5.7018147,,,,,,,,,,,,,, -29300,1.3305318,4.1523833,,,,,,,,,,,,,, -29400,0.97669584,4.143958,,,,,,,,,,,,,, -29500,0.9538102,5.277127,,,,,,,,,,,,,, -29600,0.9539486,3.9063394,,,,,,,,,,,,,, -29700,1.0592675,4.0394907,,,,,,,,,,,,,, -29800,0.96381915,4.0302005,,,,,,,,,,,,,, -29900,0.8031855,5.2040863,,,,,,,,,,,,,, -29917,,,0.3595312535762787,3.01259183883667,0.3307799994945526,3.17191219329834,50000.0,0.256600022315979,3.744589328765869,10000.0,13901.006760120392,15635.10154223442,13901.006760120392,1730.9614639282229,1.4306964874267578,0.0 -30000,0.9905756,4.4109907,,,,,,,,,,,,,, -30100,0.93828374,4.0012317,,,,,,,,,,,,,, -30200,1.0399098,4.574584,,,,,,,,,,,,,, -30300,0.7741955,5.909687,,,,,,,,,,,,,, -30400,1.2107401,3.9426143,,,,,,,,,,,,,, -30500,0.83457184,5.9714003,,,,,,,,,,,,,, -30600,1.1009167,3.9528115,,,,,,,,,,,,,, -30700,0.8628064,5.308151,,,,,,,,,,,,,, -30800,1.2155056,4.060427,,,,,,,,,,,,,, -30820,,,0.3834374845027923,2.864189624786377,0.3374399840831756,3.1222946643829346,50000.0,0.2592000067234039,3.72937273979187,10000.0,14321.184031248093,16109.316176652908,14321.184031248093,1784.915414810181,1.4623537063598633,0.0 -30900,1.0258555,4.5506496,,,,,,,,,,,,,, -31000,0.66079086,5.6876082,,,,,,,,,,,,,, -31100,0.9500226,4.2905135,,,,,,,,,,,,,, -31200,0.9848862,3.7216232,,,,,,,,,,,,,, -31300,0.94480884,6.0886693,,,,,,,,,,,,,, -31400,0.80091935,4.255659,,,,,,,,,,,,,, -31500,0.92039806,4.1280575,,,,,,,,,,,,,, -31600,1.0092765,3.9313512,,,,,,,,,,,,,, -31700,0.9882781,5.2882047,,,,,,,,,,,,,, -31728,,,0.3595312535762787,3.016347885131836,0.3363399803638458,3.149092197418213,50000.0,0.2592000067234039,3.729475259780884,10000.0,14741.241425275804,16584.70828318596,14741.241425275804,1840.1619803905487,1.498626470565796,0.0 -31800,0.99603224,4.2590685,,,,,,,,,,,,,, -31900,0.84187955,4.2991138,,,,,,,,,,,,,, -32000,1.0296223,3.9347513,,,,,,,,,,,,,, -32100,0.8450686,4.671249,,,,,,,,,,,,,, -32200,1.0246508,3.8290474,,,,,,,,,,,,,, -32300,0.9394472,3.791849,,,,,,,,,,,,,, -32400,1.201175,4.655234,,,,,,,,,,,,,, -32500,0.7762193,3.800065,,,,,,,,,,,,,, -32600,1.0300387,4.386682,,,,,,,,,,,,,, -32632,,,0.3854101598262787,2.822803497314453,0.3571199774742126,2.989459991455078,50000.0,0.2743000090122223,3.608689308166504,10000.0,15161.485805511476,17056.36390018463,15161.485805511476,1891.49284529686,1.528019905090332,0.0 -32700,0.9468142,4.8517957,,,,,,,,,,,,,, -32800,0.6758429,5.7842727,,,,,,,,,,,,,, -32900,0.71174264,5.4489737,,,,,,,,,,,,,, -33000,1.014959,4.00734,,,,,,,,,,,,,, -33100,0.8789238,4.0610766,,,,,,,,,,,,,, -33200,1.0362364,4.375226,,,,,,,,,,,,,, -33300,0.8878956,6.0713563,,,,,,,,,,,,,, -33400,0.9079207,5.5125384,,,,,,,,,,,,,, -33500,0.9136544,3.872467,,,,,,,,,,,,,, -33538,,,0.3950781226158142,2.7515311241149902,0.360179990530014,2.970732450485229,50000.0,0.2797000110149383,3.5605785846710205,10000.0,15581.891350269318,17528.8449883461,15581.891350269318,1943.4836995601647,1.5617244243621826,0.0 -33600,0.7844029,5.8011184,,,,,,,,,,,,,, -33700,1.1359776,4.055661,,,,,,,,,,,,,, -33800,0.86854804,4.650586,,,,,,,,,,,,,, -33900,0.91649455,5.327379,,,,,,,,,,,,,, -34000,0.742046,5.377743,,,,,,,,,,,,,, -34100,0.925091,3.8922367,,,,,,,,,,,,,, -34200,0.9948845,4.2410264,,,,,,,,,,,,,, -34300,1.0554298,3.957753,,,,,,,,,,,,,, -34400,1.0850239,3.9609942,,,,,,,,,,,,,, -34448,,,0.3739648461341858,2.8991634845733643,0.3464599847793579,3.0483076572418213,50000.0,0.268200010061264,3.622101306915283,10000.0,16001.988507032394,18000.516926765442,16001.988507032394,1994.9713623523712,1.596980094909668,0.0 -34500,1.099914,3.8931036,,,,,,,,,,,,,, -34600,0.88607776,5.237255,,,,,,,,,,,,,, -34700,0.731385,6.031337,,,,,,,,,,,,,, -34800,0.983946,3.995248,,,,,,,,,,,,,, -34900,0.68641365,4.6210437,,,,,,,,,,,,,, -35000,0.7938878,5.2053366,,,,,,,,,,,,,, -35100,0.9127383,3.8638732,,,,,,,,,,,,,, -35200,1.0357069,3.9765348,,,,,,,,,,,,,, -35300,0.85665685,4.2375083,,,,,,,,,,,,,, -35352,,,0.3836914002895355,2.843531370162964,0.3626999855041504,2.977720260620117,50000.0,0.2821000218391418,3.5719974040985107,10000.0,16422.013520240784,18471.41180539131,16422.013520240784,2045.7526659965515,1.6344339847564695,0.0 -35400,0.99285066,4.231998,,,,,,,,,,,,,, -35500,1.0157251,4.792582,,,,,,,,,,,,,, -35600,0.750051,5.156627,,,,,,,,,,,,,, -35700,1.020484,3.7478263,,,,,,,,,,,,,, -35800,1.0605037,3.999961,,,,,,,,,,,,,, -35900,0.7638274,4.839756,,,,,,,,,,,,,, -36000,1.1814338,4.012353,,,,,,,,,,,,,, -36100,0.8837269,6.0031652,,,,,,,,,,,,,, -36200,1.0100248,4.053787,,,,,,,,,,,,,, -36258,,,0.3859765529632568,2.824381351470948,0.3565199971199035,3.010390520095825,50000.0,0.2752000093460083,3.615450620651245,10000.0,16842.142607688904,18943.7842271328,16842.142607688904,2097.913378477097,1.6660008430480957,0.0 -36300,0.9538864,3.6889107,,,,,,,,,,,,,, -36400,1.1207417,3.9416735,,,,,,,,,,,,,, -36500,0.86963946,4.6061397,,,,,,,,,,,,,, -36600,1.001471,3.8999171,,,,,,,,,,,,,, -36700,0.89009273,4.1311274,,,,,,,,,,,,,, -36800,1.049404,3.9674726,,,,,,,,,,,,,, -36900,0.8565714,3.800933,,,,,,,,,,,,,, -37000,0.92651504,5.683596,,,,,,,,,,,,,, -37100,0.8277328,5.447791,,,,,,,,,,,,,, -37164,,,0.3831445276737213,2.815798282623291,0.3597999811172485,2.957936286926269,50000.0,0.2762000262737274,3.5880584716796875,10000.0,17262.063611745834,19414.52570724488,17262.063611745834,2148.6515715122223,1.697392463684082,0.0 -37200,1.0562702,4.2326913,,,,,,,,,,,,,, -37300,1.0587091,3.7857406,,,,,,,,,,,,,, -37400,1.1302172,3.839536,,,,,,,,,,,,,, -37500,0.9833233,4.452497,,,,,,,,,,,,,, -37600,1.0221889,3.8351085,,,,,,,,,,,,,, -37700,0.96555,3.8840704,,,,,,,,,,,,,, -37800,0.82589716,5.6923923,,,,,,,,,,,,,, -37900,1.0293987,3.7882857,,,,,,,,,,,,,, -38000,1.0060278,3.9376125,,,,,,,,,,,,,, -38070,,,0.3713281154632568,2.9289300441741943,0.3464599847793579,3.083930969238281,50000.0,0.2651000022888183,3.700192451477051,10000.0,17682.347589969635,19888.36406493187,17682.347589969635,2202.1193537712097,1.732421636581421,0.0 -38100,1.2262262,3.8261743,,,,,,,,,,,,,, -38200,0.9208452,3.7175305,,,,,,,,,,,,,, -38300,0.73786753,5.4947557,,,,,,,,,,,,,, -38400,0.94229615,3.9714859,,,,,,,,,,,,,, -38500,0.93792117,3.867582,,,,,,,,,,,,,, -38600,0.9855332,4.0991945,,,,,,,,,,,,,, -38700,0.91800505,4.1501575,,,,,,,,,,,,,, -38800,0.95252126,3.8827345,,,,,,,,,,,,,, -38900,0.7652788,5.9532986,,,,,,,,,,,,,, -38976,,,0.3922656178474426,2.8164315223693848,0.3563799858093261,3.0107686519622803,50000.0,0.2685000002384186,3.625632286071777,10000.0,18102.575630664825,20361.533406972885,18102.575630664825,2254.9687502384186,1.7733440399169922,0.0 -39000,0.92407787,5.4690385,,,,,,,,,,,,,, -39100,1.0993407,4.1333528,,,,,,,,,,,,,, -39200,0.99745816,4.285211,,,,,,,,,,,,,, -39300,1.2601873,3.9291809,,,,,,,,,,,,,, -39400,0.94787353,4.541074,,,,,,,,,,,,,, -39500,1.0668361,3.9661598,,,,,,,,,,,,,, -39600,1.0131162,3.7030902,,,,,,,,,,,,,, -39700,1.0536579,3.8075361,,,,,,,,,,,,,, -39800,0.97453904,4.3172407,,,,,,,,,,,,,, -39884,,,0.3822265565395355,2.884771585464477,0.3570599853992462,3.043698787689209,50000.0,0.2786000072956085,3.6124985218048096,10000.0,18522.856332540512,20832.76473426819,18522.856332540512,2305.820141553879,1.8208367824554443,0.0 -39900,1.0570567,3.8064282,,,,,,,,,,,,,, -40000,0.8017942,4.062215,,,,,,,,,,,,,, -40100,1.128709,3.8312287,,,,,,,,,,,,,, -40200,1.0289546,3.764418,,,,,,,,,,,,,, -40300,0.981484,3.9452775,,,,,,,,,,,,,, -40400,0.9860841,3.7329423,,,,,,,,,,,,,, -40500,0.84194046,4.3257847,,,,,,,,,,,,,, -40600,0.682543,5.9232783,,,,,,,,,,,,,, -40700,1.0369002,3.849479,,,,,,,,,,,,,, -40792,,,0.3929687440395355,2.77851676940918,0.369079977273941,2.914356231689453,50000.0,0.2778000235557556,3.527748107910156,10000.0,18942.836881637573,21302.93864560128,18942.836881637573,2355.927830219269,1.8549113273620603,0.0 -40800,0.7010831,5.6984115,,,,,,,,,,,,,, -40900,0.8761482,4.3644505,,,,,,,,,,,,,, -41000,0.85361433,4.73737,,,,,,,,,,,,,, -41100,0.8307972,5.770074,,,,,,,,,,,,,, -41200,1.1812713,3.9727798,,,,,,,,,,,,,, -41300,0.92592984,3.7785687,,,,,,,,,,,,,, -41400,0.8209595,5.9180427,,,,,,,,,,,,,, -41500,0.89142096,4.56552,,,,,,,,,,,,,, -41600,1.0662069,3.8853042,,,,,,,,,,,,,, -41694,,,0.3770507872104645,2.974250078201294,0.3516799807548523,3.125171184539795,50000.0,0.268200010061264,3.675233840942383,10000.0,19362.784957170486,21773.56910061836,19362.784957170486,2406.526533842087,1.887979984283448,0.0 -41700,0.7314242,4.8524647,,,,,,,,,,,,,, -41800,1.0538762,3.8755486,,,,,,,,,,,,,, -41900,1.0714984,4.0589232,,,,,,,,,,,,,, -42000,0.90214944,3.854623,,,,,,,,,,,,,, -42100,1.0442302,4.6172466,,,,,,,,,,,,,, -42200,0.9234349,3.8029568,,,,,,,,,,,,,, -42300,0.85603875,5.0673203,,,,,,,,,,,,,, -42400,0.9522568,5.8011355,,,,,,,,,,,,,, -42500,0.85649264,5.1935964,,,,,,,,,,,,,, -42600,1.1593997,3.937105,,,,,,,,,,,,,, -42601,,,0.3780468702316284,2.9276158809661865,0.3479799926280975,3.1044631004333496,50000.0,0.2657000124454498,3.694946765899658,10000.0,19783.164499998093,22244.399122715,19783.164499998093,2456.8914000988007,1.921765804290772,0.0 -42700,0.9200161,4.1169033,,,,,,,,,,,,,, -42800,1.010753,3.7487497,,,,,,,,,,,,,, -42900,0.7397501,5.780289,,,,,,,,,,,,,, -43000,0.7041404,5.9332905,,,,,,,,,,,,,, -43100,0.90350205,4.358953,,,,,,,,,,,,,, -43200,0.86713403,4.4514184,,,,,,,,,,,,,, -43300,1.0323902,5.4356985,,,,,,,,,,,,,, -43400,1.1060803,3.83782,,,,,,,,,,,,,, -43500,1.2484103,3.74515,,,,,,,,,,,,,, -43515,,,0.4089257717132568,2.723781108856201,0.3783199787139892,2.8741984367370605,50000.0,0.2962000072002411,3.4557039737701416,10000.0,20203.380412817,22716.77877688408,20203.380412817,2508.969986438751,1.9536054134368896,0.0 -43600,1.0393771,3.7409964,,,,,,,,,,,,,, -43700,1.0156155,3.7604506,,,,,,,,,,,,,, -43800,1.0235963,3.763887,,,,,,,,,,,,,, -43900,0.722059,5.8138995,,,,,,,,,,,,,, -44000,1.1133658,3.858712,,,,,,,,,,,,,, -44100,0.9512133,4.026616,,,,,,,,,,,,,, -44200,1.0270516,3.7034793,,,,,,,,,,,,,, -44300,1.0517502,3.8467312,,,,,,,,,,,,,, -44400,0.9834008,3.8434021,,,,,,,,,,,,,, -44422,,,0.3873632848262787,2.8844244480133057,0.3587599992752075,3.038835287094116,50000.0,0.2747000157833099,3.624547004699707,10000.0,20623.665967941284,23187.171246290207,20623.665967941284,2558.9942378997803,1.9841821193695068,0.0 -44500,0.80121917,5.342429,,,,,,,,,,,,,, -44600,0.939056,5.09997,,,,,,,,,,,,,, -44700,1.0028893,3.8138564,,,,,,,,,,,,,, -44800,0.84628445,3.9038358,,,,,,,,,,,,,, -44900,1.0680487,3.8665752,,,,,,,,,,,,,, -45000,0.9724786,3.8938618,,,,,,,,,,,,,, -45100,1.1410156,5.3142967,,,,,,,,,,,,,, -45200,0.9178935,3.8033025,,,,,,,,,,,,,, -45300,0.9905579,3.7907941,,,,,,,,,,,,,, -45332,,,0.4272656142711639,2.607675313949585,0.367059975862503,2.9391558170318604,50000.0,0.2851999998092651,3.533428192138672,10000.0,21043.75366783142,23657.901322603226,21043.75366783142,2609.552106142044,2.016412734985352,0.0 -45400,1.0021248,4.529626,,,,,,,,,,,,,, -45500,0.8262636,4.74198,,,,,,,,,,,,,, -45600,1.0660173,3.6499212,,,,,,,,,,,,,, -45700,1.0604395,3.7975705,,,,,,,,,,,,,, -45800,0.92735404,3.7526152,,,,,,,,,,,,,, -45900,1.0605363,5.840126,,,,,,,,,,,,,, -46000,1.0851188,4.0408278,,,,,,,,,,,,,, -46100,1.1138409,3.8527653,,,,,,,,,,,,,, -46200,0.97317153,4.486171,,,,,,,,,,,,,, -46241,,,0.3944140672683716,2.8179240226745605,0.3656199872493744,2.9761483669281006,50000.0,0.2811000049114227,3.5713536739349365,10000.0,21463.72553873062,24131.569739818573,21463.72553873062,2663.1634533405304,2.049370527267456,0.0 -46300,1.0347939,3.879844,,,,,,,,,,,,,, -46400,0.9535312,3.9506164,,,,,,,,,,,,,, -46500,1.1740544,3.6621926,,,,,,,,,,,,,, -46600,1.0760173,3.7474957,,,,,,,,,,,,,, -46700,0.82150936,4.4703207,,,,,,,,,,,,,, -46800,1.0022551,3.8610435,,,,,,,,,,,,,, -46900,1.0485842,4.476443,,,,,,,,,,,,,, -47000,0.8943859,5.285697,,,,,,,,,,,,,, -47100,0.92336863,4.64664,,,,,,,,,,,,,, -47149,,,0.4058789014816284,2.753173828125,0.374239981174469,2.925992965698242,50000.0,0.2869000136852264,3.5421478748321533,10000.0,21884.07934308052,24604.291342496872,21884.07934308052,2715.443905115128,2.0853078365325928,0.0 -47200,1.1266483,3.7018156,,,,,,,,,,,,,, -47300,1.2126938,3.8042674,,,,,,,,,,,,,, -47400,1.1553583,3.6821947,,,,,,,,,,,,,, -47500,1.1065867,3.8228128,,,,,,,,,,,,,, -47600,1.128139,4.2096133,,,,,,,,,,,,,, -47700,0.77578986,4.5466146,,,,,,,,,,,,,, -47800,0.98437744,3.7421641,,,,,,,,,,,,,, -47900,1.0105948,3.8067021,,,,,,,,,,,,,, -48000,1.1801978,4.096107,,,,,,,,,,,,,, -48055,,,0.4291796684265136,2.6209170818328857,0.3867599964141845,2.860398054122925,50000.0,0.2960000038146972,3.4566426277160645,10000.0,22304.013244628903,25076.32378435135,22304.013244628903,2767.4498670101166,2.1263139247894287,0.0 -48100,0.9642741,3.7842422,,,,,,,,,,,,,, -48200,0.9657973,3.7870436,,,,,,,,,,,,,, -48300,1.0098704,3.734543,,,,,,,,,,,,,, -48400,1.0909147,3.660275,,,,,,,,,,,,,, -48500,0.8736434,5.6108623,,,,,,,,,,,,,, -48600,0.86657184,4.3683314,,,,,,,,,,,,,, -48700,1.0379109,5.4880037,,,,,,,,,,,,,, -48800,1.3872665,3.8471792,,,,,,,,,,,,,, -48900,1.2742646,3.6827116,,,,,,,,,,,,,, -48963,,,0.4155077934265136,2.642244815826416,0.3876399993896484,2.803826093673706,50000.0,0.3047000169754028,3.414707899093628,10000.0,22724.27734947205,25547.15549659729,22724.27734947205,2817.928693294525,2.163120985031128,0.0 -49000,0.989007,5.9464474,,,,,,,,,,,,,, -49100,1.0908117,3.7435336,,,,,,,,,,,,,, -49200,0.9659694,3.743392,,,,,,,,,,,,,, -49300,0.9626374,4.773425,,,,,,,,,,,,,, -49400,0.9491711,3.737818,,,,,,,,,,,,,, -49500,0.92269206,3.6726983,,,,,,,,,,,,,, -49600,1.0321399,3.848784,,,,,,,,,,,,,, -49700,0.97830814,3.7073753,,,,,,,,,,,,,, -49800,0.9610203,3.8303483,,,,,,,,,,,,,, -49872,,,0.4216601550579071,2.622653722763061,0.3941999971866607,2.784731388092041,50000.0,0.3007000088691711,3.3947463035583496,10000.0,23144.547844409943,26019.239842414856,23144.547844409943,2869.654707431793,2.1987524032592773,0.0 -49900,0.8666425,5.8795595,,,,,,,,,,,,,, -50000,0.73812354,5.814189,,,,,,,,,,,,,, -50100,1.1558537,3.7524245,,,,,,,,,,,,,, -50200,1.1465702,3.7566614,,,,,,,,,,,,,, -50300,1.3030723,3.780607,,,,,,,,,,,,,, -50400,0.78496677,5.4899664,,,,,,,,,,,,,, -50500,1.1433583,3.7439227,,,,,,,,,,,,,, -50600,0.97582054,3.599656,,,,,,,,,,,,,, -50700,0.84021777,4.2916937,,,,,,,,,,,,,, -50782,,,0.431640625,2.567415475845337,0.3929999768733978,2.786828756332397,50000.0,0.3039000034332275,3.3923990726470947,10000.0,23564.829872846603,26492.41095542908,23564.829872846603,2922.45498919487,2.235773801803589,0.0 -50800,0.807832,4.9146867,,,,,,,,,,,,,, -50900,1.071164,3.809395,,,,,,,,,,,,,, -51000,1.0995125,3.7357569,,,,,,,,,,,,,, -51100,0.96360695,6.0171304,,,,,,,,,,,,,, -51200,1.0096806,3.5845795,,,,,,,,,,,,,, -51300,0.9273627,3.7710867,,,,,,,,,,,,,, -51400,1.3177595,3.724824,,,,,,,,,,,,,, -51500,1.0215428,3.7648354,,,,,,,,,,,,,, -51600,0.8573319,4.048172,,,,,,,,,,,,,, -51689,,,0.4244335889816284,2.612565279006958,0.3941600024700165,2.762552499771118,50000.0,0.3032000064849853,3.374166250228882,10000.0,23985.074984312057,26965.50780677796,23985.074984312057,2975.215850353241,2.2752928733825684,0.0 -51700,1.1728947,3.754836,,,,,,,,,,,,,, -51800,0.8663367,5.977885,,,,,,,,,,,,,, -51900,0.8983907,4.000354,,,,,,,,,,,,,, -52000,0.97949636,4.3266845,,,,,,,,,,,,,, -52100,1.0648385,3.5822134,,,,,,,,,,,,,, -52200,0.84177107,3.6137393,,,,,,,,,,,,,, -52300,0.97701216,4.5720935,,,,,,,,,,,,,, -52400,1.2539884,3.7248063,,,,,,,,,,,,,, -52500,1.122442,3.6888552,,,,,,,,,,,,,, -52597,,,0.4292578101158142,2.5953733921051025,0.3981399834156036,2.7684292793273926,50000.0,0.3110000193119049,3.386230230331421,10000.0,24405.261734247208,27436.437440156937,24405.261734247208,3025.862987279892,2.319564819335937,0.0 -52600,1.1078861,5.856307,,,,,,,,,,,,,, -52700,0.96351063,5.8370733,,,,,,,,,,,,,, -52800,0.99549854,3.7786818,,,,,,,,,,,,,, -52900,0.88669205,4.1564083,,,,,,,,,,,,,, -53000,1.2676708,3.6892157,,,,,,,,,,,,,, -53100,0.9953294,3.7029312,,,,,,,,,,,,,, -53200,1.012179,3.6147645,,,,,,,,,,,,,, -53300,1.198094,3.7760453,,,,,,,,,,,,,, -53400,1.2475686,3.6876373,,,,,,,,,,,,,, -53500,0.9236135,3.4769824,,,,,,,,,,,,,, -53503,,,0.4274999797344208,2.5975615978240967,0.3849999904632568,2.821701765060425,50000.0,0.2998000085353851,3.433039903640747,10000.0,24825.304622650143,27908.778692007065,24825.304622650143,3078.066597223282,2.363283157348633,0.0 -53600,1.0429574,3.6945102,,,,,,,,,,,,,, -53700,1.170374,3.6833622,,,,,,,,,,,,,, -53800,1.0840468,3.7629972,,,,,,,,,,,,,, -53900,0.9258058,3.764538,,,,,,,,,,,,,, -54000,0.9568544,4.4574437,,,,,,,,,,,,,, -54100,1.0623058,4.5695424,,,,,,,,,,,,,, -54200,0.980408,3.7764988,,,,,,,,,,,,,, -54300,0.9817742,3.9270658,,,,,,,,,,,,,, -54400,0.9806343,3.753491,,,,,,,,,,,,,, -54413,,,0.4327929615974426,2.579705238342285,0.3990999758243561,2.741082191467285,50000.0,0.3037000000476837,3.377845764160156,10000.0,25245.264453172684,28380.64738035202,25245.264453172684,3129.888496398926,2.3979485034942627,0.0 -54500,1.0042844,3.5331254,,,,,,,,,,,,,, -54600,0.85461533,4.3594055,,,,,,,,,,,,,, -54700,1.0629574,3.6128564,,,,,,,,,,,,,, -54800,1.1048707,3.8007357,,,,,,,,,,,,,, -54900,0.8248329,4.5544496,,,,,,,,,,,,,, -55000,1.0121043,3.643569,,,,,,,,,,,,,, -55100,1.0768178,5.7631645,,,,,,,,,,,,,, -55200,0.8999737,4.311295,,,,,,,,,,,,,, -55300,0.8416051,5.7690287,,,,,,,,,,,,,, -55323,,,0.4314843714237213,2.5818235874176025,0.4027799963951111,2.7566277980804443,50000.0,0.3120000064373016,3.340991973876953,10000.0,25665.59147167205,28852.82416653633,25665.59147167205,3181.6526210308075,2.4328207969665527,0.0 -55400,1.2080299,3.7265933,,,,,,,,,,,,,, -55500,0.79666644,4.5551276,,,,,,,,,,,,,, -55600,0.8222529,3.9535167,,,,,,,,,,,,,, -55700,1.0367793,3.6713126,,,,,,,,,,,,,, -55800,1.0665609,3.5959382,,,,,,,,,,,,,, -55900,1.2040279,3.6726909,,,,,,,,,,,,,, -56000,0.8622191,4.2390285,,,,,,,,,,,,,, -56100,0.9939912,3.7535558,,,,,,,,,,,,,, -56200,0.9187557,3.9264278,,,,,,,,,,,,,, -56233,,,0.4399023354053497,2.5200986862182617,0.407260000705719,2.699357986450196,50000.0,0.3122000098228454,3.344634532928467,10000.0,26085.61327648163,29324.68583965301,26085.61327648163,3233.405528306961,2.467010736465454,0.0 -56300,1.0246748,3.7252302,,,,,,,,,,,,,, -56400,1.0848523,3.7138753,,,,,,,,,,,,,, -56500,1.0562117,3.6245122,,,,,,,,,,,,,, -56600,0.80991626,5.2906723,,,,,,,,,,,,,, -56700,0.9951647,3.9504614,,,,,,,,,,,,,, -56800,1.0599184,4.130688,,,,,,,,,,,,,, -56900,0.83868974,5.703293,,,,,,,,,,,,,, -57000,1.0006183,3.619731,,,,,,,,,,,,,, -57100,0.92518985,4.8178287,,,,,,,,,,,,,, -57145,,,0.4365624785423279,2.546628475189209,0.4060999751091003,2.6989681720733643,50000.0,0.3115000128746032,3.323418140411377,10000.0,26505.705159902573,29796.24742078781,26505.705159902573,3284.786096572876,2.5048763751983643,0.0 -57200,1.1811497,3.6591654,,,,,,,,,,,,,, -57300,0.9629001,3.6764843,,,,,,,,,,,,,, -57400,1.2019964,3.498268,,,,,,,,,,,,,, -57500,1.0561383,3.7988558,,,,,,,,,,,,,, -57600,0.9337949,5.7252464,,,,,,,,,,,,,, -57700,1.0043824,3.6366694,,,,,,,,,,,,,, -57800,1.0419765,3.6095273,,,,,,,,,,,,,, -57900,0.82974786,5.131116,,,,,,,,,,,,,, -58000,0.79385823,4.8764763,,,,,,,,,,,,,, -58056,,,0.4286132752895355,2.626063108444214,0.4068599939346313,2.7548227310180664,50000.0,0.3158000111579895,3.3459408283233643,10000.0,26925.716034412384,30267.275554418564,26925.716034412384,3335.712982416153,2.543241739273072,0.0 -58100,1.0053997,3.6598501,,,,,,,,,,,,,, -58200,1.2474834,3.6470683,,,,,,,,,,,,,, -58300,1.0219954,3.546163,,,,,,,,,,,,,, -58400,1.0078764,4.818644,,,,,,,,,,,,,, -58500,0.9626546,3.4761102,,,,,,,,,,,,,, -58600,1.2080029,3.6908047,,,,,,,,,,,,,, -58700,0.96133816,4.8083572,,,,,,,,,,,,,, -58800,1.0766907,3.9308705,,,,,,,,,,,,,, -58900,1.1219809,3.6599574,,,,,,,,,,,,,, -58965,,,0.4401562511920929,2.5463802814483643,0.4059799909591675,2.7379801273345947,50000.0,0.3166000247001648,3.356484651565552,10000.0,27345.927248716354,30740.20338773728,27345.927248716354,3388.341548681259,2.5799450874328613,0.0 -59000,0.89881223,4.7944646,,,,,,,,,,,,,, -59100,1.0313746,5.302816,,,,,,,,,,,,,, -59200,1.1607361,3.4917896,,,,,,,,,,,,,, -59300,0.96163744,6.0323415,,,,,,,,,,,,,, -59400,0.9834349,3.387684,,,,,,,,,,,,,, -59500,1.1699135,3.6063654,,,,,,,,,,,,,, -59600,0.79714,5.6723638,,,,,,,,,,,,,, -59700,1.0945967,3.6146824,,,,,,,,,,,,,, -59800,1.1013569,4.010791,,,,,,,,,,,,,, -59875,,,0.4756445288658142,2.349417924880981,0.4097599983215332,2.690446376800537,50000.0,0.3135000169277191,3.311406373977661,10000.0,27766.268231868744,31214.8779771328,27766.268231868744,3442.583624601364,2.620101928710937,0.0 -59900,0.9378393,4.8397007,,,,,,,,,,,,,, -60000,1.1034068,3.6218133,,,,,,,,,,,,,, -60100,1.2107141,3.6323662,,,,,,,,,,,,,, -60200,0.9859821,3.5247824,,,,,,,,,,,,,, -60300,1.3890625,3.5487964,,,,,,,,,,,,,, -60400,0.7319481,5.0812216,,,,,,,,,,,,,, -60500,0.9493938,3.4218688,,,,,,,,,,,,,, -60600,0.76948136,5.8522477,,,,,,,,,,,,,, -60700,0.7966986,5.590809,,,,,,,,,,,,,, -60784,,,0.440253883600235,2.530428647994995,0.4153399765491485,2.6711795330047607,50000.0,0.3251000046730041,3.2931430339813232,10000.0,28186.384392499924,31687.30434346199,28186.384392499924,3494.8005118370056,2.661822557449341,0.0 -60800,0.8274644,4.831905,,,,,,,,,,,,,, -60900,1.0436844,3.686208,,,,,,,,,,,,,, -61000,1.0597941,3.5447764,,,,,,,,,,,,,, -61100,1.0547117,3.4936616,,,,,,,,,,,,,, -61200,1.0765785,3.5725014,,,,,,,,,,,,,, -61300,1.1910673,3.5131297,,,,,,,,,,,,,, -61400,0.8696191,5.800535,,,,,,,,,,,,,, -61500,0.6913131,5.5703425,,,,,,,,,,,,,, -61600,1.0720959,3.4820702,,,,,,,,,,,,,, -61693,,,0.4561132788658142,2.4191908836364746,0.4205799996852875,2.6124274730682373,50000.0,0.3247000277042389,3.2571167945861816,10000.0,28606.57464933396,32160.68154025078,28606.57464933396,3547.8963441848755,2.700721263885498,0.0 -61700,0.67770237,5.7078843,,,,,,,,,,,,,, -61800,0.9933947,3.8396194,,,,,,,,,,,,,, -61900,0.90665776,4.7928095,,,,,,,,,,,,,, -62000,1.0289315,3.4401793,,,,,,,,,,,,,, -62100,0.9918754,3.5014927,,,,,,,,,,,,,, -62200,1.0342584,3.9207778,,,,,,,,,,,,,, -62300,0.8777171,4.8666263,,,,,,,,,,,,,, -62400,0.8571685,3.9760437,,,,,,,,,,,,,, -62500,1.0369605,4.4735923,,,,,,,,,,,,,, -62600,1.0198208,3.405691,,,,,,,,,,,,,, -62604,,,0.4695117175579071,2.3920223712921143,0.4214800000190735,2.663222074508667,50000.0,0.3195000290870666,3.280048131942749,10000.0,29026.54181861877,32634.209529399872,29026.54181861877,3601.364328861237,2.742002010345459,0.0 -62700,0.85922295,5.7816124,,,,,,,,,,,,,, -62800,0.8783605,3.9518018,,,,,,,,,,,,,, -62900,1.1946006,3.5093052,,,,,,,,,,,,,, -63000,1.0147609,3.5786097,,,,,,,,,,,,,, -63100,1.1980349,3.645339,,,,,,,,,,,,,, -63200,1.0768654,3.5599992,,,,,,,,,,,,,, -63300,1.048566,4.1745253,,,,,,,,,,,,,, -63400,0.77860713,5.7355533,,,,,,,,,,,,,, -63500,1.1041951,4.925344,,,,,,,,,,,,,, -63516,,,0.4356835782527923,2.5817770957946777,0.4103799760341644,2.7430014610290527,50000.0,0.3118000030517578,3.331474781036377,10000.0,29446.83342576027,33106.452523231506,29446.83342576027,3653.226641178131,2.778369903564453,0.0 -63600,1.0729209,3.4191186,,,,,,,,,,,,,, -63700,0.9090314,4.4631624,,,,,,,,,,,,,, -63800,1.3228341,3.542609,,,,,,,,,,,,,, -63900,0.8398748,5.070628,,,,,,,,,,,,,, -64000,1.147536,3.4740825,,,,,,,,,,,,,, -64100,0.95855016,3.717854,,,,,,,,,,,,,, -64200,1.0595392,3.344551,,,,,,,,,,,,,, -64300,1.0727996,3.5901945,,,,,,,,,,,,,, -64400,1.0421945,3.376985,,,,,,,,,,,,,, -64423,,,0.4575585722923279,2.4170541763305664,0.4262599945068359,2.5896904468536377,50000.0,0.3339000046253204,3.209775447845459,10000.0,29867.208926677704,33581.21010494232,29867.208926677704,3707.521521806717,2.814645767211914,0.0 -64500,1.0363541,3.5050519,,,,,,,,,,,,,, -64600,1.1137607,3.53423,,,,,,,,,,,,,, -64700,0.9716464,3.4662929,,,,,,,,,,,,,, -64800,1.0041453,3.419908,,,,,,,,,,,,,, -64900,1.0322992,4.2816825,,,,,,,,,,,,,, -65000,0.89294004,5.889802,,,,,,,,,,,,,, -65100,0.7844119,5.0261025,,,,,,,,,,,,,, -65200,0.8085567,5.242818,,,,,,,,,,,,,, -65300,0.9396445,3.7776566,,,,,,,,,,,,,, -65330,,,0.4688476324081421,2.36328673362732,0.4254999756813049,2.597842216491699,50000.0,0.3272000253200531,3.247157573699951,10000.0,30287.316556453705,34054.11939907074,30287.316556453705,3760.2291600704193,2.8578405380249023,0.0 -65400,1.1518122,3.6638389,,,,,,,,,,,,,, -65500,1.1236748,3.6092548,,,,,,,,,,,,,, -65600,0.91522706,3.844605,,,,,,,,,,,,,, -65700,0.98599195,3.6503806,,,,,,,,,,,,,, -65800,1.1986535,3.623131,,,,,,,,,,,,,, -65900,1.0137955,3.445849,,,,,,,,,,,,,, -66000,1.098779,3.617025,,,,,,,,,,,,,, -66100,1.1568918,3.6224616,,,,,,,,,,,,,, -66200,0.94837284,4.4926248,,,,,,,,,,,,,, -66240,,,0.4552929699420929,2.4185917377471924,0.4275999963283539,2.574275016784668,50000.0,0.3324000239372253,3.1866111755371094,10000.0,30707.30086231232,34526.12924003601,30707.30086231232,3812.165778875351,2.89416766166687,0.0 -66300,1.0454686,3.5979662,,,,,,,,,,,,,, -66400,0.89655894,5.1062055,,,,,,,,,,,,,, -66500,0.88853145,5.8168335,,,,,,,,,,,,,, -66600,1.056545,3.6100998,,,,,,,,,,,,,, -66700,0.97053975,5.8401437,,,,,,,,,,,,,, -66800,1.3390489,3.4535863,,,,,,,,,,,,,, -66900,1.3336029,3.8374515,,,,,,,,,,,,,, -67000,1.0021939,3.833844,,,,,,,,,,,,,, -67100,0.9825842,4.990303,,,,,,,,,,,,,, -67149,,,0.4733007848262787,2.328538417816162,0.4402399957180023,2.4984281063079834,50000.0,0.332800030708313,3.1850574016571045,10000.0,31127.646122694016,35001.00862288475,31127.646122694016,3866.608591794968,2.933951139450073,0.0 -67200,0.8294815,5.779514,,,,,,,,,,,,,, -67300,1.1585301,3.6002498,,,,,,,,,,,,,, -67400,1.0451988,3.4313862,,,,,,,,,,,,,, -67500,1.027962,3.3573346,,,,,,,,,,,,,, -67600,0.83220583,4.5873656,,,,,,,,,,,,,, -67700,0.94751483,4.3213687,,,,,,,,,,,,,, -67800,0.91431975,5.177264,,,,,,,,,,,,,, -67900,0.9855752,3.4191122,,,,,,,,,,,,,, -68000,1.0426587,4.070302,,,,,,,,,,,,,, -68058,,,0.4591992199420929,2.4289512634277344,0.4236399829387665,2.623819351196289,50000.0,0.333400011062622,3.2271764278411865,10000.0,31547.65611815453,35474.00371336937,31547.65611815453,3919.500876426697,2.9747848510742188,0.0 -68100,1.1578263,3.3966126,,,,,,,,,,,,,, -68200,0.9609248,4.3077483,,,,,,,,,,,,,, -68300,1.0444996,3.2639205,,,,,,,,,,,,,, -68400,1.0951431,3.6176174,,,,,,,,,,,,,, -68500,1.0578762,3.7723107,,,,,,,,,,,,,, -68600,1.2112204,3.9451218,,,,,,,,,,,,,, -68700,0.83058596,5.170052,,,,,,,,,,,,,, -68800,1.0215822,3.8171725,,,,,,,,,,,,,, -68900,0.77056545,5.7491255,,,,,,,,,,,,,, -68967,,,0.4648632705211639,2.3892576694488525,0.4384599924087524,2.5289251804351807,50000.0,0.3456000089645386,3.1654460430145264,10000.0,31967.69016432762,35946.450082063675,31967.69016432762,3971.823595046997,3.0140380859375,0.0 -69000,1.0592501,3.536203,,,,,,,,,,,,,, -69100,0.8138489,4.368767,,,,,,,,,,,,,, -69200,1.0259554,3.402348,,,,,,,,,,,,,, -69300,1.204397,3.5174122,,,,,,,,,,,,,, -69400,1.1455709,3.5955465,,,,,,,,,,,,,, -69500,1.3096839,3.463017,,,,,,,,,,,,,, -69600,0.83626693,5.769249,,,,,,,,,,,,,, -69700,1.1709163,3.4984345,,,,,,,,,,,,,, -69800,1.1804332,3.3137865,,,,,,,,,,,,,, -69875,,,0.4686718583106994,2.3740696907043457,0.4352200031280517,2.5472748279571533,50000.0,0.3344000279903412,3.193197011947632,10000.0,32388.05753827095,36421.328429460526,32388.05753827095,4026.2420842647552,3.0550498962402344,0.0 -69900,0.9793054,3.499519,,,,,,,,,,,,,, -70000,1.1485518,3.3648567,,,,,,,,,,,,,, -70100,1.325463,3.735176,,,,,,,,,,,,,, -70200,1.2457818,4.083773,,,,,,,,,,,,,, -70300,1.1099508,5.4842134,,,,,,,,,,,,,, -70400,1.011757,3.5312753,,,,,,,,,,,,,, -70500,0.8499919,5.726365,,,,,,,,,,,,,, -70600,1.0156537,3.6548204,,,,,,,,,,,,,, -70700,0.87148154,4.040946,,,,,,,,,,,,,, -70782,,,0.4740038812160492,2.324805974960327,0.4377000033855438,2.5263214111328125,50000.0,0.3368000090122223,3.184787034988404,10000.0,32808.10027337074,36893.90516161919,32808.10027337074,4078.6816279888153,3.0984416007995605,0.0 -70800,1.0612398,3.386038,,,,,,,,,,,,,, -70900,0.95173186,3.5053616,,,,,,,,,,,,,, -71000,1.2887851,4.1775713,,,,,,,,,,,,,, -71100,0.8342031,5.2354555,,,,,,,,,,,,,, -71200,1.3981326,3.5602767,,,,,,,,,,,,,, -71300,1.0659176,3.53726,,,,,,,,,,,,,, -71400,1.0860397,3.5821276,,,,,,,,,,,,,, -71500,0.99514,3.4354806,,,,,,,,,,,,,, -71600,1.0461631,5.858356,,,,,,,,,,,,,, -71692,,,0.4591406285762787,2.428839921951294,0.4280000030994415,2.5926687717437744,50000.0,0.3339000046253204,3.219151020050049,10000.0,33228.35967874527,37367.487449646,33228.35967874527,4131.9138832092285,3.138212442398072,0.0 -71700,1.0965898,3.6605337,,,,,,,,,,,,,, -71800,1.402975,3.4185019,,,,,,,,,,,,,, -71900,1.1321541,3.6667438,,,,,,,,,,,,,, -72000,1.0652462,3.4381812,,,,,,,,,,,,,, -72100,1.0241985,3.435002,,,,,,,,,,,,,, -72200,1.2705258,3.4276,,,,,,,,,,,,,, -72300,0.8677074,3.8802426,,,,,,,,,,,,,, -72400,0.96237224,5.351529,,,,,,,,,,,,,, -72500,0.9912396,3.3476858,,,,,,,,,,,,,, -72600,1.0616559,3.6128564,,,,,,,,,,,,,, -72601,,,0.4733984172344208,2.335913896560669,0.442220002412796,2.501166582107544,50000.0,0.3451000154018402,3.1486642360687256,10000.0,33648.43023991585,37842.20454144478,33648.43023991585,4186.46707201004,3.180236339569092,0.0 -72700,1.3057541,3.470587,,,,,,,,,,,,,, -72800,1.3831582,3.940005,,,,,,,,,,,,,, -72900,0.9731852,3.2913284,,,,,,,,,,,,,, -73000,1.148952,3.6011076,,,,,,,,,,,,,, -73100,0.8214449,4.2366056,,,,,,,,,,,,,, -73200,1.079286,5.1100802,,,,,,,,,,,,,, -73300,0.69042504,5.6595774,,,,,,,,,,,,,, -73400,1.1781003,3.33223,,,,,,,,,,,,,, -73500,1.2475622,3.4053333,,,,,,,,,,,,,, -73512,,,0.4789257645606994,2.335693359375,0.4441199898719787,2.5147647857666016,50000.0,0.34170001745224,3.1700222492218018,10000.0,34068.49537944794,38316.67360329628,34068.49537944794,4240.780715942383,3.2186429500579834,0.0 -73600,1.4230036,3.667636,,,,,,,,,,,,,, -73700,1.0030687,3.5612383,,,,,,,,,,,,,, -73800,1.1793293,3.2964892,,,,,,,,,,,,,, -73900,1.1530055,3.382042,,,,,,,,,,,,,, -74000,1.5550823,3.4079218,,,,,,,,,,,,,, -74100,1.0846846,3.573311,,,,,,,,,,,,,, -74200,0.7710095,4.447818,,,,,,,,,,,,,, -74300,1.0536231,3.8781092,,,,,,,,,,,,,, -74400,0.92534816,3.6486518,,,,,,,,,,,,,, -74422,,,0.514941394329071,2.1037821769714355,0.456279993057251,2.4261739253997803,50000.0,0.3525000214576721,3.0565521717071533,10000.0,34488.55130815506,38787.65641450882,34488.55130815506,4291.616386175156,3.258612632751465,0.0 -74500,1.2003647,3.4750774,,,,,,,,,,,,,, -74600,1.2470775,3.4957194,,,,,,,,,,,,,, -74700,1.1809078,3.5098062,,,,,,,,,,,,,, -74800,0.9756389,4.206628,,,,,,,,,,,,,, -74900,1.1696863,3.2998528,,,,,,,,,,,,,, -75000,1.3474892,3.3254747,,,,,,,,,,,,,, -75100,1.0696745,3.8382764,,,,,,,,,,,,,, -75200,1.0668856,3.4758525,,,,,,,,,,,,,, -75300,1.0609812,4.17274,,,,,,,,,,,,,, -75331,,,0.4822070300579071,2.265528440475464,0.4532800018787384,2.43257212638855,50000.0,0.3541000187397003,3.067671537399292,10000.0,34908.481731414795,39258.649804115295,34908.481731414795,4342.585709571838,3.3010001182556152,0.0 -75400,1.0590978,5.702571,,,,,,,,,,,,,, -75500,0.9250024,3.694728,,,,,,,,,,,,,, -75600,0.96474385,5.0656967,,,,,,,,,,,,,, -75700,1.0863663,3.3478556,,,,,,,,,,,,,, -75800,1.086783,3.2906084,,,,,,,,,,,,,, -75900,0.9363647,4.2895107,,,,,,,,,,,,,, -76000,1.1228471,3.541594,,,,,,,,,,,,,, -76100,1.1597377,3.3564897,,,,,,,,,,,,,, -76200,1.2109914,3.3752117,,,,,,,,,,,,,, -76238,,,0.4829687476158142,2.3049960136413574,0.4452399909496307,2.4958457946777344,50000.0,0.3431000113487243,3.140433549880981,10000.0,35328.56245660782,39728.39041757584,35328.56245660782,4392.15483045578,3.3405954837799072,0.0 -76300,0.78918356,5.104451,,,,,,,,,,,,,, -76400,1.0530356,3.316263,,,,,,,,,,,,,, -76500,1.4180611,3.3963485,,,,,,,,,,,,,, -76600,1.0779383,4.1246505,,,,,,,,,,,,,, -76700,1.0932246,3.4249043,,,,,,,,,,,,,, -76800,1.0206481,3.760538,,,,,,,,,,,,,, -76900,0.9040343,4.837856,,,,,,,,,,,,,, -77000,1.1401123,3.488566,,,,,,,,,,,,,, -77100,1.1424831,3.7693305,,,,,,,,,,,,,, -77145,,,0.5133984088897705,2.1134865283966064,0.4520799815654754,2.433824062347412,50000.0,0.3474000096321106,3.1019248962402344,10000.0,35748.48908209801,40200.3743493557,35748.48908209801,4444.11982011795,3.381718873977661,0.0 -77200,1.2473845,3.4659932,,,,,,,,,,,,,, -77300,1.0466313,3.5767086,,,,,,,,,,,,,, -77400,1.0392575,5.7490597,,,,,,,,,,,,,, -77500,0.9536909,3.5207553,,,,,,,,,,,,,, -77600,1.1095079,3.6079555,,,,,,,,,,,,,, -77700,1.0343307,3.4022372,,,,,,,,,,,,,, -77800,0.92269105,4.9414454,,,,,,,,,,,,,, -77900,0.9874272,5.6682305,,,,,,,,,,,,,, -78000,0.98851126,3.0828998,,,,,,,,,,,,,, -78051,,,0.4828515648841858,2.2945902347564697,0.4523399770259857,2.466150522232056,50000.0,0.3499000072479248,3.089308261871338,10000.0,36168.59254050255,40671.88909912109,36168.59254050255,4495.4370748996735,3.4243533611297607,0.0 -78100,0.9310939,4.3347263,,,,,,,,,,,,,, -78200,1.1855999,3.3832703,,,,,,,,,,,,,, -78300,1.1069721,3.3151822,,,,,,,,,,,,,, -78400,1.1445735,3.2863247,,,,,,,,,,,,,, -78500,1.1495128,3.2698207,,,,,,,,,,,,,, -78600,1.2505213,3.3834112,,,,,,,,,,,,,, -78700,1.1719338,3.3902764,,,,,,,,,,,,,, -78800,1.1993561,3.3655174,,,,,,,,,,,,,, -78900,0.9751034,5.680029,,,,,,,,,,,,,, -78960,,,0.4942382574081421,2.219955921173096,0.4580599963665008,2.418983459472656,50000.0,0.3455000221729278,3.0912413597106934,10000.0,36588.80544137955,41143.55085515976,36588.80544137955,4546.794453859329,3.463703393936157,0.0 -79000,1.0447694,5.7742243,,,,,,,,,,,,,, -79100,1.0372475,3.2948384,,,,,,,,,,,,,, -79200,1.044496,3.493214,,,,,,,,,,,,,, -79300,0.7851738,5.110824,,,,,,,,,,,,,, -79400,1.074257,3.252267,,,,,,,,,,,,,, -79500,1.0056316,5.7656403,,,,,,,,,,,,,, -79600,0.91838646,5.4665375,,,,,,,,,,,,,, -79700,0.8970746,5.4377527,,,,,,,,,,,,,, -79800,1.1229697,3.5248542,,,,,,,,,,,,,, -79864,,,0.5013476610183716,2.2078988552093506,0.4576599895954132,2.443406581878662,50000.0,0.3482000231742859,3.0952043533325195,10000.0,37008.87369513512,41617.35136270523,37008.87369513512,4600.427646636963,3.510913848876953,0.0 -79900,1.099128,3.4711785,,,,,,,,,,,,,, -80000,1.2510136,3.5355604,,,,,,,,,,,,,, -80100,0.8591544,4.2864466,,,,,,,,,,,,,, -80200,1.167224,3.3545227,,,,,,,,,,,,,, -80300,1.3313533,3.4535491,,,,,,,,,,,,,, -80400,1.1225319,3.6635222,,,,,,,,,,,,,, -80500,0.9460504,4.698856,,,,,,,,,,,,,, -80600,1.0398456,3.258898,,,,,,,,,,,,,, -80700,1.2773013,3.2825742,,,,,,,,,,,,,, -80773,,,0.4915429651737213,2.2458813190460205,0.4599799811840057,2.409734010696411,50000.0,0.3550000190734863,3.052574396133423,10000.0,37429.11387252808,42090.65738582611,37429.11387252808,4653.40140414238,3.551713705062866,0.0 -80800,0.9597548,5.1458893,,,,,,,,,,,,,, -80900,1.2590345,3.3674784,,,,,,,,,,,,,, -81000,1.0135185,3.610643,,,,,,,,,,,,,, -81100,0.99266326,3.6470566,,,,,,,,,,,,,, -81200,0.88211566,4.9660625,,,,,,,,,,,,,, -81300,1.0894555,3.294958,,,,,,,,,,,,,, -81400,1.1469464,3.2143435,,,,,,,,,,,,,, -81500,1.0041757,3.7738721,,,,,,,,,,,,,, -81600,0.97982323,3.2891479,,,,,,,,,,,,,, -81681,,,0.4970312416553497,2.2210443019866943,0.4611999988555908,2.407777547836304,50000.0,0.3586000204086303,3.044772148132324,10000.0,37849.28627562523,42562.85619473457,37849.28627562523,4705.331708192825,3.597206354141236,0.0 -81700,1.0138528,4.0915136,,,,,,,,,,,,,, -81800,1.1939626,3.3218281,,,,,,,,,,,,,, -81900,0.873145,5.322836,,,,,,,,,,,,,, -82000,1.052114,3.1330576,,,,,,,,,,,,,, -82100,1.2053428,3.3172908,,,,,,,,,,,,,, -82200,1.0824218,3.316158,,,,,,,,,,,,,, -82300,1.0216326,4.3684173,,,,,,,,,,,,,, -82400,1.0893548,3.3885791,,,,,,,,,,,,,, -82500,1.0578948,4.8144255,,,,,,,,,,,,,, -82586,,,0.5177343487739563,2.106281042098999,0.4700599908828735,2.348548889160156,50000.0,0.3614000082015991,3.003753185272217,10000.0,38269.44111919403,43033.84465241432,38269.44111919403,4756.069163322449,3.641781806945801,0.0 -82600,1.187295,3.296026,,,,,,,,,,,,,, -82700,1.0760309,3.8674066,,,,,,,,,,,,,, -82800,0.85945845,5.654274,,,,,,,,,,,,,, -82900,1.2027988,3.6427407,,,,,,,,,,,,,, -83000,0.9626986,4.726523,,,,,,,,,,,,,, -83100,1.1549717,3.3878367,,,,,,,,,,,,,, -83200,0.9927665,3.5781457,,,,,,,,,,,,,, -83300,1.2021866,3.1705751,,,,,,,,,,,,,, -83400,1.157727,3.3430698,,,,,,,,,,,,,, -83497,,,0.4992382824420929,2.189939260482788,0.4672999978065491,2.3583219051361084,50000.0,0.3677000105381012,2.9964702129364014,10000.0,38689.66334319115,43508.62535381317,38689.66334319115,4810.5310571193695,3.686613321304321,0.0 -83500,0.835596,4.4495745,,,,,,,,,,,,,, -83600,1.3389225,3.3656802,,,,,,,,,,,,,, -83700,1.1090753,3.2303057,,,,,,,,,,,,,, -83800,0.83675253,5.2212615,,,,,,,,,,,,,, -83900,0.9048282,4.1830554,,,,,,,,,,,,,, -84000,1.0410115,3.395811,,,,,,,,,,,,,, -84100,0.9977202,5.643666,,,,,,,,,,,,,, -84200,1.1749719,3.2420502,,,,,,,,,,,,,, -84300,0.9362786,5.767623,,,,,,,,,,,,,, -84399,,,0.5051953196525574,2.1985788345336914,0.4702000021934509,2.3659684658050537,50000.0,0.3638000190258026,3.0029189586639404,10000.0,39109.79868769646,43980.53152704239,39109.79868769646,4862.211612701416,3.726073265075684,0.0 -84400,0.9357597,3.9478335,,,,,,,,,,,,,, -84500,0.8962071,5.0113,,,,,,,,,,,,,, -84600,1.13602,3.5386555,,,,,,,,,,,,,, -84700,0.9084932,5.2624784,,,,,,,,,,,,,, -84800,1.0966558,3.4134362,,,,,,,,,,,,,, -84900,1.0129604,3.4184847,,,,,,,,,,,,,, -85000,1.0901167,3.2336404,,,,,,,,,,,,,, -85100,1.1176783,3.3140604,,,,,,,,,,,,,, -85200,1.0911065,3.8394105,,,,,,,,,,,,,, -85300,0.91505164,5.119936,,,,,,,,,,,,,, -85308,,,0.505664050579071,2.189625263214112,0.4622199833393097,2.41805386543274,50000.0,0.35630002617836,3.0596253871917725,10000.0,39529.994701862335,44453.88608646393,39529.994701862335,4915.274274110794,3.7699716091156006,0.0 -85400,1.0616326,3.831286,,,,,,,,,,,,,, -85500,1.2863595,3.3245327,,,,,,,,,,,,,, -85600,1.1129416,3.1495633,,,,,,,,,,,,,, -85700,1.1293334,3.1743226,,,,,,,,,,,,,, -85800,0.98343116,3.8185143,,,,,,,,,,,,,, -85900,1.0305605,4.629941,,,,,,,,,,,,,, -86000,1.0712298,3.286361,,,,,,,,,,,,,, -86100,1.0143849,3.1195822,,,,,,,,,,,,,, -86200,1.0161545,4.0524893,,,,,,,,,,,,,, -86215,,,0.4994335770606994,2.226904392242432,0.4710399806499481,2.3905842304229736,50000.0,0.3636000156402588,3.021517038345337,10000.0,39950.410684108734,44924.90372800827,39950.410684108734,4965.784074783325,3.8100292682647705,0.0 -86300,0.92096084,5.2552986,,,,,,,,,,,,,, -86400,1.0884329,3.2318945,,,,,,,,,,,,,, -86500,1.0336678,3.920197,,,,,,,,,,,,,, -86600,1.1588061,3.1555524,,,,,,,,,,,,,, -86700,1.1582639,4.0786743,,,,,,,,,,,,,, -86800,1.2151697,3.2359576,,,,,,,,,,,,,, -86900,1.1226624,3.2373269,,,,,,,,,,,,,, -87000,1.0849177,3.161812,,,,,,,,,,,,,, -87100,1.2305834,3.1646993,,,,,,,,,,,,,, -87118,,,0.5128515362739563,2.143615484237671,0.4809999763965606,2.3133881092071533,50000.0,0.3753000199794769,2.977247714996338,10000.0,40370.71117019653,45396.8552069664,40370.71117019653,5017.331121683121,3.862768650054932,0.0 -87200,1.1029226,3.3341343,,,,,,,,,,,,,, -87300,1.1980034,3.3106222,,,,,,,,,,,,,, -87400,1.019784,3.2874293,,,,,,,,,,,,,, -87500,1.3284668,3.2038968,,,,,,,,,,,,,, -87600,1.1247468,3.1872778,,,,,,,,,,,,,, -87700,1.1478212,3.3725533,,,,,,,,,,,,,, -87800,1.1120812,3.1422498,,,,,,,,,,,,,, -87900,1.0278195,4.04597,,,,,,,,,,,,,, -88000,1.1015687,3.3888068,,,,,,,,,,,,,, -88021,,,0.5203710794448853,2.086500883102417,0.4777399897575378,2.3107686042785645,50000.0,0.3805000185966491,2.938375234603882,10000.0,40790.89816689491,45868.57350111008,40790.89816689491,5068.761607885361,3.913156747817993,0.0 -88100,1.0918939,3.3156414,,,,,,,,,,,,,, -88200,0.95276964,4.4614167,,,,,,,,,,,,,, -88300,0.86704504,4.147668,,,,,,,,,,,,,, -88400,1.1542047,3.2971916,,,,,,,,,,,,,, -88500,0.9033171,5.4357967,,,,,,,,,,,,,, -88600,1.2204858,3.2345276,,,,,,,,,,,,,, -88700,1.1997657,3.3946128,,,,,,,,,,,,,, -88800,1.3051777,3.006842,,,,,,,,,,,,,, -88900,1.0125489,3.3663273,,,,,,,,,,,,,, -88927,,,0.5184179544448853,2.0980348587036133,0.4792400002479553,2.289841413497925,50000.0,0.3777000308036804,2.9439592361450195,10000.0,41210.95104074478,46339.93909049034,41210.95104074478,5119.981722831726,3.954137563705444,0.0 -89000,1.2267627,3.5614963,,,,,,,,,,,,,, -89100,0.9954848,3.572168,,,,,,,,,,,,,, -89200,1.1926084,3.3234153,,,,,,,,,,,,,, -89300,1.1149476,3.114956,,,,,,,,,,,,,, -89400,1.0381329,3.4759097,,,,,,,,,,,,,, -89500,1.1088399,3.2092865,,,,,,,,,,,,,, -89600,0.9426625,5.737181,,,,,,,,,,,,,, -89700,1.1822687,3.2612767,,,,,,,,,,,,,, -89800,1.1436051,3.720887,,,,,,,,,,,,,, -89836,,,0.5190234184265137,2.101839303970337,0.4875999987125397,2.2763702869415283,50000.0,0.3736000061035156,2.944952964782715,10000.0,41631.14861416817,46812.072179317474,41631.14861416817,5171.822760105133,3.997264623641968,0.0 -89900,1.1517241,3.3430185,,,,,,,,,,,,,, -90000,1.1563234,3.3956745,,,,,,,,,,,,,, -90100,1.2159785,3.151897,,,,,,,,,,,,,, -90200,1.1449347,3.1739488,,,,,,,,,,,,,, -90300,1.1366292,3.1561687,,,,,,,,,,,,,, -90400,1.1268746,3.245885,,,,,,,,,,,,,, -90500,1.1087455,3.1341639,,,,,,,,,,,,,, -90600,1.2458272,3.243289,,,,,,,,,,,,,, -90700,1.190912,3.111479,,,,,,,,,,,,,, -90743,,,0.5146093368530273,2.1558313369750977,0.475519984960556,2.3405981063842773,50000.0,0.3727000057697296,2.997620820999145,10000.0,42051.27122282982,47286.09899115562,42051.27122282982,5225.636157512665,4.036552667617798,0.0 -90800,0.9250695,4.834938,,,,,,,,,,,,,, -90900,1.2108626,3.1260543,,,,,,,,,,,,,, -91000,1.153259,3.7985456,,,,,,,,,,,,,, -91100,1.012187,4.9874363,,,,,,,,,,,,,, -91200,0.89775693,4.9066815,,,,,,,,,,,,,, -91300,1.1714954,3.145548,,,,,,,,,,,,,, -91400,0.91222125,3.7677636,,,,,,,,,,,,,, -91500,1.1088382,3.3487391,,,,,,,,,,,,,, -91600,1.1448027,3.1334279,,,,,,,,,,,,,, -91651,,,0.5436718463897705,2.0345213413238525,0.4791199862957001,2.349999189376831,50000.0,0.3713000118732452,2.993382453918457,10000.0,42471.242216825485,47757.972594976425,42471.242216825485,5277.445053577423,4.078348875045776,0.0 -91700,1.1590855,3.208767,,,,,,,,,,,,,, -91800,1.123951,5.108833,,,,,,,,,,,,,, -91900,1.0580962,4.0370636,,,,,,,,,,,,,, -92000,1.3117222,3.1589653,,,,,,,,,,,,,, -92100,1.1302459,3.835385,,,,,,,,,,,,,, -92200,1.0512096,4.2124867,,,,,,,,,,,,,, -92300,1.103057,3.3158646,,,,,,,,,,,,,, -92400,1.3282697,3.3604028,,,,,,,,,,,,,, -92500,1.0903108,3.0218434,,,,,,,,,,,,,, -92556,,,0.5254101157188416,2.0577504634857178,0.4925599992275238,2.2484467029571533,50000.0,0.3812000155448913,2.9217004776000977,10000.0,42891.40844297409,48231.31739473343,42891.40844297409,5330.52893781662,4.1219470500946045,0.0 -92600,1.1505475,3.1763117,,,,,,,,,,,,,, -92700,1.2530711,3.1930175,,,,,,,,,,,,,, -92800,0.9885685,5.087852,,,,,,,,,,,,,, -92900,0.9557922,4.094105,,,,,,,,,,,,,, -93000,1.2620426,3.212645,,,,,,,,,,,,,, -93100,1.0842497,3.2280083,,,,,,,,,,,,,, -93200,0.95322686,5.5228505,,,,,,,,,,,,,, -93300,1.0414779,4.9550705,,,,,,,,,,,,,, -93400,0.9287232,3.671286,,,,,,,,,,,,,, -93464,,,0.5331054329872131,2.015186786651612,0.4948799908161163,2.220568418502808,50000.0,0.3795000314712524,2.9080698490142822,10000.0,43311.48286437988,48703.985027074814,43311.48286437988,5383.027472257614,4.1651999950408936,0.0 -93500,0.94065434,4.428635,,,,,,,,,,,,,, -93600,0.88949394,5.442005,,,,,,,,,,,,,, -93700,1.155038,3.0267973,,,,,,,,,,,,,, -93800,1.1879452,3.1029398,,,,,,,,,,,,,, -93900,1.0534093,3.2500346,,,,,,,,,,,,,, -94000,1.1071335,3.1364524,,,,,,,,,,,,,, -94100,0.9181158,4.4200473,,,,,,,,,,,,,, -94200,0.9573997,5.5946693,,,,,,,,,,,,,, -94300,1.0761597,3.113873,,,,,,,,,,,,,, -94368,,,0.5615038871765137,1.8956527709960933,0.5003199577331543,2.2095954418182373,50000.0,0.3873000144958496,2.886227607727051,10000.0,43731.60246658325,49176.6410984993,43731.60246658325,5435.459993362427,4.21837306022644,0.0 -94400,0.8992617,4.9426537,,,,,,,,,,,,,, -94500,1.1156511,3.1199126,,,,,,,,,,,,,, -94600,1.0564845,3.4461973,,,,,,,,,,,,,, -94700,0.9298753,5.328746,,,,,,,,,,,,,, -94800,0.96769994,5.447328,,,,,,,,,,,,,, -94900,1.3005908,2.9937415,,,,,,,,,,,,,, -95000,1.009471,5.268461,,,,,,,,,,,,,, -95100,1.3788573,3.15128,,,,,,,,,,,,,, -95200,0.99211866,3.8642144,,,,,,,,,,,,,, -95273,,,0.5392187237739563,1.992055058479309,0.5055800080299377,2.1670339107513428,50000.0,0.3952000141143799,2.8237500190734863,10000.0,44151.640016794205,49648.15318131447,44151.640016794205,5486.836786031723,4.264795303344727,0.0 -95300,1.0847192,3.3011117,,,,,,,,,,,,,, -95400,1.2049335,2.9820178,,,,,,,,,,,,,, -95500,1.0620731,3.0049624,,,,,,,,,,,,,, -95600,0.92307204,5.3153186,,,,,,,,,,,,,, -95700,1.11341,3.0382128,,,,,,,,,,,,,, -95800,1.251697,3.0370283,,,,,,,,,,,,,, -95900,0.9016226,3.947168,,,,,,,,,,,,,, -96000,1.2091243,3.0681612,,,,,,,,,,,,,, -96100,1.0081216,4.9674187,,,,,,,,,,,,,, -96179,,,0.5384374856948853,2.011241912841797,0.5003399848937988,2.213174343109131,50000.0,0.3921000063419342,2.8492650985717773,10000.0,44571.82088375092,50121.71514558792,44571.82088375092,5540.123233795166,4.307686805725098,0.0 -96200,1.3612096,3.18271,,,,,,,,,,,,,, -96300,1.2839655,3.1877444,,,,,,,,,,,,,, -96400,1.0214707,2.993099,,,,,,,,,,,,,, -96500,0.802686,5.254011,,,,,,,,,,,,,, -96600,1.2119273,3.4579966,,,,,,,,,,,,,, -96700,1.1060929,3.8993535,,,,,,,,,,,,,, -96800,1.1034381,3.9628203,,,,,,,,,,,,,, -96900,1.1933054,3.0725732,,,,,,,,,,,,,, -97000,0.953299,5.3148465,,,,,,,,,,,,,, -97086,,,0.5535546541213989,1.9242347478866573,0.5067799687385559,2.171412706375122,50000.0,0.398000031709671,2.8221137523651123,10000.0,44992.06829333305,50591.74269366264,44992.06829333305,5589.800358533859,4.35940146446228,0.0 -97100,1.1198325,3.4514163,,,,,,,,,,,,,, -97200,1.1544179,3.3983634,,,,,,,,,,,,,, -97300,1.326099,3.398565,,,,,,,,,,,,,, -97400,1.2167994,3.241647,,,,,,,,,,,,,, -97500,1.2240485,5.352982,,,,,,,,,,,,,, -97600,0.99891305,5.5170813,,,,,,,,,,,,,, -97700,0.9440664,4.128698,,,,,,,,,,,,,, -97800,1.1139667,3.5314186,,,,,,,,,,,,,, -97900,1.1626298,3.1028264,,,,,,,,,,,,,, -97993,,,0.532910168170929,2.055387020111084,0.4993399977684021,2.2423460483551025,50000.0,0.3891000151634216,2.8874340057373047,10000.0,45412.143065452576,51066.33588290215,45412.143065452576,5644.221765995026,4.40524959564209,0.0 -98000,1.1904438,3.3225956,,,,,,,,,,,,,, -98100,0.93973637,5.2860565,,,,,,,,,,,,,, -98200,1.2777431,3.0921636,,,,,,,,,,,,,, -98300,1.3255986,3.0574782,,,,,,,,,,,,,, -98400,1.2825637,3.1156492,,,,,,,,,,,,,, -98500,1.1697669,3.3811827,,,,,,,,,,,,,, -98600,1.0736777,4.9342327,,,,,,,,,,,,,, -98700,1.2529231,3.107346,,,,,,,,,,,,,, -98800,1.0363859,3.4039311,,,,,,,,,,,,,, -98898,,,0.5400781035423279,2.028705358505249,0.5006399750709534,2.236777067184448,50000.0,0.3814000189304352,2.900303363800049,10000.0,45832.18708944321,51537.9104912281,45832.18708944321,5695.656700849533,4.450121402740479,0.0 -98900,1.2224818,3.0522878,,,,,,,,,,,,,, -99000,1.3125496,3.0428588,,,,,,,,,,,,,, -99100,1.0330517,5.558308,,,,,,,,,,,,,, -99200,1.289596,3.143911,,,,,,,,,,,,,, -99300,1.5159055,3.0894704,,,,,,,,,,,,,, -99400,1.1186241,3.2970235,,,,,,,,,,,,,, -99500,1.1630752,2.9835672,,,,,,,,,,,,,, -99600,1.2498026,3.084704,,,,,,,,,,,,,, -99700,1.1903604,3.1366286,,,,,,,,,,,,,, -99800,1.3691682,3.075776,,,,,,,,,,,,,, -99804,,,0.5590429306030273,1.925033688545227,0.5095599889755249,2.1697280406951904,50000.0,0.3951000273227691,2.8197717666625977,10000.0,46252.08383107185,52010.89158630371,46252.08383107185,5748.635198116303,4.504448413848877,0.0 -99900,1.1892883,2.8901522,,,,,,,,,,,,,, -100000,0.94831014,4.7451034,,,,,,,,,,,,,, -100100,1.1622733,3.0563161,,,,,,,,,,,,,, -100200,1.2106621,2.9490323,,,,,,,,,,,,,, -100300,1.1012911,3.0824358,,,,,,,,,,,,,, -100400,1.1672226,2.9985795,,,,,,,,,,,,,, -100500,1.0534728,3.7048254,,,,,,,,,,,,,, -100600,1.2134429,2.9671564,,,,,,,,,,,,,, -100700,1.4257499,3.121759,,,,,,,,,,,,,, -100711,,,0.5338476300239563,2.064886569976806,0.4959999918937683,2.255467176437378,50000.0,0.3897000253200531,2.8789451122283936,10000.0,46672.03926706314,52482.73132753372,46672.03926706314,5800.427268981934,4.545479774475098,0.0 -100800,0.9257977,5.3417006,,,,,,,,,,,,,, -100900,1.1806056,3.0840428,,,,,,,,,,,,,, -101000,1.4825177,2.949307,,,,,,,,,,,,,, -101100,1.1227922,3.6128302,,,,,,,,,,,,,, -101200,1.1984184,3.2751093,,,,,,,,,,,,,, -101300,0.9552902,4.6150136,,,,,,,,,,,,,, -101400,1.257202,3.0361707,,,,,,,,,,,,,, -101500,1.2575784,3.1170864,,,,,,,,,,,,,, -101600,1.1750379,3.1040998,,,,,,,,,,,,,, -101614,,,0.5544726252555847,1.9149682521820068,0.5176399946212769,2.104093313217163,50000.0,0.4047000110149383,2.784193754196167,10000.0,47091.95157575607,52954.14155244827,47091.95157575607,5851.827321767807,4.592525005340576,0.0 -101700,1.1782937,2.8775518,,,,,,,,,,,,,, -101800,1.2432668,3.2379692,,,,,,,,,,,,,, -101900,1.2419298,3.1548657,,,,,,,,,,,,,, -102000,1.2073474,2.9976153,,,,,,,,,,,,,, -102100,1.0652288,4.6119003,,,,,,,,,,,,,, -102200,1.0139304,3.6579745,,,,,,,,,,,,,, -102300,1.2056866,2.9782827,,,,,,,,,,,,,, -102400,0.960709,4.523933,,,,,,,,,,,,,, -102500,1.0973701,3.0173535,,,,,,,,,,,,,, -102519,,,0.558300793170929,1.900124430656433,0.5158199667930603,2.11898159980774,50000.0,0.4044000208377838,2.776407241821289,10000.0,47512.10585308075,53425.54627323151,47512.10585308075,5902.980116844177,4.639246225357056,0.0 -102600,1.3673555,3.0304615,,,,,,,,,,,,,, -102700,1.0765672,3.7412505,,,,,,,,,,,,,, -102800,1.1454799,3.0211265,,,,,,,,,,,,,, -102900,1.2612926,3.0253916,,,,,,,,,,,,,, -103000,1.1439699,2.8948796,,,,,,,,,,,,,, -103100,1.151007,2.9546034,,,,,,,,,,,,,, -103200,1.0763332,4.3835773,,,,,,,,,,,,,, -103300,0.89662224,4.7623196,,,,,,,,,,,,,, -103400,0.98723394,4.505414,,,,,,,,,,,,,, -103426,,,0.5542968511581421,1.920783758163452,0.5221399664878845,2.0911567211151123,50000.0,0.4073000252246856,2.7569947242736816,10000.0,47932.22675728798,53898.2055375576,47932.22675728798,5955.425267219544,4.68032431602478,0.0 -103500,1.0345129,4.6013994,,,,,,,,,,,,,, -103600,1.4338633,3.0553217,,,,,,,,,,,,,, -103700,1.0872252,3.3486292,,,,,,,,,,,,,, -103800,1.225623,2.9113853,,,,,,,,,,,,,, -103900,1.0604936,3.735444,,,,,,,,,,,,,, -104000,1.1360382,2.8732579,,,,,,,,,,,,,, -104100,1.2661631,3.2310848,,,,,,,,,,,,,, -104200,1.3343772,2.905026,,,,,,,,,,,,,, -104300,0.936867,5.397493,,,,,,,,,,,,,, -104333,,,0.5587499737739563,1.940987467765808,0.5200799703598022,2.1409268379211426,50000.0,0.4038000106811523,2.7949371337890625,10000.0,48352.40204453468,54370.11286449432,48352.40204453468,6007.061721801758,4.724331378936768,0.0 -104400,1.1515887,2.9522247,,,,,,,,,,,,,, -104500,1.2479476,2.9921765,,,,,,,,,,,,,, -104600,1.1936089,2.9020371,,,,,,,,,,,,,, -104700,1.2587788,3.0181437,,,,,,,,,,,,,, -104800,1.2202706,3.1049833,,,,,,,,,,,,,, -104900,1.4877372,2.964629,,,,,,,,,,,,,, -105000,0.994111,4.2551045,,,,,,,,,,,,,, -105100,1.2667109,3.0190709,,,,,,,,,,,,,, -105200,1.2460192,2.9325695,,,,,,,,,,,,,, -105241,,,0.560839831829071,1.9002279043197632,0.514959990978241,2.125225067138672,50000.0,0.403300017118454,2.796470880508423,10000.0,48772.720823049545,54844.04375267029,48772.720823049545,6060.579028129578,4.768460273742676,0.0 -105300,1.1645057,3.013236,,,,,,,,,,,,,, -105400,1.2209792,2.8825052,,,,,,,,,,,,,, -105500,1.2663195,3.0579238,,,,,,,,,,,,,, -105600,1.2655348,2.8508146,,,,,,,,,,,,,, -105700,1.2975615,2.9540625,,,,,,,,,,,,,, -105800,1.3099883,2.9034383,,,,,,,,,,,,,, -105900,1.2366107,2.9615614,,,,,,,,,,,,,, -106000,1.0268857,3.742132,,,,,,,,,,,,,, -106100,1.3527586,3.08749,,,,,,,,,,,,,, -106150,,,0.5675585865974426,1.8988195657730105,0.5279399752616882,2.0975606441497803,50000.0,0.4099000096321106,2.748702049255371,10000.0,49193.001353263855,55316.173254966736,49193.001353263855,6112.325652837753,4.819467544555664,0.0 -106200,1.2233263,2.9509487,,,,,,,,,,,,,, -106300,1.0080063,5.2131133,,,,,,,,,,,,,, -106400,1.2094883,2.9107401,,,,,,,,,,,,,, -106500,1.5092863,2.91238,,,,,,,,,,,,,, -106600,1.2513189,2.9743185,,,,,,,,,,,,,, -106700,1.1845615,2.8661952,,,,,,,,,,,,,, -106800,1.2209672,5.1555977,,,,,,,,,,,,,, -106900,1.2718297,3.0145893,,,,,,,,,,,,,, -107000,1.1734388,3.151694,,,,,,,,,,,,,, -107055,,,0.5690038800239563,1.8810782432556152,0.5263199806213379,2.085660457611084,50000.0,0.4122000336647033,2.7401907444000244,10000.0,49612.96082854271,55788.6028881073,49612.96082854271,6164.696760416031,4.867316961288452,0.0 -107100,1.0657297,4.8431177,,,,,,,,,,,,,, -107200,1.065268,3.4242032,,,,,,,,,,,,,, -107300,1.280556,3.2417507,,,,,,,,,,,,,, -107400,1.2229667,2.9355931,,,,,,,,,,,,,, -107500,1.0218762,4.7731266,,,,,,,,,,,,,, -107600,1.4336138,2.8373337,,,,,,,,,,,,,, -107700,1.2739564,2.8035893,,,,,,,,,,,,,, -107800,1.2458295,3.1763904,,,,,,,,,,,,,, -107900,1.2855554,2.9122043,,,,,,,,,,,,,, -107961,,,0.5733398199081421,1.872202754020691,0.5306199789047241,2.0735056400299072,50000.0,0.417900025844574,2.732545852661133,10000.0,50033.20030045509,56263.09414482117,50033.20030045509,6218.849471569061,4.915220022201538,0.0 -108000,1.001082,5.3747377,,,,,,,,,,,,,, -108100,1.429539,3.0512793,,,,,,,,,,,,,, -108200,1.1384081,4.393778,,,,,,,,,,,,,, -108300,1.2293159,2.8833048,,,,,,,,,,,,,, -108400,1.0113541,5.1111426,,,,,,,,,,,,,, -108500,1.1711845,3.7155135,,,,,,,,,,,,,, -108600,1.2721944,3.0474455,,,,,,,,,,,,,, -108700,1.398377,2.9672272,,,,,,,,,,,,,, -108800,1.1774133,4.199681,,,,,,,,,,,,,, -108871,,,0.6080859303474426,1.6695939302444458,0.5309799909591675,2.035268545150757,50000.0,0.41880002617836,2.701533317565918,10000.0,50453.37967920303,56735.69892835617,50453.37967920303,6271.163687944412,4.974422931671143,0.0 -108900,1.268059,2.9348066,,,,,,,,,,,,,, -109000,1.2080268,2.95734,,,,,,,,,,,,,, -109100,0.94475996,5.3869934,,,,,,,,,,,,,, -109200,1.0908386,4.434765,,,,,,,,,,,,,, -109300,1.3483374,2.9324255,,,,,,,,,,,,,, -109400,1.3084732,2.8265467,,,,,,,,,,,,,, -109500,1.2732226,2.8893716,,,,,,,,,,,,,, -109600,1.0089228,4.0729823,,,,,,,,,,,,,, -109700,1.243613,3.1861372,,,,,,,,,,,,,, -109779,,,0.5797070264816284,1.7871202230453491,0.5417999625205994,1.9782638549804688,50000.0,0.4267000257968902,2.641922950744629,10000.0,50873.40883421898,57209.77555155754,50873.40883421898,6325.108122825623,5.025132894515991,0.0 -109800,1.0347503,4.844629,,,,,,,,,,,,,, -109900,1.3459435,2.8588703,,,,,,,,,,,,,, -110000,1.2473143,2.907773,,,,,,,,,,,,,, -110100,1.3193439,3.0359995,,,,,,,,,,,,,, -110200,1.0921698,5.0767374,,,,,,,,,,,,,, -110300,1.0444188,3.7448776,,,,,,,,,,,,,, -110400,0.9885526,4.836481,,,,,,,,,,,,,, -110500,1.3417194,2.934132,,,,,,,,,,,,,, -110600,1.041021,4.7834597,,,,,,,,,,,,,, -110687,,,0.5723242163658142,1.8696619272232056,0.5303800106048584,2.078565120697021,50000.0,0.419400006532669,2.7238848209381104,10000.0,51293.60374307632,57680.238035440445,51293.60374307632,6375.277026414871,5.072314500808716,0.0 -110700,1.1835692,3.5266843,,,,,,,,,,,,,, -110800,1.1759994,4.887521,,,,,,,,,,,,,, -110900,1.0164235,4.554049,,,,,,,,,,,,,, -111000,1.2619636,2.9475808,,,,,,,,,,,,,, -111100,1.2143868,3.6390767,,,,,,,,,,,,,, -111200,1.169515,3.575858,,,,,,,,,,,,,, -111300,1.1430347,3.684631,,,,,,,,,,,,,, -111400,0.989588,5.3001695,,,,,,,,,,,,,, -111500,1.2221925,2.8485603,,,,,,,,,,,,,, -111597,,,0.6044335961341858,1.6973642110824585,0.542639970779419,2.007513999938965,50000.0,0.4234000146389007,2.6781985759735107,10000.0,51713.767781972885,58152.63842535019,51713.767781972885,6427.413257360458,5.120913028717041,0.0 -111600,1.0564045,3.582253,,,,,,,,,,,,,, -111700,1.2274065,2.911086,,,,,,,,,,,,,, -111800,1.1526843,4.5676613,,,,,,,,,,,,,, -111900,1.278031,2.8523836,,,,,,,,,,,,,, -112000,1.270978,2.7160985,,,,,,,,,,,,,, -112100,1.2582718,2.7932656,,,,,,,,,,,,,, -112200,1.0986688,5.339691,,,,,,,,,,,,,, -112300,1.0285949,3.8836532,,,,,,,,,,,,,, -112400,1.1794802,2.930324,,,,,,,,,,,,,, -112500,1.3369179,2.805553,,,,,,,,,,,,,, -112503,,,0.5824804306030273,1.7816941738128662,0.5447999835014343,1.961826205253601,50000.0,0.431300014257431,2.65842080116272,10000.0,52134.00865268707,58626.1765999794,52134.00865268707,6480.615281820297,5.165015459060669,0.0 -112600,1.1420875,3.3739147,,,,,,,,,,,,,, -112700,1.299286,2.8543274,,,,,,,,,,,,,, -112800,1.1800463,4.9728703,,,,,,,,,,,,,, -112900,1.3414278,2.7440596,,,,,,,,,,,,,, -113000,1.369612,2.942391,,,,,,,,,,,,,, -113100,1.2951648,2.7135098,,,,,,,,,,,,,, -113200,1.3488971,2.914262,,,,,,,,,,,,,, -113300,1.2128689,2.7531722,,,,,,,,,,,,,, -113400,1.290888,2.6856859,,,,,,,,,,,,,, -113409,,,0.5881054401397705,1.7460676431655884,0.5488399863243103,1.9370585680007928,50000.0,0.4369000196456909,2.59963321685791,10000.0,52554.14491820336,59097.95101642609,52554.14491820336,6532.158484697342,5.209194660186768,0.0 -113500,1.3585216,2.8042183,,,,,,,,,,,,,, -113600,1.3308434,2.8655035,,,,,,,,,,,,,, -113700,1.4515673,2.897819,,,,,,,,,,,,,, -113800,1.1801511,3.4020724,,,,,,,,,,,,,, -113900,1.3363795,2.8024883,,,,,,,,,,,,,, -114000,1.1053352,4.9203157,,,,,,,,,,,,,, -114100,1.3665786,5.1956663,,,,,,,,,,,,,, -114200,1.0233712,4.2747526,,,,,,,,,,,,,, -114300,1.445489,2.875588,,,,,,,,,,,,,, -114316,,,0.6094335913658142,1.6528574228286743,0.5530999898910522,1.9297101497650144,50000.0,0.4376000165939331,2.588284969329834,10000.0,52974.33516526222,59569.60813641548,52974.33516526222,6583.523062705994,5.259409666061401,0.0 -114400,1.1857201,5.191206,,,,,,,,,,,,,, -114500,1.2400774,2.8812788,,,,,,,,,,,,,, -114600,1.3892409,2.8509033,,,,,,,,,,,,,, -114700,1.2916808,2.8252072,,,,,,,,,,,,,, -114800,1.35775,2.9010437,,,,,,,,,,,,,, -114900,1.3496683,2.8356411,,,,,,,,,,,,,, -115000,1.0968215,4.5387716,,,,,,,,,,,,,, -115100,1.2644048,2.8437824,,,,,,,,,,,,,, -115200,1.1285634,4.731786,,,,,,,,,,,,,, -115222,,,0.5908203125,1.7480441331863403,0.5513399839401245,1.947059988975525,50000.0,0.4300000071525574,2.6224279403686523,10000.0,53394.59422135353,60042.13766551018,53394.59422135353,6635.696802854538,5.305138349533081,0.0 -115300,1.2959268,2.9869192,,,,,,,,,,,,,, -115400,1.2456396,3.522224,,,,,,,,,,,,,, -115500,1.1703905,3.2490828,,,,,,,,,,,,,, -115600,1.1280625,4.131317,,,,,,,,,,,,,, -115700,1.2842721,2.619186,,,,,,,,,,,,,, -115800,1.0621569,3.8663204,,,,,,,,,,,,,, -115900,1.0644201,4.0579834,,,,,,,,,,,,,, -116000,1.3720824,2.8542233,,,,,,,,,,,,,, -116100,1.1984596,4.0397396,,,,,,,,,,,,,, -116126,,,0.5944726467132568,1.7341322898864746,0.5532999634742737,1.9365402460098269,50000.0,0.4336000084877014,2.6125783920288086,10000.0,53814.87706851959,60514.86140489578,53814.87706851959,6688.041565418243,5.350852966308594,0.0 -116200,1.402981,2.7432308,,,,,,,,,,,,,, -116300,1.4308224,2.8319561,,,,,,,,,,,,,, -116400,1.4254524,2.7861097,,,,,,,,,,,,,, -116500,1.3730348,2.7589912,,,,,,,,,,,,,, -116600,1.1362597,3.7544792,,,,,,,,,,,,,, -116700,1.2606819,2.7829142,,,,,,,,,,,,,, -116800,1.6413019,2.6687665,,,,,,,,,,,,,, -116900,1.3535075,2.6819117,,,,,,,,,,,,,, -117000,1.38503,3.859818,,,,,,,,,,,,,, -117031,,,0.607128918170929,1.6637282371520996,0.5559200048446655,1.909610152244568,50000.0,0.4370000064373016,2.60026216506958,10000.0,54235.36640357971,60989.96281218529,54235.36640357971,6742.556987047195,5.396175384521484,0.0 -117100,1.3333676,4.954359,,,,,,,,,,,,,, -117200,1.1464441,4.8979225,,,,,,,,,,,,,, -117300,1.3673887,2.830011,,,,,,,,,,,,,, -117400,1.3065945,2.8237777,,,,,,,,,,,,,, -117500,1.6515784,2.7763379,,,,,,,,,,,,,, -117600,1.1136287,5.242521,,,,,,,,,,,,,, -117700,1.3328501,2.6633933,,,,,,,,,,,,,, -117800,1.2224131,2.7023609,,,,,,,,,,,,,, -117900,1.1882745,5.145474,,,,,,,,,,,,,, -117937,,,0.5986914038658142,1.7124024629592896,0.5584200024604797,1.9031184911727903,50000.0,0.4412000179290771,2.5633251667022705,10000.0,54655.59771943092,61462.07122516632,54655.59771943092,6794.329073667526,5.449033498764038,0.0 -118000,1.407766,2.760939,,,,,,,,,,,,,, -118100,1.471273,2.8404322,,,,,,,,,,,,,, -118200,1.3651549,5.458382,,,,,,,,,,,,,, -118300,1.0995996,4.8661156,,,,,,,,,,,,,, -118400,1.2718778,2.6690862,,,,,,,,,,,,,, -118500,1.2630699,2.8598273,,,,,,,,,,,,,, -118600,1.3793228,2.6624296,,,,,,,,,,,,,, -118700,1.3089437,2.9351106,,,,,,,,,,,,,, -118800,1.1780905,5.1173944,,,,,,,,,,,,,, -118844,,,0.6104296445846558,1.6610846519470217,0.5698599815368652,1.871710538864136,50000.0,0.4528000354766845,2.5325958728790283,10000.0,55075.97532296181,61935.68432497978,55075.97532296181,6847.463381290436,5.4985644817352295,0.0 -118900,1.3615916,3.2151525,,,,,,,,,,,,,, -119000,1.1386042,4.001645,,,,,,,,,,,,,, -119100,1.268281,2.6076427,,,,,,,,,,,,,, -119200,1.2978046,2.7801163,,,,,,,,,,,,,, -119300,1.3625709,2.9900572,,,,,,,,,,,,,, -119400,1.4178869,3.2732902,,,,,,,,,,,,,, -119500,1.2882928,2.6679642,,,,,,,,,,,,,, -119600,1.2369922,3.3923545,,,,,,,,,,,,,, -119700,1.102454,4.6979456,,,,,,,,,,,,,, -119751,,,0.6143554449081421,1.6402643918991089,0.5634399652481079,1.8867182731628416,50000.0,0.4409000277519226,2.555757999420166,10000.0,55496.04265832901,62407.2891471386,55496.04265832901,6898.902683258057,5.5450968742370605,0.0 -119800,1.5158839,2.5914483,,,,,,,,,,,,,, -119900,1.3275256,2.8003974,,,,,,,,,,,,,, -120000,1.3006312,2.5677435,,,,,,,,,,,,,, -120100,1.1824613,3.7521858,,,,,,,,,,,,,, -120200,1.3117371,2.775414,,,,,,,,,,,,,, -120300,1.1030021,4.4106565,,,,,,,,,,,,,, -120400,1.1699513,3.550658,,,,,,,,,,,,,, -120500,1.4412926,2.6130092,,,,,,,,,,,,,, -120600,1.2064723,3.5852263,,,,,,,,,,,,,, -120659,,,0.6095117330551147,1.7067832946777344,0.5638799667358398,1.9024053812026973,50000.0,0.4469000101089477,2.5595107078552246,10000.0,55916.21435809136,62878.28482103348,55916.21435809136,6949.624935626984,5.5959556102752686,0.0 -120700,1.2632641,2.6656022,,,,,,,,,,,,,, -120800,1.1779068,3.7170796,,,,,,,,,,,,,, -120900,1.1276096,5.240556,,,,,,,,,,,,,, -121000,1.4611006,2.7145836,,,,,,,,,,,,,, -121100,1.336197,2.7615309,,,,,,,,,,,,,, -121200,1.3602047,2.706603,,,,,,,,,,,,,, -121300,1.177285,3.5760105,,,,,,,,,,,,,, -121400,1.459022,2.7577944,,,,,,,,,,,,,, -121500,1.4130002,2.7655923,,,,,,,,,,,,,, -121566,,,0.6133007407188416,1.6413342952728271,0.5678399801254272,1.860566258430481,50000.0,0.4475000202655792,2.534411907196045,10000.0,56336.36602210999,63349.9245557785,56336.36602210999,7001.017154455185,5.640802383422852,0.0 -121600,1.1607363,4.0488567,,,,,,,,,,,,,, -121700,1.4142125,2.6316462,,,,,,,,,,,,,, -121800,1.4552587,2.7922666,,,,,,,,,,,,,, -121900,1.3407154,4.5337796,,,,,,,,,,,,,, -122000,1.0956608,5.153512,,,,,,,,,,,,,, -122100,1.3547152,2.6383424,,,,,,,,,,,,,, -122200,1.2012211,3.1030383,,,,,,,,,,,,,, -122300,1.362309,2.6792984,,,,,,,,,,,,,, -122400,1.3611119,3.0770504,,,,,,,,,,,,,, -122472,,,0.62255859375,1.6071964502334597,0.5728999972343445,1.8406982421875,50000.0,0.4550000131130218,2.503061532974243,10000.0,56756.33172440529,63820.65265607834,56756.33172440529,7051.67671251297,5.691583156585693,0.0 -122500,1.4936405,2.7270458,,,,,,,,,,,,,, -122600,1.4088391,2.760427,,,,,,,,,,,,,, -122700,1.6531998,2.589837,,,,,,,,,,,,,, -122800,1.3875353,2.7263498,,,,,,,,,,,,,, -122900,1.2399915,3.16365,,,,,,,,,,,,,, -123000,1.2333055,3.0716517,,,,,,,,,,,,,, -123100,1.6818237,3.088733,,,,,,,,,,,,,, -123200,1.4562472,2.5743012,,,,,,,,,,,,,, -123300,1.299415,2.648511,,,,,,,,,,,,,, -123377,,,0.6289257407188416,1.570678949356079,0.579479992389679,1.805068850517273,50000.0,0.4579000174999237,2.4760637283325195,10000.0,57176.62945842743,64293.56670308113,57176.62945842743,7104.192118406296,5.741584062576294,0.0 -123400,1.359676,2.5453453,,,,,,,,,,,,,, -123500,1.4801317,2.6884124,,,,,,,,,,,,,, -123600,1.37108,2.8365164,,,,,,,,,,,,,, -123700,1.3411353,2.46171,,,,,,,,,,,,,, -123800,1.542748,2.7815356,,,,,,,,,,,,,, -123900,1.4988652,2.8733783,,,,,,,,,,,,,, -124000,1.4515767,2.6141496,,,,,,,,,,,,,, -124100,1.4712856,2.701075,,,,,,,,,,,,,, -124200,1.5323716,2.7000434,,,,,,,,,,,,,, -124287,,,0.6289257407188416,1.5635874271392822,0.5827199816703796,1.768791675567627,50000.0,0.4643000364303589,2.429023504257202,10000.0,57596.96713638306,64766.72912335396,57596.96713638306,7156.915832519531,5.790433645248413,0.0 -124300,1.3722159,3.0151067,,,,,,,,,,,,,, -124400,1.3995224,2.640215,,,,,,,,,,,,,, -124500,1.462697,3.016389,,,,,,,,,,,,,, -124600,1.1128769,5.1416802,,,,,,,,,,,,,, -124700,1.2772628,2.93984,,,,,,,,,,,,,, -124800,1.5673863,2.5992594,,,,,,,,,,,,,, -124900,1.2541465,3.5908823,,,,,,,,,,,,,, -125000,1.439598,2.6533446,,,,,,,,,,,,,, -125100,1.551564,2.5870519,,,,,,,,,,,,,, -125193,,,0.6286327838897705,1.5781629085540771,0.5806199908256531,1.8158116340637207,50000.0,0.4616000354290008,2.4752357006073,10000.0,58017.08203911781,65238.811891555786,58017.08203911781,7208.776687383652,5.846323251724243,0.0 -125200,1.5120746,2.8052487,,,,,,,,,,,,,, -125300,1.3646233,2.6758318,,,,,,,,,,,,,, -125400,1.2058859,4.9441986,,,,,,,,,,,,,, -125500,1.4980531,2.620686,,,,,,,,,,,,,, -125600,1.3051537,3.496893,,,,,,,,,,,,,, -125700,1.3646779,2.657442,,,,,,,,,,,,,, -125800,1.4929472,2.6858618,,,,,,,,,,,,,, -125900,1.4098873,3.5246522,,,,,,,,,,,,,, -126000,1.2485577,4.0798373,,,,,,,,,,,,,, -126099,,,0.6649023294448853,1.4432421922683716,0.5847399830818176,1.7886931896209717,50000.0,0.4726000130176544,2.447232484817505,10000.0,58437.05323433876,65709.36529541016,58437.05323433876,7259.260968923569,5.892705202102661,0.0 -126100,1.5386032,2.5924697,,,,,,,,,,,,,, -126200,1.4796506,2.5009615,,,,,,,,,,,,,, -126300,1.1210971,4.7842093,,,,,,,,,,,,,, -126400,1.4019601,3.1197705,,,,,,,,,,,,,, -126500,1.5017809,5.279689,,,,,,,,,,,,,, -126600,1.2208922,4.190543,,,,,,,,,,,,,, -126700,1.3526843,2.670691,,,,,,,,,,,,,, -126800,1.5397055,2.806404,,,,,,,,,,,,,, -126900,1.5892383,2.6041238,,,,,,,,,,,,,, -127000,1.3925692,2.8281364,,,,,,,,,,,,,, -127005,,,0.6364062428474426,1.526646852493286,0.5922799706459045,1.7352849245071411,50000.0,0.4724000096321106,2.40802001953125,10000.0,58857.087052583694,66180.10751318932,58857.087052583694,7309.869294166565,5.941359758377075,0.0 -127100,1.2233036,3.8928914,,,,,,,,,,,,,, -127200,1.117494,4.9651847,,,,,,,,,,,,,, -127300,1.3948928,2.528533,,,,,,,,,,,,,, -127400,1.7137054,2.9370248,,,,,,,,,,,,,, -127500,1.1624354,4.520604,,,,,,,,,,,,,, -127600,1.59909,2.548171,,,,,,,,,,,,,, -127700,1.4888719,3.159446,,,,,,,,,,,,,, -127800,1.3263336,2.8665018,,,,,,,,,,,,,, -127900,1.548064,2.5885096,,,,,,,,,,,,,, -127910,,,0.6415820121765137,1.496817708015442,0.5906599760055542,1.746111512184143,50000.0,0.4738000333309173,2.39067816734314,10000.0,59277.23089051247,66653.61686849594,59277.23089051247,7363.129547357559,5.994652986526489,0.0 -128000,1.2300133,4.2264714,,,,,,,,,,,,,, -128100,1.3219332,3.7030401,,,,,,,,,,,,,, -128200,1.3154895,4.9200296,,,,,,,,,,,,,, -128300,1.2778639,5.2163134,,,,,,,,,,,,,, -128400,1.5599474,2.758377,,,,,,,,,,,,,, -128500,1.727641,2.571745,,,,,,,,,,,,,, -128600,1.4015833,3.055171,,,,,,,,,,,,,, -128700,1.4485104,2.969788,,,,,,,,,,,,,, -128800,1.6155523,2.6357095,,,,,,,,,,,,,, -128813,,,0.6640819907188416,1.4066020250320437,0.5945999622344971,1.7294906377792358,50000.0,0.4759000241756439,2.3901915550231934,10000.0,59697.32144474983,67126.2372546196,59697.32144474983,7415.558722019196,6.044222116470337,0.0 -128900,1.5666996,2.4998288,,,,,,,,,,,,,, -129000,1.5708162,2.4467955,,,,,,,,,,,,,, -129100,1.3324374,5.077939,,,,,,,,,,,,,, -129200,1.4644357,2.8792787,,,,,,,,,,,,,, -129300,1.4233451,3.8854494,,,,,,,,,,,,,, -129400,1.5393461,2.5598335,,,,,,,,,,,,,, -129500,1.2958595,4.9650536,,,,,,,,,,,,,, -129600,1.6350197,2.6211185,,,,,,,,,,,,,, -129700,1.643991,2.6123283,,,,,,,,,,,,,, -129719,,,0.6342382431030273,1.546700358390808,0.5904600024223328,1.7574831247329712,50000.0,0.4695000350475311,2.420825242996216,10000.0,60117.41744160652,67597.53349399567,60117.41744160652,7466.64894080162,6.103210926055908,0.0 -129800,1.5932066,2.7029166,,,,,,,,,,,,,, -129900,1.644967,2.5384336,,,,,,,,,,,,,, -130000,1.4438529,2.6770642,,,,,,,,,,,,,, -130100,1.342717,3.2491975,,,,,,,,,,,,,, -130200,1.5218481,2.987136,,,,,,,,,,,,,, -130300,1.3964386,5.1875963,,,,,,,,,,,,,, -130400,1.5393679,2.5439565,,,,,,,,,,,,,, -130500,1.3208568,5.041627,,,,,,,,,,,,,, -130600,1.5027401,4.3656077,,,,,,,,,,,,,, -130626,,,0.6511132717132568,1.4745047092437744,0.5997399687767029,1.7100039720535278,50000.0,0.4785000085830688,2.387399196624756,10000.0,60537.58105373383,68068.39860129356,60537.58105373383,7517.243643760681,6.158584356307983,0.0 -130700,1.6235815,2.571911,,,,,,,,,,,,,, -130800,1.2391062,4.007875,,,,,,,,,,,,,, -130900,1.5322076,2.4360712,,,,,,,,,,,,,, -131000,1.3440816,3.3356512,,,,,,,,,,,,,, -131100,1.5867058,2.5011532,,,,,,,,,,,,,, -131200,1.7104642,2.5896385,,,,,,,,,,,,,, -131300,1.6608859,2.4697275,,,,,,,,,,,,,, -131400,1.5293484,2.732134,,,,,,,,,,,,,, -131500,1.5610082,2.5219016,,,,,,,,,,,,,, -131535,,,0.6660742163658142,1.3897231817245483,0.6065599918365479,1.6731947660446167,50000.0,0.4800000190734863,2.348107099533081,10000.0,60957.56353855133,68542.42744445801,60957.56353855133,7571.182153224945,6.214536190032959,0.0 -131600,1.284667,4.8816867,,,,,,,,,,,,,, -131700,1.287044,4.885065,,,,,,,,,,,,,, -131800,1.2533178,3.7167892,,,,,,,,,,,,,, -131900,1.43015,3.235951,,,,,,,,,,,,,, -132000,1.4888836,4.315866,,,,,,,,,,,,,, -132100,1.5752513,2.3863895,,,,,,,,,,,,,, -132200,1.4776499,2.8563662,,,,,,,,,,,,,, -132300,1.5631028,4.3471866,,,,,,,,,,,,,, -132400,1.5191559,2.5250993,,,,,,,,,,,,,, -132442,,,0.6496874690055847,1.4811943769454956,0.6088399887084961,1.6800014972686768,50000.0,0.48130002617836,2.3436646461486816,10000.0,61377.6965944767,69013.43845295906,61377.6965944767,7621.960869312286,6.263074159622192,0.0 -132500,1.6715009,2.5051425,,,,,,,,,,,,,, -132600,1.291851,4.0207367,,,,,,,,,,,,,, -132700,1.6306106,2.7144725,,,,,,,,,,,,,, -132800,1.8104471,2.6464267,,,,,,,,,,,,,, -132900,1.7113155,4.8546743,,,,,,,,,,,,,, -133000,1.5040343,2.5408945,,,,,,,,,,,,,, -133100,1.5983611,2.4969711,,,,,,,,,,,,,, -133200,1.7811561,2.4973607,,,,,,,,,,,,,, -133300,1.6347953,2.4896035,,,,,,,,,,,,,, -133349,,,0.6616796851158142,1.4510056972503662,0.6108399629592896,1.6753292083740234,50000.0,0.4870000183582306,2.3509786128997803,10000.0,61797.79363250733,69487.23033475876,61797.79363250733,7675.555843830109,6.311018705368042,0.0 -133400,1.6662469,2.6100001,,,,,,,,,,,,,, -133500,1.647234,2.4622877,,,,,,,,,,,,,, -133600,2.03646,2.811609,,,,,,,,,,,,,, -133700,1.6348554,2.380488,,,,,,,,,,,,,, -133800,1.4293201,3.891403,,,,,,,,,,,,,, -133900,1.4875581,2.729718,,,,,,,,,,,,,, -134000,1.5184187,2.4692929,,,,,,,,,,,,,, -134100,1.7896705,2.541752,,,,,,,,,,,,,, -134200,1.5142952,2.8834074,,,,,,,,,,,,,, -134257,,,0.6688085794448853,1.3909156322479248,0.6072199940681458,1.6612194776535034,50000.0,0.4892000257968902,2.3225879669189453,10000.0,62217.71622133255,69961.19663286209,62217.71622133255,7729.499813079834,6.358997106552124,0.0 -134300,1.5730577,2.596593,,,,,,,,,,,,,, -134400,1.5575283,2.3004677,,,,,,,,,,,,,, -134500,1.6359445,2.3007383,,,,,,,,,,,,,, -134600,1.5690213,2.4875994,,,,,,,,,,,,,, -134700,1.444223,4.0679183,,,,,,,,,,,,,, -134800,1.5801456,2.423832,,,,,,,,,,,,,, -134900,1.6540456,2.7007182,,,,,,,,,,,,,, -135000,1.4891901,3.1831772,,,,,,,,,,,,,, -135100,1.509322,2.6794133,,,,,,,,,,,,,, -135162,,,0.6540820002555847,1.4487427473068235,0.6113799810409546,1.649997591972351,50000.0,0.4907000362873077,2.3041481971740723,10000.0,62637.634321689606,70435.33058691025,62637.634321689606,7783.615670204163,6.407909154891968,0.0 -135200,1.5118184,4.337526,,,,,,,,,,,,,, -135300,1.7288069,2.4430604,,,,,,,,,,,,,, -135400,1.7543945,2.400717,,,,,,,,,,,,,, -135500,1.4531074,4.847916,,,,,,,,,,,,,, -135600,1.7459046,2.437108,,,,,,,,,,,,,, -135700,1.4514463,4.6916375,,,,,,,,,,,,,, -135800,1.6410698,2.752921,,,,,,,,,,,,,, -135900,1.5797185,2.2510664,,,,,,,,,,,,,, -136000,1.4521573,3.0178025,,,,,,,,,,,,,, -136068,,,0.6649999618530273,1.3921610116958618,0.6185199618339539,1.6188608407974243,50000.0,0.4901000261306762,2.3076579570770264,10000.0,63057.867579460144,70906.8068845272,63057.867579460144,7834.759024858475,6.456115007400513,0.0 -136100,1.6887629,2.531054,,,,,,,,,,,,,, -136200,1.3449137,4.031869,,,,,,,,,,,,,, -136300,1.4714351,4.8595047,,,,,,,,,,,,,, -136400,1.4934357,4.9203844,,,,,,,,,,,,,, -136500,1.7225342,2.4010687,,,,,,,,,,,,,, -136600,1.4728435,2.986468,,,,,,,,,,,,,, -136700,1.442044,2.6282973,,,,,,,,,,,,,, -136800,1.5649948,3.7773435,,,,,,,,,,,,,, -136900,1.5521036,2.3279274,,,,,,,,,,,,,, -136976,,,0.6775000095367432,1.342071533203125,0.6184399724006653,1.606795072555542,50000.0,0.496800035238266,2.2707793712615967,10000.0,63478.281381607056,71378.2106127739,63478.281381607056,7885.648756742477,6.504778623580933,0.0 -137000,1.5073056,4.4983454,,,,,,,,,,,,,, -137100,1.7072601,2.2415862,,,,,,,,,,,,,, -137200,1.7198741,2.4709518,,,,,,,,,,,,,, -137300,1.8678508,4.9796643,,,,,,,,,,,,,, -137400,1.6807735,2.414762,,,,,,,,,,,,,, -137500,1.4928463,3.8181236,,,,,,,,,,,,,, -137600,1.5834821,2.734689,,,,,,,,,,,,,, -137700,1.8846732,2.305181,,,,,,,,,,,,,, -137800,1.5406013,2.4563427,,,,,,,,,,,,,, -137884,,,0.6682812571525574,1.403730034828186,0.6224200129508972,1.6151797771453855,50000.0,0.4948000311851501,2.288226366043091,10000.0,63898.23468732834,71849.86646604538,63898.23468732834,7936.875230550766,6.928737640380859,0.0 -137900,1.5624415,2.90876,,,,,,,,,,,,,, -138000,1.3548031,3.543552,,,,,,,,,,,,,, -138100,1.4889809,4.184241,,,,,,,,,,,,,, -138200,1.609407,2.4540482,,,,,,,,,,,,,, -138300,1.7840366,2.5664139,,,,,,,,,,,,,, -138400,1.8855877,2.2886403,,,,,,,,,,,,,, -138500,1.4225212,4.080186,,,,,,,,,,,,,, -138600,1.4927883,4.41941,,,,,,,,,,,,,, -138700,1.8865626,2.4635253,,,,,,,,,,,,,, -138791,,,0.6753710508346558,1.351009488105774,0.6233800053596497,1.5772992372512815,50000.0,0.4999000132083893,2.231370687484741,10000.0,64318.29056835175,72320.50111722946,64318.29056835175,7987.347238540649,6.983047246932983,0.0 -138800,1.5136112,3.062029,,,,,,,,,,,,,, -138900,1.4284322,4.185046,,,,,,,,,,,,,, -139000,1.4391273,2.9482193,,,,,,,,,,,,,, -139100,1.6924449,2.2970896,,,,,,,,,,,,,, -139200,1.7204555,2.348158,,,,,,,,,,,,,, -139300,1.5456004,2.7781126,,,,,,,,,,,,,, -139400,1.824477,2.3685837,,,,,,,,,,,,,, -139500,1.8252702,2.4011834,,,,,,,,,,,,,, -139600,1.4897115,3.7061505,,,,,,,,,,,,,, -139699,,,0.6874608993530273,1.287361741065979,0.6303799748420715,1.5501196384429932,50000.0,0.5088000297546387,2.2103207111358643,10000.0,64738.50710082054,72792.97579598427,64738.50710082054,8039.49192738533,7.045411348342896,0.0 -139700,1.834379,2.4346108,,,,,,,,,,,,,, -139800,1.7723619,2.3275588,,,,,,,,,,,,,, -139900,1.736288,3.233114,,,,,,,,,,,,,, -140000,1.749982,2.1997688,,,,,,,,,,,,,, -140100,2.0723286,2.259914,,,,,,,,,,,,,, -140200,1.6812576,4.8048153,,,,,,,,,,,,,, -140300,1.723279,2.3851602,,,,,,,,,,,,,, -140400,1.666194,3.37798,,,,,,,,,,,,,, -140500,1.7582083,2.3921611,,,,,,,,,,,,,, -140600,1.5338575,3.1403017,,,,,,,,,,,,,, -140606,,,0.684277355670929,1.3224250078201294,0.6326999664306641,1.5607450008392334,50000.0,0.5029000043869019,2.2430734634399414,10000.0,65158.44564986229,73265.60011100769,65158.44564986229,8092.068526983261,7.1019439697265625,0.0 -140700,1.4931309,3.3987608,,,,,,,,,,,,,, -140800,1.7481207,2.2745838,,,,,,,,,,,,,, -140900,1.5365741,2.9889898,,,,,,,,,,,,,, -141000,1.7102022,2.3324664,,,,,,,,,,,,,, -141100,1.6491876,2.73485,,,,,,,,,,,,,, -141200,1.559939,3.7906055,,,,,,,,,,,,,, -141300,1.8257126,2.490667,,,,,,,,,,,,,, -141400,1.6680026,4.1612644,,,,,,,,,,,,,, -141500,1.6892961,2.3860044,,,,,,,,,,,,,, -141514,,,0.6889257431030273,1.2937071323394775,0.6418200135231018,1.511349320411682,50000.0,0.5126000046730042,2.173495054244995,10000.0,65578.51907277107,73737.41035723686,65578.51907277107,8143.69774889946,7.15775990486145,0.0 -141600,1.6152148,4.6672463,,,,,,,,,,,,,, -141700,1.5853506,4.4697285,,,,,,,,,,,,,, -141800,1.6884745,2.210478,,,,,,,,,,,,,, -141900,1.9762424,2.3100364,,,,,,,,,,,,,, -142000,1.8552704,2.3538864,,,,,,,,,,,,,, -142100,1.7357035,4.3951645,,,,,,,,,,,,,, -142200,1.8804251,2.1751258,,,,,,,,,,,,,, -142300,1.6684831,2.2539024,,,,,,,,,,,,,, -142400,1.627749,2.840238,,,,,,,,,,,,,, -142421,,,0.6948828101158142,1.2586766481399536,0.6380000114440918,1.5189902782440186,50000.0,0.513700008392334,2.185518980026245,10000.0,65998.46021318436,74208.82389330864,65998.46021318436,8195.071061849594,7.205103397369385,0.0 -142500,1.8502815,2.34658,,,,,,,,,,,,,, -142600,1.7714138,2.7084377,,,,,,,,,,,,,, -142700,1.6171627,2.8336287,,,,,,,,,,,,,, -142800,1.598431,3.4949474,,,,,,,,,,,,,, -142900,1.8386316,2.2230866,,,,,,,,,,,,,, -143000,1.7688336,2.2113404,,,,,,,,,,,,,, -143100,1.7888885,2.3691325,,,,,,,,,,,,,, -143200,2.1424131,2.4229846,,,,,,,,,,,,,, -143300,1.7057849,2.8331559,,,,,,,,,,,,,, -143326,,,0.7229296565055847,1.147066593170166,0.6485999822616577,1.4839391708374023,50000.0,0.5218000411987305,2.157728672027588,10000.0,66418.74589681625,74681.48336172104,66418.74589681625,8247.3456325531,7.25283932685852,0.0 -143400,1.7891597,2.88596,,,,,,,,,,,,,, -143500,1.7713939,2.3540852,,,,,,,,,,,,,, -143600,1.9100674,2.3761628,,,,,,,,,,,,,, -143700,2.0901346,2.3366964,,,,,,,,,,,,,, -143800,1.7114661,2.7502248,,,,,,,,,,,,,, -143900,1.5382645,3.8855119,,,,,,,,,,,,,, -144000,1.73301,4.1893697,,,,,,,,,,,,,, -144100,1.7034364,3.0950196,,,,,,,,,,,,,, -144200,1.7945614,2.3607974,,,,,,,,,,,,,, -144235,,,0.6969531178474426,1.259919285774231,0.6465399861335754,1.4953464269638062,50000.0,0.5186000466346741,2.162790298461914,10000.0,66839.10083460808,75155.56358337402,66839.10083460808,8300.968694925308,7.303514003753662,0.0 -144300,1.792799,2.2014315,,,,,,,,,,,,,, -144400,1.6759329,4.209196,,,,,,,,,,,,,, -144500,1.9116735,2.188647,,,,,,,,,,,,,, -144600,1.6547608,3.6616747,,,,,,,,,,,,,, -144700,1.721803,3.30392,,,,,,,,,,,,,, -144800,1.6294502,4.808677,,,,,,,,,,,,,, -144900,1.8861551,2.2453246,,,,,,,,,,,,,, -145000,1.58886,4.7884617,,,,,,,,,,,,,, -145100,1.7826552,2.092377,,,,,,,,,,,,,, -145143,,,0.6981835961341858,1.2710505723953247,0.6437000036239624,1.523866057395935,50000.0,0.5236000418663025,2.1719541549682617,10000.0,67259.42416167259,75627.54675364494,67259.42416167259,8352.526899814606,7.353893756866455,0.0 -145200,1.950713,2.2638261,,,,,,,,,,,,,, -145300,1.87832,2.5330186,,,,,,,,,,,,,, -145400,2.1194155,2.2825415,,,,,,,,,,,,,, -145500,1.6126113,3.200988,,,,,,,,,,,,,, -145600,1.6990962,3.806774,,,,,,,,,,,,,, -145700,2.519731,4.9743843,,,,,,,,,,,,,, -145800,1.9078277,2.2680898,,,,,,,,,,,,,, -145900,1.923761,2.3587613,,,,,,,,,,,,,, -146000,1.8920732,2.3370523,,,,,,,,,,,,,, -146049,,,0.7145116925239563,1.1924309730529783,0.6484400033950806,1.4934139251708984,50000.0,0.523300051689148,2.157883644104004,10000.0,67679.52076888084,76099.40624260902,67679.52076888084,8404.179302930832,7.413357973098755,0.0 -146100,1.8262388,2.6746435,,,,,,,,,,,,,, -146200,1.6461009,3.3365214,,,,,,,,,,,,,, -146300,1.9027369,2.5978007,,,,,,,,,,,,,, -146400,1.768361,3.7807448,,,,,,,,,,,,,, -146500,1.7561687,4.5012975,,,,,,,,,,,,,, -146600,1.911845,2.1065896,,,,,,,,,,,,,, -146700,2.1700284,2.2803242,,,,,,,,,,,,,, -146800,2.0308056,2.2370262,,,,,,,,,,,,,, -146900,1.8880514,2.138938,,,,,,,,,,,,,, -146958,,,0.7075585722923279,1.199978590011597,0.655239999294281,1.4392666816711426,50000.0,0.530500054359436,2.0883536338806152,10000.0,68099.78795552254,76573.15300369263,68099.78795552254,8457.554328680038,7.466028690338135,0.0 -147000,1.6530554,3.3578184,,,,,,,,,,,,,, -147100,1.6843035,4.720986,,,,,,,,,,,,,, -147200,1.7156745,3.3134463,,,,,,,,,,,,,, -147300,1.9060256,2.321216,,,,,,,,,,,,,, -147400,2.0578744,2.369176,,,,,,,,,,,,,, -147500,2.0991514,2.248826,,,,,,,,,,,,,, -147600,1.9159786,2.2475026,,,,,,,,,,,,,, -147700,2.1231992,2.21502,,,,,,,,,,,,,, -147800,1.9791642,2.1733162,,,,,,,,,,,,,, -147867,,,0.7089257836341858,1.191853165626526,0.6549599766731262,1.450493097305298,50000.0,0.5245000123977661,2.122317314147949,10000.0,68519.88999271393,77044.57645773888,68519.88999271393,8508.766267299652,7.52376914024353,0.0 -147900,2.1943288,2.4296322,,,,,,,,,,,,,, -148000,1.6693664,4.627697,,,,,,,,,,,,,, -148100,1.9041791,2.1427875,,,,,,,,,,,,,, -148200,1.9465266,2.1151118,,,,,,,,,,,,,, -148300,1.6971492,3.7053652,,,,,,,,,,,,,, -148400,1.9995399,2.370642,,,,,,,,,,,,,, -148500,1.9302045,4.643097,,,,,,,,,,,,,, -148600,1.8273166,4.6970377,,,,,,,,,,,,,, -148700,1.7327056,2.9391992,,,,,,,,,,,,,, -148775,,,0.72328120470047,1.1626847982406616,0.656279981136322,1.4608744382858276,50000.0,0.5313000082969666,2.1140568256378174,10000.0,68940.13805937767,77517.2398583889,68940.13805937767,8561.076631069183,7.576759338378906,0.0 -148800,1.9209937,2.2535145,,,,,,,,,,,,,, -148900,1.8933706,2.686389,,,,,,,,,,,,,, -149000,2.2740932,2.3002722,,,,,,,,,,,,,, -149100,1.7866561,3.4352486,,,,,,,,,,,,,, -149200,2.042819,2.223895,,,,,,,,,,,,,, -149300,2.0513802,2.1230087,,,,,,,,,,,,,, -149400,1.8032545,3.8160322,,,,,,,,,,,,,, -149500,2.1526039,2.0866923,,,,,,,,,,,,,, -149600,1.780078,3.295458,,,,,,,,,,,,,, -149682,,,0.71742182970047,1.170621037483215,0.6571800112724304,1.4194693565368652,50000.0,0.5441000461578369,2.061110496520996,10000.0,69360.1499080658,77988.87147164345,69360.1499080658,8612.59052824974,7.63094162940979,0.0 -149700,1.7827487,3.580517,,,,,,,,,,,,,, -149800,1.7198095,3.8744533,,,,,,,,,,,,,, -149900,2.2224212,4.4950204,,,,,,,,,,,,,, -150000,2.0949578,2.8228068,,,,,,,,,,,,,, -150100,2.200981,2.1372433,,,,,,,,,,,,,, -150200,2.0754197,2.262901,,,,,,,,,,,,,, -150300,2.0907955,4.4057407,,,,,,,,,,,,,, -150400,2.1583805,2.1626425,,,,,,,,,,,,,, -150500,2.121011,2.4467206,,,,,,,,,,,,,, -150589,,,0.7193945050239563,1.145675778388977,0.6657599806785583,1.3892215490341189,50000.0,0.5406000018119812,2.0501999855041504,10000.0,69780.04849529266,78460.87410640717,69780.04849529266,8664.59189748764,7.682143688201904,0.0 -150600,1.983644,2.1145592,,,,,,,,,,,,,, -150700,2.3469644,2.27632,,,,,,,,,,,,,, -150800,1.9460424,4.0147705,,,,,,,,,,,,,, -150900,2.066785,2.2023754,,,,,,,,,,,,,, -151000,1.8381307,3.5470014,,,,,,,,,,,,,, -151100,2.1567044,2.1746898,,,,,,,,,,,,,, -151200,2.112313,2.1998978,,,,,,,,,,,,,, -151300,2.322081,4.7134132,,,,,,,,,,,,,, -151400,2.0364392,2.0505772,,,,,,,,,,,,,, -151496,,,0.7369726300239563,1.0826455354690552,0.6676999926567078,1.3798246383666992,50000.0,0.5424000024795532,2.038510799407959,10000.0,70200.33002853394,78933.96350884438,70200.33002853394,8717.290427207947,7.740187168121338,0.0 -151500,1.9394736,2.2615604,,,,,,,,,,,,,, -151600,2.0974042,2.0080628,,,,,,,,,,,,,, -151700,2.0431035,4.6288476,,,,,,,,,,,,,, -151800,1.8074874,3.5415823,,,,,,,,,,,,,, -151900,2.069681,4.4697657,,,,,,,,,,,,,, -152000,1.8519441,3.687204,,,,,,,,,,,,,, -152100,1.8344176,3.1112776,,,,,,,,,,,,,, -152200,2.4236922,2.0861187,,,,,,,,,,,,,, -152300,1.8215423,3.2947242,,,,,,,,,,,,,, -152400,2.173535,2.0998762,,,,,,,,,,,,,, -152404,,,0.7234960794448853,1.141112208366394,0.6663999557495117,1.381853461265564,50000.0,0.5423000454902649,2.045349597930908,10000.0,70620.6611096859,79407.156021595,70620.6611096859,8770.050068855286,7.791531801223755,0.0 -152500,2.152064,2.231495,,,,,,,,,,,,,, -152600,2.1298144,2.163743,,,,,,,,,,,,,, -152700,1.9689649,4.016792,,,,,,,,,,,,,, -152800,1.8714914,2.9231937,,,,,,,,,,,,,, -152900,1.907445,2.8098783,,,,,,,,,,,,,, -153000,2.1024203,2.0358179,,,,,,,,,,,,,, -153100,2.1878276,2.1587987,,,,,,,,,,,,,, -153200,2.1097815,2.2379067,,,,,,,,,,,,,, -153300,2.0964735,4.2007627,,,,,,,,,,,,,, -153312,,,0.7329882383346558,1.0937864780426023,0.6733799576759338,1.3574568033218384,50000.0,0.5484000444412231,2.018724203109741,10000.0,71040.6793308258,79880.49059653282,71040.6793308258,8823.263661623001,7.842597007751465,0.0 -153400,1.9257064,3.0023909,,,,,,,,,,,,,, -153500,1.9803929,3.0667853,,,,,,,,,,,,,, -153600,2.017529,2.4951317,,,,,,,,,,,,,, -153700,2.1179216,2.1943848,,,,,,,,,,,,,, -153800,1.9241267,3.8999622,,,,,,,,,,,,,, -153900,2.0789013,3.056052,,,,,,,,,,,,,, -154000,2.085763,2.714418,,,,,,,,,,,,,, -154100,2.252403,2.715375,,,,,,,,,,,,,, -154200,2.367344,2.0739083,,,,,,,,,,,,,, -154220,,,0.7409570217132568,1.0447590351104736,0.6767799854278564,1.337766408920288,50000.0,0.5509000420570374,1.9991964101791384,10000.0,71460.68648219109,80351.2955019474,71460.68648219109,8873.951783180237,7.901000738143921,0.0 -154300,2.150839,4.614453,,,,,,,,,,,,,, -154400,2.3067038,2.0641117,,,,,,,,,,,,,, -154500,2.278844,1.9721758,,,,,,,,,,,,,, -154600,2.4733024,2.208001,,,,,,,,,,,,,, -154700,2.691675,1.9838327,,,,,,,,,,,,,, -154800,2.300028,2.0676768,,,,,,,,,,,,,, -154900,2.4667687,4.606698,,,,,,,,,,,,,, -155000,2.126911,2.2267728,,,,,,,,,,,,,, -155100,2.3719437,1.9234129,,,,,,,,,,,,,, -155127,,,0.7374804615974426,1.092585206031799,0.6778199672698975,1.3446279764175415,50000.0,0.5489000082015991,2.018172025680542,10000.0,71880.62465643883,80825.57174110413,71880.62465643883,8928.184937000275,7.954378366470337,0.0 -155200,2.2020776,3.5910425,,,,,,,,,,,,,, -155300,2.2010002,2.088673,,,,,,,,,,,,,, -155400,2.2604866,2.0275464,,,,,,,,,,,,,, -155500,2.433448,1.9910095,,,,,,,,,,,,,, -155600,2.2594454,2.0349436,,,,,,,,,,,,,, -155700,2.2262762,2.0712826,,,,,,,,,,,,,, -155800,2.4028442,1.9233097,,,,,,,,,,,,,, -155900,2.0646663,2.9028707,,,,,,,,,,,,,, -156000,2.1088903,3.8099031,,,,,,,,,,,,,, -156033,,,0.7378515601158142,1.0726988315582275,0.6809799671173096,1.3264795541763306,50000.0,0.5550000071525574,1.985244870185852,10000.0,72300.71397519112,81296.80072975159,72300.71397519112,8979.215026378632,8.012673616409302,0.0 -156100,2.0787013,2.3243518,,,,,,,,,,,,,, -156200,2.2871058,2.7206345,,,,,,,,,,,,,, -156300,2.1931064,1.9354465,,,,,,,,,,,,,, -156400,2.3201158,2.03609,,,,,,,,,,,,,, -156500,2.2952533,2.1090307,,,,,,,,,,,,,, -156600,2.3317876,1.9296691,,,,,,,,,,,,,, -156700,2.3909564,4.5323777,,,,,,,,,,,,,, -156800,2.1485155,3.4074788,,,,,,,,,,,,,, -156900,2.3046365,1.8111625,,,,,,,,,,,,,, -156942,,,0.7446679472923279,1.0540313720703125,0.6827600002288818,1.3250590562820437,50000.0,0.5562000274658203,1.9805034399032595,10000.0,72720.9410059452,81767.88509273529,72720.9410059452,9029.969490528109,8.063661575317383,0.0 -157000,2.4423485,1.9745687,,,,,,,,,,,,,, -157100,2.1160853,2.7979612,,,,,,,,,,,,,, -157200,2.152705,3.1855435,,,,,,,,,,,,,, -157300,2.538051,1.9946853,,,,,,,,,,,,,, -157400,2.3587778,1.8876276,,,,,,,,,,,,,, -157500,2.5098813,1.8843874,,,,,,,,,,,,,, -157600,2.015323,2.4211597,,,,,,,,,,,,,, -157700,2.3247077,1.9003569,,,,,,,,,,,,,, -157800,2.558466,1.9890547,,,,,,,,,,,,,, -157851,,,0.7596679329872131,0.9815232753753662,0.6845600008964539,1.304532527923584,50000.0,0.5621000528335571,1.955943703651428,10000.0,73141.18209695816,82242.24174976349,73141.18209695816,9083.979422330856,8.117793560028076,0.0 -157900,2.5473564,4.5446243,,,,,,,,,,,,,, -158000,2.2248526,2.8227835,,,,,,,,,,,,,, -158100,2.2166271,2.136512,,,,,,,,,,,,,, -158200,2.1528409,3.2782726,,,,,,,,,,,,,, -158300,2.3083396,3.24197,,,,,,,,,,,,,, -158400,2.2164717,1.7766292,,,,,,,,,,,,,, -158500,2.4616501,4.4880705,,,,,,,,,,,,,, -158600,2.4114442,1.9998848,,,,,,,,,,,,,, -158700,2.4074025,1.9069502,,,,,,,,,,,,,, -158758,,,0.7476562261581421,1.031006932258606,0.6883599758148193,1.2879359722137451,50000.0,0.5637000203132629,1.9300532341003416,10000.0,73561.48490953445,82713.59108018875,73561.48490953445,9134.91816353798,8.174100875854492,0.0 -158800,2.2419004,1.9167953,,,,,,,,,,,,,, -158900,2.2719808,2.899387,,,,,,,,,,,,,, -159000,2.3434644,1.9807835,,,,,,,,,,,,,, -159100,2.4385517,1.8622725,,,,,,,,,,,,,, -159200,2.270816,2.6050184,,,,,,,,,,,,,, -159300,2.295737,2.5466847,,,,,,,,,,,,,, -159400,2.3608644,1.9714379,,,,,,,,,,,,,, -159500,2.4537082,1.8408829,,,,,,,,,,,,,, -159600,2.227524,1.8653822,,,,,,,,,,,,,, -159669,,,0.7578319907188416,0.9961916208267212,0.69159996509552,1.2772376537322998,50000.0,0.5711000561714172,1.9215903282165527,10000.0,73981.68310594559,83184.31155920029,73981.68310594559,9185.332689285278,8.230481386184692,0.0 -159700,2.2896693,2.1534517,,,,,,,,,,,,,, -159800,2.5616117,4.4539604,,,,,,,,,,,,,, -159900,2.5511684,1.8597274,,,,,,,,,,,,,, -160000,2.256242,3.557148,,,,,,,,,,,,,, -160100,2.4971023,4.2556977,,,,,,,,,,,,,, -160200,2.5430884,2.0051892,,,,,,,,,,,,,, -160300,2.493257,1.9624426,,,,,,,,,,,,,, -160400,2.561493,2.4329555,,,,,,,,,,,,,, -160500,2.5089934,1.7180631,,,,,,,,,,,,,, -160576,,,0.7697070240974426,0.938443958759308,0.6931599974632263,1.2678637504577637,50000.0,0.5687000155448914,1.9138787984848025,10000.0,74401.90697979927,83659.51302075386,74401.90697979927,9240.204246282578,8.28437876701355,0.0 -160600,2.3816257,1.8925936,,,,,,,,,,,,,, -160700,2.4937098,1.8741461,,,,,,,,,,,,,, -160800,2.5779705,2.4929414,,,,,,,,,,,,,, -160900,2.5244906,2.512127,,,,,,,,,,,,,, -161000,2.3637683,2.1841908,,,,,,,,,,,,,, -161100,3.0188816,2.014144,,,,,,,,,,,,,, -161200,2.4601934,1.9446462,,,,,,,,,,,,,, -161300,2.437053,1.9069358,,,,,,,,,,,,,, -161400,2.734517,4.2634554,,,,,,,,,,,,,, -161485,,,0.7563671469688416,0.976298451423645,0.6957199573516846,1.2451553344726562,50000.0,0.5685999989509583,1.8960278034210205,10000.0,74822.10788369179,84130.91235041618,74822.10788369179,9291.293983697891,8.341437816619873,0.0 -161500,2.2637286,2.7313116,,,,,,,,,,,,,, -161600,2.4443324,2.6130576,,,,,,,,,,,,,, -161700,2.5248313,1.8068367,,,,,,,,,,,,,, -161800,2.5478144,4.3571672,,,,,,,,,,,,,, -161900,2.499083,2.932363,,,,,,,,,,,,,, -162000,2.4111013,2.2374358,,,,,,,,,,,,,, -162100,2.668741,1.8452895,,,,,,,,,,,,,, -162200,2.444939,2.8183997,,,,,,,,,,,,,, -162300,2.4156168,1.9081509,,,,,,,,,,,,,, -162394,,,0.76527339220047,0.9461551308631896,0.7001199722290039,1.229008674621582,50000.0,0.5740000009536743,1.8750234842300413,10000.0,75242.05846524239,84602.79054522514,75242.05846524239,9343.114990472794,8.395569086074829,0.0 -162400,2.3685856,2.9918969,,,,,,,,,,,,,, -162500,2.8302329,1.9171865,,,,,,,,,,,,,, -162600,2.5497398,1.6129284,,,,,,,,,,,,,, -162700,2.516991,3.6373036,,,,,,,,,,,,,, -162800,2.6993766,2.3523636,,,,,,,,,,,,,, -162900,2.54145,1.8167391,,,,,,,,,,,,,, -163000,2.4004686,2.1445246,,,,,,,,,,,,,, -163100,2.7150826,1.7975211,,,,,,,,,,,,,, -163200,2.6546822,1.7900583,,,,,,,,,,,,,, -163300,,,0.7712695002555847,0.9198420643806458,0.701259970664978,1.2253583669662476,50000.0,0.5771000385284424,1.875489592552185,10000.0,75661.95459794998,85074.48866915703,75661.95459794998,9394.80528140068,8.456037521362305,0.0 -163300,2.7909706,1.8260295,,,,,,,,,,,,,, -163400,2.6616178,1.8341684,,,,,,,,,,,,,, -163500,2.6430647,1.9364582,,,,,,,,,,,,,, -163600,2.697217,2.1564798,,,,,,,,,,,,,, -163700,2.6426353,1.927168,,,,,,,,,,,,,, -163800,2.4947383,3.5069535,,,,,,,,,,,,,, -163900,2.8662338,1.9504853,,,,,,,,,,,,,, -164000,2.612474,1.8718256,,,,,,,,,,,,,, -164100,3.175332,1.8664647,,,,,,,,,,,,,, -164200,2.7010338,1.8151662,,,,,,,,,,,,,, -164208,,,0.7705273032188416,0.9230221509933472,0.7032399773597717,1.2151079177856443,50000.0,0.5770000219345093,1.8578152656555176,10000.0,76082.06962704659,85547.03738641739,76082.06962704659,9447.132781267166,8.510085105895996,0.0 -164300,2.8579671,1.813921,,,,,,,,,,,,,, -164400,2.8036027,1.83797,,,,,,,,,,,,,, -164500,2.9004142,1.8730032,,,,,,,,,,,,,, -164600,2.8067534,1.7760502,,,,,,,,,,,,,, -164700,2.5516033,2.0029435,,,,,,,,,,,,,, -164800,2.718566,1.8436463,,,,,,,,,,,,,, -164900,2.5217376,3.0960178,,,,,,,,,,,,,, -165000,2.7357037,1.8354397,,,,,,,,,,,,,, -165100,2.737277,1.8511415,,,,,,,,,,,,,, -165117,,,0.7754296660423279,0.913085401058197,0.7068799734115601,1.209192991256714,50000.0,0.5834000110626221,1.851784825325012,10000.0,76502.12402510643,86020.86357402802,76502.12402510643,9500.79858636856,8.563549280166626,0.0 -165200,2.7359266,2.908489,,,,,,,,,,,,,, -165300,2.808965,4.059792,,,,,,,,,,,,,, -165400,2.7054698,1.777106,,,,,,,,,,,,,, -165500,2.8122025,1.7907727,,,,,,,,,,,,,, -165600,2.8492231,1.855508,,,,,,,,,,,,,, -165700,2.7697763,1.8793701,,,,,,,,,,,,,, -165800,2.8571818,1.7058129,,,,,,,,,,,,,, -165900,3.122035,1.868453,,,,,,,,,,,,,, -166000,2.6721096,2.1238613,,,,,,,,,,,,,, -166022,,,0.7830663919448853,0.8687943816184998,0.7093600034713745,1.190772533416748,50000.0,0.5837000012397766,1.840895652770996,10000.0,76922.23614120483,86491.56161642075,76922.23614120483,9551.27962064743,8.616734027862549,0.0 -166100,2.6617556,2.298531,,,,,,,,,,,,,, -166200,3.022453,1.6680272,,,,,,,,,,,,,, -166300,2.8034298,1.7764908,,,,,,,,,,,,,, -166400,2.6622493,1.7306476,,,,,,,,,,,,,, -166500,2.6860225,3.6891875,,,,,,,,,,,,,, -166600,2.8795717,1.7171819,,,,,,,,,,,,,, -166700,2.5011277,2.468461,,,,,,,,,,,,,, -166800,3.1282969,1.744924,,,,,,,,,,,,,, -166900,2.5685246,1.7129264,,,,,,,,,,,,,, -166928,,,0.7832421660423279,0.871487557888031,0.7151399850845337,1.162901520729065,50000.0,0.5940000414848328,1.7914539575576782,10000.0,77342.32014989853,86963.95547676086,77342.32014989853,9603.473745822906,8.681183338165283,0.0 -167000,2.6177428,2.973693,,,,,,,,,,,,,, -167100,3.3569672,3.5837333,,,,,,,,,,,,,, -167200,2.7798228,2.204937,,,,,,,,,,,,,, -167300,2.6408753,3.658957,,,,,,,,,,,,,, -167317,,,,,,,,,,,77520.20170688629,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 0a6a9a1e1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -31.48529624938965,0.0,41.93090534210205,1,0,41.93090534210205,0.0010000000474974,6.907756805419922,10000,73.41631269454956,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -83.17005729675293,0.0171785354614257,461.9541461467743,839,0,461.9541461467743,0.0220000017434358,6.133785247802734,10000,545.1891384124756,0.0321679674088954,5.976029396057129,0.0268599987030029,6.033750534057617,50000 -132.39385414123535,0.0436456203460693,881.9554603099823,1741,0,881.9554603099823,0.0499000027775764,5.561094284057617,10000,1014.4922845363616,0.0686132833361625,5.289651393890381,0.0642199963331222,5.34092903137207,50000 -183.4759097099304,0.0698153972625732,1302.169753074646,2649,0,1302.169753074646,0.0838000029325485,5.196085453033447,10000,1485.8664710521698,0.114414058625698,4.848740100860596,0.1076399981975555,4.910977840423584,50000 -233.9338686466217,0.1003086566925048,1722.0868241786957,3554,0,1722.0868241786957,0.1200000047683715,4.79691743850708,10000,1956.3246581554413,0.1730273365974426,4.282963752746582,0.1574199944734573,4.420511245727539,50000 -285.4160861968994,0.5835461616516113,2141.7024862766266,4448,0,2141.7024862766266,0.1438000053167343,4.610869884490967,10000,2427.956640481949,0.2078320235013961,4.076893329620361,0.1960600018501281,4.17407751083374,50000 -337.2035584449768,0.6109819412231445,2561.650346755981,5350,0,2561.650346755981,0.1876000016927719,4.200911998748779,10000,2899.76961684227,0.2707226574420929,3.5997939109802246,0.2500199973583221,3.711319923400879,50000 -388.6069040298462,0.6374530792236328,2981.769757270813,6257,0,2981.769757270813,0.2152000069618225,3.982329845428467,10000,3371.370764017105,0.3174609243869781,3.2546160221099854,0.2885800004005432,3.4442403316497803,50000 -441.1956124305725,0.6641831398010254,3401.7290391922,7165,0,3401.7290391922,0.2417000085115432,3.805813789367676,10000,3843.996475219727,0.3345507681369781,3.108985424041748,0.3141399919986725,3.2446765899658203,50000 -492.78946113586426,0.6920418739318848,3821.696554660797,8073,0,3821.696554660797,0.2471000105142593,3.786952495574951,10000,4315.6374979019165,0.3506445288658142,3.069853067398072,0.3208400011062622,3.2362983226776123,50000 -544.0257968902588,0.7220323085784912,4241.8689959049225,8981,0,4241.8689959049225,0.2645000219345093,3.641249418258667,10000,4787.127847909927,0.3797656297683716,2.877138614654541,0.3468599915504455,3.067584991455078,50000 -598.793160200119,0.7506968975067139,4662.201369285584,9886,0,4662.201369285584,0.2873000204563141,3.478476047515869,10000,5262.30820274353,0.4021289050579071,2.717653751373291,0.3776800036430359,2.877584934234619,50000 -650.6066122055054,0.7820405960083008,5082.514708995819,10791,0,5082.514708995819,0.299200028181076,3.382631063461304,10000,5734.52169752121,0.4259765446186065,2.5819032192230225,0.3936199843883514,2.756279706954956,50000 -703.0180859565735,0.8099539279937744,5502.804249048233,11694,0,5502.804249048233,0.317900002002716,3.3103625774383545,10000,6207.301434278488,0.4414648413658142,2.500446081161499,0.4053199887275696,2.709198474884033,50000 -754.8459165096283,0.8380134105682373,5923.055751800537,12599,0,5923.055751800537,0.3225000202655792,3.253435611724853,10000,6679.461454868317,0.4473632574081421,2.448283195495605,0.4174199998378753,2.6113107204437256,50000 -807.402941942215,0.8693232536315918,6343.135094165802,13500,0,6343.135094165802,0.3232000172138214,3.279659509658813,10000,7152.180879831314,0.4481445252895355,2.4669337272644043,0.4198599755764007,2.638878107070923,50000 -861.2038099765778,0.8967950344085693,6763.108623743057,14401,0,6763.108623743057,0.3260000050067901,3.225304126739502,10000,7626.033350944519,0.4692578017711639,2.391101121902466,0.4292799830436706,2.592332601547241,50000 -913.4623625278472,0.926544189453125,7183.137870073319,15307,0,7183.137870073319,0.3416000306606293,3.1499056816101074,10000,8098.40290927887,0.4732421636581421,2.3394665718078613,0.4375399947166443,2.532576084136963,50000 -966.138659477234,0.9530913829803468,7603.157238006592,16210,0,7603.157238006592,0.3473000228404999,3.0882656574249268,10000,8571.176964044571,0.4846288859844208,2.263720750808716,0.4474599957466125,2.450514554977417,50000 -1017.9731209278108,0.9839563369750975,8023.078520536423,17108,0,8023.078520536423,0.3574000298976898,3.0589606761932373,10000,9043.016461133957,0.5009179711341858,2.1943092346191406,0.4619999825954437,2.397297143936157,50000 -1071.4758758544922,1.0145676136016846,8443.216924190521,18007,0,8443.216924190521,0.3568000197410583,3.10998272895813,10000,9516.7394759655,0.5083202719688416,2.2235052585601807,0.4515199959278106,2.495116710662842,50000 -1123.613642454147,1.043529987335205,8863.401502132416,18907,0,8863.401502132416,0.3609000146389007,3.0243079662323,10000,9989.141452550888,0.5022070407867432,2.1813580989837646,0.4710799753665924,2.35910964012146,50000 -1178.4694809913635,1.0712995529174805,9283.423431873322,19804,0,9283.423431873322,0.3738000094890594,2.9481990337371826,10000,10464.098256111143,0.5236718654632568,2.076261043548584,0.4801599979400635,2.295961618423462,50000 -1230.8315978050232,1.1028954982757568,9703.366194963455,20704,0,9703.366194963455,0.3751000165939331,2.9266579151153564,10000,10936.4870698452,0.556835949420929,1.9286086559295648,0.4835599958896637,2.271160364151001,50000 -1283.7493817806244,1.1308624744415283,10123.38339304924,21605,0,10123.38339304924,0.3795000314712524,2.898561477661133,10000,11409.501055002213,0.5281640291213989,2.053582191467285,0.4886199831962585,2.251695156097412,50000 -1337.5087454319,1.162177801132202,10543.948052406313,22504,0,10543.948052406313,0.3874000310897827,2.876436948776245,10000,11883.90704703331,0.5364062190055847,1.9806296825408936,0.4958599805831909,2.1945180892944336,50000 -1391.6791398525238,1.190190076828003,10964.240460157394,23404,0,10964.240460157394,0.3874000310897827,2.8902664184570312,10000,12358.45002245903,0.5655078291893005,1.908821702003479,0.4991799890995025,2.2292399406433105,50000 -1443.191126346588,1.218409776687622,11384.598715543749,24300,0,11384.598715543749,0.3955000042915344,2.81775450706482,10000,12830.399197340012,0.5464257597923279,1.977408051490784,0.5089799761772156,2.1693193912506104,50000 -1495.6441519260406,1.2541024684906006,11804.667268276216,25200,0,11804.667268276216,0.4011000096797943,2.787417411804199,10000,13303.007089614868,0.5544726252555847,1.8977075815200808,0.5140599608421326,2.105320930480957,50000 -1550.408257484436,1.28757643699646,12224.890575885773,26100,0,12224.890575885773,0.4094000160694122,2.7280988693237305,10000,13778.07923269272,0.5863476395606995,1.7209587097167969,0.522059977054596,2.043886423110962,50000 -1602.5821635723114,1.3235411643981934,12645.1004447937,27000,0,12645.1004447937,0.4092000126838684,2.723639726638794,10000,14250.550843000412,0.5658202767372131,1.8472483158111568,0.5233199596405029,2.056349754333496,50000 -1653.192789554596,1.3615074157714844,13065.12490582466,27898,0,13065.12490582466,0.4037000238895416,2.790687322616577,10000,14721.279171228409,0.557812511920929,1.8963301181793213,0.5161399841308594,2.110172748565674,50000 -1705.023080587387,1.392759084701538,13485.352035045624,28795,0,13485.352035045624,0.4162000119686126,2.7257065773010254,10000,15193.418963193892,0.5855664014816284,1.7688456773757937,0.5272600054740906,2.059099674224853,50000 -1757.6186830997467,1.4242591857910156,13905.495992660522,29694,0,13905.495992660522,0.4116000235080719,2.7133548259735107,10000,15666.241523504255,0.5672265291213989,1.8366743326187127,0.526699960231781,2.0453226566314697,50000 -1809.639849185944,1.4553961753845217,14325.84366250038,30595,0,14325.84366250038,0.4234000146389007,2.6812961101531982,10000,16138.692564725876,0.5782226324081421,1.7867298126220703,0.5320000052452087,2.0195980072021484,50000 -1861.2554004192352,1.4885175228118896,14745.781195640564,31494,0,14745.781195640564,0.4247000217437744,2.6817445755004883,10000,16610.330486536026,0.5876171588897705,1.7332189083099363,0.5371800065040588,1.998450875282288,50000 -1912.60741186142,1.521423101425171,15165.913511514664,32397,0,15165.913511514664,0.4207000136375427,2.69540810585022,10000,17081.899649858475,0.572265625,1.8312280178070068,0.5342599749565125,2.031036615371704,50000 -1964.42044878006,1.557866096496582,15585.88121342659,33298,0,15585.88121342659,0.4212000072002411,2.652564764022827,10000,17553.76842737198,0.5862890481948853,1.739125370979309,0.5410199761390686,1.962300419807434,50000 -2015.772707223892,1.5926804542541504,16006.172748804092,34204,0,16006.172748804092,0.4263000190258026,2.6361184120178223,10000,18025.4995098114,0.5970116853713989,1.712268590927124,0.5445600152015686,1.9671906232833865,50000 -2066.1069979667664,1.6230809688568115,16426.362616062164,35105,0,16426.362616062164,0.4274000227451324,2.626077175140381,10000,18496.10616064072,0.5822460651397705,1.7583211660385132,0.5433599948883057,1.9646700620651243,50000 -2118.408214569092,1.6574127674102783,16846.550374031067,36006,0,16846.550374031067,0.433100014925003,2.605111837387085,10000,18968.680662870407,0.5903906226158142,1.75039541721344,0.546999990940094,1.9551115036010744,50000 -2169.996456384659,1.6937315464019775,17266.725713968277,36905,0,17266.725713968277,0.4347000122070312,2.613460063934326,10000,19440.53175663948,0.5983593463897705,1.7371487617492676,0.548259973526001,1.973673343658448,50000 -2220.849625825882,1.7289478778839111,17686.66410136223,37806,0,17686.66410136223,0.4330000281333923,2.59557580947876,10000,19911.408517360687,0.5934765338897705,1.7180927991867063,0.549560010433197,1.944177746772766,50000 -2272.119005203247,1.765486717224121,18106.65107870102,38706,0,18106.65107870102,0.4406000077724457,2.6032638549804688,10000,20382.752317667007,0.5957421660423279,1.7233281135559082,0.5529400110244751,1.9345680475234983,50000 -2325.1231849193573,1.80271577835083,18526.90128660202,39610,0,18526.90128660202,0.4489000141620636,2.53645920753479,10000,20856.095355033875,0.6048241853713989,1.6466890573501587,0.558139979839325,1.8898195028305047,50000 -2376.72646856308,1.842914581298828,18947.14054918289,40511,0,18947.14054918289,0.4395000338554382,2.585554361343384,10000,21328.02912712097,0.6100000143051147,1.6908632516860962,0.5554800033569336,1.9389760494232176,50000 -2427.576056241989,1.8747284412384035,19367.37152767181,41412,0,19367.37152767181,0.4462000131607055,2.5395474433898926,10000,21799.19383573532,0.6074023246765137,1.6232702732086182,0.5635200142860413,1.8514550924301147,50000 -2482.440875768661,1.907930850982666,19787.460819482803,42313,0,19787.460819482803,0.4504000246524811,2.53519868850708,10000,22274.2334086895,0.610644519329071,1.6274819374084473,0.561519980430603,1.8759976625442505,50000 -2536.3484270572662,1.941042184829712,20207.386869430546,43216,0,20207.386869430546,0.4533000290393829,2.504759788513184,10000,22748.151103019714,0.6401757597923279,1.5094099044799805,0.5685200095176697,1.8535585403442385,50000 -2590.112258195877,1.9733588695526123,20627.419096708298,44115,0,20627.419096708298,0.4478000104427337,2.5037691593170166,10000,23222.031039714813,0.615527331829071,1.6147336959838867,0.5694000124931335,1.8382197618484497,50000 -2642.76513504982,2.0060572624206543,21047.822548627853,45016,0,21047.822548627853,0.4537000358104706,2.542588472366333,10000,23695.17087316513,0.6134960651397705,1.659540057182312,0.5658400058746338,1.892349362373352,50000 -2693.3661007881165,2.0465316772460938,21468.1283724308,45915,0,21468.1283724308,0.4538000226020813,2.495013952255249,10000,24166.16872239113,0.6446874737739563,1.4981642961502075,0.5705599784851074,1.83890163898468,50000 -2746.573971748352,2.0856590270996094,21888.07623887062,46809,0,21888.07623887062,0.4558000266551971,2.4935083389282227,10000,24639.414632320404,0.6206445097923279,1.6042301654815674,0.5744799971580505,1.81850016117096,50000 -2799.619107723236,2.118630886077881,22308.360613822937,47711,0,22308.360613822937,0.4577000141143799,2.455000400543213,10000,25112.828466653824,0.6209374666213989,1.5577152967453003,0.5724999904632568,1.797107458114624,50000 -2852.789181947708,2.153792381286621,22728.47146344185,48614,0,22728.47146344185,0.4567000269889831,2.475313901901245,10000,25586.20291495323,0.6391406059265137,1.4967988729476929,0.5768799781799316,1.8095238208770752,50000 -2906.063053369522,2.1936678886413574,23148.824804782867,49511,0,23148.824804782867,0.4651000201702118,2.4164364337921143,10000,26059.92187833786,0.6273437142372131,1.5240154266357422,0.5815399885177612,1.7533224821090698,50000 -2957.601359605789,2.2275655269622803,23569.12908434868,50409,0,23569.12908434868,0.4692000150680542,2.4118294715881348,10000,26531.848781347275,0.6357030868530273,1.4875028133392334,0.5874599814414978,1.7306010723114014,50000 -3011.8984134197235,2.261938571929932,23989.14796257019,51307,0,23989.14796257019,0.4647000133991241,2.418069839477539,10000,27006.25090765953,0.6474804282188416,1.4440979957580566,0.5879999995231628,1.7363550662994385,50000 -3063.552988052368,2.295424699783325,24409.063949108124,52205,0,24409.063949108124,0.461400032043457,2.442796230316162,10000,27477.90633225441,0.6319531202316284,1.5517568588256836,0.5824599862098694,1.791892170906067,50000 -3117.278034448624,2.3338751792907715,24829.066326379776,53106,0,24829.066326379776,0.4686000347137451,2.3851122856140137,10000,27951.72364020348,0.6414257884025574,1.4624534845352173,0.5923799872398376,1.7016264200210571,50000 -3168.877803325653,2.370557308197021,25249.0579559803,54005,0,25249.0579559803,0.4639000296592712,2.440372467041016,10000,28423.403182029724,0.6461132764816284,1.4805848598480225,0.5896399617195129,1.7537370920181274,50000 -3219.899830341339,2.4072391986846924,25669.08655667305,54904,0,25669.08655667305,0.4611000120639801,2.4644737243652344,10000,28894.5414185524,0.6288476586341858,1.5620810985565186,0.5803399682044983,1.7848657369613647,50000 -3273.9854452610016,2.443835020065308,26089.343856811523,55805,0,26089.343856811523,0.4727000296115875,2.3828577995300293,10000,29368.97186183929,0.6421093344688416,1.480443835258484,0.5927599668502808,1.717319130897522,50000 -3326.871330738068,2.479809284210205,26509.37562179565,56706,0,26509.37562179565,0.4781000316143036,2.3427188396453857,10000,29841.977155923843,0.6518945097923279,1.4154984951019287,0.5967999696731567,1.6895005702972412,50000 -3380.7066905498505,2.51355242729187,26929.51243591309,57605,0,26929.51243591309,0.4772000312805176,2.4073524475097656,10000,30316.034165859222,0.6357030868530273,1.5449306964874268,0.5915799736976624,1.7564728260040283,50000 -3434.7488169670105,2.548457622528076,27349.63072085381,58505,0,27349.63072085381,0.4745000302791595,2.364131689071656,10000,30790.28051924705,0.6467968821525574,1.471572995185852,0.5975599884986877,1.7017605304718018,50000 -3485.748226881027,2.594045400619507,27769.91752099991,59407,0,27769.91752099991,0.4765000343322754,2.4175078868865967,10000,31261.66304731369,0.6440038681030273,1.4910470247268677,0.590719997882843,1.7559362649917605,50000 -3536.589988708496,2.6317689418792725,28190.01903152466,60301,0,28190.01903152466,0.4755000174045563,2.361530542373657,10000,31732.694142580032,0.6430468559265137,1.471490502357483,0.5990399718284607,1.6971635818481443,50000 -3589.198195934296,2.6737277507781982,28609.990909576416,61201,0,28609.990909576416,0.4775000214576721,2.361438274383545,10000,32205.369146585464,0.6446093320846558,1.4742058515548706,0.6018999814987183,1.6856762170791626,50000 -3643.361774921417,2.71584153175354,29030.119871616364,62101,0,29030.119871616364,0.4792000353336334,2.371155500411988,10000,32679.75479578972,0.6470312476158142,1.4702602624893188,0.5948399901390076,1.7143629789352417,50000 -3696.535520792008,2.7506535053253174,29450.567351818085,63001,0,29450.567351818085,0.4879000186920166,2.299943447113037,10000,33153.461698293686,0.6622851490974426,1.3750454187393188,0.6061999797821045,1.6341488361358645,50000 -3749.080408573151,2.791459083557129,29870.9140355587,63903,0,29870.9140355587,0.4817000329494476,2.317942380905152,10000,33626.445133686066,0.6542187333106995,1.415706753730774,0.6043999791145325,1.652857780456543,50000 -3801.326912641525,2.8265275955200195,30291.509792089462,64801,0,30291.509792089462,0.4809000194072723,2.324260711669922,10000,34099.37319946289,0.6602929830551147,1.405190348625183,0.6033799648284912,1.6764905452728271,50000 -3851.639670610428,2.863790512084961,30711.8409523964,65703,0,30711.8409523964,0.4862000346183777,2.297815322875977,10000,34570.10649561882,0.6716015338897705,1.325235366821289,0.6054999828338623,1.634854793548584,50000 -3901.277138948441,2.9034883975982666,31131.927243471146,66599,0,31131.927243471146,0.4908000230789184,2.349582433700561,10000,35039.92137694359,0.6601366996765137,1.4673115015029907,0.6058799624443054,1.716723084449768,50000 -3952.83056306839,2.9529545307159424,31552.259560346603,67499,0,31552.259560346603,0.4932000339031219,2.2746410369873047,10000,35511.90776872635,0.6662890315055847,1.3695124387741089,0.6140199899673462,1.6112968921661377,50000 -4004.3154785633087,2.992598533630371,31972.5193977356,68398,0,31972.5193977356,0.4933000206947326,2.2944016456604004,10000,35983.74279308319,0.6966406106948853,1.2503221035003662,0.6096799969673157,1.6365011930465698,50000 -4057.010164737701,3.033708333969116,32392.77770280838,69292,0,32392.77770280838,0.4886000156402588,2.312219858169556,10000,36456.78786849976,0.6567773222923279,1.425057291984558,0.6114400029182434,1.6573946475982666,50000 -4108.516996145248,3.072967052459717,32812.85529613495,70190,0,32812.85529613495,0.4886000156402588,2.298884153366089,10000,36928.463305950165,0.6670116782188416,1.3967492580413818,0.6153599619865417,1.6382001638412476,50000 -4161.779107093811,3.111844539642334,33233.01276636124,71090,0,33233.01276636124,0.4946000277996063,2.290399074554444,10000,37401.97440814972,0.6880077719688416,1.30825936794281,0.6117199659347534,1.6579543352127075,50000 -4211.959260225296,3.1535627841949463,33653.27028799057,71989,0,33653.27028799057,0.4919000267982483,2.313216209411621,10000,37872.505085229874,0.6581054329872131,1.4184592962265017,0.613599956035614,1.642994999885559,50000 -4263.880227088928,3.2022414207458496,34073.381747722626,72888,0,34073.381747722626,0.4941000342369079,2.280538320541382,10000,38344.63732886314,0.6729101538658142,1.3513695001602173,0.6170799732208252,1.6132991313934326,50000 -4316.417612314224,3.24575424194336,34493.59297966957,73785,0,34493.59297966957,0.5059000253677368,2.2482941150665283,10000,38817.48063826561,0.6938085556030273,1.2425113916397097,0.6222400069236755,1.575509786605835,50000 -4367.896469116211,3.292013645172119,34913.6925303936,74687,0,34913.6925303936,0.503600001335144,2.223655939102173,10000,39289.15676212311,0.6724804639816284,1.328711986541748,0.6229999661445618,1.5702327489852903,50000 -4419.369015693665,3.339527130126953,35333.97060585022,75586,0,35333.97060585022,0.5058000087738037,2.2576117515563965,10000,39761.00652861595,0.6808202862739563,1.340014934539795,0.6244800090789795,1.6081055402755735,50000 -4474.300911664963,3.38096022605896,35753.88782739639,76484,0,35753.88782739639,0.5015000104904175,2.227324247360229,10000,40235.9482319355,0.6885741949081421,1.2858747243881226,0.6215599775314331,1.5913811922073364,50000 -4526.091993093491,3.42548942565918,36174.27370333672,77378,0,36174.27370333672,0.5027000308036804,2.2037034034729004,10000,40708.220797777176,0.6756640672683716,1.311980962753296,0.6242200136184692,1.5544767379760742,50000 -4580.758625507355,3.471120595932007,36594.59930849075,78274,0,36594.59930849075,0.5020000338554382,2.3106143474578857,10000,41183.31033778191,0.6756445169448853,1.4200063943862915,0.6239799857139587,1.6556744575500488,50000 -4632.851358652115,3.5113136768341064,37014.5864136219,79172,0,37014.5864136219,0.4998000264167785,2.254966974258423,10000,41655.48101997376,0.6913671493530273,1.2846978902816772,0.6258999705314636,1.5905046463012695,50000 -4685.910947561264,3.5520379543304443,37434.66622328758,80067,0,37434.66622328758,0.5051000118255615,2.206871032714844,10000,42128.71227145195,0.6728710532188416,1.3366504907608032,0.6302599906921387,1.55534029006958,50000 -4737.878388643265,3.595428943634033,37854.888414382935,80967,0,37854.888414382935,0.5120000243186951,2.171586751937866,10000,42600.99657249451,0.6885156035423279,1.2635071277618408,0.634660005569458,1.520119309425354,50000 -4789.569073200226,3.63571047782898,38275.072088718414,81862,0,38275.072088718414,0.5149000287055969,2.177933931350708,10000,43072.96156978607,0.6949023008346558,1.2488195896148682,0.6343599557876587,1.5292257070541382,50000 -4842.699766159058,3.677600860595703,38695.33440542221,82760,0,38695.33440542221,0.5074000358581543,2.211868524551392,10000,43546.44796872139,0.6834570169448853,1.3045527935028076,0.6303600072860718,1.5471371412277222,50000 -4894.128688812256,3.7216544151306152,39115.676505327225,83655,0,39115.676505327225,0.5121999979019165,2.1926448345184326,10000,44018.31398630142,0.6850390434265137,1.279969573020935,0.6342200040817261,1.5213388204574585,50000 -4945.990887403488,3.761420726776123,39535.89488720894,84551,0,39535.89488720894,0.5159000158309937,2.160423755645752,10000,44490.48594856262,0.703125,1.225046157836914,0.6385599970817566,1.5180543661117554,50000 -4997.108908653259,3.810989618301392,39956.06528735161,85449,0,39956.06528735161,0.5161000490188599,2.173837900161743,10000,44961.87512779236,0.683300793170929,1.303541898727417,0.6320399641990662,1.532470464706421,50000 -5049.05076956749,3.851077079772949,40376.11482572556,86344,0,40376.11482572556,0.5146999955177307,2.1729276180267334,10000,45433.95785403252,0.6930859088897705,1.2610045671463013,0.6387799978256226,1.5271767377853394,50000 -5100.5192720890045,3.894582748413086,40796.28626155853,87241,0,40796.28626155853,0.5197000503540039,2.1511378288269043,10000,45905.69135284424,0.6998632550239563,1.2377312183380127,0.638700008392334,1.510784149169922,50000 -5153.735585212708,3.9385111331939697,41216.240828990936,88126,0,41216.240828990936,0.5146000385284424,2.151771545410156,10000,46378.955389261246,0.6940820217132568,1.2332541942596436,0.637179970741272,1.494731068611145,50000 -5205.26829957962,3.983673334121704,41636.33084964752,88977,0,41636.33084964752,0.5200999975204468,2.1336536407470703,10000,46850.67201781273,0.6952733993530273,1.2344191074371338,0.6416999697685242,1.4802577495574951,50000 -5258.57132768631,4.025211334228516,42056.24417757988,89871,0,42056.24417757988,0.516700029373169,2.169416666030884,10000,47323.9805533886,0.704296886920929,1.2123621702194214,0.6403399705886841,1.5044403076171875,50000 -5310.227010965347,4.070418834686279,42476.40330886841,90771,0,42476.40330886841,0.5189000368118286,2.149666786193848,10000,47795.89140582085,0.6926171779632568,1.247671127319336,0.6415799856185913,1.4853407144546509,50000 -5362.071304321289,4.122549295425415,42896.397490262985,91663,0,42896.397490262985,0.5196000337600708,2.1370153427124023,10000,48267.83230733872,0.7013866901397705,1.2053130865097046,0.645639955997467,1.4752509593963623,50000 -5413.447510957718,4.164551734924316,43316.60587668419,92557,0,43316.60587668419,0.52510005235672,2.1200804710388184,10000,48739.50898981094,0.7112109065055847,1.1575263738632202,0.6488800048828125,1.4489514827728271,50000 -5465.9878051280975,4.211713075637817,43736.8600165844,93455,0,43736.8600165844,0.5205000042915344,2.171663999557495,10000,49212.40234661102,0.6953515410423279,1.2690379619598389,0.6404399871826172,1.5253348350524902,50000 -5520.260136604309,4.254514455795288,44157.11583042145,94353,0,44157.11583042145,0.526900053024292,2.0961992740631104,10000,49687.024679899216,0.7071093320846558,1.196900486946106,0.6487599611282349,1.4550457000732422,50000 -5572.704406738281,4.302710056304932,44577.40431547165,95251,0,44577.40431547165,0.5276000499725342,2.12835431098938,10000,50159.85675024986,0.7099023461341858,1.2014402151107788,0.6468999981880188,1.4840155839920044,50000 -5624.41569018364,4.350196361541748,44997.624660253525,96149,0,44997.624660253525,0.525700032711029,2.12905216217041,10000,50631.88677740097,0.7105468511581421,1.17483651638031,0.6487999558448792,1.4631506204605105,50000 -5675.649765968323,4.3937060832977295,45417.73803925514,97044,0,45417.73803925514,0.5293000340461731,2.103090524673462,10000,51103.32812476158,0.7084765434265137,1.200149655342102,0.650879979133606,1.4600813388824463,50000 -5729.255409002304,4.434265851974487,45838.01228451729,97942,0,45838.01228451729,0.5301000475883484,2.0910539627075195,10000,51577.29973363876,0.7183789014816284,1.144914627075195,0.6541199684143066,1.44417405128479,50000 -5781.615537166596,4.480335235595703,46258.07188653946,98841,0,46258.07188653946,0.530500054359436,2.081018686294556,10000,52049.81680226326,0.7341406345367432,1.0801435708999634,0.6582199931144714,1.4175523519515991,50000 -5832.705719232559,4.522157669067383,46678.317527771,99736,0,46678.317527771,0.541100025177002,2.062328338623047,10000,52521.2454571724,0.7138085961341858,1.1494009494781494,0.6567999720573425,1.4137784242630005,50000 -5884.45601606369,4.5659730434417725,47098.33820104599,100630,0,47098.33820104599,0.5279000401496887,2.102136850357056,10000,52993.11081790924,0.71728515625,1.1666183471679688,0.6571399569511414,1.4454323053359983,50000 -5938.6716158390045,4.610373020172119,47518.43212723732,101524,0,47518.43212723732,0.5356000065803528,2.04892635345459,10000,53467.515218257904,0.7432226538658142,1.042794108390808,0.6620199680328369,1.3963292837142944,50000 -5990.034274101257,4.655975341796875,47938.56771445274,102418,0,47938.56771445274,0.5392000079154968,2.046409368515014,10000,53939.10957980156,0.7162694931030273,1.132684588432312,0.6610999703407288,1.4008057117462158,50000 -6042.1712164878845,4.702473402023315,48358.572227716446,103313,0,48358.572227716446,0.5410000085830688,2.047637939453125,10000,54411.3479924202,0.7245898246765137,1.1347663402557373,0.6670799851417542,1.4015071392059326,50000 -6095.1235938072205,4.748953342437744,48778.65916538239,104208,0,48778.65916538239,0.5436000227928162,2.037595748901367,10000,54884.48488640785,0.745898425579071,1.0389422178268433,0.6654399633407593,1.3964356184005735,50000 -6146.176388025284,4.796812057495117,49199.02674174309,105104,0,49199.02674174309,0.5414000153541565,2.042243719100952,10000,55356.00403022766,0.7211328148841858,1.13435959815979,0.6629399657249451,1.3976269960403442,50000 -6199.765183925629,4.843739748001099,49619.57560873032,106001,0,49619.57560873032,0.542900025844574,2.0384507179260254,10000,55830.23876786232,0.7251952886581421,1.1188418865203855,0.6670399904251099,1.396875023841858,50000 -6253.255183458328,4.89024019241333,50039.59346675873,106902,0,50039.59346675873,0.539400041103363,2.044331312179565,10000,56303.84405493736,0.7347265481948853,1.0637056827545166,0.6632599830627441,1.3980926275253296,50000 -6306.424833536148,4.935125350952148,50459.5941298008,107798,0,50459.5941298008,0.5404000282287598,2.063552141189575,10000,56777.11013770104,0.7272851467132568,1.1355656385421753,0.6676599979400635,1.4018704891204834,50000 -6357.710418224335,4.979193210601807,50879.51830530167,108692,0,50879.51830530167,0.5416000485420227,2.09930157661438,10000,57248.41434073448,0.7261914014816284,1.173351764678955,0.6655200123786926,1.454077959060669,50000 -6408.644496679306,5.026316642761231,51299.49491405487,109582,0,51299.49491405487,0.5498000383377075,2.0147860050201416,10000,57719.42250370979,0.7465234398841858,1.033710479736328,0.6727199554443359,1.3676403760910034,50000 -6461.516880512238,5.074175596237183,51719.80661559105,110472,0,51719.80661559105,0.54830002784729,2.028778553009033,10000,58192.70450210571,0.7326952815055847,1.1184043884277344,0.6693399548530579,1.3929319381713867,50000 -6515.1032173633575,5.12527322769165,52140.07601737976,111369,0,52140.07601737976,0.5551000237464905,1.9627418518066408,10000,58666.6617166996,0.7405859231948853,1.0059020519256592,0.678380012512207,1.3027551174163818,50000 -6568.118919849396,5.169904947280884,52560.04998207092,112264,0,52560.04998207092,0.5504000186920166,1.9710676670074463,10000,59139.74605703354,0.7454296946525574,0.994961440563202,0.677619993686676,1.3128902912139893,50000 -6621.345211267471,5.218159914016724,52980.3245677948,113154,0,52980.3245677948,0.5530000329017639,1.994787216186524,10000,59613.3454978466,0.7370312213897705,1.0629130601882937,0.674340009689331,1.3487356901168823,50000 -6672.544312000275,5.261758804321289,53400.63660097122,114049,0,53400.63660097122,0.5480000376701355,2.0506300926208496,10000,60084.95163035393,0.7362695336341858,1.114818453788757,0.6717199683189392,1.402865290641785,50000 -6726.00217795372,5.320492744445801,53820.99761939049,114941,0,53820.99761939049,0.5499000549316406,1.9954674243927,10000,60558.87907385826,0.7517382502555847,1.0051695108413696,0.6760599613189697,1.3408530950546265,50000 -6779.807821750641,5.36723256111145,54240.927837610245,115834,0,54240.927837610245,0.5558000206947327,1.9709168672561648,10000,61032.71226191521,0.7432421445846558,1.0460532903671265,0.6818000078201294,1.3182965517044067,50000 -6830.200934410095,5.416102886199951,54660.92917466164,116723,0,54660.92917466164,0.5578000545501709,1.965359807014465,10000,61503.20591568947,0.7484374642372131,1.0299557447433472,0.6839199662208557,1.325279951095581,50000 -6885.558126449585,5.461392164230347,55081.1960170269,117617,0,55081.1960170269,0.560699999332428,1.9365530014038088,10000,61978.925362825394,0.7597265243530273,0.952623963356018,0.6856600046157837,1.292971968650818,50000 -6940.854739904404,5.514406204223633,55501.43515133858,118506,0,55501.43515133858,0.558899998664856,1.9488836526870728,10000,62454.564338207245,0.7503515481948853,1.0049922466278076,0.6844199895858765,1.30093252658844,50000 -6993.368841648102,5.561697959899902,55921.57339477539,119398,0,55921.57339477539,0.55840003490448,1.978242874145508,10000,62927.31496787071,0.7482226490974426,1.0383548736572266,0.6855599880218506,1.3293486833572388,50000 -7044.214889287949,5.607126235961914,56341.53645062447,120291,0,56341.53645062447,0.5616000294685364,1.9497791528701784,10000,63398.22046136856,0.7626562118530273,0.9669126868247986,0.6891199946403503,1.2968260049819946,50000 -7097.606656551361,5.652962684631348,56761.78638911247,121188,0,56761.78638911247,0.5606000423431396,1.9592289924621584,10000,63871.95942783356,0.7483984231948853,1.023191213607788,0.6844800114631653,1.3076735734939575,50000 -7153.155112504959,5.705412864685059,57181.91209387779,122078,0,57181.91209387779,0.5636000037193298,1.9222257137298584,10000,64347.73667383194,0.7550390362739563,0.9838645458221436,0.689799964427948,1.2810405492782593,50000 -7206.63245177269,5.758823394775391,57601.97654438019,122967,0,57601.97654438019,0.5637000203132629,1.895406723022461,10000,64821.3822324276,0.767285168170929,0.914058804512024,0.6899399757385254,1.2592631578445437,50000 -7257.178694009781,5.806592226028442,58022.1356446743,123855,0,58022.1356446743,0.5642000436782837,1.9256024360656736,10000,65292.18530201912,0.7547070384025574,0.9752053618431092,0.691819965839386,1.2582887411117554,50000 -7309.925290107727,5.855309247970581,58442.4238243103,124748,0,58442.4238243103,0.5690000057220459,1.9072344303131104,10000,65765.31973147392,0.7588085532188416,0.9643204808235168,0.6942799687385559,1.259268045425415,50000 -7361.893659353256,5.9061243534088135,58862.69698309898,125642,0,58862.69698309898,0.5710000395774841,1.878885269165039,10000,66237.66225218773,0.7699999809265137,0.9051749110221864,0.6944599747657776,1.2426235675811768,50000 -7413.29775929451,5.957125425338745,59282.61265707016,126535,0,59282.61265707016,0.5751000046730042,1.8674311637878416,10000,66709.08399271965,0.7580859065055847,0.9454815983772278,0.6959999799728394,1.2273520231246948,50000 -7466.876657009125,6.0067572593688965,59702.56274223328,127428,0,59702.56274223328,0.5734000205993652,1.8793699741363523,10000,67182.71284723282,0.7644921541213989,0.9372758865356444,0.6952999830245972,1.2479116916656494,50000 -7520.418255567551,6.055173873901367,60122.79120898247,128321,0,60122.79120898247,0.5678000450134277,1.8976391553878784,10000,67656.58211922646,0.7732617259025574,0.9040424823760986,0.696179986000061,1.2386150360107422,50000 -7572.68899512291,6.103700399398804,60542.8531806469,129217,0,60542.8531806469,0.572100043296814,1.913463234901428,10000,68129.01347446442,0.7625390291213989,0.9737145900726318,0.6969599723815918,1.2702531814575195,50000 -7626.02169251442,6.154720783233643,60962.88206410408,130113,0,60962.88206410408,0.5756000280380249,1.8801839351654053,10000,68602.47761464119,0.7669726610183716,0.9456799626350404,0.7005999684333801,1.244845986366272,50000 -7678.519439458847,6.209751129150391,61383.08276414871,131000,0,61383.08276414871,0.5766000151634216,1.8699551820755005,10000,69075.28121972084,0.7790820002555847,0.8822490572929382,0.7021999955177307,1.2245677709579468,50000 -7730.9835069179535,6.25759744644165,61803.10675501824,131892,0,61803.10675501824,0.575700044631958,1.8553532361984253,10000,69547.86748552322,0.7725781202316284,0.9053280353546144,0.7041800022125244,1.2155123949050903,50000 -7783.519206285477,6.3048930168151855,62223.355676651,132782,0,62223.355676651,0.5833000540733337,1.8324159383773804,10000,70020.74941420555,0.7796679735183716,0.8771883249282837,0.7079199552536011,1.195351243019104,50000 -7836.93887090683,6.35483980178833,62643.74037742615,133679,0,62643.74037742615,0.5857000350952148,1.817793250083924,10000,70494.65468883514,0.78236323595047,0.8421926498413086,0.7068600058555603,1.1863993406295776,50000 -7888.631495952606,6.403504371643066,63063.97910571098,134573,0,63063.97910571098,0.5770000219345093,1.8554986715316768,10000,70966.68536663055,0.7735351324081421,0.902022123336792,0.7064799666404724,1.2089617252349854,50000 -7940.2364201545715,6.454331636428833,63484.01175141335,135466,0,63484.01175141335,0.5851000547409058,1.8302571773529053,10000,71438.42376971245,0.7805468440055847,0.8645254373550415,0.7113199830055237,1.1867696046829224,50000 -7992.3159103393555,6.5025811195373535,63904.009852170944,136359,0,63904.009852170944,0.5819000005722046,1.8246159553527832,10000,71910.59975337982,0.78480464220047,0.8444609045982361,0.7076399922370911,1.1892009973526,50000 -8045.995781421661,6.56099271774292,64323.90890264511,137253,0,64323.90890264511,0.5891000032424927,1.7870514392852783,10000,72384.28784418106,0.7909960746765137,0.8081295490264893,0.7122399806976318,1.1592646837234497,50000 -8097.431872606277,6.613979339599609,64744.21198558808,138143,0,64744.21198558808,0.5836000442504883,1.8493305444717407,10000,72856.13017225266,0.7822265625,0.8877792358398438,0.7106399536132812,1.2171674966812134,50000 -8148.8644115924835,6.663257837295532,65164.19147825241,139035,0,65164.19147825241,0.5884000062942505,1.7982066869735718,10000,73327.64161705971,0.7908398509025574,0.8290248513221741,0.7122799754142761,1.171581745147705,50000 -8200.685409069061,6.717529058456421,65584.32089519501,139928,0,65584.32089519501,0.5904000401496887,1.78292977809906,10000,73799.69849538803,0.8020117282867432,0.7729999423027039,0.7156800031661987,1.1472184658050537,50000 -8253.039557218552,6.766315937042236,66004.48244214058,140823,0,66004.48244214058,0.5879000425338745,1.7908883094787598,10000,74272.31378889084,0.7893944978713989,0.811639666557312,0.7133199572563171,1.150890588760376,50000 -8306.357367038727,6.814566850662232,66424.64887809753,141714,0,66424.64887809753,0.5871000289916992,1.8304020166397093,10000,74745.89710235596,0.7934374809265137,0.8436886668205261,0.7143200039863586,1.187924861907959,50000 -8357.431869983673,6.8636791706085205,66844.86011886597,142603,0,66844.86011886597,0.5949000120162964,1.7849997282028198,10000,75217.28182458878,0.8104296922683716,0.7512505054473877,0.7206400036811829,1.1383086442947388,50000 -8409.880574703217,6.9201719760894775,67264.9873714447,143493,0,67264.9873714447,0.5906000137329102,1.8258694410324097,10000,75689.96438384056,0.7852538824081421,0.8588006496429443,0.713699996471405,1.1768875122070312,50000 -8462.395847082138,6.977154016494751,67685.38305974007,144389,0,67685.38305974007,0.5975000262260437,1.7978652715682983,10000,76162.98290419579,0.8009960651397705,0.8065456748008728,0.7186799645423889,1.1656146049499512,50000 -8514.762991905212,7.027865171432495,68105.52442455292,145286,0,68105.52442455292,0.593500018119812,1.777706503868103,10000,76635.59283566475,0.8099414110183716,0.7473099231719971,0.720579981803894,1.1367037296295166,50000 -8567.967049360275,7.079846858978272,68525.59720563889,146180,0,68525.59720563889,0.5982000231742859,1.7780965566635132,10000,77108.97180509567,0.7946679592132568,0.8052504062652588,0.7216199636459351,1.132479190826416,50000 -8619.504743099213,7.131839036941528,68945.71439909935,147069,0,68945.71439909935,0.5991000533103943,1.7480182647705078,10000,77580.7287735939,0.8073241710662842,0.7616297602653503,0.7257999777793884,1.1163734197616575,50000 -8672.559925556183,7.190670490264893,69365.70506572723,147961,0,69365.70506572723,0.6000000238418579,1.755022406578064,10000,78053.88421607018,0.8156249523162842,0.7296286821365356,0.7245199680328369,1.1153302192687988,50000 -8723.976521015167,7.239832878112793,69785.79780435562,148852,0,69785.79780435562,0.6010000109672546,1.7640174627304075,10000,78525.49350094795,0.7974609136581421,0.7928144931793213,0.7253400087356567,1.122349977493286,50000 -8778.575244188309,7.29470419883728,70205.98230195045,149745,0,70205.98230195045,0.6070000529289246,1.728402614593506,10000,79000.38216662407,0.8090038895606995,0.7413263320922852,0.7287600040435791,1.0949060916900637,50000 -8831.101482391357,7.352070093154907,70626.07793998718,150640,0,70626.07793998718,0.5976999998092651,1.7766977548599243,10000,79473.11308956146,0.8140038847923279,0.7625994086265564,0.7268999814987183,1.140964388847351,50000 -8883.26012802124,7.406408786773682,71046.26297879219,151533,0,71046.26297879219,0.6025000214576721,1.7719186544418335,10000,79945.56105446815,0.8054101467132568,0.7942066192626953,0.7263000011444092,1.1323789358139038,50000 -8938.53841495514,7.459969520568848,71466.43354034424,152424,0,71466.43354034424,0.6144000291824341,1.7185935974121094,10000,80421.11392855644,0.8122460842132568,0.738717794418335,0.7300999760627747,1.0963596105575562,50000 -8989.983581542969,7.511559963226318,71886.71799898148,153321,0,71886.71799898148,0.6084000468254089,1.746617078781128,10000,80892.94577765465,0.81751948595047,0.732594907283783,0.7286399602890015,1.1145501136779783,50000 -9043.079141378405,7.5628063678741455,72306.71941304207,154215,0,72306.71941304207,0.6096000075340271,1.717441439628601,10000,81366.14441990852,0.8136718273162842,0.7343711853027344,0.7323399782180786,1.0908153057098389,50000 -9097.225280284882,7.617284536361694,72726.99397587776,155112,0,72726.99397587776,0.6073000431060791,1.7077174186706543,10000,81840.67090892792,0.8171679377555847,0.7142869234085083,0.7331199645996094,1.077955961227417,50000 -9151.30518102646,7.670612573623657,73147.06365466118,156004,0,73147.06365466118,0.612000048160553,1.6971514225006104,10000,82314.92569756508,0.8233398199081421,0.6954863667488098,0.735040009021759,1.0758261680603027,50000 -9202.519262313845,7.720755577087402,73567.08738541603,156897,0,73567.08738541603,0.6095000505447388,1.704157471656799,10000,82786.26455402374,0.8152148127555847,0.7171177268028259,0.7366399765014648,1.0684080123901367,50000 -9256.575710058212,7.776738405227661,73987.16595101357,157788,0,73987.16595101357,0.6110000014305115,1.6924991607666016,10000,83260.5069026947,0.8219921588897705,0.6984254121780396,0.7360799908638,1.0686091184616089,50000 -9308.94518852234,7.840383529663086,74407.3215315342,158683,0,74407.3215315342,0.6097000241279602,1.7294844388961792,10000,83733.14594745636,0.8245312571525574,0.7121429443359375,0.7342599630355835,1.108445644378662,50000 -9362.36895108223,7.894147157669067,74827.42336130142,159580,0,74827.42336130142,0.6133000254631042,1.709965705871582,10000,84206.77592754364,0.8193163871765137,0.7136068344116211,0.7355799674987793,1.0825302600860596,50000 -9416.1791973114,7.950130701065063,75247.45334744453,160474,0,75247.45334744453,0.616100013256073,1.6938591003417969,10000,84680.72242164612,0.8253905773162842,0.6981134414672852,0.7397399544715881,1.0691667795181274,50000 -9467.809502601624,8.004290342330933,75667.55098891258,161367,0,75667.55098891258,0.6173000335693359,1.6790910959243774,10000,85152.55523252487,0.8282226324081421,0.6513522267341614,0.7389799952507019,1.0421522855758667,50000 -9519.163709878922,8.066744565963745,76087.53114652634,162260,0,76087.53114652634,0.6164000034332275,1.6963045597076416,10000,85624.00301170349,0.8219531178474426,0.7046661376953125,0.739039957523346,1.0686395168304443,50000 -9574.24474644661,8.126704692840576,76507.66302323341,163149,0,76507.66302323341,0.6152999997138977,1.6925363540649414,10000,86099.32566308975,0.8304882645606995,0.668041467666626,0.7397800087928772,1.0610032081604004,50000 -9625.999377965927,8.179861307144165,76927.88128042221,164040,0,76927.88128042221,0.6208000183105469,1.6820333003997805,10000,86571.4026389122,0.8357031345367432,0.6520687937736511,0.7403199672698975,1.0528931617736816,50000 -9678.775826692581,8.240126609802246,77348.1300997734,164936,0,77348.1300997734,0.6177000403404236,1.687680959701538,10000,87044.53920483589,0.8275195360183716,0.6876667737960815,0.742859959602356,1.0580904483795166,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/measurements.csv deleted file mode 100644 index b792c0103..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1841 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.35730848,6.9077563,,,,,,,,,,,,,, -1,,,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,41.93090534210205,73.41631269454956,41.93090534210205,31.48529624938965,0.0,0.0 -100,0.5226721,6.8770037,,,,,,,,,,,,,, -200,0.56028724,6.7942576,,,,,,,,,,,,,, -300,0.6729114,6.7569294,,,,,,,,,,,,,, -400,0.79278034,6.702725,,,,,,,,,,,,,, -500,1.2728465,6.5422544,,,,,,,,,,,,,, -600,1.103655,6.4454136,,,,,,,,,,,,,, -700,1.6141862,6.398957,,,,,,,,,,,,,, -800,0.99988246,6.277336,,,,,,,,,,,,,, -839,,,0.0321679674088954,5.976029396057129,0.0268599987030029,6.033750534057617,50000.0,0.0220000017434358,6.133785247802734,10000.0,461.9541461467743,545.1891384124756,461.9541461467743,83.17005729675293,0.0171785354614257,0.0 -900,1.7395799,6.3652062,,,,,,,,,,,,,, -1000,0.953873,6.1802454,,,,,,,,,,,,,, -1100,1.1989865,6.036253,,,,,,,,,,,,,, -1200,1.0497041,6.0133004,,,,,,,,,,,,,, -1300,1.0850152,5.926459,,,,,,,,,,,,,, -1400,1.6792569,6.0086327,,,,,,,,,,,,,, -1500,1.3869395,5.9710875,,,,,,,,,,,,,, -1600,1.0625324,5.74369,,,,,,,,,,,,,, -1700,1.3455002,5.8178053,,,,,,,,,,,,,, -1741,,,0.0686132833361625,5.289651393890381,0.0642199963331222,5.34092903137207,50000.0,0.0499000027775764,5.561094284057617,10000.0,881.9554603099823,1014.4922845363616,881.9554603099823,132.39385414123535,0.0436456203460693,0.0 -1800,1.171355,6.0454245,,,,,,,,,,,,,, -1900,1.1000277,6.41831,,,,,,,,,,,,,, -2000,1.0705943,5.716703,,,,,,,,,,,,,, -2100,0.8294253,6.2714014,,,,,,,,,,,,,, -2200,1.1482056,5.6373444,,,,,,,,,,,,,, -2300,0.97436786,5.4591627,,,,,,,,,,,,,, -2400,0.6569817,6.3921175,,,,,,,,,,,,,, -2500,0.72635514,6.5218253,,,,,,,,,,,,,, -2600,0.9380325,5.3905916,,,,,,,,,,,,,, -2649,,,0.114414058625698,4.848740100860596,0.1076399981975555,4.910977840423584,50000.0,0.0838000029325485,5.196085453033447,10000.0,1302.169753074646,1485.8664710521698,1302.169753074646,183.4759097099304,0.0698153972625732,0.0 -2700,1.2610236,5.3379784,,,,,,,,,,,,,, -2800,0.8087803,6.4112344,,,,,,,,,,,,,, -2900,0.93436646,5.5158815,,,,,,,,,,,,,, -3000,0.9011645,5.238053,,,,,,,,,,,,,, -3100,0.8553483,5.256522,,,,,,,,,,,,,, -3200,1.0061047,6.395536,,,,,,,,,,,,,, -3300,0.9631861,5.1334257,,,,,,,,,,,,,, -3400,0.84467316,5.218804,,,,,,,,,,,,,, -3500,0.98899424,5.324915,,,,,,,,,,,,,, -3554,,,0.1730273365974426,4.282963752746582,0.1574199944734573,4.420511245727539,50000.0,0.1200000047683715,4.79691743850708,10000.0,1722.0868241786957,1956.3246581554413,1722.0868241786957,233.9338686466217,0.1003086566925048,0.0 -3600,0.8092314,6.0678725,,,,,,,,,,,,,, -3700,0.95163774,5.1041713,,,,,,,,,,,,,, -3800,0.9411632,5.093264,,,,,,,,,,,,,, -3900,1.1579188,5.0396757,,,,,,,,,,,,,, -4000,0.8179764,4.816235,,,,,,,,,,,,,, -4100,0.8634806,5.482866,,,,,,,,,,,,,, -4200,0.8622679,4.8595357,,,,,,,,,,,,,, -4300,0.8115085,4.6428437,,,,,,,,,,,,,, -4400,0.7373683,5.826535,,,,,,,,,,,,,, -4448,,,0.2078320235013961,4.076893329620361,0.1960600018501281,4.17407751083374,50000.0,0.1438000053167343,4.610869884490967,10000.0,2141.7024862766266,2427.956640481949,2141.7024862766266,285.4160861968994,0.5835461616516113,0.0 -4500,1.2564464,4.7575326,,,,,,,,,,,,,, -4600,0.8982049,4.6641273,,,,,,,,,,,,,, -4700,0.9982737,4.582605,,,,,,,,,,,,,, -4800,0.88182604,4.5579686,,,,,,,,,,,,,, -4900,1.3573879,4.676658,,,,,,,,,,,,,, -5000,0.64204884,6.007911,,,,,,,,,,,,,, -5100,0.7338642,5.0301795,,,,,,,,,,,,,, -5200,0.7410626,4.6488466,,,,,,,,,,,,,, -5300,0.97735584,4.455676,,,,,,,,,,,,,, -5350,,,0.2707226574420929,3.5997939109802246,0.2500199973583221,3.711319923400879,50000.0,0.1876000016927719,4.200911998748779,10000.0,2561.650346755981,2899.76961684227,2561.650346755981,337.2035584449768,0.6109819412231445,0.0 -5400,0.6966047,4.473235,,,,,,,,,,,,,, -5500,0.74112195,4.5332546,,,,,,,,,,,,,, -5600,0.7594701,5.5301843,,,,,,,,,,,,,, -5700,0.7485861,4.5702186,,,,,,,,,,,,,, -5800,1.1629382,4.400188,,,,,,,,,,,,,, -5900,0.8563096,4.302947,,,,,,,,,,,,,, -6000,0.77723294,4.870467,,,,,,,,,,,,,, -6100,0.5410838,5.671882,,,,,,,,,,,,,, -6200,0.77764726,4.5742645,,,,,,,,,,,,,, -6257,,,0.3174609243869781,3.2546160221099854,0.2885800004005432,3.4442403316497803,50000.0,0.2152000069618225,3.982329845428467,10000.0,2981.769757270813,3371.370764017105,2981.769757270813,388.6069040298462,0.6374530792236328,0.0 -6300,0.8051001,4.2107267,,,,,,,,,,,,,, -6400,0.6080121,5.7053967,,,,,,,,,,,,,, -6500,1.3590498,4.281445,,,,,,,,,,,,,, -6600,0.88586885,4.314437,,,,,,,,,,,,,, -6700,0.6262685,5.980509,,,,,,,,,,,,,, -6800,0.757882,5.6735697,,,,,,,,,,,,,, -6900,0.6523457,5.91404,,,,,,,,,,,,,, -7000,0.58986896,5.8173823,,,,,,,,,,,,,, -7100,0.9129842,4.182074,,,,,,,,,,,,,, -7165,,,0.3345507681369781,3.108985424041748,0.3141399919986725,3.2446765899658203,50000.0,0.2417000085115432,3.805813789367676,10000.0,3401.7290391922,3843.996475219727,3401.7290391922,441.1956124305725,0.6641831398010254,0.0 -7200,1.2781397,4.3360023,,,,,,,,,,,,,, -7300,1.0615039,4.085659,,,,,,,,,,,,,, -7400,0.74154407,4.2420297,,,,,,,,,,,,,, -7500,0.8267377,4.4425836,,,,,,,,,,,,,, -7600,0.9658865,4.02886,,,,,,,,,,,,,, -7700,0.96637774,4.100191,,,,,,,,,,,,,, -7800,0.8613533,3.9854264,,,,,,,,,,,,,, -7900,0.8949536,4.01929,,,,,,,,,,,,,, -8000,1.0159405,4.038151,,,,,,,,,,,,,, -8073,,,0.3506445288658142,3.069853067398072,0.3208400011062622,3.2362983226776123,50000.0,0.2471000105142593,3.786952495574951,10000.0,3821.696554660797,4315.6374979019165,3821.696554660797,492.78946113586426,0.6920418739318848,0.0 -8100,0.87996316,4.444483,,,,,,,,,,,,,, -8200,0.51998186,5.838358,,,,,,,,,,,,,, -8300,1.0851622,3.9376123,,,,,,,,,,,,,, -8400,0.6256265,5.363347,,,,,,,,,,,,,, -8500,0.8586904,4.3380156,,,,,,,,,,,,,, -8600,0.9097309,4.2931724,,,,,,,,,,,,,, -8700,0.56867045,5.725294,,,,,,,,,,,,,, -8800,0.82642245,3.89505,,,,,,,,,,,,,, -8900,0.73296803,5.972761,,,,,,,,,,,,,, -8981,,,0.3797656297683716,2.877138614654541,0.3468599915504455,3.067584991455078,50000.0,0.2645000219345093,3.641249418258667,10000.0,4241.8689959049225,4787.127847909927,4241.8689959049225,544.0257968902588,0.7220323085784912,0.0 -9000,0.63339585,5.1959815,,,,,,,,,,,,,, -9100,0.98120624,3.8871453,,,,,,,,,,,,,, -9200,0.8275501,3.8098724,,,,,,,,,,,,,, -9300,0.82496625,3.8068662,,,,,,,,,,,,,, -9400,0.56475013,5.7155557,,,,,,,,,,,,,, -9500,1.2751813,3.7305021,,,,,,,,,,,,,, -9600,1.0108222,3.7285268,,,,,,,,,,,,,, -9700,0.6709746,5.017762,,,,,,,,,,,,,, -9800,0.74974155,5.5545425,,,,,,,,,,,,,, -9886,,,0.4021289050579071,2.717653751373291,0.3776800036430359,2.877584934234619,50000.0,0.2873000204563141,3.478476047515869,10000.0,4662.201369285584,5262.30820274353,4662.201369285584,598.793160200119,0.7506968975067139,0.0 -9900,0.72678226,6.01093,,,,,,,,,,,,,, -10000,0.89333785,3.896326,,,,,,,,,,,,,, -10100,0.95147157,3.8511593,,,,,,,,,,,,,, -10200,0.8297324,3.9586375,,,,,,,,,,,,,, -10300,1.0455898,3.7610452,,,,,,,,,,,,,, -10400,0.9181419,3.651257,,,,,,,,,,,,,, -10500,0.9014045,3.7909315,,,,,,,,,,,,,, -10600,1.0733232,3.7790954,,,,,,,,,,,,,, -10700,0.65175426,4.836776,,,,,,,,,,,,,, -10791,,,0.4259765446186065,2.5819032192230225,0.3936199843883514,2.756279706954956,50000.0,0.299200028181076,3.382631063461304,10000.0,5082.514708995819,5734.52169752121,5082.514708995819,650.6066122055054,0.7820405960083008,0.0 -10800,0.7886346,5.2420435,,,,,,,,,,,,,, -10900,0.95602775,3.7162967,,,,,,,,,,,,,, -11000,0.73141706,4.8795652,,,,,,,,,,,,,, -11100,0.9063784,3.6595497,,,,,,,,,,,,,, -11200,0.88532627,3.540313,,,,,,,,,,,,,, -11300,1.0730101,3.5419745,,,,,,,,,,,,,, -11400,1.0249194,4.3451967,,,,,,,,,,,,,, -11500,0.7976861,5.8282967,,,,,,,,,,,,,, -11600,0.6433364,5.663672,,,,,,,,,,,,,, -11694,,,0.4414648413658142,2.500446081161499,0.4053199887275696,2.709198474884033,50000.0,0.317900002002716,3.3103625774383545,10000.0,5502.804249048233,6207.301434278488,5502.804249048233,703.0180859565735,0.8099539279937744,0.0 -11700,0.8708932,3.7746804,,,,,,,,,,,,,, -11800,0.70681375,4.607446,,,,,,,,,,,,,, -11900,0.9560125,3.867376,,,,,,,,,,,,,, -12000,0.9069574,3.4094124,,,,,,,,,,,,,, -12100,0.78283596,4.902628,,,,,,,,,,,,,, -12200,1.001896,3.4330049,,,,,,,,,,,,,, -12300,0.9185364,3.7837014,,,,,,,,,,,,,, -12400,0.9537564,3.5291367,,,,,,,,,,,,,, -12500,0.88475114,3.850647,,,,,,,,,,,,,, -12599,,,0.4473632574081421,2.448283195495605,0.4174199998378753,2.6113107204437256,50000.0,0.3225000202655792,3.253435611724853,10000.0,5923.055751800537,6679.461454868317,5923.055751800537,754.8459165096283,0.8380134105682373,0.0 -12600,0.7182644,5.5672193,,,,,,,,,,,,,, -12700,1.1598533,3.4640553,,,,,,,,,,,,,, -12800,0.9485012,3.5318987,,,,,,,,,,,,,, -12900,1.0460004,3.5515628,,,,,,,,,,,,,, -13000,0.8044497,4.372379,,,,,,,,,,,,,, -13100,0.9411624,3.57297,,,,,,,,,,,,,, -13200,0.9032262,3.4486604,,,,,,,,,,,,,, -13300,0.970031,3.374043,,,,,,,,,,,,,, -13400,0.89820415,3.4315338,,,,,,,,,,,,,, -13500,,,0.4481445252895355,2.4669337272644043,0.4198599755764007,2.638878107070923,50000.0,0.3232000172138214,3.279659509658813,10000.0,6343.135094165802,7152.180879831314,6343.135094165802,807.402941942215,0.8693232536315918,0.0 -13500,0.84965694,4.4808817,,,,,,,,,,,,,, -13600,0.98009336,3.6173515,,,,,,,,,,,,,, -13700,0.9601508,4.0880957,,,,,,,,,,,,,, -13800,0.7355683,5.413826,,,,,,,,,,,,,, -13900,1.0503451,3.7075922,,,,,,,,,,,,,, -14000,0.70368576,4.8374224,,,,,,,,,,,,,, -14100,0.8535884,4.1058846,,,,,,,,,,,,,, -14200,0.6560043,5.4683905,,,,,,,,,,,,,, -14300,0.7105318,5.4135222,,,,,,,,,,,,,, -14400,0.80277646,5.788411,,,,,,,,,,,,,, -14401,,,0.4692578017711639,2.391101121902466,0.4292799830436706,2.592332601547241,50000.0,0.3260000050067901,3.225304126739502,10000.0,6763.108623743057,7626.033350944519,6763.108623743057,861.2038099765778,0.8967950344085693,0.0 -14500,1.2177154,3.4469237,,,,,,,,,,,,,, -14600,1.0108403,3.4214122,,,,,,,,,,,,,, -14700,0.9384088,3.4449384,,,,,,,,,,,,,, -14800,0.95343626,3.548336,,,,,,,,,,,,,, -14900,1.0833421,3.1557026,,,,,,,,,,,,,, -15000,0.96928537,3.3120072,,,,,,,,,,,,,, -15100,1.1859269,3.681902,,,,,,,,,,,,,, -15200,1.0088967,3.4010491,,,,,,,,,,,,,, -15300,0.9727599,3.495729,,,,,,,,,,,,,, -15307,,,0.4732421636581421,2.3394665718078613,0.4375399947166443,2.532576084136963,50000.0,0.3416000306606293,3.1499056816101074,10000.0,7183.137870073319,8098.40290927887,7183.137870073319,913.4623625278472,0.926544189453125,0.0 -15400,0.8592536,4.8639045,,,,,,,,,,,,,, -15500,0.93097246,3.386692,,,,,,,,,,,,,, -15600,1.0133064,3.5838256,,,,,,,,,,,,,, -15700,0.93182606,3.4803338,,,,,,,,,,,,,, -15800,0.99234706,3.5902362,,,,,,,,,,,,,, -15900,0.8625031,4.036576,,,,,,,,,,,,,, -16000,0.9207528,3.237701,,,,,,,,,,,,,, -16100,0.99688566,3.2269738,,,,,,,,,,,,,, -16200,0.8542064,3.6647184,,,,,,,,,,,,,, -16210,,,0.4846288859844208,2.263720750808716,0.4474599957466125,2.450514554977417,50000.0,0.3473000228404999,3.0882656574249268,10000.0,7603.157238006592,8571.176964044571,7603.157238006592,966.138659477234,0.9530913829803468,0.0 -16300,1.2385818,4.1943364,,,,,,,,,,,,,, -16400,1.1010854,3.2736778,,,,,,,,,,,,,, -16500,0.79266346,5.369965,,,,,,,,,,,,,, -16600,1.0524064,3.2263308,,,,,,,,,,,,,, -16700,1.0611305,3.370592,,,,,,,,,,,,,, -16800,0.8291899,4.4237905,,,,,,,,,,,,,, -16900,0.955887,3.4795759,,,,,,,,,,,,,, -17000,0.807279,4.0685673,,,,,,,,,,,,,, -17100,0.9207965,3.8438723,,,,,,,,,,,,,, -17108,,,0.5009179711341858,2.1943092346191406,0.4619999825954437,2.397297143936157,50000.0,0.3574000298976898,3.0589606761932373,10000.0,8023.078520536423,9043.016461133957,8023.078520536423,1017.9731209278108,0.9839563369750975,0.0 -17200,1.0265532,3.3060167,,,,,,,,,,,,,, -17300,1.095276,4.084996,,,,,,,,,,,,,, -17400,0.799887,4.3183675,,,,,,,,,,,,,, -17500,0.964892,3.2814357,,,,,,,,,,,,,, -17600,1.0777982,3.301809,,,,,,,,,,,,,, -17700,1.0658139,3.2665415,,,,,,,,,,,,,, -17800,1.1453925,3.1842997,,,,,,,,,,,,,, -17900,0.98974293,3.442017,,,,,,,,,,,,,, -18000,1.2839223,3.309353,,,,,,,,,,,,,, -18007,,,0.5083202719688416,2.2235052585601807,0.4515199959278106,2.495116710662842,50000.0,0.3568000197410583,3.10998272895813,10000.0,8443.216924190521,9516.7394759655,8443.216924190521,1071.4758758544922,1.0145676136016846,0.0 -18100,1.0420804,3.044523,,,,,,,,,,,,,, -18200,0.9031155,4.167135,,,,,,,,,,,,,, -18300,0.85551167,3.999561,,,,,,,,,,,,,, -18400,1.0632372,3.16177,,,,,,,,,,,,,, -18500,1.033252,3.2135344,,,,,,,,,,,,,, -18600,1.0710338,3.1651115,,,,,,,,,,,,,, -18700,1.212348,3.27144,,,,,,,,,,,,,, -18800,0.7809735,4.6711726,,,,,,,,,,,,,, -18900,1.0554788,3.5998049,,,,,,,,,,,,,, -18907,,,0.5022070407867432,2.1813580989837646,0.4710799753665924,2.35910964012146,50000.0,0.3609000146389007,3.0243079662323,10000.0,8863.401502132416,9989.141452550888,8863.401502132416,1123.613642454147,1.043529987335205,0.0 -19000,0.9806031,3.1197047,,,,,,,,,,,,,, -19100,0.74125886,5.4309187,,,,,,,,,,,,,, -19200,1.0312145,3.2140152,,,,,,,,,,,,,, -19300,1.009259,3.2055857,,,,,,,,,,,,,, -19400,1.0155548,3.2238002,,,,,,,,,,,,,, -19500,0.7076925,5.602216,,,,,,,,,,,,,, -19600,0.8810885,3.7464097,,,,,,,,,,,,,, -19700,1.1364311,3.2236764,,,,,,,,,,,,,, -19800,0.994671,3.6529703,,,,,,,,,,,,,, -19804,,,0.5236718654632568,2.076261043548584,0.4801599979400635,2.295961618423462,50000.0,0.3738000094890594,2.9481990337371826,10000.0,9283.423431873322,10464.098256111143,9283.423431873322,1178.4694809913635,1.0712995529174805,0.0 -19900,1.194399,3.2067914,,,,,,,,,,,,,, -20000,1.1349964,3.1197658,,,,,,,,,,,,,, -20100,0.80482864,5.2493167,,,,,,,,,,,,,, -20200,0.86872375,4.142427,,,,,,,,,,,,,, -20300,0.81626326,4.766164,,,,,,,,,,,,,, -20400,0.75989586,5.6479855,,,,,,,,,,,,,, -20500,1.0893822,3.1030915,,,,,,,,,,,,,, -20600,0.7564869,5.6689215,,,,,,,,,,,,,, -20700,1.4172583,3.229123,,,,,,,,,,,,,, -20704,,,0.556835949420929,1.9286086559295648,0.4835599958896637,2.271160364151001,50000.0,0.3751000165939331,2.9266579151153564,10000.0,9703.366194963455,10936.4870698452,9703.366194963455,1230.8315978050232,1.1028954982757568,0.0 -20800,1.1506451,3.232943,,,,,,,,,,,,,, -20900,0.91262645,5.628523,,,,,,,,,,,,,, -21000,0.8020037,4.920837,,,,,,,,,,,,,, -21100,1.0170785,3.4048965,,,,,,,,,,,,,, -21200,1.0440581,3.1557364,,,,,,,,,,,,,, -21300,1.1642251,3.1900282,,,,,,,,,,,,,, -21400,0.8260117,5.290009,,,,,,,,,,,,,, -21500,0.87451917,4.0459733,,,,,,,,,,,,,, -21600,1.1041067,3.2110832,,,,,,,,,,,,,, -21605,,,0.5281640291213989,2.053582191467285,0.4886199831962585,2.251695156097412,50000.0,0.3795000314712524,2.898561477661133,10000.0,10123.38339304924,11409.501055002213,10123.38339304924,1283.7493817806244,1.1308624744415283,0.0 -21700,1.0458729,3.2569647,,,,,,,,,,,,,, -21800,1.0141991,3.3171172,,,,,,,,,,,,,, -21900,0.8970286,4.5503783,,,,,,,,,,,,,, -22000,1.1670153,3.0352206,,,,,,,,,,,,,, -22100,1.0259945,3.2833307,,,,,,,,,,,,,, -22200,1.2293247,3.1308618,,,,,,,,,,,,,, -22300,1.1124083,5.52189,,,,,,,,,,,,,, -22400,1.201915,3.0252984,,,,,,,,,,,,,, -22500,1.1430763,2.9852347,,,,,,,,,,,,,, -22504,,,0.5364062190055847,1.9806296825408936,0.4958599805831909,2.1945180892944336,50000.0,0.3874000310897827,2.876436948776245,10000.0,10543.948052406313,11883.90704703331,10543.948052406313,1337.5087454319,1.162177801132202,0.0 -22600,1.2388911,3.101682,,,,,,,,,,,,,, -22700,1.1751401,3.3384054,,,,,,,,,,,,,, -22800,0.73634285,5.483616,,,,,,,,,,,,,, -22900,1.036553,3.1481752,,,,,,,,,,,,,, -23000,1.354987,3.0725372,,,,,,,,,,,,,, -23100,1.021611,3.583063,,,,,,,,,,,,,, -23200,0.88464475,5.485098,,,,,,,,,,,,,, -23300,1.0274639,3.5583692,,,,,,,,,,,,,, -23400,0.97272,3.3259802,,,,,,,,,,,,,, -23404,,,0.5655078291893005,1.908821702003479,0.4991799890995025,2.2292399406433105,50000.0,0.3874000310897827,2.8902664184570312,10000.0,10964.240460157394,12358.45002245903,10964.240460157394,1391.6791398525238,1.190190076828003,0.0 -23500,1.1167936,3.1102705,,,,,,,,,,,,,, -23600,0.9079053,5.482374,,,,,,,,,,,,,, -23700,1.0304611,3.2612052,,,,,,,,,,,,,, -23800,0.965953,3.7174144,,,,,,,,,,,,,, -23900,0.987377,3.5194921,,,,,,,,,,,,,, -24000,1.0078875,3.121285,,,,,,,,,,,,,, -24100,0.7813824,5.443832,,,,,,,,,,,,,, -24200,1.106149,3.1080701,,,,,,,,,,,,,, -24300,,,0.5464257597923279,1.977408051490784,0.5089799761772156,2.1693193912506104,50000.0,0.3955000042915344,2.81775450706482,10000.0,11384.598715543749,12830.399197340012,11384.598715543749,1443.191126346588,1.218409776687622,0.0 -24300,1.2391256,2.8912063,,,,,,,,,,,,,, -24400,1.3424118,3.061146,,,,,,,,,,,,,, -24500,1.2683692,3.0228846,,,,,,,,,,,,,, -24600,1.0623384,2.8927913,,,,,,,,,,,,,, -24700,1.0435921,3.0572739,,,,,,,,,,,,,, -24800,0.96642035,4.74858,,,,,,,,,,,,,, -24900,1.2580421,3.0722783,,,,,,,,,,,,,, -25000,1.0915487,3.4128764,,,,,,,,,,,,,, -25100,0.7671716,5.201717,,,,,,,,,,,,,, -25200,,,0.5544726252555847,1.8977075815200808,0.5140599608421326,2.105320930480957,50000.0,0.4011000096797943,2.787417411804199,10000.0,11804.667268276216,13303.007089614868,11804.667268276216,1495.6441519260406,1.2541024684906006,0.0 -25200,1.0504365,3.1657095,,,,,,,,,,,,,, -25300,0.96753913,3.5916023,,,,,,,,,,,,,, -25400,1.1665659,3.1357265,,,,,,,,,,,,,, -25500,1.0584567,3.5713449,,,,,,,,,,,,,, -25600,0.8594986,4.015044,,,,,,,,,,,,,, -25700,1.1015236,2.974228,,,,,,,,,,,,,, -25800,0.9089969,4.329952,,,,,,,,,,,,,, -25900,0.8857735,4.154649,,,,,,,,,,,,,, -26000,1.1115471,2.970195,,,,,,,,,,,,,, -26100,,,0.5863476395606995,1.7209587097167969,0.522059977054596,2.043886423110962,50000.0,0.4094000160694122,2.7280988693237305,10000.0,12224.890575885773,13778.07923269272,12224.890575885773,1550.408257484436,1.28757643699646,0.0 -26100,1.2962894,3.0046802,,,,,,,,,,,,,, -26200,0.81915325,5.106283,,,,,,,,,,,,,, -26300,1.059835,3.2450435,,,,,,,,,,,,,, -26400,1.1334246,3.196023,,,,,,,,,,,,,, -26500,0.978165,3.0718815,,,,,,,,,,,,,, -26600,1.0738037,3.5864782,,,,,,,,,,,,,, -26700,1.1545669,2.8427236,,,,,,,,,,,,,, -26800,1.1261033,3.2012432,,,,,,,,,,,,,, -26900,1.0889167,3.6246855,,,,,,,,,,,,,, -27000,,,0.5658202767372131,1.8472483158111568,0.5233199596405029,2.056349754333496,50000.0,0.4092000126838684,2.723639726638794,10000.0,12645.1004447937,14250.550843000412,12645.1004447937,1602.5821635723114,1.3235411643981934,0.0 -27000,1.128222,2.9480944,,,,,,,,,,,,,, -27100,1.1179982,2.874032,,,,,,,,,,,,,, -27200,0.87788874,4.5966153,,,,,,,,,,,,,, -27300,0.9390329,3.353895,,,,,,,,,,,,,, -27400,1.1683853,3.0159755,,,,,,,,,,,,,, -27500,1.1202052,3.5800867,,,,,,,,,,,,,, -27600,1.0462378,3.5387025,,,,,,,,,,,,,, -27700,0.92273736,4.519606,,,,,,,,,,,,,, -27800,1.1051534,3.1889505,,,,,,,,,,,,,, -27898,,,0.557812511920929,1.8963301181793213,0.5161399841308594,2.110172748565674,50000.0,0.4037000238895416,2.790687322616577,10000.0,13065.12490582466,14721.279171228409,13065.12490582466,1653.192789554596,1.3615074157714844,0.0 -27900,1.0760646,3.0647454,,,,,,,,,,,,,, -28000,1.1454757,2.9020271,,,,,,,,,,,,,, -28100,1.10284,3.0011456,,,,,,,,,,,,,, -28200,1.2644433,2.824019,,,,,,,,,,,,,, -28300,1.2953365,2.890932,,,,,,,,,,,,,, -28400,0.8262868,5.184815,,,,,,,,,,,,,, -28500,1.1126419,2.844229,,,,,,,,,,,,,, -28600,0.93983376,5.1498523,,,,,,,,,,,,,, -28700,0.7682077,5.4751587,,,,,,,,,,,,,, -28795,,,0.5855664014816284,1.7688456773757937,0.5272600054740906,2.059099674224853,50000.0,0.4162000119686126,2.7257065773010254,10000.0,13485.352035045624,15193.418963193892,13485.352035045624,1705.023080587387,1.392759084701538,0.0 -28800,1.1516987,4.644383,,,,,,,,,,,,,, -28900,1.072431,2.874935,,,,,,,,,,,,,, -29000,0.9110823,3.521678,,,,,,,,,,,,,, -29100,0.934086,3.8395152,,,,,,,,,,,,,, -29200,0.7703969,5.0757284,,,,,,,,,,,,,, -29300,1.3656497,3.0096807,,,,,,,,,,,,,, -29400,1.0044625,3.2015343,,,,,,,,,,,,,, -29500,0.95079625,4.607183,,,,,,,,,,,,,, -29600,1.1342701,2.963546,,,,,,,,,,,,,, -29694,,,0.5672265291213989,1.8366743326187127,0.526699960231781,2.0453226566314697,50000.0,0.4116000235080719,2.7133548259735107,10000.0,13905.495992660522,15666.241523504255,13905.495992660522,1757.6186830997467,1.4242591857910156,0.0 -29700,1.0719129,2.9937499,,,,,,,,,,,,,, -29800,1.3006167,2.905778,,,,,,,,,,,,,, -29900,0.9387445,4.5310297,,,,,,,,,,,,,, -30000,0.9936629,3.376429,,,,,,,,,,,,,, -30100,1.1021326,3.0158634,,,,,,,,,,,,,, -30200,0.95814425,3.6973782,,,,,,,,,,,,,, -30300,0.8490957,5.328363,,,,,,,,,,,,,, -30400,1.3621718,2.8530538,,,,,,,,,,,,,, -30500,0.8017171,5.405604,,,,,,,,,,,,,, -30595,,,0.5782226324081421,1.7867298126220703,0.5320000052452087,2.0195980072021484,50000.0,0.4234000146389007,2.6812961101531982,10000.0,14325.84366250038,16138.692564725876,14325.84366250038,1809.639849185944,1.4553961753845217,0.0 -30600,1.1752852,2.9200494,,,,,,,,,,,,,, -30700,0.85985434,4.645159,,,,,,,,,,,,,, -30800,1.1795294,2.9830685,,,,,,,,,,,,,, -30900,1.0007522,3.6200316,,,,,,,,,,,,,, -31000,0.8304036,5.035512,,,,,,,,,,,,,, -31100,1.0641888,3.4223807,,,,,,,,,,,,,, -31200,1.123242,2.743732,,,,,,,,,,,,,, -31300,1.0905558,5.4965434,,,,,,,,,,,,,, -31400,0.9996398,3.5470657,,,,,,,,,,,,,, -31494,,,0.5876171588897705,1.7332189083099363,0.5371800065040588,1.998450875282288,50000.0,0.4247000217437744,2.6817445755004883,10000.0,14745.781195640564,16610.330486536026,14745.781195640564,1861.2554004192352,1.4885175228118896,0.0 -31500,1.0453221,3.217936,,,,,,,,,,,,,, -31600,1.3686055,2.9212036,,,,,,,,,,,,,, -31700,0.8658943,4.579705,,,,,,,,,,,,,, -31800,1.0773256,3.284882,,,,,,,,,,,,,, -31900,0.9716241,3.4603305,,,,,,,,,,,,,, -32000,1.1921073,2.9573236,,,,,,,,,,,,,, -32100,0.92443085,3.8269145,,,,,,,,,,,,,, -32200,1.3523046,2.8741622,,,,,,,,,,,,,, -32300,1.1451383,2.8434327,,,,,,,,,,,,,, -32397,,,0.572265625,1.8312280178070068,0.5342599749565125,2.031036615371704,50000.0,0.4207000136375427,2.69540810585022,10000.0,15165.913511514664,17081.899649858475,15165.913511514664,1912.60741186142,1.521423101425171,0.0 -32400,0.9592001,3.7344227,,,,,,,,,,,,,, -32500,1.0758553,2.7534604,,,,,,,,,,,,,, -32600,1.0490996,3.37328,,,,,,,,,,,,,, -32700,0.97350985,4.119601,,,,,,,,,,,,,, -32800,0.9975628,5.2486734,,,,,,,,,,,,,, -32900,0.9468931,4.7749,,,,,,,,,,,,,, -33000,1.2382379,2.839482,,,,,,,,,,,,,, -33100,1.2791518,3.088995,,,,,,,,,,,,,, -33200,1.0417719,3.3914816,,,,,,,,,,,,,, -33298,,,0.5862890481948853,1.739125370979309,0.5410199761390686,1.962300419807434,50000.0,0.4212000072002411,2.652564764022827,10000.0,15585.88121342659,17553.76842737198,15585.88121342659,1964.42044878006,1.557866096496582,0.0 -33300,1.2553539,5.547767,,,,,,,,,,,,,, -33400,0.8475644,4.838604,,,,,,,,,,,,,, -33500,1.2358738,2.8938847,,,,,,,,,,,,,, -33600,0.92136586,5.172558,,,,,,,,,,,,,, -33700,1.0785345,2.8855786,,,,,,,,,,,,,, -33800,0.93143296,3.9122982,,,,,,,,,,,,,, -33900,0.9159777,4.5454817,,,,,,,,,,,,,, -34000,0.93865263,4.7522774,,,,,,,,,,,,,, -34100,1.272045,2.89372,,,,,,,,,,,,,, -34200,1.0147748,3.395181,,,,,,,,,,,,,, -34204,,,0.5970116853713989,1.712268590927124,0.5445600152015686,1.9671906232833865,50000.0,0.4263000190258026,2.6361184120178223,10000.0,16006.172748804092,18025.4995098114,16006.172748804092,2015.772707223892,1.5926804542541504,0.0 -34300,1.0613577,2.9334276,,,,,,,,,,,,,, -34400,1.330813,2.8162763,,,,,,,,,,,,,, -34500,1.1524025,2.7991176,,,,,,,,,,,,,, -34600,0.92790335,4.490558,,,,,,,,,,,,,, -34700,0.8663418,5.450371,,,,,,,,,,,,,, -34800,1.0599797,2.9765272,,,,,,,,,,,,,, -34900,0.92097825,3.917087,,,,,,,,,,,,,, -35000,0.9362453,4.50173,,,,,,,,,,,,,, -35100,1.278579,2.9149945,,,,,,,,,,,,,, -35105,,,0.5822460651397705,1.7583211660385132,0.5433599948883057,1.9646700620651243,50000.0,0.4274000227451324,2.626077175140381,10000.0,16426.362616062164,18496.10616064072,16426.362616062164,2066.1069979667664,1.6230809688568115,0.0 -35200,1.0885471,2.9375062,,,,,,,,,,,,,, -35300,1.030528,3.4094477,,,,,,,,,,,,,, -35400,1.0754745,3.3190868,,,,,,,,,,,,,, -35500,1.0969983,4.026483,,,,,,,,,,,,,, -35600,0.92973,4.4605637,,,,,,,,,,,,,, -35700,1.3749942,2.8308249,,,,,,,,,,,,,, -35800,1.1519492,3.0363684,,,,,,,,,,,,,, -35900,0.978466,4.091699,,,,,,,,,,,,,, -36000,1.1775225,2.9331882,,,,,,,,,,,,,, -36006,,,0.5903906226158142,1.75039541721344,0.546999990940094,1.9551115036010744,50000.0,0.433100014925003,2.605111837387085,10000.0,16846.550374031067,18968.680662870407,16846.550374031067,2118.408214569092,1.6574127674102783,0.0 -36100,0.9471928,5.3785567,,,,,,,,,,,,,, -36200,1.1845981,3.0612853,,,,,,,,,,,,,, -36300,1.3205684,2.6834645,,,,,,,,,,,,,, -36400,1.3222312,2.8549056,,,,,,,,,,,,,, -36500,0.93388337,3.7352886,,,,,,,,,,,,,, -36600,1.1806494,2.8181672,,,,,,,,,,,,,, -36700,1.0229441,3.2821982,,,,,,,,,,,,,, -36800,1.2133659,2.9749854,,,,,,,,,,,,,, -36900,1.2060362,2.810832,,,,,,,,,,,,,, -36905,,,0.5983593463897705,1.7371487617492676,0.548259973526001,1.973673343658448,50000.0,0.4347000122070312,2.613460063934326,10000.0,17266.725713968277,19440.53175663948,17266.725713968277,2169.996456384659,1.6937315464019775,0.0 -37000,1.0821608,4.9986353,,,,,,,,,,,,,, -37100,0.859811,4.6798153,,,,,,,,,,,,,, -37200,1.1845057,3.2154636,,,,,,,,,,,,,, -37300,1.2007674,2.7539973,,,,,,,,,,,,,, -37400,1.2236689,2.8208525,,,,,,,,,,,,,, -37500,0.9541387,3.5748758,,,,,,,,,,,,,, -37600,1.2028552,2.7944093,,,,,,,,,,,,,, -37700,1.0938944,2.833116,,,,,,,,,,,,,, -37800,0.92149836,5.0325994,,,,,,,,,,,,,, -37806,,,0.5934765338897705,1.7180927991867063,0.549560010433197,1.944177746772766,50000.0,0.4330000281333923,2.59557580947876,10000.0,17686.66410136223,19911.408517360687,17686.66410136223,2220.849625825882,1.7289478778839111,0.0 -37900,1.0934329,2.787385,,,,,,,,,,,,,, -38000,1.1540285,2.7305286,,,,,,,,,,,,,, -38100,1.3347226,2.670548,,,,,,,,,,,,,, -38200,1.1415632,2.6955507,,,,,,,,,,,,,, -38300,0.85916436,4.8554764,,,,,,,,,,,,,, -38400,1.311071,3.0343223,,,,,,,,,,,,,, -38500,1.2030803,2.8570638,,,,,,,,,,,,,, -38600,1.0322275,3.076562,,,,,,,,,,,,,, -38700,1.0163016,3.2963316,,,,,,,,,,,,,, -38706,,,0.5957421660423279,1.7233281135559082,0.5529400110244751,1.9345680475234983,50000.0,0.4406000077724457,2.6032638549804688,10000.0,18106.65107870102,20382.752317667007,18106.65107870102,2272.119005203247,1.765486717224121,0.0 -38800,1.1394167,2.789346,,,,,,,,,,,,,, -38900,0.83475894,5.2569966,,,,,,,,,,,,,, -39000,1.1613401,4.7702303,,,,,,,,,,,,,, -39100,1.178561,3.031248,,,,,,,,,,,,,, -39200,1.0684133,3.295059,,,,,,,,,,,,,, -39300,1.1815203,2.8357177,,,,,,,,,,,,,, -39400,0.9383284,3.75552,,,,,,,,,,,,,, -39500,1.2521483,2.9289799,,,,,,,,,,,,,, -39600,1.2456809,2.6649156,,,,,,,,,,,,,, -39610,,,0.6048241853713989,1.6466890573501587,0.558139979839325,1.8898195028305047,50000.0,0.4489000141620636,2.53645920753479,10000.0,18526.90128660202,20856.095355033875,18526.90128660202,2325.1231849193573,1.80271577835083,0.0 -39700,1.2554418,2.7034233,,,,,,,,,,,,,, -39800,1.05399,3.4698102,,,,,,,,,,,,,, -39900,1.2246311,2.711698,,,,,,,,,,,,,, -40000,1.025117,3.0457888,,,,,,,,,,,,,, -40100,1.236086,2.8228774,,,,,,,,,,,,,, -40200,1.2808974,2.758224,,,,,,,,,,,,,, -40300,1.1249646,3.0141437,,,,,,,,,,,,,, -40400,1.2051554,2.6329236,,,,,,,,,,,,,, -40500,1.0285637,3.436312,,,,,,,,,,,,,, -40511,,,0.6100000143051147,1.6908632516860962,0.5554800033569336,1.9389760494232176,50000.0,0.4395000338554382,2.585554361343384,10000.0,18947.14054918289,21328.02912712097,18947.14054918289,2376.72646856308,1.842914581298828,0.0 -40600,0.861235,5.3063464,,,,,,,,,,,,,, -40700,1.0631466,2.787697,,,,,,,,,,,,,, -40800,0.90658873,5.0462904,,,,,,,,,,,,,, -40900,1.0462488,3.4487967,,,,,,,,,,,,,, -41000,0.9116227,3.9981873,,,,,,,,,,,,,, -41100,1.1280937,5.134641,,,,,,,,,,,,,, -41200,1.2016037,2.7536478,,,,,,,,,,,,,, -41300,1.278464,2.7322962,,,,,,,,,,,,,, -41400,0.9629541,5.277755,,,,,,,,,,,,,, -41412,,,0.6074023246765137,1.6232702732086182,0.5635200142860413,1.8514550924301147,50000.0,0.4462000131607055,2.5395474433898926,10000.0,19367.37152767181,21799.19383573532,19367.37152767181,2427.576056241989,1.8747284412384035,0.0 -41500,0.9107114,3.7219834,,,,,,,,,,,,,, -41600,1.1518028,2.8119335,,,,,,,,,,,,,, -41700,0.9465632,4.0218053,,,,,,,,,,,,,, -41800,1.1174036,2.7771742,,,,,,,,,,,,,, -41900,1.1313586,2.991086,,,,,,,,,,,,,, -42000,1.3863351,2.8472633,,,,,,,,,,,,,, -42100,1.035324,3.7574496,,,,,,,,,,,,,, -42200,1.1317229,2.8791218,,,,,,,,,,,,,, -42300,1.097644,4.346783,,,,,,,,,,,,,, -42313,,,0.610644519329071,1.6274819374084473,0.561519980430603,1.8759976625442505,50000.0,0.4504000246524811,2.53519868850708,10000.0,19787.460819482803,22274.2334086895,19787.460819482803,2482.440875768661,1.907930850982666,0.0 -42400,1.0492024,5.0505795,,,,,,,,,,,,,, -42500,1.1159027,4.3425326,,,,,,,,,,,,,, -42600,1.178064,2.7991927,,,,,,,,,,,,,, -42700,1.0902443,3.1404247,,,,,,,,,,,,,, -42800,1.2783778,2.722836,,,,,,,,,,,,,, -42900,0.9561732,5.1120396,,,,,,,,,,,,,, -43000,0.93273735,5.304814,,,,,,,,,,,,,, -43100,1.0278554,3.5389912,,,,,,,,,,,,,, -43200,1.0567423,3.6180623,,,,,,,,,,,,,, -43216,,,0.6401757597923279,1.5094099044799805,0.5685200095176697,1.8535585403442385,50000.0,0.4533000290393829,2.504759788513184,10000.0,20207.386869430546,22748.151103019714,20207.386869430546,2536.3484270572662,1.941042184829712,0.0 -43300,0.96430373,4.6300335,,,,,,,,,,,,,, -43400,1.1259232,2.7223048,,,,,,,,,,,,,, -43500,1.1994643,2.6470108,,,,,,,,,,,,,, -43600,1.2044683,2.676199,,,,,,,,,,,,,, -43700,1.2410623,2.6930594,,,,,,,,,,,,,, -43800,1.1110156,2.603564,,,,,,,,,,,,,, -43900,0.84236634,5.133077,,,,,,,,,,,,,, -44000,1.2166071,2.725051,,,,,,,,,,,,,, -44100,1.1214844,3.0380776,,,,,,,,,,,,,, -44115,,,0.615527331829071,1.6147336959838867,0.5694000124931335,1.8382197618484497,50000.0,0.4478000104427337,2.5037691593170166,10000.0,20627.419096708298,23222.031039714813,20627.419096708298,2590.112258195877,1.9733588695526123,0.0 -44200,1.0968627,2.619194,,,,,,,,,,,,,, -44300,1.1722555,2.7787411,,,,,,,,,,,,,, -44400,1.4428389,2.7760596,,,,,,,,,,,,,, -44500,1.1398112,4.5607204,,,,,,,,,,,,,, -44600,1.0188236,4.256463,,,,,,,,,,,,,, -44700,1.2052323,2.7475634,,,,,,,,,,,,,, -44800,1.0702239,3.0186665,,,,,,,,,,,,,, -44900,1.6027862,2.8420117,,,,,,,,,,,,,, -45000,1.191231,2.8518915,,,,,,,,,,,,,, -45016,,,0.6134960651397705,1.659540057182312,0.5658400058746338,1.892349362373352,50000.0,0.4537000358104706,2.542588472366333,10000.0,21047.822548627853,23695.17087316513,21047.822548627853,2642.76513504982,2.0060572624206543,0.0 -45100,0.9640099,4.475194,,,,,,,,,,,,,, -45200,1.2559493,2.790698,,,,,,,,,,,,,, -45300,1.16392,2.701948,,,,,,,,,,,,,, -45400,0.9703524,3.7297506,,,,,,,,,,,,,, -45500,0.9668784,3.978869,,,,,,,,,,,,,, -45600,1.1341373,2.470723,,,,,,,,,,,,,, -45700,1.2754002,2.7839937,,,,,,,,,,,,,, -45800,1.1568561,2.677434,,,,,,,,,,,,,, -45900,1.0826321,5.0544224,,,,,,,,,,,,,, -45915,,,0.6446874737739563,1.4981642961502075,0.5705599784851074,1.83890163898468,50000.0,0.4538000226020813,2.495013952255249,10000.0,21468.1283724308,24166.16872239113,21468.1283724308,2693.3661007881165,2.0465316772460938,0.0 -46000,1.1478707,2.9464726,,,,,,,,,,,,,, -46100,1.244506,2.8179574,,,,,,,,,,,,,, -46200,1.0197848,3.6508052,,,,,,,,,,,,,, -46300,1.1012014,2.947418,,,,,,,,,,,,,, -46400,1.1904483,2.9611626,,,,,,,,,,,,,, -46500,1.1334938,2.618602,,,,,,,,,,,,,, -46600,1.2153713,2.6674492,,,,,,,,,,,,,, -46700,0.94172126,3.5643344,,,,,,,,,,,,,, -46800,1.1124966,2.7152557,,,,,,,,,,,,,, -46809,,,0.6206445097923279,1.6042301654815674,0.5744799971580505,1.81850016117096,50000.0,0.4558000266551971,2.4935083389282227,10000.0,21888.07623887062,24639.414632320404,21888.07623887062,2746.573971748352,2.0856590270996094,0.0 -46900,1.0778348,3.518754,,,,,,,,,,,,,, -47000,0.9758308,4.5284243,,,,,,,,,,,,,, -47100,1.0284787,3.8310485,,,,,,,,,,,,,, -47200,1.146975,2.6457345,,,,,,,,,,,,,, -47300,1.1793307,2.7537775,,,,,,,,,,,,,, -47400,1.2357897,2.5992544,,,,,,,,,,,,,, -47500,1.1801822,2.7238903,,,,,,,,,,,,,, -47600,0.9999719,3.3153555,,,,,,,,,,,,,, -47700,0.9643562,3.7340305,,,,,,,,,,,,,, -47711,,,0.6209374666213989,1.5577152967453003,0.5724999904632568,1.797107458114624,50000.0,0.4577000141143799,2.455000400543213,10000.0,22308.360613822937,25112.828466653824,22308.360613822937,2799.619107723236,2.118630886077881,0.0 -47800,1.1513765,2.7705033,,,,,,,,,,,,,, -47900,1.1969223,2.6773741,,,,,,,,,,,,,, -48000,1.0872477,3.05645,,,,,,,,,,,,,, -48100,1.1235874,2.674977,,,,,,,,,,,,,, -48200,1.1568841,2.8563538,,,,,,,,,,,,,, -48300,1.3784564,2.696611,,,,,,,,,,,,,, -48400,1.2410302,2.7123656,,,,,,,,,,,,,, -48500,0.9710951,4.8529463,,,,,,,,,,,,,, -48600,1.0216539,3.5064485,,,,,,,,,,,,,, -48614,,,0.6391406059265137,1.4967988729476929,0.5768799781799316,1.8095238208770752,50000.0,0.4567000269889831,2.475313901901245,10000.0,22728.47146344185,25586.20291495323,22728.47146344185,2852.789181947708,2.153792381286621,0.0 -48700,0.92307925,4.6028047,,,,,,,,,,,,,, -48800,1.2707285,2.6724715,,,,,,,,,,,,,, -48900,1.5797582,2.7100172,,,,,,,,,,,,,, -49000,0.9723965,5.202457,,,,,,,,,,,,,, -49100,1.3191286,2.665113,,,,,,,,,,,,,, -49200,1.1473191,2.7146974,,,,,,,,,,,,,, -49300,0.9346528,4.0016823,,,,,,,,,,,,,, -49400,1.1128545,2.6968064,,,,,,,,,,,,,, -49500,1.1413836,2.6995263,,,,,,,,,,,,,, -49511,,,0.6273437142372131,1.5240154266357422,0.5815399885177612,1.7533224821090698,50000.0,0.4651000201702118,2.4164364337921143,10000.0,23148.824804782867,26059.92187833786,23148.824804782867,2906.063053369522,2.1936678886413574,0.0 -49600,1.2259519,2.8291478,,,,,,,,,,,,,, -49700,1.2931437,2.5875263,,,,,,,,,,,,,, -49800,1.3614932,2.7954507,,,,,,,,,,,,,, -49900,1.1366451,5.241285,,,,,,,,,,,,,, -50000,0.9665887,5.1589403,,,,,,,,,,,,,, -50100,1.2646781,2.6472821,,,,,,,,,,,,,, -50200,1.1568567,2.7628908,,,,,,,,,,,,,, -50300,1.3946444,2.628679,,,,,,,,,,,,,, -50400,0.903842,4.836564,,,,,,,,,,,,,, -50409,,,0.6357030868530273,1.4875028133392334,0.5874599814414978,1.7306010723114014,50000.0,0.4692000150680542,2.4118294715881348,10000.0,23569.12908434868,26531.848781347275,23569.12908434868,2957.601359605789,2.2275655269622803,0.0 -50500,1.2123697,2.5814657,,,,,,,,,,,,,, -50600,1.2023004,2.6066737,,,,,,,,,,,,,, -50700,1.0651973,3.435133,,,,,,,,,,,,,, -50800,0.9449127,4.195023,,,,,,,,,,,,,, -50900,1.23906,2.7662666,,,,,,,,,,,,,, -51000,1.3321013,2.7778676,,,,,,,,,,,,,, -51100,1.0246764,5.3346415,,,,,,,,,,,,,, -51200,1.1935833,2.666104,,,,,,,,,,,,,, -51300,1.1170019,2.8327127,,,,,,,,,,,,,, -51307,,,0.6474804282188416,1.4440979957580566,0.5879999995231628,1.7363550662994385,50000.0,0.4647000133991241,2.418069839477539,10000.0,23989.14796257019,27006.25090765953,23989.14796257019,3011.8984134197235,2.261938571929932,0.0 -51400,1.2079495,2.5670257,,,,,,,,,,,,,, -51500,1.1175375,2.839415,,,,,,,,,,,,,, -51600,1.1195743,3.1451206,,,,,,,,,,,,,, -51700,1.2883375,2.8281152,,,,,,,,,,,,,, -51800,0.99589866,5.3353653,,,,,,,,,,,,,, -51900,1.085872,3.1308336,,,,,,,,,,,,,, -52000,1.0776083,3.4809623,,,,,,,,,,,,,, -52100,1.2171842,2.577859,,,,,,,,,,,,,, -52200,1.0936123,2.6119552,,,,,,,,,,,,,, -52205,,,0.6319531202316284,1.5517568588256836,0.5824599862098694,1.791892170906067,50000.0,0.461400032043457,2.442796230316162,10000.0,24409.063949108124,27477.90633225441,24409.063949108124,3063.552988052368,2.295424699783325,0.0 -52300,1.2216717,3.763647,,,,,,,,,,,,,, -52400,1.439833,2.7319038,,,,,,,,,,,,,, -52500,1.1920073,2.5824142,,,,,,,,,,,,,, -52600,0.9362385,5.115755,,,,,,,,,,,,,, -52700,1.0694627,5.150553,,,,,,,,,,,,,, -52800,1.2223114,2.712873,,,,,,,,,,,,,, -52900,0.96886253,3.351018,,,,,,,,,,,,,, -53000,1.3515128,2.630859,,,,,,,,,,,,,, -53100,1.2488489,2.7467463,,,,,,,,,,,,,, -53106,,,0.6414257884025574,1.4624534845352173,0.5923799872398376,1.7016264200210571,50000.0,0.4686000347137451,2.3851122856140137,10000.0,24829.066326379776,27951.72364020348,24829.066326379776,3117.278034448624,2.3338751792907715,0.0 -53200,1.2953033,2.6685524,,,,,,,,,,,,,, -53300,1.4878819,2.6553545,,,,,,,,,,,,,, -53400,1.3796893,2.658653,,,,,,,,,,,,,, -53500,1.1855042,2.5598822,,,,,,,,,,,,,, -53600,1.2067894,2.7502103,,,,,,,,,,,,,, -53700,1.2457551,2.6302943,,,,,,,,,,,,,, -53800,1.1619251,2.6473832,,,,,,,,,,,,,, -53900,1.2945712,2.7931786,,,,,,,,,,,,,, -54000,1.0751909,3.5895069,,,,,,,,,,,,,, -54005,,,0.6461132764816284,1.4805848598480225,0.5896399617195129,1.7537370920181274,50000.0,0.4639000296592712,2.440372467041016,10000.0,25249.0579559803,28423.403182029724,25249.0579559803,3168.877803325653,2.370557308197021,0.0 -54100,0.9280587,3.7466054,,,,,,,,,,,,,, -54200,1.2996339,2.733986,,,,,,,,,,,,,, -54300,1.0699108,3.0762045,,,,,,,,,,,,,, -54400,1.2627274,2.7814898,,,,,,,,,,,,,, -54500,1.2755327,2.4896603,,,,,,,,,,,,,, -54600,1.0447395,3.5178645,,,,,,,,,,,,,, -54700,1.3269712,2.6395571,,,,,,,,,,,,,, -54800,1.2635727,2.676521,,,,,,,,,,,,,, -54900,1.107243,3.7867868,,,,,,,,,,,,,, -54904,,,0.6288476586341858,1.5620810985565186,0.5803399682044983,1.7848657369613647,50000.0,0.4611000120639801,2.4644737243652344,10000.0,25669.08655667305,28894.5414185524,25669.08655667305,3219.899830341339,2.4072391986846924,0.0 -55000,1.5071331,2.6703718,,,,,,,,,,,,,, -55100,1.1232362,5.0066185,,,,,,,,,,,,,, -55200,1.0775291,3.4077716,,,,,,,,,,,,,, -55300,1.0335811,5.017677,,,,,,,,,,,,,, -55400,1.255988,2.691362,,,,,,,,,,,,,, -55500,1.0850655,3.7670968,,,,,,,,,,,,,, -55600,1.1038394,3.0966334,,,,,,,,,,,,,, -55700,1.3070723,2.727253,,,,,,,,,,,,,, -55800,1.3533604,2.5281718,,,,,,,,,,,,,, -55805,,,0.6421093344688416,1.480443835258484,0.5927599668502808,1.717319130897522,50000.0,0.4727000296115875,2.3828577995300293,10000.0,26089.343856811523,29368.97186183929,26089.343856811523,3273.9854452610016,2.443835020065308,0.0 -55900,1.3086305,2.609704,,,,,,,,,,,,,, -56000,1.0964806,3.4732027,,,,,,,,,,,,,, -56100,1.0841523,2.8367176,,,,,,,,,,,,,, -56200,1.059004,2.9697502,,,,,,,,,,,,,, -56300,1.2279592,2.711938,,,,,,,,,,,,,, -56400,1.3511523,2.6272302,,,,,,,,,,,,,, -56500,1.4640214,2.5879583,,,,,,,,,,,,,, -56600,0.9646921,4.545494,,,,,,,,,,,,,, -56700,1.1412385,3.1634102,,,,,,,,,,,,,, -56706,,,0.6518945097923279,1.4154984951019287,0.5967999696731567,1.6895005702972412,50000.0,0.4781000316143036,2.3427188396453857,10000.0,26509.37562179565,29841.977155923843,26509.37562179565,3326.871330738068,2.479809284210205,0.0 -56800,1.1266643,3.2468758,,,,,,,,,,,,,, -56900,1.015465,5.026322,,,,,,,,,,,,,, -57000,1.2631161,2.6284466,,,,,,,,,,,,,, -57100,1.0705179,4.0193043,,,,,,,,,,,,,, -57200,1.1713259,2.581166,,,,,,,,,,,,,, -57300,1.2163821,2.726427,,,,,,,,,,,,,, -57400,1.1925933,2.500396,,,,,,,,,,,,,, -57500,1.219208,2.7578712,,,,,,,,,,,,,, -57600,1.091447,4.959499,,,,,,,,,,,,,, -57605,,,0.6357030868530273,1.5449306964874268,0.5915799736976624,1.7564728260040283,50000.0,0.4772000312805176,2.4073524475097656,10000.0,26929.51243591309,30316.034165859222,26929.51243591309,3380.7066905498505,2.51355242729187,0.0 -57700,1.2166791,2.6059175,,,,,,,,,,,,,, -57800,1.1451856,2.6246562,,,,,,,,,,,,,, -57900,0.9762446,4.3025627,,,,,,,,,,,,,, -58000,1.0186341,4.180196,,,,,,,,,,,,,, -58100,1.3824713,2.768912,,,,,,,,,,,,,, -58200,1.3816959,2.6126127,,,,,,,,,,,,,, -58300,1.2448608,2.5565753,,,,,,,,,,,,,, -58400,1.0884669,4.0326276,,,,,,,,,,,,,, -58500,1.1223795,2.523056,,,,,,,,,,,,,, -58505,,,0.6467968821525574,1.471572995185852,0.5975599884986877,1.7017605304718018,50000.0,0.4745000302791595,2.364131689071656,10000.0,27349.63072085381,30790.28051924705,27349.63072085381,3434.7488169670105,2.548457622528076,0.0 -58600,1.3280319,2.6691613,,,,,,,,,,,,,, -58700,0.99094695,3.9858258,,,,,,,,,,,,,, -58800,1.1146128,3.023512,,,,,,,,,,,,,, -58900,1.2749527,2.5397584,,,,,,,,,,,,,, -59000,1.0377233,4.037864,,,,,,,,,,,,,, -59100,0.95833945,4.5656204,,,,,,,,,,,,,, -59200,1.2141178,2.4387147,,,,,,,,,,,,,, -59300,1.1991336,5.2700553,,,,,,,,,,,,,, -59400,1.4156057,2.3876421,,,,,,,,,,,,,, -59407,,,0.6440038681030273,1.4910470247268677,0.590719997882843,1.7559362649917605,50000.0,0.4765000343322754,2.4175078868865967,10000.0,27769.91752099991,31261.66304731369,27769.91752099991,3485.748226881027,2.594045400619507,0.0 -59500,1.2462322,2.6027045,,,,,,,,,,,,,, -59600,0.9856471,4.974103,,,,,,,,,,,,,, -59700,1.1592829,2.446273,,,,,,,,,,,,,, -59800,1.222478,3.15951,,,,,,,,,,,,,, -59900,1.1055214,4.1069746,,,,,,,,,,,,,, -60000,1.2279377,2.6137826,,,,,,,,,,,,,, -60100,1.3896469,2.52212,,,,,,,,,,,,,, -60200,1.3577965,2.569542,,,,,,,,,,,,,, -60300,1.3794001,2.4751263,,,,,,,,,,,,,, -60301,,,0.6430468559265137,1.471490502357483,0.5990399718284607,1.6971635818481443,50000.0,0.4755000174045563,2.361530542373657,10000.0,28190.01903152466,31732.694142580032,28190.01903152466,3536.589988708496,2.6317689418792725,0.0 -60400,1.0716047,4.333805,,,,,,,,,,,,,, -60500,1.254171,2.3572865,,,,,,,,,,,,,, -60600,1.089306,5.1651015,,,,,,,,,,,,,, -60700,0.9641115,4.8166103,,,,,,,,,,,,,, -60800,1.0397742,4.0339317,,,,,,,,,,,,,, -60900,1.2239324,2.637685,,,,,,,,,,,,,, -61000,1.2746581,2.5882485,,,,,,,,,,,,,, -61100,1.2175739,2.4256487,,,,,,,,,,,,,, -61200,1.3057567,2.5234728,,,,,,,,,,,,,, -61201,,,0.6446093320846558,1.4742058515548706,0.6018999814987183,1.6856762170791626,50000.0,0.4775000214576721,2.361438274383545,10000.0,28609.990909576416,32205.369146585464,28609.990909576416,3589.198195934296,2.6737277507781982,0.0 -61300,1.1912683,2.3743496,,,,,,,,,,,,,, -61400,1.1907462,5.100758,,,,,,,,,,,,,, -61500,0.99838924,4.8738446,,,,,,,,,,,,,, -61600,1.2650907,2.5576637,,,,,,,,,,,,,, -61700,1.0148724,5.0644546,,,,,,,,,,,,,, -61800,1.1036042,2.9505258,,,,,,,,,,,,,, -61900,1.028264,4.0170593,,,,,,,,,,,,,, -62000,1.3135397,2.5015724,,,,,,,,,,,,,, -62100,1.2777493,2.5125604,,,,,,,,,,,,,, -62101,,,0.6470312476158142,1.4702602624893188,0.5948399901390076,1.7143629789352417,50000.0,0.4792000353336334,2.371155500411988,10000.0,29030.119871616364,32679.75479578972,29030.119871616364,3643.361774921417,2.71584153175354,0.0 -62200,1.0716484,3.0956085,,,,,,,,,,,,,, -62300,1.0159404,4.1465178,,,,,,,,,,,,,, -62400,1.0733114,3.179152,,,,,,,,,,,,,, -62500,1.2301393,3.6323247,,,,,,,,,,,,,, -62600,1.3569598,2.5124705,,,,,,,,,,,,,, -62700,0.9991827,5.020974,,,,,,,,,,,,,, -62800,1.1083602,3.0908074,,,,,,,,,,,,,, -62900,1.3771249,2.5324755,,,,,,,,,,,,,, -63000,1.3040487,2.6725216,,,,,,,,,,,,,, -63001,,,0.6622851490974426,1.3750454187393188,0.6061999797821045,1.6341488361358645,50000.0,0.4879000186920166,2.299943447113037,10000.0,29450.567351818085,33153.461698293686,29450.567351818085,3696.535520792008,2.7506535053253174,0.0 -63100,1.1918751,2.5149431,,,,,,,,,,,,,, -63200,1.2502348,2.480088,,,,,,,,,,,,,, -63300,1.2480619,3.4237278,,,,,,,,,,,,,, -63400,1.0774065,4.9968657,,,,,,,,,,,,,, -63500,1.1041634,4.1014547,,,,,,,,,,,,,, -63600,1.2961233,2.4631944,,,,,,,,,,,,,, -63700,1.0781971,3.7749913,,,,,,,,,,,,,, -63800,1.27107,2.4421062,,,,,,,,,,,,,, -63900,0.9594029,4.3458033,,,,,,,,,,,,,, -63903,,,0.6542187333106995,1.415706753730774,0.6043999791145325,1.652857780456543,50000.0,0.4817000329494476,2.317942380905152,10000.0,29870.9140355587,33626.445133686066,29870.9140355587,3749.080408573151,2.791459083557129,0.0 -64000,1.408558,2.3558376,,,,,,,,,,,,,, -64100,1.2376446,2.7741604,,,,,,,,,,,,,, -64200,1.3741854,2.4164572,,,,,,,,,,,,,, -64300,1.5171487,2.5226765,,,,,,,,,,,,,, -64400,1.3436999,2.3686948,,,,,,,,,,,,,, -64500,1.2800081,2.6354077,,,,,,,,,,,,,, -64600,1.4044873,2.565231,,,,,,,,,,,,,, -64700,1.254619,2.5458446,,,,,,,,,,,,,, -64800,1.3529049,2.4677076,,,,,,,,,,,,,, -64801,,,0.6602929830551147,1.405190348625183,0.6033799648284912,1.6764905452728271,50000.0,0.4809000194072723,2.324260711669922,10000.0,30291.509792089462,34099.37319946289,30291.509792089462,3801.326912641525,2.8265275955200195,0.0 -64900,1.1422015,3.393499,,,,,,,,,,,,,, -65000,1.0528481,5.170508,,,,,,,,,,,,,, -65100,1.0882612,4.228487,,,,,,,,,,,,,, -65200,1.2053899,4.4739704,,,,,,,,,,,,,, -65300,1.1826892,2.9283903,,,,,,,,,,,,,, -65400,1.3965334,2.6182544,,,,,,,,,,,,,, -65500,1.3549043,2.6405864,,,,,,,,,,,,,, -65600,1.1240293,2.999327,,,,,,,,,,,,,, -65700,1.1979373,2.793619,,,,,,,,,,,,,, -65703,,,0.6716015338897705,1.325235366821289,0.6054999828338623,1.634854793548584,50000.0,0.4862000346183777,2.297815322875977,10000.0,30711.8409523964,34570.10649561882,30711.8409523964,3851.639670610428,2.863790512084961,0.0 -65800,1.581551,2.7823133,,,,,,,,,,,,,, -65900,1.2541895,2.4645905,,,,,,,,,,,,,, -66000,1.2576917,2.6490204,,,,,,,,,,,,,, -66100,1.2347564,2.5930746,,,,,,,,,,,,,, -66200,1.0636678,3.6796916,,,,,,,,,,,,,, -66300,1.2442077,2.6906297,,,,,,,,,,,,,, -66400,1.0540441,4.3787146,,,,,,,,,,,,,, -66500,0.98853326,5.0202994,,,,,,,,,,,,,, -66599,,,0.6601366996765137,1.4673115015029907,0.6058799624443054,1.716723084449768,50000.0,0.4908000230789184,2.349582433700561,10000.0,31131.927243471146,35039.92137694359,31131.927243471146,3901.277138948441,2.9034883975982666,0.0 -66600,1.41845,2.572106,,,,,,,,,,,,,, -66700,1.1255834,5.030598,,,,,,,,,,,,,, -66800,1.1580511,2.3846447,,,,,,,,,,,,,, -66900,1.0646992,2.8550081,,,,,,,,,,,,,, -67000,1.1078192,2.9735773,,,,,,,,,,,,,, -67100,1.294452,4.208972,,,,,,,,,,,,,, -67200,1.0051286,5.107565,,,,,,,,,,,,,, -67300,1.3055944,2.465776,,,,,,,,,,,,,, -67400,1.2578708,2.368309,,,,,,,,,,,,,, -67499,,,0.6662890315055847,1.3695124387741089,0.6140199899673462,1.6112968921661377,50000.0,0.4932000339031219,2.2746410369873047,10000.0,31552.259560346603,35511.90776872635,31552.259560346603,3952.83056306839,2.9529545307159424,0.0 -67500,1.2572769,2.3916054,,,,,,,,,,,,,, -67600,1.0230286,3.8486514,,,,,,,,,,,,,, -67700,1.1617249,3.5833855,,,,,,,,,,,,,, -67800,1.0050759,4.3716965,,,,,,,,,,,,,, -67900,1.3165085,2.5164187,,,,,,,,,,,,,, -68000,1.1347404,3.1872098,,,,,,,,,,,,,, -68100,1.426859,2.422442,,,,,,,,,,,,,, -68200,1.1207716,3.5252063,,,,,,,,,,,,,, -68300,1.2971604,2.4272656,,,,,,,,,,,,,, -68398,,,0.6966406106948853,1.2503221035003662,0.6096799969673157,1.6365011930465698,50000.0,0.4933000206947326,2.2944016456604004,10000.0,31972.5193977356,35983.74279308319,31972.5193977356,4004.3154785633087,2.992598533630371,0.0 -68400,1.3391916,2.574077,,,,,,,,,,,,,, -68500,1.272144,2.8769279,,,,,,,,,,,,,, -68600,1.2444328,3.1025426,,,,,,,,,,,,,, -68700,1.367389,4.482081,,,,,,,,,,,,,, -68800,1.2267995,2.907293,,,,,,,,,,,,,, -68900,0.99400026,5.0273643,,,,,,,,,,,,,, -69000,1.2521591,2.6100707,,,,,,,,,,,,,, -69100,1.2366586,3.6952803,,,,,,,,,,,,,, -69200,1.4207841,2.4496884,,,,,,,,,,,,,, -69292,,,0.6567773222923279,1.425057291984558,0.6114400029182434,1.6573946475982666,50000.0,0.4886000156402588,2.312219858169556,10000.0,32392.77770280838,36456.78786849976,32392.77770280838,4057.010164737701,3.033708333969116,0.0 -69300,1.3248767,2.4942353,,,,,,,,,,,,,, -69400,1.404803,2.6041613,,,,,,,,,,,,,, -69500,1.2645872,2.4116116,,,,,,,,,,,,,, -69600,1.0577546,4.981084,,,,,,,,,,,,,, -69700,1.2164719,2.5087695,,,,,,,,,,,,,, -69800,1.3141567,2.3340049,,,,,,,,,,,,,, -69900,1.3478042,2.615515,,,,,,,,,,,,,, -70000,1.2031633,2.3988245,,,,,,,,,,,,,, -70100,1.2209107,2.7958648,,,,,,,,,,,,,, -70190,,,0.6670116782188416,1.3967492580413818,0.6153599619865417,1.6382001638412476,50000.0,0.4886000156402588,2.298884153366089,10000.0,32812.85529613495,36928.463305950165,32812.85529613495,4108.516996145248,3.072967052459717,0.0 -70200,1.1102825,3.1689749,,,,,,,,,,,,,, -70300,1.0037186,4.6035438,,,,,,,,,,,,,, -70400,1.4412515,2.447419,,,,,,,,,,,,,, -70500,1.1343157,5.027849,,,,,,,,,,,,,, -70600,1.1514299,2.8114157,,,,,,,,,,,,,, -70700,1.0874618,3.2597911,,,,,,,,,,,,,, -70800,1.3635151,2.4673636,,,,,,,,,,,,,, -70900,1.163344,2.6912546,,,,,,,,,,,,,, -71000,1.1006197,3.2927425,,,,,,,,,,,,,, -71090,,,0.6880077719688416,1.30825936794281,0.6117199659347534,1.6579543352127075,50000.0,0.4946000277996063,2.290399074554444,10000.0,33233.01276636124,37401.97440814972,33233.01276636124,4161.779107093811,3.111844539642334,0.0 -71100,1.0615566,4.531987,,,,,,,,,,,,,, -71200,1.4223922,2.3832152,,,,,,,,,,,,,, -71300,1.4391155,2.592855,,,,,,,,,,,,,, -71400,1.3139466,2.636037,,,,,,,,,,,,,, -71500,1.2896866,2.5630443,,,,,,,,,,,,,, -71600,1.0377265,5.079318,,,,,,,,,,,,,, -71700,1.2024124,2.73024,,,,,,,,,,,,,, -71800,1.224799,2.365704,,,,,,,,,,,,,, -71900,1.2249948,2.7485147,,,,,,,,,,,,,, -71989,,,0.6581054329872131,1.4184592962265017,0.613599956035614,1.642994999885559,50000.0,0.4919000267982483,2.313216209411621,10000.0,33653.27028799057,37872.505085229874,33653.27028799057,4211.959260225296,3.1535627841949463,0.0 -72000,1.2314136,2.4653516,,,,,,,,,,,,,, -72100,1.3185588,2.5081196,,,,,,,,,,,,,, -72200,1.5240445,2.4171305,,,,,,,,,,,,,, -72300,1.3576043,3.1263874,,,,,,,,,,,,,, -72400,1.0478172,4.5666656,,,,,,,,,,,,,, -72500,1.3171778,2.420576,,,,,,,,,,,,,, -72600,1.3799584,2.5676093,,,,,,,,,,,,,, -72700,1.41761,2.3632476,,,,,,,,,,,,,, -72800,1.3429102,3.0295591,,,,,,,,,,,,,, -72888,,,0.6729101538658142,1.3513695001602173,0.6170799732208252,1.6132991313934326,50000.0,0.4941000342369079,2.280538320541382,10000.0,34073.381747722626,38344.63732886314,34073.381747722626,4263.880227088928,3.2022414207458496,0.0 -72900,1.2844633,2.273808,,,,,,,,,,,,,, -73000,1.2223499,2.6742673,,,,,,,,,,,,,, -73100,1.0409702,3.483463,,,,,,,,,,,,,, -73200,1.121612,4.3615375,,,,,,,,,,,,,, -73300,1.0333745,4.9900956,,,,,,,,,,,,,, -73400,1.2351681,2.2936182,,,,,,,,,,,,,, -73500,1.3220997,2.3783383,,,,,,,,,,,,,, -73600,1.2989498,2.6732414,,,,,,,,,,,,,, -73700,1.2325945,2.68816,,,,,,,,,,,,,, -73785,,,0.6938085556030273,1.2425113916397097,0.6222400069236755,1.575509786605835,50000.0,0.5059000253677368,2.2482941150665283,10000.0,34493.59297966957,38817.48063826561,34493.59297966957,4316.417612314224,3.24575424194336,0.0 -73800,1.3198421,2.325193,,,,,,,,,,,,,, -73900,1.5863523,2.377213,,,,,,,,,,,,,, -74000,1.2518927,2.402853,,,,,,,,,,,,,, -74100,1.1743984,2.6312177,,,,,,,,,,,,,, -74200,1.1312567,3.7127323,,,,,,,,,,,,,, -74300,1.2220411,2.9740088,,,,,,,,,,,,,, -74400,1.1654435,2.8648634,,,,,,,,,,,,,, -74500,1.3609012,2.5927522,,,,,,,,,,,,,, -74600,1.2512963,2.5684047,,,,,,,,,,,,,, -74687,,,0.6724804639816284,1.328711986541748,0.6229999661445618,1.5702327489852903,50000.0,0.503600001335144,2.223655939102173,10000.0,34913.6925303936,39289.15676212311,34913.6925303936,4367.896469116211,3.292013645172119,0.0 -74700,1.4564592,2.4504583,,,,,,,,,,,,,, -74800,1.2249857,3.3850574,,,,,,,,,,,,,, -74900,1.3639854,2.36582,,,,,,,,,,,,,, -75000,1.4833566,2.408548,,,,,,,,,,,,,, -75100,1.2458158,2.974958,,,,,,,,,,,,,, -75200,1.432073,2.6354518,,,,,,,,,,,,,, -75300,1.110259,3.3494961,,,,,,,,,,,,,, -75400,1.2258795,4.8407907,,,,,,,,,,,,,, -75500,1.1571577,2.8779333,,,,,,,,,,,,,, -75586,,,0.6808202862739563,1.340014934539795,0.6244800090789795,1.6081055402755735,50000.0,0.5058000087738037,2.2576117515563965,10000.0,35333.97060585022,39761.00652861595,35333.97060585022,4419.369015693665,3.339527130126953,0.0 -75600,1.2086974,4.261097,,,,,,,,,,,,,, -75700,1.3060356,2.413331,,,,,,,,,,,,,, -75800,1.4182739,2.321113,,,,,,,,,,,,,, -75900,1.144011,3.5695057,,,,,,,,,,,,,, -76000,1.3332616,2.5339274,,,,,,,,,,,,,, -76100,1.557428,2.3528526,,,,,,,,,,,,,, -76200,1.3072559,2.3953547,,,,,,,,,,,,,, -76300,1.1305212,4.3241596,,,,,,,,,,,,,, -76400,1.4498401,2.4390879,,,,,,,,,,,,,, -76484,,,0.6885741949081421,1.2858747243881226,0.6215599775314331,1.5913811922073364,50000.0,0.5015000104904175,2.227324247360229,10000.0,35753.88782739639,40235.9482319355,35753.88782739639,4474.300911664963,3.38096022605896,0.0 -76500,1.3615636,2.3896651,,,,,,,,,,,,,, -76600,1.2263685,3.3742778,,,,,,,,,,,,,, -76700,1.4101585,2.4380112,,,,,,,,,,,,,, -76800,1.3534287,2.9929051,,,,,,,,,,,,,, -76900,1.227893,4.094976,,,,,,,,,,,,,, -77000,1.540312,2.5419645,,,,,,,,,,,,,, -77100,1.213928,2.942191,,,,,,,,,,,,,, -77200,1.304684,2.4878078,,,,,,,,,,,,,, -77300,1.3723897,2.7488809,,,,,,,,,,,,,, -77378,,,0.6756640672683716,1.311980962753296,0.6242200136184692,1.5544767379760742,50000.0,0.5027000308036804,2.2037034034729004,10000.0,36174.27370333672,40708.220797777176,36174.27370333672,4526.091993093491,3.42548942565918,0.0 -77400,1.5347865,4.9852242,,,,,,,,,,,,,, -77500,1.3639605,2.5196877,,,,,,,,,,,,,, -77600,1.2658036,2.7893248,,,,,,,,,,,,,, -77700,1.3575487,2.5999079,,,,,,,,,,,,,, -77800,1.1915444,4.206351,,,,,,,,,,,,,, -77900,1.2435457,4.8331094,,,,,,,,,,,,,, -78000,1.6198419,2.234765,,,,,,,,,,,,,, -78100,1.348086,3.635549,,,,,,,,,,,,,, -78200,1.3448786,2.3826094,,,,,,,,,,,,,, -78274,,,0.6756445169448853,1.4200063943862915,0.6239799857139587,1.6556744575500488,50000.0,0.5020000338554382,2.3106143474578857,10000.0,36594.59930849075,41183.31033778191,36594.59930849075,4580.758625507355,3.471120595932007,0.0 -78300,1.3575381,2.2800164,,,,,,,,,,,,,, -78400,1.3272343,2.3552444,,,,,,,,,,,,,, -78500,1.4010973,2.3865838,,,,,,,,,,,,,, -78600,1.3693973,2.3054357,,,,,,,,,,,,,, -78700,1.3619976,2.5231292,,,,,,,,,,,,,, -78800,1.3882291,2.3870022,,,,,,,,,,,,,, -78900,1.3389045,4.9543614,,,,,,,,,,,,,, -79000,1.1443926,5.0197306,,,,,,,,,,,,,, -79100,1.2787042,2.4451342,,,,,,,,,,,,,, -79172,,,0.6913671493530273,1.2846978902816772,0.6258999705314636,1.5905046463012695,50000.0,0.4998000264167785,2.254966974258423,10000.0,37014.5864136219,41655.48101997376,37014.5864136219,4632.851358652115,3.5113136768341064,0.0 -79200,1.34804,2.5938318,,,,,,,,,,,,,, -79300,1.0149937,4.320851,,,,,,,,,,,,,, -79400,1.4060702,2.3679793,,,,,,,,,,,,,, -79500,1.2188728,5.024477,,,,,,,,,,,,,, -79600,1.0947958,4.6916924,,,,,,,,,,,,,, -79700,1.4779037,4.706463,,,,,,,,,,,,,, -79800,1.4066012,2.597364,,,,,,,,,,,,,, -79900,1.4645114,2.5774727,,,,,,,,,,,,,, -80000,1.3942944,2.5565743,,,,,,,,,,,,,, -80067,,,0.6728710532188416,1.3366504907608032,0.6302599906921387,1.55534029006958,50000.0,0.5051000118255615,2.206871032714844,10000.0,37434.66622328758,42128.71227145195,37434.66622328758,4685.910947561264,3.5520379543304443,0.0 -80100,1.0779651,3.5616179,,,,,,,,,,,,,, -80200,1.3358135,2.3899386,,,,,,,,,,,,,, -80300,1.3744364,2.3582954,,,,,,,,,,,,,, -80400,1.276216,2.8430126,,,,,,,,,,,,,, -80500,1.1576171,3.9398034,,,,,,,,,,,,,, -80600,1.3682418,2.349899,,,,,,,,,,,,,, -80700,1.4038633,2.318812,,,,,,,,,,,,,, -80800,1.0838852,4.403186,,,,,,,,,,,,,, -80900,1.4242781,2.4623694,,,,,,,,,,,,,, -80967,,,0.6885156035423279,1.2635071277618408,0.634660005569458,1.520119309425354,50000.0,0.5120000243186951,2.171586751937866,10000.0,37854.888414382935,42600.99657249451,37854.888414382935,4737.878388643265,3.595428943634033,0.0 -81000,1.3028316,2.7867377,,,,,,,,,,,,,, -81100,1.3472736,2.8855057,,,,,,,,,,,,,, -81200,1.0931996,4.2599635,,,,,,,,,,,,,, -81300,1.5380435,2.3535554,,,,,,,,,,,,,, -81400,1.6131588,2.299087,,,,,,,,,,,,,, -81500,1.1284627,3.0596845,,,,,,,,,,,,,, -81600,1.3328403,2.3687162,,,,,,,,,,,,,, -81700,1.2387373,3.3570669,,,,,,,,,,,,,, -81800,1.3632905,2.312234,,,,,,,,,,,,,, -81862,,,0.6949023008346558,1.2488195896148682,0.6343599557876587,1.5292257070541382,50000.0,0.5149000287055969,2.177933931350708,10000.0,38275.072088718414,43072.96156978607,38275.072088718414,4789.569073200226,3.63571047782898,0.0 -81900,1.136229,4.592232,,,,,,,,,,,,,, -82000,1.4969162,2.226476,,,,,,,,,,,,,, -82100,1.3833231,2.4419847,,,,,,,,,,,,,, -82200,1.4419664,2.4144738,,,,,,,,,,,,,, -82300,1.2586154,3.5796332,,,,,,,,,,,,,, -82400,1.546952,2.4301095,,,,,,,,,,,,,, -82500,1.1791073,4.0201645,,,,,,,,,,,,,, -82600,1.3184052,2.3136086,,,,,,,,,,,,,, -82700,1.3958236,3.0690317,,,,,,,,,,,,,, -82760,,,0.6834570169448853,1.3045527935028076,0.6303600072860718,1.5471371412277222,50000.0,0.5074000358581543,2.211868524551392,10000.0,38695.33440542221,43546.44796872139,38695.33440542221,4842.699766159058,3.677600860595703,0.0 -82800,1.1302702,4.937075,,,,,,,,,,,,,, -82900,1.1326872,2.7943482,,,,,,,,,,,,,, -83000,1.1986612,3.9620988,,,,,,,,,,,,,, -83100,1.3053788,2.4311516,,,,,,,,,,,,,, -83200,1.2700982,2.7318254,,,,,,,,,,,,,, -83300,1.2845839,2.1931796,,,,,,,,,,,,,, -83400,1.3243203,2.450139,,,,,,,,,,,,,, -83500,1.225453,3.7266595,,,,,,,,,,,,,, -83600,1.3790561,2.352174,,,,,,,,,,,,,, -83655,,,0.6850390434265137,1.279969573020935,0.6342200040817261,1.5213388204574585,50000.0,0.5121999979019165,2.1926448345184326,10000.0,39115.676505327225,44018.31398630142,39115.676505327225,4894.128688812256,3.7216544151306152,0.0 -83700,1.338915,2.3974204,,,,,,,,,,,,,, -83800,1.2609593,4.4450445,,,,,,,,,,,,,, -83900,1.2548754,3.5227716,,,,,,,,,,,,,, -84000,1.4007502,2.637955,,,,,,,,,,,,,, -84100,1.1882427,4.852397,,,,,,,,,,,,,, -84200,1.3620195,2.2940621,,,,,,,,,,,,,, -84300,1.2571061,4.980896,,,,,,,,,,,,,, -84400,1.1873935,3.1590302,,,,,,,,,,,,,, -84500,1.2565914,4.3191004,,,,,,,,,,,,,, -84551,,,0.703125,1.225046157836914,0.6385599970817566,1.5180543661117554,50000.0,0.5159000158309937,2.160423755645752,10000.0,39535.89488720894,44490.48594856262,39535.89488720894,4945.990887403488,3.761420726776123,0.0 -84600,1.2582518,2.601261,,,,,,,,,,,,,, -84700,1.3283615,4.5515623,,,,,,,,,,,,,, -84800,1.4315373,2.4269323,,,,,,,,,,,,,, -84900,1.3252989,2.5665798,,,,,,,,,,,,,, -85000,1.2915479,2.4219222,,,,,,,,,,,,,, -85100,1.3757727,2.4156222,,,,,,,,,,,,,, -85200,1.2243265,3.1029038,,,,,,,,,,,,,, -85300,1.098171,4.293684,,,,,,,,,,,,,, -85400,1.3711957,3.1430616,,,,,,,,,,,,,, -85449,,,0.683300793170929,1.303541898727417,0.6320399641990662,1.532470464706421,50000.0,0.5161000490188599,2.173837900161743,10000.0,39956.06528735161,44961.87512779236,39956.06528735161,4997.108908653259,3.810989618301392,0.0 -85500,1.4909214,2.306964,,,,,,,,,,,,,, -85600,1.4568797,2.1544282,,,,,,,,,,,,,, -85700,1.4611706,2.322933,,,,,,,,,,,,,, -85800,1.2548249,3.0866656,,,,,,,,,,,,,, -85900,1.2342374,3.835744,,,,,,,,,,,,,, -86000,1.6054984,2.4367204,,,,,,,,,,,,,, -86100,1.4397833,2.3083475,,,,,,,,,,,,,, -86200,1.1817253,3.2586799,,,,,,,,,,,,,, -86300,1.154232,4.4929566,,,,,,,,,,,,,, -86344,,,0.6930859088897705,1.2610045671463013,0.6387799978256226,1.5271767377853394,50000.0,0.5146999955177307,2.1729276180267334,10000.0,40376.11482572556,45433.95785403252,40376.11482572556,5049.05076956749,3.851077079772949,0.0 -86400,1.4259684,2.421177,,,,,,,,,,,,,, -86500,1.1218067,3.14945,,,,,,,,,,,,,, -86600,1.3804475,2.188243,,,,,,,,,,,,,, -86700,1.1673099,3.2988153,,,,,,,,,,,,,, -86800,1.4399282,2.3679595,,,,,,,,,,,,,, -86900,1.53619,2.406728,,,,,,,,,,,,,, -87000,1.4652193,2.2692258,,,,,,,,,,,,,, -87100,1.4820296,2.2512019,,,,,,,,,,,,,, -87200,1.2841038,2.4730911,,,,,,,,,,,,,, -87241,,,0.6998632550239563,1.2377312183380127,0.638700008392334,1.510784149169922,50000.0,0.5197000503540039,2.1511378288269043,10000.0,40796.28626155853,45905.69135284424,40796.28626155853,5100.5192720890045,3.894582748413086,0.0 -87300,1.6740806,2.4450922,,,,,,,,,,,,,, -87400,1.3003051,2.5483792,,,,,,,,,,,,,, -87500,1.5115952,2.264469,,,,,,,,,,,,,, -87600,1.3683752,2.2724042,,,,,,,,,,,,,, -87700,1.5094339,2.5179467,,,,,,,,,,,,,, -87800,1.4564971,2.3506122,,,,,,,,,,,,,, -87900,1.2864474,3.2504332,,,,,,,,,,,,,, -88000,1.2884628,2.5912693,,,,,,,,,,,,,, -88100,1.283788,2.2692678,,,,,,,,,,,,,, -88126,,,0.6940820217132568,1.2332541942596436,0.637179970741272,1.494731068611145,50000.0,0.5146000385284424,2.151771545410156,10000.0,41216.240828990936,46378.955389261246,41216.240828990936,5153.735585212708,3.9385111331939697,0.0 -88200,1.294837,3.791469,,,,,,,,,,,,,, -88300,1.2660302,3.3933792,,,,,,,,,,,,,, -88400,1.4287407,2.3670945,,,,,,,,,,,,,, -88500,1.2701833,4.6956286,,,,,,,,,,,,,, -88600,1.4016914,2.2866971,,,,,,,,,,,,,, -88700,1.4450667,2.4421532,,,,,,,,,,,,,, -88800,1.438025,2.0627096,,,,,,,,,,,,,, -88900,1.438249,2.5532181,,,,,,,,,,,,,, -88977,,,0.6952733993530273,1.2344191074371338,0.6416999697685242,1.4802577495574951,50000.0,0.5200999975204468,2.1336536407470703,10000.0,41636.33084964752,46850.67201781273,41636.33084964752,5205.26829957962,3.983673334121704,0.0 -89000,1.449242,2.698514,,,,,,,,,,,,,, -89100,1.2358272,2.8504748,,,,,,,,,,,,,, -89200,1.4492748,2.4396372,,,,,,,,,,,,,, -89300,1.3977432,2.2298734,,,,,,,,,,,,,, -89400,1.2968036,2.7259352,,,,,,,,,,,,,, -89500,1.5780663,2.304777,,,,,,,,,,,,,, -89600,1.2958039,4.9213614,,,,,,,,,,,,,, -89700,1.5539238,2.4049509,,,,,,,,,,,,,, -89800,1.3029546,2.8974118,,,,,,,,,,,,,, -89871,,,0.704296886920929,1.2123621702194214,0.6403399705886841,1.5044403076171875,50000.0,0.516700029373169,2.169416666030884,10000.0,42056.24417757988,47323.9805533886,42056.24417757988,5258.57132768631,4.025211334228516,0.0 -89900,1.2920501,2.5323305,,,,,,,,,,,,,, -90000,1.3933433,2.5583258,,,,,,,,,,,,,, -90100,1.6776057,2.3243585,,,,,,,,,,,,,, -90200,1.6466582,2.2319748,,,,,,,,,,,,,, -90300,1.515081,2.2895226,,,,,,,,,,,,,, -90400,1.4662739,2.4252067,,,,,,,,,,,,,, -90500,1.5830874,2.3406236,,,,,,,,,,,,,, -90600,1.543783,2.3276062,,,,,,,,,,,,,, -90700,1.3800147,2.1847212,,,,,,,,,,,,,, -90771,,,0.6926171779632568,1.247671127319336,0.6415799856185913,1.4853407144546509,50000.0,0.5189000368118286,2.149666786193848,10000.0,42476.40330886841,47795.89140582085,42476.40330886841,5310.227010965347,4.070418834686279,0.0 -90800,1.191965,4.0534673,,,,,,,,,,,,,, -90900,1.443439,2.2759535,,,,,,,,,,,,,, -91000,1.4761374,3.091148,,,,,,,,,,,,,, -91100,1.2530072,4.275415,,,,,,,,,,,,,, -91200,1.2516986,4.1460238,,,,,,,,,,,,,, -91300,1.6478356,2.2608669,,,,,,,,,,,,,, -91400,1.1945009,3.0103722,,,,,,,,,,,,,, -91500,1.4269651,2.431312,,,,,,,,,,,,,, -91600,1.4197283,2.2196927,,,,,,,,,,,,,, -91663,,,0.7013866901397705,1.2053130865097046,0.645639955997467,1.4752509593963623,50000.0,0.5196000337600708,2.1370153427124023,10000.0,42896.397490262985,48267.83230733872,42896.397490262985,5362.071304321289,4.122549295425415,0.0 -91700,1.389126,2.2462127,,,,,,,,,,,,,, -91800,1.3562055,4.39092,,,,,,,,,,,,,, -91900,1.4728283,3.2617435,,,,,,,,,,,,,, -92000,1.5071455,2.1792183,,,,,,,,,,,,,, -92100,1.3302977,3.1588216,,,,,,,,,,,,,, -92200,1.1749192,3.5590708,,,,,,,,,,,,,, -92300,1.3583957,2.5695393,,,,,,,,,,,,,, -92400,1.3133743,2.5693634,,,,,,,,,,,,,, -92500,1.454287,2.2228687,,,,,,,,,,,,,, -92557,,,0.7112109065055847,1.1575263738632202,0.6488800048828125,1.4489514827728271,50000.0,0.52510005235672,2.1200804710388184,10000.0,43316.60587668419,48739.50898981094,43316.60587668419,5413.447510957718,4.164551734924316,0.0 -92600,1.3978637,2.2882586,,,,,,,,,,,,,, -92700,1.5255762,2.2987394,,,,,,,,,,,,,, -92800,1.3522587,4.380617,,,,,,,,,,,,,, -92900,1.1775633,3.306349,,,,,,,,,,,,,, -93000,1.4214041,2.2005708,,,,,,,,,,,,,, -93100,1.3219763,2.3213158,,,,,,,,,,,,,, -93200,1.1844683,4.779039,,,,,,,,,,,,,, -93300,1.3989698,4.2653527,,,,,,,,,,,,,, -93400,1.4229238,3.0361853,,,,,,,,,,,,,, -93455,,,0.6953515410423279,1.2690379619598389,0.6404399871826172,1.5253348350524902,50000.0,0.5205000042915344,2.171663999557495,10000.0,43736.8600165844,49212.40234661102,43736.8600165844,5465.9878051280975,4.211713075637817,0.0 -93500,1.2111977,3.6600819,,,,,,,,,,,,,, -93600,1.2368436,4.701106,,,,,,,,,,,,,, -93700,1.5054691,2.1696873,,,,,,,,,,,,,, -93800,1.5132071,2.2337222,,,,,,,,,,,,,, -93900,1.3951905,2.4338264,,,,,,,,,,,,,, -94000,1.4166282,2.2402258,,,,,,,,,,,,,, -94100,1.1860231,3.6932042,,,,,,,,,,,,,, -94200,1.2526515,4.8382483,,,,,,,,,,,,,, -94300,1.5023781,2.3133993,,,,,,,,,,,,,, -94353,,,0.7071093320846558,1.196900486946106,0.6487599611282349,1.4550457000732422,50000.0,0.526900053024292,2.0961992740631104,10000.0,44157.11583042145,49687.024679899216,44157.11583042145,5520.260136604309,4.254514455795288,0.0 -94400,1.1357799,4.137462,,,,,,,,,,,,,, -94500,1.4391098,2.145118,,,,,,,,,,,,,, -94600,1.3290069,2.6894884,,,,,,,,,,,,,, -94700,1.3673215,4.553104,,,,,,,,,,,,,, -94800,1.4728621,4.6784983,,,,,,,,,,,,,, -94900,1.4866618,2.093617,,,,,,,,,,,,,, -95000,1.2757547,4.4468884,,,,,,,,,,,,,, -95100,1.4828141,2.3396358,,,,,,,,,,,,,, -95200,1.1788508,3.1365843,,,,,,,,,,,,,, -95251,,,0.7099023461341858,1.2014402151107788,0.6468999981880188,1.4840155839920044,50000.0,0.5276000499725342,2.12835431098938,10000.0,44577.40431547165,50159.85675024986,44577.40431547165,5572.704406738281,4.302710056304932,0.0 -95300,1.3621502,2.5529714,,,,,,,,,,,,,, -95400,1.493671,2.1935754,,,,,,,,,,,,,, -95500,1.4918885,2.219125,,,,,,,,,,,,,, -95600,1.2700881,4.5577545,,,,,,,,,,,,,, -95700,1.5235196,2.2081428,,,,,,,,,,,,,, -95800,1.6837647,2.190932,,,,,,,,,,,,,, -95900,1.3172861,3.2503579,,,,,,,,,,,,,, -96000,1.5844897,2.289779,,,,,,,,,,,,,, -96100,1.2011354,4.2017503,,,,,,,,,,,,,, -96149,,,0.7105468511581421,1.17483651638031,0.6487999558448792,1.4631506204605105,50000.0,0.525700032711029,2.12905216217041,10000.0,44997.624660253525,50631.88677740097,44997.624660253525,5624.41569018364,4.350196361541748,0.0 -96200,1.7038789,2.2521644,,,,,,,,,,,,,, -96300,1.4452405,2.2993388,,,,,,,,,,,,,, -96400,1.4692612,2.2506135,,,,,,,,,,,,,, -96500,1.4022648,4.55292,,,,,,,,,,,,,, -96600,1.4320891,2.7303307,,,,,,,,,,,,,, -96700,1.3862051,3.2567885,,,,,,,,,,,,,, -96800,1.3292787,3.2699654,,,,,,,,,,,,,, -96900,1.4717649,2.2334905,,,,,,,,,,,,,, -97000,1.3557365,4.543356,,,,,,,,,,,,,, -97044,,,0.7084765434265137,1.200149655342102,0.650879979133606,1.4600813388824463,50000.0,0.5293000340461731,2.103090524673462,10000.0,45417.73803925514,51103.32812476158,45417.73803925514,5675.649765968323,4.3937060832977295,0.0 -97100,1.4184866,2.6900682,,,,,,,,,,,,,, -97200,1.4304175,2.59463,,,,,,,,,,,,,, -97300,1.3757952,2.5747898,,,,,,,,,,,,,, -97400,1.6577059,2.3839695,,,,,,,,,,,,,, -97500,1.431132,4.5277905,,,,,,,,,,,,,, -97600,1.2880001,4.7093396,,,,,,,,,,,,,, -97700,1.2812082,3.4777224,,,,,,,,,,,,,, -97800,1.3332537,2.8160286,,,,,,,,,,,,,, -97900,1.4953009,2.219075,,,,,,,,,,,,,, -97942,,,0.7183789014816284,1.144914627075195,0.6541199684143066,1.44417405128479,50000.0,0.5301000475883484,2.0910539627075195,10000.0,45838.01228451729,51577.29973363876,45838.01228451729,5729.255409002304,4.434265851974487,0.0 -98000,1.4967964,2.547489,,,,,,,,,,,,,, -98100,1.2072777,4.505828,,,,,,,,,,,,,, -98200,1.6660086,2.2661676,,,,,,,,,,,,,, -98300,1.5352838,2.209839,,,,,,,,,,,,,, -98400,1.5449213,2.271224,,,,,,,,,,,,,, -98500,1.2740103,2.6062887,,,,,,,,,,,,,, -98600,1.241111,4.196491,,,,,,,,,,,,,, -98700,1.620148,2.2510383,,,,,,,,,,,,,, -98800,1.3921851,2.7250924,,,,,,,,,,,,,, -98841,,,0.7341406345367432,1.0801435708999634,0.6582199931144714,1.4175523519515991,50000.0,0.530500054359436,2.081018686294556,10000.0,46258.07188653946,52049.81680226326,46258.07188653946,5781.615537166596,4.480335235595703,0.0 -98900,1.4887129,2.2765355,,,,,,,,,,,,,, -99000,1.6036739,2.156227,,,,,,,,,,,,,, -99100,1.2544696,4.7649117,,,,,,,,,,,,,, -99200,1.6601689,2.206402,,,,,,,,,,,,,, -99300,1.4486176,2.164248,,,,,,,,,,,,,, -99400,1.4722453,2.6759439,,,,,,,,,,,,,, -99500,1.5210398,2.146715,,,,,,,,,,,,,, -99600,1.4523853,2.3035011,,,,,,,,,,,,,, -99700,1.4381883,2.2606087,,,,,,,,,,,,,, -99736,,,0.7138085961341858,1.1494009494781494,0.6567999720573425,1.4137784242630005,50000.0,0.541100025177002,2.062328338623047,10000.0,46678.317527771,52521.2454571724,46678.317527771,5832.705719232559,4.522157669067383,0.0 -99800,1.5143161,2.1370883,,,,,,,,,,,,,, -99900,1.413186,2.1179304,,,,,,,,,,,,,, -100000,1.3255879,4.052937,,,,,,,,,,,,,, -100100,1.732869,2.281477,,,,,,,,,,,,,, -100200,1.5758334,2.1067276,,,,,,,,,,,,,, -100300,1.4998896,2.325104,,,,,,,,,,,,,, -100400,1.587652,2.2255828,,,,,,,,,,,,,, -100500,1.4608797,3.0019453,,,,,,,,,,,,,, -100600,1.6918684,2.1573336,,,,,,,,,,,,,, -100630,,,0.71728515625,1.1666183471679688,0.6571399569511414,1.4454323053359983,50000.0,0.5279000401496887,2.102136850357056,10000.0,47098.33820104599,52993.11081790924,47098.33820104599,5884.45601606369,4.5659730434417725,0.0 -100700,1.513865,2.2995777,,,,,,,,,,,,,, -100800,1.4391718,4.600264,,,,,,,,,,,,,, -100900,1.7103053,2.2414596,,,,,,,,,,,,,, -101000,1.4838562,2.124748,,,,,,,,,,,,,, -101100,1.3270179,2.9083257,,,,,,,,,,,,,, -101200,1.3404385,2.6004372,,,,,,,,,,,,,, -101300,1.2650158,3.9295635,,,,,,,,,,,,,, -101400,1.8524176,2.220149,,,,,,,,,,,,,, -101500,1.6110373,2.2884736,,,,,,,,,,,,,, -101524,,,0.7432226538658142,1.042794108390808,0.6620199680328369,1.3963292837142944,50000.0,0.5356000065803528,2.04892635345459,10000.0,47518.43212723732,53467.515218257904,47518.43212723732,5938.6716158390045,4.610373020172119,0.0 -101600,1.4954617,2.320492,,,,,,,,,,,,,, -101700,1.4338783,2.094051,,,,,,,,,,,,,, -101800,1.686682,2.4797165,,,,,,,,,,,,,, -101900,1.3813769,2.3428855,,,,,,,,,,,,,, -102000,1.6698629,2.2121785,,,,,,,,,,,,,, -102100,1.3735508,3.9231997,,,,,,,,,,,,,, -102200,1.4280267,3.007204,,,,,,,,,,,,,, -102300,1.5165986,2.1731055,,,,,,,,,,,,,, -102400,1.3088751,3.8472662,,,,,,,,,,,,,, -102418,,,0.7162694931030273,1.132684588432312,0.6610999703407288,1.4008057117462158,50000.0,0.5392000079154968,2.046409368515014,10000.0,47938.56771445274,53939.10957980156,47938.56771445274,5990.034274101257,4.655975341796875,0.0 -102500,1.6360177,2.3163762,,,,,,,,,,,,,, -102600,1.6663581,2.193763,,,,,,,,,,,,,, -102700,1.3205029,3.0685344,,,,,,,,,,,,,, -102800,1.7532578,2.2374945,,,,,,,,,,,,,, -102900,1.5344021,2.2375913,,,,,,,,,,,,,, -103000,1.6707419,2.0540514,,,,,,,,,,,,,, -103100,1.6804549,2.1345174,,,,,,,,,,,,,, -103200,1.3290172,3.762699,,,,,,,,,,,,,, -103300,1.1950871,3.9535003,,,,,,,,,,,,,, -103313,,,0.7245898246765137,1.1347663402557373,0.6670799851417542,1.4015071392059326,50000.0,0.5410000085830688,2.047637939453125,10000.0,48358.572227716446,54411.3479924202,48358.572227716446,6042.1712164878845,4.702473402023315,0.0 -103400,1.4041618,3.766459,,,,,,,,,,,,,, -103500,1.5118358,3.9035423,,,,,,,,,,,,,, -103600,1.6119688,2.244184,,,,,,,,,,,,,, -103700,1.5461396,2.6853805,,,,,,,,,,,,,, -103800,1.6135147,2.2032912,,,,,,,,,,,,,, -103900,1.3840114,2.983871,,,,,,,,,,,,,, -104000,1.6355528,2.1368446,,,,,,,,,,,,,, -104100,1.5705802,2.5689387,,,,,,,,,,,,,, -104200,1.7062157,2.0836067,,,,,,,,,,,,,, -104208,,,0.745898425579071,1.0389422178268433,0.6654399633407593,1.3964356184005735,50000.0,0.5436000227928162,2.037595748901367,10000.0,48778.65916538239,54884.48488640785,48778.65916538239,6095.1235938072205,4.748953342437744,0.0 -104300,1.5637596,4.631304,,,,,,,,,,,,,, -104400,1.5333604,2.2332582,,,,,,,,,,,,,, -104500,1.5558393,2.2292607,,,,,,,,,,,,,, -104600,1.5360814,2.175593,,,,,,,,,,,,,, -104700,1.6384063,2.1092043,,,,,,,,,,,,,, -104800,1.6829,2.4003258,,,,,,,,,,,,,, -104900,1.4272715,2.0620406,,,,,,,,,,,,,, -105000,1.2854682,3.618236,,,,,,,,,,,,,, -105100,1.4623983,2.1945076,,,,,,,,,,,,,, -105104,,,0.7211328148841858,1.13435959815979,0.6629399657249451,1.3976269960403442,50000.0,0.5414000153541565,2.042243719100952,10000.0,49199.02674174309,55356.00403022766,49199.02674174309,6146.176388025284,4.796812057495117,0.0 -105200,1.4368035,2.1712692,,,,,,,,,,,,,, -105300,1.5221114,2.0964634,,,,,,,,,,,,,, -105400,1.5028929,2.0992572,,,,,,,,,,,,,, -105500,1.4284742,2.334598,,,,,,,,,,,,,, -105600,1.4944797,2.0794556,,,,,,,,,,,,,, -105700,1.5244757,2.1466627,,,,,,,,,,,,,, -105800,1.5528425,2.1176171,,,,,,,,,,,,,, -105900,1.7748893,2.1960216,,,,,,,,,,,,,, -106000,1.4313713,3.1413574,,,,,,,,,,,,,, -106001,,,0.7251952886581421,1.1188418865203855,0.6670399904251099,1.396875023841858,50000.0,0.542900025844574,2.0384507179260254,10000.0,49619.57560873032,55830.23876786232,49619.57560873032,6199.765183925629,4.843739748001099,0.0 -106100,1.6202052,2.271338,,,,,,,,,,,,,, -106200,1.5323763,2.1662016,,,,,,,,,,,,,, -106300,1.4346919,4.4511566,,,,,,,,,,,,,, -106400,1.5392228,2.0476565,,,,,,,,,,,,,, -106500,1.5934432,2.094993,,,,,,,,,,,,,, -106600,1.6500964,2.2236042,,,,,,,,,,,,,, -106700,1.941536,2.047473,,,,,,,,,,,,,, -106800,1.4444854,4.3828344,,,,,,,,,,,,,, -106900,1.5250609,2.3612542,,,,,,,,,,,,,, -106902,,,0.7347265481948853,1.0637056827545166,0.6632599830627441,1.3980926275253296,50000.0,0.539400041103363,2.044331312179565,10000.0,50039.59346675873,56303.84405493736,50039.59346675873,6253.255183458328,4.89024019241333,0.0 -107000,1.3275324,2.4822257,,,,,,,,,,,,,, -107100,1.3660948,4.127262,,,,,,,,,,,,,, -107200,1.3529505,2.7626483,,,,,,,,,,,,,, -107300,1.5095446,2.5063536,,,,,,,,,,,,,, -107400,1.589193,2.1644795,,,,,,,,,,,,,, -107500,1.4146276,4.062603,,,,,,,,,,,,,, -107600,1.7625705,2.1338944,,,,,,,,,,,,,, -107700,1.5335255,1.972385,,,,,,,,,,,,,, -107798,,,0.7272851467132568,1.1355656385421753,0.6676599979400635,1.4018704891204834,50000.0,0.5404000282287598,2.063552141189575,10000.0,50459.5941298008,56777.11013770104,50459.5941298008,6306.424833536148,4.935125350952148,0.0 -107800,1.4254794,2.4309607,,,,,,,,,,,,,, -107900,1.754521,2.148817,,,,,,,,,,,,,, -108000,1.4473207,4.644761,,,,,,,,,,,,,, -108100,1.66886,2.1373992,,,,,,,,,,,,,, -108200,1.441665,3.7357152,,,,,,,,,,,,,, -108300,1.8617163,2.1225193,,,,,,,,,,,,,, -108400,1.453162,4.41166,,,,,,,,,,,,,, -108500,1.3414611,3.041928,,,,,,,,,,,,,, -108600,1.6690754,2.2564597,,,,,,,,,,,,,, -108692,,,0.7261914014816284,1.173351764678955,0.6655200123786926,1.454077959060669,50000.0,0.5416000485420227,2.09930157661438,10000.0,50879.51830530167,57248.41434073448,50879.51830530167,6357.710418224335,4.979193210601807,0.0 -108700,2.0458195,2.1697896,,,,,,,,,,,,,, -108800,1.4826341,3.563488,,,,,,,,,,,,,, -108900,1.7329463,2.1391115,,,,,,,,,,,,,, -109000,1.6113367,2.1451893,,,,,,,,,,,,,, -109100,1.3681571,4.66466,,,,,,,,,,,,,, -109200,1.3948838,3.7151327,,,,,,,,,,,,,, -109300,1.6638771,2.1456482,,,,,,,,,,,,,, -109400,1.6491852,2.0131097,,,,,,,,,,,,,, -109500,1.5605985,2.1599555,,,,,,,,,,,,,, -109582,,,0.7465234398841858,1.033710479736328,0.6727199554443359,1.3676403760910034,50000.0,0.5498000383377075,2.0147860050201416,10000.0,51299.49491405487,57719.42250370979,51299.49491405487,6408.644496679306,5.026316642761231,0.0 -109600,1.3571761,3.4446213,,,,,,,,,,,,,, -109700,1.6621398,2.448382,,,,,,,,,,,,,, -109800,1.3305624,4.145012,,,,,,,,,,,,,, -109900,1.6325903,2.081753,,,,,,,,,,,,,, -110000,1.7112329,2.111032,,,,,,,,,,,,,, -110100,1.6720115,2.1669989,,,,,,,,,,,,,, -110200,1.4119693,4.3225837,,,,,,,,,,,,,, -110300,1.4496742,3.087037,,,,,,,,,,,,,, -110400,1.4778498,4.148141,,,,,,,,,,,,,, -110472,,,0.7326952815055847,1.1184043884277344,0.6693399548530579,1.3929319381713867,50000.0,0.54830002784729,2.028778553009033,10000.0,51719.80661559105,58192.70450210571,51719.80661559105,6461.516880512238,5.074175596237183,0.0 -110500,1.9003748,2.144356,,,,,,,,,,,,,, -110600,1.5187936,4.1331644,,,,,,,,,,,,,, -110700,1.478601,2.788428,,,,,,,,,,,,,, -110800,1.5841473,4.1944118,,,,,,,,,,,,,, -110900,1.4785854,3.8897815,,,,,,,,,,,,,, -111000,1.5069667,2.23702,,,,,,,,,,,,,, -111100,1.5615047,2.9988317,,,,,,,,,,,,,, -111200,1.4511762,2.9432704,,,,,,,,,,,,,, -111300,1.5026594,3.0392494,,,,,,,,,,,,,, -111369,,,0.7405859231948853,1.0059020519256592,0.678380012512207,1.3027551174163818,50000.0,0.5551000237464905,1.9627418518066408,10000.0,52140.07601737976,58666.6617166996,52140.07601737976,6515.1032173633575,5.12527322769165,0.0 -111400,1.4138672,4.561894,,,,,,,,,,,,,, -111500,1.607693,2.0917623,,,,,,,,,,,,,, -111600,1.5314301,3.000722,,,,,,,,,,,,,, -111700,1.6601446,2.2603168,,,,,,,,,,,,,, -111800,1.5382208,3.8499327,,,,,,,,,,,,,, -111900,1.6315258,2.0801172,,,,,,,,,,,,,, -112000,1.6394886,1.973437,,,,,,,,,,,,,, -112100,2.0572603,2.0679166,,,,,,,,,,,,,, -112200,1.7498724,4.7112503,,,,,,,,,,,,,, -112264,,,0.7454296946525574,0.994961440563202,0.677619993686676,1.3128902912139893,50000.0,0.5504000186920166,1.9710676670074463,10000.0,52560.04998207092,59139.74605703354,52560.04998207092,6568.118919849396,5.169904947280884,0.0 -112300,1.5280005,3.2270422,,,,,,,,,,,,,, -112400,1.6705885,2.30494,,,,,,,,,,,,,, -112500,1.6655406,2.0476089,,,,,,,,,,,,,, -112600,1.6060985,2.7327461,,,,,,,,,,,,,, -112700,1.6688049,2.0648713,,,,,,,,,,,,,, -112800,1.4699349,4.290726,,,,,,,,,,,,,, -112900,1.6401575,1.9822474,,,,,,,,,,,,,, -113000,1.6105525,2.2491264,,,,,,,,,,,,,, -113100,1.619636,1.9875101,,,,,,,,,,,,,, -113154,,,0.7370312213897705,1.0629130601882937,0.674340009689331,1.3487356901168823,50000.0,0.5530000329017639,1.994787216186524,10000.0,52980.3245677948,59613.3454978466,52980.3245677948,6621.345211267471,5.218159914016724,0.0 -113200,1.5049887,2.2812881,,,,,,,,,,,,,, -113300,1.8638947,2.090766,,,,,,,,,,,,,, -113400,1.8171507,1.9156829,,,,,,,,,,,,,, -113500,1.800439,2.0104613,,,,,,,,,,,,,, -113600,1.674992,2.201751,,,,,,,,,,,,,, -113700,1.6630431,2.140978,,,,,,,,,,,,,, -113800,1.6380975,2.7570124,,,,,,,,,,,,,, -113900,1.6235014,2.0586147,,,,,,,,,,,,,, -114000,1.4098935,4.1885643,,,,,,,,,,,,,, -114049,,,0.7362695336341858,1.114818453788757,0.6717199683189392,1.402865290641785,50000.0,0.5480000376701355,2.0506300926208496,10000.0,53400.63660097122,60084.95163035393,53400.63660097122,6672.544312000275,5.261758804321289,0.0 -114100,1.8034047,4.4870653,,,,,,,,,,,,,, -114200,1.3251116,3.6751306,,,,,,,,,,,,,, -114300,1.7021651,2.1458275,,,,,,,,,,,,,, -114400,1.4515339,4.38098,,,,,,,,,,,,,, -114500,1.6294838,2.0970695,,,,,,,,,,,,,, -114600,1.7543643,2.0508165,,,,,,,,,,,,,, -114700,1.7509998,2.1566148,,,,,,,,,,,,,, -114800,1.7774303,2.150587,,,,,,,,,,,,,, -114900,1.7896506,2.101112,,,,,,,,,,,,,, -114941,,,0.7517382502555847,1.0051695108413696,0.6760599613189697,1.3408530950546265,50000.0,0.5499000549316406,1.9954674243927,10000.0,53820.99761939049,60558.87907385826,53820.99761939049,6726.00217795372,5.320492744445801,0.0 -115000,1.4348913,3.8693209,,,,,,,,,,,,,, -115100,1.6942,2.084058,,,,,,,,,,,,,, -115200,1.5529367,4.036004,,,,,,,,,,,,,, -115300,1.9440033,2.2226415,,,,,,,,,,,,,, -115400,1.6221238,2.9746304,,,,,,,,,,,,,, -115500,1.5214638,2.5841568,,,,,,,,,,,,,, -115600,1.467199,3.427462,,,,,,,,,,,,,, -115700,1.7101452,1.9655272,,,,,,,,,,,,,, -115800,1.5790888,3.2640038,,,,,,,,,,,,,, -115834,,,0.7432421445846558,1.0460532903671265,0.6818000078201294,1.3182965517044067,50000.0,0.5558000206947327,1.9709168672561648,10000.0,54240.927837610245,61032.71226191521,54240.927837610245,6779.807821750641,5.36723256111145,0.0 -115900,1.3899807,3.4775686,,,,,,,,,,,,,, -116000,1.8448116,2.0316596,,,,,,,,,,,,,, -116100,1.4940977,3.3809314,,,,,,,,,,,,,, -116200,1.7502936,1.9991229,,,,,,,,,,,,,, -116300,2.0869758,2.0742989,,,,,,,,,,,,,, -116400,1.6888828,2.0631244,,,,,,,,,,,,,, -116500,1.8122306,2.0723586,,,,,,,,,,,,,, -116600,1.5293945,3.1516726,,,,,,,,,,,,,, -116700,1.6722435,2.0888097,,,,,,,,,,,,,, -116723,,,0.7484374642372131,1.0299557447433472,0.6839199662208557,1.325279951095581,50000.0,0.5578000545501709,1.965359807014465,10000.0,54660.92917466164,61503.20591568947,54660.92917466164,6830.200934410095,5.416102886199951,0.0 -116800,1.659441,1.9738805,,,,,,,,,,,,,, -116900,1.7462869,1.9887741,,,,,,,,,,,,,, -117000,1.6847018,3.2398634,,,,,,,,,,,,,, -117100,1.5636945,4.2429476,,,,,,,,,,,,,, -117200,1.5272099,4.149276,,,,,,,,,,,,,, -117300,1.9232881,2.2214534,,,,,,,,,,,,,, -117400,1.8583956,2.0478787,,,,,,,,,,,,,, -117500,1.7690091,2.0705574,,,,,,,,,,,,,, -117600,1.5300162,4.480446,,,,,,,,,,,,,, -117617,,,0.7597265243530273,0.952623963356018,0.6856600046157837,1.292971968650818,50000.0,0.560699999332428,1.9365530014038088,10000.0,55081.1960170269,61978.925362825394,55081.1960170269,6885.558126449585,5.461392164230347,0.0 -117700,1.763261,1.9451784,,,,,,,,,,,,,, -117800,1.71335,1.9832346,,,,,,,,,,,,,, -117900,1.5385032,4.427619,,,,,,,,,,,,,, -118000,2.1175656,2.0655935,,,,,,,,,,,,,, -118100,1.6705384,2.2107286,,,,,,,,,,,,,, -118200,1.8989944,4.713432,,,,,,,,,,,,,, -118300,1.711862,4.1655154,,,,,,,,,,,,,, -118400,1.8001679,1.9605777,,,,,,,,,,,,,, -118500,1.6006882,2.1568377,,,,,,,,,,,,,, -118506,,,0.7503515481948853,1.0049922466278076,0.6844199895858765,1.30093252658844,50000.0,0.558899998664856,1.9488836526870728,10000.0,55501.43515133858,62454.564338207245,55501.43515133858,6940.854739904404,5.514406204223633,0.0 -118600,1.6534647,1.903783,,,,,,,,,,,,,, -118700,1.6728575,2.2879553,,,,,,,,,,,,,, -118800,1.679055,4.3980837,,,,,,,,,,,,,, -118900,2.2650516,2.6370168,,,,,,,,,,,,,, -119000,1.8645761,3.315597,,,,,,,,,,,,,, -119100,1.6667641,1.9329451,,,,,,,,,,,,,, -119200,1.8261211,2.150397,,,,,,,,,,,,,, -119300,1.6565024,2.3867674,,,,,,,,,,,,,, -119398,,,0.7482226490974426,1.0383548736572266,0.6855599880218506,1.3293486833572388,50000.0,0.55840003490448,1.978242874145508,10000.0,55921.57339477539,62927.31496787071,55921.57339477539,6993.368841648102,5.561697959899902,0.0 -119400,1.5812552,2.5886078,,,,,,,,,,,,,, -119500,1.690458,1.938123,,,,,,,,,,,,,, -119600,1.6263142,2.840556,,,,,,,,,,,,,, -119700,1.4466943,4.0453906,,,,,,,,,,,,,, -119800,1.8043829,1.8954415,,,,,,,,,,,,,, -119900,1.8603897,2.1024854,,,,,,,,,,,,,, -120000,1.736595,1.8755062,,,,,,,,,,,,,, -120100,1.5335102,3.262256,,,,,,,,,,,,,, -120200,1.6814996,2.0616195,,,,,,,,,,,,,, -120291,,,0.7626562118530273,0.9669126868247986,0.6891199946403503,1.2968260049819946,50000.0,0.5616000294685364,1.9497791528701784,10000.0,56341.53645062447,63398.22046136856,56341.53645062447,7044.214889287949,5.607126235961914,0.0 -120300,1.88183,3.7998514,,,,,,,,,,,,,, -120400,1.7702833,2.9778392,,,,,,,,,,,,,, -120500,1.7411622,1.9860837,,,,,,,,,,,,,, -120600,1.668203,3.0205863,,,,,,,,,,,,,, -120700,1.6662108,1.9988269,,,,,,,,,,,,,, -120800,1.7506999,3.124813,,,,,,,,,,,,,, -120900,1.6725141,4.5606804,,,,,,,,,,,,,, -121000,1.8275901,1.9830621,,,,,,,,,,,,,, -121100,1.8379744,2.1673946,,,,,,,,,,,,,, -121188,,,0.7483984231948853,1.023191213607788,0.6844800114631653,1.3076735734939575,50000.0,0.5606000423431396,1.9592289924621584,10000.0,56761.78638911247,63871.95942783356,56761.78638911247,7097.606656551361,5.652962684631348,0.0 -121200,1.8711942,2.0448747,,,,,,,,,,,,,, -121300,1.6110814,3.0677114,,,,,,,,,,,,,, -121400,1.8574252,2.084964,,,,,,,,,,,,,, -121500,1.844475,2.1293426,,,,,,,,,,,,,, -121600,1.5968126,3.4682076,,,,,,,,,,,,,, -121700,1.7502476,1.9370087,,,,,,,,,,,,,, -121800,1.7983983,2.0137534,,,,,,,,,,,,,, -121900,1.5646344,3.904067,,,,,,,,,,,,,, -122000,1.5900738,4.4817305,,,,,,,,,,,,,, -122078,,,0.7550390362739563,0.9838645458221436,0.689799964427948,1.2810405492782593,50000.0,0.5636000037193298,1.9222257137298584,10000.0,57181.91209387779,64347.73667383194,57181.91209387779,7153.155112504959,5.705412864685059,0.0 -122100,1.8639542,2.010961,,,,,,,,,,,,,, -122200,1.6430162,2.5846214,,,,,,,,,,,,,, -122300,1.80912,1.9900271,,,,,,,,,,,,,, -122400,1.6336008,2.4551253,,,,,,,,,,,,,, -122500,1.6921247,2.005225,,,,,,,,,,,,,, -122600,1.7070472,2.063063,,,,,,,,,,,,,, -122700,1.8350234,1.910537,,,,,,,,,,,,,, -122800,1.8075049,2.0613472,,,,,,,,,,,,,, -122900,1.7138045,2.5551655,,,,,,,,,,,,,, -122967,,,0.767285168170929,0.914058804512024,0.6899399757385254,1.2592631578445437,50000.0,0.5637000203132629,1.895406723022461,10000.0,57601.97654438019,64821.3822324276,57601.97654438019,7206.63245177269,5.758823394775391,0.0 -123000,1.7407267,2.5361876,,,,,,,,,,,,,, -123100,1.7374748,2.5054657,,,,,,,,,,,,,, -123200,1.9254062,1.8866807,,,,,,,,,,,,,, -123300,1.656839,1.9960858,,,,,,,,,,,,,, -123400,1.9790412,1.9135942,,,,,,,,,,,,,, -123500,1.9654144,2.003454,,,,,,,,,,,,,, -123600,1.7612879,2.213022,,,,,,,,,,,,,, -123700,1.7707654,1.8358309,,,,,,,,,,,,,, -123800,1.8659511,2.0268996,,,,,,,,,,,,,, -123855,,,0.7547070384025574,0.9752053618431092,0.691819965839386,1.2582887411117554,50000.0,0.5642000436782837,1.9256024360656736,10000.0,58022.1356446743,65292.18530201912,58022.1356446743,7257.178694009781,5.806592226028442,0.0 -123900,1.9136887,2.2049434,,,,,,,,,,,,,, -124000,1.7601036,1.9998285,,,,,,,,,,,,,, -124100,1.9385532,2.0543854,,,,,,,,,,,,,, -124200,1.8291818,2.0315144,,,,,,,,,,,,,, -124300,1.5562971,2.3873055,,,,,,,,,,,,,, -124400,1.9067198,1.974124,,,,,,,,,,,,,, -124500,1.7375802,2.384556,,,,,,,,,,,,,, -124600,1.743726,4.4459505,,,,,,,,,,,,,, -124700,1.8251655,2.2710052,,,,,,,,,,,,,, -124748,,,0.7588085532188416,0.9643204808235168,0.6942799687385559,1.259268045425415,50000.0,0.5690000057220459,1.9072344303131104,10000.0,58442.4238243103,65765.31973147392,58442.4238243103,7309.925290107727,5.855309247970581,0.0 -124800,1.8405882,1.9081614,,,,,,,,,,,,,, -124900,1.657381,3.02946,,,,,,,,,,,,,, -125000,1.9770408,1.9613047,,,,,,,,,,,,,, -125100,2.0410776,1.9386334,,,,,,,,,,,,,, -125200,1.814514,2.090684,,,,,,,,,,,,,, -125300,1.7519627,2.0186558,,,,,,,,,,,,,, -125400,1.6091166,4.2604923,,,,,,,,,,,,,, -125500,1.7645159,1.9729084,,,,,,,,,,,,,, -125600,1.7528001,2.9280996,,,,,,,,,,,,,, -125642,,,0.7699999809265137,0.9051749110221864,0.6944599747657776,1.2426235675811768,50000.0,0.5710000395774841,1.878885269165039,10000.0,58862.69698309898,66237.66225218773,58862.69698309898,7361.893659353256,5.9061243534088135,0.0 -125700,1.7952266,1.9908243,,,,,,,,,,,,,, -125800,1.9269041,2.0984693,,,,,,,,,,,,,, -125900,1.6819304,2.900653,,,,,,,,,,,,,, -126000,1.8632622,3.4761798,,,,,,,,,,,,,, -126100,2.0178823,1.9884984,,,,,,,,,,,,,, -126200,1.7609756,1.8832593,,,,,,,,,,,,,, -126300,1.7898544,4.1745768,,,,,,,,,,,,,, -126400,1.6836672,2.567596,,,,,,,,,,,,,, -126500,1.8164157,4.5560684,,,,,,,,,,,,,, -126535,,,0.7580859065055847,0.9454815983772278,0.6959999799728394,1.2273520231246948,50000.0,0.5751000046730042,1.8674311637878416,10000.0,59282.61265707016,66709.08399271965,59282.61265707016,7413.29775929451,5.957125425338745,0.0 -126600,1.748932,3.6075203,,,,,,,,,,,,,, -126700,1.6889061,2.1196945,,,,,,,,,,,,,, -126800,2.06665,2.1444862,,,,,,,,,,,,,, -126900,2.1529608,1.9086719,,,,,,,,,,,,,, -127000,1.9790071,2.3253083,,,,,,,,,,,,,, -127100,1.7378078,3.3729076,,,,,,,,,,,,,, -127200,1.7975963,4.267705,,,,,,,,,,,,,, -127300,2.008302,1.7996013,,,,,,,,,,,,,, -127400,1.6709121,2.323742,,,,,,,,,,,,,, -127428,,,0.7644921541213989,0.9372758865356444,0.6952999830245972,1.2479116916656494,50000.0,0.5734000205993652,1.8793699741363523,10000.0,59702.56274223328,67182.71284723282,59702.56274223328,7466.876657009125,6.0067572593688965,0.0 -127500,1.6539273,3.8812118,,,,,,,,,,,,,, -127600,1.8895814,1.9634624,,,,,,,,,,,,,, -127700,1.7129878,2.495855,,,,,,,,,,,,,, -127800,2.0051455,2.3104405,,,,,,,,,,,,,, -127900,1.8717808,1.9374516,,,,,,,,,,,,,, -128000,1.7695727,3.691059,,,,,,,,,,,,,, -128100,1.6379548,3.147265,,,,,,,,,,,,,, -128200,1.8608081,4.2399464,,,,,,,,,,,,,, -128300,1.80757,4.50665,,,,,,,,,,,,,, -128321,,,0.7732617259025574,0.9040424823760986,0.696179986000061,1.2386150360107422,50000.0,0.5678000450134277,1.8976391553878784,10000.0,60122.79120898247,67656.58211922646,60122.79120898247,7520.418255567551,6.055173873901367,0.0 -128400,2.1908116,2.1892004,,,,,,,,,,,,,, -128500,1.9943131,1.830558,,,,,,,,,,,,,, -128600,1.944203,2.5077903,,,,,,,,,,,,,, -128700,1.7930332,2.4248588,,,,,,,,,,,,,, -128800,2.061382,2.0364003,,,,,,,,,,,,,, -128900,2.0975735,1.8802693,,,,,,,,,,,,,, -129000,2.2788734,1.7843527,,,,,,,,,,,,,, -129100,1.8955467,4.371849,,,,,,,,,,,,,, -129200,1.9807814,2.2416961,,,,,,,,,,,,,, -129217,,,0.7625390291213989,0.9737145900726318,0.6969599723815918,1.2702531814575195,50000.0,0.572100043296814,1.913463234901428,10000.0,60542.8531806469,68129.01347446442,60542.8531806469,7572.68899512291,6.103700399398804,0.0 -129300,1.6893247,3.3165865,,,,,,,,,,,,,, -129400,2.0847237,1.9458243,,,,,,,,,,,,,, -129500,1.8068966,4.352754,,,,,,,,,,,,,, -129600,1.8592455,1.9848521,,,,,,,,,,,,,, -129700,2.017206,2.0560262,,,,,,,,,,,,,, -129800,1.9182343,2.0758562,,,,,,,,,,,,,, -129900,1.9322245,1.8930575,,,,,,,,,,,,,, -130000,1.891106,2.0405126,,,,,,,,,,,,,, -130100,1.8835542,2.7710981,,,,,,,,,,,,,, -130113,,,0.7669726610183716,0.9456799626350404,0.7005999684333801,1.244845986366272,50000.0,0.5756000280380249,1.8801839351654053,10000.0,60962.88206410408,68602.47761464119,60962.88206410408,7626.02169251442,6.154720783233643,0.0 -130200,1.8048002,2.3570774,,,,,,,,,,,,,, -130300,1.9405332,4.514884,,,,,,,,,,,,,, -130400,1.8608568,1.9885187,,,,,,,,,,,,,, -130500,1.8791667,4.3755584,,,,,,,,,,,,,, -130600,2.2630417,3.7437487,,,,,,,,,,,,,, -130700,2.1958632,1.900032,,,,,,,,,,,,,, -130800,1.6951654,3.5298955,,,,,,,,,,,,,, -130900,2.259881,1.8731695,,,,,,,,,,,,,, -131000,,,0.7790820002555847,0.8822490572929382,0.7021999955177307,1.2245677709579468,50000.0,0.5766000151634216,1.8699551820755005,10000.0,61383.08276414871,69075.28121972084,61383.08276414871,7678.519439458847,6.209751129150391,0.0 -131000,1.7281071,2.8077157,,,,,,,,,,,,,, -131100,1.9729214,1.8364727,,,,,,,,,,,,,, -131200,1.9689317,1.9484258,,,,,,,,,,,,,, -131300,1.8301065,1.8150941,,,,,,,,,,,,,, -131400,1.89558,2.1493201,,,,,,,,,,,,,, -131500,1.8315161,1.9278275,,,,,,,,,,,,,, -131600,1.6933187,4.1638107,,,,,,,,,,,,,, -131700,1.7458206,4.266141,,,,,,,,,,,,,, -131800,1.7781647,3.2110887,,,,,,,,,,,,,, -131892,,,0.7725781202316284,0.9053280353546144,0.7041800022125244,1.2155123949050903,50000.0,0.575700044631958,1.8553532361984253,10000.0,61803.10675501824,69547.86748552322,61803.10675501824,7730.9835069179535,6.25759744644165,0.0 -131900,1.7647556,2.6557448,,,,,,,,,,,,,, -132000,1.8613759,3.7116556,,,,,,,,,,,,,, -132100,1.8204249,1.8532127,,,,,,,,,,,,,, -132200,1.7981521,2.30545,,,,,,,,,,,,,, -132300,1.7910446,3.6406293,,,,,,,,,,,,,, -132400,2.027154,1.9548352,,,,,,,,,,,,,, -132500,2.3174496,1.9618957,,,,,,,,,,,,,, -132600,1.7621032,3.4840117,,,,,,,,,,,,,, -132700,2.0472531,2.1261036,,,,,,,,,,,,,, -132782,,,0.7796679735183716,0.8771883249282837,0.7079199552536011,1.195351243019104,50000.0,0.5833000540733337,1.8324159383773804,10000.0,62223.355676651,70020.74941420555,62223.355676651,7783.519206285477,6.3048930168151855,0.0 -132800,2.19322,2.04052,,,,,,,,,,,,,, -132900,2.16242,4.2224703,,,,,,,,,,,,,, -133000,1.9356141,1.887097,,,,,,,,,,,,,, -133100,2.2977183,1.8099424,,,,,,,,,,,,,, -133200,2.2895238,1.8254557,,,,,,,,,,,,,, -133300,1.908275,1.8695776,,,,,,,,,,,,,, -133400,2.1098187,2.0121274,,,,,,,,,,,,,, -133500,1.9254524,1.7695107,,,,,,,,,,,,,, -133600,1.9624537,2.2695076,,,,,,,,,,,,,, -133679,,,0.78236323595047,0.8421926498413086,0.7068600058555603,1.1863993406295776,50000.0,0.5857000350952148,1.817793250083924,10000.0,62643.74037742615,70494.65468883514,62643.74037742615,7836.93887090683,6.35483980178833,0.0 -133700,2.0247428,1.8844392,,,,,,,,,,,,,, -133800,1.7966633,3.3622465,,,,,,,,,,,,,, -133900,1.9169846,2.265601,,,,,,,,,,,,,, -134000,1.9502202,1.8396921,,,,,,,,,,,,,, -134100,2.050073,1.8658055,,,,,,,,,,,,,, -134200,1.8753647,2.3757792,,,,,,,,,,,,,, -134300,1.9250436,1.9983897,,,,,,,,,,,,,, -134400,1.9662691,1.6846464,,,,,,,,,,,,,, -134500,2.068731,1.753687,,,,,,,,,,,,,, -134573,,,0.7735351324081421,0.902022123336792,0.7064799666404724,1.2089617252349854,50000.0,0.5770000219345093,1.8554986715316768,10000.0,63063.97910571098,70966.68536663055,63063.97910571098,7888.631495952606,6.403504371643066,0.0 -134600,2.041778,1.9115663,,,,,,,,,,,,,, -134700,1.8467996,3.5279326,,,,,,,,,,,,,, -134800,2.073388,1.8761652,,,,,,,,,,,,,, -134900,1.8123611,2.1828942,,,,,,,,,,,,,, -135000,2.1693873,2.72041,,,,,,,,,,,,,, -135100,1.8672329,2.1540134,,,,,,,,,,,,,, -135200,2.30031,3.784018,,,,,,,,,,,,,, -135300,1.9644091,1.8876926,,,,,,,,,,,,,, -135400,1.9755119,1.8430845,,,,,,,,,,,,,, -135466,,,0.7805468440055847,0.8645254373550415,0.7113199830055237,1.1867696046829224,50000.0,0.5851000547409058,1.8302571773529053,10000.0,63484.01175141335,71438.42376971245,63484.01175141335,7940.2364201545715,6.454331636428833,0.0 -135500,1.8053987,4.2373385,,,,,,,,,,,,,, -135600,2.4822245,1.9128418,,,,,,,,,,,,,, -135700,1.9912741,4.107058,,,,,,,,,,,,,, -135800,1.9966391,2.3087616,,,,,,,,,,,,,, -135900,1.9215564,1.742474,,,,,,,,,,,,,, -136000,1.891539,2.4572346,,,,,,,,,,,,,, -136100,1.9278532,2.0153658,,,,,,,,,,,,,, -136200,1.838952,3.5300846,,,,,,,,,,,,,, -136300,1.9138371,4.202426,,,,,,,,,,,,,, -136359,,,0.78480464220047,0.8444609045982361,0.7076399922370911,1.1892009973526,50000.0,0.5819000005722046,1.8246159553527832,10000.0,63904.009852170944,71910.59975337982,63904.009852170944,7992.3159103393555,6.5025811195373535,0.0 -136400,2.263384,4.2405605,,,,,,,,,,,,,, -136500,2.375539,1.9012287,,,,,,,,,,,,,, -136600,2.215901,2.541783,,,,,,,,,,,,,, -136700,1.8606529,2.1318521,,,,,,,,,,,,,, -136800,2.1014616,3.3032806,,,,,,,,,,,,,, -136900,1.8943168,1.6715207,,,,,,,,,,,,,, -137000,2.01832,3.9065154,,,,,,,,,,,,,, -137100,2.249009,1.6187214,,,,,,,,,,,,,, -137200,2.0258086,1.9011617,,,,,,,,,,,,,, -137253,,,0.7909960746765137,0.8081295490264893,0.7122399806976318,1.1592646837234497,50000.0,0.5891000032424927,1.7870514392852783,10000.0,64323.90890264511,72384.28784418106,64323.90890264511,8045.995781421661,6.56099271774292,0.0 -137300,1.9727573,4.308461,,,,,,,,,,,,,, -137400,2.086873,1.8325344,,,,,,,,,,,,,, -137500,1.8940719,3.315617,,,,,,,,,,,,,, -137600,2.1825047,2.306887,,,,,,,,,,,,,, -137700,2.182593,1.7506735,,,,,,,,,,,,,, -137800,2.0504344,1.9191469,,,,,,,,,,,,,, -137900,1.9691474,2.453713,,,,,,,,,,,,,, -138000,1.7575936,3.068505,,,,,,,,,,,,,, -138100,2.0869505,3.6284075,,,,,,,,,,,,,, -138143,,,0.7822265625,0.8877792358398438,0.7106399536132812,1.2171674966812134,50000.0,0.5836000442504883,1.8493305444717407,10000.0,64744.21198558808,72856.13017225266,64744.21198558808,8097.431872606277,6.613979339599609,0.0 -138200,2.0867784,1.9720253,,,,,,,,,,,,,, -138300,2.1213896,2.0189393,,,,,,,,,,,,,, -138400,2.297039,1.8187468,,,,,,,,,,,,,, -138500,1.7807878,3.563027,,,,,,,,,,,,,, -138600,2.1353192,3.861918,,,,,,,,,,,,,, -138700,2.8770995,1.8814776,,,,,,,,,,,,,, -138800,2.0634568,2.5783281,,,,,,,,,,,,,, -138900,2.038816,3.7149796,,,,,,,,,,,,,, -139000,2.0818412,2.4894035,,,,,,,,,,,,,, -139035,,,0.7908398509025574,0.8290248513221741,0.7122799754142761,1.171581745147705,50000.0,0.5884000062942505,1.7982066869735718,10000.0,65164.19147825241,73327.64161705971,65164.19147825241,8148.8644115924835,6.663257837295532,0.0 -139100,2.8355043,1.6756405,,,,,,,,,,,,,, -139200,2.1968813,1.824547,,,,,,,,,,,,,, -139300,2.2709286,2.341906,,,,,,,,,,,,,, -139400,2.0363774,1.8392537,,,,,,,,,,,,,, -139500,2.1748207,1.8301969,,,,,,,,,,,,,, -139600,2.2252548,3.1864116,,,,,,,,,,,,,, -139700,1.9768231,1.7971855,,,,,,,,,,,,,, -139800,2.3912961,1.7303165,,,,,,,,,,,,,, -139900,1.9704571,2.688045,,,,,,,,,,,,,, -139928,,,0.8020117282867432,0.7729999423027039,0.7156800031661987,1.1472184658050537,50000.0,0.5904000401496887,1.78292977809906,10000.0,65584.32089519501,73799.69849538803,65584.32089519501,8200.685409069061,6.717529058456421,0.0 -140000,2.0649498,1.6584967,,,,,,,,,,,,,, -140100,2.03764,1.6544573,,,,,,,,,,,,,, -140200,2.209036,4.1763973,,,,,,,,,,,,,, -140300,2.0529618,1.7387879,,,,,,,,,,,,,, -140400,2.0928948,2.9031727,,,,,,,,,,,,,, -140500,2.3649764,1.8809443,,,,,,,,,,,,,, -140600,2.0778832,2.683239,,,,,,,,,,,,,, -140700,1.9752085,2.8533363,,,,,,,,,,,,,, -140800,2.1228492,1.7965939,,,,,,,,,,,,,, -140823,,,0.7893944978713989,0.811639666557312,0.7133199572563171,1.150890588760376,50000.0,0.5879000425338745,1.7908883094787598,10000.0,66004.48244214058,74272.31378889084,66004.48244214058,8253.039557218552,6.766315937042236,0.0 -140900,2.2921944,2.5102248,,,,,,,,,,,,,, -141000,2.3564453,1.8092965,,,,,,,,,,,,,, -141100,1.9860427,2.3183396,,,,,,,,,,,,,, -141200,2.1884768,3.3190153,,,,,,,,,,,,,, -141300,2.34532,1.995027,,,,,,,,,,,,,, -141400,2.375298,3.65534,,,,,,,,,,,,,, -141500,2.1339753,1.8995519,,,,,,,,,,,,,, -141600,2.190962,4.11322,,,,,,,,,,,,,, -141700,2.0445673,3.9303422,,,,,,,,,,,,,, -141714,,,0.7934374809265137,0.8436886668205261,0.7143200039863586,1.187924861907959,50000.0,0.5871000289916992,1.8304020166397093,10000.0,66424.64887809753,74745.89710235596,66424.64887809753,8306.357367038727,6.814566850662232,0.0 -141800,2.3540134,1.7905747,,,,,,,,,,,,,, -141900,2.2953448,1.8149774,,,,,,,,,,,,,, -142000,2.2303457,1.7846574,,,,,,,,,,,,,, -142100,3.0630047,3.8678706,,,,,,,,,,,,,, -142200,2.4299877,1.7102711,,,,,,,,,,,,,, -142300,2.1596441,1.744715,,,,,,,,,,,,,, -142400,2.2366998,2.3936045,,,,,,,,,,,,,, -142500,2.47351,1.7924883,,,,,,,,,,,,,, -142600,2.0490515,2.2621694,,,,,,,,,,,,,, -142603,,,0.8104296922683716,0.7512505054473877,0.7206400036811829,1.1383086442947388,50000.0,0.5949000120162964,1.7849997282028198,10000.0,66844.86011886597,75217.28182458878,66844.86011886597,8357.431869983673,6.8636791706085205,0.0 -142700,2.1010544,2.4323156,,,,,,,,,,,,,, -142800,1.9436021,3.0436149,,,,,,,,,,,,,, -142900,2.0438862,1.6721859,,,,,,,,,,,,,, -143000,2.082737,1.7064805,,,,,,,,,,,,,, -143100,2.1412625,1.7582306,,,,,,,,,,,,,, -143200,2.5516148,1.830695,,,,,,,,,,,,,, -143300,2.079211,2.353871,,,,,,,,,,,,,, -143400,1.935938,2.427242,,,,,,,,,,,,,, -143493,,,0.7852538824081421,0.8588006496429443,0.713699996471405,1.1768875122070312,50000.0,0.5906000137329102,1.8258694410324097,10000.0,67264.9873714447,75689.96438384056,67264.9873714447,8409.880574703217,6.9201719760894775,0.0 -143500,2.1902025,1.770169,,,,,,,,,,,,,, -143600,2.1045017,1.799111,,,,,,,,,,,,,, -143700,2.3751962,1.8675959,,,,,,,,,,,,,, -143800,2.118664,2.2830262,,,,,,,,,,,,,, -143900,1.8999084,3.3752043,,,,,,,,,,,,,, -144000,2.1138573,3.662823,,,,,,,,,,,,,, -144100,2.1183376,2.6498947,,,,,,,,,,,,,, -144200,2.174626,1.8413696,,,,,,,,,,,,,, -144300,2.7966955,1.7140936,,,,,,,,,,,,,, -144389,,,0.8009960651397705,0.8065456748008728,0.7186799645423889,1.1656146049499512,50000.0,0.5975000262260437,1.7978652715682983,10000.0,67685.38305974007,76162.98290419579,67685.38305974007,8462.395847082138,6.977154016494751,0.0 -144400,1.9165038,3.7503695,,,,,,,,,,,,,, -144500,3.732419,1.7336075,,,,,,,,,,,,,, -144600,2.2157087,3.2340252,,,,,,,,,,,,,, -144700,2.208966,2.8376627,,,,,,,,,,,,,, -144800,2.1595562,4.265273,,,,,,,,,,,,,, -144900,2.172741,1.7164214,,,,,,,,,,,,,, -145000,2.3948967,4.1906095,,,,,,,,,,,,,, -145100,2.3236935,1.6515784,,,,,,,,,,,,,, -145200,2.2850163,1.7517334,,,,,,,,,,,,,, -145286,,,0.8099414110183716,0.7473099231719971,0.720579981803894,1.1367037296295166,50000.0,0.593500018119812,1.777706503868103,10000.0,68105.52442455292,76635.59283566475,68105.52442455292,8514.762991905212,7.027865171432495,0.0 -145300,2.1729188,2.034401,,,,,,,,,,,,,, -145400,2.2597733,1.7712953,,,,,,,,,,,,,, -145500,2.039679,2.7292888,,,,,,,,,,,,,, -145600,2.2309082,3.350417,,,,,,,,,,,,,, -145700,2.5301032,4.2458687,,,,,,,,,,,,,, -145800,2.2549207,1.7760897,,,,,,,,,,,,,, -145900,2.1447139,1.8284092,,,,,,,,,,,,,, -146000,2.2000632,1.9269705,,,,,,,,,,,,,, -146100,2.118146,2.219215,,,,,,,,,,,,,, -146180,,,0.7946679592132568,0.8052504062652588,0.7216199636459351,1.132479190826416,50000.0,0.5982000231742859,1.7780965566635132,10000.0,68525.59720563889,77108.97180509567,68525.59720563889,8567.967049360275,7.079846858978272,0.0 -146200,2.6052303,2.9310956,,,,,,,,,,,,,, -146300,2.3672493,2.200394,,,,,,,,,,,,,, -146400,2.1668198,3.2707572,,,,,,,,,,,,,, -146500,2.0239968,3.9287634,,,,,,,,,,,,,, -146600,2.4300601,1.6547313,,,,,,,,,,,,,, -146700,2.2813363,1.724675,,,,,,,,,,,,,, -146800,2.4077961,1.7599599,,,,,,,,,,,,,, -146900,2.3189619,1.6952155,,,,,,,,,,,,,, -147000,2.2109287,2.950681,,,,,,,,,,,,,, -147069,,,0.8073241710662842,0.7616297602653503,0.7257999777793884,1.1163734197616575,50000.0,0.5991000533103943,1.7480182647705078,10000.0,68945.71439909935,77580.7287735939,68945.71439909935,8619.504743099213,7.131839036941528,0.0 -147100,2.228856,4.2687616,,,,,,,,,,,,,, -147200,2.5447032,2.8792949,,,,,,,,,,,,,, -147300,2.4812975,1.8743334,,,,,,,,,,,,,, -147400,2.4995477,1.8868283,,,,,,,,,,,,,, -147500,2.2119129,1.7547705,,,,,,,,,,,,,, -147600,2.4935327,1.7735844,,,,,,,,,,,,,, -147700,2.6964483,1.6813993,,,,,,,,,,,,,, -147800,2.1215184,1.7240286,,,,,,,,,,,,,, -147900,2.349162,1.9680861,,,,,,,,,,,,,, -147961,,,0.8156249523162842,0.7296286821365356,0.7245199680328369,1.1153302192687988,50000.0,0.6000000238418579,1.755022406578064,10000.0,69365.70506572723,78053.88421607018,69365.70506572723,8672.559925556183,7.190670490264893,0.0 -148000,2.249868,4.0795155,,,,,,,,,,,,,, -148100,2.4565673,1.5943704,,,,,,,,,,,,,, -148200,2.800695,1.6765695,,,,,,,,,,,,,, -148300,2.2925992,3.3088293,,,,,,,,,,,,,, -148400,2.338208,1.968151,,,,,,,,,,,,,, -148500,2.6797967,4.1286774,,,,,,,,,,,,,, -148600,2.3104384,4.1452336,,,,,,,,,,,,,, -148700,2.1273115,2.5543509,,,,,,,,,,,,,, -148800,2.3156254,1.8126061,,,,,,,,,,,,,, -148852,,,0.7974609136581421,0.7928144931793213,0.7253400087356567,1.122349977493286,50000.0,0.6010000109672546,1.7640174627304075,10000.0,69785.79780435562,78525.49350094795,69785.79780435562,8723.976521015167,7.239832878112793,0.0 -148900,2.2393796,2.3086817,,,,,,,,,,,,,, -149000,2.5253732,1.8971869,,,,,,,,,,,,,, -149100,2.4802098,3.068357,,,,,,,,,,,,,, -149200,2.4249105,1.7131609,,,,,,,,,,,,,, -149300,2.473019,1.664513,,,,,,,,,,,,,, -149400,2.609488,3.4454808,,,,,,,,,,,,,, -149500,2.4744499,1.5818729,,,,,,,,,,,,,, -149600,2.1581333,2.8877897,,,,,,,,,,,,,, -149700,2.316568,3.1414561,,,,,,,,,,,,,, -149745,,,0.8090038895606995,0.7413263320922852,0.7287600040435791,1.0949060916900637,50000.0,0.6070000529289246,1.728402614593506,10000.0,70205.98230195045,79000.38216662407,70205.98230195045,8778.575244188309,7.29470419883728,0.0 -149800,2.3692813,3.4627495,,,,,,,,,,,,,, -149900,2.231626,3.9471536,,,,,,,,,,,,,, -150000,2.83232,2.3987317,,,,,,,,,,,,,, -150100,2.4217,1.6861578,,,,,,,,,,,,,, -150200,2.3591344,1.7727892,,,,,,,,,,,,,, -150300,2.6778038,3.848731,,,,,,,,,,,,,, -150400,2.4433253,1.7407997,,,,,,,,,,,,,, -150500,2.650543,2.0263305,,,,,,,,,,,,,, -150600,2.6281614,1.6791536,,,,,,,,,,,,,, -150640,,,0.8140038847923279,0.7625994086265564,0.7268999814987183,1.140964388847351,50000.0,0.5976999998092651,1.7766977548599243,10000.0,70626.07793998718,79473.11308956146,70626.07793998718,8831.101482391357,7.352070093154907,0.0 -150700,2.5381072,1.7961686,,,,,,,,,,,,,, -150800,2.2429848,3.5702808,,,,,,,,,,,,,, -150900,2.4998977,1.8502111,,,,,,,,,,,,,, -151000,2.4075997,3.183477,,,,,,,,,,,,,, -151100,2.5866308,1.7081394,,,,,,,,,,,,,, -151200,2.6026797,1.7702138,,,,,,,,,,,,,, -151300,2.4740117,4.1064215,,,,,,,,,,,,,, -151400,2.465962,1.6090493,,,,,,,,,,,,,, -151500,2.1946034,1.8399955,,,,,,,,,,,,,, -151533,,,0.8054101467132568,0.7942066192626953,0.7263000011444092,1.1323789358139038,50000.0,0.6025000214576721,1.7719186544418335,10000.0,71046.26297879219,79945.56105446815,71046.26297879219,8883.26012802124,7.406408786773682,0.0 -151600,2.27026,1.6363206,,,,,,,,,,,,,, -151700,2.547498,4.089762,,,,,,,,,,,,,, -151800,2.219397,3.1238632,,,,,,,,,,,,,, -151900,2.442816,3.9364598,,,,,,,,,,,,,, -152000,2.5011802,3.252135,,,,,,,,,,,,,, -152100,2.2682722,2.6980186,,,,,,,,,,,,,, -152200,2.5421455,1.6242418,,,,,,,,,,,,,, -152300,2.5412154,2.9148345,,,,,,,,,,,,,, -152400,2.5013902,1.5815808,,,,,,,,,,,,,, -152424,,,0.8122460842132568,0.738717794418335,0.7300999760627747,1.0963596105575562,50000.0,0.6144000291824341,1.7185935974121094,10000.0,71466.43354034424,80421.11392855644,71466.43354034424,8938.53841495514,7.459969520568848,0.0 -152500,2.3556442,1.8263849,,,,,,,,,,,,,, -152600,2.462983,1.7271003,,,,,,,,,,,,,, -152700,2.1910813,3.528183,,,,,,,,,,,,,, -152800,2.4060175,2.5903935,,,,,,,,,,,,,, -152900,2.934462,2.4962833,,,,,,,,,,,,,, -153000,2.589257,1.6260959,,,,,,,,,,,,,, -153100,2.6902275,1.7053721,,,,,,,,,,,,,, -153200,2.2103686,1.8217789,,,,,,,,,,,,,, -153300,2.3207018,3.7447433,,,,,,,,,,,,,, -153321,,,0.81751948595047,0.732594907283783,0.7286399602890015,1.1145501136779783,50000.0,0.6084000468254089,1.746617078781128,10000.0,71886.71799898148,80892.94577765465,71886.71799898148,8989.983581542969,7.511559963226318,0.0 -153400,2.4108768,2.674418,,,,,,,,,,,,,, -153500,2.270848,2.690497,,,,,,,,,,,,,, -153600,2.4765604,2.14304,,,,,,,,,,,,,, -153700,2.482671,1.8010597,,,,,,,,,,,,,, -153800,2.4935324,3.4712477,,,,,,,,,,,,,, -153900,3.1160693,2.7225878,,,,,,,,,,,,,, -154000,2.3095946,2.3914874,,,,,,,,,,,,,, -154100,2.3946438,2.3977294,,,,,,,,,,,,,, -154200,2.5705168,1.5923481,,,,,,,,,,,,,, -154215,,,0.8136718273162842,0.7343711853027344,0.7323399782180786,1.0908153057098389,50000.0,0.6096000075340271,1.717441439628601,10000.0,72306.71941304207,81366.14441990852,72306.71941304207,9043.079141378405,7.5628063678741455,0.0 -154300,2.7343588,4.1524005,,,,,,,,,,,,,, -154400,2.6351323,1.6117927,,,,,,,,,,,,,, -154500,4.797693,1.5718752,,,,,,,,,,,,,, -154600,2.5508978,1.7311546,,,,,,,,,,,,,, -154700,2.494427,1.5455734,,,,,,,,,,,,,, -154800,3.9509416,1.723881,,,,,,,,,,,,,, -154900,2.4715774,4.060496,,,,,,,,,,,,,, -155000,2.334967,1.8655827,,,,,,,,,,,,,, -155100,2.311117,1.5746906,,,,,,,,,,,,,, -155112,,,0.8171679377555847,0.7142869234085083,0.7331199645996094,1.077955961227417,50000.0,0.6073000431060791,1.7077174186706543,10000.0,72726.99397587776,81840.67090892792,72726.99397587776,9097.225280284882,7.617284536361694,0.0 -155200,2.2745626,3.1913912,,,,,,,,,,,,,, -155300,2.4225144,1.6772863,,,,,,,,,,,,,, -155400,2.5199988,1.6479324,,,,,,,,,,,,,, -155500,2.5026295,1.6061602,,,,,,,,,,,,,, -155600,2.7515233,1.6402149,,,,,,,,,,,,,, -155700,2.4398744,1.6709011,,,,,,,,,,,,,, -155800,2.2752855,1.4790686,,,,,,,,,,,,,, -155900,2.4154754,2.5610259,,,,,,,,,,,,,, -156000,3.0130184,3.4818954,,,,,,,,,,,,,, -156004,,,0.8233398199081421,0.6954863667488098,0.735040009021759,1.0758261680603027,50000.0,0.612000048160553,1.6971514225006104,10000.0,73147.06365466118,82314.92569756508,73147.06365466118,9151.30518102646,7.670612573623657,0.0 -156100,2.540586,2.0329628,,,,,,,,,,,,,, -156200,2.4541092,2.4400973,,,,,,,,,,,,,, -156300,2.612074,1.565854,,,,,,,,,,,,,, -156400,2.5507112,1.6342094,,,,,,,,,,,,,, -156500,2.2940915,1.7209156,,,,,,,,,,,,,, -156600,2.4412603,1.54363,,,,,,,,,,,,,, -156700,2.718606,4.1003876,,,,,,,,,,,,,, -156800,2.375821,3.065412,,,,,,,,,,,,,, -156897,,,0.8152148127555847,0.7171177268028259,0.7366399765014648,1.0684080123901367,50000.0,0.6095000505447388,1.704157471656799,10000.0,73567.08738541603,82786.26455402374,73567.08738541603,9202.519262313845,7.720755577087402,0.0 -156900,2.6822002,1.4651036,,,,,,,,,,,,,, -157000,2.6335552,1.6141555,,,,,,,,,,,,,, -157100,2.2239773,2.4608395,,,,,,,,,,,,,, -157200,2.415529,2.7965436,,,,,,,,,,,,,, -157300,2.4150763,1.5941646,,,,,,,,,,,,,, -157400,2.5330377,1.5169979,,,,,,,,,,,,,, -157500,2.458045,1.4722488,,,,,,,,,,,,,, -157600,2.1837912,2.0827386,,,,,,,,,,,,,, -157700,2.7347703,1.5036485,,,,,,,,,,,,,, -157788,,,0.8219921588897705,0.6984254121780396,0.7360799908638,1.0686091184616089,50000.0,0.6110000014305115,1.6924991607666016,10000.0,73987.16595101357,83260.5069026947,73987.16595101357,9256.575710058212,7.776738405227661,0.0 -157800,2.601859,1.6355023,,,,,,,,,,,,,, -157900,2.6676033,4.0611706,,,,,,,,,,,,,, -158000,2.597226,2.5297287,,,,,,,,,,,,,, -158100,2.3587465,1.8042668,,,,,,,,,,,,,, -158200,2.834399,2.9914565,,,,,,,,,,,,,, -158300,2.4931846,2.8608265,,,,,,,,,,,,,, -158400,2.3064861,1.4476174,,,,,,,,,,,,,, -158500,2.6566699,4.062558,,,,,,,,,,,,,, -158600,2.4988923,1.5689061,,,,,,,,,,,,,, -158683,,,0.8245312571525574,0.7121429443359375,0.7342599630355835,1.108445644378662,50000.0,0.6097000241279602,1.7294844388961792,10000.0,74407.3215315342,83733.14594745636,74407.3215315342,9308.94518852234,7.840383529663086,0.0 -158700,3.3475034,1.5501237,,,,,,,,,,,,,, -158800,2.4390657,1.4908953,,,,,,,,,,,,,, -158900,2.4389133,2.577374,,,,,,,,,,,,,, -159000,2.5935411,1.6049656,,,,,,,,,,,,,, -159100,2.532316,1.5138074,,,,,,,,,,,,,, -159200,2.6309144,2.2380002,,,,,,,,,,,,,, -159300,2.5352929,2.2574232,,,,,,,,,,,,,, -159400,2.646417,1.6166683,,,,,,,,,,,,,, -159500,2.5359514,1.4526166,,,,,,,,,,,,,, -159580,,,0.8193163871765137,0.7136068344116211,0.7355799674987793,1.0825302600860596,50000.0,0.6133000254631042,1.709965705871582,10000.0,74827.42336130142,84206.77592754364,74827.42336130142,9362.36895108223,7.894147157669067,0.0 -159600,2.4053102,1.5627055,,,,,,,,,,,,,, -159700,3.336613,1.8088138,,,,,,,,,,,,,, -159800,3.023626,4.0654054,,,,,,,,,,,,,, -159900,2.8182096,1.5306833,,,,,,,,,,,,,, -160000,3.0125718,3.214889,,,,,,,,,,,,,, -160100,2.6061442,3.8908675,,,,,,,,,,,,,, -160200,2.463943,1.5651292,,,,,,,,,,,,,, -160300,2.731194,1.5944016,,,,,,,,,,,,,, -160400,2.558893,2.142712,,,,,,,,,,,,,, -160474,,,0.8253905773162842,0.6981134414672852,0.7397399544715881,1.0691667795181274,50000.0,0.616100013256073,1.6938591003417969,10000.0,75247.45334744453,84680.72242164612,75247.45334744453,9416.1791973114,7.950130701065063,0.0 -160500,2.3478143,1.3679652,,,,,,,,,,,,,, -160600,2.4530258,1.5316625,,,,,,,,,,,,,, -160700,2.606149,1.5203289,,,,,,,,,,,,,, -160800,2.4478872,2.1801217,,,,,,,,,,,,,, -160900,2.4073844,2.1857784,,,,,,,,,,,,,, -161000,2.8533278,1.9059868,,,,,,,,,,,,,, -161100,2.931777,1.6761302,,,,,,,,,,,,,, -161200,2.6564732,1.575817,,,,,,,,,,,,,, -161300,2.3216205,1.5802556,,,,,,,,,,,,,, -161367,,,0.8282226324081421,0.6513522267341614,0.7389799952507019,1.0421522855758667,50000.0,0.6173000335693359,1.6790910959243774,10000.0,75667.55098891258,85152.55523252487,75667.55098891258,9467.809502601624,8.004290342330933,0.0 -161400,2.884806,3.837165,,,,,,,,,,,,,, -161500,2.5705683,2.5018368,,,,,,,,,,,,,, -161600,2.6755333,2.3108604,,,,,,,,,,,,,, -161700,2.697876,1.4603163,,,,,,,,,,,,,, -161800,3.0240755,3.9833097,,,,,,,,,,,,,, -161900,2.4228847,2.6728363,,,,,,,,,,,,,, -162000,2.5476172,1.9087665,,,,,,,,,,,,,, -162100,2.5611253,1.477732,,,,,,,,,,,,,, -162200,2.7699833,2.5247583,,,,,,,,,,,,,, -162260,,,0.8219531178474426,0.7046661376953125,0.739039957523346,1.0686395168304443,50000.0,0.6164000034332275,1.6963045597076416,10000.0,76087.53114652634,85624.00301170349,76087.53114652634,9519.163709878922,8.066744565963745,0.0 -162300,2.5316598,1.5665929,,,,,,,,,,,,,, -162400,2.3785515,2.688004,,,,,,,,,,,,,, -162500,2.5794654,1.5945103,,,,,,,,,,,,,, -162600,2.5314193,1.2565832,,,,,,,,,,,,,, -162700,2.4671843,3.354564,,,,,,,,,,,,,, -162800,2.5860226,2.1072292,,,,,,,,,,,,,, -162900,2.7557406,1.5669909,,,,,,,,,,,,,, -163000,2.4432552,1.8893176,,,,,,,,,,,,,, -163100,2.8817356,1.5437325,,,,,,,,,,,,,, -163149,,,0.8304882645606995,0.668041467666626,0.7397800087928772,1.0610032081604004,50000.0,0.6152999997138977,1.6925363540649414,10000.0,76507.66302323341,86099.32566308975,76507.66302323341,9574.24474644661,8.126704692840576,0.0 -163200,3.2740664,1.4628197,,,,,,,,,,,,,, -163300,3.1409369,1.5154592,,,,,,,,,,,,,, -163400,2.762637,1.5356681,,,,,,,,,,,,,, -163500,2.875778,1.6016119,,,,,,,,,,,,,, -163600,2.771146,1.8909523,,,,,,,,,,,,,, -163700,3.056106,1.6006656,,,,,,,,,,,,,, -163800,2.7485335,3.261585,,,,,,,,,,,,,, -163900,2.863567,1.6482081,,,,,,,,,,,,,, -164000,3.353546,1.5628072,,,,,,,,,,,,,, -164040,,,0.8357031345367432,0.6520687937736511,0.7403199672698975,1.0528931617736816,50000.0,0.6208000183105469,1.6820333003997805,10000.0,76927.88128042221,86571.4026389122,76927.88128042221,9625.999377965927,8.179861307144165,0.0 -164100,2.9199858,1.5688834,,,,,,,,,,,,,, -164200,2.6340313,1.4620807,,,,,,,,,,,,,, -164300,2.7041605,1.5050765,,,,,,,,,,,,,, -164400,2.8826616,1.5338855,,,,,,,,,,,,,, -164500,2.872771,1.4872766,,,,,,,,,,,,,, -164600,2.647962,1.4850419,,,,,,,,,,,,,, -164700,2.5595596,1.7212489,,,,,,,,,,,,,, -164800,3.105859,1.5344689,,,,,,,,,,,,,, -164900,4.0795016,2.847735,,,,,,,,,,,,,, -164936,,,0.8275195360183716,0.6876667737960815,0.742859959602356,1.0580904483795166,50000.0,0.6177000403404236,1.687680959701538,10000.0,77348.1300997734,87044.53920483589,77348.1300997734,9678.77582669258,8.240126609802246,0.0 -165000,2.734139,1.4487802,,,,,,,,,,,,,, -165100,2.7881846,1.5835046,,,,,,,,,,,,,, -165200,2.6793187,2.6797452,,,,,,,,,,,,,, -165300,3.1160948,3.734395,,,,,,,,,,,,,, -165309,,,,,,,,,,,77520.39787626266,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 4e8369c5f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -171.2729423046112,0.0,61.41118311882019,1,0,61.41118311882019,30.214182,2472,0.9757885970791949,232.6841826438904,31.169794,1.120843137663203,30.090126,5348,0.9587360128213792 -274.23418140411377,0.0439934730529785,1501.6815497875214,1862,0,1501.6815497875214,5.7778125,2472,0.8994576808238377,1776.0325355529783,5.747149,0.9443883751107084,5.8983974,5348,0.8964635005841065 -398.6696357727051,0.0901834964752197,2942.039139032364,3751,0,2942.039139032364,2.7431552,2472,0.5617167347104585,3340.9508938789368,2.6828344,0.569682763764217,3.0748646,5348,0.6164013246183998 -526.5003283023834,0.1402344703674316,4382.566341638565,5646,0,4382.566341638565,0.7743883,2472,0.2493043283976194,4909.4394245147705,0.7053174,0.2440659740341888,1.0583082,5348,0.3103391679619993 -654.8250741958618,0.1919043064117431,5822.735749721527,7541,0,5822.735749721527,0.5533391,2472,0.185343164137875,6478.064398050308,0.50713336,0.1764339347987726,0.8168226,5348,0.2444944340925109 -782.5141160488129,0.2431385517120361,7262.69296336174,9419,0,7262.69296336174,0.47338814,2472,0.1583693863871793,8045.841583013535,0.42972216,0.1516395224945095,0.7233104,5348,0.2171814205856512 -911.7616634368896,0.2947139739990234,8702.930635690689,11299,0,8702.930635690689,0.4285467,2472,0.1473605102268803,9615.45825767517,0.4041753,0.1443203823932654,0.6717132,5348,0.2063585545053438 -1042.6672360897064,0.353374719619751,10143.44744372368,13172,0,10143.44744372368,0.3915885,2472,0.1337517518737432,11187.017598867416,0.34445938,0.1244815263863186,0.6240188,5348,0.1905442327929946 -1172.421515226364,0.408015489578247,11583.814504861832,15049,0,11583.814504861832,0.36472577,2472,0.1249162147340198,12757.273002386091,0.29312336,0.1087287760992225,0.5906545,5348,0.1808702704268322 -1301.5782821178436,0.4575080871582031,13024.158386707306,16926,0,13024.158386707306,0.3517897,2472,0.1186805597871346,14326.901668548584,0.27578232,0.1014794992791364,0.57158834,5348,0.1736003166726203 -1430.684591293335,0.510749340057373,14464.73980808258,18807,0,14464.73980808258,0.3391219,2472,0.1144354396441411,15896.72179198265,0.26276767,0.099878262641647,0.5514074,5348,0.1670641165509717 -1561.4255406856537,0.5602102279663086,15905.15338897705,20678,0,15905.15338897705,0.32171547,2472,0.1073466983527308,17468.002720832825,0.2670035,0.0969882863111134,0.5288567,5348,0.159601069735559 -1692.6593585014343,0.6153922080993652,17345.456032276154,22560,0,17345.456032276154,0.31339175,2472,0.1064936120082058,19039.6722035408,0.2625285,0.0952119661276281,0.5192188,5348,0.1560867760217036 -1823.7593805789948,0.6687924861907959,18785.56553816796,24436,0,18785.56553816796,0.30653575,2472,0.1035281213819998,20611.01500535012,0.2446675,0.0893456116884453,0.50914836,5348,0.1537793139403535 -1953.7842502594,0.7299957275390625,20225.87024140358,26310,0,20225.87024140358,0.29196113,2472,0.098531472792639,22181.484354496,0.22402444,0.0829937502639246,0.4923003,5348,0.1487106210838313 -2083.775384426117,0.7895841598510742,21666.047757864,28188,0,21666.047757864,0.28334022,2472,0.0969268580017468,23751.79039502144,0.21112266,0.0805446151438483,0.47902232,5348,0.1447329040230939 -2213.889081478119,0.8440456390380859,23105.96652579308,30062,0,23105.96652579308,0.2829486,2472,0.0935957589421729,25321.95567250252,0.21464497,0.0779781325750989,0.48762044,5348,0.1455149309209573 -2344.442668914795,0.900031566619873,24546.469081401825,31936,0,24546.469081401825,0.27928066,2472,0.0920114557309122,26893.146836042404,0.22176528,0.0809282371294851,0.46996224,5348,0.1390945866360292 -2474.185190677643,0.9610598087310792,25986.34561252594,33828,0,25986.34561252594,0.2653281,2472,0.08896471878618,28462.90608000756,0.20109509,0.0728218818333251,0.4561879,5348,0.1363719744730973 -2606.3973863124847,1.0138370990753174,27426.30551338196,35705,0,27426.30551338196,0.26111528,2472,0.0867304450267097,30035.2096157074,0.20980647,0.0764939461342072,0.45002848,5348,0.1360244069629357 -2736.506334066391,1.0663466453552246,28866.55081510544,37585,0,28866.55081510544,0.2542946,2472,0.0849024028598704,31605.695909023285,0.19981158,0.0712032567111054,0.4394024,5348,0.1306178012493121 -2867.031692266464,1.1251184940338137,30306.92392349243,39462,0,30306.92392349243,0.24808052,2472,0.0844352365283448,33176.73208355904,0.15873039,0.0605711181492984,0.43230936,5348,0.1297585371269683 -2997.738777399063,1.1801607608795166,31747.11150074005,41334,0,31747.11150074005,0.24189256,2472,0.0804541669205614,34747.760195970535,0.17622687,0.0637209351439251,0.4257966,5348,0.1274896936578584 -3127.2881696224213,1.2326452732086182,33187.53070783615,43207,0,33187.53070783615,0.23688465,2472,0.0798245079519834,36317.85943412781,0.21449152,0.0781664433752976,0.4130111,5348,0.1229423520665784 -3256.428151845932,1.2872276306152344,34627.65042424202,45073,0,34627.65042424202,0.23037176,2472,0.0759043730830946,37887.25235772133,0.21359752,0.0784236442263388,0.4034216,5348,0.1212141691688309 -3384.257961034775,1.3474400043487549,36067.74671292305,46943,0,36067.74671292305,0.22683704,2472,0.0778339731480917,39455.31745886803,0.25127593,0.0932313937912033,0.40328342,5348,0.1205769620668681 -3512.821098566056,1.4097182750701904,37508.35221242905,48821,0,37508.35221242905,0.22061086,2472,0.0718217455771535,41024.62561440468,0.21432593,0.0774040471242103,0.39701816,5348,0.11645442521023 -3641.26109457016,1.4670865535736084,38948.3796274662,50688,0,38948.3796274662,0.21984032,2472,0.0727560782402047,42593.2292907238,0.19890381,0.0744825359357776,0.39558187,5348,0.1149579539859235 -3772.089239597321,1.5247292518615725,40388.37187838554,52542,0,40388.37187838554,0.21027523,2472,0.0703796234233136,44164.18422412872,0.1622737,0.0626379840998814,0.37560198,5348,0.1109223090068258 -3901.5167338848114,1.5817480087280271,41828.6324532032,54436,0,41828.6324532032,0.20149006,2472,0.068429711778685,45734.00941133499,0.17281756,0.0661337011953705,0.36489117,5348,0.1076783455786516 -4032.1063737869263,1.6375198364257812,43269.12178993225,56319,0,43269.12178993225,0.19897437,2472,0.0670079012044766,47305.22345614433,0.14977348,0.0572549819618622,0.36254397,5348,0.1062880755380055 -4161.976230621338,1.697356939315796,44709.46357750893,58199,0,44709.46357750893,0.18987672,2472,0.0646720695468486,48875.57385182381,0.14436826,0.056059406155415,0.3545192,5348,0.1020496828446469 -4290.845764160156,1.7542970180511477,46149.45885229111,60080,0,46149.45885229111,0.18780683,2472,0.0621128105132736,50444.575095653534,0.1338687,0.0509979296066252,0.34373295,5348,0.0988926112940131 -4419.383289575577,1.8110344409942627,47589.61220765114,61946,0,47589.61220765114,0.18222867,2472,0.0603660146649604,52013.4002263546,0.13237758,0.0509769582088527,0.33442134,5348,0.0973189028452262 -4550.955982685089,1.872158527374268,49029.50516271591,63814,0,49029.50516271591,0.17482771,2472,0.0576239514147015,53585.0060338974,0.12851778,0.0490040187760703,0.32344717,5348,0.0930418915396275 -4680.961303472519,1.9368162155151367,50469.89919543266,65679,0,50469.89919543266,0.17353387,2472,0.0559584018849145,55155.54966759682,0.10282077,0.0405429641087721,0.31785193,5348,0.090879249254178 -4809.875825405121,1.9941620826721191,51910.30626130104,67553,0,51910.30626130104,0.16759367,2472,0.0542725407754961,56725.00738286972,0.098737516,0.037794943179225,0.3111092,5348,0.0892283035809108 -4938.435527801514,2.053472757339477,53350.63709568977,69429,0,53350.63709568977,0.16458964,2472,0.0534803891698657,58294.03611445427,0.08786926,0.0342476454046333,0.30810562,5348,0.0879442347239252 -5067.244613170624,2.112919569015503,54790.89835476875,71305,0,54790.89835476875,0.16141044,2472,0.0517539049011841,59863.24539685249,0.08843144,0.0338653017302732,0.30182338,5348,0.0861291599486372 -5196.659123182297,2.1780107021331787,56231.20172023773,73177,0,56231.20172023773,0.16003639,2472,0.0511242459326061,61433.10589194298,0.08011135,0.030846805751213,0.30042586,5348,0.084931982969192 -5326.510108470917,2.247305393218994,57671.870633125305,75065,0,57671.870633125305,0.15909529,2472,0.0502508480084496,63003.7757563591,0.07875839,0.0295318951108552,0.2965718,5348,0.0830493256224837 -5456.7151482105255,2.3087172508239746,59112.54464316368,76942,0,59112.54464316368,0.15672454,2472,0.0505352101232912,64574.79559850693,0.06596456,0.0250914497973121,0.2928794,5348,0.0825665929694816 -5584.824814081192,2.3721067905426025,60552.484139442444,78817,0,60552.484139442444,0.15655187,2472,0.05000710905287104,66142.98616504669,0.06438663,0.02406642764495647,0.29217574,5348,0.0821224789287197 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 104eb3b96..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,840 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,66.69013,31.531,,,,,,,,,,,,,, -1,,,31.169794,1.120843137663203,30.090126,0.9587360128213792,5348.0,30.214182,0.9757885970791949,2472.0,61.41118311882019,232.6841826438904,61.41118311882019,171.2729423046112,0.0,0.0 -100,9.565185,6.9463134,,,,,,,,,,,,,, -200,1.9593261,6.087654,,,,,,,,,,,,,, -300,0.6718752,5.8619585,,,,,,,,,,,,,, -400,0.38509175,5.810034,,,,,,,,,,,,,, -500,0.56165564,5.8057303,,,,,,,,,,,,,, -600,0.55292475,5.8009467,,,,,,,,,,,,,, -700,0.47328255,5.8037624,,,,,,,,,,,,,, -800,0.27119768,5.795229,,,,,,,,,,,,,, -900,0.28185552,5.816907,,,,,,,,,,,,,, -1000,0.28820735,5.815549,,,,,,,,,,,,,, -1100,0.31184426,5.8074646,,,,,,,,,,,,,, -1200,0.31463575,5.762359,,,,,,,,,,,,,, -1300,0.42102736,5.7439528,,,,,,,,,,,,,, -1400,0.35533693,5.6535335,,,,,,,,,,,,,, -1500,3.3306751,5.564676,,,,,,,,,,,,,, -1600,1.0550107,5.4318275,,,,,,,,,,,,,, -1700,1.8448768,5.1779666,,,,,,,,,,,,,, -1800,1.3177139,4.656218,,,,,,,,,,,,,, -1862,,,5.747149,0.9443883751107084,5.8983974,0.8964635005841065,5348.0,5.7778125,0.8994576808238377,2472.0,1501.6815497875214,1776.0325355529783,1501.6815497875214,274.23418140411377,0.0439934730529785,0.0 -1900,0.8607091,4.1603136,,,,,,,,,,,,,, -2000,1.03487,3.8494015,,,,,,,,,,,,,, -2100,1.7096262,3.6005912,,,,,,,,,,,,,, -2200,0.8949761,3.4151092,,,,,,,,,,,,,, -2300,1.3796263,3.2684035,,,,,,,,,,,,,, -2400,1.1138763,3.1437328,,,,,,,,,,,,,, -2500,1.2749282,3.089414,,,,,,,,,,,,,, -2600,1.23498,2.973093,,,,,,,,,,,,,, -2700,1.1549097,2.8670259,,,,,,,,,,,,,, -2800,0.9939264,2.8493512,,,,,,,,,,,,,, -2900,1.1062951,2.756405,,,,,,,,,,,,,, -3000,1.1371177,2.6627083,,,,,,,,,,,,,, -3100,1.4020548,2.587855,,,,,,,,,,,,,, -3200,1.098178,2.5496776,,,,,,,,,,,,,, -3300,1.0725197,2.5222704,,,,,,,,,,,,,, -3400,0.99578625,2.5107338,,,,,,,,,,,,,, -3500,1.2155474,2.5608363,,,,,,,,,,,,,, -3600,1.0691515,2.4582953,,,,,,,,,,,,,, -3700,0.9240093,2.453467,,,,,,,,,,,,,, -3751,,,2.6828344,0.569682763764217,3.0748646,0.6164013246183998,5348.0,2.7431552,0.5617167347104585,2472.0,2942.039139032364,3340.9508938789368,2942.039139032364,398.6696357727051,0.0901834964752197,0.0 -3800,1.0100827,2.40849,,,,,,,,,,,,,, -3900,0.9416064,2.391762,,,,,,,,,,,,,, -4000,1.0081433,2.3183806,,,,,,,,,,,,,, -4100,1.022526,2.318294,,,,,,,,,,,,,, -4200,1.1654165,2.2687285,,,,,,,,,,,,,, -4300,0.8549441,2.244625,,,,,,,,,,,,,, -4400,1.1038784,2.227096,,,,,,,,,,,,,, -4500,0.852466,2.2119825,,,,,,,,,,,,,, -4600,0.9825778,2.1164784,,,,,,,,,,,,,, -4700,1.0133792,2.159534,,,,,,,,,,,,,, -4800,0.9873033,2.1141934,,,,,,,,,,,,,, -4900,1.0554085,2.0910184,,,,,,,,,,,,,, -5000,0.93831736,2.0768144,,,,,,,,,,,,,, -5100,1.2057954,2.0619051,,,,,,,,,,,,,, -5200,1.0904727,2.0259063,,,,,,,,,,,,,, -5300,0.8407374,1.9956805,,,,,,,,,,,,,, -5400,0.8009201,1.9859077,,,,,,,,,,,,,, -5500,0.80352914,1.9585952,,,,,,,,,,,,,, -5600,0.85937536,1.964231,,,,,,,,,,,,,, -5646,,,0.7053174,0.2440659740341888,1.0583082,0.3103391679619993,5348.0,0.7743883,0.2493043283976194,2472.0,4382.566341638565,4909.4394245147705,4382.566341638565,526.5003283023834,0.1402344703674316,0.0 -5700,0.87035775,1.9809304,,,,,,,,,,,,,, -5800,0.787232,1.8711938,,,,,,,,,,,,,, -5900,0.7396848,1.9024634,,,,,,,,,,,,,, -6000,0.85913306,1.916235,,,,,,,,,,,,,, -6100,0.7719303,1.8576413,,,,,,,,,,,,,, -6200,0.9366944,1.8264126,,,,,,,,,,,,,, -6300,0.7761312,1.8683023,,,,,,,,,,,,,, -6400,0.8374709,1.8117625,,,,,,,,,,,,,, -6500,0.8544618,1.8106728,,,,,,,,,,,,,, -6600,0.778047,1.8085423,,,,,,,,,,,,,, -6700,0.76687646,1.8266183,,,,,,,,,,,,,, -6800,0.88767594,1.7801664,,,,,,,,,,,,,, -6900,0.7998025,1.7990017,,,,,,,,,,,,,, -7000,0.6995809,1.8089582,,,,,,,,,,,,,, -7100,0.8091625,1.7821612,,,,,,,,,,,,,, -7200,0.8105731,1.7243903,,,,,,,,,,,,,, -7300,0.81608874,1.7461658,,,,,,,,,,,,,, -7400,0.712658,1.7411777,,,,,,,,,,,,,, -7500,0.8408201,1.7457056,,,,,,,,,,,,,, -7541,,,0.50713336,0.1764339347987726,0.8168226,0.2444944340925109,5348.0,0.5533391,0.185343164137875,2472.0,5822.735749721527,6478.064398050308,5822.735749721527,654.8250741958618,0.1919043064117431,0.0 -7600,0.7488539,1.7029226,,,,,,,,,,,,,, -7700,0.7948956,1.693045,,,,,,,,,,,,,, -7800,0.9873561,1.7484093,,,,,,,,,,,,,, -7900,0.72148806,1.7535887,,,,,,,,,,,,,, -8000,0.75510865,1.7729282,,,,,,,,,,,,,, -8100,0.76202077,1.6695292,,,,,,,,,,,,,, -8200,0.7260573,1.7304232,,,,,,,,,,,,,, -8300,0.8241089,1.7542726,,,,,,,,,,,,,, -8400,0.71107626,1.6714071,,,,,,,,,,,,,, -8500,0.72129554,1.7649757,,,,,,,,,,,,,, -8600,0.7782079,1.7236199,,,,,,,,,,,,,, -8700,0.88636726,1.7851588,,,,,,,,,,,,,, -8800,0.8885135,1.665623,,,,,,,,,,,,,, -8900,0.8816513,1.6452756,,,,,,,,,,,,,, -9000,0.7635385,1.6836286,,,,,,,,,,,,,, -9100,0.67089385,1.7557575,,,,,,,,,,,,,, -9200,0.701615,1.6851546,,,,,,,,,,,,,, -9300,0.70184636,1.6640344,,,,,,,,,,,,,, -9400,0.7046346,1.6038537,,,,,,,,,,,,,, -9419,,,0.42972216,0.1516395224945095,0.7233104,0.2171814205856512,5348.0,0.47338814,0.1583693863871793,2472.0,7262.69296336174,8045.841583013535,7262.69296336174,782.5141160488129,0.2431385517120361,0.0 -9500,0.699842,1.5961632,,,,,,,,,,,,,, -9600,0.6722065,1.6244013,,,,,,,,,,,,,, -9700,0.7876577,1.6474019,,,,,,,,,,,,,, -9800,0.61563605,1.5779344,,,,,,,,,,,,,, -9900,0.68129385,1.592373,,,,,,,,,,,,,, -10000,0.8358457,1.615575,,,,,,,,,,,,,, -10100,0.67894334,1.6222497,,,,,,,,,,,,,, -10200,0.7195215,1.5980628,,,,,,,,,,,,,, -10300,0.6681297,1.4959271,,,,,,,,,,,,,, -10400,0.7610432,1.6392272,,,,,,,,,,,,,, -10500,0.7508427,1.5749208,,,,,,,,,,,,,, -10600,0.68081695,1.6716231,,,,,,,,,,,,,, -10700,0.73557436,1.5796406,,,,,,,,,,,,,, -10800,0.61239374,1.5382931,,,,,,,,,,,,,, -10900,0.71644944,1.5807471,,,,,,,,,,,,,, -11000,0.7797974,1.637205,,,,,,,,,,,,,, -11100,0.60893565,1.5938857,,,,,,,,,,,,,, -11200,0.6442448,1.5255243,,,,,,,,,,,,,, -11299,,,0.4041753,0.1443203823932654,0.6717132,0.2063585545053438,5348.0,0.4285467,0.1473605102268803,2472.0,8702.930635690689,9615.45825767517,8702.930635690689,911.7616634368896,0.2947139739990234,0.0 -11300,0.74941033,1.5731817,,,,,,,,,,,,,, -11400,0.67863953,1.5582081,,,,,,,,,,,,,, -11500,0.6946114,1.5257341,,,,,,,,,,,,,, -11600,0.73116016,1.5233706,,,,,,,,,,,,,, -11700,0.6046243,1.5662211,,,,,,,,,,,,,, -11800,0.7530354,1.4854221,,,,,,,,,,,,,, -11900,0.67727584,1.5232136,,,,,,,,,,,,,, -12000,0.70927274,1.6248288,,,,,,,,,,,,,, -12100,0.69838595,1.5822418,,,,,,,,,,,,,, -12200,0.8201416,1.5193659,,,,,,,,,,,,,, -12300,0.67760724,1.5809399,,,,,,,,,,,,,, -12400,0.59253514,1.4925641,,,,,,,,,,,,,, -12500,0.65298206,1.4876981,,,,,,,,,,,,,, -12600,0.6504046,1.4763255,,,,,,,,,,,,,, -12700,0.6396257,1.5240773,,,,,,,,,,,,,, -12800,0.6330907,1.4624283,,,,,,,,,,,,,, -12900,0.750372,1.5601164,,,,,,,,,,,,,, -13000,0.8579196,1.5060463,,,,,,,,,,,,,, -13100,0.8211452,1.5016975,,,,,,,,,,,,,, -13172,,,0.34445938,0.1244815263863186,0.6240188,0.1905442327929946,5348.0,0.3915885,0.1337517518737432,2472.0,10143.44744372368,11187.017598867416,10143.44744372368,1042.6672360897064,0.353374719619751,0.0 -13200,0.6836461,1.4802884,,,,,,,,,,,,,, -13300,0.7150534,1.4017311,,,,,,,,,,,,,, -13400,0.9778839,1.5150985,,,,,,,,,,,,,, -13500,0.76073766,1.4785011,,,,,,,,,,,,,, -13600,0.79910046,1.4654114,,,,,,,,,,,,,, -13700,0.6437007,1.5482904,,,,,,,,,,,,,, -13800,0.7863793,1.5225899,,,,,,,,,,,,,, -13900,0.64525384,1.4721658,,,,,,,,,,,,,, -14000,0.71333224,1.5211916,,,,,,,,,,,,,, -14100,0.6687904,1.4728961,,,,,,,,,,,,,, -14200,0.6915798,1.4989411,,,,,,,,,,,,,, -14300,0.61867607,1.5304865,,,,,,,,,,,,,, -14400,0.661549,1.4263613,,,,,,,,,,,,,, -14500,0.65322584,1.5005324,,,,,,,,,,,,,, -14600,0.5709814,1.4825426,,,,,,,,,,,,,, -14700,0.66027826,1.4826376,,,,,,,,,,,,,, -14800,0.650235,1.4859744,,,,,,,,,,,,,, -14900,0.6738036,1.4553521,,,,,,,,,,,,,, -15000,0.6356708,1.4286357,,,,,,,,,,,,,, -15049,,,0.29312336,0.1087287760992225,0.5906545,0.1808702704268322,5348.0,0.36472577,0.1249162147340198,2472.0,11583.814504861832,12757.273002386091,11583.814504861832,1172.421515226364,0.408015489578247,0.0 -15100,0.7589359,1.4496739,,,,,,,,,,,,,, -15200,0.69625115,1.5065051,,,,,,,,,,,,,, -15300,0.6419396,1.4509568,,,,,,,,,,,,,, -15400,0.61288536,1.5101491,,,,,,,,,,,,,, -15500,0.6513035,1.4463104,,,,,,,,,,,,,, -15600,0.7202708,1.3965708,,,,,,,,,,,,,, -15700,0.64773417,1.4548718,,,,,,,,,,,,,, -15800,0.62324613,1.4547745,,,,,,,,,,,,,, -15900,0.71662736,1.469502,,,,,,,,,,,,,, -16000,0.7207353,1.5004574,,,,,,,,,,,,,, -16100,0.7482611,1.423389,,,,,,,,,,,,,, -16200,0.73431134,1.3828017,,,,,,,,,,,,,, -16300,0.63413215,1.475272,,,,,,,,,,,,,, -16400,0.6826658,1.4657258,,,,,,,,,,,,,, -16500,0.7023965,1.4155327,,,,,,,,,,,,,, -16600,0.61838543,1.4071107,,,,,,,,,,,,,, -16700,0.7166536,1.3869901,,,,,,,,,,,,,, -16800,0.64497405,1.4282213,,,,,,,,,,,,,, -16900,0.74096614,1.4672273,,,,,,,,,,,,,, -16926,,,0.27578232,0.1014794992791364,0.57158834,0.1736003166726203,5348.0,0.3517897,0.1186805597871346,2472.0,13024.158386707306,14326.901668548584,13024.158386707306,1301.5782821178436,0.4575080871582031,0.0 -17000,0.70082814,1.4276974,,,,,,,,,,,,,, -17100,0.72258854,1.4727057,,,,,,,,,,,,,, -17200,0.87013435,1.4191176,,,,,,,,,,,,,, -17300,0.7821442,1.4194188,,,,,,,,,,,,,, -17400,0.71627957,1.445313,,,,,,,,,,,,,, -17500,0.64658654,1.4999769,,,,,,,,,,,,,, -17600,0.6609753,1.3685737,,,,,,,,,,,,,, -17700,0.752537,1.4245371,,,,,,,,,,,,,, -17800,0.61809176,1.4184912,,,,,,,,,,,,,, -17900,0.6693435,1.4431759,,,,,,,,,,,,,, -18000,0.7239468,1.4564582,,,,,,,,,,,,,, -18100,0.7145688,1.3792908,,,,,,,,,,,,,, -18200,0.7208503,1.4544305,,,,,,,,,,,,,, -18300,0.6253583,1.3986998,,,,,,,,,,,,,, -18400,0.65392846,1.3705212,,,,,,,,,,,,,, -18500,0.6494299,1.4424317,,,,,,,,,,,,,, -18600,0.68398005,1.4565251,,,,,,,,,,,,,, -18700,0.7005792,1.3764288,,,,,,,,,,,,,, -18800,0.6307731,1.4195939,,,,,,,,,,,,,, -18807,,,0.26276767,0.099878262641647,0.5514074,0.1670641165509717,5348.0,0.3391219,0.1144354396441411,2472.0,14464.73980808258,15896.72179198265,14464.73980808258,1430.684591293335,0.510749340057373,0.0 -18900,0.7063847,1.3616959,,,,,,,,,,,,,, -19000,0.6692885,1.4224074,,,,,,,,,,,,,, -19100,0.6435986,1.3749803,,,,,,,,,,,,,, -19200,0.6454998,1.3991154,,,,,,,,,,,,,, -19300,0.68132156,1.42962,,,,,,,,,,,,,, -19400,0.6553459,1.3423587,,,,,,,,,,,,,, -19500,0.78311664,1.4537203,,,,,,,,,,,,,, -19600,0.67981786,1.3166116,,,,,,,,,,,,,, -19700,0.71273285,1.3784205,,,,,,,,,,,,,, -19800,0.7684412,1.3825611,,,,,,,,,,,,,, -19900,0.6583294,1.3511223,,,,,,,,,,,,,, -20000,0.6702344,1.3254366,,,,,,,,,,,,,, -20100,0.64481294,1.4119983,,,,,,,,,,,,,, -20200,0.6818333,1.323582,,,,,,,,,,,,,, -20300,0.6752533,1.3990983,,,,,,,,,,,,,, -20400,0.73301196,1.3788683,,,,,,,,,,,,,, -20500,0.7162825,1.403455,,,,,,,,,,,,,, -20600,0.59727716,1.3224484,,,,,,,,,,,,,, -20678,,,0.2670035,0.0969882863111134,0.5288567,0.159601069735559,5348.0,0.32171547,0.1073466983527308,2472.0,15905.15338897705,17468.002720832825,15905.15338897705,1561.4255406856537,0.5602102279663086,0.0 -20700,0.702958,1.3967649,,,,,,,,,,,,,, -20800,0.6321653,1.3395671,,,,,,,,,,,,,, -20900,0.64468914,1.3645014,,,,,,,,,,,,,, -21000,0.6490954,1.4327943,,,,,,,,,,,,,, -21100,0.5970098,1.3582358,,,,,,,,,,,,,, -21200,0.73320353,1.3746454,,,,,,,,,,,,,, -21300,0.8229604,1.334569,,,,,,,,,,,,,, -21400,0.61493194,1.3764006,,,,,,,,,,,,,, -21500,0.6235714,1.3499659,,,,,,,,,,,,,, -21600,0.62455696,1.3202465,,,,,,,,,,,,,, -21700,0.6829005,1.3464314,,,,,,,,,,,,,, -21800,0.62686753,1.3567951,,,,,,,,,,,,,, -21900,0.7247318,1.3747655,,,,,,,,,,,,,, -22000,0.81782115,1.3934625,,,,,,,,,,,,,, -22100,0.7397035,1.3601991,,,,,,,,,,,,,, -22200,0.9124346,1.3734306,,,,,,,,,,,,,, -22300,0.6823543,1.3674015,,,,,,,,,,,,,, -22400,0.6923336,1.3765016,,,,,,,,,,,,,, -22500,0.56230056,1.3508894,,,,,,,,,,,,,, -22560,,,0.2625285,0.0952119661276281,0.5192188,0.1560867760217036,5348.0,0.31339175,0.1064936120082058,2472.0,17345.456032276154,19039.6722035408,17345.456032276154,1692.6593585014343,0.6153922080993652,0.0 -22600,0.5958206,1.3189534,,,,,,,,,,,,,, -22700,0.7233642,1.3177967,,,,,,,,,,,,,, -22800,0.6923138,1.3893943,,,,,,,,,,,,,, -22900,0.6532291,1.3393391,,,,,,,,,,,,,, -23000,0.7389248,1.3521433,,,,,,,,,,,,,, -23100,0.6303171,1.3387454,,,,,,,,,,,,,, -23200,0.60246104,1.3348571,,,,,,,,,,,,,, -23300,0.7976467,1.3648326,,,,,,,,,,,,,, -23400,0.59563136,1.2854083,,,,,,,,,,,,,, -23500,0.71871674,1.3358126,,,,,,,,,,,,,, -23600,0.6832555,1.2874614,,,,,,,,,,,,,, -23700,0.60424167,1.3273208,,,,,,,,,,,,,, -23800,0.6990731,1.3609582,,,,,,,,,,,,,, -23900,0.6906915,1.3225162,,,,,,,,,,,,,, -24000,0.7154671,1.3676355,,,,,,,,,,,,,, -24100,0.62640154,1.3678324,,,,,,,,,,,,,, -24200,0.7433263,1.3733547,,,,,,,,,,,,,, -24300,0.8229473,1.3556553,,,,,,,,,,,,,, -24400,0.6474116,1.3471171,,,,,,,,,,,,,, -24436,,,0.2446675,0.0893456116884453,0.50914836,0.1537793139403535,5348.0,0.30653575,0.1035281213819998,2472.0,18785.56553816796,20611.01500535012,18785.56553816796,1823.7593805789948,0.6687924861907959,0.0 -24500,0.7262468,1.37673,,,,,,,,,,,,,, -24600,0.62904906,1.3291496,,,,,,,,,,,,,, -24700,0.8196431,1.3524213,,,,,,,,,,,,,, -24800,0.9686103,1.3068719,,,,,,,,,,,,,, -24900,0.6828155,1.3421267,,,,,,,,,,,,,, -25000,0.6532573,1.3285676,,,,,,,,,,,,,, -25100,0.92418295,1.3734734,,,,,,,,,,,,,, -25200,0.7054985,1.3793826,,,,,,,,,,,,,, -25300,0.69424254,1.3595392,,,,,,,,,,,,,, -25400,0.7977364,1.3101265,,,,,,,,,,,,,, -25500,0.60323536,1.2829785,,,,,,,,,,,,,, -25600,0.93729573,1.353108,,,,,,,,,,,,,, -25700,0.7671907,1.2577407,,,,,,,,,,,,,, -25800,0.59884125,1.2929592,,,,,,,,,,,,,, -25900,0.7023635,1.3273178,,,,,,,,,,,,,, -26000,0.71567595,1.3507727,,,,,,,,,,,,,, -26100,0.66280854,1.3350275,,,,,,,,,,,,,, -26200,0.780298,1.2936442,,,,,,,,,,,,,, -26300,0.66410714,1.31515,,,,,,,,,,,,,, -26310,,,0.22402444,0.0829937502639246,0.4923003,0.1487106210838313,5348.0,0.29196113,0.098531472792639,2472.0,20225.87024140358,22181.484354496,20225.87024140358,1953.7842502594,0.7299957275390625,0.0 -26400,0.67734087,1.3854821,,,,,,,,,,,,,, -26500,0.66863173,1.2869501,,,,,,,,,,,,,, -26600,0.7153802,1.2742088,,,,,,,,,,,,,, -26700,0.63437086,1.3633304,,,,,,,,,,,,,, -26800,0.59571487,1.2451917,,,,,,,,,,,,,, -26900,0.7361804,1.3405478,,,,,,,,,,,,,, -27000,0.7243287,1.2255502,,,,,,,,,,,,,, -27100,0.7840028,1.3063864,,,,,,,,,,,,,, -27200,0.68937683,1.3105227,,,,,,,,,,,,,, -27300,0.5899553,1.3185202,,,,,,,,,,,,,, -27400,0.6941704,1.3104608,,,,,,,,,,,,,, -27500,0.5759943,1.2421982,,,,,,,,,,,,,, -27600,0.7174439,1.3599715,,,,,,,,,,,,,, -27700,0.67054826,1.302405,,,,,,,,,,,,,, -27800,0.746701,1.3378111,,,,,,,,,,,,,, -27900,0.6854886,1.3194219,,,,,,,,,,,,,, -28000,0.68435067,1.2662263,,,,,,,,,,,,,, -28100,0.6753524,1.312916,,,,,,,,,,,,,, -28188,,,0.21112266,0.0805446151438483,0.47902232,0.1447329040230939,5348.0,0.28334022,0.0969268580017468,2472.0,21666.047757864,23751.79039502144,21666.047757864,2083.775384426117,0.7895841598510742,0.0 -28200,0.6503355,1.2333318,,,,,,,,,,,,,, -28300,0.75858724,1.272951,,,,,,,,,,,,,, -28400,0.80133975,1.3181745,,,,,,,,,,,,,, -28500,0.67057765,1.3017069,,,,,,,,,,,,,, -28600,0.72070366,1.3050785,,,,,,,,,,,,,, -28700,0.6632138,1.2938523,,,,,,,,,,,,,, -28800,0.6308821,1.2533659,,,,,,,,,,,,,, -28900,0.72883713,1.3310132,,,,,,,,,,,,,, -29000,0.76158196,1.3091564,,,,,,,,,,,,,, -29100,0.67649126,1.2875532,,,,,,,,,,,,,, -29200,0.7853635,1.3057436,,,,,,,,,,,,,, -29300,0.7698339,1.3115561,,,,,,,,,,,,,, -29400,0.65913254,1.2909665,,,,,,,,,,,,,, -29500,0.63907385,1.3187699,,,,,,,,,,,,,, -29600,0.6341792,1.306853,,,,,,,,,,,,,, -29700,0.72303706,1.3073486,,,,,,,,,,,,,, -29800,0.78920585,1.2869258,,,,,,,,,,,,,, -29900,0.7928069,1.2562153,,,,,,,,,,,,,, -30000,0.81102484,1.2619985,,,,,,,,,,,,,, -30062,,,0.21464497,0.0779781325750989,0.48762044,0.1455149309209573,5348.0,0.2829486,0.0935957589421729,2472.0,23105.96652579308,25321.95567250252,23105.96652579308,2213.889081478119,0.8440456390380859,0.0 -30100,0.8217154,1.3249669,,,,,,,,,,,,,, -30200,0.5939515,1.2222052,,,,,,,,,,,,,, -30300,0.8331717,1.2976867,,,,,,,,,,,,,, -30400,0.7875536,1.3215679,,,,,,,,,,,,,, -30500,0.6879002,1.2867556,,,,,,,,,,,,,, -30600,0.78423005,1.3008512,,,,,,,,,,,,,, -30700,0.6442821,1.2931353,,,,,,,,,,,,,, -30800,0.6133264,1.3258315,,,,,,,,,,,,,, -30900,0.67525166,1.2486526,,,,,,,,,,,,,, -31000,0.65058106,1.2808471,,,,,,,,,,,,,, -31100,0.79439145,1.238514,,,,,,,,,,,,,, -31200,0.60602415,1.2879466,,,,,,,,,,,,,, -31300,0.67958677,1.2335844,,,,,,,,,,,,,, -31400,0.65599644,1.2545495,,,,,,,,,,,,,, -31500,0.68962985,1.254968,,,,,,,,,,,,,, -31600,0.83083165,1.2597934,,,,,,,,,,,,,, -31700,0.75443715,1.275175,,,,,,,,,,,,,, -31800,0.70660263,1.2777504,,,,,,,,,,,,,, -31900,0.65063447,1.2734509,,,,,,,,,,,,,, -31936,,,0.22176528,0.0809282371294851,0.46996224,0.1390945866360292,5348.0,0.27928066,0.0920114557309122,2472.0,24546.469081401825,26893.146836042404,24546.469081401825,2344.442668914795,0.900031566619873,0.0 -32000,0.67732376,1.2360806,,,,,,,,,,,,,, -32100,0.7703466,1.2732232,,,,,,,,,,,,,, -32200,0.85139334,1.2489053,,,,,,,,,,,,,, -32300,0.79881126,1.2990429,,,,,,,,,,,,,, -32400,0.7318138,1.2959148,,,,,,,,,,,,,, -32500,0.5417265,1.2589526,,,,,,,,,,,,,, -32600,0.6350395,1.1974016,,,,,,,,,,,,,, -32700,0.7526317,1.2332414,,,,,,,,,,,,,, -32800,0.6696928,1.2271063,,,,,,,,,,,,,, -32900,0.648651,1.3110672,,,,,,,,,,,,,, -33000,0.62349206,1.2299088,,,,,,,,,,,,,, -33100,0.6345776,1.1789224,,,,,,,,,,,,,, -33200,0.87880725,1.309445,,,,,,,,,,,,,, -33300,0.740194,1.2963301,,,,,,,,,,,,,, -33400,0.65681016,1.2711915,,,,,,,,,,,,,, -33500,0.68684953,1.2947642,,,,,,,,,,,,,, -33600,0.6445831,1.2494731,,,,,,,,,,,,,, -33700,0.7354525,1.2318631,,,,,,,,,,,,,, -33800,0.7868837,1.2733513,,,,,,,,,,,,,, -33828,,,0.20109509,0.0728218818333251,0.4561879,0.1363719744730973,5348.0,0.2653281,0.08896471878618,2472.0,25986.34561252594,28462.90608000756,25986.34561252594,2474.185190677643,0.9610598087310792,0.0 -33900,0.8060414,1.3010803,,,,,,,,,,,,,, -34000,0.6541802,1.2813288,,,,,,,,,,,,,, -34100,0.77036476,1.292404,,,,,,,,,,,,,, -34200,0.66259885,1.239016,,,,,,,,,,,,,, -34300,0.69603235,1.2127116,,,,,,,,,,,,,, -34400,0.60892576,1.2706604,,,,,,,,,,,,,, -34500,0.66709465,1.234734,,,,,,,,,,,,,, -34600,0.8850685,1.2656997,,,,,,,,,,,,,, -34700,0.7372185,1.2612338,,,,,,,,,,,,,, -34800,0.6865995,1.239886,,,,,,,,,,,,,, -34900,0.7666632,1.2474078,,,,,,,,,,,,,, -35000,0.6782942,1.244218,,,,,,,,,,,,,, -35100,0.6411277,1.2128004,,,,,,,,,,,,,, -35200,0.8489034,1.2073808,,,,,,,,,,,,,, -35300,0.73920685,1.2216568,,,,,,,,,,,,,, -35400,0.8151676,1.2038112,,,,,,,,,,,,,, -35500,0.8425366,1.1859413,,,,,,,,,,,,,, -35600,0.6607287,1.2385079,,,,,,,,,,,,,, -35700,0.7215768,1.245172,,,,,,,,,,,,,, -35705,,,0.20980647,0.0764939461342072,0.45002848,0.1360244069629357,5348.0,0.26111528,0.0867304450267097,2472.0,27426.30551338196,30035.2096157074,27426.30551338196,2606.3973863124847,1.0138370990753174,0.0 -35800,0.69833297,1.3173654,,,,,,,,,,,,,, -35900,0.7211272,1.2642008,,,,,,,,,,,,,, -36000,0.7322747,1.2012633,,,,,,,,,,,,,, -36100,0.76165193,1.1708101,,,,,,,,,,,,,, -36200,0.78413004,1.2113241,,,,,,,,,,,,,, -36300,0.74546504,1.2528008,,,,,,,,,,,,,, -36400,0.6453593,1.1747024,,,,,,,,,,,,,, -36500,0.7050314,1.2646465,,,,,,,,,,,,,, -36600,0.68679434,1.2157154,,,,,,,,,,,,,, -36700,0.7971533,1.2581421,,,,,,,,,,,,,, -36800,0.7490935,1.1783674,,,,,,,,,,,,,, -36900,0.7949636,1.2659115,,,,,,,,,,,,,, -37000,0.77302134,1.2367941,,,,,,,,,,,,,, -37100,0.7221242,1.1192364,,,,,,,,,,,,,, -37200,0.81518024,1.2206075,,,,,,,,,,,,,, -37300,0.7382368,1.2048185,,,,,,,,,,,,,, -37400,0.7471926,1.2252218,,,,,,,,,,,,,, -37500,0.70860857,1.1917566,,,,,,,,,,,,,, -37585,,,0.19981158,0.0712032567111054,0.4394024,0.1306178012493121,5348.0,0.2542946,0.0849024028598704,2472.0,28866.55081510544,31605.695909023285,28866.55081510544,2736.506334066391,1.0663466453552246,0.0 -37600,0.8211213,1.2319115,,,,,,,,,,,,,, -37700,0.87982845,1.243791,,,,,,,,,,,,,, -37800,0.71876144,1.242815,,,,,,,,,,,,,, -37900,0.7246659,1.2203689,,,,,,,,,,,,,, -38000,0.74392784,1.1974237,,,,,,,,,,,,,, -38100,0.7633817,1.2662363,,,,,,,,,,,,,, -38200,0.71309114,1.2082572,,,,,,,,,,,,,, -38300,0.85076034,1.1703582,,,,,,,,,,,,,, -38400,0.6691698,1.2226564,,,,,,,,,,,,,, -38500,0.69127864,1.2451243,,,,,,,,,,,,,, -38600,0.77535313,1.2198122,,,,,,,,,,,,,, -38700,0.6967082,1.2037222,,,,,,,,,,,,,, -38800,0.8224107,1.2514482,,,,,,,,,,,,,, -38900,0.7390426,1.2298537,,,,,,,,,,,,,, -39000,0.71071225,1.2693522,,,,,,,,,,,,,, -39100,0.7388311,1.2390423,,,,,,,,,,,,,, -39200,0.8147656,1.1669301,,,,,,,,,,,,,, -39300,0.7053178,1.2129296,,,,,,,,,,,,,, -39400,0.6741585,1.1779205,,,,,,,,,,,,,, -39462,,,0.15873039,0.0605711181492984,0.43230936,0.1297585371269683,5348.0,0.24808052,0.0844352365283448,2472.0,30306.92392349243,33176.73208355904,30306.92392349243,2867.031692266464,1.1251184940338137,0.0 -39500,0.8190526,1.216174,,,,,,,,,,,,,, -39600,0.7456697,1.2073752,,,,,,,,,,,,,, -39700,0.7646484,1.2278364,,,,,,,,,,,,,, -39800,0.7299668,1.1770625,,,,,,,,,,,,,, -39900,0.7578431,1.1530585,,,,,,,,,,,,,, -40000,0.71004397,1.2157387,,,,,,,,,,,,,, -40100,0.70685714,1.2230004,,,,,,,,,,,,,, -40200,0.769494,1.1540089,,,,,,,,,,,,,, -40300,0.6981467,1.2281793,,,,,,,,,,,,,, -40400,0.7372605,1.1607745,,,,,,,,,,,,,, -40500,0.96760046,1.2410352,,,,,,,,,,,,,, -40600,0.71238565,1.2104261,,,,,,,,,,,,,, -40700,0.74320817,1.2191632,,,,,,,,,,,,,, -40800,0.78962606,1.1744761,,,,,,,,,,,,,, -40900,0.8473352,1.2067918,,,,,,,,,,,,,, -41000,0.7104983,1.2066494,,,,,,,,,,,,,, -41100,0.7161423,1.1652825,,,,,,,,,,,,,, -41200,0.9554363,1.1706411,,,,,,,,,,,,,, -41300,0.7109407,1.1547921,,,,,,,,,,,,,, -41334,,,0.17622687,0.0637209351439251,0.4257966,0.1274896936578584,5348.0,0.24189256,0.0804541669205614,2472.0,31747.11150074005,34747.760195970535,31747.11150074005,2997.738777399063,1.1801607608795166,0.0 -41400,0.7308323,1.1950796,,,,,,,,,,,,,, -41500,0.885395,1.129426,,,,,,,,,,,,,, -41600,0.7040755,1.2472314,,,,,,,,,,,,,, -41700,0.892941,1.1874429,,,,,,,,,,,,,, -41800,0.7488028,1.1835659,,,,,,,,,,,,,, -41900,0.7868663,1.1913373,,,,,,,,,,,,,, -42000,0.83719593,1.2159866,,,,,,,,,,,,,, -42100,0.68060195,1.1844655,,,,,,,,,,,,,, -42200,0.81168413,1.1548822,,,,,,,,,,,,,, -42300,0.6631085,1.1611575,,,,,,,,,,,,,, -42400,0.7206839,1.1756723,,,,,,,,,,,,,, -42500,0.78715634,1.1543198,,,,,,,,,,,,,, -42600,0.66729414,1.1414576,,,,,,,,,,,,,, -42700,1.0059205,1.1345301,,,,,,,,,,,,,, -42800,0.85933995,1.133815,,,,,,,,,,,,,, -42900,0.7612028,1.1697856,,,,,,,,,,,,,, -43000,0.7913567,1.1562209,,,,,,,,,,,,,, -43100,0.7160363,1.1588275,,,,,,,,,,,,,, -43200,0.7355809,1.140032,,,,,,,,,,,,,, -43207,,,0.21449152,0.0781664433752976,0.4130111,0.1229423520665784,5348.0,0.23688465,0.0798245079519834,2472.0,33187.53070783615,36317.85943412781,33187.53070783615,3127.2881696224213,1.2326452732086182,0.0 -43300,0.7897317,1.1777588,,,,,,,,,,,,,, -43400,0.71580046,1.1668564,,,,,,,,,,,,,, -43500,0.8623807,1.1446186,,,,,,,,,,,,,, -43600,0.70284945,1.1604283,,,,,,,,,,,,,, -43700,0.6864968,1.1482667,,,,,,,,,,,,,, -43800,0.82133037,1.2091308,,,,,,,,,,,,,, -43900,0.8777453,1.1732602,,,,,,,,,,,,,, -44000,0.83179986,1.176868,,,,,,,,,,,,,, -44100,0.9184484,1.1987736,,,,,,,,,,,,,, -44200,0.67931414,1.1996901,,,,,,,,,,,,,, -44300,0.7394595,1.1420783,,,,,,,,,,,,,, -44400,0.75548375,1.1956527,,,,,,,,,,,,,, -44500,0.7951689,1.1617851,,,,,,,,,,,,,, -44600,0.7479014,1.1410025,,,,,,,,,,,,,, -44700,0.7463511,1.1248688,,,,,,,,,,,,,, -44800,0.74132454,1.1785717,,,,,,,,,,,,,, -44900,0.7668622,1.1684009,,,,,,,,,,,,,, -45000,0.7916334,1.143758,,,,,,,,,,,,,, -45073,,,0.21359752,0.0784236442263388,0.4034216,0.1212141691688309,5348.0,0.23037176,0.0759043730830946,2472.0,34627.65042424202,37887.25235772133,34627.65042424202,3256.428151845932,1.2872276306152344,0.0 -45100,0.7576583,1.1800025,,,,,,,,,,,,,, -45200,1.0419617,1.1629187,,,,,,,,,,,,,, -45300,0.75475,1.1406649,,,,,,,,,,,,,, -45400,0.7067734,1.1528269,,,,,,,,,,,,,, -45500,0.8128985,1.1782471,,,,,,,,,,,,,, -45600,0.96019,1.1997694,,,,,,,,,,,,,, -45700,0.945443,1.1424791,,,,,,,,,,,,,, -45800,0.92070687,1.1661605,,,,,,,,,,,,,, -45900,0.8869099,1.1449183,,,,,,,,,,,,,, -46000,0.7933,1.216534,,,,,,,,,,,,,, -46100,0.76337266,1.132599,,,,,,,,,,,,,, -46200,0.7436977,1.1384927,,,,,,,,,,,,,, -46300,0.7993358,1.137419,,,,,,,,,,,,,, -46400,0.87636256,1.1649909,,,,,,,,,,,,,, -46500,0.74665,1.1237311,,,,,,,,,,,,,, -46600,0.82737774,1.1671772,,,,,,,,,,,,,, -46700,0.87085694,1.1353253,,,,,,,,,,,,,, -46800,0.906715,1.1072066,,,,,,,,,,,,,, -46900,0.83770216,1.1376463,,,,,,,,,,,,,, -46943,,,0.25127593,0.0932313937912033,0.40328342,0.1205769620668681,5348.0,0.22683704,0.0778339731480917,2472.0,36067.74671292305,39455.31745886803,36067.74671292305,3384.257961034775,1.3474400043487549,0.0 -47000,0.8517697,1.1394936,,,,,,,,,,,,,, -47100,0.8477514,1.1837765,,,,,,,,,,,,,, -47200,0.8557897,1.1988072,,,,,,,,,,,,,, -47300,0.78790224,1.0928596,,,,,,,,,,,,,, -47400,0.9058788,1.1009572,,,,,,,,,,,,,, -47500,0.8138947,1.1365161,,,,,,,,,,,,,, -47600,0.84401643,1.1019349,,,,,,,,,,,,,, -47700,0.8387982,1.1507152,,,,,,,,,,,,,, -47800,0.802932,1.144057,,,,,,,,,,,,,, -47900,0.699764,1.1255163,,,,,,,,,,,,,, -48000,0.89818853,1.0807108,,,,,,,,,,,,,, -48100,0.8876634,1.1086836,,,,,,,,,,,,,, -48200,0.7497158,1.1459454,,,,,,,,,,,,,, -48300,0.8025434,1.133497,,,,,,,,,,,,,, -48400,0.7501647,1.1580242,,,,,,,,,,,,,, -48500,0.7964973,1.0690207,,,,,,,,,,,,,, -48600,0.83262163,1.087149,,,,,,,,,,,,,, -48700,0.7751343,1.0657403,,,,,,,,,,,,,, -48800,0.9323789,1.0984013,,,,,,,,,,,,,, -48821,,,0.21432593,0.0774040471242103,0.39701816,0.11645442521023,5348.0,0.22061086,0.0718217455771535,2472.0,37508.35221242905,41024.62561440468,37508.35221242905,3512.821098566056,1.4097182750701904,0.0 -48900,0.89985406,1.1184192,,,,,,,,,,,,,, -49000,0.88652563,1.1727062,,,,,,,,,,,,,, -49100,0.8018863,1.1604437,,,,,,,,,,,,,, -49200,0.76550984,1.0986073,,,,,,,,,,,,,, -49300,0.90829164,1.111413,,,,,,,,,,,,,, -49400,0.87162524,1.0960445,,,,,,,,,,,,,, -49500,0.95769787,1.1226267,,,,,,,,,,,,,, -49600,0.7840524,1.0685953,,,,,,,,,,,,,, -49700,1.0193299,1.1054304,,,,,,,,,,,,,, -49800,0.77978784,1.0838637,,,,,,,,,,,,,, -49900,0.8880828,1.152329,,,,,,,,,,,,,, -50000,0.825967,1.0810134,,,,,,,,,,,,,, -50100,0.9287582,1.1320852,,,,,,,,,,,,,, -50200,0.91043293,1.0802157,,,,,,,,,,,,,, -50300,0.8873805,1.0970354,,,,,,,,,,,,,, -50400,0.9300928,1.1321069,,,,,,,,,,,,,, -50500,0.7754534,1.0898179,,,,,,,,,,,,,, -50600,0.8312462,1.1233218,,,,,,,,,,,,,, -50688,,,0.19890381,0.0744825359357776,0.39558187,0.1149579539859235,5348.0,0.21984032,0.0727560782402047,2472.0,38948.3796274662,42593.2292907238,38948.3796274662,3641.26109457016,1.4670865535736084,0.0 -50700,0.90032494,1.0723305,,,,,,,,,,,,,, -50800,0.77855957,1.0874853,,,,,,,,,,,,,, -50900,0.8942855,1.0935731,,,,,,,,,,,,,, -51000,0.76227355,1.0851942,,,,,,,,,,,,,, -51100,0.72563165,1.0645808,,,,,,,,,,,,,, -51200,1.0281734,1.0728763,,,,,,,,,,,,,, -51300,0.96258634,1.1935953,,,,,,,,,,,,,, -51400,0.7692232,1.1235452,,,,,,,,,,,,,, -51500,0.8243583,1.0456157,,,,,,,,,,,,,, -51600,0.89674133,1.0564228,,,,,,,,,,,,,, -51700,0.78461486,1.1066496,,,,,,,,,,,,,, -51800,0.79390174,1.1149952,,,,,,,,,,,,,, -51900,0.91024196,1.1064937,,,,,,,,,,,,,, -52000,0.82340366,1.0844197,,,,,,,,,,,,,, -52100,0.71653396,1.0735364,,,,,,,,,,,,,, -52200,0.83603364,1.1010507,,,,,,,,,,,,,, -52300,0.8379352,1.07794,,,,,,,,,,,,,, -52400,0.7728855,1.114921,,,,,,,,,,,,,, -52500,0.952414,1.0510671,,,,,,,,,,,,,, -52542,,,0.1622737,0.0626379840998814,0.37560198,0.1109223090068258,5348.0,0.21027523,0.0703796234233136,2472.0,40388.37187838554,44164.18422412872,40388.37187838554,3772.089239597321,1.5247292518615725,0.0 -52600,0.9141734,1.0613796,,,,,,,,,,,,,, -52700,1.0910189,1.0818007,,,,,,,,,,,,,, -52800,0.94248945,1.0501871,,,,,,,,,,,,,, -52900,0.9977385,1.063,,,,,,,,,,,,,, -53000,0.99929106,1.0706439,,,,,,,,,,,,,, -53100,0.8410798,1.116001,,,,,,,,,,,,,, -53200,0.7762035,1.050589,,,,,,,,,,,,,, -53300,0.9283744,1.0525607,,,,,,,,,,,,,, -53400,1.0381701,1.0880805,,,,,,,,,,,,,, -53500,0.95254743,1.1574847,,,,,,,,,,,,,, -53600,0.9421183,1.0676008,,,,,,,,,,,,,, -53700,0.85811377,1.0625863,,,,,,,,,,,,,, -53800,1.0051677,1.0823,,,,,,,,,,,,,, -53900,0.878319,1.0305073,,,,,,,,,,,,,, -54000,0.8681572,1.0377111,,,,,,,,,,,,,, -54100,1.1250664,1.0267042,,,,,,,,,,,,,, -54200,0.82000333,1.0650426,,,,,,,,,,,,,, -54300,0.8223944,1.100435,,,,,,,,,,,,,, -54400,0.967315,1.0377449,,,,,,,,,,,,,, -54436,,,0.17281756,0.0661337011953705,0.36489117,0.1076783455786516,5348.0,0.20149006,0.068429711778685,2472.0,41828.6324532032,45734.00941133499,41828.6324532032,3901.5167338848114,1.5817480087280271,0.0 -54500,1.0499862,1.1032228,,,,,,,,,,,,,, -54600,0.9229683,1.1054246,,,,,,,,,,,,,, -54700,0.84033346,1.0331646,,,,,,,,,,,,,, -54800,0.9060461,1.0659658,,,,,,,,,,,,,, -54900,0.9293368,1.0883907,,,,,,,,,,,,,, -55000,0.95445967,1.0604222,,,,,,,,,,,,,, -55100,1.0056359,1.1108747,,,,,,,,,,,,,, -55200,1.0659254,1.0902224,,,,,,,,,,,,,, -55300,0.9856906,1.065259,,,,,,,,,,,,,, -55400,0.8390934,1.0756481,,,,,,,,,,,,,, -55500,0.9486195,1.0463322,,,,,,,,,,,,,, -55600,0.9327904,1.0526174,,,,,,,,,,,,,, -55700,1.1440423,1.0321136,,,,,,,,,,,,,, -55800,1.0839555,1.009127,,,,,,,,,,,,,, -55900,0.8861761,1.0180227,,,,,,,,,,,,,, -56000,1.0184066,0.97827095,,,,,,,,,,,,,, -56100,0.9601905,1.051841,,,,,,,,,,,,,, -56200,0.9092107,1.0129825,,,,,,,,,,,,,, -56300,0.9020677,0.9407454,,,,,,,,,,,,,, -56319,,,0.14977348,0.0572549819618622,0.36254397,0.1062880755380055,5348.0,0.19897437,0.0670079012044766,2472.0,43269.12178993225,47305.22345614433,43269.12178993225,4032.1063737869263,1.6375198364257812,0.0 -56400,0.9442413,1.069475,,,,,,,,,,,,,, -56500,0.9525087,1.0767767,,,,,,,,,,,,,, -56600,1.3173709,1.0163925,,,,,,,,,,,,,, -56700,0.92494714,0.9722889,,,,,,,,,,,,,, -56800,1.0910562,1.0013736,,,,,,,,,,,,,, -56900,0.91439754,0.9983739,,,,,,,,,,,,,, -57000,1.0209918,1.0733144,,,,,,,,,,,,,, -57100,1.0188812,1.0481578,,,,,,,,,,,,,, -57200,0.99662375,1.0850801,,,,,,,,,,,,,, -57300,0.9617449,1.0784897,,,,,,,,,,,,,, -57400,1.0807064,1.0380795,,,,,,,,,,,,,, -57500,1.0610272,1.0290724,,,,,,,,,,,,,, -57600,0.96187687,1.037031,,,,,,,,,,,,,, -57700,0.8917209,1.0504682,,,,,,,,,,,,,, -57800,0.860106,0.97615844,,,,,,,,,,,,,, -57900,0.9053347,1.0386131,,,,,,,,,,,,,, -58000,0.9388274,1.0225849,,,,,,,,,,,,,, -58100,0.8081343,0.98807484,,,,,,,,,,,,,, -58199,,,0.14436826,0.056059406155415,0.3545192,0.1020496828446469,5348.0,0.18987672,0.0646720695468486,2472.0,44709.46357750893,48875.57385182381,44709.46357750893,4161.976230621338,1.697356939315796,0.0 -58200,0.9998107,1.0149163,,,,,,,,,,,,,, -58300,0.9742703,0.9978372,,,,,,,,,,,,,, -58400,0.9395388,1.0362883,,,,,,,,,,,,,, -58500,0.9473321,1.0281639,,,,,,,,,,,,,, -58600,0.99720776,0.95622396,,,,,,,,,,,,,, -58700,0.8907604,1.0010847,,,,,,,,,,,,,, -58800,0.91382915,0.98462707,,,,,,,,,,,,,, -58900,0.89854,0.96352625,,,,,,,,,,,,,, -59000,0.8994886,1.0024632,,,,,,,,,,,,,, -59100,0.94861287,0.99452543,,,,,,,,,,,,,, -59200,0.9833029,1.0094386,,,,,,,,,,,,,, -59300,0.9969895,1.0102555,,,,,,,,,,,,,, -59400,0.98069215,0.99691546,,,,,,,,,,,,,, -59500,1.1363935,0.98975426,,,,,,,,,,,,,, -59600,0.9931057,0.9779357,,,,,,,,,,,,,, -59700,1.0516087,0.9829285,,,,,,,,,,,,,, -59800,0.9411195,0.9573026,,,,,,,,,,,,,, -59900,1.1417159,1.0165966,,,,,,,,,,,,,, -60000,1.2308762,0.9644725,,,,,,,,,,,,,, -60080,,,0.1338687,0.0509979296066252,0.34373295,0.0988926112940131,5348.0,0.18780683,0.0621128105132736,2472.0,46149.45885229111,50444.575095653534,46149.45885229111,4290.845764160156,1.7542970180511477,0.0 -60100,1.01574,0.95383304,,,,,,,,,,,,,, -60200,1.0404351,1.0225257,,,,,,,,,,,,,, -60300,1.1051369,0.9779747,,,,,,,,,,,,,, -60400,0.95217484,0.9919455,,,,,,,,,,,,,, -60500,1.0505972,1.0185155,,,,,,,,,,,,,, -60600,0.9375995,0.94437045,,,,,,,,,,,,,, -60700,0.9655963,0.97927856,,,,,,,,,,,,,, -60800,0.89326847,0.93963814,,,,,,,,,,,,,, -60900,1.369652,1.002995,,,,,,,,,,,,,, -61000,0.9607592,0.9446743,,,,,,,,,,,,,, -61100,1.4340024,0.9677482,,,,,,,,,,,,,, -61200,0.98995566,0.9781171,,,,,,,,,,,,,, -61300,1.0294571,0.9781284,,,,,,,,,,,,,, -61400,0.9739796,0.9765387,,,,,,,,,,,,,, -61500,0.87523913,0.9827494,,,,,,,,,,,,,, -61600,1.0671669,0.9415253,,,,,,,,,,,,,, -61700,1.0225849,0.95017135,,,,,,,,,,,,,, -61800,0.8899211,0.9515644,,,,,,,,,,,,,, -61900,1.2929071,0.97723365,,,,,,,,,,,,,, -61946,,,0.13237758,0.0509769582088527,0.33442134,0.0973189028452262,5348.0,0.18222867,0.0603660146649604,2472.0,47589.61220765114,52013.4002263546,47589.61220765114,4419.383289575577,1.8110344409942627,0.0 -62000,1.0137516,0.9362521,,,,,,,,,,,,,, -62100,1.0681738,0.98314613,,,,,,,,,,,,,, -62200,1.1308266,0.9839459,,,,,,,,,,,,,, -62300,0.8965969,0.9161985,,,,,,,,,,,,,, -62400,1.4251406,0.9707506,,,,,,,,,,,,,, -62500,1.1528081,1.0076725,,,,,,,,,,,,,, -62600,1.0700394,0.975363,,,,,,,,,,,,,, -62700,1.2674875,0.97365046,,,,,,,,,,,,,, -62800,1.15134,0.96013147,,,,,,,,,,,,,, -62900,1.1014634,0.95174867,,,,,,,,,,,,,, -63000,1.1131905,0.92890584,,,,,,,,,,,,,, -63100,1.0960051,0.98026055,,,,,,,,,,,,,, -63200,1.176204,0.9692952,,,,,,,,,,,,,, -63300,1.0520506,0.9193969,,,,,,,,,,,,,, -63400,1.0503845,0.92385125,,,,,,,,,,,,,, -63500,1.0847119,0.9678977,,,,,,,,,,,,,, -63600,1.2198235,0.92655647,,,,,,,,,,,,,, -63700,1.0100154,0.91694844,,,,,,,,,,,,,, -63800,1.2514706,0.94867724,,,,,,,,,,,,,, -63814,,,0.12851778,0.0490040187760703,0.32344717,0.0930418915396275,5348.0,0.17482771,0.0576239514147015,2472.0,49029.50516271591,53585.0060338974,49029.50516271591,4550.955982685089,1.872158527374268,0.0 -63900,0.99742645,0.9396071,,,,,,,,,,,,,, -64000,0.9752634,0.938448,,,,,,,,,,,,,, -64100,1.0243661,0.9608589,,,,,,,,,,,,,, -64200,0.99316883,0.9329211,,,,,,,,,,,,,, -64300,1.1590347,0.9375949,,,,,,,,,,,,,, -64400,1.2903175,0.93424207,,,,,,,,,,,,,, -64500,1.2880274,0.95712143,,,,,,,,,,,,,, -64600,1.0391959,0.93697953,,,,,,,,,,,,,, -64700,1.0773065,0.9295654,,,,,,,,,,,,,, -64800,1.1694638,0.9390495,,,,,,,,,,,,,, -64900,1.1420205,0.90097326,,,,,,,,,,,,,, -65000,1.1124343,0.92640376,,,,,,,,,,,,,, -65100,1.1605121,0.9546755,,,,,,,,,,,,,, -65200,1.0380223,0.9076551,,,,,,,,,,,,,, -65300,1.0083661,0.9426214,,,,,,,,,,,,,, -65400,1.1063468,0.98951435,,,,,,,,,,,,,, -65500,1.059128,0.9396144,,,,,,,,,,,,,, -65600,1.241761,0.9918663,,,,,,,,,,,,,, -65679,,,0.10282077,0.0405429641087721,0.31785193,0.090879249254178,5348.0,0.17353387,0.0559584018849145,2472.0,50469.89919543266,55155.54966759682,50469.89919543266,4680.961303472519,1.9368162155151367,0.0 -65700,1.1114621,0.9340109,,,,,,,,,,,,,, -65800,1.1314548,0.94028634,,,,,,,,,,,,,, -65900,1.1224779,0.8898316,,,,,,,,,,,,,, -66000,1.0949589,0.90555286,,,,,,,,,,,,,, -66100,1.1157515,0.95884556,,,,,,,,,,,,,, -66200,1.0160484,0.931278,,,,,,,,,,,,,, -66300,1.2366401,0.9207984,,,,,,,,,,,,,, -66400,1.0818188,0.9030474,,,,,,,,,,,,,, -66500,1.1037656,0.9230182,,,,,,,,,,,,,, -66600,1.1316139,0.8833204,,,,,,,,,,,,,, -66700,1.0054482,0.91853815,,,,,,,,,,,,,, -66800,1.2506053,0.9482973,,,,,,,,,,,,,, -66900,1.1541572,0.94285816,,,,,,,,,,,,,, -67000,1.2588816,0.9031156,,,,,,,,,,,,,, -67100,1.1175891,0.9230896,,,,,,,,,,,,,, -67200,1.1474167,0.9406719,,,,,,,,,,,,,, -67300,1.152049,0.9277726,,,,,,,,,,,,,, -67400,1.0952762,0.891332,,,,,,,,,,,,,, -67500,1.0990545,0.914775,,,,,,,,,,,,,, -67553,,,0.098737516,0.037794943179225,0.3111092,0.0892283035809108,5348.0,0.16759367,0.0542725407754961,2472.0,51910.30626130104,56725.00738286972,51910.30626130104,4809.875825405121,1.9941620826721191,0.0 -67600,1.2050313,0.9518658,,,,,,,,,,,,,, -67700,1.0101664,0.8778137,,,,,,,,,,,,,, -67800,1.1763803,0.9167955,,,,,,,,,,,,,, -67900,1.3178271,0.92876077,,,,,,,,,,,,,, -68000,1.1095027,0.92590225,,,,,,,,,,,,,, -68100,1.3205813,0.87661386,,,,,,,,,,,,,, -68200,1.1909178,0.8947389,,,,,,,,,,,,,, -68300,1.3865657,0.9169109,,,,,,,,,,,,,, -68400,1.0746515,0.8797255,,,,,,,,,,,,,, -68500,1.1701276,0.9231076,,,,,,,,,,,,,, -68600,1.1416899,0.90102607,,,,,,,,,,,,,, -68700,1.1331552,0.8878821,,,,,,,,,,,,,, -68800,1.1570688,0.8650235,,,,,,,,,,,,,, -68900,1.1175964,0.84540886,,,,,,,,,,,,,, -69000,1.0369812,0.8680401,,,,,,,,,,,,,, -69100,1.0437158,0.90495825,,,,,,,,,,,,,, -69200,1.041062,0.85768676,,,,,,,,,,,,,, -69300,1.1768466,0.88363105,,,,,,,,,,,,,, -69400,1.4464791,0.8407606,,,,,,,,,,,,,, -69429,,,0.08786926,0.0342476454046333,0.30810562,0.0879442347239252,5348.0,0.16458964,0.0534803891698657,2472.0,53350.63709568977,58294.03611445427,53350.63709568977,4938.435527801514,2.053472757339477,0.0 -69500,1.2339528,0.869379,,,,,,,,,,,,,, -69600,1.1556861,0.8974502,,,,,,,,,,,,,, -69700,1.1045095,0.87906826,,,,,,,,,,,,,, -69800,1.3544234,0.84491926,,,,,,,,,,,,,, -69900,1.4783709,0.87778866,,,,,,,,,,,,,, -70000,1.2386243,0.9004697,,,,,,,,,,,,,, -70100,1.1695728,0.9045742,,,,,,,,,,,,,, -70200,1.3188094,0.877276,,,,,,,,,,,,,, -70300,1.2846681,0.8736334,,,,,,,,,,,,,, -70400,0.9716013,0.82996094,,,,,,,,,,,,,, -70500,1.1080843,0.8759157,,,,,,,,,,,,,, -70600,1.1092308,0.84132457,,,,,,,,,,,,,, -70700,1.0194879,0.86674815,,,,,,,,,,,,,, -70800,1.3665655,0.9261157,,,,,,,,,,,,,, -70900,1.2143912,0.90921515,,,,,,,,,,,,,, -71000,1.326311,0.884723,,,,,,,,,,,,,, -71100,1.4271673,0.8000816,,,,,,,,,,,,,, -71200,1.1579363,0.8680913,,,,,,,,,,,,,, -71300,1.1585836,0.8628966,,,,,,,,,,,,,, -71305,,,0.08843144,0.0338653017302732,0.30182338,0.0861291599486372,5348.0,0.16141044,0.0517539049011841,2472.0,54790.89835476875,59863.24539685249,54790.89835476875,5067.244613170624,2.112919569015503,0.0 -71400,1.2113552,0.8719414,,,,,,,,,,,,,, -71500,1.5744205,0.8709533,,,,,,,,,,,,,, -71600,1.1986448,0.86052614,,,,,,,,,,,,,, -71700,1.222883,0.85780495,,,,,,,,,,,,,, -71800,1.1967895,0.85962325,,,,,,,,,,,,,, -71900,1.6004573,0.8975634,,,,,,,,,,,,,, -72000,1.3058157,0.84117156,,,,,,,,,,,,,, -72100,1.2561045,0.83262014,,,,,,,,,,,,,, -72200,1.4827083,0.8168684,,,,,,,,,,,,,, -72300,1.3448657,0.87270874,,,,,,,,,,,,,, -72400,1.4035894,0.8833938,,,,,,,,,,,,,, -72500,1.237864,0.8579343,,,,,,,,,,,,,, -72600,1.1712459,0.8658625,,,,,,,,,,,,,, -72700,1.1839066,0.86756235,,,,,,,,,,,,,, -72800,1.1718532,0.8401504,,,,,,,,,,,,,, -72900,1.3331989,0.8527046,,,,,,,,,,,,,, -73000,1.35435,0.8756176,,,,,,,,,,,,,, -73100,1.3193336,0.8731091,,,,,,,,,,,,,, -73177,,,0.08011135,0.030846805751213,0.30042586,0.084931982969192,5348.0,0.16003639,0.0511242459326061,2472.0,56231.20172023773,61433.10589194298,56231.20172023773,5196.659123182297,2.1780107021331787,0.0 -73200,1.0572785,0.85456216,,,,,,,,,,,,,, -73300,1.2566895,0.86810154,,,,,,,,,,,,,, -73400,1.3762703,0.8801123,,,,,,,,,,,,,, -73500,1.21077,0.89454234,,,,,,,,,,,,,, -73600,1.448057,0.8595209,,,,,,,,,,,,,, -73700,1.1036266,0.8177822,,,,,,,,,,,,,, -73800,1.254419,0.8454363,,,,,,,,,,,,,, -73900,1.1438546,0.8511195,,,,,,,,,,,,,, -74000,1.3814535,0.84379745,,,,,,,,,,,,,, -74100,1.3959736,0.89268154,,,,,,,,,,,,,, -74200,1.4646287,0.83889216,,,,,,,,,,,,,, -74300,1.1268656,0.83559704,,,,,,,,,,,,,, -74400,1.0570889,0.8355075,,,,,,,,,,,,,, -74500,1.9249659,0.8443031,,,,,,,,,,,,,, -74600,1.2071095,0.8638964,,,,,,,,,,,,,, -74700,1.5084692,0.8692855,,,,,,,,,,,,,, -74800,1.3651696,0.86677516,,,,,,,,,,,,,, -74900,1.4120308,0.8543666,,,,,,,,,,,,,, -75000,1.0439484,0.7954845,,,,,,,,,,,,,, -75065,,,0.07875839,0.0295318951108552,0.2965718,0.0830493256224837,5348.0,0.15909529,0.0502508480084496,2472.0,57671.870633125305,63003.7757563591,57671.870633125305,5326.510108470917,2.247305393218994,0.0 -75100,1.7513094,0.8350454,,,,,,,,,,,,,, -75200,1.2606221,0.88088393,,,,,,,,,,,,,, -75300,1.341322,0.81327975,,,,,,,,,,,,,, -75400,1.1359191,0.8038904,,,,,,,,,,,,,, -75500,1.2596314,0.8576258,,,,,,,,,,,,,, -75600,1.3399656,0.84143865,,,,,,,,,,,,,, -75700,1.2004305,0.8039103,,,,,,,,,,,,,, -75800,1.2236456,0.8530586,,,,,,,,,,,,,, -75900,1.2942876,0.8632519,,,,,,,,,,,,,, -76000,1.4695545,0.86794823,,,,,,,,,,,,,, -76100,1.3990171,0.9084105,,,,,,,,,,,,,, -76200,1.4568924,0.8365115,,,,,,,,,,,,,, -76300,1.2626995,0.8297349,,,,,,,,,,,,,, -76400,1.5076815,0.85612565,,,,,,,,,,,,,, -76500,1.1772115,0.83490896,,,,,,,,,,,,,, -76600,1.1710296,0.8392591,,,,,,,,,,,,,, -76700,1.2929476,0.8073833,,,,,,,,,,,,,, -76800,1.2389276,0.80194217,,,,,,,,,,,,,, -76900,1.0642116,0.8570089,,,,,,,,,,,,,, -76942,,,0.06596456,0.0250914497973121,0.2928794,0.0825665929694816,5348.0,0.15672454,0.0505352101232912,2472.0,59112.54464316368,64574.79559850693,59112.54464316368,5456.7151482105255,2.3087172508239746,0.0 -77000,1.5020738,0.84199536,,,,,,,,,,,,,, -77100,1.1585103,0.8657122,,,,,,,,,,,,,, -77200,1.2956209,0.88064164,,,,,,,,,,,,,, -77300,1.5075682,0.86176944,,,,,,,,,,,,,, -77400,1.3950994,0.8344617,,,,,,,,,,,,,, -77500,1.3902048,0.82330394,,,,,,,,,,,,,, -77600,1.572775,0.8672877,,,,,,,,,,,,,, -77700,1.1671087,0.89755744,,,,,,,,,,,,,, -77800,1.2916375,0.78011763,,,,,,,,,,,,,, -77900,1.3506887,0.85618734,,,,,,,,,,,,,, -78000,1.6050376,0.8624453,,,,,,,,,,,,,, -78100,1.2973799,0.79780376,,,,,,,,,,,,,, -78200,1.2420088,0.85395205,,,,,,,,,,,,,, -78300,1.1392392,0.83760625,,,,,,,,,,,,,, -78400,1.3629559,0.7981073,,,,,,,,,,,,,, -78500,1.1560316,0.83096236,,,,,,,,,,,,,, -78600,1.3852842,0.81083924,,,,,,,,,,,,,, -78700,1.3997142,0.84956414,,,,,,,,,,,,,, -78800,1.241011,0.8067364,,,,,,,,,,,,,, -78817,,,0.06438663,0.0240664276449564,0.29217574,0.0821224789287197,5348.0,0.15655187,0.050007109052871,2472.0,60552.48413944245,66142.98616504669,60552.48413944245,5584.824814081192,2.3721067905426025,0.0 -78900,1.5362868,0.78306544,,,,,,,,,,,,,, -79000,1.0829719,0.81353945,,,,,,,,,,,,,, -79100,1.4034619,0.81719387,,,,,,,,,,,,,, -79200,1.353332,0.8531073,,,,,,,,,,,,,, -79300,1.3268358,0.86196285,,,,,,,,,,,,,, -79400,1.4131544,0.8320516,,,,,,,,,,,,,, -79496,,,,,,,,,,,61068.20896577835,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 54f80fd7d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -128.25435638427734,0.0,35.07322812080383,1,0,35.07322812080383,30.214182,2472,0.9757885970791949,163.3276309967041,31.51566,1.1540354719622474,30.090126,5348,0.9587360128213792 -235.74946308135983,0.030027151107788,1475.6158843040466,1836,0,1475.6158843040466,5.594893,2472,0.8955578575345805,1711.4689011573792,5.6703877,0.9337997724687144,5.695042,5348,0.8934609034824333 -365.60345458984375,0.0844104290008544,2916.513821363449,3701,0,2916.513821363449,1.9130712,2472,0.462210306095505,3282.3531198501587,2.2839155,0.5220066461751967,2.2517276,5348,0.5036446315301659 -494.8561522960663,0.1389400959014892,4356.575798273087,5571,0,4356.575798273087,0.6027771,2472,0.2008002762374829,4851.800904512405,0.7571463,0.2462243320821373,0.86724573,5348,0.2605597767844212 -622.6275777816772,0.1896283626556396,5797.141031265259,7439,0,5797.141031265259,0.4682187,2472,0.1578209737371275,6420.266736745834,0.57816416,0.1990697781645697,0.7265921,5348,0.2208308794423472 -751.8287582397461,0.2437617778778076,7237.32617020607,9302,0,7237.32617020607,0.40672186,2472,0.1386468425649462,7989.78463101387,0.4748012,0.1659062013242554,0.6415892,5348,0.1948308987516533 -906.8597481250764,0.2932982444763183,8677.753060102463,11198,0,8677.753060102463,0.37524578,2472,0.1282270022139622,9585.371998786926,0.29416263,0.1091385393640537,0.600491,5348,0.1828012010388406 -1038.117871761322,0.349048376083374,10117.966343402864,13079,0,10117.966343402864,0.34251323,2472,0.1157556923201917,11156.975125789642,0.26120993,0.0949281272565755,0.5636484,5348,0.1702308427546656 -1168.8186223506927,0.4085848331451416,11558.134120464323,14950,0,11558.134120464323,0.31989244,2472,0.109316921576991,12727.983196496964,0.25069952,0.0936282671317952,0.53552085,5348,0.1616961294495882 -1298.191771030426,0.4646894931793213,12998.268486976624,16819,0,12998.268486976624,0.30219615,2472,0.1039749761338939,14297.6266579628,0.21822503,0.0836307287753568,0.5099662,5348,0.1562991783890246 -1428.2382934093475,0.5230772495269775,14438.530722856522,18672,0,14438.530722856522,0.29380992,2472,0.099567363353848,15868.071051120758,0.21849884,0.083072543196976,0.4908321,5348,0.1499174527163366 -1559.4781639575958,0.5743563175201416,15879.123154878616,20532,0,15879.123154878616,0.28147247,2472,0.0954034895293807,17440.032745838165,0.20132376,0.0777778965027193,0.49051324,5348,0.1492223176960136 -1690.011076211929,0.626962423324585,17319.44011616707,22391,0,17319.44011616707,0.27528322,2472,0.0942254179107509,19011.014504671097,0.2204811,0.0809672504077603,0.4734632,5348,0.1452349459822161 -1818.5990698337555,0.6770737171173096,18759.946305513386,24266,0,18759.946305513386,0.26455674,2472,0.0904271525196514,20580.23752808571,0.19056775,0.0741942324003392,0.45724428,5348,0.1387952923911679 -1949.59156537056,0.7306003570556641,20200.31648516655,26138,0,20200.31648516655,0.25610334,2472,0.0864054597526049,22151.73225450516,0.15853791,0.0627695006212217,0.4487414,5348,0.1351941067997721 -2078.5168731212616,0.785571813583374,21640.60538506508,28016,0,21640.60538506508,0.24930707,2472,0.0856539313062376,23721.0798664093,0.15658772,0.0616395227219364,0.43760598,5348,0.1324521853307201 -2208.4737632274628,0.8361155986785889,23080.832427740097,29887,0,23080.832427740097,0.24524696,2472,0.0811244490484024,25291.391725063324,0.1453827,0.0589077983627746,0.42563367,5348,0.1275958948415188 -2339.2270991802216,0.8912203311920166,24521.21066737175,31777,0,24521.21066737175,0.2405563,2472,0.0804338553409298,26862.657521247864,0.15269236,0.0587187496735235,0.4212963,5348,0.1271131621885167 -2469.7058634758,0.9466145038604736,25961.10282206536,33656,0,25961.10282206536,0.23121922,2472,0.0780777121036703,28433.16188645363,0.14723137,0.0569556770910319,0.41512138,5348,0.1238209254950423 -2600.440567970276,1.0080838203430176,27401.28202843666,35533,0,27401.28202843666,0.22429307,2472,0.0769808868035667,30004.2161693573,0.13383973,0.0528793184604323,0.4006253,5348,0.1209148749239696 -2730.957435131073,1.0638580322265625,28841.636016368862,37406,0,28841.636016368862,0.21991388,2472,0.0740153961773607,31575.221014022827,0.12914908,0.0497980099658896,0.39642468,5348,0.1179122778222964 -2861.2528777122498,1.1172175407409668,30282.072615385056,39263,0,30282.072615385056,0.21920747,2472,0.0737513456421505,33146.08340334892,0.113466114,0.0459796173373894,0.3963149,5348,0.1163096054143294 -2992.140277147293,1.1743803024291992,31722.417976379395,41131,0,31722.417976379395,0.21722014,2472,0.0728373245587309,34717.45092463493,0.11233612,0.0440406167117225,0.39019135,5348,0.1171012869652529 -3122.092783689499,1.228879451751709,33162.53774404526,42993,0,33162.53774404526,0.20735815,2472,0.0708671013344707,36287.6553106308,0.117409885,0.0466727980956502,0.38365835,5348,0.1144655666798613 -3251.867955684662,1.2873446941375732,34603.10031223297,44863,0,34603.10031223297,0.20476188,2472,0.0682875307212641,37858.13029813767,0.10558901,0.041014103534872,0.378038,5348,0.1115112428434884 -3382.6824221611023,1.341782808303833,36043.22028398514,46737,0,36043.22028398514,0.20145229,2472,0.0675156906952653,39429.1965405941,0.098575264,0.0401815943367189,0.37337637,5348,0.1083252073336744 -3512.5706765651703,1.40153169631958,37483.42413520813,48609,0,37483.42413520813,0.19999944,2472,0.0676375601730546,40999.42491674423,0.11590645,0.0434868162140889,0.36853057,5348,0.1070411384766888 -3642.982671022415,1.4595589637756348,38923.31358218193,50478,0,38923.31358218193,0.19472961,2472,0.0644689537505331,42569.863436460495,0.07502113,0.030328978403101,0.35714394,5348,0.1041640518647962 -3773.3202497959137,1.5202560424804688,40363.50043487549,52367,0,40363.50043487549,0.18940833,2472,0.0632096358133772,44140.52767682076,0.080772586,0.032610684468887,0.35393357,5348,0.1022234665997277 -3902.9258086681366,1.5857644081115725,41803.85074186325,54234,0,41803.85074186325,0.18704453,2472,0.0617878252391688,45710.62641215325,0.09836104,0.0394453360300745,0.34822646,5348,0.0997615300694169 -4032.204874753952,1.641676664352417,43244.468445539474,56103,0,43244.468445539474,0.18341146,2472,0.0607316230983283,47280.65885734558,0.096853964,0.03788420696577,0.34267035,5348,0.0985836623960918 -4161.369755983353,1.7073440551757812,44684.62251138687,57971,0,44684.62251138687,0.18483536,2472,0.0600613409704872,48850.121492147446,0.10674078,0.0438091702246148,0.34073088,5348,0.0973768307635865 -4290.192884206772,1.7623403072357178,46125.16544651985,59827,0,46125.16544651985,0.1787121,2472,0.0574411471980175,50419.62053847313,0.090027034,0.03508040849865,0.3356909,5348,0.0955038280699383 -4420.0238037109375,1.8231210708618164,47566.06129670143,61701,0,47566.06129670143,0.17447585,2472,0.0569942924461235,51990.48733663559,0.08025563,0.0323838636989418,0.33256233,5348,0.0946155999884144 -4549.780838727951,1.8821280002594,49007.134853601456,63562,0,49007.134853601456,0.17370649,2472,0.0566286840127556,53561.45655012131,0.06656477,0.0269001800554765,0.3317119,5348,0.09362117072323 -4679.105979681015,1.9498560428619385,50447.04723358154,65430,0,50447.04723358154,0.1742499,2472,0.0555927934515467,55130.83956623077,0.073906675,0.0299568072249732,0.32825038,5348,0.0913523272541201 -4809.865817785263,2.0128204822540283,51887.26133728027,67298,0,51887.26133728027,0.17171825,2472,0.0543537870940223,56701.95521783829,0.063671306,0.0247620575339535,0.326362,5348,0.0915261110092008 -4939.1666939258575,2.068631172180176,53327.44131612778,69166,0,53327.44131612778,0.1701673,2472,0.0543131639347592,58271.57095623016,0.062218998,0.0248977123524252,0.32358345,5348,0.089508288519652 -5068.928290843964,2.1295037269592285,54767.586223602295,71034,0,54767.586223602295,0.17008592,2472,0.0537038165458127,59841.618315935135,0.05749372,0.0227205455123496,0.32213494,5348,0.0897496548461531 -5197.914718389511,2.186281204223633,56208.07806992531,72886,0,56208.07806992531,0.16816369,2472,0.0537647512847074,61411.23158049584,0.06182419,0.0240718941112291,0.3198358,5348,0.0882145650096063 -5326.994605541229,2.25032901763916,57648.61340594292,74754,0,57648.61340594292,0.16721477,2472,0.0523226291308675,62980.98971796036,0.065528736,0.0250461231886041,0.31876612,5348,0.0880987091728858 -5456.8192529678345,2.312541007995605,59088.544001579285,76618,0,59088.544001579285,0.16727436,2472,0.05274917230313,64550.88621211052,0.055176824,0.0221612671227148,0.31730285,5348,0.0872684090097222 -5583.916262388229,2.374246597290039,60528.807544231415,78485,0,60528.807544231415,0.16668418,2472,0.052769483882761564,66118.38577461243,0.058894917,0.022904281303285867,0.31750944,5348,0.08736495554032266 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/measurements.csv deleted file mode 100644 index a506bcc5a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/measurements.csv +++ /dev/null @@ -1,837 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,69.10867,31.631193,,,,,,,,,,,,,, -1,,,31.51566,1.1540354719622474,30.090126,0.9587360128213792,5348.0,30.214182,0.9757885970791949,2472.0,35.07322812080383,163.3276309967041,35.07322812080383,128.25435638427734,0.0,0.0 -100,2.0676281,6.5627236,,,,,,,,,,,,,, -200,1.0750921,5.8906384,,,,,,,,,,,,,, -300,0.97287625,5.8126683,,,,,,,,,,,,,, -400,0.44534242,5.820738,,,,,,,,,,,,,, -500,2.6772306,5.8111587,,,,,,,,,,,,,, -600,0.3571988,5.7828298,,,,,,,,,,,,,, -700,2.217507,5.8054223,,,,,,,,,,,,,, -800,0.29511672,5.802971,,,,,,,,,,,,,, -900,2.6368334,5.782876,,,,,,,,,,,,,, -1000,5.0546603,5.7841053,,,,,,,,,,,,,, -1100,2.5368822,5.6358547,,,,,,,,,,,,,, -1200,2.9234958,5.5448856,,,,,,,,,,,,,, -1300,2.9649384,5.459657,,,,,,,,,,,,,, -1400,1.6653543,5.1016717,,,,,,,,,,,,,, -1500,0.9928049,4.4792438,,,,,,,,,,,,,, -1600,0.860254,4.035604,,,,,,,,,,,,,, -1700,1.4116457,3.6988447,,,,,,,,,,,,,, -1800,1.4880776,3.5313773,,,,,,,,,,,,,, -1836,,,5.6703877,0.9337997724687144,5.695042,0.8934609034824333,5348.0,5.594893,0.8955578575345805,2472.0,1475.6158843040466,1711.4689011573792,1475.6158843040466,235.74946308135983,0.030027151107788,0.0 -1900,1.0190183,3.2691455,,,,,,,,,,,,,, -2000,2.4809659,3.2644725,,,,,,,,,,,,,, -2100,1.1331536,3.0879762,,,,,,,,,,,,,, -2200,1.2416612,2.9544652,,,,,,,,,,,,,, -2300,1.2843913,2.9410322,,,,,,,,,,,,,, -2400,1.1085624,2.8780975,,,,,,,,,,,,,, -2500,1.3797523,2.8050947,,,,,,,,,,,,,, -2600,1.8140831,2.6867144,,,,,,,,,,,,,, -2700,0.9543843,2.648938,,,,,,,,,,,,,, -2800,0.95483667,2.6076217,,,,,,,,,,,,,, -2900,1.4176595,2.6492298,,,,,,,,,,,,,, -3000,1.5673739,2.549545,,,,,,,,,,,,,, -3100,1.0833873,2.4427946,,,,,,,,,,,,,, -3200,1.796581,2.4236882,,,,,,,,,,,,,, -3300,1.1604578,2.3976932,,,,,,,,,,,,,, -3400,2.3731,2.3877194,,,,,,,,,,,,,, -3500,0.9131953,2.2819974,,,,,,,,,,,,,, -3600,1.3873271,2.303511,,,,,,,,,,,,,, -3700,0.9313752,2.1849055,,,,,,,,,,,,,, -3701,,,2.2839155,0.5220066461751967,2.2517276,0.5036446315301659,5348.0,1.9130712,0.462210306095505,2472.0,2916.513821363449,3282.3531198501587,2916.513821363449,365.60345458984375,0.0844104290008544,0.0 -3800,1.1462643,2.1946797,,,,,,,,,,,,,, -3900,0.80157787,2.185653,,,,,,,,,,,,,, -4000,1.1142055,2.1296024,,,,,,,,,,,,,, -4100,1.5926026,2.1641657,,,,,,,,,,,,,, -4200,0.71890616,2.0726533,,,,,,,,,,,,,, -4300,0.7903205,2.066669,,,,,,,,,,,,,, -4400,0.9328933,2.0217116,,,,,,,,,,,,,, -4500,1.4069319,2.0031867,,,,,,,,,,,,,, -4600,0.8447269,1.9359974,,,,,,,,,,,,,, -4700,0.6855471,1.9151443,,,,,,,,,,,,,, -4800,0.7919679,1.9284097,,,,,,,,,,,,,, -4900,0.8081175,1.8806735,,,,,,,,,,,,,, -5000,0.6801709,1.8699104,,,,,,,,,,,,,, -5100,0.6996142,1.8970582,,,,,,,,,,,,,, -5200,0.61992353,1.795906,,,,,,,,,,,,,, -5300,0.6319067,1.8514804,,,,,,,,,,,,,, -5400,0.95039576,1.8411177,,,,,,,,,,,,,, -5500,0.739402,1.7870829,,,,,,,,,,,,,, -5571,,,0.7571463,0.2462243320821373,0.86724573,0.2605597767844212,5348.0,0.6027771,0.2008002762374829,2472.0,4356.575798273087,4851.800904512405,4356.575798273087,494.8561522960663,0.1389400959014892,0.0 -5600,0.6939507,1.7758682,,,,,,,,,,,,,, -5700,0.66328275,1.7862656,,,,,,,,,,,,,, -5800,0.673345,1.7128642,,,,,,,,,,,,,, -5900,0.74599814,1.7710161,,,,,,,,,,,,,, -6000,0.860848,1.787444,,,,,,,,,,,,,, -6100,1.3841293,1.7558727,,,,,,,,,,,,,, -6200,0.638752,1.7358701,,,,,,,,,,,,,, -6300,0.68171483,1.698611,,,,,,,,,,,,,, -6400,0.7138117,1.7376426,,,,,,,,,,,,,, -6500,0.83708555,1.6826087,,,,,,,,,,,,,, -6600,0.8475737,1.7048496,,,,,,,,,,,,,, -6700,0.7400704,1.6913153,,,,,,,,,,,,,, -6800,0.6626303,1.6625804,,,,,,,,,,,,,, -6900,1.0164742,1.6736401,,,,,,,,,,,,,, -7000,0.77670145,1.7127173,,,,,,,,,,,,,, -7100,0.7689358,1.66912,,,,,,,,,,,,,, -7200,0.6300319,1.6476064,,,,,,,,,,,,,, -7300,0.72148544,1.6271348,,,,,,,,,,,,,, -7400,0.6735359,1.5700065,,,,,,,,,,,,,, -7439,,,0.57816416,0.1990697781645697,0.7265921,0.2208308794423472,5348.0,0.4682187,0.1578209737371275,2472.0,5797.141031265259,6420.266736745834,5797.141031265259,622.6275777816772,0.1896283626556396,0.0 -7500,0.6231922,1.645595,,,,,,,,,,,,,, -7600,0.584259,1.5976992,,,,,,,,,,,,,, -7700,0.6142988,1.61055,,,,,,,,,,,,,, -7800,0.77949625,1.6105498,,,,,,,,,,,,,, -7900,0.6535482,1.6358829,,,,,,,,,,,,,, -8000,0.68302846,1.6040033,,,,,,,,,,,,,, -8100,0.6521095,1.5816423,,,,,,,,,,,,,, -8200,0.63124675,1.587713,,,,,,,,,,,,,, -8300,1.0214729,1.6236387,,,,,,,,,,,,,, -8400,0.65522325,1.573697,,,,,,,,,,,,,, -8500,0.70054996,1.5931274,,,,,,,,,,,,,, -8600,0.6752775,1.6165408,,,,,,,,,,,,,, -8700,0.692308,1.5924823,,,,,,,,,,,,,, -8800,0.6955092,1.6220006,,,,,,,,,,,,,, -8900,0.6386377,1.5872471,,,,,,,,,,,,,, -9000,0.71027845,1.4852831,,,,,,,,,,,,,, -9100,0.58549273,1.6001879,,,,,,,,,,,,,, -9200,0.6136077,1.5180557,,,,,,,,,,,,,, -9300,0.63928354,1.5448258,,,,,,,,,,,,,, -9302,,,0.4748012,0.1659062013242554,0.6415892,0.1948308987516533,5348.0,0.40672186,0.1386468425649462,2472.0,7237.32617020607,7989.78463101387,7237.32617020607,751.8287582397461,0.2437617778778076,0.0 -9400,0.600776,1.5100256,,,,,,,,,,,,,, -9500,0.71797055,1.5007795,,,,,,,,,,,,,, -9600,0.6403093,1.4792116,,,,,,,,,,,,,, -9700,0.6739799,1.4715143,,,,,,,,,,,,,, -9800,0.5979604,1.454031,,,,,,,,,,,,,, -9900,0.7057018,1.5005708,,,,,,,,,,,,,, -10000,0.710834,1.4779688,,,,,,,,,,,,,, -10100,0.6057992,1.5144814,,,,,,,,,,,,,, -10200,0.6526474,1.5160187,,,,,,,,,,,,,, -10300,0.5574107,1.4287338,,,,,,,,,,,,,, -10400,0.60206723,1.4985694,,,,,,,,,,,,,, -10500,0.58045286,1.4449645,,,,,,,,,,,,,, -10600,0.721119,1.5572208,,,,,,,,,,,,,, -10700,0.58674014,1.4741172,,,,,,,,,,,,,, -10800,0.4961461,1.4724737,,,,,,,,,,,,,, -10900,0.5977205,1.4834188,,,,,,,,,,,,,, -11000,0.5933256,1.4049306,,,,,,,,,,,,,, -11100,0.6854897,1.4996207,,,,,,,,,,,,,, -11198,,,0.29416263,0.1091385393640537,0.600491,0.1828012010388406,5348.0,0.37524578,0.1282270022139622,2472.0,8677.753060102463,9585.371998786926,8677.753060102463,906.8597481250764,0.2932982444763183,0.0 -11200,0.5753904,1.4543654,,,,,,,,,,,,,, -11300,0.6195182,1.4196154,,,,,,,,,,,,,, -11400,0.71471995,1.4490567,,,,,,,,,,,,,, -11500,0.77460563,1.4309343,,,,,,,,,,,,,, -11600,0.63303,1.4807565,,,,,,,,,,,,,, -11700,0.61426556,1.4662124,,,,,,,,,,,,,, -11800,0.5796184,1.4256245,,,,,,,,,,,,,, -11900,0.5394333,1.4050764,,,,,,,,,,,,,, -12000,0.6056214,1.5014902,,,,,,,,,,,,,, -12100,0.5881672,1.4504695,,,,,,,,,,,,,, -12200,0.81514066,1.399004,,,,,,,,,,,,,, -12300,0.51512706,1.427344,,,,,,,,,,,,,, -12400,0.6009917,1.4241862,,,,,,,,,,,,,, -12500,0.6581031,1.3979778,,,,,,,,,,,,,, -12600,0.61614275,1.4214983,,,,,,,,,,,,,, -12700,0.58348143,1.3761182,,,,,,,,,,,,,, -12800,0.66592014,1.4113748,,,,,,,,,,,,,, -12900,0.55666536,1.4089508,,,,,,,,,,,,,, -13000,0.7792613,1.4193317,,,,,,,,,,,,,, -13079,,,0.26120993,0.0949281272565755,0.5636484,0.1702308427546656,5348.0,0.34251323,0.1157556923201917,2472.0,10117.966343402864,11156.975125789642,10117.966343402864,1038.117871761322,0.349048376083374,0.0 -13100,0.6524575,1.4590521,,,,,,,,,,,,,, -13200,0.649377,1.418834,,,,,,,,,,,,,, -13300,0.6445744,1.3573235,,,,,,,,,,,,,, -13400,0.6341731,1.3813236,,,,,,,,,,,,,, -13500,0.71979237,1.3848066,,,,,,,,,,,,,, -13600,0.54584014,1.3522056,,,,,,,,,,,,,, -13700,0.58853734,1.4427724,,,,,,,,,,,,,, -13800,0.59838384,1.4010406,,,,,,,,,,,,,, -13900,0.71609557,1.411269,,,,,,,,,,,,,, -14000,0.61843276,1.3727915,,,,,,,,,,,,,, -14100,0.6527223,1.3891358,,,,,,,,,,,,,, -14200,0.5202229,1.3756211,,,,,,,,,,,,,, -14300,0.59428114,1.4055616,,,,,,,,,,,,,, -14400,0.48892194,1.3586676,,,,,,,,,,,,,, -14500,0.79465544,1.3683244,,,,,,,,,,,,,, -14600,0.6552521,1.3861787,,,,,,,,,,,,,, -14700,0.73428124,1.3631958,,,,,,,,,,,,,, -14800,0.6639529,1.3857636,,,,,,,,,,,,,, -14900,0.7414375,1.3545976,,,,,,,,,,,,,, -14950,,,0.25069952,0.0936282671317952,0.53552085,0.1616961294495882,5348.0,0.31989244,0.109316921576991,2472.0,11558.134120464323,12727.983196496964,11558.134120464323,1168.8186223506927,0.4085848331451416,0.0 -15000,0.5512523,1.3296978,,,,,,,,,,,,,, -15100,0.6209035,1.3461894,,,,,,,,,,,,,, -15200,0.64906347,1.3701334,,,,,,,,,,,,,, -15300,0.65086114,1.3534821,,,,,,,,,,,,,, -15400,0.63265806,1.3447883,,,,,,,,,,,,,, -15500,0.71441275,1.367255,,,,,,,,,,,,,, -15600,0.7785812,1.3187323,,,,,,,,,,,,,, -15700,0.5565688,1.3471699,,,,,,,,,,,,,, -15800,0.5434674,1.3542447,,,,,,,,,,,,,, -15900,0.6533227,1.3528221,,,,,,,,,,,,,, -16000,0.58652395,1.3516059,,,,,,,,,,,,,, -16100,0.6858121,1.3125534,,,,,,,,,,,,,, -16200,0.6919951,1.3295184,,,,,,,,,,,,,, -16300,0.59889287,1.3541908,,,,,,,,,,,,,, -16400,0.70676297,1.3804181,,,,,,,,,,,,,, -16500,0.6903742,1.3201851,,,,,,,,,,,,,, -16600,0.694658,1.3732738,,,,,,,,,,,,,, -16700,0.63888836,1.2811948,,,,,,,,,,,,,, -16800,0.62552977,1.2668332,,,,,,,,,,,,,, -16819,,,0.21822503,0.0836307287753568,0.5099662,0.1562991783890246,5348.0,0.30219615,0.1039749761338939,2472.0,12998.268486976624,14297.6266579628,12998.268486976624,1298.191771030426,0.4646894931793213,0.0 -16900,0.7314531,1.3180367,,,,,,,,,,,,,, -17000,0.7054039,1.3241878,,,,,,,,,,,,,, -17100,0.5377459,1.3533355,,,,,,,,,,,,,, -17200,0.6182101,1.3484917,,,,,,,,,,,,,, -17300,0.5977745,1.3231505,,,,,,,,,,,,,, -17400,0.64483917,1.326251,,,,,,,,,,,,,, -17500,0.6114678,1.3510206,,,,,,,,,,,,,, -17600,0.7185998,1.3128632,,,,,,,,,,,,,, -17700,0.7496927,1.3103157,,,,,,,,,,,,,, -17800,0.57738316,1.352247,,,,,,,,,,,,,, -17900,0.82678133,1.3066875,,,,,,,,,,,,,, -18000,0.55988854,1.3001678,,,,,,,,,,,,,, -18100,0.7107146,1.2632124,,,,,,,,,,,,,, -18200,0.68202966,1.2475524,,,,,,,,,,,,,, -18300,0.74836755,1.3213475,,,,,,,,,,,,,, -18400,0.5382101,1.2889595,,,,,,,,,,,,,, -18500,0.64479965,1.3460104,,,,,,,,,,,,,, -18600,0.8559622,1.2962185,,,,,,,,,,,,,, -18672,,,0.21849884,0.083072543196976,0.4908321,0.1499174527163366,5348.0,0.29380992,0.099567363353848,2472.0,14438.530722856522,15868.071051120758,14438.530722856522,1428.2382934093475,0.5230772495269775,0.0 -18700,0.7778182,1.2321407,,,,,,,,,,,,,, -18800,0.5540009,1.2767702,,,,,,,,,,,,,, -18900,0.47758222,1.2478818,,,,,,,,,,,,,, -19000,0.5344517,1.2786677,,,,,,,,,,,,,, -19100,0.6253174,1.2708572,,,,,,,,,,,,,, -19200,0.6709148,1.2650406,,,,,,,,,,,,,, -19300,0.65072215,1.3458272,,,,,,,,,,,,,, -19400,0.6072333,1.299405,,,,,,,,,,,,,, -19500,0.5897844,1.3432897,,,,,,,,,,,,,, -19600,0.6126031,1.2378168,,,,,,,,,,,,,, -19700,0.68419707,1.2457035,,,,,,,,,,,,,, -19800,0.62397945,1.273241,,,,,,,,,,,,,, -19900,0.57855415,1.3254722,,,,,,,,,,,,,, -20000,0.72697246,1.2668673,,,,,,,,,,,,,, -20100,0.5661323,1.232838,,,,,,,,,,,,,, -20200,0.51376945,1.2618577,,,,,,,,,,,,,, -20300,0.6928565,1.2806393,,,,,,,,,,,,,, -20400,0.71919936,1.323073,,,,,,,,,,,,,, -20500,0.67131525,1.3018079,,,,,,,,,,,,,, -20532,,,0.20132376,0.0777778965027193,0.49051324,0.1492223176960136,5348.0,0.28147247,0.0954034895293807,2472.0,15879.123154878616,17440.032745838165,15879.123154878616,1559.4781639575958,0.5743563175201416,0.0 -20600,0.63403475,1.2366067,,,,,,,,,,,,,, -20700,0.51656157,1.2848016,,,,,,,,,,,,,, -20800,0.6659229,1.2839849,,,,,,,,,,,,,, -20900,0.6365115,1.2096109,,,,,,,,,,,,,, -21000,0.64765936,1.3093212,,,,,,,,,,,,,, -21100,0.5896835,1.2319527,,,,,,,,,,,,,, -21200,0.5505128,1.2422582,,,,,,,,,,,,,, -21300,0.71612835,1.2758548,,,,,,,,,,,,,, -21400,0.68607545,1.2262099,,,,,,,,,,,,,, -21500,0.5523895,1.3162448,,,,,,,,,,,,,, -21600,0.6292209,1.26037,,,,,,,,,,,,,, -21700,0.7013105,1.2578187,,,,,,,,,,,,,, -21800,0.77822715,1.2199289,,,,,,,,,,,,,, -21900,0.6236192,1.2512144,,,,,,,,,,,,,, -22000,0.6361105,1.2587436,,,,,,,,,,,,,, -22100,0.5752244,1.2770518,,,,,,,,,,,,,, -22200,0.61710924,1.2318299,,,,,,,,,,,,,, -22300,0.5335205,1.2942905,,,,,,,,,,,,,, -22391,,,0.2204811,0.0809672504077603,0.4734632,0.1452349459822161,5348.0,0.27528322,0.0942254179107509,2472.0,17319.44011616707,19011.014504671097,17319.44011616707,1690.011076211929,0.626962423324585,0.0 -22400,0.63511777,1.3096489,,,,,,,,,,,,,, -22500,0.7257137,1.3078221,,,,,,,,,,,,,, -22600,0.65510756,1.2580696,,,,,,,,,,,,,, -22700,0.56193405,1.2769101,,,,,,,,,,,,,, -22800,0.600216,1.2248006,,,,,,,,,,,,,, -22900,0.5097737,1.2359468,,,,,,,,,,,,,, -23000,0.646618,1.2604097,,,,,,,,,,,,,, -23100,0.5404118,1.1637455,,,,,,,,,,,,,, -23200,0.5593215,1.2191497,,,,,,,,,,,,,, -23300,0.5498461,1.232635,,,,,,,,,,,,,, -23400,0.624762,1.1760375,,,,,,,,,,,,,, -23500,0.51820606,1.25557,,,,,,,,,,,,,, -23600,0.68236136,1.2072238,,,,,,,,,,,,,, -23700,0.52867997,1.187502,,,,,,,,,,,,,, -23800,0.5646702,1.1840605,,,,,,,,,,,,,, -23900,0.81965417,1.2123834,,,,,,,,,,,,,, -24000,0.64161384,1.2661891,,,,,,,,,,,,,, -24100,0.5420106,1.2611626,,,,,,,,,,,,,, -24200,0.63985485,1.1955996,,,,,,,,,,,,,, -24266,,,0.19056775,0.0741942324003392,0.45724428,0.1387952923911679,5348.0,0.26455674,0.0904271525196514,2472.0,18759.946305513386,20580.23752808571,18759.946305513386,1818.5990698337555,0.6770737171173096,0.0 -24300,0.6444248,1.2551223,,,,,,,,,,,,,, -24400,0.7247025,1.2526451,,,,,,,,,,,,,, -24500,0.5175719,1.255402,,,,,,,,,,,,,, -24600,0.5846991,1.2720983,,,,,,,,,,,,,, -24700,0.60081863,1.2191192,,,,,,,,,,,,,, -24800,0.61961895,1.1711915,,,,,,,,,,,,,, -24900,0.60884213,1.2268287,,,,,,,,,,,,,, -25000,0.5601249,1.1895324,,,,,,,,,,,,,, -25100,0.6271106,1.2778786,,,,,,,,,,,,,, -25200,0.6131153,1.2506026,,,,,,,,,,,,,, -25300,0.70161414,1.2226394,,,,,,,,,,,,,, -25400,0.9173985,1.202613,,,,,,,,,,,,,, -25500,0.58620507,1.2448677,,,,,,,,,,,,,, -25600,0.4901487,1.1644664,,,,,,,,,,,,,, -25700,0.7111246,1.1580486,,,,,,,,,,,,,, -25800,0.660567,1.2571172,,,,,,,,,,,,,, -25900,0.5767943,1.1883616,,,,,,,,,,,,,, -26000,0.6022795,1.2117546,,,,,,,,,,,,,, -26100,0.5556024,1.2346294,,,,,,,,,,,,,, -26138,,,0.15853791,0.0627695006212217,0.4487414,0.1351941067997721,5348.0,0.25610334,0.0864054597526049,2472.0,20200.31648516655,22151.73225450516,20200.31648516655,1949.59156537056,0.7306003570556641,0.0 -26200,0.6261362,1.187452,,,,,,,,,,,,,, -26300,0.5765483,1.1845496,,,,,,,,,,,,,, -26400,0.66422415,1.2153864,,,,,,,,,,,,,, -26500,0.6633857,1.1929146,,,,,,,,,,,,,, -26600,0.692955,1.1778347,,,,,,,,,,,,,, -26700,0.8906377,1.3141959,,,,,,,,,,,,,, -26800,0.6359335,1.2219568,,,,,,,,,,,,,, -26900,0.6619149,1.2059857,,,,,,,,,,,,,, -27000,0.56295305,1.1299175,,,,,,,,,,,,,, -27100,0.6230416,1.2373226,,,,,,,,,,,,,, -27200,0.6732593,1.2280453,,,,,,,,,,,,,, -27300,0.6039473,1.2243207,,,,,,,,,,,,,, -27400,0.58321434,1.1577631,,,,,,,,,,,,,, -27500,0.5375755,1.2018211,,,,,,,,,,,,,, -27600,0.481343,1.194176,,,,,,,,,,,,,, -27700,0.7251521,1.1881928,,,,,,,,,,,,,, -27800,0.7374267,1.2123039,,,,,,,,,,,,,, -27900,0.6092974,1.2142245,,,,,,,,,,,,,, -28000,0.55433565,1.1580322,,,,,,,,,,,,,, -28016,,,0.15658772,0.0616395227219364,0.43760598,0.1324521853307201,5348.0,0.24930707,0.0856539313062376,2472.0,21640.60538506508,23721.0798664093,21640.60538506508,2078.5168731212616,0.785571813583374,0.0 -28100,0.5537226,1.176328,,,,,,,,,,,,,, -28200,0.54828453,1.1828446,,,,,,,,,,,,,, -28300,0.71278393,1.1748375,,,,,,,,,,,,,, -28400,0.61393005,1.2025144,,,,,,,,,,,,,, -28500,0.62235606,1.175523,,,,,,,,,,,,,, -28600,0.711236,1.1906879,,,,,,,,,,,,,, -28700,0.7069138,1.1924267,,,,,,,,,,,,,, -28800,0.63722557,1.1832427,,,,,,,,,,,,,, -28900,0.5506378,1.1811038,,,,,,,,,,,,,, -29000,0.656189,1.1959388,,,,,,,,,,,,,, -29100,0.5592631,1.1653038,,,,,,,,,,,,,, -29200,0.618614,1.1767989,,,,,,,,,,,,,, -29300,0.640325,1.1696568,,,,,,,,,,,,,, -29400,0.59429735,1.195224,,,,,,,,,,,,,, -29500,0.6164226,1.2154143,,,,,,,,,,,,,, -29600,0.5677708,1.1599253,,,,,,,,,,,,,, -29700,0.61977893,1.1845145,,,,,,,,,,,,,, -29800,0.69024473,1.2028474,,,,,,,,,,,,,, -29887,,,0.1453827,0.0589077983627746,0.42563367,0.1275958948415188,5348.0,0.24524696,0.0811244490484024,2472.0,23080.832427740097,25291.391725063324,23080.832427740097,2208.4737632274628,0.8361155986785889,0.0 -29900,0.7273447,1.1858602,,,,,,,,,,,,,, -30000,0.70123506,1.1231318,,,,,,,,,,,,,, -30100,0.6546545,1.1777283,,,,,,,,,,,,,, -30200,0.6617906,1.1590943,,,,,,,,,,,,,, -30300,0.7213175,1.169688,,,,,,,,,,,,,, -30400,0.5998406,1.1800842,,,,,,,,,,,,,, -30500,0.60452724,1.1387237,,,,,,,,,,,,,, -30600,0.55581385,1.1717503,,,,,,,,,,,,,, -30700,0.66583836,1.1472741,,,,,,,,,,,,,, -30800,0.51180506,1.1958258,,,,,,,,,,,,,, -30900,0.60156983,1.111095,,,,,,,,,,,,,, -31000,0.664811,1.1936876,,,,,,,,,,,,,, -31100,0.56449157,1.1553642,,,,,,,,,,,,,, -31200,0.5475936,1.1471628,,,,,,,,,,,,,, -31300,0.8560872,1.200125,,,,,,,,,,,,,, -31400,0.6639247,1.1615171,,,,,,,,,,,,,, -31500,0.6164318,1.1433161,,,,,,,,,,,,,, -31600,0.5619377,1.0750126,,,,,,,,,,,,,, -31700,0.6672229,1.1097007,,,,,,,,,,,,,, -31777,,,0.15269236,0.0587187496735235,0.4212963,0.1271131621885167,5348.0,0.2405563,0.0804338553409298,2472.0,24521.21066737175,26862.657521247864,24521.21066737175,2339.2270991802216,0.8912203311920166,0.0 -31800,0.55630285,1.1299791,,,,,,,,,,,,,, -31900,0.5903761,1.1611855,,,,,,,,,,,,,, -32000,0.692622,1.134917,,,,,,,,,,,,,, -32100,0.752987,1.1691551,,,,,,,,,,,,,, -32200,0.6601967,1.0937781,,,,,,,,,,,,,, -32300,0.6321213,1.1613234,,,,,,,,,,,,,, -32400,0.5748637,1.1602845,,,,,,,,,,,,,, -32500,0.6208561,1.1055202,,,,,,,,,,,,,, -32600,0.6819772,1.1118811,,,,,,,,,,,,,, -32700,0.548031,1.120517,,,,,,,,,,,,,, -32800,0.5370779,1.1516399,,,,,,,,,,,,,, -32900,0.74882543,1.1895854,,,,,,,,,,,,,, -33000,0.5506182,1.1292251,,,,,,,,,,,,,, -33100,0.5994219,1.1138664,,,,,,,,,,,,,, -33200,0.5497031,1.1421807,,,,,,,,,,,,,, -33300,0.6427563,1.1986881,,,,,,,,,,,,,, -33400,0.6962299,1.1743922,,,,,,,,,,,,,, -33500,0.62488014,1.1445206,,,,,,,,,,,,,, -33600,0.6430311,1.1670895,,,,,,,,,,,,,, -33656,,,0.14723137,0.0569556770910319,0.41512138,0.1238209254950423,5348.0,0.23121922,0.0780777121036703,2472.0,25961.10282206536,28433.16188645363,25961.10282206536,2469.7058634758,0.9466145038604736,0.0 -33700,0.52944577,1.1110218,,,,,,,,,,,,,, -33800,0.6054783,1.2142358,,,,,,,,,,,,,, -33900,0.50473636,1.1733086,,,,,,,,,,,,,, -34000,0.6785058,1.1099541,,,,,,,,,,,,,, -34100,0.59087694,1.174376,,,,,,,,,,,,,, -34200,0.54641515,1.0778648,,,,,,,,,,,,,, -34300,0.5607598,1.0817807,,,,,,,,,,,,,, -34400,0.60718066,1.1048335,,,,,,,,,,,,,, -34500,0.52691436,1.105189,,,,,,,,,,,,,, -34600,0.7260851,1.1357918,,,,,,,,,,,,,, -34700,0.6655667,1.1787808,,,,,,,,,,,,,, -34800,0.577886,1.118999,,,,,,,,,,,,,, -34900,0.62606496,1.1322345,,,,,,,,,,,,,, -35000,0.6644467,1.0594716,,,,,,,,,,,,,, -35100,0.7709732,1.0860579,,,,,,,,,,,,,, -35200,0.57049274,1.1307459,,,,,,,,,,,,,, -35300,0.7851411,1.1066829,,,,,,,,,,,,,, -35400,0.7799699,1.1332082,,,,,,,,,,,,,, -35500,0.7264088,1.126912,,,,,,,,,,,,,, -35533,,,0.13383973,0.0528793184604323,0.4006253,0.1209148749239696,5348.0,0.22429307,0.0769808868035667,2472.0,27401.28202843666,30004.2161693573,27401.28202843666,2600.440567970276,1.0080838203430176,0.0 -35600,0.55815065,1.1427412,,,,,,,,,,,,,, -35700,0.6646052,1.1066508,,,,,,,,,,,,,, -35800,0.680508,1.1075505,,,,,,,,,,,,,, -35900,1.1904289,1.1703453,,,,,,,,,,,,,, -36000,0.7204616,1.1196307,,,,,,,,,,,,,, -36100,0.7691035,1.0846046,,,,,,,,,,,,,, -36200,0.5287156,1.0449942,,,,,,,,,,,,,, -36300,0.61669594,1.123431,,,,,,,,,,,,,, -36400,0.85264,1.1237977,,,,,,,,,,,,,, -36500,0.6425754,1.1255307,,,,,,,,,,,,,, -36600,0.62586355,1.0764464,,,,,,,,,,,,,, -36700,0.7217804,1.0818887,,,,,,,,,,,,,, -36800,0.59492415,1.1284103,,,,,,,,,,,,,, -36900,0.62082577,1.0724696,,,,,,,,,,,,,, -37000,0.65514076,1.1815064,,,,,,,,,,,,,, -37100,0.50667036,1.0756541,,,,,,,,,,,,,, -37200,0.6223847,1.1346247,,,,,,,,,,,,,, -37300,0.604697,1.1370876,,,,,,,,,,,,,, -37400,0.56594974,1.0573257,,,,,,,,,,,,,, -37406,,,0.12914908,0.0497980099658896,0.39642468,0.1179122778222964,5348.0,0.21991388,0.0740153961773607,2472.0,28841.636016368862,31575.221014022827,28841.636016368862,2730.957435131073,1.0638580322265625,0.0 -37500,0.8005656,1.0785488,,,,,,,,,,,,,, -37600,0.7137715,1.0781842,,,,,,,,,,,,,, -37700,1.1370893,1.1244329,,,,,,,,,,,,,, -37800,0.62505686,1.1348442,,,,,,,,,,,,,, -37900,0.7162863,1.1015638,,,,,,,,,,,,,, -38000,0.6876814,1.0348024,,,,,,,,,,,,,, -38100,0.5843981,1.159916,,,,,,,,,,,,,, -38200,0.9379256,1.1048084,,,,,,,,,,,,,, -38300,0.62463915,1.0724509,,,,,,,,,,,,,, -38400,0.7137695,1.0793284,,,,,,,,,,,,,, -38500,0.5578771,1.089863,,,,,,,,,,,,,, -38600,0.6430792,1.1134851,,,,,,,,,,,,,, -38700,0.68425125,1.0687428,,,,,,,,,,,,,, -38800,0.6164068,1.0894724,,,,,,,,,,,,,, -38900,0.57700276,1.1203539,,,,,,,,,,,,,, -39000,0.58275104,1.0676997,,,,,,,,,,,,,, -39100,0.5617723,1.1425117,,,,,,,,,,,,,, -39200,0.6567296,1.0244333,,,,,,,,,,,,,, -39263,,,0.113466114,0.0459796173373894,0.3963149,0.1163096054143294,5348.0,0.21920747,0.0737513456421505,2472.0,30282.072615385056,33146.08340334892,30282.072615385056,2861.2528777122498,1.1172175407409668,0.0 -39300,0.580599,1.0991232,,,,,,,,,,,,,, -39400,0.5479868,1.0494105,,,,,,,,,,,,,, -39500,0.6352475,1.05952,,,,,,,,,,,,,, -39600,0.7104316,1.0999812,,,,,,,,,,,,,, -39700,0.7693372,1.0602726,,,,,,,,,,,,,, -39800,0.5153088,0.9973919,,,,,,,,,,,,,, -39900,0.5627184,1.0677638,,,,,,,,,,,,,, -40000,0.62470466,1.0685862,,,,,,,,,,,,,, -40100,0.6105622,1.0860913,,,,,,,,,,,,,, -40200,0.60638267,1.0453303,,,,,,,,,,,,,, -40300,0.69876504,1.0710667,,,,,,,,,,,,,, -40400,0.6773685,1.0206568,,,,,,,,,,,,,, -40500,0.62890595,1.0880424,,,,,,,,,,,,,, -40600,0.5272351,1.0197028,,,,,,,,,,,,,, -40700,0.7234767,1.051082,,,,,,,,,,,,,, -40800,0.7061073,1.0637592,,,,,,,,,,,,,, -40900,0.73319644,1.131521,,,,,,,,,,,,,, -41000,0.65942717,1.0955321,,,,,,,,,,,,,, -41100,0.9004422,1.098036,,,,,,,,,,,,,, -41131,,,0.11233612,0.0440406167117225,0.39019135,0.1171012869652529,5348.0,0.21722014,0.0728373245587309,2472.0,31722.417976379395,34717.45092463493,31722.417976379395,2992.140277147293,1.1743803024291992,0.0 -41200,0.7565639,1.134222,,,,,,,,,,,,,, -41300,0.60177845,1.0775669,,,,,,,,,,,,,, -41400,0.5918511,1.0592564,,,,,,,,,,,,,, -41500,0.6554119,1.0295568,,,,,,,,,,,,,, -41600,0.7021832,1.0425444,,,,,,,,,,,,,, -41700,0.8141551,1.041108,,,,,,,,,,,,,, -41800,0.67256314,1.0652114,,,,,,,,,,,,,, -41900,0.6972375,1.0397829,,,,,,,,,,,,,, -42000,0.7453126,1.0984641,,,,,,,,,,,,,, -42100,0.5588197,1.0411988,,,,,,,,,,,,,, -42200,0.59757954,1.0656835,,,,,,,,,,,,,, -42300,0.8137107,1.024857,,,,,,,,,,,,,, -42400,0.5924436,1.0491378,,,,,,,,,,,,,, -42500,0.79489315,1.0203512,,,,,,,,,,,,,, -42600,0.5771968,1.0439371,,,,,,,,,,,,,, -42700,0.83377665,0.9781123,,,,,,,,,,,,,, -42800,0.60476905,1.0781623,,,,,,,,,,,,,, -42900,0.68600786,1.0777109,,,,,,,,,,,,,, -42993,,,0.117409885,0.0466727980956502,0.38365835,0.1144655666798613,5348.0,0.20735815,0.0708671013344707,2472.0,33162.53774404526,36287.6553106308,33162.53774404526,3122.092783689499,1.228879451751709,0.0 -43000,0.7146216,1.0971702,,,,,,,,,,,,,, -43100,0.5726761,1.0120969,,,,,,,,,,,,,, -43200,0.621445,1.063522,,,,,,,,,,,,,, -43300,0.72016555,1.0535498,,,,,,,,,,,,,, -43400,0.5633576,0.99328214,,,,,,,,,,,,,, -43500,0.65039825,1.0012345,,,,,,,,,,,,,, -43600,0.6593695,1.05341,,,,,,,,,,,,,, -43700,0.76713306,1.0346227,,,,,,,,,,,,,, -43800,0.572972,1.0434234,,,,,,,,,,,,,, -43900,0.7255448,1.008302,,,,,,,,,,,,,, -44000,0.6402186,1.0609057,,,,,,,,,,,,,, -44100,0.62050194,1.0619928,,,,,,,,,,,,,, -44200,0.63836014,1.0834036,,,,,,,,,,,,,, -44300,0.83565825,1.0649842,,,,,,,,,,,,,, -44400,0.68227196,1.0495119,,,,,,,,,,,,,, -44500,0.57177365,1.0097994,,,,,,,,,,,,,, -44600,0.83064586,1.0293709,,,,,,,,,,,,,, -44700,0.60627997,1.0316453,,,,,,,,,,,,,, -44800,0.6886175,1.010259,,,,,,,,,,,,,, -44863,,,0.10558901,0.041014103534872,0.378038,0.1115112428434884,5348.0,0.20476188,0.0682875307212641,2472.0,34603.10031223297,37858.13029813767,34603.10031223297,3251.867955684662,1.2873446941375732,0.0 -44900,0.6031049,0.9952697,,,,,,,,,,,,,, -45000,0.72334355,1.0608797,,,,,,,,,,,,,, -45100,0.67185056,1.0469588,,,,,,,,,,,,,, -45200,0.7596476,1.0189419,,,,,,,,,,,,,, -45300,0.6281719,1.0558664,,,,,,,,,,,,,, -45400,0.6845578,1.0157028,,,,,,,,,,,,,, -45500,0.6743231,1.0607051,,,,,,,,,,,,,, -45600,0.6733986,1.054694,,,,,,,,,,,,,, -45700,0.633581,1.0389115,,,,,,,,,,,,,, -45800,0.6285282,1.0420463,,,,,,,,,,,,,, -45900,0.6850503,1.0350693,,,,,,,,,,,,,, -46000,0.62462455,1.0638607,,,,,,,,,,,,,, -46100,0.6076955,1.0216943,,,,,,,,,,,,,, -46200,0.63897496,1.0086529,,,,,,,,,,,,,, -46300,0.8629967,1.0359268,,,,,,,,,,,,,, -46400,0.78891593,0.9935406,,,,,,,,,,,,,, -46500,0.6838359,1.021135,,,,,,,,,,,,,, -46600,0.6233028,1.0270787,,,,,,,,,,,,,, -46700,0.60372883,1.029242,,,,,,,,,,,,,, -46737,,,0.098575264,0.0401815943367189,0.37337637,0.1083252073336744,5348.0,0.20145229,0.0675156906952653,2472.0,36043.22028398514,39429.1965405941,36043.22028398514,3382.6824221611023,1.341782808303833,0.0 -46800,0.63308316,1.0388656,,,,,,,,,,,,,, -46900,0.67603415,1.0675834,,,,,,,,,,,,,, -47000,0.7361452,1.0061849,,,,,,,,,,,,,, -47100,0.66922694,1.0173602,,,,,,,,,,,,,, -47200,0.8285862,1.0461315,,,,,,,,,,,,,, -47300,0.5908765,1.0143578,,,,,,,,,,,,,, -47400,0.7073852,1.0264487,,,,,,,,,,,,,, -47500,0.726949,1.0039488,,,,,,,,,,,,,, -47600,0.8335502,0.9897351,,,,,,,,,,,,,, -47700,0.56070644,0.9955881,,,,,,,,,,,,,, -47800,0.60731864,1.0189841,,,,,,,,,,,,,, -47900,0.66843206,1.0457772,,,,,,,,,,,,,, -48000,0.63071096,0.97363627,,,,,,,,,,,,,, -48100,0.654111,0.99284065,,,,,,,,,,,,,, -48200,0.6179233,1.0173182,,,,,,,,,,,,,, -48300,0.6205833,1.0351236,,,,,,,,,,,,,, -48400,0.865633,1.0034475,,,,,,,,,,,,,, -48500,0.751478,1.0420469,,,,,,,,,,,,,, -48600,0.6403692,0.9786578,,,,,,,,,,,,,, -48609,,,0.11590645,0.0434868162140889,0.36853057,0.1070411384766888,5348.0,0.19999944,0.0676375601730546,2472.0,37483.42413520813,40999.42491674423,37483.42413520813,3512.5706765651703,1.40153169631958,0.0 -48700,0.77591085,0.991237,,,,,,,,,,,,,, -48800,0.6492449,1.0394291,,,,,,,,,,,,,, -48900,0.56747717,1.0113466,,,,,,,,,,,,,, -49000,0.7738168,1.0199552,,,,,,,,,,,,,, -49100,0.666613,0.9801124,,,,,,,,,,,,,, -49200,0.9801754,0.979632,,,,,,,,,,,,,, -49300,0.7164919,1.0037636,,,,,,,,,,,,,, -49400,0.6511286,0.9644109,,,,,,,,,,,,,, -49500,0.6367321,1.00823,,,,,,,,,,,,,, -49600,0.64437485,0.9627355,,,,,,,,,,,,,, -49700,0.77587724,1.0136496,,,,,,,,,,,,,, -49800,0.72106624,0.99548787,,,,,,,,,,,,,, -49900,0.654798,0.9854623,,,,,,,,,,,,,, -50000,0.71480685,0.95495903,,,,,,,,,,,,,, -50100,0.7757062,1.0260367,,,,,,,,,,,,,, -50200,0.84178776,0.98095757,,,,,,,,,,,,,, -50300,1.242057,1.0191745,,,,,,,,,,,,,, -50400,0.8060583,1.0111474,,,,,,,,,,,,,, -50478,,,0.07502113,0.030328978403101,0.35714394,0.1041640518647962,5348.0,0.19472961,0.0644689537505331,2472.0,38923.31358218193,42569.863436460495,38923.31358218193,3642.982671022415,1.4595589637756348,0.0 -50500,0.8938559,0.9589713,,,,,,,,,,,,,, -50600,0.6295589,0.9561788,,,,,,,,,,,,,, -50700,0.82488513,0.9601795,,,,,,,,,,,,,, -50800,0.7058269,0.96135056,,,,,,,,,,,,,, -50900,0.58856285,1.0015604,,,,,,,,,,,,,, -51000,0.7654164,1.0004207,,,,,,,,,,,,,, -51100,0.7557521,0.9685804,,,,,,,,,,,,,, -51200,0.87208873,0.9317926,,,,,,,,,,,,,, -51300,0.6369066,1.0116584,,,,,,,,,,,,,, -51400,0.8406069,1.0203885,,,,,,,,,,,,,, -51500,0.61507964,0.93226665,,,,,,,,,,,,,, -51600,0.6823546,0.93825775,,,,,,,,,,,,,, -51700,0.63712263,0.9947022,,,,,,,,,,,,,, -51800,0.79526824,0.9898341,,,,,,,,,,,,,, -51900,0.6488492,1.0057881,,,,,,,,,,,,,, -52000,0.7427779,1.0176423,,,,,,,,,,,,,, -52100,0.65455675,1.021681,,,,,,,,,,,,,, -52200,0.6075764,1.0459225,,,,,,,,,,,,,, -52300,0.6605397,0.92881024,,,,,,,,,,,,,, -52367,,,0.080772586,0.032610684468887,0.35393357,0.1022234665997277,5348.0,0.18940833,0.0632096358133772,2472.0,40363.50043487549,44140.52767682076,40363.50043487549,3773.3202497959137,1.5202560424804688,0.0 -52400,0.7543168,0.9767921,,,,,,,,,,,,,, -52500,0.7508883,0.94287044,,,,,,,,,,,,,, -52600,0.665472,1.0042602,,,,,,,,,,,,,, -52700,0.9996119,0.9053923,,,,,,,,,,,,,, -52800,0.6869204,0.946309,,,,,,,,,,,,,, -52900,1.2587084,0.9504033,,,,,,,,,,,,,, -53000,0.8780783,0.9660136,,,,,,,,,,,,,, -53100,0.8267751,0.9953571,,,,,,,,,,,,,, -53200,0.7933435,0.97895,,,,,,,,,,,,,, -53300,0.6861527,1.0095576,,,,,,,,,,,,,, -53400,0.7357064,0.9793124,,,,,,,,,,,,,, -53500,0.6950415,0.9991055,,,,,,,,,,,,,, -53600,0.8688787,0.98259985,,,,,,,,,,,,,, -53700,0.7382124,0.95499766,,,,,,,,,,,,,, -53800,0.6808088,0.97159755,,,,,,,,,,,,,, -53900,0.6310649,1.0196066,,,,,,,,,,,,,, -54000,0.7274693,0.99386793,,,,,,,,,,,,,, -54100,0.717335,0.9722315,,,,,,,,,,,,,, -54200,0.6612566,0.98565143,,,,,,,,,,,,,, -54234,,,0.09836104,0.0394453360300745,0.34822646,0.0997615300694169,5348.0,0.18704453,0.0617878252391688,2472.0,41803.85074186325,45710.62641215325,41803.85074186325,3902.9258086681366,1.5857644081115725,0.0 -54300,0.7439543,1.0008206,,,,,,,,,,,,,, -54400,0.8557452,0.9582234,,,,,,,,,,,,,, -54500,0.7073671,0.9644125,,,,,,,,,,,,,, -54600,0.9794889,0.9900554,,,,,,,,,,,,,, -54700,1.3862256,0.95576817,,,,,,,,,,,,,, -54800,0.64517856,0.9413464,,,,,,,,,,,,,, -54900,0.81165105,0.96379834,,,,,,,,,,,,,, -55000,0.90994817,0.94703823,,,,,,,,,,,,,, -55100,0.7215683,0.9742575,,,,,,,,,,,,,, -55200,0.6505152,0.98126906,,,,,,,,,,,,,, -55300,0.64808214,0.90350056,,,,,,,,,,,,,, -55400,0.8103138,0.9634934,,,,,,,,,,,,,, -55500,0.7494915,0.9831389,,,,,,,,,,,,,, -55600,0.80893,0.9551179,,,,,,,,,,,,,, -55700,0.7192845,0.96122843,,,,,,,,,,,,,, -55800,0.6445957,0.91666985,,,,,,,,,,,,,, -55900,0.7534988,0.91612023,,,,,,,,,,,,,, -56000,0.699379,0.9356199,,,,,,,,,,,,,, -56100,0.77664065,0.91231817,,,,,,,,,,,,,, -56103,,,0.096853964,0.03788420696577,0.34267035,0.0985836623960918,5348.0,0.18341146,0.0607316230983283,2472.0,43244.468445539474,47280.65885734558,43244.468445539474,4032.204874753952,1.641676664352417,0.0 -56200,0.7503607,0.9305396,,,,,,,,,,,,,, -56300,0.70807934,0.90566796,,,,,,,,,,,,,, -56400,0.8546748,0.99934727,,,,,,,,,,,,,, -56500,0.75674313,0.9440436,,,,,,,,,,,,,, -56600,0.8286761,0.9789447,,,,,,,,,,,,,, -56700,0.75605315,0.91118425,,,,,,,,,,,,,, -56800,0.9698035,0.91541845,,,,,,,,,,,,,, -56900,0.6912227,0.9195468,,,,,,,,,,,,,, -57000,0.7729079,0.97425526,,,,,,,,,,,,,, -57100,0.76106817,0.9396154,,,,,,,,,,,,,, -57200,0.6919973,0.95633495,,,,,,,,,,,,,, -57300,0.68183327,0.93890613,,,,,,,,,,,,,, -57400,0.81509334,0.95666635,,,,,,,,,,,,,, -57500,0.70953697,0.9416816,,,,,,,,,,,,,, -57600,0.7196269,0.91393584,,,,,,,,,,,,,, -57700,1.0793465,0.9393636,,,,,,,,,,,,,, -57800,0.7648051,0.960867,,,,,,,,,,,,,, -57900,0.76711845,0.9701182,,,,,,,,,,,,,, -57971,,,0.10674078,0.0438091702246148,0.34073088,0.0973768307635865,5348.0,0.18483536,0.0600613409704872,2472.0,44684.62251138687,48850.121492147446,44684.62251138687,4161.369755983353,1.7073440551757812,0.0 -58000,0.6731403,0.96943164,,,,,,,,,,,,,, -58100,0.69608396,0.97417027,,,,,,,,,,,,,, -58200,0.7499657,0.95271313,,,,,,,,,,,,,, -58300,1.0295988,0.9237245,,,,,,,,,,,,,, -58400,1.2965827,0.9744527,,,,,,,,,,,,,, -58500,0.80497754,0.9524155,,,,,,,,,,,,,, -58600,0.7942701,0.9061335,,,,,,,,,,,,,, -58700,1.212871,0.9288797,,,,,,,,,,,,,, -58800,0.8776123,0.9287515,,,,,,,,,,,,,, -58900,0.8215034,0.89850163,,,,,,,,,,,,,, -59000,0.82882196,0.9381489,,,,,,,,,,,,,, -59100,0.6765127,0.9241059,,,,,,,,,,,,,, -59200,0.6983933,0.9207246,,,,,,,,,,,,,, -59300,0.7475843,0.9399773,,,,,,,,,,,,,, -59400,0.68843704,0.9066171,,,,,,,,,,,,,, -59500,0.63417834,0.90682983,,,,,,,,,,,,,, -59600,0.77696884,0.93517816,,,,,,,,,,,,,, -59700,0.8089544,0.9301397,,,,,,,,,,,,,, -59800,0.60887176,0.88589835,,,,,,,,,,,,,, -59827,,,0.090027034,0.03508040849865,0.3356909,0.0955038280699383,5348.0,0.1787121,0.0574411471980175,2472.0,46125.16544651985,50419.62053847313,46125.16544651985,4290.192884206772,1.7623403072357178,0.0 -59900,0.7393156,0.9079789,,,,,,,,,,,,,, -60000,0.6480785,0.86970246,,,,,,,,,,,,,, -60100,1.3604589,0.90168685,,,,,,,,,,,,,, -60200,0.8032457,0.9176823,,,,,,,,,,,,,, -60300,0.8882415,0.9535337,,,,,,,,,,,,,, -60400,1.0416588,0.9316408,,,,,,,,,,,,,, -60500,0.7783153,0.9185786,,,,,,,,,,,,,, -60600,0.9722963,0.93996567,,,,,,,,,,,,,, -60700,0.65923464,0.93818504,,,,,,,,,,,,,, -60800,0.8856175,0.926546,,,,,,,,,,,,,, -60900,0.8734527,0.92309934,,,,,,,,,,,,,, -61000,0.80459255,0.8549334,,,,,,,,,,,,,, -61100,0.8443982,0.9579933,,,,,,,,,,,,,, -61200,0.74059,0.90557086,,,,,,,,,,,,,, -61300,0.93780476,0.89762896,,,,,,,,,,,,,, -61400,0.6834177,0.8861335,,,,,,,,,,,,,, -61500,0.6409201,0.8897936,,,,,,,,,,,,,, -61600,0.72564816,0.8862169,,,,,,,,,,,,,, -61700,0.9275851,0.847226,,,,,,,,,,,,,, -61701,,,0.08025563,0.0323838636989418,0.33256233,0.0946155999884144,5348.0,0.17447585,0.0569942924461235,2472.0,47566.06129670143,51990.48733663559,47566.06129670143,4420.0238037109375,1.8231210708618164,0.0 -61800,0.756742,0.8601365,,,,,,,,,,,,,, -61900,0.74528015,0.9246017,,,,,,,,,,,,,, -62000,1.13712,0.89801115,,,,,,,,,,,,,, -62100,0.83090425,0.9199423,,,,,,,,,,,,,, -62200,0.7274528,0.9184566,,,,,,,,,,,,,, -62300,0.7696151,0.89600706,,,,,,,,,,,,,, -62400,1.7900976,0.9276615,,,,,,,,,,,,,, -62500,0.8394942,0.89545643,,,,,,,,,,,,,, -62600,0.96238625,0.9512217,,,,,,,,,,,,,, -62700,0.88706654,0.8906398,,,,,,,,,,,,,, -62800,0.7883403,0.928339,,,,,,,,,,,,,, -62900,0.8856475,0.8851808,,,,,,,,,,,,,, -63000,0.7560939,0.8526318,,,,,,,,,,,,,, -63100,0.85470843,0.9241284,,,,,,,,,,,,,, -63200,0.66254294,0.8831635,,,,,,,,,,,,,, -63300,0.6656279,0.8467201,,,,,,,,,,,,,, -63400,1.0362909,0.91179675,,,,,,,,,,,,,, -63500,0.7862765,0.90667796,,,,,,,,,,,,,, -63562,,,0.06656477,0.0269001800554765,0.3317119,0.09362117072323,5348.0,0.17370649,0.0566286840127556,2472.0,49007.134853601456,53561.45655012131,49007.134853601456,4549.780838727951,1.8821280002594,0.0 -63600,1.9143107,0.9042724,,,,,,,,,,,,,, -63700,1.3102759,0.9055298,,,,,,,,,,,,,, -63800,0.69377774,0.9124351,,,,,,,,,,,,,, -63900,0.819034,0.9338188,,,,,,,,,,,,,, -64000,0.76496416,0.89029545,,,,,,,,,,,,,, -64100,0.71721005,0.8614501,,,,,,,,,,,,,, -64200,1.2248861,0.851857,,,,,,,,,,,,,, -64300,0.97129995,0.8893968,,,,,,,,,,,,,, -64400,0.67903894,0.9138454,,,,,,,,,,,,,, -64500,0.7831605,0.91130006,,,,,,,,,,,,,, -64600,0.9659938,0.90307945,,,,,,,,,,,,,, -64700,0.9526526,0.889704,,,,,,,,,,,,,, -64800,0.8265997,0.90354466,,,,,,,,,,,,,, -64900,0.710047,0.8417491,,,,,,,,,,,,,, -65000,0.7117917,0.8508695,,,,,,,,,,,,,, -65100,0.8171792,0.8904142,,,,,,,,,,,,,, -65200,0.7201123,0.8532386,,,,,,,,,,,,,, -65300,0.9511907,0.94720423,,,,,,,,,,,,,, -65400,0.7505003,0.89432436,,,,,,,,,,,,,, -65430,,,0.073906675,0.0299568072249732,0.32825038,0.0913523272541201,5348.0,0.1742499,0.0555927934515467,2472.0,50447.04723358154,55130.83956623077,50447.04723358154,4679.105979681015,1.9498560428619385,0.0 -65500,0.99662656,0.91107935,,,,,,,,,,,,,, -65600,0.76473045,0.8813818,,,,,,,,,,,,,, -65700,0.82317704,0.9187,,,,,,,,,,,,,, -65800,0.6716388,0.85304713,,,,,,,,,,,,,, -65900,0.7872169,0.8951141,,,,,,,,,,,,,, -66000,0.9808316,0.87012523,,,,,,,,,,,,,, -66100,0.8000358,0.8672932,,,,,,,,,,,,,, -66200,0.78681093,0.87208873,,,,,,,,,,,,,, -66300,0.7013726,0.91957563,,,,,,,,,,,,,, -66400,0.6400109,0.83200186,,,,,,,,,,,,,, -66500,0.86580807,0.89256203,,,,,,,,,,,,,, -66600,0.8211081,0.85576975,,,,,,,,,,,,,, -66700,0.8574062,0.8842799,,,,,,,,,,,,,, -66800,0.93429565,0.84910965,,,,,,,,,,,,,, -66900,0.69104356,0.89036167,,,,,,,,,,,,,, -67000,0.77742493,0.8987876,,,,,,,,,,,,,, -67100,0.9980823,0.8970909,,,,,,,,,,,,,, -67200,0.7874945,0.8677447,,,,,,,,,,,,,, -67298,,,0.063671306,0.0247620575339535,0.326362,0.0915261110092008,5348.0,0.17171825,0.0543537870940223,2472.0,51887.26133728027,56701.95521783829,51887.26133728027,4809.865817785263,2.0128204822540283,0.0 -67300,1.0310078,0.8458481,,,,,,,,,,,,,, -67400,0.7639495,0.8696639,,,,,,,,,,,,,, -67500,0.8844616,0.91677755,,,,,,,,,,,,,, -67600,0.71159446,0.85929346,,,,,,,,,,,,,, -67700,0.7934695,0.8600895,,,,,,,,,,,,,, -67800,1.1039333,0.8738859,,,,,,,,,,,,,, -67900,1.0119336,0.8606806,,,,,,,,,,,,,, -68000,0.7717712,0.87062764,,,,,,,,,,,,,, -68100,0.79651177,0.89710516,,,,,,,,,,,,,, -68200,0.8678168,0.92548573,,,,,,,,,,,,,, -68300,0.89840984,0.90291214,,,,,,,,,,,,,, -68400,0.81441975,0.835181,,,,,,,,,,,,,, -68500,0.918564,0.861185,,,,,,,,,,,,,, -68600,0.94012576,0.8798535,,,,,,,,,,,,,, -68700,0.95440316,0.8825995,,,,,,,,,,,,,, -68800,0.8352148,0.88035285,,,,,,,,,,,,,, -68900,0.9276267,0.8930403,,,,,,,,,,,,,, -69000,0.73808926,0.8612435,,,,,,,,,,,,,, -69100,1.0491887,0.8638701,,,,,,,,,,,,,, -69166,,,0.062218998,0.0248977123524252,0.32358345,0.089508288519652,5348.0,0.1701673,0.0543131639347592,2472.0,53327.44131612778,58271.57095623016,53327.44131612778,4939.1666939258575,2.068631172180176,0.0 -69200,0.98889995,0.86336917,,,,,,,,,,,,,, -69300,0.69144607,0.84211373,,,,,,,,,,,,,, -69400,0.73211133,0.850995,,,,,,,,,,,,,, -69500,0.9031828,0.89739627,,,,,,,,,,,,,, -69600,1.1613175,0.8425213,,,,,,,,,,,,,, -69700,0.8259482,0.8576675,,,,,,,,,,,,,, -69800,0.89907926,0.8359072,,,,,,,,,,,,,, -69900,1.6491336,0.85383177,,,,,,,,,,,,,, -70000,0.75945216,0.8539879,,,,,,,,,,,,,, -70100,1.1454104,0.8791508,,,,,,,,,,,,,, -70200,0.7590524,0.8648278,,,,,,,,,,,,,, -70300,1.0798268,0.83275455,,,,,,,,,,,,,, -70400,0.84498066,0.84021074,,,,,,,,,,,,,, -70500,0.7731346,0.8188001,,,,,,,,,,,,,, -70600,0.868701,0.83503175,,,,,,,,,,,,,, -70700,0.78572214,0.81664366,,,,,,,,,,,,,, -70800,0.8974357,0.85858417,,,,,,,,,,,,,, -70900,0.8616655,0.90344995,,,,,,,,,,,,,, -71000,0.76888025,0.8641442,,,,,,,,,,,,,, -71034,,,0.05749372,0.0227205455123496,0.32213494,0.0897496548461531,5348.0,0.17008592,0.0537038165458127,2472.0,54767.586223602295,59841.618315935135,54767.586223602295,5068.928290843964,2.1295037269592285,0.0 -71100,0.7635248,0.7980517,,,,,,,,,,,,,, -71200,0.96730375,0.8711619,,,,,,,,,,,,,, -71300,0.99905324,0.8529387,,,,,,,,,,,,,, -71400,0.65346307,0.83736175,,,,,,,,,,,,,, -71500,0.8430431,0.88188726,,,,,,,,,,,,,, -71600,0.8576097,0.86785686,,,,,,,,,,,,,, -71700,0.824647,0.82348216,,,,,,,,,,,,,, -71800,1.0113045,0.84117305,,,,,,,,,,,,,, -71900,0.8687711,0.866427,,,,,,,,,,,,,, -72000,0.7922709,0.8976156,,,,,,,,,,,,,, -72100,0.873529,0.79616314,,,,,,,,,,,,,, -72200,0.76100355,0.8524426,,,,,,,,,,,,,, -72300,0.93589664,0.8303088,,,,,,,,,,,,,, -72400,0.6965648,0.8593398,,,,,,,,,,,,,, -72500,1.2256784,0.83921,,,,,,,,,,,,,, -72600,0.91831416,0.8417374,,,,,,,,,,,,,, -72700,0.9272339,0.82005036,,,,,,,,,,,,,, -72800,0.83796334,0.8665874,,,,,,,,,,,,,, -72886,,,0.06182419,0.0240718941112291,0.3198358,0.0882145650096063,5348.0,0.16816369,0.0537647512847074,2472.0,56208.07806992531,61411.23158049584,56208.07806992531,5197.914718389511,2.186281204223633,0.0 -72900,1.9004155,0.85641587,,,,,,,,,,,,,, -73000,1.5087553,0.8634388,,,,,,,,,,,,,, -73100,0.78670937,0.83594567,,,,,,,,,,,,,, -73200,0.8427748,0.833193,,,,,,,,,,,,,, -73300,0.6972904,0.84211004,,,,,,,,,,,,,, -73400,0.7425252,0.8727868,,,,,,,,,,,,,, -73500,0.71425724,0.8596169,,,,,,,,,,,,,, -73600,0.9232825,0.8360401,,,,,,,,,,,,,, -73700,0.73236257,0.8272534,,,,,,,,,,,,,, -73800,0.7286607,0.8416148,,,,,,,,,,,,,, -73900,0.67544883,0.85962766,,,,,,,,,,,,,, -74000,0.75930804,0.8154705,,,,,,,,,,,,,, -74100,0.724993,0.87236196,,,,,,,,,,,,,, -74200,0.8997167,0.8366979,,,,,,,,,,,,,, -74300,0.8814333,0.83906347,,,,,,,,,,,,,, -74400,0.97569233,0.81072956,,,,,,,,,,,,,, -74500,0.7673435,0.8279705,,,,,,,,,,,,,, -74600,0.7253273,0.8190277,,,,,,,,,,,,,, -74700,1.2582171,0.8687561,,,,,,,,,,,,,, -74754,,,0.065528736,0.0250461231886041,0.31876612,0.0880987091728858,5348.0,0.16721477,0.0523226291308675,2472.0,57648.61340594292,62980.98971796036,57648.61340594292,5326.994605541229,2.25032901763916,0.0 -74800,1.1291815,0.86428535,,,,,,,,,,,,,, -74900,0.8666521,0.8222009,,,,,,,,,,,,,, -75000,1.775714,0.8589307,,,,,,,,,,,,,, -75100,0.94494677,0.80165064,,,,,,,,,,,,,, -75200,1.1447654,0.8446275,,,,,,,,,,,,,, -75300,0.95809096,0.80413187,,,,,,,,,,,,,, -75400,0.93952584,0.80897605,,,,,,,,,,,,,, -75500,1.003875,0.85525084,,,,,,,,,,,,,, -75600,1.320369,0.8238395,,,,,,,,,,,,,, -75700,0.9148212,0.8061432,,,,,,,,,,,,,, -75800,0.77893716,0.8286886,,,,,,,,,,,,,, -75900,0.9549231,0.8771518,,,,,,,,,,,,,, -76000,0.900809,0.8723922,,,,,,,,,,,,,, -76100,1.137376,0.8875621,,,,,,,,,,,,,, -76200,1.31406,0.84854835,,,,,,,,,,,,,, -76300,0.73701894,0.80499786,,,,,,,,,,,,,, -76400,0.8705507,0.8217105,,,,,,,,,,,,,, -76500,1.6046814,0.83122396,,,,,,,,,,,,,, -76600,0.7425134,0.875054,,,,,,,,,,,,,, -76618,,,0.055176824,0.0221612671227148,0.31730285,0.0872684090097222,5348.0,0.16727436,0.05274917230313,2472.0,59088.544001579285,64550.88621211052,59088.544001579285,5456.8192529678345,2.312541007995605,0.0 -76700,1.4755051,0.8625085,,,,,,,,,,,,,, -76800,0.90904146,0.8334552,,,,,,,,,,,,,, -76900,0.8384395,0.84036815,,,,,,,,,,,,,, -77000,0.8640219,0.8520224,,,,,,,,,,,,,, -77100,0.86040586,0.8403785,,,,,,,,,,,,,, -77200,0.9989398,0.85607815,,,,,,,,,,,,,, -77300,0.98001146,0.82421315,,,,,,,,,,,,,, -77400,0.75783664,0.81994575,,,,,,,,,,,,,, -77500,0.8189812,0.8424785,,,,,,,,,,,,,, -77600,0.96588224,0.81767654,,,,,,,,,,,,,, -77700,0.7848971,0.8494664,,,,,,,,,,,,,, -77800,0.90196353,0.83852255,,,,,,,,,,,,,, -77900,0.7346612,0.83388186,,,,,,,,,,,,,, -78000,0.990222,0.87134874,,,,,,,,,,,,,, -78100,0.76347667,0.85475266,,,,,,,,,,,,,, -78200,0.8324568,0.8169766,,,,,,,,,,,,,, -78300,0.8016517,0.8521074,,,,,,,,,,,,,, -78400,0.8072597,0.79718286,,,,,,,,,,,,,, -78485,,,0.058894917,0.0229042813032858,0.31750944,0.0873649555403226,5348.0,0.16668418,0.0527694838827615,2472.0,60528.807544231415,66118.38577461243,60528.807544231415,5583.916262388229,2.374246597290039,0.0 -78500,0.96748227,0.81660306,,,,,,,,,,,,,, -78600,1.0347056,0.8430063,,,,,,,,,,,,,, -78700,1.1458981,0.8494377,,,,,,,,,,,,,, -78800,1.1072582,0.8325872,,,,,,,,,,,,,, -78900,1.2282017,0.8361513,,,,,,,,,,,,,, -79000,1.6122153,0.813729,,,,,,,,,,,,,, -79100,0.98422813,0.81288797,,,,,,,,,,,,,, -79196,,,,,,,,,,,61068.353628873825,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 553970086..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -129.469651222229,0.0,34.07029867172241,1,0,34.07029867172241,30.214182,2472,0.9757885970791949,163.53999519348145,31.87367,1.1416312330885194,30.090126,5348,0.9587456674744392 -236.02138757705688,0.0324082374572753,1474.5567479133606,1828,0,1474.5567479133606,5.791992,2472,0.899579550301627,1710.6854240894318,5.8229604,0.9391896477614642,5.8899813,5348,0.8966179750330672 -353.74039936065674,0.0953776836395263,2914.5891761779785,3682,0,2914.5891761779785,2.893342,2472,0.6339040887209798,3268.5771906375885,3.304899,0.7357704402515723,3.243913,5348,0.7009953947304903 -482.6379177570343,0.1445662975311279,4354.786093473434,5545,0,4354.786093473434,0.80377287,2472,0.2609834866857595,4837.799213647842,1.0559078,0.3306912298852493,1.0954505,5348,0.3188642266140166 -611.7937211990356,0.1955733299255371,5795.037358760834,7405,0,5795.037358760834,0.5571465,2472,0.1881055389677655,6407.335177659988,0.68506795,0.2282069414193986,0.83268803,5348,0.2479314905818859 -739.1186428070068,0.2479183673858642,7235.207192897797,9263,0,7235.207192897797,0.47376433,2472,0.1573131842463388,7974.961220741272,0.56729215,0.1905387948281184,0.7289868,5348,0.2186778918099578 -868.9974067211151,0.300107479095459,8675.277812957764,11114,0,8675.277812957764,0.42439848,2472,0.1452481059451993,9545.040509700775,0.55909514,0.192576916815888,0.6618509,5348,0.2001023393224364 -999.3121480941772,0.3550164699554443,10115.687534570694,12982,0,10115.687534570694,0.39323118,2472,0.1341782950460057,11115.899310827255,0.47844568,0.1660066636003634,0.6252392,5348,0.189733241935951 -1129.5970721244812,0.4081838130950928,11555.559470653534,14851,0,11555.559470653534,0.36638296,2472,0.1254646273840716,12686.189148187636,0.4646131,0.15785684451186,0.59083736,5348,0.1807254506309315 -1259.186912059784,0.4650015830993652,12995.451155424118,16714,0,12995.451155424118,0.3518366,2472,0.119696138768712,14255.806090593338,0.38969892,0.137784051896975,0.5713257,5348,0.1726445060196761 -1387.4215536117554,0.5193827152252197,14435.541295289991,18576,0,14435.541295289991,0.33858824,2472,0.1145979322811935,15824.262801408768,0.39127594,0.1395070727318284,0.5567127,5348,0.1681647469998165 -1517.047952890396,0.5736007690429688,15876.112209320068,20463,0,15876.112209320068,0.32874158,2472,0.1117136879735137,17394.59325671196,0.3756434,0.1329070768267737,0.53292954,5348,0.1632312192861349 -1657.830862045288,0.6243786811828613,17316.20405101776,22339,0,17316.20405101776,0.31686965,2472,0.1066967278045213,18975.59814763069,0.25177395,0.0927705361205535,0.5271379,5348,0.1575253193276499 -1789.5025107860563,0.6770033836364746,18756.79286241532,24218,0,18756.79286241532,0.30003414,2472,0.1033046940060528,20547.990478992466,0.23562592,0.0862083217093911,0.5088276,5348,0.1535958755322127 -1919.0365262031555,0.7321634292602539,20196.840671777725,26095,0,20196.840671777725,0.2946906,2472,0.0992220665001117,22117.707132577896,0.23156746,0.0854878824686216,0.5032359,5348,0.1530841789200305 -2049.692736387253,0.7841732501983643,21636.85958075524,27962,0,21636.85958075524,0.2876642,2472,0.0954441126886437,23688.51284337044,0.21804456,0.0809259763946248,0.48922175,5348,0.146354885737181 -2181.014355182648,0.8413059711456299,23077.1240503788,29839,0,23077.1240503788,0.2838056,2472,0.095972213759064,25260.23565888405,0.21714135,0.0803820814055137,0.4858514,5348,0.1460459368392596 -2312.3158464431763,0.8909256458282471,24517.084721565247,31702,0,24517.084721565247,0.26807746,2472,0.0917067820364389,26831.625014066696,0.19798476,0.074705159769927,0.46828082,5348,0.140648985778696 -2442.287088871002,0.9477226734161376,25957.28610897064,33574,0,25957.28610897064,0.26978704,2472,0.0925598683809639,28401.9341943264,0.23919919,0.0845372933058556,0.46265563,5348,0.1399828147175531 -2573.6164152622223,1.000833511352539,27397.751355171204,35448,0,27397.751355171204,0.2617301,2472,0.088558487193549,29973.861132621765,0.19585215,0.0732812986505813,0.45211855,5348,0.1349817044324512 -2703.905078172684,1.056443452835083,28838.30406999588,37325,0,28838.30406999588,0.25678006,2472,0.0865476408100257,31544.836877584457,0.17662221,0.068772327400343,0.4406018,5348,0.1327900981878216 -2834.99561214447,1.10945463180542,30278.698662042618,39194,0,30278.698662042618,0.250659,2472,0.0842118091523977,33116.45455908775,0.17059627,0.0640816435006777,0.43647674,5348,0.1290440928005252 -2966.282071590424,1.1651394367218018,31718.72379803657,41071,0,31718.72379803657,0.24354605,2472,0.083704019661609,34687.90142393112,0.16845806,0.0658107235142118,0.4294395,5348,0.1285903241067032 -3099.5344285964966,1.2207741737365725,33159.33412575722,42942,0,33159.33412575722,0.23960043,2472,0.0789307984481953,36261.89895486832,0.1750182,0.0647522907774169,0.421468,5348,0.1254718711683095 -3231.975162744522,1.27516770362854,34599.805364370346,44822,0,34599.805364370346,0.23448966,2472,0.0786058131740905,37834.94379091263,0.1670249,0.0623548677552352,0.40870935,5348,0.1211562412504706 -3363.430745601654,1.3329179286956787,36039.968794584274,46699,0,36039.968794584274,0.2308128,2472,0.0774683647147238,39406.69979095459,0.15492408,0.0587481983056295,0.40402222,5348,0.1208183283933691 -3493.983107566833,1.3930652141571045,37480.06201171875,48569,0,37480.06201171875,0.22736883,2472,0.0756403225478845,40977.48327946663,0.14281599,0.0546745031368479,0.3916889,5348,0.1152765575369049 -3626.39207482338,1.450796365737915,38920.40566134453,50447,0,38920.40566134453,0.21504119,2472,0.0715373834623118,42550.37265348434,0.12362996,0.0471352378940698,0.38162008,5348,0.1128822035780144 -3759.255940914154,1.513139009475708,40360.53560185432,52319,0,40360.53560185432,0.21084793,2472,0.0695468486584201,44123.50813269615,0.12619652,0.0476263446627233,0.37294978,5348,0.1103140658640431 -3890.098317861557,1.5733115673065186,41801.09794139862,54199,0,41801.09794139862,0.20148818,2472,0.0674750675360022,45695.05196380615,0.12776229,0.0495070365468247,0.36549774,5348,0.1087693213744364 -4020.719121932984,1.631901502609253,43241.61392068863,56083,0,43241.61392068863,0.19801894,2472,0.0659720106432677,47266.32650756836,0.11001102,0.0422276984798891,0.35788244,5348,0.1045019647218977 -4153.070083618164,1.6941826343536377,44681.81542778015,57964,0,44681.81542778015,0.19171545,2472,0.0642455263745861,48839.02103638649,0.10896836,0.0432601685542037,0.34203002,5348,0.101615223456945 -4284.613779306412,1.749586582183838,46122.25700163841,59837,0,46122.25700163841,0.18553233,2472,0.0619909410354843,50411.139713048935,0.12385153,0.0441396256721466,0.3396015,5348,0.0995201637429159 -4417.90145611763,1.811532735824585,47562.51353478432,61721,0,47562.51353478432,0.17721963,2472,0.0592895009444884,51984.82677769661,0.079162344,0.0305428935156667,0.33002672,5348,0.0957741583556194 -4550.166195392609,1.872153282165528,49002.65753102303,63598,0,49002.65753102303,0.17631762,2472,0.057827067211017,53557.37498831749,0.07950283,0.0309600099870999,0.33140823,5348,0.0957162304372592 -4679.587248086929,1.935394525527954,50443.269728422165,65476,0,50443.269728422165,0.16963969,2472,0.0557755976682306,55127.55088472366,0.09328667,0.0359912868411498,0.3145385,5348,0.0916226575398013 -4808.630482435226,1.9940145015716555,51883.24606084824,67363,0,51883.24606084824,0.17028022,2472,0.0544756565718115,56696.70898079872,0.0919391,0.0344043891459985,0.31448692,5348,0.089740000193093 -4936.9902057647705,2.0550296306610107,53323.26822352409,69247,0,53323.26822352409,0.1640094,2472,0.0533585196920764,58265.23030591011,0.10656315,0.0412882922637434,0.30797923,5348,0.0877511416627243 -5064.64481139183,2.114377737045288,54763.13962602615,71130,0,54763.13962602615,0.16000023,2472,0.0515304775252371,59832.89411592484,0.08343864,0.0310373453686598,0.29869056,5348,0.0850092201936723 -5193.7060170173645,2.1751956939697266,56203.4166662693,73022,0,56203.4166662693,0.15819612,2472,0.0514086080474478,61402.37157559395,0.08170432,0.0314145653484378,0.2942599,5348,0.0834065477857053 -5324.08319067955,2.238858938217163,57643.36317658424,74899,0,57643.36317658424,0.1556986,2472,0.0505148985436597,62972.83807229996,0.061528895,0.0238534577830379,0.29282883,5348,0.0826148662347818 -5455.210937261581,2.299424171447754,59083.7230618,76783,0,59083.7230618,0.15388285,2472,0.0497430585176609,64544.46607041359,0.07354782,0.0285005856659624,0.29150823,5348,0.0815721637042972 -5585.668171644211,2.3606507778167725,60523.66146183014,78661,0,60523.66146183014,0.15340821,2472,0.049864927995450205,66115.00144767761,0.066440485,0.02468937875751503,0.2898537,5348,0.08116666827577551 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/measurements.csv deleted file mode 100644 index 86ad2a527..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/measurements.csv +++ /dev/null @@ -1,839 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,72.012146,32.30332,,,,,,,,,,,,,, -1,,,31.87367,1.1416312330885194,30.090126,0.9587456674744392,5348.0,30.214182,0.9757885970791949,2472.0,34.07029867172241,163.53999519348145,34.07029867172241,129.469651222229,0.0,0.0 -100,9.809667,6.9501767,,,,,,,,,,,,,, -200,2.743786,6.1262074,,,,,,,,,,,,,, -300,0.57362705,5.8401637,,,,,,,,,,,,,, -400,0.43969125,5.8465023,,,,,,,,,,,,,, -500,0.29867804,5.8103876,,,,,,,,,,,,,, -600,0.31931326,5.793688,,,,,,,,,,,,,, -700,0.2856928,5.816399,,,,,,,,,,,,,, -800,0.919538,5.795424,,,,,,,,,,,,,, -900,0.33954224,5.785499,,,,,,,,,,,,,, -1000,0.59882873,5.8058147,,,,,,,,,,,,,, -1100,0.48904017,5.7922564,,,,,,,,,,,,,, -1200,0.33760017,5.7839622,,,,,,,,,,,,,, -1300,1.2611649,5.7489543,,,,,,,,,,,,,, -1400,0.7754452,5.6233997,,,,,,,,,,,,,, -1500,0.7870281,5.5124917,,,,,,,,,,,,,, -1600,1.1806167,5.4293175,,,,,,,,,,,,,, -1700,2.1858244,5.139142,,,,,,,,,,,,,, -1800,0.6435465,4.625364,,,,,,,,,,,,,, -1828,,,5.8229604,0.9391896477614642,5.8899813,0.8966179750330672,5348.0,5.791992,0.899579550301627,2472.0,1474.5567479133606,1710.6854240894318,1474.5567479133606,236.02138757705688,0.0324082374572753,0.0 -1900,0.8577593,4.149359,,,,,,,,,,,,,, -2000,0.9398271,3.8066757,,,,,,,,,,,,,, -2100,1.3087928,3.5782278,,,,,,,,,,,,,, -2200,1.0481522,3.4404538,,,,,,,,,,,,,, -2300,1.1091897,3.2848082,,,,,,,,,,,,,, -2400,1.0123681,3.154174,,,,,,,,,,,,,, -2500,1.2067392,3.0947487,,,,,,,,,,,,,, -2600,1.0866786,2.974602,,,,,,,,,,,,,, -2700,1.5867802,2.8835056,,,,,,,,,,,,,, -2800,1.0388191,2.8921125,,,,,,,,,,,,,, -2900,1.1538297,2.77826,,,,,,,,,,,,,, -3000,1.1222498,2.7176833,,,,,,,,,,,,,, -3100,1.4865154,2.7001357,,,,,,,,,,,,,, -3200,1.1723493,2.6365154,,,,,,,,,,,,,, -3300,1.3909897,2.6278381,,,,,,,,,,,,,, -3400,1.4036034,2.487005,,,,,,,,,,,,,, -3500,1.2978461,2.535245,,,,,,,,,,,,,, -3600,1.3758405,2.4741502,,,,,,,,,,,,,, -3682,,,3.304899,0.7357704402515723,3.243913,0.7009953947304903,5348.0,2.893342,0.6339040887209798,2472.0,2914.5891761779785,3268.5771906375885,2914.5891761779785,353.74039936065674,0.0953776836395263,0.0 -3700,1.0898345,2.4191744,,,,,,,,,,,,,, -3800,1.0445346,2.389454,,,,,,,,,,,,,, -3900,1.0538033,2.467204,,,,,,,,,,,,,, -4000,1.3673629,2.336024,,,,,,,,,,,,,, -4100,1.0546099,2.3482568,,,,,,,,,,,,,, -4200,1.0024042,2.3098488,,,,,,,,,,,,,, -4300,0.918727,2.253067,,,,,,,,,,,,,, -4400,1.1084024,2.2654378,,,,,,,,,,,,,, -4500,0.9278074,2.2339573,,,,,,,,,,,,,, -4600,0.9011376,2.160984,,,,,,,,,,,,,, -4700,0.8794597,2.1625268,,,,,,,,,,,,,, -4800,0.9641215,2.0839436,,,,,,,,,,,,,, -4900,0.8212954,2.0614095,,,,,,,,,,,,,, -5000,0.8201158,2.0388763,,,,,,,,,,,,,, -5100,0.922712,2.0180707,,,,,,,,,,,,,, -5200,0.89652205,2.0065951,,,,,,,,,,,,,, -5300,0.8285133,2.0113647,,,,,,,,,,,,,, -5400,1.1893613,2.0154676,,,,,,,,,,,,,, -5500,0.8874415,2.0135744,,,,,,,,,,,,,, -5545,,,1.0559078,0.3306912298852493,1.0954505,0.3188642266140166,5348.0,0.80377287,0.2609834866857595,2472.0,4354.786093473434,4837.799213647842,4354.786093473434,482.6379177570343,0.1445662975311279,0.0 -5600,1.1251278,1.9978162,,,,,,,,,,,,,, -5700,1.0963098,2.0044103,,,,,,,,,,,,,, -5800,0.72041583,1.8817277,,,,,,,,,,,,,, -5900,0.8705225,1.9017401,,,,,,,,,,,,,, -6000,0.81243575,1.8792472,,,,,,,,,,,,,, -6100,0.83092684,1.8871282,,,,,,,,,,,,,, -6200,0.8374709,1.7926859,,,,,,,,,,,,,, -6300,0.73804617,1.8324753,,,,,,,,,,,,,, -6400,0.71629435,1.7869446,,,,,,,,,,,,,, -6500,0.878008,1.7972497,,,,,,,,,,,,,, -6600,0.799565,1.8336194,,,,,,,,,,,,,, -6700,0.8208614,1.8914375,,,,,,,,,,,,,, -6800,0.8278721,1.7661883,,,,,,,,,,,,,, -6900,0.84207594,1.8266724,,,,,,,,,,,,,, -7000,0.90593165,1.8151617,,,,,,,,,,,,,, -7100,0.7713407,1.8114752,,,,,,,,,,,,,, -7200,0.7034707,1.7452519,,,,,,,,,,,,,, -7300,0.9039763,1.7706146,,,,,,,,,,,,,, -7400,0.7591049,1.7661828,,,,,,,,,,,,,, -7405,,,0.68506795,0.2282069414193986,0.83268803,0.2479314905818859,5348.0,0.5571465,0.1881055389677655,2472.0,5795.037358760834,6407.335177659988,5795.037358760834,611.7937211990356,0.1955733299255371,0.0 -7500,0.81344724,1.7452679,,,,,,,,,,,,,, -7600,0.82476854,1.7635267,,,,,,,,,,,,,, -7700,0.67791927,1.7472513,,,,,,,,,,,,,, -7800,0.73455346,1.734127,,,,,,,,,,,,,, -7900,0.8077706,1.7147168,,,,,,,,,,,,,, -8000,0.7196705,1.724358,,,,,,,,,,,,,, -8100,0.66303295,1.7500142,,,,,,,,,,,,,, -8200,0.7669381,1.7281753,,,,,,,,,,,,,, -8300,0.80327076,1.7293972,,,,,,,,,,,,,, -8400,0.7778053,1.6578631,,,,,,,,,,,,,, -8500,0.6999327,1.6711222,,,,,,,,,,,,,, -8600,0.680302,1.7415761,,,,,,,,,,,,,, -8700,0.7875974,1.6830269,,,,,,,,,,,,,, -8800,0.62813765,1.6718904,,,,,,,,,,,,,, -8900,0.72900885,1.6819538,,,,,,,,,,,,,, -9000,0.7617155,1.6154501,,,,,,,,,,,,,, -9100,0.6404152,1.7041214,,,,,,,,,,,,,, -9200,0.8114175,1.6373781,,,,,,,,,,,,,, -9263,,,0.56729215,0.1905387948281184,0.7289868,0.2186778918099578,5348.0,0.47376433,0.1573131842463388,2472.0,7235.207192897797,7974.961220741272,7235.207192897797,739.1186428070068,0.2479183673858642,0.0 -9300,0.74407727,1.6382222,,,,,,,,,,,,,, -9400,0.619507,1.633007,,,,,,,,,,,,,, -9500,0.7460863,1.5707234,,,,,,,,,,,,,, -9600,0.6997404,1.6084687,,,,,,,,,,,,,, -9700,0.6466331,1.6155642,,,,,,,,,,,,,, -9800,0.64704245,1.6407582,,,,,,,,,,,,,, -9900,0.6486187,1.5790031,,,,,,,,,,,,,, -10000,0.7253237,1.5959778,,,,,,,,,,,,,, -10100,0.7705986,1.5723828,,,,,,,,,,,,,, -10200,0.9401489,1.6130774,,,,,,,,,,,,,, -10300,0.6009552,1.5824077,,,,,,,,,,,,,, -10400,0.68406326,1.5753645,,,,,,,,,,,,,, -10500,0.5880128,1.564933,,,,,,,,,,,,,, -10600,0.70750856,1.647777,,,,,,,,,,,,,, -10700,0.6637449,1.5837938,,,,,,,,,,,,,, -10800,0.66786563,1.5648284,,,,,,,,,,,,,, -10900,0.7594677,1.5948527,,,,,,,,,,,,,, -11000,0.64050156,1.5783098,,,,,,,,,,,,,, -11100,0.7304794,1.5910006,,,,,,,,,,,,,, -11114,,,0.55909514,0.192576916815888,0.6618509,0.2001023393224364,5348.0,0.42439848,0.1452481059451993,2472.0,8675.277812957764,9545.040509700775,8675.277812957764,868.9974067211151,0.300107479095459,0.0 -11200,0.72700876,1.5247957,,,,,,,,,,,,,, -11300,0.69690794,1.5376706,,,,,,,,,,,,,, -11400,0.7953661,1.6017214,,,,,,,,,,,,,, -11500,0.58181727,1.5473027,,,,,,,,,,,,,, -11600,0.57751936,1.568068,,,,,,,,,,,,,, -11700,0.64479697,1.5177015,,,,,,,,,,,,,, -11800,0.6821705,1.5710905,,,,,,,,,,,,,, -11900,0.6411267,1.5248287,,,,,,,,,,,,,, -12000,0.69686604,1.5837027,,,,,,,,,,,,,, -12100,0.59452474,1.5090954,,,,,,,,,,,,,, -12200,0.65600556,1.5808588,,,,,,,,,,,,,, -12300,0.64489883,1.5396819,,,,,,,,,,,,,, -12400,0.823806,1.5368266,,,,,,,,,,,,,, -12500,0.7785827,1.5132618,,,,,,,,,,,,,, -12600,0.62569475,1.5004314,,,,,,,,,,,,,, -12700,0.7240543,1.5724492,,,,,,,,,,,,,, -12800,0.72203857,1.5157475,,,,,,,,,,,,,, -12900,0.5460722,1.4520397,,,,,,,,,,,,,, -12982,,,0.47844568,0.1660066636003634,0.6252392,0.189733241935951,5348.0,0.39323118,0.1341782950460057,2472.0,10115.687534570694,11115.899310827255,10115.687534570694,999.3121480941772,0.3550164699554443,0.0 -13000,0.7651669,1.5150355,,,,,,,,,,,,,, -13100,0.6928409,1.5166422,,,,,,,,,,,,,, -13200,0.61773974,1.5081071,,,,,,,,,,,,,, -13300,0.74339026,1.5030363,,,,,,,,,,,,,, -13400,0.6768084,1.4525138,,,,,,,,,,,,,, -13500,0.7745634,1.5296553,,,,,,,,,,,,,, -13600,0.5796598,1.3883355,,,,,,,,,,,,,, -13700,0.6497941,1.4904152,,,,,,,,,,,,,, -13800,0.7261262,1.4629657,,,,,,,,,,,,,, -13900,0.6788716,1.5472264,,,,,,,,,,,,,, -14000,0.60558015,1.5390205,,,,,,,,,,,,,, -14100,0.839986,1.5347686,,,,,,,,,,,,,, -14200,0.7608592,1.5161759,,,,,,,,,,,,,, -14300,0.64963853,1.5007348,,,,,,,,,,,,,, -14400,0.65536034,1.4629916,,,,,,,,,,,,,, -14500,0.71086746,1.4664549,,,,,,,,,,,,,, -14600,0.9613596,1.4667857,,,,,,,,,,,,,, -14700,0.65209264,1.4490184,,,,,,,,,,,,,, -14800,0.77052116,1.4955881,,,,,,,,,,,,,, -14851,,,0.4646131,0.15785684451186,0.59083736,0.1807254506309315,5348.0,0.36638296,0.1254646273840716,2472.0,11555.559470653534,12686.189148187636,11555.559470653534,1129.5970721244812,0.4081838130950928,0.0 -14900,0.6471896,1.4817303,,,,,,,,,,,,,, -15000,0.6228854,1.4453628,,,,,,,,,,,,,, -15100,0.6927675,1.4514319,,,,,,,,,,,,,, -15200,0.7348059,1.4888096,,,,,,,,,,,,,, -15300,0.8975051,1.4492152,,,,,,,,,,,,,, -15400,0.6365984,1.4411149,,,,,,,,,,,,,, -15500,0.76027024,1.4568125,,,,,,,,,,,,,, -15600,0.75158095,1.4333727,,,,,,,,,,,,,, -15700,0.6969085,1.4348052,,,,,,,,,,,,,, -15800,0.67291284,1.3708817,,,,,,,,,,,,,, -15900,0.6372666,1.4460673,,,,,,,,,,,,,, -16000,0.6472868,1.4656104,,,,,,,,,,,,,, -16100,0.6185123,1.4239626,,,,,,,,,,,,,, -16200,0.6847749,1.4509337,,,,,,,,,,,,,, -16300,0.7252087,1.476291,,,,,,,,,,,,,, -16400,0.5883209,1.4422978,,,,,,,,,,,,,, -16500,0.92969805,1.4404027,,,,,,,,,,,,,, -16600,0.6418353,1.4160537,,,,,,,,,,,,,, -16700,0.6155115,1.3866452,,,,,,,,,,,,,, -16714,,,0.38969892,0.137784051896975,0.5713257,0.1726445060196761,5348.0,0.3518366,0.119696138768712,2472.0,12995.451155424118,14255.806090593338,12995.451155424118,1259.186912059784,0.4650015830993652,0.0 -16800,0.66955304,1.4228181,,,,,,,,,,,,,, -16900,0.6286915,1.3974506,,,,,,,,,,,,,, -17000,0.7272533,1.45553,,,,,,,,,,,,,, -17100,0.908411,1.5032676,,,,,,,,,,,,,, -17200,0.57353246,1.380721,,,,,,,,,,,,,, -17300,0.73348397,1.4657055,,,,,,,,,,,,,, -17400,0.75415957,1.4215674,,,,,,,,,,,,,, -17500,0.6430674,1.4012642,,,,,,,,,,,,,, -17600,0.61585736,1.423165,,,,,,,,,,,,,, -17700,0.5622988,1.3809843,,,,,,,,,,,,,, -17800,0.8767442,1.4065065,,,,,,,,,,,,,, -17900,0.77692723,1.4525201,,,,,,,,,,,,,, -18000,0.6503216,1.3771663,,,,,,,,,,,,,, -18100,0.7163668,1.4301515,,,,,,,,,,,,,, -18200,0.6484789,1.420767,,,,,,,,,,,,,, -18300,0.5987074,1.3311553,,,,,,,,,,,,,, -18400,0.657558,1.3833213,,,,,,,,,,,,,, -18500,0.78680056,1.450903,,,,,,,,,,,,,, -18576,,,0.39127594,0.1395070727318284,0.5567127,0.1681647469998165,5348.0,0.33858824,0.1145979322811935,2472.0,14435.541295289991,15824.262801408768,14435.541295289991,1387.4215536117554,0.5193827152252197,0.0 -18600,0.58931893,1.3806953,,,,,,,,,,,,,, -18700,0.6813884,1.3597337,,,,,,,,,,,,,, -18800,0.7376866,1.4345307,,,,,,,,,,,,,, -18900,0.6485083,1.3496107,,,,,,,,,,,,,, -19000,0.5819306,1.3830905,,,,,,,,,,,,,, -19100,0.6656156,1.3990312,,,,,,,,,,,,,, -19200,0.63721365,1.4018557,,,,,,,,,,,,,, -19300,0.74854726,1.4097933,,,,,,,,,,,,,, -19400,0.6467977,1.3714828,,,,,,,,,,,,,, -19500,0.69262975,1.404926,,,,,,,,,,,,,, -19600,0.6370728,1.4049202,,,,,,,,,,,,,, -19700,0.60780966,1.3764148,,,,,,,,,,,,,, -19800,0.64017254,1.3631692,,,,,,,,,,,,,, -19900,0.8016101,1.4132041,,,,,,,,,,,,,, -20000,0.6523281,1.3519951,,,,,,,,,,,,,, -20100,0.82737386,1.4172707,,,,,,,,,,,,,, -20200,0.6813381,1.3156918,,,,,,,,,,,,,, -20300,0.7152096,1.3979405,,,,,,,,,,,,,, -20400,0.70620286,1.4342242,,,,,,,,,,,,,, -20463,,,0.3756434,0.1329070768267737,0.53292954,0.1632312192861349,5348.0,0.32874158,0.1117136879735137,2472.0,15876.112209320068,17394.59325671196,15876.112209320068,1517.047952890396,0.5736007690429688,0.0 -20500,0.6296611,1.4304433,,,,,,,,,,,,,, -20600,0.6248168,1.3227668,,,,,,,,,,,,,, -20700,0.8560992,1.3796408,,,,,,,,,,,,,, -20800,0.72095674,1.3923702,,,,,,,,,,,,,, -20900,0.78266656,1.3703139,,,,,,,,,,,,,, -21000,0.6638241,1.3227067,,,,,,,,,,,,,, -21100,0.7358852,1.3678925,,,,,,,,,,,,,, -21200,0.6865049,1.3614941,,,,,,,,,,,,,, -21300,0.7993406,1.3537346,,,,,,,,,,,,,, -21400,0.6667875,1.3425701,,,,,,,,,,,,,, -21500,0.8338307,1.3354529,,,,,,,,,,,,,, -21600,0.90477353,1.3565937,,,,,,,,,,,,,, -21700,0.68266684,1.359158,,,,,,,,,,,,,, -21800,0.67893076,1.3702205,,,,,,,,,,,,,, -21900,0.7029296,1.3458252,,,,,,,,,,,,,, -22000,0.62568873,1.365729,,,,,,,,,,,,,, -22100,0.96626186,1.3739532,,,,,,,,,,,,,, -22200,0.63261026,1.3636087,,,,,,,,,,,,,, -22300,0.5991352,1.3460459,,,,,,,,,,,,,, -22339,,,0.25177395,0.0927705361205535,0.5271379,0.1575253193276499,5348.0,0.31686965,0.1066967278045213,2472.0,17316.20405101776,18975.59814763069,17316.20405101776,1657.830862045288,0.6243786811828613,0.0 -22400,0.7020129,1.4114032,,,,,,,,,,,,,, -22500,0.75287545,1.3289057,,,,,,,,,,,,,, -22600,0.7124008,1.3645303,,,,,,,,,,,,,, -22700,0.6327884,1.3733163,,,,,,,,,,,,,, -22800,0.8183611,1.3337162,,,,,,,,,,,,,, -22900,0.823595,1.3806704,,,,,,,,,,,,,, -23000,0.65113646,1.3613422,,,,,,,,,,,,,, -23100,0.5930567,1.2845429,,,,,,,,,,,,,, -23200,0.6875973,1.3624486,,,,,,,,,,,,,, -23300,0.68450195,1.3317387,,,,,,,,,,,,,, -23400,0.6819581,1.2822485,,,,,,,,,,,,,, -23500,0.694461,1.3835508,,,,,,,,,,,,,, -23600,0.69088817,1.314722,,,,,,,,,,,,,, -23700,0.67585254,1.3178153,,,,,,,,,,,,,, -23800,0.6715205,1.3311915,,,,,,,,,,,,,, -23900,0.87394375,1.3307835,,,,,,,,,,,,,, -24000,0.66729605,1.312542,,,,,,,,,,,,,, -24100,0.57790715,1.3711091,,,,,,,,,,,,,, -24200,0.6962864,1.3658271,,,,,,,,,,,,,, -24218,,,0.23562592,0.0862083217093911,0.5088276,0.1535958755322127,5348.0,0.30003414,0.1033046940060528,2472.0,18756.79286241532,20547.990478992466,18756.79286241532,1789.5025107860563,0.6770033836364746,0.0 -24300,0.64988375,1.3523822,,,,,,,,,,,,,, -24400,0.7131265,1.360835,,,,,,,,,,,,,, -24500,0.6394573,1.3009446,,,,,,,,,,,,,, -24600,0.59590673,1.3814499,,,,,,,,,,,,,, -24700,0.7306105,1.3261398,,,,,,,,,,,,,, -24800,0.80906045,1.2574674,,,,,,,,,,,,,, -24900,0.6369321,1.2857143,,,,,,,,,,,,,, -25000,0.74272656,1.3226479,,,,,,,,,,,,,, -25100,0.88415813,1.3584588,,,,,,,,,,,,,, -25200,0.68788534,1.3804305,,,,,,,,,,,,,, -25300,0.6375756,1.2959753,,,,,,,,,,,,,, -25400,0.72989696,1.3180213,,,,,,,,,,,,,, -25500,0.72963864,1.2873755,,,,,,,,,,,,,, -25600,0.8284157,1.3060329,,,,,,,,,,,,,, -25700,0.62539864,1.3191497,,,,,,,,,,,,,, -25800,0.7557669,1.3146114,,,,,,,,,,,,,, -25900,0.69741225,1.2777243,,,,,,,,,,,,,, -26000,0.7112332,1.3697206,,,,,,,,,,,,,, -26095,,,0.23156746,0.0854878824686216,0.5032359,0.1530841789200305,5348.0,0.2946906,0.0992220665001117,2472.0,20196.840671777725,22117.707132577896,20196.840671777725,1919.0365262031555,0.7321634292602539,0.0 -26100,0.7305183,1.3349856,,,,,,,,,,,,,, -26200,0.8128963,1.3045803,,,,,,,,,,,,,, -26300,0.7009803,1.3131151,,,,,,,,,,,,,, -26400,0.7160855,1.31832,,,,,,,,,,,,,, -26500,0.7028875,1.3201253,,,,,,,,,,,,,, -26600,0.72194546,1.272775,,,,,,,,,,,,,, -26700,0.66678625,1.35306,,,,,,,,,,,,,, -26800,0.69598925,1.2654282,,,,,,,,,,,,,, -26900,0.69314337,1.308116,,,,,,,,,,,,,, -27000,0.6365773,1.2083865,,,,,,,,,,,,,, -27100,0.75167423,1.345359,,,,,,,,,,,,,, -27200,0.6659731,1.301338,,,,,,,,,,,,,, -27300,0.6458611,1.363052,,,,,,,,,,,,,, -27400,0.6979941,1.2823629,,,,,,,,,,,,,, -27500,0.7061267,1.2875168,,,,,,,,,,,,,, -27600,0.6303804,1.254908,,,,,,,,,,,,,, -27700,0.72387177,1.2921109,,,,,,,,,,,,,, -27800,0.74155706,1.3193977,,,,,,,,,,,,,, -27900,0.6436671,1.3120865,,,,,,,,,,,,,, -27962,,,0.21804456,0.0809259763946248,0.48922175,0.146354885737181,5348.0,0.2876642,0.0954441126886437,2472.0,21636.85958075524,23688.51284337044,21636.85958075524,2049.692736387253,0.7841732501983643,0.0 -28000,0.7464466,1.2991048,,,,,,,,,,,,,, -28100,0.70002997,1.3233293,,,,,,,,,,,,,, -28200,0.6501411,1.2871665,,,,,,,,,,,,,, -28300,0.70810765,1.2894819,,,,,,,,,,,,,, -28400,0.69133943,1.3244138,,,,,,,,,,,,,, -28500,0.6981843,1.2949488,,,,,,,,,,,,,, -28600,0.60538644,1.2769697,,,,,,,,,,,,,, -28700,0.6824322,1.3184465,,,,,,,,,,,,,, -28800,0.68239516,1.2976017,,,,,,,,,,,,,, -28900,0.730775,1.3370792,,,,,,,,,,,,,, -29000,0.722886,1.2790486,,,,,,,,,,,,,, -29100,0.9161924,1.2860506,,,,,,,,,,,,,, -29200,0.67655647,1.3031877,,,,,,,,,,,,,, -29300,0.75596666,1.2773911,,,,,,,,,,,,,, -29400,0.7416357,1.2801708,,,,,,,,,,,,,, -29500,0.65467715,1.2570642,,,,,,,,,,,,,, -29600,0.63546,1.3131618,,,,,,,,,,,,,, -29700,0.62907755,1.2826881,,,,,,,,,,,,,, -29800,0.7099347,1.3298233,,,,,,,,,,,,,, -29839,,,0.21714135,0.0803820814055137,0.4858514,0.1460459368392596,5348.0,0.2838056,0.095972213759064,2472.0,23077.1240503788,25260.23565888405,23077.1240503788,2181.014355182648,0.8413059711456299,0.0 -29900,0.78966,1.2996165,,,,,,,,,,,,,, -30000,0.7146174,1.2699838,,,,,,,,,,,,,, -30100,0.8207277,1.2475762,,,,,,,,,,,,,, -30200,0.61599743,1.2971799,,,,,,,,,,,,,, -30300,0.7516119,1.2399924,,,,,,,,,,,,,, -30400,0.73764825,1.2832352,,,,,,,,,,,,,, -30500,0.70598054,1.231887,,,,,,,,,,,,,, -30600,0.7517042,1.3051192,,,,,,,,,,,,,, -30700,0.61973584,1.310527,,,,,,,,,,,,,, -30800,0.6162143,1.3066759,,,,,,,,,,,,,, -30900,0.7119552,1.262527,,,,,,,,,,,,,, -31000,0.6759857,1.3003949,,,,,,,,,,,,,, -31100,0.68346345,1.2793387,,,,,,,,,,,,,, -31200,0.73461246,1.3138615,,,,,,,,,,,,,, -31300,0.80627364,1.3023853,,,,,,,,,,,,,, -31400,0.65570766,1.2562897,,,,,,,,,,,,,, -31500,0.6697028,1.2733281,,,,,,,,,,,,,, -31600,0.777234,1.2644513,,,,,,,,,,,,,, -31700,0.64534765,1.2641746,,,,,,,,,,,,,, -31702,,,0.19798476,0.074705159769927,0.46828082,0.140648985778696,5348.0,0.26807746,0.0917067820364389,2472.0,24517.084721565247,26831.625014066696,24517.084721565247,2312.3158464431763,0.8909256458282471,0.0 -31800,0.65327483,1.3240316,,,,,,,,,,,,,, -31900,0.7644696,1.3198664,,,,,,,,,,,,,, -32000,0.8894232,1.2694589,,,,,,,,,,,,,, -32100,0.5811696,1.2766665,,,,,,,,,,,,,, -32200,0.69572645,1.255225,,,,,,,,,,,,,, -32300,0.692795,1.2437605,,,,,,,,,,,,,, -32400,0.6769994,1.3030144,,,,,,,,,,,,,, -32500,0.61065984,1.2246124,,,,,,,,,,,,,, -32600,0.71351016,1.2241372,,,,,,,,,,,,,, -32700,0.6532448,1.2904027,,,,,,,,,,,,,, -32800,0.59367347,1.2473598,,,,,,,,,,,,,, -32900,0.7596026,1.3099761,,,,,,,,,,,,,, -33000,0.7412244,1.2247641,,,,,,,,,,,,,, -33100,0.6969169,1.2145776,,,,,,,,,,,,,, -33200,0.67661446,1.2400799,,,,,,,,,,,,,, -33300,0.78404874,1.2804104,,,,,,,,,,,,,, -33400,0.83598745,1.2345363,,,,,,,,,,,,,, -33500,0.61790466,1.2282542,,,,,,,,,,,,,, -33574,,,0.23919919,0.0845372933058556,0.46265563,0.1399828147175531,5348.0,0.26978704,0.0925598683809639,2472.0,25957.28610897064,28401.9341943264,25957.28610897064,2442.287088871002,0.9477226734161376,0.0 -33600,0.7862155,1.2192053,,,,,,,,,,,,,, -33700,0.7857327,1.2479923,,,,,,,,,,,,,, -33800,0.75644135,1.251156,,,,,,,,,,,,,, -33900,0.7922645,1.256941,,,,,,,,,,,,,, -34000,0.68941176,1.2028023,,,,,,,,,,,,,, -34100,0.7917225,1.2738762,,,,,,,,,,,,,, -34200,0.68654466,1.1885678,,,,,,,,,,,,,, -34300,0.71180403,1.2365992,,,,,,,,,,,,,, -34400,0.71363366,1.2643473,,,,,,,,,,,,,, -34500,0.7250032,1.180156,,,,,,,,,,,,,, -34600,0.72858804,1.2652651,,,,,,,,,,,,,, -34700,0.7756827,1.2630657,,,,,,,,,,,,,, -34800,0.6843972,1.1935499,,,,,,,,,,,,,, -34900,0.6634275,1.2556145,,,,,,,,,,,,,, -35000,0.602912,1.189944,,,,,,,,,,,,,, -35100,0.71920156,1.3141813,,,,,,,,,,,,,, -35200,0.67717594,1.2249746,,,,,,,,,,,,,, -35300,0.72902,1.2327664,,,,,,,,,,,,,, -35400,0.7049541,1.2127454,,,,,,,,,,,,,, -35448,,,0.19585215,0.0732812986505813,0.45211855,0.1349817044324512,5348.0,0.2617301,0.088558487193549,2472.0,27397.751355171204,29973.861132621765,27397.751355171204,2573.6164152622223,1.000833511352539,0.0 -35500,0.6760159,1.2133114,,,,,,,,,,,,,, -35600,0.706934,1.2799538,,,,,,,,,,,,,, -35700,0.6878892,1.2294172,,,,,,,,,,,,,, -35800,0.75006473,1.2833066,,,,,,,,,,,,,, -35900,0.7643726,1.2347721,,,,,,,,,,,,,, -36000,0.7579161,1.14362,,,,,,,,,,,,,, -36100,0.7650158,1.188615,,,,,,,,,,,,,, -36200,0.7591953,1.2018638,,,,,,,,,,,,,, -36300,0.7225895,1.2193207,,,,,,,,,,,,,, -36400,0.7297614,1.2483456,,,,,,,,,,,,,, -36500,0.6961605,1.2556092,,,,,,,,,,,,,, -36600,0.95161533,1.2244791,,,,,,,,,,,,,, -36700,1.1366594,1.2142956,,,,,,,,,,,,,, -36800,0.76185,1.231671,,,,,,,,,,,,,, -36900,0.757383,1.2390025,,,,,,,,,,,,,, -37000,0.6820051,1.2658312,,,,,,,,,,,,,, -37100,0.7619694,1.1869366,,,,,,,,,,,,,, -37200,0.7043474,1.2311275,,,,,,,,,,,,,, -37300,0.7169492,1.2596045,,,,,,,,,,,,,, -37325,,,0.17662221,0.068772327400343,0.4406018,0.1327900981878216,5348.0,0.25678006,0.0865476408100257,2472.0,28838.30406999588,31544.836877584457,28838.30406999588,2703.905078172684,1.056443452835083,0.0 -37400,0.6315017,1.2190889,,,,,,,,,,,,,, -37500,0.7030598,1.1726842,,,,,,,,,,,,,, -37600,0.70637536,1.2307004,,,,,,,,,,,,,, -37700,0.7791442,1.2404711,,,,,,,,,,,,,, -37800,0.7638038,1.2216848,,,,,,,,,,,,,, -37900,0.754817,1.2532192,,,,,,,,,,,,,, -38000,0.7894506,1.2435113,,,,,,,,,,,,,, -38100,0.7215795,1.2274264,,,,,,,,,,,,,, -38200,0.80364865,1.2629887,,,,,,,,,,,,,, -38300,0.74735814,1.1804469,,,,,,,,,,,,,, -38400,0.8890505,1.1843574,,,,,,,,,,,,,, -38500,0.74619836,1.1855742,,,,,,,,,,,,,, -38600,0.7710806,1.2354674,,,,,,,,,,,,,, -38700,0.79309034,1.2802236,,,,,,,,,,,,,, -38800,0.748822,1.2045442,,,,,,,,,,,,,, -38900,0.7031322,1.2326963,,,,,,,,,,,,,, -39000,0.8280509,1.2096473,,,,,,,,,,,,,, -39100,0.70731586,1.2449368,,,,,,,,,,,,,, -39194,,,0.17059627,0.0640816435006777,0.43647674,0.1290440928005252,5348.0,0.250659,0.0842118091523977,2472.0,30278.698662042618,33116.45455908775,30278.698662042618,2834.99561214447,1.10945463180542,0.0 -39200,0.7203894,1.2062047,,,,,,,,,,,,,, -39300,0.7708856,1.179683,,,,,,,,,,,,,, -39400,0.7321108,1.1472028,,,,,,,,,,,,,, -39500,0.7232341,1.2175944,,,,,,,,,,,,,, -39600,0.8444343,1.2011347,,,,,,,,,,,,,, -39700,0.7403906,1.1695896,,,,,,,,,,,,,, -39800,0.6927232,1.1967046,,,,,,,,,,,,,, -39900,0.86058974,1.2128264,,,,,,,,,,,,,, -40000,0.73155427,1.1853218,,,,,,,,,,,,,, -40100,0.63782007,1.25249,,,,,,,,,,,,,, -40200,0.68003887,1.2048409,,,,,,,,,,,,,, -40300,0.63831836,1.1543422,,,,,,,,,,,,,, -40400,0.7563658,1.1535741,,,,,,,,,,,,,, -40500,0.76863086,1.1288241,,,,,,,,,,,,,, -40600,0.68716896,1.1933187,,,,,,,,,,,,,, -40700,0.7827848,1.1818055,,,,,,,,,,,,,, -40800,0.7279257,1.1902367,,,,,,,,,,,,,, -40900,0.8036652,1.2207831,,,,,,,,,,,,,, -41000,0.7458383,1.18848,,,,,,,,,,,,,, -41071,,,0.16845806,0.0658107235142118,0.4294395,0.1285903241067032,5348.0,0.24354605,0.083704019661609,2472.0,31718.72379803657,34687.90142393112,31718.72379803657,2966.282071590424,1.1651394367218018,0.0 -41100,0.847959,1.1563982,,,,,,,,,,,,,, -41200,0.7424565,1.180045,,,,,,,,,,,,,, -41300,0.77511925,1.2108428,,,,,,,,,,,,,, -41400,0.8293225,1.1838492,,,,,,,,,,,,,, -41500,0.95523125,1.1761402,,,,,,,,,,,,,, -41600,0.91244173,1.2118951,,,,,,,,,,,,,, -41700,0.7707203,1.158517,,,,,,,,,,,,,, -41800,0.7330937,1.1735057,,,,,,,,,,,,,, -41900,0.68560684,1.1582596,,,,,,,,,,,,,, -42000,0.81657755,1.1442491,,,,,,,,,,,,,, -42100,0.92202103,1.2086736,,,,,,,,,,,,,, -42200,0.66386896,1.116281,,,,,,,,,,,,,, -42300,0.68928355,1.1584313,,,,,,,,,,,,,, -42400,0.81610626,1.217718,,,,,,,,,,,,,, -42500,0.7778396,1.190534,,,,,,,,,,,,,, -42600,0.94740766,1.1599281,,,,,,,,,,,,,, -42700,0.790863,1.1542917,,,,,,,,,,,,,, -42800,0.7650204,1.2008588,,,,,,,,,,,,,, -42900,0.6637027,1.1674818,,,,,,,,,,,,,, -42942,,,0.1750182,0.0647522907774169,0.421468,0.1254718711683095,5348.0,0.23960043,0.0789307984481953,2472.0,33159.33412575722,36261.89895486832,33159.33412575722,3099.5344285964966,1.2207741737365725,0.0 -43000,0.728226,1.1539655,,,,,,,,,,,,,, -43100,0.7334312,1.1409624,,,,,,,,,,,,,, -43200,0.6654842,1.1785514,,,,,,,,,,,,,, -43300,0.81176376,1.0917519,,,,,,,,,,,,,, -43400,0.738715,1.1146986,,,,,,,,,,,,,, -43500,0.83510417,1.1589042,,,,,,,,,,,,,, -43600,0.7241297,1.1539471,,,,,,,,,,,,,, -43700,0.7630203,1.1693907,,,,,,,,,,,,,, -43800,0.64787954,1.119238,,,,,,,,,,,,,, -43900,0.7960302,1.2011025,,,,,,,,,,,,,, -44000,0.80224484,1.1601174,,,,,,,,,,,,,, -44100,0.8104888,1.1910717,,,,,,,,,,,,,, -44200,0.882582,1.176424,,,,,,,,,,,,,, -44300,0.7924962,1.1282002,,,,,,,,,,,,,, -44400,0.8097977,1.1586698,,,,,,,,,,,,,, -44500,0.78382206,1.167449,,,,,,,,,,,,,, -44600,0.78401774,1.1783446,,,,,,,,,,,,,, -44700,0.8742469,1.1768098,,,,,,,,,,,,,, -44800,0.7266768,1.103268,,,,,,,,,,,,,, -44822,,,0.1670249,0.0623548677552352,0.40870935,0.1211562412504706,5348.0,0.23448966,0.0786058131740905,2472.0,34599.805364370346,37834.94379091263,34599.805364370346,3231.975162744522,1.27516770362854,0.0 -44900,0.72193503,1.118478,,,,,,,,,,,,,, -45000,0.72086275,1.1378931,,,,,,,,,,,,,, -45100,0.88720614,1.2116907,,,,,,,,,,,,,, -45200,0.7980471,1.1574808,,,,,,,,,,,,,, -45300,0.7094839,1.202387,,,,,,,,,,,,,, -45400,0.7262751,1.1482811,,,,,,,,,,,,,, -45500,0.76072156,1.1072261,,,,,,,,,,,,,, -45600,0.8245161,1.1623589,,,,,,,,,,,,,, -45700,0.7719679,1.1340594,,,,,,,,,,,,,, -45800,0.7013545,1.1131582,,,,,,,,,,,,,, -45900,0.911877,1.1452546,,,,,,,,,,,,,, -46000,0.71878684,1.1641946,,,,,,,,,,,,,, -46100,0.8318959,1.1152105,,,,,,,,,,,,,, -46200,0.7713646,1.1367054,,,,,,,,,,,,,, -46300,0.8420978,1.1427372,,,,,,,,,,,,,, -46400,0.75306624,1.1399221,,,,,,,,,,,,,, -46500,0.89478654,1.2036815,,,,,,,,,,,,,, -46600,0.82498634,1.1467558,,,,,,,,,,,,,, -46699,,,0.15492408,0.0587481983056295,0.40402222,0.1208183283933691,5348.0,0.2308128,0.0774683647147238,2472.0,36039.968794584274,39406.69979095459,36039.968794584274,3363.430745601654,1.3329179286956787,0.0 -46700,0.77732146,1.1110728,,,,,,,,,,,,,, -46800,0.75487703,1.0877573,,,,,,,,,,,,,, -46900,0.7771121,1.099394,,,,,,,,,,,,,, -47000,0.80077344,1.1472862,,,,,,,,,,,,,, -47100,0.7871223,1.1896657,,,,,,,,,,,,,, -47200,0.7333071,1.1517267,,,,,,,,,,,,,, -47300,0.77484256,1.1428622,,,,,,,,,,,,,, -47400,0.7372378,1.1251483,,,,,,,,,,,,,, -47500,0.8427868,1.1701251,,,,,,,,,,,,,, -47600,0.7483216,1.0879099,,,,,,,,,,,,,, -47700,1.0110288,1.1490074,,,,,,,,,,,,,, -47800,0.8902129,1.1231071,,,,,,,,,,,,,, -47900,0.7531633,1.1170206,,,,,,,,,,,,,, -48000,1.0049291,1.0742756,,,,,,,,,,,,,, -48100,0.85390854,1.0587596,,,,,,,,,,,,,, -48200,0.94720197,1.1850394,,,,,,,,,,,,,, -48300,0.817539,1.1448824,,,,,,,,,,,,,, -48400,0.7314264,1.1252724,,,,,,,,,,,,,, -48500,0.8516779,1.0907114,,,,,,,,,,,,,, -48569,,,0.14281599,0.0546745031368479,0.3916889,0.1152765575369049,5348.0,0.22736883,0.0756403225478845,2472.0,37480.06201171875,40977.48327946663,37480.06201171875,3493.983107566833,1.3930652141571045,0.0 -48600,0.71528333,1.0716078,,,,,,,,,,,,,, -48700,0.6552494,1.0780241,,,,,,,,,,,,,, -48800,0.8807591,1.0914121,,,,,,,,,,,,,, -48900,0.853167,1.156122,,,,,,,,,,,,,, -49000,0.8022733,1.1067466,,,,,,,,,,,,,, -49100,0.8338616,1.0922343,,,,,,,,,,,,,, -49200,0.7578127,1.1218604,,,,,,,,,,,,,, -49300,0.69861984,1.1032511,,,,,,,,,,,,,, -49400,0.7535617,1.0619231,,,,,,,,,,,,,, -49500,0.7093332,1.0727514,,,,,,,,,,,,,, -49600,0.87152356,1.076573,,,,,,,,,,,,,, -49700,0.8760876,1.1020168,,,,,,,,,,,,,, -49800,0.7944216,1.100833,,,,,,,,,,,,,, -49900,0.7454066,1.0756174,,,,,,,,,,,,,, -50000,0.85418135,1.0877773,,,,,,,,,,,,,, -50100,0.9386515,1.0635227,,,,,,,,,,,,,, -50200,0.9955353,1.113438,,,,,,,,,,,,,, -50300,0.796538,1.0693957,,,,,,,,,,,,,, -50400,0.8265423,1.0788976,,,,,,,,,,,,,, -50447,,,0.12362996,0.0471352378940698,0.38162008,0.1128822035780144,5348.0,0.21504119,0.0715373834623118,2472.0,38920.40566134453,42550.37265348434,38920.40566134453,3626.39207482338,1.450796365737915,0.0 -50500,0.80736965,1.1147276,,,,,,,,,,,,,, -50600,0.9007038,1.0882914,,,,,,,,,,,,,, -50700,0.82021534,1.0420654,,,,,,,,,,,,,, -50800,0.8460136,1.1106284,,,,,,,,,,,,,, -50900,0.77023256,1.0231736,,,,,,,,,,,,,, -51000,0.7646863,1.0602891,,,,,,,,,,,,,, -51100,0.77197534,1.0859468,,,,,,,,,,,,,, -51200,0.8456194,1.0853102,,,,,,,,,,,,,, -51300,0.8116633,1.1401844,,,,,,,,,,,,,, -51400,0.84135,1.0886352,,,,,,,,,,,,,, -51500,0.8077007,1.0682616,,,,,,,,,,,,,, -51600,0.85520226,1.072887,,,,,,,,,,,,,, -51700,0.93119437,1.1056446,,,,,,,,,,,,,, -51800,1.0631527,1.0730804,,,,,,,,,,,,,, -51900,0.93426347,1.0852443,,,,,,,,,,,,,, -52000,0.8994084,1.0742178,,,,,,,,,,,,,, -52100,0.814666,1.1633675,,,,,,,,,,,,,, -52200,0.81743836,1.074265,,,,,,,,,,,,,, -52300,0.8278018,1.0492123,,,,,,,,,,,,,, -52319,,,0.12619652,0.0476263446627233,0.37294978,0.1103140658640431,5348.0,0.21084793,0.0695468486584201,2472.0,40360.53560185432,44123.50813269615,40360.53560185432,3759.255940914154,1.513139009475708,0.0 -52400,0.82489157,1.0945581,,,,,,,,,,,,,, -52500,0.8546769,1.0673038,,,,,,,,,,,,,, -52600,0.7833078,1.0602504,,,,,,,,,,,,,, -52700,1.1055549,1.0427989,,,,,,,,,,,,,, -52800,0.8107311,1.0516472,,,,,,,,,,,,,, -52900,1.10492,1.0670434,,,,,,,,,,,,,, -53000,0.83371633,1.0604894,,,,,,,,,,,,,, -53100,0.9270147,1.079688,,,,,,,,,,,,,, -53200,1.001457,1.0502092,,,,,,,,,,,,,, -53300,0.7996374,1.100762,,,,,,,,,,,,,, -53400,0.82647467,1.074052,,,,,,,,,,,,,, -53500,0.9314669,1.0536844,,,,,,,,,,,,,, -53600,0.8298566,1.0558184,,,,,,,,,,,,,, -53700,0.99776804,1.0154024,,,,,,,,,,,,,, -53800,0.91924614,1.0741656,,,,,,,,,,,,,, -53900,0.85983,1.054844,,,,,,,,,,,,,, -54000,0.83423674,1.0880616,,,,,,,,,,,,,, -54100,0.9255672,1.0548129,,,,,,,,,,,,,, -54199,,,0.12776229,0.0495070365468247,0.36549774,0.1087693213744364,5348.0,0.20148818,0.0674750675360022,2472.0,41801.09794139862,45695.05196380615,41801.09794139862,3890.098317861557,1.5733115673065186,0.0 -54200,0.89357555,1.0745227,,,,,,,,,,,,,, -54300,0.9054384,1.0875208,,,,,,,,,,,,,, -54400,0.8454584,1.0306749,,,,,,,,,,,,,, -54500,0.9190438,1.105999,,,,,,,,,,,,,, -54600,0.8858987,1.0758345,,,,,,,,,,,,,, -54700,0.8655705,1.0638535,,,,,,,,,,,,,, -54800,0.88394225,1.0116949,,,,,,,,,,,,,, -54900,0.86606383,1.0368742,,,,,,,,,,,,,, -55000,0.8688941,1.0579668,,,,,,,,,,,,,, -55100,1.249351,1.0465946,,,,,,,,,,,,,, -55200,0.98815215,1.063745,,,,,,,,,,,,,, -55300,0.91687405,1.0427822,,,,,,,,,,,,,, -55400,0.8847725,1.0300899,,,,,,,,,,,,,, -55500,0.91088456,1.0325098,,,,,,,,,,,,,, -55600,1.2827426,1.0427169,,,,,,,,,,,,,, -55700,1.134171,1.0405442,,,,,,,,,,,,,, -55800,0.8436266,0.99544865,,,,,,,,,,,,,, -55900,0.8601544,0.9905738,,,,,,,,,,,,,, -56000,0.8222178,1.0605824,,,,,,,,,,,,,, -56083,,,0.11001102,0.0422276984798891,0.35788244,0.1045019647218977,5348.0,0.19801894,0.0659720106432677,2472.0,43241.61392068863,47266.32650756836,43241.61392068863,4020.719121932984,1.631901502609253,0.0 -56100,0.90105796,1.0143288,,,,,,,,,,,,,, -56200,0.9123908,1.0117431,,,,,,,,,,,,,, -56300,0.98655826,1.0168562,,,,,,,,,,,,,, -56400,0.8871652,1.0415021,,,,,,,,,,,,,, -56500,0.9237552,1.0639006,,,,,,,,,,,,,, -56600,0.91487193,1.0233694,,,,,,,,,,,,,, -56700,0.90508425,0.99717855,,,,,,,,,,,,,, -56800,1.0952793,1.008168,,,,,,,,,,,,,, -56900,0.8966768,1.0046508,,,,,,,,,,,,,, -57000,0.9045429,1.092748,,,,,,,,,,,,,, -57100,0.87176687,1.0519296,,,,,,,,,,,,,, -57200,0.83855486,1.0464531,,,,,,,,,,,,,, -57300,0.8528148,1.0411131,,,,,,,,,,,,,, -57400,0.8766781,1.0111989,,,,,,,,,,,,,, -57500,1.1604501,1.0623937,,,,,,,,,,,,,, -57600,0.89571476,1.053277,,,,,,,,,,,,,, -57700,1.1160511,1.0265391,,,,,,,,,,,,,, -57800,0.9434673,1.0071936,,,,,,,,,,,,,, -57900,0.98451763,1.0708395,,,,,,,,,,,,,, -57964,,,0.10896836,0.0432601685542037,0.34203002,0.101615223456945,5348.0,0.19171545,0.0642455263745861,2472.0,44681.81542778015,48839.02103638649,44681.81542778015,4153.070083618164,1.6941826343536377,0.0 -58000,0.98916405,1.0653335,,,,,,,,,,,,,, -58100,0.97169286,1.0539547,,,,,,,,,,,,,, -58200,1.0853131,1.0165449,,,,,,,,,,,,,, -58300,0.91445535,0.9559185,,,,,,,,,,,,,, -58400,0.87697357,1.0364068,,,,,,,,,,,,,, -58500,0.9192248,1.0275061,,,,,,,,,,,,,, -58600,0.8480317,0.9827012,,,,,,,,,,,,,, -58700,1.0410006,0.9884954,,,,,,,,,,,,,, -58800,0.8358445,0.98387426,,,,,,,,,,,,,, -58900,0.9480191,0.97593015,,,,,,,,,,,,,, -59000,1.3629972,1.0033544,,,,,,,,,,,,,, -59100,0.95673555,0.98660827,,,,,,,,,,,,,, -59200,1.0257211,1.018179,,,,,,,,,,,,,, -59300,1.0655911,1.015467,,,,,,,,,,,,,, -59400,0.84066194,0.9577718,,,,,,,,,,,,,, -59500,1.395335,0.9916286,,,,,,,,,,,,,, -59600,0.9738567,1.0101956,,,,,,,,,,,,,, -59700,1.0894951,1.0089104,,,,,,,,,,,,,, -59800,0.9365064,0.97929627,,,,,,,,,,,,,, -59837,,,0.12385153,0.0441396256721466,0.3396015,0.0995201637429159,5348.0,0.18553233,0.0619909410354843,2472.0,46122.25700163841,50411.139713048935,46122.25700163841,4284.613779306412,1.749586582183838,0.0 -59900,0.89404976,0.95384204,,,,,,,,,,,,,, -60000,0.9129684,0.9897153,,,,,,,,,,,,,, -60100,1.066091,0.98858774,,,,,,,,,,,,,, -60200,1.1317024,1.0178806,,,,,,,,,,,,,, -60300,1.072983,1.0123311,,,,,,,,,,,,,, -60400,1.0813313,1.0337485,,,,,,,,,,,,,, -60500,1.1894472,1.0204715,,,,,,,,,,,,,, -60600,1.1339078,0.9708933,,,,,,,,,,,,,, -60700,0.979876,0.98007405,,,,,,,,,,,,,, -60800,1.107968,0.9486208,,,,,,,,,,,,,, -60900,1.0905691,0.9949385,,,,,,,,,,,,,, -61000,1.4129443,0.92428285,,,,,,,,,,,,,, -61100,1.0172689,1.0153894,,,,,,,,,,,,,, -61200,1.09772,1.0181539,,,,,,,,,,,,,, -61300,1.0552286,0.98403484,,,,,,,,,,,,,, -61400,1.0613698,0.9542082,,,,,,,,,,,,,, -61500,1.0551862,0.98489994,,,,,,,,,,,,,, -61600,0.924929,0.9816865,,,,,,,,,,,,,, -61700,1.1435335,0.97424746,,,,,,,,,,,,,, -61721,,,0.079162344,0.0305428935156667,0.33002672,0.0957741583556194,5348.0,0.17721963,0.0592895009444884,2472.0,47562.51353478432,51984.82677769661,47562.51353478432,4417.90145611763,1.811532735824585,0.0 -61800,0.9750801,0.97088474,,,,,,,,,,,,,, -61900,1.0586783,0.9611768,,,,,,,,,,,,,, -62000,1.2124792,0.95771587,,,,,,,,,,,,,, -62100,1.0926687,0.91946065,,,,,,,,,,,,,, -62200,1.107907,0.991212,,,,,,,,,,,,,, -62300,1.1264982,0.93562025,,,,,,,,,,,,,, -62400,1.0318943,1.0090729,,,,,,,,,,,,,, -62500,0.9008066,0.93873405,,,,,,,,,,,,,, -62600,1.0833887,0.9467336,,,,,,,,,,,,,, -62700,1.2923294,0.969582,,,,,,,,,,,,,, -62800,1.100978,0.99115205,,,,,,,,,,,,,, -62900,1.0778939,0.92367345,,,,,,,,,,,,,, -63000,1.0685004,0.8941504,,,,,,,,,,,,,, -63100,1.0180917,0.9437546,,,,,,,,,,,,,, -63200,1.1010624,1.008999,,,,,,,,,,,,,, -63300,1.0099237,0.9338643,,,,,,,,,,,,,, -63400,1.2160411,0.9438637,,,,,,,,,,,,,, -63500,1.1472261,0.98814183,,,,,,,,,,,,,, -63598,,,0.07950283,0.0309600099870999,0.33140823,0.0957162304372592,5348.0,0.17631762,0.057827067211017,2472.0,49002.65753102303,53557.37498831749,49002.65753102303,4550.166195392609,1.872153282165528,0.0 -63600,1.0870823,0.98810416,,,,,,,,,,,,,, -63700,1.3062059,0.91817063,,,,,,,,,,,,,, -63800,1.1438732,0.8873941,,,,,,,,,,,,,, -63900,1.0979475,0.94653654,,,,,,,,,,,,,, -64000,0.97073615,0.9205333,,,,,,,,,,,,,, -64100,0.9851654,0.9454705,,,,,,,,,,,,,, -64200,1.0469627,0.9251551,,,,,,,,,,,,,, -64300,1.1470947,0.9102031,,,,,,,,,,,,,, -64400,0.9981883,0.9635718,,,,,,,,,,,,,, -64500,0.97410965,0.9263021,,,,,,,,,,,,,, -64600,0.9692579,0.92765945,,,,,,,,,,,,,, -64700,0.9913155,0.98882955,,,,,,,,,,,,,, -64800,1.1246147,0.94557947,,,,,,,,,,,,,, -64900,1.0640038,0.9330836,,,,,,,,,,,,,, -65000,1.3257678,0.94760144,,,,,,,,,,,,,, -65100,1.2851821,0.8955631,,,,,,,,,,,,,, -65200,0.9828771,0.88733333,,,,,,,,,,,,,, -65300,1.1607844,0.9302439,,,,,,,,,,,,,, -65400,1.205517,0.9388157,,,,,,,,,,,,,, -65476,,,0.09328667,0.0359912868411498,0.3145385,0.0916226575398013,5348.0,0.16963969,0.0557755976682306,2472.0,50443.269728422165,55127.55088472366,50443.269728422165,4679.587248086929,1.935394525527954,0.0 -65500,1.0863241,0.94048697,,,,,,,,,,,,,, -65600,1.0692338,0.9201942,,,,,,,,,,,,,, -65700,1.1373491,0.92729914,,,,,,,,,,,,,, -65800,0.9657994,0.8977434,,,,,,,,,,,,,, -65900,1.1358838,0.9129579,,,,,,,,,,,,,, -66000,1.0483726,0.8797727,,,,,,,,,,,,,, -66100,1.0674983,0.95374167,,,,,,,,,,,,,, -66200,0.9353006,0.9221628,,,,,,,,,,,,,, -66300,1.1555579,0.95062834,,,,,,,,,,,,,, -66400,1.2087884,0.90402985,,,,,,,,,,,,,, -66500,1.0447721,0.915501,,,,,,,,,,,,,, -66600,1.1180947,0.9424228,,,,,,,,,,,,,, -66700,1.2320735,0.93324494,,,,,,,,,,,,,, -66800,1.0386358,0.8744582,,,,,,,,,,,,,, -66900,1.0987483,0.9324384,,,,,,,,,,,,,, -67000,1.3523307,0.96134275,,,,,,,,,,,,,, -67100,0.9834256,0.94686615,,,,,,,,,,,,,, -67200,1.268243,0.9183901,,,,,,,,,,,,,, -67300,1.1190912,0.886382,,,,,,,,,,,,,, -67363,,,0.0919391,0.0344043891459985,0.31448692,0.089740000193093,5348.0,0.17028022,0.0544756565718115,2472.0,51883.24606084824,56696.70898079872,51883.24606084824,4808.630482435226,1.9940145015716555,0.0 -67400,1.008441,0.94660646,,,,,,,,,,,,,, -67500,1.3067257,0.9489188,,,,,,,,,,,,,, -67600,1.3132352,0.9003158,,,,,,,,,,,,,, -67700,0.98330456,0.8694275,,,,,,,,,,,,,, -67800,1.1139244,0.8537238,,,,,,,,,,,,,, -67900,1.402659,0.8932399,,,,,,,,,,,,,, -68000,1.2245516,0.8824135,,,,,,,,,,,,,, -68100,1.1734898,0.8933761,,,,,,,,,,,,,, -68200,1.3229922,0.91868544,,,,,,,,,,,,,, -68300,1.2999097,0.89096427,,,,,,,,,,,,,, -68400,1.2260382,0.8874136,,,,,,,,,,,,,, -68500,1.0915847,0.93114835,,,,,,,,,,,,,, -68600,1.1956625,0.87615085,,,,,,,,,,,,,, -68700,1.0570033,0.90382195,,,,,,,,,,,,,, -68800,1.2906692,0.89621764,,,,,,,,,,,,,, -68900,1.2863113,0.8471354,,,,,,,,,,,,,, -69000,1.1573669,0.8396445,,,,,,,,,,,,,, -69100,1.1202166,0.86882675,,,,,,,,,,,,,, -69200,1.0777475,0.88309604,,,,,,,,,,,,,, -69247,,,0.10656315,0.0412882922637434,0.30797923,0.0877511416627243,5348.0,0.1640094,0.0533585196920764,2472.0,53323.26822352409,58265.23030591011,53323.26822352409,4936.9902057647705,2.0550296306610107,0.0 -69300,1.1773206,0.8856338,,,,,,,,,,,,,, -69400,1.3936069,0.87199974,,,,,,,,,,,,,, -69500,1.1015618,0.86500996,,,,,,,,,,,,,, -69600,1.214621,0.8951051,,,,,,,,,,,,,, -69700,1.4658186,0.8606898,,,,,,,,,,,,,, -69800,1.2517185,0.8456194,,,,,,,,,,,,,, -69900,1.0671495,0.8320818,,,,,,,,,,,,,, -70000,1.5654303,0.8989439,,,,,,,,,,,,,, -70100,1.1529189,0.8966408,,,,,,,,,,,,,, -70200,1.162505,0.8517051,,,,,,,,,,,,,, -70300,1.1409916,0.9235112,,,,,,,,,,,,,, -70400,1.1140823,0.86473495,,,,,,,,,,,,,, -70500,1.1028358,0.8539257,,,,,,,,,,,,,, -70600,1.0565802,0.8516429,,,,,,,,,,,,,, -70700,1.265353,0.87485224,,,,,,,,,,,,,, -70800,1.0785227,0.9240528,,,,,,,,,,,,,, -70900,1.2352288,0.89401,,,,,,,,,,,,,, -71000,1.4350015,0.8922061,,,,,,,,,,,,,, -71100,1.3168377,0.83414865,,,,,,,,,,,,,, -71130,,,0.08343864,0.0310373453686598,0.29869056,0.0850092201936723,5348.0,0.16000023,0.0515304775252371,2472.0,54763.13962602615,59832.89411592484,54763.13962602615,5064.64481139183,2.114377737045288,0.0 -71200,1.1094295,0.87806964,,,,,,,,,,,,,, -71300,1.1885285,0.8894997,,,,,,,,,,,,,, -71400,1.2148498,0.85375273,,,,,,,,,,,,,, -71500,1.271082,0.90194994,,,,,,,,,,,,,, -71600,1.1611519,0.8763012,,,,,,,,,,,,,, -71700,1.3226345,0.8485642,,,,,,,,,,,,,, -71800,1.2358077,0.87541944,,,,,,,,,,,,,, -71900,1.119484,0.88190293,,,,,,,,,,,,,, -72000,1.2705421,0.8768693,,,,,,,,,,,,,, -72100,1.3843136,0.8537264,,,,,,,,,,,,,, -72200,1.0188372,0.8305833,,,,,,,,,,,,,, -72300,1.28084,0.8869012,,,,,,,,,,,,,, -72400,1.0691926,0.8424754,,,,,,,,,,,,,, -72500,1.3386116,0.8678417,,,,,,,,,,,,,, -72600,1.5603749,0.86792576,,,,,,,,,,,,,, -72700,1.1942147,0.90193206,,,,,,,,,,,,,, -72800,1.2143145,0.85727817,,,,,,,,,,,,,, -72900,1.2439976,0.8485658,,,,,,,,,,,,,, -73000,1.2974551,0.84348476,,,,,,,,,,,,,, -73022,,,0.08170432,0.0314145653484378,0.2942599,0.0834065477857053,5348.0,0.15819612,0.0514086080474478,2472.0,56203.4166662693,61402.37157559395,56203.4166662693,5193.7060170173645,2.1751956939697266,0.0 -73100,1.2609807,0.8337954,,,,,,,,,,,,,, -73200,1.3831413,0.88889784,,,,,,,,,,,,,, -73300,1.202747,0.82135856,,,,,,,,,,,,,, -73400,1.106639,0.86539125,,,,,,,,,,,,,, -73500,1.251776,0.8678642,,,,,,,,,,,,,, -73600,1.1811042,0.8804076,,,,,,,,,,,,,, -73700,1.1967995,0.8696787,,,,,,,,,,,,,, -73800,1.3122163,0.8825916,,,,,,,,,,,,,, -73900,1.4478843,0.8594656,,,,,,,,,,,,,, -74000,1.2035288,0.8564557,,,,,,,,,,,,,, -74100,1.2514235,0.88855517,,,,,,,,,,,,,, -74200,1.3201991,0.80642116,,,,,,,,,,,,,, -74300,1.4582056,0.8281593,,,,,,,,,,,,,, -74400,1.0863627,0.8155677,,,,,,,,,,,,,, -74500,1.3312114,0.8546892,,,,,,,,,,,,,, -74600,1.3894606,0.8163259,,,,,,,,,,,,,, -74700,1.4252658,0.8323017,,,,,,,,,,,,,, -74800,1.1668189,0.84011483,,,,,,,,,,,,,, -74899,,,0.061528895,0.0238534577830379,0.29282883,0.0826148662347818,5348.0,0.1556986,0.0505148985436597,2472.0,57643.36317658424,62972.83807229996,57643.36317658424,5324.08319067955,2.238858938217163,0.0 -74900,1.3636711,0.8512164,,,,,,,,,,,,,, -75000,1.2776663,0.86535805,,,,,,,,,,,,,, -75100,1.2522354,0.8425313,,,,,,,,,,,,,, -75200,1.2582389,0.866718,,,,,,,,,,,,,, -75300,1.2293854,0.86725765,,,,,,,,,,,,,, -75400,1.3587857,0.82912415,,,,,,,,,,,,,, -75500,1.1396884,0.8596525,,,,,,,,,,,,,, -75600,1.6399882,0.8766605,,,,,,,,,,,,,, -75700,1.5607644,0.85227907,,,,,,,,,,,,,, -75800,1.3331573,0.8063433,,,,,,,,,,,,,, -75900,1.2133487,0.8908159,,,,,,,,,,,,,, -76000,1.1989535,0.8511659,,,,,,,,,,,,,, -76100,1.8879958,0.89592654,,,,,,,,,,,,,, -76200,1.2928942,0.8354638,,,,,,,,,,,,,, -76300,1.2129661,0.80099446,,,,,,,,,,,,,, -76400,1.5237266,0.8241429,,,,,,,,,,,,,, -76500,1.1231039,0.81994,,,,,,,,,,,,,, -76600,1.104404,0.8664894,,,,,,,,,,,,,, -76700,1.2374009,0.82714903,,,,,,,,,,,,,, -76783,,,0.07354782,0.0285005856659624,0.29150823,0.0815721637042972,5348.0,0.15388285,0.0497430585176609,2472.0,59083.7230618,64544.46607041359,59083.7230618,5455.210937261581,2.299424171447754,0.0 -76800,1.2053989,0.83928627,,,,,,,,,,,,,, -76900,1.1573529,0.8246659,,,,,,,,,,,,,, -77000,1.2821525,0.84886116,,,,,,,,,,,,,, -77100,1.159477,0.8675886,,,,,,,,,,,,,, -77200,1.2849633,0.82093585,,,,,,,,,,,,,, -77300,1.4223256,0.8571777,,,,,,,,,,,,,, -77400,1.2959613,0.810423,,,,,,,,,,,,,, -77500,1.2856307,0.8033776,,,,,,,,,,,,,, -77600,1.1864022,0.82791257,,,,,,,,,,,,,, -77700,1.6617459,0.8305045,,,,,,,,,,,,,, -77800,1.194972,0.8033202,,,,,,,,,,,,,, -77900,1.2507287,0.80395395,,,,,,,,,,,,,, -78000,1.5658451,0.8221376,,,,,,,,,,,,,, -78100,1.259436,0.80681574,,,,,,,,,,,,,, -78200,1.3632156,0.84460217,,,,,,,,,,,,,, -78300,1.1574031,0.8405053,,,,,,,,,,,,,, -78400,1.1895633,0.8493895,,,,,,,,,,,,,, -78500,1.8196203,0.83492243,,,,,,,,,,,,,, -78600,1.3274658,0.8586115,,,,,,,,,,,,,, -78661,,,0.066440485,0.024689378757515,0.2898537,0.0811666682757755,5348.0,0.15340821,0.0498649279954502,2472.0,60523.66146183014,66115.00144767761,60523.66146183014,5585.668171644211,2.3606507778167725,0.0 -78700,1.169797,0.80599874,,,,,,,,,,,,,, -78800,1.2853754,0.81283194,,,,,,,,,,,,,, -78900,1.3011132,0.8004138,,,,,,,,,,,,,, -79000,1.3444463,0.8032846,,,,,,,,,,,,,, -79100,1.3031027,0.80910647,,,,,,,,,,,,,, -79200,1.1924583,0.85200393,,,,,,,,,,,,,, -79300,1.1542051,0.8327467,,,,,,,,,,,,,, -79379,,,,,,,,,,,61068.697877407074,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/eval_measurements.csv deleted file mode 100644 index b35308b73..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -128.7245156764984,0.0,35.27013564109802,1,0,35.27013564109802,30.214182,2472,0.9757885970791949,163.99474143981934,30.981934,1.1274528946597675,30.090126,5348,0.9587360128213792 -236.133540391922,0.030684471130371,1475.9121260643003,1847,0,1475.9121260643003,7.232382,2472,0.899579550301627,1712.1499185562134,7.2952933,0.9413900245298448,7.2312818,5348,0.8966179750330672 -343.08443236351013,0.0878360271453857,2916.4644272327423,3713,0,2916.4644272327423,11.63979,2472,0.899579550301627,3259.789473295212,11.699928,0.9387706290361156,11.644811,5348,0.8966179750330672 -465.2628490924835,0.1403732299804687,4356.433173418045,5586,0,4356.433173418045,2.911543,2472,0.6546625231044219,4822.067915201187,3.2989602,0.7216809110434939,3.2346897,5348,0.6911669579153673 -593.5417983531952,0.1914393901824951,5796.8844957351685,7453,0,5796.8844957351685,1.5693043,2472,0.4566652448560924,6390.927451133728,1.9405282,0.5345338386931263,1.9695325,5348,0.5126138042229452 -721.8849267959595,0.2475202083587646,7237.020395278931,9310,0,7237.020395278931,1.1772131,2472,0.3633335364491296,7959.54049539566,1.5818727,0.4580297052680436,1.544094,5348,0.4266487733763287 -854.1291942596436,0.302591323852539,8677.674111127853,11181,0,8677.674111127853,1.0078455,2472,0.3222635224341397,9532.57375240326,1.2154001,0.3766602007901923,1.3427572,5348,0.3861088851772111 -984.2290835380554,0.3531553745269775,10117.671419858932,13048,0,10117.671419858932,0.9073358,2472,0.2978286921373875,11102.799887180328,1.1826919,0.3672558613082949,1.2360047,5348,0.3613060814659625 -1113.2875900268557,0.4085633754730224,11557.662242412567,14917,0,11557.662242412567,0.825852,2472,0.2695549732902728,12671.982889652252,1.1055752,0.3447297012862157,1.1436431,5348,0.3354123019589291 -1243.396124124527,0.4651896953582763,12997.665132284164,16786,0,12997.665132284164,0.76820284,2472,0.2563930696890297,14242.229521751404,1.0170374,0.3245074974715272,1.073623,5348,0.3191538662058178 -1372.4247515201569,0.5202996730804443,14437.630206108091,18642,0,14437.630206108091,0.7460582,2472,0.2492840168179879,15811.356358766556,0.9199811,0.3001748843506713,1.0531265,5348,0.3134962395126331 -1501.3820896148682,0.5721523761749268,15877.807674646378,20504,0,15877.807674646378,0.7135042,2472,0.2379298438039526,17380.621470928192,0.8688427,0.2856110859482051,1.0261918,5348,0.3053766762891375 -1632.0220003128052,0.6320977210998535,17318.290395498276,22366,0,17318.290395498276,0.66886187,2472,0.2255804033879714,18951.88121271133,0.8891562,0.2931569977350705,0.9662605,5348,0.2899099220869498 -1760.4569537639618,0.6845858097076416,18758.73726868629,24236,0,18758.73726868629,0.6421993,2472,0.2152621209351451,20520.893973350525,0.8207097,0.2707461272015434,0.93003184,5348,0.2816455390675536 -1890.377283334732,0.7414276599884033,20198.5342566967,26106,0,20198.5342566967,0.6293675,2472,0.216257388337091,22090.823832035065,0.8356706,0.2724414374372397,0.9050223,5348,0.2752927773540458 -2021.034260749817,0.7939121723175049,21638.650775671005,27969,0,21638.650775671005,0.6123362,2472,0.2079905754270509,23661.728536605835,0.7512261,0.2498455009324634,0.8914386,5348,0.26941309364048 -2150.510663509369,0.8455498218536377,23078.74978494644,29839,0,23078.74978494644,1.012293,2472,0.3135701663518372,25231.434408426285,1.465784,0.4232012911945225,1.4207444,5348,0.3876053564015177 -2279.962345123291,0.9009816646575928,24519.01113843918,31714,0,24519.01113843918,0.564204,2472,0.1911725874921292,26801.28219127655,0.7389771,0.2490417309784311,0.8475726,5348,0.2582426600500111 -2418.74117398262,0.9561948776245116,25959.022108078003,33599,0,25959.022108078003,0.5521909,2472,0.1895273495419738,28380.20565652848,0.525048,0.1873489863803322,0.8184007,5348,0.250229298010176 -2550.938821077347,1.008819818496704,27399.132900476456,35477,0,27399.132900476456,0.52806854,2472,0.1811792903134076,29952.6443362236,0.4653179,0.1710944701654082,0.8033444,5348,0.2460777971943578 -2682.19997549057,1.0661985874176023,28839.23782515525,37348,0,28839.23782515525,0.49983346,2472,0.1728312310848414,31524.1459839344,0.44106385,0.1581023314363641,0.7626492,5348,0.2333819284204022 -2813.408171415329,1.1235594749450684,30279.30905532837,39210,0,30279.30905532837,0.48848337,2472,0.1663518371823776,33095.56187868118,0.41415095,0.1518818386404201,0.74457777,5348,0.2290180252372631 -2944.7663497924805,1.187300205230713,31719.44522428513,41085,0,31719.44522428513,0.46268943,2472,0.159324030629862,34667.19851708412,0.4113284,0.1506914356935129,0.7187782,5348,0.2191123511976597 -3076.475456237793,1.2487943172454834,33159.86926102638,42952,0,33159.86926102638,0.4484339,2472,0.1555257652387626,36239.47293305397,0.38013545,0.1403506930963199,0.7010598,5348,0.2169303996060901 -3206.82887673378,1.3093175888061523,34600.46166920662,44827,0,34600.46166920662,0.4318812,2472,0.1488026323807202,37810.55836892128,0.42511886,0.1494690125217942,0.67514026,5348,0.2071695453623874 -3337.507324695587,1.3644671440124512,36040.71513128281,46706,0,36040.71513128281,0.42331767,2472,0.1455121564804095,39381.62409090996,0.36075482,0.1357082317487888,0.66591084,5348,0.2046786448728965 -3468.593843460083,1.4214856624603271,37481.05079865456,48574,0,37481.05079865456,0.40389547,2472,0.1369406698758962,40953.1812107563,0.33703363,0.1252548560952095,0.63967717,5348,0.1958349826698977 -3599.5541915893555,1.4804182052612305,38921.05453419685,50440,0,38921.05453419685,0.38614887,2472,0.1320049560254301,42524.28307437897,0.30615988,0.1129864016459773,0.6176672,5348,0.1900711547930525 -3731.398805141449,1.538435935974121,40361.0865046978,52298,0,40361.0865046978,0.3775948,2472,0.1291207117177503,44096.294939517975,0.29809442,0.1123178510033256,0.604947,5348,0.1849252247120499 -3864.732713460922,1.5968918800354004,41801.07320189476,54167,0,41801.07320189476,0.35808083,2472,0.1219710356874454,45669.7519903183,0.30563256,0.1127773453742896,0.58232206,5348,0.1802620272840495 -3996.17692565918,1.6657493114471436,43241.62187099457,56037,0,43241.62187099457,0.34402737,2472,0.1187008713667662,47241.89470553398,0.28636116,0.1062314602486693,0.56730247,5348,0.1750774785908068 -4127.185501813889,1.7246437072753906,44681.83616948128,57906,0,44681.83616948128,0.32470948,2472,0.1107590437308309,48813.25409722328,0.26780295,0.0981005955978449,0.5353414,5348,0.1650366394083628 -4258.248637199402,1.7822024822235107,46122.00344848633,59768,0,46122.00344848633,0.30931154,2472,0.1048483740580504,50384.61937427521,0.2401949,0.0903859619566075,0.51465917,5348,0.157988742674532 -4389.823725938797,1.842402458190918,47561.935584545135,61647,0,47561.935584545135,0.29717535,2472,0.1007454349724778,51956.26599335671,0.22045113,0.0828239963686649,0.49300355,5348,0.1518483833283451 -4522.307675123215,1.9072229862213133,49002.48093056679,63519,0,49002.48093056679,0.28151563,2472,0.0974346474925355,53529.43874812126,0.20646651,0.0763023078873953,0.47628415,5348,0.1478513569614876 -4652.021373748779,1.967803716659546,50443.06334590912,65400,0,50443.06334590912,0.2692499,2472,0.0920114557309122,55099.87467169762,0.19849926,0.0751354364544603,0.45558718,5348,0.1399828147175531 -4782.024187803268,2.0273022651672363,51883.29848217964,67276,0,51883.29848217964,0.25713265,2472,0.0860195397396055,56670.25017309189,0.17537242,0.0655394871243481,0.4380229,5348,0.1361595721057763 -4912.617544412613,2.0856642723083496,53323.735122442245,69151,0,53323.735122442245,0.24621879,2472,0.0841914975727662,58241.41588020325,0.1733773,0.0655802799387666,0.4252374,5348,0.129690954555548 -5044.270725250244,2.151578187942505,54764.13278627396,71024,0,54764.13278627396,0.2360904,2472,0.0806369711372453,59813.611429452896,0.18672377,0.0644929298911721,0.41258916,5348,0.1245932977398457 -5178.692124843597,2.211843967437744,56204.91225242615,72889,0,56204.91225242615,0.22824389,2472,0.0783214510592488,61388.951934337616,0.13206606,0.0492559071545209,0.4033787,5348,0.1225947845564169 -5312.1723392009735,2.276295185089112,57645.34727025032,74767,0,57645.34727025032,0.22153029,2472,0.0754575183312006,62963.01133060455,0.13103484,0.0474115181823407,0.39362144,5348,0.1193315118221226 -5441.886587142944,2.340604066848755,59085.85717082024,76649,0,59085.85717082024,0.21700285,2472,0.0736904109032559,64533.37995290756,0.1626185,0.061350264086916,0.3823941,5348,0.1163192600673894 -5571.456894159317,2.406038999557495,60526.26091170311,78528,0,60526.26091170311,0.21509795,2472,0.07287794771799402,66103.49896073341,0.16589448,0.06081247166399922,0.38134545,5348,0.11600065651640808 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/measurements.csv deleted file mode 100644 index 9acbb87c6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/measurements.csv +++ /dev/null @@ -1,838 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,67.55436,31.791185,,,,,,,,,,,,,, -1,,,30.981934,1.1274528946597675,30.090126,0.9587360128213792,5348.0,30.214182,0.9757885970791949,2472.0,35.27013564109802,163.99474143981934,35.27013564109802,128.7245156764984,0.0,0.0 -100,3.72571,5.803013,,,,,,,,,,,,,, -200,3.0470178,5.787939,,,,,,,,,,,,,, -300,3.53603,5.6608405,,,,,,,,,,,,,, -400,1.5516669,5.572333,,,,,,,,,,,,,, -500,3.3740325,5.563582,,,,,,,,,,,,,, -600,15.072767,6.6580033,,,,,,,,,,,,,, -700,1.8052535,5.530485,,,,,,,,,,,,,, -800,0.80869347,5.5402246,,,,,,,,,,,,,, -900,3.6815965,5.559422,,,,,,,,,,,,,, -1000,0.27819717,5.508923,,,,,,,,,,,,,, -1100,1.2595199,5.496183,,,,,,,,,,,,,, -1200,0.80130035,5.4995995,,,,,,,,,,,,,, -1300,0.3131474,5.485476,,,,,,,,,,,,,, -1400,1.76882,5.5011196,,,,,,,,,,,,,, -1500,0.3325946,5.4755564,,,,,,,,,,,,,, -1600,0.7654036,5.486213,,,,,,,,,,,,,, -1700,0.76713806,5.479709,,,,,,,,,,,,,, -1800,1.2978945,5.4784017,,,,,,,,,,,,,, -1847,,,7.2952933,0.9413900245298448,7.2312818,0.8966179750330672,5348.0,7.232382,0.899579550301627,2472.0,1475.9121260643003,1712.1499185562134,1475.9121260643003,236.133540391922,0.030684471130371,0.0 -1900,2.0779345,5.4681044,,,,,,,,,,,,,, -2000,0.81034684,5.4603343,,,,,,,,,,,,,, -2100,0.87061113,5.426196,,,,,,,,,,,,,, -2200,0.8879874,5.4473767,,,,,,,,,,,,,, -2300,1.8847954,5.4506774,,,,,,,,,,,,,, -2400,0.33070737,5.4772563,,,,,,,,,,,,,, -2500,1.572093,5.4310207,,,,,,,,,,,,,, -2600,0.8287665,5.3903346,,,,,,,,,,,,,, -2700,1.7828455,5.405651,,,,,,,,,,,,,, -2800,0.9819796,5.3317456,,,,,,,,,,,,,, -2900,0.85722405,5.2776823,,,,,,,,,,,,,, -3000,0.34418848,5.1655855,,,,,,,,,,,,,, -3100,1.4532989,4.8103857,,,,,,,,,,,,,, -3200,0.69472736,4.4364867,,,,,,,,,,,,,, -3300,1.8212528,4.249783,,,,,,,,,,,,,, -3400,0.5125106,4.105883,,,,,,,,,,,,,, -3500,1.9962108,4.035677,,,,,,,,,,,,,, -3600,1.8489528,3.8294482,,,,,,,,,,,,,, -3700,1.1802571,3.784126,,,,,,,,,,,,,, -3713,,,11.699928,0.9387706290361156,11.644811,0.8966179750330672,5348.0,11.63979,0.899579550301627,2472.0,2916.4644272327423,3259.789473295212,2916.4644272327423,343.08443236351013,0.0878360271453857,0.0 -3800,2.1757126,3.689537,,,,,,,,,,,,,, -3900,1.0521228,3.6404173,,,,,,,,,,,,,, -4000,1.1126899,3.6015139,,,,,,,,,,,,,, -4100,1.7036796,3.560375,,,,,,,,,,,,,, -4200,1.6698725,3.4257536,,,,,,,,,,,,,, -4300,1.040254,3.3352163,,,,,,,,,,,,,, -4400,1.9890991,3.4382303,,,,,,,,,,,,,, -4500,1.1272839,3.3786447,,,,,,,,,,,,,, -4600,1.3011308,3.3200543,,,,,,,,,,,,,, -4700,1.2129987,3.2897472,,,,,,,,,,,,,, -4800,1.104607,3.1835515,,,,,,,,,,,,,, -4900,0.89952844,3.175308,,,,,,,,,,,,,, -5000,0.8386799,3.1300535,,,,,,,,,,,,,, -5100,1.3449881,3.1173866,,,,,,,,,,,,,, -5200,0.9182592,3.013253,,,,,,,,,,,,,, -5300,1.230735,3.091093,,,,,,,,,,,,,, -5400,1.3341724,2.9808555,,,,,,,,,,,,,, -5500,1.5607738,2.9961505,,,,,,,,,,,,,, -5586,,,3.2989602,0.7216809110434939,3.2346897,0.6911669579153673,5348.0,2.911543,0.6546625231044219,2472.0,4356.433173418045,4822.067915201187,4356.433173418045,465.2628490924835,0.1403732299804687,0.0 -5600,0.9043031,2.9467938,,,,,,,,,,,,,, -5700,0.7780536,3.0500965,,,,,,,,,,,,,, -5800,0.78376853,2.8951368,,,,,,,,,,,,,, -5900,1.6015613,2.9268792,,,,,,,,,,,,,, -6000,0.62884194,2.9137573,,,,,,,,,,,,,, -6100,0.7835719,2.7937262,,,,,,,,,,,,,, -6200,0.5014314,2.8639112,,,,,,,,,,,,,, -6300,1.0376483,2.9031851,,,,,,,,,,,,,, -6400,1.8525767,2.8498988,,,,,,,,,,,,,, -6500,1.7162966,2.8555176,,,,,,,,,,,,,, -6600,1.0403353,2.8310952,,,,,,,,,,,,,, -6700,0.9055046,2.764169,,,,,,,,,,,,,, -6800,1.1277287,2.6851194,,,,,,,,,,,,,, -6900,1.3427292,2.7597048,,,,,,,,,,,,,, -7000,0.9897511,2.7544918,,,,,,,,,,,,,, -7100,0.60501623,2.717206,,,,,,,,,,,,,, -7200,0.9243747,2.754,,,,,,,,,,,,,, -7300,0.5630936,2.6733556,,,,,,,,,,,,,, -7400,0.6966546,2.62719,,,,,,,,,,,,,, -7453,,,1.9405282,0.5345338386931263,1.9695325,0.5126138042229452,5348.0,1.5693043,0.4566652448560924,2472.0,5796.8844957351685,6390.927451133728,5796.8844957351685,593.5417983531952,0.1914393901824951,0.0 -7500,1.0460331,2.6782749,,,,,,,,,,,,,, -7600,1.1907368,2.6037526,,,,,,,,,,,,,, -7700,0.8393955,2.7339876,,,,,,,,,,,,,, -7800,0.6177327,2.5791502,,,,,,,,,,,,,, -7900,0.51125,2.5980582,,,,,,,,,,,,,, -8000,0.5654107,2.623713,,,,,,,,,,,,,, -8100,0.9464888,2.6776967,,,,,,,,,,,,,, -8200,1.5502437,2.5717778,,,,,,,,,,,,,, -8300,2.38831,2.5697706,,,,,,,,,,,,,, -8400,1.0634977,2.5391405,,,,,,,,,,,,,, -8500,1.0162416,2.573908,,,,,,,,,,,,,, -8600,0.595501,2.5432684,,,,,,,,,,,,,, -8700,0.9608967,2.5387204,,,,,,,,,,,,,, -8800,1.0919963,2.5194912,,,,,,,,,,,,,, -8900,1.3226595,2.554548,,,,,,,,,,,,,, -9000,1.2461963,2.5112329,,,,,,,,,,,,,, -9100,1.0113212,2.4992251,,,,,,,,,,,,,, -9200,1.188775,2.4539115,,,,,,,,,,,,,, -9300,0.60748255,2.410587,,,,,,,,,,,,,, -9310,,,1.5818727,0.4580297052680436,1.544094,0.4266487733763287,5348.0,1.1772131,0.3633335364491296,2472.0,7237.020395278931,7959.54049539566,7237.020395278931,721.8849267959595,0.2475202083587646,0.0 -9400,0.91692775,2.4105258,,,,,,,,,,,,,, -9500,1.3772424,2.5040386,,,,,,,,,,,,,, -9600,0.8374198,2.4423766,,,,,,,,,,,,,, -9700,0.51080877,2.4221263,,,,,,,,,,,,,, -9800,0.80947775,2.3375568,,,,,,,,,,,,,, -9900,1.5222471,2.4195464,,,,,,,,,,,,,, -10000,1.1900305,2.415318,,,,,,,,,,,,,, -10100,0.7093966,2.445229,,,,,,,,,,,,,, -10200,1.1552434,2.4025931,,,,,,,,,,,,,, -10300,1.079242,2.3266604,,,,,,,,,,,,,, -10400,1.0646168,2.2886107,,,,,,,,,,,,,, -10500,0.7029452,2.2848485,,,,,,,,,,,,,, -10600,1.5761355,2.4199128,,,,,,,,,,,,,, -10700,0.77075,2.2765434,,,,,,,,,,,,,, -10800,0.9634893,2.3352315,,,,,,,,,,,,,, -10900,0.9177132,2.3026028,,,,,,,,,,,,,, -11000,1.4262773,2.3143468,,,,,,,,,,,,,, -11100,0.83334273,2.3365269,,,,,,,,,,,,,, -11181,,,1.2154001,0.3766602007901923,1.3427572,0.3861088851772111,5348.0,1.0078455,0.3222635224341397,2472.0,8677.674111127853,9532.57375240326,8677.674111127853,854.1291942596436,0.302591323852539,0.0 -11200,1.0091016,2.3148162,,,,,,,,,,,,,, -11300,0.48524618,2.2645266,,,,,,,,,,,,,, -11400,1.0064403,2.2486117,,,,,,,,,,,,,, -11500,1.4809952,2.2715442,,,,,,,,,,,,,, -11600,1.2214826,2.2789268,,,,,,,,,,,,,, -11700,0.7086453,2.2533996,,,,,,,,,,,,,, -11800,0.51476413,2.209849,,,,,,,,,,,,,, -11900,1.4311421,2.2829597,,,,,,,,,,,,,, -12000,0.7868106,2.2932875,,,,,,,,,,,,,, -12100,0.8662822,2.2157838,,,,,,,,,,,,,, -12200,0.62233514,2.189412,,,,,,,,,,,,,, -12300,0.5603915,2.205083,,,,,,,,,,,,,, -12400,0.8717522,2.15639,,,,,,,,,,,,,, -12500,0.767371,2.209862,,,,,,,,,,,,,, -12600,1.1687635,2.209325,,,,,,,,,,,,,, -12700,0.96542966,2.2855465,,,,,,,,,,,,,, -12800,0.7787322,2.1870143,,,,,,,,,,,,,, -12900,0.66731596,2.1276953,,,,,,,,,,,,,, -13000,0.5832878,2.1773715,,,,,,,,,,,,,, -13048,,,1.1826919,0.3672558613082949,1.2360047,0.3613060814659625,5348.0,0.9073358,0.2978286921373875,2472.0,10117.671419858932,11102.799887180328,10117.671419858932,984.2290835380554,0.3531553745269775,0.0 -13100,0.8307735,2.1977952,,,,,,,,,,,,,, -13200,0.6429228,2.144509,,,,,,,,,,,,,, -13300,0.6200258,2.0897615,,,,,,,,,,,,,, -13400,0.91055995,2.0930061,,,,,,,,,,,,,, -13500,1.0512888,2.1231315,,,,,,,,,,,,,, -13600,0.87938267,2.091014,,,,,,,,,,,,,, -13700,1.0332786,2.129965,,,,,,,,,,,,,, -13800,0.8048786,2.130906,,,,,,,,,,,,,, -13900,0.49880984,2.0981221,,,,,,,,,,,,,, -14000,0.73496175,2.136628,,,,,,,,,,,,,, -14100,0.7260825,2.1191955,,,,,,,,,,,,,, -14200,0.6527556,2.1294923,,,,,,,,,,,,,, -14300,0.66082156,2.118391,,,,,,,,,,,,,, -14400,0.9476323,2.1079743,,,,,,,,,,,,,, -14500,0.9990048,2.139726,,,,,,,,,,,,,, -14600,0.78858584,2.102261,,,,,,,,,,,,,, -14700,0.7239432,2.1553018,,,,,,,,,,,,,, -14800,0.8545574,2.140127,,,,,,,,,,,,,, -14900,1.1507902,2.0762546,,,,,,,,,,,,,, -14917,,,1.1055752,0.3447297012862157,1.1436431,0.3354123019589291,5348.0,0.825852,0.2695549732902728,2472.0,11557.662242412567,12671.982889652252,11557.662242412567,1113.2875900268557,0.4085633754730224,0.0 -15000,3.9381175,2.0593145,,,,,,,,,,,,,, -15100,0.7411221,2.0436113,,,,,,,,,,,,,, -15200,1.1717421,2.2039504,,,,,,,,,,,,,, -15300,0.8956037,2.0652552,,,,,,,,,,,,,, -15400,0.70350873,2.030845,,,,,,,,,,,,,, -15500,0.6155662,2.1425445,,,,,,,,,,,,,, -15600,0.5962907,2.0646386,,,,,,,,,,,,,, -15700,0.64530456,2.086664,,,,,,,,,,,,,, -15800,0.9589931,2.0413644,,,,,,,,,,,,,, -15900,0.5996679,2.0686874,,,,,,,,,,,,,, -16000,0.6249959,2.099434,,,,,,,,,,,,,, -16100,0.9221334,1.9658229,,,,,,,,,,,,,, -16200,0.95674294,1.992094,,,,,,,,,,,,,, -16300,0.60006446,2.0603912,,,,,,,,,,,,,, -16400,0.95331335,2.0661485,,,,,,,,,,,,,, -16500,0.68477744,2.0220017,,,,,,,,,,,,,, -16600,0.7877813,2.0156722,,,,,,,,,,,,,, -16700,0.70409197,2.0311842,,,,,,,,,,,,,, -16786,,,1.0170374,0.3245074974715272,1.073623,0.3191538662058178,5348.0,0.76820284,0.2563930696890297,2472.0,12997.665132284164,14242.229521751404,12997.665132284164,1243.396124124527,0.4651896953582763,0.0 -16800,0.6729585,2.0221786,,,,,,,,,,,,,, -16900,1.0111175,2.0679357,,,,,,,,,,,,,, -17000,0.5938512,2.0254204,,,,,,,,,,,,,, -17100,0.7899421,2.067398,,,,,,,,,,,,,, -17200,2.0155427,2.017041,,,,,,,,,,,,,, -17300,0.54902726,1.9959816,,,,,,,,,,,,,, -17400,0.9598184,2.0207262,,,,,,,,,,,,,, -17500,0.8134071,2.0079637,,,,,,,,,,,,,, -17600,0.6850101,2.0777295,,,,,,,,,,,,,, -17700,0.89319396,2.0386076,,,,,,,,,,,,,, -17800,1.5498804,2.0837908,,,,,,,,,,,,,, -17900,0.81932807,2.0483186,,,,,,,,,,,,,, -18000,1.159328,1.9927151,,,,,,,,,,,,,, -18100,0.5003971,1.9427881,,,,,,,,,,,,,, -18200,0.67950624,2.0142446,,,,,,,,,,,,,, -18300,0.6753151,2.0286787,,,,,,,,,,,,,, -18400,1.0239197,1.9950895,,,,,,,,,,,,,, -18500,1.1279103,2.0284293,,,,,,,,,,,,,, -18600,0.70435345,1.9907932,,,,,,,,,,,,,, -18642,,,0.9199811,0.3001748843506713,1.0531265,0.3134962395126331,5348.0,0.7460582,0.2492840168179879,2472.0,14437.630206108091,15811.356358766556,14437.630206108091,1372.4247515201569,0.5202996730804443,0.0 -18700,0.92389387,1.9594675,,,,,,,,,,,,,, -18800,1.1527039,2.0021174,,,,,,,,,,,,,, -18900,0.9548838,1.944743,,,,,,,,,,,,,, -19000,1.117866,2.0042121,,,,,,,,,,,,,, -19100,0.60768837,1.9522113,,,,,,,,,,,,,, -19200,0.64046085,1.9845643,,,,,,,,,,,,,, -19300,1.0229679,2.0176342,,,,,,,,,,,,,, -19400,1.094496,1.9969474,,,,,,,,,,,,,, -19500,0.66168684,2.0316746,,,,,,,,,,,,,, -19600,0.7458128,1.9497368,,,,,,,,,,,,,, -19700,0.59458184,1.9562646,,,,,,,,,,,,,, -19800,0.7296847,2.0211327,,,,,,,,,,,,,, -19900,0.63739634,1.9878455,,,,,,,,,,,,,, -20000,0.85933614,1.946222,,,,,,,,,,,,,, -20100,0.64769953,1.955032,,,,,,,,,,,,,, -20200,0.96043175,1.9513978,,,,,,,,,,,,,, -20300,1.0744253,1.9569659,,,,,,,,,,,,,, -20400,0.64698887,2.0363703,,,,,,,,,,,,,, -20500,0.98077226,1.9576076,,,,,,,,,,,,,, -20504,,,0.8688427,0.2856110859482051,1.0261918,0.3053766762891375,5348.0,0.7135042,0.2379298438039526,2472.0,15877.807674646378,17380.621470928192,15877.807674646378,1501.3820896148682,0.5721523761749268,0.0 -20600,1.002123,2.0245626,,,,,,,,,,,,,, -20700,0.57551575,1.9244252,,,,,,,,,,,,,, -20800,0.7861561,1.9431067,,,,,,,,,,,,,, -20900,0.6697211,1.9060876,,,,,,,,,,,,,, -21000,0.8894846,1.9901361,,,,,,,,,,,,,, -21100,0.6089118,1.9165008,,,,,,,,,,,,,, -21200,0.57303625,1.9136088,,,,,,,,,,,,,, -21300,0.7388395,1.942788,,,,,,,,,,,,,, -21400,0.89419967,1.8842101,,,,,,,,,,,,,, -21500,0.81434953,1.957804,,,,,,,,,,,,,, -21600,0.69466734,1.9175631,,,,,,,,,,,,,, -21700,0.827773,1.9469641,,,,,,,,,,,,,, -21800,0.9471056,1.9119573,,,,,,,,,,,,,, -21900,0.7661786,1.9138019,,,,,,,,,,,,,, -22000,0.8845927,1.9531039,,,,,,,,,,,,,, -22100,0.5587651,1.9364707,,,,,,,,,,,,,, -22200,1.0842487,1.9495294,,,,,,,,,,,,,, -22300,0.6974019,1.9047476,,,,,,,,,,,,,, -22366,,,0.8891562,0.2931569977350705,0.9662605,0.2899099220869498,5348.0,0.66886187,0.2255804033879714,2472.0,17318.290395498276,18951.88121271133,17318.290395498276,1632.0220003128052,0.6320977210998535,0.0 -22400,0.7633195,1.9132713,,,,,,,,,,,,,, -22500,0.6091135,1.8433439,,,,,,,,,,,,,, -22600,0.7173611,1.878743,,,,,,,,,,,,,, -22700,0.9196651,1.9007022,,,,,,,,,,,,,, -22800,0.57791805,1.8617271,,,,,,,,,,,,,, -22900,0.72006685,1.9102864,,,,,,,,,,,,,, -23000,0.7807686,1.8497436,,,,,,,,,,,,,, -23100,0.6632601,1.7978095,,,,,,,,,,,,,, -23200,0.55778664,1.8925918,,,,,,,,,,,,,, -23300,1.1107517,1.9378031,,,,,,,,,,,,,, -23400,0.87639064,1.8515903,,,,,,,,,,,,,, -23500,0.9392972,1.9070532,,,,,,,,,,,,,, -23600,0.6295261,1.9002829,,,,,,,,,,,,,, -23700,0.86869293,1.8106154,,,,,,,,,,,,,, -23800,0.5311121,1.880173,,,,,,,,,,,,,, -23900,0.67337877,1.868866,,,,,,,,,,,,,, -24000,0.89747024,1.8720326,,,,,,,,,,,,,, -24100,0.5788262,1.8546273,,,,,,,,,,,,,, -24200,0.7033053,1.8602573,,,,,,,,,,,,,, -24236,,,0.8207097,0.2707461272015434,0.93003184,0.2816455390675536,5348.0,0.6421993,0.2152621209351451,2472.0,18758.73726868629,20520.893973350525,18758.73726868629,1760.4569537639618,0.6845858097076416,0.0 -24300,0.56927925,1.8591504,,,,,,,,,,,,,, -24400,1.1627239,1.9487602,,,,,,,,,,,,,, -24500,0.7070601,1.7880545,,,,,,,,,,,,,, -24600,0.47804084,1.8458593,,,,,,,,,,,,,, -24700,0.84771,1.8002557,,,,,,,,,,,,,, -24800,0.9144837,1.8673483,,,,,,,,,,,,,, -24900,0.87692815,1.8591664,,,,,,,,,,,,,, -25000,0.8082589,1.8184854,,,,,,,,,,,,,, -25100,0.62565047,1.8298908,,,,,,,,,,,,,, -25200,0.5162185,1.8870889,,,,,,,,,,,,,, -25300,0.61912227,1.7994775,,,,,,,,,,,,,, -25400,0.6977186,1.8383139,,,,,,,,,,,,,, -25500,0.56702965,1.8379145,,,,,,,,,,,,,, -25600,0.5769881,1.7926022,,,,,,,,,,,,,, -25700,0.7124559,1.8208557,,,,,,,,,,,,,, -25800,0.49737898,1.8175203,,,,,,,,,,,,,, -25900,0.52401227,1.7732911,,,,,,,,,,,,,, -26000,0.52429074,1.862578,,,,,,,,,,,,,, -26100,1.0021268,1.8471785,,,,,,,,,,,,,, -26106,,,0.8356706,0.2724414374372397,0.9050223,0.2752927773540458,5348.0,0.6293675,0.216257388337091,2472.0,20198.5342566967,22090.823832035065,20198.5342566967,1890.377283334732,0.7414276599884033,0.0 -26200,0.70898783,1.794077,,,,,,,,,,,,,, -26300,0.5757559,1.8464884,,,,,,,,,,,,,, -26400,0.53600234,1.8609878,,,,,,,,,,,,,, -26500,0.59622675,1.8362776,,,,,,,,,,,,,, -26600,0.5729647,1.7713778,,,,,,,,,,,,,, -26700,0.71696216,1.8556341,,,,,,,,,,,,,, -26800,0.6123018,1.8108591,,,,,,,,,,,,,, -26900,0.5199642,1.794332,,,,,,,,,,,,,, -27000,0.60474616,1.7304417,,,,,,,,,,,,,, -27100,0.509401,1.8381749,,,,,,,,,,,,,, -27200,0.68499434,1.8113247,,,,,,,,,,,,,, -27300,0.57640755,1.8257049,,,,,,,,,,,,,, -27400,0.5001406,1.7807916,,,,,,,,,,,,,, -27500,0.5520035,1.7762336,,,,,,,,,,,,,, -27600,0.6826806,1.8641651,,,,,,,,,,,,,, -27700,0.45611358,1.7483535,,,,,,,,,,,,,, -27800,0.43984193,1.7780801,,,,,,,,,,,,,, -27900,0.6596687,1.7730098,,,,,,,,,,,,,, -27969,,,0.7512261,0.2498455009324634,0.8914386,0.26941309364048,5348.0,0.6123362,0.2079905754270509,2472.0,21638.650775671005,23661.728536605835,21638.650775671005,2021.034260749817,0.7939121723175049,0.0 -28000,0.68856466,1.772014,,,,,,,,,,,,,, -28100,0.49931276,1.7878227,,,,,,,,,,,,,, -28200,0.64790255,1.8248734,,,,,,,,,,,,,, -28300,0.5140654,1.7490588,,,,,,,,,,,,,, -28400,0.6180466,1.8262254,,,,,,,,,,,,,, -28500,0.7772241,1.8261555,,,,,,,,,,,,,, -28600,0.66023535,1.7627057,,,,,,,,,,,,,, -28700,0.54106504,1.8216228,,,,,,,,,,,,,, -28800,0.681241,1.8103033,,,,,,,,,,,,,, -28900,0.5952284,1.8391658,,,,,,,,,,,,,, -29000,0.54657096,1.7846528,,,,,,,,,,,,,, -29100,0.70457613,1.7840449,,,,,,,,,,,,,, -29200,0.6996087,1.7763194,,,,,,,,,,,,,, -29300,0.8515489,1.7731342,,,,,,,,,,,,,, -29400,0.5602586,1.7475185,,,,,,,,,,,,,, -29500,0.61871195,1.7631391,,,,,,,,,,,,,, -29600,0.56031615,1.7525744,,,,,,,,,,,,,, -29700,0.6651441,1.7767485,,,,,,,,,,,,,, -29800,0.76483494,2.1844556,,,,,,,,,,,,,, -29839,,,1.465784,0.4232012911945225,1.4207444,0.3876053564015177,5348.0,1.012293,0.3135701663518372,2472.0,23078.74978494644,25231.434408426285,23078.74978494644,2150.510663509369,0.8455498218536377,0.0 -29900,0.50973916,1.8593447,,,,,,,,,,,,,, -30000,0.5093601,1.8354622,,,,,,,,,,,,,, -30100,0.5774663,1.8304785,,,,,,,,,,,,,, -30200,0.6802979,1.7642783,,,,,,,,,,,,,, -30300,0.60260135,1.7501833,,,,,,,,,,,,,, -30400,0.7909763,1.8235563,,,,,,,,,,,,,, -30500,0.5264123,1.726553,,,,,,,,,,,,,, -30600,0.63887626,1.8029635,,,,,,,,,,,,,, -30700,0.67139214,1.7956007,,,,,,,,,,,,,, -30800,0.48366904,1.7969612,,,,,,,,,,,,,, -30900,0.73399556,1.7599452,,,,,,,,,,,,,, -31000,0.56758636,1.7670199,,,,,,,,,,,,,, -31100,0.67869705,1.8031775,,,,,,,,,,,,,, -31200,0.6303912,1.7641257,,,,,,,,,,,,,, -31300,0.64121604,1.777527,,,,,,,,,,,,,, -31400,0.6146988,1.7465705,,,,,,,,,,,,,, -31500,0.65017253,1.7794396,,,,,,,,,,,,,, -31600,0.63470054,1.7220095,,,,,,,,,,,,,, -31700,0.59087193,1.7276075,,,,,,,,,,,,,, -31714,,,0.7389771,0.2490417309784311,0.8475726,0.2582426600500111,5348.0,0.564204,0.1911725874921292,2472.0,24519.01113843918,26801.28219127655,24519.01113843918,2279.962345123291,0.9009816646575928,0.0 -31800,0.5488781,1.7502116,,,,,,,,,,,,,, -31900,0.71393126,1.8084862,,,,,,,,,,,,,, -32000,0.5774385,1.7435664,,,,,,,,,,,,,, -32100,0.55595267,1.7833167,,,,,,,,,,,,,, -32200,0.6051068,1.7141665,,,,,,,,,,,,,, -32300,0.6455736,1.7772048,,,,,,,,,,,,,, -32400,0.6156457,1.7398155,,,,,,,,,,,,,, -32500,0.4728462,1.7469822,,,,,,,,,,,,,, -32600,0.6282523,1.7207707,,,,,,,,,,,,,, -32700,0.49687392,1.7435168,,,,,,,,,,,,,, -32800,0.5220538,1.7152003,,,,,,,,,,,,,, -32900,0.61462706,1.746444,,,,,,,,,,,,,, -33000,0.72639966,1.6922536,,,,,,,,,,,,,, -33100,0.5717475,1.7087584,,,,,,,,,,,,,, -33200,0.6049136,1.6990707,,,,,,,,,,,,,, -33300,0.46984023,1.7724315,,,,,,,,,,,,,, -33400,0.64690644,1.7310618,,,,,,,,,,,,,, -33500,0.66493887,1.776971,,,,,,,,,,,,,, -33599,,,0.525048,0.1873489863803322,0.8184007,0.250229298010176,5348.0,0.5521909,0.1895273495419738,2472.0,25959.022108078003,28380.20565652848,25959.022108078003,2418.74117398262,0.9561948776245116,0.0 -33600,0.5338377,1.7094193,,,,,,,,,,,,,, -33700,0.68197966,1.7102495,,,,,,,,,,,,,, -33800,0.5977731,1.7836978,,,,,,,,,,,,,, -33900,0.691742,1.7249407,,,,,,,,,,,,,, -34000,0.7792282,1.7308241,,,,,,,,,,,,,, -34100,0.5526915,1.7686555,,,,,,,,,,,,,, -34200,0.56379795,1.6588941,,,,,,,,,,,,,, -34300,0.7525552,1.6955421,,,,,,,,,,,,,, -34400,0.64649606,1.7278489,,,,,,,,,,,,,, -34500,0.589428,1.6646101,,,,,,,,,,,,,, -34600,1.0204492,1.8257762,,,,,,,,,,,,,, -34700,0.51538324,1.6947005,,,,,,,,,,,,,, -34800,0.5913263,1.72078,,,,,,,,,,,,,, -34900,0.5603626,1.6929584,,,,,,,,,,,,,, -35000,0.56186086,1.6577761,,,,,,,,,,,,,, -35100,0.64056396,1.7491447,,,,,,,,,,,,,, -35200,0.6199715,1.6959703,,,,,,,,,,,,,, -35300,0.65766305,1.6769315,,,,,,,,,,,,,, -35400,0.70379746,1.7247163,,,,,,,,,,,,,, -35477,,,0.4653179,0.1710944701654082,0.8033444,0.2460777971943578,5348.0,0.52806854,0.1811792903134076,2472.0,27399.132900476456,29952.6443362236,27399.132900476456,2550.938821077347,1.008819818496704,0.0 -35500,0.5105181,1.670621,,,,,,,,,,,,,, -35600,0.74569935,1.7653397,,,,,,,,,,,,,, -35700,0.79089797,1.697312,,,,,,,,,,,,,, -35800,0.50814044,1.7090554,,,,,,,,,,,,,, -35900,0.5970605,1.7159134,,,,,,,,,,,,,, -36000,0.66777855,1.6720663,,,,,,,,,,,,,, -36100,0.6529145,1.6718112,,,,,,,,,,,,,, -36200,0.60061,1.6900485,,,,,,,,,,,,,, -36300,0.5825556,1.67003,,,,,,,,,,,,,, -36400,0.7183851,1.6445391,,,,,,,,,,,,,, -36500,0.56800663,1.6914859,,,,,,,,,,,,,, -36600,0.5280998,1.665071,,,,,,,,,,,,,, -36700,0.7809818,1.6741053,,,,,,,,,,,,,, -36800,0.6491356,1.6802925,,,,,,,,,,,,,, -36900,0.66696125,1.7055657,,,,,,,,,,,,,, -37000,0.60926515,1.7342731,,,,,,,,,,,,,, -37100,0.66493577,1.6099787,,,,,,,,,,,,,, -37200,0.66870147,1.6963751,,,,,,,,,,,,,, -37300,0.54055804,1.6845396,,,,,,,,,,,,,, -37348,,,0.44106385,0.1581023314363641,0.7626492,0.2333819284204022,5348.0,0.49983346,0.1728312310848414,2472.0,28839.23782515525,31524.1459839344,28839.23782515525,2682.19997549057,1.0661985874176023,0.0 -37400,0.7789804,1.6766578,,,,,,,,,,,,,, -37500,0.7928771,1.6826103,,,,,,,,,,,,,, -37600,0.5366047,1.688525,,,,,,,,,,,,,, -37700,0.6192065,1.7196099,,,,,,,,,,,,,, -37800,0.6695318,1.6504353,,,,,,,,,,,,,, -37900,0.62004906,1.6511122,,,,,,,,,,,,,, -38000,0.5831644,1.6809586,,,,,,,,,,,,,, -38100,0.57635283,1.6766928,,,,,,,,,,,,,, -38200,0.6987331,1.7010931,,,,,,,,,,,,,, -38300,0.7149293,1.6989763,,,,,,,,,,,,,, -38400,0.6679242,1.6333581,,,,,,,,,,,,,, -38500,0.7321296,1.6338488,,,,,,,,,,,,,, -38600,0.5217205,1.6380435,,,,,,,,,,,,,, -38700,0.58968616,1.6450413,,,,,,,,,,,,,, -38800,2.642928,1.6994063,,,,,,,,,,,,,, -38900,0.53238595,1.6561923,,,,,,,,,,,,,, -39000,0.6013656,1.6353667,,,,,,,,,,,,,, -39100,0.6430345,1.6581461,,,,,,,,,,,,,, -39200,0.5367043,1.6226703,,,,,,,,,,,,,, -39210,,,0.41415095,0.1518818386404201,0.74457777,0.2290180252372631,5348.0,0.48848337,0.1663518371823776,2472.0,30279.30905532837,33095.56187868118,30279.30905532837,2813.408171415329,1.1235594749450684,0.0 -39300,0.7339947,1.639748,,,,,,,,,,,,,, -39400,0.52481025,1.6218309,,,,,,,,,,,,,, -39500,0.56519526,1.6196232,,,,,,,,,,,,,, -39600,0.58532584,1.646689,,,,,,,,,,,,,, -39700,0.6190427,1.6578624,,,,,,,,,,,,,, -39800,0.6323058,1.5401984,,,,,,,,,,,,,, -39900,0.77892387,1.6668723,,,,,,,,,,,,,, -40000,0.6082372,1.685857,,,,,,,,,,,,,, -40100,0.8007063,1.6281965,,,,,,,,,,,,,, -40200,0.6416214,1.6117309,,,,,,,,,,,,,, -40300,0.6008635,1.5924726,,,,,,,,,,,,,, -40400,0.6888906,1.5610541,,,,,,,,,,,,,, -40500,0.60609156,1.6179545,,,,,,,,,,,,,, -40600,0.5711917,1.6180831,,,,,,,,,,,,,, -40700,0.840815,1.6496015,,,,,,,,,,,,,, -40800,0.5539934,1.6584432,,,,,,,,,,,,,, -40900,0.7180145,1.6217204,,,,,,,,,,,,,, -41000,0.7773501,1.6201965,,,,,,,,,,,,,, -41085,,,0.4113284,0.1506914356935129,0.7187782,0.2191123511976597,5348.0,0.46268943,0.159324030629862,2472.0,31719.44522428513,34667.19851708412,31719.44522428513,2944.7663497924805,1.187300205230713,0.0 -41100,0.63129014,1.6108615,,,,,,,,,,,,,, -41200,0.7719041,1.5580314,,,,,,,,,,,,,, -41300,0.66936195,1.5793623,,,,,,,,,,,,,, -41400,0.70663935,1.6211557,,,,,,,,,,,,,, -41500,0.4884376,1.5222281,,,,,,,,,,,,,, -41600,0.6070342,1.6147913,,,,,,,,,,,,,, -41700,0.5359392,1.5495692,,,,,,,,,,,,,, -41800,0.5157005,1.5434612,,,,,,,,,,,,,, -41900,0.6481884,1.5895426,,,,,,,,,,,,,, -42000,0.62065,1.6165812,,,,,,,,,,,,,, -42100,0.75426525,1.694252,,,,,,,,,,,,,, -42200,0.67888767,1.6311013,,,,,,,,,,,,,, -42300,0.60531366,1.6105238,,,,,,,,,,,,,, -42400,0.65491605,1.631314,,,,,,,,,,,,,, -42500,0.680051,1.5816032,,,,,,,,,,,,,, -42600,0.6059868,1.6230224,,,,,,,,,,,,,, -42700,0.57612944,1.5207574,,,,,,,,,,,,,, -42800,0.8062069,1.582415,,,,,,,,,,,,,, -42900,0.7868553,1.6172005,,,,,,,,,,,,,, -42952,,,0.38013545,0.1403506930963199,0.7010598,0.2169303996060901,5348.0,0.4484339,0.1555257652387626,2472.0,33159.86926102638,36239.47293305397,33159.86926102638,3076.475456237793,1.2487943172454834,0.0 -43000,0.6153169,1.5981047,,,,,,,,,,,,,, -43100,0.5354157,1.5941532,,,,,,,,,,,,,, -43200,0.62956667,1.579083,,,,,,,,,,,,,, -43300,0.8168646,1.5747236,,,,,,,,,,,,,, -43400,0.7092684,1.6183195,,,,,,,,,,,,,, -43500,0.6682891,1.6165721,,,,,,,,,,,,,, -43600,0.594595,1.6149122,,,,,,,,,,,,,, -43700,0.64847404,1.615979,,,,,,,,,,,,,, -43800,0.72463024,1.5619661,,,,,,,,,,,,,, -43900,0.8500878,1.5881902,,,,,,,,,,,,,, -44000,0.69270706,1.5480934,,,,,,,,,,,,,, -44100,0.7282088,1.5608853,,,,,,,,,,,,,, -44200,0.59166443,1.5250021,,,,,,,,,,,,,, -44300,0.65298826,1.5520651,,,,,,,,,,,,,, -44400,0.6190614,1.57986,,,,,,,,,,,,,, -44500,0.63494545,1.5379456,,,,,,,,,,,,,, -44600,0.61517197,1.5670342,,,,,,,,,,,,,, -44700,0.72735965,1.5538454,,,,,,,,,,,,,, -44800,0.6323353,1.527141,,,,,,,,,,,,,, -44827,,,0.42511886,0.1494690125217942,0.67514026,0.2071695453623874,5348.0,0.4318812,0.1488026323807202,2472.0,34600.46166920662,37810.55836892128,34600.46166920662,3206.82887673378,1.3093175888061523,0.0 -44900,0.6391177,1.5406003,,,,,,,,,,,,,, -45000,0.66560584,1.5910598,,,,,,,,,,,,,, -45100,0.5826804,1.6084585,,,,,,,,,,,,,, -45200,0.7551672,1.615248,,,,,,,,,,,,,, -45300,0.5874693,1.5571021,,,,,,,,,,,,,, -45400,0.58985287,1.5590605,,,,,,,,,,,,,, -45500,0.6423875,1.5646809,,,,,,,,,,,,,, -45600,0.66932756,1.5375416,,,,,,,,,,,,,, -45700,0.7707509,1.6326802,,,,,,,,,,,,,, -45800,0.83145565,1.5132298,,,,,,,,,,,,,, -45900,0.69536185,1.584411,,,,,,,,,,,,,, -46000,0.5980845,1.5527091,,,,,,,,,,,,,, -46100,0.7172038,1.492959,,,,,,,,,,,,,, -46200,0.77153766,1.4960915,,,,,,,,,,,,,, -46300,0.6600589,1.5463241,,,,,,,,,,,,,, -46400,0.70558566,1.5192158,,,,,,,,,,,,,, -46500,0.67279774,1.5037899,,,,,,,,,,,,,, -46600,0.66061026,1.5435923,,,,,,,,,,,,,, -46700,0.70952415,1.4917582,,,,,,,,,,,,,, -46706,,,0.36075482,0.1357082317487888,0.66591084,0.2046786448728965,5348.0,0.42331767,0.1455121564804095,2472.0,36040.71513128281,39381.62409090996,36040.71513128281,3337.507324695587,1.3644671440124512,0.0 -46800,0.6642748,1.4736152,,,,,,,,,,,,,, -46900,0.7778984,1.5497742,,,,,,,,,,,,,, -47000,0.76850533,1.5103302,,,,,,,,,,,,,, -47100,0.63004935,1.5470085,,,,,,,,,,,,,, -47200,0.71056336,1.531591,,,,,,,,,,,,,, -47300,0.7620434,1.4999952,,,,,,,,,,,,,, -47400,0.7278165,1.5110902,,,,,,,,,,,,,, -47500,0.668349,1.5108545,,,,,,,,,,,,,, -47600,0.6756662,1.5437181,,,,,,,,,,,,,, -47700,0.66981906,1.5286801,,,,,,,,,,,,,, -47800,0.7760976,1.5141034,,,,,,,,,,,,,, -47900,0.58206815,1.490836,,,,,,,,,,,,,, -48000,0.7890905,1.486635,,,,,,,,,,,,,, -48100,0.638902,1.5088661,,,,,,,,,,,,,, -48200,0.65174246,1.5177894,,,,,,,,,,,,,, -48300,0.6284186,1.5040288,,,,,,,,,,,,,, -48400,0.5806732,1.4974394,,,,,,,,,,,,,, -48500,0.7375295,1.4712151,,,,,,,,,,,,,, -48574,,,0.33703363,0.1252548560952095,0.63967717,0.1958349826698977,5348.0,0.40389547,0.1369406698758962,2472.0,37481.05079865456,40953.1812107563,37481.05079865456,3468.593843460083,1.4214856624603271,0.0 -48600,0.66381156,1.5394367,,,,,,,,,,,,,, -48700,0.6682144,1.4759727,,,,,,,,,,,,,, -48800,0.67938375,1.4951254,,,,,,,,,,,,,, -48900,0.6332726,1.5722764,,,,,,,,,,,,,, -49000,0.66523665,1.5677874,,,,,,,,,,,,,, -49100,0.6532178,1.4926419,,,,,,,,,,,,,, -49200,0.6341266,1.4543253,,,,,,,,,,,,,, -49300,0.75031894,1.4712692,,,,,,,,,,,,,, -49400,0.7303518,1.4409707,,,,,,,,,,,,,, -49500,0.5811328,1.4587789,,,,,,,,,,,,,, -49600,0.94222605,1.4683063,,,,,,,,,,,,,, -49700,0.6714684,1.482173,,,,,,,,,,,,,, -49800,0.85586977,1.4725807,,,,,,,,,,,,,, -49900,0.6281954,1.4659293,,,,,,,,,,,,,, -50000,0.63499016,1.4640262,,,,,,,,,,,,,, -50100,0.7522692,1.4646987,,,,,,,,,,,,,, -50200,0.6624357,1.4987859,,,,,,,,,,,,,, -50300,0.620262,1.4879618,,,,,,,,,,,,,, -50400,0.66657096,1.4893585,,,,,,,,,,,,,, -50440,,,0.30615988,0.1129864016459773,0.6176672,0.1900711547930525,5348.0,0.38614887,0.1320049560254301,2472.0,38921.05453419685,42524.28307437897,38921.05453419685,3599.5541915893555,1.4804182052612305,0.0 -50500,0.95072055,1.4462649,,,,,,,,,,,,,, -50600,0.618186,1.4478818,,,,,,,,,,,,,, -50700,0.66830075,1.4567374,,,,,,,,,,,,,, -50800,0.65210634,1.452426,,,,,,,,,,,,,, -50900,0.75734544,1.4761885,,,,,,,,,,,,,, -51000,0.6503013,1.470427,,,,,,,,,,,,,, -51100,0.68737465,1.4859295,,,,,,,,,,,,,, -51200,0.7238639,1.4814738,,,,,,,,,,,,,, -51300,0.7534506,1.4591846,,,,,,,,,,,,,, -51400,0.63423806,1.4849565,,,,,,,,,,,,,, -51500,0.72541714,1.4244281,,,,,,,,,,,,,, -51600,0.60194314,1.4337859,,,,,,,,,,,,,, -51700,0.6881538,1.4941719,,,,,,,,,,,,,, -51800,0.6960846,1.4183018,,,,,,,,,,,,,, -51900,0.6779186,1.5041167,,,,,,,,,,,,,, -52000,0.6402925,1.4799836,,,,,,,,,,,,,, -52100,0.8067359,1.4787537,,,,,,,,,,,,,, -52200,0.7060737,1.5464642,,,,,,,,,,,,,, -52298,,,0.29809442,0.1123178510033256,0.604947,0.1849252247120499,5348.0,0.3775948,0.1291207117177503,2472.0,40361.0865046978,44096.294939517975,40361.0865046978,3731.398805141449,1.538435935974121,0.0 -52300,0.84160155,1.468421,,,,,,,,,,,,,, -52400,0.72900355,1.4371706,,,,,,,,,,,,,, -52500,0.7352952,1.4564943,,,,,,,,,,,,,, -52600,0.7915799,1.4739469,,,,,,,,,,,,,, -52700,0.64285964,1.4159237,,,,,,,,,,,,,, -52800,0.5705357,1.4148432,,,,,,,,,,,,,, -52900,0.60879517,1.4730661,,,,,,,,,,,,,, -53000,0.85399675,1.4237561,,,,,,,,,,,,,, -53100,0.7745382,1.4046736,,,,,,,,,,,,,, -53200,0.76585704,1.4402341,,,,,,,,,,,,,, -53300,0.7078433,1.4934531,,,,,,,,,,,,,, -53400,0.7258701,1.4362929,,,,,,,,,,,,,, -53500,0.6595257,1.4376374,,,,,,,,,,,,,, -53600,0.74564123,1.3731703,,,,,,,,,,,,,, -53700,0.73695534,1.3889027,,,,,,,,,,,,,, -53800,0.79057336,1.4203194,,,,,,,,,,,,,, -53900,0.69591814,1.3921515,,,,,,,,,,,,,, -54000,0.6622741,1.4304235,,,,,,,,,,,,,, -54100,0.79596317,1.392637,,,,,,,,,,,,,, -54167,,,0.30563256,0.1127773453742896,0.58232206,0.1802620272840495,5348.0,0.35808083,0.1219710356874454,2472.0,41801.07320189476,45669.7519903183,41801.07320189476,3864.732713460922,1.5968918800354004,0.0 -54200,0.5481375,1.4230245,,,,,,,,,,,,,, -54300,0.7197645,1.422878,,,,,,,,,,,,,, -54400,0.69331324,1.4024578,,,,,,,,,,,,,, -54500,0.76167524,1.4384094,,,,,,,,,,,,,, -54600,0.83475786,1.4071128,,,,,,,,,,,,,, -54700,0.67992324,1.356837,,,,,,,,,,,,,, -54800,0.68553483,1.3976899,,,,,,,,,,,,,, -54900,0.6859332,1.4083514,,,,,,,,,,,,,, -55000,0.7300501,1.4080279,,,,,,,,,,,,,, -55100,0.668157,1.3849574,,,,,,,,,,,,,, -55200,0.74605066,1.417636,,,,,,,,,,,,,, -55300,0.7877441,1.4249369,,,,,,,,,,,,,, -55400,0.6780214,1.4053154,,,,,,,,,,,,,, -55500,0.67778015,1.4298593,,,,,,,,,,,,,, -55600,0.8029316,1.3279524,,,,,,,,,,,,,, -55700,0.8082869,1.40093,,,,,,,,,,,,,, -55800,0.71837205,1.367461,,,,,,,,,,,,,, -55900,0.67123765,1.3827914,,,,,,,,,,,,,, -56000,0.65973026,1.4011923,,,,,,,,,,,,,, -56037,,,0.28636116,0.1062314602486693,0.56730247,0.1750774785908068,5348.0,0.34402737,0.1187008713667662,2472.0,43241.62187099457,47241.89470553398,43241.62187099457,3996.17692565918,1.6657493114471436,0.0 -56100,0.78262746,1.4445063,,,,,,,,,,,,,, -56200,0.86139643,1.4186093,,,,,,,,,,,,,, -56300,0.75543725,1.3867328,,,,,,,,,,,,,, -56400,0.6464058,1.3336304,,,,,,,,,,,,,, -56500,0.67503214,1.3576802,,,,,,,,,,,,,, -56600,0.8718485,1.4251444,,,,,,,,,,,,,, -56700,0.71060187,1.3799247,,,,,,,,,,,,,, -56800,0.68241256,1.3271908,,,,,,,,,,,,,, -56900,0.7084323,1.3659265,,,,,,,,,,,,,, -57000,0.78904575,1.41025,,,,,,,,,,,,,, -57100,0.68771315,1.3575687,,,,,,,,,,,,,, -57200,0.6307099,1.4138572,,,,,,,,,,,,,, -57300,0.73096776,1.3772405,,,,,,,,,,,,,, -57400,0.765634,1.3482774,,,,,,,,,,,,,, -57500,0.6869817,1.3768183,,,,,,,,,,,,,, -57600,0.78530306,1.3466223,,,,,,,,,,,,,, -57700,0.70870596,1.3464075,,,,,,,,,,,,,, -57800,0.7561497,1.3627033,,,,,,,,,,,,,, -57900,0.85385066,1.3731847,,,,,,,,,,,,,, -57906,,,0.26780295,0.0981005955978449,0.5353414,0.1650366394083628,5348.0,0.32470948,0.1107590437308309,2472.0,44681.83616948128,48813.25409722328,44681.83616948128,4127.185501813889,1.7246437072753906,0.0 -58000,0.71796834,1.3637081,,,,,,,,,,,,,, -58100,0.8302763,1.3153769,,,,,,,,,,,,,, -58200,0.77642715,1.3954438,,,,,,,,,,,,,, -58300,0.863277,1.3458424,,,,,,,,,,,,,, -58400,0.7067515,1.3742903,,,,,,,,,,,,,, -58500,0.74569273,1.3480235,,,,,,,,,,,,,, -58600,0.9064722,1.3114084,,,,,,,,,,,,,, -58700,0.69990885,1.3350186,,,,,,,,,,,,,, -58800,0.82522804,1.3737917,,,,,,,,,,,,,, -58900,0.77819467,1.3150151,,,,,,,,,,,,,, -59000,0.9109483,1.3567953,,,,,,,,,,,,,, -59100,0.86443496,1.3650333,,,,,,,,,,,,,, -59200,0.79911286,1.2738125,,,,,,,,,,,,,, -59300,0.8136149,1.3739803,,,,,,,,,,,,,, -59400,0.78010255,1.300653,,,,,,,,,,,,,, -59500,0.7040109,1.3191911,,,,,,,,,,,,,, -59600,0.719073,1.3125978,,,,,,,,,,,,,, -59700,0.7792152,1.3309927,,,,,,,,,,,,,, -59768,,,0.2401949,0.0903859619566075,0.51465917,0.157988742674532,5348.0,0.30931154,0.1048483740580504,2472.0,46122.00344848633,50384.61937427521,46122.00344848633,4258.248637199402,1.7822024822235107,0.0 -59800,0.7248606,1.2405264,,,,,,,,,,,,,, -59900,0.8798989,1.3200293,,,,,,,,,,,,,, -60000,0.72813535,1.3540314,,,,,,,,,,,,,, -60100,0.67095757,1.3120605,,,,,,,,,,,,,, -60200,0.74987286,1.314159,,,,,,,,,,,,,, -60300,0.6698078,1.2839357,,,,,,,,,,,,,, -60400,0.7873955,1.2906138,,,,,,,,,,,,,, -60500,0.8433501,1.2861446,,,,,,,,,,,,,, -60600,0.9185862,1.3387678,,,,,,,,,,,,,, -60700,0.7265036,1.3065205,,,,,,,,,,,,,, -60800,0.68597037,1.2795044,,,,,,,,,,,,,, -60900,0.74442947,1.3038048,,,,,,,,,,,,,, -61000,0.9415624,1.3056488,,,,,,,,,,,,,, -61100,0.8277853,1.3343898,,,,,,,,,,,,,, -61200,0.9111727,1.3505806,,,,,,,,,,,,,, -61300,0.6943357,1.3149557,,,,,,,,,,,,,, -61400,0.84044516,1.301079,,,,,,,,,,,,,, -61500,0.6433981,1.2719771,,,,,,,,,,,,,, -61600,0.8451667,1.2813017,,,,,,,,,,,,,, -61647,,,0.22045113,0.0828239963686649,0.49300355,0.1518483833283451,5348.0,0.29717535,0.1007454349724778,2472.0,47561.935584545135,51956.26599335671,47561.935584545135,4389.823725938797,1.842402458190918,0.0 -61700,0.7644095,1.2927575,,,,,,,,,,,,,, -61800,0.76241034,1.2248224,,,,,,,,,,,,,, -61900,0.7522522,1.2510746,,,,,,,,,,,,,, -62000,0.7977921,1.2605376,,,,,,,,,,,,,, -62100,0.81835335,1.2979207,,,,,,,,,,,,,, -62200,0.7811872,1.2977064,,,,,,,,,,,,,, -62300,0.6585677,1.2617277,,,,,,,,,,,,,, -62400,0.7459188,1.2943424,,,,,,,,,,,,,, -62500,0.81967705,1.2450465,,,,,,,,,,,,,, -62600,0.75189245,1.2479613,,,,,,,,,,,,,, -62700,0.81861174,1.2844346,,,,,,,,,,,,,, -62800,0.8390626,1.2294325,,,,,,,,,,,,,, -62900,0.8809038,1.2685401,,,,,,,,,,,,,, -63000,0.844111,1.2060547,,,,,,,,,,,,,, -63100,0.7953562,1.2998159,,,,,,,,,,,,,, -63200,0.74870604,1.2436116,,,,,,,,,,,,,, -63300,0.8121633,1.2229345,,,,,,,,,,,,,, -63400,0.7426772,1.281688,,,,,,,,,,,,,, -63500,0.80418795,1.2285968,,,,,,,,,,,,,, -63519,,,0.20646651,0.0763023078873953,0.47628415,0.1478513569614876,5348.0,0.28151563,0.0974346474925355,2472.0,49002.48093056679,53529.43874812126,49002.48093056679,4522.307675123215,1.9072229862213133,0.0 -63600,0.6627388,1.2094593,,,,,,,,,,,,,, -63700,0.82260627,1.2402331,,,,,,,,,,,,,, -63800,0.7460631,1.2323203,,,,,,,,,,,,,, -63900,0.8438249,1.2014005,,,,,,,,,,,,,, -64000,0.8185105,1.2311814,,,,,,,,,,,,,, -64100,0.8824025,1.2548314,,,,,,,,,,,,,, -64200,0.9129088,1.2369907,,,,,,,,,,,,,, -64300,0.79354185,1.2344807,,,,,,,,,,,,,, -64400,0.910282,1.2597927,,,,,,,,,,,,,, -64500,0.97890556,1.2198907,,,,,,,,,,,,,, -64600,0.8904869,1.2161646,,,,,,,,,,,,,, -64700,0.92571187,1.2745012,,,,,,,,,,,,,, -64800,0.7448939,1.2357162,,,,,,,,,,,,,, -64900,0.76024264,1.2065114,,,,,,,,,,,,,, -65000,0.8807522,1.2284387,,,,,,,,,,,,,, -65100,0.8124116,1.2527161,,,,,,,,,,,,,, -65200,0.7478033,1.2466061,,,,,,,,,,,,,, -65300,1.0601677,1.2847598,,,,,,,,,,,,,, -65400,,,0.19849926,0.0751354364544603,0.45558718,0.1399828147175531,5348.0,0.2692499,0.0920114557309122,2472.0,50443.06334590912,55099.87467169762,50443.06334590912,4652.021373748779,1.967803716659546,0.0 -65400,0.7948984,1.2481191,,,,,,,,,,,,,, -65500,0.80885106,1.2382331,,,,,,,,,,,,,, -65600,0.86852884,1.2443094,,,,,,,,,,,,,, -65700,0.88828266,1.2737216,,,,,,,,,,,,,, -65800,0.7742299,1.2277743,,,,,,,,,,,,,, -65900,0.86800927,1.2040651,,,,,,,,,,,,,, -66000,0.82860494,1.183838,,,,,,,,,,,,,, -66100,0.8169572,1.1907276,,,,,,,,,,,,,, -66200,0.9073848,1.2090893,,,,,,,,,,,,,, -66300,0.9041064,1.2677253,,,,,,,,,,,,,, -66400,0.8820861,1.2434806,,,,,,,,,,,,,, -66500,1.1836011,1.2567544,,,,,,,,,,,,,, -66600,0.9054391,1.288184,,,,,,,,,,,,,, -66700,0.88744533,1.2797841,,,,,,,,,,,,,, -66800,0.80514956,1.1666278,,,,,,,,,,,,,, -66900,0.9824204,1.2305652,,,,,,,,,,,,,, -67000,0.80530214,1.1935307,,,,,,,,,,,,,, -67100,0.9721183,1.2573061,,,,,,,,,,,,,, -67200,0.88270396,1.1841989,,,,,,,,,,,,,, -67276,,,0.17537242,0.0655394871243481,0.4380229,0.1361595721057763,5348.0,0.25713265,0.0860195397396055,2472.0,51883.29848217964,56670.25017309189,51883.29848217964,4782.024187803268,2.0273022651672363,0.0 -67300,0.8506629,1.2183046,,,,,,,,,,,,,, -67400,1.0861623,1.1700767,,,,,,,,,,,,,, -67500,1.0327308,1.2314081,,,,,,,,,,,,,, -67600,0.99722016,1.2090349,,,,,,,,,,,,,, -67700,0.93939966,1.1847947,,,,,,,,,,,,,, -67800,0.936368,1.1752049,,,,,,,,,,,,,, -67900,1.1076174,1.1850084,,,,,,,,,,,,,, -68000,0.861364,1.1436679,,,,,,,,,,,,,, -68100,0.85134125,1.1712725,,,,,,,,,,,,,, -68200,0.9043576,1.1834102,,,,,,,,,,,,,, -68300,0.8590797,1.1713052,,,,,,,,,,,,,, -68400,0.86267895,1.1575922,,,,,,,,,,,,,, -68500,1.047746,1.2230471,,,,,,,,,,,,,, -68600,1.0981368,1.197972,,,,,,,,,,,,,, -68700,0.9580501,1.1881853,,,,,,,,,,,,,, -68800,0.805003,1.162475,,,,,,,,,,,,,, -68900,0.87515557,1.1732188,,,,,,,,,,,,,, -69000,0.86491936,1.1462727,,,,,,,,,,,,,, -69100,0.88498944,1.1778806,,,,,,,,,,,,,, -69151,,,0.1733773,0.0655802799387666,0.4252374,0.129690954555548,5348.0,0.24621879,0.0841914975727662,2472.0,53323.735122442245,58241.41588020325,53323.735122442245,4912.617544412613,2.0856642723083496,0.0 -69200,0.84346724,1.1320094,,,,,,,,,,,,,, -69300,0.76352483,1.1190377,,,,,,,,,,,,,, -69400,0.9371586,1.0918655,,,,,,,,,,,,,, -69500,0.8187076,1.1458592,,,,,,,,,,,,,, -69600,1.0249658,1.1464069,,,,,,,,,,,,,, -69700,0.8227784,1.1754208,,,,,,,,,,,,,, -69800,0.99557465,1.1213169,,,,,,,,,,,,,, -69900,0.95217407,1.1220208,,,,,,,,,,,,,, -70000,0.80763304,1.135115,,,,,,,,,,,,,, -70100,0.8213729,1.1254058,,,,,,,,,,,,,, -70200,0.7446245,1.1321903,,,,,,,,,,,,,, -70300,0.8913445,1.1559299,,,,,,,,,,,,,, -70400,1.0884243,1.1324466,,,,,,,,,,,,,, -70500,0.9422067,1.1168267,,,,,,,,,,,,,, -70600,1.1664052,1.1836883,,,,,,,,,,,,,, -70700,1.0271043,1.1468366,,,,,,,,,,,,,, -70800,0.8385704,1.1171422,,,,,,,,,,,,,, -70900,1.0271949,1.2018852,,,,,,,,,,,,,, -71000,1.0118077,1.1205884,,,,,,,,,,,,,, -71024,,,0.18672377,0.0644929298911721,0.41258916,0.1245932977398457,5348.0,0.2360904,0.0806369711372453,2472.0,54764.13278627396,59813.611429452896,54764.13278627396,5044.270725250244,2.151578187942505,0.0 -71100,0.8673402,1.0811933,,,,,,,,,,,,,, -71200,0.94086486,1.0836166,,,,,,,,,,,,,, -71300,0.84466606,1.1305094,,,,,,,,,,,,,, -71400,0.8679168,1.0787032,,,,,,,,,,,,,, -71500,0.90057683,1.174714,,,,,,,,,,,,,, -71600,0.94479346,1.1137319,,,,,,,,,,,,,, -71700,0.9132796,1.1295379,,,,,,,,,,,,,, -71800,0.8538172,1.1200093,,,,,,,,,,,,,, -71900,0.8517687,1.114277,,,,,,,,,,,,,, -72000,1.0577854,1.1243867,,,,,,,,,,,,,, -72100,0.8952317,1.1413221,,,,,,,,,,,,,, -72200,0.9127423,1.0979029,,,,,,,,,,,,,, -72300,0.8285641,1.188735,,,,,,,,,,,,,, -72400,1.0390325,1.1166942,,,,,,,,,,,,,, -72500,0.8691519,1.1486472,,,,,,,,,,,,,, -72600,0.88960433,1.0881459,,,,,,,,,,,,,, -72700,0.96429145,1.1402464,,,,,,,,,,,,,, -72800,0.9998059,1.0885999,,,,,,,,,,,,,, -72889,,,0.13206606,0.0492559071545209,0.4033787,0.1225947845564169,5348.0,0.22824389,0.0783214510592488,2472.0,56204.91225242615,61388.951934337616,56204.91225242615,5178.692124843597,2.211843967437744,0.0 -72900,0.9280404,1.1117055,,,,,,,,,,,,,, -73000,0.993548,1.0905622,,,,,,,,,,,,,, -73100,1.0058097,1.0992099,,,,,,,,,,,,,, -73200,0.92863405,1.080534,,,,,,,,,,,,,, -73300,1.0493369,1.1171627,,,,,,,,,,,,,, -73400,0.8699856,1.1182971,,,,,,,,,,,,,, -73500,1.0118796,1.0894558,,,,,,,,,,,,,, -73600,0.9263527,1.0827447,,,,,,,,,,,,,, -73700,0.85048443,1.0533612,,,,,,,,,,,,,, -73800,0.87967306,1.0835229,,,,,,,,,,,,,, -73900,0.9858476,1.1144449,,,,,,,,,,,,,, -74000,1.0280839,1.0650067,,,,,,,,,,,,,, -74100,1.0260847,1.0788938,,,,,,,,,,,,,, -74200,0.9087813,1.0441097,,,,,,,,,,,,,, -74300,0.9109276,1.1090406,,,,,,,,,,,,,, -74400,0.89186794,1.0813673,,,,,,,,,,,,,, -74500,1.241005,1.0586243,,,,,,,,,,,,,, -74600,1.0915539,1.0885557,,,,,,,,,,,,,, -74700,0.8181044,1.0217252,,,,,,,,,,,,,, -74767,,,0.13103484,0.0474115181823407,0.39362144,0.1193315118221226,5348.0,0.22153029,0.0754575183312006,2472.0,57645.34727025032,62963.01133060455,57645.34727025032,5312.1723392009735,2.276295185089112,0.0 -74800,0.9011028,1.127062,,,,,,,,,,,,,, -74900,1.0537728,1.0589422,,,,,,,,,,,,,, -75000,1.0327066,1.0777503,,,,,,,,,,,,,, -75100,0.988781,1.060997,,,,,,,,,,,,,, -75200,1.1113968,1.0512468,,,,,,,,,,,,,, -75300,1.059595,1.0785584,,,,,,,,,,,,,, -75400,0.8057899,1.0605451,,,,,,,,,,,,,, -75500,0.86669034,1.0816464,,,,,,,,,,,,,, -75600,0.99714214,1.088897,,,,,,,,,,,,,, -75700,0.86178094,1.0205817,,,,,,,,,,,,,, -75800,1.1948966,1.0599197,,,,,,,,,,,,,, -75900,1.1331908,1.1097835,,,,,,,,,,,,,, -76000,0.98648137,1.1137346,,,,,,,,,,,,,, -76100,0.88821036,1.1242298,,,,,,,,,,,,,, -76200,0.93893695,1.0534817,,,,,,,,,,,,,, -76300,0.96701384,1.0239927,,,,,,,,,,,,,, -76400,1.3395241,1.0821149,,,,,,,,,,,,,, -76500,0.8129357,1.067326,,,,,,,,,,,,,, -76600,1.0831715,1.0876677,,,,,,,,,,,,,, -76649,,,0.1626185,0.061350264086916,0.3823941,0.1163192600673894,5348.0,0.21700285,0.0736904109032559,2472.0,59085.85717082024,64533.37995290756,59085.85717082024,5441.886587142944,2.340604066848755,0.0 -76700,1.0200368,1.0482132,,,,,,,,,,,,,, -76800,1.0767545,1.0008625,,,,,,,,,,,,,, -76900,1.0629354,1.0072,,,,,,,,,,,,,, -77000,0.83267754,1.0435385,,,,,,,,,,,,,, -77100,1.1751429,1.1029743,,,,,,,,,,,,,, -77200,1.1299446,1.0967065,,,,,,,,,,,,,, -77300,1.0319148,1.0977738,,,,,,,,,,,,,, -77400,1.060411,1.0799874,,,,,,,,,,,,,, -77500,1.0139005,1.0377767,,,,,,,,,,,,,, -77600,0.9759844,1.0587089,,,,,,,,,,,,,, -77700,0.86620045,1.0313654,,,,,,,,,,,,,, -77800,0.9082528,1.0221587,,,,,,,,,,,,,, -77900,1.0702541,1.0741245,,,,,,,,,,,,,, -78000,1.1325033,1.102477,,,,,,,,,,,,,, -78100,0.9177222,1.0389742,,,,,,,,,,,,,, -78200,1.0754981,1.0817605,,,,,,,,,,,,,, -78300,0.88219833,1.0684259,,,,,,,,,,,,,, -78400,1.0493093,1.0357125,,,,,,,,,,,,,, -78500,1.6837077,1.0750989,,,,,,,,,,,,,, -78528,,,0.16589448,0.0608124716639992,0.38134545,0.116000656516408,5348.0,0.21509795,0.072877947717994,2472.0,60526.26091170311,66103.49896073341,60526.26091170311,5571.456894159317,2.406038999557495,0.0 -78600,1.1537725,1.0478811,,,,,,,,,,,,,, -78700,0.9962919,1.0188209,,,,,,,,,,,,,, -78800,1.0544864,1.0619563,,,,,,,,,,,,,, -78900,0.9157202,1.0196968,,,,,,,,,,,,,, -79000,1.2727807,1.0781585,,,,,,,,,,,,,, -79100,1.5859945,1.0428052,,,,,,,,,,,,,, -79200,0.9962216,1.0647651,,,,,,,,,,,,,, -79242,,,,,,,,,,,61068.479226350784,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/eval_measurements.csv deleted file mode 100644 index eac72919e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -129.26838898658752,0.0,36.713239431381226,1,0,36.713239431381226,30.214182,2472,0.9758089086588264,165.9817259311676,31.602995,1.1353257853575371,30.090126,5348,0.9587263581683192 -254.03513860702515,0.0323665142059326,1477.383898973465,1843,0,1477.383898973465,2.8049276,2472,0.5481079763573213,1731.5264666080477,3.174557,0.6152079537398865,3.1326847,5348,0.5973044208656362 -384.9413070678711,0.0895602703094482,2917.6238310337067,3716,0,2917.6238310337067,0.6313826,2472,0.2064875185343164,3302.8093264102936,0.86722195,0.2781396477884187,0.9116082,5348,0.2709771474362069 -517.5147247314453,0.1419603824615478,4357.761467218399,5589,0,4357.761467218399,0.4836131,2472,0.1620051591412264,4875.651361227036,0.54877484,0.1869594467317887,0.7388587,5348,0.2215453237687903 -649.0162920951843,0.197472333908081,5798.268486738205,7461,0,5798.268486738205,0.42196423,2472,0.1419779416245201,6447.796489477158,0.5382412,0.1822545781718035,0.6600867,5348,0.1976886760574258 -780.1031017303467,0.2476940155029297,7238.677075624466,9324,0,7238.677075624466,0.3988943,2472,0.1319237097069039,8019.420377254486,0.4458762,0.1552308174996942,0.6294824,5348,0.1864603145485966 -913.8577964305878,0.3009212017059326,8679.58162689209,11201,0,8679.58162689209,0.36695814,2472,0.1247130989377043,9594.21326494217,0.42947593,0.1529523511411656,0.589256,5348,0.1787172827944428 -1045.2302241325378,0.3532946109771728,10119.531209468842,13062,0,10119.531209468842,0.35079175,2472,0.1198789429853959,11165.667533397676,0.4079645,0.1412404371584699,0.5708126,5348,0.1735423887542601 -1177.3215169906616,0.4089851379394531,11559.680967330933,14933,0,11559.680967330933,0.337386,2472,0.1130745638088274,12738.043025493622,0.39691854,0.1382461631606303,0.5532182,5348,0.1656255732450254 -1309.8762323856354,0.4639158248901367,12999.688222646711,16804,0,12999.688222646711,0.3234783,2472,0.1077935531046249,14310.739187717438,0.39982402,0.1352574171994955,0.53366345,5348,0.1591183370825569 -1441.2039613723755,0.5192501544952393,14439.60961842537,18669,0,14439.60961842537,0.31250748,2472,0.1062701846322588,15882.122144460678,0.36982483,0.1303468192072545,0.5220176,5348,0.157081205286888 -1571.8473567962646,0.5762767791748047,15880.007965564728,20541,0,15880.007965564728,0.31171122,2472,0.1040359108727885,17453.30059313774,0.37734643,0.1359049127256365,0.5120922,5348,0.1548606350830783 -1704.1879363059998,0.6316831111907959,17320.585915327072,22399,0,17320.585915327072,0.2929786,2472,0.0975158938110616,19026.35332155228,0.3108637,0.109997742676085,0.4965751,5348,0.146818309084063 -1834.467442512512,0.6906123161315918,18760.76453447342,24268,0,18760.76453447342,0.28838947,2472,0.0962971990331688,20596.949521303177,0.3285102,0.1157363447277988,0.48729917,5348,0.1454956216148372 -1963.9308321475985,0.7447621822357178,20201.36756658554,26136,0,20201.36756658554,0.27694356,2472,0.0920520788901752,22167.149206638336,0.31813964,0.1122346663041076,0.47977257,5348,0.1423385500642034 -2095.584389448166,0.7992823123931885,21641.74042582512,27999,0,21641.74042582512,0.27267453,2472,0.0919098978327544,23739.309053897858,0.3083447,0.1089282379691773,0.46320656,5348,0.1385056527993666 -2225.088423728943,0.8564896583557129,23081.77032160759,29860,0,23081.77032160759,0.2670367,2472,0.0891678345824954,25308.97817516327,0.27677372,0.1016867090032878,0.4609452,5348,0.1371829653301408 -2355.266172170639,0.9134845733642578,24521.79063367844,31706,0,24521.79063367844,0.2575679,2472,0.0871163650397091,26879.311678886414,0.24583429,0.0905271570054513,0.4394589,5348,0.1313032816165751 -2486.330880880356,0.966865062713623,25961.964494228363,33570,0,25961.964494228363,0.25297192,2472,0.0839680701968192,28450.681255578995,0.2740817,0.0992087059315946,0.4404438,5348,0.1302991976983307 -2618.166530609131,1.0274250507354736,27402.324479341507,35434,0,27402.324479341507,0.23996721,2472,0.0799870005890358,30023.01597237587,0.2441035,0.0893927837749063,0.426791,5348,0.1284165403516224 -2749.658364534378,1.082904815673828,28842.528745412827,37298,0,28842.528745412827,0.23475464,2472,0.0799260658501411,31594.84669661522,0.26563275,0.0912432650183795,0.4151319,5348,0.1227492590053776 -2880.340024471283,1.1378488540649414,30282.826360464096,39159,0,30282.826360464096,0.23633035,2472,0.0776308573517762,33165.95847392082,0.21324308,0.0787642760250295,0.41352835,5348,0.1211658959035307 -3011.626853942871,1.2022075653076172,31722.86250114441,41031,0,31722.86250114441,0.22629774,2472,0.0753762720126744,34737.42574644089,0.19461007,0.0723608999189762,0.39671153,5348,0.1168695752918118 -3143.738133907318,1.2583415508270264,33163.35210299492,42900,0,33163.35210299492,0.22514184,2472,0.0751528446367274,36310.16140389442,0.22017571,0.0809919242074023,0.39385065,5348,0.1172654160672736 -3284.47208070755,1.315403699874878,34603.75563144684,44779,0,34603.75563144684,0.21745369,2472,0.0732232445717303,37891.43506240845,0.14650366,0.0550815640269449,0.39157462,5348,0.1133070083126562 -3417.5548565387726,1.3723094463348389,36043.75965952873,46654,0,36043.75965952873,0.21394913,2472,0.0718623687364166,39464.65749335289,0.12380371,0.0473838948874723,0.38062832,5348,0.1110574741496664 -3550.433384656906,1.4295480251312256,37483.8937060833,48521,0,37483.8937060833,0.20314053,2472,0.0680031686064225,41037.80516719818,0.12260363,0.0468208541607313,0.3690256,5348,0.1080645317010533 -3683.415236711502,1.4886324405670166,38924.372160196304,50389,0,38924.372160196304,0.20159614,2472,0.068653139154632,42611.40343165398,0.119473204,0.0460013418141535,0.36581305,5348,0.1060853278237446 -3814.117434978485,1.552154541015625,40364.40458655357,52250,0,40364.40458655357,0.19382338,2472,0.065057989559848,44182.28010845184,0.108347975,0.0429867342424755,0.35260636,5348,0.1029186016200507 -3947.251857280731,1.6132729053497314,41804.90568423271,54120,0,41804.90568423271,0.19269316,2472,0.0634127516096926,45756.05579662323,0.09962546,0.0394493188699585,0.35237452,5348,0.1011807640692431 -4077.9627606868735,1.6707770824432373,43245.16593647003,55989,0,43245.16593647003,0.18800361,2472,0.0612191010094855,47327.161296606064,0.11142533,0.0426152274898692,0.34446225,5348,0.0983712600287708 -4211.173542499542,1.7332322597503662,44685.39541554451,57853,0,44685.39541554451,0.1831914,2472,0.0594519935815408,48900.7434065342,0.09508877,0.03473609018244,0.33794376,5348,0.0959962153760004 -4342.618939638138,1.794440984725952,46125.69869709015,59717,0,46125.69869709015,0.1784495,2472,0.058273921962911,50472.63144659996,0.092592455,0.0354570637119113,0.33490953,5348,0.0958127769678596 -4475.0985696315765,1.8526763916015625,47566.37073302269,61577,0,47566.37073302269,0.17452674,2472,0.0569333577072288,52045.9193854332,0.073665015,0.0287280887458717,0.3230342,5348,0.0926363961111057 -4608.362781047821,1.9141466617584229,49006.61024451256,63446,0,49006.61024451256,0.17051816,2472,0.0555927934515467,53619.56326055527,0.06885544,0.0277676345912648,0.32002056,5348,0.0898558560298135 -4739.67716550827,1.9725675582885744,50446.49883818626,65317,0,50446.49883818626,0.16880462,2472,0.0543740986736538,55190.90368533135,0.07590765,0.029906168248546,0.3158513,5348,0.0880118172953454 -4873.053730249405,2.0313730239868164,51886.94908380509,67186,0,51886.94908380509,0.1675134,2472,0.0536225702272865,56764.86767911911,0.06836843,0.0268750163044896,0.3090984,5348,0.085858829662956 -5006.324478149414,2.09139347076416,53326.903443574905,69044,0,53326.903443574905,0.16303761,2472,0.0516726585826579,58338.23026275635,0.06659844,0.0258889079465069,0.3053904,5348,0.0853760970099539 -5139.588047981262,2.15257215499878,54767.0825843811,70916,0,54767.0825843811,0.16102883,2472,0.0516320354233948,59911.81316590309,0.05856837,0.0229841510387068,0.30260167,5348,0.0837734246019869 -5273.123430967331,2.217568635940552,56207.34043216705,72786,0,56207.34043216705,0.15916285,2472,0.0509617532955537,61485.750854730606,0.05365826,0.020761263701913,0.30205798,5348,0.0834837850101856 -5406.259341955185,2.278391599655152,57647.41107773781,74654,0,57647.41107773781,0.15752992,2472,0.0500883553713972,63059.09764456749,0.05984657,0.0220294749769726,0.2980946,5348,0.0816494009287776 -5538.735710859299,2.3376240730285645,59087.65986609459,76528,0,59087.65986609459,0.15713376,2472,0.0502914711677127,64631.96136951447,0.057438117,0.021416039085486,0.29819798,5348,0.0816300916226575 -5670.963258266449,2.402360677719116,60527.83698248863,78379,0,60527.83698248863,0.15672202,2472,0.049986797473239496,66204.51050138474,0.058575712,0.022106165300765975,0.29666498,5348,0.08126321480637594 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/measurements.csv deleted file mode 100644 index 584646241..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/measurements.csv +++ /dev/null @@ -1,836 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,67.56899,32.317066,,,,,,,,,,,,,, -1,,,31.602995,1.1353257853575371,30.090126,0.9587263581683192,5348.0,30.214182,0.9758089086588264,2472.0,36.713239431381226,165.9817259311676,36.713239431381226,129.26838898658752,0.0,0.0 -100,0.5141086,5.921081,,,,,,,,,,,,,, -200,0.4076565,5.833982,,,,,,,,,,,,,, -300,1.8658774,5.8159657,,,,,,,,,,,,,, -400,1.8619895,5.7951493,,,,,,,,,,,,,, -500,0.4125521,5.7816653,,,,,,,,,,,,,, -600,3.1287978,5.7295475,,,,,,,,,,,,,, -700,0.54899544,5.5177374,,,,,,,,,,,,,, -800,2.12451,5.172237,,,,,,,,,,,,,, -900,0.94707155,4.063635,,,,,,,,,,,,,, -1000,1.4907787,3.61856,,,,,,,,,,,,,, -1100,1.0071626,3.370638,,,,,,,,,,,,,, -1200,1.6913943,3.0567417,,,,,,,,,,,,,, -1300,1.8449073,2.9810843,,,,,,,,,,,,,, -1400,0.70547247,2.8414273,,,,,,,,,,,,,, -1500,0.7949841,2.7463121,,,,,,,,,,,,,, -1600,0.59255785,2.6098611,,,,,,,,,,,,,, -1700,0.6518116,2.5227509,,,,,,,,,,,,,, -1800,0.7244784,2.4756794,,,,,,,,,,,,,, -1843,,,3.174557,0.6152079537398865,3.1326847,0.5973044208656362,5348.0,2.8049276,0.5481079763573213,2472.0,1477.383898973465,1731.5264666080477,1477.383898973465,254.03513860702515,0.0323665142059326,0.0 -1900,0.64503884,2.3644278,,,,,,,,,,,,,, -2000,0.7621703,2.3181567,,,,,,,,,,,,,, -2100,0.65258026,2.232455,,,,,,,,,,,,,, -2200,1.0080104,2.2626247,,,,,,,,,,,,,, -2300,0.8993011,2.12886,,,,,,,,,,,,,, -2400,0.84980845,2.1452365,,,,,,,,,,,,,, -2500,0.86991775,2.0508428,,,,,,,,,,,,,, -2600,0.60407645,2.0356967,,,,,,,,,,,,,, -2700,0.5823261,2.0068476,,,,,,,,,,,,,, -2800,0.58476865,1.9889454,,,,,,,,,,,,,, -2900,0.5941195,2.009365,,,,,,,,,,,,,, -3000,0.80710346,1.9830773,,,,,,,,,,,,,, -3100,0.5800504,1.8546864,,,,,,,,,,,,,, -3200,0.45872438,1.8717037,,,,,,,,,,,,,, -3300,0.62162465,1.8703561,,,,,,,,,,,,,, -3400,0.72098726,1.8957516,,,,,,,,,,,,,, -3500,0.53441113,1.8701968,,,,,,,,,,,,,, -3600,0.5777987,1.9102211,,,,,,,,,,,,,, -3700,0.5843495,1.8645072,,,,,,,,,,,,,, -3716,,,0.86722195,0.2781396477884187,0.9116082,0.2709771474362069,5348.0,0.6313826,0.2064875185343164,2472.0,2917.6238310337067,3302.8093264102936,2917.6238310337067,384.9413070678711,0.0895602703094482,0.0 -3800,0.61259043,1.8471638,,,,,,,,,,,,,, -3900,0.51938677,1.851981,,,,,,,,,,,,,, -4000,0.51309675,1.8145641,,,,,,,,,,,,,, -4100,0.6561899,1.8507509,,,,,,,,,,,,,, -4200,0.66855764,1.762444,,,,,,,,,,,,,, -4300,0.525708,1.7404745,,,,,,,,,,,,,, -4400,0.6090219,1.7317889,,,,,,,,,,,,,, -4500,0.45203254,1.7384781,,,,,,,,,,,,,, -4600,0.4772466,1.7485397,,,,,,,,,,,,,, -4700,0.52116174,1.8087068,,,,,,,,,,,,,, -4800,0.4410769,1.7125728,,,,,,,,,,,,,, -4900,0.42982805,1.7353336,,,,,,,,,,,,,, -5000,0.5109525,1.6874425,,,,,,,,,,,,,, -5100,0.55534095,1.6938281,,,,,,,,,,,,,, -5200,0.56036216,1.6317117,,,,,,,,,,,,,, -5300,0.47812185,1.6364306,,,,,,,,,,,,,, -5400,0.5803487,1.6967517,,,,,,,,,,,,,, -5500,0.69693893,1.608383,,,,,,,,,,,,,, -5589,,,0.54877484,0.1869594467317887,0.7388587,0.2215453237687903,5348.0,0.4836131,0.1620051591412264,2472.0,4357.761467218399,4875.651361227036,4357.761467218399,517.5147247314453,0.1419603824615478,0.0 -5600,0.4643096,1.6364048,,,,,,,,,,,,,, -5700,0.59659106,1.7313567,,,,,,,,,,,,,, -5800,0.47446275,1.6168402,,,,,,,,,,,,,, -5900,0.5634595,1.7294682,,,,,,,,,,,,,, -6000,0.6106565,1.6961304,,,,,,,,,,,,,, -6100,0.59894085,1.6224697,,,,,,,,,,,,,, -6200,0.5820926,1.666139,,,,,,,,,,,,,, -6300,0.46607786,1.6632948,,,,,,,,,,,,,, -6400,0.49915832,1.6454171,,,,,,,,,,,,,, -6500,0.5125292,1.7018355,,,,,,,,,,,,,, -6600,0.551906,1.6202037,,,,,,,,,,,,,, -6700,0.5722111,1.6193765,,,,,,,,,,,,,, -6800,0.597398,1.568748,,,,,,,,,,,,,, -6900,0.42441127,1.6112648,,,,,,,,,,,,,, -7000,0.66009927,1.6504847,,,,,,,,,,,,,, -7100,0.51331145,1.6088866,,,,,,,,,,,,,, -7200,0.42701748,1.5416818,,,,,,,,,,,,,, -7300,0.79482543,1.5213833,,,,,,,,,,,,,, -7400,0.58662087,1.619446,,,,,,,,,,,,,, -7461,,,0.5382412,0.1822545781718035,0.6600867,0.1976886760574258,5348.0,0.42196423,0.1419779416245201,2472.0,5798.268486738205,6447.796489477158,5798.268486738205,649.0162920951843,0.197472333908081,0.0 -7500,0.6909335,1.5688365,,,,,,,,,,,,,, -7600,0.6022658,1.5308831,,,,,,,,,,,,,, -7700,0.44201776,1.5581931,,,,,,,,,,,,,, -7800,0.5532033,1.5506421,,,,,,,,,,,,,, -7900,0.49209806,1.6438962,,,,,,,,,,,,,, -8000,0.4798966,1.6173404,,,,,,,,,,,,,, -8100,0.55225533,1.5758053,,,,,,,,,,,,,, -8200,0.45434332,1.5484419,,,,,,,,,,,,,, -8300,0.49957386,1.553703,,,,,,,,,,,,,, -8400,0.54115987,1.5758542,,,,,,,,,,,,,, -8500,0.56831515,1.6385982,,,,,,,,,,,,,, -8600,0.6369291,1.5623115,,,,,,,,,,,,,, -8700,0.5296364,1.6039164,,,,,,,,,,,,,, -8800,0.40451506,1.5381358,,,,,,,,,,,,,, -8900,0.5950592,1.5948339,,,,,,,,,,,,,, -9000,0.55776554,1.5040753,,,,,,,,,,,,,, -9100,0.54428315,1.5939596,,,,,,,,,,,,,, -9200,0.5915014,1.5449551,,,,,,,,,,,,,, -9300,0.5712458,1.5976907,,,,,,,,,,,,,, -9324,,,0.4458762,0.1552308174996942,0.6294824,0.1864603145485966,5348.0,0.3988943,0.1319237097069039,2472.0,7238.677075624466,8019.420377254486,7238.677075624466,780.1031017303467,0.2476940155029297,0.0 -9400,0.64274096,1.4971508,,,,,,,,,,,,,, -9500,0.47015104,1.5237669,,,,,,,,,,,,,, -9600,0.5568171,1.5140134,,,,,,,,,,,,,, -9700,0.41395843,1.5421242,,,,,,,,,,,,,, -9800,0.5828981,1.5335963,,,,,,,,,,,,,, -9900,0.46484825,1.4486408,,,,,,,,,,,,,, -10000,0.55168736,1.543165,,,,,,,,,,,,,, -10100,0.8477133,1.5195214,,,,,,,,,,,,,, -10200,0.5109406,1.4829221,,,,,,,,,,,,,, -10300,0.4688721,1.48817,,,,,,,,,,,,,, -10400,0.46154186,1.4694456,,,,,,,,,,,,,, -10500,0.48611686,1.4973913,,,,,,,,,,,,,, -10600,0.61226475,1.5509038,,,,,,,,,,,,,, -10700,0.46118674,1.52498,,,,,,,,,,,,,, -10800,0.50243944,1.4440575,,,,,,,,,,,,,, -10900,0.46619835,1.5533721,,,,,,,,,,,,,, -11000,0.45307958,1.5377834,,,,,,,,,,,,,, -11100,0.47673634,1.5235308,,,,,,,,,,,,,, -11200,0.5118977,1.5065919,,,,,,,,,,,,,, -11201,,,0.42947593,0.1529523511411656,0.589256,0.1787172827944428,5348.0,0.36695814,0.1247130989377043,2472.0,8679.58162689209,9594.21326494217,8679.58162689209,913.8577964305878,0.3009212017059326,0.0 -11300,0.45475852,1.5209893,,,,,,,,,,,,,, -11400,0.7407127,1.4601128,,,,,,,,,,,,,, -11500,0.5449138,1.4753996,,,,,,,,,,,,,, -11600,0.53546065,1.4813708,,,,,,,,,,,,,, -11700,0.520975,1.4824679,,,,,,,,,,,,,, -11800,0.5273828,1.4598066,,,,,,,,,,,,,, -11900,0.4941872,1.5079427,,,,,,,,,,,,,, -12000,0.63785535,1.5482134,,,,,,,,,,,,,, -12100,0.46427506,1.5226443,,,,,,,,,,,,,, -12200,0.505409,1.48763,,,,,,,,,,,,,, -12300,0.5377944,1.4627731,,,,,,,,,,,,,, -12400,0.5226048,1.4665027,,,,,,,,,,,,,, -12500,0.47303966,1.4389666,,,,,,,,,,,,,, -12600,0.5055972,1.4472086,,,,,,,,,,,,,, -12700,0.5461656,1.4634465,,,,,,,,,,,,,, -12800,0.4636452,1.4356443,,,,,,,,,,,,,, -12900,0.509131,1.4708297,,,,,,,,,,,,,, -13000,0.47920126,1.4771448,,,,,,,,,,,,,, -13062,,,0.4079645,0.1412404371584699,0.5708126,0.1735423887542601,5348.0,0.35079175,0.1198789429853959,2472.0,10119.531209468842,11165.667533397676,10119.531209468842,1045.2302241325378,0.3532946109771728,0.0 -13100,0.4546651,1.4774144,,,,,,,,,,,,,, -13200,0.44150892,1.481864,,,,,,,,,,,,,, -13300,0.5516772,1.4457,,,,,,,,,,,,,, -13400,0.4925613,1.4407417,,,,,,,,,,,,,, -13500,0.4320942,1.4363139,,,,,,,,,,,,,, -13600,0.548009,1.3808903,,,,,,,,,,,,,, -13700,0.44470108,1.3888539,,,,,,,,,,,,,, -13800,0.5862511,1.4022253,,,,,,,,,,,,,, -13900,0.49112496,1.4229815,,,,,,,,,,,,,, -14000,0.5747795,1.5334915,,,,,,,,,,,,,, -14100,0.5223944,1.5199459,,,,,,,,,,,,,, -14200,0.5223619,1.4880337,,,,,,,,,,,,,, -14300,0.6999613,1.4825253,,,,,,,,,,,,,, -14400,0.5078951,1.4802483,,,,,,,,,,,,,, -14500,0.4136045,1.4439148,,,,,,,,,,,,,, -14600,0.7255982,1.4525242,,,,,,,,,,,,,, -14700,0.44931227,1.4090613,,,,,,,,,,,,,, -14800,0.5727006,1.5294576,,,,,,,,,,,,,, -14900,0.45024645,1.3473818,,,,,,,,,,,,,, -14933,,,0.39691854,0.1382461631606303,0.5532182,0.1656255732450254,5348.0,0.337386,0.1130745638088274,2472.0,11559.680967330933,12738.043025493622,11559.680967330933,1177.3215169906616,0.4089851379394531,0.0 -15000,0.5062354,1.3877102,,,,,,,,,,,,,, -15100,0.53576404,1.4382955,,,,,,,,,,,,,, -15200,0.6095915,1.436463,,,,,,,,,,,,,, -15300,0.47977462,1.458304,,,,,,,,,,,,,, -15400,0.70603365,1.4029086,,,,,,,,,,,,,, -15500,0.5074076,1.3799804,,,,,,,,,,,,,, -15600,0.51273507,1.3987427,,,,,,,,,,,,,, -15700,0.5962817,1.3913244,,,,,,,,,,,,,, -15800,0.62696034,1.3854737,,,,,,,,,,,,,, -15900,0.6176795,1.4417523,,,,,,,,,,,,,, -16000,0.5284762,1.4290549,,,,,,,,,,,,,, -16100,0.5439221,1.4283282,,,,,,,,,,,,,, -16200,0.51966494,1.4143277,,,,,,,,,,,,,, -16300,0.43886352,1.3961625,,,,,,,,,,,,,, -16400,0.4624511,1.4648471,,,,,,,,,,,,,, -16500,0.51156574,1.4019994,,,,,,,,,,,,,, -16600,0.45678794,1.4220359,,,,,,,,,,,,,, -16700,0.578038,1.3608645,,,,,,,,,,,,,, -16800,0.47338164,1.3897473,,,,,,,,,,,,,, -16804,,,0.39982402,0.1352574171994955,0.53366345,0.1591183370825569,5348.0,0.3234783,0.1077935531046249,2472.0,12999.688222646711,14310.739187717438,12999.688222646711,1309.8762323856354,0.4639158248901367,0.0 -16900,0.46407762,1.397829,,,,,,,,,,,,,, -17000,0.40219164,1.3990942,,,,,,,,,,,,,, -17100,0.44457835,1.4201046,,,,,,,,,,,,,, -17200,0.42519066,1.384266,,,,,,,,,,,,,, -17300,0.5983492,1.3655939,,,,,,,,,,,,,, -17400,0.539308,1.3858496,,,,,,,,,,,,,, -17500,0.50347084,1.4054096,,,,,,,,,,,,,, -17600,0.39205417,1.380985,,,,,,,,,,,,,, -17700,0.41053134,1.3305626,,,,,,,,,,,,,, -17800,0.49768767,1.41814,,,,,,,,,,,,,, -17900,0.5166358,1.4564044,,,,,,,,,,,,,, -18000,0.57477254,1.3737814,,,,,,,,,,,,,, -18100,0.5216005,1.4053785,,,,,,,,,,,,,, -18200,0.50288314,1.4106451,,,,,,,,,,,,,, -18300,0.45068884,1.369632,,,,,,,,,,,,,, -18400,0.6074462,1.3977209,,,,,,,,,,,,,, -18500,0.5122769,1.4363883,,,,,,,,,,,,,, -18600,0.5224499,1.3670781,,,,,,,,,,,,,, -18669,,,0.36982483,0.1303468192072545,0.5220176,0.157081205286888,5348.0,0.31250748,0.1062701846322588,2472.0,14439.60961842537,15882.122144460678,14439.60961842537,1441.2039613723755,0.5192501544952393,0.0 -18700,0.40818042,1.3462864,,,,,,,,,,,,,, -18800,0.52821255,1.371961,,,,,,,,,,,,,, -18900,0.52051604,1.374088,,,,,,,,,,,,,, -19000,0.46211943,1.4058002,,,,,,,,,,,,,, -19100,0.5099456,1.3669778,,,,,,,,,,,,,, -19200,0.40889862,1.3396147,,,,,,,,,,,,,, -19300,0.53744364,1.3569479,,,,,,,,,,,,,, -19400,0.49984723,1.37801,,,,,,,,,,,,,, -19500,0.43433374,1.4061475,,,,,,,,,,,,,, -19600,0.55326194,1.3321608,,,,,,,,,,,,,, -19700,0.6169935,1.3754796,,,,,,,,,,,,,, -19800,0.43726677,1.383105,,,,,,,,,,,,,, -19900,0.64749223,1.387538,,,,,,,,,,,,,, -20000,0.41914833,1.3052509,,,,,,,,,,,,,, -20100,0.50066876,1.2801791,,,,,,,,,,,,,, -20200,0.56500626,1.3510078,,,,,,,,,,,,,, -20300,0.55477554,1.4008063,,,,,,,,,,,,,, -20400,0.4655714,1.4027456,,,,,,,,,,,,,, -20500,0.6004513,1.413698,,,,,,,,,,,,,, -20541,,,0.37734643,0.1359049127256365,0.5120922,0.1548606350830783,5348.0,0.31171122,0.1040359108727885,2472.0,15880.007965564728,17453.30059313774,15880.007965564728,1571.8473567962646,0.5762767791748047,0.0 -20600,0.47329792,1.3514858,,,,,,,,,,,,,, -20700,0.4889165,1.3379511,,,,,,,,,,,,,, -20800,0.5278944,1.3798511,,,,,,,,,,,,,, -20900,0.41134256,1.342826,,,,,,,,,,,,,, -21000,0.53693956,1.3331113,,,,,,,,,,,,,, -21100,0.4571218,1.3586851,,,,,,,,,,,,,, -21200,0.47242334,1.3699441,,,,,,,,,,,,,, -21300,0.4930591,1.3677152,,,,,,,,,,,,,, -21400,0.6059409,1.328353,,,,,,,,,,,,,, -21500,0.5637689,1.4123296,,,,,,,,,,,,,, -21600,0.5027929,1.3659788,,,,,,,,,,,,,, -21700,0.510921,1.3121334,,,,,,,,,,,,,, -21800,0.67318106,1.3685045,,,,,,,,,,,,,, -21900,0.43910414,1.3349639,,,,,,,,,,,,,, -22000,0.51984394,1.3810201,,,,,,,,,,,,,, -22100,0.54126984,1.3630637,,,,,,,,,,,,,, -22200,0.49060887,1.3880534,,,,,,,,,,,,,, -22300,0.53627783,1.3366127,,,,,,,,,,,,,, -22399,,,0.3108637,0.109997742676085,0.4965751,0.146818309084063,5348.0,0.2929786,0.0975158938110616,2472.0,17320.585915327072,19026.35332155228,17320.585915327072,1704.1879363059998,0.6316831111907959,0.0 -22400,0.60425067,1.3640133,,,,,,,,,,,,,, -22500,0.42763332,1.3583905,,,,,,,,,,,,,, -22600,0.50234264,1.2998871,,,,,,,,,,,,,, -22700,0.5310731,1.3266408,,,,,,,,,,,,,, -22800,0.5181355,1.321962,,,,,,,,,,,,,, -22900,0.5121901,1.329102,,,,,,,,,,,,,, -23000,0.5399547,1.3364799,,,,,,,,,,,,,, -23100,0.73593074,1.2776508,,,,,,,,,,,,,, -23200,0.5135588,1.3873991,,,,,,,,,,,,,, -23300,0.57574505,1.3365217,,,,,,,,,,,,,, -23400,0.5019385,1.3036621,,,,,,,,,,,,,, -23500,0.5251732,1.3002019,,,,,,,,,,,,,, -23600,0.39383554,1.3122493,,,,,,,,,,,,,, -23700,0.5535609,1.3440605,,,,,,,,,,,,,, -23800,0.4809536,1.2595748,,,,,,,,,,,,,, -23900,0.51266277,1.3162848,,,,,,,,,,,,,, -24000,0.531769,1.3205103,,,,,,,,,,,,,, -24100,0.57748777,1.371566,,,,,,,,,,,,,, -24200,0.5185484,1.3288304,,,,,,,,,,,,,, -24268,,,0.3285102,0.1157363447277988,0.48729917,0.1454956216148372,5348.0,0.28838947,0.0962971990331688,2472.0,18760.76453447342,20596.949521303177,18760.76453447342,1834.467442512512,0.6906123161315918,0.0 -24300,0.5342007,1.3803213,,,,,,,,,,,,,, -24400,0.46102807,1.3691984,,,,,,,,,,,,,, -24500,0.48756123,1.288535,,,,,,,,,,,,,, -24600,0.41138706,1.3406763,,,,,,,,,,,,,, -24700,0.5502236,1.2943833,,,,,,,,,,,,,, -24800,0.5093272,1.3010137,,,,,,,,,,,,,, -24900,0.5240106,1.3267924,,,,,,,,,,,,,, -25000,0.5240815,1.2952347,,,,,,,,,,,,,, -25100,0.5698338,1.3585585,,,,,,,,,,,,,, -25200,0.6654323,1.3477736,,,,,,,,,,,,,, -25300,0.6291507,1.3339287,,,,,,,,,,,,,, -25400,0.5477162,1.3259658,,,,,,,,,,,,,, -25500,0.5844699,1.2859489,,,,,,,,,,,,,, -25600,0.43472832,1.302607,,,,,,,,,,,,,, -25700,0.5508519,1.3173755,,,,,,,,,,,,,, -25800,0.6824717,1.289984,,,,,,,,,,,,,, -25900,0.45618588,1.2717713,,,,,,,,,,,,,, -26000,0.54667306,1.2819268,,,,,,,,,,,,,, -26100,0.588463,1.2965173,,,,,,,,,,,,,, -26136,,,0.31813964,0.1122346663041076,0.47977257,0.1423385500642034,5348.0,0.27694356,0.0920520788901752,2472.0,20201.36756658554,22167.149206638336,20201.36756658554,1963.9308321475985,0.7447621822357178,0.0 -26200,0.4498179,1.2778999,,,,,,,,,,,,,, -26300,0.6259565,1.305329,,,,,,,,,,,,,, -26400,0.5326319,1.3047265,,,,,,,,,,,,,, -26500,0.5334617,1.3074178,,,,,,,,,,,,,, -26600,0.5322708,1.3030418,,,,,,,,,,,,,, -26700,0.52437323,1.3354348,,,,,,,,,,,,,, -26800,0.45633492,1.2693292,,,,,,,,,,,,,, -26900,0.45224586,1.328647,,,,,,,,,,,,,, -27000,0.47413266,1.3192089,,,,,,,,,,,,,, -27100,0.58074147,1.3110465,,,,,,,,,,,,,, -27200,0.53123754,1.3154843,,,,,,,,,,,,,, -27300,0.69096017,1.2617041,,,,,,,,,,,,,, -27400,0.47184095,1.2420001,,,,,,,,,,,,,, -27500,0.8398469,1.295348,,,,,,,,,,,,,, -27600,0.5068591,1.2901028,,,,,,,,,,,,,, -27700,0.5623752,1.2506436,,,,,,,,,,,,,, -27800,0.60122633,1.2745571,,,,,,,,,,,,,, -27900,0.6086773,1.302926,,,,,,,,,,,,,, -27999,,,0.3083447,0.1089282379691773,0.46320656,0.1385056527993666,5348.0,0.27267453,0.0919098978327544,2472.0,21641.74042582512,23739.309053897858,21641.74042582512,2095.584389448166,0.7992823123931885,0.0 -28000,0.50839543,1.2338636,,,,,,,,,,,,,, -28100,0.47717014,1.264114,,,,,,,,,,,,,, -28200,0.48781887,1.255659,,,,,,,,,,,,,, -28300,0.51778525,1.2562052,,,,,,,,,,,,,, -28400,0.57681346,1.2712816,,,,,,,,,,,,,, -28500,0.5372634,1.3282691,,,,,,,,,,,,,, -28600,0.58007747,1.2912489,,,,,,,,,,,,,, -28700,0.46951744,1.340149,,,,,,,,,,,,,, -28800,0.4695786,1.2662129,,,,,,,,,,,,,, -28900,0.43284196,1.2939439,,,,,,,,,,,,,, -29000,0.514765,1.312043,,,,,,,,,,,,,, -29100,0.51902944,1.2975016,,,,,,,,,,,,,, -29200,0.4585784,1.2721496,,,,,,,,,,,,,, -29300,0.50680023,1.2407286,,,,,,,,,,,,,, -29400,0.5970206,1.2353077,,,,,,,,,,,,,, -29500,0.4850034,1.2910609,,,,,,,,,,,,,, -29600,0.49665707,1.2341174,,,,,,,,,,,,,, -29700,0.5799819,1.2736597,,,,,,,,,,,,,, -29800,0.59703,1.2901652,,,,,,,,,,,,,, -29860,,,0.27677372,0.1016867090032878,0.4609452,0.1371829653301408,5348.0,0.2670367,0.0891678345824954,2472.0,23081.77032160759,25308.97817516327,23081.77032160759,2225.088423728943,0.8564896583557129,0.0 -29900,0.5346567,1.2726811,,,,,,,,,,,,,, -30000,0.4671385,1.2170938,,,,,,,,,,,,,, -30100,0.53275216,1.3145096,,,,,,,,,,,,,, -30200,0.46032947,1.2282106,,,,,,,,,,,,,, -30300,0.545674,1.2091926,,,,,,,,,,,,,, -30400,0.50480556,1.3099419,,,,,,,,,,,,,, -30500,0.5532258,1.2701591,,,,,,,,,,,,,, -30600,0.53246135,1.239719,,,,,,,,,,,,,, -30700,0.5873782,1.329033,,,,,,,,,,,,,, -30800,0.56184953,1.2963883,,,,,,,,,,,,,, -30900,0.52477014,1.2192694,,,,,,,,,,,,,, -31000,0.5053953,1.2787566,,,,,,,,,,,,,, -31100,0.5252379,1.2242297,,,,,,,,,,,,,, -31200,0.4572942,1.2122414,,,,,,,,,,,,,, -31300,0.4744863,1.2586526,,,,,,,,,,,,,, -31400,0.44280434,1.2267256,,,,,,,,,,,,,, -31500,0.5380366,1.2815267,,,,,,,,,,,,,, -31600,0.5189817,1.2487369,,,,,,,,,,,,,, -31700,0.57810116,1.2137046,,,,,,,,,,,,,, -31706,,,0.24583429,0.0905271570054513,0.4394589,0.1313032816165751,5348.0,0.2575679,0.0871163650397091,2472.0,24521.79063367844,26879.311678886414,24521.79063367844,2355.266172170639,0.9134845733642578,0.0 -31800,0.48953235,1.2181467,,,,,,,,,,,,,, -31900,0.5936051,1.2634596,,,,,,,,,,,,,, -32000,0.5756522,1.2750701,,,,,,,,,,,,,, -32100,0.5906836,1.2085117,,,,,,,,,,,,,, -32200,0.57752603,1.2099972,,,,,,,,,,,,,, -32300,0.5136492,1.2230512,,,,,,,,,,,,,, -32400,0.51835847,1.252804,,,,,,,,,,,,,, -32500,0.5077782,1.2361059,,,,,,,,,,,,,, -32600,0.5110309,1.193556,,,,,,,,,,,,,, -32700,0.44725567,1.2320275,,,,,,,,,,,,,, -32800,0.67483574,1.2302088,,,,,,,,,,,,,, -32900,0.6341211,1.2929734,,,,,,,,,,,,,, -33000,0.7458049,1.1978108,,,,,,,,,,,,,, -33100,0.6540112,1.1858746,,,,,,,,,,,,,, -33200,0.51085347,1.2798336,,,,,,,,,,,,,, -33300,0.60438263,1.2304239,,,,,,,,,,,,,, -33400,0.7159875,1.2946903,,,,,,,,,,,,,, -33500,0.5811735,1.2438307,,,,,,,,,,,,,, -33570,,,0.2740817,0.0992087059315946,0.4404438,0.1302991976983307,5348.0,0.25297192,0.0839680701968192,2472.0,25961.964494228363,28450.681255578995,25961.964494228363,2486.330880880356,0.966865062713623,0.0 -33600,0.46762216,1.2353742,,,,,,,,,,,,,, -33700,0.51987076,1.2798659,,,,,,,,,,,,,, -33800,0.4775731,1.2979568,,,,,,,,,,,,,, -33900,0.51677054,1.2842957,,,,,,,,,,,,,, -34000,0.4476387,1.233986,,,,,,,,,,,,,, -34100,0.49948806,1.2778533,,,,,,,,,,,,,, -34200,0.58000654,1.2219315,,,,,,,,,,,,,, -34300,0.49957842,1.2220724,,,,,,,,,,,,,, -34400,0.57478684,1.2353655,,,,,,,,,,,,,, -34500,0.44865406,1.2389733,,,,,,,,,,,,,, -34600,0.4412234,1.2354234,,,,,,,,,,,,,, -34700,0.75519425,1.226137,,,,,,,,,,,,,, -34800,0.4951923,1.2538594,,,,,,,,,,,,,, -34900,0.6007562,1.2292361,,,,,,,,,,,,,, -35000,0.54829603,1.1879092,,,,,,,,,,,,,, -35100,0.54168636,1.2881795,,,,,,,,,,,,,, -35200,0.48142505,1.1965401,,,,,,,,,,,,,, -35300,0.53905517,1.2079448,,,,,,,,,,,,,, -35400,0.57447696,1.1921958,,,,,,,,,,,,,, -35434,,,0.2441035,0.0893927837749063,0.426791,0.1284165403516224,5348.0,0.23996721,0.0799870005890358,2472.0,27402.324479341507,30023.01597237587,27402.324479341507,2618.166530609131,1.0274250507354736,0.0 -35500,0.54315907,1.1998835,,,,,,,,,,,,,, -35600,0.48705745,1.2015748,,,,,,,,,,,,,, -35700,0.49710488,1.2536212,,,,,,,,,,,,,, -35800,0.6730868,1.2408388,,,,,,,,,,,,,, -35900,0.5836662,1.2160056,,,,,,,,,,,,,, -36000,0.56142133,1.2009174,,,,,,,,,,,,,, -36100,0.5560168,1.1970233,,,,,,,,,,,,,, -36200,0.58340603,1.185385,,,,,,,,,,,,,, -36300,0.5027963,1.2172242,,,,,,,,,,,,,, -36400,0.5658635,1.1697364,,,,,,,,,,,,,, -36500,0.62646294,1.2118827,,,,,,,,,,,,,, -36600,0.60177916,1.2081126,,,,,,,,,,,,,, -36700,0.6023773,1.1896571,,,,,,,,,,,,,, -36800,0.5236644,1.2133518,,,,,,,,,,,,,, -36900,0.54473513,1.2268218,,,,,,,,,,,,,, -37000,0.50096023,1.2165283,,,,,,,,,,,,,, -37100,0.4805044,1.1524596,,,,,,,,,,,,,, -37200,0.51713187,1.187453,,,,,,,,,,,,,, -37298,,,0.26563275,0.0912432650183795,0.4151319,0.1227492590053776,5348.0,0.23475464,0.0799260658501411,2472.0,28842.528745412827,31594.84669661522,28842.528745412827,2749.658364534378,1.082904815673828,0.0 -37300,0.5685357,1.2271155,,,,,,,,,,,,,, -37400,0.66575205,1.2102871,,,,,,,,,,,,,, -37500,0.51688004,1.1731956,,,,,,,,,,,,,, -37600,0.652196,1.2215755,,,,,,,,,,,,,, -37700,0.54676604,1.1930845,,,,,,,,,,,,,, -37800,0.5609675,1.1915336,,,,,,,,,,,,,, -37900,0.6130691,1.2070135,,,,,,,,,,,,,, -38000,0.6041403,1.1394135,,,,,,,,,,,,,, -38100,0.54563946,1.2228215,,,,,,,,,,,,,, -38200,0.57682616,1.1708418,,,,,,,,,,,,,, -38300,0.5802471,1.1613609,,,,,,,,,,,,,, -38400,0.5275049,1.2137349,,,,,,,,,,,,,, -38500,0.6033195,1.1973375,,,,,,,,,,,,,, -38600,0.54146725,1.1743124,,,,,,,,,,,,,, -38700,0.52827764,1.1707917,,,,,,,,,,,,,, -38800,0.5091803,1.1880758,,,,,,,,,,,,,, -38900,0.5771634,1.2294703,,,,,,,,,,,,,, -39000,0.73909235,1.1786516,,,,,,,,,,,,,, -39100,0.60068077,1.2463615,,,,,,,,,,,,,, -39159,,,0.21324308,0.0787642760250295,0.41352835,0.1211658959035307,5348.0,0.23633035,0.0776308573517762,2472.0,30282.826360464096,33165.95847392082,30282.826360464096,2880.340024471283,1.1378488540649414,0.0 -39200,0.6373149,1.1448778,,,,,,,,,,,,,, -39300,0.6645362,1.2207191,,,,,,,,,,,,,, -39400,0.6482884,1.182843,,,,,,,,,,,,,, -39500,0.5579593,1.219482,,,,,,,,,,,,,, -39600,0.5594359,1.2234412,,,,,,,,,,,,,, -39700,0.6286953,1.1851475,,,,,,,,,,,,,, -39800,0.539709,1.1426007,,,,,,,,,,,,,, -39900,0.48661992,1.1682239,,,,,,,,,,,,,, -40000,0.5896642,1.1754191,,,,,,,,,,,,,, -40100,0.63595957,1.2187694,,,,,,,,,,,,,, -40200,0.5337007,1.2250627,,,,,,,,,,,,,, -40300,0.67286795,1.2049406,,,,,,,,,,,,,, -40400,0.5014321,1.121076,,,,,,,,,,,,,, -40500,0.617257,1.2091688,,,,,,,,,,,,,, -40600,0.6573333,1.185838,,,,,,,,,,,,,, -40700,0.7241771,1.142017,,,,,,,,,,,,,, -40800,0.58535504,1.1481335,,,,,,,,,,,,,, -40900,0.64086926,1.1868454,,,,,,,,,,,,,, -41000,0.49826545,1.1925579,,,,,,,,,,,,,, -41031,,,0.19461007,0.0723608999189762,0.39671153,0.1168695752918118,5348.0,0.22629774,0.0753762720126744,2472.0,31722.86250114441,34737.42574644089,31722.86250114441,3011.626853942871,1.2022075653076172,0.0 -41100,0.52439624,1.144318,,,,,,,,,,,,,, -41200,0.5381969,1.1582305,,,,,,,,,,,,,, -41300,0.57985365,1.129821,,,,,,,,,,,,,, -41400,0.55387443,1.1775656,,,,,,,,,,,,,, -41500,0.508737,1.16244,,,,,,,,,,,,,, -41600,0.5456861,1.1959084,,,,,,,,,,,,,, -41700,0.7200386,1.132772,,,,,,,,,,,,,, -41800,0.57359546,1.1036112,,,,,,,,,,,,,, -41900,0.64903903,1.1607077,,,,,,,,,,,,,, -42000,0.62525815,1.2247576,,,,,,,,,,,,,, -42100,0.49639714,1.2292368,,,,,,,,,,,,,, -42200,0.6655555,1.1611279,,,,,,,,,,,,,, -42300,0.4950954,1.1223592,,,,,,,,,,,,,, -42400,0.47838587,1.1222488,,,,,,,,,,,,,, -42500,0.6163923,1.2031255,,,,,,,,,,,,,, -42600,0.5926642,1.1977606,,,,,,,,,,,,,, -42700,0.5929339,1.0946124,,,,,,,,,,,,,, -42800,0.58586866,1.1570654,,,,,,,,,,,,,, -42900,,,0.22017571,0.0809919242074023,0.39385065,0.1172654160672736,5348.0,0.22514184,0.0751528446367274,2472.0,33163.35210299492,36310.16140389442,33163.35210299492,3143.738133907318,1.2583415508270264,0.0 -42900,0.5780049,1.1513644,,,,,,,,,,,,,, -43000,0.55919915,1.1527818,,,,,,,,,,,,,, -43100,0.6758243,1.1693447,,,,,,,,,,,,,, -43200,0.5207015,1.1861604,,,,,,,,,,,,,, -43300,0.56717145,1.0843415,,,,,,,,,,,,,, -43400,0.49738842,1.1560472,,,,,,,,,,,,,, -43500,0.6325167,1.141759,,,,,,,,,,,,,, -43600,0.5770321,1.0770278,,,,,,,,,,,,,, -43700,0.46563715,1.1414591,,,,,,,,,,,,,, -43800,0.55499953,1.090348,,,,,,,,,,,,,, -43900,0.6475008,1.1866384,,,,,,,,,,,,,, -44000,0.5677185,1.1781399,,,,,,,,,,,,,, -44100,0.6032836,1.1646885,,,,,,,,,,,,,, -44200,0.54675007,1.1144234,,,,,,,,,,,,,, -44300,0.62291896,1.154888,,,,,,,,,,,,,, -44400,0.57050437,1.1702503,,,,,,,,,,,,,, -44500,0.513396,1.0554259,,,,,,,,,,,,,, -44600,0.5221097,1.1580226,,,,,,,,,,,,,, -44700,0.7793252,1.173158,,,,,,,,,,,,,, -44779,,,0.14650366,0.0550815640269449,0.39157462,0.1133070083126562,5348.0,0.21745369,0.0732232445717303,2472.0,34603.75563144684,37891.43506240845,34603.75563144684,3284.47208070755,1.315403699874878,0.0 -44800,0.6221364,1.1094716,,,,,,,,,,,,,, -44900,0.61015385,1.0661224,,,,,,,,,,,,,, -45000,0.7584525,1.1736505,,,,,,,,,,,,,, -45100,0.5455676,1.1775326,,,,,,,,,,,,,, -45200,0.5632673,1.10256,,,,,,,,,,,,,, -45300,0.64272326,1.1482147,,,,,,,,,,,,,, -45400,0.6960993,1.1542966,,,,,,,,,,,,,, -45500,0.60048693,1.0893465,,,,,,,,,,,,,, -45600,0.675922,1.1477575,,,,,,,,,,,,,, -45700,0.5979508,1.1607512,,,,,,,,,,,,,, -45800,0.66829824,1.1278741,,,,,,,,,,,,,, -45900,0.6072145,1.109459,,,,,,,,,,,,,, -46000,0.72973734,1.1328192,,,,,,,,,,,,,, -46100,0.6724066,1.1364149,,,,,,,,,,,,,, -46200,0.59428966,1.1530589,,,,,,,,,,,,,, -46300,0.661537,1.1762922,,,,,,,,,,,,,, -46400,0.6234365,1.097081,,,,,,,,,,,,,, -46500,0.5670378,1.129064,,,,,,,,,,,,,, -46600,0.5284982,1.1816541,,,,,,,,,,,,,, -46654,,,0.12380371,0.0473838948874723,0.38062832,0.1110574741496664,5348.0,0.21394913,0.0718623687364166,2472.0,36043.75965952873,39464.65749335289,36043.75965952873,3417.5548565387726,1.3723094463348389,0.0 -46700,0.6448627,1.1594588,,,,,,,,,,,,,, -46800,0.6133815,1.112335,,,,,,,,,,,,,, -46900,0.8014737,1.1783024,,,,,,,,,,,,,, -47000,0.5599422,1.1299471,,,,,,,,,,,,,, -47100,0.48296645,1.1420366,,,,,,,,,,,,,, -47200,0.5168563,1.107359,,,,,,,,,,,,,, -47300,0.62287337,1.1290329,,,,,,,,,,,,,, -47400,0.51757634,1.1184876,,,,,,,,,,,,,, -47500,0.5494002,1.1503549,,,,,,,,,,,,,, -47600,0.57621646,1.1081394,,,,,,,,,,,,,, -47700,1.1339772,1.1037842,,,,,,,,,,,,,, -47800,0.69035286,1.1045158,,,,,,,,,,,,,, -47900,0.55991834,1.0920202,,,,,,,,,,,,,, -48000,0.62770164,1.0713569,,,,,,,,,,,,,, -48100,0.5563455,1.0466105,,,,,,,,,,,,,, -48200,0.6973733,1.0944237,,,,,,,,,,,,,, -48300,0.71929675,1.0888082,,,,,,,,,,,,,, -48400,0.81120276,1.1103665,,,,,,,,,,,,,, -48500,0.59256035,1.0885894,,,,,,,,,,,,,, -48521,,,0.12260363,0.0468208541607313,0.3690256,0.1080645317010533,5348.0,0.20314053,0.0680031686064225,2472.0,37483.8937060833,41037.80516719818,37483.8937060833,3550.433384656906,1.4295480251312256,0.0 -48600,0.62133455,1.088439,,,,,,,,,,,,,, -48700,0.5326127,1.057876,,,,,,,,,,,,,, -48800,0.69448614,1.0856693,,,,,,,,,,,,,, -48900,0.70446914,1.1631466,,,,,,,,,,,,,, -49000,0.45705858,1.0840632,,,,,,,,,,,,,, -49100,0.6105597,1.06787,,,,,,,,,,,,,, -49200,0.64725745,1.0662118,,,,,,,,,,,,,, -49300,0.6250183,1.0865476,,,,,,,,,,,,,, -49400,0.6943758,1.1095828,,,,,,,,,,,,,, -49500,0.5770783,1.0937117,,,,,,,,,,,,,, -49600,0.74590015,1.0832238,,,,,,,,,,,,,, -49700,0.7728666,1.1079413,,,,,,,,,,,,,, -49800,0.53749907,1.0830196,,,,,,,,,,,,,, -49900,0.70128113,1.0851569,,,,,,,,,,,,,, -50000,0.5650355,1.0892576,,,,,,,,,,,,,, -50100,0.5134615,1.0774213,,,,,,,,,,,,,, -50200,0.63204217,1.0427978,,,,,,,,,,,,,, -50300,0.6353225,1.1002204,,,,,,,,,,,,,, -50389,,,0.119473204,0.0460013418141535,0.36581305,0.1060853278237446,5348.0,0.20159614,0.068653139154632,2472.0,38924.372160196304,42611.40343165398,38924.372160196304,3683.415236711502,1.4886324405670166,0.0 -50400,0.5123744,1.0868018,,,,,,,,,,,,,, -50500,0.55895525,1.0699036,,,,,,,,,,,,,, -50600,0.5864944,1.0720603,,,,,,,,,,,,,, -50700,0.649996,1.0404091,,,,,,,,,,,,,, -50800,0.6006277,1.0649849,,,,,,,,,,,,,, -50900,0.4920378,1.0253934,,,,,,,,,,,,,, -51000,0.5941808,1.1230499,,,,,,,,,,,,,, -51100,0.70700395,1.0622085,,,,,,,,,,,,,, -51200,0.5690782,1.0718708,,,,,,,,,,,,,, -51300,0.61817884,1.1065549,,,,,,,,,,,,,, -51400,0.63822114,1.0958972,,,,,,,,,,,,,, -51500,0.64900744,1.04919,,,,,,,,,,,,,, -51600,0.6365162,1.0663689,,,,,,,,,,,,,, -51700,0.8865229,1.0724615,,,,,,,,,,,,,, -51800,0.5905973,1.0439976,,,,,,,,,,,,,, -51900,0.5764024,1.0966583,,,,,,,,,,,,,, -52000,0.6646371,1.0277916,,,,,,,,,,,,,, -52100,0.5504856,1.11235,,,,,,,,,,,,,, -52200,0.7150815,1.0887096,,,,,,,,,,,,,, -52250,,,0.108347975,0.0429867342424755,0.35260636,0.1029186016200507,5348.0,0.19382338,0.065057989559848,2472.0,40364.40458655357,44182.28010845184,40364.40458655357,3814.117434978485,1.552154541015625,0.0 -52300,0.6276057,1.1078751,,,,,,,,,,,,,, -52400,0.5566719,1.0753282,,,,,,,,,,,,,, -52500,0.6063839,1.0821248,,,,,,,,,,,,,, -52600,0.62217104,1.0715866,,,,,,,,,,,,,, -52700,0.64980507,1.1134632,,,,,,,,,,,,,, -52800,0.5101058,1.0664341,,,,,,,,,,,,,, -52900,0.59055376,1.0278566,,,,,,,,,,,,,, -53000,0.62242436,1.066659,,,,,,,,,,,,,, -53100,0.5262755,1.041197,,,,,,,,,,,,,, -53200,0.6351612,1.0436145,,,,,,,,,,,,,, -53300,0.5266121,1.0708243,,,,,,,,,,,,,, -53400,0.6707775,1.0418578,,,,,,,,,,,,,, -53500,0.6400573,1.0360698,,,,,,,,,,,,,, -53600,0.66667193,1.0925759,,,,,,,,,,,,,, -53700,0.54430604,1.0513179,,,,,,,,,,,,,, -53800,0.6396508,1.0197052,,,,,,,,,,,,,, -53900,0.71495634,1.0385116,,,,,,,,,,,,,, -54000,0.68192196,1.0210268,,,,,,,,,,,,,, -54100,0.61586016,1.0172243,,,,,,,,,,,,,, -54120,,,0.09962546,0.0394493188699585,0.35237452,0.1011807640692431,5348.0,0.19269316,0.0634127516096926,2472.0,41804.90568423271,45756.05579662323,41804.90568423271,3947.251857280731,1.6132729053497314,0.0 -54200,0.55790377,1.059818,,,,,,,,,,,,,, -54300,0.6171258,1.062957,,,,,,,,,,,,,, -54400,0.75726366,1.0359621,,,,,,,,,,,,,, -54500,0.5592873,1.0683911,,,,,,,,,,,,,, -54600,0.6234206,1.007788,,,,,,,,,,,,,, -54700,0.7431279,1.0450121,,,,,,,,,,,,,, -54800,0.6738538,1.0259845,,,,,,,,,,,,,, -54900,0.67852175,1.0719335,,,,,,,,,,,,,, -55000,0.6702145,1.0519359,,,,,,,,,,,,,, -55100,0.8012407,1.01965,,,,,,,,,,,,,, -55200,0.555488,1.0499732,,,,,,,,,,,,,, -55300,0.8318253,1.0078127,,,,,,,,,,,,,, -55400,0.56657636,1.057716,,,,,,,,,,,,,, -55500,0.5121407,1.0463881,,,,,,,,,,,,,, -55600,0.61913383,1.0449308,,,,,,,,,,,,,, -55700,0.61781496,1.0512362,,,,,,,,,,,,,, -55800,0.60180575,1.0547191,,,,,,,,,,,,,, -55900,0.5460638,1.0069586,,,,,,,,,,,,,, -55989,,,0.11142533,0.0426152274898692,0.34446225,0.0983712600287708,5348.0,0.18800361,0.0612191010094855,2472.0,43245.16593647003,47327.161296606064,43245.16593647003,4077.9627606868735,1.6707770824432373,0.0 -56000,0.68787134,1.0216213,,,,,,,,,,,,,, -56100,0.68773305,1.039196,,,,,,,,,,,,,, -56200,0.6448178,0.9944779,,,,,,,,,,,,,, -56300,0.6690158,1.0351979,,,,,,,,,,,,,, -56400,0.5834468,1.0574809,,,,,,,,,,,,,, -56500,0.6510126,1.0718769,,,,,,,,,,,,,, -56600,0.6067031,1.0559965,,,,,,,,,,,,,, -56700,0.59620714,0.95246124,,,,,,,,,,,,,, -56800,0.56750727,0.9899644,,,,,,,,,,,,,, -56900,0.68091446,1.0405343,,,,,,,,,,,,,, -57000,0.5909074,1.0548052,,,,,,,,,,,,,, -57100,0.6526899,1.0395535,,,,,,,,,,,,,, -57200,0.7354577,1.03009,,,,,,,,,,,,,, -57300,0.5890093,1.0320243,,,,,,,,,,,,,, -57400,0.76843226,1.0093778,,,,,,,,,,,,,, -57500,0.6096557,1.0529047,,,,,,,,,,,,,, -57600,0.59175247,0.98784727,,,,,,,,,,,,,, -57700,0.5528661,1.066534,,,,,,,,,,,,,, -57800,0.6187494,0.99057364,,,,,,,,,,,,,, -57853,,,0.09508877,0.03473609018244,0.33794376,0.0959962153760004,5348.0,0.1831914,0.0594519935815408,2472.0,44685.39541554451,48900.7434065342,44685.39541554451,4211.173542499542,1.7332322597503662,0.0 -57900,0.6578569,1.0584096,,,,,,,,,,,,,, -58000,0.56464803,1.015519,,,,,,,,,,,,,, -58100,0.7215142,1.0398871,,,,,,,,,,,,,, -58200,0.61426324,1.0442257,,,,,,,,,,,,,, -58300,0.5776798,0.9797966,,,,,,,,,,,,,, -58400,0.65496284,1.028761,,,,,,,,,,,,,, -58500,0.5903928,1.0052457,,,,,,,,,,,,,, -58600,0.6328448,1.0088904,,,,,,,,,,,,,, -58700,0.99941486,1.010067,,,,,,,,,,,,,, -58800,0.69169235,0.99288505,,,,,,,,,,,,,, -58900,0.8095522,0.9791374,,,,,,,,,,,,,, -59000,0.698773,1.0226223,,,,,,,,,,,,,, -59100,0.6212863,1.0078261,,,,,,,,,,,,,, -59200,0.6990878,1.005882,,,,,,,,,,,,,, -59300,0.62993336,1.01882,,,,,,,,,,,,,, -59400,0.54177403,1.0233018,,,,,,,,,,,,,, -59500,0.6518221,0.9740147,,,,,,,,,,,,,, -59600,0.6347934,0.9988064,,,,,,,,,,,,,, -59700,0.6514312,0.9920787,,,,,,,,,,,,,, -59717,,,0.092592455,0.0354570637119113,0.33490953,0.0958127769678596,5348.0,0.1784495,0.058273921962911,2472.0,46125.69869709015,50472.63144659996,46125.69869709015,4342.618939638138,1.794440984725952,0.0 -59800,0.533207,0.9438719,,,,,,,,,,,,,, -59900,0.70518994,0.99578524,,,,,,,,,,,,,, -60000,0.6555084,0.99760306,,,,,,,,,,,,,, -60100,0.57957894,0.9517747,,,,,,,,,,,,,, -60200,0.6619261,0.99884367,,,,,,,,,,,,,, -60300,0.6233029,0.98942673,,,,,,,,,,,,,, -60400,0.65895784,0.9881238,,,,,,,,,,,,,, -60500,0.77970636,0.95409375,,,,,,,,,,,,,, -60600,0.7444285,0.9853119,,,,,,,,,,,,,, -60700,0.87734884,0.9994049,,,,,,,,,,,,,, -60800,0.66016704,0.9847151,,,,,,,,,,,,,, -60900,0.7258683,1.0028627,,,,,,,,,,,,,, -61000,0.707671,0.9883866,,,,,,,,,,,,,, -61100,0.65553415,1.0262203,,,,,,,,,,,,,, -61200,0.70079976,1.0153606,,,,,,,,,,,,,, -61300,0.7601478,0.9812605,,,,,,,,,,,,,, -61400,0.6766453,0.9487751,,,,,,,,,,,,,, -61500,0.8836784,1.0032213,,,,,,,,,,,,,, -61577,,,0.073665015,0.0287280887458717,0.3230342,0.0926363961111057,5348.0,0.17452674,0.0569333577072288,2472.0,47566.37073302269,52045.9193854332,47566.37073302269,4475.0985696315765,1.8526763916015625,0.0 -61600,0.6524075,0.95802903,,,,,,,,,,,,,, -61700,0.7241233,0.953962,,,,,,,,,,,,,, -61800,0.6411488,0.98477983,,,,,,,,,,,,,, -61900,0.7927574,0.9770151,,,,,,,,,,,,,, -62000,0.65529776,0.9720523,,,,,,,,,,,,,, -62100,0.6618426,0.95224786,,,,,,,,,,,,,, -62200,0.8176746,0.94281155,,,,,,,,,,,,,, -62300,0.6031791,0.92704993,,,,,,,,,,,,,, -62400,0.61427116,0.9671263,,,,,,,,,,,,,, -62500,0.7193419,0.93266636,,,,,,,,,,,,,, -62600,0.58035177,0.9668751,,,,,,,,,,,,,, -62700,0.59197384,0.9650827,,,,,,,,,,,,,, -62800,0.8540582,0.99159473,,,,,,,,,,,,,, -62900,1.0219202,0.95902,,,,,,,,,,,,,, -63000,0.75933456,0.94012,,,,,,,,,,,,,, -63100,0.6300727,0.9818819,,,,,,,,,,,,,, -63200,0.6447726,0.9671811,,,,,,,,,,,,,, -63300,0.8902739,0.9281475,,,,,,,,,,,,,, -63400,0.609724,0.983495,,,,,,,,,,,,,, -63446,,,0.06885544,0.0277676345912648,0.32002056,0.0898558560298135,5348.0,0.17051816,0.0555927934515467,2472.0,49006.61024451256,53619.56326055527,49006.61024451256,4608.362781047821,1.9141466617584229,0.0 -63500,0.7093197,1.0033454,,,,,,,,,,,,,, -63600,0.7125051,0.93915474,,,,,,,,,,,,,, -63700,0.64742327,0.952979,,,,,,,,,,,,,, -63800,0.73246795,0.953929,,,,,,,,,,,,,, -63900,0.6844967,0.97553605,,,,,,,,,,,,,, -64000,0.675454,0.91718453,,,,,,,,,,,,,, -64100,0.64693403,0.96901554,,,,,,,,,,,,,, -64200,0.73367864,0.92781395,,,,,,,,,,,,,, -64300,0.7999578,0.9523914,,,,,,,,,,,,,, -64400,0.6169848,0.9903476,,,,,,,,,,,,,, -64500,0.6750239,0.9685067,,,,,,,,,,,,,, -64600,1.0729165,0.9303798,,,,,,,,,,,,,, -64700,0.6732263,0.9664167,,,,,,,,,,,,,, -64800,0.6924565,0.9775639,,,,,,,,,,,,,, -64900,0.7255964,0.9028659,,,,,,,,,,,,,, -65000,0.7315342,0.9105416,,,,,,,,,,,,,, -65100,0.7692286,0.90757775,,,,,,,,,,,,,, -65200,0.5936813,0.9050159,,,,,,,,,,,,,, -65300,0.75322974,0.9427634,,,,,,,,,,,,,, -65317,,,0.07590765,0.029906168248546,0.3158513,0.0880118172953454,5348.0,0.16880462,0.0543740986736538,2472.0,50446.49883818626,55190.90368533135,50446.49883818626,4739.67716550827,1.9725675582885744,0.0 -65400,0.8636542,0.9661705,,,,,,,,,,,,,, -65500,0.7828073,0.9314573,,,,,,,,,,,,,, -65600,0.7265273,0.9565782,,,,,,,,,,,,,, -65700,0.6467596,0.9482999,,,,,,,,,,,,,, -65800,0.6237521,0.9572749,,,,,,,,,,,,,, -65900,0.7575204,0.94795525,,,,,,,,,,,,,, -66000,1.055622,0.89997816,,,,,,,,,,,,,, -66100,0.6437994,0.93808025,,,,,,,,,,,,,, -66200,0.765784,1.0214149,,,,,,,,,,,,,, -66300,0.76031035,0.95488095,,,,,,,,,,,,,, -66400,0.70293164,0.9486249,,,,,,,,,,,,,, -66500,0.70590025,0.97123617,,,,,,,,,,,,,, -66600,0.6511972,0.92913455,,,,,,,,,,,,,, -66700,0.6677037,0.9411851,,,,,,,,,,,,,, -66800,0.6354736,0.8898127,,,,,,,,,,,,,, -66900,0.5970458,0.94027865,,,,,,,,,,,,,, -67000,0.6592762,0.91552323,,,,,,,,,,,,,, -67100,0.64386344,0.9625644,,,,,,,,,,,,,, -67186,,,0.06836843,0.0268750163044896,0.3090984,0.085858829662956,5348.0,0.1675134,0.0536225702272865,2472.0,51886.94908380509,56764.86767911911,51886.94908380509,4873.053730249405,2.0313730239868164,0.0 -67200,0.6237575,0.97073936,,,,,,,,,,,,,, -67300,0.7862353,0.9087853,,,,,,,,,,,,,, -67400,0.82750094,0.90346587,,,,,,,,,,,,,, -67500,0.64372146,0.9266331,,,,,,,,,,,,,, -67600,0.85505575,0.9287031,,,,,,,,,,,,,, -67700,0.74475574,0.91140914,,,,,,,,,,,,,, -67800,0.80624497,0.9254695,,,,,,,,,,,,,, -67900,0.69550276,0.9076851,,,,,,,,,,,,,, -68000,0.6005598,0.910831,,,,,,,,,,,,,, -68100,0.7741013,0.9518109,,,,,,,,,,,,,, -68200,0.64577407,0.9287067,,,,,,,,,,,,,, -68300,0.69821,0.93257135,,,,,,,,,,,,,, -68400,0.6977414,0.9151993,,,,,,,,,,,,,, -68500,0.7111463,0.9423993,,,,,,,,,,,,,, -68600,0.7974398,0.9921847,,,,,,,,,,,,,, -68700,0.61352664,0.95401734,,,,,,,,,,,,,, -68800,0.6912591,0.91082495,,,,,,,,,,,,,, -68900,0.7440202,0.96129376,,,,,,,,,,,,,, -69000,0.7086694,0.9159448,,,,,,,,,,,,,, -69044,,,0.06659844,0.0258889079465069,0.3053904,0.0853760970099539,5348.0,0.16303761,0.0516726585826579,2472.0,53326.903443574905,58338.23026275635,53326.903443574905,5006.324478149414,2.09139347076416,0.0 -69100,0.6484442,0.96013427,,,,,,,,,,,,,, -69200,0.7009627,0.88499266,,,,,,,,,,,,,, -69300,0.66991234,0.91365755,,,,,,,,,,,,,, -69400,0.96818894,0.90313387,,,,,,,,,,,,,, -69500,0.6504612,0.93642354,,,,,,,,,,,,,, -69600,1.0288821,0.9329096,,,,,,,,,,,,,, -69700,0.64414334,0.93015355,,,,,,,,,,,,,, -69800,0.60243547,0.8886489,,,,,,,,,,,,,, -69900,0.6756951,0.9086165,,,,,,,,,,,,,, -70000,0.66665614,0.8898869,,,,,,,,,,,,,, -70100,0.97538126,0.9090379,,,,,,,,,,,,,, -70200,0.8423215,0.91284484,,,,,,,,,,,,,, -70300,0.841807,0.9167783,,,,,,,,,,,,,, -70400,0.70282954,0.85719234,,,,,,,,,,,,,, -70500,0.63446975,0.91143835,,,,,,,,,,,,,, -70600,0.7091213,0.87301886,,,,,,,,,,,,,, -70700,0.6612849,0.92713577,,,,,,,,,,,,,, -70800,0.56574297,0.9264394,,,,,,,,,,,,,, -70900,0.7068077,0.9155503,,,,,,,,,,,,,, -70916,,,0.05856837,0.0229841510387068,0.30260167,0.0837734246019869,5348.0,0.16102883,0.0516320354233948,2472.0,54767.0825843811,59911.81316590309,54767.0825843811,5139.588047981262,2.15257215499878,0.0 -71000,0.77473867,0.8876763,,,,,,,,,,,,,, -71100,0.6216215,0.8616406,,,,,,,,,,,,,, -71200,0.8301045,0.91205096,,,,,,,,,,,,,, -71300,0.6277434,0.9110472,,,,,,,,,,,,,, -71400,0.7039519,0.8957462,,,,,,,,,,,,,, -71500,0.910985,0.9858902,,,,,,,,,,,,,, -71600,0.8978459,0.83988494,,,,,,,,,,,,,, -71700,0.99416435,0.87303436,,,,,,,,,,,,,, -71800,0.70658827,0.91601884,,,,,,,,,,,,,, -71900,0.64009464,0.87617826,,,,,,,,,,,,,, -72000,0.6471528,0.86945015,,,,,,,,,,,,,, -72100,0.67877287,0.90615237,,,,,,,,,,,,,, -72200,0.9531585,0.8532559,,,,,,,,,,,,,, -72300,0.7197671,0.90133613,,,,,,,,,,,,,, -72400,0.676861,0.913942,,,,,,,,,,,,,, -72500,0.6169312,0.921949,,,,,,,,,,,,,, -72600,0.6761119,0.8712295,,,,,,,,,,,,,, -72700,0.6261895,0.9551278,,,,,,,,,,,,,, -72786,,,0.05365826,0.020761263701913,0.30205798,0.0834837850101856,5348.0,0.15916285,0.0509617532955537,2472.0,56207.34043216705,61485.750854730606,56207.34043216705,5273.123430967331,2.217568635940552,0.0 -72800,0.6250141,0.8909404,,,,,,,,,,,,,, -72900,0.7450738,0.91270643,,,,,,,,,,,,,, -73000,0.7148386,0.89456177,,,,,,,,,,,,,, -73100,0.8412723,0.89373946,,,,,,,,,,,,,, -73200,0.64057624,0.8674504,,,,,,,,,,,,,, -73300,0.6324413,0.9014311,,,,,,,,,,,,,, -73400,0.6602893,0.9295077,,,,,,,,,,,,,, -73500,1.2487644,0.91676563,,,,,,,,,,,,,, -73600,0.75044364,0.92404395,,,,,,,,,,,,,, -73700,1.0468147,0.8581146,,,,,,,,,,,,,, -73800,0.67434645,0.89120966,,,,,,,,,,,,,, -73900,0.6967801,0.91145366,,,,,,,,,,,,,, -74000,0.6891288,0.87951696,,,,,,,,,,,,,, -74100,0.66025215,0.87247795,,,,,,,,,,,,,, -74200,0.6152944,0.87763894,,,,,,,,,,,,,, -74300,0.89464,0.89384806,,,,,,,,,,,,,, -74400,0.72658056,0.90465045,,,,,,,,,,,,,, -74500,1.0261835,0.87313086,,,,,,,,,,,,,, -74600,0.6236157,0.8637087,,,,,,,,,,,,,, -74654,,,0.05984657,0.0220294749769726,0.2980946,0.0816494009287776,5348.0,0.15752992,0.0500883553713972,2472.0,57647.41107773781,63059.09764456749,57647.41107773781,5406.259341955185,2.278391599655152,0.0 -74700,0.79045576,0.85418606,,,,,,,,,,,,,, -74800,0.7841248,0.9362332,,,,,,,,,,,,,, -74900,0.85282737,0.88456535,,,,,,,,,,,,,, -75000,0.6692759,0.8686538,,,,,,,,,,,,,, -75100,0.9400764,0.8962641,,,,,,,,,,,,,, -75200,0.7263907,0.895255,,,,,,,,,,,,,, -75300,0.99348056,0.90020853,,,,,,,,,,,,,, -75400,0.6675476,0.8889555,,,,,,,,,,,,,, -75500,0.82838136,0.8868688,,,,,,,,,,,,,, -75600,0.75165236,0.8509973,,,,,,,,,,,,,, -75700,0.6982449,0.87401587,,,,,,,,,,,,,, -75800,0.6284477,0.87999433,,,,,,,,,,,,,, -75900,0.7099464,0.90803975,,,,,,,,,,,,,, -76000,0.6910709,0.93372107,,,,,,,,,,,,,, -76100,1.0340798,0.9577335,,,,,,,,,,,,,, -76200,0.582163,0.90068513,,,,,,,,,,,,,, -76300,0.94248915,0.8544187,,,,,,,,,,,,,, -76400,0.81283385,0.90055287,,,,,,,,,,,,,, -76500,0.87998235,0.91065365,,,,,,,,,,,,,, -76528,,,0.057438117,0.021416039085486,0.29819798,0.0816300916226575,5348.0,0.15713376,0.0502914711677127,2472.0,59087.65986609459,64631.96136951447,59087.65986609459,5538.735710859299,2.3376240730285645,0.0 -76600,0.8292527,0.9469555,,,,,,,,,,,,,, -76700,0.7855499,0.90273386,,,,,,,,,,,,,, -76800,0.64746124,0.8934928,,,,,,,,,,,,,, -76900,0.7665663,0.8874049,,,,,,,,,,,,,, -77000,0.6580819,0.90279347,,,,,,,,,,,,,, -77100,0.8061049,0.8845098,,,,,,,,,,,,,, -77200,0.8802247,0.9046993,,,,,,,,,,,,,, -77300,1.2066125,0.87400925,,,,,,,,,,,,,, -77400,0.7267102,0.86756206,,,,,,,,,,,,,, -77500,0.87213933,0.88376874,,,,,,,,,,,,,, -77600,0.6326289,0.8948565,,,,,,,,,,,,,, -77700,1.1692102,0.86191094,,,,,,,,,,,,,, -77800,0.7853449,0.86160904,,,,,,,,,,,,,, -77900,0.9019682,0.8623641,,,,,,,,,,,,,, -78000,0.784173,0.90239835,,,,,,,,,,,,,, -78100,0.695626,0.8775079,,,,,,,,,,,,,, -78200,0.968157,0.89702696,,,,,,,,,,,,,, -78300,0.6604572,0.8694376,,,,,,,,,,,,,, -78379,,,0.058575712,0.0221061653007659,0.29666498,0.0812632148063759,5348.0,0.15672202,0.0499867974732394,2472.0,60527.83698248863,66204.51050138474,60527.83698248863,5670.963258266449,2.402360677719116,0.0 -78400,0.797992,0.87399256,,,,,,,,,,,,,, -78500,0.6761745,0.8966381,,,,,,,,,,,,,, -78600,0.8907472,0.9042856,,,,,,,,,,,,,, -78700,0.68064827,0.9027103,,,,,,,,,,,,,, -78800,0.7351543,0.85976917,,,,,,,,,,,,,, -78900,0.7175498,0.8520567,,,,,,,,,,,,,, -79000,0.8417604,0.8321967,,,,,,,,,,,,,, -79091,,,,,,,,,,,61068.11158013344,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index ac32fde7d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -233.57833003997803,0.0,44.56120181083679,1,0,44.56120181083679,30.536453,2472,3.361383624804501,278.1395950317383,31.23901,3.532270995148158,30.351154,5348,3.0460816590555817 -341.49453115463257,0.044753074645996,1485.454912185669,1829,0,1485.454912185669,6.149786,2472,0.8985639713200496,1827.0729236602783,6.2994285,0.9422720438301992,6.226698,5348,0.8960097318902845 -464.7445025444031,0.0834105014801025,2925.4223692417145,3638,0,2925.4223692417145,3.2384326,2472,0.6621168728292,3390.408249616623,3.389681,0.697909064272967,3.6511831,5348,0.7208646707280574 -598.6258668899536,0.1237275600433349,4365.819350719452,5458,0,4365.819350719452,0.68839693,2472,0.2146730851258302,4964.804627656937,0.6259818,0.2104833140935423,1.0288459,5348,0.287447985556639 -736.6049258708954,0.1600997447967529,5806.179289579392,7226,0,5806.179289579392,0.5244182,2472,0.1696220015030569,6543.255874633789,0.48525193,0.1632282252406036,0.8257652,5348,0.237224480338299 -871.8770303726196,0.198617935180664,7246.616269826889,9028,0,7246.616269826889,0.47539398,2472,0.1525196514532935,8119.081926584244,0.43553168,0.1476059492487333,0.76645046,5348,0.2220570203809726 -1006.6984577178956,0.2374663352966308,8687.026130914688,10819,0,8687.026130914688,0.50596225,2472,0.1609489570003859,9694.430442094805,0.4950031,0.1622237278008653,0.8763321,5348,0.2440020467864487 -1139.495243549347,0.2780308723449707,10127.646817445757,12624,0,10127.646817445757,0.4055817,2472,0.129445696991855,11267.96731686592,0.37129277,0.1258162635866675,0.6872197,5348,0.1990210181797117 -1272.3366241455078,0.3184385299682617,11568.152488470078,14400,0,11568.152488470078,0.39234465,2472,0.1261552210915442,12841.431944847109,0.31816357,0.1100496632777763,0.6623367,5348,0.190650433976655 -1408.132239818573,0.3631329536437988,13008.389578580856,16169,0,13008.389578580856,0.3739464,2472,0.1225803830763918,14417.589095115662,0.3033341,0.1042565529422404,0.64782745,5348,0.1861706749567954 -1541.5824666023254,0.4025120735168457,14448.71219611168,17980,0,14448.71219611168,0.36189067,2472,0.1170556334166108,15991.481305122375,0.29420593,0.1039004220585948,0.6150627,5348,0.1781669675700203 -1674.7982478141785,0.4451837539672851,15888.818793296814,19763,0,15888.818793296814,0.35354367,2472,0.1129933174903012,17564.92405796051,0.30935612,0.1035617017901055,0.6012965,5348,0.1738320283460613 -1809.982824802399,0.487401008605957,17329.08687067032,21541,0,17329.08687067032,0.34354234,2472,0.1115715069160928,19140.49734663964,0.30396807,0.1025509800300736,0.5889155,5348,0.1706653021423675 -1944.016057014465,0.5339093208312988,18769.07917690277,23309,0,18769.07917690277,0.33128116,2472,0.106351430950785,20714.64650440216,0.28539538,0.0970058648009054,0.5723087,5348,0.1664655280612491 -2078.508171081543,0.5718889236450195,20209.84367251396,25094,0,20209.84367251396,0.3237855,2472,0.1045437003635772,22290.02009320259,0.26591775,0.0911437439297326,0.5673303,5348,0.1649111289185823 -2223.6968109607697,0.6204798221588135,21650.451526880264,26858,0,21650.451526880264,0.31746915,2472,0.1007251233928462,23875.94362139702,0.24384421,0.0851665405983463,0.55111766,5348,0.1586549137356749 -2358.1283671855927,0.6628987789154053,23090.378248929977,28641,0,23090.378248929977,0.2958047,2472,0.0958503442812747,25450.42253112793,0.22982687,0.079006044737192,0.52235276,5348,0.1524759357772478 -2496.595645904541,0.7129337787628174,24530.86066389084,30429,0,24530.86066389084,0.2912043,2472,0.0946925842422765,27029.50042438507,0.24319935,0.0841892009013693,0.51407164,5348,0.148855440879732 -2629.5989384651184,0.7543802261352539,25970.978246688843,32214,0,25970.978246688843,0.27680507,2472,0.0891678345824954,28602.740975379944,0.221281,0.0749217304488901,0.4996829,5348,0.1448004865945142 -2763.8282132148743,0.8003880977630615,27412.06226158142,33991,0,27412.06226158142,0.2669557,2472,0.0857554892043954,30178.17682671547,0.22154376,0.0749528687156147,0.47980833,5348,0.1381870492483852 -2897.226144552231,0.8438427448272705,28852.0884373188,35787,0,28852.0884373188,0.25955525,2472,0.0829728027948733,31751.72282457352,0.21164294,0.0702121764336029,0.46737054,5348,0.1357830406364347 -3032.571701526642,0.8902647495269775,30292.03697609901,37564,0,30292.03697609901,0.24834183,2472,0.0789307984481953,33327.14091706276,0.15737125,0.0545032440576858,0.44882238,5348,0.1306371105554322 -3166.721561670304,0.934283971786499,31732.59752678871,39355,0,31732.59752678871,0.24206062,2472,0.0764527857331464,34901.9730963707,0.17283808,0.0592490126924457,0.43858835,5348,0.1268814505150757 -3299.519961833954,0.9837052822113036,33172.82326436043,41130,0,33172.82326436043,0.23490724,2472,0.0752340909552535,36475.123975515366,0.21339306,0.0732522252295742,0.42771238,5348,0.1242940034949844 -3429.93771147728,1.0329890251159668,34613.33181476593,42909,0,34613.33181476593,0.22871256,2472,0.0724514045457315,38046.176903009415,0.21680024,0.0741213631846872,0.41661587,5348,0.1206059260260482 -3561.055627822876,1.075854778289795,36053.66913104057,44705,0,36053.66913104057,0.22644746,2472,0.0710499055511547,39617.75190496445,0.24842209,0.0861474925563465,0.412366,5348,0.1190032536180812 -3691.6905381679535,1.123474359512329,37494.07298588753,46467,0,37494.07298588753,0.22343448,2472,0.0708264781752076,41188.91581845284,0.22267985,0.0739774741093559,0.40985504,5348,0.1182019174140977 -3832.4622428417206,1.1707794666290283,38733.70230102539,48000,0,38733.70230102539,0.22379321,2472,0.07068429711778686,42569.4317715168,0.2067998,0.0725563114471179,0.41027457,5348,0.11796055108759666 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 7da290c51..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,22.147774,32.615826,,,,,,,,,,,,,, -1,,,31.23901,3.532270995148158,30.351154,3.0460816590555817,5348.0,30.536453,3.361383624804501,2472.0,44.56120181083679,278.1395950317383,44.56120181083679,233.57833003997803,0.0,0.0 -100,7.8380814,9.602661,,,,,,,,,,,,,, -200,1.8056128,6.6094503,,,,,,,,,,,,,, -300,0.5066531,5.8954425,,,,,,,,,,,,,, -400,0.34048876,5.8645124,,,,,,,,,,,,,, -500,0.46748134,5.8475175,,,,,,,,,,,,,, -600,0.32294354,5.793175,,,,,,,,,,,,,, -700,0.40629444,5.778879,,,,,,,,,,,,,, -800,0.36939514,5.717817,,,,,,,,,,,,,, -900,0.46345606,5.6048827,,,,,,,,,,,,,, -1000,1.0211233,5.5013366,,,,,,,,,,,,,, -1100,0.6452601,5.354233,,,,,,,,,,,,,, -1200,0.6563787,5.0553303,,,,,,,,,,,,,, -1300,1.3352801,4.5509195,,,,,,,,,,,,,, -1400,1.3604256,4.1213613,,,,,,,,,,,,,, -1500,2.4294436,3.8402033,,,,,,,,,,,,,, -1600,2.006952,3.5816517,,,,,,,,,,,,,, -1700,1.8443294,3.424828,,,,,,,,,,,,,, -1800,3.2240732,3.2808568,,,,,,,,,,,,,, -1829,,,6.2994285,0.9422720438301992,6.226698,0.8960097318902845,5348.0,6.149786,0.8985639713200496,2472.0,1485.454912185669,1827.0729236602783,1485.454912185669,341.49453115463257,0.044753074645996,0.0 -1900,3.7624505,3.147782,,,,,,,,,,,,,, -2000,1.8948482,2.982963,,,,,,,,,,,,,, -2100,2.535188,2.7986598,,,,,,,,,,,,,, -2200,4.8903227,2.8255067,,,,,,,,,,,,,, -2300,2.6208067,2.700283,,,,,,,,,,,,,, -2400,2.8298082,2.649091,,,,,,,,,,,,,, -2500,4.675197,2.5984244,,,,,,,,,,,,,, -2600,3.6624048,2.5094063,,,,,,,,,,,,,, -2700,4.631579,2.5270164,,,,,,,,,,,,,, -2800,3.4223924,2.393486,,,,,,,,,,,,,, -2900,2.4042847,2.3818429,,,,,,,,,,,,,, -3000,5.6695976,2.3409255,,,,,,,,,,,,,, -3100,3.1537085,2.3471792,,,,,,,,,,,,,, -3200,3.923031,2.261043,,,,,,,,,,,,,, -3300,3.4046338,2.260425,,,,,,,,,,,,,, -3400,2.718353,2.1858006,,,,,,,,,,,,,, -3500,3.3597515,2.2083175,,,,,,,,,,,,,, -3600,5.0719295,2.1544702,,,,,,,,,,,,,, -3638,,,3.389681,0.697909064272967,3.6511831,0.7208646707280574,5348.0,3.2384326,0.6621168728292,2472.0,2925.4223692417145,3390.408249616623,2925.4223692417145,464.7445025444031,0.0834105014801025,0.0 -3700,4.24444,2.1306317,,,,,,,,,,,,,, -3800,3.400603,2.1399863,,,,,,,,,,,,,, -3900,4.965709,2.1259983,,,,,,,,,,,,,, -4000,3.8779466,2.1549308,,,,,,,,,,,,,, -4100,3.6390984,2.1001225,,,,,,,,,,,,,, -4200,4.7039204,2.0646472,,,,,,,,,,,,,, -4300,3.240824,2.0645237,,,,,,,,,,,,,, -4400,3.0007827,2.0328724,,,,,,,,,,,,,, -4500,3.2843423,2.0156572,,,,,,,,,,,,,, -4600,3.3876052,2.0031545,,,,,,,,,,,,,, -4700,4.3051214,1.9461629,,,,,,,,,,,,,, -4800,2.941184,1.9721967,,,,,,,,,,,,,, -4900,3.7997205,1.9294578,,,,,,,,,,,,,, -5000,4.584596,2.0453792,,,,,,,,,,,,,, -5100,2.8581817,1.9052347,,,,,,,,,,,,,, -5200,4.1097426,1.9303689,,,,,,,,,,,,,, -5300,3.6417177,1.9246036,,,,,,,,,,,,,, -5400,3.432431,2.0021927,,,,,,,,,,,,,, -5458,,,0.6259818,0.2104833140935423,1.0288459,0.287447985556639,5348.0,0.68839693,0.2146730851258302,2472.0,4365.819350719452,4964.804627656937,4365.819350719452,598.6258668899536,0.1237275600433349,0.0 -5500,2.7993803,1.8749723,,,,,,,,,,,,,, -5600,3.3403263,1.8817154,,,,,,,,,,,,,, -5700,3.8923054,1.8327636,,,,,,,,,,,,,, -5800,3.4864388,1.8673744,,,,,,,,,,,,,, -5900,2.9217744,1.9208283,,,,,,,,,,,,,, -6000,3.9236705,1.8490895,,,,,,,,,,,,,, -6100,5.7859955,1.9528346,,,,,,,,,,,,,, -6200,5.1869006,1.8350897,,,,,,,,,,,,,, -6300,3.0254724,1.8169398,,,,,,,,,,,,,, -6400,3.0765412,1.7624315,,,,,,,,,,,,,, -6500,3.3087795,1.7738181,,,,,,,,,,,,,, -6600,3.893984,1.7900627,,,,,,,,,,,,,, -6700,3.017784,1.8089492,,,,,,,,,,,,,, -6800,2.5479398,1.7747467,,,,,,,,,,,,,, -6900,4.561333,1.7940937,,,,,,,,,,,,,, -7000,3.3375998,1.7794045,,,,,,,,,,,,,, -7100,4.245174,1.8144281,,,,,,,,,,,,,, -7200,5.7641087,1.7352456,,,,,,,,,,,,,, -7226,,,0.48525193,0.1632282252406036,0.8257652,0.237224480338299,5348.0,0.5244182,0.1696220015030569,2472.0,5806.179289579392,6543.255874633789,5806.179289579392,736.6049258708954,0.1600997447967529,0.0 -7300,3.8560095,1.8026054,,,,,,,,,,,,,, -7400,3.5161486,1.7311518,,,,,,,,,,,,,, -7500,2.234071,1.7655956,,,,,,,,,,,,,, -7600,5.408951,1.6685243,,,,,,,,,,,,,, -7700,4.655695,1.743797,,,,,,,,,,,,,, -7800,4.3625526,1.7285098,,,,,,,,,,,,,, -7900,4.286613,1.7238085,,,,,,,,,,,,,, -8000,3.791013,1.7625656,,,,,,,,,,,,,, -8100,3.9652169,1.7285877,,,,,,,,,,,,,, -8200,3.5078907,1.6571715,,,,,,,,,,,,,, -8300,4.252633,1.7409527,,,,,,,,,,,,,, -8400,2.9313433,1.6696044,,,,,,,,,,,,,, -8500,4.0308022,1.6124533,,,,,,,,,,,,,, -8600,2.841611,1.7398659,,,,,,,,,,,,,, -8700,2.704028,1.6900741,,,,,,,,,,,,,, -8800,3.315727,1.6648742,,,,,,,,,,,,,, -8900,2.6739492,1.6275492,,,,,,,,,,,,,, -9000,3.6436656,1.6602471,,,,,,,,,,,,,, -9028,,,0.43553168,0.1476059492487333,0.76645046,0.2220570203809726,5348.0,0.47539398,0.1525196514532935,2472.0,7246.616269826889,8119.081926584244,7246.616269826889,871.8770303726196,0.198617935180664,0.0 -9100,2.8042054,1.7064553,,,,,,,,,,,,,, -9200,2.549745,1.6451432,,,,,,,,,,,,,, -9300,3.1880777,1.63652,,,,,,,,,,,,,, -9400,3.338663,1.703117,,,,,,,,,,,,,, -9500,4.4439764,1.6954339,,,,,,,,,,,,,, -9600,3.3031988,1.6840882,,,,,,,,,,,,,, -9700,3.5018806,1.6437062,,,,,,,,,,,,,, -9800,5.272682,1.6168777,,,,,,,,,,,,,, -9900,3.1560283,1.626661,,,,,,,,,,,,,, -10000,2.8056028,1.5705689,,,,,,,,,,,,,, -10100,3.0730867,1.6477447,,,,,,,,,,,,,, -10200,1.9653699,1.5735643,,,,,,,,,,,,,, -10300,2.9325247,1.632237,,,,,,,,,,,,,, -10400,3.1592762,1.6243069,,,,,,,,,,,,,, -10500,4.3111477,1.6609529,,,,,,,,,,,,,, -10600,3.4052315,1.5340109,,,,,,,,,,,,,, -10700,2.343913,1.6504692,,,,,,,,,,,,,, -10800,2.901842,1.6125141,,,,,,,,,,,,,, -10819,,,0.4950031,0.1622237278008653,0.8763321,0.2440020467864487,5348.0,0.50596225,0.1609489570003859,2472.0,8687.026130914688,9694.430442094805,8687.026130914688,1006.6984577178956,0.2374663352966308,0.0 -10900,2.5998778,1.6331214,,,,,,,,,,,,,, -11000,3.4133976,1.6157955,,,,,,,,,,,,,, -11100,3.259229,1.605659,,,,,,,,,,,,,, -11200,3.041686,1.5762928,,,,,,,,,,,,,, -11300,3.222455,1.636687,,,,,,,,,,,,,, -11400,2.8475392,1.6297,,,,,,,,,,,,,, -11500,2.8671076,1.580381,,,,,,,,,,,,,, -11600,3.3774762,1.6053988,,,,,,,,,,,,,, -11700,2.8733437,1.5419196,,,,,,,,,,,,,, -11800,2.3776777,1.6386578,,,,,,,,,,,,,, -11900,3.4755309,1.5530849,,,,,,,,,,,,,, -12000,2.3125014,1.5571206,,,,,,,,,,,,,, -12100,2.3247275,1.5871651,,,,,,,,,,,,,, -12200,2.4868853,1.6289964,,,,,,,,,,,,,, -12300,2.7718327,1.6067654,,,,,,,,,,,,,, -12400,4.6886115,1.547992,,,,,,,,,,,,,, -12500,4.9223633,1.6011142,,,,,,,,,,,,,, -12600,4.352079,1.6367604,,,,,,,,,,,,,, -12624,,,0.37129277,0.1258162635866675,0.6872197,0.1990210181797117,5348.0,0.4055817,0.129445696991855,2472.0,10127.646817445757,11267.96731686592,10127.646817445757,1139.495243549347,0.2780308723449707,0.0 -12700,3.146875,1.564803,,,,,,,,,,,,,, -12800,3.0083423,1.6566491,,,,,,,,,,,,,, -12900,2.4159615,1.5803587,,,,,,,,,,,,,, -13000,3.3574576,1.5638193,,,,,,,,,,,,,, -13100,2.6762795,1.5821384,,,,,,,,,,,,,, -13200,2.9275813,1.5578068,,,,,,,,,,,,,, -13300,2.6336393,1.6022515,,,,,,,,,,,,,, -13400,3.6186242,1.5457504,,,,,,,,,,,,,, -13500,2.7730067,1.5810583,,,,,,,,,,,,,, -13600,3.0540836,1.5770524,,,,,,,,,,,,,, -13700,3.8429828,1.6404884,,,,,,,,,,,,,, -13800,2.039898,1.52359,,,,,,,,,,,,,, -13900,3.1474833,1.5914166,,,,,,,,,,,,,, -14000,4.237532,1.500794,,,,,,,,,,,,,, -14100,2.9699101,1.4780173,,,,,,,,,,,,,, -14200,2.50274,1.5631988,,,,,,,,,,,,,, -14300,2.3040743,1.5651941,,,,,,,,,,,,,, -14400,,,0.31816357,0.1100496632777763,0.6623367,0.190650433976655,5348.0,0.39234465,0.1261552210915442,2472.0,11568.152488470078,12841.431944847109,11568.152488470078,1272.3366241455078,0.3184385299682617,0.0 -14400,3.3625832,1.5448141,,,,,,,,,,,,,, -14500,3.1129546,1.5786837,,,,,,,,,,,,,, -14600,3.242456,1.5189463,,,,,,,,,,,,,, -14700,3.72212,1.515999,,,,,,,,,,,,,, -14800,3.1599154,1.5321764,,,,,,,,,,,,,, -14900,2.7404163,1.6069121,,,,,,,,,,,,,, -15000,4.9475694,1.5200281,,,,,,,,,,,,,, -15100,3.6111064,1.5137438,,,,,,,,,,,,,, -15200,3.1848664,1.5118951,,,,,,,,,,,,,, -15300,2.8214183,1.5869304,,,,,,,,,,,,,, -15400,2.2382343,1.5629327,,,,,,,,,,,,,, -15500,2.5207279,1.4942905,,,,,,,,,,,,,, -15600,3.6554024,1.5352458,,,,,,,,,,,,,, -15700,2.381152,1.533622,,,,,,,,,,,,,, -15800,2.1201305,1.5507946,,,,,,,,,,,,,, -15900,3.410875,1.4710702,,,,,,,,,,,,,, -16000,3.2714918,1.5309751,,,,,,,,,,,,,, -16100,2.5483093,1.5083401,,,,,,,,,,,,,, -16169,,,0.3033341,0.1042565529422404,0.64782745,0.1861706749567954,5348.0,0.3739464,0.1225803830763918,2472.0,13008.389578580856,14417.589095115662,13008.389578580856,1408.132239818573,0.3631329536437988,0.0 -16200,3.1357622,1.5838461,,,,,,,,,,,,,, -16300,2.4729671,1.4911777,,,,,,,,,,,,,, -16400,3.4495966,1.534085,,,,,,,,,,,,,, -16500,3.2035148,1.4986289,,,,,,,,,,,,,, -16600,3.32721,1.4859359,,,,,,,,,,,,,, -16700,2.0885963,1.4996277,,,,,,,,,,,,,, -16800,2.5799584,1.538456,,,,,,,,,,,,,, -16900,2.1346908,1.5712436,,,,,,,,,,,,,, -17000,1.8035148,1.4699336,,,,,,,,,,,,,, -17100,2.6842673,1.5139546,,,,,,,,,,,,,, -17200,3.4934819,1.529375,,,,,,,,,,,,,, -17300,2.6166286,1.5340095,,,,,,,,,,,,,, -17400,1.8977398,1.5207461,,,,,,,,,,,,,, -17500,2.7423732,1.5330342,,,,,,,,,,,,,, -17600,2.9717774,1.4795861,,,,,,,,,,,,,, -17700,2.849745,1.421926,,,,,,,,,,,,,, -17800,3.1735785,1.5462455,,,,,,,,,,,,,, -17900,2.5829082,1.5967962,,,,,,,,,,,,,, -17980,,,0.29420593,0.1039004220585948,0.6150627,0.1781669675700203,5348.0,0.36189067,0.1170556334166108,2472.0,14448.71219611168,15991.481305122375,14448.71219611168,1541.5824666023254,0.4025120735168457,0.0 -18000,4.2631536,1.4501793,,,,,,,,,,,,,, -18100,2.7290022,1.485398,,,,,,,,,,,,,, -18200,2.8158393,1.4827782,,,,,,,,,,,,,, -18300,3.1831787,1.466582,,,,,,,,,,,,,, -18400,2.9045622,1.5308287,,,,,,,,,,,,,, -18500,2.478624,1.4878155,,,,,,,,,,,,,, -18600,4.9284906,1.4887347,,,,,,,,,,,,,, -18700,2.3802216,1.5334406,,,,,,,,,,,,,, -18800,3.766846,1.4570636,,,,,,,,,,,,,, -18900,2.776145,1.5073147,,,,,,,,,,,,,, -19000,3.6898181,1.4897907,,,,,,,,,,,,,, -19100,3.3498797,1.4714092,,,,,,,,,,,,,, -19200,2.311377,1.4118805,,,,,,,,,,,,,, -19300,3.0993073,1.4623917,,,,,,,,,,,,,, -19400,2.590718,1.483354,,,,,,,,,,,,,, -19500,2.8760788,1.5326884,,,,,,,,,,,,,, -19600,2.532563,1.5065677,,,,,,,,,,,,,, -19700,3.1346486,1.4360231,,,,,,,,,,,,,, -19763,,,0.30935612,0.1035617017901055,0.6012965,0.1738320283460613,5348.0,0.35354367,0.1129933174903012,2472.0,15888.818793296814,17564.92405796051,15888.818793296814,1674.7982478141785,0.4451837539672851,0.0 -19800,2.642296,1.5055095,,,,,,,,,,,,,, -19900,2.7403095,1.4708531,,,,,,,,,,,,,, -20000,3.7763953,1.5099051,,,,,,,,,,,,,, -20100,2.4625952,1.3920114,,,,,,,,,,,,,, -20200,3.0937076,1.460287,,,,,,,,,,,,,, -20300,2.7970946,1.4913808,,,,,,,,,,,,,, -20400,3.2158716,1.4847804,,,,,,,,,,,,,, -20500,2.249474,1.4455805,,,,,,,,,,,,,, -20600,2.4993398,1.4276254,,,,,,,,,,,,,, -20700,2.979988,1.4020808,,,,,,,,,,,,,, -20800,2.9945097,1.4195738,,,,,,,,,,,,,, -20900,3.2199345,1.5055871,,,,,,,,,,,,,, -21000,2.3693178,1.4481957,,,,,,,,,,,,,, -21100,3.905719,1.4623969,,,,,,,,,,,,,, -21200,3.266732,1.4933435,,,,,,,,,,,,,, -21300,3.0903628,1.4750104,,,,,,,,,,,,,, -21400,3.8159502,1.4101672,,,,,,,,,,,,,, -21500,2.2560153,1.4381278,,,,,,,,,,,,,, -21541,,,0.30396807,0.1025509800300736,0.5889155,0.1706653021423675,5348.0,0.34354234,0.1115715069160928,2472.0,17329.08687067032,19140.49734663964,17329.08687067032,1809.982824802399,0.487401008605957,0.0 -21600,2.0276604,1.4575635,,,,,,,,,,,,,, -21700,2.4812078,1.4414716,,,,,,,,,,,,,, -21800,2.6962607,1.4246311,,,,,,,,,,,,,, -21900,3.6346781,1.4673333,,,,,,,,,,,,,, -22000,1.9218484,1.4462502,,,,,,,,,,,,,, -22100,2.74296,1.4493217,,,,,,,,,,,,,, -22200,2.7253118,1.4083555,,,,,,,,,,,,,, -22300,3.3874102,1.5600632,,,,,,,,,,,,,, -22400,3.4298213,1.469319,,,,,,,,,,,,,, -22500,2.532029,1.4362577,,,,,,,,,,,,,, -22600,3.4654632,1.4581097,,,,,,,,,,,,,, -22700,3.306721,1.4858042,,,,,,,,,,,,,, -22800,2.8978555,1.4127562,,,,,,,,,,,,,, -22900,3.823562,1.4913229,,,,,,,,,,,,,, -23000,2.3529897,1.4026004,,,,,,,,,,,,,, -23100,2.512678,1.4578274,,,,,,,,,,,,,, -23200,1.9574873,1.4084789,,,,,,,,,,,,,, -23300,2.9055502,1.4356606,,,,,,,,,,,,,, -23309,,,0.28539538,0.0970058648009054,0.5723087,0.1664655280612491,5348.0,0.33128116,0.106351430950785,2472.0,18769.07917690277,20714.64650440216,18769.07917690277,1944.016057014465,0.5339093208312988,0.0 -23400,4.0280523,1.4713436,,,,,,,,,,,,,, -23500,3.6444674,1.4352779,,,,,,,,,,,,,, -23600,2.8729193,1.4194416,,,,,,,,,,,,,, -23700,2.8075957,1.4114308,,,,,,,,,,,,,, -23800,2.4393325,1.3436754,,,,,,,,,,,,,, -23900,2.3840566,1.3826025,,,,,,,,,,,,,, -24000,4.3802304,1.4991926,,,,,,,,,,,,,, -24100,3.1330993,1.4018179,,,,,,,,,,,,,, -24200,3.306259,1.37331,,,,,,,,,,,,,, -24300,2.2502115,1.3689455,,,,,,,,,,,,,, -24400,3.278146,1.4386561,,,,,,,,,,,,,, -24500,2.7786608,1.4088569,,,,,,,,,,,,,, -24600,3.265647,1.4226097,,,,,,,,,,,,,, -24700,2.1883683,1.363473,,,,,,,,,,,,,, -24800,2.727392,1.410502,,,,,,,,,,,,,, -24900,3.531644,1.437918,,,,,,,,,,,,,, -25000,3.2864199,1.3809589,,,,,,,,,,,,,, -25094,,,0.26591775,0.0911437439297326,0.5673303,0.1649111289185823,5348.0,0.3237855,0.1045437003635772,2472.0,20209.84367251396,22290.02009320259,20209.84367251396,2078.508171081543,0.5718889236450195,0.0 -25100,2.63169,1.4013194,,,,,,,,,,,,,, -25200,2.8414621,1.3806719,,,,,,,,,,,,,, -25300,4.417436,1.4717064,,,,,,,,,,,,,, -25400,3.008773,1.4316323,,,,,,,,,,,,,, -25500,2.7171724,1.4709396,,,,,,,,,,,,,, -25600,4.054471,1.4080254,,,,,,,,,,,,,, -25700,2.7638662,1.4200243,,,,,,,,,,,,,, -25800,2.9956353,1.4193758,,,,,,,,,,,,,, -25900,2.7205775,1.4084724,,,,,,,,,,,,,, -26000,2.9562485,1.4213364,,,,,,,,,,,,,, -26100,2.6524615,1.3845476,,,,,,,,,,,,,, -26200,3.3616042,1.4159849,,,,,,,,,,,,,, -26300,2.3349476,1.4152104,,,,,,,,,,,,,, -26400,2.260625,1.3733928,,,,,,,,,,,,,, -26500,4.6468015,1.4024374,,,,,,,,,,,,,, -26600,2.3978686,1.4025598,,,,,,,,,,,,,, -26700,2.306175,1.4186424,,,,,,,,,,,,,, -26800,2.5961072,1.4609768,,,,,,,,,,,,,, -26858,,,0.24384421,0.0851665405983463,0.55111766,0.1586549137356749,5348.0,0.31746915,0.1007251233928462,2472.0,21650.451526880264,23875.94362139702,21650.451526880264,2223.6968109607697,0.6204798221588135,0.0 -26900,1.9244462,1.3704001,,,,,,,,,,,,,, -27000,3.1414273,1.3961422,,,,,,,,,,,,,, -27100,2.75944,1.3973416,,,,,,,,,,,,,, -27200,2.711991,1.3947273,,,,,,,,,,,,,, -27300,2.8470762,1.3769535,,,,,,,,,,,,,, -27400,2.4159448,1.4458216,,,,,,,,,,,,,, -27500,2.447088,1.3841062,,,,,,,,,,,,,, -27600,3.5879562,1.3546212,,,,,,,,,,,,,, -27700,3.1010664,1.3927333,,,,,,,,,,,,,, -27800,3.7387235,1.3658131,,,,,,,,,,,,,, -27900,2.3610132,1.4043071,,,,,,,,,,,,,, -28000,3.0416737,1.38667,,,,,,,,,,,,,, -28100,3.382141,1.3936733,,,,,,,,,,,,,, -28200,5.181467,1.3614454,,,,,,,,,,,,,, -28300,3.4131866,1.3782954,,,,,,,,,,,,,, -28400,5.350908,1.3985525,,,,,,,,,,,,,, -28500,3.169579,1.3602915,,,,,,,,,,,,,, -28600,2.6391199,1.3933557,,,,,,,,,,,,,, -28641,,,0.22982687,0.079006044737192,0.52235276,0.1524759357772478,5348.0,0.2958047,0.0958503442812747,2472.0,23090.378248929977,25450.42253112793,23090.378248929977,2358.1283671855927,0.6628987789154053,0.0 -28700,3.0494547,1.4025948,,,,,,,,,,,,,, -28800,2.860873,1.3689022,,,,,,,,,,,,,, -28900,2.31186,1.3427192,,,,,,,,,,,,,, -29000,2.4356859,1.3370436,,,,,,,,,,,,,, -29100,3.3177373,1.3983557,,,,,,,,,,,,,, -29200,2.726945,1.3033543,,,,,,,,,,,,,, -29300,2.536003,1.3629771,,,,,,,,,,,,,, -29400,2.733079,1.402635,,,,,,,,,,,,,, -29500,3.144722,1.3350146,,,,,,,,,,,,,, -29600,3.040768,1.382514,,,,,,,,,,,,,, -29700,3.178263,1.3168226,,,,,,,,,,,,,, -29800,2.848854,1.3095356,,,,,,,,,,,,,, -29900,2.8503277,1.3717599,,,,,,,,,,,,,, -30000,2.473205,1.3376493,,,,,,,,,,,,,, -30100,3.3809013,1.3526491,,,,,,,,,,,,,, -30200,2.3023221,1.393272,,,,,,,,,,,,,, -30300,2.1253748,1.2965614,,,,,,,,,,,,,, -30400,3.3155813,1.3801173,,,,,,,,,,,,,, -30429,,,0.24319935,0.0841892009013693,0.51407164,0.148855440879732,5348.0,0.2912043,0.0946925842422765,2472.0,24530.86066389084,27029.50042438507,24530.86066389084,2496.595645904541,0.7129337787628174,0.0 -30500,3.2760265,1.3067689,,,,,,,,,,,,,, -30600,3.4472647,1.385017,,,,,,,,,,,,,, -30700,3.5176082,1.371676,,,,,,,,,,,,,, -30800,2.668403,1.3960238,,,,,,,,,,,,,, -30900,2.9801543,1.3590211,,,,,,,,,,,,,, -31000,3.315759,1.4009812,,,,,,,,,,,,,, -31100,2.3506393,1.298072,,,,,,,,,,,,,, -31200,2.0867596,1.3269825,,,,,,,,,,,,,, -31300,2.3876176,1.3294976,,,,,,,,,,,,,, -31400,2.428559,1.2694472,,,,,,,,,,,,,, -31500,3.1961362,1.3228822,,,,,,,,,,,,,, -31600,3.339054,1.2659917,,,,,,,,,,,,,, -31700,2.2111423,1.3163159,,,,,,,,,,,,,, -31800,2.6081924,1.2875456,,,,,,,,,,,,,, -31900,2.9145696,1.3518994,,,,,,,,,,,,,, -32000,2.7317753,1.3401692,,,,,,,,,,,,,, -32100,2.2842007,1.3272784,,,,,,,,,,,,,, -32200,3.8929641,1.3045665,,,,,,,,,,,,,, -32214,,,0.221281,0.0749217304488901,0.4996829,0.1448004865945142,5348.0,0.27680507,0.0891678345824954,2472.0,25970.978246688843,28602.740975379944,25970.978246688843,2629.5989384651184,0.7543802261352539,0.0 -32300,2.6248922,1.3238133,,,,,,,,,,,,,, -32400,2.8418226,1.2400715,,,,,,,,,,,,,, -32500,3.1545029,1.3285959,,,,,,,,,,,,,, -32600,3.6276314,1.3455321,,,,,,,,,,,,,, -32700,3.0513122,1.2225622,,,,,,,,,,,,,, -32800,2.9700782,1.2906713,,,,,,,,,,,,,, -32900,2.4230835,1.3594594,,,,,,,,,,,,,, -33000,2.4129276,1.2168399,,,,,,,,,,,,,, -33100,2.6745455,1.3042619,,,,,,,,,,,,,, -33200,2.938608,1.2407626,,,,,,,,,,,,,, -33300,3.648444,1.3047934,,,,,,,,,,,,,, -33400,2.8139863,1.2839352,,,,,,,,,,,,,, -33500,3.1843917,1.310776,,,,,,,,,,,,,, -33600,3.7690322,1.3006923,,,,,,,,,,,,,, -33700,2.513898,1.2339383,,,,,,,,,,,,,, -33800,2.8866587,1.263547,,,,,,,,,,,,,, -33900,2.9382145,1.2649664,,,,,,,,,,,,,, -33991,,,0.22154376,0.0749528687156147,0.47980833,0.1381870492483852,5348.0,0.2669557,0.0857554892043954,2472.0,27412.06226158142,30178.17682671547,27412.06226158142,2763.8282132148743,0.8003880977630615,0.0 -34000,3.1108077,1.2475511,,,,,,,,,,,,,, -34100,2.8631215,1.3103278,,,,,,,,,,,,,, -34200,2.6483004,1.295712,,,,,,,,,,,,,, -34300,2.3576088,1.2910124,,,,,,,,,,,,,, -34400,3.340096,1.2532178,,,,,,,,,,,,,, -34500,4.1977234,1.3264985,,,,,,,,,,,,,, -34600,3.7830758,1.3132569,,,,,,,,,,,,,, -34700,3.5016668,1.2685404,,,,,,,,,,,,,, -34800,3.4543047,1.2693297,,,,,,,,,,,,,, -34900,2.7203102,1.2582024,,,,,,,,,,,,,, -35000,2.0241923,1.2530284,,,,,,,,,,,,,, -35100,3.1869042,1.2768533,,,,,,,,,,,,,, -35200,3.0756838,1.1891048,,,,,,,,,,,,,, -35300,4.7053027,1.2599527,,,,,,,,,,,,,, -35400,2.4486697,1.2401196,,,,,,,,,,,,,, -35500,2.7420156,1.247769,,,,,,,,,,,,,, -35600,3.6202188,1.2491024,,,,,,,,,,,,,, -35700,2.9982157,1.2220784,,,,,,,,,,,,,, -35787,,,0.21164294,0.0702121764336029,0.46737054,0.1357830406364347,5348.0,0.25955525,0.0829728027948733,2472.0,28852.0884373188,31751.72282457352,28852.0884373188,2897.226144552231,0.8438427448272705,0.0 -35800,2.506471,1.2265366,,,,,,,,,,,,,, -35900,2.4498038,1.2393237,,,,,,,,,,,,,, -36000,3.7383556,1.2823988,,,,,,,,,,,,,, -36100,2.682148,1.2692463,,,,,,,,,,,,,, -36200,2.442957,1.2626123,,,,,,,,,,,,,, -36300,2.917051,1.1909707,,,,,,,,,,,,,, -36400,2.9999542,1.2431501,,,,,,,,,,,,,, -36500,3.9352834,1.2638646,,,,,,,,,,,,,, -36600,2.8717535,1.2343876,,,,,,,,,,,,,, -36700,2.8353283,1.1764767,,,,,,,,,,,,,, -36800,2.8387554,1.195927,,,,,,,,,,,,,, -36900,3.1350462,1.2166307,,,,,,,,,,,,,, -37000,3.9330165,1.1951677,,,,,,,,,,,,,, -37100,3.000756,1.1611702,,,,,,,,,,,,,, -37200,2.9187675,1.2361448,,,,,,,,,,,,,, -37300,2.5810497,1.2352357,,,,,,,,,,,,,, -37400,2.72432,1.1648289,,,,,,,,,,,,,, -37500,5.105663,1.2073811,,,,,,,,,,,,,, -37564,,,0.15737125,0.0545032440576858,0.44882238,0.1306371105554322,5348.0,0.24834183,0.0789307984481953,2472.0,30292.03697609901,33327.14091706276,30292.03697609901,3032.571701526642,0.8902647495269775,0.0 -37600,2.5495126,1.2353542,,,,,,,,,,,,,, -37700,2.3986404,1.2235553,,,,,,,,,,,,,, -37800,3.3608906,1.2056875,,,,,,,,,,,,,, -37900,4.6522994,1.2000728,,,,,,,,,,,,,, -38000,3.2955859,1.2276305,,,,,,,,,,,,,, -38100,2.5960388,1.237362,,,,,,,,,,,,,, -38200,2.8081372,1.1984496,,,,,,,,,,,,,, -38300,3.4841285,1.1803056,,,,,,,,,,,,,, -38400,2.7146611,1.1704156,,,,,,,,,,,,,, -38500,2.7992668,1.2015241,,,,,,,,,,,,,, -38600,3.4049902,1.1631591,,,,,,,,,,,,,, -38700,4.057486,1.2057619,,,,,,,,,,,,,, -38800,3.2311301,1.2125257,,,,,,,,,,,,,, -38900,3.4572747,1.1765985,,,,,,,,,,,,,, -39000,2.2695222,1.2080569,,,,,,,,,,,,,, -39100,2.6646442,1.2123581,,,,,,,,,,,,,, -39200,3.3311954,1.1939744,,,,,,,,,,,,,, -39300,3.9530354,1.1640297,,,,,,,,,,,,,, -39355,,,0.17283808,0.0592490126924457,0.43858835,0.1268814505150757,5348.0,0.24206062,0.0764527857331464,2472.0,31732.59752678871,34901.9730963707,31732.59752678871,3166.721561670304,0.934283971786499,0.0 -39400,2.6224523,1.1351554,,,,,,,,,,,,,, -39500,2.4668956,1.1679951,,,,,,,,,,,,,, -39600,3.7529602,1.1390699,,,,,,,,,,,,,, -39700,3.5930142,1.2360677,,,,,,,,,,,,,, -39800,4.2798944,1.1189488,,,,,,,,,,,,,, -39900,3.0717618,1.2080511,,,,,,,,,,,,,, -40000,3.5242016,1.1174809,,,,,,,,,,,,,, -40100,3.0848982,1.2527531,,,,,,,,,,,,,, -40200,3.4757934,1.1631043,,,,,,,,,,,,,, -40300,2.9949715,1.1642249,,,,,,,,,,,,,, -40400,3.7154207,1.1526265,,,,,,,,,,,,,, -40500,3.850483,1.1618334,,,,,,,,,,,,,, -40600,2.658219,1.190973,,,,,,,,,,,,,, -40700,3.0669246,1.193522,,,,,,,,,,,,,, -40800,2.7795458,1.2013419,,,,,,,,,,,,,, -40900,2.8743107,1.1678529,,,,,,,,,,,,,, -41000,3.1572475,1.1198863,,,,,,,,,,,,,, -41100,3.399042,1.1792227,,,,,,,,,,,,,, -41130,,,0.21339306,0.0732522252295742,0.42771238,0.1242940034949844,5348.0,0.23490724,0.0752340909552535,2472.0,33172.82326436043,36475.123975515366,33172.82326436043,3299.519961833954,0.9837052822113036,0.0 -41200,4.2078457,1.1647875,,,,,,,,,,,,,, -41300,3.533618,1.1737647,,,,,,,,,,,,,, -41400,2.615364,1.1532601,,,,,,,,,,,,,, -41500,3.833295,1.1488545,,,,,,,,,,,,,, -41600,3.2325113,1.1817255,,,,,,,,,,,,,, -41700,3.3316574,1.0909971,,,,,,,,,,,,,, -41800,2.7918305,1.1144574,,,,,,,,,,,,,, -41900,3.6548233,1.0987983,,,,,,,,,,,,,, -42000,4.7361274,1.1217773,,,,,,,,,,,,,, -42100,3.8864532,1.1434722,,,,,,,,,,,,,, -42200,3.5143626,1.1594801,,,,,,,,,,,,,, -42300,2.9315536,1.1421198,,,,,,,,,,,,,, -42400,6.3136353,1.1386547,,,,,,,,,,,,,, -42500,3.554959,1.153375,,,,,,,,,,,,,, -42600,4.9727445,1.1811934,,,,,,,,,,,,,, -42700,3.552581,1.1638094,,,,,,,,,,,,,, -42800,3.3583765,1.1375611,,,,,,,,,,,,,, -42900,3.8763099,1.0890701,,,,,,,,,,,,,, -42909,,,0.21680024,0.0741213631846872,0.41661587,0.1206059260260482,5348.0,0.22871256,0.0724514045457315,2472.0,34613.33181476593,38046.176903009415,34613.33181476593,3429.93771147728,1.0329890251159668,0.0 -43000,3.2518203,1.1786318,,,,,,,,,,,,,, -43100,2.8756394,1.1552969,,,,,,,,,,,,,, -43200,2.9625707,1.2062106,,,,,,,,,,,,,, -43300,4.5981617,1.1825027,,,,,,,,,,,,,, -43400,3.6315134,1.1321472,,,,,,,,,,,,,, -43500,3.4948146,1.1517106,,,,,,,,,,,,,, -43600,2.8118794,1.1678245,,,,,,,,,,,,,, -43700,3.2155752,1.1235584,,,,,,,,,,,,,, -43800,4.007563,1.1310529,,,,,,,,,,,,,, -43900,3.1311688,1.1766287,,,,,,,,,,,,,, -44000,3.0851064,1.1809822,,,,,,,,,,,,,, -44100,2.7378352,1.103108,,,,,,,,,,,,,, -44200,3.3594143,1.1596265,,,,,,,,,,,,,, -44300,2.6221216,1.0733519,,,,,,,,,,,,,, -44400,2.5625837,1.0842593,,,,,,,,,,,,,, -44500,3.1030884,1.1704593,,,,,,,,,,,,,, -44600,2.250696,1.122053,,,,,,,,,,,,,, -44700,2.9337924,1.0729277,,,,,,,,,,,,,, -44705,,,0.24842209,0.0861474925563465,0.412366,0.1190032536180812,5348.0,0.22644746,0.0710499055511547,2472.0,36053.66913104057,39617.75190496445,36053.66913104057,3561.055627822876,1.075854778289795,0.0 -44800,4.9544706,1.1209018,,,,,,,,,,,,,, -44900,3.4767592,1.11882,,,,,,,,,,,,,, -45000,3.1932414,1.1396117,,,,,,,,,,,,,, -45100,3.268748,1.1008582,,,,,,,,,,,,,, -45200,3.104939,1.0877923,,,,,,,,,,,,,, -45300,4.2024217,1.1480421,,,,,,,,,,,,,, -45400,3.348944,1.15785,,,,,,,,,,,,,, -45500,3.6174874,1.1428581,,,,,,,,,,,,,, -45600,2.5062964,1.1461451,,,,,,,,,,,,,, -45700,2.756964,1.110352,,,,,,,,,,,,,, -45800,2.672716,1.1410744,,,,,,,,,,,,,, -45900,3.5554497,1.0502585,,,,,,,,,,,,,, -46000,2.313814,1.1266328,,,,,,,,,,,,,, -46100,2.5122836,1.1398755,,,,,,,,,,,,,, -46200,4.0477195,1.1602285,,,,,,,,,,,,,, -46300,4.182689,1.1114829,,,,,,,,,,,,,, -46400,2.5248425,1.1281593,,,,,,,,,,,,,, -46467,,,0.22267985,0.0739774741093559,0.40985504,0.1182019174140977,5348.0,0.22343448,0.0708264781752076,2472.0,37494.07298588753,41188.91581845284,37494.07298588753,3691.6905381679535,1.123474359512329,0.0 -46500,2.7568252,1.1321638,,,,,,,,,,,,,, -46600,3.0671,1.1050062,,,,,,,,,,,,,, -46700,3.7894804,1.152169,,,,,,,,,,,,,, -46800,2.6383076,1.107735,,,,,,,,,,,,,, -46900,3.0353386,1.1571056,,,,,,,,,,,,,, -47000,3.4941552,1.1670604,,,,,,,,,,,,,, -47100,4.9120154,1.088128,,,,,,,,,,,,,, -47200,3.5644855,1.1217562,,,,,,,,,,,,,, -47300,2.622633,1.1786923,,,,,,,,,,,,,, -47400,4.8500333,1.1613333,,,,,,,,,,,,,, -47500,5.0061474,1.0754945,,,,,,,,,,,,,, -47600,3.1186795,1.0744575,,,,,,,,,,,,,, -47700,4.49002,1.1571865,,,,,,,,,,,,,, -47800,3.6005113,1.0394541,,,,,,,,,,,,,, -47900,4.38788,1.1347,,,,,,,,,,,,,, -48000,,,0.2067998,0.0725563114471179,0.41027457,0.1179605510875966,5348.0,0.22379321,0.0706842971177868,2472.0,38733.70230102539,42569.4317715168,38733.70230102539,3832.4622428417206,1.1707794666290283,0.0 -48000,,,,,,,,,,,38733.70230102539,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 2645b46e3..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -190.299305677414,0.0,15.445929288864136,1,0,15.445929288864136,30.536596,2472,3.3615054942822904,205.7453055381775,31.72197,3.266288921032312,30.351294,5348,3.0462071695453625 -306.5587999820709,0.0313451290130615,1455.9250264167786,1783,0,1455.9250264167786,6.6782165,2472,0.8730932504620884,1762.5931413173676,6.9688888,0.9123050937668022,6.7451887,5348,0.8757156511580756 -431.56023931503296,0.0696589946746826,2895.872971773148,3592,0,2895.872971773148,2.390297,2472,0.5203623585806268,3327.6603693962097,3.1315818,0.6487877941934376,2.9454577,5348,0.6002201260897689 -562.577470779419,0.1072361469268798,4336.168129205704,5384,0,4336.168129205704,0.58996063,2472,0.1858103304694006,4899.088443279266,0.7923331,0.2481584742728821,0.9136013,5348,0.2602025546211997 -694.9200584888458,0.1430966854095459,5776.282276391983,7156,0,5776.282276391983,0.4931348,2472,0.161842666504174,6471.658866643906,0.67502964,0.2146252587991718,0.7972467,5348,0.2310261930737518 -827.719379901886,0.1855082511901855,7216.382225990295,8929,0,7216.382225990295,0.4402412,2472,0.142790404809782,8044.677983999252,0.61008316,0.1959925826722983,0.7310794,5348,0.2113596647904457 -963.2478840351104,0.227180004119873,8656.938082933426,10735,0,8656.938082933426,0.4191527,2472,0.1371844088314748,9620.884849786758,0.5634109,0.1819830318343697,0.69694114,5348,0.20091333017948 -1094.5947580337524,0.2684915065765381,10097.11484861374,12523,0,10097.11484861374,0.39195168,2472,0.127962951678752,11192.52817583084,0.5032294,0.1660165444885332,0.6597988,5348,0.1920407040173011 -1225.9796528816223,0.3098704814910888,11537.74647140503,14309,0,11537.74647140503,0.3724739,2472,0.1187008713667662,12764.665168046951,0.4797483,0.1595602665327917,0.62959296,5348,0.1824922521409193 -1357.3227698802948,0.3514468669891357,12977.85252571106,16084,0,12977.85252571106,0.35472116,2472,0.1176446692259257,14336.233743190764,0.43343863,0.1482390751238673,0.6055314,5348,0.1772208115701362 -1490.1781721115112,0.4107773303985595,14417.783019304276,17879,0,14417.783019304276,0.3461188,2472,0.1114496374383035,15909.162315368652,0.4497294,0.1467116837568393,0.5981819,5348,0.1723548664278749 -1622.651347875595,0.4618921279907226,15858.06008553505,19647,0,15858.06008553505,0.32342452,2472,0.1053561635488392,17482.04532623291,0.41132402,0.1372341852820822,0.5722585,5348,0.1672475549591125 -1755.9761242866516,0.511904239654541,17298.850484371185,21421,0,17298.850484371185,0.31684622,2472,0.1029593971523165,19056.292241811752,0.4032407,0.1366182216663161,0.5550742,5348,0.1624878110005116 -1891.118372440338,0.5660281181335449,18738.948315143585,23194,0,18738.948315143585,0.30597436,2472,0.0991001970223224,20631.66925215721,0.3486651,0.1175532726871823,0.53862715,5348,0.1576122112051903 -2022.37153339386,0.6189842224121094,20179.532305002213,24964,0,20179.532305002213,0.29627094,2472,0.0959519021794324,22203.64442896843,0.3170169,0.1061324361523168,0.5284512,5348,0.1534027824710119 -2154.782774209976,0.691807746887207,21619.72815155983,26733,0,21619.72815155983,0.28524765,2472,0.0948550768793289,23776.40490913391,0.35497886,0.1223003714282488,0.5095953,5348,0.1485754559409907 -2288.435446739197,0.7465569972991943,23059.842032194138,28465,0,23059.842032194138,0.2816703,2472,0.090041232506652,25350.305683374405,0.32599312,0.1092377701934016,0.5059581,5348,0.1474362068799057 -2415.957443475724,0.8124251365661621,24499.839766025543,30341,0,24499.839766025543,0.27179763,2472,0.0894318851177056,26917.968510866165,0.3049627,0.1020345102376034,0.48760784,5348,0.1428599013294457 -2540.700318098068,0.8646657466888428,25940.395836114883,32253,0,25940.395836114883,0.26502502,2472,0.0862023439562894,28483.39506316185,0.28022093,0.0971240892029145,0.4823358,5348,0.1403786554930148 -2665.4868993759155,0.9170644283294678,27380.309321403503,34162,0,27380.309321403503,0.2560349,2472,0.0843539902098186,30048.22342205048,0.26070222,0.0920251056153354,0.46313033,5348,0.133803836759126 -2793.465112447738,0.9679844379425048,28820.84319806099,36072,0,28820.84319806099,0.25247654,2472,0.082627505941137,31616.8638048172,0.24969074,0.0875957002977648,0.45965692,5348,0.1332728308408237 -2940.232330560684,1.021756649017334,30261.34621310234,37985,0,30261.34621310234,0.24373694,2472,0.0783620742185119,33204.265234947205,0.1514019,0.0536420251657881,0.44699508,5348,0.1286868706373036 -3069.708080053329,1.0760719776153564,31701.76964020729,39897,0,31701.76964020729,0.24011387,2472,0.0778542847277232,34774.295803546906,0.14774501,0.0514875865615892,0.44163772,5348,0.1280593181884009 -3199.9463255405426,1.1304571628570557,33142.368851184845,41796,0,33142.368851184845,0.23733485,2472,0.0761481120386732,36345.265635252,0.15505517,0.0539654853458072,0.4351663,5348,0.1263504445967734 -3327.5098538398743,1.182192325592041,34582.90179729462,43697,0,34582.90179729462,0.23436242,2472,0.0753153372737797,37913.49289941788,0.1378165,0.0491589448201018,0.43216133,5348,0.1245643337806655 -3455.9925067424774,1.236759901046753,36023.13359427452,45596,0,36023.13359427452,0.2329999,2472,0.0748481709422541,39482.34113240242,0.14721335,0.0522255572556525,0.43008903,5348,0.1244581325970051 -3584.948172807693,1.291504144668579,37463.108900785446,47488,0,37463.108900785446,0.23242217,2472,0.0747059898848333,41051.40569543839,0.13965625,0.050840394499236,0.42927903,5348,0.124207111617444 -3714.087497472763,1.3495750427246094,37846.51352763176,48000,0,37846.51352763176,0.23241976,2472,0.07482785936262264,41564.03453874588,0.16727167,0.05534273328182677,0.4293102,5348,0.12426503953580428 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/measurements.csv deleted file mode 100644 index 7c969378b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.7948,32.23123,,,,,,,,,,,,,, -1,,,31.72197,3.266288921032312,30.351294,3.0462071695453625,5348.0,30.536596,3.3615054942822904,2472.0,15.445929288864136,205.7453055381775,15.445929288864136,190.299305677414,0.0,0.0 -100,3.094108,7.534158,,,,,,,,,,,,,, -200,0.8875849,5.965072,,,,,,,,,,,,,, -300,0.7276109,5.8689446,,,,,,,,,,,,,, -400,1.1698911,5.8296294,,,,,,,,,,,,,, -500,1.087934,5.785864,,,,,,,,,,,,,, -600,0.5251651,5.642248,,,,,,,,,,,,,, -700,0.91530687,5.482492,,,,,,,,,,,,,, -800,1.7861487,5.305217,,,,,,,,,,,,,, -900,1.4999477,4.809888,,,,,,,,,,,,,, -1000,1.2624145,4.283008,,,,,,,,,,,,,, -1100,1.6110429,3.9371495,,,,,,,,,,,,,, -1200,1.8768065,3.7358134,,,,,,,,,,,,,, -1300,2.4276826,3.4767919,,,,,,,,,,,,,, -1400,2.8328,3.3205106,,,,,,,,,,,,,, -1500,2.4818478,3.2398376,,,,,,,,,,,,,, -1600,2.1301289,3.0221856,,,,,,,,,,,,,, -1700,2.7751524,2.9446802,,,,,,,,,,,,,, -1783,,,6.9688888,0.9123050937668022,6.7451887,0.8757156511580756,5348.0,6.6782165,0.8730932504620884,2472.0,1455.9250264167786,1762.5931413173676,1455.9250264167786,306.5587999820709,0.0313451290130615,0.0 -1800,1.9640167,2.9088044,,,,,,,,,,,,,, -1900,3.3236911,2.8182273,,,,,,,,,,,,,, -2000,4.0811334,2.7351737,,,,,,,,,,,,,, -2100,3.0757232,2.6139798,,,,,,,,,,,,,, -2200,2.6084976,2.6042144,,,,,,,,,,,,,, -2300,2.7226198,2.5262766,,,,,,,,,,,,,, -2400,2.523121,2.4795425,,,,,,,,,,,,,, -2500,2.0193923,2.442922,,,,,,,,,,,,,, -2600,3.3312519,2.3599079,,,,,,,,,,,,,, -2700,2.3408315,2.3096023,,,,,,,,,,,,,, -2800,4.1588655,2.3625648,,,,,,,,,,,,,, -2900,3.4948955,2.2148476,,,,,,,,,,,,,, -3000,4.041931,2.2058146,,,,,,,,,,,,,, -3100,4.549141,2.1939702,,,,,,,,,,,,,, -3200,3.059498,2.194392,,,,,,,,,,,,,, -3300,2.2930894,2.197947,,,,,,,,,,,,,, -3400,2.1637414,2.1226838,,,,,,,,,,,,,, -3500,2.47822,2.0961096,,,,,,,,,,,,,, -3592,,,3.1315818,0.6487877941934376,2.9454577,0.6002201260897689,5348.0,2.390297,0.5203623585806268,2472.0,2895.872971773148,3327.6603693962097,2895.872971773148,431.56023931503296,0.0696589946746826,0.0 -3600,2.5156596,2.0526369,,,,,,,,,,,,,, -3700,4.2850623,2.0005236,,,,,,,,,,,,,, -3800,3.5182304,2.083558,,,,,,,,,,,,,, -3900,3.1806288,2.1023896,,,,,,,,,,,,,, -4000,3.3839595,2.0350764,,,,,,,,,,,,,, -4100,4.509757,2.0145772,,,,,,,,,,,,,, -4200,2.6773005,1.9833759,,,,,,,,,,,,,, -4300,2.810843,1.9743544,,,,,,,,,,,,,, -4400,2.920932,1.9625676,,,,,,,,,,,,,, -4500,4.042014,1.9159226,,,,,,,,,,,,,, -4600,1.9292715,1.8612777,,,,,,,,,,,,,, -4700,2.047019,1.9038581,,,,,,,,,,,,,, -4800,2.8077857,1.9218154,,,,,,,,,,,,,, -4900,2.080131,1.929638,,,,,,,,,,,,,, -5000,3.9571733,1.8727342,,,,,,,,,,,,,, -5100,1.7771939,1.870231,,,,,,,,,,,,,, -5200,2.8175359,1.8144945,,,,,,,,,,,,,, -5300,2.519483,1.836284,,,,,,,,,,,,,, -5384,,,0.7923331,0.2481584742728821,0.9136013,0.2602025546211997,5348.0,0.58996063,0.1858103304694006,2472.0,4336.168129205704,4899.088443279266,4336.168129205704,562.577470779419,0.1072361469268798,0.0 -5400,6.8496194,1.9141239,,,,,,,,,,,,,, -5500,2.134888,1.7420105,,,,,,,,,,,,,, -5600,3.7274694,1.799243,,,,,,,,,,,,,, -5700,3.0322618,1.8908416,,,,,,,,,,,,,, -5800,4.188907,1.7139534,,,,,,,,,,,,,, -5900,2.3231251,1.8050224,,,,,,,,,,,,,, -6000,3.9666228,1.8071111,,,,,,,,,,,,,, -6100,2.251052,1.7863276,,,,,,,,,,,,,, -6200,2.5580287,1.8089651,,,,,,,,,,,,,, -6300,3.1909404,1.7702229,,,,,,,,,,,,,, -6400,2.240934,1.7389274,,,,,,,,,,,,,, -6500,4.9908967,1.7581412,,,,,,,,,,,,,, -6600,1.7199724,1.6739553,,,,,,,,,,,,,, -6700,3.614588,1.7071705,,,,,,,,,,,,,, -6800,2.7228553,1.7193922,,,,,,,,,,,,,, -6900,2.530565,1.7207,,,,,,,,,,,,,, -7000,3.5548491,1.7127199,,,,,,,,,,,,,, -7100,2.0237079,1.7183965,,,,,,,,,,,,,, -7156,,,0.67502964,0.2146252587991718,0.7972467,0.2310261930737518,5348.0,0.4931348,0.161842666504174,2472.0,5776.282276391983,6471.658866643906,5776.282276391983,694.9200584888458,0.1430966854095459,0.0 -7200,4.4570236,1.7010245,,,,,,,,,,,,,, -7300,2.2042677,1.7180426,,,,,,,,,,,,,, -7400,2.420933,1.7306985,,,,,,,,,,,,,, -7500,2.8495665,1.7717787,,,,,,,,,,,,,, -7600,2.96613,1.7217821,,,,,,,,,,,,,, -7700,3.5412085,1.6739285,,,,,,,,,,,,,, -7800,2.2985451,1.6597701,,,,,,,,,,,,,, -7900,2.9293678,1.6939325,,,,,,,,,,,,,, -8000,4.128589,1.6225338,,,,,,,,,,,,,, -8100,3.8216386,1.703807,,,,,,,,,,,,,, -8200,1.6689907,1.6404437,,,,,,,,,,,,,, -8300,2.7162616,1.6685655,,,,,,,,,,,,,, -8400,2.8244374,1.6901982,,,,,,,,,,,,,, -8500,1.9069055,1.5713149,,,,,,,,,,,,,, -8600,5.0164886,1.6410209,,,,,,,,,,,,,, -8700,2.4939816,1.6445895,,,,,,,,,,,,,, -8800,4.0642614,1.6015302,,,,,,,,,,,,,, -8900,2.2877898,1.6812986,,,,,,,,,,,,,, -8929,,,0.61008316,0.1959925826722983,0.7310794,0.2113596647904457,5348.0,0.4402412,0.142790404809782,2472.0,7216.382225990295,8044.677983999252,7216.382225990295,827.719379901886,0.1855082511901855,0.0 -9000,3.0035422,1.7137529,,,,,,,,,,,,,, -9100,3.4685614,1.6266514,,,,,,,,,,,,,, -9200,3.3466043,1.6332723,,,,,,,,,,,,,, -9300,3.828295,1.5846215,,,,,,,,,,,,,, -9400,2.621049,1.6049311,,,,,,,,,,,,,, -9500,2.4296622,1.5660101,,,,,,,,,,,,,, -9600,3.850722,1.6602769,,,,,,,,,,,,,, -9700,2.8061497,1.5415199,,,,,,,,,,,,,, -9800,2.0151396,1.5979997,,,,,,,,,,,,,, -9900,4.77237,1.6088592,,,,,,,,,,,,,, -10000,2.5447605,1.6191947,,,,,,,,,,,,,, -10100,2.2812796,1.6631527,,,,,,,,,,,,,, -10200,1.5946606,1.5829699,,,,,,,,,,,,,, -10300,3.0983026,1.55948,,,,,,,,,,,,,, -10400,4.4030633,1.5527573,,,,,,,,,,,,,, -10500,1.4263242,1.5753597,,,,,,,,,,,,,, -10600,4.0028434,1.5428923,,,,,,,,,,,,,, -10700,2.0603855,1.518202,,,,,,,,,,,,,, -10735,,,0.5634109,0.1819830318343697,0.69694114,0.20091333017948,5348.0,0.4191527,0.1371844088314748,2472.0,8656.938082933426,9620.884849786758,8656.938082933426,963.2478840351104,0.227180004119873,0.0 -10800,2.8957021,1.6122191,,,,,,,,,,,,,, -10900,3.0642884,1.5369767,,,,,,,,,,,,,, -11000,6.486134,1.5811747,,,,,,,,,,,,,, -11100,2.180897,1.5434302,,,,,,,,,,,,,, -11200,2.2369084,1.5436938,,,,,,,,,,,,,, -11300,1.9722533,1.5776268,,,,,,,,,,,,,, -11400,2.1178935,1.575099,,,,,,,,,,,,,, -11500,2.8358438,1.5481913,,,,,,,,,,,,,, -11600,2.3311594,1.5469606,,,,,,,,,,,,,, -11700,4.728336,1.5096369,,,,,,,,,,,,,, -11800,2.084702,1.5258464,,,,,,,,,,,,,, -11900,2.9035232,1.5625222,,,,,,,,,,,,,, -12000,2.7021015,1.5786772,,,,,,,,,,,,,, -12100,3.408351,1.5326707,,,,,,,,,,,,,, -12200,2.4136853,1.5937794,,,,,,,,,,,,,, -12300,2.4846053,1.5026741,,,,,,,,,,,,,, -12400,2.3431184,1.5644959,,,,,,,,,,,,,, -12500,2.7633858,1.5584388,,,,,,,,,,,,,, -12523,,,0.5032294,0.1660165444885332,0.6597988,0.1920407040173011,5348.0,0.39195168,0.127962951678752,2472.0,10097.11484861374,11192.52817583084,10097.11484861374,1094.5947580337524,0.2684915065765381,0.0 -12600,2.7683392,1.5359389,,,,,,,,,,,,,, -12700,2.7371917,1.544634,,,,,,,,,,,,,, -12800,3.0869153,1.6269956,,,,,,,,,,,,,, -12900,2.100785,1.5901109,,,,,,,,,,,,,, -13000,2.1584027,1.5754675,,,,,,,,,,,,,, -13100,3.878624,1.5366684,,,,,,,,,,,,,, -13200,2.4923902,1.5143281,,,,,,,,,,,,,, -13300,2.273994,1.4872372,,,,,,,,,,,,,, -13400,3.269054,1.5600462,,,,,,,,,,,,,, -13500,6.0924945,1.4923203,,,,,,,,,,,,,, -13600,2.3743722,1.582637,,,,,,,,,,,,,, -13700,3.3784413,1.5711122,,,,,,,,,,,,,, -13800,2.7942202,1.4968774,,,,,,,,,,,,,, -13900,3.074984,1.5682598,,,,,,,,,,,,,, -14000,2.1294975,1.4666071,,,,,,,,,,,,,, -14100,5.537916,1.5147684,,,,,,,,,,,,,, -14200,3.2134836,1.5654033,,,,,,,,,,,,,, -14300,4.925414,1.4999639,,,,,,,,,,,,,, -14309,,,0.4797483,0.1595602665327917,0.62959296,0.1824922521409193,5348.0,0.3724739,0.1187008713667662,2472.0,11537.74647140503,12764.665168046951,11537.74647140503,1225.9796528816223,0.3098704814910888,0.0 -14400,2.319979,1.5281596,,,,,,,,,,,,,, -14500,2.636431,1.4766971,,,,,,,,,,,,,, -14600,1.9019526,1.4979982,,,,,,,,,,,,,, -14700,3.157743,1.4753765,,,,,,,,,,,,,, -14800,3.1304898,1.4898944,,,,,,,,,,,,,, -14900,2.5251303,1.5221308,,,,,,,,,,,,,, -15000,1.9892101,1.492511,,,,,,,,,,,,,, -15100,1.8619648,1.4182949,,,,,,,,,,,,,, -15200,3.208455,1.5236365,,,,,,,,,,,,,, -15300,2.8813994,1.4795381,,,,,,,,,,,,,, -15400,2.0331743,1.4688327,,,,,,,,,,,,,, -15500,2.2148793,1.4931521,,,,,,,,,,,,,, -15600,2.3737857,1.4158486,,,,,,,,,,,,,, -15700,1.9936423,1.4965837,,,,,,,,,,,,,, -15800,2.5443,1.480237,,,,,,,,,,,,,, -15900,2.3143637,1.4436393,,,,,,,,,,,,,, -16000,2.8668222,1.4833491,,,,,,,,,,,,,, -16084,,,0.43343863,0.1482390751238673,0.6055314,0.1772208115701362,5348.0,0.35472116,0.1176446692259257,2472.0,12977.85252571106,14336.233743190764,12977.85252571106,1357.3227698802948,0.3514468669891357,0.0 -16100,2.681932,1.4493309,,,,,,,,,,,,,, -16200,4.9354334,1.4954331,,,,,,,,,,,,,, -16300,2.5510375,1.4740447,,,,,,,,,,,,,, -16400,2.4339323,1.5640758,,,,,,,,,,,,,, -16500,2.7447803,1.4739693,,,,,,,,,,,,,, -16600,2.0040448,1.4874144,,,,,,,,,,,,,, -16700,2.611728,1.4106349,,,,,,,,,,,,,, -16800,3.9861813,1.4582756,,,,,,,,,,,,,, -16900,2.6390023,1.4407055,,,,,,,,,,,,,, -17000,2.5927734,1.4118124,,,,,,,,,,,,,, -17100,3.2115982,1.4824581,,,,,,,,,,,,,, -17200,3.1841311,1.5121362,,,,,,,,,,,,,, -17300,2.2689407,1.4565475,,,,,,,,,,,,,, -17400,2.40304,1.4081753,,,,,,,,,,,,,, -17500,3.2758834,1.4659932,,,,,,,,,,,,,, -17600,3.2394505,1.3875092,,,,,,,,,,,,,, -17700,1.7531419,1.4126215,,,,,,,,,,,,,, -17800,2.6849465,1.513544,,,,,,,,,,,,,, -17879,,,0.4497294,0.1467116837568393,0.5981819,0.1723548664278749,5348.0,0.3461188,0.1114496374383035,2472.0,14417.783019304276,15909.162315368652,14417.783019304276,1490.1781721115112,0.4107773303985595,0.0 -17900,3.1015558,1.5238885,,,,,,,,,,,,,, -18000,2.31952,1.4378862,,,,,,,,,,,,,, -18100,2.4288416,1.4635085,,,,,,,,,,,,,, -18200,2.4149115,1.4284228,,,,,,,,,,,,,, -18300,2.333271,1.398473,,,,,,,,,,,,,, -18400,2.4506404,1.4009336,,,,,,,,,,,,,, -18500,2.2845404,1.4143107,,,,,,,,,,,,,, -18600,2.2640007,1.4216254,,,,,,,,,,,,,, -18700,2.2921178,1.5375069,,,,,,,,,,,,,, -18800,2.4023345,1.3810873,,,,,,,,,,,,,, -18900,4.877492,1.4897684,,,,,,,,,,,,,, -19000,3.6343455,1.4569749,,,,,,,,,,,,,, -19100,2.7193084,1.4146768,,,,,,,,,,,,,, -19200,5.353214,1.4429708,,,,,,,,,,,,,, -19300,2.0387142,1.3220448,,,,,,,,,,,,,, -19400,2.6400273,1.3698684,,,,,,,,,,,,,, -19500,3.974529,1.4028407,,,,,,,,,,,,,, -19600,2.9127655,1.4317635,,,,,,,,,,,,,, -19647,,,0.41132402,0.1372341852820822,0.5722585,0.1672475549591125,5348.0,0.32342452,0.1053561635488392,2472.0,15858.06008553505,17482.04532623291,15858.06008553505,1622.651347875595,0.4618921279907226,0.0 -19700,2.1972144,1.3597872,,,,,,,,,,,,,, -19800,2.5624332,1.4604032,,,,,,,,,,,,,, -19900,3.097714,1.4488156,,,,,,,,,,,,,, -20000,3.2928464,1.4037116,,,,,,,,,,,,,, -20100,2.8940368,1.4526719,,,,,,,,,,,,,, -20200,3.220375,1.4736717,,,,,,,,,,,,,, -20300,5.0008125,1.3786354,,,,,,,,,,,,,, -20400,3.1326876,1.46102,,,,,,,,,,,,,, -20500,1.8046045,1.3611897,,,,,,,,,,,,,, -20600,2.4664178,1.4256445,,,,,,,,,,,,,, -20700,2.1711905,1.3365489,,,,,,,,,,,,,, -20800,1.8817711,1.3838449,,,,,,,,,,,,,, -20900,2.1756396,1.4481809,,,,,,,,,,,,,, -21000,2.0672915,1.4118754,,,,,,,,,,,,,, -21100,2.486509,1.4175957,,,,,,,,,,,,,, -21200,2.9567132,1.4216084,,,,,,,,,,,,,, -21300,2.0310447,1.4454386,,,,,,,,,,,,,, -21400,2.4219508,1.3962312,,,,,,,,,,,,,, -21421,,,0.4032407,0.1366182216663161,0.5550742,0.1624878110005116,5348.0,0.31684622,0.1029593971523165,2472.0,17298.850484371185,19056.292241811752,17298.850484371185,1755.9761242866516,0.511904239654541,0.0 -21500,1.7451794,1.4223843,,,,,,,,,,,,,, -21600,2.190113,1.4552178,,,,,,,,,,,,,, -21700,2.0752618,1.3542958,,,,,,,,,,,,,, -21800,2.326953,1.340317,,,,,,,,,,,,,, -21900,2.7131658,1.3982178,,,,,,,,,,,,,, -22000,2.567473,1.4345963,,,,,,,,,,,,,, -22100,3.2862844,1.3710865,,,,,,,,,,,,,, -22200,2.3671074,1.38867,,,,,,,,,,,,,, -22300,1.7350562,1.3274674,,,,,,,,,,,,,, -22400,1.9726348,1.4110932,,,,,,,,,,,,,, -22500,7.0465217,1.3528198,,,,,,,,,,,,,, -22600,2.8404305,1.3288708,,,,,,,,,,,,,, -22700,6.704125,1.3680876,,,,,,,,,,,,,, -22800,2.1845872,1.3473065,,,,,,,,,,,,,, -22900,2.7171905,1.3839365,,,,,,,,,,,,,, -23000,2.2219687,1.3833662,,,,,,,,,,,,,, -23100,2.502116,1.4258003,,,,,,,,,,,,,, -23194,,,0.3486651,0.1175532726871823,0.53862715,0.1576122112051903,5348.0,0.30597436,0.0991001970223224,2472.0,18738.948315143585,20631.66925215721,18738.948315143585,1891.118372440338,0.5660281181335449,0.0 -23200,3.3923116,1.3334543,,,,,,,,,,,,,, -23300,4.213286,1.4055458,,,,,,,,,,,,,, -23400,2.443218,1.3888582,,,,,,,,,,,,,, -23500,2.5566905,1.4312019,,,,,,,,,,,,,, -23600,3.1399972,1.4054589,,,,,,,,,,,,,, -23700,2.1277618,1.4147022,,,,,,,,,,,,,, -23800,3.1281016,1.3523817,,,,,,,,,,,,,, -23900,2.4838567,1.4115131,,,,,,,,,,,,,, -24000,2.6163046,1.3567709,,,,,,,,,,,,,, -24100,2.0142927,1.3549479,,,,,,,,,,,,,, -24200,2.1232011,1.3661042,,,,,,,,,,,,,, -24300,1.822498,1.3929393,,,,,,,,,,,,,, -24400,2.7119856,1.3335861,,,,,,,,,,,,,, -24500,2.171651,1.3759682,,,,,,,,,,,,,, -24600,1.8968885,1.3342316,,,,,,,,,,,,,, -24700,2.132386,1.3607788,,,,,,,,,,,,,, -24800,3.659301,1.3485612,,,,,,,,,,,,,, -24900,3.045677,1.3889902,,,,,,,,,,,,,, -24964,,,0.3170169,0.1061324361523168,0.5284512,0.1534027824710119,5348.0,0.29627094,0.0959519021794324,2472.0,20179.532305002213,22203.64442896843,20179.532305002213,2022.37153339386,0.6189842224121094,0.0 -25000,4.3797755,1.3485112,,,,,,,,,,,,,, -25100,2.164858,1.2935385,,,,,,,,,,,,,, -25200,3.5858662,1.3399236,,,,,,,,,,,,,, -25300,2.9322853,1.3310885,,,,,,,,,,,,,, -25400,1.8239167,1.3727363,,,,,,,,,,,,,, -25500,2.990468,1.3766395,,,,,,,,,,,,,, -25600,3.3407667,1.3591748,,,,,,,,,,,,,, -25700,2.5222821,1.3720735,,,,,,,,,,,,,, -25800,3.37524,1.3545655,,,,,,,,,,,,,, -25900,1.8528929,1.3127955,,,,,,,,,,,,,, -26000,4.133004,1.3966737,,,,,,,,,,,,,, -26100,3.8191617,1.3708918,,,,,,,,,,,,,, -26200,2.5568585,1.3657038,,,,,,,,,,,,,, -26300,2.4751515,1.3486912,,,,,,,,,,,,,, -26400,2.3009558,1.3455428,,,,,,,,,,,,,, -26500,2.9948173,1.3584429,,,,,,,,,,,,,, -26600,2.1197953,1.3094095,,,,,,,,,,,,,, -26700,2.1816733,1.3699629,,,,,,,,,,,,,, -26733,,,0.35497886,0.1223003714282488,0.5095953,0.1485754559409907,5348.0,0.28524765,0.0948550768793289,2472.0,21619.72815155983,23776.40490913391,21619.72815155983,2154.782774209976,0.691807746887207,0.0 -26800,3.0423834,1.4056015,,,,,,,,,,,,,, -26900,1.9478266,1.3362317,,,,,,,,,,,,,, -27000,2.145172,1.3303798,,,,,,,,,,,,,, -27100,2.1001515,1.2980056,,,,,,,,,,,,,, -27200,2.1347585,1.3159372,,,,,,,,,,,,,, -27300,4.1461763,1.3041493,,,,,,,,,,,,,, -27400,2.8151438,1.3315914,,,,,,,,,,,,,, -27500,2.1557302,1.2888267,,,,,,,,,,,,,, -27600,1.7421017,1.3330762,,,,,,,,,,,,,, -27700,2.4704258,1.2773157,,,,,,,,,,,,,, -27800,2.0024421,1.3483394,,,,,,,,,,,,,, -27900,2.596135,1.3480852,,,,,,,,,,,,,, -28000,4.319267,1.2921677,,,,,,,,,,,,,, -28100,2.4250076,1.2453443,,,,,,,,,,,,,, -28200,2.5162342,1.269058,,,,,,,,,,,,,, -28300,2.6697671,1.3123692,,,,,,,,,,,,,, -28400,2.691198,1.3265517,,,,,,,,,,,,,, -28465,,,0.32599312,0.1092377701934016,0.5059581,0.1474362068799057,5348.0,0.2816703,0.090041232506652,2472.0,23059.842032194138,25350.305683374405,23059.842032194138,2288.435446739197,0.7465569972991943,0.0 -28500,2.4857655,1.3246077,,,,,,,,,,,,,, -28600,3.0518415,1.3121243,,,,,,,,,,,,,, -28700,2.122362,1.3337152,,,,,,,,,,,,,, -28800,2.214391,1.2518682,,,,,,,,,,,,,, -28900,2.2922854,1.3239144,,,,,,,,,,,,,, -29000,3.8428667,1.2662874,,,,,,,,,,,,,, -29100,2.743589,1.2872591,,,,,,,,,,,,,, -29200,2.5796862,1.2710234,,,,,,,,,,,,,, -29300,3.0382264,1.3291875,,,,,,,,,,,,,, -29400,2.2140887,1.3106173,,,,,,,,,,,,,, -29500,4.579713,1.2943146,,,,,,,,,,,,,, -29600,2.600625,1.2937762,,,,,,,,,,,,,, -29700,2.7356725,1.2889502,,,,,,,,,,,,,, -29800,2.2304025,1.2563348,,,,,,,,,,,,,, -29900,2.18268,1.3167062,,,,,,,,,,,,,, -30000,1.8947798,1.260984,,,,,,,,,,,,,, -30100,3.5302,1.2595786,,,,,,,,,,,,,, -30200,2.1858027,1.3191959,,,,,,,,,,,,,, -30300,3.9813876,1.26339,,,,,,,,,,,,,, -30341,,,0.3049627,0.1020345102376034,0.48760784,0.1428599013294457,5348.0,0.27179763,0.0894318851177056,2472.0,24499.839766025543,26917.968510866165,24499.839766025543,2415.957443475724,0.8124251365661621,0.0 -30400,2.4621668,1.2641171,,,,,,,,,,,,,, -30500,2.2624912,1.2469995,,,,,,,,,,,,,, -30600,3.1875734,1.2399462,,,,,,,,,,,,,, -30700,2.452183,1.2868102,,,,,,,,,,,,,, -30800,1.8071378,1.3041173,,,,,,,,,,,,,, -30900,2.837754,1.291308,,,,,,,,,,,,,, -31000,4.175926,1.2867788,,,,,,,,,,,,,, -31100,4.0390344,1.2643911,,,,,,,,,,,,,, -31200,2.2578926,1.2445258,,,,,,,,,,,,,, -31300,1.6547217,1.2928715,,,,,,,,,,,,,, -31400,2.8215296,1.2333393,,,,,,,,,,,,,, -31500,3.3805745,1.2667563,,,,,,,,,,,,,, -31600,1.917822,1.2899716,,,,,,,,,,,,,, -31700,3.5271864,1.3523084,,,,,,,,,,,,,, -31800,2.8880982,1.2401885,,,,,,,,,,,,,, -31900,2.7788994,1.2711705,,,,,,,,,,,,,, -32000,3.0855744,1.2969794,,,,,,,,,,,,,, -32100,2.4767938,1.2851512,,,,,,,,,,,,,, -32200,2.0811682,1.3423195,,,,,,,,,,,,,, -32253,,,0.28022093,0.0971240892029145,0.4823358,0.1403786554930148,5348.0,0.26502502,0.0862023439562894,2472.0,25940.395836114883,28483.39506316185,25940.395836114883,2540.700318098068,0.8646657466888428,0.0 -32300,2.1855083,1.2570115,,,,,,,,,,,,,, -32400,2.752584,1.2399474,,,,,,,,,,,,,, -32500,4.7268076,1.2167767,,,,,,,,,,,,,, -32600,1.8681791,1.287537,,,,,,,,,,,,,, -32700,3.060836,1.211827,,,,,,,,,,,,,, -32800,2.3190784,1.2627685,,,,,,,,,,,,,, -32900,3.010663,1.341927,,,,,,,,,,,,,, -33000,2.859234,1.1953689,,,,,,,,,,,,,, -33100,2.005719,1.2080836,,,,,,,,,,,,,, -33200,2.950728,1.252653,,,,,,,,,,,,,, -33300,3.0995898,1.2899998,,,,,,,,,,,,,, -33400,2.3016303,1.2364438,,,,,,,,,,,,,, -33500,2.2631721,1.2415683,,,,,,,,,,,,,, -33600,2.9664447,1.2563727,,,,,,,,,,,,,, -33700,2.2195575,1.2090056,,,,,,,,,,,,,, -33800,2.5277429,1.2340668,,,,,,,,,,,,,, -33900,1.5970272,1.1954855,,,,,,,,,,,,,, -34000,2.3635473,1.2317084,,,,,,,,,,,,,, -34100,2.364716,1.1903796,,,,,,,,,,,,,, -34162,,,0.26070222,0.0920251056153354,0.46313033,0.133803836759126,5348.0,0.2560349,0.0843539902098186,2472.0,27380.309321403503,30048.22342205048,27380.309321403503,2665.4868993759155,0.9170644283294678,0.0 -34200,3.1324608,1.230746,,,,,,,,,,,,,, -34300,2.4530058,1.2158246,,,,,,,,,,,,,, -34400,2.7861905,1.1900631,,,,,,,,,,,,,, -34500,5.53414,1.2294222,,,,,,,,,,,,,, -34600,1.9405783,1.2846961,,,,,,,,,,,,,, -34700,2.2454317,1.2579727,,,,,,,,,,,,,, -34800,3.0116894,1.2580131,,,,,,,,,,,,,, -34900,2.4270384,1.2148033,,,,,,,,,,,,,, -35000,2.5010357,1.2955847,,,,,,,,,,,,,, -35100,2.5326314,1.2065684,,,,,,,,,,,,,, -35200,1.7373511,1.2005482,,,,,,,,,,,,,, -35300,1.8762089,1.226789,,,,,,,,,,,,,, -35400,4.239975,1.2230247,,,,,,,,,,,,,, -35500,2.8556008,1.2169137,,,,,,,,,,,,,, -35600,2.3551166,1.2590688,,,,,,,,,,,,,, -35700,2.1675847,1.1682719,,,,,,,,,,,,,, -35800,2.2788382,1.1829593,,,,,,,,,,,,,, -35900,2.6433635,1.1465567,,,,,,,,,,,,,, -36000,2.7899086,1.1729417,,,,,,,,,,,,,, -36072,,,0.24969074,0.0875957002977648,0.45965692,0.1332728308408237,5348.0,0.25247654,0.082627505941137,2472.0,28820.84319806099,31616.8638048172,28820.84319806099,2793.465112447738,0.9679844379425048,0.0 -36100,2.2243378,1.2278489,,,,,,,,,,,,,, -36200,2.2981105,1.2108368,,,,,,,,,,,,,, -36300,2.3975747,1.232969,,,,,,,,,,,,,, -36400,3.2830946,1.1570195,,,,,,,,,,,,,, -36500,2.8448887,1.19056,,,,,,,,,,,,,, -36600,1.9408919,1.2316155,,,,,,,,,,,,,, -36700,4.145617,1.2311574,,,,,,,,,,,,,, -36800,3.2910306,1.1292058,,,,,,,,,,,,,, -36900,2.2939456,1.1980848,,,,,,,,,,,,,, -37000,2.0541856,1.1882515,,,,,,,,,,,,,, -37100,2.3399258,1.1891142,,,,,,,,,,,,,, -37200,2.6322234,1.237676,,,,,,,,,,,,,, -37300,1.6614785,1.1954476,,,,,,,,,,,,,, -37400,3.544562,1.2525225,,,,,,,,,,,,,, -37500,3.6612093,1.1885558,,,,,,,,,,,,,, -37600,2.5233657,1.2306968,,,,,,,,,,,,,, -37700,5.2772884,1.2082931,,,,,,,,,,,,,, -37800,3.0220253,1.1957766,,,,,,,,,,,,,, -37900,2.6622639,1.213904,,,,,,,,,,,,,, -37985,,,0.1514019,0.0536420251657881,0.44699508,0.1286868706373036,5348.0,0.24373694,0.0783620742185119,2472.0,30261.34621310234,33204.265234947205,30261.34621310234,2940.232330560684,1.021756649017334,0.0 -38000,2.305059,1.1180372,,,,,,,,,,,,,, -38100,2.322555,1.225274,,,,,,,,,,,,,, -38200,5.7397494,1.1631078,,,,,,,,,,,,,, -38300,2.8565829,1.1999035,,,,,,,,,,,,,, -38400,2.0569837,1.179944,,,,,,,,,,,,,, -38500,3.429883,1.1849383,,,,,,,,,,,,,, -38600,1.5564889,1.145085,,,,,,,,,,,,,, -38700,2.5647633,1.1940603,,,,,,,,,,,,,, -38800,3.0948877,1.1762935,,,,,,,,,,,,,, -38900,1.6955366,1.1952206,,,,,,,,,,,,,, -39000,3.3332767,1.1756897,,,,,,,,,,,,,, -39100,3.3658607,1.1346672,,,,,,,,,,,,,, -39200,4.5274153,1.1869811,,,,,,,,,,,,,, -39300,3.028976,1.19934,,,,,,,,,,,,,, -39400,2.168029,1.1539851,,,,,,,,,,,,,, -39500,2.2040913,1.1946514,,,,,,,,,,,,,, -39600,2.425206,1.1734172,,,,,,,,,,,,,, -39700,2.315956,1.2053115,,,,,,,,,,,,,, -39800,2.4744556,1.1784784,,,,,,,,,,,,,, -39897,,,0.14774501,0.0514875865615892,0.44163772,0.1280593181884009,5348.0,0.24011387,0.0778542847277232,2472.0,31701.76964020729,34774.295803546906,31701.76964020729,3069.708080053329,1.0760719776153564,0.0 -39900,6.514329,1.2260424,,,,,,,,,,,,,, -40000,3.159142,1.1305802,,,,,,,,,,,,,, -40100,3.1382802,1.1633517,,,,,,,,,,,,,, -40200,2.5893404,1.1494223,,,,,,,,,,,,,, -40300,4.0491924,1.1793354,,,,,,,,,,,,,, -40400,2.734538,1.149957,,,,,,,,,,,,,, -40500,1.8427908,1.1846029,,,,,,,,,,,,,, -40600,2.6749606,1.1553795,,,,,,,,,,,,,, -40700,1.9128447,1.1836568,,,,,,,,,,,,,, -40800,3.1919405,1.1980941,,,,,,,,,,,,,, -40900,3.3144605,1.1516428,,,,,,,,,,,,,, -41000,3.259317,1.1722003,,,,,,,,,,,,,, -41100,2.9823306,1.1738503,,,,,,,,,,,,,, -41200,2.8078768,1.1263305,,,,,,,,,,,,,, -41300,3.1096947,1.1804191,,,,,,,,,,,,,, -41400,2.1659846,1.1427525,,,,,,,,,,,,,, -41500,2.7069883,1.1139624,,,,,,,,,,,,,, -41600,2.1555574,1.1348442,,,,,,,,,,,,,, -41700,3.8126729,1.1261001,,,,,,,,,,,,,, -41796,,,0.15505517,0.0539654853458072,0.4351663,0.1263504445967734,5348.0,0.23733485,0.0761481120386732,2472.0,33142.368851184845,36345.265635252,33142.368851184845,3199.9463255405426,1.1304571628570557,0.0 -41800,4.00994,1.1592329,,,,,,,,,,,,,, -41900,3.3434875,1.153045,,,,,,,,,,,,,, -42000,1.9433929,1.1373914,,,,,,,,,,,,,, -42100,2.737845,1.2067105,,,,,,,,,,,,,, -42200,3.184421,1.1597614,,,,,,,,,,,,,, -42300,2.8857632,1.209139,,,,,,,,,,,,,, -42400,2.7453628,1.1354359,,,,,,,,,,,,,, -42500,3.6493778,1.1555885,,,,,,,,,,,,,, -42600,2.1055372,1.1206294,,,,,,,,,,,,,, -42700,3.0993233,1.1474547,,,,,,,,,,,,,, -42800,2.865939,1.0969533,,,,,,,,,,,,,, -42900,3.303873,1.1272161,,,,,,,,,,,,,, -43000,3.9700277,1.1712452,,,,,,,,,,,,,, -43100,2.2932289,1.1467637,,,,,,,,,,,,,, -43200,3.9312625,1.1537749,,,,,,,,,,,,,, -43300,2.6131718,1.2003598,,,,,,,,,,,,,, -43400,2.4896088,1.1353008,,,,,,,,,,,,,, -43500,5.1919966,1.1502546,,,,,,,,,,,,,, -43600,4.2495604,1.181889,,,,,,,,,,,,,, -43697,,,0.1378165,0.0491589448201018,0.43216133,0.1245643337806655,5348.0,0.23436242,0.0753153372737797,2472.0,34582.90179729462,37913.49289941788,34582.90179729462,3327.5098538398743,1.182192325592041,0.0 -43700,3.624972,1.1555353,,,,,,,,,,,,,, -43800,3.2017355,1.1479343,,,,,,,,,,,,,, -43900,4.2707486,1.2532889,,,,,,,,,,,,,, -44000,2.155742,1.2474655,,,,,,,,,,,,,, -44100,3.6062458,1.1773665,,,,,,,,,,,,,, -44200,2.7712944,1.1891121,,,,,,,,,,,,,, -44300,6.669731,1.1147913,,,,,,,,,,,,,, -44400,2.4659083,1.1176844,,,,,,,,,,,,,, -44500,1.6624515,1.1518525,,,,,,,,,,,,,, -44600,2.6915038,1.1723895,,,,,,,,,,,,,, -44700,2.645889,1.1835203,,,,,,,,,,,,,, -44800,9.537493,1.1536655,,,,,,,,,,,,,, -44900,2.6133194,1.1927122,,,,,,,,,,,,,, -45000,4.7228518,1.14775,,,,,,,,,,,,,, -45100,4.5550375,1.1447437,,,,,,,,,,,,,, -45200,2.2290788,1.130885,,,,,,,,,,,,,, -45300,8.312192,1.1241076,,,,,,,,,,,,,, -45400,2.2727294,1.1752323,,,,,,,,,,,,,, -45500,4.4907007,1.1555967,,,,,,,,,,,,,, -45596,,,0.14721335,0.0522255572556525,0.43008903,0.1244581325970051,5348.0,0.2329999,0.0748481709422541,2472.0,36023.13359427452,39482.34113240242,36023.13359427452,3455.9925067424774,1.236759901046753,0.0 -45600,3.325439,1.1296597,,,,,,,,,,,,,, -45700,3.9282525,1.1422002,,,,,,,,,,,,,, -45800,3.4797401,1.142011,,,,,,,,,,,,,, -45900,2.0033784,1.1412158,,,,,,,,,,,,,, -46000,2.3055212,1.126383,,,,,,,,,,,,,, -46100,2.7730289,1.1370052,,,,,,,,,,,,,, -46200,2.3370056,1.1715324,,,,,,,,,,,,,, -46300,2.0828133,1.166665,,,,,,,,,,,,,, -46400,2.456373,1.1307162,,,,,,,,,,,,,, -46500,4.046015,1.1015985,,,,,,,,,,,,,, -46600,3.4159393,1.1552575,,,,,,,,,,,,,, -46700,2.6121457,1.1507845,,,,,,,,,,,,,, -46800,2.386882,1.1726424,,,,,,,,,,,,,, -46900,2.0994682,1.1536417,,,,,,,,,,,,,, -47000,4.5857882,1.1128812,,,,,,,,,,,,,, -47100,2.1282973,1.1655641,,,,,,,,,,,,,, -47200,6.6594877,1.1410533,,,,,,,,,,,,,, -47300,8.536851,1.1570032,,,,,,,,,,,,,, -47400,4.427421,1.1185431,,,,,,,,,,,,,, -47488,,,0.13965625,0.050840394499236,0.42927903,0.124207111617444,5348.0,0.23242217,0.0747059898848333,2472.0,37463.108900785446,41051.40569543839,37463.108900785446,3584.948172807693,1.291504144668579,0.0 -47500,3.628744,1.1891383,,,,,,,,,,,,,, -47600,7.496813,1.1211063,,,,,,,,,,,,,, -47700,2.1530983,1.1391666,,,,,,,,,,,,,, -47800,3.3724308,1.1098202,,,,,,,,,,,,,, -47900,2.4894307,1.1409334,,,,,,,,,,,,,, -48000,,,0.16727167,0.0553427332818267,0.4293102,0.1242650395358042,5348.0,0.23241976,0.0748278593626226,2472.0,37846.51352763176,41564.03453874588,37846.51352763176,3714.087497472763,1.3495750427246094,0.0 -48000,,,,,,,,,,,37846.51352763176,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/eval_measurements.csv deleted file mode 100644 index f0330b035..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,28 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -183.2568774223328,0.0,14.655245780944824,1,0,14.655245780944824,30.536522,2472,3.36162736376008,197.91219639778137,29.73173,3.3317589058524173,30.35123,5348,3.0462844067698427 -289.58756613731384,0.0293948650360107,1454.732485294342,1870,0,1454.732485294342,6.3189464,2472,0.8995592387219954,1744.4271211624146,6.4368887,0.9437977009574124,6.3937473,5348,0.8966083203800072 -410.0903272628784,0.0782649517059326,2895.4914994239807,3751,0,2895.4914994239807,2.8410707,2472,0.6370523835638697,3305.8187985420227,3.090075,0.6942303997950864,3.3290048,5348,0.7045386524035259 -539.3774290084839,0.1310043334960937,4335.885547399521,5632,0,4335.885547399521,0.6315554,2472,0.2048829037434241,4875.635069847107,0.59384084,0.2020196036191296,0.9646992,5348,0.2732459909053168 -671.8170447349548,0.1812648773193359,5776.177357435226,7518,0,5776.177357435226,0.54608774,2472,0.1780513070501493,6448.498773574829,0.4997202,0.1679969494039845,0.8595321,5348,0.2458171215617366 -801.1017694473267,0.2320382595062255,7216.76273560524,9405,0,7216.76273560524,0.4813573,2472,0.1556070115572888,8018.502638101578,0.44924083,0.1499989441898769,0.76941633,5348,0.2218446180136516 -934.3454160690308,0.2858431339263916,8657.179950237274,11289,0,8657.179950237274,0.44250494,2472,0.1421404342615725,9592.300470352173,0.41514525,0.1392252504823433,0.7317493,5348,0.2103362715660812 -1066.3841433525083,0.3352360725402832,10097.492826223372,13156,0,10097.492826223372,0.4097351,2472,0.1304409643938009,11164.784424304962,0.38317287,0.1291478692802238,0.6829772,5348,0.1974473097309248 -1198.4742548465729,0.388714075088501,11537.915561676024,15032,0,11537.915561676024,0.39503682,2472,0.1293441390936973,12737.434754610062,0.332868,0.1158005177386105,0.666079,5348,0.1920696679764812 -1330.4830236434937,0.4386367797851562,12978.067933797836,16910,0,12978.067933797836,0.38526803,2472,0.1252818231673877,14309.72917985916,0.33133,0.1122018775264075,0.6564473,5348,0.1873099240178804 -1460.588921546936,0.4884147644042969,14418.436093091965,18779,0,14418.436093091965,0.36856124,2472,0.1179696545000304,15880.33690047264,0.33849165,0.1135576714996754,0.6315583,5348,0.1813240391206542 -1590.1804230213163,0.5372920036315918,15858.46332526207,20639,0,15858.46332526207,0.35437745,2472,0.1130745638088274,17450.086770772934,0.30381674,0.1015356456085111,0.6062795,5348,0.1727410525502766 -1723.2054243087769,0.5952105522155762,17298.85318994522,22516,0,17298.85318994522,0.34771985,2472,0.1126683322161964,19023.645206451416,0.29839107,0.1026623576485072,0.59857935,5348,0.1738223736930013 -1854.3274881839752,0.6483442783355713,18739.20464229584,24387,0,18739.20464229584,0.33003947,2472,0.1075904373083094,20595.2571952343,0.31947687,0.1022953295680568,0.5687596,5348,0.1662048524286279 -1986.9184362888336,0.6978366374969482,20179.67861032486,26266,0,20179.67861032486,0.31949303,2472,0.1042187150894725,22168.45689558983,0.23455991,0.0812806366265356,0.5503892,5348,0.1592052289600973 -2118.6045274734497,0.7499384880065918,21619.813318490986,28145,0,21619.813318490986,0.30606657,2472,0.0993845591371641,23740.41608357429,0.24839616,0.0834191555097837,0.541494,5348,0.1560095387972233 -2249.3122136592865,0.7996718883514404,23060.3713555336,30009,0,23060.3713555336,0.291443,2472,0.0950175695163813,25311.81352329254,0.32168937,0.107909952294062,0.52409434,5348,0.1520897496548461 -2379.4582090377808,0.8533070087432861,24500.939710855484,31874,0,24500.939710855484,0.2857252,2472,0.0927629841772794,26882.66648888588,0.33718112,0.1116452902589013,0.50746393,5348,0.1489809513695125 -2508.6631367206573,0.9063496589660645,25941.51305270195,33739,0,25941.51305270195,0.2731824,2472,0.0880913208620234,28452.583513498303,0.36467797,0.1227608130684982,0.487907,5348,0.1411703370439383 -2638.0246579647064,0.9644997119903564,27381.70995497704,35622,0,27381.70995497704,0.2628657,2472,0.0838462007190299,30022.285489320755,0.30550236,0.0994893766874046,0.47197154,5348,0.1370960734526004 -2768.7975289821625,1.0214431285858154,28821.724100589752,37493,0,28821.724100589752,0.25255555,2472,0.0803119858631405,31593.214646816254,0.2715685,0.0926955085160065,0.45568103,5348,0.131959798024658 -2900.555379629135,1.07637619972229,30262.321855068207,39364,0,30262.321855068207,0.24169599,2472,0.0783011394796173,33165.71136927605,0.22365406,0.0767750063533088,0.44460142,5348,0.1275765855353987 -3031.871179819107,1.132845163345337,31702.706634283066,41230,0,31702.706634283066,0.23356742,2472,0.07413726565515,34737.552344322205,0.25303861,0.08619467289108,0.429852,5348,0.1248539733724668 -3163.4506227970123,1.1886072158813477,33142.88335800171,43113,0,33142.88335800171,0.22784421,2472,0.0733857372087827,36309.45052433014,0.21915351,0.0752219014009197,0.42223984,5348,0.1220541239850546 -3294.2698872089386,1.2424731254577637,34583.50731277466,44985,0,34583.50731277466,0.22397295,2472,0.0724107813864684,37881.03183031082,0.21478947,0.0762512871082151,0.41844097,5348,0.120547998107688 -3425.495941877365,1.2960054874420166,36024.03443598747,46864,0,36024.03443598747,0.22283582,2472,0.0718826803160481,39452.92549777031,0.21869154,0.0752600940614139,0.41614845,5348,0.1198432084343049 -3555.660044193268,1.351287841796875,36883.611184597015,48000,0,36883.611184597015,0.22271666,2472,0.0716795645197327,40442.77611327171,0.23044662,0.07937063940968733,0.4161281,5348,0.11978528051594466 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/measurements.csv deleted file mode 100644 index fab654e0a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/measurements.csv +++ /dev/null @@ -1,509 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,23.855167,32.945007,,,,,,,,,,,,,, -1,,,29.73173,3.3317589058524173,30.35123,3.0462844067698427,5348.0,30.536522,3.36162736376008,2472.0,14.655245780944824,197.91219639778137,14.655245780944824,183.2568774223328,0.0,0.0 -100,7.366188,9.443339,,,,,,,,,,,,,, -200,1.7605289,6.5504284,,,,,,,,,,,,,, -300,0.77724403,5.909629,,,,,,,,,,,,,, -400,0.5094949,5.856266,,,,,,,,,,,,,, -500,0.52208257,5.8432617,,,,,,,,,,,,,, -600,0.35727647,5.821673,,,,,,,,,,,,,, -700,0.6942517,5.7721505,,,,,,,,,,,,,, -800,0.42385772,5.709428,,,,,,,,,,,,,, -900,0.40143132,5.606819,,,,,,,,,,,,,, -1000,0.5099695,5.494177,,,,,,,,,,,,,, -1100,0.5513551,5.341774,,,,,,,,,,,,,, -1200,1.5589997,5.056851,,,,,,,,,,,,,, -1300,1.8142133,4.554522,,,,,,,,,,,,,, -1400,1.6148173,4.094589,,,,,,,,,,,,,, -1500,1.7605033,3.7985106,,,,,,,,,,,,,, -1600,1.5146009,3.5708618,,,,,,,,,,,,,, -1700,1.7821426,3.398912,,,,,,,,,,,,,, -1800,2.232893,3.234894,,,,,,,,,,,,,, -1870,,,6.4368887,0.9437977009574124,6.3937473,0.8966083203800072,5348.0,6.3189464,0.8995592387219954,2472.0,1454.732485294342,1744.4271211624146,1454.732485294342,289.58756613731384,0.0293948650360107,0.0 -1900,2.3035362,3.0912223,,,,,,,,,,,,,, -2000,2.0871317,3.0003746,,,,,,,,,,,,,, -2100,2.5768986,2.87369,,,,,,,,,,,,,, -2200,2.0145786,2.7732327,,,,,,,,,,,,,, -2300,2.20431,2.6757624,,,,,,,,,,,,,, -2400,3.223526,2.6373076,,,,,,,,,,,,,, -2500,3.6209548,2.6065526,,,,,,,,,,,,,, -2600,4.120243,2.550852,,,,,,,,,,,,,, -2700,3.544594,2.52141,,,,,,,,,,,,,, -2800,4.3908253,2.430044,,,,,,,,,,,,,, -2900,3.2455182,2.4374645,,,,,,,,,,,,,, -3000,3.9679081,2.3693602,,,,,,,,,,,,,, -3100,3.167034,2.3426242,,,,,,,,,,,,,, -3200,3.480273,2.2445571,,,,,,,,,,,,,, -3300,4.0480185,2.2906134,,,,,,,,,,,,,, -3400,3.4272714,2.1678133,,,,,,,,,,,,,, -3500,3.725747,2.10166,,,,,,,,,,,,,, -3600,2.7392652,2.051895,,,,,,,,,,,,,, -3700,3.449511,2.0572553,,,,,,,,,,,,,, -3751,,,3.090075,0.6942303997950864,3.3290048,0.7045386524035259,5348.0,2.8410707,0.6370523835638697,2472.0,2895.4914994239807,3305.8187985420227,2895.4914994239807,410.0903272628784,0.0782649517059326,0.0 -3800,2.8487062,2.1492977,,,,,,,,,,,,,, -3900,4.674115,2.136942,,,,,,,,,,,,,, -4000,2.6212926,2.0901446,,,,,,,,,,,,,, -4100,4.193834,2.05863,,,,,,,,,,,,,, -4200,4.1618958,2.0128949,,,,,,,,,,,,,, -4300,4.819425,2.0349913,,,,,,,,,,,,,, -4400,3.7360942,2.0676696,,,,,,,,,,,,,, -4500,3.5156362,2.0437806,,,,,,,,,,,,,, -4600,4.4802456,2.0878563,,,,,,,,,,,,,, -4700,3.4191504,2.0165815,,,,,,,,,,,,,, -4800,3.4879816,1.9943755,,,,,,,,,,,,,, -4900,4.5282173,1.949732,,,,,,,,,,,,,, -5000,11.12454,2.0169296,,,,,,,,,,,,,, -5100,4.083558,1.9546832,,,,,,,,,,,,,, -5200,4.5395193,1.9715028,,,,,,,,,,,,,, -5300,3.5996473,1.9797907,,,,,,,,,,,,,, -5400,4.14144,1.93439,,,,,,,,,,,,,, -5500,3.3098056,1.8074976,,,,,,,,,,,,,, -5600,2.286511,1.8122228,,,,,,,,,,,,,, -5632,,,0.59384084,0.2020196036191296,0.9646992,0.2732459909053168,5348.0,0.6315554,0.2048829037434241,2472.0,4335.885547399521,4875.635069847107,4335.885547399521,539.3774290084839,0.1310043334960937,0.0 -5700,4.394531,1.89114,,,,,,,,,,,,,, -5800,3.0535374,1.8257055,,,,,,,,,,,,,, -5900,2.9763992,1.9457631,,,,,,,,,,,,,, -6000,3.4355962,1.818079,,,,,,,,,,,,,, -6100,3.1265152,1.9172906,,,,,,,,,,,,,, -6200,2.8485265,1.8608685,,,,,,,,,,,,,, -6300,3.8178563,1.8066928,,,,,,,,,,,,,, -6400,7.0652766,1.8035126,,,,,,,,,,,,,, -6500,3.6595356,1.877315,,,,,,,,,,,,,, -6600,2.1344905,1.8377873,,,,,,,,,,,,,, -6700,3.3864455,1.8222997,,,,,,,,,,,,,, -6800,4.21416,1.8284608,,,,,,,,,,,,,, -6900,3.095543,1.8429646,,,,,,,,,,,,,, -7000,3.1053953,1.7931521,,,,,,,,,,,,,, -7100,2.7264946,1.8261135,,,,,,,,,,,,,, -7200,4.708413,1.7483426,,,,,,,,,,,,,, -7300,3.1806831,1.764796,,,,,,,,,,,,,, -7400,2.4480593,1.8008981,,,,,,,,,,,,,, -7500,2.229334,1.8335546,,,,,,,,,,,,,, -7518,,,0.4997202,0.1679969494039845,0.8595321,0.2458171215617366,5348.0,0.54608774,0.1780513070501493,2472.0,5776.177357435226,6448.498773574829,5776.177357435226,671.8170447349548,0.1812648773193359,0.0 -7600,2.6770358,1.7463161,,,,,,,,,,,,,, -7700,2.988373,1.7346256,,,,,,,,,,,,,, -7800,2.585496,1.7754049,,,,,,,,,,,,,, -7900,3.041882,1.6612123,,,,,,,,,,,,,, -8000,5.2289276,1.8109123,,,,,,,,,,,,,, -8100,2.6492267,1.8074137,,,,,,,,,,,,,, -8200,3.2038078,1.6924727,,,,,,,,,,,,,, -8300,3.5091858,1.6954746,,,,,,,,,,,,,, -8400,2.9353514,1.8083433,,,,,,,,,,,,,, -8500,2.3716753,1.7178104,,,,,,,,,,,,,, -8600,3.302687,1.728688,,,,,,,,,,,,,, -8700,2.7899876,1.7595028,,,,,,,,,,,,,, -8800,3.7314863,1.6522521,,,,,,,,,,,,,, -8900,2.7546327,1.7011342,,,,,,,,,,,,,, -9000,2.3580067,1.7204969,,,,,,,,,,,,,, -9100,2.9105067,1.7320219,,,,,,,,,,,,,, -9200,2.8551002,1.6742811,,,,,,,,,,,,,, -9300,2.5787055,1.6891413,,,,,,,,,,,,,, -9400,3.3104148,1.7040304,,,,,,,,,,,,,, -9405,,,0.44924083,0.1499989441898769,0.76941633,0.2218446180136516,5348.0,0.4813573,0.1556070115572888,2472.0,7216.76273560524,8018.502638101578,7216.76273560524,801.1017694473267,0.2320382595062255,0.0 -9500,3.178639,1.7460667,,,,,,,,,,,,,, -9600,4.0179935,1.668621,,,,,,,,,,,,,, -9700,2.42831,1.6780139,,,,,,,,,,,,,, -9800,3.4365056,1.7502975,,,,,,,,,,,,,, -9900,2.9726884,1.6309093,,,,,,,,,,,,,, -10000,2.0425189,1.6407804,,,,,,,,,,,,,, -10100,2.3122113,1.6955565,,,,,,,,,,,,,, -10200,2.9268498,1.6946532,,,,,,,,,,,,,, -10300,3.0845702,1.7117649,,,,,,,,,,,,,, -10400,5.115329,1.6150246,,,,,,,,,,,,,, -10500,2.830721,1.6769549,,,,,,,,,,,,,, -10600,2.5330832,1.6536729,,,,,,,,,,,,,, -10700,3.1752179,1.6587788,,,,,,,,,,,,,, -10800,3.5781727,1.7129819,,,,,,,,,,,,,, -10900,2.58671,1.6935706,,,,,,,,,,,,,, -11000,2.3606281,1.6384625,,,,,,,,,,,,,, -11100,3.003452,1.6304963,,,,,,,,,,,,,, -11200,2.641957,1.6545048,,,,,,,,,,,,,, -11289,,,0.41514525,0.1392252504823433,0.7317493,0.2103362715660812,5348.0,0.44250494,0.1421404342615725,2472.0,8657.179950237274,9592.300470352173,8657.179950237274,934.3454160690308,0.2858431339263916,0.0 -11300,4.061969,1.6590401,,,,,,,,,,,,,, -11400,2.6112814,1.6720443,,,,,,,,,,,,,, -11500,3.1440618,1.5597247,,,,,,,,,,,,,, -11600,4.189682,1.6698511,,,,,,,,,,,,,, -11700,3.4316838,1.6244643,,,,,,,,,,,,,, -11800,2.8701785,1.6764565,,,,,,,,,,,,,, -11900,2.7052002,1.6351317,,,,,,,,,,,,,, -12000,3.241165,1.6320866,,,,,,,,,,,,,, -12100,5.566329,1.6282761,,,,,,,,,,,,,, -12200,2.5754683,1.6978545,,,,,,,,,,,,,, -12300,3.096598,1.6390043,,,,,,,,,,,,,, -12400,2.387958,1.6168904,,,,,,,,,,,,,, -12500,3.6550908,1.5911939,,,,,,,,,,,,,, -12600,3.0533283,1.6189054,,,,,,,,,,,,,, -12700,2.9006937,1.6999251,,,,,,,,,,,,,, -12800,3.6930857,1.6439937,,,,,,,,,,,,,, -12900,2.316825,1.5926263,,,,,,,,,,,,,, -13000,2.1824458,1.6020875,,,,,,,,,,,,,, -13100,1.9417714,1.6499035,,,,,,,,,,,,,, -13156,,,0.38317287,0.1291478692802238,0.6829772,0.1974473097309248,5348.0,0.4097351,0.1304409643938009,2472.0,10097.492826223372,11164.784424304962,10097.492826223372,1066.3841433525083,0.3352360725402832,0.0 -13200,4.355909,1.5464998,,,,,,,,,,,,,, -13300,8.224519,1.592584,,,,,,,,,,,,,, -13400,3.4963453,1.6097498,,,,,,,,,,,,,, -13500,3.0426376,1.6142086,,,,,,,,,,,,,, -13600,2.8676522,1.6028612,,,,,,,,,,,,,, -13700,2.2195876,1.6409312,,,,,,,,,,,,,, -13800,2.5968623,1.5887727,,,,,,,,,,,,,, -13900,2.4962034,1.6097909,,,,,,,,,,,,,, -14000,3.1399713,1.6100426,,,,,,,,,,,,,, -14100,2.934489,1.5988619,,,,,,,,,,,,,, -14200,2.5901868,1.5905033,,,,,,,,,,,,,, -14300,2.8476782,1.5965502,,,,,,,,,,,,,, -14400,2.9217696,1.5960081,,,,,,,,,,,,,, -14500,1.8845803,1.6088051,,,,,,,,,,,,,, -14600,2.3310833,1.5815933,,,,,,,,,,,,,, -14700,4.258266,1.52964,,,,,,,,,,,,,, -14800,3.5318396,1.6032757,,,,,,,,,,,,,, -14900,3.0805986,1.5780612,,,,,,,,,,,,,, -15000,2.8274908,1.6446675,,,,,,,,,,,,,, -15032,,,0.332868,0.1158005177386105,0.666079,0.1920696679764812,5348.0,0.39503682,0.1293441390936973,2472.0,11537.915561676024,12737.434754610062,11537.915561676024,1198.4742548465729,0.388714075088501,0.0 -15100,3.2590563,1.6099242,,,,,,,,,,,,,, -15200,2.550641,1.5179883,,,,,,,,,,,,,, -15300,2.74994,1.6303165,,,,,,,,,,,,,, -15400,4.058529,1.5409781,,,,,,,,,,,,,, -15500,3.5956683,1.5818697,,,,,,,,,,,,,, -15600,3.2283764,1.5973775,,,,,,,,,,,,,, -15700,3.8366675,1.5384064,,,,,,,,,,,,,, -15800,3.828292,1.5197943,,,,,,,,,,,,,, -15900,2.2097452,1.5948044,,,,,,,,,,,,,, -16000,2.765398,1.5381293,,,,,,,,,,,,,, -16100,3.9878745,1.6762809,,,,,,,,,,,,,, -16200,4.5847416,1.5709685,,,,,,,,,,,,,, -16300,2.795014,1.6566011,,,,,,,,,,,,,, -16400,3.0506406,1.5703096,,,,,,,,,,,,,, -16500,3.3601446,1.5460303,,,,,,,,,,,,,, -16600,2.85122,1.5875005,,,,,,,,,,,,,, -16700,2.1631944,1.5357844,,,,,,,,,,,,,, -16800,2.219732,1.5228485,,,,,,,,,,,,,, -16900,5.8699427,1.5810053,,,,,,,,,,,,,, -16910,,,0.33133,0.1122018775264075,0.6564473,0.1873099240178804,5348.0,0.38526803,0.1252818231673877,2472.0,12978.067933797836,14309.72917985916,12978.067933797836,1330.4830236434937,0.4386367797851562,0.0 -17000,3.5391066,1.5552168,,,,,,,,,,,,,, -17100,2.6752622,1.5329967,,,,,,,,,,,,,, -17200,3.941613,1.568127,,,,,,,,,,,,,, -17300,2.3174345,1.558338,,,,,,,,,,,,,, -17400,2.2909975,1.6033437,,,,,,,,,,,,,, -17500,4.066459,1.5488648,,,,,,,,,,,,,, -17600,1.8656427,1.4771204,,,,,,,,,,,,,, -17700,3.1231425,1.4933678,,,,,,,,,,,,,, -17800,3.4558952,1.5635267,,,,,,,,,,,,,, -17900,2.6804485,1.5563523,,,,,,,,,,,,,, -18000,4.2672606,1.5573375,,,,,,,,,,,,,, -18100,5.54186,1.6168097,,,,,,,,,,,,,, -18200,2.9854605,1.5865085,,,,,,,,,,,,,, -18300,2.8136425,1.5409261,,,,,,,,,,,,,, -18400,3.45189,1.505171,,,,,,,,,,,,,, -18500,2.1089532,1.4700837,,,,,,,,,,,,,, -18600,2.0338008,1.5315067,,,,,,,,,,,,,, -18700,2.3394942,1.534076,,,,,,,,,,,,,, -18779,,,0.33849165,0.1135576714996754,0.6315583,0.1813240391206542,5348.0,0.36856124,0.1179696545000304,2472.0,14418.436093091965,15880.33690047264,14418.436093091965,1460.588921546936,0.4884147644042969,0.0 -18800,1.9438947,1.4675344,,,,,,,,,,,,,, -18900,3.375074,1.5129452,,,,,,,,,,,,,, -19000,3.018156,1.6022719,,,,,,,,,,,,,, -19100,2.5835252,1.5299689,,,,,,,,,,,,,, -19200,2.2227063,1.4383143,,,,,,,,,,,,,, -19300,3.3595467,1.5186156,,,,,,,,,,,,,, -19400,2.6125205,1.4986775,,,,,,,,,,,,,, -19500,2.3532152,1.5180964,,,,,,,,,,,,,, -19600,2.6291451,1.540626,,,,,,,,,,,,,, -19700,2.6216462,1.4055724,,,,,,,,,,,,,, -19800,2.2505472,1.4945632,,,,,,,,,,,,,, -19900,3.380812,1.4467068,,,,,,,,,,,,,, -20000,2.7432604,1.5238044,,,,,,,,,,,,,, -20100,4.1033883,1.4686179,,,,,,,,,,,,,, -20200,3.3643684,1.5151297,,,,,,,,,,,,,, -20300,3.5948431,1.5313976,,,,,,,,,,,,,, -20400,2.7284763,1.5441049,,,,,,,,,,,,,, -20500,2.6643252,1.5110067,,,,,,,,,,,,,, -20600,2.7603946,1.4989841,,,,,,,,,,,,,, -20639,,,0.30381674,0.1015356456085111,0.6062795,0.1727410525502766,5348.0,0.35437745,0.1130745638088274,2472.0,15858.46332526207,17450.086770772934,15858.46332526207,1590.1804230213163,0.5372920036315918,0.0 -20700,2.0779588,1.4901707,,,,,,,,,,,,,, -20800,2.1357143,1.4715177,,,,,,,,,,,,,, -20900,3.380953,1.5099152,,,,,,,,,,,,,, -21000,1.9813174,1.451618,,,,,,,,,,,,,, -21100,3.06028,1.5579134,,,,,,,,,,,,,, -21200,2.3817043,1.5128769,,,,,,,,,,,,,, -21300,2.885232,1.5231779,,,,,,,,,,,,,, -21400,1.6476791,1.3968914,,,,,,,,,,,,,, -21500,3.1101384,1.4589093,,,,,,,,,,,,,, -21600,3.864246,1.4576554,,,,,,,,,,,,,, -21700,3.6801674,1.4822586,,,,,,,,,,,,,, -21800,2.2020845,1.4528967,,,,,,,,,,,,,, -21900,3.7994173,1.491885,,,,,,,,,,,,,, -22000,2.871898,1.5420171,,,,,,,,,,,,,, -22100,2.550892,1.4981095,,,,,,,,,,,,,, -22200,2.7003946,1.4390066,,,,,,,,,,,,,, -22300,3.3064325,1.5100118,,,,,,,,,,,,,, -22400,3.4273174,1.4754496,,,,,,,,,,,,,, -22500,2.3057666,1.5145022,,,,,,,,,,,,,, -22516,,,0.29839107,0.1026623576485072,0.59857935,0.1738223736930013,5348.0,0.34771985,0.1126683322161964,2472.0,17298.85318994522,19023.645206451416,17298.85318994522,1723.2054243087769,0.5952105522155762,0.0 -22600,3.3824987,1.4411199,,,,,,,,,,,,,, -22700,3.3832963,1.5115745,,,,,,,,,,,,,, -22800,2.317615,1.4570773,,,,,,,,,,,,,, -22900,3.0338995,1.467916,,,,,,,,,,,,,, -23000,2.8796918,1.4722407,,,,,,,,,,,,,, -23100,3.459211,1.4401582,,,,,,,,,,,,,, -23200,2.6635506,1.4595755,,,,,,,,,,,,,, -23300,2.7668982,1.4714042,,,,,,,,,,,,,, -23400,1.7452304,1.4636086,,,,,,,,,,,,,, -23500,3.0558355,1.441453,,,,,,,,,,,,,, -23600,2.9159098,1.4601294,,,,,,,,,,,,,, -23700,3.4262235,1.4267219,,,,,,,,,,,,,, -23800,2.2673817,1.4394857,,,,,,,,,,,,,, -23900,2.3265486,1.4553909,,,,,,,,,,,,,, -24000,2.7228696,1.4654291,,,,,,,,,,,,,, -24100,3.0368958,1.4638472,,,,,,,,,,,,,, -24200,3.6173043,1.4839953,,,,,,,,,,,,,, -24300,2.1876853,1.4302112,,,,,,,,,,,,,, -24387,,,0.31947687,0.1022953295680568,0.5687596,0.1662048524286279,5348.0,0.33003947,0.1075904373083094,2472.0,18739.20464229584,20595.2571952343,18739.20464229584,1854.3274881839752,0.6483442783355713,0.0 -24400,3.1275482,1.4668015,,,,,,,,,,,,,, -24500,2.9907992,1.416213,,,,,,,,,,,,,, -24600,2.7675354,1.5155566,,,,,,,,,,,,,, -24700,2.9945107,1.4380838,,,,,,,,,,,,,, -24800,2.4813251,1.4658881,,,,,,,,,,,,,, -24900,2.037158,1.4718859,,,,,,,,,,,,,, -25000,3.0932317,1.4285269,,,,,,,,,,,,,, -25100,3.0423203,1.4581738,,,,,,,,,,,,,, -25200,2.6424837,1.3880215,,,,,,,,,,,,,, -25300,2.7490337,1.4418263,,,,,,,,,,,,,, -25400,3.0091782,1.4295151,,,,,,,,,,,,,, -25500,2.6948075,1.4661237,,,,,,,,,,,,,, -25600,2.0825415,1.371473,,,,,,,,,,,,,, -25700,2.139588,1.456392,,,,,,,,,,,,,, -25800,3.259117,1.3977449,,,,,,,,,,,,,, -25900,3.4204547,1.3918548,,,,,,,,,,,,,, -26000,2.638076,1.4211446,,,,,,,,,,,,,, -26100,3.2205262,1.3793476,,,,,,,,,,,,,, -26200,1.636138,1.3985565,,,,,,,,,,,,,, -26266,,,0.23455991,0.0812806366265356,0.5503892,0.1592052289600973,5348.0,0.31949303,0.1042187150894725,2472.0,20179.67861032486,22168.45689558983,20179.67861032486,1986.9184362888336,0.6978366374969482,0.0 -26300,3.2004273,1.3848261,,,,,,,,,,,,,, -26400,3.4101183,1.4076098,,,,,,,,,,,,,, -26500,2.3819556,1.4585651,,,,,,,,,,,,,, -26600,2.490819,1.3972971,,,,,,,,,,,,,, -26700,3.4372559,1.4200853,,,,,,,,,,,,,, -26800,2.3962998,1.4472638,,,,,,,,,,,,,, -26900,2.2272437,1.3961093,,,,,,,,,,,,,, -27000,4.119461,1.4536548,,,,,,,,,,,,,, -27100,2.7400372,1.4159777,,,,,,,,,,,,,, -27200,1.9158548,1.4410931,,,,,,,,,,,,,, -27300,2.056799,1.4248952,,,,,,,,,,,,,, -27400,2.3408456,1.3916508,,,,,,,,,,,,,, -27500,4.0300293,1.4211578,,,,,,,,,,,,,, -27600,2.785648,1.3769491,,,,,,,,,,,,,, -27700,2.8945708,1.3695459,,,,,,,,,,,,,, -27800,4.9158235,1.398779,,,,,,,,,,,,,, -27900,3.0658665,1.3671021,,,,,,,,,,,,,, -28000,2.804242,1.3929,,,,,,,,,,,,,, -28100,2.7165062,1.38478,,,,,,,,,,,,,, -28145,,,0.24839616,0.0834191555097837,0.541494,0.1560095387972233,5348.0,0.30606657,0.0993845591371641,2472.0,21619.813318490986,23740.41608357429,21619.813318490986,2118.6045274734497,0.7499384880065918,0.0 -28200,2.4446619,1.356422,,,,,,,,,,,,,, -28300,2.781369,1.3829796,,,,,,,,,,,,,, -28400,2.0520675,1.395065,,,,,,,,,,,,,, -28500,2.7865477,1.3957454,,,,,,,,,,,,,, -28600,3.0388265,1.3635691,,,,,,,,,,,,,, -28700,2.5676332,1.4063146,,,,,,,,,,,,,, -28800,3.2293508,1.3805176,,,,,,,,,,,,,, -28900,4.84771,1.3585185,,,,,,,,,,,,,, -29000,2.949488,1.3497714,,,,,,,,,,,,,, -29100,2.6705081,1.3868921,,,,,,,,,,,,,, -29200,2.1926038,1.387034,,,,,,,,,,,,,, -29300,3.7864408,1.4366295,,,,,,,,,,,,,, -29400,2.2552638,1.3500886,,,,,,,,,,,,,, -29500,3.5003402,1.3850348,,,,,,,,,,,,,, -29600,2.4157517,1.3174449,,,,,,,,,,,,,, -29700,3.346257,1.3344474,,,,,,,,,,,,,, -29800,2.302519,1.307893,,,,,,,,,,,,,, -29900,2.2724898,1.3803519,,,,,,,,,,,,,, -30000,2.4272494,1.3151696,,,,,,,,,,,,,, -30009,,,0.32168937,0.107909952294062,0.52409434,0.1520897496548461,5348.0,0.291443,0.0950175695163813,2472.0,23060.3713555336,25311.81352329254,23060.3713555336,2249.3122136592865,0.7996718883514404,0.0 -30100,3.4616075,1.4025869,,,,,,,,,,,,,, -30200,2.062909,1.3848009,,,,,,,,,,,,,, -30300,2.9004648,1.3755482,,,,,,,,,,,,,, -30400,2.777698,1.3712898,,,,,,,,,,,,,, -30500,3.384997,1.3555515,,,,,,,,,,,,,, -30600,5.520186,1.282941,,,,,,,,,,,,,, -30700,3.9263592,1.3183571,,,,,,,,,,,,,, -30800,2.4621942,1.3408924,,,,,,,,,,,,,, -30900,3.613191,1.3248973,,,,,,,,,,,,,, -31000,3.0284595,1.3485624,,,,,,,,,,,,,, -31100,2.0810366,1.3062723,,,,,,,,,,,,,, -31200,2.8214934,1.393779,,,,,,,,,,,,,, -31300,4.1624327,1.3437824,,,,,,,,,,,,,, -31400,2.5739365,1.3256046,,,,,,,,,,,,,, -31500,2.5487726,1.3624866,,,,,,,,,,,,,, -31600,1.8099737,1.3033105,,,,,,,,,,,,,, -31700,2.812335,1.280024,,,,,,,,,,,,,, -31800,4.690214,1.3043182,,,,,,,,,,,,,, -31874,,,0.33718112,0.1116452902589013,0.50746393,0.1489809513695125,5348.0,0.2857252,0.0927629841772794,2472.0,24500.939710855484,26882.66648888588,24500.939710855484,2379.4582090377808,0.8533070087432861,0.0 -31900,3.063344,1.3252019,,,,,,,,,,,,,, -32000,2.6634202,1.3383713,,,,,,,,,,,,,, -32100,3.847699,1.3680013,,,,,,,,,,,,,, -32200,3.4059854,1.3476804,,,,,,,,,,,,,, -32300,3.6199775,1.2953353,,,,,,,,,,,,,, -32400,1.9758726,1.2635852,,,,,,,,,,,,,, -32500,3.280346,1.335536,,,,,,,,,,,,,, -32600,2.1689715,1.3070333,,,,,,,,,,,,,, -32700,2.4323266,1.1975557,,,,,,,,,,,,,, -32800,2.1807253,1.2998903,,,,,,,,,,,,,, -32900,5.475537,1.3113059,,,,,,,,,,,,,, -33000,2.6104164,1.2411113,,,,,,,,,,,,,, -33100,2.604451,1.2540501,,,,,,,,,,,,,, -33200,2.5826945,1.2984192,,,,,,,,,,,,,, -33300,3.097428,1.2472255,,,,,,,,,,,,,, -33400,2.379989,1.2939947,,,,,,,,,,,,,, -33500,2.4110415,1.2906898,,,,,,,,,,,,,, -33600,2.299114,1.3057846,,,,,,,,,,,,,, -33700,3.2366474,1.2236688,,,,,,,,,,,,,, -33739,,,0.36467797,0.1227608130684982,0.487907,0.1411703370439383,5348.0,0.2731824,0.0880913208620234,2472.0,25941.51305270195,28452.583513498303,25941.51305270195,2508.6631367206573,0.9063496589660645,0.0 -33800,2.5843043,1.2761153,,,,,,,,,,,,,, -33900,3.8021255,1.2396128,,,,,,,,,,,,,, -34000,2.9832485,1.2792541,,,,,,,,,,,,,, -34100,3.985551,1.2632428,,,,,,,,,,,,,, -34200,2.142674,1.2933509,,,,,,,,,,,,,, -34300,2.8314493,1.2572223,,,,,,,,,,,,,, -34400,3.266704,1.2916563,,,,,,,,,,,,,, -34500,2.4288037,1.237809,,,,,,,,,,,,,, -34600,5.359385,1.2811917,,,,,,,,,,,,,, -34700,3.9585855,1.2756073,,,,,,,,,,,,,, -34800,2.8973637,1.335129,,,,,,,,,,,,,, -34900,2.994602,1.3227857,,,,,,,,,,,,,, -35000,2.3671293,1.2620392,,,,,,,,,,,,,, -35100,2.5153449,1.2402037,,,,,,,,,,,,,, -35200,4.0569153,1.2512331,,,,,,,,,,,,,, -35300,2.8433535,1.2875376,,,,,,,,,,,,,, -35400,2.539412,1.2257887,,,,,,,,,,,,,, -35500,3.0465367,1.2917008,,,,,,,,,,,,,, -35600,4.43433,1.2565136,,,,,,,,,,,,,, -35622,,,0.30550236,0.0994893766874046,0.47197154,0.1370960734526004,5348.0,0.2628657,0.0838462007190299,2472.0,27381.70995497704,30022.285489320755,27381.70995497704,2638.0246579647064,0.9644997119903564,0.0 -35700,2.6322744,1.2112526,,,,,,,,,,,,,, -35800,2.655934,1.2468009,,,,,,,,,,,,,, -35900,2.5566332,1.3005716,,,,,,,,,,,,,, -36000,2.7686021,1.312766,,,,,,,,,,,,,, -36100,2.9541516,1.278796,,,,,,,,,,,,,, -36200,2.62995,1.2441287,,,,,,,,,,,,,, -36300,3.1873944,1.242209,,,,,,,,,,,,,, -36400,2.2400362,1.2208035,,,,,,,,,,,,,, -36500,2.7148147,1.227902,,,,,,,,,,,,,, -36600,2.5010378,1.2616018,,,,,,,,,,,,,, -36700,2.9595768,1.2076274,,,,,,,,,,,,,, -36800,2.266072,1.256128,,,,,,,,,,,,,, -36900,2.6922467,1.2285523,,,,,,,,,,,,,, -37000,3.1740305,1.2034086,,,,,,,,,,,,,, -37100,3.0737364,1.2570181,,,,,,,,,,,,,, -37200,2.472334,1.3239928,,,,,,,,,,,,,, -37300,2.848467,1.2451717,,,,,,,,,,,,,, -37400,2.5270154,1.1976638,,,,,,,,,,,,,, -37493,,,0.2715685,0.0926955085160065,0.45568103,0.131959798024658,5348.0,0.25255555,0.0803119858631405,2472.0,28821.724100589752,31593.214646816254,28821.724100589752,2768.7975289821625,1.0214431285858154,0.0 -37500,2.967374,1.2279456,,,,,,,,,,,,,, -37600,2.0965202,1.2118561,,,,,,,,,,,,,, -37700,3.0302222,1.2678217,,,,,,,,,,,,,, -37800,3.3210022,1.211503,,,,,,,,,,,,,, -37900,3.3940828,1.2387122,,,,,,,,,,,,,, -38000,4.0725594,1.2232852,,,,,,,,,,,,,, -38100,3.4496813,1.2294974,,,,,,,,,,,,,, -38200,3.4256866,1.2195829,,,,,,,,,,,,,, -38300,2.6897137,1.2221779,,,,,,,,,,,,,, -38400,2.8846397,1.21846,,,,,,,,,,,,,, -38500,2.1327894,1.2606268,,,,,,,,,,,,,, -38600,2.7938776,1.2029492,,,,,,,,,,,,,, -38700,2.9505358,1.2064615,,,,,,,,,,,,,, -38800,3.4386313,1.2160801,,,,,,,,,,,,,, -38900,4.550839,1.1539662,,,,,,,,,,,,,, -39000,3.1066802,1.2640461,,,,,,,,,,,,,, -39100,3.0245352,1.2054143,,,,,,,,,,,,,, -39200,2.526556,1.2375398,,,,,,,,,,,,,, -39300,2.0782328,1.1894585,,,,,,,,,,,,,, -39364,,,0.22365406,0.0767750063533088,0.44460142,0.1275765855353987,5348.0,0.24169599,0.0783011394796173,2472.0,30262.321855068207,33165.71136927605,30262.321855068207,2900.555379629135,1.07637619972229,0.0 -39400,2.427227,1.1647288,,,,,,,,,,,,,, -39500,2.861053,1.2059265,,,,,,,,,,,,,, -39600,2.661113,1.1837379,,,,,,,,,,,,,, -39700,2.8231704,1.18569,,,,,,,,,,,,,, -39800,2.9067147,1.1773751,,,,,,,,,,,,,, -39900,4.395349,1.209171,,,,,,,,,,,,,, -40000,3.108378,1.1416346,,,,,,,,,,,,,, -40100,2.3630033,1.1326388,,,,,,,,,,,,,, -40200,4.3427067,1.1885768,,,,,,,,,,,,,, -40300,4.1900253,1.1414781,,,,,,,,,,,,,, -40400,2.5579119,1.234183,,,,,,,,,,,,,, -40500,2.1554549,1.259736,,,,,,,,,,,,,, -40600,3.3908796,1.1767614,,,,,,,,,,,,,, -40700,2.0488605,1.1468451,,,,,,,,,,,,,, -40800,2.477001,1.1663026,,,,,,,,,,,,,, -40900,2.9479797,1.1823242,,,,,,,,,,,,,, -41000,4.6881056,1.188978,,,,,,,,,,,,,, -41100,4.0140476,1.2021308,,,,,,,,,,,,,, -41200,3.8564374,1.1741364,,,,,,,,,,,,,, -41230,,,0.25303861,0.08619467289108,0.429852,0.1248539733724668,5348.0,0.23356742,0.07413726565515,2472.0,31702.706634283066,34737.552344322205,31702.706634283066,3031.871179819107,1.132845163345337,0.0 -41300,3.1884484,1.2199144,,,,,,,,,,,,,, -41400,3.8731701,1.1767566,,,,,,,,,,,,,, -41500,3.4836547,1.1795126,,,,,,,,,,,,,, -41600,2.9896474,1.1430947,,,,,,,,,,,,,, -41700,2.6736617,1.1383597,,,,,,,,,,,,,, -41800,2.2668567,1.1318443,,,,,,,,,,,,,, -41900,3.3534603,1.2320782,,,,,,,,,,,,,, -42000,2.7443273,1.1352273,,,,,,,,,,,,,, -42100,2.4508054,1.1430051,,,,,,,,,,,,,, -42200,3.6166344,1.1496022,,,,,,,,,,,,,, -42300,3.5115244,1.1506479,,,,,,,,,,,,,, -42400,2.9496682,1.188964,,,,,,,,,,,,,, -42500,3.1188073,1.166732,,,,,,,,,,,,,, -42600,2.9319797,1.1637969,,,,,,,,,,,,,, -42700,3.0676644,1.1660371,,,,,,,,,,,,,, -42800,3.0847697,1.1720164,,,,,,,,,,,,,, -42900,2.7303417,1.1332257,,,,,,,,,,,,,, -43000,2.8107827,1.1522726,,,,,,,,,,,,,, -43100,2.729191,1.1460204,,,,,,,,,,,,,, -43113,,,0.21915351,0.0752219014009197,0.42223984,0.1220541239850546,5348.0,0.22784421,0.0733857372087827,2472.0,33142.88335800171,36309.45052433014,33142.88335800171,3163.4506227970123,1.1886072158813477,0.0 -43200,4.4779787,1.1999781,,,,,,,,,,,,,, -43300,2.8841536,1.1370671,,,,,,,,,,,,,, -43400,3.9104257,1.1495973,,,,,,,,,,,,,, -43500,3.0529153,1.1866202,,,,,,,,,,,,,, -43600,2.1961606,1.1028596,,,,,,,,,,,,,, -43700,3.9719968,1.1772512,,,,,,,,,,,,,, -43800,2.6499333,1.1444538,,,,,,,,,,,,,, -43900,2.652928,1.145806,,,,,,,,,,,,,, -44000,2.2313347,1.1666857,,,,,,,,,,,,,, -44100,4.3347425,1.1168387,,,,,,,,,,,,,, -44200,2.444234,1.1811002,,,,,,,,,,,,,, -44300,3.3667526,1.0849985,,,,,,,,,,,,,, -44400,4.704138,1.1568394,,,,,,,,,,,,,, -44500,3.2890708,1.1693344,,,,,,,,,,,,,, -44600,2.642783,1.128084,,,,,,,,,,,,,, -44700,3.980417,1.1503158,,,,,,,,,,,,,, -44800,2.6039963,1.2144277,,,,,,,,,,,,,, -44900,2.7426474,1.1149975,,,,,,,,,,,,,, -44985,,,0.21478947,0.0762512871082151,0.41844097,0.120547998107688,5348.0,0.22397295,0.0724107813864684,2472.0,34583.50731277466,37881.03183031082,34583.50731277466,3294.2698872089386,1.2424731254577637,0.0 -45000,3.4059894,1.1652474,,,,,,,,,,,,,, -45100,2.8375874,1.1296146,,,,,,,,,,,,,, -45200,4.496663,1.1822468,,,,,,,,,,,,,, -45300,2.5074663,1.2081871,,,,,,,,,,,,,, -45400,2.335813,1.181121,,,,,,,,,,,,,, -45500,2.8802996,1.1240027,,,,,,,,,,,,,, -45600,2.3207529,1.093505,,,,,,,,,,,,,, -45700,3.4279056,1.2111707,,,,,,,,,,,,,, -45800,3.325857,1.1288247,,,,,,,,,,,,,, -45900,2.3865764,1.1223754,,,,,,,,,,,,,, -46000,3.3055985,1.148633,,,,,,,,,,,,,, -46100,2.9608824,1.152618,,,,,,,,,,,,,, -46200,4.78436,1.0880463,,,,,,,,,,,,,, -46300,2.535629,1.11404,,,,,,,,,,,,,, -46400,3.5085936,1.1322892,,,,,,,,,,,,,, -46500,2.5743396,1.1480896,,,,,,,,,,,,,, -46600,4.4819465,1.1017568,,,,,,,,,,,,,, -46700,5.005669,1.1293312,,,,,,,,,,,,,, -46800,2.6807892,1.0942148,,,,,,,,,,,,,, -46864,,,0.21869154,0.0752600940614139,0.41614845,0.1198432084343049,5348.0,0.22283582,0.0718826803160481,2472.0,36024.03443598747,39452.92549777031,36024.03443598747,3425.495941877365,1.2960054874420166,0.0 -46900,2.586648,1.1528312,,,,,,,,,,,,,, -47000,3.1750712,1.1581496,,,,,,,,,,,,,, -47100,3.1502845,1.1074352,,,,,,,,,,,,,, -47200,4.222103,1.1564015,,,,,,,,,,,,,, -47300,4.3279767,1.1599236,,,,,,,,,,,,,, -47400,2.7730637,1.0875992,,,,,,,,,,,,,, -47500,2.6101797,1.1374002,,,,,,,,,,,,,, -47600,3.0916615,1.0945446,,,,,,,,,,,,,, -47700,2.120076,1.136825,,,,,,,,,,,,,, -47800,3.2265406,1.0762395,,,,,,,,,,,,,, -47900,2.7727976,1.1708032,,,,,,,,,,,,,, -48000,,,0.23044662,0.0793706394096873,0.4161281,0.1197852805159446,5348.0,0.22271666,0.0716795645197327,2472.0,36883.611184597015,40442.77611327171,36883.611184597015,3555.660044193268,1.351287841796875,0.0 -48000,,,,,,,,,,,36883.611184597015,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 9afd2fd94..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,28 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -181.3793807029724,0.0,15.840157985687256,1,0,15.840157985687256,30.536505,2472,3.361383624804501,197.21963810920715,31.12199,3.282438151130771,30.351221,5348,3.0459561485658013 -311.0581033229828,0.0296719074249267,1456.1781809329989,1866,0,1456.1781809329989,1.1032181,2472,0.3216947982044563,1767.3490042686462,1.4378089,0.3935562282917147,1.5070161,5348,0.3872867528505363 -440.4767470359802,0.0824496746063232,2896.190999746322,3736,0,2896.190999746322,0.7128284,2472,0.2208681169134523,3336.9190402030945,1.0180577,0.3027155629291672,1.064015,5348,0.2916574142908175 -571.3350307941437,0.1324553489685058,4336.195417165756,5605,0,4336.195417165756,0.6446429,2472,0.1994597119818008,4907.917762994766,0.8513874,0.2599757026892705,0.9749225,5348,0.2722998349054327 -701.4711322784424,0.1807506084442138,5776.223026514053,7472,0,5776.223026514053,0.6244865,2472,0.1915381959254971,6478.2153515815735,0.8907499,0.2634945591928149,0.97102195,5348,0.2661498209061857 -831.224268913269,0.2314159870147705,7216.673100471497,9327,0,7216.673100471497,0.57379514,2472,0.1766904312148355,8048.553774595261,0.8192644,0.243019991015274,0.89382845,5348,0.2470625718064821 -962.327393770218,0.2845883369445801,8657.137380123138,11202,0,8657.137380123138,0.54874396,2472,0.1715109784087908,9620.26033258438,0.77553225,0.2367157819791857,0.8555793,5348,0.2415401102561379 -1092.690021276474,0.3334639072418213,10097.43169784546,13074,0,10097.43169784546,0.5158074,2472,0.1610708264781752,11191.052989721298,0.687591,0.2090459084094702,0.8205807,5348,0.2295779951147455 -1223.362502336502,0.3845009803771972,11537.74842619896,14952,0,11537.74842619896,0.49638095,2472,0.1571303800296549,12762.179047107697,0.63921106,0.1976266962376962,0.79561263,5348,0.224673431360244 -1354.249887228012,0.4348087310791015,12977.63620853424,16828,0,12977.63620853424,0.49312958,2472,0.1551601568053947,14333.092150449753,0.7159693,0.2213396801517847,0.7896422,5348,0.2200198885853037 -1484.704974651337,0.4846115112304687,14418.066945314407,18695,0,14418.066945314407,0.45597544,2472,0.1456949606970934,15904.111874103546,0.6039645,0.1899518739973749,0.74285376,5348,0.2117458509128474 -1616.0522694587708,0.5335447788238525,15858.226209878922,20562,0,15858.226209878922,0.43188784,2472,0.1362703877480551,17475.754253149033,0.6238497,0.1896832201191496,0.70414096,5348,0.2015215733222626 -1746.9885816574097,0.5863668918609619,17298.253359794617,22428,0,17298.253359794617,0.41050902,2472,0.1309690654642211,19046.85565328598,0.5352325,0.1693792467028832,0.69400257,5348,0.1981327900981878 -1876.3985664844515,0.6419112682342529,18738.440461158752,24303,0,18738.440461158752,0.40436313,2472,0.1286129222269615,20616.59484243393,0.5379182,0.1729687250859889,0.6751597,5348,0.1924075808335827 -2006.694060087204,0.6964304447174072,20178.6946849823,26186,0,20178.6946849823,0.38142642,2472,0.1244693599821258,22187.284984588623,0.5195341,0.1666871597679435,0.64813316,5348,0.1873002693648204 -2144.9430723190308,0.751788854598999,21618.925753593445,28066,0,21618.925753593445,0.3617607,2472,0.1158978733776125,23765.90765786171,0.31689537,0.106635261368807,0.6243423,5348,0.1771532289987159 -2276.684823989868,0.8093042373657227,23059.09961295128,29931,0,23059.09961295128,0.33669013,2472,0.1090731826214124,25337.965800762177,0.28561276,0.0947428783980657,0.57865715,5348,0.1673827201019531 -2408.727742433548,0.86600661277771,24499.12723898888,31803,0,24499.12723898888,0.32298186,2472,0.1034671866431052,26910.178052663803,0.27346468,0.0927704951650813,0.5726918,5348,0.1643415043880398 -2540.49427819252,0.921947717666626,25939.097049236298,33678,0,25939.097049236298,0.306992,2472,0.0976377632888509,28482.05617165565,0.24806608,0.0840600825978491,0.53813505,5348,0.1538275872056537 -2672.538813829422,0.9750196933746338,27379.475867271423,35561,0,27379.475867271423,0.28383914,2472,0.0896553124936526,30054.61894655228,0.23072459,0.0780527067653141,0.5039565,5348,0.1447715226353341 -2805.0091433525085,1.0354893207550049,28820.02853703499,37440,0,28820.02853703499,0.26603636,2472,0.0848617797006073,31627.78924536705,0.20172007,0.0698390438583079,0.47738373,5348,0.138592544676907 -2936.458508014679,1.0954077243804932,30260.46949505806,39309,0,30260.46949505806,0.25134808,2472,0.0799260658501411,33199.82506918907,0.22671203,0.073330542067458,0.45624575,5348,0.1313032816165751 -3069.434670448303,1.153717041015625,31700.723588705063,41183,0,31700.723588705063,0.23483199,2472,0.0743810046107285,34773.198744535446,0.17770247,0.0616869580819994,0.4342081,5348,0.1249408652500072 -3201.1860456466675,1.2138450145721436,33141.31151199341,43048,0,33141.31151199341,0.2274025,2472,0.0706842971177868,36345.68494963646,0.15259655,0.0528985621285323,0.41574052,5348,0.1196887339853442 -3333.5102150440216,1.2725615501403809,34581.25995087624,44925,0,34581.25995087624,0.21977012,2472,0.0682875307212641,37918.101712465286,0.14536211,0.0501873654556568,0.40926024,5348,0.1170337043938326 -3464.165018558502,1.3316552639007568,36021.1313123703,46805,0,36021.1313123703,0.21717538,2472,0.0679422338675278,39488.77329707146,0.13815887,0.048562661498708,0.40407094,5348,0.1158075634552072 -3596.8922271728516,1.3875088691711426,36923.73557591438,48000,0,36923.73557591438,0.21718435,2472,0.06792192228789633,40524.21762919426,0.15959443,0.053562875984863235,0.40380514,5348,0.11573998088378694 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/measurements.csv deleted file mode 100644 index e07d5f562..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/measurements.csv +++ /dev/null @@ -1,509 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,22.598402,33.018986,,,,,,,,,,,,,, -1,,,31.12199,3.282438151130771,30.351221,3.0459561485658013,5348.0,30.536505,3.361383624804501,2472.0,15.840157985687256,197.21963810920715,15.840157985687256,181.3793807029724,0.0,0.0 -100,4.004505,5.7958865,,,,,,,,,,,,,, -200,2.391825,5.0599575,,,,,,,,,,,,,, -300,2.4677982,3.7760694,,,,,,,,,,,,,, -400,3.0403495,3.3089952,,,,,,,,,,,,,, -500,2.4812646,3.0648465,,,,,,,,,,,,,, -600,2.2952778,2.87627,,,,,,,,,,,,,, -700,2.478006,2.8306448,,,,,,,,,,,,,, -800,5.3073893,2.6918132,,,,,,,,,,,,,, -900,3.2582383,2.6748648,,,,,,,,,,,,,, -1000,3.890757,2.5825086,,,,,,,,,,,,,, -1100,2.5635576,2.458304,,,,,,,,,,,,,, -1200,2.5699742,2.4497383,,,,,,,,,,,,,, -1300,2.4727783,2.3307934,,,,,,,,,,,,,, -1400,3.1793873,2.3054821,,,,,,,,,,,,,, -1500,3.078123,2.3373687,,,,,,,,,,,,,, -1600,5.1972966,2.2939796,,,,,,,,,,,,,, -1700,4.453969,2.2445626,,,,,,,,,,,,,, -1800,2.0023088,2.258729,,,,,,,,,,,,,, -1866,,,1.4378089,0.3935562282917147,1.5070161,0.3872867528505363,5348.0,1.1032181,0.3216947982044563,2472.0,1456.1781809329989,1767.3490042686462,1456.1781809329989,311.0581033229828,0.0296719074249267,0.0 -1900,2.643239,2.164547,,,,,,,,,,,,,, -2000,2.5756795,2.139079,,,,,,,,,,,,,, -2100,2.606976,2.1484218,,,,,,,,,,,,,, -2200,2.4134338,2.1612408,,,,,,,,,,,,,, -2300,3.2076151,2.155547,,,,,,,,,,,,,, -2400,2.9048028,2.161336,,,,,,,,,,,,,, -2500,2.8748124,2.2174892,,,,,,,,,,,,,, -2600,2.9985373,2.1462615,,,,,,,,,,,,,, -2700,2.265158,2.146223,,,,,,,,,,,,,, -2800,2.7730699,2.2158782,,,,,,,,,,,,,, -2900,2.5874665,2.0791306,,,,,,,,,,,,,, -3000,4.2925506,2.1161716,,,,,,,,,,,,,, -3100,2.328296,2.0975628,,,,,,,,,,,,,, -3200,2.9015594,2.105783,,,,,,,,,,,,,, -3300,3.6202428,2.106225,,,,,,,,,,,,,, -3400,3.899725,2.173953,,,,,,,,,,,,,, -3500,1.9632603,2.084718,,,,,,,,,,,,,, -3600,3.2614646,2.11952,,,,,,,,,,,,,, -3700,2.3742063,1.9985679,,,,,,,,,,,,,, -3736,,,1.0180577,0.3027155629291672,1.064015,0.2916574142908175,5348.0,0.7128284,0.2208681169134523,2472.0,2896.190999746322,3336.9190402030945,2896.190999746322,440.4767470359802,0.0824496746063232,0.0 -3800,3.4248717,2.0694554,,,,,,,,,,,,,, -3900,4.192684,2.0303657,,,,,,,,,,,,,, -4000,3.313357,2.0541804,,,,,,,,,,,,,, -4100,3.8359668,1.9805275,,,,,,,,,,,,,, -4200,4.312417,2.073807,,,,,,,,,,,,,, -4300,3.359871,2.0482638,,,,,,,,,,,,,, -4400,4.272013,2.0825748,,,,,,,,,,,,,, -4500,2.6972034,2.0940115,,,,,,,,,,,,,, -4600,2.8963664,2.0319414,,,,,,,,,,,,,, -4700,2.780874,2.0582645,,,,,,,,,,,,,, -4800,2.8333597,2.0288656,,,,,,,,,,,,,, -4900,5.3776903,2.0402582,,,,,,,,,,,,,, -5000,2.4699655,2.0818682,,,,,,,,,,,,,, -5100,2.8024101,1.9870381,,,,,,,,,,,,,, -5200,2.754062,2.024565,,,,,,,,,,,,,, -5300,3.2122695,1.9938073,,,,,,,,,,,,,, -5400,5.1593313,2.0528905,,,,,,,,,,,,,, -5500,3.3096318,1.9410414,,,,,,,,,,,,,, -5600,2.8930516,2.0138903,,,,,,,,,,,,,, -5605,,,0.8513874,0.2599757026892705,0.9749225,0.2722998349054327,5348.0,0.6446429,0.1994597119818008,2472.0,4336.195417165756,4907.917762994766,4336.195417165756,571.3350307941437,0.1324553489685058,0.0 -5700,3.162447,2.0260072,,,,,,,,,,,,,, -5800,2.3473082,1.9469999,,,,,,,,,,,,,, -5900,4.86016,2.1975453,,,,,,,,,,,,,, -6000,4.03227,2.0602138,,,,,,,,,,,,,, -6100,2.9007974,2.9852583,,,,,,,,,,,,,, -6200,2.5030978,2.0945258,,,,,,,,,,,,,, -6300,2.6152334,2.0091143,,,,,,,,,,,,,, -6400,1.852979,2.121009,,,,,,,,,,,,,, -6500,1.9105072,2.0360036,,,,,,,,,,,,,, -6600,3.318472,2.0120962,,,,,,,,,,,,,, -6700,3.172434,2.0374854,,,,,,,,,,,,,, -6800,4.5438576,1.9800771,,,,,,,,,,,,,, -6900,3.6505554,1.9582407,,,,,,,,,,,,,, -7000,4.5426903,1.9615184,,,,,,,,,,,,,, -7100,4.555605,2.0013406,,,,,,,,,,,,,, -7200,3.2534082,2.0303922,,,,,,,,,,,,,, -7300,3.077942,1.9759678,,,,,,,,,,,,,, -7400,5.899155,1.9915137,,,,,,,,,,,,,, -7472,,,0.8907499,0.2634945591928149,0.97102195,0.2661498209061857,5348.0,0.6244865,0.1915381959254971,2472.0,5776.223026514053,6478.2153515815735,5776.223026514053,701.4711322784424,0.1807506084442138,0.0 -7500,4.033106,1.9566784,,,,,,,,,,,,,, -7600,3.1090055,1.9046273,,,,,,,,,,,,,, -7700,3.8848262,1.9422774,,,,,,,,,,,,,, -7800,2.5309596,1.954315,,,,,,,,,,,,,, -7900,3.4335594,1.9744701,,,,,,,,,,,,,, -8000,2.094106,1.9884642,,,,,,,,,,,,,, -8100,2.009635,1.9560281,,,,,,,,,,,,,, -8200,4.629459,1.8684502,,,,,,,,,,,,,, -8300,2.2262366,1.8776187,,,,,,,,,,,,,, -8400,2.3593116,1.9051846,,,,,,,,,,,,,, -8500,3.421912,1.8453041,,,,,,,,,,,,,, -8600,3.7629588,1.9940926,,,,,,,,,,,,,, -8700,2.6499147,2.0416596,,,,,,,,,,,,,, -8800,3.1045249,1.8721411,,,,,,,,,,,,,, -8900,3.1781046,1.840653,,,,,,,,,,,,,, -9000,2.4384377,1.9322294,,,,,,,,,,,,,, -9100,3.4908752,1.9637588,,,,,,,,,,,,,, -9200,2.2548845,1.9936152,,,,,,,,,,,,,, -9300,3.3851745,1.8350189,,,,,,,,,,,,,, -9327,,,0.8192644,0.243019991015274,0.89382845,0.2470625718064821,5348.0,0.57379514,0.1766904312148355,2472.0,7216.673100471497,8048.553774595261,7216.673100471497,831.224268913269,0.2314159870147705,0.0 -9400,4.3691316,1.9300504,,,,,,,,,,,,,, -9500,3.455959,1.9012105,,,,,,,,,,,,,, -9600,2.035026,1.8601848,,,,,,,,,,,,,, -9700,2.980479,1.9562624,,,,,,,,,,,,,, -9800,2.4756436,1.8705006,,,,,,,,,,,,,, -9900,1.9848266,1.8510431,,,,,,,,,,,,,, -10000,2.4186146,1.8745108,,,,,,,,,,,,,, -10100,4.2333403,1.8556246,,,,,,,,,,,,,, -10200,4.115133,1.8660597,,,,,,,,,,,,,, -10300,2.4362686,1.876783,,,,,,,,,,,,,, -10400,4.4592986,1.9404185,,,,,,,,,,,,,, -10500,2.8208992,1.8746549,,,,,,,,,,,,,, -10600,2.5062678,1.8835528,,,,,,,,,,,,,, -10700,3.2962425,1.89064,,,,,,,,,,,,,, -10800,3.3371572,1.9052867,,,,,,,,,,,,,, -10900,3.8313315,1.7727915,,,,,,,,,,,,,, -11000,4.1214566,1.9309185,,,,,,,,,,,,,, -11100,2.6807806,1.8905176,,,,,,,,,,,,,, -11200,2.1265912,1.8657237,,,,,,,,,,,,,, -11202,,,0.77553225,0.2367157819791857,0.8555793,0.2415401102561379,5348.0,0.54874396,0.1715109784087908,2472.0,8657.137380123138,9620.26033258438,8657.137380123138,962.327393770218,0.2845883369445801,0.0 -11300,6.33875,1.8647983,,,,,,,,,,,,,, -11400,2.0215347,1.9545947,,,,,,,,,,,,,, -11500,2.0989826,1.8319385,,,,,,,,,,,,,, -11600,2.3026588,1.784651,,,,,,,,,,,,,, -11700,4.959821,1.7686601,,,,,,,,,,,,,, -11800,4.1056743,1.86542,,,,,,,,,,,,,, -11900,1.7103857,1.8952332,,,,,,,,,,,,,, -12000,4.2939067,1.8493944,,,,,,,,,,,,,, -12100,2.4461756,1.853953,,,,,,,,,,,,,, -12200,4.951642,1.8499284,,,,,,,,,,,,,, -12300,2.5218391,1.8841981,,,,,,,,,,,,,, -12400,2.6794014,1.828579,,,,,,,,,,,,,, -12500,4.118846,1.8822178,,,,,,,,,,,,,, -12600,2.4527068,1.8795476,,,,,,,,,,,,,, -12700,2.1501417,1.8235875,,,,,,,,,,,,,, -12800,3.0315068,1.8779513,,,,,,,,,,,,,, -12900,2.0818782,1.8514336,,,,,,,,,,,,,, -13000,4.2340145,1.785016,,,,,,,,,,,,,, -13074,,,0.687591,0.2090459084094702,0.8205807,0.2295779951147455,5348.0,0.5158074,0.1610708264781752,2472.0,10097.43169784546,11191.052989721298,10097.43169784546,1092.690021276474,0.3334639072418213,0.0 -13100,2.4772785,1.8452414,,,,,,,,,,,,,, -13200,3.8762343,1.7701594,,,,,,,,,,,,,, -13300,2.8173294,1.870097,,,,,,,,,,,,,, -13400,1.5132339,1.7920176,,,,,,,,,,,,,, -13500,2.5127203,1.8380208,,,,,,,,,,,,,, -13600,2.5098894,1.7997038,,,,,,,,,,,,,, -13700,2.6904135,1.914872,,,,,,,,,,,,,, -13800,2.7258804,1.7380786,,,,,,,,,,,,,, -13900,2.6785357,1.8272003,,,,,,,,,,,,,, -14000,2.6398032,1.7333369,,,,,,,,,,,,,, -14100,2.0586755,1.7651445,,,,,,,,,,,,,, -14200,2.2423573,1.7970889,,,,,,,,,,,,,, -14300,1.7796319,1.759606,,,,,,,,,,,,,, -14400,2.6165988,1.8116392,,,,,,,,,,,,,, -14500,2.4639535,1.8296741,,,,,,,,,,,,,, -14600,3.293213,1.72166,,,,,,,,,,,,,, -14700,2.6785603,1.7732341,,,,,,,,,,,,,, -14800,2.7277803,1.714175,,,,,,,,,,,,,, -14900,2.6251302,1.8055775,,,,,,,,,,,,,, -14952,,,0.63921106,0.1976266962376962,0.79561263,0.224673431360244,5348.0,0.49638095,0.1571303800296549,2472.0,11537.74842619896,12762.179047107697,11537.74842619896,1223.362502336502,0.3845009803771972,0.0 -15000,2.1026168,1.786244,,,,,,,,,,,,,, -15100,2.011616,1.7998711,,,,,,,,,,,,,, -15200,3.8916097,1.7595292,,,,,,,,,,,,,, -15300,2.2790086,1.7697138,,,,,,,,,,,,,, -15400,2.148046,1.8184698,,,,,,,,,,,,,, -15500,2.6371791,1.7855818,,,,,,,,,,,,,, -15600,3.5262907,1.7480576,,,,,,,,,,,,,, -15700,1.9236618,1.6894847,,,,,,,,,,,,,, -15800,2.23138,1.7657907,,,,,,,,,,,,,, -15900,2.5353136,1.7934915,,,,,,,,,,,,,, -16000,2.7865732,1.8073996,,,,,,,,,,,,,, -16100,2.4791517,1.7538838,,,,,,,,,,,,,, -16200,2.8003185,1.7564515,,,,,,,,,,,,,, -16300,3.326947,1.7989433,,,,,,,,,,,,,, -16400,2.5646617,1.7892225,,,,,,,,,,,,,, -16500,2.6652105,1.7938675,,,,,,,,,,,,,, -16600,2.8031359,1.736305,,,,,,,,,,,,,, -16700,4.322357,1.7319598,,,,,,,,,,,,,, -16800,3.1413157,1.754501,,,,,,,,,,,,,, -16828,,,0.7159693,0.2213396801517847,0.7896422,0.2200198885853037,5348.0,0.49312958,0.1551601568053947,2472.0,12977.63620853424,14333.092150449753,12977.63620853424,1354.249887228012,0.4348087310791015,0.0 -16900,2.1441362,1.77529,,,,,,,,,,,,,, -17000,2.797697,1.7095095,,,,,,,,,,,,,, -17100,1.9651322,1.7168477,,,,,,,,,,,,,, -17200,1.9226518,1.8013563,,,,,,,,,,,,,, -17300,2.9793622,1.716644,,,,,,,,,,,,,, -17400,1.9251335,1.7296664,,,,,,,,,,,,,, -17500,1.9168899,1.719571,,,,,,,,,,,,,, -17600,2.5611498,1.6980679,,,,,,,,,,,,,, -17700,3.1567047,1.6665269,,,,,,,,,,,,,, -17800,2.7761807,1.7395227,,,,,,,,,,,,,, -17900,2.6536815,1.7790061,,,,,,,,,,,,,, -18000,1.7362542,1.7044365,,,,,,,,,,,,,, -18100,2.4870868,1.7414573,,,,,,,,,,,,,, -18200,1.6495303,1.6810598,,,,,,,,,,,,,, -18300,2.6160095,1.7301047,,,,,,,,,,,,,, -18400,1.6465428,1.7350398,,,,,,,,,,,,,, -18500,4.339133,1.7401094,,,,,,,,,,,,,, -18600,4.0876355,1.7093064,,,,,,,,,,,,,, -18695,,,0.6039645,0.1899518739973749,0.74285376,0.2117458509128474,5348.0,0.45597544,0.1456949606970934,2472.0,14418.066945314407,15904.111874103546,14418.066945314407,1484.704974651337,0.4846115112304687,0.0 -18700,1.862966,1.6825786,,,,,,,,,,,,,, -18800,4.0202675,1.6591713,,,,,,,,,,,,,, -18900,2.5237901,1.7585344,,,,,,,,,,,,,, -19000,4.0807657,1.6837149,,,,,,,,,,,,,, -19100,2.7733235,1.6419705,,,,,,,,,,,,,, -19200,3.105137,1.746102,,,,,,,,,,,,,, -19300,2.1295843,1.6169705,,,,,,,,,,,,,, -19400,1.7818983,1.6572919,,,,,,,,,,,,,, -19500,2.2731123,1.7309505,,,,,,,,,,,,,, -19600,2.0147264,1.7311175,,,,,,,,,,,,,, -19700,2.3608606,1.6582881,,,,,,,,,,,,,, -19800,4.1272388,1.6805764,,,,,,,,,,,,,, -19900,3.1093874,1.643548,,,,,,,,,,,,,, -20000,2.8557386,1.6507229,,,,,,,,,,,,,, -20100,3.0851607,1.6378286,,,,,,,,,,,,,, -20200,2.7678266,1.6729093,,,,,,,,,,,,,, -20300,1.618383,1.6678828,,,,,,,,,,,,,, -20400,2.9612486,1.7520777,,,,,,,,,,,,,, -20500,2.4455864,1.6732858,,,,,,,,,,,,,, -20562,,,0.6238497,0.1896832201191496,0.70414096,0.2015215733222626,5348.0,0.43188784,0.1362703877480551,2472.0,15858.226209878922,17475.754253149033,15858.226209878922,1616.0522694587708,0.5335447788238525,0.0 -20600,2.2117348,1.6297655,,,,,,,,,,,,,, -20700,2.7098129,1.6499041,,,,,,,,,,,,,, -20800,1.9896488,1.5969695,,,,,,,,,,,,,, -20900,4.6987367,1.6857277,,,,,,,,,,,,,, -21000,3.0744708,1.6642233,,,,,,,,,,,,,, -21100,2.303685,1.7154105,,,,,,,,,,,,,, -21200,2.935488,1.6937186,,,,,,,,,,,,,, -21300,2.2128808,1.6396735,,,,,,,,,,,,,, -21400,2.436888,1.6631995,,,,,,,,,,,,,, -21500,2.8470445,1.6215763,,,,,,,,,,,,,, -21600,2.8805535,1.731376,,,,,,,,,,,,,, -21700,2.2718706,1.662362,,,,,,,,,,,,,, -21800,3.0585282,1.6529742,,,,,,,,,,,,,, -21900,1.797376,1.6450918,,,,,,,,,,,,,, -22000,2.9539614,1.667869,,,,,,,,,,,,,, -22100,1.6323379,1.6860899,,,,,,,,,,,,,, -22200,2.216319,1.6282493,,,,,,,,,,,,,, -22300,3.3640382,1.6307416,,,,,,,,,,,,,, -22400,3.8347597,1.7035396,,,,,,,,,,,,,, -22428,,,0.5352325,0.1693792467028832,0.69400257,0.1981327900981878,5348.0,0.41050902,0.1309690654642211,2472.0,17298.253359794617,19046.85565328598,17298.253359794617,1746.9885816574097,0.5863668918609619,0.0 -22500,2.1917922,1.6378021,,,,,,,,,,,,,, -22600,2.5995448,1.5635103,,,,,,,,,,,,,, -22700,2.0620127,1.5760912,,,,,,,,,,,,,, -22800,2.9057684,1.664867,,,,,,,,,,,,,, -22900,2.2544973,1.6681626,,,,,,,,,,,,,, -23000,2.206677,1.6468787,,,,,,,,,,,,,, -23100,1.7960582,1.6002141,,,,,,,,,,,,,, -23200,1.5942003,1.5745225,,,,,,,,,,,,,, -23300,2.294855,1.6329638,,,,,,,,,,,,,, -23400,1.8136108,1.5854988,,,,,,,,,,,,,, -23500,2.66548,1.6279402,,,,,,,,,,,,,, -23600,2.4261742,1.5852739,,,,,,,,,,,,,, -23700,1.7867656,1.5312396,,,,,,,,,,,,,, -23800,2.5891573,1.6533749,,,,,,,,,,,,,, -23900,2.0057113,1.5929497,,,,,,,,,,,,,, -24000,1.7896525,1.5774958,,,,,,,,,,,,,, -24100,3.9154854,1.5388798,,,,,,,,,,,,,, -24200,2.0178747,1.5856075,,,,,,,,,,,,,, -24300,2.4627833,1.6279061,,,,,,,,,,,,,, -24303,,,0.5379182,0.1729687250859889,0.6751597,0.1924075808335827,5348.0,0.40436313,0.1286129222269615,2472.0,18738.440461158752,20616.59484243393,18738.440461158752,1876.3985664844515,0.6419112682342529,0.0 -24400,1.7218118,1.6151898,,,,,,,,,,,,,, -24500,2.5628135,1.6164975,,,,,,,,,,,,,, -24600,3.1111221,1.6290329,,,,,,,,,,,,,, -24700,2.3581016,1.5890684,,,,,,,,,,,,,, -24800,3.8174624,1.5473932,,,,,,,,,,,,,, -24900,1.6390427,1.5856676,,,,,,,,,,,,,, -25000,1.97245,1.6538135,,,,,,,,,,,,,, -25100,1.9546874,1.5226202,,,,,,,,,,,,,, -25200,2.2257926,1.6107221,,,,,,,,,,,,,, -25300,2.1938524,1.5041163,,,,,,,,,,,,,, -25400,2.783727,1.5958439,,,,,,,,,,,,,, -25500,2.0905602,1.5265669,,,,,,,,,,,,,, -25600,2.1048708,1.5289412,,,,,,,,,,,,,, -25700,2.169258,1.5288694,,,,,,,,,,,,,, -25800,2.1819575,1.6206025,,,,,,,,,,,,,, -25900,3.3406096,1.5798799,,,,,,,,,,,,,, -26000,2.715909,1.6060001,,,,,,,,,,,,,, -26100,2.1208003,1.5804789,,,,,,,,,,,,,, -26186,,,0.5195341,0.1666871597679435,0.64813316,0.1873002693648204,5348.0,0.38142642,0.1244693599821258,2472.0,20178.6946849823,22187.284984588623,20178.6946849823,2006.694060087204,0.6964304447174072,0.0 -26200,2.1314309,1.5608857,,,,,,,,,,,,,, -26300,1.988251,1.5459056,,,,,,,,,,,,,, -26400,2.8382792,1.5225924,,,,,,,,,,,,,, -26500,2.7181394,1.5821904,,,,,,,,,,,,,, -26600,1.9127979,1.5774258,,,,,,,,,,,,,, -26700,3.7296917,1.584903,,,,,,,,,,,,,, -26800,1.9174781,1.5464101,,,,,,,,,,,,,, -26900,1.4474536,1.5224043,,,,,,,,,,,,,, -27000,1.9454246,1.5516859,,,,,,,,,,,,,, -27100,3.3264017,1.498344,,,,,,,,,,,,,, -27200,2.4256122,1.5164864,,,,,,,,,,,,,, -27300,3.2838738,1.5596045,,,,,,,,,,,,,, -27400,2.5592544,1.5878445,,,,,,,,,,,,,, -27500,1.6171665,1.5376712,,,,,,,,,,,,,, -27600,1.9443399,1.5836538,,,,,,,,,,,,,, -27700,1.6767622,1.4388075,,,,,,,,,,,,,, -27800,1.650124,1.5696574,,,,,,,,,,,,,, -27900,2.4029212,1.485382,,,,,,,,,,,,,, -28000,2.1790576,1.5095955,,,,,,,,,,,,,, -28066,,,0.31689537,0.106635261368807,0.6243423,0.1771532289987159,5348.0,0.3617607,0.1158978733776125,2472.0,21618.925753593445,23765.90765786171,21618.925753593445,2144.9430723190308,0.751788854598999,0.0 -28100,2.017218,1.5247363,,,,,,,,,,,,,, -28200,2.803313,1.5028155,,,,,,,,,,,,,, -28300,1.8070697,1.5010605,,,,,,,,,,,,,, -28400,1.6126071,1.5158578,,,,,,,,,,,,,, -28500,2.130237,1.5344197,,,,,,,,,,,,,, -28600,2.1107163,1.5283719,,,,,,,,,,,,,, -28700,3.09141,1.6364881,,,,,,,,,,,,,, -28800,1.8189465,1.5085943,,,,,,,,,,,,,, -28900,2.624939,1.5018528,,,,,,,,,,,,,, -29000,3.9401278,1.5462092,,,,,,,,,,,,,, -29100,2.2115233,1.5140873,,,,,,,,,,,,,, -29200,3.162785,1.4807612,,,,,,,,,,,,,, -29300,2.0938876,1.5068998,,,,,,,,,,,,,, -29400,1.4344373,1.534004,,,,,,,,,,,,,, -29500,4.142569,1.5337548,,,,,,,,,,,,,, -29600,2.7583075,1.4925364,,,,,,,,,,,,,, -29700,2.5219762,1.5338603,,,,,,,,,,,,,, -29800,1.8375734,1.4241499,,,,,,,,,,,,,, -29900,1.7351713,1.529066,,,,,,,,,,,,,, -29931,,,0.28561276,0.0947428783980657,0.57865715,0.1673827201019531,5348.0,0.33669013,0.1090731826214124,2472.0,23059.09961295128,25337.965800762177,23059.09961295128,2276.684823989868,0.8093042373657227,0.0 -30000,2.8621666,1.4537859,,,,,,,,,,,,,, -30100,2.454988,1.4991151,,,,,,,,,,,,,, -30200,3.0824175,1.4674032,,,,,,,,,,,,,, -30300,1.7246366,1.4726247,,,,,,,,,,,,,, -30400,2.02304,1.497907,,,,,,,,,,,,,, -30500,2.8368125,1.4391507,,,,,,,,,,,,,, -30600,1.7531117,1.4839844,,,,,,,,,,,,,, -30700,2.1879368,1.4951484,,,,,,,,,,,,,, -30800,3.88295,1.5223446,,,,,,,,,,,,,, -30900,2.206653,1.417207,,,,,,,,,,,,,, -31000,2.4546268,1.5070223,,,,,,,,,,,,,, -31100,1.667676,1.4614449,,,,,,,,,,,,,, -31200,3.4355562,1.4763261,,,,,,,,,,,,,, -31300,1.8825788,1.4767638,,,,,,,,,,,,,, -31400,3.6013708,1.515841,,,,,,,,,,,,,, -31500,1.7604653,1.4908168,,,,,,,,,,,,,, -31600,1.8783342,1.4413939,,,,,,,,,,,,,, -31700,2.617161,1.4756335,,,,,,,,,,,,,, -31800,2.6378398,1.4534501,,,,,,,,,,,,,, -31803,,,0.27346468,0.0927704951650813,0.5726918,0.1643415043880398,5348.0,0.32298186,0.1034671866431052,2472.0,24499.12723898888,26910.178052663803,24499.12723898888,2408.727742433548,0.86600661277771,0.0 -31900,1.5967785,1.4574453,,,,,,,,,,,,,, -32000,2.3255002,1.520211,,,,,,,,,,,,,, -32100,2.6465425,1.3849831,,,,,,,,,,,,,, -32200,2.1471927,1.4357487,,,,,,,,,,,,,, -32300,1.7819006,1.4552339,,,,,,,,,,,,,, -32400,1.6304705,1.3924751,,,,,,,,,,,,,, -32500,2.5121179,1.4494171,,,,,,,,,,,,,, -32600,1.7471563,1.4050903,,,,,,,,,,,,,, -32700,2.124783,1.3705364,,,,,,,,,,,,,, -32800,1.8191054,1.3862259,,,,,,,,,,,,,, -32900,3.538689,1.4490484,,,,,,,,,,,,,, -33000,2.0560641,1.4247962,,,,,,,,,,,,,, -33100,1.4346817,1.3622991,,,,,,,,,,,,,, -33200,2.415148,1.427782,,,,,,,,,,,,,, -33300,1.910236,1.4553335,,,,,,,,,,,,,, -33400,3.8761017,1.4048667,,,,,,,,,,,,,, -33500,2.3131871,1.4199079,,,,,,,,,,,,,, -33600,1.8875281,1.4072368,,,,,,,,,,,,,, -33678,,,0.24806608,0.0840600825978491,0.53813505,0.1538275872056537,5348.0,0.306992,0.0976377632888509,2472.0,25939.097049236298,28482.05617165565,25939.097049236298,2540.49427819252,0.921947717666626,0.0 -33700,3.3962936,1.4067852,,,,,,,,,,,,,, -33800,1.7786707,1.4954382,,,,,,,,,,,,,, -33900,2.2184286,1.3881149,,,,,,,,,,,,,, -34000,2.4885497,1.3589474,,,,,,,,,,,,,, -34100,3.281699,1.3940336,,,,,,,,,,,,,, -34200,2.3515487,1.384641,,,,,,,,,,,,,, -34300,2.015223,1.3998218,,,,,,,,,,,,,, -34400,2.7073913,1.3813775,,,,,,,,,,,,,, -34500,2.3606164,1.4079865,,,,,,,,,,,,,, -34600,1.6360943,1.3226067,,,,,,,,,,,,,, -34700,2.5900166,1.3437355,,,,,,,,,,,,,, -34800,3.0753827,1.432637,,,,,,,,,,,,,, -34900,3.3098655,1.3301066,,,,,,,,,,,,,, -35000,2.077057,1.3481911,,,,,,,,,,,,,, -35100,1.8162571,1.3210446,,,,,,,,,,,,,, -35200,2.5759323,1.3139614,,,,,,,,,,,,,, -35300,1.7803555,1.317189,,,,,,,,,,,,,, -35400,2.4598632,1.3588129,,,,,,,,,,,,,, -35500,2.8146634,1.3330442,,,,,,,,,,,,,, -35561,,,0.23072459,0.0780527067653141,0.5039565,0.1447715226353341,5348.0,0.28383914,0.0896553124936526,2472.0,27379.475867271423,30054.61894655228,27379.475867271423,2672.538813829422,0.9750196933746338,0.0 -35600,4.134631,1.4244053,,,,,,,,,,,,,, -35700,3.4054332,1.3332894,,,,,,,,,,,,,, -35800,1.9404902,1.30898,,,,,,,,,,,,,, -35900,2.1608043,1.4085575,,,,,,,,,,,,,, -36000,4.67159,1.342751,,,,,,,,,,,,,, -36100,2.9095623,1.3753338,,,,,,,,,,,,,, -36200,2.4937465,1.3012257,,,,,,,,,,,,,, -36300,1.7189387,1.375031,,,,,,,,,,,,,, -36400,1.8164463,1.2754803,,,,,,,,,,,,,, -36500,1.5957499,1.3519297,,,,,,,,,,,,,, -36600,2.771053,1.3434936,,,,,,,,,,,,,, -36700,3.3069959,1.292899,,,,,,,,,,,,,, -36800,1.7094413,1.285143,,,,,,,,,,,,,, -36900,2.265654,1.3462013,,,,,,,,,,,,,, -37000,1.8410301,1.3387069,,,,,,,,,,,,,, -37100,1.9839499,1.294105,,,,,,,,,,,,,, -37200,2.2638533,1.3116236,,,,,,,,,,,,,, -37300,2.918899,1.3302345,,,,,,,,,,,,,, -37400,1.8569387,1.360108,,,,,,,,,,,,,, -37440,,,0.20172007,0.0698390438583079,0.47738373,0.138592544676907,5348.0,0.26603636,0.0848617797006073,2472.0,28820.02853703499,31627.78924536705,28820.02853703499,2805.0091433525085,1.0354893207550049,0.0 -37500,2.7335799,1.2983283,,,,,,,,,,,,,, -37600,2.8258634,1.301235,,,,,,,,,,,,,, -37700,4.33758,1.3122042,,,,,,,,,,,,,, -37800,2.5716648,1.3293936,,,,,,,,,,,,,, -37900,1.8068084,1.260801,,,,,,,,,,,,,, -38000,2.0760837,1.282563,,,,,,,,,,,,,, -38100,3.3351805,1.2713871,,,,,,,,,,,,,, -38200,2.738743,1.2407678,,,,,,,,,,,,,, -38300,2.7828174,1.32124,,,,,,,,,,,,,, -38400,2.1682186,1.305762,,,,,,,,,,,,,, -38500,2.4175522,1.3221196,,,,,,,,,,,,,, -38600,2.4316144,1.2485207,,,,,,,,,,,,,, -38700,2.7542753,1.25604,,,,,,,,,,,,,, -38800,2.6114068,1.3113866,,,,,,,,,,,,,, -38900,2.8655787,1.283389,,,,,,,,,,,,,, -39000,2.0047824,1.2770375,,,,,,,,,,,,,, -39100,1.5790665,1.249863,,,,,,,,,,,,,, -39200,2.4644647,1.2518225,,,,,,,,,,,,,, -39300,2.3621786,1.2555151,,,,,,,,,,,,,, -39309,,,0.22671203,0.073330542067458,0.45624575,0.1313032816165751,5348.0,0.25134808,0.0799260658501411,2472.0,30260.46949505806,33199.82506918907,30260.46949505806,2936.458508014679,1.0954077243804932,0.0 -39400,1.3348777,1.1686041,,,,,,,,,,,,,, -39500,2.6345415,1.26841,,,,,,,,,,,,,, -39600,2.1856015,1.2587054,,,,,,,,,,,,,, -39700,10.157481,1.2853878,,,,,,,,,,,,,, -39800,2.1176007,1.2302344,,,,,,,,,,,,,, -39900,2.2903507,1.303578,,,,,,,,,,,,,, -40000,1.7700684,1.2285076,,,,,,,,,,,,,, -40100,2.892406,1.2574054,,,,,,,,,,,,,, -40200,3.6600878,1.2794636,,,,,,,,,,,,,, -40300,1.6601069,1.2266835,,,,,,,,,,,,,, -40400,2.2342725,1.2533398,,,,,,,,,,,,,, -40500,1.9134811,1.2118341,,,,,,,,,,,,,, -40600,2.1863973,1.2395847,,,,,,,,,,,,,, -40700,4.211133,1.2703965,,,,,,,,,,,,,, -40800,2.4929817,1.1736405,,,,,,,,,,,,,, -40900,1.8315485,1.1930616,,,,,,,,,,,,,, -41000,2.045559,1.1684974,,,,,,,,,,,,,, -41100,2.3002646,1.2321911,,,,,,,,,,,,,, -41183,,,0.17770247,0.0616869580819994,0.4342081,0.1249408652500072,5348.0,0.23483199,0.0743810046107285,2472.0,31700.723588705063,34773.198744535446,31700.723588705063,3069.434670448303,1.153717041015625,0.0 -41200,2.269231,1.1986129,,,,,,,,,,,,,, -41300,1.260011,1.1816219,,,,,,,,,,,,,, -41400,2.4685261,1.1873415,,,,,,,,,,,,,, -41500,2.2041187,1.2524171,,,,,,,,,,,,,, -41600,2.2167416,1.2759874,,,,,,,,,,,,,, -41700,2.136892,1.1574996,,,,,,,,,,,,,, -41800,2.7444434,1.213587,,,,,,,,,,,,,, -41900,1.9473028,1.2777332,,,,,,,,,,,,,, -42000,2.2089074,1.1668272,,,,,,,,,,,,,, -42100,1.5243684,1.2682258,,,,,,,,,,,,,, -42200,2.0038478,1.2550089,,,,,,,,,,,,,, -42300,2.8003643,1.209873,,,,,,,,,,,,,, -42400,2.0723855,1.218901,,,,,,,,,,,,,, -42500,2.9130952,1.135844,,,,,,,,,,,,,, -42600,3.3642187,1.2379903,,,,,,,,,,,,,, -42700,1.7302805,1.1819326,,,,,,,,,,,,,, -42800,2.9970424,1.1086099,,,,,,,,,,,,,, -42900,2.3836558,1.2065006,,,,,,,,,,,,,, -43000,2.7599092,1.195696,,,,,,,,,,,,,, -43048,,,0.15259655,0.0528985621285323,0.41574052,0.1196887339853442,5348.0,0.2274025,0.0706842971177868,2472.0,33141.31151199341,36345.68494963646,33141.31151199341,3201.1860456466675,1.2138450145721436,0.0 -43100,1.8818316,1.1901063,,,,,,,,,,,,,, -43200,3.173349,1.2035953,,,,,,,,,,,,,, -43300,2.5136604,1.1589961,,,,,,,,,,,,,, -43400,1.6213735,1.1384151,,,,,,,,,,,,,, -43500,1.7300395,1.2123926,,,,,,,,,,,,,, -43600,5.0217175,1.1599962,,,,,,,,,,,,,, -43700,2.5769644,1.1930907,,,,,,,,,,,,,, -43800,2.4613228,1.1846852,,,,,,,,,,,,,, -43900,2.5380466,1.2110752,,,,,,,,,,,,,, -44000,2.14563,1.1965268,,,,,,,,,,,,,, -44100,2.2978687,1.1522682,,,,,,,,,,,,,, -44200,2.7787344,1.2527636,,,,,,,,,,,,,, -44300,1.9613785,1.1862493,,,,,,,,,,,,,, -44400,2.5243635,1.1591793,,,,,,,,,,,,,, -44500,1.7645475,1.1587583,,,,,,,,,,,,,, -44600,2.076214,1.1705941,,,,,,,,,,,,,, -44700,2.0307117,1.1156472,,,,,,,,,,,,,, -44800,3.2971692,1.1283257,,,,,,,,,,,,,, -44900,2.0891173,1.118209,,,,,,,,,,,,,, -44925,,,0.14536211,0.0501873654556568,0.40926024,0.1170337043938326,5348.0,0.21977012,0.0682875307212641,2472.0,34581.25995087624,37918.101712465286,34581.25995087624,3333.5102150440216,1.2725615501403809,0.0 -45000,1.7961788,1.1464621,,,,,,,,,,,,,, -45100,2.4509873,1.1490145,,,,,,,,,,,,,, -45200,3.7567532,1.1722975,,,,,,,,,,,,,, -45300,2.4004388,1.1554685,,,,,,,,,,,,,, -45400,1.984331,1.1458095,,,,,,,,,,,,,, -45500,1.89505,1.1098089,,,,,,,,,,,,,, -45600,2.1729064,1.182112,,,,,,,,,,,,,, -45700,2.023393,1.1160107,,,,,,,,,,,,,, -45800,2.5531693,1.1140608,,,,,,,,,,,,,, -45900,2.2999296,1.1824987,,,,,,,,,,,,,, -46000,1.6407354,1.121028,,,,,,,,,,,,,, -46100,2.6307778,1.1937592,,,,,,,,,,,,,, -46200,2.6890047,1.1556636,,,,,,,,,,,,,, -46300,1.4717804,1.138803,,,,,,,,,,,,,, -46400,2.400062,1.1583382,,,,,,,,,,,,,, -46500,1.502376,1.1605422,,,,,,,,,,,,,, -46600,2.574388,1.1791128,,,,,,,,,,,,,, -46700,2.0768871,1.1570276,,,,,,,,,,,,,, -46800,1.9512554,1.0978038,,,,,,,,,,,,,, -46805,,,0.13815887,0.048562661498708,0.40407094,0.1158075634552072,5348.0,0.21717538,0.0679422338675278,2472.0,36021.1313123703,39488.77329707146,36021.1313123703,3464.165018558502,1.3316552639007568,0.0 -46900,2.042996,1.1320921,,,,,,,,,,,,,, -47000,3.283246,1.1098038,,,,,,,,,,,,,, -47100,2.1844885,1.157654,,,,,,,,,,,,,, -47200,2.8126957,1.1276473,,,,,,,,,,,,,, -47300,2.7233777,1.1493998,,,,,,,,,,,,,, -47400,2.3656938,1.1106021,,,,,,,,,,,,,, -47500,2.785307,1.1079067,,,,,,,,,,,,,, -47600,2.591249,1.1270345,,,,,,,,,,,,,, -47700,1.7138162,1.1442529,,,,,,,,,,,,,, -47800,2.724315,1.0991757,,,,,,,,,,,,,, -47900,2.3861842,1.1296784,,,,,,,,,,,,,, -48000,,,0.15959443,0.0535628759848632,0.40380514,0.1157399808837869,5348.0,0.21718435,0.0679219222878963,2472.0,36923.73557591438,40524.21762919426,36923.73557591438,3596.892227172852,1.3875088691711426,0.0 -48000,,,,,,,,,,,36923.73557591438,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/eval_measurements.csv deleted file mode 100644 index abcfcdd32..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,28 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -186.00546002388,0.0,16.372990608215332,1,0,16.372990608215332,30.536512,2472,3.361485182702659,202.37854552268985,30.62459,3.3513766830074823,30.351221,5348,3.0459271846066214 -315.22576928138733,0.0285015106201171,1456.8387684822085,1857,0,1456.8387684822085,2.1457593,2472,0.4915808502427233,1772.174908161163,2.3449352,0.5320623826954385,2.671262,5348,0.5683211523793893 -450.2640874385834,0.0764284133911132,2897.1221652030945,3727,0,2897.1221652030945,0.6082362,2472,0.1934068612515995,3347.631866931916,0.5966041,0.1932851115758293,0.9358604,5348,0.2626355271923303 -582.7502326965332,0.1263625621795654,4337.322243690491,5606,0,4337.322243690491,0.5005587,2472,0.1601974285540186,4920.453497171402,0.4300289,0.1458097903575012,0.8080621,5348,0.2312289407880128 -714.9384062290192,0.174260139465332,5777.62396812439,7486,0,5777.62396812439,0.44369552,2472,0.1436231795746755,6493.077576160431,0.3892914,0.1317028818944293,0.7440744,5348,0.2108383135252034 -847.2784371376038,0.2258105278015136,7217.692116260529,9352,0,7217.692116260529,0.43267328,2472,0.1383624804501046,8065.623285531998,0.39885035,0.1344804300262013,0.7127357,5348,0.2025642758527472 -978.80659198761,0.2755625247955322,8658.064115285873,11220,0,8658.064115285873,0.39661396,2472,0.1267239453212276,9637.65848660469,0.33331928,0.1113040383505799,0.67308295,5348,0.1915869353234791 -1112.9421019554138,0.3250164985656738,10098.212848901749,13089,0,10098.212848901749,0.3859156,2472,0.1237787662746531,11212.078014612198,0.3510411,0.1183566224770094,0.6683462,5348,0.1900421908338724 -1244.2248899936676,0.3792309761047363,11538.172552585602,14959,0,11538.172552585602,0.36685,2472,0.1178477850222411,12783.459538698196,0.3626199,0.1157342155230644,0.6363704,5348,0.1824150149164389 -1378.1725063323977,0.5102035999298096,12978.370990991592,16835,0,12978.370990991592,0.3457338,2472,0.1115511953364613,14357.825068950651,0.2640011,0.0905968954047061,0.5966437,5348,0.1708680498566284 -1510.10333442688,0.5596368312835693,14418.6655626297,18711,0,14418.6655626297,0.33608237,2472,0.1087278857676761,15930.184390544891,0.26224798,0.0893585368898506,0.5874665,5348,0.168280602836537 -1641.366404056549,0.6074838638305664,15859.262013435364,20585,0,15859.262013435364,0.33556062,2472,0.1091950520992017,17502.176399946213,0.38126385,0.127479454767484,0.5813514,5348,0.1663400175714685 -1773.641833782196,0.6536881923675537,17299.8719329834,22448,0,17299.8719329834,0.32734656,2472,0.1047671277395243,19075.193665504456,0.3882775,0.1288513710578304,0.57250744,5348,0.1641870299390791 -1905.782042503357,0.7064037322998047,18740.242106437683,24324,0,18740.242106437683,0.30448222,2472,0.0971096622184307,20647.843740701675,0.44150472,0.1432519683210627,0.5418794,5348,0.1553240584299603 -2036.0781605243685,0.7543544769287109,20180.723200559616,26201,0,20180.723200559616,0.30177844,2472,0.0966424958869051,22218.75596666336,0.36288148,0.1198465963566634,0.5348055,5348,0.1541751547158152 -2165.942393541336,0.8169291019439697,21621.01096129417,28079,0,21621.01096129417,0.29157224,2472,0.0944082221274348,23789.05851054192,0.37025255,0.1247483536356501,0.52608824,5348,0.1505160412060592 -2297.1110696792603,0.867875337600708,23061.27196407318,29945,0,23061.27196407318,0.279768,2472,0.0905693335770723,25360.624280691147,0.27366316,0.0941657476396696,0.50481427,5348,0.1438832945538102 -2429.268038749695,0.9267137050628662,24501.19046640396,31827,0,24501.19046640396,0.26446828,2472,0.0857961123636585,26932.845405817032,0.31801972,0.1082496797994504,0.48673412,5348,0.1397800670032922 -2561.341689825058,0.980743169784546,25941.62428545952,33708,0,25941.62428545952,0.25827703,2472,0.0828103101578209,28505.49468064308,0.2762536,0.0937875751503006,0.47617844,5348,0.1371153827587205 -2693.789966583252,1.0337035655975342,27381.94872641564,35586,0,27381.94872641564,0.24837476,2472,0.0786464363333536,30078.40672755241,0.2669341,0.0928764359399915,0.4645654,5348,0.132191509698099 -2825.580310821533,1.0857267379760742,28821.92249059677,37464,0,28821.92249059677,0.24016106,2472,0.077326183657303,31650.30978918076,0.24166115,0.082856364131916,0.4427787,5348,0.1273738378211379 -2956.2057723999023,1.1383299827575684,30261.989458084103,39340,0,30261.989458084103,0.23111513,2472,0.0736700993236244,33221.140043735504,0.24197032,0.0847769387592396,0.43660167,5348,0.1253077420662888 -3089.1986536979675,1.1911566257476809,31702.491810321808,41204,0,31702.491810321808,0.22799182,2472,0.0729591940365202,34794.77567911148,0.26820278,0.0889005727581886,0.42843142,5348,0.1224403101074562 -3221.5781757831573,1.2448463439941406,33142.95555186272,43081,0,33142.95555186272,0.22370486,2472,0.0714561371437856,36367.7583668232,0.21494322,0.075831554091704,0.42295602,5348,0.1212045145157708 -3352.5498700141907,1.303229808807373,34583.11553454399,44955,0,34583.11553454399,0.2223216,2472,0.0710092823918916,37939.03490805626,0.22929652,0.0814922255743792,0.4198003,5348,0.1197370072506444 -3485.300974607468,1.3599789142608645,36023.417892456055,46830,0,36023.417892456055,0.22184521,2472,0.0705827392196291,39512.2322010994,0.19979791,0.069397354668894,0.41824093,5348,0.1195632234955636 -3615.8436131477356,1.4163413047790527,36906.64269685745,48000,0,36906.64269685745,0.22146653,2472,0.07098897081226006,40526.11241531372,0.224268,0.07741169150711276,0.4184095,5348,0.1196790793322842 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/measurements.csv deleted file mode 100644 index 85729560a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/measurements.csv +++ /dev/null @@ -1,509 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.422195,33.24812,,,,,,,,,,,,,, -1,,,30.62459,3.3513766830074823,30.351221,3.0459271846066214,5348.0,30.536512,3.361485182702659,2472.0,16.372990608215332,202.37854552268985,16.372990608215332,186.00546002388,0.0,0.0 -100,1.2104547,6.012604,,,,,,,,,,,,,, -200,0.31214607,5.8238053,,,,,,,,,,,,,, -300,0.5862525,5.763004,,,,,,,,,,,,,, -400,1.2817597,5.540213,,,,,,,,,,,,,, -500,1.2565951,5.0697985,,,,,,,,,,,,,, -600,2.1842577,4.1616564,,,,,,,,,,,,,, -700,1.7467345,3.5908647,,,,,,,,,,,,,, -800,2.0117004,3.2705722,,,,,,,,,,,,,, -900,2.5184982,3.0931911,,,,,,,,,,,,,, -1000,2.4293103,2.9027836,,,,,,,,,,,,,, -1100,2.2892966,2.6931975,,,,,,,,,,,,,, -1200,3.3487656,2.704033,,,,,,,,,,,,,, -1300,1.8318769,2.5680978,,,,,,,,,,,,,, -1400,3.6337,2.562898,,,,,,,,,,,,,, -1500,1.957139,2.4325504,,,,,,,,,,,,,, -1600,3.2660983,2.4605446,,,,,,,,,,,,,, -1700,2.5939455,2.3479111,,,,,,,,,,,,,, -1800,3.4035192,2.3640578,,,,,,,,,,,,,, -1857,,,2.3449352,0.5320623826954385,2.671262,0.5683211523793893,5348.0,2.1457593,0.4915808502427233,2472.0,1456.8387684822085,1772.174908161163,1456.8387684822085,315.22576928138733,0.0285015106201171,0.0 -1900,2.67545,2.3174992,,,,,,,,,,,,,, -2000,2.7539034,2.2507596,,,,,,,,,,,,,, -2100,2.406539,2.1423826,,,,,,,,,,,,,, -2200,1.9945608,2.1801653,,,,,,,,,,,,,, -2300,2.8849802,2.1598303,,,,,,,,,,,,,, -2400,2.4917812,2.0331562,,,,,,,,,,,,,, -2500,2.024488,2.1263213,,,,,,,,,,,,,, -2600,1.8875407,2.048295,,,,,,,,,,,,,, -2700,2.4460795,2.0120873,,,,,,,,,,,,,, -2800,2.0105386,1.9875388,,,,,,,,,,,,,, -2900,4.9274697,1.9694775,,,,,,,,,,,,,, -3000,2.4416735,2.0101557,,,,,,,,,,,,,, -3100,3.1280541,1.9007854,,,,,,,,,,,,,, -3200,2.5571394,1.8842491,,,,,,,,,,,,,, -3300,1.4740337,1.8994118,,,,,,,,,,,,,, -3400,2.1690316,1.9308773,,,,,,,,,,,,,, -3500,2.177427,1.9063647,,,,,,,,,,,,,, -3600,2.6001596,1.8531538,,,,,,,,,,,,,, -3700,2.8724287,1.8630781,,,,,,,,,,,,,, -3727,,,0.5966041,0.1932851115758293,0.9358604,0.2626355271923303,5348.0,0.6082362,0.1934068612515995,2472.0,2897.1221652030945,3347.631866931916,2897.1221652030945,450.2640874385834,0.0764284133911132,0.0 -3800,2.5869896,1.9270604,,,,,,,,,,,,,, -3900,1.9965485,1.872142,,,,,,,,,,,,,, -4000,2.0187247,1.835449,,,,,,,,,,,,,, -4100,2.1679559,1.9105244,,,,,,,,,,,,,, -4200,3.9076777,1.9324825,,,,,,,,,,,,,, -4300,2.9253469,1.9262692,,,,,,,,,,,,,, -4400,2.3314161,1.8334144,,,,,,,,,,,,,, -4500,1.7302388,1.8100482,,,,,,,,,,,,,, -4600,2.6799057,1.8335987,,,,,,,,,,,,,, -4700,1.7918473,1.7922539,,,,,,,,,,,,,, -4800,2.2940795,1.8533955,,,,,,,,,,,,,, -4900,1.9764655,1.820421,,,,,,,,,,,,,, -5000,4.17454,1.8156486,,,,,,,,,,,,,, -5100,2.5219843,1.737641,,,,,,,,,,,,,, -5200,3.7307692,1.753277,,,,,,,,,,,,,, -5300,2.745261,1.7291744,,,,,,,,,,,,,, -5400,2.5641975,1.7502608,,,,,,,,,,,,,, -5500,3.3811142,1.6939338,,,,,,,,,,,,,, -5600,2.915746,1.7428839,,,,,,,,,,,,,, -5606,,,0.4300289,0.1458097903575012,0.8080621,0.2312289407880128,5348.0,0.5005587,0.1601974285540186,2472.0,4337.322243690491,4920.453497171402,4337.322243690491,582.7502326965332,0.1263625621795654,0.0 -5700,5.005967,1.7617438,,,,,,,,,,,,,, -5800,2.4464114,1.7117051,,,,,,,,,,,,,, -5900,2.848459,1.7646109,,,,,,,,,,,,,, -6000,4.1732597,1.7806743,,,,,,,,,,,,,, -6100,3.693445,1.7584884,,,,,,,,,,,,,, -6200,3.4666986,1.740508,,,,,,,,,,,,,, -6300,1.8440163,1.7073914,,,,,,,,,,,,,, -6400,2.3956075,1.6425464,,,,,,,,,,,,,, -6500,2.4540775,1.7449172,,,,,,,,,,,,,, -6600,3.7121286,1.72019,,,,,,,,,,,,,, -6700,1.9410502,1.7367878,,,,,,,,,,,,,, -6800,1.7171496,1.7007313,,,,,,,,,,,,,, -6900,2.5253222,1.7016801,,,,,,,,,,,,,, -7000,1.504662,1.689215,,,,,,,,,,,,,, -7100,2.360459,1.7064812,,,,,,,,,,,,,, -7200,2.0003932,1.7345055,,,,,,,,,,,,,, -7300,2.2320104,1.7579793,,,,,,,,,,,,,, -7400,2.5660365,1.652568,,,,,,,,,,,,,, -7486,,,0.3892914,0.1317028818944293,0.7440744,0.2108383135252034,5348.0,0.44369552,0.1436231795746755,2472.0,5777.62396812439,6493.077576160431,5777.62396812439,714.9384062290192,0.174260139465332,0.0 -7500,2.558392,1.6434712,,,,,,,,,,,,,, -7600,2.8727384,1.6039532,,,,,,,,,,,,,, -7700,2.6682398,1.6265464,,,,,,,,,,,,,, -7800,1.8729277,1.6138678,,,,,,,,,,,,,, -7900,1.6818881,1.6403153,,,,,,,,,,,,,, -8000,4.182618,1.6368088,,,,,,,,,,,,,, -8100,8.136972,1.685006,,,,,,,,,,,,,, -8200,3.2051017,1.672631,,,,,,,,,,,,,, -8300,3.9464154,1.6655868,,,,,,,,,,,,,, -8400,2.2420688,1.6636783,,,,,,,,,,,,,, -8500,3.3350894,1.6699028,,,,,,,,,,,,,, -8600,2.2515419,1.7203325,,,,,,,,,,,,,, -8700,2.8811934,1.6310492,,,,,,,,,,,,,, -8800,12.913555,1.6324296,,,,,,,,,,,,,, -8900,3.2038672,1.6085632,,,,,,,,,,,,,, -9000,2.9798849,1.6623579,,,,,,,,,,,,,, -9100,2.4885623,1.6321025,,,,,,,,,,,,,, -9200,2.2780254,1.6033393,,,,,,,,,,,,,, -9300,2.1375237,1.6090182,,,,,,,,,,,,,, -9352,,,0.39885035,0.1344804300262013,0.7127357,0.2025642758527472,5348.0,0.43267328,0.1383624804501046,2472.0,7217.692116260529,8065.623285531998,7217.692116260529,847.2784371376038,0.2258105278015136,0.0 -9400,2.8072114,1.6179612,,,,,,,,,,,,,, -9500,2.7259254,1.6125239,,,,,,,,,,,,,, -9600,2.6549413,1.6110725,,,,,,,,,,,,,, -9700,2.6561244,1.588685,,,,,,,,,,,,,, -9800,4.9040833,1.5521088,,,,,,,,,,,,,, -9900,2.4513896,1.5959104,,,,,,,,,,,,,, -10000,2.1417809,1.5947913,,,,,,,,,,,,,, -10100,3.298577,1.5936244,,,,,,,,,,,,,, -10200,3.1339386,1.6740637,,,,,,,,,,,,,, -10300,2.2470133,1.607417,,,,,,,,,,,,,, -10400,3.4686332,1.6233395,,,,,,,,,,,,,, -10500,3.9058685,1.6518478,,,,,,,,,,,,,, -10600,2.2670338,1.6294601,,,,,,,,,,,,,, -10700,1.762158,1.5829358,,,,,,,,,,,,,, -10800,2.1665876,1.6079583,,,,,,,,,,,,,, -10900,2.2278848,1.6104741,,,,,,,,,,,,,, -11000,1.7388119,1.6122706,,,,,,,,,,,,,, -11100,2.345848,1.5522815,,,,,,,,,,,,,, -11200,2.6640081,1.5806319,,,,,,,,,,,,,, -11220,,,0.33331928,0.1113040383505799,0.67308295,0.1915869353234791,5348.0,0.39661396,0.1267239453212276,2472.0,8658.064115285873,9637.65848660469,8658.064115285873,978.80659198761,0.2755625247955322,0.0 -11300,3.0979023,1.5924574,,,,,,,,,,,,,, -11400,2.948001,1.6128684,,,,,,,,,,,,,, -11500,2.1546402,1.5460804,,,,,,,,,,,,,, -11600,2.2415318,1.5921944,,,,,,,,,,,,,, -11700,3.21989,1.5009463,,,,,,,,,,,,,, -11800,2.2486084,1.5648142,,,,,,,,,,,,,, -11900,2.7737238,1.5982606,,,,,,,,,,,,,, -12000,2.4143186,1.5898579,,,,,,,,,,,,,, -12100,2.6029253,1.6199194,,,,,,,,,,,,,, -12200,2.958162,1.6292487,,,,,,,,,,,,,, -12300,4.997993,1.6192731,,,,,,,,,,,,,, -12400,2.5497499,1.5626552,,,,,,,,,,,,,, -12500,3.832932,1.5619092,,,,,,,,,,,,,, -12600,3.44226,1.5757749,,,,,,,,,,,,,, -12700,2.0369406,1.5419654,,,,,,,,,,,,,, -12800,3.6306417,1.5803564,,,,,,,,,,,,,, -12900,2.4595013,1.5354134,,,,,,,,,,,,,, -13000,2.5465457,1.6218698,,,,,,,,,,,,,, -13089,,,0.3510411,0.1183566224770094,0.6683462,0.1900421908338724,5348.0,0.3859156,0.1237787662746531,2472.0,10098.212848901749,11212.078014612198,10098.212848901749,1112.9421019554138,0.3250164985656738,0.0 -13100,3.206564,1.6401968,,,,,,,,,,,,,, -13200,2.7069676,1.5492,,,,,,,,,,,,,, -13300,2.1001594,1.6138334,,,,,,,,,,,,,, -13400,3.525186,1.5773916,,,,,,,,,,,,,, -13500,1.6681923,1.4888837,,,,,,,,,,,,,, -13600,2.5003018,1.5106469,,,,,,,,,,,,,, -13700,3.0599892,1.5963817,,,,,,,,,,,,,, -13800,2.2569563,1.5352182,,,,,,,,,,,,,, -13900,2.1039662,1.6534199,,,,,,,,,,,,,, -14000,4.6706953,1.5357897,,,,,,,,,,,,,, -14100,2.6903496,1.5726902,,,,,,,,,,,,,, -14200,2.8567896,1.558333,,,,,,,,,,,,,, -14300,3.057667,1.5154529,,,,,,,,,,,,,, -14400,4.221023,1.5436157,,,,,,,,,,,,,, -14500,5.5763235,1.508726,,,,,,,,,,,,,, -14600,4.2789335,1.5541583,,,,,,,,,,,,,, -14700,3.7174275,1.5747116,,,,,,,,,,,,,, -14800,4.2404594,1.5400244,,,,,,,,,,,,,, -14900,2.4688776,1.4922528,,,,,,,,,,,,,, -14959,,,0.3626199,0.1157342155230644,0.6363704,0.1824150149164389,5348.0,0.36685,0.1178477850222411,2472.0,11538.172552585602,12783.459538698196,11538.172552585602,1244.2248899936676,0.3792309761047363,0.0 -15000,2.274386,1.5802783,,,,,,,,,,,,,, -15100,3.3624609,1.5645915,,,,,,,,,,,,,, -15200,2.2079716,1.5122635,,,,,,,,,,,,,, -15300,2.4007657,1.5368404,,,,,,,,,,,,,, -15400,3.0333333,1.5640318,,,,,,,,,,,,,, -15500,3.743812,1.5384119,,,,,,,,,,,,,, -15600,1.7896568,1.5110352,,,,,,,,,,,,,, -15700,2.9396703,1.4964291,,,,,,,,,,,,,, -15800,1.9570012,1.5378808,,,,,,,,,,,,,, -15900,2.3421867,1.5041133,,,,,,,,,,,,,, -16000,2.622994,1.4861723,,,,,,,,,,,,,, -16100,3.096592,1.5827823,,,,,,,,,,,,,, -16200,2.345978,1.5527774,,,,,,,,,,,,,, -16300,3.4541743,1.5136462,,,,,,,,,,,,,, -16400,2.3667736,1.5405741,,,,,,,,,,,,,, -16500,3.1907816,1.5084543,,,,,,,,,,,,,, -16600,1.7765806,1.5395147,,,,,,,,,,,,,, -16700,2.80383,1.522494,,,,,,,,,,,,,, -16800,2.9623344,1.4597279,,,,,,,,,,,,,, -16835,,,0.2640011,0.0905968954047061,0.5966437,0.1708680498566284,5348.0,0.3457338,0.1115511953364613,2472.0,12978.370990991592,14357.825068950651,12978.370990991592,1378.1725063323977,0.5102035999298096,0.0 -16900,9.575913,1.52379,,,,,,,,,,,,,, -17000,4.184395,1.4971858,,,,,,,,,,,,,, -17100,2.2445734,1.5257883,,,,,,,,,,,,,, -17200,1.9729632,1.4597508,,,,,,,,,,,,,, -17300,3.1989622,1.494116,,,,,,,,,,,,,, -17400,2.304904,1.5286207,,,,,,,,,,,,,, -17500,2.818439,1.4965897,,,,,,,,,,,,,, -17600,1.9372087,1.4702218,,,,,,,,,,,,,, -17700,1.8785739,1.3932005,,,,,,,,,,,,,, -17800,2.8429675,1.4701642,,,,,,,,,,,,,, -17900,2.0000358,1.5175211,,,,,,,,,,,,,, -18000,1.9504591,1.5077928,,,,,,,,,,,,,, -18100,3.1900656,1.5457159,,,,,,,,,,,,,, -18200,3.4915972,1.5652617,,,,,,,,,,,,,, -18300,2.3297884,1.5074382,,,,,,,,,,,,,, -18400,3.1478312,1.5230491,,,,,,,,,,,,,, -18500,2.2215867,1.4890751,,,,,,,,,,,,,, -18600,1.5637027,1.4397639,,,,,,,,,,,,,, -18700,3.0929615,1.5314147,,,,,,,,,,,,,, -18711,,,0.26224798,0.0893585368898506,0.5874665,0.168280602836537,5348.0,0.33608237,0.1087278857676761,2472.0,14418.6655626297,15930.184390544891,14418.6655626297,1510.10333442688,0.5596368312835693,0.0 -18800,2.702339,1.4589593,,,,,,,,,,,,,, -18900,1.6857067,1.4790531,,,,,,,,,,,,,, -19000,2.985109,1.4930854,,,,,,,,,,,,,, -19100,1.6556591,1.4786819,,,,,,,,,,,,,, -19200,2.8240778,1.4539633,,,,,,,,,,,,,, -19300,2.375414,1.4373218,,,,,,,,,,,,,, -19400,2.3875768,1.4519563,,,,,,,,,,,,,, -19500,2.5829153,1.5293504,,,,,,,,,,,,,, -19600,2.42903,1.4925284,,,,,,,,,,,,,, -19700,1.867876,1.4517208,,,,,,,,,,,,,, -19800,2.7178378,1.4539115,,,,,,,,,,,,,, -19900,2.190376,1.4681658,,,,,,,,,,,,,, -20000,1.962687,1.500624,,,,,,,,,,,,,, -20100,2.3630536,1.4525867,,,,,,,,,,,,,, -20200,3.0596774,1.4579067,,,,,,,,,,,,,, -20300,2.6840405,1.4568325,,,,,,,,,,,,,, -20400,4.53348,1.5007713,,,,,,,,,,,,,, -20500,2.8248625,1.5159347,,,,,,,,,,,,,, -20585,,,0.38126385,0.127479454767484,0.5813514,0.1663400175714685,5348.0,0.33556062,0.1091950520992017,2472.0,15859.262013435364,17502.176399946213,15859.262013435364,1641.366404056549,0.6074838638305664,0.0 -20600,3.2587018,1.4449264,,,,,,,,,,,,,, -20700,4.111459,1.4270989,,,,,,,,,,,,,, -20800,1.9314641,1.4028058,,,,,,,,,,,,,, -20900,3.382753,1.4815943,,,,,,,,,,,,,, -21000,2.7080717,1.4674405,,,,,,,,,,,,,, -21100,2.6484466,1.5356646,,,,,,,,,,,,,, -21200,3.6494727,1.5016668,,,,,,,,,,,,,, -21300,2.5148532,1.4995991,,,,,,,,,,,,,, -21400,6.755698,1.4396003,,,,,,,,,,,,,, -21500,2.55656,1.4189881,,,,,,,,,,,,,, -21600,2.6608717,1.5419397,,,,,,,,,,,,,, -21700,1.7926401,1.4207844,,,,,,,,,,,,,, -21800,5.1120124,1.4888455,,,,,,,,,,,,,, -21900,4.4006367,1.4264159,,,,,,,,,,,,,, -22000,2.3047264,1.4793772,,,,,,,,,,,,,, -22100,3.5297356,1.4429909,,,,,,,,,,,,,, -22200,1.9512459,1.482717,,,,,,,,,,,,,, -22300,2.533003,1.507674,,,,,,,,,,,,,, -22400,3.296397,1.4589604,,,,,,,,,,,,,, -22448,,,0.3882775,0.1288513710578304,0.57250744,0.1641870299390791,5348.0,0.32734656,0.1047671277395243,2472.0,17299.8719329834,19075.193665504456,17299.8719329834,1773.641833782196,0.6536881923675537,0.0 -22500,2.882563,1.5143523,,,,,,,,,,,,,, -22600,2.351535,1.4172485,,,,,,,,,,,,,, -22700,2.5842855,1.4714091,,,,,,,,,,,,,, -22800,2.564968,1.3696153,,,,,,,,,,,,,, -22900,3.497139,1.5068694,,,,,,,,,,,,,, -23000,1.5163535,1.4114369,,,,,,,,,,,,,, -23100,2.2332046,1.4129559,,,,,,,,,,,,,, -23200,2.3823385,1.4098271,,,,,,,,,,,,,, -23300,2.5693183,1.425745,,,,,,,,,,,,,, -23400,1.8733844,1.4597664,,,,,,,,,,,,,, -23500,2.6285703,1.4935337,,,,,,,,,,,,,, -23600,2.0583072,1.4180154,,,,,,,,,,,,,, -23700,2.3753688,1.352622,,,,,,,,,,,,,, -23800,2.775646,1.3849814,,,,,,,,,,,,,, -23900,2.3479729,1.4348146,,,,,,,,,,,,,, -24000,3.7558374,1.4424049,,,,,,,,,,,,,, -24100,2.6593347,1.3570706,,,,,,,,,,,,,, -24200,3.275403,1.4125775,,,,,,,,,,,,,, -24300,1.9497546,1.3910189,,,,,,,,,,,,,, -24324,,,0.44150472,0.1432519683210627,0.5418794,0.1553240584299603,5348.0,0.30448222,0.0971096622184307,2472.0,18740.242106437683,20647.843740701675,18740.242106437683,1905.782042503357,0.7064037322998047,0.0 -24400,3.0077975,1.4001735,,,,,,,,,,,,,, -24500,5.4588723,1.4281487,,,,,,,,,,,,,, -24600,1.80116,1.4086932,,,,,,,,,,,,,, -24700,3.5592527,1.418419,,,,,,,,,,,,,, -24800,2.9749668,1.3954501,,,,,,,,,,,,,, -24900,2.198694,1.4339503,,,,,,,,,,,,,, -25000,2.0375702,1.3778232,,,,,,,,,,,,,, -25100,2.6057448,1.3962951,,,,,,,,,,,,,, -25200,2.5444508,1.4728254,,,,,,,,,,,,,, -25300,3.9352295,1.3819021,,,,,,,,,,,,,, -25400,2.3283412,1.4664074,,,,,,,,,,,,,, -25500,2.214183,1.4416065,,,,,,,,,,,,,, -25600,2.4618642,1.3617363,,,,,,,,,,,,,, -25700,3.141915,1.4183152,,,,,,,,,,,,,, -25800,1.9257884,1.3962682,,,,,,,,,,,,,, -25900,2.1610231,1.4374876,,,,,,,,,,,,,, -26000,4.975571,1.4178671,,,,,,,,,,,,,, -26100,3.9147425,1.3968877,,,,,,,,,,,,,, -26200,2.7838259,1.397542,,,,,,,,,,,,,, -26201,,,0.36288148,0.1198465963566634,0.5348055,0.1541751547158152,5348.0,0.30177844,0.0966424958869051,2472.0,20180.723200559616,22218.75596666336,20180.723200559616,2036.0781605243685,0.7543544769287109,0.0 -26300,2.281806,1.3866843,,,,,,,,,,,,,, -26400,2.4302106,1.3321854,,,,,,,,,,,,,, -26500,3.302581,1.3714246,,,,,,,,,,,,,, -26600,2.0624313,1.3763185,,,,,,,,,,,,,, -26700,2.92176,1.465731,,,,,,,,,,,,,, -26800,2.046092,1.3772988,,,,,,,,,,,,,, -26900,1.9110975,1.3911533,,,,,,,,,,,,,, -27000,1.8161771,1.346718,,,,,,,,,,,,,, -27100,2.040701,1.4126314,,,,,,,,,,,,,, -27200,3.7170641,1.4085815,,,,,,,,,,,,,, -27300,2.3589206,1.426346,,,,,,,,,,,,,, -27400,2.1727223,1.4344411,,,,,,,,,,,,,, -27500,2.889036,1.3837554,,,,,,,,,,,,,, -27600,2.2146125,1.3040433,,,,,,,,,,,,,, -27700,2.632871,1.3339864,,,,,,,,,,,,,, -27800,2.540248,1.3997824,,,,,,,,,,,,,, -27900,1.8872849,1.369254,,,,,,,,,,,,,, -28000,2.2892258,1.3133997,,,,,,,,,,,,,, -28079,,,0.37025255,0.1247483536356501,0.52608824,0.1505160412060592,5348.0,0.29157224,0.0944082221274348,2472.0,21621.01096129417,23789.05851054192,21621.01096129417,2165.942393541336,0.8169291019439697,0.0 -28100,1.8862216,1.3928111,,,,,,,,,,,,,, -28200,1.9856685,1.2830119,,,,,,,,,,,,,, -28300,3.4937954,1.367467,,,,,,,,,,,,,, -28400,2.534363,1.350863,,,,,,,,,,,,,, -28500,3.6687334,1.3766983,,,,,,,,,,,,,, -28600,2.2589443,1.3250384,,,,,,,,,,,,,, -28700,2.045041,1.3758982,,,,,,,,,,,,,, -28800,3.2729611,1.3792802,,,,,,,,,,,,,, -28900,2.5851884,1.3449059,,,,,,,,,,,,,, -29000,4.4356604,1.3397144,,,,,,,,,,,,,, -29100,2.1660464,1.420996,,,,,,,,,,,,,, -29200,1.9530889,1.3740559,,,,,,,,,,,,,, -29300,4.1273885,1.3499911,,,,,,,,,,,,,, -29400,2.7302291,1.3071433,,,,,,,,,,,,,, -29500,6.3310866,1.3600284,,,,,,,,,,,,,, -29600,1.8676251,1.3627083,,,,,,,,,,,,,, -29700,2.4115617,1.3585715,,,,,,,,,,,,,, -29800,2.4073334,1.2473909,,,,,,,,,,,,,, -29900,3.27498,1.4029366,,,,,,,,,,,,,, -29945,,,0.27366316,0.0941657476396696,0.50481427,0.1438832945538102,5348.0,0.279768,0.0905693335770723,2472.0,23061.27196407318,25360.624280691147,23061.27196407318,2297.1110696792603,0.867875337600708,0.0 -30000,1.8556813,1.2705374,,,,,,,,,,,,,, -30100,1.8942772,1.3744854,,,,,,,,,,,,,, -30200,2.9483855,1.344412,,,,,,,,,,,,,, -30300,3.008676,1.3192164,,,,,,,,,,,,,, -30400,3.6304412,1.3412442,,,,,,,,,,,,,, -30500,2.9627445,1.2800441,,,,,,,,,,,,,, -30600,3.0307026,1.311408,,,,,,,,,,,,,, -30700,1.8091353,1.3610699,,,,,,,,,,,,,, -30800,1.9358519,1.3808904,,,,,,,,,,,,,, -30900,1.8495437,1.3036984,,,,,,,,,,,,,, -31000,2.289351,1.3762298,,,,,,,,,,,,,, -31100,2.982093,1.2813674,,,,,,,,,,,,,, -31200,2.3435674,1.3051525,,,,,,,,,,,,,, -31300,2.4785101,1.3091052,,,,,,,,,,,,,, -31400,3.949361,1.2356573,,,,,,,,,,,,,, -31500,2.2339702,1.2960384,,,,,,,,,,,,,, -31600,1.7900004,1.3059913,,,,,,,,,,,,,, -31700,1.7687953,1.2873288,,,,,,,,,,,,,, -31800,1.9449277,1.2546772,,,,,,,,,,,,,, -31827,,,0.31801972,0.1082496797994504,0.48673412,0.1397800670032922,5348.0,0.26446828,0.0857961123636585,2472.0,24501.19046640396,26932.845405817032,24501.19046640396,2429.268038749695,0.9267137050628662,0.0 -31900,2.0385592,1.3015621,,,,,,,,,,,,,, -32000,3.0901074,1.3232114,,,,,,,,,,,,,, -32100,3.1145134,1.3202486,,,,,,,,,,,,,, -32200,2.235128,1.3074919,,,,,,,,,,,,,, -32300,2.1193697,1.2904907,,,,,,,,,,,,,, -32400,2.3636143,1.2714486,,,,,,,,,,,,,, -32500,2.0035152,1.2735946,,,,,,,,,,,,,, -32600,3.3919542,1.2911702,,,,,,,,,,,,,, -32700,2.262109,1.2805942,,,,,,,,,,,,,, -32800,3.8185012,1.2696441,,,,,,,,,,,,,, -32900,2.6569002,1.3797106,,,,,,,,,,,,,, -33000,2.3138752,1.2824711,,,,,,,,,,,,,, -33100,2.9866917,1.2464077,,,,,,,,,,,,,, -33200,2.2068384,1.2817471,,,,,,,,,,,,,, -33300,3.1693985,1.2981771,,,,,,,,,,,,,, -33400,3.4710982,1.2893425,,,,,,,,,,,,,, -33500,2.9421973,1.2653388,,,,,,,,,,,,,, -33600,2.0292153,1.2185227,,,,,,,,,,,,,, -33700,2.70463,1.2449849,,,,,,,,,,,,,, -33708,,,0.2762536,0.0937875751503006,0.47617844,0.1371153827587205,5348.0,0.25827703,0.0828103101578209,2472.0,25941.62428545952,28505.49468064308,25941.62428545952,2561.341689825058,0.980743169784546,0.0 -33800,3.468597,1.262149,,,,,,,,,,,,,, -33900,2.030013,1.2104548,,,,,,,,,,,,,, -34000,3.0991054,1.3350722,,,,,,,,,,,,,, -34100,2.7109458,1.2587364,,,,,,,,,,,,,, -34200,6.6148105,1.2546543,,,,,,,,,,,,,, -34300,1.8778555,1.2211633,,,,,,,,,,,,,, -34400,5.2724533,1.2258368,,,,,,,,,,,,,, -34500,4.7328606,1.298044,,,,,,,,,,,,,, -34600,3.5671036,1.2771442,,,,,,,,,,,,,, -34700,2.3276827,1.2615154,,,,,,,,,,,,,, -34800,1.8728236,1.3539383,,,,,,,,,,,,,, -34900,3.006943,1.2879741,,,,,,,,,,,,,, -35000,3.5745678,1.2678425,,,,,,,,,,,,,, -35100,1.5788581,1.2212147,,,,,,,,,,,,,, -35200,1.8292186,1.2152534,,,,,,,,,,,,,, -35300,2.328843,1.2699668,,,,,,,,,,,,,, -35400,3.104605,1.1994244,,,,,,,,,,,,,, -35500,3.1602,1.2702836,,,,,,,,,,,,,, -35586,,,0.2669341,0.0928764359399915,0.4645654,0.132191509698099,5348.0,0.24837476,0.0786464363333536,2472.0,27381.94872641564,30078.40672755241,27381.94872641564,2693.789966583252,1.0337035655975342,0.0 -35600,3.5365608,1.2132248,,,,,,,,,,,,,, -35700,3.1511428,1.1985261,,,,,,,,,,,,,, -35800,2.8968716,1.2062281,,,,,,,,,,,,,, -35900,5.497448,1.2645338,,,,,,,,,,,,,, -36000,2.432537,1.1917377,,,,,,,,,,,,,, -36100,2.7925286,1.2534441,,,,,,,,,,,,,, -36200,3.9797077,1.245737,,,,,,,,,,,,,, -36300,2.9393892,1.2225077,,,,,,,,,,,,,, -36400,2.329481,1.2402205,,,,,,,,,,,,,, -36500,1.7509181,1.2387894,,,,,,,,,,,,,, -36600,2.2650845,1.2301459,,,,,,,,,,,,,, -36700,2.6095016,1.2355518,,,,,,,,,,,,,, -36800,1.8963534,1.1776348,,,,,,,,,,,,,, -36900,1.7075056,1.2231661,,,,,,,,,,,,,, -37000,4.6421027,1.1884613,,,,,,,,,,,,,, -37100,3.105621,1.1621703,,,,,,,,,,,,,, -37200,4.2285585,1.2337836,,,,,,,,,,,,,, -37300,3.6156328,1.2426844,,,,,,,,,,,,,, -37400,4.6005373,1.2300082,,,,,,,,,,,,,, -37464,,,0.24166115,0.082856364131916,0.4427787,0.1273738378211379,5348.0,0.24016106,0.077326183657303,2472.0,28821.92249059677,31650.30978918076,28821.92249059677,2825.580310821533,1.0857267379760742,0.0 -37500,3.9338527,1.2492939,,,,,,,,,,,,,, -37600,2.0528224,1.2192473,,,,,,,,,,,,,, -37700,3.6242056,1.280044,,,,,,,,,,,,,, -37800,1.7575177,1.2199789,,,,,,,,,,,,,, -37900,3.5948436,1.206519,,,,,,,,,,,,,, -38000,2.340378,1.2056773,,,,,,,,,,,,,, -38100,2.207724,1.223712,,,,,,,,,,,,,, -38200,2.361717,1.2357306,,,,,,,,,,,,,, -38300,2.3018708,1.2031565,,,,,,,,,,,,,, -38400,2.973374,1.2569386,,,,,,,,,,,,,, -38500,1.847143,1.2765508,,,,,,,,,,,,,, -38600,2.1188319,1.1961254,,,,,,,,,,,,,, -38700,2.916475,1.2278415,,,,,,,,,,,,,, -38800,2.7151384,1.2039927,,,,,,,,,,,,,, -38900,1.8613544,1.2076988,,,,,,,,,,,,,, -39000,5.5483007,1.1497489,,,,,,,,,,,,,, -39100,2.157764,1.2014648,,,,,,,,,,,,,, -39200,2.8245394,1.2312052,,,,,,,,,,,,,, -39300,4.646051,1.1679102,,,,,,,,,,,,,, -39340,,,0.24197032,0.0847769387592396,0.43660167,0.1253077420662888,5348.0,0.23111513,0.0736700993236244,2472.0,30261.989458084103,33221.140043735504,30261.989458084103,2956.2057723999023,1.1383299827575684,0.0 -39400,2.614951,1.1855773,,,,,,,,,,,,,, -39500,3.7426445,1.2085339,,,,,,,,,,,,,, -39600,2.9658277,1.187446,,,,,,,,,,,,,, -39700,2.5341716,1.1871477,,,,,,,,,,,,,, -39800,2.1554868,1.1418238,,,,,,,,,,,,,, -39900,2.0816464,1.2121563,,,,,,,,,,,,,, -40000,2.1096,1.1677117,,,,,,,,,,,,,, -40100,1.9951525,1.1602571,,,,,,,,,,,,,, -40200,9.41253,1.2065343,,,,,,,,,,,,,, -40300,3.1696384,1.1580936,,,,,,,,,,,,,, -40400,4.4064226,1.2152281,,,,,,,,,,,,,, -40500,3.6149883,1.1967804,,,,,,,,,,,,,, -40600,3.4670289,1.1497415,,,,,,,,,,,,,, -40700,1.7316993,1.190594,,,,,,,,,,,,,, -40800,1.7978219,1.121887,,,,,,,,,,,,,, -40900,2.1968112,1.1486974,,,,,,,,,,,,,, -41000,1.5946,1.2594321,,,,,,,,,,,,,, -41100,2.039409,1.1209916,,,,,,,,,,,,,, -41200,3.0215352,1.1680244,,,,,,,,,,,,,, -41204,,,0.26820278,0.0889005727581886,0.42843142,0.1224403101074562,5348.0,0.22799182,0.0729591940365202,2472.0,31702.491810321808,34794.77567911148,31702.491810321808,3089.1986536979675,1.1911566257476809,0.0 -41300,2.3511536,1.1858788,,,,,,,,,,,,,, -41400,3.4466445,1.2243514,,,,,,,,,,,,,, -41500,2.4557276,1.1821208,,,,,,,,,,,,,, -41600,4.862012,1.2136834,,,,,,,,,,,,,, -41700,2.6207278,1.1490202,,,,,,,,,,,,,, -41800,4.472651,1.1857436,,,,,,,,,,,,,, -41900,3.930455,1.204498,,,,,,,,,,,,,, -42000,2.787975,1.1480272,,,,,,,,,,,,,, -42100,1.5147002,1.1941682,,,,,,,,,,,,,, -42200,3.8632715,1.1469623,,,,,,,,,,,,,, -42300,3.270023,1.1762533,,,,,,,,,,,,,, -42400,3.3011308,1.2197813,,,,,,,,,,,,,, -42500,3.6998925,1.1612798,,,,,,,,,,,,,, -42600,2.4929342,1.197645,,,,,,,,,,,,,, -42700,1.7885709,1.1639543,,,,,,,,,,,,,, -42800,3.9744802,1.1050975,,,,,,,,,,,,,, -42900,4.77079,1.1943437,,,,,,,,,,,,,, -43000,2.8044243,1.2117053,,,,,,,,,,,,,, -43081,,,0.21494322,0.075831554091704,0.42295602,0.1212045145157708,5348.0,0.22370486,0.0714561371437856,2472.0,33142.95555186272,36367.7583668232,33142.95555186272,3221.5781757831573,1.2448463439941406,0.0 -43100,5.4394593,1.2036141,,,,,,,,,,,,,, -43200,3.6988428,1.1691484,,,,,,,,,,,,,, -43300,3.8000903,1.1927855,,,,,,,,,,,,,, -43400,3.3577445,1.1248841,,,,,,,,,,,,,, -43500,1.5901505,1.2345757,,,,,,,,,,,,,, -43600,3.5234244,1.1958667,,,,,,,,,,,,,, -43700,2.0138576,1.1381758,,,,,,,,,,,,,, -43800,2.595931,1.1858562,,,,,,,,,,,,,, -43900,3.1498792,1.1680895,,,,,,,,,,,,,, -44000,3.4596279,1.2016114,,,,,,,,,,,,,, -44100,5.0793486,1.1573875,,,,,,,,,,,,,, -44200,3.1015832,1.2150363,,,,,,,,,,,,,, -44300,4.945821,1.1389775,,,,,,,,,,,,,, -44400,2.3468256,1.1413329,,,,,,,,,,,,,, -44500,3.498189,1.1820476,,,,,,,,,,,,,, -44600,2.493856,1.1699073,,,,,,,,,,,,,, -44700,3.6544921,1.1558298,,,,,,,,,,,,,, -44800,2.2781227,1.1583942,,,,,,,,,,,,,, -44900,1.8176197,1.1762865,,,,,,,,,,,,,, -44955,,,0.22929652,0.0814922255743792,0.4198003,0.1197370072506444,5348.0,0.2223216,0.0710092823918916,2472.0,34583.11553454399,37939.03490805626,34583.11553454399,3352.5498700141907,1.303229808807373,0.0 -45000,1.6940238,1.1181468,,,,,,,,,,,,,, -45100,1.7140708,1.1138393,,,,,,,,,,,,,, -45200,3.0377853,1.1889422,,,,,,,,,,,,,, -45300,1.6582146,1.1589633,,,,,,,,,,,,,, -45400,1.6925477,1.16875,,,,,,,,,,,,,, -45500,3.0608408,1.1099815,,,,,,,,,,,,,, -45600,4.542061,1.1645956,,,,,,,,,,,,,, -45700,1.8866535,1.1636349,,,,,,,,,,,,,, -45800,3.8185956,1.1775908,,,,,,,,,,,,,, -45900,4.1191854,1.1482744,,,,,,,,,,,,,, -46000,2.4212456,1.1419286,,,,,,,,,,,,,, -46100,2.9620676,1.1684899,,,,,,,,,,,,,, -46200,2.168909,1.1953523,,,,,,,,,,,,,, -46300,3.272571,1.1530242,,,,,,,,,,,,,, -46400,2.5839534,1.1756599,,,,,,,,,,,,,, -46500,2.0708685,1.1761088,,,,,,,,,,,,,, -46600,2.174228,1.0871519,,,,,,,,,,,,,, -46700,1.978855,1.1462709,,,,,,,,,,,,,, -46800,4.7844934,1.1358732,,,,,,,,,,,,,, -46830,,,0.19979791,0.069397354668894,0.41824093,0.1195632234955636,5348.0,0.22184521,0.0705827392196291,2472.0,36023.417892456055,39512.2322010994,36023.417892456055,3485.300974607468,1.3599789142608645,0.0 -46900,4.997574,1.1382056,,,,,,,,,,,,,, -47000,4.904955,1.1784375,,,,,,,,,,,,,, -47100,2.6742063,1.1804308,,,,,,,,,,,,,, -47200,2.915362,1.1574793,,,,,,,,,,,,,, -47300,1.9698874,1.1820486,,,,,,,,,,,,,, -47400,2.5120597,1.1352205,,,,,,,,,,,,,, -47500,2.0347886,1.1653216,,,,,,,,,,,,,, -47600,5.942016,1.1041712,,,,,,,,,,,,,, -47700,2.0510154,1.1565326,,,,,,,,,,,,,, -47800,1.9419646,1.0992357,,,,,,,,,,,,,, -47900,2.272173,1.159939,,,,,,,,,,,,,, -48000,,,0.224268,0.0774116915071127,0.4184095,0.1196790793322842,5348.0,0.22146653,0.07098897081226,2472.0,36906.64269685745,40526.11241531372,36906.64269685745,3615.843613147736,1.4163413047790527,0.0 -48000,,,,,,,,,,,36906.64269685745,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 0394232f9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -312.095463514328,0.0,18.22455358505249,1,0,18.22455358505249,0.5224818587303162,0.7161954641342163,0.0278240410838958,43793,330.3200616836548,0.5250149965286255,0.7151498794555664,0.0224614420980837,0.5213832855224609,0.7166012525558472,0.0261677396182137,43793 -429.68863344192505,0.0337791442871093,258.50570034980774,749,0,258.50570034980774,0.983142077922821,0.0803441479802131,0.0379646768887439,43793,688.2471423149109,0.9867318868637084,0.0685997605323791,0.034143125103973,0.9841179251670836,0.0774130299687385,0.0363251293632594,43793 -552.933183670044,0.0614261627197265,498.6134088039398,1498,0,498.6134088039398,0.983509361743927,0.0611369796097278,0.0874589221408354,43793,1051.6457195281982,0.987120807170868,0.0488369651138782,0.0850988099663653,0.9844743609428406,0.0578992888331413,0.0861650117542072,43793 -674.4949362277985,0.0938889980316162,738.7220523357391,2246,0,738.7220523357391,0.984077513217926,0.0567509531974792,0.1285198326876661,43793,1413.3673105239868,0.9877936840057372,0.04438441619277,0.1354723842981192,0.9850642085075378,0.0536051467061042,0.1311888501001946,43793 -799.3096804618835,0.1233859062194824,978.9202573299408,2973,0,978.9202573299408,0.9842742681503296,0.0546730384230613,0.1505660785614086,43793,1778.4332921504974,0.9880256056785583,0.0422192029654979,0.165306800246985,0.9852164387702942,0.0518075451254844,0.1575073027716013,43793 -926.6731991767884,0.1525752544403076,1219.0944874286652,3716,0,1219.0944874286652,0.984433889389038,0.0534871667623519,0.1657220910875919,43793,2146.022847175598,0.9881114959716796,0.0410234928131103,0.1904496845569043,0.9853609204292296,0.0507001467049121,0.1661783376870877,43793 -1053.461250782013,0.1797826290130615,1459.0584998130798,4471,0,1459.0584998130798,0.9847017526626588,0.0521049872040748,0.1806297780760786,43793,2512.822060108185,0.988349676132202,0.0401872619986534,0.2060088676469007,0.9855566024780272,0.0494523346424102,0.1817516461108204,43793 -1181.8993849754331,0.2082588672637939,1699.3178343772888,5230,0,1699.3178343772888,0.9849346876144408,0.0508664101362228,0.1926672031567617,43793,2881.5673213005066,0.9888502955436708,0.0382868386805057,0.22891057221235,0.9857863783836364,0.0481408014893531,0.1905339620153854,43793 -1309.9222609996796,0.2366900444030761,1939.3397772312164,5983,0,1939.3397772312164,0.9850884079933168,0.0498464331030845,0.2021478143178211,43793,3249.6604022979736,0.9890902042388916,0.0375760123133659,0.2509227669922907,0.9859312772750854,0.0473600141704082,0.2025852483408946,43793 -1436.303508758545,0.2638471126556396,2179.482953310013,6733,0,2179.482953310013,0.9849119186401368,0.0504280105233192,0.2090736040470227,43793,3616.232548236847,0.9888186454772948,0.0379299372434616,0.2507952327458457,0.9858505129814148,0.0478081554174423,0.2071134306774288,43793 -1563.5682861804962,0.2922539710998535,2419.6403181552887,7486,0,2419.6403181552887,0.9852008819580078,0.0493823438882827,0.2157097935430117,43793,3983.702286481857,0.9891477227211,0.0370690897107124,0.2660915521885904,0.9860952496528624,0.046781238168478,0.214461564179151,43793 -1690.2876436710358,0.3210947513580322,2659.7981646060944,8237,0,2659.7981646060944,0.985374391078949,0.0486757755279541,0.2213807100207803,43793,4350.627862453461,0.9890878200531006,0.036891583353281,0.2597051674587416,0.986197590827942,0.0462719053030014,0.2185389391485515,43793 -1817.3364193439484,0.3503391742706299,2899.909004926681,8975,0,2899.909004926681,0.9852442741394044,0.0491074994206428,0.2204789272640459,43793,4717.839361667633,0.9893003106117249,0.0360952019691467,0.2744277171019114,0.9861196279525756,0.0465352907776832,0.2194612766953757,43793 -1943.068927526474,0.3804106712341308,3139.976670026779,9731,0,3139.976670026779,0.985558032989502,0.0480080991983413,0.2356168489598163,43793,5083.68940782547,0.989581823348999,0.0353526175022125,0.3135008539187511,0.9864249229431152,0.0454021021723747,0.2378901041038153,43793 -2074.2522208690643,0.4082696437835693,3380.1543962955475,10489,0,3380.1543962955475,0.9854817986488342,0.0480877570807933,0.2358211077402679,43793,5455.0981414318085,0.9896994233131408,0.0344877429306507,0.3108126612908478,0.9863623976707458,0.0455813333392143,0.2351091964849687,43793 -2202.160078048706,0.4365499019622803,3620.235187292099,11242,0,3620.235187292099,0.985623300075531,0.0481476299464702,0.2396074151400248,43793,5823.134348869324,0.9899011850357056,0.0337928086519241,0.3420008787401611,0.986520290374756,0.0452591404318809,0.2445117594369473,43793 -2332.0449130535126,0.4658217430114746,3860.454726219177,11989,0,3860.454726219177,0.9856700897216796,0.0474320091307163,0.2479243786116264,43793,6193.287379264832,0.990170955657959,0.0327141284942626,0.3666973354664275,0.9865543842315674,0.0449664629995822,0.2531282590109368,43793 -2458.299390316009,0.495845079421997,4100.5837614536285,12739,0,4100.5837614536285,0.985740840435028,0.0472175404429435,0.2425868423137374,43793,6559.720872402191,0.990369439125061,0.0322626382112503,0.3685053446267267,0.9865767359733582,0.0446254387497901,0.246913705823616,43793 -2586.215543985367,0.5247492790222168,4340.619060993195,13486,0,4340.619060993195,0.985637664794922,0.0477154850959777,0.2470636721113114,43793,6927.721772909164,0.9902468919754028,0.0322908461093902,0.3681962166742015,0.9864902496337892,0.0449402891099453,0.2522724175171478,43793 -2713.153738975525,0.5549750328063965,4580.62403678894,14241,0,4580.62403678894,0.9858179092407228,0.0471119470894336,0.2473383380424254,43793,7294.715516328812,0.99042010307312,0.0318863168358802,0.3612297607359896,0.9867191910743712,0.0443770363926887,0.2567491818925974,43793 -2839.149816274643,0.584143877029419,4820.717617750168,14995,0,4820.717617750168,0.985866367816925,0.0470742098987102,0.2501931113658026,43793,7660.854436397552,0.9904540181159972,0.0318196974694728,0.3714848644505445,0.9867293238639832,0.0442788004875183,0.2621776228683085,43793 -2969.4709827899933,0.6131572723388672,5060.922473907471,15738,0,5060.922473907471,0.9857829809188844,0.0475221239030361,0.2503552406650511,43793,8031.429186582565,0.9903156161308287,0.0321997068822383,0.3752334703336287,0.9866416454315186,0.0446469374001026,0.2536723963678569,43793 -3098.746400117874,0.6441349983215332,5301.190806150436,16488,0,5301.190806150436,0.9859042763710022,0.0470405742526054,0.2510058219218242,43793,8401.024397611618,0.9904052019119264,0.0317552462220191,0.364807285871065,0.9867590069770812,0.044371198862791,0.2592793316141011,43793 -3226.335287809372,0.6756045818328857,5541.426566362381,17237,0,5541.426566362381,0.9857720136642456,0.0469924472272396,0.2551105020349281,43793,8768.900849103928,0.9905096292495728,0.0313777886331081,0.3839918210457591,0.986682653427124,0.0444314666092395,0.260249022031754,43793 -3355.5258157253265,0.706317663192749,5781.60738492012,17983,0,5781.60738492012,0.9858996272087096,0.0470846034586429,0.2526331454065066,43793,9138.323543548584,0.9906891584396362,0.030599458143115,0.4047101758861847,0.9867675304412842,0.0443162843585014,0.2684684903382087,43793 -3485.375589132309,0.7408199310302734,6021.643416404724,18719,0,6021.643416404724,0.985871434211731,0.0470295920968055,0.2514217514588064,43793,9508.266571044922,0.9908905029296876,0.0300750564783811,0.4239957922232233,0.9867285490036012,0.04435645788908,0.2595765740247147,43793 -3611.999095201492,0.7735686302185059,6261.605680465698,19464,0,6261.605680465698,0.9858448505401612,0.046876560896635,0.2534075491441924,43793,9874.90622472763,0.991010844707489,0.0295705664902925,0.4344861831450195,0.9867455959320068,0.0441209748387336,0.2650249423300213,43793 -3739.80638551712,0.8040339946746826,6501.562925338745,20208,0,6501.562925338745,0.9859581589698792,0.0471153147518634,0.2562529002068713,43793,10242.721267700195,0.9910753965377808,0.0293495748192071,0.437858584351662,0.9867947101593018,0.0442087724804878,0.269595815018364,43793 -3867.823861837387,0.8348932266235352,6741.627146005631,20958,0,6741.627146005631,0.9859585762023926,0.0475255995988845,0.2523169835991808,43793,10610.854069948196,0.9908106327056884,0.0300772711634635,0.407418383224706,0.98684424161911,0.0444803908467292,0.266787557603906,43793 -3995.446721315384,0.8683032989501953,6981.642642021179,21690,0,6981.642642021179,0.9858292937278748,0.0474737659096717,0.2580787181285667,43793,10978.547773361206,0.9906601905822754,0.0309933628886938,0.4074115877835104,0.9866725206375122,0.0447343923151493,0.2620614710724585,43793 -4126.191339492798,0.8996272087097168,7221.879897594452,22438,0,7221.879897594452,0.9859687089920044,0.0470192581415176,0.257428460130306,43793,11349.581412315369,0.9909005165100098,0.0299562234431505,0.3997272030494833,0.986880362033844,0.0442997403442859,0.2655132926619737,43793 -4252.171051979065,0.9298944473266602,7462.123221635818,23191,0,7462.123221635818,0.985948085784912,0.0467676855623722,0.2523688362908369,43793,11715.854789495468,0.990909993648529,0.0301405116915702,0.415039702271021,0.9867057800292968,0.0441821590065956,0.2628321328663399,43793 -4381.744997501373,1.251793384552002,7701.862661600113,23939,0,7701.862661600113,0.9859236478805542,0.0469748564064502,0.2573575046521077,43793,12085.510291099548,0.9909388422966005,0.0295868385583162,0.4346094288609707,0.9869027137756348,0.0441179126501083,0.2732023445683491,43793 -4506.559768199921,1.282576322555542,7941.946875095367,24693,0,7941.946875095367,0.9858819246292114,0.0468236804008483,0.2541363107731255,43793,12450.460037469864,0.991054356098175,0.0292239245027303,0.4274887319021707,0.986707866191864,0.0442143566906452,0.2662884585720443,43793 -4633.280071258545,1.3154983520507812,8181.9895396232605,25443,0,8181.9895396232605,0.98597252368927,0.0471655651926994,0.2596896627611287,43793,12817.27622628212,0.9912801384925842,0.0285352878272533,0.4508424274078592,0.9867947101593018,0.0443945862352848,0.2647947449786009,43793 -4758.165808677673,1.3469769954681396,8421.981696367264,26189,0,8421.981696367264,0.9858335256576538,0.0473558753728866,0.2510148645408534,43793,13182.20546245575,0.9914413094520568,0.0279674343764781,0.4703452705316486,0.9867812991142272,0.0445206128060817,0.2657041952537768,43793 -4887.066943645477,1.378154993057251,8662.139254808426,26938,0,8662.139254808426,0.9859678745269777,0.047476228326559,0.2551753097398039,43793,13551.315371513369,0.9915258884429932,0.0276505500078201,0.4760765050180133,0.9869157075881958,0.0446074083447456,0.2645389905862119,43793 -5015.407616376877,1.4100637435913086,8902.194403409958,27690,0,8902.194403409958,0.9857825636863708,0.0472225956618785,0.2515572262223682,43793,13919.763410568235,0.991478443145752,0.0281016025692224,0.4671701164093043,0.9867671132087708,0.0443702042102813,0.2653502731165771,43793 -5144.1230499744415,1.4474318027496338,9142.177165269852,28425,0,9142.177165269852,0.9859636425971984,0.0474299751222133,0.2580036411863702,43793,14288.520901203156,0.9911364912986756,0.0287772286683321,0.4375929169048969,0.9868263602256776,0.0445405580103397,0.2722830766897607,43793 -5271.874435424805,1.480443239212036,9382.21336388588,29159,0,9382.21336388588,0.985961139202118,0.0474254563450813,0.2587341616032605,43793,14656.361695289612,0.9911613464355468,0.028942160308361,0.4474896226765509,0.986789047718048,0.0447144210338592,0.2646365791157178,43793 -5402.314550876617,1.512584209442139,9622.341174840927,29899,0,9622.341174840927,0.9859535694122314,0.0473793819546699,0.2597263787708996,43793,15026.982689142227,0.9912360906600952,0.0286519527435302,0.4469355000732045,0.9867979288101196,0.0446059815585613,0.2667757483159876,43793 -5533.393811702728,1.550180196762085,9862.550683259964,30634,0,9862.550683259964,0.9859750270843506,0.0477339103817939,0.2601599084649045,43793,15398.330970048904,0.9913098812103271,0.0283946581184864,0.4532779013963894,0.9868007898330688,0.0447873137891292,0.2689093160310568,43793 -5659.4643721580505,1.5825748443603516,10102.536264419556,31379,0,10102.536264419556,0.985964059829712,0.0474789142608642,0.2624128359043144,43793,15764.439517259598,0.9914488196372986,0.0278518609702587,0.4698440030144065,0.9867720007896424,0.04453831538558,0.2701344358372728,43793 -5786.365542411804,1.615133285522461,10342.64687371254,32124,0,10342.64687371254,0.9859771132469176,0.0472414530813694,0.2615611187014712,43793,16131.504002094269,0.9916564226150512,0.027042893692851,0.4944718535620369,0.9868767261505128,0.0442681089043617,0.271961120654364,43793 -5908.865457057953,1.652513027191162,10582.716272592545,32856,0,10582.716272592545,0.9860023856163024,0.0472645424306392,0.2620336225403224,43793,16494.133712530136,0.99176424741745,0.0266096200793981,0.498826540416365,0.9868531823158264,0.0443659499287605,0.2692807413861312,43793 -6036.689339399338,1.685504913330078,10822.800383806229,33595,0,10822.800383806229,0.9859889149665833,0.0477740578353405,0.2634912505782247,43793,16862.094605207443,0.9918482303619384,0.0264681000262498,0.5026169864173307,0.986846685409546,0.0450812652707099,0.2694701894953767,43793 -6165.134907007217,1.7191433906555176,11062.89523434639,34335,0,11062.89523434639,0.9860196709632874,0.0474854782223701,0.2597895271918095,43793,17230.688241004944,0.991648256778717,0.0270575191825628,0.4754191822452325,0.9869412779808044,0.0447731092572212,0.2740292724154801,43793 -6293.562938928604,1.7529196739196775,11302.936974048616,35083,0,11302.936974048616,0.9860441088676452,0.047148123383522,0.2659752052189602,43793,17599.211398363113,0.9914780855178832,0.0276719629764556,0.4771699893653773,0.986946940422058,0.0444764271378517,0.2762864765832121,43793 -6422.19800400734,1.7894747257232666,11543.209511995316,35831,0,11543.209511995316,0.9860904216766356,0.0475407242774963,0.2623007088811127,43793,17968.17503976822,0.9915109276771544,0.0274749193340539,0.4855648311323703,0.9869274497032166,0.0445232838392257,0.2701972526758367,43793 -6549.4206647872925,1.8252544403076167,11783.2023396492,36577,0,11783.2023396492,0.9860007166862488,0.0474309548735618,0.2636249846278241,43793,18335.44613981247,0.9916188716888428,0.0272390656173229,0.4803275266367559,0.9868852496147156,0.044582299888134,0.271263958619182,43793 -6676.074639558792,1.858780860900879,12023.32865190506,37323,0,12023.32865190506,0.9859986305236816,0.0479762442409992,0.2591521938112194,43793,18702.27978849411,0.9916990995407104,0.0268614571541547,0.484263148365716,0.9868641495704652,0.0448390506207942,0.2684500922044856,43793 -6802.231438398361,1.8928401470184328,12263.584111452104,38072,0,12263.584111452104,0.9859472513198853,0.0474219359457492,0.2636671236942603,43793,19068.74548768997,0.9920689463615416,0.0256481822580099,0.516311374792447,0.9868032336235046,0.0446102283895015,0.2661784733789995,43793 -6928.428899765015,1.9264566898345947,12503.6462123394,38821,0,12503.6462123394,0.9859569072723388,0.0475714839994907,0.2615809471411672,43793,19435.058240175247,0.9921510815620422,0.0255514401942491,0.5158998215716353,0.9867951273918152,0.0446099750697612,0.2744577401290294,43793 -7056.762053728104,1.9602704048156736,12743.860652923584,39573,0,12743.860652923584,0.9850488305091858,0.0510807670652866,0.2129662973348563,43793,19803.6591398716,0.990967333316803,0.0293112080544233,0.4569042325964516,0.986046552658081,0.0476715117692947,0.233260071570542,43793 -7180.665674209595,1.9954063892364504,12983.941451787949,40320,0,12983.941451787949,0.9859982132911682,0.0479002371430397,0.2639822632255164,43793,20167.69873785973,0.9924874901771544,0.0244606956839561,0.5351341203817217,0.9868787527084352,0.0445943921804428,0.2759908697710495,43793 -7300.5468554496765,2.0311975479125977,13224.08312869072,41067,0,13224.08312869072,0.9859969019889832,0.0472699888050556,0.2602967469117153,43793,20527.777365922928,0.9923626780509948,0.0249550715088844,0.5233125303946242,0.9868706464767456,0.0443031042814254,0.2764290908345192,43793 -7426.20378446579,2.0658912658691406,13464.266607761385,41813,0,13464.266607761385,0.9860803484916688,0.0479672290384769,0.2647587286793563,43793,20893.67218565941,0.9921444058418274,0.0252819508314132,0.528187037457823,0.9869668483734132,0.0449664071202278,0.2712547909511459,43793 -7550.482100486755,2.102457523345948,13704.357329368591,42551,0,13704.357329368591,0.9861013889312744,0.0479997955262661,0.261924874372922,43793,21258.0973508358,0.9920623898506165,0.0255641397088766,0.5207061522784611,0.986931085586548,0.0451455265283584,0.2757223796749376,43793 -7676.686524152756,2.1386983394622803,13944.305476903915,43293,0,13944.305476903915,0.9859657287597656,0.0482613332569599,0.2562264799589845,43793,21624.305897712708,0.9919952154159546,0.0256527718156576,0.5196389960976198,0.9868032336235046,0.0451744943857193,0.2733776448060774,43793 -7805.214061498642,2.1753900051116943,14184.53102874756,44038,0,14184.53102874756,0.9860146045684814,0.0481569580733776,0.2617285586172711,43793,21993.115421056747,0.9922802448272704,0.02483881264925,0.5377617439585939,0.9868596792221068,0.0452494956552982,0.270351017601974,43793 -7930.460354089737,2.211509704589844,14424.795090436935,44777,0,14424.795090436935,0.9860327243804932,0.0481882318854332,0.2616151770568872,43793,22358.681414604187,0.9924088716506958,0.0242243651300668,0.5362641543231005,0.9869104027748108,0.0453433506190776,0.2695658229800238,43793 -8058.0960521698,2.247072696685791,14664.76844573021,45515,0,14664.76844573021,0.986042022705078,0.0484271347522735,0.2579358481884553,43793,22726.34547829628,0.9927688837051392,0.0231147781014442,0.5721043220273573,0.9869270324707032,0.0452565178275108,0.2640619486550432,43793 -8184.0625557899475,2.28225326538086,14904.87128996849,46262,0,14904.87128996849,0.985990583896637,0.048694908618927,0.2598243224261893,43793,23092.4700255394,0.9930108189582824,0.0223857462406158,0.5915540953232792,0.9868503212928772,0.0455769263207912,0.268653912629021,43793 -8307.041923999786,2.317678451538086,15144.95430803299,47008,0,15144.95430803299,0.9859114289283752,0.0486891642212867,0.264437696290971,43793,23455.587432146072,0.993267297744751,0.0218341443687677,0.6045476022820708,0.986806869506836,0.0458246506750583,0.2718907877934435,43793 -8434.66095161438,2.353482723236084,15385.188156366348,47748,0,15385.188156366348,0.9859864115715028,0.0491985715925693,0.2576000876595037,43793,23823.49522161484,0.992554783821106,0.0238631702959537,0.5605475454067415,0.9868738651275636,0.045866098254919,0.274083811671372,43793 -8559.123657226562,2.389343738555908,15625.320230960846,48496,0,15625.320230960846,0.9859628081321716,0.0490183718502521,0.2601319877159591,43793,24188.14550971985,0.9926549792289734,0.0236036144196987,0.5488211291481068,0.9868596792221068,0.0460748337209224,0.2715950909297707,43793 -8681.672759532928,2.4259297847747803,15865.286841392515,49243,0,15865.286841392515,0.9859880805015564,0.049083225429058,0.2621291751755614,43793,24550.717749118805,0.9927665591239928,0.0231651198118925,0.5600378176647713,0.9868649244308472,0.0459690317511558,0.2729185746436617,43793 -8800.601737260818,2.4639928340911865,16105.240637540815,49988,0,16105.240637540815,0.9858924746513368,0.049880214035511,0.2606511181190385,43793,24909.65825247765,0.9928274750709534,0.0229332204908132,0.5803153113657378,0.9868227243423462,0.0465543456375598,0.2713811193405487,43793 -8921.704628229141,2.500322103500366,16345.393598794935,50738,0,16345.393598794935,0.98593670129776,0.0497093610465526,0.2601520911636134,43793,25270.969877958298,0.992751955986023,0.0230761207640171,0.5672821856545007,0.9868150353431702,0.0462609492242336,0.2708342805840612,43793 -9044.232073783876,2.5364415645599365,16585.514472723007,51489,0,16585.514472723007,0.9858849048614502,0.0498098582029342,0.263006766620887,43793,25633.673954486847,0.9931676983833312,0.0218032449483871,0.5953628427219528,0.9868471026420592,0.0464733131229877,0.2731393930659657,43793 -9164.892379283903,2.573488712310791,16825.499007225037,52235,0,16825.499007225037,0.9858680367469788,0.0501306690275669,0.2569849594604288,43793,25994.37507820129,0.993556797504425,0.0206640735268592,0.6220936529065954,0.9868056774139404,0.0468620583415031,0.2664892358792565,43793 -9288.184242248535,2.612375974655152,17065.48960494995,52980,0,17065.48960494995,0.98589289188385,0.0506656877696514,0.2570500295083031,43793,26357.715396165848,0.9937600493431092,0.0198313146829605,0.6459993043755188,0.9867037534713744,0.0474918149411678,0.2685034861718007,43793 -9407.865612268448,2.6504719257354736,17305.6275203228,53724,0,17305.6275203228,0.985889494419098,0.0503124706447124,0.261990889494677,43793,26717.59257531166,0.9939901828765868,0.0193338319659233,0.6436779537032165,0.9866481423377992,0.0475041940808296,0.2681561503925183,43793 -9532.784777402878,2.6869962215423584,17545.65809392929,54469,0,17545.65809392929,0.9858878254890442,0.0506690517067909,0.2569829867377682,43793,27082.599598646164,0.9936777949333192,0.0201111733913421,0.6458658038745391,0.9866968989372252,0.047279093414545,0.267442034344366,43793 -9653.950505495071,2.723959684371948,17785.798129081726,55211,0,17785.798129081726,0.985811173915863,0.0510495565831661,0.2536814701130161,43793,27443.96195435524,0.9936330318450928,0.0203234907239675,0.6208121445688634,0.9867082238197328,0.0476051606237888,0.2657908474666073,43793 -9779.996412038803,2.7612576484680176,18026.053458452225,55959,0,18026.053458452225,0.9858027696609496,0.0511068813502788,0.2597542658100841,43793,27810.32054162025,0.9933901429176332,0.0209429822862148,0.6129876580194846,0.9866664409637452,0.0478049889206886,0.2669865188916774,43793 -9904.859407424927,2.799586296081543,18266.163946151733,56710,0,18266.163946151733,0.9857563972473145,0.05178118124604225,0.2562075157043629,43793,28175.351845502853,0.9932118058204651,0.021330367773771286,0.594974708618752,0.9866802096366882,0.04851559177041054,0.26519575804293455,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index 87c0e349a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,653 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.7371044,0.71459216,,,,,,,,,,,,,,,,, -1,,,0.5250149965286255,0.7151498794555664,0.0224614420980837,0.5213832855224609,0.7166012525558472,0.0261677396182137,43793.0,0.5224818587303162,0.7161954641342163,0.0278240410838958,43793.0,18.22455358505249,330.3200616836548,18.22455358505249,312.095463514328,0.0,0.0 -100,0.60223687,0.42012855,,,,,,,,,,,,,,,,, -200,0.35242286,0.31525835,,,,,,,,,,,,,,,,, -300,0.2569092,0.22896984,,,,,,,,,,,,,,,,, -400,0.1691178,0.1599384,,,,,,,,,,,,,,,,, -500,0.10653999,0.109940656,,,,,,,,,,,,,,,,, -600,0.065417066,0.08791898,,,,,,,,,,,,,,,,, -700,0.046901397,0.066946425,,,,,,,,,,,,,,,,, -749,,,0.9867318868637084,0.0685997605323791,0.034143125103973,0.9841179251670836,0.0774130299687385,0.0363251293632594,43793.0,0.983142077922821,0.0803441479802131,0.0379646768887439,43793.0,258.50570034980774,688.2471423149109,258.50570034980774,429.68863344192505,0.0337791442871093,0.0 -800,0.032160554,0.06665626,,,,,,,,,,,,,,,,, -900,0.15983059,0.061792757,,,,,,,,,,,,,,,,, -1000,0.05475169,0.059090167,,,,,,,,,,,,,,,,, -1100,0.11100432,0.056161534,,,,,,,,,,,,,,,,, -1200,0.26966006,0.053658847,,,,,,,,,,,,,,,,, -1300,0.33447692,0.05209105,,,,,,,,,,,,,,,,, -1400,0.11226441,0.04776481,,,,,,,,,,,,,,,,, -1498,,,0.987120807170868,0.0488369651138782,0.0850988099663653,0.9844743609428406,0.0578992888331413,0.0861650117542072,43793.0,0.983509361743927,0.0611369796097278,0.0874589221408354,43793.0,498.6134088039398,1051.6457195281982,498.6134088039398,552.933183670044,0.0614261627197265,0.0 -1500,0.32168272,0.045913726,,,,,,,,,,,,,,,,, -1600,0.42647645,0.049845982,,,,,,,,,,,,,,,,, -1700,0.22718193,0.05190289,,,,,,,,,,,,,,,,, -1800,0.17423648,0.050273526,,,,,,,,,,,,,,,,, -1900,0.07139972,0.049996622,,,,,,,,,,,,,,,,, -2000,0.17172761,0.045269463,,,,,,,,,,,,,,,,, -2100,0.19899097,0.042673852,,,,,,,,,,,,,,,,, -2200,0.18663466,0.04551201,,,,,,,,,,,,,,,,, -2246,,,0.9877936840057372,0.04438441619277,0.1354723842981192,0.9850642085075378,0.0536051467061042,0.1311888501001946,43793.0,0.984077513217926,0.0567509531974792,0.1285198326876661,43793.0,738.7220523357391,1413.3673105239868,738.7220523357391,674.4949362277985,0.0938889980316162,0.0 -2300,0.10463924,0.04280965,,,,,,,,,,,,,,,,, -2400,0.089078054,0.045085166,,,,,,,,,,,,,,,,, -2500,0.06857383,0.04023597,,,,,,,,,,,,,,,,, -2600,0.14510436,0.047692537,,,,,,,,,,,,,,,,, -2700,0.13045396,0.04117077,,,,,,,,,,,,,,,,, -2800,0.10649617,0.040586114,,,,,,,,,,,,,,,,, -2900,0.11231413,0.040797237,,,,,,,,,,,,,,,,, -2973,,,0.9880256056785583,0.0422192029654979,0.165306800246985,0.9852164387702942,0.0518075451254844,0.1575073027716013,43793.0,0.9842742681503296,0.0546730384230613,0.1505660785614086,43793.0,978.9202573299408,1778.4332921504974,978.9202573299408,799.3096804618835,0.1233859062194824,0.0 -3000,0.122324176,0.046032075,,,,,,,,,,,,,,,,, -3100,0.0803333,0.045717105,,,,,,,,,,,,,,,,, -3200,0.10522786,0.040946092,,,,,,,,,,,,,,,,, -3300,0.10653806,0.043020587,,,,,,,,,,,,,,,,, -3400,0.050599154,0.04338199,,,,,,,,,,,,,,,,, -3500,0.043403916,0.04356351,,,,,,,,,,,,,,,,, -3600,0.116087444,0.03991685,,,,,,,,,,,,,,,,, -3700,0.14346845,0.04448991,,,,,,,,,,,,,,,,, -3716,,,0.9881114959716796,0.0410234928131103,0.1904496845569043,0.9853609204292296,0.0507001467049121,0.1661783376870877,43793.0,0.984433889389038,0.0534871667623519,0.1657220910875919,43793.0,1219.0944874286652,2146.022847175598,1219.0944874286652,926.6731991767884,0.1525752544403076,0.0 -3800,0.062395427,0.04284718,,,,,,,,,,,,,,,,, -3900,0.04736853,0.04399071,,,,,,,,,,,,,,,,, -4000,0.08979814,0.045421656,,,,,,,,,,,,,,,,, -4100,0.1388391,0.039775264,,,,,,,,,,,,,,,,, -4200,0.091951095,0.046014134,,,,,,,,,,,,,,,,, -4300,0.074212216,0.04384021,,,,,,,,,,,,,,,,, -4400,0.074881986,0.035501927,,,,,,,,,,,,,,,,, -4471,,,0.988349676132202,0.0401872619986534,0.2060088676469007,0.9855566024780272,0.0494523346424102,0.1817516461108204,43793.0,0.9847017526626588,0.0521049872040748,0.1806297780760786,43793.0,1459.0584998130798,2512.822060108185,1459.0584998130798,1053.461250782013,0.1797826290130615,0.0 -4500,0.05505213,0.043110576,,,,,,,,,,,,,,,,, -4600,0.05302992,0.04604776,,,,,,,,,,,,,,,,, -4700,0.04866206,0.044693924,,,,,,,,,,,,,,,,, -4800,0.03573704,0.04270119,,,,,,,,,,,,,,,,, -4900,0.04968999,0.039914284,,,,,,,,,,,,,,,,, -5000,0.049450677,0.039825317,,,,,,,,,,,,,,,,, -5100,0.14910844,0.04276553,,,,,,,,,,,,,,,,, -5200,0.040012293,0.042220835,,,,,,,,,,,,,,,,, -5230,,,0.9888502955436708,0.0382868386805057,0.22891057221235,0.9857863783836364,0.0481408014893531,0.1905339620153854,43793.0,0.9849346876144408,0.0508664101362228,0.1926672031567617,43793.0,1699.3178343772888,2881.5673213005066,1699.3178343772888,1181.8993849754331,0.2082588672637939,0.0 -5300,0.065350726,0.041310094,,,,,,,,,,,,,,,,, -5400,0.040056445,0.039033897,,,,,,,,,,,,,,,,, -5500,0.048822224,0.043493822,,,,,,,,,,,,,,,,, -5600,0.061794594,0.03738674,,,,,,,,,,,,,,,,, -5700,0.04919036,0.041041974,,,,,,,,,,,,,,,,, -5800,0.04566235,0.038228095,,,,,,,,,,,,,,,,, -5900,0.0688407,0.040822193,,,,,,,,,,,,,,,,, -5983,,,0.9890902042388916,0.0375760123133659,0.2509227669922907,0.9859312772750854,0.0473600141704082,0.2025852483408946,43793.0,0.9850884079933168,0.0498464331030845,0.2021478143178211,43793.0,1939.3397772312164,3249.6604022979736,1939.3397772312164,1309.9222609996796,0.2366900444030761,0.0 -6000,0.0325702,0.039184637,,,,,,,,,,,,,,,,, -6100,0.040285055,0.041636143,,,,,,,,,,,,,,,,, -6200,0.084686786,0.040756706,,,,,,,,,,,,,,,,, -6300,0.052145634,0.040727768,,,,,,,,,,,,,,,,, -6400,0.035025153,0.03859377,,,,,,,,,,,,,,,,, -6500,0.031437635,0.03775892,,,,,,,,,,,,,,,,, -6600,0.032930415,0.041661453,,,,,,,,,,,,,,,,, -6700,0.03549215,0.04124663,,,,,,,,,,,,,,,,, -6733,,,0.9888186454772948,0.0379299372434616,0.2507952327458457,0.9858505129814148,0.0478081554174423,0.2071134306774288,43793.0,0.9849119186401368,0.0504280105233192,0.2090736040470227,43793.0,2179.482953310013,3616.232548236847,2179.482953310013,1436.303508758545,0.2638471126556396,0.0 -6800,0.038360715,0.04100093,,,,,,,,,,,,,,,,, -6900,0.08944501,0.04292042,,,,,,,,,,,,,,,,, -7000,0.03198601,0.037749555,,,,,,,,,,,,,,,,, -7100,0.03541021,0.04154605,,,,,,,,,,,,,,,,, -7200,0.025069725,0.037201583,,,,,,,,,,,,,,,,, -7300,0.025419034,0.040039964,,,,,,,,,,,,,,,,, -7400,0.025374046,0.03539117,,,,,,,,,,,,,,,,, -7486,,,0.9891477227211,0.0370690897107124,0.2660915521885904,0.9860952496528624,0.046781238168478,0.214461564179151,43793.0,0.9852008819580078,0.0493823438882827,0.2157097935430117,43793.0,2419.6403181552887,3983.702286481857,2419.6403181552887,1563.5682861804962,0.2922539710998535,0.0 -7500,0.030122498,0.040309902,,,,,,,,,,,,,,,,, -7600,0.042592064,0.038266364,,,,,,,,,,,,,,,,, -7700,0.03031111,0.04409063,,,,,,,,,,,,,,,,, -7800,0.032594886,0.04070201,,,,,,,,,,,,,,,,, -7900,0.023572017,0.038016707,,,,,,,,,,,,,,,,, -8000,0.036219858,0.04047473,,,,,,,,,,,,,,,,, -8100,0.039427478,0.043451786,,,,,,,,,,,,,,,,, -8200,0.026689697,0.038640633,,,,,,,,,,,,,,,,, -8237,,,0.9890878200531006,0.036891583353281,0.2597051674587416,0.986197590827942,0.0462719053030014,0.2185389391485515,43793.0,0.985374391078949,0.0486757755279541,0.2213807100207803,43793.0,2659.7981646060944,4350.627862453461,2659.7981646060944,1690.2876436710358,0.3210947513580322,0.0 -8300,0.02423164,0.04106651,,,,,,,,,,,,,,,,, -8400,0.022832373,0.03911947,,,,,,,,,,,,,,,,, -8500,0.031816054,0.037111655,,,,,,,,,,,,,,,,, -8600,0.035238873,0.039850753,,,,,,,,,,,,,,,,, -8700,0.024209572,0.039283462,,,,,,,,,,,,,,,,, -8800,0.017225107,0.036626466,,,,,,,,,,,,,,,,, -8900,0.029115997,0.04392107,,,,,,,,,,,,,,,,, -8975,,,0.9893003106117249,0.0360952019691467,0.2744277171019114,0.9861196279525756,0.0465352907776832,0.2194612766953757,43793.0,0.9852442741394044,0.0491074994206428,0.2204789272640459,43793.0,2899.909004926681,4717.839361667633,2899.909004926681,1817.3364193439484,0.3503391742706299,0.0 -9000,0.029303398,0.03879161,,,,,,,,,,,,,,,,, -9100,0.031545985,0.039469387,,,,,,,,,,,,,,,,, -9200,0.022464614,0.03448655,,,,,,,,,,,,,,,,, -9300,0.018298825,0.03831052,,,,,,,,,,,,,,,,, -9400,0.034906283,0.03640238,,,,,,,,,,,,,,,,, -9500,0.02107196,0.033252243,,,,,,,,,,,,,,,,, -9600,0.031812925,0.03876495,,,,,,,,,,,,,,,,, -9700,0.03760027,0.03869984,,,,,,,,,,,,,,,,, -9731,,,0.989581823348999,0.0353526175022125,0.3135008539187511,0.9864249229431152,0.0454021021723747,0.2378901041038153,43793.0,0.985558032989502,0.0480080991983413,0.2356168489598163,43793.0,3139.976670026779,5083.68940782547,3139.976670026779,1943.068927526474,0.3804106712341308,0.0 -9800,0.045728154,0.038212314,,,,,,,,,,,,,,,,, -9900,0.026788196,0.038473982,,,,,,,,,,,,,,,,, -10000,0.033190094,0.03729832,,,,,,,,,,,,,,,,, -10100,0.02506898,0.040601373,,,,,,,,,,,,,,,,, -10200,0.023108916,0.035385534,,,,,,,,,,,,,,,,, -10300,0.025456086,0.040764075,,,,,,,,,,,,,,,,, -10400,0.03727146,0.037253264,,,,,,,,,,,,,,,,, -10489,,,0.9896994233131408,0.0344877429306507,0.3108126612908478,0.9863623976707458,0.0455813333392143,0.2351091964849687,43793.0,0.9854817986488342,0.0480877570807933,0.2358211077402679,43793.0,3380.1543962955475,5455.0981414318085,3380.1543962955475,2074.2522208690643,0.4082696437835693,0.0 -10500,0.02941923,0.03744262,,,,,,,,,,,,,,,,, -10600,0.036601003,0.038591076,,,,,,,,,,,,,,,,, -10700,0.04179584,0.03866215,,,,,,,,,,,,,,,,, -10800,0.04797486,0.03613354,,,,,,,,,,,,,,,,, -10900,0.03911008,0.041262582,,,,,,,,,,,,,,,,, -11000,0.025778307,0.037533615,,,,,,,,,,,,,,,,, -11100,0.03154294,0.036586452,,,,,,,,,,,,,,,,, -11200,0.029595366,0.03849004,,,,,,,,,,,,,,,,, -11242,,,0.9899011850357056,0.0337928086519241,0.3420008787401611,0.986520290374756,0.0452591404318809,0.2445117594369473,43793.0,0.985623300075531,0.0481476299464702,0.2396074151400248,43793.0,3620.235187292099,5823.134348869324,3620.235187292099,2202.160078048706,0.4365499019622803,0.0 -11300,0.042491365,0.038958963,,,,,,,,,,,,,,,,, -11400,0.027545871,0.03891201,,,,,,,,,,,,,,,,, -11500,0.027324734,0.03747196,,,,,,,,,,,,,,,,, -11600,0.026407106,0.037020113,,,,,,,,,,,,,,,,, -11700,0.028164495,0.03434518,,,,,,,,,,,,,,,,, -11800,0.032150194,0.03541308,,,,,,,,,,,,,,,,, -11900,0.04376963,0.039361976,,,,,,,,,,,,,,,,, -11989,,,0.990170955657959,0.0327141284942626,0.3666973354664275,0.9865543842315674,0.0449664629995822,0.2531282590109368,43793.0,0.9856700897216796,0.0474320091307163,0.2479243786116264,43793.0,3860.454726219177,6193.287379264832,3860.454726219177,2332.0449130535126,0.4658217430114746,0.0 -12000,0.030512145,0.03634727,,,,,,,,,,,,,,,,, -12100,0.030956697,0.037778754,,,,,,,,,,,,,,,,, -12200,0.02847434,0.035161488,,,,,,,,,,,,,,,,, -12300,0.029849783,0.032794118,,,,,,,,,,,,,,,,, -12400,0.03026678,0.03456594,,,,,,,,,,,,,,,,, -12500,0.029863823,0.03443944,,,,,,,,,,,,,,,,, -12600,0.033363283,0.036941957,,,,,,,,,,,,,,,,, -12700,0.029781852,0.035236873,,,,,,,,,,,,,,,,, -12739,,,0.990369439125061,0.0322626382112503,0.3685053446267267,0.9865767359733582,0.0446254387497901,0.246913705823616,43793.0,0.985740840435028,0.0472175404429435,0.2425868423137374,43793.0,4100.5837614536285,6559.720872402191,4100.5837614536285,2458.299390316009,0.495845079421997,0.0 -12800,0.038656037,0.038329355,,,,,,,,,,,,,,,,, -12900,0.04733078,0.035707153,,,,,,,,,,,,,,,,, -13000,0.029587995,0.03904109,,,,,,,,,,,,,,,,, -13100,0.033588655,0.038581498,,,,,,,,,,,,,,,,, -13200,0.07277448,0.041485623,,,,,,,,,,,,,,,,, -13300,0.0635401,0.03478141,,,,,,,,,,,,,,,,, -13400,0.047107205,0.039481197,,,,,,,,,,,,,,,,, -13486,,,0.9902468919754028,0.0322908461093902,0.3681962166742015,0.9864902496337892,0.0449402891099453,0.2522724175171478,43793.0,0.985637664794922,0.0477154850959777,0.2470636721113114,43793.0,4340.619060993195,6927.721772909164,4340.619060993195,2586.215543985367,0.5247492790222168,0.0 -13500,0.032419965,0.03245461,,,,,,,,,,,,,,,,, -13600,0.03298479,0.03638411,,,,,,,,,,,,,,,,, -13700,0.038611926,0.037503194,,,,,,,,,,,,,,,,, -13800,0.032892905,0.036713287,,,,,,,,,,,,,,,,, -13900,0.045932468,0.037698273,,,,,,,,,,,,,,,,, -14000,0.05569323,0.040713042,,,,,,,,,,,,,,,,, -14100,0.035285503,0.039188657,,,,,,,,,,,,,,,,, -14200,0.057720724,0.035410743,,,,,,,,,,,,,,,,, -14241,,,0.99042010307312,0.0318863168358802,0.3612297607359896,0.9867191910743712,0.0443770363926887,0.2567491818925974,43793.0,0.9858179092407228,0.0471119470894336,0.2473383380424254,43793.0,4580.62403678894,7294.715516328812,4580.62403678894,2713.153738975525,0.5549750328063965,0.0 -14300,0.04433599,0.034010243,,,,,,,,,,,,,,,,, -14400,0.042333916,0.034812327,,,,,,,,,,,,,,,,, -14500,0.04049614,0.03622252,,,,,,,,,,,,,,,,, -14600,0.04277096,0.035152294,,,,,,,,,,,,,,,,, -14700,0.08782277,0.03384873,,,,,,,,,,,,,,,,, -14800,0.06211911,0.037494637,,,,,,,,,,,,,,,,, -14900,0.041626807,0.032720327,,,,,,,,,,,,,,,,, -14995,,,0.9904540181159972,0.0318196974694728,0.3714848644505445,0.9867293238639832,0.0442788004875183,0.2621776228683085,43793.0,0.985866367816925,0.0470742098987102,0.2501931113658026,43793.0,4820.717617750168,7660.854436397552,4820.717617750168,2839.149816274643,0.584143877029419,0.0 -15000,0.053921364,0.036718346,,,,,,,,,,,,,,,,, -15100,0.043400828,0.032491222,,,,,,,,,,,,,,,,, -15200,0.040827222,0.03518444,,,,,,,,,,,,,,,,, -15300,0.05693308,0.037701927,,,,,,,,,,,,,,,,, -15400,0.049539745,0.038847353,,,,,,,,,,,,,,,,, -15500,0.0633358,0.03530064,,,,,,,,,,,,,,,,, -15600,0.04599061,0.031869397,,,,,,,,,,,,,,,,, -15700,0.05490261,0.03276371,,,,,,,,,,,,,,,,, -15738,,,0.9903156161308287,0.0321997068822383,0.3752334703336287,0.9866416454315186,0.0446469374001026,0.2536723963678569,43793.0,0.9857829809188844,0.0475221239030361,0.2503552406650511,43793.0,5060.922473907471,8031.429186582565,5060.922473907471,2969.4709827899933,0.6131572723388672,0.0 -15800,0.07392004,0.03467152,,,,,,,,,,,,,,,,, -15900,0.04982798,0.03749546,,,,,,,,,,,,,,,,, -16000,0.049135998,0.034605835,,,,,,,,,,,,,,,,, -16100,0.049847517,0.03438525,,,,,,,,,,,,,,,,, -16200,0.0626836,0.035800993,,,,,,,,,,,,,,,,, -16300,0.08590935,0.031159928,,,,,,,,,,,,,,,,, -16400,0.04816251,0.036894806,,,,,,,,,,,,,,,,, -16488,,,0.9904052019119264,0.0317552462220191,0.364807285871065,0.9867590069770812,0.044371198862791,0.2592793316141011,43793.0,0.9859042763710022,0.0470405742526054,0.2510058219218242,43793.0,5301.190806150436,8401.024397611618,5301.190806150436,3098.746400117874,0.6441349983215332,0.0 -16500,0.06962712,0.037788372,,,,,,,,,,,,,,,,, -16600,0.050404582,0.035602815,,,,,,,,,,,,,,,,, -16700,0.05134923,0.03456368,,,,,,,,,,,,,,,,, -16800,0.046725217,0.0348847,,,,,,,,,,,,,,,,, -16900,0.076232955,0.034630436,,,,,,,,,,,,,,,,, -17000,0.06599627,0.03522124,,,,,,,,,,,,,,,,, -17100,0.049447738,0.034175705,,,,,,,,,,,,,,,,, -17200,0.10380318,0.03401317,,,,,,,,,,,,,,,,, -17237,,,0.9905096292495728,0.0313777886331081,0.3839918210457591,0.986682653427124,0.0444314666092395,0.260249022031754,43793.0,0.9857720136642456,0.0469924472272396,0.2551105020349281,43793.0,5541.426566362381,8768.900849103928,5541.426566362381,3226.335287809372,0.6756045818328857,0.0 -17300,0.0768527,0.03451787,,,,,,,,,,,,,,,,, -17400,0.0629925,0.034664255,,,,,,,,,,,,,,,,, -17500,0.04914436,0.03625905,,,,,,,,,,,,,,,,, -17600,0.067277476,0.034555074,,,,,,,,,,,,,,,,, -17700,0.116099305,0.033746317,,,,,,,,,,,,,,,,, -17800,0.05313307,0.036645334,,,,,,,,,,,,,,,,, -17900,0.10209897,0.032236237,,,,,,,,,,,,,,,,, -17983,,,0.9906891584396362,0.030599458143115,0.4047101758861847,0.9867675304412842,0.0443162843585014,0.2684684903382087,43793.0,0.9858996272087096,0.0470846034586429,0.2526331454065066,43793.0,5781.60738492012,9138.323543548584,5781.60738492012,3355.5258157253265,0.706317663192749,0.0 -18000,0.064338,0.038786806,,,,,,,,,,,,,,,,, -18100,0.06526654,0.0328959,,,,,,,,,,,,,,,,, -18200,0.080023594,0.033194084,,,,,,,,,,,,,,,,, -18300,0.077117585,0.036979724,,,,,,,,,,,,,,,,, -18400,0.08634717,0.034242306,,,,,,,,,,,,,,,,, -18500,0.048100814,0.035610884,,,,,,,,,,,,,,,,, -18600,0.06863055,0.036221325,,,,,,,,,,,,,,,,, -18700,0.056741603,0.035159606,,,,,,,,,,,,,,,,, -18719,,,0.9908905029296876,0.0300750564783811,0.4239957922232233,0.9867285490036012,0.04435645788908,0.2595765740247147,43793.0,0.985871434211731,0.0470295920968055,0.2514217514588064,43793.0,6021.643416404724,9508.266571044922,6021.643416404724,3485.375589132309,0.7408199310302734,0.0 -18800,0.0616783,0.03419317,,,,,,,,,,,,,,,,, -18900,0.068189494,0.037374493,,,,,,,,,,,,,,,,, -19000,0.06424454,0.038886726,,,,,,,,,,,,,,,,, -19100,0.065480895,0.035427667,,,,,,,,,,,,,,,,, -19200,0.072973035,0.034588177,,,,,,,,,,,,,,,,, -19300,0.08521106,0.037141956,,,,,,,,,,,,,,,,, -19400,0.07849642,0.035817817,,,,,,,,,,,,,,,,, -19464,,,0.991010844707489,0.0295705664902925,0.4344861831450195,0.9867455959320068,0.0441209748387336,0.2650249423300213,43793.0,0.9858448505401612,0.046876560896635,0.2534075491441924,43793.0,6261.605680465698,9874.90622472763,6261.605680465698,3611.999095201492,0.7735686302185059,0.0 -19500,0.05494177,0.034228086,,,,,,,,,,,,,,,,, -19600,0.053782817,0.034499522,,,,,,,,,,,,,,,,, -19700,0.06476184,0.034037933,,,,,,,,,,,,,,,,, -19800,0.049515743,0.0331445,,,,,,,,,,,,,,,,, -19900,0.07096027,0.030787557,,,,,,,,,,,,,,,,, -20000,0.055938385,0.03539789,,,,,,,,,,,,,,,,, -20100,0.091239676,0.03434177,,,,,,,,,,,,,,,,, -20200,0.06631746,0.032424178,,,,,,,,,,,,,,,,, -20208,,,0.9910753965377808,0.0293495748192071,0.437858584351662,0.9867947101593018,0.0442087724804878,0.269595815018364,43793.0,0.9859581589698792,0.0471153147518634,0.2562529002068713,43793.0,6501.562925338745,10242.721267700195,6501.562925338745,3739.80638551712,0.8040339946746826,0.0 -20300,0.115082905,0.034635477,,,,,,,,,,,,,,,,, -20400,0.08782323,0.031525787,,,,,,,,,,,,,,,,, -20500,0.08660005,0.035236053,,,,,,,,,,,,,,,,, -20600,0.07873971,0.03540641,,,,,,,,,,,,,,,,, -20700,0.054209035,0.031070469,,,,,,,,,,,,,,,,, -20800,0.07182883,0.030869922,,,,,,,,,,,,,,,,, -20900,0.0569634,0.035978638,,,,,,,,,,,,,,,,, -20958,,,0.9908106327056884,0.0300772711634635,0.407418383224706,0.98684424161911,0.0444803908467292,0.266787557603906,43793.0,0.9859585762023926,0.0475255995988845,0.2523169835991808,43793.0,6741.627146005631,10610.854069948196,6741.627146005631,3867.823861837387,0.8348932266235352,0.0 -21000,0.0660658,0.03382313,,,,,,,,,,,,,,,,, -21100,0.123518825,0.03316639,,,,,,,,,,,,,,,,, -21200,0.068358585,0.033528823,,,,,,,,,,,,,,,,, -21300,0.09692027,0.03364424,,,,,,,,,,,,,,,,, -21400,0.09102709,0.0343134,,,,,,,,,,,,,,,,, -21500,0.06607273,0.031698413,,,,,,,,,,,,,,,,, -21600,0.05899911,0.034073178,,,,,,,,,,,,,,,,, -21690,,,0.9906601905822754,0.0309933628886938,0.4074115877835104,0.9866725206375122,0.0447343923151493,0.2620614710724585,43793.0,0.9858292937278748,0.0474737659096717,0.2580787181285667,43793.0,6981.642642021179,10978.547773361206,6981.642642021179,3995.446721315384,0.8683032989501953,0.0 -21700,0.08204989,0.031753194,,,,,,,,,,,,,,,,, -21800,0.070195116,0.03189263,,,,,,,,,,,,,,,,, -21900,0.06999686,0.03803173,,,,,,,,,,,,,,,,, -22000,0.06571137,0.03389273,,,,,,,,,,,,,,,,, -22100,0.082696944,0.033405866,,,,,,,,,,,,,,,,, -22200,0.06783496,0.032347787,,,,,,,,,,,,,,,,, -22300,0.08497985,0.03557359,,,,,,,,,,,,,,,,, -22400,0.089746974,0.031044848,,,,,,,,,,,,,,,,, -22438,,,0.9909005165100098,0.0299562234431505,0.3997272030494833,0.986880362033844,0.0442997403442859,0.2655132926619737,43793.0,0.9859687089920044,0.0470192581415176,0.257428460130306,43793.0,7221.879897594452,11349.581412315369,7221.879897594452,4126.191339492798,0.8996272087097168,0.0 -22500,0.061830886,0.030982528,,,,,,,,,,,,,,,,, -22600,0.064070955,0.03478029,,,,,,,,,,,,,,,,, -22700,0.07989132,0.031533808,,,,,,,,,,,,,,,,, -22800,0.07868409,0.03162475,,,,,,,,,,,,,,,,, -22900,0.065043345,0.03390367,,,,,,,,,,,,,,,,, -23000,0.07474102,0.0336213,,,,,,,,,,,,,,,,, -23100,0.08418895,0.028843934,,,,,,,,,,,,,,,,, -23191,,,0.990909993648529,0.0301405116915702,0.415039702271021,0.9867057800292968,0.0441821590065956,0.2628321328663399,43793.0,0.985948085784912,0.0467676855623722,0.2523688362908369,43793.0,7462.123221635818,11715.854789495468,7462.123221635818,4252.171051979065,0.9298944473266602,0.0 -23200,0.066979446,0.030688364,,,,,,,,,,,,,,,,, -23300,0.07745505,0.038459834,,,,,,,,,,,,,,,,, -23400,0.071810074,0.03530791,,,,,,,,,,,,,,,,, -23500,0.06736394,0.032535743,,,,,,,,,,,,,,,,, -23600,0.10089043,0.03233871,,,,,,,,,,,,,,,,, -23700,0.081907555,0.03233679,,,,,,,,,,,,,,,,, -23800,0.08641487,0.038387697,,,,,,,,,,,,,,,,, -23900,0.06327424,0.034397766,,,,,,,,,,,,,,,,, -23939,,,0.9909388422966005,0.0295868385583162,0.4346094288609707,0.9869027137756348,0.0441179126501083,0.2732023445683491,43793.0,0.9859236478805542,0.0469748564064502,0.2573575046521077,43793.0,7701.862661600113,12085.510291099548,7701.862661600113,4381.744997501373,1.251793384552002,0.0 -24000,0.08888704,0.034757044,,,,,,,,,,,,,,,,, -24100,0.0700341,0.030316485,,,,,,,,,,,,,,,,, -24200,0.0836016,0.03190754,,,,,,,,,,,,,,,,, -24300,0.071081564,0.033701334,,,,,,,,,,,,,,,,, -24400,0.08319938,0.0365274,,,,,,,,,,,,,,,,, -24500,0.10334504,0.034369983,,,,,,,,,,,,,,,,, -24600,0.073731825,0.03429822,,,,,,,,,,,,,,,,, -24693,,,0.991054356098175,0.0292239245027303,0.4274887319021707,0.986707866191864,0.0442143566906452,0.2662884585720443,43793.0,0.9858819246292114,0.0468236804008483,0.2541363107731255,43793.0,7941.946875095367,12450.460037469864,7941.946875095367,4506.559768199921,1.282576322555542,0.0 -24700,0.070066966,0.033398204,,,,,,,,,,,,,,,,, -24800,0.0755505,0.034422465,,,,,,,,,,,,,,,,, -24900,0.07008783,0.031865228,,,,,,,,,,,,,,,,, -25000,0.08412206,0.034249447,,,,,,,,,,,,,,,,, -25100,0.06794579,0.03252431,,,,,,,,,,,,,,,,, -25200,0.07819382,0.033850484,,,,,,,,,,,,,,,,, -25300,0.06806038,0.037711866,,,,,,,,,,,,,,,,, -25400,0.10933174,0.0343974,,,,,,,,,,,,,,,,, -25443,,,0.9912801384925842,0.0285352878272533,0.4508424274078592,0.9867947101593018,0.0443945862352848,0.2647947449786009,43793.0,0.98597252368927,0.0471655651926994,0.2596896627611287,43793.0,8181.9895396232605,12817.27622628212,8181.9895396232605,4633.280071258545,1.3154983520507812,0.0 -25500,0.11253364,0.03482939,,,,,,,,,,,,,,,,, -25600,0.06891014,0.03438352,,,,,,,,,,,,,,,,, -25700,0.14413019,0.035176497,,,,,,,,,,,,,,,,, -25800,0.077684365,0.030685125,,,,,,,,,,,,,,,,, -25900,0.07526737,0.0314453,,,,,,,,,,,,,,,,, -26000,0.0783386,0.032736752,,,,,,,,,,,,,,,,, -26100,0.10349536,0.038441442,,,,,,,,,,,,,,,,, -26189,,,0.9914413094520568,0.0279674343764781,0.4703452705316486,0.9867812991142272,0.0445206128060817,0.2657041952537768,43793.0,0.9858335256576538,0.0473558753728866,0.2510148645408534,43793.0,8421.981696367264,13182.20546245575,8421.981696367264,4758.165808677673,1.3469769954681396,0.0 -26200,0.0725387,0.030789325,,,,,,,,,,,,,,,,, -26300,0.07343643,0.034533117,,,,,,,,,,,,,,,,, -26400,0.089327104,0.035176992,,,,,,,,,,,,,,,,, -26500,0.06852932,0.03526473,,,,,,,,,,,,,,,,, -26600,0.106627114,0.034386143,,,,,,,,,,,,,,,,, -26700,0.10706522,0.03151283,,,,,,,,,,,,,,,,, -26800,0.08171247,0.035055947,,,,,,,,,,,,,,,,, -26900,0.08352305,0.0326369,,,,,,,,,,,,,,,,, -26938,,,0.9915258884429932,0.0276505500078201,0.4760765050180133,0.9869157075881958,0.0446074083447456,0.2645389905862119,43793.0,0.9859678745269777,0.047476228326559,0.2551753097398039,43793.0,8662.139254808426,13551.315371513369,8662.139254808426,4887.066943645477,1.378154993057251,0.0 -27000,0.080667146,0.03138927,,,,,,,,,,,,,,,,, -27100,0.1229984,0.034403026,,,,,,,,,,,,,,,,, -27200,0.07980911,0.03310474,,,,,,,,,,,,,,,,, -27300,0.09459659,0.03022108,,,,,,,,,,,,,,,,, -27400,0.11438281,0.039023705,,,,,,,,,,,,,,,,, -27500,0.068119995,0.03254544,,,,,,,,,,,,,,,,, -27600,0.10076305,0.031040488,,,,,,,,,,,,,,,,, -27690,,,0.991478443145752,0.0281016025692224,0.4671701164093043,0.9867671132087708,0.0443702042102813,0.2653502731165771,43793.0,0.9857825636863708,0.0472225956618785,0.2515572262223682,43793.0,8902.194403409958,13919.763410568235,8902.194403409958,5015.407616376877,1.4100637435913086,0.0 -27700,0.064910546,0.02926849,,,,,,,,,,,,,,,,, -27800,0.09111166,0.034457784,,,,,,,,,,,,,,,,, -27900,0.08321296,0.03628612,,,,,,,,,,,,,,,,, -28000,0.16096373,0.032907356,,,,,,,,,,,,,,,,, -28100,0.085212715,0.03028726,,,,,,,,,,,,,,,,, -28200,0.063331835,0.030736927,,,,,,,,,,,,,,,,, -28300,0.091087475,0.03110885,,,,,,,,,,,,,,,,, -28400,0.07143312,0.033061422,,,,,,,,,,,,,,,,, -28425,,,0.9911364912986756,0.0287772286683321,0.4375929169048969,0.9868263602256776,0.0445405580103397,0.2722830766897607,43793.0,0.9859636425971984,0.0474299751222133,0.2580036411863702,43793.0,9142.177165269852,14288.520901203156,9142.177165269852,5144.1230499744415,1.4474318027496338,0.0 -28500,0.09290696,0.03357845,,,,,,,,,,,,,,,,, -28600,0.096227765,0.03538663,,,,,,,,,,,,,,,,, -28700,0.0843358,0.03499547,,,,,,,,,,,,,,,,, -28800,0.06644488,0.033069007,,,,,,,,,,,,,,,,, -28900,0.07889028,0.03135541,,,,,,,,,,,,,,,,, -29000,0.078933746,0.033575676,,,,,,,,,,,,,,,,, -29100,0.08341061,0.030213965,,,,,,,,,,,,,,,,, -29159,,,0.9911613464355468,0.028942160308361,0.4474896226765509,0.986789047718048,0.0447144210338592,0.2646365791157178,43793.0,0.985961139202118,0.0474254563450813,0.2587341616032605,43793.0,9382.21336388588,14656.361695289612,9382.21336388588,5271.874435424805,1.480443239212036,0.0 -29200,0.09225022,0.034224562,,,,,,,,,,,,,,,,, -29300,0.0725693,0.032336924,,,,,,,,,,,,,,,,, -29400,0.10284584,0.030135015,,,,,,,,,,,,,,,,, -29500,0.08447511,0.033300806,,,,,,,,,,,,,,,,, -29600,0.080839746,0.03340416,,,,,,,,,,,,,,,,, -29700,0.106140286,0.03455115,,,,,,,,,,,,,,,,, -29800,0.08646029,0.033100307,,,,,,,,,,,,,,,,, -29899,,,0.9912360906600952,0.0286519527435302,0.4469355000732045,0.9867979288101196,0.0446059815585613,0.2667757483159876,43793.0,0.9859535694122314,0.0473793819546699,0.2597263787708996,43793.0,9622.341174840927,15026.982689142227,9622.341174840927,5402.314550876617,1.512584209442139,0.0 -29900,0.08149136,0.033147044,,,,,,,,,,,,,,,,, -30000,0.08728776,0.03269684,,,,,,,,,,,,,,,,, -30100,0.07837398,0.032072984,,,,,,,,,,,,,,,,, -30200,0.10764598,0.031234423,,,,,,,,,,,,,,,,, -30300,0.08041379,0.033197902,,,,,,,,,,,,,,,,, -30400,0.08230675,0.03416159,,,,,,,,,,,,,,,,, -30500,0.084663704,0.03466215,,,,,,,,,,,,,,,,, -30600,0.08912704,0.031330947,,,,,,,,,,,,,,,,, -30634,,,0.9913098812103271,0.0283946581184864,0.4532779013963894,0.9868007898330688,0.0447873137891292,0.2689093160310568,43793.0,0.9859750270843506,0.0477339103817939,0.2601599084649045,43793.0,9862.550683259964,15398.330970048904,9862.550683259964,5533.393811702728,1.550180196762085,0.0 -30700,0.09240375,0.032594964,,,,,,,,,,,,,,,,, -30800,0.07200682,0.032544978,,,,,,,,,,,,,,,,, -30900,0.08925762,0.03367824,,,,,,,,,,,,,,,,, -31000,0.063243344,0.029395618,,,,,,,,,,,,,,,,, -31100,0.081674695,0.03430771,,,,,,,,,,,,,,,,, -31200,0.109379016,0.029194396,,,,,,,,,,,,,,,,, -31300,0.089206696,0.033214327,,,,,,,,,,,,,,,,, -31379,,,0.9914488196372986,0.0278518609702587,0.4698440030144065,0.9867720007896424,0.04453831538558,0.2701344358372728,43793.0,0.985964059829712,0.0474789142608642,0.2624128359043144,43793.0,10102.536264419556,15764.439517259598,10102.536264419556,5659.4643721580505,1.5825748443603516,0.0 -31400,0.0766067,0.03138326,,,,,,,,,,,,,,,,, -31500,0.08180464,0.03149603,,,,,,,,,,,,,,,,, -31600,0.08363752,0.031244153,,,,,,,,,,,,,,,,, -31700,0.10276382,0.031902492,,,,,,,,,,,,,,,,, -31800,0.06744162,0.029792463,,,,,,,,,,,,,,,,, -31900,0.08483168,0.029619677,,,,,,,,,,,,,,,,, -32000,0.10968376,0.031750347,,,,,,,,,,,,,,,,, -32100,0.07240908,0.030519389,,,,,,,,,,,,,,,,, -32124,,,0.9916564226150512,0.027042893692851,0.4944718535620369,0.9868767261505128,0.0442681089043617,0.271961120654364,43793.0,0.9859771132469176,0.0472414530813694,0.2615611187014712,43793.0,10342.64687371254,16131.504002094269,10342.64687371254,5786.365542411804,1.615133285522461,0.0 -32200,0.070697755,0.031480595,,,,,,,,,,,,,,,,, -32300,0.0909792,0.03242934,,,,,,,,,,,,,,,,, -32400,0.08378321,0.03154891,,,,,,,,,,,,,,,,, -32500,0.11714117,0.031043146,,,,,,,,,,,,,,,,, -32600,0.09512837,0.032586075,,,,,,,,,,,,,,,,, -32700,0.10260408,0.03583373,,,,,,,,,,,,,,,,, -32800,0.10194074,0.03165154,,,,,,,,,,,,,,,,, -32856,,,0.99176424741745,0.0266096200793981,0.498826540416365,0.9868531823158264,0.0443659499287605,0.2692807413861312,43793.0,0.9860023856163024,0.0472645424306392,0.2620336225403224,43793.0,10582.716272592545,16494.133712530136,10582.716272592545,5908.865457057953,1.652513027191162,0.0 -32900,0.14883777,0.030562017,,,,,,,,,,,,,,,,, -33000,0.07860897,0.031774584,,,,,,,,,,,,,,,,, -33100,0.07021263,0.032707956,,,,,,,,,,,,,,,,, -33200,0.08695379,0.030673146,,,,,,,,,,,,,,,,, -33300,0.08044098,0.0321561,,,,,,,,,,,,,,,,, -33400,0.124048136,0.03200121,,,,,,,,,,,,,,,,, -33500,0.086046055,0.033515364,,,,,,,,,,,,,,,,, -33595,,,0.9918482303619384,0.0264681000262498,0.5026169864173307,0.986846685409546,0.0450812652707099,0.2694701894953767,43793.0,0.9859889149665833,0.0477740578353405,0.2634912505782247,43793.0,10822.800383806229,16862.094605207443,10822.800383806229,6036.689339399338,1.685504913330078,0.0 -33600,0.17477018,0.034570184,,,,,,,,,,,,,,,,, -33700,0.1171869,0.03032848,,,,,,,,,,,,,,,,, -33800,0.1823887,0.028929764,,,,,,,,,,,,,,,,, -33900,0.083970755,0.033438347,,,,,,,,,,,,,,,,, -34000,0.08450609,0.03134706,,,,,,,,,,,,,,,,, -34100,0.10473918,0.037417565,,,,,,,,,,,,,,,,, -34200,0.08419283,0.029636344,,,,,,,,,,,,,,,,, -34300,0.090232834,0.03335523,,,,,,,,,,,,,,,,, -34335,,,0.991648256778717,0.0270575191825628,0.4754191822452325,0.9869412779808044,0.0447731092572212,0.2740292724154801,43793.0,0.9860196709632874,0.0474854782223701,0.2597895271918095,43793.0,11062.89523434639,17230.688241004944,11062.89523434639,6165.134907007217,1.7191433906555176,0.0 -34400,0.09176027,0.03587543,,,,,,,,,,,,,,,,, -34500,0.07910303,0.031031743,,,,,,,,,,,,,,,,, -34600,0.08026281,0.029569468,,,,,,,,,,,,,,,,, -34700,0.097953945,0.033393066,,,,,,,,,,,,,,,,, -34800,0.07733253,0.03099064,,,,,,,,,,,,,,,,, -34900,0.13094994,0.033814747,,,,,,,,,,,,,,,,, -35000,0.11300431,0.032870356,,,,,,,,,,,,,,,,, -35083,,,0.9914780855178832,0.0276719629764556,0.4771699893653773,0.986946940422058,0.0444764271378517,0.2762864765832121,43793.0,0.9860441088676452,0.047148123383522,0.2659752052189602,43793.0,11302.936974048616,17599.211398363113,11302.936974048616,6293.562938928604,1.7529196739196775,0.0 -35100,0.101629846,0.035524063,,,,,,,,,,,,,,,,, -35200,0.15240365,0.030960063,,,,,,,,,,,,,,,,, -35300,0.07605813,0.030979777,,,,,,,,,,,,,,,,, -35400,0.13987996,0.03399782,,,,,,,,,,,,,,,,, -35500,0.09135259,0.031035887,,,,,,,,,,,,,,,,, -35600,0.09080369,0.031125689,,,,,,,,,,,,,,,,, -35700,0.104610726,0.031645775,,,,,,,,,,,,,,,,, -35800,0.10166012,0.031351347,,,,,,,,,,,,,,,,, -35831,,,0.9915109276771544,0.0274749193340539,0.4855648311323703,0.9869274497032166,0.0445232838392257,0.2701972526758367,43793.0,0.9860904216766356,0.0475407242774963,0.2623007088811127,43793.0,11543.209511995316,17968.17503976822,11543.209511995316,6422.19800400734,1.7894747257232666,0.0 -35900,0.096000604,0.03191204,,,,,,,,,,,,,,,,, -36000,0.10839194,0.032625504,,,,,,,,,,,,,,,,, -36100,0.08484171,0.030729463,,,,,,,,,,,,,,,,, -36200,0.139062,0.029416438,,,,,,,,,,,,,,,,, -36300,0.15635656,0.03303199,,,,,,,,,,,,,,,,, -36400,0.07549833,0.031147249,,,,,,,,,,,,,,,,, -36500,0.10482102,0.03409765,,,,,,,,,,,,,,,,, -36577,,,0.9916188716888428,0.0272390656173229,0.4803275266367559,0.9868852496147156,0.044582299888134,0.271263958619182,43793.0,0.9860007166862488,0.0474309548735618,0.2636249846278241,43793.0,11783.2023396492,18335.44613981247,11783.2023396492,6549.4206647872925,1.8252544403076167,0.0 -36600,0.09466018,0.031312954,,,,,,,,,,,,,,,,, -36700,0.07785302,0.029670777,,,,,,,,,,,,,,,,, -36800,0.08964187,0.03084351,,,,,,,,,,,,,,,,, -36900,0.08820485,0.032113563,,,,,,,,,,,,,,,,, -37000,0.10829258,0.030548116,,,,,,,,,,,,,,,,, -37100,0.103511594,0.029382097,,,,,,,,,,,,,,,,, -37200,0.07865753,0.029939713,,,,,,,,,,,,,,,,, -37300,0.11282072,0.031135553,,,,,,,,,,,,,,,,, -37323,,,0.9916990995407104,0.0268614571541547,0.484263148365716,0.9868641495704652,0.0448390506207942,0.2684500922044856,43793.0,0.9859986305236816,0.0479762442409992,0.2591521938112194,43793.0,12023.32865190506,18702.27978849411,12023.32865190506,6676.074639558792,1.858780860900879,0.0 -37400,0.088719144,0.028805504,,,,,,,,,,,,,,,,, -37500,0.09441365,0.031034468,,,,,,,,,,,,,,,,, -37600,0.077770606,0.029602908,,,,,,,,,,,,,,,,, -37700,0.081249945,0.029080503,,,,,,,,,,,,,,,,, -37800,0.10844507,0.032057747,,,,,,,,,,,,,,,,, -37900,0.09315075,0.029546192,,,,,,,,,,,,,,,,, -38000,0.0864059,0.029094564,,,,,,,,,,,,,,,,, -38072,,,0.9920689463615416,0.0256481822580099,0.516311374792447,0.9868032336235046,0.0446102283895015,0.2661784733789995,43793.0,0.9859472513198853,0.0474219359457492,0.2636671236942603,43793.0,12263.584111452104,19068.74548768997,12263.584111452104,6802.231438398361,1.8928401470184328,0.0 -38100,0.07929104,0.027849816,,,,,,,,,,,,,,,,, -38200,0.10938105,0.029195769,,,,,,,,,,,,,,,,, -38300,0.107096314,0.0333001,,,,,,,,,,,,,,,,, -38400,0.10022392,0.027337484,,,,,,,,,,,,,,,,, -38500,0.10637036,0.030625077,,,,,,,,,,,,,,,,, -38600,0.07785976,0.02879773,,,,,,,,,,,,,,,,, -38700,0.09190394,0.03459153,,,,,,,,,,,,,,,,, -38800,0.07734396,0.03013568,,,,,,,,,,,,,,,,, -38821,,,0.9921510815620422,0.0255514401942491,0.5158998215716353,0.9867951273918152,0.0446099750697612,0.2744577401290294,43793.0,0.9859569072723388,0.0475714839994907,0.2615809471411672,43793.0,12503.6462123394,19435.058240175247,12503.6462123394,6928.428899765015,1.9264566898345947,0.0 -38900,0.09221483,0.031391237,,,,,,,,,,,,,,,,, -39000,0.09289224,0.028473912,,,,,,,,,,,,,,,,, -39100,0.08378374,0.028286584,,,,,,,,,,,,,,,,, -39200,0.11537705,0.031951327,,,,,,,,,,,,,,,,, -39300,0.09414337,0.030360255,,,,,,,,,,,,,,,,, -39400,0.09657024,0.03134174,,,,,,,,,,,,,,,,, -39500,0.09467583,0.030642934,,,,,,,,,,,,,,,,, -39573,,,0.990967333316803,0.0293112080544233,0.4569042325964516,0.986046552658081,0.0476715117692947,0.233260071570542,43793.0,0.9850488305091858,0.0510807670652866,0.2129662973348563,43793.0,12743.860652923584,19803.6591398716,12743.860652923584,7056.762053728104,1.9602704048156736,0.0 -39600,0.15209311,0.0349229,,,,,,,,,,,,,,,,, -39700,0.108951524,0.03349177,,,,,,,,,,,,,,,,, -39800,0.10994055,0.032183543,,,,,,,,,,,,,,,,, -39900,0.11021694,0.031991538,,,,,,,,,,,,,,,,, -40000,0.10318921,0.030037628,,,,,,,,,,,,,,,,, -40100,0.08604337,0.028019631,,,,,,,,,,,,,,,,, -40200,0.09660001,0.027602438,,,,,,,,,,,,,,,,, -40300,0.15351725,0.032396134,,,,,,,,,,,,,,,,, -40320,,,0.9924874901771544,0.0244606956839561,0.5351341203817217,0.9868787527084352,0.0445943921804428,0.2759908697710495,43793.0,0.9859982132911682,0.0479002371430397,0.2639822632255164,43793.0,12983.941451787949,20167.69873785973,12983.941451787949,7180.665674209595,1.9954063892364504,0.0 -40400,0.15976457,0.033582438,,,,,,,,,,,,,,,,, -40500,0.10909921,0.031640906,,,,,,,,,,,,,,,,, -40600,0.09794101,0.030332316,,,,,,,,,,,,,,,,, -40700,0.10681837,0.033005886,,,,,,,,,,,,,,,,, -40800,0.123600386,0.029559463,,,,,,,,,,,,,,,,, -40900,0.13560744,0.031917,,,,,,,,,,,,,,,,, -41000,0.119482316,0.03218828,,,,,,,,,,,,,,,,, -41067,,,0.9923626780509948,0.0249550715088844,0.5233125303946242,0.9868706464767456,0.0443031042814254,0.2764290908345192,43793.0,0.9859969019889832,0.0472699888050556,0.2602967469117153,43793.0,13224.08312869072,20527.777365922928,13224.08312869072,7300.5468554496765,2.0311975479125977,0.0 -41100,0.1047951,0.030249402,,,,,,,,,,,,,,,,, -41200,0.09719651,0.030270899,,,,,,,,,,,,,,,,, -41300,0.09374228,0.031903584,,,,,,,,,,,,,,,,, -41400,0.10241655,0.028069606,,,,,,,,,,,,,,,,, -41500,0.13855273,0.030549685,,,,,,,,,,,,,,,,, -41600,0.07400316,0.027533589,,,,,,,,,,,,,,,,, -41700,0.102760926,0.02968511,,,,,,,,,,,,,,,,, -41800,0.11178782,0.030664044,,,,,,,,,,,,,,,,, -41813,,,0.9921444058418274,0.0252819508314132,0.528187037457823,0.9869668483734132,0.0449664071202278,0.2712547909511459,43793.0,0.9860803484916688,0.0479672290384769,0.2647587286793563,43793.0,13464.266607761385,20893.67218565941,13464.266607761385,7426.20378446579,2.0658912658691406,0.0 -41900,0.10425673,0.032274198,,,,,,,,,,,,,,,,, -42000,0.10159752,0.027215458,,,,,,,,,,,,,,,,, -42100,0.10910818,0.030122317,,,,,,,,,,,,,,,,, -42200,0.1628971,0.03330666,,,,,,,,,,,,,,,,, -42300,0.1059913,0.02994023,,,,,,,,,,,,,,,,, -42400,0.13474867,0.03287345,,,,,,,,,,,,,,,,, -42500,0.09790219,0.028775433,,,,,,,,,,,,,,,,, -42551,,,0.9920623898506165,0.0255641397088766,0.5207061522784611,0.986931085586548,0.0451455265283584,0.2757223796749376,43793.0,0.9861013889312744,0.0479997955262661,0.261924874372922,43793.0,13704.357329368591,21258.0973508358,13704.357329368591,7550.482100486755,2.102457523345948,0.0 -42600,0.117244974,0.02812475,,,,,,,,,,,,,,,,, -42700,0.0996752,0.02723488,,,,,,,,,,,,,,,,, -42800,0.089535,0.029369758,,,,,,,,,,,,,,,,, -42900,0.11990992,0.027116396,,,,,,,,,,,,,,,,, -43000,0.102845006,0.030998183,,,,,,,,,,,,,,,,, -43100,0.10897328,0.030454472,,,,,,,,,,,,,,,,, -43200,0.09429079,0.028816346,,,,,,,,,,,,,,,,, -43293,,,0.9919952154159546,0.0256527718156576,0.5196389960976198,0.9868032336235046,0.0451744943857193,0.2733776448060774,43793.0,0.9859657287597656,0.0482613332569599,0.2562264799589845,43793.0,13944.305476903915,21624.305897712708,13944.305476903915,7676.686524152756,2.1386983394622803,0.0 -43300,0.09784645,0.027059184,,,,,,,,,,,,,,,,, -43400,0.0977768,0.029994665,,,,,,,,,,,,,,,,, -43500,0.10831614,0.028611898,,,,,,,,,,,,,,,,, -43600,0.15617706,0.029856265,,,,,,,,,,,,,,,,, -43700,0.10727521,0.029868782,,,,,,,,,,,,,,,,, -43800,0.10462808,0.027624866,,,,,,,,,,,,,,,,, -43900,0.11097343,0.0310055,,,,,,,,,,,,,,,,, -44000,0.12541057,0.02837362,,,,,,,,,,,,,,,,, -44038,,,0.9922802448272704,0.02483881264925,0.5377617439585939,0.9868596792221068,0.0452494956552982,0.270351017601974,43793.0,0.9860146045684814,0.0481569580733776,0.2617285586172711,43793.0,14184.53102874756,21993.115421056747,14184.53102874756,7805.214061498642,2.1753900051116943,0.0 -44100,0.11386963,0.030487357,,,,,,,,,,,,,,,,, -44200,0.11099659,0.031186705,,,,,,,,,,,,,,,,, -44300,0.108306386,0.030223109,,,,,,,,,,,,,,,,, -44400,0.107464835,0.028258907,,,,,,,,,,,,,,,,, -44500,0.1261724,0.027774777,,,,,,,,,,,,,,,,, -44600,0.1165992,0.029032718,,,,,,,,,,,,,,,,, -44700,0.11430188,0.02861127,,,,,,,,,,,,,,,,, -44777,,,0.9924088716506958,0.0242243651300668,0.5362641543231005,0.9869104027748108,0.0453433506190776,0.2695658229800238,43793.0,0.9860327243804932,0.0481882318854332,0.2616151770568872,43793.0,14424.795090436935,22358.681414604187,14424.795090436935,7930.460354089737,2.211509704589844,0.0 -44800,0.10059612,0.03161227,,,,,,,,,,,,,,,,, -44900,0.115661636,0.029587185,,,,,,,,,,,,,,,,, -45000,0.112475246,0.029725207,,,,,,,,,,,,,,,,, -45100,0.13632606,0.031176614,,,,,,,,,,,,,,,,, -45200,0.10607571,0.030138245,,,,,,,,,,,,,,,,, -45300,0.12751569,0.029586837,,,,,,,,,,,,,,,,, -45400,0.11377879,0.03074294,,,,,,,,,,,,,,,,, -45500,0.11115793,0.030949822,,,,,,,,,,,,,,,,, -45515,,,0.9927688837051392,0.0231147781014442,0.5721043220273573,0.9869270324707032,0.0452565178275108,0.2640619486550432,43793.0,0.986042022705078,0.0484271347522735,0.2579358481884553,43793.0,14664.76844573021,22726.34547829628,14664.76844573021,8058.0960521698,2.247072696685791,0.0 -45600,0.1143194,0.026027584,,,,,,,,,,,,,,,,, -45700,0.14832819,0.028588653,,,,,,,,,,,,,,,,, -45800,0.13386874,0.028872518,,,,,,,,,,,,,,,,, -45900,0.10410111,0.027983243,,,,,,,,,,,,,,,,, -46000,0.13600817,0.027037209,,,,,,,,,,,,,,,,, -46100,0.111327626,0.028262105,,,,,,,,,,,,,,,,, -46200,0.109829456,0.030778252,,,,,,,,,,,,,,,,, -46262,,,0.9930108189582824,0.0223857462406158,0.5915540953232792,0.9868503212928772,0.0455769263207912,0.268653912629021,43793.0,0.985990583896637,0.048694908618927,0.2598243224261893,43793.0,14904.87128996849,23092.4700255394,14904.87128996849,8184.0625557899475,2.28225326538086,0.0 -46300,0.114033595,0.030602101,,,,,,,,,,,,,,,,, -46400,0.12803593,0.028508194,,,,,,,,,,,,,,,,, -46500,0.13280138,0.027566444,,,,,,,,,,,,,,,,, -46600,0.12719427,0.027062906,,,,,,,,,,,,,,,,, -46700,0.12152762,0.029528946,,,,,,,,,,,,,,,,, -46800,0.10859775,0.028452268,,,,,,,,,,,,,,,,, -46900,0.21572305,0.029678017,,,,,,,,,,,,,,,,, -47000,0.096596174,0.025832506,,,,,,,,,,,,,,,,, -47008,,,0.993267297744751,0.0218341443687677,0.6045476022820708,0.986806869506836,0.0458246506750583,0.2718907877934435,43793.0,0.9859114289283752,0.0486891642212867,0.264437696290971,43793.0,15144.95430803299,23455.587432146072,15144.95430803299,8307.041923999786,2.317678451538086,0.0 -47100,0.12702279,0.029075619,,,,,,,,,,,,,,,,, -47200,0.10249064,0.028078478,,,,,,,,,,,,,,,,, -47300,0.10922977,0.026852395,,,,,,,,,,,,,,,,, -47400,0.14524122,0.02894663,,,,,,,,,,,,,,,,, -47500,0.16307816,0.027144019,,,,,,,,,,,,,,,,, -47600,0.113947205,0.026942642,,,,,,,,,,,,,,,,, -47700,0.12183476,0.025723156,,,,,,,,,,,,,,,,, -47748,,,0.992554783821106,0.0238631702959537,0.5605475454067415,0.9868738651275636,0.045866098254919,0.274083811671372,43793.0,0.9859864115715028,0.0491985715925693,0.2576000876595037,43793.0,15385.188156366348,23823.49522161484,15385.188156366348,8434.66095161438,2.353482723236084,0.0 -47800,0.12643337,0.02889456,,,,,,,,,,,,,,,,, -47900,0.13347538,0.027165044,,,,,,,,,,,,,,,,, -48000,0.12336192,0.030216347,,,,,,,,,,,,,,,,, -48100,0.104393795,0.02851454,,,,,,,,,,,,,,,,, -48200,0.16564028,0.0284161,,,,,,,,,,,,,,,,, -48300,0.1115654,0.027966527,,,,,,,,,,,,,,,,, -48400,0.11260345,0.027222417,,,,,,,,,,,,,,,,, -48496,,,0.9926549792289734,0.0236036144196987,0.5488211291481068,0.9868596792221068,0.0460748337209224,0.2715950909297707,43793.0,0.9859628081321716,0.0490183718502521,0.2601319877159591,43793.0,15625.320230960846,24188.14550971985,15625.320230960846,8559.123657226562,2.389343738555908,0.0 -48500,0.12233689,0.029824749,,,,,,,,,,,,,,,,, -48600,0.12568866,0.024538442,,,,,,,,,,,,,,,,, -48700,0.11253831,0.027651876,,,,,,,,,,,,,,,,, -48800,0.11102748,0.025872774,,,,,,,,,,,,,,,,, -48900,0.14059857,0.028620362,,,,,,,,,,,,,,,,, -49000,0.123361394,0.028771628,,,,,,,,,,,,,,,,, -49100,0.1222144,0.028749194,,,,,,,,,,,,,,,,, -49200,0.11581765,0.028735992,,,,,,,,,,,,,,,,, -49243,,,0.9927665591239928,0.0231651198118925,0.5600378176647713,0.9868649244308472,0.0459690317511558,0.2729185746436617,43793.0,0.9859880805015564,0.049083225429058,0.2621291751755614,43793.0,15865.286841392515,24550.717749118805,15865.286841392515,8681.672759532928,2.4259297847747803,0.0 -49300,0.15067905,0.026473185,,,,,,,,,,,,,,,,, -49400,0.13527814,0.029698988,,,,,,,,,,,,,,,,, -49500,0.15133655,0.027269546,,,,,,,,,,,,,,,,, -49600,0.13477889,0.02633241,,,,,,,,,,,,,,,,, -49700,0.17073376,0.028870432,,,,,,,,,,,,,,,,, -49800,0.13443032,0.026214598,,,,,,,,,,,,,,,,, -49900,0.13257271,0.026350712,,,,,,,,,,,,,,,,, -49988,,,0.9928274750709534,0.0229332204908132,0.5803153113657378,0.9868227243423462,0.0465543456375598,0.2713811193405487,43793.0,0.9858924746513368,0.049880214035511,0.2606511181190385,43793.0,16105.240637540815,24909.65825247765,16105.240637540815,8800.601737260818,2.4639928340911865,0.0 -50000,0.11443443,0.022525419,,,,,,,,,,,,,,,,, -50100,0.14529876,0.027489081,,,,,,,,,,,,,,,,, -50200,0.14397402,0.026917672,,,,,,,,,,,,,,,,, -50300,0.13661021,0.026610171,,,,,,,,,,,,,,,,, -50400,0.11172684,0.024958685,,,,,,,,,,,,,,,,, -50500,0.14855224,0.028061364,,,,,,,,,,,,,,,,, -50600,0.14535531,0.026252732,,,,,,,,,,,,,,,,, -50700,0.16331515,0.03155972,,,,,,,,,,,,,,,,, -50738,,,0.992751955986023,0.0230761207640171,0.5672821856545007,0.9868150353431702,0.0462609492242336,0.2708342805840612,43793.0,0.98593670129776,0.0497093610465526,0.2601520911636134,43793.0,16345.393598794935,25270.969877958298,16345.393598794935,8921.704628229141,2.500322103500366,0.0 -50800,0.15642525,0.029131126,,,,,,,,,,,,,,,,, -50900,0.16444781,0.027203633,,,,,,,,,,,,,,,,, -51000,0.12476521,0.026707502,,,,,,,,,,,,,,,,, -51100,0.12490831,0.024420453,,,,,,,,,,,,,,,,, -51200,0.1491842,0.026083343,,,,,,,,,,,,,,,,, -51300,0.12149335,0.026418231,,,,,,,,,,,,,,,,, -51400,0.1513076,0.025787799,,,,,,,,,,,,,,,,, -51489,,,0.9931676983833312,0.0218032449483871,0.5953628427219528,0.9868471026420592,0.0464733131229877,0.2731393930659657,43793.0,0.9858849048614502,0.0498098582029342,0.263006766620887,43793.0,16585.514472723007,25633.673954486847,16585.514472723007,9044.232073783876,2.5364415645599365,0.0 -51500,0.13376728,0.026495388,,,,,,,,,,,,,,,,, -51600,0.12897727,0.025840448,,,,,,,,,,,,,,,,, -51700,0.12835823,0.026608009,,,,,,,,,,,,,,,,, -51800,0.13861676,0.027772522,,,,,,,,,,,,,,,,, -51900,0.17885098,0.029200062,,,,,,,,,,,,,,,,, -52000,0.1360037,0.02459855,,,,,,,,,,,,,,,,, -52100,0.22597094,0.026518999,,,,,,,,,,,,,,,,, -52200,0.13366136,0.026711546,,,,,,,,,,,,,,,,, -52235,,,0.993556797504425,0.0206640735268592,0.6220936529065954,0.9868056774139404,0.0468620583415031,0.2664892358792565,43793.0,0.9858680367469788,0.0501306690275669,0.2569849594604288,43793.0,16825.499007225037,25994.37507820129,16825.499007225037,9164.892379283903,2.573488712310791,0.0 -52300,0.19145173,0.027987743,,,,,,,,,,,,,,,,, -52400,0.16127822,0.024374314,,,,,,,,,,,,,,,,, -52500,0.1605573,0.028339647,,,,,,,,,,,,,,,,, -52600,0.13789162,0.025931448,,,,,,,,,,,,,,,,, -52700,0.17086314,0.027725255,,,,,,,,,,,,,,,,, -52800,0.16316248,0.026880864,,,,,,,,,,,,,,,,, -52900,0.12737699,0.023609633,,,,,,,,,,,,,,,,, -52980,,,0.9937600493431092,0.0198313146829605,0.6459993043755188,0.9867037534713744,0.0474918149411678,0.2685034861718007,43793.0,0.98589289188385,0.0506656877696514,0.2570500295083031,43793.0,17065.48960494995,26357.715396165848,17065.48960494995,9288.184242248535,2.612375974655152,0.0 -53000,0.1519721,0.026174316,,,,,,,,,,,,,,,,, -53100,0.16769545,0.028080046,,,,,,,,,,,,,,,,, -53200,0.15621394,0.027259089,,,,,,,,,,,,,,,,, -53300,0.17117004,0.024577776,,,,,,,,,,,,,,,,, -53400,0.1532315,0.026890604,,,,,,,,,,,,,,,,, -53500,0.14053792,0.02760216,,,,,,,,,,,,,,,,, -53600,0.13360637,0.026700864,,,,,,,,,,,,,,,,, -53700,0.14571644,0.024385814,,,,,,,,,,,,,,,,, -53724,,,0.9939901828765868,0.0193338319659233,0.6436779537032165,0.9866481423377992,0.0475041940808296,0.2681561503925183,43793.0,0.985889494419098,0.0503124706447124,0.261990889494677,43793.0,17305.6275203228,26717.59257531166,17305.6275203228,9407.865612268448,2.6504719257354736,0.0 -53800,0.13666378,0.026934346,,,,,,,,,,,,,,,,, -53900,0.14280006,0.022163957,,,,,,,,,,,,,,,,, -54000,0.17446417,0.025876433,,,,,,,,,,,,,,,,, -54100,0.15594265,0.025621254,,,,,,,,,,,,,,,,, -54200,0.1641202,0.023936871,,,,,,,,,,,,,,,,, -54300,0.15319178,0.026348392,,,,,,,,,,,,,,,,, -54400,0.14684246,0.025148574,,,,,,,,,,,,,,,,, -54469,,,0.9936777949333192,0.0201111733913421,0.6458658038745391,0.9866968989372252,0.047279093414545,0.267442034344366,43793.0,0.9858878254890442,0.0506690517067909,0.2569829867377682,43793.0,17545.65809392929,27082.599598646164,17545.65809392929,9532.784777402878,2.6869962215423584,0.0 -54500,0.16235565,0.02533039,,,,,,,,,,,,,,,,, -54600,0.15745476,0.025624571,,,,,,,,,,,,,,,,, -54700,0.14466715,0.026707591,,,,,,,,,,,,,,,,, -54800,0.1798825,0.024084132,,,,,,,,,,,,,,,,, -54900,0.16503367,0.02363483,,,,,,,,,,,,,,,,, -55000,0.15593617,0.026145037,,,,,,,,,,,,,,,,, -55100,0.1661194,0.02676908,,,,,,,,,,,,,,,,, -55200,0.19195461,0.026382403,,,,,,,,,,,,,,,,, -55211,,,0.9936330318450928,0.0203234907239675,0.6208121445688634,0.9867082238197328,0.0476051606237888,0.2657908474666073,43793.0,0.985811173915863,0.0510495565831661,0.2536814701130161,43793.0,17785.798129081726,27443.96195435524,17785.798129081726,9653.950505495071,2.723959684371948,0.0 -55300,0.19691773,0.025248474,,,,,,,,,,,,,,,,, -55400,0.15107547,0.025323529,,,,,,,,,,,,,,,,, -55500,0.18086885,0.0267466,,,,,,,,,,,,,,,,, -55600,0.1793818,0.02744615,,,,,,,,,,,,,,,,, -55700,0.1673088,0.02623543,,,,,,,,,,,,,,,,, -55800,0.17978728,0.024577873,,,,,,,,,,,,,,,,, -55900,0.17093062,0.027703896,,,,,,,,,,,,,,,,, -55959,,,0.9933901429176332,0.0209429822862148,0.6129876580194846,0.9866664409637452,0.0478049889206886,0.2669865188916774,43793.0,0.9858027696609496,0.0511068813502788,0.2597542658100841,43793.0,18026.053458452225,27810.32054162025,18026.053458452225,9779.996412038803,2.7612576484680176,0.0 -56000,0.16261156,0.024734909,,,,,,,,,,,,,,,,, -56100,0.1606523,0.024407597,,,,,,,,,,,,,,,,, -56200,0.15563586,0.025570739,,,,,,,,,,,,,,,,, -56300,0.20912842,0.023328507,,,,,,,,,,,,,,,,, -56400,0.21190134,0.025924345,,,,,,,,,,,,,,,,, -56500,0.14994107,0.025576156,,,,,,,,,,,,,,,,, -56600,0.17245847,0.024398971,,,,,,,,,,,,,,,,, -56700,0.1568633,0.02345529,,,,,,,,,,,,,,,,, -56710,,,0.9932118058204652,0.0213303677737712,0.594974708618752,0.9866802096366882,0.0485155917704105,0.2651957580429345,43793.0,0.9857563972473145,0.0517811812460422,0.2562075157043629,43793.0,18266.163946151733,28175.351845502853,18266.163946151733,9904.859407424929,2.799586296081543,0.0 -56800,0.1804839,0.02397401,,,,,,,,,,,,,,,,, -56900,0.18922861,0.023331456,,,,,,,,,,,,,,,,, -57000,0.1777507,0.025536044,,,,,,,,,,,,,,,,, -57100,0.17755829,0.024718234,,,,,,,,,,,,,,,,, -57200,0.19211124,0.02658745,,,,,,,,,,,,,,,,, -57300,0.16268106,0.023821607,,,,,,,,,,,,,,,,, -57359,,,,,,,,,,,,,,18477.090780735016,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/eval_measurements.csv deleted file mode 100644 index a745b9543..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -121.37401604652403,0.0,12.159255981445312,1,0,12.159255981445312,0.5224818587303162,0.7161954641342163,0.0278402619301766,43793,133.5333137512207,0.5251287221908569,0.7151197195053101,0.0241442787101899,0.5213832855224609,0.7166012525558472,0.0261539542103122,43793 -247.18216228485107,0.0235548019409179,252.4109787940979,742,0,252.4109787940979,0.983142077922821,0.0757889300584793,0.0436686746645511,43793,499.6362552642822,0.9867380261421204,0.064150258898735,0.0403616680620335,0.9841179251670836,0.0728871896862983,0.0420972196052659,43793 -371.4069876670837,0.0522158145904541,492.4810729026794,1468,0,492.4810729026794,0.9831437468528748,0.0672952905297279,0.0508437648925251,43793,863.9817779064178,0.9867783784866332,0.0536482743918895,0.0473054507702929,0.9841195344924928,0.0639380291104316,0.0484474674040182,43793 -490.7164118289948,0.07963228225708,732.6191091537476,2213,0,732.6191091537476,0.9836630821228028,0.0585633553564548,0.1103763842191628,43793,1223.476276397705,0.9873389005661012,0.0461795665323734,0.1193167657896497,0.984654188156128,0.0554618798196315,0.1120072526561411,43793 -608.784569978714,0.1080617904663086,972.7951905727386,2951,0,972.7951905727386,0.9839431643486024,0.0571308247745037,0.1403339919641544,43793,1581.7689995765686,0.9876325130462646,0.0444084219634532,0.1469359938076572,0.984946072101593,0.0540356822311878,0.1438901259098251,43793 -730.6357326507568,0.1368811130523681,1213.0103611946106,3695,0,1213.0103611946106,0.9841756820678712,0.0548621453344821,0.1504054410516651,43793,1943.883972644806,0.987868309020996,0.0430908612906932,0.1693227572686155,0.98510479927063,0.05209706351161,0.1507358351030023,43793 -847.740592956543,0.1639246940612793,1453.0346467494965,4435,0,1453.0346467494965,0.9844208359718324,0.0528491176664829,0.1704869601797195,43793,2301.0603160858154,0.9882818460464478,0.0408571362495422,0.2046375427604755,0.9853986501693726,0.0501000694930553,0.1703647678159369,43793 -969.2806987762452,0.1926295757293701,1693.2447366714478,5175,0,1693.2447366714478,0.9846398234367372,0.0521175302565097,0.1875863792805293,43793,2662.8599441051483,0.9885432720184326,0.0398456007242202,0.216454931039379,0.9855533242225648,0.0495475642383098,0.19147633222763,43793 -1091.2149925231934,0.2206106185913086,1933.3324942588808,5912,0,1933.3324942588808,0.9847046732902528,0.051370620727539,0.1949138699611138,43793,3024.929833889008,0.9885270595550536,0.038916241377592,0.238007550440989,0.9856272339820862,0.0486249588429927,0.1974359053649196,43793 -1213.531497001648,0.2502303123474121,2173.4175686836243,6651,0,2173.4175686836243,0.9848660230636596,0.0509888343513011,0.2112681620169295,43793,3387.38121843338,0.9885249137878418,0.038509763777256,0.2527049176774597,0.9857376217842102,0.0481778718531131,0.2078144214837223,43793 -1331.4031381607056,0.2816083431243896,2413.567792892456,7381,0,2413.567792892456,0.984985649585724,0.0502688474953174,0.2317656739722545,43793,3745.457129716873,0.9886531829833984,0.0377003587782382,0.2923059632380551,0.985855758190155,0.0476825460791587,0.2244302866692918,43793 -1452.041398525238,0.3138508796691894,2653.66479420662,8111,0,2653.66479420662,0.9850509166717528,0.0505181662738323,0.2287647445540784,43793,4106.247630119324,0.989067792892456,0.0366627164185047,0.3014880708827807,0.9859344959259032,0.0479256436228752,0.2225765737842831,43793 -1569.6379013061523,0.3439414501190185,2893.8072040081024,8847,0,2893.8072040081024,0.9853474497795104,0.04921505600214,0.2329979128399078,43793,4464.036288499832,0.989511251449585,0.0353168547153472,0.3274415535340602,0.9862276315689088,0.0465082228183746,0.2328149372357666,43793 -1690.8187873363495,0.373058557510376,3133.857388734817,9594,0,3133.857388734817,0.9854249358177184,0.0489817410707473,0.2374337427035071,43793,4825.316298484802,0.9896472096443176,0.0346287861466407,0.3339020980092739,0.9862998723983764,0.0462951585650444,0.2387715916397305,43793 -1813.10927772522,0.4024133682250976,3374.110833644867,10326,0,3374.110833644867,0.9856414198875428,0.0482327677309513,0.2421628109443619,43793,5187.90918803215,0.989692211151123,0.0348010957241058,0.3372411340607423,0.9864484667778016,0.0456375554203987,0.2384509236503546,43793 -1932.6103360652924,0.4336550235748291,3614.1969878673553,11064,0,3614.1969878673553,0.9856801629066468,0.0483282282948493,0.2508927952331046,43793,5547.547726154327,0.9897984266281128,0.0339798592031002,0.3418119491218034,0.986470341682434,0.0457224324345588,0.2462407117357929,43793 -2051.3599536418915,0.463397741317749,3854.333966732025,11806,0,3854.333966732025,0.9856789112091064,0.0481674931943416,0.2501293202794034,43793,5906.48420381546,0.9900467395782472,0.0330719351768493,0.3749070903184368,0.9865052700042723,0.0455107390880584,0.2459694547153054,43793 -2172.739722967148,0.4942986965179443,4094.3583641052246,12535,0,4094.3583641052246,0.9857109189033508,0.0479223653674125,0.2474763344552082,43793,6267.943195104599,0.9904047846794128,0.0322107784450054,0.378024805566873,0.9865750670433044,0.0453389957547187,0.2465292262857419,43793 -2290.664441347122,0.525383710861206,4334.632151842117,13272,0,4334.632151842117,0.9858461618423462,0.0478320717811584,0.256542108864486,43793,6626.192683458328,0.990593671798706,0.0313814431428909,0.4109021595111499,0.9866944551467896,0.0450428761541843,0.2539900303246503,43793 -2412.107246160507,0.5567858219146729,4574.894693851471,14008,0,4574.894693851471,0.985842764377594,0.0476877987384796,0.2592138244146481,43793,6987.949181556702,0.9907403588294984,0.0303930770605802,0.4359508097884353,0.9867265224456788,0.0450580827891826,0.2625926952812717,43793 -2535.4455637931824,0.586010217666626,4814.908618211746,14739,0,4814.908618211746,0.9856974482536316,0.0481946356594562,0.2531345840715379,43793,7351.350212574005,0.9908390641212464,0.0299171060323715,0.4426423572423156,0.986622989177704,0.0454111546277999,0.2583273177120799,43793 -2656.38614153862,0.6158895492553711,5054.8963351249695,15467,0,5054.8963351249695,0.9858457446098328,0.0474570840597152,0.2533422638624075,43793,7712.331021547317,0.991133749485016,0.0297701694071292,0.439455381327504,0.9866952300071716,0.0447630360722541,0.2584276547547586,43793 -2772.7395095825195,0.6458499431610107,5295.120626449585,16201,0,5295.120626449585,0.985964059829712,0.0480331107974052,0.2627030608426233,43793,8068.958131074905,0.9909468293190002,0.0296922866255044,0.4539990121975116,0.9867402911186218,0.0453396849334239,0.2613085320446217,43793 -2891.1803154945374,0.6761302947998047,5535.215556144714,16944,0,5535.215556144714,0.9858773350715636,0.0486006960272789,0.2582021144128389,43793,8427.543749332428,0.9909261465072632,0.0296106729656457,0.4592231931246184,0.9867545366287231,0.0458285622298717,0.2621394887681972,43793 -3007.234807729721,0.7085833549499512,5775.349369287491,17681,0,5775.349369287491,0.985878586769104,0.0483734384179115,0.2591937009499683,43793,8783.784770011902,0.9912601709365844,0.0284853987395763,0.4747558361436588,0.9867882132530212,0.0454965271055698,0.2641498525409075,43793 -3129.774727344513,0.739865779876709,6015.588186979294,18420,0,6015.588186979294,0.985854148864746,0.0481930561363697,0.2527952731086541,43793,9146.614343166351,0.9915626645088196,0.0277974400669336,0.4831674995977595,0.9867565631866456,0.0454152189195156,0.2554176813356391,43793 -3244.3884332180023,0.7705888748168945,6255.743933677673,19167,0,6255.743933677673,0.9859089255332948,0.0483778268098831,0.2589651663826672,43793,9501.434201717377,0.9915648102760316,0.0273576527833938,0.5086215310274381,0.9867614507675172,0.0455816313624382,0.2631518378835457,43793 -3358.1936955451965,0.8017594814300537,6496.008260965347,19903,0,6496.008260965347,0.9860078692436218,0.0478366501629352,0.2606229283561334,43793,9855.55399274826,0.992231011390686,0.0255564451217651,0.5372999658293216,0.9868332743644714,0.0451035685837268,0.2715457824918656,43793 -3474.561186313629,0.832097053527832,6736.173248052597,20644,0,6736.173248052597,0.9859581589698792,0.0484275855123996,0.2603293757011919,43793,10212.136778831482,0.9922563433647156,0.0252972152084112,0.5544247501332782,0.9868401885032654,0.0456213280558586,0.2686777262266689,43793 -3590.268481016159,0.8643553256988525,6976.150879383087,21384,0,6976.150879383087,0.9859493374824524,0.0490799285471439,0.2598057736054661,43793,10567.873800992966,0.99234277009964,0.0249669533222913,0.5519044794078625,0.9868003726005554,0.0461722835898399,0.264854724285865,43793 -3709.2970135211945,0.8965957164764404,7216.116875648498,22120,0,7216.116875648498,0.985874354839325,0.0491098128259182,0.261793099253754,43793,10926.920147418976,0.991874098777771,0.0265325270593166,0.5180254499914452,0.986780881881714,0.0461693368852138,0.2620975266770983,43793 -3831.945240736008,0.9279828071594238,7456.379545927048,22854,0,7456.379545927048,0.9859341979026794,0.04970159009099,0.2569096916812762,43793,11289.88250374794,0.9917595386505128,0.0264614392071962,0.5174314820179169,0.986764669418335,0.0466690026223659,0.2659003367203237,43793 -3952.633655309677,0.962378740310669,7696.485308170319,23586,0,7696.485308170319,0.9858810901641846,0.0492570735514164,0.2546088235250043,43793,11650.7317340374,0.992215096950531,0.0253068357706069,0.5333598000806847,0.9867005348205566,0.0463218465447425,0.2577843055251133,43793 -4071.9370653629294,0.9949030876159668,7936.522948503494,24314,0,7936.522948503494,0.9858802556991576,0.0502522885799407,0.2530300101575706,43793,12010.12484574318,0.9922340512275696,0.0250587128102779,0.5520819790954581,0.986710250377655,0.047242235392332,0.2548902204472823,43793 -4188.638366937637,1.026634693145752,8176.764466285706,25054,0,8176.764466285706,0.985859215259552,0.0502813048660755,0.255721761943902,43793,12367.119560956957,0.9923595190048218,0.0244757682085037,0.5629086457355192,0.9867419600486756,0.0471686944365501,0.2574200603232365,43793 -4309.53524518013,1.0576236248016355,8416.962620973587,25778,0,8416.962620973587,0.9858731031417848,0.0505585819482803,0.2549490936009421,43793,12728.26784825325,0.9928763508796692,0.0227034576237201,0.6162934071137858,0.9866648316383362,0.047454223036766,0.2594423493508385,43793 -4429.807000875473,1.0906052589416504,8657.234840393066,26515,0,8657.234840393066,0.985743761062622,0.0503081008791923,0.2500445869214605,43793,13088.86488366127,0.9934829473495485,0.0216501988470554,0.6098429034027366,0.9865832328796388,0.0472255833446979,0.2603134068454962,43793 -4551.0147252082825,1.12199068069458,8897.440607070923,27259,0,8897.440607070923,0.9857968688011168,0.0504861883819103,0.2517114387345971,43793,13450.329896450045,0.9931455254554749,0.0224333405494689,0.601588529797614,0.986559271812439,0.0475690253078937,0.2533269911318941,43793 -4675.2046592235565,1.1547377109527588,9137.503035783768,27986,0,9137.503035783768,0.9857610464096068,0.0507408790290355,0.2519577155097544,43793,13814.634695529938,0.99298095703125,0.0228356607258319,0.5933179400988164,0.986605942249298,0.0477442778646945,0.2548442086332938,43793 -4790.3657166957855,1.1872403621673584,9377.56341791153,28719,0,9377.56341791153,0.9859505891799928,0.0510335154831409,0.2555062883256959,43793,14169.90866613388,0.9927420020103456,0.0231804102659225,0.5969560782400278,0.9867175817489624,0.0479126051068306,0.2682034494169585,43793 -4911.0401475429535,1.2209062576293943,9617.663268327711,29456,0,9617.663268327711,0.9856972694396972,0.0518477261066436,0.2487521176330588,43793,14530.736252069471,0.9926162958145142,0.0233869291841983,0.5947517350336357,0.9866631627082824,0.0486439242959022,0.2574898460922729,43793 -5029.308331251144,1.2543549537658691,9857.94836807251,30183,0,9857.94836807251,0.9857484102249146,0.0511756390333175,0.2540606307889602,43793,14889.342999219894,0.9931610822677612,0.0220011305063962,0.6025305047501668,0.986594557762146,0.0480909161269664,0.2596651647218211,43793 -5148.699724435806,1.288421869277954,10098.119551181791,30913,0,10098.119551181791,0.9857787489891052,0.0517778843641281,0.2512915422425292,43793,15248.962916851044,0.9933783411979676,0.0210400521755218,0.6411639773982251,0.9866043329238892,0.0487705618143081,0.2531882033520808,43793 -5270.216884851456,1.32468581199646,10338.317598104475,31645,0,10338.317598104475,0.9857193827629088,0.0524006262421608,0.2476491451965961,43793,15610.73458480835,0.993828535079956,0.0197061952203512,0.6617192203318676,0.9866339564323424,0.0489429533481597,0.2552728387461361,43793 -5389.546606063843,1.3582172393798828,10578.267916440964,32376,0,10578.267916440964,0.9855828881263732,0.0523907206952571,0.2438136395282832,43793,15970.068334579468,0.994285523891449,0.0187331493943929,0.6851187612466829,0.986394464969635,0.0492233633995056,0.2534544147804244,43793 -5508.245011806488,1.3919477462768557,10818.28539800644,33103,0,10818.28539800644,0.9857964515686036,0.0526494719088077,0.2487113766268982,43793,16328.837913751602,0.9938401579856871,0.0198705811053514,0.6640449668679048,0.9866132736206056,0.0494127534329891,0.2477094520285919,43793 -5624.259745597839,1.4259235858917236,11058.477831840515,33847,0,11058.477831840515,0.9856372475624084,0.0534129738807678,0.2489563508088895,43793,16685.099545240402,0.9930763244628906,0.0219397060573101,0.604933073567349,0.9864837527275084,0.0501733385026454,0.2480326956404477,43793 -5736.817763805389,1.4586491584777832,11298.49192237854,34589,0,11298.49192237854,0.9856675267219543,0.0528992712497711,0.2486676367310245,43793,17037.724573135376,0.9932107329368592,0.0217495448887348,0.6007094082387573,0.9864163994789124,0.0497723706066608,0.2514264851650297,43793 -5852.193566083908,1.4922783374786377,11538.711616039276,35331,0,11538.711616039276,0.9855715036392212,0.0532553158700466,0.251005720832414,43793,17393.373587608337,0.993598461151123,0.020510371774435,0.6334068095016367,0.986381471157074,0.050090841948986,0.2464012148980412,43793 -5962.109864473343,1.5256366729736328,11778.746633768082,36077,0,11778.746633768082,0.9856035113334656,0.0539854243397712,0.2435299307766965,43793,17743.378707647324,0.99366557598114,0.0199929028749465,0.6459082535084215,0.9864638447761536,0.0505931377410888,0.2457527230082513,43793 -6079.482671022415,1.9244272708892824,12018.441509246826,36814,0,12018.441509246826,0.9855087399482728,0.0544721521437168,0.2419251836295627,43793,18100.86488223076,0.9936034679412842,0.0200266540050506,0.6523864222962256,0.98634535074234,0.0508432760834693,0.2410080625010881,43793 -6195.485498428345,1.959678649902344,12258.66541147232,37548,0,12258.66541147232,0.985682725906372,0.0545710735023021,0.2465838969900322,43793,18457.14681982994,0.9947233200073242,0.0171383786946535,0.7183873072526152,0.9864943027496338,0.0512232780456543,0.246641803540727,43793 -6314.177472352982,1.9961557388305664,12498.67090535164,38282,0,12498.67090535164,0.985569417476654,0.0554567091166973,0.2416775400873894,43793,18815.900892019272,0.995094358921051,0.0160580594092607,0.7305360156173912,0.9864216446876526,0.0519767627120018,0.2444132828398841,43793 -6430.45547246933,2.035463333129883,12738.752710580826,39020,0,12738.752710580826,0.9855411648750304,0.0549594648182392,0.2413185281618792,43793,19172.32074737549,0.9950499534606934,0.0164400376379489,0.7171123304405793,0.9864265322685242,0.0511059425771236,0.2482558400172647,43793 -6545.504342556,2.070361852645874,12978.81694483757,39750,0,12978.81694483757,0.9855074882507324,0.05539982765913,0.2396051910604039,43793,19527.48945236206,0.9946119785308838,0.0173363238573074,0.7018877172650734,0.9864224791526794,0.0517116338014602,0.2454003609325848,43793 -6660.941750526428,2.1060330867767334,13218.769835948944,40480,0,13218.769835948944,0.9855066537857056,0.0563833639025688,0.2406186509886538,43793,19882.935147047043,0.9938182830810548,0.0190417077392339,0.6818837267982015,0.9863936305046082,0.0527338311076164,0.2413260157386747,43793 -6770.818592071533,2.141692876815796,13458.771996974943,41218,0,13458.771996974943,0.9854856133461,0.0560379549860954,0.232853285548857,43793,20232.86994457245,0.9941673874855042,0.0182603262364864,0.6868479320912788,0.986441969871521,0.0524128153920173,0.2450127274158767,43793 -6886.031506538391,2.176521062850952,13698.72559428215,41956,0,13698.72559428215,0.9855399131774902,0.0567248314619064,0.2337525892945866,43793,20588.09084820748,0.9942653775215148,0.0178059693425893,0.6941094896473745,0.9863846898078918,0.053409494459629,0.2359936574011092,43793 -6995.729851484299,2.2115795612335205,13938.932447195051,42698,0,13938.932447195051,0.9854699969291688,0.056529015302658,0.2339992933288823,43793,20938.051344156265,0.9948505759239196,0.0164919458329677,0.7170254342784487,0.9862925410270692,0.0528776720166206,0.2373408264054241,43793 -7108.043355226517,2.249853610992432,14179.018709421158,43447,0,14179.018709421158,0.985352098941803,0.0571636594831943,0.232133301550064,43793,21290.509116888046,0.9954681992530824,0.0149921560660004,0.7636412831561692,0.986258029937744,0.0534884706139564,0.2398435653849362,43793 -7220.051170825958,2.2852609157562256,14418.994359016418,44197,0,14418.994359016418,0.9854013323783876,0.0583654716610908,0.2327379965088002,43793,21642.547493696213,0.995785117149353,0.0140907894819974,0.7848275030564575,0.9862913489341736,0.0545599684119224,0.2348540301396725,43793 -7332.72357583046,2.32065486907959,14659.189492940905,44943,0,14659.189492940905,0.9854552745819092,0.0584064722061157,0.2363809785728483,43793,21995.470474243164,0.9962485432624816,0.0128924082964658,0.8084429781038803,0.9863153100013732,0.0548417158424854,0.2340596465030607,43793 -7444.150077819824,2.3560049533843994,14899.14623594284,45683,0,14899.14623594284,0.9852480292320251,0.0591854453086853,0.2316662683194748,43793,22346.908483743668,0.9945888519287108,0.0169161148369312,0.7274850797457694,0.9861624836921692,0.0553164780139923,0.2296347303095366,43793 -7553.0161652565,2.39290738105774,15139.321603536606,46426,0,15139.321603536606,0.9853516817092896,0.0587752237915992,0.2365998127794417,43793,22696.00659751892,0.9955013990402222,0.0145522383973002,0.7470824289675861,0.9862040877342224,0.0551997199654579,0.2317819172577487,43793 -7664.210397481918,2.4298739433288574,15379.40378499031,47163,0,15379.40378499031,0.9853390455245972,0.0595657564699649,0.2342264089381991,43793,23047.339591503143,0.9950583577156068,0.0155198480933904,0.7462739776245799,0.9862487316131592,0.0558253675699234,0.233228704631362,43793 -7775.425299167633,2.46665096282959,15619.506531715391,47902,0,15619.506531715391,0.9852008819580078,0.0601598471403121,0.2281059140172945,43793,23398.714463949203,0.9947388172149658,0.0161861889064311,0.7402327222700571,0.9861894249916076,0.0563515312969684,0.2244347422110065,43793 -7885.1501557827,2.50447678565979,15859.705310821531,48640,0,15859.705310821531,0.9852421283721924,0.0606749877333641,0.2275132350776089,43793,23748.696949481964,0.9944795966148376,0.0167838465422391,0.7291917979386155,0.986224353313446,0.056857194751501,0.2288818870293448,43793 -7998.654189348221,2.5425143241882324,16099.769166469574,49384,0,16099.769166469574,0.9852739572525024,0.0606629438698291,0.2257372623571563,43793,24102.322848558422,0.9957962036132812,0.0137219009920954,0.7867519225622568,0.9862629175186156,0.0566055364906787,0.2319078980631626,43793 -8114.529004335403,2.579529523849488,16339.929149627686,50127,0,16339.929149627686,0.9852084517478944,0.0609424151480197,0.2246761260090098,43793,24458.41442155838,0.996640145778656,0.0119273085147142,0.8215199199748869,0.9861634969711304,0.0570623017847538,0.2324574809517005,43793 -8229.751652002335,2.6154143810272217,16580.062649965286,50874,0,16580.062649965286,0.9852164387702942,0.0615736842155456,0.2228248464410566,43793,24813.826580524445,0.9970665574073792,0.0110013792291283,0.8423189663398486,0.9861395359039308,0.0575702711939811,0.2293329472246431,43793 -8343.433227062225,2.657003402709961,16820.25114750862,51617,0,16820.25114750862,0.9852501749992372,0.062121957540512,0.2230208669252613,43793,25167.75855875016,0.996950089931488,0.011073999106884,0.8505485895345607,0.986198365688324,0.0580263286828994,0.2304312454150581,43793 -8456.342872619629,2.694946765899658,17060.339695692062,52353,0,17060.339695692062,0.985108196735382,0.062013104557991,0.2234894896601067,43793,25520.81440806389,0.9966402053833008,0.0118224844336509,0.8342763710567649,0.9861005544662476,0.0578526854515075,0.2286860382086004,43793 -8565.075862884521,2.733237028121948,17300.507427692413,53106,0,17300.507427692413,0.9851238131523132,0.0628400072455406,0.2246251051391217,43793,25869.773553848267,0.9960552453994752,0.0127846905961632,0.808120967576912,0.986150085926056,0.0588770769536495,0.2253980118229107,43793 -8678.787027597427,2.770055055618286,17540.69450187683,53849,0,17540.69450187683,0.9851983189582824,0.0627905428409576,0.2211224959491054,43793,26223.72825813293,0.99503892660141,0.0151807256042957,0.755597599761523,0.986154556274414,0.0587789863348007,0.2282030288300124,43793 -8789.08589553833,2.8089280128479004,17780.679956674576,54595,0,17780.679956674576,0.9851747751235962,0.0633826255798339,0.2202348065167658,43793,26574.07158923149,0.995265007019043,0.0146405976265668,0.7755410902328488,0.9860976934432985,0.0594463013112545,0.2263200967059221,43793 -8899.616615772247,2.8510804176330566,18020.666786193848,55339,0,18020.666786193848,0.9852118492126464,0.0640837997198104,0.2175502037047203,43793,26924.650710582733,0.9959477186203004,0.0129019608721137,0.8156166730274892,0.9861111044883728,0.0599651709198951,0.2242800732972598,43793 -9008.119769573212,2.8922278881073,18260.690786361694,56087,0,18260.690786361694,0.9850656390190125,0.06437241286039352,0.2164854876489734,43793,27273.239208221436,0.9954010248184204,0.014091824181377888,0.8007883470937304,0.9860416650772095,0.06013806164264679,0.22281162844515326,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/measurements.csv deleted file mode 100644 index f2c116c35..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/measurements.csv +++ /dev/null @@ -1,647 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.7374306,0.71465915,,,,,,,,,,,,,,,,, -1,,,0.5251287221908569,0.7151197195053101,0.0241442787101899,0.5213832855224609,0.7166012525558472,0.0261539542103122,43793.0,0.5224818587303162,0.7161954641342163,0.0278402619301766,43793.0,12.159255981445312,133.5333137512207,12.159255981445312,121.37401604652403,0.0,0.0 -100,0.43090063,0.37779137,,,,,,,,,,,,,,,,, -200,0.34009856,0.28988427,,,,,,,,,,,,,,,,, -300,0.25765303,0.20068635,,,,,,,,,,,,,,,,, -400,0.15369603,0.13612215,,,,,,,,,,,,,,,,, -500,0.13582495,0.094141796,,,,,,,,,,,,,,,,, -600,0.09296992,0.07757967,,,,,,,,,,,,,,,,, -700,0.106897525,0.0610802,,,,,,,,,,,,,,,,, -742,,,0.9867380261421204,0.064150258898735,0.0403616680620335,0.9841179251670836,0.0728871896862983,0.0420972196052659,43793.0,0.983142077922821,0.0757889300584793,0.0436686746645511,43793.0,252.4109787940979,499.6362552642822,252.4109787940979,247.18216228485107,0.0235548019409179,0.0 -800,0.03171458,0.061985992,,,,,,,,,,,,,,,,, -900,0.058275506,0.058535147,,,,,,,,,,,,,,,,, -1000,0.032781743,0.05698038,,,,,,,,,,,,,,,,, -1100,0.027900672,0.05412717,,,,,,,,,,,,,,,,, -1200,0.052511495,0.05226143,,,,,,,,,,,,,,,,, -1300,0.08766163,0.051040027,,,,,,,,,,,,,,,,, -1400,0.030923763,0.04751333,,,,,,,,,,,,,,,,, -1468,,,0.9867783784866332,0.0536482743918895,0.0473054507702929,0.9841195344924928,0.0639380291104316,0.0484474674040182,43793.0,0.9831437468528748,0.0672952905297279,0.0508437648925251,43793.0,492.4810729026794,863.9817779064178,492.4810729026794,371.4069876670837,0.0522158145904541,0.0 -1500,0.05413845,0.04964868,,,,,,,,,,,,,,,,, -1600,0.03575883,0.050236702,,,,,,,,,,,,,,,,, -1700,0.11966938,0.05531621,,,,,,,,,,,,,,,,, -1800,0.04577616,0.051737573,,,,,,,,,,,,,,,,, -1900,0.09076723,0.051791877,,,,,,,,,,,,,,,,, -2000,0.092518024,0.048363175,,,,,,,,,,,,,,,,, -2100,0.123445086,0.046751022,,,,,,,,,,,,,,,,, -2200,0.060705807,0.04907848,,,,,,,,,,,,,,,,, -2213,,,0.9873389005661012,0.0461795665323734,0.1193167657896497,0.984654188156128,0.0554618798196315,0.1120072526561411,43793.0,0.9836630821228028,0.0585633553564548,0.1103763842191628,43793.0,732.6191091537476,1223.476276397705,732.6191091537476,490.7164118289948,0.07963228225708,0.0 -2300,0.040745165,0.04508823,,,,,,,,,,,,,,,,, -2400,0.028631486,0.04715676,,,,,,,,,,,,,,,,, -2500,0.057712823,0.04446732,,,,,,,,,,,,,,,,, -2600,0.04963796,0.04966117,,,,,,,,,,,,,,,,, -2700,0.05641993,0.04395234,,,,,,,,,,,,,,,,, -2800,0.023665449,0.043355107,,,,,,,,,,,,,,,,, -2900,0.04534983,0.043435395,,,,,,,,,,,,,,,,, -2951,,,0.9876325130462646,0.0444084219634532,0.1469359938076572,0.984946072101593,0.0540356822311878,0.1438901259098251,43793.0,0.9839431643486024,0.0571308247745037,0.1403339919641544,43793.0,972.7951905727386,1581.7689995765686,972.7951905727386,608.784569978714,0.1080617904663086,0.0 -3000,0.044870343,0.048441548,,,,,,,,,,,,,,,,, -3100,0.034398586,0.04819179,,,,,,,,,,,,,,,,, -3200,0.030830463,0.04358172,,,,,,,,,,,,,,,,, -3300,0.02590187,0.044616323,,,,,,,,,,,,,,,,, -3400,0.018751698,0.045788273,,,,,,,,,,,,,,,,, -3500,0.02629222,0.04817317,,,,,,,,,,,,,,,,, -3600,0.06846924,0.04487653,,,,,,,,,,,,,,,,, -3695,,,0.987868309020996,0.0430908612906932,0.1693227572686155,0.98510479927063,0.05209706351161,0.1507358351030023,43793.0,0.9841756820678712,0.0548621453344821,0.1504054410516651,43793.0,1213.0103611946106,1943.883972644806,1213.0103611946106,730.6357326507568,0.1368811130523681,0.0 -3700,0.049667746,0.046941407,,,,,,,,,,,,,,,,, -3800,0.027366374,0.04518434,,,,,,,,,,,,,,,,, -3900,0.032669198,0.046587687,,,,,,,,,,,,,,,,, -4000,0.027653757,0.048168126,,,,,,,,,,,,,,,,, -4100,0.04389618,0.042548325,,,,,,,,,,,,,,,,, -4200,0.01729105,0.04877333,,,,,,,,,,,,,,,,, -4300,0.018790357,0.04611555,,,,,,,,,,,,,,,,, -4400,0.03438347,0.038912717,,,,,,,,,,,,,,,,, -4435,,,0.9882818460464478,0.0408571362495422,0.2046375427604755,0.9853986501693726,0.0501000694930553,0.1703647678159369,43793.0,0.9844208359718324,0.0528491176664829,0.1704869601797195,43793.0,1453.0346467494965,2301.0603160858154,1453.0346467494965,847.740592956543,0.1639246940612793,0.0 -4500,0.025253981,0.04544213,,,,,,,,,,,,,,,,, -4600,0.022534998,0.047839537,,,,,,,,,,,,,,,,, -4700,0.033213906,0.04680278,,,,,,,,,,,,,,,,, -4800,0.04031537,0.045547824,,,,,,,,,,,,,,,,, -4900,0.019588478,0.04218425,,,,,,,,,,,,,,,,, -5000,0.021304147,0.042112872,,,,,,,,,,,,,,,,, -5100,0.04154813,0.044748757,,,,,,,,,,,,,,,,, -5175,,,0.9885432720184326,0.0398456007242202,0.216454931039379,0.9855533242225648,0.0495475642383098,0.19147633222763,43793.0,0.9846398234367372,0.0521175302565097,0.1875863792805293,43793.0,1693.2447366714478,2662.8599441051483,1693.2447366714478,969.2806987762452,0.1926295757293701,0.0 -5200,0.025291132,0.044567503,,,,,,,,,,,,,,,,, -5300,0.03195139,0.04320134,,,,,,,,,,,,,,,,, -5400,0.019636258,0.041894943,,,,,,,,,,,,,,,,, -5500,0.022208096,0.045167945,,,,,,,,,,,,,,,,, -5600,0.018570315,0.039844282,,,,,,,,,,,,,,,,, -5700,0.01424234,0.043035332,,,,,,,,,,,,,,,,, -5800,0.018017719,0.040763628,,,,,,,,,,,,,,,,, -5900,0.031898014,0.04306663,,,,,,,,,,,,,,,,, -5912,,,0.9885270595550536,0.038916241377592,0.238007550440989,0.9856272339820862,0.0486249588429927,0.1974359053649196,43793.0,0.9847046732902528,0.051370620727539,0.1949138699611138,43793.0,1933.3324942588808,3024.929833889008,1933.3324942588808,1091.2149925231934,0.2206106185913086,0.0 -6000,0.016793579,0.04160886,,,,,,,,,,,,,,,,, -6100,0.02496337,0.04442508,,,,,,,,,,,,,,,,, -6200,0.034244232,0.042624302,,,,,,,,,,,,,,,,, -6300,0.01586366,0.042934198,,,,,,,,,,,,,,,,, -6400,0.026301512,0.04035589,,,,,,,,,,,,,,,,, -6500,0.014896395,0.039819628,,,,,,,,,,,,,,,,, -6600,0.014251416,0.042529896,,,,,,,,,,,,,,,,, -6651,,,0.9885249137878418,0.038509763777256,0.2527049176774597,0.9857376217842102,0.0481778718531131,0.2078144214837223,43793.0,0.9848660230636596,0.0509888343513011,0.2112681620169295,43793.0,2173.4175686836243,3387.38121843338,2173.4175686836243,1213.531497001648,0.2502303123474121,0.0 -6700,0.018345838,0.043526005,,,,,,,,,,,,,,,,, -6800,0.030086493,0.042977415,,,,,,,,,,,,,,,,, -6900,0.024036046,0.043872137,,,,,,,,,,,,,,,,, -7000,0.016003098,0.039564658,,,,,,,,,,,,,,,,, -7100,0.017614821,0.04352025,,,,,,,,,,,,,,,,, -7200,0.01386407,0.039641917,,,,,,,,,,,,,,,,, -7300,0.016293831,0.041841663,,,,,,,,,,,,,,,,, -7381,,,0.9886531829833984,0.0377003587782382,0.2923059632380551,0.985855758190155,0.0476825460791587,0.2244302866692918,43793.0,0.984985649585724,0.0502688474953174,0.2317656739722545,43793.0,2413.567792892456,3745.457129716873,2413.567792892456,1331.4031381607056,0.2816083431243896,0.0 -7400,0.013587175,0.03673655,,,,,,,,,,,,,,,,, -7500,0.012333086,0.041828692,,,,,,,,,,,,,,,,, -7600,0.020408109,0.040639434,,,,,,,,,,,,,,,,, -7700,0.016898239,0.04570847,,,,,,,,,,,,,,,,, -7800,0.023503648,0.042423576,,,,,,,,,,,,,,,,, -7900,0.012486074,0.039697032,,,,,,,,,,,,,,,,, -8000,0.019338982,0.042595275,,,,,,,,,,,,,,,,, -8100,0.020884598,0.045047186,,,,,,,,,,,,,,,,, -8111,,,0.989067792892456,0.0366627164185047,0.3014880708827807,0.9859344959259032,0.0479256436228752,0.2225765737842831,43793.0,0.9850509166717528,0.0505181662738323,0.2287647445540784,43793.0,2653.66479420662,4106.247630119324,2653.66479420662,1452.041398525238,0.3138508796691894,0.0 -8200,0.01810151,0.040874816,,,,,,,,,,,,,,,,, -8300,0.020410141,0.04310378,,,,,,,,,,,,,,,,, -8400,0.012440818,0.041040882,,,,,,,,,,,,,,,,, -8500,0.021322437,0.03917093,,,,,,,,,,,,,,,,, -8600,0.022751363,0.04160734,,,,,,,,,,,,,,,,, -8700,0.01420434,0.040988274,,,,,,,,,,,,,,,,, -8800,0.019567968,0.0396229,,,,,,,,,,,,,,,,, -8847,,,0.989511251449585,0.0353168547153472,0.3274415535340602,0.9862276315689088,0.0465082228183746,0.2328149372357666,43793.0,0.9853474497795104,0.04921505600214,0.2329979128399078,43793.0,2893.8072040081024,4464.036288499832,2893.8072040081024,1569.6379013061523,0.3439414501190185,0.0 -8900,0.016629605,0.045134082,,,,,,,,,,,,,,,,, -9000,0.022517968,0.040463377,,,,,,,,,,,,,,,,, -9100,0.01666147,0.041070048,,,,,,,,,,,,,,,,, -9200,0.017298108,0.03647387,,,,,,,,,,,,,,,,, -9300,0.013882599,0.040153146,,,,,,,,,,,,,,,,, -9400,0.012678916,0.03830711,,,,,,,,,,,,,,,,, -9500,0.012479324,0.03558525,,,,,,,,,,,,,,,,, -9594,,,0.9896472096443176,0.0346287861466407,0.3339020980092739,0.9862998723983764,0.0462951585650444,0.2387715916397305,43793.0,0.9854249358177184,0.0489817410707473,0.2374337427035071,43793.0,3133.857388734817,4825.316298484802,3133.857388734817,1690.8187873363495,0.373058557510376,0.0 -9600,0.017922157,0.040059205,,,,,,,,,,,,,,,,, -9700,0.019493146,0.040227465,,,,,,,,,,,,,,,,, -9800,0.018473295,0.03998223,,,,,,,,,,,,,,,,, -9900,0.01721518,0.04040477,,,,,,,,,,,,,,,,, -10000,0.020942671,0.03869111,,,,,,,,,,,,,,,,, -10100,0.018634573,0.042354397,,,,,,,,,,,,,,,,, -10200,0.0165813,0.03709575,,,,,,,,,,,,,,,,, -10300,0.015831362,0.042163942,,,,,,,,,,,,,,,,, -10326,,,0.989692211151123,0.0348010957241058,0.3372411340607423,0.9864484667778016,0.0456375554203987,0.2384509236503546,43793.0,0.9856414198875428,0.0482327677309513,0.2421628109443619,43793.0,3374.110833644867,5187.90918803215,3374.110833644867,1813.10927772522,0.4024133682250976,0.0 -10400,0.015147741,0.039381437,,,,,,,,,,,,,,,,, -10500,0.013485591,0.0393359,,,,,,,,,,,,,,,,, -10600,0.022313086,0.040343948,,,,,,,,,,,,,,,,, -10700,0.015943173,0.03970031,,,,,,,,,,,,,,,,, -10800,0.0155500425,0.037826575,,,,,,,,,,,,,,,,, -10900,0.016349627,0.043082107,,,,,,,,,,,,,,,,, -11000,0.014662168,0.03936974,,,,,,,,,,,,,,,,, -11064,,,0.9897984266281128,0.0339798592031002,0.3418119491218034,0.986470341682434,0.0457224324345588,0.2462407117357929,43793.0,0.9856801629066468,0.0483282282948493,0.2508927952331046,43793.0,3614.1969878673553,5547.547726154327,3614.1969878673553,1932.6103360652924,0.4336550235748291,0.0 -11100,0.016884718,0.037861515,,,,,,,,,,,,,,,,, -11200,0.013792417,0.040063545,,,,,,,,,,,,,,,,, -11300,0.015969187,0.040490597,,,,,,,,,,,,,,,,, -11400,0.01464247,0.04094247,,,,,,,,,,,,,,,,, -11500,0.014156209,0.03920565,,,,,,,,,,,,,,,,, -11600,0.017077677,0.039050657,,,,,,,,,,,,,,,,, -11700,0.014440817,0.035815213,,,,,,,,,,,,,,,,, -11800,0.017900158,0.0372276,,,,,,,,,,,,,,,,, -11806,,,0.9900467395782472,0.0330719351768493,0.3749070903184368,0.9865052700042723,0.0455107390880584,0.2459694547153054,43793.0,0.9856789112091064,0.0481674931943416,0.2501293202794034,43793.0,3854.333966732025,5906.48420381546,3854.333966732025,2051.3599536418915,0.463397741317749,0.0 -11900,0.015800865,0.04092727,,,,,,,,,,,,,,,,, -12000,0.01788053,0.038194336,,,,,,,,,,,,,,,,, -12100,0.016634332,0.039176423,,,,,,,,,,,,,,,,, -12200,0.016645968,0.03757685,,,,,,,,,,,,,,,,, -12300,0.015866779,0.03540938,,,,,,,,,,,,,,,,, -12400,0.01659654,0.036298282,,,,,,,,,,,,,,,,, -12500,0.015944777,0.036968052,,,,,,,,,,,,,,,,, -12535,,,0.9904047846794128,0.0322107784450054,0.378024805566873,0.9865750670433044,0.0453389957547187,0.2465292262857419,43793.0,0.9857109189033508,0.0479223653674125,0.2474763344552082,43793.0,4094.3583641052246,6267.943195104599,4094.3583641052246,2172.739722967148,0.4942986965179443,0.0 -12600,0.016958587,0.038292978,,,,,,,,,,,,,,,,, -12700,0.01575876,0.036914796,,,,,,,,,,,,,,,,, -12800,0.018495852,0.03942824,,,,,,,,,,,,,,,,, -12900,0.023070194,0.03791136,,,,,,,,,,,,,,,,, -13000,0.018276943,0.040223226,,,,,,,,,,,,,,,,, -13100,0.01795967,0.040557586,,,,,,,,,,,,,,,,, -13200,0.022107726,0.042486515,,,,,,,,,,,,,,,,, -13272,,,0.990593671798706,0.0313814431428909,0.4109021595111499,0.9866944551467896,0.0450428761541843,0.2539900303246503,43793.0,0.9858461618423462,0.0478320717811584,0.256542108864486,43793.0,4334.632151842117,6626.192683458328,4334.632151842117,2290.664441347122,0.525383710861206,0.0 -13300,0.01890418,0.03652618,,,,,,,,,,,,,,,,, -13400,0.01917173,0.039738137,,,,,,,,,,,,,,,,, -13500,0.014322492,0.034905385,,,,,,,,,,,,,,,,, -13600,0.016122838,0.03812269,,,,,,,,,,,,,,,,, -13700,0.017468005,0.03832994,,,,,,,,,,,,,,,,, -13800,0.016187588,0.038114112,,,,,,,,,,,,,,,,, -13900,0.021099111,0.038843803,,,,,,,,,,,,,,,,, -14000,0.019433279,0.04029085,,,,,,,,,,,,,,,,, -14008,,,0.9907403588294984,0.0303930770605802,0.4359508097884353,0.9867265224456788,0.0450580827891826,0.2625926952812717,43793.0,0.985842764377594,0.0476877987384796,0.2592138244146481,43793.0,4574.894693851471,6987.949181556702,4574.894693851471,2412.107246160507,0.5567858219146729,0.0 -14100,0.021758888,0.04032136,,,,,,,,,,,,,,,,, -14200,0.022844821,0.036522374,,,,,,,,,,,,,,,,, -14300,0.015987901,0.03591352,,,,,,,,,,,,,,,,, -14400,0.016153231,0.03624274,,,,,,,,,,,,,,,,, -14500,0.01669077,0.03748264,,,,,,,,,,,,,,,,, -14600,0.01987761,0.03679028,,,,,,,,,,,,,,,,, -14700,0.02807483,0.035703033,,,,,,,,,,,,,,,,, -14739,,,0.9908390641212464,0.0299171060323715,0.4426423572423156,0.986622989177704,0.0454111546277999,0.2583273177120799,43793.0,0.9856974482536316,0.0481946356594562,0.2531345840715379,43793.0,4814.908618211746,7351.350212574005,4814.908618211746,2535.4455637931824,0.586010217666626,0.0 -14800,0.021531671,0.039192975,,,,,,,,,,,,,,,,, -14900,0.014655235,0.033913255,,,,,,,,,,,,,,,,, -15000,0.023210244,0.037905637,,,,,,,,,,,,,,,,, -15100,0.015874285,0.034348834,,,,,,,,,,,,,,,,, -15200,0.01781749,0.036484465,,,,,,,,,,,,,,,,, -15300,0.018908551,0.038063947,,,,,,,,,,,,,,,,, -15400,0.022314109,0.039221723,,,,,,,,,,,,,,,,, -15467,,,0.991133749485016,0.0297701694071292,0.439455381327504,0.9866952300071716,0.0447630360722541,0.2584276547547586,43793.0,0.9858457446098328,0.0474570840597152,0.2533422638624075,43793.0,5054.8963351249695,7712.331021547317,5054.8963351249695,2656.38614153862,0.6158895492553711,0.0 -15500,0.01919812,0.036824696,,,,,,,,,,,,,,,,, -15600,0.021875639,0.033456575,,,,,,,,,,,,,,,,, -15700,0.01886285,0.034077704,,,,,,,,,,,,,,,,, -15800,0.026734937,0.03565537,,,,,,,,,,,,,,,,, -15900,0.01883578,0.038048208,,,,,,,,,,,,,,,,, -16000,0.018281266,0.03580335,,,,,,,,,,,,,,,,, -16100,0.019584058,0.036116607,,,,,,,,,,,,,,,,, -16200,0.018431002,0.03602461,,,,,,,,,,,,,,,,, -16201,,,0.9909468293190002,0.0296922866255044,0.4539990121975116,0.9867402911186218,0.0453396849334239,0.2613085320446217,43793.0,0.985964059829712,0.0480331107974052,0.2627030608426233,43793.0,5295.120626449585,8068.958131074905,5295.120626449585,2772.7395095825195,0.6458499431610107,0.0 -16300,0.01753031,0.03272322,,,,,,,,,,,,,,,,, -16400,0.018338168,0.037239257,,,,,,,,,,,,,,,,, -16500,0.02465009,0.03793838,,,,,,,,,,,,,,,,, -16600,0.020376831,0.03717854,,,,,,,,,,,,,,,,, -16700,0.019475628,0.03593515,,,,,,,,,,,,,,,,, -16800,0.024552278,0.035934936,,,,,,,,,,,,,,,,, -16900,0.020589259,0.03601191,,,,,,,,,,,,,,,,, -16944,,,0.9909261465072632,0.0296106729656457,0.4592231931246184,0.9867545366287231,0.0458285622298717,0.2621394887681972,43793.0,0.9858773350715636,0.0486006960272789,0.2582021144128389,43793.0,5535.215556144714,8427.543749332428,5535.215556144714,2891.1803154945374,0.6761302947998047,0.0 -17000,0.019110238,0.035409234,,,,,,,,,,,,,,,,, -17100,0.025429824,0.036302984,,,,,,,,,,,,,,,,, -17200,0.02203043,0.034607325,,,,,,,,,,,,,,,,, -17300,0.02512493,0.035041064,,,,,,,,,,,,,,,,, -17400,0.021889675,0.035758544,,,,,,,,,,,,,,,,, -17500,0.022649752,0.037474662,,,,,,,,,,,,,,,,, -17600,0.02216504,0.03447264,,,,,,,,,,,,,,,,, -17681,,,0.9912601709365844,0.0284853987395763,0.4747558361436588,0.9867882132530212,0.0454965271055698,0.2641498525409075,43793.0,0.985878586769104,0.0483734384179115,0.2591937009499683,43793.0,5775.349369287491,8783.784770011902,5775.349369287491,3007.234807729721,0.7085833549499512,0.0 -17700,0.02648964,0.03450572,,,,,,,,,,,,,,,,, -17800,0.029729076,0.03791724,,,,,,,,,,,,,,,,, -17900,0.018422024,0.032799862,,,,,,,,,,,,,,,,, -18000,0.02390969,0.038592186,,,,,,,,,,,,,,,,, -18100,0.020193668,0.03284333,,,,,,,,,,,,,,,,, -18200,0.021469219,0.033457607,,,,,,,,,,,,,,,,, -18300,0.021501731,0.03757372,,,,,,,,,,,,,,,,, -18400,0.034082115,0.034827277,,,,,,,,,,,,,,,,, -18420,,,0.9915626645088196,0.0277974400669336,0.4831674995977595,0.9867565631866456,0.0454152189195156,0.2554176813356391,43793.0,0.985854148864746,0.0481930561363697,0.2527952731086541,43793.0,6015.588186979294,9146.614343166351,6015.588186979294,3129.774727344513,0.739865779876709,0.0 -18500,0.022713041,0.036562797,,,,,,,,,,,,,,,,, -18600,0.023970151,0.03566027,,,,,,,,,,,,,,,,, -18700,0.027265871,0.036143564,,,,,,,,,,,,,,,,, -18800,0.027685499,0.034964714,,,,,,,,,,,,,,,,, -18900,0.039559674,0.03690158,,,,,,,,,,,,,,,,, -19000,0.02462155,0.038310617,,,,,,,,,,,,,,,,, -19100,0.026099043,0.03498022,,,,,,,,,,,,,,,,, -19167,,,0.9915648102760316,0.0273576527833938,0.5086215310274381,0.9867614507675172,0.0455816313624382,0.2631518378835457,43793.0,0.9859089255332948,0.0483778268098831,0.2589651663826672,43793.0,6255.743933677673,9501.434201717377,6255.743933677673,3244.3884332180023,0.7705888748168945,0.0 -19200,0.029324686,0.035480168,,,,,,,,,,,,,,,,, -19300,0.022521203,0.036030527,,,,,,,,,,,,,,,,, -19400,0.027356973,0.036311414,,,,,,,,,,,,,,,,, -19500,0.023280405,0.035148066,,,,,,,,,,,,,,,,, -19600,0.022100175,0.03413597,,,,,,,,,,,,,,,,, -19700,0.025728751,0.0343496,,,,,,,,,,,,,,,,, -19800,0.022539848,0.03392756,,,,,,,,,,,,,,,,, -19900,0.029707976,0.03282769,,,,,,,,,,,,,,,,, -19903,,,0.992231011390686,0.0255564451217651,0.5372999658293216,0.9868332743644714,0.0451035685837268,0.2715457824918656,43793.0,0.9860078692436218,0.0478366501629352,0.2606229283561334,43793.0,6496.008260965347,9855.55399274826,6496.008260965347,3358.1936955451965,0.8017594814300537,0.0 -20000,0.028074132,0.035353046,,,,,,,,,,,,,,,,, -20100,0.02619727,0.034704152,,,,,,,,,,,,,,,,, -20200,0.0288285,0.03326357,,,,,,,,,,,,,,,,, -20300,0.02387349,0.03446022,,,,,,,,,,,,,,,,, -20400,0.022997629,0.031993452,,,,,,,,,,,,,,,,, -20500,0.028970676,0.035118207,,,,,,,,,,,,,,,,, -20600,0.027345484,0.034554888,,,,,,,,,,,,,,,,, -20644,,,0.9922563433647156,0.0252972152084112,0.5544247501332782,0.9868401885032654,0.0456213280558586,0.2686777262266689,43793.0,0.9859581589698792,0.0484275855123996,0.2603293757011919,43793.0,6736.173248052597,10212.136778831482,6736.173248052597,3474.561186313629,0.832097053527832,0.0 -20700,0.02414215,0.030793602,,,,,,,,,,,,,,,,, -20800,0.026658814,0.032360256,,,,,,,,,,,,,,,,, -20900,0.026479589,0.035875496,,,,,,,,,,,,,,,,, -21000,0.027723204,0.03347786,,,,,,,,,,,,,,,,, -21100,0.026957113,0.033217683,,,,,,,,,,,,,,,,, -21200,0.027245998,0.03380749,,,,,,,,,,,,,,,,, -21300,0.02995045,0.032879055,,,,,,,,,,,,,,,,, -21384,,,0.99234277009964,0.0249669533222913,0.5519044794078625,0.9868003726005554,0.0461722835898399,0.264854724285865,43793.0,0.9859493374824524,0.0490799285471439,0.2598057736054661,43793.0,6976.150879383087,10567.873800992966,6976.150879383087,3590.268481016159,0.8643553256988525,0.0 -21400,0.02764232,0.033239707,,,,,,,,,,,,,,,,, -21500,0.033871014,0.032107458,,,,,,,,,,,,,,,,, -21600,0.025942815,0.034061994,,,,,,,,,,,,,,,,, -21700,0.028918473,0.031489965,,,,,,,,,,,,,,,,, -21800,0.030988371,0.03216987,,,,,,,,,,,,,,,,, -21900,0.031598844,0.036904305,,,,,,,,,,,,,,,,, -22000,0.024801526,0.032510363,,,,,,,,,,,,,,,,, -22100,0.028143702,0.03270575,,,,,,,,,,,,,,,,, -22120,,,0.991874098777771,0.0265325270593166,0.5180254499914452,0.986780881881714,0.0461693368852138,0.2620975266770983,43793.0,0.985874354839325,0.0491098128259182,0.261793099253754,43793.0,7216.116875648498,10926.920147418976,7216.116875648498,3709.2970135211945,0.8965957164764404,0.0 -22200,0.030418785,0.032576934,,,,,,,,,,,,,,,,, -22300,0.030304724,0.033876475,,,,,,,,,,,,,,,,, -22400,0.030991044,0.03159063,,,,,,,,,,,,,,,,, -22500,0.030238194,0.030768547,,,,,,,,,,,,,,,,, -22600,0.029332105,0.035072494,,,,,,,,,,,,,,,,, -22700,0.027051264,0.03098086,,,,,,,,,,,,,,,,, -22800,0.030429814,0.031550046,,,,,,,,,,,,,,,,, -22854,,,0.9917595386505128,0.0264614392071962,0.5174314820179169,0.986764669418335,0.0466690026223659,0.2659003367203237,43793.0,0.9859341979026794,0.04970159009099,0.2569096916812762,43793.0,7456.379545927048,11289.88250374794,7456.379545927048,3831.945240736008,0.9279828071594238,0.0 -22900,0.027305584,0.032616913,,,,,,,,,,,,,,,,, -23000,0.031115772,0.032431375,,,,,,,,,,,,,,,,, -23100,0.031654958,0.029048277,,,,,,,,,,,,,,,,, -23200,0.029953355,0.0317899,,,,,,,,,,,,,,,,, -23300,0.03743064,0.03706478,,,,,,,,,,,,,,,,, -23400,0.033291463,0.034215134,,,,,,,,,,,,,,,,, -23500,0.030419486,0.03283222,,,,,,,,,,,,,,,,, -23586,,,0.992215096950531,0.0253068357706069,0.5333598000806847,0.9867005348205566,0.0463218465447425,0.2577843055251133,43793.0,0.9858810901641846,0.0492570735514164,0.2546088235250043,43793.0,7696.485308170319,11650.7317340374,7696.485308170319,3952.633655309677,0.962378740310669,0.0 -23600,0.036592785,0.031206569,,,,,,,,,,,,,,,,, -23700,0.03907426,0.03215623,,,,,,,,,,,,,,,,, -23800,0.038730226,0.036499467,,,,,,,,,,,,,,,,, -23900,0.032479487,0.032916427,,,,,,,,,,,,,,,,, -24000,0.041575313,0.033799715,,,,,,,,,,,,,,,,, -24100,0.032398347,0.030565433,,,,,,,,,,,,,,,,, -24200,0.035009388,0.031843618,,,,,,,,,,,,,,,,, -24300,0.030570315,0.032634534,,,,,,,,,,,,,,,,, -24314,,,0.9922340512275696,0.0250587128102779,0.5520819790954581,0.986710250377655,0.047242235392332,0.2548902204472823,43793.0,0.9858802556991576,0.0502522885799407,0.2530300101575706,43793.0,7936.522948503494,12010.12484574318,7936.522948503494,4071.9370653629294,0.9949030876159668,0.0 -24400,0.032899134,0.035181087,,,,,,,,,,,,,,,,, -24500,0.034609213,0.032898515,,,,,,,,,,,,,,,,, -24600,0.030449597,0.032547664,,,,,,,,,,,,,,,,, -24700,0.030459099,0.03245148,,,,,,,,,,,,,,,,, -24800,0.035374064,0.03410427,,,,,,,,,,,,,,,,, -24900,0.034341775,0.03188201,,,,,,,,,,,,,,,,, -25000,0.034703325,0.03285133,,,,,,,,,,,,,,,,, -25054,,,0.9923595190048218,0.0244757682085037,0.5629086457355192,0.9867419600486756,0.0471686944365501,0.2574200603232365,43793.0,0.985859215259552,0.0502813048660755,0.255721761943902,43793.0,8176.764466285706,12367.119560956957,8176.764466285706,4188.638366937637,1.026634693145752,0.0 -25100,0.0361656,0.03149943,,,,,,,,,,,,,,,,, -25200,0.03179968,0.031639233,,,,,,,,,,,,,,,,, -25300,0.034734637,0.034568273,,,,,,,,,,,,,,,,, -25400,0.037038174,0.03260327,,,,,,,,,,,,,,,,, -25500,0.05441922,0.034229252,,,,,,,,,,,,,,,,, -25600,0.03222697,0.033127572,,,,,,,,,,,,,,,,, -25700,0.047548857,0.032606643,,,,,,,,,,,,,,,,, -25778,,,0.9928763508796692,0.0227034576237201,0.6162934071137858,0.9866648316383362,0.047454223036766,0.2594423493508385,43793.0,0.9858731031417848,0.0505585819482803,0.2549490936009421,43793.0,8416.962620973587,12728.26784825325,8416.962620973587,4309.53524518013,1.0576236248016355,0.0 -25800,0.04822118,0.030286564,,,,,,,,,,,,,,,,, -25900,0.03588411,0.030498931,,,,,,,,,,,,,,,,, -26000,0.035500705,0.030607667,,,,,,,,,,,,,,,,, -26100,0.035955,0.035555903,,,,,,,,,,,,,,,,, -26200,0.047288258,0.030769626,,,,,,,,,,,,,,,,, -26300,0.039536666,0.033531334,,,,,,,,,,,,,,,,, -26400,0.03595817,0.03304431,,,,,,,,,,,,,,,,, -26500,0.03438893,0.033554863,,,,,,,,,,,,,,,,, -26515,,,0.9934829473495485,0.0216501988470554,0.6098429034027366,0.9865832328796388,0.0472255833446979,0.2603134068454962,43793.0,0.985743761062622,0.0503081008791923,0.2500445869214605,43793.0,8657.234840393066,13088.86488366127,8657.234840393066,4429.807000875473,1.0906052589416504,0.0 -26600,0.0430536,0.032751996,,,,,,,,,,,,,,,,, -26700,0.052238375,0.031212216,,,,,,,,,,,,,,,,, -26800,0.04322019,0.033839095,,,,,,,,,,,,,,,,, -26900,0.042209424,0.031279374,,,,,,,,,,,,,,,,, -27000,0.034406565,0.030198827,,,,,,,,,,,,,,,,, -27100,0.04068478,0.031873103,,,,,,,,,,,,,,,,, -27200,0.048371907,0.031889804,,,,,,,,,,,,,,,,, -27259,,,0.9931455254554749,0.0224333405494689,0.601588529797614,0.986559271812439,0.0475690253078937,0.2533269911318941,43793.0,0.9857968688011168,0.0504861883819103,0.2517114387345971,43793.0,8897.440607070923,13450.329896450045,8897.440607070923,4551.0147252082825,1.12199068069458,0.0 -27300,0.040016104,0.029561723,,,,,,,,,,,,,,,,, -27400,0.048785027,0.035176277,,,,,,,,,,,,,,,,, -27500,0.0431171,0.03133343,,,,,,,,,,,,,,,,, -27600,0.03772643,0.029384224,,,,,,,,,,,,,,,,, -27700,0.03982365,0.029545052,,,,,,,,,,,,,,,,, -27800,0.0503427,0.03144088,,,,,,,,,,,,,,,,, -27900,0.038400274,0.03271289,,,,,,,,,,,,,,,,, -27986,,,0.99298095703125,0.0228356607258319,0.5933179400988164,0.986605942249298,0.0477442778646945,0.2548442086332938,43793.0,0.9857610464096068,0.0507408790290355,0.2519577155097544,43793.0,9137.503035783768,13814.634695529938,9137.503035783768,4675.2046592235565,1.1547377109527588,0.0 -28000,0.036452543,0.030670622,,,,,,,,,,,,,,,,, -28100,0.037353665,0.028949799,,,,,,,,,,,,,,,,, -28200,0.040250387,0.029546749,,,,,,,,,,,,,,,,, -28300,0.04255295,0.030231165,,,,,,,,,,,,,,,,, -28400,0.04180943,0.031710673,,,,,,,,,,,,,,,,, -28500,0.04820393,0.031869162,,,,,,,,,,,,,,,,, -28600,0.044028703,0.032512426,,,,,,,,,,,,,,,,, -28700,0.045475774,0.03354996,,,,,,,,,,,,,,,,, -28719,,,0.9927420020103456,0.0231804102659225,0.5969560782400278,0.9867175817489624,0.0479126051068306,0.2682034494169585,43793.0,0.9859505891799928,0.0510335154831409,0.2555062883256959,43793.0,9377.56341791153,14169.90866613388,9377.56341791153,4790.3657166957855,1.1872403621673584,0.0 -28800,0.049464423,0.032957412,,,,,,,,,,,,,,,,, -28900,0.04301304,0.030124284,,,,,,,,,,,,,,,,, -29000,0.039080013,0.03133602,,,,,,,,,,,,,,,,, -29100,0.044465948,0.028173748,,,,,,,,,,,,,,,,, -29200,0.046528388,0.03103928,,,,,,,,,,,,,,,,, -29300,0.046096854,0.0318575,,,,,,,,,,,,,,,,, -29400,0.045802053,0.02910044,,,,,,,,,,,,,,,,, -29456,,,0.9926162958145142,0.0233869291841983,0.5947517350336357,0.9866631627082824,0.0486439242959022,0.2574898460922729,43793.0,0.9856972694396972,0.0518477261066436,0.2487521176330588,43793.0,9617.663268327711,14530.736252069471,9617.663268327711,4911.0401475429535,1.2209062576293943,0.0 -29500,0.03971229,0.030489633,,,,,,,,,,,,,,,,, -29600,0.04569931,0.030863212,,,,,,,,,,,,,,,,, -29700,0.044322576,0.031385485,,,,,,,,,,,,,,,,, -29800,0.046322618,0.03080357,,,,,,,,,,,,,,,,, -29900,0.04927579,0.030986452,,,,,,,,,,,,,,,,, -30000,0.044773396,0.03018339,,,,,,,,,,,,,,,,, -30100,0.04710343,0.029646644,,,,,,,,,,,,,,,,, -30183,,,0.9931610822677612,0.0220011305063962,0.6025305047501668,0.986594557762146,0.0480909161269664,0.2596651647218211,43793.0,0.9857484102249146,0.0511756390333175,0.2540606307889602,43793.0,9857.94836807251,14889.342999219894,9857.94836807251,5029.308331251144,1.2543549537658691,0.0 -30200,0.06534482,0.030238906,,,,,,,,,,,,,,,,, -30300,0.052271325,0.031641927,,,,,,,,,,,,,,,,, -30400,0.04024244,0.031995438,,,,,,,,,,,,,,,,, -30500,0.043221205,0.032213256,,,,,,,,,,,,,,,,, -30600,0.050446935,0.029360691,,,,,,,,,,,,,,,,, -30700,0.050053183,0.030247185,,,,,,,,,,,,,,,,, -30800,0.050535828,0.030689523,,,,,,,,,,,,,,,,, -30900,0.04826834,0.031190654,,,,,,,,,,,,,,,,, -30913,,,0.9933783411979676,0.0210400521755218,0.6411639773982251,0.9866043329238892,0.0487705618143081,0.2531882033520808,43793.0,0.9857787489891052,0.0517778843641281,0.2512915422425292,43793.0,10098.119551181791,15248.962916851044,10098.119551181791,5148.699724435806,1.288421869277954,0.0 -31000,0.04435971,0.028622404,,,,,,,,,,,,,,,,, -31100,0.05419314,0.031821303,,,,,,,,,,,,,,,,, -31200,0.047911312,0.028115312,,,,,,,,,,,,,,,,, -31300,0.04854963,0.030898705,,,,,,,,,,,,,,,,, -31400,0.047011156,0.030185902,,,,,,,,,,,,,,,,, -31500,0.046812654,0.030095838,,,,,,,,,,,,,,,,, -31600,0.049126707,0.02995008,,,,,,,,,,,,,,,,, -31645,,,0.993828535079956,0.0197061952203512,0.6617192203318676,0.9866339564323424,0.0489429533481597,0.2552728387461361,43793.0,0.9857193827629088,0.0524006262421608,0.2476491451965961,43793.0,10338.317598104475,15610.73458480835,10338.317598104475,5270.216884851456,1.32468581199646,0.0 -31700,0.049898054,0.029293554,,,,,,,,,,,,,,,,, -31800,0.05110163,0.028456489,,,,,,,,,,,,,,,,, -31900,0.05491881,0.028969133,,,,,,,,,,,,,,,,, -32000,0.05214607,0.02968459,,,,,,,,,,,,,,,,, -32100,0.048233602,0.02827279,,,,,,,,,,,,,,,,, -32200,0.059166577,0.030483609,,,,,,,,,,,,,,,,, -32300,0.053141292,0.030207608,,,,,,,,,,,,,,,,, -32376,,,0.994285523891449,0.0187331493943929,0.6851187612466829,0.986394464969635,0.0492233633995056,0.2534544147804244,43793.0,0.9855828881263732,0.0523907206952571,0.2438136395282832,43793.0,10578.267916440964,15970.068334579468,10578.267916440964,5389.546606063843,1.3582172393798828,0.0 -32400,0.05096452,0.028999873,,,,,,,,,,,,,,,,, -32500,0.060685758,0.029302733,,,,,,,,,,,,,,,,, -32600,0.052852903,0.029630864,,,,,,,,,,,,,,,,, -32700,0.054976933,0.031288106,,,,,,,,,,,,,,,,, -32800,0.07480684,0.029961929,,,,,,,,,,,,,,,,, -32900,0.053318422,0.029023234,,,,,,,,,,,,,,,,, -33000,0.05686823,0.029779766,,,,,,,,,,,,,,,,, -33100,0.056242727,0.029778851,,,,,,,,,,,,,,,,, -33103,,,0.9938401579856871,0.0198705811053514,0.6640449668679048,0.9866132736206056,0.0494127534329891,0.2477094520285919,43793.0,0.9857964515686036,0.0526494719088077,0.2487113766268982,43793.0,10818.28539800644,16328.837913751602,10818.28539800644,5508.245011806488,1.3919477462768557,0.0 -33200,0.044989973,0.027685879,,,,,,,,,,,,,,,,, -33300,0.052740484,0.029622488,,,,,,,,,,,,,,,,, -33400,0.05596247,0.029407747,,,,,,,,,,,,,,,,, -33500,0.053889263,0.030326545,,,,,,,,,,,,,,,,, -33600,0.058213584,0.030240668,,,,,,,,,,,,,,,,, -33700,0.060888655,0.029075146,,,,,,,,,,,,,,,,, -33800,0.047607537,0.026588619,,,,,,,,,,,,,,,,, -33847,,,0.9930763244628906,0.0219397060573101,0.604933073567349,0.9864837527275084,0.0501733385026454,0.2480326956404477,43793.0,0.9856372475624084,0.0534129738807678,0.2489563508088895,43793.0,11058.477831840515,16685.099545240402,11058.477831840515,5624.259745597839,1.4259235858917236,0.0 -33900,0.06029255,0.03026274,,,,,,,,,,,,,,,,, -34000,0.058362532,0.029276617,,,,,,,,,,,,,,,,, -34100,0.060925663,0.032826796,,,,,,,,,,,,,,,,, -34200,0.057535827,0.028674087,,,,,,,,,,,,,,,,, -34300,0.06902096,0.031170443,,,,,,,,,,,,,,,,, -34400,0.061246257,0.032463986,,,,,,,,,,,,,,,,, -34500,0.061780985,0.028464567,,,,,,,,,,,,,,,,, -34589,,,0.9932107329368592,0.0217495448887348,0.6007094082387573,0.9864163994789124,0.0497723706066608,0.2514264851650297,43793.0,0.9856675267219543,0.0528992712497711,0.2486676367310245,43793.0,11298.49192237854,17037.724573135376,11298.49192237854,5736.817763805389,1.4586491584777832,0.0 -34600,0.053896636,0.02770837,,,,,,,,,,,,,,,,, -34700,0.05481114,0.029875398,,,,,,,,,,,,,,,,, -34800,0.056984127,0.028262885,,,,,,,,,,,,,,,,, -34900,0.08654992,0.030782836,,,,,,,,,,,,,,,,, -35000,0.056221996,0.029991947,,,,,,,,,,,,,,,,, -35100,0.06379572,0.030137228,,,,,,,,,,,,,,,,, -35200,0.06512853,0.02853557,,,,,,,,,,,,,,,,, -35300,0.0632976,0.029385267,,,,,,,,,,,,,,,,, -35331,,,0.993598461151123,0.020510371774435,0.6334068095016367,0.986381471157074,0.050090841948986,0.2464012148980412,43793.0,0.9855715036392212,0.0532553158700466,0.251005720832414,43793.0,11538.711616039276,17393.373587608337,11538.711616039276,5852.193566083908,1.4922783374786377,0.0 -35400,0.06573815,0.031111129,,,,,,,,,,,,,,,,, -35500,0.058699552,0.027934667,,,,,,,,,,,,,,,,, -35600,0.067240655,0.028918177,,,,,,,,,,,,,,,,, -35700,0.066655815,0.02912738,,,,,,,,,,,,,,,,, -35800,0.06516765,0.029023815,,,,,,,,,,,,,,,,, -35900,0.059746765,0.029048197,,,,,,,,,,,,,,,,, -36000,0.06455419,0.030376432,,,,,,,,,,,,,,,,, -36077,,,0.99366557598114,0.0199929028749465,0.6459082535084215,0.9864638447761536,0.0505931377410888,0.2457527230082513,43793.0,0.9856035113334656,0.0539854243397712,0.2435299307766965,43793.0,11778.746633768082,17743.378707647324,11778.746633768082,5962.109864473343,1.5256366729736328,0.0 -36100,0.057099402,0.027061721,,,,,,,,,,,,,,,,, -36200,0.06588316,0.027904334,,,,,,,,,,,,,,,,, -36300,0.06309819,0.029754832,,,,,,,,,,,,,,,,, -36400,0.06070309,0.027274758,,,,,,,,,,,,,,,,, -36500,0.061791386,0.030945532,,,,,,,,,,,,,,,,, -36600,0.064220265,0.028731491,,,,,,,,,,,,,,,,, -36700,0.06611037,0.028068064,,,,,,,,,,,,,,,,, -36800,0.07046598,0.028626002,,,,,,,,,,,,,,,,, -36814,,,0.9936034679412842,0.0200266540050506,0.6523864222962256,0.98634535074234,0.0508432760834693,0.2410080625010881,43793.0,0.9855087399482728,0.0544721521437168,0.2419251836295627,43793.0,12018.441509246826,18100.86488223076,12018.441509246826,6079.482671022415,1.9244272708892824,0.0 -36900,0.06697221,0.02969912,,,,,,,,,,,,,,,,, -37000,0.061282456,0.028208686,,,,,,,,,,,,,,,,, -37100,0.06022351,0.028161775,,,,,,,,,,,,,,,,, -37200,0.07603933,0.027522271,,,,,,,,,,,,,,,,, -37300,0.0725688,0.028489428,,,,,,,,,,,,,,,,, -37400,0.061621483,0.02625455,,,,,,,,,,,,,,,,, -37500,0.06016286,0.02774978,,,,,,,,,,,,,,,,, -37548,,,0.9947233200073242,0.0171383786946535,0.7183873072526152,0.9864943027496338,0.0512232780456543,0.246641803540727,43793.0,0.985682725906372,0.0545710735023021,0.2465838969900322,43793.0,12258.66541147232,18457.14681982994,12258.66541147232,6195.485498428345,1.959678649902344,0.0 -37600,0.064052045,0.027777757,,,,,,,,,,,,,,,,, -37700,0.05685796,0.026663346,,,,,,,,,,,,,,,,, -37800,0.06394784,0.029048733,,,,,,,,,,,,,,,,, -37900,0.06859481,0.026977455,,,,,,,,,,,,,,,,, -38000,0.064425744,0.026841428,,,,,,,,,,,,,,,,, -38100,0.05949572,0.024985295,,,,,,,,,,,,,,,,, -38200,0.06287559,0.026680425,,,,,,,,,,,,,,,,, -38282,,,0.995094358921051,0.0160580594092607,0.7305360156173912,0.9864216446876526,0.0519767627120018,0.2444132828398841,43793.0,0.985569417476654,0.0554567091166973,0.2416775400873894,43793.0,12498.67090535164,18815.900892019272,12498.67090535164,6314.177472352982,1.9961557388305664,0.0 -38300,0.08741095,0.029772585,,,,,,,,,,,,,,,,, -38400,0.062276367,0.02551366,,,,,,,,,,,,,,,,, -38500,0.06432529,0.027669825,,,,,,,,,,,,,,,,, -38600,0.059328906,0.025817132,,,,,,,,,,,,,,,,, -38700,0.088518545,0.030344523,,,,,,,,,,,,,,,,, -38800,0.068908475,0.027492624,,,,,,,,,,,,,,,,, -38900,0.06176239,0.027891912,,,,,,,,,,,,,,,,, -39000,0.07860531,0.025827283,,,,,,,,,,,,,,,,, -39020,,,0.9950499534606934,0.0164400376379489,0.7171123304405793,0.9864265322685242,0.0511059425771236,0.2482558400172647,43793.0,0.9855411648750304,0.0549594648182392,0.2413185281618792,43793.0,12738.752710580826,19172.32074737549,12738.752710580826,6430.45547246933,2.035463333129883,0.0 -39100,0.06945106,0.027015697,,,,,,,,,,,,,,,,, -39200,0.07316366,0.028241603,,,,,,,,,,,,,,,,, -39300,0.06922587,0.027810307,,,,,,,,,,,,,,,,, -39400,0.06388847,0.027628336,,,,,,,,,,,,,,,,, -39500,0.07327638,0.027442614,,,,,,,,,,,,,,,,, -39600,0.07042926,0.028318362,,,,,,,,,,,,,,,,, -39700,0.084246226,0.029317485,,,,,,,,,,,,,,,,, -39750,,,0.9946119785308838,0.0173363238573074,0.7018877172650734,0.9864224791526794,0.0517116338014602,0.2454003609325848,43793.0,0.9855074882507324,0.05539982765913,0.2396051910604039,43793.0,12978.81694483757,19527.48945236206,12978.81694483757,6545.504342556,2.070361852645874,0.0 -39800,0.069822945,0.028540658,,,,,,,,,,,,,,,,, -39900,0.07776056,0.027707241,,,,,,,,,,,,,,,,, -40000,0.08073113,0.027706457,,,,,,,,,,,,,,,,, -40100,0.07146173,0.0255618,,,,,,,,,,,,,,,,, -40200,0.069882005,0.025788706,,,,,,,,,,,,,,,,, -40300,0.08253495,0.028242087,,,,,,,,,,,,,,,,, -40400,0.08134946,0.02934411,,,,,,,,,,,,,,,,, -40480,,,0.9938182830810548,0.0190417077392339,0.6818837267982015,0.9863936305046082,0.0527338311076164,0.2413260157386747,43793.0,0.9855066537857056,0.0563833639025688,0.2406186509886538,43793.0,13218.769835948944,19882.935147047043,13218.769835948944,6660.941750526428,2.1060330867767334,0.0 -40500,0.07204121,0.02828986,,,,,,,,,,,,,,,,, -40600,0.086335786,0.0277532,,,,,,,,,,,,,,,,, -40700,0.07816042,0.02896029,,,,,,,,,,,,,,,,, -40800,0.08527272,0.027025724,,,,,,,,,,,,,,,,, -40900,0.09803873,0.028086523,,,,,,,,,,,,,,,,, -41000,0.07770083,0.028778601,,,,,,,,,,,,,,,,, -41100,0.06564694,0.026288787,,,,,,,,,,,,,,,,, -41200,0.075904965,0.026908465,,,,,,,,,,,,,,,,, -41218,,,0.9941673874855042,0.0182603262364864,0.6868479320912788,0.986441969871521,0.0524128153920173,0.2450127274158767,43793.0,0.9854856133461,0.0560379549860954,0.232853285548857,43793.0,13458.771996974943,20232.86994457245,13458.771996974943,6770.818592071533,2.141692876815796,0.0 -41300,0.08041581,0.027663503,,,,,,,,,,,,,,,,, -41400,0.064959705,0.025739126,,,,,,,,,,,,,,,,, -41500,0.07573091,0.026276026,,,,,,,,,,,,,,,,, -41600,0.06817163,0.025427718,,,,,,,,,,,,,,,,, -41700,0.06796409,0.025722824,,,,,,,,,,,,,,,,, -41800,0.08099882,0.027144076,,,,,,,,,,,,,,,,, -41900,0.07684375,0.026941087,,,,,,,,,,,,,,,,, -41956,,,0.9942653775215148,0.0178059693425893,0.6941094896473745,0.9863846898078918,0.053409494459629,0.2359936574011092,43793.0,0.9855399131774902,0.0567248314619064,0.2337525892945866,43793.0,13698.72559428215,20588.09084820748,13698.72559428215,6886.031506538391,2.176521062850952,0.0 -42000,0.0849018,0.026165225,,,,,,,,,,,,,,,,, -42100,0.08859459,0.026723117,,,,,,,,,,,,,,,,, -42200,0.08637159,0.02869345,,,,,,,,,,,,,,,,, -42300,0.0679312,0.025883976,,,,,,,,,,,,,,,,, -42400,0.09346773,0.029330224,,,,,,,,,,,,,,,,, -42500,0.08228185,0.026161505,,,,,,,,,,,,,,,,, -42600,0.09097318,0.025490394,,,,,,,,,,,,,,,,, -42698,,,0.9948505759239196,0.0164919458329677,0.7170254342784487,0.9862925410270692,0.0528776720166206,0.2373408264054241,43793.0,0.9854699969291688,0.056529015302658,0.2339992933288823,43793.0,13938.932447195051,20938.051344156265,13938.932447195051,6995.729851484299,2.2115795612335205,0.0 -42700,0.07896229,0.024875289,,,,,,,,,,,,,,,,, -42800,0.076012686,0.026176626,,,,,,,,,,,,,,,,, -42900,0.06987217,0.024281971,,,,,,,,,,,,,,,,, -43000,0.07776409,0.027636338,,,,,,,,,,,,,,,,, -43100,0.08215397,0.027394399,,,,,,,,,,,,,,,,, -43200,0.07769484,0.026017183,,,,,,,,,,,,,,,,, -43300,0.07761934,0.024157813,,,,,,,,,,,,,,,,, -43400,0.07776691,0.025302345,,,,,,,,,,,,,,,,, -43447,,,0.9954681992530824,0.0149921560660004,0.7636412831561692,0.986258029937744,0.0534884706139564,0.2398435653849362,43793.0,0.985352098941803,0.0571636594831943,0.232133301550064,43793.0,14179.018709421158,21290.509116888046,14179.018709421158,7108.043355226517,2.249853610992432,0.0 -43500,0.083105266,0.025857145,,,,,,,,,,,,,,,,, -43600,0.088407345,0.02698587,,,,,,,,,,,,,,,,, -43700,0.082259916,0.02654022,,,,,,,,,,,,,,,,, -43800,0.08374699,0.025642613,,,,,,,,,,,,,,,,, -43900,0.07450858,0.027008865,,,,,,,,,,,,,,,,, -44000,0.06956342,0.025189398,,,,,,,,,,,,,,,,, -44100,0.07960503,0.026342072,,,,,,,,,,,,,,,,, -44197,,,0.995785117149353,0.0140907894819974,0.7848275030564575,0.9862913489341736,0.0545599684119224,0.2348540301396725,43793.0,0.9854013323783876,0.0583654716610908,0.2327379965088002,43793.0,14418.994359016418,21642.547493696213,14418.994359016418,7220.051170825958,2.2852609157562256,0.0 -44200,0.08408944,0.026824478,,,,,,,,,,,,,,,,, -44300,0.10449069,0.02757678,,,,,,,,,,,,,,,,, -44400,0.07792194,0.025194619,,,,,,,,,,,,,,,,, -44500,0.08149211,0.02563299,,,,,,,,,,,,,,,,, -44600,0.087151565,0.026593992,,,,,,,,,,,,,,,,, -44700,0.09255665,0.02594723,,,,,,,,,,,,,,,,, -44800,0.082642354,0.027594946,,,,,,,,,,,,,,,,, -44900,0.08644091,0.02594403,,,,,,,,,,,,,,,,, -44943,,,0.9962485432624816,0.0128924082964658,0.8084429781038803,0.9863153100013732,0.0548417158424854,0.2340596465030607,43793.0,0.9854552745819092,0.0584064722061157,0.2363809785728483,43793.0,14659.189492940905,21995.470474243164,14659.189492940905,7332.72357583046,2.32065486907959,0.0 -45000,0.09393503,0.027129102,,,,,,,,,,,,,,,,, -45100,0.07359832,0.02707053,,,,,,,,,,,,,,,,, -45200,0.08576524,0.026504522,,,,,,,,,,,,,,,,, -45300,0.07679406,0.025624672,,,,,,,,,,,,,,,,, -45400,0.08270404,0.026676541,,,,,,,,,,,,,,,,, -45500,0.08068537,0.026650088,,,,,,,,,,,,,,,,, -45600,0.08271144,0.023923416,,,,,,,,,,,,,,,,, -45683,,,0.9945888519287108,0.0169161148369312,0.7274850797457694,0.9861624836921692,0.0553164780139923,0.2296347303095366,43793.0,0.9852480292320251,0.0591854453086853,0.2316662683194748,43793.0,14899.14623594284,22346.908483743668,14899.14623594284,7444.150077819824,2.3560049533843994,0.0 -45700,0.09962385,0.025338069,,,,,,,,,,,,,,,,, -45800,0.0738515,0.025359467,,,,,,,,,,,,,,,,, -45900,0.085359134,0.025269778,,,,,,,,,,,,,,,,, -46000,0.080035426,0.024946095,,,,,,,,,,,,,,,,, -46100,0.085694686,0.024884364,,,,,,,,,,,,,,,,, -46200,0.08048576,0.026630618,,,,,,,,,,,,,,,,, -46300,0.083069086,0.027090922,,,,,,,,,,,,,,,,, -46400,0.08972683,0.024809754,,,,,,,,,,,,,,,,, -46426,,,0.9955013990402222,0.0145522383973002,0.7470824289675861,0.9862040877342224,0.0551997199654579,0.2317819172577487,43793.0,0.9853516817092896,0.0587752237915992,0.2365998127794417,43793.0,15139.321603536606,22696.00659751892,15139.321603536606,7553.0161652565,2.39290738105774,0.0 -46500,0.10156902,0.02456296,,,,,,,,,,,,,,,,, -46600,0.07776954,0.024199128,,,,,,,,,,,,,,,,, -46700,0.08404111,0.02572547,,,,,,,,,,,,,,,,, -46800,0.079112805,0.024908165,,,,,,,,,,,,,,,,, -46900,0.08941251,0.025378924,,,,,,,,,,,,,,,,, -47000,0.08711643,0.02396757,,,,,,,,,,,,,,,,, -47100,0.08177023,0.025220882,,,,,,,,,,,,,,,,, -47163,,,0.9950583577156068,0.0155198480933904,0.7462739776245799,0.9862487316131592,0.0558253675699234,0.233228704631362,43793.0,0.9853390455245972,0.0595657564699649,0.2342264089381991,43793.0,15379.40378499031,23047.339591503143,15379.40378499031,7664.210397481918,2.4298739433288574,0.0 -47200,0.08744607,0.025275141,,,,,,,,,,,,,,,,, -47300,0.090704456,0.024026234,,,,,,,,,,,,,,,,, -47400,0.10465175,0.025384491,,,,,,,,,,,,,,,,, -47500,0.083305895,0.024796097,,,,,,,,,,,,,,,,, -47600,0.09076708,0.023577232,,,,,,,,,,,,,,,,, -47700,0.08150592,0.023825664,,,,,,,,,,,,,,,,, -47800,0.09770529,0.025450885,,,,,,,,,,,,,,,,, -47900,0.10283344,0.023771659,,,,,,,,,,,,,,,,, -47902,,,0.9947388172149658,0.0161861889064311,0.7402327222700571,0.9861894249916076,0.0563515312969684,0.2244347422110065,43793.0,0.9852008819580078,0.0601598471403121,0.2281059140172945,43793.0,15619.506531715391,23398.714463949203,15619.506531715391,7775.425299167633,2.46665096282959,0.0 -48000,0.110680364,0.025736516,,,,,,,,,,,,,,,,, -48100,0.08692327,0.02500699,,,,,,,,,,,,,,,,, -48200,0.084056005,0.025555419,,,,,,,,,,,,,,,,, -48300,0.088985965,0.025100628,,,,,,,,,,,,,,,,, -48400,0.086700074,0.024294622,,,,,,,,,,,,,,,,, -48500,0.08963618,0.02510876,,,,,,,,,,,,,,,,, -48600,0.08634366,0.022513002,,,,,,,,,,,,,,,,, -48640,,,0.9944795966148376,0.0167838465422391,0.7291917979386155,0.986224353313446,0.056857194751501,0.2288818870293448,43793.0,0.9852421283721924,0.0606749877333641,0.2275132350776089,43793.0,15859.705310821531,23748.696949481964,15859.705310821531,7885.1501557827,2.50447678565979,0.0 -48700,0.08926924,0.024050796,,,,,,,,,,,,,,,,, -48800,0.093836874,0.023244314,,,,,,,,,,,,,,,,, -48900,0.091805466,0.025828786,,,,,,,,,,,,,,,,, -49000,0.09439604,0.024871275,,,,,,,,,,,,,,,,, -49100,0.08350367,0.024883326,,,,,,,,,,,,,,,,, -49200,0.08760892,0.024704438,,,,,,,,,,,,,,,,, -49300,0.092676565,0.024023592,,,,,,,,,,,,,,,,, -49384,,,0.9957962036132812,0.0137219009920954,0.7867519225622568,0.9862629175186156,0.0566055364906787,0.2319078980631626,43793.0,0.9852739572525024,0.0606629438698291,0.2257372623571563,43793.0,16099.769166469574,24102.322848558422,16099.769166469574,7998.654189348221,2.5425143241882324,0.0 -49400,0.09987727,0.025376631,,,,,,,,,,,,,,,,, -49500,0.08323977,0.023551142,,,,,,,,,,,,,,,,, -49600,0.10022226,0.023478722,,,,,,,,,,,,,,,,, -49700,0.09337637,0.02485057,,,,,,,,,,,,,,,,, -49800,0.08062202,0.023340264,,,,,,,,,,,,,,,,, -49900,0.09991593,0.023462163,,,,,,,,,,,,,,,,, -50000,0.08198984,0.021202745,,,,,,,,,,,,,,,,, -50100,0.09183007,0.02413864,,,,,,,,,,,,,,,,, -50127,,,0.996640145778656,0.0119273085147142,0.8215199199748869,0.9861634969711304,0.0570623017847538,0.2324574809517005,43793.0,0.9852084517478944,0.0609424151480197,0.2246761260090098,43793.0,16339.929149627686,24458.41442155838,16339.929149627686,8114.529004335403,2.579529523849488,0.0 -50200,0.099861585,0.024139702,,,,,,,,,,,,,,,,, -50300,0.09230964,0.023770165,,,,,,,,,,,,,,,,, -50400,0.087275706,0.022732887,,,,,,,,,,,,,,,,, -50500,0.09452105,0.023546416,,,,,,,,,,,,,,,,, -50600,0.08195619,0.022792796,,,,,,,,,,,,,,,,, -50700,0.09711565,0.025654972,,,,,,,,,,,,,,,,, -50800,0.088386595,0.0249202,,,,,,,,,,,,,,,,, -50874,,,0.9970665574073792,0.0110013792291283,0.8423189663398486,0.9861395359039308,0.0575702711939811,0.2293329472246431,43793.0,0.9852164387702942,0.0615736842155456,0.2228248464410566,43793.0,16580.062649965286,24813.826580524445,16580.062649965286,8229.751652002335,2.6154143810272217,0.0 -50900,0.099540375,0.023854198,,,,,,,,,,,,,,,,, -51000,0.08456451,0.023244236,,,,,,,,,,,,,,,,, -51100,0.08139661,0.023356183,,,,,,,,,,,,,,,,, -51200,0.09084138,0.022558311,,,,,,,,,,,,,,,,, -51300,0.07862441,0.022953039,,,,,,,,,,,,,,,,, -51400,0.08613624,0.022293504,,,,,,,,,,,,,,,,, -51500,0.09149995,0.023273055,,,,,,,,,,,,,,,,, -51600,0.07995702,0.022437628,,,,,,,,,,,,,,,,, -51617,,,0.996950089931488,0.011073999106884,0.8505485895345607,0.986198365688324,0.0580263286828994,0.2304312454150581,43793.0,0.9852501749992372,0.062121957540512,0.2230208669252613,43793.0,16820.25114750862,25167.75855875016,16820.25114750862,8343.433227062225,2.657003402709961,0.0 -51700,0.09450298,0.024046706,,,,,,,,,,,,,,,,, -51800,0.089114085,0.02337224,,,,,,,,,,,,,,,,, -51900,0.092781566,0.023992967,,,,,,,,,,,,,,,,, -52000,0.08392598,0.022306554,,,,,,,,,,,,,,,,, -52100,0.086315796,0.022866024,,,,,,,,,,,,,,,,, -52200,0.09765704,0.023385497,,,,,,,,,,,,,,,,, -52300,0.09442985,0.02396515,,,,,,,,,,,,,,,,, -52353,,,0.9966402053833008,0.0118224844336509,0.8342763710567649,0.9861005544662476,0.0578526854515075,0.2286860382086004,43793.0,0.985108196735382,0.062013104557991,0.2234894896601067,43793.0,17060.339695692062,25520.81440806389,17060.339695692062,8456.342872619629,2.694946765899658,0.0 -52400,0.0863384,0.022027832,,,,,,,,,,,,,,,,, -52500,0.102176785,0.024414549,,,,,,,,,,,,,,,,, -52600,0.13075095,0.024483263,,,,,,,,,,,,,,,,, -52700,0.08866654,0.023656009,,,,,,,,,,,,,,,,, -52800,0.09427155,0.023369487,,,,,,,,,,,,,,,,, -52900,0.088499375,0.022056712,,,,,,,,,,,,,,,,, -53000,0.08812785,0.022975665,,,,,,,,,,,,,,,,, -53100,0.101332396,0.024361787,,,,,,,,,,,,,,,,, -53106,,,0.9960552453994752,0.0127846905961632,0.808120967576912,0.986150085926056,0.0588770769536495,0.2253980118229107,43793.0,0.9851238131523132,0.0628400072455406,0.2246251051391217,43793.0,17300.507427692413,25869.773553848267,17300.507427692413,8565.075862884521,2.733237028121948,0.0 -53200,0.086263366,0.023769043,,,,,,,,,,,,,,,,, -53300,0.09970567,0.022402072,,,,,,,,,,,,,,,,, -53400,0.07653155,0.022812394,,,,,,,,,,,,,,,,, -53500,0.08520397,0.02330633,,,,,,,,,,,,,,,,, -53600,0.09703041,0.023743164,,,,,,,,,,,,,,,,, -53700,0.100480705,0.022110071,,,,,,,,,,,,,,,,, -53800,0.09354045,0.023504974,,,,,,,,,,,,,,,,, -53849,,,0.99503892660141,0.0151807256042957,0.755597599761523,0.986154556274414,0.0587789863348007,0.2282030288300124,43793.0,0.9851983189582824,0.0627905428409576,0.2211224959491054,43793.0,17540.69450187683,26223.72825813293,17540.69450187683,8678.787027597427,2.770055055618286,0.0 -53900,0.06835604,0.020168206,,,,,,,,,,,,,,,,, -54000,0.0834328,0.022725949,,,,,,,,,,,,,,,,, -54100,0.0829052,0.022727557,,,,,,,,,,,,,,,,, -54200,0.10042038,0.021911312,,,,,,,,,,,,,,,,, -54300,0.08419502,0.023472808,,,,,,,,,,,,,,,,, -54400,0.07800377,0.0226272,,,,,,,,,,,,,,,,, -54500,0.0997595,0.022475895,,,,,,,,,,,,,,,,, -54595,,,0.995265007019043,0.0146405976265668,0.7755410902328488,0.9860976934432985,0.0594463013112545,0.2263200967059221,43793.0,0.9851747751235962,0.0633826255798339,0.2202348065167658,43793.0,17780.679956674576,26574.07158923149,17780.679956674576,8789.08589553833,2.8089280128479004,0.0 -54600,0.095006816,0.022928264,,,,,,,,,,,,,,,,, -54700,0.0748983,0.02288728,,,,,,,,,,,,,,,,, -54800,0.08174396,0.021559348,,,,,,,,,,,,,,,,, -54900,0.088191904,0.021561125,,,,,,,,,,,,,,,,, -55000,0.09312162,0.023323158,,,,,,,,,,,,,,,,, -55100,0.0962314,0.023864986,,,,,,,,,,,,,,,,, -55200,0.10067153,0.023130208,,,,,,,,,,,,,,,,, -55300,0.096229225,0.022813767,,,,,,,,,,,,,,,,, -55339,,,0.9959477186203004,0.0129019608721137,0.8156166730274892,0.9861111044883728,0.0599651709198951,0.2242800732972598,43793.0,0.9852118492126464,0.0640837997198104,0.2175502037047203,43793.0,18020.666786193848,26924.650710582733,18020.666786193848,8899.616615772247,2.8510804176330566,0.0 -55400,0.07200422,0.022398073,,,,,,,,,,,,,,,,, -55500,0.08134283,0.022543292,,,,,,,,,,,,,,,,, -55600,0.08794257,0.023370957,,,,,,,,,,,,,,,,, -55700,0.094576515,0.022795891,,,,,,,,,,,,,,,,, -55800,0.085335195,0.02204328,,,,,,,,,,,,,,,,, -55900,0.09033691,0.0242606,,,,,,,,,,,,,,,,, -56000,0.09071415,0.021634309,,,,,,,,,,,,,,,,, -56087,,,0.9954010248184204,0.0140918241813778,0.8007883470937304,0.9860416650772096,0.0601380616426467,0.2228116284451532,43793.0,0.9850656390190125,0.0643724128603935,0.2164854876489734,43793.0,18260.69078636169,27273.23920822144,18260.69078636169,9008.119769573212,2.8922278881073,0.0 -56100,0.07873113,0.021922553,,,,,,,,,,,,,,,,, -56200,0.0975884,0.02247489,,,,,,,,,,,,,,,,, -56300,0.1058401,0.021600371,,,,,,,,,,,,,,,,, -56400,0.10813331,0.02316987,,,,,,,,,,,,,,,,, -56500,0.09282533,0.022408145,,,,,,,,,,,,,,,,, -56600,0.09122546,0.022106139,,,,,,,,,,,,,,,,, -56700,0.079892665,0.021735802,,,,,,,,,,,,,,,,, -56756,,,,,,,,,,,,,,18477.04327273369,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 2c3204fc8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -109.07290530204772,0.0,13.67792797088623,1,0,13.67792797088623,0.5224818587303162,0.7161954641342163,0.0278308650959588,43793,122.75087356567384,0.5250248908996582,0.7151214480400085,0.0225963444812421,0.5213832855224609,0.7166012525558472,0.0261423612204539,43793 -221.7531237602234,0.0207765102386474,253.8570437431336,746,0,253.8570437431336,0.983142077922821,0.0802183300256729,0.0382531656617273,43793,475.6508138179779,0.9866614937782288,0.0689144805073738,0.0343829040720385,0.9841179251670836,0.0773229226469993,0.0368152229443847,43793 -328.72299671173096,0.0476095676422119,493.8156957626343,1491,0,493.8156957626343,0.9835514426231384,0.0615411400794982,0.0909965974697616,43793,822.6260054111481,0.9872501492500304,0.0485505908727645,0.0815675372660544,0.9845539331436156,0.0582684986293315,0.087065007711593,43793 -440.2566766738892,0.0757808685302734,733.8995008468628,2243,0,733.8995008468628,0.9840973615646362,0.055955272167921,0.1340190769220492,43793,1174.291626214981,0.9877920150756836,0.0436826199293136,0.1433774872818228,0.9850483536720276,0.0529293678700923,0.1361727300785612,43793 -550.643904209137,0.1027662754058837,974.0467576980592,2990,0,974.0467576980592,0.9841592311859132,0.0553637146949768,0.1528859321678368,43793,1524.873197555542,0.9879328012466432,0.0425906032323837,0.1625567684371191,0.9850613474845886,0.0526109784841537,0.1509572942259753,43793 -659.9346287250519,0.134716510772705,1214.0511775016785,3731,0,1214.0511775016785,0.9846293330192566,0.0528262816369533,0.1723300618497228,43793,1874.2206497192385,0.9883593916893004,0.0405401736497879,0.1877052384465054,0.9855139851570128,0.0501751229166984,0.1707993933649933,43793 -769.0754227638245,0.1634418964385986,1454.1758043766022,4471,0,1454.1758043766022,0.984872341156006,0.0512144789099693,0.1853954365318094,43793,2223.534078359604,0.9885321259498596,0.039501205086708,0.2105221418082331,0.9856792092323304,0.048727061599493,0.1853128791746986,43793 -876.6488153934479,0.1911218166351318,1694.1631109714508,5214,0,1694.1631109714508,0.9848892092704772,0.0511879958212375,0.1917954177815328,43793,2571.1419973373413,0.9887311458587646,0.038518089801073,0.2226331766038031,0.9857717156410216,0.048647440969944,0.1899224175615532,43793 -985.5234348773956,0.2186455726623535,1934.2749044895168,5958,0,1934.2749044895168,0.9849742650985718,0.0500043183565139,0.2009979646605937,43793,2920.176949262619,0.9891310334205629,0.0371371954679489,0.2455989572156849,0.9858813285827636,0.0474062897264957,0.1994098609323364,43793 -1093.623272895813,0.24747896194458,2174.433957338333,6707,0,2174.433957338333,0.9851874113082886,0.0493319183588027,0.216469287800845,43793,3268.484499692917,0.9892842769622804,0.0365219675004482,0.2603477099825366,0.9860348105430604,0.0468565262854099,0.2105766077737178,43793 -1201.0050013065338,0.2745974063873291,2414.466232776642,7447,0,2414.466232776642,0.9851532578468324,0.0495551265776157,0.2166048805413218,43793,3615.946009159088,0.989247500896454,0.0362639613449573,0.273397267948866,0.9859694242477416,0.0470303744077682,0.2144172040557444,43793 -1307.7273774147034,0.3020987510681152,2654.612272977829,8197,0,2654.612272977829,0.9851823449134828,0.0489699803292751,0.2174333997986895,43793,3962.862035751343,0.9892252683639526,0.0363873355090618,0.2732933598362437,0.9859321117401124,0.046553336083889,0.2161812692848252,43793 -1419.5197608470917,0.3305845260620117,2894.6948940753937,8944,0,2894.6948940753937,0.9854089617729188,0.0488144159317016,0.2225351922664364,43793,4314.785717964172,0.989263653755188,0.0359275192022323,0.2766664071539449,0.9862779378890992,0.0461990498006343,0.2300533427368359,43793 -1528.5073668956757,0.3598239421844482,3134.852992296219,9690,0,3134.852992296219,0.9854245185852052,0.0481401570141315,0.2321340221277662,43793,4663.980555534363,0.9895308017730712,0.0351092107594013,0.297002617842767,0.9863014817237854,0.0454533360898494,0.2307974903300786,43793 -1634.1574666500092,0.3879969120025635,3374.9913704395294,10443,0,3374.9913704395294,0.9855690002441406,0.0480202287435531,0.2404512825341332,43793,5009.81748175621,0.9899264574050904,0.0339541770517826,0.3211785197160298,0.9864752292633056,0.0452593080699443,0.2440599289904552,43793 -1740.821620941162,0.4178082942962646,3614.948825597763,11189,0,3614.948825597763,0.9851945638656616,0.0487063974142074,0.2293547026063842,43793,5356.488221406937,0.9897028207778932,0.0340827777981758,0.3120466220388052,0.9860798716545104,0.0459315627813339,0.2344883683706195,43793 -1846.8564975261688,0.4476304054260254,3855.154119253159,11938,0,3855.154119253159,0.9857610464096068,0.0475765503942966,0.2491225917398928,43793,5702.779677867889,0.9901050925254822,0.0328376404941082,0.3517702449602501,0.9865832328796388,0.0449516586959362,0.2546755345667482,43793 -1952.983324766159,0.4773571491241455,4095.321917295456,12680,0,4095.321917295456,0.9856886267662048,0.0474681220948696,0.2495203850540048,43793,6049.124286413193,0.990390956401825,0.0319230556488037,0.3709610937739455,0.986536145210266,0.0446862578392028,0.2527580664724838,43793 -2065.694550514221,0.5062572956085205,4335.301592588425,13428,0,4335.301592588425,0.9858078360557556,0.0477101020514965,0.2436257342937955,43793,6401.863872051239,0.9904465079307556,0.0313658602535724,0.3848623867266765,0.986697256565094,0.0446103140711784,0.2574315640066265,43793 -2171.998544931412,0.5388197898864746,4575.268049478531,14163,0,4575.268049478531,0.9858394265174866,0.047771580517292,0.2455737676991713,43793,6748.188427686691,0.9905685186386108,0.0309013519436121,0.3852918889568212,0.9867021441459656,0.0448856018483638,0.248555407906438,43793 -2280.576951980591,0.5679218769073486,4815.410197257996,14900,0,4815.410197257996,0.985740840435028,0.0479223951697349,0.2417695698170992,43793,7096.95876121521,0.9904906153678894,0.0313890948891639,0.3731714129751292,0.9866453409194946,0.0449206680059433,0.2515589492994831,43793 -2385.3323168754578,0.599492073059082,5055.575888395309,15641,0,5055.575888395309,0.9857046008110046,0.0473469384014606,0.2544363572094351,43793,7441.932022809982,0.990464985370636,0.031555913388729,0.3589935910191276,0.9865288138389589,0.044826865196228,0.2577375655676164,43793 -2502.7095897197723,0.631049633026123,5295.590830564499,16376,0,5295.590830564499,0.9857821464538574,0.0473953075706958,0.2509488054962481,43793,7799.37594294548,0.990566611289978,0.0311063472181558,0.397761553804051,0.9866769909858704,0.0444979481399059,0.2670592840195563,43793 -2611.6243760585785,0.6633293628692627,5535.744588136673,17110,0,5535.744588136673,0.9858613014221193,0.0475315563380718,0.2466091120410939,43793,8148.500215291977,0.9904873967170716,0.0309950038790702,0.3928430456424588,0.9867208003997804,0.044674951583147,0.2598440136279063,43793 -2718.475531101227,0.6938979625701904,5775.890267133713,17841,0,5775.890267133713,0.9859004616737366,0.0474623404443264,0.2501062094638308,43793,8495.548310518265,0.9907149076461792,0.0303160939365625,0.4062286195899587,0.9867061972618104,0.0447667352855205,0.2629582640949041,43793 -2827.3079738616943,0.7254743576049805,6016.049191236496,18577,0,6016.049191236496,0.985743761062622,0.0472783781588077,0.2573864426298581,43793,8844.591564416885,0.9908282160758972,0.0297470297664403,0.4252668941110549,0.986599862575531,0.0443700589239597,0.2613065277900078,43793 -2931.338026046753,0.7550005912780762,6256.203496932983,19329,0,6256.203496932983,0.9858258962631226,0.0471059717237949,0.2557344872373129,43793,9188.825388908386,0.9912741780281068,0.028726851567626,0.4346732790575787,0.9866855144500732,0.044583573937416,0.2656386849488621,43793 -3043.9894185066223,0.7850522994995117,6496.22972202301,20076,0,6496.22972202301,0.9858339428901672,0.046953234821558,0.2626131459051882,43793,9541.552764177322,0.9913241863250732,0.0283257327973842,0.4522674479990455,0.9865933656692504,0.0445747785270214,0.2639246465543347,43793 -3157.6582396030426,0.8184309005737305,6736.18346953392,20821,0,6736.18346953392,0.9857593774795532,0.047393824905157,0.2585024581789923,43793,9895.228059530258,0.9910668134689332,0.0292000044137239,0.4294235006191442,0.9865803718566896,0.0445712096989154,0.2657850786309507,43793 -3264.571383714676,0.8491370677947998,6976.153180122376,21563,0,6976.153180122376,0.9858583807945251,0.0471864379942417,0.264102735060814,43793,10242.161269426346,0.9908154606819152,0.0297638848423957,0.4070325561547697,0.9867427349090576,0.0443494729697704,0.2697587835020477,43793 -3377.514060497284,0.8812205791473389,7216.347539186478,22303,0,7216.347539186478,0.9859581589698792,0.0471544824540615,0.2582019135397603,43793,10595.352976083755,0.9910184144973756,0.0295051932334899,0.4229856291489511,0.9867549538612366,0.0441675409674644,0.2666595006962322,43793 -3484.15520453453,0.9138262271881104,7456.5054433345795,23037,0,7456.5054433345795,0.986004114151001,0.0473821796476841,0.2643163832457119,43793,10942.205500364304,0.991021454334259,0.0294076334685087,0.431725123072511,0.9867947101593018,0.0446119531989097,0.2679407448618573,43793 -3592.053944826126,0.946509599685669,7696.663548946381,23786,0,7696.663548946381,0.9857631921768188,0.0475984513759613,0.2511069827343094,43793,11290.314909934998,0.9908633828163148,0.0297190472483634,0.4275191391629354,0.9866741299629213,0.0445358306169509,0.266686509655212,43793 -3700.659593105316,0.978080987930298,7936.902099847794,24530,0,7936.902099847794,0.9859198331832886,0.0475117191672325,0.2620446093170533,43793,11639.210319280624,0.991230607032776,0.0282744206488132,0.4655806054908825,0.9868429899215698,0.0444994159042835,0.2668326704110005,43793 -3803.438230276108,1.0112404823303225,8177.051397800446,25279,0,8177.051397800446,0.9858301281929016,0.0474405437707901,0.2627028867128916,43793,11982.191071748734,0.9913411736488342,0.0279610250145196,0.4508059522935717,0.986743986606598,0.0443639457225799,0.2656890057799961,43793 -3914.915863752365,1.0447053909301758,8417.059885501862,26024,0,8417.059885501862,0.9858301281929016,0.0477521158754825,0.260429039874971,43793,12333.731023311617,0.9914472103118896,0.0275970380753278,0.4639268114410071,0.9867309927940368,0.0446482859551906,0.2693065018483168,43793 -4023.420888900757,1.0764586925506592,8657.316404104233,26765,0,8657.316404104233,0.9860129356384276,0.0470789335668087,0.2639423076898141,43793,12682.5444586277,0.9916866421699524,0.0269755423069,0.4817028067889763,0.9868255853652954,0.0443985722959041,0.2758032396284843,43793 -4129.012980937958,1.1086671352386477,8897.550789117813,27507,0,8897.550789117813,0.9858406782150269,0.0475322641432285,0.2560654260851063,43793,13028.423025131226,0.9914228916168212,0.0279256198555231,0.4597945923570499,0.9868003726005554,0.0443090498447418,0.2730667285241964,43793 -4237.702749490738,1.1413967609405518,9137.598731279371,28250,0,9137.598731279371,0.9859476685523988,0.0476632639765739,0.2613195491298001,43793,13377.213291406631,0.9913296699523926,0.0280524995177984,0.4551051103635106,0.9868763089179992,0.0447614639997482,0.2724308668815321,43793 -4343.46866941452,1.1738147735595703,9377.87127995491,28993,0,9377.87127995491,0.9859139323234558,0.0477691814303398,0.258527015701871,43793,13723.3050699234,0.991166055202484,0.028592262417078,0.448091841051458,0.9867184162139891,0.0448287017643451,0.2768478901678559,43793 -4448.318630695343,1.2066032886505127,9618.070008039474,29745,0,9618.070008039474,0.9858364462852478,0.0477632582187652,0.2666181185574914,43793,14068.406420230864,0.991333782672882,0.0280229710042476,0.4609738230187459,0.986785352230072,0.0448971949517726,0.2769324470043009,43793 -4554.223012685776,1.2404890060424805,9858.108284711838,30499,0,9858.108284711838,0.9857665300369264,0.0479437932372093,0.2634665263536712,43793,14414.402602910995,0.9914437532424928,0.0275432635098695,0.4647404749149301,0.9866364002227784,0.0448987856507301,0.2743575633883375,43793 -4664.768091201782,1.2731256484985352,10098.096035957336,31246,0,10098.096035957336,0.9858992099761964,0.0478649251163005,0.261310824095927,43793,14764.987814426422,0.991542637348175,0.0271708536893129,0.4721364663286831,0.9867780804634094,0.0449287667870521,0.2721670726653356,43793 -4771.527981519699,1.307499647140503,10338.08183336258,31991,0,10338.08183336258,0.9860554933547974,0.0480229407548904,0.2684884152472522,43793,15111.78766322136,0.9917091727256776,0.0264940895140171,0.5052146545409436,0.9868125915527344,0.0450993701815605,0.2773346138812802,43793 -4873.972446680069,1.34173321723938,10578.065612077711,32748,0,10578.065612077711,0.985779583454132,0.0479509383440017,0.2652142004790609,43793,15454.269760608671,0.9920130968093872,0.0255514904856681,0.5120100156677214,0.9867143034934998,0.0450541898608207,0.2664563256925483,43793 -4986.211427688599,1.3757221698760986,10818.260104179382,33481,0,10818.260104179382,0.9859358668327332,0.0482783876359462,0.2660083338093088,43793,15806.758916139604,0.9919853806495668,0.0255288798362016,0.5158409082918909,0.986780881881714,0.0451483465731143,0.2665157521690351,43793 -5092.624956607819,1.4126989841461182,11058.269440174105,34209,0,11058.269440174105,0.9858587980270386,0.0481285229325294,0.2619270975504091,43793,16153.24194407463,0.9919201135635376,0.0261252131313085,0.4842983115492911,0.9868369102478028,0.0448420867323875,0.2726715227388324,43793 -5202.5957906246185,1.445598840713501,11298.358862400057,34945,0,11298.358862400057,0.98597252368927,0.0481930114328861,0.2647074282382107,43793,16503.35486650467,0.991661548614502,0.0268019344657659,0.4983068987457481,0.9867764711380004,0.0450432524085044,0.2768250623646225,43793 -5308.205292224884,1.4793949127197266,11538.379633426666,35690,0,11538.379633426666,0.9858074188232422,0.048274740576744,0.2578979398603731,43793,16849.039041757584,0.9916203618049622,0.0269064083695411,0.4801378497014065,0.9867455959320068,0.0451300702989101,0.2683754941878571,43793 -5419.337324857712,1.5126359462738037,11778.522836446762,36436,0,11778.522836446762,0.9859842658042908,0.0479830466210842,0.2682538265776146,43793,17200.367270946503,0.9917700290679932,0.0262035969644784,0.5078279699959175,0.9868600368499756,0.0449077449738979,0.279754069660172,43793 -5525.407692909241,1.5457723140716553,12018.514259815216,37188,0,12018.514259815216,0.9859581589698792,0.0480935610830783,0.265025083874453,43793,17546.48204088211,0.9918221831321716,0.0260329116135835,0.4927025177955553,0.9868361353874208,0.0448870733380317,0.2795414593961319,43793 -5628.5393006801605,1.5792734622955322,12258.54852104187,37936,0,12258.54852104187,0.9859695434570312,0.048520628362894,0.2581425592235508,43793,17889.70045185089,0.9920411109924316,0.0253579262644052,0.521290996199623,0.9868373274803162,0.0454892143607139,0.2716746889496871,43793 -5734.337277889252,1.6127097606658936,12498.732014417648,38681,0,12498.732014417648,0.9859328866004944,0.0484734661877155,0.2601893972789249,43793,18235.73575282097,0.9922906160354614,0.0244758687913417,0.5456552959656337,0.9868028163909912,0.0452447049319744,0.2777050270904259,43793 -5842.779334545136,1.6485440731048584,12738.766446590424,39429,0,12738.766446590424,0.985876441001892,0.0485768802464008,0.2619989013770163,43793,18584.267076969147,0.9927371740341188,0.0233322009444236,0.5693393673773479,0.9867658615112304,0.0452545844018459,0.2783459699096976,43793 -5945.3051841259,1.685373306274414,12978.871631383896,40178,0,12978.871631383896,0.9857779145240784,0.0483333766460418,0.2643365091769727,43793,18926.955935239792,0.9926030039787292,0.0237077176570892,0.5622201792353998,0.9866286516189576,0.0452627837657928,0.272614353780963,43793 -6054.370141029358,1.7200186252593994,13219.148606538773,40922,0,13219.148606538773,0.9859089255332948,0.0489551685750484,0.262507941629819,43793,19276.35291576385,0.9924303889274596,0.0241570603102445,0.531331590513631,0.986861288547516,0.0455635003745555,0.2747620842624564,43793 -6158.383851766586,1.759134292602539,13459.148141384125,41668,0,13459.148141384125,0.9858528971672058,0.0490683577954769,0.2610362912805396,43793,19620.42688894272,0.9921717047691344,0.0247845407575368,0.5356050889597663,0.986797571182251,0.0456201955676078,0.2794463055468996,43793 -6267.478013277054,1.794682502746582,13699.192795991898,42412,0,13699.192795991898,0.9858364462852478,0.0495792776346206,0.2634238571883222,43793,19969.621764421463,0.9920896291732788,0.0249362252652645,0.5138407443009074,0.9867837429046632,0.0462259538471698,0.2705062063880303,43793 -6369.788389205933,1.8295137882232664,13939.405371427536,43161,0,13939.405371427536,0.9857585430145264,0.0493027940392494,0.2641423804382998,43793,20312.19924545288,0.9922144412994384,0.0246965419501066,0.5323905620648921,0.9867528676986694,0.0458684600889682,0.277431209463629,43793 -6479.280866146088,1.8684642314910889,14179.438196897509,43909,0,14179.438196897509,0.9857884645462036,0.0491280667483806,0.2697252043773885,43793,20661.7832839489,0.9923882484436036,0.0241861324757337,0.5461527346502013,0.9866071343421936,0.04615493491292,0.27460194184325,43793 -6583.301622629166,1.906172752380371,14419.610810041428,44660,0,14419.610810041428,0.9857361912727356,0.049146506935358,0.2656135139712025,43793,21006.03396821022,0.9926992058753968,0.0229804348200559,0.573160671933495,0.9866668581962584,0.0460957139730453,0.2783106684866889,43793 -6688.38343501091,1.94104528427124,14659.80555152893,45411,0,14659.80555152893,0.9858831763267516,0.050001386553049,0.2606782498808195,43793,21351.36517882347,0.992841899394989,0.0224623624235391,0.5858940454668893,0.9867216348648072,0.0465570129454135,0.273357246040638,43793 -6791.985899925232,1.9806413650512693,14900.003878116608,46154,0,14900.003878116608,0.9857484102249146,0.0501597858965396,0.262155570037263,43793,21695.22643852234,0.9932268857955932,0.0213325824588537,0.6116497269315911,0.9866270422935486,0.0467495769262313,0.2717662805382729,43793 -6898.519830942154,2.017240524291992,15140.24789738655,46898,0,15140.24789738655,0.9857871532440186,0.0503666512668132,0.2615554241055192,43793,22042.06078195572,0.993324100971222,0.0210649482905864,0.6244343061105562,0.986622154712677,0.0470507629215717,0.2749869542278762,43793 -7000.982522249222,2.052924633026123,15380.289312124252,47655,0,15380.289312124252,0.9858199954032898,0.0504102669656276,0.2589229370037478,43793,22384.61990666389,0.9933059215545654,0.0211957413703203,0.6117602752374381,0.9866501688957214,0.0472495593130588,0.2753124853835367,43793 -7109.367336988449,2.09198260307312,15620.493504047394,48397,0,15620.493504047394,0.985643982887268,0.051022358238697,0.2503140930026999,43793,22733.268282413483,0.9927676320075988,0.0228091683238744,0.5803819765058651,0.986539363861084,0.0478117503225803,0.2669560169577921,43793 -7211.308455705643,2.13135313987732,15860.683614492416,49145,0,15860.683614492416,0.9856220483779908,0.05133892968297,0.2546934945497422,43793,23075.4584479332,0.9926386475563048,0.0231125317513942,0.5560028585107328,0.986544668674469,0.0476967096328735,0.2754886023087445,43793 -7319.88063287735,2.16735315322876,16100.88074851036,49889,0,16100.88074851036,0.985731601715088,0.0513281859457492,0.256285389852942,43793,23424.28323435784,0.992885947227478,0.0223179999738931,0.5716329006994851,0.9866303205490112,0.0478951223194599,0.2716456802267858,43793 -7424.692152500153,2.204059839248657,16341.046777009964,50635,0,16341.046777009964,0.9856178760528564,0.0514837466180324,0.2527862156016627,43793,23769.31679654121,0.9929296374320984,0.0218769125640392,0.6079837780471014,0.9864983558654784,0.0479838177561759,0.2703280493989189,43793 -7529.710796117783,2.239814043045044,16581.230067253113,51387,0,16581.230067253113,0.985525608062744,0.0517373047769069,0.25799338239721,43793,24114.57426643372,0.99312424659729,0.0213737860321998,0.6051922907115049,0.986455738544464,0.0481710955500602,0.2786032396239108,43793 -7630.984179973602,2.2781577110290527,16821.317676067352,52126,0,16821.317676067352,0.9856654405593872,0.0520572513341903,0.2631199337380247,43793,24455.993393421173,0.9935033917427064,0.0201828982681036,0.6243054543350697,0.9865494966506958,0.0487725138664245,0.2713997153176349,43793 -7733.823001623154,2.314923048019409,17061.454341888428,52874,0,17061.454341888428,0.9856258630752563,0.0528428927063941,0.2543117989191361,43793,24799.02546787262,0.9940038323402404,0.0186669565737247,0.6656468057073384,0.9866250157356262,0.0489798784255981,0.2706539868005573,43793 -7845.143984079361,2.351712226867676,17301.6496155262,53612,0,17301.6496155262,0.9855963587760924,0.0533251948654651,0.2544714874167593,43793,25150.59912610054,0.9942200183868408,0.0180659666657447,0.6819209665010129,0.9864720106124878,0.0495971888303756,0.2684478956022052,43793 -7950.844016551971,2.3931870460510254,17541.733896255493,54367,0,17541.733896255493,0.9855533838272096,0.0534865520894527,0.2591776566716485,43793,25496.44469404221,0.994297206401825,0.0178672652691602,0.6827159617126499,0.9865003824234008,0.0496802628040313,0.2730149322743814,43793 -8059.008177280426,2.4300525188446045,17781.83658337593,55116,0,17781.83658337593,0.985566020011902,0.0536334998905658,0.2579593188538949,43793,25844.768416643143,0.9939027428627014,0.0188218895345926,0.6740066988974451,0.9864963293075562,0.0502084381878376,0.2762999691316146,43793 -8162.010338068008,2.467262029647827,18021.98615550995,55871,0,18021.98615550995,0.9854717254638672,0.0546558536589145,0.2514616051676566,43793,26187.977236509323,0.9938641786575316,0.0188779514282941,0.6532421190749456,0.9863501787185668,0.0509905144572258,0.2727208477607921,43793 -8265.10645365715,2.506730794906616,18262.12407398224,56616,0,18262.12407398224,0.985351026058197,0.05505384877324104,0.2538779185836647,43793,26531.270143032074,0.9935514330863953,0.019919833168387413,0.6342096266526925,0.9862856864929199,0.051337894052267075,0.2656026553756837,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/measurements.csv deleted file mode 100644 index b840f6019..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/measurements.csv +++ /dev/null @@ -1,652 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.736839,0.7145253,,,,,,,,,,,,,,,,, -1,,,0.5250248908996582,0.7151214480400085,0.0225963444812421,0.5213832855224609,0.7166012525558472,0.0261423612204539,43793.0,0.5224818587303162,0.7161954641342163,0.0278308650959588,43793.0,13.67792797088623,122.75087356567384,13.67792797088623,109.07290530204772,0.0,0.0 -100,0.6021669,0.42064476,,,,,,,,,,,,,,,,, -200,0.35215497,0.31573054,,,,,,,,,,,,,,,,, -300,0.256844,0.22945008,,,,,,,,,,,,,,,,, -400,0.1691628,0.16078915,,,,,,,,,,,,,,,,, -500,0.10682292,0.11055862,,,,,,,,,,,,,,,,, -600,0.06593774,0.08887928,,,,,,,,,,,,,,,,, -700,0.047252987,0.06710344,,,,,,,,,,,,,,,,, -746,,,0.9866614937782288,0.0689144805073738,0.0343829040720385,0.9841179251670836,0.0773229226469993,0.0368152229443847,43793.0,0.983142077922821,0.0802183300256729,0.0382531656617273,43793.0,253.8570437431336,475.6508138179779,253.8570437431336,221.7531237602234,0.0207765102386474,0.0 -800,0.03477175,0.06740332,,,,,,,,,,,,,,,,, -900,0.21892506,0.06215878,,,,,,,,,,,,,,,,, -1000,0.070104554,0.05902601,,,,,,,,,,,,,,,,, -1100,0.2174903,0.0557108,,,,,,,,,,,,,,,,, -1200,0.23384,0.05258089,,,,,,,,,,,,,,,,, -1300,0.19029747,0.049876235,,,,,,,,,,,,,,,,, -1400,0.18845022,0.046419527,,,,,,,,,,,,,,,,, -1491,,,0.9872501492500304,0.0485505908727645,0.0815675372660544,0.9845539331436156,0.0582684986293315,0.087065007711593,43793.0,0.9835514426231384,0.0615411400794982,0.0909965974697616,43793.0,493.8156957626343,822.6260054111481,493.8156957626343,328.72299671173096,0.0476095676422119,0.0 -1500,0.3349249,0.043733105,,,,,,,,,,,,,,,,, -1600,0.5290166,0.048917573,,,,,,,,,,,,,,,,, -1700,0.24359909,0.050353877,,,,,,,,,,,,,,,,, -1800,0.22130764,0.04851325,,,,,,,,,,,,,,,,, -1900,0.08006981,0.04855258,,,,,,,,,,,,,,,,, -2000,0.2269593,0.043706965,,,,,,,,,,,,,,,,, -2100,0.13584577,0.040103484,,,,,,,,,,,,,,,,, -2200,0.20477715,0.044035576,,,,,,,,,,,,,,,,, -2243,,,0.9877920150756836,0.0436826199293136,0.1433774872818228,0.9850483536720276,0.0529293678700923,0.1361727300785612,43793.0,0.9840973615646362,0.055955272167921,0.1340190769220492,43793.0,733.8995008468628,1174.291626214981,733.8995008468628,440.2566766738892,0.0757808685302734,0.0 -2300,0.13167384,0.040757183,,,,,,,,,,,,,,,,, -2400,0.066449836,0.042905,,,,,,,,,,,,,,,,, -2500,0.077699766,0.03792321,,,,,,,,,,,,,,,,, -2600,0.11209747,0.04507673,,,,,,,,,,,,,,,,, -2700,0.10873799,0.039019763,,,,,,,,,,,,,,,,, -2800,0.09168806,0.038135044,,,,,,,,,,,,,,,,, -2900,0.1295042,0.038485657,,,,,,,,,,,,,,,,, -2990,,,0.9879328012466432,0.0425906032323837,0.1625567684371191,0.9850613474845886,0.0526109784841537,0.1509572942259753,43793.0,0.9841592311859132,0.0553637146949768,0.1528859321678368,43793.0,974.0467576980592,1524.873197555542,974.0467576980592,550.643904209137,0.1027662754058837,0.0 -3000,0.10263753,0.04362779,,,,,,,,,,,,,,,,, -3100,0.07696137,0.043589164,,,,,,,,,,,,,,,,, -3200,0.093104266,0.038259797,,,,,,,,,,,,,,,,, -3300,0.14268948,0.0408994,,,,,,,,,,,,,,,,, -3400,0.09506091,0.040460095,,,,,,,,,,,,,,,,, -3500,0.06653921,0.04140807,,,,,,,,,,,,,,,,, -3600,0.15273423,0.037124533,,,,,,,,,,,,,,,,, -3700,0.17528193,0.04204057,,,,,,,,,,,,,,,,, -3731,,,0.9883593916893004,0.0405401736497879,0.1877052384465054,0.9855139851570128,0.0501751229166984,0.1707993933649933,43793.0,0.9846293330192566,0.0528262816369533,0.1723300618497228,43793.0,1214.0511775016785,1874.2206497192385,1214.0511775016785,659.9346287250519,0.134716510772705,0.0 -3800,0.09811728,0.03999652,,,,,,,,,,,,,,,,, -3900,0.056501314,0.042116713,,,,,,,,,,,,,,,,, -4000,0.11763214,0.042799033,,,,,,,,,,,,,,,,, -4100,0.15791239,0.03723101,,,,,,,,,,,,,,,,, -4200,0.090160556,0.043308374,,,,,,,,,,,,,,,,, -4300,0.05498333,0.040764537,,,,,,,,,,,,,,,,, -4400,0.07852218,0.033192188,,,,,,,,,,,,,,,,, -4471,,,0.9885321259498596,0.039501205086708,0.2105221418082331,0.9856792092323304,0.048727061599493,0.1853128791746986,43793.0,0.984872341156006,0.0512144789099693,0.1853954365318094,43793.0,1454.1758043766022,2223.534078359604,1454.1758043766022,769.0754227638245,0.1634418964385986,0.0 -4500,0.06279481,0.040605348,,,,,,,,,,,,,,,,, -4600,0.052626032,0.043533158,,,,,,,,,,,,,,,,, -4700,0.053227656,0.042979885,,,,,,,,,,,,,,,,, -4800,0.05944386,0.040093504,,,,,,,,,,,,,,,,, -4900,0.055763375,0.037226845,,,,,,,,,,,,,,,,, -5000,0.053901605,0.03704632,,,,,,,,,,,,,,,,, -5100,0.16606782,0.040210962,,,,,,,,,,,,,,,,, -5200,0.06258809,0.039707787,,,,,,,,,,,,,,,,, -5214,,,0.9887311458587646,0.038518089801073,0.2226331766038031,0.9857717156410216,0.048647440969944,0.1899224175615532,43793.0,0.9848892092704772,0.0511879958212375,0.1917954177815328,43793.0,1694.1631109714508,2571.1419973373413,1694.1631109714508,876.6488153934479,0.1911218166351318,0.0 -5300,0.07562118,0.037473965,,,,,,,,,,,,,,,,, -5400,0.052678294,0.036628895,,,,,,,,,,,,,,,,, -5500,0.059052896,0.041039117,,,,,,,,,,,,,,,,, -5600,0.061886214,0.03470134,,,,,,,,,,,,,,,,, -5700,0.055679165,0.038081683,,,,,,,,,,,,,,,,, -5800,0.061589938,0.035124585,,,,,,,,,,,,,,,,, -5900,0.08357498,0.03804251,,,,,,,,,,,,,,,,, -5958,,,0.9891310334205629,0.0371371954679489,0.2455989572156849,0.9858813285827636,0.0474062897264957,0.1994098609323364,43793.0,0.9849742650985718,0.0500043183565139,0.2009979646605937,43793.0,1934.2749044895168,2920.176949262619,1934.2749044895168,985.5234348773956,0.2186455726623535,0.0 -6000,0.04007588,0.036132082,,,,,,,,,,,,,,,,, -6100,0.04287815,0.039033808,,,,,,,,,,,,,,,,, -6200,0.09314924,0.0377904,,,,,,,,,,,,,,,,, -6300,0.06799514,0.038170632,,,,,,,,,,,,,,,,, -6400,0.04632738,0.036061324,,,,,,,,,,,,,,,,, -6500,0.038746286,0.034776367,,,,,,,,,,,,,,,,, -6600,0.03790998,0.039060146,,,,,,,,,,,,,,,,, -6700,0.04153527,0.03867416,,,,,,,,,,,,,,,,, -6707,,,0.9892842769622804,0.0365219675004482,0.2603477099825366,0.9860348105430604,0.0468565262854099,0.2105766077737178,43793.0,0.9851874113082886,0.0493319183588027,0.216469287800845,43793.0,2174.433957338333,3268.484499692917,2174.433957338333,1093.623272895813,0.24747896194458,0.0 -6800,0.06068647,0.039179705,,,,,,,,,,,,,,,,, -6900,0.113292284,0.041014668,,,,,,,,,,,,,,,,, -7000,0.03938729,0.034434598,,,,,,,,,,,,,,,,, -7100,0.03949969,0.037655253,,,,,,,,,,,,,,,,, -7200,0.03465047,0.034526333,,,,,,,,,,,,,,,,, -7300,0.033339508,0.03677891,,,,,,,,,,,,,,,,, -7400,0.030816698,0.031546365,,,,,,,,,,,,,,,,, -7447,,,0.989247500896454,0.0362639613449573,0.273397267948866,0.9859694242477416,0.0470303744077682,0.2144172040557444,43793.0,0.9851532578468324,0.0495551265776157,0.2166048805413218,43793.0,2414.466232776642,3615.946009159088,2414.466232776642,1201.0050013065338,0.2745974063873291,0.0 -7500,0.02246668,0.036551137,,,,,,,,,,,,,,,,, -7600,0.03839017,0.034574695,,,,,,,,,,,,,,,,, -7700,0.045908514,0.04155614,,,,,,,,,,,,,,,,, -7800,0.055068288,0.03849049,,,,,,,,,,,,,,,,, -7900,0.025910081,0.03533487,,,,,,,,,,,,,,,,, -8000,0.035976514,0.037436914,,,,,,,,,,,,,,,,, -8100,0.038401585,0.0408932,,,,,,,,,,,,,,,,, -8197,,,0.9892252683639526,0.0363873355090618,0.2732933598362437,0.9859321117401124,0.046553336083889,0.2161812692848252,43793.0,0.9851823449134828,0.0489699803292751,0.2174333997986895,43793.0,2654.612272977829,3962.862035751343,2654.612272977829,1307.7273774147034,0.3020987510681152,0.0 -8200,0.023741024,0.03503423,,,,,,,,,,,,,,,,, -8300,0.02861127,0.038264386,,,,,,,,,,,,,,,,, -8400,0.028115548,0.03601785,,,,,,,,,,,,,,,,, -8500,0.037882067,0.034309767,,,,,,,,,,,,,,,,, -8600,0.037346795,0.037026733,,,,,,,,,,,,,,,,, -8700,0.028592968,0.03628683,,,,,,,,,,,,,,,,, -8800,0.031192716,0.034488,,,,,,,,,,,,,,,,, -8900,0.030570699,0.040583894,,,,,,,,,,,,,,,,, -8944,,,0.989263653755188,0.0359275192022323,0.2766664071539449,0.9862779378890992,0.0461990498006343,0.2300533427368359,43793.0,0.9854089617729188,0.0488144159317016,0.2225351922664364,43793.0,2894.6948940753937,4314.785717964172,2894.6948940753937,1419.5197608470917,0.3305845260620117,0.0 -9000,0.03298572,0.035478335,,,,,,,,,,,,,,,,, -9100,0.030820487,0.036394224,,,,,,,,,,,,,,,,, -9200,0.03660055,0.031488467,,,,,,,,,,,,,,,,, -9300,0.023644645,0.03526316,,,,,,,,,,,,,,,,, -9400,0.03133497,0.032986533,,,,,,,,,,,,,,,,, -9500,0.022027826,0.029941719,,,,,,,,,,,,,,,,, -9600,0.030747237,0.0355467,,,,,,,,,,,,,,,,, -9690,,,0.9895308017730712,0.0351092107594013,0.297002617842767,0.9863014817237854,0.0454533360898494,0.2307974903300786,43793.0,0.9854245185852052,0.0481401570141315,0.2321340221277662,43793.0,3134.852992296219,4663.980555534363,3134.852992296219,1528.5073668956757,0.3598239421844482,0.0 -9700,0.049494494,0.035948817,,,,,,,,,,,,,,,,, -9800,0.043123182,0.03520254,,,,,,,,,,,,,,,,, -9900,0.030428644,0.035541188,,,,,,,,,,,,,,,,, -10000,0.04196148,0.03318708,,,,,,,,,,,,,,,,, -10100,0.032536596,0.03796072,,,,,,,,,,,,,,,,, -10200,0.046462342,0.032425,,,,,,,,,,,,,,,,, -10300,0.02807419,0.038292233,,,,,,,,,,,,,,,,, -10400,0.050462946,0.03353299,,,,,,,,,,,,,,,,, -10443,,,0.9899264574050904,0.0339541770517826,0.3211785197160298,0.9864752292633056,0.0452593080699443,0.2440599289904552,43793.0,0.9855690002441406,0.0480202287435531,0.2404512825341332,43793.0,3374.9913704395294,5009.81748175621,3374.9913704395294,1634.1574666500092,0.3879969120025635,0.0 -10500,0.02533382,0.033726968,,,,,,,,,,,,,,,,, -10600,0.04094257,0.036109604,,,,,,,,,,,,,,,,, -10700,0.03826613,0.03455402,,,,,,,,,,,,,,,,, -10800,0.058827065,0.032403592,,,,,,,,,,,,,,,,, -10900,0.052210048,0.038215503,,,,,,,,,,,,,,,,, -11000,0.02716547,0.035228286,,,,,,,,,,,,,,,,, -11100,0.036192585,0.032407213,,,,,,,,,,,,,,,,, -11189,,,0.9897028207778932,0.0340827777981758,0.3120466220388052,0.9860798716545104,0.0459315627813339,0.2344883683706195,43793.0,0.9851945638656616,0.0487063974142074,0.2293547026063842,43793.0,3614.948825597763,5356.488221406937,3614.948825597763,1740.821620941162,0.4178082942962646,0.0 -11200,0.036542043,0.034671307,,,,,,,,,,,,,,,,, -11300,0.04492636,0.0356158,,,,,,,,,,,,,,,,, -11400,0.036197472,0.036404226,,,,,,,,,,,,,,,,, -11500,0.029020054,0.03383197,,,,,,,,,,,,,,,,, -11600,0.039071303,0.03302847,,,,,,,,,,,,,,,,, -11700,0.044017218,0.031303585,,,,,,,,,,,,,,,,, -11800,0.03423276,0.030956775,,,,,,,,,,,,,,,,, -11900,0.044632148,0.036648005,,,,,,,,,,,,,,,,, -11938,,,0.9901050925254822,0.0328376404941082,0.3517702449602501,0.9865832328796388,0.0449516586959362,0.2546755345667482,43793.0,0.9857610464096068,0.0475765503942966,0.2491225917398928,43793.0,3855.154119253159,5702.779677867889,3855.154119253159,1846.8564975261688,0.4476304054260254,0.0 -12000,0.044274516,0.032119695,,,,,,,,,,,,,,,,, -12100,0.044603087,0.035092365,,,,,,,,,,,,,,,,, -12200,0.044902608,0.031472713,,,,,,,,,,,,,,,,, -12300,0.031568598,0.02927003,,,,,,,,,,,,,,,,, -12400,0.044066705,0.031153886,,,,,,,,,,,,,,,,, -12500,0.040515874,0.031859156,,,,,,,,,,,,,,,,, -12600,0.037593894,0.032635137,,,,,,,,,,,,,,,,, -12680,,,0.990390956401825,0.0319230556488037,0.3709610937739455,0.986536145210266,0.0446862578392028,0.2527580664724838,43793.0,0.9856886267662048,0.0474681220948696,0.2495203850540048,43793.0,4095.321917295456,6049.124286413193,4095.321917295456,1952.983324766159,0.4773571491241455,0.0 -12700,0.043352485,0.031849034,,,,,,,,,,,,,,,,, -12800,0.05572111,0.03424288,,,,,,,,,,,,,,,,, -12900,0.043767657,0.031640735,,,,,,,,,,,,,,,,, -13000,0.041211747,0.03572922,,,,,,,,,,,,,,,,, -13100,0.0408718,0.035696056,,,,,,,,,,,,,,,,, -13200,0.08764862,0.038950656,,,,,,,,,,,,,,,,, -13300,0.066724144,0.02984381,,,,,,,,,,,,,,,,, -13400,0.062010728,0.03623871,,,,,,,,,,,,,,,,, -13428,,,0.9904465079307556,0.0313658602535724,0.3848623867266765,0.986697256565094,0.0446103140711784,0.2574315640066265,43793.0,0.9858078360557556,0.0477101020514965,0.2436257342937955,43793.0,4335.301592588425,6401.863872051239,4335.301592588425,2065.694550514221,0.5062572956085205,0.0 -13500,0.036679044,0.028757257,,,,,,,,,,,,,,,,, -13600,0.057010558,0.033051837,,,,,,,,,,,,,,,,, -13700,0.041812416,0.034102418,,,,,,,,,,,,,,,,, -13800,0.049281728,0.033949126,,,,,,,,,,,,,,,,, -13900,0.04605741,0.033768818,,,,,,,,,,,,,,,,, -14000,0.049451366,0.036450967,,,,,,,,,,,,,,,,, -14100,0.04046794,0.03523962,,,,,,,,,,,,,,,,, -14163,,,0.9905685186386108,0.0309013519436121,0.3852918889568212,0.9867021441459656,0.0448856018483638,0.248555407906438,43793.0,0.9858394265174866,0.047771580517292,0.2455737676991713,43793.0,4575.268049478531,6748.188427686691,4575.268049478531,2171.998544931412,0.5388197898864746,0.0 -14200,0.06362424,0.031623382,,,,,,,,,,,,,,,,, -14300,0.046233524,0.031017639,,,,,,,,,,,,,,,,, -14400,0.038497128,0.03064206,,,,,,,,,,,,,,,,, -14500,0.05959014,0.032637563,,,,,,,,,,,,,,,,, -14600,0.053658385,0.031412978,,,,,,,,,,,,,,,,, -14700,0.094402805,0.02973747,,,,,,,,,,,,,,,,, -14800,0.055632874,0.033446282,,,,,,,,,,,,,,,,, -14900,,,0.9904906153678894,0.0313890948891639,0.3731714129751292,0.9866453409194946,0.0449206680059433,0.2515589492994831,43793.0,0.985740840435028,0.0479223951697349,0.2417695698170992,43793.0,4815.410197257996,7096.95876121521,4815.410197257996,2280.576951980591,0.5679218769073486,0.0 -14900,0.0473877,0.028240627,,,,,,,,,,,,,,,,, -15000,0.057412382,0.03249536,,,,,,,,,,,,,,,,, -15100,0.05500905,0.028383046,,,,,,,,,,,,,,,,, -15200,0.057273533,0.03151725,,,,,,,,,,,,,,,,, -15300,0.0659813,0.034251783,,,,,,,,,,,,,,,,, -15400,0.057410166,0.035079237,,,,,,,,,,,,,,,,, -15500,0.054962236,0.031338204,,,,,,,,,,,,,,,,, -15600,0.060670476,0.028364832,,,,,,,,,,,,,,,,, -15641,,,0.990464985370636,0.031555913388729,0.3589935910191276,0.9865288138389589,0.044826865196228,0.2577375655676164,43793.0,0.9857046008110046,0.0473469384014606,0.2544363572094351,43793.0,5055.575888395309,7441.932022809982,5055.575888395309,2385.3323168754578,0.599492073059082,0.0 -15700,0.04803343,0.027616328,,,,,,,,,,,,,,,,, -15800,0.048537433,0.029991832,,,,,,,,,,,,,,,,, -15900,0.07461159,0.03438981,,,,,,,,,,,,,,,,, -16000,0.05298494,0.03131688,,,,,,,,,,,,,,,,, -16100,0.06412544,0.031403225,,,,,,,,,,,,,,,,, -16200,0.060572308,0.03164522,,,,,,,,,,,,,,,,, -16300,0.091379054,0.027775913,,,,,,,,,,,,,,,,, -16376,,,0.990566611289978,0.0311063472181558,0.397761553804051,0.9866769909858704,0.0444979481399059,0.2670592840195563,43793.0,0.9857821464538574,0.0473953075706958,0.2509488054962481,43793.0,5295.590830564499,7799.37594294548,5295.590830564499,2502.7095897197723,0.631049633026123,0.0 -16400,0.06400132,0.033776067,,,,,,,,,,,,,,,,, -16500,0.07492143,0.0336502,,,,,,,,,,,,,,,,, -16600,0.06381625,0.033010148,,,,,,,,,,,,,,,,, -16700,0.061619174,0.031150453,,,,,,,,,,,,,,,,, -16800,0.087606534,0.030720407,,,,,,,,,,,,,,,,, -16900,0.10866775,0.031861145,,,,,,,,,,,,,,,,, -17000,0.078274265,0.030415947,,,,,,,,,,,,,,,,, -17100,0.08684737,0.03067564,,,,,,,,,,,,,,,,, -17110,,,0.9904873967170716,0.0309950038790702,0.3928430456424588,0.9867208003997804,0.044674951583147,0.2598440136279063,43793.0,0.9858613014221193,0.0475315563380718,0.2466091120410939,43793.0,5535.744588136673,8148.500215291977,5535.744588136673,2611.6243760585785,0.6633293628692627,0.0 -17200,0.1536026,0.030742884,,,,,,,,,,,,,,,,, -17300,0.07382239,0.030801421,,,,,,,,,,,,,,,,, -17400,0.06358831,0.03031778,,,,,,,,,,,,,,,,, -17500,0.062845886,0.032329828,,,,,,,,,,,,,,,,, -17600,0.08270361,0.029631628,,,,,,,,,,,,,,,,, -17700,0.06703516,0.029642584,,,,,,,,,,,,,,,,, -17800,0.11028485,0.033385724,,,,,,,,,,,,,,,,, -17841,,,0.9907149076461792,0.0303160939365625,0.4062286195899587,0.9867061972618104,0.0447667352855205,0.2629582640949041,43793.0,0.9859004616737366,0.0474623404443264,0.2501062094638308,43793.0,5775.890267133713,8495.548310518265,5775.890267133713,2718.475531101227,0.6938979625701904,0.0 -17900,0.11197488,0.027784094,,,,,,,,,,,,,,,,, -18000,0.067915365,0.034992483,,,,,,,,,,,,,,,,, -18100,0.08190066,0.029424815,,,,,,,,,,,,,,,,, -18200,0.0725269,0.028342191,,,,,,,,,,,,,,,,, -18300,0.077497,0.03313861,,,,,,,,,,,,,,,,, -18400,0.10415502,0.030216075,,,,,,,,,,,,,,,,, -18500,0.057824142,0.03223843,,,,,,,,,,,,,,,,, -18577,,,0.9908282160758972,0.0297470297664403,0.4252668941110549,0.986599862575531,0.0443700589239597,0.2613065277900078,43793.0,0.985743761062622,0.0472783781588077,0.2573864426298581,43793.0,6016.049191236496,8844.591564416885,6016.049191236496,2827.3079738616943,0.7254743576049805,0.0 -18600,0.08562918,0.032601304,,,,,,,,,,,,,,,,, -18700,0.07364147,0.030829983,,,,,,,,,,,,,,,,, -18800,0.069378085,0.030716633,,,,,,,,,,,,,,,,, -18900,0.11099942,0.03331164,,,,,,,,,,,,,,,,, -19000,0.0819256,0.035855696,,,,,,,,,,,,,,,,, -19100,0.099616215,0.03202422,,,,,,,,,,,,,,,,, -19200,0.09815566,0.030684756,,,,,,,,,,,,,,,,, -19300,0.08593391,0.032336738,,,,,,,,,,,,,,,,, -19329,,,0.9912741780281068,0.028726851567626,0.4346732790575787,0.9866855144500732,0.044583573937416,0.2656386849488621,43793.0,0.9858258962631226,0.0471059717237949,0.2557344872373129,43793.0,6256.203496932983,9188.825388908386,6256.203496932983,2931.338026046753,0.7550005912780762,0.0 -19400,0.09935342,0.03246606,,,,,,,,,,,,,,,,, -19500,0.05749277,0.030227875,,,,,,,,,,,,,,,,, -19600,0.06873221,0.030473502,,,,,,,,,,,,,,,,, -19700,0.10682322,0.03081832,,,,,,,,,,,,,,,,, -19800,0.08970223,0.029136801,,,,,,,,,,,,,,,,, -19900,0.08452749,0.026673518,,,,,,,,,,,,,,,,, -20000,0.07530416,0.03210343,,,,,,,,,,,,,,,,, -20076,,,0.9913241863250732,0.0283257327973842,0.4522674479990455,0.9865933656692504,0.0445747785270214,0.2639246465543347,43793.0,0.9858339428901672,0.046953234821558,0.2626131459051882,43793.0,6496.22972202301,9541.552764177322,6496.22972202301,3043.9894185066223,0.7850522994995117,0.0 -20100,0.11734931,0.03040728,,,,,,,,,,,,,,,,, -20200,0.08230767,0.028510252,,,,,,,,,,,,,,,,, -20300,0.1802055,0.030787561,,,,,,,,,,,,,,,,, -20400,0.07656433,0.027256373,,,,,,,,,,,,,,,,, -20500,0.07922064,0.03110792,,,,,,,,,,,,,,,,, -20600,0.10187606,0.030905977,,,,,,,,,,,,,,,,, -20700,0.07212467,0.026190782,,,,,,,,,,,,,,,,, -20800,0.08009995,0.027450683,,,,,,,,,,,,,,,,, -20821,,,0.9910668134689332,0.0292000044137239,0.4294235006191442,0.9865803718566896,0.0445712096989154,0.2657850786309507,43793.0,0.9857593774795532,0.047393824905157,0.2585024581789923,43793.0,6736.18346953392,9895.228059530258,6736.18346953392,3157.6582396030426,0.8184309005737305,0.0 -20900,0.06713913,0.03170617,,,,,,,,,,,,,,,,, -21000,0.07985364,0.028972581,,,,,,,,,,,,,,,,, -21100,0.13494503,0.027873946,,,,,,,,,,,,,,,,, -21200,0.0760656,0.029366568,,,,,,,,,,,,,,,,, -21300,0.074129015,0.029796453,,,,,,,,,,,,,,,,, -21400,0.0804267,0.030239027,,,,,,,,,,,,,,,,, -21500,0.06519442,0.027177157,,,,,,,,,,,,,,,,, -21563,,,0.9908154606819152,0.0297638848423957,0.4070325561547697,0.9867427349090576,0.0443494729697704,0.2697587835020477,43793.0,0.9858583807945251,0.0471864379942417,0.264102735060814,43793.0,6976.153180122376,10242.161269426346,6976.153180122376,3264.571383714676,0.8491370677947998,0.0 -21600,0.08833916,0.029920915,,,,,,,,,,,,,,,,, -21700,0.081997365,0.02744542,,,,,,,,,,,,,,,,, -21800,0.09720677,0.027493516,,,,,,,,,,,,,,,,, -21900,0.089816816,0.034085974,,,,,,,,,,,,,,,,, -22000,0.08178619,0.03046926,,,,,,,,,,,,,,,,, -22100,0.08533926,0.029816784,,,,,,,,,,,,,,,,, -22200,0.0818455,0.02701943,,,,,,,,,,,,,,,,, -22300,0.090151206,0.030471938,,,,,,,,,,,,,,,,, -22303,,,0.9910184144973756,0.0295051932334899,0.4229856291489511,0.9867549538612366,0.0441675409674644,0.2666595006962322,43793.0,0.9859581589698792,0.0471544824540615,0.2582019135397603,43793.0,7216.347539186478,10595.352976083755,7216.347539186478,3377.514060497284,0.8812205791473389,0.0 -22400,0.09362395,0.026649183,,,,,,,,,,,,,,,,, -22500,0.06257619,0.026314937,,,,,,,,,,,,,,,,, -22600,0.1012362,0.030215316,,,,,,,,,,,,,,,,, -22700,0.07982333,0.026799807,,,,,,,,,,,,,,,,, -22800,0.09285153,0.027808854,,,,,,,,,,,,,,,,, -22900,0.06739385,0.02874063,,,,,,,,,,,,,,,,, -23000,0.08313236,0.029620532,,,,,,,,,,,,,,,,, -23037,,,0.991021454334259,0.0294076334685087,0.431725123072511,0.9867947101593018,0.0446119531989097,0.2679407448618573,43793.0,0.986004114151001,0.0473821796476841,0.2643163832457119,43793.0,7456.5054433345795,10942.205500364304,7456.5054433345795,3484.15520453453,0.9138262271881104,0.0 -23100,0.08989075,0.02397795,,,,,,,,,,,,,,,,, -23200,0.081693046,0.027078182,,,,,,,,,,,,,,,,, -23300,0.083626196,0.034261823,,,,,,,,,,,,,,,,, -23400,0.10344433,0.031037563,,,,,,,,,,,,,,,,, -23500,0.07829427,0.027953384,,,,,,,,,,,,,,,,, -23600,0.10906731,0.028222354,,,,,,,,,,,,,,,,, -23700,0.08319319,0.027445279,,,,,,,,,,,,,,,,, -23786,,,0.9908633828163148,0.0297190472483634,0.4275191391629354,0.9866741299629213,0.0445358306169509,0.266686509655212,43793.0,0.9857631921768188,0.0475984513759613,0.2511069827343094,43793.0,7696.663548946381,11290.314909934998,7696.663548946381,3592.053944826126,0.946509599685669,0.0 -23800,0.08973073,0.034891684,,,,,,,,,,,,,,,,, -23900,0.124471344,0.03132767,,,,,,,,,,,,,,,,, -24000,0.0935381,0.02963943,,,,,,,,,,,,,,,,, -24100,0.0955964,0.026068468,,,,,,,,,,,,,,,,, -24200,0.09120298,0.02785571,,,,,,,,,,,,,,,,, -24300,0.08115016,0.029117282,,,,,,,,,,,,,,,,, -24400,0.08393908,0.03280154,,,,,,,,,,,,,,,,, -24500,0.09302802,0.029361708,,,,,,,,,,,,,,,,, -24530,,,0.991230607032776,0.0282744206488132,0.4655806054908825,0.9868429899215698,0.0444994159042835,0.2668326704110005,43793.0,0.9859198331832886,0.0475117191672325,0.2620446093170533,43793.0,7936.902099847794,11639.210319280624,7936.902099847794,3700.659593105316,0.978080987930298,0.0 -24600,0.09260433,0.029962968,,,,,,,,,,,,,,,,, -24700,0.08827182,0.02972855,,,,,,,,,,,,,,,,, -24800,0.11770073,0.03200978,,,,,,,,,,,,,,,,, -24900,0.08574748,0.027505038,,,,,,,,,,,,,,,,, -25000,0.09341903,0.029766096,,,,,,,,,,,,,,,,, -25100,0.1223971,0.029021211,,,,,,,,,,,,,,,,, -25200,0.09125783,0.029107423,,,,,,,,,,,,,,,,, -25279,,,0.9913411736488342,0.0279610250145196,0.4508059522935717,0.986743986606598,0.0443639457225799,0.2656890057799961,43793.0,0.9858301281929016,0.0474405437707901,0.2627028867128916,43793.0,8177.051397800446,11982.191071748734,8177.051397800446,3803.438230276108,1.0112404823303225,0.0 -25300,0.11706819,0.03368289,,,,,,,,,,,,,,,,, -25400,0.085443586,0.030875364,,,,,,,,,,,,,,,,, -25500,0.22676492,0.030859629,,,,,,,,,,,,,,,,, -25600,0.11034724,0.030668613,,,,,,,,,,,,,,,,, -25700,0.13057776,0.030807614,,,,,,,,,,,,,,,,, -25800,0.09347975,0.0262819,,,,,,,,,,,,,,,,, -25900,0.13092731,0.02694117,,,,,,,,,,,,,,,,, -26000,0.09373378,0.02872861,,,,,,,,,,,,,,,,, -26024,,,0.9914472103118896,0.0275970380753278,0.4639268114410071,0.9867309927940368,0.0446482859551906,0.2693065018483168,43793.0,0.9858301281929016,0.0477521158754825,0.260429039874971,43793.0,8417.059885501862,12333.731023311617,8417.059885501862,3914.915863752365,1.0447053909301758,0.0 -26100,0.10036772,0.03274375,,,,,,,,,,,,,,,,, -26200,0.118522026,0.027453136,,,,,,,,,,,,,,,,, -26300,0.09645198,0.029531704,,,,,,,,,,,,,,,,, -26400,0.10355748,0.03190114,,,,,,,,,,,,,,,,, -26500,0.079072304,0.030764181,,,,,,,,,,,,,,,,, -26600,0.15135756,0.030518934,,,,,,,,,,,,,,,,, -26700,0.08217952,0.027449436,,,,,,,,,,,,,,,,, -26765,,,0.9916866421699524,0.0269755423069,0.4817028067889763,0.9868255853652954,0.0443985722959041,0.2758032396284843,43793.0,0.9860129356384276,0.0470789335668087,0.2639423076898141,43793.0,8657.316404104233,12682.5444586277,8657.316404104233,4023.420888900757,1.0764586925506592,0.0 -26800,0.12564352,0.031153563,,,,,,,,,,,,,,,,, -26900,0.10667296,0.02762316,,,,,,,,,,,,,,,,, -27000,0.081377916,0.026562508,,,,,,,,,,,,,,,,, -27100,0.10878079,0.028957883,,,,,,,,,,,,,,,,, -27200,0.10358591,0.028427841,,,,,,,,,,,,,,,,, -27300,0.14564958,0.026334409,,,,,,,,,,,,,,,,, -27400,0.10146204,0.032725178,,,,,,,,,,,,,,,,, -27500,0.08902009,0.02779387,,,,,,,,,,,,,,,,, -27507,,,0.9914228916168212,0.0279256198555231,0.4597945923570499,0.9868003726005554,0.0443090498447418,0.2730667285241964,43793.0,0.9858406782150269,0.0475322641432285,0.2560654260851063,43793.0,8897.550789117813,13028.423025131226,8897.550789117813,4129.012980937958,1.1086671352386477,0.0 -27600,0.1277747,0.027178034,,,,,,,,,,,,,,,,, -27700,0.08325973,0.02514218,,,,,,,,,,,,,,,,, -27800,0.103078924,0.029019447,,,,,,,,,,,,,,,,, -27900,0.12137801,0.031186137,,,,,,,,,,,,,,,,, -28000,0.10359739,0.028519003,,,,,,,,,,,,,,,,, -28100,0.17840406,0.026060611,,,,,,,,,,,,,,,,, -28200,0.11847715,0.02691201,,,,,,,,,,,,,,,,, -28250,,,0.9913296699523926,0.0280524995177984,0.4551051103635106,0.9868763089179992,0.0447614639997482,0.2724308668815321,43793.0,0.9859476685523988,0.0476632639765739,0.2613195491298001,43793.0,9137.598731279371,13377.213291406631,9137.598731279371,4237.702749490738,1.1413967609405518,0.0 -28300,0.08713887,0.027067108,,,,,,,,,,,,,,,,, -28400,0.101265006,0.028068941,,,,,,,,,,,,,,,,, -28500,0.10898972,0.030141177,,,,,,,,,,,,,,,,, -28600,0.12837097,0.029682232,,,,,,,,,,,,,,,,, -28700,0.10846912,0.030539438,,,,,,,,,,,,,,,,, -28800,0.07900927,0.028486358,,,,,,,,,,,,,,,,, -28900,0.10826957,0.027644351,,,,,,,,,,,,,,,,, -28993,,,0.991166055202484,0.028592262417078,0.448091841051458,0.9867184162139891,0.0448287017643451,0.2768478901678559,43793.0,0.9859139323234558,0.0477691814303398,0.258527015701871,43793.0,9377.87127995491,13723.3050699234,9377.87127995491,4343.46866941452,1.1738147735595703,0.0 -29000,0.09969186,0.028634647,,,,,,,,,,,,,,,,, -29100,0.09518878,0.024729643,,,,,,,,,,,,,,,,, -29200,0.11345313,0.028963761,,,,,,,,,,,,,,,,, -29300,0.12794687,0.02757029,,,,,,,,,,,,,,,,, -29400,0.09928039,0.026122812,,,,,,,,,,,,,,,,, -29500,0.10543552,0.027225226,,,,,,,,,,,,,,,,, -29600,0.13473192,0.02859547,,,,,,,,,,,,,,,,, -29700,0.12954295,0.031051252,,,,,,,,,,,,,,,,, -29745,,,0.991333782672882,0.0280229710042476,0.4609738230187459,0.986785352230072,0.0448971949517726,0.2769324470043009,43793.0,0.9858364462852478,0.0477632582187652,0.2666181185574914,43793.0,9618.070008039474,14068.406420230864,9618.070008039474,4448.318630695343,1.2066032886505127,0.0 -29800,0.104730025,0.028959202,,,,,,,,,,,,,,,,, -29900,0.09308511,0.027966155,,,,,,,,,,,,,,,,, -30000,0.08819442,0.028253717,,,,,,,,,,,,,,,,, -30100,0.090650745,0.028158126,,,,,,,,,,,,,,,,, -30200,0.11775338,0.027619174,,,,,,,,,,,,,,,,, -30300,0.10604989,0.028451607,,,,,,,,,,,,,,,,, -30400,0.13626857,0.029592052,,,,,,,,,,,,,,,,, -30499,,,0.9914437532424928,0.0275432635098695,0.4647404749149301,0.9866364002227784,0.0448987856507301,0.2743575633883375,43793.0,0.9857665300369264,0.0479437932372093,0.2634665263536712,43793.0,9858.108284711838,14414.402602910995,9858.108284711838,4554.223012685776,1.2404890060424805,0.0 -30500,0.13226764,0.0301831,,,,,,,,,,,,,,,,, -30600,0.13926813,0.027455254,,,,,,,,,,,,,,,,, -30700,0.09665813,0.027728837,,,,,,,,,,,,,,,,, -30800,0.10668887,0.028224839,,,,,,,,,,,,,,,,, -30900,0.13785307,0.028756442,,,,,,,,,,,,,,,,, -31000,0.10939129,0.02522112,,,,,,,,,,,,,,,,, -31100,0.10790816,0.030227426,,,,,,,,,,,,,,,,, -31200,0.09318721,0.025478154,,,,,,,,,,,,,,,,, -31246,,,0.991542637348175,0.0271708536893129,0.4721364663286831,0.9867780804634094,0.0449287667870521,0.2721670726653356,43793.0,0.9858992099761964,0.0478649251163005,0.261310824095927,43793.0,10098.096035957336,14764.987814426422,10098.096035957336,4664.768091201782,1.2731256484985352,0.0 -31300,0.11281898,0.028586738,,,,,,,,,,,,,,,,, -31400,0.10039009,0.026788304,,,,,,,,,,,,,,,,, -31500,0.09865778,0.027324548,,,,,,,,,,,,,,,,, -31600,0.091953106,0.027045788,,,,,,,,,,,,,,,,, -31700,0.105907716,0.028049849,,,,,,,,,,,,,,,,, -31800,0.08103709,0.024676166,,,,,,,,,,,,,,,,, -31900,0.086442344,0.024176726,,,,,,,,,,,,,,,,, -31991,,,0.9917091727256776,0.0264940895140171,0.5052146545409436,0.9868125915527344,0.0450993701815605,0.2773346138812802,43793.0,0.9860554933547974,0.0480229407548904,0.2684884152472522,43793.0,10338.08183336258,15111.78766322136,10338.08183336258,4771.527981519699,1.307499647140503,0.0 -32000,0.11892572,0.02744799,,,,,,,,,,,,,,,,, -32100,0.08831571,0.025271984,,,,,,,,,,,,,,,,, -32200,0.09780543,0.028493151,,,,,,,,,,,,,,,,, -32300,0.097212605,0.026527533,,,,,,,,,,,,,,,,, -32400,0.102972426,0.02651129,,,,,,,,,,,,,,,,, -32500,0.09014029,0.026756136,,,,,,,,,,,,,,,,, -32600,0.10879867,0.027127229,,,,,,,,,,,,,,,,, -32700,0.119701385,0.032041747,,,,,,,,,,,,,,,,, -32748,,,0.9920130968093872,0.0255514904856681,0.5120100156677214,0.9867143034934998,0.0450541898608207,0.2664563256925483,43793.0,0.985779583454132,0.0479509383440017,0.2652142004790609,43793.0,10578.065612077711,15454.269760608671,10578.065612077711,4873.972446680069,1.34173321723938,0.0 -32800,0.13549411,0.027122913,,,,,,,,,,,,,,,,, -32900,0.11202732,0.025416886,,,,,,,,,,,,,,,,, -33000,0.11598106,0.026874715,,,,,,,,,,,,,,,,, -33100,0.10183498,0.028363852,,,,,,,,,,,,,,,,, -33200,0.13789725,0.025728054,,,,,,,,,,,,,,,,, -33300,0.09970163,0.026913162,,,,,,,,,,,,,,,,, -33400,0.13446863,0.025044048,,,,,,,,,,,,,,,,, -33481,,,0.9919853806495668,0.0255288798362016,0.5158409082918909,0.986780881881714,0.0451483465731143,0.2665157521690351,43793.0,0.9859358668327332,0.0482783876359462,0.2660083338093088,43793.0,10818.260104179382,15806.758916139604,10818.260104179382,4986.211427688599,1.3757221698760986,0.0 -33500,0.15069126,0.028441653,,,,,,,,,,,,,,,,, -33600,0.17372347,0.0282326,,,,,,,,,,,,,,,,, -33700,0.14025053,0.025902739,,,,,,,,,,,,,,,,, -33800,0.12046373,0.02390339,,,,,,,,,,,,,,,,, -33900,0.118110046,0.027478974,,,,,,,,,,,,,,,,, -34000,0.1446732,0.028216688,,,,,,,,,,,,,,,,, -34100,0.15880476,0.032456305,,,,,,,,,,,,,,,,, -34200,0.10238378,0.025512792,,,,,,,,,,,,,,,,, -34209,,,0.9919201135635376,0.0261252131313085,0.4842983115492911,0.9868369102478028,0.0448420867323875,0.2726715227388324,43793.0,0.9858587980270386,0.0481285229325294,0.2619270975504091,43793.0,11058.269440174105,16153.24194407463,11058.269440174105,5092.624956607819,1.4126989841461182,0.0 -34300,0.11718505,0.02828773,,,,,,,,,,,,,,,,, -34400,0.12078089,0.030809866,,,,,,,,,,,,,,,,, -34500,0.08211133,0.025397303,,,,,,,,,,,,,,,,, -34600,0.11224197,0.024307054,,,,,,,,,,,,,,,,, -34700,0.10394813,0.028076926,,,,,,,,,,,,,,,,, -34800,0.10853259,0.026903395,,,,,,,,,,,,,,,,, -34900,0.18320216,0.028007645,,,,,,,,,,,,,,,,, -34945,,,0.991661548614502,0.0268019344657659,0.4983068987457481,0.9867764711380004,0.0450432524085044,0.2768250623646225,43793.0,0.98597252368927,0.0481930114328861,0.2647074282382107,43793.0,11298.358862400057,16503.35486650467,11298.358862400057,5202.5957906246185,1.445598840713501,0.0 -35000,0.10754332,0.028117076,,,,,,,,,,,,,,,,, -35100,0.12772134,0.03038621,,,,,,,,,,,,,,,,, -35200,0.13400324,0.026037471,,,,,,,,,,,,,,,,, -35300,0.09382946,0.026149642,,,,,,,,,,,,,,,,, -35400,0.13716875,0.029199304,,,,,,,,,,,,,,,,, -35500,0.12708369,0.025835197,,,,,,,,,,,,,,,,, -35600,0.11673727,0.02625477,,,,,,,,,,,,,,,,, -35690,,,0.9916203618049622,0.0269064083695411,0.4801378497014065,0.9867455959320068,0.0451300702989101,0.2683754941878571,43793.0,0.9858074188232422,0.048274740576744,0.2578979398603731,43793.0,11538.379633426666,16849.039041757584,11538.379633426666,5308.205292224884,1.4793949127197266,0.0 -35700,0.123256914,0.026134523,,,,,,,,,,,,,,,,, -35800,0.11875008,0.027254872,,,,,,,,,,,,,,,,, -35900,0.12575702,0.027720014,,,,,,,,,,,,,,,,, -36000,0.11799576,0.027249245,,,,,,,,,,,,,,,,, -36100,0.16092508,0.025735753,,,,,,,,,,,,,,,,, -36200,0.110522844,0.025918799,,,,,,,,,,,,,,,,, -36300,0.11978901,0.027610186,,,,,,,,,,,,,,,,, -36400,0.10125429,0.027245184,,,,,,,,,,,,,,,,, -36436,,,0.9917700290679932,0.0262035969644784,0.5078279699959175,0.9868600368499756,0.0449077449738979,0.279754069660172,43793.0,0.9859842658042908,0.0479830466210842,0.2682538265776146,43793.0,11778.522836446762,17200.367270946503,11778.522836446762,5419.337324857712,1.5126359462738037,0.0 -36500,0.14046817,0.029369634,,,,,,,,,,,,,,,,, -36600,0.111976504,0.02561981,,,,,,,,,,,,,,,,, -36700,0.16275305,0.025361653,,,,,,,,,,,,,,,,, -36800,0.11894762,0.02600778,,,,,,,,,,,,,,,,, -36900,0.12002781,0.027302884,,,,,,,,,,,,,,,,, -37000,0.11889536,0.026393287,,,,,,,,,,,,,,,,, -37100,0.11502461,0.025026392,,,,,,,,,,,,,,,,, -37188,,,0.9918221831321716,0.0260329116135835,0.4927025177955553,0.9868361353874208,0.0448870733380317,0.2795414593961319,43793.0,0.9859581589698792,0.0480935610830783,0.265025083874453,43793.0,12018.514259815216,17546.48204088211,12018.514259815216,5525.407692909241,1.5457723140716553,0.0 -37200,0.10761296,0.025560709,,,,,,,,,,,,,,,,, -37300,0.13817869,0.026232364,,,,,,,,,,,,,,,,, -37400,0.09769777,0.023071233,,,,,,,,,,,,,,,,, -37500,0.09886855,0.024536202,,,,,,,,,,,,,,,,, -37600,0.10308553,0.024926575,,,,,,,,,,,,,,,,, -37700,0.10048157,0.023202807,,,,,,,,,,,,,,,,, -37800,0.14603707,0.02683726,,,,,,,,,,,,,,,,, -37900,0.11407408,0.02453869,,,,,,,,,,,,,,,,, -37936,,,0.9920411109924316,0.0253579262644052,0.521290996199623,0.9868373274803162,0.0454892143607139,0.2716746889496871,43793.0,0.9859695434570312,0.048520628362894,0.2581425592235508,43793.0,12258.54852104187,17889.70045185089,12258.54852104187,5628.5393006801605,1.5792734622955322,0.0 -38000,0.12124412,0.024034113,,,,,,,,,,,,,,,,, -38100,0.11547659,0.021981388,,,,,,,,,,,,,,,,, -38200,0.10521701,0.02432024,,,,,,,,,,,,,,,,, -38300,0.16625719,0.028970446,,,,,,,,,,,,,,,,, -38400,0.13987188,0.023141455,,,,,,,,,,,,,,,,, -38500,0.1294718,0.02482087,,,,,,,,,,,,,,,,, -38600,0.10374128,0.024037518,,,,,,,,,,,,,,,,, -38681,,,0.9922906160354614,0.0244758687913417,0.5456552959656337,0.9868028163909912,0.0452447049319744,0.2777050270904259,43793.0,0.9859328866004944,0.0484734661877155,0.2601893972789249,43793.0,12498.732014417648,18235.73575282097,12498.732014417648,5734.337277889252,1.6127097606658936,0.0 -38700,0.1219764,0.028429534,,,,,,,,,,,,,,,,, -38800,0.108988576,0.025614798,,,,,,,,,,,,,,,,, -38900,0.10566133,0.02573707,,,,,,,,,,,,,,,,, -39000,0.106067665,0.023174666,,,,,,,,,,,,,,,,, -39100,0.10870171,0.022818418,,,,,,,,,,,,,,,,, -39200,0.14590849,0.027666101,,,,,,,,,,,,,,,,, -39300,0.18087757,0.024895208,,,,,,,,,,,,,,,,, -39400,0.10840877,0.026757594,,,,,,,,,,,,,,,,, -39429,,,0.9927371740341188,0.0233322009444236,0.5693393673773479,0.9867658615112304,0.0452545844018459,0.2783459699096976,43793.0,0.985876441001892,0.0485768802464008,0.2619989013770163,43793.0,12738.766446590424,18584.267076969147,12738.766446590424,5842.779334545136,1.6485440731048584,0.0 -39500,0.11925152,0.02494765,,,,,,,,,,,,,,,,, -39600,0.14222558,0.027364086,,,,,,,,,,,,,,,,, -39700,0.12950976,0.028815037,,,,,,,,,,,,,,,,, -39800,0.12047321,0.027284207,,,,,,,,,,,,,,,,, -39900,0.1374024,0.027291795,,,,,,,,,,,,,,,,, -40000,0.16162683,0.024976516,,,,,,,,,,,,,,,,, -40100,0.11154042,0.023520831,,,,,,,,,,,,,,,,, -40178,,,0.9926030039787292,0.0237077176570892,0.5622201792353998,0.9866286516189576,0.0452627837657928,0.272614353780963,43793.0,0.9857779145240784,0.0483333766460418,0.2643365091769727,43793.0,12978.871631383896,18926.955935239792,12978.871631383896,5945.3051841259,1.685373306274414,0.0 -40200,0.12383015,0.021970777,,,,,,,,,,,,,,,,, -40300,0.124488235,0.02675773,,,,,,,,,,,,,,,,, -40400,0.20040362,0.029138362,,,,,,,,,,,,,,,,, -40500,0.1388357,0.027559854,,,,,,,,,,,,,,,,, -40600,0.10831515,0.024108386,,,,,,,,,,,,,,,,, -40700,0.122093596,0.027579475,,,,,,,,,,,,,,,,, -40800,0.13068339,0.024306988,,,,,,,,,,,,,,,,, -40900,0.12305358,0.02658422,,,,,,,,,,,,,,,,, -40922,,,0.9924303889274596,0.0241570603102445,0.531331590513631,0.986861288547516,0.0455635003745555,0.2747620842624564,43793.0,0.9859089255332948,0.0489551685750484,0.262507941629819,43793.0,13219.148606538773,19276.35291576385,13219.148606538773,6054.370141029358,1.7200186252593994,0.0 -41000,0.13318568,0.026466155,,,,,,,,,,,,,,,,, -41100,0.22671421,0.025085526,,,,,,,,,,,,,,,,, -41200,0.12910992,0.025620714,,,,,,,,,,,,,,,,, -41300,0.11599074,0.027435184,,,,,,,,,,,,,,,,, -41400,0.1271121,0.023495153,,,,,,,,,,,,,,,,, -41500,0.16197388,0.024406936,,,,,,,,,,,,,,,,, -41600,0.10488578,0.021666905,,,,,,,,,,,,,,,,, -41668,,,0.9921717047691344,0.0247845407575368,0.5356050889597663,0.986797571182251,0.0456201955676078,0.2794463055468996,43793.0,0.9858528971672058,0.0490683577954769,0.2610362912805396,43793.0,13459.148141384125,19620.42688894272,13459.148141384125,6158.383851766586,1.759134292602539,0.0 -41700,0.14603692,0.024634626,,,,,,,,,,,,,,,,, -41800,0.13069415,0.02595767,,,,,,,,,,,,,,,,, -41900,0.11557842,0.026115984,,,,,,,,,,,,,,,,, -42000,0.13854352,0.021613063,,,,,,,,,,,,,,,,, -42100,0.13591155,0.024310125,,,,,,,,,,,,,,,,, -42200,0.1267052,0.028513636,,,,,,,,,,,,,,,,, -42300,0.119438134,0.0248337,,,,,,,,,,,,,,,,, -42400,0.124140754,0.02676857,,,,,,,,,,,,,,,,, -42412,,,0.9920896291732788,0.0249362252652645,0.5138407443009074,0.9867837429046632,0.0462259538471698,0.2705062063880303,43793.0,0.9858364462852478,0.0495792776346206,0.2634238571883222,43793.0,13699.192795991898,19969.621764421463,13699.192795991898,6267.478013277054,1.794682502746582,0.0 -42500,0.14990912,0.023698444,,,,,,,,,,,,,,,,, -42600,0.17579827,0.023495493,,,,,,,,,,,,,,,,, -42700,0.13489582,0.022370864,,,,,,,,,,,,,,,,, -42800,0.12502486,0.025416851,,,,,,,,,,,,,,,,, -42900,0.12459326,0.022068204,,,,,,,,,,,,,,,,, -43000,0.13895117,0.02660649,,,,,,,,,,,,,,,,, -43100,0.14014597,0.025320591,,,,,,,,,,,,,,,,, -43161,,,0.9922144412994384,0.0246965419501066,0.5323905620648921,0.9867528676986694,0.0458684600889682,0.277431209463629,43793.0,0.9857585430145264,0.0493027940392494,0.2641423804382998,43793.0,13939.405371427536,20312.19924545288,13939.405371427536,6369.788389205933,1.8295137882232664,0.0 -43200,0.11285074,0.022464246,,,,,,,,,,,,,,,,, -43300,0.12339846,0.02187039,,,,,,,,,,,,,,,,, -43400,0.13377766,0.024644068,,,,,,,,,,,,,,,,, -43500,0.14197966,0.023732832,,,,,,,,,,,,,,,,, -43600,0.16568418,0.024287922,,,,,,,,,,,,,,,,, -43700,0.13198633,0.024673168,,,,,,,,,,,,,,,,, -43800,0.13652891,0.022024646,,,,,,,,,,,,,,,,, -43900,0.17840986,0.026364014,,,,,,,,,,,,,,,,, -43909,,,0.9923882484436036,0.0241861324757337,0.5461527346502013,0.9866071343421936,0.04615493491292,0.27460194184325,43793.0,0.9857884645462036,0.0491280667483806,0.2697252043773885,43793.0,14179.438196897509,20661.7832839489,14179.438196897509,6479.280866146088,1.8684642314910889,0.0 -44000,0.1483784,0.023057638,,,,,,,,,,,,,,,,, -44100,0.14140876,0.025927098,,,,,,,,,,,,,,,,, -44200,0.1336503,0.025915656,,,,,,,,,,,,,,,,, -44300,0.13716958,0.024158848,,,,,,,,,,,,,,,,, -44400,0.14449677,0.022318043,,,,,,,,,,,,,,,,, -44500,0.15527993,0.022795653,,,,,,,,,,,,,,,,, -44600,0.17078355,0.02356684,,,,,,,,,,,,,,,,, -44660,,,0.9926992058753968,0.0229804348200559,0.573160671933495,0.9866668581962584,0.0460957139730453,0.2783106684866889,43793.0,0.9857361912727356,0.049146506935358,0.2656135139712025,43793.0,14419.610810041428,21006.03396821022,14419.610810041428,6583.301622629166,1.906172752380371,0.0 -44700,0.12481369,0.023058683,,,,,,,,,,,,,,,,, -44800,0.12408478,0.025794212,,,,,,,,,,,,,,,,, -44900,0.13099514,0.024425652,,,,,,,,,,,,,,,,, -45000,0.17665428,0.024804171,,,,,,,,,,,,,,,,, -45100,0.15774386,0.025994143,,,,,,,,,,,,,,,,, -45200,0.15533409,0.025191832,,,,,,,,,,,,,,,,, -45300,0.13686387,0.023817645,,,,,,,,,,,,,,,,, -45400,0.1571552,0.025480893,,,,,,,,,,,,,,,,, -45411,,,0.992841899394989,0.0224623624235391,0.5858940454668893,0.9867216348648072,0.0465570129454135,0.273357246040638,43793.0,0.9858831763267516,0.050001386553049,0.2606782498808195,43793.0,14659.80555152893,21351.36517882347,14659.80555152893,6688.38343501091,1.94104528427124,0.0 -45500,0.14710097,0.02477,,,,,,,,,,,,,,,,, -45600,0.13671237,0.02067505,,,,,,,,,,,,,,,,, -45700,0.1574154,0.02259048,,,,,,,,,,,,,,,,, -45800,0.1524067,0.022578038,,,,,,,,,,,,,,,,, -45900,0.16830732,0.022248698,,,,,,,,,,,,,,,,, -46000,0.14283164,0.022293445,,,,,,,,,,,,,,,,, -46100,0.14660063,0.024481526,,,,,,,,,,,,,,,,, -46154,,,0.9932268857955932,0.0213325824588537,0.6116497269315911,0.9866270422935486,0.0467495769262313,0.2717662805382729,43793.0,0.9857484102249146,0.0501597858965396,0.262155570037263,43793.0,14900.003878116608,21695.22643852234,14900.003878116608,6791.985899925232,1.9806413650512693,0.0 -46200,0.15547529,0.026264293,,,,,,,,,,,,,,,,, -46300,0.15082571,0.025106922,,,,,,,,,,,,,,,,, -46400,0.13589792,0.022335505,,,,,,,,,,,,,,,,, -46500,0.150032,0.021200445,,,,,,,,,,,,,,,,, -46600,0.1640567,0.021286087,,,,,,,,,,,,,,,,, -46700,0.1405661,0.023850897,,,,,,,,,,,,,,,,, -46800,0.12546849,0.022605505,,,,,,,,,,,,,,,,, -46898,,,0.993324100971222,0.0210649482905864,0.6244343061105562,0.986622154712677,0.0470507629215717,0.2749869542278762,43793.0,0.9857871532440186,0.0503666512668132,0.2615554241055192,43793.0,15140.24789738655,22042.06078195572,15140.24789738655,6898.519830942154,2.017240524291992,0.0 -46900,0.22120622,0.023738002,,,,,,,,,,,,,,,,, -47000,0.14670885,0.02081321,,,,,,,,,,,,,,,,, -47100,0.18003315,0.025368448,,,,,,,,,,,,,,,,, -47200,0.1380495,0.021463608,,,,,,,,,,,,,,,,, -47300,0.15568173,0.021428201,,,,,,,,,,,,,,,,, -47400,0.16212292,0.022715606,,,,,,,,,,,,,,,,, -47500,0.15139867,0.0217014,,,,,,,,,,,,,,,,, -47600,0.13986956,0.021000138,,,,,,,,,,,,,,,,, -47655,,,0.9933059215545654,0.0211957413703203,0.6117602752374381,0.9866501688957214,0.0472495593130588,0.2753124853835367,43793.0,0.9858199954032898,0.0504102669656276,0.2589229370037478,43793.0,15380.289312124252,22384.61990666389,15380.289312124252,7000.982522249222,2.052924633026123,0.0 -47700,0.13735642,0.020003092,,,,,,,,,,,,,,,,, -47800,0.1463703,0.022334946,,,,,,,,,,,,,,,,, -47900,0.13528906,0.020874955,,,,,,,,,,,,,,,,, -48000,0.16870062,0.023360448,,,,,,,,,,,,,,,,, -48100,0.16010739,0.02341183,,,,,,,,,,,,,,,,, -48200,0.14765534,0.022343783,,,,,,,,,,,,,,,,, -48300,0.15198925,0.02296871,,,,,,,,,,,,,,,,, -48397,,,0.9927676320075988,0.0228091683238744,0.5803819765058651,0.986539363861084,0.0478117503225803,0.2669560169577921,43793.0,0.985643982887268,0.051022358238697,0.2503140930026999,43793.0,15620.493504047394,22733.268282413483,15620.493504047394,7109.367336988449,2.09198260307312,0.0 -48400,0.13692689,0.020228606,,,,,,,,,,,,,,,,, -48500,0.15410727,0.023780694,,,,,,,,,,,,,,,,, -48600,0.18246831,0.018379247,,,,,,,,,,,,,,,,, -48700,0.16222301,0.022360263,,,,,,,,,,,,,,,,, -48800,0.18991527,0.021191435,,,,,,,,,,,,,,,,, -48900,0.13718462,0.022494884,,,,,,,,,,,,,,,,, -49000,0.1909854,0.022830095,,,,,,,,,,,,,,,,, -49100,0.15946627,0.02309107,,,,,,,,,,,,,,,,, -49145,,,0.9926386475563048,0.0231125317513942,0.5560028585107328,0.986544668674469,0.0476967096328735,0.2754886023087445,43793.0,0.9856220483779908,0.05133892968297,0.2546934945497422,43793.0,15860.683614492416,23075.4584479332,15860.683614492416,7211.308455705643,2.13135313987732,0.0 -49200,0.16683173,0.023038026,,,,,,,,,,,,,,,,, -49300,0.1730268,0.020879209,,,,,,,,,,,,,,,,, -49400,0.13976745,0.022509215,,,,,,,,,,,,,,,,, -49500,0.2967876,0.0216642,,,,,,,,,,,,,,,,, -49600,0.16009603,0.020551618,,,,,,,,,,,,,,,,, -49700,0.17480592,0.02372357,,,,,,,,,,,,,,,,, -49800,0.14876685,0.020651827,,,,,,,,,,,,,,,,, -49889,,,0.992885947227478,0.0223179999738931,0.5716329006994851,0.9866303205490112,0.0478951223194599,0.2716456802267858,43793.0,0.985731601715088,0.0513281859457492,0.256285389852942,43793.0,16100.88074851036,23424.28323435784,16100.88074851036,7319.88063287735,2.16735315322876,0.0 -49900,0.19719134,0.021389307,,,,,,,,,,,,,,,,, -50000,0.14810051,0.017312184,,,,,,,,,,,,,,,,, -50100,0.19066218,0.02201532,,,,,,,,,,,,,,,,, -50200,0.17327933,0.020549355,,,,,,,,,,,,,,,,, -50300,0.16700532,0.021712665,,,,,,,,,,,,,,,,, -50400,0.15321745,0.018744629,,,,,,,,,,,,,,,,, -50500,0.15031798,0.020851746,,,,,,,,,,,,,,,,, -50600,0.1603009,0.019913292,,,,,,,,,,,,,,,,, -50635,,,0.9929296374320984,0.0218769125640392,0.6079837780471014,0.9864983558654784,0.0479838177561759,0.2703280493989189,43793.0,0.9856178760528564,0.0514837466180324,0.2527862156016627,43793.0,16341.046777009964,23769.31679654121,16341.046777009964,7424.692152500153,2.204059839248657,0.0 -50700,0.187813,0.024096161,,,,,,,,,,,,,,,,, -50800,0.19957569,0.02370697,,,,,,,,,,,,,,,,, -50900,0.1658345,0.021271758,,,,,,,,,,,,,,,,, -51000,0.14283973,0.019198483,,,,,,,,,,,,,,,,, -51100,0.16901685,0.019048145,,,,,,,,,,,,,,,,, -51200,0.19009046,0.020884678,,,,,,,,,,,,,,,,, -51300,0.18659596,0.020399308,,,,,,,,,,,,,,,,, -51387,,,0.99312424659729,0.0213737860321998,0.6051922907115049,0.986455738544464,0.0481710955500602,0.2786032396239108,43793.0,0.985525608062744,0.0517373047769069,0.25799338239721,43793.0,16581.230067253113,24114.57426643372,16581.230067253113,7529.710796117783,2.239814043045044,0.0 -51400,0.16733319,0.020299917,,,,,,,,,,,,,,,,, -51500,0.24201816,0.02103727,,,,,,,,,,,,,,,,, -51600,0.17517053,0.01868013,,,,,,,,,,,,,,,,, -51700,0.17790444,0.020874282,,,,,,,,,,,,,,,,, -51800,0.18061341,0.02057172,,,,,,,,,,,,,,,,, -51900,0.18845785,0.024398722,,,,,,,,,,,,,,,,, -52000,0.18438555,0.018895965,,,,,,,,,,,,,,,,, -52100,0.19223951,0.020741545,,,,,,,,,,,,,,,,, -52126,,,0.9935033917427064,0.0201828982681036,0.6243054543350697,0.9865494966506958,0.0487725138664245,0.2713997153176349,43793.0,0.9856654405593872,0.0520572513341903,0.2631199337380247,43793.0,16821.317676067352,24455.993393421173,16821.317676067352,7630.984179973602,2.2781577110290527,0.0 -52200,0.16460893,0.020055931,,,,,,,,,,,,,,,,, -52300,0.19666918,0.021156793,,,,,,,,,,,,,,,,, -52400,0.2357962,0.018078396,,,,,,,,,,,,,,,,, -52500,0.19642217,0.021608936,,,,,,,,,,,,,,,,, -52600,0.19221126,0.020599242,,,,,,,,,,,,,,,,, -52700,0.1928895,0.020768927,,,,,,,,,,,,,,,,, -52800,0.1896994,0.019050447,,,,,,,,,,,,,,,,, -52874,,,0.9940038323402404,0.0186669565737247,0.6656468057073384,0.9866250157356262,0.0489798784255981,0.2706539868005573,43793.0,0.9856258630752563,0.0528428927063941,0.2543117989191361,43793.0,17061.454341888428,24799.02546787262,17061.454341888428,7733.823001623154,2.314923048019409,0.0 -52900,0.17830223,0.016565295,,,,,,,,,,,,,,,,, -53000,0.20355631,0.020406716,,,,,,,,,,,,,,,,, -53100,0.16470645,0.02144255,,,,,,,,,,,,,,,,, -53200,0.24846318,0.021662153,,,,,,,,,,,,,,,,, -53300,0.19254325,0.018298633,,,,,,,,,,,,,,,,, -53400,0.18597242,0.020859808,,,,,,,,,,,,,,,,, -53500,0.18395653,0.020812621,,,,,,,,,,,,,,,,, -53600,0.17567514,0.020015577,,,,,,,,,,,,,,,,, -53612,,,0.9942200183868408,0.0180659666657447,0.6819209665010129,0.9864720106124878,0.0495971888303756,0.2684478956022052,43793.0,0.9855963587760924,0.0533251948654651,0.2544714874167593,43793.0,17301.6496155262,25150.59912610054,17301.6496155262,7845.143984079361,2.351712226867676,0.0 -53700,0.19331384,0.01896084,,,,,,,,,,,,,,,,, -53800,0.26398012,0.020670796,,,,,,,,,,,,,,,,, -53900,0.18749891,0.016165107,,,,,,,,,,,,,,,,, -54000,0.19614023,0.01947771,,,,,,,,,,,,,,,,, -54100,0.1695908,0.018162685,,,,,,,,,,,,,,,,, -54200,0.18436573,0.017805193,,,,,,,,,,,,,,,,, -54300,0.22168273,0.019231794,,,,,,,,,,,,,,,,, -54367,,,0.994297206401825,0.0178672652691602,0.6827159617126499,0.9865003824234008,0.0496802628040313,0.2730149322743814,43793.0,0.9855533838272096,0.0534865520894527,0.2591776566716485,43793.0,17541.733896255493,25496.44469404221,17541.733896255493,7950.844016551971,2.3931870460510254,0.0 -54400,0.1906611,0.019206367,,,,,,,,,,,,,,,,, -54500,0.19659328,0.018368261,,,,,,,,,,,,,,,,, -54600,0.23210935,0.019609494,,,,,,,,,,,,,,,,, -54700,0.23011631,0.02051221,,,,,,,,,,,,,,,,, -54800,0.18980229,0.017285587,,,,,,,,,,,,,,,,, -54900,0.18964608,0.018330121,,,,,,,,,,,,,,,,, -55000,0.17200443,0.018594593,,,,,,,,,,,,,,,,, -55100,0.2967151,0.019147826,,,,,,,,,,,,,,,,, -55116,,,0.9939027428627014,0.0188218895345926,0.6740066988974451,0.9864963293075562,0.0502084381878376,0.2762999691316146,43793.0,0.985566020011902,0.0536334998905658,0.2579593188538949,43793.0,17781.83658337593,25844.768416643143,17781.83658337593,8059.008177280426,2.4300525188446045,0.0 -55200,0.21561469,0.01921948,,,,,,,,,,,,,,,,, -55300,0.2829247,0.019710217,,,,,,,,,,,,,,,,, -55400,0.19421977,0.018830225,,,,,,,,,,,,,,,,, -55500,0.2174335,0.021044744,,,,,,,,,,,,,,,,, -55600,0.2416606,0.020009812,,,,,,,,,,,,,,,,, -55700,0.19913425,0.019104932,,,,,,,,,,,,,,,,, -55800,0.19367217,0.017538283,,,,,,,,,,,,,,,,, -55871,,,0.9938641786575316,0.0188779514282941,0.6532421190749456,0.9863501787185668,0.0509905144572258,0.2727208477607921,43793.0,0.9854717254638672,0.0546558536589145,0.2514616051676566,43793.0,18021.98615550995,26187.977236509323,18021.98615550995,8162.010338068008,2.467262029647827,0.0 -55900,0.2283731,0.02063952,,,,,,,,,,,,,,,,, -56000,0.20434053,0.017291805,,,,,,,,,,,,,,,,, -56100,0.21188515,0.018006155,,,,,,,,,,,,,,,,, -56200,0.22481431,0.018805873,,,,,,,,,,,,,,,,, -56300,0.20529453,0.017140731,,,,,,,,,,,,,,,,, -56400,0.20457312,0.018704854,,,,,,,,,,,,,,,,, -56500,0.20750278,0.017918764,,,,,,,,,,,,,,,,, -56600,0.2373483,0.018878598,,,,,,,,,,,,,,,,, -56616,,,0.9935514330863952,0.0199198331683874,0.6342096266526925,0.98628568649292,0.051337894052267,0.2656026553756837,43793.0,0.985351026058197,0.055053848773241,0.2538779185836647,43793.0,18262.12407398224,26531.27014303208,18262.12407398224,8265.10645365715,2.506730794906616,0.0 -56700,0.20247404,0.017297111,,,,,,,,,,,,,,,,, -56800,0.25879198,0.017123615,,,,,,,,,,,,,,,,, -56900,0.24395241,0.016899979,,,,,,,,,,,,,,,,, -57000,0.20487665,0.018153138,,,,,,,,,,,,,,,,, -57100,0.22766806,0.017469235,,,,,,,,,,,,,,,,, -57200,0.23083131,0.018789748,,,,,,,,,,,,,,,,, -57282,,,,,,,,,,,,,,18477.27854990959,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 87d05f72c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -108.1970920562744,0.0,12.220922231674194,1,0,12.220922231674194,0.5224818587303162,0.7161954641342163,0.0278457039326483,43793,120.41805958747864,0.5250684022903442,0.715176522731781,0.0237950027495085,0.5213832855224609,0.7166012525558472,0.0261413871170089,43793 -217.11013174057007,0.0211904048919677,252.46294784545896,742,0,252.46294784545896,0.9831947088241576,0.0670140460133552,0.0409508164805676,43793,469.6144058704376,0.9868351221084596,0.0535385869443416,0.0419205494794269,0.9841658473014832,0.0638294294476509,0.0402648250883774,43793 -322.55107712745667,0.0476987361907959,492.63631868362427,1482,0,492.63631868362427,0.9832208156585692,0.0637183338403701,0.0687140178744884,43793,815.2754602432251,0.9870707988739014,0.0494044385850429,0.072728017218383,0.9842279553413392,0.0601286329329013,0.0695433154886204,43793 -431.5114159584045,0.0747332572937011,732.8385767936707,2230,0,732.8385767936707,0.9836769700050354,0.0601913519203662,0.1044002999493918,43793,1164.4848325252533,0.9873882532119752,0.0467899739742279,0.1048028276153056,0.9846578240394592,0.0567194893956184,0.0997552259313046,43793 -536.5080606937408,0.1023402214050293,972.8108472824096,2978,0,972.8108472824096,0.9839507341384888,0.0563553124666214,0.1253405511373911,43793,1509.5015892982483,0.9875863790512084,0.0444664433598518,0.133050320459334,0.9849082827568054,0.0534515306353569,0.1206809639673781,43793 -645.8973081111908,0.1297438144683838,1212.889556646347,3728,0,1212.889556646347,0.9840299487113952,0.0558254644274711,0.1258086693069963,43793,1859.016436338425,0.9877572655677797,0.0430839397013187,0.1463087512056293,0.9850422739982604,0.0526715070009231,0.1245638568835015,43793 -746.6748743057251,0.156790018081665,1452.92014336586,4478,0,1452.92014336586,0.984278440475464,0.0547199919819831,0.1437030584159531,43793,2199.8714084625244,0.9880750179290771,0.0414841137826442,0.1616528893787325,0.985185980796814,0.0514313615858554,0.144405763800987,43793 -851.5282611846924,0.1850535869598388,1692.9912858009338,5222,0,1692.9912858009338,0.9843176007270812,0.0537596717476844,0.1524632543170833,43793,2544.843962907791,0.9881662726402284,0.0413249917328357,0.1795573077062224,0.9852378964424132,0.0509257391095161,0.1538947014999502,43793 -955.644908428192,0.2118959426879882,1933.113857269287,5973,0,1933.113857269287,0.9843466877937316,0.0536030866205692,0.1532219121304731,43793,2889.129508972168,0.9882667064666748,0.0411820746958255,0.164725269141511,0.9852712154388428,0.0508301928639411,0.1485142866215823,43793 -1056.376545906067,0.2399752140045166,2173.093494415283,6726,0,2173.093494415283,0.98445326089859,0.0532456524670124,0.1665050708613972,43793,3229.888239145279,0.9882901310920716,0.0403869524598121,0.189390218607616,0.9854116439819336,0.0502158813178539,0.1654686194838573,43793 -1161.199847459793,0.2680511474609375,2413.061748027801,7478,0,2413.061748027801,0.9845025539398192,0.0525631867349147,0.1614758018767295,43793,3574.727502822876,0.9882825613021852,0.0407113768160343,0.1801941639881932,0.9854328036308287,0.0499150194227695,0.1604554826060077,43793 -1263.1020185947418,0.2975277900695801,2653.143779039383,8227,0,2653.143779039383,0.9846537113189696,0.0527768693864345,0.1617903733675443,43793,3916.760485887528,0.9884076118469238,0.0400715358555316,0.1891099807972201,0.9855748414993286,0.0497897751629352,0.1705180084490086,43793 -1368.509529352188,0.3260746002197265,2893.2505328655243,8974,0,2893.2505328655243,0.9845378994941713,0.0523785389959812,0.166351327124705,43793,4262.32323884964,0.9886220097541808,0.0399004854261875,0.1923071782334979,0.9855508804321288,0.0497305281460285,0.1655544296563194,43793 -1470.3253271579742,0.3538753986358642,3133.2117640972137,9727,0,3133.2117640972137,0.984627604484558,0.0532267913222312,0.1667354992957371,43793,4604.147631645203,0.9885628819465636,0.0397010818123817,0.1998062097585312,0.985582947731018,0.0500649698078632,0.1718920686631062,43793 -1576.149382352829,0.384164810180664,3373.1918189525604,10476,0,3373.1918189525604,0.9844890236854552,0.0524704307317733,0.1669158520132354,43793,4950.001894235611,0.9885133504867554,0.0394293665885925,0.200452295893541,0.9854490160942078,0.0496672354638576,0.1666263098371183,43793 -1679.6482055187223,0.4154295921325683,3613.176426172257,11226,0,3613.176426172257,0.9845783710479736,0.0523843765258789,0.1724580365021767,43793,5293.536472797394,0.9886851906776428,0.0390754826366901,0.2089894518928057,0.9855192303657532,0.0493711568415164,0.1749158195471736,43793 -1785.6903955936432,0.4441530704498291,3853.230010032654,11977,0,3853.230010032654,0.9847741723060608,0.0515562482178211,0.1748346744354908,43793,5639.680479049683,0.9885552525520324,0.0391332358121871,0.2056036590679171,0.985720992088318,0.0487687140703201,0.1817230665469874,43793 -1892.4185304641724,0.4730062484741211,4093.308432340622,12734,0,4093.308432340622,0.9846027493476868,0.0522534400224685,0.1657067585304924,43793,5986.535817146301,0.9885872602462769,0.0398699790239334,0.2044354306931992,0.9855306148529052,0.0496048107743263,0.1690678095134202,43793 -1997.831268787384,0.5020365715026855,4333.41264462471,13486,0,4333.41264462471,0.984676480293274,0.0518286414444446,0.1755089764659194,43793,6332.10172867775,0.9886804819107056,0.0392939932644367,0.2035504562494897,0.9856122136116028,0.048954602330923,0.1798079085889805,43793 -2103.823152303696,0.5319912433624268,4573.500229597092,14228,0,4573.500229597092,0.9847055673599244,0.0521335490047931,0.1784581841483761,43793,6678.234007120132,0.9885827898979188,0.0393496677279472,0.1989838706543882,0.9857429265975952,0.0490309186279773,0.1796446625647285,43793 -2209.81316947937,0.5619616508483887,4813.4798884391785,14978,0,4813.4798884391785,0.9847737550735474,0.0518770404160022,0.1807104800860751,43793,7024.253228664398,0.9885103106498718,0.0394258648157119,0.2103199067351833,0.9857169389724731,0.0490715727210044,0.1809956758013063,43793 -2309.904142856598,0.5947067737579346,5053.459381818771,15730,0,5053.459381818771,0.9846811294555664,0.0517447367310524,0.1799350506932672,43793,7364.3765437603,0.9885491132736206,0.0389850884675979,0.2031699322901911,0.985655665397644,0.0487467497587204,0.1846246212844854,43793 -2413.569778442383,0.625511884689331,5293.689656734467,16481,0,5293.689656734467,0.9846815466880798,0.0517053455114364,0.1696599787089885,43793,7708.322807788849,0.9886710047721864,0.039756067097187,0.2001787874556077,0.9856000542640686,0.0490147173404693,0.1758140324766489,43793 -2516.3563482761383,0.6553435325622559,5533.651293992996,17235,0,5533.651293992996,0.9848251342773438,0.0511091090738773,0.1786148488164832,43793,8051.120588302612,0.988852858543396,0.0381815806031227,0.2247231109183317,0.9858115315437316,0.0481926798820495,0.1914216357737809,43793 -2619.8639080524445,0.6865098476409912,5773.6067237854,17987,0,5773.6067237854,0.9846655130386353,0.0518954284489154,0.1731670891302522,43793,8394.634209156036,0.988699436187744,0.0385518334805965,0.2194349941750594,0.9855996370315552,0.0491237789392471,0.1775695241781132,43793 -2728.1255388259888,0.7215981483459473,6013.775787115097,18728,0,6013.775787115097,0.98477041721344,0.0521637499332428,0.1759919712575796,43793,8743.120332956314,0.9887787699699402,0.0383401215076446,0.2242228187728377,0.9857478141784668,0.0490990057587623,0.1829504384835823,43793 -2832.679259777069,0.7523925304412842,6253.961903810501,19480,0,6253.961903810501,0.9848179817199708,0.0520950555801391,0.1765644162673381,43793,9087.911134004593,0.988776445388794,0.0384742207825183,0.2229867901020646,0.9857891798019408,0.0489357411861419,0.1849021107267856,43793 -2935.130667924881,0.7841694355010986,6494.021353006363,20226,0,6494.021353006363,0.9848251342773438,0.0520286560058593,0.174959123510546,43793,9430.473743200302,0.9885373115539552,0.0390364378690719,0.2137670788115012,0.9857355952262878,0.0490761920809745,0.1825315784667147,43793 -3041.1337745189667,0.8143470287322998,6734.287105321884,20980,0,6734.287105321884,0.984847903251648,0.0528499558568,0.1779089723611548,43793,9776.793322563171,0.988884449005127,0.0387068316340446,0.2198519838767519,0.9858046174049376,0.0497435145080089,0.1831063059045827,43793 -3150.0082376003265,0.8462753295898438,6974.567994594574,21726,0,6974.567994594574,0.9848167300224304,0.0513973832130432,0.1776546379820925,43793,10126.000288248062,0.9888972640037536,0.0385029278695583,0.2088390445070321,0.9857494235038756,0.0482809171080589,0.1844016397267828,43793 -3253.414433002472,0.877194881439209,7214.62762594223,22477,0,7214.62762594223,0.984839916229248,0.0518183521926403,0.1856877932834978,43793,10469.517782449722,0.9887509942054749,0.0384994745254516,0.2163177758440355,0.9858188033103944,0.0485533736646175,0.1951387458625174,43793 -3362.0041534900665,0.9081747531890868,7454.670845508575,23223,0,7454.670845508575,0.9848597049713136,0.0520401038229465,0.173719214320403,43793,10818.201761245728,0.988722324371338,0.0386566929519176,0.2140264105266031,0.9857449531555176,0.0490856431424617,0.1800239360133262,43793 -3466.1369805336,0.9395201206207277,7694.770886421204,23962,0,7694.770886421204,0.9845973253250122,0.052184447646141,0.1705859883723574,43793,11162.487454891205,0.988548755645752,0.0393689386546611,0.2090295690843694,0.9855188131332396,0.0490764528512954,0.1787901620259714,43793 -3570.6146986484528,0.9710242748260498,7934.792539834976,24708,0,7934.792539834976,0.9849161505699158,0.0509380102157592,0.1800086315183968,43793,11507.038450717926,0.9888287782669068,0.0382855348289012,0.2273005416580073,0.9858277440071106,0.0481491759419441,0.1914915996375246,43793 -3676.598792552948,1.0045819282531738,8174.904351234436,25447,0,8174.904351234436,0.9848967790603638,0.0509436763823032,0.1838913211204439,43793,11853.188624620438,0.9889517426490784,0.0376424901187419,0.2403767031906625,0.9858801364898682,0.0479339472949504,0.1939360319991367,43793 -3779.948962688446,1.4443256855010986,8414.594813585281,26174,0,8414.594813585281,0.9846979379653932,0.0513856559991836,0.1769077695913019,43793,12196.693078517914,0.988925576210022,0.0379108116030693,0.2249631482400948,0.9856353402137756,0.0484851188957691,0.1846468363042744,43793 -3884.9562027454376,1.477400302886963,8654.654390335083,26923,0,8654.654390335083,0.9848226308822632,0.0519541725516319,0.1780696550770492,43793,12541.8131275177,0.9887903332710266,0.0383216552436351,0.2137093013014565,0.9858123064041138,0.0486218445003032,0.1889470301621161,43793 -3987.4045078754425,1.5093200206756592,8894.8265645504,27663,0,8894.8265645504,0.984944760799408,0.0511478707194328,0.1803156483087207,43793,12884.48639678955,0.9887416362762452,0.0384444631636142,0.2172426777091902,0.9857981204986572,0.0482029020786285,0.1893011427508064,43793 -4098.485007286072,1.5454697608947754,9135.05375790596,28398,0,9135.05375790596,0.9848782420158386,0.0512909218668937,0.1815005829169209,43793,13235.853038072586,0.9888597130775452,0.0383498929440975,0.2187355641821716,0.985792875289917,0.0483689941465854,0.1914001890482414,43793 -4201.458696365356,1.5778565406799316,9375.042990922928,29139,0,9375.042990922928,0.9849388599395752,0.0508907735347747,0.1844989454817836,43793,13578.868658781052,0.989043354988098,0.0376631803810596,0.2262856218682795,0.9858456254005432,0.0478578992187976,0.1901224715875739,43793 -4306.060763597488,1.6103568077087402,9615.254949569702,29886,0,9615.254949569702,0.984988570213318,0.0511430278420448,0.1817438590122285,43793,13923.735192537308,0.9888697862625122,0.0380624830722808,0.2286901512819851,0.9859182834625244,0.0480234026908874,0.1951326886481693,43793 -4406.300930023193,1.644965410232544,9855.509615659714,30632,0,9855.509615659714,0.9848508834838868,0.0519293956458568,0.177579091184478,43793,14264.285064697266,0.9889466166496276,0.0379223525524139,0.2214070215655898,0.985720992088318,0.0488744862377643,0.1847238076122485,43793 -4508.494157552719,1.681239128112793,10095.65578699112,31367,0,10095.65578699112,0.9847914576530457,0.0509674362838268,0.1805531748658989,43793,14606.683991193771,0.988867163658142,0.0382557809352874,0.2291374449002138,0.9856942296028136,0.0482274815440177,0.1854468354650805,43793 -4616.725874662399,1.714320421218872,10335.89645934105,32111,0,10335.89645934105,0.9849165678024292,0.051450528204441,0.1954261868553897,43793,14955.209387540815,0.988877534866333,0.0377889573574066,0.2309434431589295,0.9858277440071106,0.0486466772854328,0.1995270009315271,43793 -4722.938477993012,1.7510161399841309,10575.972088098526,32848,0,10575.972088098526,0.9847506284713744,0.0512115359306335,0.1892154607444269,43793,15301.55495762825,0.9889434576034546,0.0377963371574878,0.2432826322175141,0.9857007265090942,0.0483184233307838,0.1951626788260467,43793 -4828.792944908142,1.7866477966308594,10816.100275278091,33585,0,10816.100275278091,0.9850643873214722,0.0500930547714233,0.1914235613692715,43793,15647.596687793732,0.9891308546066284,0.0373149178922176,0.2260269949128286,0.9859727025032043,0.0474255979061126,0.1937569568102272,43793 -4935.460122585297,1.8201825618743896,11056.169059515,34329,0,11056.169059515,0.9850370287895204,0.0505309477448463,0.1907580082815553,43793,15994.386904001236,0.9889580011367798,0.0375989601016044,0.2289441943128746,0.9859438538551332,0.0477371960878372,0.2020584578773697,43793 -5045.811078548431,1.8531954288482664,11296.211079359056,35077,0,11296.211079359056,0.9851309657096864,0.0503195337951183,0.1868910440185294,43793,16344.833291053772,0.9889733791351318,0.0374145433306694,0.2383705253421114,0.985992968082428,0.0475658178329467,0.1923817748258664,43793 -5146.181943178177,1.8905532360076904,11536.243942975998,35816,0,11536.243942975998,0.9848508834838868,0.0513558462262153,0.1860502985099147,43793,16685.29561161995,0.9889889359474182,0.0376660488545894,0.2281931754273816,0.9858512878417968,0.0482183136045932,0.1911937336152979,43793 -5248.06134390831,1.9240601062774656,11776.2650411129,36561,0,11776.2650411129,0.984813392162323,0.0513754598796367,0.1791125450303736,43793,17027.249828100204,0.9889605641365052,0.0377225056290626,0.2295263943943149,0.9858009815216064,0.0483268313109874,0.1856618333966113,43793 -5352.1737768650055,1.9577631950378416,12016.520510673525,37307,0,12016.520510673525,0.9851098656654358,0.0500889718532562,0.1917312015242276,43793,17371.671184062958,0.9892958998680116,0.0363865718245506,0.2465555241199644,0.985998272895813,0.0471748933196067,0.2012492139526782,43793 -5458.929462432861,1.9929468631744385,12256.781912088394,38051,0,12256.781912088394,0.9851751923561096,0.050572469830513,0.1954931106804427,43793,17718.74373793602,0.9892456531524658,0.0365177951753139,0.2509095829787848,0.9860972762107848,0.0475809313356876,0.1989474597551918,43793 -5561.9447453022,2.0276620388031006,12496.968587875366,38792,0,12496.968587875366,0.9852147698402404,0.0499976128339767,0.2027752433634516,43793,18062.000724554066,0.9892786145210266,0.0362685807049274,0.2614758905918543,0.9860640168190002,0.0471036955714225,0.2022572474787074,43793 -5664.227452039719,2.0624632835388184,12736.987015485764,39536,0,12736.987015485764,0.98500794172287,0.0498680137097835,0.1890189421100775,43793,18404.35669374466,0.9891278743743896,0.0369930937886238,0.2449056318351088,0.985833466053009,0.0472330823540687,0.196620120614496,43793 -5770.586403608322,2.098353624343872,12977.07556772232,40280,0,12977.07556772232,0.9851473569869996,0.0499475635588169,0.1944470872124818,43793,18750.859877347943,0.9891670346260072,0.0370794646441936,0.2359508469546338,0.985985279083252,0.0471570491790771,0.2041718924831687,43793 -5869.447708368301,2.133057117462158,13217.259890556335,41030,0,13217.259890556335,0.9852400422096252,0.0495612360537052,0.1957468493772349,43793,19089.960252285004,0.9892788529396056,0.0365193933248519,0.2582791490632765,0.9861382842063904,0.0466996915638446,0.2035863850120045,43793 -5967.794405460358,2.167390823364258,13457.330917358398,41783,0,13457.330917358398,0.9851107597351074,0.0501090809702873,0.1927060702893194,43793,19428.43219280243,0.98917156457901,0.036688920110464,0.2375985526014747,0.9860579371452332,0.0470901429653167,0.20033803818553,43793 -6078.01136302948,2.20347547531128,13697.295249938965,42524,0,13697.295249938965,0.9851625561714172,0.0495956428349018,0.2029422827180296,43793,19778.66919779777,0.9892622828483582,0.0363938324153423,0.2609484153302455,0.9860400557518004,0.0468417443335056,0.2093666927716416,43793 -6179.531805515289,2.237947463989258,13937.44949221611,43273,0,13937.44949221611,0.985236644744873,0.0498459227383136,0.2022384611718691,43793,20120.39776873589,0.9893701076507568,0.0361695028841495,0.2488368199052049,0.9862012267112732,0.0469544194638729,0.2082395055408017,43793 -6283.494655847549,2.2736637592315674,14177.5388982296,44019,0,14177.5388982296,0.9852383732795716,0.0496873632073402,0.2024729546533087,43793,20464.505380392075,0.9893001317977904,0.0360747314989566,0.2637567741698877,0.9860668778419496,0.0468038506805896,0.2082817521928868,43793 -6384.748462438583,2.308631896972656,14417.7179479599,44780,0,14417.7179479599,0.9852139353752136,0.0498356223106384,0.1990046135939414,43793,20805.993231773376,0.9894915223121644,0.035425916314125,0.2697776492707348,0.9861618280410768,0.0468348525464534,0.2099969184826789,43793 -6490.651208877564,2.3464467525482178,14657.810967445374,45528,0,14657.810967445374,0.9852522611618042,0.049809843301773,0.1999492152297651,43793,21152.046332597733,0.9893769025802612,0.0358100272715091,0.2722272612931093,0.9861736297607422,0.0468456894159317,0.20630033965734,43793 -6591.992488861084,2.38614821434021,14897.910877466202,46269,0,14897.910877466202,0.9853171110153198,0.0493623949587345,0.2056291911976642,43793,21493.54910182953,0.9895589351654052,0.0355012081563472,0.2688084506959136,0.9862186908721924,0.0464554876089096,0.2145519109500384,43793 -6696.60099697113,2.4226717948913574,15138.077248096466,47009,0,15138.077248096466,0.9852370619773864,0.0499967895448207,0.1964800375281237,43793,21838.38315463066,0.9891928434371948,0.0365955233573913,0.2510488937477338,0.9861078858375548,0.0471156612038612,0.2064479490678904,43793 -6797.390934467316,2.4604079723358154,15378.110914945602,47751,0,15378.110914945602,0.9853697419166564,0.0497363582253456,0.2064511036544932,43793,22179.26458454132,0.9892840385437012,0.0362851954996585,0.2603987360262739,0.9862515330314636,0.046829804778099,0.2111941524972733,43793 -6898.0792760849,2.502638578414917,15618.3207821846,48497,0,15618.3207821846,0.9850791692733764,0.0494124293327331,0.2029720569531877,43793,22520.22537112236,0.9894214272499084,0.0358730033040046,0.2551420798235866,0.9860554933547974,0.0465924367308616,0.2120485444070477,43793 -7000.6140151023865,2.537923812866211,15858.295699596403,49250,0,15858.295699596403,0.9852433800697328,0.0496956780552864,0.2003439402274657,43793,22862.790357112885,0.9894852638244628,0.0357244722545146,0.2669809593413159,0.9862223267555236,0.0465139858424663,0.2131052753435941,43793 -7101.655977487564,2.57417106628418,16098.242151737211,50000,0,16098.242151737211,0.9852758646011353,0.0492729060351848,0.2034335283017547,43793,23203.83507847786,0.9894618391990662,0.0355752818286418,0.2711189138805558,0.9861240983009338,0.0466132201254367,0.2095939907786124,43793 -7206.845845937729,2.612422466278076,16338.338171482086,50747,0,16338.338171482086,0.9854114651679992,0.0491921380162239,0.2062226497609948,43793,23549.17933702469,0.989648938179016,0.0350227542221546,0.2890718598209157,0.9863201379776,0.0462420061230659,0.212382103524577,43793 -7307.999075889587,2.64997935295105,16578.522582292557,51497,0,16578.522582292557,0.9853802919387816,0.0490918084979057,0.2061251770082258,43793,23890.57517528534,0.989640176296234,0.0348977521061897,0.2752238390381377,0.9863294959068298,0.0461780801415443,0.2151606578956752,43793 -7407.201065540314,2.6873366832733154,16818.65753221512,52238,0,16818.65753221512,0.9854455590248108,0.049126137048006,0.2123704547099925,43793,24229.96969652176,0.9896869659423828,0.0347418673336505,0.29000266580794,0.9863424897193908,0.0462528392672538,0.2170042776596787,43793 -7509.441777706146,2.724439859390259,17058.785735607147,52977,0,17058.785735607147,0.985417366027832,0.049002967774868,0.211536483233568,43793,24572.39704155922,0.9897372126579284,0.0343862660229206,0.2859971482991346,0.986326277256012,0.0460726581513881,0.2227344429919973,43793 -7612.057886600494,2.76061749458313,17298.919110774994,53732,0,17298.919110774994,0.9854910969734192,0.048701986670494,0.2097780509319448,43793,24915.203050851826,0.989798903465271,0.0346322394907474,0.273643243785204,0.986401379108429,0.0458697639405727,0.2181591479311418,43793 -7713.9464428424835,2.797553777694702,17538.96286535263,54493,0,17538.96286535263,0.9853390455245972,0.0490941666066646,0.2090533931014886,43793,25257.191864728928,0.989654541015625,0.0349539890885353,0.2809974503744279,0.9862645268440248,0.0462120249867439,0.2155153642334833,43793 -7815.701034784317,2.834656476974488,17779.07443547249,55249,0,17779.07443547249,0.9854329228401184,0.0487629063427448,0.2074281320406993,43793,25599.11487293244,0.989753007888794,0.0345969647169113,0.2854877147970602,0.9863563179969788,0.0458742342889308,0.2150310554975685,43793 -7916.534770727158,2.8772406578063965,18019.22664308548,55990,0,18019.22664308548,0.9854813814163208,0.0485040582716465,0.2149564124202218,43793,25940.16689515114,0.9899150133132936,0.0341606698930263,0.2922762819836991,0.9864853620529176,0.0455940142273902,0.2239880195362973,43793 -8021.491682767868,2.9176416397094727,18259.172052145004,56739,0,18259.172052145004,0.9855051636695862,0.048889946192502975,0.21376523695951066,43793,26285.1296813488,0.9898921847343445,0.03401986509561539,0.29999873610666594,0.986504077911377,0.04590983688831329,0.2238110699225643,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/measurements.csv deleted file mode 100644 index b6954ad64..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/measurements.csv +++ /dev/null @@ -1,654 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.7368343,0.7145253,,,,,,,,,,,,,,,,, -1,,,0.5250684022903442,0.715176522731781,0.0237950027495085,0.5213832855224609,0.7166012525558472,0.0261413871170089,43793.0,0.5224818587303162,0.7161954641342163,0.0278457039326483,43793.0,12.220922231674194,120.41805958747864,12.220922231674194,108.1970920562744,0.0,0.0 -100,0.10305035,0.112475425,,,,,,,,,,,,,,,,, -200,0.006617447,0.05630973,,,,,,,,,,,,,,,,, -300,0.010944607,0.054760095,,,,,,,,,,,,,,,,, -400,0.014553424,0.05808673,,,,,,,,,,,,,,,,, -500,0.024274442,0.05203868,,,,,,,,,,,,,,,,, -600,0.022319114,0.05792223,,,,,,,,,,,,,,,,, -700,0.005230985,0.046618916,,,,,,,,,,,,,,,,, -742,,,0.9868351221084596,0.0535385869443416,0.0419205494794269,0.9841658473014832,0.0638294294476509,0.0402648250883774,43793.0,0.9831947088241576,0.0670140460133552,0.0409508164805676,43793.0,252.46294784545896,469.6144058704376,252.46294784545896,217.11013174057007,0.0211904048919677,0.0 -800,0.010451046,0.05699752,,,,,,,,,,,,,,,,, -900,0.00880578,0.056861665,,,,,,,,,,,,,,,,, -1000,0.012447842,0.055755187,,,,,,,,,,,,,,,,, -1100,0.009900678,0.05301253,,,,,,,,,,,,,,,,, -1200,0.0070153354,0.0500547,,,,,,,,,,,,,,,,, -1300,0.00608548,0.048727617,,,,,,,,,,,,,,,,, -1400,0.007847441,0.046485264,,,,,,,,,,,,,,,,, -1482,,,0.9870707988739014,0.0494044385850429,0.072728017218383,0.9842279553413392,0.0601286329329013,0.0695433154886204,43793.0,0.9832208156585692,0.0637183338403701,0.0687140178744884,43793.0,492.63631868362427,815.2754602432251,492.63631868362427,322.55107712745667,0.0476987361907959,0.0 -1500,0.009242539,0.044253215,,,,,,,,,,,,,,,,, -1600,0.0071481853,0.04825765,,,,,,,,,,,,,,,,, -1700,0.023729118,0.05428505,,,,,,,,,,,,,,,,, -1800,0.016824203,0.052102573,,,,,,,,,,,,,,,,, -1900,0.00919883,0.049722742,,,,,,,,,,,,,,,,, -2000,0.0071174866,0.044314392,,,,,,,,,,,,,,,,, -2100,0.032370735,0.04574344,,,,,,,,,,,,,,,,, -2200,0.010567335,0.046446618,,,,,,,,,,,,,,,,, -2230,,,0.9873882532119752,0.0467899739742279,0.1048028276153056,0.9846578240394592,0.0567194893956184,0.0997552259313046,43793.0,0.9836769700050354,0.0601913519203662,0.1044002999493918,43793.0,732.8385767936707,1164.4848325252533,732.8385767936707,431.5114159584045,0.0747332572937011,0.0 -2300,0.01423386,0.043021187,,,,,,,,,,,,,,,,, -2400,0.024158308,0.046337485,,,,,,,,,,,,,,,,, -2500,0.018543806,0.041773647,,,,,,,,,,,,,,,,, -2600,0.01631477,0.04799132,,,,,,,,,,,,,,,,, -2700,0.012092511,0.041016784,,,,,,,,,,,,,,,,, -2800,0.035713717,0.043012,,,,,,,,,,,,,,,,, -2900,0.02064748,0.03964,,,,,,,,,,,,,,,,, -2978,,,0.9875863790512084,0.0444664433598518,0.133050320459334,0.9849082827568054,0.0534515306353569,0.1206809639673781,43793.0,0.9839507341384888,0.0563553124666214,0.1253405511373911,43793.0,972.8108472824096,1509.5015892982483,972.8108472824096,536.5080606937408,0.1023402214050293,0.0 -3000,0.012839264,0.04646348,,,,,,,,,,,,,,,,, -3100,0.026605964,0.04670595,,,,,,,,,,,,,,,,, -3200,0.037094038,0.04139749,,,,,,,,,,,,,,,,, -3300,0.030378582,0.041849144,,,,,,,,,,,,,,,,, -3400,0.017037284,0.04334672,,,,,,,,,,,,,,,,, -3500,0.016877089,0.04351013,,,,,,,,,,,,,,,,, -3600,0.04171627,0.041228004,,,,,,,,,,,,,,,,, -3700,0.03908229,0.04477712,,,,,,,,,,,,,,,,, -3728,,,0.9877572655677797,0.0430839397013187,0.1463087512056293,0.9850422739982604,0.0526715070009231,0.1245638568835015,43793.0,0.9840299487113952,0.0558254644274711,0.1258086693069963,43793.0,1212.889556646347,1859.016436338425,1212.889556646347,645.8973081111908,0.1297438144683838,0.0 -3800,0.035344806,0.042486582,,,,,,,,,,,,,,,,, -3900,0.023014924,0.044879675,,,,,,,,,,,,,,,,, -4000,0.033737507,0.046485357,,,,,,,,,,,,,,,,, -4100,0.039595947,0.040264484,,,,,,,,,,,,,,,,, -4200,0.019132765,0.046602905,,,,,,,,,,,,,,,,, -4300,0.031147785,0.045581233,,,,,,,,,,,,,,,,, -4400,0.03244446,0.034473293,,,,,,,,,,,,,,,,, -4478,,,0.9880750179290771,0.0414841137826442,0.1616528893787325,0.985185980796814,0.0514313615858554,0.144405763800987,43793.0,0.984278440475464,0.0547199919819831,0.1437030584159531,43793.0,1452.92014336586,2199.8714084625244,1452.92014336586,746.6748743057251,0.156790018081665,0.0 -4500,0.014154301,0.042609498,,,,,,,,,,,,,,,,, -4600,0.029033428,0.049056564,,,,,,,,,,,,,,,,, -4700,0.056677777,0.045723032,,,,,,,,,,,,,,,,, -4800,0.04781185,0.043670718,,,,,,,,,,,,,,,,, -4900,0.047366627,0.040521488,,,,,,,,,,,,,,,,, -5000,0.02676586,0.039941683,,,,,,,,,,,,,,,,, -5100,0.05769919,0.042226955,,,,,,,,,,,,,,,,, -5200,0.03399419,0.042915497,,,,,,,,,,,,,,,,, -5222,,,0.9881662726402284,0.0413249917328357,0.1795573077062224,0.9852378964424132,0.0509257391095161,0.1538947014999502,43793.0,0.9843176007270812,0.0537596717476844,0.1524632543170833,43793.0,1692.9912858009338,2544.843962907791,1692.9912858009338,851.5282611846924,0.1850535869598388,0.0 -5300,0.038418334,0.041256547,,,,,,,,,,,,,,,,, -5400,0.03135825,0.0407461,,,,,,,,,,,,,,,,, -5500,0.024009451,0.04309972,,,,,,,,,,,,,,,,, -5600,0.03537132,0.036380395,,,,,,,,,,,,,,,,, -5700,0.026727822,0.03983527,,,,,,,,,,,,,,,,, -5800,0.022683151,0.037437543,,,,,,,,,,,,,,,,, -5900,0.034352534,0.040627263,,,,,,,,,,,,,,,,, -5973,,,0.9882667064666748,0.0411820746958255,0.164725269141511,0.9852712154388428,0.0508301928639411,0.1485142866215823,43793.0,0.9843466877937316,0.0536030866205692,0.1532219121304731,43793.0,1933.113857269287,2889.129508972168,1933.113857269287,955.644908428192,0.2118959426879882,0.0 -6000,0.024012605,0.039799962,,,,,,,,,,,,,,,,, -6100,0.034868017,0.041321743,,,,,,,,,,,,,,,,, -6200,0.061948013,0.038903814,,,,,,,,,,,,,,,,, -6300,0.045907587,0.041044533,,,,,,,,,,,,,,,,, -6400,0.03976306,0.03694747,,,,,,,,,,,,,,,,, -6500,0.029418437,0.03827048,,,,,,,,,,,,,,,,, -6600,0.028066458,0.040336404,,,,,,,,,,,,,,,,, -6700,0.039791238,0.042275526,,,,,,,,,,,,,,,,, -6726,,,0.9882901310920716,0.0403869524598121,0.189390218607616,0.9854116439819336,0.0502158813178539,0.1654686194838573,43793.0,0.98445326089859,0.0532456524670124,0.1665050708613972,43793.0,2173.093494415283,3229.888239145279,2173.093494415283,1056.376545906067,0.2399752140045166,0.0 -6800,0.05939569,0.041469786,,,,,,,,,,,,,,,,, -6900,0.06392614,0.043293573,,,,,,,,,,,,,,,,, -7000,0.048981898,0.03779842,,,,,,,,,,,,,,,,, -7100,0.038276765,0.04124221,,,,,,,,,,,,,,,,, -7200,0.04241562,0.0385511,,,,,,,,,,,,,,,,, -7300,0.05371325,0.04142221,,,,,,,,,,,,,,,,, -7400,0.027378079,0.03583935,,,,,,,,,,,,,,,,, -7478,,,0.9882825613021852,0.0407113768160343,0.1801941639881932,0.9854328036308287,0.0499150194227695,0.1604554826060077,43793.0,0.9845025539398192,0.0525631867349147,0.1614758018767295,43793.0,2413.061748027801,3574.727502822876,2413.061748027801,1161.199847459793,0.2680511474609375,0.0 -7500,0.04125508,0.041006655,,,,,,,,,,,,,,,,, -7600,0.03338652,0.038397923,,,,,,,,,,,,,,,,, -7700,0.05378157,0.045647208,,,,,,,,,,,,,,,,, -7800,0.051837537,0.041477527,,,,,,,,,,,,,,,,, -7900,0.05421332,0.038681496,,,,,,,,,,,,,,,,, -8000,0.046004623,0.042042077,,,,,,,,,,,,,,,,, -8100,0.032108918,0.045057096,,,,,,,,,,,,,,,,, -8200,0.035772085,0.03923527,,,,,,,,,,,,,,,,, -8227,,,0.9884076118469238,0.0400715358555316,0.1891099807972201,0.9855748414993286,0.0497897751629352,0.1705180084490086,43793.0,0.9846537113189696,0.0527768693864345,0.1617903733675443,43793.0,2653.143779039383,3916.760485887528,2653.143779039383,1263.1020185947418,0.2975277900695801,0.0 -8300,0.058830734,0.041503776,,,,,,,,,,,,,,,,, -8400,0.041798204,0.041223954,,,,,,,,,,,,,,,,, -8500,0.036167964,0.03725516,,,,,,,,,,,,,,,,, -8600,0.024888827,0.040679213,,,,,,,,,,,,,,,,, -8700,0.024759423,0.04060987,,,,,,,,,,,,,,,,, -8800,0.081916116,0.03954278,,,,,,,,,,,,,,,,, -8900,0.040631607,0.04757808,,,,,,,,,,,,,,,,, -8974,,,0.9886220097541808,0.0399004854261875,0.1923071782334979,0.9855508804321288,0.0497305281460285,0.1655544296563194,43793.0,0.9845378994941713,0.0523785389959812,0.166351327124705,43793.0,2893.2505328655243,4262.32323884964,2893.2505328655243,1368.509529352188,0.3260746002197265,0.0 -9000,0.038480196,0.03980502,,,,,,,,,,,,,,,,, -9100,0.060324304,0.042510726,,,,,,,,,,,,,,,,, -9200,0.07636171,0.03615807,,,,,,,,,,,,,,,,, -9300,0.04932866,0.039245192,,,,,,,,,,,,,,,,, -9400,0.04500456,0.03652978,,,,,,,,,,,,,,,,, -9500,0.04104056,0.033546034,,,,,,,,,,,,,,,,, -9600,0.06706842,0.039628573,,,,,,,,,,,,,,,,, -9700,0.05933567,0.039046656,,,,,,,,,,,,,,,,, -9727,,,0.9885628819465636,0.0397010818123817,0.1998062097585312,0.985582947731018,0.0500649698078632,0.1718920686631062,43793.0,0.984627604484558,0.0532267913222312,0.1667354992957371,43793.0,3133.2117640972137,4604.147631645203,3133.2117640972137,1470.3253271579742,0.3538753986358642,0.0 -9800,0.05856827,0.04067842,,,,,,,,,,,,,,,,, -9900,0.051763467,0.040671226,,,,,,,,,,,,,,,,, -10000,0.057560347,0.03776233,,,,,,,,,,,,,,,,, -10100,0.029370967,0.042501893,,,,,,,,,,,,,,,,, -10200,0.046185005,0.03750496,,,,,,,,,,,,,,,,, -10300,0.07545597,0.043615256,,,,,,,,,,,,,,,,, -10400,0.06428292,0.040323522,,,,,,,,,,,,,,,,, -10476,,,0.9885133504867554,0.0394293665885925,0.200452295893541,0.9854490160942078,0.0496672354638576,0.1666263098371183,43793.0,0.9844890236854552,0.0524704307317733,0.1669158520132354,43793.0,3373.1918189525604,4950.001894235611,3373.1918189525604,1576.149382352829,0.384164810180664,0.0 -10500,0.025047164,0.039023135,,,,,,,,,,,,,,,,, -10600,0.09792573,0.041609965,,,,,,,,,,,,,,,,, -10700,0.044829268,0.038886763,,,,,,,,,,,,,,,,, -10800,0.051381994,0.03778449,,,,,,,,,,,,,,,,, -10900,0.05753196,0.044342965,,,,,,,,,,,,,,,,, -11000,0.0300942,0.04003148,,,,,,,,,,,,,,,,, -11100,0.050354596,0.03839289,,,,,,,,,,,,,,,,, -11200,0.036299303,0.041431013,,,,,,,,,,,,,,,,, -11226,,,0.9886851906776428,0.0390754826366901,0.2089894518928057,0.9855192303657532,0.0493711568415164,0.1749158195471736,43793.0,0.9845783710479736,0.0523843765258789,0.1724580365021767,43793.0,3613.176426172257,5293.536472797394,3613.176426172257,1679.6482055187223,0.4154295921325683,0.0 -11300,0.051483985,0.042199533,,,,,,,,,,,,,,,,, -11400,0.031236658,0.042346217,,,,,,,,,,,,,,,,, -11500,0.03691016,0.038549606,,,,,,,,,,,,,,,,, -11600,0.040110625,0.03775836,,,,,,,,,,,,,,,,, -11700,0.031838074,0.036310658,,,,,,,,,,,,,,,,, -11800,0.034077495,0.03773262,,,,,,,,,,,,,,,,, -11900,0.064232536,0.043850094,,,,,,,,,,,,,,,,, -11977,,,0.9885552525520324,0.0391332358121871,0.2056036590679171,0.985720992088318,0.0487687140703201,0.1817230665469874,43793.0,0.9847741723060608,0.0515562482178211,0.1748346744354908,43793.0,3853.230010032654,5639.680479049683,3853.230010032654,1785.6903955936432,0.4441530704498291,0.0 -12000,0.07272796,0.039440755,,,,,,,,,,,,,,,,, -12100,0.027671166,0.039356258,,,,,,,,,,,,,,,,, -12200,0.05983907,0.037625983,,,,,,,,,,,,,,,,, -12300,0.03381056,0.034575164,,,,,,,,,,,,,,,,, -12400,0.04340146,0.03695317,,,,,,,,,,,,,,,,, -12500,0.028532153,0.036931984,,,,,,,,,,,,,,,,, -12600,0.045498885,0.039662797,,,,,,,,,,,,,,,,, -12700,0.031901356,0.037541665,,,,,,,,,,,,,,,,, -12734,,,0.9885872602462769,0.0398699790239334,0.2044354306931992,0.9855306148529052,0.0496048107743263,0.1690678095134202,43793.0,0.9846027493476868,0.0522534400224685,0.1657067585304924,43793.0,4093.308432340622,5986.535817146301,4093.308432340622,1892.4185304641724,0.4730062484741211,0.0 -12800,0.06893088,0.041548595,,,,,,,,,,,,,,,,, -12900,0.12120799,0.04110231,,,,,,,,,,,,,,,,, -13000,0.026178407,0.042760074,,,,,,,,,,,,,,,,, -13100,0.09015995,0.042645443,,,,,,,,,,,,,,,,, -13200,0.083888955,0.04676535,,,,,,,,,,,,,,,,, -13300,0.047950953,0.03621453,,,,,,,,,,,,,,,,, -13400,0.11487919,0.045439985,,,,,,,,,,,,,,,,, -13486,,,0.9886804819107056,0.0392939932644367,0.2035504562494897,0.9856122136116028,0.048954602330923,0.1798079085889805,43793.0,0.984676480293274,0.0518286414444446,0.1755089764659194,43793.0,4333.41264462471,6332.10172867775,4333.41264462471,1997.831268787384,0.5020365715026855,0.0 -13500,0.058788247,0.033921923,,,,,,,,,,,,,,,,, -13600,0.035537772,0.040614497,,,,,,,,,,,,,,,,, -13700,0.059296586,0.041113555,,,,,,,,,,,,,,,,, -13800,0.035596408,0.038502797,,,,,,,,,,,,,,,,, -13900,0.039518874,0.04182743,,,,,,,,,,,,,,,,, -14000,0.06053439,0.045370277,,,,,,,,,,,,,,,,, -14100,0.065602444,0.043827064,,,,,,,,,,,,,,,,, -14200,0.03045004,0.03840446,,,,,,,,,,,,,,,,, -14228,,,0.9885827898979188,0.0393496677279472,0.1989838706543882,0.9857429265975952,0.0490309186279773,0.1796446625647285,43793.0,0.9847055673599244,0.0521335490047931,0.1784581841483761,43793.0,4573.500229597092,6678.234007120132,4573.500229597092,2103.823152303696,0.5319912433624268,0.0 -14300,0.05616613,0.03814682,,,,,,,,,,,,,,,,, -14400,0.04210274,0.037574314,,,,,,,,,,,,,,,,, -14500,0.024883542,0.03974427,,,,,,,,,,,,,,,,, -14600,0.022950558,0.03772144,,,,,,,,,,,,,,,,, -14700,0.11028018,0.03807817,,,,,,,,,,,,,,,,, -14800,0.031664,0.042639133,,,,,,,,,,,,,,,,, -14900,0.017864257,0.033520475,,,,,,,,,,,,,,,,, -14978,,,0.9885103106498718,0.0394258648157119,0.2103199067351833,0.9857169389724731,0.0490715727210044,0.1809956758013063,43793.0,0.9847737550735474,0.0518770404160022,0.1807104800860751,43793.0,4813.4798884391785,7024.253228664398,4813.4798884391785,2209.81316947937,0.5619616508483887,0.0 -15000,0.05570267,0.040190812,,,,,,,,,,,,,,,,, -15100,0.06508002,0.03616492,,,,,,,,,,,,,,,,, -15200,0.04539135,0.03977441,,,,,,,,,,,,,,,,, -15300,0.035493832,0.04261783,,,,,,,,,,,,,,,,, -15400,0.03892091,0.042681392,,,,,,,,,,,,,,,,, -15500,0.10719732,0.03923257,,,,,,,,,,,,,,,,, -15600,0.05823851,0.03526276,,,,,,,,,,,,,,,,, -15700,0.072576374,0.03519691,,,,,,,,,,,,,,,,, -15730,,,0.9885491132736206,0.0389850884675979,0.2031699322901911,0.985655665397644,0.0487467497587204,0.1846246212844854,43793.0,0.9846811294555664,0.0517447367310524,0.1799350506932672,43793.0,5053.459381818771,7364.3765437603,5053.459381818771,2309.904142856598,0.5947067737579346,0.0 -15800,0.04569355,0.038450208,,,,,,,,,,,,,,,,, -15900,0.055495154,0.04210249,,,,,,,,,,,,,,,,, -16000,0.057426963,0.03825146,,,,,,,,,,,,,,,,, -16100,0.058459856,0.037773684,,,,,,,,,,,,,,,,, -16200,0.054313257,0.03944111,,,,,,,,,,,,,,,,, -16300,0.033372123,0.034463726,,,,,,,,,,,,,,,,, -16400,0.041566107,0.040300205,,,,,,,,,,,,,,,,, -16481,,,0.9886710047721864,0.039756067097187,0.2001787874556077,0.9856000542640686,0.0490147173404693,0.1758140324766489,43793.0,0.9846815466880798,0.0517053455114364,0.1696599787089885,43793.0,5293.689656734467,7708.322807788849,5293.689656734467,2413.569778442383,0.625511884689331,0.0 -16500,0.041238382,0.04189467,,,,,,,,,,,,,,,,, -16600,0.0394607,0.038895458,,,,,,,,,,,,,,,,, -16700,0.07447491,0.03831168,,,,,,,,,,,,,,,,, -16800,0.03534591,0.037696645,,,,,,,,,,,,,,,,, -16900,0.06032246,0.03880442,,,,,,,,,,,,,,,,, -17000,0.05092649,0.039158422,,,,,,,,,,,,,,,,, -17100,0.045324594,0.038741507,,,,,,,,,,,,,,,,, -17200,0.09818107,0.03698424,,,,,,,,,,,,,,,,, -17235,,,0.988852858543396,0.0381815806031227,0.2247231109183317,0.9858115315437316,0.0481926798820495,0.1914216357737809,43793.0,0.9848251342773438,0.0511091090738773,0.1786148488164832,43793.0,5533.651293992996,8051.120588302612,5533.651293992996,2516.3563482761383,0.6553435325622559,0.0 -17300,0.0641559,0.036817204,,,,,,,,,,,,,,,,, -17400,0.031999897,0.037816595,,,,,,,,,,,,,,,,, -17500,0.05988805,0.04235089,,,,,,,,,,,,,,,,, -17600,0.029984638,0.036229365,,,,,,,,,,,,,,,,, -17700,0.05639642,0.037896175,,,,,,,,,,,,,,,,, -17800,0.112954244,0.042688474,,,,,,,,,,,,,,,,, -17900,0.08373367,0.03628962,,,,,,,,,,,,,,,,, -17987,,,0.988699436187744,0.0385518334805965,0.2194349941750594,0.9855996370315552,0.0491237789392471,0.1775695241781132,43793.0,0.9846655130386353,0.0518954284489154,0.1731670891302522,43793.0,5773.6067237854,8394.634209156036,5773.6067237854,2619.8639080524445,0.6865098476409912,0.0 -18000,0.07967406,0.04426142,,,,,,,,,,,,,,,,, -18100,0.048362415,0.035854742,,,,,,,,,,,,,,,,, -18200,0.03616971,0.03757357,,,,,,,,,,,,,,,,, -18300,0.030728243,0.042169906,,,,,,,,,,,,,,,,, -18400,0.07760945,0.039317675,,,,,,,,,,,,,,,,, -18500,0.058550343,0.040212132,,,,,,,,,,,,,,,,, -18600,0.03635082,0.039266877,,,,,,,,,,,,,,,,, -18700,0.083164655,0.039528046,,,,,,,,,,,,,,,,, -18728,,,0.9887787699699402,0.0383401215076446,0.2242228187728377,0.9857478141784668,0.0490990057587623,0.1829504384835823,43793.0,0.98477041721344,0.0521637499332428,0.1759919712575796,43793.0,6013.775787115097,8743.120332956314,6013.775787115097,2728.1255388259888,0.7215981483459473,0.0 -18800,0.05645905,0.039435793,,,,,,,,,,,,,,,,, -18900,0.040560123,0.040703434,,,,,,,,,,,,,,,,, -19000,0.041500323,0.042786323,,,,,,,,,,,,,,,,, -19100,0.030191245,0.038863026,,,,,,,,,,,,,,,,, -19200,0.048404865,0.040867783,,,,,,,,,,,,,,,,, -19300,0.055067685,0.041848667,,,,,,,,,,,,,,,,, -19400,0.03643892,0.040704653,,,,,,,,,,,,,,,,, -19480,,,0.988776445388794,0.0384742207825183,0.2229867901020646,0.9857891798019408,0.0489357411861419,0.1849021107267856,43793.0,0.9848179817199708,0.0520950555801391,0.1765644162673381,43793.0,6253.961903810501,9087.911134004593,6253.961903810501,2832.679259777069,0.7523925304412842,0.0 -19500,0.048332762,0.038266767,,,,,,,,,,,,,,,,, -19600,0.031545337,0.037920862,,,,,,,,,,,,,,,,, -19700,0.041585695,0.03818392,,,,,,,,,,,,,,,,, -19800,0.08938063,0.03968844,,,,,,,,,,,,,,,,, -19900,0.033206373,0.035223335,,,,,,,,,,,,,,,,, -20000,0.035724442,0.04056107,,,,,,,,,,,,,,,,, -20100,0.04819733,0.039753255,,,,,,,,,,,,,,,,, -20200,0.06569826,0.037983257,,,,,,,,,,,,,,,,, -20226,,,0.9885373115539552,0.0390364378690719,0.2137670788115012,0.9857355952262878,0.0490761920809745,0.1825315784667147,43793.0,0.9848251342773438,0.0520286560058593,0.174959123510546,43793.0,6494.021353006363,9430.473743200302,6494.021353006363,2935.130667924881,0.7841694355010986,0.0 -20300,0.08346763,0.04136601,,,,,,,,,,,,,,,,, -20400,0.029535558,0.035546385,,,,,,,,,,,,,,,,, -20500,0.05236629,0.04069077,,,,,,,,,,,,,,,,, -20600,0.041929934,0.04260756,,,,,,,,,,,,,,,,, -20700,0.0545322,0.033757355,,,,,,,,,,,,,,,,, -20800,0.022361673,0.034717835,,,,,,,,,,,,,,,,, -20900,0.03593679,0.04121339,,,,,,,,,,,,,,,,, -20980,,,0.988884449005127,0.0387068316340446,0.2198519838767519,0.9858046174049376,0.0497435145080089,0.1831063059045827,43793.0,0.984847903251648,0.0528499558568,0.1779089723611548,43793.0,6734.287105321884,9776.793322563171,6734.287105321884,3041.1337745189667,0.8143470287322998,0.0 -21000,0.047317203,0.036166802,,,,,,,,,,,,,,,,, -21100,0.043062426,0.03684865,,,,,,,,,,,,,,,,, -21200,0.056007884,0.036578424,,,,,,,,,,,,,,,,, -21300,0.068218194,0.038905554,,,,,,,,,,,,,,,,, -21400,0.04045159,0.037444565,,,,,,,,,,,,,,,,, -21500,0.041514717,0.034753647,,,,,,,,,,,,,,,,, -21600,0.042948104,0.038816668,,,,,,,,,,,,,,,,, -21700,0.058646176,0.037503697,,,,,,,,,,,,,,,,, -21726,,,0.9888972640037536,0.0385029278695583,0.2088390445070321,0.9857494235038756,0.0482809171080589,0.1844016397267828,43793.0,0.9848167300224304,0.0513973832130432,0.1776546379820925,43793.0,6974.567994594574,10126.000288248062,6974.567994594574,3150.0082376003265,0.8462753295898438,0.0 -21800,0.025442464,0.03403818,,,,,,,,,,,,,,,,, -21900,0.026397012,0.043106806,,,,,,,,,,,,,,,,, -22000,0.030370764,0.037935324,,,,,,,,,,,,,,,,, -22100,0.052133344,0.039606296,,,,,,,,,,,,,,,,, -22200,0.048772756,0.035770964,,,,,,,,,,,,,,,,, -22300,0.06783106,0.040805917,,,,,,,,,,,,,,,,, -22400,0.04070612,0.033718284,,,,,,,,,,,,,,,,, -22477,,,0.9887509942054749,0.0384994745254516,0.2163177758440355,0.9858188033103944,0.0485533736646175,0.1951387458625174,43793.0,0.984839916229248,0.0518183521926403,0.1856877932834978,43793.0,7214.62762594223,10469.517782449722,7214.62762594223,3253.414433002472,0.877194881439209,0.0 -22500,0.027977388,0.0335413,,,,,,,,,,,,,,,,, -22600,0.034783967,0.039947305,,,,,,,,,,,,,,,,, -22700,0.03450495,0.033720925,,,,,,,,,,,,,,,,, -22800,0.048547044,0.035490654,,,,,,,,,,,,,,,,, -22900,0.04928759,0.038258437,,,,,,,,,,,,,,,,, -23000,0.04534713,0.040021647,,,,,,,,,,,,,,,,, -23100,0.03944411,0.033072997,,,,,,,,,,,,,,,,, -23200,0.030016506,0.03626245,,,,,,,,,,,,,,,,, -23223,,,0.988722324371338,0.0386566929519176,0.2140264105266031,0.9857449531555176,0.0490856431424617,0.1800239360133262,43793.0,0.9848597049713136,0.0520401038229465,0.173719214320403,43793.0,7454.670845508575,10818.201761245728,7454.670845508575,3362.0041534900665,0.9081747531890868,0.0 -23300,0.04108751,0.04456079,,,,,,,,,,,,,,,,, -23400,0.07010899,0.040693827,,,,,,,,,,,,,,,,, -23500,0.05342665,0.037468355,,,,,,,,,,,,,,,,, -23600,0.040098127,0.036609847,,,,,,,,,,,,,,,,, -23700,0.061712664,0.03493397,,,,,,,,,,,,,,,,, -23800,0.032357484,0.04397474,,,,,,,,,,,,,,,,, -23900,0.028340328,0.038816173,,,,,,,,,,,,,,,,, -23962,,,0.988548755645752,0.0393689386546611,0.2090295690843694,0.9855188131332396,0.0490764528512954,0.1787901620259714,43793.0,0.9845973253250122,0.052184447646141,0.1705859883723574,43793.0,7694.770886421204,11162.487454891205,7694.770886421204,3466.1369805336,0.9395201206207277,0.0 -24000,0.0416148,0.039972518,,,,,,,,,,,,,,,,, -24100,0.034701873,0.034237985,,,,,,,,,,,,,,,,, -24200,0.06448649,0.035958327,,,,,,,,,,,,,,,,, -24300,0.034610406,0.03792924,,,,,,,,,,,,,,,,, -24400,0.036781177,0.041124348,,,,,,,,,,,,,,,,, -24500,0.06339632,0.03974727,,,,,,,,,,,,,,,,, -24600,0.035260007,0.040471066,,,,,,,,,,,,,,,,, -24700,0.030491881,0.036474537,,,,,,,,,,,,,,,,, -24708,,,0.9888287782669068,0.0382855348289012,0.2273005416580073,0.9858277440071106,0.0481491759419441,0.1914915996375246,43793.0,0.9849161505699158,0.0509380102157592,0.1800086315183968,43793.0,7934.792539834976,11507.038450717926,7934.792539834976,3570.6146986484528,0.9710242748260498,0.0 -24800,0.049334437,0.040549282,,,,,,,,,,,,,,,,, -24900,0.032226678,0.035212792,,,,,,,,,,,,,,,,, -25000,0.024223184,0.038755573,,,,,,,,,,,,,,,,, -25100,0.046126053,0.036753695,,,,,,,,,,,,,,,,, -25200,0.08064174,0.038552377,,,,,,,,,,,,,,,,, -25300,0.07618261,0.04440863,,,,,,,,,,,,,,,,, -25400,0.06698976,0.038256224,,,,,,,,,,,,,,,,, -25447,,,0.9889517426490784,0.0376424901187419,0.2403767031906625,0.9858801364898682,0.0479339472949504,0.1939360319991367,43793.0,0.9848967790603638,0.0509436763823032,0.1838913211204439,43793.0,8174.904351234436,11853.188624620438,8174.904351234436,3676.598792552948,1.0045819282531738,0.0 -25500,0.040349454,0.039531156,,,,,,,,,,,,,,,,, -25600,0.060937785,0.03876306,,,,,,,,,,,,,,,,, -25700,0.08641693,0.041435987,,,,,,,,,,,,,,,,, -25800,0.072181046,0.03404883,,,,,,,,,,,,,,,,, -25900,0.043462247,0.035545666,,,,,,,,,,,,,,,,, -26000,0.036425963,0.0374025,,,,,,,,,,,,,,,,, -26100,0.05232204,0.046086755,,,,,,,,,,,,,,,,, -26174,,,0.988925576210022,0.0379108116030693,0.2249631482400948,0.9856353402137756,0.0484851188957691,0.1846468363042744,43793.0,0.9846979379653932,0.0513856559991836,0.1769077695913019,43793.0,8414.594813585281,12196.693078517914,8414.594813585281,3779.948962688446,1.4443256855010986,0.0 -26200,0.049142353,0.036506105,,,,,,,,,,,,,,,,, -26300,0.056608263,0.038479816,,,,,,,,,,,,,,,,, -26400,0.037344206,0.0414892,,,,,,,,,,,,,,,,, -26500,0.034351718,0.040141456,,,,,,,,,,,,,,,,, -26600,0.031788237,0.039570313,,,,,,,,,,,,,,,,, -26700,0.06196582,0.03675768,,,,,,,,,,,,,,,,, -26800,0.038354103,0.042600125,,,,,,,,,,,,,,,,, -26900,0.042014748,0.036206435,,,,,,,,,,,,,,,,, -26923,,,0.9887903332710266,0.0383216552436351,0.2137093013014565,0.9858123064041138,0.0486218445003032,0.1889470301621161,43793.0,0.9848226308822632,0.0519541725516319,0.1780696550770492,43793.0,8654.654390335083,12541.8131275177,8654.654390335083,3884.9562027454376,1.477400302886963,0.0 -27000,0.039158117,0.036413345,,,,,,,,,,,,,,,,, -27100,0.060680937,0.039709516,,,,,,,,,,,,,,,,, -27200,0.042248655,0.037057146,,,,,,,,,,,,,,,,, -27300,0.053581502,0.035720155,,,,,,,,,,,,,,,,, -27400,0.08018854,0.043887254,,,,,,,,,,,,,,,,, -27500,0.034863226,0.037127182,,,,,,,,,,,,,,,,, -27600,0.030597102,0.03475186,,,,,,,,,,,,,,,,, -27663,,,0.9887416362762452,0.0384444631636142,0.2172426777091902,0.9857981204986572,0.0482029020786285,0.1893011427508064,43793.0,0.984944760799408,0.0511478707194328,0.1803156483087207,43793.0,8894.8265645504,12884.48639678955,8894.8265645504,3987.4045078754425,1.5093200206756592,0.0 -27700,0.02848403,0.03331613,,,,,,,,,,,,,,,,, -27800,0.06413183,0.04090856,,,,,,,,,,,,,,,,, -27900,0.047642134,0.041624554,,,,,,,,,,,,,,,,, -28000,0.032830887,0.03994672,,,,,,,,,,,,,,,,, -28100,0.046791688,0.0358101,,,,,,,,,,,,,,,,, -28200,0.042062495,0.034804087,,,,,,,,,,,,,,,,, -28300,0.07162879,0.037488356,,,,,,,,,,,,,,,,, -28398,,,0.9888597130775452,0.0383498929440975,0.2187355641821716,0.985792875289917,0.0483689941465854,0.1914001890482414,43793.0,0.9848782420158386,0.0512909218668937,0.1815005829169209,43793.0,9135.05375790596,13235.853038072586,9135.05375790596,4098.485007286072,1.5454697608947754,0.0 -28400,0.03243045,0.039842855,,,,,,,,,,,,,,,,, -28500,0.047135964,0.0392918,,,,,,,,,,,,,,,,, -28600,0.04129879,0.040412307,,,,,,,,,,,,,,,,, -28700,0.050313447,0.043207455,,,,,,,,,,,,,,,,, -28800,0.09129127,0.03958869,,,,,,,,,,,,,,,,, -28900,0.032129597,0.036922544,,,,,,,,,,,,,,,,, -29000,0.051060487,0.036942404,,,,,,,,,,,,,,,,, -29100,0.037469324,0.03411475,,,,,,,,,,,,,,,,, -29139,,,0.989043354988098,0.0376631803810596,0.2262856218682795,0.9858456254005432,0.0478578992187976,0.1901224715875739,43793.0,0.9849388599395752,0.0508907735347747,0.1844989454817836,43793.0,9375.042990922928,13578.868658781052,9375.042990922928,4201.458696365356,1.5778565406799316,0.0 -29200,0.056540027,0.039197095,,,,,,,,,,,,,,,,, -29300,0.043879133,0.038040053,,,,,,,,,,,,,,,,, -29400,0.057790738,0.034978066,,,,,,,,,,,,,,,,, -29500,0.075416245,0.039620653,,,,,,,,,,,,,,,,, -29600,0.030134406,0.040365264,,,,,,,,,,,,,,,,, -29700,0.0601657,0.039484046,,,,,,,,,,,,,,,,, -29800,0.0916145,0.03742821,,,,,,,,,,,,,,,,, -29886,,,0.9888697862625122,0.0380624830722808,0.2286901512819851,0.9859182834625244,0.0480234026908874,0.1951326886481693,43793.0,0.984988570213318,0.0511430278420448,0.1817438590122285,43793.0,9615.254949569702,13923.735192537308,9615.254949569702,4306.060763597488,1.6103568077087402,0.0 -29900,0.049680423,0.038169555,,,,,,,,,,,,,,,,, -30000,0.03744842,0.038357913,,,,,,,,,,,,,,,,, -30100,0.07425707,0.0379648,,,,,,,,,,,,,,,,, -30200,0.09569824,0.039608955,,,,,,,,,,,,,,,,, -30300,0.0497061,0.040819112,,,,,,,,,,,,,,,,, -30400,0.072322145,0.041192994,,,,,,,,,,,,,,,,, -30500,0.0735613,0.04209795,,,,,,,,,,,,,,,,, -30600,0.071688116,0.037998714,,,,,,,,,,,,,,,,, -30632,,,0.9889466166496276,0.0379223525524139,0.2214070215655898,0.985720992088318,0.0488744862377643,0.1847238076122485,43793.0,0.9848508834838868,0.0519293956458568,0.177579091184478,43793.0,9855.509615659714,14264.285064697266,9855.509615659714,4406.300930023193,1.644965410232544,0.0 -30700,0.042822506,0.036952943,,,,,,,,,,,,,,,,, -30800,0.059785973,0.039580192,,,,,,,,,,,,,,,,, -30900,0.03391597,0.04025945,,,,,,,,,,,,,,,,, -31000,0.07245156,0.034125194,,,,,,,,,,,,,,,,, -31100,0.067789845,0.04102773,,,,,,,,,,,,,,,,, -31200,0.08850959,0.035264574,,,,,,,,,,,,,,,,, -31300,0.061925862,0.037625436,,,,,,,,,,,,,,,,, -31367,,,0.988867163658142,0.0382557809352874,0.2291374449002138,0.9856942296028136,0.0482274815440177,0.1854468354650805,43793.0,0.9847914576530457,0.0509674362838268,0.1805531748658989,43793.0,10095.65578699112,14606.683991193771,10095.65578699112,4508.494157552719,1.681239128112793,0.0 -31400,0.07819885,0.03628132,,,,,,,,,,,,,,,,, -31500,0.06555846,0.036601745,,,,,,,,,,,,,,,,, -31600,0.07124342,0.036226965,,,,,,,,,,,,,,,,, -31700,0.046703715,0.039502338,,,,,,,,,,,,,,,,, -31800,0.035893556,0.033858,,,,,,,,,,,,,,,,, -31900,0.064699955,0.03543138,,,,,,,,,,,,,,,,, -32000,0.04065816,0.037266683,,,,,,,,,,,,,,,,, -32100,0.06516264,0.035874404,,,,,,,,,,,,,,,,, -32111,,,0.988877534866333,0.0377889573574066,0.2309434431589295,0.9858277440071106,0.0486466772854328,0.1995270009315271,43793.0,0.9849165678024292,0.051450528204441,0.1954261868553897,43793.0,10335.89645934105,14955.209387540815,10335.89645934105,4616.725874662399,1.714320421218872,0.0 -32200,0.04776319,0.037010815,,,,,,,,,,,,,,,,, -32300,0.07916556,0.037194453,,,,,,,,,,,,,,,,, -32400,0.044429597,0.03683188,,,,,,,,,,,,,,,,, -32500,0.032262534,0.035210695,,,,,,,,,,,,,,,,, -32600,0.038762283,0.038964752,,,,,,,,,,,,,,,,, -32700,0.08339995,0.044255782,,,,,,,,,,,,,,,,, -32800,0.048045345,0.036628425,,,,,,,,,,,,,,,,, -32848,,,0.9889434576034546,0.0377963371574878,0.2432826322175141,0.9857007265090942,0.0483184233307838,0.1951626788260467,43793.0,0.9847506284713744,0.0512115359306335,0.1892154607444269,43793.0,10575.972088098526,15301.55495762825,10575.972088098526,4722.938477993012,1.7510161399841309,0.0 -32900,0.044538595,0.033481833,,,,,,,,,,,,,,,,, -33000,0.040296584,0.037578106,,,,,,,,,,,,,,,,, -33100,0.04868668,0.039281517,,,,,,,,,,,,,,,,, -33200,0.07610613,0.034951515,,,,,,,,,,,,,,,,, -33300,0.061794102,0.039353613,,,,,,,,,,,,,,,,, -33400,0.043613236,0.038620066,,,,,,,,,,,,,,,,, -33500,0.11195796,0.038780376,,,,,,,,,,,,,,,,, -33585,,,0.9891308546066284,0.0373149178922176,0.2260269949128286,0.9859727025032043,0.0474255979061126,0.1937569568102272,43793.0,0.9850643873214722,0.0500930547714233,0.1914235613692715,43793.0,10816.100275278091,15647.596687793732,10816.100275278091,4828.792944908142,1.7866477966308594,0.0 -33600,0.08398005,0.04067691,,,,,,,,,,,,,,,,, -33700,0.045413494,0.037362076,,,,,,,,,,,,,,,,, -33800,0.064117365,0.033156294,,,,,,,,,,,,,,,,, -33900,0.04425611,0.040231615,,,,,,,,,,,,,,,,, -34000,0.040160097,0.039214592,,,,,,,,,,,,,,,,, -34100,0.064161584,0.042520866,,,,,,,,,,,,,,,,, -34200,0.041448258,0.037375327,,,,,,,,,,,,,,,,, -34300,0.07668567,0.03987419,,,,,,,,,,,,,,,,, -34329,,,0.9889580011367798,0.0375989601016044,0.2289441943128746,0.9859438538551332,0.0477371960878372,0.2020584578773697,43793.0,0.9850370287895204,0.0505309477448463,0.1907580082815553,43793.0,11056.169059515,15994.386904001236,11056.169059515,4935.460122585297,1.8201825618743896,0.0 -34400,0.042403303,0.0430191,,,,,,,,,,,,,,,,, -34500,0.06681446,0.03717444,,,,,,,,,,,,,,,,, -34600,0.041735075,0.033603474,,,,,,,,,,,,,,,,, -34700,0.039131895,0.040907025,,,,,,,,,,,,,,,,, -34800,0.070598856,0.03882684,,,,,,,,,,,,,,,,, -34900,0.04730255,0.039842017,,,,,,,,,,,,,,,,, -35000,0.034231324,0.038932763,,,,,,,,,,,,,,,,, -35077,,,0.9889733791351318,0.0374145433306694,0.2383705253421114,0.985992968082428,0.0475658178329467,0.1923817748258664,43793.0,0.9851309657096864,0.0503195337951183,0.1868910440185294,43793.0,11296.211079359056,16344.833291053772,11296.211079359056,5045.811078548431,1.8531954288482664,0.0 -35100,0.06437449,0.043489262,,,,,,,,,,,,,,,,, -35200,0.079943724,0.037038073,,,,,,,,,,,,,,,,, -35300,0.034531146,0.036116894,,,,,,,,,,,,,,,,, -35400,0.04870297,0.040859073,,,,,,,,,,,,,,,,, -35500,0.052753475,0.036901653,,,,,,,,,,,,,,,,, -35600,0.053608216,0.037938878,,,,,,,,,,,,,,,,, -35700,0.043830216,0.035692893,,,,,,,,,,,,,,,,, -35800,0.065785274,0.038564436,,,,,,,,,,,,,,,,, -35816,,,0.9889889359474182,0.0376660488545894,0.2281931754273816,0.9858512878417968,0.0482183136045932,0.1911937336152979,43793.0,0.9848508834838868,0.0513558462262153,0.1860502985099147,43793.0,11536.243942975998,16685.29561161995,11536.243942975998,5146.181943178177,1.8905532360076904,0.0 -35900,0.05422063,0.039127946,,,,,,,,,,,,,,,,, -36000,0.055863965,0.041542325,,,,,,,,,,,,,,,,, -36100,0.043039694,0.03680348,,,,,,,,,,,,,,,,, -36200,0.08296999,0.03568293,,,,,,,,,,,,,,,,, -36300,0.044476073,0.038350996,,,,,,,,,,,,,,,,, -36400,0.0573417,0.036974747,,,,,,,,,,,,,,,,, -36500,0.082189195,0.04161191,,,,,,,,,,,,,,,,, -36561,,,0.9889605641365052,0.0377225056290626,0.2295263943943149,0.9858009815216064,0.0483268313109874,0.1856618333966113,43793.0,0.984813392162323,0.0513754598796367,0.1791125450303736,43793.0,11776.2650411129,17027.249828100204,11776.2650411129,5248.06134390831,1.9240601062774656,0.0 -36600,0.055650797,0.037079997,,,,,,,,,,,,,,,,, -36700,0.041650094,0.035432044,,,,,,,,,,,,,,,,, -36800,0.07167458,0.03831769,,,,,,,,,,,,,,,,, -36900,0.05551295,0.03810415,,,,,,,,,,,,,,,,, -37000,0.041680764,0.034464944,,,,,,,,,,,,,,,,, -37100,0.09438165,0.03603118,,,,,,,,,,,,,,,,, -37200,0.081347644,0.037162676,,,,,,,,,,,,,,,,, -37300,0.04496459,0.037204076,,,,,,,,,,,,,,,,, -37307,,,0.9892958998680116,0.0363865718245506,0.2465555241199644,0.985998272895813,0.0471748933196067,0.2012492139526782,43793.0,0.9851098656654358,0.0500889718532562,0.1917312015242276,43793.0,12016.520510673525,17371.671184062958,12016.520510673525,5352.1737768650055,1.9577631950378416,0.0 -37400,0.042874616,0.034105144,,,,,,,,,,,,,,,,, -37500,0.063488305,0.036377195,,,,,,,,,,,,,,,,, -37600,0.05983851,0.03605571,,,,,,,,,,,,,,,,, -37700,0.07851481,0.03348157,,,,,,,,,,,,,,,,, -37800,0.09047632,0.04137988,,,,,,,,,,,,,,,,, -37900,0.06038528,0.03458473,,,,,,,,,,,,,,,,, -38000,0.08379629,0.03463358,,,,,,,,,,,,,,,,, -38051,,,0.9892456531524658,0.0365177951753139,0.2509095829787848,0.9860972762107848,0.0475809313356876,0.1989474597551918,43793.0,0.9851751923561096,0.050572469830513,0.1954931106804427,43793.0,12256.781912088394,17718.74373793602,12256.781912088394,5458.929462432861,1.9929468631744385,0.0 -38100,0.048108064,0.03268896,,,,,,,,,,,,,,,,, -38200,0.06568773,0.033856794,,,,,,,,,,,,,,,,, -38300,0.09280777,0.041864935,,,,,,,,,,,,,,,,, -38400,0.05557208,0.031360578,,,,,,,,,,,,,,,,, -38500,0.13004738,0.037107885,,,,,,,,,,,,,,,,, -38600,0.06608143,0.035315108,,,,,,,,,,,,,,,,, -38700,0.052810807,0.04122,,,,,,,,,,,,,,,,, -38792,,,0.9892786145210266,0.0362685807049274,0.2614758905918543,0.9860640168190002,0.0471036955714225,0.2022572474787074,43793.0,0.9852147698402404,0.0499976128339767,0.2027752433634516,43793.0,12496.968587875366,18062.000724554066,12496.968587875366,5561.9447453022,2.0276620388031006,0.0 -38800,0.044898286,0.0359516,,,,,,,,,,,,,,,,, -38900,0.052256078,0.036525805,,,,,,,,,,,,,,,,, -39000,0.044357672,0.034002163,,,,,,,,,,,,,,,,, -39100,0.03896512,0.033831157,,,,,,,,,,,,,,,,, -39200,0.08771018,0.038808145,,,,,,,,,,,,,,,,, -39300,0.043298297,0.036992405,,,,,,,,,,,,,,,,, -39400,0.064867966,0.037643228,,,,,,,,,,,,,,,,, -39500,0.03906952,0.039996833,,,,,,,,,,,,,,,,, -39536,,,0.9891278743743896,0.0369930937886238,0.2449056318351088,0.985833466053009,0.0472330823540687,0.196620120614496,43793.0,0.98500794172287,0.0498680137097835,0.1890189421100775,43793.0,12736.987015485764,18404.35669374466,12736.987015485764,5664.227452039719,2.0624632835388184,0.0 -39600,0.054595996,0.03983946,,,,,,,,,,,,,,,,, -39700,0.057428963,0.03919424,,,,,,,,,,,,,,,,, -39800,0.042272426,0.039966777,,,,,,,,,,,,,,,,, -39900,0.10495113,0.039301034,,,,,,,,,,,,,,,,, -40000,0.045012876,0.0362658,,,,,,,,,,,,,,,,, -40100,0.07858942,0.033470284,,,,,,,,,,,,,,,,, -40200,0.10194181,0.03360751,,,,,,,,,,,,,,,,, -40280,,,0.9891670346260072,0.0370794646441936,0.2359508469546338,0.985985279083252,0.0471570491790771,0.2041718924831687,43793.0,0.9851473569869996,0.0499475635588169,0.1944470872124818,43793.0,12977.07556772232,18750.859877347943,12977.07556772232,5770.586403608322,2.098353624343872,0.0 -40300,0.06944603,0.039726246,,,,,,,,,,,,,,,,, -40400,0.046469014,0.040053707,,,,,,,,,,,,,,,,, -40500,0.061690174,0.039940108,,,,,,,,,,,,,,,,, -40600,0.05166477,0.036312055,,,,,,,,,,,,,,,,, -40700,0.07204496,0.040347043,,,,,,,,,,,,,,,,, -40800,0.05004933,0.035923507,,,,,,,,,,,,,,,,, -40900,0.07333659,0.03893823,,,,,,,,,,,,,,,,, -41000,0.062377058,0.042165313,,,,,,,,,,,,,,,,, -41030,,,0.9892788529396056,0.0365193933248519,0.2582791490632765,0.9861382842063904,0.0466996915638446,0.2035863850120045,43793.0,0.9852400422096252,0.0495612360537052,0.1957468493772349,43793.0,13217.259890556335,19089.960252285004,13217.259890556335,5869.447708368301,2.133057117462158,0.0 -41100,0.058744617,0.03616053,,,,,,,,,,,,,,,,, -41200,0.052701864,0.03645578,,,,,,,,,,,,,,,,, -41300,0.053059526,0.038983453,,,,,,,,,,,,,,,,, -41400,0.047952358,0.034800593,,,,,,,,,,,,,,,,, -41500,0.10051014,0.039027166,,,,,,,,,,,,,,,,, -41600,0.11797534,0.035697155,,,,,,,,,,,,,,,,, -41700,0.094289705,0.036779236,,,,,,,,,,,,,,,,, -41783,,,0.98917156457901,0.036688920110464,0.2375985526014747,0.9860579371452332,0.0470901429653167,0.20033803818553,43793.0,0.9851107597351074,0.0501090809702873,0.1927060702893194,43793.0,13457.330917358398,19428.43219280243,13457.330917358398,5967.794405460358,2.167390823364258,0.0 -41800,0.060284372,0.037063237,,,,,,,,,,,,,,,,, -41900,0.052909788,0.03905625,,,,,,,,,,,,,,,,, -42000,0.062805355,0.034406442,,,,,,,,,,,,,,,,, -42100,0.051183455,0.03571199,,,,,,,,,,,,,,,,, -42200,0.045601513,0.04165509,,,,,,,,,,,,,,,,, -42300,0.053038925,0.038227025,,,,,,,,,,,,,,,,, -42400,0.07133382,0.042040262,,,,,,,,,,,,,,,,, -42500,0.050429225,0.03470744,,,,,,,,,,,,,,,,, -42524,,,0.9892622828483582,0.0363938324153423,0.2609484153302455,0.9860400557518004,0.0468417443335056,0.2093666927716416,43793.0,0.9851625561714172,0.0495956428349018,0.2029422827180296,43793.0,13697.295249938965,19778.66919779777,13697.295249938965,6078.01136302948,2.20347547531128,0.0 -42600,0.07471785,0.033825487,,,,,,,,,,,,,,,,, -42700,0.06772207,0.032503005,,,,,,,,,,,,,,,,, -42800,0.070259236,0.03552197,,,,,,,,,,,,,,,,, -42900,0.04925412,0.03328324,,,,,,,,,,,,,,,,, -43000,0.053795405,0.04035248,,,,,,,,,,,,,,,,, -43100,0.047500744,0.03795922,,,,,,,,,,,,,,,,, -43200,0.06061966,0.03673315,,,,,,,,,,,,,,,,, -43273,,,0.9893701076507568,0.0361695028841495,0.2488368199052049,0.9862012267112732,0.0469544194638729,0.2082395055408017,43793.0,0.985236644744873,0.0498459227383136,0.2022384611718691,43793.0,13937.44949221611,20120.39776873589,13937.44949221611,6179.531805515289,2.237947463989258,0.0 -43300,0.06810284,0.033356655,,,,,,,,,,,,,,,,, -43400,0.09084433,0.035424475,,,,,,,,,,,,,,,,, -43500,0.0451164,0.03550024,,,,,,,,,,,,,,,,, -43600,0.08202211,0.03898031,,,,,,,,,,,,,,,,, -43700,0.038838726,0.037247267,,,,,,,,,,,,,,,,, -43800,0.07738127,0.033899,,,,,,,,,,,,,,,,, -43900,0.041797962,0.03736994,,,,,,,,,,,,,,,,, -44000,0.07354111,0.035396583,,,,,,,,,,,,,,,,, -44019,,,0.9893001317977904,0.0360747314989566,0.2637567741698877,0.9860668778419496,0.0468038506805896,0.2082817521928868,43793.0,0.9852383732795716,0.0496873632073402,0.2024729546533087,43793.0,14177.5388982296,20464.505380392075,14177.5388982296,6283.494655847549,2.2736637592315674,0.0 -44100,0.04097748,0.03717226,,,,,,,,,,,,,,,,, -44200,0.06755874,0.039420728,,,,,,,,,,,,,,,,, -44300,0.07139231,0.039577827,,,,,,,,,,,,,,,,, -44400,0.0462805,0.036236398,,,,,,,,,,,,,,,,, -44500,0.07194996,0.035870902,,,,,,,,,,,,,,,,, -44600,0.101671234,0.036961406,,,,,,,,,,,,,,,,, -44700,0.04652714,0.03418822,,,,,,,,,,,,,,,,, -44780,,,0.9894915223121644,0.035425916314125,0.2697776492707348,0.9861618280410768,0.0468348525464534,0.2099969184826789,43793.0,0.9852139353752136,0.0498356223106384,0.1990046135939414,43793.0,14417.7179479599,20805.993231773376,14417.7179479599,6384.748462438583,2.308631896972656,0.0 -44800,0.053148802,0.038031682,,,,,,,,,,,,,,,,, -44900,0.047990117,0.036052104,,,,,,,,,,,,,,,,, -45000,0.0561402,0.037040025,,,,,,,,,,,,,,,,, -45100,0.07261752,0.040919073,,,,,,,,,,,,,,,,, -45200,0.085082136,0.037000343,,,,,,,,,,,,,,,,, -45300,0.06714718,0.03476488,,,,,,,,,,,,,,,,, -45400,0.06006479,0.03862909,,,,,,,,,,,,,,,,, -45500,0.050224084,0.036563735,,,,,,,,,,,,,,,,, -45528,,,0.9893769025802612,0.0358100272715091,0.2722272612931093,0.9861736297607422,0.0468456894159317,0.20630033965734,43793.0,0.9852522611618042,0.049809843301773,0.1999492152297651,43793.0,14657.810967445374,21152.046332597733,14657.810967445374,6490.651208877564,2.3464467525482178,0.0 -45600,0.038893256,0.03043871,,,,,,,,,,,,,,,,, -45700,0.06453002,0.03474946,,,,,,,,,,,,,,,,, -45800,0.08378117,0.036619052,,,,,,,,,,,,,,,,, -45900,0.13006486,0.03652781,,,,,,,,,,,,,,,,, -46000,0.05706337,0.035638034,,,,,,,,,,,,,,,,, -46100,0.052129414,0.033789154,,,,,,,,,,,,,,,,, -46200,0.05594041,0.039238326,,,,,,,,,,,,,,,,, -46269,,,0.9895589351654052,0.0355012081563472,0.2688084506959136,0.9862186908721924,0.0464554876089096,0.2145519109500384,43793.0,0.9853171110153198,0.0493623949587345,0.2056291911976642,43793.0,14897.910877466202,21493.54910182953,14897.910877466202,6591.992488861084,2.38614821434021,0.0 -46300,0.08487459,0.0394335,,,,,,,,,,,,,,,,, -46400,0.047633916,0.035105582,,,,,,,,,,,,,,,,, -46500,0.06944554,0.036329776,,,,,,,,,,,,,,,,, -46600,0.08967555,0.033236768,,,,,,,,,,,,,,,,, -46700,0.04576786,0.03874291,,,,,,,,,,,,,,,,, -46800,0.04566292,0.03598158,,,,,,,,,,,,,,,,, -46900,0.11243234,0.038998015,,,,,,,,,,,,,,,,, -47000,0.058129452,0.033380277,,,,,,,,,,,,,,,,, -47009,,,0.9891928434371948,0.0365955233573913,0.2510488937477338,0.9861078858375548,0.0471156612038612,0.2064479490678904,43793.0,0.9852370619773864,0.0499967895448207,0.1964800375281237,43793.0,15138.077248096466,21838.38315463066,15138.077248096466,6696.60099697113,2.4226717948913574,0.0 -47100,0.09410474,0.037538424,,,,,,,,,,,,,,,,, -47200,0.057278384,0.03589899,,,,,,,,,,,,,,,,, -47300,0.05295194,0.03353207,,,,,,,,,,,,,,,,, -47400,0.05562184,0.033913042,,,,,,,,,,,,,,,,, -47500,0.17298043,0.03749574,,,,,,,,,,,,,,,,, -47600,0.043976676,0.03329097,,,,,,,,,,,,,,,,, -47700,0.07967725,0.03272196,,,,,,,,,,,,,,,,, -47751,,,0.9892840385437012,0.0362851954996585,0.2603987360262739,0.9862515330314636,0.046829804778099,0.2111941524972733,43793.0,0.9853697419166564,0.0497363582253456,0.2064511036544932,43793.0,15378.110914945602,22179.26458454132,15378.110914945602,6797.390934467316,2.4604079723358154,0.0 -47800,0.07560489,0.039435122,,,,,,,,,,,,,,,,, -47900,0.07020613,0.03283954,,,,,,,,,,,,,,,,, -48000,0.16628927,0.03927532,,,,,,,,,,,,,,,,, -48100,0.09600493,0.037147388,,,,,,,,,,,,,,,,, -48200,0.070234604,0.036570624,,,,,,,,,,,,,,,,, -48300,0.06939687,0.03737438,,,,,,,,,,,,,,,,, -48400,0.048565853,0.035474263,,,,,,,,,,,,,,,,, -48497,,,0.9894214272499084,0.0358730033040046,0.2551420798235866,0.9860554933547974,0.0465924367308616,0.2120485444070477,43793.0,0.9850791692733764,0.0494124293327331,0.2029720569531877,43793.0,15618.3207821846,22520.22537112236,15618.3207821846,6898.0792760849,2.502638578414917,0.0 -48500,0.048604358,0.03839347,,,,,,,,,,,,,,,,, -48600,0.0523939,0.029840332,,,,,,,,,,,,,,,,, -48700,0.12653913,0.034621503,,,,,,,,,,,,,,,,, -48800,0.055962693,0.032816414,,,,,,,,,,,,,,,,, -48900,0.06200517,0.039173357,,,,,,,,,,,,,,,,, -49000,0.09604607,0.03778172,,,,,,,,,,,,,,,,, -49100,0.06844121,0.037126496,,,,,,,,,,,,,,,,, -49200,0.08309043,0.03874998,,,,,,,,,,,,,,,,, -49250,,,0.9894852638244628,0.0357244722545146,0.2669809593413159,0.9862223267555236,0.0465139858424663,0.2131052753435941,43793.0,0.9852433800697328,0.0496956780552864,0.2003439402274657,43793.0,15858.295699596403,22862.790357112885,15858.295699596403,7000.6140151023865,2.537923812866211,0.0 -49300,0.07842012,0.03414368,,,,,,,,,,,,,,,,, -49400,0.06718013,0.037369173,,,,,,,,,,,,,,,,, -49500,0.064972155,0.03648693,,,,,,,,,,,,,,,,, -49600,0.06539601,0.03268685,,,,,,,,,,,,,,,,, -49700,0.07686462,0.035896864,,,,,,,,,,,,,,,,, -49800,0.06021778,0.034263324,,,,,,,,,,,,,,,,, -49900,0.0685127,0.036554415,,,,,,,,,,,,,,,,, -50000,,,0.9894618391990662,0.0355752818286418,0.2711189138805558,0.9861240983009338,0.0466132201254367,0.2095939907786124,43793.0,0.9852758646011353,0.0492729060351848,0.2034335283017547,43793.0,16098.242151737211,23203.83507847786,16098.242151737211,7101.655977487564,2.57417106628418,0.0 -50000,0.060587533,0.02772005,,,,,,,,,,,,,,,,, -50100,0.08969902,0.036294296,,,,,,,,,,,,,,,,, -50200,0.10729209,0.033578742,,,,,,,,,,,,,,,,, -50300,0.1061425,0.035329998,,,,,,,,,,,,,,,,, -50400,0.07282264,0.031771336,,,,,,,,,,,,,,,,, -50500,0.055297907,0.034276415,,,,,,,,,,,,,,,,, -50600,0.07160061,0.032587465,,,,,,,,,,,,,,,,, -50700,0.06298859,0.040379126,,,,,,,,,,,,,,,,, -50747,,,0.989648938179016,0.0350227542221546,0.2890718598209157,0.9863201379776,0.0462420061230659,0.212382103524577,43793.0,0.9854114651679992,0.0491921380162239,0.2062226497609948,43793.0,16338.338171482086,23549.17933702469,16338.338171482086,7206.845845937729,2.612422466278076,0.0 -50800,0.080017306,0.037201006,,,,,,,,,,,,,,,,, -50900,0.11254017,0.0376334,,,,,,,,,,,,,,,,, -51000,0.057173435,0.03438844,,,,,,,,,,,,,,,,, -51100,0.074692994,0.032988265,,,,,,,,,,,,,,,,, -51200,0.08909388,0.033151884,,,,,,,,,,,,,,,,, -51300,0.0596771,0.035042867,,,,,,,,,,,,,,,,, -51400,0.050093032,0.033366673,,,,,,,,,,,,,,,,, -51497,,,0.989640176296234,0.0348977521061897,0.2752238390381377,0.9863294959068298,0.0461780801415443,0.2151606578956752,43793.0,0.9853802919387816,0.0490918084979057,0.2061251770082258,43793.0,16578.522582292557,23890.57517528534,16578.522582292557,7307.999075889587,2.64997935295105,0.0 -51500,0.07599085,0.03378283,,,,,,,,,,,,,,,,, -51600,0.08023994,0.03396074,,,,,,,,,,,,,,,,, -51700,0.055352405,0.037501257,,,,,,,,,,,,,,,,, -51800,0.05094412,0.036002547,,,,,,,,,,,,,,,,, -51900,0.09463014,0.037827455,,,,,,,,,,,,,,,,, -52000,0.08817433,0.03215276,,,,,,,,,,,,,,,,, -52100,0.1377328,0.037237775,,,,,,,,,,,,,,,,, -52200,0.06975062,0.033192243,,,,,,,,,,,,,,,,, -52238,,,0.9896869659423828,0.0347418673336505,0.29000266580794,0.9863424897193908,0.0462528392672538,0.2170042776596787,43793.0,0.9854455590248108,0.049126137048006,0.2123704547099925,43793.0,16818.65753221512,24229.96969652176,16818.65753221512,7407.201065540314,2.6873366832733154,0.0 -52300,0.092729665,0.037122227,,,,,,,,,,,,,,,,, -52400,0.08044644,0.031740002,,,,,,,,,,,,,,,,, -52500,0.06539487,0.035309702,,,,,,,,,,,,,,,,, -52600,0.07069088,0.03296321,,,,,,,,,,,,,,,,, -52700,0.12749666,0.03776438,,,,,,,,,,,,,,,,, -52800,0.09692553,0.035002824,,,,,,,,,,,,,,,,, -52900,0.060673498,0.02921543,,,,,,,,,,,,,,,,, -52977,,,0.9897372126579284,0.0343862660229206,0.2859971482991346,0.986326277256012,0.0460726581513881,0.2227344429919973,43793.0,0.985417366027832,0.049002967774868,0.211536483233568,43793.0,17058.785735607147,24572.39704155922,17058.785735607147,7509.441777706146,2.724439859390259,0.0 -53000,0.09501092,0.034230977,,,,,,,,,,,,,,,,, -53100,0.08854493,0.037900902,,,,,,,,,,,,,,,,, -53200,0.06481269,0.037218966,,,,,,,,,,,,,,,,, -53300,0.06378017,0.032865252,,,,,,,,,,,,,,,,, -53400,0.055913474,0.034872424,,,,,,,,,,,,,,,,, -53500,0.054320943,0.035837576,,,,,,,,,,,,,,,,, -53600,0.08498241,0.035491977,,,,,,,,,,,,,,,,, -53700,0.11370264,0.03333551,,,,,,,,,,,,,,,,, -53732,,,0.989798903465271,0.0346322394907474,0.273643243785204,0.986401379108429,0.0458697639405727,0.2181591479311418,43793.0,0.9854910969734192,0.048701986670494,0.2097780509319448,43793.0,17298.919110774994,24915.203050851826,17298.919110774994,7612.057886600494,2.76061749458313,0.0 -53800,0.09704607,0.03630023,,,,,,,,,,,,,,,,, -53900,0.108703665,0.027370395,,,,,,,,,,,,,,,,, -54000,0.08019307,0.035007346,,,,,,,,,,,,,,,,, -54100,0.07560816,0.03449797,,,,,,,,,,,,,,,,, -54200,0.07935277,0.031168783,,,,,,,,,,,,,,,,, -54300,0.07001659,0.037036303,,,,,,,,,,,,,,,,, -54400,0.06722619,0.033079207,,,,,,,,,,,,,,,,, -54493,,,0.989654541015625,0.0349539890885353,0.2809974503744279,0.9862645268440248,0.0462120249867439,0.2155153642334833,43793.0,0.9853390455245972,0.0490941666066646,0.2090533931014886,43793.0,17538.96286535263,25257.191864728928,17538.96286535263,7713.9464428424835,2.797553777694702,0.0 -54500,0.06941174,0.034088764,,,,,,,,,,,,,,,,, -54600,0.062614456,0.034574393,,,,,,,,,,,,,,,,, -54700,0.08330314,0.037083156,,,,,,,,,,,,,,,,, -54800,0.06366747,0.030902786,,,,,,,,,,,,,,,,, -54900,0.10165764,0.03268847,,,,,,,,,,,,,,,,, -55000,0.08315711,0.03629938,,,,,,,,,,,,,,,,, -55100,0.09042141,0.036745574,,,,,,,,,,,,,,,,, -55200,0.123743124,0.036161277,,,,,,,,,,,,,,,,, -55249,,,0.989753007888794,0.0345969647169113,0.2854877147970602,0.9863563179969788,0.0458742342889308,0.2150310554975685,43793.0,0.9854329228401184,0.0487629063427448,0.2074281320406993,43793.0,17779.07443547249,25599.11487293244,17779.07443547249,7815.701034784317,2.834656476974488,0.0 -55300,0.0665811,0.03579074,,,,,,,,,,,,,,,,, -55400,0.09642959,0.034211624,,,,,,,,,,,,,,,,, -55500,0.09999147,0.03592895,,,,,,,,,,,,,,,,, -55600,0.13514332,0.03724675,,,,,,,,,,,,,,,,, -55700,0.05256573,0.03561941,,,,,,,,,,,,,,,,, -55800,0.06922448,0.03327599,,,,,,,,,,,,,,,,, -55900,0.07371141,0.038635038,,,,,,,,,,,,,,,,, -55990,,,0.9899150133132936,0.0341606698930263,0.2922762819836991,0.9864853620529176,0.0455940142273902,0.2239880195362973,43793.0,0.9854813814163208,0.0485040582716465,0.2149564124202218,43793.0,18019.22664308548,25940.16689515114,18019.22664308548,7916.534770727158,2.8772406578063965,0.0 -56000,0.08505922,0.033288933,,,,,,,,,,,,,,,,, -56100,0.10222753,0.033131637,,,,,,,,,,,,,,,,, -56200,0.07129729,0.03460676,,,,,,,,,,,,,,,,, -56300,0.10687168,0.032808967,,,,,,,,,,,,,,,,, -56400,0.10390359,0.036353286,,,,,,,,,,,,,,,,, -56500,0.087349944,0.034024607,,,,,,,,,,,,,,,,, -56600,0.07780021,0.034867935,,,,,,,,,,,,,,,,, -56700,0.08225389,0.03385255,,,,,,,,,,,,,,,,, -56739,,,0.9898921847343444,0.0340198650956153,0.2999987361066659,0.986504077911377,0.0459098368883132,0.2238110699225643,43793.0,0.9855051636695862,0.0488899461925029,0.2137652369595106,43793.0,18259.172052145004,26285.1296813488,18259.172052145004,8021.491682767868,2.9176416397094727,0.0 -56800,0.09756491,0.032365162,,,,,,,,,,,,,,,,, -56900,0.077852264,0.03142506,,,,,,,,,,,,,,,,, -57000,0.112937585,0.035462588,,,,,,,,,,,,,,,,, -57100,0.10024904,0.03585824,,,,,,,,,,,,,,,,, -57200,0.10017346,0.036433883,,,,,,,,,,,,,,,,, -57300,0.11027485,0.031712417,,,,,,,,,,,,,,,,, -57400,0.09817149,0.030444616,,,,,,,,,,,,,,,,, -57413,,,,,,,,,,,,,,18477.2912671566,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 86ac736a5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -103.0699977874756,0.0,13.2299907207489,1,0,13.2299907207489,0.5224818587303162,0.7161954641342163,0.0278394680676913,43793,116.30002975463869,0.5250520706176758,0.7150473594665527,0.0223154203866871,0.5213832855224609,0.7166012525558472,0.0261420653014558,43793 -205.0555601119995,0.0197958946228027,253.348135471344,758,0,253.348135471344,0.9832048416137696,0.0645277202129364,0.0557577787410884,43793,458.4430618286133,0.986740231513977,0.0512496270239353,0.0562020739249715,0.9841589331626892,0.0611993111670017,0.0552324893701033,43793 -311.6006577014923,0.0477762222290039,493.5868184566498,1500,0,493.5868184566498,0.9832578897476196,0.0637605413794517,0.0747159425183009,43793,805.2762854099274,0.9867498874664308,0.0503420047461986,0.0742255061412899,0.9842255115509032,0.0604429580271244,0.0718821453766525,43793 -412.3603291511536,0.3961923122406006,733.4666738510132,2254,0,733.4666738510132,0.9839398264884948,0.0567453317344188,0.1191155747065219,43793,1146.283578157425,0.9876054525375366,0.0443707220256328,0.1313076185234076,0.9849160313606262,0.0536021776497364,0.1235890399748405,43793 -514.584146976471,0.431208848953247,973.4950165748596,2992,0,973.4950165748596,0.984123468399048,0.0544441118836402,0.140199927936232,43793,1488.5906157493591,0.9879937171936036,0.0425633452832698,0.156441868377698,0.9850869178771972,0.0516696348786354,0.1419235231911762,43793 -616.377361536026,0.4610466957092285,1213.5263085365295,3734,0,1213.5263085365295,0.9844486117362976,0.053183663636446,0.161748321889361,43793,1830.4661691188808,0.9882489442825316,0.0411407016217708,0.1847905469470313,0.9853414297103882,0.0505308620631694,0.1624161177546391,43793 -723.1845881938934,0.4890384674072265,1453.7445459365845,4485,0,1453.7445459365845,0.9847337603569032,0.0520514212548732,0.1894657997357002,43793,2177.53892993927,0.9885805249214172,0.0395596623420715,0.2050289150720067,0.9856662154197692,0.0490881167352199,0.183546899527389,43793 -824.2064425945282,0.5192301273345947,1693.9610152244568,5237,0,1693.9610152244568,0.9849599599838256,0.0505706556141376,0.2014641590199237,43793,2518.826466560364,0.9888032674789428,0.0379521921277046,0.2372561723309122,0.985903263092041,0.0477237068116664,0.200352854559572,43793 -926.5682003498076,0.5474085807800293,1934.0389623641968,5990,0,1934.0389623641968,0.984913170337677,0.0503614880144596,0.2117786355546247,43793,2861.313567876816,0.9887657761573792,0.0378564335405826,0.244673735261001,0.9857993125915528,0.0475665219128131,0.2110000050222093,43793 -1027.5352380275726,0.5763721466064453,2174.274698495865,6739,0,2174.274698495865,0.985205054283142,0.0494270324707031,0.2250390013924116,43793,3202.564444541931,0.9891668558120728,0.0364877134561538,0.2775128441355874,0.986108660697937,0.0465688593685627,0.2195144180369691,43793 -1131.0144610404968,0.6038436889648438,2414.481337785721,7485,0,2414.481337785721,0.9853411316871644,0.0486358180642128,0.2302153420167335,43793,3546.296167373657,0.9896649718284608,0.0348514281213283,0.3058615825134636,0.9863104224205016,0.0458557456731796,0.228201156598706,43793 -1231.274868965149,0.635286808013916,2654.502183198929,8230,0,2654.502183198929,0.9855673313140868,0.0482541732490062,0.2369594849311999,43793,3886.628069639206,0.9897923469543456,0.0343365855515003,0.3150043600492316,0.98642897605896,0.0455535389482975,0.2336724299582642,43793 -1333.866200208664,0.6630842685699463,2894.498073577881,8977,0,2894.498073577881,0.9856507182121276,0.0477606654167175,0.2460367723696103,43793,4229.2618017196655,0.99002343416214,0.0333950668573379,0.3414138007447258,0.9864451885223388,0.0453813783824443,0.2379311461333866,43793 -1438.5838916301727,0.6951918601989746,3134.520377397537,9726,0,3134.520377397537,0.9857496619224548,0.0478009134531021,0.2438241087385982,43793,4574.052921056747,0.9900118708610536,0.0332933627068996,0.3366389186653354,0.9865231513977052,0.0451737977564334,0.2406404996389662,43793 -1545.8462941646576,0.7231438159942627,3374.544422149658,10464,0,3374.544422149658,0.9857842326164246,0.0477038696408271,0.2462669049591603,43793,4921.391419172287,0.9898894429206848,0.0337132103741169,0.3243495702523755,0.9865986108779908,0.0451174639165401,0.2425262042197452,43793 -1644.2886338233948,0.7532014846801758,3614.744652986528,11209,0,3614.744652986528,0.9858301281929016,0.0476005971431732,0.257797762911955,43793,5260.084446191788,0.9900863766670228,0.032924685627222,0.3577444428522381,0.9866254329681396,0.0449906326830387,0.252842845892908,43793 -1747.719571352005,0.7825462818145752,3854.700940847397,11953,0,3854.700940847397,0.985815405845642,0.0470694378018379,0.2561756224992291,43793,5603.52001285553,0.9902267456054688,0.0325705744326114,0.3536929634358303,0.9866960644721984,0.0445419959723949,0.2551666926403044,43793 -1853.3250706195831,0.812180757522583,4094.770954608917,12698,0,4094.770954608917,0.9857736825942992,0.0474612824618816,0.2538050406626495,43793,5949.243740320206,0.9903358221054076,0.0318615287542343,0.3735765411648799,0.9865893125534058,0.0446997880935668,0.2553556998427123,43793 -1953.5731303691864,0.8416659832000732,4334.911934137344,13452,0,4334.911934137344,0.9858141541481018,0.0467435866594314,0.2679424626869677,43793,6289.6809866428375,0.9905796051025392,0.031181801110506,0.3965822545995702,0.9866871237754822,0.0440555736422538,0.265454619956688,43793 -2058.5749821662903,0.8709104061126709,4575.066102266312,14202,0,4575.066102266312,0.9858920574188232,0.0474026165902614,0.2573156301471659,43793,6634.885343790054,0.9905725121498108,0.0308720637112855,0.3942485855207641,0.9867480397224426,0.0445171296596527,0.2602028697357067,43793 -2161.5691606998444,0.9022881984710692,4815.111397981644,14946,0,4815.111397981644,0.9859813451766968,0.0469692908227443,0.2671970450686207,43793,6977.977539300919,0.9908130168914796,0.0302795935422182,0.4061079881876005,0.986812949180603,0.0440843030810356,0.2687047924207835,43793 -2266.796665430069,0.933624029159546,5055.322886943817,15696,0,5055.322886943817,0.9860504269599916,0.0470215938985347,0.2661189499958922,43793,7323.467862606049,0.9909314513206482,0.0297503676265478,0.4212134215553861,0.9868617057800292,0.0441776104271411,0.2672141950776982,43793 -2366.7448992729187,0.9629554748535156,5295.4963212013245,16444,0,5295.4963212013245,0.9859354496002196,0.0465993173420429,0.2583384105555724,43793,7663.638106822967,0.9908322095870972,0.0304487310349941,0.403097629666715,0.9867277145385742,0.0441249907016754,0.2647822363531874,43793 -2466.7602257728577,0.9934120178222656,5535.677088022232,17196,0,5535.677088022232,0.9858874082565308,0.0464788116514682,0.2605033970119208,43793,8003.883762598038,0.99091899394989,0.0300291255116462,0.4113099548393802,0.9866400361061096,0.044159110635519,0.2626244980078488,43793 -2569.17115855217,1.0242016315460205,5775.816581726074,17942,0,5775.816581726074,0.9860824346542358,0.0468352399766445,0.2684401779090156,43793,8346.484165668488,0.9906976222991944,0.0304965618997812,0.3988523759511924,0.986907124519348,0.0442199669778347,0.2648747445150785,43793 -2668.448446750641,1.055455446243286,6015.941589832306,18695,0,6015.941589832306,0.9860491752624512,0.0467876940965652,0.2734708380090722,43793,8685.937515974045,0.9908912777900696,0.0298970993608236,0.4184613505153205,0.9868507385253906,0.0440458469092845,0.2643315627503256,43793 -2771.006584405899,1.0866594314575195,6256.012636184692,19445,0,6256.012636184692,0.9860230684280396,0.0465541593730449,0.2650854230397668,43793,9028.61729645729,0.990937113761902,0.0297523606568574,0.4243669747970852,0.9867671132087708,0.0439590215682983,0.267616429782502,43793 -2868.755485534668,1.1163904666900637,6496.045476913452,20202,0,6496.045476913452,0.9861186742782592,0.0470700450241565,0.2681579230676182,43793,9366.448516368866,0.9910696148872375,0.029148319736123,0.4278207878860953,0.9869027137756348,0.044284913688898,0.2742441528896123,43793 -2966.293663740158,1.1478157043457031,6736.057872056961,20957,0,6736.057872056961,0.9860588312149048,0.0468362160027027,0.2750265405447729,43793,9704.049597024918,0.9910153150558472,0.029240183532238,0.4452358825783558,0.9868357181549072,0.0443260408937931,0.2762801502975392,43793 -3071.3200681209564,1.1784136295318604,6976.239338636398,21712,0,6976.239338636398,0.9860668182373048,0.0465934127569198,0.273082646921358,43793,10049.307490348816,0.9912437200546264,0.0283992681652307,0.4604391958469054,0.9868929386138916,0.0440583266317844,0.2723770620871654,43793 -3172.609664440155,1.208643913269043,7216.246557235718,22467,0,7216.246557235718,0.98606938123703,0.0467710085213184,0.2726564770947386,43793,10390.653838157654,0.991296112537384,0.028296872973442,0.4488773275626957,0.9868758916854858,0.0440400317311286,0.2674766798136144,43793 -3276.680414915085,1.241858959197998,7456.395405292511,23222,0,7456.395405292511,0.9861102104187012,0.0468363463878631,0.2697631826066058,43793,10734.925931692123,0.9915757775306702,0.0274291150271892,0.4692676798938344,0.986928641796112,0.0437940917909145,0.2775344659929296,43793 -3377.009297847748,1.2736289501190186,7696.503025054932,23968,0,7696.503025054932,0.986065149307251,0.0466242805123329,0.2755573066660334,43793,11075.41303539276,0.991378128528595,0.0281359814107418,0.4653400569960497,0.9869207739830016,0.043841116130352,0.2782805106836062,43793 -3479.635217189789,1.3058831691741943,7936.5431044101715,24710,0,7936.5431044101715,0.9860036373138428,0.0464859828352928,0.2684493736634549,43793,11418.130235671995,0.9912890195846558,0.0285564847290515,0.446825708880746,0.9868153929710388,0.0439945720136165,0.2730327227108191,43793 -3584.0178577899933,1.3375873565673828,8176.756483316421,25457,0,8176.756483316421,0.9860908389091492,0.0465183109045028,0.2790029952847814,43793,11762.777160644531,0.9913530349731444,0.0281521100550889,0.4645335916983681,0.9869887232780457,0.0438172034919261,0.2778937891981876,43793 -3682.528043746948,1.3701817989349363,8416.975894927979,26216,0,8416.975894927979,0.9861236810684204,0.0470030941069126,0.2659943522029855,43793,12101.558113098145,0.9911909699440002,0.0284435730427503,0.4417814461731026,0.98688805103302,0.0441344380378723,0.2709613871925868,43793 -3783.687194108963,1.4013991355895996,8657.060805559158,26968,0,8657.060805559158,0.986162006855011,0.0469765998423099,0.275428224527491,43793,12442.852400064468,0.9912831783294678,0.0281435083597898,0.4646475549258269,0.9869964718818665,0.0441713444888591,0.2806928255369479,43793 -3886.666358947754,1.4324800968170166,8897.122105360031,27724,0,8897.122105360031,0.986177623271942,0.0467942990362644,0.2716887555121834,43793,12785.94291329384,0.991607904434204,0.0271330680698156,0.4827749643581413,0.986968457698822,0.0441259928047657,0.27225728252014,43793 -3991.973956346512,1.4689247608184814,9137.191581249235,28474,0,9137.191581249235,0.986123263835907,0.0466069392859935,0.2744048140787645,43793,13131.37703680992,0.9916527271270752,0.027179455384612,0.4807633852413003,0.9868791699409484,0.0440702252089977,0.2754178447037879,43793 -4093.964473724365,1.5014734268188477,9377.420338869097,29223,0,9377.420338869097,0.986240804195404,0.046885460615158,0.2749978669854959,43793,13473.64820098877,0.9917328357696532,0.0264650303870439,0.5178903073222842,0.9871068596839904,0.0441155098378658,0.2825732089125649,43793 -4197.028309106827,1.5345180034637451,9617.694372415544,29972,0,9617.694372415544,0.9862378239631652,0.0468265935778617,0.2805010265121149,43793,13817.03848195076,0.9918742179870604,0.0262527354061603,0.5005576253532122,0.9870269298553468,0.0442079156637191,0.2820148549135489,43793 -4303.904276847839,1.5686705112457275,9857.652417898178,30714,0,9857.652417898178,0.9860790371894836,0.0469512119889259,0.2753131299952883,43793,14163.926680326462,0.9917741417884828,0.0264660846441984,0.4974052599352223,0.986899435520172,0.0442755110561847,0.2783416322284315,43793 -4404.187355518341,1.6013495922088623,10097.801033735275,31471,0,10097.801033735275,0.9862290024757384,0.0471750982105731,0.2739805160313872,43793,14504.410804271698,0.9916498064994812,0.0269456524401903,0.4888814574380777,0.9870484471321106,0.0444845706224441,0.2807809220691736,43793 -4506.080503463745,1.637082815170288,10337.873920917513,32224,0,10337.873920917513,0.9862456321716307,0.0467881597578525,0.2771946456173613,43793,14846.432065725328,0.9916934370994568,0.0268517304211854,0.4979061005428039,0.9870707392692566,0.0440482571721077,0.2837913263385009,43793 -4607.861759901047,1.6719791889190674,10577.960082054138,32971,0,10577.960082054138,0.9862130284309388,0.0467044077813625,0.2748783828748157,43793,15188.354211330414,0.9917108416557312,0.0268477853387594,0.4847077089609162,0.9870195984840392,0.0439106449484825,0.2810509003984059,43793 -4708.640298604965,1.7060413360595703,10818.10796546936,33723,0,10818.10796546936,0.986255943775177,0.0467122979462146,0.2788147279553553,43793,15529.334302663803,0.9917195439338684,0.0264511443674564,0.5071388805226102,0.9871170520782472,0.0440160892903804,0.2869082133800527,43793 -4815.907956838608,1.7401671409606934,11058.298412799835,34468,0,11058.298412799835,0.986292600631714,0.0474320203065872,0.2736994714566074,43793,15876.846804141998,0.9917593002319336,0.0264794621616601,0.4987217774287702,0.9870354533195496,0.0445708893239498,0.2862611577565212,43793 -4917.55984044075,1.7752108573913574,11298.292127609251,35216,0,11298.292127609251,0.986276149749756,0.0471402741968631,0.2778236007067979,43793,16218.54702448845,0.9919956922531128,0.0257887691259384,0.514730269745362,0.9871243238449096,0.0442296601831913,0.2876763413028705,43793 -5019.775372505188,1.809107780456543,11538.531027317047,35967,0,11538.531027317047,0.9861658215522766,0.047429759055376,0.2739156328156931,43793,16561.054585933685,0.9921483397483826,0.0250180773437023,0.5371498315246481,0.9870285391807556,0.0445261783897876,0.2823993135022633,43793 -5118.604428529739,1.8418443202972408,11778.565999269484,36721,0,11778.565999269484,0.9860352277755736,0.046927087008953,0.2776079649370204,43793,16899.971519231796,0.9923694133758544,0.0246111080050468,0.5423980058733963,0.9868807792663574,0.0442409552633762,0.2812444716684876,43793 -5222.37273478508,1.878031730651856,12018.567860364914,37462,0,12018.567860364914,0.986195743083954,0.0474725700914859,0.2828257711099307,43793,17243.798904657364,0.9923901557922364,0.0243136920034885,0.5580372747770705,0.9870139360427856,0.0445088073611259,0.2837891887851236,43793 -5324.253618955612,1.9119784832000728,12258.701553821564,38214,0,12258.701553821564,0.9862033128738404,0.0475928448140621,0.2776924273187139,43793,17585.86657810211,0.9922755360603333,0.0247705467045307,0.5347153345263589,0.9870476126670836,0.0448093898594379,0.2813780929325306,43793 -5425.832328796387,2.3672399520874023,12498.43436551094,38966,0,12498.43436551094,0.9862778782844543,0.0475385151803493,0.2776829744498602,43793,17927.652944803238,0.9921236634254456,0.0251753758639097,0.5225601239481771,0.9870768189430236,0.044680256396532,0.2864147846605825,43793 -5530.745029449463,2.4028241634368896,12738.577870845796,39706,0,12738.577870845796,0.9863018989562988,0.0470027439296245,0.2751548149248845,43793,18272.76641869545,0.9922091364860536,0.0250520091503858,0.534985723044657,0.9870837330818176,0.0441632904112339,0.2902683111279471,43793 -5629.910222530365,2.4375898838043213,12978.589567184448,40458,0,12978.589567184448,0.986185610294342,0.0472837202250957,0.2740302064957952,43793,18611.997394800183,0.9922121167182922,0.0248994305729866,0.5352071481301257,0.9870695471763612,0.0442797504365444,0.2866353280879107,43793 -5727.685400247574,2.4737045764923096,13218.623850822449,41211,0,13218.623850822449,0.9861717224121094,0.0471930019557476,0.2720835039963426,43793,18949.86185884476,0.9924872517585754,0.0240594092756509,0.5571624292463508,0.9870370626449584,0.0443583317101001,0.2841668355790797,43793 -5827.797815561295,2.508739709854126,13458.846201658249,41969,0,13458.846201658249,0.9861666560173036,0.0474649108946323,0.2750019408692665,43793,19290.2509264946,0.9924891591072084,0.0239780861884355,0.5498505304311031,0.9871364831924438,0.0445279330015182,0.292061123135377,43793 -5928.517250061035,2.544694185256958,13698.999931812286,42708,0,13698.999931812286,0.9862711429595948,0.0472346059978008,0.2784730233764366,43793,19631.180638074875,0.9924418926239014,0.0239533670246601,0.5589275876659786,0.987052083015442,0.0447279922664165,0.2869159028380207,43793 -6029.976556539536,2.5795791149139404,13939.12776207924,43459,0,13939.12776207924,0.9862121343612672,0.0479162074625492,0.2752638868043979,43793,19972.82164144516,0.9928225874900818,0.022863321006298,0.5811168779417648,0.987057328224182,0.0449929870665073,0.2861434822523052,43793 -6127.8829135894775,2.6141395568847656,14179.133565664291,44206,0,14179.133565664291,0.9861574172973632,0.0479033477604389,0.2757788378159617,43793,20310.78754711151,0.9930243492126464,0.0223029311746358,0.6021831467185383,0.9870301485061646,0.0448726527392864,0.2870586605260568,43793 -6228.195499420166,2.649040937423706,14419.119592666626,44963,0,14419.119592666626,0.9862648248672484,0.0477472729980945,0.2821340983594425,43793,20651.140026569366,0.9930409789085388,0.0221349876374006,0.602127596669349,0.98716002702713,0.0449102371931076,0.2858967786479935,43793 -6327.404683113098,2.686012029647827,14659.334668159485,45727,0,14659.334668159485,0.9863384962081908,0.0479014739394187,0.2862644840320328,43793,20990.62044978141,0.9928721785545348,0.0226054582744836,0.583901784898843,0.9871146082878112,0.0450869426131248,0.2856805049431565,43793 -6423.754679679871,2.7223784923553467,14899.518951892853,46489,0,14899.518951892853,0.9862883687019348,0.0479216575622558,0.2808426418318341,43793,21327.210332155228,0.992776334285736,0.0229767374694347,0.5796602909704152,0.9870861768722534,0.0449971593916416,0.2878844677435476,43793 -6528.876852750778,2.7595677375793457,15139.666759252548,47245,0,15139.666759252548,0.9861990809440612,0.0483264736831188,0.2735987649905027,43793,21672.53662109375,0.9926639199256896,0.0231765508651733,0.5689805751294961,0.987066686153412,0.0454099029302597,0.2847175279822495,43793 -6627.028215408325,2.79514479637146,15379.932568311691,47994,0,15379.932568311691,0.9861843585968018,0.0484066344797611,0.2701901439753169,43793,22011.00979018212,0.9927076697349548,0.0229984018951654,0.5812239180112349,0.9869928359985352,0.0454793311655521,0.2820994474233479,43793 -6722.673615455627,2.831093072891236,15620.12927222252,48748,0,15620.12927222252,0.9862820506095886,0.0481756180524826,0.2763885679228161,43793,22346.907153129578,0.992895781993866,0.0225099623203277,0.6029228174570651,0.9870654940605164,0.0453277118504047,0.2864824900130744,43793 -6826.473671674728,2.866787433624268,15860.07743358612,49500,0,15860.07743358612,0.9862605929374696,0.0483807176351547,0.2754886985604703,43793,22690.710390090942,0.9930621981620787,0.0219377987086772,0.6060850216324132,0.9870135188102722,0.0454810671508312,0.2898876379847001,43793 -6922.95885682106,2.9028890132904053,16100.23567533493,50255,0,16100.23567533493,0.9861767888069152,0.0484723262488842,0.273646059366143,43793,23027.40929079056,0.9931796193122864,0.0217054821550846,0.5990312491287693,0.9870184063911438,0.045732669532299,0.2864895013825414,43793 -7027.812830686569,2.9394538402557373,16340.418535232544,51016,0,16340.418535232544,0.9862424731254578,0.0485534220933914,0.2766943537149877,43793,23372.501963615417,0.9933199286460876,0.0210258439183235,0.6328068561122213,0.9870354533195496,0.0458424612879753,0.2870732851061626,43793 -7129.999175310135,2.975675106048584,16580.547554254532,51773,0,16580.547554254532,0.9862707257270812,0.0488488227128982,0.2793308413871986,43793,23714.87309408188,0.993573784828186,0.020389275625348,0.6390301099909069,0.9870334267616272,0.0460025481879711,0.2893493358659952,43793 -7231.016023159027,3.0121028423309326,16820.687520742416,52523,0,16820.687520742416,0.9861144423484802,0.048769537359476,0.2781312477674411,43793,24056.08555865288,0.9938708543777466,0.0197893120348453,0.6536397233896962,0.9869379997253418,0.045902457088232,0.2885353006882328,43793 -7331.012581825256,3.0534260272979736,17060.65802717209,53269,0,17060.65802717209,0.9862328171730042,0.0491791106760501,0.2761081457071273,43793,24396.114936590195,0.9936069250106812,0.0201866496354341,0.6379870321649714,0.98707115650177,0.0463605523109436,0.2920631409937257,43793 -7433.599872112274,3.091221570968628,17300.819898843765,54016,0,17300.819898843765,0.9862012267112732,0.0492318421602249,0.2789847800402106,43793,24738.92270088196,0.993407905101776,0.020683042705059,0.6422283249361697,0.9870346188545228,0.0463643260300159,0.2859237206899969,43793 -7534.629323959351,3.1281652450561523,17541.016096830368,54759,0,17541.016096830368,0.9863005876541138,0.0494023486971855,0.2773694298836116,43793,25080.20478367805,0.993454933166504,0.0205811411142349,0.6245154233950967,0.9870610237121582,0.0465365014970302,0.2884550829211076,43793 -7630.906837940216,3.166072130203247,17781.16787624359,55505,0,17781.16787624359,0.98625510931015,0.0496644601225853,0.2748268458454385,43793,25416.691744089127,0.9934847950935364,0.0203968025743961,0.6366287142374494,0.9871174097061156,0.0465443395078182,0.2906898791426333,43793 -7733.94634437561,3.202991008758545,18021.238358020782,56252,0,18021.238358020782,0.9861961603164672,0.0496541261672973,0.2770761467327258,43793,25759.857971429825,0.993678867816925,0.0199342537671327,0.6509323873209593,0.9869863390922546,0.0468346439301967,0.2884947556034788,43793 -7831.50794172287,3.240222930908203,18261.479460000992,57001,0,18261.479460000992,0.9861481189727783,0.04986385628581047,0.2743496012662404,43793,26097.717761278152,0.9936870336532593,0.01976962573826313,0.6579927457781373,0.9869254231452942,0.0470384806394577,0.2847933420483199,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/measurements.csv deleted file mode 100644 index 9f75bf47b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/measurements.csv +++ /dev/null @@ -1,656 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.0256538,0.7228075,,,,,,,,,,,,,,,,, -1,,,0.5250520706176758,0.7150473594665527,0.0223154203866871,0.5213832855224609,0.7166012525558472,0.0261420653014558,43793.0,0.5224818587303162,0.7161954641342163,0.0278394680676913,43793.0,13.2299907207489,116.30002975463869,13.2299907207489,103.0699977874756,0.0,0.0 -100,0.26948208,0.25041583,,,,,,,,,,,,,,,,, -200,0.084610455,0.10300187,,,,,,,,,,,,,,,,, -300,0.028566165,0.064968824,,,,,,,,,,,,,,,,, -400,0.023708766,0.06121402,,,,,,,,,,,,,,,,, -500,0.012395626,0.05290257,,,,,,,,,,,,,,,,, -600,0.0327294,0.057464734,,,,,,,,,,,,,,,,, -700,0.015620986,0.046034243,,,,,,,,,,,,,,,,, -758,,,0.986740231513977,0.0512496270239353,0.0562020739249715,0.9841589331626892,0.0611993111670017,0.0552324893701033,43793.0,0.9832048416137696,0.0645277202129364,0.0557577787410884,43793.0,253.348135471344,458.4430618286133,253.348135471344,205.0555601119995,0.0197958946228027,0.0 -800,0.01942055,0.054119736,,,,,,,,,,,,,,,,, -900,0.015586647,0.052434843,,,,,,,,,,,,,,,,, -1000,0.026044806,0.05201346,,,,,,,,,,,,,,,,, -1100,0.016221242,0.04959465,,,,,,,,,,,,,,,,, -1200,0.016622104,0.050890233,,,,,,,,,,,,,,,,, -1300,0.026223961,0.048248205,,,,,,,,,,,,,,,,, -1400,0.012097944,0.045386616,,,,,,,,,,,,,,,,, -1500,,,0.9867498874664308,0.0503420047461986,0.0742255061412899,0.9842255115509032,0.0604429580271244,0.0718821453766525,43793.0,0.9832578897476196,0.0637605413794517,0.0747159425183009,43793.0,493.5868184566498,805.2762854099274,493.5868184566498,311.6006577014923,0.0477762222290039,0.0 -1500,0.04487256,0.044993732,,,,,,,,,,,,,,,,, -1600,0.056843247,0.04895391,,,,,,,,,,,,,,,,, -1700,0.025313012,0.052258797,,,,,,,,,,,,,,,,, -1800,0.015197262,0.049300913,,,,,,,,,,,,,,,,, -1900,0.0148859965,0.049096152,,,,,,,,,,,,,,,,, -2000,0.028357387,0.044809785,,,,,,,,,,,,,,,,, -2100,0.025620095,0.041638434,,,,,,,,,,,,,,,,, -2200,0.016712643,0.045324158,,,,,,,,,,,,,,,,, -2254,,,0.9876054525375366,0.0443707220256328,0.1313076185234076,0.9849160313606262,0.0536021776497364,0.1235890399748405,43793.0,0.9839398264884948,0.0567453317344188,0.1191155747065219,43793.0,733.4666738510132,1146.283578157425,733.4666738510132,412.3603291511536,0.3961923122406006,0.0 -2300,0.01866059,0.042365182,,,,,,,,,,,,,,,,, -2400,0.011688546,0.044706758,,,,,,,,,,,,,,,,, -2500,0.01186619,0.040179376,,,,,,,,,,,,,,,,, -2600,0.010394438,0.047445055,,,,,,,,,,,,,,,,, -2700,0.014613482,0.040636484,,,,,,,,,,,,,,,,, -2800,0.013967117,0.04021071,,,,,,,,,,,,,,,,, -2900,0.013140153,0.039843068,,,,,,,,,,,,,,,,, -2992,,,0.9879937171936036,0.0425633452832698,0.156441868377698,0.9850869178771972,0.0516696348786354,0.1419235231911762,43793.0,0.984123468399048,0.0544441118836402,0.140199927936232,43793.0,973.4950165748596,1488.5906157493591,973.4950165748596,514.584146976471,0.431208848953247,0.0 -3000,0.013300484,0.044944897,,,,,,,,,,,,,,,,, -3100,0.011381927,0.04484059,,,,,,,,,,,,,,,,, -3200,0.013667951,0.039614484,,,,,,,,,,,,,,,,, -3300,0.02498446,0.04238184,,,,,,,,,,,,,,,,, -3400,0.021729944,0.043018006,,,,,,,,,,,,,,,,, -3500,0.011229767,0.042252176,,,,,,,,,,,,,,,,, -3600,0.019297924,0.03887076,,,,,,,,,,,,,,,,, -3700,0.01902157,0.042525273,,,,,,,,,,,,,,,,, -3734,,,0.9882489442825316,0.0411407016217708,0.1847905469470313,0.9853414297103882,0.0505308620631694,0.1624161177546391,43793.0,0.9844486117362976,0.053183663636446,0.161748321889361,43793.0,1213.5263085365295,1830.4661691188808,1213.5263085365295,616.377361536026,0.4610466957092285,0.0 -3800,0.010284366,0.040983792,,,,,,,,,,,,,,,,, -3900,0.00959744,0.042956974,,,,,,,,,,,,,,,,, -4000,0.024854736,0.04598168,,,,,,,,,,,,,,,,, -4100,0.013159732,0.037080195,,,,,,,,,,,,,,,,, -4200,0.010189037,0.04346478,,,,,,,,,,,,,,,,, -4300,0.011869594,0.041386705,,,,,,,,,,,,,,,,, -4400,0.013387328,0.032886982,,,,,,,,,,,,,,,,, -4485,,,0.9885805249214172,0.0395596623420715,0.2050289150720067,0.9856662154197692,0.0490881167352199,0.183546899527389,43793.0,0.9847337603569032,0.0520514212548732,0.1894657997357002,43793.0,1453.7445459365845,2177.53892993927,1453.7445459365845,723.1845881938934,0.4890384674072265,0.0 -4500,0.014609794,0.040108826,,,,,,,,,,,,,,,,, -4600,0.010360192,0.04516131,,,,,,,,,,,,,,,,, -4700,0.014120975,0.042915422,,,,,,,,,,,,,,,,, -4800,0.015606862,0.040503837,,,,,,,,,,,,,,,,, -4900,0.011676624,0.03814533,,,,,,,,,,,,,,,,, -5000,0.0089523075,0.036923368,,,,,,,,,,,,,,,,, -5100,0.018336909,0.039519962,,,,,,,,,,,,,,,,, -5200,0.013878049,0.039385725,,,,,,,,,,,,,,,,, -5237,,,0.9888032674789428,0.0379521921277046,0.2372561723309122,0.985903263092041,0.0477237068116664,0.200352854559572,43793.0,0.9849599599838256,0.0505706556141376,0.2014641590199237,43793.0,1693.9610152244568,2518.826466560364,1693.9610152244568,824.2064425945282,0.5192301273345947,0.0 -5300,0.0123178195,0.038188554,,,,,,,,,,,,,,,,, -5400,0.010309091,0.03626551,,,,,,,,,,,,,,,,, -5500,0.011549067,0.041204415,,,,,,,,,,,,,,,,, -5600,0.012753935,0.035012856,,,,,,,,,,,,,,,,, -5700,0.013462728,0.03832923,,,,,,,,,,,,,,,,, -5800,0.016392784,0.035578523,,,,,,,,,,,,,,,,, -5900,0.018056944,0.038484566,,,,,,,,,,,,,,,,, -5990,,,0.9887657761573792,0.0378564335405826,0.244673735261001,0.9857993125915528,0.0475665219128131,0.2110000050222093,43793.0,0.984913170337677,0.0503614880144596,0.2117786355546247,43793.0,1934.0389623641968,2861.313567876816,1934.0389623641968,926.5682003498076,0.5474085807800293,0.0 -6000,0.010021587,0.035840627,,,,,,,,,,,,,,,,, -6100,0.018432336,0.039600726,,,,,,,,,,,,,,,,, -6200,0.021224802,0.037730414,,,,,,,,,,,,,,,,, -6300,0.02327763,0.038259443,,,,,,,,,,,,,,,,, -6400,0.012077622,0.03502231,,,,,,,,,,,,,,,,, -6500,0.012611324,0.034542687,,,,,,,,,,,,,,,,, -6600,0.017875774,0.03800651,,,,,,,,,,,,,,,,, -6700,0.019891333,0.039206684,,,,,,,,,,,,,,,,, -6739,,,0.9891668558120728,0.0364877134561538,0.2775128441355874,0.986108660697937,0.0465688593685627,0.2195144180369691,43793.0,0.985205054283142,0.0494270324707031,0.2250390013924116,43793.0,2174.274698495865,3202.564444541931,2174.274698495865,1027.5352380275726,0.5763721466064453,0.0 -6800,0.017224971,0.039047305,,,,,,,,,,,,,,,,, -6900,0.01818762,0.040468354,,,,,,,,,,,,,,,,, -7000,0.024617456,0.034624223,,,,,,,,,,,,,,,,, -7100,0.016090408,0.037045617,,,,,,,,,,,,,,,,, -7200,0.01257639,0.03379787,,,,,,,,,,,,,,,,, -7300,0.015927061,0.037333682,,,,,,,,,,,,,,,,, -7400,0.015694914,0.031552,,,,,,,,,,,,,,,,, -7485,,,0.9896649718284608,0.0348514281213283,0.3058615825134636,0.9863104224205016,0.0458557456731796,0.228201156598706,43793.0,0.9853411316871644,0.0486358180642128,0.2302153420167335,43793.0,2414.481337785721,3546.296167373657,2414.481337785721,1131.0144610404968,0.6038436889648438,0.0 -7500,0.017396336,0.03625717,,,,,,,,,,,,,,,,, -7600,0.013266593,0.03435185,,,,,,,,,,,,,,,,, -7700,0.015722794,0.04102468,,,,,,,,,,,,,,,,, -7800,0.01977235,0.037604205,,,,,,,,,,,,,,,,, -7900,0.01153061,0.03419411,,,,,,,,,,,,,,,,, -8000,0.02001891,0.037541483,,,,,,,,,,,,,,,,, -8100,0.015019611,0.04011926,,,,,,,,,,,,,,,,, -8200,0.016229501,0.03461629,,,,,,,,,,,,,,,,, -8230,,,0.9897923469543456,0.0343365855515003,0.3150043600492316,0.98642897605896,0.0455535389482975,0.2336724299582642,43793.0,0.9855673313140868,0.0482541732490062,0.2369594849311999,43793.0,2654.502183198929,3886.628069639206,2654.502183198929,1231.274868965149,0.635286808013916,0.0 -8300,0.01311641,0.037831854,,,,,,,,,,,,,,,,, -8400,0.014456641,0.03518669,,,,,,,,,,,,,,,,, -8500,0.019599907,0.03379037,,,,,,,,,,,,,,,,, -8600,0.026242789,0.037347667,,,,,,,,,,,,,,,,, -8700,0.016122496,0.035592645,,,,,,,,,,,,,,,,, -8800,0.016676445,0.033707187,,,,,,,,,,,,,,,,, -8900,0.022311974,0.040084578,,,,,,,,,,,,,,,,, -8977,,,0.99002343416214,0.0333950668573379,0.3414138007447258,0.9864451885223388,0.0453813783824443,0.2379311461333866,43793.0,0.9856507182121276,0.0477606654167175,0.2460367723696103,43793.0,2894.498073577881,4229.2618017196655,2894.498073577881,1333.866200208664,0.6630842685699463,0.0 -9000,0.020467976,0.03470736,,,,,,,,,,,,,,,,, -9100,0.018110268,0.036027018,,,,,,,,,,,,,,,,, -9200,0.016362717,0.030877935,,,,,,,,,,,,,,,,, -9300,0.017081274,0.035427485,,,,,,,,,,,,,,,,, -9400,0.014904575,0.033260353,,,,,,,,,,,,,,,,, -9500,0.01645154,0.029671159,,,,,,,,,,,,,,,,, -9600,0.023000108,0.03524796,,,,,,,,,,,,,,,,, -9700,0.022527918,0.035424523,,,,,,,,,,,,,,,,, -9726,,,0.9900118708610536,0.0332933627068996,0.3366389186653354,0.9865231513977052,0.0451737977564334,0.2406404996389662,43793.0,0.9857496619224548,0.0478009134531021,0.2438241087385982,43793.0,3134.520377397537,4574.052921056747,3134.520377397537,1438.5838916301727,0.6951918601989746,0.0 -9800,0.027551938,0.034660824,,,,,,,,,,,,,,,,, -9900,0.02323153,0.03573762,,,,,,,,,,,,,,,,, -10000,0.022850432,0.032336567,,,,,,,,,,,,,,,,, -10100,0.019246727,0.038558632,,,,,,,,,,,,,,,,, -10200,0.018089384,0.031445324,,,,,,,,,,,,,,,,, -10300,0.023752479,0.037496474,,,,,,,,,,,,,,,,, -10400,0.03517672,0.034019172,,,,,,,,,,,,,,,,, -10464,,,0.9898894429206848,0.0337132103741169,0.3243495702523755,0.9865986108779908,0.0451174639165401,0.2425262042197452,43793.0,0.9857842326164246,0.0477038696408271,0.2462669049591603,43793.0,3374.544422149658,4921.391419172287,3374.544422149658,1545.8462941646576,0.7231438159942627,0.0 -10500,0.0246651,0.034513667,,,,,,,,,,,,,,,,, -10600,0.031501073,0.03642073,,,,,,,,,,,,,,,,, -10700,0.023047209,0.034739245,,,,,,,,,,,,,,,,, -10800,0.019209355,0.032559264,,,,,,,,,,,,,,,,, -10900,0.03689941,0.039566275,,,,,,,,,,,,,,,,, -11000,0.023901576,0.03491412,,,,,,,,,,,,,,,,, -11100,0.01840742,0.03162341,,,,,,,,,,,,,,,,, -11200,0.027051598,0.034892406,,,,,,,,,,,,,,,,, -11209,,,0.9900863766670228,0.032924685627222,0.3577444428522381,0.9866254329681396,0.0449906326830387,0.252842845892908,43793.0,0.9858301281929016,0.0476005971431732,0.257797762911955,43793.0,3614.744652986528,5260.084446191788,3614.744652986528,1644.2886338233948,0.7532014846801758,0.0 -11300,0.027498154,0.036502328,,,,,,,,,,,,,,,,, -11400,0.019949196,0.035297893,,,,,,,,,,,,,,,,, -11500,0.026471809,0.03446974,,,,,,,,,,,,,,,,, -11600,0.02114853,0.03279827,,,,,,,,,,,,,,,,, -11700,0.022156436,0.030540844,,,,,,,,,,,,,,,,, -11800,0.027478516,0.031449553,,,,,,,,,,,,,,,,, -11900,0.027898842,0.036984134,,,,,,,,,,,,,,,,, -11953,,,0.9902267456054688,0.0325705744326114,0.3536929634358303,0.9866960644721984,0.0445419959723949,0.2551666926403044,43793.0,0.985815405845642,0.0470694378018379,0.2561756224992291,43793.0,3854.700940847397,5603.52001285553,3854.700940847397,1747.719571352005,0.7825462818145752,0.0 -12000,0.023747506,0.032951955,,,,,,,,,,,,,,,,, -12100,0.02953478,0.03515718,,,,,,,,,,,,,,,,, -12200,0.023445707,0.031546388,,,,,,,,,,,,,,,,, -12300,0.021252671,0.029842021,,,,,,,,,,,,,,,,, -12400,0.028649937,0.031518873,,,,,,,,,,,,,,,,, -12500,0.022221286,0.032073703,,,,,,,,,,,,,,,,, -12600,0.027568875,0.033803664,,,,,,,,,,,,,,,,, -12698,,,0.9903358221054076,0.0318615287542343,0.3735765411648799,0.9865893125534058,0.0446997880935668,0.2553556998427123,43793.0,0.9857736825942992,0.0474612824618816,0.2538050406626495,43793.0,4094.770954608917,5949.243740320206,4094.770954608917,1853.3250706195831,0.812180757522583,0.0 -12700,0.03123418,0.031787954,,,,,,,,,,,,,,,,, -12800,0.03329509,0.035296816,,,,,,,,,,,,,,,,, -12900,0.02989626,0.03274244,,,,,,,,,,,,,,,,, -13000,0.024785286,0.036559008,,,,,,,,,,,,,,,,, -13100,0.034832586,0.035759725,,,,,,,,,,,,,,,,, -13200,0.0514489,0.03946108,,,,,,,,,,,,,,,,, -13300,0.03302067,0.031116836,,,,,,,,,,,,,,,,, -13400,0.03546624,0.037232373,,,,,,,,,,,,,,,,, -13452,,,0.9905796051025392,0.031181801110506,0.3965822545995702,0.9866871237754822,0.0440555736422538,0.265454619956688,43793.0,0.9858141541481018,0.0467435866594314,0.2679424626869677,43793.0,4334.911934137344,6289.6809866428375,4334.911934137344,1953.5731303691864,0.8416659832000732,0.0 -13500,0.027333532,0.029845499,,,,,,,,,,,,,,,,, -13600,0.024909217,0.0329318,,,,,,,,,,,,,,,,, -13700,0.030725066,0.03374297,,,,,,,,,,,,,,,,, -13800,0.025980446,0.033023622,,,,,,,,,,,,,,,,, -13900,0.035556808,0.03329569,,,,,,,,,,,,,,,,, -14000,0.03309363,0.03707571,,,,,,,,,,,,,,,,, -14100,0.036285162,0.036257435,,,,,,,,,,,,,,,,, -14200,0.033225037,0.032013115,,,,,,,,,,,,,,,,, -14202,,,0.9905725121498108,0.0308720637112855,0.3942485855207641,0.9867480397224426,0.0445171296596527,0.2602028697357067,43793.0,0.9858920574188232,0.0474026165902614,0.2573156301471659,43793.0,4575.066102266312,6634.885343790054,4575.066102266312,2058.5749821662903,0.8709104061126709,0.0 -14300,0.031778373,0.03143225,,,,,,,,,,,,,,,,, -14400,0.028142633,0.032353632,,,,,,,,,,,,,,,,, -14500,0.033136852,0.032662712,,,,,,,,,,,,,,,,, -14600,0.029023282,0.03132193,,,,,,,,,,,,,,,,, -14700,0.05496665,0.030597987,,,,,,,,,,,,,,,,, -14800,0.035126023,0.034159686,,,,,,,,,,,,,,,,, -14900,0.02675824,0.028494675,,,,,,,,,,,,,,,,, -14946,,,0.9908130168914796,0.0302795935422182,0.4061079881876005,0.986812949180603,0.0440843030810356,0.2687047924207835,43793.0,0.9859813451766968,0.0469692908227443,0.2671970450686207,43793.0,4815.111397981644,6977.977539300919,4815.111397981644,2161.5691606998444,0.9022881984710692,0.0 -15000,0.036978196,0.034156695,,,,,,,,,,,,,,,,, -15100,0.032228973,0.029012144,,,,,,,,,,,,,,,,, -15200,0.03798353,0.033745777,,,,,,,,,,,,,,,,, -15300,0.0388135,0.03444841,,,,,,,,,,,,,,,,, -15400,0.03689454,0.03537468,,,,,,,,,,,,,,,,, -15500,0.03541033,0.03252681,,,,,,,,,,,,,,,,, -15600,0.04730254,0.029095866,,,,,,,,,,,,,,,,, -15696,,,0.9909314513206482,0.0297503676265478,0.4212134215553861,0.9868617057800292,0.0441776104271411,0.2672141950776982,43793.0,0.9860504269599916,0.0470215938985347,0.2661189499958922,43793.0,5055.322886943817,7323.467862606049,5055.322886943817,2266.796665430069,0.933624029159546,0.0 -15700,0.028805425,0.02861907,,,,,,,,,,,,,,,,, -15800,0.04013955,0.03127732,,,,,,,,,,,,,,,,, -15900,0.03741093,0.035257865,,,,,,,,,,,,,,,,, -16000,0.03383375,0.032154422,,,,,,,,,,,,,,,,, -16100,0.03297166,0.031361524,,,,,,,,,,,,,,,,, -16200,0.050316606,0.03278974,,,,,,,,,,,,,,,,, -16300,0.054438945,0.027537871,,,,,,,,,,,,,,,,, -16400,0.038249224,0.033797156,,,,,,,,,,,,,,,,, -16444,,,0.9908322095870972,0.0304487310349941,0.403097629666715,0.9867277145385742,0.0441249907016754,0.2647822363531874,43793.0,0.9859354496002196,0.0465993173420429,0.2583384105555724,43793.0,5295.4963212013245,7663.638106822967,5295.4963212013245,2366.7448992729187,0.9629554748535156,0.0 -16500,0.037001673,0.033847842,,,,,,,,,,,,,,,,, -16600,0.037876405,0.033514865,,,,,,,,,,,,,,,,, -16700,0.04605842,0.032037206,,,,,,,,,,,,,,,,, -16800,0.0538775,0.032243405,,,,,,,,,,,,,,,,, -16900,0.05431527,0.03290583,,,,,,,,,,,,,,,,, -17000,0.03641762,0.03130387,,,,,,,,,,,,,,,,, -17100,0.044506032,0.032614037,,,,,,,,,,,,,,,,, -17196,,,0.99091899394989,0.0300291255116462,0.4113099548393802,0.9866400361061096,0.044159110635519,0.2626244980078488,43793.0,0.9858874082565308,0.0464788116514682,0.2605033970119208,43793.0,5535.677088022232,8003.883762598038,5535.677088022232,2466.7602257728577,0.9934120178222656,0.0 -17200,0.071871065,0.031202508,,,,,,,,,,,,,,,,, -17300,0.03732602,0.030580603,,,,,,,,,,,,,,,,, -17400,0.03697258,0.032242928,,,,,,,,,,,,,,,,, -17500,0.049709957,0.03350746,,,,,,,,,,,,,,,,, -17600,0.031722143,0.03036042,,,,,,,,,,,,,,,,, -17700,0.049008913,0.030440507,,,,,,,,,,,,,,,,, -17800,0.047186628,0.033282176,,,,,,,,,,,,,,,,, -17900,0.0439862,0.028748922,,,,,,,,,,,,,,,,, -17942,,,0.9906976222991944,0.0304965618997812,0.3988523759511924,0.986907124519348,0.0442199669778347,0.2648747445150785,43793.0,0.9860824346542358,0.0468352399766445,0.2684401779090156,43793.0,5775.816581726074,8346.484165668488,5775.816581726074,2569.17115855217,1.0242016315460205,0.0 -18000,0.04245082,0.036056492,,,,,,,,,,,,,,,,, -18100,0.033629663,0.029200548,,,,,,,,,,,,,,,,, -18200,0.046888337,0.02944463,,,,,,,,,,,,,,,,, -18300,0.045700952,0.035032947,,,,,,,,,,,,,,,,, -18400,0.040943287,0.03179815,,,,,,,,,,,,,,,,, -18500,0.037476055,0.033887178,,,,,,,,,,,,,,,,, -18600,0.04486218,0.032944925,,,,,,,,,,,,,,,,, -18695,,,0.9908912777900696,0.0298970993608236,0.4184613505153205,0.9868507385253906,0.0440458469092845,0.2643315627503256,43793.0,0.9860491752624512,0.0467876940965652,0.2734708380090722,43793.0,6015.941589832306,8685.937515974045,6015.941589832306,2668.448446750641,1.055455446243286,0.0 -18700,0.046364307,0.030899825,,,,,,,,,,,,,,,,, -18800,0.038832363,0.031069793,,,,,,,,,,,,,,,,, -18900,0.039849427,0.034305654,,,,,,,,,,,,,,,,, -19000,0.046960063,0.036787935,,,,,,,,,,,,,,,,, -19100,0.041803606,0.032923665,,,,,,,,,,,,,,,,, -19200,0.045600887,0.032505203,,,,,,,,,,,,,,,,, -19300,0.038993035,0.032766372,,,,,,,,,,,,,,,,, -19400,0.042822376,0.032252297,,,,,,,,,,,,,,,,, -19445,,,0.990937113761902,0.0297523606568574,0.4243669747970852,0.9867671132087708,0.0439590215682983,0.267616429782502,43793.0,0.9860230684280396,0.0465541593730449,0.2650854230397668,43793.0,6256.012636184692,9028.61729645729,6256.012636184692,2771.006584405899,1.0866594314575195,0.0 -19500,0.0424141,0.031566598,,,,,,,,,,,,,,,,, -19600,0.04420846,0.030906368,,,,,,,,,,,,,,,,, -19700,0.03742398,0.030194182,,,,,,,,,,,,,,,,, -19800,0.057047524,0.031071698,,,,,,,,,,,,,,,,, -19900,0.042139802,0.02807094,,,,,,,,,,,,,,,,, -20000,0.054679945,0.03279669,,,,,,,,,,,,,,,,, -20100,0.045596536,0.030421618,,,,,,,,,,,,,,,,, -20200,0.053313006,0.029209495,,,,,,,,,,,,,,,,, -20202,,,0.9910696148872375,0.029148319736123,0.4278207878860953,0.9869027137756348,0.044284913688898,0.2742441528896123,43793.0,0.9861186742782592,0.0470700450241565,0.2681579230676182,43793.0,6496.045476913452,9366.448516368866,6496.045476913452,2868.755485534668,1.1163904666900637,0.0 -20300,0.05619448,0.03149921,,,,,,,,,,,,,,,,, -20400,0.032839246,0.027161172,,,,,,,,,,,,,,,,, -20500,0.04700922,0.03196052,,,,,,,,,,,,,,,,, -20600,0.04620588,0.031940047,,,,,,,,,,,,,,,,, -20700,0.03785038,0.02609296,,,,,,,,,,,,,,,,, -20800,0.048873696,0.028988631,,,,,,,,,,,,,,,,, -20900,0.04389445,0.033130385,,,,,,,,,,,,,,,,, -20957,,,0.9910153150558472,0.029240183532238,0.4452358825783558,0.9868357181549072,0.0443260408937931,0.2762801502975392,43793.0,0.9860588312149048,0.0468362160027027,0.2750265405447729,43793.0,6736.057872056961,9704.049597024918,6736.057872056961,2966.293663740158,1.1478157043457031,0.0 -21000,0.038502634,0.028370215,,,,,,,,,,,,,,,,, -21100,0.049860008,0.029904852,,,,,,,,,,,,,,,,, -21200,0.051361572,0.029884143,,,,,,,,,,,,,,,,, -21300,0.046296522,0.031926043,,,,,,,,,,,,,,,,, -21400,0.055778287,0.031707525,,,,,,,,,,,,,,,,, -21500,0.05739517,0.028287245,,,,,,,,,,,,,,,,, -21600,0.05714261,0.0315246,,,,,,,,,,,,,,,,, -21700,0.055166252,0.028660676,,,,,,,,,,,,,,,,, -21712,,,0.9912437200546264,0.0283992681652307,0.4604391958469054,0.9868929386138916,0.0440583266317844,0.2723770620871654,43793.0,0.9860668182373048,0.0465934127569198,0.273082646921358,43793.0,6976.239338636398,10049.307490348816,6976.239338636398,3071.3200681209564,1.1784136295318604,0.0 -21800,0.0410761,0.027849419,,,,,,,,,,,,,,,,, -21900,0.048473805,0.035447855,,,,,,,,,,,,,,,,, -22000,0.05301262,0.029886391,,,,,,,,,,,,,,,,, -22100,0.060280077,0.03144386,,,,,,,,,,,,,,,,, -22200,0.047772877,0.028017705,,,,,,,,,,,,,,,,, -22300,0.053452216,0.031794444,,,,,,,,,,,,,,,,, -22400,0.06512957,0.027782293,,,,,,,,,,,,,,,,, -22467,,,0.991296112537384,0.028296872973442,0.4488773275626957,0.9868758916854858,0.0440400317311286,0.2674766798136144,43793.0,0.98606938123703,0.0467710085213184,0.2726564770947386,43793.0,7216.246557235718,10390.653838157654,7216.246557235718,3172.609664440155,1.208643913269043,0.0 -22500,0.04820739,0.027081644,,,,,,,,,,,,,,,,, -22600,0.05359494,0.031516027,,,,,,,,,,,,,,,,, -22700,0.045308344,0.027526993,,,,,,,,,,,,,,,,, -22800,0.058783073,0.029355947,,,,,,,,,,,,,,,,, -22900,0.044227123,0.029900301,,,,,,,,,,,,,,,,, -23000,0.05332259,0.031151839,,,,,,,,,,,,,,,,, -23100,0.05181953,0.025954619,,,,,,,,,,,,,,,,, -23200,0.03875327,0.02706749,,,,,,,,,,,,,,,,, -23222,,,0.9915757775306702,0.0274291150271892,0.4692676798938344,0.986928641796112,0.0437940917909145,0.2775344659929296,43793.0,0.9861102104187012,0.0468363463878631,0.2697631826066058,43793.0,7456.395405292511,10734.925931692123,7456.395405292511,3276.680414915085,1.241858959197998,0.0 -23300,0.05928886,0.03555007,,,,,,,,,,,,,,,,, -23400,0.054688256,0.03251639,,,,,,,,,,,,,,,,, -23500,0.049826104,0.029437488,,,,,,,,,,,,,,,,, -23600,0.06688759,0.028798334,,,,,,,,,,,,,,,,, -23700,0.060597315,0.028614195,,,,,,,,,,,,,,,,, -23800,0.052219663,0.035040442,,,,,,,,,,,,,,,,, -23900,0.046051044,0.031789865,,,,,,,,,,,,,,,,, -23968,,,0.991378128528595,0.0281359814107418,0.4653400569960497,0.9869207739830016,0.043841116130352,0.2782805106836062,43793.0,0.986065149307251,0.0466242805123329,0.2755573066660334,43793.0,7696.503025054932,11075.41303539276,7696.503025054932,3377.009297847748,1.2736289501190186,0.0 -24000,0.05704601,0.031609103,,,,,,,,,,,,,,,,, -24100,0.05579426,0.02736327,,,,,,,,,,,,,,,,, -24200,0.045160387,0.027289145,,,,,,,,,,,,,,,,, -24300,0.054604657,0.03066503,,,,,,,,,,,,,,,,, -24400,0.054706562,0.038378067,,,,,,,,,,,,,,,,, -24500,0.056197654,0.03137301,,,,,,,,,,,,,,,,, -24600,0.05252151,0.031366486,,,,,,,,,,,,,,,,, -24700,0.049708113,0.030426482,,,,,,,,,,,,,,,,, -24710,,,0.9912890195846558,0.0285564847290515,0.446825708880746,0.9868153929710388,0.0439945720136165,0.2730327227108191,43793.0,0.9860036373138428,0.0464859828352928,0.2684493736634549,43793.0,7936.5431044101715,11418.130235671995,7936.5431044101715,3479.635217189789,1.3058831691741943,0.0 -24800,0.06418633,0.032323144,,,,,,,,,,,,,,,,, -24900,0.045539364,0.02750319,,,,,,,,,,,,,,,,, -25000,0.052871954,0.031567335,,,,,,,,,,,,,,,,, -25100,0.05514013,0.030540688,,,,,,,,,,,,,,,,, -25200,0.053154163,0.031533707,,,,,,,,,,,,,,,,, -25300,0.05272269,0.035559855,,,,,,,,,,,,,,,,, -25400,0.07397325,0.03292121,,,,,,,,,,,,,,,,, -25457,,,0.9913530349731444,0.0281521100550889,0.4645335916983681,0.9869887232780457,0.0438172034919261,0.2778937891981876,43793.0,0.9860908389091492,0.0465183109045028,0.2790029952847814,43793.0,8176.756483316421,11762.777160644531,8176.756483316421,3584.0178577899933,1.3375873565673828,0.0 -25500,0.07510105,0.032027353,,,,,,,,,,,,,,,,, -25600,0.066317916,0.03184397,,,,,,,,,,,,,,,,, -25700,0.05075096,0.03139141,,,,,,,,,,,,,,,,, -25800,0.054895762,0.026900833,,,,,,,,,,,,,,,,, -25900,0.058087055,0.02884246,,,,,,,,,,,,,,,,, -26000,0.065162554,0.029183056,,,,,,,,,,,,,,,,, -26100,0.060960855,0.035999246,,,,,,,,,,,,,,,,, -26200,0.059940416,0.02799979,,,,,,,,,,,,,,,,, -26216,,,0.9911909699440002,0.0284435730427503,0.4417814461731026,0.98688805103302,0.0441344380378723,0.2709613871925868,43793.0,0.9861236810684204,0.0470030941069126,0.2659943522029855,43793.0,8416.975894927979,12101.558113098145,8416.975894927979,3682.528043746948,1.3701817989349363,0.0 -26300,0.046177555,0.029825617,,,,,,,,,,,,,,,,, -26400,0.05226465,0.03237219,,,,,,,,,,,,,,,,, -26500,0.044604026,0.031601187,,,,,,,,,,,,,,,,, -26600,0.049299933,0.031104583,,,,,,,,,,,,,,,,, -26700,0.059276924,0.028466372,,,,,,,,,,,,,,,,, -26800,0.06151938,0.03230035,,,,,,,,,,,,,,,,, -26900,0.05313442,0.028069135,,,,,,,,,,,,,,,,, -26968,,,0.9912831783294678,0.0281435083597898,0.4646475549258269,0.9869964718818665,0.0441713444888591,0.2806928255369479,43793.0,0.986162006855011,0.0469765998423099,0.275428224527491,43793.0,8657.060805559158,12442.852400064468,8657.060805559158,3783.687194108963,1.4013991355895996,0.0 -27000,0.056810223,0.02868472,,,,,,,,,,,,,,,,, -27100,0.06068147,0.030690383,,,,,,,,,,,,,,,,, -27200,0.05186401,0.029115064,,,,,,,,,,,,,,,,, -27300,0.062827736,0.026288718,,,,,,,,,,,,,,,,, -27400,0.061427344,0.034386877,,,,,,,,,,,,,,,,, -27500,0.048507802,0.028800882,,,,,,,,,,,,,,,,, -27600,0.054728057,0.027924886,,,,,,,,,,,,,,,,, -27700,0.04144777,0.02581051,,,,,,,,,,,,,,,,, -27724,,,0.991607904434204,0.0271330680698156,0.4827749643581413,0.986968457698822,0.0441259928047657,0.27225728252014,43793.0,0.986177623271942,0.0467942990362644,0.2716887555121834,43793.0,8897.122105360031,12785.94291329384,8897.122105360031,3886.666358947754,1.4324800968170166,0.0 -27800,0.06933041,0.030514376,,,,,,,,,,,,,,,,, -27900,0.059825152,0.033235006,,,,,,,,,,,,,,,,, -28000,0.061084494,0.02995259,,,,,,,,,,,,,,,,, -28100,0.060357973,0.02596266,,,,,,,,,,,,,,,,, -28200,0.05052001,0.027082715,,,,,,,,,,,,,,,,, -28300,0.059505425,0.029010816,,,,,,,,,,,,,,,,, -28400,0.05908064,0.029740429,,,,,,,,,,,,,,,,, -28474,,,0.9916527271270752,0.027179455384612,0.4807633852413003,0.9868791699409484,0.0440702252089977,0.2754178447037879,43793.0,0.986123263835907,0.0466069392859935,0.2744048140787645,43793.0,9137.191581249235,13131.37703680992,9137.191581249235,3991.973956346512,1.4689247608184814,0.0 -28500,0.06766815,0.03124129,,,,,,,,,,,,,,,,, -28600,0.053436555,0.030647708,,,,,,,,,,,,,,,,, -28700,0.05786989,0.031952865,,,,,,,,,,,,,,,,, -28800,0.05416617,0.030529851,,,,,,,,,,,,,,,,, -28900,0.06677797,0.029214151,,,,,,,,,,,,,,,,, -29000,0.05973632,0.030286178,,,,,,,,,,,,,,,,, -29100,0.051826328,0.025652146,,,,,,,,,,,,,,,,, -29200,0.055079613,0.029576974,,,,,,,,,,,,,,,,, -29223,,,0.9917328357696532,0.0264650303870439,0.5178903073222842,0.9871068596839904,0.0441155098378658,0.2825732089125649,43793.0,0.986240804195404,0.046885460615158,0.2749978669854959,43793.0,9377.420338869097,13473.64820098877,9377.420338869097,4093.964473724365,1.5014734268188477,0.0 -29300,0.061520554,0.028312705,,,,,,,,,,,,,,,,, -29400,0.06128046,0.026952585,,,,,,,,,,,,,,,,, -29500,0.14767715,0.031093795,,,,,,,,,,,,,,,,, -29600,0.05990281,0.030285846,,,,,,,,,,,,,,,,, -29700,0.06956893,0.032113317,,,,,,,,,,,,,,,,, -29800,0.07293671,0.029867947,,,,,,,,,,,,,,,,, -29900,0.05536434,0.029916205,,,,,,,,,,,,,,,,, -29972,,,0.9918742179870604,0.0262527354061603,0.5005576253532122,0.9870269298553468,0.0442079156637191,0.2820148549135489,43793.0,0.9862378239631652,0.0468265935778617,0.2805010265121149,43793.0,9617.694372415544,13817.03848195076,9617.694372415544,4197.028309106827,1.5345180034637451,0.0 -30000,0.059452813,0.0292046,,,,,,,,,,,,,,,,, -30100,0.06507817,0.028880313,,,,,,,,,,,,,,,,, -30200,0.05865182,0.02831293,,,,,,,,,,,,,,,,, -30300,0.0630612,0.030342449,,,,,,,,,,,,,,,,, -30400,0.046500437,0.030250212,,,,,,,,,,,,,,,,, -30500,0.05967727,0.031299245,,,,,,,,,,,,,,,,, -30600,0.05727707,0.02804758,,,,,,,,,,,,,,,,, -30700,0.06623647,0.028999623,,,,,,,,,,,,,,,,, -30714,,,0.9917741417884828,0.0264660846441984,0.4974052599352223,0.986899435520172,0.0442755110561847,0.2783416322284315,43793.0,0.9860790371894836,0.0469512119889259,0.2753131299952883,43793.0,9857.652417898178,14163.926680326462,9857.652417898178,4303.904276847839,1.5686705112457275,0.0 -30800,0.062478743,0.029348072,,,,,,,,,,,,,,,,, -30900,0.056811985,0.030218521,,,,,,,,,,,,,,,,, -31000,0.050300557,0.026057594,,,,,,,,,,,,,,,,, -31100,0.061113995,0.03167248,,,,,,,,,,,,,,,,, -31200,0.06393369,0.026792599,,,,,,,,,,,,,,,,, -31300,0.0632919,0.029953044,,,,,,,,,,,,,,,,, -31400,0.07270283,0.027958484,,,,,,,,,,,,,,,,, -31471,,,0.9916498064994812,0.0269456524401903,0.4888814574380777,0.9870484471321106,0.0444845706224441,0.2807809220691736,43793.0,0.9862290024757384,0.0471750982105731,0.2739805160313872,43793.0,10097.801033735275,14504.410804271698,10097.801033735275,4404.187355518341,1.6013495922088623,0.0 -31500,0.059706792,0.027981786,,,,,,,,,,,,,,,,, -31600,0.052560918,0.02759279,,,,,,,,,,,,,,,,, -31700,0.0593707,0.028408036,,,,,,,,,,,,,,,,, -31800,0.052864302,0.026654443,,,,,,,,,,,,,,,,, -31900,0.06403561,0.026370015,,,,,,,,,,,,,,,,, -32000,0.06867698,0.029148327,,,,,,,,,,,,,,,,, -32100,0.055949505,0.027834581,,,,,,,,,,,,,,,,, -32200,0.05159342,0.027273018,,,,,,,,,,,,,,,,, -32224,,,0.9916934370994568,0.0268517304211854,0.4979061005428039,0.9870707392692566,0.0440482571721077,0.2837913263385009,43793.0,0.9862456321716307,0.0467881597578525,0.2771946456173613,43793.0,10337.873920917513,14846.432065725328,10337.873920917513,4506.080503463745,1.637082815170288,0.0 -32300,0.0519116,0.02737133,,,,,,,,,,,,,,,,, -32400,0.05748766,0.02716068,,,,,,,,,,,,,,,,, -32500,0.049389042,0.028048113,,,,,,,,,,,,,,,,, -32600,0.05859628,0.028942324,,,,,,,,,,,,,,,,, -32700,0.06960056,0.033269115,,,,,,,,,,,,,,,,, -32800,0.05927808,0.027421292,,,,,,,,,,,,,,,,, -32900,0.066918276,0.027162967,,,,,,,,,,,,,,,,, -32971,,,0.9917108416557312,0.0268477853387594,0.4847077089609162,0.9870195984840392,0.0439106449484825,0.2810509003984059,43793.0,0.9862130284309388,0.0467044077813625,0.2748783828748157,43793.0,10577.960082054138,15188.354211330414,10577.960082054138,4607.861759901047,1.6719791889190674,0.0 -33000,0.060553618,0.028764952,,,,,,,,,,,,,,,,, -33100,0.050362773,0.029136034,,,,,,,,,,,,,,,,, -33200,0.06754487,0.0278726,,,,,,,,,,,,,,,,, -33300,0.06337932,0.029077403,,,,,,,,,,,,,,,,, -33400,0.06723784,0.027997816,,,,,,,,,,,,,,,,, -33500,0.07527603,0.029886227,,,,,,,,,,,,,,,,, -33600,0.07909029,0.030676747,,,,,,,,,,,,,,,,, -33700,0.060280368,0.027752299,,,,,,,,,,,,,,,,, -33723,,,0.9917195439338684,0.0264511443674564,0.5071388805226102,0.9871170520782472,0.0440160892903804,0.2869082133800527,43793.0,0.986255943775177,0.0467122979462146,0.2788147279553553,43793.0,10818.10796546936,15529.334302663803,10818.10796546936,4708.640298604965,1.7060413360595703,0.0 -33800,0.05375205,0.024476605,,,,,,,,,,,,,,,,, -33900,0.055121925,0.028964987,,,,,,,,,,,,,,,,, -34000,0.060293965,0.02886772,,,,,,,,,,,,,,,,, -34100,0.07014635,0.032529,,,,,,,,,,,,,,,,, -34200,0.049708687,0.026791316,,,,,,,,,,,,,,,,, -34300,0.05121606,0.029623684,,,,,,,,,,,,,,,,, -34400,0.063310474,0.033290766,,,,,,,,,,,,,,,,, -34468,,,0.9917593002319336,0.0264794621616601,0.4987217774287702,0.9870354533195496,0.0445708893239498,0.2862611577565212,43793.0,0.986292600631714,0.0474320203065872,0.2736994714566074,43793.0,11058.298412799835,15876.846804141998,11058.298412799835,4815.907956838608,1.7401671409606934,0.0 -34500,0.05218253,0.026929121,,,,,,,,,,,,,,,,, -34600,0.05646905,0.025003044,,,,,,,,,,,,,,,,, -34700,0.06516432,0.029517941,,,,,,,,,,,,,,,,, -34800,0.057721402,0.027975352,,,,,,,,,,,,,,,,, -34900,0.07199921,0.030089397,,,,,,,,,,,,,,,,, -35000,0.058989294,0.029768705,,,,,,,,,,,,,,,,, -35100,0.06631494,0.03255528,,,,,,,,,,,,,,,,, -35200,0.058978654,0.027313119,,,,,,,,,,,,,,,,, -35216,,,0.9919956922531128,0.0257887691259384,0.514730269745362,0.9871243238449096,0.0442296601831913,0.2876763413028705,43793.0,0.986276149749756,0.0471402741968631,0.2778236007067979,43793.0,11298.292127609251,16218.54702448845,11298.292127609251,4917.55984044075,1.7752108573913574,0.0 -35300,0.053100437,0.027822133,,,,,,,,,,,,,,,,, -35400,0.10042386,0.030506996,,,,,,,,,,,,,,,,, -35500,0.07511813,0.02718497,,,,,,,,,,,,,,,,, -35600,0.058384433,0.02641005,,,,,,,,,,,,,,,,, -35700,0.062362686,0.027492791,,,,,,,,,,,,,,,,, -35800,0.06452862,0.026829213,,,,,,,,,,,,,,,,, -35900,0.06838366,0.028689522,,,,,,,,,,,,,,,,, -35967,,,0.9921483397483826,0.0250180773437023,0.5371498315246481,0.9870285391807556,0.0445261783897876,0.2823993135022633,43793.0,0.9861658215522766,0.047429759055376,0.2739156328156931,43793.0,11538.531027317047,16561.054585933685,11538.531027317047,5019.775372505188,1.809107780456543,0.0 -36000,0.06873152,0.028966704,,,,,,,,,,,,,,,,, -36100,0.060269177,0.027579546,,,,,,,,,,,,,,,,, -36200,0.059655186,0.027063198,,,,,,,,,,,,,,,,, -36300,0.066721685,0.029378787,,,,,,,,,,,,,,,,, -36400,0.0548324,0.02770697,,,,,,,,,,,,,,,,, -36500,0.05359136,0.030196432,,,,,,,,,,,,,,,,, -36600,0.058573794,0.027101636,,,,,,,,,,,,,,,,, -36700,0.058770414,0.026874227,,,,,,,,,,,,,,,,, -36721,,,0.9923694133758544,0.0246111080050468,0.5423980058733963,0.9868807792663574,0.0442409552633762,0.2812444716684876,43793.0,0.9860352277755736,0.046927087008953,0.2776079649370204,43793.0,11778.565999269484,16899.971519231796,11778.565999269484,5118.604428529739,1.8418443202972408,0.0 -36800,0.066241845,0.027695201,,,,,,,,,,,,,,,,, -36900,0.061397403,0.027994145,,,,,,,,,,,,,,,,, -37000,0.057271495,0.02604024,,,,,,,,,,,,,,,,, -37100,0.063222155,0.025882127,,,,,,,,,,,,,,,,, -37200,0.05621915,0.026911836,,,,,,,,,,,,,,,,, -37300,0.06509055,0.027386468,,,,,,,,,,,,,,,,, -37400,0.061247963,0.025596963,,,,,,,,,,,,,,,,, -37462,,,0.9923901557922364,0.0243136920034885,0.5580372747770705,0.9870139360427856,0.0445088073611259,0.2837891887851236,43793.0,0.986195743083954,0.0474725700914859,0.2828257711099307,43793.0,12018.567860364914,17243.798904657364,12018.567860364914,5222.37273478508,1.878031730651856,0.0 -37500,0.06470492,0.027632153,,,,,,,,,,,,,,,,, -37600,0.060554367,0.027023101,,,,,,,,,,,,,,,,, -37700,0.056178287,0.023946542,,,,,,,,,,,,,,,,, -37800,0.08232552,0.029355826,,,,,,,,,,,,,,,,, -37900,0.053598497,0.025551446,,,,,,,,,,,,,,,,, -38000,0.067909285,0.024859993,,,,,,,,,,,,,,,,, -38100,0.054954287,0.023702871,,,,,,,,,,,,,,,,, -38200,0.05600954,0.025323197,,,,,,,,,,,,,,,,, -38214,,,0.9922755360603333,0.0247705467045307,0.5347153345263589,0.9870476126670836,0.0448093898594379,0.2813780929325306,43793.0,0.9862033128738404,0.0475928448140621,0.2776924273187139,43793.0,12258.701553821564,17585.86657810211,12258.701553821564,5324.253618955612,1.9119784832000728,0.0 -38300,0.09791027,0.030874342,,,,,,,,,,,,,,,,, -38400,0.06181427,0.02357789,,,,,,,,,,,,,,,,, -38500,0.071891844,0.027084043,,,,,,,,,,,,,,,,, -38600,0.052182913,0.024604525,,,,,,,,,,,,,,,,, -38700,0.07350598,0.030710889,,,,,,,,,,,,,,,,, -38800,0.06693775,0.025901904,,,,,,,,,,,,,,,,, -38900,0.071327955,0.02731503,,,,,,,,,,,,,,,,, -38966,,,0.9921236634254456,0.0251753758639097,0.5225601239481771,0.9870768189430236,0.044680256396532,0.2864147846605825,43793.0,0.9862778782844543,0.0475385151803493,0.2776829744498602,43793.0,12498.43436551094,17927.652944803238,12498.43436551094,5425.832328796387,2.3672399520874023,0.0 -39000,0.071404144,0.024865802,,,,,,,,,,,,,,,,, -39100,0.066410765,0.025013274,,,,,,,,,,,,,,,,, -39200,0.07620762,0.028323684,,,,,,,,,,,,,,,,, -39300,0.07102178,0.026707463,,,,,,,,,,,,,,,,, -39400,0.06604512,0.028470343,,,,,,,,,,,,,,,,, -39500,0.07322624,0.027272558,,,,,,,,,,,,,,,,, -39600,0.095712185,0.02958792,,,,,,,,,,,,,,,,, -39700,0.07625919,0.02985786,,,,,,,,,,,,,,,,, -39706,,,0.9922091364860536,0.0250520091503858,0.534985723044657,0.9870837330818176,0.0441632904112339,0.2902683111279471,43793.0,0.9863018989562988,0.0470027439296245,0.2751548149248845,43793.0,12738.577870845796,18272.76641869545,12738.577870845796,5530.745029449463,2.4028241634368896,0.0 -39800,0.06446197,0.02825718,,,,,,,,,,,,,,,,, -39900,0.06784487,0.028833328,,,,,,,,,,,,,,,,, -40000,0.077327445,0.026970647,,,,,,,,,,,,,,,,, -40100,0.07317198,0.024681821,,,,,,,,,,,,,,,,, -40200,0.061333496,0.023746666,,,,,,,,,,,,,,,,, -40300,0.08107508,0.029610403,,,,,,,,,,,,,,,,, -40400,0.08509202,0.031017235,,,,,,,,,,,,,,,,, -40458,,,0.9922121167182922,0.0248994305729866,0.5352071481301257,0.9870695471763612,0.0442797504365444,0.2866353280879107,43793.0,0.986185610294342,0.0472837202250957,0.2740302064957952,43793.0,12978.589567184448,18611.997394800183,12978.589567184448,5629.910222530365,2.4375898838043213,0.0 -40500,0.071067534,0.029638076,,,,,,,,,,,,,,,,, -40600,0.061574306,0.026125908,,,,,,,,,,,,,,,,, -40700,0.088446796,0.029772392,,,,,,,,,,,,,,,,, -40800,0.07256214,0.026644256,,,,,,,,,,,,,,,,, -40900,0.07449714,0.027617797,,,,,,,,,,,,,,,,, -41000,0.07980828,0.029532442,,,,,,,,,,,,,,,,, -41100,0.06800768,0.026483659,,,,,,,,,,,,,,,,, -41200,0.071429886,0.026768137,,,,,,,,,,,,,,,,, -41211,,,0.9924872517585754,0.0240594092756509,0.5571624292463508,0.9870370626449584,0.0443583317101001,0.2841668355790797,43793.0,0.9861717224121094,0.0471930019557476,0.2720835039963426,43793.0,13218.623850822449,18949.86185884476,13218.623850822449,5727.685400247574,2.4737045764923096,0.0 -41300,0.070654504,0.029056413,,,,,,,,,,,,,,,,, -41400,0.0818373,0.026431216,,,,,,,,,,,,,,,,, -41500,0.09424236,0.028052805,,,,,,,,,,,,,,,,, -41600,0.06254995,0.023798596,,,,,,,,,,,,,,,,, -41700,0.0818508,0.026026212,,,,,,,,,,,,,,,,, -41800,0.08005119,0.028006181,,,,,,,,,,,,,,,,, -41900,0.069500156,0.028973479,,,,,,,,,,,,,,,,, -41969,,,0.9924891591072084,0.0239780861884355,0.5498505304311031,0.9871364831924438,0.0445279330015182,0.292061123135377,43793.0,0.9861666560173036,0.0474649108946323,0.2750019408692665,43793.0,13458.846201658249,19290.2509264946,13458.846201658249,5827.797815561295,2.508739709854126,0.0 -42000,0.0713563,0.023106212,,,,,,,,,,,,,,,,, -42100,0.0735887,0.025961485,,,,,,,,,,,,,,,,, -42200,0.08440437,0.030011598,,,,,,,,,,,,,,,,, -42300,0.07517272,0.026732462,,,,,,,,,,,,,,,,, -42400,0.082413,0.030350588,,,,,,,,,,,,,,,,, -42500,0.07251652,0.025713708,,,,,,,,,,,,,,,,, -42600,0.07868823,0.026130239,,,,,,,,,,,,,,,,, -42700,0.05980643,0.023557475,,,,,,,,,,,,,,,,, -42708,,,0.9924418926239014,0.0239533670246601,0.5589275876659786,0.987052083015442,0.0447279922664165,0.2869159028380207,43793.0,0.9862711429595948,0.0472346059978008,0.2784730233764366,43793.0,13698.999931812286,19631.180638074875,13698.999931812286,5928.517250061035,2.544694185256958,0.0 -42800,0.068750285,0.026153633,,,,,,,,,,,,,,,,, -42900,0.06908561,0.02309272,,,,,,,,,,,,,,,,, -43000,0.08050788,0.028527003,,,,,,,,,,,,,,,,, -43100,0.069106646,0.027609827,,,,,,,,,,,,,,,,, -43200,0.08201005,0.025582826,,,,,,,,,,,,,,,,, -43300,0.069199726,0.023835462,,,,,,,,,,,,,,,,, -43400,0.08695149,0.026820488,,,,,,,,,,,,,,,,, -43459,,,0.9928225874900818,0.022863321006298,0.5811168779417648,0.987057328224182,0.0449929870665073,0.2861434822523052,43793.0,0.9862121343612672,0.0479162074625492,0.2752638868043979,43793.0,13939.12776207924,19972.82164144516,13939.12776207924,6029.976556539536,2.5795791149139404,0.0 -43500,0.0795701,0.026481504,,,,,,,,,,,,,,,,, -43600,0.08995256,0.026712341,,,,,,,,,,,,,,,,, -43700,0.08141188,0.02680128,,,,,,,,,,,,,,,,, -43800,0.07451097,0.024194036,,,,,,,,,,,,,,,,, -43900,0.08610354,0.027340416,,,,,,,,,,,,,,,,, -44000,0.06389434,0.024268815,,,,,,,,,,,,,,,,, -44100,0.06785759,0.027875392,,,,,,,,,,,,,,,,, -44200,0.082228385,0.027361356,,,,,,,,,,,,,,,,, -44206,,,0.9930243492126464,0.0223029311746358,0.6021831467185383,0.9870301485061646,0.0448726527392864,0.2870586605260568,43793.0,0.9861574172973632,0.0479033477604389,0.2757788378159617,43793.0,14179.133565664291,20310.78754711151,14179.133565664291,6127.8829135894775,2.6141395568847656,0.0 -44300,0.0755046,0.026240481,,,,,,,,,,,,,,,,, -44400,0.08103286,0.02477109,,,,,,,,,,,,,,,,, -44500,0.07983704,0.025256036,,,,,,,,,,,,,,,,, -44600,0.08715359,0.025643893,,,,,,,,,,,,,,,,, -44700,0.068100125,0.02530865,,,,,,,,,,,,,,,,, -44800,0.07133788,0.027757686,,,,,,,,,,,,,,,,, -44900,0.07169097,0.0263221,,,,,,,,,,,,,,,,, -44963,,,0.9930409789085388,0.0221349876374006,0.602127596669349,0.98716002702713,0.0449102371931076,0.2858967786479935,43793.0,0.9862648248672484,0.0477472729980945,0.2821340983594425,43793.0,14419.119592666626,20651.140026569366,14419.119592666626,6228.195499420166,2.649040937423706,0.0 -45000,0.07835864,0.026297443,,,,,,,,,,,,,,,,, -45100,0.08270327,0.02883519,,,,,,,,,,,,,,,,, -45200,0.10295257,0.027120907,,,,,,,,,,,,,,,,, -45300,0.07498227,0.025010701,,,,,,,,,,,,,,,,, -45400,0.08198886,0.027983608,,,,,,,,,,,,,,,,, -45500,0.07663036,0.026920088,,,,,,,,,,,,,,,,, -45600,0.06724578,0.02338148,,,,,,,,,,,,,,,,, -45700,0.08217276,0.02576689,,,,,,,,,,,,,,,,, -45727,,,0.9928721785545348,0.0226054582744836,0.583901784898843,0.9871146082878112,0.0450869426131248,0.2856805049431565,43793.0,0.9863384962081908,0.0479014739394187,0.2862644840320328,43793.0,14659.334668159485,20990.62044978141,14659.334668159485,6327.404683113098,2.686012029647827,0.0 -45800,0.09368614,0.026063606,,,,,,,,,,,,,,,,, -45900,0.08412539,0.024718788,,,,,,,,,,,,,,,,, -46000,0.08749984,0.023798143,,,,,,,,,,,,,,,,, -46100,0.07133271,0.025417529,,,,,,,,,,,,,,,,, -46200,0.117355004,0.027803568,,,,,,,,,,,,,,,,, -46300,0.07956463,0.027058464,,,,,,,,,,,,,,,,, -46400,0.09264146,0.025883881,,,,,,,,,,,,,,,,, -46489,,,0.992776334285736,0.0229767374694347,0.5796602909704152,0.9870861768722534,0.0449971593916416,0.2878844677435476,43793.0,0.9862883687019348,0.0479216575622558,0.2808426418318341,43793.0,14899.518951892853,21327.210332155228,14899.518951892853,6423.754679679871,2.7223784923553467,0.0 -46500,0.099344455,0.02517338,,,,,,,,,,,,,,,,, -46600,0.066951625,0.022813847,,,,,,,,,,,,,,,,, -46700,0.0963396,0.027198859,,,,,,,,,,,,,,,,, -46800,0.07524825,0.024663981,,,,,,,,,,,,,,,,, -46900,0.10166017,0.027873922,,,,,,,,,,,,,,,,, -47000,0.06935556,0.023157064,,,,,,,,,,,,,,,,, -47100,0.083404414,0.026637627,,,,,,,,,,,,,,,,, -47200,0.08410456,0.025392363,,,,,,,,,,,,,,,,, -47245,,,0.9926639199256896,0.0231765508651733,0.5689805751294961,0.987066686153412,0.0454099029302597,0.2847175279822495,43793.0,0.9861990809440612,0.0483264736831188,0.2735987649905027,43793.0,15139.666759252548,21672.53662109375,15139.666759252548,6528.876852750778,2.7595677375793457,0.0 -47300,0.07328518,0.023361677,,,,,,,,,,,,,,,,, -47400,0.09508553,0.025499376,,,,,,,,,,,,,,,,, -47500,0.090129994,0.023446236,,,,,,,,,,,,,,,,, -47600,0.08191167,0.023811704,,,,,,,,,,,,,,,,, -47700,0.077551015,0.023159504,,,,,,,,,,,,,,,,, -47800,0.09004128,0.026123924,,,,,,,,,,,,,,,,, -47900,0.0950802,0.023692043,,,,,,,,,,,,,,,,, -47994,,,0.9927076697349548,0.0229984018951654,0.5812239180112349,0.9869928359985352,0.0454793311655521,0.2820994474233479,43793.0,0.9861843585968018,0.0484066344797611,0.2701901439753169,43793.0,15379.932568311691,22011.00979018212,15379.932568311691,6627.028215408325,2.79514479637146,0.0 -48000,0.086119376,0.02640923,,,,,,,,,,,,,,,,, -48100,0.0777851,0.025396654,,,,,,,,,,,,,,,,, -48200,0.075060606,0.024860498,,,,,,,,,,,,,,,,, -48300,0.089304626,0.02509803,,,,,,,,,,,,,,,,, -48400,0.07085008,0.023319324,,,,,,,,,,,,,,,,, -48500,0.08543521,0.027238147,,,,,,,,,,,,,,,,, -48600,0.084298484,0.022255126,,,,,,,,,,,,,,,,, -48700,0.07474231,0.024328353,,,,,,,,,,,,,,,,, -48748,,,0.992895781993866,0.0225099623203277,0.6029228174570651,0.9870654940605164,0.0453277118504047,0.2864824900130744,43793.0,0.9862820506095886,0.0481756180524826,0.2763885679228161,43793.0,15620.12927222252,22346.907153129578,15620.12927222252,6722.673615455627,2.831093072891236,0.0 -48800,0.10987846,0.023687169,,,,,,,,,,,,,,,,, -48900,0.10420066,0.025927689,,,,,,,,,,,,,,,,, -49000,0.08121735,0.025425894,,,,,,,,,,,,,,,,, -49100,0.080763556,0.026627176,,,,,,,,,,,,,,,,, -49200,0.08769049,0.026593262,,,,,,,,,,,,,,,,, -49300,0.08484929,0.022256885,,,,,,,,,,,,,,,,, -49400,0.07469917,0.025739633,,,,,,,,,,,,,,,,, -49500,,,0.9930621981620787,0.0219377987086772,0.6060850216324132,0.9870135188102722,0.0454810671508312,0.2898876379847001,43793.0,0.9862605929374696,0.0483807176351547,0.2754886985604703,43793.0,15860.07743358612,22690.710390090942,15860.07743358612,6826.473671674728,2.866787433624268,0.0 -49500,0.09966945,0.025056977,,,,,,,,,,,,,,,,, -49600,0.08728137,0.023140376,,,,,,,,,,,,,,,,, -49700,0.08977761,0.024641007,,,,,,,,,,,,,,,,, -49800,0.08894264,0.022392694,,,,,,,,,,,,,,,,, -49900,0.096191764,0.023357583,,,,,,,,,,,,,,,,, -50000,0.08181365,0.020165956,,,,,,,,,,,,,,,,, -50100,0.10020241,0.026034396,,,,,,,,,,,,,,,,, -50200,0.09703819,0.024418259,,,,,,,,,,,,,,,,, -50255,,,0.9931796193122864,0.0217054821550846,0.5990312491287693,0.9870184063911438,0.045732669532299,0.2864895013825414,43793.0,0.9861767888069152,0.0484723262488842,0.273646059366143,43793.0,16100.23567533493,23027.40929079056,16100.23567533493,6922.95885682106,2.9028890132904053,0.0 -50300,0.09637728,0.024510264,,,,,,,,,,,,,,,,, -50400,0.082468286,0.022541335,,,,,,,,,,,,,,,,, -50500,0.09180753,0.025044505,,,,,,,,,,,,,,,,, -50600,0.08353786,0.0226006,,,,,,,,,,,,,,,,, -50700,0.10292878,0.02810418,,,,,,,,,,,,,,,,, -50800,0.08209324,0.025344498,,,,,,,,,,,,,,,,, -50900,0.08483079,0.024654089,,,,,,,,,,,,,,,,, -51000,0.08233548,0.022296643,,,,,,,,,,,,,,,,, -51016,,,0.9933199286460876,0.0210258439183235,0.6328068561122213,0.9870354533195496,0.0458424612879753,0.2870732851061626,43793.0,0.9862424731254578,0.0485534220933914,0.2766943537149877,43793.0,16340.418535232544,23372.501963615417,16340.418535232544,7027.812830686569,2.9394538402557373,0.0 -51100,0.086978436,0.022048911,,,,,,,,,,,,,,,,, -51200,0.08521582,0.022884151,,,,,,,,,,,,,,,,, -51300,0.08575107,0.023521027,,,,,,,,,,,,,,,,, -51400,0.10600727,0.024620032,,,,,,,,,,,,,,,,, -51500,0.10032372,0.023684734,,,,,,,,,,,,,,,,, -51600,0.10449954,0.023019757,,,,,,,,,,,,,,,,, -51700,0.08627077,0.023736972,,,,,,,,,,,,,,,,, -51773,,,0.993573784828186,0.020389275625348,0.6390301099909069,0.9870334267616272,0.0460025481879711,0.2893493358659952,43793.0,0.9862707257270812,0.0488488227128982,0.2793308413871986,43793.0,16580.547554254532,23714.87309408188,16580.547554254532,7129.999175310135,2.975675106048584,0.0 -51800,0.08780002,0.024713257,,,,,,,,,,,,,,,,, -51900,0.09758489,0.029016243,,,,,,,,,,,,,,,,, -52000,0.082041,0.021359626,,,,,,,,,,,,,,,,, -52100,0.113954544,0.025044775,,,,,,,,,,,,,,,,, -52200,0.08941431,0.023412423,,,,,,,,,,,,,,,,, -52300,0.09442994,0.023777239,,,,,,,,,,,,,,,,, -52400,0.08654176,0.021792335,,,,,,,,,,,,,,,,, -52500,0.093977325,0.025415529,,,,,,,,,,,,,,,,, -52523,,,0.9938708543777466,0.0197893120348453,0.6536397233896962,0.9869379997253418,0.045902457088232,0.2885353006882328,43793.0,0.9861144423484802,0.048769537359476,0.2781312477674411,43793.0,16820.687520742416,24056.08555865288,16820.687520742416,7231.016023159027,3.0121028423309326,0.0 -52600,0.08896674,0.023418577,,,,,,,,,,,,,,,,, -52700,0.109651186,0.025184713,,,,,,,,,,,,,,,,, -52800,0.09969417,0.023454687,,,,,,,,,,,,,,,,, -52900,0.08864315,0.019883107,,,,,,,,,,,,,,,,, -53000,0.09902724,0.023024745,,,,,,,,,,,,,,,,, -53100,0.104408816,0.025783194,,,,,,,,,,,,,,,,, -53200,0.09482157,0.02461716,,,,,,,,,,,,,,,,, -53269,,,0.9936069250106812,0.0201866496354341,0.6379870321649714,0.98707115650177,0.0463605523109436,0.2920631409937257,43793.0,0.9862328171730042,0.0491791106760501,0.2761081457071273,43793.0,17060.65802717209,24396.114936590195,17060.65802717209,7331.012581825256,3.0534260272979736,0.0 -53300,0.10829946,0.021711405,,,,,,,,,,,,,,,,, -53400,0.09417911,0.024859102,,,,,,,,,,,,,,,,, -53500,0.09130999,0.024281826,,,,,,,,,,,,,,,,, -53600,0.09201464,0.024090849,,,,,,,,,,,,,,,,, -53700,0.10322922,0.022322167,,,,,,,,,,,,,,,,, -53800,0.12375881,0.025576968,,,,,,,,,,,,,,,,, -53900,0.10302634,0.019137483,,,,,,,,,,,,,,,,, -54000,0.08487996,0.022457693,,,,,,,,,,,,,,,,, -54016,,,0.993407905101776,0.020683042705059,0.6422283249361697,0.9870346188545228,0.0463643260300159,0.2859237206899969,43793.0,0.9862012267112732,0.0492318421602249,0.2789847800402106,43793.0,17300.819898843765,24738.92270088196,17300.819898843765,7433.599872112274,3.091221570968628,0.0 -54100,0.08292383,0.02177957,,,,,,,,,,,,,,,,, -54200,0.09330321,0.02054458,,,,,,,,,,,,,,,,, -54300,0.10902893,0.02385354,,,,,,,,,,,,,,,,, -54400,0.096782066,0.023816535,,,,,,,,,,,,,,,,, -54500,0.09415442,0.021563066,,,,,,,,,,,,,,,,, -54600,0.1097123,0.022425205,,,,,,,,,,,,,,,,, -54700,0.09943381,0.02437432,,,,,,,,,,,,,,,,, -54759,,,0.993454933166504,0.0205811411142349,0.6245154233950967,0.9870610237121582,0.0465365014970302,0.2884550829211076,43793.0,0.9863005876541138,0.0494023486971855,0.2773694298836116,43793.0,17541.016096830368,25080.20478367805,17541.016096830368,7534.629323959351,3.1281652450561523,0.0 -54800,0.10202427,0.021244975,,,,,,,,,,,,,,,,, -54900,0.09135007,0.021060897,,,,,,,,,,,,,,,,, -55000,0.09850043,0.023191938,,,,,,,,,,,,,,,,, -55100,0.121660635,0.024579654,,,,,,,,,,,,,,,,, -55200,0.10799559,0.023484813,,,,,,,,,,,,,,,,, -55300,0.10158163,0.021768456,,,,,,,,,,,,,,,,, -55400,0.10036112,0.022943266,,,,,,,,,,,,,,,,, -55500,0.10245432,0.025170492,,,,,,,,,,,,,,,,, -55505,,,0.9934847950935364,0.0203968025743961,0.6366287142374494,0.9871174097061156,0.0465443395078182,0.2906898791426333,43793.0,0.98625510931015,0.0496644601225853,0.2748268458454385,43793.0,17781.16787624359,25416.691744089127,17781.16787624359,7630.906837940216,3.166072130203247,0.0 -55600,0.12821342,0.025188662,,,,,,,,,,,,,,,,, -55700,0.10205551,0.0238792,,,,,,,,,,,,,,,,, -55800,0.09336459,0.02154278,,,,,,,,,,,,,,,,, -55900,0.09547183,0.024935616,,,,,,,,,,,,,,,,, -56000,0.09831818,0.022109749,,,,,,,,,,,,,,,,, -56100,0.09710566,0.02118011,,,,,,,,,,,,,,,,, -56200,0.10358441,0.022997046,,,,,,,,,,,,,,,,, -56252,,,0.993678867816925,0.0199342537671327,0.6509323873209593,0.9869863390922546,0.0468346439301967,0.2884947556034788,43793.0,0.9861961603164672,0.0496541261672973,0.2770761467327258,43793.0,18021.238358020782,25759.857971429825,18021.238358020782,7733.94634437561,3.202991008758545,0.0 -56300,0.10021105,0.022023434,,,,,,,,,,,,,,,,, -56400,0.11610994,0.023570415,,,,,,,,,,,,,,,,, -56500,0.09536492,0.022242868,,,,,,,,,,,,,,,,, -56600,0.12068127,0.023749627,,,,,,,,,,,,,,,,, -56700,0.09877371,0.02214013,,,,,,,,,,,,,,,,, -56800,0.123088166,0.020531848,,,,,,,,,,,,,,,,, -56900,0.09906331,0.021053437,,,,,,,,,,,,,,,,, -57000,0.11889654,0.023783725,,,,,,,,,,,,,,,,, -57001,,,0.9936870336532592,0.0197696257382631,0.6579927457781373,0.9869254231452942,0.0470384806394577,0.2847933420483199,43793.0,0.9861481189727784,0.0498638562858104,0.2743496012662404,43793.0,18261.479460000992,26097.717761278152,18261.479460000992,7831.50794172287,3.240222930908203,0.0 -57100,0.11055782,0.022581603,,,,,,,,,,,,,,,,, -57200,0.12563452,0.023917533,,,,,,,,,,,,,,,,, -57300,0.10570808,0.021357091,,,,,,,,,,,,,,,,, -57400,0.10419358,0.019407764,,,,,,,,,,,,,,,,, -57500,0.10987195,0.024209907,,,,,,,,,,,,,,,,, -57600,0.110908955,0.0238439,,,,,,,,,,,,,,,,, -57678,,,,,,,,,,,,,,18477.257613420486,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d99d9eb84..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -882.5367727279663,0.0,37.38932204246521,1,0,37.38932204246521,0.0007088489946909,0.0,10.966498374938965,3003,919.9261367321014,0.0006382566643878,0.0,10.960665702819824,0.0004835649742744,0.0,10.980294227600098,3000 -1477.5438141822815,0.0384318828582763,877.514372587204,2411,0,877.514372587204,0.3811051249504089,8.015197424590168,4.32146692276001,3003,2355.1783118247986,0.4126627445220947,14.245377447202788,4.002304553985596,0.3975647985935211,9.613389479466637,4.118062496185303,3000 -1979.8509588241573,0.0781202316284179,1717.5991599559784,4821,0,1717.5991599559784,0.5461623668670654,19.172664785009378,2.733055353164673,3003,3697.6896362304688,0.5397511124610901,24.513584478437387,2.753583669662476,0.5449033379554749,20.62834035354065,2.7072184085845947,3000 -2472.026073932648,0.1051478385925293,2557.73996424675,7234,0,2557.73996424675,0.5904944539070129,22.381906824340813,2.3054873943328857,3003,5030.110518932343,0.5832327604293823,27.40869814351565,2.348278522491455,0.5864651203155518,23.116123270683246,2.313688039779663,3000 -2925.8024010658264,0.1338064670562744,3397.9064087867737,9649,0,3397.9064087867737,0.6160014271736145,23.70429726840615,2.095056533813477,3003,6324.159275770187,0.5944892764091492,27.885071200396293,2.2307252883911133,0.608386754989624,24.95503798977198,2.125382423400879,3000 -3395.214864253998,0.15960693359375,4237.891711235046,12064,0,4237.891711235046,0.6272267699241638,24.68109616132204,1.9692918062210083,3003,7633.663153886795,0.6005798578262329,28.89981054140853,2.151870012283325,0.6196823120117188,25.802299412078856,2.009481906890869,3000 -3893.4397599697113,0.1855201721191406,5078.001572847366,14480,0,5078.001572847366,0.6392307281494141,25.5829851352678,1.868680477142334,3003,8972.102147102356,0.6170651316642761,29.54870471997391,2.030705451965332,0.630109965801239,26.213816054497475,1.917237877845764,3000 -4376.910125494003,0.2147214412689209,5918.144634008408,16895,0,5918.144634008408,0.648678183555603,26.35533227720316,1.8056501150131223,3003,10295.827911376951,0.6205698847770691,30.48051704857396,2.0030786991119385,0.6400168538093567,26.91149584465031,1.853195786476136,3000 -4865.887587070465,0.2430028915405273,6758.210176944733,19311,0,6758.210176944733,0.6544535756111145,26.59917892660872,1.754570722579956,3003,11624.979398727415,0.6361596584320068,30.9128753871064,1.8797343969345093,0.6456212401390076,27.351768905608047,1.810285568237305,3000 -5401.86420583725,0.2701001167297363,7598.422609567642,21728,0,7598.422609567642,0.2937540113925934,0.1403133577495965,4.477366924285889,3003,13001.2740752697,0.3390825092792511,0.5862118968036852,3.9086546897888175,0.3057122528553009,0.1320855284362729,4.292892932891846,3000 -5907.838281869888,0.300400972366333,8438.351751565933,24147,0,8438.351751565933,0.6543838381767273,26.747935955855187,1.7676335573196411,3003,14347.287384033203,0.6233353614807129,30.139503659176032,1.9733836650848389,0.6435753703117371,27.41255425476972,1.819201946258545,3000 -6431.037258863449,0.3333091735839844,9278.293791770937,26564,0,9278.293791770937,0.6584277749061584,26.87580244357025,1.7201982736587524,3003,15710.543276309969,0.6321825981140137,30.340070514809657,1.8988949060440063,0.648423433303833,27.29769121173301,1.778126239776611,3000 -6922.425815820694,0.3626482486724853,10118.426603794098,28982,0,10118.426603794098,0.6632153987884521,27.12441943548814,1.6976944208145142,3003,17042.172969341278,0.6355407238006592,30.347952863442696,1.885409474372864,0.6500105261802673,27.48807983750271,1.75833261013031,3000 -7458.717993736267,0.391817569732666,10958.39440727234,31400,0,10958.39440727234,0.664307713508606,27.28199048678084,1.6905139684677124,3003,18418.541491508484,0.6615469455718994,32.51895289150185,1.7062214612960815,0.6527507305145264,27.932306559997773,1.750123143196106,3000 -7950.335062503815,0.4220054149627685,11798.52327799797,33818,0,11798.52327799797,0.6667364239692688,27.411963702224377,1.6695778369903564,3003,19750.40006542205,0.636962890625,30.93731275058393,1.8799915313720703,0.6577227711677551,28.16495224589458,1.7262121438980105,3000 -8434.08610200882,0.4524543285369873,12638.686345100405,36237,0,12638.686345100405,0.6673523187637329,27.305733465104947,1.6683061122894287,3003,21074.42302322388,0.6327520608901978,30.885318130356406,1.8985131978988647,0.6562100648880005,28.097237758642365,1.7289568185806274,3000 -8890.070576429367,0.4886150360107422,13478.862303972244,38656,0,13478.862303972244,0.670001745223999,27.66217813006899,1.6540709733963013,3003,22370.699309825897,0.6423766613006592,31.39003853798642,1.8290132284164429,0.6569168567657471,27.887943009530343,1.7200067043304443,3000 -9385.98035311699,0.5256316661834717,14318.88666820526,41073,0,14318.88666820526,0.6714078187942505,27.96062907699173,1.647200584411621,3003,23706.75301337242,0.6356822848320007,31.44775272147028,1.8719481229782104,0.6580947637557983,28.085698986814226,1.707008957862854,3000 -9895.090104341509,0.5565388202667236,15158.90975689888,43492,0,15158.90975689888,0.6728022694587708,27.88330925488826,1.6397759914398191,3003,25055.99835085869,0.6358650922775269,31.489122957348748,1.8773494958877563,0.6598678231239319,28.275403261768112,1.6993831396102903,3000 -10440.49032473564,0.5880370140075684,15998.971328496931,45910,0,15998.971328496931,0.6724420785903931,28.059474468605845,1.634313702583313,3003,26441.571583509445,0.6427233219146729,31.07863303548564,1.8206764459609983,0.6608349680900574,28.36347288268632,1.692588448524475,3000 -10916.7520134449,0.619476318359375,16838.999277830124,48328,0,16838.999277830124,0.6750450730323792,28.01314804173373,1.6206564903259275,3003,27757.97251176834,0.6389437913894653,31.00010065151017,1.855763554573059,0.6613061428070068,28.22158443610769,1.6884702444076538,3000 -11382.648072242737,0.6527702808380127,17679.204163074493,50747,0,17679.204163074493,0.6755331158638,28.33245909376165,1.6183867454528809,3003,29064.18690776825,0.6508919596672058,31.78269588076056,1.7648361921310425,0.6645422577857971,28.74312417864181,1.6845265626907349,3000 -11919.171528339386,0.6851651668548584,18519.205310583115,53166,0,18519.205310583115,0.6751612424850464,28.21440380671305,1.60860013961792,3003,30440.82387685776,0.6417784690856934,31.41106280404022,1.827314376831055,0.664182722568512,28.51513503042841,1.6704254150390625,3000 -12487.166585206984,0.7213225364685059,19359.40021085739,55585,0,19359.40021085739,0.6774388551712036,28.394650410572435,1.6036652326583862,3003,31849.13062644005,0.6417572498321533,31.531738068339457,1.8483734130859373,0.665658175945282,28.607842205522875,1.668606519699097,3000 -12950.59977388382,0.7550153732299805,20199.471867084503,58004,0,20199.471867084503,0.6786706447601318,28.438258704322525,1.5954159498214722,3003,33152.7517850399,0.650646448135376,32.27619017894549,1.7800432443618774,0.6677040457725525,28.97028555750232,1.6552205085754397,3000 -13445.960545539856,0.7892227172851562,21039.458937883377,60422,0,21039.458937883377,0.6806228756904602,28.597255906283987,1.5879414081573486,3003,34488.21444058418,0.6451210379600525,31.819075921555985,1.8069937229156487,0.6668609380722046,28.74034449614497,1.6589609384536743,3000 -13930.784398078918,0.8219027519226074,21879.506196975708,62841,0,21879.506196975708,0.6802510023117065,28.31530092941621,1.5829477310180664,3003,35813.19875717163,0.6688887476921082,33.49501242969906,1.6635355949401855,0.6675428748130798,28.96451958189699,1.650787353515625,3000 -14404.023327350616,0.8547759056091309,22719.717471838,65260,0,22719.717471838,0.6811341643333435,28.40901193383648,1.5736756324768066,3003,37126.762442588806,0.6497659683227539,32.40010374364767,1.7808959484100342,0.6687827706336975,28.94475644983091,1.6441073417663574,3000 -14930.35793542862,0.8874788284301758,23559.93111491204,67679,0,23559.93111491204,0.6845738291740417,29.088031046885384,1.5635030269622805,3003,38493.42343258858,0.6502645611763,31.68374978602823,1.7860978841781616,0.6713865995407104,29.22256790706708,1.6337947845458984,3000 -15469.750158786774,0.9219081401824952,24400.08797216416,70098,0,24400.08797216416,0.6856196522712708,28.82446304234327,1.5533530712127686,3003,39873.08808875084,0.6582711338996887,32.39691829939411,1.7285667657852173,0.671584963798523,29.149120230108544,1.6267058849334717,3000 -15970.62186050415,0.957329273223877,25240.307821035385,72518,0,25240.307821035385,0.6862123012542725,28.91150751075374,1.5499813556671145,3003,41214.29425621033,0.6576229333877563,32.81365972041144,1.7471884489059448,0.6732960343360901,29.323903687794544,1.6212689876556396,3000 -16445.30373263359,0.995025396347046,26080.382111549377,74937,0,26080.382111549377,0.6883504986763,29.17476942337257,1.5352879762649536,3003,42529.167917490005,0.6538400053977966,32.41990521138852,1.7619236707687378,0.6753295063972473,29.47000517384073,1.6074609756469729,3000 -16943.978850841522,1.0299007892608645,26920.60708403588,77357,0,26920.60708403588,0.6897449493408203,29.111522492945284,1.5332320928573608,3003,43868.18283462525,0.6637595295906067,32.32256300712956,1.6998523473739624,0.6762842535972595,29.577603225878857,1.6049737930297852,3000 -17437.31713628769,1.068063259124756,27760.75708150864,79777,0,27760.75708150864,0.6903492212295532,29.494597510156364,1.5179253816604614,3003,45201.78984117508,0.6550542116165161,32.499716599105,1.7492350339889526,0.6763834357261658,29.572896222219057,1.593187689781189,3000 -17931.717635393143,1.1103150844573977,28600.833948373795,82196,0,28600.833948373795,0.6923130750656128,29.512227379029262,1.5137943029403689,3003,46536.39051222801,0.6707914471626282,33.72058966171932,1.653606414794922,0.6779953241348267,29.64467562287056,1.590328335762024,3000 -18450.751024246216,1.1467373371124268,29440.945076942444,84615,0,29440.945076942444,0.693312406539917,30.06056970667257,1.509398102760315,3003,47895.65140795708,0.6634590029716492,32.766808653370965,1.7018529176712036,0.6799171566963196,29.75577826387701,1.586037039756775,3000 -18958.97201180458,1.1856064796447754,30280.988560199738,87033,0,30280.988560199738,0.6963453888893127,29.654388366237686,1.4961185455322266,3003,49244.03613877296,0.6625377535820007,32.974561510453555,1.709887146949768,0.6813058853149414,29.83404728465848,1.572974443435669,3000 -19494.512558221817,1.2216103076934814,31121.10328722,89452,0,31121.10328722,0.6961362361907959,29.73615868519291,1.495225429534912,3003,50619.80656862259,0.6736891269683838,33.41154952869727,1.6366500854492188,0.6804007291793823,29.573600778223444,1.5719475746154783,3000 -19998.722029209137,1.26001238822937,31961.19503569603,91871,0,31961.19503569603,0.6966475248336792,29.9217194579761,1.4875853061676023,3003,51964.228222608566,0.6677938103675842,33.170331808642466,1.668520450592041,0.6823474168777466,30.001956328565168,1.5652003288269043,3000 -20481.708650112152,1.298454761505127,32801.38553190231,94290,0,32801.38553190231,0.6995874643325806,30.34351668402616,1.4738984107971191,3003,53287.52480864525,0.6874837279319763,34.74978426592824,1.560304880142212,0.684963583946228,30.26482790250689,1.5531262159347534,3000 -20996.7893345356,1.3360624313354492,33641.31292676926,96708,0,33641.31292676926,0.700888991355896,30.135616820141905,1.4674443006515503,3003,54642.65100026131,0.6748186945915222,33.85164936344683,1.623995304107666,0.6853107810020447,30.209081609458195,1.5497137308120728,3000 -21516.848722696304,1.3743152618408203,34481.37762069702,99127,0,34481.37762069702,0.7012376189231873,30.20869715529164,1.4639596939086914,3003,56002.891575336456,0.6738321185112,33.38403674379716,1.6364940404891968,0.6848272085189819,30.294506822664715,1.545107364654541,3000 -21981.911460876465,1.413140058517456,35321.28910493851,101546,0,35321.28910493851,0.7024809718132019,30.57824113953,1.4583240747451782,3003,57307.98519515991,0.6848159432411194,34.37797153427266,1.569870948791504,0.6859679222106934,30.165309046410087,1.5355159044265747,3000 -22471.217081546783,1.451249599456787,36161.35554885864,103965,0,36161.35554885864,0.7020859122276306,30.37345328936431,1.455241084098816,3003,58637.47705602646,0.6798616647720337,34.46955027904853,1.5981993675231934,0.688063383102417,30.38069848744398,1.5368680953979492,3000 -22978.96356916428,1.4913089275360107,37001.44723653793,106383,0,37001.44723653793,0.7030503749847412,30.489053354388908,1.4487695693969729,3003,59985.43886613846,0.7143574953079224,36.8438553370823,1.4251322746276855,0.6882121562957764,30.254751595029035,1.5287905931472778,3000 -23485.036183595657,1.5315580368041992,37841.80606007576,108801,0,37841.80606007576,0.7047237157821655,30.474966315456253,1.443089723587036,3003,61331.99062085152,0.6898171305656433,34.717774021813774,1.552299976348877,0.6894644498825073,30.55597763090765,1.5260159969329834,3000 -23982.11319732666,1.572606325149536,38681.82443928719,111219,0,38681.82443928719,0.7083958387374878,30.872641774580423,1.4284865856170654,3003,62669.208958387375,0.6886301040649414,34.74471014801158,1.5508441925048828,0.6914111375808716,30.64500096572725,1.5159528255462646,3000 -24451.67786312104,1.6134326457977295,39521.84886193276,113635,0,39521.84886193276,0.7083725929260254,30.95872246263002,1.428377389907837,3003,63978.92404127121,0.7003282308578491,35.50631142680764,1.4843988418579102,0.6920930743217468,30.741345891342124,1.5145164728164673,3000 -24944.99639344216,1.653876543045044,40361.91211247444,116053,0,40361.91211247444,0.7091976404190063,31.064218560260848,1.4271239042282104,3003,65312.42619681358,0.6954756379127502,35.4937397891447,1.5078685283660889,0.6918203234672546,30.846350466007053,1.5126873254776,3000 -25441.72201323509,1.702117919921875,41202.006234169006,118470,0,41202.006234169006,0.7104526162147522,30.91869931762568,1.4224282503128052,3003,66649.37670564651,0.6973755359649658,35.351594846978585,1.5098248720169067,0.691745936870575,30.6144301486764,1.5092313289642334,3000 -25943.07703590393,1.7425308227539062,42042.0838637352,120888,0,42042.0838637352,0.7091627717018127,31.02129557676758,1.4226619005203247,3003,67990.92956995964,0.7043367624282837,35.96213979511151,1.4685285091400146,0.6929238438606262,30.69302274081346,1.5071839094161987,3000 -26427.96184802056,1.7832961082458496,42882.12957930565,123305,0,42882.12957930565,0.7106502056121826,31.02842778685176,1.4200072288513184,3003,69315.98330974579,0.7063530087471008,36.09334633422933,1.4596189260482788,0.6923534870147705,30.75848752404956,1.504960298538208,3000 -26911.074389457703,1.8275110721588133,43722.28950881958,125724,0,43722.28950881958,0.710615336894989,31.023551642698948,1.4180010557174685,3003,70639.37914276123,0.7098019123077393,36.35349656081629,1.4383488893508911,0.6933577656745911,30.86917300683674,1.502547264099121,3000 -27387.63648557663,1.872190237045288,44562.2737493515,128142,0,44562.2737493515,0.7110220193862915,31.07880207453981,1.4158810377120972,3003,71956.04916667938,0.7098627090454102,36.20999485164631,1.4390041828155518,0.6937545537948608,30.961534202035583,1.502580165863037,3000 -27864.195012569427,1.916002988815308,45402.22863721848,130560,0,45402.22863721848,0.711266040802002,31.17882674781833,1.415252923965454,3003,73272.68795681,0.7109185457229614,36.323414242246834,1.434372901916504,0.6939157247543335,30.871359894567888,1.5021934509277344,3000 -28334.5637383461,1.960179090499878,46242.14085435867,132978,0,46242.14085435867,0.7113938927650452,31.23585328606663,1.4153876304626465,3003,74583.09764313698,0.7064332365989685,36.41087151297877,1.4602043628692627,0.6937917470932007,30.904665803245663,1.502265214920044,3000 -28808.747275829315,2.0037403106689453,46365.04322433472,133333,0,46365.04322433472,0.7113706469535828,31.22193609040821,1.415420651435852,3003,75180.2390575409,0.7103666067123413,36.36197193300483,1.433834433555603,0.6937669515609741,30.902044416208398,1.502273440361023,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index fd01dfe96..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.2275486,10.960339,,,,,,,,,,,,,,,,, -1,,,0.0006382566643878,10.960665702819824,0.0,0.0004835649742744,10.980294227600098,0.0,3000.0,0.0007088489946909,10.966498374938965,0.0,3003.0,37.38932204246521,919.9261367321014,37.38932204246521,882.5367727279663,0.0,0.0 -100,0.39252505,8.968057,,,,,,,,,,,,,,,,, -200,0.16082531,8.581539,,,,,,,,,,,,,,,,, -300,0.17985734,8.335307,,,,,,,,,,,,,,,,, -400,0.25205564,8.013865,,,,,,,,,,,,,,,,, -500,0.35767213,7.663035,,,,,,,,,,,,,,,,, -600,0.5109454,7.411954,,,,,,,,,,,,,,,,, -700,0.61367494,7.247362,,,,,,,,,,,,,,,,, -800,0.5663554,6.976062,,,,,,,,,,,,,,,,, -900,0.57988864,6.7435584,,,,,,,,,,,,,,,,, -1000,0.563016,6.590374,,,,,,,,,,,,,,,,, -1100,0.69635063,6.3629355,,,,,,,,,,,,,,,,, -1200,0.81312186,6.2095537,,,,,,,,,,,,,,,,, -1300,0.506547,6.1209908,,,,,,,,,,,,,,,,, -1400,0.77841777,5.9826207,,,,,,,,,,,,,,,,, -1500,0.6148735,5.8296156,,,,,,,,,,,,,,,,, -1600,0.6450021,5.666721,,,,,,,,,,,,,,,,, -1700,0.54518723,5.5673146,,,,,,,,,,,,,,,,, -1800,1.5623139,5.4891357,,,,,,,,,,,,,,,,, -1900,1.0547086,5.3503547,,,,,,,,,,,,,,,,, -2000,0.6484911,5.307024,,,,,,,,,,,,,,,,, -2100,1.1187018,5.1577477,,,,,,,,,,,,,,,,, -2200,0.64605397,5.094729,,,,,,,,,,,,,,,,, -2300,0.74373925,4.989433,,,,,,,,,,,,,,,,, -2400,1.206457,4.847176,,,,,,,,,,,,,,,,, -2411,,,0.4126627445220947,4.002304553985596,14.245377447202788,0.3975647985935211,4.118062496185303,9.613389479466637,3000.0,0.3811051249504089,4.32146692276001,8.015197424590168,3003.0,877.514372587204,2355.1783118247986,877.514372587204,1477.5438141822815,0.0384318828582763,0.0 -2500,0.80570596,4.8671355,,,,,,,,,,,,,,,,, -2600,1.0746293,4.644018,,,,,,,,,,,,,,,,, -2700,1.3878789,4.660332,,,,,,,,,,,,,,,,, -2800,1.3421057,4.6049747,,,,,,,,,,,,,,,,, -2900,0.66344064,4.4743853,,,,,,,,,,,,,,,,, -3000,0.6725298,4.434315,,,,,,,,,,,,,,,,, -3100,0.8506165,4.383146,,,,,,,,,,,,,,,,, -3200,0.7431074,4.345572,,,,,,,,,,,,,,,,, -3300,0.9651064,4.210878,,,,,,,,,,,,,,,,, -3400,0.76616174,4.2104177,,,,,,,,,,,,,,,,, -3500,0.72440964,4.1439686,,,,,,,,,,,,,,,,, -3600,0.66361034,4.1694517,,,,,,,,,,,,,,,,, -3700,0.6447309,4.0986547,,,,,,,,,,,,,,,,, -3800,0.6762318,4.060281,,,,,,,,,,,,,,,,, -3900,0.6723311,4.0118537,,,,,,,,,,,,,,,,, -4000,0.72291213,4.0280747,,,,,,,,,,,,,,,,, -4100,0.63329357,3.9788342,,,,,,,,,,,,,,,,, -4200,0.6585279,3.9316688,,,,,,,,,,,,,,,,, -4300,0.70780534,3.8942819,,,,,,,,,,,,,,,,, -4400,0.5523742,3.8605099,,,,,,,,,,,,,,,,, -4500,0.56868297,3.8099186,,,,,,,,,,,,,,,,, -4600,0.56505084,3.8154848,,,,,,,,,,,,,,,,, -4700,0.5659034,3.782592,,,,,,,,,,,,,,,,, -4800,0.5085804,3.752678,,,,,,,,,,,,,,,,, -4821,,,0.5397511124610901,2.753583669662476,24.513584478437387,0.5449033379554749,2.7072184085845947,20.62834035354065,3000.0,0.5461623668670654,2.733055353164673,19.172664785009378,3003.0,1717.5991599559784,3697.6896362304688,1717.5991599559784,1979.8509588241573,0.0781202316284179,0.0 -4900,0.5931035,3.777735,,,,,,,,,,,,,,,,, -5000,0.5547078,3.7565389,,,,,,,,,,,,,,,,, -5100,0.58252776,3.7031488,,,,,,,,,,,,,,,,, -5200,0.53691214,3.69021,,,,,,,,,,,,,,,,, -5300,0.47806272,3.6487427,,,,,,,,,,,,,,,,, -5400,0.5045373,3.6647828,,,,,,,,,,,,,,,,, -5500,0.60625297,3.6937451,,,,,,,,,,,,,,,,, -5600,0.49247485,3.6752017,,,,,,,,,,,,,,,,, -5700,0.4781562,3.6503582,,,,,,,,,,,,,,,,, -5800,0.53489184,3.6206937,,,,,,,,,,,,,,,,, -5900,0.5837455,3.6176536,,,,,,,,,,,,,,,,, -6000,0.5024573,3.591188,,,,,,,,,,,,,,,,, -6100,0.5925126,3.6254005,,,,,,,,,,,,,,,,, -6200,0.47268948,3.620698,,,,,,,,,,,,,,,,, -6300,0.43305656,3.5316153,,,,,,,,,,,,,,,,, -6400,0.46436405,3.5508022,,,,,,,,,,,,,,,,, -6500,0.42809412,3.574417,,,,,,,,,,,,,,,,, -6600,0.5007066,3.6033292,,,,,,,,,,,,,,,,, -6700,0.39962915,3.5346022,,,,,,,,,,,,,,,,, -6800,0.50777686,3.5118387,,,,,,,,,,,,,,,,, -6900,0.44078943,3.5543382,,,,,,,,,,,,,,,,, -7000,0.4331341,3.5046873,,,,,,,,,,,,,,,,, -7100,0.40931875,3.550097,,,,,,,,,,,,,,,,, -7200,0.46508738,3.4676018,,,,,,,,,,,,,,,,, -7234,,,0.5832327604293823,2.348278522491455,27.40869814351565,0.5864651203155518,2.313688039779663,23.116123270683246,3000.0,0.5904944539070129,2.3054873943328857,22.381906824340813,3003.0,2557.73996424675,5030.110518932343,2557.73996424675,2472.026073932648,0.1051478385925293,0.0 -7300,0.47227982,3.4267893,,,,,,,,,,,,,,,,, -7400,0.40985486,3.390139,,,,,,,,,,,,,,,,, -7500,0.3846465,3.4081795,,,,,,,,,,,,,,,,, -7600,0.51437664,3.38502,,,,,,,,,,,,,,,,, -7700,0.47044644,3.4234283,,,,,,,,,,,,,,,,, -7800,0.43230873,3.4490268,,,,,,,,,,,,,,,,, -7900,0.4092601,3.5037918,,,,,,,,,,,,,,,,, -8000,0.35586303,3.3813496,,,,,,,,,,,,,,,,, -8100,0.35927278,3.403218,,,,,,,,,,,,,,,,, -8200,0.34859908,3.4089365,,,,,,,,,,,,,,,,, -8300,0.41517335,3.3731618,,,,,,,,,,,,,,,,, -8400,0.37998962,3.393767,,,,,,,,,,,,,,,,, -8500,0.33201575,3.2748437,,,,,,,,,,,,,,,,, -8600,0.32286876,3.2871377,,,,,,,,,,,,,,,,, -8700,0.33565,3.3740103,,,,,,,,,,,,,,,,, -8800,0.4032893,3.294327,,,,,,,,,,,,,,,,, -8900,0.41500074,3.3203263,,,,,,,,,,,,,,,,, -9000,0.32754916,3.4274917,,,,,,,,,,,,,,,,, -9100,0.35529578,3.3640332,,,,,,,,,,,,,,,,, -9200,0.34843233,3.4304583,,,,,,,,,,,,,,,,, -9300,0.3447424,3.3211696,,,,,,,,,,,,,,,,, -9400,0.32194826,3.3775349,,,,,,,,,,,,,,,,, -9500,0.30619663,3.2913477,,,,,,,,,,,,,,,,, -9600,0.3259735,3.3017902,,,,,,,,,,,,,,,,, -9649,,,0.5944892764091492,2.2307252883911133,27.885071200396293,0.608386754989624,2.125382423400879,24.95503798977198,3000.0,0.6160014271736145,2.095056533813477,23.70429726840615,3003.0,3397.9064087867737,6324.159275770187,3397.9064087867737,2925.8024010658264,0.1338064670562744,0.0 -9700,0.30526662,3.2355876,,,,,,,,,,,,,,,,, -9800,0.2855355,3.3002598,,,,,,,,,,,,,,,,, -9900,0.30147123,3.2819252,,,,,,,,,,,,,,,,, -10000,0.3041153,3.2554853,,,,,,,,,,,,,,,,, -10100,0.3464335,3.361351,,,,,,,,,,,,,,,,, -10200,0.3137512,3.2769976,,,,,,,,,,,,,,,,, -10300,0.30692917,3.2743874,,,,,,,,,,,,,,,,, -10400,0.30265993,3.2707467,,,,,,,,,,,,,,,,, -10500,0.2777225,3.3604264,,,,,,,,,,,,,,,,, -10600,0.2896631,3.2825954,,,,,,,,,,,,,,,,, -10700,0.27028474,3.3261173,,,,,,,,,,,,,,,,, -10800,0.26410922,3.285944,,,,,,,,,,,,,,,,, -10900,0.25657257,3.2536361,,,,,,,,,,,,,,,,, -11000,0.32161313,3.3734434,,,,,,,,,,,,,,,,, -11100,0.2922208,3.1935005,,,,,,,,,,,,,,,,, -11200,0.2673738,3.2042654,,,,,,,,,,,,,,,,, -11300,0.2636149,3.257447,,,,,,,,,,,,,,,,, -11400,0.2854974,3.2332299,,,,,,,,,,,,,,,,, -11500,0.2521972,3.2548015,,,,,,,,,,,,,,,,, -11600,0.25635788,3.218306,,,,,,,,,,,,,,,,, -11700,0.28260782,3.2771666,,,,,,,,,,,,,,,,, -11800,0.2684893,3.142275,,,,,,,,,,,,,,,,, -11900,0.27251592,3.238768,,,,,,,,,,,,,,,,, -12000,0.3340705,3.2707422,,,,,,,,,,,,,,,,, -12064,,,0.6005798578262329,2.151870012283325,28.89981054140853,0.6196823120117188,2.009481906890869,25.802299412078856,3000.0,0.6272267699241638,1.9692918062210083,24.68109616132204,3003.0,4237.891711235046,7633.663153886795,4237.891711235046,3395.214864253998,0.15960693359375,0.0 -12100,0.26004812,3.1697614,,,,,,,,,,,,,,,,, -12200,0.3016056,3.2389421,,,,,,,,,,,,,,,,, -12300,0.24179186,3.153571,,,,,,,,,,,,,,,,, -12400,0.25865066,3.1942878,,,,,,,,,,,,,,,,, -12500,0.25805748,3.2817504,,,,,,,,,,,,,,,,, -12600,0.26463544,3.2369554,,,,,,,,,,,,,,,,, -12700,0.2513738,3.2638013,,,,,,,,,,,,,,,,, -12800,0.2778486,3.2301915,,,,,,,,,,,,,,,,, -12900,0.24338779,3.1387372,,,,,,,,,,,,,,,,, -13000,0.22688057,3.1480606,,,,,,,,,,,,,,,,, -13100,0.23473108,3.1933763,,,,,,,,,,,,,,,,, -13200,0.25882238,3.1974764,,,,,,,,,,,,,,,,, -13300,0.2622393,3.1675951,,,,,,,,,,,,,,,,, -13400,0.24577999,3.2129838,,,,,,,,,,,,,,,,, -13500,0.26636043,3.172494,,,,,,,,,,,,,,,,, -13600,0.26866034,3.16552,,,,,,,,,,,,,,,,, -13700,0.2696446,3.1810083,,,,,,,,,,,,,,,,, -13800,0.2514293,3.0982997,,,,,,,,,,,,,,,,, -13900,0.257217,3.2399836,,,,,,,,,,,,,,,,, -14000,0.23973295,3.1567614,,,,,,,,,,,,,,,,, -14100,0.23575175,3.168263,,,,,,,,,,,,,,,,, -14200,0.26650316,3.1678157,,,,,,,,,,,,,,,,, -14300,0.28743553,3.1001234,,,,,,,,,,,,,,,,, -14400,0.26768517,3.1544495,,,,,,,,,,,,,,,,, -14480,,,0.6170651316642761,2.030705451965332,29.54870471997391,0.630109965801239,1.917237877845764,26.213816054497475,3000.0,0.6392307281494141,1.868680477142334,25.5829851352678,3003.0,5078.001572847366,8972.102147102356,5078.001572847366,3893.4397599697113,0.1855201721191406,0.0 -14500,0.2508948,3.0909493,,,,,,,,,,,,,,,,, -14600,0.252036,3.1672091,,,,,,,,,,,,,,,,, -14700,0.24547215,3.1892257,,,,,,,,,,,,,,,,, -14800,0.28479666,3.1375935,,,,,,,,,,,,,,,,, -14900,0.26247418,3.169771,,,,,,,,,,,,,,,,, -15000,0.2727037,3.1038287,,,,,,,,,,,,,,,,, -15100,0.29804146,3.1799269,,,,,,,,,,,,,,,,, -15200,0.25690562,3.1932266,,,,,,,,,,,,,,,,, -15300,0.27002493,3.1448221,,,,,,,,,,,,,,,,, -15400,0.25288764,3.0837624,,,,,,,,,,,,,,,,, -15500,0.22979261,3.0758905,,,,,,,,,,,,,,,,, -15600,0.3434945,3.1473274,,,,,,,,,,,,,,,,, -15700,0.36056468,3.1736257,,,,,,,,,,,,,,,,, -15800,0.26774773,3.0975697,,,,,,,,,,,,,,,,, -15900,0.2920781,3.1415484,,,,,,,,,,,,,,,,, -16000,0.29883787,3.0200045,,,,,,,,,,,,,,,,, -16100,0.28421187,3.1314085,,,,,,,,,,,,,,,,, -16200,0.28039205,3.0481339,,,,,,,,,,,,,,,,, -16300,0.25814033,3.0697863,,,,,,,,,,,,,,,,, -16400,0.27843237,3.047747,,,,,,,,,,,,,,,,, -16500,0.29799268,3.1566927,,,,,,,,,,,,,,,,, -16600,0.2867805,3.0136704,,,,,,,,,,,,,,,,, -16700,0.3507007,3.1464489,,,,,,,,,,,,,,,,, -16800,0.2950353,3.0503774,,,,,,,,,,,,,,,,, -16895,,,0.6205698847770691,2.0030786991119385,30.48051704857396,0.6400168538093567,1.853195786476136,26.91149584465031,3000.0,0.648678183555603,1.8056501150131223,26.35533227720316,3003.0,5918.144634008408,10295.827911376951,5918.144634008408,4376.910125494003,0.2147214412689209,0.0 -16900,0.2952534,3.1278255,,,,,,,,,,,,,,,,, -17000,0.27573937,3.2066557,,,,,,,,,,,,,,,,, -17100,0.27843368,3.1288633,,,,,,,,,,,,,,,,, -17200,0.2809365,3.0417185,,,,,,,,,,,,,,,,, -17300,0.3298672,3.1133444,,,,,,,,,,,,,,,,, -17400,0.27664527,3.047349,,,,,,,,,,,,,,,,, -17500,0.2868167,3.0897405,,,,,,,,,,,,,,,,, -17600,0.299274,3.117866,,,,,,,,,,,,,,,,, -17700,0.29290187,3.0494232,,,,,,,,,,,,,,,,, -17800,0.30405787,3.0771732,,,,,,,,,,,,,,,,, -17900,0.31429762,3.0552914,,,,,,,,,,,,,,,,, -18000,0.3208718,3.0769658,,,,,,,,,,,,,,,,, -18100,0.27747634,3.0225139,,,,,,,,,,,,,,,,, -18200,0.3319313,3.1616778,,,,,,,,,,,,,,,,, -18300,0.27875847,3.0043569,,,,,,,,,,,,,,,,, -18400,0.3228637,3.0773296,,,,,,,,,,,,,,,,, -18500,0.32039452,3.0780365,,,,,,,,,,,,,,,,, -18600,0.31622478,3.0876026,,,,,,,,,,,,,,,,, -18700,0.35858664,3.0108562,,,,,,,,,,,,,,,,, -18800,0.3348845,3.093548,,,,,,,,,,,,,,,,, -18900,0.32934517,2.989937,,,,,,,,,,,,,,,,, -19000,0.38035697,3.0012436,,,,,,,,,,,,,,,,, -19100,0.37286696,3.0697687,,,,,,,,,,,,,,,,, -19200,0.31860572,3.0394602,,,,,,,,,,,,,,,,, -19300,0.34063157,2.9742846,,,,,,,,,,,,,,,,, -19311,,,0.6361596584320068,1.8797343969345093,30.9128753871064,0.6456212401390076,1.810285568237305,27.351768905608047,3000.0,0.6544535756111145,1.754570722579956,26.59917892660872,3003.0,6758.210176944733,11624.979398727415,6758.210176944733,4865.887587070465,0.2430028915405273,0.0 -19400,0.36272988,3.0846658,,,,,,,,,,,,,,,,, -19500,0.38342777,2.9784408,,,,,,,,,,,,,,,,, -19600,0.33376572,3.1306052,,,,,,,,,,,,,,,,, -19700,0.35111,3.0518086,,,,,,,,,,,,,,,,, -19800,0.36861223,3.049944,,,,,,,,,,,,,,,,, -19900,0.36824968,3.0933762,,,,,,,,,,,,,,,,, -20000,0.31248578,3.0418282,,,,,,,,,,,,,,,,, -20100,0.3291934,2.9727619,,,,,,,,,,,,,,,,, -20200,0.40061858,2.998131,,,,,,,,,,,,,,,,, -20300,0.42912862,3.0704434,,,,,,,,,,,,,,,,, -20400,0.32168666,3.1021729,,,,,,,,,,,,,,,,, -20500,0.32056653,3.0164683,,,,,,,,,,,,,,,,, -20600,0.3140387,2.9405885,,,,,,,,,,,,,,,,, -20700,0.37227204,2.9591892,,,,,,,,,,,,,,,,, -20800,0.3093337,3.0375905,,,,,,,,,,,,,,,,, -20900,0.3324505,3.0079172,,,,,,,,,,,,,,,,, -21000,0.33961037,3.0201252,,,,,,,,,,,,,,,,, -21100,0.42910627,3.0412626,,,,,,,,,,,,,,,,, -21200,0.33417663,3.0664868,,,,,,,,,,,,,,,,, -21300,0.5206797,3.080964,,,,,,,,,,,,,,,,, -21400,2.567396,5.064516,,,,,,,,,,,,,,,,, -21500,0.8150995,4.869471,,,,,,,,,,,,,,,,, -21600,0.34181827,4.772241,,,,,,,,,,,,,,,,, -21700,0.5021196,4.753515,,,,,,,,,,,,,,,,, -21728,,,0.3390825092792511,3.9086546897888175,0.5862118968036852,0.3057122528553009,4.292892932891846,0.1320855284362729,3000.0,0.2937540113925934,4.477366924285889,0.1403133577495965,3003.0,7598.422609567642,13001.2740752697,7598.422609567642,5401.86420583725,0.2701001167297363,0.0 -21800,0.44276774,4.7420173,,,,,,,,,,,,,,,,, -21900,0.53113693,4.7302613,,,,,,,,,,,,,,,,, -22000,0.7017814,4.792573,,,,,,,,,,,,,,,,, -22100,0.56494755,4.702534,,,,,,,,,,,,,,,,, -22200,0.90934175,4.7186522,,,,,,,,,,,,,,,,, -22300,0.66424495,4.624328,,,,,,,,,,,,,,,,, -22400,2.4120624,4.666037,,,,,,,,,,,,,,,,, -22500,3.99696,4.647831,,,,,,,,,,,,,,,,, -22600,3.0438147,4.586124,,,,,,,,,,,,,,,,, -22700,1.9158037,4.361886,,,,,,,,,,,,,,,,, -22800,0.6807436,3.3157945,,,,,,,,,,,,,,,,, -22900,0.3927096,3.1905997,,,,,,,,,,,,,,,,, -23000,0.44743654,3.1031675,,,,,,,,,,,,,,,,, -23100,0.3451953,3.0813243,,,,,,,,,,,,,,,,, -23200,0.32571808,3.0965507,,,,,,,,,,,,,,,,, -23300,0.3244751,3.063844,,,,,,,,,,,,,,,,, -23400,0.33577463,3.1138813,,,,,,,,,,,,,,,,, -23500,0.32787293,3.026123,,,,,,,,,,,,,,,,, -23600,0.36644077,3.1123438,,,,,,,,,,,,,,,,, -23700,0.428848,3.0261173,,,,,,,,,,,,,,,,, -23800,0.3565795,3.0322511,,,,,,,,,,,,,,,,, -23900,0.45167664,3.073529,,,,,,,,,,,,,,,,, -24000,0.3207673,3.037026,,,,,,,,,,,,,,,,, -24100,0.3591559,3.04165,,,,,,,,,,,,,,,,, -24147,,,0.6233353614807129,1.9733836650848389,30.139503659176032,0.6435753703117371,1.819201946258545,27.41255425476972,3000.0,0.6543838381767273,1.7676335573196411,26.747935955855187,3003.0,8438.351751565933,14347.287384033203,8438.351751565933,5907.838281869888,0.300400972366333,0.0 -24200,0.37468415,3.036394,,,,,,,,,,,,,,,,, -24300,0.3016356,3.048237,,,,,,,,,,,,,,,,, -24400,0.5027949,2.9839883,,,,,,,,,,,,,,,,, -24500,0.36979154,2.979511,,,,,,,,,,,,,,,,, -24600,0.32602066,3.0649362,,,,,,,,,,,,,,,,, -24700,0.37964907,2.9969263,,,,,,,,,,,,,,,,, -24800,0.33987737,3.0049894,,,,,,,,,,,,,,,,, -24900,0.33071122,3.0664191,,,,,,,,,,,,,,,,, -25000,0.37388697,2.9586532,,,,,,,,,,,,,,,,, -25100,0.41807616,3.053597,,,,,,,,,,,,,,,,, -25200,0.39376402,3.0844698,,,,,,,,,,,,,,,,, -25300,0.37806153,3.0000396,,,,,,,,,,,,,,,,, -25400,0.33637503,2.9799252,,,,,,,,,,,,,,,,, -25500,0.39228618,2.9630642,,,,,,,,,,,,,,,,, -25600,0.34207532,2.9999986,,,,,,,,,,,,,,,,, -25700,0.41220427,3.0029926,,,,,,,,,,,,,,,,, -25800,0.33956975,3.063231,,,,,,,,,,,,,,,,, -25900,0.42858353,2.960474,,,,,,,,,,,,,,,,, -26000,0.39862347,3.0237856,,,,,,,,,,,,,,,,, -26100,0.4234204,3.047319,,,,,,,,,,,,,,,,, -26200,0.33556965,3.054041,,,,,,,,,,,,,,,,, -26300,0.35004964,3.0262918,,,,,,,,,,,,,,,,, -26400,0.40320987,3.0083587,,,,,,,,,,,,,,,,, -26500,0.34093708,3.0209637,,,,,,,,,,,,,,,,, -26564,,,0.6321825981140137,1.8988949060440063,30.340070514809657,0.648423433303833,1.778126239776611,27.29769121173301,3000.0,0.6584277749061584,1.7201982736587524,26.87580244357025,3003.0,9278.293791770937,15710.543276309969,9278.293791770937,6431.037258863449,0.3333091735839844,0.0 -26600,0.41355857,3.0210655,,,,,,,,,,,,,,,,, -26700,0.38225353,2.9592807,,,,,,,,,,,,,,,,, -26800,0.38811776,3.0124724,,,,,,,,,,,,,,,,, -26900,0.46816015,3.0154657,,,,,,,,,,,,,,,,, -27000,0.382269,3.0246072,,,,,,,,,,,,,,,,, -27100,0.45034072,2.9664664,,,,,,,,,,,,,,,,, -27200,0.36564043,3.0368474,,,,,,,,,,,,,,,,, -27300,0.3784797,3.039443,,,,,,,,,,,,,,,,, -27400,0.3323137,2.9809513,,,,,,,,,,,,,,,,, -27500,0.42185304,3.0788648,,,,,,,,,,,,,,,,, -27600,0.36728743,3.037404,,,,,,,,,,,,,,,,, -27700,0.35497686,3.076834,,,,,,,,,,,,,,,,, -27800,0.4233005,3.0033784,,,,,,,,,,,,,,,,, -27900,0.34448466,3.0256662,,,,,,,,,,,,,,,,, -28000,0.41609237,3.048154,,,,,,,,,,,,,,,,, -28100,0.4144877,3.0455956,,,,,,,,,,,,,,,,, -28200,0.47342217,2.9484801,,,,,,,,,,,,,,,,, -28300,0.3487746,2.9355395,,,,,,,,,,,,,,,,, -28400,0.45053625,3.0040677,,,,,,,,,,,,,,,,, -28500,0.40586036,3.0423613,,,,,,,,,,,,,,,,, -28600,0.3361354,2.9363272,,,,,,,,,,,,,,,,, -28700,0.36852863,3.037718,,,,,,,,,,,,,,,,, -28800,0.36396188,2.9921672,,,,,,,,,,,,,,,,, -28900,0.3689911,3.0351171,,,,,,,,,,,,,,,,, -28982,,,0.6355407238006592,1.885409474372864,30.347952863442696,0.6500105261802673,1.75833261013031,27.48807983750271,3000.0,0.6632153987884521,1.6976944208145142,27.12441943548814,3003.0,10118.426603794098,17042.172969341278,10118.426603794098,6922.425815820694,0.3626482486724853,0.0 -29000,0.4453941,3.0437183,,,,,,,,,,,,,,,,, -29100,0.39803857,2.9923873,,,,,,,,,,,,,,,,, -29200,0.39048332,3.0154102,,,,,,,,,,,,,,,,, -29300,0.35899,3.0343812,,,,,,,,,,,,,,,,, -29400,0.36034653,3.0070374,,,,,,,,,,,,,,,,, -29500,0.31119847,2.9052393,,,,,,,,,,,,,,,,, -29600,0.33512396,2.9432275,,,,,,,,,,,,,,,,, -29700,0.5207687,3.0543468,,,,,,,,,,,,,,,,, -29800,0.4195082,3.0103295,,,,,,,,,,,,,,,,, -29900,0.34930068,2.9986877,,,,,,,,,,,,,,,,, -30000,0.369341,2.9903696,,,,,,,,,,,,,,,,, -30100,0.36686212,3.0013652,,,,,,,,,,,,,,,,, -30200,0.3836494,2.9958084,,,,,,,,,,,,,,,,, -30300,0.3493569,2.9623034,,,,,,,,,,,,,,,,, -30400,0.42953354,2.9639432,,,,,,,,,,,,,,,,, -30500,0.3893368,2.9732265,,,,,,,,,,,,,,,,, -30600,0.47937256,3.0080683,,,,,,,,,,,,,,,,, -30700,0.34282434,2.9780397,,,,,,,,,,,,,,,,, -30800,0.3654103,2.944427,,,,,,,,,,,,,,,,, -30900,0.37816897,2.9478462,,,,,,,,,,,,,,,,, -31000,0.3273806,2.9209864,,,,,,,,,,,,,,,,, -31100,0.33217755,2.9722672,,,,,,,,,,,,,,,,, -31200,0.3597456,3.0173228,,,,,,,,,,,,,,,,, -31300,0.89186287,3.042693,,,,,,,,,,,,,,,,, -31400,,,0.6615469455718994,1.7062214612960815,32.51895289150185,0.6527507305145264,1.750123143196106,27.932306559997773,3000.0,0.664307713508606,1.6905139684677124,27.28199048678084,3003.0,10958.39440727234,18418.541491508484,10958.39440727234,7458.717993736267,0.391817569732666,0.0 -31400,0.41896126,3.0322237,,,,,,,,,,,,,,,,, -31500,1.2113571,3.044215,,,,,,,,,,,,,,,,, -31600,1.389527,2.9660642,,,,,,,,,,,,,,,,, -31700,0.49283943,3.0443752,,,,,,,,,,,,,,,,, -31800,0.35557732,3.0330923,,,,,,,,,,,,,,,,, -31900,0.35186628,3.0271485,,,,,,,,,,,,,,,,, -32000,0.39408037,3.007642,,,,,,,,,,,,,,,,, -32100,0.35631928,3.005756,,,,,,,,,,,,,,,,, -32200,4.066105,3.007328,,,,,,,,,,,,,,,,, -32300,0.3317078,2.9187489,,,,,,,,,,,,,,,,, -32400,0.3765269,3.0152676,,,,,,,,,,,,,,,,, -32500,0.3748003,2.8652828,,,,,,,,,,,,,,,,, -32600,0.39907745,2.9874737,,,,,,,,,,,,,,,,, -32700,0.3710107,3.007333,,,,,,,,,,,,,,,,, -32800,0.38980186,3.0760818,,,,,,,,,,,,,,,,, -32900,0.37257397,2.9546838,,,,,,,,,,,,,,,,, -33000,0.4190247,2.9232461,,,,,,,,,,,,,,,,, -33100,0.34663337,2.995711,,,,,,,,,,,,,,,,, -33200,0.37550962,2.988586,,,,,,,,,,,,,,,,, -33300,0.34254307,2.8929505,,,,,,,,,,,,,,,,, -33400,0.45669013,3.0269058,,,,,,,,,,,,,,,,, -33500,0.36950487,2.9735847,,,,,,,,,,,,,,,,, -33600,0.45350942,3.0310843,,,,,,,,,,,,,,,,, -33700,0.37483102,3.0220199,,,,,,,,,,,,,,,,, -33800,0.37186837,2.9675908,,,,,,,,,,,,,,,,, -33818,,,0.636962890625,1.8799915313720703,30.93731275058393,0.6577227711677551,1.7262121438980105,28.16495224589458,3000.0,0.6667364239692688,1.6695778369903564,27.411963702224377,3003.0,11798.52327799797,19750.40006542205,11798.52327799797,7950.335062503815,0.4220054149627685,0.0 -33900,0.47494128,3.0033529,,,,,,,,,,,,,,,,, -34000,0.35574296,2.9691985,,,,,,,,,,,,,,,,, -34100,0.8701068,3.2164388,,,,,,,,,,,,,,,,, -34200,0.41139045,3.0251472,,,,,,,,,,,,,,,,, -34300,0.39856124,2.9993174,,,,,,,,,,,,,,,,, -34400,0.34915334,2.9370368,,,,,,,,,,,,,,,,, -34500,0.33753312,2.956266,,,,,,,,,,,,,,,,, -34600,0.3760323,2.9662743,,,,,,,,,,,,,,,,, -34700,0.37138158,2.9510338,,,,,,,,,,,,,,,,, -34800,0.3603373,2.8564167,,,,,,,,,,,,,,,,, -34900,0.33539745,2.9603262,,,,,,,,,,,,,,,,, -35000,0.3390265,2.92343,,,,,,,,,,,,,,,,, -35100,0.34349945,2.9440646,,,,,,,,,,,,,,,,, -35200,0.360997,2.9621453,,,,,,,,,,,,,,,,, -35300,0.3145128,2.895585,,,,,,,,,,,,,,,,, -35400,0.33623987,2.9882925,,,,,,,,,,,,,,,,, -35500,0.35616094,2.9899988,,,,,,,,,,,,,,,,, -35600,0.38244605,2.9672983,,,,,,,,,,,,,,,,, -35700,0.36547977,2.8926055,,,,,,,,,,,,,,,,, -35800,0.37984452,3.003269,,,,,,,,,,,,,,,,, -35900,0.42575708,3.0290349,,,,,,,,,,,,,,,,, -36000,0.38803628,2.9183085,,,,,,,,,,,,,,,,, -36100,0.33665082,2.948947,,,,,,,,,,,,,,,,, -36200,0.34136513,2.9552891,,,,,,,,,,,,,,,,, -36237,,,0.6327520608901978,1.8985131978988647,30.885318130356406,0.6562100648880005,1.7289568185806274,28.097237758642365,3000.0,0.6673523187637329,1.6683061122894287,27.305733465104947,3003.0,12638.686345100405,21074.42302322388,12638.686345100405,8434.08610200882,0.4524543285369873,0.0 -36300,0.3764311,2.941353,,,,,,,,,,,,,,,,, -36400,0.41088915,2.94111,,,,,,,,,,,,,,,,, -36500,0.38095197,2.9245174,,,,,,,,,,,,,,,,, -36600,0.35701516,2.914738,,,,,,,,,,,,,,,,, -36700,0.3566676,2.9084277,,,,,,,,,,,,,,,,, -36800,0.4127535,2.9761865,,,,,,,,,,,,,,,,, -36900,0.35156977,3.027545,,,,,,,,,,,,,,,,, -37000,0.34943998,2.9389265,,,,,,,,,,,,,,,,, -37100,0.37536258,2.9884229,,,,,,,,,,,,,,,,, -37200,0.3749075,2.9481602,,,,,,,,,,,,,,,,, -37300,0.40729263,3.0190492,,,,,,,,,,,,,,,,, -37400,0.36025533,2.9480326,,,,,,,,,,,,,,,,, -37500,0.40466073,2.9649754,,,,,,,,,,,,,,,,, -37600,0.34844863,2.9750416,,,,,,,,,,,,,,,,, -37700,0.35564345,2.914898,,,,,,,,,,,,,,,,, -37800,0.34621772,2.8981767,,,,,,,,,,,,,,,,, -37900,0.41134045,3.0387776,,,,,,,,,,,,,,,,, -38000,0.3991624,2.9491606,,,,,,,,,,,,,,,,, -38100,0.3613168,2.9911916,,,,,,,,,,,,,,,,, -38200,0.34250784,2.896448,,,,,,,,,,,,,,,,, -38300,0.35397023,2.9053025,,,,,,,,,,,,,,,,, -38400,0.34221587,2.973788,,,,,,,,,,,,,,,,, -38500,0.3472401,2.9483979,,,,,,,,,,,,,,,,, -38600,0.32592335,2.9039838,,,,,,,,,,,,,,,,, -38656,,,0.6423766613006592,1.8290132284164429,31.39003853798642,0.6569168567657471,1.7200067043304443,27.887943009530343,3000.0,0.670001745223999,1.6540709733963013,27.66217813006899,3003.0,13478.862303972244,22370.699309825897,13478.862303972244,8890.070576429367,0.4886150360107422,0.0 -38700,0.33664474,2.923365,,,,,,,,,,,,,,,,, -38800,0.3593855,2.9904857,,,,,,,,,,,,,,,,, -38900,0.35490623,2.9373407,,,,,,,,,,,,,,,,, -39000,0.3992503,2.9762669,,,,,,,,,,,,,,,,, -39100,0.36971027,2.9257736,,,,,,,,,,,,,,,,, -39200,0.38146165,2.9910374,,,,,,,,,,,,,,,,, -39300,0.35041383,2.9098847,,,,,,,,,,,,,,,,, -39400,0.46231106,2.9331346,,,,,,,,,,,,,,,,, -39500,0.35580423,2.9702504,,,,,,,,,,,,,,,,, -39600,0.35237005,2.9464982,,,,,,,,,,,,,,,,, -39700,0.36836138,3.002436,,,,,,,,,,,,,,,,, -39800,0.34109807,2.9988842,,,,,,,,,,,,,,,,, -39900,0.38898787,2.9760294,,,,,,,,,,,,,,,,, -40000,0.4247431,2.9389837,,,,,,,,,,,,,,,,, -40100,0.35976243,2.9229698,,,,,,,,,,,,,,,,, -40200,0.41500235,2.9943101,,,,,,,,,,,,,,,,, -40300,0.32959148,2.9095368,,,,,,,,,,,,,,,,, -40400,0.3407055,2.9712863,,,,,,,,,,,,,,,,, -40500,0.4311979,2.9962482,,,,,,,,,,,,,,,,, -40600,0.3942568,2.9402113,,,,,,,,,,,,,,,,, -40700,0.3227648,2.953249,,,,,,,,,,,,,,,,, -40800,0.34145752,2.9676685,,,,,,,,,,,,,,,,, -40900,0.35137096,3.0267942,,,,,,,,,,,,,,,,, -41000,0.35481274,2.946371,,,,,,,,,,,,,,,,, -41073,,,0.6356822848320007,1.8719481229782104,31.44775272147028,0.6580947637557983,1.707008957862854,28.085698986814226,3000.0,0.6714078187942505,1.647200584411621,27.96062907699173,3003.0,14318.88666820526,23706.75301337242,14318.88666820526,9385.98035311699,0.5256316661834717,0.0 -41100,0.408555,2.9318962,,,,,,,,,,,,,,,,, -41200,0.37623206,2.9764187,,,,,,,,,,,,,,,,, -41300,0.3620116,2.9762583,,,,,,,,,,,,,,,,, -41400,0.3396834,2.9698536,,,,,,,,,,,,,,,,, -41500,0.4208476,2.899035,,,,,,,,,,,,,,,,, -41600,0.3613057,3.0465963,,,,,,,,,,,,,,,,, -41700,0.38585532,2.9464414,,,,,,,,,,,,,,,,, -41800,0.3857099,2.9725082,,,,,,,,,,,,,,,,, -41900,0.37178555,2.933661,,,,,,,,,,,,,,,,, -42000,0.43209535,2.9645143,,,,,,,,,,,,,,,,, -42100,0.40119877,2.907324,,,,,,,,,,,,,,,,, -42200,0.37337065,2.920075,,,,,,,,,,,,,,,,, -42300,0.38527122,2.9096568,,,,,,,,,,,,,,,,, -42400,0.5410582,2.9294689,,,,,,,,,,,,,,,,, -42500,0.37438536,2.944049,,,,,,,,,,,,,,,,, -42600,0.34974438,2.9708018,,,,,,,,,,,,,,,,, -42700,0.4820144,2.9886608,,,,,,,,,,,,,,,,, -42800,0.42206955,2.9011,,,,,,,,,,,,,,,,, -42900,0.3707062,2.9420323,,,,,,,,,,,,,,,,, -43000,0.37003013,2.9206834,,,,,,,,,,,,,,,,, -43100,0.3743864,2.9497998,,,,,,,,,,,,,,,,, -43200,0.3537571,2.9321933,,,,,,,,,,,,,,,,, -43300,0.3776133,3.01656,,,,,,,,,,,,,,,,, -43400,0.36322147,2.9257689,,,,,,,,,,,,,,,,, -43492,,,0.6358650922775269,1.8773494958877563,31.489122957348748,0.6598678231239319,1.6993831396102903,28.275403261768112,3000.0,0.6728022694587708,1.6397759914398191,27.88330925488826,3003.0,15158.90975689888,25055.99835085869,15158.90975689888,9895.090104341509,0.5565388202667236,0.0 -43500,0.35627753,2.890202,,,,,,,,,,,,,,,,, -43600,0.34917063,2.9406338,,,,,,,,,,,,,,,,, -43700,0.34805816,2.9624994,,,,,,,,,,,,,,,,, -43800,0.38539198,3.0430233,,,,,,,,,,,,,,,,, -43900,0.38057223,2.9468324,,,,,,,,,,,,,,,,, -44000,0.37844127,2.9442682,,,,,,,,,,,,,,,,, -44100,0.3344037,2.9830623,,,,,,,,,,,,,,,,, -44200,0.400055,2.9622345,,,,,,,,,,,,,,,,, -44300,0.38872978,2.9807944,,,,,,,,,,,,,,,,, -44400,0.33371013,2.8661156,,,,,,,,,,,,,,,,, -44500,0.4160336,2.9576852,,,,,,,,,,,,,,,,, -44600,0.40135047,2.8529387,,,,,,,,,,,,,,,,, -44700,0.41560766,2.9633698,,,,,,,,,,,,,,,,, -44800,0.36001313,2.9344218,,,,,,,,,,,,,,,,, -44900,0.37250733,3.0527244,,,,,,,,,,,,,,,,, -45000,0.3914002,2.9231126,,,,,,,,,,,,,,,,, -45100,0.3825643,2.9285166,,,,,,,,,,,,,,,,, -45200,0.3376761,2.894836,,,,,,,,,,,,,,,,, -45300,0.34300447,2.9096587,,,,,,,,,,,,,,,,, -45400,0.4068734,2.9956539,,,,,,,,,,,,,,,,, -45500,0.36892727,2.971465,,,,,,,,,,,,,,,,, -45600,0.33560398,2.8997428,,,,,,,,,,,,,,,,, -45700,0.43799803,2.900368,,,,,,,,,,,,,,,,, -45800,0.45088693,2.8486187,,,,,,,,,,,,,,,,, -45900,0.3690504,2.936345,,,,,,,,,,,,,,,,, -45910,,,0.6427233219146729,1.8206764459609983,31.07863303548564,0.6608349680900574,1.692588448524475,28.36347288268632,3000.0,0.6724420785903931,1.634313702583313,28.059474468605845,3003.0,15998.971328496931,26441.571583509445,15998.971328496931,10440.49032473564,0.5880370140075684,0.0 -46000,0.36913362,2.9945111,,,,,,,,,,,,,,,,, -46100,0.39890665,2.9042513,,,,,,,,,,,,,,,,, -46200,0.35582182,2.8889148,,,,,,,,,,,,,,,,, -46300,0.3618282,2.98022,,,,,,,,,,,,,,,,, -46400,0.3745313,2.9112508,,,,,,,,,,,,,,,,, -46500,0.32912406,2.9416432,,,,,,,,,,,,,,,,, -46600,0.37712443,2.938649,,,,,,,,,,,,,,,,, -46700,0.41174385,2.9281406,,,,,,,,,,,,,,,,, -46800,0.4096847,2.9691124,,,,,,,,,,,,,,,,, -46900,0.44483438,2.953753,,,,,,,,,,,,,,,,, -47000,0.3276072,2.901253,,,,,,,,,,,,,,,,, -47100,0.41656852,2.9044707,,,,,,,,,,,,,,,,, -47200,0.35255054,2.864084,,,,,,,,,,,,,,,,, -47300,0.34089136,2.869347,,,,,,,,,,,,,,,,, -47400,0.3873029,2.939524,,,,,,,,,,,,,,,,, -47500,0.36467463,2.9469483,,,,,,,,,,,,,,,,, -47600,0.35648724,2.8692322,,,,,,,,,,,,,,,,, -47700,0.40069717,2.9781609,,,,,,,,,,,,,,,,, -47800,0.36461318,2.9634569,,,,,,,,,,,,,,,,, -47900,0.3584463,3.0098484,,,,,,,,,,,,,,,,, -48000,0.40259627,2.8837624,,,,,,,,,,,,,,,,, -48100,0.34112868,2.9109948,,,,,,,,,,,,,,,,, -48200,0.40350685,2.887527,,,,,,,,,,,,,,,,, -48300,0.35924348,2.9829187,,,,,,,,,,,,,,,,, -48328,,,0.6389437913894653,1.855763554573059,31.00010065151017,0.6613061428070068,1.6884702444076538,28.22158443610769,3000.0,0.6750450730323792,1.6206564903259275,28.01314804173373,3003.0,16838.999277830124,27757.97251176834,16838.999277830124,10916.7520134449,0.619476318359375,0.0 -48400,0.3565581,2.898013,,,,,,,,,,,,,,,,, -48500,0.3903272,3.012882,,,,,,,,,,,,,,,,, -48600,0.37717932,2.8810577,,,,,,,,,,,,,,,,, -48700,0.35581762,2.9202547,,,,,,,,,,,,,,,,, -48800,0.35128152,2.891218,,,,,,,,,,,,,,,,, -48900,0.3755036,2.9490435,,,,,,,,,,,,,,,,, -49000,0.36511794,2.9485495,,,,,,,,,,,,,,,,, -49100,0.38544834,2.9035077,,,,,,,,,,,,,,,,, -49200,0.4071739,2.9654856,,,,,,,,,,,,,,,,, -49300,0.34587,2.9392154,,,,,,,,,,,,,,,,, -49400,0.37146696,2.8889108,,,,,,,,,,,,,,,,, -49500,0.3854153,2.9535615,,,,,,,,,,,,,,,,, -49600,0.39610285,2.941235,,,,,,,,,,,,,,,,, -49700,0.40639138,2.898928,,,,,,,,,,,,,,,,, -49800,0.34917453,2.8622034,,,,,,,,,,,,,,,,, -49900,0.34274462,2.8955646,,,,,,,,,,,,,,,,, -50000,0.3283139,2.9397693,,,,,,,,,,,,,,,,, -50100,0.34484804,2.883262,,,,,,,,,,,,,,,,, -50200,0.36407572,2.8897018,,,,,,,,,,,,,,,,, -50300,0.3716517,2.9235053,,,,,,,,,,,,,,,,, -50400,0.34622818,2.955901,,,,,,,,,,,,,,,,, -50500,0.3781707,2.966953,,,,,,,,,,,,,,,,, -50600,0.35802892,2.9517403,,,,,,,,,,,,,,,,, -50700,0.35865745,2.9703963,,,,,,,,,,,,,,,,, -50747,,,0.6508919596672058,1.7648361921310425,31.78269588076056,0.6645422577857971,1.6845265626907349,28.74312417864181,3000.0,0.6755331158638,1.6183867454528809,28.33245909376165,3003.0,17679.204163074493,29064.18690776825,17679.204163074493,11382.648072242737,0.6527702808380127,0.0 -50800,0.48587307,2.9035532,,,,,,,,,,,,,,,,, -50900,0.36510164,2.9568803,,,,,,,,,,,,,,,,, -51000,0.34822395,2.8831193,,,,,,,,,,,,,,,,, -51100,0.35810673,2.927562,,,,,,,,,,,,,,,,, -51200,0.3602574,2.9178574,,,,,,,,,,,,,,,,, -51300,0.3771359,2.9802508,,,,,,,,,,,,,,,,, -51400,0.38531047,2.9135725,,,,,,,,,,,,,,,,, -51500,0.36876845,2.9140651,,,,,,,,,,,,,,,,, -51600,0.375818,2.8914616,,,,,,,,,,,,,,,,, -51700,0.34088147,2.9871967,,,,,,,,,,,,,,,,, -51800,0.345872,2.949608,,,,,,,,,,,,,,,,, -51900,0.33983478,2.8505526,,,,,,,,,,,,,,,,, -52000,0.3839552,2.9751954,,,,,,,,,,,,,,,,, -52100,0.38361356,2.9388382,,,,,,,,,,,,,,,,, -52200,0.34139758,2.8975546,,,,,,,,,,,,,,,,, -52300,0.36431658,2.8915284,,,,,,,,,,,,,,,,, -52400,0.35901442,2.9603071,,,,,,,,,,,,,,,,, -52500,0.3750785,2.8634586,,,,,,,,,,,,,,,,, -52600,0.3319995,2.8874087,,,,,,,,,,,,,,,,, -52700,0.40362865,2.8496358,,,,,,,,,,,,,,,,, -52800,0.37324667,2.9199607,,,,,,,,,,,,,,,,, -52900,0.36463815,2.9629993,,,,,,,,,,,,,,,,, -53000,0.36293408,2.8607042,,,,,,,,,,,,,,,,, -53100,0.36391503,2.9449358,,,,,,,,,,,,,,,,, -53166,,,0.6417784690856934,1.827314376831055,31.41106280404022,0.664182722568512,1.6704254150390625,28.51513503042841,3000.0,0.6751612424850464,1.60860013961792,28.21440380671305,3003.0,18519.205310583115,30440.82387685776,18519.205310583115,11919.171528339386,0.6851651668548584,0.0 -53200,0.4526058,2.938611,,,,,,,,,,,,,,,,, -53300,0.37590814,2.8757918,,,,,,,,,,,,,,,,, -53400,0.8606134,2.9188833,,,,,,,,,,,,,,,,, -53500,0.32981452,2.8488255,,,,,,,,,,,,,,,,, -53600,0.39117345,2.9218574,,,,,,,,,,,,,,,,, -53700,0.36651254,2.8157883,,,,,,,,,,,,,,,,, -53800,0.43199328,2.9720373,,,,,,,,,,,,,,,,, -53900,0.37700954,2.9414132,,,,,,,,,,,,,,,,, -54000,0.3577861,2.917058,,,,,,,,,,,,,,,,, -54100,0.3492389,2.9437644,,,,,,,,,,,,,,,,, -54200,0.3865474,2.9380002,,,,,,,,,,,,,,,,, -54300,0.3610047,2.9200802,,,,,,,,,,,,,,,,, -54400,0.3700897,2.9102097,,,,,,,,,,,,,,,,, -54500,0.3628392,3.019977,,,,,,,,,,,,,,,,, -54600,0.40068665,2.9815855,,,,,,,,,,,,,,,,, -54700,0.3417141,2.9364605,,,,,,,,,,,,,,,,, -54800,0.33732516,2.9115536,,,,,,,,,,,,,,,,, -54900,0.355542,2.8798492,,,,,,,,,,,,,,,,, -55000,0.33651528,2.8883865,,,,,,,,,,,,,,,,, -55100,0.384666,2.864996,,,,,,,,,,,,,,,,, -55200,0.38901898,2.8983684,,,,,,,,,,,,,,,,, -55300,0.3558385,2.8974016,,,,,,,,,,,,,,,,, -55400,0.40641275,2.8784235,,,,,,,,,,,,,,,,, -55500,0.344772,3.024484,,,,,,,,,,,,,,,,, -55585,,,0.6417572498321533,1.8483734130859373,31.531738068339457,0.665658175945282,1.668606519699097,28.607842205522875,3000.0,0.6774388551712036,1.6036652326583862,28.394650410572435,3003.0,19359.40021085739,31849.13062644005,19359.40021085739,12487.166585206984,0.7213225364685059,0.0 -55600,0.3878324,2.9108534,,,,,,,,,,,,,,,,, -55700,0.34375572,2.9307632,,,,,,,,,,,,,,,,, -55800,0.3969509,2.9384665,,,,,,,,,,,,,,,,, -55900,0.37240496,3.0403774,,,,,,,,,,,,,,,,, -56000,0.36585003,2.9237678,,,,,,,,,,,,,,,,, -56100,0.33327398,2.8900077,,,,,,,,,,,,,,,,, -56200,0.33929303,2.920685,,,,,,,,,,,,,,,,, -56300,0.3815895,2.9842546,,,,,,,,,,,,,,,,, -56400,0.38164663,2.8658566,,,,,,,,,,,,,,,,, -56500,0.3902043,2.8785665,,,,,,,,,,,,,,,,, -56600,0.41953865,2.907534,,,,,,,,,,,,,,,,, -56700,0.3789675,2.8937266,,,,,,,,,,,,,,,,, -56800,0.3830235,2.926589,,,,,,,,,,,,,,,,, -56900,0.358851,2.8828852,,,,,,,,,,,,,,,,, -57000,0.37563288,2.9615426,,,,,,,,,,,,,,,,, -57100,0.37047616,2.8744237,,,,,,,,,,,,,,,,, -57200,0.3382597,2.9605052,,,,,,,,,,,,,,,,, -57300,0.3744782,2.884417,,,,,,,,,,,,,,,,, -57400,0.38012764,2.925135,,,,,,,,,,,,,,,,, -57500,0.3479481,2.9067092,,,,,,,,,,,,,,,,, -57600,0.33803576,2.9455526,,,,,,,,,,,,,,,,, -57700,0.36484793,2.8799806,,,,,,,,,,,,,,,,, -57800,0.31824112,2.9066584,,,,,,,,,,,,,,,,, -57900,0.42742822,2.8516352,,,,,,,,,,,,,,,,, -58000,0.41589934,2.8971457,,,,,,,,,,,,,,,,, -58004,,,0.650646448135376,1.7800432443618774,32.27619017894549,0.6677040457725525,1.6552205085754397,28.97028555750232,3000.0,0.6786706447601318,1.5954159498214722,28.438258704322525,3003.0,20199.471867084503,33152.7517850399,20199.471867084503,12950.59977388382,0.7550153732299805,0.0 -58100,0.37279642,2.8675272,,,,,,,,,,,,,,,,, -58200,0.34707025,2.9018843,,,,,,,,,,,,,,,,, -58300,0.36256403,2.848658,,,,,,,,,,,,,,,,, -58400,0.36234465,2.8956897,,,,,,,,,,,,,,,,, -58500,0.33852312,2.9236252,,,,,,,,,,,,,,,,, -58600,0.3753125,2.8723006,,,,,,,,,,,,,,,,, -58700,0.39695945,2.9171054,,,,,,,,,,,,,,,,, -58800,0.3640885,2.8198287,,,,,,,,,,,,,,,,, -58900,0.3859901,2.872913,,,,,,,,,,,,,,,,, -59000,0.33541483,2.8645322,,,,,,,,,,,,,,,,, -59100,0.3515297,2.9120052,,,,,,,,,,,,,,,,, -59200,0.33507764,2.9596784,,,,,,,,,,,,,,,,, -59300,0.33831137,2.9641013,,,,,,,,,,,,,,,,, -59400,0.39418688,2.8431532,,,,,,,,,,,,,,,,, -59500,0.34710807,2.888387,,,,,,,,,,,,,,,,, -59600,0.3479931,2.8696005,,,,,,,,,,,,,,,,, -59700,0.41479018,2.9390867,,,,,,,,,,,,,,,,, -59800,0.3796579,2.9120276,,,,,,,,,,,,,,,,, -59900,0.3573839,2.8968418,,,,,,,,,,,,,,,,, -60000,0.37732407,2.9558642,,,,,,,,,,,,,,,,, -60100,0.38614956,2.891467,,,,,,,,,,,,,,,,, -60200,0.35859078,2.887433,,,,,,,,,,,,,,,,, -60300,0.38613743,2.948386,,,,,,,,,,,,,,,,, -60400,0.45798412,2.940232,,,,,,,,,,,,,,,,, -60422,,,0.6451210379600525,1.8069937229156487,31.819075921555985,0.6668609380722046,1.6589609384536743,28.74034449614497,3000.0,0.6806228756904602,1.5879414081573486,28.597255906283987,3003.0,21039.458937883377,34488.21444058418,21039.458937883377,13445.960545539856,0.7892227172851562,0.0 -60500,0.37038764,2.9094858,,,,,,,,,,,,,,,,, -60600,0.38156512,2.8531778,,,,,,,,,,,,,,,,, -60700,0.40046883,2.9762738,,,,,,,,,,,,,,,,, -60800,0.40618986,2.970374,,,,,,,,,,,,,,,,, -60900,0.3484741,2.9067008,,,,,,,,,,,,,,,,, -61000,0.38692603,2.9781616,,,,,,,,,,,,,,,,, -61100,0.33577314,2.8858309,,,,,,,,,,,,,,,,, -61200,0.36652994,2.8833685,,,,,,,,,,,,,,,,, -61300,0.35558662,2.8418322,,,,,,,,,,,,,,,,, -61400,0.36951208,2.8918483,,,,,,,,,,,,,,,,, -61500,0.36141908,2.9241912,,,,,,,,,,,,,,,,, -61600,0.36998463,2.9382157,,,,,,,,,,,,,,,,, -61700,0.34638953,2.887382,,,,,,,,,,,,,,,,, -61800,0.37742785,2.9622364,,,,,,,,,,,,,,,,, -61900,0.34022713,2.8247573,,,,,,,,,,,,,,,,, -62000,0.3579812,2.91907,,,,,,,,,,,,,,,,, -62100,0.33115608,2.936572,,,,,,,,,,,,,,,,, -62200,0.38120034,2.8464055,,,,,,,,,,,,,,,,, -62300,0.4440066,2.8632123,,,,,,,,,,,,,,,,, -62400,0.33409533,2.8014288,,,,,,,,,,,,,,,,, -62500,0.35205713,2.885088,,,,,,,,,,,,,,,,, -62600,0.38087976,2.843496,,,,,,,,,,,,,,,,, -62700,0.38482654,2.9052,,,,,,,,,,,,,,,,, -62800,0.37281948,2.8934236,,,,,,,,,,,,,,,,, -62841,,,0.6688887476921082,1.6635355949401855,33.49501242969906,0.6675428748130798,1.650787353515625,28.96451958189699,3000.0,0.6802510023117065,1.5829477310180664,28.31530092941621,3003.0,21879.506196975708,35813.19875717163,21879.506196975708,13930.784398078918,0.8219027519226074,0.0 -62900,0.34163105,2.9102364,,,,,,,,,,,,,,,,, -63000,0.34185556,2.8935661,,,,,,,,,,,,,,,,, -63100,0.36438102,2.8844488,,,,,,,,,,,,,,,,, -63200,2.4071562,2.9170647,,,,,,,,,,,,,,,,, -63300,0.3828997,2.822728,,,,,,,,,,,,,,,,, -63400,0.40194702,2.9400032,,,,,,,,,,,,,,,,, -63500,0.3615095,2.8694742,,,,,,,,,,,,,,,,, -63600,0.362991,2.8797483,,,,,,,,,,,,,,,,, -63700,0.3520924,2.8910575,,,,,,,,,,,,,,,,, -63800,0.3426326,2.899064,,,,,,,,,,,,,,,,, -63900,0.3520588,2.900508,,,,,,,,,,,,,,,,, -64000,0.39552298,2.8437774,,,,,,,,,,,,,,,,, -64100,0.3851394,2.891002,,,,,,,,,,,,,,,,, -64200,0.3606147,2.9020057,,,,,,,,,,,,,,,,, -64300,0.3809013,2.8810108,,,,,,,,,,,,,,,,, -64400,0.3680209,2.908046,,,,,,,,,,,,,,,,, -64500,0.3639929,2.9342895,,,,,,,,,,,,,,,,, -64600,0.36763087,2.897775,,,,,,,,,,,,,,,,, -64700,0.38588527,2.9307172,,,,,,,,,,,,,,,,, -64800,0.35490847,2.8964965,,,,,,,,,,,,,,,,, -64900,0.33391812,2.8559356,,,,,,,,,,,,,,,,, -65000,0.35370424,2.8631377,,,,,,,,,,,,,,,,, -65100,0.38848564,2.8926756,,,,,,,,,,,,,,,,, -65200,0.3700131,2.8409214,,,,,,,,,,,,,,,,, -65260,,,0.6497659683227539,1.7808959484100342,32.40010374364767,0.6687827706336975,1.6441073417663574,28.94475644983091,3000.0,0.6811341643333435,1.5736756324768066,28.40901193383648,3003.0,22719.717471838,37126.762442588806,22719.717471838,14404.023327350616,0.8547759056091309,0.0 -65300,0.37114015,2.8740203,,,,,,,,,,,,,,,,, -65400,0.3843636,2.9174616,,,,,,,,,,,,,,,,, -65500,0.34671974,2.853151,,,,,,,,,,,,,,,,, -65600,0.35887408,2.9168024,,,,,,,,,,,,,,,,, -65700,0.36986622,2.9262726,,,,,,,,,,,,,,,,, -65800,0.37368256,2.8468962,,,,,,,,,,,,,,,,, -65900,0.37107274,2.8728042,,,,,,,,,,,,,,,,, -66000,0.33971766,2.8350286,,,,,,,,,,,,,,,,, -66100,0.36574078,2.884032,,,,,,,,,,,,,,,,, -66200,0.35982275,2.8887067,,,,,,,,,,,,,,,,, -66300,0.37978446,2.917813,,,,,,,,,,,,,,,,, -66400,0.37510398,2.821106,,,,,,,,,,,,,,,,, -66500,0.36617747,2.8590224,,,,,,,,,,,,,,,,, -66600,0.3601734,2.8794038,,,,,,,,,,,,,,,,, -66700,0.4419693,2.9380915,,,,,,,,,,,,,,,,, -66800,0.3369067,2.844018,,,,,,,,,,,,,,,,, -66900,0.36773315,2.8922467,,,,,,,,,,,,,,,,, -67000,0.34065711,2.8434572,,,,,,,,,,,,,,,,, -67100,0.36015484,2.9065106,,,,,,,,,,,,,,,,, -67200,0.35682547,2.8386211,,,,,,,,,,,,,,,,, -67300,0.36895403,2.893203,,,,,,,,,,,,,,,,, -67400,0.36502606,2.905403,,,,,,,,,,,,,,,,, -67500,0.35332248,2.9320796,,,,,,,,,,,,,,,,, -67600,0.36538967,2.8201978,,,,,,,,,,,,,,,,, -67679,,,0.6502645611763,1.7860978841781616,31.68374978602823,0.6713865995407104,1.6337947845458984,29.22256790706708,3000.0,0.6845738291740417,1.5635030269622805,29.088031046885384,3003.0,23559.93111491204,38493.42343258858,23559.93111491204,14930.35793542862,0.8874788284301758,0.0 -67700,0.34768996,2.8431764,,,,,,,,,,,,,,,,, -67800,0.35336405,2.8900044,,,,,,,,,,,,,,,,, -67900,0.36379468,2.9118764,,,,,,,,,,,,,,,,, -68000,0.3818819,2.8584383,,,,,,,,,,,,,,,,, -68100,0.36431155,2.8465886,,,,,,,,,,,,,,,,, -68200,0.3869924,2.9460495,,,,,,,,,,,,,,,,, -68300,0.32929242,2.821643,,,,,,,,,,,,,,,,, -68400,0.36689195,2.832326,,,,,,,,,,,,,,,,, -68500,0.36591733,2.8849006,,,,,,,,,,,,,,,,, -68600,0.34601057,2.8348832,,,,,,,,,,,,,,,,, -68700,0.39792475,2.914014,,,,,,,,,,,,,,,,, -68800,0.36371264,2.7967145,,,,,,,,,,,,,,,,, -68900,0.3503146,2.8812404,,,,,,,,,,,,,,,,, -69000,0.343864,2.8762195,,,,,,,,,,,,,,,,, -69100,0.3549097,2.9180439,,,,,,,,,,,,,,,,, -69200,0.4022509,2.9190922,,,,,,,,,,,,,,,,, -69300,0.34166226,2.9052968,,,,,,,,,,,,,,,,, -69400,0.34252673,2.8316505,,,,,,,,,,,,,,,,, -69500,0.35670742,2.8433905,,,,,,,,,,,,,,,,, -69600,0.36659747,2.8387103,,,,,,,,,,,,,,,,, -69700,0.39603987,2.9084435,,,,,,,,,,,,,,,,, -69800,0.41204613,2.82459,,,,,,,,,,,,,,,,, -69900,0.37556374,2.892341,,,,,,,,,,,,,,,,, -70000,0.34836975,2.8663597,,,,,,,,,,,,,,,,, -70098,,,0.6582711338996887,1.7285667657852173,32.39691829939411,0.671584963798523,1.6267058849334717,29.149120230108544,3000.0,0.6856196522712708,1.5533530712127686,28.82446304234327,3003.0,24400.08797216416,39873.08808875084,24400.08797216416,15469.750158786774,0.9219081401824952,0.0 -70100,0.3843172,2.842855,,,,,,,,,,,,,,,,, -70200,0.3406656,2.9027858,,,,,,,,,,,,,,,,, -70300,0.37285775,2.842028,,,,,,,,,,,,,,,,, -70400,0.36534342,2.8018198,,,,,,,,,,,,,,,,, -70500,0.3687636,2.8438058,,,,,,,,,,,,,,,,, -70600,0.38613054,2.8377118,,,,,,,,,,,,,,,,, -70700,0.35853454,2.842778,,,,,,,,,,,,,,,,, -70800,0.36670527,2.841119,,,,,,,,,,,,,,,,, -70900,0.37430844,2.8640485,,,,,,,,,,,,,,,,, -71000,0.40766805,2.7715547,,,,,,,,,,,,,,,,, -71100,0.36144167,2.8495326,,,,,,,,,,,,,,,,, -71200,0.3697103,2.9288843,,,,,,,,,,,,,,,,, -71300,0.3706761,2.8568847,,,,,,,,,,,,,,,,, -71400,0.3564343,2.855725,,,,,,,,,,,,,,,,, -71500,0.33782262,2.8681579,,,,,,,,,,,,,,,,, -71600,0.36452472,2.845593,,,,,,,,,,,,,,,,, -71700,0.39281967,2.8673322,,,,,,,,,,,,,,,,, -71800,0.379163,2.877191,,,,,,,,,,,,,,,,, -71900,0.32939857,2.8396552,,,,,,,,,,,,,,,,, -72000,0.36809257,2.8908129,,,,,,,,,,,,,,,,, -72100,0.38455245,2.8471527,,,,,,,,,,,,,,,,, -72200,0.3962542,2.799652,,,,,,,,,,,,,,,,, -72300,0.37032285,2.8281147,,,,,,,,,,,,,,,,, -72400,0.40197948,2.8968465,,,,,,,,,,,,,,,,, -72500,0.38677788,2.83047,,,,,,,,,,,,,,,,, -72518,,,0.6576229333877563,1.7471884489059448,32.81365972041144,0.6732960343360901,1.6212689876556396,29.323903687794544,3000.0,0.6862123012542725,1.5499813556671145,28.91150751075374,3003.0,25240.307821035385,41214.29425621033,25240.307821035385,15970.62186050415,0.957329273223877,0.0 -72600,0.4330479,2.8293064,,,,,,,,,,,,,,,,, -72700,0.39246738,2.8546665,,,,,,,,,,,,,,,,, -72800,0.34862143,2.871731,,,,,,,,,,,,,,,,, -72900,0.36903962,2.8867686,,,,,,,,,,,,,,,,, -73000,0.38660184,2.867091,,,,,,,,,,,,,,,,, -73100,0.34897017,2.8774683,,,,,,,,,,,,,,,,, -73200,0.35885668,2.8095553,,,,,,,,,,,,,,,,, -73300,0.3638391,2.7723417,,,,,,,,,,,,,,,,, -73400,0.3533303,2.7910955,,,,,,,,,,,,,,,,, -73500,0.354108,2.8483603,,,,,,,,,,,,,,,,, -73600,0.3723877,2.9114714,,,,,,,,,,,,,,,,, -73700,0.38449568,2.8305376,,,,,,,,,,,,,,,,, -73800,0.38568297,2.8500812,,,,,,,,,,,,,,,,, -73900,0.348081,2.8337772,,,,,,,,,,,,,,,,, -74000,0.38755742,2.8912487,,,,,,,,,,,,,,,,, -74100,0.38876885,2.8994944,,,,,,,,,,,,,,,,, -74200,0.36230874,2.8847191,,,,,,,,,,,,,,,,, -74300,0.3733859,2.8889503,,,,,,,,,,,,,,,,, -74400,0.38848817,2.8293028,,,,,,,,,,,,,,,,, -74500,0.38359156,2.7710686,,,,,,,,,,,,,,,,, -74600,0.34622583,2.7937698,,,,,,,,,,,,,,,,, -74700,0.3615639,2.7925458,,,,,,,,,,,,,,,,, -74800,0.35975417,2.889425,,,,,,,,,,,,,,,,, -74900,0.38032514,2.842724,,,,,,,,,,,,,,,,, -74937,,,0.6538400053977966,1.7619236707687378,32.41990521138852,0.6753295063972473,1.6074609756469729,29.47000517384073,3000.0,0.6883504986763,1.5352879762649536,29.17476942337257,3003.0,26080.382111549377,42529.167917490005,26080.382111549377,16445.30373263359,0.995025396347046,0.0 -75000,0.3679481,2.8136811,,,,,,,,,,,,,,,,, -75100,0.34679767,2.7624686,,,,,,,,,,,,,,,,, -75200,0.3841245,2.8214862,,,,,,,,,,,,,,,,, -75300,0.396284,2.8594112,,,,,,,,,,,,,,,,, -75400,0.34867623,2.8049078,,,,,,,,,,,,,,,,, -75500,0.35624337,2.8552785,,,,,,,,,,,,,,,,, -75600,0.3773054,2.863888,,,,,,,,,,,,,,,,, -75700,0.37161404,2.8639317,,,,,,,,,,,,,,,,, -75800,0.34757927,2.851297,,,,,,,,,,,,,,,,, -75900,0.40684026,2.8298986,,,,,,,,,,,,,,,,, -76000,0.36465532,2.7961578,,,,,,,,,,,,,,,,, -76100,0.36259508,2.8332286,,,,,,,,,,,,,,,,, -76200,0.35747182,2.8122132,,,,,,,,,,,,,,,,, -76300,0.385081,2.846182,,,,,,,,,,,,,,,,, -76400,0.36261594,2.9233923,,,,,,,,,,,,,,,,, -76500,0.3715899,2.82508,,,,,,,,,,,,,,,,, -76600,0.37308538,2.793164,,,,,,,,,,,,,,,,, -76700,0.3468704,2.8542097,,,,,,,,,,,,,,,,, -76800,0.36694536,2.8313103,,,,,,,,,,,,,,,,, -76900,0.3757103,2.8235424,,,,,,,,,,,,,,,,, -77000,0.3675709,2.8779092,,,,,,,,,,,,,,,,, -77100,0.3905575,2.820562,,,,,,,,,,,,,,,,, -77200,0.36796978,2.8263965,,,,,,,,,,,,,,,,, -77300,0.383089,2.838885,,,,,,,,,,,,,,,,, -77357,,,0.6637595295906067,1.6998523473739624,32.32256300712956,0.6762842535972595,1.6049737930297852,29.577603225878857,3000.0,0.6897449493408203,1.5332320928573608,29.111522492945284,3003.0,26920.60708403588,43868.18283462525,26920.60708403588,16943.978850841522,1.0299007892608645,0.0 -77400,0.36870292,2.8411648,,,,,,,,,,,,,,,,, -77500,0.40062657,2.8445926,,,,,,,,,,,,,,,,, -77600,0.39124423,2.8266845,,,,,,,,,,,,,,,,, -77700,0.37886077,2.9918776,,,,,,,,,,,,,,,,, -77800,0.59633255,2.8861992,,,,,,,,,,,,,,,,, -77900,0.3810999,2.8236341,,,,,,,,,,,,,,,,, -78000,0.3876523,2.79607,,,,,,,,,,,,,,,,, -78100,0.39697805,2.8813877,,,,,,,,,,,,,,,,, -78200,0.38616243,2.8773954,,,,,,,,,,,,,,,,, -78300,0.39858934,2.8505707,,,,,,,,,,,,,,,,, -78400,0.37096938,2.8148324,,,,,,,,,,,,,,,,, -78500,0.36324283,2.847444,,,,,,,,,,,,,,,,, -78600,0.3525223,2.8397992,,,,,,,,,,,,,,,,, -78700,0.35806498,2.8430045,,,,,,,,,,,,,,,,, -78800,0.39309856,2.7713525,,,,,,,,,,,,,,,,, -78900,0.41522503,2.8100042,,,,,,,,,,,,,,,,, -79000,0.38214362,2.8140397,,,,,,,,,,,,,,,,, -79100,0.36903155,2.8706334,,,,,,,,,,,,,,,,, -79200,0.36825126,2.8601437,,,,,,,,,,,,,,,,, -79300,0.3933565,2.8452103,,,,,,,,,,,,,,,,, -79400,0.37478724,2.8008146,,,,,,,,,,,,,,,,, -79500,0.4150948,2.9045455,,,,,,,,,,,,,,,,, -79600,0.37533322,2.8508687,,,,,,,,,,,,,,,,, -79700,0.3882572,2.8608801,,,,,,,,,,,,,,,,, -79777,,,0.6550542116165161,1.7492350339889526,32.499716599105,0.6763834357261658,1.593187689781189,29.572896222219057,3000.0,0.6903492212295532,1.5179253816604614,29.494597510156364,3003.0,27760.75708150864,45201.78984117508,27760.75708150864,17437.31713628769,1.068063259124756,0.0 -79800,0.40028378,2.8441453,,,,,,,,,,,,,,,,, -79900,0.38369438,2.8630497,,,,,,,,,,,,,,,,, -80000,0.42024234,2.7481525,,,,,,,,,,,,,,,,, -80100,0.39432254,2.8986,,,,,,,,,,,,,,,,, -80200,0.3709913,2.797163,,,,,,,,,,,,,,,,, -80300,0.3926703,2.867501,,,,,,,,,,,,,,,,, -80400,0.3730625,2.817616,,,,,,,,,,,,,,,,, -80500,0.3889356,2.9030569,,,,,,,,,,,,,,,,, -80600,0.35838288,2.8050041,,,,,,,,,,,,,,,,, -80700,0.37066102,2.792101,,,,,,,,,,,,,,,,, -80800,0.38657445,2.863404,,,,,,,,,,,,,,,,, -80900,0.43676603,2.835516,,,,,,,,,,,,,,,,, -81000,0.37165332,2.8681512,,,,,,,,,,,,,,,,, -81100,0.39262074,2.790086,,,,,,,,,,,,,,,,, -81200,0.38261122,2.831925,,,,,,,,,,,,,,,,, -81300,0.3979329,2.9181137,,,,,,,,,,,,,,,,, -81400,0.37912428,2.8129332,,,,,,,,,,,,,,,,, -81500,0.37713754,2.7596622,,,,,,,,,,,,,,,,, -81600,0.4110829,2.854522,,,,,,,,,,,,,,,,, -81700,0.36958358,2.797551,,,,,,,,,,,,,,,,, -81800,0.39331883,2.8214862,,,,,,,,,,,,,,,,, -81900,0.39531082,2.861297,,,,,,,,,,,,,,,,, -82000,0.38747066,2.87881,,,,,,,,,,,,,,,,, -82100,0.38608617,2.762903,,,,,,,,,,,,,,,,, -82196,,,0.6707914471626282,1.653606414794922,33.72058966171932,0.6779953241348267,1.590328335762024,29.64467562287056,3000.0,0.6923130750656128,1.5137943029403689,29.512227379029262,3003.0,28600.833948373795,46536.39051222801,28600.833948373795,17931.717635393143,1.1103150844573977,0.0 -82200,0.36710533,2.8694425,,,,,,,,,,,,,,,,, -82300,0.3963144,2.9044154,,,,,,,,,,,,,,,,, -82400,0.35734066,2.854073,,,,,,,,,,,,,,,,, -82500,0.38504744,2.8329456,,,,,,,,,,,,,,,,, -82600,0.41616452,2.7813842,,,,,,,,,,,,,,,,, -82700,0.38556424,2.8389723,,,,,,,,,,,,,,,,, -82800,0.40427274,2.8264291,,,,,,,,,,,,,,,,, -82900,0.37447232,2.822065,,,,,,,,,,,,,,,,, -83000,0.36956644,2.7801394,,,,,,,,,,,,,,,,, -83100,0.38118586,2.8288958,,,,,,,,,,,,,,,,, -83200,0.3804045,2.8806522,,,,,,,,,,,,,,,,, -83300,0.37941715,2.7669065,,,,,,,,,,,,,,,,, -83400,0.4013294,2.8236976,,,,,,,,,,,,,,,,, -83500,0.4038856,2.8244975,,,,,,,,,,,,,,,,, -83600,0.3831953,2.8266182,,,,,,,,,,,,,,,,, -83700,0.42472136,2.784159,,,,,,,,,,,,,,,,, -83800,0.39815718,2.7540097,,,,,,,,,,,,,,,,, -83900,0.38960317,2.8633668,,,,,,,,,,,,,,,,, -84000,0.3761279,2.7779613,,,,,,,,,,,,,,,,, -84100,0.41817772,2.7928827,,,,,,,,,,,,,,,,, -84200,0.38558325,2.8394446,,,,,,,,,,,,,,,,, -84300,0.3980124,2.7818727,,,,,,,,,,,,,,,,, -84400,0.43511894,2.8614058,,,,,,,,,,,,,,,,, -84500,0.38973266,2.782573,,,,,,,,,,,,,,,,, -84600,0.3736846,2.8146894,,,,,,,,,,,,,,,,, -84615,,,0.6634590029716492,1.7018529176712036,32.766808653370965,0.6799171566963196,1.586037039756775,29.75577826387701,3000.0,0.693312406539917,1.509398102760315,30.06056970667257,3003.0,29440.945076942444,47895.65140795708,29440.945076942444,18450.751024246216,1.1467373371124268,0.0 -84700,0.38603485,2.8002908,,,,,,,,,,,,,,,,, -84800,0.40692377,2.756203,,,,,,,,,,,,,,,,, -84900,0.37954628,2.8365965,,,,,,,,,,,,,,,,, -85000,0.3859586,2.7891867,,,,,,,,,,,,,,,,, -85100,0.36442482,2.7784805,,,,,,,,,,,,,,,,, -85200,0.4049911,2.8398688,,,,,,,,,,,,,,,,, -85300,0.38696396,2.8256497,,,,,,,,,,,,,,,,, -85400,0.38320592,2.8098147,,,,,,,,,,,,,,,,, -85500,0.4079651,2.8701196,,,,,,,,,,,,,,,,, -85600,0.43775913,2.750066,,,,,,,,,,,,,,,,, -85700,0.40073657,2.8468816,,,,,,,,,,,,,,,,, -85800,0.40587416,2.8926878,,,,,,,,,,,,,,,,, -85900,0.46252915,2.8800845,,,,,,,,,,,,,,,,, -86000,0.39379287,2.8432772,,,,,,,,,,,,,,,,, -86100,0.3803946,2.7668855,,,,,,,,,,,,,,,,, -86200,0.41248262,2.7818823,,,,,,,,,,,,,,,,, -86300,0.39680243,2.775948,,,,,,,,,,,,,,,,, -86400,0.408182,2.8877208,,,,,,,,,,,,,,,,, -86500,0.4086609,2.7950733,,,,,,,,,,,,,,,,, -86600,0.41209945,2.773104,,,,,,,,,,,,,,,,, -86700,0.43467063,2.8149958,,,,,,,,,,,,,,,,, -86800,0.43255988,2.8066988,,,,,,,,,,,,,,,,, -86900,0.41059345,2.7841609,,,,,,,,,,,,,,,,, -87000,0.3922505,2.767442,,,,,,,,,,,,,,,,, -87033,,,0.6625377535820007,1.709887146949768,32.974561510453555,0.6813058853149414,1.572974443435669,29.83404728465848,3000.0,0.6963453888893127,1.4961185455322266,29.654388366237686,3003.0,30280.988560199738,49244.03613877296,30280.988560199738,18958.97201180458,1.1856064796447754,0.0 -87100,0.394283,2.7568223,,,,,,,,,,,,,,,,, -87200,0.39514244,2.8075373,,,,,,,,,,,,,,,,, -87300,0.43446344,2.7963104,,,,,,,,,,,,,,,,, -87400,0.39571851,2.7853434,,,,,,,,,,,,,,,,, -87500,0.41419455,2.841322,,,,,,,,,,,,,,,,, -87600,0.39557672,2.7613308,,,,,,,,,,,,,,,,, -87700,0.39642608,2.7970972,,,,,,,,,,,,,,,,, -87800,0.40360767,2.7610388,,,,,,,,,,,,,,,,, -87900,0.44016534,2.7688038,,,,,,,,,,,,,,,,, -88000,0.41231364,2.7395043,,,,,,,,,,,,,,,,, -88100,0.39339456,2.8199193,,,,,,,,,,,,,,,,, -88200,0.39437044,2.7922518,,,,,,,,,,,,,,,,, -88300,0.41209817,2.829835,,,,,,,,,,,,,,,,, -88400,0.4114237,2.807956,,,,,,,,,,,,,,,,, -88500,0.4053799,2.7977107,,,,,,,,,,,,,,,,, -88600,0.40485084,2.7217445,,,,,,,,,,,,,,,,, -88700,0.42744526,2.7876618,,,,,,,,,,,,,,,,, -88800,0.40373865,2.8261068,,,,,,,,,,,,,,,,, -88900,0.42772865,2.8144379,,,,,,,,,,,,,,,,, -89000,0.42488462,2.7937944,,,,,,,,,,,,,,,,, -89100,0.41432187,2.7967076,,,,,,,,,,,,,,,,, -89200,0.44261038,2.8041632,,,,,,,,,,,,,,,,, -89300,0.43921,2.7895293,,,,,,,,,,,,,,,,, -89400,0.3911609,2.7609463,,,,,,,,,,,,,,,,, -89452,,,0.6736891269683838,1.6366500854492188,33.41154952869727,0.6804007291793823,1.5719475746154783,29.573600778223444,3000.0,0.6961362361907959,1.495225429534912,29.73615868519291,3003.0,31121.10328722,50619.80656862259,31121.10328722,19494.512558221817,1.2216103076934814,0.0 -89500,0.4123077,2.743342,,,,,,,,,,,,,,,,, -89600,0.42283794,2.7492442,,,,,,,,,,,,,,,,, -89700,0.43928897,2.793394,,,,,,,,,,,,,,,,, -89800,0.42273715,2.7710044,,,,,,,,,,,,,,,,, -89900,0.42200446,2.7608874,,,,,,,,,,,,,,,,, -90000,0.4374758,2.8503966,,,,,,,,,,,,,,,,, -90100,0.43438002,2.8145711,,,,,,,,,,,,,,,,, -90200,0.43325418,2.8284774,,,,,,,,,,,,,,,,, -90300,0.39968655,2.8386927,,,,,,,,,,,,,,,,, -90400,0.46172532,2.7503145,,,,,,,,,,,,,,,,, -90500,0.4454841,2.7382252,,,,,,,,,,,,,,,,, -90600,0.441348,2.801164,,,,,,,,,,,,,,,,, -90700,0.41740224,2.8614361,,,,,,,,,,,,,,,,, -90800,0.48920697,2.8363261,,,,,,,,,,,,,,,,, -90900,0.41794991,2.7985456,,,,,,,,,,,,,,,,, -91000,0.63419455,2.780336,,,,,,,,,,,,,,,,, -91100,0.41187313,2.7898347,,,,,,,,,,,,,,,,, -91200,0.44845548,2.823616,,,,,,,,,,,,,,,,, -91300,0.4400888,2.8252714,,,,,,,,,,,,,,,,, -91400,0.425658,2.7409337,,,,,,,,,,,,,,,,, -91500,0.43498778,2.7493007,,,,,,,,,,,,,,,,, -91600,0.43775344,2.7728407,,,,,,,,,,,,,,,,, -91700,0.4428932,2.836663,,,,,,,,,,,,,,,,, -91800,0.4310889,2.7837915,,,,,,,,,,,,,,,,, -91871,,,0.6677938103675842,1.668520450592041,33.170331808642466,0.6823474168777466,1.5652003288269043,30.001956328565168,3000.0,0.6966475248336792,1.4875853061676023,29.9217194579761,3003.0,31961.19503569603,51964.228222608566,31961.19503569603,19998.722029209137,1.26001238822937,0.0 -91900,0.43745923,2.7460997,,,,,,,,,,,,,,,,, -92000,0.4492033,2.8301961,,,,,,,,,,,,,,,,, -92100,0.43172342,2.780719,,,,,,,,,,,,,,,,, -92200,0.46225598,2.8691351,,,,,,,,,,,,,,,,, -92300,0.4280829,2.7503529,,,,,,,,,,,,,,,,, -92400,0.46344143,2.8677537,,,,,,,,,,,,,,,,, -92500,0.4598179,2.753779,,,,,,,,,,,,,,,,, -92600,0.43900707,2.7231414,,,,,,,,,,,,,,,,, -92700,0.425124,2.8053312,,,,,,,,,,,,,,,,, -92800,0.4480142,2.8297722,,,,,,,,,,,,,,,,, -92900,0.44845194,2.8619308,,,,,,,,,,,,,,,,, -93000,0.41749576,2.8273098,,,,,,,,,,,,,,,,, -93100,0.45147067,2.7894394,,,,,,,,,,,,,,,,, -93200,0.46772623,2.7476535,,,,,,,,,,,,,,,,, -93300,0.42892617,2.7240493,,,,,,,,,,,,,,,,, -93400,0.45143956,2.7711663,,,,,,,,,,,,,,,,, -93500,0.4673295,2.7231257,,,,,,,,,,,,,,,,, -93600,0.43296176,2.7516093,,,,,,,,,,,,,,,,, -93700,0.45016652,2.7964108,,,,,,,,,,,,,,,,, -93800,0.418395,2.7153924,,,,,,,,,,,,,,,,, -93900,0.4566385,2.7733712,,,,,,,,,,,,,,,,, -94000,0.4579275,2.7913136,,,,,,,,,,,,,,,,, -94100,0.44882077,2.855984,,,,,,,,,,,,,,,,, -94200,0.45776367,2.7459676,,,,,,,,,,,,,,,,, -94290,,,0.6874837279319763,1.560304880142212,34.74978426592824,0.684963583946228,1.5531262159347534,30.26482790250689,3000.0,0.6995874643325806,1.4738984107971191,30.34351668402616,3003.0,32801.38553190231,53287.52480864525,32801.38553190231,20481.708650112152,1.298454761505127,0.0 -94300,0.4552582,2.7125685,,,,,,,,,,,,,,,,, -94400,0.44764328,2.7427654,,,,,,,,,,,,,,,,, -94500,0.43130973,2.7347703,,,,,,,,,,,,,,,,, -94600,0.46511003,2.775372,,,,,,,,,,,,,,,,, -94700,0.47868884,2.741602,,,,,,,,,,,,,,,,, -94800,0.46496707,2.7565577,,,,,,,,,,,,,,,,, -94900,0.4351555,2.8280652,,,,,,,,,,,,,,,,, -95000,0.4431901,2.736917,,,,,,,,,,,,,,,,, -95100,0.4396485,2.7366683,,,,,,,,,,,,,,,,, -95200,0.43779486,2.7541547,,,,,,,,,,,,,,,,, -95300,0.44523376,2.825968,,,,,,,,,,,,,,,,, -95400,0.44400623,2.7697854,,,,,,,,,,,,,,,,, -95500,0.48641503,2.7438118,,,,,,,,,,,,,,,,, -95600,0.43400502,2.762673,,,,,,,,,,,,,,,,, -95700,0.481403,2.8239145,,,,,,,,,,,,,,,,, -95800,0.4609408,2.7048254,,,,,,,,,,,,,,,,, -95900,0.46435744,2.71097,,,,,,,,,,,,,,,,, -96000,0.47049594,2.783372,,,,,,,,,,,,,,,,, -96100,0.45304576,2.7798238,,,,,,,,,,,,,,,,, -96200,0.46767116,2.655855,,,,,,,,,,,,,,,,, -96300,0.4747314,2.7480102,,,,,,,,,,,,,,,,, -96400,0.47122666,2.7279534,,,,,,,,,,,,,,,,, -96500,0.45993713,2.756686,,,,,,,,,,,,,,,,, -96600,0.472423,2.8016934,,,,,,,,,,,,,,,,, -96700,0.45428655,2.74313,,,,,,,,,,,,,,,,, -96708,,,0.6748186945915222,1.623995304107666,33.85164936344683,0.6853107810020447,1.5497137308120728,30.209081609458195,3000.0,0.700888991355896,1.4674443006515503,30.135616820141905,3003.0,33641.31292676926,54642.65100026131,33641.31292676926,20996.7893345356,1.3360624313354492,0.0 -96800,0.4420546,2.7618437,,,,,,,,,,,,,,,,, -96900,0.4689716,2.7625327,,,,,,,,,,,,,,,,, -97000,0.45937034,2.7602074,,,,,,,,,,,,,,,,, -97100,0.4604459,2.6913683,,,,,,,,,,,,,,,,, -97200,0.47392574,2.7658584,,,,,,,,,,,,,,,,, -97300,0.46517557,2.7942774,,,,,,,,,,,,,,,,, -97400,0.452964,2.7304146,,,,,,,,,,,,,,,,, -97500,0.47422022,2.7025886,,,,,,,,,,,,,,,,, -97600,0.46355116,2.7679107,,,,,,,,,,,,,,,,, -97700,0.49033806,2.7419455,,,,,,,,,,,,,,,,, -97800,0.49221328,2.774946,,,,,,,,,,,,,,,,, -97900,0.4666752,2.7631102,,,,,,,,,,,,,,,,, -98000,0.48957005,2.7420065,,,,,,,,,,,,,,,,, -98100,0.48828328,2.777534,,,,,,,,,,,,,,,,, -98200,0.46832615,2.7330968,,,,,,,,,,,,,,,,, -98300,0.4786525,2.7202468,,,,,,,,,,,,,,,,, -98400,0.49225467,2.7022119,,,,,,,,,,,,,,,,, -98500,0.47723672,2.738011,,,,,,,,,,,,,,,,, -98600,0.48514953,2.7612832,,,,,,,,,,,,,,,,, -98700,0.47445256,2.7371204,,,,,,,,,,,,,,,,, -98800,0.4975619,2.6767907,,,,,,,,,,,,,,,,, -98900,0.48282957,2.7721133,,,,,,,,,,,,,,,,, -99000,0.50308585,2.7019334,,,,,,,,,,,,,,,,, -99100,0.5113678,2.735284,,,,,,,,,,,,,,,,, -99127,,,0.6738321185112,1.6364940404891968,33.38403674379716,0.6848272085189819,1.545107364654541,30.294506822664715,3000.0,0.7012376189231873,1.4639596939086914,30.20869715529164,3003.0,34481.37762069702,56002.891575336456,34481.37762069702,21516.848722696304,1.3743152618408203,0.0 -99200,0.4924953,2.6872284,,,,,,,,,,,,,,,,, -99300,0.51237893,2.759103,,,,,,,,,,,,,,,,, -99400,0.50465006,2.7419221,,,,,,,,,,,,,,,,, -99500,0.48851418,2.703639,,,,,,,,,,,,,,,,, -99600,0.49687895,2.7312,,,,,,,,,,,,,,,,, -99700,0.49886748,2.6896253,,,,,,,,,,,,,,,,, -99800,0.4864435,2.692924,,,,,,,,,,,,,,,,, -99900,0.4899555,2.7530046,,,,,,,,,,,,,,,,, -100000,0.4876081,2.7495127,,,,,,,,,,,,,,,,, -100100,0.513822,2.7678595,,,,,,,,,,,,,,,,, -100200,0.5027133,2.742511,,,,,,,,,,,,,,,,, -100300,0.5058998,2.6981125,,,,,,,,,,,,,,,,, -100400,0.49827597,2.7341726,,,,,,,,,,,,,,,,, -100500,0.5281067,2.7462027,,,,,,,,,,,,,,,,, -100600,0.50969523,2.6867044,,,,,,,,,,,,,,,,, -100700,0.53246117,2.7229972,,,,,,,,,,,,,,,,, -100800,0.52595913,2.7308424,,,,,,,,,,,,,,,,, -100900,0.48284006,2.7385318,,,,,,,,,,,,,,,,, -101000,0.538886,2.7441628,,,,,,,,,,,,,,,,, -101100,0.49702108,2.6951113,,,,,,,,,,,,,,,,, -101200,0.51492745,2.7171447,,,,,,,,,,,,,,,,, -101300,0.51627755,2.694299,,,,,,,,,,,,,,,,, -101400,0.4890613,2.7244816,,,,,,,,,,,,,,,,, -101500,0.51441044,2.700345,,,,,,,,,,,,,,,,, -101546,,,0.6848159432411194,1.569870948791504,34.37797153427266,0.6859679222106934,1.5355159044265747,30.165309046410087,3000.0,0.7024809718132019,1.4583240747451782,30.57824113953,3003.0,35321.28910493851,57307.98519515991,35321.28910493851,21981.911460876465,1.413140058517456,0.0 -101600,0.49959582,2.783734,,,,,,,,,,,,,,,,, -101700,0.49856332,2.7208595,,,,,,,,,,,,,,,,, -101800,0.530116,2.6905274,,,,,,,,,,,,,,,,, -101900,0.512961,2.8014874,,,,,,,,,,,,,,,,, -102000,0.51963764,2.700533,,,,,,,,,,,,,,,,, -102100,0.53189605,2.6717858,,,,,,,,,,,,,,,,, -102200,0.50370777,2.7461314,,,,,,,,,,,,,,,,, -102300,0.5232406,2.8204374,,,,,,,,,,,,,,,,, -102400,0.5230932,2.7837675,,,,,,,,,,,,,,,,, -102500,0.55355144,2.7036064,,,,,,,,,,,,,,,,, -102600,0.48440295,2.6521099,,,,,,,,,,,,,,,,, -102700,0.5287816,2.6780698,,,,,,,,,,,,,,,,, -102800,0.51817816,2.6767156,,,,,,,,,,,,,,,,, -102900,0.5089771,2.7561808,,,,,,,,,,,,,,,,, -103000,0.5539227,2.7167203,,,,,,,,,,,,,,,,, -103100,0.508462,2.671376,,,,,,,,,,,,,,,,, -103200,0.52204543,2.6121852,,,,,,,,,,,,,,,,, -103300,0.5124489,2.6458733,,,,,,,,,,,,,,,,, -103400,0.52072144,2.7171817,,,,,,,,,,,,,,,,, -103500,0.5335922,2.691737,,,,,,,,,,,,,,,,, -103600,0.5526817,2.6887872,,,,,,,,,,,,,,,,, -103700,0.55165684,2.720758,,,,,,,,,,,,,,,,, -103800,0.5290833,2.7121084,,,,,,,,,,,,,,,,, -103900,0.571649,2.769704,,,,,,,,,,,,,,,,, -103965,,,0.6798616647720337,1.5981993675231934,34.46955027904853,0.688063383102417,1.5368680953979492,30.38069848744398,3000.0,0.7020859122276306,1.455241084098816,30.37345328936431,3003.0,36161.35554885864,58637.47705602646,36161.35554885864,22471.217081546783,1.451249599456787,0.0 -104000,0.5541218,2.7173705,,,,,,,,,,,,,,,,, -104100,0.5446838,2.717148,,,,,,,,,,,,,,,,, -104200,0.56332695,2.731935,,,,,,,,,,,,,,,,, -104300,0.54210687,2.7570996,,,,,,,,,,,,,,,,, -104400,0.5165677,2.668408,,,,,,,,,,,,,,,,, -104500,0.5502689,2.7235622,,,,,,,,,,,,,,,,, -104600,0.52835023,2.6940753,,,,,,,,,,,,,,,,, -104700,0.54859096,2.7225013,,,,,,,,,,,,,,,,, -104800,0.5503085,2.7605977,,,,,,,,,,,,,,,,, -104900,0.5484923,2.708644,,,,,,,,,,,,,,,,, -105000,0.55712116,2.7489395,,,,,,,,,,,,,,,,, -105100,0.54876685,2.707445,,,,,,,,,,,,,,,,, -105200,0.56153643,2.752932,,,,,,,,,,,,,,,,, -105300,0.55382276,2.662681,,,,,,,,,,,,,,,,, -105400,0.53966725,2.6685004,,,,,,,,,,,,,,,,, -105500,0.5651776,2.6746042,,,,,,,,,,,,,,,,, -105600,0.5603108,2.726164,,,,,,,,,,,,,,,,, -105700,0.52344525,2.6925974,,,,,,,,,,,,,,,,, -105800,0.57787526,2.7369802,,,,,,,,,,,,,,,,, -105900,0.54120404,2.7046363,,,,,,,,,,,,,,,,, -106000,0.5612588,2.6877875,,,,,,,,,,,,,,,,, -106100,0.56876147,2.7829177,,,,,,,,,,,,,,,,, -106200,0.5320262,2.7252626,,,,,,,,,,,,,,,,, -106300,0.54751694,2.7499835,,,,,,,,,,,,,,,,, -106383,,,0.7143574953079224,1.4251322746276855,36.8438553370823,0.6882121562957764,1.5287905931472778,30.254751595029035,3000.0,0.7030503749847412,1.4487695693969729,30.489053354388908,3003.0,37001.44723653793,59985.43886613846,37001.44723653793,22978.96356916428,1.4913089275360107,0.0 -106400,0.5800271,2.7074223,,,,,,,,,,,,,,,,, -106500,0.62612456,2.7074502,,,,,,,,,,,,,,,,, -106600,0.5611848,2.659661,,,,,,,,,,,,,,,,, -106700,0.57923204,2.6872015,,,,,,,,,,,,,,,,, -106800,0.58471936,2.7277308,,,,,,,,,,,,,,,,, -106900,0.5875439,2.6772711,,,,,,,,,,,,,,,,, -107000,0.5824331,2.784168,,,,,,,,,,,,,,,,, -107100,0.55922765,2.6475253,,,,,,,,,,,,,,,,, -107200,0.542275,2.6607168,,,,,,,,,,,,,,,,, -107300,0.5567177,2.6696427,,,,,,,,,,,,,,,,, -107400,0.57242763,2.7274923,,,,,,,,,,,,,,,,, -107500,0.5699823,2.6630714,,,,,,,,,,,,,,,,, -107600,0.5804152,2.607719,,,,,,,,,,,,,,,,, -107700,0.60320234,2.7063596,,,,,,,,,,,,,,,,, -107800,0.5791836,2.660211,,,,,,,,,,,,,,,,, -107900,0.5897185,2.660955,,,,,,,,,,,,,,,,, -108000,0.57952577,2.7275949,,,,,,,,,,,,,,,,, -108100,0.5819654,2.6954877,,,,,,,,,,,,,,,,, -108200,0.60323066,2.7314272,,,,,,,,,,,,,,,,, -108300,0.59873194,2.719784,,,,,,,,,,,,,,,,, -108400,0.5965193,2.6525798,,,,,,,,,,,,,,,,, -108500,0.58900404,2.6949716,,,,,,,,,,,,,,,,, -108600,0.57151204,2.6786814,,,,,,,,,,,,,,,,, -108700,0.56267107,2.6478949,,,,,,,,,,,,,,,,, -108800,0.58372885,2.628823,,,,,,,,,,,,,,,,, -108801,,,0.6898171305656433,1.552299976348877,34.717774021813774,0.6894644498825073,1.5260159969329834,30.55597763090765,3000.0,0.7047237157821655,1.443089723587036,30.474966315456253,3003.0,37841.80606007576,61331.99062085152,37841.80606007576,23485.036183595657,1.5315580368041992,0.0 -108900,0.6013556,2.6423614,,,,,,,,,,,,,,,,, -109000,0.59973705,2.7324495,,,,,,,,,,,,,,,,, -109100,0.57682467,2.6523411,,,,,,,,,,,,,,,,, -109200,0.59824055,2.6783743,,,,,,,,,,,,,,,,, -109300,0.5960463,2.6391292,,,,,,,,,,,,,,,,, -109400,0.6080131,2.616153,,,,,,,,,,,,,,,,, -109500,0.63763106,2.6991143,,,,,,,,,,,,,,,,, -109600,0.6169813,2.6495073,,,,,,,,,,,,,,,,, -109700,0.6174928,2.6701574,,,,,,,,,,,,,,,,, -109800,0.95739853,2.6655366,,,,,,,,,,,,,,,,, -109900,0.63482785,2.6921918,,,,,,,,,,,,,,,,, -110000,0.5978119,2.6148462,,,,,,,,,,,,,,,,, -110100,0.606905,2.6829958,,,,,,,,,,,,,,,,, -110200,0.59724426,2.7098892,,,,,,,,,,,,,,,,, -110300,0.59098023,2.6251235,,,,,,,,,,,,,,,,, -110400,0.6099195,2.6499166,,,,,,,,,,,,,,,,, -110500,0.6287988,2.6571953,,,,,,,,,,,,,,,,, -110600,0.6263242,2.596922,,,,,,,,,,,,,,,,, -110700,0.606035,2.6489666,,,,,,,,,,,,,,,,, -110800,0.59743875,2.6906257,,,,,,,,,,,,,,,,, -110900,0.60807914,2.7118456,,,,,,,,,,,,,,,,, -111000,0.59923744,2.7056487,,,,,,,,,,,,,,,,, -111100,0.6191789,2.6980326,,,,,,,,,,,,,,,,, -111200,0.611768,2.6592953,,,,,,,,,,,,,,,,, -111219,,,0.6886301040649414,1.5508441925048828,34.74471014801158,0.6914111375808716,1.5159528255462646,30.64500096572725,3000.0,0.7083958387374878,1.4284865856170654,30.872641774580423,3003.0,38681.82443928719,62669.208958387375,38681.82443928719,23982.11319732666,1.572606325149536,0.0 -111300,0.6144291,2.6362765,,,,,,,,,,,,,,,,, -111400,0.6485989,2.6846983,,,,,,,,,,,,,,,,, -111500,0.64518785,2.6894116,,,,,,,,,,,,,,,,, -111600,0.61972296,2.6752126,,,,,,,,,,,,,,,,, -111700,0.6602376,2.678048,,,,,,,,,,,,,,,,, -111800,0.64261407,2.6191885,,,,,,,,,,,,,,,,, -111900,0.6454642,2.6651862,,,,,,,,,,,,,,,,, -112000,0.6437709,2.602195,,,,,,,,,,,,,,,,, -112100,0.6291986,2.6213171,,,,,,,,,,,,,,,,, -112200,0.6511423,2.7407591,,,,,,,,,,,,,,,,, -112300,0.676021,2.692323,,,,,,,,,,,,,,,,, -112400,0.6233592,2.6045997,,,,,,,,,,,,,,,,, -112500,0.66271955,2.6529334,,,,,,,,,,,,,,,,, -112600,0.672896,2.670534,,,,,,,,,,,,,,,,, -112700,0.6162707,2.611499,,,,,,,,,,,,,,,,, -112800,0.6291166,2.6406443,,,,,,,,,,,,,,,,, -112900,0.67897725,2.6639934,,,,,,,,,,,,,,,,, -113000,0.6272211,2.687007,,,,,,,,,,,,,,,,, -113100,0.6429402,2.6155198,,,,,,,,,,,,,,,,, -113200,0.63857925,2.690123,,,,,,,,,,,,,,,,, -113300,0.65529656,2.7256145,,,,,,,,,,,,,,,,, -113400,0.652789,2.6713605,,,,,,,,,,,,,,,,, -113500,0.626832,2.6503088,,,,,,,,,,,,,,,,, -113600,0.6340207,2.7008963,,,,,,,,,,,,,,,,, -113635,,,0.7003282308578491,1.4843988418579102,35.50631142680764,0.6920930743217468,1.5145164728164673,30.741345891342124,3000.0,0.7083725929260254,1.428377389907837,30.95872246263002,3003.0,39521.84886193276,63978.92404127121,39521.84886193276,24451.67786312104,1.6134326457977295,0.0 -113700,0.6674882,2.7131045,,,,,,,,,,,,,,,,, -113800,0.6510239,2.6351306,,,,,,,,,,,,,,,,, -113900,0.6674122,2.7168326,,,,,,,,,,,,,,,,, -114000,0.650289,2.609997,,,,,,,,,,,,,,,,, -114100,0.6310679,2.5896327,,,,,,,,,,,,,,,,, -114200,0.6772869,2.5966847,,,,,,,,,,,,,,,,, -114300,0.66145784,2.6764205,,,,,,,,,,,,,,,,, -114400,0.65306294,2.6491122,,,,,,,,,,,,,,,,, -114500,0.64782345,2.6286364,,,,,,,,,,,,,,,,, -114600,0.64428806,2.652607,,,,,,,,,,,,,,,,, -114700,0.67888206,2.7040975,,,,,,,,,,,,,,,,, -114800,0.6896027,2.6143723,,,,,,,,,,,,,,,,, -114900,0.66955596,2.6420302,,,,,,,,,,,,,,,,, -115000,0.65826625,2.6422727,,,,,,,,,,,,,,,,, -115100,0.6760453,2.634799,,,,,,,,,,,,,,,,, -115200,0.6649565,2.6135702,,,,,,,,,,,,,,,,, -115300,0.69082886,2.666199,,,,,,,,,,,,,,,,, -115400,0.6725762,2.6524255,,,,,,,,,,,,,,,,, -115500,0.6820174,2.6183634,,,,,,,,,,,,,,,,, -115600,0.6938309,2.6291978,,,,,,,,,,,,,,,,, -115700,0.67604303,2.658192,,,,,,,,,,,,,,,,, -115800,0.67233616,2.6130388,,,,,,,,,,,,,,,,, -115900,0.66242206,2.6017027,,,,,,,,,,,,,,,,, -116000,0.68288773,2.6729906,,,,,,,,,,,,,,,,, -116053,,,0.6954756379127502,1.5078685283660889,35.4937397891447,0.6918203234672546,1.5126873254776,30.846350466007053,3000.0,0.7091976404190063,1.4271239042282104,31.064218560260848,3003.0,40361.91211247444,65312.42619681358,40361.91211247444,24944.99639344216,1.653876543045044,0.0 -116100,0.66057867,2.5934222,,,,,,,,,,,,,,,,, -116200,0.69235164,2.6343124,,,,,,,,,,,,,,,,, -116300,0.6975528,2.5827286,,,,,,,,,,,,,,,,, -116400,0.6877382,2.599534,,,,,,,,,,,,,,,,, -116500,0.6478525,2.6892245,,,,,,,,,,,,,,,,, -116600,0.70047855,2.6720495,,,,,,,,,,,,,,,,, -116700,0.6560434,2.6112416,,,,,,,,,,,,,,,,, -116800,0.6843725,2.697152,,,,,,,,,,,,,,,,, -116900,0.6602895,2.678166,,,,,,,,,,,,,,,,, -117000,0.69821346,2.6809723,,,,,,,,,,,,,,,,, -117100,0.6830814,2.6014428,,,,,,,,,,,,,,,,, -117200,0.7007237,2.6560373,,,,,,,,,,,,,,,,, -117300,0.69803745,2.5793705,,,,,,,,,,,,,,,,, -117400,0.6880063,2.712473,,,,,,,,,,,,,,,,, -117500,0.6826319,2.6653037,,,,,,,,,,,,,,,,, -117600,0.6668294,2.5937855,,,,,,,,,,,,,,,,, -117700,0.70125794,2.7247236,,,,,,,,,,,,,,,,, -117800,0.70633006,2.6522648,,,,,,,,,,,,,,,,, -117900,0.7039854,2.5926716,,,,,,,,,,,,,,,,, -118000,0.6978279,2.573457,,,,,,,,,,,,,,,,, -118100,0.70030296,2.5791402,,,,,,,,,,,,,,,,, -118200,0.70932645,2.6222165,,,,,,,,,,,,,,,,, -118300,0.66729385,2.6220667,,,,,,,,,,,,,,,,, -118400,0.6621264,2.6177778,,,,,,,,,,,,,,,,, -118470,,,0.6973755359649658,1.5098248720169067,35.351594846978585,0.691745936870575,1.5092313289642334,30.6144301486764,3000.0,0.7104526162147522,1.4224282503128052,30.91869931762568,3003.0,41202.006234169006,66649.37670564651,41202.006234169006,25441.72201323509,1.702117919921875,0.0 -118500,0.69854945,2.6227522,,,,,,,,,,,,,,,,, -118600,0.72139955,2.657856,,,,,,,,,,,,,,,,, -118700,0.70952547,2.6857116,,,,,,,,,,,,,,,,, -118800,0.6815591,2.6111662,,,,,,,,,,,,,,,,, -118900,0.6979446,2.6291363,,,,,,,,,,,,,,,,, -119000,0.7084276,2.6134233,,,,,,,,,,,,,,,,, -119100,0.74086493,2.592312,,,,,,,,,,,,,,,,, -119200,0.7085169,2.5680056,,,,,,,,,,,,,,,,, -119300,0.71065116,2.5827496,,,,,,,,,,,,,,,,, -119400,0.6891731,2.5819488,,,,,,,,,,,,,,,,, -119500,0.70769966,2.578353,,,,,,,,,,,,,,,,, -119600,0.71860117,2.5802636,,,,,,,,,,,,,,,,, -119700,0.7368893,2.6415744,,,,,,,,,,,,,,,,, -119800,0.6891543,2.6063359,,,,,,,,,,,,,,,,, -119900,0.72062457,2.589353,,,,,,,,,,,,,,,,, -120000,0.72242606,2.6042686,,,,,,,,,,,,,,,,, -120100,0.7118086,2.6744456,,,,,,,,,,,,,,,,, -120200,0.71568745,2.6353068,,,,,,,,,,,,,,,,, -120300,0.75388527,2.594543,,,,,,,,,,,,,,,,, -120400,0.74281335,2.6610851,,,,,,,,,,,,,,,,, -120500,0.7237491,2.5945156,,,,,,,,,,,,,,,,, -120600,0.71810347,2.6388667,,,,,,,,,,,,,,,,, -120700,0.69641083,2.5667615,,,,,,,,,,,,,,,,, -120800,0.7366591,2.5666547,,,,,,,,,,,,,,,,, -120888,,,0.7043367624282837,1.4685285091400146,35.96213979511151,0.6929238438606262,1.5071839094161987,30.69302274081346,3000.0,0.7091627717018127,1.4226619005203247,31.02129557676758,3003.0,42042.0838637352,67990.92956995964,42042.0838637352,25943.07703590393,1.7425308227539062,0.0 -120900,0.7345151,2.592553,,,,,,,,,,,,,,,,, -121000,0.7131843,2.5975564,,,,,,,,,,,,,,,,, -121100,0.71621644,2.5954287,,,,,,,,,,,,,,,,, -121200,0.74762464,2.6146615,,,,,,,,,,,,,,,,, -121300,0.70043343,2.6210375,,,,,,,,,,,,,,,,, -121400,0.7347797,2.6271908,,,,,,,,,,,,,,,,, -121500,0.73105705,2.6233723,,,,,,,,,,,,,,,,, -121600,0.7238964,2.6195903,,,,,,,,,,,,,,,,, -121700,0.7198082,2.5830328,,,,,,,,,,,,,,,,, -121800,0.7655605,2.6701016,,,,,,,,,,,,,,,,, -121900,0.72419995,2.6447685,,,,,,,,,,,,,,,,, -122000,0.7183852,2.6036916,,,,,,,,,,,,,,,,, -122100,0.730678,2.587185,,,,,,,,,,,,,,,,, -122200,0.71419954,2.6780665,,,,,,,,,,,,,,,,, -122300,0.74732006,2.614235,,,,,,,,,,,,,,,,, -122400,0.73858625,2.609721,,,,,,,,,,,,,,,,, -122500,0.7427306,2.5938656,,,,,,,,,,,,,,,,, -122600,0.7631268,2.6113372,,,,,,,,,,,,,,,,, -122700,0.7434831,2.6565168,,,,,,,,,,,,,,,,, -122800,0.7192605,2.5823448,,,,,,,,,,,,,,,,, -122900,0.7598317,2.5604866,,,,,,,,,,,,,,,,, -123000,0.7322354,2.621529,,,,,,,,,,,,,,,,, -123100,0.733253,2.605727,,,,,,,,,,,,,,,,, -123200,0.7505736,2.559152,,,,,,,,,,,,,,,,, -123300,0.75280184,2.589901,,,,,,,,,,,,,,,,, -123305,,,0.7063530087471008,1.4596189260482788,36.09334633422933,0.6923534870147705,1.504960298538208,30.75848752404956,3000.0,0.7106502056121826,1.4200072288513184,31.02842778685176,3003.0,42882.12957930565,69315.98330974579,42882.12957930565,26427.96184802056,1.7832961082458496,0.0 -123400,0.7310429,2.6390753,,,,,,,,,,,,,,,,, -123500,0.7634118,2.631552,,,,,,,,,,,,,,,,, -123600,0.7211031,2.6393745,,,,,,,,,,,,,,,,, -123700,0.75252616,2.618255,,,,,,,,,,,,,,,,, -123800,0.73533946,2.5558019,,,,,,,,,,,,,,,,, -123900,0.7450604,2.6129992,,,,,,,,,,,,,,,,, -124000,0.75685704,2.6010165,,,,,,,,,,,,,,,,, -124100,0.76098305,2.615887,,,,,,,,,,,,,,,,, -124200,0.7404236,2.5881145,,,,,,,,,,,,,,,,, -124300,0.76933044,2.6135676,,,,,,,,,,,,,,,,, -124400,0.75044316,2.5718477,,,,,,,,,,,,,,,,, -124500,0.73321164,2.5756333,,,,,,,,,,,,,,,,, -124600,0.75150394,2.624206,,,,,,,,,,,,,,,,, -124700,0.73374146,2.597656,,,,,,,,,,,,,,,,, -124800,0.7718618,2.575733,,,,,,,,,,,,,,,,, -124900,0.78040105,2.6010835,,,,,,,,,,,,,,,,, -125000,0.7540457,2.6365156,,,,,,,,,,,,,,,,, -125100,0.75505453,2.6329005,,,,,,,,,,,,,,,,, -125200,0.75505567,2.5730677,,,,,,,,,,,,,,,,, -125300,0.73072016,2.5349126,,,,,,,,,,,,,,,,, -125400,0.7407254,2.6222575,,,,,,,,,,,,,,,,, -125500,0.75003576,2.5794945,,,,,,,,,,,,,,,,, -125600,0.75059396,2.6312876,,,,,,,,,,,,,,,,, -125700,0.77140075,2.596064,,,,,,,,,,,,,,,,, -125724,,,0.7098019123077393,1.4383488893508911,36.35349656081629,0.6933577656745911,1.502547264099121,30.86917300683674,3000.0,0.710615336894989,1.4180010557174685,31.023551642698948,3003.0,43722.28950881958,70639.37914276123,43722.28950881958,26911.074389457703,1.8275110721588133,0.0 -125800,0.7545237,2.5999389,,,,,,,,,,,,,,,,, -125900,0.77467686,2.5625143,,,,,,,,,,,,,,,,, -126000,0.72039086,2.5773726,,,,,,,,,,,,,,,,, -126100,0.7761038,2.6074734,,,,,,,,,,,,,,,,, -126200,0.76406306,2.587995,,,,,,,,,,,,,,,,, -126300,0.7392466,2.555898,,,,,,,,,,,,,,,,, -126400,0.7427572,2.5581486,,,,,,,,,,,,,,,,, -126500,0.7497156,2.6270764,,,,,,,,,,,,,,,,, -126600,0.737107,2.6097283,,,,,,,,,,,,,,,,, -126700,0.7054489,2.5398905,,,,,,,,,,,,,,,,, -126800,0.75510466,2.5716453,,,,,,,,,,,,,,,,, -126900,0.74702144,2.5596385,,,,,,,,,,,,,,,,, -127000,0.7280771,2.564783,,,,,,,,,,,,,,,,, -127100,0.75204146,2.5714362,,,,,,,,,,,,,,,,, -127200,0.75791943,2.5816815,,,,,,,,,,,,,,,,, -127300,0.7694216,2.599008,,,,,,,,,,,,,,,,, -127400,0.7442244,2.5741737,,,,,,,,,,,,,,,,, -127500,0.74886507,2.6022415,,,,,,,,,,,,,,,,, -127600,0.7512391,2.5520988,,,,,,,,,,,,,,,,, -127700,0.75813365,2.614071,,,,,,,,,,,,,,,,, -127800,0.7439743,2.5637817,,,,,,,,,,,,,,,,, -127900,0.73912406,2.5816715,,,,,,,,,,,,,,,,, -128000,0.7515736,2.5717049,,,,,,,,,,,,,,,,, -128100,0.7584066,2.5723433,,,,,,,,,,,,,,,,, -128142,,,0.7098627090454102,1.4390041828155518,36.20999485164631,0.6937545537948608,1.502580165863037,30.961534202035583,3000.0,0.7110220193862915,1.4158810377120972,31.07880207453981,3003.0,44562.2737493515,71956.04916667938,44562.2737493515,27387.63648557663,1.872190237045288,0.0 -128200,0.71139914,2.5722103,,,,,,,,,,,,,,,,, -128300,0.77673054,2.5821187,,,,,,,,,,,,,,,,, -128400,0.74857056,2.5578659,,,,,,,,,,,,,,,,, -128500,0.75294393,2.5790744,,,,,,,,,,,,,,,,, -128600,0.7076852,2.5330832,,,,,,,,,,,,,,,,, -128700,0.7634585,2.5360258,,,,,,,,,,,,,,,,, -128800,0.7668858,2.6167595,,,,,,,,,,,,,,,,, -128900,0.75305414,2.5622454,,,,,,,,,,,,,,,,, -129000,0.77825886,2.593336,,,,,,,,,,,,,,,,, -129100,0.7309951,2.5878913,,,,,,,,,,,,,,,,, -129200,0.7390471,2.6180277,,,,,,,,,,,,,,,,, -129300,0.742404,2.6100943,,,,,,,,,,,,,,,,, -129400,0.7495616,2.5769658,,,,,,,,,,,,,,,,, -129500,0.75170094,2.5270245,,,,,,,,,,,,,,,,, -129600,0.72767764,2.593807,,,,,,,,,,,,,,,,, -129700,0.78102136,2.6700172,,,,,,,,,,,,,,,,, -129800,0.7385933,2.5846193,,,,,,,,,,,,,,,,, -129900,0.7669713,2.5311434,,,,,,,,,,,,,,,,, -130000,0.7844386,2.6008344,,,,,,,,,,,,,,,,, -130100,0.7606846,2.5705922,,,,,,,,,,,,,,,,, -130200,0.7744084,2.5890365,,,,,,,,,,,,,,,,, -130300,0.7422165,2.587865,,,,,,,,,,,,,,,,, -130400,0.7457731,2.5805118,,,,,,,,,,,,,,,,, -130500,0.7186012,2.5742497,,,,,,,,,,,,,,,,, -130560,,,0.7109185457229614,1.434372901916504,36.323414242246834,0.6939157247543335,1.5021934509277344,30.871359894567888,3000.0,0.711266040802002,1.415252923965454,31.17882674781833,3003.0,45402.22863721848,73272.68795681,45402.22863721848,27864.195012569427,1.916002988815308,0.0 -130600,0.7696635,2.5863528,,,,,,,,,,,,,,,,, -130700,0.74296206,2.5936825,,,,,,,,,,,,,,,,, -130800,0.7397483,2.6221054,,,,,,,,,,,,,,,,, -130900,0.74153656,2.5468194,,,,,,,,,,,,,,,,, -131000,0.7513337,2.5749564,,,,,,,,,,,,,,,,, -131100,0.7328603,2.5751445,,,,,,,,,,,,,,,,, -131200,0.75260574,2.5036411,,,,,,,,,,,,,,,,, -131300,0.76402086,2.5732946,,,,,,,,,,,,,,,,, -131400,0.7404627,2.5463014,,,,,,,,,,,,,,,,, -131500,0.7446334,2.563341,,,,,,,,,,,,,,,,, -131600,0.7630706,2.633473,,,,,,,,,,,,,,,,, -131700,0.7515491,2.6247504,,,,,,,,,,,,,,,,, -131800,0.76238066,2.6288471,,,,,,,,,,,,,,,,, -131900,0.726754,2.5421689,,,,,,,,,,,,,,,,, -132000,0.7411047,2.5569544,,,,,,,,,,,,,,,,, -132100,0.7501872,2.580158,,,,,,,,,,,,,,,,, -132200,0.7554857,2.5910034,,,,,,,,,,,,,,,,, -132300,0.7486301,2.64979,,,,,,,,,,,,,,,,, -132400,0.7498629,2.563368,,,,,,,,,,,,,,,,, -132500,0.7542315,2.5674505,,,,,,,,,,,,,,,,, -132600,0.7321159,2.6271079,,,,,,,,,,,,,,,,, -132700,0.7514801,2.535298,,,,,,,,,,,,,,,,, -132800,0.7248462,2.4928145,,,,,,,,,,,,,,,,, -132900,0.74745804,2.5920515,,,,,,,,,,,,,,,,, -132978,,,0.7064332365989685,1.4602043628692627,36.41087151297877,0.6937917470932007,1.502265214920044,30.904665803245663,3000.0,0.7113938927650452,1.4153876304626465,31.23585328606663,3003.0,46242.14085435867,74583.09764313698,46242.14085435867,28334.5637383461,1.960179090499878,0.0 -133000,0.759186,2.6067953,,,,,,,,,,,,,,,,, -133100,0.73461753,2.589456,,,,,,,,,,,,,,,,, -133200,0.7430192,2.633551,,,,,,,,,,,,,,,,, -133300,0.7260182,2.5888143,,,,,,,,,,,,,,,,, -133333,,,0.7103666067123413,1.433834433555603,36.36197193300483,0.6937669515609741,1.502273440361023,30.9020444162084,3000.0,0.7113706469535828,1.415420651435852,31.22193609040821,3003.0,46365.04322433472,75180.2390575409,46365.04322433472,28808.747275829315,2.0037403106689453,0.0 -133333,,,,,,,,,,,,,,46365.04322433472,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 79b8edf65..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -861.9812302589417,0.0,28.013572454452515,1,0,28.013572454452515,0.0007088489946909,0.0,10.966498374938965,3003,889.9948675632477,0.0006404464365914,0.0,10.957476615905762,0.0004835649742744,0.0,10.980294227600098,3000 -1452.7302029132843,0.0190260410308837,868.0313177108765,2413,0,868.0313177108765,0.4059613049030304,9.371626452081102,4.175229549407959,3003,2320.8611991405487,0.4285485744476318,15.366745138241075,3.919285297393799,0.4183457195758819,10.708148967332471,4.008279800415039,3000 -1923.3364017009733,0.0459742546081543,1708.2604315280914,4825,0,1708.2604315280914,0.5490791201591492,19.09038173838629,2.7924811840057373,3003,3631.8051176071167,0.5466793179512024,24.074602836910767,2.79789137840271,0.5493918061256409,20.89684893807884,2.760988473892212,3000 -2385.8916296958923,0.0791592597961425,2548.210542678833,7238,0,2548.210542678833,0.5952356457710266,22.04175614493549,2.362957000732422,3003,4934.42494559288,0.582229733467102,25.84377936152035,2.4481654167175293,0.5931110382080078,23.34587319084209,2.375218152999878,3000 -2843.837694168091,0.1052148342132568,3388.2512698173523,9653,0,3388.2512698173523,0.6251118779182434,24.447402386344702,2.1336803436279297,3003,6232.517226934433,0.603490948677063,28.78597901242265,2.285142183303833,0.6179216504096985,25.328648638725443,2.1735126972198486,3000 -3303.861836194992,0.1316878795623779,4228.148333311081,12068,0,4228.148333311081,0.6416710615158081,25.192112223007232,1.9897395372390747,3003,7532.544618368149,0.6128459572792053,29.37454607756388,2.18766450881958,0.6334205269813538,26.08702096366193,2.038684844970703,3000 -3743.176248073578,0.162806749343872,5068.229043722153,14483,0,5068.229043722153,0.6509325504302979,26.48782476873385,1.8961308002471924,3003,8812.05256319046,0.6248981952667236,30.089301867317968,2.080275535583496,0.641777515411377,27.244437164531448,1.9527431726455688,3000 -4210.046847343445,0.1902122497558593,5908.441868543625,16900,0,5908.441868543625,0.6578467488288879,26.744218529681,1.8625454902648928,3003,10119.24120926857,0.6296117901802063,30.56911605973895,2.061455726623535,0.6512380242347717,27.760250709466792,1.9128708839416504,3000 -4706.243428945541,0.2173449993133545,6748.534591674805,19316,0,6748.534591674805,0.6642031669616699,27.05667497708177,1.809813857078552,3003,11455.637856960297,0.6452200412750244,31.575388848299944,1.942769169807434,0.6562596559524536,28.11841567073765,1.870498538017273,3000 -5181.89380979538,0.2472145557403564,7588.722680091858,21732,0,7588.722680091858,0.6690372824668884,27.950302449347426,1.7705334424972534,3003,12771.588398694992,0.6420875787734985,31.35088749641315,1.9512648582458496,0.6596198081970215,28.56932057810524,1.8318740129470823,3000 -5662.523268461227,0.2762696743011474,8428.657732963562,24147,0,8428.657732963562,0.6719539761543274,28.07511583204494,1.7517828941345217,3003,14092.263476133348,0.6424695253372192,31.319786346641795,1.9587422609329224,0.6618020534515381,28.48977382265801,1.812417268753052,3000 -6321.172736406326,0.3053386211395263,9268.82431268692,26562,0,9268.82431268692,0.6746034622192383,28.215674239974483,1.7212979793548584,3003,15591.191541194916,0.6491546630859375,32.128709662827625,1.897887825965881,0.6637735366821289,28.636768559280444,1.7884488105773926,3000 -6805.479952096939,0.3381767272949219,10108.904735088348,28978,0,10108.904735088348,0.6767067909240723,28.1050435110422,1.701871633529663,3003,16915.69260573387,0.6483582854270935,31.718846363850467,1.9046748876571653,0.6666997075080872,28.82188555062492,1.7674710750579834,3000 -7331.497166872024,0.3668546676635742,10948.82987356186,31393,0,10948.82987356186,0.6804369688034058,28.100285054452986,1.6927775144577026,3003,18281.74568796158,0.6834477782249451,34.22825500162002,1.6875489950180054,0.6686215996742249,28.802539963238083,1.760544776916504,3000 -7941.9084758758545,0.3961889743804931,11788.780126810074,33808,0,11788.780126810074,0.681715190410614,28.551759458820776,1.691118240356445,3003,19732.2159409523,0.6541592478752136,31.89343011566094,1.8681278228759768,0.6689191460609436,29.06737236228412,1.7596172094345093,3000 -8462.458575487137,0.4259674549102783,12629.00738811493,36224,0,12629.00738811493,0.6827842593193054,28.751117281827,1.6493988037109375,3003,21093.10581278801,0.654525101184845,32.16741371237828,1.8604371547698968,0.673159658908844,29.286452756607737,1.7235437631607056,3000 -8911.495740890503,0.4579017162322998,13469.150447130203,38641,0,13469.150447130203,0.6843181848526001,28.69606835632274,1.6504517793655396,3003,22382.396797180176,0.6630606055259705,32.4857500217436,1.7945261001586914,0.6725273132324219,29.01776935916361,1.7204879522323608,3000 -9388.679366111755,0.4881284236907959,14309.070579051971,41057,0,14309.070579051971,0.6871187090873718,29.00677069047909,1.634048342704773,3003,23699.610480308533,0.6570492386817932,32.32277040187047,1.825039505958557,0.6732960343360901,29.21355250631401,1.7080892324447632,3000 -9852.615337371826,0.5199141502380371,15149.12542128563,43472,0,15149.12542128563,0.6887455582618713,28.902835204010888,1.635982632637024,3003,25003.713657855988,0.6580873727798462,32.67865687974197,1.8422174453735352,0.6763338446617126,29.45904070467425,1.7116023302078247,3000 -10475.447494745256,0.5502684116363525,15989.273291826248,45888,0,15989.273291826248,0.6890825629234314,28.80330193904192,1.618253231048584,3003,26466.803416490555,0.6625374555587769,32.6860518428758,1.791016936302185,0.6753295063972473,29.099113091513008,1.6926016807556152,3000 -10964.418253660202,0.5813858509063721,16829.422029733658,48304,0,16829.422029733658,0.6889082789421082,29.31796205687577,1.6244529485702517,3003,27796.0339281559,0.6612311601638794,32.39720039613303,1.815684676170349,0.6767057776451111,29.48392422047256,1.702085256576538,3000 -11437.406776428224,0.6123223304748535,17669.506113767624,50720,0,17669.506113767624,0.6897914409637451,29.17607198808368,1.608918070793152,3003,29109.21942853928,0.6744092702865601,33.865250969584935,1.718177080154419,0.6792972087860107,29.988029477140355,1.6851235628128052,3000 -11987.390083789824,0.6450626850128174,18509.686230421063,53136,0,18509.686230421063,0.6934635043144226,29.42710453871561,1.597248673439026,3003,30499.49797201157,0.6658750772476196,32.865700967124376,1.772932529449463,0.677821695804596,29.386465831879853,1.675114989280701,3000 -12460.27916264534,0.6809508800506592,19349.803040981293,55552,0,19349.803040981293,0.6937540173530579,29.22578359814461,1.5995779037475586,3003,31812.61920762062,0.6659360527992249,32.94168061929005,1.7834343910217283,0.6789996027946472,29.670184999573905,1.6796774864196775,3000 -12921.206790924072,0.71565842628479,20189.68864178657,57968,0,20189.68864178657,0.6937307715415955,29.51395797552482,1.5895719528198242,3003,33113.54695224762,0.6726042032241821,33.86621185465534,1.7362380027770996,0.6793344020843506,29.51395916056389,1.6724979877471924,3000 -13383.99617767334,0.7550392150878906,21029.69061565399,60383,0,21029.69061565399,0.6942071914672852,29.482354140661048,1.5787086486816406,3003,34416.46136879921,0.6659606695175171,33.31700121865782,1.7663089036941528,0.6792724132537842,30.04262259680157,1.6618462800979614,3000 -13954.791036367416,0.7893767356872559,21869.710822820663,62799,0,21869.710822820663,0.6970425844192505,29.63074699460067,1.576493263244629,3003,35827.3902528286,0.6952691674232483,34.85920428088039,1.5940762758255005,0.6813182830810547,30.02150473419823,1.6626694202423096,3000 -14495.03400874138,0.8237001895904541,22709.933475017548,65215,0,22709.933475017548,0.6985416412353516,29.660603258547127,1.574635028839111,3003,37207.97281885147,0.6770080327987671,33.7335961330243,1.7094494104385376,0.6842692494392395,30.38194529115567,1.6536860466003418,3000 -14988.952271461489,0.8565609455108643,23549.856046438217,67630,0,23549.856046438217,0.6962407827377319,29.662369265097585,1.5660450458526611,3003,38541.92924189568,0.6708812117576599,33.23434424702395,1.747066617012024,0.6815910339355469,29.844001559891947,1.65132737159729,3000 -15469.488009691238,0.8931279182434082,24390.056773662567,70045,0,24390.056773662567,0.697623610496521,29.73855557334434,1.5698145627975464,3003,39862.7853975296,0.6814951300621033,34.03927182611093,1.671489715576172,0.6828433275222778,30.02545059786362,1.652961492538452,3000 -15953.917650938034,0.9359502792358398,25230.09949684143,72461,0,25230.09949684143,0.6997618079185486,29.75574832873671,1.559238314628601,3003,41187.3813123703,0.6753785014152527,33.9016781817681,1.71561861038208,0.6829301714897156,30.174211488067066,1.6497820615768433,3000 -16500.258165597916,0.977301836013794,26069.99861884117,74876,0,26069.99861884117,0.7012027502059937,30.20991710316605,1.540870189666748,3003,42573.744396448135,0.6737843155860901,33.94385616955503,1.714847445487976,0.6842072606086731,30.11589849755397,1.6327464580535889,3000 -17024.11154460907,1.0152418613433838,26909.980088472366,77292,0,26909.980088472366,0.7004590034484863,30.350398961195094,1.5464823246002195,3003,43937.69822263718,0.6831181645393372,34.21749717101149,1.6621711254119873,0.685459554195404,30.35544077774712,1.6368142366409302,3000 -17480.916483163834,1.0527923107147217,27749.933834314343,79708,0,27749.933834314343,0.7026204466819763,30.37668151012431,1.5370594263076782,3003,45234.57467293739,0.6789937019348145,34.13086744144354,1.6847949028015137,0.6863647103309631,30.33725457330963,1.625876784324646,3000 -17937.84757256508,1.0906941890716553,28590.127032995224,82124,0,28590.127032995224,0.7039335370063782,30.4180427003378,1.5355430841445925,3003,46531.81851029396,0.6997047066688538,35.5428039352093,1.573757290840149,0.6866622567176819,30.64508817624996,1.6228492259979248,3000 -18416.06471323967,1.1282691955566406,29430.197852134705,84539,0,29430.197852134705,0.7041078805923462,30.54304867956124,1.5266790390014648,3003,47850.226432323456,0.682442307472229,34.606211236537106,1.6693923473358154,0.688286542892456,30.211782980971023,1.6197954416275024,3000 -18922.261901140213,1.1654300689697266,30270.317069530487,86955,0,30270.317069530487,0.7032247185707092,30.31646394999133,1.533256769180298,3003,49196.66105890274,0.6825690269470215,34.71119184261502,1.666361689567566,0.6877037882804871,30.182893205006323,1.6188981533050537,3000 -19389.877872228622,1.2036826610565186,31110.298770189285,89370,0,31110.298770189285,0.705083966255188,30.645274576054728,1.517716407775879,3003,50504.37959980965,0.6950032711029053,35.154835319669424,1.60042405128479,0.6891793012619019,30.554063056987584,1.609760046005249,3000 -19851.59775352478,1.242016077041626,31950.33271098137,91785,0,31950.33271098137,0.7055255770683289,30.73101495101404,1.5153443813323977,3003,51806.25148367882,0.6907691955566406,34.67632533933311,1.61209237575531,0.6889809370040894,30.47947166718364,1.6088262796401978,3000 -20333.17138814926,1.2815396785736084,32790.51075673103,94200,0,32790.51075673103,0.7057114839553833,30.51289879789388,1.514120101928711,3003,53128.12596178055,0.7140809893608093,36.6485987570553,1.4977169036865234,0.6901712417602539,30.60468141212876,1.6079293489456177,3000 -20813.4673306942,1.321253538131714,33630.70009255409,96616,0,33630.70009255409,0.7062111496925354,30.341710528802896,1.511307954788208,3003,54448.7342133522,0.6998093128204346,35.22440418347292,1.5723187923431396,0.6890801191329956,30.391074376933528,1.608446478843689,3000 -21287.758934021,1.3601250648498535,34470.83130598068,99032,0,34470.83130598068,0.706362247467041,30.670146910166952,1.5094925165176392,3003,55763.27522611618,0.6960033774375916,35.784215910569934,1.58683180809021,0.6911879777908325,30.653904813297338,1.6049529314041138,3000 -21750.94939732552,1.3992786407470703,35310.982800245285,101448,0,35310.982800245285,0.7069781422615051,30.958594252207657,1.5060288906097412,3003,57066.73641419411,0.7109413743019104,36.498945434107455,1.5052330493927002,0.6911631226539612,30.694062047415063,1.6019834280014038,3000 -22215.19164896012,1.43782639503479,36151.15802383423,103863,0,36151.15802383423,0.7081633806228638,30.9335101099981,1.5044052600860596,3003,58371.27501010895,0.6996182203292847,36.38101795307617,1.565675139427185,0.6910267472267151,30.47832296179434,1.6026326417922974,3000 -22687.063174963,1.4773857593536377,36991.27319264412,106278,0,36991.27319264412,0.7085352540016174,30.86064708275572,1.5015463829040527,3003,59683.38256788254,0.7071945071220398,35.68798725043688,1.5247044563293457,0.6906920075416565,30.56966933807441,1.6015530824661257,3000 -23157.18415951729,1.517770767211914,37831.44410777092,108693,0,37831.44410777092,0.7084887623786926,30.698081478592425,1.4992235898971558,3003,60993.79703378677,0.7097283601760864,36.4322003452034,1.5101783275604248,0.691696286201477,30.72172265147943,1.6009178161621094,3000 -23627.291821718216,1.559746265411377,38671.49594020844,111109,0,38671.49594020844,0.7071989178657532,30.854497798748653,1.5012623071670532,3003,62304.079092502594,0.7096648812294006,36.62844563646342,1.506799578666687,0.6907168030738831,30.522907929975723,1.6045445203781128,3000 -24088.021463871,1.606471061706543,39511.46545481682,113523,0,39511.46545481682,0.7090232968330383,30.81991409290025,1.4939895868301392,3003,63604.90850496292,0.7205490469932556,37.3473881454824,1.4559381008148191,0.6923410892486572,30.74776238203665,1.5988622903823853,3000 -24553.03221011161,1.6485328674316406,40351.63733792305,115939,0,40351.63733792305,0.7080239653587341,30.801907041652694,1.49728524684906,3003,64910.21405529976,0.7139232158660889,36.63052469599194,1.487220048904419,0.6912499666213989,30.716334354820788,1.601006269454956,3000 -25012.41903567314,1.6906397342681885,41191.82350349426,118355,0,41191.82350349426,0.7077218294143677,30.7039356366018,1.4967893362045288,3003,66209.90850758553,0.7143579721450806,37.03462724545019,1.484772801399231,0.6923038959503174,30.58042947513172,1.6000486612319946,3000 -25484.83347582817,1.7393221855163574,42031.85128903389,120769,0,42031.85128903389,0.7095927000045776,30.727894032551443,1.4932905435562134,3003,67522.48354244232,0.7247974872589111,37.53693442719608,1.4364838600158691,0.6926262378692627,30.67829971443381,1.5970616340637207,3000 -25945.196885108948,1.7829713821411133,42871.75427532196,123184,0,42871.75427532196,0.7095230221748352,30.812469473793247,1.493468999862671,3003,68822.87196969986,0.7164363265037537,37.23399471055593,1.4766035079956057,0.6918575167655945,30.603763521386373,1.5992780923843384,3000 -26413.436207294464,1.82772159576416,43711.76387667656,125598,0,43711.76387667656,0.7091279029846191,30.822391857216715,1.4928165674209597,3003,70131.25042271614,0.7242320775985718,37.27877258251524,1.4370570182800293,0.6923782825469971,30.46979271110354,1.598088026046753,3000 -26876.707780361176,1.870469331741333,44551.90661859512,128013,0,44551.90661859512,0.7096856832504272,30.728892446349207,1.493368148803711,3003,71434.78941488266,0.7244399785995483,37.314039333533486,1.435945749282837,0.6923162937164307,30.50014741034056,1.597839117050171,3000 -27340.31756520272,1.9223430156707764,45392.08640527725,130429,0,45392.08640527725,0.7091395258903503,30.76292408930312,1.493313193321228,3003,72738.71410131454,0.7212345004081726,37.38313524691233,1.4529694318771362,0.6924278736114502,30.58731194343409,1.598196029663086,3000 -27812.481696128845,1.968170166015625,46231.99690008164,132843,0,46231.99690008164,0.7089187502861023,30.805018297453067,1.4930022954940796,3003,74050.91628265381,0.7237308025360107,37.43974567291994,1.441246747970581,0.6926634311676025,30.54856858515057,1.598312497138977,3000 -28281.60786294937,2.0151820182800293,46401.92576980591,133333,0,46401.92576980591,0.7088838815689087,30.78181083377449,1.4929988384246826,3003,74690.03471302986,0.7221422791481018,37.50803900245006,1.4469060897827148,0.6926758289337158,30.54901382135056,1.598305106163025,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/measurements.csv deleted file mode 100644 index f809a37ea..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.718122,10.972803,,,,,,,,,,,,,,,,, -1,,,0.0006404464365914,10.957476615905762,0.0,0.0004835649742744,10.980294227600098,0.0,3000.0,0.0007088489946909,10.966498374938965,0.0,3003.0,28.013572454452515,889.9948675632477,28.013572454452515,861.9812302589417,0.0,0.0 -100,0.26217154,9.080983,,,,,,,,,,,,,,,,, -200,0.20238712,8.686631,,,,,,,,,,,,,,,,, -300,0.86200863,8.31746,,,,,,,,,,,,,,,,, -400,0.702041,8.102167,,,,,,,,,,,,,,,,, -500,1.0521156,7.861621,,,,,,,,,,,,,,,,, -600,0.6750308,7.653686,,,,,,,,,,,,,,,,, -700,0.80034643,7.5326624,,,,,,,,,,,,,,,,, -800,0.60453415,7.2980566,,,,,,,,,,,,,,,,, -900,0.58557457,7.1105933,,,,,,,,,,,,,,,,, -1000,0.7684279,7.0089364,,,,,,,,,,,,,,,,, -1100,0.5497048,6.8114786,,,,,,,,,,,,,,,,, -1200,0.55962455,6.7065144,,,,,,,,,,,,,,,,, -1300,0.5158026,6.641334,,,,,,,,,,,,,,,,, -1400,0.6539942,6.5464277,,,,,,,,,,,,,,,,, -1500,0.7065697,6.43843,,,,,,,,,,,,,,,,, -1600,0.60874933,6.304315,,,,,,,,,,,,,,,,, -1700,0.60077995,6.2144494,,,,,,,,,,,,,,,,, -1800,0.7954559,6.1620245,,,,,,,,,,,,,,,,, -1900,0.7098702,6.023779,,,,,,,,,,,,,,,,, -2000,0.6988197,5.9722896,,,,,,,,,,,,,,,,, -2100,0.63756204,5.824293,,,,,,,,,,,,,,,,, -2200,0.5037304,5.7686286,,,,,,,,,,,,,,,,, -2300,0.77904904,5.699641,,,,,,,,,,,,,,,,, -2400,0.4909803,5.5310774,,,,,,,,,,,,,,,,, -2413,,,0.4285485744476318,3.919285297393799,15.366745138241075,0.4183457195758819,4.008279800415039,10.708148967332471,3000.0,0.4059613049030304,4.175229549407959,9.371626452081102,3003.0,868.0313177108765,2320.8611991405487,868.0313177108765,1452.7302029132843,0.0190260410308837,0.0 -2500,0.65532875,5.5900655,,,,,,,,,,,,,,,,, -2600,0.54390574,5.376059,,,,,,,,,,,,,,,,, -2700,0.77021927,5.4274793,,,,,,,,,,,,,,,,, -2800,0.7414649,5.3610945,,,,,,,,,,,,,,,,, -2900,0.5093289,5.238914,,,,,,,,,,,,,,,,, -3000,0.5242716,5.2140055,,,,,,,,,,,,,,,,, -3100,0.53530085,5.185772,,,,,,,,,,,,,,,,, -3200,0.46857914,5.1332335,,,,,,,,,,,,,,,,, -3300,0.5407777,5.0374513,,,,,,,,,,,,,,,,, -3400,0.66170865,5.0832005,,,,,,,,,,,,,,,,, -3500,0.47535637,4.986652,,,,,,,,,,,,,,,,, -3600,0.5021697,5.030281,,,,,,,,,,,,,,,,, -3700,0.45826206,4.959399,,,,,,,,,,,,,,,,, -3800,0.4913259,4.9507594,,,,,,,,,,,,,,,,, -3900,0.42912933,4.899036,,,,,,,,,,,,,,,,, -4000,0.40886965,4.9090014,,,,,,,,,,,,,,,,, -4100,0.44358784,4.8754725,,,,,,,,,,,,,,,,, -4200,0.49942487,4.8544226,,,,,,,,,,,,,,,,, -4300,0.41098085,4.796178,,,,,,,,,,,,,,,,, -4400,0.42510644,4.7789793,,,,,,,,,,,,,,,,, -4500,0.4286625,4.7460785,,,,,,,,,,,,,,,,, -4600,0.4168868,4.751547,,,,,,,,,,,,,,,,, -4700,0.36536014,4.708461,,,,,,,,,,,,,,,,, -4800,0.40667206,4.70608,,,,,,,,,,,,,,,,, -4825,,,0.5466793179512024,2.79789137840271,24.074602836910767,0.5493918061256409,2.760988473892212,20.89684893807884,3000.0,0.5490791201591492,2.7924811840057373,19.09038173838629,3003.0,1708.2604315280914,3631.8051176071167,1708.2604315280914,1923.3364017009733,0.0459742546081543,0.0 -4900,0.35252914,4.7009044,,,,,,,,,,,,,,,,, -5000,0.36115822,4.694727,,,,,,,,,,,,,,,,, -5100,0.33185712,4.643216,,,,,,,,,,,,,,,,, -5200,0.31707326,4.6330934,,,,,,,,,,,,,,,,, -5300,0.34737486,4.607986,,,,,,,,,,,,,,,,, -5400,0.34914544,4.6245704,,,,,,,,,,,,,,,,, -5500,0.33914715,4.644703,,,,,,,,,,,,,,,,, -5600,0.3415663,4.6280317,,,,,,,,,,,,,,,,, -5700,0.35512805,4.6170144,,,,,,,,,,,,,,,,, -5800,0.3325846,4.5825233,,,,,,,,,,,,,,,,, -5900,0.31719297,4.5755634,,,,,,,,,,,,,,,,, -6000,0.29526815,4.558788,,,,,,,,,,,,,,,,, -6100,0.31876808,4.5817823,,,,,,,,,,,,,,,,, -6200,0.2787825,4.5798635,,,,,,,,,,,,,,,,, -6300,0.30351037,4.509255,,,,,,,,,,,,,,,,, -6400,0.2686023,4.512942,,,,,,,,,,,,,,,,, -6500,0.2586723,4.539218,,,,,,,,,,,,,,,,, -6600,0.28254256,4.5737495,,,,,,,,,,,,,,,,, -6700,0.25596514,4.5147123,,,,,,,,,,,,,,,,, -6800,0.24393368,4.4847555,,,,,,,,,,,,,,,,, -6900,0.29258758,4.515305,,,,,,,,,,,,,,,,, -7000,0.22852707,4.464119,,,,,,,,,,,,,,,,, -7100,0.29699287,4.52708,,,,,,,,,,,,,,,,, -7200,0.26832396,4.4357166,,,,,,,,,,,,,,,,, -7238,,,0.582229733467102,2.4481654167175293,25.84377936152035,0.5931110382080078,2.375218152999878,23.34587319084209,3000.0,0.5952356457710266,2.362957000732422,22.04175614493549,3003.0,2548.210542678833,4934.42494559288,2548.210542678833,2385.8916296958923,0.0791592597961425,0.0 -7300,0.2799663,4.411298,,,,,,,,,,,,,,,,, -7400,0.22520222,4.3654594,,,,,,,,,,,,,,,,, -7500,0.22528253,4.382702,,,,,,,,,,,,,,,,, -7600,0.22612941,4.3521023,,,,,,,,,,,,,,,,, -7700,0.23018588,4.384777,,,,,,,,,,,,,,,,, -7800,0.224747,4.4005413,,,,,,,,,,,,,,,,, -7900,0.23131342,4.457604,,,,,,,,,,,,,,,,, -8000,0.21210681,4.347127,,,,,,,,,,,,,,,,, -8100,0.21429272,4.3690205,,,,,,,,,,,,,,,,, -8200,0.20308463,4.370054,,,,,,,,,,,,,,,,, -8300,0.2255682,4.340023,,,,,,,,,,,,,,,,, -8400,0.22205852,4.3611984,,,,,,,,,,,,,,,,, -8500,0.18517877,4.2574735,,,,,,,,,,,,,,,,, -8600,0.20493354,4.258499,,,,,,,,,,,,,,,,, -8700,0.19978671,4.332851,,,,,,,,,,,,,,,,, -8800,0.25720483,4.273906,,,,,,,,,,,,,,,,, -8900,0.2424546,4.2878594,,,,,,,,,,,,,,,,, -9000,0.21048412,4.3906617,,,,,,,,,,,,,,,,, -9100,0.18154663,4.3234215,,,,,,,,,,,,,,,,, -9200,0.20472029,4.3739867,,,,,,,,,,,,,,,,, -9300,0.18741766,4.292516,,,,,,,,,,,,,,,,, -9400,0.18918277,4.3400865,,,,,,,,,,,,,,,,, -9500,0.17886783,4.256828,,,,,,,,,,,,,,,,, -9600,0.20181073,4.266094,,,,,,,,,,,,,,,,, -9653,,,0.603490948677063,2.285142183303833,28.78597901242265,0.6179216504096985,2.1735126972198486,25.328648638725443,3000.0,0.6251118779182434,2.1336803436279297,24.447402386344702,3003.0,3388.2512698173523,6232.517226934433,3388.2512698173523,2843.837694168091,0.1052148342132568,0.0 -9700,0.18435355,4.204712,,,,,,,,,,,,,,,,, -9800,0.17888819,4.2729664,,,,,,,,,,,,,,,,, -9900,0.18402956,4.2505026,,,,,,,,,,,,,,,,, -10000,0.18308285,4.2193227,,,,,,,,,,,,,,,,, -10100,0.18000115,4.3152795,,,,,,,,,,,,,,,,, -10200,0.20928021,4.241019,,,,,,,,,,,,,,,,, -10300,0.17631656,4.238051,,,,,,,,,,,,,,,,, -10400,0.17558616,4.233207,,,,,,,,,,,,,,,,, -10500,0.1735815,4.319882,,,,,,,,,,,,,,,,, -10600,0.17791645,4.246314,,,,,,,,,,,,,,,,, -10700,0.19412684,4.2882285,,,,,,,,,,,,,,,,, -10800,0.17048128,4.2452936,,,,,,,,,,,,,,,,, -10900,0.17777489,4.2176046,,,,,,,,,,,,,,,,, -11000,0.2195508,4.3305173,,,,,,,,,,,,,,,,, -11100,0.17678735,4.1672163,,,,,,,,,,,,,,,,, -11200,0.16485918,4.172787,,,,,,,,,,,,,,,,, -11300,0.16273887,4.224958,,,,,,,,,,,,,,,,, -11400,0.19789875,4.20639,,,,,,,,,,,,,,,,, -11500,0.17331962,4.2208047,,,,,,,,,,,,,,,,, -11600,0.16269565,4.1800637,,,,,,,,,,,,,,,,, -11700,0.16219518,4.2304034,,,,,,,,,,,,,,,,, -11800,0.16227578,4.1212535,,,,,,,,,,,,,,,,, -11900,0.168709,4.2076616,,,,,,,,,,,,,,,,, -12000,0.18081452,4.226919,,,,,,,,,,,,,,,,, -12068,,,0.6128459572792053,2.18766450881958,29.37454607756388,0.6334205269813538,2.038684844970703,26.08702096366193,3000.0,0.6416710615158081,1.9897395372390747,25.192112223007232,3003.0,4228.148333311081,7532.544618368149,4228.148333311081,3303.861836194992,0.1316878795623779,0.0 -12100,0.16065511,4.14626,,,,,,,,,,,,,,,,, -12200,0.150078,4.1954374,,,,,,,,,,,,,,,,, -12300,0.19969487,4.122411,,,,,,,,,,,,,,,,, -12400,0.17394592,4.1620493,,,,,,,,,,,,,,,,, -12500,0.17891781,4.238187,,,,,,,,,,,,,,,,, -12600,0.17158031,4.1985197,,,,,,,,,,,,,,,,, -12700,0.16678561,4.2260847,,,,,,,,,,,,,,,,, -12800,0.16242933,4.192098,,,,,,,,,,,,,,,,, -12900,0.16456239,4.1090894,,,,,,,,,,,,,,,,, -13000,0.15449701,4.1278386,,,,,,,,,,,,,,,,, -13100,0.15806215,4.160906,,,,,,,,,,,,,,,,, -13200,0.15287884,4.164647,,,,,,,,,,,,,,,,, -13300,0.16206802,4.136923,,,,,,,,,,,,,,,,, -13400,0.16659884,4.172173,,,,,,,,,,,,,,,,, -13500,0.16333257,4.137351,,,,,,,,,,,,,,,,, -13600,0.1509775,4.1299725,,,,,,,,,,,,,,,,, -13700,0.1512101,4.1469517,,,,,,,,,,,,,,,,, -13800,0.15049078,4.0783973,,,,,,,,,,,,,,,,, -13900,0.17039177,4.211523,,,,,,,,,,,,,,,,, -14000,0.14564939,4.129735,,,,,,,,,,,,,,,,, -14100,0.15036821,4.138506,,,,,,,,,,,,,,,,, -14200,0.1502819,4.142111,,,,,,,,,,,,,,,,, -14300,0.17721878,4.078578,,,,,,,,,,,,,,,,, -14400,0.14875154,4.1310987,,,,,,,,,,,,,,,,, -14483,,,0.6248981952667236,2.080275535583496,30.089301867317968,0.641777515411377,1.9527431726455688,27.244437164531448,3000.0,0.6509325504302979,1.8961308002471924,26.48782476873385,3003.0,5068.229043722153,8812.05256319046,5068.229043722153,3743.176248073578,0.162806749343872,0.0 -14500,0.15675183,4.0696425,,,,,,,,,,,,,,,,, -14600,0.14959258,4.1291027,,,,,,,,,,,,,,,,, -14700,0.15308657,4.165112,,,,,,,,,,,,,,,,, -14800,0.16522042,4.10889,,,,,,,,,,,,,,,,, -14900,0.15036398,4.1466875,,,,,,,,,,,,,,,,, -15000,0.156108,4.087738,,,,,,,,,,,,,,,,, -15100,0.16351514,4.1629066,,,,,,,,,,,,,,,,, -15200,0.16186272,4.168757,,,,,,,,,,,,,,,,, -15300,0.17980611,4.1279664,,,,,,,,,,,,,,,,, -15400,0.14975096,4.079939,,,,,,,,,,,,,,,,, -15500,0.16608204,4.0660458,,,,,,,,,,,,,,,,, -15600,0.16030511,4.139727,,,,,,,,,,,,,,,,, -15700,0.18346469,4.1542916,,,,,,,,,,,,,,,,, -15800,0.17451389,4.0869884,,,,,,,,,,,,,,,,, -15900,0.15282868,4.1177845,,,,,,,,,,,,,,,,, -16000,0.14974682,4.006855,,,,,,,,,,,,,,,,, -16100,0.15431118,4.113213,,,,,,,,,,,,,,,,, -16200,0.1428017,4.0371633,,,,,,,,,,,,,,,,, -16300,0.1584688,4.0704503,,,,,,,,,,,,,,,,, -16400,0.19182625,4.049878,,,,,,,,,,,,,,,,, -16500,0.155091,4.1298714,,,,,,,,,,,,,,,,, -16600,0.16709581,4.009322,,,,,,,,,,,,,,,,, -16700,0.15917893,4.1302524,,,,,,,,,,,,,,,,, -16800,0.15976542,4.041918,,,,,,,,,,,,,,,,, -16900,,,0.6296117901802063,2.061455726623535,30.56911605973895,0.6512380242347717,1.9128708839416504,27.760250709466792,3000.0,0.6578467488288879,1.8625454902648928,26.744218529681,3003.0,5908.441868543625,10119.24120926857,5908.441868543625,4210.046847343445,0.1902122497558593,0.0 -16900,0.15620913,4.109905,,,,,,,,,,,,,,,,, -17000,0.1796731,4.18407,,,,,,,,,,,,,,,,, -17100,0.15549517,4.114802,,,,,,,,,,,,,,,,, -17200,0.15413104,4.041161,,,,,,,,,,,,,,,,, -17300,0.15457016,4.0925264,,,,,,,,,,,,,,,,, -17400,0.20450868,4.042363,,,,,,,,,,,,,,,,, -17500,0.15302011,4.0763574,,,,,,,,,,,,,,,,, -17600,0.15837917,4.1045675,,,,,,,,,,,,,,,,, -17700,0.15120855,4.041872,,,,,,,,,,,,,,,,, -17800,0.14350614,4.062099,,,,,,,,,,,,,,,,, -17900,0.16018696,4.040727,,,,,,,,,,,,,,,,, -18000,0.20446159,4.0592055,,,,,,,,,,,,,,,,, -18100,0.14999065,4.0103707,,,,,,,,,,,,,,,,, -18200,0.15913445,4.1417503,,,,,,,,,,,,,,,,, -18300,0.14884329,3.995737,,,,,,,,,,,,,,,,, -18400,0.18526152,4.05573,,,,,,,,,,,,,,,,, -18500,0.15489009,4.05974,,,,,,,,,,,,,,,,, -18600,0.15542008,4.0717936,,,,,,,,,,,,,,,,, -18700,0.15707462,4.004346,,,,,,,,,,,,,,,,, -18800,0.15623826,4.0775614,,,,,,,,,,,,,,,,, -18900,0.1802299,3.9881048,,,,,,,,,,,,,,,,, -19000,0.15601994,3.9981742,,,,,,,,,,,,,,,,, -19100,0.20442264,4.054589,,,,,,,,,,,,,,,,, -19200,0.1811348,4.021943,,,,,,,,,,,,,,,,, -19300,0.18115562,3.970094,,,,,,,,,,,,,,,,, -19316,,,0.6452200412750244,1.942769169807434,31.575388848299944,0.6562596559524536,1.870498538017273,28.11841567073765,3000.0,0.6642031669616699,1.809813857078552,27.05667497708177,3003.0,6748.534591674805,11455.637856960297,6748.534591674805,4706.243428945541,0.2173449993133545,0.0 -19400,0.16222757,4.0647902,,,,,,,,,,,,,,,,, -19500,0.2433044,3.9741209,,,,,,,,,,,,,,,,, -19600,0.15605085,4.1190543,,,,,,,,,,,,,,,,, -19700,0.1582164,4.0339036,,,,,,,,,,,,,,,,, -19800,0.16691186,4.0289683,,,,,,,,,,,,,,,,, -19900,0.16190474,4.0772033,,,,,,,,,,,,,,,,, -20000,0.16804793,4.026738,,,,,,,,,,,,,,,,, -20100,0.17573562,3.9712203,,,,,,,,,,,,,,,,, -20200,0.16088903,3.9843786,,,,,,,,,,,,,,,,, -20300,0.16160303,4.0481663,,,,,,,,,,,,,,,,, -20400,0.16755709,4.0807843,,,,,,,,,,,,,,,,, -20500,0.18347998,4.003741,,,,,,,,,,,,,,,,, -20600,0.17405485,3.9448948,,,,,,,,,,,,,,,,, -20700,0.1563434,3.9487827,,,,,,,,,,,,,,,,, -20800,0.16712403,4.0245004,,,,,,,,,,,,,,,,, -20900,0.16301936,3.9947033,,,,,,,,,,,,,,,,, -21000,0.15873653,4.001884,,,,,,,,,,,,,,,,, -21100,0.17244372,4.030586,,,,,,,,,,,,,,,,, -21200,0.18142863,4.045987,,,,,,,,,,,,,,,,, -21300,0.20333947,4.0401397,,,,,,,,,,,,,,,,, -21400,0.16968584,4.0398693,,,,,,,,,,,,,,,,, -21500,0.17286627,4.021098,,,,,,,,,,,,,,,,, -21600,0.16725689,3.9549372,,,,,,,,,,,,,,,,, -21700,0.16540952,3.932931,,,,,,,,,,,,,,,,, -21732,,,0.6420875787734985,1.9512648582458496,31.35088749641315,0.6596198081970215,1.8318740129470823,28.56932057810524,3000.0,0.6690372824668884,1.7705334424972534,27.950302449347426,3003.0,7588.722680091858,12771.588398694992,7588.722680091858,5181.89380979538,0.2472145557403564,0.0 -21800,0.17814462,4.000715,,,,,,,,,,,,,,,,, -21900,0.16558307,4.0580626,,,,,,,,,,,,,,,,, -22000,0.16724479,4.1192904,,,,,,,,,,,,,,,,, -22100,0.1717353,4.0656786,,,,,,,,,,,,,,,,, -22200,0.1613133,4.066546,,,,,,,,,,,,,,,,, -22300,0.18635117,4.0264735,,,,,,,,,,,,,,,,, -22400,0.18272884,4.0604825,,,,,,,,,,,,,,,,, -22500,0.16952704,4.0196347,,,,,,,,,,,,,,,,, -22600,0.16659158,4.0297995,,,,,,,,,,,,,,,,, -22700,0.20138498,3.9804945,,,,,,,,,,,,,,,,, -22800,0.15870346,3.9701695,,,,,,,,,,,,,,,,, -22900,0.17538732,4.0152946,,,,,,,,,,,,,,,,, -23000,0.18117423,3.9861042,,,,,,,,,,,,,,,,, -23100,0.21433406,3.9781468,,,,,,,,,,,,,,,,, -23200,0.20492004,4.0166235,,,,,,,,,,,,,,,,, -23300,0.16211925,4.000675,,,,,,,,,,,,,,,,, -23400,0.17984734,4.049149,,,,,,,,,,,,,,,,, -23500,0.17722821,3.9624083,,,,,,,,,,,,,,,,, -23600,0.17598492,4.039781,,,,,,,,,,,,,,,,, -23700,0.20596032,3.972359,,,,,,,,,,,,,,,,, -23800,0.1718154,3.9695804,,,,,,,,,,,,,,,,, -23900,0.19741619,4.0154624,,,,,,,,,,,,,,,,, -24000,0.17397128,3.9858997,,,,,,,,,,,,,,,,, -24100,0.17785202,3.9936547,,,,,,,,,,,,,,,,, -24147,,,0.6424695253372192,1.9587422609329224,31.319786346641795,0.6618020534515381,1.812417268753052,28.48977382265801,3000.0,0.6719539761543274,1.7517828941345217,28.07511583204494,3003.0,8428.657732963562,14092.263476133348,8428.657732963562,5662.523268461227,0.2762696743011474,0.0 -24200,0.2542359,3.9937727,,,,,,,,,,,,,,,,, -24300,0.1848127,4.0039434,,,,,,,,,,,,,,,,, -24400,0.24886133,3.9337306,,,,,,,,,,,,,,,,, -24500,0.16671956,3.9359791,,,,,,,,,,,,,,,,, -24600,0.1510605,4.0145755,,,,,,,,,,,,,,,,, -24700,0.16899824,3.9466841,,,,,,,,,,,,,,,,, -24800,0.1935019,3.9621162,,,,,,,,,,,,,,,,, -24900,0.18743767,4.0106964,,,,,,,,,,,,,,,,, -25000,0.18107735,3.9128227,,,,,,,,,,,,,,,,, -25100,0.18249293,4.0058503,,,,,,,,,,,,,,,,, -25200,0.2067789,4.031285,,,,,,,,,,,,,,,,, -25300,0.17956299,3.949326,,,,,,,,,,,,,,,,, -25400,0.19388407,3.944121,,,,,,,,,,,,,,,,, -25500,0.21790345,3.9202764,,,,,,,,,,,,,,,,, -25600,0.18927579,3.95625,,,,,,,,,,,,,,,,, -25700,0.17813358,3.9629462,,,,,,,,,,,,,,,,, -25800,0.1785952,4.016351,,,,,,,,,,,,,,,,, -25900,0.17964914,3.916852,,,,,,,,,,,,,,,,, -26000,0.17919342,3.9708478,,,,,,,,,,,,,,,,, -26100,0.19963329,3.9993873,,,,,,,,,,,,,,,,, -26200,0.18772386,3.997765,,,,,,,,,,,,,,,,, -26300,0.17703152,3.9715548,,,,,,,,,,,,,,,,, -26400,0.19772702,3.9640627,,,,,,,,,,,,,,,,, -26500,0.16723046,3.9726398,,,,,,,,,,,,,,,,, -26562,,,0.6491546630859375,1.897887825965881,32.128709662827625,0.6637735366821289,1.7884488105773926,28.636768559280444,3000.0,0.6746034622192383,1.7212979793548584,28.215674239974483,3003.0,9268.82431268692,15591.191541194916,9268.82431268692,6321.172736406326,0.3053386211395263,0.0 -26600,0.19614978,3.9786675,,,,,,,,,,,,,,,,, -26700,0.17298613,3.9195035,,,,,,,,,,,,,,,,, -26800,0.18536387,3.971221,,,,,,,,,,,,,,,,, -26900,0.18843617,3.966914,,,,,,,,,,,,,,,,, -27000,0.21240316,3.977241,,,,,,,,,,,,,,,,, -27100,0.25509378,3.9252236,,,,,,,,,,,,,,,,, -27200,0.20358518,3.985929,,,,,,,,,,,,,,,,, -27300,0.18008961,3.9844575,,,,,,,,,,,,,,,,, -27400,0.18607087,3.9502294,,,,,,,,,,,,,,,,, -27500,0.2638177,4.0259356,,,,,,,,,,,,,,,,, -27600,0.26981702,3.9802897,,,,,,,,,,,,,,,,, -27700,0.20252438,4.01982,,,,,,,,,,,,,,,,, -27800,0.19881763,3.9533942,,,,,,,,,,,,,,,,, -27900,0.20847334,3.9663756,,,,,,,,,,,,,,,,, -28000,0.1837373,3.9942582,,,,,,,,,,,,,,,,, -28100,0.2573745,3.9868755,,,,,,,,,,,,,,,,, -28200,0.2182292,3.907525,,,,,,,,,,,,,,,,, -28300,0.18023393,3.9027607,,,,,,,,,,,,,,,,, -28400,0.2901226,3.9489608,,,,,,,,,,,,,,,,, -28500,0.19037342,3.9726002,,,,,,,,,,,,,,,,, -28600,0.18156253,3.8906233,,,,,,,,,,,,,,,,, -28700,0.20262928,3.9823377,,,,,,,,,,,,,,,,, -28800,0.18665402,3.9447994,,,,,,,,,,,,,,,,, -28900,0.19187753,3.9861684,,,,,,,,,,,,,,,,, -28978,,,0.6483582854270935,1.9046748876571653,31.718846363850467,0.6666997075080872,1.7674710750579834,28.82188555062492,3000.0,0.6767067909240723,1.701871633529663,28.1050435110422,3003.0,10108.904735088348,16915.69260573387,10108.904735088348,6805.479952096939,0.3381767272949219,0.0 -29000,0.22514296,3.9902132,,,,,,,,,,,,,,,,, -29100,0.18928602,3.9418612,,,,,,,,,,,,,,,,, -29200,0.25481975,3.9689105,,,,,,,,,,,,,,,,, -29300,0.20453799,3.9821486,,,,,,,,,,,,,,,,, -29400,0.24034962,3.967474,,,,,,,,,,,,,,,,, -29500,0.19078094,3.8741903,,,,,,,,,,,,,,,,, -29600,0.1954415,3.9066691,,,,,,,,,,,,,,,,, -29700,0.2391758,3.9995046,,,,,,,,,,,,,,,,, -29800,0.19604324,3.967161,,,,,,,,,,,,,,,,, -29900,0.19256797,3.9509678,,,,,,,,,,,,,,,,, -30000,0.21780443,3.939441,,,,,,,,,,,,,,,,, -30100,0.19640577,3.9506853,,,,,,,,,,,,,,,,, -30200,0.21743658,3.9534104,,,,,,,,,,,,,,,,, -30300,0.19227426,3.9200847,,,,,,,,,,,,,,,,, -30400,0.19737665,3.929103,,,,,,,,,,,,,,,,, -30500,0.2067082,3.928431,,,,,,,,,,,,,,,,, -30600,0.27417576,3.952674,,,,,,,,,,,,,,,,, -30700,0.18364774,3.9332438,,,,,,,,,,,,,,,,, -30800,0.20041093,3.9035718,,,,,,,,,,,,,,,,, -30900,0.2001885,3.9073758,,,,,,,,,,,,,,,,, -31000,0.21684289,3.8862903,,,,,,,,,,,,,,,,, -31100,0.20309846,3.9345005,,,,,,,,,,,,,,,,, -31200,0.2019973,3.9709918,,,,,,,,,,,,,,,,, -31300,0.19069472,3.951365,,,,,,,,,,,,,,,,, -31393,,,0.6834477782249451,1.6875489950180054,34.22825500162002,0.6686215996742249,1.760544776916504,28.802539963238083,3000.0,0.6804369688034058,1.6927775144577026,28.100285054452986,3003.0,10948.82987356186,18281.74568796158,10948.82987356186,7331.497166872024,0.3668546676635742,0.0 -31400,0.33216512,3.9781218,,,,,,,,,,,,,,,,, -31500,0.24586083,3.9578958,,,,,,,,,,,,,,,,, -31600,0.1873015,3.904448,,,,,,,,,,,,,,,,, -31700,0.25187624,3.9870665,,,,,,,,,,,,,,,,, -31800,0.26387456,3.9818814,,,,,,,,,,,,,,,,, -31900,0.20502603,3.9824576,,,,,,,,,,,,,,,,, -32000,0.22971238,3.9614527,,,,,,,,,,,,,,,,, -32100,0.21558695,3.9563417,,,,,,,,,,,,,,,,, -32200,0.2646772,3.9339862,,,,,,,,,,,,,,,,, -32300,0.21775378,3.875419,,,,,,,,,,,,,,,,, -32400,0.21216542,3.9587197,,,,,,,,,,,,,,,,, -32500,0.19700214,3.8259788,,,,,,,,,,,,,,,,, -32600,0.2001425,3.9378853,,,,,,,,,,,,,,,,, -32700,0.22780903,3.9623764,,,,,,,,,,,,,,,,, -32800,0.20079805,4.0130863,,,,,,,,,,,,,,,,, -32900,0.24117282,3.9126055,,,,,,,,,,,,,,,,, -33000,0.19597511,3.8813605,,,,,,,,,,,,,,,,, -33100,0.22898683,3.9518027,,,,,,,,,,,,,,,,, -33200,0.23660012,3.9392238,,,,,,,,,,,,,,,,, -33300,0.21306162,3.8554945,,,,,,,,,,,,,,,,, -33400,0.21094164,3.9602325,,,,,,,,,,,,,,,,, -33500,0.20440632,3.919313,,,,,,,,,,,,,,,,, -33600,0.21421868,3.9765327,,,,,,,,,,,,,,,,, -33700,0.2865983,3.9629831,,,,,,,,,,,,,,,,, -33800,0.2292834,3.9198258,,,,,,,,,,,,,,,,, -33808,,,0.6541592478752136,1.8681278228759768,31.89343011566094,0.6689191460609436,1.7596172094345093,29.06737236228412,3000.0,0.681715190410614,1.691118240356445,28.551759458820776,3003.0,11788.780126810074,19732.2159409523,11788.780126810074,7941.9084758758545,0.3961889743804931,0.0 -33900,0.27747843,3.9557145,,,,,,,,,,,,,,,,, -34000,0.25708425,3.911964,,,,,,,,,,,,,,,,, -34100,0.22857171,3.965915,,,,,,,,,,,,,,,,, -34200,0.2132413,3.9313595,,,,,,,,,,,,,,,,, -34300,0.21902184,3.9404368,,,,,,,,,,,,,,,,, -34400,0.22441366,3.8926816,,,,,,,,,,,,,,,,, -34500,0.2913384,3.9209495,,,,,,,,,,,,,,,,, -34600,0.24490176,3.9251683,,,,,,,,,,,,,,,,, -34700,0.24945207,3.9194489,,,,,,,,,,,,,,,,, -34800,0.21046366,3.8369036,,,,,,,,,,,,,,,,, -34900,0.23151886,3.9146194,,,,,,,,,,,,,,,,, -35000,0.21064802,3.8857584,,,,,,,,,,,,,,,,, -35100,0.20176208,3.910406,,,,,,,,,,,,,,,,, -35200,0.23840483,3.9093616,,,,,,,,,,,,,,,,, -35300,0.1860209,3.8587694,,,,,,,,,,,,,,,,, -35400,0.28733405,3.94213,,,,,,,,,,,,,,,,, -35500,0.23788188,3.933155,,,,,,,,,,,,,,,,, -35600,0.22504409,3.9250853,,,,,,,,,,,,,,,,, -35700,0.21156241,3.8494227,,,,,,,,,,,,,,,,, -35800,0.21932654,3.9494677,,,,,,,,,,,,,,,,, -35900,0.2510977,3.9670935,,,,,,,,,,,,,,,,, -36000,0.2938572,3.874793,,,,,,,,,,,,,,,,, -36100,0.20873858,3.8914244,,,,,,,,,,,,,,,,, -36200,0.2586239,3.9070363,,,,,,,,,,,,,,,,, -36224,,,0.654525101184845,1.8604371547698968,32.16741371237828,0.673159658908844,1.7235437631607056,29.286452756607737,3000.0,0.6827842593193054,1.6493988037109375,28.751117281827,3003.0,12629.00738811493,21093.10581278801,12629.00738811493,8462.458575487137,0.4259674549102783,0.0 -36300,0.22406551,3.8923724,,,,,,,,,,,,,,,,, -36400,0.23255356,3.8874655,,,,,,,,,,,,,,,,, -36500,0.2405254,3.8742123,,,,,,,,,,,,,,,,, -36600,0.22691312,3.8676925,,,,,,,,,,,,,,,,, -36700,0.2884138,3.8597717,,,,,,,,,,,,,,,,, -36800,0.3033909,3.9173813,,,,,,,,,,,,,,,,, -36900,0.20687401,3.9639776,,,,,,,,,,,,,,,,, -37000,0.205548,3.8795953,,,,,,,,,,,,,,,,, -37100,0.2576321,3.9258013,,,,,,,,,,,,,,,,, -37200,0.24386472,3.89338,,,,,,,,,,,,,,,,, -37300,0.2293287,3.9491255,,,,,,,,,,,,,,,,, -37400,0.2087589,3.8899274,,,,,,,,,,,,,,,,, -37500,0.2319504,3.9046633,,,,,,,,,,,,,,,,, -37600,0.26861224,3.918473,,,,,,,,,,,,,,,,, -37700,0.21357863,3.86403,,,,,,,,,,,,,,,,, -37800,0.23704375,3.8486767,,,,,,,,,,,,,,,,, -37900,0.21427333,3.9712765,,,,,,,,,,,,,,,,, -38000,0.25543502,3.8883164,,,,,,,,,,,,,,,,, -38100,0.20540485,3.92153,,,,,,,,,,,,,,,,, -38200,0.24009213,3.8434935,,,,,,,,,,,,,,,,, -38300,0.26928288,3.8578298,,,,,,,,,,,,,,,,, -38400,0.23284356,3.918891,,,,,,,,,,,,,,,,, -38500,0.22844365,3.8946488,,,,,,,,,,,,,,,,, -38600,0.2171895,3.859374,,,,,,,,,,,,,,,,, -38641,,,0.6630606055259705,1.7945261001586914,32.4857500217436,0.6725273132324219,1.7204879522323608,29.01776935916361,3000.0,0.6843181848526001,1.6504517793655396,28.69606835632274,3003.0,13469.150447130203,22382.396797180176,13469.150447130203,8911.495740890503,0.4579017162322998,0.0 -38700,0.20747733,3.8740132,,,,,,,,,,,,,,,,, -38800,0.22634582,3.9296691,,,,,,,,,,,,,,,,, -38900,0.23182403,3.888436,,,,,,,,,,,,,,,,, -39000,0.21957831,3.9126806,,,,,,,,,,,,,,,,, -39100,0.23987807,3.868548,,,,,,,,,,,,,,,,, -39200,0.29187286,3.9282103,,,,,,,,,,,,,,,,, -39300,0.25499895,3.8603048,,,,,,,,,,,,,,,,, -39400,0.21250874,3.8783145,,,,,,,,,,,,,,,,, -39500,0.25699615,3.9066873,,,,,,,,,,,,,,,,, -39600,0.20723474,3.884043,,,,,,,,,,,,,,,,, -39700,0.24580547,3.9381561,,,,,,,,,,,,,,,,, -39800,0.27526096,3.941647,,,,,,,,,,,,,,,,, -39900,0.28125477,3.9186916,,,,,,,,,,,,,,,,, -40000,0.25159702,3.875761,,,,,,,,,,,,,,,,, -40100,0.22127052,3.8666515,,,,,,,,,,,,,,,,, -40200,0.2430962,3.9261646,,,,,,,,,,,,,,,,, -40300,0.28458616,3.8511834,,,,,,,,,,,,,,,,, -40400,0.22025694,3.9055746,,,,,,,,,,,,,,,,, -40500,0.31886247,3.928965,,,,,,,,,,,,,,,,, -40600,0.22659828,3.8786852,,,,,,,,,,,,,,,,, -40700,0.2164961,3.8919263,,,,,,,,,,,,,,,,, -40800,0.23577291,3.9055052,,,,,,,,,,,,,,,,, -40900,0.22024138,3.9566112,,,,,,,,,,,,,,,,, -41000,0.21591042,3.8878038,,,,,,,,,,,,,,,,, -41057,,,0.6570492386817932,1.825039505958557,32.32277040187047,0.6732960343360901,1.7080892324447632,29.21355250631401,3000.0,0.6871187090873718,1.634048342704773,29.00677069047909,3003.0,14309.070579051971,23699.610480308533,14309.070579051971,9388.679366111755,0.4881284236907959,0.0 -41100,0.22200574,3.8779209,,,,,,,,,,,,,,,,, -41200,0.2940682,3.9146705,,,,,,,,,,,,,,,,, -41300,0.22967958,3.9237533,,,,,,,,,,,,,,,,, -41400,0.22844474,3.905107,,,,,,,,,,,,,,,,, -41500,0.25051203,3.8493135,,,,,,,,,,,,,,,,, -41600,0.23395325,3.972989,,,,,,,,,,,,,,,,, -41700,0.26117942,3.8809645,,,,,,,,,,,,,,,,, -41800,0.25012708,3.9106786,,,,,,,,,,,,,,,,, -41900,0.25203785,3.873426,,,,,,,,,,,,,,,,, -42000,0.25612327,3.878897,,,,,,,,,,,,,,,,, -42100,0.22763826,3.8562396,,,,,,,,,,,,,,,,, -42200,0.22757897,3.8661928,,,,,,,,,,,,,,,,, -42300,0.23074907,3.8569858,,,,,,,,,,,,,,,,, -42400,0.27569926,3.8726475,,,,,,,,,,,,,,,,, -42500,0.24965554,3.8840017,,,,,,,,,,,,,,,,, -42600,0.23408633,3.9075685,,,,,,,,,,,,,,,,, -42700,0.23828349,3.916125,,,,,,,,,,,,,,,,, -42800,0.29915595,3.8417192,,,,,,,,,,,,,,,,, -42900,0.2501851,3.8824058,,,,,,,,,,,,,,,,, -43000,0.25009143,3.862282,,,,,,,,,,,,,,,,, -43100,0.24462423,3.8916843,,,,,,,,,,,,,,,,, -43200,0.24372216,3.8728814,,,,,,,,,,,,,,,,, -43300,0.26399302,3.9551005,,,,,,,,,,,,,,,,, -43400,0.24546535,3.8631597,,,,,,,,,,,,,,,,, -43472,,,0.6580873727798462,1.8422174453735352,32.67865687974197,0.6763338446617126,1.7116023302078247,29.45904070467425,3000.0,0.6887455582618713,1.635982632637024,28.902835204010888,3003.0,15149.12542128563,25003.713657855988,15149.12542128563,9852.615337371826,0.5199141502380371,0.0 -43500,0.23486148,3.8354425,,,,,,,,,,,,,,,,, -43600,0.263214,3.877975,,,,,,,,,,,,,,,,, -43700,0.2513165,3.8983111,,,,,,,,,,,,,,,,, -43800,0.23362957,3.9646301,,,,,,,,,,,,,,,,, -43900,0.24370305,3.883312,,,,,,,,,,,,,,,,, -44000,0.2538607,3.890081,,,,,,,,,,,,,,,,, -44100,0.25301155,3.9183993,,,,,,,,,,,,,,,,, -44200,0.24355228,3.8888586,,,,,,,,,,,,,,,,, -44300,0.23883107,3.9091692,,,,,,,,,,,,,,,,, -44400,0.24175943,3.8174806,,,,,,,,,,,,,,,,, -44500,0.2492773,3.893803,,,,,,,,,,,,,,,,, -44600,0.23242527,3.8085501,,,,,,,,,,,,,,,,, -44700,0.2557594,3.899404,,,,,,,,,,,,,,,,, -44800,0.23848198,3.8769178,,,,,,,,,,,,,,,,, -44900,0.28498033,3.9805458,,,,,,,,,,,,,,,,, -45000,0.25732607,3.8653474,,,,,,,,,,,,,,,,, -45100,0.25416157,3.8697934,,,,,,,,,,,,,,,,, -45200,0.27568737,3.8402479,,,,,,,,,,,,,,,,, -45300,0.22765537,3.8552957,,,,,,,,,,,,,,,,, -45400,0.28179535,3.9293444,,,,,,,,,,,,,,,,, -45500,0.2766263,3.9033155,,,,,,,,,,,,,,,,, -45600,0.24953654,3.8477435,,,,,,,,,,,,,,,,, -45700,0.2341264,3.8402805,,,,,,,,,,,,,,,,, -45800,0.24623463,3.802832,,,,,,,,,,,,,,,,, -45888,,,0.6625374555587769,1.791016936302185,32.6860518428758,0.6753295063972473,1.6926016807556152,29.099113091513008,3000.0,0.6890825629234314,1.618253231048584,28.80330193904192,3003.0,15989.273291826248,26466.803416490555,15989.273291826248,10475.447494745256,0.5502684116363525,0.0 -45900,0.23112638,3.8757427,,,,,,,,,,,,,,,,, -46000,0.28232265,3.927458,,,,,,,,,,,,,,,,, -46100,0.24593213,3.8498635,,,,,,,,,,,,,,,,, -46200,0.230852,3.8341265,,,,,,,,,,,,,,,,, -46300,0.26317528,3.913739,,,,,,,,,,,,,,,,, -46400,0.23308256,3.856984,,,,,,,,,,,,,,,,, -46500,0.22354552,3.8808982,,,,,,,,,,,,,,,,, -46600,0.23127978,3.8741171,,,,,,,,,,,,,,,,, -46700,0.24321757,3.864215,,,,,,,,,,,,,,,,, -46800,0.28941917,3.8970244,,,,,,,,,,,,,,,,, -46900,0.3225288,3.894127,,,,,,,,,,,,,,,,, -47000,0.24150021,3.844776,,,,,,,,,,,,,,,,, -47100,0.23916963,3.8326693,,,,,,,,,,,,,,,,, -47200,0.2558705,3.8132153,,,,,,,,,,,,,,,,, -47300,0.28675988,3.8192885,,,,,,,,,,,,,,,,, -47400,0.24837591,3.8759923,,,,,,,,,,,,,,,,, -47500,0.23811552,3.8766084,,,,,,,,,,,,,,,,, -47600,0.26613945,3.8180552,,,,,,,,,,,,,,,,, -47700,0.24880156,3.9129307,,,,,,,,,,,,,,,,, -47800,0.2640885,3.9014854,,,,,,,,,,,,,,,,, -47900,0.24127725,3.93682,,,,,,,,,,,,,,,,, -48000,0.23250581,3.825518,,,,,,,,,,,,,,,,, -48100,0.24884039,3.8544745,,,,,,,,,,,,,,,,, -48200,0.26302913,3.8290198,,,,,,,,,,,,,,,,, -48300,0.253456,3.9137785,,,,,,,,,,,,,,,,, -48304,,,0.6612311601638794,1.815684676170349,32.39720039613303,0.6767057776451111,1.702085256576538,29.48392422047256,3000.0,0.6889082789421082,1.6244529485702517,29.31796205687577,3003.0,16829.422029733658,27796.0339281559,16829.422029733658,10964.418253660202,0.5813858509063721,0.0 -48400,0.28414038,3.8493543,,,,,,,,,,,,,,,,, -48500,0.24771504,3.9356222,,,,,,,,,,,,,,,,, -48600,0.2785148,3.8303168,,,,,,,,,,,,,,,,, -48700,0.29411712,3.8658938,,,,,,,,,,,,,,,,, -48800,0.24759981,3.8311055,,,,,,,,,,,,,,,,, -48900,0.26470244,3.8876853,,,,,,,,,,,,,,,,, -49000,0.28563577,3.887974,,,,,,,,,,,,,,,,, -49100,0.25711328,3.8490784,,,,,,,,,,,,,,,,, -49200,0.2493715,3.8944345,,,,,,,,,,,,,,,,, -49300,0.24706686,3.8731003,,,,,,,,,,,,,,,,, -49400,0.24079777,3.8288279,,,,,,,,,,,,,,,,, -49500,0.24444285,3.8799357,,,,,,,,,,,,,,,,, -49600,0.29989144,3.8764117,,,,,,,,,,,,,,,,, -49700,0.24229687,3.8389351,,,,,,,,,,,,,,,,, -49800,0.24325675,3.813201,,,,,,,,,,,,,,,,, -49900,0.26211286,3.8458796,,,,,,,,,,,,,,,,, -50000,0.23071757,3.875121,,,,,,,,,,,,,,,,, -50100,0.24453096,3.8316243,,,,,,,,,,,,,,,,, -50200,0.24151222,3.8416777,,,,,,,,,,,,,,,,, -50300,0.25925103,3.8625512,,,,,,,,,,,,,,,,, -50400,0.28128782,3.8916407,,,,,,,,,,,,,,,,, -50500,0.25577477,3.8971245,,,,,,,,,,,,,,,,, -50600,0.24545686,3.887358,,,,,,,,,,,,,,,,, -50700,0.26799423,3.8969858,,,,,,,,,,,,,,,,, -50720,,,0.6744092702865601,1.718177080154419,33.865250969584935,0.6792972087860107,1.6851235628128052,29.988029477140355,3000.0,0.6897914409637451,1.608918070793152,29.17607198808368,3003.0,17669.506113767624,29109.21942853928,17669.506113767624,11437.406776428224,0.6123223304748535,0.0 -50800,0.24883074,3.848322,,,,,,,,,,,,,,,,, -50900,0.2735013,3.8857994,,,,,,,,,,,,,,,,, -51000,0.23069109,3.8274803,,,,,,,,,,,,,,,,, -51100,0.30484298,3.866778,,,,,,,,,,,,,,,,, -51200,0.2549028,3.852472,,,,,,,,,,,,,,,,, -51300,0.25800866,3.911924,,,,,,,,,,,,,,,,, -51400,0.25362182,3.847913,,,,,,,,,,,,,,,,, -51500,0.2585657,3.8568373,,,,,,,,,,,,,,,,, -51600,0.2638691,3.8321133,,,,,,,,,,,,,,,,, -51700,0.25142184,3.9125752,,,,,,,,,,,,,,,,, -51800,0.24457788,3.878497,,,,,,,,,,,,,,,,, -51900,0.23447348,3.804636,,,,,,,,,,,,,,,,, -52000,0.25804788,3.9036143,,,,,,,,,,,,,,,,, -52100,0.24702561,3.867855,,,,,,,,,,,,,,,,, -52200,0.27951404,3.840674,,,,,,,,,,,,,,,,, -52300,0.2694282,3.8409047,,,,,,,,,,,,,,,,, -52400,0.35332903,3.893214,,,,,,,,,,,,,,,,, -52500,0.2743119,3.8051043,,,,,,,,,,,,,,,,, -52600,0.26354238,3.8299727,,,,,,,,,,,,,,,,, -52700,0.2570216,3.7955549,,,,,,,,,,,,,,,,, -52800,0.28179216,3.8529155,,,,,,,,,,,,,,,,, -52900,0.25710014,3.8927546,,,,,,,,,,,,,,,,, -53000,0.2725214,3.8041387,,,,,,,,,,,,,,,,, -53100,0.25357175,3.8739407,,,,,,,,,,,,,,,,, -53136,,,0.6658750772476196,1.772932529449463,32.865700967124376,0.677821695804596,1.675114989280701,29.386465831879853,3000.0,0.6934635043144226,1.597248673439026,29.42710453871561,3003.0,18509.686230421063,30499.49797201157,18509.686230421063,11987.390083789824,0.6450626850128174,0.0 -53200,0.3366339,3.8737283,,,,,,,,,,,,,,,,, -53300,0.27263808,3.8236208,,,,,,,,,,,,,,,,, -53400,0.24794166,3.848375,,,,,,,,,,,,,,,,, -53500,0.33871344,3.7937267,,,,,,,,,,,,,,,,, -53600,0.25079852,3.855547,,,,,,,,,,,,,,,,, -53700,0.3537537,3.7672548,,,,,,,,,,,,,,,,, -53800,0.27793717,3.9041994,,,,,,,,,,,,,,,,, -53900,0.2760518,3.886573,,,,,,,,,,,,,,,,, -54000,0.2864504,3.8528907,,,,,,,,,,,,,,,,, -54100,0.2733452,3.8768756,,,,,,,,,,,,,,,,, -54200,0.26536888,3.866729,,,,,,,,,,,,,,,,, -54300,0.24627933,3.8544421,,,,,,,,,,,,,,,,, -54400,0.2770085,3.8441827,,,,,,,,,,,,,,,,, -54500,0.29191515,3.941954,,,,,,,,,,,,,,,,, -54600,0.2790067,3.9147267,,,,,,,,,,,,,,,,, -54700,0.25256455,3.8728929,,,,,,,,,,,,,,,,, -54800,0.25002417,3.8461037,,,,,,,,,,,,,,,,, -54900,0.28353792,3.8255215,,,,,,,,,,,,,,,,, -55000,0.27923563,3.8323927,,,,,,,,,,,,,,,,, -55100,0.2715745,3.8130453,,,,,,,,,,,,,,,,, -55200,0.24170122,3.8369036,,,,,,,,,,,,,,,,, -55300,0.26394856,3.8429847,,,,,,,,,,,,,,,,, -55400,0.27827042,3.8066335,,,,,,,,,,,,,,,,, -55500,0.25863066,3.941042,,,,,,,,,,,,,,,,, -55552,,,0.6659360527992249,1.7834343910217283,32.94168061929005,0.6789996027946472,1.6796774864196775,29.670184999573905,3000.0,0.6937540173530579,1.5995779037475586,29.22578359814461,3003.0,19349.803040981293,31812.61920762062,19349.803040981293,12460.27916264534,0.6809508800506592,0.0 -55600,0.3379092,3.8484852,,,,,,,,,,,,,,,,, -55700,0.2653084,3.8688357,,,,,,,,,,,,,,,,, -55800,0.28537032,3.8727846,,,,,,,,,,,,,,,,, -55900,0.25212762,3.9520044,,,,,,,,,,,,,,,,, -56000,0.28691122,3.8572106,,,,,,,,,,,,,,,,, -56100,0.2633527,3.823636,,,,,,,,,,,,,,,,, -56200,0.2454479,3.857605,,,,,,,,,,,,,,,,, -56300,0.2655132,3.9109924,,,,,,,,,,,,,,,,, -56400,0.24647003,3.8173263,,,,,,,,,,,,,,,,, -56500,0.25770456,3.8177688,,,,,,,,,,,,,,,,, -56600,0.3170488,3.845603,,,,,,,,,,,,,,,,, -56700,0.2572246,3.8340068,,,,,,,,,,,,,,,,, -56800,0.29117832,3.8549874,,,,,,,,,,,,,,,,, -56900,0.26690358,3.8257506,,,,,,,,,,,,,,,,, -57000,0.27709928,3.8843765,,,,,,,,,,,,,,,,, -57100,0.2737968,3.8222473,,,,,,,,,,,,,,,,, -57200,0.2609997,3.8932064,,,,,,,,,,,,,,,,, -57300,0.28124171,3.8237278,,,,,,,,,,,,,,,,, -57400,0.28389463,3.860672,,,,,,,,,,,,,,,,, -57500,0.25313187,3.8432944,,,,,,,,,,,,,,,,, -57600,0.2645628,3.8780503,,,,,,,,,,,,,,,,, -57700,0.26723504,3.8282251,,,,,,,,,,,,,,,,, -57800,0.26370704,3.8383965,,,,,,,,,,,,,,,,, -57900,0.28307113,3.798837,,,,,,,,,,,,,,,,, -57968,,,0.6726042032241821,1.7362380027770996,33.86621185465534,0.6793344020843506,1.6724979877471924,29.51395916056389,3000.0,0.6937307715415955,1.5895719528198242,29.51395797552482,3003.0,20189.68864178657,33113.54695224762,20189.68864178657,12921.206790924072,0.71565842628479,0.0 -58000,0.2772906,3.8290598,,,,,,,,,,,,,,,,, -58100,0.2614738,3.798895,,,,,,,,,,,,,,,,, -58200,0.30867538,3.8461602,,,,,,,,,,,,,,,,, -58300,0.25687963,3.7957225,,,,,,,,,,,,,,,,, -58400,0.29864323,3.8405917,,,,,,,,,,,,,,,,, -58500,0.25974682,3.8585365,,,,,,,,,,,,,,,,, -58600,0.25138873,3.813821,,,,,,,,,,,,,,,,, -58700,0.25446114,3.8511627,,,,,,,,,,,,,,,,, -58800,0.3234122,3.7747173,,,,,,,,,,,,,,,,, -58900,0.248925,3.8151813,,,,,,,,,,,,,,,,, -59000,0.26847503,3.811874,,,,,,,,,,,,,,,,, -59100,0.25454178,3.8409305,,,,,,,,,,,,,,,,, -59200,0.28733316,3.888592,,,,,,,,,,,,,,,,, -59300,0.30615154,3.8863163,,,,,,,,,,,,,,,,, -59400,0.269398,3.7879796,,,,,,,,,,,,,,,,, -59500,0.28010136,3.8300297,,,,,,,,,,,,,,,,, -59600,0.2495549,3.8118443,,,,,,,,,,,,,,,,, -59700,0.2747148,3.8616414,,,,,,,,,,,,,,,,, -59800,0.27069765,3.8402662,,,,,,,,,,,,,,,,, -59900,0.3134156,3.8226154,,,,,,,,,,,,,,,,, -60000,0.28525442,3.873111,,,,,,,,,,,,,,,,, -60100,0.30052492,3.8291066,,,,,,,,,,,,,,,,, -60200,0.27816302,3.8291667,,,,,,,,,,,,,,,,, -60300,0.2752917,3.8723993,,,,,,,,,,,,,,,,, -60383,,,0.6659606695175171,1.7663089036941528,33.31700121865782,0.6792724132537842,1.6618462800979614,30.04262259680157,3000.0,0.6942071914672852,1.5787086486816406,29.482354140661048,3003.0,21029.69061565399,34416.46136879921,21029.69061565399,13383.99617767334,0.7550392150878906,0.0 -60400,0.35117418,3.872205,,,,,,,,,,,,,,,,, -60500,0.3009367,3.8331668,,,,,,,,,,,,,,,,, -60600,0.2697581,3.7919395,,,,,,,,,,,,,,,,, -60700,0.27163863,3.9027178,,,,,,,,,,,,,,,,, -60800,0.2515851,3.897203,,,,,,,,,,,,,,,,, -60900,0.29209882,3.8407652,,,,,,,,,,,,,,,,, -61000,0.28167635,3.8966446,,,,,,,,,,,,,,,,, -61100,0.27988392,3.8244033,,,,,,,,,,,,,,,,, -61200,0.30022314,3.824108,,,,,,,,,,,,,,,,, -61300,0.28210616,3.7856753,,,,,,,,,,,,,,,,, -61400,0.27575436,3.8346295,,,,,,,,,,,,,,,,, -61500,0.2755484,3.853682,,,,,,,,,,,,,,,,, -61600,0.2544933,3.8639123,,,,,,,,,,,,,,,,, -61700,0.31457412,3.8181922,,,,,,,,,,,,,,,,, -61800,0.27068415,3.8862417,,,,,,,,,,,,,,,,, -61900,0.26076132,3.7656004,,,,,,,,,,,,,,,,, -62000,0.28978598,3.8526099,,,,,,,,,,,,,,,,, -62100,0.27280778,3.8676546,,,,,,,,,,,,,,,,, -62200,0.2524668,3.7878597,,,,,,,,,,,,,,,,, -62300,0.31155238,3.8034291,,,,,,,,,,,,,,,,, -62400,0.28338417,3.7522128,,,,,,,,,,,,,,,,, -62500,0.26349857,3.8281143,,,,,,,,,,,,,,,,, -62600,0.2898218,3.7867138,,,,,,,,,,,,,,,,, -62700,0.29758704,3.835047,,,,,,,,,,,,,,,,, -62799,,,0.6952691674232483,1.5940762758255005,34.85920428088039,0.6813182830810547,1.6626694202423096,30.02150473419823,3000.0,0.6970425844192505,1.576493263244629,29.63074699460067,3003.0,21869.710822820663,35827.3902528286,21869.710822820663,13954.791036367416,0.7893767356872559,0.0 -62800,0.2823805,3.8247745,,,,,,,,,,,,,,,,, -62900,0.32137594,3.8459215,,,,,,,,,,,,,,,,, -63000,0.2925079,3.8300955,,,,,,,,,,,,,,,,, -63100,0.26965994,3.8218038,,,,,,,,,,,,,,,,, -63200,0.2751998,3.8355532,,,,,,,,,,,,,,,,, -63300,0.28503865,3.7754056,,,,,,,,,,,,,,,,, -63400,0.31910938,3.8600643,,,,,,,,,,,,,,,,, -63500,0.2613089,3.8054636,,,,,,,,,,,,,,,,, -63600,0.39788282,3.7951274,,,,,,,,,,,,,,,,, -63700,0.2746194,3.8290677,,,,,,,,,,,,,,,,, -63800,0.27974313,3.8348343,,,,,,,,,,,,,,,,, -63900,0.28262278,3.8361769,,,,,,,,,,,,,,,,, -64000,0.25212997,3.7860196,,,,,,,,,,,,,,,,, -64100,0.33906084,3.8258944,,,,,,,,,,,,,,,,, -64200,0.26341376,3.8422196,,,,,,,,,,,,,,,,, -64300,0.27958465,3.8124092,,,,,,,,,,,,,,,,, -64400,0.2742472,3.8395739,,,,,,,,,,,,,,,,, -64500,0.27065143,3.8633304,,,,,,,,,,,,,,,,, -64600,0.30662137,3.8246033,,,,,,,,,,,,,,,,, -64700,0.29544678,3.858807,,,,,,,,,,,,,,,,, -64800,0.28807703,3.8275988,,,,,,,,,,,,,,,,, -64900,0.27014247,3.7924507,,,,,,,,,,,,,,,,, -65000,0.2634165,3.7995014,,,,,,,,,,,,,,,,, -65100,0.2618005,3.8259168,,,,,,,,,,,,,,,,, -65200,0.27805597,3.780817,,,,,,,,,,,,,,,,, -65215,,,0.6770080327987671,1.7094494104385376,33.7335961330243,0.6842692494392395,1.6536860466003418,30.38194529115567,3000.0,0.6985416412353516,1.574635028839111,29.660603258547127,3003.0,22709.933475017548,37207.97281885147,22709.933475017548,14495.03400874138,0.8237001895904541,0.0 -65300,0.27422142,3.8145213,,,,,,,,,,,,,,,,, -65400,0.27021202,3.850708,,,,,,,,,,,,,,,,, -65500,0.250992,3.793119,,,,,,,,,,,,,,,,, -65600,0.27806023,3.8460987,,,,,,,,,,,,,,,,, -65700,0.2744082,3.8605545,,,,,,,,,,,,,,,,, -65800,0.2600763,3.795949,,,,,,,,,,,,,,,,, -65900,0.27689788,3.8029392,,,,,,,,,,,,,,,,, -66000,0.2839165,3.7755334,,,,,,,,,,,,,,,,, -66100,0.2886662,3.8112683,,,,,,,,,,,,,,,,, -66200,0.27482405,3.8223546,,,,,,,,,,,,,,,,, -66300,0.27170378,3.8453252,,,,,,,,,,,,,,,,, -66400,0.24810079,3.764822,,,,,,,,,,,,,,,,, -66500,0.31588864,3.803267,,,,,,,,,,,,,,,,, -66600,0.28755903,3.8229806,,,,,,,,,,,,,,,,, -66700,0.28455052,3.858416,,,,,,,,,,,,,,,,, -66800,0.29642972,3.7876287,,,,,,,,,,,,,,,,, -66900,0.2971421,3.8349257,,,,,,,,,,,,,,,,, -67000,0.25113395,3.78959,,,,,,,,,,,,,,,,, -67100,0.29124945,3.83895,,,,,,,,,,,,,,,,, -67200,0.28629345,3.7893903,,,,,,,,,,,,,,,,, -67300,0.2841939,3.8225083,,,,,,,,,,,,,,,,, -67400,0.3256374,3.8359044,,,,,,,,,,,,,,,,, -67500,0.2707604,3.8619802,,,,,,,,,,,,,,,,, -67600,0.37747467,3.7599194,,,,,,,,,,,,,,,,, -67630,,,0.6708812117576599,1.747066617012024,33.23434424702395,0.6815910339355469,1.65132737159729,29.844001559891947,3000.0,0.6962407827377319,1.5660450458526611,29.662369265097585,3003.0,23549.856046438217,38541.92924189568,23549.856046438217,14988.952271461489,0.8565609455108643,0.0 -67700,0.29252267,3.7838352,,,,,,,,,,,,,,,,, -67800,0.28948346,3.8284287,,,,,,,,,,,,,,,,, -67900,0.2681778,3.843952,,,,,,,,,,,,,,,,, -68000,0.27747345,3.7931557,,,,,,,,,,,,,,,,, -68100,0.31289527,3.788899,,,,,,,,,,,,,,,,, -68200,0.29157364,3.8663137,,,,,,,,,,,,,,,,, -68300,0.2639985,3.768882,,,,,,,,,,,,,,,,, -68400,0.30036467,3.7728956,,,,,,,,,,,,,,,,, -68500,0.2799571,3.8228958,,,,,,,,,,,,,,,,, -68600,0.31040064,3.776381,,,,,,,,,,,,,,,,, -68700,0.2956731,3.8446584,,,,,,,,,,,,,,,,, -68800,0.2754606,3.743904,,,,,,,,,,,,,,,,, -68900,0.3433797,3.8145487,,,,,,,,,,,,,,,,, -69000,0.28067437,3.810036,,,,,,,,,,,,,,,,, -69100,0.282402,3.8489177,,,,,,,,,,,,,,,,, -69200,0.31658885,3.8482926,,,,,,,,,,,,,,,,, -69300,0.32100368,3.8390086,,,,,,,,,,,,,,,,, -69400,0.2886464,3.7679715,,,,,,,,,,,,,,,,, -69500,0.3045051,3.781925,,,,,,,,,,,,,,,,, -69600,0.3283181,3.772366,,,,,,,,,,,,,,,,, -69700,0.29016042,3.8363082,,,,,,,,,,,,,,,,, -69800,0.3438167,3.771091,,,,,,,,,,,,,,,,, -69900,0.26912278,3.826513,,,,,,,,,,,,,,,,, -70000,0.27165538,3.8064883,,,,,,,,,,,,,,,,, -70045,,,0.6814951300621033,1.671489715576172,34.03927182611093,0.6828433275222778,1.652961492538452,30.02545059786362,3000.0,0.697623610496521,1.5698145627975464,29.73855557334434,3003.0,24390.056773662567,39862.7853975296,24390.056773662567,15469.488009691238,0.8931279182434082,0.0 -70100,0.2946728,3.7917368,,,,,,,,,,,,,,,,, -70200,0.2708475,3.840583,,,,,,,,,,,,,,,,, -70300,0.29096907,3.7805903,,,,,,,,,,,,,,,,, -70400,0.29147768,3.7445965,,,,,,,,,,,,,,,,, -70500,0.2669073,3.7777834,,,,,,,,,,,,,,,,, -70600,0.26809856,3.7770245,,,,,,,,,,,,,,,,, -70700,0.2867867,3.7833083,,,,,,,,,,,,,,,,, -70800,0.2907628,3.782439,,,,,,,,,,,,,,,,, -70900,0.2918785,3.806453,,,,,,,,,,,,,,,,, -71000,0.26544833,3.7263587,,,,,,,,,,,,,,,,, -71100,0.27866673,3.7953632,,,,,,,,,,,,,,,,, -71200,0.28754866,3.859626,,,,,,,,,,,,,,,,, -71300,0.31860206,3.7853858,,,,,,,,,,,,,,,,, -71400,0.27788228,3.7938566,,,,,,,,,,,,,,,,, -71500,0.26739466,3.8118336,,,,,,,,,,,,,,,,, -71600,0.26848465,3.7854488,,,,,,,,,,,,,,,,, -71700,0.2940904,3.7998123,,,,,,,,,,,,,,,,, -71800,0.3414003,3.8027844,,,,,,,,,,,,,,,,, -71900,0.2757214,3.7847157,,,,,,,,,,,,,,,,, -72000,0.3403167,3.826702,,,,,,,,,,,,,,,,, -72100,0.28842813,3.7896118,,,,,,,,,,,,,,,,, -72200,0.27764037,3.747368,,,,,,,,,,,,,,,,, -72300,0.2846901,3.7659776,,,,,,,,,,,,,,,,, -72400,0.28977677,3.8322635,,,,,,,,,,,,,,,,, -72461,,,0.6753785014152527,1.71561861038208,33.9016781817681,0.6829301714897156,1.6497820615768433,30.174211488067066,3000.0,0.6997618079185486,1.559238314628601,29.75574832873671,3003.0,25230.09949684143,41187.3813123703,25230.09949684143,15953.917650938034,0.9359502792358398,0.0 -72500,0.29459658,3.7601523,,,,,,,,,,,,,,,,, -72600,0.281524,3.7687173,,,,,,,,,,,,,,,,, -72700,0.2787641,3.790789,,,,,,,,,,,,,,,,, -72800,0.28653452,3.808542,,,,,,,,,,,,,,,,, -72900,0.29270765,3.823007,,,,,,,,,,,,,,,,, -73000,0.28879878,3.7980175,,,,,,,,,,,,,,,,, -73100,0.3051927,3.815447,,,,,,,,,,,,,,,,, -73200,0.288905,3.7612593,,,,,,,,,,,,,,,,, -73300,0.3011253,3.7250054,,,,,,,,,,,,,,,,, -73400,0.27214617,3.7317002,,,,,,,,,,,,,,,,, -73500,0.29329014,3.790128,,,,,,,,,,,,,,,,, -73600,0.29242778,3.8300436,,,,,,,,,,,,,,,,, -73700,0.28373143,3.7593205,,,,,,,,,,,,,,,,, -73800,0.30653986,3.7799382,,,,,,,,,,,,,,,,, -73900,0.27674603,3.7760851,,,,,,,,,,,,,,,,, -74000,0.27666584,3.8206086,,,,,,,,,,,,,,,,, -74100,0.31425667,3.824887,,,,,,,,,,,,,,,,, -74200,0.305002,3.8147252,,,,,,,,,,,,,,,,, -74300,0.30486265,3.8231785,,,,,,,,,,,,,,,,, -74400,0.29518154,3.76699,,,,,,,,,,,,,,,,, -74500,0.27730122,3.7210317,,,,,,,,,,,,,,,,, -74600,0.26995146,3.7478452,,,,,,,,,,,,,,,,, -74700,0.2726761,3.734052,,,,,,,,,,,,,,,,, -74800,0.2804614,3.8170133,,,,,,,,,,,,,,,,, -74876,,,0.6737843155860901,1.714847445487976,33.94385616955503,0.6842072606086731,1.6327464580535889,30.11589849755397,3000.0,0.7012027502059937,1.540870189666748,30.20991710316605,3003.0,26069.99861884117,42573.744396448135,26069.99861884117,16500.258165597916,0.977301836013794,0.0 -74900,0.2871252,3.7855792,,,,,,,,,,,,,,,,, -75000,0.28238934,3.763703,,,,,,,,,,,,,,,,, -75100,0.3141586,3.7128696,,,,,,,,,,,,,,,,, -75200,0.30371642,3.7602296,,,,,,,,,,,,,,,,, -75300,0.29689404,3.7877934,,,,,,,,,,,,,,,,, -75400,0.27822348,3.7508469,,,,,,,,,,,,,,,,, -75500,0.28554732,3.791635,,,,,,,,,,,,,,,,, -75600,0.30047798,3.801256,,,,,,,,,,,,,,,,, -75700,0.28039414,3.7960472,,,,,,,,,,,,,,,,, -75800,0.2863754,3.789063,,,,,,,,,,,,,,,,, -75900,0.2854771,3.7658448,,,,,,,,,,,,,,,,, -76000,0.30939472,3.741855,,,,,,,,,,,,,,,,, -76100,0.29065272,3.7636764,,,,,,,,,,,,,,,,, -76200,0.2937815,3.7533321,,,,,,,,,,,,,,,,, -76300,0.28439012,3.7886167,,,,,,,,,,,,,,,,, -76400,0.31024137,3.8326654,,,,,,,,,,,,,,,,, -76500,0.2661151,3.7617502,,,,,,,,,,,,,,,,, -76600,0.27311447,3.7410827,,,,,,,,,,,,,,,,, -76700,0.30107966,3.7846253,,,,,,,,,,,,,,,,, -76800,0.2880282,3.773991,,,,,,,,,,,,,,,,, -76900,0.2791386,3.7649415,,,,,,,,,,,,,,,,, -77000,0.29948214,3.8112152,,,,,,,,,,,,,,,,, -77100,0.28218088,3.7637537,,,,,,,,,,,,,,,,, -77200,0.26660588,3.7641387,,,,,,,,,,,,,,,,, -77292,,,0.6831181645393372,1.6621711254119873,34.21749717101149,0.685459554195404,1.6368142366409302,30.35544077774712,3000.0,0.7004590034484863,1.5464823246002195,30.350398961195094,3003.0,26909.980088472366,43937.69822263718,26909.980088472366,17024.11154460907,1.0152418613433838,0.0 -77300,0.29272816,3.7760763,,,,,,,,,,,,,,,,, -77400,0.295765,3.7714384,,,,,,,,,,,,,,,,, -77500,0.286941,3.7785969,,,,,,,,,,,,,,,,, -77600,0.29019615,3.765801,,,,,,,,,,,,,,,,, -77700,0.3109408,3.8954566,,,,,,,,,,,,,,,,, -77800,0.29343387,3.8114676,,,,,,,,,,,,,,,,, -77900,0.28880295,3.7557702,,,,,,,,,,,,,,,,, -78000,0.28937605,3.7238128,,,,,,,,,,,,,,,,, -78100,0.27904856,3.8119028,,,,,,,,,,,,,,,,, -78200,0.29914886,3.7913723,,,,,,,,,,,,,,,,, -78300,0.30430937,3.7851667,,,,,,,,,,,,,,,,, -78400,0.28978646,3.7545052,,,,,,,,,,,,,,,,, -78500,0.29739568,3.7843826,,,,,,,,,,,,,,,,, -78600,0.28103706,3.7771306,,,,,,,,,,,,,,,,, -78700,0.29588935,3.7844958,,,,,,,,,,,,,,,,, -78800,0.29339498,3.7193928,,,,,,,,,,,,,,,,, -78900,0.28588417,3.7519991,,,,,,,,,,,,,,,,, -79000,0.3116456,3.7573643,,,,,,,,,,,,,,,,, -79100,0.28757507,3.8110363,,,,,,,,,,,,,,,,, -79200,0.28883287,3.794132,,,,,,,,,,,,,,,,, -79300,0.30601513,3.787093,,,,,,,,,,,,,,,,, -79400,0.2875203,3.7436264,,,,,,,,,,,,,,,,, -79500,0.31136778,3.8253734,,,,,,,,,,,,,,,,, -79600,0.30007383,3.7832065,,,,,,,,,,,,,,,,, -79700,0.30877718,3.801673,,,,,,,,,,,,,,,,, -79708,,,0.6789937019348145,1.6847949028015137,34.13086744144354,0.6863647103309631,1.625876784324646,30.33725457330963,3000.0,0.7026204466819763,1.5370594263076782,30.37668151012431,3003.0,27749.933834314343,45234.57467293739,27749.933834314343,17480.916483163834,1.0527923107147217,0.0 -79800,0.31247017,3.779022,,,,,,,,,,,,,,,,, -79900,0.2892412,3.7944584,,,,,,,,,,,,,,,,, -80000,0.2932478,3.6950986,,,,,,,,,,,,,,,,, -80100,0.32530782,3.8288255,,,,,,,,,,,,,,,,, -80200,0.2902298,3.7389755,,,,,,,,,,,,,,,,, -80300,0.29663348,3.7971656,,,,,,,,,,,,,,,,, -80400,0.31466538,3.7554762,,,,,,,,,,,,,,,,, -80500,0.31375182,3.832453,,,,,,,,,,,,,,,,, -80600,0.29739743,3.752422,,,,,,,,,,,,,,,,, -80700,0.28757432,3.7342596,,,,,,,,,,,,,,,,, -80800,0.28643015,3.7940881,,,,,,,,,,,,,,,,, -80900,0.29145595,3.7739928,,,,,,,,,,,,,,,,, -81000,0.30822808,3.8005157,,,,,,,,,,,,,,,,, -81100,0.30613583,3.7298975,,,,,,,,,,,,,,,,, -81200,0.29159102,3.763275,,,,,,,,,,,,,,,,, -81300,0.3349116,3.842905,,,,,,,,,,,,,,,,, -81400,0.29359826,3.7483993,,,,,,,,,,,,,,,,, -81500,0.28804153,3.7097287,,,,,,,,,,,,,,,,, -81600,0.36294916,3.7851722,,,,,,,,,,,,,,,,, -81700,0.29367208,3.7414393,,,,,,,,,,,,,,,,, -81800,0.32077485,3.7633004,,,,,,,,,,,,,,,,, -81900,0.32476488,3.7959876,,,,,,,,,,,,,,,,, -82000,0.28778136,3.811374,,,,,,,,,,,,,,,,, -82100,0.29380232,3.7128747,,,,,,,,,,,,,,,,, -82124,,,0.6997047066688538,1.573757290840149,35.5428039352093,0.6866622567176819,1.6228492259979248,30.64508817624996,3000.0,0.7039335370063782,1.5355430841445925,30.4180427003378,3003.0,28590.127032995224,46531.81851029396,28590.127032995224,17937.84757256508,1.0906941890716553,0.0 -82200,0.31033057,3.7972026,,,,,,,,,,,,,,,,, -82300,0.32017586,3.828687,,,,,,,,,,,,,,,,, -82400,0.2946238,3.7858064,,,,,,,,,,,,,,,,, -82500,0.30191913,3.7740219,,,,,,,,,,,,,,,,, -82600,0.3091474,3.7217412,,,,,,,,,,,,,,,,, -82700,0.31271988,3.7685342,,,,,,,,,,,,,,,,, -82800,0.2900409,3.7679408,,,,,,,,,,,,,,,,, -82900,0.28591448,3.7653298,,,,,,,,,,,,,,,,, -83000,0.2837805,3.725043,,,,,,,,,,,,,,,,, -83100,0.29354963,3.7721064,,,,,,,,,,,,,,,,, -83200,0.29692364,3.7928908,,,,,,,,,,,,,,,,, -83300,0.29474202,3.7117255,,,,,,,,,,,,,,,,, -83400,0.3020688,3.7648687,,,,,,,,,,,,,,,,, -83500,0.2834891,3.762591,,,,,,,,,,,,,,,,, -83600,0.2869553,3.7661433,,,,,,,,,,,,,,,,, -83700,0.30306444,3.7286754,,,,,,,,,,,,,,,,, -83800,0.29775503,3.7047591,,,,,,,,,,,,,,,,, -83900,0.2992753,3.7905018,,,,,,,,,,,,,,,,, -84000,0.296119,3.7243967,,,,,,,,,,,,,,,,, -84100,0.32602528,3.7280946,,,,,,,,,,,,,,,,, -84200,0.274524,3.7739205,,,,,,,,,,,,,,,,, -84300,0.32354355,3.7285311,,,,,,,,,,,,,,,,, -84400,0.30226663,3.78268,,,,,,,,,,,,,,,,, -84500,0.3063006,3.7306356,,,,,,,,,,,,,,,,, -84539,,,0.682442307472229,1.6693923473358154,34.606211236537106,0.688286542892456,1.6197954416275024,30.211782980971023,3000.0,0.7041078805923462,1.5266790390014648,30.54304867956124,3003.0,29430.197852134705,47850.226432323456,29430.197852134705,18416.06471323967,1.1282691955566406,0.0 -84600,0.28779605,3.7510207,,,,,,,,,,,,,,,,, -84700,0.28063497,3.7430987,,,,,,,,,,,,,,,,, -84800,0.29576862,3.700653,,,,,,,,,,,,,,,,, -84900,0.30945936,3.775379,,,,,,,,,,,,,,,,, -85000,0.3126446,3.7341244,,,,,,,,,,,,,,,,, -85100,0.29915535,3.716572,,,,,,,,,,,,,,,,, -85200,0.30290473,3.7836988,,,,,,,,,,,,,,,,, -85300,0.32903814,3.7639475,,,,,,,,,,,,,,,,, -85400,0.30437744,3.756446,,,,,,,,,,,,,,,,, -85500,0.318664,3.784098,,,,,,,,,,,,,,,,, -85600,0.31112915,3.7006154,,,,,,,,,,,,,,,,, -85700,0.3185772,3.7865632,,,,,,,,,,,,,,,,, -85800,0.33218876,3.8150663,,,,,,,,,,,,,,,,, -85900,0.31068572,3.7900534,,,,,,,,,,,,,,,,, -86000,0.3430626,3.7792542,,,,,,,,,,,,,,,,, -86100,0.30178794,3.716086,,,,,,,,,,,,,,,,, -86200,0.3453646,3.726133,,,,,,,,,,,,,,,,, -86300,0.30170587,3.724619,,,,,,,,,,,,,,,,, -86400,0.30827898,3.8119607,,,,,,,,,,,,,,,,, -86500,0.3240555,3.7393906,,,,,,,,,,,,,,,,, -86600,0.3276666,3.7136006,,,,,,,,,,,,,,,,, -86700,0.32706589,3.7555094,,,,,,,,,,,,,,,,, -86800,0.33007896,3.739333,,,,,,,,,,,,,,,,, -86900,0.31415227,3.7253716,,,,,,,,,,,,,,,,, -86955,,,0.6825690269470215,1.666361689567566,34.71119184261502,0.6877037882804871,1.6188981533050537,30.182893205006323,3000.0,0.7032247185707092,1.533256769180298,30.31646394999133,3003.0,30270.317069530487,49196.66105890274,30270.317069530487,18922.261901140213,1.1654300689697266,0.0 -87000,0.30085897,3.7132025,,,,,,,,,,,,,,,,, -87100,0.2989087,3.698566,,,,,,,,,,,,,,,,, -87200,0.32695594,3.7499738,,,,,,,,,,,,,,,,, -87300,0.34673423,3.729009,,,,,,,,,,,,,,,,, -87400,0.31690624,3.726211,,,,,,,,,,,,,,,,, -87500,0.30823916,3.7721088,,,,,,,,,,,,,,,,, -87600,0.3333487,3.7057865,,,,,,,,,,,,,,,,, -87700,0.2995191,3.7321935,,,,,,,,,,,,,,,,, -87800,0.35049573,3.707984,,,,,,,,,,,,,,,,, -87900,0.34168792,3.7167633,,,,,,,,,,,,,,,,, -88000,0.31822032,3.6876192,,,,,,,,,,,,,,,,, -88100,0.31860822,3.7487965,,,,,,,,,,,,,,,,, -88200,0.30084926,3.7409742,,,,,,,,,,,,,,,,, -88300,0.30330023,3.7669754,,,,,,,,,,,,,,,,, -88400,0.3171028,3.747029,,,,,,,,,,,,,,,,, -88500,0.31526107,3.736953,,,,,,,,,,,,,,,,, -88600,0.29899663,3.6783938,,,,,,,,,,,,,,,,, -88700,0.3151874,3.735583,,,,,,,,,,,,,,,,, -88800,0.33105114,3.7705383,,,,,,,,,,,,,,,,, -88900,0.31195673,3.7609699,,,,,,,,,,,,,,,,, -89000,0.31970423,3.73662,,,,,,,,,,,,,,,,, -89100,0.33730513,3.7460635,,,,,,,,,,,,,,,,, -89200,0.32047477,3.7434046,,,,,,,,,,,,,,,,, -89300,0.32403046,3.7303708,,,,,,,,,,,,,,,,, -89370,,,0.6950032711029053,1.60042405128479,35.154835319669424,0.6891793012619019,1.609760046005249,30.554063056987584,3000.0,0.705083966255188,1.517716407775879,30.645274576054728,3003.0,31110.298770189285,50504.37959980965,31110.298770189285,19389.877872228622,1.2036826610565186,0.0 -89400,0.29410088,3.7051287,,,,,,,,,,,,,,,,, -89500,0.30705646,3.6878765,,,,,,,,,,,,,,,,, -89600,0.3325733,3.6951923,,,,,,,,,,,,,,,,, -89700,0.34518626,3.7372146,,,,,,,,,,,,,,,,, -89800,0.30885297,3.7197104,,,,,,,,,,,,,,,,, -89900,0.30386576,3.714039,,,,,,,,,,,,,,,,, -90000,0.33102787,3.7883804,,,,,,,,,,,,,,,,, -90100,0.3304548,3.7515895,,,,,,,,,,,,,,,,, -90200,0.31185246,3.7623756,,,,,,,,,,,,,,,,, -90300,0.31895173,3.7773066,,,,,,,,,,,,,,,,, -90400,0.32388246,3.6991594,,,,,,,,,,,,,,,,, -90500,0.31352884,3.6744387,,,,,,,,,,,,,,,,, -90600,0.34665793,3.7368078,,,,,,,,,,,,,,,,, -90700,0.3517159,3.7928903,,,,,,,,,,,,,,,,, -90800,0.33079603,3.7566347,,,,,,,,,,,,,,,,, -90900,0.3129494,3.7406893,,,,,,,,,,,,,,,,, -91000,0.3035207,3.7250946,,,,,,,,,,,,,,,,, -91100,0.3165166,3.7261639,,,,,,,,,,,,,,,,, -91200,0.3687947,3.765122,,,,,,,,,,,,,,,,, -91300,0.32781687,3.7677355,,,,,,,,,,,,,,,,, -91400,0.33003327,3.694552,,,,,,,,,,,,,,,,, -91500,0.31893277,3.7035801,,,,,,,,,,,,,,,,, -91600,0.32565323,3.7134144,,,,,,,,,,,,,,,,, -91700,0.35099363,3.7698636,,,,,,,,,,,,,,,,, -91785,,,0.6907691955566406,1.61209237575531,34.67632533933311,0.6889809370040894,1.6088262796401978,30.47947166718364,3000.0,0.7055255770683289,1.5153443813323977,30.73101495101404,3003.0,31950.33271098137,51806.25148367882,31950.33271098137,19851.59775352478,1.242016077041626,0.0 -91800,0.33304018,3.728759,,,,,,,,,,,,,,,,, -91900,0.30922315,3.6955473,,,,,,,,,,,,,,,,, -92000,0.35480314,3.7663603,,,,,,,,,,,,,,,,, -92100,0.30862907,3.7259247,,,,,,,,,,,,,,,,, -92200,0.32867458,3.7891185,,,,,,,,,,,,,,,,, -92300,0.31572628,3.6993113,,,,,,,,,,,,,,,,, -92400,0.33046475,3.7983356,,,,,,,,,,,,,,,,, -92500,0.32045197,3.6984644,,,,,,,,,,,,,,,,, -92600,0.3159597,3.669921,,,,,,,,,,,,,,,,, -92700,0.33148232,3.7478828,,,,,,,,,,,,,,,,, -92800,0.31403306,3.7576208,,,,,,,,,,,,,,,,, -92900,0.32308576,3.779254,,,,,,,,,,,,,,,,, -93000,0.315524,3.7671838,,,,,,,,,,,,,,,,, -93100,0.3278051,3.7221282,,,,,,,,,,,,,,,,, -93200,0.3406132,3.699546,,,,,,,,,,,,,,,,, -93300,0.33133563,3.673524,,,,,,,,,,,,,,,,, -93400,0.32743922,3.721115,,,,,,,,,,,,,,,,, -93500,0.32762498,3.6711116,,,,,,,,,,,,,,,,, -93600,0.32003486,3.7029955,,,,,,,,,,,,,,,,, -93700,0.3308878,3.7330112,,,,,,,,,,,,,,,,, -93800,0.29230204,3.6648657,,,,,,,,,,,,,,,,, -93900,0.31652328,3.7184181,,,,,,,,,,,,,,,,, -94000,0.34232128,3.7385552,,,,,,,,,,,,,,,,, -94100,0.33278108,3.7862382,,,,,,,,,,,,,,,,, -94200,,,0.7140809893608093,1.4977169036865234,36.6485987570553,0.6901712417602539,1.6079293489456177,30.60468141212876,3000.0,0.7057114839553833,1.514120101928711,30.51289879789388,3003.0,32790.51075673103,53128.12596178055,32790.51075673103,20333.17138814926,1.2815396785736084,0.0 -94200,0.32705504,3.690223,,,,,,,,,,,,,,,,, -94300,0.3226949,3.679449,,,,,,,,,,,,,,,,, -94400,0.33468264,3.688405,,,,,,,,,,,,,,,,, -94500,0.31260514,3.6915307,,,,,,,,,,,,,,,,, -94600,0.32517865,3.7265284,,,,,,,,,,,,,,,,, -94700,0.35967043,3.6902997,,,,,,,,,,,,,,,,, -94800,0.3318577,3.7014446,,,,,,,,,,,,,,,,, -94900,0.32659894,3.761705,,,,,,,,,,,,,,,,, -95000,0.32639506,3.688034,,,,,,,,,,,,,,,,, -95100,0.3354474,3.6895773,,,,,,,,,,,,,,,,, -95200,0.35050485,3.6975186,,,,,,,,,,,,,,,,, -95300,0.32123756,3.7638853,,,,,,,,,,,,,,,,, -95400,0.38571313,3.70802,,,,,,,,,,,,,,,,, -95500,0.33273205,3.6980367,,,,,,,,,,,,,,,,, -95600,0.34477875,3.7091904,,,,,,,,,,,,,,,,, -95700,0.3319727,3.7550447,,,,,,,,,,,,,,,,, -95800,0.31557524,3.6598163,,,,,,,,,,,,,,,,, -95900,0.3327746,3.6610081,,,,,,,,,,,,,,,,, -96000,0.34167892,3.7231193,,,,,,,,,,,,,,,,, -96100,0.35256746,3.7222672,,,,,,,,,,,,,,,,, -96200,0.33714068,3.6214817,,,,,,,,,,,,,,,,, -96300,0.32671914,3.6959236,,,,,,,,,,,,,,,,, -96400,0.32219627,3.6822152,,,,,,,,,,,,,,,,, -96500,0.3206809,3.696819,,,,,,,,,,,,,,,,, -96600,0.34082744,3.7480166,,,,,,,,,,,,,,,,, -96616,,,0.6998093128204346,1.5723187923431396,35.22440418347292,0.6890801191329956,1.608446478843689,30.391074376933528,3000.0,0.7062111496925354,1.511307954788208,30.341710528802896,3003.0,33630.70009255409,54448.7342133522,33630.70009255409,20813.4673306942,1.321253538131714,0.0 -96700,0.3216187,3.689965,,,,,,,,,,,,,,,,, -96800,0.33957687,3.713169,,,,,,,,,,,,,,,,, -96900,0.35956442,3.7066898,,,,,,,,,,,,,,,,, -97000,0.3490901,3.7088974,,,,,,,,,,,,,,,,, -97100,0.32787862,3.6509256,,,,,,,,,,,,,,,,, -97200,0.33886597,3.713072,,,,,,,,,,,,,,,,, -97300,0.35547057,3.728868,,,,,,,,,,,,,,,,, -97400,0.32424432,3.6845207,,,,,,,,,,,,,,,,, -97500,0.3505095,3.6549478,,,,,,,,,,,,,,,,, -97600,0.34451354,3.7097986,,,,,,,,,,,,,,,,, -97700,0.3552802,3.6954503,,,,,,,,,,,,,,,,, -97800,0.35904175,3.7184346,,,,,,,,,,,,,,,,, -97900,0.33744442,3.6986496,,,,,,,,,,,,,,,,, -98000,0.34100664,3.6899204,,,,,,,,,,,,,,,,, -98100,0.35095483,3.72653,,,,,,,,,,,,,,,,, -98200,0.3334781,3.6891477,,,,,,,,,,,,,,,,, -98300,0.33492795,3.6779425,,,,,,,,,,,,,,,,, -98400,0.32768592,3.6624746,,,,,,,,,,,,,,,,, -98500,0.32799152,3.6868398,,,,,,,,,,,,,,,,, -98600,0.35086584,3.7191954,,,,,,,,,,,,,,,,, -98700,0.34166658,3.6976528,,,,,,,,,,,,,,,,, -98800,0.33139467,3.6417587,,,,,,,,,,,,,,,,, -98900,0.35294667,3.7172303,,,,,,,,,,,,,,,,, -99000,0.33949292,3.6534576,,,,,,,,,,,,,,,,, -99032,,,0.6960033774375916,1.58683180809021,35.784215910569934,0.6911879777908325,1.6049529314041138,30.653904813297338,3000.0,0.706362247467041,1.5094925165176392,30.670146910166952,3003.0,34470.83130598068,55763.27522611618,34470.83130598068,21287.758934021,1.3601250648498535,0.0 -99100,0.35231644,3.690014,,,,,,,,,,,,,,,,, -99200,0.3279873,3.653791,,,,,,,,,,,,,,,,, -99300,0.33465403,3.7084587,,,,,,,,,,,,,,,,, -99400,0.3601969,3.6896935,,,,,,,,,,,,,,,,, -99500,0.33629358,3.6653447,,,,,,,,,,,,,,,,, -99600,0.3306396,3.6775963,,,,,,,,,,,,,,,,, -99700,0.3302523,3.644229,,,,,,,,,,,,,,,,, -99800,0.3261925,3.64791,,,,,,,,,,,,,,,,, -99900,0.36718,3.696951,,,,,,,,,,,,,,,,, -100000,0.34019706,3.6983063,,,,,,,,,,,,,,,,, -100100,0.35594556,3.711332,,,,,,,,,,,,,,,,, -100200,0.34752208,3.6929975,,,,,,,,,,,,,,,,, -100300,0.32820818,3.64849,,,,,,,,,,,,,,,,, -100400,0.35092738,3.682934,,,,,,,,,,,,,,,,, -100500,0.36535326,3.6920907,,,,,,,,,,,,,,,,, -100600,0.33750984,3.647812,,,,,,,,,,,,,,,,, -100700,0.3528077,3.6729178,,,,,,,,,,,,,,,,, -100800,0.35459915,3.680488,,,,,,,,,,,,,,,,, -100900,0.33972272,3.6895337,,,,,,,,,,,,,,,,, -101000,0.35185498,3.691432,,,,,,,,,,,,,,,,, -101100,0.34181082,3.6579084,,,,,,,,,,,,,,,,, -101200,0.3521709,3.6712537,,,,,,,,,,,,,,,,, -101300,0.32517463,3.6587622,,,,,,,,,,,,,,,,, -101400,0.33587822,3.6785176,,,,,,,,,,,,,,,,, -101448,,,0.7109413743019104,1.5052330493927002,36.498945434107455,0.6911631226539612,1.6019834280014038,30.694062047415063,3000.0,0.7069781422615051,1.5060288906097412,30.958594252207657,3003.0,35310.982800245285,57066.73641419411,35310.982800245285,21750.94939732552,1.3992786407470703,0.0 -101500,0.3505531,3.6611319,,,,,,,,,,,,,,,,, -101600,0.35172388,3.7304585,,,,,,,,,,,,,,,,, -101700,0.34863013,3.6797009,,,,,,,,,,,,,,,,, -101800,0.34767318,3.651066,,,,,,,,,,,,,,,,, -101900,0.36721066,3.744429,,,,,,,,,,,,,,,,, -102000,0.34442207,3.6607547,,,,,,,,,,,,,,,,, -102100,0.33920258,3.6339276,,,,,,,,,,,,,,,,, -102200,0.3455678,3.7059643,,,,,,,,,,,,,,,,, -102300,0.35553783,3.76236,,,,,,,,,,,,,,,,, -102400,0.36469826,3.723933,,,,,,,,,,,,,,,,, -102500,0.38345698,3.667465,,,,,,,,,,,,,,,,, -102600,0.3606844,3.6247377,,,,,,,,,,,,,,,,, -102700,0.35623637,3.6343148,,,,,,,,,,,,,,,,, -102800,0.3452842,3.6337235,,,,,,,,,,,,,,,,, -102900,0.351586,3.7014718,,,,,,,,,,,,,,,,, -103000,0.3604491,3.6737394,,,,,,,,,,,,,,,,, -103100,0.35106072,3.6329558,,,,,,,,,,,,,,,,, -103200,0.34055713,3.5859768,,,,,,,,,,,,,,,,, -103300,0.33650568,3.619954,,,,,,,,,,,,,,,,, -103400,0.35230175,3.6732385,,,,,,,,,,,,,,,,, -103500,0.3449552,3.6559625,,,,,,,,,,,,,,,,, -103600,0.362841,3.653297,,,,,,,,,,,,,,,,, -103700,0.3575864,3.6769376,,,,,,,,,,,,,,,,, -103800,0.35681137,3.6676238,,,,,,,,,,,,,,,,, -103863,,,0.6996182203292847,1.565675139427185,36.38101795307617,0.6910267472267151,1.6026326417922974,30.47832296179434,3000.0,0.7081633806228638,1.5044052600860596,30.9335101099981,3003.0,36151.15802383423,58371.27501010895,36151.15802383423,22215.19164896012,1.43782639503479,0.0 -103900,0.35869628,3.712956,,,,,,,,,,,,,,,,, -104000,0.37462032,3.676192,,,,,,,,,,,,,,,,, -104100,0.37734604,3.6782813,,,,,,,,,,,,,,,,, -104200,0.35988513,3.6872997,,,,,,,,,,,,,,,,, -104300,0.3656632,3.7013652,,,,,,,,,,,,,,,,, -104400,0.34425375,3.6347601,,,,,,,,,,,,,,,,, -104500,0.3546196,3.6942387,,,,,,,,,,,,,,,,, -104600,0.36096564,3.6522608,,,,,,,,,,,,,,,,, -104700,0.3606025,3.6772957,,,,,,,,,,,,,,,,, -104800,0.3573039,3.6955369,,,,,,,,,,,,,,,,, -104900,0.36136523,3.6698005,,,,,,,,,,,,,,,,, -105000,0.37432203,3.700799,,,,,,,,,,,,,,,,, -105100,0.34580925,3.6694188,,,,,,,,,,,,,,,,, -105200,0.36009824,3.7049444,,,,,,,,,,,,,,,,, -105300,0.36210635,3.632008,,,,,,,,,,,,,,,,, -105400,0.3419515,3.6319833,,,,,,,,,,,,,,,,, -105500,0.34660295,3.6399753,,,,,,,,,,,,,,,,, -105600,0.38386157,3.6846387,,,,,,,,,,,,,,,,, -105700,0.3503563,3.6556835,,,,,,,,,,,,,,,,, -105800,0.3730579,3.6921124,,,,,,,,,,,,,,,,, -105900,0.35000667,3.6632187,,,,,,,,,,,,,,,,, -106000,0.3383578,3.6489427,,,,,,,,,,,,,,,,, -106100,0.3647841,3.72755,,,,,,,,,,,,,,,,, -106200,0.3484916,3.6887074,,,,,,,,,,,,,,,,, -106278,,,0.7071945071220398,1.5247044563293457,35.68798725043688,0.6906920075416565,1.6015530824661257,30.56966933807441,3000.0,0.7085352540016174,1.5015463829040527,30.86064708275572,3003.0,36991.27319264412,59683.38256788254,36991.27319264412,22687.063174963,1.4773857593536377,0.0 -106300,0.37776953,3.6963534,,,,,,,,,,,,,,,,, -106400,0.362591,3.6661007,,,,,,,,,,,,,,,,, -106500,0.37286437,3.656394,,,,,,,,,,,,,,,,, -106600,0.39193186,3.6296165,,,,,,,,,,,,,,,,, -106700,0.37370452,3.6552334,,,,,,,,,,,,,,,,, -106800,0.36124417,3.686755,,,,,,,,,,,,,,,,, -106900,0.37433958,3.640254,,,,,,,,,,,,,,,,, -107000,0.3624874,3.722633,,,,,,,,,,,,,,,,, -107100,0.37499267,3.6213682,,,,,,,,,,,,,,,,, -107200,0.35227498,3.6286345,,,,,,,,,,,,,,,,, -107300,0.35941818,3.6404035,,,,,,,,,,,,,,,,, -107400,0.3705651,3.6905246,,,,,,,,,,,,,,,,, -107500,0.38223284,3.637185,,,,,,,,,,,,,,,,, -107600,0.35818395,3.5900235,,,,,,,,,,,,,,,,, -107700,0.37594885,3.6729817,,,,,,,,,,,,,,,,, -107800,0.36394018,3.635716,,,,,,,,,,,,,,,,, -107900,0.36626154,3.6272187,,,,,,,,,,,,,,,,, -108000,0.36680463,3.6915462,,,,,,,,,,,,,,,,, -108100,0.3796949,3.6579769,,,,,,,,,,,,,,,,, -108200,0.40413994,3.6875973,,,,,,,,,,,,,,,,, -108300,0.35680747,3.6753895,,,,,,,,,,,,,,,,, -108400,0.36858025,3.6288176,,,,,,,,,,,,,,,,, -108500,0.3772965,3.6617897,,,,,,,,,,,,,,,,, -108600,0.37062785,3.6479025,,,,,,,,,,,,,,,,, -108693,,,0.7097283601760864,1.5101783275604248,36.4322003452034,0.691696286201477,1.6009178161621094,30.72172265147943,3000.0,0.7084887623786926,1.4992235898971558,30.698081478592425,3003.0,37831.44410777092,60993.79703378677,37831.44410777092,23157.18415951729,1.517770767211914,0.0 -108700,0.34975046,3.6165183,,,,,,,,,,,,,,,,, -108800,0.34523392,3.5992289,,,,,,,,,,,,,,,,, -108900,0.34616655,3.6189966,,,,,,,,,,,,,,,,, -109000,0.38966042,3.6964152,,,,,,,,,,,,,,,,, -109100,0.36835277,3.6351018,,,,,,,,,,,,,,,,, -109200,0.37794366,3.6487398,,,,,,,,,,,,,,,,, -109300,0.37899664,3.614447,,,,,,,,,,,,,,,,, -109400,0.36613947,3.5853138,,,,,,,,,,,,,,,,, -109500,0.38673073,3.6593409,,,,,,,,,,,,,,,,, -109600,0.373628,3.6256893,,,,,,,,,,,,,,,,, -109700,0.39946166,3.640563,,,,,,,,,,,,,,,,, -109800,0.4044709,3.6201184,,,,,,,,,,,,,,,,, -109900,0.38438922,3.665064,,,,,,,,,,,,,,,,, -110000,0.35489547,3.5969553,,,,,,,,,,,,,,,,, -110100,0.36314645,3.6536748,,,,,,,,,,,,,,,,, -110200,0.3684349,3.67822,,,,,,,,,,,,,,,,, -110300,0.38400504,3.604145,,,,,,,,,,,,,,,,, -110400,0.3593461,3.6310914,,,,,,,,,,,,,,,,, -110500,0.36295134,3.6328628,,,,,,,,,,,,,,,,, -110600,0.37237284,3.582436,,,,,,,,,,,,,,,,, -110700,0.3698367,3.621214,,,,,,,,,,,,,,,,, -110800,0.36871728,3.6595669,,,,,,,,,,,,,,,,, -110900,0.39364767,3.68224,,,,,,,,,,,,,,,,, -111000,0.38993594,3.6737537,,,,,,,,,,,,,,,,, -111100,0.3752285,3.6681063,,,,,,,,,,,,,,,,, -111109,,,0.7096648812294006,1.506799578666687,36.62844563646342,0.6907168030738831,1.6045445203781128,30.522907929975723,3000.0,0.7071989178657532,1.5012623071670532,30.854497798748653,3003.0,38671.49594020844,62304.079092502594,38671.49594020844,23627.291821718216,1.559746265411377,0.0 -111200,0.37389466,3.6396332,,,,,,,,,,,,,,,,, -111300,0.3708535,3.6148949,,,,,,,,,,,,,,,,, -111400,0.39834803,3.6514256,,,,,,,,,,,,,,,,, -111500,0.39260885,3.6678126,,,,,,,,,,,,,,,,, -111600,0.39266565,3.6436305,,,,,,,,,,,,,,,,, -111700,0.38596675,3.6532183,,,,,,,,,,,,,,,,, -111800,0.36161074,3.596621,,,,,,,,,,,,,,,,, -111900,0.36882135,3.649876,,,,,,,,,,,,,,,,, -112000,0.4011229,3.5917861,,,,,,,,,,,,,,,,, -112100,0.3904602,3.600975,,,,,,,,,,,,,,,,, -112200,0.3855909,3.7055092,,,,,,,,,,,,,,,,, -112300,0.40800655,3.6646073,,,,,,,,,,,,,,,,, -112400,0.38119596,3.597448,,,,,,,,,,,,,,,,, -112500,0.37951466,3.6300724,,,,,,,,,,,,,,,,, -112600,0.38112018,3.6414282,,,,,,,,,,,,,,,,, -112700,0.35926193,3.5958624,,,,,,,,,,,,,,,,, -112800,0.37103993,3.6169512,,,,,,,,,,,,,,,,, -112900,0.3965224,3.6367512,,,,,,,,,,,,,,,,, -113000,0.37663487,3.658852,,,,,,,,,,,,,,,,, -113100,0.38493383,3.5935013,,,,,,,,,,,,,,,,, -113200,0.37464136,3.6586943,,,,,,,,,,,,,,,,, -113300,0.39415964,3.6867023,,,,,,,,,,,,,,,,, -113400,0.38995728,3.6436298,,,,,,,,,,,,,,,,, -113500,0.37402162,3.63626,,,,,,,,,,,,,,,,, -113523,,,0.7205490469932556,1.4559381008148191,37.3473881454824,0.6923410892486572,1.5988622903823853,30.74776238203665,3000.0,0.7090232968330383,1.4939895868301392,30.81991409290025,3003.0,39511.46545481682,63604.90850496292,39511.46545481682,24088.021463871,1.606471061706543,0.0 -113600,0.37910524,3.6709795,,,,,,,,,,,,,,,,, -113700,0.3934538,3.6780632,,,,,,,,,,,,,,,,, -113800,0.38225174,3.6172462,,,,,,,,,,,,,,,,, -113900,0.38046193,3.689146,,,,,,,,,,,,,,,,, -114000,0.3839561,3.6051433,,,,,,,,,,,,,,,,, -114100,0.3719614,3.5824344,,,,,,,,,,,,,,,,, -114200,0.3819086,3.5833097,,,,,,,,,,,,,,,,, -114300,0.3999239,3.653311,,,,,,,,,,,,,,,,, -114400,0.3762625,3.6235652,,,,,,,,,,,,,,,,, -114500,0.3682117,3.6184754,,,,,,,,,,,,,,,,, -114600,0.38290527,3.637453,,,,,,,,,,,,,,,,, -114700,0.39122656,3.6722813,,,,,,,,,,,,,,,,, -114800,0.37915474,3.596693,,,,,,,,,,,,,,,,, -114900,0.3816267,3.6285734,,,,,,,,,,,,,,,,, -115000,0.39138582,3.6251194,,,,,,,,,,,,,,,,, -115100,0.39285138,3.6230087,,,,,,,,,,,,,,,,, -115200,0.38052925,3.6006055,,,,,,,,,,,,,,,,, -115300,0.39260942,3.6470525,,,,,,,,,,,,,,,,, -115400,0.3760214,3.6318774,,,,,,,,,,,,,,,,, -115500,0.38312453,3.6074464,,,,,,,,,,,,,,,,, -115600,0.4030354,3.6178842,,,,,,,,,,,,,,,,, -115700,0.3782481,3.636632,,,,,,,,,,,,,,,,, -115800,0.37962762,3.6022038,,,,,,,,,,,,,,,,, -115900,0.37241355,3.5949893,,,,,,,,,,,,,,,,, -115939,,,0.7139232158660889,1.487220048904419,36.63052469599194,0.6912499666213989,1.601006269454956,30.716334354820788,3000.0,0.7080239653587341,1.49728524684906,30.801907041652694,3003.0,40351.63733792305,64910.21405529976,40351.63733792305,24553.03221011161,1.6485328674316406,0.0 -116000,0.40050736,3.647921,,,,,,,,,,,,,,,,, -116100,0.36562195,3.583254,,,,,,,,,,,,,,,,, -116200,0.38941374,3.6128638,,,,,,,,,,,,,,,,, -116300,0.37472767,3.5807729,,,,,,,,,,,,,,,,, -116400,0.3963639,3.5969858,,,,,,,,,,,,,,,,, -116500,0.3968776,3.6649933,,,,,,,,,,,,,,,,, -116600,0.39183423,3.6543124,,,,,,,,,,,,,,,,, -116700,0.38662693,3.6002576,,,,,,,,,,,,,,,,, -116800,0.3810862,3.6712003,,,,,,,,,,,,,,,,, -116900,0.37723368,3.6609983,,,,,,,,,,,,,,,,, -117000,0.38776967,3.6608195,,,,,,,,,,,,,,,,, -117100,0.37693226,3.5982685,,,,,,,,,,,,,,,,, -117200,0.4039077,3.6429102,,,,,,,,,,,,,,,,, -117300,0.37510857,3.578107,,,,,,,,,,,,,,,,, -117400,0.40213498,3.6883419,,,,,,,,,,,,,,,,, -117500,0.4026496,3.6485434,,,,,,,,,,,,,,,,, -117600,0.3867943,3.5908663,,,,,,,,,,,,,,,,, -117700,0.41303664,3.7057562,,,,,,,,,,,,,,,,, -117800,0.39502928,3.6330392,,,,,,,,,,,,,,,,, -117900,0.38167495,3.5922704,,,,,,,,,,,,,,,,, -118000,0.38443995,3.56998,,,,,,,,,,,,,,,,, -118100,0.37848127,3.577961,,,,,,,,,,,,,,,,, -118200,0.37266484,3.6161757,,,,,,,,,,,,,,,,, -118300,0.39774406,3.619848,,,,,,,,,,,,,,,,, -118355,,,0.7143579721450806,1.484772801399231,37.03462724545019,0.6923038959503174,1.6000486612319946,30.58042947513172,3000.0,0.7077218294143677,1.4967893362045288,30.7039356366018,3003.0,41191.82350349426,66209.90850758553,41191.82350349426,25012.41903567314,1.6906397342681885,0.0 -118400,0.36815694,3.6101825,,,,,,,,,,,,,,,,, -118500,0.3934073,3.6129947,,,,,,,,,,,,,,,,, -118600,0.40565702,3.6494563,,,,,,,,,,,,,,,,, -118700,0.40625072,3.670567,,,,,,,,,,,,,,,,, -118800,0.40237376,3.5998812,,,,,,,,,,,,,,,,, -118900,0.38446835,3.6147313,,,,,,,,,,,,,,,,, -119000,0.39367235,3.6039984,,,,,,,,,,,,,,,,, -119100,0.3939928,3.591148,,,,,,,,,,,,,,,,, -119200,0.41060704,3.5751595,,,,,,,,,,,,,,,,, -119300,0.3721481,3.5804493,,,,,,,,,,,,,,,,, -119400,0.40632808,3.585218,,,,,,,,,,,,,,,,, -119500,0.39333528,3.5751839,,,,,,,,,,,,,,,,, -119600,0.39288053,3.5821745,,,,,,,,,,,,,,,,, -119700,0.38701952,3.6250503,,,,,,,,,,,,,,,,, -119800,0.39742208,3.6000795,,,,,,,,,,,,,,,,, -119900,0.4022337,3.5849116,,,,,,,,,,,,,,,,, -120000,0.38444746,3.60129,,,,,,,,,,,,,,,,, -120100,0.40419164,3.6560705,,,,,,,,,,,,,,,,, -120200,0.41911212,3.6252475,,,,,,,,,,,,,,,,, -120300,0.387419,3.5959854,,,,,,,,,,,,,,,,, -120400,0.40546626,3.6503196,,,,,,,,,,,,,,,,, -120500,0.37988356,3.5899248,,,,,,,,,,,,,,,,, -120600,0.39020088,3.6282003,,,,,,,,,,,,,,,,, -120700,0.3671636,3.5770319,,,,,,,,,,,,,,,,, -120769,,,0.7247974872589111,1.4364838600158691,37.53693442719608,0.6926262378692627,1.5970616340637207,30.67829971443381,3000.0,0.7095927000045776,1.4932905435562134,30.727894032551443,3003.0,42031.85128903389,67522.48354244232,42031.85128903389,25484.83347582817,1.7393221855163574,0.0 -120800,0.36966953,3.5772457,,,,,,,,,,,,,,,,, -120900,0.3876029,3.5901742,,,,,,,,,,,,,,,,, -121000,0.38527158,3.589912,,,,,,,,,,,,,,,,, -121100,0.4171539,3.5974643,,,,,,,,,,,,,,,,, -121200,0.39802417,3.6085217,,,,,,,,,,,,,,,,, -121300,0.3715839,3.6244879,,,,,,,,,,,,,,,,, -121400,0.39346525,3.624219,,,,,,,,,,,,,,,,, -121500,0.39284334,3.6310616,,,,,,,,,,,,,,,,, -121600,0.37228653,3.6153495,,,,,,,,,,,,,,,,, -121700,0.39953786,3.584927,,,,,,,,,,,,,,,,, -121800,0.4097172,3.6564157,,,,,,,,,,,,,,,,, -121900,0.3970606,3.6406853,,,,,,,,,,,,,,,,, -122000,0.39116517,3.6065652,,,,,,,,,,,,,,,,, -122100,0.4003788,3.597207,,,,,,,,,,,,,,,,, -122200,0.41206124,3.6655195,,,,,,,,,,,,,,,,, -122300,0.38499516,3.6094086,,,,,,,,,,,,,,,,, -122400,0.38066137,3.6127572,,,,,,,,,,,,,,,,, -122500,0.4020645,3.5948977,,,,,,,,,,,,,,,,, -122600,0.4062097,3.6097054,,,,,,,,,,,,,,,,, -122700,0.4166083,3.6529582,,,,,,,,,,,,,,,,, -122800,0.40361527,3.5918286,,,,,,,,,,,,,,,,, -122900,0.38193375,3.5716248,,,,,,,,,,,,,,,,, -123000,0.40359962,3.6116579,,,,,,,,,,,,,,,,, -123100,0.38199797,3.612153,,,,,,,,,,,,,,,,, -123184,,,0.7164363265037537,1.4766035079956057,37.23399471055593,0.6918575167655945,1.5992780923843384,30.603763521386373,3000.0,0.7095230221748352,1.493468999862671,30.812469473793247,3003.0,42871.75427532196,68822.87196969986,42871.75427532196,25945.196885108948,1.7829713821411133,0.0 -123200,0.39685512,3.5655644,,,,,,,,,,,,,,,,, -123300,0.38645488,3.5918753,,,,,,,,,,,,,,,,, -123400,0.4145446,3.639483,,,,,,,,,,,,,,,,, -123500,0.39478093,3.6252627,,,,,,,,,,,,,,,,, -123600,0.394631,3.6415699,,,,,,,,,,,,,,,,, -123700,0.39648637,3.6204062,,,,,,,,,,,,,,,,, -123800,0.37514028,3.5677993,,,,,,,,,,,,,,,,, -123900,0.39156318,3.6154745,,,,,,,,,,,,,,,,, -124000,0.3967715,3.609748,,,,,,,,,,,,,,,,, -124100,0.41275522,3.6189659,,,,,,,,,,,,,,,,, -124200,0.39831358,3.592778,,,,,,,,,,,,,,,,, -124300,0.40165448,3.6137197,,,,,,,,,,,,,,,,, -124400,0.39138862,3.5830717,,,,,,,,,,,,,,,,, -124500,0.37038898,3.5876184,,,,,,,,,,,,,,,,, -124600,0.3974603,3.628659,,,,,,,,,,,,,,,,, -124700,0.38332728,3.6002731,,,,,,,,,,,,,,,,, -124800,0.40443945,3.581548,,,,,,,,,,,,,,,,, -124900,0.4074397,3.608658,,,,,,,,,,,,,,,,, -125000,0.39903784,3.6340966,,,,,,,,,,,,,,,,, -125100,0.39929724,3.629597,,,,,,,,,,,,,,,,, -125200,0.3916244,3.5893261,,,,,,,,,,,,,,,,, -125300,0.3810055,3.545964,,,,,,,,,,,,,,,,, -125400,0.40283906,3.619323,,,,,,,,,,,,,,,,, -125500,0.39111447,3.5900843,,,,,,,,,,,,,,,,, -125598,,,0.7242320775985718,1.4370570182800293,37.27877258251524,0.6923782825469971,1.598088026046753,30.46979271110354,3000.0,0.7091279029846191,1.4928165674209597,30.822391857216715,3003.0,43711.76387667656,70131.25042271614,43711.76387667656,26413.436207294464,1.82772159576416,0.0 -125600,0.40571168,3.62656,,,,,,,,,,,,,,,,, -125700,0.41835144,3.6090443,,,,,,,,,,,,,,,,, -125800,0.3900576,3.6121697,,,,,,,,,,,,,,,,, -125900,0.39412248,3.572685,,,,,,,,,,,,,,,,, -126000,0.39540973,3.5884054,,,,,,,,,,,,,,,,, -126100,0.38805947,3.6168013,,,,,,,,,,,,,,,,, -126200,0.40578213,3.5983644,,,,,,,,,,,,,,,,, -126300,0.3880184,3.5625486,,,,,,,,,,,,,,,,, -126400,0.37807474,3.5764184,,,,,,,,,,,,,,,,, -126500,0.40267172,3.6245472,,,,,,,,,,,,,,,,, -126600,0.40090108,3.6189678,,,,,,,,,,,,,,,,, -126700,0.3719888,3.5572972,,,,,,,,,,,,,,,,, -126800,0.3932298,3.580907,,,,,,,,,,,,,,,,, -126900,0.41095543,3.5710473,,,,,,,,,,,,,,,,, -127000,0.3798033,3.5846384,,,,,,,,,,,,,,,,, -127100,0.37205675,3.5844576,,,,,,,,,,,,,,,,, -127200,0.3864507,3.596756,,,,,,,,,,,,,,,,, -127300,0.40504068,3.6050935,,,,,,,,,,,,,,,,, -127400,0.39460868,3.5886655,,,,,,,,,,,,,,,,, -127500,0.39168584,3.6172376,,,,,,,,,,,,,,,,, -127600,0.40106857,3.5658174,,,,,,,,,,,,,,,,, -127700,0.39232484,3.623443,,,,,,,,,,,,,,,,, -127800,0.37604675,3.5778036,,,,,,,,,,,,,,,,, -127900,0.36986846,3.5961657,,,,,,,,,,,,,,,,, -128000,0.3901932,3.5803168,,,,,,,,,,,,,,,,, -128013,,,0.7244399785995483,1.435945749282837,37.314039333533486,0.6923162937164307,1.597839117050171,30.50014741034056,3000.0,0.7096856832504272,1.493368148803711,30.728892446349207,3003.0,44551.90661859512,71434.78941488266,44551.90661859512,26876.707780361176,1.870469331741333,0.0 -128100,0.38851044,3.5860977,,,,,,,,,,,,,,,,, -128200,0.38835528,3.5835958,,,,,,,,,,,,,,,,, -128300,0.4008386,3.5940156,,,,,,,,,,,,,,,,, -128400,0.38213882,3.573742,,,,,,,,,,,,,,,,, -128500,0.4117652,3.593503,,,,,,,,,,,,,,,,, -128600,0.38676402,3.5546212,,,,,,,,,,,,,,,,, -128700,0.38964188,3.5545776,,,,,,,,,,,,,,,,, -128800,0.40917858,3.6222022,,,,,,,,,,,,,,,,, -128900,0.39272815,3.5817356,,,,,,,,,,,,,,,,, -129000,0.4030414,3.6053317,,,,,,,,,,,,,,,,, -129100,0.3986822,3.5949047,,,,,,,,,,,,,,,,, -129200,0.3906509,3.6230912,,,,,,,,,,,,,,,,, -129300,0.39407754,3.6194487,,,,,,,,,,,,,,,,, -129400,0.40891933,3.5943642,,,,,,,,,,,,,,,,, -129500,0.37686288,3.5512211,,,,,,,,,,,,,,,,, -129600,0.38081726,3.602369,,,,,,,,,,,,,,,,, -129700,0.41232145,3.6710749,,,,,,,,,,,,,,,,, -129800,0.40054512,3.5988913,,,,,,,,,,,,,,,,, -129900,0.36986864,3.5534308,,,,,,,,,,,,,,,,, -130000,0.40285653,3.6093795,,,,,,,,,,,,,,,,, -130100,0.386185,3.5857975,,,,,,,,,,,,,,,,, -130200,0.40643433,3.6035326,,,,,,,,,,,,,,,,, -130300,0.38997558,3.6052728,,,,,,,,,,,,,,,,, -130400,0.3896528,3.5951083,,,,,,,,,,,,,,,,, -130429,,,0.7212345004081726,1.4529694318771362,37.38313524691233,0.6924278736114502,1.598196029663086,30.58731194343409,3000.0,0.7091395258903503,1.493313193321228,30.76292408930312,3003.0,45392.08640527725,72738.71410131454,45392.08640527725,27340.31756520272,1.9223430156707764,0.0 -130500,0.36642152,3.5944164,,,,,,,,,,,,,,,,, -130600,0.3898282,3.601053,,,,,,,,,,,,,,,,, -130700,0.39123294,3.6018586,,,,,,,,,,,,,,,,, -130800,0.3975207,3.6362743,,,,,,,,,,,,,,,,, -130900,0.38136265,3.5649142,,,,,,,,,,,,,,,,, -131000,0.38958126,3.590117,,,,,,,,,,,,,,,,, -131100,0.38324502,3.5930705,,,,,,,,,,,,,,,,, -131200,0.3892511,3.5313094,,,,,,,,,,,,,,,,, -131300,0.40147328,3.5930538,,,,,,,,,,,,,,,,, -131400,0.3969067,3.567947,,,,,,,,,,,,,,,,, -131500,0.39102238,3.5846536,,,,,,,,,,,,,,,,, -131600,0.44167393,3.6414733,,,,,,,,,,,,,,,,, -131700,0.4158844,3.6286094,,,,,,,,,,,,,,,,, -131800,0.3950753,3.636835,,,,,,,,,,,,,,,,, -131900,0.41779307,3.560917,,,,,,,,,,,,,,,,, -132000,0.39074722,3.5778003,,,,,,,,,,,,,,,,, -132100,0.39351374,3.59992,,,,,,,,,,,,,,,,, -132200,0.39336324,3.6083016,,,,,,,,,,,,,,,,, -132300,0.39102027,3.6541564,,,,,,,,,,,,,,,,, -132400,0.39822876,3.5785828,,,,,,,,,,,,,,,,, -132500,0.41828898,3.5789793,,,,,,,,,,,,,,,,, -132600,0.39441472,3.641481,,,,,,,,,,,,,,,,, -132700,0.38456216,3.5589888,,,,,,,,,,,,,,,,, -132800,0.38842168,3.5297453,,,,,,,,,,,,,,,,, -132843,,,0.7237308025360107,1.441246747970581,37.43974567291994,0.6926634311676025,1.598312497138977,30.54856858515057,3000.0,0.7089187502861023,1.4930022954940796,30.805018297453067,3003.0,46231.99690008164,74050.91628265381,46231.99690008164,27812.481696128845,1.968170166015625,0.0 -132900,0.401232,3.6046839,,,,,,,,,,,,,,,,, -133000,0.41205806,3.6178136,,,,,,,,,,,,,,,,, -133100,0.39118704,3.6027083,,,,,,,,,,,,,,,,, -133200,0.39902928,3.6445103,,,,,,,,,,,,,,,,, -133300,0.38537857,3.6015842,,,,,,,,,,,,,,,,, -133333,,,0.7221422791481018,1.4469060897827148,37.50803900245006,0.6926758289337158,1.598305106163025,30.54901382135056,3000.0,0.7088838815689087,1.4929988384246826,30.78181083377449,3003.0,46401.92576980591,74690.03471302986,46401.92576980591,28281.60786294937,2.0151820182800293,0.0 -133333,,,,,,,,,,,,,,46401.92576980591,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 732edb021..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -859.8778374195099,0.0,30.05414128303528,1,0,30.05414128303528,0.0007088489946909,0.0,10.966498374938965,3003,889.9320287704468,0.0006362841231748,0.0,10.959882736206056,0.0004835649742744,0.0,10.980294227600098,3000 -1459.2620244026184,0.0198702812194824,870.1676576137543,2413,0,870.1676576137543,0.3826157748699188,8.030870962674728,4.261987686157227,3003,2329.528389930725,0.4128057062625885,14.08984575280066,3.936654329299927,0.3986683189868927,9.633888628769736,4.049123287200928,3000 -1924.25225186348,0.0505769252777099,1710.180258989334,4824,0,1710.180258989334,0.5445354580879211,19.00498182839845,2.6518774032592773,3003,3634.642117023468,0.5432161092758179,24.829318029776022,2.668916940689087,0.5427582859992981,20.365649432861616,2.6340794563293457,3000 -2394.763978004456,0.0766811370849609,2550.388504981994,7235,0,2550.388504981994,0.5854976773262024,21.73489214068676,2.21711802482605,3003,4945.470089673996,0.581144392490387,27.389879639296822,2.2662315368652344,0.5851012468338013,23.189117839268576,2.2258543968200684,3000 -2832.442953109741,0.1034173965454101,3390.442659378052,9648,0,3390.442659378052,0.6137005686759949,23.58660872714324,1.992855429649353,3003,6223.30933713913,0.592995285987854,28.40913522209264,2.1528778076171875,0.6077048182487488,24.757049137021426,2.034355640411377,3000 -3329.537977933884,0.1310830116271972,4230.618235588074,12063,0,4230.618235588074,0.6264017224311829,24.72895569223176,1.8653494119644165,3003,7560.686071395874,0.6006006598472595,28.907384311822263,2.061574697494507,0.6192111372947693,25.72642706038576,1.9119383096694944,3000 -3822.848229885101,0.1623303890228271,5070.55059671402,14479,0,5070.55059671402,0.6380454301834106,25.639332507671263,1.7743580341339111,3003,8894.039239883423,0.6118413209915161,29.58988159437644,1.9555613994598389,0.6300107836723328,26.35269237560661,1.8212212324142456,3000 -4307.02596282959,0.1945400238037109,5910.659248828888,16895,0,5910.659248828888,0.6462843418121338,26.05308990904883,1.706995964050293,3003,10218.438910245895,0.6211254596710205,29.799160596214755,1.907320499420166,0.6405500173568726,26.86963051346412,1.7542517185211182,3000 -4810.599547386169,0.2221553325653076,6750.589684724808,19311,0,6750.589684724808,0.6546743512153625,26.37794400092443,1.6618460416793823,3003,11562.04913663864,0.6253435015678406,30.11935736955452,1.8560582399368288,0.6459808349609375,27.16554742317671,1.716254472732544,3000 -5276.457659244537,0.2528328895568847,7590.577194213867,21727,0,7590.577194213867,0.6591482162475586,27.06400431392655,1.6333097219467163,3003,12868.004507541656,0.625741720199585,30.296728445860584,1.8575769662857056,0.6483986377716064,27.708869311230742,1.690137267112732,3000 -5809.7654638290405,0.2813522815704345,8430.762261390686,24142,0,8430.762261390686,0.6639590859413147,27.245981275229024,1.604680418968201,3003,14241.60857963562,0.6286821961402893,30.94324298130113,1.842318058013916,0.6500725150108337,27.84176105484295,1.673073410987854,3000 -6348.470281600952,0.3095395565032959,9270.876539945602,26559,0,9270.876539945602,0.6608215570449829,26.744297453024025,1.6026358604431152,3003,15620.536381721497,0.6353502869606018,30.71224749308285,1.772486925125122,0.6529862880706787,27.82423233567216,1.6593254804611206,3000 -6806.992643594742,0.3389849662780761,10111.109763383864,28976,0,10111.109763383864,0.6632967591285706,27.30567537970401,1.58885395526886,3003,16919.401398181915,0.6322528719902039,30.72612753003488,1.8003666400909424,0.6520811915397644,27.845179168567565,1.648480772972107,3000 -7369.476313352585,0.3675158023834228,10951.089903831482,31392,0,10951.089903831482,0.6667131781578064,27.34374131007221,1.5764648914337158,3003,18321.97534775734,0.6668255925178528,33.53180261019265,1.5802048444747925,0.6552553772926331,27.770932078759643,1.6375457048416138,3000 -7893.05903172493,0.4018421173095703,11791.302706718445,33809,0,11791.302706718445,0.665690541267395,27.31780751025964,1.5672624111175537,3003,19685.88606619835,0.6364154815673828,30.883054872066605,1.7697381973266602,0.6572515964508057,28.07798566784241,1.6273834705352783,3000 -8378.304702997208,0.4316813945770263,12631.28459262848,36227,0,12631.28459262848,0.6670966148376465,27.478248748542704,1.5575518608093262,3003,21011.22130537033,0.6372659206390381,31.08953855694967,1.7720032930374146,0.6579459309577942,28.23293001497443,1.616313338279724,3000 -8841.15231704712,0.4688034057617187,13471.477101564407,38643,0,13471.477101564407,0.6660391688346863,27.41756213472292,1.5656033754348757,3003,22314.383660316467,0.6398760676383972,31.15971808188359,1.7399176359176636,0.6560736894607544,28.02467696226403,1.6214781999588013,3000 -9422.134581565855,0.5005180835723877,14311.619084835052,41061,0,14311.619084835052,0.672163188457489,27.999417482086034,1.5471595525741575,3003,23735.61979198456,0.6372715830802917,30.75969332973245,1.7691075801849363,0.6587767004966736,28.07393183728591,1.6140390634536743,3000 -9898.793123483658,0.5321898460388184,15151.775276899338,43479,0,15151.775276899338,0.6713265180587769,27.56919141028713,1.540757417678833,3003,25052.546183347706,0.6386000514030457,31.427426086467367,1.760711908340454,0.6623228192329407,28.480353182921547,1.598664164543152,3000 -10439.45487523079,0.5634627342224121,15991.93740272522,45897,0,15991.93740272522,0.6712567806243896,27.98019850766508,1.539738416671753,3003,26433.4802134037,0.6438037753105164,31.366122874842,1.7203915119171145,0.6614301204681396,28.56949921427732,1.597715973854065,3000 -10959.045782327652,0.5941922664642334,16832.123507976532,48315,0,16832.123507976532,0.6734995245933533,28.134675838345206,1.5197216272354126,3003,27793.366586208344,0.6425302028656006,31.29823991504049,1.7314515113830566,0.6614673137664795,28.392376189178425,1.5935068130493164,3000 -11434.958333969116,0.6256530284881592,17672.332036733627,50733,0,17672.332036733627,0.6736041307449341,27.925587206144616,1.5203142166137695,3003,29109.59782481193,0.6550332307815552,31.87787872913857,1.6446468830108645,0.664765477180481,28.684633210753194,1.5815393924713137,3000 -12006.873229980469,0.6573050022125244,18512.45762705803,53150,0,18512.45762705803,0.6757422685623169,28.26698253487057,1.5066756010055542,3003,30521.75099492073,0.6440820097923279,31.623226056556543,1.7206270694732666,0.6641455292701721,28.54921184734869,1.5732247829437256,3000 -12472.107189893724,0.6961920261383057,19352.615475177765,55566,0,19352.615475177765,0.6787520051002502,28.449937622484967,1.5035967826843262,3003,31827.267405748367,0.6458576917648315,31.80318325835912,1.7048068046569824,0.6658813953399658,28.654052868631133,1.5674703121185305,3000 -13002.97655892372,0.7319936752319336,20192.78905391693,57983,0,20192.78905391693,0.678345263004303,28.715793486279985,1.492568016052246,3003,33198.427434682846,0.6503816843032837,32.32253220630714,1.6813665628433228,0.6673941016197205,28.76944588602514,1.5596922636032104,3000 -13500.465492010117,0.7707424163818359,21032.677373170853,60399,0,21032.677373170853,0.6783917546272278,28.376137687711587,1.4889075756072998,3003,34535.92731380463,0.6484251022338867,31.836385315191027,1.6919599771499634,0.667815625667572,28.898338564560326,1.5593230724334717,3000 -14147.967857837675,0.803971529006958,21872.860898256306,62817,0,21872.860898256306,0.6795189380645752,28.410224716126773,1.4860724210739136,3003,36023.72711586952,0.6717912554740906,33.68370344851488,1.5358587503433228,0.6687455773353577,28.54941150424726,1.5462676286697388,3000 -14705.24766755104,0.8449218273162842,22712.9723572731,65235,0,22712.9723572731,0.6814130544662476,28.651393552679146,1.4726934432983398,3003,37421.23839473725,0.6527928113937378,32.17369020688684,1.659690499305725,0.6705434322357178,28.912173341920933,1.5372453927993774,3000 -15250.439332008362,0.881615400314331,23552.86766433716,67652,0,23552.86766433716,0.6824589371681213,28.896382659621786,1.463875412940979,3003,38806.44243097305,0.6534407138824463,32.31266434351574,1.6673787832260132,0.6718453764915466,29.001009543577776,1.5317059755325315,3000 -15839.218275308607,0.917823314666748,24392.89284348488,70069,0,24392.89284348488,0.6847829818725586,28.9900289674124,1.4570910930633545,3003,40235.36283278465,0.6589791178703308,32.72365867691716,1.6300126314163208,0.673568844795227,29.25410080130436,1.5262335538864136,3000 -16385.491423606873,0.9566261768341064,25232.844252109528,72486,0,25232.844252109528,0.6864911913871765,29.154292326536787,1.4465367794036863,3003,41621.70757579804,0.6565991044044495,32.39992250394725,1.6378833055496216,0.6732092499732971,29.28347733531719,1.5226255655288696,3000 -16930.741188049316,0.9927854537963868,26072.833317756653,74903,0,26072.833317756653,0.6882458925247192,29.18233746463092,1.440338373184204,3003,43007.06311130524,0.65643310546875,32.34811927392111,1.6501281261444092,0.6752055287361145,29.52027020865552,1.5135127305984497,3000 -17470.37271785736,1.029709815979004,26912.88521933556,77321,0,26912.88521933556,0.6892685294151306,29.210126730147856,1.427660584449768,3003,44386.86210536957,0.6606552600860596,33.132585237997496,1.6113643646240234,0.6767429709434509,29.46409467654244,1.5007519721984863,3000 -17985.91515660286,1.0653765201568604,27752.805153131485,79737,0,27752.805153131485,0.6904189586639404,29.423481171518823,1.426446557044983,3003,45742.4435338974,0.6597093939781189,32.721303780072326,1.622588872909546,0.6777225136756897,29.9094847833621,1.5018718242645264,3000 -18510.5944993496,1.1032533645629885,28592.93453645706,82154,0,28592.93453645706,0.6918599009513855,29.724196297362315,1.4198691844940186,3003,47107.37024998665,0.6717402935028076,33.79730536923105,1.538611888885498,0.6788632273674011,29.745504916128827,1.4914655685424805,3000 -19105.83856487274,1.1398842334747314,29432.89954471588,84571,0,29432.89954471588,0.6945209503173828,29.945540079129746,1.4088683128356934,3003,48542.69511389732,0.665301501750946,33.228433223994266,1.583433747291565,0.6785780787467957,29.50485520503944,1.48270583152771,3000 -19636.73670172692,1.1791932582855225,30272.9449737072,86989,0,30272.9449737072,0.6953459978103638,29.84721352815368,1.4079169034957886,3003,49913.75622940064,0.6616896986961365,33.137917686474665,1.6091835498809814,0.6786152720451355,29.751823489337745,1.482520580291748,3000 -20182.18912935257,1.2226190567016602,31113.069878339767,89407,0,31113.069878339767,0.6946952939033508,29.631786634944643,1.3910845518112185,3003,51299.45757818222,0.6713466644287109,33.52182294055766,1.5552958250045776,0.6820374131202698,29.869141731518045,1.465645670890808,3000 -20672.136009454727,1.2626848220825195,31953.005130767822,91823,0,31953.005130767822,0.6980187296867371,30.353041929433832,1.381521701812744,3003,52629.45927906037,0.6688807606697083,33.646249972575525,1.561652421951294,0.6818886399269104,29.931857466564303,1.4689723253250122,3000 -21232.269668102264,1.301042079925537,32793.152134656906,94241,0,32793.152134656906,0.699785053730011,30.07975448373041,1.3779276609420776,3003,54029.85724711418,0.6905794739723206,34.98234446474538,1.4270117282867432,0.6850627660751343,30.200363953398387,1.4553813934326172,3000 -21778.425297021862,1.342395544052124,33633.102585315704,96658,0,33633.102585315704,0.6992853283882141,30.073980041317707,1.3757984638214111,3003,55416.08647465706,0.6737492084503174,34.108211883515416,1.5275613069534302,0.6844552159309387,30.20829723913238,1.4521570205688477,3000 -22323.09084415436,1.3831169605255127,34472.99043941498,99074,0,34472.99043941498,0.7000755667686462,30.03763911580841,1.370320200920105,3003,56800.761588811874,0.6748491525650024,33.783884560990835,1.5263543128967283,0.6855711340904236,30.292007534617788,1.4503666162490845,3000 -22862.66453719139,1.4255762100219729,35312.9445669651,101491,0,35312.9445669651,0.7018999457359314,30.280502251710864,1.3603651523590088,3003,58180.41181755066,0.6866604685783386,34.48494577502086,1.455691695213318,0.6869474649429321,30.187936515973128,1.4450162649154663,3000 -23525.2458164692,1.4649641513824463,36152.97557926178,103909,0,36152.97557926178,0.7025158405303955,30.232446000713,1.3520406484603882,3003,59683.14560890198,0.6815876960754395,34.64913971543467,1.492857575416565,0.6869598627090454,30.17169848257597,1.4377763271331787,3000 -24089.24018883705,1.504408836364746,36993.1280477047,106327,0,36993.1280477047,0.7030619978904724,30.41185965475121,1.3502763509750366,3003,61087.41273570061,0.6837038993835449,35.08743437169032,1.4776220321655271,0.688646137714386,30.522066373997887,1.4313939809799194,3000 -24620.91137957573,1.5444655418395996,37833.29610252381,108745,0,37833.29610252381,0.705420970916748,30.55323781090465,1.344992756843567,3003,62459.371757268906,0.6909144520759583,34.72590805081428,1.4369436502456665,0.6895388960838318,30.422796842264475,1.42962646484375,3000 -25158.474792718887,1.5841474533081057,38673.18712234497,111162,0,38673.18712234497,0.7066760063171387,30.61793987736942,1.337507247924805,3003,63836.94434714317,0.6862839460372925,35.03915497026832,1.4570157527923584,0.6900472044944763,30.63425878074417,1.4253438711166382,3000 -25712.46093392372,1.6317753791809082,39513.24196100235,113580,0,39513.24196100235,0.7059555053710938,30.445572866631903,1.3358170986175537,3003,65231.111525297165,0.703749418258667,36.01579664231886,1.3690860271453855,0.691473126411438,30.438168648771725,1.4214249849319458,3000 -26318.045152664185,1.672149419784546,40353.26636815071,115998,0,40353.26636815071,0.7085468769073486,30.682219921723902,1.3254882097244265,3003,66676.8403544426,0.6970692873001099,35.804610108981215,1.398681640625,0.691386342048645,30.566812091793675,1.418184757232666,3000 -26890.30124187469,1.715043306350708,41193.21762943268,118415,0,41193.21762943268,0.7084422707557678,30.9174266863902,1.3272607326507568,3003,68089.17468738556,0.6983740925788879,35.78991389340325,1.3953670263290403,0.6922046542167664,30.73623908706971,1.4161663055419922,3000 -27430.756391763687,1.8395373821258545,42033.18365478516,120833,0,42033.18365478516,0.7084422707557678,30.767148916451248,1.324387550354004,3003,69469.79971814156,0.7079231142997742,36.54008424706086,1.3452441692352295,0.69246506690979,30.88375035702032,1.4150506258010864,3000 -27963.00019860268,1.8919637203216555,42873.383053064346,123251,0,42873.383053064346,0.7096391916275024,30.74864159068372,1.320741891860962,3003,70842.37611746788,0.7070516347885132,36.243165574604014,1.346768140792847,0.6931470036506653,30.778888404449745,1.413427233695984,3000 -28548.430537700653,1.933983564376831,43713.294801950455,125668,0,43713.294801950455,0.709337055683136,30.857238354469462,1.3211370706558228,3003,72267.84021043777,0.710022509098053,36.40957170008483,1.332497239112854,0.6937421560287476,30.8876838523228,1.4104636907577517,3000 -29095.45155930519,1.9780395030975344,44553.37302994728,128086,0,44553.37302994728,0.7093021869659424,30.86260403130315,1.319244623184204,3003,73655.06421208382,0.7109796404838562,36.99412188193442,1.329494595527649,0.6932337880134583,30.99008929284226,1.4097765684127808,3000 -29643.51740694046,2.031567573547364,45393.514297008514,130505,0,45393.514297008514,0.710220217704773,30.82152023947141,1.3187133073806765,3003,75043.40402579308,0.7105883359909058,36.643996629084846,1.3296940326690674,0.6933826208114624,30.929830153579665,1.4100453853607178,3000 -30203.573014974598,2.0750794410705566,46233.72875595093,132924,0,46233.72875595093,0.7099761962890625,30.837535923196857,1.3190022706985474,3003,76443.79782891273,0.7117727398872375,36.79409669434138,1.324880599975586,0.6934942007064819,30.80455651715488,1.4099982976913452,3000 -30751.867042303085,2.1205337047576904,46375.332560777664,133333,0,46375.332560777664,0.7099761962890625,30.846443803364846,1.318996548652649,3003,77133.75478172302,0.7115277051925659,36.53925289250019,1.324217438697815,0.6935189962387085,30.825316571983876,1.40999174118042,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/measurements.csv deleted file mode 100644 index 8736e4dcb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.756458,10.947874,,,,,,,,,,,,,,,,, -1,,,0.0006362841231748,10.959882736206056,0.0,0.0004835649742744,10.980294227600098,0.0,3000.0,0.0007088489946909,10.966498374938965,0.0,3003.0,30.05414128303528,889.9320287704468,30.05414128303528,859.8778374195099,0.0,0.0 -100,0.43636623,8.727115,,,,,,,,,,,,,,,,, -200,0.18352073,8.299497,,,,,,,,,,,,,,,,, -300,0.20910476,8.015485,,,,,,,,,,,,,,,,, -400,0.28724572,7.6509776,,,,,,,,,,,,,,,,, -500,0.4089872,7.2564383,,,,,,,,,,,,,,,,, -600,0.6025768,6.973931,,,,,,,,,,,,,,,,, -700,0.6915962,6.7854657,,,,,,,,,,,,,,,,, -800,0.6633699,6.475345,,,,,,,,,,,,,,,,, -900,0.6709908,6.2056136,,,,,,,,,,,,,,,,, -1000,0.62693435,6.0277357,,,,,,,,,,,,,,,,, -1100,0.7626876,5.7627954,,,,,,,,,,,,,,,,, -1200,1.1794981,5.584896,,,,,,,,,,,,,,,,, -1300,0.68090546,5.473249,,,,,,,,,,,,,,,,, -1400,0.9197639,5.3056355,,,,,,,,,,,,,,,,, -1500,0.6875018,5.1224475,,,,,,,,,,,,,,,,, -1600,0.68983877,4.9264603,,,,,,,,,,,,,,,,, -1700,0.64467424,4.805288,,,,,,,,,,,,,,,,, -1800,1.5413042,4.7125883,,,,,,,,,,,,,,,,, -1900,1.6085511,4.555047,,,,,,,,,,,,,,,,, -2000,0.7795955,4.4959702,,,,,,,,,,,,,,,,, -2100,0.77823323,4.307182,,,,,,,,,,,,,,,,, -2200,1.0662092,4.246206,,,,,,,,,,,,,,,,, -2300,0.7970316,4.112397,,,,,,,,,,,,,,,,, -2400,1.213558,3.9406664,,,,,,,,,,,,,,,,, -2413,,,0.4128057062625885,3.936654329299927,14.08984575280066,0.3986683189868927,4.049123287200928,9.633888628769736,3000.0,0.3826157748699188,4.261987686157227,8.030870962674728,3003.0,870.1676576137543,2329.528389930725,870.1676576137543,1459.2620244026184,0.0198702812194824,0.0 -2500,0.8030683,3.9633765,,,,,,,,,,,,,,,,, -2600,0.83380944,3.6881893,,,,,,,,,,,,,,,,, -2700,0.8842217,3.6935318,,,,,,,,,,,,,,,,, -2800,1.7114344,3.6567,,,,,,,,,,,,,,,,, -2900,1.3451463,3.5154202,,,,,,,,,,,,,,,,, -3000,0.786722,3.4499195,,,,,,,,,,,,,,,,, -3100,0.9925693,3.3813097,,,,,,,,,,,,,,,,, -3200,0.76444,3.3407028,,,,,,,,,,,,,,,,, -3300,1.1004518,3.1798759,,,,,,,,,,,,,,,,, -3400,0.7188384,3.1770575,,,,,,,,,,,,,,,,, -3500,0.91128194,3.1030047,,,,,,,,,,,,,,,,, -3600,0.78607756,3.1351235,,,,,,,,,,,,,,,,, -3700,0.73459274,3.0406036,,,,,,,,,,,,,,,,, -3800,0.82364094,3.0017807,,,,,,,,,,,,,,,,, -3900,0.85181385,2.9508533,,,,,,,,,,,,,,,,, -4000,0.7886067,2.9601038,,,,,,,,,,,,,,,,, -4100,0.7123784,2.9083083,,,,,,,,,,,,,,,,, -4200,0.84984434,2.8579214,,,,,,,,,,,,,,,,, -4300,0.612369,2.7969873,,,,,,,,,,,,,,,,, -4400,0.6769874,2.773632,,,,,,,,,,,,,,,,, -4500,0.636431,2.7093916,,,,,,,,,,,,,,,,, -4600,0.7868551,2.7184498,,,,,,,,,,,,,,,,, -4700,0.67160803,2.670638,,,,,,,,,,,,,,,,, -4800,0.63801765,2.6466537,,,,,,,,,,,,,,,,, -4824,,,0.5432161092758179,2.668916940689087,24.829318029776022,0.5427582859992981,2.6340794563293457,20.365649432861616,3000.0,0.5445354580879211,2.6518774032592773,19.00498182839845,3003.0,1710.180258989334,3634.642117023468,1710.180258989334,1924.25225186348,0.0505769252777099,0.0 -4900,0.6036918,2.6756232,,,,,,,,,,,,,,,,, -5000,0.5803073,2.6454413,,,,,,,,,,,,,,,,, -5100,0.7208338,2.5836482,,,,,,,,,,,,,,,,, -5200,0.79410434,2.5740292,,,,,,,,,,,,,,,,, -5300,0.6058675,2.5199928,,,,,,,,,,,,,,,,, -5400,0.5757684,2.5371487,,,,,,,,,,,,,,,,, -5500,0.68131196,2.56964,,,,,,,,,,,,,,,,, -5600,0.6189619,2.5491607,,,,,,,,,,,,,,,,, -5700,0.6270749,2.5245414,,,,,,,,,,,,,,,,, -5800,0.5668511,2.480766,,,,,,,,,,,,,,,,, -5900,0.5172791,2.4777539,,,,,,,,,,,,,,,,, -6000,0.5676179,2.457432,,,,,,,,,,,,,,,,, -6100,0.59763527,2.4949567,,,,,,,,,,,,,,,,, -6200,0.5486406,2.4857867,,,,,,,,,,,,,,,,, -6300,0.61723197,2.3933885,,,,,,,,,,,,,,,,, -6400,0.49188346,2.4028609,,,,,,,,,,,,,,,,, -6500,0.5596532,2.4377384,,,,,,,,,,,,,,,,, -6600,0.5177502,2.4708748,,,,,,,,,,,,,,,,, -6700,0.5402101,2.3917513,,,,,,,,,,,,,,,,, -6800,0.4859856,2.3642066,,,,,,,,,,,,,,,,, -6900,0.6293431,2.4267135,,,,,,,,,,,,,,,,, -7000,0.4699507,2.3581312,,,,,,,,,,,,,,,,, -7100,0.5290806,2.4119189,,,,,,,,,,,,,,,,, -7200,0.47245526,2.308573,,,,,,,,,,,,,,,,, -7235,,,0.581144392490387,2.2662315368652344,27.389879639296822,0.5851012468338013,2.2258543968200684,23.189117839268576,3000.0,0.5854976773262024,2.21711802482605,21.73489214068676,3003.0,2550.388504981994,4945.470089673996,2550.388504981994,2394.763978004456,0.0766811370849609,0.0 -7300,0.50309336,2.27569,,,,,,,,,,,,,,,,, -7400,0.45214587,2.2318342,,,,,,,,,,,,,,,,, -7500,0.46112418,2.2554953,,,,,,,,,,,,,,,,, -7600,0.43591976,2.2207444,,,,,,,,,,,,,,,,, -7700,0.51043254,2.2665415,,,,,,,,,,,,,,,,, -7800,0.48728916,2.2979982,,,,,,,,,,,,,,,,, -7900,0.57458013,2.3686585,,,,,,,,,,,,,,,,, -8000,0.452347,2.2264504,,,,,,,,,,,,,,,,, -8100,0.45448506,2.2584958,,,,,,,,,,,,,,,,, -8200,0.4144092,2.2509732,,,,,,,,,,,,,,,,, -8300,0.47367388,2.2187674,,,,,,,,,,,,,,,,, -8400,0.49606535,2.237892,,,,,,,,,,,,,,,,, -8500,0.4405324,2.0985243,,,,,,,,,,,,,,,,, -8600,0.39641932,2.1231468,,,,,,,,,,,,,,,,, -8700,0.3956742,2.2158432,,,,,,,,,,,,,,,,, -8800,0.43835786,2.1237905,,,,,,,,,,,,,,,,, -8900,0.4293,2.1488478,,,,,,,,,,,,,,,,, -9000,0.40179038,2.2753465,,,,,,,,,,,,,,,,, -9100,0.410133,2.2065792,,,,,,,,,,,,,,,,, -9200,0.42548385,2.2803986,,,,,,,,,,,,,,,,, -9300,0.35272363,2.1585305,,,,,,,,,,,,,,,,, -9400,0.36678112,2.2246728,,,,,,,,,,,,,,,,, -9500,0.37967265,2.1129224,,,,,,,,,,,,,,,,, -9600,0.3955888,2.1296103,,,,,,,,,,,,,,,,, -9648,,,0.592995285987854,2.1528778076171875,28.40913522209264,0.6077048182487488,2.034355640411377,24.757049137021426,3000.0,0.6137005686759949,1.992855429649353,23.58660872714324,3003.0,3390.442659378052,6223.30933713913,3390.442659378052,2832.442953109741,0.1034173965454101,0.0 -9700,0.37233928,2.0608225,,,,,,,,,,,,,,,,, -9800,0.35335416,2.133482,,,,,,,,,,,,,,,,, -9900,0.34511894,2.1142335,,,,,,,,,,,,,,,,, -10000,0.3656415,2.0809443,,,,,,,,,,,,,,,,, -10100,0.35389182,2.2102668,,,,,,,,,,,,,,,,, -10200,0.38722706,2.106284,,,,,,,,,,,,,,,,, -10300,0.36918214,2.1033578,,,,,,,,,,,,,,,,, -10400,0.3390132,2.105094,,,,,,,,,,,,,,,,, -10500,0.3166391,2.2113779,,,,,,,,,,,,,,,,, -10600,0.3370047,2.118498,,,,,,,,,,,,,,,,, -10700,0.32650244,2.168459,,,,,,,,,,,,,,,,, -10800,0.31991065,2.1232657,,,,,,,,,,,,,,,,, -10900,0.29911244,2.0882978,,,,,,,,,,,,,,,,, -11000,0.37584507,2.2188427,,,,,,,,,,,,,,,,, -11100,0.30846062,2.018473,,,,,,,,,,,,,,,,, -11200,0.347638,2.0270197,,,,,,,,,,,,,,,,, -11300,0.3080035,2.0973616,,,,,,,,,,,,,,,,, -11400,0.32644275,2.0631964,,,,,,,,,,,,,,,,, -11500,0.2956233,2.0889156,,,,,,,,,,,,,,,,, -11600,0.31514058,2.0494642,,,,,,,,,,,,,,,,, -11700,0.30618244,2.1080642,,,,,,,,,,,,,,,,, -11800,0.34339607,1.9583933,,,,,,,,,,,,,,,,, -11900,0.3054332,2.0731642,,,,,,,,,,,,,,,,, -12000,0.381553,2.1042151,,,,,,,,,,,,,,,,, -12063,,,0.6006006598472595,2.061574697494507,28.907384311822263,0.6192111372947693,1.9119383096694944,25.72642706038576,3000.0,0.6264017224311829,1.8653494119644165,24.72895569223176,3003.0,4230.618235588074,7560.686071395874,4230.618235588074,3329.537977933884,0.1310830116271972,0.0 -12100,0.34044272,1.9991934,,,,,,,,,,,,,,,,, -12200,0.2963797,2.0632143,,,,,,,,,,,,,,,,, -12300,0.31286624,1.9641378,,,,,,,,,,,,,,,,, -12400,0.30594042,2.0134318,,,,,,,,,,,,,,,,, -12500,0.29679635,2.121949,,,,,,,,,,,,,,,,, -12600,0.2926285,2.0664928,,,,,,,,,,,,,,,,, -12700,0.29704568,2.1068556,,,,,,,,,,,,,,,,, -12800,0.2866511,2.058662,,,,,,,,,,,,,,,,, -12900,0.2918404,1.9567344,,,,,,,,,,,,,,,,, -13000,0.28142595,1.978296,,,,,,,,,,,,,,,,, -13100,0.2969334,2.0181437,,,,,,,,,,,,,,,,, -13200,0.28428334,2.0217235,,,,,,,,,,,,,,,,, -13300,0.27539223,1.9844458,,,,,,,,,,,,,,,,, -13400,0.27641243,2.0443876,,,,,,,,,,,,,,,,, -13500,0.31941527,1.989375,,,,,,,,,,,,,,,,, -13600,0.31219444,1.9872304,,,,,,,,,,,,,,,,, -13700,0.290167,2.011687,,,,,,,,,,,,,,,,, -13800,0.2841542,1.9269207,,,,,,,,,,,,,,,,, -13900,0.30715686,2.0886662,,,,,,,,,,,,,,,,, -14000,0.30153257,1.9851743,,,,,,,,,,,,,,,,, -14100,0.2717624,1.9981072,,,,,,,,,,,,,,,,, -14200,0.26448044,1.9950572,,,,,,,,,,,,,,,,, -14300,0.27503386,1.9168638,,,,,,,,,,,,,,,,, -14400,0.25412396,1.9876926,,,,,,,,,,,,,,,,, -14479,,,0.6118413209915161,1.9555613994598389,29.58988159437644,0.6300107836723328,1.8212212324142456,26.35269237560661,3000.0,0.6380454301834106,1.7743580341339111,25.639332507671263,3003.0,5070.55059671402,8894.039239883423,5070.55059671402,3822.848229885101,0.1623303890228271,0.0 -14500,0.25780416,1.9110975,,,,,,,,,,,,,,,,, -14600,0.26916933,1.9867945,,,,,,,,,,,,,,,,, -14700,0.27709794,2.0205114,,,,,,,,,,,,,,,,, -14800,0.31872112,1.9654862,,,,,,,,,,,,,,,,, -14900,0.30013818,2.0038908,,,,,,,,,,,,,,,,, -15000,0.2996425,1.9415168,,,,,,,,,,,,,,,,, -15100,0.2864839,2.0189123,,,,,,,,,,,,,,,,, -15200,0.2707539,2.031745,,,,,,,,,,,,,,,,, -15300,0.25675374,1.967555,,,,,,,,,,,,,,,,, -15400,0.27464542,1.9098225,,,,,,,,,,,,,,,,, -15500,0.3161467,1.8900592,,,,,,,,,,,,,,,,, -15600,0.333892,1.9842719,,,,,,,,,,,,,,,,, -15700,0.34250596,2.0031755,,,,,,,,,,,,,,,,, -15800,0.38718784,1.9236224,,,,,,,,,,,,,,,,, -15900,0.28715256,1.9661844,,,,,,,,,,,,,,,,, -16000,0.3298286,1.8309857,,,,,,,,,,,,,,,,, -16100,0.30356112,1.9651679,,,,,,,,,,,,,,,,, -16200,0.29453987,1.8650944,,,,,,,,,,,,,,,,, -16300,0.32306877,1.9032066,,,,,,,,,,,,,,,,, -16400,0.4028019,1.8703796,,,,,,,,,,,,,,,,, -16500,0.36409023,1.9951007,,,,,,,,,,,,,,,,, -16600,0.3592236,1.8265342,,,,,,,,,,,,,,,,, -16700,0.31478465,1.9740531,,,,,,,,,,,,,,,,, -16800,0.3094116,1.8758371,,,,,,,,,,,,,,,,, -16895,,,0.6211254596710205,1.907320499420166,29.799160596214755,0.6405500173568726,1.7542517185211182,26.86963051346412,3000.0,0.6462843418121338,1.706995964050293,26.05308990904883,3003.0,5910.659248828888,10218.438910245895,5910.659248828888,4307.02596282959,0.1945400238037109,0.0 -16900,0.2872787,1.966302,,,,,,,,,,,,,,,,, -17000,0.3258389,2.0405068,,,,,,,,,,,,,,,,, -17100,0.3177424,1.9596417,,,,,,,,,,,,,,,,, -17200,0.33709642,1.8647053,,,,,,,,,,,,,,,,, -17300,0.36171436,1.9409844,,,,,,,,,,,,,,,,, -17400,0.39216897,1.8742162,,,,,,,,,,,,,,,,, -17500,0.3197272,1.9138392,,,,,,,,,,,,,,,,, -17600,0.33067283,1.9415617,,,,,,,,,,,,,,,,, -17700,0.33428884,1.861137,,,,,,,,,,,,,,,,, -17800,0.2860216,1.8987138,,,,,,,,,,,,,,,,, -17900,0.3581389,1.866491,,,,,,,,,,,,,,,,, -18000,0.3089606,1.8958707,,,,,,,,,,,,,,,,, -18100,0.3168437,1.8288538,,,,,,,,,,,,,,,,, -18200,0.3412254,1.9987047,,,,,,,,,,,,,,,,, -18300,0.45081052,1.8270876,,,,,,,,,,,,,,,,, -18400,0.29852808,1.8917474,,,,,,,,,,,,,,,,, -18500,0.32856432,1.9030575,,,,,,,,,,,,,,,,, -18600,0.41042718,4.032799,,,,,,,,,,,,,,,,, -18700,0.2858307,3.8838897,,,,,,,,,,,,,,,,, -18800,1.2103266,2.6138408,,,,,,,,,,,,,,,,, -18900,0.35085964,1.8682652,,,,,,,,,,,,,,,,, -19000,0.3348012,1.837214,,,,,,,,,,,,,,,,, -19100,0.35584015,1.8933173,,,,,,,,,,,,,,,,, -19200,0.29665643,1.8588358,,,,,,,,,,,,,,,,, -19300,0.28783193,1.7837056,,,,,,,,,,,,,,,,, -19311,,,0.6253435015678406,1.8560582399368288,30.11935736955452,0.6459808349609375,1.716254472732544,27.16554742317671,3000.0,0.6546743512153625,1.6618460416793823,26.37794400092443,3003.0,6750.589684724808,11562.04913663864,6750.589684724808,4810.599547386169,0.2221553325653076,0.0 -19400,0.33468327,1.911655,,,,,,,,,,,,,,,,, -19500,0.32171866,1.784027,,,,,,,,,,,,,,,,, -19600,0.33651388,1.9717908,,,,,,,,,,,,,,,,, -19700,0.3716232,1.864103,,,,,,,,,,,,,,,,, -19800,0.38298255,1.8655208,,,,,,,,,,,,,,,,, -19900,0.32635903,1.9222088,,,,,,,,,,,,,,,,, -20000,0.34260812,1.8685435,,,,,,,,,,,,,,,,, -20100,0.3047376,1.7794496,,,,,,,,,,,,,,,,, -20200,0.35183275,1.8154128,,,,,,,,,,,,,,,,, -20300,0.34703177,1.8951555,,,,,,,,,,,,,,,,, -20400,0.33116972,1.9359834,,,,,,,,,,,,,,,,, -20500,0.33812457,1.843293,,,,,,,,,,,,,,,,, -20600,0.34626815,1.7527982,,,,,,,,,,,,,,,,, -20700,0.35454834,1.7819476,,,,,,,,,,,,,,,,, -20800,0.36979496,1.8761197,,,,,,,,,,,,,,,,, -20900,0.33791476,1.8327365,,,,,,,,,,,,,,,,, -21000,0.38998008,1.8381598,,,,,,,,,,,,,,,,, -21100,0.33063594,1.8747349,,,,,,,,,,,,,,,,, -21200,0.34954718,1.8960136,,,,,,,,,,,,,,,,, -21300,0.39778945,1.8947809,,,,,,,,,,,,,,,,, -21400,1.0131671,1.9249744,,,,,,,,,,,,,,,,, -21500,0.5885605,1.8855597,,,,,,,,,,,,,,,,, -21600,0.3833321,1.7793936,,,,,,,,,,,,,,,,, -21700,0.37058353,1.7550206,,,,,,,,,,,,,,,,, -21727,,,0.625741720199585,1.8575769662857056,30.296728445860584,0.6483986377716064,1.690137267112732,27.708869311230742,3000.0,0.6591482162475586,1.6333097219467163,27.06400431392655,3003.0,7590.577194213867,12868.004507541656,7590.577194213867,5276.457659244537,0.2528328895568847,0.0 -21800,0.36485797,1.8378031,,,,,,,,,,,,,,,,, -21900,0.39090306,1.9035347,,,,,,,,,,,,,,,,, -22000,0.358029,1.984748,,,,,,,,,,,,,,,,, -22100,0.4271569,1.9250764,,,,,,,,,,,,,,,,, -22200,0.38248911,1.9247978,,,,,,,,,,,,,,,,, -22300,0.33611137,1.8853132,,,,,,,,,,,,,,,,, -22400,0.4475338,1.9180899,,,,,,,,,,,,,,,,, -22500,0.36689445,1.8727468,,,,,,,,,,,,,,,,, -22600,0.39100385,1.8798683,,,,,,,,,,,,,,,,, -22700,0.3843394,1.8202075,,,,,,,,,,,,,,,,, -22800,0.38033983,1.806332,,,,,,,,,,,,,,,,, -22900,0.38878107,1.8701472,,,,,,,,,,,,,,,,, -23000,0.38192737,1.8357866,,,,,,,,,,,,,,,,, -23100,0.34067658,1.823498,,,,,,,,,,,,,,,,, -23200,0.4325724,1.8746561,,,,,,,,,,,,,,,,, -23300,0.41604158,1.8440262,,,,,,,,,,,,,,,,, -23400,0.36620438,1.9066974,,,,,,,,,,,,,,,,, -23500,0.4562884,1.8114655,,,,,,,,,,,,,,,,, -23600,0.36651108,1.9009682,,,,,,,,,,,,,,,,, -23700,0.44310567,1.8067969,,,,,,,,,,,,,,,,, -23800,0.34332237,1.8177568,,,,,,,,,,,,,,,,, -23900,0.43383402,1.8663751,,,,,,,,,,,,,,,,, -24000,0.49332693,1.8422273,,,,,,,,,,,,,,,,, -24100,0.33964494,1.8392974,,,,,,,,,,,,,,,,, -24142,,,0.6286821961402893,1.842318058013916,30.94324298130113,0.6500725150108337,1.673073410987854,27.84176105484295,3000.0,0.6639590859413147,1.604680418968201,27.245981275229024,3003.0,8430.762261390686,14241.60857963562,8430.762261390686,5809.7654638290405,0.2813522815704345,0.0 -24200,0.39950642,1.8301275,,,,,,,,,,,,,,,,, -24300,0.36036476,1.8460798,,,,,,,,,,,,,,,,, -24400,0.4542525,1.7658936,,,,,,,,,,,,,,,,, -24500,0.34433678,1.767894,,,,,,,,,,,,,,,,, -24600,0.35135418,1.8727714,,,,,,,,,,,,,,,,, -24700,0.36990002,1.7850046,,,,,,,,,,,,,,,,, -24800,0.36163568,1.8011118,,,,,,,,,,,,,,,,, -24900,0.38599867,1.8775089,,,,,,,,,,,,,,,,, -25000,0.408278,1.7461498,,,,,,,,,,,,,,,,, -25100,0.3654474,1.8612301,,,,,,,,,,,,,,,,, -25200,0.4704523,1.8970599,,,,,,,,,,,,,,,,, -25300,0.3890321,1.7853118,,,,,,,,,,,,,,,,, -25400,0.45016465,1.7829658,,,,,,,,,,,,,,,,, -25500,0.37816852,1.7560748,,,,,,,,,,,,,,,,, -25600,0.40765965,1.800653,,,,,,,,,,,,,,,,, -25700,0.44719663,1.8039684,,,,,,,,,,,,,,,,, -25800,0.37903753,1.8782586,,,,,,,,,,,,,,,,, -25900,0.35671523,1.7491817,,,,,,,,,,,,,,,,, -26000,0.3731693,1.8242456,,,,,,,,,,,,,,,,, -26100,0.44282225,1.851949,,,,,,,,,,,,,,,,, -26200,0.42010447,1.864122,,,,,,,,,,,,,,,,, -26300,0.4129029,1.8212585,,,,,,,,,,,,,,,,, -26400,0.4256567,1.8124363,,,,,,,,,,,,,,,,, -26500,0.35014328,1.8245822,,,,,,,,,,,,,,,,, -26559,,,0.6353502869606018,1.772486925125122,30.71224749308285,0.6529862880706787,1.6593254804611206,27.82423233567216,3000.0,0.6608215570449829,1.6026358604431152,26.744297453024025,3003.0,9270.876539945602,15620.536381721497,9270.876539945602,6348.470281600952,0.3095395565032959,0.0 -26600,0.424572,1.834527,,,,,,,,,,,,,,,,, -26700,0.39803004,1.7600865,,,,,,,,,,,,,,,,, -26800,0.41825548,1.8194704,,,,,,,,,,,,,,,,, -26900,0.3988726,1.8134179,,,,,,,,,,,,,,,,, -27000,0.3852093,1.8313831,,,,,,,,,,,,,,,,, -27100,0.4685126,1.7775006,,,,,,,,,,,,,,,,, -27200,0.38505676,1.8411644,,,,,,,,,,,,,,,,, -27300,0.36595675,1.847104,,,,,,,,,,,,,,,,, -27400,0.37853524,1.7868712,,,,,,,,,,,,,,,,, -27500,0.48186648,1.8888437,,,,,,,,,,,,,,,,, -27600,0.44124043,1.8388081,,,,,,,,,,,,,,,,, -27700,0.39973623,1.8831346,,,,,,,,,,,,,,,,, -27800,0.44935176,1.8015345,,,,,,,,,,,,,,,,, -27900,0.42359135,1.8300576,,,,,,,,,,,,,,,,, -28000,0.45502207,1.85556,,,,,,,,,,,,,,,,, -28100,0.40618637,1.8542119,,,,,,,,,,,,,,,,, -28200,0.6522288,1.7638569,,,,,,,,,,,,,,,,, -28300,0.81101364,1.7387097,,,,,,,,,,,,,,,,, -28400,0.44106042,1.8088142,,,,,,,,,,,,,,,,, -28500,0.42377663,1.8421308,,,,,,,,,,,,,,,,, -28600,0.42511854,1.720953,,,,,,,,,,,,,,,,, -28700,0.4747582,1.8450164,,,,,,,,,,,,,,,,, -28800,0.3955662,1.7973055,,,,,,,,,,,,,,,,, -28900,0.38031647,1.8558084,,,,,,,,,,,,,,,,, -28976,,,0.6322528719902039,1.8003666400909424,30.72612753003488,0.6520811915397644,1.648480772972107,27.845179168567565,3000.0,0.6632967591285706,1.58885395526886,27.30567537970401,3003.0,10111.109763383864,16919.401398181915,10111.109763383864,6806.992643594742,0.3389849662780761,0.0 -29000,0.45433196,1.841863,,,,,,,,,,,,,,,,, -29100,0.4369319,1.7985195,,,,,,,,,,,,,,,,, -29200,0.45084614,1.8302685,,,,,,,,,,,,,,,,, -29300,0.44941178,1.8488618,,,,,,,,,,,,,,,,, -29400,0.3411823,1.8148159,,,,,,,,,,,,,,,,, -29500,0.39261085,1.6996816,,,,,,,,,,,,,,,,, -29600,0.36912352,1.7403728,,,,,,,,,,,,,,,,, -29700,0.4743532,1.8701794,,,,,,,,,,,,,,,,, -29800,0.438239,1.8184597,,,,,,,,,,,,,,,,, -29900,0.39859784,1.8080513,,,,,,,,,,,,,,,,, -30000,0.4706808,1.7973689,,,,,,,,,,,,,,,,, -30100,0.38128015,1.808129,,,,,,,,,,,,,,,,, -30200,0.3660248,1.8009717,,,,,,,,,,,,,,,,, -30300,0.36461163,1.7661134,,,,,,,,,,,,,,,,, -30400,0.39365634,1.7733554,,,,,,,,,,,,,,,,, -30500,0.43332735,1.7821978,,,,,,,,,,,,,,,,, -30600,0.40211013,1.8225324,,,,,,,,,,,,,,,,, -30700,0.43969128,1.7872459,,,,,,,,,,,,,,,,, -30800,0.43068087,1.7521552,,,,,,,,,,,,,,,,, -30900,0.37873402,1.7570082,,,,,,,,,,,,,,,,, -31000,0.35618117,1.7290779,,,,,,,,,,,,,,,,, -31100,0.39097816,1.7822815,,,,,,,,,,,,,,,,, -31200,0.3912505,1.8340088,,,,,,,,,,,,,,,,, -31300,0.41868588,1.8203565,,,,,,,,,,,,,,,,, -31392,,,0.6668255925178528,1.5802048444747925,33.53180261019265,0.6552553772926331,1.6375457048416138,27.770932078759643,3000.0,0.6667131781578064,1.5764648914337158,27.34374131007221,3003.0,10951.089903831482,18321.97534775734,10951.089903831482,7369.476313352585,0.3675158023834228,0.0 -31400,0.75157493,1.846995,,,,,,,,,,,,,,,,, -31500,0.41621926,1.8220435,,,,,,,,,,,,,,,,, -31600,0.38031837,1.7620564,,,,,,,,,,,,,,,,, -31700,0.47540498,1.8659791,,,,,,,,,,,,,,,,, -31800,0.42060897,1.8528202,,,,,,,,,,,,,,,,, -31900,0.4557507,1.8502603,,,,,,,,,,,,,,,,, -32000,0.42857882,1.81847,,,,,,,,,,,,,,,,, -32100,0.39797747,1.8320224,,,,,,,,,,,,,,,,, -32200,0.41213852,1.7997409,,,,,,,,,,,,,,,,, -32300,0.3973317,1.720233,,,,,,,,,,,,,,,,, -32400,0.43631566,1.8327725,,,,,,,,,,,,,,,,, -32500,0.39236438,1.6559113,,,,,,,,,,,,,,,,, -32600,0.4724618,1.8066139,,,,,,,,,,,,,,,,, -32700,0.5333962,1.8396343,,,,,,,,,,,,,,,,, -32800,0.412926,1.9068002,,,,,,,,,,,,,,,,, -32900,0.39186925,1.7634046,,,,,,,,,,,,,,,,, -33000,0.4005247,1.7242886,,,,,,,,,,,,,,,,, -33100,0.42198417,1.8067359,,,,,,,,,,,,,,,,, -33200,0.39558756,1.8064328,,,,,,,,,,,,,,,,, -33300,0.5686467,1.6942157,,,,,,,,,,,,,,,,, -33400,0.4271921,1.8229828,,,,,,,,,,,,,,,,, -33500,0.3843906,1.7702527,,,,,,,,,,,,,,,,, -33600,0.41849345,1.8530866,,,,,,,,,,,,,,,,, -33700,0.43743205,1.8497405,,,,,,,,,,,,,,,,, -33800,0.45447466,1.7865788,,,,,,,,,,,,,,,,, -33809,,,0.6364154815673828,1.7697381973266602,30.883054872066605,0.6572515964508057,1.6273834705352783,28.07798566784241,3000.0,0.665690541267395,1.5672624111175537,27.31780751025964,3003.0,11791.302706718445,19685.88606619835,11791.302706718445,7893.05903172493,0.4018421173095703,0.0 -33900,0.5033249,1.8230643,,,,,,,,,,,,,,,,, -34000,0.4077298,1.8043802,,,,,,,,,,,,,,,,, -34100,0.45658627,1.853675,,,,,,,,,,,,,,,,, -34200,0.38710192,1.795832,,,,,,,,,,,,,,,,, -34300,0.39776197,1.7991326,,,,,,,,,,,,,,,,, -34400,0.34751362,1.7383554,,,,,,,,,,,,,,,,, -34500,0.4014955,1.7677711,,,,,,,,,,,,,,,,, -34600,0.39299977,1.7926937,,,,,,,,,,,,,,,,, -34700,0.42545325,1.7944499,,,,,,,,,,,,,,,,, -34800,0.5138201,1.673139,,,,,,,,,,,,,,,,, -34900,0.4113244,1.776984,,,,,,,,,,,,,,,,, -35000,0.6521703,1.8608055,,,,,,,,,,,,,,,,, -35100,0.443859,1.8041694,,,,,,,,,,,,,,,,, -35200,0.42451993,1.7911228,,,,,,,,,,,,,,,,, -35300,0.3843685,1.7111405,,,,,,,,,,,,,,,,, -35400,0.42564082,1.8017817,,,,,,,,,,,,,,,,, -35500,0.38335854,1.7939665,,,,,,,,,,,,,,,,, -35600,0.3535315,1.7642993,,,,,,,,,,,,,,,,, -35700,0.34717244,1.7049607,,,,,,,,,,,,,,,,, -35800,0.3963304,1.812757,,,,,,,,,,,,,,,,, -35900,0.48580304,1.8312232,,,,,,,,,,,,,,,,, -36000,0.43381774,1.7321222,,,,,,,,,,,,,,,,, -36100,0.37893358,1.7473981,,,,,,,,,,,,,,,,, -36200,0.36653832,1.7750489,,,,,,,,,,,,,,,,, -36227,,,0.6372659206390381,1.7720032930374146,31.08953855694967,0.6579459309577942,1.616313338279724,28.23293001497443,3000.0,0.6670966148376465,1.5575518608093262,27.478248748542704,3003.0,12631.28459262848,21011.22130537033,12631.28459262848,8378.304702997208,0.4316813945770263,0.0 -36300,0.3558632,1.7403206,,,,,,,,,,,,,,,,, -36400,0.37797767,1.7425711,,,,,,,,,,,,,,,,, -36500,0.39867595,1.7279602,,,,,,,,,,,,,,,,, -36600,0.36368194,1.7198222,,,,,,,,,,,,,,,,, -36700,0.4194442,1.7114155,,,,,,,,,,,,,,,,, -36800,0.40645117,1.7850266,,,,,,,,,,,,,,,,, -36900,0.37370402,1.8415077,,,,,,,,,,,,,,,,, -37000,2.3056853,1.9459096,,,,,,,,,,,,,,,,, -37100,0.42934325,1.8236027,,,,,,,,,,,,,,,,, -37200,0.43704295,1.7655913,,,,,,,,,,,,,,,,, -37300,0.37769863,1.8348688,,,,,,,,,,,,,,,,, -37400,0.44153035,1.7569914,,,,,,,,,,,,,,,,, -37500,0.41702846,1.7681123,,,,,,,,,,,,,,,,, -37600,0.38224083,1.7911489,,,,,,,,,,,,,,,,, -37700,0.4372152,1.707534,,,,,,,,,,,,,,,,, -37800,0.6052072,1.7050569,,,,,,,,,,,,,,,,, -37900,0.39585304,1.8585073,,,,,,,,,,,,,,,,, -38000,0.40922803,1.757727,,,,,,,,,,,,,,,,, -38100,0.4532626,1.7882911,,,,,,,,,,,,,,,,, -38200,0.3624096,1.6969384,,,,,,,,,,,,,,,,, -38300,0.3970987,1.7148519,,,,,,,,,,,,,,,,, -38400,0.3442416,1.7852362,,,,,,,,,,,,,,,,, -38500,0.40456384,1.7578568,,,,,,,,,,,,,,,,, -38600,0.3291412,1.7090926,,,,,,,,,,,,,,,,, -38643,,,0.6398760676383972,1.7399176359176636,31.15971808188359,0.6560736894607544,1.6214781999588013,28.02467696226403,3000.0,0.6660391688346863,1.5656033754348757,27.41756213472292,3003.0,13471.477101564407,22314.383660316467,13471.477101564407,8841.15231704712,0.4688034057617187,0.0 -38700,0.3634928,1.7294426,,,,,,,,,,,,,,,,, -38800,0.37376064,1.8037955,,,,,,,,,,,,,,,,, -38900,0.3882324,1.7499657,,,,,,,,,,,,,,,,, -39000,0.42601946,1.7956023,,,,,,,,,,,,,,,,, -39100,0.46230632,1.7610676,,,,,,,,,,,,,,,,, -39200,0.43526447,1.8105626,,,,,,,,,,,,,,,,, -39300,0.37198538,1.7123978,,,,,,,,,,,,,,,,, -39400,0.40293112,1.7536585,,,,,,,,,,,,,,,,, -39500,0.40954116,1.7808758,,,,,,,,,,,,,,,,, -39600,0.37240157,1.7538898,,,,,,,,,,,,,,,,, -39700,0.43213233,1.8278891,,,,,,,,,,,,,,,,, -39800,0.3700554,1.8154769,,,,,,,,,,,,,,,,, -39900,0.3754393,1.7867011,,,,,,,,,,,,,,,,, -40000,0.47054783,1.7429254,,,,,,,,,,,,,,,,, -40100,0.37453893,1.7313733,,,,,,,,,,,,,,,,, -40200,0.39432928,1.8092198,,,,,,,,,,,,,,,,, -40300,0.37702304,1.7213209,,,,,,,,,,,,,,,,, -40400,0.4029062,1.7799737,,,,,,,,,,,,,,,,, -40500,0.42666066,1.80999,,,,,,,,,,,,,,,,, -40600,0.36361083,1.7462546,,,,,,,,,,,,,,,,, -40700,0.37753117,1.7538611,,,,,,,,,,,,,,,,, -40800,0.36388236,1.7775021,,,,,,,,,,,,,,,,, -40900,0.3936459,1.8427522,,,,,,,,,,,,,,,,, -41000,0.38179767,1.7535518,,,,,,,,,,,,,,,,, -41061,,,0.6372715830802917,1.7691075801849363,30.75969332973245,0.6587767004966736,1.6140390634536743,28.07393183728591,3000.0,0.672163188457489,1.5471595525741575,27.999417482086034,3003.0,14311.619084835052,23735.61979198456,14311.619084835052,9422.134581565855,0.5005180835723877,0.0 -41100,0.42847806,1.7499077,,,,,,,,,,,,,,,,, -41200,0.47387692,1.8075155,,,,,,,,,,,,,,,,, -41300,0.39158344,1.8085283,,,,,,,,,,,,,,,,, -41400,0.3864482,1.7836343,,,,,,,,,,,,,,,,, -41500,0.4397008,1.7121184,,,,,,,,,,,,,,,,, -41600,0.43116724,1.8693967,,,,,,,,,,,,,,,,, -41700,0.4231888,1.7591902,,,,,,,,,,,,,,,,, -41800,0.40052104,1.787631,,,,,,,,,,,,,,,,, -41900,0.37801692,1.7429438,,,,,,,,,,,,,,,,, -42000,0.43086156,1.7522932,,,,,,,,,,,,,,,,, -42100,0.37827644,1.7110348,,,,,,,,,,,,,,,,, -42200,0.36340988,1.7223532,,,,,,,,,,,,,,,,, -42300,0.3708384,1.7127213,,,,,,,,,,,,,,,,, -42400,0.38579422,1.7324744,,,,,,,,,,,,,,,,, -42500,0.38865688,1.7595605,,,,,,,,,,,,,,,,, -42600,0.3834774,1.783012,,,,,,,,,,,,,,,,, -42700,0.41494858,1.8012972,,,,,,,,,,,,,,,,, -42800,0.47553864,1.7031914,,,,,,,,,,,,,,,,, -42900,0.37202454,1.7527642,,,,,,,,,,,,,,,,, -43000,0.40582937,1.7291907,,,,,,,,,,,,,,,,, -43100,0.40921193,1.7628108,,,,,,,,,,,,,,,,, -43200,0.41846403,1.7499058,,,,,,,,,,,,,,,,, -43300,0.42773554,1.8435456,,,,,,,,,,,,,,,,, -43400,0.38352934,1.7411954,,,,,,,,,,,,,,,,, -43479,,,0.6386000514030457,1.760711908340454,31.427426086467367,0.6623228192329407,1.598664164543152,28.480353182921547,3000.0,0.6713265180587769,1.540757417678833,27.56919141028713,3003.0,15151.775276899338,25052.546183347706,15151.775276899338,9898.793123483658,0.5321898460388184,0.0 -43500,0.41084987,1.6965392,,,,,,,,,,,,,,,,, -43600,0.3923469,1.7551256,,,,,,,,,,,,,,,,, -43700,0.417484,1.7762573,,,,,,,,,,,,,,,,, -43800,0.4327211,1.8595357,,,,,,,,,,,,,,,,, -43900,0.40039414,1.7569623,,,,,,,,,,,,,,,,, -44000,0.3711289,1.7593725,,,,,,,,,,,,,,,,, -44100,0.4069036,1.8016263,,,,,,,,,,,,,,,,, -44200,0.43990824,1.7728387,,,,,,,,,,,,,,,,, -44300,0.4075986,1.7963479,,,,,,,,,,,,,,,,, -44400,0.37611464,1.6707319,,,,,,,,,,,,,,,,, -44500,0.4225802,1.7693081,,,,,,,,,,,,,,,,, -44600,0.41024983,1.661158,,,,,,,,,,,,,,,,, -44700,0.4203121,1.7734057,,,,,,,,,,,,,,,,, -44800,0.3747604,1.7420794,,,,,,,,,,,,,,,,, -44900,0.39022762,1.8795246,,,,,,,,,,,,,,,,, -45000,0.40206578,1.7368976,,,,,,,,,,,,,,,,, -45100,0.4286717,1.7360005,,,,,,,,,,,,,,,,, -45200,0.37600794,1.6994845,,,,,,,,,,,,,,,,, -45300,0.3922843,1.720565,,,,,,,,,,,,,,,,, -45400,0.38872722,1.8159974,,,,,,,,,,,,,,,,, -45500,0.4087501,1.7810053,,,,,,,,,,,,,,,,, -45600,0.4549594,1.7108753,,,,,,,,,,,,,,,,, -45700,0.43320772,1.7110915,,,,,,,,,,,,,,,,, -45800,0.46653065,1.6472431,,,,,,,,,,,,,,,,, -45897,,,0.6438037753105164,1.7203915119171145,31.366122874842,0.6614301204681396,1.597715973854065,28.56949921427732,3000.0,0.6712567806243896,1.539738416671753,27.98019850766508,3003.0,15991.93740272522,26433.4802134037,15991.93740272522,10439.45487523079,0.5634627342224121,0.0 -45900,0.37695795,1.7478067,,,,,,,,,,,,,,,,, -46000,0.39300206,1.8044599,,,,,,,,,,,,,,,,, -46100,0.41516122,1.7048191,,,,,,,,,,,,,,,,, -46200,0.42431885,1.6919706,,,,,,,,,,,,,,,,, -46300,0.36012042,1.7885382,,,,,,,,,,,,,,,,, -46400,0.39610922,1.7219911,,,,,,,,,,,,,,,,, -46500,0.36164615,1.7497121,,,,,,,,,,,,,,,,, -46600,0.3773856,1.7382758,,,,,,,,,,,,,,,,, -46700,0.47162965,1.7339917,,,,,,,,,,,,,,,,, -46800,0.37877396,1.7823285,,,,,,,,,,,,,,,,, -46900,0.4073411,1.7584614,,,,,,,,,,,,,,,,, -47000,0.40374544,1.7047437,,,,,,,,,,,,,,,,, -47100,0.43645528,1.7036988,,,,,,,,,,,,,,,,, -47200,0.47230405,1.6836969,,,,,,,,,,,,,,,,, -47300,0.4185215,1.6813886,,,,,,,,,,,,,,,,, -47400,0.45115915,1.7474122,,,,,,,,,,,,,,,,, -47500,0.36710483,1.760362,,,,,,,,,,,,,,,,, -47600,0.35429204,1.672454,,,,,,,,,,,,,,,,, -47700,0.36144388,1.7969636,,,,,,,,,,,,,,,,, -47800,0.42049173,1.7813704,,,,,,,,,,,,,,,,, -47900,0.41208604,1.8328478,,,,,,,,,,,,,,,,, -48000,0.41388774,1.6851776,,,,,,,,,,,,,,,,, -48100,0.37622133,1.7241831,,,,,,,,,,,,,,,,, -48200,0.41532928,1.6858926,,,,,,,,,,,,,,,,, -48300,0.4113372,1.7986443,,,,,,,,,,,,,,,,, -48315,,,0.6425302028656006,1.7314515113830566,31.29823991504049,0.6614673137664795,1.5935068130493164,28.392376189178425,3000.0,0.6734995245933533,1.5197216272354126,28.134675838345206,3003.0,16832.123507976532,27793.366586208344,16832.123507976532,10959.045782327652,0.5941922664642334,0.0 -48400,0.43254486,1.7084087,,,,,,,,,,,,,,,,, -48500,0.3965909,1.8286252,,,,,,,,,,,,,,,,, -48600,0.42679816,1.691911,,,,,,,,,,,,,,,,, -48700,0.41734064,1.7348993,,,,,,,,,,,,,,,,, -48800,0.3808768,1.6894146,,,,,,,,,,,,,,,,, -48900,0.40089846,1.759962,,,,,,,,,,,,,,,,, -49000,0.3640594,1.7601298,,,,,,,,,,,,,,,,, -49100,0.43857846,1.7104436,,,,,,,,,,,,,,,,, -49200,0.45940816,1.7838339,,,,,,,,,,,,,,,,, -49300,0.38085827,1.7465104,,,,,,,,,,,,,,,,, -49400,0.38834587,1.6967905,,,,,,,,,,,,,,,,, -49500,0.38591433,1.7646923,,,,,,,,,,,,,,,,, -49600,0.38886157,1.7528452,,,,,,,,,,,,,,,,, -49700,0.3723922,1.7101718,,,,,,,,,,,,,,,,, -49800,0.40095463,1.6691837,,,,,,,,,,,,,,,,, -49900,0.37737286,1.7110224,,,,,,,,,,,,,,,,, -50000,0.39699495,1.7557212,,,,,,,,,,,,,,,,, -50100,0.39745536,1.6903073,,,,,,,,,,,,,,,,, -50200,0.39122584,1.696791,,,,,,,,,,,,,,,,, -50300,0.37488416,1.7283927,,,,,,,,,,,,,,,,, -50400,0.3771456,1.7723384,,,,,,,,,,,,,,,,, -50500,0.37077963,1.7771477,,,,,,,,,,,,,,,,, -50600,0.39476055,1.76463,,,,,,,,,,,,,,,,, -50700,0.3924663,1.7831041,,,,,,,,,,,,,,,,, -50733,,,0.6550332307815552,1.6446468830108645,31.87787872913857,0.664765477180481,1.5815393924713137,28.684633210753194,3000.0,0.6736041307449341,1.5203142166137695,27.925587206144616,3003.0,17672.332036733627,29109.59782481193,17672.332036733627,11434.958333969116,0.6256530284881592,0.0 -50800,0.44828013,1.7056046,,,,,,,,,,,,,,,,, -50900,0.37589133,1.768427,,,,,,,,,,,,,,,,, -51000,0.43675765,1.6906409,,,,,,,,,,,,,,,,, -51100,0.38831827,1.7393116,,,,,,,,,,,,,,,,, -51200,0.38184673,1.7217294,,,,,,,,,,,,,,,,, -51300,0.3831699,1.8010479,,,,,,,,,,,,,,,,, -51400,0.39424607,1.7166036,,,,,,,,,,,,,,,,, -51500,0.3820604,1.722919,,,,,,,,,,,,,,,,, -51600,0.39709762,1.6927396,,,,,,,,,,,,,,,,, -51700,0.40457824,1.8078805,,,,,,,,,,,,,,,,, -51800,0.54534477,1.7568786,,,,,,,,,,,,,,,,, -51900,0.35980493,1.6541171,,,,,,,,,,,,,,,,, -52000,0.37790576,1.783481,,,,,,,,,,,,,,,,, -52100,0.38556117,1.7567381,,,,,,,,,,,,,,,,, -52200,0.37319654,1.7056795,,,,,,,,,,,,,,,,, -52300,0.35667774,1.6977623,,,,,,,,,,,,,,,,, -52400,0.43087128,1.7758703,,,,,,,,,,,,,,,,, -52500,0.3982106,1.6658523,,,,,,,,,,,,,,,,, -52600,0.38938677,1.6984705,,,,,,,,,,,,,,,,, -52700,0.37670997,1.6463206,,,,,,,,,,,,,,,,, -52800,0.36721054,1.7315227,,,,,,,,,,,,,,,,, -52900,0.41012707,1.7778035,,,,,,,,,,,,,,,,, -53000,0.37136537,1.6628602,,,,,,,,,,,,,,,,, -53100,0.41926026,1.7567676,,,,,,,,,,,,,,,,, -53150,,,0.6440820097923279,1.7206270694732666,31.623226056556543,0.6641455292701721,1.5732247829437256,28.54921184734869,3000.0,0.6757422685623169,1.5066756010055542,28.26698253487057,3003.0,18512.45762705803,30521.75099492073,18512.45762705803,12006.873229980469,0.6573050022125244,0.0 -53200,0.42293715,1.750115,,,,,,,,,,,,,,,,, -53300,0.40188444,1.6793683,,,,,,,,,,,,,,,,, -53400,0.3887324,1.7220627,,,,,,,,,,,,,,,,, -53500,0.3983713,1.6501212,,,,,,,,,,,,,,,,, -53600,0.3678726,1.7240081,,,,,,,,,,,,,,,,, -53700,0.3833017,1.6031841,,,,,,,,,,,,,,,,, -53800,0.42766422,1.7912198,,,,,,,,,,,,,,,,, -53900,0.384309,1.7632372,,,,,,,,,,,,,,,,, -54000,0.37305394,1.7272795,,,,,,,,,,,,,,,,, -54100,0.3740119,1.7539543,,,,,,,,,,,,,,,,, -54200,0.37686414,1.7382485,,,,,,,,,,,,,,,,, -54300,0.4245658,1.7321855,,,,,,,,,,,,,,,,, -54400,0.41319278,1.7197781,,,,,,,,,,,,,,,,, -54500,0.38796762,1.8418075,,,,,,,,,,,,,,,,, -54600,0.46423548,1.7993956,,,,,,,,,,,,,,,,, -54700,0.41681787,1.7448355,,,,,,,,,,,,,,,,, -54800,0.34299937,1.7202859,,,,,,,,,,,,,,,,, -54900,0.369228,1.6808683,,,,,,,,,,,,,,,,, -55000,0.40278122,1.6920289,,,,,,,,,,,,,,,,, -55100,0.40779802,1.6672139,,,,,,,,,,,,,,,,, -55200,0.379302,1.7045263,,,,,,,,,,,,,,,,, -55300,0.35347876,1.7038918,,,,,,,,,,,,,,,,, -55400,0.3927512,1.6777083,,,,,,,,,,,,,,,,, -55500,0.37234613,1.8421979,,,,,,,,,,,,,,,,, -55566,,,0.6458576917648315,1.7048068046569824,31.80318325835912,0.6658813953399658,1.5674703121185305,28.654052868631133,3000.0,0.6787520051002502,1.5035967826843262,28.449937622484967,3003.0,19352.615475177765,31827.267405748367,19352.615475177765,12472.107189893724,0.6961920261383057,0.0 -55600,0.39903998,1.7220439,,,,,,,,,,,,,,,,, -55700,0.37112674,1.7391809,,,,,,,,,,,,,,,,, -55800,0.38893774,1.7518975,,,,,,,,,,,,,,,,, -55900,0.39608946,1.8581783,,,,,,,,,,,,,,,,, -56000,0.40563837,1.733057,,,,,,,,,,,,,,,,, -56100,0.3571439,1.6940902,,,,,,,,,,,,,,,,, -56200,0.3782528,1.7340481,,,,,,,,,,,,,,,,, -56300,0.4019402,1.8026402,,,,,,,,,,,,,,,,, -56400,0.3576951,1.6704265,,,,,,,,,,,,,,,,, -56500,0.3748812,1.6912057,,,,,,,,,,,,,,,,, -56600,0.45588195,1.7162691,,,,,,,,,,,,,,,,, -56700,0.37396002,1.6970447,,,,,,,,,,,,,,,,, -56800,0.39808804,1.7292786,,,,,,,,,,,,,,,,, -56900,0.38042635,1.683002,,,,,,,,,,,,,,,,, -57000,0.4998951,1.773072,,,,,,,,,,,,,,,,, -57100,0.3975611,1.683348,,,,,,,,,,,,,,,,, -57200,0.37728822,1.7731187,,,,,,,,,,,,,,,,, -57300,0.39439562,1.6881677,,,,,,,,,,,,,,,,, -57400,0.37189248,1.7329241,,,,,,,,,,,,,,,,, -57500,0.4117712,1.7205794,,,,,,,,,,,,,,,,, -57600,0.36783823,1.7552207,,,,,,,,,,,,,,,,, -57700,0.38259423,1.6828258,,,,,,,,,,,,,,,,, -57800,0.4159046,1.7021947,,,,,,,,,,,,,,,,, -57900,0.42419896,1.6529875,,,,,,,,,,,,,,,,, -57983,,,0.6503816843032837,1.6813665628433228,32.32253220630714,0.6673941016197205,1.5596922636032104,28.76944588602514,3000.0,0.678345263004303,1.492568016052246,28.715793486279985,3003.0,20192.78905391693,33198.427434682846,20192.78905391693,13002.97655892372,0.7319936752319336,0.0 -58000,0.39953634,1.7017378,,,,,,,,,,,,,,,,, -58100,0.3657687,1.6539565,,,,,,,,,,,,,,,,, -58200,0.3904267,1.7071127,,,,,,,,,,,,,,,,, -58300,0.37227097,1.6464646,,,,,,,,,,,,,,,,, -58400,0.37150097,1.6994556,,,,,,,,,,,,,,,,, -58500,0.555187,1.7472618,,,,,,,,,,,,,,,,, -58600,0.36644232,1.6805547,,,,,,,,,,,,,,,,, -58700,0.43771833,1.7207491,,,,,,,,,,,,,,,,, -58800,0.3822135,1.6167638,,,,,,,,,,,,,,,,, -58900,0.36753255,1.6762675,,,,,,,,,,,,,,,,, -59000,0.35263756,1.6658384,,,,,,,,,,,,,,,,, -59100,0.4064347,1.7131935,,,,,,,,,,,,,,,,, -59200,0.39024445,1.7789145,,,,,,,,,,,,,,,,, -59300,0.40233517,1.768276,,,,,,,,,,,,,,,,, -59400,0.37852857,1.6352696,,,,,,,,,,,,,,,,, -59500,0.3693366,1.7004948,,,,,,,,,,,,,,,,, -59600,0.34220657,1.6725088,,,,,,,,,,,,,,,,, -59700,0.44722605,1.7494862,,,,,,,,,,,,,,,,, -59800,0.3686231,1.7292899,,,,,,,,,,,,,,,,, -59900,0.4615442,1.6966091,,,,,,,,,,,,,,,,, -60000,0.453175,1.7641823,,,,,,,,,,,,,,,,, -60100,0.37922165,1.6929536,,,,,,,,,,,,,,,,, -60200,0.35082626,1.6897154,,,,,,,,,,,,,,,,, -60300,0.35998425,1.7502823,,,,,,,,,,,,,,,,, -60399,,,0.6484251022338867,1.6919599771499634,31.836385315191027,0.667815625667572,1.5593230724334717,28.898338564560326,3000.0,0.6783917546272278,1.4889075756072998,28.376137687711587,3003.0,21032.677373170853,34535.92731380463,21032.677373170853,13500.465492010117,0.7707424163818359,0.0 -60400,0.38207245,1.7526467,,,,,,,,,,,,,,,,, -60500,0.43857723,1.7010372,,,,,,,,,,,,,,,,, -60600,0.3705653,1.6490002,,,,,,,,,,,,,,,,, -60700,0.40428713,1.7944998,,,,,,,,,,,,,,,,, -60800,0.4638131,1.7896447,,,,,,,,,,,,,,,,, -60900,0.3739469,1.7122054,,,,,,,,,,,,,,,,, -61000,0.37329382,1.7876254,,,,,,,,,,,,,,,,, -61100,0.3807924,1.6920857,,,,,,,,,,,,,,,,, -61200,0.41547626,1.6904356,,,,,,,,,,,,,,,,, -61300,0.3925492,1.6443739,,,,,,,,,,,,,,,,, -61400,0.41309386,1.7020913,,,,,,,,,,,,,,,,, -61500,0.3889356,1.732546,,,,,,,,,,,,,,,,, -61600,0.36589113,1.7443223,,,,,,,,,,,,,,,,, -61700,0.38340744,1.695224,,,,,,,,,,,,,,,,, -61800,0.41412303,1.7764345,,,,,,,,,,,,,,,,, -61900,0.41814762,1.6209512,,,,,,,,,,,,,,,,, -62000,0.37206882,1.7342905,,,,,,,,,,,,,,,,, -62100,0.37050444,1.7482893,,,,,,,,,,,,,,,,, -62200,0.39630863,1.6399441,,,,,,,,,,,,,,,,, -62300,0.49061608,1.6703337,,,,,,,,,,,,,,,,, -62400,0.3742268,1.5989488,,,,,,,,,,,,,,,,, -62500,0.38751617,1.691567,,,,,,,,,,,,,,,,, -62600,0.37427205,1.6499016,,,,,,,,,,,,,,,,, -62700,0.42156112,1.7039301,,,,,,,,,,,,,,,,, -62800,0.44203454,1.6960639,,,,,,,,,,,,,,,,, -62817,,,0.6717912554740906,1.5358587503433228,33.68370344851488,0.6687455773353577,1.5462676286697388,28.54941150424726,3000.0,0.6795189380645752,1.4860724210739136,28.410224716126773,3003.0,21872.860898256306,36023.72711586952,21872.860898256306,14147.967857837675,0.803971529006958,0.0 -62900,0.39266002,1.7170179,,,,,,,,,,,,,,,,, -63000,0.38220602,1.7064818,,,,,,,,,,,,,,,,, -63100,0.3881557,1.6879432,,,,,,,,,,,,,,,,, -63200,0.38755262,1.7087258,,,,,,,,,,,,,,,,, -63300,0.46792933,1.616877,,,,,,,,,,,,,,,,, -63400,0.42101386,1.7443123,,,,,,,,,,,,,,,,, -63500,0.35664165,1.6712672,,,,,,,,,,,,,,,,, -63600,0.41214788,1.6606045,,,,,,,,,,,,,,,,, -63700,0.36892775,1.6967841,,,,,,,,,,,,,,,,, -63800,0.42710555,1.702178,,,,,,,,,,,,,,,,, -63900,0.4030672,1.7043798,,,,,,,,,,,,,,,,, -64000,0.41587463,1.6481731,,,,,,,,,,,,,,,,, -64100,0.41553572,1.6961632,,,,,,,,,,,,,,,,, -64200,0.38620877,1.707037,,,,,,,,,,,,,,,,, -64300,0.39088985,1.6811519,,,,,,,,,,,,,,,,, -64400,0.38503096,1.7129183,,,,,,,,,,,,,,,,, -64500,0.3950043,1.7496623,,,,,,,,,,,,,,,,, -64600,0.3940329,1.7000697,,,,,,,,,,,,,,,,, -64700,0.42682052,1.734464,,,,,,,,,,,,,,,,, -64800,0.40593633,1.7006521,,,,,,,,,,,,,,,,, -64900,0.38152903,1.6498392,,,,,,,,,,,,,,,,, -65000,0.35982993,1.6631943,,,,,,,,,,,,,,,,, -65100,0.38078722,1.6934048,,,,,,,,,,,,,,,,, -65200,0.43869802,1.6371833,,,,,,,,,,,,,,,,, -65235,,,0.6527928113937378,1.659690499305725,32.17369020688684,0.6705434322357178,1.5372453927993774,28.912173341920933,3000.0,0.6814130544662476,1.4726934432983398,28.651393552679146,3003.0,22712.9723572731,37421.23839473725,22712.9723572731,14705.24766755104,0.8449218273162842,0.0 -65300,0.4085081,1.6755843,,,,,,,,,,,,,,,,, -65400,0.39658108,1.7271205,,,,,,,,,,,,,,,,, -65500,0.36249754,1.6492293,,,,,,,,,,,,,,,,, -65600,0.41851887,1.7295895,,,,,,,,,,,,,,,,, -65700,0.37946072,1.7386469,,,,,,,,,,,,,,,,, -65800,0.36374047,1.6535918,,,,,,,,,,,,,,,,, -65900,0.453879,1.6810786,,,,,,,,,,,,,,,,, -66000,0.3852148,1.6290693,,,,,,,,,,,,,,,,, -66100,0.40560412,1.6855793,,,,,,,,,,,,,,,,, -66200,0.38780874,1.690131,,,,,,,,,,,,,,,,, -66300,0.36805746,1.7274897,,,,,,,,,,,,,,,,, -66400,0.37063614,1.616403,,,,,,,,,,,,,,,,, -66500,0.40036035,1.6582782,,,,,,,,,,,,,,,,, -66600,0.42113838,1.6831467,,,,,,,,,,,,,,,,, -66700,0.43486804,1.7431123,,,,,,,,,,,,,,,,, -66800,0.33987868,1.6393497,,,,,,,,,,,,,,,,, -66900,0.3911328,1.7018102,,,,,,,,,,,,,,,,, -67000,0.36980549,1.6439444,,,,,,,,,,,,,,,,, -67100,0.40036857,1.714615,,,,,,,,,,,,,,,,, -67200,0.40326646,1.6431725,,,,,,,,,,,,,,,,, -67300,0.41934115,1.6960281,,,,,,,,,,,,,,,,, -67400,0.40554017,1.712626,,,,,,,,,,,,,,,,, -67500,0.40499756,1.7438499,,,,,,,,,,,,,,,,, -67600,0.42027196,1.6149702,,,,,,,,,,,,,,,,, -67652,,,0.6534407138824463,1.6673787832260132,32.31266434351574,0.6718453764915466,1.5317059755325315,29.001009543577776,3000.0,0.6824589371681213,1.463875412940979,28.896382659621786,3003.0,23552.86766433716,38806.44243097305,23552.86766433716,15250.439332008362,0.881615400314331,0.0 -67700,0.36251387,1.6364349,,,,,,,,,,,,,,,,, -67800,0.4113513,1.701709,,,,,,,,,,,,,,,,, -67900,0.42984352,1.7177488,,,,,,,,,,,,,,,,, -68000,0.40385136,1.6602544,,,,,,,,,,,,,,,,, -68100,0.3922572,1.6483394,,,,,,,,,,,,,,,,, -68200,0.40430197,1.7528508,,,,,,,,,,,,,,,,, -68300,0.37259498,1.6226974,,,,,,,,,,,,,,,,, -68400,0.37614754,1.633595,,,,,,,,,,,,,,,,, -68500,0.38000774,1.7025388,,,,,,,,,,,,,,,,, -68600,0.39923692,1.6410729,,,,,,,,,,,,,,,,, -68700,0.41876742,1.7242997,,,,,,,,,,,,,,,,, -68800,0.43455917,1.5937357,,,,,,,,,,,,,,,,, -68900,0.39204627,1.6818925,,,,,,,,,,,,,,,,, -69000,0.4044422,1.6859162,,,,,,,,,,,,,,,,, -69100,0.37517118,1.7311771,,,,,,,,,,,,,,,,, -69200,0.42025882,1.729683,,,,,,,,,,,,,,,,, -69300,0.38205874,1.7153479,,,,,,,,,,,,,,,,, -69400,0.37630534,1.6236067,,,,,,,,,,,,,,,,, -69500,0.43118793,1.645464,,,,,,,,,,,,,,,,, -69600,0.43091184,1.6398227,,,,,,,,,,,,,,,,, -69700,0.40350458,1.7106462,,,,,,,,,,,,,,,,, -69800,0.5271065,1.6228446,,,,,,,,,,,,,,,,, -69900,0.4088911,1.6979855,,,,,,,,,,,,,,,,, -70000,0.40918478,1.6696651,,,,,,,,,,,,,,,,, -70069,,,0.6589791178703308,1.6300126314163208,32.72365867691716,0.673568844795227,1.5262335538864136,29.25410080130436,3000.0,0.6847829818725586,1.4570910930633545,28.9900289674124,3003.0,24392.89284348488,40235.36283278465,24392.89284348488,15839.218275308607,0.917823314666748,0.0 -70100,0.3911219,1.6372446,,,,,,,,,,,,,,,,, -70200,0.39996344,1.7172931,,,,,,,,,,,,,,,,, -70300,0.3682146,1.6374862,,,,,,,,,,,,,,,,, -70400,0.42630294,1.6007552,,,,,,,,,,,,,,,,, -70500,0.40623316,1.643558,,,,,,,,,,,,,,,,, -70600,0.37408024,1.6302319,,,,,,,,,,,,,,,,, -70700,0.43042967,1.6425476,,,,,,,,,,,,,,,,, -70800,0.4145352,1.6487232,,,,,,,,,,,,,,,,, -70900,0.4168969,1.6738783,,,,,,,,,,,,,,,,, -71000,0.39683256,1.5638949,,,,,,,,,,,,,,,,, -71100,0.39611536,1.6508846,,,,,,,,,,,,,,,,, -71200,0.4055609,1.7391692,,,,,,,,,,,,,,,,, -71300,0.4135717,1.6554592,,,,,,,,,,,,,,,,, -71400,0.4217029,1.6512673,,,,,,,,,,,,,,,,, -71500,0.3939725,1.6742496,,,,,,,,,,,,,,,,, -71600,0.37647542,1.641462,,,,,,,,,,,,,,,,, -71700,0.40084124,1.6595888,,,,,,,,,,,,,,,,, -71800,0.41879046,1.6724031,,,,,,,,,,,,,,,,, -71900,0.39809862,1.6383314,,,,,,,,,,,,,,,,, -72000,0.383821,1.7006533,,,,,,,,,,,,,,,,, -72100,0.4095011,1.6478795,,,,,,,,,,,,,,,,, -72200,0.42575437,1.5888792,,,,,,,,,,,,,,,,, -72300,0.3763949,1.6187743,,,,,,,,,,,,,,,,, -72400,0.40566003,1.710033,,,,,,,,,,,,,,,,, -72486,,,0.6565991044044495,1.6378833055496216,32.39992250394725,0.6732092499732971,1.5226255655288696,29.28347733531719,3000.0,0.6864911913871765,1.4465367794036863,29.154292326536787,3003.0,25232.844252109528,41621.70757579804,25232.844252109528,16385.491423606873,0.9566261768341064,0.0 -72500,0.42541346,1.6161708,,,,,,,,,,,,,,,,, -72600,0.3822396,1.6297525,,,,,,,,,,,,,,,,, -72700,0.39299002,1.6559763,,,,,,,,,,,,,,,,, -72800,0.38839436,1.6760402,,,,,,,,,,,,,,,,, -72900,0.37791649,1.6937289,,,,,,,,,,,,,,,,, -73000,0.42792422,1.6684213,,,,,,,,,,,,,,,,, -73100,0.37966803,1.6843449,,,,,,,,,,,,,,,,, -73200,0.40129718,1.6063795,,,,,,,,,,,,,,,,, -73300,0.40910882,1.5542353,,,,,,,,,,,,,,,,, -73400,0.3697741,1.5808073,,,,,,,,,,,,,,,,, -73500,0.3735049,1.6520251,,,,,,,,,,,,,,,,, -73600,0.39800477,1.7127508,,,,,,,,,,,,,,,,, -73700,0.38106593,1.6130396,,,,,,,,,,,,,,,,, -73800,0.4333231,1.6456618,,,,,,,,,,,,,,,,, -73900,0.38776973,1.6299953,,,,,,,,,,,,,,,,, -74000,0.4192184,1.6940315,,,,,,,,,,,,,,,,, -74100,0.4275737,1.6945266,,,,,,,,,,,,,,,,, -74200,0.38265827,1.6841567,,,,,,,,,,,,,,,,, -74300,0.41937748,1.7006959,,,,,,,,,,,,,,,,, -74400,0.42229792,1.6234303,,,,,,,,,,,,,,,,, -74500,0.4108974,1.5585643,,,,,,,,,,,,,,,,, -74600,0.43279475,1.5857505,,,,,,,,,,,,,,,,, -74700,0.3688092,1.5834827,,,,,,,,,,,,,,,,, -74800,0.41585043,1.6925504,,,,,,,,,,,,,,,,, -74900,0.39386323,1.644974,,,,,,,,,,,,,,,,, -74903,,,0.65643310546875,1.6501281261444092,32.34811927392111,0.6752055287361145,1.5135127305984497,29.52027020865552,3000.0,0.6882458925247192,1.440338373184204,29.18233746463092,3003.0,26072.833317756653,43007.06311130524,26072.833317756653,16930.741188049316,0.9927854537963868,0.0 -75000,0.39847666,1.6111584,,,,,,,,,,,,,,,,, -75100,0.39562488,1.541449,,,,,,,,,,,,,,,,, -75200,0.40197074,1.6154876,,,,,,,,,,,,,,,,, -75300,0.47720793,1.6649512,,,,,,,,,,,,,,,,, -75400,0.40746126,1.6053067,,,,,,,,,,,,,,,,, -75500,0.38069564,1.6528306,,,,,,,,,,,,,,,,, -75600,0.39574778,1.6639823,,,,,,,,,,,,,,,,, -75700,0.40940714,1.6678588,,,,,,,,,,,,,,,,, -75800,0.3945367,1.6541284,,,,,,,,,,,,,,,,, -75900,0.3983839,1.6202669,,,,,,,,,,,,,,,,, -76000,0.44548842,1.595253,,,,,,,,,,,,,,,,, -76100,0.4317021,1.6209011,,,,,,,,,,,,,,,,, -76200,0.3848473,1.6080897,,,,,,,,,,,,,,,,, -76300,0.39743307,1.6541991,,,,,,,,,,,,,,,,, -76400,0.4438897,1.7263823,,,,,,,,,,,,,,,,, -76500,0.40948683,1.6165082,,,,,,,,,,,,,,,,, -76600,0.4186729,1.588726,,,,,,,,,,,,,,,,, -76700,0.43600848,1.6550468,,,,,,,,,,,,,,,,, -76800,1.9479043,1.6340649,,,,,,,,,,,,,,,,, -76900,0.40395585,1.6226369,,,,,,,,,,,,,,,,, -77000,0.39086586,1.6883273,,,,,,,,,,,,,,,,, -77100,0.39380115,1.6244682,,,,,,,,,,,,,,,,, -77200,0.42265135,1.6171767,,,,,,,,,,,,,,,,, -77300,0.41262442,1.6347175,,,,,,,,,,,,,,,,, -77321,,,0.6606552600860596,1.6113643646240234,33.132585237997496,0.6767429709434509,1.5007519721984863,29.46409467654244,3000.0,0.6892685294151306,1.427660584449768,29.210126730147856,3003.0,26912.88521933556,44386.86210536957,26912.88521933556,17470.37271785736,1.029709815979004,0.0 -77400,0.4173484,1.6344007,,,,,,,,,,,,,,,,, -77500,0.40152195,1.6390654,,,,,,,,,,,,,,,,, -77600,0.4400475,1.6247108,,,,,,,,,,,,,,,,, -77700,0.44159746,1.7973272,,,,,,,,,,,,,,,,, -77800,0.39379677,1.6837405,,,,,,,,,,,,,,,,, -77900,0.414535,1.6140888,,,,,,,,,,,,,,,,, -78000,0.40658364,1.5794151,,,,,,,,,,,,,,,,, -78100,0.44952154,1.6849798,,,,,,,,,,,,,,,,, -78200,0.5508634,1.6641515,,,,,,,,,,,,,,,,, -78300,0.38883033,1.6410238,,,,,,,,,,,,,,,,, -78400,0.39265564,1.6032284,,,,,,,,,,,,,,,,, -78500,0.3920017,1.6452047,,,,,,,,,,,,,,,,, -78600,0.39795384,1.6326131,,,,,,,,,,,,,,,,, -78700,0.39442664,1.6475692,,,,,,,,,,,,,,,,, -78800,0.39912638,1.5596691,,,,,,,,,,,,,,,,, -78900,0.43275505,1.6049228,,,,,,,,,,,,,,,,, -79000,0.411327,1.6076691,,,,,,,,,,,,,,,,, -79100,0.41875502,1.679266,,,,,,,,,,,,,,,,, -79200,0.3866858,1.6719176,,,,,,,,,,,,,,,,, -79300,0.41551813,1.6445242,,,,,,,,,,,,,,,,, -79400,0.409688,1.5966294,,,,,,,,,,,,,,,,, -79500,0.44856757,1.7023667,,,,,,,,,,,,,,,,, -79600,0.42262018,1.6501219,,,,,,,,,,,,,,,,, -79700,0.41852698,1.6586187,,,,,,,,,,,,,,,,, -79737,,,0.6597093939781189,1.622588872909546,32.721303780072326,0.6777225136756897,1.5018718242645264,29.9094847833621,3000.0,0.6904189586639404,1.426446557044983,29.423481171518823,3003.0,27752.805153131485,45742.4435338974,27752.805153131485,17985.91515660286,1.0653765201568604,0.0 -79800,0.4347387,1.6414394,,,,,,,,,,,,,,,,, -79900,0.44552007,1.6722233,,,,,,,,,,,,,,,,, -80000,0.41572487,1.5355551,,,,,,,,,,,,,,,,, -80100,0.39337468,1.7082666,,,,,,,,,,,,,,,,, -80200,0.4052058,1.5888854,,,,,,,,,,,,,,,,, -80300,0.41349918,1.664716,,,,,,,,,,,,,,,,, -80400,0.39584088,1.6099348,,,,,,,,,,,,,,,,, -80500,0.40288362,1.7182082,,,,,,,,,,,,,,,,, -80600,0.3959491,1.6050214,,,,,,,,,,,,,,,,, -80700,0.40011668,1.5846777,,,,,,,,,,,,,,,,, -80800,0.43762907,1.6663342,,,,,,,,,,,,,,,,, -80900,0.43413493,1.6363075,,,,,,,,,,,,,,,,, -81000,0.41589633,1.6816096,,,,,,,,,,,,,,,,, -81100,0.39446202,1.5666846,,,,,,,,,,,,,,,,, -81200,0.42810634,1.6321633,,,,,,,,,,,,,,,,, -81300,0.42988974,1.7190511,,,,,,,,,,,,,,,,, -81400,0.40572587,1.6052952,,,,,,,,,,,,,,,,, -81500,0.41889563,1.555366,,,,,,,,,,,,,,,,, -81600,0.4301659,1.6497984,,,,,,,,,,,,,,,,, -81700,0.5024468,1.5914803,,,,,,,,,,,,,,,,, -81800,0.41817486,1.6178113,,,,,,,,,,,,,,,,, -81900,0.4094118,1.6656764,,,,,,,,,,,,,,,,, -82000,0.43386477,1.6841481,,,,,,,,,,,,,,,,, -82100,0.44735503,1.551802,,,,,,,,,,,,,,,,, -82154,,,0.6717402935028076,1.538611888885498,33.79730536923105,0.6788632273674011,1.4914655685424805,29.745504916128827,3000.0,0.6918599009513855,1.4198691844940186,29.724196297362315,3003.0,28592.93453645706,47107.37024998665,28592.93453645706,18510.5944993496,1.1032533645629885,0.0 -82200,0.4169632,1.6688327,,,,,,,,,,,,,,,,, -82300,0.42678922,1.7073749,,,,,,,,,,,,,,,,, -82400,0.42949882,1.6591632,,,,,,,,,,,,,,,,, -82500,0.41900417,1.6328539,,,,,,,,,,,,,,,,, -82600,0.4325598,1.5577899,,,,,,,,,,,,,,,,, -82700,0.42873743,1.628984,,,,,,,,,,,,,,,,, -82800,0.41696012,1.6241354,,,,,,,,,,,,,,,,, -82900,0.42150337,1.6204711,,,,,,,,,,,,,,,,, -83000,0.41904113,1.5680084,,,,,,,,,,,,,,,,, -83100,0.44935524,1.6288308,,,,,,,,,,,,,,,,, -83200,0.4400641,1.671537,,,,,,,,,,,,,,,,, -83300,0.42909768,1.5600253,,,,,,,,,,,,,,,,, -83400,0.42685014,1.6189976,,,,,,,,,,,,,,,,, -83500,0.42426965,1.6253779,,,,,,,,,,,,,,,,, -83600,0.39573616,1.6234403,,,,,,,,,,,,,,,,, -83700,0.4376456,1.5704602,,,,,,,,,,,,,,,,, -83800,0.42627835,1.5368881,,,,,,,,,,,,,,,,, -83900,0.42260972,1.6643528,,,,,,,,,,,,,,,,, -84000,0.40750638,1.5797832,,,,,,,,,,,,,,,,, -84100,0.42153734,1.5787476,,,,,,,,,,,,,,,,, -84200,0.3897048,1.6383957,,,,,,,,,,,,,,,,, -84300,0.41327617,1.5697856,,,,,,,,,,,,,,,,, -84400,0.44140735,1.6555315,,,,,,,,,,,,,,,,, -84500,0.42940134,1.5790513,,,,,,,,,,,,,,,,, -84571,,,0.665301501750946,1.583433747291565,33.228433223994266,0.6785780787467957,1.48270583152771,29.50485520503944,3000.0,0.6945209503173828,1.4088683128356934,29.945540079129746,3003.0,29432.89954471588,48542.69511389732,29432.89954471588,19105.83856487274,1.1398842334747314,0.0 -84600,0.42036244,1.6080383,,,,,,,,,,,,,,,,, -84700,0.41555476,1.5951953,,,,,,,,,,,,,,,,, -84800,0.44057122,1.5398434,,,,,,,,,,,,,,,,, -84900,0.4395047,1.6386068,,,,,,,,,,,,,,,,, -85000,0.45397744,1.5840067,,,,,,,,,,,,,,,,, -85100,0.41243538,1.563778,,,,,,,,,,,,,,,,, -85200,0.4394186,1.6418881,,,,,,,,,,,,,,,,, -85300,0.4463047,1.6286916,,,,,,,,,,,,,,,,, -85400,0.44719958,1.602158,,,,,,,,,,,,,,,,, -85500,0.45613524,1.6629246,,,,,,,,,,,,,,,,, -85600,0.41043827,1.5378058,,,,,,,,,,,,,,,,, -85700,0.43245044,1.6440748,,,,,,,,,,,,,,,,, -85800,0.45547423,1.6926389,,,,,,,,,,,,,,,,, -85900,0.48984677,1.6774201,,,,,,,,,,,,,,,,, -86000,0.4287408,1.6391443,,,,,,,,,,,,,,,,, -86100,0.41795722,1.5624583,,,,,,,,,,,,,,,,, -86200,0.4537581,1.5763713,,,,,,,,,,,,,,,,, -86300,0.42863777,1.5728565,,,,,,,,,,,,,,,,, -86400,0.44070986,1.6914119,,,,,,,,,,,,,,,,, -86500,0.43171814,1.5887699,,,,,,,,,,,,,,,,, -86600,0.45468232,1.5614152,,,,,,,,,,,,,,,,, -86700,0.43389425,1.610972,,,,,,,,,,,,,,,,, -86800,0.46432644,1.6046938,,,,,,,,,,,,,,,,, -86900,0.4359216,1.5730215,,,,,,,,,,,,,,,,, -86989,,,0.6616896986961365,1.6091835498809814,33.137917686474665,0.6786152720451355,1.482520580291748,29.751823489337745,3000.0,0.6953459978103638,1.4079169034957886,29.84721352815368,3003.0,30272.9449737072,49913.75622940064,30272.9449737072,19636.73670172692,1.1791932582855225,0.0 -87000,0.4084899,1.5535985,,,,,,,,,,,,,,,,, -87100,0.44592917,1.5445234,,,,,,,,,,,,,,,,, -87200,0.4400096,1.6032443,,,,,,,,,,,,,,,,, -87300,0.45775363,1.5836202,,,,,,,,,,,,,,,,, -87400,0.8247046,1.5824325,,,,,,,,,,,,,,,,, -87500,0.43002892,1.632948,,,,,,,,,,,,,,,,, -87600,0.42689535,1.5545449,,,,,,,,,,,,,,,,, -87700,0.4312549,1.5853512,,,,,,,,,,,,,,,,, -87800,0.42313,1.5498893,,,,,,,,,,,,,,,,, -87900,0.43046117,1.5598662,,,,,,,,,,,,,,,,, -88000,0.5031413,1.5279502,,,,,,,,,,,,,,,,, -88100,0.43996766,1.6090271,,,,,,,,,,,,,,,,, -88200,0.4402143,1.5933583,,,,,,,,,,,,,,,,, -88300,0.4474118,1.6281536,,,,,,,,,,,,,,,,, -88400,0.44445303,1.6035408,,,,,,,,,,,,,,,,, -88500,0.4359394,1.5949162,,,,,,,,,,,,,,,,, -88600,0.43390742,1.5039673,,,,,,,,,,,,,,,,, -88700,0.4687569,1.5798652,,,,,,,,,,,,,,,,, -88800,0.4358825,1.6214669,,,,,,,,,,,,,,,,, -88900,0.46383238,1.6142145,,,,,,,,,,,,,,,,, -89000,0.42963633,1.590681,,,,,,,,,,,,,,,,, -89100,0.43803787,1.5951957,,,,,,,,,,,,,,,,, -89200,0.4337411,1.5992231,,,,,,,,,,,,,,,,, -89300,0.45421422,1.5811158,,,,,,,,,,,,,,,,, -89400,0.43017152,1.5487285,,,,,,,,,,,,,,,,, -89407,,,0.6713466644287109,1.5552958250045776,33.52182294055766,0.6820374131202698,1.465645670890808,29.869141731518045,3000.0,0.6946952939033508,1.3910845518112185,29.631786634944643,3003.0,31113.069878339767,51299.45757818222,31113.069878339767,20182.18912935257,1.2226190567016602,0.0 -89500,0.43933782,1.5196466,,,,,,,,,,,,,,,,, -89600,0.46485916,1.5355922,,,,,,,,,,,,,,,,, -89700,0.48489907,1.5838602,,,,,,,,,,,,,,,,, -89800,0.4515706,1.5581826,,,,,,,,,,,,,,,,, -89900,0.42891124,1.5547701,,,,,,,,,,,,,,,,, -90000,0.46496603,1.6536212,,,,,,,,,,,,,,,,, -90100,0.45746624,1.612337,,,,,,,,,,,,,,,,, -90200,0.43612114,1.6271842,,,,,,,,,,,,,,,,, -90300,0.4374108,1.6381527,,,,,,,,,,,,,,,,, -90400,0.45991072,1.537393,,,,,,,,,,,,,,,,, -90500,0.46054292,1.5150034,,,,,,,,,,,,,,,,, -90600,0.4839354,1.592269,,,,,,,,,,,,,,,,, -90700,0.46200645,1.663495,,,,,,,,,,,,,,,,, -90800,0.46566483,1.6187682,,,,,,,,,,,,,,,,, -90900,0.44318596,1.5944564,,,,,,,,,,,,,,,,, -91000,0.46203545,1.5794846,,,,,,,,,,,,,,,,, -91100,0.51360023,1.5778251,,,,,,,,,,,,,,,,, -91200,0.46772394,1.6328226,,,,,,,,,,,,,,,,, -91300,0.46376112,1.6254919,,,,,,,,,,,,,,,,, -91400,0.46599102,1.5317094,,,,,,,,,,,,,,,,, -91500,0.45820716,1.542774,,,,,,,,,,,,,,,,, -91600,0.46164513,1.5675852,,,,,,,,,,,,,,,,, -91700,0.48476437,1.6295567,,,,,,,,,,,,,,,,, -91800,0.447773,1.5806667,,,,,,,,,,,,,,,,, -91823,,,0.6688807606697083,1.561652421951294,33.646249972575525,0.6818886399269104,1.4689723253250122,29.931857466564303,3000.0,0.6980187296867371,1.381521701812744,30.353041929433832,3003.0,31953.005130767822,52629.45927906037,31953.005130767822,20672.136009454727,1.2626848220825195,0.0 -91900,0.4560448,1.5298086,,,,,,,,,,,,,,,,, -92000,0.46093985,1.6271347,,,,,,,,,,,,,,,,, -92100,0.4569159,1.5797943,,,,,,,,,,,,,,,,, -92200,0.5100775,1.6577396,,,,,,,,,,,,,,,,, -92300,0.489274,1.5464818,,,,,,,,,,,,,,,,, -92400,0.48240304,1.6669984,,,,,,,,,,,,,,,,, -92500,0.46180242,1.5391523,,,,,,,,,,,,,,,,, -92600,0.4458803,1.5018244,,,,,,,,,,,,,,,,, -92700,0.43319714,1.5989093,,,,,,,,,,,,,,,,, -92800,0.47219032,1.6217405,,,,,,,,,,,,,,,,, -92900,0.4675843,1.6403004,,,,,,,,,,,,,,,,, -93000,0.4903752,1.6236305,,,,,,,,,,,,,,,,, -93100,0.49744862,1.5771242,,,,,,,,,,,,,,,,, -93200,0.4993595,1.5368472,,,,,,,,,,,,,,,,, -93300,0.47692922,1.498166,,,,,,,,,,,,,,,,, -93400,0.48028335,1.5610516,,,,,,,,,,,,,,,,, -93500,0.46222174,1.5056804,,,,,,,,,,,,,,,,, -93600,0.49110377,1.5392787,,,,,,,,,,,,,,,,, -93700,0.4890772,1.5829405,,,,,,,,,,,,,,,,, -93800,0.47407612,1.5033948,,,,,,,,,,,,,,,,, -93900,0.52653944,1.562776,,,,,,,,,,,,,,,,, -94000,0.48266447,1.5915383,,,,,,,,,,,,,,,,, -94100,0.4770353,1.6483227,,,,,,,,,,,,,,,,, -94200,0.49182442,1.5357789,,,,,,,,,,,,,,,,, -94241,,,0.6905794739723206,1.4270117282867432,34.98234446474538,0.6850627660751343,1.4553813934326172,30.200363953398387,3000.0,0.699785053730011,1.3779276609420776,30.07975448373041,3003.0,32793.152134656906,54029.85724711418,32793.152134656906,21232.269668102264,1.301042079925537,0.0 -94300,0.4628757,1.5000339,,,,,,,,,,,,,,,,, -94400,0.47794706,1.530123,,,,,,,,,,,,,,,,, -94500,0.47323057,1.5229903,,,,,,,,,,,,,,,,, -94600,0.48651868,1.5637673,,,,,,,,,,,,,,,,, -94700,0.4757991,1.5313169,,,,,,,,,,,,,,,,, -94800,0.5165444,1.5400667,,,,,,,,,,,,,,,,, -94900,0.4605301,1.6196635,,,,,,,,,,,,,,,,, -95000,0.47082633,1.5208275,,,,,,,,,,,,,,,,, -95100,0.48077357,1.5196694,,,,,,,,,,,,,,,,, -95200,0.4571251,1.5469916,,,,,,,,,,,,,,,,, -95300,0.47235587,1.6201886,,,,,,,,,,,,,,,,, -95400,0.4907115,1.5586221,,,,,,,,,,,,,,,,, -95500,0.49238372,1.5314894,,,,,,,,,,,,,,,,, -95600,0.50727016,1.5542513,,,,,,,,,,,,,,,,, -95700,0.48176193,1.6206968,,,,,,,,,,,,,,,,, -95800,0.4806561,1.4836483,,,,,,,,,,,,,,,,, -95900,0.48635054,1.4924592,,,,,,,,,,,,,,,,, -96000,0.49439755,1.5733533,,,,,,,,,,,,,,,,, -96100,0.49648777,1.5735037,,,,,,,,,,,,,,,,, -96200,0.49334946,1.4346495,,,,,,,,,,,,,,,,, -96300,0.5070262,1.540755,,,,,,,,,,,,,,,,, -96400,0.48740068,1.5163256,,,,,,,,,,,,,,,,, -96500,0.51666975,1.5429156,,,,,,,,,,,,,,,,, -96600,0.4908995,1.5947891,,,,,,,,,,,,,,,,, -96658,,,0.6737492084503174,1.5275613069534302,34.108211883515416,0.6844552159309387,1.4521570205688477,30.20829723913238,3000.0,0.6992853283882141,1.3757984638214111,30.073980041317707,3003.0,33633.102585315704,55416.08647465706,33633.102585315704,21778.425297021862,1.342395544052124,0.0 -96700,0.49809694,1.5258862,,,,,,,,,,,,,,,,, -96800,0.5034237,1.5545613,,,,,,,,,,,,,,,,, -96900,0.50728947,1.5508277,,,,,,,,,,,,,,,,, -97000,0.5187353,1.5524544,,,,,,,,,,,,,,,,, -97100,0.51981014,1.474479,,,,,,,,,,,,,,,,, -97200,0.5150759,1.551719,,,,,,,,,,,,,,,,, -97300,0.51117086,1.583073,,,,,,,,,,,,,,,,, -97400,0.5195819,1.5132753,,,,,,,,,,,,,,,,, -97500,0.49465656,1.4771874,,,,,,,,,,,,,,,,, -97600,0.5115466,1.5591891,,,,,,,,,,,,,,,,, -97700,0.49802992,1.535026,,,,,,,,,,,,,,,,, -97800,0.51154435,1.5644062,,,,,,,,,,,,,,,,, -97900,0.51498365,1.5430968,,,,,,,,,,,,,,,,, -98000,0.51191634,1.519753,,,,,,,,,,,,,,,,, -98100,0.5336775,1.5751301,,,,,,,,,,,,,,,,, -98200,0.5165827,1.5147401,,,,,,,,,,,,,,,,, -98300,0.5188703,1.5125428,,,,,,,,,,,,,,,,, -98400,0.48721275,1.4793988,,,,,,,,,,,,,,,,, -98500,0.50263846,1.5208352,,,,,,,,,,,,,,,,, -98600,0.51857764,1.5595098,,,,,,,,,,,,,,,,, -98700,0.5075414,1.5361749,,,,,,,,,,,,,,,,, -98800,0.5209347,1.4637809,,,,,,,,,,,,,,,,, -98900,0.50789034,1.5628383,,,,,,,,,,,,,,,,, -99000,0.5460942,1.4856559,,,,,,,,,,,,,,,,, -99074,,,0.6748491525650024,1.5263543128967283,33.783884560990835,0.6855711340904236,1.4503666162490845,30.292007534617788,3000.0,0.7000755667686462,1.370320200920105,30.03763911580841,3003.0,34472.99043941498,56800.761588811874,34472.99043941498,22323.09084415436,1.3831169605255127,0.0 -99100,0.5358338,1.522864,,,,,,,,,,,,,,,,, -99200,0.50231874,1.4699678,,,,,,,,,,,,,,,,, -99300,0.52530444,1.5411175,,,,,,,,,,,,,,,,, -99400,0.5462143,1.5227994,,,,,,,,,,,,,,,,, -99500,0.52586424,1.4909787,,,,,,,,,,,,,,,,, -99600,0.5308813,1.5073855,,,,,,,,,,,,,,,,, -99700,0.5349687,1.4658682,,,,,,,,,,,,,,,,, -99800,0.521002,1.474941,,,,,,,,,,,,,,,,, -99900,0.55857414,1.5376855,,,,,,,,,,,,,,,,, -100000,0.53814024,1.5318993,,,,,,,,,,,,,,,,, -100100,0.5077718,1.5563543,,,,,,,,,,,,,,,,, -100200,0.5230411,1.5214132,,,,,,,,,,,,,,,,, -100300,0.5243988,1.4703714,,,,,,,,,,,,,,,,, -100400,0.53873706,1.5227814,,,,,,,,,,,,,,,,, -100500,0.54424405,1.525412,,,,,,,,,,,,,,,,, -100600,0.5333535,1.464014,,,,,,,,,,,,,,,,, -100700,0.54209757,1.5059495,,,,,,,,,,,,,,,,, -100800,0.5499294,1.5108612,,,,,,,,,,,,,,,,, -100900,0.526836,1.5252727,,,,,,,,,,,,,,,,, -101000,0.5328109,1.5278913,,,,,,,,,,,,,,,,, -101100,0.52646714,1.4766965,,,,,,,,,,,,,,,,, -101200,0.5578922,1.5017163,,,,,,,,,,,,,,,,, -101300,0.53620166,1.4770046,,,,,,,,,,,,,,,,, -101400,0.5303319,1.5042242,,,,,,,,,,,,,,,,, -101491,,,0.6866604685783386,1.455691695213318,34.48494577502086,0.6869474649429321,1.4450162649154663,30.187936515973128,3000.0,0.7018999457359314,1.3603651523590088,30.280502251710864,3003.0,35312.9445669651,58180.41181755066,35312.9445669651,22862.66453719139,1.4255762100219729,0.0 -101500,0.551607,1.4808255,,,,,,,,,,,,,,,,, -101600,0.5409117,1.5758951,,,,,,,,,,,,,,,,, -101700,0.5422307,1.5067091,,,,,,,,,,,,,,,,, -101800,0.5511967,1.474968,,,,,,,,,,,,,,,,, -101900,0.54838437,1.5914611,,,,,,,,,,,,,,,,, -102000,0.5568134,1.4838761,,,,,,,,,,,,,,,,, -102100,0.5530425,1.4503083,,,,,,,,,,,,,,,,, -102200,0.5417041,1.5342098,,,,,,,,,,,,,,,,, -102300,0.5761054,1.617568,,,,,,,,,,,,,,,,, -102400,0.5581049,1.5684224,,,,,,,,,,,,,,,,, -102500,0.5822287,1.486639,,,,,,,,,,,,,,,,, -102600,0.53847474,1.4287704,,,,,,,,,,,,,,,,, -102700,0.5725727,1.4557562,,,,,,,,,,,,,,,,, -102800,0.5628259,1.4570173,,,,,,,,,,,,,,,,, -102900,0.5409246,1.537896,,,,,,,,,,,,,,,,, -103000,0.5561987,1.4966757,,,,,,,,,,,,,,,,, -103100,0.56969744,1.4511466,,,,,,,,,,,,,,,,, -103200,0.5471651,1.3792543,,,,,,,,,,,,,,,,, -103300,0.55086315,1.4220458,,,,,,,,,,,,,,,,, -103400,0.55865157,1.504094,,,,,,,,,,,,,,,,, -103500,0.5923491,1.4749806,,,,,,,,,,,,,,,,, -103600,0.58984894,1.4742815,,,,,,,,,,,,,,,,, -103700,0.5859511,1.4996861,,,,,,,,,,,,,,,,, -103800,0.58638275,1.4900633,,,,,,,,,,,,,,,,, -103900,0.59476924,1.5652041,,,,,,,,,,,,,,,,, -103909,,,0.6815876960754395,1.492857575416565,34.64913971543467,0.6869598627090454,1.4377763271331787,30.17169848257597,3000.0,0.7025158405303955,1.3520406484603882,30.232446000713,3003.0,36152.97557926178,59683.14560890198,36152.97557926178,23525.2458164692,1.4649641513824463,0.0 -104000,0.5801483,1.5035764,,,,,,,,,,,,,,,,, -104100,0.6146481,1.5113953,,,,,,,,,,,,,,,,, -104200,0.616302,1.5212353,,,,,,,,,,,,,,,,, -104300,0.58298546,1.5397607,,,,,,,,,,,,,,,,, -104400,0.5482006,1.4474561,,,,,,,,,,,,,,,,, -104500,0.574734,1.5128872,,,,,,,,,,,,,,,,, -104600,0.5354935,1.4706767,,,,,,,,,,,,,,,,, -104700,0.5749673,1.5006547,,,,,,,,,,,,,,,,, -104800,0.5888825,1.5341774,,,,,,,,,,,,,,,,, -104900,0.6080857,1.4971005,,,,,,,,,,,,,,,,, -105000,0.573804,1.5348871,,,,,,,,,,,,,,,,, -105100,0.5855444,1.488796,,,,,,,,,,,,,,,,, -105200,0.6022738,1.5483171,,,,,,,,,,,,,,,,, -105300,0.5881076,1.4372663,,,,,,,,,,,,,,,,, -105400,0.56932145,1.4426116,,,,,,,,,,,,,,,,, -105500,0.58831733,1.4553275,,,,,,,,,,,,,,,,, -105600,0.60041326,1.5194428,,,,,,,,,,,,,,,,, -105700,0.565747,1.466074,,,,,,,,,,,,,,,,, -105800,0.6643444,1.5245012,,,,,,,,,,,,,,,,, -105900,0.5896376,1.4887936,,,,,,,,,,,,,,,,, -106000,0.5913592,1.4701428,,,,,,,,,,,,,,,,, -106100,0.5929738,1.5760232,,,,,,,,,,,,,,,,, -106200,0.61411357,1.5103819,,,,,,,,,,,,,,,,, -106300,0.5914821,1.5280976,,,,,,,,,,,,,,,,, -106327,,,0.6837038993835449,1.4776220321655271,35.08743437169032,0.688646137714386,1.4313939809799194,30.522066373997887,3000.0,0.7030619978904724,1.3502763509750366,30.41185965475121,3003.0,36993.1280477047,61087.41273570061,36993.1280477047,24089.24018883705,1.504408836364746,0.0 -106400,0.6136687,1.4892241,,,,,,,,,,,,,,,,, -106500,0.6837772,1.484202,,,,,,,,,,,,,,,,, -106600,0.62979776,1.4346021,,,,,,,,,,,,,,,,, -106700,0.6199839,1.4682906,,,,,,,,,,,,,,,,, -106800,0.5929912,1.5158917,,,,,,,,,,,,,,,,, -106900,0.62228614,1.4517097,,,,,,,,,,,,,,,,, -107000,0.61026627,1.5580151,,,,,,,,,,,,,,,,, -107100,0.5977088,1.4255496,,,,,,,,,,,,,,,,, -107200,0.5859839,1.437171,,,,,,,,,,,,,,,,, -107300,0.601201,1.4486316,,,,,,,,,,,,,,,,, -107400,0.61488676,1.5162128,,,,,,,,,,,,,,,,, -107500,0.63087815,1.4486192,,,,,,,,,,,,,,,,, -107600,0.60010934,1.3839437,,,,,,,,,,,,,,,,, -107700,0.60582054,1.4907707,,,,,,,,,,,,,,,,, -107800,0.60746986,1.441726,,,,,,,,,,,,,,,,, -107900,0.644308,1.4424931,,,,,,,,,,,,,,,,, -108000,0.6213626,1.509982,,,,,,,,,,,,,,,,, -108100,0.61726785,1.4740952,,,,,,,,,,,,,,,,, -108200,0.66035694,1.5227257,,,,,,,,,,,,,,,,, -108300,0.63170755,1.5056682,,,,,,,,,,,,,,,,, -108400,0.6159219,1.4317787,,,,,,,,,,,,,,,,, -108500,0.651014,1.4797916,,,,,,,,,,,,,,,,, -108600,0.6073962,1.4538871,,,,,,,,,,,,,,,,, -108700,0.5974094,1.4249712,,,,,,,,,,,,,,,,, -108745,,,0.6909144520759583,1.4369436502456665,34.72590805081428,0.6895388960838318,1.42962646484375,30.422796842264475,3000.0,0.705420970916748,1.344992756843567,30.55323781090465,3003.0,37833.29610252381,62459.371757268906,37833.29610252381,24620.91137957573,1.5444655418395996,0.0 -108800,0.5962686,1.3974123,,,,,,,,,,,,,,,,, -108900,0.63505054,1.4248315,,,,,,,,,,,,,,,,, -109000,0.64317757,1.52046,,,,,,,,,,,,,,,,, -109100,0.6242322,1.4330027,,,,,,,,,,,,,,,,, -109200,0.63092583,1.457282,,,,,,,,,,,,,,,,, -109300,0.6370616,1.4167577,,,,,,,,,,,,,,,,, -109400,0.6543284,1.378091,,,,,,,,,,,,,,,,, -109500,0.6311457,1.4768226,,,,,,,,,,,,,,,,, -109600,0.6436955,1.4227136,,,,,,,,,,,,,,,,, -109700,0.64057297,1.4443188,,,,,,,,,,,,,,,,, -109800,0.6544805,1.427687,,,,,,,,,,,,,,,,, -109900,0.6696135,1.481041,,,,,,,,,,,,,,,,, -110000,0.6231514,1.3844016,,,,,,,,,,,,,,,,, -110100,0.6468311,1.4615648,,,,,,,,,,,,,,,,, -110200,0.67166495,1.4973863,,,,,,,,,,,,,,,,, -110300,0.64318514,1.4047037,,,,,,,,,,,,,,,,, -110400,0.6427527,1.4265622,,,,,,,,,,,,,,,,, -110500,0.67146605,1.4376856,,,,,,,,,,,,,,,,, -110600,0.6500104,1.3614671,,,,,,,,,,,,,,,,, -110700,0.65290666,1.4161055,,,,,,,,,,,,,,,,, -110800,0.6346326,1.4718577,,,,,,,,,,,,,,,,, -110900,0.6420976,1.4998218,,,,,,,,,,,,,,,,, -111000,0.63516814,1.4817625,,,,,,,,,,,,,,,,, -111100,0.64482003,1.4862194,,,,,,,,,,,,,,,,, -111162,,,0.6862839460372925,1.4570157527923584,35.03915497026832,0.6900472044944763,1.4253438711166382,30.63425878074417,3000.0,0.7066760063171387,1.337507247924805,30.61793987736942,3003.0,38673.18712234497,63836.94434714317,38673.18712234497,25158.474792718887,1.5841474533081057,0.0 -111200,0.6411615,1.4386623,,,,,,,,,,,,,,,,, -111300,0.65545154,1.4057367,,,,,,,,,,,,,,,,, -111400,0.69567263,1.4655854,,,,,,,,,,,,,,,,, -111500,0.6651757,1.4696665,,,,,,,,,,,,,,,,, -111600,0.68337786,1.4487268,,,,,,,,,,,,,,,,, -111700,0.68670815,1.4561757,,,,,,,,,,,,,,,,, -111800,0.68677354,1.3941796,,,,,,,,,,,,,,,,, -111900,0.675604,1.4491608,,,,,,,,,,,,,,,,, -112000,0.6842416,1.3752644,,,,,,,,,,,,,,,,, -112100,0.70128316,1.3949722,,,,,,,,,,,,,,,,, -112200,0.71310914,1.535819,,,,,,,,,,,,,,,,, -112300,0.6740276,1.4763262,,,,,,,,,,,,,,,,, -112400,0.653344,1.38596,,,,,,,,,,,,,,,,, -112500,0.66382754,1.4278598,,,,,,,,,,,,,,,,, -112600,0.7146414,1.445531,,,,,,,,,,,,,,,,, -112700,0.65979654,1.3830472,,,,,,,,,,,,,,,,, -112800,0.6947927,1.4149972,,,,,,,,,,,,,,,,, -112900,0.7138016,1.4373137,,,,,,,,,,,,,,,,, -113000,0.67835855,1.4748453,,,,,,,,,,,,,,,,, -113100,0.6880592,1.3772721,,,,,,,,,,,,,,,,, -113200,0.66986215,1.4649866,,,,,,,,,,,,,,,,, -113300,0.6902088,1.5074627,,,,,,,,,,,,,,,,, -113400,0.7089335,1.4579791,,,,,,,,,,,,,,,,, -113500,0.67479396,1.4358793,,,,,,,,,,,,,,,,, -113580,,,0.703749418258667,1.3690860271453855,36.01579664231886,0.691473126411438,1.4214249849319458,30.438168648771725,3000.0,0.7059555053710938,1.3358170986175537,30.445572866631903,3003.0,39513.24196100235,65231.111525297165,39513.24196100235,25712.46093392372,1.6317753791809082,0.0 -113600,0.68950343,1.4803513,,,,,,,,,,,,,,,,, -113700,0.730396,1.4969689,,,,,,,,,,,,,,,,, -113800,0.70623827,1.4140493,,,,,,,,,,,,,,,,, -113900,0.68934935,1.5040555,,,,,,,,,,,,,,,,, -114000,0.7606879,1.3892814,,,,,,,,,,,,,,,,, -114100,0.69051826,1.3601028,,,,,,,,,,,,,,,,, -114200,0.72162783,1.3627423,,,,,,,,,,,,,,,,, -114300,0.71888334,1.4588817,,,,,,,,,,,,,,,,, -114400,0.73512334,1.4239483,,,,,,,,,,,,,,,,, -114500,0.69653785,1.4056859,,,,,,,,,,,,,,,,, -114600,0.73308957,1.435099,,,,,,,,,,,,,,,,, -114700,0.7268656,1.4882021,,,,,,,,,,,,,,,,, -114800,0.72407466,1.3821876,,,,,,,,,,,,,,,,, -114900,0.68431354,1.4240285,,,,,,,,,,,,,,,,, -115000,0.6930619,1.4164305,,,,,,,,,,,,,,,,, -115100,0.72264284,1.4095293,,,,,,,,,,,,,,,,, -115200,0.7117175,1.3843309,,,,,,,,,,,,,,,,, -115300,0.7197506,1.4490322,,,,,,,,,,,,,,,,, -115400,0.7227246,1.4193218,,,,,,,,,,,,,,,,, -115500,0.7223901,1.3936251,,,,,,,,,,,,,,,,, -115600,0.7458831,1.4032105,,,,,,,,,,,,,,,,, -115700,0.7072479,1.4335225,,,,,,,,,,,,,,,,, -115800,0.6940586,1.3771634,,,,,,,,,,,,,,,,, -115900,0.74061716,1.3731753,,,,,,,,,,,,,,,,, -115998,,,0.6970692873001099,1.398681640625,35.804610108981215,0.691386342048645,1.418184757232666,30.566812091793675,3000.0,0.7085468769073486,1.3254882097244265,30.682219921723902,3003.0,40353.26636815071,66676.8403544426,40353.26636815071,26318.045152664185,1.672149419784546,0.0 -116000,0.7096799,1.4509497,,,,,,,,,,,,,,,,, -116100,0.7081537,1.3598986,,,,,,,,,,,,,,,,, -116200,0.7300078,1.4012555,,,,,,,,,,,,,,,,, -116300,0.71888596,1.3565656,,,,,,,,,,,,,,,,, -116400,0.73240274,1.3715336,,,,,,,,,,,,,,,,, -116500,0.72187066,1.4693408,,,,,,,,,,,,,,,,, -116600,0.7712998,1.4490894,,,,,,,,,,,,,,,,, -116700,0.7417946,1.3818948,,,,,,,,,,,,,,,,, -116800,0.74001306,1.4745989,,,,,,,,,,,,,,,,, -116900,0.7418634,1.4584491,,,,,,,,,,,,,,,,, -117000,0.74486744,1.4578811,,,,,,,,,,,,,,,,, -117100,0.7445728,1.3762797,,,,,,,,,,,,,,,,, -117200,0.7540476,1.4333684,,,,,,,,,,,,,,,,, -117300,0.72846603,1.3441417,,,,,,,,,,,,,,,,, -117400,0.73556775,1.49037,,,,,,,,,,,,,,,,, -117500,0.7359993,1.4384555,,,,,,,,,,,,,,,,, -117600,0.7277457,1.3670899,,,,,,,,,,,,,,,,, -117700,0.75754666,1.5084845,,,,,,,,,,,,,,,,, -117800,0.77951974,1.4338324,,,,,,,,,,,,,,,,, -117900,0.74224156,1.3601255,,,,,,,,,,,,,,,,, -118000,0.744381,1.3420993,,,,,,,,,,,,,,,,, -118100,0.75148505,1.3524293,,,,,,,,,,,,,,,,, -118200,0.76085246,1.3960918,,,,,,,,,,,,,,,,, -118300,0.72396016,1.3998374,,,,,,,,,,,,,,,,, -118400,0.7502484,1.3911353,,,,,,,,,,,,,,,,, -118415,,,0.6983740925788879,1.3953670263290403,35.78991389340325,0.6922046542167664,1.4161663055419922,30.73623908706971,3000.0,0.7084422707557678,1.3272607326507568,30.9174266863902,3003.0,41193.21762943268,68089.17468738556,41193.21762943268,26890.30124187469,1.715043306350708,0.0 -118500,0.75850254,1.4022676,,,,,,,,,,,,,,,,, -118600,0.7466006,1.4295716,,,,,,,,,,,,,,,,, -118700,0.771793,1.4667702,,,,,,,,,,,,,,,,, -118800,0.7357999,1.3821406,,,,,,,,,,,,,,,,, -118900,0.7581029,1.3997176,,,,,,,,,,,,,,,,, -119000,0.7521148,1.3826036,,,,,,,,,,,,,,,,, -119100,0.83634853,1.3646042,,,,,,,,,,,,,,,,, -119200,0.7591844,1.332981,,,,,,,,,,,,,,,,, -119300,0.75583583,1.3493829,,,,,,,,,,,,,,,,, -119400,0.7334419,1.3595307,,,,,,,,,,,,,,,,, -119500,0.76800865,1.3440098,,,,,,,,,,,,,,,,, -119600,0.7754286,1.347245,,,,,,,,,,,,,,,,, -119700,0.7890078,1.4195731,,,,,,,,,,,,,,,,, -119800,0.7666414,1.3780782,,,,,,,,,,,,,,,,, -119900,0.77593064,1.3532509,,,,,,,,,,,,,,,,, -120000,0.7489685,1.3706856,,,,,,,,,,,,,,,,, -120100,0.7636986,1.4499443,,,,,,,,,,,,,,,,, -120200,0.7745963,1.4045117,,,,,,,,,,,,,,,,, -120300,0.78156203,1.3673804,,,,,,,,,,,,,,,,, -120400,0.8125042,1.442434,,,,,,,,,,,,,,,,, -120500,0.7835555,1.3616972,,,,,,,,,,,,,,,,, -120600,0.75701785,1.4144093,,,,,,,,,,,,,,,,, -120700,0.75903636,1.3312066,,,,,,,,,,,,,,,,, -120800,0.7439198,1.3348783,,,,,,,,,,,,,,,,, -120833,,,0.7079231142997742,1.3452441692352295,36.54008424706086,0.69246506690979,1.4150506258010864,30.88375035702032,3000.0,0.7084422707557678,1.324387550354004,30.767148916451248,3003.0,42033.18365478516,69469.79971814156,42033.18365478516,27430.756391763687,1.8395373821258545,0.0 -120900,0.79226655,1.3699396,,,,,,,,,,,,,,,,, -121000,0.7769003,1.364297,,,,,,,,,,,,,,,,, -121100,0.77162904,1.3625939,,,,,,,,,,,,,,,,, -121200,0.7857236,1.3898457,,,,,,,,,,,,,,,,, -121300,0.75974035,1.4052598,,,,,,,,,,,,,,,,, -121400,0.8089363,1.4016798,,,,,,,,,,,,,,,,, -121500,0.7938926,1.3996593,,,,,,,,,,,,,,,,, -121600,0.7602004,1.3931836,,,,,,,,,,,,,,,,, -121700,0.78774315,1.3549707,,,,,,,,,,,,,,,,, -121800,0.7904765,1.4478499,,,,,,,,,,,,,,,,, -121900,0.76659846,1.4289936,,,,,,,,,,,,,,,,, -122000,0.7854305,1.3740699,,,,,,,,,,,,,,,,, -122100,0.7725148,1.3534817,,,,,,,,,,,,,,,,, -122200,0.78219426,1.4568317,,,,,,,,,,,,,,,,, -122300,0.8023246,1.3786622,,,,,,,,,,,,,,,,, -122400,0.7845223,1.3849193,,,,,,,,,,,,,,,,, -122500,0.81041515,1.3618125,,,,,,,,,,,,,,,,, -122600,0.81174386,1.3835117,,,,,,,,,,,,,,,,, -122700,0.803146,1.4323636,,,,,,,,,,,,,,,,, -122800,0.7910185,1.3517659,,,,,,,,,,,,,,,,, -122900,0.8118909,1.3239868,,,,,,,,,,,,,,,,, -123000,0.7788463,1.3933535,,,,,,,,,,,,,,,,, -123100,0.7922045,1.3817457,,,,,,,,,,,,,,,,, -123200,0.8148978,1.3238684,,,,,,,,,,,,,,,,, -123251,,,0.7070516347885132,1.346768140792847,36.243165574604014,0.6931470036506653,1.413427233695984,30.778888404449745,3000.0,0.7096391916275024,1.320741891860962,30.74864159068372,3003.0,42873.383053064346,70842.37611746788,42873.383053064346,27963.00019860268,1.8919637203216555,0.0 -123300,0.7862018,1.3613575,,,,,,,,,,,,,,,,, -123400,0.78007185,1.4191834,,,,,,,,,,,,,,,,, -123500,0.83114487,1.4015868,,,,,,,,,,,,,,,,, -123600,0.7790446,1.4138818,,,,,,,,,,,,,,,,, -123700,0.7797847,1.3960773,,,,,,,,,,,,,,,,, -123800,0.7763109,1.3211138,,,,,,,,,,,,,,,,, -123900,0.789275,1.3883696,,,,,,,,,,,,,,,,, -124000,0.81843793,1.3734998,,,,,,,,,,,,,,,,, -124100,0.809514,1.391055,,,,,,,,,,,,,,,,, -124200,0.7985261,1.3544302,,,,,,,,,,,,,,,,, -124300,0.7919869,1.3758564,,,,,,,,,,,,,,,,, -124400,0.7977073,1.339851,,,,,,,,,,,,,,,,, -124500,0.78819126,1.3413681,,,,,,,,,,,,,,,,, -124600,0.81165266,1.4060183,,,,,,,,,,,,,,,,, -124700,0.7918804,1.363173,,,,,,,,,,,,,,,,, -124800,0.8421834,1.3450543,,,,,,,,,,,,,,,,, -124900,0.8353496,1.3726294,,,,,,,,,,,,,,,,, -125000,0.81402326,1.4176092,,,,,,,,,,,,,,,,, -125100,0.78348964,1.4011323,,,,,,,,,,,,,,,,, -125200,0.8122943,1.3406011,,,,,,,,,,,,,,,,, -125300,0.79977775,1.2964925,,,,,,,,,,,,,,,,, -125400,0.81045306,1.3927672,,,,,,,,,,,,,,,,, -125500,0.7962996,1.3426884,,,,,,,,,,,,,,,,, -125600,0.81372356,1.4036353,,,,,,,,,,,,,,,,, -125668,,,0.710022509098053,1.332497239112854,36.40957170008483,0.6937421560287476,1.4104636907577517,30.8876838523228,3000.0,0.709337055683136,1.3211370706558228,30.857238354469462,3003.0,43713.294801950455,72267.84021043777,43713.294801950455,28548.430537700653,1.933983564376831,0.0 -125700,0.81285816,1.368059,,,,,,,,,,,,,,,,, -125800,0.7991281,1.3741785,,,,,,,,,,,,,,,,, -125900,0.8181255,1.3271158,,,,,,,,,,,,,,,,, -126000,0.7847876,1.3446453,,,,,,,,,,,,,,,,, -126100,0.77917254,1.3853307,,,,,,,,,,,,,,,,, -126200,0.79313976,1.3532683,,,,,,,,,,,,,,,,, -126300,0.7832494,1.3198159,,,,,,,,,,,,,,,,, -126400,0.78203714,1.3256884,,,,,,,,,,,,,,,,, -126500,0.80291283,1.403834,,,,,,,,,,,,,,,,, -126600,0.811046,1.3823141,,,,,,,,,,,,,,,,, -126700,0.7709989,1.308299,,,,,,,,,,,,,,,,, -126800,0.8012928,1.3381108,,,,,,,,,,,,,,,,, -126900,0.78805983,1.3231827,,,,,,,,,,,,,,,,, -127000,0.7843977,1.3334911,,,,,,,,,,,,,,,,, -127100,0.80955845,1.3432602,,,,,,,,,,,,,,,,, -127200,0.78072566,1.3526607,,,,,,,,,,,,,,,,, -127300,0.82266307,1.3639355,,,,,,,,,,,,,,,,, -127400,0.7782202,1.34594,,,,,,,,,,,,,,,,, -127500,0.81171864,1.3724955,,,,,,,,,,,,,,,,, -127600,0.80753636,1.3163037,,,,,,,,,,,,,,,,, -127700,0.819937,1.3909035,,,,,,,,,,,,,,,,, -127800,0.80178994,1.329092,,,,,,,,,,,,,,,,, -127900,0.7972311,1.3522954,,,,,,,,,,,,,,,,, -128000,0.8272135,1.3386066,,,,,,,,,,,,,,,,, -128086,,,0.7109796404838562,1.329494595527649,36.99412188193442,0.6932337880134583,1.4097765684127808,30.99008929284226,3000.0,0.7093021869659424,1.319244623184204,30.86260403130315,3003.0,44553.37302994728,73655.06421208382,44553.37302994728,29095.45155930519,1.9780395030975344,0.0 -128100,0.80742484,1.3363047,,,,,,,,,,,,,,,,, -128200,0.766123,1.3370697,,,,,,,,,,,,,,,,, -128300,0.80885726,1.3559498,,,,,,,,,,,,,,,,, -128400,0.79388845,1.3168646,,,,,,,,,,,,,,,,, -128500,0.80526745,1.3486981,,,,,,,,,,,,,,,,, -128600,0.78146964,1.2938793,,,,,,,,,,,,,,,,, -128700,0.79698724,1.3021568,,,,,,,,,,,,,,,,, -128800,0.8063006,1.3916999,,,,,,,,,,,,,,,,, -128900,0.7963213,1.3276976,,,,,,,,,,,,,,,,, -129000,0.8031868,1.3578256,,,,,,,,,,,,,,,,, -129100,0.78373915,1.3501247,,,,,,,,,,,,,,,,, -129200,0.81954885,1.392085,,,,,,,,,,,,,,,,, -129300,0.7957575,1.3805084,,,,,,,,,,,,,,,,, -129400,0.8123947,1.34564,,,,,,,,,,,,,,,,, -129500,0.7713551,1.2930348,,,,,,,,,,,,,,,,, -129600,0.7949119,1.3641566,,,,,,,,,,,,,,,,, -129700,0.8584166,1.4525083,,,,,,,,,,,,,,,,, -129800,0.8112218,1.3496023,,,,,,,,,,,,,,,,, -129900,0.7868661,1.2909558,,,,,,,,,,,,,,,,, -130000,0.80124444,1.370586,,,,,,,,,,,,,,,,, -130100,0.7903707,1.3353608,,,,,,,,,,,,,,,,, -130200,0.8199021,1.3590909,,,,,,,,,,,,,,,,, -130300,0.8273688,1.3588995,,,,,,,,,,,,,,,,, -130400,0.80331075,1.3527633,,,,,,,,,,,,,,,,, -130500,0.7905339,1.3451476,,,,,,,,,,,,,,,,, -130505,,,0.7105883359909058,1.3296940326690674,36.643996629084846,0.6933826208114624,1.4100453853607178,30.929830153579665,3000.0,0.710220217704773,1.3187133073806765,30.82152023947141,3003.0,45393.514297008514,75043.40402579308,45393.514297008514,29643.51740694046,2.031567573547364,0.0 -130600,0.8107554,1.3578917,,,,,,,,,,,,,,,,, -130700,0.8136261,1.3567482,,,,,,,,,,,,,,,,, -130800,0.78733104,1.4000738,,,,,,,,,,,,,,,,, -130900,0.7948978,1.3093847,,,,,,,,,,,,,,,,, -131000,0.8256926,1.3451406,,,,,,,,,,,,,,,,, -131100,0.7935686,1.3408995,,,,,,,,,,,,,,,,, -131200,0.82423294,1.2664338,,,,,,,,,,,,,,,,, -131300,0.802205,1.3435001,,,,,,,,,,,,,,,,, -131400,0.7910595,1.3050547,,,,,,,,,,,,,,,,, -131500,0.792261,1.3300753,,,,,,,,,,,,,,,,, -131600,0.82085633,1.4093927,,,,,,,,,,,,,,,,, -131700,0.8106132,1.3925767,,,,,,,,,,,,,,,,, -131800,0.8376152,1.4033139,,,,,,,,,,,,,,,,, -131900,0.8035145,1.3022919,,,,,,,,,,,,,,,,, -132000,0.7976361,1.3186735,,,,,,,,,,,,,,,,, -132100,0.81107247,1.349003,,,,,,,,,,,,,,,,, -132200,0.79442364,1.3594736,,,,,,,,,,,,,,,,, -132300,0.80285037,1.4264809,,,,,,,,,,,,,,,,, -132400,0.7901176,1.3255652,,,,,,,,,,,,,,,,, -132500,0.8170909,1.3323231,,,,,,,,,,,,,,,,, -132600,0.79686016,1.4052395,,,,,,,,,,,,,,,,, -132700,0.8013104,1.2982033,,,,,,,,,,,,,,,,, -132800,0.81653255,1.2538137,,,,,,,,,,,,,,,,, -132900,0.8285873,1.3628513,,,,,,,,,,,,,,,,, -132924,,,0.7117727398872375,1.324880599975586,36.79409669434138,0.6934942007064819,1.4099982976913452,30.80455651715488,3000.0,0.7099761962890625,1.3190022706985474,30.837535923196857,3003.0,46233.72875595093,76443.79782891273,46233.72875595093,30203.5730149746,2.0750794410705566,0.0 -133000,0.8325353,1.3757931,,,,,,,,,,,,,,,,, -133100,0.79107034,1.3590623,,,,,,,,,,,,,,,,, -133200,0.80831903,1.4134985,,,,,,,,,,,,,,,,, -133300,0.7889452,1.3556072,,,,,,,,,,,,,,,,, -133333,,,0.7115277051925659,1.324217438697815,36.53925289250019,0.6935189962387085,1.40999174118042,30.82531657198388,3000.0,0.7099761962890625,1.318996548652649,30.84644380336485,3003.0,46375.332560777664,77133.75478172302,46375.332560777664,30751.867042303085,2.1205337047576904,0.0 -133333,,,,,,,,,,,,,,46375.332560777664,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 5131312ec..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,58 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -860.1454901695251,0.0,29.8894636631012,1,0,29.8894636631012,0.0007088489946909,0.0,10.966498374938965,3003,890.0350096225739,0.0005914963549003,0.0,10.96237564086914,0.0004835649742744,0.0,10.980294227600098,3000 -1357.0244114398956,0.0199313163757324,870.12095952034,2417,0,870.12095952034,0.5345999598503113,17.927984064231914,2.571000099182129,3003,2227.244187355041,0.5345553755760193,23.673126618849807,2.5997135639190674,0.5361867547035217,19.77231194625641,2.562365293502808,3000 -1817.197232961655,0.0494604110717773,1710.169316291809,4833,0,1710.169316291809,0.5962698459625244,21.65287838064829,2.062872171401977,3003,3527.5778126716614,0.5767702460289001,26.59802861754333,2.2277464866638184,0.5945369601249695,23.0944986510416,2.089860439300537,3000 -2296.1606137752533,0.0747115612030029,2550.408034324646,7251,0,2550.408034324646,0.6099936366081238,22.926869041848573,1.969637036323548,3003,4846.885024785996,0.5924220681190491,27.298196480210727,2.1205623149871826,0.6061673164367676,24.03146513922641,2.006256580352783,3000 -2804.2633929252625,0.0998330116271972,3390.572445869446,9669,0,3390.572445869446,0.6165592074394226,23.52521008844378,1.929946780204773,3003,6195.257897377014,0.5942032337188721,27.3323492483322,2.1156461238861084,0.6107797622680664,24.644655251514664,1.9769822359085083,3000 -3388.406713962555,0.1255655288696289,4230.6515011787415,12086,0,4230.6515011787415,0.6239033341407776,24.09416863267773,1.8900692462921145,3003,7619.584226846695,0.5959599018096924,28.17086301658108,2.1070985794067383,0.6141027212142944,24.87685627139683,1.9404385089874268,3000 -3911.032277584076,0.1560111045837402,5070.606368780136,14503,0,5070.606368780136,0.6236360669136047,24.06938099765642,1.8791215419769287,3003,8982.273378133774,0.6016581058502197,28.41499795066516,2.071608066558838,0.6191367506980896,25.37529482175652,1.9150434732437127,3000 -4442.9956386089325,0.1835007667541504,5910.714811086655,16919,0,5910.714811086655,0.6234152913093567,24.18222592701892,1.8667335510253904,3003,10354.45345211029,0.5999884605407715,28.36836530080516,2.073273181915283,0.6177108883857727,25.15550042864411,1.9168777465820312,3000 -4971.604582309723,0.2146503925323486,6750.888104915619,19337,0,6750.888104915619,0.6259601712226868,23.853407099789305,1.8603891134262085,3003,11723.34387397766,0.6061010956764221,28.632670103777563,1.999993920326233,0.6188639998435974,24.93843937756793,1.9008632898330688,3000 -5588.410370588303,0.2451803684234619,7590.796914815903,21754,0,7590.796914815903,0.6301435232162476,24.47606036026552,1.8423911333084104,3003,13180.169267416,0.6028353571891785,28.765118879859944,2.0488171577453613,0.6212570071220398,25.664352122554146,1.893480896949768,3000 -6178.330642223358,0.2723922729492187,8430.839210271835,24172,0,8430.839210271835,0.626308798789978,24.13714878439688,1.852922797203064,3003,14610.239954471588,0.6010239124298096,28.34594869682518,2.058652639389038,0.6217281818389893,25.278200789983146,1.882087469100952,3000 -6655.143758058548,0.3059818744659424,9271.049669265749,26590,0,9271.049669265749,0.6293068528175354,24.118035745694435,1.835338115692139,3003,15927.374853849413,0.6073386073112488,28.424016769886777,2.0060691833496094,0.6224597096443176,25.24002727839257,1.866793274879456,3000 -7179.2492508888245,0.3355967998504638,10111.076085329056,29007,0,10111.076085329056,0.6305851340293884,23.666496475395743,1.827478289604187,3003,17291.617335796356,0.601288914680481,28.23649580127213,2.059950113296509,0.6220133304595947,24.748724901143305,1.878191590309143,3000 -7640.133558750153,0.3651926517486572,10951.275916576384,31426,0,10951.275916576384,0.6326651573181152,24.27457520068237,1.811199426651001,3003,18592.80987429619,0.6311168670654297,29.951347937008773,1.7990312576293943,0.6247039437294006,25.24824983902175,1.865667939186096,3000 -8146.219051361084,0.3942883014678955,11791.334715604782,33843,0,11791.334715604782,0.6364302039146423,24.618080302001683,1.796127200126648,3003,19939.061821222305,0.608905017375946,28.668787281923468,2.006382942199707,0.6290560364723206,26.221402628623725,1.847534775733948,3000 -8679.177055120468,0.4261302947998047,12631.307217121124,36261,0,12631.307217121124,0.6352681517601013,25.25024140268233,1.80757749080658,3003,21312.102447271347,0.6050881147384644,28.62304099873272,2.023306131362915,0.6242327690124512,25.53469168614256,1.8724839687347408,3000 -9177.565561771393,0.4557979106903076,13471.460614442823,38678,0,13471.460614442823,0.6340363621711731,24.727151582813143,1.797541260719299,3003,22650.755472898483,0.6133084893226624,29.0881076266154,1.960720181465149,0.6289196610450745,25.56870350178375,1.840861201286316,3000 -9677.77205824852,0.4883375167846679,14311.619910955427,41096,0,14311.619910955427,0.6390215754508972,24.106862780559368,1.7725989818572998,3003,23991.232154369354,0.6108484864234924,28.27439783242662,1.9878480434417725,0.6305067539215088,25.164706094613813,1.8244279623031616,3000 -10220.935767889025,0.5191919803619385,15151.617607355118,43513,0,15151.617607355118,0.6409505605697632,25.588925953245663,1.772014617919922,3003,25374.50620198249,0.607908308506012,29.054417956878822,2.006985902786255,0.6317094564437866,25.956385935198032,1.8237781524658203,3000 -10751.859727859495,0.5495162010192871,15991.709733009338,45930,0,15991.709733009338,0.6413573026657104,25.74506411049333,1.751164436340332,3003,26745.632090568542,0.6155939102172852,29.238332232257907,1.9596761465072632,0.633432924747467,26.45347452918118,1.8007560968399048,3000 -11253.07631278038,0.5821371078491211,16831.879290819168,48348,0,16831.879290819168,0.6401836276054382,25.13559768501155,1.750659704208374,3003,28087.13062429428,0.6091585755348206,28.95863394471144,1.986215591430664,0.6338297128677368,26.30413678744673,1.803328514099121,3000 -11832.61929321289,0.6142852306365967,17671.927947998047,50765,0,17671.927947998047,0.6443437337875366,25.595392106041817,1.731534719467163,3003,29506.834899187088,0.6214835047721863,29.929704359714087,1.8999727964401243,0.6367310881614685,26.16699754167921,1.78389310836792,3000 -12361.133578777311,0.6469564437866211,18511.936827898026,53183,0,18511.936827898026,0.6461565494537354,25.63268747559185,1.7227002382278442,3003,30875.469685792923,0.6156458258628845,29.42759289127698,1.9492367506027224,0.6360987424850464,26.239123202813865,1.7813689708709717,3000 -12897.13995742798,0.6800875663757324,19352.044857025143,55600,0,19352.044857025143,0.6483876705169678,25.566127278252768,1.7087100744247437,3003,32251.69904208184,0.6179764866828918,28.86257199663916,1.9319465160369875,0.64002925157547,26.359163168974383,1.7602189779281616,3000 -13441.818828821182,0.7141232490539551,20191.96054983139,58017,0,20191.96054983139,0.6493057012557983,26.47097900002062,1.694807529449463,3003,33636.40701699257,0.6243790984153748,30.10999079467282,1.8864920139312744,0.6418147087097168,26.486150866715008,1.7520946264266968,3000 -13985.639262914658,0.752568244934082,21032.15257835388,60436,0,21032.15257835388,0.6511765718460083,25.956983521488905,1.6859363317489624,3003,35020.5366435051,0.6193897128105164,29.84611340870408,1.9212905168533323,0.6405872106552124,26.8364294960986,1.7477840185165403,3000 -14516.028207540512,0.7864320278167725,21872.31868505478,62854,0,21872.31868505478,0.6542444229125977,26.21284965459793,1.6709492206573486,3003,36391.20599746704,0.6402044296264648,31.06420597935536,1.7503422498703003,0.6426454782485962,26.713522805537497,1.7346762418746948,3000 -15101.708109140396,0.8203840255737305,22712.382335186005,65273,0,22712.382335186005,0.6531752943992615,26.230637257259648,1.6693758964538574,3003,37817.06096410751,0.6218117475509644,29.803210788948807,1.9062598943710327,0.6440093517303467,26.75458586334858,1.7179428339004517,3000 -15905.911636829376,0.8546721935272217,23552.53147172928,67690,0,23552.53147172928,0.6579745411872864,26.80649253582942,1.6437774896621704,3003,39461.53255653381,0.6277031898498535,30.25100134313,1.868759274482727,0.6481258869171143,26.14190472558656,1.7066301107406616,3000 -16440.199833631516,0.8966896533966064,24392.43099117279,70107,0,24392.43099117279,0.6586717963218689,26.89588272520431,1.635362982749939,3003,40835.84396624565,0.6321364045143127,30.434628193983656,1.8177008628845213,0.6484978199005127,27.146623495404736,1.7004215717315674,3000 -17030.950835227966,0.933668613433838,25232.470304965973,72525,0,25232.470304965973,0.6619255542755127,26.698471399758624,1.614753007888794,3003,42266.7501308918,0.6318503618240356,30.32549065257292,1.8400778770446773,0.6501221060752869,27.80913835525352,1.6853595972061155,3000 -17649.86233663559,0.9771442413330078,26072.504900217056,74944,0,26072.504900217056,0.6597524881362915,26.434105724544263,1.6147087812423706,3003,43725.81726980209,0.6331945657730103,30.63485098077157,1.8335165977478027,0.6535691022872925,27.427614088747323,1.6698365211486816,3000 -18164.56058192253,1.013715744018555,26912.40016245842,77361,0,26912.40016245842,0.6661786437034607,27.594957648339527,1.5921257734298706,3003,45080.52706384659,0.6339204907417297,30.797597115589767,1.8167153596878047,0.6549205780029297,27.880582151012234,1.665434956550598,3000 -18708.115093946457,1.0589666366577148,27752.500133752823,79779,0,27752.500133752823,0.6673523187637329,27.324396090936787,1.5821276903152466,3003,46464.30451631546,0.6340129375457764,31.369896513488115,1.818768858909607,0.6570656299591064,28.10779161190765,1.6416860818862915,3000 -19247.07579922676,1.102534532546997,28592.67398762703,82197,0,28592.67398762703,0.6687816381454468,27.69355476262512,1.567957639694214,3003,47843.56256365776,0.6473284363746643,31.41526678659185,1.7201316356658936,0.659334659576416,28.270451851715933,1.631327986717224,3000 -19813.59550333023,1.1386573314666748,29432.80449271202,84615,0,29432.80449271202,0.6705595254898071,27.64755868083939,1.5564008951187134,3003,49250.32803225517,0.6412078142166138,30.883998162965828,1.7685593366622925,0.6598802208900452,27.8358965330051,1.6230331659317017,3000 -20362.637778520584,1.175865888595581,30272.858829975128,87033,0,30272.858829975128,0.6758236289024353,28.10549738521797,1.5261303186416626,3003,50639.543586969376,0.6418304443359375,31.223073845288223,1.7551774978637695,0.6620252728462219,28.48434794106564,1.602751851081848,3000 -21066.31170296669,1.2124691009521484,31112.956319332123,89450,0,31112.956319332123,0.6781593561172485,28.19384526539295,1.5138115882873535,3003,52183.43749761581,0.6507663130760193,31.590313155448776,1.6971262693405151,0.6641083359718323,28.41043268094581,1.5904189348220823,3000 -21709.92742419243,1.2507178783416748,31952.85079932213,91866,0,31952.85079932213,0.6778688430786133,28.213751552293576,1.5026555061340332,3003,53667.065616846085,0.6441076993942261,31.650447712787365,1.7398148775100708,0.6667988896369934,28.44722863245684,1.5745916366577148,3000 -22231.343421697617,1.2895457744598389,32792.86261463165,94283,0,32792.86261463165,0.6807971596717834,28.15908149035932,1.4862122535705566,3003,55028.61365580559,0.6625533699989319,33.16702852310345,1.6080623865127563,0.6701342463493347,28.759051738058805,1.558275818824768,3000 -22777.78276062012,1.3282885551452637,33633.10122871399,96701,0,33633.10122871399,0.6864447593688965,29.245032393980694,1.4664220809936523,3003,56415.408299446106,0.6566953063011169,31.842143989658336,1.6568286418914795,0.670741856098175,29.138297350353145,1.5458422899246216,3000 -23282.706238031387,1.3684589862823486,34473.18996334076,99118,0,34473.18996334076,0.6870489716529846,29.00504494574323,1.4522655010223389,3003,57760.54006314278,0.6576961874961853,32.159575954341314,1.6613869667053225,0.6734076142311096,29.481519388764976,1.527236819267273,3000 -23890.701112031937,1.4091482162475586,35313.20211029053,101536,0,35313.20211029053,0.6894544363021851,29.022071865705737,1.432341456413269,3003,59208.66853928566,0.6628363132476807,33.3643248815988,1.61624276638031,0.6767801642417908,29.35418899946777,1.5128036737442017,3000 -24466.42604756356,1.4494218826293943,36153.19030022621,103954,0,36153.19030022621,0.6900238394737244,29.251162714857767,1.422572374343872,3003,60624.50135970116,0.6624199748039246,32.52034005600726,1.615139722824097,0.6787640452384949,29.72840660397787,1.4980305433273315,3000 -25030.5927362442,1.497650384902954,36993.35413503647,106372,0,36993.35413503647,0.6945441961288452,29.731491757638405,1.4052151441574097,3003,62028.96114087105,0.7069824934005737,36.09946513355275,1.3659493923187256,0.6812934875488281,29.75335012117342,1.4833416938781738,3000 -25604.80843281746,1.5368640422821045,37833.25030493736,108790,0,37833.25030493736,0.6963105201721191,29.78246186358487,1.3900195360183716,3003,63443.19142055512,0.6711868047714233,33.68167779038636,1.5594106912612915,0.682409405708313,29.792658094298663,1.4713393449783323,3000 -26251.868897914886,1.5767569541931152,38673.397919654846,111208,0,38673.397919654846,0.6995410323143005,30.15858276774193,1.3734357357025146,3003,64930.52002501488,0.6703004240989685,33.46816417771401,1.5756620168685913,0.683078944683075,29.83414942148012,1.4614923000335691,3000 -26806.072550058365,1.625173568725586,39513.55310797691,113626,0,39513.55310797691,0.7013537883758545,30.219203249138683,1.3622872829437256,3003,66325.00939846039,0.6843309998512268,34.52292679167273,1.4797334671020508,0.6871954202651978,30.258902249238226,1.4451779127120972,3000 -27418.13508272171,1.6672179698944092,40353.66239070892,116044,0,40353.66239070892,0.7033408880233765,30.47281998498783,1.350867509841919,3003,67777.30443549156,0.677942156791687,34.21801980482632,1.5225650072097778,0.6876294016838074,30.542062526187728,1.43970787525177,3000 -27988.293677330017,1.7116749286651611,41193.64802837372,118462,0,41193.64802837372,0.705479085445404,30.77297665042285,1.3375895023345947,3003,69187.57227706909,0.6817564964294434,34.63627435914792,1.5030604600906372,0.6891049146652222,30.23524154073049,1.429511785507202,3000 -28561.33699631691,1.754795789718628,42033.85468816757,120881,0,42033.85468816757,0.7065946459770203,30.69767269998512,1.3328224420547483,3003,70600.94706273079,0.6928209066390991,35.03023370942456,1.4292106628417969,0.6907663941383362,30.38323877714415,1.4212204217910769,3000 -29151.12611794472,1.7970449924468994,42874.04085731506,123300,0,42874.04085731506,0.7075010538101196,30.729561965493104,1.3267334699630735,3003,72031.04406666756,0.692010223865509,35.203263862741345,1.4411240816116333,0.6924154758453369,30.84412796852067,1.4140108823776243,3000 -29742.05605435372,1.842005014419556,43714.04322123528,125718,0,43714.04322123528,0.708349347114563,30.87662334271036,1.3209024667739868,3003,73462.10220861435,0.698666512966156,35.56513313304348,1.3995305299758911,0.6939033269882202,30.6555715813736,1.4084949493408203,3000 -30328.435676574707,1.886183738708496,44554.06897211075,128136,0,44554.06897211075,0.710615336894989,30.94520762179943,1.315718173980713,3003,74888.634370327,0.6984717845916748,35.613569304448035,1.403459548950195,0.6934818029403687,30.81339413073557,1.4072446823120115,3000 -30929.09312582016,1.9304907321929927,45394.16031217575,130555,0,45394.16031217575,0.7104758620262146,30.946965112698592,1.3132108449935913,3003,76329.50495123863,0.7004496455192566,35.82390045814242,1.3945019245147705,0.693221390247345,30.76973079067687,1.4052176475524902,3000 -31533.85618448257,1.974720478057861,46234.09957194328,132973,0,46234.09957194328,0.7101853489875793,31.12658905369212,1.313064455986023,3003,77774.33254957199,0.7006556391716003,35.9691844802764,1.3924342393875122,0.6937049627304077,30.88978323109175,1.4050370454788208,3000 -32123.061244010925,2.0193135738372803,46358.67147278786,133333,0,46358.67147278786,0.7102318406105042,31.113830544833622,1.3130912780761719,3003,78488.16630244255,0.6965222358703613,35.569271445766226,1.4127204418182373,0.6937049627304077,30.89082856801816,1.4050407409667969,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/measurements.csv deleted file mode 100644 index 08deb33c8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1393 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.756458,10.947874,,,,,,,,,,,,,,,,, -1,,,0.0005914963549003,10.96237564086914,0.0,0.0004835649742744,10.980294227600098,0.0,3000.0,0.0007088489946909,10.966498374938965,0.0,3003.0,29.8894636631012,890.0350096225739,29.8894636631012,860.1454901695251,0.0,0.0 -100,0.49193034,7.586776,,,,,,,,,,,,,,,,, -200,0.50852764,6.622317,,,,,,,,,,,,,,,,, -300,0.5207285,5.84456,,,,,,,,,,,,,,,,, -400,0.5993211,5.5002418,,,,,,,,,,,,,,,,, -500,0.40841293,5.0374327,,,,,,,,,,,,,,,,, -600,0.5274712,4.7514114,,,,,,,,,,,,,,,,, -700,0.59246874,4.5574737,,,,,,,,,,,,,,,,, -800,0.41844046,4.2035165,,,,,,,,,,,,,,,,, -900,0.4279014,3.9457226,,,,,,,,,,,,,,,,, -1000,0.41056985,3.7846806,,,,,,,,,,,,,,,,, -1100,0.5646035,3.6510975,,,,,,,,,,,,,,,,, -1200,0.3841698,3.5195632,,,,,,,,,,,,,,,,, -1300,0.4050893,3.4949188,,,,,,,,,,,,,,,,, -1400,0.27583095,3.4144764,,,,,,,,,,,,,,,,, -1500,0.29461387,3.2122705,,,,,,,,,,,,,,,,, -1600,0.22993813,3.1128712,,,,,,,,,,,,,,,,, -1700,0.25217518,3.0168664,,,,,,,,,,,,,,,,, -1800,0.25897545,2.9395494,,,,,,,,,,,,,,,,, -1900,0.26759362,2.8543148,,,,,,,,,,,,,,,,, -2000,0.18490961,2.8232503,,,,,,,,,,,,,,,,, -2100,0.2186267,2.7950945,,,,,,,,,,,,,,,,, -2200,0.30442336,2.7616143,,,,,,,,,,,,,,,,, -2300,0.17249916,2.6637354,,,,,,,,,,,,,,,,, -2400,0.20111692,2.5834095,,,,,,,,,,,,,,,,, -2417,,,0.5345553755760193,2.5997135639190674,23.673126618849807,0.5361867547035217,2.562365293502808,19.77231194625641,3000.0,0.5345999598503113,2.571000099182129,17.927984064231914,3003.0,870.12095952034,2227.244187355041,870.12095952034,1357.0244114398956,0.0199313163757324,0.0 -2500,0.2067844,2.6777682,,,,,,,,,,,,,,,,, -2600,0.19431603,2.4585552,,,,,,,,,,,,,,,,, -2700,0.18567258,2.5042996,,,,,,,,,,,,,,,,, -2800,0.23129262,2.489593,,,,,,,,,,,,,,,,, -2900,0.20998405,2.430954,,,,,,,,,,,,,,,,, -3000,0.27978212,2.442265,,,,,,,,,,,,,,,,, -3100,0.2433598,2.451685,,,,,,,,,,,,,,,,, -3200,0.2505995,2.4142745,,,,,,,,,,,,,,,,, -3300,0.21989681,2.301504,,,,,,,,,,,,,,,,, -3400,0.22951207,2.2986104,,,,,,,,,,,,,,,,, -3500,0.33939448,2.3045506,,,,,,,,,,,,,,,,, -3600,0.21352135,2.3731244,,,,,,,,,,,,,,,,, -3700,0.4288365,2.3264277,,,,,,,,,,,,,,,,, -3800,0.31698346,2.3110218,,,,,,,,,,,,,,,,, -3900,0.36685768,2.2752838,,,,,,,,,,,,,,,,, -4000,0.3691128,2.3080678,,,,,,,,,,,,,,,,, -4100,0.38510755,2.2836401,,,,,,,,,,,,,,,,, -4200,0.400828,2.2645998,,,,,,,,,,,,,,,,, -4300,0.3305625,2.2512915,,,,,,,,,,,,,,,,, -4400,0.27481028,2.226676,,,,,,,,,,,,,,,,, -4500,0.26685447,2.1853101,,,,,,,,,,,,,,,,, -4600,0.3117092,2.2257843,,,,,,,,,,,,,,,,, -4700,0.51152605,2.197021,,,,,,,,,,,,,,,,, -4800,0.42301792,2.1973069,,,,,,,,,,,,,,,,, -4833,,,0.5767702460289001,2.2277464866638184,26.59802861754333,0.5945369601249695,2.089860439300537,23.0944986510416,3000.0,0.5962698459625244,2.062872171401977,21.65287838064829,3003.0,1710.169316291809,3527.5778126716614,1710.169316291809,1817.197232961655,0.0494604110717773,0.0 -4900,0.30344772,2.2344775,,,,,,,,,,,,,,,,, -5000,0.45036545,2.2537346,,,,,,,,,,,,,,,,, -5100,0.29521415,2.175109,,,,,,,,,,,,,,,,, -5200,0.3732838,2.1926384,,,,,,,,,,,,,,,,, -5300,0.30293012,2.1482475,,,,,,,,,,,,,,,,, -5400,0.4411821,2.1837962,,,,,,,,,,,,,,,,, -5500,0.5849624,2.2233818,,,,,,,,,,,,,,,,, -5600,0.50636476,2.2150872,,,,,,,,,,,,,,,,, -5700,0.5974105,2.2103217,,,,,,,,,,,,,,,,, -5800,0.29550806,2.1682248,,,,,,,,,,,,,,,,, -5900,0.2861009,2.1843724,,,,,,,,,,,,,,,,, -6000,0.40653908,2.177291,,,,,,,,,,,,,,,,, -6100,0.2980455,2.2087348,,,,,,,,,,,,,,,,, -6200,0.5589807,2.2384822,,,,,,,,,,,,,,,,, -6300,0.3093427,2.1323555,,,,,,,,,,,,,,,,, -6400,0.42201826,2.163304,,,,,,,,,,,,,,,,, -6500,0.29153764,2.1925108,,,,,,,,,,,,,,,,, -6600,0.3503221,2.2570221,,,,,,,,,,,,,,,,, -6700,0.28211635,2.1681938,,,,,,,,,,,,,,,,, -6800,0.28358296,2.1654088,,,,,,,,,,,,,,,,, -6900,0.43637434,2.2221391,,,,,,,,,,,,,,,,, -7000,0.3016772,2.1740215,,,,,,,,,,,,,,,,, -7100,0.6531628,2.2559085,,,,,,,,,,,,,,,,, -7200,0.32826436,2.1338837,,,,,,,,,,,,,,,,, -7251,,,0.5924220681190491,2.1205623149871826,27.298196480210727,0.6061673164367676,2.006256580352783,24.03146513922641,3000.0,0.6099936366081238,1.969637036323548,22.926869041848573,3003.0,2550.408034324646,4846.885024785996,2550.408034324646,2296.1606137752533,0.0747115612030029,0.0 -7300,0.75347495,2.1062882,,,,,,,,,,,,,,,,, -7400,0.2714791,2.0674465,,,,,,,,,,,,,,,,, -7500,0.5938162,2.1188905,,,,,,,,,,,,,,,,, -7600,0.26037735,2.067335,,,,,,,,,,,,,,,,, -7700,0.49026558,2.1403139,,,,,,,,,,,,,,,,, -7800,0.62244546,2.1763804,,,,,,,,,,,,,,,,, -7900,0.52926064,2.2359178,,,,,,,,,,,,,,,,, -8000,0.30307695,2.095781,,,,,,,,,,,,,,,,, -8100,0.47221616,2.149299,,,,,,,,,,,,,,,,, -8200,0.36154637,2.138695,,,,,,,,,,,,,,,,, -8300,0.3243047,2.117834,,,,,,,,,,,,,,,,, -8400,0.40818676,2.1443994,,,,,,,,,,,,,,,,, -8500,0.44089687,2.016842,,,,,,,,,,,,,,,,, -8600,0.28490815,2.022824,,,,,,,,,,,,,,,,, -8700,0.3246216,2.129172,,,,,,,,,,,,,,,,, -8800,0.51508796,2.0500722,,,,,,,,,,,,,,,,, -8900,0.3007032,2.082357,,,,,,,,,,,,,,,,, -9000,0.71213806,2.218069,,,,,,,,,,,,,,,,, -9100,0.51462317,2.1316912,,,,,,,,,,,,,,,,, -9200,0.30428296,2.2044194,,,,,,,,,,,,,,,,, -9300,0.35641465,2.11137,,,,,,,,,,,,,,,,, -9400,0.306422,2.1762178,,,,,,,,,,,,,,,,, -9500,0.31527,2.0791311,,,,,,,,,,,,,,,,, -9600,0.2933912,2.1008396,,,,,,,,,,,,,,,,, -9669,,,0.5942032337188721,2.1156461238861084,27.3323492483322,0.6107797622680664,1.9769822359085083,24.644655251514664,3000.0,0.6165592074394226,1.929946780204773,23.52521008844378,3003.0,3390.572445869446,6195.257897377014,3390.572445869446,2804.2633929252625,0.0998330116271972,0.0 -9700,0.4156002,2.0144262,,,,,,,,,,,,,,,,, -9800,0.3372372,2.106537,,,,,,,,,,,,,,,,, -9900,0.5544329,2.094477,,,,,,,,,,,,,,,,, -10000,0.2700044,2.0463328,,,,,,,,,,,,,,,,, -10100,0.55561846,2.1968782,,,,,,,,,,,,,,,,, -10200,0.69259953,2.1004798,,,,,,,,,,,,,,,,, -10300,0.40213355,2.103678,,,,,,,,,,,,,,,,, -10400,0.335004,2.0931795,,,,,,,,,,,,,,,,, -10500,0.36502296,2.209657,,,,,,,,,,,,,,,,, -10600,0.6308872,2.117428,,,,,,,,,,,,,,,,, -10700,0.5448846,2.1992793,,,,,,,,,,,,,,,,, -10800,0.4639641,2.1130507,,,,,,,,,,,,,,,,, -10900,0.6418324,2.1078985,,,,,,,,,,,,,,,,, -11000,0.86971647,2.2285926,,,,,,,,,,,,,,,,, -11100,0.27808976,2.0253367,,,,,,,,,,,,,,,,, -11200,0.39284727,2.050614,,,,,,,,,,,,,,,,, -11300,0.28814453,2.1072156,,,,,,,,,,,,,,,,, -11400,0.36165935,2.08873,,,,,,,,,,,,,,,,, -11500,0.29565343,2.0977118,,,,,,,,,,,,,,,,, -11600,0.31173074,2.0734334,,,,,,,,,,,,,,,,, -11700,0.33871973,2.1406527,,,,,,,,,,,,,,,,, -11800,0.47560087,1.9950554,,,,,,,,,,,,,,,,, -11900,0.36688706,2.1226678,,,,,,,,,,,,,,,,, -12000,0.59839183,2.143535,,,,,,,,,,,,,,,,, -12086,,,0.5959599018096924,2.1070985794067383,28.17086301658108,0.6141027212142944,1.9404385089874268,24.87685627139683,3000.0,0.6239033341407776,1.8900692462921145,24.09416863267773,3003.0,4230.6515011787415,7619.584226846695,4230.6515011787415,3388.406713962555,0.1255655288696289,0.0 -12100,0.45718163,2.0340672,,,,,,,,,,,,,,,,, -12200,0.36629733,2.1053953,,,,,,,,,,,,,,,,, -12300,0.43072802,2.0129986,,,,,,,,,,,,,,,,, -12400,0.5498598,2.0653367,,,,,,,,,,,,,,,,, -12500,0.5982058,2.1875036,,,,,,,,,,,,,,,,, -12600,0.336619,2.1079352,,,,,,,,,,,,,,,,, -12700,0.33918852,2.158597,,,,,,,,,,,,,,,,, -12800,0.43533143,2.1202314,,,,,,,,,,,,,,,,, -12900,0.35778603,2.0233767,,,,,,,,,,,,,,,,, -13000,0.5308085,2.0578637,,,,,,,,,,,,,,,,, -13100,0.5937642,2.090917,,,,,,,,,,,,,,,,, -13200,0.2901864,2.0918024,,,,,,,,,,,,,,,,, -13300,0.6852342,2.071467,,,,,,,,,,,,,,,,, -13400,0.46382588,2.1252983,,,,,,,,,,,,,,,,, -13500,0.6552446,2.0747833,,,,,,,,,,,,,,,,, -13600,0.33646005,2.068742,,,,,,,,,,,,,,,,, -13700,0.5894014,2.108417,,,,,,,,,,,,,,,,, -13800,0.3898206,2.0069087,,,,,,,,,,,,,,,,, -13900,0.35408917,2.184597,,,,,,,,,,,,,,,,, -14000,0.46064076,2.0818052,,,,,,,,,,,,,,,,, -14100,0.8261634,2.10244,,,,,,,,,,,,,,,,, -14200,0.7041885,2.1175082,,,,,,,,,,,,,,,,, -14300,0.5473854,2.0179477,,,,,,,,,,,,,,,,, -14400,0.30841443,2.0886278,,,,,,,,,,,,,,,,, -14500,0.27590948,2.0222464,,,,,,,,,,,,,,,,, -14503,,,0.6016581058502197,2.071608066558838,28.41499795066516,0.6191367506980896,1.9150434732437127,25.37529482175652,3000.0,0.6236360669136047,1.8791215419769287,24.06938099765642,3003.0,5070.606368780136,8982.273378133774,5070.606368780136,3911.032277584076,0.1560111045837402,0.0 -14600,0.33486316,2.097501,,,,,,,,,,,,,,,,, -14700,0.9375684,2.1548316,,,,,,,,,,,,,,,,, -14800,0.37796012,2.0690935,,,,,,,,,,,,,,,,, -14900,0.30263424,2.1029227,,,,,,,,,,,,,,,,, -15000,0.2832165,2.0483546,,,,,,,,,,,,,,,,, -15100,0.6690594,2.1402323,,,,,,,,,,,,,,,,, -15200,0.50214255,2.1440876,,,,,,,,,,,,,,,,, -15300,0.8600654,2.1057527,,,,,,,,,,,,,,,,, -15400,0.7472631,2.0410366,,,,,,,,,,,,,,,,, -15500,0.29054543,2.0029001,,,,,,,,,,,,,,,,, -15600,0.4402405,2.110881,,,,,,,,,,,,,,,,, -15700,0.64760804,2.1608353,,,,,,,,,,,,,,,,, -15800,0.7671789,2.0605986,,,,,,,,,,,,,,,,, -15900,0.4978557,2.099394,,,,,,,,,,,,,,,,, -16000,0.32352418,1.9721155,,,,,,,,,,,,,,,,, -16100,0.4246449,2.097308,,,,,,,,,,,,,,,,, -16200,0.3503662,2.0075817,,,,,,,,,,,,,,,,, -16300,0.31456888,2.0236354,,,,,,,,,,,,,,,,, -16400,0.26978987,2.007536,,,,,,,,,,,,,,,,, -16500,0.6083821,2.1411002,,,,,,,,,,,,,,,,, -16600,0.413157,1.9801226,,,,,,,,,,,,,,,,, -16700,0.5462099,2.1423502,,,,,,,,,,,,,,,,, -16800,0.35329658,2.0264988,,,,,,,,,,,,,,,,, -16900,0.394487,2.1192634,,,,,,,,,,,,,,,,, -16919,,,0.5999884605407715,2.073273181915283,28.36836530080516,0.6177108883857727,1.9168777465820312,25.15550042864411,3000.0,0.6234152913093567,1.8667335510253904,24.18222592701892,3003.0,5910.714811086655,10354.45345211029,5910.714811086655,4442.9956386089325,0.1835007667541504,0.0 -17000,0.5829009,2.2097602,,,,,,,,,,,,,,,,, -17100,0.6786325,2.1293018,,,,,,,,,,,,,,,,, -17200,0.34654284,2.0247314,,,,,,,,,,,,,,,,, -17300,0.49707192,2.100647,,,,,,,,,,,,,,,,, -17400,0.45463666,2.0235767,,,,,,,,,,,,,,,,, -17500,0.34008807,2.0744948,,,,,,,,,,,,,,,,, -17600,0.2924609,2.1110497,,,,,,,,,,,,,,,,, -17700,0.4438634,2.0352094,,,,,,,,,,,,,,,,, -17800,0.314274,2.0470026,,,,,,,,,,,,,,,,, -17900,0.4847864,2.0227666,,,,,,,,,,,,,,,,, -18000,0.30230808,2.056835,,,,,,,,,,,,,,,,, -18100,0.50894946,2.006961,,,,,,,,,,,,,,,,, -18200,0.46825963,2.1719608,,,,,,,,,,,,,,,,, -18300,0.21699059,1.9839147,,,,,,,,,,,,,,,,, -18400,0.33166465,2.0546317,,,,,,,,,,,,,,,,, -18500,0.43877184,2.0720088,,,,,,,,,,,,,,,,, -18600,0.75593346,2.0924644,,,,,,,,,,,,,,,,, -18700,0.44540107,2.0049565,,,,,,,,,,,,,,,,, -18800,0.44946963,2.098138,,,,,,,,,,,,,,,,, -18900,0.29561657,1.9771965,,,,,,,,,,,,,,,,, -19000,0.40381438,1.9965135,,,,,,,,,,,,,,,,, -19100,0.31340542,2.0584416,,,,,,,,,,,,,,,,, -19200,0.58610904,2.0380738,,,,,,,,,,,,,,,,, -19300,0.27918223,1.9629365,,,,,,,,,,,,,,,,, -19337,,,0.6061010956764221,1.999993920326233,28.632670103777563,0.6188639998435974,1.9008632898330688,24.93843937756793,3000.0,0.6259601712226868,1.8603891134262085,23.853407099789305,3003.0,6750.888104915619,11723.34387397766,6750.888104915619,4971.604582309723,0.2146503925323486,0.0 -19400,0.59785986,2.1067467,,,,,,,,,,,,,,,,, -19500,0.51185644,1.9709325,,,,,,,,,,,,,,,,, -19600,0.26323965,2.1528068,,,,,,,,,,,,,,,,, -19700,0.41451037,2.0493088,,,,,,,,,,,,,,,,, -19800,0.56442857,2.0379705,,,,,,,,,,,,,,,,, -19900,0.39564514,2.114916,,,,,,,,,,,,,,,,, -20000,0.8136595,2.0513217,,,,,,,,,,,,,,,,, -20100,0.57710683,1.9654198,,,,,,,,,,,,,,,,, -20200,0.28593805,1.9954093,,,,,,,,,,,,,,,,, -20300,0.6045572,2.0974119,,,,,,,,,,,,,,,,, -20400,0.69420034,2.1307588,,,,,,,,,,,,,,,,, -20500,0.39545926,2.007111,,,,,,,,,,,,,,,,, -20600,0.58293563,1.9461086,,,,,,,,,,,,,,,,, -20700,0.36901313,1.9450694,,,,,,,,,,,,,,,,, -20800,0.4997003,2.0548594,,,,,,,,,,,,,,,,, -20900,0.32667127,2.0239036,,,,,,,,,,,,,,,,, -21000,0.2689484,2.0209837,,,,,,,,,,,,,,,,, -21100,0.4081186,2.05714,,,,,,,,,,,,,,,,, -21200,0.34409976,2.0921938,,,,,,,,,,,,,,,,, -21300,0.6059751,2.101228,,,,,,,,,,,,,,,,, -21400,0.44484884,2.0900269,,,,,,,,,,,,,,,,, -21500,0.56219035,2.0620215,,,,,,,,,,,,,,,,, -21600,0.575451,1.9794965,,,,,,,,,,,,,,,,, -21700,0.5014155,1.9434772,,,,,,,,,,,,,,,,, -21754,,,0.6028353571891785,2.0488171577453613,28.765118879859944,0.6212570071220398,1.893480896949768,25.664352122554146,3000.0,0.6301435232162476,1.8423911333084104,24.47606036026552,3003.0,7590.796914815903,13180.169267416,7590.796914815903,5588.410370588303,0.2451803684234619,0.0 -21800,0.63144284,2.0447693,,,,,,,,,,,,,,,,, -21900,0.675928,2.1053605,,,,,,,,,,,,,,,,, -22000,0.3379339,2.1983774,,,,,,,,,,,,,,,,, -22100,0.38917318,2.1403313,,,,,,,,,,,,,,,,, -22200,0.6051576,2.1198027,,,,,,,,,,,,,,,,, -22300,0.31846657,2.0760968,,,,,,,,,,,,,,,,, -22400,0.3181591,2.1232677,,,,,,,,,,,,,,,,, -22500,0.28061196,2.0650878,,,,,,,,,,,,,,,,, -22600,0.47498247,2.0912757,,,,,,,,,,,,,,,,, -22700,0.5466202,2.0287783,,,,,,,,,,,,,,,,, -22800,0.51215744,2.0057132,,,,,,,,,,,,,,,,, -22900,0.39193437,2.0726745,,,,,,,,,,,,,,,,, -23000,0.324957,2.0284178,,,,,,,,,,,,,,,,, -23100,0.49326316,2.0247912,,,,,,,,,,,,,,,,, -23200,0.6816258,2.0818157,,,,,,,,,,,,,,,,, -23300,0.2876593,2.0343072,,,,,,,,,,,,,,,,, -23400,0.37902585,2.0975335,,,,,,,,,,,,,,,,, -23500,0.28963497,2.020553,,,,,,,,,,,,,,,,, -23600,0.6580409,2.1226995,,,,,,,,,,,,,,,,, -23700,0.2556632,2.0078325,,,,,,,,,,,,,,,,, -23800,0.73589754,2.029905,,,,,,,,,,,,,,,,, -23900,0.80419797,2.0722418,,,,,,,,,,,,,,,,, -24000,0.27175805,2.0375469,,,,,,,,,,,,,,,,, -24100,0.32678434,2.0456295,,,,,,,,,,,,,,,,, -24172,,,0.6010239124298096,2.058652639389038,28.34594869682518,0.6217281818389893,1.882087469100952,25.278200789983146,3000.0,0.626308798789978,1.852922797203064,24.13714878439688,3003.0,8430.839210271835,14610.239954471588,8430.839210271835,6178.330642223358,0.2723922729492187,0.0 -24200,0.51073605,2.0497947,,,,,,,,,,,,,,,,, -24300,0.4712131,2.0484354,,,,,,,,,,,,,,,,, -24400,0.4273268,1.9628868,,,,,,,,,,,,,,,,, -24500,0.33816406,1.9805019,,,,,,,,,,,,,,,,, -24600,0.57652426,2.0819755,,,,,,,,,,,,,,,,, -24700,0.30798975,1.9923427,,,,,,,,,,,,,,,,, -24800,0.39369676,2.0161138,,,,,,,,,,,,,,,,, -24900,0.55424255,2.0999358,,,,,,,,,,,,,,,,, -25000,0.39926183,1.9397503,,,,,,,,,,,,,,,,, -25100,0.46491188,2.0612628,,,,,,,,,,,,,,,,, -25200,0.44497547,2.1014984,,,,,,,,,,,,,,,,, -25300,0.45004842,2.0007057,,,,,,,,,,,,,,,,, -25400,0.27344307,1.9955835,,,,,,,,,,,,,,,,, -25500,0.3812961,1.9545515,,,,,,,,,,,,,,,,, -25600,0.5638188,2.0035727,,,,,,,,,,,,,,,,, -25700,0.32396328,2.0163612,,,,,,,,,,,,,,,,, -25800,0.3716823,2.0861347,,,,,,,,,,,,,,,,, -25900,0.7087125,1.9726508,,,,,,,,,,,,,,,,, -26000,0.4763983,2.0373027,,,,,,,,,,,,,,,,, -26100,0.6739032,2.07149,,,,,,,,,,,,,,,,, -26200,0.7837888,2.0757396,,,,,,,,,,,,,,,,, -26300,0.6120814,2.036185,,,,,,,,,,,,,,,,, -26400,0.46615934,2.020129,,,,,,,,,,,,,,,,, -26500,0.26631466,2.0243485,,,,,,,,,,,,,,,,, -26590,,,0.6073386073112488,2.0060691833496094,28.424016769886777,0.6224597096443176,1.866793274879456,25.24002727839257,3000.0,0.6293068528175354,1.835338115692139,24.118035745694435,3003.0,9271.049669265749,15927.374853849413,9271.049669265749,6655.143758058548,0.3059818744659424,0.0 -26600,0.36799547,2.0443425,,,,,,,,,,,,,,,,, -26700,0.43001842,1.974011,,,,,,,,,,,,,,,,, -26800,0.2569169,2.025503,,,,,,,,,,,,,,,,, -26900,0.46858814,2.0318315,,,,,,,,,,,,,,,,, -27000,0.4237067,2.03544,,,,,,,,,,,,,,,,, -27100,0.4154652,1.9815112,,,,,,,,,,,,,,,,, -27200,0.37554544,2.049875,,,,,,,,,,,,,,,,, -27300,0.5871642,2.0646992,,,,,,,,,,,,,,,,, -27400,0.3550321,1.9954976,,,,,,,,,,,,,,,,, -27500,0.35116497,2.1107988,,,,,,,,,,,,,,,,, -27600,0.27096125,2.0606794,,,,,,,,,,,,,,,,, -27700,0.44706622,2.0988846,,,,,,,,,,,,,,,,, -27800,0.4715447,2.0170946,,,,,,,,,,,,,,,,, -27900,0.48434842,2.0666482,,,,,,,,,,,,,,,,, -28000,0.5806096,2.0707304,,,,,,,,,,,,,,,,, -28100,0.6558419,2.057869,,,,,,,,,,,,,,,,, -28200,0.3279948,1.9552946,,,,,,,,,,,,,,,,, -28300,0.44417873,1.9237055,,,,,,,,,,,,,,,,, -28400,0.379427,2.0303388,,,,,,,,,,,,,,,,, -28500,0.40205643,2.0420165,,,,,,,,,,,,,,,,, -28600,0.4205677,1.9284706,,,,,,,,,,,,,,,,, -28700,0.28261745,2.0486035,,,,,,,,,,,,,,,,, -28800,0.49476784,2.010732,,,,,,,,,,,,,,,,, -28900,0.2971744,2.059888,,,,,,,,,,,,,,,,, -29000,0.51286256,2.0845237,,,,,,,,,,,,,,,,, -29007,,,0.601288914680481,2.059950113296509,28.23649580127213,0.6220133304595947,1.878191590309143,24.748724901143305,3000.0,0.6305851340293884,1.827478289604187,23.666496475395743,3003.0,10111.076085329056,17291.617335796356,10111.076085329056,7179.2492508888245,0.3355967998504638,0.0 -29100,0.24996498,2.0170178,,,,,,,,,,,,,,,,, -29200,0.6496458,2.0359764,,,,,,,,,,,,,,,,, -29300,0.4580817,2.0720265,,,,,,,,,,,,,,,,, -29400,1.0089567,2.0499637,,,,,,,,,,,,,,,,, -29500,0.29476902,1.9115491,,,,,,,,,,,,,,,,, -29600,0.35500446,1.9680407,,,,,,,,,,,,,,,,, -29700,0.31902644,2.0861166,,,,,,,,,,,,,,,,, -29800,0.30003336,2.0558467,,,,,,,,,,,,,,,,, -29900,0.45201486,2.0337112,,,,,,,,,,,,,,,,, -30000,0.46929568,2.020635,,,,,,,,,,,,,,,,, -30100,0.4481421,2.025457,,,,,,,,,,,,,,,,, -30200,0.8698949,2.0373945,,,,,,,,,,,,,,,,, -30300,0.26258218,1.9820251,,,,,,,,,,,,,,,,, -30400,0.37917575,1.9897102,,,,,,,,,,,,,,,,, -30500,0.5072568,2.007388,,,,,,,,,,,,,,,,, -30600,0.4004106,2.0433521,,,,,,,,,,,,,,,,, -30700,0.27592182,1.9965206,,,,,,,,,,,,,,,,, -30800,0.33147967,1.968027,,,,,,,,,,,,,,,,, -30900,0.31084585,1.9668524,,,,,,,,,,,,,,,,, -31000,0.24649015,1.928625,,,,,,,,,,,,,,,,, -31100,0.2641066,1.9966935,,,,,,,,,,,,,,,,, -31200,0.31676674,2.0513384,,,,,,,,,,,,,,,,, -31300,0.450014,2.0575624,,,,,,,,,,,,,,,,, -31400,0.27212253,2.0545077,,,,,,,,,,,,,,,,, -31426,,,0.6311168670654297,1.7990312576293943,29.951347937008773,0.6247039437294006,1.865667939186096,25.24824983902175,3000.0,0.6326651573181152,1.811199426651001,24.27457520068237,3003.0,10951.275916576384,18592.80987429619,10951.275916576384,7640.133558750153,0.3651926517486572,0.0 -31500,0.31979877,2.0399752,,,,,,,,,,,,,,,,, -31600,0.23338789,1.9661536,,,,,,,,,,,,,,,,, -31700,0.63895553,2.078977,,,,,,,,,,,,,,,,, -31800,0.32119334,2.0750253,,,,,,,,,,,,,,,,, -31900,0.28926557,2.06648,,,,,,,,,,,,,,,,, -32000,0.31893685,2.0435135,,,,,,,,,,,,,,,,, -32100,0.3694942,2.057232,,,,,,,,,,,,,,,,, -32200,0.6795909,2.0170724,,,,,,,,,,,,,,,,, -32300,0.5546234,1.9285036,,,,,,,,,,,,,,,,, -32400,0.38968197,2.0469725,,,,,,,,,,,,,,,,, -32500,0.31826732,1.8645287,,,,,,,,,,,,,,,,, -32600,0.58460903,2.0340397,,,,,,,,,,,,,,,,, -32700,0.43506587,2.0536702,,,,,,,,,,,,,,,,, -32800,0.46947405,2.114509,,,,,,,,,,,,,,,,, -32900,0.6601081,1.9858959,,,,,,,,,,,,,,,,, -33000,0.26071104,1.9274788,,,,,,,,,,,,,,,,, -33100,0.403269,2.0287716,,,,,,,,,,,,,,,,, -33200,0.47028822,2.0287788,,,,,,,,,,,,,,,,, -33300,0.64824045,1.9218055,,,,,,,,,,,,,,,,, -33400,0.5009099,2.0401337,,,,,,,,,,,,,,,,, -33500,0.5434334,1.9993067,,,,,,,,,,,,,,,,, -33600,0.30956557,2.0750575,,,,,,,,,,,,,,,,, -33700,0.69562835,2.0669732,,,,,,,,,,,,,,,,, -33800,0.40754166,2.009743,,,,,,,,,,,,,,,,, -33843,,,0.608905017375946,2.006382942199707,28.668787281923468,0.6290560364723206,1.847534775733948,26.221402628623725,3000.0,0.6364302039146423,1.796127200126648,24.618080302001683,3003.0,11791.334715604782,19939.061821222305,11791.334715604782,8146.219051361084,0.3942883014678955,0.0 -33900,0.568473,2.0671954,,,,,,,,,,,,,,,,, -34000,0.4753338,1.9964288,,,,,,,,,,,,,,,,, -34100,0.36579382,2.0725293,,,,,,,,,,,,,,,,, -34200,0.31354666,2.021091,,,,,,,,,,,,,,,,, -34300,0.4131147,2.0253048,,,,,,,,,,,,,,,,, -34400,0.5411668,1.9626135,,,,,,,,,,,,,,,,, -34500,0.28246376,1.9881223,,,,,,,,,,,,,,,,, -34600,0.28462818,2.012112,,,,,,,,,,,,,,,,, -34700,0.36084563,1.9877785,,,,,,,,,,,,,,,,, -34800,0.26941964,1.8700881,,,,,,,,,,,,,,,,, -34900,0.36708915,1.9922656,,,,,,,,,,,,,,,,, -35000,0.3132351,1.9587641,,,,,,,,,,,,,,,,, -35100,0.3764255,1.9787636,,,,,,,,,,,,,,,,, -35200,0.3384179,2.0113034,,,,,,,,,,,,,,,,, -35300,0.26371032,1.9184142,,,,,,,,,,,,,,,,, -35400,0.26920918,2.0314221,,,,,,,,,,,,,,,,, -35500,0.53153586,2.0399096,,,,,,,,,,,,,,,,, -35600,0.43960506,2.0048826,,,,,,,,,,,,,,,,, -35700,0.44944653,1.9093071,,,,,,,,,,,,,,,,, -35800,0.44114226,2.046785,,,,,,,,,,,,,,,,, -35900,0.48917136,2.0728822,,,,,,,,,,,,,,,,, -36000,0.6161522,1.9389616,,,,,,,,,,,,,,,,, -36100,0.3313064,1.9852171,,,,,,,,,,,,,,,,, -36200,0.41024843,1.9834822,,,,,,,,,,,,,,,,, -36261,,,0.6050881147384644,2.023306131362915,28.62304099873272,0.6242327690124512,1.8724839687347408,25.53469168614256,3000.0,0.6352681517601013,1.80757749080658,25.25024140268233,3003.0,12631.307217121124,21312.102447271347,12631.307217121124,8679.177055120468,0.4261302947998047,0.0 -36300,0.5071763,1.970167,,,,,,,,,,,,,,,,, -36400,0.26292282,1.9572097,,,,,,,,,,,,,,,,, -36500,0.32956412,1.9402431,,,,,,,,,,,,,,,,, -36600,0.5642104,1.9363891,,,,,,,,,,,,,,,,, -36700,0.26541072,1.9252002,,,,,,,,,,,,,,,,, -36800,0.35974115,2.01224,,,,,,,,,,,,,,,,, -36900,0.96037024,2.0709128,,,,,,,,,,,,,,,,, -37000,0.35257527,1.9622467,,,,,,,,,,,,,,,,, -37100,0.7986452,2.0273747,,,,,,,,,,,,,,,,, -37200,0.51039034,1.9804422,,,,,,,,,,,,,,,,, -37300,0.30535343,2.0473924,,,,,,,,,,,,,,,,, -37400,0.32436493,1.9828955,,,,,,,,,,,,,,,,, -37500,0.24571788,1.9893254,,,,,,,,,,,,,,,,, -37600,0.6444658,2.0072505,,,,,,,,,,,,,,,,, -37700,0.59213907,1.9382664,,,,,,,,,,,,,,,,, -37800,0.32076293,1.9155211,,,,,,,,,,,,,,,,, -37900,0.34497207,2.0821524,,,,,,,,,,,,,,,,, -38000,0.29381588,1.9763789,,,,,,,,,,,,,,,,, -38100,0.44466323,2.0089376,,,,,,,,,,,,,,,,, -38200,0.523779,1.9135762,,,,,,,,,,,,,,,,, -38300,0.5979949,1.9313468,,,,,,,,,,,,,,,,, -38400,0.3268286,2.002725,,,,,,,,,,,,,,,,, -38500,0.38343057,1.9716865,,,,,,,,,,,,,,,,, -38600,0.24102592,1.9053221,,,,,,,,,,,,,,,,, -38678,,,0.6133084893226624,1.960720181465149,29.0881076266154,0.6289196610450745,1.840861201286316,25.56870350178375,3000.0,0.6340363621711731,1.797541260719299,24.727151582813143,3003.0,13471.460614442823,22650.755472898483,13471.460614442823,9177.565561771393,0.4557979106903076,0.0 -38700,0.45329553,1.9574853,,,,,,,,,,,,,,,,, -38800,0.69146097,2.0270512,,,,,,,,,,,,,,,,, -38900,0.53647184,1.96816,,,,,,,,,,,,,,,,, -39000,0.2780735,1.9958835,,,,,,,,,,,,,,,,, -39100,0.4866806,1.939513,,,,,,,,,,,,,,,,, -39200,0.3542796,2.0253017,,,,,,,,,,,,,,,,, -39300,0.34295022,1.943666,,,,,,,,,,,,,,,,, -39400,0.60459363,1.953969,,,,,,,,,,,,,,,,, -39500,0.2894012,1.9979113,,,,,,,,,,,,,,,,, -39600,0.90054625,1.9850233,,,,,,,,,,,,,,,,, -39700,0.4459347,2.0504808,,,,,,,,,,,,,,,,, -39800,0.3861088,2.037365,,,,,,,,,,,,,,,,, -39900,0.7740857,2.0148904,,,,,,,,,,,,,,,,, -40000,0.6519572,1.9688529,,,,,,,,,,,,,,,,, -40100,0.38540977,1.9462625,,,,,,,,,,,,,,,,, -40200,0.32624868,2.024662,,,,,,,,,,,,,,,,, -40300,1.1736239,1.9470919,,,,,,,,,,,,,,,,, -40400,0.38023555,1.9962198,,,,,,,,,,,,,,,,, -40500,1.134645,2.0451875,,,,,,,,,,,,,,,,, -40600,0.33722106,1.960054,,,,,,,,,,,,,,,,, -40700,0.6017666,1.9762214,,,,,,,,,,,,,,,,, -40800,0.6223714,1.9954263,,,,,,,,,,,,,,,,, -40900,0.6602896,2.074742,,,,,,,,,,,,,,,,, -41000,0.42246735,1.9665892,,,,,,,,,,,,,,,,, -41096,,,0.6108484864234924,1.9878480434417725,28.27439783242662,0.6305067539215088,1.8244279623031616,25.164706094613813,3000.0,0.6390215754508972,1.7725989818572998,24.106862780559368,3003.0,14311.619910955427,23991.232154369354,14311.619910955427,9677.77205824852,0.4883375167846679,0.0 -41100,0.5286519,1.9628569,,,,,,,,,,,,,,,,, -41200,0.30631137,2.0124915,,,,,,,,,,,,,,,,, -41300,0.4255459,2.0144658,,,,,,,,,,,,,,,,, -41400,0.33454838,1.9976372,,,,,,,,,,,,,,,,, -41500,0.8975091,1.9233093,,,,,,,,,,,,,,,,, -41600,0.6858968,2.0953588,,,,,,,,,,,,,,,,, -41700,0.3856424,1.9688873,,,,,,,,,,,,,,,,, -41800,0.49514875,1.9946598,,,,,,,,,,,,,,,,, -41900,0.31644595,1.9420043,,,,,,,,,,,,,,,,, -42000,0.4943603,1.9695104,,,,,,,,,,,,,,,,, -42100,0.29419014,1.9215922,,,,,,,,,,,,,,,,, -42200,0.33942464,1.9380116,,,,,,,,,,,,,,,,, -42300,0.26143175,1.9198046,,,,,,,,,,,,,,,,, -42400,0.33565703,1.9512395,,,,,,,,,,,,,,,,, -42500,0.3007746,1.9796262,,,,,,,,,,,,,,,,, -42600,0.6712268,2.020997,,,,,,,,,,,,,,,,, -42700,0.38580713,2.0316484,,,,,,,,,,,,,,,,, -42800,0.45257345,1.927819,,,,,,,,,,,,,,,,, -42900,0.3318487,1.9622827,,,,,,,,,,,,,,,,, -43000,0.9731613,1.9488122,,,,,,,,,,,,,,,,, -43100,0.72736156,1.987735,,,,,,,,,,,,,,,,, -43200,0.51334125,1.9649544,,,,,,,,,,,,,,,,, -43300,0.8448642,2.0746384,,,,,,,,,,,,,,,,, -43400,0.27253824,1.9565002,,,,,,,,,,,,,,,,, -43500,0.28591102,1.8943924,,,,,,,,,,,,,,,,, -43513,,,0.607908308506012,2.006985902786255,29.054417956878822,0.6317094564437866,1.8237781524658203,25.956385935198032,3000.0,0.6409505605697632,1.772014617919922,25.588925953245663,3003.0,15151.617607355118,25374.50620198249,15151.617607355118,10220.935767889025,0.5191919803619385,0.0 -43600,0.3052536,1.9495713,,,,,,,,,,,,,,,,, -43700,0.49765632,1.9916296,,,,,,,,,,,,,,,,, -43800,0.34876096,2.0973926,,,,,,,,,,,,,,,,, -43900,0.31151158,1.9762317,,,,,,,,,,,,,,,,, -44000,0.4863802,1.9705025,,,,,,,,,,,,,,,,, -44100,0.4142425,2.0267048,,,,,,,,,,,,,,,,, -44200,0.6401361,1.9989425,,,,,,,,,,,,,,,,, -44300,0.2822755,2.0267975,,,,,,,,,,,,,,,,, -44400,0.30345738,1.8684577,,,,,,,,,,,,,,,,, -44500,0.3455526,1.9893702,,,,,,,,,,,,,,,,, -44600,0.40377793,1.866855,,,,,,,,,,,,,,,,, -44700,0.59399515,1.9933382,,,,,,,,,,,,,,,,, -44800,0.43237332,1.964635,,,,,,,,,,,,,,,,, -44900,0.28830466,2.104637,,,,,,,,,,,,,,,,, -45000,0.3832237,1.9376521,,,,,,,,,,,,,,,,, -45100,0.36461592,1.9505961,,,,,,,,,,,,,,,,, -45200,0.36355925,1.9088844,,,,,,,,,,,,,,,,, -45300,0.39326793,1.9364874,,,,,,,,,,,,,,,,, -45400,0.3834273,2.0371358,,,,,,,,,,,,,,,,, -45500,0.26077512,1.9950807,,,,,,,,,,,,,,,,, -45600,0.2790257,1.9137028,,,,,,,,,,,,,,,,, -45700,0.2902205,1.9129964,,,,,,,,,,,,,,,,, -45800,0.22808592,1.8449795,,,,,,,,,,,,,,,,, -45900,0.233317,1.9652718,,,,,,,,,,,,,,,,, -45930,,,0.6155939102172852,1.9596761465072632,29.238332232257907,0.633432924747467,1.8007560968399048,26.45347452918118,3000.0,0.6413573026657104,1.751164436340332,25.74506411049333,3003.0,15991.709733009338,26745.632090568542,15991.709733009338,10751.859727859495,0.5495162010192871,0.0 -46000,0.61735976,2.0208209,,,,,,,,,,,,,,,,, -46100,0.41124177,1.9201221,,,,,,,,,,,,,,,,, -46200,0.37837481,1.9095228,,,,,,,,,,,,,,,,, -46300,0.6978169,2.0076594,,,,,,,,,,,,,,,,, -46400,0.41580564,1.9288659,,,,,,,,,,,,,,,,, -46500,0.595466,1.9673848,,,,,,,,,,,,,,,,, -46600,0.3505066,1.9858006,,,,,,,,,,,,,,,,, -46700,0.37640554,1.9515148,,,,,,,,,,,,,,,,, -46800,0.47088787,2.0066361,,,,,,,,,,,,,,,,, -46900,0.5473122,1.9743806,,,,,,,,,,,,,,,,, -47000,0.7007037,1.9218248,,,,,,,,,,,,,,,,, -47100,0.367634,1.9193108,,,,,,,,,,,,,,,,, -47200,0.36440083,1.8681506,,,,,,,,,,,,,,,,, -47300,0.5380952,1.891319,,,,,,,,,,,,,,,,, -47400,0.4571818,1.9582794,,,,,,,,,,,,,,,,, -47500,0.46480307,1.971213,,,,,,,,,,,,,,,,, -47600,0.33817655,1.8844591,,,,,,,,,,,,,,,,, -47700,0.5333931,2.0321484,,,,,,,,,,,,,,,,, -47800,0.51911545,1.9933548,,,,,,,,,,,,,,,,, -47900,0.71092176,2.0474548,,,,,,,,,,,,,,,,, -48000,0.30554113,1.911458,,,,,,,,,,,,,,,,, -48100,0.593334,1.94209,,,,,,,,,,,,,,,,, -48200,0.39225942,1.8971951,,,,,,,,,,,,,,,,, -48300,0.46595532,2.0108113,,,,,,,,,,,,,,,,, -48348,,,0.6091585755348206,1.986215591430664,28.95863394471144,0.6338297128677368,1.803328514099121,26.30413678744673,3000.0,0.6401836276054382,1.750659704208374,25.13559768501155,3003.0,16831.879290819168,28087.13062429428,16831.879290819168,11253.07631278038,0.5821371078491211,0.0 -48400,0.59332955,1.9164109,,,,,,,,,,,,,,,,, -48500,0.28897518,2.0485888,,,,,,,,,,,,,,,,, -48600,0.2865467,1.9075701,,,,,,,,,,,,,,,,, -48700,0.46393865,1.954574,,,,,,,,,,,,,,,,, -48800,0.28425154,1.9072311,,,,,,,,,,,,,,,,, -48900,0.44530702,1.970933,,,,,,,,,,,,,,,,, -49000,0.454078,1.9732219,,,,,,,,,,,,,,,,, -49100,0.3155634,1.913921,,,,,,,,,,,,,,,,, -49200,0.60979074,2.0153682,,,,,,,,,,,,,,,,, -49300,0.36744955,1.9688929,,,,,,,,,,,,,,,,, -49400,0.2725878,1.9099183,,,,,,,,,,,,,,,,, -49500,0.50868744,1.9824771,,,,,,,,,,,,,,,,, -49600,1.0712227,1.9802309,,,,,,,,,,,,,,,,, -49700,0.31855747,1.9097275,,,,,,,,,,,,,,,,, -49800,0.40420893,1.8774855,,,,,,,,,,,,,,,,, -49900,0.39667577,1.927173,,,,,,,,,,,,,,,,, -50000,0.5636227,1.9579495,,,,,,,,,,,,,,,,, -50100,0.36081687,1.9030552,,,,,,,,,,,,,,,,, -50200,0.31255054,1.9080712,,,,,,,,,,,,,,,,, -50300,0.5696509,1.9366542,,,,,,,,,,,,,,,,, -50400,0.25468695,1.9845632,,,,,,,,,,,,,,,,, -50500,0.41278493,1.9962149,,,,,,,,,,,,,,,,, -50600,0.5366282,1.9842672,,,,,,,,,,,,,,,,, -50700,0.4776224,2.0013924,,,,,,,,,,,,,,,,, -50765,,,0.6214835047721863,1.8999727964401243,29.929704359714087,0.6367310881614685,1.78389310836792,26.16699754167921,3000.0,0.6443437337875366,1.731534719467163,25.595392106041817,3003.0,17671.927947998047,29506.834899187088,17671.927947998047,11832.61929321289,0.6142852306365967,0.0 -50800,0.4797599,1.9329977,,,,,,,,,,,,,,,,, -50900,0.49147666,1.9841617,,,,,,,,,,,,,,,,, -51000,0.28616205,1.9063474,,,,,,,,,,,,,,,,, -51100,0.36391723,1.9488152,,,,,,,,,,,,,,,,, -51200,0.5008929,1.9496279,,,,,,,,,,,,,,,,, -51300,0.4015326,2.0059288,,,,,,,,,,,,,,,,, -51400,0.4127017,1.929769,,,,,,,,,,,,,,,,, -51500,0.40450788,1.9471813,,,,,,,,,,,,,,,,, -51600,0.3263806,1.8837614,,,,,,,,,,,,,,,,, -51700,0.34190482,2.0182264,,,,,,,,,,,,,,,,, -51800,0.3411849,1.9652302,,,,,,,,,,,,,,,,, -51900,0.6400427,1.8752974,,,,,,,,,,,,,,,,, -52000,0.85650235,1.9930347,,,,,,,,,,,,,,,,, -52100,0.520312,1.9588209,,,,,,,,,,,,,,,,, -52200,0.7684856,1.901615,,,,,,,,,,,,,,,,, -52300,0.677955,1.9155278,,,,,,,,,,,,,,,,, -52400,0.9485724,1.9924546,,,,,,,,,,,,,,,,, -52500,0.30223876,1.8647045,,,,,,,,,,,,,,,,, -52600,0.30727002,1.9065685,,,,,,,,,,,,,,,,, -52700,0.56201416,1.8558176,,,,,,,,,,,,,,,,, -52800,0.25219664,1.948836,,,,,,,,,,,,,,,,, -52900,0.30127504,1.9853715,,,,,,,,,,,,,,,,, -53000,0.3153287,1.8790914,,,,,,,,,,,,,,,,, -53100,0.30064395,1.9643887,,,,,,,,,,,,,,,,, -53183,,,0.6156458258628845,1.9492367506027224,29.42759289127698,0.6360987424850464,1.7813689708709717,26.239123202813865,3000.0,0.6461565494537354,1.7227002382278442,25.63268747559185,3003.0,18511.936827898026,30875.469685792923,18511.936827898026,12361.133578777311,0.6469564437866211,0.0 -53200,0.44966406,1.9670743,,,,,,,,,,,,,,,,, -53300,0.49120897,1.8862085,,,,,,,,,,,,,,,,, -53400,0.33676475,1.923978,,,,,,,,,,,,,,,,, -53500,0.8834245,1.8497553,,,,,,,,,,,,,,,,, -53600,0.61012185,1.9451923,,,,,,,,,,,,,,,,, -53700,0.3811291,1.8155752,,,,,,,,,,,,,,,,, -53800,0.6048089,2.0011678,,,,,,,,,,,,,,,,, -53900,0.8880957,1.9662707,,,,,,,,,,,,,,,,, -54000,0.24946211,1.9318931,,,,,,,,,,,,,,,,, -54100,0.48137423,1.9754602,,,,,,,,,,,,,,,,, -54200,0.5065619,1.9578696,,,,,,,,,,,,,,,,, -54300,0.56629616,1.9579232,,,,,,,,,,,,,,,,, -54400,0.31403637,1.9271642,,,,,,,,,,,,,,,,, -54500,0.5701647,2.0639794,,,,,,,,,,,,,,,,, -54600,0.4028489,2.0164032,,,,,,,,,,,,,,,,, -54700,0.2760958,1.9594225,,,,,,,,,,,,,,,,, -54800,0.2804019,1.9207815,,,,,,,,,,,,,,,,, -54900,0.2996831,1.8899605,,,,,,,,,,,,,,,,, -55000,0.62782735,1.8970165,,,,,,,,,,,,,,,,, -55100,0.4764201,1.8801268,,,,,,,,,,,,,,,,, -55200,0.29083282,1.9140744,,,,,,,,,,,,,,,,, -55300,0.24411629,1.9079792,,,,,,,,,,,,,,,,, -55400,0.35217905,1.8716315,,,,,,,,,,,,,,,,, -55500,0.31926632,2.0703979,,,,,,,,,,,,,,,,, -55600,,,0.6179764866828918,1.9319465160369875,28.86257199663916,0.64002925157547,1.7602189779281616,26.359163168974383,3000.0,0.6483876705169678,1.7087100744247437,25.566127278252768,3003.0,19352.044857025143,32251.69904208184,19352.044857025143,12897.13995742798,0.6800875663757324,0.0 -55600,0.34148374,1.9281479,,,,,,,,,,,,,,,,, -55700,0.3123126,1.9434532,,,,,,,,,,,,,,,,, -55800,0.34070128,1.9530109,,,,,,,,,,,,,,,,, -55900,0.34217796,2.0822363,,,,,,,,,,,,,,,,, -56000,0.23425145,1.9471543,,,,,,,,,,,,,,,,, -56100,0.65284014,1.9153198,,,,,,,,,,,,,,,,, -56200,0.35345277,1.9397415,,,,,,,,,,,,,,,,, -56300,0.29140356,2.0424988,,,,,,,,,,,,,,,,, -56400,0.42875513,1.8696021,,,,,,,,,,,,,,,,, -56500,0.36304036,1.9083929,,,,,,,,,,,,,,,,, -56600,0.48524374,1.923867,,,,,,,,,,,,,,,,, -56700,0.59357864,1.9045569,,,,,,,,,,,,,,,,, -56800,0.3946909,1.9508785,,,,,,,,,,,,,,,,, -56900,0.283918,1.8904566,,,,,,,,,,,,,,,,, -57000,0.51882374,1.9865779,,,,,,,,,,,,,,,,, -57100,0.3885532,1.8865832,,,,,,,,,,,,,,,,, -57200,0.30212358,2.003191,,,,,,,,,,,,,,,,, -57300,0.2621715,1.8948734,,,,,,,,,,,,,,,,, -57400,0.27806437,1.9441874,,,,,,,,,,,,,,,,, -57500,0.7043691,1.9229258,,,,,,,,,,,,,,,,, -57600,0.41443893,1.9653404,,,,,,,,,,,,,,,,, -57700,0.5701512,1.8946768,,,,,,,,,,,,,,,,, -57800,0.2990891,1.9220288,,,,,,,,,,,,,,,,, -57900,0.5997842,1.8760437,,,,,,,,,,,,,,,,, -58000,0.41044426,1.9123566,,,,,,,,,,,,,,,,, -58017,,,0.6243790984153748,1.8864920139312744,30.10999079467282,0.6418147087097168,1.7520946264266968,26.486150866715008,3000.0,0.6493057012557983,1.694807529449463,26.47097900002062,3003.0,20191.96054983139,33636.40701699257,20191.96054983139,13441.818828821182,0.7141232490539551,0.0 -58100,0.40097243,1.8481268,,,,,,,,,,,,,,,,, -58200,0.6230783,1.9194866,,,,,,,,,,,,,,,,, -58300,0.3221948,1.849548,,,,,,,,,,,,,,,,, -58400,0.32340232,1.9117002,,,,,,,,,,,,,,,,, -58500,0.29648265,1.944658,,,,,,,,,,,,,,,,, -58600,0.5878134,1.8868148,,,,,,,,,,,,,,,,, -58700,0.404503,1.9357183,,,,,,,,,,,,,,,,, -58800,0.31188455,1.8029498,,,,,,,,,,,,,,,,, -58900,0.30964234,1.876365,,,,,,,,,,,,,,,,, -59000,0.24576941,1.8619885,,,,,,,,,,,,,,,,, -59100,0.77633977,1.9251595,,,,,,,,,,,,,,,,, -59200,0.30416286,1.9810269,,,,,,,,,,,,,,,,, -59300,0.6408584,1.9694669,,,,,,,,,,,,,,,,, -59400,0.33439884,1.8405508,,,,,,,,,,,,,,,,, -59500,0.58784133,1.8985463,,,,,,,,,,,,,,,,, -59600,0.29738244,1.8781266,,,,,,,,,,,,,,,,, -59700,0.33702677,1.9558804,,,,,,,,,,,,,,,,, -59800,0.62656766,1.9160165,,,,,,,,,,,,,,,,, -59900,0.35430634,1.9074967,,,,,,,,,,,,,,,,, -60000,0.4361345,1.9743538,,,,,,,,,,,,,,,,, -60100,0.48636615,1.891319,,,,,,,,,,,,,,,,, -60200,0.36387065,1.8925321,,,,,,,,,,,,,,,,, -60300,0.29995066,1.9559298,,,,,,,,,,,,,,,,, -60400,0.49913213,1.9577628,,,,,,,,,,,,,,,,, -60436,,,0.6193897128105164,1.9212905168533323,29.84611340870408,0.6405872106552124,1.7477840185165403,26.8364294960986,3000.0,0.6511765718460083,1.6859363317489624,25.956983521488905,3003.0,21032.15257835388,35020.5366435051,21032.15257835388,13985.639262914658,0.752568244934082,0.0 -60500,0.30790126,1.911191,,,,,,,,,,,,,,,,, -60600,0.32416487,1.8481165,,,,,,,,,,,,,,,,, -60700,0.3659046,2.0108886,,,,,,,,,,,,,,,,, -60800,0.28612036,1.989331,,,,,,,,,,,,,,,,, -60900,0.5812857,1.9224663,,,,,,,,,,,,,,,,, -61000,0.3414648,1.9856716,,,,,,,,,,,,,,,,, -61100,0.5735818,1.9131418,,,,,,,,,,,,,,,,, -61200,0.3973802,1.8996272,,,,,,,,,,,,,,,,, -61300,0.39324018,1.8478287,,,,,,,,,,,,,,,,, -61400,0.41717336,1.9071127,,,,,,,,,,,,,,,,, -61500,0.31242606,1.9382888,,,,,,,,,,,,,,,,, -61600,0.34327677,1.949342,,,,,,,,,,,,,,,,, -61700,0.36537924,1.8987836,,,,,,,,,,,,,,,,, -61800,0.30558562,1.9746561,,,,,,,,,,,,,,,,, -61900,0.73655695,1.8289254,,,,,,,,,,,,,,,,, -62000,0.42624107,1.9462049,,,,,,,,,,,,,,,,, -62100,0.35926262,1.9575266,,,,,,,,,,,,,,,,, -62200,0.28299263,1.8351473,,,,,,,,,,,,,,,,, -62300,0.8081848,1.8782071,,,,,,,,,,,,,,,,, -62400,0.5497772,1.7910688,,,,,,,,,,,,,,,,, -62500,0.54579586,1.8927835,,,,,,,,,,,,,,,,, -62600,0.400302,1.8366148,,,,,,,,,,,,,,,,, -62700,0.6111035,1.8994837,,,,,,,,,,,,,,,,, -62800,0.4832671,1.9016706,,,,,,,,,,,,,,,,, -62854,,,0.6402044296264648,1.7503422498703003,31.06420597935536,0.6426454782485962,1.7346762418746948,26.713522805537497,3000.0,0.6542444229125977,1.6709492206573486,26.21284965459793,3003.0,21872.31868505478,36391.20599746704,21872.31868505478,14516.028207540512,0.7864320278167725,0.0 -62900,0.31756434,1.9174899,,,,,,,,,,,,,,,,, -63000,0.29916796,1.9092089,,,,,,,,,,,,,,,,, -63100,0.31671154,1.8941793,,,,,,,,,,,,,,,,, -63200,0.478254,1.9234725,,,,,,,,,,,,,,,,, -63300,0.43287596,1.8237816,,,,,,,,,,,,,,,,, -63400,0.29455212,1.9506756,,,,,,,,,,,,,,,,, -63500,0.4029833,1.8764312,,,,,,,,,,,,,,,,, -63600,0.5104609,1.8525331,,,,,,,,,,,,,,,,, -63700,0.31333408,1.8926642,,,,,,,,,,,,,,,,, -63800,0.5601118,1.8985306,,,,,,,,,,,,,,,,, -63900,0.3884456,1.903945,,,,,,,,,,,,,,,,, -64000,0.2822314,1.8401095,,,,,,,,,,,,,,,,, -64100,0.55599153,1.885527,,,,,,,,,,,,,,,,, -64200,0.28357136,1.9057853,,,,,,,,,,,,,,,,, -64300,0.31654164,1.8871557,,,,,,,,,,,,,,,,, -64400,0.27444494,1.9172857,,,,,,,,,,,,,,,,, -64500,0.6571779,1.9444551,,,,,,,,,,,,,,,,, -64600,0.3248327,1.9054468,,,,,,,,,,,,,,,,, -64700,0.41026866,1.9486833,,,,,,,,,,,,,,,,, -64800,0.40599883,1.9058844,,,,,,,,,,,,,,,,, -64900,0.27909106,1.8442242,,,,,,,,,,,,,,,,, -65000,0.31193104,1.8770514,,,,,,,,,,,,,,,,, -65100,0.32922256,1.9148524,,,,,,,,,,,,,,,,, -65200,0.7096019,1.8362534,,,,,,,,,,,,,,,,, -65273,,,0.6218117475509644,1.9062598943710327,29.803210788948807,0.6440093517303467,1.7179428339004517,26.75458586334858,3000.0,0.6531752943992615,1.6693758964538574,26.230637257259648,3003.0,22712.382335186005,37817.06096410751,22712.382335186005,15101.708109140396,0.8203840255737305,0.0 -65300,0.3469765,1.871121,,,,,,,,,,,,,,,,, -65400,0.6523272,1.932757,,,,,,,,,,,,,,,,, -65500,0.31529608,1.8511078,,,,,,,,,,,,,,,,, -65600,0.27935,1.9279867,,,,,,,,,,,,,,,,, -65700,0.404234,1.9402049,,,,,,,,,,,,,,,,, -65800,0.34304827,1.8520144,,,,,,,,,,,,,,,,, -65900,0.597946,1.8778435,,,,,,,,,,,,,,,,, -66000,0.3726112,1.8222575,,,,,,,,,,,,,,,,, -66100,0.27349475,1.8810521,,,,,,,,,,,,,,,,, -66200,0.34481224,1.8834921,,,,,,,,,,,,,,,,, -66300,0.45499945,1.9323226,,,,,,,,,,,,,,,,, -66400,0.40811852,1.794074,,,,,,,,,,,,,,,,, -66500,0.26757905,1.8408247,,,,,,,,,,,,,,,,, -66600,0.30823296,1.8852179,,,,,,,,,,,,,,,,, -66700,0.4828246,1.9430466,,,,,,,,,,,,,,,,, -66800,0.6464303,1.8363448,,,,,,,,,,,,,,,,, -66900,0.28887862,1.8969862,,,,,,,,,,,,,,,,, -67000,0.2887528,1.8429462,,,,,,,,,,,,,,,,, -67100,0.3171335,1.9068831,,,,,,,,,,,,,,,,, -67200,0.2825187,1.8309245,,,,,,,,,,,,,,,,, -67300,0.32217646,1.8962718,,,,,,,,,,,,,,,,, -67400,0.47593328,1.9056288,,,,,,,,,,,,,,,,, -67500,0.4583361,1.9512554,,,,,,,,,,,,,,,,, -67600,0.33084568,1.8127,,,,,,,,,,,,,,,,, -67690,,,0.6277031898498535,1.868759274482727,30.25100134313,0.6481258869171143,1.7066301107406616,26.14190472558656,3000.0,0.6579745411872864,1.6437774896621704,26.80649253582942,3003.0,23552.53147172928,39461.53255653381,23552.53147172928,15905.911636829376,0.8546721935272217,0.0 -67700,0.30885455,1.8299053,,,,,,,,,,,,,,,,, -67800,0.46142948,1.8964798,,,,,,,,,,,,,,,,, -67900,0.27642035,1.9253802,,,,,,,,,,,,,,,,, -68000,0.31430358,1.8510609,,,,,,,,,,,,,,,,, -68100,0.29344803,1.8369077,,,,,,,,,,,,,,,,, -68200,0.32932496,1.9479905,,,,,,,,,,,,,,,,, -68300,0.4744698,1.8186678,,,,,,,,,,,,,,,,, -68400,0.34333608,1.8244439,,,,,,,,,,,,,,,,, -68500,0.598396,1.8856877,,,,,,,,,,,,,,,,, -68600,0.27945858,1.832977,,,,,,,,,,,,,,,,, -68700,0.5947783,1.9257656,,,,,,,,,,,,,,,,, -68800,0.4062337,1.7797338,,,,,,,,,,,,,,,,, -68900,0.3479739,1.8854365,,,,,,,,,,,,,,,,, -69000,0.53771377,1.8764199,,,,,,,,,,,,,,,,, -69100,0.33016387,1.9223106,,,,,,,,,,,,,,,,, -69200,0.31897843,1.9371351,,,,,,,,,,,,,,,,, -69300,0.4732785,1.9138899,,,,,,,,,,,,,,,,, -69400,0.4009812,1.8140029,,,,,,,,,,,,,,,,, -69500,0.62870336,1.8231199,,,,,,,,,,,,,,,,, -69600,0.61982876,1.8331074,,,,,,,,,,,,,,,,, -69700,0.5584872,1.9193236,,,,,,,,,,,,,,,,, -69800,0.632509,1.8193692,,,,,,,,,,,,,,,,, -69900,0.2743615,1.8979483,,,,,,,,,,,,,,,,, -70000,0.60496765,1.8633571,,,,,,,,,,,,,,,,, -70100,0.26380536,1.8313665,,,,,,,,,,,,,,,,, -70107,,,0.6321364045143127,1.8177008628845213,30.434628193983656,0.6484978199005127,1.7004215717315674,27.146623495404736,3000.0,0.6586717963218689,1.635362982749939,26.89588272520431,3003.0,24392.43099117279,40835.84396624565,24392.43099117279,16440.199833631516,0.8966896533966064,0.0 -70200,0.3264941,1.9032989,,,,,,,,,,,,,,,,, -70300,0.45987883,1.824519,,,,,,,,,,,,,,,,, -70400,0.28924477,1.767604,,,,,,,,,,,,,,,,, -70500,0.64929044,1.8335284,,,,,,,,,,,,,,,,, -70600,0.3723969,1.8226186,,,,,,,,,,,,,,,,, -70700,0.44249782,1.8297716,,,,,,,,,,,,,,,,, -70800,0.32878825,1.8419943,,,,,,,,,,,,,,,,, -70900,0.2957033,1.8588493,,,,,,,,,,,,,,,,, -71000,0.36207178,1.7423205,,,,,,,,,,,,,,,,, -71100,0.499114,1.8413119,,,,,,,,,,,,,,,,, -71200,0.38066912,1.9243077,,,,,,,,,,,,,,,,, -71300,0.4384351,1.8465806,,,,,,,,,,,,,,,,, -71400,0.4161827,1.8510525,,,,,,,,,,,,,,,,, -71500,0.4951101,1.868033,,,,,,,,,,,,,,,,, -71600,0.29248598,1.8264338,,,,,,,,,,,,,,,,, -71700,0.5347067,1.8576154,,,,,,,,,,,,,,,,, -71800,0.36104372,1.8693012,,,,,,,,,,,,,,,,, -71900,0.3124174,1.8328422,,,,,,,,,,,,,,,,, -72000,0.42729977,1.8872439,,,,,,,,,,,,,,,,, -72100,0.3548131,1.8500066,,,,,,,,,,,,,,,,, -72200,0.51409876,1.7742821,,,,,,,,,,,,,,,,, -72300,0.29032668,1.8082078,,,,,,,,,,,,,,,,, -72400,0.30958858,1.8905776,,,,,,,,,,,,,,,,, -72500,0.57196933,1.8003652,,,,,,,,,,,,,,,,, -72525,,,0.6318503618240356,1.8400778770446773,30.32549065257292,0.6501221060752869,1.6853595972061155,27.80913835525352,3000.0,0.6619255542755127,1.614753007888794,26.698471399758624,3003.0,25232.470304965973,42266.7501308918,25232.470304965973,17030.950835227966,0.933668613433838,0.0 -72600,0.3923631,1.8141754,,,,,,,,,,,,,,,,, -72700,0.31959918,1.8441014,,,,,,,,,,,,,,,,, -72800,0.46545225,1.8526523,,,,,,,,,,,,,,,,, -72900,0.33234283,1.8746934,,,,,,,,,,,,,,,,, -73000,0.4155824,1.8618355,,,,,,,,,,,,,,,,, -73100,0.46411294,1.8746824,,,,,,,,,,,,,,,,, -73200,0.40494946,1.7975838,,,,,,,,,,,,,,,,, -73300,0.45901752,1.7347542,,,,,,,,,,,,,,,,, -73400,0.30990425,1.7616276,,,,,,,,,,,,,,,,, -73500,0.52709055,1.8340023,,,,,,,,,,,,,,,,, -73600,0.34888425,1.8990592,,,,,,,,,,,,,,,,, -73700,0.40169883,1.7823906,,,,,,,,,,,,,,,,, -73800,0.30765957,1.8374147,,,,,,,,,,,,,,,,, -73900,0.32606077,1.8219815,,,,,,,,,,,,,,,,, -74000,0.50991964,1.8744106,,,,,,,,,,,,,,,,, -74100,0.27458268,1.8742226,,,,,,,,,,,,,,,,, -74200,0.31843254,1.8714802,,,,,,,,,,,,,,,,, -74300,0.38061363,1.8800374,,,,,,,,,,,,,,,,, -74400,0.30858132,1.8160431,,,,,,,,,,,,,,,,, -74500,0.35641417,1.7377483,,,,,,,,,,,,,,,,, -74600,0.35200018,1.769754,,,,,,,,,,,,,,,,, -74700,0.5099263,1.773833,,,,,,,,,,,,,,,,, -74800,0.37679547,1.8876528,,,,,,,,,,,,,,,,, -74900,0.47086453,1.8303859,,,,,,,,,,,,,,,,, -74944,,,0.6331945657730103,1.8335165977478027,30.63485098077157,0.6535691022872925,1.6698365211486816,27.427614088747323,3000.0,0.6597524881362915,1.6147087812423706,26.434105724544263,3003.0,26072.504900217056,43725.81726980209,26072.504900217056,17649.86233663559,0.9771442413330078,0.0 -75000,0.31073382,1.7851956,,,,,,,,,,,,,,,,, -75100,0.53421247,1.7156335,,,,,,,,,,,,,,,,, -75200,0.29046625,1.7932082,,,,,,,,,,,,,,,,, -75300,0.30513453,1.843587,,,,,,,,,,,,,,,,, -75400,0.2958785,1.7887031,,,,,,,,,,,,,,,,, -75500,0.52342945,1.8296422,,,,,,,,,,,,,,,,, -75600,0.4419102,1.8527644,,,,,,,,,,,,,,,,, -75700,0.5602175,1.8570828,,,,,,,,,,,,,,,,, -75800,0.33758092,1.8385848,,,,,,,,,,,,,,,,, -75900,0.28181267,1.7999882,,,,,,,,,,,,,,,,, -76000,0.3172494,1.7782788,,,,,,,,,,,,,,,,, -76100,0.60157627,1.8050302,,,,,,,,,,,,,,,,, -76200,0.39821234,1.773039,,,,,,,,,,,,,,,,, -76300,0.47457218,1.8356203,,,,,,,,,,,,,,,,, -76400,0.29456413,1.9188957,,,,,,,,,,,,,,,,, -76500,0.32441843,1.8015097,,,,,,,,,,,,,,,,, -76600,0.2865426,1.7688261,,,,,,,,,,,,,,,,, -76700,0.60250247,1.8408184,,,,,,,,,,,,,,,,, -76800,0.34688115,1.8101139,,,,,,,,,,,,,,,,, -76900,0.2859783,1.8150433,,,,,,,,,,,,,,,,, -77000,0.3778063,1.8623066,,,,,,,,,,,,,,,,, -77100,0.33656487,1.8048788,,,,,,,,,,,,,,,,, -77200,0.35601425,1.8039405,,,,,,,,,,,,,,,,, -77300,0.4268168,1.8127738,,,,,,,,,,,,,,,,, -77361,,,0.6339204907417297,1.8167153596878047,30.797597115589767,0.6549205780029297,1.665434956550598,27.880582151012234,3000.0,0.6661786437034607,1.5921257734298706,27.594957648339527,3003.0,26912.40016245842,45080.52706384659,26912.40016245842,18164.56058192253,1.013715744018555,0.0 -77400,0.5228241,1.8056213,,,,,,,,,,,,,,,,, -77500,0.40755948,1.8188857,,,,,,,,,,,,,,,,, -77600,0.5108982,1.7996843,,,,,,,,,,,,,,,,, -77700,0.4183402,1.9881257,,,,,,,,,,,,,,,,, -77800,0.40907162,1.8648915,,,,,,,,,,,,,,,,, -77900,0.281869,1.7903161,,,,,,,,,,,,,,,,, -78000,0.28562325,1.7436632,,,,,,,,,,,,,,,,, -78100,0.37482318,1.873901,,,,,,,,,,,,,,,,, -78200,0.3027009,1.8387413,,,,,,,,,,,,,,,,, -78300,0.45866048,1.8267382,,,,,,,,,,,,,,,,, -78400,0.43425044,1.7828592,,,,,,,,,,,,,,,,, -78500,0.3236273,1.8174895,,,,,,,,,,,,,,,,, -78600,0.3001673,1.8156232,,,,,,,,,,,,,,,,, -78700,0.44267648,1.8195597,,,,,,,,,,,,,,,,, -78800,0.291043,1.7427342,,,,,,,,,,,,,,,,, -78900,0.46066245,1.7792103,,,,,,,,,,,,,,,,, -79000,0.3208988,1.7780776,,,,,,,,,,,,,,,,, -79100,0.3056571,1.8459287,,,,,,,,,,,,,,,,, -79200,0.357659,1.847248,,,,,,,,,,,,,,,,, -79300,0.36403614,1.8158787,,,,,,,,,,,,,,,,, -79400,0.34242398,1.7751298,,,,,,,,,,,,,,,,, -79500,0.38367215,1.8803002,,,,,,,,,,,,,,,,, -79600,0.30764297,1.829077,,,,,,,,,,,,,,,,, -79700,0.51992244,1.8421823,,,,,,,,,,,,,,,,, -79779,,,0.6340129375457764,1.818768858909607,31.369896513488115,0.6570656299591064,1.6416860818862915,28.10779161190765,3000.0,0.6673523187637329,1.5821276903152466,27.324396090936787,3003.0,27752.500133752823,46464.30451631546,27752.500133752823,18708.115093946457,1.0589666366577148,0.0 -79800,0.36841238,1.8207939,,,,,,,,,,,,,,,,, -79900,0.34134117,1.8417338,,,,,,,,,,,,,,,,, -80000,0.72163403,1.7098327,,,,,,,,,,,,,,,,, -80100,0.28950045,1.8795,,,,,,,,,,,,,,,,, -80200,0.27831993,1.7578778,,,,,,,,,,,,,,,,, -80300,0.28825545,1.8503083,,,,,,,,,,,,,,,,, -80400,0.32284433,1.800775,,,,,,,,,,,,,,,,, -80500,0.35311264,1.8864213,,,,,,,,,,,,,,,,, -80600,0.3194826,1.7860855,,,,,,,,,,,,,,,,, -80700,0.29516664,1.7608892,,,,,,,,,,,,,,,,, -80800,0.45520395,1.8419108,,,,,,,,,,,,,,,,, -80900,0.32540384,1.8146434,,,,,,,,,,,,,,,,, -81000,0.36150122,1.8588134,,,,,,,,,,,,,,,,, -81100,0.43070307,1.7486459,,,,,,,,,,,,,,,,, -81200,0.36789832,1.7901739,,,,,,,,,,,,,,,,, -81300,0.27615446,1.8887002,,,,,,,,,,,,,,,,, -81400,0.3152678,1.7710935,,,,,,,,,,,,,,,,, -81500,0.29410475,1.7149466,,,,,,,,,,,,,,,,, -81600,0.32561168,1.8282443,,,,,,,,,,,,,,,,, -81700,0.53584504,1.7653214,,,,,,,,,,,,,,,,, -81800,0.35642123,1.7877334,,,,,,,,,,,,,,,,, -81900,0.51503867,1.8392745,,,,,,,,,,,,,,,,, -82000,0.3805703,1.8607572,,,,,,,,,,,,,,,,, -82100,0.40027547,1.7311244,,,,,,,,,,,,,,,,, -82197,,,0.6473284363746643,1.7201316356658936,31.41526678659185,0.659334659576416,1.631327986717224,28.270451851715933,3000.0,0.6687816381454468,1.567957639694214,27.69355476262512,3003.0,28592.67398762703,47843.56256365776,28592.67398762703,19247.07579922676,1.102534532546997,0.0 -82200,0.28240278,1.8426826,,,,,,,,,,,,,,,,, -82300,0.59468687,1.8872392,,,,,,,,,,,,,,,,, -82400,0.47812372,1.8130597,,,,,,,,,,,,,,,,, -82500,0.36171317,1.8113608,,,,,,,,,,,,,,,,, -82600,0.38763833,1.7381063,,,,,,,,,,,,,,,,, -82700,0.32730472,1.7982398,,,,,,,,,,,,,,,,, -82800,0.31706735,1.7965204,,,,,,,,,,,,,,,,, -82900,0.2873251,1.776649,,,,,,,,,,,,,,,,, -83000,0.35544786,1.7443631,,,,,,,,,,,,,,,,, -83100,0.43462333,1.8080325,,,,,,,,,,,,,,,,, -83200,0.8189135,1.8456215,,,,,,,,,,,,,,,,, -83300,0.48174387,1.7212602,,,,,,,,,,,,,,,,, -83400,0.30510297,1.7942704,,,,,,,,,,,,,,,,, -83500,0.31596792,1.7883617,,,,,,,,,,,,,,,,, -83600,0.3370208,1.7922539,,,,,,,,,,,,,,,,, -83700,0.2774859,1.7511772,,,,,,,,,,,,,,,,, -83800,0.4284926,1.70395,,,,,,,,,,,,,,,,, -83900,0.32172114,1.8384079,,,,,,,,,,,,,,,,, -84000,0.37238577,1.7414223,,,,,,,,,,,,,,,,, -84100,0.29593483,1.7417942,,,,,,,,,,,,,,,,, -84200,0.5041076,1.8150121,,,,,,,,,,,,,,,,, -84300,0.39010748,1.743019,,,,,,,,,,,,,,,,, -84400,0.28510678,1.8427595,,,,,,,,,,,,,,,,, -84500,0.32865438,1.7596968,,,,,,,,,,,,,,,,, -84600,0.44847807,1.7799294,,,,,,,,,,,,,,,,, -84615,,,0.6412078142166138,1.7685593366622925,30.883998162965828,0.6598802208900452,1.6230331659317017,27.8358965330051,3000.0,0.6705595254898071,1.5564008951187134,27.64755868083939,3003.0,29432.80449271202,49250.32803225517,29432.80449271202,19813.59550333023,1.1386573314666748,0.0 -84700,0.2995006,1.7534944,,,,,,,,,,,,,,,,, -84800,0.3404693,1.7044396,,,,,,,,,,,,,,,,, -84900,0.4778871,1.810203,,,,,,,,,,,,,,,,, -85000,0.33335307,1.756901,,,,,,,,,,,,,,,,, -85100,0.4067662,1.7194817,,,,,,,,,,,,,,,,, -85200,0.34464896,1.8037928,,,,,,,,,,,,,,,,, -85300,0.3870841,1.7984465,,,,,,,,,,,,,,,,, -85400,0.37657976,1.7705423,,,,,,,,,,,,,,,,, -85500,0.3546691,1.832491,,,,,,,,,,,,,,,,, -85600,0.29846832,1.6940372,,,,,,,,,,,,,,,,, -85700,0.3023693,1.8044347,,,,,,,,,,,,,,,,, -85800,0.34642506,1.8677418,,,,,,,,,,,,,,,,, -85900,0.41681257,1.8463875,,,,,,,,,,,,,,,,, -86000,0.307229,1.8044872,,,,,,,,,,,,,,,,, -86100,0.32185444,1.7187647,,,,,,,,,,,,,,,,, -86200,0.29500186,1.7440588,,,,,,,,,,,,,,,,, -86300,0.4637999,1.7423766,,,,,,,,,,,,,,,,, -86400,0.3499073,1.8674886,,,,,,,,,,,,,,,,, -86500,0.37776834,1.7557045,,,,,,,,,,,,,,,,, -86600,0.4017816,1.7334192,,,,,,,,,,,,,,,,, -86700,0.36363235,1.7845132,,,,,,,,,,,,,,,,, -86800,0.36018962,1.7722176,,,,,,,,,,,,,,,,, -86900,0.34465468,1.7414563,,,,,,,,,,,,,,,,, -87000,0.30332428,1.7175683,,,,,,,,,,,,,,,,, -87033,,,0.6418304443359375,1.7551774978637695,31.223073845288223,0.6620252728462219,1.602751851081848,28.48434794106564,3000.0,0.6758236289024353,1.5261303186416626,28.10549738521797,3003.0,30272.858829975128,50639.543586969376,30272.858829975128,20362.637778520584,1.175865888595581,0.0 -87100,0.30906162,1.7014079,,,,,,,,,,,,,,,,, -87200,0.43619964,1.7699488,,,,,,,,,,,,,,,,, -87300,0.39261368,1.7443008,,,,,,,,,,,,,,,,, -87400,0.32828465,1.7431053,,,,,,,,,,,,,,,,, -87500,0.4015976,1.81148,,,,,,,,,,,,,,,,, -87600,0.31116262,1.7248716,,,,,,,,,,,,,,,,, -87700,0.33161312,1.7519218,,,,,,,,,,,,,,,,, -87800,0.27378422,1.7015344,,,,,,,,,,,,,,,,, -87900,0.49176747,1.7289153,,,,,,,,,,,,,,,,, -88000,0.3117497,1.6803304,,,,,,,,,,,,,,,,, -88100,0.3130426,1.778403,,,,,,,,,,,,,,,,, -88200,0.29547375,1.7451903,,,,,,,,,,,,,,,,, -88300,0.47400314,1.8021317,,,,,,,,,,,,,,,,, -88400,0.30038533,1.7689972,,,,,,,,,,,,,,,,, -88500,0.31603897,1.7713853,,,,,,,,,,,,,,,,, -88600,0.32375154,1.6508228,,,,,,,,,,,,,,,,, -88700,0.36835018,1.7455698,,,,,,,,,,,,,,,,, -88800,0.32412863,1.8034184,,,,,,,,,,,,,,,,, -88900,0.30587554,1.7824454,,,,,,,,,,,,,,,,, -89000,0.36810458,1.7506489,,,,,,,,,,,,,,,,, -89100,0.33397692,1.7569865,,,,,,,,,,,,,,,,, -89200,0.3100836,1.7631787,,,,,,,,,,,,,,,,, -89300,0.30979112,1.7367536,,,,,,,,,,,,,,,,, -89400,0.33288845,1.6991292,,,,,,,,,,,,,,,,, -89450,,,0.6507663130760193,1.6971262693405151,31.590313155448776,0.6641083359718323,1.5904189348220823,28.41043268094581,3000.0,0.6781593561172485,1.5138115882873535,28.19384526539295,3003.0,31112.956319332123,52183.43749761581,31112.956319332123,21066.31170296669,1.2124691009521484,0.0 -89500,0.31263092,1.6792967,,,,,,,,,,,,,,,,, -89600,0.31216815,1.7004371,,,,,,,,,,,,,,,,, -89700,0.39484206,1.74722,,,,,,,,,,,,,,,,, -89800,0.29387888,1.7181627,,,,,,,,,,,,,,,,, -89900,0.3888847,1.7063687,,,,,,,,,,,,,,,,, -90000,0.31544366,1.8128945,,,,,,,,,,,,,,,,, -90100,0.3308143,1.7885113,,,,,,,,,,,,,,,,, -90200,0.32238564,1.7850752,,,,,,,,,,,,,,,,, -90300,0.49775314,1.8001419,,,,,,,,,,,,,,,,, -90400,0.29537532,1.6974212,,,,,,,,,,,,,,,,, -90500,0.33485943,1.6673031,,,,,,,,,,,,,,,,, -90600,0.3575797,1.7489402,,,,,,,,,,,,,,,,, -90700,0.3226873,1.8287075,,,,,,,,,,,,,,,,, -90800,0.33747134,1.776538,,,,,,,,,,,,,,,,, -90900,0.50175285,1.7592504,,,,,,,,,,,,,,,,, -91000,0.31133822,1.7247318,,,,,,,,,,,,,,,,, -91100,0.30306095,1.7397102,,,,,,,,,,,,,,,,, -91200,0.47942466,1.7830983,,,,,,,,,,,,,,,,, -91300,0.4988533,1.7963938,,,,,,,,,,,,,,,,, -91400,0.3308801,1.6822801,,,,,,,,,,,,,,,,, -91500,0.3509074,1.6938919,,,,,,,,,,,,,,,,, -91600,0.30563155,1.7268826,,,,,,,,,,,,,,,,, -91700,0.42399088,1.7916636,,,,,,,,,,,,,,,,, -91800,0.2828126,1.7334914,,,,,,,,,,,,,,,,, -91866,,,0.6441076993942261,1.7398148775100708,31.650447712787365,0.6667988896369934,1.5745916366577148,28.44722863245684,3000.0,0.6778688430786133,1.5026555061340332,28.213751552293576,3003.0,31952.85079932213,53667.065616846085,31952.85079932213,21709.92742419243,1.2507178783416748,0.0 -91900,0.28777477,1.6816938,,,,,,,,,,,,,,,,, -92000,0.31809017,1.7819026,,,,,,,,,,,,,,,,, -92100,0.38441315,1.7272171,,,,,,,,,,,,,,,,, -92200,0.3206461,1.8196255,,,,,,,,,,,,,,,,, -92300,0.34224603,1.7033236,,,,,,,,,,,,,,,,, -92400,0.28985888,1.8329899,,,,,,,,,,,,,,,,, -92500,0.3703325,1.6906742,,,,,,,,,,,,,,,,, -92600,0.320357,1.6482157,,,,,,,,,,,,,,,,, -92700,0.3352646,1.7655665,,,,,,,,,,,,,,,,, -92800,0.3410689,1.7793617,,,,,,,,,,,,,,,,, -92900,0.3251468,1.8059341,,,,,,,,,,,,,,,,, -93000,0.30520758,1.7823881,,,,,,,,,,,,,,,,, -93100,0.36150414,1.7313287,,,,,,,,,,,,,,,,, -93200,0.37007025,1.6900448,,,,,,,,,,,,,,,,, -93300,0.3436087,1.6447238,,,,,,,,,,,,,,,,, -93400,0.4612164,1.7174109,,,,,,,,,,,,,,,,, -93500,0.37725937,1.6472782,,,,,,,,,,,,,,,,, -93600,0.37264386,1.7012635,,,,,,,,,,,,,,,,, -93700,0.34462497,1.7338868,,,,,,,,,,,,,,,,, -93800,0.3024624,1.6511912,,,,,,,,,,,,,,,,, -93900,0.32885507,1.7241702,,,,,,,,,,,,,,,,, -94000,0.33531255,1.7504807,,,,,,,,,,,,,,,,, -94100,0.34984878,1.8152303,,,,,,,,,,,,,,,,, -94200,0.3473592,1.6898625,,,,,,,,,,,,,,,,, -94283,,,0.6625533699989319,1.6080623865127563,33.16702852310345,0.6701342463493347,1.558275818824768,28.759051738058805,3000.0,0.6807971596717834,1.4862122535705566,28.15908149035932,3003.0,32792.86261463165,55028.61365580559,32792.86261463165,22231.343421697617,1.2895457744598389,0.0 -94300,0.3605795,1.6545963,,,,,,,,,,,,,,,,, -94400,0.33576262,1.6767912,,,,,,,,,,,,,,,,, -94500,0.34875277,1.6696632,,,,,,,,,,,,,,,,, -94600,0.292359,1.7142657,,,,,,,,,,,,,,,,, -94700,0.3429478,1.6780155,,,,,,,,,,,,,,,,, -94800,0.32937488,1.6852775,,,,,,,,,,,,,,,,, -94900,0.41412273,1.7699115,,,,,,,,,,,,,,,,, -95000,0.37061697,1.6732035,,,,,,,,,,,,,,,,, -95100,0.3297057,1.6709832,,,,,,,,,,,,,,,,, -95200,0.39198807,1.701855,,,,,,,,,,,,,,,,, -95300,0.3703887,1.7764283,,,,,,,,,,,,,,,,, -95400,0.30910528,1.7023989,,,,,,,,,,,,,,,,, -95500,0.31899557,1.67936,,,,,,,,,,,,,,,,, -95600,0.3816989,1.7181855,,,,,,,,,,,,,,,,, -95700,0.3095727,1.7787629,,,,,,,,,,,,,,,,, -95800,0.34887114,1.6291994,,,,,,,,,,,,,,,,, -95900,0.32258075,1.6481334,,,,,,,,,,,,,,,,, -96000,0.41585955,1.717307,,,,,,,,,,,,,,,,, -96100,0.37642354,1.7292746,,,,,,,,,,,,,,,,, -96200,0.31590366,1.5727848,,,,,,,,,,,,,,,,, -96300,0.39347738,1.6876976,,,,,,,,,,,,,,,,, -96400,0.30957687,1.6610194,,,,,,,,,,,,,,,,, -96500,0.3594755,1.6858552,,,,,,,,,,,,,,,,, -96600,0.33127767,1.7534186,,,,,,,,,,,,,,,,, -96700,0.37555704,1.666306,,,,,,,,,,,,,,,,, -96701,,,0.6566953063011169,1.6568286418914795,31.842143989658336,0.670741856098175,1.5458422899246216,29.138297350353145,3000.0,0.6864447593688965,1.4664220809936523,29.245032393980694,3003.0,33633.10122871399,56415.408299446106,33633.10122871399,22777.78276062012,1.3282885551452637,0.0 -96800,0.3353913,1.6977229,,,,,,,,,,,,,,,,, -96900,0.31509712,1.6982992,,,,,,,,,,,,,,,,, -97000,0.33365348,1.6989082,,,,,,,,,,,,,,,,, -97100,0.3272858,1.6096333,,,,,,,,,,,,,,,,, -97200,0.34526515,1.7009948,,,,,,,,,,,,,,,,, -97300,0.3662859,1.728396,,,,,,,,,,,,,,,,, -97400,0.4081034,1.6496263,,,,,,,,,,,,,,,,, -97500,0.3064819,1.6169038,,,,,,,,,,,,,,,,, -97600,0.30856803,1.70703,,,,,,,,,,,,,,,,, -97700,0.31916204,1.6806381,,,,,,,,,,,,,,,,, -97800,0.32589513,1.717841,,,,,,,,,,,,,,,,, -97900,0.31556657,1.6835483,,,,,,,,,,,,,,,,, -98000,0.29621807,1.6587901,,,,,,,,,,,,,,,,, -98100,0.31088713,1.7279744,,,,,,,,,,,,,,,,, -98200,0.30577177,1.6720192,,,,,,,,,,,,,,,,, -98300,0.31278294,1.6545562,,,,,,,,,,,,,,,,, -98400,0.31109682,1.6263573,,,,,,,,,,,,,,,,, -98500,0.3692007,1.6600274,,,,,,,,,,,,,,,,, -98600,0.31731346,1.7025367,,,,,,,,,,,,,,,,, -98700,0.42249236,1.6715838,,,,,,,,,,,,,,,,, -98800,0.30389804,1.6136022,,,,,,,,,,,,,,,,, -98900,0.32212555,1.7073038,,,,,,,,,,,,,,,,, -99000,0.32935327,1.624194,,,,,,,,,,,,,,,,, -99100,0.33742526,1.6670384,,,,,,,,,,,,,,,,, -99118,,,0.6576961874961853,1.6613869667053225,32.159575954341314,0.6734076142311096,1.527236819267273,29.481519388764976,3000.0,0.6870489716529846,1.4522655010223389,29.00504494574323,3003.0,34473.18996334076,57760.54006314278,34473.18996334076,23282.706238031387,1.3684589862823486,0.0 -99200,0.30501464,1.6148648,,,,,,,,,,,,,,,,, -99300,0.34788078,1.6902144,,,,,,,,,,,,,,,,, -99400,0.39628193,1.6669947,,,,,,,,,,,,,,,,, -99500,0.36620283,1.6232971,,,,,,,,,,,,,,,,, -99600,0.31986123,1.6469449,,,,,,,,,,,,,,,,, -99700,0.3095535,1.5938209,,,,,,,,,,,,,,,,, -99800,0.32548293,1.6079913,,,,,,,,,,,,,,,,, -99900,0.34935352,1.6769824,,,,,,,,,,,,,,,,, -100000,0.30768254,1.6709433,,,,,,,,,,,,,,,,, -100100,0.34967396,1.7058738,,,,,,,,,,,,,,,,, -100200,0.3465381,1.6605401,,,,,,,,,,,,,,,,, -100300,0.33058313,1.6051182,,,,,,,,,,,,,,,,, -100400,0.33236322,1.6725384,,,,,,,,,,,,,,,,, -100500,0.33970398,1.6653346,,,,,,,,,,,,,,,,, -100600,0.4077767,1.6017867,,,,,,,,,,,,,,,,, -100700,0.35774967,1.644783,,,,,,,,,,,,,,,,, -100800,0.3632378,1.6514609,,,,,,,,,,,,,,,,, -100900,0.3075775,1.6696748,,,,,,,,,,,,,,,,, -101000,0.32132766,1.6747745,,,,,,,,,,,,,,,,, -101100,0.3528605,1.6181526,,,,,,,,,,,,,,,,, -101200,0.31162834,1.6439551,,,,,,,,,,,,,,,,, -101300,0.30664554,1.6121292,,,,,,,,,,,,,,,,, -101400,0.3576007,1.6412464,,,,,,,,,,,,,,,,, -101500,0.3187588,1.6263059,,,,,,,,,,,,,,,,, -101536,,,0.6628363132476807,1.61624276638031,33.3643248815988,0.6767801642417908,1.5128036737442017,29.35418899946777,3000.0,0.6894544363021851,1.432341456413269,29.022071865705737,3003.0,35313.20211029053,59208.66853928566,35313.20211029053,23890.701112031937,1.4091482162475586,0.0 -101600,0.32229853,1.7112365,,,,,,,,,,,,,,,,, -101700,0.33664712,1.6522446,,,,,,,,,,,,,,,,, -101800,0.29798466,1.615473,,,,,,,,,,,,,,,,, -101900,0.34802565,1.7339576,,,,,,,,,,,,,,,,, -102000,0.36668658,1.6215255,,,,,,,,,,,,,,,,, -102100,0.31738165,1.5956646,,,,,,,,,,,,,,,,, -102200,0.32910866,1.6626502,,,,,,,,,,,,,,,,, -102300,0.35313496,1.7589356,,,,,,,,,,,,,,,,, -102400,0.33787352,1.6997291,,,,,,,,,,,,,,,,, -102500,0.3195941,1.6147488,,,,,,,,,,,,,,,,, -102600,0.32545325,1.55938,,,,,,,,,,,,,,,,, -102700,0.31284082,1.5874603,,,,,,,,,,,,,,,,, -102800,0.32370892,1.5729933,,,,,,,,,,,,,,,,, -102900,0.33397585,1.6746122,,,,,,,,,,,,,,,,, -103000,0.32794642,1.629647,,,,,,,,,,,,,,,,, -103100,0.30639645,1.5725528,,,,,,,,,,,,,,,,, -103200,0.30890623,1.512121,,,,,,,,,,,,,,,,, -103300,0.32244456,1.5547854,,,,,,,,,,,,,,,,, -103400,0.32959726,1.6410141,,,,,,,,,,,,,,,,, -103500,0.32995045,1.6103156,,,,,,,,,,,,,,,,, -103600,0.30085063,1.5985212,,,,,,,,,,,,,,,,, -103700,0.31172922,1.6372347,,,,,,,,,,,,,,,,, -103800,0.35376438,1.6189324,,,,,,,,,,,,,,,,, -103900,0.38180146,1.6955786,,,,,,,,,,,,,,,,, -103954,,,0.6624199748039246,1.615139722824097,32.52034005600726,0.6787640452384949,1.4980305433273315,29.72840660397787,3000.0,0.6900238394737244,1.422572374343872,29.251162714857767,3003.0,36153.19030022621,60624.50135970116,36153.19030022621,24466.42604756356,1.4494218826293943,0.0 -104000,0.3753026,1.640019,,,,,,,,,,,,,,,,, -104100,0.34825805,1.6349909,,,,,,,,,,,,,,,,, -104200,0.33635494,1.6672335,,,,,,,,,,,,,,,,, -104300,0.32313257,1.6703635,,,,,,,,,,,,,,,,, -104400,0.3094518,1.5740834,,,,,,,,,,,,,,,,, -104500,0.31733376,1.6473817,,,,,,,,,,,,,,,,, -104600,0.3335845,1.5999044,,,,,,,,,,,,,,,,, -104700,0.37398133,1.6357145,,,,,,,,,,,,,,,,, -104800,0.33965778,1.6640005,,,,,,,,,,,,,,,,, -104900,0.3438606,1.6348745,,,,,,,,,,,,,,,,, -105000,0.3287012,1.6607925,,,,,,,,,,,,,,,,, -105100,0.3276404,1.6242248,,,,,,,,,,,,,,,,, -105200,0.331188,1.6736101,,,,,,,,,,,,,,,,, -105300,0.34574243,1.5659403,,,,,,,,,,,,,,,,, -105400,0.32413492,1.5570703,,,,,,,,,,,,,,,,, -105500,0.32714635,1.580809,,,,,,,,,,,,,,,,, -105600,0.3281996,1.6477026,,,,,,,,,,,,,,,,, -105700,0.33335987,1.5893127,,,,,,,,,,,,,,,,, -105800,0.35706168,1.6531723,,,,,,,,,,,,,,,,, -105900,0.30837238,1.6167649,,,,,,,,,,,,,,,,, -106000,0.31470817,1.6020913,,,,,,,,,,,,,,,,, -106100,0.3252722,1.7048149,,,,,,,,,,,,,,,,, -106200,0.38033798,1.6339812,,,,,,,,,,,,,,,,, -106300,0.32634243,1.6617981,,,,,,,,,,,,,,,,, -106372,,,0.7069824934005737,1.3659493923187256,36.09946513355275,0.6812934875488281,1.4833416938781738,29.75335012117342,3000.0,0.6945441961288452,1.4052151441574097,29.731491757638405,3003.0,36993.35413503647,62028.96114087105,36993.35413503647,25030.5927362442,1.497650384902954,0.0 -106400,0.34206644,1.6145688,,,,,,,,,,,,,,,,, -106500,0.37649396,1.6008211,,,,,,,,,,,,,,,,, -106600,0.33867273,1.5645266,,,,,,,,,,,,,,,,, -106700,0.33898446,1.596865,,,,,,,,,,,,,,,,, -106800,0.3305619,1.63936,,,,,,,,,,,,,,,,, -106900,0.33890337,1.5754709,,,,,,,,,,,,,,,,, -107000,0.33886126,1.6793078,,,,,,,,,,,,,,,,, -107100,0.33054206,1.5384699,,,,,,,,,,,,,,,,, -107200,0.31803977,1.5486616,,,,,,,,,,,,,,,,, -107300,0.31945798,1.5658846,,,,,,,,,,,,,,,,, -107400,0.34275305,1.6359731,,,,,,,,,,,,,,,,, -107500,0.30708128,1.5659174,,,,,,,,,,,,,,,,, -107600,0.3267516,1.4943568,,,,,,,,,,,,,,,,, -107700,0.3465236,1.6224685,,,,,,,,,,,,,,,,, -107800,0.3171414,1.5614203,,,,,,,,,,,,,,,,, -107900,0.35850507,1.5561794,,,,,,,,,,,,,,,,, -108000,0.33032686,1.6402022,,,,,,,,,,,,,,,,, -108100,0.36844686,1.6049505,,,,,,,,,,,,,,,,, -108200,0.4065679,1.6537147,,,,,,,,,,,,,,,,, -108300,0.36922494,1.6239431,,,,,,,,,,,,,,,,, -108400,0.30663818,1.5522454,,,,,,,,,,,,,,,,, -108500,0.32751217,1.598196,,,,,,,,,,,,,,,,, -108600,0.3182981,1.5830142,,,,,,,,,,,,,,,,, -108700,0.32455435,1.5368899,,,,,,,,,,,,,,,,, -108790,,,0.6711868047714233,1.5594106912612915,33.68167779038636,0.682409405708313,1.4713393449783323,29.792658094298663,3000.0,0.6963105201721191,1.3900195360183716,29.78246186358487,3003.0,37833.25030493736,63443.19142055512,37833.25030493736,25604.80843281746,1.5368640422821045,0.0 -108800,0.32306194,1.5070769,,,,,,,,,,,,,,,,, -108900,0.34294096,1.5336484,,,,,,,,,,,,,,,,, -109000,0.33132255,1.6486324,,,,,,,,,,,,,,,,, -109100,0.34567624,1.5490181,,,,,,,,,,,,,,,,, -109200,0.32552114,1.5751523,,,,,,,,,,,,,,,,, -109300,0.34215552,1.5295913,,,,,,,,,,,,,,,,, -109400,0.33913347,1.4927989,,,,,,,,,,,,,,,,, -109500,0.33834013,1.5944028,,,,,,,,,,,,,,,,, -109600,0.3245674,1.5354042,,,,,,,,,,,,,,,,, -109700,0.36688158,1.5629722,,,,,,,,,,,,,,,,, -109800,0.34186405,1.5424088,,,,,,,,,,,,,,,,, -109900,0.34226894,1.6004512,,,,,,,,,,,,,,,,, -110000,0.33731368,1.4918962,,,,,,,,,,,,,,,,, -110100,0.34996778,1.5717329,,,,,,,,,,,,,,,,, -110200,0.33912668,1.6144986,,,,,,,,,,,,,,,,, -110300,0.33954692,1.5088115,,,,,,,,,,,,,,,,, -110400,0.35428682,1.5390264,,,,,,,,,,,,,,,,, -110500,0.33530563,1.5569547,,,,,,,,,,,,,,,,, -110600,0.3452148,1.4767562,,,,,,,,,,,,,,,,, -110700,0.342051,1.5315344,,,,,,,,,,,,,,,,, -110800,0.3449952,1.5970031,,,,,,,,,,,,,,,,, -110900,0.3270141,1.617262,,,,,,,,,,,,,,,,, -111000,0.33115274,1.601068,,,,,,,,,,,,,,,,, -111100,0.3467863,1.5942812,,,,,,,,,,,,,,,,, -111200,0.3367674,1.5601698,,,,,,,,,,,,,,,,, -111208,,,0.6703004240989685,1.5756620168685913,33.46816417771401,0.683078944683075,1.4614923000335691,29.83414942148012,3000.0,0.6995410323143005,1.3734357357025146,30.15858276774193,3003.0,38673.397919654846,64930.52002501488,38673.397919654846,26251.868897914886,1.5767569541931152,0.0 -111300,0.34271532,1.5154035,,,,,,,,,,,,,,,,, -111400,0.34971434,1.58034,,,,,,,,,,,,,,,,, -111500,0.3339324,1.5870845,,,,,,,,,,,,,,,,, -111600,0.34443712,1.5741204,,,,,,,,,,,,,,,,, -111700,0.35448772,1.5733968,,,,,,,,,,,,,,,,, -111800,0.33943585,1.4975697,,,,,,,,,,,,,,,,, -111900,0.3191288,1.5626535,,,,,,,,,,,,,,,,, -112000,0.3655872,1.4864751,,,,,,,,,,,,,,,,, -112100,0.35357788,1.5068592,,,,,,,,,,,,,,,,, -112200,0.3731247,1.6539378,,,,,,,,,,,,,,,,, -112300,0.34966344,1.5887473,,,,,,,,,,,,,,,,, -112400,0.32897267,1.4942735,,,,,,,,,,,,,,,,, -112500,0.3477688,1.5423329,,,,,,,,,,,,,,,,, -112600,0.37035653,1.56634,,,,,,,,,,,,,,,,, -112700,0.34765178,1.4854276,,,,,,,,,,,,,,,,, -112800,0.34824893,1.5232174,,,,,,,,,,,,,,,,, -112900,0.35865876,1.5559075,,,,,,,,,,,,,,,,, -113000,0.3754342,1.579478,,,,,,,,,,,,,,,,, -113100,0.32998875,1.4866518,,,,,,,,,,,,,,,,, -113200,0.34790894,1.571906,,,,,,,,,,,,,,,,, -113300,0.371278,1.6189371,,,,,,,,,,,,,,,,, -113400,0.35088974,1.5606323,,,,,,,,,,,,,,,,, -113500,0.35147595,1.5403715,,,,,,,,,,,,,,,,, -113600,0.36897683,1.5919447,,,,,,,,,,,,,,,,, -113626,,,0.6843309998512268,1.4797334671020508,34.52292679167273,0.6871954202651978,1.4451779127120972,30.258902249238226,3000.0,0.7013537883758545,1.3622872829437256,30.219203249138683,3003.0,39513.55310797691,66325.00939846039,39513.55310797691,26806.072550058365,1.625173568725586,0.0 -113700,0.3851017,1.6042075,,,,,,,,,,,,,,,,, -113800,0.32037628,1.5150286,,,,,,,,,,,,,,,,, -113900,0.3606693,1.6134475,,,,,,,,,,,,,,,,, -114000,0.34044114,1.4897108,,,,,,,,,,,,,,,,, -114100,0.33653143,1.4606931,,,,,,,,,,,,,,,,, -114200,0.34898016,1.4661348,,,,,,,,,,,,,,,,, -114300,0.37305403,1.5568005,,,,,,,,,,,,,,,,, -114400,0.36299717,1.5236503,,,,,,,,,,,,,,,,, -114500,0.3472379,1.5073578,,,,,,,,,,,,,,,,, -114600,0.36030588,1.5321448,,,,,,,,,,,,,,,,, -114700,0.35787836,1.5888383,,,,,,,,,,,,,,,,, -114800,0.3552737,1.4868312,,,,,,,,,,,,,,,,, -114900,0.35031328,1.5235276,,,,,,,,,,,,,,,,, -115000,0.34828553,1.517675,,,,,,,,,,,,,,,,, -115100,0.34582123,1.5147865,,,,,,,,,,,,,,,,, -115200,0.35614774,1.4821653,,,,,,,,,,,,,,,,, -115300,0.38368195,1.5502185,,,,,,,,,,,,,,,,, -115400,0.36033726,1.5330595,,,,,,,,,,,,,,,,, -115500,0.32131127,1.4864755,,,,,,,,,,,,,,,,, -115600,0.3891836,1.5111698,,,,,,,,,,,,,,,,, -115700,0.36471248,1.5315515,,,,,,,,,,,,,,,,, -115800,0.35655692,1.4819798,,,,,,,,,,,,,,,,, -115900,0.3343637,1.4796388,,,,,,,,,,,,,,,,, -116000,0.3555671,1.5469141,,,,,,,,,,,,,,,,, -116044,,,0.677942156791687,1.5225650072097778,34.21801980482632,0.6876294016838074,1.43970787525177,30.542062526187728,3000.0,0.7033408880233765,1.350867509841919,30.47281998498783,3003.0,40353.66239070892,67777.30443549156,40353.66239070892,27418.13508272171,1.6672179698944092,0.0 -116100,0.37433875,1.4482063,,,,,,,,,,,,,,,,, -116200,0.34316322,1.4973197,,,,,,,,,,,,,,,,, -116300,0.34952268,1.4404076,,,,,,,,,,,,,,,,, -116400,0.34890127,1.4677321,,,,,,,,,,,,,,,,, -116500,0.361573,1.5741729,,,,,,,,,,,,,,,,, -116600,0.3665121,1.5504982,,,,,,,,,,,,,,,,, -116700,0.3415304,1.4764063,,,,,,,,,,,,,,,,, -116800,0.36020073,1.5700545,,,,,,,,,,,,,,,,, -116900,0.35555246,1.5498513,,,,,,,,,,,,,,,,, -117000,0.35945845,1.5612079,,,,,,,,,,,,,,,,, -117100,0.35187092,1.4646826,,,,,,,,,,,,,,,,, -117200,0.3711745,1.5279404,,,,,,,,,,,,,,,,, -117300,0.37590018,1.4378668,,,,,,,,,,,,,,,,, -117400,0.3784086,1.6012605,,,,,,,,,,,,,,,,, -117500,0.35246852,1.5355262,,,,,,,,,,,,,,,,, -117600,0.36237445,1.4604431,,,,,,,,,,,,,,,,, -117700,0.38126606,1.6175327,,,,,,,,,,,,,,,,, -117800,0.38108322,1.5274105,,,,,,,,,,,,,,,,, -117900,0.33763444,1.4593024,,,,,,,,,,,,,,,,, -118000,0.34991825,1.4340925,,,,,,,,,,,,,,,,, -118100,0.3549599,1.4363259,,,,,,,,,,,,,,,,, -118200,0.35594624,1.4872217,,,,,,,,,,,,,,,,, -118300,0.36614436,1.4849014,,,,,,,,,,,,,,,,, -118400,0.35920134,1.4765667,,,,,,,,,,,,,,,,, -118462,,,0.6817564964294434,1.5030604600906372,34.63627435914792,0.6891049146652222,1.429511785507202,30.23524154073049,3000.0,0.705479085445404,1.3375895023345947,30.77297665042285,3003.0,41193.64802837372,69187.57227706909,41193.64802837372,27988.293677330017,1.7116749286651611,0.0 -118500,0.3997778,1.493548,,,,,,,,,,,,,,,,, -118600,0.36242208,1.5352831,,,,,,,,,,,,,,,,, -118700,0.38131022,1.5662192,,,,,,,,,,,,,,,,, -118800,0.34830976,1.4701382,,,,,,,,,,,,,,,,, -118900,0.35361785,1.4865949,,,,,,,,,,,,,,,,, -119000,0.37308142,1.4683868,,,,,,,,,,,,,,,,, -119100,0.35806248,1.4505892,,,,,,,,,,,,,,,,, -119200,0.35704264,1.4222769,,,,,,,,,,,,,,,,, -119300,0.36790463,1.438412,,,,,,,,,,,,,,,,, -119400,0.35353452,1.4343642,,,,,,,,,,,,,,,,, -119500,0.3541425,1.4297727,,,,,,,,,,,,,,,,, -119600,0.37359142,1.4368355,,,,,,,,,,,,,,,,, -119700,0.35949162,1.5162232,,,,,,,,,,,,,,,,, -119800,0.37552428,1.4660815,,,,,,,,,,,,,,,,, -119900,0.39539924,1.4375981,,,,,,,,,,,,,,,,, -120000,0.35544068,1.4561137,,,,,,,,,,,,,,,,, -120100,0.38429025,1.5424533,,,,,,,,,,,,,,,,, -120200,0.36463037,1.4909648,,,,,,,,,,,,,,,,, -120300,0.3784323,1.44659,,,,,,,,,,,,,,,,, -120400,0.37926987,1.5319667,,,,,,,,,,,,,,,,, -120500,0.37685713,1.4482856,,,,,,,,,,,,,,,,, -120600,0.37915555,1.5018485,,,,,,,,,,,,,,,,, -120700,0.3668942,1.4101299,,,,,,,,,,,,,,,,, -120800,0.35882676,1.4229112,,,,,,,,,,,,,,,,, -120881,,,0.6928209066390991,1.4292106628417969,35.03023370942456,0.6907663941383362,1.4212204217910769,30.38323877714415,3000.0,0.7065946459770203,1.3328224420547483,30.69767269998512,3003.0,42033.85468816757,70600.94706273079,42033.85468816757,28561.33699631691,1.754795789718628,0.0 -120900,0.3747279,1.4458164,,,,,,,,,,,,,,,,, -121000,0.37508962,1.4441499,,,,,,,,,,,,,,,,, -121100,0.38244337,1.4426877,,,,,,,,,,,,,,,,, -121200,0.37221754,1.4668962,,,,,,,,,,,,,,,,, -121300,0.3609252,1.4827667,,,,,,,,,,,,,,,,, -121400,0.39186576,1.4813279,,,,,,,,,,,,,,,,, -121500,0.39140844,1.4961321,,,,,,,,,,,,,,,,, -121600,0.37840068,1.4765725,,,,,,,,,,,,,,,,, -121700,0.36517113,1.4378371,,,,,,,,,,,,,,,,, -121800,0.3886322,1.5301217,,,,,,,,,,,,,,,,, -121900,0.3805156,1.5083454,,,,,,,,,,,,,,,,, -122000,0.39028144,1.4562124,,,,,,,,,,,,,,,,, -122100,0.37710437,1.4421202,,,,,,,,,,,,,,,,, -122200,0.3904903,1.5433505,,,,,,,,,,,,,,,,, -122300,0.3860507,1.4597269,,,,,,,,,,,,,,,,, -122400,0.35899368,1.4609215,,,,,,,,,,,,,,,,, -122500,0.38523746,1.4372185,,,,,,,,,,,,,,,,, -122600,0.3998583,1.4681122,,,,,,,,,,,,,,,,, -122700,0.39231244,1.5176978,,,,,,,,,,,,,,,,, -122800,0.369423,1.4280005,,,,,,,,,,,,,,,,, -122900,0.38293758,1.3984653,,,,,,,,,,,,,,,,, -123000,0.3898376,1.4677706,,,,,,,,,,,,,,,,, -123100,0.3766933,1.4552594,,,,,,,,,,,,,,,,, -123200,0.37777635,1.4067116,,,,,,,,,,,,,,,,, -123300,,,0.692010223865509,1.4411240816116333,35.203263862741345,0.6924154758453369,1.4140108823776243,30.84412796852067,3000.0,0.7075010538101196,1.3267334699630735,30.729561965493104,3003.0,42874.04085731506,72031.04406666756,42874.04085731506,29151.12611794472,1.7970449924468994,0.0 -123300,0.38892558,1.4318054,,,,,,,,,,,,,,,,, -123400,0.3733278,1.4953674,,,,,,,,,,,,,,,,, -123500,0.4056068,1.4840792,,,,,,,,,,,,,,,,, -123600,0.37489712,1.4919945,,,,,,,,,,,,,,,,, -123700,0.38782576,1.4701902,,,,,,,,,,,,,,,,, -123800,0.37853903,1.3966833,,,,,,,,,,,,,,,,, -123900,0.39609993,1.4658906,,,,,,,,,,,,,,,,, -124000,0.38290587,1.442923,,,,,,,,,,,,,,,,, -124100,0.38364685,1.4661909,,,,,,,,,,,,,,,,, -124200,0.39213106,1.4302307,,,,,,,,,,,,,,,,, -124300,0.3754113,1.4543893,,,,,,,,,,,,,,,,, -124400,0.3904761,1.4180948,,,,,,,,,,,,,,,,, -124500,0.37009865,1.414671,,,,,,,,,,,,,,,,, -124600,0.39869326,1.477522,,,,,,,,,,,,,,,,, -124700,0.39403665,1.4422296,,,,,,,,,,,,,,,,, -124800,0.39803833,1.4178461,,,,,,,,,,,,,,,,, -124900,0.38185024,1.4414074,,,,,,,,,,,,,,,,, -125000,0.39660606,1.4848871,,,,,,,,,,,,,,,,, -125100,0.3780226,1.4766806,,,,,,,,,,,,,,,,, -125200,0.38057992,1.4062608,,,,,,,,,,,,,,,,, -125300,0.3918326,1.3685338,,,,,,,,,,,,,,,,, -125400,0.39045995,1.4640228,,,,,,,,,,,,,,,,, -125500,0.4070698,1.4235129,,,,,,,,,,,,,,,,, -125600,0.3718375,1.4701067,,,,,,,,,,,,,,,,, -125700,0.39965287,1.4365951,,,,,,,,,,,,,,,,, -125718,,,0.698666512966156,1.3995305299758911,35.56513313304348,0.6939033269882202,1.4084949493408203,30.6555715813736,3000.0,0.708349347114563,1.3209024667739868,30.87662334271036,3003.0,43714.04322123528,73462.10220861435,43714.04322123528,29742.05605435372,1.842005014419556,0.0 -125800,0.3942465,1.4392499,,,,,,,,,,,,,,,,, -125900,0.3705948,1.3953079,,,,,,,,,,,,,,,,, -126000,0.3837414,1.4105955,,,,,,,,,,,,,,,,, -126100,0.4082367,1.4532869,,,,,,,,,,,,,,,,, -126200,0.39177814,1.4285012,,,,,,,,,,,,,,,,, -126300,0.36957192,1.3803958,,,,,,,,,,,,,,,,, -126400,0.3719603,1.3869212,,,,,,,,,,,,,,,,, -126500,0.40524474,1.4695127,,,,,,,,,,,,,,,,, -126600,0.394327,1.4531167,,,,,,,,,,,,,,,,, -126700,0.37482148,1.3687732,,,,,,,,,,,,,,,,, -126800,0.3975232,1.3999511,,,,,,,,,,,,,,,,, -126900,0.37567148,1.3820205,,,,,,,,,,,,,,,,, -127000,0.36396673,1.3954757,,,,,,,,,,,,,,,,, -127100,0.37591365,1.4017673,,,,,,,,,,,,,,,,, -127200,0.3712195,1.4152235,,,,,,,,,,,,,,,,, -127300,0.4144039,1.4299757,,,,,,,,,,,,,,,,, -127400,0.40054289,1.404774,,,,,,,,,,,,,,,,, -127500,0.38913646,1.4418007,,,,,,,,,,,,,,,,, -127600,0.37904412,1.3716407,,,,,,,,,,,,,,,,, -127700,0.37824026,1.4523326,,,,,,,,,,,,,,,,, -127800,0.39312297,1.398296,,,,,,,,,,,,,,,,, -127900,0.3640656,1.415519,,,,,,,,,,,,,,,,, -128000,0.38041863,1.4035251,,,,,,,,,,,,,,,,, -128100,0.37689862,1.4010372,,,,,,,,,,,,,,,,, -128136,,,0.6984717845916748,1.403459548950195,35.613569304448035,0.6934818029403687,1.4072446823120115,30.81339413073557,3000.0,0.710615336894989,1.315718173980713,30.94520762179943,3003.0,44554.06897211075,74888.634370327,44554.06897211075,30328.435676574707,1.886183738708496,0.0 -128200,0.3759452,1.3993527,,,,,,,,,,,,,,,,, -128300,0.39748347,1.4115423,,,,,,,,,,,,,,,,, -128400,0.38842046,1.3858825,,,,,,,,,,,,,,,,, -128500,0.39935946,1.4126428,,,,,,,,,,,,,,,,, -128600,0.3768891,1.3480326,,,,,,,,,,,,,,,,, -128700,0.3804613,1.3565513,,,,,,,,,,,,,,,,, -128800,0.384551,1.4509778,,,,,,,,,,,,,,,,, -128900,0.38983443,1.3862071,,,,,,,,,,,,,,,,, -129000,0.3971243,1.4224852,,,,,,,,,,,,,,,,, -129100,0.37276947,1.412551,,,,,,,,,,,,,,,,, -129200,0.38428903,1.4548928,,,,,,,,,,,,,,,,, -129300,0.3922965,1.4453108,,,,,,,,,,,,,,,,, -129400,0.38370526,1.4060992,,,,,,,,,,,,,,,,, -129500,0.37528068,1.347654,,,,,,,,,,,,,,,,, -129600,0.36534542,1.4132471,,,,,,,,,,,,,,,,, -129700,0.3844652,1.5135555,,,,,,,,,,,,,,,,, -129800,0.37916392,1.41684,,,,,,,,,,,,,,,,, -129900,0.36811924,1.3490669,,,,,,,,,,,,,,,,, -130000,0.36447835,1.4280537,,,,,,,,,,,,,,,,, -130100,0.38123694,1.4030951,,,,,,,,,,,,,,,,, -130200,0.37790757,1.4220332,,,,,,,,,,,,,,,,, -130300,0.39695057,1.4210768,,,,,,,,,,,,,,,,, -130400,0.38470367,1.409078,,,,,,,,,,,,,,,,, -130500,0.36987007,1.4133025,,,,,,,,,,,,,,,,, -130555,,,0.7004496455192566,1.3945019245147705,35.82390045814242,0.693221390247345,1.4052176475524902,30.76973079067687,3000.0,0.7104758620262146,1.3132108449935913,30.946965112698592,3003.0,45394.16031217575,76329.50495123863,45394.16031217575,30929.09312582016,1.9304907321929927,0.0 -130600,0.38805157,1.4169331,,,,,,,,,,,,,,,,, -130700,0.3905813,1.4208865,,,,,,,,,,,,,,,,, -130800,0.371848,1.461905,,,,,,,,,,,,,,,,, -130900,0.37572074,1.3709701,,,,,,,,,,,,,,,,, -131000,0.39388108,1.4033488,,,,,,,,,,,,,,,,, -131100,0.40026355,1.4004608,,,,,,,,,,,,,,,,, -131200,0.35727724,1.3262734,,,,,,,,,,,,,,,,, -131300,0.3976135,1.4137957,,,,,,,,,,,,,,,,, -131400,0.38481057,1.3645611,,,,,,,,,,,,,,,,, -131500,0.37925386,1.3891089,,,,,,,,,,,,,,,,, -131600,0.39513582,1.46887,,,,,,,,,,,,,,,,, -131700,0.40590668,1.456115,,,,,,,,,,,,,,,,, -131800,0.39003822,1.4672827,,,,,,,,,,,,,,,,, -131900,0.37113604,1.3648843,,,,,,,,,,,,,,,,, -132000,0.3865474,1.3823202,,,,,,,,,,,,,,,,, -132100,0.4075123,1.4178827,,,,,,,,,,,,,,,,, -132200,0.3703153,1.4213194,,,,,,,,,,,,,,,,, -132300,0.410023,1.4823319,,,,,,,,,,,,,,,,, -132400,0.40367183,1.3871619,,,,,,,,,,,,,,,,, -132500,0.40058774,1.3875073,,,,,,,,,,,,,,,,, -132600,0.38163927,1.4642558,,,,,,,,,,,,,,,,, -132700,0.39631212,1.3478826,,,,,,,,,,,,,,,,, -132800,0.36158505,1.3053414,,,,,,,,,,,,,,,,, -132900,0.3689254,1.4173428,,,,,,,,,,,,,,,,, -132973,,,0.7006556391716003,1.3924342393875122,35.9691844802764,0.6937049627304077,1.4050370454788208,30.88978323109175,3000.0,0.7101853489875793,1.313064455986023,31.12658905369212,3003.0,46234.09957194328,77774.33254957199,46234.09957194328,31533.85618448257,1.974720478057861,0.0 -133000,0.3980262,1.4332103,,,,,,,,,,,,,,,,, -133100,0.3809121,1.419903,,,,,,,,,,,,,,,,, -133200,0.39325455,1.4737884,,,,,,,,,,,,,,,,, -133300,0.38796213,1.4189658,,,,,,,,,,,,,,,,, -133333,,,0.6965222358703613,1.4127204418182373,35.569271445766226,0.6937049627304077,1.4050407409667969,30.89082856801816,3000.0,0.7102318406105042,1.313091278076172,31.113830544833625,3003.0,46358.67147278786,78488.16630244255,46358.67147278786,32123.061244010925,2.0193135738372803,0.0 -133333,,,,,,,,,,,,,,46358.67147278786,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 6e008af5f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,59 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -859.7216436862946,0.0,31.81131482124329,1,0,31.81131482124329,0.0007088489946909,0.0,10.966498374938965,3003,891.5330049991608,0.0005722984205931,0.0,10.961396217346191,0.0004835649742744,0.0,10.980294227600098,3000 -1335.8473892211914,0.0207440853118896,871.7362320423126,2358,0,871.7362320423126,0.5146359801292419,17.058797586743704,2.8742642402648926,3003,2207.681263446808,0.5152115821838379,22.559558180281464,2.8529348373413086,0.5171913504600525,18.385519591984245,2.8354642391204834,3000 -1839.4309611320496,0.0474672317504882,1711.7513601779938,4718,0,1711.7513601779938,0.5932833552360535,22.21513088925912,2.123673677444458,3003,3551.384021759033,0.5785308480262756,27.16403077071468,2.2483465671539307,0.5911272168159485,23.69484945016536,2.150007486343384,3000 -2279.2648980617523,0.0735304355621337,2551.796592235565,7077,0,2551.796592235565,0.6244146227836609,24.14670156771929,1.8764768838882449,3003,4831.369480133057,0.6040958762168884,29.06062816167113,2.007240056991577,0.6170040965080261,25.44547280020576,1.9235773086547847,3000 -2732.8127586841583,0.1012864112854003,3391.7186181545258,9435,0,3391.7186181545258,0.6360118389129639,24.85918981164338,1.7616527080535889,3003,6124.946580171585,0.6132701635360718,29.630833378883725,1.943329811096192,0.6301347613334656,26.173138325326224,1.815695643424988,3000 -3206.662947416305,0.1285784244537353,4231.872063875198,11794,0,4231.872063875198,0.6483644247055054,25.8616843932732,1.6862653493881226,3003,7439.05501461029,0.6201885938644409,29.409620364240315,1.886154294013977,0.6406244039535522,26.7636870644405,1.745435357093811,3000 -3740.741831302643,0.1564745903015136,5072.088208436966,14153,0,5072.088208436966,0.6512695550918579,25.934985390513155,1.6413919925689695,3003,8813.45656490326,0.6258756518363953,30.30475543405984,1.830366611480713,0.6443069577217102,27.178247847671944,1.7058216333389282,3000 -4229.750396728516,0.1886544227600097,5912.17545747757,16512,0,5912.17545747757,0.6567195653915405,26.76462025361148,1.6103800535202026,3003,10142.664711236954,0.6253536343574524,30.42698447853597,1.8386945724487305,0.6469727754592896,27.51153504880521,1.6795895099639893,3000 -4817.370816230774,0.2168202400207519,6752.3742599487305,18871,0,6752.3742599487305,0.6603451371192932,26.534122605706614,1.5860060453414917,3003,11570.590962171556,0.6545403003692627,32.29815543479825,1.619762897491455,0.6512380242347717,27.504371133986982,1.6543937921524048,3000 -5347.3108031749725,0.2467694282531738,7592.311160326004,21228,0,7592.311160326004,0.6626227498054504,26.67511559058768,1.5686941146850586,3003,12940.57827091217,0.6335942149162292,30.58009748460818,1.776694416999817,0.6522423624992371,27.75300949580246,1.640693187713623,3000 -5933.192608118057,0.275264024734497,8432.473328590393,23587,0,8432.473328590393,0.6640636920928955,26.70121926683885,1.562119483947754,3003,14366.7304251194,0.6337989568710327,30.87184258005141,1.7762765884399414,0.6548957824707031,27.84559009417645,1.6313564777374268,3000 -6419.971329689026,0.303957462310791,9272.582363128662,25946,0,9272.582363128662,0.6673174500465393,27.51932513636056,1.539101481437683,3003,15693.724759817123,0.6462305188179016,31.452835525252222,1.6862976551055908,0.6573259830474854,27.836506053367746,1.6084153652191162,3000 -6881.22886633873,0.3334319591522217,10112.716992616652,28304,0,10112.716992616652,0.6687235236167908,27.162785054412023,1.5342345237731934,3003,16995.22621178627,0.6390774250030518,31.344862714431464,1.7401479482650757,0.6581319570541382,28.05053729965437,1.6097475290298462,3000 -7389.272298574448,0.3637864589691162,10952.624883651732,30662,0,10952.624883651732,0.6705711483955383,27.320115273429533,1.5236083269119265,3003,18343.28938150406,0.63572096824646,31.040252556013616,1.763773798942566,0.6615169048309326,28.108661415052204,1.5909042358398438,3000 -7953.593356847763,0.4070956707000732,11792.630442142488,33020,0,11792.630442142488,0.6711289286613464,27.68331961654764,1.5155937671661377,3003,19747.73938369751,0.6443125009536743,31.166148767869927,1.6968934535980225,0.6595454216003418,27.80920383950556,1.5900745391845703,3000 -8705.197973966599,0.4429934024810791,12632.544389486313,35378,0,12632.544389486313,0.6730231046676636,27.782769305202045,1.5036847591400146,3003,21339.37600016594,0.6399866938591003,31.33820875483425,1.7375435829162598,0.6620624661445618,28.38960980084261,1.5832661390304563,3000 -9295.258509159088,0.474111795425415,13472.63930630684,37736,0,13472.63930630684,0.6728371381759644,27.78853257605112,1.4999555349349976,3003,22769.643906354904,0.663627564907074,32.87796516096618,1.5535968542099,0.6638479232788086,28.38134980475151,1.5714770555496216,3000 -10041.635178804398,0.5051698684692383,14312.627776622772,40094,0,14312.627776622772,0.6751612424850464,28.055212927233068,1.4886205196380615,3003,24356.119666576385,0.6476908922195435,31.62556446241438,1.6861525774002075,0.6630792021751404,28.31053959711827,1.5618013143539429,3000 -10602.350082874298,0.5357956886291504,15152.739064216614,42452,0,15152.739064216614,0.6750218272209167,27.51980545574519,1.486008644104004,3003,25757.05531859398,0.6445221900939941,31.200524801237783,1.700629711151123,0.6668733358383179,28.73763681514912,1.5635435581207275,3000 -11158.329864501951,0.566993236541748,15992.64580154419,44811,0,15992.64580154419,0.6767880916595459,27.9387057645756,1.471121311187744,3003,27153.050884485245,0.6530870795249939,31.688902445897607,1.634215235710144,0.6673692464828491,28.50414605105004,1.5494558811187744,3000 -11643.594151735306,0.5991125106811523,16832.79411482811,47170,0,16832.79411482811,0.6807855367660522,28.13457178472897,1.45988130569458,3003,28478.575281381607,0.6481767892837524,31.457149841710663,1.6744226217269895,0.668088436126709,28.76677074342338,1.545769214630127,3000 -12202.3813393116,0.6326305866241455,17672.922026872635,49530,0,17672.922026872635,0.681494414806366,28.54769671956192,1.4549793004989624,3003,29877.601472377777,0.6455420851707458,31.77329398717936,1.7034410238265991,0.666240930557251,28.787587888853047,1.5409092903137207,3000 -12881.750809669496,0.6645634174346924,18513.092685222626,51890,0,18513.092685222626,0.6804369688034058,28.056858321750056,1.45551860332489,3003,31397.252323150635,0.652571439743042,32.39050117798098,1.6386526823043823,0.6687827706336975,28.55614230670583,1.5333878993988037,3000 -13559.124497413635,0.6962974071502686,19353.05917453766,54249,0,19353.05917453766,0.6831910014152527,28.418639510555767,1.4422683715820312,3003,32914.703605651855,0.6511799693107605,31.79512948258729,1.6562837362289429,0.6702210903167725,29.00607766441165,1.527982473373413,3000 -14192.409289360046,0.7298541069030762,20193.22402882576,56609,0,20193.22402882576,0.6822032332420349,28.28308659347989,1.4344021081924438,3003,34388.2636551857,0.6711536049842834,33.20045640260435,1.5183404684066772,0.6715229749679565,29.05587524880472,1.5183759927749634,3000 -14971.292563676834,0.7623600959777832,21033.412729263306,58968,0,21033.412729263306,0.6840625405311584,28.442912978921708,1.4318151473999023,3003,36007.44893741608,0.6562167406082153,32.00557628346472,1.6146293878555298,0.6700970530509949,28.6046113423605,1.5160971879959106,3000 -15518.967089891434,0.7969238758087158,21873.331847190857,61327,0,21873.331847190857,0.6860031485557556,28.75322589399797,1.4202231168746948,3003,37395.15630316734,0.6536709666252136,32.201144941719825,1.6435863971710205,0.67365562915802,29.024937312260363,1.5085053443908691,3000 -16095.283498048782,0.9187502861022948,22713.31698822975,63686,0,22713.31698822975,0.686851441860199,29.019938516980574,1.4139506816864014,3003,38811.658296108246,0.6629753708839417,33.01746848462258,1.568068265914917,0.6728744506835938,28.890694782084704,1.503533124923706,3000 -16623.977352142334,0.9587399959564208,23553.378241539,66045,0,23553.378241539,0.6908256411552429,29.21380668381085,1.397244215011597,3003,40180.534745931625,0.6583556532859802,32.41859989742564,1.5996129512786863,0.6759990453720093,29.61073050124265,1.4937245845794678,3000 -17171.25036072731,0.9950568675994872,24393.38697886467,68403,0,24393.38697886467,0.6906862258911133,28.9712420024266,1.391940951347351,3003,41567.93536186218,0.6556589603424072,32.53331232232554,1.61857008934021,0.675317108631134,29.18859526057956,1.4912937879562378,3000 -17874.665945768356,1.0316176414489746,25233.515276670456,70762,0,25233.515276670456,0.6927662491798401,29.39340593968376,1.3877900838851929,3003,43111.59431147576,0.6649782657623291,33.27862937609742,1.5656242370605469,0.6770901679992676,29.4606829320584,1.4805506467819214,3000 -18456.486599206924,1.0666627883911133,26073.75528979301,73121,0,26073.75528979301,0.6933588981628418,29.37666021325784,1.37883198261261,3003,44533.76924061775,0.6601581573486328,33.035396256855584,1.5850191116333008,0.6781564950942993,29.27033968859275,1.4745867252349854,3000 -19089.9792573452,1.1101996898651123,26913.816690921783,75480,0,26913.816690921783,0.6943350434303284,29.297592249645188,1.3725438117980957,3003,46007.44537067413,0.6792148947715759,33.80764561643805,1.4641389846801758,0.6794211864471436,29.555809985213795,1.4671216011047363,3000 -19886.381717681885,1.1533830165863037,27753.721145629883,77838,0,27753.721145629883,0.6929754614830017,29.21646204970246,1.3658453226089478,3003,47643.87626647949,0.6665636897087097,33.07441180399439,1.5482903718948364,0.68016517162323,29.50371177493752,1.459350347518921,3000 -20572.528188943863,1.1913011074066162,28593.74077177048,80197,0,28593.74077177048,0.6975771188735962,29.435464433153378,1.3530352115631104,3003,49170.16017580032,0.664623498916626,33.14200471817913,1.563930869102478,0.6821985840797424,29.735981389350844,1.4505184888839722,3000 -21310.009718179703,1.229506015777588,29433.655618667603,82557,0,29433.655618667603,0.6959037780761719,29.501150496830565,1.3536893129348757,3003,50747.67319107056,0.6770812273025513,34.02139815328523,1.4800242185592651,0.6815042495727539,29.45909323576052,1.4466660022735596,3000 -22040.064561843872,1.2669031620025637,30273.7618765831,84916,0,30273.7618765831,0.698843777179718,29.460929968466225,1.3375470638275146,3003,52317.95136976242,0.6718239784240723,33.53921780192777,1.5177863836288452,0.6841204762458801,29.954040391928764,1.4370832443237305,3000 -22610.86890363693,1.3044085502624512,31113.80547237396,87275,0,31113.80547237396,0.699785053730011,29.740632616160205,1.3343571424484253,3003,53728.91593122482,0.6702315211296082,33.57381970171942,1.5227426290512085,0.6851620078086853,30.223012227748367,1.4330618381500244,3000 -23232.804119586945,1.3434412479400637,31953.73739719391,89633,0,31953.73739719391,0.7009587287902832,29.58023972500876,1.3281270265579224,3003,55190.8995449543,0.6788753867149353,34.11023781864019,1.4781488180160522,0.6842444539070129,30.039313991534137,1.4293118715286257,3000 -23843.425915002823,1.3887052536010742,32793.78981876373,91992,0,32793.78981876373,0.7017953991889954,30.07968705955243,1.3203881978988647,3003,56641.69893527031,0.674981951713562,33.107271415973,1.497982621192932,0.6863523125648499,29.94934206485915,1.421521544456482,3000 -24446.093323946,1.426877498626709,33633.84631562233,94351,0,33633.84631562233,0.704317033290863,30.00001747016476,1.3113782405853271,3003,58084.538499593735,0.6841570734977722,34.501937700227685,1.4367806911468506,0.6876170039176941,30.217930488187925,1.4109599590301514,3000 -25121.57190084457,1.466022491455078,34473.7677295208,96710,0,34473.7677295208,0.7056649923324585,30.114120610001063,1.3058409690856934,3003,59600.05470824242,0.6783720254898071,34.322575911748935,1.473360896110535,0.6874558329582214,30.34860323432693,1.4094278812408447,3000 -25678.95575976372,1.5046508312225342,35313.945922613144,99070,0,35313.945922613144,0.7047004699707031,30.178915366389266,1.3031493425369265,3003,60997.73252558708,0.6781623959541321,33.968795052975345,1.4802968502044678,0.6889064908027649,30.316761453508192,1.4011088609695437,3000 -26332.34897923469,1.548503875732422,36153.99291825295,101430,0,36153.99291825295,0.7055255770683289,30.34395915705975,1.2928892374038696,3003,62491.293865680695,0.6870272159576416,34.51947111345885,1.419952392578125,0.6896132826805115,30.23538142734516,1.3957315683364868,3000 -26937.60739517212,1.5876469612121582,36993.88757181168,103788,0,36993.88757181168,0.7094416618347168,30.334914627920924,1.28492534160614,3003,63936.5658159256,0.6809467077255249,34.7810708399674,1.462726354598999,0.6901588439941406,30.3594091547342,1.3913047313690186,3000 -27505.089040517807,1.6410527229309082,37833.93421292305,106147,0,37833.93421292305,0.7098251581192017,30.52362190743765,1.2795196771621704,3003,65344.22820472717,0.687261164188385,34.439128262689984,1.4270676374435425,0.6915971040725708,30.57444599493856,1.3868253231048584,3000 -28149.444666147232,1.6813023090362549,38673.95379757881,108506,0,38673.95379757881,0.7099180817604065,30.54607527793401,1.2736002206802368,3003,66828.72117614746,0.6925578117370605,35.037549363106514,1.3921180963516235,0.6933826208114624,30.82730396732934,1.3761391639709473,3000 -28747.312561512,1.722498893737793,39513.917432546616,110865,0,39513.917432546616,0.7104293704032898,30.823263329626307,1.2698314189910889,3003,68266.67229032516,0.6871196031570435,34.8735757618847,1.423920512199402,0.6927750110626221,30.65946671057725,1.379041075706482,3000 -29395.77043747902,1.7632851600646973,40353.82215619087,113222,0,40353.82215619087,0.7110685110092163,30.57820511062706,1.2664237022399902,3003,69755.15716338158,0.6949504017829895,35.29129947522768,1.366610407829285,0.6940769553184509,30.822025096723628,1.373827338218689,3000 -30026.637234210968,1.8142666816711424,41193.96239209175,115581,0,41193.96239209175,0.7114287614822388,30.915548870642528,1.2623103857040403,3003,71226.29557180405,0.692932665348053,34.87558627803715,1.3845998048782349,0.6947588920593262,30.69022831385768,1.3719394207000732,3000 -30575.64464998245,1.8567523956298828,42033.97531580925,117940,0,42033.97531580925,0.7126837372779846,31.04379746971764,1.2599252462387085,3003,72615.43852734566,0.6903288960456848,35.18875162329316,1.406310796737671,0.6950441002845764,30.774742153044187,1.3678932189941406,3000 -31137.830321788788,1.8995048999786377,42873.95320272446,120299,0,42873.95320272446,0.7123816609382629,30.803521074857283,1.2575608491897583,3003,74017.72199606895,0.6964380145072937,35.66247077415662,1.3709895610809326,0.6952300667762756,30.875901980810767,1.3651654720306396,3000 -31734.78758907318,1.942683458328247,43713.98013854027,122659,0,43713.98013854027,0.7139387726783752,30.83641365004823,1.2523609399795532,3003,75454.82591438293,0.6943261623382568,35.10050834963898,1.382710337638855,0.6955276131629944,31.04729680749209,1.3625982999801636,3000 -32310.166967868805,1.9930167198181152,44553.89228606224,125019,0,44553.89228606224,0.713927149772644,30.98107247465367,1.2523263692855835,3003,76870.24442744255,0.6947394609451294,35.855026133964195,1.377518892288208,0.6967179775238037,30.87607412885397,1.361031174659729,3000 -32904.57277917862,2.0373668670654297,45394.04832172394,127379,0,45394.04832172394,0.7140550017356873,30.88051741175642,1.251923441886902,3003,78304.93042588234,0.6961644291877747,35.79156931045429,1.368328094482422,0.6965195536613464,30.94567958438413,1.3605613708496094,3000 -33483.43272519112,2.08833909034729,46234.16683530808,129739,0,46234.16683530808,0.7140898704528809,30.877321535383,1.2510143518447876,3003,79724.03761744499,0.6961685419082642,35.88213839829184,1.3700485229492188,0.6964327692985535,30.98190345178364,1.36008882522583,3000 -34078.76060009003,2.137282371520996,47074.08837556839,132099,0,47074.08837556839,0.7141827940940857,30.886567051931777,1.2505288124084473,3003,81159.41278767586,0.6940845251083374,35.896859808845385,1.3850692510604858,0.6964203715324402,30.921410299106665,1.359937071800232,3000 -34649.70331478119,2.181474447250366,47513.109431266785,133333,0,47513.109431266785,0.7141595482826233,30.925173570681565,1.2506054639816284,3003,82169.46109032631,0.6966027021408081,35.481369630996994,1.3738631010055542,0.6963583827018738,30.92441491532882,1.360085368156433,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/measurements.csv deleted file mode 100644 index e198fa685..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1394 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.866031,10.961176,,,,,,,,,,,,,,,,, -1,,,0.0005722984205931,10.961396217346191,0.0,0.0004835649742744,10.980294227600098,0.0,3000.0,0.0007088489946909,10.966498374938965,0.0,3003.0,31.81131482124329,891.5330049991608,31.81131482124329,859.7216436862946,0.0,0.0 -100,0.16309346,8.263269,,,,,,,,,,,,,,,,, -200,0.34484667,7.440059,,,,,,,,,,,,,,,,, -300,0.63377917,6.8457956,,,,,,,,,,,,,,,,, -400,0.38963336,6.378072,,,,,,,,,,,,,,,,, -500,0.48827317,5.8951907,,,,,,,,,,,,,,,,, -600,0.6352641,5.5688763,,,,,,,,,,,,,,,,, -700,0.4651748,5.335845,,,,,,,,,,,,,,,,, -800,0.6650674,5.0998588,,,,,,,,,,,,,,,,, -900,0.5952439,4.8158417,,,,,,,,,,,,,,,,, -1000,0.45355114,4.576095,,,,,,,,,,,,,,,,, -1100,0.74068767,4.2885394,,,,,,,,,,,,,,,,, -1200,0.47756234,4.0769176,,,,,,,,,,,,,,,,, -1300,0.4636505,3.9753165,,,,,,,,,,,,,,,,, -1400,0.44879144,3.8670514,,,,,,,,,,,,,,,,, -1500,0.5330285,3.6312335,,,,,,,,,,,,,,,,, -1600,0.4711977,3.5142868,,,,,,,,,,,,,,,,, -1700,0.530964,3.4039948,,,,,,,,,,,,,,,,, -1800,0.4621943,3.3233066,,,,,,,,,,,,,,,,, -1900,0.40230772,3.224062,,,,,,,,,,,,,,,,, -2000,0.45852113,3.2079709,,,,,,,,,,,,,,,,, -2100,0.5029061,3.1433392,,,,,,,,,,,,,,,,, -2200,0.41940093,3.1078396,,,,,,,,,,,,,,,,, -2300,0.42138833,3.0280933,,,,,,,,,,,,,,,,, -2358,,,0.5152115821838379,2.8529348373413086,22.559558180281464,0.5171913504600525,2.8354642391204834,18.385519591984245,3000.0,0.5146359801292419,2.8742642402648926,17.058797586743704,3003.0,871.7362320423126,2207.681263446808,871.7362320423126,1335.8473892211914,0.0207440853118896,0.0 -2400,0.46809822,2.9291873,,,,,,,,,,,,,,,,, -2500,0.32862297,3.0117266,,,,,,,,,,,,,,,,, -2600,0.28408188,2.78312,,,,,,,,,,,,,,,,, -2700,0.3019398,2.8127127,,,,,,,,,,,,,,,,, -2800,0.3594723,2.807623,,,,,,,,,,,,,,,,, -2900,0.32453093,2.731549,,,,,,,,,,,,,,,,, -3000,0.31618467,2.7165205,,,,,,,,,,,,,,,,, -3100,0.25403133,2.7085886,,,,,,,,,,,,,,,,, -3200,0.22609653,2.6690097,,,,,,,,,,,,,,,,, -3300,0.2506487,2.5562294,,,,,,,,,,,,,,,,, -3400,0.24169265,2.5666554,,,,,,,,,,,,,,,,, -3500,0.22280082,2.5251405,,,,,,,,,,,,,,,,, -3600,0.20590897,2.590582,,,,,,,,,,,,,,,,, -3700,0.23239714,2.521679,,,,,,,,,,,,,,,,, -3800,0.21177395,2.4962928,,,,,,,,,,,,,,,,, -3900,0.20283979,2.4602695,,,,,,,,,,,,,,,,, -4000,0.2078164,2.4844992,,,,,,,,,,,,,,,,, -4100,0.2269506,2.458177,,,,,,,,,,,,,,,,, -4200,0.20316899,2.4283948,,,,,,,,,,,,,,,,, -4300,0.17778632,2.3816006,,,,,,,,,,,,,,,,, -4400,0.18044074,2.3866296,,,,,,,,,,,,,,,,, -4500,0.20165826,2.3364925,,,,,,,,,,,,,,,,, -4600,0.18896493,2.3311348,,,,,,,,,,,,,,,,, -4700,0.19070664,2.3162134,,,,,,,,,,,,,,,,, -4718,,,0.5785308480262756,2.2483465671539307,27.16403077071468,0.5911272168159485,2.150007486343384,23.69484945016536,3000.0,0.5932833552360535,2.123673677444458,22.21513088925912,3003.0,1711.7513601779938,3551.384021759033,1711.7513601779938,1839.4309611320496,0.0474672317504882,0.0 -4800,0.20917065,2.3020875,,,,,,,,,,,,,,,,, -4900,0.18689896,2.3381739,,,,,,,,,,,,,,,,, -5000,0.16452,2.3209057,,,,,,,,,,,,,,,,, -5100,0.16782556,2.268517,,,,,,,,,,,,,,,,, -5200,0.1513035,2.2577548,,,,,,,,,,,,,,,,, -5300,0.16595276,2.2384012,,,,,,,,,,,,,,,,, -5400,0.1583696,2.2457268,,,,,,,,,,,,,,,,, -5500,0.15449373,2.2876854,,,,,,,,,,,,,,,,, -5600,0.21273924,2.2747674,,,,,,,,,,,,,,,,, -5700,0.17878889,2.2550042,,,,,,,,,,,,,,,,, -5800,0.18004844,2.2220423,,,,,,,,,,,,,,,,, -5900,0.1778948,2.21628,,,,,,,,,,,,,,,,, -6000,0.16484803,2.213916,,,,,,,,,,,,,,,,, -6100,0.16596551,2.2321835,,,,,,,,,,,,,,,,, -6200,0.19912069,2.2589927,,,,,,,,,,,,,,,,, -6300,0.16659614,2.157328,,,,,,,,,,,,,,,,, -6400,0.1455757,2.1810706,,,,,,,,,,,,,,,,, -6500,0.19331945,2.2175257,,,,,,,,,,,,,,,,, -6600,0.17526914,2.260029,,,,,,,,,,,,,,,,, -6700,0.14389212,2.1827142,,,,,,,,,,,,,,,,, -6800,0.17818093,2.1572938,,,,,,,,,,,,,,,,, -6900,0.18820867,2.206259,,,,,,,,,,,,,,,,, -7000,0.14213485,2.1636498,,,,,,,,,,,,,,,,, -7077,,,0.6040958762168884,2.007240056991577,29.06062816167113,0.6170040965080261,1.9235773086547847,25.44547280020576,3000.0,0.6244146227836609,1.8764768838882449,24.14670156771929,3003.0,2551.796592235565,4831.369480133057,2551.796592235565,2279.2648980617523,0.0735304355621337,0.0 -7100,0.21015613,2.218535,,,,,,,,,,,,,,,,, -7200,0.16041522,2.1273296,,,,,,,,,,,,,,,,, -7300,0.19892345,2.0766163,,,,,,,,,,,,,,,,, -7400,0.14810877,2.0534856,,,,,,,,,,,,,,,,, -7500,0.18310077,2.082099,,,,,,,,,,,,,,,,, -7600,0.15552849,2.0365977,,,,,,,,,,,,,,,,, -7700,0.18913744,2.091416,,,,,,,,,,,,,,,,, -7800,0.24312572,2.1345634,,,,,,,,,,,,,,,,, -7900,0.17842914,2.179497,,,,,,,,,,,,,,,,, -8000,0.15164964,2.061892,,,,,,,,,,,,,,,,, -8100,0.1736087,2.0813632,,,,,,,,,,,,,,,,, -8200,0.15221249,2.0920906,,,,,,,,,,,,,,,,, -8300,0.20023063,2.064548,,,,,,,,,,,,,,,,, -8400,0.17637613,2.0845857,,,,,,,,,,,,,,,,, -8500,0.16893207,1.9522544,,,,,,,,,,,,,,,,, -8600,0.16594113,1.9766757,,,,,,,,,,,,,,,,, -8700,0.17702359,2.0598497,,,,,,,,,,,,,,,,, -8800,0.18111745,1.9731094,,,,,,,,,,,,,,,,, -8900,0.19888593,1.996583,,,,,,,,,,,,,,,,, -9000,0.15271403,2.1430452,,,,,,,,,,,,,,,,, -9100,0.25799263,2.0542324,,,,,,,,,,,,,,,,, -9200,0.16838887,2.143006,,,,,,,,,,,,,,,,, -9300,0.16201797,2.0320706,,,,,,,,,,,,,,,,, -9400,0.20565586,2.1019192,,,,,,,,,,,,,,,,, -9435,,,0.6132701635360718,1.943329811096192,29.630833378883725,0.6301347613334656,1.815695643424988,26.173138325326224,3000.0,0.6360118389129639,1.7616527080535889,24.85918981164338,3003.0,3391.7186181545258,6124.946580171585,3391.7186181545258,2732.8127586841583,0.1012864112854003,0.0 -9500,0.15386544,1.9992334,,,,,,,,,,,,,,,,, -9600,0.15350881,2.0080078,,,,,,,,,,,,,,,,, -9700,0.24001028,1.9384094,,,,,,,,,,,,,,,,, -9800,0.16338375,2.006362,,,,,,,,,,,,,,,,, -9900,0.17941631,1.9906831,,,,,,,,,,,,,,,,, -10000,0.1977977,1.9645184,,,,,,,,,,,,,,,,, -10100,0.2604733,2.0903,,,,,,,,,,,,,,,,, -10200,0.28964147,1.9869001,,,,,,,,,,,,,,,,, -10300,0.22336836,1.9960157,,,,,,,,,,,,,,,,, -10400,0.16861652,1.9987191,,,,,,,,,,,,,,,,, -10500,0.23260105,2.1103888,,,,,,,,,,,,,,,,, -10600,0.24001043,2.0062425,,,,,,,,,,,,,,,,, -10700,0.23164855,2.0749755,,,,,,,,,,,,,,,,, -10800,0.21437147,2.0070214,,,,,,,,,,,,,,,,, -10900,0.21335632,1.9866209,,,,,,,,,,,,,,,,, -11000,0.24619898,2.1185837,,,,,,,,,,,,,,,,, -11100,0.18509738,1.9145396,,,,,,,,,,,,,,,,, -11200,0.2194634,1.9238523,,,,,,,,,,,,,,,,, -11300,0.24812701,1.9947767,,,,,,,,,,,,,,,,, -11400,0.28714576,1.9706138,,,,,,,,,,,,,,,,, -11500,0.19422245,1.9800758,,,,,,,,,,,,,,,,, -11600,0.2061132,1.9546906,,,,,,,,,,,,,,,,, -11700,0.19336781,2.010086,,,,,,,,,,,,,,,,, -11794,,,0.6201885938644409,1.886154294013977,29.409620364240315,0.6406244039535522,1.745435357093811,26.7636870644405,3000.0,0.6483644247055054,1.6862653493881226,25.8616843932732,3003.0,4231.872063875198,7439.05501461029,4231.872063875198,3206.662947416305,0.1285784244537353,0.0 -11800,0.1906168,1.8808306,,,,,,,,,,,,,,,,, -11900,0.29704967,1.9859356,,,,,,,,,,,,,,,,, -12000,0.34235215,2.008742,,,,,,,,,,,,,,,,, -12100,0.32518232,1.9153833,,,,,,,,,,,,,,,,, -12200,0.22169285,1.9798753,,,,,,,,,,,,,,,,, -12300,0.2149191,1.8880876,,,,,,,,,,,,,,,,, -12400,0.17589325,1.9378049,,,,,,,,,,,,,,,,, -12500,0.20422806,2.0374045,,,,,,,,,,,,,,,,, -12600,0.20603651,1.9780688,,,,,,,,,,,,,,,,, -12700,0.18255591,2.0237052,,,,,,,,,,,,,,,,, -12800,0.24972752,1.977906,,,,,,,,,,,,,,,,, -12900,0.18204543,1.8767847,,,,,,,,,,,,,,,,, -13000,0.1926454,1.8986783,,,,,,,,,,,,,,,,, -13100,0.27402025,1.9476126,,,,,,,,,,,,,,,,, -13200,0.2176452,1.9534653,,,,,,,,,,,,,,,,, -13300,0.21245176,1.913291,,,,,,,,,,,,,,,,, -13400,0.25957048,1.9682316,,,,,,,,,,,,,,,,, -13500,0.21244228,1.9192736,,,,,,,,,,,,,,,,, -13600,0.191954,1.917769,,,,,,,,,,,,,,,,, -13700,0.2356224,1.9425324,,,,,,,,,,,,,,,,, -13800,0.1982384,1.8524766,,,,,,,,,,,,,,,,, -13900,0.19337927,2.0237763,,,,,,,,,,,,,,,,, -14000,0.20508221,1.9195734,,,,,,,,,,,,,,,,, -14100,0.19756648,1.9456587,,,,,,,,,,,,,,,,, -14153,,,0.6258756518363953,1.830366611480713,30.30475543405984,0.6443069577217102,1.7058216333389282,27.178247847671944,3000.0,0.6512695550918579,1.6413919925689695,25.934985390513155,3003.0,5072.088208436966,8813.45656490326,5072.088208436966,3740.741831302643,0.1564745903015136,0.0 -14200,0.20829357,1.9475776,,,,,,,,,,,,,,,,, -14300,0.19653925,1.8561583,,,,,,,,,,,,,,,,, -14400,0.19306305,1.9345762,,,,,,,,,,,,,,,,, -14500,0.24908943,1.8587669,,,,,,,,,,,,,,,,, -14600,0.17903492,1.9429394,,,,,,,,,,,,,,,,, -14700,0.17559528,1.9690939,,,,,,,,,,,,,,,,, -14800,0.21108143,1.9023149,,,,,,,,,,,,,,,,, -14900,0.18914375,1.9567851,,,,,,,,,,,,,,,,, -15000,0.19802034,1.8915023,,,,,,,,,,,,,,,,, -15100,0.18074097,1.9682541,,,,,,,,,,,,,,,,, -15200,0.21429877,1.9884305,,,,,,,,,,,,,,,,, -15300,0.21472853,1.9321511,,,,,,,,,,,,,,,,, -15400,0.31260318,1.8769925,,,,,,,,,,,,,,,,, -15500,0.25784236,1.854002,,,,,,,,,,,,,,,,, -15600,0.18759127,1.9466214,,,,,,,,,,,,,,,,, -15700,0.35300094,1.9720316,,,,,,,,,,,,,,,,, -15800,0.30034083,1.8897188,,,,,,,,,,,,,,,,, -15900,0.21895668,1.9317446,,,,,,,,,,,,,,,,, -16000,0.19629802,1.8071861,,,,,,,,,,,,,,,,, -16100,0.20386049,1.9214169,,,,,,,,,,,,,,,,, -16200,0.20534554,1.8319465,,,,,,,,,,,,,,,,, -16300,0.17931035,1.8703532,,,,,,,,,,,,,,,,, -16400,0.3391199,1.8476914,,,,,,,,,,,,,,,,, -16500,0.27231988,1.9599407,,,,,,,,,,,,,,,,, -16512,,,0.6253536343574524,1.8386945724487305,30.42698447853597,0.6469727754592896,1.6795895099639893,27.51153504880521,3000.0,0.6567195653915405,1.6103800535202026,26.76462025361148,3003.0,5912.17545747757,10142.664711236954,5912.17545747757,4229.750396728516,0.1886544227600097,0.0 -16600,0.23265961,1.7985064,,,,,,,,,,,,,,,,, -16700,0.19525062,1.949576,,,,,,,,,,,,,,,,, -16800,0.19910397,1.8409328,,,,,,,,,,,,,,,,, -16900,0.2720403,1.9396869,,,,,,,,,,,,,,,,, -17000,0.22710423,2.015688,,,,,,,,,,,,,,,,, -17100,0.19898507,1.9501898,,,,,,,,,,,,,,,,, -17200,0.26104483,1.8404071,,,,,,,,,,,,,,,,, -17300,0.26500252,1.9279104,,,,,,,,,,,,,,,,, -17400,0.3527575,1.8498534,,,,,,,,,,,,,,,,, -17500,0.19149709,1.8982055,,,,,,,,,,,,,,,,, -17600,0.18068667,1.9259887,,,,,,,,,,,,,,,,, -17700,0.1989299,1.8465642,,,,,,,,,,,,,,,,, -17800,0.17227298,1.8765138,,,,,,,,,,,,,,,,, -17900,0.1873738,1.8423976,,,,,,,,,,,,,,,,, -18000,0.18686497,1.8788741,,,,,,,,,,,,,,,,, -18100,0.19610122,1.8153204,,,,,,,,,,,,,,,,, -18200,0.19187926,1.9728004,,,,,,,,,,,,,,,,, -18300,0.16664658,1.8051051,,,,,,,,,,,,,,,,, -18400,0.24959555,1.8803194,,,,,,,,,,,,,,,,, -18500,0.22688007,1.8888435,,,,,,,,,,,,,,,,, -18600,0.19704984,1.886221,,,,,,,,,,,,,,,,, -18700,0.20751879,1.8174821,,,,,,,,,,,,,,,,, -18800,0.18843083,1.907529,,,,,,,,,,,,,,,,, -18871,,,0.6545403003692627,1.619762897491455,32.29815543479825,0.6512380242347717,1.6543937921524048,27.504371133986982,3000.0,0.6603451371192932,1.5860060453414917,26.534122605706614,3003.0,6752.3742599487305,11570.590962171556,6752.3742599487305,4817.370816230774,0.2168202400207519,0.0 -18900,0.20568012,1.7892083,,,,,,,,,,,,,,,,, -19000,0.19560184,1.8096195,,,,,,,,,,,,,,,,, -19100,0.25422397,1.876677,,,,,,,,,,,,,,,,, -19200,0.20807967,1.8553109,,,,,,,,,,,,,,,,, -19300,0.31273088,1.7790062,,,,,,,,,,,,,,,,, -19400,0.28523865,1.8937061,,,,,,,,,,,,,,,,, -19500,0.27564934,1.779665,,,,,,,,,,,,,,,,, -19600,0.19848904,1.9594862,,,,,,,,,,,,,,,,, -19700,0.18582794,1.8651284,,,,,,,,,,,,,,,,, -19800,0.20540927,1.8581355,,,,,,,,,,,,,,,,, -19900,0.21003863,1.9102209,,,,,,,,,,,,,,,,, -20000,0.20202354,1.848754,,,,,,,,,,,,,,,,, -20100,0.18243042,1.7825136,,,,,,,,,,,,,,,,, -20200,0.19149604,1.8059156,,,,,,,,,,,,,,,,, -20300,0.23332338,1.8910589,,,,,,,,,,,,,,,,, -20400,0.19475845,1.9268978,,,,,,,,,,,,,,,,, -20500,0.20131253,1.8334315,,,,,,,,,,,,,,,,, -20600,0.20090096,1.7386107,,,,,,,,,,,,,,,,, -20700,0.20185393,1.7614102,,,,,,,,,,,,,,,,, -20800,0.4619403,1.8735025,,,,,,,,,,,,,,,,, -20900,0.2284489,1.8211553,,,,,,,,,,,,,,,,, -21000,0.20889445,1.8347936,,,,,,,,,,,,,,,,, -21100,0.18629594,1.8619822,,,,,,,,,,,,,,,,, -21200,0.19214542,1.8954407,,,,,,,,,,,,,,,,, -21228,,,0.6335942149162292,1.776694416999817,30.58009748460818,0.6522423624992371,1.640693187713623,27.75300949580246,3000.0,0.6626227498054504,1.5686941146850586,26.67511559058768,3003.0,7592.311160326004,12940.57827091217,7592.311160326004,5347.3108031749725,0.2467694282531738,0.0 -21300,0.22432247,1.8865528,,,,,,,,,,,,,,,,, -21400,0.3145957,1.8813764,,,,,,,,,,,,,,,,, -21500,0.21561074,1.8564104,,,,,,,,,,,,,,,,, -21600,0.18998653,1.7826424,,,,,,,,,,,,,,,,, -21700,0.21394649,1.7443115,,,,,,,,,,,,,,,,, -21800,0.21167423,1.8269587,,,,,,,,,,,,,,,,, -21900,0.18998963,1.9051269,,,,,,,,,,,,,,,,, -22000,0.20611532,1.9785802,,,,,,,,,,,,,,,,, -22100,0.23965986,1.9232842,,,,,,,,,,,,,,,,, -22200,0.2213328,1.9148903,,,,,,,,,,,,,,,,, -22300,0.18757504,1.8796076,,,,,,,,,,,,,,,,, -22400,0.26202485,1.9219637,,,,,,,,,,,,,,,,, -22500,0.22985163,1.874367,,,,,,,,,,,,,,,,, -22600,0.20034786,1.8707111,,,,,,,,,,,,,,,,, -22700,0.249991,1.8175789,,,,,,,,,,,,,,,,, -22800,0.21657924,1.8054082,,,,,,,,,,,,,,,,, -22900,0.2201951,1.8628447,,,,,,,,,,,,,,,,, -23000,0.19239607,1.8311514,,,,,,,,,,,,,,,,, -23100,0.21068116,1.8183566,,,,,,,,,,,,,,,,, -23200,0.2001137,1.8622912,,,,,,,,,,,,,,,,, -23300,0.20023945,1.8503661,,,,,,,,,,,,,,,,, -23400,0.21330108,1.902196,,,,,,,,,,,,,,,,, -23500,0.19693962,1.8018252,,,,,,,,,,,,,,,,, -23587,,,0.6337989568710327,1.7762765884399414,30.87184258005141,0.6548957824707031,1.6313564777374268,27.84559009417645,3000.0,0.6640636920928955,1.562119483947754,26.70121926683885,3003.0,8432.473328590393,14366.7304251194,8432.473328590393,5933.192608118057,0.275264024734497,0.0 -23600,0.22656168,1.9003037,,,,,,,,,,,,,,,,, -23700,0.3306062,1.8019749,,,,,,,,,,,,,,,,, -23800,0.22867778,1.8223171,,,,,,,,,,,,,,,,, -23900,0.19733717,1.8645321,,,,,,,,,,,,,,,,, -24000,0.25278023,1.8367177,,,,,,,,,,,,,,,,, -24100,0.2003951,1.8366451,,,,,,,,,,,,,,,,, -24200,0.33024168,1.8289201,,,,,,,,,,,,,,,,, -24300,0.20330366,1.8534682,,,,,,,,,,,,,,,,, -24400,0.30720535,1.7697464,,,,,,,,,,,,,,,,, -24500,0.2461305,1.7855588,,,,,,,,,,,,,,,,, -24600,0.22799556,1.8783233,,,,,,,,,,,,,,,,, -24700,0.26204818,1.7898071,,,,,,,,,,,,,,,,, -24800,0.20140792,1.8012687,,,,,,,,,,,,,,,,, -24900,0.2286541,1.8735392,,,,,,,,,,,,,,,,, -25000,0.20941193,1.7383173,,,,,,,,,,,,,,,,, -25100,0.20280917,1.8624775,,,,,,,,,,,,,,,,, -25200,0.22275893,1.8973423,,,,,,,,,,,,,,,,, -25300,0.21048258,1.7876027,,,,,,,,,,,,,,,,, -25400,0.18632111,1.7866497,,,,,,,,,,,,,,,,, -25500,0.20036818,1.7646035,,,,,,,,,,,,,,,,, -25600,0.24197501,1.8005202,,,,,,,,,,,,,,,,, -25700,0.21009612,1.7960391,,,,,,,,,,,,,,,,, -25800,0.2532353,1.8724462,,,,,,,,,,,,,,,,, -25900,0.2932069,1.7417228,,,,,,,,,,,,,,,,, -25946,,,0.6462305188179016,1.6862976551055908,31.452835525252222,0.6573259830474854,1.6084153652191162,27.836506053367746,3000.0,0.6673174500465393,1.539101481437683,27.51932513636056,3003.0,9272.582363128662,15693.724759817123,9272.582363128662,6419.971329689026,0.303957462310791,0.0 -26000,0.22109625,1.83417,,,,,,,,,,,,,,,,, -26100,0.23932074,1.858033,,,,,,,,,,,,,,,,, -26200,0.18466999,1.8555486,,,,,,,,,,,,,,,,, -26300,0.23819897,1.8243734,,,,,,,,,,,,,,,,, -26400,0.23205225,1.8161929,,,,,,,,,,,,,,,,, -26500,0.21470055,1.8351684,,,,,,,,,,,,,,,,, -26600,0.19882411,1.8242563,,,,,,,,,,,,,,,,, -26700,0.33054343,1.7601248,,,,,,,,,,,,,,,,, -26800,0.19930384,1.8236171,,,,,,,,,,,,,,,,, -26900,0.21263416,1.8311491,,,,,,,,,,,,,,,,, -27000,0.22560094,1.8328009,,,,,,,,,,,,,,,,, -27100,0.2674753,1.7727909,,,,,,,,,,,,,,,,, -27200,0.21926643,1.8424101,,,,,,,,,,,,,,,,, -27300,0.2032481,1.8433908,,,,,,,,,,,,,,,,, -27400,0.2014472,1.7996637,,,,,,,,,,,,,,,,, -27500,0.21231021,1.8974539,,,,,,,,,,,,,,,,, -27600,0.1884816,1.840587,,,,,,,,,,,,,,,,, -27700,0.19810963,1.8836942,,,,,,,,,,,,,,,,, -27800,0.23368905,1.8073841,,,,,,,,,,,,,,,,, -27900,0.24936952,1.8354298,,,,,,,,,,,,,,,,, -28000,0.22478646,1.8607476,,,,,,,,,,,,,,,,, -28100,0.29215744,1.8490964,,,,,,,,,,,,,,,,, -28200,0.23079339,1.7487926,,,,,,,,,,,,,,,,, -28300,0.19591059,1.7512726,,,,,,,,,,,,,,,,, -28304,,,0.6390774250030518,1.7401479482650757,31.344862714431464,0.6581319570541382,1.6097475290298462,28.05053729965437,3000.0,0.6687235236167908,1.5342345237731934,27.162785054412023,3003.0,10112.716992616652,16995.22621178627,10112.716992616652,6881.22886633873,0.3334319591522217,0.0 -28400,0.23018818,1.8047509,,,,,,,,,,,,,,,,, -28500,0.25473002,1.8307397,,,,,,,,,,,,,,,,, -28600,0.21905601,1.722973,,,,,,,,,,,,,,,,, -28700,0.25086358,1.8440125,,,,,,,,,,,,,,,,, -28800,0.18043469,1.7966286,,,,,,,,,,,,,,,,, -28900,0.21319059,1.8536855,,,,,,,,,,,,,,,,, -29000,0.24765293,1.862468,,,,,,,,,,,,,,,,, -29100,0.20258835,1.8064532,,,,,,,,,,,,,,,,, -29200,0.20004402,1.8251224,,,,,,,,,,,,,,,,, -29300,0.21124749,1.8491288,,,,,,,,,,,,,,,,, -29400,0.17882156,1.8264585,,,,,,,,,,,,,,,,, -29500,0.26156253,1.7051351,,,,,,,,,,,,,,,,, -29600,0.22204907,1.7483635,,,,,,,,,,,,,,,,, -29700,0.28862786,1.8827922,,,,,,,,,,,,,,,,, -29800,3.2386596,1.82671,,,,,,,,,,,,,,,,, -29900,0.19003174,1.8137959,,,,,,,,,,,,,,,,, -30000,0.23174198,1.7984327,,,,,,,,,,,,,,,,, -30100,0.20404895,1.810594,,,,,,,,,,,,,,,,, -30200,0.23828875,1.8163586,,,,,,,,,,,,,,,,, -30300,0.21125612,1.7700678,,,,,,,,,,,,,,,,, -30400,0.20485313,1.779644,,,,,,,,,,,,,,,,, -30500,0.22330321,1.7821085,,,,,,,,,,,,,,,,, -30600,0.25458056,1.8244591,,,,,,,,,,,,,,,,, -30662,,,0.63572096824646,1.763773798942566,31.040252556013616,0.6615169048309326,1.5909042358398438,28.108661415052204,3000.0,0.6705711483955383,1.5236083269119265,27.320115273429533,3003.0,10952.624883651732,18343.28938150406,10952.624883651732,7389.272298574448,0.3637864589691162,0.0 -30700,0.20937471,1.7943017,,,,,,,,,,,,,,,,, -30800,0.22285374,1.7556329,,,,,,,,,,,,,,,,, -30900,0.20225425,1.7628337,,,,,,,,,,,,,,,,, -31000,0.18818463,1.7290584,,,,,,,,,,,,,,,,, -31100,0.19382694,1.7869588,,,,,,,,,,,,,,,,, -31200,0.20998925,1.8383107,,,,,,,,,,,,,,,,, -31300,0.2145068,1.8193618,,,,,,,,,,,,,,,,, -31400,0.26166368,1.8426691,,,,,,,,,,,,,,,,, -31500,0.22231446,1.8259076,,,,,,,,,,,,,,,,, -31600,2.1261694,1.8406516,,,,,,,,,,,,,,,,, -31700,0.28029382,1.8709502,,,,,,,,,,,,,,,,, -31800,0.21734354,1.8529948,,,,,,,,,,,,,,,,, -31900,0.24886778,1.8516407,,,,,,,,,,,,,,,,, -32000,0.19723254,1.8246819,,,,,,,,,,,,,,,,, -32100,0.24339072,1.8232228,,,,,,,,,,,,,,,,, -32200,0.24348034,1.8027422,,,,,,,,,,,,,,,,, -32300,0.20442663,1.7124698,,,,,,,,,,,,,,,,, -32400,0.21598296,1.8392776,,,,,,,,,,,,,,,,, -32500,0.21817404,1.6577059,,,,,,,,,,,,,,,,, -32600,0.20271303,1.8093399,,,,,,,,,,,,,,,,, -32700,0.21148892,1.8324208,,,,,,,,,,,,,,,,, -32800,0.22771284,1.9039389,,,,,,,,,,,,,,,,, -32900,0.22917624,1.7724785,,,,,,,,,,,,,,,,, -33000,0.2234713,1.7222877,,,,,,,,,,,,,,,,, -33020,,,0.6443125009536743,1.6968934535980225,31.166148767869927,0.6595454216003418,1.5900745391845703,27.80920383950556,3000.0,0.6711289286613464,1.5155937671661377,27.68331961654764,3003.0,11792.630442142488,19747.73938369751,11792.630442142488,7953.593356847763,0.4070956707000732,0.0 -33100,0.1942956,1.8175385,,,,,,,,,,,,,,,,, -33200,0.22374395,1.8136247,,,,,,,,,,,,,,,,, -33300,0.21976857,1.7019691,,,,,,,,,,,,,,,,, -33400,0.2016015,1.8290446,,,,,,,,,,,,,,,,, -33500,0.2149344,1.778467,,,,,,,,,,,,,,,,, -33600,0.2065373,1.852644,,,,,,,,,,,,,,,,, -33700,1.2075089,1.9194689,,,,,,,,,,,,,,,,, -33800,0.23562473,1.7931141,,,,,,,,,,,,,,,,, -33900,0.29252848,1.8264328,,,,,,,,,,,,,,,,, -34000,0.5170896,1.7889559,,,,,,,,,,,,,,,,, -34100,0.1862886,1.8554643,,,,,,,,,,,,,,,,, -34200,0.21034569,1.7991954,,,,,,,,,,,,,,,,, -34300,0.20413741,1.8079517,,,,,,,,,,,,,,,,, -34400,0.19109933,1.74776,,,,,,,,,,,,,,,,, -34500,0.24961,1.7826558,,,,,,,,,,,,,,,,, -34600,0.19621606,1.7952547,,,,,,,,,,,,,,,,, -34700,0.19106972,1.7744509,,,,,,,,,,,,,,,,, -34800,0.28059983,1.6719,,,,,,,,,,,,,,,,, -34900,0.19412053,1.7789404,,,,,,,,,,,,,,,,, -35000,0.19203384,1.7347506,,,,,,,,,,,,,,,,, -35100,0.20121257,1.7799557,,,,,,,,,,,,,,,,, -35200,0.21933526,1.7828159,,,,,,,,,,,,,,,,, -35300,0.21040946,1.7124171,,,,,,,,,,,,,,,,, -35378,,,0.6399866938591003,1.7375435829162598,31.33820875483425,0.6620624661445618,1.5832661390304563,28.38960980084261,3000.0,0.6730231046676636,1.5036847591400146,27.782769305202045,3003.0,12632.544389486313,21339.37600016594,12632.544389486313,8705.197973966599,0.4429934024810791,0.0 -35400,0.22539596,1.8125995,,,,,,,,,,,,,,,,, -35500,0.24933517,1.8236998,,,,,,,,,,,,,,,,, -35600,0.24923477,1.7825925,,,,,,,,,,,,,,,,, -35700,0.18886366,1.7032037,,,,,,,,,,,,,,,,, -35800,0.2198066,1.8260893,,,,,,,,,,,,,,,,, -35900,0.2629357,1.8554335,,,,,,,,,,,,,,,,, -36000,0.25322583,1.7323661,,,,,,,,,,,,,,,,, -36100,0.21830206,1.7596936,,,,,,,,,,,,,,,,, -36200,0.20315163,1.7693752,,,,,,,,,,,,,,,,, -36300,0.22245488,1.754584,,,,,,,,,,,,,,,,, -36400,0.22687331,1.7555449,,,,,,,,,,,,,,,,, -36500,0.21350409,1.7308519,,,,,,,,,,,,,,,,, -36600,0.20174041,1.7179044,,,,,,,,,,,,,,,,, -36700,0.1909745,1.7133101,,,,,,,,,,,,,,,,, -36800,0.24664721,1.7906389,,,,,,,,,,,,,,,,, -36900,0.192201,1.8492391,,,,,,,,,,,,,,,,, -37000,0.19468641,1.7459847,,,,,,,,,,,,,,,,, -37100,0.22366937,1.80373,,,,,,,,,,,,,,,,, -37200,0.22830881,1.7561023,,,,,,,,,,,,,,,,, -37300,0.20481756,1.8328524,,,,,,,,,,,,,,,,, -37400,0.23007503,1.762398,,,,,,,,,,,,,,,,, -37500,0.24135365,1.7668971,,,,,,,,,,,,,,,,, -37600,0.24974582,1.7879395,,,,,,,,,,,,,,,,, -37700,0.28459704,1.7139235,,,,,,,,,,,,,,,,, -37736,,,0.663627564907074,1.5535968542099,32.87796516096618,0.6638479232788086,1.5714770555496216,28.38134980475151,3000.0,0.6728371381759644,1.4999555349349976,27.78853257605112,3003.0,13472.63930630684,22769.643906354904,13472.63930630684,9295.258509159088,0.474111795425415,0.0 -37800,0.17645687,1.7055396,,,,,,,,,,,,,,,,, -37900,0.24072935,1.8690655,,,,,,,,,,,,,,,,, -38000,0.22862011,1.764155,,,,,,,,,,,,,,,,, -38100,0.17921937,1.7936932,,,,,,,,,,,,,,,,, -38200,0.26329187,1.7065938,,,,,,,,,,,,,,,,, -38300,0.2259871,1.7522969,,,,,,,,,,,,,,,,, -38400,0.20050012,1.8065782,,,,,,,,,,,,,,,,, -38500,0.24894826,1.7582932,,,,,,,,,,,,,,,,, -38600,0.18744792,1.7083141,,,,,,,,,,,,,,,,, -38700,0.22008269,1.7369508,,,,,,,,,,,,,,,,, -38800,0.21481222,1.804193,,,,,,,,,,,,,,,,, -38900,0.19849718,1.7489794,,,,,,,,,,,,,,,,, -39000,0.23435189,1.7816057,,,,,,,,,,,,,,,,, -39100,0.2050276,1.7298968,,,,,,,,,,,,,,,,, -39200,0.20220152,1.8026853,,,,,,,,,,,,,,,,, -39300,0.20307288,1.71381,,,,,,,,,,,,,,,,, -39400,0.2340662,1.737684,,,,,,,,,,,,,,,,, -39500,0.19731542,1.7857003,,,,,,,,,,,,,,,,, -39600,0.2085527,1.746227,,,,,,,,,,,,,,,,, -39700,0.19019115,1.8267282,,,,,,,,,,,,,,,,, -39800,0.20474103,1.8138714,,,,,,,,,,,,,,,,, -39900,0.31439066,1.7928616,,,,,,,,,,,,,,,,, -40000,0.23219346,1.7550442,,,,,,,,,,,,,,,,, -40094,,,0.6476908922195435,1.6861525774002075,31.62556446241438,0.6630792021751404,1.5618013143539429,28.31053959711827,3000.0,0.6751612424850464,1.4886205196380615,28.055212927233068,3003.0,14312.627776622772,24356.119666576385,14312.627776622772,10041.635178804398,0.5051698684692383,0.0 -40100,0.22000149,1.7366264,,,,,,,,,,,,,,,,, -40200,0.19804145,1.8100361,,,,,,,,,,,,,,,,, -40300,0.20157896,1.7171987,,,,,,,,,,,,,,,,, -40400,0.2146601,1.7806168,,,,,,,,,,,,,,,,, -40500,0.30153814,1.8085841,,,,,,,,,,,,,,,,, -40600,0.19163854,1.7488916,,,,,,,,,,,,,,,,, -40700,0.19013512,1.7516565,,,,,,,,,,,,,,,,, -40800,0.2043597,1.7908566,,,,,,,,,,,,,,,,, -40900,0.41494742,1.8494939,,,,,,,,,,,,,,,,, -41000,0.2774044,1.7684894,,,,,,,,,,,,,,,,, -41100,0.20719399,1.7592543,,,,,,,,,,,,,,,,, -41200,0.23210032,1.7920835,,,,,,,,,,,,,,,,, -41300,0.18770468,1.8019546,,,,,,,,,,,,,,,,, -41400,0.19466297,1.7853234,,,,,,,,,,,,,,,,, -41500,0.24430448,1.7092885,,,,,,,,,,,,,,,,, -41600,0.19300467,1.8723702,,,,,,,,,,,,,,,,, -41700,0.25540316,1.75879,,,,,,,,,,,,,,,,, -41800,0.23720153,1.7918943,,,,,,,,,,,,,,,,, -41900,0.20255548,1.7389754,,,,,,,,,,,,,,,,, -42000,0.27501854,1.7600971,,,,,,,,,,,,,,,,, -42100,0.20147787,1.7184513,,,,,,,,,,,,,,,,, -42200,0.21611501,1.7222388,,,,,,,,,,,,,,,,, -42300,0.22905949,1.7170931,,,,,,,,,,,,,,,,, -42400,0.33296612,1.7486575,,,,,,,,,,,,,,,,, -42452,,,0.6445221900939941,1.700629711151123,31.200524801237783,0.6668733358383179,1.5635435581207275,28.73763681514912,3000.0,0.6750218272209167,1.486008644104004,27.51980545574519,3003.0,15152.739064216614,25757.05531859398,15152.739064216614,10602.350082874298,0.5357956886291504,0.0 -42500,0.18872315,1.7655997,,,,,,,,,,,,,,,,, -42600,0.23648082,1.7922355,,,,,,,,,,,,,,,,, -42700,0.23703136,1.7996715,,,,,,,,,,,,,,,,, -42800,1.0696356,1.7353001,,,,,,,,,,,,,,,,, -42900,0.20900437,1.74873,,,,,,,,,,,,,,,,, -43000,0.20336413,1.7239448,,,,,,,,,,,,,,,,, -43100,0.20437235,1.7605929,,,,,,,,,,,,,,,,, -43200,0.19159473,1.7388706,,,,,,,,,,,,,,,,, -43300,0.20919895,1.8519318,,,,,,,,,,,,,,,,, -43400,0.23280983,1.7411599,,,,,,,,,,,,,,,,, -43500,0.19341956,1.6903012,,,,,,,,,,,,,,,,, -43600,0.23095828,1.739668,,,,,,,,,,,,,,,,, -43700,0.32326382,1.7959985,,,,,,,,,,,,,,,,, -43800,0.19974007,1.8812898,,,,,,,,,,,,,,,,, -43900,0.19973499,1.7644053,,,,,,,,,,,,,,,,, -44000,0.1890275,1.756958,,,,,,,,,,,,,,,,, -44100,0.20754403,1.8013828,,,,,,,,,,,,,,,,, -44200,0.2171773,1.7721982,,,,,,,,,,,,,,,,, -44300,0.20368162,1.7985947,,,,,,,,,,,,,,,,, -44400,0.19932012,1.6602105,,,,,,,,,,,,,,,,, -44500,0.19676875,1.7707113,,,,,,,,,,,,,,,,, -44600,0.19444309,1.6543568,,,,,,,,,,,,,,,,, -44700,0.20377125,1.7707084,,,,,,,,,,,,,,,,, -44800,0.2068707,1.7484499,,,,,,,,,,,,,,,,, -44811,,,0.6530870795249939,1.634215235710144,31.688902445897607,0.6673692464828491,1.5494558811187744,28.50414605105004,3000.0,0.6767880916595459,1.471121311187744,27.9387057645756,3003.0,15992.64580154419,27153.050884485245,15992.64580154419,11158.329864501951,0.566993236541748,0.0 -44900,0.18881565,1.8869696,,,,,,,,,,,,,,,,, -45000,0.2361548,1.7292529,,,,,,,,,,,,,,,,, -45100,0.19419608,1.7436048,,,,,,,,,,,,,,,,, -45200,0.20245615,1.6982934,,,,,,,,,,,,,,,,, -45300,0.20092611,1.721618,,,,,,,,,,,,,,,,, -45400,0.21171466,1.8193717,,,,,,,,,,,,,,,,, -45500,0.24655765,1.7945778,,,,,,,,,,,,,,,,, -45600,0.20814344,1.7117348,,,,,,,,,,,,,,,,, -45700,0.20048934,1.7103384,,,,,,,,,,,,,,,,, -45800,0.21159647,1.6433322,,,,,,,,,,,,,,,,, -45900,0.1919141,1.7502279,,,,,,,,,,,,,,,,, -46000,0.1831576,1.8099469,,,,,,,,,,,,,,,,, -46100,0.20891711,1.7041942,,,,,,,,,,,,,,,,, -46200,0.2162756,1.7023718,,,,,,,,,,,,,,,,, -46300,0.19748585,1.796828,,,,,,,,,,,,,,,,, -46400,0.19023947,1.7308065,,,,,,,,,,,,,,,,, -46500,0.21147162,1.7555534,,,,,,,,,,,,,,,,, -46600,0.20112504,1.7545995,,,,,,,,,,,,,,,,, -46700,0.194485,1.7282431,,,,,,,,,,,,,,,,, -46800,0.20630947,1.7846633,,,,,,,,,,,,,,,,, -46900,0.23857549,1.7598493,,,,,,,,,,,,,,,,, -47000,0.19661362,1.7134336,,,,,,,,,,,,,,,,, -47100,0.20250137,1.6983565,,,,,,,,,,,,,,,,, -47170,,,0.6481767892837524,1.6744226217269895,31.457149841710663,0.668088436126709,1.545769214630127,28.76677074342338,3000.0,0.6807855367660522,1.45988130569458,28.13457178472897,3003.0,16832.79411482811,28478.575281381607,16832.79411482811,11643.594151735306,0.5991125106811523,0.0 -47200,0.195202,1.6696463,,,,,,,,,,,,,,,,, -47300,0.21690246,1.6851718,,,,,,,,,,,,,,,,, -47400,0.23105966,1.7656221,,,,,,,,,,,,,,,,, -47500,0.17776857,1.7610579,,,,,,,,,,,,,,,,, -47600,0.2500959,1.6878835,,,,,,,,,,,,,,,,, -47700,0.22037846,1.7965693,,,,,,,,,,,,,,,,, -47800,0.2002359,1.780044,,,,,,,,,,,,,,,,, -47900,0.22053435,1.8382959,,,,,,,,,,,,,,,,, -48000,0.20069556,1.685532,,,,,,,,,,,,,,,,, -48100,0.20138246,1.7293715,,,,,,,,,,,,,,,,, -48200,0.23916486,1.6877024,,,,,,,,,,,,,,,,, -48300,0.20219116,1.8049269,,,,,,,,,,,,,,,,, -48400,0.20318528,1.706929,,,,,,,,,,,,,,,,, -48500,0.22358672,1.8232344,,,,,,,,,,,,,,,,, -48600,0.27206582,1.7016404,,,,,,,,,,,,,,,,, -48700,0.21178344,1.7393231,,,,,,,,,,,,,,,,, -48800,0.21139261,1.7017494,,,,,,,,,,,,,,,,, -48900,0.23764493,1.7587223,,,,,,,,,,,,,,,,, -49000,0.18407314,1.763925,,,,,,,,,,,,,,,,, -49100,0.20616582,1.7122476,,,,,,,,,,,,,,,,, -49200,0.19377665,1.7839087,,,,,,,,,,,,,,,,, -49300,0.22096993,1.7556081,,,,,,,,,,,,,,,,, -49400,0.21410066,1.6965376,,,,,,,,,,,,,,,,, -49500,0.19850051,1.7581561,,,,,,,,,,,,,,,,, -49530,,,0.6455420851707458,1.7034410238265991,31.77329398717936,0.666240930557251,1.5409092903137207,28.787587888853047,3000.0,0.681494414806366,1.4549793004989624,28.54769671956192,3003.0,17672.922026872635,29877.601472377777,17672.922026872635,12202.3813393116,0.6326305866241455,0.0 -49600,0.21704385,1.7562042,,,,,,,,,,,,,,,,, -49700,0.19662027,1.7029743,,,,,,,,,,,,,,,,, -49800,0.19564354,1.6680081,,,,,,,,,,,,,,,,, -49900,0.32213703,1.723249,,,,,,,,,,,,,,,,, -50000,0.1738184,1.7549745,,,,,,,,,,,,,,,,, -50100,0.2042695,1.6924957,,,,,,,,,,,,,,,,, -50200,0.21375993,1.7079413,,,,,,,,,,,,,,,,, -50300,0.20729984,1.7391317,,,,,,,,,,,,,,,,, -50400,0.19183041,1.777486,,,,,,,,,,,,,,,,, -50500,0.19041431,1.7879776,,,,,,,,,,,,,,,,, -50600,0.20452406,1.7694473,,,,,,,,,,,,,,,,, -50700,0.21633346,1.7930759,,,,,,,,,,,,,,,,, -50800,0.24379748,1.7143178,,,,,,,,,,,,,,,,, -50900,0.20702219,1.7681578,,,,,,,,,,,,,,,,, -51000,0.19648036,1.6962905,,,,,,,,,,,,,,,,, -51100,0.2070531,1.7312199,,,,,,,,,,,,,,,,, -51200,0.19271278,1.7211007,,,,,,,,,,,,,,,,, -51300,0.20566146,1.7993104,,,,,,,,,,,,,,,,, -51400,0.19778606,1.717358,,,,,,,,,,,,,,,,, -51500,0.22322315,1.732321,,,,,,,,,,,,,,,,, -51600,0.20399255,1.6902452,,,,,,,,,,,,,,,,, -51700,0.20263354,1.8011223,,,,,,,,,,,,,,,,, -51800,0.19191347,1.7607564,,,,,,,,,,,,,,,,, -51890,,,0.652571439743042,1.6386526823043823,32.39050117798098,0.6687827706336975,1.5333878993988037,28.55614230670583,3000.0,0.6804369688034058,1.45551860332489,28.056858321750056,3003.0,18513.092685222626,31397.252323150635,18513.092685222626,12881.750809669496,0.6645634174346924,0.0 -51900,0.19898702,1.6616588,,,,,,,,,,,,,,,,, -52000,0.19565794,1.7894846,,,,,,,,,,,,,,,,, -52100,0.20379424,1.7464525,,,,,,,,,,,,,,,,, -52200,0.19148852,1.7013181,,,,,,,,,,,,,,,,, -52300,0.18688138,1.7084068,,,,,,,,,,,,,,,,, -52400,0.21724626,1.7794445,,,,,,,,,,,,,,,,, -52500,0.20087786,1.6620425,,,,,,,,,,,,,,,,, -52600,0.2090347,1.7013003,,,,,,,,,,,,,,,,, -52700,0.19730146,1.6534712,,,,,,,,,,,,,,,,, -52800,0.266485,1.7386339,,,,,,,,,,,,,,,,, -52900,0.18742085,1.7803177,,,,,,,,,,,,,,,,, -53000,0.22774172,1.674379,,,,,,,,,,,,,,,,, -53100,0.19043529,1.7554797,,,,,,,,,,,,,,,,, -53200,0.2097573,1.7546666,,,,,,,,,,,,,,,,, -53300,0.21185403,1.6772846,,,,,,,,,,,,,,,,, -53400,0.19235806,1.7123945,,,,,,,,,,,,,,,,, -53500,0.20110664,1.6442279,,,,,,,,,,,,,,,,, -53600,0.26373708,1.736623,,,,,,,,,,,,,,,,, -53700,0.26133135,1.6144376,,,,,,,,,,,,,,,,, -53800,0.21028183,1.7829912,,,,,,,,,,,,,,,,, -53900,0.22184257,1.75787,,,,,,,,,,,,,,,,, -54000,0.19651344,1.7211815,,,,,,,,,,,,,,,,, -54100,0.24243033,1.765423,,,,,,,,,,,,,,,,, -54200,0.21200226,1.7472733,,,,,,,,,,,,,,,,, -54249,,,0.6511799693107605,1.6562837362289429,31.79512948258729,0.6702210903167725,1.527982473373413,29.00607766441165,3000.0,0.6831910014152527,1.4422683715820312,28.418639510555767,3003.0,19353.05917453766,32914.703605651855,19353.05917453766,13559.124497413635,0.6962974071502686,0.0 -54300,0.21898271,1.7361406,,,,,,,,,,,,,,,,, -54400,0.19893676,1.7194737,,,,,,,,,,,,,,,,, -54500,0.21733083,1.8411278,,,,,,,,,,,,,,,,, -54600,0.20728476,1.8027422,,,,,,,,,,,,,,,,, -54700,0.2137363,1.744039,,,,,,,,,,,,,,,,, -54800,0.205296,1.7236077,,,,,,,,,,,,,,,,, -54900,0.21656787,1.7001547,,,,,,,,,,,,,,,,, -55000,0.20367141,1.7000872,,,,,,,,,,,,,,,,, -55100,0.2710293,1.6742688,,,,,,,,,,,,,,,,, -55200,0.22376917,1.7088239,,,,,,,,,,,,,,,,, -55300,0.20959881,1.7025219,,,,,,,,,,,,,,,,, -55400,0.21390952,1.6691852,,,,,,,,,,,,,,,,, -55500,0.21045676,1.8492088,,,,,,,,,,,,,,,,, -55600,0.22569752,1.7152157,,,,,,,,,,,,,,,,, -55700,0.19765171,1.7398802,,,,,,,,,,,,,,,,, -55800,0.2021264,1.7517462,,,,,,,,,,,,,,,,, -55900,0.20483266,1.8584462,,,,,,,,,,,,,,,,, -56000,0.21428667,1.7387849,,,,,,,,,,,,,,,,, -56100,0.199447,1.6942645,,,,,,,,,,,,,,,,, -56200,0.20017703,1.7331051,,,,,,,,,,,,,,,,, -56300,0.19763875,1.8011514,,,,,,,,,,,,,,,,, -56400,0.21233019,1.6805068,,,,,,,,,,,,,,,,, -56500,0.19823301,1.6935521,,,,,,,,,,,,,,,,, -56600,0.3468578,1.715454,,,,,,,,,,,,,,,,, -56609,,,0.6711536049842834,1.5183404684066772,33.20045640260435,0.6715229749679565,1.5183759927749634,29.05587524880472,3000.0,0.6822032332420349,1.4344021081924438,28.28308659347989,3003.0,20193.22402882576,34388.2636551857,20193.22402882576,14192.409289360046,0.7298541069030762,0.0 -56700,0.21785681,1.6956686,,,,,,,,,,,,,,,,, -56800,0.23315248,1.7284337,,,,,,,,,,,,,,,,, -56900,0.19018427,1.6884574,,,,,,,,,,,,,,,,, -57000,0.24242604,1.7703488,,,,,,,,,,,,,,,,, -57100,0.19704479,1.690899,,,,,,,,,,,,,,,,, -57200,0.20651811,1.7827244,,,,,,,,,,,,,,,,, -57300,0.21363729,1.6858611,,,,,,,,,,,,,,,,, -57400,0.19614804,1.7361376,,,,,,,,,,,,,,,,, -57500,0.20838334,1.7232552,,,,,,,,,,,,,,,,, -57600,0.17968735,1.754231,,,,,,,,,,,,,,,,, -57700,0.35415226,1.6996601,,,,,,,,,,,,,,,,, -57800,0.19317572,1.7206026,,,,,,,,,,,,,,,,, -57900,0.20672877,1.6537993,,,,,,,,,,,,,,,,, -58000,0.21400885,1.7002119,,,,,,,,,,,,,,,,, -58100,0.7999465,1.6439141,,,,,,,,,,,,,,,,, -58200,0.24101646,1.7061096,,,,,,,,,,,,,,,,, -58300,0.19726941,1.6488652,,,,,,,,,,,,,,,,, -58400,0.18373528,1.6977473,,,,,,,,,,,,,,,,, -58500,0.19598527,1.7301766,,,,,,,,,,,,,,,,, -58600,0.19538173,1.6693496,,,,,,,,,,,,,,,,, -58700,0.25280437,1.7412359,,,,,,,,,,,,,,,,, -58800,0.20685937,1.6131446,,,,,,,,,,,,,,,,, -58900,0.2089689,1.6767377,,,,,,,,,,,,,,,,, -58968,,,0.6562167406082153,1.6146293878555298,32.00557628346472,0.6700970530509949,1.5160971879959106,28.6046113423605,3000.0,0.6840625405311584,1.4318151473999023,28.442912978921708,3003.0,21033.412729263306,36007.44893741608,21033.412729263306,14971.292563676834,0.7623600959777832,0.0 -59000,0.19167733,1.6670152,,,,,,,,,,,,,,,,, -59100,0.19875428,1.7145442,,,,,,,,,,,,,,,,, -59200,0.20749633,1.7747669,,,,,,,,,,,,,,,,, -59300,0.19414069,1.7704176,,,,,,,,,,,,,,,,, -59400,0.20519303,1.6430004,,,,,,,,,,,,,,,,, -59500,0.19911979,1.6983929,,,,,,,,,,,,,,,,, -59600,0.19905886,1.6781584,,,,,,,,,,,,,,,,, -59700,0.2309405,1.7492138,,,,,,,,,,,,,,,,, -59800,0.20458603,1.7149284,,,,,,,,,,,,,,,,, -59900,0.2373399,1.7012048,,,,,,,,,,,,,,,,, -60000,0.21546374,1.7599103,,,,,,,,,,,,,,,,, -60100,0.21974136,1.6931953,,,,,,,,,,,,,,,,, -60200,0.21207522,1.7058878,,,,,,,,,,,,,,,,, -60300,0.21304491,1.7531557,,,,,,,,,,,,,,,,, -60400,0.21129523,1.7379787,,,,,,,,,,,,,,,,, -60500,0.20086455,1.7077656,,,,,,,,,,,,,,,,, -60600,0.22189708,1.6440793,,,,,,,,,,,,,,,,, -60700,0.22985938,1.7934327,,,,,,,,,,,,,,,,, -60800,0.18365018,1.7891468,,,,,,,,,,,,,,,,, -60900,0.20407458,1.718481,,,,,,,,,,,,,,,,, -61000,0.19571589,1.7839463,,,,,,,,,,,,,,,,, -61100,0.20927086,1.6917547,,,,,,,,,,,,,,,,, -61200,0.22452396,1.6989468,,,,,,,,,,,,,,,,, -61300,0.26896486,1.648605,,,,,,,,,,,,,,,,, -61327,,,0.6536709666252136,1.6435863971710205,32.201144941719825,0.67365562915802,1.5085053443908691,29.024937312260363,3000.0,0.6860031485557556,1.4202231168746948,28.75322589399797,3003.0,21873.331847190857,37395.15630316734,21873.331847190857,15518.967089891434,0.7969238758087158,0.0 -61400,0.22628923,1.7071334,,,,,,,,,,,,,,,,, -61500,0.20665888,1.7329761,,,,,,,,,,,,,,,,, -61600,0.20927571,1.7473946,,,,,,,,,,,,,,,,, -61700,0.2006578,1.6940724,,,,,,,,,,,,,,,,, -61800,0.2090838,1.7771703,,,,,,,,,,,,,,,,, -61900,0.20243864,1.6182684,,,,,,,,,,,,,,,,, -62000,0.2234933,1.726973,,,,,,,,,,,,,,,,, -62100,0.22964938,1.756898,,,,,,,,,,,,,,,,, -62200,0.19845451,1.6404697,,,,,,,,,,,,,,,,, -62300,0.23883508,1.6692524,,,,,,,,,,,,,,,,, -62400,0.19116274,1.5972755,,,,,,,,,,,,,,,,, -62500,0.20674008,1.6968538,,,,,,,,,,,,,,,,, -62600,0.19908416,1.6382664,,,,,,,,,,,,,,,,, -62700,0.2535982,1.7056327,,,,,,,,,,,,,,,,, -62800,0.21961011,1.697635,,,,,,,,,,,,,,,,, -62900,0.2054147,1.7212602,,,,,,,,,,,,,,,,, -63000,0.20153713,1.7046294,,,,,,,,,,,,,,,,, -63100,0.20195843,1.6914351,,,,,,,,,,,,,,,,, -63200,0.20172718,1.7204958,,,,,,,,,,,,,,,,, -63300,0.20370327,1.6267107,,,,,,,,,,,,,,,,, -63400,0.24299598,1.7486745,,,,,,,,,,,,,,,,, -63500,0.19028641,1.6780777,,,,,,,,,,,,,,,,, -63600,0.40159813,1.6458961,,,,,,,,,,,,,,,,, -63686,,,0.6629753708839417,1.568068265914917,33.01746848462258,0.6728744506835938,1.503533124923706,28.890694782084704,3000.0,0.686851441860199,1.4139506816864014,29.019938516980574,3003.0,22713.31698822975,38811.658296108246,22713.31698822975,16095.283498048782,0.9187502861022948,0.0 -63700,0.19976853,1.7021471,,,,,,,,,,,,,,,,, -63800,0.20139474,1.7009064,,,,,,,,,,,,,,,,, -63900,0.20454209,1.705637,,,,,,,,,,,,,,,,, -64000,0.21153572,1.6549667,,,,,,,,,,,,,,,,, -64100,0.20193638,1.7010505,,,,,,,,,,,,,,,,, -64200,0.20681214,1.71213,,,,,,,,,,,,,,,,, -64300,0.20106822,1.6760632,,,,,,,,,,,,,,,,, -64400,0.20348896,1.7191453,,,,,,,,,,,,,,,,, -64500,0.24245709,1.7458925,,,,,,,,,,,,,,,,, -64600,0.23298378,1.6935613,,,,,,,,,,,,,,,,, -64700,0.1969276,1.7401918,,,,,,,,,,,,,,,,, -64800,0.22426093,1.7157058,,,,,,,,,,,,,,,,, -64900,0.1944873,1.6446137,,,,,,,,,,,,,,,,, -65000,0.21756203,1.6662954,,,,,,,,,,,,,,,,, -65100,0.19972464,1.7024579,,,,,,,,,,,,,,,,, -65200,0.20762345,1.6406918,,,,,,,,,,,,,,,,, -65300,0.2002557,1.6791635,,,,,,,,,,,,,,,,, -65400,0.19422041,1.730781,,,,,,,,,,,,,,,,, -65500,0.22082408,1.660903,,,,,,,,,,,,,,,,, -65600,0.2022698,1.7241302,,,,,,,,,,,,,,,,, -65700,0.2064012,1.7382718,,,,,,,,,,,,,,,,, -65800,0.21288976,1.6640011,,,,,,,,,,,,,,,,, -65900,0.20286986,1.6729239,,,,,,,,,,,,,,,,, -66000,0.20788334,1.6247119,,,,,,,,,,,,,,,,, -66045,,,0.6583556532859802,1.5996129512786863,32.41859989742564,0.6759990453720093,1.4937245845794678,29.61073050124265,3000.0,0.6908256411552429,1.397244215011597,29.21380668381085,3003.0,23553.378241539,40180.534745931625,23553.378241539,16623.977352142334,0.9587399959564208,0.0 -66100,0.22596253,1.6857412,,,,,,,,,,,,,,,,, -66200,0.20325504,1.688861,,,,,,,,,,,,,,,,, -66300,0.21267271,1.728454,,,,,,,,,,,,,,,,, -66400,0.19739446,1.6109798,,,,,,,,,,,,,,,,, -66500,0.21052499,1.6644961,,,,,,,,,,,,,,,,, -66600,0.20275182,1.6869218,,,,,,,,,,,,,,,,, -66700,0.21793313,1.7422285,,,,,,,,,,,,,,,,, -66800,0.18964097,1.6369983,,,,,,,,,,,,,,,,, -66900,0.26552358,1.7028757,,,,,,,,,,,,,,,,, -67000,0.20993179,1.6473299,,,,,,,,,,,,,,,,, -67100,0.22605352,1.7111803,,,,,,,,,,,,,,,,, -67200,0.2053819,1.6470927,,,,,,,,,,,,,,,,, -67300,0.19613208,1.7006665,,,,,,,,,,,,,,,,, -67400,0.2106931,1.7105505,,,,,,,,,,,,,,,,, -67500,0.21038452,1.7601073,,,,,,,,,,,,,,,,, -67600,0.20693411,1.6082038,,,,,,,,,,,,,,,,, -67700,0.34789848,1.6381035,,,,,,,,,,,,,,,,, -67800,0.22124563,1.6994978,,,,,,,,,,,,,,,,, -67900,0.21408421,1.7129189,,,,,,,,,,,,,,,,, -68000,0.33633938,1.6544925,,,,,,,,,,,,,,,,, -68100,0.20792641,1.6434832,,,,,,,,,,,,,,,,, -68200,0.23180252,1.7546598,,,,,,,,,,,,,,,,, -68300,0.18349917,1.6215199,,,,,,,,,,,,,,,,, -68400,0.21081947,1.628204,,,,,,,,,,,,,,,,, -68403,,,0.6556589603424072,1.61857008934021,32.53331232232554,0.675317108631134,1.4912937879562378,29.18859526057956,3000.0,0.6906862258911133,1.391940951347351,28.9712420024266,3003.0,24393.38697886467,41567.93536186218,24393.38697886467,17171.25036072731,0.9950568675994872,0.0 -68500,0.2503337,1.6940701,,,,,,,,,,,,,,,,, -68600,0.21482088,1.6372173,,,,,,,,,,,,,,,,, -68700,0.22575963,1.7216998,,,,,,,,,,,,,,,,, -68800,0.22252475,1.5878319,,,,,,,,,,,,,,,,, -68900,0.20667784,1.6825652,,,,,,,,,,,,,,,,, -69000,0.20153768,1.6740253,,,,,,,,,,,,,,,,, -69100,0.18560031,1.7313279,,,,,,,,,,,,,,,,, -69200,0.21529427,1.732718,,,,,,,,,,,,,,,,, -69300,0.20144114,1.708933,,,,,,,,,,,,,,,,, -69400,0.2025315,1.6221223,,,,,,,,,,,,,,,,, -69500,0.20602253,1.6324903,,,,,,,,,,,,,,,,, -69600,0.20580943,1.6343502,,,,,,,,,,,,,,,,, -69700,0.21193059,1.7138408,,,,,,,,,,,,,,,,, -69800,0.23004419,1.6301996,,,,,,,,,,,,,,,,, -69900,0.20822558,1.7029337,,,,,,,,,,,,,,,,, -70000,0.19524762,1.669668,,,,,,,,,,,,,,,,, -70100,0.22907405,1.6540862,,,,,,,,,,,,,,,,, -70200,0.2016347,1.7101182,,,,,,,,,,,,,,,,, -70300,0.20023336,1.641435,,,,,,,,,,,,,,,,, -70400,0.21881008,1.593432,,,,,,,,,,,,,,,,, -70500,0.21102066,1.6469853,,,,,,,,,,,,,,,,, -70600,0.2237645,1.6397364,,,,,,,,,,,,,,,,, -70700,0.20738266,1.6428525,,,,,,,,,,,,,,,,, -70762,,,0.6649782657623291,1.5656242370605469,33.27862937609742,0.6770901679992676,1.4805506467819214,29.4606829320584,3000.0,0.6927662491798401,1.3877900838851929,29.39340593968376,3003.0,25233.515276670456,43111.59431147576,25233.515276670456,17874.665945768356,1.0316176414489746,0.0 -70800,0.22461993,1.6372963,,,,,,,,,,,,,,,,, -70900,0.20857409,1.6736447,,,,,,,,,,,,,,,,, -71000,0.20607394,1.5595896,,,,,,,,,,,,,,,,, -71100,0.20582189,1.6532079,,,,,,,,,,,,,,,,, -71200,0.22236037,1.7383726,,,,,,,,,,,,,,,,, -71300,0.19983806,1.6571602,,,,,,,,,,,,,,,,, -71400,0.22415954,1.6652211,,,,,,,,,,,,,,,,, -71500,0.20006216,1.6791648,,,,,,,,,,,,,,,,, -71600,0.20309862,1.6505195,,,,,,,,,,,,,,,,, -71700,0.21926081,1.6671555,,,,,,,,,,,,,,,,, -71800,0.22168691,1.678051,,,,,,,,,,,,,,,,, -71900,0.20151,1.6459059,,,,,,,,,,,,,,,,, -72000,0.21129182,1.699556,,,,,,,,,,,,,,,,, -72100,0.20106931,1.6532608,,,,,,,,,,,,,,,,, -72200,0.22002059,1.5996567,,,,,,,,,,,,,,,,, -72300,0.19938447,1.6245052,,,,,,,,,,,,,,,,, -72400,0.2183568,1.7135537,,,,,,,,,,,,,,,,, -72500,0.26758015,1.6077622,,,,,,,,,,,,,,,,, -72600,0.19331458,1.6282316,,,,,,,,,,,,,,,,, -72700,0.1922275,1.6545848,,,,,,,,,,,,,,,,, -72800,0.1924008,1.6725981,,,,,,,,,,,,,,,,, -72900,0.20412497,1.6838243,,,,,,,,,,,,,,,,, -73000,0.22778463,1.6733779,,,,,,,,,,,,,,,,, -73100,0.21405329,1.6922011,,,,,,,,,,,,,,,,, -73121,,,0.6601581573486328,1.5850191116333008,33.035396256855584,0.6781564950942993,1.4745867252349854,29.27033968859275,3000.0,0.6933588981628418,1.37883198261261,29.37666021325784,3003.0,26073.75528979301,44533.76924061775,26073.75528979301,18456.486599206924,1.0666627883911133,0.0 -73200,0.20318054,1.6106291,,,,,,,,,,,,,,,,, -73300,0.19156148,1.566556,,,,,,,,,,,,,,,,, -73400,0.214102,1.5846915,,,,,,,,,,,,,,,,, -73500,0.19347851,1.6552415,,,,,,,,,,,,,,,,, -73600,0.23397711,1.7070047,,,,,,,,,,,,,,,,, -73700,0.21487537,1.6142606,,,,,,,,,,,,,,,,, -73800,0.21034794,1.6624848,,,,,,,,,,,,,,,,, -73900,0.19787908,1.6360892,,,,,,,,,,,,,,,,, -74000,0.2030309,1.6905096,,,,,,,,,,,,,,,,, -74100,0.21859722,1.6984782,,,,,,,,,,,,,,,,, -74200,0.21953711,1.7030959,,,,,,,,,,,,,,,,, -74300,0.2701776,1.7057122,,,,,,,,,,,,,,,,, -74400,0.24389751,1.6173775,,,,,,,,,,,,,,,,, -74500,0.19367279,1.5560526,,,,,,,,,,,,,,,,, -74600,0.21889557,1.5918058,,,,,,,,,,,,,,,,, -74700,0.23350462,1.5952206,,,,,,,,,,,,,,,,, -74800,0.20779026,1.7028934,,,,,,,,,,,,,,,,, -74900,0.20282741,1.6431462,,,,,,,,,,,,,,,,, -75000,0.19172429,1.614856,,,,,,,,,,,,,,,,, -75100,0.20231949,1.544397,,,,,,,,,,,,,,,,, -75200,0.21825409,1.6218978,,,,,,,,,,,,,,,,, -75300,0.22107293,1.6676499,,,,,,,,,,,,,,,,, -75400,0.18633717,1.6073656,,,,,,,,,,,,,,,,, -75480,,,0.6792148947715759,1.4641389846801758,33.80764561643805,0.6794211864471436,1.4671216011047363,29.555809985213795,3000.0,0.6943350434303284,1.3725438117980957,29.297592249645188,3003.0,26913.816690921783,46007.44537067413,26913.816690921783,19089.9792573452,1.1101996898651123,0.0 -75500,0.2026922,1.6499774,,,,,,,,,,,,,,,,, -75600,0.25390914,1.6722171,,,,,,,,,,,,,,,,, -75700,0.20972802,1.6645216,,,,,,,,,,,,,,,,, -75800,0.1980391,1.6536735,,,,,,,,,,,,,,,,, -75900,0.20256883,1.6247739,,,,,,,,,,,,,,,,, -76000,0.19194297,1.5973822,,,,,,,,,,,,,,,,, -76100,0.19564895,1.6227041,,,,,,,,,,,,,,,,, -76200,0.27588078,1.6105621,,,,,,,,,,,,,,,,, -76300,0.19753529,1.6549867,,,,,,,,,,,,,,,,, -76400,0.20585276,1.7272282,,,,,,,,,,,,,,,,, -76500,0.19960098,1.6214049,,,,,,,,,,,,,,,,, -76600,0.21564204,1.5933601,,,,,,,,,,,,,,,,, -76700,0.21895514,1.6498854,,,,,,,,,,,,,,,,, -76800,0.19515921,1.6358708,,,,,,,,,,,,,,,,, -76900,0.20615736,1.6264962,,,,,,,,,,,,,,,,, -77000,0.20076732,1.6850455,,,,,,,,,,,,,,,,, -77100,0.21131143,1.6235801,,,,,,,,,,,,,,,,, -77200,0.20919444,1.6264561,,,,,,,,,,,,,,,,, -77300,0.19282034,1.6402864,,,,,,,,,,,,,,,,, -77400,0.2232701,1.6316633,,,,,,,,,,,,,,,,, -77500,0.21599731,1.6548189,,,,,,,,,,,,,,,,, -77600,0.21625955,1.6289297,,,,,,,,,,,,,,,,, -77700,0.20403485,1.7951714,,,,,,,,,,,,,,,,, -77800,0.2099043,1.6935778,,,,,,,,,,,,,,,,, -77838,,,0.6665636897087097,1.5482903718948364,33.07441180399439,0.68016517162323,1.459350347518921,29.50371177493752,3000.0,0.6929754614830017,1.3658453226089478,29.21646204970246,3003.0,27753.721145629883,47643.87626647949,27753.721145629883,19886.381717681885,1.1533830165863037,0.0 -77900,0.21195626,1.6139011,,,,,,,,,,,,,,,,, -78000,0.1932589,1.5791793,,,,,,,,,,,,,,,,, -78100,0.21058586,1.6931337,,,,,,,,,,,,,,,,, -78200,0.22887093,1.6783626,,,,,,,,,,,,,,,,, -78300,0.20633605,1.6476786,,,,,,,,,,,,,,,,, -78400,0.1880814,1.6148571,,,,,,,,,,,,,,,,, -78500,0.20485388,1.6472578,,,,,,,,,,,,,,,,, -78600,0.19914626,1.6385729,,,,,,,,,,,,,,,,, -78700,0.21323144,1.644571,,,,,,,,,,,,,,,,, -78800,0.2002234,1.5608659,,,,,,,,,,,,,,,,, -78900,0.20016748,1.59443,,,,,,,,,,,,,,,,, -79000,0.20650913,1.6165233,,,,,,,,,,,,,,,,, -79100,0.198599,1.6760242,,,,,,,,,,,,,,,,, -79200,0.21551807,1.6756889,,,,,,,,,,,,,,,,, -79300,0.20451978,1.6435869,,,,,,,,,,,,,,,,, -79400,0.24513522,1.6160725,,,,,,,,,,,,,,,,, -79500,0.2147662,1.7125243,,,,,,,,,,,,,,,,, -79600,0.21027896,1.6484591,,,,,,,,,,,,,,,,, -79700,0.19925691,1.6677123,,,,,,,,,,,,,,,,, -79800,0.21119589,1.6533741,,,,,,,,,,,,,,,,, -79900,0.21351029,1.6685715,,,,,,,,,,,,,,,,, -80000,0.21478939,1.536551,,,,,,,,,,,,,,,,, -80100,0.21151975,1.7084562,,,,,,,,,,,,,,,,, -80197,,,0.664623498916626,1.563930869102478,33.14200471817913,0.6821985840797424,1.4505184888839722,29.735981389350844,3000.0,0.6975771188735962,1.3530352115631104,29.435464433153378,3003.0,28593.74077177048,49170.16017580032,28593.74077177048,20572.528188943863,1.1913011074066162,0.0 -80200,0.19928631,1.5878159,,,,,,,,,,,,,,,,, -80300,0.21306685,1.6820742,,,,,,,,,,,,,,,,, -80400,0.20672235,1.6123768,,,,,,,,,,,,,,,,, -80500,0.21906829,1.7207459,,,,,,,,,,,,,,,,, -80600,0.20027944,1.60123,,,,,,,,,,,,,,,,, -80700,0.19914529,1.5912389,,,,,,,,,,,,,,,,, -80800,0.22467484,1.6696467,,,,,,,,,,,,,,,,, -80900,0.5030672,1.637798,,,,,,,,,,,,,,,,, -81000,0.21218686,1.6815333,,,,,,,,,,,,,,,,, -81100,0.20091417,1.575227,,,,,,,,,,,,,,,,, -81200,0.20878223,1.62461,,,,,,,,,,,,,,,,, -81300,0.20740029,1.7206821,,,,,,,,,,,,,,,,, -81400,0.21018542,1.6017311,,,,,,,,,,,,,,,,, -81500,0.19218273,1.5521,,,,,,,,,,,,,,,,, -81600,0.21657749,1.6539237,,,,,,,,,,,,,,,,, -81700,0.201437,1.599345,,,,,,,,,,,,,,,,, -81800,0.23646569,1.618078,,,,,,,,,,,,,,,,, -81900,0.20705183,1.6626742,,,,,,,,,,,,,,,,, -82000,0.20885256,1.6854329,,,,,,,,,,,,,,,,, -82100,0.22390965,1.566714,,,,,,,,,,,,,,,,, -82200,0.19800785,1.6706429,,,,,,,,,,,,,,,,, -82300,0.20063217,1.7175041,,,,,,,,,,,,,,,,, -82400,0.20633885,1.6549724,,,,,,,,,,,,,,,,, -82500,0.20335647,1.6450166,,,,,,,,,,,,,,,,, -82557,,,0.6770812273025513,1.4800242185592651,34.02139815328523,0.6815042495727539,1.4466660022735596,29.45909323576052,3000.0,0.6959037780761719,1.3536893129348757,29.501150496830565,3003.0,29433.655618667603,50747.67319107056,29433.655618667603,21310.009718179703,1.229506015777588,0.0 -82600,0.2365033,1.5650688,,,,,,,,,,,,,,,,, -82700,0.22640306,1.6304575,,,,,,,,,,,,,,,,, -82800,0.2026016,1.6361713,,,,,,,,,,,,,,,,, -82900,0.21002834,1.6229032,,,,,,,,,,,,,,,,, -83000,0.1978605,1.5734882,,,,,,,,,,,,,,,,, -83100,0.19831385,1.6289847,,,,,,,,,,,,,,,,, -83200,0.20781158,1.6624913,,,,,,,,,,,,,,,,, -83300,0.2052174,1.5580571,,,,,,,,,,,,,,,,, -83400,0.22021198,1.6229378,,,,,,,,,,,,,,,,, -83500,0.2079455,1.6346824,,,,,,,,,,,,,,,,, -83600,0.22174212,1.6276648,,,,,,,,,,,,,,,,, -83700,0.21839048,1.5818037,,,,,,,,,,,,,,,,, -83800,0.1954205,1.5488639,,,,,,,,,,,,,,,,, -83900,0.20409136,1.6620069,,,,,,,,,,,,,,,,, -84000,0.21702826,1.5780383,,,,,,,,,,,,,,,,, -84100,0.23145045,1.583626,,,,,,,,,,,,,,,,, -84200,0.19443282,1.6399091,,,,,,,,,,,,,,,,, -84300,0.20531881,1.5839195,,,,,,,,,,,,,,,,, -84400,0.20209569,1.6601198,,,,,,,,,,,,,,,,, -84500,0.21609913,1.5975558,,,,,,,,,,,,,,,,, -84600,0.206925,1.6150959,,,,,,,,,,,,,,,,, -84700,0.21999565,1.5934829,,,,,,,,,,,,,,,,, -84800,0.19609325,1.5431167,,,,,,,,,,,,,,,,, -84900,0.20431161,1.6439506,,,,,,,,,,,,,,,,, -84916,,,0.6718239784240723,1.5177863836288452,33.53921780192777,0.6841204762458801,1.4370832443237305,29.954040391928764,3000.0,0.698843777179718,1.3375470638275146,29.460929968466225,3003.0,30273.7618765831,52317.95136976242,30273.7618765831,22040.064561843872,1.2669031620025637,0.0 -85000,0.20557107,1.5934461,,,,,,,,,,,,,,,,, -85100,0.19104934,1.5645871,,,,,,,,,,,,,,,,, -85200,0.23556896,1.6565531,,,,,,,,,,,,,,,,, -85300,0.22654793,1.6346457,,,,,,,,,,,,,,,,, -85400,0.200428,1.6072663,,,,,,,,,,,,,,,,, -85500,0.22590351,1.6649808,,,,,,,,,,,,,,,,, -85600,0.20922399,1.5448028,,,,,,,,,,,,,,,,, -85700,0.20002352,1.6473428,,,,,,,,,,,,,,,,, -85800,0.21107799,1.7032201,,,,,,,,,,,,,,,,, -85900,0.2375818,1.6847826,,,,,,,,,,,,,,,,, -86000,0.20363742,1.6497207,,,,,,,,,,,,,,,,, -86100,0.2116824,1.5701044,,,,,,,,,,,,,,,,, -86200,0.21868417,1.5862253,,,,,,,,,,,,,,,,, -86300,0.2044086,1.5818199,,,,,,,,,,,,,,,,, -86400,0.21449533,1.7067363,,,,,,,,,,,,,,,,, -86500,0.21152471,1.6001185,,,,,,,,,,,,,,,,, -86600,0.21914385,1.5682052,,,,,,,,,,,,,,,,, -86700,0.23712263,1.6192014,,,,,,,,,,,,,,,,, -86800,0.22329426,1.603443,,,,,,,,,,,,,,,,, -86900,0.21228619,1.5774002,,,,,,,,,,,,,,,,, -87000,0.19650528,1.563559,,,,,,,,,,,,,,,,, -87100,0.20277022,1.553132,,,,,,,,,,,,,,,,, -87200,0.22046265,1.6141409,,,,,,,,,,,,,,,,, -87275,,,0.6702315211296082,1.5227426290512085,33.57381970171942,0.6851620078086853,1.4330618381500244,30.223012227748367,3000.0,0.699785053730011,1.3343571424484253,29.740632616160205,3003.0,31113.80547237396,53728.91593122482,31113.80547237396,22610.86890363693,1.3044085502624512,0.0 -87300,0.32098335,1.5852269,,,,,,,,,,,,,,,,, -87400,0.209108,1.5842407,,,,,,,,,,,,,,,,, -87500,0.22774863,1.652441,,,,,,,,,,,,,,,,, -87600,0.22072816,1.5624876,,,,,,,,,,,,,,,,, -87700,0.21386859,1.5957104,,,,,,,,,,,,,,,,, -87800,0.20128793,1.5611246,,,,,,,,,,,,,,,,, -87900,0.21547018,1.5663769,,,,,,,,,,,,,,,,, -88000,0.2138759,1.5340158,,,,,,,,,,,,,,,,, -88100,0.21492875,1.6208907,,,,,,,,,,,,,,,,, -88200,0.20678438,1.5983366,,,,,,,,,,,,,,,,, -88300,0.22027214,1.6407049,,,,,,,,,,,,,,,,, -88400,0.21631691,1.6120858,,,,,,,,,,,,,,,,, -88500,0.21041335,1.6077319,,,,,,,,,,,,,,,,, -88600,0.19732048,1.5105792,,,,,,,,,,,,,,,,, -88700,0.2194345,1.5948633,,,,,,,,,,,,,,,,, -88800,0.22040468,1.6360165,,,,,,,,,,,,,,,,, -88900,0.2138873,1.6305012,,,,,,,,,,,,,,,,, -89000,0.21688682,1.6008162,,,,,,,,,,,,,,,,, -89100,0.22737022,1.6055924,,,,,,,,,,,,,,,,, -89200,0.21295322,1.6109147,,,,,,,,,,,,,,,,, -89300,0.23470987,1.5865241,,,,,,,,,,,,,,,,, -89400,0.2092536,1.5599965,,,,,,,,,,,,,,,,, -89500,0.21236941,1.5358448,,,,,,,,,,,,,,,,, -89600,0.23767684,1.5494673,,,,,,,,,,,,,,,,, -89633,,,0.6788753867149353,1.4781488180160522,34.11023781864019,0.6842444539070129,1.4293118715286257,30.039313991534137,3000.0,0.7009587287902832,1.3281270265579224,29.58023972500876,3003.0,31953.73739719391,55190.8995449543,31953.73739719391,23232.804119586945,1.3434412479400637,0.0 -89700,0.2714491,1.6120439,,,,,,,,,,,,,,,,, -89800,0.20804016,1.5692202,,,,,,,,,,,,,,,,, -89900,0.20118281,1.5586094,,,,,,,,,,,,,,,,, -90000,0.22605495,1.6649864,,,,,,,,,,,,,,,,, -90100,0.23474583,1.6165075,,,,,,,,,,,,,,,,, -90200,0.20745039,1.6405687,,,,,,,,,,,,,,,,, -90300,0.20733209,1.6438704,,,,,,,,,,,,,,,,, -90400,0.21694496,1.5467895,,,,,,,,,,,,,,,,, -90500,0.2203031,1.522592,,,,,,,,,,,,,,,,, -90600,0.2424289,1.6002709,,,,,,,,,,,,,,,,, -90700,0.20629264,1.6759903,,,,,,,,,,,,,,,,, -90800,0.24497135,1.6171103,,,,,,,,,,,,,,,,, -90900,0.20458542,1.6122533,,,,,,,,,,,,,,,,, -91000,0.20707394,1.5835977,,,,,,,,,,,,,,,,, -91100,0.20920184,1.590933,,,,,,,,,,,,,,,,, -91200,0.23694777,1.6304103,,,,,,,,,,,,,,,,, -91300,0.21355885,1.6433084,,,,,,,,,,,,,,,,, -91400,0.2077679,1.536138,,,,,,,,,,,,,,,,, -91500,0.21011458,1.551192,,,,,,,,,,,,,,,,, -91600,0.20817089,1.5808293,,,,,,,,,,,,,,,,, -91700,0.2316423,1.644951,,,,,,,,,,,,,,,,, -91800,0.22847427,1.5895683,,,,,,,,,,,,,,,,, -91900,0.21444647,1.5422689,,,,,,,,,,,,,,,,, -91992,,,0.674981951713562,1.497982621192932,33.107271415973,0.6863523125648499,1.421521544456482,29.94934206485915,3000.0,0.7017953991889954,1.3203881978988647,30.07968705955243,3003.0,32793.78981876373,56641.69893527031,32793.78981876373,23843.425915002823,1.3887052536010742,0.0 -92000,0.21774201,1.6358246,,,,,,,,,,,,,,,,, -92100,0.2121851,1.5883131,,,,,,,,,,,,,,,,, -92200,0.3920177,1.6798229,,,,,,,,,,,,,,,,, -92300,0.21365921,1.5619121,,,,,,,,,,,,,,,,, -92400,0.2197099,1.6798003,,,,,,,,,,,,,,,,, -92500,0.22195195,1.5499614,,,,,,,,,,,,,,,,, -92600,0.20000221,1.5112369,,,,,,,,,,,,,,,,, -92700,0.2204048,1.6100209,,,,,,,,,,,,,,,,, -92800,0.23476453,1.6426158,,,,,,,,,,,,,,,,, -92900,0.23367262,1.6449254,,,,,,,,,,,,,,,,, -93000,0.325488,1.639581,,,,,,,,,,,,,,,,, -93100,0.21565133,1.587516,,,,,,,,,,,,,,,,, -93200,0.21991765,1.5565653,,,,,,,,,,,,,,,,, -93300,0.211434,1.5121686,,,,,,,,,,,,,,,,, -93400,0.3247167,1.5921842,,,,,,,,,,,,,,,,, -93500,0.21070646,1.5114697,,,,,,,,,,,,,,,,, -93600,0.21328409,1.550024,,,,,,,,,,,,,,,,, -93700,0.20653398,1.6034472,,,,,,,,,,,,,,,,, -93800,0.21548839,1.5130926,,,,,,,,,,,,,,,,, -93900,0.23891354,1.5787696,,,,,,,,,,,,,,,,, -94000,0.21548158,1.6024964,,,,,,,,,,,,,,,,, -94100,0.21324949,1.6723584,,,,,,,,,,,,,,,,, -94200,0.22398046,1.5483289,,,,,,,,,,,,,,,,, -94300,0.20104004,1.5172291,,,,,,,,,,,,,,,,, -94351,,,0.6841570734977722,1.4367806911468506,34.501937700227685,0.6876170039176941,1.4109599590301514,30.217930488187925,3000.0,0.704317033290863,1.3113782405853271,30.00001747016476,3003.0,33633.84631562233,58084.538499593735,33633.84631562233,24446.093323946,1.426877498626709,0.0 -94400,0.215297,1.5442133,,,,,,,,,,,,,,,,, -94500,0.2124004,1.5353081,,,,,,,,,,,,,,,,, -94600,0.23971887,1.5863088,,,,,,,,,,,,,,,,, -94700,0.22621877,1.5315135,,,,,,,,,,,,,,,,, -94800,0.21952137,1.5559361,,,,,,,,,,,,,,,,, -94900,0.2102375,1.6364019,,,,,,,,,,,,,,,,, -95000,0.21497978,1.5343902,,,,,,,,,,,,,,,,, -95100,0.22056781,1.5300909,,,,,,,,,,,,,,,,, -95200,0.2184267,1.560136,,,,,,,,,,,,,,,,, -95300,0.23205316,1.6392708,,,,,,,,,,,,,,,,, -95400,0.2179081,1.5734491,,,,,,,,,,,,,,,,, -95500,0.20753072,1.5468497,,,,,,,,,,,,,,,,, -95600,0.22496279,1.5823845,,,,,,,,,,,,,,,,, -95700,0.23472485,1.633984,,,,,,,,,,,,,,,,, -95800,0.2118989,1.4898111,,,,,,,,,,,,,,,,, -95900,0.21489702,1.5131485,,,,,,,,,,,,,,,,, -96000,0.22948071,1.5887864,,,,,,,,,,,,,,,,, -96100,0.21860926,1.5866013,,,,,,,,,,,,,,,,, -96200,0.21868506,1.447076,,,,,,,,,,,,,,,,, -96300,0.2184223,1.5502424,,,,,,,,,,,,,,,,, -96400,0.20748805,1.5314517,,,,,,,,,,,,,,,,, -96500,0.2125403,1.5543993,,,,,,,,,,,,,,,,, -96600,0.21263865,1.6098601,,,,,,,,,,,,,,,,, -96700,0.2304086,1.5388428,,,,,,,,,,,,,,,,, -96710,,,0.6783720254898071,1.473360896110535,34.322575911748935,0.6874558329582214,1.4094278812408447,30.34860323432693,3000.0,0.7056649923324585,1.3058409690856934,30.114120610001063,3003.0,34473.7677295208,59600.05470824242,34473.7677295208,25121.57190084457,1.466022491455078,0.0 -96800,0.2262166,1.5765692,,,,,,,,,,,,,,,,, -96900,0.21817236,1.5697682,,,,,,,,,,,,,,,,, -97000,0.22776836,1.5628642,,,,,,,,,,,,,,,,, -97100,0.21622853,1.4786544,,,,,,,,,,,,,,,,, -97200,0.22252508,1.5762186,,,,,,,,,,,,,,,,, -97300,0.26432657,1.6059105,,,,,,,,,,,,,,,,, -97400,0.21007241,1.5250899,,,,,,,,,,,,,,,,, -97500,0.21971029,1.5022203,,,,,,,,,,,,,,,,, -97600,0.21914493,1.5768592,,,,,,,,,,,,,,,,, -97700,0.22902982,1.5574938,,,,,,,,,,,,,,,,, -97800,0.21649753,1.5872905,,,,,,,,,,,,,,,,, -97900,0.22578065,1.5557092,,,,,,,,,,,,,,,,, -98000,0.20861268,1.5383583,,,,,,,,,,,,,,,,, -98100,0.2310033,1.5974355,,,,,,,,,,,,,,,,, -98200,0.21511045,1.5435832,,,,,,,,,,,,,,,,, -98300,0.22512926,1.5344801,,,,,,,,,,,,,,,,, -98400,0.20543717,1.5010574,,,,,,,,,,,,,,,,, -98500,0.21381126,1.5368357,,,,,,,,,,,,,,,,, -98600,0.25486583,1.5777199,,,,,,,,,,,,,,,,, -98700,0.21867381,1.5538439,,,,,,,,,,,,,,,,, -98800,0.21273229,1.480211,,,,,,,,,,,,,,,,, -98900,0.22600864,1.5803648,,,,,,,,,,,,,,,,, -99000,0.22512846,1.5034367,,,,,,,,,,,,,,,,, -99070,,,0.6781623959541321,1.4802968502044678,33.968795052975345,0.6889064908027649,1.4011088609695437,30.316761453508192,3000.0,0.7047004699707031,1.3031493425369265,30.178915366389266,3003.0,35313.945922613144,60997.73252558708,35313.945922613144,25678.95575976372,1.5046508312225342,0.0 -99100,0.23119691,1.5539625,,,,,,,,,,,,,,,,, -99200,0.23277783,1.4920635,,,,,,,,,,,,,,,,, -99300,0.22584239,1.5611699,,,,,,,,,,,,,,,,, -99400,0.2406451,1.5318071,,,,,,,,,,,,,,,,, -99500,0.33418792,1.5099115,,,,,,,,,,,,,,,,, -99600,0.22003335,1.5337924,,,,,,,,,,,,,,,,, -99700,0.22528468,1.4859276,,,,,,,,,,,,,,,,, -99800,0.22670662,1.4970412,,,,,,,,,,,,,,,,, -99900,0.23854886,1.5656314,,,,,,,,,,,,,,,,, -100000,0.21496932,1.557597,,,,,,,,,,,,,,,,, -100100,0.22658156,1.5752037,,,,,,,,,,,,,,,,, -100200,0.22448541,1.5448488,,,,,,,,,,,,,,,,, -100300,0.21681038,1.492982,,,,,,,,,,,,,,,,, -100400,0.22846332,1.5428952,,,,,,,,,,,,,,,,, -100500,0.23982136,1.5492828,,,,,,,,,,,,,,,,, -100600,0.20751138,1.489036,,,,,,,,,,,,,,,,, -100700,0.21602957,1.522428,,,,,,,,,,,,,,,,, -100800,0.24179216,1.5435894,,,,,,,,,,,,,,,,, -100900,0.2178445,1.5445231,,,,,,,,,,,,,,,,, -101000,0.23756523,1.559452,,,,,,,,,,,,,,,,, -101100,0.21826264,1.4996027,,,,,,,,,,,,,,,,, -101200,0.22892658,1.525725,,,,,,,,,,,,,,,,, -101300,0.21412979,1.5037698,,,,,,,,,,,,,,,,, -101400,0.21199392,1.5254432,,,,,,,,,,,,,,,,, -101430,,,0.6870272159576416,1.419952392578125,34.51947111345885,0.6896132826805115,1.3957315683364868,30.23538142734516,3000.0,0.7055255770683289,1.2928892374038696,30.34395915705975,3003.0,36153.99291825295,62491.293865680695,36153.99291825295,26332.34897923469,1.548503875732422,0.0 -101500,0.2265825,1.510189,,,,,,,,,,,,,,,,, -101600,0.2161583,1.6018199,,,,,,,,,,,,,,,,, -101700,0.23761323,1.5222213,,,,,,,,,,,,,,,,, -101800,0.22786695,1.5068733,,,,,,,,,,,,,,,,, -101900,0.22678532,1.6176231,,,,,,,,,,,,,,,,, -102000,0.21896619,1.516591,,,,,,,,,,,,,,,,, -102100,0.24934384,1.4873346,,,,,,,,,,,,,,,,, -102200,0.21861094,1.5548179,,,,,,,,,,,,,,,,, -102300,0.22711796,1.6460204,,,,,,,,,,,,,,,,, -102400,0.21949662,1.5919327,,,,,,,,,,,,,,,,, -102500,0.23634747,1.507264,,,,,,,,,,,,,,,,, -102600,0.21253514,1.4532567,,,,,,,,,,,,,,,,, -102700,0.23036538,1.4803427,,,,,,,,,,,,,,,,, -102800,0.23212235,1.4750819,,,,,,,,,,,,,,,,, -102900,0.21592486,1.5678018,,,,,,,,,,,,,,,,, -103000,0.23431553,1.5212659,,,,,,,,,,,,,,,,, -103100,0.22685951,1.4712633,,,,,,,,,,,,,,,,, -103200,0.20790277,1.4081813,,,,,,,,,,,,,,,,, -103300,0.20878202,1.4443039,,,,,,,,,,,,,,,,, -103400,0.21643999,1.5362458,,,,,,,,,,,,,,,,, -103500,0.23402263,1.501576,,,,,,,,,,,,,,,,, -103600,0.21857636,1.5062753,,,,,,,,,,,,,,,,, -103700,0.24143353,1.539152,,,,,,,,,,,,,,,,, -103788,,,0.6809467077255249,1.462726354598999,34.7810708399674,0.6901588439941406,1.3913047313690186,30.3594091547342,3000.0,0.7094416618347168,1.28492534160614,30.334914627920924,3003.0,36993.88757181168,63936.5658159256,36993.88757181168,26937.60739517212,1.5876469612121582,0.0 -103800,0.22357038,1.524225,,,,,,,,,,,,,,,,, -103900,0.23445496,1.5877358,,,,,,,,,,,,,,,,, -104000,0.23900443,1.5346577,,,,,,,,,,,,,,,,, -104100,0.21975978,1.5353309,,,,,,,,,,,,,,,,, -104200,0.23600943,1.5608423,,,,,,,,,,,,,,,,, -104300,0.8027412,1.572852,,,,,,,,,,,,,,,,, -104400,0.2088219,1.4769316,,,,,,,,,,,,,,,,, -104500,0.29767084,1.5498278,,,,,,,,,,,,,,,,, -104600,0.25176919,1.5032642,,,,,,,,,,,,,,,,, -104700,0.222654,1.5353236,,,,,,,,,,,,,,,,, -104800,0.22346935,1.5661894,,,,,,,,,,,,,,,,, -104900,0.22948198,1.5339236,,,,,,,,,,,,,,,,, -105000,0.22548717,1.5765768,,,,,,,,,,,,,,,,, -105100,0.22800395,1.5291781,,,,,,,,,,,,,,,,, -105200,0.23638459,1.5768939,,,,,,,,,,,,,,,,, -105300,0.22705887,1.4705452,,,,,,,,,,,,,,,,, -105400,0.21176448,1.4608705,,,,,,,,,,,,,,,,, -105500,0.23679279,1.4931087,,,,,,,,,,,,,,,,, -105600,0.2319286,1.5534399,,,,,,,,,,,,,,,,, -105700,0.20968382,1.5017321,,,,,,,,,,,,,,,,, -105800,0.23283947,1.5536369,,,,,,,,,,,,,,,,, -105900,0.21905619,1.5198543,,,,,,,,,,,,,,,,, -106000,0.22591068,1.5136421,,,,,,,,,,,,,,,,, -106100,0.22582656,1.6123323,,,,,,,,,,,,,,,,, -106147,,,0.687261164188385,1.4270676374435425,34.439128262689984,0.6915971040725708,1.3868253231048584,30.57444599493856,3000.0,0.7098251581192017,1.2795196771621704,30.52362190743765,3003.0,37833.93421292305,65344.22820472717,37833.93421292305,27505.089040517807,1.6410527229309082,0.0 -106200,0.21733756,1.550387,,,,,,,,,,,,,,,,, -106300,0.23216546,1.5693797,,,,,,,,,,,,,,,,, -106400,0.22844262,1.5284805,,,,,,,,,,,,,,,,, -106500,0.24108934,1.5151834,,,,,,,,,,,,,,,,, -106600,0.21739869,1.4699292,,,,,,,,,,,,,,,,, -106700,0.2368319,1.4996315,,,,,,,,,,,,,,,,, -106800,0.224456,1.5519735,,,,,,,,,,,,,,,,, -106900,0.21769124,1.4828537,,,,,,,,,,,,,,,,, -107000,0.23469304,1.5962816,,,,,,,,,,,,,,,,, -107100,0.22963667,1.4593103,,,,,,,,,,,,,,,,, -107200,0.21606322,1.4719306,,,,,,,,,,,,,,,,, -107300,0.23063147,1.4854993,,,,,,,,,,,,,,,,, -107400,0.23177236,1.5445949,,,,,,,,,,,,,,,,, -107500,0.22577226,1.4802587,,,,,,,,,,,,,,,,, -107600,0.22428313,1.4215468,,,,,,,,,,,,,,,,, -107700,0.23546475,1.5383685,,,,,,,,,,,,,,,,, -107800,0.24092203,1.4708148,,,,,,,,,,,,,,,,, -107900,0.22806709,1.4651438,,,,,,,,,,,,,,,,, -108000,0.2316175,1.5542871,,,,,,,,,,,,,,,,, -108100,0.22138107,1.5140319,,,,,,,,,,,,,,,,, -108200,0.24346642,1.5629362,,,,,,,,,,,,,,,,, -108300,0.23128767,1.5422121,,,,,,,,,,,,,,,,, -108400,0.23341882,1.4692028,,,,,,,,,,,,,,,,, -108500,0.25135916,1.5186399,,,,,,,,,,,,,,,,, -108506,,,0.6925578117370605,1.3921180963516235,35.037549363106514,0.6933826208114624,1.3761391639709473,30.82730396732934,3000.0,0.7099180817604065,1.2736002206802368,30.54607527793401,3003.0,38673.95379757881,66828.72117614746,38673.95379757881,28149.444666147232,1.6813023090362549,0.0 -108600,0.22173338,1.4934231,,,,,,,,,,,,,,,,, -108700,0.22718608,1.4567924,,,,,,,,,,,,,,,,, -108800,0.22620633,1.4322807,,,,,,,,,,,,,,,,, -108900,0.22286493,1.4606879,,,,,,,,,,,,,,,,, -109000,0.23233801,1.5615047,,,,,,,,,,,,,,,,, -109100,0.22908472,1.4801477,,,,,,,,,,,,,,,,, -109200,0.22818549,1.4935685,,,,,,,,,,,,,,,,, -109300,0.2254384,1.458047,,,,,,,,,,,,,,,,, -109400,0.22852997,1.4113911,,,,,,,,,,,,,,,,, -109500,0.22563982,1.5194033,,,,,,,,,,,,,,,,, -109600,0.22594987,1.4673569,,,,,,,,,,,,,,,,, -109700,0.22267233,1.4882879,,,,,,,,,,,,,,,,, -109800,0.22963098,1.4603263,,,,,,,,,,,,,,,,, -109900,0.23730421,1.5220133,,,,,,,,,,,,,,,,, -110000,0.21572061,1.4277148,,,,,,,,,,,,,,,,, -110100,0.22681817,1.5051118,,,,,,,,,,,,,,,,, -110200,0.23747921,1.5363934,,,,,,,,,,,,,,,,, -110300,0.22077495,1.4385487,,,,,,,,,,,,,,,,, -110400,0.2308924,1.4766682,,,,,,,,,,,,,,,,, -110500,0.2311827,1.4893961,,,,,,,,,,,,,,,,, -110600,0.23205756,1.412606,,,,,,,,,,,,,,,,, -110700,0.23294248,1.4650576,,,,,,,,,,,,,,,,, -110800,0.23456302,1.5298691,,,,,,,,,,,,,,,,, -110865,,,0.6871196031570435,1.423920512199402,34.8735757618847,0.6927750110626221,1.379041075706482,30.65946671057725,3000.0,0.7104293704032898,1.2698314189910889,30.823263329626307,3003.0,39513.917432546616,68266.67229032516,39513.917432546616,28747.312561512,1.722498893737793,0.0 -110900,0.22480223,1.5480268,,,,,,,,,,,,,,,,, -111000,0.21886212,1.5427023,,,,,,,,,,,,,,,,, -111100,0.22417487,1.5310265,,,,,,,,,,,,,,,,, -111200,0.21794342,1.4815371,,,,,,,,,,,,,,,,, -111300,0.22413895,1.452186,,,,,,,,,,,,,,,,, -111400,0.22876017,1.5153579,,,,,,,,,,,,,,,,, -111500,0.23490526,1.5226628,,,,,,,,,,,,,,,,, -111600,0.24277933,1.5048962,,,,,,,,,,,,,,,,, -111700,0.23792209,1.5045214,,,,,,,,,,,,,,,,, -111800,0.22905686,1.4311429,,,,,,,,,,,,,,,,, -111900,0.23594199,1.4972279,,,,,,,,,,,,,,,,, -112000,0.24628432,1.4262744,,,,,,,,,,,,,,,,, -112100,0.2324305,1.4370862,,,,,,,,,,,,,,,,, -112200,0.23606682,1.5903819,,,,,,,,,,,,,,,,, -112300,0.25166076,1.5260067,,,,,,,,,,,,,,,,, -112400,0.22513257,1.4283854,,,,,,,,,,,,,,,,, -112500,0.22389548,1.4808189,,,,,,,,,,,,,,,,, -112600,0.24402153,1.5013977,,,,,,,,,,,,,,,,, -112700,0.21049766,1.4318062,,,,,,,,,,,,,,,,, -112800,0.22927542,1.4622234,,,,,,,,,,,,,,,,, -112900,0.25242317,1.4935273,,,,,,,,,,,,,,,,, -113000,0.22716595,1.5335612,,,,,,,,,,,,,,,,, -113100,0.22428472,1.4225888,,,,,,,,,,,,,,,,, -113200,0.23489375,1.5154289,,,,,,,,,,,,,,,,, -113222,,,0.6949504017829895,1.366610407829285,35.29129947522768,0.6940769553184509,1.373827338218689,30.822025096723628,3000.0,0.7110685110092163,1.2664237022399902,30.57820511062706,3003.0,40353.82215619087,69755.15716338158,40353.82215619087,29395.77043747902,1.7632851600646973,0.0 -113300,0.24144603,1.5614009,,,,,,,,,,,,,,,,, -113400,0.22980176,1.5037143,,,,,,,,,,,,,,,,, -113500,0.22877932,1.4874866,,,,,,,,,,,,,,,,, -113600,0.23076943,1.5385176,,,,,,,,,,,,,,,,, -113700,0.23944242,1.5545199,,,,,,,,,,,,,,,,, -113800,0.2361035,1.4709777,,,,,,,,,,,,,,,,, -113900,0.24017279,1.5610356,,,,,,,,,,,,,,,,, -114000,0.23799405,1.4458673,,,,,,,,,,,,,,,,, -114100,0.22642244,1.413123,,,,,,,,,,,,,,,,, -114200,0.24622092,1.4221313,,,,,,,,,,,,,,,,, -114300,0.22701177,1.5069344,,,,,,,,,,,,,,,,, -114400,0.22514051,1.4752232,,,,,,,,,,,,,,,,, -114500,0.22876331,1.4621263,,,,,,,,,,,,,,,,, -114600,0.23982927,1.4836448,,,,,,,,,,,,,,,,, -114700,0.24169426,1.5386436,,,,,,,,,,,,,,,,, -114800,0.23152405,1.4414711,,,,,,,,,,,,,,,,, -114900,0.27248552,1.4830328,,,,,,,,,,,,,,,,, -115000,0.23061268,1.467392,,,,,,,,,,,,,,,,, -115100,0.3315788,1.4698586,,,,,,,,,,,,,,,,, -115200,0.26083496,1.4301012,,,,,,,,,,,,,,,,, -115300,0.24455962,1.5067812,,,,,,,,,,,,,,,,, -115400,0.23540846,1.4820182,,,,,,,,,,,,,,,,, -115500,0.25438643,1.4490308,,,,,,,,,,,,,,,,, -115581,,,0.692932665348053,1.3845998048782349,34.87558627803715,0.6947588920593262,1.3719394207000732,30.69022831385768,3000.0,0.7114287614822388,1.2623103857040403,30.915548870642528,3003.0,41193.96239209175,71226.29557180405,41193.96239209175,30026.637234210968,1.8142666816711424,0.0 -115600,0.24012831,1.4735683,,,,,,,,,,,,,,,,, -115700,0.2400922,1.4935424,,,,,,,,,,,,,,,,, -115800,0.24224839,1.4401613,,,,,,,,,,,,,,,,, -115900,0.2338954,1.4358946,,,,,,,,,,,,,,,,, -116000,0.23173769,1.5015194,,,,,,,,,,,,,,,,, -116100,0.21680643,1.4143444,,,,,,,,,,,,,,,,, -116200,0.22524372,1.4584838,,,,,,,,,,,,,,,,, -116300,0.2326616,1.4103053,,,,,,,,,,,,,,,,, -116400,0.24096961,1.4335854,,,,,,,,,,,,,,,,, -116500,0.22670688,1.5348148,,,,,,,,,,,,,,,,, -116600,0.23576504,1.5081478,,,,,,,,,,,,,,,,, -116700,0.22503246,1.438626,,,,,,,,,,,,,,,,, -116800,0.23430109,1.5349921,,,,,,,,,,,,,,,,, -116900,0.22651058,1.5220058,,,,,,,,,,,,,,,,, -117000,0.23848048,1.524722,,,,,,,,,,,,,,,,, -117100,0.2346474,1.4365757,,,,,,,,,,,,,,,,, -117200,0.24441195,1.503507,,,,,,,,,,,,,,,,, -117300,0.23138182,1.4113786,,,,,,,,,,,,,,,,, -117400,0.2458329,1.5670615,,,,,,,,,,,,,,,,, -117500,0.23067477,1.5036005,,,,,,,,,,,,,,,,, -117600,0.22717816,1.4333428,,,,,,,,,,,,,,,,, -117700,0.25240833,1.583502,,,,,,,,,,,,,,,,, -117800,0.23847786,1.4945195,,,,,,,,,,,,,,,,, -117900,0.23063873,1.4301127,,,,,,,,,,,,,,,,, -117940,,,0.6903288960456848,1.406310796737671,35.18875162329316,0.6950441002845764,1.3678932189941406,30.774742153044187,3000.0,0.7126837372779846,1.2599252462387085,31.04379746971764,3003.0,42033.97531580925,72615.43852734566,42033.97531580925,30575.64464998245,1.8567523956298828,0.0 -118000,0.23446275,1.4088039,,,,,,,,,,,,,,,,, -118100,0.23138946,1.4115466,,,,,,,,,,,,,,,,, -118200,0.23733744,1.4620217,,,,,,,,,,,,,,,,, -118300,0.22361566,1.4605843,,,,,,,,,,,,,,,,, -118400,0.24090792,1.455029,,,,,,,,,,,,,,,,, -118500,0.24267146,1.4646153,,,,,,,,,,,,,,,,, -118600,0.23613434,1.506762,,,,,,,,,,,,,,,,, -118700,0.24237192,1.5412707,,,,,,,,,,,,,,,,, -118800,0.22902831,1.4447346,,,,,,,,,,,,,,,,, -118900,0.23366295,1.4675773,,,,,,,,,,,,,,,,, -119000,0.23182966,1.4453555,,,,,,,,,,,,,,,,, -119100,0.23822881,1.4298509,,,,,,,,,,,,,,,,, -119200,0.23426446,1.4036164,,,,,,,,,,,,,,,,, -119300,0.22918771,1.4192493,,,,,,,,,,,,,,,,, -119400,0.22533952,1.4128302,,,,,,,,,,,,,,,,, -119500,0.2300107,1.4055089,,,,,,,,,,,,,,,,, -119600,0.22414716,1.419624,,,,,,,,,,,,,,,,, -119700,0.23010485,1.4922863,,,,,,,,,,,,,,,,, -119800,0.23487742,1.4546994,,,,,,,,,,,,,,,,, -119900,0.23200871,1.4188596,,,,,,,,,,,,,,,,, -120000,0.23400584,1.436497,,,,,,,,,,,,,,,,, -120100,0.23198836,1.5208205,,,,,,,,,,,,,,,,, -120200,0.23628134,1.4730848,,,,,,,,,,,,,,,,, -120299,,,0.6964380145072937,1.3709895610809326,35.66247077415662,0.6952300667762756,1.3651654720306396,30.875901980810767,3000.0,0.7123816609382629,1.2575608491897583,30.803521074857283,3003.0,42873.95320272446,74017.72199606895,42873.95320272446,31137.830321788788,1.8995048999786377,0.0 -120300,0.2398367,1.4347695,,,,,,,,,,,,,,,,, -120400,0.24346213,1.5134002,,,,,,,,,,,,,,,,, -120500,0.23444147,1.4334748,,,,,,,,,,,,,,,,, -120600,0.23725449,1.4848526,,,,,,,,,,,,,,,,, -120700,0.2247539,1.4006282,,,,,,,,,,,,,,,,, -120800,0.2367481,1.409355,,,,,,,,,,,,,,,,, -120900,0.24992862,1.4383879,,,,,,,,,,,,,,,,, -121000,0.24214949,1.4328638,,,,,,,,,,,,,,,,, -121100,0.2236883,1.429405,,,,,,,,,,,,,,,,, -121200,0.2401887,1.4555871,,,,,,,,,,,,,,,,, -121300,0.22620547,1.4850445,,,,,,,,,,,,,,,,, -121400,0.24587922,1.4634504,,,,,,,,,,,,,,,,, -121500,0.24143891,1.4806892,,,,,,,,,,,,,,,,, -121600,0.23810986,1.4731824,,,,,,,,,,,,,,,,, -121700,0.2366444,1.4365759,,,,,,,,,,,,,,,,, -121800,0.2502376,1.5287832,,,,,,,,,,,,,,,,, -121900,0.22835395,1.507115,,,,,,,,,,,,,,,,, -122000,0.23150335,1.4554695,,,,,,,,,,,,,,,,, -122100,0.22850421,1.4377244,,,,,,,,,,,,,,,,, -122200,0.24471648,1.5403104,,,,,,,,,,,,,,,,, -122300,0.23195869,1.4557981,,,,,,,,,,,,,,,,, -122400,0.22744235,1.4631264,,,,,,,,,,,,,,,,, -122500,0.23112075,1.4441442,,,,,,,,,,,,,,,,, -122600,0.2428244,1.4648547,,,,,,,,,,,,,,,,, -122659,,,0.6943261623382568,1.382710337638855,35.10050834963898,0.6955276131629944,1.3625982999801636,31.04729680749209,3000.0,0.7139387726783752,1.2523609399795532,30.83641365004823,3003.0,43713.98013854027,75454.82591438293,43713.98013854027,31734.78758907318,1.942683458328247,0.0 -122700,0.24951504,1.5124983,,,,,,,,,,,,,,,,, -122800,0.23154993,1.4307337,,,,,,,,,,,,,,,,, -122900,0.23999895,1.3974175,,,,,,,,,,,,,,,,, -123000,0.2394764,1.4721495,,,,,,,,,,,,,,,,, -123100,0.22791335,1.4630563,,,,,,,,,,,,,,,,, -123200,0.23781003,1.4070206,,,,,,,,,,,,,,,,, -123300,0.22378941,1.4252791,,,,,,,,,,,,,,,,, -123400,0.23742495,1.4988726,,,,,,,,,,,,,,,,, -123500,0.24227601,1.4879997,,,,,,,,,,,,,,,,, -123600,0.24133314,1.5020783,,,,,,,,,,,,,,,,, -123700,0.23966907,1.4747283,,,,,,,,,,,,,,,,, -123800,0.23980775,1.3996394,,,,,,,,,,,,,,,,, -123900,0.23418596,1.4693037,,,,,,,,,,,,,,,,, -124000,0.22801185,1.4541037,,,,,,,,,,,,,,,,, -124100,0.241552,1.4790109,,,,,,,,,,,,,,,,, -124200,0.23752889,1.4419625,,,,,,,,,,,,,,,,, -124300,0.23916757,1.4673036,,,,,,,,,,,,,,,,, -124400,0.24007097,1.4232284,,,,,,,,,,,,,,,,, -124500,0.23155734,1.4237094,,,,,,,,,,,,,,,,, -124600,0.24481598,1.4887016,,,,,,,,,,,,,,,,, -124700,0.23777542,1.4467448,,,,,,,,,,,,,,,,, -124800,0.24092667,1.4302602,,,,,,,,,,,,,,,,, -124900,0.24353966,1.4570204,,,,,,,,,,,,,,,,, -125000,0.23413125,1.4976736,,,,,,,,,,,,,,,,, -125019,,,0.6947394609451294,1.377518892288208,35.855026133964195,0.6967179775238037,1.361031174659729,30.87607412885397,3000.0,0.713927149772644,1.2523263692855835,30.98107247465367,3003.0,44553.89228606224,76870.24442744255,44553.89228606224,32310.166967868805,1.9930167198181152,0.0 -125100,0.2396596,1.4926839,,,,,,,,,,,,,,,,, -125200,0.23850217,1.4294983,,,,,,,,,,,,,,,,, -125300,0.23392461,1.3847637,,,,,,,,,,,,,,,,, -125400,0.23023412,1.469189,,,,,,,,,,,,,,,,, -125500,0.23275325,1.4286939,,,,,,,,,,,,,,,,, -125600,0.23690923,1.4839368,,,,,,,,,,,,,,,,, -125700,0.23986247,1.4522094,,,,,,,,,,,,,,,,, -125800,0.23376043,1.4573689,,,,,,,,,,,,,,,,, -125900,0.25432473,1.413128,,,,,,,,,,,,,,,,, -126000,0.23846596,1.4286085,,,,,,,,,,,,,,,,, -126100,0.23116633,1.4720955,,,,,,,,,,,,,,,,, -126200,0.2396752,1.4392463,,,,,,,,,,,,,,,,, -126300,0.22619964,1.3984898,,,,,,,,,,,,,,,,, -126400,0.21728677,1.4085145,,,,,,,,,,,,,,,,, -126500,0.24066241,1.4841411,,,,,,,,,,,,,,,,, -126600,0.2327468,1.4690341,,,,,,,,,,,,,,,,, -126700,0.23013936,1.3825402,,,,,,,,,,,,,,,,, -126800,0.2408612,1.4228947,,,,,,,,,,,,,,,,, -126900,0.23247992,1.4050814,,,,,,,,,,,,,,,,, -127000,0.22600438,1.4215232,,,,,,,,,,,,,,,,, -127100,0.24249686,1.4226606,,,,,,,,,,,,,,,,, -127200,0.23736769,1.4356376,,,,,,,,,,,,,,,,, -127300,0.23790035,1.4547493,,,,,,,,,,,,,,,,, -127379,,,0.6961644291877747,1.368328094482422,35.79156931045429,0.6965195536613464,1.3605613708496094,30.94567958438413,3000.0,0.7140550017356873,1.251923441886902,30.88051741175642,3003.0,45394.04832172394,78304.93042588234,45394.04832172394,32904.57277917862,2.0373668670654297,0.0 -127400,0.22919276,1.4334195,,,,,,,,,,,,,,,,, -127500,0.24090037,1.4662979,,,,,,,,,,,,,,,,, -127600,0.23329377,1.4000616,,,,,,,,,,,,,,,,, -127700,0.23093185,1.4795552,,,,,,,,,,,,,,,,, -127800,0.22412238,1.4200349,,,,,,,,,,,,,,,,, -127900,0.22927155,1.4410304,,,,,,,,,,,,,,,,, -128000,0.23397692,1.4314722,,,,,,,,,,,,,,,,, -128100,0.22732024,1.4349483,,,,,,,,,,,,,,,,, -128200,0.23050258,1.4187354,,,,,,,,,,,,,,,,, -128300,0.23112136,1.4396155,,,,,,,,,,,,,,,,, -128400,0.23456629,1.4089063,,,,,,,,,,,,,,,,, -128500,0.24249722,1.4327351,,,,,,,,,,,,,,,,, -128600,0.22684649,1.3851182,,,,,,,,,,,,,,,,, -128700,0.23412687,1.389767,,,,,,,,,,,,,,,,, -128800,0.23715636,1.4850552,,,,,,,,,,,,,,,,, -128900,0.24588643,1.4171755,,,,,,,,,,,,,,,,, -129000,0.24642657,1.4537859,,,,,,,,,,,,,,,,, -129100,0.22405453,1.4445329,,,,,,,,,,,,,,,,, -129200,0.23895217,1.4859728,,,,,,,,,,,,,,,,, -129300,0.24214347,1.4696404,,,,,,,,,,,,,,,,, -129400,0.2341847,1.4407861,,,,,,,,,,,,,,,,, -129500,0.23580594,1.3801876,,,,,,,,,,,,,,,,, -129600,0.2282278,1.4409475,,,,,,,,,,,,,,,,, -129700,0.23383997,1.5431648,,,,,,,,,,,,,,,,, -129739,,,0.6961685419082642,1.3700485229492188,35.88213839829184,0.6964327692985535,1.36008882522583,30.98190345178364,3000.0,0.7140898704528809,1.2510143518447876,30.877321535383,3003.0,46234.16683530808,79724.03761744499,46234.16683530808,33483.43272519112,2.08833909034729,0.0 -129800,0.2378554,1.4452418,,,,,,,,,,,,,,,,, -129900,0.22399831,1.3824023,,,,,,,,,,,,,,,,, -130000,0.234925,1.4600474,,,,,,,,,,,,,,,,, -130100,0.23684974,1.43439,,,,,,,,,,,,,,,,, -130200,0.23786119,1.4562972,,,,,,,,,,,,,,,,, -130300,0.23402439,1.4539194,,,,,,,,,,,,,,,,, -130400,0.22890757,1.4410388,,,,,,,,,,,,,,,,, -130500,0.22748783,1.4434682,,,,,,,,,,,,,,,,, -130600,0.23085558,1.4457487,,,,,,,,,,,,,,,,, -130700,0.23538145,1.4555236,,,,,,,,,,,,,,,,, -130800,0.23718973,1.4936802,,,,,,,,,,,,,,,,, -130900,0.23331432,1.4109058,,,,,,,,,,,,,,,,, -131000,0.23291366,1.4332087,,,,,,,,,,,,,,,,, -131100,0.2357533,1.4345778,,,,,,,,,,,,,,,,, -131200,0.2267108,1.3623766,,,,,,,,,,,,,,,,, -131300,0.24103475,1.4441903,,,,,,,,,,,,,,,,, -131400,0.22531253,1.397573,,,,,,,,,,,,,,,,, -131500,0.23457058,1.4335449,,,,,,,,,,,,,,,,, -131600,0.23746203,1.5011586,,,,,,,,,,,,,,,,, -131700,0.24317609,1.4937798,,,,,,,,,,,,,,,,, -131800,0.23326364,1.5022688,,,,,,,,,,,,,,,,, -131900,0.23004778,1.3954433,,,,,,,,,,,,,,,,, -132000,0.2303924,1.4160051,,,,,,,,,,,,,,,,, -132099,,,0.6940845251083374,1.3850692510604858,35.896859808845385,0.6964203715324402,1.359937071800232,30.921410299106665,3000.0,0.7141827940940857,1.2505288124084473,30.886567051931777,3003.0,47074.08837556839,81159.41278767586,47074.08837556839,34078.76060009003,2.137282371520996,0.0 -132100,0.23819841,1.4552639,,,,,,,,,,,,,,,,, -132200,0.23248462,1.4586922,,,,,,,,,,,,,,,,, -132300,0.2387679,1.5190489,,,,,,,,,,,,,,,,, -132400,0.23352113,1.4173598,,,,,,,,,,,,,,,,, -132500,0.2427483,1.4283396,,,,,,,,,,,,,,,,, -132600,0.22773889,1.5001873,,,,,,,,,,,,,,,,, -132700,0.23215678,1.3867995,,,,,,,,,,,,,,,,, -132800,0.23097154,1.3392466,,,,,,,,,,,,,,,,, -132900,0.22690254,1.4488834,,,,,,,,,,,,,,,,, -133000,0.24937621,1.4740685,,,,,,,,,,,,,,,,, -133100,0.22662936,1.4496013,,,,,,,,,,,,,,,,, -133200,0.2410606,1.5091478,,,,,,,,,,,,,,,,, -133300,0.22854903,1.4606477,,,,,,,,,,,,,,,,, -133333,,,0.6966027021408081,1.3738631010055542,35.481369630996994,0.6963583827018738,1.360085368156433,30.92441491532882,3000.0,0.7141595482826233,1.2506054639816284,30.925173570681565,3003.0,47513.10943126679,82169.46109032631,47513.10943126679,34649.70331478119,2.181474447250366,0.0 -133333,,,,,,,,,,,,,,47513.109431266785,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 40611c881..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -785.1935272216797,0.0,27.716558933258057,1,0,27.716558933258057,1.1383302977796053,95000000,812.9101238250732,1.1387101996619746,1.1412394287140513,83274637 -1433.7186193466189,0.0321598052978515,1227.957686662674,1576,0,1227.957686662674,0.1283590380139802,95000000,2661.7595295906067,0.1265354322374991,0.1257825652357692,83274637 -2006.2125263214111,0.0580251216888427,2427.9918830394745,3157,0,2427.9918830394745,0.1273610610094572,95000000,4434.363398075104,0.1250145559509595,0.1248898546215668,83274637 -2546.6115198135376,0.0854334831237793,3628.334785699845,4736,0,3628.334785699845,0.1271178178659539,95000000,6175.182409763336,0.1245123667630759,0.1246454173882349,83274637 -3048.265507698059,0.107569932937622,4828.327179908752,6300,0,4828.327179908752,0.1267488249794408,95000000,7876.901391744614,0.1229369122089829,0.1243104364736678,83274637 -3489.507905483246,0.137164831161499,6028.836967468262,7873,0,6028.836967468262,0.126348890234375,95000000,9518.732946872711,0.1235138203644152,0.1239335534513497,83274637 -3810.321284532547,0.16239404678344727,7228.878541469574,9450,0,7228.878541469574,0.12599259650493422,95000000,11039.662580251694,0.12572382367069615,0.1236452393286926,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index cfa214078..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,110 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.7002125,1.1401201,,,,,,,,,,, -1,,,1.1387101996619746,1.1412394287140513,83274637.0,1.1383302977796053,95000000.0,27.716558933258057,812.9101238250732,27.716558933258057,785.1935272216797,0.0,0.0 -100,0.24383008,0.14878187,,,,,,,,,,, -200,0.013091584,0.12994152,,,,,,,,,,, -300,0.010891981,0.12662204,,,,,,,,,,, -400,0.015032606,0.13138969,,,,,,,,,,, -500,0.03152271,0.125022,,,,,,,,,,, -600,0.053681355,0.12564978,,,,,,,,,,, -700,0.011523451,0.12797575,,,,,,,,,,, -800,0.008856374,0.12639582,,,,,,,,,,, -900,0.023055794,0.12295103,,,,,,,,,,, -1000,0.030516803,0.12980601,,,,,,,,,,, -1100,0.020605383,0.123496965,,,,,,,,,,, -1200,0.028999712,0.12137407,,,,,,,,,,, -1300,0.01464291,0.12002707,,,,,,,,,,, -1400,0.025066461,0.12370108,,,,,,,,,,, -1500,0.008459271,0.1271017,,,,,,,,,,, -1576,,,0.1265354322374991,0.1257825652357692,83274637.0,0.1283590380139802,95000000.0,1227.957686662674,2661.7595295906067,1227.957686662674,1433.7186193466189,0.0321598052978515,0.0 -1600,0.01162522,0.12531254,,,,,,,,,,, -1700,0.017693082,0.1265181,,,,,,,,,,, -1800,0.030237064,0.12776184,,,,,,,,,,, -1900,0.015421996,0.12462021,,,,,,,,,,, -2000,0.01435982,0.12458937,,,,,,,,,,, -2100,0.0067298403,0.122516975,,,,,,,,,,, -2200,0.010286536,0.12670936,,,,,,,,,,, -2300,0.03861526,0.117453806,,,,,,,,,,, -2400,0.0137084965,0.124727495,,,,,,,,,,, -2500,0.0057432214,0.13027062,,,,,,,,,,, -2600,0.008523842,0.120722644,,,,,,,,,,, -2700,0.006940018,0.12078036,,,,,,,,,,, -2800,0.007575915,0.123394564,,,,,,,,,,, -2900,0.007605333,0.11718398,,,,,,,,,,, -3000,0.005237864,0.12716387,,,,,,,,,,, -3100,0.010437182,0.12145832,,,,,,,,,,, -3157,,,0.1250145559509595,0.1248898546215668,83274637.0,0.1273610610094572,95000000.0,2427.9918830394745,4434.363398075104,2427.9918830394745,2006.2125263214111,0.0580251216888427,0.0 -3200,0.012158472,0.11746885,,,,,,,,,,, -3300,0.013264378,0.123476185,,,,,,,,,,, -3400,0.0064012823,0.12646101,,,,,,,,,,, -3500,0.018339187,0.11728511,,,,,,,,,,, -3600,0.008660787,0.114668064,,,,,,,,,,, -3700,0.011241067,0.1228232,,,,,,,,,,, -3800,0.014488466,0.12458989,,,,,,,,,,, -3900,0.006825056,0.12284803,,,,,,,,,,, -4000,0.017386807,0.12693106,,,,,,,,,,, -4100,0.029129008,0.115885384,,,,,,,,,,, -4200,0.042656157,0.12910168,,,,,,,,,,, -4300,0.015078864,0.119770035,,,,,,,,,,, -4400,0.011522433,0.12070279,,,,,,,,,,, -4500,0.016065499,0.12676589,,,,,,,,,,, -4600,0.0066766036,0.1260093,,,,,,,,,,, -4700,0.0123857055,0.11825952,,,,,,,,,,, -4736,,,0.1245123667630759,0.1246454173882349,83274637.0,0.1271178178659539,95000000.0,3628.334785699845,6175.182409763336,3628.334785699845,2546.6115198135376,0.0854334831237793,0.0 -4800,0.006367566,0.12626922,,,,,,,,,,, -4900,0.0065239687,0.119995154,,,,,,,,,,, -5000,0.01444806,0.122415766,,,,,,,,,,, -5100,0.01151765,0.12505803,,,,,,,,,,, -5200,0.00941134,0.11976142,,,,,,,,,,, -5300,0.01467388,0.12235555,,,,,,,,,,, -5400,0.012974384,0.122432455,,,,,,,,,,, -5500,0.016527114,0.122912124,,,,,,,,,,, -5600,0.014621125,0.13048,,,,,,,,,,, -5700,0.007303681,0.12623139,,,,,,,,,,, -5800,0.02923589,0.13564952,,,,,,,,,,, -5900,0.013886849,0.13469706,,,,,,,,,,, -6000,0.011891468,0.13042505,,,,,,,,,,, -6100,0.0064127087,0.12450581,,,,,,,,,,, -6200,0.0057824897,0.12877831,,,,,,,,,,, -6300,,,0.1229369122089829,0.1243104364736678,83274637.0,0.1267488249794408,95000000.0,4828.327179908752,7876.901391744614,4828.327179908752,3048.265507698059,0.107569932937622,0.0 -6300,0.0072490536,0.11581876,,,,,,,,,,, -6400,0.013236335,0.11770435,,,,,,,,,,, -6500,0.019342236,0.114047,,,,,,,,,,, -6600,0.009538265,0.1202817,,,,,,,,,,, -6700,0.008865009,0.1235397,,,,,,,,,,, -6800,0.012066671,0.12931274,,,,,,,,,,, -6900,0.015932625,0.1273464,,,,,,,,,,, -7000,0.007357311,0.12648769,,,,,,,,,,, -7100,0.010450959,0.12574467,,,,,,,,,,, -7200,0.010972499,0.12807378,,,,,,,,,,, -7300,0.0062460583,0.118562154,,,,,,,,,,, -7400,0.009980612,0.12098844,,,,,,,,,,, -7500,0.00541103,0.12361343,,,,,,,,,,, -7600,0.018055994,0.12365401,,,,,,,,,,, -7700,0.0114403805,0.12564743,,,,,,,,,,, -7800,0.0076588653,0.12206203,,,,,,,,,,, -7873,,,0.1235138203644152,0.1239335534513497,83274637.0,0.126348890234375,95000000.0,6028.836967468262,9518.732946872711,6028.836967468262,3489.507905483246,0.137164831161499,0.0 -7900,0.017706854,0.12267377,,,,,,,,,,, -8000,0.015573464,0.11957896,,,,,,,,,,, -8100,0.0090615535,0.13428412,,,,,,,,,,, -8200,0.0074274964,0.1307882,,,,,,,,,,, -8300,0.005932137,0.11898576,,,,,,,,,,, -8400,0.0064847604,0.12572427,,,,,,,,,,, -8500,0.0077352314,0.115640104,,,,,,,,,,, -8600,0.00708441,0.11689823,,,,,,,,,,, -8700,0.016511008,0.1345076,,,,,,,,,,, -8800,0.008876984,0.120986514,,,,,,,,,,, -8900,0.019547682,0.12798372,,,,,,,,,,, -9000,0.0116292,0.12477104,,,,,,,,,,, -9100,0.0073604668,0.12239869,,,,,,,,,,, -9200,0.011416247,0.11910376,,,,,,,,,,, -9300,0.027518792,0.11765481,,,,,,,,,,, -9400,0.010459235,0.11933068,,,,,,,,,,, -9450,,,0.1257238236706961,0.1236452393286926,83274637.0,0.1259925965049342,95000000.0,7228.878541469574,11039.662580251694,7228.878541469574,3810.321284532547,0.1623940467834472,0.0 -9500,0.0154415015,0.11936553,,,,,,,,,,, -9600,0.01433595,0.11523857,,,,,,,,,,, -9700,0.007792307,0.11882177,,,,,,,,,,, -9800,0.009383341,0.12641138,,,,,,,,,,, -9900,0.006927387,0.12728961,,,,,,,,,,, -10000,0.009838309,0.12015448,,,,,,,,,,, -10082,,,,,,,,7703.8553948402405,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/eval_measurements.csv deleted file mode 100644 index a66c3055e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -142.94095396995544,0.0,5.999122381210327,1,0,5.999122381210327,1.1383302977796053,95000000,148.94012141227722,1.1381916647437234,1.1412394287140513,83274637 -164.9314045906067,0.0208840370178222,1206.7365925312042,1412,0,1206.7365925312042,0.1283274101356908,95000000,1371.735160112381,0.1242735869479629,0.1258216658077416,83274637 -188.3298783302307,0.0456488132476806,2406.797998189926,2865,0,2406.797998189926,0.1273177606393914,95000000,2595.266819000244,0.1248169975273264,0.1249308702142706,83274637 -210.46006774902344,0.0704514980316162,3606.999135971069,4321,0,3606.999135971069,0.1269606765522204,95000000,3817.669621944428,0.1232700200294548,0.1247298527514145,83274637 -233.3674511909485,0.0937106609344482,4807.333256721497,5755,0,4807.333256721497,0.1267128283511513,95000000,5040.979163885117,0.1226670890109344,0.1243506332396261,83274637 -255.59379959106445,0.1180720329284668,6007.554527759552,7201,0,6007.554527759552,0.1262388908717105,95000000,6263.496790409088,0.124997078575803,0.1239542023236318,83274637 -278.29405975341797,0.14840149879455566,7208.236174821854,8653,0,7208.236174821854,0.1260754088199013,95000000,7486.954886198044,0.12412071260828642,0.12380786007471578,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/measurements.csv deleted file mode 100644 index bd7589513..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/measurements.csv +++ /dev/null @@ -1,101 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.895897,1.1361253,,,,,,,,,,, -1,,,1.1381916647437234,1.1412394287140513,83274637.0,1.1383302977796053,95000000.0,5.999122381210327,148.94012141227722,5.999122381210327,142.94095396995544,0.0,0.0 -100,0.15877701,0.13580965,,,,,,,,,,, -200,0.45911163,0.13721,,,,,,,,,,, -300,0.080517955,0.12535727,,,,,,,,,,, -400,0.04639653,0.12652424,,,,,,,,,,, -500,0.069940194,0.13342834,,,,,,,,,,, -600,0.03939476,0.12435555,,,,,,,,,,, -700,0.021190166,0.12412526,,,,,,,,,,, -800,0.0762932,0.13642465,,,,,,,,,,, -900,0.011331771,0.124189004,,,,,,,,,,, -1000,0.04241869,0.12683073,,,,,,,,,,, -1100,0.106556594,0.1233192,,,,,,,,,,, -1200,0.007934796,0.1149347,,,,,,,,,,, -1300,0.10373124,0.13081092,,,,,,,,,,, -1400,0.04536738,0.13309678,,,,,,,,,,, -1412,,,0.1242735869479629,0.1258216658077416,83274637.0,0.1283274101356908,95000000.0,1206.7365925312042,1371.735160112381,1206.7365925312042,164.9314045906067,0.0208840370178222,0.0 -1500,0.009404721,0.11950159,,,,,,,,,,, -1600,0.03353843,0.12368977,,,,,,,,,,, -1700,0.05589022,0.13267964,,,,,,,,,,, -1800,0.024120728,0.12651376,,,,,,,,,,, -1900,0.00846579,0.12007589,,,,,,,,,,, -2000,0.013400545,0.11880896,,,,,,,,,,, -2100,0.027414037,0.13303886,,,,,,,,,,, -2200,0.024530385,0.12395034,,,,,,,,,,, -2300,0.034361623,0.122060575,,,,,,,,,,, -2400,0.02173608,0.11964709,,,,,,,,,,, -2500,0.011439231,0.12664361,,,,,,,,,,, -2600,0.046755638,0.13233712,,,,,,,,,,, -2700,0.025655124,0.119589336,,,,,,,,,,, -2800,0.008127233,0.12308384,,,,,,,,,,, -2865,,,0.1248169975273264,0.1249308702142706,83274637.0,0.1273177606393914,95000000.0,2406.797998189926,2595.266819000244,2406.797998189926,188.3298783302307,0.0456488132476806,0.0 -2900,0.017054843,0.12132361,,,,,,,,,,, -3000,0.020355346,0.11852489,,,,,,,,,,, -3100,0.02414406,0.13113335,,,,,,,,,,, -3200,0.033654887,0.117789485,,,,,,,,,,, -3300,0.013705828,0.116813466,,,,,,,,,,, -3400,0.055415366,0.12538119,,,,,,,,,,, -3500,0.03051774,0.1219767,,,,,,,,,,, -3600,0.06419036,0.13297233,,,,,,,,,,, -3700,0.029351214,0.12778388,,,,,,,,,,, -3800,0.012299313,0.13315862,,,,,,,,,,, -3900,0.026741387,0.1282285,,,,,,,,,,, -4000,0.03411944,0.122529365,,,,,,,,,,, -4100,0.015100866,0.11782851,,,,,,,,,,, -4200,0.054741293,0.1345721,,,,,,,,,,, -4300,0.032419518,0.122347705,,,,,,,,,,, -4321,,,0.1232700200294548,0.1247298527514145,83274637.0,0.1269606765522204,95000000.0,3606.999135971069,3817.669621944428,3606.999135971069,210.46006774902344,0.0704514980316162,0.0 -4400,0.008869184,0.12781206,,,,,,,,,,, -4500,0.037103366,0.12935355,,,,,,,,,,, -4600,0.009751366,0.116102904,,,,,,,,,,, -4700,0.03279877,0.11457239,,,,,,,,,,, -4800,0.0072406605,0.11847854,,,,,,,,,,, -4900,0.00993664,0.11627845,,,,,,,,,,, -5000,0.008296701,0.11821434,,,,,,,,,,, -5100,0.0060408874,0.119850904,,,,,,,,,,, -5200,0.027774801,0.12422663,,,,,,,,,,, -5300,0.037620228,0.11813497,,,,,,,,,,, -5400,0.013191805,0.11860917,,,,,,,,,,, -5500,0.01795246,0.11938642,,,,,,,,,,, -5600,0.019858131,0.12678492,,,,,,,,,,, -5700,0.011351689,0.1173955,,,,,,,,,,, -5755,,,0.1226670890109344,0.1243506332396261,83274637.0,0.1267128283511513,95000000.0,4807.333256721497,5040.979163885117,4807.333256721497,233.3674511909485,0.0937106609344482,0.0 -5800,0.015840111,0.12131603,,,,,,,,,,, -5900,0.007170521,0.12596737,,,,,,,,,,, -6000,0.013529948,0.12422128,,,,,,,,,,, -6100,0.0075196004,0.12437001,,,,,,,,,,, -6200,0.018679123,0.115760915,,,,,,,,,,, -6300,0.008042818,0.12231227,,,,,,,,,,, -6400,0.014198091,0.12760024,,,,,,,,,,, -6500,0.012393793,0.12401425,,,,,,,,,,, -6600,0.0066165016,0.12427472,,,,,,,,,,, -6700,0.006429489,0.13413441,,,,,,,,,,, -6800,0.01766555,0.11932086,,,,,,,,,,, -6900,0.013569157,0.11508713,,,,,,,,,,, -7000,0.0070113647,0.12870851,,,,,,,,,,, -7100,0.0149563905,0.11951274,,,,,,,,,,, -7200,0.009439767,0.1199363,,,,,,,,,,, -7201,,,0.124997078575803,0.1239542023236318,83274637.0,0.1262388908717105,95000000.0,6007.554527759552,6263.496790409088,6007.554527759552,255.59379959106445,0.1180720329284668,0.0 -7300,0.010612552,0.12582852,,,,,,,,,,, -7400,0.021272158,0.13399625,,,,,,,,,,, -7500,0.0074924896,0.12727384,,,,,,,,,,, -7600,0.0067822016,0.11996639,,,,,,,,,,, -7700,0.011045068,0.11808551,,,,,,,,,,, -7800,0.0080609005,0.11921682,,,,,,,,,,, -7900,0.022696882,0.11664951,,,,,,,,,,, -8000,0.014387902,0.12707104,,,,,,,,,,, -8100,0.0072890404,0.12212642,,,,,,,,,,, -8200,0.008439851,0.12702027,,,,,,,,,,, -8300,0.0067999535,0.12197433,,,,,,,,,,, -8400,0.00905231,0.122835845,,,,,,,,,,, -8500,0.009853225,0.12117256,,,,,,,,,,, -8600,0.007189799,0.124038644,,,,,,,,,,, -8653,,,0.1241207126082864,0.1238078600747157,83274637.0,0.1260754088199013,95000000.0,7208.236174821854,7486.954886198044,7208.236174821854,278.294059753418,0.1484014987945556,0.0 -8700,0.012390366,0.13025525,,,,,,,,,,, -8800,0.008760139,0.12253609,,,,,,,,,,, -8900,0.0075650504,0.121103965,,,,,,,,,,, -9000,0.0068234783,0.11931207,,,,,,,,,,, -9100,0.0063867574,0.12154928,,,,,,,,,,, -9184,,,,,,,,7703.111186981201,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 8165a89b7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.36259961128235,0.0,5.777649641036987,1,0,5.777649641036987,1.1383302977796053,95000000,28.140287399291992,1.136454695425693,1.1412394287140513,83274637 -44.922887086868286,0.0194365978240966,1206.3347568511963,1420,0,1206.3347568511963,0.1284445489617598,95000000,1251.3243670463562,0.1234569041140424,0.1260957703123941,83274637 -67.28594326972961,0.0445485115051269,2406.601612329483,2839,0,2406.601612329483,0.1278019597039473,95000000,2474.02592253685,0.1252233332151887,0.1253235501757576,83274637 -89.56600451469421,0.0698604583740234,3606.578542947769,4254,0,3606.578542947769,0.1268322721628289,95000000,3696.3545751571655,0.1229900993734785,0.1244401118738139,83274637 -111.88411116600037,0.0961744785308837,4807.285554647446,5661,0,4807.285554647446,0.1266506056332237,95000000,4919.452340841293,0.1233027539913009,0.1243415355660001,83274637 -134.33966636657715,0.1190066337585449,6007.639580249786,7089,0,6007.639580249786,0.1265344593852796,95000000,6142.330835580826,0.1228763087646766,0.1242025253251779,83274637 -156.77400660514832,0.14187884330749512,7207.912177801132,8502,0,7207.912177801132,0.12635440872738488,95000000,7365.107407331467,0.12026732218153072,0.12393889111721676,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/measurements.csv deleted file mode 100644 index 5dff2dba5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/measurements.csv +++ /dev/null @@ -1,100 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,9.004341,1.1355703,,,,,,,,,,, -1,,,1.136454695425693,1.1412394287140513,83274637.0,1.1383302977796053,95000000.0,5.777649641036987,28.140287399291992,5.777649641036987,22.36259961128235,0.0,0.0 -100,0.2195468,0.14842516,,,,,,,,,,, -200,0.023573576,0.12922555,,,,,,,,,,, -300,0.008325954,0.12937993,,,,,,,,,,, -400,0.0411794,0.12470615,,,,,,,,,,, -500,0.020833697,0.12782061,,,,,,,,,,, -600,0.013026058,0.12351695,,,,,,,,,,, -700,0.017509408,0.11993018,,,,,,,,,,, -800,0.019172125,0.122616716,,,,,,,,,,, -900,0.023660768,0.1255528,,,,,,,,,,, -1000,0.008256648,0.12215307,,,,,,,,,,, -1100,0.02494568,0.12786335,,,,,,,,,,, -1200,0.069961414,0.12710826,,,,,,,,,,, -1300,0.04093263,0.12489389,,,,,,,,,,, -1400,0.00858516,0.12036084,,,,,,,,,,, -1420,,,0.1234569041140424,0.1260957703123941,83274637.0,0.1284445489617598,95000000.0,1206.3347568511963,1251.3243670463562,1206.3347568511963,44.922887086868286,0.0194365978240966,0.0 -1500,0.014007774,0.1259068,,,,,,,,,,, -1600,0.0074059935,0.12434555,,,,,,,,,,, -1700,0.023707965,0.122452036,,,,,,,,,,, -1800,0.011978514,0.12264089,,,,,,,,,,, -1900,0.010092269,0.12329286,,,,,,,,,,, -2000,0.018841522,0.12177594,,,,,,,,,,, -2100,0.009631566,0.1253958,,,,,,,,,,, -2200,0.009466187,0.1245709,,,,,,,,,,, -2300,0.012612258,0.12185256,,,,,,,,,,, -2400,0.016221603,0.12543194,,,,,,,,,,, -2500,0.013394631,0.12128663,,,,,,,,,,, -2600,0.008456676,0.12731266,,,,,,,,,,, -2700,0.05017735,0.12494735,,,,,,,,,,, -2800,0.007918957,0.121055804,,,,,,,,,,, -2839,,,0.1252233332151887,0.1253235501757576,83274637.0,0.1278019597039473,95000000.0,2406.601612329483,2474.02592253685,2406.601612329483,67.28594326972961,0.0445485115051269,0.0 -2900,0.025157863,0.1268828,,,,,,,,,,, -3000,0.020252218,0.12631463,,,,,,,,,,, -3100,0.008045888,0.12241657,,,,,,,,,,, -3200,0.022065837,0.1237916,,,,,,,,,,, -3300,0.013211815,0.122445226,,,,,,,,,,, -3400,0.006591093,0.12134187,,,,,,,,,,, -3500,0.017819544,0.12689398,,,,,,,,,,, -3600,0.01856412,0.12022172,,,,,,,,,,, -3700,0.006771855,0.11965509,,,,,,,,,,, -3800,0.023850325,0.12646161,,,,,,,,,,, -3900,0.0166568,0.116867475,,,,,,,,,,, -4000,0.019986942,0.117117666,,,,,,,,,,, -4100,0.018358916,0.12413156,,,,,,,,,,, -4200,0.005855414,0.12258653,,,,,,,,,,, -4254,,,0.1229900993734785,0.1244401118738139,83274637.0,0.1268322721628289,95000000.0,3606.578542947769,3696.3545751571655,3606.578542947769,89.56600451469421,0.0698604583740234,0.0 -4300,0.0066447346,0.1228343,,,,,,,,,,, -4400,0.011771175,0.121072724,,,,,,,,,,, -4500,0.01876903,0.12540253,,,,,,,,,,, -4600,0.02065401,0.12733772,,,,,,,,,,, -4700,0.01762381,0.1370727,,,,,,,,,,, -4800,0.007966693,0.13400961,,,,,,,,,,, -4900,0.0064255134,0.120518856,,,,,,,,,,, -5000,0.010503186,0.13262567,,,,,,,,,,, -5100,0.027004369,0.12771668,,,,,,,,,,, -5200,0.0084518725,0.12547094,,,,,,,,,,, -5300,0.011434232,0.12760395,,,,,,,,,,, -5400,0.0062366044,0.11931784,,,,,,,,,,, -5500,0.010800204,0.12088795,,,,,,,,,,, -5600,0.0058805044,0.12211694,,,,,,,,,,, -5661,,,0.1233027539913009,0.1243415355660001,83274637.0,0.1266506056332237,95000000.0,4807.285554647446,4919.452340841293,4807.285554647446,111.88411116600037,0.0961744785308837,0.0 -5700,0.01090348,0.118979156,,,,,,,,,,, -5800,0.008188949,0.11734389,,,,,,,,,,, -5900,0.023772221,0.14013359,,,,,,,,,,, -6000,0.015817512,0.122405745,,,,,,,,,,, -6100,0.012542065,0.11993846,,,,,,,,,,, -6200,0.016172329,0.12270611,,,,,,,,,,, -6300,0.011993517,0.12993109,,,,,,,,,,, -6400,0.006008287,0.12234146,,,,,,,,,,, -6500,0.012413178,0.1199281,,,,,,,,,,, -6600,0.017183267,0.11853789,,,,,,,,,,, -6700,0.007481068,0.124759614,,,,,,,,,,, -6800,0.011720652,0.1348174,,,,,,,,,,, -6900,0.0082824705,0.119925976,,,,,,,,,,, -7000,0.010304082,0.12507293,,,,,,,,,,, -7089,,,0.1228763087646766,0.1242025253251779,83274637.0,0.1265344593852796,95000000.0,6007.639580249786,6142.330835580826,6007.639580249786,134.33966636657715,0.1190066337585449,0.0 -7100,0.015812052,0.12466569,,,,,,,,,,, -7200,0.009570475,0.11468696,,,,,,,,,,, -7300,0.013501475,0.11700627,,,,,,,,,,, -7400,0.008453471,0.12060854,,,,,,,,,,, -7500,0.0096919015,0.1191439,,,,,,,,,,, -7600,0.012694883,0.12025249,,,,,,,,,,, -7700,0.009797173,0.12320998,,,,,,,,,,, -7800,0.006456026,0.1266503,,,,,,,,,,, -7900,0.013408278,0.11911102,,,,,,,,,,, -8000,0.012345686,0.11923166,,,,,,,,,,, -8100,0.011385912,0.11976627,,,,,,,,,,, -8200,0.0072099757,0.124963544,,,,,,,,,,, -8300,0.013417654,0.115980834,,,,,,,,,,, -8400,0.010523681,0.13886325,,,,,,,,,,, -8500,0.018550169,0.12516,,,,,,,,,,, -8502,,,0.1202673221815307,0.1239388911172167,83274637.0,0.1263544087273848,95000000.0,7207.912177801132,7365.107407331467,7207.912177801132,156.77400660514832,0.1418788433074951,0.0 -8600,0.029755535,0.12626682,,,,,,,,,,, -8700,0.009580715,0.12110446,,,,,,,,,,, -8800,0.006268214,0.11674887,,,,,,,,,,, -8900,0.0077448403,0.12248897,,,,,,,,,,, -9000,0.009515113,0.12364941,,,,,,,,,,, -9016,,,,,,,,7703.492475032806,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 258ce2c01..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -22.5659658908844,0.0,9.183252811431885,1,0,9.183252811431885,1.1383302977796053,95000000,31.74926257133484,1.1377923391150229,1.1412394287140513,83274637 -44.02743005752564,0.0185277462005615,1209.4919574260712,1553,0,1209.4919574260712,0.1286738535259046,95000000,1253.5894060134888,0.1237834237469067,0.1264740821536094,83274637 -65.60438871383667,0.0516526699066162,2409.6908428668976,3102,0,2409.6908428668976,0.1286282376130756,95000000,2475.4475286006927,0.1235720918238537,0.1262985270315558,83274637 -87.3217556476593,0.0744392871856689,3609.815655231476,4651,0,3609.815655231476,0.1278445281044408,95000000,3697.361325979233,0.12652800419608,0.1256186724844264,83274637 -109.03082489967346,0.1045811176300048,4810.387387752533,6206,0,4810.387387752533,0.1277123699013158,95000000,4919.721008300781,0.1242330909337637,0.1254181505581045,83274637 -130.81375217437744,0.1347107887268066,6010.884887456894,7754,0,6010.884887456894,0.1278925601151316,95000000,6142.080229520798,0.1240850572222433,0.1256028640058797,83274637 -152.58654189109802,0.16337895393371582,7211.590596199036,9298,0,7211.590596199036,0.1276181388877467,95000000,7364.635870933533,0.12417525559101465,0.1254012833650959,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/measurements.csv deleted file mode 100644 index 6e8daa448..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/measurements.csv +++ /dev/null @@ -1,109 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.763918,1.1391962,,,,,,,,,,, -1,,,1.1377923391150229,1.1412394287140513,83274637.0,1.1383302977796053,95000000.0,9.183252811431885,31.74926257133484,9.183252811431885,22.5659658908844,0.0,0.0 -100,0.19830255,0.1337387,,,,,,,,,,, -200,0.016985236,0.12486619,,,,,,,,,,, -300,0.07834911,0.12710041,,,,,,,,,,, -400,0.025604676,0.12618966,,,,,,,,,,, -500,0.022816028,0.12298228,,,,,,,,,,, -600,0.010240288,0.1208424,,,,,,,,,,, -700,0.016066676,0.12828052,,,,,,,,,,, -800,0.020151164,0.1275284,,,,,,,,,,, -900,0.024337199,0.12483462,,,,,,,,,,, -1000,0.026390001,0.12451852,,,,,,,,,,, -1100,0.04945678,0.12914066,,,,,,,,,,, -1200,0.017253544,0.1270198,,,,,,,,,,, -1300,0.023387881,0.12443112,,,,,,,,,,, -1400,0.03321936,0.13824785,,,,,,,,,,, -1500,0.0064427233,0.13292089,,,,,,,,,,, -1553,,,0.1237834237469067,0.1264740821536094,83274637.0,0.1286738535259046,95000000.0,1209.4919574260712,1253.5894060134888,1209.4919574260712,44.02743005752564,0.0185277462005615,0.0 -1600,0.023811081,0.12790519,,,,,,,,,,, -1700,0.03825195,0.12031356,,,,,,,,,,, -1800,0.038240496,0.12879834,,,,,,,,,,, -1900,0.029522432,0.12503941,,,,,,,,,,, -2000,0.043613683,0.12648207,,,,,,,,,,, -2100,0.06493681,0.12199803,,,,,,,,,,, -2200,0.04524204,0.12401636,,,,,,,,,,, -2300,0.028200705,0.121770136,,,,,,,,,,, -2400,0.04447812,0.11692427,,,,,,,,,,, -2500,0.045758054,0.13541695,,,,,,,,,,, -2600,0.07071561,0.12731346,,,,,,,,,,, -2700,0.025534075,0.12954465,,,,,,,,,,, -2800,0.044604577,0.13275628,,,,,,,,,,, -2900,0.03768813,0.12518254,,,,,,,,,,, -3000,0.023652477,0.12614803,,,,,,,,,,, -3100,0.017927222,0.12366004,,,,,,,,,,, -3102,,,0.1235720918238537,0.1262985270315558,83274637.0,0.1286282376130756,95000000.0,2409.6908428668976,2475.4475286006927,2409.6908428668976,65.60438871383667,0.0516526699066162,0.0 -3200,0.034077983,0.1305116,,,,,,,,,,, -3300,0.0048245513,0.12677793,,,,,,,,,,, -3400,0.00995345,0.121668905,,,,,,,,,,, -3500,0.03314028,0.13287833,,,,,,,,,,, -3600,0.013366056,0.12118204,,,,,,,,,,, -3700,0.043387406,0.1246627,,,,,,,,,,, -3800,0.011837643,0.113343984,,,,,,,,,,, -3900,0.029817507,0.12574758,,,,,,,,,,, -4000,0.01671156,0.12180291,,,,,,,,,,, -4100,0.028413849,0.13612133,,,,,,,,,,, -4200,0.039853748,0.12043022,,,,,,,,,,, -4300,0.032695197,0.12863006,,,,,,,,,,, -4400,0.012523953,0.12443,,,,,,,,,,, -4500,0.029920865,0.1252161,,,,,,,,,,, -4600,0.0052468874,0.1364731,,,,,,,,,,, -4651,,,0.12652800419608,0.1256186724844264,83274637.0,0.1278445281044408,95000000.0,3609.815655231476,3697.361325979233,3609.815655231476,87.3217556476593,0.0744392871856689,0.0 -4700,0.008438607,0.12334362,,,,,,,,,,, -4800,0.01179665,0.12854259,,,,,,,,,,, -4900,0.044395536,0.12989402,,,,,,,,,,, -5000,0.028644659,0.11893052,,,,,,,,,,, -5100,0.027208906,0.13276115,,,,,,,,,,, -5200,0.023622215,0.1258018,,,,,,,,,,, -5300,0.016072892,0.12410818,,,,,,,,,,, -5400,0.012244789,0.13096038,,,,,,,,,,, -5500,0.022370858,0.12526558,,,,,,,,,,, -5600,0.020041905,0.12589602,,,,,,,,,,, -5700,0.013850278,0.13422054,,,,,,,,,,, -5800,0.0039385427,0.119611055,,,,,,,,,,, -5900,0.026738789,0.124559306,,,,,,,,,,, -6000,0.010244146,0.11666626,,,,,,,,,,, -6100,0.02368105,0.1242594,,,,,,,,,,, -6200,0.025635734,0.122609586,,,,,,,,,,, -6206,,,0.1242330909337637,0.1254181505581045,83274637.0,0.1277123699013158,95000000.0,4810.387387752533,4919.721008300781,4810.387387752533,109.03082489967346,0.1045811176300048,0.0 -6300,0.022283131,0.11926788,,,,,,,,,,, -6400,0.026061008,0.12367189,,,,,,,,,,, -6500,0.032404434,0.12770139,,,,,,,,,,, -6600,0.017177846,0.13770564,,,,,,,,,,, -6700,0.010503659,0.1250424,,,,,,,,,,, -6800,0.028449427,0.12175281,,,,,,,,,,, -6900,0.030188175,0.118504025,,,,,,,,,,, -7000,0.005900583,0.12135411,,,,,,,,,,, -7100,0.013872769,0.12175412,,,,,,,,,,, -7200,0.020152014,0.13665913,,,,,,,,,,, -7300,0.007852465,0.12787865,,,,,,,,,,, -7400,0.017541753,0.1302454,,,,,,,,,,, -7500,0.014518543,0.12764856,,,,,,,,,,, -7600,0.022455493,0.1293567,,,,,,,,,,, -7700,0.015124062,0.12716654,,,,,,,,,,, -7754,,,0.1240850572222433,0.1256028640058797,83274637.0,0.1278925601151316,95000000.0,6010.884887456894,6142.080229520798,6010.884887456894,130.81375217437744,0.1347107887268066,0.0 -7800,0.008033981,0.124090105,,,,,,,,,,, -7900,0.009362594,0.12154714,,,,,,,,,,, -8000,0.016628982,0.12018354,,,,,,,,,,, -8100,0.00807453,0.1316793,,,,,,,,,,, -8200,0.013170577,0.13622367,,,,,,,,,,, -8300,0.00727621,0.124016464,,,,,,,,,,, -8400,0.0064783595,0.12410295,,,,,,,,,,, -8500,0.010508219,0.11868785,,,,,,,,,,, -8600,0.008760848,0.12177724,,,,,,,,,,, -8700,0.0122955395,0.1275083,,,,,,,,,,, -8800,0.010535799,0.1285533,,,,,,,,,,, -8900,0.008096058,0.124013916,,,,,,,,,,, -9000,0.007316603,0.12576094,,,,,,,,,,, -9100,0.016246716,0.12384556,,,,,,,,,,, -9200,0.008161498,0.1220302,,,,,,,,,,, -9298,,,0.1241752555910146,0.1254012833650959,83274637.0,0.1276181388877467,95000000.0,7211.590596199036,7364.635870933533,7211.590596199036,152.58654189109802,0.1633789539337158,0.0 -9300,0.010514734,0.1194999,,,,,,,,,,, -9400,0.008689071,0.11493563,,,,,,,,,,, -9500,0.011726927,0.12052154,,,,,,,,,,, -9600,0.010867343,0.12623936,,,,,,,,,,, -9700,0.010039786,0.122174226,,,,,,,,,,, -9800,0.010349579,0.12325975,,,,,,,,,,, -9900,0.006870278,0.12077733,,,,,,,,,,, -9935,,,,,,,,7703.547810316086,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 51a8ec596..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,8 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -21.726317644119263,0.0,5.916582107543945,1,0,5.916582107543945,1.1383302977796053,95000000,27.642942667007446,1.137724521025172,1.1412394287140513,83274637 -43.38259959220886,0.0167696475982666,1206.0222396850586,1435,0,1206.0222396850586,0.1285183155016447,95000000,1249.4685735702517,0.1271230274776242,0.1260156512863274,83274637 -65.00065398216248,0.0475471019744873,2406.620073080063,2861,0,2406.620073080063,0.1274107140419408,95000000,2471.7603373527527,0.1212015245226944,0.124953155848791,83274637 -87.16549277305603,0.0716829299926757,3606.948037385941,4300,0,3606.948037385941,0.1269433251644736,95000000,3694.322910785675,0.1248393994551034,0.1245742178790584,83274637 -108.96741604804993,0.0965254306793212,4807.114329099655,5742,0,4807.114329099655,0.1265665383018092,95000000,4916.36280465126,0.1219415247159184,0.1242236964246688,83274637 -130.712637424469,0.1242105960845947,6007.315401077271,7188,0,6007.315401077271,0.1262289715049342,95000000,6138.382369041443,0.1219883938349268,0.123905986112254,83274637 -152.55200576782227,0.149766206741333,7207.853322982788,8613,0,7207.853322982788,0.1259711364925987,95000000,7360.8301429748535,0.12188524190547331,0.12371287828309596,83274637 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/measurements.csv deleted file mode 100644 index df174dcb5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/measurements.csv +++ /dev/null @@ -1,101 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,8.671969,1.1413251,,,,,,,,,,, -1,,,1.137724521025172,1.1412394287140513,83274637.0,1.1383302977796053,95000000.0,5.916582107543945,27.642942667007446,5.916582107543945,21.726317644119263,0.0,0.0 -100,0.13462387,0.1362634,,,,,,,,,,, -200,0.059699122,0.13274145,,,,,,,,,,, -300,0.006679597,0.11512595,,,,,,,,,,, -400,0.0706696,0.12462514,,,,,,,,,,, -500,0.03192076,0.12389795,,,,,,,,,,, -600,0.043816697,0.12412751,,,,,,,,,,, -700,0.030718103,0.12410068,,,,,,,,,,, -800,0.04061881,0.118558116,,,,,,,,,,, -900,0.019394653,0.12702745,,,,,,,,,,, -1000,0.010277657,0.11911905,,,,,,,,,,, -1100,0.057754513,0.12513629,,,,,,,,,,, -1200,0.08134076,0.13263501,,,,,,,,,,, -1300,0.0125672035,0.117888935,,,,,,,,,,, -1400,0.006504627,0.11849068,,,,,,,,,,, -1435,,,0.1271230274776242,0.1260156512863274,83274637.0,0.1285183155016447,95000000.0,1206.0222396850586,1249.4685735702517,1206.0222396850586,43.38259959220886,0.0167696475982666,0.0 -1500,0.04974661,0.12322987,,,,,,,,,,, -1600,0.08361404,0.12624812,,,,,,,,,,, -1700,0.021132987,0.12742752,,,,,,,,,,, -1800,0.013470068,0.12925835,,,,,,,,,,, -1900,0.005722752,0.124191254,,,,,,,,,,, -2000,0.0064390735,0.13315055,,,,,,,,,,, -2100,0.038050007,0.13641626,,,,,,,,,,, -2200,0.03665783,0.13503213,,,,,,,,,,, -2300,0.027556209,0.11932279,,,,,,,,,,, -2400,0.04256562,0.12071201,,,,,,,,,,, -2500,0.004875055,0.118422806,,,,,,,,,,, -2600,0.026754027,0.121923216,,,,,,,,,,, -2700,0.013047592,0.12080453,,,,,,,,,,, -2800,0.0072613414,0.122929364,,,,,,,,,,, -2861,,,0.1212015245226944,0.124953155848791,83274637.0,0.1274107140419408,95000000.0,2406.620073080063,2471.7603373527527,2406.620073080063,65.00065398216248,0.0475471019744873,0.0 -2900,0.016805857,0.124512464,,,,,,,,,,, -3000,0.008125109,0.121179834,,,,,,,,,,, -3100,0.016649855,0.13074899,,,,,,,,,,, -3200,0.027017772,0.12254907,,,,,,,,,,, -3300,0.028269196,0.13463037,,,,,,,,,,, -3400,0.0068470114,0.13177039,,,,,,,,,,, -3500,0.0071451045,0.124129094,,,,,,,,,,, -3600,0.018497063,0.12372634,,,,,,,,,,, -3700,0.05595312,0.13486427,,,,,,,,,,, -3800,0.0047906553,0.12319681,,,,,,,,,,, -3900,0.050147813,0.13150844,,,,,,,,,,, -4000,0.017117035,0.11632915,,,,,,,,,,, -4100,0.027553564,0.12186057,,,,,,,,,,, -4200,0.024093512,0.12603953,,,,,,,,,,, -4300,,,0.1248393994551034,0.1245742178790584,83274637.0,0.1269433251644736,95000000.0,3606.948037385941,3694.322910785675,3606.948037385941,87.16549277305603,0.0716829299926757,0.0 -4300,0.01691768,0.12224764,,,,,,,,,,, -4400,0.022380123,0.11965147,,,,,,,,,,, -4500,0.023158545,0.12395055,,,,,,,,,,, -4600,0.01982721,0.124709606,,,,,,,,,,, -4700,0.015328315,0.13292359,,,,,,,,,,, -4800,0.0058420035,0.11989681,,,,,,,,,,, -4900,0.0125515815,0.1200019,,,,,,,,,,, -5000,0.0122667495,0.11410335,,,,,,,,,,, -5100,0.01862362,0.13179469,,,,,,,,,,, -5200,0.021843841,0.12794648,,,,,,,,,,, -5300,0.019687844,0.13191879,,,,,,,,,,, -5400,0.01071819,0.13088776,,,,,,,,,,, -5500,0.02593827,0.116496235,,,,,,,,,,, -5600,0.008005957,0.12677659,,,,,,,,,,, -5700,0.007883259,0.12338661,,,,,,,,,,, -5742,,,0.1219415247159184,0.1242236964246688,83274637.0,0.1265665383018092,95000000.0,4807.114329099655,4916.36280465126,4807.114329099655,108.96741604804993,0.0965254306793212,0.0 -5800,0.0076325736,0.12930506,,,,,,,,,,, -5900,0.0059721423,0.1251838,,,,,,,,,,, -6000,0.011473841,0.12960564,,,,,,,,,,, -6100,0.010943541,0.12219045,,,,,,,,,,, -6200,0.018864471,0.120760776,,,,,,,,,,, -6300,0.02994056,0.12145783,,,,,,,,,,, -6400,0.0064306073,0.12025509,,,,,,,,,,, -6500,0.005746634,0.1225956,,,,,,,,,,, -6600,0.011634675,0.116622984,,,,,,,,,,, -6700,0.011714685,0.12989667,,,,,,,,,,, -6800,0.005627184,0.11932426,,,,,,,,,,, -6900,0.0069993297,0.11648046,,,,,,,,,,, -7000,0.006960346,0.11746988,,,,,,,,,,, -7100,0.005308711,0.12166672,,,,,,,,,,, -7188,,,0.1219883938349268,0.123905986112254,83274637.0,0.1262289715049342,95000000.0,6007.315401077271,6138.382369041443,6007.315401077271,130.712637424469,0.1242105960845947,0.0 -7200,0.0053221453,0.13411845,,,,,,,,,,, -7300,0.008481452,0.12936312,,,,,,,,,,, -7400,0.006203591,0.11430548,,,,,,,,,,, -7500,0.015377502,0.1176561,,,,,,,,,,, -7600,0.00821529,0.12890367,,,,,,,,,,, -7700,0.00803069,0.121003956,,,,,,,,,,, -7800,0.0078036413,0.12135284,,,,,,,,,,, -7900,0.008822322,0.12371291,,,,,,,,,,, -8000,0.009999202,0.12414397,,,,,,,,,,, -8100,0.012582995,0.122517556,,,,,,,,,,, -8200,0.0070723277,0.12364673,,,,,,,,,,, -8300,0.0058145714,0.123066284,,,,,,,,,,, -8400,0.009978924,0.12054944,,,,,,,,,,, -8500,0.013111201,0.11862197,,,,,,,,,,, -8600,0.009075691,0.11733858,,,,,,,,,,, -8613,,,0.1218852419054733,0.1237128782830959,83274637.0,0.1259711364925987,95000000.0,7207.853322982788,7360.830142974853,7207.853322982788,152.5520057678223,0.149766206741333,0.0 -8700,0.006510841,0.11861403,,,,,,,,,,, -8800,0.008120617,0.12258181,,,,,,,,,,, -8900,0.010429396,0.13334247,,,,,,,,,,, -9000,0.008148508,0.114138246,,,,,,,,,,, -9100,0.0065826736,0.119948946,,,,,,,,,,, -9138,,,,,,,,7703.856807947159,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 231c4065c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -208.3280806541443,0.0,71.32777667045593,1,0,71.32777667045593,1.0323986403937448,3581,0.2128455943150481,279.655978679657,1.0327176366533553,0.1975462777273995,1.0370871866426914,3554,0.1897597777152504 -212.7399332523346,0.0284140110015869,151.2991268634796,333,0,151.2991268634796,0.3496438723907428,3581,0.6765431241927883,364.0809512138367,0.3291159697941371,0.6802842276436942,0.3486846082319217,3554,0.6573601735060847 -216.74945425987244,0.0638709068298339,231.5786828994751,564,0,231.5786828994751,0.32667249622574,3581,0.700672003717537,448.4144470691681,0.3065744127546038,0.7045314652579171,0.3248303037180817,3554,0.6825660100195202 -220.76053142547607,0.1021690368652343,311.58346366882324,795,0,311.58346366882324,0.3133649913933782,3581,0.7146922614580424,532.4777381420135,0.293259859085083,0.7191570145743233,0.3110030415223164,3554,0.6973806884628235 -224.7690806388855,0.1421968936920166,391.5698597431183,1054,0,391.5698597431183,0.3066303644124372,3581,0.7205648627303826,616.5227851867676,0.2865138053894043,0.7251708166939872,0.304333276271015,3554,0.7035528299978897 -228.7809481620789,0.1752948760986328,471.7077603340149,1398,0,471.7077603340149,0.2996619596590512,3581,0.7265395244650587,700.7194221019745,0.2799362965992519,0.731006281716483,0.2978182112694323,3554,0.7094610468618107 -232.7886414527893,0.2010126113891601,551.8113622665405,1744,0,551.8113622665405,0.3003049678424322,3581,0.7262638180457623,784.8702282905579,0.279796940939767,0.7316139766148159,0.2986845189289357,3554,0.7091029418876618 -236.79975581169128,0.2301688194274902,631.8748698234558,2086,0,631.8748698234558,0.2955346809223157,3581,0.7289850213147515,868.9877972602844,0.2761511291776384,0.7330848830086845,0.2939654743664708,3554,0.7116767226408625 -240.8074486255645,0.2555031776428222,711.8415286540985,2429,0,711.8415286540985,0.2950749997818346,3581,0.7307537966219282,953.00092792511,0.2751845802579607,0.7360131399972099,0.2933108491312605,3554,0.713600927212296 -244.8169682025909,0.2803800106048584,792.1041221618652,2776,0,792.1041221618652,0.2931648601669401,3581,0.7335982632949944,1037.3116459846497,0.2733841112681797,0.7387440545218331,0.2914826796171567,3554,0.7164845206325618 -248.82698583602905,0.3063969612121582,872.119637966156,3120,0,872.119637966156,0.2928364872765987,3581,0.7353520397366309,1121.3769915103912,0.2726921013423374,0.7408569880894252,0.2912389855070871,3554,0.7182428964283202 -252.83643865585327,0.3326916694641113,952.2769184112548,3465,0,952.2769184112548,0.2922635987896188,3581,0.7351641448574071,1205.583834886551,0.2719330617359706,0.7407792636326381,0.2906374268539849,3554,0.7179621415790307 -256.84815645217896,0.3579692840576172,1032.3964014053345,3811,0,1032.3964014053345,0.2929896120584334,3581,0.7342503048860305,1289.7542309761047,0.2725422552653721,0.7399152347019741,0.2913451186713034,3554,0.7170441068822102 -260.85873436927795,0.3834872245788574,1112.4486014842987,4156,0,1112.4486014842987,0.2919820632657602,3581,0.7362064296591385,1373.8563287258148,0.2715973513466971,0.7418276923043388,0.2905141543859032,3554,0.7188788023837578 -264.8728246688843,0.4092543125152588,1192.4837381839752,4497,0,1192.4837381839752,0.2915290293497801,3581,0.7369858934306059,1457.9447252750397,0.2710503510066441,0.7427105222429548,0.2898978607400992,3554,0.7198047369601154 -268.8852117061615,0.4353601932525635,1272.489012479782,4844,0,1272.489012479782,0.291571639763247,3581,0.7367711369467328,1542.002141237259,0.2719627789088658,0.7410882541111538,0.290022987962507,3554,0.719926257715778 -272.89883518218994,0.461669921875,1352.4867494106293,5191,0,1352.4867494106293,0.2904303965372801,3581,0.7386012030726403,1626.053575515747,0.2707222359521048,0.7438113348824638,0.2888372847385868,3554,0.7214453015088632 -276.9070551395416,0.4866147041320801,1432.6245748996737,5532,0,1432.6245748996737,0.2914324571086987,3581,0.7371778107328609,1710.2379422187803,0.2709032297134399,0.7429601124354771,0.2898524879537141,3554,0.7198999476821891 -280.91675758361816,0.5123147964477539,1512.714156150818,5878,0,1512.714156150818,0.2924137237983454,3581,0.7340223903064786,1794.376507282257,0.2728471074785505,0.7389003208705357,0.2909343249353721,3554,0.71665186068954 -284.9300625324249,0.5374350547790527,1592.7204446792605,6220,0,1592.7204446792605,0.2902518418606883,3581,0.7380594713199874,1878.43439412117,0.2703012909208025,0.7431642668587821,0.2887784134623839,3554,0.7207949009918402 -288.945408821106,0.563525915145874,1672.7735142707825,6563,0,1672.7735142707825,0.2902247075493926,3581,0.7378159442849413,1962.5421574115755,0.2698183059692383,0.7439885820661273,0.2887259651317178,3554,0.720698522461487 -292.95811128616333,0.5912113189697266,1752.762172460556,6909,0,1752.762172460556,0.2902472058477031,3581,0.7378185349980801,2046.5846438407896,0.2705033676964896,0.7427389962332589,0.2888722159450619,3554,0.7205512412290729 -296.9700815677643,0.6176929473876953,1832.948479890824,7254,0,1832.948479890824,0.2919200565920832,3581,0.7345816434611491,2130.8227894306183,0.2724441460200718,0.7382341793605259,0.2904418189671848,3554,0.7179365871860931 -300.9809060096741,0.643923282623291,1912.942586898804,7598,0,1912.942586898804,0.2890893956929803,3581,0.7394905676225216,2214.8672075271606,0.2690965959003993,0.744854313986642,0.2875853599157639,3554,0.722372472588105 -304.9908983707428,0.6710951328277588,1992.9886183738708,7945,0,1992.9886183738708,0.2891729461917062,3581,0.7386821969465582,2298.9635775089264,0.2691159078053066,0.7442718233380999,0.2878344808996025,3554,0.7212834570202589 -309.0040822029114,0.6984851360321045,2072.9793784618378,8291,0,2072.9793784618378,0.2895576670928163,3581,0.7391878632452528,2383.0082693099976,0.2695486886160714,0.7449304035731724,0.2882051226391038,3554,0.7218863895654544 -313.0164098739624,0.7257647514343262,2152.94308423996,8634,0,2152.94308423996,0.2887361042328434,3581,0.7405721221813041,2467.0250334739685,0.2686595065253122,0.746199539729527,0.2872973577856816,3554,0.7233574158298748 -317.02690291404724,0.753577470779419,2232.981694459915,8978,0,2232.981694459915,0.2885050194385297,3581,0.7401969460128107,2551.1151769161224,0.2681561027254377,0.7457756996154785,0.2870447505506559,3554,0.7230410771753658 -321.0403423309326,0.7809731960296631,2313.036667108536,9322,0,2313.036667108536,0.2883283737084613,3581,0.7405238531049287,2635.224442481994,0.2680826868329729,0.7460549899509975,0.286880192625167,3554,0.7233535002374085 -325.05160641670227,0.8076949119567871,2393.2060899734497,9668,0,2393.2060899734497,0.2886850740016755,3581,0.7397892495767593,2719.44512963295,0.268416234425136,0.7457577160426548,0.2872653632735562,3554,0.7224750336328785 -329.06395959854126,0.8351755142211914,2473.2667071819305,10012,0,2473.2667071819305,0.2882206204948862,3581,0.7403182322937029,2803.558928966522,0.2677499226161411,0.7463709967476981,0.2868000088478651,3554,0.7230605177484877 -333.07476782798767,0.8613989353179932,2553.332664489746,10358,0,2553.332664489746,0.2883491335019024,3581,0.7408877119476054,2887.675325155258,0.2682306596211025,0.7464085987636021,0.2870103173800647,3554,0.7235777881216587 -337.0884153842926,0.8884561061859131,2633.5612363815308,10703,0,2633.5612363815308,0.2885392782109571,3581,0.7402850302595294,2971.9580538272858,0.268431510244097,0.7456763812473842,0.2871879616277434,3554,0.7229905866409327 -341.10114550590515,0.9152963161468506,2713.557804107666,11047,0,2713.557804107666,0.289556201294593,3581,0.7390198759512008,3056.007774591446,0.2691765342439924,0.7449214799063546,0.2880697255732977,3554,0.721823053139948 -345.11279916763306,0.943565845489502,2793.704563856125,11393,0,2793.704563856125,0.2887941566601508,3581,0.7398553809384599,3140.207786083221,0.2686616012028285,0.7457105772835868,0.2873810793351241,3554,0.7226912842483821 -349.1259129047394,0.970717191696167,2873.7954342365265,11739,0,2873.7954342365265,0.2888523795291119,3581,0.7401566536058364,3224.352249622345,0.2691163335527692,0.7449744769505092,0.2874639250283571,3554,0.7231983878200618 -353.1361691951752,0.9994661808013916,2953.9432995319366,12081,0,2953.9432995319366,0.2884994289522828,3581,0.7412144826864004,3308.552498102188,0.2677397727966308,0.7475405420575824,0.2871470368170107,3554,0.7239450294782288 -357.14885449409485,1.0281689167022705,3034.0391058921814,12428,0,3034.0391058921814,0.2879170979976787,3581,0.7423336707623569,3392.7031519412994,0.2676975556782314,0.7480308668954032,0.2865567784263769,3554,0.7251667630222988 -361.1650941371918,1.0573818683624268,3114.203953027725,12772,0,3114.203953027725,0.2880633369366971,3581,0.740602392619031,3476.926920652389,0.2679569721221924,0.7462588718959263,0.2866768737689927,3554,0.7233728721159257 -365.1799545288086,1.085845947265625,3194.259214401245,13113,0,3194.259214401245,0.288360621269373,3581,0.7415505936278274,3561.0386533737183,0.2680355651038034,0.7472739900861468,0.286952081528955,3554,0.724447667900605 -369.19701194763184,1.113793134689331,3274.239005088806,13458,0,3274.239005088806,0.2886832332318137,3581,0.7403828637688494,3645.07679104805,0.2685546534402029,0.7458792413984027,0.2873326496388312,3554,0.7230923233504502 -373.212171792984,1.1418728828430176,3354.394580364228,13805,0,3354.394580364228,0.2884788736888264,3581,0.7408170127495811,3729.289011478424,0.2679947103772844,0.7467846189226423,0.2870484772329593,3554,0.7235428912624859 -377.22587037086487,1.1696898937225342,3434.455368757248,14147,0,3434.455368757248,0.287787903224047,3581,0.7413705390646816,3813.404773712158,0.2674491916384016,0.7473293713160923,0.2863896444532129,3554,0.7241186207442318 -381.2357349395752,1.2031478881835938,3514.48440861702,14488,0,3514.48440861702,0.2880733248176138,3581,0.7404717661355068,3897.490335702896,0.2679217542920794,0.7461852346147809,0.2866953526176491,3554,0.723222568320906 -385.2501049041748,1.2313237190246582,3594.580799818039,14832,0,3594.580799818039,0.2878396493101612,3581,0.7417353523806199,3981.64256811142,0.2674539940697806,0.7476575715201241,0.28643901870032,3554,0.7245439090417487 -389.26374077796936,1.2599167823791504,3674.611796617508,15175,0,3674.611796617508,0.2881857481325048,3581,0.7397276178747207,4065.728991985321,0.2678097827093942,0.7454748834882464,0.2868108969427054,3554,0.7224384881031936 -393.27994561195374,1.2880818843841553,3754.784963130951,15522,0,3754.784963130951,0.28855615193469,3581,0.7403652741901704,4149.960080385208,0.2679358380181448,0.7466159548078265,0.2872402038745955,3554,0.7230707532445836 -397.2954897880554,1.3162243366241455,3834.939851999283,15867,0,3834.939851999283,0.2872456601464325,3581,0.7412455712440659,4234.172024965286,0.2669087648391723,0.7471408843994141,0.285905776831563,3554,0.7238568943004361 -401.3085203170776,1.3447983264923096,3915.069020032882,16213,0,3915.069020032882,0.2872687379463662,3581,0.7419904012714674,4318.356410264969,0.2671341555459158,0.7476457868303571,0.2859542752224606,3554,0.7246648115459693 -405.32308530807495,1.3752338886260986,3995.056948661804,16560,0,3995.056948661804,0.2871150336629084,3581,0.7418862955092851,4402.402938842773,0.2664672647203718,0.7480151993887765,0.2857950583024848,3554,0.7246860381788126 -409.3372468948364,1.4041290283203125,4075.1593718528734,16906,0,4075.1593718528734,0.2871182720543319,3581,0.7420793036381248,4486.561875104904,0.2666489567075457,0.7479713984898159,0.2857232724406039,3554,0.7247964304085186 -413.3470900058746,1.4336450099945068,4155.271572828293,17249,0,4155.271572828293,0.2874048867394757,3581,0.742540450576829,4570.726595163345,0.2669464349746704,0.7486212594168526,0.2860727734155177,3554,0.7252796282577729 -417.3585879802704,1.4628279209136963,4235.25742316246,17592,0,4235.25742316246,0.2877851079809236,3581,0.7413886740566532,4654.7667927742,0.2669038772583008,0.7479096140180316,0.2864979930184299,3554,0.7240776100652434 -421.3698930740357,1.4926958084106443,4315.329802036285,17940,0,4315.329802036285,0.2874236694097319,3581,0.7414530328251536,4738.894271850586,0.267059462411063,0.7473158836364746,0.2861336196616049,3554,0.7240636650604952 -425.3839862346649,1.5210142135620115,4395.332178592682,18286,0,4395.332178592682,0.2873145867512566,3581,0.7421564796189961,4822.952711105347,0.2668248585292271,0.7483020509992327,0.2860413799811832,3554,0.7248403949555079 -429.3969392776489,1.5507056713104248,4475.431641340256,18630,0,4475.431641340256,0.2870992507657602,3581,0.7419122708173346,4907.108276367188,0.2661434582301548,0.7485624722072056,0.2857404632654227,3554,0.7246146644845597 -433.4125940799713,1.5795207023620603,4555.57452249527,18976,0,4555.57452249527,0.2873082463217327,3581,0.741142760838453,4991.308980464935,0.2668733426502773,0.7473161561148507,0.2860027392660558,3554,0.7237860701630205 -437.425502538681,1.6093058586120603,4635.542121648788,19323,0,4635.542121648788,0.2871508945868821,3581,0.7423779855923625,5075.3327651023865,0.2665856054850987,0.748532772064209,0.2857810274294808,3554,0.7251577640290869 -441.4396963119507,1.6389601230621338,4715.623168230057,19666,0,4715.623168230057,0.2870662191732407,3581,0.7424777280482058,5159.4710512161255,0.2659586497715541,0.7492211205618722,0.2857543224019151,3554,0.7252296872801772 -445.4508848190308,1.6683268547058103,4795.6775233745575,20011,0,4795.6775233745575,0.2869267638120462,3581,0.7418216640341385,5243.579426765442,0.2665020397731236,0.747901439666748,0.2856981817362479,3554,0.7244954106508511 -449.4621365070343,1.7000653743743896,4875.658908843994,20354,0,4875.658908843994,0.2872837368119066,3581,0.7420431018308433,5327.617372274399,0.2666335957390921,0.7483769825526646,0.2859574351742754,3554,0.7248445853263928 -453.4756505489349,1.7338788509368896,4955.702982664108,20696,0,4955.702982664108,0.2869335473898701,3581,0.7427288226926836,5411.722235918045,0.2660750320979527,0.7491252762930733,0.2855740162383054,3554,0.7255758393931134 -457.4879927635193,1.763725996017456,5035.683332443237,21044,0,5035.683332443237,0.2868228625798485,3581,0.7424461622539095,5495.758540868759,0.2661369187491281,0.7485916273934501,0.2855152480040095,3554,0.7251855853439786 -461.4994411468506,1.793873310089111,5115.678194522858,21390,0,5115.678194522858,0.2872540458758028,3581,0.7419769704691427,5579.808661937714,0.2669078111648559,0.7477847508021763,0.2859515789592273,3554,0.7248390210634145 -465.5098142623901,1.823737382888794,5195.80691409111,21731,0,5195.80691409111,0.287159791641214,3581,0.7431472910412594,5663.991146087647,0.2665330682482038,0.7492568152291434,0.285770912148943,3554,0.7260055241453293 -469.5228455066681,1.8534579277038568,5275.870712280273,22078,0,5275.870712280273,0.28738392241605,3581,0.7400780459150726,5748.111387014389,0.2668310574122837,0.7461775371006557,0.286123092213439,3554,0.7226040421004502 -473.5297937393189,1.8833134174346924,5356.035150527954,22424,0,5356.035150527954,0.2867656282724797,3581,0.7425572902122313,5832.326321363449,0.2662124293191092,0.7486978939601353,0.2854405769687324,3554,0.7253482541678391 -477.544869184494,1.915790557861328,5436.060939788818,22767,0,5436.060939788818,0.2872492394211637,3581,0.7433064153693103,5916.412758111954,0.2666067055293492,0.749427046094622,0.2857713586638734,3554,0.7262519316922833 -481.5577425956726,1.9451351165771484,5516.128949642181,23114,0,5516.128949642181,0.2866858615784697,3581,0.7426980068416643,6000.536830663681,0.265721321105957,0.7493207114083427,0.2853577312754994,3554,0.7254548681942882 -485.5675754547119,1.9756977558135984,5596.276057720184,23459,0,5596.276057720184,0.2865019891222773,3581,0.7426928254153867,6084.737917661667,0.2658481938498361,0.7489502089364188,0.2852362105198368,3554,0.7253882344277575 -489.5242583751679,2.007225275039673,5676.235087871552,23801,0,5676.235087871552,0.2866005384865435,3581,0.7432441019006563,6168.698595523834,0.2658059937613351,0.7496252059936523,0.2852406241481869,3554,0.7260273003350098 -493.53691935539246,2.037910223007202,5756.289666891098,24146,0,5756.289666891098,0.2864032693163572,3581,0.7430751601333426,6252.809968471527,0.2652804510934012,0.7498554502214704,0.2851335979541098,3554,0.7258368101962578 -497.5527238845825,2.0684311389923096,5836.392609119415,24492,0,5836.392609119415,0.2864860016951445,3581,0.7429374432770176,6336.972864627838,0.2656843491962978,0.7493213926042829,0.2851750036270751,3554,0.7257174189733399 -501.56517338752747,2.099144697189331,5916.42840719223,24837,0,5916.42840719223,0.2864755024892663,3581,0.743295916163432,6421.065683841705,0.2655792576926095,0.7498811313084194,0.2851052957769854,3554,0.7261310978826674 -505.57781767845154,2.129843711853028,5996.445634841919,25180,0,5996.445634841919,0.2863555456532742,3581,0.7432925755070162,6505.139622926712,0.2650695528302874,0.7501208441598075,0.2850197538205191,3554,0.7260279185864519 -509.5925693511963,2.161004543304444,6076.453567028046,25527,0,6076.453567028046,0.2865228852690414,3581,0.7433385265768989,6589.206932544708,0.265564067023141,0.7498316083635602,0.2852243950478334,3554,0.7260694788222777 -513.6058740615845,2.1919939517974854,6156.694766283035,25872,0,6156.694766283035,0.2863676470106988,3581,0.743424156463802,6673.505701780319,0.2655244214194162,0.7498658725193569,0.2850607473258564,3554,0.7262190956712508 -517.6208217144012,2.224311351776123,6236.789935827255,26214,0,6236.789935827255,0.2863786575415387,3581,0.7430811596795588,6757.661597251892,0.264920881816319,0.7502429825919015,0.2850451708242473,3554,0.7258607159186832 -521.6307320594788,2.2558205127716064,6316.855994939804,26559,0,6316.855994939804,0.2862669500815938,3581,0.7435952798842851,6841.782725572586,0.2651557922363281,0.7503128732953753,0.2849239763679568,3554,0.7263939234401379 -525.6441569328308,2.287712812423706,6396.952573060989,26907,0,6396.952573060989,0.2864094733925579,3581,0.743497310021642,6925.938338279724,0.2654227529253278,0.7502010890415737,0.285112319800313,3554,0.7263083986573228 -529.6607377529144,2.31883192062378,6476.943389892578,27249,0,6476.943389892578,0.2863592271929978,3581,0.7435778948355907,7009.990304231644,0.2648429019110543,0.7506561960492816,0.2849978745889315,3554,0.7264343845622889 -533.6739375591278,2.35146427154541,6556.95263504982,27596,0,6556.95263504982,0.2862810285622033,3581,0.7427244593863446,7094.058998584747,0.2651593174253191,0.7495948246547154,0.2850242189698227,3554,0.7254652410795934 -537.6880068778992,2.38265347480774,6636.923789262772,27939,0,6636.923789262772,0.2861486976621404,3581,0.7433330724439752,7178.088750839233,0.265045166015625,0.750136307307652,0.2848596610443338,3554,0.7260886446169809 -541.7042119503021,2.416171073913574,6716.911543369293,28283,0,6716.911543369293,0.2862953115727974,3581,0.7430446851656312,7262.139762163162,0.2648871626172747,0.7500929151262555,0.284950715742825,3554,0.7258392145074212 -545.7160577774048,2.4494388103485107,6796.866130590439,28628,0,6796.866130590439,0.286252019392715,3581,0.7434871516990715,7346.153185129166,0.2649558271680559,0.7504279272896903,0.2849141530394889,3554,0.7262554351171215 -549.7289471626282,2.482147455215454,6876.82315993309,28972,0,6876.82315993309,0.2861457319773631,3581,0.7436901817971586,7430.169394493103,0.2649299076625279,0.7505592618669782,0.284866135510824,3554,0.7264863863780248 -553.7419407367706,2.515422344207764,6956.778971672058,29314,0,6956.778971672058,0.2863870432709089,3581,0.7430681379372033,7514.184953689575,0.2651236738477434,0.7498602867126465,0.2850154947550295,3554,0.7258943762749719 -557.7553491592407,2.5471506118774414,7036.799525976181,29658,0,7036.799525976181,0.2861635942626885,3581,0.7433210051748813,7598.264146327972,0.2647002935409546,0.7504299027579171,0.2848177229881823,3554,0.7261024522325197 -561.7724099159241,2.5802698135375977,7116.821959257126,30003,0,7116.821959257126,0.2862120337807177,3581,0.7437669487180606,7682.350178480148,0.2649395125252859,0.7505853516714913,0.2848703602290113,3554,0.7266091436365715 -565.7843985557556,2.616807460784912,7196.888969659805,30348,0,7196.888969659805,0.2862066137361246,3581,0.7439271638726962,7766.479268789291,0.26484443460192,0.7509135518755231,0.2848333166634426,3554,0.726777857585643 -569.8019735813141,2.649852514266968,7276.879731178284,30693,0,7276.879731178284,0.2861527882618332,3581,0.7436694560920483,7850.534174203873,0.264484030859811,0.7508979524884906,0.2848165551799029,3554,0.726506582591798 -573.8156337738037,2.681325674057007,7356.900419712067,31041,0,7356.900419712067,0.2860615678886833,3581,0.7439709332894093,7934.61389875412,0.2646525587354387,0.7510170936584473,0.284739548528067,3554,0.7267936573447172 -577.8291063308716,2.7132408618927,7436.869582891464,31385,0,7436.869582891464,0.2860809982372242,3581,0.7435804173720678,8018.64168548584,0.2646515199116298,0.7506763594491142,0.2847565504427229,3554,0.726357583994267 -581.842268705368,2.745941162109375,7516.999427556992,31729,0,7516.999427556992,0.2860875091084019,3581,0.7437989235723261,8102.830830574036,0.2642843893596104,0.7512480872017997,0.2847422276176491,3554,0.726590664787915 -585.8551957607269,2.778761148452759,7596.9842693805695,32075,0,7596.9842693805695,0.2860553979008133,3581,0.7436901817971586,8186.874871730804,0.2644829750061035,0.750896862574986,0.2847329194987162,3554,0.7264721665948579 -589.8650617599487,2.811578750610352,7677.118875980377,32419,0,7677.118875980377,0.2860293544161023,3581,0.7439336406555431,8271.065751314163,0.2645196914672851,0.751089368547712,0.2846938666159609,3554,0.7267552570607062 -593.8802864551544,2.8437328338623047,7757.2637186050415,32765,0,7757.2637186050415,0.2860738396877618,3581,0.7439089607040631,8355.271630525589,0.2642244781766619,0.7514265605381557,0.2846990358849623,3554,0.7267815670942952 -597.8902106285095,2.876116275787353,7837.221108436584,33109,0,7837.221108436584,0.2859970045901983,3581,0.7439951360042586,8439.284693956375,0.2643719060080392,0.7512797628130231,0.2846539722242983,3554,0.7268199673783061 -601.9045717716217,2.9088120460510254,7917.2318749427795,33454,0,7917.2318749427795,0.2859811535163886,3581,0.7438003552822187,8523.35584950447,0.2643978595733642,0.7510183198111398,0.2846558956732291,3554,0.726607151493036 -605.9201576709747,2.9417948722839355,7997.312318086624,33796,0,7997.312318086624,0.2859906641606743,3581,0.7438691455337196,8607.498377799988,0.2641600370407104,0.7513384137834821,0.2846327284178039,3554,0.7267140402979038 -609.9347627162933,2.974903583526612,8077.364030599594,34141,0,8077.364030599594,0.2859547691483698,3581,0.744121058298136,8691.611410140991,0.2642370292118617,0.7514914103916713,0.2846216342391495,3554,0.7269534409951814 -613.9488220214844,3.0089025497436523,8157.329024076462,34487,0,8157.329024076462,0.2859628139944324,3581,0.7439071881108629,8775.637843370438,0.2642880848475865,0.7512409346444267,0.2846199512213351,3554,0.7267562187851716 -617.962637424469,3.041879415512085,8237.486559867859,34829,0,8237.486559867859,0.2860020155748219,3581,0.7442693425370008,8859.855407238007,0.2642233712332589,0.751697267804827,0.2846372794353633,3554,0.7271451676368177 -621.9767315387726,3.076141119003296,8317.607504606247,35174,0,8317.607504606247,0.2859003130399591,3581,0.7440546542297891,8944.038051128387,0.2641376086643764,0.7514427730015346,0.2845476158026168,3554,0.7268877689531162 -625.9883260726929,3.109600782394409,8397.610383033752,35519,0,8397.610383033752,0.2858832006979108,3581,0.7441674866046495,9028.099632501602,0.2641528333936419,0.7515405927385602,0.2845413645935917,3554,0.7270060610623593 -629.9989938735962,3.143000602722168,8477.655651330948,35862,0,8477.655651330948,0.2858744911293982,3581,0.7438509423650865,9112.202283859251,0.264157908303397,0.7511764253888812,0.2845379298633582,3554,0.7266679462181697 -634.0129742622375,3.176217794418335,8553.608373880386,36189,0,8553.608373880386,0.2858859788968689,3581,0.7440562222930047,9192.214786529541,0.26414341585976736,0.7514442716326032,0.2845460530003605,3554,0.7268797316843697 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 39179cc66..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.6054354,1.0139427,,,,,,,,,,,,,, -1,,,0.1975462777273995,1.0327176366533553,0.1897597777152504,1.0370871866426914,3554.0,0.2128455943150481,1.0323986403937448,3581.0,71.32777667045593,279.655978679657,71.32777667045593,208.3280806541443,0.0,0.0 -100,0.8439901,0.4717341,,,,,,,,,,,,,, -200,0.19660556,0.36092052,,,,,,,,,,,,,, -300,0.16759156,0.3407624,,,,,,,,,,,,,, -333,,,0.6802842276436942,0.3291159697941371,0.6573601735060847,0.3486846082319217,3554.0,0.6765431241927883,0.3496438723907428,3581.0,151.2991268634796,364.0809512138367,151.2991268634796,212.7399332523346,0.0284140110015869,0.0 -400,0.07975911,0.3675091,,,,,,,,,,,,,, -500,0.08872041,0.24359351,,,,,,,,,,,,,, -564,,,0.7045314652579171,0.3065744127546038,0.6825660100195202,0.3248303037180817,3554.0,0.700672003717537,0.32667249622574,3581.0,231.5786828994751,448.4144470691681,231.5786828994751,216.74945425987244,0.0638709068298339,0.0 -600,0.11685214,0.3673199,,,,,,,,,,,,,, -700,0.19322443,0.32410625,,,,,,,,,,,,,, -795,,,0.7191570145743233,0.293259859085083,0.6973806884628235,0.3110030415223164,3554.0,0.7146922614580424,0.3133649913933782,3581.0,311.58346366882324,532.4777381420135,311.58346366882324,220.76053142547607,0.1021690368652343,0.0 -800,0.17018808,0.25184354,,,,,,,,,,,,,, -900,0.09580772,0.31866455,,,,,,,,,,,,,, -1000,0.11593685,0.26299876,,,,,,,,,,,,,, -1054,,,0.7251708166939872,0.2865138053894043,0.7035528299978897,0.304333276271015,3554.0,0.7205648627303826,0.3066303644124372,3581.0,391.5698597431183,616.5227851867676,391.5698597431183,224.7690806388855,0.1421968936920166,0.0 -1100,0.42259818,0.26593363,,,,,,,,,,,,,, -1200,0.21053404,0.254252,,,,,,,,,,,,,, -1300,0.2007704,0.3659557,,,,,,,,,,,,,, -1398,,,0.731006281716483,0.2799362965992519,0.7094610468618107,0.2978182112694323,3554.0,0.7265395244650587,0.2996619596590512,3581.0,471.7077603340149,700.7194221019745,471.7077603340149,228.7809481620789,0.1752948760986328,0.0 -1400,0.2217138,0.27175143,,,,,,,,,,,,,, -1500,0.25691208,0.3366652,,,,,,,,,,,,,, -1600,0.4118706,0.1931799,,,,,,,,,,,,,, -1700,0.20768581,0.33588028,,,,,,,,,,,,,, -1744,,,0.7316139766148159,0.279796940939767,0.7091029418876618,0.2986845189289357,3554.0,0.7262638180457623,0.3003049678424322,3581.0,551.8113622665405,784.8702282905579,551.8113622665405,232.7886414527893,0.2010126113891601,0.0 -1800,0.1941869,0.27803332,,,,,,,,,,,,,, -1900,0.073492505,0.21703637,,,,,,,,,,,,,, -2000,0.07117865,0.3356461,,,,,,,,,,,,,, -2086,,,0.7330848830086845,0.2761511291776384,0.7116767226408625,0.2939654743664708,3554.0,0.7289850213147515,0.2955346809223157,3581.0,631.8748698234558,868.9877972602844,631.8748698234558,236.79975581169128,0.2301688194274902,0.0 -2100,0.17710307,0.28989267,,,,,,,,,,,,,, -2200,0.77988905,0.29444596,,,,,,,,,,,,,, -2300,0.10385427,0.22358008,,,,,,,,,,,,,, -2400,0.17266253,0.28450483,,,,,,,,,,,,,, -2429,,,0.7360131399972099,0.2751845802579607,0.713600927212296,0.2933108491312605,3554.0,0.7307537966219282,0.2950749997818346,3581.0,711.8415286540985,953.00092792511,711.8415286540985,240.8074486255645,0.2555031776428222,0.0 -2500,0.09749519,0.2578752,,,,,,,,,,,,,, -2600,0.31171212,0.25376716,,,,,,,,,,,,,, -2700,0.11786953,0.29415905,,,,,,,,,,,,,, -2776,,,0.7387440545218331,0.2733841112681797,0.7164845206325618,0.2914826796171567,3554.0,0.7335982632949944,0.2931648601669401,3581.0,792.1041221618652,1037.3116459846497,792.1041221618652,244.8169682025909,0.2803800106048584,0.0 -2800,0.16608167,0.3505792,,,,,,,,,,,,,, -2900,0.32257107,0.22649127,,,,,,,,,,,,,, -3000,0.12707305,0.26016477,,,,,,,,,,,,,, -3100,0.10115211,0.2391883,,,,,,,,,,,,,, -3120,,,0.7408569880894252,0.2726921013423374,0.7182428964283202,0.2912389855070871,3554.0,0.7353520397366309,0.2928364872765987,3581.0,872.119637966156,1121.3769915103912,872.119637966156,248.82698583602905,0.3063969612121582,0.0 -3200,0.112595655,0.27411264,,,,,,,,,,,,,, -3300,0.10494702,0.3079452,,,,,,,,,,,,,, -3400,0.08049615,0.3773296,,,,,,,,,,,,,, -3465,,,0.7407792636326381,0.2719330617359706,0.7179621415790307,0.2906374268539849,3554.0,0.7351641448574071,0.2922635987896188,3581.0,952.2769184112548,1205.583834886551,952.2769184112548,252.83643865585327,0.3326916694641113,0.0 -3500,0.10477073,0.32274526,,,,,,,,,,,,,, -3600,0.15814093,0.28800178,,,,,,,,,,,,,, -3700,0.1639419,0.22571999,,,,,,,,,,,,,, -3800,0.10755436,0.25518405,,,,,,,,,,,,,, -3811,,,0.7399152347019741,0.2725422552653721,0.7170441068822102,0.2913451186713034,3554.0,0.7342503048860305,0.2929896120584334,3581.0,1032.3964014053345,1289.7542309761047,1032.3964014053345,256.84815645217896,0.3579692840576172,0.0 -3900,0.3170732,0.28003186,,,,,,,,,,,,,, -4000,0.3328216,0.22416572,,,,,,,,,,,,,, -4100,0.17416371,0.29284498,,,,,,,,,,,,,, -4156,,,0.7418276923043388,0.2715973513466971,0.7188788023837578,0.2905141543859032,3554.0,0.7362064296591385,0.2919820632657602,3581.0,1112.4486014842987,1373.8563287258148,1112.4486014842987,260.85873436927795,0.3834872245788574,0.0 -4200,0.107646726,0.23598811,,,,,,,,,,,,,, -4300,0.24453965,0.3175575,,,,,,,,,,,,,, -4400,0.1307289,0.2588099,,,,,,,,,,,,,, -4497,,,0.7427105222429548,0.2710503510066441,0.7198047369601154,0.2898978607400992,3554.0,0.7369858934306059,0.2915290293497801,3581.0,1192.4837381839752,1457.9447252750397,1192.4837381839752,264.8728246688843,0.4092543125152588,0.0 -4500,0.1305047,0.2636383,,,,,,,,,,,,,, -4600,0.14902826,0.23350392,,,,,,,,,,,,,, -4700,0.085729055,0.2536161,,,,,,,,,,,,,, -4800,0.27356172,0.26640457,,,,,,,,,,,,,, -4844,,,0.7410882541111538,0.2719627789088658,0.719926257715778,0.290022987962507,3554.0,0.7367711369467328,0.291571639763247,3581.0,1272.489012479782,1542.002141237259,1272.489012479782,268.8852117061615,0.4353601932525635,0.0 -4900,0.12747565,0.36817902,,,,,,,,,,,,,, -5000,0.07534247,0.29661632,,,,,,,,,,,,,, -5100,0.0810999,0.2401969,,,,,,,,,,,,,, -5191,,,0.7438113348824638,0.2707222359521048,0.7214453015088632,0.2888372847385868,3554.0,0.7386012030726403,0.2904303965372801,3581.0,1352.4867494106293,1626.053575515747,1352.4867494106293,272.89883518218994,0.461669921875,0.0 -5200,0.056213472,0.31195095,,,,,,,,,,,,,, -5300,0.20209323,0.19886126,,,,,,,,,,,,,, -5400,0.108140714,0.23907691,,,,,,,,,,,,,, -5500,0.174144,0.2711107,,,,,,,,,,,,,, -5532,,,0.7429601124354771,0.2709032297134399,0.7198999476821891,0.2898524879537141,3554.0,0.7371778107328609,0.2914324571086987,3581.0,1432.6245748996737,1710.2379422187803,1432.6245748996737,276.9070551395416,0.4866147041320801,0.0 -5600,0.14600798,0.22420977,,,,,,,,,,,,,, -5700,0.16605675,0.29131728,,,,,,,,,,,,,, -5800,0.1703988,0.4370351,,,,,,,,,,,,,, -5878,,,0.7389003208705357,0.2728471074785505,0.71665186068954,0.2909343249353721,3554.0,0.7340223903064786,0.2924137237983454,3581.0,1512.714156150818,1794.376507282257,1512.714156150818,280.91675758361816,0.5123147964477539,0.0 -5900,0.09326728,0.2812853,,,,,,,,,,,,,, -6000,0.13415593,0.23604867,,,,,,,,,,,,,, -6100,0.1988807,0.3355318,,,,,,,,,,,,,, -6200,0.17175522,0.23259096,,,,,,,,,,,,,, -6220,,,0.7431642668587821,0.2703012909208025,0.7207949009918402,0.2887784134623839,3554.0,0.7380594713199874,0.2902518418606883,3581.0,1592.7204446792605,1878.43439412117,1592.7204446792605,284.9300625324249,0.5374350547790527,0.0 -6300,0.25503793,0.28284305,,,,,,,,,,,,,, -6400,0.17906813,0.28625494,,,,,,,,,,,,,, -6500,0.06954017,0.24931723,,,,,,,,,,,,,, -6563,,,0.7439885820661273,0.2698183059692383,0.720698522461487,0.2887259651317178,3554.0,0.7378159442849413,0.2902247075493926,3581.0,1672.7735142707825,1962.5421574115755,1672.7735142707825,288.945408821106,0.563525915145874,0.0 -6600,0.10342173,0.3104647,,,,,,,,,,,,,, -6700,0.1705505,0.2602036,,,,,,,,,,,,,, -6800,0.1908491,0.27729028,,,,,,,,,,,,,, -6900,0.07121154,0.31384987,,,,,,,,,,,,,, -6909,,,0.7427389962332589,0.2705033676964896,0.7205512412290729,0.2888722159450619,3554.0,0.7378185349980801,0.2902472058477031,3581.0,1752.762172460556,2046.5846438407896,1752.762172460556,292.95811128616333,0.5912113189697266,0.0 -7000,0.18182966,0.22660503,,,,,,,,,,,,,, -7100,0.08052672,0.24121055,,,,,,,,,,,,,, -7200,0.18994653,0.23606512,,,,,,,,,,,,,, -7254,,,0.7382341793605259,0.2724441460200718,0.7179365871860931,0.2904418189671848,3554.0,0.7345816434611491,0.2919200565920832,3581.0,1832.948479890824,2130.8227894306183,1832.948479890824,296.9700815677643,0.6176929473876953,0.0 -7300,0.122491024,0.24767981,,,,,,,,,,,,,, -7400,0.28306022,0.2586438,,,,,,,,,,,,,, -7500,0.13872166,0.19056834,,,,,,,,,,,,,, -7598,,,0.744854313986642,0.2690965959003993,0.722372472588105,0.2875853599157639,3554.0,0.7394905676225216,0.2890893956929803,3581.0,1912.942586898804,2214.8672075271606,1912.942586898804,300.9809060096741,0.643923282623291,0.0 -7600,0.32882842,0.25358653,,,,,,,,,,,,,, -7700,0.19552575,0.24291447,,,,,,,,,,,,,, -7800,0.0738211,0.35088697,,,,,,,,,,,,,, -7900,0.18304966,0.29106763,,,,,,,,,,,,,, -7945,,,0.7442718233380999,0.2691159078053066,0.7212834570202589,0.2878344808996025,3554.0,0.7386821969465582,0.2891729461917062,3581.0,1992.9886183738708,2298.9635775089264,1992.9886183738708,304.9908983707428,0.6710951328277588,0.0 -8000,0.28002042,0.26514277,,,,,,,,,,,,,, -8100,0.17096905,0.25717533,,,,,,,,,,,,,, -8200,0.06882553,0.32186508,,,,,,,,,,,,,, -8291,,,0.7449304035731724,0.2695486886160714,0.7218863895654544,0.2882051226391038,3554.0,0.7391878632452528,0.2895576670928163,3581.0,2072.9793784618378,2383.0082693099976,2072.9793784618378,309.0040822029114,0.6984851360321045,0.0 -8300,0.34690037,0.29607686,,,,,,,,,,,,,, -8400,0.10297743,0.23231864,,,,,,,,,,,,,, -8500,0.137743,0.2761314,,,,,,,,,,,,,, -8600,0.20796749,0.19594225,,,,,,,,,,,,,, -8634,,,0.746199539729527,0.2686595065253122,0.7233574158298748,0.2872973577856816,3554.0,0.7405721221813041,0.2887361042328434,3581.0,2152.94308423996,2467.0250334739685,2152.94308423996,313.0164098739624,0.7257647514343262,0.0 -8700,0.11488095,0.26779178,,,,,,,,,,,,,, -8800,0.105490744,0.2598878,,,,,,,,,,,,,, -8900,0.14557381,0.2921134,,,,,,,,,,,,,, -8978,,,0.7457756996154785,0.2681561027254377,0.7230410771753658,0.2870447505506559,3554.0,0.7401969460128107,0.2885050194385297,3581.0,2232.981694459915,2551.1151769161224,2232.981694459915,317.02690291404724,0.753577470779419,0.0 -9000,0.17351562,0.33173075,,,,,,,,,,,,,, -9100,0.19118074,0.28395385,,,,,,,,,,,,,, -9200,0.22900693,0.22577854,,,,,,,,,,,,,, -9300,0.052359417,0.42702293,,,,,,,,,,,,,, -9322,,,0.7460549899509975,0.2680826868329729,0.7233535002374085,0.286880192625167,3554.0,0.7405238531049287,0.2883283737084613,3581.0,2313.036667108536,2635.224442481994,2313.036667108536,321.0403423309326,0.7809731960296631,0.0 -9400,0.14105514,0.23464781,,,,,,,,,,,,,, -9500,0.07675516,0.26939094,,,,,,,,,,,,,, -9600,0.236194,0.32392263,,,,,,,,,,,,,, -9668,,,0.7457577160426548,0.268416234425136,0.7224750336328785,0.2872653632735562,3554.0,0.7397892495767593,0.2886850740016755,3581.0,2393.2060899734497,2719.44512963295,2393.2060899734497,325.05160641670227,0.8076949119567871,0.0 -9700,0.20126289,0.3238584,,,,,,,,,,,,,, -9800,0.2275961,0.2310619,,,,,,,,,,,,,, -9900,0.15093447,0.24885313,,,,,,,,,,,,,, -10000,0.1469692,0.30946657,,,,,,,,,,,,,, -10012,,,0.7463709967476981,0.2677499226161411,0.7230605177484877,0.2868000088478651,3554.0,0.7403182322937029,0.2882206204948862,3581.0,2473.2667071819305,2803.558928966522,2473.2667071819305,329.06395959854126,0.8351755142211914,0.0 -10100,0.1712269,0.2539611,,,,,,,,,,,,,, -10200,0.1411321,0.36653268,,,,,,,,,,,,,, -10300,0.148583,0.24215584,,,,,,,,,,,,,, -10358,,,0.7464085987636021,0.2682306596211025,0.7235777881216587,0.2870103173800647,3554.0,0.7408877119476054,0.2883491335019024,3581.0,2553.332664489746,2887.675325155258,2553.332664489746,333.07476782798767,0.8613989353179932,0.0 -10400,0.16423696,0.2929941,,,,,,,,,,,,,, -10500,0.15924813,0.33367202,,,,,,,,,,,,,, -10600,0.118518606,0.29405087,,,,,,,,,,,,,, -10700,0.112980895,0.22881751,,,,,,,,,,,,,, -10703,,,0.7456763812473842,0.268431510244097,0.7229905866409327,0.2871879616277434,3554.0,0.7402850302595294,0.2885392782109571,3581.0,2633.5612363815308,2971.9580538272858,2633.5612363815308,337.0884153842926,0.8884561061859131,0.0 -10800,0.2554862,0.324826,,,,,,,,,,,,,, -10900,0.12552696,0.26592463,,,,,,,,,,,,,, -11000,0.06889286,0.2678408,,,,,,,,,,,,,, -11047,,,0.7449214799063546,0.2691765342439924,0.721823053139948,0.2880697255732977,3554.0,0.7390198759512008,0.289556201294593,3581.0,2713.557804107666,3056.007774591446,2713.557804107666,341.10114550590515,0.9152963161468506,0.0 -11100,0.053238854,0.28704855,,,,,,,,,,,,,, -11200,0.24418457,0.27712628,,,,,,,,,,,,,, -11300,0.24172449,0.33409107,,,,,,,,,,,,,, -11393,,,0.7457105772835868,0.2686616012028285,0.7226912842483821,0.2873810793351241,3554.0,0.7398553809384599,0.2887941566601508,3581.0,2793.704563856125,3140.207786083221,2793.704563856125,345.11279916763306,0.943565845489502,0.0 -11400,0.09332794,0.30522862,,,,,,,,,,,,,, -11500,0.12549429,0.26357746,,,,,,,,,,,,,, -11600,0.21318154,0.2359715,,,,,,,,,,,,,, -11700,0.11091858,0.29667097,,,,,,,,,,,,,, -11739,,,0.7449744769505092,0.2691163335527692,0.7231983878200618,0.2874639250283571,3554.0,0.7401566536058364,0.2888523795291119,3581.0,2873.7954342365265,3224.352249622345,2873.7954342365265,349.1259129047394,0.970717191696167,0.0 -11800,0.1211634,0.31342676,,,,,,,,,,,,,, -11900,0.34354162,0.27475965,,,,,,,,,,,,,, -12000,0.10549306,0.25544485,,,,,,,,,,,,,, -12081,,,0.7475405420575824,0.2677397727966308,0.7239450294782288,0.2871470368170107,3554.0,0.7412144826864004,0.2884994289522828,3581.0,2953.9432995319366,3308.552498102188,2953.9432995319366,353.1361691951752,0.9994661808013916,0.0 -12100,0.1633615,0.29279208,,,,,,,,,,,,,, -12200,0.15437329,0.25868732,,,,,,,,,,,,,, -12300,0.13122451,0.28648984,,,,,,,,,,,,,, -12400,0.13041751,0.2571452,,,,,,,,,,,,,, -12428,,,0.7480308668954032,0.2676975556782314,0.7251667630222988,0.2865567784263769,3554.0,0.7423336707623569,0.2879170979976787,3581.0,3034.0391058921814,3392.7031519412994,3034.0391058921814,357.14885449409485,1.0281689167022705,0.0 -12500,0.11639456,0.2991579,,,,,,,,,,,,,, -12600,0.19761471,0.30106956,,,,,,,,,,,,,, -12700,0.2445715,0.27184874,,,,,,,,,,,,,, -12772,,,0.7462588718959263,0.2679569721221924,0.7233728721159257,0.2866768737689927,3554.0,0.740602392619031,0.2880633369366971,3581.0,3114.203953027725,3476.926920652389,3114.203953027725,361.1650941371918,1.0573818683624268,0.0 -12800,0.2966622,0.2536059,,,,,,,,,,,,,, -12900,0.13392887,0.2088636,,,,,,,,,,,,,, -13000,0.16261296,0.302184,,,,,,,,,,,,,, -13100,0.33819675,0.25570992,,,,,,,,,,,,,, -13113,,,0.7472739900861468,0.2680355651038034,0.724447667900605,0.286952081528955,3554.0,0.7415505936278274,0.288360621269373,3581.0,3194.259214401245,3561.0386533737183,3194.259214401245,365.1799545288086,1.085845947265625,0.0 -13200,0.22148305,0.23002297,,,,,,,,,,,,,, -13300,0.16687258,0.34874958,,,,,,,,,,,,,, -13400,0.24717985,0.40837023,,,,,,,,,,,,,, -13458,,,0.7458792413984027,0.2685546534402029,0.7230923233504502,0.2873326496388312,3554.0,0.7403828637688494,0.2886832332318137,3581.0,3274.239005088806,3645.07679104805,3274.239005088806,369.19701194763184,1.113793134689331,0.0 -13500,0.100021355,0.32857576,,,,,,,,,,,,,, -13600,0.16654955,0.19771981,,,,,,,,,,,,,, -13700,0.11486679,0.31667924,,,,,,,,,,,,,, -13800,0.28489754,0.26452437,,,,,,,,,,,,,, -13805,,,0.7467846189226423,0.2679947103772844,0.7235428912624859,0.2870484772329593,3554.0,0.7408170127495811,0.2884788736888264,3581.0,3354.394580364228,3729.289011478424,3354.394580364228,373.212171792984,1.1418728828430176,0.0 -13900,0.10292225,0.23444828,,,,,,,,,,,,,, -14000,0.15645857,0.21726798,,,,,,,,,,,,,, -14100,0.15074694,0.24408843,,,,,,,,,,,,,, -14147,,,0.7473293713160923,0.2674491916384016,0.7241186207442318,0.2863896444532129,3554.0,0.7413705390646816,0.287787903224047,3581.0,3434.455368757248,3813.404773712158,3434.455368757248,377.22587037086487,1.1696898937225342,0.0 -14200,0.11549186,0.35532627,,,,,,,,,,,,,, -14300,0.18369733,0.3305636,,,,,,,,,,,,,, -14400,0.14630124,0.28794438,,,,,,,,,,,,,, -14488,,,0.7461852346147809,0.2679217542920794,0.723222568320906,0.2866953526176491,3554.0,0.7404717661355068,0.2880733248176138,3581.0,3514.48440861702,3897.490335702896,3514.48440861702,381.2357349395752,1.2031478881835938,0.0 -14500,0.10743938,0.34145367,,,,,,,,,,,,,, -14600,0.22533117,0.21210954,,,,,,,,,,,,,, -14700,0.16534324,0.2609677,,,,,,,,,,,,,, -14800,0.094615035,0.34032154,,,,,,,,,,,,,, -14832,,,0.7476575715201241,0.2674539940697806,0.7245439090417487,0.28643901870032,3554.0,0.7417353523806199,0.2878396493101612,3581.0,3594.580799818039,3981.64256811142,3594.580799818039,385.2501049041748,1.2313237190246582,0.0 -14900,0.17427836,0.35185936,,,,,,,,,,,,,, -15000,0.13250169,0.21882327,,,,,,,,,,,,,, -15100,0.108874604,0.19621904,,,,,,,,,,,,,, -15175,,,0.7454748834882464,0.2678097827093942,0.7224384881031936,0.2868108969427054,3554.0,0.7397276178747207,0.2881857481325048,3581.0,3674.611796617508,4065.728991985321,3674.611796617508,389.26374077796936,1.2599167823791504,0.0 -15200,0.21028116,0.25783736,,,,,,,,,,,,,, -15300,0.21451426,0.24168779,,,,,,,,,,,,,, -15400,0.10146689,0.27438867,,,,,,,,,,,,,, -15500,0.12290342,0.2559791,,,,,,,,,,,,,, -15522,,,0.7466159548078265,0.2679358380181448,0.7230707532445836,0.2872402038745955,3554.0,0.7403652741901704,0.28855615193469,3581.0,3754.784963130951,4149.960080385208,3754.784963130951,393.27994561195374,1.2880818843841553,0.0 -15600,0.1556547,0.2758312,,,,,,,,,,,,,, -15700,0.30676994,0.2080411,,,,,,,,,,,,,, -15800,0.11060604,0.24947165,,,,,,,,,,,,,, -15867,,,0.7471408843994141,0.2669087648391723,0.7238568943004361,0.285905776831563,3554.0,0.7412455712440659,0.2872456601464325,3581.0,3834.939851999283,4234.172024965286,3834.939851999283,397.2954897880554,1.3162243366241455,0.0 -15900,0.4408808,0.21032651,,,,,,,,,,,,,, -16000,0.20339614,0.2354865,,,,,,,,,,,,,, -16100,0.13822514,0.36448172,,,,,,,,,,,,,, -16200,0.12110754,0.22594972,,,,,,,,,,,,,, -16213,,,0.7476457868303571,0.2671341555459158,0.7246648115459693,0.2859542752224606,3554.0,0.7419904012714674,0.2872687379463662,3581.0,3915.069020032882,4318.356410264969,3915.069020032882,401.3085203170776,1.3447983264923096,0.0 -16300,0.15293775,0.24439405,,,,,,,,,,,,,, -16400,0.12044056,0.23146352,,,,,,,,,,,,,, -16500,0.10600627,0.2747072,,,,,,,,,,,,,, -16560,,,0.7480151993887765,0.2664672647203718,0.7246860381788126,0.2857950583024848,3554.0,0.7418862955092851,0.2871150336629084,3581.0,3995.056948661804,4402.402938842773,3995.056948661804,405.32308530807495,1.3752338886260986,0.0 -16600,0.13000426,0.23354441,,,,,,,,,,,,,, -16700,0.2694467,0.32104096,,,,,,,,,,,,,, -16800,0.4187804,0.29702452,,,,,,,,,,,,,, -16900,0.055698145,0.33047408,,,,,,,,,,,,,, -16906,,,0.7479713984898159,0.2666489567075457,0.7247964304085186,0.2857232724406039,3554.0,0.7420793036381248,0.2871182720543319,3581.0,4075.1593718528734,4486.561875104904,4075.1593718528734,409.3372468948364,1.4041290283203125,0.0 -17000,0.36062288,0.22258335,,,,,,,,,,,,,, -17100,0.15912426,0.31060782,,,,,,,,,,,,,, -17200,0.08833836,0.25682086,,,,,,,,,,,,,, -17249,,,0.7486212594168526,0.2669464349746704,0.7252796282577729,0.2860727734155177,3554.0,0.742540450576829,0.2874048867394757,3581.0,4155.271572828293,4570.726595163345,4155.271572828293,413.3470900058746,1.4336450099945068,0.0 -17300,0.30954504,0.25420246,,,,,,,,,,,,,, -17400,0.29264385,0.23460847,,,,,,,,,,,,,, -17500,0.1706674,0.25132337,,,,,,,,,,,,,, -17592,,,0.7479096140180316,0.2669038772583008,0.7240776100652434,0.2864979930184299,3554.0,0.7413886740566532,0.2877851079809236,3581.0,4235.25742316246,4654.7667927742,4235.25742316246,417.3585879802704,1.4628279209136963,0.0 -17600,0.23132199,0.3020588,,,,,,,,,,,,,, -17700,0.1765479,0.26152092,,,,,,,,,,,,,, -17800,0.2658667,0.28799626,,,,,,,,,,,,,, -17900,0.13409227,0.20781684,,,,,,,,,,,,,, -17940,,,0.7473158836364746,0.267059462411063,0.7240636650604952,0.2861336196616049,3554.0,0.7414530328251536,0.2874236694097319,3581.0,4315.329802036285,4738.894271850586,4315.329802036285,421.3698930740357,1.4926958084106443,0.0 -18000,0.22700448,0.3144531,,,,,,,,,,,,,, -18100,0.13544801,0.26906568,,,,,,,,,,,,,, -18200,0.1082821,0.21350774,,,,,,,,,,,,,, -18286,,,0.7483020509992327,0.2668248585292271,0.7248403949555079,0.2860413799811832,3554.0,0.7421564796189961,0.2873145867512566,3581.0,4395.332178592682,4822.952711105347,4395.332178592682,425.3839862346649,1.5210142135620115,0.0 -18300,0.10475261,0.25567588,,,,,,,,,,,,,, -18400,0.1915848,0.22087878,,,,,,,,,,,,,, -18500,0.16220479,0.33679816,,,,,,,,,,,,,, -18600,0.20196983,0.2586003,,,,,,,,,,,,,, -18630,,,0.7485624722072056,0.2661434582301548,0.7246146644845597,0.2857404632654227,3554.0,0.7419122708173346,0.2870992507657602,3581.0,4475.431641340256,4907.108276367188,4475.431641340256,429.3969392776489,1.5507056713104248,0.0 -18700,0.21994801,0.28920618,,,,,,,,,,,,,, -18800,0.30876175,0.26278037,,,,,,,,,,,,,, -18900,0.11272208,0.24469036,,,,,,,,,,,,,, -18976,,,0.7473161561148507,0.2668733426502773,0.7237860701630205,0.2860027392660558,3554.0,0.741142760838453,0.2873082463217327,3581.0,4555.57452249527,4991.308980464935,4555.57452249527,433.4125940799713,1.5795207023620603,0.0 -19000,0.068878554,0.280694,,,,,,,,,,,,,, -19100,0.100798205,0.2741893,,,,,,,,,,,,,, -19200,0.1355478,0.23201819,,,,,,,,,,,,,, -19300,0.45148578,0.2572658,,,,,,,,,,,,,, -19323,,,0.748532772064209,0.2665856054850987,0.7251577640290869,0.2857810274294808,3554.0,0.7423779855923625,0.2871508945868821,3581.0,4635.542121648788,5075.3327651023865,4635.542121648788,437.425502538681,1.6093058586120603,0.0 -19400,0.33241874,0.2450596,,,,,,,,,,,,,, -19500,0.79459536,0.22781573,,,,,,,,,,,,,, -19600,0.21605726,0.34352732,,,,,,,,,,,,,, -19666,,,0.7492211205618722,0.2659586497715541,0.7252296872801772,0.2857543224019151,3554.0,0.7424777280482058,0.2870662191732407,3581.0,4715.623168230057,5159.4710512161255,4715.623168230057,441.4396963119507,1.6389601230621338,0.0 -19700,0.31780922,0.20294575,,,,,,,,,,,,,, -19800,0.18930279,0.22042017,,,,,,,,,,,,,, -19900,0.11534004,0.31495705,,,,,,,,,,,,,, -20000,0.20146757,0.24188964,,,,,,,,,,,,,, -20011,,,0.747901439666748,0.2665020397731236,0.7244954106508511,0.2856981817362479,3554.0,0.7418216640341385,0.2869267638120462,3581.0,4795.6775233745575,5243.579426765442,4795.6775233745575,445.4508848190308,1.6683268547058103,0.0 -20100,0.13708699,0.38506863,,,,,,,,,,,,,, -20200,0.17794755,0.35859427,,,,,,,,,,,,,, -20300,0.08608258,0.2691367,,,,,,,,,,,,,, -20354,,,0.7483769825526646,0.2666335957390921,0.7248445853263928,0.2859574351742754,3554.0,0.7420431018308433,0.2872837368119066,3581.0,4875.658908843994,5327.617372274399,4875.658908843994,449.4621365070343,1.7000653743743896,0.0 -20400,0.16476578,0.23533514,,,,,,,,,,,,,, -20500,0.34767506,0.23347224,,,,,,,,,,,,,, -20600,0.29876706,0.36532605,,,,,,,,,,,,,, -20696,,,0.7491252762930733,0.2660750320979527,0.7255758393931134,0.2855740162383054,3554.0,0.7427288226926836,0.2869335473898701,3581.0,4955.702982664108,5411.722235918045,4955.702982664108,453.4756505489349,1.7338788509368896,0.0 -20700,0.13825753,0.25319558,,,,,,,,,,,,,, -20800,0.12117965,0.23596787,,,,,,,,,,,,,, -20900,0.18279305,0.30542746,,,,,,,,,,,,,, -21000,0.22901596,0.31042948,,,,,,,,,,,,,, -21044,,,0.7485916273934501,0.2661369187491281,0.7251855853439786,0.2855152480040095,3554.0,0.7424461622539095,0.2868228625798485,3581.0,5035.683332443237,5495.758540868759,5035.683332443237,457.4879927635193,1.763725996017456,0.0 -21100,0.053816915,0.27026263,,,,,,,,,,,,,, -21200,0.15267068,0.22055998,,,,,,,,,,,,,, -21300,0.2901799,0.27343252,,,,,,,,,,,,,, -21390,,,0.7477847508021763,0.2669078111648559,0.7248390210634145,0.2859515789592273,3554.0,0.7419769704691427,0.2872540458758028,3581.0,5115.678194522858,5579.808661937714,5115.678194522858,461.4994411468506,1.793873310089111,0.0 -21400,0.10703955,0.25787884,,,,,,,,,,,,,, -21500,0.2501168,0.21533695,,,,,,,,,,,,,, -21600,0.11260508,0.22329587,,,,,,,,,,,,,, -21700,0.17934625,0.30209637,,,,,,,,,,,,,, -21731,,,0.7492568152291434,0.2665330682482038,0.7260055241453293,0.285770912148943,3554.0,0.7431472910412594,0.287159791641214,3581.0,5195.80691409111,5663.991146087647,5195.80691409111,465.5098142623901,1.823737382888794,0.0 -21800,0.3550008,0.24701463,,,,,,,,,,,,,, -21900,0.21410318,0.3495797,,,,,,,,,,,,,, -22000,0.17676419,0.28611574,,,,,,,,,,,,,, -22078,,,0.7461775371006557,0.2668310574122837,0.7226040421004502,0.286123092213439,3554.0,0.7400780459150726,0.28738392241605,3581.0,5275.870712280273,5748.111387014389,5275.870712280273,469.5228455066681,1.8534579277038568,0.0 -22100,0.24765879,0.27081487,,,,,,,,,,,,,, -22200,0.05999771,0.28202572,,,,,,,,,,,,,, -22300,0.14534286,0.3292086,,,,,,,,,,,,,, -22400,0.11247185,0.25352722,,,,,,,,,,,,,, -22424,,,0.7486978939601353,0.2662124293191092,0.7253482541678391,0.2854405769687324,3554.0,0.7425572902122313,0.2867656282724797,3581.0,5356.035150527954,5832.326321363449,5356.035150527954,473.5297937393189,1.8833134174346924,0.0 -22500,0.16958289,0.3094387,,,,,,,,,,,,,, -22600,0.13401769,0.22415721,,,,,,,,,,,,,, -22700,0.34008792,0.21274264,,,,,,,,,,,,,, -22767,,,0.749427046094622,0.2666067055293492,0.7262519316922833,0.2857713586638734,3554.0,0.7433064153693103,0.2872492394211637,3581.0,5436.060939788818,5916.412758111954,5436.060939788818,477.544869184494,1.915790557861328,0.0 -22800,0.21451265,0.23309992,,,,,,,,,,,,,, -22900,0.178579,0.21045078,,,,,,,,,,,,,, -23000,0.07399294,0.27579463,,,,,,,,,,,,,, -23100,0.08298527,0.26708913,,,,,,,,,,,,,, -23114,,,0.7493207114083427,0.265721321105957,0.7254548681942882,0.2853577312754994,3554.0,0.7426980068416643,0.2866858615784697,3581.0,5516.128949642181,6000.536830663681,5516.128949642181,481.5577425956726,1.9451351165771484,0.0 -23200,0.056148753,0.30427146,,,,,,,,,,,,,, -23300,0.32589945,0.24107254,,,,,,,,,,,,,, -23400,0.111876056,0.29347,,,,,,,,,,,,,, -23459,,,0.7489502089364188,0.2658481938498361,0.7253882344277575,0.2852362105198368,3554.0,0.7426928254153867,0.2865019891222773,3581.0,5596.276057720184,6084.737917661667,5596.276057720184,485.5675754547119,1.9756977558135984,0.0 -23500,0.15962598,0.24268425,,,,,,,,,,,,,, -23600,0.1901026,0.28377813,,,,,,,,,,,,,, -23700,0.09305254,0.33548304,,,,,,,,,,,,,, -23800,0.1391469,0.21118394,,,,,,,,,,,,,, -23801,,,0.7496252059936523,0.2658059937613351,0.7260273003350098,0.2852406241481869,3554.0,0.7432441019006563,0.2866005384865435,3581.0,5676.235087871552,6168.698595523834,5676.235087871552,489.5242583751679,2.007225275039673,0.0 -23900,0.13517225,0.2845946,,,,,,,,,,,,,, -24000,0.105908684,0.18721294,,,,,,,,,,,,,, -24100,0.16843483,0.29314464,,,,,,,,,,,,,, -24146,,,0.7498554502214704,0.2652804510934012,0.7258368101962578,0.2851335979541098,3554.0,0.7430751601333426,0.2864032693163572,3581.0,5756.289666891098,6252.809968471527,5756.289666891098,493.53691935539246,2.037910223007202,0.0 -24200,0.10233781,0.2704727,,,,,,,,,,,,,, -24300,0.11958076,0.24853541,,,,,,,,,,,,,, -24400,0.38419712,0.28383172,,,,,,,,,,,,,, -24492,,,0.7493213926042829,0.2656843491962978,0.7257174189733399,0.2851750036270751,3554.0,0.7429374432770176,0.2864860016951445,3581.0,5836.392609119415,6336.972864627838,5836.392609119415,497.5527238845825,2.0684311389923096,0.0 -24500,0.14297295,0.33727828,,,,,,,,,,,,,, -24600,0.22230265,0.22202697,,,,,,,,,,,,,, -24700,0.09487795,0.2718642,,,,,,,,,,,,,, -24800,0.086535424,0.2714364,,,,,,,,,,,,,, -24837,,,0.7498811313084194,0.2655792576926095,0.7261310978826674,0.2851052957769854,3554.0,0.743295916163432,0.2864755024892663,3581.0,5916.42840719223,6421.065683841705,5916.42840719223,501.56517338752747,2.099144697189331,0.0 -24900,0.24180096,0.27559564,,,,,,,,,,,,,, -25000,0.102626674,0.29372045,,,,,,,,,,,,,, -25100,0.10886837,0.3340153,,,,,,,,,,,,,, -25180,,,0.7501208441598075,0.2650695528302874,0.7260279185864519,0.2850197538205191,3554.0,0.7432925755070162,0.2863555456532742,3581.0,5996.445634841919,6505.139622926712,5996.445634841919,505.57781767845154,2.129843711853028,0.0 -25200,0.11032737,0.2898512,,,,,,,,,,,,,, -25300,0.09964888,0.28215227,,,,,,,,,,,,,, -25400,0.3083769,0.21645255,,,,,,,,,,,,,, -25500,0.09389299,0.25941354,,,,,,,,,,,,,, -25527,,,0.7498316083635602,0.265564067023141,0.7260694788222777,0.2852243950478334,3554.0,0.7433385265768989,0.2865228852690414,3581.0,6076.453567028046,6589.206932544708,6076.453567028046,509.5925693511963,2.161004543304444,0.0 -25600,0.1116176,0.3492518,,,,,,,,,,,,,, -25700,0.16583033,0.24136925,,,,,,,,,,,,,, -25800,0.10141674,0.25393915,,,,,,,,,,,,,, -25872,,,0.7498658725193569,0.2655244214194162,0.7262190956712508,0.2850607473258564,3554.0,0.743424156463802,0.2863676470106988,3581.0,6156.694766283035,6673.505701780319,6156.694766283035,513.6058740615845,2.1919939517974854,0.0 -25900,0.11826235,0.21466693,,,,,,,,,,,,,, -26000,0.11593978,0.32389042,,,,,,,,,,,,,, -26100,0.06832033,0.2930458,,,,,,,,,,,,,, -26200,0.06971088,0.28083926,,,,,,,,,,,,,, -26214,,,0.7502429825919015,0.264920881816319,0.7258607159186832,0.2850451708242473,3554.0,0.7430811596795588,0.2863786575415387,3581.0,6236.789935827255,6757.661597251892,6236.789935827255,517.6208217144012,2.224311351776123,0.0 -26300,0.09539934,0.25783464,,,,,,,,,,,,,, -26400,0.07995587,0.4077147,,,,,,,,,,,,,, -26500,0.12902206,0.2071976,,,,,,,,,,,,,, -26559,,,0.7503128732953753,0.2651557922363281,0.7263939234401379,0.2849239763679568,3554.0,0.7435952798842851,0.2862669500815938,3581.0,6316.855994939804,6841.782725572586,6316.855994939804,521.6307320594788,2.2558205127716064,0.0 -26600,0.19735086,0.3098378,,,,,,,,,,,,,, -26700,0.12537874,0.32691088,,,,,,,,,,,,,, -26800,0.09974973,0.2605342,,,,,,,,,,,,,, -26900,0.118202664,0.2298297,,,,,,,,,,,,,, -26907,,,0.7502010890415737,0.2654227529253278,0.7263083986573228,0.285112319800313,3554.0,0.743497310021642,0.2864094733925579,3581.0,6396.952573060989,6925.938338279724,6396.952573060989,525.6441569328308,2.287712812423706,0.0 -27000,0.084880546,0.22189137,,,,,,,,,,,,,, -27100,0.25142643,0.2829168,,,,,,,,,,,,,, -27200,0.11502063,0.3543246,,,,,,,,,,,,,, -27249,,,0.7506561960492816,0.2648429019110543,0.7264343845622889,0.2849978745889315,3554.0,0.7435778948355907,0.2863592271929978,3581.0,6476.943389892578,7009.990304231644,6476.943389892578,529.6607377529144,2.31883192062378,0.0 -27300,0.16816677,0.22081986,,,,,,,,,,,,,, -27400,0.081198804,0.25939524,,,,,,,,,,,,,, -27500,0.09464563,0.3366724,,,,,,,,,,,,,, -27596,,,0.7495948246547154,0.2651593174253191,0.7254652410795934,0.2850242189698227,3554.0,0.7427244593863446,0.2862810285622033,3581.0,6556.95263504982,7094.058998584747,6556.95263504982,533.6739375591278,2.35146427154541,0.0 -27600,0.06760711,0.29902637,,,,,,,,,,,,,, -27700,0.06919422,0.37004474,,,,,,,,,,,,,, -27800,0.08763061,0.3127361,,,,,,,,,,,,,, -27900,0.0454972,0.30454025,,,,,,,,,,,,,, -27939,,,0.750136307307652,0.265045166015625,0.7260886446169809,0.2848596610443338,3554.0,0.7433330724439752,0.2861486976621404,3581.0,6636.923789262772,7178.088750839233,6636.923789262772,537.6880068778992,2.38265347480774,0.0 -28000,0.09160372,0.21107121,,,,,,,,,,,,,, -28100,0.1879302,0.2185714,,,,,,,,,,,,,, -28200,0.09927773,0.3675546,,,,,,,,,,,,,, -28283,,,0.7500929151262555,0.2648871626172747,0.7258392145074212,0.284950715742825,3554.0,0.7430446851656312,0.2862953115727974,3581.0,6716.911543369293,7262.139762163162,6716.911543369293,541.7042119503021,2.416171073913574,0.0 -28300,0.09790352,0.26382023,,,,,,,,,,,,,, -28400,0.0988913,0.22896267,,,,,,,,,,,,,, -28500,0.07606695,0.29534245,,,,,,,,,,,,,, -28600,0.07063542,0.31223053,,,,,,,,,,,,,, -28628,,,0.7504279272896903,0.2649558271680559,0.7262554351171215,0.2849141530394889,3554.0,0.7434871516990715,0.286252019392715,3581.0,6796.866130590439,7346.153185129166,6796.866130590439,545.7160577774048,2.4494388103485107,0.0 -28700,0.06482243,0.29391828,,,,,,,,,,,,,, -28800,0.12354587,0.26981315,,,,,,,,,,,,,, -28900,0.078614615,0.30349874,,,,,,,,,,,,,, -28972,,,0.7505592618669782,0.2649299076625279,0.7264863863780248,0.284866135510824,3554.0,0.7436901817971586,0.2861457319773631,3581.0,6876.82315993309,7430.169394493103,6876.82315993309,549.7289471626282,2.482147455215454,0.0 -29000,0.1079665,0.22541787,,,,,,,,,,,,,, -29100,0.13958995,0.18494342,,,,,,,,,,,,,, -29200,0.116177306,0.18821374,,,,,,,,,,,,,, -29300,0.091862746,0.2909005,,,,,,,,,,,,,, -29314,,,0.7498602867126465,0.2651236738477434,0.7258943762749719,0.2850154947550295,3554.0,0.7430681379372033,0.2863870432709089,3581.0,6956.778971672058,7514.184953689575,6956.778971672058,553.7419407367706,2.515422344207764,0.0 -29400,0.15920103,0.2793925,,,,,,,,,,,,,, -29500,0.09381959,0.21323323,,,,,,,,,,,,,, -29600,0.09913959,0.25416657,,,,,,,,,,,,,, -29658,,,0.7504299027579171,0.2647002935409546,0.7261024522325197,0.2848177229881823,3554.0,0.7433210051748813,0.2861635942626885,3581.0,7036.799525976181,7598.264146327972,7036.799525976181,557.7553491592407,2.5471506118774414,0.0 -29700,0.076735795,0.30583292,,,,,,,,,,,,,, -29800,0.123963766,0.2855879,,,,,,,,,,,,,, -29900,0.11671577,0.24242532,,,,,,,,,,,,,, -30000,0.06043138,0.29563832,,,,,,,,,,,,,, -30003,,,0.7505853516714913,0.2649395125252859,0.7266091436365715,0.2848703602290113,3554.0,0.7437669487180606,0.2862120337807177,3581.0,7116.821959257126,7682.350178480148,7116.821959257126,561.7724099159241,2.5802698135375977,0.0 -30100,0.06707992,0.23600432,,,,,,,,,,,,,, -30200,0.098719954,0.27850342,,,,,,,,,,,,,, -30300,0.079628415,0.2186565,,,,,,,,,,,,,, -30348,,,0.7509135518755231,0.26484443460192,0.726777857585643,0.2848333166634426,3554.0,0.7439271638726962,0.2862066137361246,3581.0,7196.888969659805,7766.479268789291,7196.888969659805,565.7843985557556,2.616807460784912,0.0 -30400,0.13999903,0.25127244,,,,,,,,,,,,,, -30500,0.14470682,0.22538167,,,,,,,,,,,,,, -30600,0.07626016,0.3221715,,,,,,,,,,,,,, -30693,,,0.7508979524884906,0.264484030859811,0.726506582591798,0.2848165551799029,3554.0,0.7436694560920483,0.2861527882618332,3581.0,7276.879731178284,7850.534174203873,7276.879731178284,569.8019735813141,2.649852514266968,0.0 -30700,0.09227602,0.2907714,,,,,,,,,,,,,, -30800,0.1457602,0.3562202,,,,,,,,,,,,,, -30900,0.06108871,0.21006104,,,,,,,,,,,,,, -31000,0.058954936,0.27542242,,,,,,,,,,,,,, -31041,,,0.7510170936584473,0.2646525587354387,0.7267936573447172,0.284739548528067,3554.0,0.7439709332894093,0.2860615678886833,3581.0,7356.900419712067,7934.61389875412,7356.900419712067,573.8156337738037,2.681325674057007,0.0 -31100,0.10891128,0.20000122,,,,,,,,,,,,,, -31200,0.08716727,0.21530862,,,,,,,,,,,,,, -31300,0.19916424,0.33873013,,,,,,,,,,,,,, -31385,,,0.7506763594491142,0.2646515199116298,0.726357583994267,0.2847565504427229,3554.0,0.7435804173720678,0.2860809982372242,3581.0,7436.869582891464,8018.64168548584,7436.869582891464,577.8291063308716,2.7132408618927,0.0 -31400,0.14137319,0.21948513,,,,,,,,,,,,,, -31500,0.13820691,0.26091275,,,,,,,,,,,,,, -31600,0.12083399,0.21858746,,,,,,,,,,,,,, -31700,0.0763123,0.23527984,,,,,,,,,,,,,, -31729,,,0.7512480872017997,0.2642843893596104,0.726590664787915,0.2847422276176491,3554.0,0.7437989235723261,0.2860875091084019,3581.0,7516.999427556992,8102.830830574036,7516.999427556992,581.842268705368,2.745941162109375,0.0 -31800,0.23740049,0.2185856,,,,,,,,,,,,,, -31900,0.107172936,0.30544832,,,,,,,,,,,,,, -32000,0.12675527,0.1903483,,,,,,,,,,,,,, -32075,,,0.750896862574986,0.2644829750061035,0.7264721665948579,0.2847329194987162,3554.0,0.7436901817971586,0.2860553979008133,3581.0,7596.9842693805695,8186.874871730804,7596.9842693805695,585.8551957607269,2.778761148452759,0.0 -32100,0.09658695,0.32705128,,,,,,,,,,,,,, -32200,0.07374917,0.3216127,,,,,,,,,,,,,, -32300,0.08625913,0.29519618,,,,,,,,,,,,,, -32400,0.06459737,0.2684552,,,,,,,,,,,,,, -32419,,,0.751089368547712,0.2645196914672851,0.7267552570607062,0.2846938666159609,3554.0,0.7439336406555431,0.2860293544161023,3581.0,7677.118875980377,8271.065751314163,7677.118875980377,589.8650617599487,2.811578750610352,0.0 -32500,0.059665382,0.2492563,,,,,,,,,,,,,, -32600,0.11873022,0.23357198,,,,,,,,,,,,,, -32700,0.07827826,0.24712658,,,,,,,,,,,,,, -32765,,,0.7514265605381557,0.2642244781766619,0.7267815670942952,0.2846990358849623,3554.0,0.7439089607040631,0.2860738396877618,3581.0,7757.2637186050415,8355.271630525589,7757.2637186050415,593.8802864551544,2.8437328338623047,0.0 -32800,0.05249224,0.23246816,,,,,,,,,,,,,, -32900,0.080200896,0.21377191,,,,,,,,,,,,,, -33000,0.0707767,0.21415225,,,,,,,,,,,,,, -33100,0.054861583,0.23445037,,,,,,,,,,,,,, -33109,,,0.7512797628130231,0.2643719060080392,0.7268199673783061,0.2846539722242983,3554.0,0.7439951360042586,0.2859970045901983,3581.0,7837.221108436584,8439.284693956375,7837.221108436584,597.8902106285095,2.876116275787353,0.0 -33200,0.066703446,0.28876844,,,,,,,,,,,,,, -33300,0.071409,0.3217846,,,,,,,,,,,,,, -33400,0.04852037,0.26386783,,,,,,,,,,,,,, -33454,,,0.7510183198111398,0.2643978595733642,0.726607151493036,0.2846558956732291,3554.0,0.7438003552822187,0.2859811535163886,3581.0,7917.2318749427795,8523.35584950447,7917.2318749427795,601.9045717716217,2.9088120460510254,0.0 -33500,0.048323326,0.22225621,,,,,,,,,,,,,, -33600,0.04796484,0.2981352,,,,,,,,,,,,,, -33700,0.078159034,0.24874553,,,,,,,,,,,,,, -33796,,,0.7513384137834821,0.2641600370407104,0.7267140402979038,0.2846327284178039,3554.0,0.7438691455337196,0.2859906641606743,3581.0,7997.312318086624,8607.498377799988,7997.312318086624,605.9201576709747,2.9417948722839355,0.0 -33800,0.06830641,0.23818192,,,,,,,,,,,,,, -33900,0.08424096,0.28584245,,,,,,,,,,,,,, -34000,0.07741934,0.3971374,,,,,,,,,,,,,, -34100,0.085014105,0.21735327,,,,,,,,,,,,,, -34141,,,0.7514914103916713,0.2642370292118617,0.7269534409951814,0.2846216342391495,3554.0,0.744121058298136,0.2859547691483698,3581.0,8077.364030599594,8691.611410140991,8077.364030599594,609.9347627162933,2.974903583526612,0.0 -34200,0.07030004,0.311491,,,,,,,,,,,,,, -34300,0.062623724,0.23524018,,,,,,,,,,,,,, -34400,0.07045588,0.20643955,,,,,,,,,,,,,, -34487,,,0.7512409346444267,0.2642880848475865,0.7267562187851716,0.2846199512213351,3554.0,0.7439071881108629,0.2859628139944324,3581.0,8157.329024076462,8775.637843370438,8157.329024076462,613.9488220214844,3.0089025497436523,0.0 -34500,0.07577722,0.23875128,,,,,,,,,,,,,, -34600,0.051018447,0.24254078,,,,,,,,,,,,,, -34700,0.2631335,0.34890726,,,,,,,,,,,,,, -34800,0.049871895,0.3138941,,,,,,,,,,,,,, -34829,,,0.751697267804827,0.2642233712332589,0.7271451676368177,0.2846372794353633,3554.0,0.7442693425370008,0.2860020155748219,3581.0,8237.486559867859,8859.855407238007,8237.486559867859,617.962637424469,3.041879415512085,0.0 -34900,0.08058839,0.17860915,,,,,,,,,,,,,, -35000,0.05877336,0.2401844,,,,,,,,,,,,,, -35100,0.0913731,0.22024743,,,,,,,,,,,,,, -35174,,,0.7514427730015346,0.2641376086643764,0.7268877689531162,0.2845476158026168,3554.0,0.7440546542297891,0.2859003130399591,3581.0,8317.607504606247,8944.038051128387,8317.607504606247,621.9767315387726,3.076141119003296,0.0 -35200,0.08064503,0.2172639,,,,,,,,,,,,,, -35300,0.049467493,0.28993237,,,,,,,,,,,,,, -35400,0.054633625,0.2459361,,,,,,,,,,,,,, -35500,0.047566928,0.3186196,,,,,,,,,,,,,, -35519,,,0.7515405927385602,0.2641528333936419,0.7270060610623593,0.2845413645935917,3554.0,0.7441674866046495,0.2858832006979108,3581.0,8397.610383033752,9028.099632501602,8397.610383033752,625.9883260726929,3.109600782394409,0.0 -35600,0.092521414,0.20179902,,,,,,,,,,,,,, -35700,0.075997256,0.2089142,,,,,,,,,,,,,, -35800,0.10385868,0.2349109,,,,,,,,,,,,,, -35862,,,0.7511764253888812,0.264157908303397,0.7266679462181697,0.2845379298633582,3554.0,0.7438509423650865,0.2858744911293982,3581.0,8477.655651330948,9112.202283859251,8477.655651330948,629.9989938735962,3.143000602722168,0.0 -35900,0.070997015,0.19004494,,,,,,,,,,,,,, -36000,0.058141056,0.30186462,,,,,,,,,,,,,, -36100,0.059343427,0.24438658,,,,,,,,,,,,,, -36189,,,0.7514442716326032,0.2641434158597673,0.7268797316843697,0.2845460530003605,3554.0,0.7440562222930047,0.2858859788968689,3581.0,8553.608373880386,9192.21478652954,8553.608373880386,634.0129742622375,3.176217794418335,0.0 -36189,,,,,,,,,,,8553.608373880386,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/eval_measurements.csv deleted file mode 100644 index b1d6882f7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.95663857460022,0.0,28.97892165184021,1,0,28.97892165184021,1.0323986403937448,3581,0.2128455943150481,32.935672760009766,1.0327176366533553,0.1975462777273995,1.0370871866426914,3554,0.1897597777152504 -7.961894273757935,0.0180349349975585,109.12077069282532,344,0,109.12077069282532,0.3352491202483594,3581,0.689466828508971,117.11503720283508,0.3145843914576939,0.6935643468584333,0.3335539002598304,3554,0.6709493402020611 -11.97603702545166,0.0421395301818847,189.2466561794281,686,0,189.2466561794281,0.3128802553297787,3581,0.7157858151092572,201.29289603233337,0.2929177284240722,0.7199602808271136,0.3106455547996096,3554,0.6985059747819359 -21.19350576400757,0.0658392906188964,269.41495633125305,1026,0,269.41495633125305,0.3078806221202178,3581,0.7185578781677604,290.7165720462799,0.2882750374930246,0.7229129927498954,0.3057712604305888,3554,0.7015023647430711 -25.20110893249512,0.0901281833648681,349.4108908176422,1372,0,349.4108908176422,0.3003164896982337,3581,0.7253081857808573,374.75886607170105,0.280628102166312,0.7301158223833356,0.298456384146824,3554,0.7081162812807752 -29.209868669509888,0.1150052547454834,429.4600641727448,1718,0,429.4600641727448,0.299394673057456,3581,0.7260523340416434,458.8567945957184,0.2795077732631138,0.7308858462742397,0.2976524168410593,3554,0.7087743068989167 -33.27567386627197,0.1403758525848388,509.4550085067749,2062,0,509.4550085067749,0.2966512101084718,3581,0.7291415549296635,542.9577581882477,0.2772624662944248,0.733593395778111,0.2949803341085748,3554,0.7119433264015897 -37.286170959472656,0.1653282642364502,589.4221696853638,2407,0,589.4221696853638,0.2977179362433678,3581,0.7290389490540352,626.9751617908478,0.2773699590138027,0.734459672655378,0.2958892667702764,3554,0.7120967901484243 -41.29489207267761,0.1909987926483154,669.4945640563965,2754,0,669.4945640563965,0.2960094631933294,3581,0.730638373533929,711.0966486930847,0.2762782233101981,0.7358302388872419,0.2941995168845843,3554,0.7136425561427265 -45.30148363113403,0.2160556316375732,749.650426864624,3099,0,749.650426864624,0.2937576562390917,3581,0.733525927857093,795.2993674278259,0.2740841593061174,0.7383593831743512,0.2921760829567037,3554,0.7162388687262592 -49.31132197380066,0.2427656650543213,829.7700300216675,3446,0,829.7700300216675,0.2923777265210486,3581,0.7349211632356535,879.4692540168762,0.2721844060080392,0.7405399594988141,0.2907790064342114,3554,0.7177803069604671 -53.31908226013184,0.2685549259185791,909.8821904659272,3794,0,909.8821904659272,0.2926671705376466,3581,0.7352366848252933,963.6286108493804,0.2730499165398733,0.7399829455784389,0.2910478084222883,3554,0.718139854521314 -57.33399033546448,0.2938432693481445,990.0270144939424,4139,0,990.0270144939424,0.2926931117573653,3581,0.7357221708321697,1047.8274364471436,0.2725841488157,0.740879876273019,0.2911185295177969,3554,0.7185715314170653 -61.34604811668396,0.3189194202423095,1070.1156814098358,4482,0,1070.1156814098358,0.291115708339151,3581,0.7364936579342363,1131.9665739536283,0.2707956518445696,0.7423781667436872,0.2895648292966551,3554,0.7193415979354248 -65.3628830909729,0.3446106910705566,1150.3168470859528,4829,0,1150.3168470859528,0.2924860592362469,3581,0.7358825223401284,1216.2242138385773,0.2728027275630406,0.7408492905752999,0.2910806444433209,3554,0.7187132483865011 -69.37438464164734,0.3716864585876465,1230.4913923740387,5173,0,1230.4913923740387,0.290889634529461,3581,0.7370365486901355,1300.4513957500458,0.2708584240504673,0.7423981257847377,0.2893769839001829,3554,0.719932715008617 -73.38616728782654,0.3977348804473877,1310.5437569618225,5516,0,1310.5437569618225,0.2912650152279391,3581,0.736393915478393,1384.5558178424835,0.2711343594959804,0.7416613442557198,0.2897987687728615,3554,0.7192231684369724 -77.39232563972473,0.4241127967834472,1390.7154486179352,5862,0,1390.7154486179352,0.2916534517571034,3581,0.736499725657114,1468.7736823558807,0.2712622029440744,0.742295469556536,0.290077462784011,3554,0.7193490169527293 -81.4054057598114,0.45113205909729,1470.9021430015564,6208,0,1470.9021430015564,0.2910343394935946,3581,0.737248850814193,1553.0145823955536,0.2708876473563058,0.7428860664367676,0.2894863113635164,3554,0.7201298685240223 -85.41960525512695,0.4771711826324463,1551.0724306106567,6552,0,1551.0724306106567,0.2930313361753002,3581,0.7366234662978218,1637.238983631134,0.2728444167545863,0.7417918613978794,0.2914096572523916,3554,0.7196870631023143 -89.430344581604,0.5038919448852539,1631.3222754001615,6899,0,1631.3222754001615,0.2900458460778239,3581,0.7397011653300405,1721.5402591228485,0.2698579856327602,0.7451910972595215,0.288501505510956,3554,0.7226475944798115 -93.44122433662416,0.5299606323242188,1711.325751543045,7245,0,1711.325751543045,0.2896346044553721,3581,0.7388482752940868,1805.5942947864528,0.2694618701934814,0.7445575850350517,0.2881723553126759,3554,0.7216776953564645 -97.45112013816832,0.5566811561584473,1791.4902153015137,7589,0,1791.4902153015137,0.2897060195083426,3581,0.7392248831724728,1889.8089570999143,0.2695349454879761,0.7447182791573661,0.2881819725573297,3554,0.7220469288565701 -101.46298694610596,0.5842244625091553,1871.5134434700008,7935,0,1871.5134434700008,0.2894398237333321,3581,0.7399333750392697,1973.8852660655973,0.2689814908163888,0.7456581251961845,0.2879128099225784,3554,0.722861578173361 -105.47749543190002,0.6115193367004395,1951.6899888515472,8282,0,1951.6899888515472,0.2890387063451201,3581,0.7398233379075329,2058.11745595932,0.2688144786017282,0.7454793793814523,0.2876635858968328,3554,0.7226394885164603 -109.49516701698305,0.6374204158782959,2031.8050429821008,8625,0,2031.8050429821008,0.2900541977188634,3581,0.7396019682874895,2142.290014743805,0.2696679660252162,0.7453502927507673,0.2886448368036016,3554,0.7224195283923045 -113.5051245689392,0.6649777889251709,2111.8292939662933,8971,0,2111.8292939662933,0.2892369299885681,3581,0.7398632894311994,2226.365335702896,0.2687513147081647,0.7453242710658482,0.2877739266055852,3554,0.7227986539154826 -117.51276755332948,0.6933310031890869,2191.909622192383,9316,0,2191.909622192383,0.2901880625938111,3581,0.7397393442605068,2310.495637178421,0.270092248916626,0.7449782235281808,0.2887455774413513,3554,0.7226653863824212 -121.52645993232728,0.7202789783477783,2271.916765451432,9658,0,2271.916765451432,0.2886718136410046,3581,0.7397383216105836,2394.557331323624,0.2685591323035104,0.7451615333557129,0.2872988003723797,3554,0.7224965350441404 -125.54268622398376,0.7473556995391846,2351.937286376953,10003,0,2351.937286376953,0.2888142346869764,3581,0.7395914690816113,2478.6354858875275,0.2682241712297712,0.7457009043012347,0.2873909713581967,3554,0.7223576345534961 -129.55689525604248,0.7765707969665527,2431.914649963379,10348,0,2431.914649963379,0.289362784105784,3581,0.7397801820807736,2562.67041516304,0.268939733505249,0.7455603054591587,0.2879494069732168,3554,0.7225909901255627 -133.56716871261597,0.8033936023712158,2511.8980689048767,10692,0,2511.8980689048767,0.2888140642453225,3581,0.7395841060021642,2646.7047917842865,0.2683139187949044,0.7455687522888184,0.2873520215173484,3554,0.722452364413337 -137.57607746124268,0.8305361270904541,2592.050028562546,11038,0,2592.050028562546,0.288732559046443,3581,0.7401103616526459,2730.907357931137,0.2678736788885934,0.7464522634233747,0.2872696051653946,3554,0.7229895562218627 -141.58618927001953,0.8576686382293701,2672.0285749435425,11384,0,2672.0285749435425,0.2887165716193102,3581,0.7403277088496579,2814.937951564789,0.2681691816874912,0.7461264474051339,0.2872296249054762,3554,0.7231992121553179 -145.6027750968933,0.8846747875213623,2752.0991699695587,11730,0,2752.0991699695587,0.2882289721359257,3581,0.7403952719212511,2899.066791296005,0.2681073631559099,0.7456574440002441,0.2868558060405089,3554,0.7232962776317178 -149.6131820678711,0.9115986824035645,2832.174956798553,12073,0,2832.174956798553,0.288034907268832,3581,0.7406891815091804,2983.193598508835,0.2673248733792986,0.7469180652073452,0.2866296290546303,3554,0.7235540197884426 -153.62973427772522,0.938593864440918,2912.2765226364136,12418,0,2912.2765226364136,0.2881162420260577,3581,0.7404398594579028,3067.35256266594,0.2677193880081177,0.746204035622733,0.2867117191072119,3554,0.7232028529693655 -157.64133405685425,0.9663434028625488,2992.2794008255005,12764,0,2992.2794008255005,0.2880752337641371,3581,0.7402508055754329,3151.408617973328,0.2676114354814802,0.7461694989885602,0.2867143466758406,3554,0.7230209496561972 -161.65591073036194,0.9946553707122804,3072.431079149246,13108,0,3072.431079149246,0.2917841464172891,3581,0.7357027404836288,3235.6166298389435,0.2709654739924839,0.7421060970851353,0.2904571035167241,3554,0.7184103738745076 -165.66799879074097,1.0246577262878418,3152.4668333530426,13454,0,3152.4668333530426,0.2877401795609641,3581,0.7405587595556409,3319.708043813705,0.2673053400857108,0.746272087097168,0.286368658251486,3554,0.7233196337973059 -169.683251619339,1.0551369190216064,3232.474317073822,13800,0,3232.474317073822,0.2878812370737049,3581,0.7401016350399678,3403.774663925171,0.2676271370479038,0.7456895964486259,0.2865288197222759,3554,0.7229075348638858 -173.69537568092346,1.0848171710968018,3312.635585308075,14143,0,3312.635585308075,0.288486918534889,3581,0.740409725373499,3487.991771221161,0.2676312753132411,0.7468343462262835,0.2871341909259373,3554,0.7231794281091728 -177.7116355895996,1.1130201816558838,3392.603483438492,14489,0,3392.603483438492,0.287666889649801,3581,0.7417685544147934,3572.0182886123657,0.2670495850699289,0.7479549816676548,0.2862707512661789,3554,0.7246055281021384 -181.7254831790924,1.142401933670044,3472.621673107147,14835,0,3472.621673107147,0.2877663253106674,3581,0.7418121874781834,3656.093423604965,0.2671400308609009,0.7479761668613979,0.2863954834946099,3554,0.7246696201682963 -185.7420694828033,1.1750538349151611,3552.6018946170807,15176,0,3552.6018946170807,0.28812868426679,3581,0.7418471621055571,3740.136210441589,0.2673849718911307,0.7476153373718262,0.2866646117820589,3554,0.7247315140071047 -189.75790739059448,1.2031304836273191,3632.597405195236,15521,0,3632.597405195236,0.2877476108170727,3581,0.7398874921460485,3824.189927577973,0.2672596148082188,0.7456714085170201,0.286428267994689,3554,0.7225538263444359 -193.76922011375427,1.2327890396118164,3712.6631059646606,15867,0,3712.6631059646606,0.2873642534491936,3581,0.7413994459691776,3908.31063747406,0.2666402203696115,0.7476047788347516,0.2860132838878728,3554,0.7241376491497257 -197.7851824760437,1.260986328125,3792.858276128769,16210,0,3792.858276128769,0.2875945201235688,3581,0.7413657666983734,3992.563661813736,0.2667592423302786,0.7477357728140694,0.2862234206835607,3554,0.7241149112355796 -201.79962134361267,1.2895026206970217,3873.008520364762,16558,0,3873.008520364762,0.2875985084582693,3581,0.7419231109065205,4076.7703862190247,0.2664368493216378,0.748434339250837,0.2862051650923695,3554,0.724732063563942 -205.81161665916443,1.3195016384124756,3953.220718622208,16903,0,3953.220718622208,0.2871007165639835,3581,0.7421911815397235,4161.03843665123,0.2663480384009225,0.7483385631016323,0.2857050511967149,3554,0.7249650756629854 -209.8240213394165,1.347517728805542,4033.388636350632,17247,0,4033.388636350632,0.2876641966716699,3581,0.7422807656729964,4245.26022362709,0.2669499261038644,0.7485357693263462,0.2862374515565647,3554,0.7251728768421145 -213.8384931087494,1.375917673110962,4113.347852706909,17592,0,4113.347852706909,0.2882725370226542,3581,0.7403153688739179,4329.27641749382,0.2669779573168073,0.7472812107631138,0.2868957691267761,3554,0.723099261505522 -217.8497090339661,1.4057002067565918,4193.354947090149,17936,0,4193.354947090149,0.287055038200747,3581,0.7417701224780089,4413.337802171707,0.2662241969789777,0.7481169700622559,0.2856556254286543,3554,0.7245039287818303 -221.86294984817505,1.4366326332092283,4273.456275224686,18281,0,4273.456275224686,0.2870435163449455,3581,0.7418551387749581,4497.49751830101,0.2663684742791312,0.7480468068804059,0.2857335594576533,3554,0.7245921326542276 -225.87409567832947,4.3728437423706055,4350.548347711563,18609,0,4350.548347711563,0.2870939670744903,3581,0.741691855670553,4581.550496578217,0.2659213542938232,0.7484518459865025,0.2857109761063678,3554,0.724492731561269 -229.88594794273376,4.401664972305298,4430.576882123947,18955,0,4430.576882123947,0.2875681016672193,3581,0.7419325874624756,4665.634213447571,0.2668883630207607,0.74802337374006,0.2862276969227015,3554,0.7247163324994724 -233.89514589309687,4.43175482749939,4510.648891210556,19298,0,4510.648891210556,0.2872580342105033,3581,0.7416037714238342,4749.759995222092,0.266499434198652,0.7479914937700544,0.2859593242759039,3554,0.7243800724096089 -237.90518808364868,4.460954427719116,4590.72519493103,19640,0,4590.72519493103,0.2870987053524679,3581,0.742760524840303,4833.889722824097,0.2658189194543021,0.7494755472455706,0.2856957087304797,3554,0.7257194111168753 -241.92007851600647,4.490314960479736,4670.917221784592,19985,0,4670.917221784592,0.2873735595634948,3581,0.7427205051399749,4918.140657424927,0.2660778931209019,0.7494589260646275,0.2859783526813977,3554,0.7255726107466939 -245.93207383155823,4.521239280700684,4750.951722860336,20329,0,4750.951722860336,0.2870360850888369,3581,0.7423481923912664,5002.232574462891,0.2660074234008789,0.7488336563110352,0.2856191142462718,3554,0.7252096971502181 -249.9386088848114,4.551448345184326,4830.941958904266,20668,0,4830.941958904266,0.2871444859806967,3581,0.7419362690021991,5086.27272605896,0.2657781328473772,0.7490506853376117,0.2857364618047007,3554,0.7247555571187394 -253.94873142242432,4.5808281898498535,4911.088220119476,21014,0,4911.088220119476,0.2872181849518291,3581,0.7425711300745252,5170.472066164017,0.266131009374346,0.7492361749921527,0.2857910224944605,3554,0.7255093430157921 -257.9545512199402,4.610207319259644,4991.229874610901,21362,0,4991.229874610901,0.286876074464186,3581,0.7426757130733385,5254.662563562393,0.2658153431756155,0.7492140361240932,0.2855267028293384,3554,0.7254861929340181 -261.9677035808563,4.641288995742798,5071.210415363312,21703,0,5071.210415363312,0.2871951071518954,3581,0.7431506316976753,5338.700535297394,0.2659606592995779,0.7498057229178292,0.2857190992433701,3554,0.7261010783404263 -265.9774160385132,4.673285722732544,5151.294460535049,22048,0,5151.294460535049,0.287159041697937,3581,0.7429184219884459,5422.839916706085,0.2656627552849905,0.749934468950544,0.2857178970877884,3554,0.7257356230435776 -269.9894750118256,4.706132888793945,5231.44545173645,22391,0,5231.44545173645,0.2871350094247417,3581,0.7431317467624267,5507.0493676662445,0.2657770940235683,0.7499521800449916,0.2857854582314821,3554,0.7259103134232555 -274.00053787231445,4.736352205276489,5311.462661266327,22735,0,5311.462661266327,0.2868370433254503,3581,0.7426772129598925,5591.121385335922,0.2656928300857544,0.749309744153704,0.2854609105717149,3554,0.7254960849570906 -278.0168051719665,4.767531394958496,5391.550899028778,23080,0,5391.550899028778,0.2868782220290247,3581,0.7429138541521223,5675.270745754242,0.2653392212731497,0.750044754573277,0.2854270441316123,3554,0.7257962116848973 -282.02875447273254,4.797436714172363,5471.681828737259,23426,0,5471.681828737259,0.2870010422848017,3581,0.7429287848410011,5759.457330942154,0.2657884699957711,0.749577454158238,0.2855529956892761,3554,0.7258460152732836 -286.0441243648529,4.827166795730591,5551.710270643234,23770,0,5551.710270643234,0.2867749343867809,3581,0.7428168387627408,5843.544717788696,0.2655000175748552,0.7495931216648647,0.2853825815487391,3554,0.7256565555536015 -290.05428886413574,4.858988285064697,5631.85106420517,24116,0,5631.85106420517,0.2867777296299043,3581,0.7428803794113027,5927.740929841995,0.2650260414396013,0.750204154423305,0.2853486979349852,3554,0.7257194111168753 -294.06663846969604,4.8908774852752686,5712.004364490509,24462,0,5712.004364490509,0.2868571895289374,3581,0.7423995294174114,6011.95213842392,0.2653989621571132,0.749311992100307,0.2854617692542733,3554,0.7251623665675999 -298.0805068016052,4.921252250671387,5792.194799900055,24808,0,5792.194799900055,0.2868880735566183,3581,0.7429736450842991,6096.200660705566,0.2654122965676443,0.7499324253627232,0.2854906038145839,3554,0.7258076836838773 -302.09289360046387,4.95358943939209,5872.191735506058,25153,0,5872.191735506058,0.2868404521585276,3581,0.7430454351089081,6180.25568819046,0.265025121825082,0.7504647118704659,0.2854275078201937,3554,0.7259484389288478 -306.1063861846924,4.98514199256897,5952.1544880867,25497,0,5952.1544880867,0.2869330701532393,3581,0.742860471826131,6264.277378082275,0.2653808423451015,0.7500345366341727,0.2855252602426403,3554,0.7257026496333356 -310.115487575531,5.016194820404053,6032.137490034103,25845,0,6032.137490034103,0.2867713892003805,3581,0.742732231525761,6348.31400680542,0.2653955902372087,0.7496070861816406,0.2853795418124824,3554,0.7255828462427898 -314.12628722190857,5.047453880310059,6112.283630371094,26189,0,6112.283630371094,0.2868577349422298,3581,0.743048912118647,6432.515614271164,0.2648829051426479,0.7505412101745605,0.2854377261426385,3554,0.7259142977103263 -318.13490319252014,5.079059839248657,6192.388805150986,26532,0,6192.388805150986,0.2868616891885995,3581,0.7429106498490295,6516.674349784851,0.2651327848434448,0.7501976149422782,0.2854400617591974,3554,0.7257329439539955 -322.150488615036,5.109807968139648,6272.537399768829,26880,0,6272.537399768829,0.2867100642933189,3581,0.7427324360557456,6600.883124828339,0.265214272907802,0.7498160089765277,0.2852989973885059,3554,0.7256105988630768 -326.1638045310974,5.161976337432861,6352.623078107834,27223,0,6352.623078107834,0.2868899825031415,3581,0.7427158691269896,6685.047988176346,0.2650226354598999,0.7501500674656459,0.2854420195554305,3554,0.725652296488112 -330.1715428829193,5.19452714920044,6432.78490281105,27569,0,6432.78490281105,0.2866826913637078,3581,0.7429066956026599,6769.26359128952,0.2649794135774885,0.7501449584960938,0.2852976921910171,3554,0.725748056767023 -334.1822066307068,5.227466106414795,6512.966456890106,27914,0,6512.966456890106,0.2867274493420134,3581,0.7428141798729405,6853.502288341522,0.2651003428867885,0.7500529289245605,0.2853273510865838,3554,0.7256652110737901 -338.1944320201874,5.260152816772461,6593.085450172424,28257,0,6593.085450172424,0.2868046934995462,3581,0.7432074910334054,6937.679908514023,0.2649598121643066,0.7506343296595982,0.2853751453577834,3554,0.7260884385331668 -342.20906805992126,5.2943854331970215,6673.190561294556,28604,0,6673.190561294556,0.2867792636047891,3581,0.743204695790282,7021.847796201706,0.2649067299706595,0.7506190027509417,0.2853277460805606,3554,0.7260718144388365 -346.21965074539185,5.327981233596802,6753.176616668701,28946,0,6753.176616668701,0.2866358880855557,3581,0.7428719255052709,7105.89125585556,0.2649789367403303,0.7500931876046317,0.2852500868299803,3554,0.7257242197392023 -350.23344373703003,5.359550714492798,6833.300783395767,29290,0,6833.300783395767,0.2869654199791434,3581,0.7430520482450782,7190.074701547623,0.2651293788637434,0.7503688676016671,0.2854834767493493,3554,0.7259679481965743 -354.2454254627228,5.391098737716675,6913.516995429993,29635,0,6913.516995429993,0.2868040458212615,3581,0.7431846518517872,7274.34780049324,0.2648529495511736,0.750690392085484,0.2853850373808561,3554,0.7260376732203151 -358.2557344436645,5.422429800033569,6993.540855407715,29980,0,6993.540855407715,0.2867590492246404,3581,0.7432026504904357,7358.42671585083,0.2649602379117693,0.7506021772112165,0.2853407465344946,3554,0.7260679675409749 -362.267263174057,5.454789638519287,7073.571084022522,30322,0,7073.571084022522,0.2867245177455669,3581,0.743074205660081,7442.514234781265,0.2649151768003191,0.750415733882359,0.2853185925244882,3554,0.7259121681775816 -366.28090047836304,5.489109992980957,7153.6601984500885,30666,0,7153.6601984500885,0.2866318315741936,3581,0.7429935526694709,7526.665081501007,0.2646281208310808,0.750591686793736,0.2852345790229759,3554,0.7258340624120709 -370.2935276031494,5.52242112159729,7233.649788141251,31010,0,7233.649788141251,0.2867750025634424,3581,0.7430826595661129,7610.714449882507,0.2648442813328334,0.7505792209080288,0.2853307686431661,3554,0.7259497441263365 -374.3013157844544,5.554333209991455,7313.791854381561,31354,0,7313.791854381561,0.2866630905735129,3581,0.7429784856272689,7694.910228490829,0.2648018257958548,0.7504358291625977,0.2852551015361212,3554,0.7258527473445414 -378.3068976402282,5.58677077293396,7393.845289468765,31697,0,7393.845289468765,0.2866952699577632,3581,0.7431663123298311,7779.015561819076,0.2645924602236066,0.750840391431536,0.2852751947079874,3554,0.726041245339758 -382.3198854923248,5.620407104492188,7473.926406145096,32043,0,7473.926406145096,0.2866786007640149,3581,0.7427708876928582,7863.157017469406,0.2647336891719273,0.7502557209559849,0.2852741471152663,3554,0.7255834644942318 -386.3294751644135,5.653552770614624,7554.078770637512,32388,0,7554.078770637512,0.2867119050631806,3581,0.7431155207169785,7947.365916252136,0.2648208141326904,0.7505275181361607,0.2853099026569974,3554,0.7259816184229038 -390.34309220314026,5.686913251876831,7634.058696508408,32730,0,7634.058696508408,0.2867138140097039,3581,0.7432014914871893,8031.405979156494,0.264573335647583,0.7509356907435826,0.2852920077124806,3554,0.7260761421989308 -394.3575599193573,5.719174385070801,7714.161678314209,33076,0,7714.161678314209,0.28670041729571,3581,0.7435912574612539,8115.569339990616,0.2646849836621965,0.7511607578822544,0.2852806215817565,3554,0.7264695561998804 -398.3710045814514,5.753046751022339,7794.147730350494,33421,0,7794.147730350494,0.286716813782812,3581,0.7432023777837894,8199.616106510162,0.2647251742226736,0.7507865088326591,0.2852847604316879,3554,0.7260885072277715 -402.38356733322144,5.791030645370483,7874.208901643753,33763,0,7874.208901643753,0.2866793507072919,3581,0.7431103392907009,8283.741166591644,0.2645835195268903,0.750774656023298,0.2852442649622344,3554,0.726001814636677 -406.400333404541,5.826385736465454,7954.261415958404,34108,0,7954.261415958404,0.286617480386938,3581,0.7431852654417411,8367.859296798706,0.2645749023982456,0.7508078983851841,0.2851960585234067,3554,0.726052923422552 -410.41228795051575,5.861198663711548,8034.315863609314,34452,0,8034.315863609314,0.2866588977088278,3581,0.7431785841289095,8451.97430229187,0.2646217857088361,0.7507697514125279,0.2852324323165799,3554,0.7260702344629291 -414.4228415489197,5.896820545196533,8114.298360586166,34795,0,8114.298360586166,0.2866116853707065,3581,0.7429303529042167,8536.016276597977,0.2645829575402396,0.7505358287266323,0.2851932420446152,3554,0.725791059589547 -418.4358830451965,5.930956602096558,8194.391567707062,35140,0,8194.391567707062,0.2865960729152122,3581,0.7432427383674253,8620.169978618622,0.2645574637821742,0.7508696147373745,0.2851793829081229,3554,0.7261189389376407 -422.4488196372986,5.9632744789123535,8274.552117347717,35487,0,8274.552117347717,0.2866312520725705,3581,0.7433055290727102,8704.389466524124,0.2645804711750575,0.750948292868478,0.2852134039110861,3554,0.7261771232677968 -426.4578926563263,5.996220350265503,8354.554366588593,35830,0,8354.554366588593,0.2866460464081262,3581,0.7432949616901704,8788.44710612297,0.2645731312888009,0.7509431838989258,0.2852237767963914,3554,0.7261721085616559 -430.4693253040314,6.029922008514404,8434.533107995987,36175,0,8434.533107995987,0.2866142079071837,3581,0.7432375569411477,8872.484619140625,0.2645613465990339,0.7508655956813267,0.285195783744988,3554,0.7261112451419176 -434.4867742061615,6.064191818237305,8435.488661050797,36189,0,8435.488661050797,0.2866142079071837,3581,0.7432375569411477,8877.492777347565,0.2645613465990339,0.7508655956813267,0.28519578374498805,3554,0.7261113138365223 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/measurements.csv deleted file mode 100644 index 6b8af9185..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.6054354,1.0139427,,,,,,,,,,,,,, -1,,,0.1975462777273995,1.0327176366533553,0.1897597777152504,1.0370871866426914,3554.0,0.2128455943150481,1.0323986403937448,3581.0,28.97892165184021,32.935672760009766,28.97892165184021,3.95663857460022,0.0,0.0 -100,0.46014294,0.44008237,,,,,,,,,,,,,, -200,0.17156607,0.35069218,,,,,,,,,,,,,, -300,0.22710404,0.32800668,,,,,,,,,,,,,, -344,,,0.6935643468584333,0.3145843914576939,0.6709493402020611,0.3335539002598304,3554.0,0.689466828508971,0.3352491202483594,3581.0,109.12077069282532,117.11503720283508,109.12077069282532,7.961894273757935,0.0180349349975585,0.0 -400,0.19933636,0.35333467,,,,,,,,,,,,,, -500,0.36611632,0.23368965,,,,,,,,,,,,,, -600,0.18406083,0.35911918,,,,,,,,,,,,,, -686,,,0.7199602808271136,0.2929177284240722,0.6985059747819359,0.3106455547996096,3554.0,0.7157858151092572,0.3128802553297787,3581.0,189.2466561794281,201.29289603233337,189.2466561794281,11.97603702545166,0.0421395301818847,0.0 -700,0.7224801,0.32139847,,,,,,,,,,,,,, -800,0.18461631,0.24806102,,,,,,,,,,,,,, -900,0.28746685,0.31711465,,,,,,,,,,,,,, -1000,0.3476334,0.26068014,,,,,,,,,,,,,, -1026,,,0.7229129927498954,0.2882750374930246,0.7015023647430711,0.3057712604305888,3554.0,0.7185578781677604,0.3078806221202178,3581.0,269.41495633125305,290.7165720462799,269.41495633125305,21.19350576400757,0.0658392906188964,0.0 -1100,0.3442797,0.266749,,,,,,,,,,,,,, -1200,0.395839,0.255566,,,,,,,,,,,,,, -1300,0.27808088,0.3658355,,,,,,,,,,,,,, -1372,,,0.7301158223833356,0.280628102166312,0.7081162812807752,0.298456384146824,3554.0,0.7253081857808573,0.3003164896982337,3581.0,349.4108908176422,374.75886607170105,349.4108908176422,25.20110893249512,0.0901281833648681,0.0 -1400,0.68760854,0.27748096,,,,,,,,,,,,,, -1500,0.34403676,0.33915222,,,,,,,,,,,,,, -1600,0.44919017,0.19533512,,,,,,,,,,,,,, -1700,0.30041483,0.33824524,,,,,,,,,,,,,, -1718,,,0.7308858462742397,0.2795077732631138,0.7087743068989167,0.2976524168410593,3554.0,0.7260523340416434,0.299394673057456,3581.0,429.4600641727448,458.8567945957184,429.4600641727448,29.209868669509888,0.1150052547454834,0.0 -1800,0.11609503,0.278703,,,,,,,,,,,,,, -1900,0.11796071,0.21721575,,,,,,,,,,,,,, -2000,0.17952067,0.3402842,,,,,,,,,,,,,, -2062,,,0.733593395778111,0.2772624662944248,0.7119433264015897,0.2949803341085748,3554.0,0.7291415549296635,0.2966512101084718,3581.0,509.4550085067749,542.9577581882477,509.4550085067749,33.27567386627197,0.1403758525848388,0.0 -2100,0.21787606,0.29328954,,,,,,,,,,,,,, -2200,0.20142296,0.2815698,,,,,,,,,,,,,, -2300,0.069177955,0.2233255,,,,,,,,,,,,,, -2400,0.08221592,0.28752202,,,,,,,,,,,,,, -2407,,,0.734459672655378,0.2773699590138027,0.7120967901484243,0.2958892667702764,3554.0,0.7290389490540352,0.2977179362433678,3581.0,589.4221696853638,626.9751617908478,589.4221696853638,37.286170959472656,0.1653282642364502,0.0 -2500,0.21364903,0.2587566,,,,,,,,,,,,,, -2600,0.12284463,0.25367984,,,,,,,,,,,,,, -2700,0.17458507,0.2958204,,,,,,,,,,,,,, -2754,,,0.7358302388872419,0.2762782233101981,0.7136425561427265,0.2941995168845843,3554.0,0.730638373533929,0.2960094631933294,3581.0,669.4945640563965,711.0966486930847,669.4945640563965,41.29489207267761,0.1909987926483154,0.0 -2800,0.1800853,0.3528636,,,,,,,,,,,,,, -2900,0.41024315,0.2302421,,,,,,,,,,,,,, -3000,0.09617321,0.26109502,,,,,,,,,,,,,, -3099,,,0.7383593831743512,0.2740841593061174,0.7162388687262592,0.2921760829567037,3554.0,0.733525927857093,0.2937576562390917,3581.0,749.650426864624,795.2993674278259,749.650426864624,45.30148363113403,0.2160556316375732,0.0 -3100,0.2205155,0.24024099,,,,,,,,,,,,,, -3200,0.109892786,0.27468383,,,,,,,,,,,,,, -3300,0.22322683,0.30835438,,,,,,,,,,,,,, -3400,0.06980253,0.3774644,,,,,,,,,,,,,, -3446,,,0.7405399594988141,0.2721844060080392,0.7177803069604671,0.2907790064342114,3554.0,0.7349211632356535,0.2923777265210486,3581.0,829.7700300216675,879.4692540168762,829.7700300216675,49.31132197380066,0.2427656650543213,0.0 -3500,0.18342339,0.32513162,,,,,,,,,,,,,, -3600,0.13692334,0.2886062,,,,,,,,,,,,,, -3700,0.13257875,0.2272046,,,,,,,,,,,,,, -3794,,,0.7399829455784389,0.2730499165398733,0.718139854521314,0.2910478084222883,3554.0,0.7352366848252933,0.2926671705376466,3581.0,909.8821904659272,963.6286108493804,909.8821904659272,53.31908226013184,0.2685549259185791,0.0 -3800,0.108519204,0.25614938,,,,,,,,,,,,,, -3900,0.17545946,0.2804445,,,,,,,,,,,,,, -4000,0.24233429,0.22441685,,,,,,,,,,,,,, -4100,0.15373844,0.29351375,,,,,,,,,,,,,, -4139,,,0.740879876273019,0.2725841488157,0.7185715314170653,0.2911185295177969,3554.0,0.7357221708321697,0.2926931117573653,3581.0,990.0270144939424,1047.8274364471436,990.0270144939424,57.33399033546448,0.2938432693481445,0.0 -4200,0.27899495,0.23751697,,,,,,,,,,,,,, -4300,0.18552618,0.3198182,,,,,,,,,,,,,, -4400,0.07305534,0.25737107,,,,,,,,,,,,,, -4482,,,0.7423781667436872,0.2707956518445696,0.7193415979354248,0.2895648292966551,3554.0,0.7364936579342363,0.291115708339151,3581.0,1070.1156814098358,1131.9665739536283,1070.1156814098358,61.34604811668396,0.3189194202423095,0.0 -4500,0.15751861,0.27111596,,,,,,,,,,,,,, -4600,0.14451692,0.23472889,,,,,,,,,,,,,, -4700,0.10266353,0.25455412,,,,,,,,,,,,,, -4800,0.46930295,0.2717571,,,,,,,,,,,,,, -4829,,,0.7408492905752999,0.2728027275630406,0.7187132483865011,0.2910806444433209,3554.0,0.7358825223401284,0.2924860592362469,3581.0,1150.3168470859528,1216.2242138385773,1150.3168470859528,65.3628830909729,0.3446106910705566,0.0 -4900,0.124408446,0.36937314,,,,,,,,,,,,,, -5000,0.071282744,0.2968708,,,,,,,,,,,,,, -5100,0.11529784,0.24066842,,,,,,,,,,,,,, -5173,,,0.7423981257847377,0.2708584240504673,0.719932715008617,0.2893769839001829,3554.0,0.7370365486901355,0.290889634529461,3581.0,1230.4913923740387,1300.4513957500458,1230.4913923740387,69.37438464164734,0.3716864585876465,0.0 -5200,0.12751108,0.31272444,,,,,,,,,,,,,, -5300,0.1317952,0.19875747,,,,,,,,,,,,,, -5400,0.2449198,0.23892084,,,,,,,,,,,,,, -5500,0.18162549,0.2716129,,,,,,,,,,,,,, -5516,,,0.7416613442557198,0.2711343594959804,0.7192231684369724,0.2897987687728615,3554.0,0.736393915478393,0.2912650152279391,3581.0,1310.5437569618225,1384.5558178424835,1310.5437569618225,73.38616728782654,0.3977348804473877,0.0 -5600,0.21089117,0.224529,,,,,,,,,,,,,, -5700,0.15577714,0.29133177,,,,,,,,,,,,,, -5800,0.21688992,0.44977176,,,,,,,,,,,,,, -5862,,,0.742295469556536,0.2712622029440744,0.7193490169527293,0.290077462784011,3554.0,0.736499725657114,0.2916534517571034,3581.0,1390.7154486179352,1468.7736823558807,1390.7154486179352,77.39232563972473,0.4241127967834472,0.0 -5900,0.043249257,0.2806212,,,,,,,,,,,,,, -6000,0.1416535,0.23608632,,,,,,,,,,,,,, -6100,0.18822221,0.33605537,,,,,,,,,,,,,, -6200,0.11453406,0.23252724,,,,,,,,,,,,,, -6208,,,0.7428860664367676,0.2708876473563058,0.7201298685240223,0.2894863113635164,3554.0,0.737248850814193,0.2910343394935946,3581.0,1470.9021430015564,1553.0145823955536,1470.9021430015564,81.4054057598114,0.45113205909729,0.0 -6300,0.1742122,0.28286213,,,,,,,,,,,,,, -6400,0.15070842,0.28736225,,,,,,,,,,,,,, -6500,0.081162445,0.24953656,,,,,,,,,,,,,, -6552,,,0.7417918613978794,0.2728444167545863,0.7196870631023143,0.2914096572523916,3554.0,0.7366234662978218,0.2930313361753002,3581.0,1551.0724306106567,1637.238983631134,1551.0724306106567,85.41960525512695,0.4771711826324463,0.0 -6600,0.056554187,0.3107561,,,,,,,,,,,,,, -6700,0.11578721,0.2604481,,,,,,,,,,,,,, -6800,0.14505707,0.2774807,,,,,,,,,,,,,, -6899,,,0.7451910972595215,0.2698579856327602,0.7226475944798115,0.288501505510956,3554.0,0.7397011653300405,0.2900458460778239,3581.0,1631.3222754001615,1721.5402591228485,1631.3222754001615,89.430344581604,0.5038919448852539,0.0 -6900,0.054991107,0.3135753,,,,,,,,,,,,,, -7000,0.1134655,0.22688186,,,,,,,,,,,,,, -7100,0.082183704,0.24123569,,,,,,,,,,,,,, -7200,0.11701771,0.23626035,,,,,,,,,,,,,, -7245,,,0.7445575850350517,0.2694618701934814,0.7216776953564645,0.2881723553126759,3554.0,0.7388482752940868,0.2896346044553721,3581.0,1711.325751543045,1805.5942947864528,1711.325751543045,93.44122433662416,0.5299606323242188,0.0 -7300,0.10653821,0.24694645,,,,,,,,,,,,,, -7400,0.09472008,0.2575663,,,,,,,,,,,,,, -7500,0.14672491,0.19063582,,,,,,,,,,,,,, -7589,,,0.7447182791573661,0.2695349454879761,0.7220469288565701,0.2881819725573297,3554.0,0.7392248831724728,0.2897060195083426,3581.0,1791.4902153015137,1889.8089570999143,1791.4902153015137,97.45112013816832,0.5566811561584473,0.0 -7600,0.25741863,0.25393775,,,,,,,,,,,,,, -7700,0.21740617,0.24355263,,,,,,,,,,,,,, -7800,0.04738099,0.350919,,,,,,,,,,,,,, -7900,0.06685608,0.2901274,,,,,,,,,,,,,, -7935,,,0.7456581251961845,0.2689814908163888,0.722861578173361,0.2879128099225784,3554.0,0.7399333750392697,0.2894398237333321,3581.0,1871.5134434700008,1973.8852660655973,1871.5134434700008,101.46298694610596,0.5842244625091553,0.0 -8000,0.22155318,0.2708894,,,,,,,,,,,,,, -8100,0.08521787,0.25833192,,,,,,,,,,,,,, -8200,0.037213378,0.32230633,,,,,,,,,,,,,, -8282,,,0.7454793793814523,0.2688144786017282,0.7226394885164603,0.2876635858968328,3554.0,0.7398233379075329,0.2890387063451201,3581.0,1951.6899888515472,2058.11745595932,1951.6899888515472,105.47749543190002,0.6115193367004395,0.0 -8300,0.097957045,0.29549447,,,,,,,,,,,,,, -8400,0.19417867,0.23366103,,,,,,,,,,,,,, -8500,0.04156631,0.2761132,,,,,,,,,,,,,, -8600,0.13735707,0.19570473,,,,,,,,,,,,,, -8625,,,0.7453502927507673,0.2696679660252162,0.7224195283923045,0.2886448368036016,3554.0,0.7396019682874895,0.2900541977188634,3581.0,2031.8050429821008,2142.290014743805,2031.8050429821008,109.49516701698305,0.6374204158782959,0.0 -8700,0.11896521,0.26741275,,,,,,,,,,,,,, -8800,0.06175625,0.25978598,,,,,,,,,,,,,, -8900,0.13679132,0.29243454,,,,,,,,,,,,,, -8971,,,0.7453242710658482,0.2687513147081647,0.7227986539154826,0.2877739266055852,3554.0,0.7398632894311994,0.2892369299885681,3581.0,2111.8292939662933,2226.365335702896,2111.8292939662933,113.5051245689392,0.6649777889251709,0.0 -9000,0.100745626,0.33183894,,,,,,,,,,,,,, -9100,0.13594219,0.28395075,,,,,,,,,,,,,, -9200,0.47910872,0.24226445,,,,,,,,,,,,,, -9300,0.033404943,0.4276595,,,,,,,,,,,,,, -9316,,,0.7449782235281808,0.270092248916626,0.7226653863824212,0.2887455774413513,3554.0,0.7397393442605068,0.2901880625938111,3581.0,2191.909622192383,2310.495637178421,2191.909622192383,117.51276755332948,0.6933310031890869,0.0 -9400,0.12752156,0.23524323,,,,,,,,,,,,,, -9500,0.047337856,0.26933724,,,,,,,,,,,,,, -9600,0.14274022,0.32421097,,,,,,,,,,,,,, -9658,,,0.7451615333557129,0.2685591323035104,0.7224965350441404,0.2872988003723797,3554.0,0.7397383216105836,0.2886718136410046,3581.0,2271.916765451432,2394.557331323624,2271.916765451432,121.52645993232728,0.7202789783477783,0.0 -9700,0.1084427,0.32393736,,,,,,,,,,,,,, -9800,0.12744807,0.23101902,,,,,,,,,,,,,, -9900,0.08356779,0.24962273,,,,,,,,,,,,,, -10000,0.080816105,0.30968407,,,,,,,,,,,,,, -10003,,,0.7457009043012347,0.2682241712297712,0.7223576345534961,0.2873909713581967,3554.0,0.7395914690816113,0.2888142346869764,3581.0,2351.937286376953,2478.6354858875275,2351.937286376953,125.54268622398376,0.7473556995391846,0.0 -10100,0.093516625,0.2535346,,,,,,,,,,,,,, -10200,0.059487842,0.36637664,,,,,,,,,,,,,, -10300,0.09706198,0.24246445,,,,,,,,,,,,,, -10348,,,0.7455603054591587,0.268939733505249,0.7225909901255627,0.2879494069732168,3554.0,0.7397801820807736,0.289362784105784,3581.0,2431.914649963379,2562.67041516304,2431.914649963379,129.55689525604248,0.7765707969665527,0.0 -10400,0.18505481,0.29350454,,,,,,,,,,,,,, -10500,0.10882665,0.33424562,,,,,,,,,,,,,, -10600,0.05089213,0.29405862,,,,,,,,,,,,,, -10692,,,0.7455687522888184,0.2683139187949044,0.722452364413337,0.2873520215173484,3554.0,0.7395841060021642,0.2888140642453225,3581.0,2511.8980689048767,2646.7047917842865,2511.8980689048767,133.56716871261597,0.8033936023712158,0.0 -10700,0.1270182,0.22932284,,,,,,,,,,,,,, -10800,0.08989605,0.32469943,,,,,,,,,,,,,, -10900,0.08566543,0.26010948,,,,,,,,,,,,,, -11000,0.04503743,0.26593435,,,,,,,,,,,,,, -11038,,,0.7464522634233747,0.2678736788885934,0.7229895562218627,0.2872696051653946,3554.0,0.7401103616526459,0.288732559046443,3581.0,2592.050028562546,2730.907357931137,2592.050028562546,137.57607746124268,0.8305361270904541,0.0 -11100,0.044031613,0.28624958,,,,,,,,,,,,,, -11200,0.08881977,0.2759828,,,,,,,,,,,,,, -11300,0.11251663,0.33326617,,,,,,,,,,,,,, -11384,,,0.7461264474051339,0.2681691816874912,0.7231992121553179,0.2872296249054762,3554.0,0.7403277088496579,0.2887165716193102,3581.0,2672.0285749435425,2814.937951564789,2672.0285749435425,141.58618927001953,0.8576686382293701,0.0 -11400,0.069720946,0.30487078,,,,,,,,,,,,,, -11500,0.047782194,0.26350546,,,,,,,,,,,,,, -11600,0.12262273,0.23557493,,,,,,,,,,,,,, -11700,0.058943905,0.29625484,,,,,,,,,,,,,, -11730,,,0.7456574440002441,0.2681073631559099,0.7232962776317178,0.2868558060405089,3554.0,0.7403952719212511,0.2882289721359257,3581.0,2752.0991699695587,2899.066791296005,2752.0991699695587,145.6027750968933,0.8846747875213623,0.0 -11800,0.107402764,0.31397286,,,,,,,,,,,,,, -11900,0.16497011,0.2750877,,,,,,,,,,,,,, -12000,0.06731988,0.2542731,,,,,,,,,,,,,, -12073,,,0.7469180652073452,0.2673248733792986,0.7235540197884426,0.2866296290546303,3554.0,0.7406891815091804,0.288034907268832,3581.0,2832.174956798553,2983.193598508835,2832.174956798553,149.6131820678711,0.9115986824035645,0.0 -12100,0.14753933,0.29326475,,,,,,,,,,,,,, -12200,0.057572093,0.25830775,,,,,,,,,,,,,, -12300,0.13774471,0.29070124,,,,,,,,,,,,,, -12400,0.12096552,0.25823417,,,,,,,,,,,,,, -12418,,,0.746204035622733,0.2677193880081177,0.7232028529693655,0.2867117191072119,3554.0,0.7404398594579028,0.2881162420260577,3581.0,2912.2765226364136,3067.35256266594,2912.2765226364136,153.62973427772522,0.938593864440918,0.0 -12500,0.09351084,0.2989472,,,,,,,,,,,,,, -12600,0.09660736,0.30144334,,,,,,,,,,,,,, -12700,0.17530856,0.27215606,,,,,,,,,,,,,, -12764,,,0.7461694989885602,0.2676114354814802,0.7230209496561972,0.2867143466758406,3554.0,0.7402508055754329,0.2880752337641371,3581.0,2992.2794008255005,3151.408617973328,2992.2794008255005,157.64133405685425,0.9663434028625488,0.0 -12800,0.17889623,0.25453144,,,,,,,,,,,,,, -12900,0.114855714,0.20899612,,,,,,,,,,,,,, -13000,0.084953316,0.30120334,,,,,,,,,,,,,, -13100,0.08455686,0.2546575,,,,,,,,,,,,,, -13108,,,0.7421060970851353,0.2709654739924839,0.7184103738745076,0.2904571035167241,3554.0,0.7357027404836288,0.2917841464172891,3581.0,3072.431079149246,3235.6166298389435,3072.431079149246,161.65591073036194,0.9946553707122804,0.0 -13200,0.13418624,0.22968258,,,,,,,,,,,,,, -13300,0.08240846,0.34815547,,,,,,,,,,,,,, -13400,0.06934228,0.4057738,,,,,,,,,,,,,, -13454,,,0.746272087097168,0.2673053400857108,0.7233196337973059,0.286368658251486,3554.0,0.7405587595556409,0.2877401795609641,3581.0,3152.4668333530426,3319.708043813705,3152.4668333530426,165.66799879074097,1.0246577262878418,0.0 -13500,0.09723597,0.32815954,,,,,,,,,,,,,, -13600,0.1386146,0.19731516,,,,,,,,,,,,,, -13700,0.075598724,0.3164143,,,,,,,,,,,,,, -13800,,,0.7456895964486259,0.2676271370479038,0.7229075348638858,0.2865288197222759,3554.0,0.7401016350399678,0.2878812370737049,3581.0,3232.474317073822,3403.774663925171,3232.474317073822,169.683251619339,1.0551369190216064,0.0 -13800,0.12370031,0.26394925,,,,,,,,,,,,,, -13900,0.067662194,0.2338199,,,,,,,,,,,,,, -14000,0.055978313,0.21650182,,,,,,,,,,,,,, -14100,0.101470165,0.24379902,,,,,,,,,,,,,, -14143,,,0.7468343462262835,0.2676312753132411,0.7231794281091728,0.2871341909259373,3554.0,0.740409725373499,0.288486918534889,3581.0,3312.635585308075,3487.991771221161,3312.635585308075,173.69537568092346,1.0848171710968018,0.0 -14200,0.0748333,0.3551912,,,,,,,,,,,,,, -14300,0.13042276,0.33052874,,,,,,,,,,,,,, -14400,0.07499179,0.28722146,,,,,,,,,,,,,, -14489,,,0.7479549816676548,0.2670495850699289,0.7246055281021384,0.2862707512661789,3554.0,0.7417685544147934,0.287666889649801,3581.0,3392.603483438492,3572.0182886123657,3392.603483438492,177.7116355895996,1.1130201816558838,0.0 -14500,0.039989565,0.34051314,,,,,,,,,,,,,, -14600,0.08925442,0.21134424,,,,,,,,,,,,,, -14700,0.05985427,0.26022625,,,,,,,,,,,,,, -14800,0.040908262,0.34014103,,,,,,,,,,,,,, -14835,,,0.7479761668613979,0.2671400308609009,0.7246696201682963,0.2863954834946099,3554.0,0.7418121874781834,0.2877663253106674,3581.0,3472.621673107147,3656.093423604965,3472.621673107147,181.7254831790924,1.142401933670044,0.0 -14900,0.07779694,0.3520344,,,,,,,,,,,,,, -15000,0.07009795,0.21835986,,,,,,,,,,,,,, -15100,0.08531571,0.19659589,,,,,,,,,,,,,, -15176,,,0.7476153373718262,0.2673849718911307,0.7247315140071047,0.2866646117820589,3554.0,0.7418471621055571,0.28812868426679,3581.0,3552.6018946170807,3740.136210441589,3552.6018946170807,185.7420694828033,1.1750538349151611,0.0 -15200,0.09990789,0.25700134,,,,,,,,,,,,,, -15300,0.073916696,0.24144293,,,,,,,,,,,,,, -15400,0.06508741,0.27408427,,,,,,,,,,,,,, -15500,0.07156115,0.25554824,,,,,,,,,,,,,, -15521,,,0.7456714085170201,0.2672596148082188,0.7225538263444359,0.286428267994689,3554.0,0.7398874921460485,0.2877476108170727,3581.0,3632.597405195236,3824.189927577973,3632.597405195236,189.75790739059448,1.2031304836273191,0.0 -15600,0.09340302,0.27547887,,,,,,,,,,,,,, -15700,0.14574723,0.20728937,,,,,,,,,,,,,, -15800,0.04595083,0.24917538,,,,,,,,,,,,,, -15867,,,0.7476047788347516,0.2666402203696115,0.7241376491497257,0.2860132838878728,3554.0,0.7413994459691776,0.2873642534491936,3581.0,3712.6631059646606,3908.31063747406,3712.6631059646606,193.76922011375427,1.2327890396118164,0.0 -15900,0.08458953,0.20501715,,,,,,,,,,,,,, -16000,0.21377723,0.23487669,,,,,,,,,,,,,, -16100,0.060022615,0.36407855,,,,,,,,,,,,,, -16200,0.10401424,0.22592686,,,,,,,,,,,,,, -16210,,,0.7477357728140694,0.2667592423302786,0.7241149112355796,0.2862234206835607,3554.0,0.7413657666983734,0.2875945201235688,3581.0,3792.858276128769,3992.563661813736,3792.858276128769,197.7851824760437,1.260986328125,0.0 -16300,0.15627438,0.24365187,,,,,,,,,,,,,, -16400,0.10073914,0.23089108,,,,,,,,,,,,,, -16500,0.06916296,0.27440536,,,,,,,,,,,,,, -16558,,,0.748434339250837,0.2664368493216378,0.724732063563942,0.2862051650923695,3554.0,0.7419231109065205,0.2875985084582693,3581.0,3873.008520364762,4076.7703862190247,3873.008520364762,201.79962134361267,1.2895026206970217,0.0 -16600,0.12363648,0.23328412,,,,,,,,,,,,,, -16700,0.1584701,0.32035512,,,,,,,,,,,,,, -16800,0.21335018,0.29693595,,,,,,,,,,,,,, -16900,0.055942662,0.33017477,,,,,,,,,,,,,, -16903,,,0.7483385631016323,0.2663480384009225,0.7249650756629854,0.2857050511967149,3554.0,0.7421911815397235,0.2871007165639835,3581.0,3953.220718622208,4161.03843665123,3953.220718622208,205.81161665916443,1.3195016384124756,0.0 -17000,0.12779766,0.22168764,,,,,,,,,,,,,, -17100,0.05498756,0.3098395,,,,,,,,,,,,,, -17200,0.052595425,0.25633276,,,,,,,,,,,,,, -17247,,,0.7485357693263462,0.2669499261038644,0.7251728768421145,0.2862374515565647,3554.0,0.7422807656729964,0.2876641966716699,3581.0,4033.388636350632,4245.26022362709,4033.388636350632,209.8240213394165,1.347517728805542,0.0 -17300,0.062552944,0.25323647,,,,,,,,,,,,,, -17400,0.07851403,0.23237942,,,,,,,,,,,,,, -17500,0.09070766,0.25037262,,,,,,,,,,,,,, -17592,,,0.7472812107631138,0.2669779573168073,0.723099261505522,0.2868957691267761,3554.0,0.7403153688739179,0.2882725370226542,3581.0,4113.347852706909,4329.27641749382,4113.347852706909,213.8384931087494,1.375917673110962,0.0 -17600,0.06789575,0.30132744,,,,,,,,,,,,,, -17700,0.09746674,0.26094493,,,,,,,,,,,,,, -17800,0.13679667,0.28623748,,,,,,,,,,,,,, -17900,0.07093092,0.20620728,,,,,,,,,,,,,, -17936,,,0.7481169700622559,0.2662241969789777,0.7245039287818303,0.2856556254286543,3554.0,0.7417701224780089,0.287055038200747,3581.0,4193.354947090149,4413.337802171707,4193.354947090149,217.8497090339661,1.4057002067565918,0.0 -18000,0.08247538,0.31354678,,,,,,,,,,,,,, -18100,0.07561741,0.26765049,,,,,,,,,,,,,, -18200,0.06799011,0.21282817,,,,,,,,,,,,,, -18281,,,0.7480468068804059,0.2663684742791312,0.7245921326542276,0.2857335594576533,3554.0,0.7418551387749581,0.2870435163449455,3581.0,4273.456275224686,4497.49751830101,4273.456275224686,221.86294984817505,1.4366326332092283,0.0 -18300,0.076979935,0.25524303,,,,,,,,,,,,,, -18400,0.15354382,0.22064765,,,,,,,,,,,,,, -18500,0.066255756,0.3360023,,,,,,,,,,,,,, -18600,0.09066053,0.25749624,,,,,,,,,,,,,, -18609,,,0.7484518459865025,0.2659213542938232,0.724492731561269,0.2857109761063678,3554.0,0.741691855670553,0.2870939670744903,3581.0,4350.548347711563,4581.550496578217,4350.548347711563,225.87409567832947,4.3728437423706055,0.0 -18700,0.080460384,0.28863534,,,,,,,,,,,,,, -18800,0.088369764,0.26186922,,,,,,,,,,,,,, -18900,0.059722204,0.24334326,,,,,,,,,,,,,, -18955,,,0.74802337374006,0.2668883630207607,0.7247163324994724,0.2862276969227015,3554.0,0.7419325874624756,0.2875681016672193,3581.0,4430.576882123947,4665.634213447571,4430.576882123947,229.88594794273376,4.401664972305298,0.0 -19000,0.033495456,0.27970418,,,,,,,,,,,,,, -19100,0.06990355,0.27330673,,,,,,,,,,,,,, -19200,0.112757795,0.2315897,,,,,,,,,,,,,, -19298,,,0.7479914937700544,0.266499434198652,0.7243800724096089,0.2859593242759039,3554.0,0.7416037714238342,0.2872580342105033,3581.0,4510.648891210556,4749.759995222092,4510.648891210556,233.89514589309687,4.43175482749939,0.0 -19300,0.13428488,0.2564213,,,,,,,,,,,,,, -19400,0.14784938,0.24409468,,,,,,,,,,,,,, -19500,0.26131648,0.22488175,,,,,,,,,,,,,, -19600,0.06037121,0.34287298,,,,,,,,,,,,,, -19640,,,0.7494755472455706,0.2658189194543021,0.7257194111168753,0.2856957087304797,3554.0,0.742760524840303,0.2870987053524679,3581.0,4590.72519493103,4833.889722824097,4590.72519493103,237.90518808364868,4.460954427719116,0.0 -19700,0.14568543,0.20181224,,,,,,,,,,,,,, -19800,0.08149253,0.21865582,,,,,,,,,,,,,, -19900,0.059637286,0.31406644,,,,,,,,,,,,,, -19985,,,0.7494589260646275,0.2660778931209019,0.7255726107466939,0.2859783526813977,3554.0,0.7427205051399749,0.2873735595634948,3581.0,4670.917221784592,4918.140657424927,4670.917221784592,241.92007851600647,4.490314960479736,0.0 -20000,0.13853548,0.24102524,,,,,,,,,,,,,, -20100,0.040902857,0.38406754,,,,,,,,,,,,,, -20200,0.051475774,0.35781577,,,,,,,,,,,,,, -20300,0.05305914,0.26848578,,,,,,,,,,,,,, -20329,,,0.7488336563110352,0.2660074234008789,0.7252096971502181,0.2856191142462718,3554.0,0.7423481923912664,0.2870360850888369,3581.0,4750.951722860336,5002.232574462891,4750.951722860336,245.93207383155823,4.521239280700684,0.0 -20400,0.080372974,0.23440984,,,,,,,,,,,,,, -20500,0.1548762,0.23255952,,,,,,,,,,,,,, -20600,0.07271692,0.364141,,,,,,,,,,,,,, -20668,,,0.7490506853376117,0.2657781328473772,0.7247555571187394,0.2857364618047007,3554.0,0.7419362690021991,0.2871444859806967,3581.0,4830.941958904266,5086.27272605896,4830.941958904266,249.9386088848114,4.551448345184326,0.0 -20700,0.18514144,0.26004374,,,,,,,,,,,,,, -20800,0.02836722,0.23595153,,,,,,,,,,,,,, -20900,0.07113936,0.30526245,,,,,,,,,,,,,, -21000,0.064729705,0.31074882,,,,,,,,,,,,,, -21014,,,0.7492361749921527,0.266131009374346,0.7255093430157921,0.2857910224944605,3554.0,0.7425711300745252,0.2872181849518291,3581.0,4911.088220119476,5170.472066164017,4911.088220119476,253.94873142242432,4.5808281898498535,0.0 -21100,0.028768359,0.26955682,,,,,,,,,,,,,, -21200,0.045229185,0.22004344,,,,,,,,,,,,,, -21300,0.05301781,0.27257216,,,,,,,,,,,,,, -21362,,,0.7492140361240932,0.2658153431756155,0.7254861929340181,0.2855267028293384,3554.0,0.7426757130733385,0.286876074464186,3581.0,4991.229874610901,5254.662563562393,4991.229874610901,257.9545512199402,4.610207319259644,0.0 -21400,0.038423102,0.2572821,,,,,,,,,,,,,, -21500,0.07004351,0.21430665,,,,,,,,,,,,,, -21600,0.06935016,0.22302486,,,,,,,,,,,,,, -21700,0.10041389,0.3010721,,,,,,,,,,,,,, -21703,,,0.7498057229178292,0.2659606592995779,0.7261010783404263,0.2857190992433701,3554.0,0.7431506316976753,0.2871951071518954,3581.0,5071.210415363312,5338.700535297394,5071.210415363312,261.9677035808563,4.641288995742798,0.0 -21800,0.12746862,0.24641888,,,,,,,,,,,,,, -21900,0.059539985,0.3472636,,,,,,,,,,,,,, -22000,0.0960343,0.2847372,,,,,,,,,,,,,, -22048,,,0.749934468950544,0.2656627552849905,0.7257356230435776,0.2857178970877884,3554.0,0.7429184219884459,0.287159041697937,3581.0,5151.294460535049,5422.839916706085,5151.294460535049,265.9774160385132,4.673285722732544,0.0 -22100,0.054434173,0.2694831,,,,,,,,,,,,,, -22200,0.049009472,0.28095767,,,,,,,,,,,,,, -22300,0.06899909,0.3273716,,,,,,,,,,,,,, -22391,,,0.7499521800449916,0.2657770940235683,0.7259103134232555,0.2857854582314821,3554.0,0.7431317467624267,0.2871350094247417,3581.0,5231.44545173645,5507.0493676662445,5231.44545173645,269.9894750118256,4.706132888793945,0.0 -22400,0.041777786,0.25252438,,,,,,,,,,,,,, -22500,0.06573207,0.30819565,,,,,,,,,,,,,, -22600,0.053089835,0.22310866,,,,,,,,,,,,,, -22700,0.1271063,0.21164152,,,,,,,,,,,,,, -22735,,,0.749309744153704,0.2656928300857544,0.7254960849570906,0.2854609105717149,3554.0,0.7426772129598925,0.2868370433254503,3581.0,5311.462661266327,5591.121385335922,5311.462661266327,274.00053787231445,4.736352205276489,0.0 -22800,0.0666697,0.23184639,,,,,,,,,,,,,, -22900,0.07671483,0.20958394,,,,,,,,,,,,,, -23000,0.01938953,0.27511927,,,,,,,,,,,,,, -23080,,,0.750044754573277,0.2653392212731497,0.7257962116848973,0.2854270441316123,3554.0,0.7429138541521223,0.2868782220290247,3581.0,5391.550899028778,5675.270745754242,5391.550899028778,278.0168051719665,4.767531394958496,0.0 -23100,0.034687553,0.26594588,,,,,,,,,,,,,, -23200,0.066526376,0.3033148,,,,,,,,,,,,,, -23300,0.07329226,0.23979105,,,,,,,,,,,,,, -23400,0.03158706,0.2934369,,,,,,,,,,,,,, -23426,,,0.749577454158238,0.2657884699957711,0.7258460152732836,0.2855529956892761,3554.0,0.7429287848410011,0.2870010422848017,3581.0,5471.681828737259,5759.457330942154,5471.681828737259,282.02875447273254,4.797436714172363,0.0 -23500,0.107422985,0.24220772,,,,,,,,,,,,,, -23600,0.05997599,0.2831827,,,,,,,,,,,,,, -23700,0.061486986,0.33474565,,,,,,,,,,,,,, -23770,,,0.7495931216648647,0.2655000175748552,0.7256565555536015,0.2853825815487391,3554.0,0.7428168387627408,0.2867749343867809,3581.0,5551.710270643234,5843.544717788696,5551.710270643234,286.0441243648529,4.827166795730591,0.0 -23800,0.031965952,0.210415,,,,,,,,,,,,,, -23900,0.03758025,0.28367373,,,,,,,,,,,,,, -24000,0.039889645,0.18648136,,,,,,,,,,,,,, -24100,0.0911831,0.29215902,,,,,,,,,,,,,, -24116,,,0.750204154423305,0.2650260414396013,0.7257194111168753,0.2853486979349852,3554.0,0.7428803794113027,0.2867777296299043,3581.0,5631.85106420517,5927.740929841995,5631.85106420517,290.05428886413574,4.858988285064697,0.0 -24200,0.06586418,0.2696791,,,,,,,,,,,,,, -24300,0.05993486,0.24773946,,,,,,,,,,,,,, -24400,0.06805564,0.28298077,,,,,,,,,,,,,, -24462,,,0.749311992100307,0.2653989621571132,0.7251623665675999,0.2854617692542733,3554.0,0.7423995294174114,0.2868571895289374,3581.0,5712.004364490509,6011.95213842392,5712.004364490509,294.06663846969604,4.8908774852752686,0.0 -24500,0.056466937,0.33659813,,,,,,,,,,,,,, -24600,0.06111087,0.22128813,,,,,,,,,,,,,, -24700,0.03310193,0.2711314,,,,,,,,,,,,,, -24800,0.053187836,0.27060184,,,,,,,,,,,,,, -24808,,,0.7499324253627232,0.2654122965676443,0.7258076836838773,0.2854906038145839,3554.0,0.7429736450842991,0.2868880735566183,3581.0,5792.194799900055,6096.200660705566,5792.194799900055,298.0805068016052,4.921252250671387,0.0 -24900,0.0399461,0.27462986,,,,,,,,,,,,,, -25000,0.049528796,0.29323417,,,,,,,,,,,,,, -25100,0.029418122,0.33350778,,,,,,,,,,,,,, -25153,,,0.7504647118704659,0.265025121825082,0.7259484389288478,0.2854275078201937,3554.0,0.7430454351089081,0.2868404521585276,3581.0,5872.191735506058,6180.25568819046,5872.191735506058,302.09289360046387,4.95358943939209,0.0 -25200,0.052064445,0.28895605,,,,,,,,,,,,,, -25300,0.065513395,0.281506,,,,,,,,,,,,,, -25400,0.10160772,0.21487582,,,,,,,,,,,,,, -25497,,,0.7500345366341727,0.2653808423451015,0.7257026496333356,0.2855252602426403,3554.0,0.742860471826131,0.2869330701532393,3581.0,5952.1544880867,6264.277378082275,5952.1544880867,306.1063861846924,4.98514199256897,0.0 -25500,0.036894493,0.258602,,,,,,,,,,,,,, -25600,0.024650928,0.34842837,,,,,,,,,,,,,, -25700,0.042566072,0.24044102,,,,,,,,,,,,,, -25800,0.041283622,0.25313765,,,,,,,,,,,,,, -25845,,,0.7496070861816406,0.2653955902372087,0.7255828462427898,0.2853795418124824,3554.0,0.742732231525761,0.2867713892003805,3581.0,6032.137490034103,6348.31400680542,6032.137490034103,310.115487575531,5.016194820404053,0.0 -25900,0.04540183,0.21393256,,,,,,,,,,,,,, -26000,0.040155184,0.32333577,,,,,,,,,,,,,, -26100,0.037009835,0.29187036,,,,,,,,,,,,,, -26189,,,0.7505412101745605,0.2648829051426479,0.7259142977103263,0.2854377261426385,3554.0,0.743048912118647,0.2868577349422298,3581.0,6112.283630371094,6432.515614271164,6112.283630371094,314.12628722190857,5.047453880310059,0.0 -26200,0.020775864,0.2801876,,,,,,,,,,,,,, -26300,0.06771035,0.25711063,,,,,,,,,,,,,, -26400,0.025016062,0.4069217,,,,,,,,,,,,,, -26500,0.049923617,0.2065919,,,,,,,,,,,,,, -26532,,,0.7501976149422782,0.2651327848434448,0.7257329439539955,0.2854400617591974,3554.0,0.7429106498490295,0.2868616891885995,3581.0,6192.388805150986,6516.674349784851,6192.388805150986,318.13490319252014,5.079059839248657,0.0 -26600,0.07555848,0.30926472,,,,,,,,,,,,,, -26700,0.046785392,0.3261274,,,,,,,,,,,,,, -26800,0.037487682,0.25942743,,,,,,,,,,,,,, -26880,,,0.7498160089765277,0.265214272907802,0.7256105988630768,0.2852989973885059,3554.0,0.7427324360557456,0.2867100642933189,3581.0,6272.537399768829,6600.883124828339,6272.537399768829,322.150488615036,5.109807968139648,0.0 -26900,0.035356067,0.22886884,,,,,,,,,,,,,, -27000,0.033739097,0.22108869,,,,,,,,,,,,,, -27100,0.034199703,0.28203833,,,,,,,,,,,,,, -27200,0.031390373,0.35367924,,,,,,,,,,,,,, -27223,,,0.7501500674656459,0.2650226354598999,0.725652296488112,0.2854420195554305,3554.0,0.7427158691269896,0.2868899825031415,3581.0,6352.623078107834,6685.047988176346,6352.623078107834,326.1638045310974,5.161976337432861,0.0 -27300,0.031696465,0.22030799,,,,,,,,,,,,,, -27400,0.029654784,0.25872427,,,,,,,,,,,,,, -27500,0.02807896,0.33616325,,,,,,,,,,,,,, -27569,,,0.7501449584960938,0.2649794135774885,0.725748056767023,0.2852976921910171,3554.0,0.7429066956026599,0.2866826913637078,3581.0,6432.78490281105,6769.26359128952,6432.78490281105,330.1715428829193,5.19452714920044,0.0 -27600,0.044820532,0.29861745,,,,,,,,,,,,,, -27700,0.022343444,0.36928508,,,,,,,,,,,,,, -27800,0.020957813,0.31204262,,,,,,,,,,,,,, -27900,0.034188665,0.30386657,,,,,,,,,,,,,, -27914,,,0.7500529289245605,0.2651003428867885,0.7256652110737901,0.2853273510865838,3554.0,0.7428141798729405,0.2867274493420134,3581.0,6512.966456890106,6853.502288341522,6512.966456890106,334.1822066307068,5.227466106414795,0.0 -28000,0.058382247,0.21049,,,,,,,,,,,,,, -28100,0.05819993,0.21816026,,,,,,,,,,,,,, -28200,0.0330271,0.36721736,,,,,,,,,,,,,, -28257,,,0.7506343296595982,0.2649598121643066,0.7260884385331668,0.2853751453577834,3554.0,0.7432074910334054,0.2868046934995462,3581.0,6593.085450172424,6937.679908514023,6593.085450172424,338.1944320201874,5.260152816772461,0.0 -28300,0.04306503,0.263417,,,,,,,,,,,,,, -28400,0.04015496,0.22843805,,,,,,,,,,,,,, -28500,0.023310194,0.2947671,,,,,,,,,,,,,, -28600,0.025998827,0.31183523,,,,,,,,,,,,,, -28604,,,0.7506190027509417,0.2649067299706595,0.7260718144388365,0.2853277460805606,3554.0,0.743204695790282,0.2867792636047891,3581.0,6673.190561294556,7021.847796201706,6673.190561294556,342.20906805992126,5.2943854331970215,0.0 -28700,0.020722931,0.29348344,,,,,,,,,,,,,, -28800,0.051606283,0.26908565,,,,,,,,,,,,,, -28900,0.020917628,0.30314824,,,,,,,,,,,,,, -28946,,,0.7500931876046317,0.2649789367403303,0.7257242197392023,0.2852500868299803,3554.0,0.7428719255052709,0.2866358880855557,3581.0,6753.176616668701,7105.89125585556,6753.176616668701,346.21965074539185,5.327981233596802,0.0 -29000,0.031828053,0.22491731,,,,,,,,,,,,,, -29100,0.046500117,0.18436529,,,,,,,,,,,,,, -29200,0.037120417,0.18778555,,,,,,,,,,,,,, -29290,,,0.7503688676016671,0.2651293788637434,0.7259679481965743,0.2854834767493493,3554.0,0.7430520482450782,0.2869654199791434,3581.0,6833.300783395767,7190.074701547623,6833.300783395767,350.23344373703003,5.359550714492798,0.0 -29300,0.032220304,0.29069066,,,,,,,,,,,,,, -29400,0.04051544,0.27921897,,,,,,,,,,,,,, -29500,0.033801038,0.21288934,,,,,,,,,,,,,, -29600,0.030282555,0.25353357,,,,,,,,,,,,,, -29635,,,0.750690392085484,0.2648529495511736,0.7260376732203151,0.2853850373808561,3554.0,0.7431846518517872,0.2868040458212615,3581.0,6913.516995429993,7274.34780049324,6913.516995429993,354.2454254627228,5.391098737716675,0.0 -29700,0.018876176,0.30552825,,,,,,,,,,,,,, -29800,0.03046469,0.2852043,,,,,,,,,,,,,, -29900,0.028518107,0.24181578,,,,,,,,,,,,,, -29980,,,0.7506021772112165,0.2649602379117693,0.7260679675409749,0.2853407465344946,3554.0,0.7432026504904357,0.2867590492246404,3581.0,6993.540855407715,7358.42671585083,6993.540855407715,358.2557344436645,5.422429800033569,0.0 -30000,0.019956855,0.29512072,,,,,,,,,,,,,, -30100,0.0186549,0.23580031,,,,,,,,,,,,,, -30200,0.022686223,0.27810374,,,,,,,,,,,,,, -30300,0.026818395,0.2183017,,,,,,,,,,,,,, -30322,,,0.750415733882359,0.2649151768003191,0.7259121681775816,0.2853185925244882,3554.0,0.743074205660081,0.2867245177455669,3581.0,7073.571084022522,7442.514234781265,7073.571084022522,362.267263174057,5.454789638519287,0.0 -30400,0.08160042,0.2507101,,,,,,,,,,,,,, -30500,0.038572922,0.22509308,,,,,,,,,,,,,, -30600,0.032159775,0.32202622,,,,,,,,,,,,,, -30666,,,0.750591686793736,0.2646281208310808,0.7258340624120709,0.2852345790229759,3554.0,0.7429935526694709,0.2866318315741936,3581.0,7153.6601984500885,7526.665081501007,7153.6601984500885,366.28090047836304,5.489109992980957,0.0 -30700,0.024127033,0.29039592,,,,,,,,,,,,,, -30800,0.03336898,0.35617894,,,,,,,,,,,,,, -30900,0.027347673,0.20991318,,,,,,,,,,,,,, -31000,0.021889886,0.27529496,,,,,,,,,,,,,, -31010,,,0.7505792209080288,0.2648442813328334,0.7259497441263365,0.2853307686431661,3554.0,0.7430826595661129,0.2867750025634424,3581.0,7233.649788141251,7610.714449882507,7233.649788141251,370.2935276031494,5.52242112159729,0.0 -31100,0.04082823,0.1997267,,,,,,,,,,,,,, -31200,0.032211497,0.2151066,,,,,,,,,,,,,, -31300,0.054636754,0.33849505,,,,,,,,,,,,,, -31354,,,0.7504358291625977,0.2648018257958548,0.7258527473445414,0.2852551015361212,3554.0,0.7429784856272689,0.2866630905735129,3581.0,7313.791854381561,7694.910228490829,7313.791854381561,374.3013157844544,5.554333209991455,0.0 -31400,0.052832667,0.21930814,,,,,,,,,,,,,, -31500,0.05410925,0.26047173,,,,,,,,,,,,,, -31600,0.039220903,0.2184774,,,,,,,,,,,,,, -31697,,,0.750840391431536,0.2645924602236066,0.726041245339758,0.2852751947079874,3554.0,0.7431663123298311,0.2866952699577632,3581.0,7393.845289468765,7779.015561819076,7393.845289468765,378.3068976402282,5.58677077293396,0.0 -31700,0.028008142,0.23507483,,,,,,,,,,,,,, -31800,0.056671765,0.21840125,,,,,,,,,,,,,, -31900,0.032221846,0.30514032,,,,,,,,,,,,,, -32000,0.028379384,0.19017304,,,,,,,,,,,,,, -32043,,,0.7502557209559849,0.2647336891719273,0.7255834644942318,0.2852741471152663,3554.0,0.7427708876928582,0.2866786007640149,3581.0,7473.926406145096,7863.157017469406,7473.926406145096,382.3198854923248,5.620407104492188,0.0 -32100,0.02932376,0.32672018,,,,,,,,,,,,,, -32200,0.02366762,0.3216047,,,,,,,,,,,,,, -32300,0.0348081,0.29487833,,,,,,,,,,,,,, -32388,,,0.7505275181361607,0.2648208141326904,0.7259816184229038,0.2853099026569974,3554.0,0.7431155207169785,0.2867119050631806,3581.0,7554.078770637512,7947.365916252136,7554.078770637512,386.3294751644135,5.653552770614624,0.0 -32400,0.026992338,0.26841766,,,,,,,,,,,,,, -32500,0.016482146,0.24908632,,,,,,,,,,,,,, -32600,0.0288203,0.23347089,,,,,,,,,,,,,, -32700,0.020549515,0.24718396,,,,,,,,,,,,,, -32730,,,0.7509356907435826,0.264573335647583,0.7260761421989308,0.2852920077124806,3554.0,0.7432014914871893,0.2867138140097039,3581.0,7634.058696508408,8031.405979156494,7634.058696508408,390.34309220314026,5.686913251876831,0.0 -32800,0.015662499,0.23234165,,,,,,,,,,,,,, -32900,0.025592782,0.21376562,,,,,,,,,,,,,, -33000,0.022345632,0.21396402,,,,,,,,,,,,,, -33076,,,0.7511607578822544,0.2646849836621965,0.7264695561998804,0.2852806215817565,3554.0,0.7435912574612539,0.28670041729571,3581.0,7714.161678314209,8115.569339990616,7714.161678314209,394.3575599193573,5.719174385070801,0.0 -33100,0.015710458,0.23435253,,,,,,,,,,,,,, -33200,0.01630042,0.28860128,,,,,,,,,,,,,, -33300,0.019731432,0.3218053,,,,,,,,,,,,,, -33400,0.015363047,0.26377547,,,,,,,,,,,,,, -33421,,,0.7507865088326591,0.2647251742226736,0.7260885072277715,0.2852847604316879,3554.0,0.7432023777837894,0.286716813782812,3581.0,7794.147730350494,8199.616106510162,7794.147730350494,398.3710045814514,5.753046751022339,0.0 -33500,0.026006196,0.22221377,,,,,,,,,,,,,, -33600,0.018936105,0.29815567,,,,,,,,,,,,,, -33700,0.020066317,0.24866763,,,,,,,,,,,,,, -33763,,,0.750774656023298,0.2645835195268903,0.726001814636677,0.2852442649622344,3554.0,0.7431103392907009,0.2866793507072919,3581.0,7874.208901643753,8283.741166591644,7874.208901643753,402.38356733322144,5.791030645370483,0.0 -33800,0.020498684,0.23827618,,,,,,,,,,,,,, -33900,0.018311918,0.28577158,,,,,,,,,,,,,, -34000,0.020466156,0.3971158,,,,,,,,,,,,,, -34100,0.021997044,0.217166,,,,,,,,,,,,,, -34108,,,0.7508078983851841,0.2645749023982456,0.726052923422552,0.2851960585234067,3554.0,0.7431852654417411,0.286617480386938,3581.0,7954.261415958404,8367.859296798706,7954.261415958404,406.400333404541,5.826385736465454,0.0 -34200,0.018991455,0.31136453,,,,,,,,,,,,,, -34300,0.019029034,0.23514087,,,,,,,,,,,,,, -34400,0.019084131,0.20652096,,,,,,,,,,,,,, -34452,,,0.7507697514125279,0.2646217857088361,0.7260702344629291,0.2852324323165799,3554.0,0.7431785841289095,0.2866588977088278,3581.0,8034.315863609314,8451.97430229187,8034.315863609314,410.41228795051575,5.861198663711548,0.0 -34500,0.021958716,0.23864886,,,,,,,,,,,,,, -34600,0.016958416,0.2424548,,,,,,,,,,,,,, -34700,0.031273432,0.34583953,,,,,,,,,,,,,, -34795,,,0.7505358287266323,0.2645829575402396,0.725791059589547,0.2851932420446152,3554.0,0.7429303529042167,0.2866116853707065,3581.0,8114.298360586166,8536.016276597977,8114.298360586166,414.4228415489197,5.896820545196533,0.0 -34800,0.02006997,0.31392694,,,,,,,,,,,,,, -34900,0.023107871,0.1787213,,,,,,,,,,,,,, -35000,0.017305572,0.24014443,,,,,,,,,,,,,, -35100,0.018597638,0.22017461,,,,,,,,,,,,,, -35140,,,0.7508696147373745,0.2645574637821742,0.7261189389376407,0.2851793829081229,3554.0,0.7432427383674253,0.2865960729152122,3581.0,8194.391567707062,8620.169978618622,8194.391567707062,418.4358830451965,5.930956602096558,0.0 -35200,0.018108353,0.21710587,,,,,,,,,,,,,, -35300,0.02305203,0.2899525,,,,,,,,,,,,,, -35400,0.013728579,0.24601346,,,,,,,,,,,,,, -35487,,,0.750948292868478,0.2645804711750575,0.7261771232677968,0.2852134039110861,3554.0,0.7433055290727102,0.2866312520725705,3581.0,8274.552117347717,8704.389466524124,8274.552117347717,422.4488196372986,5.9632744789123535,0.0 -35500,0.022492796,0.31879503,,,,,,,,,,,,,, -35600,0.017089337,0.20186773,,,,,,,,,,,,,, -35700,0.027065825,0.20897359,,,,,,,,,,,,,, -35800,0.02494968,0.23503317,,,,,,,,,,,,,, -35830,,,0.7509431838989258,0.2645731312888009,0.7261721085616559,0.2852237767963914,3554.0,0.7432949616901704,0.2866460464081262,3581.0,8354.554366588593,8788.44710612297,8354.554366588593,426.4578926563263,5.996220350265503,0.0 -35900,0.023638135,0.19030464,,,,,,,,,,,,,, -36000,0.024383951,0.30215755,,,,,,,,,,,,,, -36100,0.017743133,0.24440815,,,,,,,,,,,,,, -36175,,,0.7508655956813267,0.2645613465990339,0.7261112451419176,0.285195783744988,3554.0,0.7432375569411477,0.2866142079071837,3581.0,8434.533107995987,8872.484619140625,8434.533107995987,430.4693253040314,6.029922008514404,0.0 -36189,,,0.7508655956813267,0.2645613465990339,0.7261113138365223,0.285195783744988,3554.0,0.7432375569411477,0.2866142079071837,3581.0,8435.488661050797,8877.492777347565,8435.488661050797,434.4867742061615,6.064191818237305,0.0 -36189,,,,,,,,,,,8435.488661050797,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/eval_measurements.csv deleted file mode 100644 index ce31daeb5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.965081214904785,0.0,44.76288342475891,1,0,44.76288342475891,1.0323986403937448,3581,0.2128455943150481,48.72807693481445,1.0327176366533553,0.1975462777273995,1.0370871866426914,3554,0.1897597777152504 -7.969703435897827,0.0244109630584716,125.10523748397829,248,0,125.10523748397829,0.3618895938961707,3581,0.6574814068608629,133.1096351146698,0.3414253166743687,0.6612253870282855,0.3610390924508476,3554,0.637590141060249 -11.9761803150177,0.0585534572601318,205.28294610977173,483,0,205.28294610977173,0.3340769929947116,3581,0.693326854732442,217.33776664733887,0.3137917518615722,0.6970335415431431,0.3325878136870427,3554,0.674776522712085 -15.985013246536257,0.0937292575836181,285.4695768356323,715,0,285.4695768356323,0.3169058485758168,3581,0.7114660736569743,301.5783360004425,0.2966330392020089,0.716080801827567,0.3145881785927827,3554,0.6942070664216375 -19.99083566665649,0.1238799095153808,365.430878162384,1034,0,365.430878162384,0.3066954731242146,3581,0.7197285396231848,385.5891439914704,0.2870752811431885,0.7235066550118583,0.3045320097623276,3554,0.7026445499349324 -23.99786949157715,0.1478610038757324,445.5725281238556,1381,0,445.5725281238556,0.2998524793397444,3581,0.7259140717720259,469.7762997150421,0.2799033607755388,0.7308674539838519,0.2979709537229178,3554,0.7087941596396665 -28.005931615829468,0.1723105907440185,525.693502664566,1727,0,525.693502664566,0.298256906841228,3581,0.7284042243350322,553.9444513320923,0.2783775159290859,0.7333673749651227,0.2964568902886712,3554,0.7113109925655952 -32.01269316673279,0.1958217620849609,605.7808966636658,2073,0,605.7808966636658,0.2952709054187901,3581,0.7319312075668458,638.0774807929993,0.275574905531747,0.7364885466439384,0.2935554706184932,3554,0.7148465664787915 -36.02655911445618,0.2213416099548339,685.8993790149689,2420,0,685.8993790149689,0.2949160118071069,3581,0.7310254806181933,722.2499129772186,0.2753708021981375,0.7353577613830566,0.2931602705578222,3554,0.7139438506788126 -40.04101777076721,0.2482287883758545,765.968150138855,2765,0,765.968150138855,0.2939401310737224,3581,0.7326973768893117,806.3747057914734,0.2741092954363142,0.7376354762486049,0.292292039449388,3554,0.7155301464898706 -44.046605587005615,0.2724692821502685,846.0964727401733,3110,0,846.0964727401733,0.2943855292036093,3581,0.7312285788929419,890.5473172664642,0.2750832523618425,0.7347473417009626,0.2927851636690173,3554,0.7142884915104459 -48.05672740936279,0.2985961437225342,926.2872281074524,3458,0,926.2872281074524,0.2926600801648457,3581,0.7346316851307246,974.789454460144,0.2726844719478062,0.7395678247724261,0.2909971118040412,3554,0.7174643117789814 -52.06747007369995,2.796729803085327,1003.8125767707824,3793,0,1003.8125767707824,0.2918852183180326,3581,0.7358856584665596,1058.8378336429596,0.2718532766614641,0.7412997654506138,0.290306353206774,3554,0.7187269186128307 -56.0785710811615,2.821539163589477,1083.916685819626,4139,0,1083.916685819626,0.2931617240405089,3581,0.7350120427254957,1142.9926023483276,0.2731419290815081,0.7402080808367048,0.2915027040944182,3554,0.7178762046285875 -60.08868336677551,2.8460841178894043,1164.0630095005035,4484,0,1164.0630095005035,0.2912875135262496,3581,0.7369192848322745,1227.1880266666412,0.2710321971348354,0.7428110667637416,0.2897092597029755,3554,0.71978186165676 -64.09915161132812,2.870636463165283,1244.148692369461,4831,0,1244.148692369461,0.2908246621710067,3581,0.737703248263404,1311.323415517807,0.2707383973257882,0.7431929452078683,0.2893053354275112,3554,0.7205519281751196 -68.11032629013062,2.897249937057495,1324.1467669010162,5175,0,1324.1467669010162,0.2903081557831262,3581,0.7370804544601718,1395.3740241527555,0.2704767840249197,0.7424209458487374,0.2888622552273846,3554,0.7198196436893289 -72.12192797660828,2.9218204021453857,1404.2759974002838,5517,0,1404.2759974002838,0.2907500087266126,3581,0.7387610773439681,1479.5537617206571,0.2705959592546735,0.7446633747645787,0.2892542953362408,3554,0.7215675779051772 -76.13256049156189,2.946476936340332,1484.4351799488068,5863,0,1484.4351799488068,0.2902668407262287,3581,0.737032253560458,1563.7628009319303,0.2705733605793544,0.7418193135942731,0.288847142414357,3554,0.7196740111274268 -80.14620423316956,2.9725465774536133,1564.6058371067047,6210,0,1564.6058371067047,0.2908501261540945,3581,0.7394353445266685,1647.987946987152,0.2705646753311157,0.745011465890067,0.289374133074089,3554,0.7222998623909679 -84.1552243232727,2.997762680053711,1644.8221879005432,6553,0,1644.8221879005432,0.2899777375929384,3581,0.7388201865095294,1732.2529256343842,0.2694678476878575,0.7449780872889927,0.2884729972500175,3554,0.7216078329435144 -88.16625618934631,3.023563861846924,1724.981882095337,6899,0,1724.981882095337,0.2901134432377478,3581,0.7392115887234711,1816.464188098908,0.2704054968697684,0.7441498211451939,0.2887626823979143,3554,0.722005643399163 -92.17454433441162,3.0495445728302,1805.1859049797056,7247,0,1805.1859049797056,0.2893399790124965,3581,0.7395222015934795,1900.717324256897,0.2693194661821638,0.7448177337646484,0.2880321152772404,3554,0.7221843180659117 -96.1854498386383,3.076326608657837,1885.1832218170168,7588,0,1885.1832218170168,0.2893281162733873,3581,0.7393584412524434,1984.7666964530945,0.2691126380647932,0.7447967529296875,0.2878792010872432,3554,0.7220263204751688 -100.19572973251344,3.102079153060913,1965.366005659104,7934,0,1965.366005659104,0.2891115871963138,3581,0.7399581231674114,2069.000062465668,0.2690371445247105,0.7456778798784528,0.2876706442674627,3554,0.7227637570563098 -104.20498085021973,3.1299993991851807,2045.4119424819944,8281,0,2045.4119424819944,0.2892600418768326,3581,0.7401415865636345,2153.09862613678,0.2691262619835989,0.745875358581543,0.2879217230475345,3554,0.7229084278937464 -108.21316409111024,3.1573917865753174,2125.377302646637,8624,0,2125.377302646637,0.2892616781167097,3581,0.739326602751501,2237.113946437836,0.2690394095012119,0.7451011112758091,0.2877977979807083,3554,0.7221248972328714 -112.22468495368958,3.184272289276123,2205.461377620697,8969,0,2205.461377620697,0.2884087539924253,3581,0.7405642136885646,2321.250861644745,0.2680323464529855,0.7463887759617397,0.2870491298317037,3554,0.7232508018034257 -116.23565125465392,3.211730480194092,2285.5357501506805,9314,0,2285.5357501506805,0.2889439066972389,3581,0.740029367778728,2405.378289461136,0.2686294998441423,0.7459329877580915,0.2875499993680096,3554,0.7227771525042206 -120.24630069732666,3.239140033721924,2365.571274280548,9658,0,2365.571274280548,0.2885687646170762,3581,0.7404210426993159,2489.4663603305817,0.2684325831277029,0.746016229901995,0.2871666834739466,3554,0.7231320288319499 -124.25722122192384,3.26558256149292,2445.607465028763,10000,0,2445.607465028763,0.2884534437940694,3581,0.7405252848148213,2573.5549368858337,0.2678298609597342,0.7468297140938895,0.2870614776868933,3554,0.7232698302089196 -128.26841831207275,3.2932775020599365,2525.712516069412,10347,0,2525.712516069412,0.288304341435266,3581,0.7405346931941148,2657.7135181427,0.2679857185908726,0.746333122253418,0.2869289142735298,3554,0.7232405663073298 -132.2769341468811,3.3216447830200195,2605.689066648484,10691,0,2605.689066648484,0.2885059057351298,3581,0.7402006957291958,2741.7419786453247,0.2683102062770298,0.7459057399204799,0.2871290903515405,3554,0.7228599981974536 -136.29143595695496,3.349501371383667,2685.8272409439087,11037,0,2685.8272409439087,0.2883143974928441,3581,0.7406535251151913,2825.9370296001434,0.2679037196295602,0.7467265129089355,0.2869203617952483,3554,0.7233592018895962 -140.30236983299255,3.3763203620910645,2765.9728231430054,11383,0,2765.9728231430054,0.2887903387671041,3581,0.7408546462667551,2910.1348733901978,0.2682475192206247,0.7467336654663086,0.287365915001143,3554,0.7235851384443585 -144.31327152252197,3.4031455516815186,2846.0457010269165,11728,0,2846.0457010269165,0.2882078714591769,3581,0.7404165430396538,2994.259813785553,0.2683146340506417,0.7459479059491839,0.2869273342976224,3554,0.723168093499402 -148.32446479797363,3.429948806762696,2926.048833608628,12072,0,2926.048833608628,0.2879914446470958,3581,0.7406919085756423,3078.3154022693634,0.2674442529678345,0.7469075066702706,0.2866193076902785,3554,0.7234120967351927 -152.33849143981934,3.458447933197021,3006.178083181381,12417,0,3006.178083181381,0.2879815590311714,3581,0.7412790459848855,3162.5018265247345,0.2675682646887643,0.7473031452723912,0.2866649896023846,3554,0.7239558832257668 -156.35280227661133,3.4851222038269043,3086.1382641792297,12762,0,3086.1382641792297,0.2877508151201655,3581,0.7415175279469771,3246.5176918506622,0.2675473690032959,0.7473674501691546,0.286402455996984,3554,0.7242794348137662 -160.36889815330505,3.513031244277954,3166.1795699596405,13105,0,3166.1795699596405,0.2902166967916608,3581,0.7376949307106954,3330.6173605918884,0.2700824226651873,0.7421697889055524,0.2887568777038196,3554,0.720902820215778 -164.37267231941223,3.5402119159698486,3246.334449291229,13452,0,3246.334449291229,0.2877930505619938,3581,0.7413308602476613,3414.818110704422,0.2673826728548322,0.7474530764988491,0.286526982141601,3554,0.7240228604653207 -168.38208889961243,3.567617654800415,3326.4810423851013,13798,0,3326.4810423851013,0.2885023605487294,3581,0.7404483133639347,3499.016035795212,0.268414991242545,0.7462870052882603,0.2871614970312939,3554,0.7231948843952237 -172.3943543434143,3.595734119415283,3406.449381589889,14141,0,3406.449381589889,0.2880931642261239,3581,0.7413019533431653,3583.039057970047,0.2679232358932495,0.7470973559788295,0.2867153255739572,3554,0.7241787972179234 -176.40333008766174,3.623152017593384,3486.5816123485565,14487,0,3486.5816123485565,0.2881711583269338,3581,0.7403454347816601,3667.222679138184,0.2677939278738839,0.7464298520769391,0.286742957978686,3554,0.7232238048237901 -180.41909766197205,3.652091026306152,3566.6675159931183,14832,0,3566.6675159931183,0.2877265442286547,3581,0.7416326101516685,3751.367563962936,0.267444235937936,0.7475502831595284,0.286362716168182,3554,0.7244118780115715 -184.4270977973938,3.679409980773926,3646.81110739708,15174,0,3646.81110739708,0.2883633483358349,3581,0.742216406904496,3835.5609562397,0.2678801161902291,0.7479367256164551,0.2868504135140423,3554,0.7252343585132949 -188.43514847755432,3.70670485496521,3726.9752271175385,15520,0,3726.9752271175385,0.2879875585773876,3581,0.7409330494275342,3919.774846792221,0.2672254358019147,0.7471392495291573,0.2866708114701304,3554,0.7236271108478123 -192.44737243652344,3.734318971633911,3806.976819515228,15866,0,3806.976819515228,0.2874218968165317,3581,0.7422716981770107,4003.8309524059296,0.2670131070273263,0.748380184173584,0.2861659576467537,3554,0.724918844194042 -196.4577419757843,3.761636972427368,3886.973289012909,16209,0,3886.973289012909,0.2874222036115086,3581,0.7419779931190659,4087.879930019378,0.2671475410461426,0.7478961263384137,0.2861368311343732,3554,0.7246443405537775 -200.4681260585785,3.789839267730713,3967.0818524360657,16556,0,3967.0818524360657,0.2871162608428163,3581,0.7418374128429559,4172.041687011719,0.2665354013442993,0.748145307813372,0.2858139321451182,3554,0.7245252241092783 -204.47909569740293,3.817901611328125,4047.131493330002,16903,0,4047.131493330002,0.2870625035451864,3581,0.7418483892854649,4256.145545244217,0.2666094814028059,0.7478715351649693,0.2857410128222601,3554,0.724559914884637 -208.4960463047028,3.8455328941345215,4127.304555654526,17247,0,4127.304555654526,0.2876970578225356,3581,0.7423234442631248,4340.377782583237,0.2673543861934117,0.7482375417436872,0.2863902798783061,3554,0.7250987553636747 -212.5071542263031,3.8736438751220703,4207.318914651871,17592,0,4207.318914651871,0.2876395508085206,3581,0.7414175809611491,4424.445866823196,0.2666993652071271,0.748047011239188,0.2863386730565472,3554,0.7240940280757597 -216.51686763763428,3.901830911636353,4287.361625432968,17937,0,4287.361625432968,0.2872101741940973,3581,0.7426705998237224,4508.540894031525,0.2666618824005127,0.7487597465515137,0.2859063950830051,3554,0.7253592453045864 -220.5280795097351,3.9302728176116943,4367.358929157257,18283,0,4367.358929157257,0.2869363426329935,3581,0.7427411626684236,4592.592378616333,0.2666053431374686,0.7487377439226423,0.2856880149347566,3554,0.7254902459156936 -224.53717684745789,3.959471225738525,4447.448851585388,18626,0,4447.448851585388,0.2870357101171984,3581,0.7427425262016546,4676.735166788101,0.2661102328981672,0.7493752070835659,0.2856842539051509,3554,0.7254873607422974 -228.54846930503845,3.989834785461426,4527.552319765091,18972,0,4527.552319765091,0.2880688592462824,3581,0.7408810306347738,4760.894608259201,0.2676242760249546,0.7464577811104911,0.2866407060596335,3554,0.7239104760920794 -232.5619733333588,4.01876974105835,4607.723635435104,19318,0,4607.723635435104,0.2872599090686959,3581,0.7427173008368821,4845.123031139374,0.266740185873849,0.7489397185189384,0.2859350407331528,3554,0.7255840140510692 -236.56711435318,4.048708200454712,4687.796002864838,19663,0,4687.796002864838,0.2868840170452562,3581,0.7418566386615122,4929.245357036591,0.2659630945750645,0.7484968049185616,0.2855817787286332,3554,0.7245838893016672 -240.57340145111084,4.0773022174835205,4767.805709838867,20008,0,4767.805709838867,0.2869510687918877,3581,0.7426997794348645,5013.3049693107605,0.2663997071129935,0.7490484373910087,0.2857362041999332,3554,0.725382051913337 -244.5843369960785,4.106105089187622,4847.913654327393,20354,0,4847.913654327393,0.2873005082706472,3581,0.7417742812543633,5097.467324256897,0.2667103154318673,0.7480840001787458,0.2860161518876178,3554,0.7245500228615644 -248.5912554264069,4.135280132293701,4927.873211860657,20697,0,4927.873211860657,0.2869389333461323,3581,0.742291810292167,5181.477455615997,0.2661643028259277,0.7485865865434919,0.28558687930303,3554,0.7251484215628518 -252.6009838581085,4.1652820110321045,5007.937152862549,21042,0,5007.937152862549,0.2870958760210136,3581,0.742089802844003,5265.59605717659,0.2663168055670602,0.7484387670244489,0.285725024153023,3554,0.7248215039392234 -256.6096124649048,4.198110818862915,5088.057301521301,21388,0,5088.057301521301,0.2872167873302674,3581,0.7421052107695127,5349.772326469421,0.2667201246534075,0.7481115886143276,0.2858996458380962,3554,0.7249347126477209 -260.62133407592773,4.228944778442383,5168.070353746414,21732,0,5168.070353746414,0.2870578334438704,3581,0.7425255198879503,5433.842468261719,0.2662261894771031,0.7490191459655762,0.2857492218275183,3554,0.7253119834165729 -264.62954115867615,4.259636402130127,5248.222505569458,22079,0,5248.222505569458,0.2868992204407812,3581,0.7428523588034068,5518.048063755035,0.265928966658456,0.749504634312221,0.2855818130759355,3554,0.7256125223120076 -268.64070224761963,4.2899558544158936,5328.270455598831,22423,0,5328.270455598831,0.2867292901118752,3581,0.7429833943469003,5602.152354717255,0.2660556520734514,0.749286447252546,0.2854290019278453,3554,0.7258175757069499 -272.6491093635559,4.319113254547119,5408.337042331696,22767,0,5408.337042331696,0.2871282258469178,3581,0.7428747207483943,5686.270875692368,0.2664235830307007,0.7490869930812291,0.2856929265989906,3554,0.725718861560038 -276.6583559513092,4.348245620727539,5488.437965869904,23111,0,5488.437965869904,0.2865338276232198,3581,0.743165426033231,5770.424674987793,0.2655370916639055,0.7498129435947963,0.285219534904553,3554,0.7259281053258653 -280.66591787338257,4.377592325210571,5568.480888128281,23459,0,5568.480888128281,0.2864100188058503,3581,0.742678781023108,5854.519265413284,0.2656686987195696,0.7490844045366559,0.2851312795112021,3554,0.7254319241963281 -284.6766347885132,4.406728506088257,5648.613835811615,23804,0,5648.613835811615,0.2864074621810423,3581,0.7428889014939961,5938.706800699234,0.2656446354729788,0.7493758201599121,0.285176532082029,3554,0.7255900591762803 -288.6837303638458,4.437592029571533,5728.598200559616,24148,0,5728.598200559616,0.2864313581009145,3581,0.7429643048816671,6022.743552207947,0.2651958976473127,0.7498668261936733,0.2851227613802229,3554,0.7257230519309229 -292.6920075416565,4.466193199157715,5808.61803650856,24494,0,5808.61803650856,0.2865370319263124,3581,0.7430404582126152,6106.814738750458,0.2655548027583531,0.7496851512363979,0.2852120128453415,3554,0.7258203921857415 -296.69993567466736,4.499108552932739,5888.69374203682,24838,0,5888.69374203682,0.2867215520607896,3581,0.7424380492311854,6190.94571518898,0.2659956046513149,0.748795781816755,0.2854291908380082,3554,0.7251729455367192 -300.70812129974365,4.529009819030762,5968.84553027153,25185,0,5968.84553027153,0.2864891037332449,3581,0.7436003249572396,6275.150093793869,0.265080486025129,0.7506846700395856,0.2851709162980972,3554,0.7263689872986424 -304.7168712615967,4.558919191360474,6048.872050285339,25531,0,6048.872050285339,0.2862789150856953,3581,0.742812748163048,6359.229727506638,0.2652970893042428,0.7494222096034459,0.2849968785171637,3554,0.7254723853184791 -308.72206234931946,4.588968515396118,6128.933758020401,25876,0,6128.933758020401,0.2862819148588034,3581,0.7430234140472284,6443.341115951538,0.2653527430125645,0.7495801789419991,0.2849699502321328,3554,0.7257883804999649 -312.73252749443054,4.620581388473511,6209.109485387802,26222,0,6209.109485387802,0.2864005763382261,3581,0.74321253610636,6527.5736310482025,0.2648673909051077,0.7504499980381557,0.2850825063418859,3554,0.7260060737021665 -316.7389891147613,4.651438236236572,6289.2570996284485,26568,0,6289.2570996284485,0.2864089279792656,3581,0.7435167403701829,6611.773455381393,0.2652456249509539,0.7503705705915179,0.2850904405687254,3554,0.7262834625158272 -320.7466149330139,4.683363676071167,6369.344420433044,26915,0,6369.344420433044,0.2861802634564367,3581,0.7430864774591595,6695.914905786514,0.2651932580130441,0.7497830390930176,0.2848746021208497,3554,0.7259059856631612 -324.75905179977417,4.714396953582764,6449.312697172165,27255,0,6449.312697172165,0.2862984136108978,3581,0.7434242928171251,6779.940864801407,0.2648061513900757,0.7505489758082798,0.2849321338522615,3554,0.7262866224676421 -328.7665092945099,4.74639892578125,6529.513382673264,27600,0,6529.513382673264,0.2861646850892732,3581,0.74350603663432,6864.195533752441,0.2649491514478411,0.750410965510777,0.2848861943353879,3554,0.7263029030889491 -332.7754480838776,4.778889656066895,6609.545207738876,27946,0,6609.545207738876,0.2862289075044505,3581,0.7435222626797682,6948.283418178558,0.2650993721825735,0.7503207751682827,0.2849315842954241,3554,0.7263553170723129 -336.7862060070038,4.811024188995361,6689.76785159111,28290,0,6689.76785159111,0.2862642571034627,3581,0.7434201340407708,7032.563189744949,0.2648235048566545,0.7504880087716239,0.2849101515787668,3554,0.7263116273037422 -340.7982308864593,4.841078758239746,6769.840493440628,28637,0,6769.840493440628,0.2862636776018396,3581,0.7436845231342503,7116.692368984222,0.2649060317448207,0.7506187983921596,0.2849294547626794,3554,0.7264817151449071 -344.80379843711853,4.873269319534302,6849.891664505005,28982,0,6849.891664505005,0.2862514739794226,3581,0.7437460866596272,7200.795670986176,0.2649570022310529,0.7507256780351911,0.2849720625912264,3554,0.7265615382755346 -348.8160009384155,4.905006408691406,6929.979115247726,29323,0,6929.979115247726,0.2862659956083321,3581,0.7433762282707345,7284.941651105881,0.2648698602403913,0.750253541128976,0.2849260543797481,3554,0.7261388603729952 -352.8241741657257,4.935630083084106,7010.161997318268,29668,0,7010.161997318268,0.2861837404661757,3581,0.7433797734571349,7369.177745580673,0.264641353062221,0.7505779947553363,0.2848554878471001,3554,0.7261684677476083 -356.831734418869,4.967212677001953,7090.252366781235,30014,0,7090.252366781235,0.2861041101254887,3581,0.7432790083513683,7453.321771860123,0.2647990328924997,0.7503046308244977,0.2848229781254396,3554,0.7261086347469401 -360.8413755893707,4.998633146286011,7170.308975458145,30357,0,7170.308975458145,0.2865039321571314,3581,0.7438623960442264,7537.433724164963,0.2650235380445208,0.7509540830339704,0.2851525576649989,3554,0.7267261992429305 -364.797967672348,5.029951572418213,7250.800181627274,30701,0,7250.800181627274,0.2861409255227241,3581,0.7436667972022479,7621.926305532455,0.264449988092695,0.7509610312325614,0.2848017686662475,3554,0.7264982705446328 -368.8075482845306,5.063923120498657,7330.806107759476,31047,0,7330.806107759476,0.2861016898540037,3581,0.7439442762147445,7705.99044585228,0.2646209682737078,0.7511175700596401,0.2847992269658747,3554,0.7267902226144837 -372.81839632987976,5.095531225204468,7410.819982767105,31390,0,7410.819982767105,0.2860398877103113,3581,0.7441750542140813,7790.061582565308,0.2646078722817557,0.7512401853288923,0.2847450784437429,3554,0.7270120374929657 -376.8252596855164,5.130434513092041,7490.946867465973,31735,0,7490.946867465973,0.2860473871430815,3581,0.7438293303633762,7874.244492769241,0.2642558813095093,0.7512694086347308,0.284735255115275,3554,0.7266450709148143 -380.8364975452423,5.161880731582642,7571.110435009003,32079,0,7571.110435009003,0.2861123254132051,3581,0.743691477153728,7958.465414047241,0.2644948278154646,0.7509448868887765,0.2848054438275974,3554,0.7264786238876969 -384.8424386978149,5.193868160247803,7651.189124345779,32422,0,7651.189124345779,0.2859859599710276,3581,0.7440023627303826,8042.596368312836,0.2644805908203125,0.7511148452758789,0.2846810378985386,3554,0.7268427739870569 -388.8557326793671,5.225993633270264,7731.246908187866,32767,0,7731.246908187866,0.2860509664178127,3581,0.7438860533457833,8126.714017152786,0.2641618422099522,0.7514227458408901,0.2846828067846089,3554,0.7267701637899198 -392.8660671710968,5.263420820236206,7811.215454101562,33112,0,7811.215454101562,0.2860263205546635,3581,0.743903847454447,8210.744921445847,0.2643602745873587,0.7512646402631488,0.2847045829742895,3554,0.7267347860685144 -396.87922167778015,5.295862913131714,7891.278911113739,33459,0,7891.278911113739,0.2859879711825432,3581,0.7437773797472773,8294.86878991127,0.2643741539546421,0.7510176386151995,0.2846880104009127,3554,0.7265887413389842 -400.88340163230896,5.328710556030273,7971.397326469421,33805,0,7971.397326469421,0.2859865394726508,3581,0.7439129149504329,8379.038766145706,0.2641398906707763,0.7513986315046038,0.2846557926313221,3554,0.7267900165306697 -404.8903458118439,5.362544059753418,8051.56923866272,34151,0,8051.56923866272,0.2859658137675405,3581,0.7440828793676696,8463.266416311264,0.2642030886241368,0.7514948163713727,0.2846536802722285,3554,0.7269228718961029 -408.90083360672,5.398136138916016,8131.788681983948,34497,0,8131.788681983948,0.2859455141665648,3581,0.7439314590023737,8547.546661138535,0.2642321756907871,0.751321724482945,0.2846361974953397,3554,0.7267886426385762 -412.909912109375,5.430059194564819,8211.779695034027,34841,0,8211.779695034027,0.2859060398795291,3581,0.7437752662707693,8631.593449831009,0.2641234397888183,0.7512396403721401,0.2845885577870005,3554,0.7266201347733188 -416.91813564300537,5.466964721679688,8291.951369524002,35188,0,8291.951369524002,0.2859155675679803,3581,0.7440539724631737,8715.825322628021,0.2641071251460484,0.7515051705496651,0.2845823409252778,3554,0.7268994470359103 -420.9276645183563,5.500535249710083,8372.11243891716,35534,0,8372.11243891716,0.285922385234135,3581,0.7441068093758727,8800.044318675995,0.2641344581331525,0.751544748033796,0.2846019704085625,3554,0.7269539218574141 -424.93412804603577,5.533951044082642,8452.082240343094,35877,0,8452.082240343094,0.2858525723327108,3581,0.7439154374869101,8884.068562984467,0.2641035829271589,0.7513072150094169,0.2845437860784063,3554,0.7267433728940982 -428.94553327560425,5.568581819534302,8523.938553333282,36189,0,8523.938553333282,0.285875275161006,3581,0.7441089228523806,8959.984071493149,0.264105030468532,0.7515296254839215,0.2845614921127603,3554,0.7269513801570414 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/measurements.csv deleted file mode 100644 index fe7fb7b3b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/measurements.csv +++ /dev/null @@ -1,471 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.6054354,1.0139427,,,,,,,,,,,,,, -1,,,0.1975462777273995,1.0327176366533553,0.1897597777152504,1.0370871866426914,3554.0,0.2128455943150481,1.0323986403937448,3581.0,44.76288342475891,48.72807693481445,44.76288342475891,3.965081214904785,0.0,0.0 -100,0.84390146,0.47173285,,,,,,,,,,,,,, -200,0.1966194,0.36092034,,,,,,,,,,,,,, -248,,,0.6612253870282855,0.3414253166743687,0.637590141060249,0.3610390924508476,3554.0,0.6574814068608629,0.3618895938961707,3581.0,125.10523748397829,133.1096351146698,125.10523748397829,7.969703435897827,0.0244109630584716,0.0 -300,0.1676544,0.34076256,,,,,,,,,,,,,, -400,0.07945391,0.36750776,,,,,,,,,,,,,, -483,,,0.6970335415431431,0.3137917518615722,0.674776522712085,0.3325878136870427,3554.0,0.693326854732442,0.3340769929947116,3581.0,205.28294610977173,217.33776664733887,205.28294610977173,11.9761803150177,0.0585534572601318,0.0 -500,0.08901173,0.24360149,,,,,,,,,,,,,, -600,0.11258794,0.36730403,,,,,,,,,,,,,, -700,0.16454776,0.32400733,,,,,,,,,,,,,, -715,,,0.716080801827567,0.2966330392020089,0.6942070664216375,0.3145881785927827,3554.0,0.7114660736569743,0.3169058485758168,3581.0,285.4695768356323,301.5783360004425,285.4695768356323,15.985013246536257,0.0937292575836181,0.0 -800,0.10720289,0.25179094,,,,,,,,,,,,,, -900,0.1125827,0.3182636,,,,,,,,,,,,,, -1000,0.12952697,0.26284254,,,,,,,,,,,,,, -1034,,,0.7235066550118583,0.2870752811431885,0.7026445499349324,0.3045320097623276,3554.0,0.7197285396231848,0.3066954731242146,3581.0,365.430878162384,385.5891439914704,365.430878162384,19.99083566665649,0.1238799095153808,0.0 -1100,0.50322896,0.26626754,,,,,,,,,,,,,, -1200,0.16637957,0.25427282,,,,,,,,,,,,,, -1300,0.13284822,0.3656017,,,,,,,,,,,,,, -1381,,,0.7308674539838519,0.2799033607755388,0.7087941596396665,0.2979709537229178,3554.0,0.7259140717720259,0.2998524793397444,3581.0,445.5725281238556,469.7762997150421,445.5725281238556,23.99786949157715,0.1478610038757324,0.0 -1400,0.25165415,0.27188602,,,,,,,,,,,,,, -1500,0.13117155,0.33635268,,,,,,,,,,,,,, -1600,0.15575157,0.19268724,,,,,,,,,,,,,, -1700,0.24059288,0.33701628,,,,,,,,,,,,,, -1727,,,0.7333673749651227,0.2783775159290859,0.7113109925655952,0.2964568902886712,3554.0,0.7284042243350322,0.298256906841228,3581.0,525.693502664566,553.9444513320923,525.693502664566,28.005931615829468,0.1723105907440185,0.0 -1800,0.10381038,0.27820826,,,,,,,,,,,,,, -1900,0.1611684,0.21781519,,,,,,,,,,,,,, -2000,0.08994849,0.33580065,,,,,,,,,,,,,, -2073,,,0.7364885466439384,0.275574905531747,0.7148465664787915,0.2935554706184932,3554.0,0.7319312075668458,0.2952709054187901,3581.0,605.7808966636658,638.0774807929993,605.7808966636658,32.01269316673279,0.1958217620849609,0.0 -2100,0.21240646,0.28997383,,,,,,,,,,,,,, -2200,0.4481415,0.2866003,,,,,,,,,,,,,, -2300,0.1572935,0.22425552,,,,,,,,,,,,,, -2400,0.15656275,0.28382713,,,,,,,,,,,,,, -2420,,,0.7353577613830566,0.2753708021981375,0.7139438506788126,0.2931602705578222,3554.0,0.7310254806181933,0.2949160118071069,3581.0,685.8993790149689,722.2499129772186,685.8993790149689,36.02655911445618,0.2213416099548339,0.0 -2500,0.15742415,0.25789192,,,,,,,,,,,,,, -2600,0.059832536,0.2531016,,,,,,,,,,,,,, -2700,0.10355433,0.29418573,,,,,,,,,,,,,, -2765,,,0.7376354762486049,0.2741092954363142,0.7155301464898706,0.292292039449388,3554.0,0.7326973768893117,0.2939401310737224,3581.0,765.968150138855,806.3747057914734,765.968150138855,40.04101777076721,0.2482287883758545,0.0 -2800,0.1964678,0.3510169,,,,,,,,,,,,,, -2900,0.1773262,0.22603571,,,,,,,,,,,,,, -3000,0.08115314,0.2601655,,,,,,,,,,,,,, -3100,0.21948422,0.2400251,,,,,,,,,,,,,, -3110,,,0.7347473417009626,0.2750832523618425,0.7142884915104459,0.2927851636690173,3554.0,0.7312285788929419,0.2943855292036093,3581.0,846.0964727401733,890.5473172664642,846.0964727401733,44.046605587005615,0.2724692821502685,0.0 -3200,0.13561998,0.27402082,,,,,,,,,,,,,, -3300,0.100827776,0.30675927,,,,,,,,,,,,,, -3400,0.081274934,0.37726882,,,,,,,,,,,,,, -3458,,,0.7395678247724261,0.2726844719478062,0.7174643117789814,0.2909971118040412,3554.0,0.7346316851307246,0.2926600801648457,3581.0,926.2872281074524,974.789454460144,926.2872281074524,48.05672740936279,0.2985961437225342,0.0 -3500,0.1629443,0.32436255,,,,,,,,,,,,,, -3600,0.16498853,0.2881899,,,,,,,,,,,,,, -3700,0.105972685,0.22544882,,,,,,,,,,,,,, -3793,,,0.7412997654506138,0.2718532766614641,0.7187269186128307,0.290306353206774,3554.0,0.7358856584665596,0.2918852183180326,3581.0,1003.8125767707824,1058.8378336429596,1003.8125767707824,52.06747007369995,2.796729803085327,0.0 -3800,0.0658441,0.25506678,,,,,,,,,,,,,, -3900,0.22062273,0.2791026,,,,,,,,,,,,,, -4000,0.27998105,0.22281809,,,,,,,,,,,,,, -4100,0.14826415,0.2948167,,,,,,,,,,,,,, -4139,,,0.7402080808367048,0.2731419290815081,0.7178762046285875,0.2915027040944182,3554.0,0.7350120427254957,0.2931617240405089,3581.0,1083.916685819626,1142.9926023483276,1083.916685819626,56.0785710811615,2.821539163589477,0.0 -4200,0.0505143,0.23660786,,,,,,,,,,,,,, -4300,0.28023562,0.31804404,,,,,,,,,,,,,, -4400,0.133203,0.2587689,,,,,,,,,,,,,, -4484,,,0.7428110667637416,0.2710321971348354,0.71978186165676,0.2897092597029755,3554.0,0.7369192848322745,0.2912875135262496,3581.0,1164.0630095005035,1227.1880266666412,1164.0630095005035,60.08868336677551,2.8460841178894043,0.0 -4500,0.17478141,0.26338005,,,,,,,,,,,,,, -4600,0.20538962,0.23357011,,,,,,,,,,,,,, -4700,0.10334098,0.25372523,,,,,,,,,,,,,, -4800,0.18212804,0.2661491,,,,,,,,,,,,,, -4831,,,0.7431929452078683,0.2707383973257882,0.7205519281751196,0.2893053354275112,3554.0,0.737703248263404,0.2908246621710067,3581.0,1244.148692369461,1311.323415517807,1244.148692369461,64.09915161132812,2.870636463165283,0.0 -4900,0.10414284,0.36851612,,,,,,,,,,,,,, -5000,0.08228541,0.29665118,,,,,,,,,,,,,, -5100,0.11773542,0.2402722,,,,,,,,,,,,,, -5175,,,0.7424209458487374,0.2704767840249197,0.7198196436893289,0.2888622552273846,3554.0,0.7370804544601718,0.2903081557831262,3581.0,1324.1467669010162,1395.3740241527555,1324.1467669010162,68.11032629013062,2.897249937057495,0.0 -5200,0.0447298,0.31177118,,,,,,,,,,,,,, -5300,0.15223892,0.19814645,,,,,,,,,,,,,, -5400,0.1445914,0.23647082,,,,,,,,,,,,,, -5500,0.25003827,0.2706975,,,,,,,,,,,,,, -5517,,,0.7446633747645787,0.2705959592546735,0.7215675779051772,0.2892542953362408,3554.0,0.7387610773439681,0.2907500087266126,3581.0,1404.2759974002838,1479.5537617206571,1404.2759974002838,72.12192797660828,2.9218204021453857,0.0 -5600,0.27883887,0.2274453,,,,,,,,,,,,,, -5700,0.13397337,0.29188132,,,,,,,,,,,,,, -5800,0.12288248,0.4359888,,,,,,,,,,,,,, -5863,,,0.7418193135942731,0.2705733605793544,0.7196740111274268,0.288847142414357,3554.0,0.737032253560458,0.2902668407262287,3581.0,1484.4351799488068,1563.7628009319303,1484.4351799488068,76.13256049156189,2.946476936340332,0.0 -5900,0.080931984,0.2798031,,,,,,,,,,,,,, -6000,0.19637188,0.23700649,,,,,,,,,,,,,, -6100,0.21609513,0.33553,,,,,,,,,,,,,, -6200,0.11678326,0.23252983,,,,,,,,,,,,,, -6210,,,0.745011465890067,0.2705646753311157,0.7222998623909679,0.289374133074089,3554.0,0.7394353445266685,0.2908501261540945,3581.0,1564.6058371067047,1647.987946987152,1564.6058371067047,80.14620423316956,2.9725465774536133,0.0 -6300,0.115628116,0.2817505,,,,,,,,,,,,,, -6400,0.17083305,0.2861244,,,,,,,,,,,,,, -6500,0.12691468,0.24930231,,,,,,,,,,,,,, -6553,,,0.7449780872889927,0.2694678476878575,0.7216078329435144,0.2884729972500175,3554.0,0.7388201865095294,0.2899777375929384,3581.0,1644.8221879005432,1732.2529256343842,1644.8221879005432,84.1552243232727,2.997762680053711,0.0 -6600,0.07415754,0.31032324,,,,,,,,,,,,,, -6700,0.14891696,0.25993422,,,,,,,,,,,,,, -6800,0.13444155,0.2769519,,,,,,,,,,,,,, -6899,,,0.7441498211451939,0.2704054968697684,0.722005643399163,0.2887626823979143,3554.0,0.7392115887234711,0.2901134432377478,3581.0,1724.981882095337,1816.464188098908,1724.981882095337,88.16625618934631,3.023563861846924,0.0 -6900,0.110454135,0.31374586,,,,,,,,,,,,,, -7000,0.3045167,0.22728078,,,,,,,,,,,,,, -7100,0.11343157,0.24132673,,,,,,,,,,,,,, -7200,0.16813363,0.2361075,,,,,,,,,,,,,, -7247,,,0.7448177337646484,0.2693194661821638,0.7221843180659117,0.2880321152772404,3554.0,0.7395222015934795,0.2893399790124965,3581.0,1805.1859049797056,1900.717324256897,1805.1859049797056,92.17454433441162,3.0495445728302,0.0 -7300,0.14688042,0.24696614,,,,,,,,,,,,,, -7400,0.15121162,0.2574913,,,,,,,,,,,,,, -7500,0.19695531,0.19041549,,,,,,,,,,,,,, -7588,,,0.7447967529296875,0.2691126380647932,0.7220263204751688,0.2878792010872432,3554.0,0.7393584412524434,0.2893281162733873,3581.0,1885.1832218170168,1984.7666964530945,1885.1832218170168,96.1854498386383,3.076326608657837,0.0 -7600,0.36932236,0.25377506,,,,,,,,,,,,,, -7700,0.16152297,0.24275066,,,,,,,,,,,,,, -7800,0.06791858,0.35064343,,,,,,,,,,,,,, -7900,0.11564387,0.29016098,,,,,,,,,,,,,, -7934,,,0.7456778798784528,0.2690371445247105,0.7227637570563098,0.2876706442674627,3554.0,0.7399581231674114,0.2891115871963138,3581.0,1965.366005659104,2069.000062465668,1965.366005659104,100.19572973251344,3.102079153060913,0.0 -8000,0.123067446,0.265065,,,,,,,,,,,,,, -8100,0.33715972,0.26118612,,,,,,,,,,,,,, -8200,0.075005986,0.32266766,,,,,,,,,,,,,, -8281,,,0.745875358581543,0.2691262619835989,0.7229084278937464,0.2879217230475345,3554.0,0.7401415865636345,0.2892600418768326,3581.0,2045.4119424819944,2153.09862613678,2045.4119424819944,104.20498085021973,3.1299993991851807,0.0 -8300,0.31310704,0.29663238,,,,,,,,,,,,,, -8400,0.16292892,0.23295836,,,,,,,,,,,,,, -8500,0.13153794,0.2765822,,,,,,,,,,,,,, -8600,0.20216204,0.19614862,,,,,,,,,,,,,, -8624,,,0.7451011112758091,0.2690394095012119,0.7221248972328714,0.2877977979807083,3554.0,0.739326602751501,0.2892616781167097,3581.0,2125.377302646637,2237.113946437836,2125.377302646637,108.21316409111024,3.1573917865753174,0.0 -8700,0.3582313,0.26848707,,,,,,,,,,,,,, -8800,0.122452065,0.25996846,,,,,,,,,,,,,, -8900,0.19363575,0.29238018,,,,,,,,,,,,,, -8969,,,0.7463887759617397,0.2680323464529855,0.7232508018034257,0.2870491298317037,3554.0,0.7405642136885646,0.2884087539924253,3581.0,2205.461377620697,2321.250861644745,2205.461377620697,112.22468495368958,3.184272289276123,0.0 -9000,0.26854452,0.3321055,,,,,,,,,,,,,, -9100,0.10950642,0.2835707,,,,,,,,,,,,,, -9200,0.38067544,0.2276018,,,,,,,,,,,,,, -9300,0.061346956,0.42695588,,,,,,,,,,,,,, -9314,,,0.7459329877580915,0.2686294998441423,0.7227771525042206,0.2875499993680096,3554.0,0.740029367778728,0.2889439066972389,3581.0,2285.5357501506805,2405.378289461136,2285.5357501506805,116.23565125465392,3.211730480194092,0.0 -9400,0.09034091,0.23434378,,,,,,,,,,,,,, -9500,0.078703634,0.26972395,,,,,,,,,,,,,, -9600,0.22396602,0.32400316,,,,,,,,,,,,,, -9658,,,0.746016229901995,0.2684325831277029,0.7231320288319499,0.2871666834739466,3554.0,0.7404210426993159,0.2885687646170762,3581.0,2365.571274280548,2489.4663603305817,2365.571274280548,120.24630069732666,3.239140033721924,0.0 -9700,0.19417037,0.3243738,,,,,,,,,,,,,, -9800,0.2679386,0.23135346,,,,,,,,,,,,,, -9900,0.19116345,0.24892169,,,,,,,,,,,,,, -10000,,,0.7468297140938895,0.2678298609597342,0.7232698302089196,0.2870614776868933,3554.0,0.7405252848148213,0.2884534437940694,3581.0,2445.607465028763,2573.5549368858337,2445.607465028763,124.25722122192384,3.26558256149292,0.0 -10000,0.10737297,0.30922168,,,,,,,,,,,,,, -10100,0.107263185,0.25357756,,,,,,,,,,,,,, -10200,0.32807153,0.36754638,,,,,,,,,,,,,, -10300,0.0914137,0.24205206,,,,,,,,,,,,,, -10347,,,0.746333122253418,0.2679857185908726,0.7232405663073298,0.2869289142735298,3554.0,0.7405346931941148,0.288304341435266,3581.0,2525.712516069412,2657.7135181427,2525.712516069412,128.26841831207275,3.2932775020599365,0.0 -10400,0.19574808,0.292877,,,,,,,,,,,,,, -10500,0.33130944,0.33418182,,,,,,,,,,,,,, -10600,0.11449234,0.29481605,,,,,,,,,,,,,, -10691,,,0.7459057399204799,0.2683102062770298,0.7228599981974536,0.2871290903515405,3554.0,0.7402006957291958,0.2885059057351298,3581.0,2605.689066648484,2741.7419786453247,2605.689066648484,132.2769341468811,3.3216447830200195,0.0 -10700,0.09018858,0.22927827,,,,,,,,,,,,,, -10800,0.2071347,0.3251693,,,,,,,,,,,,,, -10900,0.16561267,0.26079232,,,,,,,,,,,,,, -11000,0.07455951,0.26611903,,,,,,,,,,,,,, -11037,,,0.7467265129089355,0.2679037196295602,0.7233592018895962,0.2869203617952483,3554.0,0.7406535251151913,0.2883143974928441,3581.0,2685.8272409439087,2825.9370296001434,2685.8272409439087,136.29143595695496,3.349501371383667,0.0 -11100,0.10205932,0.28628284,,,,,,,,,,,,,, -11200,0.1538509,0.27629042,,,,,,,,,,,,,, -11300,0.086395375,0.33275974,,,,,,,,,,,,,, -11383,,,0.7467336654663086,0.2682475192206247,0.7235851384443585,0.287365915001143,3554.0,0.7408546462667551,0.2887903387671041,3581.0,2765.9728231430054,2910.1348733901978,2765.9728231430054,140.30236983299255,3.3763203620910645,0.0 -11400,0.08024561,0.304924,,,,,,,,,,,,,, -11500,0.10085827,0.26556963,,,,,,,,,,,,,, -11600,0.23541068,0.23665379,,,,,,,,,,,,,, -11700,0.104158975,0.2966473,,,,,,,,,,,,,, -11728,,,0.7459479059491839,0.2683146340506417,0.723168093499402,0.2869273342976224,3554.0,0.7404165430396538,0.2882078714591769,3581.0,2846.0457010269165,2994.259813785553,2846.0457010269165,144.31327152252197,3.4031455516815186,0.0 -11800,0.09477404,0.31333777,,,,,,,,,,,,,, -11900,0.3044729,0.2761524,,,,,,,,,,,,,, -12000,0.12547258,0.25475115,,,,,,,,,,,,,, -12072,,,0.7469075066702706,0.2674442529678345,0.7234120967351927,0.2866193076902785,3554.0,0.7406919085756423,0.2879914446470958,3581.0,2926.048833608628,3078.3154022693634,2926.048833608628,148.32446479797363,3.429948806762696,0.0 -12100,0.104249045,0.29234064,,,,,,,,,,,,,, -12200,0.10394856,0.25843507,,,,,,,,,,,,,, -12300,0.13019535,0.28580084,,,,,,,,,,,,,, -12400,0.2805933,0.25769624,,,,,,,,,,,,,, -12417,,,0.7473031452723912,0.2675682646887643,0.7239558832257668,0.2866649896023846,3554.0,0.7412790459848855,0.2879815590311714,3581.0,3006.178083181381,3162.5018265247345,3006.178083181381,152.33849143981934,3.458447933197021,0.0 -12500,0.11586354,0.29904795,,,,,,,,,,,,,, -12600,0.093753844,0.30059975,,,,,,,,,,,,,, -12700,0.32059032,0.27203655,,,,,,,,,,,,,, -12762,,,0.7473674501691546,0.2675473690032959,0.7242794348137662,0.286402455996984,3554.0,0.7415175279469771,0.2877508151201655,3581.0,3086.1382641792297,3246.5176918506622,3086.1382641792297,156.35280227661133,3.4851222038269043,0.0 -12800,0.2679424,0.2529104,,,,,,,,,,,,,, -12900,0.11717543,0.2088125,,,,,,,,,,,,,, -13000,0.097635,0.30156064,,,,,,,,,,,,,, -13100,0.19882368,0.25637856,,,,,,,,,,,,,, -13105,,,0.7421697889055524,0.2700824226651873,0.720902820215778,0.2887568777038196,3554.0,0.7376949307106954,0.2902166967916608,3581.0,3166.1795699596405,3330.6173605918884,3166.1795699596405,160.36889815330505,3.513031244277954,0.0 -13200,0.100060716,0.2298256,,,,,,,,,,,,,, -13300,0.1252943,0.3485277,,,,,,,,,,,,,, -13400,0.13245283,0.40599766,,,,,,,,,,,,,, -13452,,,0.7474530764988491,0.2673826728548322,0.7240228604653207,0.286526982141601,3554.0,0.7413308602476613,0.2877930505619938,3581.0,3246.334449291229,3414.818110704422,3246.334449291229,164.37267231941223,3.5402119159698486,0.0 -13500,0.15782167,0.3281097,,,,,,,,,,,,,, -13600,0.19275263,0.19758028,,,,,,,,,,,,,, -13700,0.16740574,0.31933212,,,,,,,,,,,,,, -13798,,,0.7462870052882603,0.268414991242545,0.7231948843952237,0.2871614970312939,3554.0,0.7404483133639347,0.2885023605487294,3581.0,3326.4810423851013,3499.016035795212,3326.4810423851013,168.38208889961243,3.567617654800415,0.0 -13800,0.29078376,0.26511052,,,,,,,,,,,,,, -13900,0.11855131,0.23484951,,,,,,,,,,,,,, -14000,0.10780757,0.21726434,,,,,,,,,,,,,, -14100,0.17023093,0.24431846,,,,,,,,,,,,,, -14141,,,0.7470973559788295,0.2679232358932495,0.7241787972179234,0.2867153255739572,3554.0,0.7413019533431653,0.2880931642261239,3581.0,3406.449381589889,3583.039057970047,3406.449381589889,172.3943543434143,3.595734119415283,0.0 -14200,0.114203386,0.35542998,,,,,,,,,,,,,, -14300,0.117034875,0.33050328,,,,,,,,,,,,,, -14400,0.12215927,0.28791296,,,,,,,,,,,,,, -14487,,,0.7464298520769391,0.2677939278738839,0.7232238048237901,0.286742957978686,3554.0,0.7403454347816601,0.2881711583269338,3581.0,3486.5816123485565,3667.222679138184,3486.5816123485565,176.40333008766174,3.623152017593384,0.0 -14500,0.14296825,0.34107152,,,,,,,,,,,,,, -14600,0.12592992,0.21196273,,,,,,,,,,,,,, -14700,0.13685928,0.26077658,,,,,,,,,,,,,, -14800,0.062188767,0.34026027,,,,,,,,,,,,,, -14832,,,0.7475502831595284,0.267444235937936,0.7244118780115715,0.286362716168182,3554.0,0.7416326101516685,0.2877265442286547,3581.0,3566.6675159931183,3751.367563962936,3566.6675159931183,180.41909766197205,3.652091026306152,0.0 -14900,0.20293924,0.35204574,,,,,,,,,,,,,, -15000,0.24266681,0.21899146,,,,,,,,,,,,,, -15100,0.11685601,0.19635344,,,,,,,,,,,,,, -15174,,,0.7479367256164551,0.2678801161902291,0.7252343585132949,0.2868504135140423,3554.0,0.742216406904496,0.2883633483358349,3581.0,3646.81110739708,3835.5609562397,3646.81110739708,184.4270977973938,3.679409980773926,0.0 -15200,0.193588,0.25747773,,,,,,,,,,,,,, -15300,0.2597506,0.24159577,,,,,,,,,,,,,, -15400,0.12661485,0.2742805,,,,,,,,,,,,,, -15500,0.14495146,0.25578263,,,,,,,,,,,,,, -15520,,,0.7471392495291573,0.2672254358019147,0.7236271108478123,0.2866708114701304,3554.0,0.7409330494275342,0.2879875585773876,3581.0,3726.9752271175385,3919.774846792221,3726.9752271175385,188.43514847755432,3.70670485496521,0.0 -15600,0.10829298,0.27544987,,,,,,,,,,,,,, -15700,0.23102704,0.20776966,,,,,,,,,,,,,, -15800,0.10379294,0.2494108,,,,,,,,,,,,,, -15866,,,0.748380184173584,0.2670131070273263,0.724918844194042,0.2861659576467537,3554.0,0.7422716981770107,0.2874218968165317,3581.0,3806.976819515228,4003.8309524059296,3806.976819515228,192.44737243652344,3.734318971633911,0.0 -15900,0.320158,0.20631102,,,,,,,,,,,,,, -16000,0.2198444,0.23462254,,,,,,,,,,,,,, -16100,0.21040629,0.36575517,,,,,,,,,,,,,, -16200,0.18173924,0.22633928,,,,,,,,,,,,,, -16209,,,0.7478961263384137,0.2671475410461426,0.7246443405537775,0.2861368311343732,3554.0,0.7419779931190659,0.2874222036115086,3581.0,3886.973289012909,4087.879930019378,3886.973289012909,196.4577419757843,3.761636972427368,0.0 -16300,0.11281878,0.24424969,,,,,,,,,,,,,, -16400,0.21360327,0.23180121,,,,,,,,,,,,,, -16500,0.096113525,0.27499405,,,,,,,,,,,,,, -16556,,,0.748145307813372,0.2665354013442993,0.7245252241092783,0.2858139321451182,3554.0,0.7418374128429559,0.2871162608428163,3581.0,3967.0818524360657,4172.041687011719,3967.0818524360657,200.4681260585785,3.789839267730713,0.0 -16600,0.183891,0.23364902,,,,,,,,,,,,,, -16700,0.24722579,0.32089162,,,,,,,,,,,,,, -16800,0.26688462,0.2963,,,,,,,,,,,,,, -16900,0.0891093,0.33063114,,,,,,,,,,,,,, -16903,,,0.7478715351649693,0.2666094814028059,0.724559914884637,0.2857410128222601,3554.0,0.7418483892854649,0.2870625035451864,3581.0,4047.131493330002,4256.145545244217,4047.131493330002,204.47909569740293,3.817901611328125,0.0 -17000,0.4985463,0.22526217,,,,,,,,,,,,,, -17100,0.19349693,0.31122902,,,,,,,,,,,,,, -17200,0.14006262,0.25713915,,,,,,,,,,,,,, -17247,,,0.7482375417436872,0.2673543861934117,0.7250987553636747,0.2863902798783061,3554.0,0.7423234442631248,0.2876970578225356,3581.0,4127.304555654526,4340.377782583237,4127.304555654526,208.4960463047028,3.8455328941345215,0.0 -17300,0.2220793,0.25439996,,,,,,,,,,,,,, -17400,0.1789679,0.2332861,,,,,,,,,,,,,, -17500,0.25820807,0.25133502,,,,,,,,,,,,,, -17592,,,0.748047011239188,0.2666993652071271,0.7240940280757597,0.2863386730565472,3554.0,0.7414175809611491,0.2876395508085206,3581.0,4207.318914651871,4424.445866823196,4207.318914651871,212.5071542263031,3.8736438751220703,0.0 -17600,0.18178326,0.30193043,,,,,,,,,,,,,, -17700,0.15011993,0.26136133,,,,,,,,,,,,,, -17800,0.18272103,0.28616938,,,,,,,,,,,,,, -17900,0.15618291,0.20692194,,,,,,,,,,,,,, -17937,,,0.7487597465515137,0.2666618824005127,0.7253592453045864,0.2859063950830051,3554.0,0.7426705998237224,0.2872101741940973,3581.0,4287.361625432968,4508.540894031525,4287.361625432968,216.51686763763428,3.901830911636353,0.0 -18000,0.156664,0.31392437,,,,,,,,,,,,,, -18100,0.2612757,0.26890498,,,,,,,,,,,,,, -18200,0.08412183,0.21346137,,,,,,,,,,,,,, -18283,,,0.7487377439226423,0.2666053431374686,0.7254902459156936,0.2856880149347566,3554.0,0.7427411626684236,0.2869363426329935,3581.0,4367.358929157257,4592.592378616333,4367.358929157257,220.5280795097351,3.9302728176116943,0.0 -18300,0.20695944,0.25581676,,,,,,,,,,,,,, -18400,0.1332053,0.22045204,,,,,,,,,,,,,, -18500,0.099108264,0.336829,,,,,,,,,,,,,, -18600,0.13168895,0.25814015,,,,,,,,,,,,,, -18626,,,0.7493752070835659,0.2661102328981672,0.7254873607422974,0.2856842539051509,3554.0,0.7427425262016546,0.2870357101171984,3581.0,4447.448851585388,4676.735166788101,4447.448851585388,224.53717684745789,3.959471225738525,0.0 -18700,0.22400458,0.28920448,,,,,,,,,,,,,, -18800,0.24729884,0.2622141,,,,,,,,,,,,,, -18900,0.11342347,0.24405476,,,,,,,,,,,,,, -18972,,,0.7464577811104911,0.2676242760249546,0.7239104760920794,0.2866407060596335,3554.0,0.7408810306347738,0.2880688592462824,3581.0,4527.552319765091,4760.894608259201,4527.552319765091,228.54846930503845,3.989834785461426,0.0 -19000,0.1191656,0.28044412,,,,,,,,,,,,,, -19100,0.11206551,0.27397862,,,,,,,,,,,,,, -19200,0.1793488,0.23212272,,,,,,,,,,,,,, -19300,0.5400874,0.25769204,,,,,,,,,,,,,, -19318,,,0.7489397185189384,0.266740185873849,0.7255840140510692,0.2859350407331528,3554.0,0.7427173008368821,0.2872599090686959,3581.0,4607.723635435104,4845.123031139374,4607.723635435104,232.5619733333588,4.01876974105835,0.0 -19400,0.29781052,0.24489781,,,,,,,,,,,,,, -19500,0.8636816,0.22960931,,,,,,,,,,,,,, -19600,0.13067433,0.34349582,,,,,,,,,,,,,, -19663,,,0.7484968049185616,0.2659630945750645,0.7245838893016672,0.2855817787286332,3554.0,0.7418566386615122,0.2868840170452562,3581.0,4687.796002864838,4929.245357036591,4687.796002864838,236.56711435318,4.048708200454712,0.0 -19700,0.21729776,0.20273411,,,,,,,,,,,,,, -19800,0.112342015,0.21936911,,,,,,,,,,,,,, -19900,0.09588619,0.31456137,,,,,,,,,,,,,, -20000,0.3574423,0.24239188,,,,,,,,,,,,,, -20008,,,0.7490484373910087,0.2663997071129935,0.725382051913337,0.2857362041999332,3554.0,0.7426997794348645,0.2869510687918877,3581.0,4767.805709838867,5013.3049693107605,4767.805709838867,240.57340145111084,4.0773022174835205,0.0 -20100,0.16321418,0.38477704,,,,,,,,,,,,,, -20200,0.1706354,0.3598729,,,,,,,,,,,,,, -20300,0.18422405,0.26978707,,,,,,,,,,,,,, -20354,,,0.7480840001787458,0.2667103154318673,0.7245500228615644,0.2860161518876178,3554.0,0.7417742812543633,0.2873005082706472,3581.0,4847.913654327393,5097.467324256897,4847.913654327393,244.5843369960785,4.106105089187622,0.0 -20400,0.16263144,0.2354979,,,,,,,,,,,,,, -20500,0.3036992,0.23320502,,,,,,,,,,,,,, -20600,0.2135711,0.36511225,,,,,,,,,,,,,, -20697,,,0.7485865865434919,0.2661643028259277,0.7251484215628518,0.28558687930303,3554.0,0.742291810292167,0.2869389333461323,3581.0,4927.873211860657,5181.477455615997,4927.873211860657,248.5912554264069,4.135280132293701,0.0 -20700,0.14595181,0.25300115,,,,,,,,,,,,,, -20800,0.14131396,0.23597269,,,,,,,,,,,,,, -20900,0.12019162,0.30512193,,,,,,,,,,,,,, -21000,0.11334468,0.30996096,,,,,,,,,,,,,, -21042,,,0.7484387670244489,0.2663168055670602,0.7248215039392234,0.285725024153023,3554.0,0.742089802844003,0.2870958760210136,3581.0,5007.937152862549,5265.59605717659,5007.937152862549,252.6009838581085,4.1652820110321045,0.0 -21100,0.084834024,0.27013993,,,,,,,,,,,,,, -21200,0.19851455,0.22033644,,,,,,,,,,,,,, -21300,0.11668342,0.273114,,,,,,,,,,,,,, -21388,,,0.7481115886143276,0.2667201246534075,0.7249347126477209,0.2858996458380962,3554.0,0.7421052107695127,0.2872167873302674,3581.0,5088.057301521301,5349.772326469421,5088.057301521301,256.6096124649048,4.198110818862915,0.0 -21400,0.11730545,0.25770974,,,,,,,,,,,,,, -21500,0.3777059,0.21556088,,,,,,,,,,,,,, -21600,0.13610296,0.22334614,,,,,,,,,,,,,, -21700,0.1624791,0.30207258,,,,,,,,,,,,,, -21732,,,0.7490191459655762,0.2662261894771031,0.7253119834165729,0.2857492218275183,3554.0,0.7425255198879503,0.2870578334438704,3581.0,5168.070353746414,5433.842468261719,5168.070353746414,260.62133407592773,4.228944778442383,0.0 -21800,0.28955325,0.24692145,,,,,,,,,,,,,, -21900,0.07805652,0.34767976,,,,,,,,,,,,,, -22000,0.17949161,0.28501135,,,,,,,,,,,,,, -22079,,,0.749504634312221,0.265928966658456,0.7256125223120076,0.2855818130759355,3554.0,0.7428523588034068,0.2868992204407812,3581.0,5248.222505569458,5518.048063755035,5248.222505569458,264.62954115867615,4.259636402130127,0.0 -22100,0.089947574,0.26990128,,,,,,,,,,,,,, -22200,0.06582751,0.2814898,,,,,,,,,,,,,, -22300,0.15339968,0.3278006,,,,,,,,,,,,,, -22400,0.082788624,0.25321501,,,,,,,,,,,,,, -22423,,,0.749286447252546,0.2660556520734514,0.7258175757069499,0.2854290019278453,3554.0,0.7429833943469003,0.2867292901118752,3581.0,5328.270455598831,5602.152354717255,5328.270455598831,268.64070224761963,4.2899558544158936,0.0 -22500,0.05556611,0.3091541,,,,,,,,,,,,,, -22600,0.11529505,0.22384597,,,,,,,,,,,,,, -22700,0.20950253,0.2123117,,,,,,,,,,,,,, -22767,,,0.7490869930812291,0.2664235830307007,0.725718861560038,0.2856929265989906,3554.0,0.7428747207483943,0.2871282258469178,3581.0,5408.337042331696,5686.270875692368,5408.337042331696,272.6491093635559,4.319113254547119,0.0 -22800,0.14630352,0.23275721,,,,,,,,,,,,,, -22900,0.17703485,0.21019135,,,,,,,,,,,,,, -23000,0.077035315,0.27557212,,,,,,,,,,,,,, -23100,0.123870105,0.26685128,,,,,,,,,,,,,, -23111,,,0.7498129435947963,0.2655370916639055,0.7259281053258653,0.285219534904553,3554.0,0.743165426033231,0.2865338276232198,3581.0,5488.437965869904,5770.424674987793,5488.437965869904,276.6583559513092,4.348245620727539,0.0 -23200,0.10670986,0.3039124,,,,,,,,,,,,,, -23300,0.25556076,0.24038266,,,,,,,,,,,,,, -23400,0.12406866,0.2932509,,,,,,,,,,,,,, -23459,,,0.7490844045366559,0.2656686987195696,0.7254319241963281,0.2851312795112021,3554.0,0.742678781023108,0.2864100188058503,3581.0,5568.480888128281,5854.519265413284,5568.480888128281,280.66591787338257,4.377592325210571,0.0 -23500,0.08958929,0.24224219,,,,,,,,,,,,,, -23600,0.114573464,0.28356415,,,,,,,,,,,,,, -23700,0.0840198,0.33542177,,,,,,,,,,,,,, -23800,0.12852702,0.21107307,,,,,,,,,,,,,, -23804,,,0.7493758201599121,0.2656446354729788,0.7255900591762803,0.285176532082029,3554.0,0.7428889014939961,0.2864074621810423,3581.0,5648.613835811615,5938.706800699234,5648.613835811615,284.6766347885132,4.406728506088257,0.0 -23900,0.10152931,0.28435576,,,,,,,,,,,,,, -24000,0.13676634,0.18699889,,,,,,,,,,,,,, -24100,0.25441065,0.29285586,,,,,,,,,,,,,, -24148,,,0.7498668261936733,0.2651958976473127,0.7257230519309229,0.2851227613802229,3554.0,0.7429643048816671,0.2864313581009145,3581.0,5728.598200559616,6022.743552207947,5728.598200559616,288.6837303638458,4.437592029571533,0.0 -24200,0.13995601,0.2699609,,,,,,,,,,,,,, -24300,0.15370849,0.24823152,,,,,,,,,,,,,, -24400,0.21190411,0.283138,,,,,,,,,,,,,, -24494,,,0.7496851512363979,0.2655548027583531,0.7258203921857415,0.2852120128453415,3554.0,0.7430404582126152,0.2865370319263124,3581.0,5808.61803650856,6106.814738750458,5808.61803650856,292.6920075416565,4.466193199157715,0.0 -24500,0.24519305,0.3370556,,,,,,,,,,,,,, -24600,0.2515455,0.22201246,,,,,,,,,,,,,, -24700,0.13978836,0.27165475,,,,,,,,,,,,,, -24800,0.1465586,0.27275497,,,,,,,,,,,,,, -24838,,,0.748795781816755,0.2659956046513149,0.7251729455367192,0.2854291908380082,3554.0,0.7424380492311854,0.2867215520607896,3581.0,5888.69374203682,6190.94571518898,5888.69374203682,296.69993567466736,4.499108552932739,0.0 -24900,0.099361464,0.27558535,,,,,,,,,,,,,, -25000,0.10691447,0.2938633,,,,,,,,,,,,,, -25100,0.09202402,0.3338361,,,,,,,,,,,,,, -25185,,,0.7506846700395856,0.265080486025129,0.7263689872986424,0.2851709162980972,3554.0,0.7436003249572396,0.2864891037332449,3581.0,5968.84553027153,6275.150093793869,5968.84553027153,300.70812129974365,4.529009819030762,0.0 -25200,0.09248675,0.28952765,,,,,,,,,,,,,, -25300,0.11944865,0.28211665,,,,,,,,,,,,,, -25400,0.25668725,0.21562642,,,,,,,,,,,,,, -25500,0.114221595,0.25897783,,,,,,,,,,,,,, -25531,,,0.7494222096034459,0.2652970893042428,0.7254723853184791,0.2849968785171637,3554.0,0.742812748163048,0.2862789150856953,3581.0,6048.872050285339,6359.229727506638,6048.872050285339,304.7168712615967,4.558919191360474,0.0 -25600,0.14245977,0.34903368,,,,,,,,,,,,,, -25700,0.11616817,0.2410528,,,,,,,,,,,,,, -25800,0.06295588,0.25390688,,,,,,,,,,,,,, -25876,,,0.7495801789419991,0.2653527430125645,0.7257883804999649,0.2849699502321328,3554.0,0.7430234140472284,0.2862819148588034,3581.0,6128.933758020401,6443.341115951538,6128.933758020401,308.72206234931946,4.588968515396118,0.0 -25900,0.13627258,0.21472052,,,,,,,,,,,,,, -26000,0.14770417,0.32377344,,,,,,,,,,,,,, -26100,0.08450182,0.29296905,,,,,,,,,,,,,, -26200,0.07683172,0.28062025,,,,,,,,,,,,,, -26222,,,0.7504499980381557,0.2648673909051077,0.7260060737021665,0.2850825063418859,3554.0,0.74321253610636,0.2864005763382261,3581.0,6209.109485387802,6527.5736310482025,6209.109485387802,312.73252749443054,4.620581388473511,0.0 -26300,0.13792159,0.2575318,,,,,,,,,,,,,, -26400,0.081336826,0.4076262,,,,,,,,,,,,,, -26500,0.20714156,0.20702918,,,,,,,,,,,,,, -26568,,,0.7503705705915179,0.2652456249509539,0.7262834625158272,0.2850904405687254,3554.0,0.7435167403701829,0.2864089279792656,3581.0,6289.2570996284485,6611.773455381393,6289.2570996284485,316.7389891147613,4.651438236236572,0.0 -26600,0.18871842,0.30970073,,,,,,,,,,,,,, -26700,0.11443943,0.32657692,,,,,,,,,,,,,, -26800,0.10678103,0.26020348,,,,,,,,,,,,,, -26900,0.11281224,0.22959356,,,,,,,,,,,,,, -26915,,,0.7497830390930176,0.2651932580130441,0.7259059856631612,0.2848746021208497,3554.0,0.7430864774591595,0.2861802634564367,3581.0,6369.344420433044,6695.914905786514,6369.344420433044,320.7466149330139,4.683363676071167,0.0 -27000,0.118194416,0.2218634,,,,,,,,,,,,,, -27100,0.30284166,0.28297707,,,,,,,,,,,,,, -27200,0.10415873,0.3543284,,,,,,,,,,,,,, -27255,,,0.7505489758082798,0.2648061513900757,0.7262866224676421,0.2849321338522615,3554.0,0.7434242928171251,0.2862984136108978,3581.0,6449.312697172165,6779.940864801407,6449.312697172165,324.75905179977417,4.714396953582764,0.0 -27300,0.124631494,0.22068313,,,,,,,,,,,,,, -27400,0.12898232,0.25911146,,,,,,,,,,,,,, -27500,0.10665877,0.33639586,,,,,,,,,,,,,, -27600,,,0.750410965510777,0.2649491514478411,0.7263029030889491,0.2848861943353879,3554.0,0.74350603663432,0.2861646850892732,3581.0,6529.513382673264,6864.195533752441,6529.513382673264,328.7665092945099,4.74639892578125,0.0 -27600,0.07855577,0.29877764,,,,,,,,,,,,,, -27700,0.10831075,0.36969784,,,,,,,,,,,,,, -27800,0.055755343,0.31232116,,,,,,,,,,,,,, -27900,0.08407621,0.3044426,,,,,,,,,,,,,, -27946,,,0.7503207751682827,0.2650993721825735,0.7263553170723129,0.2849315842954241,3554.0,0.7435222626797682,0.2862289075044505,3581.0,6609.545207738876,6948.283418178558,6609.545207738876,332.7754480838776,4.778889656066895,0.0 -28000,0.14448087,0.21095353,,,,,,,,,,,,,, -28100,0.14388646,0.21849178,,,,,,,,,,,,,, -28200,0.13273694,0.36752653,,,,,,,,,,,,,, -28290,,,0.7504880087716239,0.2648235048566545,0.7263116273037422,0.2849101515787668,3554.0,0.7434201340407708,0.2862642571034627,3581.0,6689.76785159111,7032.563189744949,6689.76785159111,336.7862060070038,4.811024188995361,0.0 -28300,0.1300694,0.26380166,,,,,,,,,,,,,, -28400,0.1355474,0.22883672,,,,,,,,,,,,,, -28500,0.09271958,0.29517913,,,,,,,,,,,,,, -28600,0.075543694,0.31206626,,,,,,,,,,,,,, -28637,,,0.7506187983921596,0.2649060317448207,0.7264817151449071,0.2849294547626794,3554.0,0.7436845231342503,0.2862636776018396,3581.0,6769.840493440628,7116.692368984222,6769.840493440628,340.7982308864593,4.841078758239746,0.0 -28700,0.056128692,0.29374015,,,,,,,,,,,,,, -28800,0.18947327,0.26947626,,,,,,,,,,,,,, -28900,0.07169009,0.30325168,,,,,,,,,,,,,, -28982,,,0.7507256780351911,0.2649570022310529,0.7265615382755346,0.2849720625912264,3554.0,0.7437460866596272,0.2862514739794226,3581.0,6849.891664505005,7200.795670986176,6849.891664505005,344.80379843711853,4.873269319534302,0.0 -29000,0.11685075,0.22530116,,,,,,,,,,,,,, -29100,0.14643754,0.18489885,,,,,,,,,,,,,, -29200,0.1346478,0.18837172,,,,,,,,,,,,,, -29300,0.060803257,0.29089046,,,,,,,,,,,,,, -29323,,,0.750253541128976,0.2648698602403913,0.7261388603729952,0.2849260543797481,3554.0,0.7433762282707345,0.2862659956083321,3581.0,6929.979115247726,7284.941651105881,6929.979115247726,348.8160009384155,4.905006408691406,0.0 -29400,0.16077127,0.27930707,,,,,,,,,,,,,, -29500,0.12318245,0.21306622,,,,,,,,,,,,,, -29600,0.15041202,0.25401458,,,,,,,,,,,,,, -29668,,,0.7505779947553363,0.264641353062221,0.7261684677476083,0.2848554878471001,3554.0,0.7433797734571349,0.2861837404661757,3581.0,7010.161997318268,7369.177745580673,7010.161997318268,352.8241741657257,4.935630083084106,0.0 -29700,0.097821675,0.30548045,,,,,,,,,,,,,, -29800,0.09199255,0.28551543,,,,,,,,,,,,,, -29900,0.10681657,0.2421251,,,,,,,,,,,,,, -30000,0.0750126,0.29543698,,,,,,,,,,,,,, -30014,,,0.7503046308244977,0.2647990328924997,0.7261086347469401,0.2848229781254396,3554.0,0.7432790083513683,0.2861041101254887,3581.0,7090.252366781235,7453.321771860123,7090.252366781235,356.831734418869,4.967212677001953,0.0 -30100,0.08102454,0.23588632,,,,,,,,,,,,,, -30200,0.11431559,0.2784277,,,,,,,,,,,,,, -30300,0.10567957,0.21874556,,,,,,,,,,,,,, -30357,,,0.7509540830339704,0.2650235380445208,0.7267261992429305,0.2851525576649989,3554.0,0.7438623960442264,0.2865039321571314,3581.0,7170.308975458145,7537.433724164963,7170.308975458145,360.8413755893707,4.998633146286011,0.0 -30400,0.16817991,0.25114143,,,,,,,,,,,,,, -30500,0.08860607,0.22513062,,,,,,,,,,,,,, -30600,0.06678559,0.32206947,,,,,,,,,,,,,, -30700,0.06609531,0.29046157,,,,,,,,,,,,,, -30701,,,0.7509610312325614,0.264449988092695,0.7264982705446328,0.2848017686662475,3554.0,0.7436667972022479,0.2861409255227241,3581.0,7250.800181627274,7621.926305532455,7250.800181627274,364.797967672348,5.029951572418213,0.0 -30800,0.14246574,0.3558893,,,,,,,,,,,,,, -30900,0.087057404,0.20984904,,,,,,,,,,,,,, -31000,0.059917398,0.27536705,,,,,,,,,,,,,, -31047,,,0.7511175700596401,0.2646209682737078,0.7267902226144837,0.2847992269658747,3554.0,0.7439442762147445,0.2861016898540037,3581.0,7330.806107759476,7705.99044585228,7330.806107759476,368.8075482845306,5.063923120498657,0.0 -31100,0.12874512,0.19981019,,,,,,,,,,,,,, -31200,0.083921246,0.21519877,,,,,,,,,,,,,, -31300,0.16084492,0.33862773,,,,,,,,,,,,,, -31390,,,0.7512401853288923,0.2646078722817557,0.7270120374929657,0.2847450784437429,3554.0,0.7441750542140813,0.2860398877103113,3581.0,7410.819982767105,7790.061582565308,7410.819982767105,372.81839632987976,5.095531225204468,0.0 -31400,0.13974907,0.21941036,,,,,,,,,,,,,, -31500,0.12852445,0.2608719,,,,,,,,,,,,,, -31600,0.09863351,0.2184745,,,,,,,,,,,,,, -31700,0.078849725,0.23523735,,,,,,,,,,,,,, -31735,,,0.7512694086347308,0.2642558813095093,0.7266450709148143,0.284735255115275,3554.0,0.7438293303633762,0.2860473871430815,3581.0,7490.946867465973,7874.244492769241,7490.946867465973,376.8252596855164,5.130434513092041,0.0 -31800,0.23169385,0.21829052,,,,,,,,,,,,,, -31900,0.10058998,0.30495661,,,,,,,,,,,,,, -32000,0.09423399,0.19014792,,,,,,,,,,,,,, -32079,,,0.7509448868887765,0.2644948278154646,0.7264786238876969,0.2848054438275974,3554.0,0.743691477153728,0.2861123254132051,3581.0,7571.110435009003,7958.465414047241,7571.110435009003,380.8364975452423,5.161880731582642,0.0 -32100,0.1066134,0.32657838,,,,,,,,,,,,,, -32200,0.107812226,0.3215056,,,,,,,,,,,,,, -32300,0.08474012,0.2949779,,,,,,,,,,,,,, -32400,0.06596663,0.26835495,,,,,,,,,,,,,, -32422,,,0.7511148452758789,0.2644805908203125,0.7268427739870569,0.2846810378985386,3554.0,0.7440023627303826,0.2859859599710276,3581.0,7651.189124345779,8042.596368312836,7651.189124345779,384.8424386978149,5.193868160247803,0.0 -32500,0.0869715,0.24936128,,,,,,,,,,,,,, -32600,0.12805475,0.23364584,,,,,,,,,,,,,, -32700,0.10083702,0.24698406,,,,,,,,,,,,,, -32767,,,0.7514227458408901,0.2641618422099522,0.7267701637899198,0.2846828067846089,3554.0,0.7438860533457833,0.2860509664178127,3581.0,7731.246908187866,8126.714017152786,7731.246908187866,388.8557326793671,5.225993633270264,0.0 -32800,0.059453342,0.2321873,,,,,,,,,,,,,, -32900,0.078406245,0.21365218,,,,,,,,,,,,,, -33000,0.07116738,0.21387267,,,,,,,,,,,,,, -33100,0.049601924,0.23410012,,,,,,,,,,,,,, -33112,,,0.7512646402631488,0.2643602745873587,0.7267347860685144,0.2847045829742895,3554.0,0.743903847454447,0.2860263205546635,3581.0,7811.215454101562,8210.744921445847,7811.215454101562,392.8660671710968,5.263420820236206,0.0 -33200,0.06729653,0.28849226,,,,,,,,,,,,,, -33300,0.11081559,0.3216474,,,,,,,,,,,,,, -33400,0.058062293,0.26383787,,,,,,,,,,,,,, -33459,,,0.7510176386151995,0.2643741539546421,0.7265887413389842,0.2846880104009127,3554.0,0.7437773797472773,0.2859879711825432,3581.0,7891.278911113739,8294.86878991127,7891.278911113739,396.87922167778015,5.295862913131714,0.0 -33500,0.05928546,0.22216599,,,,,,,,,,,,,, -33600,0.05376747,0.2979876,,,,,,,,,,,,,, -33700,0.08463442,0.24869639,,,,,,,,,,,,,, -33800,0.06861097,0.23817044,,,,,,,,,,,,,, -33805,,,0.7513986315046038,0.2641398906707763,0.7267900165306697,0.2846557926313221,3554.0,0.7439129149504329,0.2859865394726508,3581.0,7971.397326469421,8379.038766145706,7971.397326469421,400.88340163230896,5.328710556030273,0.0 -33900,0.068383254,0.28575647,,,,,,,,,,,,,, -34000,0.09506289,0.397062,,,,,,,,,,,,,, -34100,0.07973934,0.21696644,,,,,,,,,,,,,, -34151,,,0.7514948163713727,0.2642030886241368,0.7269228718961029,0.2846536802722285,3554.0,0.7440828793676696,0.2859658137675405,3581.0,8051.56923866272,8463.266416311264,8051.56923866272,404.8903458118439,5.362544059753418,0.0 -34200,0.087377146,0.3111246,,,,,,,,,,,,,, -34300,0.06442911,0.23496847,,,,,,,,,,,,,, -34400,0.09489902,0.20653374,,,,,,,,,,,,,, -34497,,,0.751321724482945,0.2642321756907871,0.7267886426385762,0.2846361974953397,3554.0,0.7439314590023737,0.2859455141665648,3581.0,8131.788681983948,8547.546661138535,8131.788681983948,408.90083360672,5.398136138916016,0.0 -34500,0.06924,0.2387158,,,,,,,,,,,,,, -34600,0.06067658,0.2424585,,,,,,,,,,,,,, -34700,0.29905677,0.34930354,,,,,,,,,,,,,, -34800,0.066308014,0.31405404,,,,,,,,,,,,,, -34841,,,0.7512396403721401,0.2641234397888183,0.7266201347733188,0.2845885577870005,3554.0,0.7437752662707693,0.2859060398795291,3581.0,8211.779695034027,8631.593449831009,8211.779695034027,412.909912109375,5.430059194564819,0.0 -34900,0.059870157,0.17856152,,,,,,,,,,,,,, -35000,0.05290939,0.24003756,,,,,,,,,,,,,, -35100,0.09194636,0.22000015,,,,,,,,,,,,,, -35188,,,0.7515051705496651,0.2641071251460484,0.7268994470359103,0.2845823409252778,3554.0,0.7440539724631737,0.2859155675679803,3581.0,8291.951369524002,8715.825322628021,8291.951369524002,416.91813564300537,5.466964721679688,0.0 -35200,0.057521332,0.21689603,,,,,,,,,,,,,, -35300,0.07623596,0.28961277,,,,,,,,,,,,,, -35400,0.06523494,0.24569917,,,,,,,,,,,,,, -35500,0.04413355,0.31853712,,,,,,,,,,,,,, -35534,,,0.751544748033796,0.2641344581331525,0.7269539218574141,0.2846019704085625,3554.0,0.7441068093758727,0.285922385234135,3581.0,8372.11243891716,8800.044318675995,8372.11243891716,420.9276645183563,5.500535249710083,0.0 -35600,0.08743892,0.20174131,,,,,,,,,,,,,, -35700,0.06909921,0.20884772,,,,,,,,,,,,,, -35800,0.11066647,0.23480292,,,,,,,,,,,,,, -35877,,,0.7513072150094169,0.2641035829271589,0.7267433728940982,0.2845437860784063,3554.0,0.7439154374869101,0.2858525723327108,3581.0,8452.082240343094,8884.068562984467,8452.082240343094,424.93412804603577,5.533951044082642,0.0 -35900,0.08718354,0.19019195,,,,,,,,,,,,,, -36000,0.05159236,0.30177307,,,,,,,,,,,,,, -36100,0.06667509,0.2442748,,,,,,,,,,,,,, -36189,,,0.7515296254839215,0.264105030468532,0.7269513801570414,0.2845614921127603,3554.0,0.7441089228523806,0.285875275161006,3581.0,8523.938553333282,8959.984071493149,8523.938553333282,428.94553327560425,5.568581819534302,0.0 -36189,,,,,,,,,,,8523.938553333282,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/eval_measurements.csv deleted file mode 100644 index b8578cca8..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,107 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -3.9573137760162354,0.0,31.401532649993896,1,0,31.401532649993896,1.0323986403937448,3581,0.2128455943150481,35.358983516693115,1.0327176366533553,0.1975462777273995,1.0370871866426914,3554,0.1897597777152504 -7.96499228477478,0.017967939376831,111.60064959526062,342,0,111.60064959526062,0.3123536587959892,3581,0.7183853912140463,119.59745144844057,0.2925906862531389,0.7213667460850307,0.3103522631849676,3554,0.7017441697515123 -11.974652290344238,0.0440468788146972,191.57904171943665,682,0,191.57904171943665,0.302153373490296,3581,0.7247404105434935,203.62491083145144,0.2831108910696847,0.7287031582423619,0.3006681786807119,3554,0.7073842715734032 -15.98682451248169,0.0700306892395019,271.53961205482483,1019,0,271.53961205482483,0.2996568464094352,3581,0.7234966637068906,287.6374468803406,0.280535272189549,0.7273684910365513,0.2981226657573332,3554,0.7060646482176772 -19.99767827987671,0.0958211421966552,351.7352349758148,1366,0,351.7352349758148,0.2950062436186644,3581,0.7332023614213907,371.88407588005066,0.2756518295833042,0.7377146993364606,0.2935038809703855,3554,0.715898143487092 -24.00688052177429,0.1219170093536377,431.7387228012085,1710,0,431.7387228012085,0.2947562738891022,3581,0.7352617738367425,455.9366509914398,0.2755191666739328,0.7396605355398995,0.2932262517256084,3554,0.7181294816360088 -28.0211033821106,0.146493911743164,511.8938019275665,2054,0,511.8938019275665,0.2953716023478951,3581,0.73088387769216,540.144252538681,0.2752586773463658,0.736421925680978,0.2937887574959553,3554,0.7134809177379361 -31.982479333877563,0.1711635589599609,591.8965113162994,2401,0,591.8965113162994,0.2993610619633133,3581,0.7316117317308364,624.146678686142,0.2788510152271816,0.7367238317217145,0.297630159789146,3554,0.7148442308622327 -35.997042179107666,0.1950883865356445,671.9206922054291,2748,0,671.9206922054291,0.2926768857119171,3581,0.7344828554785674,708.2232940196991,0.2730081421988351,0.7392657143729073,0.2913048636329663,3554,0.7169806330674944 -40.01491498947144,0.218900442123413,751.973610162735,3094,0,751.973610162735,0.2927859342820615,3581,0.7372292159356674,792.3318250179291,0.2729813711983817,0.7426317759922573,0.2911620818971581,3554,0.7201212130038337 -44.0278480052948,0.2433154582977295,832.1201596260071,3441,0,832.1201596260071,0.2922184317513439,3581,0.7372187167297891,876.5295839309692,0.2722130332674299,0.742382322038923,0.2907499142691334,3554,0.7201152365732274 -48.04212021827698,0.2679710388183594,912.2310099601746,3788,0,912.2310099601746,0.2925956532196837,3581,0.734154175793249,960.69327378273,0.2730766023908342,0.7390454837254116,0.2911471751679446,3554,0.7167553147641742 -52.056846380233765,0.2930951118469238,992.3586266040802,4135,0,992.3586266040802,0.2926088454036931,3581,0.7359224056871335,1044.8746247291565,0.2732838732855661,0.7394935062953404,0.2911408552643149,3554,0.7192164363657146 -56.06653499603272,0.318572998046875,1072.367573738098,4480,0,1072.367573738098,0.2930826732014451,3581,0.7337420478741972,1128.9326810836792,0.2730314220700945,0.7384296144757952,0.2914937394485087,3554,0.7167116249956036 -60.07899618148804,0.3449559211730957,1152.3579506874084,4826,0,1152.3579506874084,0.2923433995719596,3581,0.7363882568154845,1212.97585606575,0.2724244935171945,0.7417031696864537,0.2909856741523635,3554,0.7190104899409117 -64.09178137779236,0.3717594146728515,1232.4112536907196,5173,0,1232.4112536907196,0.2921543116011589,3581,0.7354368515035954,1297.0824031829834,0.2726188898086548,0.7403593744550433,0.2907309202109419,3554,0.7179201004809721 -68.104318857193,0.3967089653015136,1312.3778328895569,5515,0,1312.3778328895569,0.2941171858637601,3581,0.7297423958478777,1381.1000657081604,0.2744945628302438,0.7342469351632255,0.2928571899620146,3554,0.7122155631199001 -72.11777949333191,0.4217469692230224,1392.5162916183472,5861,0,1392.5162916183472,0.2932253669540631,3581,0.7366074447823583,1465.2908942699432,0.2733808415276663,0.741619518824986,0.2918860886830859,3554,0.7192837570782921 -76.12851810455322,0.4470601081848144,1472.5103845596311,6209,0,1472.5103845596311,0.2928307945253595,3581,0.7324459413615261,1549.3350069522858,0.2732622453144618,0.7370784623282296,0.2914096229050893,3554,0.7148053497159891 -80.14044666290283,0.4731578826904297,1552.4796981811523,6552,0,1552.4796981811523,0.2915882748686644,3581,0.7380670389294192,1633.3560178279877,0.2714024782180786,0.7439711434500558,0.2901221486243493,3554,0.7208899056300999 -84.14881253242493,0.4980559349060058,1632.6826753616333,6897,0,1632.6826753616333,0.2932220944743088,3581,0.7310526149294889,1717.6061108112335,0.2739830527986799,0.7350197519574847,0.2920506809558771,3554,0.7134123605224747 -88.164067029953,0.5237960815429688,1712.859538078308,7245,0,1712.859538078308,0.2924798551600461,3581,0.7354655538781066,1801.8379180431368,0.272863643510001,0.7398674147469657,0.2909971461513435,3554,0.718265909120885 -92.1750111579895,0.550663948059082,1792.8237624168396,7585,0,1792.8237624168396,0.2920952024355976,3581,0.7389684707483943,1885.8538382053373,0.2723337071282523,0.7441009793962751,0.2907232264152187,3554,0.7217527785593697 -96.18921422958374,0.5797338485717773,1872.8409950733185,7931,0,1872.8409950733185,0.292417541691392,3581,0.7343427524390882,1969.928893327713,0.2723456450871059,0.7397144862583706,0.2911539072392023,3554,0.7167655502602701 -100.20110750198364,0.6055564880371094,1952.8641078472133,8278,0,1952.8641078472133,0.292765242665282,3581,0.733583128076131,2054.003679037094,0.2730504444667271,0.7389503206525531,0.291527159373681,3554,0.715964502475204 -104.21291184425354,0.6329255104064941,2032.9702162742608,8622,0,2032.9702162742608,0.2912397216865052,3581,0.7363723034766825,2138.162857532501,0.2715661185128348,0.7411728586469378,0.2899193964986635,3554,0.7190179776528207 -108.22664284706116,0.6587364673614502,2112.99298119545,8967,0,2112.99298119545,0.2928140571549497,3581,0.7353939683834823,2222.238981962204,0.2726804869515555,0.7409847123282296,0.2913322040856254,3554,0.7181451440058737 -112.237300157547,0.6848461627960205,2193.1075069904327,9313,0,2193.1075069904327,0.2915215640053407,3581,0.7360125352336987,2306.40425825119,0.2718922240393502,0.7409087589808873,0.2901541946574282,3554,0.7185364284740785 -116.2509913444519,0.7118649482727051,2273.0946094989777,9656,0,2273.0946094989777,0.2914483081825084,3581,0.7363588726743577,2390.445656776428,0.2718787704195295,0.7410191808428083,0.2900739593591727,3554,0.7191517947427195 -120.26565432548524,0.7373523712158203,2353.26428937912,10003,0,2353.26428937912,0.2902696018810213,3581,0.7367913172385507,2474.669335842133,0.2702041523797171,0.742525441305978,0.2888822797046462,3554,0.7192932369337366 -124.27734661102296,0.7643404006958008,2433.438845872879,10348,0,2433.438845872879,0.2909743099431024,3581,0.7398191791311785,2558.896376132965,0.2710064649581909,0.7450870786394391,0.2895298637428777,3554,0.7226098811418472 -128.28723120689392,0.7917084693908691,2513.490818977356,10692,0,2513.490818977356,0.2907895171019792,3581,0.7358138684419506,2642.999470949173,0.2711236817496164,0.7411471775599888,0.2894940738538442,3554,0.7182149377242192 -132.29781484603882,0.81842041015625,2593.563329219818,11038,0,2593.563329219818,0.2911124358593968,3581,0.7375919157750978,2727.123106479645,0.2708985464913504,0.7429281643458775,0.2897779199603439,3554,0.7203224881955191 -136.3119599819183,0.844599723815918,2673.632261276245,11382,0,2673.632261276245,0.2932149359248464,3581,0.728263507705599,2811.2464735507965,0.2733525208064488,0.7329865183149066,0.2919654652987831,3554,0.7107803954391179 -140.32230615615845,0.871314287185669,2753.709707736969,11728,0,2753.709707736969,0.2899191056640079,3581,0.7364855449115122,2895.374767303467,0.2701068094798496,0.7416582788739886,0.2886116573095456,3554,0.7189289494451674 -144.3353407382965,0.8990838527679443,2833.909121990204,12074,0,2833.909121990204,0.2903468119502234,3581,0.7362733109641162,2979.629019737244,0.270159363746643,0.7418908391680036,0.2889761165346264,3554,0.7188359369504431 -148.34514021873474,0.9278120994567872,2913.879408121109,12421,0,2913.879408121109,0.2904481906459438,3581,0.7358350032070302,3063.651765346527,0.2705170256750924,0.7414828709193638,0.2892454680795406,3554,0.7181956345403067 -152.35568261146545,0.9562292098999025,2993.9373548030853,12762,0,2993.9373548030853,0.2901792337161407,3581,0.7369440329604161,3147.762551546097,0.2706202268600464,0.7418837547302246,0.288879497573157,3554,0.7193922945536719 -156.36930775642395,0.9846577644348145,3074.032609462738,13105,0,3074.032609462738,0.2905407745523247,3581,0.7382352307534558,3231.913425922394,0.2700550215584891,0.7434845651899066,0.2892409685829347,3554,0.7209310536982977 -160.3839066028595,1.0118327140808103,3154.0704021453857,13450,0,3154.0704021453857,0.2893762489964395,3581,0.7398077254520385,3316.0069653987885,0.2694966622761318,0.7449960708618164,0.2880865729250932,3554,0.7224813535365081 -164.39890503883362,1.0392420291900637,3234.164406776428,13797,0,3234.164406776428,0.2902129129869449,3581,0.740477765681723,3400.156908750534,0.2703155108860561,0.7459932054792132,0.2888293848590497,3554,0.7233974647843978 -168.41451406478882,1.0669360160827637,3314.325216770172,14143,0,3314.325216770172,0.289407030759128,3581,0.7413721753045588,3484.374987363816,0.2687981128692627,0.7477308000837054,0.2880507830360597,3554,0.7242183653102139 -172.42876362800598,1.09513521194458,3394.4950959682465,14487,0,3394.4950959682465,0.2915374832558119,3581,0.7343974301216489,3568.601181983948,0.2714332682745797,0.7401647567749023,0.2901178552115574,3554,0.7168947648116559 -176.43772840499878,1.122864007949829,3474.6364142894745,14833,0,3474.6364142894745,0.2900032697526878,3581,0.7382288221472704,3652.793179273605,0.2703208412442888,0.7431543213980538,0.2886230606139209,3554,0.7210589630521947 -180.44852137565613,1.1519262790679932,3554.608986616134,15177,0,3554.608986616134,0.2892966186557525,3581,0.7372273069891441,3736.8194127082825,0.2690542936325073,0.7430147443498883,0.287806264590734,3554,0.7197655123408483 -184.462073802948,4.738783836364746,3631.154922485352,15507,0,3631.154922485352,0.2892636552398945,3581,0.7414995974849903,3820.978665113449,0.2690369401659284,0.7467621394566127,0.2878441324915588,3554,0.7243016231710748 -188.4758760929108,4.7664830684661865,3711.2650315761566,15854,0,3711.2650315761566,0.2902920660910011,3581,0.7394906357991832,3905.144039869309,0.2699409893580845,0.7453158242361886,0.2888884622190665,3554,0.7223393617886537 -192.4837954044342,4.794270038604736,3791.233234167099,16197,0,3791.233234167099,0.2896448991312657,3581,0.7357913019669785,3989.16166472435,0.2696803978511265,0.741283689226423,0.2883882624551561,3554,0.7182117090777996 -196.49588203430176,4.823106288909912,3871.200170278549,16543,0,3871.200170278549,0.2904213631296251,3581,0.738705104304838,4073.183162927628,0.2695753744670323,0.7450265884399414,0.2891437657173255,3554,0.7212872352235158 -200.5108027458191,4.8511247634887695,3951.185159444809,16890,0,3951.185159444809,0.2889543036381248,3581,0.7379346398526948,4157.22482085228,0.2689415046146938,0.7434981209891183,0.2876258897325197,3554,0.7204480619328574 -204.5243673324585,4.8789660930633545,4031.2938227653503,17233,0,4031.2938227653503,0.2889681435004189,3581,0.7388480025874407,4241.38879776001,0.2689672538212367,0.7445393289838519,0.2876462748564558,3554,0.721334771889948 -208.53724908828733,4.909036159515381,4111.324229717255,17578,0,4111.324229717255,0.2883759610182211,3581,0.7402261938006144,4325.475947856903,0.2680603095463344,0.7461472238813128,0.2870571671004502,3554,0.7228117745849747 -212.548864364624,4.937071323394775,4191.43989610672,17922,0,4191.43989610672,0.289038569991797,3581,0.7409845909836638,4409.6447513103485,0.2687728575297764,0.7468926565987724,0.2877300822741541,3554,0.7236925768060636 -216.56116938591003,4.966598510742188,4271.575967788696,18269,0,4271.575967788696,0.2891711395101752,3581,0.7402092178118891,4493.836472988129,0.2687146833964756,0.7458718163626534,0.2878685019025657,3554,0.7227657491998453 -220.5689299106598,4.99765419960022,4351.592973470688,18612,0,4351.592973470688,0.288235244388788,3581,0.7391915447849763,4577.905798196793,0.2677571092333112,0.7453145980834961,0.2869282273274831,3554,0.7217790885929586 -224.58108806610107,5.0261549949646,4431.69629073143,18958,0,4431.69629073143,0.2881993834648143,3581,0.7389786972476263,4662.063468456268,0.2680262838091169,0.7447752271379743,0.2869150723106886,3554,0.7214467440955613 -228.59129190444943,5.055289030075073,4511.751484632492,19304,0,4511.751484632492,0.2890040726010541,3581,0.7427710922228428,4746.171770095825,0.268930333001273,0.748098645891462,0.2876355413244759,3554,0.7257016879088702 -232.6044859886169,5.084141492843628,4591.897384405136,19646,0,4591.897384405136,0.2885423461607267,3581,0.7385132551792446,4830.373497962952,0.2679600204740252,0.7443575177873883,0.287134877871984,3554,0.721302622814962 -236.61859679222107,5.113516569137573,4671.86004781723,19992,0,4671.86004781723,0.2878015044680257,3581,0.7415709102729684,4914.393493652344,0.267649906022208,0.7473198345729283,0.2864868473188221,3554,0.7242911815911649 -240.63190078735352,5.14256739616394,4752.01032614708,20337,0,4752.01032614708,0.2885085646249302,3581,0.7389167928389416,4998.600042819977,0.268222553389413,0.7448357854570661,0.2872340385338263,3554,0.7214832896252462 -244.6449682712555,5.173840522766113,4832.023688554764,20681,0,4832.023688554764,0.2886559284788641,3581,0.7401653120418529,5082.671446084976,0.2680410317012242,0.7465263775416783,0.2873804954309844,3554,0.7229192816412845 -248.65569019317627,5.203402757644653,4912.122037649155,21026,0,4912.122037649155,0.2883169200293214,3581,0.7379010287585521,5166.823796510696,0.2681511810847691,0.7433980533054897,0.2870477731132614,3554,0.7203869237347004 -252.6671848297119,5.232451677322388,4992.088282823563,21369,0,4992.088282823563,0.2877285213518396,3581,0.7432473743804106,5250.844703674316,0.2676476751055036,0.7488290241786412,0.2864730053559809,3554,0.7261367308402504 -256.68112564086914,5.262057304382324,5072.231734991074,21713,0,5072.231734991074,0.2888652308298136,3581,0.7417139449088942,5335.045144796372,0.2682633740561349,0.7478596142360142,0.2876049893990486,3554,0.7245296205639772 -260.69290471076965,5.292056083679199,5152.4338619709015,22057,0,5152.4338619709015,0.288064257321628,3581,0.7409195504485478,5419.302825212479,0.26737151827131,0.7472703797476632,0.2867679971620885,3554,0.7236122728132034 -264.7044961452484,5.322473049163818,5232.52370929718,22403,0,5232.52370929718,0.2876894220364423,3581,0.7407228607799846,5503.448872804642,0.2673004354749407,0.746769768851144,0.2865335081290447,3554,0.7232089667891812 -268.7178628444672,5.352055549621582,5312.687821388245,22747,0,5312.687821388245,0.2877149541961917,3581,0.7426483742320581,5587.670019626617,0.2672470297132219,0.7487727573939732,0.2863474487922939,3554,0.7256340924178742 -272.73209524154663,5.382074594497681,5392.799517869949,23094,0,5392.799517869949,0.288260537930222,3581,0.7398897419758796,5671.839567661285,0.267437253679548,0.7463985170636859,0.28699774626741,3554,0.722510411354284 -276.7411725521088,5.412250757217407,5472.933496236801,23441,0,5472.933496236801,0.2874584395071209,3581,0.7420801217580634,5756.026417970657,0.2669349568230765,0.7481884275163923,0.2861438036367473,3554,0.7248329072435987 -280.7495219707489,5.442123651504517,5553.040453672409,23786,0,5553.040453672409,0.2872618861918807,3581,0.7407065665578749,5840.18537902832,0.2668683358601161,0.7467733110700335,0.2860006269069622,3554,0.7233074748522791 -284.7582213878632,5.472839117050171,5633.144581317902,24132,0,5633.144581317902,0.2871942549436261,3581,0.7431660396231848,5924.342694759369,0.2663380248206002,0.7495803151811872,0.2858426636635217,3554,0.726013904887099 -288.76734495162964,5.504015684127808,5713.145759820938,24482,0,5713.145759820938,0.2870046897361945,3581,0.7414027184489319,6008.39787364006,0.2665562970297677,0.7474779401506696,0.2857435030016794,3554,0.7240453922956528 -292.7781751155853,5.533693075180054,5793.313705205917,24825,0,5793.313705205917,0.2870229610814891,3581,0.7442264594168877,6092.62002158165,0.2664697340556553,0.7502619198390416,0.2857424038880047,3554,0.7271315661050929 -296.7916488647461,5.566459655761719,5873.269628047943,25169,0,5873.269628047943,0.2872016180230731,3581,0.7421172780386065,6176.635795116425,0.2660852500370571,0.7488915579659599,0.2859390250202236,3554,0.7247901105048888 -300.80610942840576,5.5982396602630615,5953.41180896759,25516,0,5953.41180896759,0.2870041784112329,3581,0.7428563812264382,6260.837786436081,0.2663029602595738,0.7491566794259208,0.2857857501835519,3554,0.7255977529720034 -304.81979846954346,5.628521680831909,6033.534689426422,25862,0,6033.534689426422,0.286961567997766,3581,0.7434290651834334,6345.01830124855,0.2662316220147269,0.7497294970921108,0.2857248695901624,3554,0.7261816571117051 -308.8319561481476,5.659749031066895,6113.551764726639,26206,0,6113.551764726639,0.2869879523657846,3581,0.742172296604475,6429.092185497284,0.2659216778618948,0.7487617901393345,0.2857339029306767,3554,0.7249687851716375 -312.83840131759644,5.691235303878784,6193.714996337891,26552,0,6193.714996337891,0.2868769266724553,3581,0.7426279894102555,6513.307105064392,0.2660137755530221,0.7490940093994141,0.2856169675398758,3554,0.7254139262099043 -316.84984970092773,5.721496343612671,6273.787319660187,26896,0,6273.787319660187,0.2867733322352345,3581,0.7420892574307107,6597.434686660767,0.2661850452423095,0.7483999388558524,0.2855410256544123,3554,0.7247854392717712 -320.86294293403625,5.753942012786865,6353.8551030159,27238,0,6353.8551030159,0.2866217755166155,3581,0.743832057429838,6681.561727285385,0.2653639316558838,0.7507381439208984,0.2852600647213087,3554,0.7267112925137169 -324.8746993541717,5.79008412361145,6433.951121091843,27583,0,6433.951121091843,0.2866703513879677,3581,0.7437995371622801,6765.719286441803,0.2658776044845581,0.7501236370631627,0.2853921472724395,3554,0.7266756400138928 -328.8831934928894,5.821483612060547,6514.103759050369,27929,0,6514.103759050369,0.2863536026184201,3581,0.7433997492189681,6849.925580263138,0.2655642032623291,0.749838011605399,0.285093068137354,3554,0.7262190269766461 -332.8914361000061,5.852307319641113,6594.274115562439,28272,0,6594.274115562439,0.2865293279635577,3581,0.7428893105539653,6934.148802280426,0.265284776687622,0.749901294708252,0.2852201188086927,3554,0.7256960549512873 -336.89839482307434,5.885340690612793,6674.250906705856,28616,0,6674.250906705856,0.2864145184655124,3581,0.7427862956183677,7018.179235696793,0.2654579707554408,0.7493587221418109,0.2851458427673923,3554,0.7255252801640757 -340.9145243167877,5.916884183883667,6754.433007955551,28963,0,6754.433007955551,0.286447243263055,3581,0.7433507301993159,7102.422847747803,0.2655364956174578,0.7498035430908203,0.2851735782140282,3554,0.7261129625070343 -344.9258232116699,5.949899196624756,6834.417027235031,29308,0,6834.417027235031,0.2863590567513439,3581,0.743136519128735,7186.464999914169,0.2654083626610892,0.7496084485735212,0.2850448101775728,3554,0.7259997537985369 -348.9375340938568,5.982892036437988,6914.371897935867,29652,0,6914.371897935867,0.2863031518888753,3581,0.7429690772479755,7270.478240966797,0.2650468349456787,0.7499695505414691,0.2850546850269942,3554,0.7257326691755768 -352.9492189884186,6.013725280761719,6994.351936340332,29999,0,6994.351936340332,0.2863494097537349,3581,0.7440159298860305,7354.514622688293,0.265220182282584,0.7508008139474052,0.2850975504603088,3554,0.7267960616558807 -356.95574593544006,6.045287609100342,7074.379700660706,30341,0,7074.379700660706,0.2861762069450747,3581,0.7433912271362748,7438.5940990448,0.2650173732212612,0.7502134186880929,0.2848916383828081,3554,0.7262326285083709 -360.9666941165924,6.0774827003479,7154.448456525803,30688,0,7154.448456525803,0.2861211542908755,3581,0.7436298454516894,7522.719612360001,0.2646711894444057,0.7507006100245884,0.284826086556301,3554,0.7264637171584833 -364.9779260158539,6.110843896865845,7234.592389583588,31036,0,7234.592389583588,0.2861302217868612,3581,0.7438296712466839,7606.922004938126,0.264853971345084,0.750760691506522,0.2848708239175928,3554,0.7266390944842079 -368.9930951595306,6.142259836196899,7314.720451116562,31381,0,7314.720451116562,0.285995538791975,3581,0.7439185736133412,7691.110218048096,0.2647814410073416,0.7507956368582589,0.2847459542999525,3554,0.7267478380434018 -373.00463938713074,6.173776626586914,7394.727523565292,31728,0,7394.727523565292,0.2861455615357093,3581,0.743655343523108,7775.174228429794,0.2644347974232265,0.751133850642613,0.2848719573785699,3554,0.7264723039840673 -377.02014422416687,6.205841541290283,7474.850043296814,32074,0,7474.850043296814,0.2859557236216315,3581,0.7434837428659942,7859.357992172241,0.2645878280912127,0.7505387578691755,0.2847132041471757,3554,0.7262119514323649 -381.0288183689117,6.23901629447937,7554.948711395264,32417,0,7554.948711395264,0.2859554509149853,3581,0.7438118089613586,7943.512165546417,0.2646485567092895,0.7507827622549874,0.2846971467833339,3554,0.7266321563291361 -385.0398862361908,6.273144245147705,7635.059431314468,32762,0,7635.059431314468,0.2859545305300544,3581,0.7439133921870636,8027.681664466858,0.2641993931361607,0.7513869830540248,0.2846715065221405,3554,0.7267633630240574 -389.0552260875702,6.30461573600769,7715.03692984581,33107,0,7715.03692984581,0.2859292881211167,3581,0.7437964010358489,8111.719587564468,0.2644737107413156,0.7509965215410505,0.2846685183068373,3554,0.7265825588245639 -393.0689616203308,6.34293532371521,7794.994208097458,33451,0,7794.994208097458,0.2859337366482826,3581,0.7437085213191148,8195.742512226105,0.2644854273114885,0.7508912086486816,0.2846644653251617,3554,0.7264931184492825 -397.07849621772766,6.376688718795776,7875.045525312424,33794,0,7875.045525312424,0.2859363614497521,3581,0.7439062336376012,8279.850937843323,0.2641565118517194,0.7513776506696429,0.284640010045899,3554,0.726762607383406 -401.0880081653595,6.409821510314941,7955.107652187347,34140,0,7955.107652187347,0.2859049149646135,3581,0.7438452155255166,8363.969428777695,0.2642498016357422,0.7512272426060268,0.2846393917944569,3554,0.7266439031065349 -405.1007053852081,6.443702220916748,8035.217735528946,34485,0,8035.217735528946,0.2858476636130794,3581,0.7438459654687937,8448.139996290207,0.2643188067844936,0.7511179787772042,0.2845892962540007,3554,0.7266925388866418 -409.1116394996643,6.4771833419799805,8115.333026409149,34828,0,8115.333026409149,0.2859044888604789,3581,0.7439223233297263,8532.313452720642,0.2641722304480416,0.7513903209141323,0.284604168635912,3554,0.7267753158852701 -413.1216921806336,6.510056257247925,8195.418248653412,35173,0,8195.418248653412,0.2858396869436784,3581,0.7439750238891022,8616.45526599884,0.2641371658870152,0.7513918195452008,0.284552355730339,3554,0.7268032058947664 -417.1336545944214,6.544367790222168,8275.447546958923,35519,0,8275.447546958923,0.2858166432320755,3581,0.7440558132330355,8700.54442024231,0.264160360608782,0.751441410609654,0.284526028523099,3554,0.7268765717325548 -421.14607095718384,6.579448938369751,8355.44581079483,35863,0,8355.44581079483,0.2857884180941951,3581,0.7440089758665527,8784.603637218475,0.264140316418239,0.751368590763637,0.2844962150646719,3554,0.7268266994495639 -425.1565029621124,6.6140077114105225,8430.592392921448,36189,0,8430.592392921448,0.28579595161529603,3581,0.744079675064577,8863.808090209961,0.26413134166172575,0.7514665467398507,0.28450164193844085,3554,0.7269076903884707 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/measurements.csv deleted file mode 100644 index 5e48a339c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/measurements.csv +++ /dev/null @@ -1,470 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.6054354,1.0139427,,,,,,,,,,,,,, -1,,,0.1975462777273995,1.0327176366533553,0.1897597777152504,1.0370871866426914,3554.0,0.2128455943150481,1.0323986403937448,3581.0,31.401532649993896,35.358983516693115,31.401532649993896,3.9573137760162354,0.0,0.0 -100,0.6652692,0.35165003,,,,,,,,,,,,,, -200,0.13961153,0.30814108,,,,,,,,,,,,,, -300,0.2843543,0.30594993,,,,,,,,,,,,,, -342,,,0.7213667460850307,0.2925906862531389,0.7017441697515123,0.3103522631849676,3554.0,0.7183853912140463,0.3123536587959892,3581.0,111.60064959526062,119.59745144844057,111.60064959526062,7.96499228477478,0.017967939376831,0.0 -400,0.33448353,0.33655044,,,,,,,,,,,,,, -500,0.14699563,0.22371851,,,,,,,,,,,,,, -600,0.19604123,0.344944,,,,,,,,,,,,,, -682,,,0.7287031582423619,0.2831108910696847,0.7073842715734032,0.3006681786807119,3554.0,0.7247404105434935,0.302153373490296,3581.0,191.57904171943665,203.62491083145144,191.57904171943665,11.974652290344238,0.0440468788146972,0.0 -700,0.28474367,0.3072062,,,,,,,,,,,,,, -800,0.17133608,0.2419311,,,,,,,,,,,,,, -900,0.06763553,0.31217209,,,,,,,,,,,,,, -1000,0.34074533,0.25395182,,,,,,,,,,,,,, -1019,,,0.7273684910365513,0.280535272189549,0.7060646482176772,0.2981226657573332,3554.0,0.7234966637068906,0.2996568464094352,3581.0,271.53961205482483,287.6374468803406,271.53961205482483,15.98682451248169,0.0700306892395019,0.0 -1100,0.17882983,0.26258785,,,,,,,,,,,,,, -1200,0.3205233,0.25311646,,,,,,,,,,,,,, -1300,0.2357555,0.35968897,,,,,,,,,,,,,, -1366,,,0.7377146993364606,0.2756518295833042,0.715898143487092,0.2935038809703855,3554.0,0.7332023614213907,0.2950062436186644,3581.0,351.7352349758148,371.88407588005066,351.7352349758148,19.99767827987671,0.0958211421966552,0.0 -1400,0.2883193,0.27001345,,,,,,,,,,,,,, -1500,0.21421173,0.33364204,,,,,,,,,,,,,, -1600,0.087087035,0.1910802,,,,,,,,,,,,,, -1700,0.22225282,0.33617243,,,,,,,,,,,,,, -1710,,,0.7396605355398995,0.2755191666739328,0.7181294816360088,0.2932262517256084,3554.0,0.7352617738367425,0.2947562738891022,3581.0,431.7387228012085,455.9366509914398,431.7387228012085,24.00688052177429,0.1219170093536377,0.0 -1800,0.09719013,0.2751081,,,,,,,,,,,,,, -1900,0.18462433,0.21706043,,,,,,,,,,,,,, -2000,0.07877084,0.33444253,,,,,,,,,,,,,, -2054,,,0.736421925680978,0.2752586773463658,0.7134809177379361,0.2937887574959553,3554.0,0.73088387769216,0.2953716023478951,3581.0,511.8938019275665,540.144252538681,511.8938019275665,28.0211033821106,0.146493911743164,0.0 -2100,0.33867586,0.2922589,,,,,,,,,,,,,, -2200,0.29248154,0.2815805,,,,,,,,,,,,,, -2300,0.20225936,0.22330923,,,,,,,,,,,,,, -2400,0.31177396,0.2862909,,,,,,,,,,,,,, -2401,,,0.7367238317217145,0.2788510152271816,0.7148442308622327,0.297630159789146,3554.0,0.7316117317308364,0.2993610619633133,3581.0,591.8965113162994,624.146678686142,591.8965113162994,31.982479333877563,0.1711635589599609,0.0 -2500,0.09072322,0.25631154,,,,,,,,,,,,,, -2600,0.09069383,0.2528142,,,,,,,,,,,,,, -2700,0.11911584,0.29436603,,,,,,,,,,,,,, -2748,,,0.7392657143729073,0.2730081421988351,0.7169806330674944,0.2913048636329663,3554.0,0.7344828554785674,0.2926768857119171,3581.0,671.9206922054291,708.2232940196991,671.9206922054291,35.997042179107666,0.1950883865356445,0.0 -2800,0.17434996,0.3524504,,,,,,,,,,,,,, -2900,0.14203066,0.2269283,,,,,,,,,,,,,, -3000,0.14188072,0.26104635,,,,,,,,,,,,,, -3094,,,0.7426317759922573,0.2729813711983817,0.7201212130038337,0.2911620818971581,3554.0,0.7372292159356674,0.2927859342820615,3581.0,751.973610162735,792.3318250179291,751.973610162735,40.01491498947144,0.218900442123413,0.0 -3100,0.10389619,0.2385836,,,,,,,,,,,,,, -3200,0.08094021,0.2736367,,,,,,,,,,,,,, -3300,0.103687696,0.30712253,,,,,,,,,,,,,, -3400,0.24584563,0.3810093,,,,,,,,,,,,,, -3441,,,0.742382322038923,0.2722130332674299,0.7201152365732274,0.2907499142691334,3554.0,0.7372187167297891,0.2922184317513439,3581.0,832.1201596260071,876.5295839309692,832.1201596260071,44.0278480052948,0.2433154582977295,0.0 -3500,0.16695826,0.3254162,,,,,,,,,,,,,, -3600,0.10189231,0.28890267,,,,,,,,,,,,,, -3700,0.0846492,0.226023,,,,,,,,,,,,,, -3788,,,0.7390454837254116,0.2730766023908342,0.7167553147641742,0.2911471751679446,3554.0,0.734154175793249,0.2925956532196837,3581.0,912.2310099601746,960.69327378273,912.2310099601746,48.04212021827698,0.2679710388183594,0.0 -3800,0.08598889,0.2563287,,,,,,,,,,,,,, -3900,0.12191042,0.28035188,,,,,,,,,,,,,, -4000,0.22934553,0.2256539,,,,,,,,,,,,,, -4100,0.12672529,0.2932511,,,,,,,,,,,,,, -4135,,,0.7394935062953404,0.2732838732855661,0.7192164363657146,0.2911408552643149,3554.0,0.7359224056871335,0.2926088454036931,3581.0,992.3586266040802,1044.8746247291565,992.3586266040802,52.056846380233765,0.2930951118469238,0.0 -4200,0.18100853,0.23813498,,,,,,,,,,,,,, -4300,0.27057883,0.32319835,,,,,,,,,,,,,, -4400,0.1612618,0.25921932,,,,,,,,,,,,,, -4480,,,0.7384296144757952,0.2730314220700945,0.7167116249956036,0.2914937394485087,3554.0,0.7337420478741972,0.2930826732014451,3581.0,1072.367573738098,1128.9326810836792,1072.367573738098,56.06653499603272,0.318572998046875,0.0 -4500,0.16086808,0.26650244,,,,,,,,,,,,,, -4600,0.07054327,0.23437588,,,,,,,,,,,,,, -4700,0.1202957,0.25629124,,,,,,,,,,,,,, -4800,0.1296697,0.2669937,,,,,,,,,,,,,, -4826,,,0.7417031696864537,0.2724244935171945,0.7190104899409117,0.2909856741523635,3554.0,0.7363882568154845,0.2923433995719596,3581.0,1152.3579506874084,1212.97585606575,1152.3579506874084,60.07899618148804,0.3449559211730957,0.0 -4900,0.11859083,0.3711417,,,,,,,,,,,,,, -5000,0.19547123,0.29968005,,,,,,,,,,,,,, -5100,0.21844584,0.24285471,,,,,,,,,,,,,, -5173,,,0.7403593744550433,0.2726188898086548,0.7179201004809721,0.2907309202109419,3554.0,0.7354368515035954,0.2921543116011589,3581.0,1232.4112536907196,1297.0824031829834,1232.4112536907196,64.09178137779236,0.3717594146728515,0.0 -5200,0.14817807,0.3156859,,,,,,,,,,,,,, -5300,0.12778425,0.20114538,,,,,,,,,,,,,, -5400,0.1001609,0.23901352,,,,,,,,,,,,,, -5500,0.089591734,0.27187926,,,,,,,,,,,,,, -5515,,,0.7342469351632255,0.2744945628302438,0.7122155631199001,0.2928571899620146,3554.0,0.7297423958478777,0.2941171858637601,3581.0,1312.3778328895569,1381.1000657081604,1312.3778328895569,68.104318857193,0.3967089653015136,0.0 -5600,0.27835408,0.22714305,,,,,,,,,,,,,, -5700,0.17122506,0.29456207,,,,,,,,,,,,,, -5800,0.12712558,0.44035074,,,,,,,,,,,,,, -5861,,,0.741619518824986,0.2733808415276663,0.7192837570782921,0.2918860886830859,3554.0,0.7366074447823583,0.2932253669540631,3581.0,1392.5162916183472,1465.2908942699432,1392.5162916183472,72.11777949333191,0.4217469692230224,0.0 -5900,0.07244157,0.28134134,,,,,,,,,,,,,, -6000,0.10287179,0.2376793,,,,,,,,,,,,,, -6100,0.13215174,0.33813164,,,,,,,,,,,,,, -6200,0.17418242,0.23516005,,,,,,,,,,,,,, -6209,,,0.7370784623282296,0.2732622453144618,0.7148053497159891,0.2914096229050893,3554.0,0.7324459413615261,0.2928307945253595,3581.0,1472.5103845596311,1549.3350069522858,1472.5103845596311,76.12851810455322,0.4470601081848144,0.0 -6300,0.13472474,0.28407413,,,,,,,,,,,,,, -6400,0.11157195,0.28725132,,,,,,,,,,,,,, -6500,0.19159093,0.2516826,,,,,,,,,,,,,, -6552,,,0.7439711434500558,0.2714024782180786,0.7208899056300999,0.2901221486243493,3554.0,0.7380670389294192,0.2915882748686644,3581.0,1552.4796981811523,1633.3560178279877,1552.4796981811523,80.14044666290283,0.4731578826904297,0.0 -6600,0.20006286,0.3130654,,,,,,,,,,,,,, -6700,0.1093883,0.2621145,,,,,,,,,,,,,, -6800,0.11132916,0.2791722,,,,,,,,,,,,,, -6897,,,0.7350197519574847,0.2739830527986799,0.7134123605224747,0.2920506809558771,3554.0,0.7310526149294889,0.2932220944743088,3581.0,1632.6826753616333,1717.6061108112335,1632.6826753616333,84.14881253242493,0.4980559349060058,0.0 -6900,0.11156279,0.31630814,,,,,,,,,,,,,, -7000,0.18072486,0.23053017,,,,,,,,,,,,,, -7100,0.110213146,0.24463688,,,,,,,,,,,,,, -7200,0.1655159,0.24005154,,,,,,,,,,,,,, -7245,,,0.7398674147469657,0.272863643510001,0.718265909120885,0.2909971461513435,3554.0,0.7354655538781066,0.2924798551600461,3581.0,1712.859538078308,1801.8379180431368,1712.859538078308,88.164067029953,0.5237960815429688,0.0 -7300,0.18188351,0.24982424,,,,,,,,,,,,,, -7400,0.16953091,0.26303172,,,,,,,,,,,,,, -7500,0.17447111,0.19403961,,,,,,,,,,,,,, -7585,,,0.7441009793962751,0.2723337071282523,0.7217527785593697,0.2907232264152187,3554.0,0.7389684707483943,0.2920952024355976,3581.0,1792.8237624168396,1885.8538382053373,1792.8237624168396,92.1750111579895,0.550663948059082,0.0 -7600,0.14580236,0.25559098,,,,,,,,,,,,,, -7700,0.15164694,0.2469222,,,,,,,,,,,,,, -7800,0.1008895,0.35398316,,,,,,,,,,,,,, -7900,0.086170055,0.29368967,,,,,,,,,,,,,, -7931,,,0.7397144862583706,0.2723456450871059,0.7167655502602701,0.2911539072392023,3554.0,0.7343427524390882,0.292417541691392,3581.0,1872.8409950733185,1969.928893327713,1872.8409950733185,96.18921422958374,0.5797338485717773,0.0 -8000,0.19226973,0.26827827,,,,,,,,,,,,,, -8100,0.09307154,0.25950363,,,,,,,,,,,,,, -8200,0.084366925,0.3232333,,,,,,,,,,,,,, -8278,,,0.7389503206525531,0.2730504444667271,0.715964502475204,0.291527159373681,3554.0,0.733583128076131,0.292765242665282,3581.0,1952.8641078472133,2054.003679037094,1952.8641078472133,100.20110750198364,0.6055564880371094,0.0 -8300,0.20714702,0.29904678,,,,,,,,,,,,,, -8400,0.08532756,0.23511031,,,,,,,,,,,,,, -8500,0.18251586,0.28093678,,,,,,,,,,,,,, -8600,0.30187988,0.20089908,,,,,,,,,,,,,, -8622,,,0.7411728586469378,0.2715661185128348,0.7190179776528207,0.2899193964986635,3554.0,0.7363723034766825,0.2912397216865052,3581.0,2032.9702162742608,2138.162857532501,2032.9702162742608,104.21291184425354,0.6329255104064941,0.0 -8700,0.19466954,0.27020133,,,,,,,,,,,,,, -8800,0.14035012,0.26174223,,,,,,,,,,,,,, -8900,0.4037744,0.29932633,,,,,,,,,,,,,, -8967,,,0.7409847123282296,0.2726804869515555,0.7181451440058737,0.2913322040856254,3554.0,0.7353939683834823,0.2928140571549497,3581.0,2112.99298119545,2222.238981962204,2112.99298119545,108.22664284706116,0.6587364673614502,0.0 -9000,0.30269915,0.33769065,,,,,,,,,,,,,, -9100,0.12305968,0.28640938,,,,,,,,,,,,,, -9200,0.13831191,0.22804946,,,,,,,,,,,,,, -9300,0.06500063,0.42891264,,,,,,,,,,,,,, -9313,,,0.7409087589808873,0.2718922240393502,0.7185364284740785,0.2901541946574282,3554.0,0.7360125352336987,0.2915215640053407,3581.0,2193.1075069904327,2306.40425825119,2193.1075069904327,112.237300157547,0.6848461627960205,0.0 -9400,0.1626478,0.23858336,,,,,,,,,,,,,, -9500,0.14585322,0.27204677,,,,,,,,,,,,,, -9600,0.07952797,0.3257777,,,,,,,,,,,,,, -9656,,,0.7410191808428083,0.2718787704195295,0.7191517947427195,0.2900739593591727,3554.0,0.7363588726743577,0.2914483081825084,3581.0,2273.0946094989777,2390.445656776428,2273.0946094989777,116.2509913444519,0.7118649482727051,0.0 -9700,0.10276529,0.32576197,,,,,,,,,,,,,, -9800,0.08004225,0.23270248,,,,,,,,,,,,,, -9900,0.12473177,0.25145254,,,,,,,,,,,,,, -10000,0.19562583,0.31304982,,,,,,,,,,,,,, -10003,,,0.742525441305978,0.2702041523797171,0.7192932369337366,0.2888822797046462,3554.0,0.7367913172385507,0.2902696018810213,3581.0,2353.26428937912,2474.669335842133,2353.26428937912,120.26565432548524,0.7373523712158203,0.0 -10100,0.098491155,0.25694305,,,,,,,,,,,,,, -10200,0.15485935,0.36891913,,,,,,,,,,,,,, -10300,0.15041696,0.24424124,,,,,,,,,,,,,, -10348,,,0.7450870786394391,0.2710064649581909,0.7226098811418472,0.2895298637428777,3554.0,0.7398191791311785,0.2909743099431024,3581.0,2433.438845872879,2558.896376132965,2433.438845872879,124.27734661102296,0.7643404006958008,0.0 -10400,0.06146402,0.29428643,,,,,,,,,,,,,, -10500,0.20836382,0.33674002,,,,,,,,,,,,,, -10600,0.13323264,0.29800352,,,,,,,,,,,,,, -10692,,,0.7411471775599888,0.2711236817496164,0.7182149377242192,0.2894940738538442,3554.0,0.7358138684419506,0.2907895171019792,3581.0,2513.490818977356,2642.999470949173,2513.490818977356,128.28723120689392,0.7917084693908691,0.0 -10700,0.12706031,0.23260681,,,,,,,,,,,,,, -10800,0.12101266,0.32745638,,,,,,,,,,,,,, -10900,0.22932556,0.2624903,,,,,,,,,,,,,, -11000,0.18198456,0.26816997,,,,,,,,,,,,,, -11038,,,0.7429281643458775,0.2708985464913504,0.7203224881955191,0.2897779199603439,3554.0,0.7375919157750978,0.2911124358593968,3581.0,2593.563329219818,2727.123106479645,2593.563329219818,132.29781484603882,0.81842041015625,0.0 -11100,0.08067144,0.2883662,,,,,,,,,,,,,, -11200,0.24967106,0.27992988,,,,,,,,,,,,,, -11300,0.14988506,0.33685964,,,,,,,,,,,,,, -11382,,,0.7329865183149066,0.2733525208064488,0.7107803954391179,0.2919654652987831,3554.0,0.728263507705599,0.2932149359248464,3581.0,2673.632261276245,2811.2464735507965,2673.632261276245,136.3119599819183,0.844599723815918,0.0 -11400,0.049371805,0.30676112,,,,,,,,,,,,,, -11500,0.20001024,0.26630747,,,,,,,,,,,,,, -11600,0.26576483,0.23966847,,,,,,,,,,,,,, -11700,0.11406075,0.2992141,,,,,,,,,,,,,, -11728,,,0.7416582788739886,0.2701068094798496,0.7189289494451674,0.2886116573095456,3554.0,0.7364855449115122,0.2899191056640079,3581.0,2753.709707736969,2895.374767303467,2753.709707736969,140.32230615615845,0.871314287185669,0.0 -11800,0.08792556,0.31536362,,,,,,,,,,,,,, -11900,0.11453134,0.27535775,,,,,,,,,,,,,, -12000,0.14409368,0.26081488,,,,,,,,,,,,,, -12074,,,0.7418908391680036,0.270159363746643,0.7188359369504431,0.2889761165346264,3554.0,0.7362733109641162,0.2903468119502234,3581.0,2833.909121990204,2979.629019737244,2833.909121990204,144.3353407382965,0.8990838527679443,0.0 -12100,0.047997497,0.29443154,,,,,,,,,,,,,, -12200,0.21034247,0.26241088,,,,,,,,,,,,,, -12300,0.19884522,0.29052716,,,,,,,,,,,,,, -12400,0.099057704,0.25924414,,,,,,,,,,,,,, -12421,,,0.7414828709193638,0.2705170256750924,0.7181956345403067,0.2892454680795406,3554.0,0.7358350032070302,0.2904481906459438,3581.0,2913.879408121109,3063.651765346527,2913.879408121109,148.34514021873474,0.9278120994567872,0.0 -12500,0.18781383,0.30244255,,,,,,,,,,,,,, -12600,0.17794512,0.30440867,,,,,,,,,,,,,, -12700,0.18861195,0.27458888,,,,,,,,,,,,,, -12762,,,0.7418837547302246,0.2706202268600464,0.7193922945536719,0.288879497573157,3554.0,0.7369440329604161,0.2901792337161407,3581.0,2993.9373548030853,3147.762551546097,2993.9373548030853,152.35568261146545,0.9562292098999025,0.0 -12800,0.06671159,0.2555067,,,,,,,,,,,,,, -12900,0.1924562,0.21301341,,,,,,,,,,,,,, -13000,0.15764523,0.30434602,,,,,,,,,,,,,, -13100,0.23211922,0.25929296,,,,,,,,,,,,,, -13105,,,0.7434845651899066,0.2700550215584891,0.7209310536982977,0.2892409685829347,3554.0,0.7382352307534558,0.2905407745523247,3581.0,3074.032609462738,3231.913425922394,3074.032609462738,156.36930775642395,0.9846577644348145,0.0 -13200,0.09389281,0.23234996,,,,,,,,,,,,,, -13300,0.10346756,0.35201922,,,,,,,,,,,,,, -13400,0.1612183,0.41009766,,,,,,,,,,,,,, -13450,,,0.7449960708618164,0.2694966622761318,0.7224813535365081,0.2880865729250932,3554.0,0.7398077254520385,0.2893762489964395,3581.0,3154.0704021453857,3316.0069653987885,3154.0704021453857,160.3839066028595,1.0118327140808103,0.0 -13500,0.079985045,0.33001217,,,,,,,,,,,,,, -13600,0.21654946,0.20144662,,,,,,,,,,,,,, -13700,0.106447875,0.31946328,,,,,,,,,,,,,, -13797,,,0.7459932054792132,0.2703155108860561,0.7233974647843978,0.2888293848590497,3554.0,0.740477765681723,0.2902129129869449,3581.0,3234.164406776428,3400.156908750534,3234.164406776428,164.39890503883362,1.0392420291900637,0.0 -13800,0.1920763,0.2672009,,,,,,,,,,,,,, -13900,0.17872638,0.23724934,,,,,,,,,,,,,, -14000,0.14460474,0.21930121,,,,,,,,,,,,,, -14100,0.08468113,0.24609241,,,,,,,,,,,,,, -14143,,,0.7477308000837054,0.2687981128692627,0.7242183653102139,0.2880507830360597,3554.0,0.7413721753045588,0.289407030759128,3581.0,3314.325216770172,3484.374987363816,3314.325216770172,168.41451406478882,1.0669360160827637,0.0 -14200,0.08620784,0.35726765,,,,,,,,,,,,,, -14300,0.111771315,0.33347845,,,,,,,,,,,,,, -14400,0.118026495,0.2906502,,,,,,,,,,,,,, -14487,,,0.7401647567749023,0.2714332682745797,0.7168947648116559,0.2901178552115574,3554.0,0.7343974301216489,0.2915374832558119,3581.0,3394.4950959682465,3568.601181983948,3394.4950959682465,172.42876362800598,1.09513521194458,0.0 -14500,0.07494738,0.3437567,,,,,,,,,,,,,, -14600,0.07796797,0.21519972,,,,,,,,,,,,,, -14700,0.121637985,0.26335198,,,,,,,,,,,,,, -14800,0.06911011,0.3423081,,,,,,,,,,,,,, -14833,,,0.7431543213980538,0.2703208412442888,0.7210589630521947,0.2886230606139209,3554.0,0.7382288221472704,0.2900032697526878,3581.0,3474.6364142894745,3652.793179273605,3474.6364142894745,176.43772840499878,1.122864007949829,0.0 -14900,0.115041435,0.35450456,,,,,,,,,,,,,, -15000,0.076761745,0.22156191,,,,,,,,,,,,,, -15100,0.09120382,0.19944121,,,,,,,,,,,,,, -15177,,,0.7430147443498883,0.2690542936325073,0.7197655123408483,0.287806264590734,3554.0,0.7372273069891441,0.2892966186557525,3581.0,3554.608986616134,3736.8194127082825,3554.608986616134,180.44852137565613,1.1519262790679932,0.0 -15200,0.07077936,0.2612246,,,,,,,,,,,,,, -15300,0.10002171,0.24333075,,,,,,,,,,,,,, -15400,0.14471899,0.27699018,,,,,,,,,,,,,, -15500,0.09791843,0.25835,,,,,,,,,,,,,, -15507,,,0.7467621394566127,0.2690369401659284,0.7243016231710748,0.2878441324915588,3554.0,0.7414995974849903,0.2892636552398945,3581.0,3631.154922485352,3820.978665113449,3631.154922485352,184.462073802948,4.738783836364746,0.0 -15600,0.07811509,0.27786848,,,,,,,,,,,,,, -15700,0.16095746,0.21025534,,,,,,,,,,,,,, -15800,0.14795397,0.2518962,,,,,,,,,,,,,, -15854,,,0.7453158242361886,0.2699409893580845,0.7223393617886537,0.2888884622190665,3554.0,0.7394906357991832,0.2902920660910011,3581.0,3711.2650315761566,3905.144039869309,3711.2650315761566,188.4758760929108,4.7664830684661865,0.0 -15900,0.060728054,0.20829082,,,,,,,,,,,,,, -16000,0.13306515,0.23805296,,,,,,,,,,,,,, -16100,0.03987447,0.36553857,,,,,,,,,,,,,, -16197,,,0.741283689226423,0.2696803978511265,0.7182117090777996,0.2883882624551561,3554.0,0.7357913019669785,0.2896448991312657,3581.0,3791.233234167099,3989.16166472435,3791.233234167099,192.4837954044342,4.794270038604736,0.0 -16200,0.088134766,0.22853445,,,,,,,,,,,,,, -16300,0.1420502,0.24608348,,,,,,,,,,,,,, -16400,0.10623398,0.23360476,,,,,,,,,,,,,, -16500,0.04265522,0.276103,,,,,,,,,,,,,, -16543,,,0.7450265884399414,0.2695753744670323,0.7212872352235158,0.2891437657173255,3554.0,0.738705104304838,0.2904213631296251,3581.0,3871.200170278549,4073.183162927628,3871.200170278549,196.49588203430176,4.823106288909912,0.0 -16600,0.1472183,0.23639107,,,,,,,,,,,,,, -16700,0.09502449,0.32233778,,,,,,,,,,,,,, -16800,0.20191409,0.29891342,,,,,,,,,,,,,, -16890,,,0.7434981209891183,0.2689415046146938,0.7204480619328574,0.2876258897325197,3554.0,0.7379346398526948,0.2889543036381248,3581.0,3951.185159444809,4157.22482085228,3951.185159444809,200.5108027458191,4.8511247634887695,0.0 -16900,0.14240308,0.33321288,,,,,,,,,,,,,, -17000,0.08517201,0.2242471,,,,,,,,,,,,,, -17100,0.08963023,0.31203753,,,,,,,,,,,,,, -17200,0.13776301,0.25961983,,,,,,,,,,,,,, -17233,,,0.7445393289838519,0.2689672538212367,0.721334771889948,0.2876462748564558,3554.0,0.7388480025874407,0.2889681435004189,3581.0,4031.2938227653503,4241.38879776001,4031.2938227653503,204.5243673324585,4.8789660930633545,0.0 -17300,0.08723789,0.25660264,,,,,,,,,,,,,, -17400,0.11024636,0.2353129,,,,,,,,,,,,,, -17500,0.07834395,0.25219864,,,,,,,,,,,,,, -17578,,,0.7461472238813128,0.2680603095463344,0.7228117745849747,0.2870571671004502,3554.0,0.7402261938006144,0.2883759610182211,3581.0,4111.324229717255,4325.475947856903,4111.324229717255,208.53724908828733,4.909036159515381,0.0 -17600,0.07724469,0.30457994,,,,,,,,,,,,,, -17700,0.12541933,0.26379138,,,,,,,,,,,,,, -17800,0.12760572,0.28903562,,,,,,,,,,,,,, -17900,0.08785931,0.2090554,,,,,,,,,,,,,, -17922,,,0.7468926565987724,0.2687728575297764,0.7236925768060636,0.2877300822741541,3554.0,0.7409845909836638,0.289038569991797,3581.0,4191.43989610672,4409.6447513103485,4191.43989610672,212.548864364624,4.937071323394775,0.0 -18000,0.049005806,0.31594282,,,,,,,,,,,,,, -18100,0.09189122,0.27116305,,,,,,,,,,,,,, -18200,0.101346895,0.2152969,,,,,,,,,,,,,, -18269,,,0.7458718163626534,0.2687146833964756,0.7227657491998453,0.2878685019025657,3554.0,0.7402092178118891,0.2891711395101752,3581.0,4271.575967788696,4493.836472988129,4271.575967788696,216.56116938591003,4.966598510742188,0.0 -18300,0.0989907,0.25724843,,,,,,,,,,,,,, -18400,0.17011733,0.22363421,,,,,,,,,,,,,, -18500,0.06981425,0.3390261,,,,,,,,,,,,,, -18600,0.15691209,0.26034865,,,,,,,,,,,,,, -18612,,,0.7453145980834961,0.2677571092333112,0.7217790885929586,0.2869282273274831,3554.0,0.7391915447849763,0.288235244388788,3581.0,4351.592973470688,4577.905798196793,4351.592973470688,220.5689299106598,4.99765419960022,0.0 -18700,0.14148018,0.2911642,,,,,,,,,,,,,, -18800,0.16066258,0.26583216,,,,,,,,,,,,,, -18900,0.055432785,0.24578768,,,,,,,,,,,,,, -18958,,,0.7447752271379743,0.2680262838091169,0.7214467440955613,0.2869150723106886,3554.0,0.7389786972476263,0.2881993834648143,3581.0,4431.69629073143,4662.063468456268,4431.69629073143,224.58108806610107,5.0261549949646,0.0 -19000,0.070974395,0.2822093,,,,,,,,,,,,,, -19100,0.09298475,0.276066,,,,,,,,,,,,,, -19200,0.066675,0.23400201,,,,,,,,,,,,,, -19300,0.09884617,0.25821257,,,,,,,,,,,,,, -19304,,,0.748098645891462,0.268930333001273,0.7257016879088702,0.2876355413244759,3554.0,0.7427710922228428,0.2890040726010541,3581.0,4511.751484632492,4746.171770095825,4511.751484632492,228.59129190444943,5.055289030075073,0.0 -19400,0.15031916,0.24727628,,,,,,,,,,,,,, -19500,0.19395842,0.22661515,,,,,,,,,,,,,, -19600,0.11470514,0.34506986,,,,,,,,,,,,,, -19646,,,0.7443575177873883,0.2679600204740252,0.721302622814962,0.287134877871984,3554.0,0.7385132551792446,0.2885423461607267,3581.0,4591.897384405136,4830.373497962952,4591.897384405136,232.6044859886169,5.084141492843628,0.0 -19700,0.061308052,0.20440558,,,,,,,,,,,,,, -19800,0.16609944,0.22203478,,,,,,,,,,,,,, -19900,0.18180612,0.3176582,,,,,,,,,,,,,, -19992,,,0.7473198345729283,0.267649906022208,0.7242911815911649,0.2864868473188221,3554.0,0.7415709102729684,0.2878015044680257,3581.0,4671.86004781723,4914.393493652344,4671.86004781723,236.61859679222107,5.113516569137573,0.0 -20000,0.12282909,0.24372165,,,,,,,,,,,,,, -20100,0.061919734,0.38682348,,,,,,,,,,,,,, -20200,0.107326694,0.3605698,,,,,,,,,,,,,, -20300,0.071510024,0.270511,,,,,,,,,,,,,, -20337,,,0.7448357854570661,0.268222553389413,0.7214832896252462,0.2872340385338263,3554.0,0.7389167928389416,0.2885085646249302,3581.0,4752.01032614708,4998.600042819977,4752.01032614708,240.63190078735352,5.14256739616394,0.0 -20400,0.17535979,0.23779236,,,,,,,,,,,,,, -20500,0.057096686,0.23427509,,,,,,,,,,,,,, -20600,0.063610375,0.36659065,,,,,,,,,,,,,, -20681,,,0.7465263775416783,0.2680410317012242,0.7229192816412845,0.2873804954309844,3554.0,0.7401653120418529,0.2886559284788641,3581.0,4832.023688554764,5082.671446084976,4832.023688554764,244.6449682712555,5.173840522766113,0.0 -20700,0.06534628,0.25509822,,,,,,,,,,,,,, -20800,0.15507859,0.23758832,,,,,,,,,,,,,, -20900,0.08007307,0.30732417,,,,,,,,,,,,,, -21000,0.09991182,0.31468564,,,,,,,,,,,,,, -21026,,,0.7433980533054897,0.2681511810847691,0.7203869237347004,0.2870477731132614,3554.0,0.7379010287585521,0.2883169200293214,3581.0,4912.122037649155,5166.823796510696,4912.122037649155,248.65569019317627,5.203402757644653,0.0 -21100,0.10848597,0.27186278,,,,,,,,,,,,,, -21200,0.0767449,0.22271553,,,,,,,,,,,,,, -21300,0.06048272,0.2744095,,,,,,,,,,,,,, -21369,,,0.7488290241786412,0.2676476751055036,0.7261367308402504,0.2864730053559809,3554.0,0.7432473743804106,0.2877285213518396,3581.0,4992.088282823563,5250.844703674316,4992.088282823563,252.6671848297119,5.232451677322388,0.0 -21400,0.09548473,0.25959867,,,,,,,,,,,,,, -21500,0.08986076,0.21690658,,,,,,,,,,,,,, -21600,0.12273617,0.22525537,,,,,,,,,,,,,, -21700,0.13673721,0.30361855,,,,,,,,,,,,,, -21713,,,0.7478596142360142,0.2682633740561349,0.7245296205639772,0.2876049893990486,3554.0,0.7417139449088942,0.2888652308298136,3581.0,5072.231734991074,5335.045144796372,5072.231734991074,256.68112564086914,5.262057304382324,0.0 -21800,0.08617126,0.24849729,,,,,,,,,,,,,, -21900,0.052547332,0.34925148,,,,,,,,,,,,,, -22000,0.09131719,0.28725296,,,,,,,,,,,,,, -22057,,,0.7472703797476632,0.26737151827131,0.7236122728132034,0.2867679971620885,3554.0,0.7409195504485478,0.288064257321628,3581.0,5152.4338619709015,5419.302825212479,5152.4338619709015,260.69290471076965,5.292056083679199,0.0 -22100,0.11143046,0.27198493,,,,,,,,,,,,,, -22200,0.16023248,0.2840195,,,,,,,,,,,,,, -22300,0.08700913,0.3295805,,,,,,,,,,,,,, -22400,0.08245134,0.25472537,,,,,,,,,,,,,, -22403,,,0.746769768851144,0.2673004354749407,0.7232089667891812,0.2865335081290447,3554.0,0.7407228607799846,0.2876894220364423,3581.0,5232.52370929718,5503.448872804642,5232.52370929718,264.7044961452484,5.322473049163818,0.0 -22500,0.062740475,0.31020638,,,,,,,,,,,,,, -22600,0.11449087,0.22551617,,,,,,,,,,,,,, -22700,0.080029696,0.2146793,,,,,,,,,,,,,, -22747,,,0.7487727573939732,0.2672470297132219,0.7256340924178742,0.2863474487922939,3554.0,0.7426483742320581,0.2877149541961917,3581.0,5312.687821388245,5587.670019626617,5312.687821388245,268.7178628444672,5.352055549621582,0.0 -22800,0.117641434,0.23529515,,,,,,,,,,,,,, -22900,0.09090475,0.21186057,,,,,,,,,,,,,, -23000,0.08331454,0.27683604,,,,,,,,,,,,,, -23094,,,0.7463985170636859,0.267437253679548,0.722510411354284,0.28699774626741,3554.0,0.7398897419758796,0.288260537930222,3581.0,5392.799517869949,5671.839567661285,5392.799517869949,272.73209524154663,5.382074594497681,0.0 -23100,0.046469964,0.26837957,,,,,,,,,,,,,, -23200,0.04468746,0.30555356,,,,,,,,,,,,,, -23300,0.04039472,0.2421088,,,,,,,,,,,,,, -23400,0.04909189,0.2945602,,,,,,,,,,,,,, -23441,,,0.7481884275163923,0.2669349568230765,0.7248329072435987,0.2861438036367473,3554.0,0.7420801217580634,0.2874584395071209,3581.0,5472.933496236801,5756.026417970657,5472.933496236801,276.7411725521088,5.412250757217407,0.0 -23500,0.040307276,0.24339807,,,,,,,,,,,,,, -23600,0.05580763,0.28531244,,,,,,,,,,,,,, -23700,0.06072623,0.3375381,,,,,,,,,,,,,, -23786,,,0.7467733110700335,0.2668683358601161,0.7233074748522791,0.2860006269069622,3554.0,0.7407065665578749,0.2872618861918807,3581.0,5553.040453672409,5840.18537902832,5553.040453672409,280.7495219707489,5.442123651504517,0.0 -23800,0.047595683,0.21261275,,,,,,,,,,,,,, -23900,0.043947194,0.28559422,,,,,,,,,,,,,, -24000,0.08526703,0.18816139,,,,,,,,,,,,,, -24100,0.07294347,0.2946077,,,,,,,,,,,,,, -24132,,,0.7495803151811872,0.2663380248206002,0.726013904887099,0.2858426636635217,3554.0,0.7431660396231848,0.2871942549436261,3581.0,5633.144581317902,5924.342694759369,5633.144581317902,284.7582213878632,5.472839117050171,0.0 -24200,0.12191965,0.2723906,,,,,,,,,,,,,, -24300,0.0690752,0.24936192,,,,,,,,,,,,,, -24400,0.1255177,0.28515998,,,,,,,,,,,,,, -24482,,,0.7474779401506696,0.2665562970297677,0.7240453922956528,0.2857435030016794,3554.0,0.7414027184489319,0.2870046897361945,3581.0,5713.145759820938,6008.39787364006,5713.145759820938,288.76734495162964,5.504015684127808,0.0 -24500,0.096125536,0.3383948,,,,,,,,,,,,,, -24600,0.080781184,0.22328302,,,,,,,,,,,,,, -24700,0.07467315,0.27284324,,,,,,,,,,,,,, -24800,0.11735817,0.27337936,,,,,,,,,,,,,, -24825,,,0.7502619198390416,0.2664697340556553,0.7271315661050929,0.2857424038880047,3554.0,0.7442264594168877,0.2870229610814891,3581.0,5793.313705205917,6092.62002158165,5793.313705205917,292.7781751155853,5.533693075180054,0.0 -24900,0.09172441,0.2769463,,,,,,,,,,,,,, -25000,0.07829477,0.29496264,,,,,,,,,,,,,, -25100,0.080001876,0.33467487,,,,,,,,,,,,,, -25169,,,0.7488915579659599,0.2660852500370571,0.7247901105048888,0.2859390250202236,3554.0,0.7421172780386065,0.2872016180230731,3581.0,5873.269628047943,6176.635795116425,5873.269628047943,296.7916488647461,5.566459655761719,0.0 -25200,0.089209616,0.29129583,,,,,,,,,,,,,, -25300,0.081897154,0.28333935,,,,,,,,,,,,,, -25400,0.1267098,0.21745102,,,,,,,,,,,,,, -25500,0.08673865,0.26051658,,,,,,,,,,,,,, -25516,,,0.7491566794259208,0.2663029602595738,0.7255977529720034,0.2857857501835519,3554.0,0.7428563812264382,0.2870041784112329,3581.0,5953.41180896759,6260.837786436081,5953.41180896759,300.80610942840576,5.5982396602630615,0.0 -25600,0.029941205,0.3502701,,,,,,,,,,,,,, -25700,0.13524449,0.24314474,,,,,,,,,,,,,, -25800,0.043462176,0.2550799,,,,,,,,,,,,,, -25862,,,0.7497294970921108,0.2662316220147269,0.7261816571117051,0.2857248695901624,3554.0,0.7434290651834334,0.286961567997766,3581.0,6033.534689426422,6345.01830124855,6033.534689426422,304.81979846954346,5.628521680831909,0.0 -25900,0.116904885,0.21584056,,,,,,,,,,,,,, -26000,0.0872043,0.3255185,,,,,,,,,,,,,, -26100,0.089020655,0.29472658,,,,,,,,,,,,,, -26200,0.028589629,0.28168136,,,,,,,,,,,,,, -26206,,,0.7487617901393345,0.2659216778618948,0.7249687851716375,0.2857339029306767,3554.0,0.742172296604475,0.2869879523657846,3581.0,6113.551764726639,6429.092185497284,6113.551764726639,308.8319561481476,5.659749031066895,0.0 -26300,0.08888221,0.25877815,,,,,,,,,,,,,, -26400,0.06903823,0.40875673,,,,,,,,,,,,,, -26500,0.12160117,0.20846985,,,,,,,,,,,,,, -26552,,,0.7490940093994141,0.2660137755530221,0.7254139262099043,0.2856169675398758,3554.0,0.7426279894102555,0.2868769266724553,3581.0,6193.714996337891,6513.307105064392,6193.714996337891,312.83840131759644,5.691235303878784,0.0 -26600,0.044821642,0.3105908,,,,,,,,,,,,,, -26700,0.06834458,0.32788777,,,,,,,,,,,,,, -26800,0.11315011,0.2620054,,,,,,,,,,,,,, -26896,,,0.7483999388558524,0.2661850452423095,0.7247854392717712,0.2855410256544123,3554.0,0.7420892574307107,0.2867733322352345,3581.0,6273.787319660187,6597.434686660767,6273.787319660187,316.84984970092773,5.721496343612671,0.0 -26900,0.065356284,0.23126426,,,,,,,,,,,,,, -27000,0.03675696,0.22300242,,,,,,,,,,,,,, -27100,0.052937005,0.28382963,,,,,,,,,,,,,, -27200,0.062055968,0.35530275,,,,,,,,,,,,,, -27238,,,0.7507381439208984,0.2653639316558838,0.7267112925137169,0.2852600647213087,3554.0,0.743832057429838,0.2866217755166155,3581.0,6353.8551030159,6681.561727285385,6353.8551030159,320.86294293403625,5.753942012786865,0.0 -27300,0.082864255,0.22187945,,,,,,,,,,,,,, -27400,0.039017938,0.26053992,,,,,,,,,,,,,, -27500,0.04418468,0.33760804,,,,,,,,,,,,,, -27583,,,0.7501236370631627,0.2658776044845581,0.7266756400138928,0.2853921472724395,3554.0,0.7437995371622801,0.2866703513879677,3581.0,6433.951121091843,6765.719286441803,6433.951121091843,324.8746993541717,5.79008412361145,0.0 -27600,0.0473717,0.2995636,,,,,,,,,,,,,, -27700,0.05279986,0.37092897,,,,,,,,,,,,,, -27800,0.108443916,0.3137864,,,,,,,,,,,,,, -27900,0.034239046,0.30576026,,,,,,,,,,,,,, -27929,,,0.749838011605399,0.2655642032623291,0.7262190269766461,0.285093068137354,3554.0,0.7433997492189681,0.2863536026184201,3581.0,6514.103759050369,6849.925580263138,6514.103759050369,328.8831934928894,5.821483612060547,0.0 -28000,0.039885774,0.21225274,,,,,,,,,,,,,, -28100,0.07758919,0.21983753,,,,,,,,,,,,,, -28200,0.032159768,0.36830378,,,,,,,,,,,,,, -28272,,,0.749901294708252,0.265284776687622,0.7256960549512873,0.2852201188086927,3554.0,0.7428893105539653,0.2865293279635577,3581.0,6594.274115562439,6934.148802280426,6594.274115562439,332.8914361000061,5.852307319641113,0.0 -28300,0.11407425,0.26468566,,,,,,,,,,,,,, -28400,0.048578408,0.22966997,,,,,,,,,,,,,, -28500,0.06639101,0.2958942,,,,,,,,,,,,,, -28600,0.042699996,0.31307828,,,,,,,,,,,,,, -28616,,,0.7493587221418109,0.2654579707554408,0.7255252801640757,0.2851458427673923,3554.0,0.7427862956183677,0.2864145184655124,3581.0,6674.250906705856,7018.179235696793,6674.250906705856,336.89839482307434,5.885340690612793,0.0 -28700,0.042195007,0.2944864,,,,,,,,,,,,,, -28800,0.060862463,0.27083427,,,,,,,,,,,,,, -28900,0.03899379,0.30407572,,,,,,,,,,,,,, -28963,,,0.7498035430908203,0.2655364956174578,0.7261129625070343,0.2851735782140282,3554.0,0.7433507301993159,0.286447243263055,3581.0,6754.433007955551,7102.422847747803,6754.433007955551,340.9145243167877,5.916884183883667,0.0 -29000,0.059519593,0.22629997,,,,,,,,,,,,,, -29100,0.062501706,0.1857876,,,,,,,,,,,,,, -29200,0.104766235,0.189808,,,,,,,,,,,,,, -29300,0.04864347,0.2914729,,,,,,,,,,,,,, -29308,,,0.7496084485735212,0.2654083626610892,0.7259997537985369,0.2850448101775728,3554.0,0.743136519128735,0.2863590567513439,3581.0,6834.417027235031,7186.464999914169,6834.417027235031,344.9258232116699,5.949899196624756,0.0 -29400,0.040923465,0.279948,,,,,,,,,,,,,, -29500,0.04401119,0.21383479,,,,,,,,,,,,,, -29600,0.07629326,0.25480738,,,,,,,,,,,,,, -29652,,,0.7499695505414691,0.2650468349456787,0.7257326691755768,0.2850546850269942,3554.0,0.7429690772479755,0.2863031518888753,3581.0,6914.371897935867,7270.478240966797,6914.371897935867,348.9375340938568,5.982892036437988,0.0 -29700,0.042000547,0.30646393,,,,,,,,,,,,,, -29800,0.045462653,0.2866103,,,,,,,,,,,,,, -29900,0.044169243,0.24300753,,,,,,,,,,,,,, -29999,,,0.7508008139474052,0.265220182282584,0.7267960616558807,0.2850975504603088,3554.0,0.7440159298860305,0.2863494097537349,3581.0,6994.351936340332,7354.514622688293,6994.351936340332,352.9492189884186,6.013725280761719,0.0 -30000,0.0353692,0.2964908,,,,,,,,,,,,,, -30100,0.044856,0.23670569,,,,,,,,,,,,,, -30200,0.053703364,0.2794368,,,,,,,,,,,,,, -30300,0.05629635,0.21963844,,,,,,,,,,,,,, -30341,,,0.7502134186880929,0.2650173732212612,0.7262326285083709,0.2848916383828081,3554.0,0.7433912271362748,0.2861762069450747,3581.0,7074.379700660706,7438.5940990448,7074.379700660706,356.95574593544006,6.045287609100342,0.0 -30400,0.087292254,0.25211942,,,,,,,,,,,,,, -30500,0.08213622,0.22576113,,,,,,,,,,,,,, -30600,0.025576724,0.3225608,,,,,,,,,,,,,, -30688,,,0.7507006100245884,0.2646711894444057,0.7264637171584833,0.284826086556301,3554.0,0.7436298454516894,0.2861211542908755,3581.0,7154.448456525803,7522.719612360001,7154.448456525803,360.9666941165924,6.0774827003479,0.0 -30700,0.037957728,0.29127774,,,,,,,,,,,,,, -30800,0.049569182,0.35710806,,,,,,,,,,,,,, -30900,0.03409789,0.21051168,,,,,,,,,,,,,, -31000,0.022465382,0.27576464,,,,,,,,,,,,,, -31036,,,0.750760691506522,0.264853971345084,0.7266390944842079,0.2848708239175928,3554.0,0.7438296712466839,0.2861302217868612,3581.0,7234.592389583588,7606.922004938126,7234.592389583588,364.9779260158539,6.110843896865845,0.0 -31100,0.037761975,0.20042406,,,,,,,,,,,,,, -31200,0.059893124,0.21587211,,,,,,,,,,,,,, -31300,0.04327139,0.33905715,,,,,,,,,,,,,, -31381,,,0.7507956368582589,0.2647814410073416,0.7267478380434018,0.2847459542999525,3554.0,0.7439185736133412,0.285995538791975,3581.0,7314.720451116562,7691.110218048096,7314.720451116562,368.9930951595306,6.142259836196899,0.0 -31400,0.06630708,0.21994859,,,,,,,,,,,,,, -31500,0.067541175,0.26141503,,,,,,,,,,,,,, -31600,0.024640022,0.21866415,,,,,,,,,,,,,, -31700,0.03071734,0.23565696,,,,,,,,,,,,,, -31728,,,0.751133850642613,0.2644347974232265,0.7264723039840673,0.2848719573785699,3554.0,0.743655343523108,0.2861455615357093,3581.0,7394.727523565292,7775.174228429794,7394.727523565292,373.00463938713074,6.173776626586914,0.0 -31800,0.066161126,0.21908273,,,,,,,,,,,,,, -31900,0.025841251,0.3054301,,,,,,,,,,,,,, -32000,0.05549202,0.19069183,,,,,,,,,,,,,, -32074,,,0.7505387578691755,0.2645878280912127,0.7262119514323649,0.2847132041471757,3554.0,0.7434837428659942,0.2859557236216315,3581.0,7474.850043296814,7859.357992172241,7474.850043296814,377.02014422416687,6.205841541290283,0.0 -32100,0.033687886,0.32723862,,,,,,,,,,,,,, -32200,0.021971855,0.32196814,,,,,,,,,,,,,, -32300,0.03523714,0.29532075,,,,,,,,,,,,,, -32400,0.028929634,0.26865503,,,,,,,,,,,,,, -32417,,,0.7507827622549874,0.2646485567092895,0.7266321563291361,0.2846971467833339,3554.0,0.7438118089613586,0.2859554509149853,3581.0,7554.948711395264,7943.512165546417,7554.948711395264,381.0288183689117,6.23901629447937,0.0 -32500,0.01762483,0.24977073,,,,,,,,,,,,,, -32600,0.019168045,0.23371619,,,,,,,,,,,,,, -32700,0.024477446,0.24723557,,,,,,,,,,,,,, -32762,,,0.7513869830540248,0.2641993931361607,0.7267633630240574,0.2846715065221405,3554.0,0.7439133921870636,0.2859545305300544,3581.0,7635.059431314468,8027.681664466858,7635.059431314468,385.0398862361908,6.273144245147705,0.0 -32800,0.023012852,0.23229286,,,,,,,,,,,,,, -32900,0.02600868,0.21409112,,,,,,,,,,,,,, -33000,0.022527559,0.21421762,,,,,,,,,,,,,, -33100,0.022557767,0.23462267,,,,,,,,,,,,,, -33107,,,0.7509965215410505,0.2644737107413156,0.7265825588245639,0.2846685183068373,3554.0,0.7437964010358489,0.2859292881211167,3581.0,7715.03692984581,8111.719587564468,7715.03692984581,389.0552260875702,6.30461573600769,0.0 -33200,0.02247897,0.28879055,,,,,,,,,,,,,, -33300,0.016325217,0.3220177,,,,,,,,,,,,,, -33400,0.0144841885,0.2641151,,,,,,,,,,,,,, -33451,,,0.7508912086486816,0.2644854273114885,0.7264931184492825,0.2846644653251617,3554.0,0.7437085213191148,0.2859337366482826,3581.0,7794.994208097458,8195.742512226105,7794.994208097458,393.0689616203308,6.34293532371521,0.0 -33500,0.02426196,0.22226647,,,,,,,,,,,,,, -33600,0.025309244,0.298475,,,,,,,,,,,,,, -33700,0.026796851,0.24912438,,,,,,,,,,,,,, -33794,,,0.7513776506696429,0.2641565118517194,0.726762607383406,0.284640010045899,3554.0,0.7439062336376012,0.2859363614497521,3581.0,7875.045525312424,8279.850937843323,7875.045525312424,397.07849621772766,6.376688718795776,0.0 -33800,0.017096289,0.23839214,,,,,,,,,,,,,, -33900,0.018065317,0.2858318,,,,,,,,,,,,,, -34000,0.026724052,0.39730406,,,,,,,,,,,,,, -34100,0.02704514,0.21725225,,,,,,,,,,,,,, -34140,,,0.7512272426060268,0.2642498016357422,0.7266439031065349,0.2846393917944569,3554.0,0.7438452155255166,0.2859049149646135,3581.0,7955.107652187347,8363.969428777695,7955.107652187347,401.0880081653595,6.409821510314941,0.0 -34200,0.018421216,0.31162915,,,,,,,,,,,,,, -34300,0.019912232,0.2353409,,,,,,,,,,,,,, -34400,0.017959297,0.20668735,,,,,,,,,,,,,, -34485,,,0.7511179787772042,0.2643188067844936,0.7266925388866418,0.2845892962540007,3554.0,0.7438459654687937,0.2858476636130794,3581.0,8035.217735528946,8448.139996290207,8035.217735528946,405.1007053852081,6.443702220916748,0.0 -34500,0.024131602,0.23920837,,,,,,,,,,,,,, -34600,0.020863816,0.24269192,,,,,,,,,,,,,, -34700,0.05776164,0.35152122,,,,,,,,,,,,,, -34800,0.01749556,0.31407523,,,,,,,,,,,,,, -34828,,,0.7513903209141323,0.2641722304480416,0.7267753158852701,0.284604168635912,3554.0,0.7439223233297263,0.2859044888604789,3581.0,8115.333026409149,8532.313452720642,8115.333026409149,409.1116394996643,6.4771833419799805,0.0 -34900,0.025834054,0.17864096,,,,,,,,,,,,,, -35000,0.02121043,0.24012008,,,,,,,,,,,,,, -35100,0.02038718,0.22031723,,,,,,,,,,,,,, -35173,,,0.7513918195452008,0.2641371658870152,0.7268032058947664,0.284552355730339,3554.0,0.7439750238891022,0.2858396869436784,3581.0,8195.418248653412,8616.45526599884,8195.418248653412,413.1216921806336,6.510056257247925,0.0 -35200,0.013102603,0.21717474,,,,,,,,,,,,,, -35300,0.017587066,0.29003933,,,,,,,,,,,,,, -35400,0.01683473,0.24597171,,,,,,,,,,,,,, -35500,0.017224776,0.31872004,,,,,,,,,,,,,, -35519,,,0.751441410609654,0.264160360608782,0.7268765717325548,0.284526028523099,3554.0,0.7440558132330355,0.2858166432320755,3581.0,8275.447546958923,8700.54442024231,8275.447546958923,417.1336545944214,6.544367790222168,0.0 -35600,0.01659229,0.20208849,,,,,,,,,,,,,, -35700,0.024763552,0.20901486,,,,,,,,,,,,,, -35800,0.02967869,0.23518467,,,,,,,,,,,,,, -35863,,,0.751368590763637,0.264140316418239,0.7268266994495639,0.2844962150646719,3554.0,0.7440089758665527,0.2857884180941951,3581.0,8355.44581079483,8784.603637218475,8355.44581079483,421.14607095718384,6.579448938369751,0.0 -35900,0.018733252,0.19046234,,,,,,,,,,,,,, -36000,0.020906506,0.3020697,,,,,,,,,,,,,, -36100,0.015121041,0.24443705,,,,,,,,,,,,,, -36189,,,0.7514665467398507,0.2641313416617257,0.7269076903884707,0.2845016419384408,3554.0,0.744079675064577,0.285795951615296,3581.0,8430.592392921448,8863.808090209961,8430.592392921448,425.1565029621124,6.6140077114105225,0.0 -36189,,,,,,,,,,,8430.592392921448,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/eval_measurements.csv deleted file mode 100644 index f9a9ad0e1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,97 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -201.8237082958221,0.0,55.27614974975586,1,0,55.27614974975586,0.7883392137976473,3581,0.2942129740732337,257.10017442703247,0.7874368940080915,0.2754161528178623,0.7887035019959905,3554,0.2711523739756261 -206.1908564567566,0.029102087020874,135.5868763923645,335,0,135.5868763923645,0.3293810186356813,3581,0.700335142832833,341.8186767101288,0.3088119370596749,0.7016571589878627,0.3278064974654439,3554,0.6821352948482344 -210.20913743972773,0.0634787082672119,215.7821829319,570,0,215.7821829319,0.3174902589185981,3581,0.7096741864615331,426.0742771625519,0.2976064000810896,0.7112112726484027,0.3155408353703573,3554,0.6918897913222777 -214.2245299816132,0.1001718044281005,296.0674078464508,809,0,296.0674078464508,0.3132424779325782,3581,0.7159655969657568,510.41920804977417,0.2927102702004568,0.7181434631347656,0.3111197880029544,3554,0.6987013422376196 -218.2427520751953,0.1342196464538574,376.2434275150299,1085,0,376.2434275150299,0.3076909887361246,3581,0.7175513861133412,594.6569521427155,0.2878707477024623,0.7194068091256278,0.3055842050220702,3554,0.7001576678566404 -222.26331734657288,0.1570632457733154,456.43198323249817,1433,0,456.43198323249817,0.3031562862699839,3581,0.7231047842563181,678.9004311561584,0.2832481009619577,0.725543635232108,0.3011932458915131,3554,0.7061628128077518 -226.2819182872772,0.183495283126831,536.503087759018,1779,0,536.503087759018,0.3014893328068277,3581,0.7249152836803616,763.0277693271637,0.2818348578044346,0.727372373853411,0.2996103161160839,3554,0.7079292258722566 -230.29897046089167,0.2066645622253418,616.6463549137115,2124,0,616.6463549137115,0.3017955482821663,3581,0.7273907100844736,847.2227547168732,0.2815104722976684,0.7304275376456124,0.299705458143553,3554,0.7106021329399972 -234.31987977027893,0.2298219203948974,696.7213439941406,2470,0,696.7213439941406,0.2977084596874127,3581,0.7302026564899818,931.3529715538024,0.2781985827854701,0.7321715354919434,0.2960028189517972,3554,0.7131566105092854 -238.3439779281616,0.2531630992889404,776.725414276123,2814,0,776.725414276123,0.2958164209761589,3581,0.7317454943407917,1015.4160070419312,0.2757815633501325,0.7344308580671038,0.294157201008107,3554,0.7145839470051351 -242.36777329444885,0.2781813144683838,856.8844676017761,3159,0,856.8844676017761,0.2954341203465337,3581,0.7313013233908127,1099.6351277828217,0.2757396868297032,0.734090873173305,0.2939341152794386,3554,0.713944606319464 -246.39089345932007,0.3060920238494873,936.8562984466552,3505,0,936.8562984466552,0.2947556602991483,3581,0.7299829912864773,1183.6691632270813,0.2749878849302019,0.7327982357570103,0.2931661782938238,3554,0.7125102629739378 -250.41059923172,0.32985520362854,1017.0803380012512,3853,0,1017.0803380012512,0.2951208485867251,3581,0.7323862186060109,1267.948300600052,0.2754707336425781,0.7350469316755023,0.2935341409437429,3554,0.7151173606104038 -254.43634033203125,0.3551664352416992,1097.1687524318695,4197,0,1097.1687524318695,0.294075802630201,3581,0.7330801888438984,1352.0996026992798,0.2739982094083513,0.7359296253749302,0.2924870634320484,3554,0.7157386346150464 -258.4599637985229,0.3792397975921631,1177.3439629077911,4542,0,1177.3439629077911,0.2935119475508762,3581,0.7331703183904635,1436.3335757255554,0.2733580384935651,0.7363841874258858,0.2920301756163829,3554,0.7157795079048256 -262.4826068878174,0.4045064449310303,1257.5012693405151,4889,0,1257.5012693405151,0.2929092317744694,3581,0.7352506610409104,1520.5500855445862,0.2730294125420706,0.7380552291870117,0.2913801185723832,3554,0.7179574016513084 -266.5050644874573,0.4283595085144043,1337.5284023284912,5237,0,1337.5284023284912,0.2938717157938599,3581,0.7328710910229336,1604.6348433494568,0.2737502711159842,0.7361578941345215,0.2923185212194886,3554,0.7155856517304445 -270.5280725955963,0.4526963233947754,1417.6836984157562,5582,0,1417.6836984157562,0.2939875138534976,3581,0.732240047843654,1688.8488776683807,0.2737070322036743,0.7356574194771903,0.2923291345359102,3554,0.7151166049697524 -274.5515446662903,0.4803798198699951,1497.8876945972445,5930,0,1497.8876945972445,0.2922278401306374,3581,0.7344219055431443,1773.1156358718872,0.2719990525926862,0.7378465107509068,0.2905588058789392,3554,0.7171564225608469 -278.56909227371216,0.5047817230224609,1577.9124248027802,6274,0,1577.9124248027802,0.2925560766676557,3581,0.7346790679104999,1857.193552732468,0.2726098469325474,0.7375033923557827,0.2910405267941932,3554,0.7173943806714266 -282.5923149585724,0.5298738479614258,1658.0193195343018,6618,0,1658.0193195343018,0.2922553835019024,3581,0.7346147091419994,1941.3599860668185,0.2716326372964041,0.7380969864981515,0.2906893943224184,3554,0.7171778552775042 -286.6163022518158,0.5550334453582764,1738.11696600914,6965,0,1738.11696600914,0.2917652273937098,3581,0.7359029071619311,2025.51864361763,0.2713968753814697,0.7388999802725655,0.2902418489729881,3554,0.7183991079593416 -290.63995122909546,0.5848760604858398,1818.217258691788,7311,0,1818.217258691788,0.2925642237787105,3581,0.7355315488864842,2109.6839339733124,0.2722315958568028,0.7389180319649833,0.2909853306793402,3554,0.7183624250404473 -294.6629104614258,0.6104896068572998,1898.3242774009705,7655,0,1898.3242774009705,0.2917681249018256,3581,0.7366848934698758,2193.850837945938,0.2715520858764648,0.7400620324271066,0.2903149400323579,3554,0.7194567987874578 -298.62927532196045,0.6351032257080078,1978.3652493953705,8001,0,1978.3652493953705,0.2913099436478986,3581,0.7354586680352905,2277.894032239914,0.2712582860674177,0.7384415354047503,0.2899239990371764,3554,0.7179328089828363 -302.6508135795593,0.659902811050415,2058.3442318439484,8349,0,2058.3442318439484,0.2907514404365052,3581,0.7374274054907847,2361.930545568466,0.2707258292606899,0.7404849188668388,0.2892531962225661,3554,0.7200534801236284 -306.67108273506165,0.6853137016296387,2138.530520439148,8694,0,2138.530520439148,0.2910767453770769,3581,0.7370651828879852,2446.174116373062,0.2708195107323782,0.7401842389787946,0.2895578567942811,3554,0.7198569448596651 -310.69659090042114,0.710352897644043,2218.502999305725,9041,0,2218.502999305725,0.2917106860644722,3581,0.7357169894058923,2530.208495140076,0.2717066322054182,0.7388166018894741,0.2902902786692811,3554,0.7182672143183737 -314.72029852867126,0.7382581233978271,2298.4649019241333,9385,0,2298.4649019241333,0.291312602537699,3581,0.7357533957431583,2614.2332146167755,0.271361129624503,0.7385387420654297,0.2897665853505733,3554,0.718523170415377 -318.74656081199646,0.763106107711792,2378.51877951622,9731,0,2378.51877951622,0.2919103414178127,3581,0.7369561002295099,2698.349310874939,0.2714030572346279,0.7405737468174526,0.2901874084987866,3554,0.7199406148881542 -322.7638728618622,0.787938117980957,2458.4906923770905,10076,0,2458.4906923770905,0.2907320782646258,3581,0.7354978014390184,2782.374828338623,0.2708113874707903,0.7381803648812431,0.2894058012868423,3554,0.7178138299275464 -326.7850124835968,0.813319206237793,2538.6012523174286,10425,0,2538.6012523174286,0.2913566105727276,3581,0.7357311019748325,2866.543663740158,0.270908168384007,0.7391643524169922,0.2897740043678777,3554,0.7184152511914392 -330.80628180503845,0.8382964134216309,2618.594559669494,10769,0,2618.594559669494,0.2910979824071488,3581,0.7358581150952946,2950.5946168899536,0.2707622562135969,0.7393864904131208,0.2896606239228686,3554,0.7183939558639912 -334.8300590515137,0.8649129867553711,2698.567401409149,11116,0,2698.567401409149,0.2914646705812797,3581,0.735458736211952,3034.629213809967,0.2708562101636614,0.7392473902021136,0.2900183510766917,3554,0.717951219136888 -338.85645866394043,0.8903038501739502,2778.720671415329,11464,0,2778.720671415329,0.2926540465302988,3581,0.7351206481473401,3118.845671892166,0.2724556241716657,0.7386783191135952,0.2911367335880346,3554,0.7178643204619795 -342.8772482872009,0.9197978973388672,2858.910383224488,11809,0,2858.910383224488,0.2903113259978881,3581,0.7360054448608978,3203.0968782901764,0.2700128214699881,0.7395752498081752,0.2889353462867544,3554,0.7184557123135903 -346.899793624878,0.9456956386566162,2939.020454645157,12156,0,2939.020454645157,0.2907112162061924,3581,0.7359343366029042,3287.2667756080627,0.2702807358333042,0.7396104676382882,0.2892972294641601,3554,0.7184459576797271 -350.86227202415466,0.9773619174957277,3019.049448251724,12501,0,3019.049448251724,0.2926618186697152,3581,0.7345638493524853,3371.301284790039,0.2721385274614606,0.7382897649492536,0.2910956885617438,3554,0.7172479237742684 -354.8867268562317,1.0039846897125244,3099.242963552475,12846,0,3099.242963552475,0.2908907935327073,3581,0.7372216483262357,3455.5569927692413,0.2703834431512015,0.7408561025347028,0.2893926462700478,3554,0.719871782894274 -358.91018557548523,1.0313971042633057,3179.2792825698853,13193,0,3179.2792825698853,0.291466272732826,3581,0.7359990362547124,3539.655587911606,0.2707656451633998,0.7397722516741071,0.2900265600819499,3554,0.7187002651062183 -362.9328968524933,1.0576093196868896,3259.4413664340973,13540,0,3259.4413664340973,0.2903057014233105,3581,0.7370741822073094,3623.878072023392,0.2699011053357805,0.7408096449715751,0.2887142870489237,3554,0.7198803010252532 -366.9540934562683,1.0833580493927002,3339.4830079078674,13890,0,3339.4830079078674,0.2898894488162349,3581,0.7361964758665527,3707.978204965592,0.269716739654541,0.7395924840654645,0.2885267507781724,3554,0.7187545338439083 -370.97393560409546,1.112830638885498,3419.608423233032,14234,0,3419.608423233032,0.2901660415321314,3581,0.7397115281825957,3792.164309263229,0.2695150886263166,0.7435992785862514,0.2886484776176491,3554,0.722461569490363 -374.9954266548157,1.1386516094207764,3499.815920114517,14581,0,3499.815920114517,0.2897205070489214,3581,0.738634200576829,3876.430303573608,0.269396322114127,0.741945607321603,0.2883350928311409,3554,0.7212877160857485 -379.01821637153625,1.1643576622009275,3579.806512117386,14929,0,3579.806512117386,0.2895457702653763,3581,0.7374406317631248,3960.480762004852,0.2694524015699114,0.7403436388288226,0.2881380423576428,3554,0.7200152172288267 -383.04474329948425,1.1917848587036133,3659.9155492782593,15273,0,3659.9155492782593,0.290189426127042,3581,0.7385734551713907,4044.655151128769,0.2695294788905552,0.7423791204180036,0.2886782567287739,3554,0.7212960281329136 -387.0701777935028,1.2184903621673584,3740.0139966011047,15618,0,3740.0139966011047,0.2900657195746648,3581,0.7362009755262148,4128.817002296448,0.2698123455047607,0.7393215043204171,0.2887839090307576,3554,0.7186708638154192 -391.0984275341034,1.2453598976135254,3820.050070762634,15963,0,3820.050070762634,0.2911154356325048,3581,0.7371156336175301,4212.91955447197,0.2705531631197248,0.7407644816807338,0.2897495147413126,3554,0.7197459343785172 -395.1222703456879,1.2719926834106443,3900.068152904512,16306,0,3900.068152904512,0.2900724349758273,3581,0.7390276480906172,4296.999444723129,0.2694735697337559,0.742877687726702,0.2886311322299698,3554,0.7217908353703574 -399.14530396461487,1.300278902053833,3980.2133026123047,16653,0,3980.2133026123047,0.2897662535888194,3581,0.7379176638639695,4381.207130908966,0.269013694354466,0.7418242863246373,0.2883058632768535,3554,0.7205337241048818 -403.1664481163025,1.3272216320037842,4060.215700864792,16997,0,4060.215700864792,0.2902862028981081,3581,0.737869872024225,4465.268810033798,0.2694042580468314,0.7416016033717564,0.2887377119091165,3554,0.7206827227024127 -407.1881792545319,1.3539175987243652,4140.239362716675,17341,0,4140.239362716675,0.2898127159836637,3581,0.737314027702632,4549.352287769318,0.2689962557383946,0.7411505835396903,0.2881981501367297,3554,0.7200577391891179 -411.21010637283325,1.3826377391815186,4220.379903554916,17687,0,4220.379903554916,0.2892712228493263,3581,0.738631064450398,4633.554921627045,0.2687692982809884,0.7420896802629743,0.2878214117510639,3554,0.721324536393852 -415.2354400157929,1.4109349250793457,4300.405557632446,18034,0,4300.405557632446,0.2894379488751396,3581,0.738011065890289,4717.645820856094,0.2690235035760062,0.7414772169930595,0.2880781406623698,3554,0.7205499360315841 -419.25944995880127,1.4389152526855469,4380.436009883881,18376,0,4380.436009883881,0.2907369188075956,3581,0.7376423665046425,4801.739559173584,0.2698216608592442,0.7417355946132115,0.2891783534507773,3554,0.7205434100441404 -423.2813596725464,1.4707441329956057,4460.510647535324,18722,0,4460.510647535324,0.2893671133237922,3581,0.7383829014023666,4885.879251480103,0.2683524574552263,0.7425732612609863,0.2878672825733329,3554,0.7210469414963773 -427.30615234375,1.4997599124908447,4540.564539909363,19071,0,4540.564539909363,0.2900805820868821,3581,0.7375544867879084,4969.998610019684,0.269416298185076,0.7413437707083566,0.2886675060231429,3554,0.7201945101470174 -431.3310775756836,1.5268681049346924,4620.680071592331,19414,0,4620.680071592331,0.289402224304489,3581,0.7385687509817439,5054.177350997925,0.268864699772426,0.7421463557652065,0.2880018553038829,3554,0.7212591391302055 -435.3556377887726,1.553640365600586,4700.793580293655,19761,0,4700.793580293655,0.2890437855064053,3581,0.7378884160761658,5138.35369515419,0.2684040069580078,0.7418694496154785,0.2876706099201603,3554,0.7203742152328363 -439.3786163330078,1.5818867683410645,4780.86026096344,20106,0,4780.86026096344,0.2893898502404182,3581,0.7390157171748464,5222.482827663422,0.2686488287789481,0.7428394726344517,0.2879764383001547,3554,0.7216941820615855 -443.4014058113098,1.6094567775726318,4860.848987817764,20450,0,4860.848987817764,0.2890304910574036,3581,0.7386649482511868,5306.533141613007,0.268344555582319,0.7424825940813337,0.2875891896399743,3554,0.7212864795828644 -447.4226930141449,1.6419434547424316,4940.809394598007,20795,0,4940.809394598007,0.289724018146991,3581,0.7381565548860305,5390.558655261993,0.268923418862479,0.7422796658107212,0.2882749163574493,3554,0.7209059801675929 -451.4455780982971,1.6706173419952393,5020.819479465485,21143,0,5020.819479465485,0.2891526977232267,3581,0.7366146033318207,5474.631682395935,0.2688467161996024,0.7402078083583287,0.2878423636054885,3554,0.7190345330525464 -455.46900844573975,1.699165105819702,5100.909521818161,21492,0,5100.909521818161,0.2888301198491169,3581,0.7382451163693801,5558.785286426544,0.2684067487716675,0.7418524878365653,0.2874485717842132,3554,0.7208238214204066 -459.4944705963135,1.727196455001831,5180.924647092819,21835,0,5180.924647092819,0.2895198631339884,3581,0.7389221106185423,5642.864949226379,0.268767237663269,0.7428131103515625,0.2881664475766741,3554,0.7215138587243247 -463.5176067352295,1.7548956871032717,5261.056898117065,22183,0,5261.056898117065,0.2891085874232058,3581,0.7386228150743508,5727.059560060501,0.2683233533586774,0.7425223759242466,0.287667398447392,3554,0.7213983143992684 -467.53966760635376,1.7834970951080322,5341.214492321014,22531,0,5341.214492321014,0.2886374866919157,3581,0.7405354431373918,5811.279220342636,0.2679621321814401,0.7441225733075824,0.2872024218420266,3554,0.7233546680456879 -471.5667657852173,1.812317132949829,5421.3310425281525,22876,0,5421.3310425281525,0.2891892745021467,3581,0.7389800607808573,5895.463119029999,0.2679972989218576,0.7432924679347447,0.2877167726944991,3554,0.7215693639648987 -475.58651852607727,1.84131932258606,5501.379611253738,23222,0,5501.379611253738,0.2884190486683189,3581,0.7390121038117844,5979.571918964386,0.267835259437561,0.7425169263567243,0.2870750276976646,3554,0.7215940940225802 -479.6103277206421,1.8737876415252688,5581.438790082932,23567,0,5581.438790082932,0.288459818311924,3581,0.7396102858401983,6063.698635339737,0.2676151139395578,0.7436034338814872,0.2870295518693725,3554,0.7222664081184933 -483.6366469860077,1.902566194534301,5661.581089258194,23910,0,5661.581089258194,0.2890792373704098,3581,0.7397383897872452,6147.907499790192,0.2682526792798723,0.7438061577933175,0.2877475993983451,3554,0.7224687137292487 -487.6585056781769,1.932025671005249,5741.601750612259,24254,0,5741.601750612259,0.2885350171696104,3581,0.7383813333391511,6231.99070930481,0.267869165965489,0.742143086024693,0.2871526354272914,3554,0.7209467160681626 -491.6800878047943,1.9605860710144043,5821.691652059555,24600,0,5821.691652059555,0.2893403880724658,3581,0.7388236635192683,6316.142263650894,0.2684909956795828,0.7429958071027484,0.2879671645285242,3554,0.7214709245964055 -495.7059428691864,1.9899907112121584,5901.788957595825,24942,0,5901.788957595825,0.2886966640341385,3581,0.7389417454970678,6400.3060693740845,0.2679623876299177,0.7429577963692802,0.2873232556516425,3554,0.7216223275050999 -499.72960019111633,2.020745277404785,5981.759289741516,25286,0,5981.759289741516,0.2882990918323269,3581,0.7401606760288676,6484.3422310352325,0.2674739871706281,0.7442034993852887,0.2869437351344875,3554,0.7227554450091446 -503.7518606185913,2.050223588943481,6061.882197141647,25632,0,6061.882197141647,0.2882901947779949,3581,0.7397795003141581,6568.528303146362,0.2674898760659354,0.7437772750854492,0.2869256169325056,3554,0.722484719572137 -507.7743980884552,2.079303741455078,6141.866056203842,25975,0,6141.866056203842,0.2887357292612049,3581,0.7398986731185423,6652.575384140015,0.2677900450570242,0.7442035675048828,0.2873275490644344,3554,0.7226171627699424 -511.7950129508972,2.1087591648101807,6222.008386850357,26317,0,6222.008386850357,0.2884458420963069,3581,0.7379803182159314,6736.779012918472,0.2675496169498988,0.7420498303004673,0.2871192326757703,3554,0.7204110355409398 -515.8137700557709,2.137401580810547,6302.204536914825,26663,0,6302.204536914825,0.2882063374842921,3581,0.7390444877260193,6821.033949136734,0.2673583030700683,0.7431620870317731,0.2868381687007597,3554,0.7216088633625845 -519.8362936973572,2.1678225994110107,6382.178173780441,27009,0,6382.178173780441,0.2884547391506388,3581,0.7397913630532672,6905.071727991104,0.2674673965999058,0.7441284315926688,0.2871214309031197,3554,0.7224542191676632 -523.8579688072205,2.198601007461548,6462.302304506302,27354,0,6462.302304506302,0.2886721886126431,3581,0.7403103238009634,6989.259694099426,0.2674651827130999,0.7447470937456403,0.2872797032722812,3554,0.7230006160532146 -527.8807153701782,2.23337459564209,6542.314068317413,27700,0,6542.314068317413,0.288021510554838,3581,0.7397503207030159,7073.3403215408325,0.2670155423028128,0.7439596993582589,0.2866732844758986,3554,0.7222767810037986 -531.9025478363037,2.262672185897827,6622.435270547867,28045,0,6622.435270547867,0.2889350778195685,3581,0.7397511388229545,7157.523993968964,0.267963205065046,0.7440105165754046,0.2876472022336188,3554,0.7223733656179657 -535.9265110492706,2.2928481101989746,6702.4539477825165,28391,0,6702.4539477825165,0.2899664543554524,3581,0.7392949006038816,7241.608174085617,0.268479687826974,0.7441020011901855,0.2886336052357379,3554,0.7219329645074212 -539.9468801021576,2.322210788726806,6782.633191108704,28738,0,6782.633191108704,0.2879308696833112,3581,0.7395879920718724,7325.848667621613,0.2669568913323538,0.7438011169433594,0.2865940452494109,3554,0.7221119139525887 -543.9663681983948,2.3514530658721924,6862.738671541214,29081,0,6862.738671541214,0.287915700376117,3581,0.7403883179017733,7410.014344930649,0.267032333782741,0.7445215497698102,0.2865559025701674,3554,0.7230511752822524 -547.9826490879059,2.3811745643615723,6942.728103876114,29421,0,6942.728103876114,0.2883681888788048,3581,0.7399836894154915,7494.060920238495,0.2672220127923148,0.7445074490138462,0.2870260999654878,3554,0.7226337181696679 -552.0078539848328,2.4109811782836914,7022.759325265884,29768,0,7022.759325265884,0.2881304227716594,3581,0.7407431092484641,7578.158590555191,0.2670083386557443,0.7450000217982701,0.2868083895896349,3554,0.723371773002251 -556.0319905281067,2.440687417984009,7102.75586271286,30115,0,7102.75586271286,0.2878097538440728,3581,0.7410003397924811,7662.220293045044,0.2668940850666591,0.745185102735247,0.2865342294223937,3554,0.7236418114932118 -560.0553650856018,2.4702839851379395,7182.74059510231,30457,0,7182.74059510231,0.2879080986783545,3581,0.7395849241221027,7746.269132852554,0.2668058701923915,0.7440342221941266,0.2865863686273389,3554,0.7220873899787211 -564.0749621391296,2.5015549659729004,7262.764033317566,30802,0,7262.764033317566,0.2881001182456018,3581,0.740367728549986,7830.354666471481,0.2671293360846383,0.7446126937866211,0.2867912846330719,3554,0.7229679174213914 -568.098906993866,2.532049655914306,7342.755488395691,31147,0,7342.755488395691,0.2880672230064053,3581,0.7406983853584892,7914.412054538727,0.2669330665043422,0.7450896671840123,0.2867378574042892,3554,0.7233025288407429 -572.1239104270935,2.563407182693481,7422.793350458145,31489,0,7422.793350458145,0.2878837596101822,3581,0.74044149569778,7998.517551660538,0.2668275151933942,0.7447887829371861,0.2865797052506858,3554,0.7230282999788971 -576.1479756832123,2.593926191329956,7502.960831642151,31837,0,7502.960831642151,0.2877985728715792,3581,0.740141654740296,8082.751039028168,0.2666935239519392,0.7444804736546108,0.2864793424332618,3554,0.7226634629334905 -580.1747455596924,2.6239583492279053,7583.132245779037,32183,0,7583.132245779037,0.2877483607603497,3581,0.7408316707318138,8166.990717411041,0.2666273287364414,0.7452107838221959,0.286437782197436,3554,0.7234276217158483 -584.1944150924683,2.6562201976776123,7663.143029689789,32527,0,7663.143029689789,0.2880450655914025,3581,0.7412730464386693,8251.064739465714,0.2667953797749111,0.7456954547337123,0.286699027778999,3554,0.7239288862461312 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/measurements.csv deleted file mode 100644 index 00152386b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/measurements.csv +++ /dev/null @@ -1,424 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.7338097,0.8675922,,,,,,,,,,,,,, -1,,,0.2754161528178623,0.7874368940080915,0.2711523739756261,0.7887035019959905,3554.0,0.2942129740732337,0.7883392137976473,3581.0,55.27614974975586,257.10017442703247,55.27614974975586,201.8237082958221,0.0,0.0 -100,0.27257302,0.29150486,,,,,,,,,,,,,, -200,0.2589629,0.38168514,,,,,,,,,,,,,, -300,0.17125995,0.38091353,,,,,,,,,,,,,, -335,,,0.7016571589878627,0.3088119370596749,0.6821352948482344,0.3278064974654439,3554.0,0.700335142832833,0.3293810186356813,3581.0,135.5868763923645,341.8186767101288,135.5868763923645,206.1908564567566,0.029102087020874,0.0 -400,0.08627533,0.30232683,,,,,,,,,,,,,, -500,0.12504089,0.29936007,,,,,,,,,,,,,, -570,,,0.7112112726484027,0.2976064000810896,0.6918897913222777,0.3155408353703573,3554.0,0.7096741864615331,0.3174902589185981,3581.0,215.7821829319,426.0742771625519,215.7821829319,210.20913743972773,0.0634787082672119,0.0 -600,0.107785344,0.37188053,,,,,,,,,,,,,, -700,0.37273005,0.28078055,,,,,,,,,,,,,, -800,0.22444786,0.2863211,,,,,,,,,,,,,, -809,,,0.7181434631347656,0.2927102702004568,0.6987013422376196,0.3111197880029544,3554.0,0.7159655969657568,0.3132424779325782,3581.0,296.0674078464508,510.41920804977417,296.0674078464508,214.2245299816132,0.1001718044281005,0.0 -900,0.26899448,0.30442417,,,,,,,,,,,,,, -1000,0.15249602,0.23173304,,,,,,,,,,,,,, -1085,,,0.7194068091256278,0.2878707477024623,0.7001576678566404,0.3055842050220702,3554.0,0.7175513861133412,0.3076909887361246,3581.0,376.2434275150299,594.6569521427155,376.2434275150299,218.2427520751953,0.1342196464538574,0.0 -1100,0.14004692,0.26989838,,,,,,,,,,,,,, -1200,0.34471637,0.24770275,,,,,,,,,,,,,, -1300,0.23452508,0.3452118,,,,,,,,,,,,,, -1400,0.084050134,0.3243309,,,,,,,,,,,,,, -1433,,,0.725543635232108,0.2832481009619577,0.7061628128077518,0.3011932458915131,3554.0,0.7231047842563181,0.3031562862699839,3581.0,456.43198323249817,678.9004311561584,456.43198323249817,222.26331734657288,0.1570632457733154,0.0 -1500,0.17049678,0.33319467,,,,,,,,,,,,,, -1600,0.23990354,0.21673386,,,,,,,,,,,,,, -1700,0.16468829,0.35379025,,,,,,,,,,,,,, -1779,,,0.727372373853411,0.2818348578044346,0.7079292258722566,0.2996103161160839,3554.0,0.7249152836803616,0.3014893328068277,3581.0,536.503087759018,763.0277693271637,536.503087759018,226.2819182872772,0.183495283126831,0.0 -1800,0.23538998,0.36362392,,,,,,,,,,,,,, -1900,0.29376575,0.32109317,,,,,,,,,,,,,, -2000,0.1181399,0.32342735,,,,,,,,,,,,,, -2100,0.24886413,0.30132243,,,,,,,,,,,,,, -2124,,,0.7304275376456124,0.2815104722976684,0.7106021329399972,0.299705458143553,3554.0,0.7273907100844736,0.3017955482821663,3581.0,616.6463549137115,847.2227547168732,616.6463549137115,230.29897046089167,0.2066645622253418,0.0 -2200,0.08359691,0.2635861,,,,,,,,,,,,,, -2300,0.21410973,0.259557,,,,,,,,,,,,,, -2400,0.21092258,0.43794677,,,,,,,,,,,,,, -2470,,,0.7321715354919434,0.2781985827854701,0.7131566105092854,0.2960028189517972,3554.0,0.7302026564899818,0.2977084596874127,3581.0,696.7213439941406,931.3529715538024,696.7213439941406,234.31987977027893,0.2298219203948974,0.0 -2500,0.113017194,0.33653015,,,,,,,,,,,,,, -2600,0.22998418,0.28623986,,,,,,,,,,,,,, -2700,0.09801333,0.29073066,,,,,,,,,,,,,, -2800,0.10790854,0.29849824,,,,,,,,,,,,,, -2814,,,0.7344308580671038,0.2757815633501325,0.7145839470051351,0.294157201008107,3554.0,0.7317454943407917,0.2958164209761589,3581.0,776.725414276123,1015.4160070419312,776.725414276123,238.3439779281616,0.2531630992889404,0.0 -2900,0.1413616,0.34451628,,,,,,,,,,,,,, -3000,0.16347328,0.2094135,,,,,,,,,,,,,, -3100,0.14418936,0.26691166,,,,,,,,,,,,,, -3159,,,0.734090873173305,0.2757396868297032,0.713944606319464,0.2939341152794386,3554.0,0.7313013233908127,0.2954341203465337,3581.0,856.8844676017761,1099.6351277828217,856.8844676017761,242.36777329444885,0.2781813144683838,0.0 -3200,0.09630538,0.22705914,,,,,,,,,,,,,, -3300,0.17130451,0.27657968,,,,,,,,,,,,,, -3400,0.12797052,0.3154865,,,,,,,,,,,,,, -3500,0.14940108,0.35790664,,,,,,,,,,,,,, -3505,,,0.7327982357570103,0.2749878849302019,0.7125102629739378,0.2931661782938238,3554.0,0.7299829912864773,0.2947556602991483,3581.0,936.8562984466552,1183.6691632270813,936.8562984466552,246.39089345932007,0.3060920238494873,0.0 -3600,0.07286058,0.3215726,,,,,,,,,,,,,, -3700,0.27302802,0.23081434,,,,,,,,,,,,,, -3800,0.22575867,0.2798206,,,,,,,,,,,,,, -3853,,,0.7350469316755023,0.2754707336425781,0.7151173606104038,0.2935341409437429,3554.0,0.7323862186060109,0.2951208485867251,3581.0,1017.0803380012512,1267.948300600052,1017.0803380012512,250.41059923172,0.32985520362854,0.0 -3900,0.083343744,0.32231152,,,,,,,,,,,,,, -4000,0.16785318,0.31869742,,,,,,,,,,,,,, -4100,0.15610832,0.30462134,,,,,,,,,,,,,, -4197,,,0.7359296253749302,0.2739982094083513,0.7157386346150464,0.2924870634320484,3554.0,0.7330801888438984,0.294075802630201,3581.0,1097.1687524318695,1352.0996026992798,1097.1687524318695,254.43634033203125,0.3551664352416992,0.0 -4200,0.104223445,0.25913507,,,,,,,,,,,,,, -4300,0.1217054,0.44295844,,,,,,,,,,,,,, -4400,0.11147143,0.2774009,,,,,,,,,,,,,, -4500,0.24267848,0.28160962,,,,,,,,,,,,,, -4542,,,0.7363841874258858,0.2733580384935651,0.7157795079048256,0.2920301756163829,3554.0,0.7331703183904635,0.2935119475508762,3581.0,1177.3439629077911,1436.3335757255554,1177.3439629077911,258.4599637985229,0.3792397975921631,0.0 -4600,0.06410903,0.25155675,,,,,,,,,,,,,, -4700,0.1399433,0.2843745,,,,,,,,,,,,,, -4800,0.14531812,0.27947623,,,,,,,,,,,,,, -4889,,,0.7380552291870117,0.2730294125420706,0.7179574016513084,0.2913801185723832,3554.0,0.7352506610409104,0.2929092317744694,3581.0,1257.5012693405151,1520.5500855445862,1257.5012693405151,262.4826068878174,0.4045064449310303,0.0 -4900,0.09568358,0.30742106,,,,,,,,,,,,,, -5000,0.15664577,0.35746595,,,,,,,,,,,,,, -5100,0.14128824,0.2804004,,,,,,,,,,,,,, -5200,0.0925189,0.2525345,,,,,,,,,,,,,, -5237,,,0.7361578941345215,0.2737502711159842,0.7155856517304445,0.2923185212194886,3554.0,0.7328710910229336,0.2938717157938599,3581.0,1337.5284023284912,1604.6348433494568,1337.5284023284912,266.5050644874573,0.4283595085144043,0.0 -5300,0.075950295,0.23319116,,,,,,,,,,,,,, -5400,0.18100786,0.37970266,,,,,,,,,,,,,, -5500,0.13503936,0.2987436,,,,,,,,,,,,,, -5582,,,0.7356574194771903,0.2737070322036743,0.7151166049697524,0.2923291345359102,3554.0,0.732240047843654,0.2939875138534976,3581.0,1417.6836984157562,1688.8488776683807,1417.6836984157562,270.5280725955963,0.4526963233947754,0.0 -5600,0.11677386,0.28088474,,,,,,,,,,,,,, -5700,0.106305905,0.28334194,,,,,,,,,,,,,, -5800,0.12183828,0.3038551,,,,,,,,,,,,,, -5900,0.11007997,0.3649738,,,,,,,,,,,,,, -5930,,,0.7378465107509068,0.2719990525926862,0.7171564225608469,0.2905588058789392,3554.0,0.7344219055431443,0.2922278401306374,3581.0,1497.8876945972445,1773.1156358718872,1497.8876945972445,274.5515446662903,0.4803798198699951,0.0 -6000,0.14248799,0.32109618,,,,,,,,,,,,,, -6100,0.06840112,0.2798931,,,,,,,,,,,,,, -6200,0.11377606,0.28471223,,,,,,,,,,,,,, -6274,,,0.7375033923557827,0.2726098469325474,0.7173943806714266,0.2910405267941932,3554.0,0.7346790679104999,0.2925560766676557,3581.0,1577.9124248027802,1857.193552732468,1577.9124248027802,278.56909227371216,0.5047817230224609,0.0 -6300,0.22460598,0.23435624,,,,,,,,,,,,,, -6400,0.12944655,0.2987815,,,,,,,,,,,,,, -6500,0.2526553,0.33517358,,,,,,,,,,,,,, -6600,0.27014425,0.2755266,,,,,,,,,,,,,, -6618,,,0.7380969864981515,0.2716326372964041,0.7171778552775042,0.2906893943224184,3554.0,0.7346147091419994,0.2922553835019024,3581.0,1658.0193195343018,1941.3599860668185,1658.0193195343018,282.5923149585724,0.5298738479614258,0.0 -6700,0.17532629,0.34974432,,,,,,,,,,,,,, -6800,0.09619828,0.27894506,,,,,,,,,,,,,, -6900,0.10469079,0.3049395,,,,,,,,,,,,,, -6965,,,0.7388999802725655,0.2713968753814697,0.7183991079593416,0.2902418489729881,3554.0,0.7359029071619311,0.2917652273937098,3581.0,1738.11696600914,2025.51864361763,1738.11696600914,286.6163022518158,0.5550334453582764,0.0 -7000,0.28173387,0.268291,,,,,,,,,,,,,, -7100,0.18702354,0.2696089,,,,,,,,,,,,,, -7200,0.11917335,0.36671662,,,,,,,,,,,,,, -7300,0.058589473,0.25894114,,,,,,,,,,,,,, -7311,,,0.7389180319649833,0.2722315958568028,0.7183624250404473,0.2909853306793402,3554.0,0.7355315488864842,0.2925642237787105,3581.0,1818.217258691788,2109.6839339733124,1818.217258691788,290.63995122909546,0.5848760604858398,0.0 -7400,0.1194047,0.24322388,,,,,,,,,,,,,, -7500,0.1469257,0.25937995,,,,,,,,,,,,,, -7600,0.20462121,0.31089926,,,,,,,,,,,,,, -7655,,,0.7400620324271066,0.2715520858764648,0.7194567987874578,0.2903149400323579,3554.0,0.7366848934698758,0.2917681249018256,3581.0,1898.3242774009705,2193.850837945938,1898.3242774009705,294.6629104614258,0.6104896068572998,0.0 -7700,0.13607526,0.27140564,,,,,,,,,,,,,, -7800,0.14254807,0.245042,,,,,,,,,,,,,, -7900,0.07503486,0.25084135,,,,,,,,,,,,,, -8000,0.19122943,0.27727503,,,,,,,,,,,,,, -8001,,,0.7384415354047503,0.2712582860674177,0.7179328089828363,0.2899239990371764,3554.0,0.7354586680352905,0.2913099436478986,3581.0,1978.3652493953705,2277.894032239914,1978.3652493953705,298.62927532196045,0.6351032257080078,0.0 -8100,0.17420909,0.2177901,,,,,,,,,,,,,, -8200,0.12935306,0.25691468,,,,,,,,,,,,,, -8300,0.25498992,0.22240311,,,,,,,,,,,,,, -8349,,,0.7404849188668388,0.2707258292606899,0.7200534801236284,0.2892531962225661,3554.0,0.7374274054907847,0.2907514404365052,3581.0,2058.3442318439484,2361.930545568466,2058.3442318439484,302.6508135795593,0.659902811050415,0.0 -8400,0.07179847,0.2667808,,,,,,,,,,,,,, -8500,0.103897445,0.31444967,,,,,,,,,,,,,, -8600,0.17644006,0.23727529,,,,,,,,,,,,,, -8694,,,0.7401842389787946,0.2708195107323782,0.7198569448596651,0.2895578567942811,3554.0,0.7370651828879852,0.2910767453770769,3581.0,2138.530520439148,2446.174116373062,2138.530520439148,306.67108273506165,0.6853137016296387,0.0 -8700,0.10887247,0.25715426,,,,,,,,,,,,,, -8800,0.14147115,0.2917154,,,,,,,,,,,,,, -8900,0.34630296,0.29216316,,,,,,,,,,,,,, -9000,0.2255619,0.23291713,,,,,,,,,,,,,, -9041,,,0.7388166018894741,0.2717066322054182,0.7182672143183737,0.2902902786692811,3554.0,0.7357169894058923,0.2917106860644722,3581.0,2218.502999305725,2530.208495140076,2218.502999305725,310.69659090042114,0.710352897644043,0.0 -9100,0.31153294,0.25740436,,,,,,,,,,,,,, -9200,0.39285576,0.22667448,,,,,,,,,,,,,, -9300,0.098324046,0.34733933,,,,,,,,,,,,,, -9385,,,0.7385387420654297,0.271361129624503,0.718523170415377,0.2897665853505733,3554.0,0.7357533957431583,0.291312602537699,3581.0,2298.4649019241333,2614.2332146167755,2298.4649019241333,314.72029852867126,0.7382581233978271,0.0 -9400,0.19905408,0.23771982,,,,,,,,,,,,,, -9500,0.15127607,0.30108434,,,,,,,,,,,,,, -9600,0.1956144,0.29205525,,,,,,,,,,,,,, -9700,0.079704754,0.33365136,,,,,,,,,,,,,, -9731,,,0.7405737468174526,0.2714030572346279,0.7199406148881542,0.2901874084987866,3554.0,0.7369561002295099,0.2919103414178127,3581.0,2378.51877951622,2698.349310874939,2378.51877951622,318.74656081199646,0.763106107711792,0.0 -9800,0.16968875,0.29035503,,,,,,,,,,,,,, -9900,0.11456586,0.25099304,,,,,,,,,,,,,, -10000,0.10075717,0.23337907,,,,,,,,,,,,,, -10076,,,0.7381803648812431,0.2708113874707903,0.7178138299275464,0.2894058012868423,3554.0,0.7354978014390184,0.2907320782646258,3581.0,2458.4906923770905,2782.374828338623,2458.4906923770905,322.7638728618622,0.787938117980957,0.0 -10100,0.14250432,0.27302718,,,,,,,,,,,,,, -10200,0.082512446,0.28709608,,,,,,,,,,,,,, -10300,0.18266828,0.29185367,,,,,,,,,,,,,, -10400,0.21751863,0.34183314,,,,,,,,,,,,,, -10425,,,0.7391643524169922,0.270908168384007,0.7184152511914392,0.2897740043678777,3554.0,0.7357311019748325,0.2913566105727276,3581.0,2538.6012523174286,2866.543663740158,2538.6012523174286,326.7850124835968,0.813319206237793,0.0 -10500,0.20463173,0.30688503,,,,,,,,,,,,,, -10600,0.11220282,0.21668713,,,,,,,,,,,,,, -10700,0.18874013,0.25209522,,,,,,,,,,,,,, -10769,,,0.7393864904131208,0.2707622562135969,0.7183939558639912,0.2896606239228686,3554.0,0.7358581150952946,0.2910979824071488,3581.0,2618.594559669494,2950.5946168899536,2618.594559669494,330.80628180503845,0.8382964134216309,0.0 -10800,0.26332942,0.26148468,,,,,,,,,,,,,, -10900,0.2109802,0.37564817,,,,,,,,,,,,,, -11000,0.1391919,0.30604312,,,,,,,,,,,,,, -11100,0.2534981,0.32696652,,,,,,,,,,,,,, -11116,,,0.7392473902021136,0.2708562101636614,0.717951219136888,0.2900183510766917,3554.0,0.735458736211952,0.2914646705812797,3581.0,2698.567401409149,3034.629213809967,2698.567401409149,334.8300590515137,0.8649129867553711,0.0 -11200,0.08294305,0.31233832,,,,,,,,,,,,,, -11300,0.050271243,0.27084023,,,,,,,,,,,,,, -11400,0.12691492,0.33141208,,,,,,,,,,,,,, -11464,,,0.7386783191135952,0.2724556241716657,0.7178643204619795,0.2911367335880346,3554.0,0.7351206481473401,0.2926540465302988,3581.0,2778.720671415329,3118.845671892166,2778.720671415329,338.85645866394043,0.8903038501739502,0.0 -11500,0.08878478,0.3096151,,,,,,,,,,,,,, -11600,0.25059012,0.22720282,,,,,,,,,,,,,, -11700,0.33693627,0.2581617,,,,,,,,,,,,,, -11800,0.14722486,0.1980044,,,,,,,,,,,,,, -11809,,,0.7395752498081752,0.2700128214699881,0.7184557123135903,0.2889353462867544,3554.0,0.7360054448608978,0.2903113259978881,3581.0,2858.910383224488,3203.0968782901764,2858.910383224488,342.8772482872009,0.9197978973388672,0.0 -11900,0.13529497,0.271892,,,,,,,,,,,,,, -12000,0.11521444,0.32360345,,,,,,,,,,,,,, -12100,0.07583731,0.26212293,,,,,,,,,,,,,, -12156,,,0.7396104676382882,0.2702807358333042,0.7184459576797271,0.2892972294641601,3554.0,0.7359343366029042,0.2907112162061924,3581.0,2939.020454645157,3287.2667756080627,2939.020454645157,346.899793624878,0.9456956386566162,0.0 -12200,0.118709266,0.28075927,,,,,,,,,,,,,, -12300,0.18205543,0.2852394,,,,,,,,,,,,,, -12400,0.1473377,0.34171838,,,,,,,,,,,,,, -12500,0.19862232,0.41089877,,,,,,,,,,,,,, -12501,,,0.7382897649492536,0.2721385274614606,0.7172479237742684,0.2910956885617438,3554.0,0.7345638493524853,0.2926618186697152,3581.0,3019.049448251724,3371.301284790039,3019.049448251724,350.86227202415466,0.9773619174957277,0.0 -12600,0.16489857,0.338531,,,,,,,,,,,,,, -12700,0.21349859,0.25434712,,,,,,,,,,,,,, -12800,0.094494596,0.2661044,,,,,,,,,,,,,, -12846,,,0.7408561025347028,0.2703834431512015,0.719871782894274,0.2893926462700478,3554.0,0.7372216483262357,0.2908907935327073,3581.0,3099.242963552475,3455.5569927692413,3099.242963552475,354.8867268562317,1.0039846897125244,0.0 -12900,0.17650332,0.23748617,,,,,,,,,,,,,, -13000,0.1965141,0.34195268,,,,,,,,,,,,,, -13100,0.119899884,0.27504075,,,,,,,,,,,,,, -13193,,,0.7397722516741071,0.2707656451633998,0.7187002651062183,0.2900265600819499,3554.0,0.7359990362547124,0.291466272732826,3581.0,3179.2792825698853,3539.655587911606,3179.2792825698853,358.91018557548523,1.0313971042633057,0.0 -13200,0.21341307,0.2965862,,,,,,,,,,,,,, -13300,0.13037635,0.3040249,,,,,,,,,,,,,, -13400,0.20401634,0.35466242,,,,,,,,,,,,,, -13500,0.15026434,0.27802023,,,,,,,,,,,,,, -13540,,,0.7408096449715751,0.2699011053357805,0.7198803010252532,0.2887142870489237,3554.0,0.7370741822073094,0.2903057014233105,3581.0,3259.4413664340973,3623.878072023392,3259.4413664340973,362.9328968524933,1.0576093196868896,0.0 -13600,0.22111844,0.20542541,,,,,,,,,,,,,, -13700,0.21618098,0.2909969,,,,,,,,,,,,,, -13800,0.0816786,0.242582,,,,,,,,,,,,,, -13890,,,0.7395924840654645,0.269716739654541,0.7187545338439083,0.2885267507781724,3554.0,0.7361964758665527,0.2898894488162349,3581.0,3339.4830079078674,3707.978204965592,3339.4830079078674,366.9540934562683,1.0833580493927002,0.0 -13900,0.19462277,0.26288313,,,,,,,,,,,,,, -14000,0.08202086,0.28277785,,,,,,,,,,,,,, -14100,0.3180646,0.28381538,,,,,,,,,,,,,, -14200,0.15163277,0.19840872,,,,,,,,,,,,,, -14234,,,0.7435992785862514,0.2695150886263166,0.722461569490363,0.2886484776176491,3554.0,0.7397115281825957,0.2901660415321314,3581.0,3419.608423233032,3792.164309263229,3419.608423233032,370.97393560409546,1.112830638885498,0.0 -14300,0.13378182,0.3225734,,,,,,,,,,,,,, -14400,0.052422855,0.28971088,,,,,,,,,,,,,, -14500,0.17382516,0.3729127,,,,,,,,,,,,,, -14581,,,0.741945607321603,0.269396322114127,0.7212877160857485,0.2883350928311409,3554.0,0.738634200576829,0.2897205070489214,3581.0,3499.815920114517,3876.430303573608,3499.815920114517,374.9954266548157,1.1386516094207764,0.0 -14600,0.18556099,0.23041809,,,,,,,,,,,,,, -14700,0.17080048,0.27423584,,,,,,,,,,,,,, -14800,0.09100302,0.3272407,,,,,,,,,,,,,, -14900,0.13280275,0.2300187,,,,,,,,,,,,,, -14929,,,0.7403436388288226,0.2694524015699114,0.7200152172288267,0.2881380423576428,3554.0,0.7374406317631248,0.2895457702653763,3581.0,3579.806512117386,3960.480762004852,3579.806512117386,379.01821637153625,1.1643576622009275,0.0 -15000,0.2206107,0.22874415,,,,,,,,,,,,,, -15100,0.08497248,0.26331297,,,,,,,,,,,,,, -15200,0.21039905,0.27113175,,,,,,,,,,,,,, -15273,,,0.7423791204180036,0.2695294788905552,0.7212960281329136,0.2886782567287739,3554.0,0.7385734551713907,0.290189426127042,3581.0,3659.9155492782593,4044.655151128769,3659.9155492782593,383.04474329948425,1.1917848587036133,0.0 -15300,0.08835022,0.28860062,,,,,,,,,,,,,, -15400,0.15241951,0.34656423,,,,,,,,,,,,,, -15500,0.16861589,0.28044677,,,,,,,,,,,,,, -15600,0.12083387,0.38843244,,,,,,,,,,,,,, -15618,,,0.7393215043204171,0.2698123455047607,0.7186708638154192,0.2887839090307576,3554.0,0.7362009755262148,0.2900657195746648,3581.0,3740.0139966011047,4128.817002296448,3740.0139966011047,387.0701777935028,1.2184903621673584,0.0 -15700,0.2928507,0.24429591,,,,,,,,,,,,,, -15800,0.11274435,0.30476883,,,,,,,,,,,,,, -15900,0.109470084,0.21162432,,,,,,,,,,,,,, -15963,,,0.7407644816807338,0.2705531631197248,0.7197459343785172,0.2897495147413126,3554.0,0.7371156336175301,0.2911154356325048,3581.0,3820.050070762634,4212.91955447197,3820.050070762634,391.0984275341034,1.2453598976135254,0.0 -16000,0.27598655,0.21340945,,,,,,,,,,,,,, -16100,0.12812962,0.356476,,,,,,,,,,,,,, -16200,0.44131002,0.23451447,,,,,,,,,,,,,, -16300,0.26910016,0.22634073,,,,,,,,,,,,,, -16306,,,0.742877687726702,0.2694735697337559,0.7217908353703574,0.2886311322299698,3554.0,0.7390276480906172,0.2900724349758273,3581.0,3900.068152904512,4296.999444723129,3900.068152904512,395.1222703456879,1.2719926834106443,0.0 -16400,0.079423174,0.29164854,,,,,,,,,,,,,, -16500,0.16566363,0.2373803,,,,,,,,,,,,,, -16600,0.11871176,0.27711642,,,,,,,,,,,,,, -16653,,,0.7418242863246373,0.269013694354466,0.7205337241048818,0.2883058632768535,3554.0,0.7379176638639695,0.2897662535888194,3581.0,3980.2133026123047,4381.207130908966,3980.2133026123047,399.14530396461487,1.300278902053833,0.0 -16700,0.13814586,0.3886914,,,,,,,,,,,,,, -16800,0.2665535,0.22516488,,,,,,,,,,,,,, -16900,0.15037069,0.390898,,,,,,,,,,,,,, -16997,,,0.7416016033717564,0.2694042580468314,0.7206827227024127,0.2887377119091165,3554.0,0.737869872024225,0.2902862028981081,3581.0,4060.215700864792,4465.268810033798,4060.215700864792,403.1664481163025,1.3272216320037842,0.0 -17000,0.13953607,0.2342499,,,,,,,,,,,,,, -17100,0.16351318,0.31549528,,,,,,,,,,,,,, -17200,0.14614855,0.30925643,,,,,,,,,,,,,, -17300,0.19594847,0.30444428,,,,,,,,,,,,,, -17341,,,0.7411505835396903,0.2689962557383946,0.7200577391891179,0.2881981501367297,3554.0,0.737314027702632,0.2898127159836637,3581.0,4140.239362716675,4549.352287769318,4140.239362716675,407.1881792545319,1.3539175987243652,0.0 -17400,0.1481361,0.25534528,,,,,,,,,,,,,, -17500,0.17245603,0.32806966,,,,,,,,,,,,,, -17600,0.21885078,0.24799469,,,,,,,,,,,,,, -17687,,,0.7420896802629743,0.2687692982809884,0.721324536393852,0.2878214117510639,3554.0,0.738631064450398,0.2892712228493263,3581.0,4220.379903554916,4633.554921627045,4220.379903554916,411.21010637283325,1.3826377391815186,0.0 -17700,0.12800635,0.2510383,,,,,,,,,,,,,, -17800,0.15638018,0.32395473,,,,,,,,,,,,,, -17900,0.12764733,0.26266772,,,,,,,,,,,,,, -18000,0.09272873,0.41921952,,,,,,,,,,,,,, -18034,,,0.7414772169930595,0.2690235035760062,0.7205499360315841,0.2880781406623698,3554.0,0.738011065890289,0.2894379488751396,3581.0,4300.405557632446,4717.645820856094,4300.405557632446,415.2354400157929,1.4109349250793457,0.0 -18100,0.18604067,0.2710367,,,,,,,,,,,,,, -18200,0.16881555,0.26821923,,,,,,,,,,,,,, -18300,0.3371175,0.26516217,,,,,,,,,,,,,, -18376,,,0.7417355946132115,0.2698216608592442,0.7205434100441404,0.2891783534507773,3554.0,0.7376423665046425,0.2907369188075956,3581.0,4380.436009883881,4801.739559173584,4380.436009883881,419.25944995880127,1.4389152526855469,0.0 -18400,0.079882935,0.25487012,,,,,,,,,,,,,, -18500,0.0968857,0.3540184,,,,,,,,,,,,,, -18600,0.2500149,0.333723,,,,,,,,,,,,,, -18700,0.13640295,0.25236213,,,,,,,,,,,,,, -18722,,,0.7425732612609863,0.2683524574552263,0.7210469414963773,0.2878672825733329,3554.0,0.7383829014023666,0.2893671133237922,3581.0,4460.510647535324,4885.879251480103,4460.510647535324,423.2813596725464,1.4707441329956057,0.0 -18800,0.20833164,0.28508028,,,,,,,,,,,,,, -18900,0.17927279,0.24836722,,,,,,,,,,,,,, -19000,0.20833334,0.31053624,,,,,,,,,,,,,, -19071,,,0.7413437707083566,0.269416298185076,0.7201945101470174,0.2886675060231429,3554.0,0.7375544867879084,0.2900805820868821,3581.0,4540.564539909363,4969.998610019684,4540.564539909363,427.30615234375,1.4997599124908447,0.0 -19100,0.22940952,0.3233062,,,,,,,,,,,,,, -19200,0.18496849,0.22217262,,,,,,,,,,,,,, -19300,0.12639852,0.243649,,,,,,,,,,,,,, -19400,0.178091,0.3115976,,,,,,,,,,,,,, -19414,,,0.7421463557652065,0.268864699772426,0.7212591391302055,0.2880018553038829,3554.0,0.7385687509817439,0.289402224304489,3581.0,4620.680071592331,5054.177350997925,4620.680071592331,431.3310775756836,1.5268681049346924,0.0 -19500,0.14423086,0.30965364,,,,,,,,,,,,,, -19600,0.22648062,0.3679847,,,,,,,,,,,,,, -19700,0.2071091,0.24056305,,,,,,,,,,,,,, -19761,,,0.7418694496154785,0.2684040069580078,0.7203742152328363,0.2876706099201603,3554.0,0.7378884160761658,0.2890437855064053,3581.0,4700.793580293655,5138.35369515419,4700.793580293655,435.3556377887726,1.553640365600586,0.0 -19800,0.25964788,0.22019272,,,,,,,,,,,,,, -19900,0.10258712,0.2668504,,,,,,,,,,,,,, -20000,0.21333958,0.3565205,,,,,,,,,,,,,, -20100,0.077249736,0.36217073,,,,,,,,,,,,,, -20106,,,0.7428394726344517,0.2686488287789481,0.7216941820615855,0.2879764383001547,3554.0,0.7390157171748464,0.2893898502404182,3581.0,4780.86026096344,5222.482827663422,4780.86026096344,439.3786163330078,1.5818867683410645,0.0 -20200,0.14843787,0.27091753,,,,,,,,,,,,,, -20300,0.121638976,0.2769406,,,,,,,,,,,,,, -20400,0.1732567,0.26218092,,,,,,,,,,,,,, -20450,,,0.7424825940813337,0.268344555582319,0.7212864795828644,0.2875891896399743,3554.0,0.7386649482511868,0.2890304910574036,3581.0,4860.848987817764,5306.533141613007,4860.848987817764,443.4014058113098,1.6094567775726318,0.0 -20500,0.2317639,0.25093967,,,,,,,,,,,,,, -20600,0.1885758,0.27318877,,,,,,,,,,,,,, -20700,0.14837745,0.28257376,,,,,,,,,,,,,, -20795,,,0.7422796658107212,0.268923418862479,0.7209059801675929,0.2882749163574493,3554.0,0.7381565548860305,0.289724018146991,3581.0,4940.809394598007,5390.558655261993,4940.809394598007,447.4226930141449,1.6419434547424316,0.0 -20800,0.18226236,0.2284774,,,,,,,,,,,,,, -20900,0.07879915,0.3265239,,,,,,,,,,,,,, -21000,0.1561565,0.36973757,,,,,,,,,,,,,, -21100,0.07723149,0.34102514,,,,,,,,,,,,,, -21143,,,0.7402078083583287,0.2688467161996024,0.7190345330525464,0.2878423636054885,3554.0,0.7366146033318207,0.2891526977232267,3581.0,5020.819479465485,5474.631682395935,5020.819479465485,451.4455780982971,1.6706173419952393,0.0 -21200,0.14562725,0.26734486,,,,,,,,,,,,,, -21300,0.18653752,0.3478473,,,,,,,,,,,,,, -21400,0.17461611,0.25075755,,,,,,,,,,,,,, -21492,,,0.7418524878365653,0.2684067487716675,0.7208238214204066,0.2874485717842132,3554.0,0.7382451163693801,0.2888301198491169,3581.0,5100.909521818161,5558.785286426544,5100.909521818161,455.46900844573975,1.699165105819702,0.0 -21500,0.13443425,0.22556287,,,,,,,,,,,,,, -21600,0.16493487,0.3666648,,,,,,,,,,,,,, -21700,0.090249024,0.27065387,,,,,,,,,,,,,, -21800,0.19275588,0.20022056,,,,,,,,,,,,,, -21835,,,0.7428131103515625,0.268767237663269,0.7215138587243247,0.2881664475766741,3554.0,0.7389221106185423,0.2895198631339884,3581.0,5180.924647092819,5642.864949226379,5180.924647092819,459.4944705963135,1.727196455001831,0.0 -21900,0.12567613,0.34225947,,,,,,,,,,,,,, -22000,0.14932069,0.30938035,,,,,,,,,,,,,, -22100,0.10675633,0.34285858,,,,,,,,,,,,,, -22183,,,0.7425223759242466,0.2683233533586774,0.7213983143992684,0.287667398447392,3554.0,0.7386228150743508,0.2891085874232058,3581.0,5261.056898117065,5727.059560060501,5261.056898117065,463.5176067352295,1.7548956871032717,0.0 -22200,0.11451594,0.3013047,,,,,,,,,,,,,, -22300,0.13045,0.2583173,,,,,,,,,,,,,, -22400,0.0938535,0.30854163,,,,,,,,,,,,,, -22500,0.16748287,0.23890932,,,,,,,,,,,,,, -22531,,,0.7441225733075824,0.2679621321814401,0.7233546680456879,0.2872024218420266,3554.0,0.7405354431373918,0.2886374866919157,3581.0,5341.214492321014,5811.279220342636,5341.214492321014,467.53966760635376,1.7834970951080322,0.0 -22600,0.3216787,0.21631248,,,,,,,,,,,,,, -22700,0.14996247,0.24601418,,,,,,,,,,,,,, -22800,0.111857355,0.22233495,,,,,,,,,,,,,, -22876,,,0.7432924679347447,0.2679972989218576,0.7215693639648987,0.2877167726944991,3554.0,0.7389800607808573,0.2891892745021467,3581.0,5421.3310425281525,5895.463119029999,5421.3310425281525,471.5667657852173,1.812317132949829,0.0 -22900,0.18233022,0.19371803,,,,,,,,,,,,,, -23000,0.32794294,0.4189075,,,,,,,,,,,,,, -23100,0.23776324,0.30611634,,,,,,,,,,,,,, -23200,0.12220548,0.33897018,,,,,,,,,,,,,, -23222,,,0.7425169263567243,0.267835259437561,0.7215940940225802,0.2870750276976646,3554.0,0.7390121038117844,0.2884190486683189,3581.0,5501.379611253738,5979.571918964386,5501.379611253738,475.58651852607727,1.84131932258606,0.0 -23300,0.09284054,0.22518323,,,,,,,,,,,,,, -23400,0.12126483,0.35455918,,,,,,,,,,,,,, -23500,0.19685179,0.23818202,,,,,,,,,,,,,, -23567,,,0.7436034338814872,0.2676151139395578,0.7222664081184933,0.2870295518693725,3554.0,0.7396102858401983,0.288459818311924,3581.0,5581.438790082932,6063.698635339737,5581.438790082932,479.6103277206421,1.8737876415252688,0.0 -23600,0.16346294,0.2629887,,,,,,,,,,,,,, -23700,0.13129163,0.31892747,,,,,,,,,,,,,, -23800,0.22334845,0.21513513,,,,,,,,,,,,,, -23900,0.17152353,0.2677871,,,,,,,,,,,,,, -23910,,,0.7438061577933175,0.2682526792798723,0.7224687137292487,0.2877475993983451,3554.0,0.7397383897872452,0.2890792373704098,3581.0,5661.581089258194,6147.907499790192,5661.581089258194,483.6366469860077,1.902566194534301,0.0 -24000,0.07815804,0.22444965,,,,,,,,,,,,,, -24100,0.14065552,0.277445,,,,,,,,,,,,,, -24200,0.14024204,0.24236745,,,,,,,,,,,,,, -24254,,,0.742143086024693,0.267869165965489,0.7209467160681626,0.2871526354272914,3554.0,0.7383813333391511,0.2885350171696104,3581.0,5741.601750612259,6231.99070930481,5741.601750612259,487.6585056781769,1.932025671005249,0.0 -24300,0.14498931,0.320309,,,,,,,,,,,,,, -24400,0.24181964,0.22425458,,,,,,,,,,,,,, -24500,0.16470343,0.33399206,,,,,,,,,,,,,, -24600,,,0.7429958071027484,0.2684909956795828,0.7214709245964055,0.2879671645285242,3554.0,0.7388236635192683,0.2893403880724658,3581.0,5821.691652059555,6316.142263650894,5821.691652059555,491.6800878047943,1.9605860710144043,0.0 -24600,0.116361596,0.24352539,,,,,,,,,,,,,, -24700,0.10673512,0.29744977,,,,,,,,,,,,,, -24800,0.080180764,0.3316514,,,,,,,,,,,,,, -24900,0.22697961,0.33479375,,,,,,,,,,,,,, -24942,,,0.7429577963692802,0.2679623876299177,0.7216223275050999,0.2873232556516425,3554.0,0.7389417454970678,0.2886966640341385,3581.0,5901.788957595825,6400.3060693740845,5901.788957595825,495.7059428691864,1.9899907112121584,0.0 -25000,0.21879636,0.27466834,,,,,,,,,,,,,, -25100,0.15927882,0.29735464,,,,,,,,,,,,,, -25200,0.12966484,0.28921503,,,,,,,,,,,,,, -25286,,,0.7442034993852887,0.2674739871706281,0.7227554450091446,0.2869437351344875,3554.0,0.7401606760288676,0.2882990918323269,3581.0,5981.759289741516,6484.3422310352325,5981.759289741516,499.72960019111633,2.020745277404785,0.0 -25300,0.13721597,0.22409856,,,,,,,,,,,,,, -25400,0.1405699,0.25523245,,,,,,,,,,,,,, -25500,0.106038384,0.33406937,,,,,,,,,,,,,, -25600,0.053671815,0.36659864,,,,,,,,,,,,,, -25632,,,0.7437772750854492,0.2674898760659354,0.722484719572137,0.2869256169325056,3554.0,0.7397795003141581,0.2882901947779949,3581.0,6061.882197141647,6568.528303146362,6061.882197141647,503.7518606185913,2.050223588943481,0.0 -25700,0.15386026,0.28283662,,,,,,,,,,,,,, -25800,0.12390102,0.27321595,,,,,,,,,,,,,, -25900,0.11564208,0.26610628,,,,,,,,,,,,,, -25975,,,0.7442035675048828,0.2677900450570242,0.7226171627699424,0.2873275490644344,3554.0,0.7398986731185423,0.2887357292612049,3581.0,6141.866056203842,6652.575384140015,6141.866056203842,507.7743980884552,2.079303741455078,0.0 -26000,0.16900428,0.29583472,,,,,,,,,,,,,, -26100,0.074763395,0.27838483,,,,,,,,,,,,,, -26200,0.20160536,0.38238385,,,,,,,,,,,,,, -26300,0.090991855,0.30336642,,,,,,,,,,,,,, -26317,,,0.7420498303004673,0.2675496169498988,0.7204110355409398,0.2871192326757703,3554.0,0.7379803182159314,0.2884458420963069,3581.0,6222.008386850357,6736.779012918472,6222.008386850357,511.7950129508972,2.1087591648101807,0.0 -26400,0.20219098,0.288974,,,,,,,,,,,,,, -26500,0.10709347,0.27056557,,,,,,,,,,,,,, -26600,0.09476083,0.29943267,,,,,,,,,,,,,, -26663,,,0.7431620870317731,0.2673583030700683,0.7216088633625845,0.2868381687007597,3554.0,0.7390444877260193,0.2882063374842921,3581.0,6302.204536914825,6821.033949136734,6302.204536914825,515.8137700557709,2.137401580810547,0.0 -26700,0.13407883,0.29384238,,,,,,,,,,,,,, -26800,0.22589236,0.20937598,,,,,,,,,,,,,, -26900,0.1311447,0.26475435,,,,,,,,,,,,,, -27000,0.065362595,0.22730796,,,,,,,,,,,,,, -27009,,,0.7441284315926688,0.2674673965999058,0.7224542191676632,0.2871214309031197,3554.0,0.7397913630532672,0.2884547391506388,3581.0,6382.178173780441,6905.071727991104,6382.178173780441,519.8362936973572,2.1678225994110107,0.0 -27100,0.096852005,0.41389874,,,,,,,,,,,,,, -27200,0.08735115,0.29748344,,,,,,,,,,,,,, -27300,0.121531725,0.26327807,,,,,,,,,,,,,, -27354,,,0.7447470937456403,0.2674651827130999,0.7230006160532146,0.2872797032722812,3554.0,0.7403103238009634,0.2886721886126431,3581.0,6462.302304506302,6989.259694099426,6462.302304506302,523.8579688072205,2.198601007461548,0.0 -27400,0.15517871,0.30645707,,,,,,,,,,,,,, -27500,0.14403902,0.34435183,,,,,,,,,,,,,, -27600,0.11952214,0.29184395,,,,,,,,,,,,,, -27700,,,0.7439596993582589,0.2670155423028128,0.7222767810037986,0.2866732844758986,3554.0,0.7397503207030159,0.288021510554838,3581.0,6542.314068317413,7073.3403215408325,6542.314068317413,527.8807153701782,2.23337459564209,0.0 -27700,0.12413268,0.3104113,,,,,,,,,,,,,, -27800,0.1452175,0.27381253,,,,,,,,,,,,,, -27900,0.17061211,0.26885042,,,,,,,,,,,,,, -28000,0.09685055,0.24809606,,,,,,,,,,,,,, -28045,,,0.7440105165754046,0.267963205065046,0.7223733656179657,0.2876472022336188,3554.0,0.7397511388229545,0.2889350778195685,3581.0,6622.435270547867,7157.523993968964,6622.435270547867,531.9025478363037,2.262672185897827,0.0 -28100,0.14277826,0.17731693,,,,,,,,,,,,,, -28200,0.1797633,0.33885357,,,,,,,,,,,,,, -28300,0.11987242,0.22273883,,,,,,,,,,,,,, -28391,,,0.7441020011901855,0.268479687826974,0.7219329645074212,0.2886336052357379,3554.0,0.7392949006038816,0.2899664543554524,3581.0,6702.4539477825165,7241.608174085617,6702.4539477825165,535.9265110492706,2.2928481101989746,0.0 -28400,0.13124885,0.23532236,,,,,,,,,,,,,, -28500,0.107759,0.226653,,,,,,,,,,,,,, -28600,0.108998865,0.34364918,,,,,,,,,,,,,, -28700,0.07914831,0.3436704,,,,,,,,,,,,,, -28738,,,0.7438011169433594,0.2669568913323538,0.7221119139525887,0.2865940452494109,3554.0,0.7395879920718724,0.2879308696833112,3581.0,6782.633191108704,7325.848667621613,6782.633191108704,539.9468801021576,2.322210788726806,0.0 -28800,0.18806243,0.2528129,,,,,,,,,,,,,, -28900,0.13752073,0.29754227,,,,,,,,,,,,,, -29000,0.19324996,0.33811823,,,,,,,,,,,,,, -29081,,,0.7445215497698102,0.267032333782741,0.7230511752822524,0.2865559025701674,3554.0,0.7403883179017733,0.287915700376117,3581.0,6862.738671541214,7410.014344930649,6862.738671541214,543.9663681983948,2.3514530658721924,0.0 -29100,0.10909572,0.240905,,,,,,,,,,,,,, -29200,0.09276578,0.24607162,,,,,,,,,,,,,, -29300,0.18192264,0.28615034,,,,,,,,,,,,,, -29400,0.18082616,0.24728061,,,,,,,,,,,,,, -29421,,,0.7445074490138462,0.2672220127923148,0.7226337181696679,0.2870260999654878,3554.0,0.7399836894154915,0.2883681888788048,3581.0,6942.728103876114,7494.060920238495,6942.728103876114,547.9826490879059,2.3811745643615723,0.0 -29500,0.247162,0.30357873,,,,,,,,,,,,,, -29600,0.11120415,0.3188827,,,,,,,,,,,,,, -29700,0.22506793,0.33779934,,,,,,,,,,,,,, -29768,,,0.7450000217982701,0.2670083386557443,0.723371773002251,0.2868083895896349,3554.0,0.7407431092484641,0.2881304227716594,3581.0,7022.759325265884,7578.158590555191,7022.759325265884,552.0078539848328,2.4109811782836914,0.0 -29800,0.077187195,0.2907893,,,,,,,,,,,,,, -29900,0.15376155,0.25364453,,,,,,,,,,,,,, -30000,0.10834406,0.24841303,,,,,,,,,,,,,, -30100,0.13799955,0.21941662,,,,,,,,,,,,,, -30115,,,0.745185102735247,0.2668940850666591,0.7236418114932118,0.2865342294223937,3554.0,0.7410003397924811,0.2878097538440728,3581.0,7102.75586271286,7662.220293045044,7102.75586271286,556.0319905281067,2.440687417984009,0.0 -30200,0.10560653,0.2170438,,,,,,,,,,,,,, -30300,0.19362636,0.2216832,,,,,,,,,,,,,, -30400,0.22186567,0.22607619,,,,,,,,,,,,,, -30457,,,0.7440342221941266,0.2668058701923915,0.7220873899787211,0.2865863686273389,3554.0,0.7395849241221027,0.2879080986783545,3581.0,7182.74059510231,7746.269132852554,7182.74059510231,560.0553650856018,2.4702839851379395,0.0 -30500,0.32861656,0.30424425,,,,,,,,,,,,,, -30600,0.16237372,0.30872843,,,,,,,,,,,,,, -30700,0.096711084,0.36784738,,,,,,,,,,,,,, -30800,0.13140272,0.25476995,,,,,,,,,,,,,, -30802,,,0.7446126937866211,0.2671293360846383,0.7229679174213914,0.2867912846330719,3554.0,0.740367728549986,0.2881001182456018,3581.0,7262.764033317566,7830.354666471481,7262.764033317566,564.0749621391296,2.5015549659729004,0.0 -30900,0.10527988,0.26957473,,,,,,,,,,,,,, -31000,0.12383668,0.3035981,,,,,,,,,,,,,, -31100,0.12717453,0.3293863,,,,,,,,,,,,,, -31147,,,0.7450896671840123,0.2669330665043422,0.7233025288407429,0.2867378574042892,3554.0,0.7406983853584892,0.2880672230064053,3581.0,7342.755488395691,7914.412054538727,7342.755488395691,568.098906993866,2.532049655914306,0.0 -31200,0.15564811,0.23887715,,,,,,,,,,,,,, -31300,0.28185317,0.2655444,,,,,,,,,,,,,, -31400,0.17719316,0.19776307,,,,,,,,,,,,,, -31489,,,0.7447887829371861,0.2668275151933942,0.7230282999788971,0.2865797052506858,3554.0,0.74044149569778,0.2878837596101822,3581.0,7422.793350458145,7998.517551660538,7422.793350458145,572.1239104270935,2.563407182693481,0.0 -31500,0.18116084,0.2795556,,,,,,,,,,,,,, -31600,0.18998718,0.25632745,,,,,,,,,,,,,, -31700,0.12698217,0.26556256,,,,,,,,,,,,,, -31800,0.13163576,0.2580038,,,,,,,,,,,,,, -31837,,,0.7444804736546108,0.2666935239519392,0.7226634629334905,0.2864793424332618,3554.0,0.740141654740296,0.2877985728715792,3581.0,7502.960831642151,8082.751039028168,7502.960831642151,576.1479756832123,2.593926191329956,0.0 -31900,0.21200821,0.280298,,,,,,,,,,,,,, -32000,0.08379043,0.37313622,,,,,,,,,,,,,, -32100,0.09420667,0.3403038,,,,,,,,,,,,,, -32183,,,0.7452107838221959,0.2666273287364414,0.7234276217158483,0.286437782197436,3554.0,0.7408316707318138,0.2877483607603497,3581.0,7583.132245779037,8166.990717411041,7583.132245779037,580.1747455596924,2.6239583492279053,0.0 -32200,0.08459192,0.26754558,,,,,,,,,,,,,, -32300,0.18477747,0.26377806,,,,,,,,,,,,,, -32400,0.11211273,0.29017225,,,,,,,,,,,,,, -32500,0.1418771,0.2735875,,,,,,,,,,,,,, -32527,,,0.7456954547337123,0.2667953797749111,0.7239288862461312,0.286699027778999,3554.0,0.7412730464386693,0.2880450655914025,3581.0,7663.143029689789,8251.064739465714,7663.143029689789,584.1944150924683,2.6562201976776123,0.0 -32527,,,,,,,,,,,7663.143029689789,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index b7a23bf8a..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.67181849479675,0.0,69.66642546653748,1,0,69.66642546653748,0.0006000000284984,6.910250186920166,10000,107.33833169937134,0.0007573341717943,6.910542011260986,0.0009599999757483,6.910243988037109,50000 -55.593273878097534,0.0293140411376953,579.8638372421265,1514,0,579.8638372421265,0.049600001424551,5.629796504974365,10000,635.5374145507812,0.0735411345958709,5.336194038391113,0.0678599998354911,5.403800964355469,50000 -73.83026790618896,0.0556843280792236,1089.8733115196228,3025,0,1089.8733115196228,0.1212000027298927,4.7993268966674805,10000,1163.8596937656405,0.1801857352256775,4.261962890625,0.1609999984502792,4.37910270690918,50000 -92.06435537338255,0.0844905376434326,1600.124408245087,4537,0,1600.124408245087,0.1918000131845474,4.191176891326904,10000,1692.4212460517883,0.2841398119926452,3.504873514175415,0.2563000023365021,3.6541614532470703,50000 -110.3130383491516,0.1100184917449951,2110.33571267128,6049,0,2110.33571267128,0.2526000142097473,3.744076728820801,10000,2220.955129146576,0.3693797886371612,2.9464893341064453,0.3428999781608581,3.0969789028167725,50000 -128.61637711524963,0.1401574611663818,2620.27745461464,7561,0,2620.27745461464,0.3183000087738037,3.374671220779419,10000,2749.279436826706,0.443459004163742,2.543766975402832,0.4113999903202057,2.717856168746948,50000 -147.15200424194336,0.1761560440063476,3130.3148624897003,9073,0,3130.3148624897003,0.3585000038146972,3.096574783325196,10000,3277.937031984329,0.5113998651504517,2.2248189449310303,0.4605399966239929,2.471454620361328,50000 -165.99961352348328,0.2035794258117675,3640.3234446048737,10586,0,3640.3234446048737,0.3758000135421753,3.0121326446533203,10000,3806.8680925369263,0.5440847873687744,2.0609612464904785,0.492819994688034,2.328207492828369,50000 -186.73877835273743,0.2456672191619873,4150.755066394806,12101,0,4150.755066394806,0.410500019788742,2.809180498123169,10000,4338.130095720291,0.5639150142669678,1.929271221160889,0.521619975566864,2.1376399993896484,50000 -206.2655758857727,0.2724974155426025,4660.733189105988,13616,0,4660.733189105988,0.4193000197410583,2.7065205574035645,10000,4867.710218667984,0.5889668464660645,1.818697810173035,0.5448799729347229,2.03045392036438,50000 -228.1629378795624,0.3139839172363281,5170.75431728363,15132,0,5170.75431728363,0.4246000349521637,2.713749647140503,10000,5399.7187423706055,0.5876913070678711,1.8149843215942385,0.5457000136375427,2.0031583309173584,50000 -249.90737676620483,0.3474881649017334,5680.732532739639,16648,0,5680.732532739639,0.4403000175952911,2.5987136363983154,10000,5931.523034095764,0.5989118218421936,1.7425923347473145,0.5619199872016907,1.9272526502609253,50000 -274.1750280857086,0.3943376541137695,6190.90416431427,18165,0,6190.90416431427,0.4456000328063965,2.6102218627929688,10000,6466.05579328537,0.6433154940605164,1.562894344329834,0.5627399682998657,1.9308120012283323,50000 -299.4436390399933,0.4216992855072021,6700.854817867279,19681,0,6700.854817867279,0.4467000067234039,2.569241523742676,10000,7001.347071886063,0.6237842440605164,1.6315232515335083,0.5619800090789795,1.91173791885376,50000 -325.11325001716614,0.4727559089660644,7210.977495670319,21198,0,7210.977495670319,0.4509000182151794,2.527494430541992,10000,7537.237438201904,0.6259167790412903,1.6084551811218262,0.5781800150871277,1.842873334884644,50000 -348.44293189048767,0.514392614364624,7720.921734571457,22715,0,7720.921734571457,0.4554000198841095,2.526982307434082,10000,8070.599623918533,0.6239436864852905,1.6318671703338623,0.5826199650764465,1.8319292068481443,50000 -372.9615008831024,0.5425989627838135,8231.1480448246,24232,0,8231.1480448246,0.4675000309944153,2.4928462505340576,10000,8605.418276548386,0.6275310516357422,1.646264910697937,0.5839399695396423,1.8555593490600584,50000 -395.8874454498291,0.5697648525238037,8741.327996253967,25749,0,8741.327996253967,0.4647000133991241,2.5192012786865234,10000,9138.597157239914,0.6228276491165161,1.6504992246627808,0.5827800035476685,1.833041071891785,50000 -418.5490050315857,0.6001632213592529,9251.424172401428,27266,0,9251.424172401428,0.4601000249385834,2.5271754264831543,10000,9671.430490970612,0.6548548936843872,1.5148216485977173,0.581820011138916,1.857088208198548,50000 -440.63156938552856,0.6332635879516602,9761.575938463213,28784,0,9761.575938463213,0.4591000080108642,2.5145013332366943,10000,10203.743408441544,0.6424983739852905,1.578094244003296,0.5882999897003174,1.832597017288208,50000 -464.7254109382629,0.6648633480072021,10271.727610588074,30303,0,10271.727610588074,0.4589000344276428,2.4995713233947754,10000,10738.066817998886,0.6371771097183228,1.577159404754639,0.5897799730300903,1.7975454330444336,50000 -488.7787292003632,0.6953849792480469,10781.824187994003,31821,0,10781.824187994003,0.4662000238895416,2.477463960647583,10000,11272.292582035065,0.6344068646430969,1.5880506038665771,0.5906800031661987,1.8002487421035769,50000 -512.6822199821472,0.7257883548736572,11291.99487566948,33339,0,11291.99487566948,0.4714000225067138,2.437567949295044,10000,11806.442516088486,0.6326330900192261,1.577653169631958,0.5985000133514404,1.7677619457244873,50000 -536.9916772842407,0.756098747253418,11802.263848781586,34858,0,11802.263848781586,0.4784000217914581,2.4060091972351074,10000,12341.097544431686,0.6404655575752258,1.569533348083496,0.5995999574661255,1.7648115158081057,50000 -561.7181787490845,0.7879447937011719,12312.29682803154,36376,0,12312.29682803154,0.4791000187397003,2.39719295501709,10000,12875.93498826027,0.670340359210968,1.3972032070159912,0.6022799611091614,1.7204241752624512,50000 -584.714658498764,0.8208250999450684,12822.334090471268,37894,0,12822.334090471268,0.4779000282287597,2.4191813468933105,10000,13409.04744195938,0.6550741195678711,1.4708644151687622,0.6022599935531616,1.7340105772018433,50000 -607.5823495388031,0.8514549732208252,13332.26151394844,39412,0,13332.26151394844,0.4886000156402588,2.380843162536621,10000,13941.919185638428,0.6563097834587097,1.4915953874588013,0.6089400053024292,1.7161346673965454,50000 -629.6153049468994,0.883293628692627,13842.18695950508,40930,0,13842.18695950508,0.4817000329494476,2.3924968242645264,10000,14473.955300807953,0.6541174650192261,1.521090388298035,0.6050199866294861,1.733443260192871,50000 -651.8660788536072,0.916776180267334,14352.171369075775,42448,0,14352.171369075775,0.4720000326633453,2.444088220596313,10000,15006.26985692978,0.6384526491165161,1.568603515625,0.5953999757766724,1.7692608833312988,50000 -674.709219455719,0.9466893672943116,14862.203121185305,43967,0,14862.203121185305,0.4828000366687774,2.402573585510254,10000,15539.22065258026,0.6627271771430969,1.4629147052764893,0.6109600067138672,1.699565887451172,50000 -696.4635140895844,0.9763743877410888,15372.199818134308,45486,0,15372.199818134308,0.4945000112056732,2.347254753112793,10000,16071.047409534454,0.674246609210968,1.4069665670394895,0.6114199757575989,1.696818232536316,50000 -717.4448945522308,1.00885272026062,15882.18303823471,47004,0,15882.18303823471,0.4917000234127044,2.3279991149902344,10000,16602.090711593628,0.6636040806770325,1.4506380558013916,0.6128799915313721,1.6783645153045654,50000 -738.927268743515,1.0390520095825195,16392.207607269287,48523,0,16392.207607269287,0.4991000294685364,2.3448281288146973,10000,17133.673773765564,0.6721540093421936,1.4366706609725952,0.6180399656295776,1.684572458267212,50000 -759.0310838222504,1.070474624633789,16902.16853904724,50042,0,16902.16853904724,0.491100013256073,2.3817460536956787,10000,17663.81556916237,0.6578244566917419,1.4930152893066406,0.6173999905586243,1.6958266496658323,50000 -778.0412209033966,1.103534698486328,17412.375111818314,51561,0,17412.375111818314,0.490200012922287,2.355665445327759,10000,18193.11154937744,0.6567083597183228,1.477152943611145,0.613099992275238,1.6816418170928955,50000 -795.7929601669312,1.141834735870361,17922.586621522903,53081,0,17922.586621522903,0.496800035238266,2.343891859054565,10000,18721.159491062164,0.7010921239852905,1.2759617567062378,0.6177399754524231,1.6607871055603027,50000 -813.3257002830505,1.177684307098389,18432.66418099404,54600,0,18432.66418099404,0.4958000183105469,2.347243547439575,10000,19248.851968050003,0.6756815910339355,1.4029406309127808,0.6170799732208252,1.6762901544570925,50000 -832.5534996986389,1.2281639575958252,18942.804827928543,56119,0,18942.804827928543,0.5015000104904175,2.316536903381348,10000,19778.317579746246,0.6758211255073547,1.4040615558624268,0.6243000030517578,1.6517601013183594,50000 -850.4759373664856,1.2610328197479248,19453.03369998932,57639,0,19453.03369998932,0.5049999952316284,2.2848803997039795,10000,20306.5486035347,0.6713966727256775,1.409226417541504,0.6262199878692627,1.636522889137268,50000 -867.8226597309113,1.2980821132659912,19962.99881720543,59157,0,19962.99881720543,0.4987000226974487,2.3283464908599854,10000,20833.94366264344,0.6717753410339355,1.4401389360427856,0.6253199577331543,1.6600350141525269,50000 -885.1255774497986,1.3364100456237793,20473.223114967343,60677,0,20473.223114967343,0.495600014925003,2.357531309127808,10000,21361.55599737168,0.6564891338348389,1.4987339973449707,0.614579975605011,1.6961729526519775,50000 -902.6221942901612,1.3790311813354492,20983.24218249321,62196,0,20983.24218249321,0.5074000358581543,2.304157495498657,10000,21889.161187648773,0.7053770422935486,1.2767313718795776,0.6273599863052368,1.6219308376312256,50000 -920.0448322296144,1.4266955852508545,21493.27866005897,63715,0,21493.27866005897,0.5064000487327576,2.267906904220581,10000,22416.715700864792,0.686922013759613,1.3271760940551758,0.626800000667572,1.617149829864502,50000 -937.196320772171,1.4666364192962646,22003.35453414917,65234,0,22003.35453414917,0.4989000260829925,2.329805850982666,10000,22944.03020715713,0.6711774468421936,1.4122507572174072,0.6250199675559998,1.6446616649627686,50000 -954.4305288791656,1.5082590579986572,22513.46359586716,66753,0,22513.46359586716,0.5105000138282776,2.2873687744140625,10000,23471.462081432343,0.6834542155265808,1.3994112014770508,0.635159969329834,1.6157879829406738,50000 -971.5997366905212,1.5481698513031006,23023.515419483185,68272,0,23023.515419483185,0.5113000273704529,2.254912614822388,10000,23998.77006816864,0.6808832883834839,1.3870117664337158,0.6350199580192566,1.601508378982544,50000 -988.909274339676,1.5890882015228271,23533.512261867523,69791,0,23533.512261867523,0.5146000385284424,2.2334415912628174,10000,24526.164582252502,0.6853276491165161,1.3273032903671265,0.6344199776649475,1.5748940706253052,50000 -1006.218918800354,1.6315371990203855,24043.516547441483,71310,0,24043.516547441483,0.5209000110626221,2.210066556930542,10000,25053.56804537773,0.7115951776504517,1.2286536693572998,0.6373599767684937,1.564759612083435,50000 -1023.3500711917876,1.678883075714111,24553.690885543823,72830,0,24553.690885543823,0.509600043296814,2.260319232940674,10000,25580.96841239929,0.6828762888908386,1.3698737621307373,0.6262800097465515,1.6336623430252075,50000 -1040.6843490600586,1.718044996261597,25063.78478884697,74349,0,25063.78478884697,0.5154000520706177,2.2236592769622803,10000,26108.48323750496,0.6959502696990967,1.3203473091125488,0.6404399871826172,1.5740586519241333,50000 -1058.6442294120789,1.7606401443481443,25573.829756498337,75868,0,25573.829756498337,0.5042999982833862,2.291550636291504,10000,26636.57822918892,0.6812818646430969,1.3611277341842651,0.6279599666595459,1.6080148220062256,50000 -1075.7697627544403,1.8028457164764404,26084.01960873604,77387,0,26084.01960873604,0.5112000107765198,2.2878735065460205,10000,27163.98302912712,0.6804049611091614,1.39091157913208,0.6298800110816956,1.6200053691864014,50000 -1093.0664336681366,2.7603235244750977,26593.105378627777,78903,0,26593.105378627777,0.5195000171661377,2.204761743545532,10000,27691.36990666389,0.7225167155265808,1.2051185369491575,0.6467399597167969,1.5430620908737185,50000 -1110.3081967830658,2.8014683723449707,27103.18165063858,80423,0,27103.18165063858,0.5088000297546387,2.290923833847046,10000,28218.77651834488,0.6974050998687744,1.309516429901123,0.6312400102615356,1.6141302585601809,50000 -1127.628051996231,2.843165397644043,27613.21877503395,81943,0,27613.21877503395,0.5279000401496887,2.181778907775879,10000,28746.222467899323,0.7071707248687744,1.2736690044403076,0.6467799544334412,1.539712905883789,50000 -1144.8795185089111,2.9040091037750244,28123.213754415512,83462,0,28123.213754415512,0.5254000425338745,2.181201457977295,10000,29273.577413082123,0.7016302347183228,1.2925941944122314,0.6476199626922607,1.5447707176208496,50000 -1162.0792186260223,2.945914030075073,28633.28370141983,84982,0,28633.28370141983,0.5200999975204468,2.230949878692627,10000,29800.936242103577,0.6917649507522583,1.3234705924987793,0.6398599743843079,1.566019058227539,50000 -1179.1734466552734,2.9869160652160645,29143.66135954857,86501,0,29143.66135954857,0.5342000126838684,2.135103940963745,10000,30328.49675798416,0.7071906924247742,1.2525839805603027,0.6568399667739868,1.488780856132507,50000 -1196.4498274326324,3.030723810195923,29653.78978037834,88020,0,29653.78978037834,0.5270000100135803,2.18676495552063,10000,30855.992782831192,0.7352319955825806,1.1445754766464231,0.649459958076477,1.5333545207977295,50000 -1213.8182995319366,3.075812101364136,30164.02828192711,89541,0,30164.02828192711,0.523900032043457,2.197510004043579,10000,31383.692828655243,0.7146444320678711,1.2600167989730835,0.6505399942398071,1.539999008178711,50000 -1231.279939413071,3.1218960285186768,30674.164899349213,91061,0,30674.164899349213,0.5308000445365906,2.186326503753662,10000,31911.384786367416,0.7093231678009033,1.2750592231750488,0.6477000117301941,1.54591703414917,50000 -1248.3998510837555,3.161928653717041,31184.35217189789,92581,0,31184.35217189789,0.5351000428199768,2.1427531242370605,10000,32438.77952504158,0.720723032951355,1.2018429040908811,0.6581999659538269,1.4694814682006836,50000 -1265.8118696212769,3.207626342773437,31694.415167331696,94100,0,31694.415167331696,0.5314000248908997,2.120694875717163,10000,32966.347737550735,0.711355984210968,1.227055311203003,0.6619600057601929,1.455871343612671,50000 -1283.4004225730896,3.2659912109375,32204.58331465721,95620,0,32204.58331465721,0.5458000302314758,2.099045753479004,10000,33494.21009898186,0.7209023833274841,1.197264552116394,0.6657999753952026,1.4429244995117188,50000 -1300.787608385086,3.3119215965271,32714.638087511063,97141,0,32714.638087511063,0.5304000377655029,2.161627769470215,10000,34021.7456908226,0.7374441623687744,1.1343486309051514,0.6576399803161621,1.4939663410186768,50000 -1317.972489118576,3.3565704822540283,33224.680389881134,98661,0,33224.680389881134,0.5508000254631042,2.060675621032715,10000,34549.06488656998,0.7361487150192261,1.1169815063476562,0.6668399572372437,1.4288816452026367,50000 -1335.168357372284,3.4204201698303223,33734.73139023781,100181,0,33734.73139023781,0.5432000160217285,2.1067442893981934,10000,35076.42337989807,0.7201650142669678,1.1991384029388428,0.6631199717521667,1.4586812257766724,50000 -1352.396959066391,3.4652857780456543,34244.87643456459,101700,0,34244.87643456459,0.5473999977111816,2.062584638595581,10000,35603.890127658844,0.7339365482330322,1.128197193145752,0.6735199689865112,1.408560276031494,50000 -1369.6633830070496,3.5199079513549805,34754.832845926285,103220,0,34754.832845926285,0.539900004863739,2.114496946334839,10000,36131.21515202522,0.7251275181770325,1.1850506067276,0.668179988861084,1.444494366645813,50000 -1387.0266127586365,3.5705230236053467,35265.055674791336,104740,0,35265.055674791336,0.5522000193595886,2.052886724472046,10000,36658.89939570427,0.7435427308082581,1.0935951471328735,0.6726999878883362,1.40591561794281,50000 -1404.2709031105042,3.6195414066314697,35775.128962278366,106260,0,35775.128962278366,0.5509999990463257,2.0571823120117188,10000,37186.31377243996,0.7564970850944519,1.039839744567871,0.6774199604988098,1.3940984010696411,50000 -1421.402559518814,3.66951847076416,36285.17335796356,107781,0,36285.17335796356,0.5616000294685364,2.0226991176605225,10000,37713.587841272354,0.7531289458274841,1.063181757926941,0.6829400062561035,1.3601897954940796,50000 -1438.9501745700836,3.7160933017730713,36795.37970161438,109301,0,36795.37970161438,0.5556000471115112,2.055934190750122,10000,38241.43623971939,0.7407127022743225,1.1150754690170288,0.6787599921226501,1.402061939239502,50000 -1456.2903575897217,3.763372182846069,37305.41358447075,110820,0,37305.41358447075,0.5631000399589539,2.0043318271636963,10000,38768.90533566475,0.746113657951355,1.0706056356430054,0.685479998588562,1.3558672666549685,50000 -1473.4958517551422,3.8107833862304688,37815.575922966,112340,0,37815.575922966,0.5627000331878662,2.016084909439087,10000,39296.36831307411,0.7463727593421936,1.0985363721847534,0.6815400123596191,1.3779244422912598,50000 -1490.8633544445038,3.8569798469543457,38325.62386679649,113860,0,38325.62386679649,0.5616000294685364,2.0014493465423584,10000,39823.877307891846,0.7833625674247742,0.9273353815078736,0.6845200061798096,1.3640166521072388,50000 -1507.9782707691193,3.90530037879944,38835.71640062332,115379,0,38835.71640062332,0.5652000308036804,1.992992281913757,10000,40351.181077718735,0.7703882455825806,1.0078136920928955,0.6881399750709534,1.3621941804885864,50000 -1525.088080406189,3.955566167831421,39345.8971452713,116900,0,39345.8971452713,0.5777000188827515,1.9591028690338133,10000,40878.56984305382,0.7623764276504517,1.0332750082015991,0.694599986076355,1.3372550010681152,50000 -1542.2515261173248,4.002415657043457,39856.08269357681,118420,0,39856.08269357681,0.567300021648407,1.957592248916626,10000,41406.01397848129,0.7677973508834839,0.99867582321167,0.6958000063896179,1.3079572916030884,50000 -1559.540601491928,4.057379245758057,40366.02552604675,119939,0,40366.02552604675,0.5699000358581543,1.961030006408692,10000,41933.3487534523,0.760184109210968,1.0345245599746704,0.6921399831771851,1.3318723440170288,50000 -1576.865121126175,4.10741400718689,40876.20910453797,121459,0,40876.20910453797,0.5731000304222107,1.9369956254959104,10000,42460.95462989807,0.7603236436843872,1.0235108137130735,0.6966999769210815,1.3048481941223145,50000 -1594.183688402176,4.156118869781494,41386.174983263016,122979,0,41386.174983263016,0.5800999999046326,1.9291913509368896,10000,42988.33508872986,0.7931680083274841,0.8888539671897888,0.6993399858474731,1.2950785160064695,50000 -1611.4827795028689,4.206348896026611,41896.13381576538,124498,0,41896.13381576538,0.5751000046730042,1.938994646072388,10000,43515.6908364296,0.7772042155265808,0.9544240832328796,0.6949399709701538,1.311497926712036,50000 -1628.591871261597,4.261731863021851,42406.0859875679,126016,0,42406.0859875679,0.5834000110626221,1.9185949563980105,10000,44042.85533833504,0.7820471525192261,0.9358399510383606,0.7050999999046326,1.2812731266021729,50000 -1646.0038397312164,4.311935186386108,42916.115965127945,127536,0,42916.115965127945,0.579200029373169,1.9385350942611688,10000,44570.39530873299,0.7782206535339355,0.9608622193336488,0.7054600119590759,1.2842296361923218,50000 -1663.313729763031,4.684342861175537,43425.82639288902,129054,0,43425.82639288902,0.5873000025749207,1.8815577030181885,10000,45097.83585691452,0.7837212681770325,0.9352360963821412,0.712179958820343,1.254599690437317,50000 -1680.3904614448547,4.7375593185424805,43935.96074128151,130573,0,43935.96074128151,0.5934000015258789,1.860069990158081,10000,45625.14785838127,0.7919722199440002,0.8858956694602966,0.7138800024986267,1.2260390520095823,50000 -1697.4430515766144,4.790385007858276,44446.12138080597,132093,0,44446.12138080597,0.5836000442504883,1.886076092720032,10000,46152.46186089516,0.8088328838348389,0.8323812484741211,0.7155199646949768,1.2397068738937378,50000 -1714.5795366764069,4.842229604721069,44956.11803674698,133612,0,44956.11803674698,0.588200032711029,1.905611991882324,10000,46679.69448518753,0.7983298897743225,0.8808016180992126,0.7109400033950806,1.252920150756836,50000 -1731.6397836208344,4.896440029144287,45466.05125045776,135131,0,45466.05125045776,0.5926000475883484,1.841191053390503,10000,47206.79020857811,0.8057836294174194,0.8274676203727722,0.7206000089645386,1.196054220199585,50000 -1748.9506647586825,4.948869228363037,45976.103113889694,136650,0,45976.103113889694,0.5906000137329102,1.861374497413636,10000,47734.25307178497,0.7958585619926453,0.875244677066803,0.7117800116539001,1.2368409633636477,50000 -1766.2522943019867,5.00182056427002,46486.0466632843,138169,0,46486.0466632843,0.6002000570297241,1.803093433380127,10000,48261.59885954857,0.8053850531578064,0.8225789070129395,0.723859965801239,1.1837297677993774,50000 -1783.7273569107056,5.05206823348999,46996.25559186936,139689,0,46996.25559186936,0.6009000539779663,1.8257156610488887,10000,48789.38086462021,0.8387476205825806,0.7008848190307617,0.7271199822425842,1.168049693107605,50000 -1800.8658833503723,5.111589431762695,47506.24500584602,141209,0,47506.24500584602,0.6070000529289246,1.7983014583587646,10000,49316.61631822586,0.8303571343421936,0.7429234385490417,0.725600004196167,1.187226176261902,50000 -1818.0516149997711,5.166984558105469,48016.3487842083,142728,0,48016.3487842083,0.5998000502586365,1.849668025970459,10000,49844.00881743431,0.8176219463348389,0.8000547885894775,0.7247999906539917,1.205077886581421,50000 -1835.3061537742608,5.220117092132568,48526.44045686722,144248,0,48526.44045686722,0.6071000099182129,1.7840397357940674,10000,50371.45580005646,0.8232421875,0.7640501856803894,0.7290999889373779,1.1696761846542358,50000 -1852.686731815338,5.2726123332977295,49036.527978897095,145767,0,49036.527978897095,0.6126000285148621,1.7603540420532229,10000,50899.024411439896,0.8306760191917419,0.7338430881500244,0.735040009021759,1.146825909614563,50000 -1869.9471807479856,5.328448534011841,49546.45242190361,147287,0,49546.45242190361,0.6184000372886658,1.7458066940307615,10000,51426.31285953522,0.8342036008834839,0.7212420701980591,0.7377399802207947,1.13205885887146,50000 -1887.1576430797577,5.3819334506988525,50056.575719833374,148806,0,50056.575719833374,0.6178000569343567,1.7484245300292969,10000,51953.74762201309,0.8618662357330322,0.6142693758010864,0.7396999597549438,1.1240257024765017,50000 -1904.453207731247,5.434524297714233,50566.62089514732,150325,0,50566.62089514732,0.6203000545501709,1.736547350883484,10000,52481.188838005066,0.8557676672935486,0.6431095004081726,0.7424399852752686,1.1167047023773191,50000 -1921.752840518952,5.488680601119995,51076.7715549469,151844,0,51076.7715549469,0.6206000447273254,1.7190169095993042,10000,53008.741042375565,0.8517019748687744,0.6508660316467285,0.740399956703186,1.1089304685592651,50000 -1939.816710472107,5.542165279388428,51586.68243670464,153363,0,51586.68243670464,0.6276000142097473,1.702033519744873,10000,53536.81693482399,0.8570631146430969,0.6400169134140015,0.7451599836349487,1.0995455980300903,50000 -1956.984309911728,5.594337701797485,52096.776774168015,154882,0,52096.776774168015,0.6283000111579895,1.6885766983032229,10000,54064.17822790146,0.8557876348495483,0.6262343525886536,0.7460199594497681,1.0894845724105835,50000 -1974.597580432892,5.648033142089844,52606.98158121109,156401,0,52606.98158121109,0.6301000118255615,1.6781634092330933,10000,54592.09770488739,0.861348032951355,0.6097630858421326,0.7497199773788452,1.077207088470459,50000 -1991.7971727848053,5.691704273223877,53117.21468949318,157920,0,53117.21468949318,0.631600022315979,1.692663550376892,10000,55119.62186551094,0.8819953799247742,0.5402708649635315,0.7523799538612366,1.0711145401000977,50000 -2009.16032409668,5.752228260040283,53627.40317606926,159440,0,53627.40317606926,0.6307000517845154,1.6828455924987793,10000,55647.28178501129,0.8802216053009033,0.5432620644569397,0.7541999816894531,1.059921383857727,50000 -2026.33642411232,5.809492588043213,54137.527206897736,160959,0,54137.527206897736,0.6326000094413757,1.689836025238037,10000,56174.68740797043,0.8743024468421936,0.5627148747444153,0.7519999742507935,1.0669498443603516,50000 -2043.4988808631897,5.86617112159729,54647.75112223625,162480,0,54647.75112223625,0.6355000138282776,1.671195149421692,10000,56702.178926467896,0.8794642686843872,0.5540726184844971,0.7535199522972107,1.0689977407455444,50000 -2060.928759098053,5.921685457229614,55157.73432970047,163999,0,55157.73432970047,0.6357000470161438,1.6739016771316528,10000,57229.695892095566,0.881257951259613,0.5390360951423645,0.7562599778175354,1.0575190782546997,50000 -2078.2167851924896,5.992365121841431,55667.66546201706,165518,0,55667.66546201706,0.6420000195503235,1.6526086330413818,10000,57757.03380203247,0.8997927308082581,0.4756694734096527,0.7611799836158752,1.0342700481414795,50000 -2095.539171934128,6.0488669872283936,56177.75947856903,167037,0,56177.75947856903,0.6413000226020813,1.6404330730438232,10000,58284.55493855477,0.9026426672935486,0.4636174738407135,0.7644599676132202,1.0262614488601685,50000 -2112.800806760788,6.107359647750855,56687.92796278,168557,0,56687.92796278,0.6444000601768494,1.6343672275543213,10000,58812.09150671959,0.901147961616516,0.4678551256656647,0.7643799781799316,1.0218467712402344,50000 -2130.0928435325623,6.1661036014556885,57198.12523698807,170076,0,57198.12523698807,0.6461000442504883,1.621721267700195,10000,59339.6875641346,0.904117465019226,0.4593851864337921,0.7635599970817566,1.023459792137146,50000 -2147.3965339660645,6.22303581237793,57708.20014286041,171595,0,57708.20014286041,0.6449000239372253,1.6215829849243164,10000,59867.17143726349,0.906648576259613,0.4511556625366211,0.7665199637413025,1.014053463935852,50000 -2164.6122002601624,6.283660173416138,58218.26253461838,173116,0,58218.26253461838,0.6490000486373901,1.6089009046554563,10000,60394.55778956413,0.9085220098495485,0.4437970519065857,0.7672799825668335,1.0095661878585815,50000 -2181.95579123497,6.341675043106079,58728.41574501991,174635,0,58728.41574501991,0.6513000130653381,1.6108440160751345,10000,60922.16036057472,0.9193837642669678,0.4103900194168091,0.7671399712562561,1.0063893795013428,50000 -2199.349052429199,6.401141405105591,59238.44852399826,176154,0,59238.44852399826,0.651900053024292,1.6077901124954224,10000,61449.693984270096,0.917390763759613,0.4171870648860931,0.76801997423172,1.0084744691848757,50000 -2216.56134724617,6.457212448120117,59748.627321243286,177673,0,59748.627321243286,0.6516000032424927,1.602932333946228,10000,61977.189665555954,0.9159757494926452,0.4144531488418579,0.7691400051116943,1.002310037612915,50000 -2233.763617515564,7.341777801513672,60257.7345225811,179190,0,60257.7345225811,0.6502000093460083,1.6043004989624023,10000,62504.43201804161,0.9170718789100648,0.4057277143001556,0.7702599763870239,0.9993435740470886,50000 -2251.149762868881,7.402913808822632,60767.89223456383,180709,0,60767.89223456383,0.6528000235557556,1.6050145626068115,10000,63032.08485865593,0.9195631146430968,0.4090628623962402,0.7701999545097351,1.0007290840148926,50000 -2268.478595972061,7.46382737159729,61277.79031038284,182227,0,61277.79031038284,0.6515000462532043,1.602009892463684,10000,63559.42090892792,0.9197823405265808,0.4075490236282348,0.7706800103187561,0.9976078867912292,50000 -2285.850029706955,7.525523662567139,61788.00178670883,183747,0,61788.00178670883,0.6534000039100647,1.599919080734253,10000,64087.113674640656,0.9213767051696776,0.4001802504062652,0.7707799673080444,0.9960906505584716,50000 -2303.0237517356877,7.586014032363892,62297.926362752914,185266,0,62297.926362752914,0.6541000604629517,1.599553108215332,10000,64614.32078671456,0.9223333597183228,0.3986599445343017,0.7706599831581116,0.9960277080535888,50000 -2320.299160003662,7.647452354431152,62767.63660430908,186666,0,62767.63660430908,0.6535000205039978,1.5979526042938232,10000,65101.41163563728,0.9204400181770325,0.39730027318000793,0.7709999680519104,0.9936330318450928,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index c0098ff8f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1993 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.59813327,6.920137,,,,,,,,,,,,,, -1,,,0.0007573341717943,6.910542011260986,0.0009599999757483,6.910243988037109,50000.0,0.0006000000284984,6.910250186920166,10000.0,69.66642546653748,107.33833169937134,69.66642546653748,37.67181849479675,0.0,0.0 -100,0.5784756,6.90276,,,,,,,,,,,,,, -200,0.58761346,6.8650274,,,,,,,,,,,,,, -300,0.6389509,6.787943,,,,,,,,,,,,,, -400,0.6632045,6.685791,,,,,,,,,,,,,, -500,0.7111429,6.5861616,,,,,,,,,,,,,, -600,0.7550149,6.5469246,,,,,,,,,,,,,, -700,0.8050522,6.4201927,,,,,,,,,,,,,, -800,0.8635853,6.269799,,,,,,,,,,,,,, -900,1.2339413,6.202387,,,,,,,,,,,,,, -1000,1.6126767,6.181346,,,,,,,,,,,,,, -1100,1.2455163,6.081448,,,,,,,,,,,,,, -1200,1.7211617,6.0478253,,,,,,,,,,,,,, -1300,2.2765565,5.9073815,,,,,,,,,,,,,, -1400,3.0953815,5.8354073,,,,,,,,,,,,,, -1500,3.1859355,5.7966833,,,,,,,,,,,,,, -1514,,,0.0735411345958709,5.336194038391113,0.0678599998354911,5.403800964355469,50000.0,0.049600001424551,5.629796504974365,10000.0,579.8638372421265,635.5374145507812,579.8638372421265,55.593273878097534,0.0293140411376953,0.0 -1600,2.3478713,5.7338147,,,,,,,,,,,,,, -1700,2.951114,5.6781974,,,,,,,,,,,,,, -1800,2.9106114,5.582506,,,,,,,,,,,,,, -1900,3.1299734,5.53599,,,,,,,,,,,,,, -2000,4.6852636,5.4280934,,,,,,,,,,,,,, -2100,4.3807793,5.401987,,,,,,,,,,,,,, -2200,5.381535,5.359907,,,,,,,,,,,,,, -2300,4.170131,5.3671107,,,,,,,,,,,,,, -2400,2.9330976,5.2454777,,,,,,,,,,,,,, -2500,3.5452044,5.2498126,,,,,,,,,,,,,, -2600,4.9461064,5.195735,,,,,,,,,,,,,, -2700,4.6976447,5.182199,,,,,,,,,,,,,, -2800,4.1785808,5.124696,,,,,,,,,,,,,, -2900,4.1695085,5.0975275,,,,,,,,,,,,,, -3000,3.7559206,4.9112096,,,,,,,,,,,,,, -3025,,,0.1801857352256775,4.261962890625,0.1609999984502792,4.37910270690918,50000.0,0.1212000027298927,4.7993268966674805,10000.0,1089.8733115196228,1163.8596937656405,1089.8733115196228,73.83026790618896,0.0556843280792236,0.0 -3100,3.5064385,4.966599,,,,,,,,,,,,,, -3200,8.116688,4.887416,,,,,,,,,,,,,, -3300,4.7777667,4.916379,,,,,,,,,,,,,, -3400,4.327111,4.8855925,,,,,,,,,,,,,, -3500,5.6726074,4.7504888,,,,,,,,,,,,,, -3600,3.565009,4.7258787,,,,,,,,,,,,,, -3700,5.705304,4.714261,,,,,,,,,,,,,, -3800,5.6393085,4.5837374,,,,,,,,,,,,,, -3900,5.773008,4.588313,,,,,,,,,,,,,, -4000,4.557863,4.6197863,,,,,,,,,,,,,, -4100,3.6206903,4.5917473,,,,,,,,,,,,,, -4200,7.8607655,4.5149345,,,,,,,,,,,,,, -4300,4.5061173,4.53312,,,,,,,,,,,,,, -4400,4.2826114,4.463257,,,,,,,,,,,,,, -4500,6.539879,4.374696,,,,,,,,,,,,,, -4537,,,0.2841398119926452,3.504873514175415,0.2563000023365021,3.6541614532470703,50000.0,0.1918000131845474,4.191176891326904,10000.0,1600.124408245087,1692.4212460517883,1600.124408245087,92.06435537338255,0.0844905376434326,0.0 -4600,5.5683947,4.3239183,,,,,,,,,,,,,, -4700,7.69759,4.35105,,,,,,,,,,,,,, -4800,7.027551,4.2546754,,,,,,,,,,,,,, -4900,6.056701,4.3057594,,,,,,,,,,,,,, -5000,6.2473044,4.2111154,,,,,,,,,,,,,, -5100,14.415817,4.159472,,,,,,,,,,,,,, -5200,7.9198074,4.157467,,,,,,,,,,,,,, -5300,4.521769,4.2275705,,,,,,,,,,,,,, -5400,8.319544,4.141226,,,,,,,,,,,,,, -5500,9.417938,4.0804386,,,,,,,,,,,,,, -5600,8.642919,4.115401,,,,,,,,,,,,,, -5700,6.314146,4.0153217,,,,,,,,,,,,,, -5800,5.866005,4.057621,,,,,,,,,,,,,, -5900,4.3345313,3.982259,,,,,,,,,,,,,, -6000,6.5053244,3.9104426,,,,,,,,,,,,,, -6049,,,0.3693797886371612,2.9464893341064453,0.3428999781608581,3.0969789028167725,50000.0,0.2526000142097473,3.744076728820801,10000.0,2110.33571267128,2220.955129146576,2110.33571267128,110.3130383491516,0.1100184917449951,0.0 -6100,5.329354,3.8790207,,,,,,,,,,,,,, -6200,10.139266,3.8215442,,,,,,,,,,,,,, -6300,5.65833,3.9238005,,,,,,,,,,,,,, -6400,5.300355,3.9027529,,,,,,,,,,,,,, -6500,4.1092615,3.9444437,,,,,,,,,,,,,, -6600,5.530079,3.7791855,,,,,,,,,,,,,, -6700,5.2907166,3.8456283,,,,,,,,,,,,,, -6800,6.7961936,3.7654877,,,,,,,,,,,,,, -6900,4.963301,3.7655385,,,,,,,,,,,,,, -7000,4.9337416,3.7876313,,,,,,,,,,,,,, -7100,5.7584205,3.770959,,,,,,,,,,,,,, -7200,7.0297146,3.7816074,,,,,,,,,,,,,, -7300,5.8564396,3.698736,,,,,,,,,,,,,, -7400,8.228339,3.7312536,,,,,,,,,,,,,, -7500,4.5835557,3.6942956,,,,,,,,,,,,,, -7561,,,0.443459004163742,2.543766975402832,0.4113999903202057,2.717856168746948,50000.0,0.3183000087738037,3.374671220779419,10000.0,2620.27745461464,2749.279436826706,2620.27745461464,128.61637711524963,0.1401574611663818,0.0 -7600,4.857524,3.675047,,,,,,,,,,,,,, -7700,6.8615294,3.6108627,,,,,,,,,,,,,, -7800,7.1400585,3.5724025,,,,,,,,,,,,,, -7900,7.4964757,3.6447296,,,,,,,,,,,,,, -8000,6.4764442,3.5702486,,,,,,,,,,,,,, -8100,4.517153,3.5451782,,,,,,,,,,,,,, -8200,4.555914,3.5575294,,,,,,,,,,,,,, -8300,5.2080283,3.5356731,,,,,,,,,,,,,, -8400,4.553494,3.5657375,,,,,,,,,,,,,, -8500,4.123767,3.476673,,,,,,,,,,,,,, -8600,3.719042,3.5087924,,,,,,,,,,,,,, -8700,8.396599,3.464552,,,,,,,,,,,,,, -8800,5.22273,3.4820442,,,,,,,,,,,,,, -8900,6.4251823,3.4609869,,,,,,,,,,,,,, -9000,7.4829874,3.5129132,,,,,,,,,,,,,, -9073,,,0.5113998651504517,2.2248189449310303,0.4605399966239929,2.471454620361328,50000.0,0.3585000038146972,3.096574783325196,10000.0,3130.3148624897003,3277.937031984329,3130.3148624897003,147.15200424194336,0.1761560440063476,0.0 -9100,5.442669,3.3303902,,,,,,,,,,,,,, -9200,6.7003546,3.3748014,,,,,,,,,,,,,, -9300,5.0501227,3.465205,,,,,,,,,,,,,, -9400,7.131802,3.4227257,,,,,,,,,,,,,, -9500,5.3882837,3.322351,,,,,,,,,,,,,, -9600,6.927008,3.395998,,,,,,,,,,,,,, -9700,7.452097,3.4393973,,,,,,,,,,,,,, -9800,5.568983,3.3358173,,,,,,,,,,,,,, -9900,5.522243,3.3562808,,,,,,,,,,,,,, -10000,5.596778,3.360685,,,,,,,,,,,,,, -10100,4.557776,3.2858896,,,,,,,,,,,,,, -10200,3.6334293,3.41425,,,,,,,,,,,,,, -10300,5.526923,3.334156,,,,,,,,,,,,,, -10400,5.6264386,3.352252,,,,,,,,,,,,,, -10500,4.3169312,3.296801,,,,,,,,,,,,,, -10586,,,0.5440847873687744,2.0609612464904785,0.492819994688034,2.328207492828369,50000.0,0.3758000135421753,3.0121326446533203,10000.0,3640.3234446048737,3806.8680925369263,3640.3234446048737,165.99961352348328,0.2035794258117675,0.0 -10600,7.150444,3.222268,,,,,,,,,,,,,, -10700,5.3143816,3.2672,,,,,,,,,,,,,, -10800,6.44803,3.3309546,,,,,,,,,,,,,, -10900,6.594354,3.2648814,,,,,,,,,,,,,, -11000,4.791562,3.1936128,,,,,,,,,,,,,, -11100,5.017349,3.2511225,,,,,,,,,,,,,, -11200,4.166161,3.1445878,,,,,,,,,,,,,, -11300,3.5438168,3.2159963,,,,,,,,,,,,,, -11400,4.5320377,3.2705822,,,,,,,,,,,,,, -11500,8.766891,3.2284417,,,,,,,,,,,,,, -11600,3.1683223,3.3576064,,,,,,,,,,,,,, -11700,5.305745,3.1972706,,,,,,,,,,,,,, -11800,5.436514,3.1510878,,,,,,,,,,,,,, -11900,5.8975534,3.2156296,,,,,,,,,,,,,, -12000,4.6742263,3.0982254,,,,,,,,,,,,,, -12100,5.4401846,3.0711644,,,,,,,,,,,,,, -12101,,,0.5639150142669678,1.929271221160889,0.521619975566864,2.1376399993896484,50000.0,0.410500019788742,2.809180498123169,10000.0,4150.755066394806,4338.130095720291,4150.755066394806,186.73877835273743,0.2456672191619873,0.0 -12200,5.535553,3.0356364,,,,,,,,,,,,,, -12300,5.642607,3.0231316,,,,,,,,,,,,,, -12400,7.0755644,3.1943417,,,,,,,,,,,,,, -12500,3.7946122,3.1422799,,,,,,,,,,,,,, -12600,4.629423,3.1680658,,,,,,,,,,,,,, -12700,6.510559,3.1041307,,,,,,,,,,,,,, -12800,4.21491,3.1201143,,,,,,,,,,,,,, -12900,6.276537,3.1499584,,,,,,,,,,,,,, -13000,5.1208034,3.2101061,,,,,,,,,,,,,, -13100,4.403644,3.1699808,,,,,,,,,,,,,, -13200,4.4651318,3.0167282,,,,,,,,,,,,,, -13300,4.840335,3.0762444,,,,,,,,,,,,,, -13400,4.863168,3.1439326,,,,,,,,,,,,,, -13500,6.101066,3.0080266,,,,,,,,,,,,,, -13600,9.281346,3.0626433,,,,,,,,,,,,,, -13616,,,0.5889668464660645,1.818697810173035,0.5448799729347229,2.03045392036438,50000.0,0.4193000197410583,2.7065205574035645,10000.0,4660.733189105988,4867.710218667984,4660.733189105988,206.2655758857727,0.2724974155426025,0.0 -13700,5.6754622,3.1068606,,,,,,,,,,,,,, -13800,4.759749,3.0110123,,,,,,,,,,,,,, -13900,6.90599,3.1366413,,,,,,,,,,,,,, -14000,7.30408,3.180294,,,,,,,,,,,,,, -14100,6.1858153,3.2117329,,,,,,,,,,,,,, -14200,7.296192,3.19318,,,,,,,,,,,,,, -14300,6.0175614,3.1105158,,,,,,,,,,,,,, -14400,4.545892,2.979792,,,,,,,,,,,,,, -14500,3.285814,3.0138152,,,,,,,,,,,,,, -14600,5.293021,3.126555,,,,,,,,,,,,,, -14700,5.0640383,3.1914217,,,,,,,,,,,,,, -14800,5.026855,2.9894218,,,,,,,,,,,,,, -14900,5.303098,2.9950266,,,,,,,,,,,,,, -15000,4.9399185,3.0224984,,,,,,,,,,,,,, -15100,8.017794,3.1622064,,,,,,,,,,,,,, -15132,,,0.5876913070678711,1.8149843215942385,0.5457000136375427,2.0031583309173584,50000.0,0.4246000349521637,2.713749647140503,10000.0,5170.75431728363,5399.7187423706055,5170.75431728363,228.1629378795624,0.3139839172363281,0.0 -15200,5.229208,3.0523543,,,,,,,,,,,,,, -15300,3.6536937,3.0645478,,,,,,,,,,,,,, -15400,4.9341726,3.0016313,,,,,,,,,,,,,, -15500,4.013448,3.0839016,,,,,,,,,,,,,, -15600,4.1399107,3.1005938,,,,,,,,,,,,,, -15700,5.1235356,2.9412673,,,,,,,,,,,,,, -15800,4.751747,3.0372393,,,,,,,,,,,,,, -15900,4.210514,3.0393248,,,,,,,,,,,,,, -16000,4.4295373,3.0386817,,,,,,,,,,,,,, -16100,4.2688675,3.0357423,,,,,,,,,,,,,, -16200,7.700673,3.1786,,,,,,,,,,,,,, -16300,5.8963327,3.0563717,,,,,,,,,,,,,, -16400,3.822324,3.0197124,,,,,,,,,,,,,, -16500,4.5444117,3.0249333,,,,,,,,,,,,,, -16600,5.8650446,2.9363751,,,,,,,,,,,,,, -16648,,,0.5989118218421936,1.7425923347473145,0.5619199872016907,1.9272526502609253,50000.0,0.4403000175952911,2.5987136363983154,10000.0,5680.732532739639,5931.523034095764,5680.732532739639,249.90737676620483,0.3474881649017334,0.0 -16700,4.1770406,3.0415003,,,,,,,,,,,,,, -16800,4.2030945,3.021429,,,,,,,,,,,,,, -16900,4.59682,2.9474175,,,,,,,,,,,,,, -17000,3.3692145,3.08843,,,,,,,,,,,,,, -17100,4.419429,2.9692738,,,,,,,,,,,,,, -17200,4.422788,3.0227032,,,,,,,,,,,,,, -17300,3.2975566,2.9581978,,,,,,,,,,,,,, -17400,4.5129657,3.0351605,,,,,,,,,,,,,, -17500,4.252215,2.9794016,,,,,,,,,,,,,, -17600,4.201011,3.0144043,,,,,,,,,,,,,, -17700,4.0545607,3.021765,,,,,,,,,,,,,, -17800,3.2706752,3.048366,,,,,,,,,,,,,, -17900,5.299096,3.0035603,,,,,,,,,,,,,, -18000,4.33046,2.9735057,,,,,,,,,,,,,, -18100,3.5759928,2.9204657,,,,,,,,,,,,,, -18165,,,0.6433154940605164,1.562894344329834,0.5627399682998657,1.9308120012283323,50000.0,0.4456000328063965,2.6102218627929688,10000.0,6190.90416431427,6466.05579328537,6190.90416431427,274.1750280857086,0.3943376541137695,0.0 -18200,4.697024,3.0468354,,,,,,,,,,,,,, -18300,5.1901507,2.943434,,,,,,,,,,,,,, -18400,4.3421593,2.8620696,,,,,,,,,,,,,, -18500,4.443903,3.0099854,,,,,,,,,,,,,, -18600,5.792394,3.0529766,,,,,,,,,,,,,, -18700,3.074601,2.936655,,,,,,,,,,,,,, -18800,2.9545646,3.0412638,,,,,,,,,,,,,, -18900,5.3324413,2.973302,,,,,,,,,,,,,, -19000,4.5455165,2.8937392,,,,,,,,,,,,,, -19100,2.9495428,2.9096103,,,,,,,,,,,,,, -19200,5.942278,2.975845,,,,,,,,,,,,,, -19300,3.551962,3.0455103,,,,,,,,,,,,,, -19400,3.04152,2.9198518,,,,,,,,,,,,,, -19500,4.7364273,2.9545016,,,,,,,,,,,,,, -19600,3.4539218,2.908428,,,,,,,,,,,,,, -19681,,,0.6237842440605164,1.6315232515335083,0.5619800090789795,1.91173791885376,50000.0,0.4467000067234039,2.569241523742676,10000.0,6700.854817867279,7001.347071886063,6700.854817867279,299.4436390399933,0.4216992855072021,0.0 -19700,4.485539,3.0156894,,,,,,,,,,,,,, -19800,3.417483,2.9940658,,,,,,,,,,,,,, -19900,4.666079,2.906215,,,,,,,,,,,,,, -20000,3.8580115,3.0357697,,,,,,,,,,,,,, -20100,4.3204136,2.9899616,,,,,,,,,,,,,, -20200,3.1506097,3.1268733,,,,,,,,,,,,,, -20300,4.095629,2.90863,,,,,,,,,,,,,, -20400,3.46495,2.9484322,,,,,,,,,,,,,, -20500,3.4443417,2.9595964,,,,,,,,,,,,,, -20600,3.0316985,2.8859901,,,,,,,,,,,,,, -20700,4.0645328,2.9399414,,,,,,,,,,,,,, -20800,3.6035912,3.036006,,,,,,,,,,,,,, -20900,3.3489072,2.9389703,,,,,,,,,,,,,, -21000,2.633006,3.0164816,,,,,,,,,,,,,, -21100,3.572615,2.983708,,,,,,,,,,,,,, -21198,,,0.6259167790412903,1.6084551811218262,0.5781800150871277,1.842873334884644,50000.0,0.4509000182151794,2.527494430541992,10000.0,7210.977495670319,7537.237438201904,7210.977495670319,325.11325001716614,0.4727559089660644,0.0 -21200,3.3721197,2.89607,,,,,,,,,,,,,, -21300,2.9601352,2.8917933,,,,,,,,,,,,,, -21400,3.3887644,2.84565,,,,,,,,,,,,,, -21500,3.963041,2.9495695,,,,,,,,,,,,,, -21600,3.4428194,2.9072204,,,,,,,,,,,,,, -21700,3.3016157,2.9256494,,,,,,,,,,,,,, -21800,3.5161622,2.8657236,,,,,,,,,,,,,, -21900,3.50958,2.8323882,,,,,,,,,,,,,, -22000,3.3627372,2.83529,,,,,,,,,,,,,, -22100,3.0587058,2.8185043,,,,,,,,,,,,,, -22200,4.192312,2.9222157,,,,,,,,,,,,,, -22300,3.0000608,2.8834672,,,,,,,,,,,,,, -22400,5.1023173,2.9050088,,,,,,,,,,,,,, -22500,3.5002463,2.8976617,,,,,,,,,,,,,, -22600,2.8444526,2.8681405,,,,,,,,,,,,,, -22700,3.932093,2.9307852,,,,,,,,,,,,,, -22715,,,0.6239436864852905,1.6318671703338623,0.5826199650764465,1.8319292068481443,50000.0,0.4554000198841095,2.526982307434082,10000.0,7720.921734571457,8070.599623918533,7720.921734571457,348.44293189048767,0.514392614364624,0.0 -22800,3.013572,2.814649,,,,,,,,,,,,,, -22900,3.0310047,2.8297894,,,,,,,,,,,,,, -23000,2.9584322,2.9184523,,,,,,,,,,,,,, -23100,2.980702,3.0659378,,,,,,,,,,,,,, -23200,4.157487,2.9928522,,,,,,,,,,,,,, -23300,3.1580627,2.8785453,,,,,,,,,,,,,, -23400,3.0330436,2.8824978,,,,,,,,,,,,,, -23500,3.879431,2.8886404,,,,,,,,,,,,,, -23600,2.6800315,3.0118988,,,,,,,,,,,,,, -23700,3.0172656,2.890328,,,,,,,,,,,,,, -23800,2.9567122,2.8715088,,,,,,,,,,,,,, -23900,2.8848462,2.8969247,,,,,,,,,,,,,, -24000,2.9254441,2.8376634,,,,,,,,,,,,,, -24100,2.557264,2.927843,,,,,,,,,,,,,, -24200,2.8124673,2.9842148,,,,,,,,,,,,,, -24232,,,0.6275310516357422,1.646264910697937,0.5839399695396423,1.8555593490600584,50000.0,0.4675000309944153,2.4928462505340576,10000.0,8231.1480448246,8605.418276548386,8231.1480448246,372.9615008831024,0.5425989627838135,0.0 -24300,2.7425938,2.94865,,,,,,,,,,,,,, -24400,3.2823362,2.899798,,,,,,,,,,,,,, -24500,2.5804021,2.9181113,,,,,,,,,,,,,, -24600,2.779456,2.8163164,,,,,,,,,,,,,, -24700,3.6736894,2.9145947,,,,,,,,,,,,,, -24800,2.6983654,2.8827696,,,,,,,,,,,,,, -24900,3.5426157,2.841652,,,,,,,,,,,,,, -25000,3.7243164,2.9409513,,,,,,,,,,,,,, -25100,2.866906,2.8005378,,,,,,,,,,,,,, -25200,3.02565,2.8620868,,,,,,,,,,,,,, -25300,2.7103968,2.8279843,,,,,,,,,,,,,, -25400,3.741021,2.8927739,,,,,,,,,,,,,, -25500,3.0015025,2.8150518,,,,,,,,,,,,,, -25600,3.8149147,2.8147635,,,,,,,,,,,,,, -25700,2.6194742,2.782652,,,,,,,,,,,,,, -25749,,,0.6228276491165161,1.6504992246627808,0.5827800035476685,1.833041071891785,50000.0,0.4647000133991241,2.5192012786865234,10000.0,8741.327996253967,9138.597157239914,8741.327996253967,395.8874454498291,0.5697648525238037,0.0 -25800,3.5934353,2.93631,,,,,,,,,,,,,, -25900,2.8797715,2.915139,,,,,,,,,,,,,, -26000,2.962639,2.8613653,,,,,,,,,,,,,, -26100,3.070852,2.8985436,,,,,,,,,,,,,, -26200,2.5745058,2.8089614,,,,,,,,,,,,,, -26300,3.2028847,2.9198058,,,,,,,,,,,,,, -26400,4.4876733,2.8942447,,,,,,,,,,,,,, -26500,2.6683233,2.8963597,,,,,,,,,,,,,, -26600,3.2021587,2.9662142,,,,,,,,,,,,,, -26700,2.6234486,2.9182959,,,,,,,,,,,,,, -26800,3.5235312,2.8797603,,,,,,,,,,,,,, -26900,3.1068976,2.8350208,,,,,,,,,,,,,, -27000,2.8249013,2.6993592,,,,,,,,,,,,,, -27100,2.9842892,3.0518188,,,,,,,,,,,,,, -27200,2.9913132,2.7548518,,,,,,,,,,,,,, -27266,,,0.6548548936843872,1.5148216485977173,0.581820011138916,1.857088208198548,50000.0,0.4601000249385834,2.5271754264831543,10000.0,9251.424172401428,9671.430490970612,9251.424172401428,418.5490050315857,0.6001632213592529,0.0 -27300,2.8112285,2.805828,,,,,,,,,,,,,, -27400,3.7151773,2.8435128,,,,,,,,,,,,,, -27500,2.6946445,2.875058,,,,,,,,,,,,,, -27600,2.8457656,2.8611553,,,,,,,,,,,,,, -27700,2.766184,2.8312173,,,,,,,,,,,,,, -27800,3.120278,2.885089,,,,,,,,,,,,,, -27900,2.8766563,2.8225296,,,,,,,,,,,,,, -28000,3.122327,2.845011,,,,,,,,,,,,,, -28100,2.7005346,2.947173,,,,,,,,,,,,,, -28200,3.3821235,2.810206,,,,,,,,,,,,,, -28300,3.273401,2.8156335,,,,,,,,,,,,,, -28400,3.6650832,2.9275508,,,,,,,,,,,,,, -28500,3.2535975,2.7786412,,,,,,,,,,,,,, -28600,2.9468079,2.7662792,,,,,,,,,,,,,, -28700,2.9061832,2.8656335,,,,,,,,,,,,,, -28784,,,0.6424983739852905,1.578094244003296,0.5882999897003174,1.832597017288208,50000.0,0.4591000080108642,2.5145013332366943,10000.0,9761.575938463213,10203.743408441544,9761.575938463213,440.63156938552856,0.6332635879516602,0.0 -28800,3.2885602,2.845531,,,,,,,,,,,,,, -28900,3.554302,2.896261,,,,,,,,,,,,,, -29000,3.229445,2.8014078,,,,,,,,,,,,,, -29100,3.0415173,2.8052402,,,,,,,,,,,,,, -29200,3.110641,2.835215,,,,,,,,,,,,,, -29300,3.7168868,2.7383118,,,,,,,,,,,,,, -29400,3.0057964,2.824453,,,,,,,,,,,,,, -29500,3.159038,2.8759334,,,,,,,,,,,,,, -29600,3.4764352,2.7946198,,,,,,,,,,,,,, -29700,2.7221923,2.8445168,,,,,,,,,,,,,, -29800,3.2791185,2.827394,,,,,,,,,,,,,, -29900,3.0653126,2.8674586,,,,,,,,,,,,,, -30000,3.2921524,2.8792558,,,,,,,,,,,,,, -30100,4.0452933,2.789028,,,,,,,,,,,,,, -30200,2.6357627,2.7766943,,,,,,,,,,,,,, -30300,3.205807,2.9355266,,,,,,,,,,,,,, -30303,,,0.6371771097183228,1.577159404754639,0.5897799730300903,1.7975454330444336,50000.0,0.4589000344276428,2.4995713233947754,10000.0,10271.727610588074,10738.066817998886,10271.727610588074,464.7254109382629,0.6648633480072021,0.0 -30400,2.973303,2.813099,,,,,,,,,,,,,, -30500,3.029966,2.854542,,,,,,,,,,,,,, -30600,2.6326444,2.773185,,,,,,,,,,,,,, -30700,2.852663,2.9151096,,,,,,,,,,,,,, -30800,2.8181608,2.8423514,,,,,,,,,,,,,, -30900,3.0487742,2.87476,,,,,,,,,,,,,, -31000,3.9006765,2.8041706,,,,,,,,,,,,,, -31100,2.9844859,2.893469,,,,,,,,,,,,,, -31200,3.0387864,2.7852168,,,,,,,,,,,,,, -31300,2.7015688,2.7482202,,,,,,,,,,,,,, -31400,2.5056393,2.8117633,,,,,,,,,,,,,, -31500,3.6037822,2.75814,,,,,,,,,,,,,, -31600,3.245549,2.7723157,,,,,,,,,,,,,, -31700,2.9128304,2.8227184,,,,,,,,,,,,,, -31800,2.6039343,2.9582136,,,,,,,,,,,,,, -31821,,,0.6344068646430969,1.5880506038665771,0.5906800031661987,1.8002487421035769,50000.0,0.4662000238895416,2.477463960647583,10000.0,10781.824187994003,11272.292582035065,10781.824187994003,488.7787292003632,0.6953849792480469,0.0 -31900,2.6142688,2.893355,,,,,,,,,,,,,, -32000,2.7319288,2.9122658,,,,,,,,,,,,,, -32100,2.964312,2.770778,,,,,,,,,,,,,, -32200,3.0364208,2.7947068,,,,,,,,,,,,,, -32300,2.9851751,2.7279522,,,,,,,,,,,,,, -32400,2.8831267,2.7726727,,,,,,,,,,,,,, -32500,3.0549746,2.8607187,,,,,,,,,,,,,, -32600,2.9484982,2.8358536,,,,,,,,,,,,,, -32700,3.3274527,2.7798204,,,,,,,,,,,,,, -32800,3.166761,2.8615723,,,,,,,,,,,,,, -32900,2.642981,2.7488036,,,,,,,,,,,,,, -33000,3.3205216,2.8119662,,,,,,,,,,,,,, -33100,2.7077224,2.8137949,,,,,,,,,,,,,, -33200,2.777371,2.8372145,,,,,,,,,,,,,, -33300,3.2975624,2.8716242,,,,,,,,,,,,,, -33339,,,0.6326330900192261,1.577653169631958,0.5985000133514404,1.7677619457244873,50000.0,0.4714000225067138,2.437567949295044,10000.0,11291.99487566948,11806.442516088486,11291.99487566948,512.6822199821472,0.7257883548736572,0.0 -33400,2.6637394,2.74178,,,,,,,,,,,,,, -33500,2.7066991,2.76567,,,,,,,,,,,,,, -33600,2.6194146,2.786594,,,,,,,,,,,,,, -33700,2.842778,2.8998394,,,,,,,,,,,,,, -33800,3.3560512,2.7843406,,,,,,,,,,,,,, -33900,3.0359988,2.8657243,,,,,,,,,,,,,, -34000,2.823162,2.874697,,,,,,,,,,,,,, -34100,3.1852024,2.8701792,,,,,,,,,,,,,, -34200,3.0422006,2.858323,,,,,,,,,,,,,, -34300,3.9364498,2.8956757,,,,,,,,,,,,,, -34400,2.9436505,2.8914464,,,,,,,,,,,,,, -34500,2.9238126,2.8662944,,,,,,,,,,,,,, -34600,3.0667703,2.8184,,,,,,,,,,,,,, -34700,2.9246209,2.7867897,,,,,,,,,,,,,, -34800,3.3937972,2.8277743,,,,,,,,,,,,,, -34858,,,0.6404655575752258,1.569533348083496,0.5995999574661255,1.7648115158081057,50000.0,0.4784000217914581,2.4060091972351074,10000.0,11802.263848781586,12341.097544431686,11802.263848781586,536.9916772842407,0.756098747253418,0.0 -34900,3.0742428,2.8651602,,,,,,,,,,,,,, -35000,2.7804925,2.873223,,,,,,,,,,,,,, -35100,4.031746,2.8403602,,,,,,,,,,,,,, -35200,2.6308317,2.7755342,,,,,,,,,,,,,, -35300,3.0979002,2.7789187,,,,,,,,,,,,,, -35400,2.8467336,2.8484824,,,,,,,,,,,,,, -35500,2.6036239,2.7693906,,,,,,,,,,,,,, -35600,2.9816785,2.7794304,,,,,,,,,,,,,, -35700,3.3847985,2.788282,,,,,,,,,,,,,, -35800,3.0653138,2.8155181,,,,,,,,,,,,,, -35900,3.1314714,2.7918496,,,,,,,,,,,,,, -36000,3.0283532,2.8330748,,,,,,,,,,,,,, -36100,3.213848,2.7741919,,,,,,,,,,,,,, -36200,2.8768497,2.7705357,,,,,,,,,,,,,, -36300,3.5106494,2.8176208,,,,,,,,,,,,,, -36376,,,0.670340359210968,1.3972032070159912,0.6022799611091614,1.7204241752624512,50000.0,0.4791000187397003,2.39719295501709,10000.0,12312.29682803154,12875.93498826027,12312.29682803154,561.7181787490845,0.7879447937011719,0.0 -36400,2.5356376,2.8369226,,,,,,,,,,,,,, -36500,3.7145936,2.9052994,,,,,,,,,,,,,, -36600,2.5890834,2.660098,,,,,,,,,,,,,, -36700,2.8842504,2.8041534,,,,,,,,,,,,,, -36800,2.681959,2.7119086,,,,,,,,,,,,,, -36900,3.4539113,2.8024027,,,,,,,,,,,,,, -37000,3.0448499,2.8154259,,,,,,,,,,,,,, -37100,2.6343384,2.7437842,,,,,,,,,,,,,, -37200,3.2680907,2.817811,,,,,,,,,,,,,, -37300,2.8850002,2.8054447,,,,,,,,,,,,,, -37400,3.1058714,2.748653,,,,,,,,,,,,,, -37500,2.9694178,2.8210056,,,,,,,,,,,,,, -37600,2.9935791,2.7652583,,,,,,,,,,,,,, -37700,3.1397843,2.835846,,,,,,,,,,,,,, -37800,3.3675425,2.7497234,,,,,,,,,,,,,, -37894,,,0.6550741195678711,1.4708644151687622,0.6022599935531616,1.7340105772018433,50000.0,0.4779000282287597,2.4191813468933105,10000.0,12822.334090471268,13409.04744195938,12822.334090471268,584.714658498764,0.8208250999450684,0.0 -37900,2.815681,2.8477907,,,,,,,,,,,,,, -38000,2.8918412,2.7924438,,,,,,,,,,,,,, -38100,3.9437697,2.700056,,,,,,,,,,,,,, -38200,3.012506,2.810175,,,,,,,,,,,,,, -38300,3.1701806,2.8699026,,,,,,,,,,,,,, -38400,2.9150534,2.7501893,,,,,,,,,,,,,, -38500,3.3376374,2.7661488,,,,,,,,,,,,,, -38600,2.793446,2.7805796,,,,,,,,,,,,,, -38700,3.0797966,2.7341785,,,,,,,,,,,,,, -38800,3.0518513,2.7212837,,,,,,,,,,,,,, -38900,2.9295487,2.8146124,,,,,,,,,,,,,, -39000,2.9281878,2.6547143,,,,,,,,,,,,,, -39100,3.2488036,2.8570952,,,,,,,,,,,,,, -39200,3.1393943,2.8118494,,,,,,,,,,,,,, -39300,3.5872016,2.7617571,,,,,,,,,,,,,, -39400,3.321256,2.7268443,,,,,,,,,,,,,, -39412,,,0.6563097834587097,1.4915953874588013,0.6089400053024292,1.7161346673965454,50000.0,0.4886000156402588,2.380843162536621,10000.0,13332.26151394844,13941.919185638428,13332.26151394844,607.5823495388031,0.8514549732208252,0.0 -39500,3.5166569,2.784994,,,,,,,,,,,,,, -39600,3.8479218,2.7818272,,,,,,,,,,,,,, -39700,3.1261048,2.797941,,,,,,,,,,,,,, -39800,3.0409408,2.8739917,,,,,,,,,,,,,, -39900,2.839186,2.7745676,,,,,,,,,,,,,, -40000,3.4124045,2.8525057,,,,,,,,,,,,,, -40100,3.1805835,2.6950629,,,,,,,,,,,,,, -40200,2.7189996,2.7483969,,,,,,,,,,,,,, -40300,2.68721,2.7620826,,,,,,,,,,,,,, -40400,3.4188952,2.8701355,,,,,,,,,,,,,, -40500,2.8002098,2.6770313,,,,,,,,,,,,,, -40600,3.0950959,2.8106313,,,,,,,,,,,,,, -40700,2.6289937,2.9127882,,,,,,,,,,,,,, -40800,3.278613,2.7820187,,,,,,,,,,,,,, -40900,3.1459544,2.8084893,,,,,,,,,,,,,, -40930,,,0.6541174650192261,1.521090388298035,0.6050199866294861,1.733443260192871,50000.0,0.4817000329494476,2.3924968242645264,10000.0,13842.18695950508,14473.955300807953,13842.18695950508,629.6153049468994,0.883293628692627,0.0 -41000,2.8340895,2.8169026,,,,,,,,,,,,,, -41100,3.2128205,2.741751,,,,,,,,,,,,,, -41200,3.5503798,2.8443804,,,,,,,,,,,,,, -41300,2.758454,2.8100255,,,,,,,,,,,,,, -41400,3.5597153,2.7184305,,,,,,,,,,,,,, -41500,3.3256624,2.854391,,,,,,,,,,,,,, -41600,2.6332376,2.7926762,,,,,,,,,,,,,, -41700,2.87642,2.8095696,,,,,,,,,,,,,, -41800,3.3183928,2.6777072,,,,,,,,,,,,,, -41900,3.5741508,2.8077252,,,,,,,,,,,,,, -42000,3.2628636,2.7722025,,,,,,,,,,,,,, -42100,3.5445635,2.7207236,,,,,,,,,,,,,, -42200,3.682322,2.7781987,,,,,,,,,,,,,, -42300,2.8218784,2.6714532,,,,,,,,,,,,,, -42400,3.2408628,2.7917855,,,,,,,,,,,,,, -42448,,,0.6384526491165161,1.568603515625,0.5953999757766724,1.7692608833312988,50000.0,0.4720000326633453,2.444088220596313,10000.0,14352.171369075775,15006.26985692978,14352.171369075775,651.8660788536072,0.916776180267334,0.0 -42500,3.4639132,2.7404506,,,,,,,,,,,,,, -42600,3.1990857,2.7709777,,,,,,,,,,,,,, -42700,3.0987906,2.7400014,,,,,,,,,,,,,, -42800,3.256473,2.753178,,,,,,,,,,,,,, -42900,3.1989431,2.7525785,,,,,,,,,,,,,, -43000,3.262111,2.7795436,,,,,,,,,,,,,, -43100,3.0311406,2.7731972,,,,,,,,,,,,,, -43200,2.8399317,2.7577958,,,,,,,,,,,,,, -43300,2.7793083,2.7374573,,,,,,,,,,,,,, -43400,3.9346812,2.7064874,,,,,,,,,,,,,, -43500,2.594609,2.7712433,,,,,,,,,,,,,, -43600,2.9122243,2.837769,,,,,,,,,,,,,, -43700,2.8936975,2.7297387,,,,,,,,,,,,,, -43800,2.8058026,2.8003032,,,,,,,,,,,,,, -43900,2.8255146,2.7756672,,,,,,,,,,,,,, -43967,,,0.6627271771430969,1.4629147052764893,0.6109600067138672,1.699565887451172,50000.0,0.4828000366687774,2.402573585510254,10000.0,14862.203121185305,15539.22065258026,14862.203121185305,674.709219455719,0.9466893672943116,0.0 -44000,3.2775047,2.814465,,,,,,,,,,,,,, -44100,3.248098,2.7836943,,,,,,,,,,,,,, -44200,2.7908397,2.7993333,,,,,,,,,,,,,, -44300,2.8194468,2.7644649,,,,,,,,,,,,,, -44400,3.1004028,2.893168,,,,,,,,,,,,,, -44500,2.8259616,2.8492794,,,,,,,,,,,,,, -44600,3.6397638,2.7618513,,,,,,,,,,,,,, -44700,3.3427386,2.777899,,,,,,,,,,,,,, -44800,2.828147,2.7449062,,,,,,,,,,,,,, -44900,2.946581,2.7751813,,,,,,,,,,,,,, -45000,2.8207557,2.7090123,,,,,,,,,,,,,, -45100,3.1974998,2.7937155,,,,,,,,,,,,,, -45200,2.6213322,2.6744263,,,,,,,,,,,,,, -45300,3.4493968,2.8249412,,,,,,,,,,,,,, -45400,3.0557163,2.6272552,,,,,,,,,,,,,, -45486,,,0.674246609210968,1.4069665670394895,0.6114199757575989,1.696818232536316,50000.0,0.4945000112056732,2.347254753112793,10000.0,15372.199818134308,16071.047409534454,15372.199818134308,696.4635140895844,0.9763743877410888,0.0 -45500,3.0272305,2.6642501,,,,,,,,,,,,,, -45600,3.016403,2.7958498,,,,,,,,,,,,,, -45700,2.9106238,2.7663984,,,,,,,,,,,,,, -45800,3.176219,2.7214565,,,,,,,,,,,,,, -45900,3.484111,2.8061752,,,,,,,,,,,,,, -46000,3.202485,2.7663102,,,,,,,,,,,,,, -46100,3.4072568,2.7367122,,,,,,,,,,,,,, -46200,2.8961732,2.8799572,,,,,,,,,,,,,, -46300,2.8696203,2.6312568,,,,,,,,,,,,,, -46400,3.4000747,2.660903,,,,,,,,,,,,,, -46500,2.975131,2.747763,,,,,,,,,,,,,, -46600,3.2380247,2.7972481,,,,,,,,,,,,,, -46700,3.0301352,2.7237134,,,,,,,,,,,,,, -46800,3.079138,2.74705,,,,,,,,,,,,,, -46900,2.9851358,2.7430425,,,,,,,,,,,,,, -47000,3.1908379,2.8111064,,,,,,,,,,,,,, -47004,,,0.6636040806770325,1.4506380558013916,0.6128799915313721,1.6783645153045654,50000.0,0.4917000234127044,2.3279991149902344,10000.0,15882.18303823471,16602.090711593628,15882.18303823471,717.4448945522308,1.00885272026062,0.0 -47100,2.9464762,2.7099829,,,,,,,,,,,,,, -47200,3.0941744,2.794462,,,,,,,,,,,,,, -47300,2.8696232,2.7587452,,,,,,,,,,,,,, -47400,3.6521964,2.8001213,,,,,,,,,,,,,, -47500,3.61277,2.820639,,,,,,,,,,,,,, -47600,3.146996,2.6421645,,,,,,,,,,,,,, -47700,4.362054,2.7266068,,,,,,,,,,,,,, -47800,3.3202975,2.8137085,,,,,,,,,,,,,, -47900,2.7278135,2.6882374,,,,,,,,,,,,,, -48000,3.1605983,2.718526,,,,,,,,,,,,,, -48100,2.8117554,2.6426694,,,,,,,,,,,,,, -48200,2.8253887,2.7748187,,,,,,,,,,,,,, -48300,3.3558328,2.8566747,,,,,,,,,,,,,, -48400,2.9183176,2.8852165,,,,,,,,,,,,,, -48500,3.0825295,2.7151563,,,,,,,,,,,,,, -48523,,,0.6721540093421936,1.4366706609725952,0.6180399656295776,1.684572458267212,50000.0,0.4991000294685364,2.3448281288146973,10000.0,16392.207607269287,17133.673773765564,16392.207607269287,738.927268743515,1.0390520095825195,0.0 -48600,2.8723063,2.7261019,,,,,,,,,,,,,, -48700,3.5971053,2.6949122,,,,,,,,,,,,,, -48800,2.882011,2.722211,,,,,,,,,,,,,, -48900,3.2409706,2.7996936,,,,,,,,,,,,,, -49000,3.5870638,2.7100804,,,,,,,,,,,,,, -49100,3.3111067,2.7233365,,,,,,,,,,,,,, -49200,3.2652261,2.6355648,,,,,,,,,,,,,, -49300,3.0427983,2.7705297,,,,,,,,,,,,,, -49400,3.1091309,2.8040533,,,,,,,,,,,,,, -49500,3.3242002,2.6781936,,,,,,,,,,,,,, -49600,3.1542425,2.751696,,,,,,,,,,,,,, -49700,2.8563528,2.8090522,,,,,,,,,,,,,, -49800,3.2501485,2.7671335,,,,,,,,,,,,,, -49900,2.7000654,2.7352471,,,,,,,,,,,,,, -50000,2.8554122,2.6630235,,,,,,,,,,,,,, -50042,,,0.6578244566917419,1.4930152893066406,0.6173999905586243,1.6958266496658323,50000.0,0.491100013256073,2.3817460536956787,10000.0,16902.16853904724,17663.81556916237,16902.16853904724,759.0310838222504,1.070474624633789,0.0 -50100,3.331784,2.7362099,,,,,,,,,,,,,, -50200,3.047197,2.8133378,,,,,,,,,,,,,, -50300,3.088796,2.7646165,,,,,,,,,,,,,, -50400,3.1154633,2.6728573,,,,,,,,,,,,,, -50500,3.5369427,2.8349502,,,,,,,,,,,,,, -50600,3.4488077,2.7641354,,,,,,,,,,,,,, -50700,3.4099386,2.7402015,,,,,,,,,,,,,, -50800,3.2975688,2.8139915,,,,,,,,,,,,,, -50900,3.2714036,2.7539499,,,,,,,,,,,,,, -51000,3.1643481,2.7692714,,,,,,,,,,,,,, -51100,2.8541422,2.746109,,,,,,,,,,,,,, -51200,3.275006,2.75633,,,,,,,,,,,,,, -51300,2.8821685,2.8501723,,,,,,,,,,,,,, -51400,3.0464401,2.7549708,,,,,,,,,,,,,, -51500,3.4587638,2.7786455,,,,,,,,,,,,,, -51561,,,0.6567083597183228,1.477152943611145,0.613099992275238,1.6816418170928955,50000.0,0.490200012922287,2.355665445327759,10000.0,17412.375111818314,18193.11154937744,17412.375111818314,778.0412209033966,1.103534698486328,0.0 -51600,3.4853048,2.6955268,,,,,,,,,,,,,, -51700,2.8809588,2.7508624,,,,,,,,,,,,,, -51800,3.2237654,2.7389894,,,,,,,,,,,,,, -51900,3.0825105,2.8501809,,,,,,,,,,,,,, -52000,2.8335674,2.7760973,,,,,,,,,,,,,, -52100,3.1878586,2.6609666,,,,,,,,,,,,,, -52200,2.7719336,2.6288745,,,,,,,,,,,,,, -52300,3.026237,2.755781,,,,,,,,,,,,,, -52400,3.084577,2.7348394,,,,,,,,,,,,,, -52500,3.0294147,2.7155375,,,,,,,,,,,,,, -52600,3.3533428,2.6272166,,,,,,,,,,,,,, -52700,2.8821704,2.7539902,,,,,,,,,,,,,, -52800,2.919266,2.6949546,,,,,,,,,,,,,, -52900,3.1888971,2.733759,,,,,,,,,,,,,, -53000,3.1525056,2.6922626,,,,,,,,,,,,,, -53081,,,0.7010921239852905,1.2759617567062378,0.6177399754524231,1.6607871055603027,50000.0,0.496800035238266,2.343891859054565,10000.0,17922.586621522903,18721.159491062164,17922.586621522903,795.7929601669312,1.141834735870361,0.0 -53100,3.115572,2.8583753,,,,,,,,,,,,,, -53200,3.0533328,2.6903253,,,,,,,,,,,,,, -53300,3.2218702,2.7572873,,,,,,,,,,,,,, -53400,3.126929,2.631845,,,,,,,,,,,,,, -53500,3.1887162,2.7046747,,,,,,,,,,,,,, -53600,3.3660467,2.727378,,,,,,,,,,,,,, -53700,3.5234566,2.758866,,,,,,,,,,,,,, -53800,2.9423246,2.8023784,,,,,,,,,,,,,, -53900,3.1104555,2.6625648,,,,,,,,,,,,,, -54000,2.890812,2.7474768,,,,,,,,,,,,,, -54100,3.187417,2.7940748,,,,,,,,,,,,,, -54200,3.517624,2.7345104,,,,,,,,,,,,,, -54300,3.0241435,2.8346992,,,,,,,,,,,,,, -54400,3.3429036,2.7130427,,,,,,,,,,,,,, -54500,3.42828,2.715285,,,,,,,,,,,,,, -54600,,,0.6756815910339355,1.4029406309127808,0.6170799732208252,1.6762901544570925,50000.0,0.4958000183105469,2.347243547439575,10000.0,18432.66418099404,19248.851968050003,18432.66418099404,813.3257002830505,1.177684307098389,0.0 -54600,2.968297,2.6660187,,,,,,,,,,,,,, -54700,2.9988387,2.6588423,,,,,,,,,,,,,, -54800,3.02494,2.7819653,,,,,,,,,,,,,, -54900,3.091757,2.7306166,,,,,,,,,,,,,, -55000,2.9070153,2.7307966,,,,,,,,,,,,,, -55100,3.1194491,2.724791,,,,,,,,,,,,,, -55200,3.7527575,2.7787108,,,,,,,,,,,,,, -55300,2.7406025,2.644202,,,,,,,,,,,,,, -55400,2.7507968,2.6521955,,,,,,,,,,,,,, -55500,3.0897086,2.6325417,,,,,,,,,,,,,, -55600,3.408838,2.626838,,,,,,,,,,,,,, -55700,2.8142202,2.7580304,,,,,,,,,,,,,, -55800,2.940592,2.7143517,,,,,,,,,,,,,, -55900,3.395384,2.7265217,,,,,,,,,,,,,, -56000,2.8325546,2.6543193,,,,,,,,,,,,,, -56100,3.2942874,2.5949314,,,,,,,,,,,,,, -56119,,,0.6758211255073547,1.4040615558624268,0.6243000030517578,1.6517601013183594,50000.0,0.5015000104904175,2.316536903381348,10000.0,18942.804827928543,19778.317579746246,18942.804827928543,832.5534996986389,1.2281639575958252,0.0 -56200,3.862801,2.7184343,,,,,,,,,,,,,, -56300,3.3707426,2.6519508,,,,,,,,,,,,,, -56400,3.1097758,2.7419355,,,,,,,,,,,,,, -56500,2.8601954,2.7861755,,,,,,,,,,,,,, -56600,2.8359635,2.8588681,,,,,,,,,,,,,, -56700,3.2986116,2.7850287,,,,,,,,,,,,,, -56800,3.7151172,2.8288023,,,,,,,,,,,,,, -56900,2.7749658,2.721249,,,,,,,,,,,,,, -57000,3.0480258,2.8001993,,,,,,,,,,,,,, -57100,3.2007492,2.672203,,,,,,,,,,,,,, -57200,3.27588,2.7382817,,,,,,,,,,,,,, -57300,2.8738983,2.7608109,,,,,,,,,,,,,, -57400,3.3715541,2.7239394,,,,,,,,,,,,,, -57500,3.270095,2.7087317,,,,,,,,,,,,,, -57600,3.2475448,2.6971502,,,,,,,,,,,,,, -57639,,,0.6713966727256775,1.409226417541504,0.6262199878692627,1.636522889137268,50000.0,0.5049999952316284,2.2848803997039795,10000.0,19453.03369998932,20306.5486035347,19453.03369998932,850.4759373664856,1.2610328197479248,0.0 -57700,3.0573266,2.6210952,,,,,,,,,,,,,, -57800,3.2499511,2.6416588,,,,,,,,,,,,,, -57900,2.8917294,2.7071629,,,,,,,,,,,,,, -58000,3.4559906,2.7459595,,,,,,,,,,,,,, -58100,3.2019527,2.5576797,,,,,,,,,,,,,, -58200,3.2349236,2.7445912,,,,,,,,,,,,,, -58300,2.7782166,2.6068902,,,,,,,,,,,,,, -58400,2.8687348,2.652622,,,,,,,,,,,,,, -58500,3.100201,2.6294208,,,,,,,,,,,,,, -58600,3.0737612,2.6926587,,,,,,,,,,,,,, -58700,2.7582881,2.6872044,,,,,,,,,,,,,, -58800,3.4604087,2.7298644,,,,,,,,,,,,,, -58900,4.118969,2.6293244,,,,,,,,,,,,,, -59000,3.2283287,2.773205,,,,,,,,,,,,,, -59100,2.8142054,2.7280183,,,,,,,,,,,,,, -59157,,,0.6717753410339355,1.4401389360427856,0.6253199577331543,1.6600350141525269,50000.0,0.4987000226974487,2.3283464908599854,10000.0,19962.99881720543,20833.94366264344,19962.99881720543,867.8226597309113,1.2980821132659912,0.0 -59200,3.502357,2.6949923,,,,,,,,,,,,,, -59300,3.2215302,2.6635313,,,,,,,,,,,,,, -59400,3.097181,2.6413822,,,,,,,,,,,,,, -59500,2.933855,2.6419706,,,,,,,,,,,,,, -59600,3.153286,2.7790625,,,,,,,,,,,,,, -59700,2.9917305,2.719509,,,,,,,,,,,,,, -59800,2.9062486,2.639762,,,,,,,,,,,,,, -59900,3.5060225,2.7091672,,,,,,,,,,,,,, -60000,3.5352805,2.7887163,,,,,,,,,,,,,, -60100,3.2954307,2.7467067,,,,,,,,,,,,,, -60200,2.9821506,2.685206,,,,,,,,,,,,,, -60300,3.0709503,2.687862,,,,,,,,,,,,,, -60400,3.3240879,2.656283,,,,,,,,,,,,,, -60500,3.9266613,2.6763344,,,,,,,,,,,,,, -60600,2.931058,2.5629406,,,,,,,,,,,,,, -60677,,,0.6564891338348389,1.4987339973449707,0.614579975605011,1.6961729526519775,50000.0,0.495600014925003,2.357531309127808,10000.0,20473.223114967343,21361.55599737168,20473.223114967343,885.1255774497986,1.3364100456237793,0.0 -60700,3.1214743,2.6699784,,,,,,,,,,,,,, -60800,3.3451865,2.806826,,,,,,,,,,,,,, -60900,3.033336,2.6807885,,,,,,,,,,,,,, -61000,3.1539774,2.6514764,,,,,,,,,,,,,, -61100,2.8516114,2.6621041,,,,,,,,,,,,,, -61200,3.3925397,2.684791,,,,,,,,,,,,,, -61300,2.9683278,2.5924687,,,,,,,,,,,,,, -61400,3.2253978,2.693295,,,,,,,,,,,,,, -61500,3.2161758,2.840297,,,,,,,,,,,,,, -61600,3.2172925,2.6655574,,,,,,,,,,,,,, -61700,3.6230347,2.7318597,,,,,,,,,,,,,, -61800,3.0570493,2.7491698,,,,,,,,,,,,,, -61900,3.1691144,2.7758934,,,,,,,,,,,,,, -62000,3.0847316,2.7316964,,,,,,,,,,,,,, -62100,3.3162673,2.6919558,,,,,,,,,,,,,, -62196,,,0.7053770422935486,1.2767313718795776,0.6273599863052368,1.6219308376312256,50000.0,0.5074000358581543,2.304157495498657,10000.0,20983.24218249321,21889.161187648773,20983.24218249321,902.6221942901612,1.3790311813354492,0.0 -62200,3.321483,2.824497,,,,,,,,,,,,,, -62300,2.8062332,2.6317782,,,,,,,,,,,,,, -62400,3.173839,2.6951642,,,,,,,,,,,,,, -62500,3.302977,2.7014732,,,,,,,,,,,,,, -62600,3.28245,2.7095976,,,,,,,,,,,,,, -62700,3.104095,2.7880232,,,,,,,,,,,,,, -62800,3.0805404,2.6639447,,,,,,,,,,,,,, -62900,3.0397108,2.6415944,,,,,,,,,,,,,, -63000,3.0782623,2.6478202,,,,,,,,,,,,,, -63100,3.239245,2.7904787,,,,,,,,,,,,,, -63200,3.5452147,2.6596215,,,,,,,,,,,,,, -63300,3.0511343,2.668938,,,,,,,,,,,,,, -63400,3.1170738,2.689823,,,,,,,,,,,,,, -63500,3.137996,2.6299362,,,,,,,,,,,,,, -63600,3.2365067,2.6850443,,,,,,,,,,,,,, -63700,3.3751578,2.5986273,,,,,,,,,,,,,, -63715,,,0.686922013759613,1.3271760940551758,0.626800000667572,1.617149829864502,50000.0,0.5064000487327576,2.267906904220581,10000.0,21493.27866005897,22416.715700864792,21493.27866005897,920.0448322296144,1.4266955852508545,0.0 -63800,2.9320545,2.6509893,,,,,,,,,,,,,, -63900,3.1622753,2.5902615,,,,,,,,,,,,,, -64000,3.229412,2.7374768,,,,,,,,,,,,,, -64100,2.9212842,2.5953226,,,,,,,,,,,,,, -64200,2.9820116,2.6971142,,,,,,,,,,,,,, -64300,3.9450185,2.7023854,,,,,,,,,,,,,, -64400,3.1690297,2.715599,,,,,,,,,,,,,, -64500,3.5915234,2.6520247,,,,,,,,,,,,,, -64600,3.0553606,2.6010923,,,,,,,,,,,,,, -64700,3.0927637,2.6925552,,,,,,,,,,,,,, -64800,3.313803,2.6601048,,,,,,,,,,,,,, -64900,3.3358233,2.7779012,,,,,,,,,,,,,, -65000,3.111949,2.5929492,,,,,,,,,,,,,, -65100,3.0285513,2.659643,,,,,,,,,,,,,, -65200,3.9431307,2.7586148,,,,,,,,,,,,,, -65234,,,0.6711774468421936,1.4122507572174072,0.6250199675559998,1.6446616649627686,50000.0,0.4989000260829925,2.329805850982666,10000.0,22003.35453414917,22944.03020715713,22003.35453414917,937.196320772171,1.4666364192962646,0.0 -65300,3.6340604,2.6957974,,,,,,,,,,,,,, -65400,3.030704,2.6227837,,,,,,,,,,,,,, -65500,3.1658568,2.6815293,,,,,,,,,,,,,, -65600,3.5233455,2.6580336,,,,,,,,,,,,,, -65700,3.0761468,2.5817628,,,,,,,,,,,,,, -65800,2.9927692,2.7043724,,,,,,,,,,,,,, -65900,3.4180834,2.6861978,,,,,,,,,,,,,, -66000,3.3604417,2.6731753,,,,,,,,,,,,,, -66100,3.2625926,2.6552076,,,,,,,,,,,,,, -66200,3.4059227,2.6170962,,,,,,,,,,,,,, -66300,3.2338736,2.6731741,,,,,,,,,,,,,, -66400,3.4620671,2.7413273,,,,,,,,,,,,,, -66500,3.1849015,2.7106643,,,,,,,,,,,,,, -66600,3.117788,2.6598496,,,,,,,,,,,,,, -66700,3.5499969,2.6865506,,,,,,,,,,,,,, -66753,,,0.6834542155265808,1.3994112014770508,0.635159969329834,1.6157879829406738,50000.0,0.5105000138282776,2.2873687744140625,10000.0,22513.46359586716,23471.462081432343,22513.46359586716,954.4305288791656,1.5082590579986572,0.0 -66800,2.9101675,2.6955786,,,,,,,,,,,,,, -66900,3.1345587,2.6985714,,,,,,,,,,,,,, -67000,3.647182,2.704792,,,,,,,,,,,,,, -67100,3.1340525,2.687077,,,,,,,,,,,,,, -67200,3.6778858,2.6914978,,,,,,,,,,,,,, -67300,3.0395308,2.6541834,,,,,,,,,,,,,, -67400,3.0426495,2.6568844,,,,,,,,,,,,,, -67500,3.3740137,2.6365905,,,,,,,,,,,,,, -67600,3.215955,2.6729066,,,,,,,,,,,,,, -67700,3.3198826,2.6532152,,,,,,,,,,,,,, -67800,3.2918434,2.696945,,,,,,,,,,,,,, -67900,3.6058826,2.5984578,,,,,,,,,,,,,, -68000,3.167934,2.7744272,,,,,,,,,,,,,, -68100,3.1729522,2.687634,,,,,,,,,,,,,, -68200,3.2601392,2.654445,,,,,,,,,,,,,, -68272,,,0.6808832883834839,1.3870117664337158,0.6350199580192566,1.601508378982544,50000.0,0.5113000273704529,2.254912614822388,10000.0,23023.515419483185,23998.77006816864,23023.515419483185,971.5997366905212,1.5481698513031006,0.0 -68300,3.2294693,2.6653688,,,,,,,,,,,,,, -68400,3.3639526,2.6744127,,,,,,,,,,,,,, -68500,3.3894994,2.7684267,,,,,,,,,,,,,, -68600,3.0799246,2.6164045,,,,,,,,,,,,,, -68700,3.5219138,2.6394734,,,,,,,,,,,,,, -68800,3.4987464,2.6535454,,,,,,,,,,,,,, -68900,3.1906857,2.7049677,,,,,,,,,,,,,, -69000,2.96437,2.6685414,,,,,,,,,,,,,, -69100,3.8255444,2.697846,,,,,,,,,,,,,, -69200,3.8333786,2.669737,,,,,,,,,,,,,, -69300,2.931448,2.6657925,,,,,,,,,,,,,, -69400,3.0243628,2.6478744,,,,,,,,,,,,,, -69500,3.503728,2.625247,,,,,,,,,,,,,, -69600,3.339135,2.5673053,,,,,,,,,,,,,, -69700,3.4862678,2.6290312,,,,,,,,,,,,,, -69791,,,0.6853276491165161,1.3273032903671265,0.6344199776649475,1.5748940706253052,50000.0,0.5146000385284424,2.2334415912628174,10000.0,23533.512261867523,24526.164582252502,23533.512261867523,988.909274339676,1.5890882015228271,0.0 -69800,3.7517145,2.6728988,,,,,,,,,,,,,, -69900,3.0789645,2.6291397,,,,,,,,,,,,,, -70000,3.1696432,2.6717398,,,,,,,,,,,,,, -70100,3.3445356,2.731591,,,,,,,,,,,,,, -70200,3.4812675,2.6791334,,,,,,,,,,,,,, -70300,3.1542783,2.5509808,,,,,,,,,,,,,, -70400,3.2314322,2.5495067,,,,,,,,,,,,,, -70500,3.0423515,2.6824667,,,,,,,,,,,,,, -70600,3.4029846,2.6532557,,,,,,,,,,,,,, -70700,3.589729,2.7194273,,,,,,,,,,,,,, -70800,3.0763516,2.675208,,,,,,,,,,,,,, -70900,3.099114,2.6813745,,,,,,,,,,,,,, -71000,3.5503592,2.573063,,,,,,,,,,,,,, -71100,3.2484705,2.5875528,,,,,,,,,,,,,, -71200,3.8366807,2.7943504,,,,,,,,,,,,,, -71300,2.9179308,2.6362526,,,,,,,,,,,,,, -71310,,,0.7115951776504517,1.2286536693572998,0.6373599767684937,1.564759612083435,50000.0,0.5209000110626221,2.210066556930542,10000.0,24043.516547441483,25053.56804537773,24043.516547441483,1006.218918800354,1.6315371990203855,0.0 -71400,3.307387,2.6284337,,,,,,,,,,,,,, -71500,3.1961327,2.5796082,,,,,,,,,,,,,, -71600,3.1019166,2.6647613,,,,,,,,,,,,,, -71700,3.8021336,2.7194548,,,,,,,,,,,,,, -71800,3.4966052,2.657986,,,,,,,,,,,,,, -71900,3.210798,2.704482,,,,,,,,,,,,,, -72000,3.226344,2.5979428,,,,,,,,,,,,,, -72100,3.1365783,2.6833649,,,,,,,,,,,,,, -72200,3.31222,2.6760123,,,,,,,,,,,,,, -72300,3.2833352,2.67504,,,,,,,,,,,,,, -72400,3.9192038,2.629921,,,,,,,,,,,,,, -72500,3.3552926,2.6077502,,,,,,,,,,,,,, -72600,3.405842,2.7101476,,,,,,,,,,,,,, -72700,3.4115791,2.754103,,,,,,,,,,,,,, -72800,3.4222732,2.612867,,,,,,,,,,,,,, -72830,,,0.6828762888908386,1.3698737621307373,0.6262800097465515,1.6336623430252075,50000.0,0.509600043296814,2.260319232940674,10000.0,24553.690885543823,25580.96841239929,24553.690885543823,1023.3500711917876,1.678883075714111,0.0 -72900,3.2124915,2.7151885,,,,,,,,,,,,,, -73000,3.611786,2.6301715,,,,,,,,,,,,,, -73100,3.4298666,2.615415,,,,,,,,,,,,,, -73200,3.3000073,2.5843291,,,,,,,,,,,,,, -73300,3.367941,2.6483302,,,,,,,,,,,,,, -73400,3.1355548,2.4568467,,,,,,,,,,,,,, -73500,3.169355,2.532209,,,,,,,,,,,,,, -73600,3.4830623,2.6003156,,,,,,,,,,,,,, -73700,3.4314866,2.7064788,,,,,,,,,,,,,, -73800,3.208462,2.62958,,,,,,,,,,,,,, -73900,3.118713,2.6427052,,,,,,,,,,,,,, -74000,2.928728,2.5458453,,,,,,,,,,,,,, -74100,3.216616,2.5877216,,,,,,,,,,,,,, -74200,3.1711152,2.6249602,,,,,,,,,,,,,, -74300,3.432116,2.6741593,,,,,,,,,,,,,, -74349,,,0.6959502696990967,1.3203473091125488,0.6404399871826172,1.5740586519241333,50000.0,0.5154000520706177,2.2236592769622803,10000.0,25063.78478884697,26108.48323750496,25063.78478884697,1040.6843490600586,1.718044996261597,0.0 -74400,3.9054554,2.5973058,,,,,,,,,,,,,, -74500,3.2745066,2.6687632,,,,,,,,,,,,,, -74600,3.5451899,2.6577826,,,,,,,,,,,,,, -74700,3.4691799,2.6096303,,,,,,,,,,,,,, -74800,3.5463057,2.578267,,,,,,,,,,,,,, -74900,3.5627477,2.6801078,,,,,,,,,,,,,, -75000,3.5488868,2.5949142,,,,,,,,,,,,,, -75100,3.4997342,2.6261945,,,,,,,,,,,,,, -75200,3.013261,2.5664942,,,,,,,,,,,,,, -75300,3.4282851,2.6442106,,,,,,,,,,,,,, -75400,3.124086,2.5427442,,,,,,,,,,,,,, -75500,3.6598725,2.634669,,,,,,,,,,,,,, -75600,3.3971574,2.5893533,,,,,,,,,,,,,, -75700,3.368744,2.6651182,,,,,,,,,,,,,, -75800,3.1272645,2.580422,,,,,,,,,,,,,, -75868,,,0.6812818646430969,1.3611277341842651,0.6279599666595459,1.6080148220062256,50000.0,0.5042999982833862,2.291550636291504,10000.0,25573.829756498337,26636.57822918892,25573.829756498337,1058.6442294120789,1.7606401443481443,0.0 -75900,3.2356477,2.586609,,,,,,,,,,,,,, -76000,3.2451203,2.5476255,,,,,,,,,,,,,, -76100,3.5382767,2.6801374,,,,,,,,,,,,,, -76200,3.608812,2.703284,,,,,,,,,,,,,, -76300,3.1476293,2.6027904,,,,,,,,,,,,,, -76400,3.6437006,2.562685,,,,,,,,,,,,,, -76500,3.664266,2.7159011,,,,,,,,,,,,,, -76600,3.832306,2.61534,,,,,,,,,,,,,, -76700,2.9524043,2.6281419,,,,,,,,,,,,,, -76800,3.8366861,2.5852983,,,,,,,,,,,,,, -76900,3.6754227,2.7143044,,,,,,,,,,,,,, -77000,3.2875667,2.6117585,,,,,,,,,,,,,, -77100,3.7108314,2.6013536,,,,,,,,,,,,,, -77200,3.1075847,2.6219592,,,,,,,,,,,,,, -77300,3.3704615,2.665635,,,,,,,,,,,,,, -77387,,,0.6804049611091614,1.39091157913208,0.6298800110816956,1.6200053691864014,50000.0,0.5112000107765198,2.2878735065460205,10000.0,26084.01960873604,27163.98302912712,26084.01960873604,1075.7697627544403,1.8028457164764404,0.0 -77400,3.1851447,2.6036832,,,,,,,,,,,,,, -77500,3.5652044,2.6928296,,,,,,,,,,,,,, -77600,3.2292793,2.600222,,,,,,,,,,,,,, -77700,3.289493,2.6426516,,,,,,,,,,,,,, -77800,3.2126534,2.6486034,,,,,,,,,,,,,, -77900,3.5386868,2.6657314,,,,,,,,,,,,,, -78000,3.1473343,2.6006644,,,,,,,,,,,,,, -78100,3.297023,2.5933013,,,,,,,,,,,,,, -78200,3.7838743,2.5320191,,,,,,,,,,,,,, -78300,3.4811382,2.6417167,,,,,,,,,,,,,, -78400,4.127418,2.5756469,,,,,,,,,,,,,, -78500,4.0628076,2.6257868,,,,,,,,,,,,,, -78600,3.5783668,2.57938,,,,,,,,,,,,,, -78700,3.629978,2.6658876,,,,,,,,,,,,,, -78800,3.3552291,2.6513958,,,,,,,,,,,,,, -78900,4.1260924,2.6003466,,,,,,,,,,,,,, -78903,,,0.7225167155265808,1.2051185369491575,0.6467399597167969,1.5430620908737185,50000.0,0.5195000171661377,2.204761743545532,10000.0,26593.105378627777,27691.36990666389,26593.105378627777,1093.0664336681366,2.7603235244750977,0.0 -79000,3.4578714,2.6435266,,,,,,,,,,,,,, -79100,3.7602482,2.5946598,,,,,,,,,,,,,, -79200,3.2924922,2.671077,,,,,,,,,,,,,, -79300,3.4654288,2.615162,,,,,,,,,,,,,, -79400,3.2679918,2.6648917,,,,,,,,,,,,,, -79500,3.6579998,2.7110133,,,,,,,,,,,,,, -79600,3.6300642,2.617704,,,,,,,,,,,,,, -79700,3.429107,2.4778953,,,,,,,,,,,,,, -79800,3.2435243,2.645797,,,,,,,,,,,,,, -79900,3.18371,2.4867117,,,,,,,,,,,,,, -80000,3.4109354,2.6705818,,,,,,,,,,,,,, -80100,3.643121,2.6246173,,,,,,,,,,,,,, -80200,3.6993377,2.6506019,,,,,,,,,,,,,, -80300,3.2520964,2.5906043,,,,,,,,,,,,,, -80400,3.5286763,2.5990775,,,,,,,,,,,,,, -80423,,,0.6974050998687744,1.309516429901123,0.6312400102615356,1.6141302585601809,50000.0,0.5088000297546387,2.290923833847046,10000.0,27103.18165063858,28218.77651834488,27103.18165063858,1110.3081967830658,2.8014683723449707,0.0 -80500,3.7710443,2.717369,,,,,,,,,,,,,, -80600,3.206258,2.589959,,,,,,,,,,,,,, -80700,3.4155352,2.592369,,,,,,,,,,,,,, -80800,4.14236,2.6919284,,,,,,,,,,,,,, -80900,3.681954,2.614753,,,,,,,,,,,,,, -81000,3.4908912,2.5095224,,,,,,,,,,,,,, -81100,3.116677,2.64479,,,,,,,,,,,,,, -81200,3.7961164,2.5438795,,,,,,,,,,,,,, -81300,3.2971172,2.6459475,,,,,,,,,,,,,, -81400,4.410897,2.5726364,,,,,,,,,,,,,, -81500,3.192226,2.5624945,,,,,,,,,,,,,, -81600,3.3580613,2.5212655,,,,,,,,,,,,,, -81700,3.571458,2.5362012,,,,,,,,,,,,,, -81800,3.719464,2.662195,,,,,,,,,,,,,, -81900,3.5739806,2.5659754,,,,,,,,,,,,,, -81943,,,0.7071707248687744,1.2736690044403076,0.6467799544334412,1.539712905883789,50000.0,0.5279000401496887,2.181778907775879,10000.0,27613.21877503395,28746.222467899323,27613.21877503395,1127.628051996231,2.843165397644043,0.0 -82000,3.3400497,2.608196,,,,,,,,,,,,,, -82100,3.7261488,2.5559952,,,,,,,,,,,,,, -82200,3.5466273,2.5670123,,,,,,,,,,,,,, -82300,3.2051105,2.6501439,,,,,,,,,,,,,, -82400,3.6934776,2.7137742,,,,,,,,,,,,,, -82500,3.348707,2.5798616,,,,,,,,,,,,,, -82600,3.5584192,2.7057345,,,,,,,,,,,,,, -82700,3.791451,2.6367912,,,,,,,,,,,,,, -82800,3.85537,2.6208215,,,,,,,,,,,,,, -82900,3.8584974,2.5262852,,,,,,,,,,,,,, -83000,4.214519,2.559742,,,,,,,,,,,,,, -83100,3.5471175,2.5195296,,,,,,,,,,,,,, -83200,3.7406666,2.6503801,,,,,,,,,,,,,, -83300,3.4109013,2.565165,,,,,,,,,,,,,, -83400,3.2842796,2.6155148,,,,,,,,,,,,,, -83462,,,0.7016302347183228,1.2925941944122314,0.6476199626922607,1.5447707176208496,50000.0,0.5254000425338745,2.181201457977295,10000.0,28123.213754415512,29273.577413082123,28123.213754415512,1144.8795185089111,2.9040091037750244,0.0 -83500,3.3622832,2.6017992,,,,,,,,,,,,,, -83600,3.9520872,2.640228,,,,,,,,,,,,,, -83700,3.6515296,2.5424275,,,,,,,,,,,,,, -83800,4.077261,2.6677976,,,,,,,,,,,,,, -83900,3.9708114,2.624649,,,,,,,,,,,,,, -84000,3.626649,2.5448189,,,,,,,,,,,,,, -84100,3.3888216,2.5562422,,,,,,,,,,,,,, -84200,3.8934424,2.6098473,,,,,,,,,,,,,, -84300,3.588349,2.5545673,,,,,,,,,,,,,, -84400,3.3756375,2.619127,,,,,,,,,,,,,, -84500,3.4054668,2.6030707,,,,,,,,,,,,,, -84600,3.7248342,2.606024,,,,,,,,,,,,,, -84700,3.0906236,2.5825632,,,,,,,,,,,,,, -84800,3.2857032,2.6589437,,,,,,,,,,,,,, -84900,3.4563181,2.568029,,,,,,,,,,,,,, -84982,,,0.6917649507522583,1.3234705924987793,0.6398599743843079,1.566019058227539,50000.0,0.5200999975204468,2.230949878692627,10000.0,28633.28370141983,29800.936242103577,28633.28370141983,1162.0792186260223,2.945914030075073,0.0 -85000,3.9239206,2.6314878,,,,,,,,,,,,,, -85100,3.71255,2.542151,,,,,,,,,,,,,, -85200,3.5931647,2.5685356,,,,,,,,,,,,,, -85300,3.2430458,2.5462813,,,,,,,,,,,,,, -85400,3.3922687,2.5767968,,,,,,,,,,,,,, -85500,3.4934168,2.6415749,,,,,,,,,,,,,, -85600,3.5675209,2.534369,,,,,,,,,,,,,, -85700,3.899515,2.6067667,,,,,,,,,,,,,, -85800,3.723112,2.5579295,,,,,,,,,,,,,, -85900,3.4457006,2.5479012,,,,,,,,,,,,,, -86000,3.3022778,2.5371966,,,,,,,,,,,,,, -86100,3.4258828,2.535953,,,,,,,,,,,,,, -86200,3.9672213,2.553029,,,,,,,,,,,,,, -86300,3.707299,2.595179,,,,,,,,,,,,,, -86400,3.4468107,2.5477245,,,,,,,,,,,,,, -86500,4.314533,2.6982152,,,,,,,,,,,,,, -86501,,,0.7071906924247742,1.2525839805603027,0.6568399667739868,1.488780856132507,50000.0,0.5342000126838684,2.135103940963745,10000.0,29143.66135954857,30328.49675798416,29143.66135954857,1179.1734466552734,2.9869160652160645,0.0 -86600,3.923064,2.5504045,,,,,,,,,,,,,, -86700,3.8352013,2.5528512,,,,,,,,,,,,,, -86800,4.012416,2.6378427,,,,,,,,,,,,,, -86900,3.5095792,2.5595496,,,,,,,,,,,,,, -87000,3.0261133,2.638667,,,,,,,,,,,,,, -87100,3.620567,2.565269,,,,,,,,,,,,,, -87200,3.8127172,2.5705664,,,,,,,,,,,,,, -87300,3.8034196,2.6367283,,,,,,,,,,,,,, -87400,3.2588046,2.60852,,,,,,,,,,,,,, -87500,3.618086,2.6016827,,,,,,,,,,,,,, -87600,3.6888134,2.607814,,,,,,,,,,,,,, -87700,3.4586356,2.5182338,,,,,,,,,,,,,, -87800,3.463712,2.5194514,,,,,,,,,,,,,, -87900,3.911951,2.6152513,,,,,,,,,,,,,, -88000,3.6704,2.5432873,,,,,,,,,,,,,, -88020,,,0.7352319955825806,1.1445754766464231,0.649459958076477,1.5333545207977295,50000.0,0.5270000100135803,2.18676495552063,10000.0,29653.78978037834,30855.992782831192,29653.78978037834,1196.4498274326324,3.030723810195923,0.0 -88100,3.7940972,2.4833705,,,,,,,,,,,,,, -88200,3.6314852,2.599252,,,,,,,,,,,,,, -88300,3.3212383,2.6088061,,,,,,,,,,,,,, -88400,4.1303277,2.4916115,,,,,,,,,,,,,, -88500,3.5610318,2.5623538,,,,,,,,,,,,,, -88600,4.5074334,2.545831,,,,,,,,,,,,,, -88700,3.9864018,2.5981987,,,,,,,,,,,,,, -88800,3.8212051,2.4216042,,,,,,,,,,,,,, -88900,3.8067913,2.498829,,,,,,,,,,,,,, -89000,4.1081777,2.5538363,,,,,,,,,,,,,, -89100,4.0580835,2.5606508,,,,,,,,,,,,,, -89200,3.1928663,2.545319,,,,,,,,,,,,,, -89300,3.2617517,2.49722,,,,,,,,,,,,,, -89400,3.4976766,2.4546573,,,,,,,,,,,,,, -89500,3.2974555,2.4161406,,,,,,,,,,,,,, -89541,,,0.7146444320678711,1.2600167989730835,0.6505399942398071,1.539999008178711,50000.0,0.523900032043457,2.197510004043579,10000.0,30164.02828192711,31383.692828655243,30164.02828192711,1213.8182995319366,3.075812101364136,0.0 -89600,3.7402809,2.6342673,,,,,,,,,,,,,, -89700,3.455018,2.6481395,,,,,,,,,,,,,, -89800,3.7787986,2.6184683,,,,,,,,,,,,,, -89900,3.42629,2.5194018,,,,,,,,,,,,,, -90000,3.6497445,2.5354424,,,,,,,,,,,,,, -90100,3.5982647,2.5573869,,,,,,,,,,,,,, -90200,3.722896,2.5365033,,,,,,,,,,,,,, -90300,3.524422,2.6152697,,,,,,,,,,,,,, -90400,3.6785138,2.604533,,,,,,,,,,,,,, -90500,3.8676283,2.623309,,,,,,,,,,,,,, -90600,4.0632043,2.5456834,,,,,,,,,,,,,, -90700,3.3468838,2.4422336,,,,,,,,,,,,,, -90800,3.6471877,2.5496297,,,,,,,,,,,,,, -90900,3.988566,2.495093,,,,,,,,,,,,,, -91000,3.9491203,2.554894,,,,,,,,,,,,,, -91061,,,0.7093231678009033,1.2750592231750488,0.6477000117301941,1.54591703414917,50000.0,0.5308000445365906,2.186326503753662,10000.0,30674.164899349213,31911.384786367416,30674.164899349213,1231.279939413071,3.1218960285186768,0.0 -91100,3.551661,2.5791192,,,,,,,,,,,,,, -91200,3.4917781,2.6391108,,,,,,,,,,,,,, -91300,3.7045026,2.540405,,,,,,,,,,,,,, -91400,3.744286,2.5004728,,,,,,,,,,,,,, -91500,3.753117,2.6218839,,,,,,,,,,,,,, -91600,4.150305,2.596522,,,,,,,,,,,,,, -91700,3.074454,2.5038443,,,,,,,,,,,,,, -91800,3.8465033,2.575709,,,,,,,,,,,,,, -91900,4.113621,2.5855742,,,,,,,,,,,,,, -92000,4.0867133,2.5628357,,,,,,,,,,,,,, -92100,4.540432,2.639235,,,,,,,,,,,,,, -92200,3.905249,2.5812583,,,,,,,,,,,,,, -92300,3.9582534,2.509418,,,,,,,,,,,,,, -92400,3.3530498,2.5452604,,,,,,,,,,,,,, -92500,3.851612,2.5513716,,,,,,,,,,,,,, -92581,,,0.720723032951355,1.2018429040908811,0.6581999659538269,1.4694814682006836,50000.0,0.5351000428199768,2.1427531242370605,10000.0,31184.35217189789,32438.77952504158,31184.35217189789,1248.3998510837555,3.161928653717041,0.0 -92600,3.519851,2.4776978,,,,,,,,,,,,,, -92700,3.8768833,2.5395663,,,,,,,,,,,,,, -92800,3.9824412,2.5445938,,,,,,,,,,,,,, -92900,3.753307,2.5152593,,,,,,,,,,,,,, -93000,3.4849396,2.50014,,,,,,,,,,,,,, -93100,3.601769,2.5443084,,,,,,,,,,,,,, -93200,3.3530903,2.469831,,,,,,,,,,,,,, -93300,4.1121125,2.5526366,,,,,,,,,,,,,, -93400,3.688838,2.4466076,,,,,,,,,,,,,, -93500,4.0129766,2.524873,,,,,,,,,,,,,, -93600,3.384864,2.5151129,,,,,,,,,,,,,, -93700,3.8899848,2.5925112,,,,,,,,,,,,,, -93800,3.6468043,2.532473,,,,,,,,,,,,,, -93900,3.8679724,2.4562907,,,,,,,,,,,,,, -94000,3.7344277,2.5401905,,,,,,,,,,,,,, -94100,,,0.711355984210968,1.227055311203003,0.6619600057601929,1.455871343612671,50000.0,0.5314000248908997,2.120694875717163,10000.0,31694.415167331696,32966.347737550735,31694.415167331696,1265.8118696212769,3.207626342773437,0.0 -94100,3.9945536,2.581709,,,,,,,,,,,,,, -94200,4.0660973,2.5362122,,,,,,,,,,,,,, -94300,3.574709,2.5327945,,,,,,,,,,,,,, -94400,4.293809,2.5483932,,,,,,,,,,,,,, -94500,3.7354467,2.5128176,,,,,,,,,,,,,, -94600,3.857862,2.5319166,,,,,,,,,,,,,, -94700,3.8120506,2.503473,,,,,,,,,,,,,, -94800,3.6239839,2.4737406,,,,,,,,,,,,,, -94900,4.0399733,2.6259944,,,,,,,,,,,,,, -95000,3.618215,2.5434031,,,,,,,,,,,,,, -95100,3.6916409,2.5372815,,,,,,,,,,,,,, -95200,4.193561,2.5830886,,,,,,,,,,,,,, -95300,3.9714406,2.498018,,,,,,,,,,,,,, -95400,3.7985814,2.5920634,,,,,,,,,,,,,, -95500,3.709389,2.5681212,,,,,,,,,,,,,, -95600,3.9667072,2.5676916,,,,,,,,,,,,,, -95620,,,0.7209023833274841,1.197264552116394,0.6657999753952026,1.4429244995117188,50000.0,0.5458000302314758,2.099045753479004,10000.0,32204.58331465721,33494.21009898186,32204.58331465721,1283.4004225730896,3.2659912109375,0.0 -95700,3.9194262,2.568372,,,,,,,,,,,,,, -95800,3.7926235,2.4451993,,,,,,,,,,,,,, -95900,3.3438723,2.578673,,,,,,,,,,,,,, -96000,4.19305,2.5038855,,,,,,,,,,,,,, -96100,3.267881,2.521634,,,,,,,,,,,,,, -96200,3.850525,2.5312748,,,,,,,,,,,,,, -96300,4.0643253,2.555669,,,,,,,,,,,,,, -96400,3.838285,2.420397,,,,,,,,,,,,,, -96500,5.3538303,2.5616,,,,,,,,,,,,,, -96600,3.777653,2.5482457,,,,,,,,,,,,,, -96700,3.9907613,2.4933507,,,,,,,,,,,,,, -96800,3.698634,2.4640908,,,,,,,,,,,,,, -96900,3.956921,2.5119941,,,,,,,,,,,,,, -97000,3.7455168,2.5484374,,,,,,,,,,,,,, -97100,3.8032262,2.5479715,,,,,,,,,,,,,, -97141,,,0.7374441623687744,1.1343486309051514,0.6576399803161621,1.4939663410186768,50000.0,0.5304000377655029,2.161627769470215,10000.0,32714.638087511063,34021.7456908226,32714.638087511063,1300.787608385086,3.3119215965271,0.0 -97200,4.446449,2.5246012,,,,,,,,,,,,,, -97300,3.720823,2.539235,,,,,,,,,,,,,, -97400,3.8476682,2.4727848,,,,,,,,,,,,,, -97500,3.6431913,2.463958,,,,,,,,,,,,,, -97600,4.518397,2.5742304,,,,,,,,,,,,,, -97700,3.699897,2.5216935,,,,,,,,,,,,,, -97800,3.6192708,2.5099633,,,,,,,,,,,,,, -97900,3.632117,2.4637413,,,,,,,,,,,,,, -98000,3.80545,2.464478,,,,,,,,,,,,,, -98100,3.7955027,2.4650388,,,,,,,,,,,,,, -98200,4.0275474,2.5031717,,,,,,,,,,,,,, -98300,4.5350375,2.5448468,,,,,,,,,,,,,, -98400,4.401247,2.4741704,,,,,,,,,,,,,, -98500,4.0556374,2.4907565,,,,,,,,,,,,,, -98600,3.4091246,2.4664917,,,,,,,,,,,,,, -98661,,,0.7361487150192261,1.1169815063476562,0.6668399572372437,1.4288816452026367,50000.0,0.5508000254631042,2.060675621032715,10000.0,33224.680389881134,34549.06488656998,33224.680389881134,1317.972489118576,3.3565704822540283,0.0 -98700,3.6547074,2.4164774,,,,,,,,,,,,,, -98800,4.736819,2.5002227,,,,,,,,,,,,,, -98900,3.570105,2.483798,,,,,,,,,,,,,, -99000,4.3684134,2.5234141,,,,,,,,,,,,,, -99100,3.5279884,2.5091114,,,,,,,,,,,,,, -99200,3.6652987,2.533033,,,,,,,,,,,,,, -99300,3.733586,2.4857788,,,,,,,,,,,,,, -99400,4.3810754,2.457613,,,,,,,,,,,,,, -99500,3.908003,2.6340013,,,,,,,,,,,,,, -99600,3.9638557,2.5382323,,,,,,,,,,,,,, -99700,3.7092676,2.5630848,,,,,,,,,,,,,, -99800,3.618607,2.349113,,,,,,,,,,,,,, -99900,4.1436787,2.4535444,,,,,,,,,,,,,, -100000,3.8697138,2.4633362,,,,,,,,,,,,,, -100100,3.8645363,2.5283113,,,,,,,,,,,,,, -100181,,,0.7201650142669678,1.1991384029388428,0.6631199717521667,1.4586812257766724,50000.0,0.5432000160217285,2.1067442893981934,10000.0,33734.73139023781,35076.42337989807,33734.73139023781,1335.168357372284,3.4204201698303223,0.0 -100200,3.7888558,2.4819274,,,,,,,,,,,,,, -100300,3.7328618,2.5308676,,,,,,,,,,,,,, -100400,3.625893,2.5507226,,,,,,,,,,,,,, -100500,3.7913263,2.5298016,,,,,,,,,,,,,, -100600,4.4614058,2.3911374,,,,,,,,,,,,,, -100700,4.105508,2.4886723,,,,,,,,,,,,,, -100800,3.6582072,2.4400966,,,,,,,,,,,,,, -100900,4.1054645,2.5200772,,,,,,,,,,,,,, -101000,4.127303,2.437655,,,,,,,,,,,,,, -101100,3.9032834,2.452229,,,,,,,,,,,,,, -101200,3.9996257,2.3826213,,,,,,,,,,,,,, -101300,4.062743,2.494574,,,,,,,,,,,,,, -101400,3.8656566,2.4636354,,,,,,,,,,,,,, -101500,3.632176,2.4998944,,,,,,,,,,,,,, -101600,4.22472,2.43177,,,,,,,,,,,,,, -101700,,,0.7339365482330322,1.128197193145752,0.6735199689865112,1.408560276031494,50000.0,0.5473999977111816,2.062584638595581,10000.0,34244.87643456459,35603.890127658844,34244.87643456459,1352.396959066391,3.4652857780456543,0.0 -101700,3.7826607,2.4881198,,,,,,,,,,,,,, -101800,4.002224,2.4907246,,,,,,,,,,,,,, -101900,4.201103,2.5749595,,,,,,,,,,,,,, -102000,4.0919533,2.4783177,,,,,,,,,,,,,, -102100,3.9796033,2.5209365,,,,,,,,,,,,,, -102200,3.910566,2.4989078,,,,,,,,,,,,,, -102300,3.8366537,2.5280745,,,,,,,,,,,,,, -102400,3.538597,2.414731,,,,,,,,,,,,,, -102500,4.1081004,2.5043395,,,,,,,,,,,,,, -102600,3.827143,2.4784665,,,,,,,,,,,,,, -102700,4.2900724,2.4367046,,,,,,,,,,,,,, -102800,5.4668307,2.5198026,,,,,,,,,,,,,, -102900,4.1505394,2.4710732,,,,,,,,,,,,,, -103000,4.068418,2.5637884,,,,,,,,,,,,,, -103100,4.0966268,2.534036,,,,,,,,,,,,,, -103200,3.9762366,2.5164208,,,,,,,,,,,,,, -103220,,,0.7251275181770325,1.1850506067276,0.668179988861084,1.444494366645813,50000.0,0.539900004863739,2.114496946334839,10000.0,34754.832845926285,36131.21515202522,34754.832845926285,1369.6633830070496,3.5199079513549805,0.0 -103300,3.7597933,2.441399,,,,,,,,,,,,,, -103400,3.801249,2.5022879,,,,,,,,,,,,,, -103500,3.8514905,2.3902025,,,,,,,,,,,,,, -103600,4.720424,2.4432302,,,,,,,,,,,,,, -103700,4.369379,2.3937342,,,,,,,,,,,,,, -103800,4.065407,2.439421,,,,,,,,,,,,,, -103900,4.1959453,2.463069,,,,,,,,,,,,,, -104000,4.214337,2.4855857,,,,,,,,,,,,,, -104100,4.280576,2.5214548,,,,,,,,,,,,,, -104200,4.065221,2.3593903,,,,,,,,,,,,,, -104300,4.3628216,2.571648,,,,,,,,,,,,,, -104400,4.0179505,2.5026379,,,,,,,,,,,,,, -104500,3.6360319,2.4210773,,,,,,,,,,,,,, -104600,4.378494,2.4068918,,,,,,,,,,,,,, -104700,4.223994,2.4345775,,,,,,,,,,,,,, -104740,,,0.7435427308082581,1.0935951471328735,0.6726999878883362,1.40591561794281,50000.0,0.5522000193595886,2.052886724472046,10000.0,35265.055674791336,36658.89939570427,35265.055674791336,1387.0266127586365,3.5705230236053467,0.0 -104800,4.2951717,2.415825,,,,,,,,,,,,,, -104900,4.133156,2.4341488,,,,,,,,,,,,,, -105000,3.9008439,2.4244447,,,,,,,,,,,,,, -105100,4.348023,2.5737455,,,,,,,,,,,,,, -105200,4.1640024,2.4008503,,,,,,,,,,,,,, -105300,4.059875,2.5079727,,,,,,,,,,,,,, -105400,3.7000015,2.4943233,,,,,,,,,,,,,, -105500,3.7729802,2.4522471,,,,,,,,,,,,,, -105600,4.4396906,2.4332492,,,,,,,,,,,,,, -105700,3.5687184,2.4924364,,,,,,,,,,,,,, -105800,4.109091,2.41593,,,,,,,,,,,,,, -105900,3.7177672,2.4312954,,,,,,,,,,,,,, -106000,3.934297,2.4277532,,,,,,,,,,,,,, -106100,3.8956218,2.5090044,,,,,,,,,,,,,, -106200,4.5715218,2.4215825,,,,,,,,,,,,,, -106260,,,0.7564970850944519,1.039839744567871,0.6774199604988098,1.3940984010696411,50000.0,0.5509999990463257,2.0571823120117188,10000.0,35775.128962278366,37186.31377243996,35775.128962278366,1404.2709031105042,3.6195414066314697,0.0 -106300,4.361489,2.5022745,,,,,,,,,,,,,, -106400,3.993371,2.4387367,,,,,,,,,,,,,, -106500,4.056973,2.4777749,,,,,,,,,,,,,, -106600,4.005858,2.510572,,,,,,,,,,,,,, -106700,4.035038,2.3885336,,,,,,,,,,,,,, -106800,3.8931024,2.4025617,,,,,,,,,,,,,, -106900,4.0829062,2.5234306,,,,,,,,,,,,,, -107000,3.730756,2.356446,,,,,,,,,,,,,, -107100,3.8766875,2.3568506,,,,,,,,,,,,,, -107200,4.0643253,2.4233363,,,,,,,,,,,,,, -107300,4.5222664,2.51414,,,,,,,,,,,,,, -107400,4.8430066,2.491042,,,,,,,,,,,,,, -107500,4.2055106,2.4755123,,,,,,,,,,,,,, -107600,4.5955877,2.5461931,,,,,,,,,,,,,, -107700,4.3612833,2.4827964,,,,,,,,,,,,,, -107781,,,0.7531289458274841,1.063181757926941,0.6829400062561035,1.3601897954940796,50000.0,0.5616000294685364,2.0226991176605225,10000.0,36285.17335796356,37713.587841272354,36285.17335796356,1421.402559518814,3.66951847076416,0.0 -107800,4.441686,2.5492885,,,,,,,,,,,,,, -107900,4.5635853,2.4845283,,,,,,,,,,,,,, -108000,4.25235,2.4817119,,,,,,,,,,,,,, -108100,4.347503,2.5122774,,,,,,,,,,,,,, -108200,4.320313,2.4459944,,,,,,,,,,,,,, -108300,4.012342,2.3788178,,,,,,,,,,,,,, -108400,4.072155,2.473683,,,,,,,,,,,,,, -108500,4.0806985,2.4302413,,,,,,,,,,,,,, -108600,4.2384086,2.3864255,,,,,,,,,,,,,, -108700,4.033303,2.445557,,,,,,,,,,,,,, -108800,3.8776336,2.4616582,,,,,,,,,,,,,, -108900,4.352316,2.5443232,,,,,,,,,,,,,, -109000,3.8607626,2.4231634,,,,,,,,,,,,,, -109100,3.854872,2.407878,,,,,,,,,,,,,, -109200,4.370511,2.4270403,,,,,,,,,,,,,, -109300,4.094281,2.4338083,,,,,,,,,,,,,, -109301,,,0.7407127022743225,1.1150754690170288,0.6787599921226501,1.402061939239502,50000.0,0.5556000471115112,2.055934190750122,10000.0,36795.37970161438,38241.43623971939,36795.37970161438,1438.9501745700836,3.7160933017730713,0.0 -109400,4.3230424,2.3611722,,,,,,,,,,,,,, -109500,3.934649,2.4276125,,,,,,,,,,,,,, -109600,4.581505,2.4288144,,,,,,,,,,,,,, -109700,3.8817906,2.4653168,,,,,,,,,,,,,, -109800,4.723576,2.3980165,,,,,,,,,,,,,, -109900,4.8145623,2.516254,,,,,,,,,,,,,, -110000,4.5229454,2.5537965,,,,,,,,,,,,,, -110100,3.965079,2.408658,,,,,,,,,,,,,, -110200,4.276529,2.347596,,,,,,,,,,,,,, -110300,4.1135373,2.5083795,,,,,,,,,,,,,, -110400,4.3107247,2.3803768,,,,,,,,,,,,,, -110500,4.231165,2.3735423,,,,,,,,,,,,,, -110600,4.3713617,2.4941921,,,,,,,,,,,,,, -110700,4.0574346,2.335199,,,,,,,,,,,,,, -110800,4.321993,2.2790754,,,,,,,,,,,,,, -110820,,,0.746113657951355,1.0706056356430054,0.685479998588562,1.3558672666549685,50000.0,0.5631000399589539,2.0043318271636963,10000.0,37305.41358447075,38768.90533566475,37305.41358447075,1456.2903575897217,3.763372182846069,0.0 -110900,4.619423,2.4417312,,,,,,,,,,,,,, -111000,4.4726005,2.4210238,,,,,,,,,,,,,, -111100,4.2964206,2.3940625,,,,,,,,,,,,,, -111200,4.3435307,2.3529367,,,,,,,,,,,,,, -111300,4.000729,2.430921,,,,,,,,,,,,,, -111400,4.3761244,2.4859588,,,,,,,,,,,,,, -111500,4.6199274,2.4316754,,,,,,,,,,,,,, -111600,4.072749,2.3994389,,,,,,,,,,,,,, -111700,5.282599,2.4588115,,,,,,,,,,,,,, -111800,4.350537,2.4790132,,,,,,,,,,,,,, -111900,4.352868,2.3579965,,,,,,,,,,,,,, -112000,4.0835686,2.4228175,,,,,,,,,,,,,, -112100,4.259002,2.4731898,,,,,,,,,,,,,, -112200,4.268332,2.4514608,,,,,,,,,,,,,, -112300,4.4794927,2.377452,,,,,,,,,,,,,, -112340,,,0.7463727593421936,1.0985363721847534,0.6815400123596191,1.3779244422912598,50000.0,0.5627000331878662,2.016084909439087,10000.0,37815.575922966,39296.36831307411,37815.575922966,1473.4958517551422,3.8107833862304688,0.0 -112400,4.3821383,2.4543464,,,,,,,,,,,,,, -112500,3.8235583,2.3737974,,,,,,,,,,,,,, -112600,4.764964,2.4709346,,,,,,,,,,,,,, -112700,4.101064,2.4733405,,,,,,,,,,,,,, -112800,4.334201,2.4249618,,,,,,,,,,,,,, -112900,4.2628856,2.4515328,,,,,,,,,,,,,, -113000,3.6730232,2.4371152,,,,,,,,,,,,,, -113100,4.926871,2.4119225,,,,,,,,,,,,,, -113200,4.6212482,2.3850884,,,,,,,,,,,,,, -113300,4.434845,2.5121317,,,,,,,,,,,,,, -113400,4.028161,2.4502532,,,,,,,,,,,,,, -113500,3.8339484,2.4153843,,,,,,,,,,,,,, -113600,4.2688556,2.4332929,,,,,,,,,,,,,, -113700,4.098999,2.398339,,,,,,,,,,,,,, -113800,3.7617931,2.3641717,,,,,,,,,,,,,, -113860,,,0.7833625674247742,0.9273353815078736,0.6845200061798096,1.3640166521072388,50000.0,0.5616000294685364,2.0014493465423584,10000.0,38325.62386679649,39823.877307891846,38325.62386679649,1490.8633544445038,3.8569798469543457,0.0 -113900,3.9714267,2.424382,,,,,,,,,,,,,, -114000,4.2993426,2.4611588,,,,,,,,,,,,,, -114100,4.223847,2.439193,,,,,,,,,,,,,, -114200,4.196726,2.371871,,,,,,,,,,,,,, -114300,4.7868733,2.3773742,,,,,,,,,,,,,, -114400,4.0275316,2.4419396,,,,,,,,,,,,,, -114500,4.666713,2.4199102,,,,,,,,,,,,,, -114600,4.3819704,2.4290435,,,,,,,,,,,,,, -114700,4.622657,2.3689723,,,,,,,,,,,,,, -114800,4.67364,2.3798668,,,,,,,,,,,,,, -114900,4.461263,2.3998826,,,,,,,,,,,,,, -115000,3.815698,2.2675653,,,,,,,,,,,,,, -115100,4.524545,2.48829,,,,,,,,,,,,,, -115200,4.2501736,2.3975217,,,,,,,,,,,,,, -115300,4.0640574,2.3346815,,,,,,,,,,,,,, -115379,,,0.7703882455825806,1.0078136920928955,0.6881399750709534,1.3621941804885864,50000.0,0.5652000308036804,1.992992281913757,10000.0,38835.71640062332,40351.181077718735,38835.71640062332,1507.9782707691193,3.90530037879944,0.0 -115400,4.043722,2.3766632,,,,,,,,,,,,,, -115500,4.982244,2.3975704,,,,,,,,,,,,,, -115600,4.4587917,2.3632452,,,,,,,,,,,,,, -115700,4.2111845,2.4031382,,,,,,,,,,,,,, -115800,4.098812,2.3027577,,,,,,,,,,,,,, -115900,4.549873,2.4331129,,,,,,,,,,,,,, -116000,4.2901683,2.3200727,,,,,,,,,,,,,, -116100,4.83173,2.3386378,,,,,,,,,,,,,, -116200,3.9761667,2.349185,,,,,,,,,,,,,, -116300,4.1765723,2.3235404,,,,,,,,,,,,,, -116400,4.2666483,2.4330707,,,,,,,,,,,,,, -116500,4.051759,2.3488295,,,,,,,,,,,,,, -116600,4.5901294,2.3820786,,,,,,,,,,,,,, -116700,4.585827,2.3726814,,,,,,,,,,,,,, -116800,4.0305347,2.3744404,,,,,,,,,,,,,, -116900,,,0.7623764276504517,1.0332750082015991,0.694599986076355,1.3372550010681152,50000.0,0.5777000188827515,1.9591028690338133,10000.0,39345.8971452713,40878.56984305382,39345.8971452713,1525.088080406189,3.955566167831421,0.0 -116900,4.3354926,2.4011557,,,,,,,,,,,,,, -117000,5.4685936,2.3608088,,,,,,,,,,,,,, -117100,4.2760944,2.3815355,,,,,,,,,,,,,, -117200,4.3841295,2.3952057,,,,,,,,,,,,,, -117300,5.2549357,2.4188223,,,,,,,,,,,,,, -117400,4.9726596,2.4247189,,,,,,,,,,,,,, -117500,4.1911592,2.3257334,,,,,,,,,,,,,, -117600,4.19388,2.4466295,,,,,,,,,,,,,, -117700,4.7021155,2.3974965,,,,,,,,,,,,,, -117800,4.1313815,2.3542929,,,,,,,,,,,,,, -117900,4.3144035,2.3736875,,,,,,,,,,,,,, -118000,4.2384257,2.3822122,,,,,,,,,,,,,, -118100,4.647048,2.3983407,,,,,,,,,,,,,, -118200,4.3641553,2.447015,,,,,,,,,,,,,, -118300,4.2783017,2.3462615,,,,,,,,,,,,,, -118400,5.3341184,2.2878377,,,,,,,,,,,,,, -118420,,,0.7677973508834839,0.99867582321167,0.6958000063896179,1.3079572916030884,50000.0,0.567300021648407,1.957592248916626,10000.0,39856.08269357681,41406.01397848129,39856.08269357681,1542.2515261173248,4.002415657043457,0.0 -118500,4.326778,2.359215,,,,,,,,,,,,,, -118600,4.324835,2.3052106,,,,,,,,,,,,,, -118700,4.6535053,2.1903272,,,,,,,,,,,,,, -118800,4.4404535,2.325766,,,,,,,,,,,,,, -118900,4.296357,2.3673885,,,,,,,,,,,,,, -119000,4.392503,2.2562602,,,,,,,,,,,,,, -119100,4.2204566,2.3889775,,,,,,,,,,,,,, -119200,4.5279408,2.3995912,,,,,,,,,,,,,, -119300,5.3795433,2.3451276,,,,,,,,,,,,,, -119400,4.3689127,2.3707743,,,,,,,,,,,,,, -119500,4.1735854,2.3684978,,,,,,,,,,,,,, -119600,4.3740845,2.4693565,,,,,,,,,,,,,, -119700,4.1358786,2.3566477,,,,,,,,,,,,,, -119800,4.4560065,2.417316,,,,,,,,,,,,,, -119900,4.613804,2.3138874,,,,,,,,,,,,,, -119939,,,0.760184109210968,1.0345245599746704,0.6921399831771851,1.3318723440170288,50000.0,0.5699000358581543,1.961030006408692,10000.0,40366.02552604675,41933.3487534523,40366.02552604675,1559.540601491928,4.057379245758057,0.0 -120000,4.0602965,2.3907535,,,,,,,,,,,,,, -120100,4.4338984,2.3228755,,,,,,,,,,,,,, -120200,4.268498,2.352964,,,,,,,,,,,,,, -120300,4.519915,2.3491185,,,,,,,,,,,,,, -120400,4.461141,2.3537383,,,,,,,,,,,,,, -120500,4.537351,2.3234048,,,,,,,,,,,,,, -120600,4.6948867,2.3649268,,,,,,,,,,,,,, -120700,4.1645145,2.3919764,,,,,,,,,,,,,, -120800,4.532566,2.3550067,,,,,,,,,,,,,, -120900,4.9539084,2.4942737,,,,,,,,,,,,,, -121000,4.9711933,2.35326,,,,,,,,,,,,,, -121100,4.2771354,2.2926254,,,,,,,,,,,,,, -121200,4.7150044,2.324749,,,,,,,,,,,,,, -121300,4.265977,2.315932,,,,,,,,,,,,,, -121400,4.7101774,2.3612745,,,,,,,,,,,,,, -121459,,,0.7603236436843872,1.0235108137130735,0.6966999769210815,1.3048481941223145,50000.0,0.5731000304222107,1.9369956254959104,10000.0,40876.20910453797,42460.95462989807,40876.20910453797,1576.865121126175,4.10741400718689,0.0 -121500,4.8399734,2.3908165,,,,,,,,,,,,,, -121600,4.8241124,2.3426661,,,,,,,,,,,,,, -121700,4.6825633,2.3457017,,,,,,,,,,,,,, -121800,4.965373,2.4400394,,,,,,,,,,,,,, -121900,4.486461,2.362556,,,,,,,,,,,,,, -122000,4.477391,2.3045528,,,,,,,,,,,,,, -122100,4.6885686,2.3588226,,,,,,,,,,,,,, -122200,4.465575,2.3205872,,,,,,,,,,,,,, -122300,4.328023,2.35833,,,,,,,,,,,,,, -122400,4.556774,2.2771637,,,,,,,,,,,,,, -122500,4.7481894,2.382339,,,,,,,,,,,,,, -122600,4.9837103,2.339833,,,,,,,,,,,,,, -122700,4.776748,2.3819346,,,,,,,,,,,,,, -122800,4.916699,2.3315883,,,,,,,,,,,,,, -122900,4.9444304,2.3275886,,,,,,,,,,,,,, -122979,,,0.7931680083274841,0.8888539671897888,0.6993399858474731,1.2950785160064695,50000.0,0.5800999999046326,1.9291913509368896,10000.0,41386.174983263016,42988.33508872986,41386.174983263016,1594.183688402176,4.156118869781494,0.0 -123000,4.606162,2.4195945,,,,,,,,,,,,,, -123100,4.7832885,2.2241228,,,,,,,,,,,,,, -123200,4.321276,2.2188847,,,,,,,,,,,,,, -123300,5.0840034,2.3035815,,,,,,,,,,,,,, -123400,4.608668,2.28877,,,,,,,,,,,,,, -123500,4.772956,2.3702476,,,,,,,,,,,,,, -123600,4.4310374,2.2772837,,,,,,,,,,,,,, -123700,4.7922378,2.3561282,,,,,,,,,,,,,, -123800,4.822565,2.3889396,,,,,,,,,,,,,, -123900,4.6172314,2.2961912,,,,,,,,,,,,,, -124000,4.7665014,2.358868,,,,,,,,,,,,,, -124100,5.074289,2.3512797,,,,,,,,,,,,,, -124200,5.209639,2.3381035,,,,,,,,,,,,,, -124300,4.300107,2.3163316,,,,,,,,,,,,,, -124400,4.5318117,2.31145,,,,,,,,,,,,,, -124498,,,0.7772042155265808,0.9544240832328796,0.6949399709701538,1.311497926712036,50000.0,0.5751000046730042,1.938994646072388,10000.0,41896.13381576538,43515.6908364296,41896.13381576538,1611.4827795028689,4.206348896026611,0.0 -124500,4.6252522,2.345531,,,,,,,,,,,,,, -124600,4.3847866,2.2942994,,,,,,,,,,,,,, -124700,4.9692974,2.35385,,,,,,,,,,,,,, -124800,4.9131284,2.3739314,,,,,,,,,,,,,, -124900,5.2950306,2.2725768,,,,,,,,,,,,,, -125000,5.0698133,2.230204,,,,,,,,,,,,,, -125100,4.955683,2.391083,,,,,,,,,,,,,, -125200,4.8929977,2.3398504,,,,,,,,,,,,,, -125300,5.686682,2.3647158,,,,,,,,,,,,,, -125400,4.2524657,2.2745419,,,,,,,,,,,,,, -125500,4.9650126,2.4135184,,,,,,,,,,,,,, -125600,4.952874,2.3090801,,,,,,,,,,,,,, -125700,4.671826,2.3071394,,,,,,,,,,,,,, -125800,5.0860744,2.3023643,,,,,,,,,,,,,, -125900,5.0891013,2.3080647,,,,,,,,,,,,,, -126000,4.5843644,2.2319083,,,,,,,,,,,,,, -126016,,,0.7820471525192261,0.9358399510383606,0.7050999999046326,1.2812731266021729,50000.0,0.5834000110626221,1.9185949563980105,10000.0,42406.0859875679,44042.85533833504,42406.0859875679,1628.591871261597,4.261731863021851,0.0 -126100,4.8057346,2.2728474,,,,,,,,,,,,,, -126200,5.049076,2.3819244,,,,,,,,,,,,,, -126300,4.543946,2.26883,,,,,,,,,,,,,, -126400,4.487829,2.2578273,,,,,,,,,,,,,, -126500,5.035155,2.2991195,,,,,,,,,,,,,, -126600,4.611149,2.2813318,,,,,,,,,,,,,, -126700,5.1038404,2.2507076,,,,,,,,,,,,,, -126800,5.2019773,2.223702,,,,,,,,,,,,,, -126900,4.7151427,2.356111,,,,,,,,,,,,,, -127000,4.9282565,2.3088124,,,,,,,,,,,,,, -127100,4.735312,2.2763646,,,,,,,,,,,,,, -127200,4.7890944,2.3180194,,,,,,,,,,,,,, -127300,4.5527043,2.3070238,,,,,,,,,,,,,, -127400,5.0106506,2.244101,,,,,,,,,,,,,, -127500,4.732182,2.225045,,,,,,,,,,,,,, -127536,,,0.7782206535339355,0.9608622193336488,0.7054600119590759,1.2842296361923218,50000.0,0.579200029373169,1.9385350942611688,10000.0,42916.115965127945,44570.39530873299,42916.115965127945,1646.0038397312164,4.311935186386108,0.0 -127600,4.694719,2.2128112,,,,,,,,,,,,,, -127700,4.5457864,2.3047252,,,,,,,,,,,,,, -127800,4.991712,2.3404636,,,,,,,,,,,,,, -127900,4.7922034,2.2872415,,,,,,,,,,,,,, -128000,5.033386,2.2871127,,,,,,,,,,,,,, -128100,4.806264,2.273433,,,,,,,,,,,,,, -128200,4.9845467,2.3186135,,,,,,,,,,,,,, -128300,4.920137,2.2873797,,,,,,,,,,,,,, -128400,5.5672636,2.2741437,,,,,,,,,,,,,, -128500,4.8181906,2.1710718,,,,,,,,,,,,,, -128600,4.839081,2.3107572,,,,,,,,,,,,,, -128700,4.7439475,2.266403,,,,,,,,,,,,,, -128800,5.146401,2.2764785,,,,,,,,,,,,,, -128900,5.012849,2.310279,,,,,,,,,,,,,, -129000,4.635491,2.2077837,,,,,,,,,,,,,, -129054,,,0.7837212681770325,0.9352360963821412,0.712179958820343,1.254599690437317,50000.0,0.5873000025749207,1.8815577030181885,10000.0,43425.82639288902,45097.83585691452,43425.82639288902,1663.313729763031,4.684342861175537,0.0 -129100,5.521641,2.247207,,,,,,,,,,,,,, -129200,5.6794243,2.2926462,,,,,,,,,,,,,, -129300,5.2235923,2.3624277,,,,,,,,,,,,,, -129400,4.936046,2.221695,,,,,,,,,,,,,, -129500,5.110801,2.2641702,,,,,,,,,,,,,, -129600,5.448273,2.204492,,,,,,,,,,,,,, -129700,4.6824217,2.2629304,,,,,,,,,,,,,, -129800,5.040794,2.30348,,,,,,,,,,,,,, -129900,5.4470177,2.2629912,,,,,,,,,,,,,, -130000,6.0060143,2.2916162,,,,,,,,,,,,,, -130100,4.969082,2.1672251,,,,,,,,,,,,,, -130200,4.7903557,2.2315114,,,,,,,,,,,,,, -130300,4.9055376,2.3006103,,,,,,,,,,,,,, -130400,4.845527,2.2484558,,,,,,,,,,,,,, -130500,4.5932903,2.275641,,,,,,,,,,,,,, -130573,,,0.7919722199440002,0.8858956694602966,0.7138800024986267,1.2260390520095823,50000.0,0.5934000015258789,1.860069990158081,10000.0,43935.96074128151,45625.14785838127,43935.96074128151,1680.3904614448547,4.7375593185424805,0.0 -130600,4.963958,2.238114,,,,,,,,,,,,,, -130700,5.191378,2.2302766,,,,,,,,,,,,,, -130800,5.0445204,2.21512,,,,,,,,,,,,,, -130900,4.730966,2.3385265,,,,,,,,,,,,,, -131000,5.1231804,2.252289,,,,,,,,,,,,,, -131100,4.7030034,2.2952528,,,,,,,,,,,,,, -131200,5.537406,2.2387881,,,,,,,,,,,,,, -131300,4.718815,2.2713375,,,,,,,,,,,,,, -131400,4.978921,2.2655015,,,,,,,,,,,,,, -131500,5.3065405,2.311059,,,,,,,,,,,,,, -131600,5.6014934,2.2392318,,,,,,,,,,,,,, -131700,5.2806954,2.2051144,,,,,,,,,,,,,, -131800,5.156736,2.2209895,,,,,,,,,,,,,, -131900,4.8885865,2.1713223,,,,,,,,,,,,,, -132000,5.031744,2.2385404,,,,,,,,,,,,,, -132093,,,0.8088328838348389,0.8323812484741211,0.7155199646949768,1.2397068738937378,50000.0,0.5836000442504883,1.886076092720032,10000.0,44446.12138080597,46152.46186089516,44446.12138080597,1697.4430515766144,4.790385007858276,0.0 -132100,4.9275427,2.1626093,,,,,,,,,,,,,, -132200,5.324383,2.2097,,,,,,,,,,,,,, -132300,5.096101,2.2177005,,,,,,,,,,,,,, -132400,5.0267725,2.3426375,,,,,,,,,,,,,, -132500,4.5145555,2.21554,,,,,,,,,,,,,, -132600,5.379557,2.2209504,,,,,,,,,,,,,, -132700,5.6201878,2.1920614,,,,,,,,,,,,,, -132800,5.219214,2.207868,,,,,,,,,,,,,, -132900,5.62936,2.35617,,,,,,,,,,,,,, -133000,5.4990716,2.210721,,,,,,,,,,,,,, -133100,5.2981076,2.2326736,,,,,,,,,,,,,, -133200,5.3021035,2.291174,,,,,,,,,,,,,, -133300,5.2036405,2.2153473,,,,,,,,,,,,,, -133400,5.089058,2.1841888,,,,,,,,,,,,,, -133500,5.162953,2.2519615,,,,,,,,,,,,,, -133600,5.9348626,2.2241087,,,,,,,,,,,,,, -133612,,,0.7983298897743225,0.8808016180992126,0.7109400033950806,1.252920150756836,50000.0,0.588200032711029,1.905611991882324,10000.0,44956.11803674698,46679.69448518753,44956.11803674698,1714.5795366764069,4.842229604721069,0.0 -133700,4.7282386,2.2346392,,,,,,,,,,,,,, -133800,4.9560413,2.191286,,,,,,,,,,,,,, -133900,5.3236284,2.256712,,,,,,,,,,,,,, -134000,5.538314,2.1831274,,,,,,,,,,,,,, -134100,5.1736045,2.231873,,,,,,,,,,,,,, -134200,5.3846045,2.142976,,,,,,,,,,,,,, -134300,5.1994686,2.2661679,,,,,,,,,,,,,, -134400,4.9828405,2.2413294,,,,,,,,,,,,,, -134500,5.915025,2.2547011,,,,,,,,,,,,,, -134600,5.1533713,2.2401948,,,,,,,,,,,,,, -134700,5.5764546,2.2271307,,,,,,,,,,,,,, -134800,5.4991603,2.1651146,,,,,,,,,,,,,, -134900,5.570747,2.1925328,,,,,,,,,,,,,, -135000,4.9568706,2.2204683,,,,,,,,,,,,,, -135100,5.289269,2.2706041,,,,,,,,,,,,,, -135131,,,0.8057836294174194,0.8274676203727722,0.7206000089645386,1.196054220199585,50000.0,0.5926000475883484,1.841191053390503,10000.0,45466.05125045776,47206.79020857811,45466.05125045776,1731.6397836208344,4.896440029144287,0.0 -135200,4.9648223,2.1527712,,,,,,,,,,,,,, -135300,5.321404,2.1824718,,,,,,,,,,,,,, -135400,5.485292,2.2132201,,,,,,,,,,,,,, -135500,5.0850606,2.1469755,,,,,,,,,,,,,, -135600,5.5627217,2.1739566,,,,,,,,,,,,,, -135700,5.2547035,2.2951224,,,,,,,,,,,,,, -135800,4.9461412,2.2010365,,,,,,,,,,,,,, -135900,5.128889,2.1962814,,,,,,,,,,,,,, -136000,5.630661,2.246327,,,,,,,,,,,,,, -136100,5.6354504,2.207539,,,,,,,,,,,,,, -136200,5.741675,2.1450279,,,,,,,,,,,,,, -136300,5.199582,2.1929204,,,,,,,,,,,,,, -136400,5.7338595,2.2495036,,,,,,,,,,,,,, -136500,5.3085413,2.2493277,,,,,,,,,,,,,, -136600,5.8766456,2.1965673,,,,,,,,,,,,,, -136650,,,0.7958585619926453,0.875244677066803,0.7117800116539001,1.2368409633636477,50000.0,0.5906000137329102,1.861374497413636,10000.0,45976.103113889694,47734.25307178497,45976.103113889694,1748.9506647586825,4.948869228363037,0.0 -136700,5.617618,2.207928,,,,,,,,,,,,,, -136800,5.093631,2.1795511,,,,,,,,,,,,,, -136900,6.083588,2.2498136,,,,,,,,,,,,,, -137000,5.483583,2.1536608,,,,,,,,,,,,,, -137100,5.49023,2.1665213,,,,,,,,,,,,,, -137200,5.3676195,2.1721756,,,,,,,,,,,,,, -137300,5.1481805,2.118138,,,,,,,,,,,,,, -137400,5.5837398,2.1219997,,,,,,,,,,,,,, -137500,5.993428,2.3148954,,,,,,,,,,,,,, -137600,5.1139555,2.1316783,,,,,,,,,,,,,, -137700,5.6874676,2.1576715,,,,,,,,,,,,,, -137800,5.5479302,2.141884,,,,,,,,,,,,,, -137900,5.551352,2.1328554,,,,,,,,,,,,,, -138000,5.2618475,2.1448398,,,,,,,,,,,,,, -138100,6.083709,2.2051773,,,,,,,,,,,,,, -138169,,,0.8053850531578064,0.8225789070129395,0.723859965801239,1.1837297677993774,50000.0,0.6002000570297241,1.803093433380127,10000.0,46486.0466632843,48261.59885954857,46486.0466632843,1766.2522943019867,5.00182056427002,0.0 -138200,5.1812487,2.1568487,,,,,,,,,,,,,, -138300,5.7941074,2.1423812,,,,,,,,,,,,,, -138400,5.6717997,2.1784785,,,,,,,,,,,,,, -138500,5.4646754,2.1919897,,,,,,,,,,,,,, -138600,5.697605,2.1626737,,,,,,,,,,,,,, -138700,5.795968,2.1936693,,,,,,,,,,,,,, -138800,5.6060867,2.193667,,,,,,,,,,,,,, -138900,5.364061,2.2161505,,,,,,,,,,,,,, -139000,5.230962,2.241355,,,,,,,,,,,,,, -139100,5.490634,2.2019522,,,,,,,,,,,,,, -139200,5.298153,2.1591349,,,,,,,,,,,,,, -139300,6.343692,2.1306088,,,,,,,,,,,,,, -139400,5.7278366,2.1898918,,,,,,,,,,,,,, -139500,5.7244225,2.1348,,,,,,,,,,,,,, -139600,5.1798806,2.1945071,,,,,,,,,,,,,, -139689,,,0.8387476205825806,0.7008848190307617,0.7271199822425842,1.168049693107605,50000.0,0.6009000539779663,1.8257156610488887,10000.0,46996.25559186936,48789.38086462021,46996.25559186936,1783.7273569107056,5.05206823348999,0.0 -139700,5.630965,2.1230443,,,,,,,,,,,,,, -139800,5.8750544,2.128902,,,,,,,,,,,,,, -139900,5.4907856,2.1307862,,,,,,,,,,,,,, -140000,5.0078993,2.1445036,,,,,,,,,,,,,, -140100,5.7691307,2.1562445,,,,,,,,,,,,,, -140200,5.9662337,2.1780188,,,,,,,,,,,,,, -140300,5.699384,2.123268,,,,,,,,,,,,,, -140400,6.02024,2.1894884,,,,,,,,,,,,,, -140500,5.407648,2.1437056,,,,,,,,,,,,,, -140600,5.85032,2.1773207,,,,,,,,,,,,,, -140700,5.7114577,2.1072428,,,,,,,,,,,,,, -140800,5.5463686,2.0929844,,,,,,,,,,,,,, -140900,5.783397,2.1415546,,,,,,,,,,,,,, -141000,5.48595,2.0179694,,,,,,,,,,,,,, -141100,5.6871214,2.1626487,,,,,,,,,,,,,, -141200,5.986566,2.0451446,,,,,,,,,,,,,, -141209,,,0.8303571343421936,0.7429234385490417,0.725600004196167,1.187226176261902,50000.0,0.6070000529289246,1.7983014583587646,10000.0,47506.24500584602,49316.61631822586,47506.24500584602,1800.8658833503723,5.111589431762695,0.0 -141300,6.0067906,2.2624469,,,,,,,,,,,,,, -141400,5.835401,2.2187428,,,,,,,,,,,,,, -141500,5.654168,2.18863,,,,,,,,,,,,,, -141600,6.2113533,2.1929562,,,,,,,,,,,,,, -141700,5.716936,2.1618934,,,,,,,,,,,,,, -141800,5.5521164,2.161343,,,,,,,,,,,,,, -141900,5.6856346,2.1408217,,,,,,,,,,,,,, -142000,5.324606,2.145416,,,,,,,,,,,,,, -142100,5.815415,2.1423702,,,,,,,,,,,,,, -142200,5.6761904,2.170516,,,,,,,,,,,,,, -142300,5.9099274,2.0948253,,,,,,,,,,,,,, -142400,5.9270906,2.19269,,,,,,,,,,,,,, -142500,5.5143223,2.070938,,,,,,,,,,,,,, -142600,5.835468,2.0536182,,,,,,,,,,,,,, -142700,5.3520536,2.1507275,,,,,,,,,,,,,, -142728,,,0.8176219463348389,0.8000547885894775,0.7247999906539917,1.205077886581421,50000.0,0.5998000502586365,1.849668025970459,10000.0,48016.3487842083,49844.00881743431,48016.3487842083,1818.0516149997711,5.166984558105469,0.0 -142800,5.572998,2.0563002,,,,,,,,,,,,,, -142900,5.5050225,2.177796,,,,,,,,,,,,,, -143000,6.059038,2.1619453,,,,,,,,,,,,,, -143100,5.799712,2.0724106,,,,,,,,,,,,,, -143200,5.7869515,2.1652749,,,,,,,,,,,,,, -143300,5.786191,2.1368291,,,,,,,,,,,,,, -143400,6.1694946,2.0840816,,,,,,,,,,,,,, -143500,5.788983,2.0978062,,,,,,,,,,,,,, -143600,6.0376296,2.0734816,,,,,,,,,,,,,, -143700,5.8543787,2.1644707,,,,,,,,,,,,,, -143800,6.41563,2.1672454,,,,,,,,,,,,,, -143900,5.7422504,2.127778,,,,,,,,,,,,,, -144000,5.747271,2.1500678,,,,,,,,,,,,,, -144100,5.7014704,2.1177535,,,,,,,,,,,,,, -144200,5.7747145,2.107178,,,,,,,,,,,,,, -144248,,,0.8232421875,0.7640501856803894,0.7290999889373779,1.1696761846542358,50000.0,0.6071000099182129,1.7840397357940674,10000.0,48526.44045686722,50371.45580005646,48526.44045686722,1835.3061537742608,5.220117092132568,0.0 -144300,6.396789,2.1765904,,,,,,,,,,,,,, -144400,6.039709,2.1050727,,,,,,,,,,,,,, -144500,6.049261,2.1157117,,,,,,,,,,,,,, -144600,6.062245,2.062451,,,,,,,,,,,,,, -144700,5.8833447,2.1765757,,,,,,,,,,,,,, -144800,6.4989605,2.2131236,,,,,,,,,,,,,, -144900,5.5381875,2.120582,,,,,,,,,,,,,, -145000,5.872467,2.1319075,,,,,,,,,,,,,, -145100,5.9802327,2.1268542,,,,,,,,,,,,,, -145200,6.1214767,2.1266491,,,,,,,,,,,,,, -145300,5.8400917,2.0749166,,,,,,,,,,,,,, -145400,5.739708,2.0578866,,,,,,,,,,,,,, -145500,5.7794685,2.1072538,,,,,,,,,,,,,, -145600,5.692926,2.019064,,,,,,,,,,,,,, -145700,5.601421,2.0949636,,,,,,,,,,,,,, -145767,,,0.8306760191917419,0.7338430881500244,0.735040009021759,1.146825909614563,50000.0,0.6126000285148621,1.7603540420532229,10000.0,49036.527978897095,50899.024411439896,49036.527978897095,1852.686731815338,5.2726123332977295,0.0 -145800,6.3226337,2.146533,,,,,,,,,,,,,, -145900,5.8549986,1.9798644,,,,,,,,,,,,,, -146000,6.837796,2.1199658,,,,,,,,,,,,,, -146100,5.7165494,2.169597,,,,,,,,,,,,,, -146200,5.9850035,2.1112318,,,,,,,,,,,,,, -146300,6.633581,2.0854232,,,,,,,,,,,,,, -146400,6.818362,2.1105063,,,,,,,,,,,,,, -146500,5.929259,2.075274,,,,,,,,,,,,,, -146600,6.0483823,2.1153846,,,,,,,,,,,,,, -146700,6.190461,2.0757794,,,,,,,,,,,,,, -146800,5.9727983,2.0871391,,,,,,,,,,,,,, -146900,6.9502087,2.1523492,,,,,,,,,,,,,, -147000,6.624789,2.1227596,,,,,,,,,,,,,, -147100,6.2495556,2.1504917,,,,,,,,,,,,,, -147200,6.1661587,2.0482333,,,,,,,,,,,,,, -147287,,,0.8342036008834839,0.7212420701980591,0.7377399802207947,1.13205885887146,50000.0,0.6184000372886658,1.7458066940307615,10000.0,49546.45242190361,51426.31285953522,49546.45242190361,1869.9471807479856,5.328448534011841,0.0 -147300,6.0467763,2.1111984,,,,,,,,,,,,,, -147400,6.333242,2.1665363,,,,,,,,,,,,,, -147500,6.1882405,2.0994704,,,,,,,,,,,,,, -147600,5.984387,2.031375,,,,,,,,,,,,,, -147700,5.7860556,2.0833426,,,,,,,,,,,,,, -147800,5.9378366,2.0485451,,,,,,,,,,,,,, -147900,6.331355,2.085014,,,,,,,,,,,,,, -148000,6.2268763,2.0525963,,,,,,,,,,,,,, -148100,5.9955,2.0905075,,,,,,,,,,,,,, -148200,6.147116,2.0572944,,,,,,,,,,,,,, -148300,6.3968554,2.0606022,,,,,,,,,,,,,, -148400,6.121329,2.1065483,,,,,,,,,,,,,, -148500,6.0745044,2.0457628,,,,,,,,,,,,,, -148600,5.935619,2.1502502,,,,,,,,,,,,,, -148700,5.998023,2.0692093,,,,,,,,,,,,,, -148800,6.050649,2.131508,,,,,,,,,,,,,, -148806,,,0.8618662357330322,0.6142693758010864,0.7396999597549438,1.1240257024765017,50000.0,0.6178000569343567,1.7484245300292969,10000.0,50056.575719833374,51953.74762201309,50056.575719833374,1887.1576430797577,5.3819334506988525,0.0 -148900,6.0531735,2.0428777,,,,,,,,,,,,,, -149000,5.7815638,2.0436893,,,,,,,,,,,,,, -149100,6.173334,1.9993786,,,,,,,,,,,,,, -149200,6.4097433,2.0148337,,,,,,,,,,,,,, -149300,6.502058,2.0804496,,,,,,,,,,,,,, -149400,6.4794116,2.100615,,,,,,,,,,,,,, -149500,6.654972,2.0853007,,,,,,,,,,,,,, -149600,6.2718577,2.1160717,,,,,,,,,,,,,, -149700,6.3601317,2.0597708,,,,,,,,,,,,,, -149800,6.9069185,2.1076443,,,,,,,,,,,,,, -149900,6.4166207,2.0676608,,,,,,,,,,,,,, -150000,5.9962406,2.1041496,,,,,,,,,,,,,, -150100,6.0594816,2.0174298,,,,,,,,,,,,,, -150200,7.308798,2.0933614,,,,,,,,,,,,,, -150300,6.7283096,2.1388814,,,,,,,,,,,,,, -150325,,,0.8557676672935486,0.6431095004081726,0.7424399852752686,1.1167047023773191,50000.0,0.6203000545501709,1.736547350883484,10000.0,50566.62089514732,52481.188838005066,50566.62089514732,1904.453207731247,5.434524297714233,0.0 -150400,6.2387967,2.0537255,,,,,,,,,,,,,, -150500,5.76597,2.0559673,,,,,,,,,,,,,, -150600,6.141671,2.0716677,,,,,,,,,,,,,, -150700,5.9152865,2.020249,,,,,,,,,,,,,, -150800,6.457854,2.0550218,,,,,,,,,,,,,, -150900,6.5163236,2.0591934,,,,,,,,,,,,,, -151000,6.9778,1.9984536,,,,,,,,,,,,,, -151100,6.706812,1.9960763,,,,,,,,,,,,,, -151200,6.593831,1.9992046,,,,,,,,,,,,,, -151300,6.5353703,2.0001245,,,,,,,,,,,,,, -151400,6.436152,2.044379,,,,,,,,,,,,,, -151500,6.774479,2.0665026,,,,,,,,,,,,,, -151600,7.4791765,2.1287503,,,,,,,,,,,,,, -151700,6.0463457,2.0350463,,,,,,,,,,,,,, -151800,6.6090226,2.0203876,,,,,,,,,,,,,, -151844,,,0.8517019748687744,0.6508660316467285,0.740399956703186,1.1089304685592651,50000.0,0.6206000447273254,1.7190169095993042,10000.0,51076.7715549469,53008.741042375565,51076.7715549469,1921.752840518952,5.488680601119995,0.0 -151900,6.511478,1.9877491,,,,,,,,,,,,,, -152000,6.49924,2.0040395,,,,,,,,,,,,,, -152100,6.6163816,2.015258,,,,,,,,,,,,,, -152200,6.042445,2.0029538,,,,,,,,,,,,,, -152300,6.1609845,2.0002887,,,,,,,,,,,,,, -152400,6.7776303,1.9715335,,,,,,,,,,,,,, -152500,5.81999,1.961873,,,,,,,,,,,,,, -152600,6.3556376,2.0685449,,,,,,,,,,,,,, -152700,5.9909577,2.0197177,,,,,,,,,,,,,, -152800,6.851267,2.0491264,,,,,,,,,,,,,, -152900,6.239209,2.0381753,,,,,,,,,,,,,, -153000,6.721208,1.9991672,,,,,,,,,,,,,, -153100,6.7126093,2.0345707,,,,,,,,,,,,,, -153200,6.8767524,2.0075898,,,,,,,,,,,,,, -153300,6.483681,2.1172676,,,,,,,,,,,,,, -153363,,,0.8570631146430969,0.6400169134140015,0.7451599836349487,1.0995455980300903,50000.0,0.6276000142097473,1.702033519744873,10000.0,51586.68243670464,53536.81693482399,51586.68243670464,1939.816710472107,5.542165279388428,0.0 -153400,6.277579,2.0033305,,,,,,,,,,,,,, -153500,6.102159,2.05825,,,,,,,,,,,,,, -153600,6.3341937,1.972417,,,,,,,,,,,,,, -153700,6.325297,2.0149033,,,,,,,,,,,,,, -153800,6.303059,2.0292842,,,,,,,,,,,,,, -153900,6.125492,1.9776585,,,,,,,,,,,,,, -154000,6.5632234,1.9991648,,,,,,,,,,,,,, -154100,6.5408907,2.0396564,,,,,,,,,,,,,, -154200,7.1791177,2.0982542,,,,,,,,,,,,,, -154300,6.790456,2.0090055,,,,,,,,,,,,,, -154400,6.858543,2.0005665,,,,,,,,,,,,,, -154500,6.5489116,2.0112936,,,,,,,,,,,,,, -154600,6.303145,2.0233126,,,,,,,,,,,,,, -154700,7.5090356,2.0724814,,,,,,,,,,,,,, -154800,6.139082,1.965163,,,,,,,,,,,,,, -154882,,,0.8557876348495483,0.6262343525886536,0.7460199594497681,1.0894845724105835,50000.0,0.6283000111579895,1.6885766983032229,10000.0,52096.776774168015,54064.17822790146,52096.776774168015,1956.984309911728,5.594337701797485,0.0 -154900,6.5851755,2.0127223,,,,,,,,,,,,,, -155000,6.3553314,2.0039225,,,,,,,,,,,,,, -155100,7.067167,2.063815,,,,,,,,,,,,,, -155200,7.3140125,2.0402448,,,,,,,,,,,,,, -155300,6.1121936,1.9672109,,,,,,,,,,,,,, -155400,7.7249837,2.007883,,,,,,,,,,,,,, -155500,7.0026474,1.942051,,,,,,,,,,,,,, -155600,6.4703474,1.9605554,,,,,,,,,,,,,, -155700,6.6637993,2.0128973,,,,,,,,,,,,,, -155800,6.3250346,1.9823678,,,,,,,,,,,,,, -155900,6.9889,1.9908319,,,,,,,,,,,,,, -156000,7.1298203,2.0170443,,,,,,,,,,,,,, -156100,6.5615907,1.9986789,,,,,,,,,,,,,, -156200,7.213985,2.0137403,,,,,,,,,,,,,, -156300,6.513552,1.981386,,,,,,,,,,,,,, -156400,6.67859,1.9685421,,,,,,,,,,,,,, -156401,,,0.861348032951355,0.6097630858421326,0.7497199773788452,1.077207088470459,50000.0,0.6301000118255615,1.6781634092330933,10000.0,52606.98158121109,54592.09770488739,52606.98158121109,1974.597580432892,5.648033142089844,0.0 -156500,6.839601,1.9042916,,,,,,,,,,,,,, -156600,6.5045547,1.920516,,,,,,,,,,,,,, -156700,6.804585,1.9051102,,,,,,,,,,,,,, -156800,7.1477222,1.9692963,,,,,,,,,,,,,, -156900,6.7215796,1.9688468,,,,,,,,,,,,,, -157000,7.9561095,2.0737216,,,,,,,,,,,,,, -157100,6.406181,1.9250054,,,,,,,,,,,,,, -157200,7.2010202,1.9659655,,,,,,,,,,,,,, -157300,6.992986,2.0175095,,,,,,,,,,,,,, -157400,6.896511,1.9337413,,,,,,,,,,,,,, -157500,6.568652,1.8749018,,,,,,,,,,,,,, -157600,7.038903,1.9975803,,,,,,,,,,,,,, -157700,7.580393,1.8888192,,,,,,,,,,,,,, -157800,6.4860253,2.0000367,,,,,,,,,,,,,, -157900,6.4733267,1.9741186,,,,,,,,,,,,,, -157920,,,0.8819953799247742,0.5402708649635315,0.7523799538612366,1.0711145401000977,50000.0,0.631600022315979,1.692663550376892,10000.0,53117.21468949318,55119.62186551094,53117.21468949318,1991.7971727848053,5.691704273223877,0.0 -158000,6.68543,2.01959,,,,,,,,,,,,,, -158100,6.7728148,1.9575051,,,,,,,,,,,,,, -158200,6.9092546,2.0266063,,,,,,,,,,,,,, -158300,6.393918,1.9310383,,,,,,,,,,,,,, -158400,6.3078423,1.918718,,,,,,,,,,,,,, -158500,6.474618,1.9230542,,,,,,,,,,,,,, -158600,6.3978066,1.9356636,,,,,,,,,,,,,, -158700,6.712446,1.9767615,,,,,,,,,,,,,, -158800,7.1930532,1.9829777,,,,,,,,,,,,,, -158900,6.8498178,2.053051,,,,,,,,,,,,,, -159000,6.842329,1.9571248,,,,,,,,,,,,,, -159100,7.562759,2.0069473,,,,,,,,,,,,,, -159200,6.7635555,1.9680407,,,,,,,,,,,,,, -159300,6.5352736,1.9132085,,,,,,,,,,,,,, -159400,6.3583164,1.9536699,,,,,,,,,,,,,, -159440,,,0.8802216053009033,0.5432620644569397,0.7541999816894531,1.059921383857727,50000.0,0.6307000517845154,1.6828455924987793,10000.0,53627.40317606926,55647.28178501129,53627.40317606926,2009.16032409668,5.752228260040283,0.0 -159500,6.8284106,1.9426711,,,,,,,,,,,,,, -159600,6.9822373,1.9236069,,,,,,,,,,,,,, -159700,6.934785,1.9333293,,,,,,,,,,,,,, -159800,6.7579656,1.9408625,,,,,,,,,,,,,, -159900,6.8590174,1.9032946,,,,,,,,,,,,,, -160000,7.3001666,2.001725,,,,,,,,,,,,,, -160100,6.357571,1.8967694,,,,,,,,,,,,,, -160200,7.027849,2.0715795,,,,,,,,,,,,,, -160300,6.986273,1.9807136,,,,,,,,,,,,,, -160400,6.667827,1.9202259,,,,,,,,,,,,,, -160500,7.2341676,1.9341005,,,,,,,,,,,,,, -160600,6.5417233,1.9681456,,,,,,,,,,,,,, -160700,7.250298,1.9989278,,,,,,,,,,,,,, -160800,6.7071023,1.9652537,,,,,,,,,,,,,, -160900,6.5349154,1.8974721,,,,,,,,,,,,,, -160959,,,0.8743024468421936,0.5627148747444153,0.7519999742507935,1.0669498443603516,50000.0,0.6326000094413757,1.689836025238037,10000.0,54137.527206897736,56174.68740797043,54137.527206897736,2026.33642411232,5.809492588043213,0.0 -161000,6.501002,1.840754,,,,,,,,,,,,,, -161100,6.788576,1.9274359,,,,,,,,,,,,,, -161200,7.3991523,1.9564669,,,,,,,,,,,,,, -161300,7.0153995,1.9657717,,,,,,,,,,,,,, -161400,7.0055037,1.8960835,,,,,,,,,,,,,, -161500,6.4797225,1.8902843,,,,,,,,,,,,,, -161600,7.277196,1.8955554,,,,,,,,,,,,,, -161700,7.480453,1.9799453,,,,,,,,,,,,,, -161800,7.309668,1.9725583,,,,,,,,,,,,,, -161900,6.6739187,1.9103941,,,,,,,,,,,,,, -162000,7.251643,1.9345319,,,,,,,,,,,,,, -162100,6.988608,1.8854167,,,,,,,,,,,,,, -162200,6.762894,1.938432,,,,,,,,,,,,,, -162300,6.782809,1.9111205,,,,,,,,,,,,,, -162400,7.4684525,1.9272183,,,,,,,,,,,,,, -162480,,,0.8794642686843872,0.5540726184844971,0.7535199522972107,1.0689977407455444,50000.0,0.6355000138282776,1.671195149421692,10000.0,54647.75112223625,56702.178926467896,54647.75112223625,2043.4988808631897,5.86617112159729,0.0 -162500,6.977249,1.9651036,,,,,,,,,,,,,, -162600,6.3353724,1.8720514,,,,,,,,,,,,,, -162700,7.6122346,1.9564884,,,,,,,,,,,,,, -162800,7.4542627,1.9175748,,,,,,,,,,,,,, -162900,7.6923437,1.9345549,,,,,,,,,,,,,, -163000,7.3188972,1.9277184,,,,,,,,,,,,,, -163100,7.421398,1.8757023,,,,,,,,,,,,,, -163200,6.5504036,1.8524833,,,,,,,,,,,,,, -163300,7.662947,1.8840818,,,,,,,,,,,,,, -163400,7.31435,1.9670933,,,,,,,,,,,,,, -163500,7.087874,1.8777268,,,,,,,,,,,,,, -163600,7.23274,1.8940425,,,,,,,,,,,,,, -163700,7.5030804,1.8933514,,,,,,,,,,,,,, -163800,7.4838376,1.881369,,,,,,,,,,,,,, -163900,6.9939814,1.8842629,,,,,,,,,,,,,, -163999,,,0.881257951259613,0.5390360951423645,0.7562599778175354,1.0575190782546997,50000.0,0.6357000470161438,1.6739016771316528,10000.0,55157.73432970047,57229.695892095566,55157.73432970047,2060.928759098053,5.921685457229614,0.0 -164000,7.6780186,1.9443616,,,,,,,,,,,,,, -164100,7.8619685,1.9681535,,,,,,,,,,,,,, -164200,7.8229556,1.9587591,,,,,,,,,,,,,, -164300,6.803936,1.8367915,,,,,,,,,,,,,, -164400,7.4154453,1.9254463,,,,,,,,,,,,,, -164500,6.784055,1.7746302,,,,,,,,,,,,,, -164600,7.413943,1.8788626,,,,,,,,,,,,,, -164700,7.0796213,1.9661646,,,,,,,,,,,,,, -164800,7.5911503,1.9952112,,,,,,,,,,,,,, -164900,7.2439685,1.9031019,,,,,,,,,,,,,, -165000,6.8125825,1.8456676,,,,,,,,,,,,,, -165100,8.039653,1.9176656,,,,,,,,,,,,,, -165200,6.8284326,1.792429,,,,,,,,,,,,,, -165300,7.155534,1.9117825,,,,,,,,,,,,,, -165400,6.9482837,1.923248,,,,,,,,,,,,,, -165500,7.0987372,1.8739588,,,,,,,,,,,,,, -165518,,,0.8997927308082581,0.4756694734096527,0.7611799836158752,1.0342700481414795,50000.0,0.6420000195503235,1.6526086330413818,10000.0,55667.66546201706,57757.03380203247,55667.66546201706,2078.2167851924896,5.992365121841431,0.0 -165600,7.9397545,1.8358647,,,,,,,,,,,,,, -165700,7.2924433,1.8444147,,,,,,,,,,,,,, -165800,6.7019277,1.8824003,,,,,,,,,,,,,, -165900,6.9571786,1.8715928,,,,,,,,,,,,,, -166000,7.5281916,1.9318974,,,,,,,,,,,,,, -166100,7.6096754,1.8420265,,,,,,,,,,,,,, -166200,7.6814814,1.9090822,,,,,,,,,,,,,, -166300,7.4470296,1.835525,,,,,,,,,,,,,, -166400,6.814378,1.8614558,,,,,,,,,,,,,, -166500,8.216946,1.866768,,,,,,,,,,,,,, -166600,7.0585732,1.9076712,,,,,,,,,,,,,, -166700,7.7047596,1.8787899,,,,,,,,,,,,,, -166800,7.6832814,1.8854038,,,,,,,,,,,,,, -166900,7.613139,1.8400935,,,,,,,,,,,,,, -167000,8.242848,1.8815013,,,,,,,,,,,,,, -167037,,,0.9026426672935486,0.4636174738407135,0.7644599676132202,1.0262614488601685,50000.0,0.6413000226020813,1.6404330730438232,10000.0,56177.75947856903,58284.55493855477,56177.75947856903,2095.539171934128,6.0488669872283936,0.0 -167100,7.439317,1.8498048,,,,,,,,,,,,,, -167200,7.0748944,1.8867456,,,,,,,,,,,,,, -167300,8.501499,1.9334219,,,,,,,,,,,,,, -167400,6.650343,1.8354383,,,,,,,,,,,,,, -167500,7.4319158,1.8023727,,,,,,,,,,,,,, -167600,7.3352575,1.8471837,,,,,,,,,,,,,, -167700,7.058202,1.8681961,,,,,,,,,,,,,, -167800,7.623615,1.8867916,,,,,,,,,,,,,, -167900,7.621887,1.901387,,,,,,,,,,,,,, -168000,7.9264884,1.882586,,,,,,,,,,,,,, -168100,7.462793,1.8264946,,,,,,,,,,,,,, -168200,7.641774,1.8700709,,,,,,,,,,,,,, -168300,7.8673677,1.8735385,,,,,,,,,,,,,, -168400,7.7047606,1.808874,,,,,,,,,,,,,, -168500,7.5093713,1.8302728,,,,,,,,,,,,,, -168557,,,0.901147961616516,0.4678551256656647,0.7643799781799316,1.0218467712402344,50000.0,0.6444000601768494,1.6343672275543213,10000.0,56687.92796278,58812.09150671959,56687.92796278,2112.800806760788,6.107359647750855,0.0 -168600,7.627701,1.8350618,,,,,,,,,,,,,, -168700,7.003669,1.7870586,,,,,,,,,,,,,, -168800,7.88422,1.8114747,,,,,,,,,,,,,, -168900,7.5217133,1.8926116,,,,,,,,,,,,,, -169000,7.2970977,1.8196703,,,,,,,,,,,,,, -169100,7.646092,1.824031,,,,,,,,,,,,,, -169200,8.118395,1.8752842,,,,,,,,,,,,,, -169300,7.4017534,1.8463666,,,,,,,,,,,,,, -169400,6.896088,1.8222432,,,,,,,,,,,,,, -169500,7.859507,1.7940309,,,,,,,,,,,,,, -169600,7.5427103,1.8897591,,,,,,,,,,,,,, -169700,7.9081798,1.8342102,,,,,,,,,,,,,, -169800,7.735322,1.940164,,,,,,,,,,,,,, -169900,7.0807633,1.8293898,,,,,,,,,,,,,, -170000,7.4896646,1.8555794,,,,,,,,,,,,,, -170076,,,0.904117465019226,0.4593851864337921,0.7635599970817566,1.023459792137146,50000.0,0.6461000442504883,1.621721267700195,10000.0,57198.12523698807,59339.6875641346,57198.12523698807,2130.0928435325623,6.1661036014556885,0.0 -170100,7.5636044,1.8223825,,,,,,,,,,,,,, -170200,6.7092786,1.758053,,,,,,,,,,,,,, -170300,7.6174226,1.8765678,,,,,,,,,,,,,, -170400,8.518094,1.7764074,,,,,,,,,,,,,, -170500,7.8840804,1.8808906,,,,,,,,,,,,,, -170600,7.5871744,1.8619909,,,,,,,,,,,,,, -170700,8.131195,1.8191468,,,,,,,,,,,,,, -170800,8.21841,1.879717,,,,,,,,,,,,,, -170900,7.6958594,1.8092511,,,,,,,,,,,,,, -171000,7.6343074,1.8609465,,,,,,,,,,,,,, -171100,7.854899,1.828902,,,,,,,,,,,,,, -171200,7.7921176,1.8289119,,,,,,,,,,,,,, -171300,7.835812,1.7792121,,,,,,,,,,,,,, -171400,7.010798,1.7749794,,,,,,,,,,,,,, -171500,8.46703,1.7798296,,,,,,,,,,,,,, -171595,,,0.906648576259613,0.4511556625366211,0.7665199637413025,1.014053463935852,50000.0,0.6449000239372253,1.6215829849243164,10000.0,57708.20014286041,59867.17143726349,57708.20014286041,2147.3965339660645,6.22303581237793,0.0 -171600,6.8395386,1.7360286,,,,,,,,,,,,,, -171700,8.128484,1.8395789,,,,,,,,,,,,,, -171800,6.9198947,1.7827818,,,,,,,,,,,,,, -171900,7.536597,1.8098649,,,,,,,,,,,,,, -172000,7.196162,1.7650518,,,,,,,,,,,,,, -172100,8.233695,1.8133917,,,,,,,,,,,,,, -172200,7.9370174,1.7729392,,,,,,,,,,,,,, -172300,8.667828,1.87095,,,,,,,,,,,,,, -172400,9.009292,1.865098,,,,,,,,,,,,,, -172500,7.635249,1.8011782,,,,,,,,,,,,,, -172600,7.551719,1.8144315,,,,,,,,,,,,,, -172700,7.416642,1.8196046,,,,,,,,,,,,,, -172800,7.637421,1.7545223,,,,,,,,,,,,,, -172900,8.108677,1.7625802,,,,,,,,,,,,,, -173000,6.987015,1.8283646,,,,,,,,,,,,,, -173100,6.6748986,1.7728757,,,,,,,,,,,,,, -173116,,,0.9085220098495485,0.4437970519065857,0.7672799825668335,1.0095661878585815,50000.0,0.6490000486373901,1.6089009046554563,10000.0,58218.26253461838,60394.55778956413,58218.26253461838,2164.6122002601624,6.283660173416138,0.0 -173200,8.04297,1.8316147,,,,,,,,,,,,,, -173300,7.8401394,1.7965682,,,,,,,,,,,,,, -173400,7.73989,1.9006824,,,,,,,,,,,,,, -173500,7.9319744,1.806553,,,,,,,,,,,,,, -173600,8.2001,1.8280624,,,,,,,,,,,,,, -173700,8.0911255,1.7957578,,,,,,,,,,,,,, -173800,7.690886,1.8200433,,,,,,,,,,,,,, -173900,7.049603,1.8098978,,,,,,,,,,,,,, -174000,7.751727,1.7872269,,,,,,,,,,,,,, -174100,7.494134,1.7721232,,,,,,,,,,,,,, -174200,7.5896907,1.7809021,,,,,,,,,,,,,, -174300,8.0763855,1.8377401,,,,,,,,,,,,,, -174400,7.4949803,1.8657417,,,,,,,,,,,,,, -174500,8.250185,1.8206111,,,,,,,,,,,,,, -174600,7.442977,1.838533,,,,,,,,,,,,,, -174635,,,0.9193837642669678,0.4103900194168091,0.7671399712562561,1.0063893795013428,50000.0,0.6513000130653381,1.6108440160751345,10000.0,58728.41574501991,60922.16036057472,58728.41574501991,2181.95579123497,6.341675043106079,0.0 -174700,7.7409525,1.8353465,,,,,,,,,,,,,, -174800,8.257768,1.8428948,,,,,,,,,,,,,, -174900,7.397556,1.7723945,,,,,,,,,,,,,, -175000,8.681555,1.8614161,,,,,,,,,,,,,, -175100,7.5019503,1.8477038,,,,,,,,,,,,,, -175200,8.125679,1.8768377,,,,,,,,,,,,,, -175300,8.013229,1.8083392,,,,,,,,,,,,,, -175400,7.7259507,1.8501017,,,,,,,,,,,,,, -175500,7.4700007,1.766601,,,,,,,,,,,,,, -175600,7.7354164,1.799541,,,,,,,,,,,,,, -175700,8.022114,1.8572674,,,,,,,,,,,,,, -175800,8.901298,1.8985816,,,,,,,,,,,,,, -175900,7.52125,1.884317,,,,,,,,,,,,,, -176000,7.645848,1.7551539,,,,,,,,,,,,,, -176100,7.9353437,1.8074619,,,,,,,,,,,,,, -176154,,,0.917390763759613,0.4171870648860931,0.76801997423172,1.0084744691848757,50000.0,0.651900053024292,1.6077901124954224,10000.0,59238.44852399826,61449.693984270096,59238.44852399826,2199.349052429199,6.401141405105591,0.0 -176200,8.838115,1.7976618,,,,,,,,,,,,,, -176300,7.747396,1.7612984,,,,,,,,,,,,,, -176400,8.012155,1.7613804,,,,,,,,,,,,,, -176500,7.3000693,1.758668,,,,,,,,,,,,,, -176600,7.488168,1.7594469,,,,,,,,,,,,,, -176700,7.6894155,1.7864447,,,,,,,,,,,,,, -176800,8.069612,1.7941917,,,,,,,,,,,,,, -176900,8.049813,1.8627195,,,,,,,,,,,,,, -177000,8.70122,1.8174484,,,,,,,,,,,,,, -177100,7.95406,1.7923083,,,,,,,,,,,,,, -177200,7.76264,1.8126273,,,,,,,,,,,,,, -177300,7.403956,1.7880661,,,,,,,,,,,,,, -177400,8.073953,1.8257663,,,,,,,,,,,,,, -177500,7.8048496,1.7646544,,,,,,,,,,,,,, -177600,7.5346103,1.7517769,,,,,,,,,,,,,, -177673,,,0.9159757494926452,0.4144531488418579,0.7691400051116943,1.002310037612915,50000.0,0.6516000032424927,1.602932333946228,10000.0,59748.627321243286,61977.189665555954,59748.627321243286,2216.56134724617,6.457212448120117,0.0 -177700,7.6684694,1.8095181,,,,,,,,,,,,,, -177800,7.4286084,1.8169882,,,,,,,,,,,,,, -177900,7.7299185,1.8111105,,,,,,,,,,,,,, -178000,7.819138,1.823604,,,,,,,,,,,,,, -178100,7.8716693,1.7547469,,,,,,,,,,,,,, -178200,8.073553,1.7848892,,,,,,,,,,,,,, -178300,7.970077,1.8300377,,,,,,,,,,,,,, -178400,7.237824,1.7219492,,,,,,,,,,,,,, -178500,7.8555923,1.8375454,,,,,,,,,,,,,, -178600,7.3732233,1.6994681,,,,,,,,,,,,,, -178700,7.9835978,1.8071337,,,,,,,,,,,,,, -178800,8.132243,1.7879603,,,,,,,,,,,,,, -178900,8.093951,1.7662784,,,,,,,,,,,,,, -179000,7.399889,1.8043095,,,,,,,,,,,,,, -179100,7.740109,1.8305736,,,,,,,,,,,,,, -179190,,,0.9170718789100648,0.4057277143001556,0.7702599763870239,0.9993435740470886,50000.0,0.6502000093460083,1.6043004989624023,10000.0,60257.7345225811,62504.43201804161,60257.7345225811,2233.763617515564,7.341777801513672,0.0 -179200,7.883016,1.8554507,,,,,,,,,,,,,, -179300,7.864165,1.7229078,,,,,,,,,,,,,, -179400,8.112187,1.8005704,,,,,,,,,,,,,, -179500,7.910514,1.7620972,,,,,,,,,,,,,, -179600,7.841455,1.7891384,,,,,,,,,,,,,, -179700,7.910571,1.7304156,,,,,,,,,,,,,, -179800,7.7569885,1.7790692,,,,,,,,,,,,,, -179900,7.793644,1.809824,,,,,,,,,,,,,, -180000,7.0771093,1.7419257,,,,,,,,,,,,,, -180100,7.874744,1.7670668,,,,,,,,,,,,,, -180200,8.614332,1.8783652,,,,,,,,,,,,,, -180300,8.564838,1.8095487,,,,,,,,,,,,,, -180400,8.135068,1.7179854,,,,,,,,,,,,,, -180500,7.7094746,1.755209,,,,,,,,,,,,,, -180600,8.76136,1.819928,,,,,,,,,,,,,, -180700,7.9141235,1.7595154,,,,,,,,,,,,,, -180709,,,0.9195631146430968,0.4090628623962402,0.7701999545097351,1.0007290840148926,50000.0,0.6528000235557556,1.6050145626068115,10000.0,60767.89223456383,63032.08485865593,60767.89223456383,2251.149762868881,7.402913808822632,0.0 -180800,7.8416224,1.798039,,,,,,,,,,,,,, -180900,7.4056354,1.8172106,,,,,,,,,,,,,, -181000,7.22735,1.7676737,,,,,,,,,,,,,, -181100,7.448198,1.7685925,,,,,,,,,,,,,, -181200,8.506167,1.7889411,,,,,,,,,,,,,, -181300,7.949966,1.7881956,,,,,,,,,,,,,, -181400,7.5340796,1.7337675,,,,,,,,,,,,,, -181500,7.774035,1.8473912,,,,,,,,,,,,,, -181600,8.399174,1.7222648,,,,,,,,,,,,,, -181700,7.603586,1.8072783,,,,,,,,,,,,,, -181800,7.919949,1.8270779,,,,,,,,,,,,,, -181900,7.099539,1.7383122,,,,,,,,,,,,,, -182000,7.4230003,1.7960868,,,,,,,,,,,,,, -182100,7.9123893,1.7803918,,,,,,,,,,,,,, -182200,8.331726,1.7987239,,,,,,,,,,,,,, -182227,,,0.9197823405265808,0.4075490236282348,0.7706800103187561,0.9976078867912292,50000.0,0.6515000462532043,1.602009892463684,10000.0,61277.79031038284,63559.42090892792,61277.79031038284,2268.478595972061,7.46382737159729,0.0 -182300,7.874671,1.8118242,,,,,,,,,,,,,, -182400,7.6140647,1.7457621,,,,,,,,,,,,,, -182500,7.266966,1.7236307,,,,,,,,,,,,,, -182600,8.014106,1.7690814,,,,,,,,,,,,,, -182700,7.8403063,1.8031541,,,,,,,,,,,,,, -182800,9.568518,1.8049781,,,,,,,,,,,,,, -182900,7.414927,1.7898448,,,,,,,,,,,,,, -183000,8.454659,1.7766409,,,,,,,,,,,,,, -183100,7.516519,1.7369449,,,,,,,,,,,,,, -183200,7.4473553,1.7459184,,,,,,,,,,,,,, -183300,7.790861,1.7780006,,,,,,,,,,,,,, -183400,7.771125,1.788592,,,,,,,,,,,,,, -183500,9.031929,1.7606751,,,,,,,,,,,,,, -183600,7.8796434,1.7581819,,,,,,,,,,,,,, -183700,7.0832953,1.6764054,,,,,,,,,,,,,, -183747,,,0.9213767051696776,0.4001802504062652,0.7707799673080444,0.9960906505584716,50000.0,0.6534000039100647,1.599919080734253,10000.0,61788.00178670883,64087.113674640656,61788.00178670883,2285.850029706955,7.525523662567139,0.0 -183800,7.463984,1.773828,,,,,,,,,,,,,, -183900,8.515172,1.7664026,,,,,,,,,,,,,, -184000,8.423305,1.7519361,,,,,,,,,,,,,, -184100,7.949738,1.8835055,,,,,,,,,,,,,, -184200,9.085539,1.8136407,,,,,,,,,,,,,, -184300,7.7978272,1.8500359,,,,,,,,,,,,,, -184400,7.330957,1.7732419,,,,,,,,,,,,,, -184500,7.6853714,1.7482538,,,,,,,,,,,,,, -184600,8.117434,1.7765977,,,,,,,,,,,,,, -184700,7.999035,1.7801349,,,,,,,,,,,,,, -184800,7.460981,1.727588,,,,,,,,,,,,,, -184900,7.0069337,1.7681124,,,,,,,,,,,,,, -185000,8.309968,1.7584388,,,,,,,,,,,,,, -185100,8.766921,1.7713866,,,,,,,,,,,,,, -185200,7.7581987,1.7633052,,,,,,,,,,,,,, -185266,,,0.9223333597183228,0.3986599445343017,0.7706599831581116,0.9960277080535888,50000.0,0.6541000604629517,1.599553108215332,10000.0,62297.926362752914,64614.32078671456,62297.926362752914,2303.023751735688,7.586014032363892,0.0 -185300,8.692481,1.7430415,,,,,,,,,,,,,, -185400,8.109518,1.8019377,,,,,,,,,,,,,, -185500,8.013169,1.8058133,,,,,,,,,,,,,, -185600,7.2735453,1.7754544,,,,,,,,,,,,,, -185700,7.468111,1.8264095,,,,,,,,,,,,,, -185800,7.738557,1.7898293,,,,,,,,,,,,,, -185900,7.4030995,1.7896749,,,,,,,,,,,,,, -186000,7.5444603,1.7787814,,,,,,,,,,,,,, -186100,7.662561,1.838992,,,,,,,,,,,,,, -186200,7.268826,1.743108,,,,,,,,,,,,,, -186300,7.0792794,1.8033485,,,,,,,,,,,,,, -186400,7.967802,1.812693,,,,,,,,,,,,,, -186500,7.9945145,1.7916042,,,,,,,,,,,,,, -186600,7.8250723,1.7380371,,,,,,,,,,,,,, -186666,,,0.9204400181770324,0.3973002731800079,0.7709999680519104,0.9936330318450928,50000.0,0.6535000205039978,1.5979526042938232,10000.0,62767.63660430908,65101.41163563728,62767.63660430908,2320.299160003662,7.647452354431152,0.0 -186666,,,,,,,,,,,62767.63660430908,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/eval_measurements.csv deleted file mode 100644 index d55bf0b72..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,126 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.192532777786255,0.0,31.64558982849121,1,0,31.64558982849121,0.0006000000284984,6.910250186920166,10000,48.83828067779541,0.0004783163312822,6.910229682922363,0.0009599999757483,6.910243988037109,50000 -34.846673011779785,0.0188093185424804,541.6560969352722,1514,0,541.6560969352722,0.055000003427267,5.537538528442383,10000,576.5695464611053,0.0840242356061935,5.227553844451904,0.0806799978017807,5.2884297370910645,50000 -52.72998690605164,0.0470173358917236,1051.752347946167,3026,0,1051.752347946167,0.1348000019788742,4.705352783203125,10000,1104.6250972747805,0.199597418308258,4.143094062805176,0.1790200024843216,4.270390033721924,50000 -71.29328966140747,0.0761914253234863,1561.9691922664642,4538,0,1561.9691922664642,0.2043000161647796,4.173789024353027,10000,1633.4822070598602,0.3047672212123871,3.462144136428833,0.2767399847507477,3.6103830337524414,50000 -89.09815335273743,0.1078622341156005,2072.226632356644,6051,0,2072.226632356644,0.2770000100135803,3.678683042526245,10000,2161.6236131191254,0.3840282261371612,2.9529519081115723,0.3597399890422821,3.094580173492432,50000 -106.90979647636414,0.140455961227417,2582.2532529830933,7563,0,2582.2532529830933,0.3039000034332275,3.5461227893829346,10000,2689.541732311249,0.4412667453289032,2.6990714073181152,0.4002600014209747,2.899838447570801,50000 -124.54310297966003,0.1691508293151855,3092.440190553665,9077,0,3092.440190553665,0.3500000238418579,3.250772476196289,10000,3217.438376188278,0.5027104616165161,2.3729019165039062,0.452919989824295,2.6144769191741943,50000 -142.70739150047302,0.1985282897949218,3602.418808460236,10591,0,3602.418808460236,0.4058000147342682,2.908501148223877,10000,3745.6586899757385,0.5700932741165161,2.06361722946167,0.5237799882888794,2.2730724811553955,50000 -160.36835479736328,0.2296433448791504,4112.355051517487,12105,0,4112.355051517487,0.4216000139713287,2.831441640853882,10000,4273.334962368012,0.5975366830825806,1.9064337015151973,0.5470799803733826,2.1341938972473145,50000 -177.8351345062256,0.2601957321166992,4622.293032407761,13620,0,4622.293032407761,0.4493000209331512,2.669410228729248,10000,4800.818227529526,0.615632951259613,1.8004908561706543,0.5709399580955505,2.0190188884735107,50000 -195.62596201896667,0.2902054786682129,5132.363230466843,15135,0,5132.363230466843,0.4436000287532806,2.772144794464112,10000,5328.756629705429,0.6227080821990967,1.854886531829834,0.5773000121116638,2.06612229347229,50000 -213.26491689682007,0.3256824016571045,5642.407160997391,16650,0,5642.407160997391,0.4802000224590301,2.534341096878052,10000,5856.522654294968,0.6803053021430969,1.539376974105835,0.6041600108146667,1.868208289146424,50000 -230.8041076660156,0.3559126853942871,6152.345567941666,18165,0,6152.345567941666,0.4698000252246856,2.5441064834594727,10000,6384.078463315964,0.6703802347183228,1.531219720840454,0.5999000072479248,1.8504778146743768,50000 -248.65882778167725,0.3883523941040039,6662.31097984314,19681,0,6662.31097984314,0.4904000163078308,2.440124988555908,10000,6911.979278564453,0.6851084232330322,1.5148324966430664,0.6193000078201294,1.7937859296798706,50000 -266.68954062461853,0.4194791316986084,7172.2588493824005,21197,0,7172.2588493824005,0.4988000094890594,2.369382619857788,10000,7440.037328243256,0.6892338991165161,1.42905855178833,0.627020001411438,1.7129013538360596,50000 -284.39879989624023,0.4545462131500244,7682.28437256813,22713,0,7682.28437256813,0.4872000217437744,2.462309598922729,10000,7967.854699373245,0.684012234210968,1.504253387451172,0.6255599856376648,1.766654372215271,50000 -301.9742715358734,0.4906270503997803,8192.348526477814,24229,0,8192.348526477814,0.5049999952316284,2.389516592025757,10000,8495.57875084877,0.6905492544174194,1.4637855291366575,0.6354599595069885,1.7087171077728271,50000 -319.3741111755371,0.5262205600738525,8702.524030208588,25745,0,8702.524030208588,0.5076000094413757,2.4090499877929688,10000,9023.23773431778,0.7331194281578064,1.299314022064209,0.6287800073623657,1.7513171434402466,50000 -337.14146423339844,0.562938928604126,9212.554991006851,27261,0,9212.554991006851,0.510200023651123,2.3702125549316406,10000,9551.120630264282,0.7091438174247742,1.3814681768417358,0.6342599987983704,1.7121782302856443,50000 -354.7211463451385,0.5966382026672363,9722.621646165848,28777,0,9722.621646165848,0.5174000263214111,2.340773820877075,10000,10078.848618268968,0.7155014276504517,1.364587664604187,0.6462999582290649,1.6733760833740234,50000 -372.41778016090393,0.6293659210205078,10232.737368822098,30293,0,10232.737368822098,0.5148000121116638,2.338981866836548,10000,10606.741770744324,0.7161391973495483,1.3571972846984863,0.6464399695396423,1.6608171463012695,50000 -389.93855023384094,0.665184497833252,10742.950670957563,31809,0,10742.950670957563,0.5184000134468079,2.3254666328430176,10000,11134.55994296074,0.70511794090271,1.4040005207061768,0.64028000831604,1.698283553123474,50000 -407.4942150115967,0.7008495330810547,11253.040585517883,33325,0,11253.040585517883,0.518500030040741,2.3262481689453125,10000,11662.289780378342,0.7089444994926453,1.3775570392608645,0.6474399566650391,1.6512587070465088,50000 -425.2882845401764,0.7346818447113037,11763.106231689451,34841,0,11763.106231689451,0.5170000195503235,2.30928897857666,10000,12190.231878995895,0.7420878410339355,1.2352484464645386,0.6431399583816528,1.672220230102539,50000 -443.1710669994354,0.7699494361877441,12273.201929330826,36358,0,12273.201929330826,0.5149000287055969,2.3678078651428223,10000,12718.294162273409,0.7210817933082581,1.3398000001907349,0.6416599750518799,1.693834900856018,50000 -460.7906458377838,0.8074545860290527,12783.18024611473,37874,0,12783.18024611473,0.5271000266075134,2.265986442565918,10000,13245.977816104887,0.7284956574440002,1.312329649925232,0.6582199931144714,1.6255125999450684,50000 -478.4318549633026,0.8491504192352295,13293.384312152864,39391,0,13293.384312152864,0.5283000469207764,2.272613525390625,10000,13773.91310286522,0.7284956574440002,1.298724889755249,0.6593199968338013,1.6137521266937256,50000 -496.1568441390991,0.8838469982147217,13803.354566812515,40907,0,13803.354566812515,0.524399995803833,2.278727054595948,10000,14301.69158387184,0.7203842401504517,1.3398540019989014,0.6532599925994873,1.6414886713027954,50000 -514.7730877399445,0.921715259552002,14313.371778011322,42423,0,14313.371778011322,0.5348000526428223,2.26089096069336,10000,14830.411584377289,0.7283163070678711,1.3241567611694336,0.6582599878311157,1.6258599758148191,50000 -532.4613721370697,0.9581685066223145,14823.624761104584,43940,0,14823.624761104584,0.5181000232696533,2.357792377471924,10000,15358.438279390337,0.7344746589660645,1.301184892654419,0.6468999981880188,1.689677119255066,50000 -550.1557083129883,0.9941191673278807,15333.79275083542,45458,0,15333.79275083542,0.5314000248908997,2.2523021697998047,10000,15886.385159492493,0.7339963316917419,1.248931646347046,0.6551799774169922,1.5993354320526123,50000 -567.648931980133,1.0327842235565186,15843.919181346891,46975,0,15843.919181346891,0.5320000052452087,2.250394344329834,10000,16414.092190027237,0.7340162396430969,1.2588564157485962,0.6627399921417236,1.5765480995178225,50000 -585.4615631103516,1.0696442127227783,16353.850271701813,48491,0,16353.850271701813,0.5386000275611877,2.2234253883361816,10000,16941.92165875435,0.7350525856018066,1.2592726945877075,0.665619969367981,1.5760258436203003,50000 -603.4375021457672,1.108870029449463,16863.88741350174,50008,0,16863.88741350174,0.5308000445365906,2.27528715133667,10000,17470.022649526596,0.7184112071990967,1.2996251583099363,0.6577799916267395,1.5904572010040283,50000 -621.0799465179443,1.1506271362304688,17374.02089715004,51525,0,17374.02089715004,0.5425000190734863,2.1815757751464844,10000,17997.889188289642,0.7494618892669678,1.2071516513824463,0.6755399703979492,1.532850742340088,50000 -638.7280640602112,1.1915946006774902,17884.068472623825,53042,0,17884.068472623825,0.536300003528595,2.2342357635498047,10000,18525.674897909164,0.7538065910339355,1.20997416973114,0.6650999784469604,1.5931397676467896,50000 -656.6700406074524,1.2295169830322266,18394.20597648621,54560,0,18394.20597648621,0.5281000137329102,2.285375595092773,10000,19053.84075427056,0.7381616830825806,1.2561745643615725,0.659559965133667,1.598878264427185,50000 -674.2945744991302,1.2716209888458252,18904.25890445709,56077,0,18904.25890445709,0.5379000306129456,2.216094732284546,10000,19581.60926795005,0.7382214665412903,1.2343106269836426,0.6646400094032288,1.5662380456924438,50000 -691.946305513382,1.3120112419128418,19414.452545642853,57594,0,19414.452545642853,0.5416000485420227,2.204932928085327,10000,20109.54371070861,0.7438416481018066,1.227565050125122,0.6708399653434753,1.5553630590438845,50000 -709.6492412090302,1.3549623489379885,19924.543339967728,59112,0,19924.543339967728,0.52920001745224,2.269146203994751,10000,20637.42899608612,0.7326809763908386,1.2575324773788452,0.6616599559783936,1.569987416267395,50000 -727.0632431507111,1.396129131317139,20434.725960969925,60630,0,20434.725960969925,0.5357000231742859,2.234541177749634,10000,21165.11572289467,0.7663623690605164,1.119907021522522,0.6607800126075745,1.5795252323150637,50000 -744.6991124153137,1.4345617294311523,20944.98304605484,62148,0,20944.98304605484,0.5370000004768372,2.229222059249878,10000,21693.09570145607,0.7565170526504517,1.1901875734329224,0.6680799722671509,1.5773533582687378,50000 -762.3969528675079,1.4752840995788574,21455.00278377533,63665,0,21455.00278377533,0.5432000160217285,2.200866222381592,10000,22220.902873277664,0.7577726244926453,1.1620841026306152,0.6750199794769287,1.522871017456055,50000 -780.350729227066,1.5167927742004397,21964.96987080574,65182,0,21964.96987080574,0.5532000064849854,2.1577084064483643,10000,22748.91353034973,0.7588488459587097,1.1486854553222656,0.6837799549102783,1.4862542152404783,50000 -797.9648485183716,1.5609960556030271,22475.05553460121,66699,0,22475.05553460121,0.5546000003814697,2.189296245574951,10000,23276.70662856102,0.7502591013908386,1.213298797607422,0.6759399771690369,1.5442116260528564,50000 -815.652755022049,1.6021060943603516,22985.083032131195,68216,0,22985.083032131195,0.5452000498771667,2.196040153503418,10000,23804.51207590103,0.7473692297935486,1.19962477684021,0.6743800044059753,1.5287421941757202,50000 -833.3205726146698,1.642003059387207,23495.33017706871,69734,0,23495.33017706871,0.5418000221252441,2.243528127670288,10000,24332.515582323074,0.7683553695678711,1.1409224271774292,0.6607999801635742,1.5949475765228271,50000 -850.9681005477905,1.6843080520629885,24005.278024673466,71251,0,24005.278024673466,0.5508000254631042,2.167146682739258,10000,24860.202106952667,0.7674585580825806,1.128574013710022,0.6775799989700317,1.5199021100997925,50000 -868.7539627552032,1.7253928184509275,24515.25668501854,72768,0,24515.25668501854,0.5521000027656555,2.178276538848877,10000,25388.05660820008,0.7587292790412903,1.159325361251831,0.6759200096130371,1.519149661064148,50000 -886.672360420227,1.7654447555541992,25025.17741537094,74285,0,25025.17741537094,0.5637000203132629,2.108211040496826,10000,25915.98461484909,0.7687141299247742,1.109268307685852,0.6859599947929382,1.4653456211090088,50000 -904.7744419574738,1.8068821430206297,25535.3845744133,75803,0,25535.3845744133,0.5576000213623047,2.1966655254364014,10000,26444.38476872444,0.7549425959587097,1.1974339485168457,0.6801599860191345,1.5412228107452393,50000 -922.5809574127196,1.847849607467652,26045.36720395088,77320,0,26045.36720395088,0.5628000497817993,2.1257145404815674,10000,26972.264142274857,0.7669403553009033,1.1282862424850464,0.6887999773025513,1.4690423011779783,50000 -940.5901341438292,1.890838623046875,26555.412185668945,78837,0,26555.412185668945,0.5586000084877014,2.133809804916382,10000,27500.41028022766,0.7891621589660645,1.0352303981781006,0.684719979763031,1.4883527755737305,50000 -958.6321983337402,1.9472503662109373,27065.31203103065,80354,0,27065.31203103065,0.5583000183105469,2.126771688461304,10000,28028.457494974136,0.7771245241165161,1.0599044561386108,0.686739981174469,1.464746356010437,50000 -977.0377764701844,1.992159843444824,27575.38053822517,81870,0,27575.38053822517,0.55840003490448,2.112879991531372,10000,28557.02544283867,0.7718032598495483,1.0625839233398438,0.6845200061798096,1.44517719745636,50000 -994.7666764259338,2.0380008220672607,28085.509213924408,83388,0,28085.509213924408,0.5593000054359436,2.149788618087769,10000,29084.97811937332,0.7626355290412903,1.1467963457107544,0.6807599663734436,1.5086842775344849,50000 -1012.3052713871002,2.0864999294281006,28595.469446659088,84905,0,28595.469446659088,0.5598000288009644,2.1390833854675293,10000,29612.57447457313,0.7693120241165161,1.1221765279769895,0.6900999546051025,1.4688431024551392,50000 -1029.8799359798431,2.1302878856658936,29105.538024902344,86423,0,29105.538024902344,0.5702000260353088,2.054215669631958,10000,30140.310687065125,0.78125,1.0466645956039429,0.6961999535560608,1.407152771949768,50000 -1047.508416891098,2.1748242378234863,29615.679981470108,87940,0,29615.679981470108,0.5648000240325928,2.1412951946258545,10000,30668.1747674942,0.7981704473495483,1.0359044075012207,0.6924200057983398,1.4844554662704468,50000 -1065.0068988800049,2.219353675842285,30125.72606277466,89457,0,30125.72606277466,0.5681000351905823,2.101494789123535,10000,31195.812687158585,0.7861925959587097,1.0689496994018557,0.692359983921051,1.4794917106628418,50000 -1082.5966680049896,2.263647556304932,30635.79270672798,90974,0,30635.79270672798,0.5699000358581543,2.0710158348083496,10000,31723.56230640412,0.783621609210968,1.0466140508651731,0.6930199861526489,1.4299108982086182,50000 -1100.3993997573853,2.308422565460205,31145.91519737244,92491,0,31145.91519737244,0.5730000138282776,2.0835328102111816,10000,32251.58094215393,0.7875877022743225,1.0378481149673462,0.6992799639701843,1.4238126277923584,50000 -1118.535723209381,2.3523941040039062,31656.0486536026,94008,0,31656.0486536026,0.5737000107765198,2.070194721221924,10000,32779.9430668354,0.7824856638908386,1.0473655462265017,0.6983399987220764,1.4176557064056396,50000 -1136.5134971141815,2.400253295898437,32166.133952617645,95526,0,32166.133952617645,0.5761000514030457,2.0659029483795166,10000,33308.10327959061,0.8123804330825806,0.9565143585205078,0.6977799534797668,1.4305815696716309,50000 -1154.3426752090454,2.4465677738189697,32676.36093187332,97045,0,32676.36093187332,0.5766000151634216,2.0669610500335693,10000,33836.25459456444,0.7994260191917419,0.9897308945655824,0.6978600025177002,1.441440463066101,50000 -1171.936819076538,2.495965480804444,33186.46888136864,98562,0,33186.46888136864,0.5819000005722046,1.995455741882324,10000,34364.05491042137,0.8053650856018066,0.9407090544700624,0.7060999870300293,1.369234323501587,50000 -1189.5462565422058,2.544489622116089,33696.61943101883,100080,0,33696.61943101883,0.5737000107765198,2.07080340385437,10000,34891.91277551651,0.7974131107330322,1.000417709350586,0.7032999992370605,1.4029895067214966,50000 -1207.4842991828918,2.5927844047546387,34206.52481389046,101597,0,34206.52481389046,0.5823000073432922,2.036078691482544,10000,35419.85327458382,0.7987882494926453,0.9974828362464904,0.7042399644851685,1.3996078968048096,50000 -1225.1640086174011,2.6405186653137207,34716.570257902145,103115,0,34716.570257902145,0.5835000276565552,2.0236704349517822,10000,35947.67550730705,0.7995654940605164,0.9808319807052612,0.7089200019836426,1.38383686542511,50000 -1242.8338084220886,2.702409267425537,35226.643216609955,104632,0,35226.643216609955,0.581000030040741,2.044727802276612,10000,36475.528554201126,0.8382493257522583,0.8604050874710083,0.7122600078582764,1.3814303874969482,50000 -1260.3823611736298,2.751145124435425,35736.733736753464,106149,0,35736.733736753464,0.5835000276565552,2.0411343574523926,10000,37003.26521921158,0.8146125674247742,0.930175006389618,0.7082599997520447,1.390093445777893,50000 -1277.9910361766815,2.79943585395813,36246.702476263046,107667,0,36246.702476263046,0.5896000266075134,2.011181116104126,10000,37530.93986582756,0.8170041441917419,0.9116354584693908,0.7121999859809875,1.3599272966384888,50000 -1296.5142834186554,2.8462047576904297,36756.72316074371,109184,0,36756.72316074371,0.5759000182151794,2.094308614730835,10000,38059.57973694801,0.8053650856018066,0.9917887449264526,0.7057399749755859,1.4216396808624268,50000 -1314.1434371471405,2.8875298500061035,37266.88493037224,110701,0,37266.88493037224,0.5900000333786011,1.984406352043152,10000,38587.46114182472,0.8196946382522583,0.9097425937652588,0.7184199690818787,1.3392237424850464,50000 -1332.1986136436462,2.9358580112457275,37777.11248350144,112219,0,37777.11248350144,0.5934000015258789,1.9667091369628904,10000,39115.8409371376,0.8258330225944519,0.8885695338249207,0.7242000102996826,1.3239599466323853,50000 -1349.967808008194,2.9852118492126465,38287.02067565918,113736,0,38287.02067565918,0.5962000489234924,1.992884516716004,10000,39643.61619019508,0.8484932780265808,0.8184272646903992,0.7218599915504456,1.3505167961120603,50000 -1367.7323701381683,3.036560297012329,38796.95326471329,115253,0,38796.95326471329,0.5914000272750854,2.007413864135742,10000,40171.41336941719,0.8364556431770325,0.8760803937911987,0.7188000082969666,1.3691059350967407,50000 -1385.2446112632751,3.0897364616394043,39307.07882666588,116770,0,39307.07882666588,0.5920000076293945,1.972323298454285,10000,40699.153477191925,0.8357780575752258,0.8407694101333618,0.7210800051689148,1.322500228881836,50000 -1403.0130491256714,3.140377759933472,39817.02921366692,118288,0,39817.02921366692,0.5981000065803528,1.9613901376724243,10000,41226.97157096863,0.8320910334587097,0.842000424861908,0.7231999635696411,1.3130041360855105,50000 -1421.460168838501,3.1964597702026367,40327.05902671814,119806,0,40327.05902671814,0.6025000214576721,1.9448192119598389,10000,41755.55356431008,0.8359375,0.8376394510269165,0.7249799966812134,1.3017497062683103,50000 -1438.906097650528,3.2463817596435547,40836.98566198349,121323,0,40836.98566198349,0.5984000563621521,1.9204200506210327,10000,42283.02461075783,0.8392458558082581,0.7929951548576355,0.7260000109672546,1.2777409553527832,50000 -1456.579344749451,3.2976536750793457,41347.065252542496,122840,0,41347.065252542496,0.6027000546455383,1.939791440963745,10000,42810.877432107925,0.856465220451355,0.7566264867782593,0.7263199687004089,1.3030215501785278,50000 -1474.6264395713806,3.3471579551696777,41857.15559768677,124357,0,41857.15559768677,0.6040000319480896,1.9366973638534544,10000,43339.1133646965,0.8516820669174194,0.7917265295982361,0.7299000024795532,1.3059499263763428,50000 -1492.354014635086,3.394263982772827,42367.3814136982,125875,0,42367.3814136982,0.6016000509262085,1.963367819786072,10000,43867.16249752045,0.8467992544174194,0.8255721926689148,0.7300599813461304,1.3159984350204468,50000 -1510.0505783557892,3.447601556777954,42877.36577987671,127392,0,42877.36577987671,0.6080000400543213,1.9545058012008667,10000,44394.94526076317,0.8463209271430969,0.8251389861106873,0.7315599918365479,1.3149410486221311,50000 -1527.7779560089111,3.500355958938598,43387.268907785416,128909,0,43387.268907785416,0.6148000359535217,1.911783456802368,10000,44922.67690491676,0.8551897406578064,0.7878064513206482,0.7348999977111816,1.2902276515960691,50000 -1545.6414897441864,3.5544705390930176,43897.26140189171,130426,0,43897.26140189171,0.6135000586509705,1.874533653259277,10000,45450.63581967354,0.8754384517669678,0.675972580909729,0.7377600073814392,1.2478464841842651,50000 -1563.994446992874,3.609476804733277,44407.33543777466,131942,0,44407.33543777466,0.610200047492981,1.9361129999160769,10000,45979.16612029076,0.8747608065605164,0.712469220161438,0.7366399765014648,1.283627986907959,50000 -1581.8467528820038,3.656475305557251,44917.43026852608,133460,0,44917.43026852608,0.617900013923645,1.9060282707214355,10000,46507.20842504501,0.872468888759613,0.7190256714820862,0.7376599907875061,1.2690998315811155,50000 -1599.4554841518402,3.710766077041626,45427.43486452103,134977,0,45427.43486452103,0.6167000532150269,1.884137749671936,10000,47034.92509889603,0.8716716766357422,0.7018413543701172,0.7393400073051453,1.2499456405639648,50000 -1617.0773251056671,3.766352891921997,45937.36474323273,136494,0,45937.36474323273,0.6141000390052795,1.8976376056671145,10000,47562.58168148994,0.8722297549247742,0.7077077031135559,0.7393400073051453,1.2572718858718872,50000 -1634.5576055049896,3.823425769805908,46447.27852892876,138011,0,46447.27852892876,0.6186000108718872,1.8796846866607664,10000,48090.08140492439,0.87015700340271,0.7002887725830078,0.7409999966621399,1.2513980865478516,50000 -1652.4860928058624,3.877878189086914,46957.439453840256,139529,0,46957.439453840256,0.6126000285148621,1.902232766151428,10000,48618.27424407005,0.8952487111091614,0.6294954419136047,0.7395399808883667,1.2603678703308103,50000 -1670.0611016750336,3.935212373733521,47467.50153207779,141046,0,47467.50153207779,0.6173000335693359,1.898269772529602,10000,49146.01747059822,0.8868981003761292,0.6444560289382935,0.7418599724769592,1.247159719467163,50000 -1687.8899908065796,3.9951000213623047,47977.68201828003,142564,0,47977.68201828003,0.6184000372886658,1.8916077613830569,10000,49674.135682582855,0.8874959945678711,0.652847409248352,0.741599977016449,1.251569747924805,50000 -1705.4653916358948,4.055452346801758,48487.624255895615,144081,0,48487.624255895615,0.6184000372886658,1.899420022964477,10000,50201.76250863075,0.8844068646430969,0.6725389957427979,0.7425000071525574,1.2619949579238892,50000 -1723.1903104782104,4.112022161483765,48997.56184363365,145597,0,48997.56184363365,0.6247000098228455,1.8494324684143064,10000,50729.53057074547,0.893973171710968,0.6204026341438293,0.750499963760376,1.2134588956832886,50000 -1740.7729632854462,4.185078859329224,49507.54261040688,147114,0,49507.54261040688,0.6261000037193298,1.849046230316162,10000,51257.21557235718,0.8909637928009033,0.6227953433990479,0.7475199699401855,1.2216311693191528,50000 -1758.690044879913,4.2591400146484375,50017.68539023399,148632,0,50017.68539023399,0.6324000358581543,1.8357360363006592,10000,51785.398183107376,0.9108936190605164,0.5526172518730164,0.7495799660682678,1.2158610820770264,50000 -1776.3126661777496,4.313075065612793,50527.5884706974,150148,0,50527.5884706974,0.6328000426292419,1.8467143774032595,10000,52313.0264942646,0.906887710094452,0.5834668278694153,0.751039981842041,1.21971595287323,50000 -1794.0457956790924,4.367350816726685,51037.737275362015,151666,0,51037.737275362015,0.6306000351905823,1.8276084661483765,10000,52841.0115032196,0.906628668308258,0.5731101632118225,0.753279983997345,1.204783320426941,50000 -1811.6900601387024,4.42540454864502,51547.88285636902,153183,0,51547.88285636902,0.6339000463485718,1.8135104179382324,10000,53368.90864777565,0.9112324714660645,0.56153404712677,0.7550399899482727,1.1990495920181274,50000 -1829.5243427753448,4.483033180236816,52058.02461576462,154701,0,52058.02461576462,0.6313000321388245,1.840086817741394,10000,53896.99075818062,0.9130659699440002,0.5653855800628662,0.7539599537849426,1.209978461265564,50000 -1847.399951696396,4.56777286529541,52568.0820877552,156219,0,52568.0820877552,0.6291000247001648,1.8272104263305664,10000,54425.057513952255,0.9151187539100648,0.5467893481254578,0.7532999515533447,1.2041923999786377,50000 -1865.1478443145752,4.623232364654541,53078.011219501495,157736,0,53078.011219501495,0.6354000568389893,1.8231844902038568,10000,54952.83845996857,0.9242864847183228,0.5116089582443237,0.7549600005149841,1.1921504735946655,50000 -1883.357558965683,4.680333614349365,53588.11385130882,159253,0,53588.11385130882,0.6379000544548035,1.805983066558838,10000,55481.25697994232,0.9227519035339355,0.5066460371017456,0.7560999989509583,1.1838161945343018,50000 -1900.907883644104,4.739240646362305,54098.28979110718,160772,0,54098.28979110718,0.6315000057220459,1.8179877996444704,10000,56009.091359615326,0.9209582209587096,0.5264387726783752,0.7565599679946899,1.198883295059204,50000 -1918.6634063720703,4.798632621765137,54608.4274187088,162289,0,54608.4274187088,0.6361000537872314,1.8157532215118408,10000,56537.093678474426,0.9260801672935486,0.5113489627838135,0.7582199573516846,1.190921425819397,50000 -1936.598509311676,4.890327453613281,55118.39876580238,163806,0,55118.39876580238,0.6383000016212463,1.8151981830596924,10000,57065.14196181297,0.9243263602256776,0.5113154053688049,0.7582599520683289,1.1876296997070312,50000 -1954.2799079418185,4.950265169143677,55628.49181294441,165323,0,55628.49181294441,0.6380000114440918,1.814751029014588,10000,57593.02518582344,0.9312419891357422,0.4898253381252289,0.7599200010299683,1.1844065189361572,50000 -1971.945927143097,5.009655952453613,56138.64579749107,166840,0,56138.64579749107,0.6389000415802002,1.793910026550293,10000,58120.9536485672,0.9327367544174194,0.4756486117839813,0.7603200078010559,1.176532745361328,50000 -1989.7520174980164,5.06985878944397,56648.71825623512,168357,0,56648.71825623512,0.6421000361442566,1.8000293970108032,10000,58648.94144535065,0.9333944320678712,0.4795606434345245,0.7605199813842773,1.178816318511963,50000 -2007.657867193222,5.131593704223633,57158.62016701698,169874,0,57158.62016701698,0.6429000496864319,1.791094422340393,10000,59176.86003422737,0.9356465339660645,0.4672803878784179,0.7620399594306946,1.1741501092910769,50000 -2025.074841976165,5.192117214202881,57668.82580900192,171391,0,57668.82580900192,0.6449000239372253,1.793129324913025,10000,59704.59183359146,0.9360251426696776,0.4715806841850281,0.7620599865913391,1.177618145942688,50000 -2042.8664588928225,5.255893468856812,58178.83932805061,172908,0,58178.83932805061,0.640500009059906,1.8016676902771,10000,60232.509991168976,0.935546875,0.4717677533626556,0.7625799775123596,1.175244688987732,50000 -2060.579292535782,5.320931911468506,58688.77640485764,174424,0,58688.77640485764,0.6444000601768494,1.7961137294769287,10000,60760.27409243584,0.9393534660339355,0.4597503542900085,0.7628600001335144,1.1763155460357666,50000 -2078.3175649642944,5.382373809814453,59198.87612986565,175941,0,59198.87612986565,0.6434000134468079,1.7967631816864014,10000,61288.2229487896,0.939233899116516,0.4566811025142669,0.7627999782562256,1.177257061004639,50000 -2095.821103572845,5.4390709400177,59708.8821656704,177458,0,59708.8821656704,0.6452000141143799,1.7910445928573608,10000,61815.838076114655,0.9397919178009032,0.4564704298973083,0.7625600099563599,1.1731715202331543,50000 -2113.4979746341705,5.501019239425659,60219.034625291824,178976,0,60219.034625291824,0.6419000029563904,1.791907548904419,10000,62343.77858257294,0.9402303695678712,0.4584923982620239,0.7633799910545349,1.1745049953460691,50000 -2131.249956130981,5.567572593688965,60729.01041841507,180493,0,60729.01041841507,0.6457000374794006,1.7949339151382446,10000,62871.621817588806,0.9389349222183228,0.4585332870483398,0.7633999586105347,1.175327181816101,50000 -2148.8214042186737,5.627314567565918,61239.120810985565,182010,0,61239.120810985565,0.6449000239372253,1.789433240890503,10000,63399.412459373474,0.93949294090271,0.4553861320018768,0.7636399865150452,1.1701064109802246,50000 -2166.364020586014,5.686914920806885,61749.29934263229,183527,0,61749.29934263229,0.6444000601768494,1.7934223413467407,10000,63927.24230384827,0.9393534660339355,0.4550036489963531,0.7640799880027771,1.1725746393203735,50000 -2184.0561985969543,5.7481689453125,62259.41258692741,185044,0,62259.41258692741,0.6449000239372253,1.789334416389465,10000,64455.15799450874,0.9412667155265808,0.4489424228668213,0.763979971408844,1.1695361137390137,50000 -2201.5659034252167,5.813524484634399,62769.59384036064,186561,0,62769.59384036064,0.6455000042915344,1.7891464233398438,10000,64982.96338915825,0.9403699040412904,0.452570378780365,0.7641800045967102,1.1703591346740725,50000 -2219.1971287727356,5.877562046051025,62804.61339139938,186666,0,62804.61339139938,0.64410001039505,1.787272572517395,10000,65035.68202996254,0.9392936825752258,0.4516395628452301,0.7646999955177307,1.167160987854004,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/measurements.csv deleted file mode 100644 index 19a6f4ac1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1994 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.5329929,6.920775,,,,,,,,,,,,,, -1,,,0.0004783163312822,6.910229682922363,0.0009599999757483,6.910243988037109,50000.0,0.0006000000284984,6.910250186920166,10000.0,31.64558982849121,48.83828067779541,31.64558982849121,17.192532777786255,0.0,0.0 -100,0.5142984,6.9001083,,,,,,,,,,,,,, -200,0.5290886,6.8554277,,,,,,,,,,,,,, -300,0.59028065,6.7649374,,,,,,,,,,,,,, -400,0.61277163,6.6682243,,,,,,,,,,,,,, -500,0.64812905,6.5798864,,,,,,,,,,,,,, -600,0.7033887,6.543291,,,,,,,,,,,,,, -700,1.5174623,6.413944,,,,,,,,,,,,,, -800,1.3738877,6.302283,,,,,,,,,,,,,, -900,2.0624993,6.2393694,,,,,,,,,,,,,, -1000,1.3498743,6.2225704,,,,,,,,,,,,,, -1100,1.9189479,6.1569295,,,,,,,,,,,,,, -1200,1.4128162,6.1113524,,,,,,,,,,,,,, -1300,1.8241144,5.9920263,,,,,,,,,,,,,, -1400,3.4546568,5.942293,,,,,,,,,,,,,, -1500,2.077801,5.9008546,,,,,,,,,,,,,, -1514,,,0.0840242356061935,5.227553844451904,0.0806799978017807,5.2884297370910645,50000.0,0.055000003427267,5.537538528442383,10000.0,541.6560969352722,576.5695464611053,541.6560969352722,34.846673011779785,0.0188093185424804,0.0 -1600,2.5705738,5.8502455,,,,,,,,,,,,,, -1700,3.0496166,5.834356,,,,,,,,,,,,,, -1800,3.102287,5.7493496,,,,,,,,,,,,,, -1900,2.6538856,5.7278924,,,,,,,,,,,,,, -2000,4.594701,5.6174026,,,,,,,,,,,,,, -2100,2.8848546,5.5767937,,,,,,,,,,,,,, -2200,4.131198,5.547429,,,,,,,,,,,,,, -2300,2.6348245,5.5714884,,,,,,,,,,,,,, -2400,2.4428873,5.441094,,,,,,,,,,,,,, -2500,2.5140896,5.44436,,,,,,,,,,,,,, -2600,4.0546417,5.4254236,,,,,,,,,,,,,, -2700,4.261032,5.4207296,,,,,,,,,,,,,, -2800,3.3500729,5.352999,,,,,,,,,,,,,, -2900,3.4894447,5.317024,,,,,,,,,,,,,, -3000,4.2178493,5.151644,,,,,,,,,,,,,, -3026,,,0.199597418308258,4.143094062805176,0.1790200024843216,4.270390033721924,50000.0,0.1348000019788742,4.705352783203125,10000.0,1051.752347946167,1104.6250972747805,1051.752347946167,52.72998690605164,0.0470173358917236,0.0 -3100,5.3901725,5.228229,,,,,,,,,,,,,, -3200,3.2450285,5.137745,,,,,,,,,,,,,, -3300,2.227993,5.141033,,,,,,,,,,,,,, -3400,5.7631264,5.1335135,,,,,,,,,,,,,, -3500,4.1648755,5.020753,,,,,,,,,,,,,, -3600,5.2285867,4.9944286,,,,,,,,,,,,,, -3700,3.2501023,4.939705,,,,,,,,,,,,,, -3800,4.570208,4.8324885,,,,,,,,,,,,,, -3900,5.7761197,4.8931007,,,,,,,,,,,,,, -4000,4.1255674,4.8974915,,,,,,,,,,,,,, -4100,3.2145367,4.867345,,,,,,,,,,,,,, -4200,4.340966,4.8177624,,,,,,,,,,,,,, -4300,7.455496,4.8567867,,,,,,,,,,,,,, -4400,2.8838506,4.7786946,,,,,,,,,,,,,, -4500,5.677103,4.689653,,,,,,,,,,,,,, -4538,,,0.3047672212123871,3.462144136428833,0.2767399847507477,3.6103830337524414,50000.0,0.2043000161647796,4.173789024353027,10000.0,1561.9691922664642,1633.4822070598602,1561.9691922664642,71.29328966140747,0.0761914253234863,0.0 -4600,4.3045845,4.6523623,,,,,,,,,,,,,, -4700,4.397726,4.659617,,,,,,,,,,,,,, -4800,3.7682989,4.6113515,,,,,,,,,,,,,, -4900,3.6233974,4.630513,,,,,,,,,,,,,, -5000,3.926668,4.5374146,,,,,,,,,,,,,, -5100,2.4813452,4.483537,,,,,,,,,,,,,, -5200,6.0025773,4.491422,,,,,,,,,,,,,, -5300,4.0567627,4.543038,,,,,,,,,,,,,, -5400,3.4377253,4.5155563,,,,,,,,,,,,,, -5500,3.644877,4.43493,,,,,,,,,,,,,, -5600,3.454095,4.478875,,,,,,,,,,,,,, -5700,2.8381164,4.380797,,,,,,,,,,,,,, -5800,4.7541666,4.434137,,,,,,,,,,,,,, -5900,3.8845704,4.3857875,,,,,,,,,,,,,, -6000,4.120787,4.3305116,,,,,,,,,,,,,, -6051,,,0.3840282261371612,2.9529519081115723,0.3597399890422821,3.094580173492432,50000.0,0.2770000100135803,3.678683042526245,10000.0,2072.226632356644,2161.6236131191254,2072.226632356644,89.09815335273743,0.1078622341156005,0.0 -6100,2.4832888,4.28971,,,,,,,,,,,,,, -6200,3.9970987,4.2722206,,,,,,,,,,,,,, -6300,3.385373,4.331168,,,,,,,,,,,,,, -6400,3.6025305,4.3224607,,,,,,,,,,,,,, -6500,2.8914185,4.318966,,,,,,,,,,,,,, -6600,2.2950459,4.2434583,,,,,,,,,,,,,, -6700,2.645514,4.2972145,,,,,,,,,,,,,, -6800,3.500608,4.2070875,,,,,,,,,,,,,, -6900,3.312379,4.1972494,,,,,,,,,,,,,, -7000,3.795389,4.205227,,,,,,,,,,,,,, -7100,4.512348,4.201164,,,,,,,,,,,,,, -7200,3.2093163,4.251354,,,,,,,,,,,,,, -7300,1.9303346,4.133174,,,,,,,,,,,,,, -7400,2.7437508,4.157182,,,,,,,,,,,,,, -7500,2.5938723,4.1474943,,,,,,,,,,,,,, -7563,,,0.4412667453289032,2.6990714073181152,0.4002600014209747,2.899838447570801,50000.0,0.3039000034332275,3.5461227893829346,10000.0,2582.2532529830933,2689.541732311249,2582.2532529830933,106.90979647636414,0.140455961227417,0.0 -7600,3.9009998,4.128934,,,,,,,,,,,,,, -7700,2.2752469,4.0656743,,,,,,,,,,,,,, -7800,3.1460295,4.0581536,,,,,,,,,,,,,, -7900,3.891442,4.0954146,,,,,,,,,,,,,, -8000,2.850617,4.0425935,,,,,,,,,,,,,, -8100,2.5896487,4.0568695,,,,,,,,,,,,,, -8200,1.6386756,4.043369,,,,,,,,,,,,,, -8300,2.5398624,3.9973679,,,,,,,,,,,,,, -8400,2.52665,4.0693645,,,,,,,,,,,,,, -8500,2.3394332,3.9548173,,,,,,,,,,,,,, -8600,2.8965836,3.982637,,,,,,,,,,,,,, -8700,2.213532,3.9524746,,,,,,,,,,,,,, -8800,2.380564,3.9879265,,,,,,,,,,,,,, -8900,2.499839,3.9765239,,,,,,,,,,,,,, -9000,3.790597,4.0251875,,,,,,,,,,,,,, -9077,,,0.5027104616165161,2.3729019165039062,0.452919989824295,2.6144769191741943,50000.0,0.3500000238418579,3.250772476196289,10000.0,3092.440190553665,3217.438376188278,3092.440190553665,124.54310297966003,0.1691508293151855,0.0 -9100,2.6172037,3.8468373,,,,,,,,,,,,,, -9200,2.4059517,3.8832607,,,,,,,,,,,,,, -9300,3.59445,3.9690242,,,,,,,,,,,,,, -9400,2.832182,3.9127848,,,,,,,,,,,,,, -9500,2.6590612,3.850317,,,,,,,,,,,,,, -9600,2.4086814,3.931879,,,,,,,,,,,,,, -9700,2.6804698,3.9083302,,,,,,,,,,,,,, -9800,2.7995653,3.8277917,,,,,,,,,,,,,, -9900,2.164042,3.8256779,,,,,,,,,,,,,, -10000,1.9430019,3.8416817,,,,,,,,,,,,,, -10100,2.3368094,3.819652,,,,,,,,,,,,,, -10200,1.6291173,3.9235215,,,,,,,,,,,,,, -10300,2.0728807,3.856732,,,,,,,,,,,,,, -10400,2.1316915,3.8570242,,,,,,,,,,,,,, -10500,1.868098,3.7980802,,,,,,,,,,,,,, -10591,,,0.5700932741165161,2.06361722946167,0.5237799882888794,2.2730724811553955,50000.0,0.4058000147342682,2.908501148223877,10000.0,3602.418808460236,3745.6586899757385,3602.418808460236,142.70739150047302,0.1985282897949218,0.0 -10600,1.9822761,3.705033,,,,,,,,,,,,,, -10700,1.7569,3.7656243,,,,,,,,,,,,,, -10800,1.8056644,3.842495,,,,,,,,,,,,,, -10900,2.5785992,3.716211,,,,,,,,,,,,,, -11000,1.9918677,3.7156177,,,,,,,,,,,,,, -11100,2.0957057,3.7581146,,,,,,,,,,,,,, -11200,2.049396,3.6425586,,,,,,,,,,,,,, -11300,1.4761354,3.7347038,,,,,,,,,,,,,, -11400,1.8824708,3.7718992,,,,,,,,,,,,,, -11500,1.6955446,3.6968002,,,,,,,,,,,,,, -11600,1.8025264,3.8410692,,,,,,,,,,,,,, -11700,2.1115327,3.7056985,,,,,,,,,,,,,, -11800,1.7000954,3.6511068,,,,,,,,,,,,,, -11900,1.6949332,3.6859465,,,,,,,,,,,,,, -12000,2.0966945,3.6167736,,,,,,,,,,,,,, -12100,1.6238754,3.5507076,,,,,,,,,,,,,, -12105,,,0.5975366830825806,1.9064337015151973,0.5470799803733826,2.1341938972473145,50000.0,0.4216000139713287,2.831441640853882,10000.0,4112.355051517487,4273.334962368012,4112.355051517487,160.36835479736328,0.2296433448791504,0.0 -12200,1.3364317,3.553541,,,,,,,,,,,,,, -12300,1.5221012,3.554859,,,,,,,,,,,,,, -12400,1.9035159,3.6783504,,,,,,,,,,,,,, -12500,2.181125,3.6672142,,,,,,,,,,,,,, -12600,1.8005174,3.6270685,,,,,,,,,,,,,, -12700,1.7545755,3.5810485,,,,,,,,,,,,,, -12800,1.4178866,3.602148,,,,,,,,,,,,,, -12900,1.8234706,3.631659,,,,,,,,,,,,,, -13000,1.3985083,3.698904,,,,,,,,,,,,,, -13100,2.0063736,3.6698499,,,,,,,,,,,,,, -13200,1.3381014,3.5283144,,,,,,,,,,,,,, -13300,1.6038805,3.573575,,,,,,,,,,,,,, -13400,1.5729443,3.6446683,,,,,,,,,,,,,, -13500,1.4788879,3.524342,,,,,,,,,,,,,, -13600,1.6841189,3.5442176,,,,,,,,,,,,,, -13620,,,0.615632951259613,1.8004908561706543,0.5709399580955505,2.0190188884735107,50000.0,0.4493000209331512,2.669410228729248,10000.0,4622.293032407761,4800.818227529526,4622.293032407761,177.8351345062256,0.2601957321166992,0.0 -13700,1.5666579,3.5891275,,,,,,,,,,,,,, -13800,1.7900366,3.531038,,,,,,,,,,,,,, -13900,1.7016798,3.6001744,,,,,,,,,,,,,, -14000,1.6363806,3.6251361,,,,,,,,,,,,,, -14100,1.8694508,3.675049,,,,,,,,,,,,,, -14200,1.9246894,3.6913254,,,,,,,,,,,,,, -14300,2.1325371,3.5713058,,,,,,,,,,,,,, -14400,1.6073534,3.5039866,,,,,,,,,,,,,, -14500,1.5013305,3.5011222,,,,,,,,,,,,,, -14600,1.7415622,3.5921392,,,,,,,,,,,,,, -14700,1.6281989,3.684927,,,,,,,,,,,,,, -14800,1.5748689,3.4688444,,,,,,,,,,,,,, -14900,1.7221854,3.4402797,,,,,,,,,,,,,, -15000,1.4200892,3.523467,,,,,,,,,,,,,, -15100,1.3205806,3.5758607,,,,,,,,,,,,,, -15135,,,0.6227080821990967,1.854886531829834,0.5773000121116638,2.06612229347229,50000.0,0.4436000287532806,2.772144794464112,10000.0,5132.363230466843,5328.756629705429,5132.363230466843,195.62596201896667,0.2902054786682129,0.0 -15200,1.5443032,3.5636272,,,,,,,,,,,,,, -15300,1.2453654,3.526633,,,,,,,,,,,,,, -15400,1.5273204,3.4726162,,,,,,,,,,,,,, -15500,1.5498677,3.558669,,,,,,,,,,,,,, -15600,1.5751405,3.5949125,,,,,,,,,,,,,, -15700,1.3031582,3.4385917,,,,,,,,,,,,,, -15800,1.2004144,3.4628954,,,,,,,,,,,,,, -15900,1.5575879,3.5135975,,,,,,,,,,,,,, -16000,2.1331506,3.464448,,,,,,,,,,,,,, -16100,1.6732254,3.4755356,,,,,,,,,,,,,, -16200,2.1034787,3.5464742,,,,,,,,,,,,,, -16300,1.3755659,3.493001,,,,,,,,,,,,,, -16400,1.4389987,3.4545462,,,,,,,,,,,,,, -16500,1.2508142,3.4896054,,,,,,,,,,,,,, -16600,1.4075367,3.4437597,,,,,,,,,,,,,, -16650,,,0.6803053021430969,1.539376974105835,0.6041600108146667,1.868208289146424,50000.0,0.4802000224590301,2.534341096878052,10000.0,5642.407160997391,5856.522654294968,5642.407160997391,213.26491689682007,0.3256824016571045,0.0 -16700,1.1861953,3.4462044,,,,,,,,,,,,,, -16800,1.3835166,3.5060844,,,,,,,,,,,,,, -16900,1.7445966,3.4261627,,,,,,,,,,,,,, -17000,1.7251418,3.5269842,,,,,,,,,,,,,, -17100,1.3732264,3.447615,,,,,,,,,,,,,, -17200,1.4014268,3.4750373,,,,,,,,,,,,,, -17300,1.2181332,3.4292421,,,,,,,,,,,,,, -17400,1.6514289,3.4825444,,,,,,,,,,,,,, -17500,1.2629032,3.4127502,,,,,,,,,,,,,, -17600,1.1953363,3.4594436,,,,,,,,,,,,,, -17700,1.5579867,3.4603574,,,,,,,,,,,,,, -17800,1.1734895,3.4932718,,,,,,,,,,,,,, -17900,1.3974684,3.4009366,,,,,,,,,,,,,, -18000,1.4012809,3.4228458,,,,,,,,,,,,,, -18100,1.3150643,3.3499038,,,,,,,,,,,,,, -18165,,,0.6703802347183228,1.531219720840454,0.5999000072479248,1.8504778146743768,50000.0,0.4698000252246856,2.5441064834594727,10000.0,6152.345567941666,6384.078463315964,6152.345567941666,230.8041076660156,0.3559126853942871,0.0 -18200,1.6546787,3.4524345,,,,,,,,,,,,,, -18300,1.9533113,3.3808696,,,,,,,,,,,,,, -18400,1.357933,3.334373,,,,,,,,,,,,,, -18500,1.4668144,3.427084,,,,,,,,,,,,,, -18600,1.1781367,3.4358594,,,,,,,,,,,,,, -18700,1.2741504,3.3908515,,,,,,,,,,,,,, -18800,1.2493999,3.4768755,,,,,,,,,,,,,, -18900,1.7527839,3.4065006,,,,,,,,,,,,,, -19000,1.2360318,3.3348262,,,,,,,,,,,,,, -19100,1.4199747,3.3555706,,,,,,,,,,,,,, -19200,1.4449904,3.4217408,,,,,,,,,,,,,, -19300,1.3546638,3.4555955,,,,,,,,,,,,,, -19400,1.3413953,3.4061394,,,,,,,,,,,,,, -19500,1.2297329,3.3556376,,,,,,,,,,,,,, -19600,1.235578,3.3384495,,,,,,,,,,,,,, -19681,,,0.6851084232330322,1.5148324966430664,0.6193000078201294,1.7937859296798706,50000.0,0.4904000163078308,2.440124988555908,10000.0,6662.31097984314,6911.979278564453,6662.31097984314,248.65882778167725,0.3883523941040039,0.0 -19700,1.3077668,3.4536085,,,,,,,,,,,,,, -19800,1.4239614,3.4218316,,,,,,,,,,,,,, -19900,1.516628,3.3594902,,,,,,,,,,,,,, -20000,1.1239126,3.439583,,,,,,,,,,,,,, -20100,1.1641282,3.3929188,,,,,,,,,,,,,, -20200,1.308635,3.5064538,,,,,,,,,,,,,, -20300,1.2951595,3.345814,,,,,,,,,,,,,, -20400,1.3430996,3.3744805,,,,,,,,,,,,,, -20500,1.7307305,3.3731604,,,,,,,,,,,,,, -20600,1.6239358,3.3369722,,,,,,,,,,,,,, -20700,1.4288422,3.3718007,,,,,,,,,,,,,, -20800,1.7475624,3.488295,,,,,,,,,,,,,, -20900,1.5667171,3.3656707,,,,,,,,,,,,,, -21000,1.2987723,3.4257514,,,,,,,,,,,,,, -21100,1.5999175,3.4148884,,,,,,,,,,,,,, -21197,,,0.6892338991165161,1.42905855178833,0.627020001411438,1.7129013538360596,50000.0,0.4988000094890594,2.369382619857788,10000.0,7172.2588493824005,7440.037328243256,7172.2588493824005,266.68954062461853,0.4194791316986084,0.0 -21200,1.4342017,3.3254492,,,,,,,,,,,,,, -21300,1.3957314,3.3527362,,,,,,,,,,,,,, -21400,1.2970052,3.292965,,,,,,,,,,,,,, -21500,1.5190036,3.3689482,,,,,,,,,,,,,, -21600,1.3075783,3.3489466,,,,,,,,,,,,,, -21700,1.2385787,3.368726,,,,,,,,,,,,,, -21800,1.4260156,3.316865,,,,,,,,,,,,,, -21900,1.2538593,3.29966,,,,,,,,,,,,,, -22000,1.3238925,3.2463148,,,,,,,,,,,,,, -22100,1.5201663,3.3216379,,,,,,,,,,,,,, -22200,1.2907774,3.354905,,,,,,,,,,,,,, -22300,1.3953265,3.3500571,,,,,,,,,,,,,, -22400,1.5999659,3.3584547,,,,,,,,,,,,,, -22500,1.3397617,3.2959101,,,,,,,,,,,,,, -22600,1.4702237,3.2951293,,,,,,,,,,,,,, -22700,1.810478,3.3740888,,,,,,,,,,,,,, -22713,,,0.684012234210968,1.504253387451172,0.6255599856376648,1.766654372215271,50000.0,0.4872000217437744,2.462309598922729,10000.0,7682.28437256813,7967.854699373245,7682.28437256813,284.39879989624023,0.4545462131500244,0.0 -22800,1.4939449,3.2927573,,,,,,,,,,,,,, -22900,1.434183,3.2796593,,,,,,,,,,,,,, -23000,1.3718485,3.3473911,,,,,,,,,,,,,, -23100,1.6246545,3.482356,,,,,,,,,,,,,, -23200,1.2885597,3.4060867,,,,,,,,,,,,,, -23300,1.7692211,3.3316956,,,,,,,,,,,,,, -23400,1.3510528,3.2816265,,,,,,,,,,,,,, -23500,1.2163706,3.2899158,,,,,,,,,,,,,, -23600,1.2806885,3.3880837,,,,,,,,,,,,,, -23700,1.6704539,3.302158,,,,,,,,,,,,,, -23800,1.4444244,3.3597898,,,,,,,,,,,,,, -23900,1.487505,3.329415,,,,,,,,,,,,,, -24000,1.2601919,3.2767751,,,,,,,,,,,,,, -24100,1.3093666,3.3465261,,,,,,,,,,,,,, -24200,1.4066945,3.393867,,,,,,,,,,,,,, -24229,,,0.6905492544174194,1.4637855291366575,0.6354599595069885,1.7087171077728271,50000.0,0.5049999952316284,2.389516592025757,10000.0,8192.348526477814,8495.57875084877,8192.348526477814,301.9742715358734,0.4906270503997803,0.0 -24300,1.360961,3.3258004,,,,,,,,,,,,,, -24400,1.5101805,3.342997,,,,,,,,,,,,,, -24500,1.2572109,3.3431456,,,,,,,,,,,,,, -24600,1.1390395,3.2844253,,,,,,,,,,,,,, -24700,1.5010175,3.378875,,,,,,,,,,,,,, -24800,1.3711184,3.3189187,,,,,,,,,,,,,, -24900,1.2491783,3.3106413,,,,,,,,,,,,,, -25000,1.4164093,3.3525028,,,,,,,,,,,,,, -25100,1.1977544,3.2400863,,,,,,,,,,,,,, -25200,1.2201208,3.2863262,,,,,,,,,,,,,, -25300,1.1966432,3.2765148,,,,,,,,,,,,,, -25400,1.6946633,3.303173,,,,,,,,,,,,,, -25500,1.5188773,3.250499,,,,,,,,,,,,,, -25600,1.1453044,3.2728436,,,,,,,,,,,,,, -25700,1.2247227,3.2283714,,,,,,,,,,,,,, -25745,,,0.7331194281578064,1.299314022064209,0.6287800073623657,1.7513171434402466,50000.0,0.5076000094413757,2.4090499877929688,10000.0,8702.524030208588,9023.23773431778,8702.524030208588,319.3741111755371,0.5262205600738525,0.0 -25800,1.437329,3.3389535,,,,,,,,,,,,,, -25900,1.4339492,3.3233058,,,,,,,,,,,,,, -26000,1.349757,3.3118837,,,,,,,,,,,,,, -26100,1.4631621,3.3359559,,,,,,,,,,,,,, -26200,1.3079575,3.2347112,,,,,,,,,,,,,, -26300,1.4514904,3.3361478,,,,,,,,,,,,,, -26400,1.3830531,3.3233569,,,,,,,,,,,,,, -26500,1.4366707,3.3163314,,,,,,,,,,,,,, -26600,1.4476783,3.3515234,,,,,,,,,,,,,, -26700,1.2175709,3.3284762,,,,,,,,,,,,,, -26800,1.2346222,3.2936263,,,,,,,,,,,,,, -26900,1.4488527,3.2456844,,,,,,,,,,,,,, -27000,1.1739887,3.1607893,,,,,,,,,,,,,, -27100,1.3903091,3.3987927,,,,,,,,,,,,,, -27200,1.2465442,3.2111287,,,,,,,,,,,,,, -27261,,,0.7091438174247742,1.3814681768417358,0.6342599987983704,1.7121782302856443,50000.0,0.510200023651123,2.3702125549316406,10000.0,9212.554991006851,9551.120630264282,9212.554991006851,337.14146423339844,0.562938928604126,0.0 -27300,1.3116304,3.2321672,,,,,,,,,,,,,, -27400,1.2247863,3.2879639,,,,,,,,,,,,,, -27500,1.4195215,3.2987275,,,,,,,,,,,,,, -27600,1.2603302,3.285247,,,,,,,,,,,,,, -27700,1.3100725,3.2994323,,,,,,,,,,,,,, -27800,1.5272442,3.3266084,,,,,,,,,,,,,, -27900,1.4277242,3.2779708,,,,,,,,,,,,,, -28000,1.3379636,3.289972,,,,,,,,,,,,,, -28100,1.3801438,3.3763793,,,,,,,,,,,,,, -28200,1.2243533,3.2643468,,,,,,,,,,,,,, -28300,1.5302323,3.2564664,,,,,,,,,,,,,, -28400,1.756814,3.3306928,,,,,,,,,,,,,, -28500,1.4058245,3.246001,,,,,,,,,,,,,, -28600,1.2346827,3.2029595,,,,,,,,,,,,,, -28700,1.244756,3.2765396,,,,,,,,,,,,,, -28777,,,0.7155014276504517,1.364587664604187,0.6462999582290649,1.6733760833740234,50000.0,0.5174000263214111,2.340773820877075,10000.0,9722.621646165848,10078.848618268968,9722.621646165848,354.7211463451385,0.5966382026672363,0.0 -28800,1.4564668,3.2890077,,,,,,,,,,,,,, -28900,1.4423323,3.3012867,,,,,,,,,,,,,, -29000,1.6250304,3.2624774,,,,,,,,,,,,,, -29100,1.2154744,3.223061,,,,,,,,,,,,,, -29200,1.3846482,3.2323534,,,,,,,,,,,,,, -29300,1.2463404,3.172532,,,,,,,,,,,,,, -29400,1.8127089,3.2527692,,,,,,,,,,,,,, -29500,1.2690604,3.299761,,,,,,,,,,,,,, -29600,1.2443448,3.251122,,,,,,,,,,,,,, -29700,1.3065242,3.2719586,,,,,,,,,,,,,, -29800,1.5257875,3.2546473,,,,,,,,,,,,,, -29900,1.4996513,3.2625704,,,,,,,,,,,,,, -30000,1.3123577,3.3122308,,,,,,,,,,,,,, -30100,1.6772285,3.2453427,,,,,,,,,,,,,, -30200,1.5671818,3.2388253,,,,,,,,,,,,,, -30293,,,0.7161391973495483,1.3571972846984863,0.6464399695396423,1.6608171463012695,50000.0,0.5148000121116638,2.338981866836548,10000.0,10232.737368822098,10606.741770744324,10232.737368822098,372.41778016090393,0.6293659210205078,0.0 -30300,1.5376334,3.3329067,,,,,,,,,,,,,, -30400,1.4672863,3.2296176,,,,,,,,,,,,,, -30500,1.3045275,3.2458565,,,,,,,,,,,,,, -30600,1.4143264,3.2330582,,,,,,,,,,,,,, -30700,1.3136097,3.3205142,,,,,,,,,,,,,, -30800,1.4036644,3.238543,,,,,,,,,,,,,, -30900,1.230583,3.2867386,,,,,,,,,,,,,, -31000,1.4051605,3.2268968,,,,,,,,,,,,,, -31100,1.3774364,3.3316212,,,,,,,,,,,,,, -31200,1.4354473,3.183695,,,,,,,,,,,,,, -31300,1.3692701,3.2091725,,,,,,,,,,,,,, -31400,1.2719495,3.262236,,,,,,,,,,,,,, -31500,1.3789423,3.202118,,,,,,,,,,,,,, -31600,1.6460512,3.1993835,,,,,,,,,,,,,, -31700,1.2672371,3.2491524,,,,,,,,,,,,,, -31800,1.4527292,3.3321655,,,,,,,,,,,,,, -31809,,,0.70511794090271,1.4040005207061768,0.64028000831604,1.698283553123474,50000.0,0.5184000134468079,2.3254666328430176,10000.0,10742.950670957563,11134.55994296074,10742.950670957563,389.93855023384094,0.665184497833252,0.0 -31900,1.5110054,3.252922,,,,,,,,,,,,,, -32000,1.4368469,3.330002,,,,,,,,,,,,,, -32100,1.2699258,3.2010705,,,,,,,,,,,,,, -32200,1.5805508,3.2182958,,,,,,,,,,,,,, -32300,1.5324161,3.1727219,,,,,,,,,,,,,, -32400,1.3531667,3.1865656,,,,,,,,,,,,,, -32500,1.4562817,3.2767549,,,,,,,,,,,,,, -32600,1.5765687,3.2217736,,,,,,,,,,,,,, -32700,1.4789366,3.2132924,,,,,,,,,,,,,, -32800,1.422169,3.26182,,,,,,,,,,,,,, -32900,1.341327,3.1728761,,,,,,,,,,,,,, -33000,1.4446043,3.2212148,,,,,,,,,,,,,, -33100,1.4120486,3.2770638,,,,,,,,,,,,,, -33200,1.4188209,3.2258592,,,,,,,,,,,,,, -33300,1.376284,3.2399492,,,,,,,,,,,,,, -33325,,,0.7089444994926453,1.3775570392608645,0.6474399566650391,1.6512587070465088,50000.0,0.518500030040741,2.3262481689453125,10000.0,11253.040585517883,11662.289780378342,11253.040585517883,407.4942150115967,0.7008495330810547,0.0 -33400,1.5264184,3.2032907,,,,,,,,,,,,,, -33500,1.3822601,3.1974168,,,,,,,,,,,,,, -33600,1.4331284,3.2153535,,,,,,,,,,,,,, -33700,1.4787234,3.2879977,,,,,,,,,,,,,, -33800,1.388907,3.2246027,,,,,,,,,,,,,, -33900,1.2957925,3.2221274,,,,,,,,,,,,,, -34000,1.4797199,3.2820008,,,,,,,,,,,,,, -34100,1.3006301,3.2515528,,,,,,,,,,,,,, -34200,1.5805026,3.280498,,,,,,,,,,,,,, -34300,1.4639467,3.3202734,,,,,,,,,,,,,, -34400,1.3824323,3.2924776,,,,,,,,,,,,,, -34500,1.3160561,3.2907858,,,,,,,,,,,,,, -34600,1.3996881,3.2416146,,,,,,,,,,,,,, -34700,1.6215373,3.2402976,,,,,,,,,,,,,, -34800,1.4816314,3.27217,,,,,,,,,,,,,, -34841,,,0.7420878410339355,1.2352484464645386,0.6431399583816528,1.672220230102539,50000.0,0.5170000195503235,2.30928897857666,10000.0,11763.106231689451,12190.231878995895,11763.106231689451,425.2882845401764,0.7346818447113037,0.0 -34900,1.2961477,3.2706964,,,,,,,,,,,,,, -35000,1.6045885,3.278119,,,,,,,,,,,,,, -35100,1.5155982,3.2326756,,,,,,,,,,,,,, -35200,1.334519,3.1693287,,,,,,,,,,,,,, -35300,1.4295335,3.2264173,,,,,,,,,,,,,, -35400,1.5447023,3.2410812,,,,,,,,,,,,,, -35500,1.3071908,3.2057786,,,,,,,,,,,,,, -35600,1.2937943,3.1839666,,,,,,,,,,,,,, -35700,1.3982148,3.2343428,,,,,,,,,,,,,, -35800,1.4670608,3.235567,,,,,,,,,,,,,, -35900,1.4272593,3.2270544,,,,,,,,,,,,,, -36000,1.473872,3.2306912,,,,,,,,,,,,,, -36100,1.5650904,3.189078,,,,,,,,,,,,,, -36200,1.5113381,3.2227511,,,,,,,,,,,,,, -36300,1.4414666,3.225537,,,,,,,,,,,,,, -36358,,,0.7210817933082581,1.3398000001907349,0.6416599750518799,1.693834900856018,50000.0,0.5149000287055969,2.3678078651428223,10000.0,12273.201929330826,12718.294162273409,12273.201929330826,443.1710669994354,0.7699494361877441,0.0 -36400,1.4151798,3.2253413,,,,,,,,,,,,,, -36500,1.7448282,3.3029878,,,,,,,,,,,,,, -36600,1.393265,3.1102605,,,,,,,,,,,,,, -36700,1.5505592,3.233467,,,,,,,,,,,,,, -36800,1.5153152,3.202129,,,,,,,,,,,,,, -36900,1.4192606,3.219098,,,,,,,,,,,,,, -37000,1.4620262,3.2198844,,,,,,,,,,,,,, -37100,1.4518723,3.156436,,,,,,,,,,,,,, -37200,1.5040702,3.2382722,,,,,,,,,,,,,, -37300,1.3640627,3.2556603,,,,,,,,,,,,,, -37400,1.599204,3.2024333,,,,,,,,,,,,,, -37500,1.465242,3.231794,,,,,,,,,,,,,, -37600,1.4431205,3.208495,,,,,,,,,,,,,, -37700,1.6535091,3.2704234,,,,,,,,,,,,,, -37800,1.4721111,3.1346676,,,,,,,,,,,,,, -37874,,,0.7284956574440002,1.312329649925232,0.6582199931144714,1.6255125999450684,50000.0,0.5271000266075134,2.265986442565918,10000.0,12783.18024611473,13245.977816104887,12783.18024611473,460.7906458377838,0.8074545860290527,0.0 -37900,1.5939552,3.2603998,,,,,,,,,,,,,, -38000,1.485445,3.216421,,,,,,,,,,,,,, -38100,1.3906283,3.1585197,,,,,,,,,,,,,, -38200,1.4504747,3.2229276,,,,,,,,,,,,,, -38300,1.7446929,3.2690563,,,,,,,,,,,,,, -38400,2.0959256,3.2038116,,,,,,,,,,,,,, -38500,1.5855283,3.2260723,,,,,,,,,,,,,, -38600,1.5450759,3.212714,,,,,,,,,,,,,, -38700,1.4351479,3.1605136,,,,,,,,,,,,,, -38800,1.568167,3.1819003,,,,,,,,,,,,,, -38900,1.480488,3.1993392,,,,,,,,,,,,,, -39000,1.6268883,3.1008253,,,,,,,,,,,,,, -39100,1.4790684,3.2572324,,,,,,,,,,,,,, -39200,1.4209107,3.2011132,,,,,,,,,,,,,, -39300,1.8092017,3.2210078,,,,,,,,,,,,,, -39391,,,0.7284956574440002,1.298724889755249,0.6593199968338013,1.6137521266937256,50000.0,0.5283000469207764,2.272613525390625,10000.0,13293.384312152864,13773.91310286522,13293.384312152864,478.4318549633026,0.8491504192352295,0.0 -39400,1.6811675,3.1964371,,,,,,,,,,,,,, -39500,1.7957897,3.1927962,,,,,,,,,,,,,, -39600,1.4715319,3.200491,,,,,,,,,,,,,, -39700,1.5889046,3.1951346,,,,,,,,,,,,,, -39800,1.5296142,3.298726,,,,,,,,,,,,,, -39900,1.7133547,3.1993532,,,,,,,,,,,,,, -40000,1.5063673,3.2376132,,,,,,,,,,,,,, -40100,1.7920642,3.12699,,,,,,,,,,,,,, -40200,1.4563575,3.204048,,,,,,,,,,,,,, -40300,1.5597849,3.1264265,,,,,,,,,,,,,, -40400,1.4814911,3.1977663,,,,,,,,,,,,,, -40500,1.5332527,3.1221693,,,,,,,,,,,,,, -40600,1.7724353,3.1884718,,,,,,,,,,,,,, -40700,1.5320765,3.3167202,,,,,,,,,,,,,, -40800,1.7444791,3.2400758,,,,,,,,,,,,,, -40900,1.7521484,3.2099419,,,,,,,,,,,,,, -40907,,,0.7203842401504517,1.3398540019989014,0.6532599925994873,1.6414886713027954,50000.0,0.524399995803833,2.278727054595948,10000.0,13803.354566812515,14301.69158387184,13803.354566812515,496.1568441390991,0.8838469982147217,0.0 -41000,1.6195328,3.2248445,,,,,,,,,,,,,, -41100,1.6380574,3.1926296,,,,,,,,,,,,,, -41200,1.6313137,3.276992,,,,,,,,,,,,,, -41300,1.5613791,3.2298474,,,,,,,,,,,,,, -41400,1.5805808,3.1638026,,,,,,,,,,,,,, -41500,1.5584184,3.2650666,,,,,,,,,,,,,, -41600,1.8183066,3.2265193,,,,,,,,,,,,,, -41700,1.4987617,3.1955633,,,,,,,,,,,,,, -41800,1.7427131,3.137658,,,,,,,,,,,,,, -41900,1.871544,3.2345781,,,,,,,,,,,,,, -42000,1.72155,3.1844983,,,,,,,,,,,,,, -42100,1.6478212,3.136401,,,,,,,,,,,,,, -42200,1.5727028,3.1727629,,,,,,,,,,,,,, -42300,1.5431265,3.1302583,,,,,,,,,,,,,, -42400,1.6943973,3.2119765,,,,,,,,,,,,,, -42423,,,0.7283163070678711,1.3241567611694336,0.6582599878311157,1.6258599758148191,50000.0,0.5348000526428223,2.26089096069336,10000.0,14313.371778011322,14830.411584377289,14313.371778011322,514.7730877399445,0.921715259552002,0.0 -42500,1.6226456,3.1392572,,,,,,,,,,,,,, -42600,1.55776,3.1630301,,,,,,,,,,,,,, -42700,1.9645038,3.1320758,,,,,,,,,,,,,, -42800,1.5347365,3.155033,,,,,,,,,,,,,, -42900,1.7619662,3.173472,,,,,,,,,,,,,, -43000,1.6390682,3.2123787,,,,,,,,,,,,,, -43100,1.6819155,3.1677692,,,,,,,,,,,,,, -43200,1.8179183,3.1864178,,,,,,,,,,,,,, -43300,1.4862009,3.1494994,,,,,,,,,,,,,, -43400,1.5490805,3.1631522,,,,,,,,,,,,,, -43500,1.730004,3.1807647,,,,,,,,,,,,,, -43600,1.6960934,3.224859,,,,,,,,,,,,,, -43700,1.7166563,3.1152256,,,,,,,,,,,,,, -43800,1.6424341,3.2217498,,,,,,,,,,,,,, -43900,1.7318566,3.222218,,,,,,,,,,,,,, -43940,,,0.7344746589660645,1.301184892654419,0.6468999981880188,1.689677119255066,50000.0,0.5181000232696533,2.357792377471924,10000.0,14823.624761104584,15358.438279390337,14823.624761104584,532.4613721370697,0.9581685066223145,0.0 -44000,1.558325,3.2293994,,,,,,,,,,,,,, -44100,1.7781007,3.1632152,,,,,,,,,,,,,, -44200,1.5568736,3.2257087,,,,,,,,,,,,,, -44300,1.6204486,3.183357,,,,,,,,,,,,,, -44400,1.6009998,3.240704,,,,,,,,,,,,,, -44500,1.7636801,3.2617366,,,,,,,,,,,,,, -44600,1.7431865,3.148035,,,,,,,,,,,,,, -44700,1.5775126,3.1498356,,,,,,,,,,,,,, -44800,1.4868582,3.1593924,,,,,,,,,,,,,, -44900,1.7021549,3.1469696,,,,,,,,,,,,,, -45000,1.7439004,3.1302202,,,,,,,,,,,,,, -45100,1.8700061,3.1867814,,,,,,,,,,,,,, -45200,1.6698577,3.088698,,,,,,,,,,,,,, -45300,1.5446804,3.161797,,,,,,,,,,,,,, -45400,1.6130602,3.096987,,,,,,,,,,,,,, -45458,,,0.7339963316917419,1.248931646347046,0.6551799774169922,1.5993354320526123,50000.0,0.5314000248908997,2.2523021697998047,10000.0,15333.79275083542,15886.385159492493,15333.79275083542,550.1557083129883,0.9941191673278807,0.0 -45500,1.5900491,3.0847492,,,,,,,,,,,,,, -45600,1.6739545,3.1892452,,,,,,,,,,,,,, -45700,1.5986009,3.179084,,,,,,,,,,,,,, -45800,1.7173777,3.1414757,,,,,,,,,,,,,, -45900,1.6843532,3.1955204,,,,,,,,,,,,,, -46000,1.7263392,3.2222097,,,,,,,,,,,,,, -46100,1.5932385,3.1584086,,,,,,,,,,,,,, -46200,1.8554633,3.190392,,,,,,,,,,,,,, -46300,1.7115889,3.097552,,,,,,,,,,,,,, -46400,1.7268945,3.1066308,,,,,,,,,,,,,, -46500,1.7831832,3.1793742,,,,,,,,,,,,,, -46600,1.665795,3.1846795,,,,,,,,,,,,,, -46700,1.719738,3.1283262,,,,,,,,,,,,,, -46800,1.6718894,3.1674864,,,,,,,,,,,,,, -46900,1.7069256,3.1823454,,,,,,,,,,,,,, -46975,,,0.7340162396430969,1.2588564157485962,0.6627399921417236,1.5765480995178225,50000.0,0.5320000052452087,2.250394344329834,10000.0,15843.919181346891,16414.092190027237,15843.919181346891,567.648931980133,1.0327842235565186,0.0 -47000,1.5670621,3.1889744,,,,,,,,,,,,,, -47100,2.0670488,3.1604471,,,,,,,,,,,,,, -47200,1.828313,3.1842504,,,,,,,,,,,,,, -47300,1.9016223,3.1675394,,,,,,,,,,,,,, -47400,1.6207539,3.1813993,,,,,,,,,,,,,, -47500,1.7551626,3.2192059,,,,,,,,,,,,,, -47600,1.772724,3.1290991,,,,,,,,,,,,,, -47700,1.7384123,3.127329,,,,,,,,,,,,,, -47800,1.777128,3.2120404,,,,,,,,,,,,,, -47900,1.6702193,3.1336117,,,,,,,,,,,,,, -48000,1.8629491,3.1534202,,,,,,,,,,,,,, -48100,1.6739188,3.0926144,,,,,,,,,,,,,, -48200,1.7508881,3.1749234,,,,,,,,,,,,,, -48300,1.8162737,3.2442517,,,,,,,,,,,,,, -48400,1.6072011,3.2754078,,,,,,,,,,,,,, -48491,,,0.7350525856018066,1.2592726945877075,0.665619969367981,1.5760258436203003,50000.0,0.5386000275611877,2.2234253883361816,10000.0,16353.850271701813,16941.92165875435,16353.850271701813,585.4615631103516,1.0696442127227783,0.0 -48500,1.7637689,3.1150894,,,,,,,,,,,,,, -48600,1.7133389,3.1148868,,,,,,,,,,,,,, -48700,1.8630763,3.1354728,,,,,,,,,,,,,, -48800,1.7038959,3.1477437,,,,,,,,,,,,,, -48900,1.7851619,3.1894524,,,,,,,,,,,,,, -49000,1.8024498,3.1681137,,,,,,,,,,,,,, -49100,1.670204,3.1446075,,,,,,,,,,,,,, -49200,1.541809,3.1023827,,,,,,,,,,,,,, -49300,1.6567435,3.1784556,,,,,,,,,,,,,, -49400,1.6558203,3.2109199,,,,,,,,,,,,,, -49500,1.8152832,3.0977974,,,,,,,,,,,,,, -49600,2.0288498,3.191792,,,,,,,,,,,,,, -49700,1.6636548,3.1935556,,,,,,,,,,,,,, -49800,1.8750278,3.1754978,,,,,,,,,,,,,, -49900,1.9781121,3.1468358,,,,,,,,,,,,,, -50000,1.7497331,3.119277,,,,,,,,,,,,,, -50008,,,0.7184112071990967,1.2996251583099363,0.6577799916267395,1.5904572010040283,50000.0,0.5308000445365906,2.27528715133667,10000.0,16863.88741350174,17470.022649526596,16863.88741350174,603.4375021457672,1.108870029449463,0.0 -50100,1.9633487,3.1905172,,,,,,,,,,,,,, -50200,1.8151752,3.2331994,,,,,,,,,,,,,, -50300,1.8872628,3.1864624,,,,,,,,,,,,,, -50400,1.7865021,3.1019142,,,,,,,,,,,,,, -50500,1.8387656,3.2517505,,,,,,,,,,,,,, -50600,1.7768689,3.1585333,,,,,,,,,,,,,, -50700,1.6847484,3.1219308,,,,,,,,,,,,,, -50800,1.9530301,3.2012565,,,,,,,,,,,,,, -50900,1.6740882,3.146811,,,,,,,,,,,,,, -51000,1.8378361,3.1921172,,,,,,,,,,,,,, -51100,1.7796739,3.1836057,,,,,,,,,,,,,, -51200,1.7870456,3.147934,,,,,,,,,,,,,, -51300,1.7699996,3.257841,,,,,,,,,,,,,, -51400,2.1131155,3.160845,,,,,,,,,,,,,, -51500,1.7655929,3.1977527,,,,,,,,,,,,,, -51525,,,0.7494618892669678,1.2071516513824463,0.6755399703979492,1.532850742340088,50000.0,0.5425000190734863,2.1815757751464844,10000.0,17374.02089715004,17997.889188289642,17374.02089715004,621.0799465179443,1.1506271362304688,0.0 -51600,1.8324832,3.1032586,,,,,,,,,,,,,, -51700,1.9996579,3.1631513,,,,,,,,,,,,,, -51800,2.093597,3.167992,,,,,,,,,,,,,, -51900,1.8541747,3.2300406,,,,,,,,,,,,,, -52000,1.7856835,3.1782458,,,,,,,,,,,,,, -52100,1.780593,3.1132712,,,,,,,,,,,,,, -52200,1.7749718,3.0629475,,,,,,,,,,,,,, -52300,1.8785896,3.14997,,,,,,,,,,,,,, -52400,1.665602,3.1211812,,,,,,,,,,,,,, -52500,1.7502078,3.1111655,,,,,,,,,,,,,, -52600,1.9925703,3.0491936,,,,,,,,,,,,,, -52700,1.8914812,3.2075963,,,,,,,,,,,,,, -52800,1.8231992,3.1093175,,,,,,,,,,,,,, -52900,1.6857561,3.1346917,,,,,,,,,,,,,, -53000,1.7950141,3.0983567,,,,,,,,,,,,,, -53042,,,0.7538065910339355,1.20997416973114,0.6650999784469604,1.5931397676467896,50000.0,0.536300003528595,2.2342357635498047,10000.0,17884.068472623825,18525.674897909164,17884.068472623825,638.7280640602112,1.1915946006774902,0.0 -53100,1.8236562,3.2253778,,,,,,,,,,,,,, -53200,1.9069055,3.0883718,,,,,,,,,,,,,, -53300,1.8725485,3.180408,,,,,,,,,,,,,, -53400,1.8638327,3.0567064,,,,,,,,,,,,,, -53500,2.004006,3.0994391,,,,,,,,,,,,,, -53600,1.9672874,3.1386466,,,,,,,,,,,,,, -53700,1.6947609,3.1637678,,,,,,,,,,,,,, -53800,1.841114,3.1907578,,,,,,,,,,,,,, -53900,1.7854145,3.0929594,,,,,,,,,,,,,, -54000,1.8735365,3.1413376,,,,,,,,,,,,,, -54100,1.8766844,3.1883092,,,,,,,,,,,,,, -54200,1.894156,3.1539595,,,,,,,,,,,,,, -54300,2.0011544,3.2038026,,,,,,,,,,,,,, -54400,1.8578413,3.1278517,,,,,,,,,,,,,, -54500,1.8418522,3.137619,,,,,,,,,,,,,, -54560,,,0.7381616830825806,1.2561745643615725,0.659559965133667,1.598878264427185,50000.0,0.5281000137329102,2.285375595092773,10000.0,18394.20597648621,19053.84075427056,18394.20597648621,656.6700406074524,1.2295169830322266,0.0 -54600,1.8302646,3.0857544,,,,,,,,,,,,,, -54700,1.9434832,3.11174,,,,,,,,,,,,,, -54800,1.9094528,3.2002382,,,,,,,,,,,,,, -54900,1.8243304,3.1549752,,,,,,,,,,,,,, -55000,1.8155558,3.1534674,,,,,,,,,,,,,, -55100,1.8935444,3.1674018,,,,,,,,,,,,,, -55200,1.8096197,3.1648273,,,,,,,,,,,,,, -55300,2.070495,3.0894165,,,,,,,,,,,,,, -55400,1.7949758,3.12292,,,,,,,,,,,,,, -55500,1.7699443,3.0771437,,,,,,,,,,,,,, -55600,1.9399444,3.080159,,,,,,,,,,,,,, -55700,1.8326249,3.1466234,,,,,,,,,,,,,, -55800,2.0896013,3.1657658,,,,,,,,,,,,,, -55900,1.7328906,3.1092863,,,,,,,,,,,,,, -56000,1.9301741,3.112057,,,,,,,,,,,,,, -56077,,,0.7382214665412903,1.2343106269836426,0.6646400094032288,1.5662380456924438,50000.0,0.5379000306129456,2.216094732284546,10000.0,18904.25890445709,19581.60926795005,18904.25890445709,674.2945744991302,1.2716209888458252,0.0 -56100,1.7226774,3.0431216,,,,,,,,,,,,,, -56200,2.0480676,3.1213522,,,,,,,,,,,,,, -56300,1.968742,3.043983,,,,,,,,,,,,,, -56400,1.8893664,3.1563718,,,,,,,,,,,,,, -56500,1.8333504,3.1411057,,,,,,,,,,,,,, -56600,1.9247093,3.2197518,,,,,,,,,,,,,, -56700,2.0731363,3.178748,,,,,,,,,,,,,, -56800,2.0195353,3.234672,,,,,,,,,,,,,, -56900,1.8058517,3.0984225,,,,,,,,,,,,,, -57000,1.9067065,3.2085712,,,,,,,,,,,,,, -57100,2.0571983,3.1050482,,,,,,,,,,,,,, -57200,1.9924092,3.1052496,,,,,,,,,,,,,, -57300,1.994103,3.1842713,,,,,,,,,,,,,, -57400,1.9388018,3.0863938,,,,,,,,,,,,,, -57500,1.9111265,3.1478987,,,,,,,,,,,,,, -57594,,,0.7438416481018066,1.227565050125122,0.6708399653434753,1.5553630590438845,50000.0,0.5416000485420227,2.204932928085327,10000.0,19414.452545642853,20109.54371070861,19414.452545642853,691.946305513382,1.3120112419128418,0.0 -57600,1.8966837,3.0895813,,,,,,,,,,,,,, -57700,1.739485,3.0760956,,,,,,,,,,,,,, -57800,1.8549569,3.0970492,,,,,,,,,,,,,, -57900,1.9176674,3.1315165,,,,,,,,,,,,,, -58000,2.1136537,3.1765966,,,,,,,,,,,,,, -58100,1.8207704,3.0094817,,,,,,,,,,,,,, -58200,1.8518287,3.141031,,,,,,,,,,,,,, -58300,1.8949771,3.0375178,,,,,,,,,,,,,, -58400,2.047641,3.0917497,,,,,,,,,,,,,, -58500,2.0738811,3.116818,,,,,,,,,,,,,, -58600,1.9990185,3.0861588,,,,,,,,,,,,,, -58700,1.86271,3.1449473,,,,,,,,,,,,,, -58800,1.9544132,3.1733117,,,,,,,,,,,,,, -58900,1.8855071,3.04754,,,,,,,,,,,,,, -59000,1.8675497,3.1833975,,,,,,,,,,,,,, -59100,1.9088722,3.1246808,,,,,,,,,,,,,, -59112,,,0.7326809763908386,1.2575324773788452,0.6616599559783936,1.569987416267395,50000.0,0.52920001745224,2.269146203994751,10000.0,19924.543339967728,20637.42899608612,19924.543339967728,709.6492412090302,1.3549623489379885,0.0 -59200,1.924348,3.1094708,,,,,,,,,,,,,, -59300,1.7655991,3.0284407,,,,,,,,,,,,,, -59400,1.955027,3.0976229,,,,,,,,,,,,,, -59500,1.9063612,3.0855618,,,,,,,,,,,,,, -59600,2.0518878,3.1479604,,,,,,,,,,,,,, -59700,1.9696034,3.140294,,,,,,,,,,,,,, -59800,1.866983,3.0875685,,,,,,,,,,,,,, -59900,2.0332546,3.1307602,,,,,,,,,,,,,, -60000,2.0328434,3.175029,,,,,,,,,,,,,, -60100,2.1039066,3.1566749,,,,,,,,,,,,,, -60200,2.0204322,3.1360278,,,,,,,,,,,,,, -60300,1.9418176,3.1026225,,,,,,,,,,,,,, -60400,1.984041,3.0670779,,,,,,,,,,,,,, -60500,1.9973352,3.10364,,,,,,,,,,,,,, -60600,1.8696264,3.0299647,,,,,,,,,,,,,, -60630,,,0.7663623690605164,1.119907021522522,0.6607800126075745,1.5795252323150637,50000.0,0.5357000231742859,2.234541177749634,10000.0,20434.725960969925,21165.11572289467,20434.725960969925,727.0632431507111,1.396129131317139,0.0 -60700,1.9515189,3.106301,,,,,,,,,,,,,, -60800,1.9776311,3.2224765,,,,,,,,,,,,,, -60900,1.8878697,3.0962029,,,,,,,,,,,,,, -61000,2.2047858,3.0765584,,,,,,,,,,,,,, -61100,1.9323847,3.0541728,,,,,,,,,,,,,, -61200,2.345471,3.1448612,,,,,,,,,,,,,, -61300,1.9908555,3.0508513,,,,,,,,,,,,,, -61400,2.0786233,3.1241035,,,,,,,,,,,,,, -61500,1.9358487,3.2171462,,,,,,,,,,,,,, -61600,1.8737705,3.047175,,,,,,,,,,,,,, -61700,1.8432308,3.1236606,,,,,,,,,,,,,, -61800,2.0997767,3.1539507,,,,,,,,,,,,,, -61900,1.8946551,3.150521,,,,,,,,,,,,,, -62000,2.0914485,3.127915,,,,,,,,,,,,,, -62100,2.0467184,3.0982237,,,,,,,,,,,,,, -62148,,,0.7565170526504517,1.1901875734329224,0.6680799722671509,1.5773533582687378,50000.0,0.5370000004768372,2.229222059249878,10000.0,20944.98304605484,21693.09570145607,20944.98304605484,744.6991124153137,1.4345617294311523,0.0 -62200,2.152977,3.1944683,,,,,,,,,,,,,, -62300,1.8610843,3.0848157,,,,,,,,,,,,,, -62400,2.0362272,3.1162455,,,,,,,,,,,,,, -62500,1.9386252,3.1171455,,,,,,,,,,,,,, -62600,2.092871,3.1257017,,,,,,,,,,,,,, -62700,2.1421168,3.1866436,,,,,,,,,,,,,, -62800,1.9313426,3.1085343,,,,,,,,,,,,,, -62900,1.9449188,3.0747015,,,,,,,,,,,,,, -63000,1.9464899,3.0233655,,,,,,,,,,,,,, -63100,2.110198,3.1857274,,,,,,,,,,,,,, -63200,1.9010983,3.0628798,,,,,,,,,,,,,, -63300,1.905334,3.0776436,,,,,,,,,,,,,, -63400,1.9701046,3.1248102,,,,,,,,,,,,,, -63500,2.0382721,3.0643377,,,,,,,,,,,,,, -63600,2.1654847,3.152009,,,,,,,,,,,,,, -63665,,,0.7577726244926453,1.1620841026306152,0.6750199794769287,1.522871017456055,50000.0,0.5432000160217285,2.200866222381592,10000.0,21455.00278377533,22220.902873277664,21455.00278377533,762.3969528675079,1.4752840995788574,0.0 -63700,1.8306097,3.0120492,,,,,,,,,,,,,, -63800,2.0354517,3.0854049,,,,,,,,,,,,,, -63900,1.7555792,3.0586543,,,,,,,,,,,,,, -64000,2.0210516,3.1396155,,,,,,,,,,,,,, -64100,1.9076346,3.0497115,,,,,,,,,,,,,, -64200,1.9653714,3.120048,,,,,,,,,,,,,, -64300,2.0630958,3.1116188,,,,,,,,,,,,,, -64400,2.024042,3.137461,,,,,,,,,,,,,, -64500,1.9576316,3.0993094,,,,,,,,,,,,,, -64600,1.9879149,3.0537658,,,,,,,,,,,,,, -64700,1.9204608,3.0906284,,,,,,,,,,,,,, -64800,2.0858164,3.0829172,,,,,,,,,,,,,, -64900,2.2203894,3.1458578,,,,,,,,,,,,,, -65000,1.939116,3.028929,,,,,,,,,,,,,, -65100,2.0447361,3.093336,,,,,,,,,,,,,, -65182,,,0.7588488459587097,1.1486854553222656,0.6837799549102783,1.4862542152404783,50000.0,0.5532000064849854,2.1577084064483643,10000.0,21964.96987080574,22748.91353034973,21964.96987080574,780.350729227066,1.5167927742004397,0.0 -65200,1.9478171,3.1496916,,,,,,,,,,,,,, -65300,2.1058586,3.097624,,,,,,,,,,,,,, -65400,1.9532945,3.046491,,,,,,,,,,,,,, -65500,2.1076493,3.1105692,,,,,,,,,,,,,, -65600,2.094365,3.1083593,,,,,,,,,,,,,, -65700,2.0337284,3.0279713,,,,,,,,,,,,,, -65800,1.9996157,3.1128128,,,,,,,,,,,,,, -65900,2.3001661,3.112097,,,,,,,,,,,,,, -66000,2.0371065,3.079528,,,,,,,,,,,,,, -66100,1.9990275,3.0695195,,,,,,,,,,,,,, -66200,2.0290172,3.0544622,,,,,,,,,,,,,, -66300,2.1391172,3.123134,,,,,,,,,,,,,, -66400,2.17678,3.1578248,,,,,,,,,,,,,, -66500,2.1492434,3.106083,,,,,,,,,,,,,, -66600,1.9808867,3.0633862,,,,,,,,,,,,,, -66699,,,0.7502591013908386,1.213298797607422,0.6759399771690369,1.5442116260528564,50000.0,0.5546000003814697,2.189296245574951,10000.0,22475.05553460121,23276.70662856102,22475.05553460121,797.9648485183716,1.5609960556030271,0.0 -66700,2.208358,3.161524,,,,,,,,,,,,,, -66800,2.0846486,3.0788474,,,,,,,,,,,,,, -66900,2.1084173,3.097886,,,,,,,,,,,,,, -67000,2.2186809,3.120093,,,,,,,,,,,,,, -67100,1.9705213,3.0993292,,,,,,,,,,,,,, -67200,1.9980577,3.098298,,,,,,,,,,,,,, -67300,2.052359,3.0924096,,,,,,,,,,,,,, -67400,2.0320852,3.0620313,,,,,,,,,,,,,, -67500,2.200867,3.068252,,,,,,,,,,,,,, -67600,2.2352042,3.1037986,,,,,,,,,,,,,, -67700,2.0107675,3.063333,,,,,,,,,,,,,, -67800,1.96825,3.061724,,,,,,,,,,,,,, -67900,1.9328705,3.0077899,,,,,,,,,,,,,, -68000,2.1226737,3.1373875,,,,,,,,,,,,,, -68100,2.033343,3.0926132,,,,,,,,,,,,,, -68200,2.036779,3.0661342,,,,,,,,,,,,,, -68216,,,0.7473692297935486,1.19962477684021,0.6743800044059753,1.5287421941757202,50000.0,0.5452000498771667,2.196040153503418,10000.0,22985.083032131195,23804.51207590103,22985.083032131195,815.652755022049,1.6021060943603516,0.0 -68300,2.0825343,3.1186016,,,,,,,,,,,,,, -68400,1.975938,3.115993,,,,,,,,,,,,,, -68500,2.3276477,3.1850243,,,,,,,,,,,,,, -68600,2.1079516,3.0389407,,,,,,,,,,,,,, -68700,2.2602148,3.0684295,,,,,,,,,,,,,, -68800,2.095688,3.0724406,,,,,,,,,,,,,, -68900,2.063653,3.134804,,,,,,,,,,,,,, -69000,2.038592,3.0654957,,,,,,,,,,,,,, -69100,2.1245835,3.098182,,,,,,,,,,,,,, -69200,2.0182745,3.0885997,,,,,,,,,,,,,, -69300,2.0288713,3.0758486,,,,,,,,,,,,,, -69400,2.0081067,3.0746105,,,,,,,,,,,,,, -69500,2.1554956,3.0594838,,,,,,,,,,,,,, -69600,1.9753015,2.9725223,,,,,,,,,,,,,, -69700,1.9633797,3.024328,,,,,,,,,,,,,, -69734,,,0.7683553695678711,1.1409224271774292,0.6607999801635742,1.5949475765228271,50000.0,0.5418000221252441,2.243528127670288,10000.0,23495.33017706871,24332.515582323074,23495.33017706871,833.3205726146698,1.642003059387207,0.0 -69800,2.1606283,3.0622284,,,,,,,,,,,,,, -69900,2.131188,3.0853696,,,,,,,,,,,,,, -70000,1.9398049,3.0473223,,,,,,,,,,,,,, -70100,2.0521147,3.151136,,,,,,,,,,,,,, -70200,2.2311404,3.092851,,,,,,,,,,,,,, -70300,2.0389752,3.0280495,,,,,,,,,,,,,, -70400,2.0258145,3.0076616,,,,,,,,,,,,,, -70500,2.0686707,3.136625,,,,,,,,,,,,,, -70600,2.0879226,3.0721748,,,,,,,,,,,,,, -70700,1.9454399,3.0781288,,,,,,,,,,,,,, -70800,1.9976397,3.066856,,,,,,,,,,,,,, -70900,2.113023,3.0669947,,,,,,,,,,,,,, -71000,1.9606227,3.0221438,,,,,,,,,,,,,, -71100,2.0192537,3.031138,,,,,,,,,,,,,, -71200,2.312808,3.1976635,,,,,,,,,,,,,, -71251,,,0.7674585580825806,1.128574013710022,0.6775799989700317,1.5199021100997925,50000.0,0.5508000254631042,2.167146682739258,10000.0,24005.278024673466,24860.202106952667,24005.278024673466,850.9681005477905,1.6843080520629885,0.0 -71300,2.2066138,3.0933714,,,,,,,,,,,,,, -71400,2.0875576,3.0690007,,,,,,,,,,,,,, -71500,2.1668797,3.036999,,,,,,,,,,,,,, -71600,1.9651326,3.0993655,,,,,,,,,,,,,, -71700,2.140811,3.133189,,,,,,,,,,,,,, -71800,2.175717,3.068952,,,,,,,,,,,,,, -71900,2.120707,3.1141198,,,,,,,,,,,,,, -72000,1.9563704,2.9955597,,,,,,,,,,,,,, -72100,2.251072,3.092145,,,,,,,,,,,,,, -72200,2.0658956,3.112697,,,,,,,,,,,,,, -72300,2.1796653,3.101417,,,,,,,,,,,,,, -72400,2.2058933,3.0508263,,,,,,,,,,,,,, -72500,2.114175,3.0258348,,,,,,,,,,,,,, -72600,2.1826804,3.1439035,,,,,,,,,,,,,, -72700,2.1939502,3.1695786,,,,,,,,,,,,,, -72768,,,0.7587292790412903,1.159325361251831,0.6759200096130371,1.519149661064148,50000.0,0.5521000027656555,2.178276538848877,10000.0,24515.25668501854,25388.05660820008,24515.25668501854,868.7539627552032,1.7253928184509275,0.0 -72800,2.1013653,3.0409932,,,,,,,,,,,,,, -72900,2.2271373,3.0875673,,,,,,,,,,,,,, -73000,2.1853454,3.0128617,,,,,,,,,,,,,, -73100,2.234686,3.0251808,,,,,,,,,,,,,, -73200,2.0973537,3.0627236,,,,,,,,,,,,,, -73300,2.1587956,3.0413013,,,,,,,,,,,,,, -73400,2.0494783,2.9397197,,,,,,,,,,,,,, -73500,1.9872024,2.9894836,,,,,,,,,,,,,, -73600,2.046431,3.006402,,,,,,,,,,,,,, -73700,2.1440208,3.1292927,,,,,,,,,,,,,, -73800,2.2485168,3.0069423,,,,,,,,,,,,,, -73900,2.128453,3.0777946,,,,,,,,,,,,,, -74000,2.087603,3.003365,,,,,,,,,,,,,, -74100,2.0679219,3.0226405,,,,,,,,,,,,,, -74200,2.3542106,3.0622473,,,,,,,,,,,,,, -74285,,,0.7687141299247742,1.109268307685852,0.6859599947929382,1.4653456211090088,50000.0,0.5637000203132629,2.108211040496826,10000.0,25025.17741537094,25915.98461484909,25025.17741537094,886.672360420227,1.7654447555541992,0.0 -74300,2.1905432,3.1030035,,,,,,,,,,,,,, -74400,2.1186924,2.9947097,,,,,,,,,,,,,, -74500,2.2259133,3.0808535,,,,,,,,,,,,,, -74600,2.2063084,3.0543437,,,,,,,,,,,,,, -74700,2.0720901,3.010749,,,,,,,,,,,,,, -74800,2.092124,3.0146942,,,,,,,,,,,,,, -74900,2.3171358,3.0950203,,,,,,,,,,,,,, -75000,2.1807296,2.9906957,,,,,,,,,,,,,, -75100,2.262527,3.0451865,,,,,,,,,,,,,, -75200,2.0510342,3.011818,,,,,,,,,,,,,, -75300,2.228371,3.0606833,,,,,,,,,,,,,, -75400,1.9613389,3.0152864,,,,,,,,,,,,,, -75500,2.100998,3.0771785,,,,,,,,,,,,,, -75600,2.3783145,3.0268102,,,,,,,,,,,,,, -75700,2.2539747,3.11365,,,,,,,,,,,,,, -75800,2.0475464,3.0184588,,,,,,,,,,,,,, -75803,,,0.7549425959587097,1.1974339485168457,0.6801599860191345,1.5412228107452393,50000.0,0.5576000213623047,2.1966655254364014,10000.0,25535.3845744133,26444.38476872444,25535.3845744133,904.7744419574738,1.8068821430206297,0.0 -75900,2.072973,3.0312936,,,,,,,,,,,,,, -76000,2.0586703,3.0018964,,,,,,,,,,,,,, -76100,2.3614116,3.1314688,,,,,,,,,,,,,, -76200,2.083956,3.0565975,,,,,,,,,,,,,, -76300,2.1027257,3.0564377,,,,,,,,,,,,,, -76400,2.1419916,2.9698534,,,,,,,,,,,,,, -76500,2.2193582,3.1215682,,,,,,,,,,,,,, -76600,2.1456735,3.0534203,,,,,,,,,,,,,, -76700,2.2965329,3.0349822,,,,,,,,,,,,,, -76800,2.1116052,3.0247123,,,,,,,,,,,,,, -76900,2.1768048,3.1057973,,,,,,,,,,,,,, -77000,2.0604627,3.0520375,,,,,,,,,,,,,, -77100,2.4141726,3.0113823,,,,,,,,,,,,,, -77200,2.2317376,3.019712,,,,,,,,,,,,,, -77300,2.2735412,3.092348,,,,,,,,,,,,,, -77320,,,0.7669403553009033,1.1282862424850464,0.6887999773025513,1.4690423011779783,50000.0,0.5628000497817993,2.1257145404815674,10000.0,26045.36720395088,26972.264142274857,26045.36720395088,922.5809574127196,1.847849607467652,0.0 -77400,2.1614296,3.0206563,,,,,,,,,,,,,, -77500,2.146158,3.103865,,,,,,,,,,,,,, -77600,2.2667134,3.0156646,,,,,,,,,,,,,, -77700,2.2314732,3.0848677,,,,,,,,,,,,,, -77800,2.2142985,3.0705543,,,,,,,,,,,,,, -77900,2.1851785,3.062079,,,,,,,,,,,,,, -78000,2.233951,3.0222034,,,,,,,,,,,,,, -78100,2.3212488,3.013853,,,,,,,,,,,,,, -78200,2.0701396,2.9535441,,,,,,,,,,,,,, -78300,2.1629398,3.038521,,,,,,,,,,,,,, -78400,2.1729715,3.008637,,,,,,,,,,,,,, -78500,2.2216408,3.0557606,,,,,,,,,,,,,, -78600,2.1277874,2.993104,,,,,,,,,,,,,, -78700,2.2258718,3.0782394,,,,,,,,,,,,,, -78800,2.2355907,3.0676982,,,,,,,,,,,,,, -78837,,,0.7891621589660645,1.0352303981781006,0.684719979763031,1.4883527755737305,50000.0,0.5586000084877014,2.133809804916382,10000.0,26555.412185668945,27500.41028022766,26555.412185668945,940.5901341438292,1.890838623046875,0.0 -78900,2.2173727,3.0038419,,,,,,,,,,,,,, -79000,2.161345,3.0896692,,,,,,,,,,,,,, -79100,2.3549685,3.0228345,,,,,,,,,,,,,, -79200,2.2981555,3.0752234,,,,,,,,,,,,,, -79300,2.0890687,3.0511222,,,,,,,,,,,,,, -79400,2.1294298,3.072628,,,,,,,,,,,,,, -79500,2.246822,3.1189947,,,,,,,,,,,,,, -79600,2.2911224,3.0368438,,,,,,,,,,,,,, -79700,2.060902,2.9577675,,,,,,,,,,,,,, -79800,2.2294571,3.0798244,,,,,,,,,,,,,, -79900,2.122645,2.9931033,,,,,,,,,,,,,, -80000,2.2527556,3.05435,,,,,,,,,,,,,, -80100,2.1898205,3.0474195,,,,,,,,,,,,,, -80200,2.229372,3.061255,,,,,,,,,,,,,, -80300,2.2320192,3.0217187,,,,,,,,,,,,,, -80354,,,0.7771245241165161,1.0599044561386108,0.686739981174469,1.464746356010437,50000.0,0.5583000183105469,2.126771688461304,10000.0,27065.31203103065,28028.457494974136,27065.31203103065,958.6321983337402,1.9472503662109373,0.0 -80400,2.249658,3.058409,,,,,,,,,,,,,, -80500,2.2988918,3.0657277,,,,,,,,,,,,,, -80600,2.282632,3.006978,,,,,,,,,,,,,, -80700,2.1796844,3.05197,,,,,,,,,,,,,, -80800,2.3184078,3.0901947,,,,,,,,,,,,,, -80900,2.3138032,3.0294116,,,,,,,,,,,,,, -81000,2.2746623,2.9806414,,,,,,,,,,,,,, -81100,2.208297,3.031731,,,,,,,,,,,,,, -81200,2.182088,2.9917858,,,,,,,,,,,,,, -81300,2.1733465,3.0420694,,,,,,,,,,,,,, -81400,2.3180206,2.9719975,,,,,,,,,,,,,, -81500,2.2749467,3.007201,,,,,,,,,,,,,, -81600,2.2023802,2.9971042,,,,,,,,,,,,,, -81700,2.2957869,2.96251,,,,,,,,,,,,,, -81800,2.315411,3.1026611,,,,,,,,,,,,,, -81870,,,0.7718032598495483,1.0625839233398438,0.6845200061798096,1.44517719745636,50000.0,0.55840003490448,2.112879991531372,10000.0,27575.38053822517,28557.02544283867,27575.38053822517,977.0377764701844,1.992159843444824,0.0 -81900,2.1609714,2.975588,,,,,,,,,,,,,, -82000,2.2453198,3.0102072,,,,,,,,,,,,,, -82100,2.258967,2.9780693,,,,,,,,,,,,,, -82200,2.248197,3.0263836,,,,,,,,,,,,,, -82300,2.211605,3.0501094,,,,,,,,,,,,,, -82400,2.4626932,3.1142414,,,,,,,,,,,,,, -82500,2.2292473,3.0004928,,,,,,,,,,,,,, -82600,2.2386472,3.1254044,,,,,,,,,,,,,, -82700,2.2421513,3.0516028,,,,,,,,,,,,,, -82800,2.2654243,3.0298405,,,,,,,,,,,,,, -82900,2.131243,2.9451475,,,,,,,,,,,,,, -83000,2.376966,2.96775,,,,,,,,,,,,,, -83100,2.379438,2.9642298,,,,,,,,,,,,,, -83200,2.3423185,3.060166,,,,,,,,,,,,,, -83300,2.3920794,3.0002878,,,,,,,,,,,,,, -83388,,,0.7626355290412903,1.1467963457107544,0.6807599663734436,1.5086842775344849,50000.0,0.5593000054359436,2.149788618087769,10000.0,28085.509213924408,29084.97811937332,28085.509213924408,994.7666764259338,2.0380008220672607,0.0 -83400,2.3078356,3.0401077,,,,,,,,,,,,,, -83500,2.3104074,3.030933,,,,,,,,,,,,,, -83600,2.3283293,3.0244992,,,,,,,,,,,,,, -83700,2.195049,3.0014036,,,,,,,,,,,,,, -83800,2.3779268,3.0671706,,,,,,,,,,,,,, -83900,2.368892,3.0357044,,,,,,,,,,,,,, -84000,2.263747,2.995558,,,,,,,,,,,,,, -84100,2.1396437,2.99521,,,,,,,,,,,,,, -84200,2.3012238,3.0536804,,,,,,,,,,,,,, -84300,2.3512945,2.9911532,,,,,,,,,,,,,, -84400,2.3279812,3.0345054,,,,,,,,,,,,,, -84500,2.2430708,3.0359635,,,,,,,,,,,,,, -84600,2.5482779,3.0362642,,,,,,,,,,,,,, -84700,2.3103971,3.0403433,,,,,,,,,,,,,, -84800,2.5337365,3.0977976,,,,,,,,,,,,,, -84900,2.3836057,3.071011,,,,,,,,,,,,,, -84905,,,0.7693120241165161,1.1221765279769895,0.6900999546051025,1.4688431024551392,50000.0,0.5598000288009644,2.1390833854675293,10000.0,28595.469446659088,29612.57447457313,28595.469446659088,1012.3052713871002,2.0864999294281006,0.0 -85000,2.3606524,3.0233316,,,,,,,,,,,,,, -85100,2.2537642,2.9926016,,,,,,,,,,,,,, -85200,2.2614439,3.000733,,,,,,,,,,,,,, -85300,2.1725194,2.9711094,,,,,,,,,,,,,, -85400,2.4117014,3.0358815,,,,,,,,,,,,,, -85500,2.4182992,3.025525,,,,,,,,,,,,,, -85600,2.1982722,2.956938,,,,,,,,,,,,,, -85700,2.3617008,3.0405188,,,,,,,,,,,,,, -85800,2.2576404,3.0243795,,,,,,,,,,,,,, -85900,2.2135525,2.98066,,,,,,,,,,,,,, -86000,2.1297863,2.9754252,,,,,,,,,,,,,, -86100,2.5318727,2.992313,,,,,,,,,,,,,, -86200,2.2112143,3.01558,,,,,,,,,,,,,, -86300,2.9102,3.0414257,,,,,,,,,,,,,, -86400,2.478347,3.018966,,,,,,,,,,,,,, -86423,,,0.78125,1.0466645956039429,0.6961999535560608,1.407152771949768,50000.0,0.5702000260353088,2.054215669631958,10000.0,29105.538024902344,30140.310687065125,29105.538024902344,1029.8799359798431,2.1302878856658936,0.0 -86500,2.5944488,3.0883389,,,,,,,,,,,,,, -86600,2.3703215,3.0311418,,,,,,,,,,,,,, -86700,2.4116569,3.0118573,,,,,,,,,,,,,, -86800,2.5055888,3.0701048,,,,,,,,,,,,,, -86900,2.2749064,2.9948227,,,,,,,,,,,,,, -87000,2.1426456,3.0446615,,,,,,,,,,,,,, -87100,2.1598098,2.9735594,,,,,,,,,,,,,, -87200,2.2767806,2.9725795,,,,,,,,,,,,,, -87300,2.333353,3.0811467,,,,,,,,,,,,,, -87400,2.532741,2.9826381,,,,,,,,,,,,,, -87500,2.3973632,3.0358632,,,,,,,,,,,,,, -87600,2.414385,3.0162897,,,,,,,,,,,,,, -87700,2.3301795,2.9548774,,,,,,,,,,,,,, -87800,2.3136296,2.9679208,,,,,,,,,,,,,, -87900,2.3601332,3.0593884,,,,,,,,,,,,,, -87940,,,0.7981704473495483,1.0359044075012207,0.6924200057983398,1.4844554662704468,50000.0,0.5648000240325928,2.1412951946258545,10000.0,29615.679981470108,30668.1747674942,29615.679981470108,1047.508416891098,2.1748242378234863,0.0 -88000,2.338068,2.9910648,,,,,,,,,,,,,, -88100,2.264422,2.925311,,,,,,,,,,,,,, -88200,2.447739,3.0345776,,,,,,,,,,,,,, -88300,2.5174432,3.022156,,,,,,,,,,,,,, -88400,2.3039167,2.9557998,,,,,,,,,,,,,, -88500,2.380027,2.9939756,,,,,,,,,,,,,, -88600,2.2298462,2.9502006,,,,,,,,,,,,,, -88700,2.2861097,3.0203044,,,,,,,,,,,,,, -88800,2.2683127,2.8489935,,,,,,,,,,,,,, -88900,2.2374597,2.947393,,,,,,,,,,,,,, -89000,2.437134,3.0379226,,,,,,,,,,,,,, -89100,2.46327,2.974725,,,,,,,,,,,,,, -89200,2.3545454,3.0003657,,,,,,,,,,,,,, -89300,2.3580003,2.9580593,,,,,,,,,,,,,, -89400,2.455151,2.9157536,,,,,,,,,,,,,, -89457,,,0.7861925959587097,1.0689496994018557,0.692359983921051,1.4794917106628418,50000.0,0.5681000351905823,2.101494789123535,10000.0,30125.72606277466,31195.812687158585,30125.72606277466,1065.0068988800049,2.219353675842285,0.0 -89500,2.2565174,2.8918567,,,,,,,,,,,,,, -89600,2.4698238,3.0176206,,,,,,,,,,,,,, -89700,2.5190325,3.092708,,,,,,,,,,,,,, -89800,2.3928995,3.0099435,,,,,,,,,,,,,, -89900,2.4313836,2.9517255,,,,,,,,,,,,,, -90000,2.339307,2.9545612,,,,,,,,,,,,,, -90100,2.3345873,2.9741209,,,,,,,,,,,,,, -90200,2.2738905,2.9586997,,,,,,,,,,,,,, -90300,2.398544,3.016753,,,,,,,,,,,,,, -90400,2.4880004,3.0558543,,,,,,,,,,,,,, -90500,2.4139276,3.0232708,,,,,,,,,,,,,, -90600,2.4769228,2.97043,,,,,,,,,,,,,, -90700,2.3214655,2.8899493,,,,,,,,,,,,,, -90800,2.3178997,2.972115,,,,,,,,,,,,,, -90900,2.2474964,2.8994071,,,,,,,,,,,,,, -90974,,,0.783621609210968,1.0466140508651731,0.6930199861526489,1.4299108982086182,50000.0,0.5699000358581543,2.0710158348083496,10000.0,30635.79270672798,31723.56230640412,30635.79270672798,1082.5966680049896,2.263647556304932,0.0 -91000,2.4772441,2.990598,,,,,,,,,,,,,, -91100,2.338124,3.0065563,,,,,,,,,,,,,, -91200,2.7832475,3.0032854,,,,,,,,,,,,,, -91300,2.4083214,3.0119615,,,,,,,,,,,,,, -91400,2.3194687,2.8971803,,,,,,,,,,,,,, -91500,2.566691,3.0429618,,,,,,,,,,,,,, -91600,2.4419272,2.982436,,,,,,,,,,,,,, -91700,2.404436,2.9949849,,,,,,,,,,,,,, -91800,2.355594,2.969494,,,,,,,,,,,,,, -91900,2.354622,3.0010915,,,,,,,,,,,,,, -92000,2.4928527,2.9593887,,,,,,,,,,,,,, -92100,2.6196096,3.065776,,,,,,,,,,,,,, -92200,2.2924795,2.9744852,,,,,,,,,,,,,, -92300,2.5205152,2.916887,,,,,,,,,,,,,, -92400,2.4532733,2.9837165,,,,,,,,,,,,,, -92491,,,0.7875877022743225,1.0378481149673462,0.6992799639701843,1.4238126277923584,50000.0,0.5730000138282776,2.0835328102111816,10000.0,31145.91519737244,32251.58094215393,31145.91519737244,1100.3993997573853,2.308422565460205,0.0 -92500,2.3558848,2.9753232,,,,,,,,,,,,,, -92600,2.3180685,2.9370515,,,,,,,,,,,,,, -92700,2.493745,3.0240335,,,,,,,,,,,,,, -92800,2.563459,2.974431,,,,,,,,,,,,,, -92900,2.3566384,2.9257526,,,,,,,,,,,,,, -93000,2.3126903,2.9389691,,,,,,,,,,,,,, -93100,2.39106,2.9339433,,,,,,,,,,,,,, -93200,2.3072193,2.9341018,,,,,,,,,,,,,, -93300,2.2571936,2.9714553,,,,,,,,,,,,,, -93400,2.3568947,2.8940752,,,,,,,,,,,,,, -93500,2.4687915,2.9865425,,,,,,,,,,,,,, -93600,2.4425545,2.9480202,,,,,,,,,,,,,, -93700,2.6322193,3.0259156,,,,,,,,,,,,,, -93800,2.6054149,2.956845,,,,,,,,,,,,,, -93900,2.3866842,2.9224648,,,,,,,,,,,,,, -94000,2.300748,3.0329962,,,,,,,,,,,,,, -94008,,,0.7824856638908386,1.0473655462265017,0.6983399987220764,1.4176557064056396,50000.0,0.5737000107765198,2.070194721221924,10000.0,31656.0486536026,32779.9430668354,31656.0486536026,1118.535723209381,2.3523941040039062,0.0 -94100,2.6000354,3.0228026,,,,,,,,,,,,,, -94200,2.5273085,2.9598756,,,,,,,,,,,,,, -94300,2.4818647,2.9970264,,,,,,,,,,,,,, -94400,2.9281547,2.979559,,,,,,,,,,,,,, -94500,2.5900545,2.9478984,,,,,,,,,,,,,, -94600,2.3933675,2.9460287,,,,,,,,,,,,,, -94700,2.4931755,2.9301891,,,,,,,,,,,,,, -94800,2.6404355,2.9146078,,,,,,,,,,,,,, -94900,2.799525,3.0265942,,,,,,,,,,,,,, -95000,2.4236789,2.9865255,,,,,,,,,,,,,, -95100,2.5269244,2.9855294,,,,,,,,,,,,,, -95200,2.487831,2.9870396,,,,,,,,,,,,,, -95300,2.4823909,2.936939,,,,,,,,,,,,,, -95400,2.5945916,3.0264099,,,,,,,,,,,,,, -95500,2.462522,2.9913507,,,,,,,,,,,,,, -95526,,,0.8123804330825806,0.9565143585205078,0.6977799534797668,1.4305815696716309,50000.0,0.5761000514030457,2.0659029483795166,10000.0,32166.133952617645,33308.10327959061,32166.133952617645,1136.5134971141815,2.400253295898437,0.0 -95600,2.3224647,2.959736,,,,,,,,,,,,,, -95700,2.8698325,2.986237,,,,,,,,,,,,,, -95800,2.3465528,2.9167037,,,,,,,,,,,,,, -95900,2.4529462,2.9966385,,,,,,,,,,,,,, -96000,2.4771712,2.9244146,,,,,,,,,,,,,, -96100,2.5105164,2.990706,,,,,,,,,,,,,, -96200,2.5612352,2.9831917,,,,,,,,,,,,,, -96300,2.4064379,2.990273,,,,,,,,,,,,,, -96400,2.4500713,2.86985,,,,,,,,,,,,,, -96500,2.666628,2.9847667,,,,,,,,,,,,,, -96600,2.4226744,2.9867737,,,,,,,,,,,,,, -96700,2.5958366,2.9212193,,,,,,,,,,,,,, -96800,2.5059574,2.9137564,,,,,,,,,,,,,, -96900,2.4118404,2.9300182,,,,,,,,,,,,,, -97000,2.4614267,2.9945807,,,,,,,,,,,,,, -97045,,,0.7994260191917419,0.9897308945655824,0.6978600025177002,1.441440463066101,50000.0,0.5766000151634216,2.0669610500335693,10000.0,32676.36093187332,33836.25459456444,32676.36093187332,1154.3426752090454,2.4465677738189697,0.0 -97100,2.5384076,2.9665284,,,,,,,,,,,,,, -97200,2.5054064,2.981584,,,,,,,,,,,,,, -97300,2.488123,2.9435,,,,,,,,,,,,,, -97400,2.56678,2.911882,,,,,,,,,,,,,, -97500,2.523904,2.866065,,,,,,,,,,,,,, -97600,2.5068622,2.96529,,,,,,,,,,,,,, -97700,2.6012776,2.9635153,,,,,,,,,,,,,, -97800,2.7249014,2.9281034,,,,,,,,,,,,,, -97900,2.508023,2.9267886,,,,,,,,,,,,,, -98000,2.393203,2.9006245,,,,,,,,,,,,,, -98100,2.498132,2.9303164,,,,,,,,,,,,,, -98200,2.49269,2.9204926,,,,,,,,,,,,,, -98300,2.862156,2.9775136,,,,,,,,,,,,,, -98400,2.421305,2.9136295,,,,,,,,,,,,,, -98500,2.497718,2.9505193,,,,,,,,,,,,,, -98562,,,0.8053650856018066,0.9407090544700624,0.7060999870300293,1.369234323501587,50000.0,0.5819000005722046,1.995455741882324,10000.0,33186.46888136864,34364.05491042137,33186.46888136864,1171.936819076538,2.495965480804444,0.0 -98600,2.5600128,2.956658,,,,,,,,,,,,,, -98700,2.5029817,2.8837786,,,,,,,,,,,,,, -98800,2.79912,2.965109,,,,,,,,,,,,,, -98900,2.4728405,2.9389825,,,,,,,,,,,,,, -99000,2.4796774,2.9414036,,,,,,,,,,,,,, -99100,2.5600128,2.9837968,,,,,,,,,,,,,, -99200,2.3656802,2.968768,,,,,,,,,,,,,, -99300,2.5275831,2.9285927,,,,,,,,,,,,,, -99400,2.6474712,2.9134853,,,,,,,,,,,,,, -99500,2.5328774,2.9827926,,,,,,,,,,,,,, -99600,2.6075168,2.946445,,,,,,,,,,,,,, -99700,2.4485738,3.008473,,,,,,,,,,,,,, -99800,2.4773161,2.8415856,,,,,,,,,,,,,, -99900,2.6952968,2.948205,,,,,,,,,,,,,, -100000,2.7402954,2.926172,,,,,,,,,,,,,, -100080,,,0.7974131107330322,1.000417709350586,0.7032999992370605,1.4029895067214966,50000.0,0.5737000107765198,2.07080340385437,10000.0,33696.61943101883,34891.91277551651,33696.61943101883,1189.5462565422058,2.544489622116089,0.0 -100100,2.5287073,2.9210608,,,,,,,,,,,,,, -100200,2.7186885,2.9153395,,,,,,,,,,,,,, -100300,2.4543242,2.9720905,,,,,,,,,,,,,, -100400,2.5298684,2.9616997,,,,,,,,,,,,,, -100500,2.5051923,2.9802175,,,,,,,,,,,,,, -100600,2.755884,2.8791392,,,,,,,,,,,,,, -100700,2.4871216,2.893655,,,,,,,,,,,,,, -100800,2.3810327,2.8869376,,,,,,,,,,,,,, -100900,2.6568117,2.9298007,,,,,,,,,,,,,, -101000,2.7194815,2.9188793,,,,,,,,,,,,,, -101100,2.517883,2.9303598,,,,,,,,,,,,,, -101200,2.4626966,2.8194606,,,,,,,,,,,,,, -101300,2.6062927,2.9274654,,,,,,,,,,,,,, -101400,2.557495,2.9219863,,,,,,,,,,,,,, -101500,2.7992148,2.9544568,,,,,,,,,,,,,, -101597,,,0.7987882494926453,0.9974828362464904,0.7042399644851685,1.3996078968048096,50000.0,0.5823000073432922,2.036078691482544,10000.0,34206.52481389046,35419.85327458382,34206.52481389046,1207.4842991828918,2.5927844047546387,0.0 -101600,2.5251303,2.868364,,,,,,,,,,,,,, -101700,2.5857706,2.9601026,,,,,,,,,,,,,, -101800,2.8649893,2.9611664,,,,,,,,,,,,,, -101900,2.8107278,3.025955,,,,,,,,,,,,,, -102000,2.390316,2.8653493,,,,,,,,,,,,,, -102100,2.5767784,2.9741468,,,,,,,,,,,,,, -102200,2.5268798,2.8978853,,,,,,,,,,,,,, -102300,2.607628,2.979111,,,,,,,,,,,,,, -102400,2.5133855,2.884283,,,,,,,,,,,,,, -102500,2.682895,2.934723,,,,,,,,,,,,,, -102600,2.5720012,2.9312928,,,,,,,,,,,,,, -102700,2.5882816,2.8895917,,,,,,,,,,,,,, -102800,2.783187,2.9386487,,,,,,,,,,,,,, -102900,2.6866868,2.956502,,,,,,,,,,,,,, -103000,2.7652442,2.9934125,,,,,,,,,,,,,, -103100,2.4909832,2.978461,,,,,,,,,,,,,, -103115,,,0.7995654940605164,0.9808319807052612,0.7089200019836426,1.38383686542511,50000.0,0.5835000276565552,2.0236704349517822,10000.0,34716.570257902145,35947.67550730705,34716.570257902145,1225.1640086174011,2.6405186653137207,0.0 -103200,3.037011,2.9535117,,,,,,,,,,,,,, -103300,2.715563,2.8809497,,,,,,,,,,,,,, -103400,2.5828843,2.9453573,,,,,,,,,,,,,, -103500,2.6230648,2.86658,,,,,,,,,,,,,, -103600,2.8083653,2.9201856,,,,,,,,,,,,,, -103700,2.5096707,2.8456032,,,,,,,,,,,,,, -103800,2.7434387,2.8770208,,,,,,,,,,,,,, -103900,2.710824,2.8804238,,,,,,,,,,,,,, -104000,2.7904603,2.9579532,,,,,,,,,,,,,, -104100,2.7614253,2.9269798,,,,,,,,,,,,,, -104200,2.820611,2.841114,,,,,,,,,,,,,, -104300,2.7720218,2.9836674,,,,,,,,,,,,,, -104400,2.559037,2.9263072,,,,,,,,,,,,,, -104500,2.568793,2.8692236,,,,,,,,,,,,,, -104600,2.6088183,2.888239,,,,,,,,,,,,,, -104632,,,0.8382493257522583,0.8604050874710083,0.7122600078582764,1.3814303874969482,50000.0,0.581000030040741,2.044727802276612,10000.0,35226.643216609955,36475.528554201126,35226.643216609955,1242.8338084220886,2.702409267425537,0.0 -104700,2.7101955,2.8963256,,,,,,,,,,,,,, -104800,2.4532878,2.8994694,,,,,,,,,,,,,, -104900,2.679469,2.8900654,,,,,,,,,,,,,, -105000,2.4822056,2.8782113,,,,,,,,,,,,,, -105100,2.925888,2.9606574,,,,,,,,,,,,,, -105200,2.5482345,2.8474655,,,,,,,,,,,,,, -105300,2.6762242,2.9382513,,,,,,,,,,,,,, -105400,2.714702,2.9584978,,,,,,,,,,,,,, -105500,2.4234982,2.856034,,,,,,,,,,,,,, -105600,2.6423106,2.877427,,,,,,,,,,,,,, -105700,2.5482209,2.962107,,,,,,,,,,,,,, -105800,2.606327,2.8483706,,,,,,,,,,,,,, -105900,2.538033,2.8742986,,,,,,,,,,,,,, -106000,2.5581086,2.8353286,,,,,,,,,,,,,, -106100,2.7201111,2.919501,,,,,,,,,,,,,, -106149,,,0.8146125674247742,0.930175006389618,0.7082599997520447,1.390093445777893,50000.0,0.5835000276565552,2.0411343574523926,10000.0,35736.733736753464,37003.26521921158,35736.733736753464,1260.3823611736298,2.751145124435425,0.0 -106200,2.7574446,2.8860068,,,,,,,,,,,,,, -106300,2.8436267,2.9415016,,,,,,,,,,,,,, -106400,2.7228062,2.912542,,,,,,,,,,,,,, -106500,2.819675,2.9134698,,,,,,,,,,,,,, -106600,2.8162823,2.9402719,,,,,,,,,,,,,, -106700,2.7213385,2.880596,,,,,,,,,,,,,, -106800,2.8093204,2.8755991,,,,,,,,,,,,,, -106900,2.943488,2.9684927,,,,,,,,,,,,,, -107000,2.5645528,2.862852,,,,,,,,,,,,,, -107100,2.5770667,2.8217692,,,,,,,,,,,,,, -107200,2.7442553,2.8898335,,,,,,,,,,,,,, -107300,2.9114313,2.9864795,,,,,,,,,,,,,, -107400,2.8309002,2.9461548,,,,,,,,,,,,,, -107500,2.6609895,2.9192226,,,,,,,,,,,,,, -107600,2.9646695,2.9373343,,,,,,,,,,,,,, -107667,,,0.8170041441917419,0.9116354584693908,0.7121999859809875,1.3599272966384888,50000.0,0.5896000266075134,2.011181116104126,10000.0,36246.702476263046,37530.93986582756,36246.702476263046,1277.9910361766815,2.79943585395813,0.0 -107700,2.6534133,2.95632,,,,,,,,,,,,,, -107800,2.8374841,3.0022635,,,,,,,,,,,,,, -107900,2.8169603,2.933568,,,,,,,,,,,,,, -108000,2.7967117,2.9506578,,,,,,,,,,,,,, -108100,2.8665233,2.9336216,,,,,,,,,,,,,, -108200,2.7469482,2.877435,,,,,,,,,,,,,, -108300,2.7523642,2.8505273,,,,,,,,,,,,,, -108400,2.675357,2.94375,,,,,,,,,,,,,, -108500,2.7492335,2.885545,,,,,,,,,,,,,, -108600,2.9397683,2.8612714,,,,,,,,,,,,,, -108700,2.658571,2.8780608,,,,,,,,,,,,,, -108800,2.596616,2.8845434,,,,,,,,,,,,,, -108900,2.8826,2.9753206,,,,,,,,,,,,,, -109000,2.5862179,2.8822036,,,,,,,,,,,,,, -109100,2.550169,2.8533292,,,,,,,,,,,,,, -109184,,,0.8053650856018066,0.9917887449264526,0.7057399749755859,1.4216396808624268,50000.0,0.5759000182151794,2.094308614730835,10000.0,36756.72316074371,38059.57973694801,36756.72316074371,1296.5142834186554,2.8462047576904297,0.0 -109200,2.683643,2.871597,,,,,,,,,,,,,, -109300,2.8342202,2.9052234,,,,,,,,,,,,,, -109400,2.843645,2.853522,,,,,,,,,,,,,, -109500,2.5072434,2.8564608,,,,,,,,,,,,,, -109600,2.9150643,2.8769727,,,,,,,,,,,,,, -109700,2.9925597,2.9568994,,,,,,,,,,,,,, -109800,2.8459237,2.8581953,,,,,,,,,,,,,, -109900,3.066873,2.9695866,,,,,,,,,,,,,, -110000,2.912346,2.990928,,,,,,,,,,,,,, -110100,2.908204,2.885863,,,,,,,,,,,,,, -110200,2.6249838,2.8138742,,,,,,,,,,,,,, -110300,2.989464,2.97504,,,,,,,,,,,,,, -110400,2.7109046,2.850048,,,,,,,,,,,,,, -110500,2.7717185,2.8264463,,,,,,,,,,,,,, -110600,2.7412944,2.9700494,,,,,,,,,,,,,, -110700,2.7433567,2.853251,,,,,,,,,,,,,, -110701,,,0.8196946382522583,0.9097425937652588,0.7184199690818787,1.3392237424850464,50000.0,0.5900000333786011,1.984406352043152,10000.0,37266.88493037224,38587.46114182472,37266.88493037224,1314.1434371471405,2.8875298500061035,0.0 -110800,2.6250713,2.7464786,,,,,,,,,,,,,, -110900,2.7691576,2.9178102,,,,,,,,,,,,,, -111000,2.716284,2.8776927,,,,,,,,,,,,,, -111100,2.8351474,2.8810546,,,,,,,,,,,,,, -111200,2.8301568,2.8257291,,,,,,,,,,,,,, -111300,2.8249297,2.8373628,,,,,,,,,,,,,, -111400,2.7367547,2.9456196,,,,,,,,,,,,,, -111500,2.7971935,2.8511977,,,,,,,,,,,,,, -111600,2.692951,2.8615415,,,,,,,,,,,,,, -111700,2.7973928,2.879279,,,,,,,,,,,,,, -111800,2.8048375,2.9458241,,,,,,,,,,,,,, -111900,2.6998255,2.8385768,,,,,,,,,,,,,, -112000,2.7690187,2.866998,,,,,,,,,,,,,, -112100,2.7455664,2.9028404,,,,,,,,,,,,,, -112200,2.7739909,2.9269986,,,,,,,,,,,,,, -112219,,,0.8258330225944519,0.8885695338249207,0.7242000102996826,1.3239599466323853,50000.0,0.5934000015258789,1.9667091369628904,10000.0,37777.11248350144,39115.8409371376,37777.11248350144,1332.1986136436462,2.9358580112457275,0.0 -112300,2.8685472,2.8447719,,,,,,,,,,,,,, -112400,3.0214393,2.946935,,,,,,,,,,,,,, -112500,2.7283738,2.8599095,,,,,,,,,,,,,, -112600,2.9539952,2.945599,,,,,,,,,,,,,, -112700,2.7606318,2.950482,,,,,,,,,,,,,, -112800,2.7822485,2.876006,,,,,,,,,,,,,, -112900,2.9769015,2.9133778,,,,,,,,,,,,,, -113000,2.7779942,2.9034238,,,,,,,,,,,,,, -113100,2.8920853,2.8813505,,,,,,,,,,,,,, -113200,2.870047,2.8660088,,,,,,,,,,,,,, -113300,3.1079113,2.9785688,,,,,,,,,,,,,, -113400,2.794618,2.8954031,,,,,,,,,,,,,, -113500,2.7331738,2.8695886,,,,,,,,,,,,,, -113600,2.8893445,2.8828676,,,,,,,,,,,,,, -113700,2.789114,2.8996062,,,,,,,,,,,,,, -113736,,,0.8484932780265808,0.8184272646903992,0.7218599915504456,1.3505167961120603,50000.0,0.5962000489234924,1.992884516716004,10000.0,38287.02067565918,39643.61619019508,38287.02067565918,1349.967808008194,2.9852118492126465,0.0 -113800,2.6252317,2.8387635,,,,,,,,,,,,,, -113900,2.854733,2.8939738,,,,,,,,,,,,,, -114000,2.900682,2.891078,,,,,,,,,,,,,, -114100,3.065844,2.8961954,,,,,,,,,,,,,, -114200,2.9195392,2.8527951,,,,,,,,,,,,,, -114300,2.900616,2.8615732,,,,,,,,,,,,,, -114400,2.8889613,2.8849263,,,,,,,,,,,,,, -114500,2.8629892,2.8773732,,,,,,,,,,,,,, -114600,3.1242175,2.893673,,,,,,,,,,,,,, -114700,2.7440448,2.8211238,,,,,,,,,,,,,, -114800,3.029905,2.8348646,,,,,,,,,,,,,, -114900,2.9672325,2.8880415,,,,,,,,,,,,,, -115000,2.731505,2.771812,,,,,,,,,,,,,, -115100,2.9322467,2.924613,,,,,,,,,,,,,, -115200,2.8419192,2.8479934,,,,,,,,,,,,,, -115253,,,0.8364556431770325,0.8760803937911987,0.7188000082969666,1.3691059350967407,50000.0,0.5914000272750854,2.007413864135742,10000.0,38796.95326471329,40171.41336941719,38796.95326471329,1367.7323701381683,3.036560297012329,0.0 -115300,2.9321725,2.8425708,,,,,,,,,,,,,, -115400,2.969312,2.8592153,,,,,,,,,,,,,, -115500,2.7539456,2.809803,,,,,,,,,,,,,, -115600,2.8020058,2.8340542,,,,,,,,,,,,,, -115700,2.819115,2.8728662,,,,,,,,,,,,,, -115800,2.882468,2.7718437,,,,,,,,,,,,,, -115900,2.877201,2.8652,,,,,,,,,,,,,, -116000,2.9071417,2.8027425,,,,,,,,,,,,,, -116100,2.793646,2.8290882,,,,,,,,,,,,,, -116200,2.8523316,2.8130739,,,,,,,,,,,,,, -116300,3.0075347,2.8177404,,,,,,,,,,,,,, -116400,3.1522152,2.886409,,,,,,,,,,,,,, -116500,2.7420778,2.8288722,,,,,,,,,,,,,, -116600,3.0200934,2.8570173,,,,,,,,,,,,,, -116700,2.9103763,2.829515,,,,,,,,,,,,,, -116770,,,0.8357780575752258,0.8407694101333618,0.7210800051689148,1.322500228881836,50000.0,0.5920000076293945,1.972323298454285,10000.0,39307.07882666588,40699.153477191925,39307.07882666588,1385.2446112632751,3.0897364616394043,0.0 -116800,2.8556292,2.8381662,,,,,,,,,,,,,, -116900,2.8369431,2.8737376,,,,,,,,,,,,,, -117000,2.961077,2.8314397,,,,,,,,,,,,,, -117100,2.8158717,2.8308914,,,,,,,,,,,,,, -117200,2.8394916,2.8590412,,,,,,,,,,,,,, -117300,3.3585713,2.8518648,,,,,,,,,,,,,, -117400,2.9417264,2.8844037,,,,,,,,,,,,,, -117500,2.7958374,2.810457,,,,,,,,,,,,,, -117600,2.8850696,2.9087536,,,,,,,,,,,,,, -117700,2.928195,2.8189065,,,,,,,,,,,,,, -117800,2.9437745,2.8332171,,,,,,,,,,,,,, -117900,3.01139,2.8470984,,,,,,,,,,,,,, -118000,3.1621585,2.877093,,,,,,,,,,,,,, -118100,3.0497913,2.865515,,,,,,,,,,,,,, -118200,3.0187376,2.876367,,,,,,,,,,,,,, -118288,,,0.8320910334587097,0.842000424861908,0.7231999635696411,1.3130041360855105,50000.0,0.5981000065803528,1.9613901376724243,10000.0,39817.02921366692,41226.97157096863,39817.02921366692,1403.0130491256714,3.140377759933472,0.0 -118300,3.1686337,2.8243353,,,,,,,,,,,,,, -118400,2.9903333,2.7948334,,,,,,,,,,,,,, -118500,3.0721636,2.847772,,,,,,,,,,,,,, -118600,2.898941,2.7946186,,,,,,,,,,,,,, -118700,2.7193522,2.7381735,,,,,,,,,,,,,, -118800,2.8203607,2.8215356,,,,,,,,,,,,,, -118900,2.833891,2.8392246,,,,,,,,,,,,,, -119000,2.8245356,2.7677426,,,,,,,,,,,,,, -119100,3.0643985,2.8649607,,,,,,,,,,,,,, -119200,2.8427513,2.8498564,,,,,,,,,,,,,, -119300,2.8886197,2.8051353,,,,,,,,,,,,,, -119400,2.977174,2.8367558,,,,,,,,,,,,,, -119500,3.0543625,2.8262172,,,,,,,,,,,,,, -119600,2.931167,2.9079165,,,,,,,,,,,,,, -119700,3.0368006,2.845384,,,,,,,,,,,,,, -119800,3.002157,2.8435056,,,,,,,,,,,,,, -119806,,,0.8359375,0.8376394510269165,0.7249799966812134,1.3017497062683103,50000.0,0.6025000214576721,1.9448192119598389,10000.0,40327.05902671814,41755.55356431008,40327.05902671814,1421.460168838501,3.1964597702026367,0.0 -119900,2.8905916,2.805486,,,,,,,,,,,,,, -120000,3.0175648,2.8485296,,,,,,,,,,,,,, -120100,2.945576,2.8322966,,,,,,,,,,,,,, -120200,2.97951,2.8331118,,,,,,,,,,,,,, -120300,2.971188,2.791728,,,,,,,,,,,,,, -120400,3.1092699,2.831803,,,,,,,,,,,,,, -120500,3.1257875,2.8094332,,,,,,,,,,,,,, -120600,2.8478627,2.8234704,,,,,,,,,,,,,, -120700,2.937364,2.8542895,,,,,,,,,,,,,, -120800,3.1806452,2.82182,,,,,,,,,,,,,, -120900,3.0996823,2.9018393,,,,,,,,,,,,,, -121000,3.08114,2.8092396,,,,,,,,,,,,,, -121100,3.0744529,2.8353643,,,,,,,,,,,,,, -121200,2.8834074,2.7918766,,,,,,,,,,,,,, -121300,2.994573,2.7802205,,,,,,,,,,,,,, -121323,,,0.8392458558082581,0.7929951548576355,0.7260000109672546,1.2777409553527832,50000.0,0.5984000563621521,1.9204200506210327,10000.0,40836.98566198349,42283.02461075783,40836.98566198349,1438.906097650528,3.2463817596435547,0.0 -121400,3.057418,2.8321729,,,,,,,,,,,,,, -121500,3.2336042,2.8807268,,,,,,,,,,,,,, -121600,3.1183846,2.7836523,,,,,,,,,,,,,, -121700,3.2775633,2.8374104,,,,,,,,,,,,,, -121800,3.1485713,2.8970265,,,,,,,,,,,,,, -121900,2.88005,2.8071718,,,,,,,,,,,,,, -122000,2.8483014,2.7815125,,,,,,,,,,,,,, -122100,3.1027389,2.8336163,,,,,,,,,,,,,, -122200,3.0165522,2.776326,,,,,,,,,,,,,, -122300,2.9650924,2.8485618,,,,,,,,,,,,,, -122400,3.1051593,2.789177,,,,,,,,,,,,,, -122500,3.0583594,2.880232,,,,,,,,,,,,,, -122600,2.966259,2.796255,,,,,,,,,,,,,, -122700,3.0995376,2.8383024,,,,,,,,,,,,,, -122800,3.0512788,2.783259,,,,,,,,,,,,,, -122840,,,0.856465220451355,0.7566264867782593,0.7263199687004089,1.3030215501785278,50000.0,0.6027000546455383,1.939791440963745,10000.0,41347.065252542496,42810.877432107925,41347.065252542496,1456.579344749451,3.2976536750793457,0.0 -122900,3.1572847,2.8128314,,,,,,,,,,,,,, -123000,3.0475695,2.8813586,,,,,,,,,,,,,, -123100,3.0849435,2.7343905,,,,,,,,,,,,,, -123200,2.9270651,2.7401955,,,,,,,,,,,,,, -123300,2.9875803,2.744205,,,,,,,,,,,,,, -123400,3.215843,2.792992,,,,,,,,,,,,,, -123500,2.927447,2.8294077,,,,,,,,,,,,,, -123600,2.9925928,2.811465,,,,,,,,,,,,,, -123700,3.2681565,2.8251452,,,,,,,,,,,,,, -123800,3.0988955,2.8169746,,,,,,,,,,,,,, -123900,3.0299404,2.8153543,,,,,,,,,,,,,, -124000,3.1881244,2.840572,,,,,,,,,,,,,, -124100,3.170367,2.7994456,,,,,,,,,,,,,, -124200,3.1338103,2.8247573,,,,,,,,,,,,,, -124300,3.0940416,2.8251066,,,,,,,,,,,,,, -124357,,,0.8516820669174194,0.7917265295982361,0.7299000024795532,1.3059499263763428,50000.0,0.6040000319480896,1.9366973638534544,10000.0,41857.15559768677,43339.1133646965,41857.15559768677,1474.6264395713806,3.3471579551696777,0.0 -124400,3.1111739,2.8066473,,,,,,,,,,,,,, -124500,3.0508943,2.823958,,,,,,,,,,,,,, -124600,3.0192032,2.7755485,,,,,,,,,,,,,, -124700,3.1197476,2.8690684,,,,,,,,,,,,,, -124800,3.094603,2.8377476,,,,,,,,,,,,,, -124900,3.0901191,2.7969656,,,,,,,,,,,,,, -125000,3.3435807,2.7395775,,,,,,,,,,,,,, -125100,3.0644753,2.848413,,,,,,,,,,,,,, -125200,3.3543546,2.8364222,,,,,,,,,,,,,, -125300,3.2502792,2.857495,,,,,,,,,,,,,, -125400,3.014863,2.8091319,,,,,,,,,,,,,, -125500,3.1801076,2.8510227,,,,,,,,,,,,,, -125600,3.185674,2.7705665,,,,,,,,,,,,,, -125700,3.1692443,2.7917368,,,,,,,,,,,,,, -125800,3.1664903,2.8128705,,,,,,,,,,,,,, -125875,,,0.8467992544174194,0.8255721926689148,0.7300599813461304,1.3159984350204468,50000.0,0.6016000509262085,1.963367819786072,10000.0,42367.3814136982,43867.16249752045,42367.3814136982,1492.354014635086,3.394263982772827,0.0 -125900,3.1259396,2.7916105,,,,,,,,,,,,,, -126000,3.033216,2.747188,,,,,,,,,,,,,, -126100,3.4997916,2.7973795,,,,,,,,,,,,,, -126200,3.1800957,2.8515124,,,,,,,,,,,,,, -126300,2.8942177,2.7831645,,,,,,,,,,,,,, -126400,3.2092373,2.7608354,,,,,,,,,,,,,, -126500,3.169075,2.7718341,,,,,,,,,,,,,, -126600,3.1496258,2.7972455,,,,,,,,,,,,,, -126700,3.3886087,2.730897,,,,,,,,,,,,,, -126800,3.1313806,2.7173855,,,,,,,,,,,,,, -126900,3.3372254,2.8426557,,,,,,,,,,,,,, -127000,3.3781314,2.8075614,,,,,,,,,,,,,, -127100,3.2051005,2.7711954,,,,,,,,,,,,,, -127200,3.0522969,2.8152986,,,,,,,,,,,,,, -127300,3.1234448,2.7880461,,,,,,,,,,,,,, -127392,,,0.8463209271430969,0.8251389861106873,0.7315599918365479,1.3149410486221311,50000.0,0.6080000400543213,1.9545058012008667,10000.0,42877.36577987671,44394.94526076317,42877.36577987671,1510.0505783557892,3.447601556777954,0.0 -127400,3.183491,2.730644,,,,,,,,,,,,,, -127500,3.2173624,2.7383475,,,,,,,,,,,,,, -127600,3.1510563,2.727718,,,,,,,,,,,,,, -127700,3.2184703,2.7904015,,,,,,,,,,,,,, -127800,3.4302607,2.8254547,,,,,,,,,,,,,, -127900,3.0834794,2.7975707,,,,,,,,,,,,,, -128000,3.042398,2.7846622,,,,,,,,,,,,,, -128100,3.2151673,2.7563777,,,,,,,,,,,,,, -128200,3.2227764,2.7780445,,,,,,,,,,,,,, -128300,3.3985338,2.7753768,,,,,,,,,,,,,, -128400,3.1437702,2.766822,,,,,,,,,,,,,, -128500,3.174974,2.7039864,,,,,,,,,,,,,, -128600,3.0652306,2.7885313,,,,,,,,,,,,,, -128700,3.021096,2.7551267,,,,,,,,,,,,,, -128800,3.3372161,2.7823544,,,,,,,,,,,,,, -128900,3.0716932,2.8044152,,,,,,,,,,,,,, -128909,,,0.8551897406578064,0.7878064513206482,0.7348999977111816,1.2902276515960691,50000.0,0.6148000359535217,1.911783456802368,10000.0,43387.268907785416,44922.67690491676,43387.268907785416,1527.7779560089111,3.500355958938598,0.0 -129000,3.217732,2.7474723,,,,,,,,,,,,,, -129100,3.2971184,2.773878,,,,,,,,,,,,,, -129200,3.239505,2.7287374,,,,,,,,,,,,,, -129300,3.2437859,2.8307226,,,,,,,,,,,,,, -129400,3.4329267,2.7660284,,,,,,,,,,,,,, -129500,3.4180596,2.7496724,,,,,,,,,,,,,, -129600,3.120835,2.7175581,,,,,,,,,,,,,, -129700,3.172758,2.7567868,,,,,,,,,,,,,, -129800,3.3222804,2.799285,,,,,,,,,,,,,, -129900,3.3273897,2.7573335,,,,,,,,,,,,,, -130000,3.2442534,2.7356317,,,,,,,,,,,,,, -130100,3.143951,2.6761706,,,,,,,,,,,,,, -130200,3.2187858,2.7839794,,,,,,,,,,,,,, -130300,3.207181,2.7965658,,,,,,,,,,,,,, -130400,3.1831067,2.7271695,,,,,,,,,,,,,, -130426,,,0.8754384517669678,0.675972580909729,0.7377600073814392,1.2478464841842651,50000.0,0.6135000586509705,1.874533653259277,10000.0,43897.26140189171,45450.63581967354,43897.26140189171,1545.6414897441864,3.5544705390930176,0.0 -130500,3.4243197,2.81987,,,,,,,,,,,,,, -130600,3.1352975,2.752422,,,,,,,,,,,,,, -130700,3.3774166,2.7062764,,,,,,,,,,,,,, -130800,3.3405101,2.724394,,,,,,,,,,,,,, -130900,3.248004,2.809533,,,,,,,,,,,,,, -131000,3.0576735,2.73277,,,,,,,,,,,,,, -131100,3.2492406,2.7895448,,,,,,,,,,,,,, -131200,3.2748375,2.7530077,,,,,,,,,,,,,, -131300,3.2183783,2.7610908,,,,,,,,,,,,,, -131400,3.5519278,2.7707608,,,,,,,,,,,,,, -131500,3.335539,2.807872,,,,,,,,,,,,,, -131600,3.2514555,2.7396126,,,,,,,,,,,,,, -131700,3.2879715,2.7154315,,,,,,,,,,,,,, -131800,3.3130097,2.764637,,,,,,,,,,,,,, -131900,3.126372,2.6972194,,,,,,,,,,,,,, -131942,,,0.8747608065605164,0.712469220161438,0.7366399765014648,1.283627986907959,50000.0,0.610200047492981,1.9361129999160769,10000.0,44407.33543777466,45979.16612029076,44407.33543777466,1563.994446992874,3.609476804733277,0.0 -132000,3.335428,2.7395203,,,,,,,,,,,,,, -132100,3.1355639,2.6728225,,,,,,,,,,,,,, -132200,3.2934184,2.7301714,,,,,,,,,,,,,, -132300,3.3538556,2.6946547,,,,,,,,,,,,,, -132400,3.3364947,2.8466938,,,,,,,,,,,,,, -132500,3.2271307,2.7356243,,,,,,,,,,,,,, -132600,3.1987624,2.7311947,,,,,,,,,,,,,, -132700,3.3754375,2.7180932,,,,,,,,,,,,,, -132800,3.2605538,2.722168,,,,,,,,,,,,,, -132900,3.6034439,2.8187013,,,,,,,,,,,,,, -133000,3.3268247,2.7500417,,,,,,,,,,,,,, -133100,3.399614,2.7536025,,,,,,,,,,,,,, -133200,3.248289,2.765872,,,,,,,,,,,,,, -133300,3.386804,2.7396123,,,,,,,,,,,,,, -133400,3.0692034,2.6913571,,,,,,,,,,,,,, -133460,,,0.872468888759613,0.7190256714820862,0.7376599907875061,1.2690998315811155,50000.0,0.617900013923645,1.9060282707214355,10000.0,44917.43026852608,46507.20842504501,44917.43026852608,1581.8467528820038,3.656475305557251,0.0 -133500,3.5187871,2.749935,,,,,,,,,,,,,, -133600,3.6468077,2.7512937,,,,,,,,,,,,,, -133700,3.3184357,2.742194,,,,,,,,,,,,,, -133800,3.4041517,2.716039,,,,,,,,,,,,,, -133900,3.3582146,2.7603798,,,,,,,,,,,,,, -134000,3.4942513,2.715288,,,,,,,,,,,,,, -134100,3.4716942,2.7542686,,,,,,,,,,,,,, -134200,3.4007819,2.6852608,,,,,,,,,,,,,, -134300,3.3808825,2.7599847,,,,,,,,,,,,,, -134400,3.575386,2.773661,,,,,,,,,,,,,, -134500,3.3672357,2.7374015,,,,,,,,,,,,,, -134600,3.399452,2.7729065,,,,,,,,,,,,,, -134700,3.336278,2.7291062,,,,,,,,,,,,,, -134800,3.4802065,2.7293217,,,,,,,,,,,,,, -134900,3.4317396,2.6973443,,,,,,,,,,,,,, -134977,,,0.8716716766357422,0.7018413543701172,0.7393400073051453,1.2499456405639648,50000.0,0.6167000532150269,1.884137749671936,10000.0,45427.43486452103,47034.92509889603,45427.43486452103,1599.4554841518402,3.710766077041626,0.0 -135000,3.3210626,2.7470298,,,,,,,,,,,,,, -135100,3.566021,2.7876868,,,,,,,,,,,,,, -135200,3.4865067,2.7095933,,,,,,,,,,,,,, -135300,3.2368412,2.694469,,,,,,,,,,,,,, -135400,3.448157,2.7255375,,,,,,,,,,,,,, -135500,3.2974272,2.6711235,,,,,,,,,,,,,, -135600,3.3096638,2.7150013,,,,,,,,,,,,,, -135700,3.3305497,2.7911222,,,,,,,,,,,,,, -135800,3.2263126,2.7388868,,,,,,,,,,,,,, -135900,3.3853626,2.7429943,,,,,,,,,,,,,, -136000,3.2709897,2.7508943,,,,,,,,,,,,,, -136100,3.1031604,2.7033165,,,,,,,,,,,,,, -136200,3.3527157,2.7181625,,,,,,,,,,,,,, -136300,3.3672895,2.7302098,,,,,,,,,,,,,, -136400,3.3655415,2.744705,,,,,,,,,,,,,, -136494,,,0.8722297549247742,0.7077077031135559,0.7393400073051453,1.2572718858718872,50000.0,0.6141000390052795,1.8976376056671145,10000.0,45937.36474323273,47562.58168148994,45937.36474323273,1617.0773251056671,3.766352891921997,0.0 -136500,3.416628,2.7331998,,,,,,,,,,,,,, -136600,3.5766633,2.7261968,,,,,,,,,,,,,, -136700,3.2480378,2.7218602,,,,,,,,,,,,,, -136800,3.3906507,2.699233,,,,,,,,,,,,,, -136900,3.3963828,2.7424655,,,,,,,,,,,,,, -137000,3.5523643,2.6968608,,,,,,,,,,,,,, -137100,3.6239321,2.7195334,,,,,,,,,,,,,, -137200,3.4972217,2.7242846,,,,,,,,,,,,,, -137300,3.4684975,2.6568162,,,,,,,,,,,,,, -137400,3.1998281,2.647927,,,,,,,,,,,,,, -137500,3.6254497,2.7805474,,,,,,,,,,,,,, -137600,3.3637848,2.6765842,,,,,,,,,,,,,, -137700,3.6183593,2.6980572,,,,,,,,,,,,,, -137800,3.495407,2.686266,,,,,,,,,,,,,, -137900,3.459477,2.6872404,,,,,,,,,,,,,, -138000,3.5597298,2.6933937,,,,,,,,,,,,,, -138011,,,0.87015700340271,0.7002887725830078,0.7409999966621399,1.2513980865478516,50000.0,0.6186000108718872,1.8796846866607664,10000.0,46447.27852892876,48090.08140492439,46447.27852892876,1634.5576055049896,3.823425769805908,0.0 -138100,3.353378,2.7146316,,,,,,,,,,,,,, -138200,3.2124326,2.6777303,,,,,,,,,,,,,, -138300,3.3541253,2.6630113,,,,,,,,,,,,,, -138400,3.3192751,2.6780634,,,,,,,,,,,,,, -138500,3.60417,2.765835,,,,,,,,,,,,,, -138600,3.6418371,2.7190673,,,,,,,,,,,,,, -138700,3.3358324,2.6933603,,,,,,,,,,,,,, -138800,3.6284807,2.7190535,,,,,,,,,,,,,, -138900,3.5421169,2.7388067,,,,,,,,,,,,,, -139000,3.413825,2.7357984,,,,,,,,,,,,,, -139100,3.5134087,2.7322218,,,,,,,,,,,,,, -139200,3.3006418,2.686655,,,,,,,,,,,,,, -139300,3.3797996,2.6710413,,,,,,,,,,,,,, -139400,3.5654297,2.7545815,,,,,,,,,,,,,, -139500,3.3396184,2.6870399,,,,,,,,,,,,,, -139529,,,0.8952487111091614,0.6294954419136047,0.7395399808883667,1.2603678703308103,50000.0,0.6126000285148621,1.902232766151428,10000.0,46957.439453840256,48618.27424407005,46957.439453840256,1652.4860928058624,3.877878189086914,0.0 -139600,3.650565,2.7172577,,,,,,,,,,,,,, -139700,3.2513819,2.6575048,,,,,,,,,,,,,, -139800,3.4759867,2.6731305,,,,,,,,,,,,,, -139900,3.1354403,2.6547775,,,,,,,,,,,,,, -140000,3.841896,2.7041068,,,,,,,,,,,,,, -140100,3.4466722,2.6974237,,,,,,,,,,,,,, -140200,3.6429744,2.7066643,,,,,,,,,,,,,, -140300,3.7245836,2.69765,,,,,,,,,,,,,, -140400,3.632567,2.6936438,,,,,,,,,,,,,, -140500,3.420841,2.694687,,,,,,,,,,,,,, -140600,3.6393273,2.7205567,,,,,,,,,,,,,, -140700,3.762559,2.6814008,,,,,,,,,,,,,, -140800,3.3010411,2.6464298,,,,,,,,,,,,,, -140900,3.44387,2.6630692,,,,,,,,,,,,,, -141000,3.396325,2.6067696,,,,,,,,,,,,,, -141046,,,0.8868981003761292,0.6444560289382935,0.7418599724769592,1.247159719467163,50000.0,0.6173000335693359,1.898269772529602,10000.0,47467.50153207779,49146.01747059822,47467.50153207779,1670.0611016750336,3.935212373733521,0.0 -141100,3.7081692,2.6925755,,,,,,,,,,,,,, -141200,3.379502,2.5936983,,,,,,,,,,,,,, -141300,3.4173787,2.770205,,,,,,,,,,,,,, -141400,3.8427882,2.7288601,,,,,,,,,,,,,, -141500,3.648598,2.7488678,,,,,,,,,,,,,, -141600,3.5189633,2.706734,,,,,,,,,,,,,, -141700,3.5653028,2.6808538,,,,,,,,,,,,,, -141800,3.686695,2.7157183,,,,,,,,,,,,,, -141900,3.6744123,2.6662204,,,,,,,,,,,,,, -142000,3.8627057,2.694522,,,,,,,,,,,,,, -142100,3.5915346,2.714109,,,,,,,,,,,,,, -142200,3.5510974,2.7328305,,,,,,,,,,,,,, -142300,3.4741495,2.654805,,,,,,,,,,,,,, -142400,3.6211667,2.7129734,,,,,,,,,,,,,, -142500,3.6540654,2.6078348,,,,,,,,,,,,,, -142564,,,0.8874959945678711,0.652847409248352,0.741599977016449,1.251569747924805,50000.0,0.6184000372886658,1.8916077613830569,10000.0,47977.68201828003,49674.135682582855,47977.68201828003,1687.8899908065796,3.9951000213623047,0.0 -142600,3.223568,2.5888262,,,,,,,,,,,,,, -142700,3.3734503,2.682685,,,,,,,,,,,,,, -142800,3.3396165,2.611032,,,,,,,,,,,,,, -142900,3.4650235,2.7082095,,,,,,,,,,,,,, -143000,3.6499236,2.6949034,,,,,,,,,,,,,, -143100,3.6359327,2.6760993,,,,,,,,,,,,,, -143200,3.6111612,2.6864417,,,,,,,,,,,,,, -143300,3.5347347,2.6794121,,,,,,,,,,,,,, -143400,3.7214034,2.6654224,,,,,,,,,,,,,, -143500,3.782486,2.6605186,,,,,,,,,,,,,, -143600,3.6948733,2.6557517,,,,,,,,,,,,,, -143700,3.4705758,2.7062604,,,,,,,,,,,,,, -143800,3.5596557,2.6966567,,,,,,,,,,,,,, -143900,3.6891475,2.6649094,,,,,,,,,,,,,, -144000,3.7128532,2.7139833,,,,,,,,,,,,,, -144081,,,0.8844068646430969,0.6725389957427979,0.7425000071525574,1.2619949579238892,50000.0,0.6184000372886658,1.899420022964477,10000.0,48487.624255895615,50201.76250863075,48487.624255895615,1705.4653916358948,4.055452346801758,0.0 -144100,3.4019217,2.6549766,,,,,,,,,,,,,, -144200,3.5218606,2.6876972,,,,,,,,,,,,,, -144300,3.792307,2.6822298,,,,,,,,,,,,,, -144400,3.6079378,2.65949,,,,,,,,,,,,,, -144500,3.5778012,2.6882184,,,,,,,,,,,,,, -144600,3.3933706,2.6275115,,,,,,,,,,,,,, -144700,3.8220465,2.7003524,,,,,,,,,,,,,, -144800,3.7727327,2.7634509,,,,,,,,,,,,,, -144900,3.6240315,2.676989,,,,,,,,,,,,,, -145000,3.6698143,2.6628175,,,,,,,,,,,,,, -145100,3.6056156,2.6913927,,,,,,,,,,,,,, -145200,3.6938648,2.6761665,,,,,,,,,,,,,, -145300,3.4364862,2.6137094,,,,,,,,,,,,,, -145400,3.6477926,2.6450675,,,,,,,,,,,,,, -145500,3.7981913,2.677507,,,,,,,,,,,,,, -145597,,,0.893973171710968,0.6204026341438293,0.750499963760376,1.2134588956832886,50000.0,0.6247000098228455,1.8494324684143064,10000.0,48997.56184363365,50729.53057074547,48997.56184363365,1723.1903104782104,4.112022161483765,0.0 -145600,3.6623366,2.591011,,,,,,,,,,,,,, -145700,3.4824142,2.6382833,,,,,,,,,,,,,, -145800,3.7190008,2.6955109,,,,,,,,,,,,,, -145900,3.2706735,2.5608032,,,,,,,,,,,,,, -146000,3.787554,2.68517,,,,,,,,,,,,,, -146100,3.8899906,2.749223,,,,,,,,,,,,,, -146200,3.5865157,2.634353,,,,,,,,,,,,,, -146300,3.6553347,2.6621273,,,,,,,,,,,,,, -146400,3.8330905,2.6633348,,,,,,,,,,,,,, -146500,3.7128048,2.6331396,,,,,,,,,,,,,, -146600,3.8251553,2.6653,,,,,,,,,,,,,, -146700,3.6220407,2.6338959,,,,,,,,,,,,,, -146800,3.6271572,2.658144,,,,,,,,,,,,,, -146900,3.8834202,2.717127,,,,,,,,,,,,,, -147000,3.691262,2.7060223,,,,,,,,,,,,,, -147100,3.8379304,2.6778007,,,,,,,,,,,,,, -147114,,,0.8909637928009033,0.6227953433990479,0.7475199699401855,1.2216311693191528,50000.0,0.6261000037193298,1.849046230316162,10000.0,49507.54261040688,51257.21557235718,49507.54261040688,1740.7729632854462,4.185078859329224,0.0 -147200,3.7257428,2.5965316,,,,,,,,,,,,,, -147300,3.875766,2.6564941,,,,,,,,,,,,,, -147400,3.722391,2.7008529,,,,,,,,,,,,,, -147500,3.663048,2.6609163,,,,,,,,,,,,,, -147600,3.6096287,2.6231704,,,,,,,,,,,,,, -147700,3.6405241,2.6522944,,,,,,,,,,,,,, -147800,3.884256,2.6422298,,,,,,,,,,,,,, -147900,3.5272353,2.6468267,,,,,,,,,,,,,, -148000,3.743457,2.6157074,,,,,,,,,,,,,, -148100,3.6734638,2.6651502,,,,,,,,,,,,,, -148200,3.9148128,2.6407015,,,,,,,,,,,,,, -148300,3.7101943,2.6445935,,,,,,,,,,,,,, -148400,3.7275307,2.6678958,,,,,,,,,,,,,, -148500,3.6970637,2.6064017,,,,,,,,,,,,,, -148600,3.8355532,2.7036586,,,,,,,,,,,,,, -148632,,,0.9108936190605164,0.5526172518730164,0.7495799660682678,1.2158610820770264,50000.0,0.6324000358581543,1.8357360363006592,10000.0,50017.68539023399,51785.398183107376,50017.68539023399,1758.690044879913,4.2591400146484375,0.0 -148700,3.6188958,2.6542122,,,,,,,,,,,,,, -148800,3.7145092,2.6862578,,,,,,,,,,,,,, -148900,3.9698863,2.6267953,,,,,,,,,,,,,, -149000,3.7912679,2.6212456,,,,,,,,,,,,,, -149100,3.6774023,2.5782003,,,,,,,,,,,,,, -149200,3.5027127,2.5697389,,,,,,,,,,,,,, -149300,3.7991834,2.637763,,,,,,,,,,,,,, -149400,3.7553146,2.6547163,,,,,,,,,,,,,, -149500,3.911331,2.6536353,,,,,,,,,,,,,, -149600,3.7420669,2.6572099,,,,,,,,,,,,,, -149700,3.639297,2.6134253,,,,,,,,,,,,,, -149800,3.8994043,2.6858351,,,,,,,,,,,,,, -149900,3.7018235,2.6307735,,,,,,,,,,,,,, -150000,3.843244,2.692387,,,,,,,,,,,,,, -150100,3.7854626,2.6273608,,,,,,,,,,,,,, -150148,,,0.906887710094452,0.5834668278694153,0.751039981842041,1.21971595287323,50000.0,0.6328000426292419,1.8467143774032595,10000.0,50527.5884706974,52313.0264942646,50527.5884706974,1776.3126661777496,4.313075065612793,0.0 -150200,3.762062,2.6727757,,,,,,,,,,,,,, -150300,3.821278,2.6944199,,,,,,,,,,,,,, -150400,3.7326865,2.6257741,,,,,,,,,,,,,, -150500,3.6351573,2.6257055,,,,,,,,,,,,,, -150600,3.8087487,2.640324,,,,,,,,,,,,,, -150700,3.6391587,2.624787,,,,,,,,,,,,,, -150800,3.7139685,2.652455,,,,,,,,,,,,,, -150900,4.047852,2.655116,,,,,,,,,,,,,, -151000,3.8068528,2.6080756,,,,,,,,,,,,,, -151100,3.6661339,2.5987625,,,,,,,,,,,,,, -151200,3.4967797,2.5923946,,,,,,,,,,,,,, -151300,3.6233175,2.5843537,,,,,,,,,,,,,, -151400,3.785189,2.613419,,,,,,,,,,,,,, -151500,3.6252213,2.6417425,,,,,,,,,,,,,, -151600,3.9445574,2.682771,,,,,,,,,,,,,, -151666,,,0.906628668308258,0.5731101632118225,0.753279983997345,1.204783320426941,50000.0,0.6306000351905823,1.8276084661483765,10000.0,51037.737275362015,52841.0115032196,51037.737275362015,1794.0457956790924,4.367350816726685,0.0 -151700,3.765385,2.6405976,,,,,,,,,,,,,, -151800,3.6847155,2.6190631,,,,,,,,,,,,,, -151900,3.7548249,2.6046476,,,,,,,,,,,,,, -152000,3.5826766,2.58651,,,,,,,,,,,,,, -152100,3.7470942,2.619168,,,,,,,,,,,,,, -152200,3.7599204,2.5937624,,,,,,,,,,,,,, -152300,3.6580765,2.6119545,,,,,,,,,,,,,, -152400,3.8722296,2.5773544,,,,,,,,,,,,,, -152500,3.910245,2.5718079,,,,,,,,,,,,,, -152600,4.356039,2.6382973,,,,,,,,,,,,,, -152700,3.7733843,2.643044,,,,,,,,,,,,,, -152800,3.8502753,2.640646,,,,,,,,,,,,,, -152900,3.9432812,2.653877,,,,,,,,,,,,,, -153000,3.9502678,2.6226103,,,,,,,,,,,,,, -153100,3.5857463,2.6436572,,,,,,,,,,,,,, -153183,,,0.9112324714660645,0.56153404712677,0.7550399899482727,1.1990495920181274,50000.0,0.6339000463485718,1.8135104179382324,10000.0,51547.88285636902,53368.90864777565,51547.88285636902,1811.6900601387024,4.42540454864502,0.0 -153200,3.9900045,2.6036403,,,,,,,,,,,,,, -153300,3.9569259,2.7259889,,,,,,,,,,,,,, -153400,3.8059208,2.6199899,,,,,,,,,,,,,, -153500,4.0366106,2.6461828,,,,,,,,,,,,,, -153600,3.6999197,2.5683043,,,,,,,,,,,,,, -153700,3.6804178,2.5994005,,,,,,,,,,,,,, -153800,3.7121174,2.6207871,,,,,,,,,,,,,, -153900,3.756894,2.5668344,,,,,,,,,,,,,, -154000,3.9219456,2.6023576,,,,,,,,,,,,,, -154100,3.7572963,2.6400952,,,,,,,,,,,,,, -154200,3.9799426,2.6914773,,,,,,,,,,,,,, -154300,3.935169,2.6006145,,,,,,,,,,,,,, -154400,3.72945,2.6036725,,,,,,,,,,,,,, -154500,3.860763,2.6036162,,,,,,,,,,,,,, -154600,3.8490152,2.617501,,,,,,,,,,,,,, -154700,4.0789313,2.6814153,,,,,,,,,,,,,, -154701,,,0.9130659699440002,0.5653855800628662,0.7539599537849426,1.209978461265564,50000.0,0.6313000321388245,1.840086817741394,10000.0,52058.02461576462,53896.99075818062,52058.02461576462,1829.5243427753448,4.483033180236816,0.0 -154800,3.420407,2.5542798,,,,,,,,,,,,,, -154900,3.9013054,2.6225367,,,,,,,,,,,,,, -155000,3.990796,2.6109328,,,,,,,,,,,,,, -155100,4.130071,2.6604564,,,,,,,,,,,,,, -155200,3.766012,2.623285,,,,,,,,,,,,,, -155300,3.6572075,2.5865393,,,,,,,,,,,,,, -155400,3.8404994,2.6366713,,,,,,,,,,,,,, -155500,3.8489313,2.5542543,,,,,,,,,,,,,, -155600,4.079722,2.6038864,,,,,,,,,,,,,, -155700,3.707154,2.619715,,,,,,,,,,,,,, -155800,3.8032355,2.5979578,,,,,,,,,,,,,, -155900,3.887192,2.618746,,,,,,,,,,,,,, -156000,4.0319595,2.6299796,,,,,,,,,,,,,, -156100,3.9096138,2.618599,,,,,,,,,,,,,, -156200,4.033185,2.6363423,,,,,,,,,,,,,, -156219,,,0.9151187539100648,0.5467893481254578,0.7532999515533447,1.2041923999786377,50000.0,0.6291000247001648,1.8272104263305664,10000.0,52568.0820877552,54425.057513952255,52568.0820877552,1847.399951696396,4.56777286529541,0.0 -156300,3.8358376,2.5887916,,,,,,,,,,,,,, -156400,3.8146126,2.600931,,,,,,,,,,,,,, -156500,3.9194407,2.5302923,,,,,,,,,,,,,, -156600,3.824545,2.5522554,,,,,,,,,,,,,, -156700,3.7507715,2.5520945,,,,,,,,,,,,,, -156800,3.812228,2.5874407,,,,,,,,,,,,,, -156900,4.0441656,2.591,,,,,,,,,,,,,, -157000,4.3580194,2.6583753,,,,,,,,,,,,,, -157100,3.8665726,2.568111,,,,,,,,,,,,,, -157200,4.1591644,2.6263204,,,,,,,,,,,,,, -157300,4.0046744,2.6078024,,,,,,,,,,,,,, -157400,3.7794683,2.5615382,,,,,,,,,,,,,, -157500,3.5739045,2.5296474,,,,,,,,,,,,,, -157600,3.9763947,2.5910463,,,,,,,,,,,,,, -157700,3.6996953,2.5345035,,,,,,,,,,,,,, -157736,,,0.9242864847183228,0.5116089582443237,0.7549600005149841,1.1921504735946655,50000.0,0.6354000568389893,1.8231844902038568,10000.0,53078.011219501495,54952.83845996857,53078.011219501495,1865.1478443145752,4.623232364654541,0.0 -157800,3.8053334,2.615975,,,,,,,,,,,,,, -157900,4.0582547,2.614287,,,,,,,,,,,,,, -158000,3.9049895,2.637058,,,,,,,,,,,,,, -158100,3.8762372,2.607463,,,,,,,,,,,,,, -158200,3.9030392,2.6064825,,,,,,,,,,,,,, -158300,3.8657856,2.5536218,,,,,,,,,,,,,, -158400,3.8817928,2.5478888,,,,,,,,,,,,,, -158500,3.492782,2.5649786,,,,,,,,,,,,,, -158600,3.746259,2.5829723,,,,,,,,,,,,,, -158700,4.0726275,2.589477,,,,,,,,,,,,,, -158800,4.252314,2.610471,,,,,,,,,,,,,, -158900,4.0526905,2.648048,,,,,,,,,,,,,, -159000,3.797669,2.581402,,,,,,,,,,,,,, -159100,4.0949445,2.61585,,,,,,,,,,,,,, -159200,3.6165028,2.5848408,,,,,,,,,,,,,, -159253,,,0.9227519035339355,0.5066460371017456,0.7560999989509583,1.1838161945343018,50000.0,0.6379000544548035,1.805983066558838,10000.0,53588.11385130882,55481.25697994232,53588.11385130882,1883.357558965683,4.680333614349365,0.0 -159300,3.8678396,2.572038,,,,,,,,,,,,,, -159400,4.0320396,2.579951,,,,,,,,,,,,,, -159500,3.9752147,2.5875928,,,,,,,,,,,,,, -159600,3.7553492,2.5529504,,,,,,,,,,,,,, -159700,3.8455527,2.561417,,,,,,,,,,,,,, -159800,3.7247825,2.5606139,,,,,,,,,,,,,, -159900,3.762817,2.553697,,,,,,,,,,,,,, -160000,4.214921,2.6384037,,,,,,,,,,,,,, -160100,3.8997593,2.5641265,,,,,,,,,,,,,, -160200,4.131313,2.6958659,,,,,,,,,,,,,, -160300,3.8751209,2.5817432,,,,,,,,,,,,,, -160400,4.1002474,2.5920177,,,,,,,,,,,,,, -160500,3.9116027,2.5789766,,,,,,,,,,,,,, -160600,3.9917572,2.6254559,,,,,,,,,,,,,, -160700,4.4523625,2.6348948,,,,,,,,,,,,,, -160772,,,0.9209582209587096,0.5264387726783752,0.7565599679946899,1.198883295059204,50000.0,0.6315000057220459,1.8179877996444704,10000.0,54098.28979110718,56009.091359615326,54098.28979110718,1900.907883644104,4.739240646362305,0.0 -160800,4.062788,2.5937712,,,,,,,,,,,,,, -160900,3.800435,2.580545,,,,,,,,,,,,,, -161000,3.8665528,2.5211537,,,,,,,,,,,,,, -161100,3.7469208,2.566515,,,,,,,,,,,,,, -161200,4.0486507,2.6103663,,,,,,,,,,,,,, -161300,3.9542694,2.60307,,,,,,,,,,,,,, -161400,3.6111138,2.5443847,,,,,,,,,,,,,, -161500,3.822579,2.5340614,,,,,,,,,,,,,, -161600,4.0035505,2.5537763,,,,,,,,,,,,,, -161700,3.806733,2.613994,,,,,,,,,,,,,, -161800,3.9720476,2.5735853,,,,,,,,,,,,,, -161900,4.0750217,2.563557,,,,,,,,,,,,,, -162000,4.07923,2.5836272,,,,,,,,,,,,,, -162100,3.9358947,2.529513,,,,,,,,,,,,,, -162200,3.928958,2.593698,,,,,,,,,,,,,, -162289,,,0.9260801672935486,0.5113489627838135,0.7582199573516846,1.190921425819397,50000.0,0.6361000537872314,1.8157532215118408,10000.0,54608.4274187088,56537.093678474426,54608.4274187088,1918.6634063720703,4.798632621765137,0.0 -162300,3.8982928,2.579312,,,,,,,,,,,,,, -162400,4.10976,2.5951405,,,,,,,,,,,,,, -162500,4.005661,2.5866175,,,,,,,,,,,,,, -162600,3.8286166,2.5273998,,,,,,,,,,,,,, -162700,4.0337696,2.6001735,,,,,,,,,,,,,, -162800,4.261741,2.5650625,,,,,,,,,,,,,, -162900,4.02585,2.570735,,,,,,,,,,,,,, -163000,4.0159345,2.5783186,,,,,,,,,,,,,, -163100,4.0634713,2.5436187,,,,,,,,,,,,,, -163200,3.8508418,2.5308912,,,,,,,,,,,,,, -163300,4.199178,2.543128,,,,,,,,,,,,,, -163400,4.197121,2.619444,,,,,,,,,,,,,, -163500,3.9569588,2.5623507,,,,,,,,,,,,,, -163600,3.9315648,2.5165117,,,,,,,,,,,,,, -163700,3.9504662,2.564075,,,,,,,,,,,,,, -163800,3.8503208,2.5491557,,,,,,,,,,,,,, -163806,,,0.9243263602256776,0.5113154053688049,0.7582599520683289,1.1876296997070312,50000.0,0.6383000016212463,1.8151981830596924,10000.0,55118.39876580238,57065.14196181297,55118.39876580238,1936.598509311676,4.890327453613281,0.0 -163900,3.8828144,2.5231433,,,,,,,,,,,,,, -164000,4.018997,2.5830998,,,,,,,,,,,,,, -164100,4.120202,2.6099327,,,,,,,,,,,,,, -164200,4.32063,2.6027944,,,,,,,,,,,,,, -164300,3.7479446,2.5199165,,,,,,,,,,,,,, -164400,3.9966798,2.5916135,,,,,,,,,,,,,, -164500,3.950153,2.4927244,,,,,,,,,,,,,, -164600,4.1367245,2.5423493,,,,,,,,,,,,,, -164700,3.9128659,2.599139,,,,,,,,,,,,,, -164800,4.143602,2.6242785,,,,,,,,,,,,,, -164900,3.993884,2.5608106,,,,,,,,,,,,,, -165000,3.936447,2.556545,,,,,,,,,,,,,, -165100,4.17286,2.596151,,,,,,,,,,,,,, -165200,4.220912,2.501153,,,,,,,,,,,,,, -165300,4.088166,2.6062684,,,,,,,,,,,,,, -165323,,,0.9312419891357422,0.4898253381252289,0.7599200010299683,1.1844065189361572,50000.0,0.6380000114440918,1.814751029014588,10000.0,55628.49181294441,57593.02518582344,55628.49181294441,1954.2799079418185,4.950265169143677,0.0 -165400,3.9821053,2.5626452,,,,,,,,,,,,,, -165500,4.133637,2.5577216,,,,,,,,,,,,,, -165600,4.164848,2.532356,,,,,,,,,,,,,, -165700,3.9307442,2.5236795,,,,,,,,,,,,,, -165800,4.1601996,2.5493815,,,,,,,,,,,,,, -165900,4.195288,2.5811062,,,,,,,,,,,,,, -166000,4.1149583,2.5896134,,,,,,,,,,,,,, -166100,3.9420464,2.4895756,,,,,,,,,,,,,, -166200,4.167524,2.5728743,,,,,,,,,,,,,, -166300,3.7945802,2.5121086,,,,,,,,,,,,,, -166400,3.92818,2.546522,,,,,,,,,,,,,, -166500,4.0702624,2.550125,,,,,,,,,,,,,, -166600,3.9739864,2.5694804,,,,,,,,,,,,,, -166700,4.3444614,2.5701733,,,,,,,,,,,,,, -166800,4.4433393,2.5635805,,,,,,,,,,,,,, -166840,,,0.9327367544174194,0.4756486117839813,0.7603200078010559,1.176532745361328,50000.0,0.6389000415802002,1.793910026550293,10000.0,56138.64579749107,58120.9536485672,56138.64579749107,1971.945927143097,5.009655952453613,0.0 -166900,4.0841594,2.5352643,,,,,,,,,,,,,, -167000,4.181535,2.5556016,,,,,,,,,,,,,, -167100,4.0854974,2.5442438,,,,,,,,,,,,,, -167200,4.026479,2.5489125,,,,,,,,,,,,,, -167300,4.2511706,2.6103811,,,,,,,,,,,,,, -167400,3.8425891,2.5173566,,,,,,,,,,,,,, -167500,4.2812567,2.4913666,,,,,,,,,,,,,, -167600,4.3158493,2.573004,,,,,,,,,,,,,, -167700,4.1950326,2.5558825,,,,,,,,,,,,,, -167800,4.2136335,2.5841687,,,,,,,,,,,,,, -167900,3.9370751,2.581985,,,,,,,,,,,,,, -168000,4.7877607,2.5841777,,,,,,,,,,,,,, -168100,3.9750097,2.522625,,,,,,,,,,,,,, -168200,3.8544376,2.5650475,,,,,,,,,,,,,, -168300,4.0113544,2.5608041,,,,,,,,,,,,,, -168357,,,0.9333944320678712,0.4795606434345245,0.7605199813842773,1.178816318511963,50000.0,0.6421000361442566,1.8000293970108032,10000.0,56648.71825623512,58648.94144535065,56648.71825623512,1989.7520174980164,5.06985878944397,0.0 -168400,3.9363623,2.51125,,,,,,,,,,,,,, -168500,3.926016,2.546909,,,,,,,,,,,,,, -168600,4.2241173,2.574369,,,,,,,,,,,,,, -168700,3.9134047,2.455518,,,,,,,,,,,,,, -168800,4.006506,2.5235877,,,,,,,,,,,,,, -168900,4.0217447,2.5889573,,,,,,,,,,,,,, -169000,3.9799757,2.5203156,,,,,,,,,,,,,, -169100,4.1734333,2.5397475,,,,,,,,,,,,,, -169200,4.384988,2.5896668,,,,,,,,,,,,,, -169300,3.9552238,2.524857,,,,,,,,,,,,,, -169400,4.212606,2.5306015,,,,,,,,,,,,,, -169500,3.840797,2.4989629,,,,,,,,,,,,,, -169600,4.1165705,2.6021557,,,,,,,,,,,,,, -169700,4.186625,2.5278215,,,,,,,,,,,,,, -169800,4.5251265,2.6394877,,,,,,,,,,,,,, -169874,,,0.9356465339660645,0.4672803878784179,0.7620399594306946,1.1741501092910769,50000.0,0.6429000496864319,1.791094422340393,10000.0,57158.62016701698,59176.86003422737,57158.62016701698,2007.657867193222,5.131593704223633,0.0 -169900,3.681483,2.5441146,,,,,,,,,,,,,, -170000,4.05866,2.560075,,,,,,,,,,,,,, -170100,3.8662035,2.532405,,,,,,,,,,,,,, -170200,4.0882244,2.4854794,,,,,,,,,,,,,, -170300,3.9771068,2.5679128,,,,,,,,,,,,,, -170400,4.112166,2.4961643,,,,,,,,,,,,,, -170500,4.3786864,2.564448,,,,,,,,,,,,,, -170600,4.0657883,2.5796528,,,,,,,,,,,,,, -170700,3.9842634,2.5220993,,,,,,,,,,,,,, -170800,4.3427052,2.5698233,,,,,,,,,,,,,, -170900,4.1908817,2.502882,,,,,,,,,,,,,, -171000,4.6136823,2.5479763,,,,,,,,,,,,,, -171100,4.124865,2.5435543,,,,,,,,,,,,,, -171200,3.9173503,2.5360994,,,,,,,,,,,,,, -171300,4.0668545,2.5015848,,,,,,,,,,,,,, -171391,,,0.9360251426696776,0.4715806841850281,0.7620599865913391,1.177618145942688,50000.0,0.6449000239372253,1.793129324913025,10000.0,57668.82580900192,59704.59183359146,57668.82580900192,2025.074841976165,5.192117214202881,0.0 -171400,3.9937854,2.5107393,,,,,,,,,,,,,, -171500,3.9757621,2.4825625,,,,,,,,,,,,,, -171600,3.8447223,2.473648,,,,,,,,,,,,,, -171700,3.9654598,2.5459595,,,,,,,,,,,,,, -171800,4.2181473,2.536467,,,,,,,,,,,,,, -171900,4.118085,2.5551195,,,,,,,,,,,,,, -172000,3.7425048,2.479811,,,,,,,,,,,,,, -172100,4.236323,2.5257766,,,,,,,,,,,,,, -172200,3.7790093,2.4806113,,,,,,,,,,,,,, -172300,4.113661,2.5860298,,,,,,,,,,,,,, -172400,4.1478314,2.56927,,,,,,,,,,,,,, -172500,4.034888,2.5134623,,,,,,,,,,,,,, -172600,3.9909627,2.522847,,,,,,,,,,,,,, -172700,3.9569383,2.5243971,,,,,,,,,,,,,, -172800,4.067848,2.4764388,,,,,,,,,,,,,, -172900,4.022931,2.4831805,,,,,,,,,,,,,, -172908,,,0.935546875,0.4717677533626556,0.7625799775123596,1.175244688987732,50000.0,0.640500009059906,1.8016676902771,10000.0,58178.83932805061,60232.509991168976,58178.83932805061,2042.8664588928225,5.255893468856812,0.0 -173000,3.9351528,2.5554357,,,,,,,,,,,,,, -173100,4.043646,2.5272758,,,,,,,,,,,,,, -173200,4.132558,2.5467927,,,,,,,,,,,,,, -173300,4.0577297,2.5010545,,,,,,,,,,,,,, -173400,4.161959,2.5930476,,,,,,,,,,,,,, -173500,4.164402,2.5210748,,,,,,,,,,,,,, -173600,4.28649,2.5431178,,,,,,,,,,,,,, -173700,4.1092033,2.5497923,,,,,,,,,,,,,, -173800,4.2699533,2.5424142,,,,,,,,,,,,,, -173900,4.221042,2.527278,,,,,,,,,,,,,, -174000,4.1647425,2.48844,,,,,,,,,,,,,, -174100,3.9503198,2.511745,,,,,,,,,,,,,, -174200,3.9499238,2.5002036,,,,,,,,,,,,,, -174300,4.166147,2.5665705,,,,,,,,,,,,,, -174400,4.5143,2.5702558,,,,,,,,,,,,,, -174424,,,0.9393534660339355,0.4597503542900085,0.7628600001335144,1.1763155460357666,50000.0,0.6444000601768494,1.7961137294769287,10000.0,58688.77640485764,60760.27409243584,58688.77640485764,2060.579292535782,5.320931911468506,0.0 -174500,4.2952476,2.5267866,,,,,,,,,,,,,, -174600,4.1438065,2.5394418,,,,,,,,,,,,,, -174700,4.015923,2.52936,,,,,,,,,,,,,, -174800,4.3212376,2.5301898,,,,,,,,,,,,,, -174900,4.3164673,2.5101206,,,,,,,,,,,,,, -175000,4.244491,2.5743706,,,,,,,,,,,,,, -175100,4.195031,2.550129,,,,,,,,,,,,,, -175200,4.2442126,2.581092,,,,,,,,,,,,,, -175300,4.014283,2.5171776,,,,,,,,,,,,,, -175400,4.300995,2.5525377,,,,,,,,,,,,,, -175500,3.9014294,2.517024,,,,,,,,,,,,,, -175600,4.117464,2.5496182,,,,,,,,,,,,,, -175700,4.320457,2.5768437,,,,,,,,,,,,,, -175800,4.2998605,2.5849686,,,,,,,,,,,,,, -175900,4.380407,2.5707808,,,,,,,,,,,,,, -175941,,,0.939233899116516,0.4566811025142669,0.7627999782562256,1.177257061004639,50000.0,0.6434000134468079,1.7967631816864014,10000.0,59198.87612986565,61288.2229487896,59198.87612986565,2078.3175649642944,5.382373809814453,0.0 -176000,4.0694985,2.4795606,,,,,,,,,,,,,, -176100,4.267848,2.5188873,,,,,,,,,,,,,, -176200,4.212778,2.558085,,,,,,,,,,,,,, -176300,4.125289,2.5236857,,,,,,,,,,,,,, -176400,3.8838255,2.5040083,,,,,,,,,,,,,, -176500,4.1352506,2.5240042,,,,,,,,,,,,,, -176600,3.9865825,2.4778388,,,,,,,,,,,,,, -176700,3.8071437,2.5074828,,,,,,,,,,,,,, -176800,4.0281067,2.5402546,,,,,,,,,,,,,, -176900,4.1881456,2.57853,,,,,,,,,,,,,, -177000,4.1403184,2.546742,,,,,,,,,,,,,, -177100,4.305186,2.4978428,,,,,,,,,,,,,, -177200,4.2729764,2.538172,,,,,,,,,,,,,, -177300,3.9236436,2.5278213,,,,,,,,,,,,,, -177400,4.3078074,2.5408993,,,,,,,,,,,,,, -177458,,,0.9397919178009032,0.4564704298973083,0.7625600099563599,1.1731715202331543,50000.0,0.6452000141143799,1.7910445928573608,10000.0,59708.8821656704,61815.838076114655,59708.8821656704,2095.821103572845,5.4390709400177,0.0 -177500,4.1999116,2.5269685,,,,,,,,,,,,,, -177600,3.94897,2.5118876,,,,,,,,,,,,,, -177700,3.971899,2.5309227,,,,,,,,,,,,,, -177800,4.225193,2.5561366,,,,,,,,,,,,,, -177900,3.9495142,2.5449204,,,,,,,,,,,,,, -178000,3.9928942,2.5352576,,,,,,,,,,,,,, -178100,4.123058,2.5237865,,,,,,,,,,,,,, -178200,4.2093353,2.5287008,,,,,,,,,,,,,, -178300,4.3655543,2.5383172,,,,,,,,,,,,,, -178400,4.194995,2.4727244,,,,,,,,,,,,,, -178500,4.3150387,2.551554,,,,,,,,,,,,,, -178600,3.7717502,2.4589617,,,,,,,,,,,,,, -178700,3.9820788,2.544698,,,,,,,,,,,,,, -178800,4.166389,2.5191195,,,,,,,,,,,,,, -178900,3.9917457,2.515371,,,,,,,,,,,,,, -178976,,,0.9402303695678712,0.4584923982620239,0.7633799910545349,1.1745049953460691,50000.0,0.6419000029563904,1.791907548904419,10000.0,60219.034625291824,62343.77858257294,60219.034625291824,2113.4979746341705,5.501019239425659,0.0 -179000,4.1642466,2.5108817,,,,,,,,,,,,,, -179100,4.05912,2.5582097,,,,,,,,,,,,,, -179200,4.1132994,2.56887,,,,,,,,,,,,,, -179300,3.9449878,2.4892964,,,,,,,,,,,,,, -179400,4.376652,2.5323734,,,,,,,,,,,,,, -179500,4.079877,2.5014384,,,,,,,,,,,,,, -179600,3.9025574,2.4814181,,,,,,,,,,,,,, -179700,4.033746,2.4583175,,,,,,,,,,,,,, -179800,4.0813828,2.5040135,,,,,,,,,,,,,, -179900,4.263307,2.5564756,,,,,,,,,,,,,, -180000,3.7765644,2.4737208,,,,,,,,,,,,,, -180100,4.2069497,2.5230305,,,,,,,,,,,,,, -180200,4.2058997,2.5644913,,,,,,,,,,,,,, -180300,4.2770767,2.558032,,,,,,,,,,,,,, -180400,4.103121,2.4626076,,,,,,,,,,,,,, -180493,,,0.9389349222183228,0.4585332870483398,0.7633999586105347,1.175327181816101,50000.0,0.6457000374794006,1.7949339151382446,10000.0,60729.01041841507,62871.621817588806,60729.01041841507,2131.249956130981,5.567572593688965,0.0 -180500,3.9938574,2.4842088,,,,,,,,,,,,,, -180600,4.3943634,2.5619962,,,,,,,,,,,,,, -180700,4.077052,2.502729,,,,,,,,,,,,,, -180800,4.434006,2.5288239,,,,,,,,,,,,,, -180900,3.9252486,2.54657,,,,,,,,,,,,,, -181000,4.0904574,2.5108438,,,,,,,,,,,,,, -181100,3.9822257,2.5089269,,,,,,,,,,,,,, -181200,4.044971,2.5152,,,,,,,,,,,,,, -181300,4.1255517,2.5124664,,,,,,,,,,,,,, -181400,4.0750165,2.4875028,,,,,,,,,,,,,, -181500,4.166805,2.5621512,,,,,,,,,,,,,, -181600,3.9668686,2.4602113,,,,,,,,,,,,,, -181700,4.4133205,2.5388486,,,,,,,,,,,,,, -181800,4.110458,2.5378351,,,,,,,,,,,,,, -181900,3.8130887,2.4778748,,,,,,,,,,,,,, -182000,4.0664115,2.528366,,,,,,,,,,,,,, -182010,,,0.93949294090271,0.4553861320018768,0.7636399865150452,1.1701064109802246,50000.0,0.6449000239372253,1.789433240890503,10000.0,61239.120810985565,63399.412459373474,61239.120810985565,2148.8214042186737,5.627314567565918,0.0 -182100,3.951418,2.5404978,,,,,,,,,,,,,, -182200,4.2385697,2.5450428,,,,,,,,,,,,,, -182300,3.959441,2.5426304,,,,,,,,,,,,,, -182400,3.9744961,2.503992,,,,,,,,,,,,,, -182500,4.0031757,2.4947028,,,,,,,,,,,,,, -182600,3.998686,2.5147943,,,,,,,,,,,,,, -182700,3.8482158,2.5347528,,,,,,,,,,,,,, -182800,4.521802,2.5602674,,,,,,,,,,,,,, -182900,4.142565,2.5400305,,,,,,,,,,,,,, -183000,4.1040025,2.5309603,,,,,,,,,,,,,, -183100,4.0787835,2.4733617,,,,,,,,,,,,,, -183200,3.986768,2.491179,,,,,,,,,,,,,, -183300,4.23004,2.522276,,,,,,,,,,,,,, -183400,4.059473,2.5360615,,,,,,,,,,,,,, -183500,4.379259,2.482521,,,,,,,,,,,,,, -183527,,,0.9393534660339355,0.4550036489963531,0.7640799880027771,1.1725746393203735,50000.0,0.6444000601768494,1.7934223413467407,10000.0,61749.29934263229,63927.24230384827,61749.29934263229,2166.364020586014,5.686914920806885,0.0 -183600,4.1785645,2.505999,,,,,,,,,,,,,, -183700,3.8663712,2.4531457,,,,,,,,,,,,,, -183800,3.8170857,2.5150468,,,,,,,,,,,,,, -183900,4.1289067,2.5224402,,,,,,,,,,,,,, -184000,3.934199,2.4910107,,,,,,,,,,,,,, -184100,4.2572117,2.5960026,,,,,,,,,,,,,, -184200,4.1805654,2.5462282,,,,,,,,,,,,,, -184300,4.2779903,2.614036,,,,,,,,,,,,,, -184400,4.330847,2.5193691,,,,,,,,,,,,,, -184500,4.1019325,2.5103445,,,,,,,,,,,,,, -184600,4.0211034,2.5069385,,,,,,,,,,,,,, -184700,4.3595552,2.516392,,,,,,,,,,,,,, -184800,3.9374394,2.4592109,,,,,,,,,,,,,, -184900,4.1468787,2.5216503,,,,,,,,,,,,,, -185000,4.4544635,2.5314553,,,,,,,,,,,,,, -185044,,,0.9412667155265808,0.4489424228668213,0.763979971408844,1.1695361137390137,50000.0,0.6449000239372253,1.789334416389465,10000.0,62259.41258692741,64455.15799450874,62259.41258692741,2184.0561985969543,5.7481689453125,0.0 -185100,4.280186,2.5037303,,,,,,,,,,,,,, -185200,3.8952866,2.5013843,,,,,,,,,,,,,, -185300,4.1587896,2.4926417,,,,,,,,,,,,,, -185400,4.32398,2.5150049,,,,,,,,,,,,,, -185500,4.0940995,2.545905,,,,,,,,,,,,,, -185600,3.876801,2.5188851,,,,,,,,,,,,,, -185700,4.159071,2.5193214,,,,,,,,,,,,,, -185800,4.4247637,2.5313015,,,,,,,,,,,,,, -185900,4.196279,2.549603,,,,,,,,,,,,,, -186000,3.9949033,2.5168839,,,,,,,,,,,,,, -186100,3.9723177,2.5456507,,,,,,,,,,,,,, -186200,4.03613,2.48597,,,,,,,,,,,,,, -186300,4.240896,2.5505562,,,,,,,,,,,,,, -186400,4.1572194,2.5187511,,,,,,,,,,,,,, -186500,4.120488,2.5275826,,,,,,,,,,,,,, -186561,,,0.9403699040412904,0.452570378780365,0.7641800045967102,1.1703591346740725,50000.0,0.6455000042915344,1.7891464233398438,10000.0,62769.59384036064,64982.96338915825,62769.59384036064,2201.5659034252167,5.813524484634399,0.0 -186600,3.7103186,2.4688673,,,,,,,,,,,,,, -186666,,,0.9392936825752258,0.4516395628452301,0.7646999955177307,1.167160987854004,50000.0,0.64410001039505,1.787272572517395,10000.0,62804.61339139938,65035.68202996254,62804.61339139938,2219.197128772736,5.877562046051025,0.0 -186666,,,,,,,,,,,62804.61339139938,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 194244fd1..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.657403707504272,0.0,30.54241108894348,1,0,30.54241108894348,0.0006000000284984,6.910250186920166,10000,48.19990372657776,0.0006576849264092,6.909994125366211,0.0009599999757483,6.910243988037109,50000 -35.58681917190552,0.0257160663604736,540.7001566886902,1514,0,540.7001566886902,0.0486000031232833,5.617886543273926,10000,576.3619163036346,0.0730229541659355,5.302392959594727,0.0668999999761581,5.365999221801758,50000 -53.62484097480774,0.0550439357757568,1050.9415996074677,3027,0,1050.9415996074677,0.1161000058054924,4.863234043121338,10000,1104.719381570816,0.1718152016401291,4.267473220825195,0.1560599952936172,4.383291721343994,50000 -71.31632661819458,0.0830860137939453,1560.8828208446505,4539,0,1560.8828208446505,0.1802000105381012,4.247800827026367,10000,1632.4283664226532,0.2708067595958709,3.4995758533477783,0.2490399926900863,3.6413185596466064,50000 -89.48260450363159,0.1137394905090332,2071.0617468357086,6051,0,2071.0617468357086,0.2535000145435333,3.7423603534698486,10000,2160.853071212769,0.3694595098495483,2.88207745552063,0.3414799869060516,3.04559588432312,50000 -108.95680356025696,0.144883394241333,2581.108203172684,7563,0,2581.108203172684,0.2979000210762024,3.458079099655152,10000,2690.453903913498,0.4502351582050323,2.4000375270843506,0.3928200006484985,2.7543821334838867,50000 -126.6402349472046,0.1710314750671386,3091.184275150299,9076,0,3091.184275150299,0.3509000241756439,3.098984479904175,10000,3218.2887029647827,0.4960738122463226,2.1706013679504395,0.4490000009536743,2.427058219909668,50000 -144.45068073272705,0.2024247646331787,3601.261110067368,10590,0,3601.261110067368,0.3641000092029571,3.029717445373535,10000,3746.2567975521088,0.5216438174247742,2.0399041175842285,0.4847599864006042,2.257338762283325,50000 -162.51656198501587,0.2360508441925048,4111.291719198227,12104,0,4111.291719198227,0.3975000083446502,2.803204298019409,10000,4274.436697483063,0.5573580861091614,1.8680070638656616,0.5182600021362305,2.074094533920288,50000 -180.25822043418884,0.2662384510040283,4621.447257757187,13620,0,4621.447257757187,0.4131000339984894,2.741434097290039,10000,4802.413382053375,0.5631775856018066,1.82098388671875,0.5280199646949768,2.02920913696289,50000 -198.1563205718994,0.3011949062347412,5131.6218910217285,15136,0,5131.6218910217285,0.4170000255107879,2.7061638832092285,10000,5330.570430994034,0.5808952450752258,1.7555890083312988,0.5410000085830688,1.951683759689331,50000 -216.2797749042511,0.3344736099243164,5641.749358415604,16652,0,5641.749358415604,0.4243000149726867,2.689218759536743,10000,5858.903384923935,0.6128427982330322,1.5677685737609863,0.5491600036621094,1.9117090702056885,50000 -234.1956398487091,0.3683803081512451,6151.70788192749,18168,0,6151.70788192749,0.4319000244140625,2.617641925811768,10000,6386.860626220703,0.6028379797935486,1.6164047718048096,0.5559599995613098,1.8733510971069336,50000 -252.2565195560456,0.4041624069213867,6661.777126550674,19684,0,6661.777126550674,0.4443000257015228,2.5521981716156006,10000,6915.075645208359,0.6129025816917419,1.5699564218521118,0.5642799735069275,1.820169448852539,50000 -270.24809288978577,0.4388735294342041,7172.218660593033,21201,0,7172.218660593033,0.4500000178813934,2.5406386852264404,10000,7443.592651605606,0.6129623651504517,1.5876972675323486,0.5711599588394165,1.807672500610352,50000 -287.95379066467285,0.4722630977630615,7682.367743968964,22718,0,7682.367743968964,0.4487000107765198,2.5585756301879883,10000,7971.530035734177,0.6120854616165161,1.585247039794922,0.5758599638938904,1.7821930646896362,50000 -305.84978795051575,0.5091888904571533,8192.398866891861,24235,0,8192.398866891861,0.4481000304222107,2.5390639305114746,10000,8499.543118476868,0.6190210580825806,1.554908275604248,0.5786600112915039,1.7810800075531006,50000 -323.8264684677124,0.5421411991119385,8702.434856891632,25753,0,8702.434856891632,0.4615000188350677,2.4996163845062256,10000,9027.637979269028,0.644551157951355,1.4304097890853882,0.5829799771308899,1.7570055723190308,50000 -341.5526399612427,0.5781388282775879,9212.556456327438,27271,0,9212.556456327438,0.4508000314235687,2.5254218578338623,10000,9555.571353673937,0.62890625,1.5057846307754517,0.5762799978256226,1.7802006006240845,50000 -359.44764280319214,0.6132323741912842,9722.534570932388,28788,0,9722.534570932388,0.4659000337123871,2.4030990600585938,10000,10083.528599262238,0.6389309763908386,1.4611449241638184,0.5892400145530701,1.712038516998291,50000 -377.1987464427948,0.6481332778930664,10232.679866552353,30306,0,10232.679866552353,0.4637000262737274,2.4629597663879395,10000,10611.508907079697,0.6354631781578064,1.4758671522140503,0.5902199745178223,1.7170330286026,50000 -395.0763454437256,0.6836209297180176,10742.683179616928,31824,0,10742.683179616928,0.4649000167846679,2.449855327606201,10000,11139.474148750303,0.6270925998687744,1.5033053159713743,0.5899400115013123,1.7156438827514648,50000 -413.0981290340424,0.716942310333252,11252.691735982897,33342,0,11252.691735982897,0.4624000191688537,2.4317657947540283,10000,11667.586690425873,0.6690449714660645,1.3072504997253418,0.5861600041389465,1.723919153213501,50000 -431.0991203784943,0.7553849220275879,11762.682977676392,34860,0,11762.682977676392,0.4726000130176544,2.405595541000366,10000,12195.666572332382,0.6469627022743225,1.3967673778533936,0.5877000093460083,1.7086846828460691,50000 -449.1371431350708,0.7958519458770752,12272.765510082245,36377,0,12272.765510082245,0.4793000221252441,2.3845810890197754,10000,12723.876703977585,0.6559311151504517,1.3725244998931885,0.6061399579048157,1.6266961097717283,50000 -466.961225271225,0.8347640037536621,12782.802032232285,37896,0,12782.802032232285,0.4638000130653381,2.486398220062256,10000,13251.825589179993,0.6361208558082581,1.4662803411483765,0.5847199559211731,1.7373977899551392,50000 -484.7647216320038,0.873910665512085,13292.97028517723,39414,0,13292.97028517723,0.48130002617836,2.3708035945892334,10000,13779.885905742643,0.6498525142669678,1.4036788940429688,0.6021599769592285,1.6521224975585938,50000 -502.8075284957886,0.9099681377410888,13803.110150814056,40933,0,13803.110150814056,0.4783000349998474,2.3724279403686523,10000,14308.154326200483,0.6477798223495483,1.4108294248580933,0.6057999730110168,1.6259907484054563,50000 -520.8403308391571,0.9460337162017822,14313.330046653748,42453,0,14313.330046653748,0.4844000339508056,2.3288164138793945,10000,14836.49200987816,0.6938576102256775,1.2015491724014282,0.6082800030708313,1.6163575649261477,50000 -538.9052393436432,0.9865646362304688,14823.266336917875,43972,0,14823.266336917875,0.4779000282287597,2.388891696929932,10000,15364.582607030869,0.6569873690605164,1.3638534545898438,0.5990399718284607,1.6688741445541382,50000 -556.6109170913696,1.0245742797851562,15333.213256835938,45491,0,15333.213256835938,0.4779000282287597,2.395018577575684,10000,15892.32246518135,0.6528021097183228,1.3831080198287964,0.6025800108909607,1.646975874900818,50000 -575.4105768203735,1.063220739364624,15843.245793819427,47009,0,15843.245793819427,0.4746000170707702,2.371513843536377,10000,16421.24270606041,0.6518455147743225,1.3824917078018188,0.6039599776268005,1.638573408126831,50000 -593.1785054206848,1.1000022888183594,16353.201929330826,48528,0,16353.201929330826,0.4837000370025635,2.336787700653076,10000,16949.052980184555,0.6499919891357422,1.4085599184036257,0.602620005607605,1.6432632207870483,50000 -611.1795015335083,1.139614820480347,16863.183659791946,50046,0,16863.183659791946,0.4767000079154968,2.3756377696990967,10000,17477.124361991882,0.6486168503761292,1.416173219680786,0.6020399928092957,1.6556391716003418,50000 -629.0958392620087,1.1804418563842771,17373.11920595169,51565,0,17373.11920595169,0.4915000200271606,2.3638558387756348,10000,18005.066600322723,0.6805644035339355,1.269840955734253,0.6078000068664551,1.6356245279312134,50000 -646.7576491832733,1.2222459316253662,17883.04052567482,53084,0,17883.04052567482,0.4881000220775604,2.3278653621673584,10000,18532.7405500412,0.6640226244926453,1.340054988861084,0.6099599599838257,1.6146225929260254,50000 -664.7858171463013,1.262617111206055,18393.22454881668,54603,0,18393.22454881668,0.496500015258789,2.257696866989136,10000,19061.04249453545,0.6684470772743225,1.310733675956726,0.6151599884033203,1.5864496231079102,50000 -682.5275778770447,1.305849313735962,18903.22807765007,56122,0,18903.22807765007,0.4861000180244446,2.338015079498291,10000,19588.880088090897,0.6581233739852905,1.3662872314453125,0.6144399642944336,1.60483980178833,50000 -700.529173374176,1.3492977619171145,19413.14436841011,57640,0,19413.14436841011,0.4958000183105469,2.2923803329467773,10000,20116.891170024872,0.6615911722183228,1.3483868837356567,0.615559995174408,1.5824005603790283,50000 -718.4421739578247,1.3909647464752195,19923.271131277084,59161,0,19923.271131277084,0.4762000143527984,2.395241498947144,10000,20645.021767616272,0.6617307066917419,1.3514286279678345,0.6044600009918213,1.642971396446228,50000 -736.025773525238,1.4345412254333496,20433.406439065933,60682,0,20433.406439065933,0.489300012588501,2.322281837463379,10000,21172.83334803581,0.6777144074440002,1.2699861526489258,0.6165800094604492,1.5854922533035278,50000 -753.8541922569275,1.4773738384246826,20943.60581111908,62201,0,20943.60581111908,0.4894000291824341,2.293773651123047,10000,21700.95312690735,0.6781728267669678,1.2760369777679443,0.6189999580383301,1.5650577545166016,50000 -771.4942946434021,1.524343490600586,21453.79224085808,63721,0,21453.79224085808,0.4957000315189361,2.286545991897583,10000,22228.8758354187,0.6701809763908386,1.300941824913025,0.6164000034332275,1.5907514095306396,50000 -789.372960805893,1.5687105655670166,21963.94081664085,65241,0,21963.94081664085,0.4811000227928161,2.425404787063598,10000,22756.996761083603,0.6570671200752258,1.3706705570220947,0.6152200102806091,1.615429162979126,50000 -807.3948268890381,1.6115176677703855,22474.06643724441,66761,0,22474.06643724441,0.4582000076770782,2.52813458442688,10000,23285.236786603928,0.6267139315605164,1.5229088068008425,0.5832399725914001,1.7533199787139893,50000 -825.0936350822449,1.6541671752929688,22984.056749343872,68281,0,22984.056749343872,0.5125000476837158,2.1850202083587646,10000,23813.01803874969,0.7253667116165161,1.0721691846847534,0.6342999935150146,1.5084197521209717,50000 -842.8929722309113,1.6980137825012207,23493.977559566498,69799,0,23493.977559566498,0.4927000105381012,2.3095078468322754,10000,24340.83118534088,0.6788305044174194,1.2620773315429688,0.6158999800682068,1.5898820161819458,50000 -860.7975602149963,1.7430338859558103,24004.088791370392,71319,0,24004.088791370392,0.5060000419616699,2.2284297943115234,10000,24868.941687583923,0.6802256107330322,1.253443956375122,0.6263799667358398,1.537427306175232,50000 -878.8791308403015,1.7854652404785156,24514.16850876808,72838,0,24514.16850876808,0.5053000450134277,2.216010332107544,10000,25397.194969654083,0.6843112111091614,1.24107563495636,0.6298199892044067,1.5166643857955933,50000 -896.9066410064697,1.8266286849975584,25024.354562044144,74359,0,25024.354562044144,0.5045000314712524,2.2394707202911377,10000,25925.49914598465,0.6735291481018066,1.2930454015731812,0.6256600022315979,1.5487301349639893,50000 -914.560183763504,1.8696751594543457,25534.386019468307,75878,0,25534.386019468307,0.5094000101089478,2.201269149780273,10000,26453.276960134503,0.6875996589660645,1.2251676321029663,0.6364799737930298,1.4780266284942627,50000 -932.3965713977814,1.9160151481628416,26044.426952838898,77398,0,26044.426952838898,0.4964000284671783,2.239720582962036,10000,26981.249837636948,0.7134685516357422,1.1057634353637695,0.6320599913597107,1.5262938737869265,50000 -950.2284531593324,1.964091300964356,26554.584349632263,78918,0,26554.584349632263,0.511900007724762,2.2021243572235107,10000,27509.33651995659,0.7004145383834839,1.1687120199203491,0.6345599889755249,1.4978222846984863,50000 -968.1738333702089,2.00515365600586,27064.51063919068,80438,0,27064.51063919068,0.4999000132083893,2.279034376144409,10000,28037.29874444008,0.6895527839660645,1.2138910293579102,0.6326000094413757,1.5095300674438477,50000 -986.3525202274324,2.056394100189209,27574.732538461685,81958,0,27574.732538461685,0.5067999958992004,2.243538379669189,10000,28565.79950976372,0.6863241195678711,1.2330068349838257,0.6323399543762207,1.5172523260116575,50000 -1004.3920419216156,2.1011264324188232,28084.9185359478,83478,0,28084.9185359478,0.5056000351905823,2.229276418685913,10000,29094.119473934174,0.6814213991165161,1.2530291080474854,0.6282199621200562,1.5206059217453003,50000 -1022.558931350708,2.1451663970947266,28594.97287845612,84998,0,28594.97287845612,0.5214000344276428,2.1729736328125,10000,29622.433605909348,0.7023875713348389,1.1722116470336914,0.6421599984169006,1.463862419128418,50000 -1040.2748339176178,2.191434383392334,29105.079597473145,86518,0,29105.079597473145,0.5195000171661377,2.1262402534484863,10000,30150.35160303116,0.7252471446990967,1.0614542961120603,0.6484799981117249,1.429309368133545,50000 -1058.1022355556488,2.241151094436645,29615.25696110725,88037,0,29615.25696110725,0.518500030040741,2.1785125732421875,10000,30678.455330610275,0.7028061151504517,1.1444787979125977,0.6430799961090088,1.4514235258102417,50000 -1076.1300678253174,2.2871439456939697,30125.36184811592,89557,0,30125.36184811592,0.5265000462532043,2.140740156173706,10000,31206.683420419693,0.7121930718421936,1.1185708045959473,0.6528599858283997,1.4161499738693235,50000 -1093.9589052200315,2.341646194458008,30635.47507429123,91078,0,30635.47507429123,0.5210000276565552,2.142979145050049,10000,31734.72960019112,0.7091438174247742,1.1337120532989502,0.6495800018310547,1.4375072717666626,50000 -1111.7064609527588,2.390122890472412,31145.39885497093,92598,0,31145.39885497093,0.5238000154495239,2.1197195053100586,10000,32262.498690605164,0.7068518400192261,1.1458536386489868,0.6547399759292603,1.4125068187713623,50000 -1129.5105464458466,2.4372799396514893,31655.37549734116,94119,0,31655.37549734116,0.5188000202178955,2.1304824352264404,10000,32790.37596988678,0.7403738498687744,1.0028444528579712,0.6489399671554565,1.430841326713562,50000 -1147.253174781799,2.4813926219940186,32165.38191127777,95639,0,32165.38191127777,0.5311000347137451,2.102889060974121,10000,33318.218418598175,0.7263033986091614,1.055260419845581,0.6544399857521057,1.4071980714797974,50000 -1165.2235417366028,2.5274858474731445,32675.390652418137,97159,0,32675.390652418137,0.5247000455856323,2.1374266147613525,10000,33846.29317569733,0.7153021097183228,1.094274640083313,0.647059977054596,1.4412496089935305,50000 -1183.1489372253418,2.5802924633026123,33185.30989527702,98679,0,33185.30989527702,0.5229000449180603,2.1488847732543945,10000,34374.24014735222,0.7119140625,1.1197112798690796,0.6540799736976624,1.4109952449798584,50000 -1200.9140849113464,2.625534772872925,33695.51848888397,100199,0,33695.51848888397,0.5344000458717346,2.116262435913086,10000,34902.30881524086,0.7194873690605164,1.0825355052947998,0.6628199815750122,1.372917652130127,50000 -1218.8947851657867,2.6776158809661865,34205.558450460434,101718,0,34205.558450460434,0.5348000526428223,2.0898184776306152,10000,35430.43071985245,0.7230349183082581,1.0770888328552246,0.6646599769592285,1.3627052307128906,50000 -1236.8908894062042,2.725346565246582,34715.77932167053,103239,0,34715.77932167053,0.537600040435791,2.0718114376068115,10000,35958.745235681534,0.7520726919174194,0.94245183467865,0.662559986114502,1.3745607137680054,50000 -1254.5282595157623,2.777066469192505,35225.935331106186,104760,0,35225.935331106186,0.541700005531311,2.064157247543335,10000,36486.63972115517,0.7437419891357422,0.9715170860290528,0.6689199805259705,1.338512659072876,50000 -1272.1391229629517,2.83117151260376,35736.099576711655,106281,0,35736.099576711655,0.5430000424385071,2.052748441696167,10000,37014.51857161522,0.7352120280265808,1.0203943252563477,0.6659799814224243,1.3560991287231443,50000 -1290.1619033813477,2.8791487216949463,36246.10291814804,107801,0,36246.10291814804,0.5539000034332275,1.9979313611984253,10000,37542.64195537567,0.7411311864852905,0.9817953109741212,0.675279974937439,1.3169835805892944,50000 -1308.0986626148224,2.928966760635376,36756.17011857033,109321,0,36756.17011857033,0.5401000380516052,2.053678274154663,10000,38070.74506354332,0.7330197691917419,1.0413142442703247,0.670960009098053,1.3357013463974,50000 -1325.7961132526398,2.9773285388946533,37266.25584864616,110842,0,37266.25584864616,0.5503000020980835,2.027350187301636,10000,38598.62584900856,0.7404536008834839,0.9842280745506288,0.675819993019104,1.313747763633728,50000 -1343.4937970638275,3.02756404876709,37776.244634628296,112363,0,37776.244634628296,0.5544000267982483,1.9689432382583616,10000,39126.41235136986,0.7700294852256775,0.8577762246131897,0.6807999610900879,1.2912861108779907,50000 -1361.454597234726,3.0750892162323,38286.34936618805,113883,0,38286.34936618805,0.542900025844574,2.041203260421753,10000,39654.57463693619,0.7518534660339355,0.9311055541038512,0.6769199967384338,1.313949704170227,50000 -1379.3148682117462,3.123174905776977,38796.49595832825,115404,0,38796.49595832825,0.5494000315666199,1.9755940437316888,10000,40182.67863607407,0.7578722834587097,0.92294579744339,0.6810199618339539,1.289506435394287,50000 -1397.2055568695068,3.1755475997924805,39306.57546567917,116925,0,39306.57546567917,0.563800036907196,1.953534245491028,10000,40710.75027012825,0.7563177347183228,0.92183256149292,0.6868000030517578,1.2699737548828125,50000 -1414.9953515529633,3.2276132106781006,39816.80427622795,118447,0,39816.80427622795,0.5552999973297119,1.979101061820984,10000,41238.8702480793,0.7603236436843872,0.9137925505638124,0.6899399757385254,1.2629296779632568,50000 -1432.9292194843292,3.300215721130371,40326.80095338821,119967,0,40326.80095338821,0.5565000176429749,1.9592362642288208,10000,41766.92306900024,0.7940250039100647,0.7722108364105225,0.6904199719429016,1.25036358833313,50000 -1450.694310426712,3.352198362350464,40836.91822504997,121487,0,40836.91822504997,0.5522000193595886,1.993802189826965,10000,42294.90674686432,0.7747727632522583,0.8482697010040283,0.6867799758911133,1.2761807441711426,50000 -1468.680810213089,3.4059770107269287,41346.84619688988,123007,0,41346.84619688988,0.5526000261306763,2.0052297115325928,10000,42822.92417263985,0.762137234210968,0.8999484181404114,0.6813399791717529,1.2865992784500122,50000 -1486.8664498329165,3.4564247131347656,41856.85910201073,124527,0,41856.85910201073,0.5652000308036804,1.933935284614563,10000,43351.22230386734,0.7736367583274841,0.8483324646949768,0.692579984664917,1.2330009937286377,50000 -1504.823757648468,3.5072178840637207,42366.83264923096,126047,0,42366.83264923096,0.570900022983551,1.8856645822525024,10000,43879.25383400917,0.7742147445678711,0.844782292842865,0.6972999572753906,1.2068699598312378,50000 -1522.5617182254791,3.56189227104187,42876.88329672813,127567,0,42876.88329672813,0.5746000409126282,1.928336262702942,10000,44407.14696741104,0.7678172588348389,0.8690696358680725,0.6976400017738342,1.232379674911499,50000 -1540.3736391067505,3.616084575653076,43387.10359358788,129088,0,43387.10359358788,0.5788000226020813,1.8793182373046875,10000,44935.282566308975,0.8156289458274841,0.6751767992973328,0.7046799659729004,1.1848829984664917,50000 -1558.0184445381165,3.6664631366729736,43897.194222450256,130608,0,43897.194222450256,0.5804000496864319,1.8776094913482664,10000,45463.11754608154,0.7964963316917419,0.7662532925605774,0.7048999667167664,1.1896097660064695,50000 -1575.7394952774048,3.723146915435791,44407.26432728768,132129,0,44407.26432728768,0.579800009727478,1.8872145414352417,10000,45991.0145611763,0.7948620915412903,0.7571321725845337,0.7081599831581116,1.174660325050354,50000 -1593.8583455085754,3.777401447296143,44917.38735723496,133651,0,44917.38735723496,0.5773000121116638,1.888568878173828,10000,46519.36021995544,0.7874082922935486,0.7908145189285278,0.7020399570465088,1.193451166152954,50000 -1611.7913794517517,3.833778142929077,45427.400640010834,135172,0,45427.400640010834,0.5895000100135803,1.8212326765060425,10000,47047.412162303925,0.8009805083274841,0.7329120635986328,0.7127000093460083,1.1449772119522097,50000 -1629.5676186084747,3.88659930229187,45937.60295057297,136692,0,45937.60295057297,0.5819000005722046,1.8573424816131592,10000,47575.49288535118,0.7917529940605164,0.7613667845726013,0.7073599696159363,1.1776471138000488,50000 -1647.4524323940277,3.945059061050415,46447.6070895195,138213,0,46447.6070895195,0.589900016784668,1.820005178451538,10000,48103.48945856094,0.8251953125,0.6313984990119934,0.7180399894714355,1.1386464834213257,50000 -1665.1089329719543,4.003594875335693,46957.71826648712,139735,0,46957.71826648712,0.5910000205039978,1.8218023777008057,10000,48631.36492419243,0.8220463991165161,0.6473369002342224,0.7188199758529663,1.1304951906204224,50000 -1683.0212621688845,4.060739040374756,47467.91790008545,141255,0,47467.91790008545,0.597100019454956,1.824455499649048,10000,49159.58340501785,0.8222058415412903,0.6530354619026184,0.7206199765205383,1.1245300769805908,50000 -1700.9479422569275,4.118076086044312,47978.107671022415,142775,0,47978.107671022415,0.5913000106811523,1.8181012868881223,10000,49687.80672287941,0.8192163705825806,0.6568843126296997,0.722819983959198,1.1111185550689695,50000 -1718.752968788147,4.174353361129761,48488.26706719399,144296,0,48488.26706719399,0.5961000323295593,1.785703420639038,10000,50215.87666511536,0.8274872303009033,0.6358198523521423,0.7251600027084351,1.0935237407684326,50000 -1736.8405735492706,4.234951734542847,48998.32432341576,145817,0,48998.32432341576,0.6053000092506409,1.75275981426239,10000,50744.13141441345,0.8562061190605164,0.5220614075660706,0.7297799587249756,1.0798618793487549,50000 -1754.6182186603546,4.2909040451049805,49508.459768772125,147337,0,49508.459768772125,0.604200005531311,1.7619547843933103,10000,51272.14950942993,0.8483338356018066,0.5386354327201843,0.7303999662399292,1.0842643976211548,50000 -1772.192732810974,4.347080707550049,50018.51382923126,148857,0,50018.51382923126,0.6065000295639038,1.7632725238800049,10000,51799.88316822052,0.8486328125,0.5442114472389221,0.7340599894523621,1.065670132637024,50000 -1789.946456670761,4.403222799301148,50528.65114068985,150377,0,50528.65114068985,0.6093000173568726,1.749652624130249,10000,52327.87962150574,0.8487922549247742,0.5424256324768066,0.7318800091743469,1.0758512020111084,50000 -1807.7406420707705,4.458520889282227,51038.67733478546,151898,0,51038.67733478546,0.6142000555992126,1.7403059005737305,10000,52855.8038623333,0.8529974222183228,0.52959805727005,0.7369999885559082,1.0513129234313965,50000 -1825.565360069275,4.516251564025879,51548.73529744148,153418,0,51548.73529744148,0.609000027179718,1.75309419631958,10000,53383.79339146614,0.8550103306770325,0.5078051090240479,0.7404800057411194,1.0449867248535156,50000 -1843.183394432068,4.576179504394531,52058.83040165901,154938,0,52058.83040165901,0.617400050163269,1.7246414422988892,10000,53911.61543941498,0.8839883208274841,0.4073780179023742,0.7401799559593201,1.042005181312561,50000 -1860.8247520923608,4.636116027832031,52568.99254751205,156460,0,52568.99254751205,0.619100034236908,1.7108631134033203,10000,54439.52820849419,0.8816167116165161,0.4281372129917145,0.7430399656295776,1.0246366262435913,50000 -1878.545128583908,4.690703630447388,53079.10682106018,157980,0,53079.10682106018,0.6183000206947327,1.7170692682266235,10000,54967.46635508537,0.8804408311843872,0.4245630204677582,0.7468999624252319,1.0233200788497925,50000 -1896.2643899917605,4.749074697494507,53589.45731592178,159501,0,53589.45731592178,0.6223000288009644,1.720950484275818,10000,55495.64396739006,0.8826530575752258,0.4148502647876739,0.7451199889183044,1.0272034406661987,50000 -1914.0802400112152,4.812868118286133,54099.56357502937,161022,0,54099.56357502937,0.6247000098228455,1.695212960243225,10000,56023.67906188965,0.8878945708274841,0.3984717428684234,0.7486799955368042,1.0120301246643066,50000 -1932.141437768936,4.868780136108398,54609.65087771416,162541,0,54609.65087771416,0.6285000443458557,1.6943116188049316,10000,56551.9337553978,0.8908840417861938,0.3822648227214813,0.751259982585907,1.008744239807129,50000 -1949.9157021045685,4.9245476722717285,55119.58855962753,164062,0,55119.58855962753,0.6237000226974487,1.7114315032958984,10000,57079.75123023987,0.9062101244926452,0.3281794488430023,0.7523199915885925,1.0044119358062744,50000 -1967.455517053604,4.982961177825928,55629.49757552147,165582,0,55629.49757552147,0.6300000548362732,1.6897133588790894,10000,57607.30726504326,0.9041573405265808,0.3352792263031006,0.7546600103378296,0.99237322807312,50000 -1985.1188147068024,5.042438745498657,56139.54078269005,167102,0,56139.54078269005,0.6339000463485718,1.6737549304962158,10000,58135.12322330475,0.9091398119926452,0.3203354179859161,0.7573999762535095,0.987073004245758,50000 -2002.999568939209,5.100426197052002,56649.45196771622,168623,0,56649.45196771622,0.6332000494003296,1.6756737232208252,10000,58663.02194476128,0.9107740521430968,0.3188508749008178,0.7586399912834167,0.9830759763717652,50000 -2020.919572353363,5.164439678192139,57159.36017847061,170143,0,57159.36017847061,0.6368000507354736,1.6767271757125854,10000,59190.96336650848,0.9118303656578064,0.3059313297271728,0.7596799731254578,0.977874517440796,50000 -2038.8429753780365,5.226194143295288,57669.37860417366,171663,0,57669.37860417366,0.6338000297546387,1.6702979803085327,10000,59719.01620292664,0.9235889315605164,0.2766060531139374,0.7594000101089478,0.9787272214889526,50000 -2056.581691026688,5.2877233028411865,58179.60224986076,173183,0,58179.60224986076,0.6374000310897827,1.6677082777023315,10000,60247.089400053024,0.9255819320678712,0.2669574022293091,0.7631999850273132,0.9670901298522948,50000 -2074.306261062622,5.346015691757202,58689.71545791626,174703,0,58689.71545791626,0.6367000341415405,1.670340895652771,10000,60775.03505182266,0.927754282951355,0.2576439678668976,0.7628799676895142,0.9702308177947998,50000 -2092.1598665714264,5.409138917922974,59199.76896715164,176222,0,59199.76896715164,0.6407000422477722,1.6606649160385132,10000,61303.05489516258,0.9254623651504515,0.265313446521759,0.7633199691772461,0.9638875722885132,50000 -2110.070912361145,5.473035573959351,59709.75693559647,177741,0,59709.75693559647,0.6406000256538391,1.659266233444214,10000,61831.0673494339,0.9296875,0.254580557346344,0.7645799517631531,0.9602543711662292,50000 -2127.696268796921,5.533898830413818,60219.763149023056,179261,0,60219.763149023056,0.6409000158309937,1.654666543006897,10000,62358.80902075768,0.9309031963348388,0.2519624531269073,0.7645399570465088,0.9593108296394348,50000 -2145.725523948669,5.593425512313843,60729.83857727051,180781,0,60729.83857727051,0.6391000151634216,1.6557667255401611,10000,62887.02244234085,0.9323779940605164,0.244142547249794,0.76419997215271,0.958231508731842,50000 -2163.610629081726,5.656862735748291,61240.16095900536,182301,0,61240.16095900536,0.6412000060081482,1.6545675992965698,10000,63415.34263134003,0.934351086616516,0.239544078707695,0.7648999691009521,0.956494688987732,50000 -2181.3248648643494,5.717596530914307,61750.23178052902,183820,0,61750.23178052902,0.6407000422477722,1.6530959606170654,10000,63943.23792815208,0.9310028553009032,0.2465308010578155,0.7655199766159058,0.9564112424850464,50000 -2199.1994581222534,5.779508590698242,62260.25123000145,185340,0,62260.25123000145,0.6410000324249268,1.6519469022750854,10000,64471.24368929863,0.9342713356018066,0.2396325170993805,0.765779972076416,0.95612633228302,50000 -2216.782294511795,5.841466188430786,62705.71089506149,186666,0,62705.71089506149,0.640500009059906,1.6542025804519653,10000,64934.39106178284,0.9329161047935486,0.24513345956802368,0.7657399773597717,0.9568769335746765,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/measurements.csv deleted file mode 100644 index 39ae7847d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1993 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.66344464,6.9194984,,,,,,,,,,,,,, -1,,,0.0006576849264092,6.909994125366211,0.0009599999757483,6.910243988037109,50000.0,0.0006000000284984,6.910250186920166,10000.0,30.54241108894348,48.19990372657776,30.54241108894348,17.657403707504272,0.0,0.0 -100,0.64229107,6.9012103,,,,,,,,,,,,,, -200,0.6539818,6.859667,,,,,,,,,,,,,, -300,0.7117052,6.77052,,,,,,,,,,,,,, -400,0.74145323,6.650942,,,,,,,,,,,,,, -500,0.79783905,6.528565,,,,,,,,,,,,,, -600,0.85097677,6.475006,,,,,,,,,,,,,, -700,0.9111472,6.316205,,,,,,,,,,,,,, -800,0.99889374,6.1250863,,,,,,,,,,,,,, -900,1.4699045,6.0376167,,,,,,,,,,,,,, -1000,1.9983644,6.005973,,,,,,,,,,,,,, -1100,1.532888,5.875403,,,,,,,,,,,,,, -1200,2.2010064,5.830358,,,,,,,,,,,,,, -1300,2.9731917,5.657202,,,,,,,,,,,,,, -1400,3.9692576,5.5699625,,,,,,,,,,,,,, -1500,4.2432356,5.518051,,,,,,,,,,,,,, -1514,,,0.0730229541659355,5.302392959594727,0.0668999999761581,5.365999221801758,50000.0,0.0486000031232833,5.617886543273926,10000.0,540.7001566886902,576.3619163036346,540.7001566886902,35.58681917190552,0.0257160663604736,0.0 -1600,2.8372788,5.4453554,,,,,,,,,,,,,, -1700,3.909633,5.373728,,,,,,,,,,,,,, -1800,3.9527872,5.2500787,,,,,,,,,,,,,, -1900,4.5248384,5.198055,,,,,,,,,,,,,, -2000,6.239444,5.0628057,,,,,,,,,,,,,, -2100,5.420957,5.026223,,,,,,,,,,,,,, -2200,6.896294,4.9780016,,,,,,,,,,,,,, -2300,5.73369,5.00435,,,,,,,,,,,,,, -2400,4.3841996,4.8369412,,,,,,,,,,,,,, -2500,4.5742083,4.837573,,,,,,,,,,,,,, -2600,7.4229875,4.7798166,,,,,,,,,,,,,, -2700,5.3491616,4.771052,,,,,,,,,,,,,, -2800,5.2802505,4.703374,,,,,,,,,,,,,, -2900,5.701994,4.6690054,,,,,,,,,,,,,, -3000,5.478309,4.457527,,,,,,,,,,,,,, -3027,,,0.1718152016401291,4.267473220825195,0.1560599952936172,4.383291721343994,50000.0,0.1161000058054924,4.863234043121338,10000.0,1050.9415996074677,1104.719381570816,1050.9415996074677,53.62484097480774,0.0550439357757568,0.0 -3100,4.7477403,4.513783,,,,,,,,,,,,,, -3200,9.5809765,4.415099,,,,,,,,,,,,,, -3300,5.848755,4.471198,,,,,,,,,,,,,, -3400,4.986415,4.4330096,,,,,,,,,,,,,, -3500,7.357249,4.259865,,,,,,,,,,,,,, -3600,4.785697,4.229288,,,,,,,,,,,,,, -3700,7.3319116,4.2278323,,,,,,,,,,,,,, -3800,5.8523026,4.0617,,,,,,,,,,,,,, -3900,8.050775,4.0989156,,,,,,,,,,,,,, -4000,5.0890193,4.1064305,,,,,,,,,,,,,, -4100,4.7472353,4.087098,,,,,,,,,,,,,, -4200,9.39798,3.994732,,,,,,,,,,,,,, -4300,7.2358723,4.0012665,,,,,,,,,,,,,, -4400,5.433345,3.942111,,,,,,,,,,,,,, -4500,10.488512,3.8341753,,,,,,,,,,,,,, -4539,,,0.2708067595958709,3.4995758533477783,0.2490399926900863,3.6413185596466064,50000.0,0.1802000105381012,4.247800827026367,10000.0,1560.8828208446505,1632.4283664226532,1560.8828208446505,71.31632661819458,0.0830860137939453,0.0 -4600,6.1557455,3.7647352,,,,,,,,,,,,,, -4700,6.2713327,3.7976654,,,,,,,,,,,,,, -4800,10.177706,3.6958346,,,,,,,,,,,,,, -4900,6.725826,3.725985,,,,,,,,,,,,,, -5000,6.042552,3.6257672,,,,,,,,,,,,,, -5100,13.520611,3.5595753,,,,,,,,,,,,,, -5200,8.716533,3.5516338,,,,,,,,,,,,,, -5300,5.9864554,3.6616845,,,,,,,,,,,,,, -5400,10.541709,3.5494719,,,,,,,,,,,,,, -5500,11.52424,3.4691749,,,,,,,,,,,,,, -5600,11.421967,3.527454,,,,,,,,,,,,,, -5700,8.050197,3.404698,,,,,,,,,,,,,, -5800,7.4512124,3.4675577,,,,,,,,,,,,,, -5900,6.5606184,3.3519764,,,,,,,,,,,,,, -6000,9.858003,3.2797592,,,,,,,,,,,,,, -6051,,,0.3694595098495483,2.88207745552063,0.3414799869060516,3.04559588432312,50000.0,0.2535000145435333,3.7423603534698486,10000.0,2071.0617468357086,2160.853071212769,2071.0617468357086,89.48260450363159,0.1137394905090332,0.0 -6100,5.776933,3.2724938,,,,,,,,,,,,,, -6200,8.773853,3.19632,,,,,,,,,,,,,, -6300,6.95712,3.2976801,,,,,,,,,,,,,, -6400,7.7809286,3.2559676,,,,,,,,,,,,,, -6500,8.688369,3.3469307,,,,,,,,,,,,,, -6600,10.737656,3.1196854,,,,,,,,,,,,,, -6700,7.820778,3.187452,,,,,,,,,,,,,, -6800,7.2895126,3.0931568,,,,,,,,,,,,,, -6900,6.474232,3.0870261,,,,,,,,,,,,,, -7000,4.5541224,3.0956357,,,,,,,,,,,,,, -7100,6.5173087,3.1187437,,,,,,,,,,,,,, -7200,10.091579,3.1106195,,,,,,,,,,,,,, -7300,6.7298436,3.0502596,,,,,,,,,,,,,, -7400,15.514636,3.1092,,,,,,,,,,,,,, -7500,6.0942345,2.9992151,,,,,,,,,,,,,, -7563,,,0.4502351582050323,2.4000375270843506,0.3928200006484985,2.7543821334838867,50000.0,0.2979000210762024,3.458079099655152,10000.0,2581.108203172684,2690.453903913498,2581.108203172684,108.95680356025696,0.144883394241333,0.0 -7600,7.5616455,2.9986763,,,,,,,,,,,,,, -7700,10.442184,2.9352074,,,,,,,,,,,,,, -7800,7.954777,2.8753567,,,,,,,,,,,,,, -7900,10.137074,2.964498,,,,,,,,,,,,,, -8000,9.97002,2.8845787,,,,,,,,,,,,,, -8100,6.155618,2.8630998,,,,,,,,,,,,,, -8200,8.087205,2.8663707,,,,,,,,,,,,,, -8300,7.020241,2.8279963,,,,,,,,,,,,,, -8400,6.486458,2.8618097,,,,,,,,,,,,,, -8500,7.682197,2.7598398,,,,,,,,,,,,,, -8600,5.047332,2.7779698,,,,,,,,,,,,,, -8700,4.5162435,2.7558265,,,,,,,,,,,,,, -8800,8.412603,2.7918231,,,,,,,,,,,,,, -8900,8.325638,2.768439,,,,,,,,,,,,,, -9000,8.601308,2.8059506,,,,,,,,,,,,,, -9076,,,0.4960738122463226,2.1706013679504395,0.4490000009536743,2.427058219909668,50000.0,0.3509000241756439,3.098984479904175,10000.0,3091.184275150299,3218.2887029647827,3091.184275150299,126.6402349472046,0.1710314750671386,0.0 -9100,6.8098793,2.5526583,,,,,,,,,,,,,, -9200,6.641185,2.6549282,,,,,,,,,,,,,, -9300,4.949934,2.7134385,,,,,,,,,,,,,, -9400,7.284762,2.6836371,,,,,,,,,,,,,, -9500,11.637571,2.590277,,,,,,,,,,,,,, -9600,8.590896,2.6536224,,,,,,,,,,,,,, -9700,8.545928,2.7496073,,,,,,,,,,,,,, -9800,6.9966526,2.5508907,,,,,,,,,,,,,, -9900,7.5556726,2.5586278,,,,,,,,,,,,,, -10000,4.160585,2.6021397,,,,,,,,,,,,,, -10100,7.8081093,2.5415175,,,,,,,,,,,,,, -10200,4.951346,2.6774917,,,,,,,,,,,,,, -10300,5.4653597,2.5903325,,,,,,,,,,,,,, -10400,6.2236514,2.5962114,,,,,,,,,,,,,, -10500,5.767752,2.5234423,,,,,,,,,,,,,, -10590,,,0.5216438174247742,2.0399041175842285,0.4847599864006042,2.257338762283325,50000.0,0.3641000092029571,3.029717445373535,10000.0,3601.261110067368,3746.2567975521088,3601.261110067368,144.45068073272705,0.2024247646331787,0.0 -10600,6.8278003,2.4769368,,,,,,,,,,,,,, -10700,5.4774475,2.489579,,,,,,,,,,,,,, -10800,8.024878,2.6084495,,,,,,,,,,,,,, -10900,6.431628,2.4763844,,,,,,,,,,,,,, -11000,6.2554703,2.3890796,,,,,,,,,,,,,, -11100,4.8193007,2.4473343,,,,,,,,,,,,,, -11200,5.3845577,2.3766952,,,,,,,,,,,,,, -11300,5.047831,2.4434173,,,,,,,,,,,,,, -11400,6.5494184,2.5224843,,,,,,,,,,,,,, -11500,6.2570834,2.4551444,,,,,,,,,,,,,, -11600,5.2744403,2.630372,,,,,,,,,,,,,, -11700,8.781606,2.4292922,,,,,,,,,,,,,, -11800,6.8670917,2.391424,,,,,,,,,,,,,, -11900,7.566243,2.5019927,,,,,,,,,,,,,, -12000,8.023102,2.3258462,,,,,,,,,,,,,, -12100,6.1048827,2.236165,,,,,,,,,,,,,, -12104,,,0.5573580861091614,1.8680070638656616,0.5182600021362305,2.074094533920288,50000.0,0.3975000083446502,2.803204298019409,10000.0,4111.291719198227,4274.436697483063,4111.291719198227,162.51656198501587,0.2360508441925048,0.0 -12200,6.4914436,2.2796965,,,,,,,,,,,,,, -12300,5.3560596,2.2023213,,,,,,,,,,,,,, -12400,6.557903,2.4302108,,,,,,,,,,,,,, -12500,7.2835917,2.372607,,,,,,,,,,,,,, -12600,8.238485,2.4129105,,,,,,,,,,,,,, -12700,7.635244,2.3011055,,,,,,,,,,,,,, -12800,5.2954545,2.3300767,,,,,,,,,,,,,, -12900,7.2032228,2.390574,,,,,,,,,,,,,, -13000,6.6995463,2.4587498,,,,,,,,,,,,,, -13100,6.536041,2.3600736,,,,,,,,,,,,,, -13200,8.265873,2.2045257,,,,,,,,,,,,,, -13300,5.598678,2.2842672,,,,,,,,,,,,,, -13400,5.160575,2.3449845,,,,,,,,,,,,,, -13500,6.5018435,2.1848445,,,,,,,,,,,,,, -13600,7.499996,2.2293322,,,,,,,,,,,,,, -13620,,,0.5631775856018066,1.82098388671875,0.5280199646949768,2.02920913696289,50000.0,0.4131000339984894,2.741434097290039,10000.0,4621.447257757187,4802.413382053375,4621.447257757187,180.25822043418884,0.2662384510040283,0.0 -13700,4.948839,2.3084931,,,,,,,,,,,,,, -13800,6.5416775,2.205112,,,,,,,,,,,,,, -13900,6.332095,2.3709588,,,,,,,,,,,,,, -14000,11.166296,2.436258,,,,,,,,,,,,,, -14100,5.488044,2.4349923,,,,,,,,,,,,,, -14200,5.1130986,2.4407578,,,,,,,,,,,,,, -14300,5.150279,2.2474022,,,,,,,,,,,,,, -14400,9.367592,2.1426709,,,,,,,,,,,,,, -14500,7.0994735,2.1940978,,,,,,,,,,,,,, -14600,5.598675,2.2992492,,,,,,,,,,,,,, -14700,3.7619398,2.4475462,,,,,,,,,,,,,, -14800,5.634398,2.1402833,,,,,,,,,,,,,, -14900,4.47044,2.187807,,,,,,,,,,,,,, -15000,9.725892,2.2240927,,,,,,,,,,,,,, -15100,6.62748,2.3644009,,,,,,,,,,,,,, -15136,,,0.5808952450752258,1.7555890083312988,0.5410000085830688,1.951683759689331,50000.0,0.4170000255107879,2.7061638832092285,10000.0,5131.6218910217285,5330.570430994034,5131.6218910217285,198.1563205718994,0.3011949062347412,0.0 -15200,4.84423,2.253782,,,,,,,,,,,,,, -15300,6.9652963,2.2227144,,,,,,,,,,,,,, -15400,5.135825,2.207864,,,,,,,,,,,,,, -15500,7.717747,2.2760396,,,,,,,,,,,,,, -15600,6.666287,2.290134,,,,,,,,,,,,,, -15700,11.554185,2.1551988,,,,,,,,,,,,,, -15800,5.5670156,2.2038128,,,,,,,,,,,,,, -15900,5.645806,2.2182508,,,,,,,,,,,,,, -16000,4.727754,2.1587968,,,,,,,,,,,,,, -16100,5.866091,2.200336,,,,,,,,,,,,,, -16200,7.313699,2.3414743,,,,,,,,,,,,,, -16300,5.1904335,2.2850473,,,,,,,,,,,,,, -16400,8.992191,2.2399738,,,,,,,,,,,,,, -16500,4.234595,2.2473278,,,,,,,,,,,,,, -16600,6.279004,2.1084197,,,,,,,,,,,,,, -16652,,,0.6128427982330322,1.5677685737609863,0.5491600036621094,1.9117090702056885,50000.0,0.4243000149726867,2.689218759536743,10000.0,5641.749358415604,5858.903384923935,5641.749358415604,216.2797749042511,0.3344736099243164,0.0 -16700,5.695759,2.2098286,,,,,,,,,,,,,, -16800,4.8475323,2.2364757,,,,,,,,,,,,,, -16900,6.502743,2.0634015,,,,,,,,,,,,,, -17000,6.06117,2.315965,,,,,,,,,,,,,, -17100,5.813855,2.1701813,,,,,,,,,,,,,, -17200,5.3442354,2.241709,,,,,,,,,,,,,, -17300,4.836659,2.11784,,,,,,,,,,,,,, -17400,3.6460752,2.1896386,,,,,,,,,,,,,, -17500,5.1439133,2.1273994,,,,,,,,,,,,,, -17600,5.091394,2.215223,,,,,,,,,,,,,, -17700,3.938591,2.1915169,,,,,,,,,,,,,, -17800,4.8451614,2.2670047,,,,,,,,,,,,,, -17900,4.8766794,2.1525564,,,,,,,,,,,,,, -18000,4.7803206,2.1798148,,,,,,,,,,,,,, -18100,4.237253,2.096371,,,,,,,,,,,,,, -18168,,,0.6028379797935486,1.6164047718048096,0.5559599995613098,1.8733510971069336,50000.0,0.4319000244140625,2.617641925811768,10000.0,6151.70788192749,6386.860626220703,6151.70788192749,234.1956398487091,0.3683803081512451,0.0 -18200,4.5107684,2.2742114,,,,,,,,,,,,,, -18300,4.3979497,2.0895092,,,,,,,,,,,,,, -18400,5.1579814,2.024323,,,,,,,,,,,,,, -18500,7.562296,2.2061172,,,,,,,,,,,,,, -18600,4.5107403,2.205957,,,,,,,,,,,,,, -18700,2.860305,2.1170607,,,,,,,,,,,,,, -18800,3.6729193,2.2128682,,,,,,,,,,,,,, -18900,4.9270883,2.1912203,,,,,,,,,,,,,, -19000,4.951211,2.0523152,,,,,,,,,,,,,, -19100,3.8101308,2.0627465,,,,,,,,,,,,,, -19200,5.7638106,2.1521134,,,,,,,,,,,,,, -19300,4.5886483,2.2157238,,,,,,,,,,,,,, -19400,8.235357,2.129821,,,,,,,,,,,,,, -19500,4.1484585,2.089851,,,,,,,,,,,,,, -19600,4.365752,2.06458,,,,,,,,,,,,,, -19684,,,0.6129025816917419,1.5699564218521118,0.5642799735069275,1.820169448852539,50000.0,0.4443000257015228,2.5521981716156006,10000.0,6661.777126550674,6915.075645208359,6661.777126550674,252.2565195560456,0.4041624069213867,0.0 -19700,3.792399,2.190165,,,,,,,,,,,,,, -19800,5.053295,2.1861145,,,,,,,,,,,,,, -19900,4.755706,2.0227566,,,,,,,,,,,,,, -20000,3.6086082,2.1800184,,,,,,,,,,,,,, -20100,5.0941267,2.1558633,,,,,,,,,,,,,, -20200,3.7421079,2.3240094,,,,,,,,,,,,,, -20300,4.1602516,2.0662327,,,,,,,,,,,,,, -20400,3.4104733,2.0764542,,,,,,,,,,,,,, -20500,3.81122,2.1573644,,,,,,,,,,,,,, -20600,4.1539283,2.0717115,,,,,,,,,,,,,, -20700,5.7549386,2.1268716,,,,,,,,,,,,,, -20800,3.631183,2.2541804,,,,,,,,,,,,,, -20900,3.6761827,2.0858405,,,,,,,,,,,,,, -21000,4.5049496,2.2393131,,,,,,,,,,,,,, -21100,3.8000379,2.2226388,,,,,,,,,,,,,, -21200,3.9978094,2.0237136,,,,,,,,,,,,,, -21201,,,0.6129623651504517,1.5876972675323486,0.5711599588394165,1.807672500610352,50000.0,0.4500000178813934,2.5406386852264404,10000.0,7172.218660593033,7443.592651605606,7172.218660593033,270.24809288978577,0.4388735294342041,0.0 -21300,4.7058835,2.035882,,,,,,,,,,,,,, -21400,4.691328,1.9929354,,,,,,,,,,,,,, -21500,4.3821297,2.1207807,,,,,,,,,,,,,, -21600,4.0384803,2.0837476,,,,,,,,,,,,,, -21700,3.1255846,2.1180487,,,,,,,,,,,,,, -21800,4.0008926,2.0405953,,,,,,,,,,,,,, -21900,3.6148784,1.9755331,,,,,,,,,,,,,, -22000,3.0654557,2.0129743,,,,,,,,,,,,,, -22100,4.5281854,2.0321262,,,,,,,,,,,,,, -22200,3.9754016,2.1124208,,,,,,,,,,,,,, -22300,2.9131687,2.0914416,,,,,,,,,,,,,, -22400,5.605189,2.064253,,,,,,,,,,,,,, -22500,3.2933445,2.0353754,,,,,,,,,,,,,, -22600,4.04043,2.004042,,,,,,,,,,,,,, -22700,3.839373,2.0781178,,,,,,,,,,,,,, -22718,,,0.6120854616165161,1.585247039794922,0.5758599638938904,1.7821930646896362,50000.0,0.4487000107765198,2.5585756301879883,10000.0,7682.367743968964,7971.530035734177,7682.367743968964,287.95379066467285,0.4722630977630615,0.0 -22800,2.9923398,1.9485993,,,,,,,,,,,,,, -22900,3.6599455,2.0269816,,,,,,,,,,,,,, -23000,3.9990222,2.0850449,,,,,,,,,,,,,, -23100,4.1070156,2.3098822,,,,,,,,,,,,,, -23200,3.0888999,2.1620264,,,,,,,,,,,,,, -23300,4.2157016,2.0009418,,,,,,,,,,,,,, -23400,3.8387682,2.0257757,,,,,,,,,,,,,, -23500,5.359292,2.0615797,,,,,,,,,,,,,, -23600,3.4379284,2.1614473,,,,,,,,,,,,,, -23700,3.2293892,2.0156043,,,,,,,,,,,,,, -23800,4.095798,2.0359392,,,,,,,,,,,,,, -23900,3.817497,2.0844371,,,,,,,,,,,,,, -24000,3.55436,1.9312326,,,,,,,,,,,,,, -24100,3.793011,2.114719,,,,,,,,,,,,,, -24200,4.3678646,2.1511605,,,,,,,,,,,,,, -24235,,,0.6190210580825806,1.554908275604248,0.5786600112915039,1.7810800075531006,50000.0,0.4481000304222107,2.5390639305114746,10000.0,8192.398866891861,8499.543118476868,8192.398866891861,305.84978795051575,0.5091888904571533,0.0 -24300,3.5360959,2.0878787,,,,,,,,,,,,,, -24400,3.3488214,2.0601761,,,,,,,,,,,,,, -24500,4.1422577,2.07714,,,,,,,,,,,,,, -24600,4.0924807,1.9644318,,,,,,,,,,,,,, -24700,4.9461303,2.1278586,,,,,,,,,,,,,, -24800,3.7118123,2.00495,,,,,,,,,,,,,, -24900,3.4096231,2.0212553,,,,,,,,,,,,,, -25000,3.390708,2.1133008,,,,,,,,,,,,,, -25100,3.1206417,1.9023359,,,,,,,,,,,,,, -25200,4.440089,2.013566,,,,,,,,,,,,,, -25300,3.6381364,2.0264692,,,,,,,,,,,,,, -25400,4.606075,2.0227823,,,,,,,,,,,,,, -25500,3.3672369,1.9913166,,,,,,,,,,,,,, -25600,4.11993,2.052827,,,,,,,,,,,,,, -25700,3.515091,1.9430774,,,,,,,,,,,,,, -25753,,,0.644551157951355,1.4304097890853882,0.5829799771308899,1.7570055723190308,50000.0,0.4615000188350677,2.4996163845062256,10000.0,8702.434856891632,9027.637979269028,8702.434856891632,323.8264684677124,0.5421411991119385,0.0 -25800,4.1263404,2.0784993,,,,,,,,,,,,,, -25900,4.207412,2.0726037,,,,,,,,,,,,,, -26000,3.325205,2.0007515,,,,,,,,,,,,,, -26100,3.6112862,2.0633426,,,,,,,,,,,,,, -26200,3.098807,1.9696438,,,,,,,,,,,,,, -26300,3.5613077,2.0636394,,,,,,,,,,,,,, -26400,5.9191213,2.0393052,,,,,,,,,,,,,, -26500,3.1391966,2.0727487,,,,,,,,,,,,,, -26600,3.0110073,2.0654109,,,,,,,,,,,,,, -26700,3.053235,2.0911891,,,,,,,,,,,,,, -26800,3.667218,1.9899178,,,,,,,,,,,,,, -26900,3.25883,1.9485096,,,,,,,,,,,,,, -27000,3.8500218,1.8387341,,,,,,,,,,,,,, -27100,4.1248407,2.2132423,,,,,,,,,,,,,, -27200,2.8367276,1.9141315,,,,,,,,,,,,,, -27271,,,0.62890625,1.5057846307754517,0.5762799978256226,1.7802006006240845,50000.0,0.4508000314235687,2.5254218578338623,10000.0,9212.556456327438,9555.571353673937,9212.556456327438,341.5526399612427,0.5781388282775879,0.0 -27300,3.3889475,1.8878657,,,,,,,,,,,,,, -27400,3.4630508,2.0119169,,,,,,,,,,,,,, -27500,3.4896445,2.0268936,,,,,,,,,,,,,, -27600,3.8829026,1.9935461,,,,,,,,,,,,,, -27700,4.898137,1.9495804,,,,,,,,,,,,,, -27800,3.8549466,2.0180104,,,,,,,,,,,,,, -27900,4.293144,1.9946654,,,,,,,,,,,,,, -28000,4.117047,2.0121264,,,,,,,,,,,,,, -28100,3.3481867,2.0806935,,,,,,,,,,,,,, -28200,4.638963,1.9378495,,,,,,,,,,,,,, -28300,4.0905924,1.958646,,,,,,,,,,,,,, -28400,3.6267931,2.063279,,,,,,,,,,,,,, -28500,3.4144917,1.9378825,,,,,,,,,,,,,, -28600,4.6799984,1.912162,,,,,,,,,,,,,, -28700,3.51959,2.0318472,,,,,,,,,,,,,, -28788,,,0.6389309763908386,1.4611449241638184,0.5892400145530701,1.712038516998291,50000.0,0.4659000337123871,2.4030990600585938,10000.0,9722.534570932388,10083.528599262238,9722.534570932388,359.44764280319214,0.6132323741912842,0.0 -28800,3.283134,2.0098417,,,,,,,,,,,,,, -28900,3.8662105,2.0553062,,,,,,,,,,,,,, -29000,3.280289,1.9435041,,,,,,,,,,,,,, -29100,3.5482879,1.9297484,,,,,,,,,,,,,, -29200,4.1325917,1.8916011,,,,,,,,,,,,,, -29300,3.9040756,1.8745564,,,,,,,,,,,,,, -29400,3.873099,1.9406731,,,,,,,,,,,,,, -29500,3.6775668,2.0365927,,,,,,,,,,,,,, -29600,3.7553017,1.9519633,,,,,,,,,,,,,, -29700,3.714191,1.9928281,,,,,,,,,,,,,, -29800,3.3016922,1.9334496,,,,,,,,,,,,,, -29900,3.495455,2.0366309,,,,,,,,,,,,,, -30000,3.4833248,2.073817,,,,,,,,,,,,,, -30100,3.4356759,1.890629,,,,,,,,,,,,,, -30200,3.621591,1.9664006,,,,,,,,,,,,,, -30300,3.199253,2.0993235,,,,,,,,,,,,,, -30306,,,0.6354631781578064,1.4758671522140503,0.5902199745178223,1.7170330286026,50000.0,0.4637000262737274,2.4629597663879395,10000.0,10232.679866552353,10611.508907079697,10232.679866552353,377.1987464427948,0.6481332778930664,0.0 -30400,3.609605,1.9465433,,,,,,,,,,,,,, -30500,3.4645264,1.9949938,,,,,,,,,,,,,, -30600,3.4334652,1.9444922,,,,,,,,,,,,,, -30700,3.8832984,2.0568876,,,,,,,,,,,,,, -30800,3.0231917,1.9543439,,,,,,,,,,,,,, -30900,4.3573747,2.0176504,,,,,,,,,,,,,, -31000,3.6850967,1.9369606,,,,,,,,,,,,,, -31100,3.1057103,2.0404024,,,,,,,,,,,,,, -31200,3.4838731,1.8741318,,,,,,,,,,,,,, -31300,3.3644607,1.8829349,,,,,,,,,,,,,, -31400,3.153661,2.0240347,,,,,,,,,,,,,, -31500,3.677576,1.8810287,,,,,,,,,,,,,, -31600,4.250645,1.862546,,,,,,,,,,,,,, -31700,4.4837337,1.9808589,,,,,,,,,,,,,, -31800,4.282507,2.0934112,,,,,,,,,,,,,, -31824,,,0.6270925998687744,1.5033053159713743,0.5899400115013123,1.7156438827514648,50000.0,0.4649000167846679,2.449855327606201,10000.0,10742.683179616928,11139.474148750303,10742.683179616928,395.0763454437256,0.6836209297180176,0.0 -31900,3.3526738,2.0822966,,,,,,,,,,,,,, -32000,3.5947006,2.113827,,,,,,,,,,,,,, -32100,3.542409,1.8804495,,,,,,,,,,,,,, -32200,3.722868,1.9347875,,,,,,,,,,,,,, -32300,3.9447706,1.8216177,,,,,,,,,,,,,, -32400,3.7418098,1.8332517,,,,,,,,,,,,,, -32500,3.4440362,2.0442684,,,,,,,,,,,,,, -32600,3.3381083,1.9618568,,,,,,,,,,,,,, -32700,3.5767767,1.9509729,,,,,,,,,,,,,, -32800,3.20679,1.9852287,,,,,,,,,,,,,, -32900,3.1579666,1.8729655,,,,,,,,,,,,,, -33000,3.5713022,1.9013705,,,,,,,,,,,,,, -33100,3.5544293,2.0029316,,,,,,,,,,,,,, -33200,3.5561888,1.964426,,,,,,,,,,,,,, -33300,4.775095,2.0174603,,,,,,,,,,,,,, -33342,,,0.6690449714660645,1.3072504997253418,0.5861600041389465,1.723919153213501,50000.0,0.4624000191688537,2.4317657947540283,10000.0,11252.691735982897,11667.586690425873,11252.691735982897,413.0981290340424,0.716942310333252,0.0 -33400,3.513268,1.8858372,,,,,,,,,,,,,, -33500,2.8413382,1.8884625,,,,,,,,,,,,,, -33600,3.3072772,1.9364378,,,,,,,,,,,,,, -33700,3.001985,2.0421214,,,,,,,,,,,,,, -33800,3.8746188,1.8989528,,,,,,,,,,,,,, -33900,3.4823565,1.96976,,,,,,,,,,,,,, -34000,3.706768,2.0467782,,,,,,,,,,,,,, -34100,3.500449,1.9785496,,,,,,,,,,,,,, -34200,3.9056854,2.0184321,,,,,,,,,,,,,, -34300,4.2156124,2.1204,,,,,,,,,,,,,, -34400,3.6881948,2.0479724,,,,,,,,,,,,,, -34500,3.363704,1.9805307,,,,,,,,,,,,,, -34600,3.881011,2.0107083,,,,,,,,,,,,,, -34700,4.1476912,1.911232,,,,,,,,,,,,,, -34800,4.2728662,1.9797498,,,,,,,,,,,,,, -34860,,,0.6469627022743225,1.3967673778533936,0.5877000093460083,1.7086846828460691,50000.0,0.4726000130176544,2.405595541000366,10000.0,11762.682977676392,12195.666572332382,11762.682977676392,431.0991203784943,0.7553849220275879,0.0 -34900,3.54697,2.025051,,,,,,,,,,,,,, -35000,3.4358919,2.0274975,,,,,,,,,,,,,, -35100,3.6843648,1.931459,,,,,,,,,,,,,, -35200,3.8846707,1.9117637,,,,,,,,,,,,,, -35300,3.369575,1.908534,,,,,,,,,,,,,, -35400,3.3830252,1.9773376,,,,,,,,,,,,,, -35500,2.88232,1.931866,,,,,,,,,,,,,, -35600,4.0156107,1.919166,,,,,,,,,,,,,, -35700,4.696578,1.9322875,,,,,,,,,,,,,, -35800,3.5487177,2.0048156,,,,,,,,,,,,,, -35900,3.6910367,1.9479654,,,,,,,,,,,,,, -36000,3.4150896,1.9267197,,,,,,,,,,,,,, -36100,4.0700264,1.8844693,,,,,,,,,,,,,, -36200,4.1177263,1.933927,,,,,,,,,,,,,, -36300,4.4518876,2.0053773,,,,,,,,,,,,,, -36377,,,0.6559311151504517,1.3725244998931885,0.6061399579048157,1.6266961097717283,50000.0,0.4793000221252441,2.3845810890197754,10000.0,12272.765510082245,12723.876703977585,12272.765510082245,449.1371431350708,0.7958519458770752,0.0 -36400,4.2016873,1.9869175,,,,,,,,,,,,,, -36500,4.406279,2.0827188,,,,,,,,,,,,,, -36600,4.1498513,1.7717148,,,,,,,,,,,,,, -36700,4.610656,1.9593308,,,,,,,,,,,,,, -36800,3.348992,1.8819786,,,,,,,,,,,,,, -36900,3.4401305,1.9909382,,,,,,,,,,,,,, -37000,3.4980752,1.9270984,,,,,,,,,,,,,, -37100,3.2833848,1.9199407,,,,,,,,,,,,,, -37200,3.4888134,1.950829,,,,,,,,,,,,,, -37300,3.6626987,1.9456959,,,,,,,,,,,,,, -37400,3.5170877,1.9210898,,,,,,,,,,,,,, -37500,3.1019635,1.9928129,,,,,,,,,,,,,, -37600,3.402336,1.9297,,,,,,,,,,,,,, -37700,3.2735603,1.9835728,,,,,,,,,,,,,, -37800,3.2768247,1.8505374,,,,,,,,,,,,,, -37896,,,0.6361208558082581,1.4662803411483765,0.5847199559211731,1.7373977899551392,50000.0,0.4638000130653381,2.486398220062256,10000.0,12782.802032232285,13251.825589179993,12782.802032232285,466.961225271225,0.8347640037536621,0.0 -37900,3.8628304,2.0063612,,,,,,,,,,,,,, -38000,2.9135828,1.9180944,,,,,,,,,,,,,, -38100,3.385727,1.8736048,,,,,,,,,,,,,, -38200,3.0577672,1.9711945,,,,,,,,,,,,,, -38300,4.0919175,1.9950122,,,,,,,,,,,,,, -38400,4.332729,1.8920867,,,,,,,,,,,,,, -38500,4.2519107,1.9261664,,,,,,,,,,,,,, -38600,3.3692648,1.9337884,,,,,,,,,,,,,, -38700,4.1646514,1.9114434,,,,,,,,,,,,,, -38800,4.5056376,1.9242084,,,,,,,,,,,,,, -38900,3.227515,1.9277151,,,,,,,,,,,,,, -39000,3.4172635,1.7845542,,,,,,,,,,,,,, -39100,3.5733397,2.0713964,,,,,,,,,,,,,, -39200,3.4093876,1.9543619,,,,,,,,,,,,,, -39300,3.9112875,1.9521513,,,,,,,,,,,,,, -39400,3.6329143,1.8834504,,,,,,,,,,,,,, -39414,,,0.6498525142669678,1.4036788940429688,0.6021599769592285,1.6521224975585938,50000.0,0.48130002617836,2.3708035945892334,10000.0,13292.97028517723,13779.885905742643,13292.97028517723,484.7647216320038,0.873910665512085,0.0 -39500,3.9728742,1.92113,,,,,,,,,,,,,, -39600,3.4474688,1.9171946,,,,,,,,,,,,,, -39700,4.6121774,1.9277534,,,,,,,,,,,,,, -39800,3.3631954,2.0339675,,,,,,,,,,,,,, -39900,4.0581884,1.9310801,,,,,,,,,,,,,, -40000,3.6739876,2.0347555,,,,,,,,,,,,,, -40100,4.23702,1.8298408,,,,,,,,,,,,,, -40200,3.610909,1.9440651,,,,,,,,,,,,,, -40300,4.208341,1.8703692,,,,,,,,,,,,,, -40400,4.1009145,1.9846752,,,,,,,,,,,,,, -40500,3.4423559,1.8300936,,,,,,,,,,,,,, -40600,4.204532,1.9450676,,,,,,,,,,,,,, -40700,3.9757557,2.0790043,,,,,,,,,,,,,, -40800,3.5579019,1.9589893,,,,,,,,,,,,,, -40900,4.263715,1.964957,,,,,,,,,,,,,, -40933,,,0.6477798223495483,1.4108294248580933,0.6057999730110168,1.6259907484054563,50000.0,0.4783000349998474,2.3724279403686523,10000.0,13803.110150814056,14308.154326200483,13803.110150814056,502.8075284957886,0.9099681377410888,0.0 -41000,3.4723585,2.0458968,,,,,,,,,,,,,, -41100,3.2928765,1.9382871,,,,,,,,,,,,,, -41200,4.214198,1.9976547,,,,,,,,,,,,,, -41300,4.0723314,1.9725695,,,,,,,,,,,,,, -41400,3.7390616,1.8549995,,,,,,,,,,,,,, -41500,4.413446,2.0124178,,,,,,,,,,,,,, -41600,3.5336282,1.9070991,,,,,,,,,,,,,, -41700,3.6356447,1.9360744,,,,,,,,,,,,,, -41800,3.6052954,1.857422,,,,,,,,,,,,,, -41900,3.7944274,1.9530752,,,,,,,,,,,,,, -42000,3.7311006,1.89852,,,,,,,,,,,,,, -42100,3.405686,1.7542468,,,,,,,,,,,,,, -42200,4.0099487,1.9066979,,,,,,,,,,,,,, -42300,3.4645872,1.856407,,,,,,,,,,,,,, -42400,3.840232,1.9485524,,,,,,,,,,,,,, -42453,,,0.6938576102256775,1.2015491724014282,0.6082800030708313,1.6163575649261477,50000.0,0.4844000339508056,2.3288164138793945,10000.0,14313.330046653748,14836.49200987816,14313.330046653748,520.8403308391571,0.9460337162017822,0.0 -42500,3.5546396,1.8469548,,,,,,,,,,,,,, -42600,4.0991373,1.88497,,,,,,,,,,,,,, -42700,3.2374637,1.8786385,,,,,,,,,,,,,, -42800,3.3215861,1.8530977,,,,,,,,,,,,,, -42900,4.150152,1.920571,,,,,,,,,,,,,, -43000,3.2062323,1.8996536,,,,,,,,,,,,,, -43100,3.3408802,1.9121997,,,,,,,,,,,,,, -43200,5.092698,1.9336972,,,,,,,,,,,,,, -43300,3.1381898,1.8919333,,,,,,,,,,,,,, -43400,3.438231,1.7796397,,,,,,,,,,,,,, -43500,3.2917447,1.8746153,,,,,,,,,,,,,, -43600,3.8653083,1.9877803,,,,,,,,,,,,,, -43700,3.7701313,1.8308111,,,,,,,,,,,,,, -43800,4.0993257,1.9478128,,,,,,,,,,,,,, -43900,3.2965803,1.9469317,,,,,,,,,,,,,, -43972,,,0.6569873690605164,1.3638534545898438,0.5990399718284607,1.6688741445541382,50000.0,0.4779000282287597,2.388891696929932,10000.0,14823.266336917875,15364.582607030869,14823.266336917875,538.9052393436432,0.9865646362304688,0.0 -44000,4.0188975,1.9939057,,,,,,,,,,,,,, -44100,3.399561,1.9106953,,,,,,,,,,,,,, -44200,3.447333,1.9926527,,,,,,,,,,,,,, -44300,3.725882,1.9289398,,,,,,,,,,,,,, -44400,3.5099955,1.9752955,,,,,,,,,,,,,, -44500,3.9381735,2.0978482,,,,,,,,,,,,,, -44600,3.8523772,1.9193332,,,,,,,,,,,,,, -44700,3.2590828,1.9436543,,,,,,,,,,,,,, -44800,3.4670258,1.9467586,,,,,,,,,,,,,, -44900,3.7359345,1.8762612,,,,,,,,,,,,,, -45000,4.0926538,1.8491712,,,,,,,,,,,,,, -45100,3.6228237,1.9555323,,,,,,,,,,,,,, -45200,2.940621,1.7715309,,,,,,,,,,,,,, -45300,4.1028137,1.9810545,,,,,,,,,,,,,, -45400,4.1514444,1.8017117,,,,,,,,,,,,,, -45491,,,0.6528021097183228,1.3831080198287964,0.6025800108909607,1.646975874900818,50000.0,0.4779000282287597,2.395018577575684,10000.0,15333.213256835938,15892.32246518135,15333.213256835938,556.6109170913696,1.0245742797851562,0.0 -45500,3.291067,1.7514174,,,,,,,,,,,,,, -45600,3.6896946,1.9284326,,,,,,,,,,,,,, -45700,3.1103702,1.8927721,,,,,,,,,,,,,, -45800,3.7177083,1.8517731,,,,,,,,,,,,,, -45900,4.116309,1.9358755,,,,,,,,,,,,,, -46000,3.212949,1.9451929,,,,,,,,,,,,,, -46100,3.8028681,1.9203258,,,,,,,,,,,,,, -46200,3.3618262,2.043818,,,,,,,,,,,,,, -46300,3.7512784,1.7058637,,,,,,,,,,,,,, -46400,3.2148066,1.8215388,,,,,,,,,,,,,, -46500,3.7200952,1.8744107,,,,,,,,,,,,,, -46600,3.2653196,1.9614391,,,,,,,,,,,,,, -46700,3.5651946,1.8327278,,,,,,,,,,,,,, -46800,3.557309,1.9229157,,,,,,,,,,,,,, -46900,4.1383843,1.9074557,,,,,,,,,,,,,, -47000,3.2955809,2.0012667,,,,,,,,,,,,,, -47009,,,0.6518455147743225,1.3824917078018188,0.6039599776268005,1.638573408126831,50000.0,0.4746000170707702,2.371513843536377,10000.0,15843.245793819427,16421.24270606041,15843.245793819427,575.4105768203735,1.063220739364624,0.0 -47100,5.5233536,1.8488047,,,,,,,,,,,,,, -47200,3.8022358,1.9884473,,,,,,,,,,,,,, -47300,3.3751297,1.9318366,,,,,,,,,,,,,, -47400,3.3428915,1.9195956,,,,,,,,,,,,,, -47500,4.368177,1.9721432,,,,,,,,,,,,,, -47600,3.3819623,1.8365788,,,,,,,,,,,,,, -47700,4.068681,1.8325833,,,,,,,,,,,,,, -47800,3.612035,1.9659872,,,,,,,,,,,,,, -47900,3.8018398,1.8446337,,,,,,,,,,,,,, -48000,3.7921293,1.851865,,,,,,,,,,,,,, -48100,3.6768281,1.7743821,,,,,,,,,,,,,, -48200,3.3586268,1.8822381,,,,,,,,,,,,,, -48300,4.161514,2.0066285,,,,,,,,,,,,,, -48400,3.3258832,2.0833983,,,,,,,,,,,,,, -48500,4.9673877,1.8034767,,,,,,,,,,,,,, -48528,,,0.6499919891357422,1.4085599184036257,0.602620005607605,1.6432632207870483,50000.0,0.4837000370025635,2.336787700653076,10000.0,16353.201929330826,16949.052980184555,16353.201929330826,593.1785054206848,1.1000022888183594,0.0 -48600,3.6075912,1.7998906,,,,,,,,,,,,,, -48700,3.6165822,1.8156507,,,,,,,,,,,,,, -48800,3.313027,1.8282826,,,,,,,,,,,,,, -48900,3.4431112,1.9645156,,,,,,,,,,,,,, -49000,3.538068,1.8493978,,,,,,,,,,,,,, -49100,3.8227248,1.8120896,,,,,,,,,,,,,, -49200,3.2462,1.7801126,,,,,,,,,,,,,, -49300,3.3727293,1.9389923,,,,,,,,,,,,,, -49400,3.5646431,1.9617711,,,,,,,,,,,,,, -49500,3.790097,1.7960666,,,,,,,,,,,,,, -49600,3.613824,1.9536786,,,,,,,,,,,,,, -49700,3.5802126,1.9237286,,,,,,,,,,,,,, -49800,3.3998291,1.9301981,,,,,,,,,,,,,, -49900,3.502161,1.908607,,,,,,,,,,,,,, -50000,4.2248034,1.8333399,,,,,,,,,,,,,, -50046,,,0.6486168503761292,1.416173219680786,0.6020399928092957,1.6556391716003418,50000.0,0.4767000079154968,2.3756377696990967,10000.0,16863.183659791946,17477.124361991882,16863.183659791946,611.1795015335083,1.139614820480347,0.0 -50100,3.1099446,1.8306247,,,,,,,,,,,,,, -50200,3.7906418,2.0164897,,,,,,,,,,,,,, -50300,3.427038,1.9524176,,,,,,,,,,,,,, -50400,3.4849148,1.7805024,,,,,,,,,,,,,, -50500,4.218922,2.0297587,,,,,,,,,,,,,, -50600,3.6176364,1.8864932,,,,,,,,,,,,,, -50700,4.7073298,1.8683902,,,,,,,,,,,,,, -50800,3.625626,1.9485977,,,,,,,,,,,,,, -50900,3.8516533,1.9060439,,,,,,,,,,,,,, -51000,3.4018357,1.9543362,,,,,,,,,,,,,, -51100,3.9977674,1.9313728,,,,,,,,,,,,,, -51200,3.63985,1.8772488,,,,,,,,,,,,,, -51300,4.044308,2.0330772,,,,,,,,,,,,,, -51400,3.4312832,1.8963075,,,,,,,,,,,,,, -51500,4.531955,1.8539393,,,,,,,,,,,,,, -51565,,,0.6805644035339355,1.269840955734253,0.6078000068664551,1.6356245279312134,50000.0,0.4915000200271606,2.3638558387756348,10000.0,17373.11920595169,18005.066600322723,17373.11920595169,629.0958392620087,1.1804418563842771,0.0 -51600,3.9844193,1.8048047,,,,,,,,,,,,,, -51700,4.6285663,1.9418366,,,,,,,,,,,,,, -51800,4.471687,1.9325762,,,,,,,,,,,,,, -51900,3.983232,1.9801593,,,,,,,,,,,,,, -52000,3.5602348,1.8893986,,,,,,,,,,,,,, -52100,3.4653828,1.7401863,,,,,,,,,,,,,, -52200,4.007313,1.7765688,,,,,,,,,,,,,, -52300,3.4825869,1.9069271,,,,,,,,,,,,,, -52400,3.3549082,1.8705813,,,,,,,,,,,,,, -52500,3.8144596,1.8310027,,,,,,,,,,,,,, -52600,4.1925426,1.762232,,,,,,,,,,,,,, -52700,3.7494752,1.9219031,,,,,,,,,,,,,, -52800,3.7341163,1.8149079,,,,,,,,,,,,,, -52900,3.6486158,1.9085692,,,,,,,,,,,,,, -53000,4.011844,1.8441744,,,,,,,,,,,,,, -53084,,,0.6640226244926453,1.340054988861084,0.6099599599838257,1.6146225929260254,50000.0,0.4881000220775604,2.3278653621673584,10000.0,17883.04052567482,18532.7405500412,17883.04052567482,646.7576491832733,1.2222459316253662,0.0 -53100,3.6957545,2.0044367,,,,,,,,,,,,,, -53200,3.5378258,1.8206639,,,,,,,,,,,,,, -53300,3.4835067,1.8940905,,,,,,,,,,,,,, -53400,3.882062,1.7778549,,,,,,,,,,,,,, -53500,3.1976175,1.8348404,,,,,,,,,,,,,, -53600,3.3737519,1.8602256,,,,,,,,,,,,,, -53700,4.330946,1.8828385,,,,,,,,,,,,,, -53800,4.509856,1.9694396,,,,,,,,,,,,,, -53900,3.3464315,1.8462837,,,,,,,,,,,,,, -54000,3.604339,1.8390205,,,,,,,,,,,,,, -54100,4.2151976,1.9466361,,,,,,,,,,,,,, -54200,3.3525808,1.8694549,,,,,,,,,,,,,, -54300,3.982766,1.9967198,,,,,,,,,,,,,, -54400,3.3735008,1.8585494,,,,,,,,,,,,,, -54500,3.6094124,1.8154048,,,,,,,,,,,,,, -54600,3.7505722,1.7860703,,,,,,,,,,,,,, -54603,,,0.6684470772743225,1.310733675956726,0.6151599884033203,1.5864496231079102,50000.0,0.496500015258789,2.257696866989136,10000.0,18393.22454881668,19061.04249453545,18393.22454881668,664.7858171463013,1.262617111206055,0.0 -54700,3.6162972,1.776396,,,,,,,,,,,,,, -54800,4.554474,2.0131016,,,,,,,,,,,,,, -54900,3.981033,1.9008673,,,,,,,,,,,,,, -55000,3.5867248,1.8797393,,,,,,,,,,,,,, -55100,3.7782764,1.9076,,,,,,,,,,,,,, -55200,3.889647,1.9137123,,,,,,,,,,,,,, -55300,3.5216103,1.7577226,,,,,,,,,,,,,, -55400,3.64254,1.8018539,,,,,,,,,,,,,, -55500,3.7864072,1.7735617,,,,,,,,,,,,,, -55600,4.0871816,1.7388016,,,,,,,,,,,,,, -55700,3.6525133,1.827925,,,,,,,,,,,,,, -55800,4.7467737,1.8813974,,,,,,,,,,,,,, -55900,3.6091623,1.9288917,,,,,,,,,,,,,, -56000,3.5159554,1.8575207,,,,,,,,,,,,,, -56100,3.8635883,1.7336955,,,,,,,,,,,,,, -56122,,,0.6581233739852905,1.3662872314453125,0.6144399642944336,1.60483980178833,50000.0,0.4861000180244446,2.338015079498291,10000.0,18903.22807765007,19588.880088090897,18903.22807765007,682.5275778770447,1.305849313735962,0.0 -56200,3.8762667,1.9097798,,,,,,,,,,,,,, -56300,4.0597115,1.7511157,,,,,,,,,,,,,, -56400,4.3650513,1.9226787,,,,,,,,,,,,,, -56500,3.5836504,1.8805363,,,,,,,,,,,,,, -56600,3.725535,2.0083847,,,,,,,,,,,,,, -56700,3.7141569,1.9341531,,,,,,,,,,,,,, -56800,3.8783271,1.9745643,,,,,,,,,,,,,, -56900,3.069598,1.8087012,,,,,,,,,,,,,, -57000,3.824121,1.9337211,,,,,,,,,,,,,, -57100,3.527149,1.7956951,,,,,,,,,,,,,, -57200,3.6037982,1.8787332,,,,,,,,,,,,,, -57300,3.5135927,1.947017,,,,,,,,,,,,,, -57400,4.0879474,1.8701384,,,,,,,,,,,,,, -57500,3.8849285,1.8754423,,,,,,,,,,,,,, -57600,4.0757446,1.8597419,,,,,,,,,,,,,, -57640,,,0.6615911722183228,1.3483868837356567,0.615559995174408,1.5824005603790283,50000.0,0.4958000183105469,2.2923803329467773,10000.0,19413.14436841011,20116.891170024872,19413.14436841011,700.529173374176,1.3492977619171145,0.0 -57700,3.7986066,1.8207417,,,,,,,,,,,,,, -57800,4.409816,1.7875688,,,,,,,,,,,,,, -57900,3.3991382,1.85433,,,,,,,,,,,,,, -58000,3.5731866,1.8637054,,,,,,,,,,,,,, -58100,4.191022,1.6666925,,,,,,,,,,,,,, -58200,4.190449,1.9051359,,,,,,,,,,,,,, -58300,3.350644,1.6857579,,,,,,,,,,,,,, -58400,3.6658156,1.7563664,,,,,,,,,,,,,, -58500,3.5434039,1.8140777,,,,,,,,,,,,,, -58600,3.3274512,1.8192216,,,,,,,,,,,,,, -58700,3.391162,1.8661034,,,,,,,,,,,,,, -58800,3.6793306,1.8693669,,,,,,,,,,,,,, -58900,4.3800025,1.739418,,,,,,,,,,,,,, -59000,3.882518,1.9365661,,,,,,,,,,,,,, -59100,4.276572,1.8743536,,,,,,,,,,,,,, -59161,,,0.6617307066917419,1.3514286279678345,0.6044600009918213,1.642971396446228,50000.0,0.4762000143527984,2.395241498947144,10000.0,19923.271131277084,20645.021767616272,19923.271131277084,718.4421739578247,1.3909647464752195,0.0 -59200,4.081921,1.8159701,,,,,,,,,,,,,, -59300,3.801328,1.7606429,,,,,,,,,,,,,, -59400,3.9787576,1.788253,,,,,,,,,,,,,, -59500,3.8848846,1.7343482,,,,,,,,,,,,,, -59600,3.8757682,1.913284,,,,,,,,,,,,,, -59700,3.5365222,1.8556058,,,,,,,,,,,,,, -59800,3.9180048,1.8114775,,,,,,,,,,,,,, -59900,3.5243144,1.8752235,,,,,,,,,,,,,, -60000,3.8871722,1.9325064,,,,,,,,,,,,,, -60100,3.777822,1.9030952,,,,,,,,,,,,,, -60200,3.5433424,1.8746221,,,,,,,,,,,,,, -60300,3.4099174,1.7791417,,,,,,,,,,,,,, -60400,4.0392303,1.7848272,,,,,,,,,,,,,, -60500,4.339727,1.8092034,,,,,,,,,,,,,, -60600,3.4181979,1.7025274,,,,,,,,,,,,,, -60682,,,0.6777144074440002,1.2699861526489258,0.6165800094604492,1.5854922533035278,50000.0,0.489300012588501,2.322281837463379,10000.0,20433.406439065933,21172.83334803581,20433.406439065933,736.025773525238,1.4345412254333496,0.0 -60700,3.6678798,1.812801,,,,,,,,,,,,,, -60800,3.3640597,1.9777341,,,,,,,,,,,,,, -60900,3.769485,1.8115509,,,,,,,,,,,,,, -61000,3.8635516,1.8219105,,,,,,,,,,,,,, -61100,3.649453,1.772014,,,,,,,,,,,,,, -61200,3.4142337,1.8008518,,,,,,,,,,,,,, -61300,3.5140617,1.7338951,,,,,,,,,,,,,, -61400,3.2191036,1.8530741,,,,,,,,,,,,,, -61500,3.7330284,1.9734869,,,,,,,,,,,,,, -61600,3.5550375,1.7487725,,,,,,,,,,,,,, -61700,3.6044247,1.8300899,,,,,,,,,,,,,, -61800,3.5806057,1.8941582,,,,,,,,,,,,,, -61900,3.6275034,1.9035337,,,,,,,,,,,,,, -62000,4.2184515,1.9400188,,,,,,,,,,,,,, -62100,4.326461,1.7983445,,,,,,,,,,,,,, -62200,4.3737326,1.9758688,,,,,,,,,,,,,, -62201,,,0.6781728267669678,1.2760369777679443,0.6189999580383301,1.5650577545166016,50000.0,0.4894000291824341,2.293773651123047,10000.0,20943.60581111908,21700.95312690735,20943.60581111908,753.8541922569275,1.4773738384246826,0.0 -62300,4.208108,1.7866964,,,,,,,,,,,,,, -62400,3.8096325,1.8206673,,,,,,,,,,,,,, -62500,4.7099366,1.8343761,,,,,,,,,,,,,, -62600,4.132096,1.832621,,,,,,,,,,,,,, -62700,4.407948,1.9907213,,,,,,,,,,,,,, -62800,3.7217603,1.8128456,,,,,,,,,,,,,, -62900,3.4758055,1.755108,,,,,,,,,,,,,, -63000,4.274367,1.7351526,,,,,,,,,,,,,, -63100,4.5409164,1.9275179,,,,,,,,,,,,,, -63200,3.7831395,1.7703514,,,,,,,,,,,,,, -63300,3.7120988,1.826069,,,,,,,,,,,,,, -63400,3.589015,1.8473537,,,,,,,,,,,,,, -63500,3.3097992,1.7673726,,,,,,,,,,,,,, -63600,4.0114646,1.8795687,,,,,,,,,,,,,, -63700,3.5381355,1.7243572,,,,,,,,,,,,,, -63721,,,0.6701809763908386,1.300941824913025,0.6164000034332275,1.5907514095306396,50000.0,0.4957000315189361,2.286545991897583,10000.0,21453.79224085808,22228.8758354187,21453.79224085808,771.4942946434021,1.524343490600586,0.0 -63800,4.1626234,1.7596889,,,,,,,,,,,,,, -63900,4.535683,1.7607479,,,,,,,,,,,,,, -64000,3.7722077,1.8974969,,,,,,,,,,,,,, -64100,4.2942476,1.7654871,,,,,,,,,,,,,, -64200,3.8205576,1.8426511,,,,,,,,,,,,,, -64300,4.1757436,1.8774358,,,,,,,,,,,,,, -64400,3.7358398,1.8110272,,,,,,,,,,,,,, -64500,3.914098,1.8605733,,,,,,,,,,,,,, -64600,3.6599774,1.7103992,,,,,,,,,,,,,, -64700,3.4385045,1.820394,,,,,,,,,,,,,, -64800,4.029056,1.8554478,,,,,,,,,,,,,, -64900,3.4239862,1.9117309,,,,,,,,,,,,,, -65000,3.9246554,1.7253311,,,,,,,,,,,,,, -65100,3.8168674,1.820014,,,,,,,,,,,,,, -65200,3.6646116,1.8417988,,,,,,,,,,,,,, -65241,,,0.6570671200752258,1.3706705570220947,0.6152200102806091,1.615429162979126,50000.0,0.4811000227928161,2.425404787063598,10000.0,21963.94081664085,22756.996761083603,21963.94081664085,789.372960805893,1.5687105655670166,0.0 -65300,4.0291615,1.7911887,,,,,,,,,,,,,, -65400,4.1052165,1.8275893,,,,,,,,,,,,,, -65500,3.865405,1.8835945,,,,,,,,,,,,,, -65600,3.5871434,1.8277164,,,,,,,,,,,,,, -65700,4.0447626,1.7087791,,,,,,,,,,,,,, -65800,3.8824337,1.8673565,,,,,,,,,,,,,, -65900,3.9914262,1.8443292,,,,,,,,,,,,,, -66000,3.2586305,1.7934703,,,,,,,,,,,,,, -66100,3.8995905,1.8023126,,,,,,,,,,,,,, -66200,3.9970794,1.777616,,,,,,,,,,,,,, -66300,3.7520857,1.8352425,,,,,,,,,,,,,, -66400,3.458847,1.8605889,,,,,,,,,,,,,, -66500,3.5524712,1.8090897,,,,,,,,,,,,,, -66600,3.3721845,1.8047348,,,,,,,,,,,,,, -66700,3.3760474,1.8768753,,,,,,,,,,,,,, -66761,,,0.6267139315605164,1.5229088068008425,0.5832399725914001,1.7533199787139893,50000.0,0.4582000076770782,2.52813458442688,10000.0,22474.06643724441,23285.236786603928,22474.06643724441,807.3948268890381,1.6115176677703855,0.0 -66800,3.8400407,1.8212647,,,,,,,,,,,,,, -66900,3.6930187,1.8069266,,,,,,,,,,,,,, -67000,4.186905,1.8628223,,,,,,,,,,,,,, -67100,3.5330703,1.8212095,,,,,,,,,,,,,, -67200,3.816036,1.871605,,,,,,,,,,,,,, -67300,3.8346746,1.8462234,,,,,,,,,,,,,, -67400,3.4720433,1.8232468,,,,,,,,,,,,,, -67500,3.7220864,1.770128,,,,,,,,,,,,,, -67600,4.3479037,1.7801046,,,,,,,,,,,,,, -67700,4.203117,1.810617,,,,,,,,,,,,,, -67800,4.2931466,1.8095782,,,,,,,,,,,,,, -67900,3.9761937,1.6886172,,,,,,,,,,,,,, -68000,4.2397857,1.949881,,,,,,,,,,,,,, -68100,3.391909,1.8042744,,,,,,,,,,,,,, -68200,3.8437898,1.8033576,,,,,,,,,,,,,, -68281,,,0.7253667116165161,1.0721691846847534,0.6342999935150146,1.5084197521209717,50000.0,0.5125000476837158,2.1850202083587646,10000.0,22984.056749343872,23813.01803874969,22984.056749343872,825.0936350822449,1.6541671752929688,0.0 -68300,3.448289,1.8236309,,,,,,,,,,,,,, -68400,3.9923086,1.8150803,,,,,,,,,,,,,, -68500,3.9901066,1.8980923,,,,,,,,,,,,,, -68600,3.664592,1.7692915,,,,,,,,,,,,,, -68700,3.7729304,1.7722237,,,,,,,,,,,,,, -68800,4.3667393,1.7737948,,,,,,,,,,,,,, -68900,3.6895168,1.8805836,,,,,,,,,,,,,, -69000,3.8233082,1.7667027,,,,,,,,,,,,,, -69100,3.97967,1.8510394,,,,,,,,,,,,,, -69200,3.9433131,1.8326151,,,,,,,,,,,,,, -69300,3.6696858,1.7830627,,,,,,,,,,,,,, -69400,4.354436,1.8054805,,,,,,,,,,,,,, -69500,3.9119148,1.7221777,,,,,,,,,,,,,, -69600,4.154936,1.6630836,,,,,,,,,,,,,, -69700,3.3913682,1.714173,,,,,,,,,,,,,, -69799,,,0.6788305044174194,1.2620773315429688,0.6158999800682068,1.5898820161819458,50000.0,0.4927000105381012,2.3095078468322754,10000.0,23493.977559566498,24340.83118534088,23493.977559566498,842.8929722309113,1.6980137825012207,0.0 -69800,4.0284257,1.7671578,,,,,,,,,,,,,, -69900,3.7416885,1.7698958,,,,,,,,,,,,,, -70000,4.055188,1.7853038,,,,,,,,,,,,,, -70100,3.839172,1.8981968,,,,,,,,,,,,,, -70200,4.3837867,1.8086736,,,,,,,,,,,,,, -70300,4.047992,1.6855639,,,,,,,,,,,,,, -70400,3.918314,1.725853,,,,,,,,,,,,,, -70500,3.9257069,1.8401757,,,,,,,,,,,,,, -70600,3.6937082,1.8402354,,,,,,,,,,,,,, -70700,3.9599044,1.8711205,,,,,,,,,,,,,, -70800,3.6173658,1.828374,,,,,,,,,,,,,, -70900,3.5476296,1.7796259,,,,,,,,,,,,,, -71000,3.9328039,1.7336841,,,,,,,,,,,,,, -71100,3.6567202,1.7771777,,,,,,,,,,,,,, -71200,3.5820897,1.8984562,,,,,,,,,,,,,, -71300,4.0184417,1.7758521,,,,,,,,,,,,,, -71319,,,0.6802256107330322,1.253443956375122,0.6263799667358398,1.537427306175232,50000.0,0.5060000419616699,2.2284297943115234,10000.0,24004.088791370392,24868.941687583923,24004.088791370392,860.7975602149963,1.7430338859558103,0.0 -71400,4.5538263,1.7711755,,,,,,,,,,,,,, -71500,4.3013725,1.7229548,,,,,,,,,,,,,, -71600,3.5904217,1.7757181,,,,,,,,,,,,,, -71700,3.7805893,1.8925285,,,,,,,,,,,,,, -71800,3.408559,1.8090559,,,,,,,,,,,,,, -71900,3.571613,1.8095893,,,,,,,,,,,,,, -72000,4.0350266,1.7213804,,,,,,,,,,,,,, -72100,4.037159,1.8146951,,,,,,,,,,,,,, -72200,3.8258674,1.8440255,,,,,,,,,,,,,, -72300,3.8871074,1.8231821,,,,,,,,,,,,,, -72400,3.7645373,1.7851832,,,,,,,,,,,,,, -72500,3.9583867,1.7059698,,,,,,,,,,,,,, -72600,3.8277853,1.8631219,,,,,,,,,,,,,, -72700,3.99653,1.9168043,,,,,,,,,,,,,, -72800,3.9146373,1.6812013,,,,,,,,,,,,,, -72838,,,0.6843112111091614,1.24107563495636,0.6298199892044067,1.5166643857955933,50000.0,0.5053000450134277,2.216010332107544,10000.0,24514.16850876808,25397.194969654083,24514.16850876808,878.8791308403015,1.7854652404785156,0.0 -72900,4.2722006,1.8509524,,,,,,,,,,,,,, -73000,3.726557,1.7026888,,,,,,,,,,,,,, -73100,4.3042436,1.7144347,,,,,,,,,,,,,, -73200,4.0587907,1.7497056,,,,,,,,,,,,,, -73300,4.271354,1.7370219,,,,,,,,,,,,,, -73400,3.7070892,1.6267675,,,,,,,,,,,,,, -73500,3.8810802,1.6726015,,,,,,,,,,,,,, -73600,3.7620845,1.684486,,,,,,,,,,,,,, -73700,3.6692564,1.878479,,,,,,,,,,,,,, -73800,4.1647935,1.7446519,,,,,,,,,,,,,, -73900,4.5329623,1.7635928,,,,,,,,,,,,,, -74000,3.7836463,1.669728,,,,,,,,,,,,,, -74100,3.7331386,1.6843107,,,,,,,,,,,,,, -74200,3.8796642,1.7769607,,,,,,,,,,,,,, -74300,4.2642536,1.8101553,,,,,,,,,,,,,, -74359,,,0.6735291481018066,1.2930454015731812,0.6256600022315979,1.5487301349639893,50000.0,0.5045000314712524,2.2394707202911377,10000.0,25024.354562044144,25925.49914598465,25024.354562044144,896.9066410064697,1.8266286849975584,0.0 -74400,4.2329726,1.6764553,,,,,,,,,,,,,, -74500,3.8007154,1.7835637,,,,,,,,,,,,,, -74600,4.4309916,1.8397202,,,,,,,,,,,,,, -74700,4.936646,1.7119536,,,,,,,,,,,,,, -74800,3.9582012,1.7578487,,,,,,,,,,,,,, -74900,3.8632898,1.7742903,,,,,,,,,,,,,, -75000,5.6951613,1.745834,,,,,,,,,,,,,, -75100,3.8347926,1.7622414,,,,,,,,,,,,,, -75200,4.1410575,1.6489525,,,,,,,,,,,,,, -75300,3.7426763,1.7831191,,,,,,,,,,,,,, -75400,4.3912916,1.644034,,,,,,,,,,,,,, -75500,4.136029,1.7574726,,,,,,,,,,,,,, -75600,4.123883,1.732134,,,,,,,,,,,,,, -75700,3.5110364,1.8196244,,,,,,,,,,,,,, -75800,3.499537,1.7158589,,,,,,,,,,,,,, -75878,,,0.6875996589660645,1.2251676321029663,0.6364799737930298,1.4780266284942627,50000.0,0.5094000101089478,2.201269149780273,10000.0,25534.386019468307,26453.276960134503,25534.386019468307,914.560183763504,1.8696751594543457,0.0 -75900,4.118294,1.7481393,,,,,,,,,,,,,, -76000,3.92424,1.6754093,,,,,,,,,,,,,, -76100,4.4584785,1.8452384,,,,,,,,,,,,,, -76200,4.4753184,1.8090703,,,,,,,,,,,,,, -76300,3.7691298,1.7613158,,,,,,,,,,,,,, -76400,3.5729032,1.6858088,,,,,,,,,,,,,, -76500,4.206225,1.8996304,,,,,,,,,,,,,, -76600,4.337208,1.7427205,,,,,,,,,,,,,, -76700,3.8520133,1.7734237,,,,,,,,,,,,,, -76800,4.468652,1.6759719,,,,,,,,,,,,,, -76900,4.1157556,1.8289578,,,,,,,,,,,,,, -77000,3.8520725,1.7461507,,,,,,,,,,,,,, -77100,3.8110318,1.6923392,,,,,,,,,,,,,, -77200,3.8579085,1.7580221,,,,,,,,,,,,,, -77300,4.4239707,1.7932882,,,,,,,,,,,,,, -77398,,,0.7134685516357422,1.1057634353637695,0.6320599913597107,1.5262938737869265,50000.0,0.4964000284671783,2.239720582962036,10000.0,26044.426952838898,26981.249837636948,26044.426952838898,932.3965713977814,1.9160151481628416,0.0 -77400,3.611305,1.7700534,,,,,,,,,,,,,, -77500,4.1686044,1.8245945,,,,,,,,,,,,,, -77600,3.8884416,1.7012973,,,,,,,,,,,,,, -77700,4.1073995,1.8160815,,,,,,,,,,,,,, -77800,4.1591597,1.7437509,,,,,,,,,,,,,, -77900,3.6797018,1.7957777,,,,,,,,,,,,,, -78000,3.7737648,1.7316364,,,,,,,,,,,,,, -78100,3.8964508,1.730469,,,,,,,,,,,,,, -78200,4.052955,1.6134192,,,,,,,,,,,,,, -78300,3.8738916,1.7539546,,,,,,,,,,,,,, -78400,3.781688,1.7092702,,,,,,,,,,,,,, -78500,5.4364195,1.7508658,,,,,,,,,,,,,, -78600,4.2529383,1.694495,,,,,,,,,,,,,, -78700,4.286309,1.839614,,,,,,,,,,,,,, -78800,4.357081,1.7782739,,,,,,,,,,,,,, -78900,4.150944,1.7509685,,,,,,,,,,,,,, -78918,,,0.7004145383834839,1.1687120199203491,0.6345599889755249,1.4978222846984863,50000.0,0.511900007724762,2.2021243572235107,10000.0,26554.584349632263,27509.33651995659,26554.584349632263,950.2284531593324,1.964091300964356,0.0 -79000,4.0361586,1.7885898,,,,,,,,,,,,,, -79100,3.7668068,1.7437356,,,,,,,,,,,,,, -79200,4.833593,1.8435402,,,,,,,,,,,,,, -79300,4.0524397,1.744182,,,,,,,,,,,,,, -79400,4.227996,1.842555,,,,,,,,,,,,,, -79500,4.095618,1.8204457,,,,,,,,,,,,,, -79600,4.1703296,1.732023,,,,,,,,,,,,,, -79700,4.222421,1.5934803,,,,,,,,,,,,,, -79800,4.3676057,1.7747433,,,,,,,,,,,,,, -79900,3.8389146,1.6339306,,,,,,,,,,,,,, -80000,3.9165297,1.7619237,,,,,,,,,,,,,, -80100,4.297904,1.7842588,,,,,,,,,,,,,, -80200,3.7330492,1.7453697,,,,,,,,,,,,,, -80300,4.282861,1.7186675,,,,,,,,,,,,,, -80400,4.4951944,1.7654858,,,,,,,,,,,,,, -80438,,,0.6895527839660645,1.2138910293579102,0.6326000094413757,1.5095300674438477,50000.0,0.4999000132083893,2.279034376144409,10000.0,27064.51063919068,28037.29874444008,27064.51063919068,968.1738333702089,2.00515365600586,0.0 -80500,4.2163033,1.7829081,,,,,,,,,,,,,, -80600,4.1549325,1.7138898,,,,,,,,,,,,,, -80700,3.936899,1.7195371,,,,,,,,,,,,,, -80800,4.482318,1.8663416,,,,,,,,,,,,,, -80900,4.087675,1.7333198,,,,,,,,,,,,,, -81000,4.162828,1.6332533,,,,,,,,,,,,,, -81100,3.6648386,1.7697494,,,,,,,,,,,,,, -81200,4.8101125,1.6690829,,,,,,,,,,,,,, -81300,3.8896954,1.7743895,,,,,,,,,,,,,, -81400,4.127283,1.6607296,,,,,,,,,,,,,, -81500,4.437369,1.6824853,,,,,,,,,,,,,, -81600,4.6876483,1.6775784,,,,,,,,,,,,,, -81700,4.164914,1.6408668,,,,,,,,,,,,,, -81800,4.9102182,1.7971786,,,,,,,,,,,,,, -81900,4.354825,1.6967775,,,,,,,,,,,,,, -81958,,,0.6863241195678711,1.2330068349838257,0.6323399543762207,1.5172523260116575,50000.0,0.5067999958992004,2.243538379669189,10000.0,27574.732538461685,28565.79950976372,27574.732538461685,986.3525202274324,2.056394100189209,0.0 -82000,4.3012857,1.7189814,,,,,,,,,,,,,, -82100,5.464732,1.6856706,,,,,,,,,,,,,, -82200,4.257582,1.6965847,,,,,,,,,,,,,, -82300,4.074801,1.7583514,,,,,,,,,,,,,, -82400,4.2271237,1.87606,,,,,,,,,,,,,, -82500,4.0036254,1.6953769,,,,,,,,,,,,,, -82600,3.792963,1.8981501,,,,,,,,,,,,,, -82700,4.2855134,1.7687898,,,,,,,,,,,,,, -82800,4.401528,1.7189139,,,,,,,,,,,,,, -82900,4.2960305,1.6716983,,,,,,,,,,,,,, -83000,4.6166472,1.6783983,,,,,,,,,,,,,, -83100,3.7237172,1.6230905,,,,,,,,,,,,,, -83200,3.7966027,1.7600757,,,,,,,,,,,,,, -83300,4.182965,1.6812879,,,,,,,,,,,,,, -83400,4.863439,1.70887,,,,,,,,,,,,,, -83478,,,0.6814213991165161,1.2530291080474854,0.6282199621200562,1.5206059217453003,50000.0,0.5056000351905823,2.229276418685913,10000.0,28084.9185359478,29094.119473934174,28084.9185359478,1004.3920419216156,2.1011264324188232,0.0 -83500,4.4018784,1.7701983,,,,,,,,,,,,,, -83600,4.3083353,1.743619,,,,,,,,,,,,,, -83700,3.9713428,1.6750101,,,,,,,,,,,,,, -83800,4.175895,1.7894536,,,,,,,,,,,,,, -83900,4.675874,1.7952816,,,,,,,,,,,,,, -84000,3.9957564,1.6772853,,,,,,,,,,,,,, -84100,4.2889504,1.653345,,,,,,,,,,,,,, -84200,3.9991367,1.736485,,,,,,,,,,,,,, -84300,4.2340636,1.7011127,,,,,,,,,,,,,, -84400,5.039123,1.7269963,,,,,,,,,,,,,, -84500,3.942379,1.708672,,,,,,,,,,,,,, -84600,4.414907,1.7891287,,,,,,,,,,,,,, -84700,4.050233,1.7193319,,,,,,,,,,,,,, -84800,4.124746,1.8346077,,,,,,,,,,,,,, -84900,4.9267144,1.7433939,,,,,,,,,,,,,, -84998,,,0.7023875713348389,1.1722116470336914,0.6421599984169006,1.463862419128418,50000.0,0.5214000344276428,2.1729736328125,10000.0,28594.97287845612,29622.433605909348,28594.97287845612,1022.558931350708,2.1451663970947266,0.0 -85000,4.619613,1.7440883,,,,,,,,,,,,,, -85100,4.554713,1.6924957,,,,,,,,,,,,,, -85200,4.3858056,1.6861517,,,,,,,,,,,,,, -85300,4.373217,1.6299798,,,,,,,,,,,,,, -85400,4.35831,1.6828996,,,,,,,,,,,,,, -85500,3.862309,1.7679565,,,,,,,,,,,,,, -85600,4.310859,1.631858,,,,,,,,,,,,,, -85700,4.159386,1.7340773,,,,,,,,,,,,,, -85800,4.42987,1.6744602,,,,,,,,,,,,,, -85900,5.253299,1.6681769,,,,,,,,,,,,,, -86000,4.6091995,1.680499,,,,,,,,,,,,,, -86100,4.726867,1.6440004,,,,,,,,,,,,,, -86200,4.9727573,1.6521175,,,,,,,,,,,,,, -86300,4.8141537,1.7247663,,,,,,,,,,,,,, -86400,4.4048867,1.729829,,,,,,,,,,,,,, -86500,4.1401415,1.812526,,,,,,,,,,,,,, -86518,,,0.7252471446990967,1.0614542961120603,0.6484799981117249,1.429309368133545,50000.0,0.5195000171661377,2.1262402534484863,10000.0,29105.079597473145,30150.35160303116,29105.079597473145,1040.2748339176178,2.191434383392334,0.0 -86600,4.2612963,1.70118,,,,,,,,,,,,,, -86700,4.4140453,1.6698592,,,,,,,,,,,,,, -86800,4.5776525,1.8071042,,,,,,,,,,,,,, -86900,3.9882882,1.6696374,,,,,,,,,,,,,, -87000,5.1383896,1.750253,,,,,,,,,,,,,, -87100,4.1826215,1.641399,,,,,,,,,,,,,, -87200,3.9661279,1.6505065,,,,,,,,,,,,,, -87300,3.9009035,1.7686743,,,,,,,,,,,,,, -87400,4.4296517,1.7024589,,,,,,,,,,,,,, -87500,4.2796426,1.6973861,,,,,,,,,,,,,, -87600,4.185544,1.7434547,,,,,,,,,,,,,, -87700,4.518136,1.6334945,,,,,,,,,,,,,, -87800,3.8372169,1.6492996,,,,,,,,,,,,,, -87900,4.309948,1.7163306,,,,,,,,,,,,,, -88000,4.024069,1.6744658,,,,,,,,,,,,,, -88037,,,0.7028061151504517,1.1444787979125977,0.6430799961090088,1.4514235258102417,50000.0,0.518500030040741,2.1785125732421875,10000.0,29615.25696110725,30678.455330610275,29615.25696110725,1058.1022355556488,2.241151094436645,0.0 -88100,3.8756912,1.5914148,,,,,,,,,,,,,, -88200,4.094457,1.702629,,,,,,,,,,,,,, -88300,3.7400923,1.6784804,,,,,,,,,,,,,, -88400,4.169649,1.6226726,,,,,,,,,,,,,, -88500,4.676279,1.655206,,,,,,,,,,,,,, -88600,3.8232315,1.625191,,,,,,,,,,,,,, -88700,4.8670855,1.7119433,,,,,,,,,,,,,, -88800,4.0826197,1.5527443,,,,,,,,,,,,,, -88900,3.9968603,1.6213113,,,,,,,,,,,,,, -89000,4.652014,1.6820297,,,,,,,,,,,,,, -89100,4.1416225,1.6595039,,,,,,,,,,,,,, -89200,3.916604,1.6871147,,,,,,,,,,,,,, -89300,3.9937847,1.621095,,,,,,,,,,,,,, -89400,3.9992614,1.6045965,,,,,,,,,,,,,, -89500,3.8210986,1.487849,,,,,,,,,,,,,, -89557,,,0.7121930718421936,1.1185708045959473,0.6528599858283997,1.4161499738693235,50000.0,0.5265000462532043,2.140740156173706,10000.0,30125.36184811592,31206.683420419693,30125.36184811592,1076.1300678253174,2.2871439456939697,0.0 -89600,3.986921,1.7052116,,,,,,,,,,,,,, -89700,3.9540682,1.8162062,,,,,,,,,,,,,, -89800,4.2370553,1.6958563,,,,,,,,,,,,,, -89900,4.3509326,1.676114,,,,,,,,,,,,,, -90000,4.2271876,1.6208112,,,,,,,,,,,,,, -90100,4.795469,1.7144481,,,,,,,,,,,,,, -90200,3.8722086,1.6597524,,,,,,,,,,,,,, -90300,4.288927,1.7086148,,,,,,,,,,,,,, -90400,4.2714314,1.7408352,,,,,,,,,,,,,, -90500,4.0947785,1.7157874,,,,,,,,,,,,,, -90600,5.0681634,1.6256697,,,,,,,,,,,,,, -90700,4.3540993,1.551333,,,,,,,,,,,,,, -90800,4.650166,1.6730525,,,,,,,,,,,,,, -90900,4.1109056,1.5573018,,,,,,,,,,,,,, -91000,4.0464487,1.6791954,,,,,,,,,,,,,, -91078,,,0.7091438174247742,1.1337120532989502,0.6495800018310547,1.4375072717666626,50000.0,0.5210000276565552,2.142979145050049,10000.0,30635.47507429123,31734.72960019112,30635.47507429123,1093.9589052200315,2.341646194458008,0.0 -91100,4.3103423,1.7234144,,,,,,,,,,,,,, -91200,5.099743,1.7452154,,,,,,,,,,,,,, -91300,4.5503826,1.7055984,,,,,,,,,,,,,, -91400,3.9219139,1.567265,,,,,,,,,,,,,, -91500,4.107526,1.7364756,,,,,,,,,,,,,, -91600,3.8692572,1.7455597,,,,,,,,,,,,,, -91700,4.5964365,1.6549335,,,,,,,,,,,,,, -91800,5.08137,1.679714,,,,,,,,,,,,,, -91900,4.766801,1.6644992,,,,,,,,,,,,,, -92000,4.280865,1.6953685,,,,,,,,,,,,,, -92100,4.391095,1.7276015,,,,,,,,,,,,,, -92200,4.287134,1.7154548,,,,,,,,,,,,,, -92300,4.8255305,1.6205411,,,,,,,,,,,,,, -92400,4.2650986,1.6668599,,,,,,,,,,,,,, -92500,4.3506017,1.624679,,,,,,,,,,,,,, -92598,,,0.7068518400192261,1.1458536386489868,0.6547399759292603,1.4125068187713623,50000.0,0.5238000154495239,2.1197195053100586,10000.0,31145.39885497093,32262.498690605164,31145.39885497093,1111.7064609527588,2.390122890472412,0.0 -92600,4.1112404,1.6001555,,,,,,,,,,,,,, -92700,4.078218,1.6840488,,,,,,,,,,,,,, -92800,4.250196,1.6256126,,,,,,,,,,,,,, -92900,4.567244,1.5834496,,,,,,,,,,,,,, -93000,4.344425,1.6560369,,,,,,,,,,,,,, -93100,3.9001112,1.6701859,,,,,,,,,,,,,, -93200,4.2551303,1.5640296,,,,,,,,,,,,,, -93300,4.3229637,1.7025183,,,,,,,,,,,,,, -93400,3.860143,1.5392216,,,,,,,,,,,,,, -93500,4.421257,1.6437876,,,,,,,,,,,,,, -93600,4.903606,1.5962759,,,,,,,,,,,,,, -93700,3.9515774,1.7101842,,,,,,,,,,,,,, -93800,4.614161,1.6166309,,,,,,,,,,,,,, -93900,4.5797544,1.5542637,,,,,,,,,,,,,, -94000,4.3634562,1.644058,,,,,,,,,,,,,, -94100,4.646974,1.6753193,,,,,,,,,,,,,, -94119,,,0.7403738498687744,1.0028444528579712,0.6489399671554565,1.430841326713562,50000.0,0.5188000202178955,2.1304824352264404,10000.0,31655.37549734116,32790.37596988678,31655.37549734116,1129.5105464458466,2.4372799396514893,0.0 -94200,4.0050898,1.6263965,,,,,,,,,,,,,, -94300,4.3751693,1.6847092,,,,,,,,,,,,,, -94400,4.2876554,1.6371726,,,,,,,,,,,,,, -94500,4.3145375,1.6312261,,,,,,,,,,,,,, -94600,4.9973164,1.6548759,,,,,,,,,,,,,, -94700,4.192212,1.5514731,,,,,,,,,,,,,, -94800,4.3263054,1.5634222,,,,,,,,,,,,,, -94900,4.292208,1.7149935,,,,,,,,,,,,,, -95000,5.6426764,1.6948949,,,,,,,,,,,,,, -95100,4.152498,1.65763,,,,,,,,,,,,,, -95200,4.5127206,1.7351727,,,,,,,,,,,,,, -95300,4.003155,1.5667179,,,,,,,,,,,,,, -95400,4.5268307,1.7104656,,,,,,,,,,,,,, -95500,4.34657,1.6920639,,,,,,,,,,,,,, -95600,4.756909,1.6774249,,,,,,,,,,,,,, -95639,,,0.7263033986091614,1.055260419845581,0.6544399857521057,1.4071980714797974,50000.0,0.5311000347137451,2.102889060974121,10000.0,32165.38191127777,33318.218418598175,32165.38191127777,1147.253174781799,2.4813926219940186,0.0 -95700,4.4397335,1.6461805,,,,,,,,,,,,,, -95800,3.9768708,1.5537739,,,,,,,,,,,,,, -95900,3.9903846,1.6712832,,,,,,,,,,,,,, -96000,4.340566,1.6010618,,,,,,,,,,,,,, -96100,3.9161098,1.6927772,,,,,,,,,,,,,, -96200,4.444535,1.69349,,,,,,,,,,,,,, -96300,4.4461174,1.7008214,,,,,,,,,,,,,, -96400,4.1334395,1.4957597,,,,,,,,,,,,,, -96500,5.0587354,1.7000746,,,,,,,,,,,,,, -96600,3.9840639,1.672687,,,,,,,,,,,,,, -96700,4.4908466,1.5846673,,,,,,,,,,,,,, -96800,4.4325304,1.5864946,,,,,,,,,,,,,, -96900,4.50927,1.6537591,,,,,,,,,,,,,, -97000,4.4010572,1.6987814,,,,,,,,,,,,,, -97100,4.352824,1.65182,,,,,,,,,,,,,, -97159,,,0.7153021097183228,1.094274640083313,0.647059977054596,1.4412496089935305,50000.0,0.5247000455856323,2.1374266147613525,10000.0,32675.390652418137,33846.29317569733,32675.390652418137,1165.2235417366028,2.5274858474731445,0.0 -97200,4.5234194,1.6821133,,,,,,,,,,,,,, -97300,4.3522105,1.6649712,,,,,,,,,,,,,, -97400,4.1758237,1.5503311,,,,,,,,,,,,,, -97500,4.2926736,1.5154063,,,,,,,,,,,,,, -97600,5.3499427,1.6780428,,,,,,,,,,,,,, -97700,4.9570246,1.6500547,,,,,,,,,,,,,, -97800,4.03756,1.5898252,,,,,,,,,,,,,, -97900,5.0926857,1.5871272,,,,,,,,,,,,,, -98000,4.1533766,1.5547712,,,,,,,,,,,,,, -98100,3.7018638,1.5554888,,,,,,,,,,,,,, -98200,5.012796,1.6305113,,,,,,,,,,,,,, -98300,4.990362,1.697422,,,,,,,,,,,,,, -98400,4.2250195,1.549315,,,,,,,,,,,,,, -98500,4.582785,1.6062707,,,,,,,,,,,,,, -98600,4.718408,1.5475365,,,,,,,,,,,,,, -98679,,,0.7119140625,1.1197112798690796,0.6540799736976624,1.4109952449798584,50000.0,0.5229000449180603,2.1488847732543945,10000.0,33185.30989527702,34374.24014735222,33185.30989527702,1183.1489372253418,2.5802924633026123,0.0 -98700,4.6614976,1.5011926,,,,,,,,,,,,,, -98800,4.3532,1.5781641,,,,,,,,,,,,,, -98900,4.965194,1.6574879,,,,,,,,,,,,,, -99000,5.85382,1.5906787,,,,,,,,,,,,,, -99100,5.0477924,1.6550044,,,,,,,,,,,,,, -99200,4.1845317,1.6370611,,,,,,,,,,,,,, -99300,4.3790298,1.5310833,,,,,,,,,,,,,, -99400,4.820716,1.5996267,,,,,,,,,,,,,, -99500,4.5160007,1.7668226,,,,,,,,,,,,,, -99600,4.644123,1.6810635,,,,,,,,,,,,,, -99700,4.59786,1.7461643,,,,,,,,,,,,,, -99800,4.3165545,1.4582345,,,,,,,,,,,,,, -99900,3.7939029,1.57432,,,,,,,,,,,,,, -100000,4.593425,1.5505121,,,,,,,,,,,,,, -100100,4.400203,1.612308,,,,,,,,,,,,,, -100199,,,0.7194873690605164,1.0825355052947998,0.6628199815750122,1.372917652130127,50000.0,0.5344000458717346,2.116262435913086,10000.0,33695.51848888397,34902.30881524086,33695.51848888397,1200.9140849113464,2.625534772872925,0.0 -100200,4.8619328,1.594763,,,,,,,,,,,,,, -100300,4.268246,1.6292098,,,,,,,,,,,,,, -100400,4.1847644,1.6814624,,,,,,,,,,,,,, -100500,4.283506,1.6455793,,,,,,,,,,,,,, -100600,5.328192,1.523672,,,,,,,,,,,,,, -100700,4.10611,1.5600579,,,,,,,,,,,,,, -100800,4.2508826,1.4830239,,,,,,,,,,,,,, -100900,4.439103,1.6708782,,,,,,,,,,,,,, -101000,4.6319094,1.5851445,,,,,,,,,,,,,, -101100,4.2335854,1.5158422,,,,,,,,,,,,,, -101200,4.159153,1.443467,,,,,,,,,,,,,, -101300,4.2828894,1.5458417,,,,,,,,,,,,,, -101400,4.615174,1.5222982,,,,,,,,,,,,,, -101500,4.7075877,1.6048232,,,,,,,,,,,,,, -101600,4.3509974,1.5698341,,,,,,,,,,,,,, -101700,4.6765027,1.6174467,,,,,,,,,,,,,, -101718,,,0.7230349183082581,1.0770888328552246,0.6646599769592285,1.3627052307128906,50000.0,0.5348000526428223,2.0898184776306152,10000.0,34205.558450460434,35430.43071985245,34205.558450460434,1218.8947851657867,2.6776158809661865,0.0 -101800,5.364622,1.6080422,,,,,,,,,,,,,, -101900,4.8778048,1.6911877,,,,,,,,,,,,,, -102000,4.9950194,1.5621102,,,,,,,,,,,,,, -102100,4.369626,1.6141295,,,,,,,,,,,,,, -102200,5.067764,1.577311,,,,,,,,,,,,,, -102300,4.9805026,1.6753386,,,,,,,,,,,,,, -102400,5.055211,1.5007982,,,,,,,,,,,,,, -102500,4.6396537,1.6610391,,,,,,,,,,,,,, -102600,4.524176,1.5463678,,,,,,,,,,,,,, -102700,4.4169116,1.5247215,,,,,,,,,,,,,, -102800,4.65861,1.6614683,,,,,,,,,,,,,, -102900,4.7059293,1.5527377,,,,,,,,,,,,,, -103000,4.5017614,1.7214454,,,,,,,,,,,,,, -103100,4.3249283,1.6071616,,,,,,,,,,,,,, -103200,5.279138,1.6091573,,,,,,,,,,,,,, -103239,,,0.7520726919174194,0.94245183467865,0.662559986114502,1.3745607137680054,50000.0,0.537600040435791,2.0718114376068115,10000.0,34715.77932167053,35958.745235681534,34715.77932167053,1236.8908894062042,2.725346565246582,0.0 -103300,5.5983725,1.534237,,,,,,,,,,,,,, -103400,4.5236297,1.5762851,,,,,,,,,,,,,, -103500,4.550759,1.4879237,,,,,,,,,,,,,, -103600,4.724788,1.5104542,,,,,,,,,,,,,, -103700,4.3950763,1.492906,,,,,,,,,,,,,, -103800,5.038676,1.5593889,,,,,,,,,,,,,, -103900,4.6726165,1.5537052,,,,,,,,,,,,,, -104000,5.0450983,1.6289873,,,,,,,,,,,,,, -104100,5.4546075,1.5874963,,,,,,,,,,,,,, -104200,4.6682954,1.4743528,,,,,,,,,,,,,, -104300,5.097286,1.668077,,,,,,,,,,,,,, -104400,4.6406817,1.5852221,,,,,,,,,,,,,, -104500,5.1651173,1.5124661,,,,,,,,,,,,,, -104600,4.630506,1.493574,,,,,,,,,,,,,, -104700,4.254242,1.5307791,,,,,,,,,,,,,, -104760,,,0.7437419891357422,0.9715170860290528,0.6689199805259705,1.338512659072876,50000.0,0.541700005531311,2.064157247543335,10000.0,35225.935331106186,36486.63972115517,35225.935331106186,1254.5282595157623,2.777066469192505,0.0 -104800,5.1582327,1.5269113,,,,,,,,,,,,,, -104900,4.5168233,1.5514892,,,,,,,,,,,,,, -105000,5.0455427,1.535837,,,,,,,,,,,,,, -105100,4.537739,1.6423761,,,,,,,,,,,,,, -105200,4.5399756,1.5014298,,,,,,,,,,,,,, -105300,4.6168213,1.5729733,,,,,,,,,,,,,, -105400,4.7884045,1.6447542,,,,,,,,,,,,,, -105500,4.8817234,1.5285147,,,,,,,,,,,,,, -105600,5.4466276,1.5158031,,,,,,,,,,,,,, -105700,5.72843,1.6332119,,,,,,,,,,,,,, -105800,5.1140747,1.5188724,,,,,,,,,,,,,, -105900,4.709098,1.5471404,,,,,,,,,,,,,, -106000,4.7322593,1.4820802,,,,,,,,,,,,,, -106100,4.682618,1.616651,,,,,,,,,,,,,, -106200,4.2925863,1.5083101,,,,,,,,,,,,,, -106281,,,0.7352120280265808,1.0203943252563477,0.6659799814224243,1.3560991287231443,50000.0,0.5430000424385071,2.052748441696167,10000.0,35736.099576711655,37014.51857161522,35736.099576711655,1272.1391229629517,2.83117151260376,0.0 -106300,4.932644,1.5886112,,,,,,,,,,,,,, -106400,4.525142,1.5812542,,,,,,,,,,,,,, -106500,4.8971295,1.634745,,,,,,,,,,,,,, -106600,5.116842,1.6257768,,,,,,,,,,,,,, -106700,4.970265,1.4443967,,,,,,,,,,,,,, -106800,5.2575355,1.5131675,,,,,,,,,,,,,, -106900,5.1859827,1.6068678,,,,,,,,,,,,,, -107000,4.555709,1.4455158,,,,,,,,,,,,,, -107100,4.5639853,1.4894397,,,,,,,,,,,,,, -107200,4.406751,1.5087469,,,,,,,,,,,,,, -107300,4.331492,1.5301387,,,,,,,,,,,,,, -107400,4.6226153,1.5927359,,,,,,,,,,,,,, -107500,4.684986,1.5448085,,,,,,,,,,,,,, -107600,5.795269,1.696339,,,,,,,,,,,,,, -107700,4.639847,1.5695469,,,,,,,,,,,,,, -107800,4.682379,1.6261051,,,,,,,,,,,,,, -107801,,,0.7411311864852905,0.9817953109741212,0.675279974937439,1.3169835805892944,50000.0,0.5539000034332275,1.9979313611984253,10000.0,36246.10291814804,37542.64195537567,36246.10291814804,1290.1619033813477,2.8791487216949463,0.0 -107900,5.131617,1.5412441,,,,,,,,,,,,,, -108000,5.1124344,1.614605,,,,,,,,,,,,,, -108100,4.6048,1.5804343,,,,,,,,,,,,,, -108200,4.4955072,1.5222604,,,,,,,,,,,,,, -108300,4.212692,1.4362758,,,,,,,,,,,,,, -108400,4.5743976,1.583664,,,,,,,,,,,,,, -108500,4.9483953,1.5315063,,,,,,,,,,,,,, -108600,4.6088657,1.491706,,,,,,,,,,,,,, -108700,5.482743,1.5625659,,,,,,,,,,,,,, -108800,4.5685635,1.5450667,,,,,,,,,,,,,, -108900,4.988223,1.6791133,,,,,,,,,,,,,, -109000,5.129179,1.4958739,,,,,,,,,,,,,, -109100,4.766852,1.4965634,,,,,,,,,,,,,, -109200,4.1507363,1.4501162,,,,,,,,,,,,,, -109300,4.7236037,1.5314974,,,,,,,,,,,,,, -109321,,,0.7330197691917419,1.0413142442703247,0.670960009098053,1.3357013463974,50000.0,0.5401000380516052,2.053678274154663,10000.0,36756.17011857033,38070.74506354332,36756.17011857033,1308.0986626148224,2.928966760635376,0.0 -109400,5.576943,1.4594996,,,,,,,,,,,,,, -109500,4.983717,1.5119786,,,,,,,,,,,,,, -109600,5.6158204,1.531674,,,,,,,,,,,,,, -109700,4.84938,1.5960783,,,,,,,,,,,,,, -109800,5.409438,1.5672112,,,,,,,,,,,,,, -109900,5.4507895,1.6888096,,,,,,,,,,,,,, -110000,4.650571,1.6368809,,,,,,,,,,,,,, -110100,4.689189,1.4588531,,,,,,,,,,,,,, -110200,5.588014,1.435004,,,,,,,,,,,,,, -110300,5.1000023,1.6087188,,,,,,,,,,,,,, -110400,5.029582,1.4626642,,,,,,,,,,,,,, -110500,4.448529,1.4202403,,,,,,,,,,,,,, -110600,7.290776,1.6398299,,,,,,,,,,,,,, -110700,5.0742707,1.442868,,,,,,,,,,,,,, -110800,5.3138595,1.4072669,,,,,,,,,,,,,, -110842,,,0.7404536008834839,0.9842280745506288,0.675819993019104,1.313747763633728,50000.0,0.5503000020980835,2.027350187301636,10000.0,37266.25584864616,38598.62584900856,37266.25584864616,1325.7961132526398,2.9773285388946533,0.0 -110900,4.8355923,1.5887568,,,,,,,,,,,,,, -111000,4.9191527,1.4765978,,,,,,,,,,,,,, -111100,4.5880976,1.5042405,,,,,,,,,,,,,, -111200,5.2612205,1.4339625,,,,,,,,,,,,,, -111300,4.850444,1.480263,,,,,,,,,,,,,, -111400,4.9792786,1.5688162,,,,,,,,,,,,,, -111500,4.8223825,1.5065119,,,,,,,,,,,,,, -111600,4.987326,1.4760206,,,,,,,,,,,,,, -111700,5.00122,1.5674167,,,,,,,,,,,,,, -111800,4.73873,1.5558825,,,,,,,,,,,,,, -111900,4.7289877,1.4271941,,,,,,,,,,,,,, -112000,5.330058,1.508058,,,,,,,,,,,,,, -112100,5.2098417,1.6053424,,,,,,,,,,,,,, -112200,4.9426775,1.5937884,,,,,,,,,,,,,, -112300,4.4185686,1.4550683,,,,,,,,,,,,,, -112363,,,0.7700294852256775,0.8577762246131897,0.6807999610900879,1.2912861108779907,50000.0,0.5544000267982483,1.9689432382583616,10000.0,37776.244634628296,39126.41235136986,37776.244634628296,1343.4937970638275,3.02756404876709,0.0 -112400,4.919943,1.5388806,,,,,,,,,,,,,, -112500,5.704542,1.4378859,,,,,,,,,,,,,, -112600,4.8532224,1.5555425,,,,,,,,,,,,,, -112700,4.9631987,1.5831182,,,,,,,,,,,,,, -112800,5.6777673,1.5039309,,,,,,,,,,,,,, -112900,5.167291,1.5868173,,,,,,,,,,,,,, -113000,4.9226255,1.5656102,,,,,,,,,,,,,, -113100,5.7606926,1.517543,,,,,,,,,,,,,, -113200,5.08952,1.4303929,,,,,,,,,,,,,, -113300,4.8139367,1.6374167,,,,,,,,,,,,,, -113400,4.928392,1.5457997,,,,,,,,,,,,,, -113500,4.992895,1.51208,,,,,,,,,,,,,, -113600,5.020014,1.5136871,,,,,,,,,,,,,, -113700,4.9759307,1.5328727,,,,,,,,,,,,,, -113800,4.7712703,1.4654908,,,,,,,,,,,,,, -113883,,,0.7518534660339355,0.9311055541038512,0.6769199967384338,1.313949704170227,50000.0,0.542900025844574,2.041203260421753,10000.0,38286.34936618805,39654.57463693619,38286.34936618805,1361.454597234726,3.0750892162323,0.0 -113900,5.496258,1.5492085,,,,,,,,,,,,,, -114000,5.0423293,1.5076436,,,,,,,,,,,,,, -114100,4.9186454,1.5017254,,,,,,,,,,,,,, -114200,4.743636,1.4218194,,,,,,,,,,,,,, -114300,5.618883,1.457062,,,,,,,,,,,,,, -114400,6.1070876,1.5126348,,,,,,,,,,,,,, -114500,5.362962,1.5101118,,,,,,,,,,,,,, -114600,4.7276645,1.5050043,,,,,,,,,,,,,, -114700,5.188598,1.4497771,,,,,,,,,,,,,, -114800,5.0266213,1.4767749,,,,,,,,,,,,,, -114900,4.945123,1.4761266,,,,,,,,,,,,,, -115000,4.803269,1.3553045,,,,,,,,,,,,,, -115100,4.800389,1.5965991,,,,,,,,,,,,,, -115200,5.1058598,1.4709588,,,,,,,,,,,,,, -115300,4.7262783,1.4212788,,,,,,,,,,,,,, -115400,5.227575,1.4571882,,,,,,,,,,,,,, -115404,,,0.7578722834587097,0.92294579744339,0.6810199618339539,1.289506435394287,50000.0,0.5494000315666199,1.9755940437316888,10000.0,38796.49595832825,40182.67863607407,38796.49595832825,1379.3148682117462,3.123174905776977,0.0 -115500,5.1283865,1.4527837,,,,,,,,,,,,,, -115600,4.7546,1.4817514,,,,,,,,,,,,,, -115700,5.826596,1.4729806,,,,,,,,,,,,,, -115800,5.5489345,1.4320964,,,,,,,,,,,,,, -115900,4.995012,1.5084628,,,,,,,,,,,,,, -116000,4.8333535,1.390986,,,,,,,,,,,,,, -116100,5.1566772,1.3895481,,,,,,,,,,,,,, -116200,5.2969923,1.4082024,,,,,,,,,,,,,, -116300,5.2031355,1.3962243,,,,,,,,,,,,,, -116400,5.1754293,1.4974037,,,,,,,,,,,,,, -116500,5.5170064,1.4684346,,,,,,,,,,,,,, -116600,4.9315424,1.4871113,,,,,,,,,,,,,, -116700,4.9917364,1.4374301,,,,,,,,,,,,,, -116800,5.8160815,1.4072369,,,,,,,,,,,,,, -116900,4.831204,1.4797679,,,,,,,,,,,,,, -116925,,,0.7563177347183228,0.92183256149292,0.6868000030517578,1.2699737548828125,50000.0,0.563800036907196,1.953534245491028,10000.0,39306.57546567917,40710.75027012825,39306.57546567917,1397.2055568695068,3.1755475997924805,0.0 -117000,5.427405,1.41916,,,,,,,,,,,,,, -117100,4.738563,1.4423027,,,,,,,,,,,,,, -117200,5.204974,1.4910337,,,,,,,,,,,,,, -117300,5.2146077,1.4967058,,,,,,,,,,,,,, -117400,5.2885785,1.4650085,,,,,,,,,,,,,, -117500,5.044953,1.40198,,,,,,,,,,,,,, -117600,5.174506,1.5355127,,,,,,,,,,,,,, -117700,5.0866303,1.4539026,,,,,,,,,,,,,, -117800,4.9734807,1.434403,,,,,,,,,,,,,, -117900,5.8921075,1.4507838,,,,,,,,,,,,,, -118000,6.414428,1.4673665,,,,,,,,,,,,,, -118100,5.536702,1.4523964,,,,,,,,,,,,,, -118200,6.004834,1.5125263,,,,,,,,,,,,,, -118300,5.1553965,1.4149985,,,,,,,,,,,,,, -118400,5.275629,1.3130199,,,,,,,,,,,,,, -118447,,,0.7603236436843872,0.9137925505638124,0.6899399757385254,1.2629296779632568,50000.0,0.5552999973297119,1.979101061820984,10000.0,39816.80427622795,41238.8702480793,39816.80427622795,1414.9953515529633,3.2276132106781006,0.0 -118500,5.437137,1.4569614,,,,,,,,,,,,,, -118600,5.0126595,1.399343,,,,,,,,,,,,,, -118700,4.5069685,1.233762,,,,,,,,,,,,,, -118800,4.803383,1.3880421,,,,,,,,,,,,,, -118900,4.841435,1.4287317,,,,,,,,,,,,,, -119000,5.161687,1.3666664,,,,,,,,,,,,,, -119100,4.8919997,1.4674363,,,,,,,,,,,,,, -119200,5.291096,1.4958907,,,,,,,,,,,,,, -119300,5.1459455,1.441596,,,,,,,,,,,,,, -119400,4.7495947,1.4233279,,,,,,,,,,,,,, -119500,4.772388,1.4040644,,,,,,,,,,,,,, -119600,5.4993467,1.5670533,,,,,,,,,,,,,, -119700,5.0861516,1.4358696,,,,,,,,,,,,,, -119800,5.5601087,1.4739494,,,,,,,,,,,,,, -119900,5.3020625,1.366816,,,,,,,,,,,,,, -119967,,,0.7940250039100647,0.7722108364105225,0.6904199719429016,1.25036358833313,50000.0,0.5565000176429749,1.9592362642288208,10000.0,40326.80095338821,41766.92306900024,40326.80095338821,1432.9292194843292,3.300215721130371,0.0 -120000,5.2239304,1.4835173,,,,,,,,,,,,,, -120100,5.069397,1.4047596,,,,,,,,,,,,,, -120200,5.3621116,1.4549277,,,,,,,,,,,,,, -120300,5.2105165,1.4191307,,,,,,,,,,,,,, -120400,5.2624407,1.448667,,,,,,,,,,,,,, -120500,5.4684815,1.3884281,,,,,,,,,,,,,, -120600,6.505848,1.4649751,,,,,,,,,,,,,, -120700,6.5763936,1.4925327,,,,,,,,,,,,,, -120800,5.979877,1.4520446,,,,,,,,,,,,,, -120900,6.144285,1.5848296,,,,,,,,,,,,,, -121000,5.304216,1.4076117,,,,,,,,,,,,,, -121100,5.2634516,1.3641057,,,,,,,,,,,,,, -121200,5.166865,1.435003,,,,,,,,,,,,,, -121300,5.5215564,1.399999,,,,,,,,,,,,,, -121400,5.734982,1.5027719,,,,,,,,,,,,,, -121487,,,0.7747727632522583,0.8482697010040283,0.6867799758911133,1.2761807441711426,50000.0,0.5522000193595886,1.993802189826965,10000.0,40836.91822504997,42294.90674686432,40836.91822504997,1450.694310426712,3.352198362350464,0.0 -121500,5.2730284,1.4893909,,,,,,,,,,,,,, -121600,5.8511157,1.4314913,,,,,,,,,,,,,, -121700,4.915263,1.3516709,,,,,,,,,,,,,, -121800,5.298792,1.5243773,,,,,,,,,,,,,, -121900,5.216375,1.4205643,,,,,,,,,,,,,, -122000,5.082584,1.3682047,,,,,,,,,,,,,, -122100,5.5550804,1.3881791,,,,,,,,,,,,,, -122200,5.3703284,1.4061276,,,,,,,,,,,,,, -122300,5.656774,1.5108651,,,,,,,,,,,,,, -122400,5.5944133,1.3215889,,,,,,,,,,,,,, -122500,5.2301483,1.4333256,,,,,,,,,,,,,, -122600,5.755854,1.4520448,,,,,,,,,,,,,, -122700,5.361142,1.5078369,,,,,,,,,,,,,, -122800,4.796778,1.3191035,,,,,,,,,,,,,, -122900,5.316392,1.429079,,,,,,,,,,,,,, -123000,6.501104,1.5135494,,,,,,,,,,,,,, -123007,,,0.762137234210968,0.8999484181404114,0.6813399791717529,1.2865992784500122,50000.0,0.5526000261306763,2.0052297115325928,10000.0,41346.84619688988,42822.92417263985,41346.84619688988,1468.680810213089,3.4059770107269287,0.0 -123100,5.2791004,1.2743037,,,,,,,,,,,,,, -123200,5.813441,1.3292534,,,,,,,,,,,,,, -123300,5.7588925,1.3451331,,,,,,,,,,,,,, -123400,5.3892546,1.3358943,,,,,,,,,,,,,, -123500,5.588474,1.4320064,,,,,,,,,,,,,, -123600,5.1076927,1.3638605,,,,,,,,,,,,,, -123700,5.1904063,1.4268233,,,,,,,,,,,,,, -123800,5.765261,1.4541911,,,,,,,,,,,,,, -123900,6.0743384,1.4001713,,,,,,,,,,,,,, -124000,5.719925,1.4272124,,,,,,,,,,,,,, -124100,5.497149,1.406096,,,,,,,,,,,,,, -124200,5.8063707,1.4319146,,,,,,,,,,,,,, -124300,5.272955,1.4077148,,,,,,,,,,,,,, -124400,5.8014073,1.4510362,,,,,,,,,,,,,, -124500,5.17594,1.4191748,,,,,,,,,,,,,, -124527,,,0.7736367583274841,0.8483324646949768,0.692579984664917,1.2330009937286377,50000.0,0.5652000308036804,1.933935284614563,10000.0,41856.85910201073,43351.22230386734,41856.85910201073,1486.8664498329165,3.4564247131347656,0.0 -124600,5.2996154,1.3401108,,,,,,,,,,,,,, -124700,5.4098096,1.4612464,,,,,,,,,,,,,, -124800,6.1936255,1.4541858,,,,,,,,,,,,,, -124900,5.327618,1.3586279,,,,,,,,,,,,,, -125000,5.8875985,1.3344718,,,,,,,,,,,,,, -125100,5.8935323,1.4504502,,,,,,,,,,,,,, -125200,6.250264,1.4433608,,,,,,,,,,,,,, -125300,5.8268795,1.4419534,,,,,,,,,,,,,, -125400,4.8754506,1.3669376,,,,,,,,,,,,,, -125500,5.604123,1.4983808,,,,,,,,,,,,,, -125600,5.314623,1.3682564,,,,,,,,,,,,,, -125700,5.1082783,1.4054826,,,,,,,,,,,,,, -125800,5.530798,1.3820741,,,,,,,,,,,,,, -125900,5.454633,1.3845347,,,,,,,,,,,,,, -126000,5.4342895,1.360733,,,,,,,,,,,,,, -126047,,,0.7742147445678711,0.844782292842865,0.6972999572753906,1.2068699598312378,50000.0,0.570900022983551,1.8856645822525024,10000.0,42366.83264923096,43879.25383400917,42366.83264923096,1504.823757648468,3.5072178840637207,0.0 -126100,5.5389414,1.3549448,,,,,,,,,,,,,, -126200,6.0408845,1.491767,,,,,,,,,,,,,, -126300,5.5560746,1.3428652,,,,,,,,,,,,,, -126400,5.387577,1.3209627,,,,,,,,,,,,,, -126500,5.50525,1.3647655,,,,,,,,,,,,,, -126600,5.1540723,1.3370131,,,,,,,,,,,,,, -126700,5.4336886,1.2862419,,,,,,,,,,,,,, -126800,6.3409696,1.2796454,,,,,,,,,,,,,, -126900,5.709616,1.461086,,,,,,,,,,,,,, -127000,5.50761,1.4049482,,,,,,,,,,,,,, -127100,5.316426,1.2736984,,,,,,,,,,,,,, -127200,5.51425,1.418719,,,,,,,,,,,,,, -127300,5.7857018,1.3878543,,,,,,,,,,,,,, -127400,6.55133,1.3016425,,,,,,,,,,,,,, -127500,5.6081066,1.2686344,,,,,,,,,,,,,, -127567,,,0.7678172588348389,0.8690696358680725,0.6976400017738342,1.232379674911499,50000.0,0.5746000409126282,1.928336262702942,10000.0,42876.88329672813,44407.14696741104,42876.88329672813,1522.5617182254791,3.56189227104187,0.0 -127600,5.3843603,1.2715802,,,,,,,,,,,,,, -127700,6.0752325,1.3943326,,,,,,,,,,,,,, -127800,5.835962,1.4349238,,,,,,,,,,,,,, -127900,5.5136666,1.3342062,,,,,,,,,,,,,, -128000,5.533693,1.33337,,,,,,,,,,,,,, -128100,5.346227,1.3408633,,,,,,,,,,,,,, -128200,5.8047915,1.3988537,,,,,,,,,,,,,, -128300,6.1867332,1.310621,,,,,,,,,,,,,, -128400,5.881641,1.3364185,,,,,,,,,,,,,, -128500,5.4582505,1.2154548,,,,,,,,,,,,,, -128600,5.6034045,1.3589275,,,,,,,,,,,,,, -128700,6.5395417,1.3231016,,,,,,,,,,,,,, -128800,5.8283043,1.3579922,,,,,,,,,,,,,, -128900,5.7668343,1.4121919,,,,,,,,,,,,,, -129000,5.9649515,1.2834358,,,,,,,,,,,,,, -129088,,,0.8156289458274841,0.6751767992973328,0.7046799659729004,1.1848829984664917,50000.0,0.5788000226020813,1.8793182373046875,10000.0,43387.10359358788,44935.282566308975,43387.10359358788,1540.3736391067505,3.616084575653076,0.0 -129100,6.045297,1.3093238,,,,,,,,,,,,,, -129200,6.0125604,1.3434188,,,,,,,,,,,,,, -129300,6.0852165,1.4195485,,,,,,,,,,,,,, -129400,5.7218547,1.2579136,,,,,,,,,,,,,, -129500,6.2736754,1.3334062,,,,,,,,,,,,,, -129600,5.357958,1.2749567,,,,,,,,,,,,,, -129700,5.268927,1.2847314,,,,,,,,,,,,,, -129800,5.429336,1.3160613,,,,,,,,,,,,,, -129900,5.5094576,1.2819494,,,,,,,,,,,,,, -130000,6.0805793,1.2850512,,,,,,,,,,,,,, -130100,4.8149676,1.1608566,,,,,,,,,,,,,, -130200,5.5595264,1.3454332,,,,,,,,,,,,,, -130300,6.173154,1.3563789,,,,,,,,,,,,,, -130400,6.236894,1.2969155,,,,,,,,,,,,,, -130500,6.1995487,1.3455786,,,,,,,,,,,,,, -130600,5.4764423,1.2447613,,,,,,,,,,,,,, -130608,,,0.7964963316917419,0.7662532925605774,0.7048999667167664,1.1896097660064695,50000.0,0.5804000496864319,1.8776094913482664,10000.0,43897.194222450256,45463.11754608154,43897.194222450256,1558.0184445381165,3.6664631366729736,0.0 -130700,5.9166994,1.2854612,,,,,,,,,,,,,, -130800,5.8241854,1.2496505,,,,,,,,,,,,,, -130900,5.318327,1.4009426,,,,,,,,,,,,,, -131000,5.3383574,1.2861409,,,,,,,,,,,,,, -131100,6.3767133,1.3684762,,,,,,,,,,,,,, -131200,5.673643,1.2955867,,,,,,,,,,,,,, -131300,5.5549307,1.3061571,,,,,,,,,,,,,, -131400,6.1060357,1.359961,,,,,,,,,,,,,, -131500,6.0723047,1.4264959,,,,,,,,,,,,,, -131600,6.233939,1.3034995,,,,,,,,,,,,,, -131700,6.3045964,1.2708873,,,,,,,,,,,,,, -131800,5.6944017,1.2927009,,,,,,,,,,,,,, -131900,6.3299336,1.255634,,,,,,,,,,,,,, -132000,5.600589,1.3139217,,,,,,,,,,,,,, -132100,5.981488,1.1908714,,,,,,,,,,,,,, -132129,,,0.7948620915412903,0.7571321725845337,0.7081599831581116,1.174660325050354,50000.0,0.579800009727478,1.8872145414352417,10000.0,44407.26432728768,45991.0145611763,44407.26432728768,1575.7394952774048,3.723146915435791,0.0 -132200,5.9007792,1.2849545,,,,,,,,,,,,,, -132300,5.491867,1.2443047,,,,,,,,,,,,,, -132400,5.9923573,1.4172599,,,,,,,,,,,,,, -132500,5.993754,1.2575684,,,,,,,,,,,,,, -132600,5.946121,1.2893207,,,,,,,,,,,,,, -132700,6.0917463,1.227012,,,,,,,,,,,,,, -132800,5.6686053,1.2143356,,,,,,,,,,,,,, -132900,6.521313,1.4402918,,,,,,,,,,,,,, -133000,5.827082,1.2948499,,,,,,,,,,,,,, -133100,6.6113014,1.2748787,,,,,,,,,,,,,, -133200,6.133615,1.334729,,,,,,,,,,,,,, -133300,5.8718977,1.2825683,,,,,,,,,,,,,, -133400,6.0461574,1.206145,,,,,,,,,,,,,, -133500,6.370552,1.3018496,,,,,,,,,,,,,, -133600,7.0523496,1.3224328,,,,,,,,,,,,,, -133651,,,0.7874082922935486,0.7908145189285278,0.7020399570465088,1.193451166152954,50000.0,0.5773000121116638,1.888568878173828,10000.0,44917.38735723496,46519.36021995544,44917.38735723496,1593.8583455085754,3.777401447296143,0.0 -133700,5.815239,1.2613628,,,,,,,,,,,,,, -133800,6.326844,1.270997,,,,,,,,,,,,,, -133900,6.691176,1.3478818,,,,,,,,,,,,,, -134000,6.085673,1.242549,,,,,,,,,,,,,, -134100,6.1440673,1.2725242,,,,,,,,,,,,,, -134200,5.809487,1.2086151,,,,,,,,,,,,,, -134300,6.166375,1.3394816,,,,,,,,,,,,,, -134400,5.7614603,1.3040155,,,,,,,,,,,,,, -134500,5.9872537,1.2582632,,,,,,,,,,,,,, -134600,7.5088744,1.3425095,,,,,,,,,,,,,, -134700,6.0212145,1.2488457,,,,,,,,,,,,,, -134800,6.09808,1.1787299,,,,,,,,,,,,,, -134900,5.971096,1.2165678,,,,,,,,,,,,,, -135000,6.116651,1.3156986,,,,,,,,,,,,,, -135100,6.2341557,1.3114175,,,,,,,,,,,,,, -135172,,,0.8009805083274841,0.7329120635986328,0.7127000093460083,1.1449772119522097,50000.0,0.5895000100135803,1.8212326765060425,10000.0,45427.400640010834,47047.412162303925,45427.400640010834,1611.7913794517517,3.833778142929077,0.0 -135200,6.305688,1.1954374,,,,,,,,,,,,,, -135300,5.4395885,1.233763,,,,,,,,,,,,,, -135400,6.2657065,1.288575,,,,,,,,,,,,,, -135500,6.327631,1.187333,,,,,,,,,,,,,, -135600,5.9057636,1.2436302,,,,,,,,,,,,,, -135700,6.6636853,1.3453721,,,,,,,,,,,,,, -135800,5.660606,1.2586787,,,,,,,,,,,,,, -135900,6.380171,1.279412,,,,,,,,,,,,,, -136000,6.2620077,1.3224514,,,,,,,,,,,,,, -136100,6.886325,1.2800281,,,,,,,,,,,,,, -136200,6.152735,1.2233199,,,,,,,,,,,,,, -136300,8.078597,1.2837024,,,,,,,,,,,,,, -136400,6.0714755,1.2835326,,,,,,,,,,,,,, -136500,6.848284,1.3164554,,,,,,,,,,,,,, -136600,6.3453197,1.2025484,,,,,,,,,,,,,, -136692,,,0.7917529940605164,0.7613667845726013,0.7073599696159363,1.1776471138000488,50000.0,0.5819000005722046,1.8573424816131592,10000.0,45937.60295057297,47575.49288535118,45937.60295057297,1629.5676186084747,3.88659930229187,0.0 -136700,6.555153,1.2990459,,,,,,,,,,,,,, -136800,5.8480425,1.2290139,,,,,,,,,,,,,, -136900,6.0293074,1.279501,,,,,,,,,,,,,, -137000,5.963282,1.183792,,,,,,,,,,,,,, -137100,7.2167826,1.2568898,,,,,,,,,,,,,, -137200,6.4388647,1.2061841,,,,,,,,,,,,,, -137300,6.1086965,1.1665988,,,,,,,,,,,,,, -137400,6.201234,1.1726948,,,,,,,,,,,,,, -137500,6.5695667,1.3443388,,,,,,,,,,,,,, -137600,5.6099477,1.1751046,,,,,,,,,,,,,, -137700,6.5988197,1.227653,,,,,,,,,,,,,, -137800,6.197578,1.2069839,,,,,,,,,,,,,, -137900,5.80782,1.124731,,,,,,,,,,,,,, -138000,5.900431,1.1604698,,,,,,,,,,,,,, -138100,6.019742,1.2662561,,,,,,,,,,,,,, -138200,6.028393,1.2136277,,,,,,,,,,,,,, -138213,,,0.8251953125,0.6313984990119934,0.7180399894714355,1.1386464834213257,50000.0,0.589900016784668,1.820005178451538,10000.0,46447.6070895195,48103.48945856094,46447.6070895195,1647.4524323940277,3.945059061050415,0.0 -138300,6.1551075,1.1785573,,,,,,,,,,,,,, -138400,6.1030507,1.1547867,,,,,,,,,,,,,, -138500,6.5519595,1.2781945,,,,,,,,,,,,,, -138600,6.896092,1.2648249,,,,,,,,,,,,,, -138700,6.4105372,1.2424129,,,,,,,,,,,,,, -138800,6.7641945,1.2403456,,,,,,,,,,,,,, -138900,6.3197117,1.247541,,,,,,,,,,,,,, -139000,6.5927887,1.2854048,,,,,,,,,,,,,, -139100,6.498002,1.2604525,,,,,,,,,,,,,, -139200,6.4922504,1.2508544,,,,,,,,,,,,,, -139300,6.1212416,1.1098918,,,,,,,,,,,,,, -139400,6.350478,1.2771262,,,,,,,,,,,,,, -139500,6.0864816,1.1456816,,,,,,,,,,,,,, -139600,5.9905725,1.2286122,,,,,,,,,,,,,, -139700,6.083026,1.167473,,,,,,,,,,,,,, -139735,,,0.8220463991165161,0.6473369002342224,0.7188199758529663,1.1304951906204224,50000.0,0.5910000205039978,1.8218023777008057,10000.0,46957.71826648712,48631.36492419243,46957.71826648712,1665.1089329719543,4.003594875335693,0.0 -139800,6.689617,1.1985533,,,,,,,,,,,,,, -139900,6.5195403,1.2046694,,,,,,,,,,,,,, -140000,6.192082,1.2114156,,,,,,,,,,,,,, -140100,6.912593,1.2117174,,,,,,,,,,,,,, -140200,6.841891,1.1733246,,,,,,,,,,,,,, -140300,6.818425,1.1897222,,,,,,,,,,,,,, -140400,7.5047717,1.2145157,,,,,,,,,,,,,, -140500,6.7659645,1.186156,,,,,,,,,,,,,, -140600,6.234972,1.2160671,,,,,,,,,,,,,, -140700,6.9121537,1.1780791,,,,,,,,,,,,,, -140800,7.5825124,1.1230229,,,,,,,,,,,,,, -140900,6.609918,1.1446939,,,,,,,,,,,,,, -141000,6.7356052,1.0610803,,,,,,,,,,,,,, -141100,6.7336855,1.209306,,,,,,,,,,,,,, -141200,6.071196,1.0855156,,,,,,,,,,,,,, -141255,,,0.8222058415412903,0.6530354619026184,0.7206199765205383,1.1245300769805908,50000.0,0.597100019454956,1.824455499649048,10000.0,47467.91790008545,49159.58340501785,47467.91790008545,1683.0212621688845,4.060739040374756,0.0 -141300,6.824692,1.3369682,,,,,,,,,,,,,, -141400,6.4702754,1.2990603,,,,,,,,,,,,,, -141500,6.5497026,1.2750458,,,,,,,,,,,,,, -141600,6.6435537,1.2263988,,,,,,,,,,,,,, -141700,7.0495243,1.1682546,,,,,,,,,,,,,, -141800,6.8906007,1.2360742,,,,,,,,,,,,,, -141900,6.7699633,1.1764468,,,,,,,,,,,,,, -142000,5.9766517,1.1689215,,,,,,,,,,,,,, -142100,6.7054024,1.2312553,,,,,,,,,,,,,, -142200,6.199609,1.1896688,,,,,,,,,,,,,, -142300,6.0325875,1.087845,,,,,,,,,,,,,, -142400,6.8514533,1.2069255,,,,,,,,,,,,,, -142500,6.273793,1.045672,,,,,,,,,,,,,, -142600,6.1441646,1.0794923,,,,,,,,,,,,,, -142700,6.3707824,1.212277,,,,,,,,,,,,,, -142775,,,0.8192163705825806,0.6568843126296997,0.722819983959198,1.1111185550689695,50000.0,0.5913000106811523,1.8181012868881223,10000.0,47978.107671022415,49687.80672287941,47978.107671022415,1700.9479422569275,4.118076086044312,0.0 -142800,6.2363806,1.0692341,,,,,,,,,,,,,, -142900,7.458333,1.2565968,,,,,,,,,,,,,, -143000,6.5052447,1.1810865,,,,,,,,,,,,,, -143100,6.7608547,1.1387649,,,,,,,,,,,,,, -143200,6.913659,1.1961768,,,,,,,,,,,,,, -143300,7.0370655,1.1937399,,,,,,,,,,,,,, -143400,6.1915574,1.1634425,,,,,,,,,,,,,, -143500,6.413745,1.111309,,,,,,,,,,,,,, -143600,6.33123,1.1065412,,,,,,,,,,,,,, -143700,7.0614877,1.2137256,,,,,,,,,,,,,, -143800,6.619461,1.1454442,,,,,,,,,,,,,, -143900,6.205642,1.148216,,,,,,,,,,,,,, -144000,6.6215963,1.1864297,,,,,,,,,,,,,, -144100,7.066426,1.1238736,,,,,,,,,,,,,, -144200,6.9510765,1.1213658,,,,,,,,,,,,,, -144296,,,0.8274872303009033,0.6358198523521423,0.7251600027084351,1.0935237407684326,50000.0,0.5961000323295593,1.785703420639038,10000.0,48488.26706719399,50215.87666511536,48488.26706719399,1718.752968788147,4.174353361129761,0.0 -144300,6.6094594,1.1704272,,,,,,,,,,,,,, -144400,7.482783,1.1657887,,,,,,,,,,,,,, -144500,6.497854,1.1428747,,,,,,,,,,,,,, -144600,6.4193707,1.0805582,,,,,,,,,,,,,, -144700,7.3702703,1.2479095,,,,,,,,,,,,,, -144800,7.6913795,1.301417,,,,,,,,,,,,,, -144900,6.445246,1.1697382,,,,,,,,,,,,,, -145000,7.0942264,1.1743861,,,,,,,,,,,,,, -145100,6.7435236,1.1363739,,,,,,,,,,,,,, -145200,6.9903793,1.1375345,,,,,,,,,,,,,, -145300,6.441129,1.0457983,,,,,,,,,,,,,, -145400,6.9177337,1.0777202,,,,,,,,,,,,,, -145500,7.6503615,1.193754,,,,,,,,,,,,,, -145600,6.561568,1.0466655,,,,,,,,,,,,,, -145700,7.6112723,1.1344634,,,,,,,,,,,,,, -145800,6.640249,1.2149595,,,,,,,,,,,,,, -145817,,,0.8562061190605164,0.5220614075660706,0.7297799587249756,1.0798618793487549,50000.0,0.6053000092506409,1.75275981426239,10000.0,48998.32432341576,50744.13141441345,48998.32432341576,1736.8405735492706,4.234951734542847,0.0 -145900,6.7597985,0.9967865,,,,,,,,,,,,,, -146000,7.462428,1.1429406,,,,,,,,,,,,,, -146100,7.2780523,1.1911855,,,,,,,,,,,,,, -146200,6.6481676,1.1256849,,,,,,,,,,,,,, -146300,6.838799,1.1271988,,,,,,,,,,,,,, -146400,6.550163,1.1268911,,,,,,,,,,,,,, -146500,7.2090063,1.0915418,,,,,,,,,,,,,, -146600,7.20686,1.1638722,,,,,,,,,,,,,, -146700,7.373352,1.1367568,,,,,,,,,,,,,, -146800,7.514644,1.125685,,,,,,,,,,,,,, -146900,7.2053013,1.1894538,,,,,,,,,,,,,, -147000,7.0168457,1.1472336,,,,,,,,,,,,,, -147100,7.907373,1.153845,,,,,,,,,,,,,, -147200,7.1749682,1.0347828,,,,,,,,,,,,,, -147300,7.5579767,1.1213715,,,,,,,,,,,,,, -147337,,,0.8483338356018066,0.5386354327201843,0.7303999662399292,1.0842643976211548,50000.0,0.604200005531311,1.7619547843933103,10000.0,49508.459768772125,51272.14950942993,49508.459768772125,1754.6182186603546,4.2909040451049805,0.0 -147400,7.5110755,1.1951431,,,,,,,,,,,,,, -147500,7.0510726,1.1265715,,,,,,,,,,,,,, -147600,7.0528455,1.0551095,,,,,,,,,,,,,, -147700,6.5654297,1.0890611,,,,,,,,,,,,,, -147800,6.980405,1.0676498,,,,,,,,,,,,,, -147900,7.2196484,1.1226941,,,,,,,,,,,,,, -148000,7.8050466,1.084992,,,,,,,,,,,,,, -148100,6.5346065,1.0797421,,,,,,,,,,,,,, -148200,6.7099624,1.059613,,,,,,,,,,,,,, -148300,7.333182,1.0805498,,,,,,,,,,,,,, -148400,6.976727,1.1271577,,,,,,,,,,,,,, -148500,6.7612443,1.0623986,,,,,,,,,,,,,, -148600,8.029651,1.1689963,,,,,,,,,,,,,, -148700,7.365466,1.153847,,,,,,,,,,,,,, -148800,8.074935,1.192898,,,,,,,,,,,,,, -148857,,,0.8486328125,0.5442114472389221,0.7340599894523621,1.065670132637024,50000.0,0.6065000295639038,1.7632725238800049,10000.0,50018.51382923126,51799.88316822052,50018.51382923126,1772.192732810974,4.347080707550049,0.0 -148900,7.081242,1.0533732,,,,,,,,,,,,,, -149000,6.9090924,1.0684996,,,,,,,,,,,,,, -149100,6.506569,0.9856515,,,,,,,,,,,,,, -149200,6.7707124,1.0591309,,,,,,,,,,,,,, -149300,7.2575994,1.0888605,,,,,,,,,,,,,, -149400,7.6506414,1.1098244,,,,,,,,,,,,,, -149500,7.324782,1.1034425,,,,,,,,,,,,,, -149600,7.019372,1.1698114,,,,,,,,,,,,,, -149700,6.9008746,1.0923109,,,,,,,,,,,,,, -149800,7.952594,1.2039412,,,,,,,,,,,,,, -149900,7.0388064,1.0941061,,,,,,,,,,,,,, -150000,7.1403823,1.1741765,,,,,,,,,,,,,, -150100,6.893858,1.0327003,,,,,,,,,,,,,, -150200,7.176679,1.1085905,,,,,,,,,,,,,, -150300,8.5536,1.1228962,,,,,,,,,,,,,, -150377,,,0.8487922549247742,0.5424256324768066,0.7318800091743469,1.0758512020111084,50000.0,0.6093000173568726,1.749652624130249,10000.0,50528.65114068985,52327.87962150574,50528.65114068985,1789.946456670761,4.403222799301148,0.0 -150400,7.665934,1.0611637,,,,,,,,,,,,,, -150500,7.2345543,1.0169617,,,,,,,,,,,,,, -150600,8.098672,1.1290874,,,,,,,,,,,,,, -150700,7.971371,1.0684516,,,,,,,,,,,,,, -150800,6.6219697,1.0569218,,,,,,,,,,,,,, -150900,7.4203954,1.0819834,,,,,,,,,,,,,, -151000,7.615633,1.0130093,,,,,,,,,,,,,, -151100,7.2249246,1.0141315,,,,,,,,,,,,,, -151200,7.20214,1.0384533,,,,,,,,,,,,,, -151300,7.6367846,1.0241691,,,,,,,,,,,,,, -151400,6.9505315,1.0167783,,,,,,,,,,,,,, -151500,7.7453165,1.1307155,,,,,,,,,,,,,, -151600,7.051702,1.1503487,,,,,,,,,,,,,, -151700,7.6650743,1.073758,,,,,,,,,,,,,, -151800,7.4612756,1.0152024,,,,,,,,,,,,,, -151898,,,0.8529974222183228,0.52959805727005,0.7369999885559082,1.0513129234313965,50000.0,0.6142000555992126,1.7403059005737305,10000.0,51038.67733478546,52855.8038623333,51038.67733478546,1807.7406420707705,4.458520889282227,0.0 -151900,7.916336,0.99722505,,,,,,,,,,,,,, -152000,6.9476223,1.00537,,,,,,,,,,,,,, -152100,7.629451,1.029616,,,,,,,,,,,,,, -152200,7.5587697,1.0119678,,,,,,,,,,,,,, -152300,7.7537727,1.017278,,,,,,,,,,,,,, -152400,7.348613,0.9639098,,,,,,,,,,,,,, -152500,7.0530386,0.9622003,,,,,,,,,,,,,, -152600,7.043322,1.0671244,,,,,,,,,,,,,, -152700,6.9810658,1.0145146,,,,,,,,,,,,,, -152800,7.5770416,1.0699234,,,,,,,,,,,,,, -152900,7.0991406,1.0568326,,,,,,,,,,,,,, -153000,7.4489374,1.0188909,,,,,,,,,,,,,, -153100,7.2894955,1.0663018,,,,,,,,,,,,,, -153200,7.687516,1.0045128,,,,,,,,,,,,,, -153300,7.570177,1.1806877,,,,,,,,,,,,,, -153400,8.141639,1.0422833,,,,,,,,,,,,,, -153418,,,0.8550103306770325,0.5078051090240479,0.7404800057411194,1.0449867248535156,50000.0,0.609000027179718,1.75309419631958,10000.0,51548.73529744148,53383.79339146614,51548.73529744148,1825.565360069275,4.516251564025879,0.0 -153500,7.1209292,1.0514021,,,,,,,,,,,,,, -153600,7.011548,0.96292186,,,,,,,,,,,,,, -153700,7.9749823,1.0213771,,,,,,,,,,,,,, -153800,7.2538767,1.0425577,,,,,,,,,,,,,, -153900,7.3035526,0.9715157,,,,,,,,,,,,,, -154000,7.760364,1.0099195,,,,,,,,,,,,,, -154100,7.4587216,1.0856514,,,,,,,,,,,,,, -154200,8.020633,1.1070545,,,,,,,,,,,,,, -154300,7.404535,0.9987148,,,,,,,,,,,,,, -154400,7.4139404,1.028195,,,,,,,,,,,,,, -154500,6.9761195,1.0177399,,,,,,,,,,,,,, -154600,7.122276,1.0117406,,,,,,,,,,,,,, -154700,7.7874393,1.110429,,,,,,,,,,,,,, -154800,6.784028,0.95763135,,,,,,,,,,,,,, -154900,7.1815734,0.9837691,,,,,,,,,,,,,, -154938,,,0.8839883208274841,0.4073780179023742,0.7401799559593201,1.042005181312561,50000.0,0.617400050163269,1.7246414422988892,10000.0,52058.83040165901,53911.61543941498,52058.83040165901,1843.183394432068,4.576179504394531,0.0 -155000,7.7992306,1.0680746,,,,,,,,,,,,,, -155100,8.665993,1.1090279,,,,,,,,,,,,,, -155200,6.8536525,0.97572154,,,,,,,,,,,,,, -155300,7.458855,0.95827305,,,,,,,,,,,,,, -155400,8.173074,0.9844105,,,,,,,,,,,,,, -155500,7.9088645,0.9277397,,,,,,,,,,,,,, -155600,7.7900147,0.96999115,,,,,,,,,,,,,, -155700,7.1228113,1.0097052,,,,,,,,,,,,,, -155800,6.99643,0.96488535,,,,,,,,,,,,,, -155900,8.79849,0.9982407,,,,,,,,,,,,,, -156000,7.258815,0.9835827,,,,,,,,,,,,,, -156100,7.5809293,1.0418406,,,,,,,,,,,,,, -156200,8.969956,1.0398746,,,,,,,,,,,,,, -156300,8.10352,0.95627224,,,,,,,,,,,,,, -156400,7.1756206,0.9708073,,,,,,,,,,,,,, -156460,,,0.8816167116165161,0.4281372129917145,0.7430399656295776,1.0246366262435913,50000.0,0.619100034236908,1.7108631134033203,10000.0,52568.99254751205,54439.52820849419,52568.99254751205,1860.8247520923608,4.636116027832031,0.0 -156500,8.048233,0.93612856,,,,,,,,,,,,,, -156600,7.729235,0.9317549,,,,,,,,,,,,,, -156700,7.6718397,0.9176017,,,,,,,,,,,,,, -156800,7.7403855,0.99844164,,,,,,,,,,,,,, -156900,8.071155,0.98592937,,,,,,,,,,,,,, -157000,8.049398,1.0697434,,,,,,,,,,,,,, -157100,7.846315,0.9171846,,,,,,,,,,,,,, -157200,7.77987,0.97251993,,,,,,,,,,,,,, -157300,8.847234,1.0442308,,,,,,,,,,,,,, -157400,7.274682,0.94223017,,,,,,,,,,,,,, -157500,8.411144,0.8675906,,,,,,,,,,,,,, -157600,8.336882,0.9858216,,,,,,,,,,,,,, -157700,8.338648,0.89812845,,,,,,,,,,,,,, -157800,8.32965,0.9937744,,,,,,,,,,,,,, -157900,8.330664,0.9870262,,,,,,,,,,,,,, -157980,,,0.8804408311843872,0.4245630204677582,0.7468999624252319,1.0233200788497925,50000.0,0.6183000206947327,1.7170692682266235,10000.0,53079.10682106018,54967.46635508537,53079.10682106018,1878.545128583908,4.690703630447388,0.0 -158000,7.734338,1.0217425,,,,,,,,,,,,,, -158100,8.214979,0.9685371,,,,,,,,,,,,,, -158200,8.071318,0.99027,,,,,,,,,,,,,, -158300,8.2283745,0.90300286,,,,,,,,,,,,,, -158400,7.784784,0.9177984,,,,,,,,,,,,,, -158500,7.461352,0.9125709,,,,,,,,,,,,,, -158600,7.8228807,0.91970354,,,,,,,,,,,,,, -158700,7.7184978,0.96973085,,,,,,,,,,,,,, -158800,8.993899,1.0052378,,,,,,,,,,,,,, -158900,8.050899,1.0546458,,,,,,,,,,,,,, -159000,7.65073,0.96315825,,,,,,,,,,,,,, -159100,8.013671,1.0125834,,,,,,,,,,,,,, -159200,7.294822,0.92814744,,,,,,,,,,,,,, -159300,7.8761396,0.9612565,,,,,,,,,,,,,, -159400,8.436352,0.98497945,,,,,,,,,,,,,, -159500,8.594856,0.92515075,,,,,,,,,,,,,, -159501,,,0.8826530575752258,0.4148502647876739,0.7451199889183044,1.0272034406661987,50000.0,0.6223000288009644,1.720950484275818,10000.0,53589.45731592178,55495.64396739006,53589.45731592178,1896.2643899917605,4.749074697494507,0.0 -159600,7.6556997,0.90418756,,,,,,,,,,,,,, -159700,8.046015,0.92294633,,,,,,,,,,,,,, -159800,7.3611794,0.94602895,,,,,,,,,,,,,, -159900,8.399784,0.91322935,,,,,,,,,,,,,, -160000,8.229829,1.037054,,,,,,,,,,,,,, -160100,7.778813,0.9279625,,,,,,,,,,,,,, -160200,8.783558,1.1213989,,,,,,,,,,,,,, -160300,8.252551,0.9743259,,,,,,,,,,,,,, -160400,8.59433,0.9575031,,,,,,,,,,,,,, -160500,8.184634,0.90350467,,,,,,,,,,,,,, -160600,9.124529,0.98597425,,,,,,,,,,,,,, -160700,8.359482,0.99751395,,,,,,,,,,,,,, -160800,7.691263,0.9463871,,,,,,,,,,,,,, -160900,7.896627,0.91805923,,,,,,,,,,,,,, -161000,7.365195,0.81695837,,,,,,,,,,,,,, -161022,,,0.8878945708274841,0.3984717428684234,0.7486799955368042,1.0120301246643066,50000.0,0.6247000098228455,1.695212960243225,10000.0,54099.56357502937,56023.67906188965,54099.56357502937,1914.0802400112152,4.812868118286133,0.0 -161100,7.994489,0.90797216,,,,,,,,,,,,,, -161200,8.503497,0.97883755,,,,,,,,,,,,,, -161300,7.9823847,0.96430767,,,,,,,,,,,,,, -161400,8.640429,0.885293,,,,,,,,,,,,,, -161500,7.674361,0.9137065,,,,,,,,,,,,,, -161600,8.194473,0.8768036,,,,,,,,,,,,,, -161700,7.8477726,0.9687901,,,,,,,,,,,,,, -161800,8.547701,0.9311415,,,,,,,,,,,,,, -161900,8.021891,0.9266467,,,,,,,,,,,,,, -162000,8.346653,0.9296392,,,,,,,,,,,,,, -162100,8.219187,0.88925326,,,,,,,,,,,,,, -162200,8.550979,0.9287102,,,,,,,,,,,,,, -162300,8.158391,0.9015325,,,,,,,,,,,,,, -162400,7.9325175,0.9160242,,,,,,,,,,,,,, -162500,7.764219,0.9281361,,,,,,,,,,,,,, -162541,,,0.8908840417861938,0.3822648227214813,0.751259982585907,1.008744239807129,50000.0,0.6285000443458557,1.6943116188049316,10000.0,54609.65087771416,56551.9337553978,54609.65087771416,1932.141437768936,4.868780136108398,0.0 -162600,7.825351,0.8461309,,,,,,,,,,,,,, -162700,8.742094,0.9328469,,,,,,,,,,,,,, -162800,8.636414,0.8934393,,,,,,,,,,,,,, -162900,9.732201,0.9519253,,,,,,,,,,,,,, -163000,7.6207166,0.8975877,,,,,,,,,,,,,, -163100,8.834878,0.8575602,,,,,,,,,,,,,, -163200,8.548216,0.87931967,,,,,,,,,,,,,, -163300,8.16502,0.8619945,,,,,,,,,,,,,, -163400,8.524278,0.933334,,,,,,,,,,,,,, -163500,7.978543,0.88681066,,,,,,,,,,,,,, -163600,8.112935,0.8805227,,,,,,,,,,,,,, -163700,8.656164,0.8858791,,,,,,,,,,,,,, -163800,9.166006,0.88713163,,,,,,,,,,,,,, -163900,8.160124,0.83531165,,,,,,,,,,,,,, -164000,8.700835,0.9448555,,,,,,,,,,,,,, -164062,,,0.9062101244926452,0.3281794488430023,0.7523199915885925,1.0044119358062744,50000.0,0.6237000226974487,1.7114315032958984,10000.0,55119.58855962753,57079.75123023987,55119.58855962753,1949.9157021045685,4.9245476722717285,0.0 -164100,8.54737,0.94777495,,,,,,,,,,,,,, -164200,9.913177,0.95811003,,,,,,,,,,,,,, -164300,8.641831,0.82678694,,,,,,,,,,,,,, -164400,8.185765,0.91295147,,,,,,,,,,,,,, -164500,7.576968,0.76971716,,,,,,,,,,,,,, -164600,8.086475,0.869836,,,,,,,,,,,,,, -164700,7.832255,0.9237424,,,,,,,,,,,,,, -164800,8.834487,0.9734859,,,,,,,,,,,,,, -164900,7.7865734,0.8649235,,,,,,,,,,,,,, -165000,8.437174,0.88824284,,,,,,,,,,,,,, -165100,8.138433,0.88487965,,,,,,,,,,,,,, -165200,7.977282,0.8001899,,,,,,,,,,,,,, -165300,8.497664,0.9245448,,,,,,,,,,,,,, -165400,7.961735,0.88674015,,,,,,,,,,,,,, -165500,8.205657,0.8564154,,,,,,,,,,,,,, -165582,,,0.9041573405265808,0.3352792263031006,0.7546600103378296,0.99237322807312,50000.0,0.6300000548362732,1.6897133588790894,10000.0,55629.49757552147,57607.30726504326,55629.49757552147,1967.455517053604,4.982961177825928,0.0 -165600,8.791739,0.81742907,,,,,,,,,,,,,, -165700,8.18533,0.84426326,,,,,,,,,,,,,, -165800,8.3309145,0.8610584,,,,,,,,,,,,,, -165900,7.885664,0.83755875,,,,,,,,,,,,,, -166000,9.092741,0.91630757,,,,,,,,,,,,,, -166100,7.497652,0.77233744,,,,,,,,,,,,,, -166200,8.521186,0.87218,,,,,,,,,,,,,, -166300,7.9865675,0.81516004,,,,,,,,,,,,,, -166400,8.366286,0.84149003,,,,,,,,,,,,,, -166500,9.510424,0.86490035,,,,,,,,,,,,,, -166600,8.4293995,0.8550256,,,,,,,,,,,,,, -166700,9.274034,0.89270234,,,,,,,,,,,,,, -166800,8.286131,0.8489355,,,,,,,,,,,,,, -166900,7.897375,0.81574565,,,,,,,,,,,,,, -167000,8.76293,0.8772056,,,,,,,,,,,,,, -167100,8.570079,0.83154947,,,,,,,,,,,,,, -167102,,,0.9091398119926452,0.3203354179859161,0.7573999762535095,0.987073004245758,50000.0,0.6339000463485718,1.6737549304962158,10000.0,56139.54078269005,58135.12322330475,56139.54078269005,1985.1188147068024,5.042438745498657,0.0 -167200,8.659145,0.82719016,,,,,,,,,,,,,, -167300,9.859706,0.96985495,,,,,,,,,,,,,, -167400,7.918386,0.7882518,,,,,,,,,,,,,, -167500,10.320787,0.76334584,,,,,,,,,,,,,, -167600,8.922447,0.8491784,,,,,,,,,,,,,, -167700,8.220794,0.8621598,,,,,,,,,,,,,, -167800,8.312394,0.8843664,,,,,,,,,,,,,, -167900,8.893004,0.92479026,,,,,,,,,,,,,, -168000,9.104518,0.8983045,,,,,,,,,,,,,, -168100,7.932446,0.7607261,,,,,,,,,,,,,, -168200,8.602667,0.84498245,,,,,,,,,,,,,, -168300,9.004729,0.8643909,,,,,,,,,,,,,, -168400,8.838233,0.81985366,,,,,,,,,,,,,, -168500,8.78312,0.8217758,,,,,,,,,,,,,, -168600,7.9502234,0.8146856,,,,,,,,,,,,,, -168623,,,0.9107740521430968,0.3188508749008178,0.7586399912834167,0.9830759763717652,50000.0,0.6332000494003296,1.6756737232208252,10000.0,56649.45196771622,58663.02194476128,56649.45196771622,2002.999568939209,5.100426197052002,0.0 -168700,8.395843,0.7048867,,,,,,,,,,,,,, -168800,9.361465,0.7933667,,,,,,,,,,,,,, -168900,8.03152,0.87135184,,,,,,,,,,,,,, -169000,7.4895806,0.759768,,,,,,,,,,,,,, -169100,9.175068,0.81340206,,,,,,,,,,,,,, -169200,8.684114,0.8455219,,,,,,,,,,,,,, -169300,9.103326,0.8014587,,,,,,,,,,,,,, -169400,7.942005,0.78117263,,,,,,,,,,,,,, -169500,8.722683,0.73474175,,,,,,,,,,,,,, -169600,8.693499,0.879341,,,,,,,,,,,,,, -169700,8.553928,0.8000196,,,,,,,,,,,,,, -169800,10.942819,0.9098121,,,,,,,,,,,,,, -169900,8.451206,0.8100232,,,,,,,,,,,,,, -170000,9.034894,0.8409343,,,,,,,,,,,,,, -170100,7.9986453,0.7846889,,,,,,,,,,,,,, -170143,,,0.9118303656578064,0.3059313297271728,0.7596799731254578,0.977874517440796,50000.0,0.6368000507354736,1.6767271757125854,10000.0,57159.36017847061,59190.96336650848,57159.36017847061,2020.919572353363,5.164439678192139,0.0 -170200,8.687379,0.72563183,,,,,,,,,,,,,, -170300,8.912919,0.8533821,,,,,,,,,,,,,, -170400,8.809127,0.7167394,,,,,,,,,,,,,, -170500,9.139967,0.8374655,,,,,,,,,,,,,, -170600,8.734698,0.87142694,,,,,,,,,,,,,, -170700,8.928726,0.79461205,,,,,,,,,,,,,, -170800,9.2755,0.83581024,,,,,,,,,,,,,, -170900,9.293777,0.78978837,,,,,,,,,,,,,, -171000,8.531932,0.82813144,,,,,,,,,,,,,, -171100,9.755216,0.8255007,,,,,,,,,,,,,, -171200,8.135157,0.79641914,,,,,,,,,,,,,, -171300,7.9810457,0.74423397,,,,,,,,,,,,,, -171400,8.205371,0.7353299,,,,,,,,,,,,,, -171500,9.21824,0.7249124,,,,,,,,,,,,,, -171600,8.514689,0.70170957,,,,,,,,,,,,,, -171663,,,0.9235889315605164,0.2766060531139374,0.7594000101089478,0.9787272214889526,50000.0,0.6338000297546387,1.6702979803085327,10000.0,57669.37860417366,59719.01620292664,57669.37860417366,2038.8429753780365,5.226194143295288,0.0 -171700,9.447249,0.84093654,,,,,,,,,,,,,, -171800,8.408668,0.789803,,,,,,,,,,,,,, -171900,8.639223,0.7947562,,,,,,,,,,,,,, -172000,8.846258,0.72244644,,,,,,,,,,,,,, -172100,8.797195,0.7560742,,,,,,,,,,,,,, -172200,9.07432,0.7879968,,,,,,,,,,,,,, -172300,9.225238,0.8362934,,,,,,,,,,,,,, -172400,10.167885,0.8537555,,,,,,,,,,,,,, -172500,8.626703,0.74267966,,,,,,,,,,,,,, -172600,9.313258,0.8027136,,,,,,,,,,,,,, -172700,8.881813,0.78107786,,,,,,,,,,,,,, -172800,8.828807,0.71724164,,,,,,,,,,,,,, -172900,8.238751,0.7197254,,,,,,,,,,,,,, -173000,8.846154,0.799702,,,,,,,,,,,,,, -173100,8.244522,0.74446064,,,,,,,,,,,,,, -173183,,,0.9255819320678712,0.2669574022293091,0.7631999850273132,0.9670901298522948,50000.0,0.6374000310897827,1.6677082777023315,10000.0,58179.60224986076,60247.089400053024,58179.60224986076,2056.581691026688,5.2877233028411865,0.0 -173200,8.696186,0.79355985,,,,,,,,,,,,,, -173300,8.785423,0.74121094,,,,,,,,,,,,,, -173400,9.081696,0.8537807,,,,,,,,,,,,,, -173500,8.825797,0.7911338,,,,,,,,,,,,,, -173600,9.222673,0.79248214,,,,,,,,,,,,,, -173700,10.02177,0.79440993,,,,,,,,,,,,,, -173800,9.354471,0.7878873,,,,,,,,,,,,,, -173900,8.462375,0.78706074,,,,,,,,,,,,,, -174000,8.61691,0.7159212,,,,,,,,,,,,,, -174100,8.543243,0.7199428,,,,,,,,,,,,,, -174200,8.174123,0.768136,,,,,,,,,,,,,, -174300,8.678372,0.8238469,,,,,,,,,,,,,, -174400,9.496028,0.85105073,,,,,,,,,,,,,, -174500,9.041541,0.77410233,,,,,,,,,,,,,, -174600,8.637101,0.7723735,,,,,,,,,,,,,, -174700,9.729905,0.8096251,,,,,,,,,,,,,, -174703,,,0.927754282951355,0.2576439678668976,0.7628799676895142,0.9702308177947998,50000.0,0.6367000341415405,1.670340895652771,10000.0,58689.71545791626,60775.03505182266,58689.71545791626,2074.306261062622,5.346015691757202,0.0 -174800,8.701895,0.77274656,,,,,,,,,,,,,, -174900,8.368253,0.74327934,,,,,,,,,,,,,, -175000,9.393048,0.8579416,,,,,,,,,,,,,, -175100,9.164649,0.80626535,,,,,,,,,,,,,, -175200,9.114425,0.8322074,,,,,,,,,,,,,, -175300,9.054181,0.7459121,,,,,,,,,,,,,, -175400,9.55878,0.82626396,,,,,,,,,,,,,, -175500,9.12575,0.7206116,,,,,,,,,,,,,, -175600,8.759724,0.7838259,,,,,,,,,,,,,, -175700,9.897583,0.8690847,,,,,,,,,,,,,, -175800,9.876089,0.84677774,,,,,,,,,,,,,, -175900,8.441187,0.85681313,,,,,,,,,,,,,, -176000,9.489293,0.7357916,,,,,,,,,,,,,, -176100,9.142776,0.75854135,,,,,,,,,,,,,, -176200,9.607492,0.75896364,,,,,,,,,,,,,, -176222,,,0.9254623651504515,0.265313446521759,0.7633199691772461,0.9638875722885132,50000.0,0.6407000422477722,1.6606649160385132,10000.0,59199.76896715164,61303.05489516258,59199.76896715164,2092.1598665714264,5.409138917922974,0.0 -176300,8.548893,0.69868916,,,,,,,,,,,,,, -176400,8.671703,0.6970081,,,,,,,,,,,,,, -176500,8.435236,0.7379859,,,,,,,,,,,,,, -176600,8.563951,0.7364764,,,,,,,,,,,,,, -176700,8.93058,0.7554042,,,,,,,,,,,,,, -176800,9.259004,0.7809149,,,,,,,,,,,,,, -176900,9.788003,0.81020457,,,,,,,,,,,,,, -177000,9.380465,0.7763378,,,,,,,,,,,,,, -177100,8.851947,0.7460374,,,,,,,,,,,,,, -177200,10.365875,0.81162107,,,,,,,,,,,,,, -177300,8.32801,0.7279837,,,,,,,,,,,,,, -177400,8.989866,0.79655814,,,,,,,,,,,,,, -177500,8.659847,0.7398681,,,,,,,,,,,,,, -177600,8.750589,0.72766733,,,,,,,,,,,,,, -177700,8.581536,0.78949136,,,,,,,,,,,,,, -177741,,,0.9296875,0.254580557346344,0.7645799517631531,0.9602543711662292,50000.0,0.6406000256538391,1.659266233444214,10000.0,59709.75693559647,61831.0673494339,59709.75693559647,2110.070912361145,5.473035573959351,0.0 -177800,8.316523,0.78847545,,,,,,,,,,,,,, -177900,8.019721,0.74333465,,,,,,,,,,,,,, -178000,9.335449,0.7796296,,,,,,,,,,,,,, -178100,9.271668,0.75152236,,,,,,,,,,,,,, -178200,9.401339,0.7680622,,,,,,,,,,,,,, -178300,9.494626,0.813486,,,,,,,,,,,,,, -178400,8.427086,0.6933455,,,,,,,,,,,,,, -178500,9.2066765,0.80139804,,,,,,,,,,,,,, -178600,9.029535,0.6965015,,,,,,,,,,,,,, -178700,8.853108,0.79254925,,,,,,,,,,,,,, -178800,8.495119,0.73162764,,,,,,,,,,,,,, -178900,9.082917,0.7711045,,,,,,,,,,,,,, -179000,9.157055,0.7450099,,,,,,,,,,,,,, -179100,8.780266,0.7833694,,,,,,,,,,,,,, -179200,8.683585,0.8124346,,,,,,,,,,,,,, -179261,,,0.9309031963348388,0.2519624531269073,0.7645399570465088,0.9593108296394348,50000.0,0.6409000158309937,1.654666543006897,10000.0,60219.763149023056,62358.80902075768,60219.763149023056,2127.696268796921,5.533898830413818,0.0 -179300,8.742837,0.6762099,,,,,,,,,,,,,, -179400,8.824966,0.8032103,,,,,,,,,,,,,, -179500,8.905093,0.70964277,,,,,,,,,,,,,, -179600,7.964823,0.7250987,,,,,,,,,,,,,, -179700,7.929661,0.6617675,,,,,,,,,,,,,, -179800,8.744877,0.7199681,,,,,,,,,,,,,, -179900,9.309359,0.7719854,,,,,,,,,,,,,, -180000,8.4467745,0.68241936,,,,,,,,,,,,,, -180100,9.120375,0.717438,,,,,,,,,,,,,, -180200,8.974774,0.84073985,,,,,,,,,,,,,, -180300,8.736292,0.7488515,,,,,,,,,,,,,, -180400,8.361872,0.6739903,,,,,,,,,,,,,, -180500,9.875502,0.71900505,,,,,,,,,,,,,, -180600,9.233982,0.7876171,,,,,,,,,,,,,, -180700,8.790683,0.72256595,,,,,,,,,,,,,, -180781,,,0.9323779940605164,0.244142547249794,0.76419997215271,0.958231508731842,50000.0,0.6391000151634216,1.6557667255401611,10000.0,60729.83857727051,62887.02244234085,60729.83857727051,2145.725523948669,5.593425512313843,0.0 -180800,8.626132,0.74326885,,,,,,,,,,,,,, -180900,8.738793,0.7999241,,,,,,,,,,,,,, -181000,9.076204,0.74876904,,,,,,,,,,,,,, -181100,8.89312,0.7279144,,,,,,,,,,,,,, -181200,8.485153,0.7947411,,,,,,,,,,,,,, -181300,8.616778,0.741228,,,,,,,,,,,,,, -181400,9.573494,0.6923493,,,,,,,,,,,,,, -181500,8.984036,0.75248885,,,,,,,,,,,,,, -181600,8.493473,0.67383343,,,,,,,,,,,,,, -181700,8.374314,0.76902103,,,,,,,,,,,,,, -181800,9.72083,0.80056924,,,,,,,,,,,,,, -181900,8.262237,0.65700746,,,,,,,,,,,,,, -182000,9.836583,0.78004354,,,,,,,,,,,,,, -182100,8.638242,0.73486465,,,,,,,,,,,,,, -182200,9.099422,0.722327,,,,,,,,,,,,,, -182300,9.4216,0.7457604,,,,,,,,,,,,,, -182301,,,0.934351086616516,0.239544078707695,0.7648999691009521,0.956494688987732,50000.0,0.6412000060081482,1.6545675992965698,10000.0,61240.16095900536,63415.34263134003,61240.16095900536,2163.610629081726,5.656862735748291,0.0 -182400,8.93499,0.7097217,,,,,,,,,,,,,, -182500,8.360045,0.7167377,,,,,,,,,,,,,, -182600,8.593301,0.7388789,,,,,,,,,,,,,, -182700,9.055047,0.7744607,,,,,,,,,,,,,, -182800,9.540671,0.7689548,,,,,,,,,,,,,, -182900,8.674394,0.72330344,,,,,,,,,,,,,, -183000,9.187567,0.7106183,,,,,,,,,,,,,, -183100,8.480136,0.6583921,,,,,,,,,,,,,, -183200,8.319396,0.7007441,,,,,,,,,,,,,, -183300,8.925547,0.73032445,,,,,,,,,,,,,, -183400,8.471183,0.7499076,,,,,,,,,,,,,, -183500,8.814888,0.67949015,,,,,,,,,,,,,, -183600,8.836851,0.7250329,,,,,,,,,,,,,, -183700,8.332611,0.6266285,,,,,,,,,,,,,, -183800,9.027062,0.74795985,,,,,,,,,,,,,, -183820,,,0.9310028553009032,0.2465308010578155,0.7655199766159058,0.9564112424850464,50000.0,0.6407000422477722,1.6530959606170654,10000.0,61750.23178052902,63943.23792815208,61750.23178052902,2181.3248648643494,5.717596530914307,0.0 -183900,9.186302,0.7284936,,,,,,,,,,,,,, -184000,9.616059,0.7094568,,,,,,,,,,,,,, -184100,9.287849,0.88848126,,,,,,,,,,,,,, -184200,8.954775,0.7618219,,,,,,,,,,,,,, -184300,9.287289,0.85019517,,,,,,,,,,,,,, -184400,9.16975,0.72721225,,,,,,,,,,,,,, -184500,8.735806,0.6831282,,,,,,,,,,,,,, -184600,8.591207,0.7241469,,,,,,,,,,,,,, -184700,9.26939,0.7630915,,,,,,,,,,,,,, -184800,7.935357,0.68909025,,,,,,,,,,,,,, -184900,9.123902,0.75430673,,,,,,,,,,,,,, -185000,8.061105,0.74014485,,,,,,,,,,,,,, -185100,9.326918,0.7182082,,,,,,,,,,,,,, -185200,8.731881,0.7339276,,,,,,,,,,,,,, -185300,9.138712,0.67870486,,,,,,,,,,,,,, -185340,,,0.9342713356018066,0.2396325170993805,0.765779972076416,0.95612633228302,50000.0,0.6410000324249268,1.6519469022750854,10000.0,62260.25123000145,64471.24368929863,62260.25123000145,2199.1994581222534,5.779508590698242,0.0 -185400,9.283564,0.74943393,,,,,,,,,,,,,, -185500,9.598545,0.7751515,,,,,,,,,,,,,, -185600,8.52133,0.7107654,,,,,,,,,,,,,, -185700,9.540915,0.7686684,,,,,,,,,,,,,, -185800,8.9235935,0.7616459,,,,,,,,,,,,,, -185900,9.422894,0.7554031,,,,,,,,,,,,,, -186000,9.034007,0.7361592,,,,,,,,,,,,,, -186100,9.696645,0.80294234,,,,,,,,,,,,,, -186200,8.787662,0.69182515,,,,,,,,,,,,,, -186300,9.312724,0.7634795,,,,,,,,,,,,,, -186400,9.166252,0.74996763,,,,,,,,,,,,,, -186500,9.064055,0.7801989,,,,,,,,,,,,,, -186600,9.190293,0.70130605,,,,,,,,,,,,,, -186666,,,0.9329161047935486,0.2451334595680236,0.7657399773597717,0.9568769335746764,50000.0,0.640500009059906,1.654202580451965,10000.0,62705.71089506149,64934.39106178284,62705.71089506149,2216.782294511795,5.841466188430786,0.0 -186666,,,,,,,,,,,62705.71089506149,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 49ecf0bbb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.99710965156555,0.0,33.11548709869385,1,0,33.11548709869385,0.0006000000284984,6.910250186920166,10000,51.11273193359375,0.0009167729294858,6.91040563583374,0.0009599999757483,6.910243988037109,50000 -35.76667332649231,0.0208404064178466,543.3378028869629,1515,0,543.3378028869629,0.1337999999523162,4.854593276977539,10000,579.1754071712494,0.1935586631298065,4.136880397796631,0.181439995765686,4.266019344329834,50000 -53.82983016967773,0.052786111831665,1053.532371044159,3030,0,1053.532371044159,0.2606000006198883,3.755113124847412,10000,1107.5142476558683,0.3678053319454193,2.9227986335754395,0.3442199826240539,3.063002109527588,50000 -71.71818590164185,0.0823340415954589,1563.7822008132937,4547,0,1563.7822008132937,0.2830000221729278,3.582597494125366,10000,1635.7310452461245,0.409877210855484,2.6684727668762207,0.3856199979782104,2.8410749435424805,50000 -89.78872632980347,0.1114284992218017,2073.916249513626,6064,0,2073.916249513626,0.2252000123262405,4.131947994232178,10000,2164.014529466629,0.3488520383834839,3.0972437858581543,0.3138200044631958,3.3472726345062256,50000 -107.7860050201416,0.1452906131744384,2583.992983341217,7581,0,2583.992983341217,0.2040000110864639,4.534960746765137,10000,2692.171564102173,0.291693240404129,3.627654552459717,0.2701799869537353,3.8371315002441406,50000 -125.98636031150818,0.1741261482238769,3094.0958971977234,9099,0,3094.0958971977234,0.0473000034689903,7.12650203704834,10000,3220.553163051605,0.0675223171710968,6.560290813446045,0.0609200000762939,6.683966159820557,50000 -143.99125719070437,0.2043311595916748,3604.2537105083466,10618,0,3604.2537105083466,0.2546000182628631,3.9828693866729736,10000,3748.795606613159,0.3647162020206451,3.049172878265381,0.3426199853420257,3.222270965576172,50000 -162.03544116020203,0.2363069057464599,4114.4356808662415,12138,0,4114.4356808662415,0.1862000077962875,4.577045440673828,10000,4277.1030423641205,0.2653858363628387,3.797953367233277,0.2460999935865402,3.96741247177124,50000 -180.38438057899475,0.2737338542938232,4624.410368680954,13658,0,4624.410368680954,0.0982000082731247,5.658753395080566,10000,4805.513270378113,0.1414421200752258,5.066132068634033,0.1360199898481369,5.142930030822754,50000 -198.250910282135,0.3105723857879638,5134.47070813179,15178,0,5134.47070813179,0.2077000141143798,4.33635950088501,10000,5333.526302337647,0.2964963316917419,3.482040882110596,0.2676399946212768,3.7009527683258057,50000 -216.19500756263733,0.347388744354248,5644.488646507263,16699,0,5644.488646507263,0.1782000064849853,4.823949813842773,10000,5861.574481487274,0.2660634517669678,3.838587760925293,0.2495799958705902,3.98842978477478,50000 -233.9950804710388,0.3794333934783935,6154.5932059288025,18220,0,6154.5932059288025,0.1228000074625015,5.50702428817749,10000,6389.560833454132,0.1853675097227096,4.6938252449035645,0.1709599941968917,4.821983337402344,50000 -252.04337310791016,0.4124011993408203,6664.773620843887,19743,0,6664.773620843887,0.2044000029563903,4.545632839202881,10000,6917.872552871704,0.3005022406578064,3.5624327659606934,0.2856999933719635,3.7141499519348153,50000 -269.8241550922394,0.4445338249206543,7174.803897380829,21265,0,7174.803897380829,0.1683000028133392,5.043724060058594,10000,7445.765621185303,0.2461535334587097,4.195694923400879,0.2303600013256073,4.367095470428467,50000 -287.60385155677795,0.482598066329956,7685.032256126404,22789,0,7685.032256126404,0.0043000001460313,10.739754676818848,10000,7973.861493825912,0.006437340285629,10.417054176330566,0.005719999782741,10.540876388549805,50000 -305.51457262039185,0.5152654647827148,8195.057115793228,24313,0,8195.057115793228,0.2270000129938125,4.0918073654174805,10000,8501.879445791245,0.3395248651504516,3.138223648071289,0.3097600042819977,3.3382935523986816,50000 -323.1879549026489,0.5490307807922363,8705.275005102158,25837,0,8705.275005102158,0.2086000144481659,4.288328170776367,10000,9029.854134559631,0.3012396395206451,3.451519012451172,0.2791000008583069,3.64577889442444,50000 -341.062353849411,0.5862481594085693,9215.481248617172,27362,0,9215.481248617172,0.0480000004172325,9.050444602966309,10000,9558.02182650566,0.0728037282824516,8.18244743347168,0.0666399970650672,8.396821975708008,50000 -359.5661907196045,0.6190822124481201,9725.596413373947,28887,0,9725.596413373947,0.1783000081777572,5.061367988586426,10000,10086.723033189774,0.2673788070678711,4.047645092010498,0.2487599998712539,4.2008957862854,50000 -377.2316431999207,0.6528291702270508,10235.736080169678,30412,0,10235.736080169678,0.1219000071287155,5.848219871520996,10000,10614.611628293993,0.1772759854793548,5.089120864868164,0.161420002579689,5.261129379272461,50000 -394.8913614749909,0.6915583610534668,10745.85445523262,31937,0,10745.85445523262,0.1275000125169754,5.495242595672607,10000,11142.47832274437,0.1874202787876129,4.548458099365234,0.1766799986362457,4.695439338684082,50000 -412.54050064086914,0.7317273616790771,11255.838641881945,33462,0,11255.838641881945,0.1143000051379203,6.097242832183838,10000,11670.20167684555,0.177973523736,5.004164218902588,0.1741800010204315,5.03352689743042,50000 -430.7237157821655,0.7693831920623779,11765.846685171127,34987,0,11765.846685171127,0.1218000054359436,5.849211692810059,10000,12198.48053264618,0.1864835768938064,4.9204301834106445,0.179639995098114,5.009549140930176,50000 -448.6538984775543,0.8097381591796875,12275.982605934145,36513,0,12275.982605934145,0.2223000079393386,4.251679420471191,10000,12726.636957645416,0.3043088316917419,3.459326267242432,0.290719985961914,3.5893800258636475,50000 -466.66515278816223,0.8490216732025146,12786.193783521652,38039,0,12786.193783521652,0.1794000118970871,4.738935470581055,10000,13254.94819188118,0.2504782974720001,3.9678797721862793,0.2319799959659576,4.118331909179688,50000 -484.4862456321716,0.8864550590515137,13296.231772899628,39563,0,13296.231772899628,0.0627000033855438,6.510247230529785,10000,13782.895034313202,0.0961415767669677,5.963015556335449,0.0889599993824958,6.041711807250977,50000 -502.60431265830994,0.9239518642425536,13806.275705337524,41089,0,13806.275705337524,0.0402000024914741,7.291037082672119,10000,14311.14434671402,0.0666852667927742,6.601611614227295,0.0598399974405765,6.742616653442383,50000 -520.4744794368744,0.961716651916504,14316.291334152222,42615,0,14316.291334152222,0.2201000154018402,4.310806751251221,10000,14839.117769956589,0.322963148355484,3.35842227935791,0.3021000027656555,3.5063271522521973,50000 -538.5152575969696,0.9980008602142334,14826.530678033829,44049,0,14826.530678033829,0.184900015592575,4.619652271270752,10000,15367.481809616089,0.2813097834587097,3.6863856315612793,0.2552799880504608,3.895809650421143,50000 -556.3334038257599,1.0387768745422363,15336.509346485138,45575,0,15336.509346485138,0.2482000142335891,4.0240983963012695,10000,15895.369309186935,0.3519212305545807,3.07643723487854,0.3336399793624878,3.243494272232056,50000 -573.8821873664856,1.0766091346740725,15846.612461805344,47101,0,15846.612461805344,0.1379000097513198,5.32618522644043,10000,16423.10874414444,0.2008330672979354,4.566847324371338,0.1885399967432022,4.639566898345947,50000 -591.6564452648163,1.1180686950683594,16356.553442955015,48627,0,16356.553442955015,0.2035000026226043,4.547379016876221,10000,16950.91544485092,0.2834024131298065,3.696853637695313,0.2582399845123291,3.9303948879241934,50000 -609.6140928268433,1.1585710048675537,16866.786379098892,50154,0,16866.786379098892,0.1962000131607055,4.586901664733887,10000,17479.196828603745,0.2775829136371612,3.7395527362823486,0.2506999969482422,3.980403423309326,50000 -627.7574996948242,1.1959059238433838,17376.93217921257,51681,0,17376.93217921257,0.2612999975681305,3.977926254272461,10000,18007.572848796844,0.3650350570678711,3.0436782836914062,0.3366599977016449,3.240746259689331,50000 -645.4750876426697,1.2352380752563477,17887.1039853096,53208,0,17887.1039853096,0.2336000055074691,4.169589042663574,10000,18535.551746606827,0.3320710957050323,3.289456367492676,0.3113200068473816,3.4361612796783447,50000 -663.5008449554443,1.2788963317871094,18397.301684379578,54735,0,18397.301684379578,0.1551000028848648,5.019404411315918,10000,19063.868786096573,0.2272998988628387,4.260939121246338,0.2242999970912933,4.226137161254883,50000 -681.4662253856659,1.322425127029419,18907.414145231247,56262,0,18907.414145231247,0.1467000097036361,5.59663200378418,10000,19592.039578437805,0.1985411345958709,4.842617511749268,0.1861799955368042,5.003037452697754,50000 -699.392019033432,1.363011598587036,19417.5827600956,57789,0,19417.5827600956,0.2788000106811523,3.648837566375733,10000,20120.224281072617,0.405652105808258,2.680437088012696,0.3668600022792816,2.947593688964844,50000 -717.2902994155884,1.4048044681549072,19927.652045965195,59316,0,19927.652045965195,0.284600019454956,3.6405067443847656,10000,20648.283844470978,0.4073062837123871,2.710867166519165,0.3699599802494049,2.94526481628418,50000 -735.1041345596313,1.4447991847991943,20437.66402554512,60843,0,20437.66402554512,0.1575000137090683,5.075519561767578,10000,21176.199315071102,0.2323022931814193,4.145813941955566,0.221119999885559,4.304041862487793,50000 -753.0401477813721,1.4879536628723145,20947.76675963401,62370,0,20947.76675963401,0.1399000138044357,5.52890157699585,10000,21704.33132839203,0.2048588991165161,4.66219425201416,0.1911199986934662,4.752766132354736,50000 -770.624137878418,1.533905029296875,21457.712785243988,63897,0,21457.712785243988,0.2681000232696533,3.77394700050354,10000,22231.957134723663,0.3743223845958709,2.886444091796875,0.3484399914741516,3.05894422531128,50000 -788.2158498764038,1.5760498046875,21967.712792396545,65424,0,21967.712792396545,0.2158000171184539,4.417821884155273,10000,22759.64086675644,0.307995855808258,3.498842716217041,0.284879982471466,3.684895277023315,50000 -805.8626515865326,1.61775541305542,22477.6426115036,66951,0,22477.6426115036,0.2025000154972076,4.604314804077148,10000,23287.309163093567,0.3006616532802582,3.679957628250122,0.2753999829292297,3.886249542236328,50000 -823.9564123153687,1.6605603694915771,22987.82523560524,68479,0,22987.82523560524,0.2540000081062317,3.8654260635375977,10000,23815.67833900452,0.3583784997463226,3.008211612701416,0.3360199928283691,3.1703059673309326,50000 -841.6116020679474,1.704545497894287,23497.832139492035,70006,0,23497.832139492035,0.214800015091896,4.404978275299072,10000,24343.43438887596,0.3116430044174194,3.471470355987549,0.2888000011444092,3.624599695205689,50000 -859.4020702838898,1.7458949089050293,24007.7573299408,71533,0,24007.7573299408,0.1634000092744827,5.040404319763184,10000,24871.24115753174,0.2196069806814193,4.325572967529297,0.2080599963665008,4.440552234649658,50000 -877.2425088882446,1.785841703414917,24517.887838363647,73061,0,24517.887838363647,0.2539000213146209,4.007626056671143,10000,25399.302145957947,0.3602519035339355,3.1127405166625977,0.338619977235794,3.271209239959717,50000 -894.9329364299774,1.829528570175171,25027.81852698326,74588,0,25027.81852698326,0.1949000060558319,4.602164268493652,10000,25927.017527341843,0.287488043308258,3.622225284576416,0.262719988822937,3.84192419052124,50000 -912.58482670784,1.876232624053955,25537.98294305801,76116,0,25537.98294305801,0.2672000229358673,3.9029479026794434,10000,26454.930029153824,0.3895487785339355,2.8953745365142822,0.3669599890708923,3.0703601837158203,50000 -930.2763559818268,1.9222619533538816,26048.15784502029,77644,0,26048.15784502029,0.3033000230789184,3.6137118339538574,10000,26982.89257240296,0.4274752736091614,2.703529357910156,0.3989799916744232,2.865347623825073,50000 -948.2588489055634,1.9660618305206297,26558.10323810577,79171,0,26558.10323810577,0.315200001001358,3.397288084030152,10000,27510.913942098618,0.445691168308258,2.4578397274017334,0.4142599999904632,2.646533966064453,50000 -966.1588563919069,2.0083022117614746,27068.290987730023,80699,0,27068.290987730023,0.3410000205039978,3.322407484054565,10000,28039.09343481064,0.4790935814380646,2.3329217433929443,0.4555400013923645,2.490853309631348,50000 -983.8452835083008,2.0514883995056152,27578.20561933517,82226,0,27578.20561933517,0.2016000151634216,4.77598237991333,10000,28566.787320137024,0.306740254163742,3.739010810852051,0.2848399877548218,3.9017951488494873,50000 -1001.8912694454192,2.093227863311768,28088.363196611404,83754,0,28088.363196611404,0.2864000201225281,3.7166709899902335,10000,29095.08299231529,0.4067083895206451,2.743592500686645,0.370739996433258,3.0049564838409424,50000 -1019.915079832077,2.1403162479400635,28598.490182876587,85282,0,28598.490182876587,0.1824000030755996,4.670032978057861,10000,29623.33106446266,0.2871890962123871,3.671375036239624,0.2664200067520141,3.834371089935303,50000 -1037.8243143558502,2.196073055267334,29108.441494226456,86809,0,29108.441494226456,0.2788000106811523,3.904389619827272,10000,30151.29687857628,0.3782086968421936,3.034822463989258,0.3511799871921539,3.27678656578064,50000 -1055.6407935619354,2.24765419960022,29618.62277984619,88337,0,29618.62277984619,0.2198000103235244,4.452763080596924,10000,30679.39599967003,0.3022560477256775,3.619099140167236,0.2789599895477295,3.844160556793213,50000 -1073.5947699546814,2.2941675186157227,30128.824651002884,89864,0,30128.824651002884,0.3614000082015991,3.230496644973755,10000,31207.64864993096,0.5006975531578064,2.200382947921753,0.4702999889850616,2.404654264450073,50000 -1092.3768472671509,2.3449578285217285,30638.732491970062,91391,0,30638.732491970062,0.2346000075340271,4.22288990020752,10000,31736.439160823826,0.3496890962123871,3.236147165298462,0.3291199803352356,3.4062445163726807,50000 -1110.28062915802,2.3948986530303955,31148.847939491272,92919,0,31148.847939491272,0.2875000238418579,3.9007530212402335,10000,32264.55832004547,0.4011878073215484,2.894664764404297,0.3725799918174743,3.125652313232422,50000 -1128.0174877643583,2.43977689743042,31658.8655025959,94446,0,31658.8655025959,0.325300008058548,3.4505057334899902,10000,32792.40767073631,0.460339605808258,2.431842565536499,0.4266199767589569,2.65660047531128,50000 -1145.614995956421,2.484170436859131,32168.929438829426,95974,0,32168.929438829426,0.1424000114202499,5.283186912536621,10000,33320.16366028786,0.2023078650236129,4.5234880447387695,0.1915399879217147,4.669929027557373,50000 -1163.6204626560211,2.5338714122772217,32678.92079949379,97501,0,32678.92079949379,0.2884000241756439,3.7778568267822266,10000,33848.26008415222,0.3995336294174194,2.832772731781006,0.3749800026416778,3.0190296173095703,50000 -1181.5033974647522,2.582381010055542,33189.129398584366,99029,0,33189.129398584366,0.3244000077247619,3.5730552673339844,10000,34376.450227975845,0.4629504084587097,2.432036638259888,0.4339199960231781,2.634920597076416,50000 -1199.5894927978516,2.630051374435425,33699.24235534668,100557,0,33699.24235534668,0.2790000140666961,4.015759944915772,10000,34904.74674367905,0.4109534323215484,2.8825113773345947,0.3824599981307983,3.089705467224121,50000 -1217.1798260211945,2.681713342666626,34209.41844010353,102085,0,34209.41844010353,0.347100019454956,3.2543532848358154,10000,35432.6150135994,0.4940409660339355,2.256666421890259,0.4540799856185913,2.4757895469665527,50000 -1235.0998673439026,2.732177972793579,34719.36843562126,103612,0,34719.36843562126,0.2385000139474868,4.211745738983154,10000,35960.58519363403,0.3447664082050323,3.2952005863189697,0.3284199833869934,3.4189095497131348,50000 -1252.949723482132,2.781178951263428,35229.46646118164,105140,0,35229.46646118164,0.3194000124931335,3.649909496307373,10000,36488.632420539856,0.4314213991165161,2.7033612728118896,0.3980000019073486,2.949122190475464,50000 -1271.185317993164,2.8301868438720703,35739.57062864304,106668,0,35739.57062864304,0.4075000286102295,2.7543745040893555,10000,37017.07039570808,0.562898576259613,1.828281283378601,0.5299199819564819,2.0261363983154297,50000 -1289.1059653759005,2.881099939346313,36249.67180633545,108196,0,36249.67180633545,0.359000027179718,3.246283531188965,10000,37545.1928293705,0.5349768996238708,1.9775022268295288,0.4815199971199035,2.32653546333313,50000 -1306.8855466842651,2.931341648101806,36759.76053881645,109724,0,36759.76053881645,0.3776000142097473,3.059465885162353,10000,38073.16121888161,0.5242546200752258,2.0457510948181152,0.4815599918365478,2.3079276084899902,50000 -1324.7671279907229,2.98144006729126,37269.95954012871,111252,0,37269.95954012871,0.3883000314235687,2.9401986598968506,10000,38601.34153342247,0.5350366830825806,1.9782723188400269,0.5010600090026855,2.1930699348449707,50000 -1342.4133460521698,3.02941370010376,37780.089233636856,112780,0,37780.089233636856,0.3680000305175781,3.1187973022460938,10000,39129.21529483795,0.5105628371238708,2.158193826675415,0.4708399772644043,2.4166712760925293,50000 -1360.0454897880554,3.0832841396331787,38290.16973924637,114308,0,38290.16973924637,0.3813000321388244,3.057910680770874,10000,39657.031764507294,0.5438257455825806,1.9540691375732424,0.5044999718666077,2.177619695663452,50000 -1378.0283389091492,3.134504556655884,38800.36287164688,115836,0,38800.36287164688,0.457500010728836,2.545141696929932,10000,40185.308972120285,0.6219108700752258,1.5522531270980835,0.5729599595069885,1.8273431062698364,50000 -1395.919328212738,3.18233060836792,39310.3577773571,117364,0,39310.3577773571,0.435200035572052,2.6567647457122803,10000,40713.292598724365,0.6141780614852905,1.578540563583374,0.551800012588501,1.9139325618743896,50000 -1413.4967761039734,3.233930826187134,39820.552035331726,118892,0,39820.552035331726,0.4172000288963318,2.772226333618164,10000,41241.165909051895,0.5724250674247742,1.76844584941864,0.526639997959137,2.034866571426392,50000 -1431.083517551422,3.284346342086792,40330.61541700363,120420,0,40330.61541700363,0.443200021982193,2.578282833099365,10000,41768.916761636734,0.6161710619926453,1.567164182662964,0.5663999915122986,1.838868498802185,50000 -1448.8743512630465,3.3349971771240234,40840.63898897171,121948,0,40840.63898897171,0.4316000342369079,2.7272050380706787,10000,42296.831644296646,0.5903021097183228,1.6958752870559692,0.5471799969673157,1.95624577999115,50000 -1466.5170137882233,3.384669065475464,41350.85647821426,123476,0,41350.85647821426,0.4378000199794769,2.6645984649658203,10000,42824.79166054726,0.6090362071990967,1.6039897203445437,0.5696600079536438,1.830997109413147,50000 -1484.1569118499756,3.4393198490142822,41860.99411845207,125004,0,41860.99411845207,0.4446000158786773,2.612132787704468,10000,43352.67394042015,0.6317163705825806,1.4943809509277344,0.5685999989509583,1.83841335773468,50000 -1501.814185142517,3.492196083068848,42371.09545564652,126532,0,42371.09545564652,0.4226000308990478,2.7985992431640625,10000,43880.53527808189,0.5991111397743225,1.6552882194519043,0.5497399568557739,1.9696252346038816,50000 -1519.5065422058103,3.5481600761413574,42881.19850087166,128060,0,42881.19850087166,0.4629000127315521,2.504871129989624,10000,44408.43687057495,0.6313177347183228,1.4843790531158447,0.5854200124740601,1.7508049011230469,50000 -1537.0586075782776,3.6010334491729736,43391.20353627205,129588,0,43391.20353627205,0.4502000212669372,2.617003917694092,10000,44936.09682822228,0.620137095451355,1.572115778923035,0.5705400109291077,1.8495970964431765,50000 -1554.9034173488617,3.656316041946411,43901.29053258896,131116,0,43901.29053258896,0.433100014925003,2.7411255836486816,10000,45464.13428735733,0.6050900816917419,1.635583519935608,0.5597000122070312,1.9223034381866453,50000 -1572.5753400325775,3.7136833667755127,44411.38590765,132644,0,44411.38590765,0.457500010728836,2.617147445678711,10000,45992.00921392441,0.6350247263908386,1.4889466762542725,0.5819799900054932,1.7952096462249756,50000 -1590.098914861679,3.770865440368652,44921.41864657402,134172,0,44921.41864657402,0.4809000194072723,2.401205778121948,10000,46519.67265796661,0.6775948405265808,1.2720587253570557,0.6111199855804443,1.6209615468978882,50000 -1607.9817507266998,3.824446439743042,45431.4598543644,135700,0,45431.4598543644,0.4642000198364258,2.5052618980407715,10000,47047.70048522949,0.6362802982330322,1.4815256595611572,0.5810799598693848,1.777900457382202,50000 -1625.8593764305117,3.8787858486175537,45941.48925709725,137228,0,45941.48925709725,0.497700035572052,2.3209965229034424,10000,47575.711570978165,0.672273576259613,1.3039746284484863,0.6099199652671814,1.6179298162460327,50000 -1643.7108445167542,3.932407379150391,46451.5012075901,138756,0,46451.5012075901,0.4925000369548797,2.3523638248443604,10000,48103.67853784561,0.6704400181770325,1.3004449605941772,0.6202799677848816,1.5722438097000122,50000 -1661.262745141983,3.98689866065979,46961.43891119957,140282,0,46961.43891119957,0.355400025844574,3.248229503631592,10000,48631.27277398109,0.4956353604793548,2.2361655235290527,0.4678599834442138,2.429154634475708,50000 -1679.1552288532257,4.043623924255371,47471.4825797081,141810,0,47471.4825797081,0.4934000372886657,2.3541271686553955,10000,49159.315249443054,0.7058752775192261,1.1420633792877195,0.6265599727630615,1.5559496879577637,50000 -1697.2029082775116,4.101156234741211,47981.5309278965,143338,0,47981.5309278965,0.5059000253677368,2.282422065734864,10000,49687.51860022545,0.7002949714660645,1.1588937044143677,0.6348199844360352,1.5233947038650513,50000 -1714.9963533878326,4.159027576446533,48491.59139537811,144866,0,48491.59139537811,0.5162000060081482,2.229666948318481,10000,50215.48052382469,0.7127311825752258,1.1116129159927368,0.642300009727478,1.4744572639465332,50000 -1732.4942715168,4.217477798461914,49001.64282894135,146394,0,49001.64282894135,0.5260000228881836,2.159235954284668,10000,50743.13808107376,0.7197863459587097,1.077343225479126,0.6534799933433533,1.424755334854126,50000 -1750.3120419979095,4.277265548706055,49511.83917331696,147922,0,49511.83917331696,0.526900053024292,2.1575634479522705,10000,51271.2620010376,0.7180524468421936,1.0938856601715088,0.6541599631309509,1.4333339929580688,50000 -1768.307421207428,4.336929559707642,50021.862213134766,149450,0,50021.862213134766,0.5319000482559204,2.132917881011963,10000,51799.39016842842,0.7408322691917419,0.9889224767684937,0.6670199632644653,1.3719351291656494,50000 -1785.8828003406525,4.391806602478027,50531.87656021118,150978,0,50531.87656021118,0.5469000339508057,2.0598459243774414,10000,52327.08454442024,0.7679368257522583,0.8675883412361145,0.6796999573707581,1.3259252309799194,50000 -1803.7932722568512,4.44978928565979,51042.05988812447,152506,0,51042.05988812447,0.5541000366210938,2.066560745239258,10000,52855.28619623184,0.7592275142669678,0.9073374271392822,0.6794599890708923,1.3260866403579712,50000 -1821.650677442551,4.508504867553711,51552.22501087189,154034,0,51552.22501087189,0.5524000525474548,2.03236985206604,10000,53383.41737794876,0.7609016299247742,0.9098778963088988,0.6835599541664124,1.3010308742523191,50000 -1839.724959373474,4.563934326171875,52062.35582947731,155562,0,52062.35582947731,0.5663000345230103,1.9566088914871216,10000,53911.72821640968,0.7763671875,0.8441674709320068,0.6982199549674988,1.2281646728515625,50000 -1857.1919219493864,4.625617980957031,52572.40886282921,157090,0,52572.40886282921,0.5633000135421753,1.987952828407288,10000,54439.35976409912,0.7760881781578064,0.8410550951957703,0.6970399618148804,1.2410316467285156,50000 -1875.256622314453,4.684526920318604,53082.36680340767,158618,0,53082.36680340767,0.579800009727478,1.921416163444519,10000,54967.49105381966,0.8107461333274841,0.691912055015564,0.7024999856948853,1.2118477821350098,50000 -1893.1282756328585,4.742824554443359,53592.58766222,160146,0,53592.58766222,0.5776000022888184,1.909941554069519,10000,55495.69156050682,0.8052654266357422,0.7151588201522827,0.7060399651527405,1.1937520503997805,50000 -1911.025577545166,4.800848007202148,54102.61943149567,161674,0,54102.61943149567,0.5815000534057617,1.902944564819336,10000,56023.72791719437,0.8055843114852905,0.7204368710517883,0.7084199786186218,1.1815768480300903,50000 -1928.901858329773,4.860100746154785,54612.54327106476,163201,0,54612.54327106476,0.5901000499725342,1.8701530694961548,10000,56551.63760781288,0.8162667155265808,0.6751604676246643,0.7181999683380127,1.147418975830078,50000 -1946.5123527050016,4.916692018508911,55122.71467757225,164729,0,55122.71467757225,0.5978000164031982,1.8753104209899905,10000,57079.52602314949,0.8210498690605164,0.6503438353538513,0.7199400067329407,1.1440032720565796,50000 -1964.1750729084013,4.975133895874023,55632.91140437126,166257,0,55632.91140437126,0.5962000489234924,1.851595997810364,10000,57607.49398231506,0.8339245915412903,0.5999334454536438,0.720579981803894,1.1352955102920532,50000 -1982.2370581626888,5.035584211349487,56142.81165266037,167784,0,56142.81165266037,0.5995000004768372,1.822930932044983,10000,58135.56666016579,0.8476362824440002,0.5580589175224304,0.7263399958610535,1.1028032302856443,50000 -2000.8068022727969,5.0948405265808105,56652.88702845573,169312,0,56652.88702845573,0.6077000498771667,1.800598382949829,10000,58664.32110500336,0.8518813848495483,0.5361858010292053,0.7312799692153931,1.0846476554870603,50000 -2018.65958237648,5.155996322631836,57162.94032907486,170840,0,57162.94032907486,0.6133000254631042,1.7768418788909912,10000,59192.33838939667,0.8516421914100647,0.5271843671798706,0.7345199584960938,1.0806338787078855,50000 -2036.3830211162567,5.2170140743255615,57673.15584611893,172368,0,57673.15584611893,0.6114000082015991,1.7621028423309326,10000,59720.38824558258,0.8561064600944519,0.51559978723526,0.7358399629592896,1.0692557096481323,50000 -2054.0894026756287,5.276186227798462,58183.20809054375,173896,0,58183.20809054375,0.6181000471115112,1.763548493385315,10000,60248.25574231148,0.8584582209587097,0.5045436024665833,0.7406799793243408,1.059166669845581,50000 -2071.8521168231964,5.343605995178223,58693.388276577,175424,0,58693.388276577,0.6177000403404236,1.7469100952148438,10000,60776.315844774246,0.8757573366165161,0.4480269849300384,0.7421999573707581,1.0490195751190186,50000 -2089.3616197109222,5.404559135437012,59203.56213974953,176952,0,59203.56213974953,0.6208000183105469,1.750800848007202,10000,61304.11006069183,0.8742027878761292,0.4469842612743377,0.7447999715805054,1.043401122093201,50000 -2106.9123561382294,5.46753716468811,59713.66224193573,178480,0,59713.66224193573,0.6232000589370728,1.740649938583374,10000,61831.87384080887,0.878348171710968,0.4332602918148041,0.7470600008964539,1.0316156148910522,50000 -2124.426818370819,5.530123233795166,60223.66440916061,180007,0,60223.66440916061,0.6262000203132629,1.7318052053451538,10000,62359.50334787369,0.8775510191917419,0.4316553771495819,0.7484999895095825,1.030386447906494,50000 -2142.2800900936127,5.592935800552368,60733.81407546997,181535,0,60733.81407546997,0.6265000104904175,1.7263579368591309,10000,62887.61942386627,0.883211076259613,0.4166988730430603,0.7495200037956238,1.027388334274292,50000 -2159.90277671814,5.653109788894653,61243.983364105225,183063,0,61243.983364105225,0.6269000172615051,1.7243411540985107,10000,63415.52132463455,0.8834103941917419,0.4161964356899261,0.7491599917411804,1.0240086317062378,50000 -2177.6936724185944,5.716378927230835,61753.94016075134,184590,0,61753.94016075134,0.6244000196456909,1.7232787609100342,10000,63943.381813287735,0.8830117583274841,0.4159443080425262,0.7496199607849121,1.0223222970962524,50000 -2195.205705881119,5.779400110244751,62264.0898706913,186118,0,62264.0898706913,0.6252000331878662,1.7218387126922607,10000,64471.15631699562,0.8834103941917419,0.4153479635715484,0.7500199675559998,1.0227417945861816,50000 -2213.0312852859497,5.842752933502197,62446.866736888885,186666,0,62446.866736888885,0.6258000135421753,1.7213950157165527,10000,64671.84037780762,0.8849848508834839,0.41121968626976013,0.7501599788665771,1.0221025943756104,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/measurements.csv deleted file mode 100644 index d495f3b48..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1993 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.66344464,6.9194984,,,,,,,,,,,,,, -1,,,0.0009167729294858,6.91040563583374,0.0009599999757483,6.910243988037109,50000.0,0.0006000000284984,6.910250186920166,10000.0,33.11548709869385,51.11273193359375,33.11548709869385,17.99710965156555,0.0,0.0 -100,0.7601842,6.6261435,,,,,,,,,,,,,, -200,0.9262269,6.2224007,,,,,,,,,,,,,, -300,2.5157611,5.9305778,,,,,,,,,,,,,, -400,2.5438652,5.675941,,,,,,,,,,,,,, -500,4.243795,5.4943447,,,,,,,,,,,,,, -600,3.8333776,5.335438,,,,,,,,,,,,,, -700,5.4438863,5.1705694,,,,,,,,,,,,,, -800,5.2511344,4.9309525,,,,,,,,,,,,,, -900,7.0671425,4.77565,,,,,,,,,,,,,, -1000,2.9867406,4.6419444,,,,,,,,,,,,,, -1100,2.3310397,4.602032,,,,,,,,,,,,,, -1200,2.7845414,4.4976683,,,,,,,,,,,,,, -1300,3.4910622,4.267277,,,,,,,,,,,,,, -1400,3.8865523,4.146941,,,,,,,,,,,,,, -1500,4.456491,4.07131,,,,,,,,,,,,,, -1515,,,0.1935586631298065,4.136880397796631,0.181439995765686,4.266019344329834,50000.0,0.1337999999523162,4.854593276977539,10000.0,543.3378028869629,579.1754071712494,543.3378028869629,35.76667332649231,0.0208404064178466,0.0 -1600,3.022937,4.000207,,,,,,,,,,,,,, -1700,2.9365256,3.8594542,,,,,,,,,,,,,, -1800,3.2997746,3.725142,,,,,,,,,,,,,, -1900,1.9644289,3.7295938,,,,,,,,,,,,,, -2000,2.494686,3.6215384,,,,,,,,,,,,,, -2100,1.5081024,3.5458653,,,,,,,,,,,,,, -2200,2.4457285,3.4643962,,,,,,,,,,,,,, -2300,1.9292928,3.507222,,,,,,,,,,,,,, -2400,1.2380189,3.266749,,,,,,,,,,,,,, -2500,2.0495877,3.3173437,,,,,,,,,,,,,, -2600,1.4405842,3.2978897,,,,,,,,,,,,,, -2700,1.0199248,3.2565703,,,,,,,,,,,,,, -2800,1.2615166,3.2269433,,,,,,,,,,,,,, -2900,1.2786475,3.218392,,,,,,,,,,,,,, -3000,1.3325933,3.0007062,,,,,,,,,,,,,, -3030,,,0.3678053319454193,2.9227986335754395,0.3442199826240539,3.063002109527588,50000.0,0.2606000006198883,3.755113124847412,10000.0,1053.532371044159,1107.5142476558683,1053.532371044159,53.82983016967773,0.052786111831665,0.0 -3100,2.1955154,3.159089,,,,,,,,,,,,,, -3200,1.0741245,3.0293994,,,,,,,,,,,,,, -3300,1.1120003,3.0249503,,,,,,,,,,,,,, -3400,1.2623445,3.0797665,,,,,,,,,,,,,, -3500,1.2239339,2.9407446,,,,,,,,,,,,,, -3600,1.0395179,2.9740052,,,,,,,,,,,,,, -3700,1.011735,2.9197016,,,,,,,,,,,,,, -3800,0.86937255,2.840313,,,,,,,,,,,,,, -3900,1.6027439,2.8777776,,,,,,,,,,,,,, -4000,0.860746,2.8489895,,,,,,,,,,,,,, -4100,0.9684569,2.9492111,,,,,,,,,,,,,, -4200,0.8857932,2.839797,,,,,,,,,,,,,, -4300,1.1907746,2.8780873,,,,,,,,,,,,,, -4400,1.1544917,2.8070264,,,,,,,,,,,,,, -4500,1.1185856,2.7731419,,,,,,,,,,,,,, -4547,,,0.409877210855484,2.6684727668762207,0.3856199979782104,2.8410749435424805,50000.0,0.2830000221729278,3.582597494125366,10000.0,1563.7822008132937,1635.7310452461245,1563.7822008132937,71.71818590164185,0.0823340415954589,0.0 -4600,0.8466758,2.677763,,,,,,,,,,,,,, -4700,0.91452056,2.7493556,,,,,,,,,,,,,, -4800,1.0936269,2.7047434,,,,,,,,,,,,,, -4900,0.77820086,2.7208335,,,,,,,,,,,,,, -5000,0.9872544,2.6415915,,,,,,,,,,,,,, -5100,0.76156735,2.6717381,,,,,,,,,,,,,, -5200,0.8703763,2.6688178,,,,,,,,,,,,,, -5300,0.8850041,2.7540205,,,,,,,,,,,,,, -5400,1.0073792,2.6130738,,,,,,,,,,,,,, -5500,0.9192385,2.601632,,,,,,,,,,,,,, -5600,0.83027524,2.7230065,,,,,,,,,,,,,, -5700,0.8174405,2.556299,,,,,,,,,,,,,, -5800,0.90187573,2.5921946,,,,,,,,,,,,,, -5900,1.0968889,2.60271,,,,,,,,,,,,,, -6000,0.91704065,2.56205,,,,,,,,,,,,,, -6064,,,0.3488520383834839,3.0972437858581543,0.3138200044631958,3.3472726345062256,50000.0,0.2252000123262405,4.131947994232178,10000.0,2073.916249513626,2164.014529466629,2073.916249513626,89.78872632980347,0.1114284992218017,0.0 -6100,0.99097264,2.5500765,,,,,,,,,,,,,, -6200,0.89848167,2.5509214,,,,,,,,,,,,,, -6300,0.91995114,2.702902,,,,,,,,,,,,,, -6400,0.9750557,2.6687813,,,,,,,,,,,,,, -6500,0.8653617,2.6542447,,,,,,,,,,,,,, -6600,1.289297,2.5477214,,,,,,,,,,,,,, -6700,0.85397583,2.6034925,,,,,,,,,,,,,, -6800,0.84921443,2.5540442,,,,,,,,,,,,,, -6900,1.0441005,2.54679,,,,,,,,,,,,,, -7000,1.0843856,2.587129,,,,,,,,,,,,,, -7100,0.9519341,2.6579595,,,,,,,,,,,,,, -7200,1.0796114,2.6783738,,,,,,,,,,,,,, -7300,0.90445215,2.4934974,,,,,,,,,,,,,, -7400,1.0662235,2.640772,,,,,,,,,,,,,, -7500,0.9204555,2.599927,,,,,,,,,,,,,, -7581,,,0.291693240404129,3.627654552459717,0.2701799869537353,3.8371315002441406,50000.0,0.2040000110864639,4.534960746765137,10000.0,2583.992983341217,2692.171564102173,2583.992983341217,107.7860050201416,0.1452906131744384,0.0 -7600,0.9240144,2.5167012,,,,,,,,,,,,,, -7700,1.0792023,2.4800994,,,,,,,,,,,,,, -7800,1.0351455,2.4826412,,,,,,,,,,,,,, -7900,1.0205649,2.5742397,,,,,,,,,,,,,, -8000,0.9645947,2.4255712,,,,,,,,,,,,,, -8100,0.99980086,2.495604,,,,,,,,,,,,,, -8200,0.9857035,2.6179862,,,,,,,,,,,,,, -8300,0.85827243,2.4591966,,,,,,,,,,,,,, -8400,0.9582239,2.5014117,,,,,,,,,,,,,, -8500,1.00211,2.4128776,,,,,,,,,,,,,, -8600,0.9306888,2.448164,,,,,,,,,,,,,, -8700,1.0589285,2.4502296,,,,,,,,,,,,,, -8800,1.0038171,2.5310173,,,,,,,,,,,,,, -8900,0.89129204,2.5433764,,,,,,,,,,,,,, -9000,1.0177574,2.6182876,,,,,,,,,,,,,, -9099,,,0.0675223171710968,6.560290813446045,0.0609200000762939,6.683966159820557,50000.0,0.0473000034689903,7.12650203704834,10000.0,3094.0958971977234,3220.553163051605,3094.0958971977234,125.98636031150818,0.1741261482238769,0.0 -9100,0.9788124,2.3702462,,,,,,,,,,,,,, -9200,1.0035299,2.423063,,,,,,,,,,,,,, -9300,1.0218054,2.5087616,,,,,,,,,,,,,, -9400,0.89880925,2.414271,,,,,,,,,,,,,, -9500,1.0353796,2.3546858,,,,,,,,,,,,,, -9600,0.89598626,2.5448103,,,,,,,,,,,,,, -9700,1.0023733,2.5439873,,,,,,,,,,,,,, -9800,0.9171606,2.3853283,,,,,,,,,,,,,, -9900,0.9806048,2.467494,,,,,,,,,,,,,, -10000,1.0175625,2.47308,,,,,,,,,,,,,, -10100,1.0363152,2.4494412,,,,,,,,,,,,,, -10200,1.0432148,2.5368328,,,,,,,,,,,,,, -10300,1.0255313,2.467517,,,,,,,,,,,,,, -10400,0.9066514,2.4961286,,,,,,,,,,,,,, -10500,1.0664347,2.3855536,,,,,,,,,,,,,, -10600,0.9946156,2.31742,,,,,,,,,,,,,, -10618,,,0.3647162020206451,3.049172878265381,0.3426199853420257,3.222270965576172,50000.0,0.2546000182628631,3.9828693866729736,10000.0,3604.2537105083466,3748.795606613159,3604.2537105083466,143.99125719070437,0.2043311595916748,0.0 -10700,0.9898416,2.4344578,,,,,,,,,,,,,, -10800,0.9979726,2.4907513,,,,,,,,,,,,,, -10900,0.9657495,2.4213328,,,,,,,,,,,,,, -11000,0.96676207,2.3963032,,,,,,,,,,,,,, -11100,0.99835026,2.4295084,,,,,,,,,,,,,, -11200,1.0862163,2.3008142,,,,,,,,,,,,,, -11300,1.0031433,2.428119,,,,,,,,,,,,,, -11400,1.0429077,2.5544553,,,,,,,,,,,,,, -11500,0.8226736,2.4063616,,,,,,,,,,,,,, -11600,0.988427,2.6187286,,,,,,,,,,,,,, -11700,1.0395187,2.487824,,,,,,,,,,,,,, -11800,1.303913,2.4131145,,,,,,,,,,,,,, -11900,1.0461701,2.4630344,,,,,,,,,,,,,, -12000,1.247397,2.3489127,,,,,,,,,,,,,, -12100,1.1095893,2.3057575,,,,,,,,,,,,,, -12138,,,0.2653858363628387,3.797953367233277,0.2460999935865402,3.96741247177124,50000.0,0.1862000077962875,4.577045440673828,10000.0,4114.4356808662415,4277.1030423641205,4114.4356808662415,162.03544116020203,0.2363069057464599,0.0 -12200,1.0468978,2.2898989,,,,,,,,,,,,,, -12300,0.9283813,2.279522,,,,,,,,,,,,,, -12400,1.1670592,2.4539669,,,,,,,,,,,,,, -12500,1.0791883,2.4349475,,,,,,,,,,,,,, -12600,0.9272256,2.404419,,,,,,,,,,,,,, -12700,0.9619364,2.3169923,,,,,,,,,,,,,, -12800,0.92554927,2.375463,,,,,,,,,,,,,, -12900,0.9894491,2.415782,,,,,,,,,,,,,, -13000,1.1504703,2.5415344,,,,,,,,,,,,,, -13100,0.8935953,2.4440446,,,,,,,,,,,,,, -13200,1.1767542,2.332811,,,,,,,,,,,,,, -13300,1.1555077,2.3788974,,,,,,,,,,,,,, -13400,1.0581013,2.5063384,,,,,,,,,,,,,, -13500,1.0058328,2.2899566,,,,,,,,,,,,,, -13600,1.1337761,2.3563685,,,,,,,,,,,,,, -13658,,,0.1414421200752258,5.066132068634033,0.1360199898481369,5.142930030822754,50000.0,0.0982000082731247,5.658753395080566,10000.0,4624.410368680954,4805.513270378113,4624.410368680954,180.38438057899475,0.2737338542938232,0.0 -13700,0.90298545,2.3931372,,,,,,,,,,,,,, -13800,0.99291116,2.2939925,,,,,,,,,,,,,, -13900,1.0609046,2.4299066,,,,,,,,,,,,,, -14000,1.0374941,2.4784684,,,,,,,,,,,,,, -14100,1.0865828,2.5594254,,,,,,,,,,,,,, -14200,1.107779,2.592993,,,,,,,,,,,,,, -14300,1.0796185,2.4391766,,,,,,,,,,,,,, -14400,1.0256163,2.266693,,,,,,,,,,,,,, -14500,1.0092365,2.305131,,,,,,,,,,,,,, -14600,1.1551261,2.4057155,,,,,,,,,,,,,, -14700,1.04229,2.5717595,,,,,,,,,,,,,, -14800,1.0033038,2.2758784,,,,,,,,,,,,,, -14900,0.94587713,2.3434665,,,,,,,,,,,,,, -15000,0.9928219,2.3551276,,,,,,,,,,,,,, -15100,0.9934149,2.5259092,,,,,,,,,,,,,, -15178,,,0.2964963316917419,3.482040882110596,0.2676399946212768,3.7009527683258057,50000.0,0.2077000141143798,4.33635950088501,10000.0,5134.47070813179,5333.526302337647,5134.47070813179,198.250910282135,0.3105723857879638,0.0 -15200,0.9151435,2.3673918,,,,,,,,,,,,,, -15300,0.9435786,2.3550794,,,,,,,,,,,,,, -15400,1.0909336,2.3145306,,,,,,,,,,,,,, -15500,1.1181068,2.40034,,,,,,,,,,,,,, -15600,1.0468397,2.4656024,,,,,,,,,,,,,, -15700,1.0450279,2.3267174,,,,,,,,,,,,,, -15800,1.0581582,2.375967,,,,,,,,,,,,,, -15900,1.023724,2.4068835,,,,,,,,,,,,,, -16000,1.0268213,2.33533,,,,,,,,,,,,,, -16100,1.0437037,2.3657074,,,,,,,,,,,,,, -16200,0.97758514,2.4749937,,,,,,,,,,,,,, -16300,0.96368575,2.3644652,,,,,,,,,,,,,, -16400,1.0595013,2.3913171,,,,,,,,,,,,,, -16500,1.0039574,2.3518586,,,,,,,,,,,,,, -16600,1.1195749,2.249255,,,,,,,,,,,,,, -16699,,,0.2660634517669678,3.838587760925293,0.2495799958705902,3.98842978477478,50000.0,0.1782000064849853,4.823949813842773,10000.0,5644.488646507263,5861.574481487274,5644.488646507263,216.19500756263733,0.347388744354248,0.0 -16700,0.99501073,2.3210201,,,,,,,,,,,,,, -16800,1.0217929,2.4231014,,,,,,,,,,,,,, -16900,1.138634,2.2897289,,,,,,,,,,,,,, -17000,1.0728316,2.4611115,,,,,,,,,,,,,, -17100,1.0210694,2.367828,,,,,,,,,,,,,, -17200,1.0926762,2.4504027,,,,,,,,,,,,,, -17300,1.0364234,2.396396,,,,,,,,,,,,,, -17400,1.1890035,2.396006,,,,,,,,,,,,,, -17500,1.1225337,2.3727117,,,,,,,,,,,,,, -17600,1.0295111,2.426317,,,,,,,,,,,,,, -17700,1.0799338,2.4350238,,,,,,,,,,,,,, -17800,1.1740134,2.4480903,,,,,,,,,,,,,, -17900,1.0510417,2.351195,,,,,,,,,,,,,, -18000,1.0882049,2.36311,,,,,,,,,,,,,, -18100,1.0503185,2.311427,,,,,,,,,,,,,, -18200,1.035147,2.4585373,,,,,,,,,,,,,, -18220,,,0.1853675097227096,4.6938252449035645,0.1709599941968917,4.821983337402344,50000.0,0.1228000074625015,5.50702428817749,10000.0,6154.5932059288025,6389.560833454132,6154.5932059288025,233.9950804710388,0.3794333934783935,0.0 -18300,1.0304377,2.2776885,,,,,,,,,,,,,, -18400,1.0097141,2.2215948,,,,,,,,,,,,,, -18500,1.1598216,2.3502507,,,,,,,,,,,,,, -18600,1.0370405,2.391615,,,,,,,,,,,,,, -18700,0.93887997,2.2900465,,,,,,,,,,,,,, -18800,1.0727623,2.4502175,,,,,,,,,,,,,, -18900,0.9905823,2.2942054,,,,,,,,,,,,,, -19000,0.9976333,2.2533038,,,,,,,,,,,,,, -19100,0.9739562,2.198058,,,,,,,,,,,,,, -19200,1.0411077,2.4094877,,,,,,,,,,,,,, -19300,1.009051,2.4296055,,,,,,,,,,,,,, -19400,1.1306283,2.3091657,,,,,,,,,,,,,, -19500,1.0233576,2.3511634,,,,,,,,,,,,,, -19600,0.9825148,2.2868507,,,,,,,,,,,,,, -19700,1.0833042,2.430819,,,,,,,,,,,,,, -19743,,,0.3005022406578064,3.5624327659606934,0.2856999933719635,3.7141499519348153,50000.0,0.2044000029563903,4.545632839202881,10000.0,6664.773620843887,6917.872552871704,6664.773620843887,252.04337310791016,0.4124011993408203,0.0 -19800,1.107991,2.4014943,,,,,,,,,,,,,, -19900,1.1659046,2.3234658,,,,,,,,,,,,,, -20000,1.1518536,2.4554753,,,,,,,,,,,,,, -20100,0.960918,2.3627071,,,,,,,,,,,,,, -20200,1.1623003,2.550498,,,,,,,,,,,,,, -20300,1.040656,2.2613678,,,,,,,,,,,,,, -20400,1.2262015,2.3041282,,,,,,,,,,,,,, -20500,1.1196747,2.3678885,,,,,,,,,,,,,, -20600,1.2590107,2.2891402,,,,,,,,,,,,,, -20700,1.0612159,2.3343582,,,,,,,,,,,,,, -20800,1.0246898,2.4588518,,,,,,,,,,,,,, -20900,1.166275,2.3099642,,,,,,,,,,,,,, -21000,1.0206138,2.4661655,,,,,,,,,,,,,, -21100,1.1291342,2.3950078,,,,,,,,,,,,,, -21200,1.0296601,2.3090403,,,,,,,,,,,,,, -21265,,,0.2461535334587097,4.195694923400879,0.2303600013256073,4.367095470428467,50000.0,0.1683000028133392,5.043724060058594,10000.0,7174.803897380829,7445.765621185303,7174.803897380829,269.8241550922394,0.4445338249206543,0.0 -21300,1.0643479,2.3543224,,,,,,,,,,,,,, -21400,1.0436708,2.3029907,,,,,,,,,,,,,, -21500,1.013088,2.3287308,,,,,,,,,,,,,, -21600,1.0675213,2.2896976,,,,,,,,,,,,,, -21700,1.0668341,2.3720546,,,,,,,,,,,,,, -21800,1.063422,2.2639594,,,,,,,,,,,,,, -21900,1.1230838,2.2738607,,,,,,,,,,,,,, -22000,1.0097294,2.2165766,,,,,,,,,,,,,, -22100,1.2052544,2.2931092,,,,,,,,,,,,,, -22200,0.9936671,2.348809,,,,,,,,,,,,,, -22300,1.2445483,2.3623786,,,,,,,,,,,,,, -22400,1.3270904,2.3710265,,,,,,,,,,,,,, -22500,1.177975,2.3364794,,,,,,,,,,,,,, -22600,1.0446303,2.215774,,,,,,,,,,,,,, -22700,1.3000889,2.3448503,,,,,,,,,,,,,, -22789,,,0.006437340285629,10.417054176330566,0.005719999782741,10.540876388549805,50000.0,0.0043000001460313,10.739754676818848,10000.0,7685.032256126404,7973.861493825912,7685.032256126404,287.60385155677795,0.482598066329956,0.0 -22800,1.0154666,2.2133293,,,,,,,,,,,,,, -22900,1.0322568,2.276514,,,,,,,,,,,,,, -23000,1.0582783,2.3315299,,,,,,,,,,,,,, -23100,1.1347464,2.5382512,,,,,,,,,,,,,, -23200,1.1215287,2.4063156,,,,,,,,,,,,,, -23300,1.2671442,2.3674715,,,,,,,,,,,,,, -23400,1.1087132,2.2562003,,,,,,,,,,,,,, -23500,1.0481113,2.2521515,,,,,,,,,,,,,, -23600,1.0492201,2.3875024,,,,,,,,,,,,,, -23700,1.0244181,2.2854204,,,,,,,,,,,,,, -23800,1.0921751,2.3438041,,,,,,,,,,,,,, -23900,1.1054776,2.3286421,,,,,,,,,,,,,, -24000,1.0540898,2.2048478,,,,,,,,,,,,,, -24100,1.0497231,2.4218745,,,,,,,,,,,,,, -24200,1.11785,2.4326913,,,,,,,,,,,,,, -24300,1.156436,2.4260178,,,,,,,,,,,,,, -24313,,,0.3395248651504516,3.138223648071289,0.3097600042819977,3.3382935523986816,50000.0,0.2270000129938125,4.0918073654174805,10000.0,8195.057115793228,8501.879445791245,8195.057115793228,305.51457262039185,0.5152654647827148,0.0 -24400,1.0662614,2.3427794,,,,,,,,,,,,,, -24500,1.1370885,2.3978753,,,,,,,,,,,,,, -24600,1.0936507,2.3136053,,,,,,,,,,,,,, -24700,1.0727413,2.3697577,,,,,,,,,,,,,, -24800,1.0979344,2.2678156,,,,,,,,,,,,,, -24900,1.0652878,2.2896528,,,,,,,,,,,,,, -25000,1.0059093,2.4371881,,,,,,,,,,,,,, -25100,1.1137978,2.2395809,,,,,,,,,,,,,, -25200,1.0333657,2.2937553,,,,,,,,,,,,,, -25300,1.0950563,2.3020697,,,,,,,,,,,,,, -25400,1.1325078,2.354395,,,,,,,,,,,,,, -25500,1.1764435,2.2302382,,,,,,,,,,,,,, -25600,1.0793077,2.3202605,,,,,,,,,,,,,, -25700,1.1287528,2.2028222,,,,,,,,,,,,,, -25800,1.198569,2.4180021,,,,,,,,,,,,,, -25837,,,0.3012396395206451,3.451519012451172,0.2791000008583069,3.64577889442444,50000.0,0.2086000144481659,4.288328170776367,10000.0,8705.275005102158,9029.854134559631,8705.275005102158,323.1879549026489,0.5490307807922363,0.0 -25900,1.119821,2.381172,,,,,,,,,,,,,, -26000,1.1101414,2.304091,,,,,,,,,,,,,, -26100,1.1708378,2.389518,,,,,,,,,,,,,, -26200,1.0113308,2.293298,,,,,,,,,,,,,, -26300,1.20259,2.4054067,,,,,,,,,,,,,, -26400,1.0573364,2.364298,,,,,,,,,,,,,, -26500,1.1166635,2.3532095,,,,,,,,,,,,,, -26600,1.202517,2.4006028,,,,,,,,,,,,,, -26700,1.0613838,2.3704426,,,,,,,,,,,,,, -26800,1.0458794,2.2691288,,,,,,,,,,,,,, -26900,1.1986438,2.2871482,,,,,,,,,,,,,, -27000,1.1234164,2.1620638,,,,,,,,,,,,,, -27100,1.1215239,2.5386558,,,,,,,,,,,,,, -27200,1.1076993,2.2804284,,,,,,,,,,,,,, -27300,1.1326375,2.2688656,,,,,,,,,,,,,, -27362,,,0.0728037282824516,8.18244743347168,0.0666399970650672,8.396821975708008,50000.0,0.0480000004172325,9.050444602966309,10000.0,9215.481248617172,9558.02182650566,9215.481248617172,341.062353849411,0.5862481594085693,0.0 -27400,1.1821109,2.3750062,,,,,,,,,,,,,, -27500,1.1085263,2.3483481,,,,,,,,,,,,,, -27600,1.1646671,2.2982092,,,,,,,,,,,,,, -27700,1.1539794,2.3429186,,,,,,,,,,,,,, -27800,1.201353,2.3542275,,,,,,,,,,,,,, -27900,1.1603037,2.2788856,,,,,,,,,,,,,, -28000,1.138704,2.2746587,,,,,,,,,,,,,, -28100,1.0584601,2.5000486,,,,,,,,,,,,,, -28200,1.0924028,2.2753315,,,,,,,,,,,,,, -28300,1.2532625,2.2803383,,,,,,,,,,,,,, -28400,1.1110773,2.34449,,,,,,,,,,,,,, -28500,1.1001046,2.2275405,,,,,,,,,,,,,, -28600,1.1814204,2.246149,,,,,,,,,,,,,, -28700,1.2177026,2.3376844,,,,,,,,,,,,,, -28800,1.0865407,2.372217,,,,,,,,,,,,,, -28887,,,0.2673788070678711,4.047645092010498,0.2487599998712539,4.2008957862854,50000.0,0.1783000081777572,5.061367988586426,10000.0,9725.596413373947,10086.723033189774,9725.596413373947,359.5661907196045,0.6190822124481201,0.0 -28900,1.1567644,2.378806,,,,,,,,,,,,,, -29000,1.0016904,2.2343113,,,,,,,,,,,,,, -29100,1.0311997,2.2583592,,,,,,,,,,,,,, -29200,1.190627,2.2507305,,,,,,,,,,,,,, -29300,1.1521589,2.158455,,,,,,,,,,,,,, -29400,1.1249585,2.2739522,,,,,,,,,,,,,, -29500,1.080335,2.28934,,,,,,,,,,,,,, -29600,1.2158862,2.2994869,,,,,,,,,,,,,, -29700,1.2482388,2.3008478,,,,,,,,,,,,,, -29800,1.0264773,2.2999935,,,,,,,,,,,,,, -29900,1.1684362,2.3876147,,,,,,,,,,,,,, -30000,1.1378287,2.3807952,,,,,,,,,,,,,, -30100,1.0826213,2.313408,,,,,,,,,,,,,, -30200,1.3081186,2.3155718,,,,,,,,,,,,,, -30300,1.0851644,2.460369,,,,,,,,,,,,,, -30400,1.11158,2.3137593,,,,,,,,,,,,,, -30412,,,0.1772759854793548,5.089120864868164,0.161420002579689,5.261129379272461,50000.0,0.1219000071287155,5.848219871520996,10000.0,10235.736080169678,10614.611628293993,10235.736080169678,377.2316431999207,0.6528291702270508,0.0 -30500,1.2224296,2.3603146,,,,,,,,,,,,,, -30600,1.0644476,2.2633579,,,,,,,,,,,,,, -30700,1.2560556,2.3809109,,,,,,,,,,,,,, -30800,1.0788236,2.3213625,,,,,,,,,,,,,, -30900,1.0863054,2.3081942,,,,,,,,,,,,,, -31000,1.1509017,2.2822468,,,,,,,,,,,,,, -31100,1.1870801,2.392474,,,,,,,,,,,,,, -31200,1.0865899,2.2084916,,,,,,,,,,,,,, -31300,1.120309,2.2042563,,,,,,,,,,,,,, -31400,1.1257826,2.3287864,,,,,,,,,,,,,, -31500,1.2454714,2.2424943,,,,,,,,,,,,,, -31600,1.3963575,2.2339468,,,,,,,,,,,,,, -31700,1.2175292,2.289865,,,,,,,,,,,,,, -31800,1.2342532,2.484667,,,,,,,,,,,,,, -31900,1.1331452,2.4198809,,,,,,,,,,,,,, -31937,,,0.1874202787876129,4.548458099365234,0.1766799986362457,4.695439338684082,50000.0,0.1275000125169754,5.495242595672607,10000.0,10745.85445523262,11142.47832274437,10745.85445523262,394.8913614749909,0.6915583610534668,0.0 -32000,1.1531894,2.435892,,,,,,,,,,,,,, -32100,1.121025,2.214783,,,,,,,,,,,,,, -32200,1.1276481,2.2957935,,,,,,,,,,,,,, -32300,1.1648659,2.1746676,,,,,,,,,,,,,, -32400,1.1422178,2.2349627,,,,,,,,,,,,,, -32500,1.0873439,2.2979794,,,,,,,,,,,,,, -32600,1.3865246,2.3400729,,,,,,,,,,,,,, -32700,1.1567343,2.2617533,,,,,,,,,,,,,, -32800,1.1455885,2.322612,,,,,,,,,,,,,, -32900,1.2777091,2.26341,,,,,,,,,,,,,, -33000,1.2234901,2.2728424,,,,,,,,,,,,,, -33100,1.2218484,2.313364,,,,,,,,,,,,,, -33200,1.1246867,2.3448524,,,,,,,,,,,,,, -33300,1.2355137,2.319494,,,,,,,,,,,,,, -33400,1.1241993,2.2385526,,,,,,,,,,,,,, -33462,,,0.177973523736,5.004164218902588,0.1741800010204315,5.03352689743042,50000.0,0.1143000051379203,6.097242832183838,10000.0,11255.838641881945,11670.20167684555,11255.838641881945,412.54050064086914,0.7317273616790771,0.0 -33500,1.1584715,2.187743,,,,,,,,,,,,,, -33600,1.0531058,2.2618437,,,,,,,,,,,,,, -33700,1.0616829,2.3790514,,,,,,,,,,,,,, -33800,1.0592078,2.2336428,,,,,,,,,,,,,, -33900,1.1915439,2.3909826,,,,,,,,,,,,,, -34000,1.1163182,2.4297333,,,,,,,,,,,,,, -34100,1.1590236,2.302815,,,,,,,,,,,,,, -34200,1.1895301,2.3473148,,,,,,,,,,,,,, -34300,1.2832733,2.409339,,,,,,,,,,,,,, -34400,1.2175312,2.3556128,,,,,,,,,,,,,, -34500,1.1845671,2.3876655,,,,,,,,,,,,,, -34600,1.1231135,2.2964551,,,,,,,,,,,,,, -34700,1.3302573,2.2860005,,,,,,,,,,,,,, -34800,1.2661631,2.3688006,,,,,,,,,,,,,, -34900,1.19675,2.382139,,,,,,,,,,,,,, -34987,,,0.1864835768938064,4.9204301834106445,0.179639995098114,5.009549140930176,50000.0,0.1218000054359436,5.849211692810059,10000.0,11765.846685171127,12198.48053264618,11765.846685171127,430.7237157821655,0.7693831920623779,0.0 -35000,1.1105429,2.375977,,,,,,,,,,,,,, -35100,1.1078378,2.3057718,,,,,,,,,,,,,, -35200,1.1290904,2.2155418,,,,,,,,,,,,,, -35300,1.1668333,2.280029,,,,,,,,,,,,,, -35400,1.0650036,2.3031275,,,,,,,,,,,,,, -35500,1.1086878,2.307689,,,,,,,,,,,,,, -35600,1.1242745,2.2660315,,,,,,,,,,,,,, -35700,1.1335534,2.245011,,,,,,,,,,,,,, -35800,1.1500418,2.3337657,,,,,,,,,,,,,, -35900,1.1358947,2.2917035,,,,,,,,,,,,,, -36000,1.2335484,2.3062313,,,,,,,,,,,,,, -36100,1.1584026,2.2066271,,,,,,,,,,,,,, -36200,1.0893723,2.285438,,,,,,,,,,,,,, -36300,1.3098596,2.3613436,,,,,,,,,,,,,, -36400,1.1695845,2.3726354,,,,,,,,,,,,,, -36500,1.3491158,2.4083452,,,,,,,,,,,,,, -36513,,,0.3043088316917419,3.459326267242432,0.290719985961914,3.5893800258636475,50000.0,0.2223000079393386,4.251679420471191,10000.0,12275.982605934145,12726.636957645416,12275.982605934145,448.6538984775543,0.8097381591796875,0.0 -36600,1.1197411,2.0658445,,,,,,,,,,,,,, -36700,1.4632589,2.319034,,,,,,,,,,,,,, -36800,1.1590935,2.2238383,,,,,,,,,,,,,, -36900,1.2132396,2.323768,,,,,,,,,,,,,, -37000,1.0874374,2.2269936,,,,,,,,,,,,,, -37100,1.2057691,2.2191613,,,,,,,,,,,,,, -37200,1.3767688,2.3225653,,,,,,,,,,,,,, -37300,1.0543247,2.2767324,,,,,,,,,,,,,, -37400,1.2044786,2.2341876,,,,,,,,,,,,,, -37500,1.2046664,2.3266964,,,,,,,,,,,,,, -37600,1.0936103,2.2440286,,,,,,,,,,,,,, -37700,1.2708516,2.3775115,,,,,,,,,,,,,, -37800,1.2988465,2.2266145,,,,,,,,,,,,,, -37900,1.2467743,2.4196107,,,,,,,,,,,,,, -38000,1.1766177,2.2319512,,,,,,,,,,,,,, -38039,,,0.2504782974720001,3.9678797721862793,0.2319799959659576,4.118331909179688,50000.0,0.1794000118970871,4.738935470581055,10000.0,12786.193783521652,13254.94819188118,12786.193783521652,466.66515278816223,0.8490216732025146,0.0 -38100,1.2926064,2.2233324,,,,,,,,,,,,,, -38200,1.2195928,2.295009,,,,,,,,,,,,,, -38300,1.2845927,2.3770404,,,,,,,,,,,,,, -38400,1.1113629,2.1817493,,,,,,,,,,,,,, -38500,1.284692,2.321534,,,,,,,,,,,,,, -38600,1.2060713,2.2971754,,,,,,,,,,,,,, -38700,1.2115223,2.215014,,,,,,,,,,,,,, -38800,1.1053066,2.238256,,,,,,,,,,,,,, -38900,1.2776713,2.3385315,,,,,,,,,,,,,, -39000,1.1426592,2.1591234,,,,,,,,,,,,,, -39100,1.1606023,2.4849072,,,,,,,,,,,,,, -39200,1.1728865,2.3421805,,,,,,,,,,,,,, -39300,1.1517736,2.2561843,,,,,,,,,,,,,, -39400,1.337415,2.2394092,,,,,,,,,,,,,, -39500,1.2901509,2.3120136,,,,,,,,,,,,,, -39563,,,0.0961415767669677,5.963015556335449,0.0889599993824958,6.041711807250977,50000.0,0.0627000033855438,6.510247230529785,10000.0,13296.231772899628,13782.895034313202,13296.231772899628,484.4862456321716,0.8864550590515137,0.0 -39600,1.1690035,2.3008509,,,,,,,,,,,,,, -39700,1.2465633,2.3047073,,,,,,,,,,,,,, -39800,1.3088063,2.428597,,,,,,,,,,,,,, -39900,1.2537053,2.2718973,,,,,,,,,,,,,, -40000,1.1332928,2.3677227,,,,,,,,,,,,,, -40100,1.3229358,2.1973236,,,,,,,,,,,,,, -40200,1.0834494,2.264813,,,,,,,,,,,,,, -40300,1.2724211,2.204773,,,,,,,,,,,,,, -40400,1.4153649,2.3560967,,,,,,,,,,,,,, -40500,1.3505914,2.164673,,,,,,,,,,,,,, -40600,1.2432636,2.3838701,,,,,,,,,,,,,, -40700,1.2359084,2.4401598,,,,,,,,,,,,,, -40800,1.1740834,2.296813,,,,,,,,,,,,,, -40900,1.2895148,2.3229933,,,,,,,,,,,,,, -41000,1.1484556,2.3262527,,,,,,,,,,,,,, -41089,,,0.0666852667927742,6.601611614227295,0.0598399974405765,6.742616653442383,50000.0,0.0402000024914741,7.291037082672119,10000.0,13806.275705337524,14311.14434671402,13806.275705337524,502.60431265830994,0.9239518642425536,0.0 -41100,1.2548907,2.2742262,,,,,,,,,,,,,, -41200,1.1768798,2.3633382,,,,,,,,,,,,,, -41300,1.2030442,2.283657,,,,,,,,,,,,,, -41400,1.2423346,2.1927056,,,,,,,,,,,,,, -41500,1.2322149,2.328959,,,,,,,,,,,,,, -41600,1.1978532,2.3524609,,,,,,,,,,,,,, -41700,1.1326467,2.2788255,,,,,,,,,,,,,, -41800,1.3296425,2.1826873,,,,,,,,,,,,,, -41900,1.2448361,2.336234,,,,,,,,,,,,,, -42000,1.2845516,2.2788367,,,,,,,,,,,,,, -42100,1.1537459,2.1217544,,,,,,,,,,,,,, -42200,1.4852295,2.3119707,,,,,,,,,,,,,, -42300,1.1858289,2.1872675,,,,,,,,,,,,,, -42400,1.2311761,2.3197832,,,,,,,,,,,,,, -42500,1.2730463,2.2139487,,,,,,,,,,,,,, -42600,1.1548927,2.2715786,,,,,,,,,,,,,, -42615,,,0.322963148355484,3.35842227935791,0.3021000027656555,3.5063271522521973,50000.0,0.2201000154018402,4.310806751251221,10000.0,14316.291334152222,14839.117769956589,14316.291334152222,520.4744794368744,0.961716651916504,0.0 -42700,1.273023,2.210261,,,,,,,,,,,,,, -42800,1.3803982,2.2577472,,,,,,,,,,,,,, -42900,1.2710238,2.2808616,,,,,,,,,,,,,, -43000,1.1615193,2.3318276,,,,,,,,,,,,,, -43100,1.3520598,2.3478801,,,,,,,,,,,,,, -43200,1.2319672,2.337573,,,,,,,,,,,,,, -43300,1.1808717,2.2520514,,,,,,,,,,,,,, -43400,1.281269,2.2021732,,,,,,,,,,,,,, -43500,1.3726789,2.25651,,,,,,,,,,,,,, -43600,1.2988393,2.3736234,,,,,,,,,,,,,, -43700,1.2129829,2.2024794,,,,,,,,,,,,,, -43800,1.1248975,2.3117483,,,,,,,,,,,,,, -43900,1.1687616,2.3067963,,,,,,,,,,,,,, -44000,1.1762315,2.2895331,,,,,,,,,,,,,, -44049,,,0.2813097834587097,3.6863856315612793,0.2552799880504608,3.895809650421143,50000.0,0.184900015592575,4.619652271270752,10000.0,14826.530678033829,15367.481809616089,14826.530678033829,538.5152575969696,0.9980008602142334,0.0 -44100,1.2933403,2.3004897,,,,,,,,,,,,,, -44200,1.2510486,2.3872838,,,,,,,,,,,,,, -44300,1.4705391,2.2951975,,,,,,,,,,,,,, -44400,1.3874164,2.3409488,,,,,,,,,,,,,, -44500,1.4191484,2.4380174,,,,,,,,,,,,,, -44600,1.1887776,2.2731674,,,,,,,,,,,,,, -44700,1.2934111,2.2918005,,,,,,,,,,,,,, -44800,1.2597377,2.229102,,,,,,,,,,,,,, -44900,1.4714581,2.2627296,,,,,,,,,,,,,, -45000,1.184263,2.1432405,,,,,,,,,,,,,, -45100,1.2144936,2.3015914,,,,,,,,,,,,,, -45200,1.1415831,2.1381707,,,,,,,,,,,,,, -45300,1.2424461,2.362986,,,,,,,,,,,,,, -45400,1.1575896,2.1386855,,,,,,,,,,,,,, -45500,1.2406577,2.1597106,,,,,,,,,,,,,, -45575,,,0.3519212305545807,3.07643723487854,0.3336399793624878,3.243494272232056,50000.0,0.2482000142335891,4.0240983963012695,10000.0,15336.509346485138,15895.369309186935,15336.509346485138,556.3334038257599,1.0387768745422363,0.0 -45600,1.3438393,2.290407,,,,,,,,,,,,,, -45700,1.151979,2.2584856,,,,,,,,,,,,,, -45800,1.2979414,2.2508123,,,,,,,,,,,,,, -45900,1.1792399,2.3494856,,,,,,,,,,,,,, -46000,1.3822598,2.3314903,,,,,,,,,,,,,, -46100,1.385575,2.2656405,,,,,,,,,,,,,, -46200,1.222012,2.3404493,,,,,,,,,,,,,, -46300,1.2275816,2.0826101,,,,,,,,,,,,,, -46400,1.2047976,2.1311595,,,,,,,,,,,,,, -46500,1.1910876,2.2015524,,,,,,,,,,,,,, -46600,1.2730135,2.2772574,,,,,,,,,,,,,, -46700,1.1304821,2.230239,,,,,,,,,,,,,, -46800,1.3958138,2.2964077,,,,,,,,,,,,,, -46900,1.2825098,2.2189596,,,,,,,,,,,,,, -47000,1.2390368,2.2586093,,,,,,,,,,,,,, -47100,1.3216683,2.2585092,,,,,,,,,,,,,, -47101,,,0.2008330672979354,4.566847324371338,0.1885399967432022,4.639566898345947,50000.0,0.1379000097513198,5.32618522644043,10000.0,15846.612461805344,16423.10874414444,15846.612461805344,573.8821873664856,1.0766091346740725,0.0 -47200,1.1903318,2.3000348,,,,,,,,,,,,,, -47300,1.3461264,2.293436,,,,,,,,,,,,,, -47400,1.286025,2.2905507,,,,,,,,,,,,,, -47500,1.3495985,2.3786821,,,,,,,,,,,,,, -47600,1.2091002,2.1678169,,,,,,,,,,,,,, -47700,1.3169543,2.2416804,,,,,,,,,,,,,, -47800,1.2997468,2.3533492,,,,,,,,,,,,,, -47900,1.3072636,2.2153752,,,,,,,,,,,,,, -48000,1.3292612,2.2483423,,,,,,,,,,,,,, -48100,1.168677,2.1273124,,,,,,,,,,,,,, -48200,1.1286013,2.3064272,,,,,,,,,,,,,, -48300,1.3155042,2.3965456,,,,,,,,,,,,,, -48400,1.3005639,2.434049,,,,,,,,,,,,,, -48500,1.2593775,2.2012386,,,,,,,,,,,,,, -48600,1.1758709,2.150243,,,,,,,,,,,,,, -48627,,,0.2834024131298065,3.696853637695313,0.2582399845123291,3.9303948879241934,50000.0,0.2035000026226043,4.547379016876221,10000.0,16356.553442955015,16950.91544485092,16356.553442955015,591.6564452648163,1.1180686950683594,0.0 -48700,1.2465749,2.1996198,,,,,,,,,,,,,, -48800,1.2123339,2.1198823,,,,,,,,,,,,,, -48900,1.200779,2.275646,,,,,,,,,,,,,, -49000,1.3362013,2.22465,,,,,,,,,,,,,, -49100,1.4034871,2.197996,,,,,,,,,,,,,, -49200,1.2699624,2.0895321,,,,,,,,,,,,,, -49300,1.2819762,2.2587671,,,,,,,,,,,,,, -49400,1.3029957,2.364995,,,,,,,,,,,,,, -49500,1.2560303,2.1396296,,,,,,,,,,,,,, -49600,1.2437207,2.284287,,,,,,,,,,,,,, -49700,1.229875,2.2893977,,,,,,,,,,,,,, -49800,1.3101747,2.2551923,,,,,,,,,,,,,, -49900,1.2836446,2.28657,,,,,,,,,,,,,, -50000,1.2785084,2.1506455,,,,,,,,,,,,,, -50100,1.3591752,2.237651,,,,,,,,,,,,,, -50154,,,0.2775829136371612,3.7395527362823486,0.2506999969482422,3.980403423309326,50000.0,0.1962000131607055,4.586901664733887,10000.0,16866.786379098892,17479.196828603745,16866.786379098892,609.6140928268433,1.1585710048675537,0.0 -50200,1.3278767,2.3213804,,,,,,,,,,,,,, -50300,1.2103496,2.2705843,,,,,,,,,,,,,, -50400,1.2912654,2.1286407,,,,,,,,,,,,,, -50500,1.3542981,2.407444,,,,,,,,,,,,,, -50600,1.5225738,2.2711923,,,,,,,,,,,,,, -50700,1.2268103,2.194124,,,,,,,,,,,,,, -50800,1.4012446,2.338363,,,,,,,,,,,,,, -50900,1.3249081,2.2147036,,,,,,,,,,,,,, -51000,1.3940969,2.2997563,,,,,,,,,,,,,, -51100,1.3450941,2.333525,,,,,,,,,,,,,, -51200,1.3999242,2.2724085,,,,,,,,,,,,,, -51300,1.2664965,2.3906643,,,,,,,,,,,,,, -51400,1.2186023,2.2436793,,,,,,,,,,,,,, -51500,1.3358908,2.2899973,,,,,,,,,,,,,, -51600,1.2973204,2.1780257,,,,,,,,,,,,,, -51681,,,0.3650350570678711,3.0436782836914062,0.3366599977016449,3.240746259689331,50000.0,0.2612999975681305,3.977926254272461,10000.0,17376.93217921257,18007.572848796844,17376.93217921257,627.7574996948242,1.1959059238433838,0.0 -51700,1.2857393,2.2962928,,,,,,,,,,,,,, -51800,1.4649835,2.241118,,,,,,,,,,,,,, -51900,1.2987263,2.3780487,,,,,,,,,,,,,, -52000,1.2805192,2.280034,,,,,,,,,,,,,, -52100,1.4322871,2.1138704,,,,,,,,,,,,,, -52200,1.257243,2.111865,,,,,,,,,,,,,, -52300,1.3485482,2.2160563,,,,,,,,,,,,,, -52400,1.2158,2.2356882,,,,,,,,,,,,,, -52500,1.2362328,2.195592,,,,,,,,,,,,,, -52600,1.3735228,2.0846474,,,,,,,,,,,,,, -52700,1.3419358,2.3257577,,,,,,,,,,,,,, -52800,1.3638742,2.2004514,,,,,,,,,,,,,, -52900,1.2081169,2.313389,,,,,,,,,,,,,, -53000,1.3290182,2.1280165,,,,,,,,,,,,,, -53100,1.3027815,2.4254436,,,,,,,,,,,,,, -53200,1.3304309,2.1781337,,,,,,,,,,,,,, -53208,,,0.3320710957050323,3.289456367492676,0.3113200068473816,3.4361612796783447,50000.0,0.2336000055074691,4.169589042663574,10000.0,17887.1039853096,18535.551746606827,17887.1039853096,645.4750876426697,1.2352380752563477,0.0 -53300,1.2272329,2.235675,,,,,,,,,,,,,, -53400,1.2946502,2.1702461,,,,,,,,,,,,,, -53500,1.3375056,2.2017105,,,,,,,,,,,,,, -53600,1.2208066,2.192106,,,,,,,,,,,,,, -53700,1.3769315,2.2533114,,,,,,,,,,,,,, -53800,1.241865,2.3313003,,,,,,,,,,,,,, -53900,1.4645877,2.2194774,,,,,,,,,,,,,, -54000,1.4207046,2.2162502,,,,,,,,,,,,,, -54100,1.3681754,2.3956356,,,,,,,,,,,,,, -54200,1.2026491,2.2323055,,,,,,,,,,,,,, -54300,1.3435618,2.374465,,,,,,,,,,,,,, -54400,1.2365811,2.2783122,,,,,,,,,,,,,, -54500,1.255695,2.199925,,,,,,,,,,,,,, -54600,1.3902875,2.1663883,,,,,,,,,,,,,, -54700,1.1716044,2.1993108,,,,,,,,,,,,,, -54735,,,0.2272998988628387,4.260939121246338,0.2242999970912933,4.226137161254883,50000.0,0.1551000028848648,5.019404411315918,10000.0,18397.301684379578,19063.868786096573,18397.301684379578,663.5008449554443,1.2788963317871094,0.0 -54800,1.3403889,2.322992,,,,,,,,,,,,,, -54900,1.5376633,2.2729435,,,,,,,,,,,,,, -55000,1.2695233,2.238634,,,,,,,,,,,,,, -55100,1.2374029,2.2595422,,,,,,,,,,,,,, -55200,1.400805,2.2601075,,,,,,,,,,,,,, -55300,1.3186884,2.1534243,,,,,,,,,,,,,, -55400,1.2907085,2.135377,,,,,,,,,,,,,, -55500,1.3136638,2.236246,,,,,,,,,,,,,, -55600,1.3004056,2.117956,,,,,,,,,,,,,, -55700,1.3986458,2.2428849,,,,,,,,,,,,,, -55800,1.3549279,2.2216291,,,,,,,,,,,,,, -55900,1.3006828,2.165901,,,,,,,,,,,,,, -56000,1.4561516,2.1928205,,,,,,,,,,,,,, -56100,1.2295394,2.1176453,,,,,,,,,,,,,, -56200,1.319384,2.1946826,,,,,,,,,,,,,, -56262,,,0.1985411345958709,4.842617511749268,0.1861799955368042,5.003037452697754,50000.0,0.1467000097036361,5.59663200378418,10000.0,18907.414145231247,19592.039578437805,18907.414145231247,681.4662253856659,1.322425127029419,0.0 -56300,1.397314,2.1117141,,,,,,,,,,,,,, -56400,1.3501434,2.2737105,,,,,,,,,,,,,, -56500,1.3601308,2.2892733,,,,,,,,,,,,,, -56600,1.3879051,2.3904057,,,,,,,,,,,,,, -56700,1.4037488,2.317845,,,,,,,,,,,,,, -56800,1.2944282,2.3037984,,,,,,,,,,,,,, -56900,1.2005904,2.2200034,,,,,,,,,,,,,, -57000,1.3679694,2.3152308,,,,,,,,,,,,,, -57100,1.2948582,2.1686208,,,,,,,,,,,,,, -57200,1.3292612,2.2149777,,,,,,,,,,,,,, -57300,1.2946168,2.292954,,,,,,,,,,,,,, -57400,1.4655268,2.1727571,,,,,,,,,,,,,, -57500,1.1989241,2.1768794,,,,,,,,,,,,,, -57600,1.3275968,2.2389479,,,,,,,,,,,,,, -57700,1.3136897,2.1387148,,,,,,,,,,,,,, -57789,,,0.405652105808258,2.680437088012696,0.3668600022792816,2.947593688964844,50000.0,0.2788000106811523,3.648837566375733,10000.0,19417.5827600956,20120.224281072617,19417.5827600956,699.392019033432,1.363011598587036,0.0 -57800,1.438265,2.1896844,,,,,,,,,,,,,, -57900,1.3857688,2.2367623,,,,,,,,,,,,,, -58000,1.3486025,2.1906195,,,,,,,,,,,,,, -58100,1.2161794,2.0119972,,,,,,,,,,,,,, -58200,1.3947905,2.2918673,,,,,,,,,,,,,, -58300,1.3777719,2.1126215,,,,,,,,,,,,,, -58400,1.4667053,2.181385,,,,,,,,,,,,,, -58500,1.3626692,2.1641078,,,,,,,,,,,,,, -58600,1.1896542,2.1625602,,,,,,,,,,,,,, -58700,1.1878903,2.2038758,,,,,,,,,,,,,, -58800,1.3567116,2.2257433,,,,,,,,,,,,,, -58900,1.3189621,2.089708,,,,,,,,,,,,,, -59000,1.4339597,2.289101,,,,,,,,,,,,,, -59100,1.2861456,2.247134,,,,,,,,,,,,,, -59200,1.1964432,2.1552203,,,,,,,,,,,,,, -59300,1.4499588,2.1135526,,,,,,,,,,,,,, -59316,,,0.4073062837123871,2.710867166519165,0.3699599802494049,2.94526481628418,50000.0,0.284600019454956,3.6405067443847656,10000.0,19927.652045965195,20648.283844470978,19927.652045965195,717.2902994155884,1.4048044681549072,0.0 -59400,1.294925,2.1858754,,,,,,,,,,,,,, -59500,1.2181906,2.163695,,,,,,,,,,,,,, -59600,1.5864762,2.3232036,,,,,,,,,,,,,, -59700,1.2207233,2.2742355,,,,,,,,,,,,,, -59800,1.389185,2.1828964,,,,,,,,,,,,,, -59900,1.3661066,2.253501,,,,,,,,,,,,,, -60000,1.490964,2.346963,,,,,,,,,,,,,, -60100,1.3020484,2.3132255,,,,,,,,,,,,,, -60200,1.4478496,2.282995,,,,,,,,,,,,,, -60300,1.2583791,2.2367065,,,,,,,,,,,,,, -60400,1.3073492,2.135393,,,,,,,,,,,,,, -60500,1.6465343,2.2104115,,,,,,,,,,,,,, -60600,1.4587953,2.0387301,,,,,,,,,,,,,, -60700,1.2646852,2.2045329,,,,,,,,,,,,,, -60800,1.3864144,2.3248203,,,,,,,,,,,,,, -60843,,,0.2323022931814193,4.145813941955566,0.221119999885559,4.304041862487793,50000.0,0.1575000137090683,5.075519561767578,10000.0,20437.66402554512,21176.199315071102,20437.66402554512,735.1041345596313,1.4447991847991943,0.0 -60900,1.4433872,2.1337423,,,,,,,,,,,,,, -61000,1.4815996,2.2512403,,,,,,,,,,,,,, -61100,1.2356359,2.158972,,,,,,,,,,,,,, -61200,1.3490435,2.1863527,,,,,,,,,,,,,, -61300,1.3485646,2.1225324,,,,,,,,,,,,,, -61400,1.468092,2.1865754,,,,,,,,,,,,,, -61500,1.2577312,2.3120568,,,,,,,,,,,,,, -61600,1.2619282,2.1148832,,,,,,,,,,,,,, -61700,1.390556,2.2678812,,,,,,,,,,,,,, -61800,1.6329057,2.2626526,,,,,,,,,,,,,, -61900,1.421523,2.3402698,,,,,,,,,,,,,, -62000,1.303877,2.2796264,,,,,,,,,,,,,, -62100,1.5513526,2.1646528,,,,,,,,,,,,,, -62200,1.4143466,2.35992,,,,,,,,,,,,,, -62300,1.2028021,2.165884,,,,,,,,,,,,,, -62370,,,0.2048588991165161,4.66219425201416,0.1911199986934662,4.752766132354736,50000.0,0.1399000138044357,5.52890157699585,10000.0,20947.76675963401,21704.33132839203,20947.76675963401,753.0401477813721,1.4879536628723145,0.0 -62400,1.4026942,2.1800606,,,,,,,,,,,,,, -62500,1.2521383,2.1671546,,,,,,,,,,,,,, -62600,1.4349723,2.2154725,,,,,,,,,,,,,, -62700,1.2870347,2.3490412,,,,,,,,,,,,,, -62800,1.3369043,2.179686,,,,,,,,,,,,,, -62900,1.2813108,2.1281939,,,,,,,,,,,,,, -63000,1.2392879,2.1769748,,,,,,,,,,,,,, -63100,1.4247555,2.2881825,,,,,,,,,,,,,, -63200,1.2134825,2.1370502,,,,,,,,,,,,,, -63300,1.5057503,2.14932,,,,,,,,,,,,,, -63400,1.4131985,2.1700418,,,,,,,,,,,,,, -63500,1.5140408,2.1073763,,,,,,,,,,,,,, -63600,1.4101058,2.2547958,,,,,,,,,,,,,, -63700,1.2761072,2.0191963,,,,,,,,,,,,,, -63800,1.3158556,2.1310797,,,,,,,,,,,,,, -63897,,,0.3743223845958709,2.886444091796875,0.3484399914741516,3.05894422531128,50000.0,0.2681000232696533,3.77394700050354,10000.0,21457.712785243988,22231.957134723663,21457.712785243988,770.624137878418,1.533905029296875,0.0 -63900,1.2712852,2.164464,,,,,,,,,,,,,, -64000,1.3952508,2.250727,,,,,,,,,,,,,, -64100,1.4604295,2.0249455,,,,,,,,,,,,,, -64200,1.2911906,2.1941073,,,,,,,,,,,,,, -64300,1.3672384,2.2492714,,,,,,,,,,,,,, -64400,1.5236502,2.2304657,,,,,,,,,,,,,, -64500,1.3763137,2.2774959,,,,,,,,,,,,,, -64600,1.4006999,2.124435,,,,,,,,,,,,,, -64700,1.4468024,2.1376677,,,,,,,,,,,,,, -64800,1.5792049,2.2353873,,,,,,,,,,,,,, -64900,1.5354435,2.2901826,,,,,,,,,,,,,, -65000,1.3440228,2.1016872,,,,,,,,,,,,,, -65100,1.4331124,2.1751702,,,,,,,,,,,,,, -65200,1.3596842,2.2472506,,,,,,,,,,,,,, -65300,1.3645469,2.241371,,,,,,,,,,,,,, -65400,1.44112,2.1230798,,,,,,,,,,,,,, -65424,,,0.307995855808258,3.498842716217041,0.284879982471466,3.684895277023315,50000.0,0.2158000171184539,4.417821884155273,10000.0,21967.712792396545,22759.64086675644,21967.712792396545,788.2158498764038,1.5760498046875,0.0 -65500,1.3818492,2.220086,,,,,,,,,,,,,, -65600,1.3625453,2.1552496,,,,,,,,,,,,,, -65700,1.2958833,2.071287,,,,,,,,,,,,,, -65800,1.4422835,2.1424446,,,,,,,,,,,,,, -65900,1.4464612,2.2011347,,,,,,,,,,,,,, -66000,1.3362741,2.2306423,,,,,,,,,,,,,, -66100,1.3768075,2.0954177,,,,,,,,,,,,,, -66200,1.4317064,2.13319,,,,,,,,,,,,,, -66300,1.3134559,2.178057,,,,,,,,,,,,,, -66400,1.2436078,2.2595568,,,,,,,,,,,,,, -66500,1.3690149,2.1914575,,,,,,,,,,,,,, -66600,1.2984283,2.1496503,,,,,,,,,,,,,, -66700,1.4746017,2.2918913,,,,,,,,,,,,,, -66800,1.331788,2.2106214,,,,,,,,,,,,,, -66900,1.4235585,2.244359,,,,,,,,,,,,,, -66951,,,0.3006616532802582,3.679957628250122,0.2753999829292297,3.886249542236328,50000.0,0.2025000154972076,4.604314804077148,10000.0,22477.6426115036,23287.309163093567,22477.6426115036,805.8626515865326,1.61775541305542,0.0 -67000,1.332018,2.19459,,,,,,,,,,,,,, -67100,1.3949177,2.2006373,,,,,,,,,,,,,, -67200,1.2891047,2.2096148,,,,,,,,,,,,,, -67300,1.3481256,2.1796126,,,,,,,,,,,,,, -67400,1.2407522,2.1691222,,,,,,,,,,,,,, -67500,1.3716244,2.1797538,,,,,,,,,,,,,, -67600,1.4275717,2.0899172,,,,,,,,,,,,,, -67700,1.4413621,2.1432958,,,,,,,,,,,,,, -67800,1.3994762,2.2083416,,,,,,,,,,,,,, -67900,1.3669064,2.0679877,,,,,,,,,,,,,, -68000,1.4034554,2.264525,,,,,,,,,,,,,, -68100,1.4049654,2.2681556,,,,,,,,,,,,,, -68200,1.3890768,2.1712673,,,,,,,,,,,,,, -68300,1.3876526,2.2026145,,,,,,,,,,,,,, -68400,1.5248652,2.187309,,,,,,,,,,,,,, -68479,,,0.3583784997463226,3.008211612701416,0.3360199928283691,3.1703059673309326,50000.0,0.2540000081062317,3.8654260635375977,10000.0,22987.82523560524,23815.67833900452,22987.82523560524,823.9564123153687,1.6605603694915771,0.0 -68500,1.481025,2.267365,,,,,,,,,,,,,, -68600,1.4532907,2.0846977,,,,,,,,,,,,,, -68700,1.3817188,2.1129892,,,,,,,,,,,,,, -68800,1.436149,2.1306,,,,,,,,,,,,,, -68900,1.389437,2.2626545,,,,,,,,,,,,,, -69000,1.5797487,2.2044697,,,,,,,,,,,,,, -69100,1.5484459,2.240792,,,,,,,,,,,,,, -69200,1.4479166,2.157006,,,,,,,,,,,,,, -69300,1.332184,2.1293423,,,,,,,,,,,,,, -69400,1.498309,2.0926485,,,,,,,,,,,,,, -69500,1.4487156,2.125689,,,,,,,,,,,,,, -69600,1.5701028,2.0815337,,,,,,,,,,,,,, -69700,1.3465501,2.0465932,,,,,,,,,,,,,, -69800,1.3500094,2.1662238,,,,,,,,,,,,,, -69900,1.3934908,2.1778574,,,,,,,,,,,,,, -70000,1.349898,2.120842,,,,,,,,,,,,,, -70006,,,0.3116430044174194,3.471470355987549,0.2888000011444092,3.624599695205689,50000.0,0.214800015091896,4.404978275299072,10000.0,23497.832139492035,24343.43438887596,23497.832139492035,841.6116020679474,1.704545497894287,0.0 -70100,1.4097865,2.264001,,,,,,,,,,,,,, -70200,1.4090977,2.122381,,,,,,,,,,,,,, -70300,1.5546752,2.069577,,,,,,,,,,,,,, -70400,1.6327531,2.0705292,,,,,,,,,,,,,, -70500,1.4488997,2.2360165,,,,,,,,,,,,,, -70600,1.4411067,2.1932683,,,,,,,,,,,,,, -70700,1.3688122,2.2199085,,,,,,,,,,,,,, -70800,1.3901925,2.1632452,,,,,,,,,,,,,, -70900,1.4503449,2.159857,,,,,,,,,,,,,, -71000,1.3870147,2.085979,,,,,,,,,,,,,, -71100,1.3856174,2.1578176,,,,,,,,,,,,,, -71200,1.4399037,2.3354511,,,,,,,,,,,,,, -71300,1.5199525,2.1748204,,,,,,,,,,,,,, -71400,1.3703992,2.1397567,,,,,,,,,,,,,, -71500,1.3273652,2.0490086,,,,,,,,,,,,,, -71533,,,0.2196069806814193,4.325572967529297,0.2080599963665008,4.440552234649658,50000.0,0.1634000092744827,5.040404319763184,10000.0,24007.7573299408,24871.24115753174,24007.7573299408,859.4020702838898,1.7458949089050293,0.0 -71600,1.3547101,2.1288009,,,,,,,,,,,,,, -71700,1.3868504,2.2309735,,,,,,,,,,,,,, -71800,1.8219594,2.2213078,,,,,,,,,,,,,, -71900,1.3528926,2.1917906,,,,,,,,,,,,,, -72000,1.3327341,2.0188968,,,,,,,,,,,,,, -72100,1.5031139,2.2287738,,,,,,,,,,,,,, -72200,1.4963716,2.163612,,,,,,,,,,,,,, -72300,1.3549645,2.1840684,,,,,,,,,,,,,, -72400,1.4677217,2.1407936,,,,,,,,,,,,,, -72500,1.4072214,2.1227088,,,,,,,,,,,,,, -72600,1.557236,2.253151,,,,,,,,,,,,,, -72700,1.5302304,2.272094,,,,,,,,,,,,,, -72800,1.4818723,2.0739906,,,,,,,,,,,,,, -72900,1.4530413,2.1949556,,,,,,,,,,,,,, -73000,1.3208212,2.098336,,,,,,,,,,,,,, -73061,,,0.3602519035339355,3.1127405166625977,0.338619977235794,3.271209239959717,50000.0,0.2539000213146209,4.007626056671143,10000.0,24517.887838363647,25399.302145957947,24517.887838363647,877.2425088882446,1.785841703414917,0.0 -73100,1.6013019,2.0519462,,,,,,,,,,,,,, -73200,1.3768395,2.1360621,,,,,,,,,,,,,, -73300,1.4765251,2.1485057,,,,,,,,,,,,,, -73400,1.3970855,1.954958,,,,,,,,,,,,,, -73500,1.399804,2.0442264,,,,,,,,,,,,,, -73600,1.3836305,2.038352,,,,,,,,,,,,,, -73700,1.7707994,2.3118553,,,,,,,,,,,,,, -73800,1.4177644,2.0716853,,,,,,,,,,,,,, -73900,1.4848515,2.1555417,,,,,,,,,,,,,, -74000,1.3671436,2.0364473,,,,,,,,,,,,,, -74100,1.4114982,2.0691304,,,,,,,,,,,,,, -74200,1.3798448,2.112492,,,,,,,,,,,,,, -74300,1.4357837,2.1472888,,,,,,,,,,,,,, -74400,1.3123717,2.0225956,,,,,,,,,,,,,, -74500,1.5879769,2.1239254,,,,,,,,,,,,,, -74588,,,0.287488043308258,3.622225284576416,0.262719988822937,3.84192419052124,50000.0,0.1949000060558319,4.602164268493652,10000.0,25027.81852698326,25927.017527341843,25027.81852698326,894.9329364299774,1.829528570175171,0.0 -74600,1.9724352,2.206193,,,,,,,,,,,,,, -74700,1.3685431,2.1240606,,,,,,,,,,,,,, -74800,1.4176494,2.126788,,,,,,,,,,,,,, -74900,1.2483314,2.144037,,,,,,,,,,,,,, -75000,1.5215708,2.0877135,,,,,,,,,,,,,, -75100,1.5164967,2.1354163,,,,,,,,,,,,,, -75200,1.479092,2.0464718,,,,,,,,,,,,,, -75300,1.4851347,2.1458895,,,,,,,,,,,,,, -75400,1.4783171,2.0283263,,,,,,,,,,,,,, -75500,1.4282881,2.166397,,,,,,,,,,,,,, -75600,1.5224403,2.1304731,,,,,,,,,,,,,, -75700,1.4371688,2.2320254,,,,,,,,,,,,,, -75800,1.45068,2.0753713,,,,,,,,,,,,,, -75900,1.4099778,2.0292833,,,,,,,,,,,,,, -76000,1.3455943,2.0127773,,,,,,,,,,,,,, -76100,1.557987,2.2066302,,,,,,,,,,,,,, -76116,,,0.3895487785339355,2.8953745365142822,0.3669599890708923,3.0703601837158203,50000.0,0.2672000229358673,3.9029479026794434,10000.0,25537.98294305801,26454.930029153824,25537.98294305801,912.58482670784,1.876232624053955,0.0 -76200,1.6388862,2.16094,,,,,,,,,,,,,, -76300,1.4149319,2.121554,,,,,,,,,,,,,, -76400,1.5746266,2.0383768,,,,,,,,,,,,,, -76500,1.4134011,2.2679887,,,,,,,,,,,,,, -76600,1.6417534,2.2181578,,,,,,,,,,,,,, -76700,1.4937198,2.1222792,,,,,,,,,,,,,, -76800,1.5286238,2.0236654,,,,,,,,,,,,,, -76900,1.4273689,2.188391,,,,,,,,,,,,,, -77000,1.3931653,2.1200132,,,,,,,,,,,,,, -77100,1.7613732,2.0408106,,,,,,,,,,,,,, -77200,1.5118823,2.1339967,,,,,,,,,,,,,, -77300,1.4640685,2.191541,,,,,,,,,,,,,, -77400,1.4817638,2.090534,,,,,,,,,,,,,, -77500,1.5532563,2.2026908,,,,,,,,,,,,,, -77600,1.4046339,2.043109,,,,,,,,,,,,,, -77644,,,0.4274752736091614,2.703529357910156,0.3989799916744232,2.865347623825073,50000.0,0.3033000230789184,3.6137118339538574,10000.0,26048.15784502029,26982.89257240296,26048.15784502029,930.2763559818268,1.9222619533538816,0.0 -77700,1.480257,2.1378753,,,,,,,,,,,,,, -77800,1.4737712,2.1250992,,,,,,,,,,,,,, -77900,1.4477638,2.1227171,,,,,,,,,,,,,, -78000,1.5573539,2.0893216,,,,,,,,,,,,,, -78100,1.6833578,2.1080368,,,,,,,,,,,,,, -78200,1.5020742,1.9282444,,,,,,,,,,,,,, -78300,1.5109634,2.2042341,,,,,,,,,,,,,, -78400,1.5523918,2.0976543,,,,,,,,,,,,,, -78500,1.4709368,2.1214137,,,,,,,,,,,,,, -78600,1.5791332,2.042912,,,,,,,,,,,,,, -78700,1.5525348,2.2204216,,,,,,,,,,,,,, -78800,1.3799895,2.1732805,,,,,,,,,,,,,, -78900,1.4732069,2.0615828,,,,,,,,,,,,,, -79000,1.4269804,2.2012324,,,,,,,,,,,,,, -79100,1.5569474,2.047599,,,,,,,,,,,,,, -79171,,,0.445691168308258,2.4578397274017334,0.4142599999904632,2.646533966064453,50000.0,0.315200001001358,3.397288084030152,10000.0,26558.10323810577,27510.913942098618,26558.10323810577,948.2588489055634,1.9660618305206297,0.0 -79200,1.4233286,2.1560297,,,,,,,,,,,,,, -79300,1.5304718,2.1396923,,,,,,,,,,,,,, -79400,1.7480223,2.2629437,,,,,,,,,,,,,, -79500,1.6812881,2.1392167,,,,,,,,,,,,,, -79600,1.6001475,2.0951688,,,,,,,,,,,,,, -79700,1.4542899,1.9401767,,,,,,,,,,,,,, -79800,1.4393034,2.1435761,,,,,,,,,,,,,, -79900,1.4857261,2.0558288,,,,,,,,,,,,,, -80000,1.5004516,2.1292226,,,,,,,,,,,,,, -80100,1.4354683,2.1460032,,,,,,,,,,,,,, -80200,1.5337623,2.1711938,,,,,,,,,,,,,, -80300,1.3872137,2.0581064,,,,,,,,,,,,,, -80400,1.5337632,2.1322112,,,,,,,,,,,,,, -80500,1.6683521,2.1895428,,,,,,,,,,,,,, -80600,1.4471933,2.0664937,,,,,,,,,,,,,, -80699,,,0.4790935814380646,2.3329217433929443,0.4555400013923645,2.490853309631348,50000.0,0.3410000205039978,3.322407484054565,10000.0,27068.290987730023,28039.09343481064,27068.290987730023,966.1588563919069,2.0083022117614746,0.0 -80700,1.4565792,2.060479,,,,,,,,,,,,,, -80800,1.8259989,2.2385755,,,,,,,,,,,,,, -80900,1.7517816,2.1334026,,,,,,,,,,,,,, -81000,1.6530292,1.9690528,,,,,,,,,,,,,, -81100,1.5410911,2.1784778,,,,,,,,,,,,,, -81200,1.7824465,2.0741549,,,,,,,,,,,,,, -81300,1.8967423,2.1253083,,,,,,,,,,,,,, -81400,1.6290034,2.0286064,,,,,,,,,,,,,, -81500,1.5040047,2.1055758,,,,,,,,,,,,,, -81600,1.7567092,2.0348,,,,,,,,,,,,,, -81700,1.5187118,2.0080488,,,,,,,,,,,,,, -81800,1.4948413,2.1381223,,,,,,,,,,,,,, -81900,1.6568425,2.0013785,,,,,,,,,,,,,, -82000,1.485038,2.052234,,,,,,,,,,,,,, -82100,1.4404953,2.02291,,,,,,,,,,,,,, -82200,1.5810851,2.0357702,,,,,,,,,,,,,, -82226,,,0.306740254163742,3.739010810852051,0.2848399877548218,3.9017951488494873,50000.0,0.2016000151634216,4.77598237991333,10000.0,27578.20561933517,28566.787320137024,27578.20561933517,983.8452835083008,2.0514883995056152,0.0 -82300,1.7767705,2.1623862,,,,,,,,,,,,,, -82400,1.632872,2.2599962,,,,,,,,,,,,,, -82500,1.6808245,2.0147972,,,,,,,,,,,,,, -82600,1.5256184,2.2479546,,,,,,,,,,,,,, -82700,1.4844774,2.1488554,,,,,,,,,,,,,, -82800,1.9956288,2.1646743,,,,,,,,,,,,,, -82900,1.5811074,2.0061617,,,,,,,,,,,,,, -83000,1.7883315,2.0823379,,,,,,,,,,,,,, -83100,1.5013857,1.9667726,,,,,,,,,,,,,, -83200,1.7120702,2.2127163,,,,,,,,,,,,,, -83300,1.5894778,2.022734,,,,,,,,,,,,,, -83400,1.4973352,2.1298237,,,,,,,,,,,,,, -83500,1.64319,2.1411605,,,,,,,,,,,,,, -83600,1.5755371,2.1188002,,,,,,,,,,,,,, -83700,1.5389913,2.0444288,,,,,,,,,,,,,, -83754,,,0.4067083895206451,2.743592500686645,0.370739996433258,3.0049564838409424,50000.0,0.2864000201225281,3.7166709899902335,10000.0,28088.363196611404,29095.08299231529,28088.363196611404,1001.8912694454192,2.093227863311768,0.0 -83800,1.5491714,2.1398935,,,,,,,,,,,,,, -83900,1.6560527,2.1598864,,,,,,,,,,,,,, -84000,1.6422048,2.052145,,,,,,,,,,,,,, -84100,1.3970104,2.0047479,,,,,,,,,,,,,, -84200,1.5199778,2.1518815,,,,,,,,,,,,,, -84300,1.7788037,2.0863767,,,,,,,,,,,,,, -84400,1.6027774,2.0778208,,,,,,,,,,,,,, -84500,1.4686924,2.076556,,,,,,,,,,,,,, -84600,1.5877849,2.1539812,,,,,,,,,,,,,, -84700,1.4054791,2.059834,,,,,,,,,,,,,, -84800,1.616838,2.1881642,,,,,,,,,,,,,, -84900,1.548868,2.1510475,,,,,,,,,,,,,, -85000,1.6610466,2.1127539,,,,,,,,,,,,,, -85100,1.5622745,2.0577369,,,,,,,,,,,,,, -85200,1.648147,2.0381744,,,,,,,,,,,,,, -85282,,,0.2871890962123871,3.671375036239624,0.2664200067520141,3.834371089935303,50000.0,0.1824000030755996,4.670032978057861,10000.0,28598.490182876587,29623.33106446266,28598.490182876587,1019.915079832077,2.1403162479400635,0.0 -85300,1.4999821,2.015328,,,,,,,,,,,,,, -85400,1.5708272,2.1085932,,,,,,,,,,,,,, -85500,1.7636135,2.0828311,,,,,,,,,,,,,, -85600,1.7539123,1.9823856,,,,,,,,,,,,,, -85700,1.8748591,2.0901687,,,,,,,,,,,,,, -85800,1.6272464,2.0497334,,,,,,,,,,,,,, -85900,1.5858977,2.0354342,,,,,,,,,,,,,, -86000,1.5150574,1.9889551,,,,,,,,,,,,,, -86100,1.6657073,2.1165788,,,,,,,,,,,,,, -86200,1.5314567,2.062848,,,,,,,,,,,,,, -86300,1.974378,2.169032,,,,,,,,,,,,,, -86400,1.6644524,2.110477,,,,,,,,,,,,,, -86500,1.6330314,2.2105646,,,,,,,,,,,,,, -86600,1.6375389,2.0262558,,,,,,,,,,,,,, -86700,1.6898146,2.063106,,,,,,,,,,,,,, -86800,1.5979019,2.1110468,,,,,,,,,,,,,, -86809,,,0.3782086968421936,3.034822463989258,0.3511799871921539,3.27678656578064,50000.0,0.2788000106811523,3.904389619827272,10000.0,29108.441494226456,30151.29687857628,29108.441494226456,1037.8243143558502,2.196073055267334,0.0 -86900,1.6146427,2.00719,,,,,,,,,,,,,, -87000,1.6217266,2.120719,,,,,,,,,,,,,, -87100,1.5077791,2.015449,,,,,,,,,,,,,, -87200,1.7755004,2.0805278,,,,,,,,,,,,,, -87300,1.918865,2.1378653,,,,,,,,,,,,,, -87400,1.6905258,2.0470872,,,,,,,,,,,,,, -87500,1.6102449,2.0588257,,,,,,,,,,,,,, -87600,1.824906,2.1342494,,,,,,,,,,,,,, -87700,1.542862,2.0000074,,,,,,,,,,,,,, -87800,1.5417964,1.9692693,,,,,,,,,,,,,, -87900,1.5804579,2.0971017,,,,,,,,,,,,,, -88000,1.7036244,2.023839,,,,,,,,,,,,,, -88100,1.7830884,1.9945714,,,,,,,,,,,,,, -88200,1.7519273,2.0989952,,,,,,,,,,,,,, -88300,1.7006345,2.0589461,,,,,,,,,,,,,, -88337,,,0.3022560477256775,3.619099140167236,0.2789599895477295,3.844160556793213,50000.0,0.2198000103235244,4.452763080596924,10000.0,29618.62277984619,30679.39599967003,29618.62277984619,1055.6407935619354,2.24765419960022,0.0 -88400,1.53189,2.0585964,,,,,,,,,,,,,, -88500,1.6239923,2.029209,,,,,,,,,,,,,, -88600,1.6972947,2.036867,,,,,,,,,,,,,, -88700,1.8344907,2.0466242,,,,,,,,,,,,,, -88800,1.637323,1.8555522,,,,,,,,,,,,,, -88900,1.5772123,1.9833462,,,,,,,,,,,,,, -89000,1.6213982,2.0476701,,,,,,,,,,,,,, -89100,1.5918988,2.0021875,,,,,,,,,,,,,, -89200,1.4955755,2.0533712,,,,,,,,,,,,,, -89300,1.6430697,2.0116606,,,,,,,,,,,,,, -89400,1.6944461,2.0125427,,,,,,,,,,,,,, -89500,1.5713885,1.8984268,,,,,,,,,,,,,, -89600,1.6112926,2.106373,,,,,,,,,,,,,, -89700,1.6349347,2.1165302,,,,,,,,,,,,,, -89800,1.7298659,2.0608716,,,,,,,,,,,,,, -89864,,,0.5006975531578064,2.200382947921753,0.4702999889850616,2.404654264450073,50000.0,0.3614000082015991,3.230496644973755,10000.0,30128.824651002884,31207.64864993096,30128.824651002884,1073.5947699546814,2.2941675186157227,0.0 -89900,1.8274693,2.0483665,,,,,,,,,,,,,, -90000,1.745817,2.0000439,,,,,,,,,,,,,, -90100,1.6956462,2.068831,,,,,,,,,,,,,, -90200,1.5090921,1.9952877,,,,,,,,,,,,,, -90300,1.5053319,2.1316648,,,,,,,,,,,,,, -90400,1.5491807,2.1133053,,,,,,,,,,,,,, -90500,1.6739167,2.0817685,,,,,,,,,,,,,, -90600,1.8084793,2.0296144,,,,,,,,,,,,,, -90700,1.6646469,1.8613868,,,,,,,,,,,,,, -90800,1.5601592,2.0364108,,,,,,,,,,,,,, -90900,1.6752136,1.976819,,,,,,,,,,,,,, -91000,1.7330441,2.0610545,,,,,,,,,,,,,, -91100,1.6159644,2.091124,,,,,,,,,,,,,, -91200,1.7258648,2.119232,,,,,,,,,,,,,, -91300,1.6155292,2.0975556,,,,,,,,,,,,,, -91391,,,0.3496890962123871,3.236147165298462,0.3291199803352356,3.4062445163726807,50000.0,0.2346000075340271,4.22288990020752,10000.0,30638.732491970062,31736.439160823826,30638.732491970062,1092.3768472671509,2.3449578285217285,0.0 -91400,1.9060296,1.9052787,,,,,,,,,,,,,, -91500,1.7100992,2.1215098,,,,,,,,,,,,,, -91600,1.7495443,2.1302521,,,,,,,,,,,,,, -91700,1.6482211,2.043589,,,,,,,,,,,,,, -91800,1.6001967,2.064038,,,,,,,,,,,,,, -91900,1.8606304,1.9962175,,,,,,,,,,,,,, -92000,2.2878096,2.0205395,,,,,,,,,,,,,, -92100,1.9118462,2.0977762,,,,,,,,,,,,,, -92200,2.1216567,2.0766623,,,,,,,,,,,,,, -92300,1.6763793,1.9696813,,,,,,,,,,,,,, -92400,1.6412382,2.0458362,,,,,,,,,,,,,, -92500,1.6687918,2.0084648,,,,,,,,,,,,,, -92600,1.7152451,1.9883213,,,,,,,,,,,,,, -92700,1.6291894,2.0736744,,,,,,,,,,,,,, -92800,1.6586413,2.070769,,,,,,,,,,,,,, -92900,1.6835917,1.9638913,,,,,,,,,,,,,, -92919,,,0.4011878073215484,2.894664764404297,0.3725799918174743,3.125652313232422,50000.0,0.2875000238418579,3.9007530212402335,10000.0,31148.847939491272,32264.55832004547,31148.847939491272,1110.28062915802,2.3948986530303955,0.0 -93000,1.7100439,2.011741,,,,,,,,,,,,,, -93100,1.5640633,2.0282,,,,,,,,,,,,,, -93200,1.4986415,1.900401,,,,,,,,,,,,,, -93300,1.7639226,2.0022497,,,,,,,,,,,,,, -93400,1.6651363,1.8877337,,,,,,,,,,,,,, -93500,1.7465085,2.0257776,,,,,,,,,,,,,, -93600,1.687052,1.9971037,,,,,,,,,,,,,, -93700,1.650499,2.1168983,,,,,,,,,,,,,, -93800,1.7889135,2.0269558,,,,,,,,,,,,,, -93900,1.8390702,1.8974447,,,,,,,,,,,,,, -94000,1.7633903,2.043675,,,,,,,,,,,,,, -94100,1.8356225,2.0402222,,,,,,,,,,,,,, -94200,1.6366816,1.9675907,,,,,,,,,,,,,, -94300,1.7309163,2.0725458,,,,,,,,,,,,,, -94400,1.7212771,2.100402,,,,,,,,,,,,,, -94446,,,0.460339605808258,2.431842565536499,0.4266199767589569,2.65660047531128,50000.0,0.325300008058548,3.4505057334899902,10000.0,31658.8655025959,32792.40767073631,31658.8655025959,1128.0174877643583,2.43977689743042,0.0 -94500,1.762791,1.9857235,,,,,,,,,,,,,, -94600,1.6726484,1.9992654,,,,,,,,,,,,,, -94700,1.6799611,1.9799916,,,,,,,,,,,,,, -94800,1.6920646,1.933675,,,,,,,,,,,,,, -94900,2.1179636,2.157917,,,,,,,,,,,,,, -95000,1.785623,2.0965347,,,,,,,,,,,,,, -95100,1.6615382,2.0291169,,,,,,,,,,,,,, -95200,1.6608813,2.0636694,,,,,,,,,,,,,, -95300,1.9566166,2.0019813,,,,,,,,,,,,,, -95400,1.7652067,2.068088,,,,,,,,,,,,,, -95500,1.64082,2.0348923,,,,,,,,,,,,,, -95600,1.697548,2.069869,,,,,,,,,,,,,, -95700,1.8865439,2.0282502,,,,,,,,,,,,,, -95800,1.6550801,1.9271765,,,,,,,,,,,,,, -95900,1.639146,2.0295382,,,,,,,,,,,,,, -95974,,,0.2023078650236129,4.5234880447387695,0.1915399879217147,4.669929027557373,50000.0,0.1424000114202499,5.283186912536621,10000.0,32168.929438829426,33320.16366028786,32168.929438829426,1145.614995956421,2.484170436859131,0.0 -96000,1.976241,2.0087194,,,,,,,,,,,,,, -96100,1.6169691,2.0356722,,,,,,,,,,,,,, -96200,1.8544804,2.131629,,,,,,,,,,,,,, -96300,1.7434361,2.067712,,,,,,,,,,,,,, -96400,1.687802,1.9307977,,,,,,,,,,,,,, -96500,1.8701785,2.0511906,,,,,,,,,,,,,, -96600,1.7281197,2.079546,,,,,,,,,,,,,, -96700,1.7354221,1.9325476,,,,,,,,,,,,,, -96800,1.5956544,1.915365,,,,,,,,,,,,,, -96900,1.8412527,2.0490105,,,,,,,,,,,,,, -97000,1.9201684,2.0611813,,,,,,,,,,,,,, -97100,1.7555344,2.0230608,,,,,,,,,,,,,, -97200,2.0491705,2.0715997,,,,,,,,,,,,,, -97300,1.6989009,2.0609856,,,,,,,,,,,,,, -97400,2.1149318,1.8970201,,,,,,,,,,,,,, -97500,1.850651,1.9279358,,,,,,,,,,,,,, -97501,,,0.3995336294174194,2.832772731781006,0.3749800026416778,3.0190296173095703,50000.0,0.2884000241756439,3.7778568267822266,10000.0,32678.92079949379,33848.26008415222,32678.92079949379,1163.6204626560211,2.5338714122772217,0.0 -97600,1.609261,2.0431612,,,,,,,,,,,,,, -97700,1.8835341,2.0171177,,,,,,,,,,,,,, -97800,1.9016569,1.9841499,,,,,,,,,,,,,, -97900,1.8761395,1.9546516,,,,,,,,,,,,,, -98000,1.870604,1.959357,,,,,,,,,,,,,, -98100,1.8444377,1.9508889,,,,,,,,,,,,,, -98200,1.966368,1.958278,,,,,,,,,,,,,, -98300,1.7959116,2.03952,,,,,,,,,,,,,, -98400,1.7735513,1.9428407,,,,,,,,,,,,,, -98500,1.9121023,1.967095,,,,,,,,,,,,,, -98600,1.7102191,1.96493,,,,,,,,,,,,,, -98700,1.9781892,1.8861086,,,,,,,,,,,,,, -98800,1.8937632,2.00521,,,,,,,,,,,,,, -98900,1.9517425,1.9592977,,,,,,,,,,,,,, -99000,1.7169297,1.9579418,,,,,,,,,,,,,, -99029,,,0.4629504084587097,2.432036638259888,0.4339199960231781,2.634920597076416,50000.0,0.3244000077247619,3.5730552673339844,10000.0,33189.129398584366,34376.450227975845,33189.129398584366,1181.5033974647522,2.582381010055542,0.0 -99100,1.8239394,2.0098212,,,,,,,,,,,,,, -99200,1.6981466,2.0103831,,,,,,,,,,,,,, -99300,1.736154,1.9555326,,,,,,,,,,,,,, -99400,1.7699454,1.898302,,,,,,,,,,,,,, -99500,2.0507324,2.1287332,,,,,,,,,,,,,, -99600,1.9014363,2.0253232,,,,,,,,,,,,,, -99700,1.9274256,2.0647514,,,,,,,,,,,,,, -99800,1.8893869,1.8392372,,,,,,,,,,,,,, -99900,1.6807393,1.9477954,,,,,,,,,,,,,, -100000,1.8086568,1.898092,,,,,,,,,,,,,, -100100,1.9718941,2.0044427,,,,,,,,,,,,,, -100200,1.7648553,1.9774644,,,,,,,,,,,,,, -100300,1.6239561,1.9810983,,,,,,,,,,,,,, -100400,1.9472173,1.9963963,,,,,,,,,,,,,, -100500,1.8571999,1.9890205,,,,,,,,,,,,,, -100557,,,0.4109534323215484,2.8825113773345947,0.3824599981307983,3.089705467224121,50000.0,0.2790000140666961,4.015759944915772,10000.0,33699.24235534668,34904.74674367905,33699.24235534668,1199.5894927978516,2.630051374435425,0.0 -100600,1.8353859,1.8598067,,,,,,,,,,,,,, -100700,1.7423596,1.9027617,,,,,,,,,,,,,, -100800,1.714536,1.8568761,,,,,,,,,,,,,, -100900,1.9534508,1.9826615,,,,,,,,,,,,,, -101000,1.9252905,1.9454324,,,,,,,,,,,,,, -101100,1.7664218,1.9137645,,,,,,,,,,,,,, -101200,1.6775473,1.8582008,,,,,,,,,,,,,, -101300,1.8687392,1.9344864,,,,,,,,,,,,,, -101400,1.7566097,1.9378543,,,,,,,,,,,,,, -101500,1.8957812,2.0242682,,,,,,,,,,,,,, -101600,1.9334062,1.8849651,,,,,,,,,,,,,, -101700,1.7542965,2.0354133,,,,,,,,,,,,,, -101800,1.855176,1.9774832,,,,,,,,,,,,,, -101900,1.8575333,2.0996177,,,,,,,,,,,,,, -102000,2.0801578,1.9641715,,,,,,,,,,,,,, -102085,,,0.4940409660339355,2.256666421890259,0.4540799856185913,2.4757895469665527,50000.0,0.347100019454956,3.2543532848358154,10000.0,34209.41844010353,35432.6150135994,34209.41844010353,1217.1798260211945,2.681713342666626,0.0 -102100,2.014998,2.0657165,,,,,,,,,,,,,, -102200,1.7904408,1.9724605,,,,,,,,,,,,,, -102300,1.8683211,2.04072,,,,,,,,,,,,,, -102400,1.8760351,1.8441842,,,,,,,,,,,,,, -102500,1.8371929,2.0040681,,,,,,,,,,,,,, -102600,1.9138957,1.9596894,,,,,,,,,,,,,, -102700,1.7860868,1.9096956,,,,,,,,,,,,,, -102800,2.084104,2.0523396,,,,,,,,,,,,,, -102900,1.9607835,2.0044222,,,,,,,,,,,,,, -103000,1.9292741,2.086213,,,,,,,,,,,,,, -103100,1.7927538,1.9188044,,,,,,,,,,,,,, -103200,1.8926109,1.9789152,,,,,,,,,,,,,, -103300,1.798144,1.9278036,,,,,,,,,,,,,, -103400,1.9150125,1.9595181,,,,,,,,,,,,,, -103500,2.0394897,1.8968931,,,,,,,,,,,,,, -103600,1.909847,1.939153,,,,,,,,,,,,,, -103612,,,0.3447664082050323,3.2952005863189697,0.3284199833869934,3.4189095497131348,50000.0,0.2385000139474868,4.211745738983154,10000.0,34719.36843562126,35960.58519363403,34719.36843562126,1235.0998673439026,2.732177972793579,0.0 -103700,1.85398,1.852328,,,,,,,,,,,,,, -103800,1.9466835,1.9163231,,,,,,,,,,,,,, -103900,1.762373,1.8932102,,,,,,,,,,,,,, -104000,1.9658897,2.027692,,,,,,,,,,,,,, -104100,1.8308847,1.9491229,,,,,,,,,,,,,, -104200,2.0194025,1.8025562,,,,,,,,,,,,,, -104300,1.9395736,2.0149777,,,,,,,,,,,,,, -104400,1.8300326,1.9636896,,,,,,,,,,,,,, -104500,1.674848,1.895077,,,,,,,,,,,,,, -104600,1.9151734,1.9139919,,,,,,,,,,,,,, -104700,1.8066481,1.891084,,,,,,,,,,,,,, -104800,1.8851858,1.8810065,,,,,,,,,,,,,, -104900,1.8404948,1.9215962,,,,,,,,,,,,,, -105000,1.8696707,1.89911,,,,,,,,,,,,,, -105100,1.9210286,2.0690007,,,,,,,,,,,,,, -105140,,,0.4314213991165161,2.7033612728118896,0.3980000019073486,2.949122190475464,50000.0,0.3194000124931335,3.649909496307373,10000.0,35229.46646118164,36488.632420539856,35229.46646118164,1252.949723482132,2.781178951263428,0.0 -105200,1.8424152,1.9080681,,,,,,,,,,,,,, -105300,1.9277953,1.9294868,,,,,,,,,,,,,, -105400,3.6352887,2.0212128,,,,,,,,,,,,,, -105500,1.9045419,1.9093955,,,,,,,,,,,,,, -105600,2.028124,1.9207487,,,,,,,,,,,,,, -105700,1.9470425,2.050925,,,,,,,,,,,,,, -105800,1.8451138,1.8546965,,,,,,,,,,,,,, -105900,1.8200015,1.8355374,,,,,,,,,,,,,, -106000,2.0837939,1.8847643,,,,,,,,,,,,,, -106100,1.820961,2.0675018,,,,,,,,,,,,,, -106200,2.0691988,1.9147103,,,,,,,,,,,,,, -106300,1.9555115,1.9889796,,,,,,,,,,,,,, -106400,1.9784328,1.9214461,,,,,,,,,,,,,, -106500,1.9761469,1.9618335,,,,,,,,,,,,,, -106600,2.0554523,1.9847304,,,,,,,,,,,,,, -106668,,,0.562898576259613,1.828281283378601,0.5299199819564819,2.0261363983154297,50000.0,0.4075000286102295,2.7543745040893555,10000.0,35739.57062864304,37017.07039570808,35739.57062864304,1271.185317993164,2.8301868438720703,0.0 -106700,2.3448555,1.8472996,,,,,,,,,,,,,, -106800,1.892981,1.8161254,,,,,,,,,,,,,, -106900,2.1538475,2.0267622,,,,,,,,,,,,,, -107000,1.924293,1.826633,,,,,,,,,,,,,, -107100,2.0061905,1.9153806,,,,,,,,,,,,,, -107200,1.8330071,1.8991461,,,,,,,,,,,,,, -107300,2.31325,1.9885817,,,,,,,,,,,,,, -107400,2.0564585,2.016055,,,,,,,,,,,,,, -107500,1.8300244,1.8793126,,,,,,,,,,,,,, -107600,1.9634335,2.0036306,,,,,,,,,,,,,, -107700,1.9721178,1.93808,,,,,,,,,,,,,, -107800,1.9506533,2.0390716,,,,,,,,,,,,,, -107900,2.3239343,1.971535,,,,,,,,,,,,,, -108000,2.2345357,2.0320787,,,,,,,,,,,,,, -108100,1.8273772,1.9464972,,,,,,,,,,,,,, -108196,,,0.5349768996238708,1.9775022268295288,0.4815199971199035,2.32653546333313,50000.0,0.359000027179718,3.246283531188965,10000.0,36249.67180633545,37545.1928293705,36249.67180633545,1289.1059653759005,2.881099939346313,0.0 -108200,1.8446648,1.8732275,,,,,,,,,,,,,, -108300,1.922732,1.7525649,,,,,,,,,,,,,, -108400,1.9250228,1.9074128,,,,,,,,,,,,,, -108500,2.009018,1.8340274,,,,,,,,,,,,,, -108600,1.8576603,1.8529485,,,,,,,,,,,,,, -108700,1.8858565,1.8579404,,,,,,,,,,,,,, -108800,2.1011417,1.8766705,,,,,,,,,,,,,, -108900,2.4733188,2.0147069,,,,,,,,,,,,,, -109000,1.8258333,1.9202733,,,,,,,,,,,,,, -109100,1.8935933,1.858121,,,,,,,,,,,,,, -109200,2.1010005,1.8542459,,,,,,,,,,,,,, -109300,1.9344927,1.8971105,,,,,,,,,,,,,, -109400,1.96584,1.8493866,,,,,,,,,,,,,, -109500,1.9590551,1.8261687,,,,,,,,,,,,,, -109600,2.134711,1.9147166,,,,,,,,,,,,,, -109700,1.9324322,1.9414103,,,,,,,,,,,,,, -109724,,,0.5242546200752258,2.0457510948181152,0.4815599918365478,2.3079276084899902,50000.0,0.3776000142097473,3.059465885162353,10000.0,36759.76053881645,38073.16121888161,36759.76053881645,1306.8855466842651,2.931341648101806,0.0 -109800,2.2244182,1.8969678,,,,,,,,,,,,,, -109900,2.050075,2.0591643,,,,,,,,,,,,,, -110000,1.9216887,1.993064,,,,,,,,,,,,,, -110100,1.8992316,1.8879616,,,,,,,,,,,,,, -110200,1.9435867,1.7939458,,,,,,,,,,,,,, -110300,1.9920393,2.002417,,,,,,,,,,,,,, -110400,2.253606,1.8862382,,,,,,,,,,,,,, -110500,1.9497972,1.8245434,,,,,,,,,,,,,, -110600,1.9494412,2.005947,,,,,,,,,,,,,, -110700,2.4815986,1.8233802,,,,,,,,,,,,,, -110800,1.9374495,1.7678447,,,,,,,,,,,,,, -110900,2.2389472,1.9654953,,,,,,,,,,,,,, -111000,2.123419,1.8317772,,,,,,,,,,,,,, -111100,2.1524572,1.8622066,,,,,,,,,,,,,, -111200,2.1883118,1.816597,,,,,,,,,,,,,, -111252,,,0.5350366830825806,1.9782723188400269,0.5010600090026855,2.1930699348449707,50000.0,0.3883000314235687,2.9401986598968506,10000.0,37269.95954012871,38601.34153342247,37269.95954012871,1324.7671279907229,2.98144006729126,0.0 -111300,1.9951568,1.8463502,,,,,,,,,,,,,, -111400,1.8889002,1.9503874,,,,,,,,,,,,,, -111500,1.9536201,1.8823365,,,,,,,,,,,,,, -111600,2.201176,1.8605783,,,,,,,,,,,,,, -111700,2.1553934,1.9166173,,,,,,,,,,,,,, -111800,2.1128514,1.952915,,,,,,,,,,,,,, -111900,2.0934887,1.7485019,,,,,,,,,,,,,, -112000,2.205268,1.8464884,,,,,,,,,,,,,, -112100,2.0976043,1.9803048,,,,,,,,,,,,,, -112200,1.90905,1.8839339,,,,,,,,,,,,,, -112300,2.0279546,1.8355935,,,,,,,,,,,,,, -112400,2.1712983,1.9535437,,,,,,,,,,,,,, -112500,2.111739,1.8240216,,,,,,,,,,,,,, -112600,2.4175668,2.022398,,,,,,,,,,,,,, -112700,2.15368,2.031859,,,,,,,,,,,,,, -112780,,,0.5105628371238708,2.158193826675415,0.4708399772644043,2.4166712760925293,50000.0,0.3680000305175781,3.1187973022460938,10000.0,37780.089233636856,39129.21529483795,37780.089233636856,1342.4133460521698,3.02941370010376,0.0 -112800,1.9211067,1.8694032,,,,,,,,,,,,,, -112900,2.3800502,1.9630253,,,,,,,,,,,,,, -113000,2.2577946,1.9567665,,,,,,,,,,,,,, -113100,2.1981828,1.8895948,,,,,,,,,,,,,, -113200,2.4154696,1.865134,,,,,,,,,,,,,, -113300,2.1055496,1.9190575,,,,,,,,,,,,,, -113400,2.3498785,1.981464,,,,,,,,,,,,,, -113500,2.2783132,1.9033991,,,,,,,,,,,,,, -113600,2.1315317,1.8996594,,,,,,,,,,,,,, -113700,2.2478447,1.8890693,,,,,,,,,,,,,, -113800,2.131074,1.8309109,,,,,,,,,,,,,, -113900,2.1172843,1.9285748,,,,,,,,,,,,,, -114000,2.1990266,1.9113581,,,,,,,,,,,,,, -114100,2.1856644,1.8812861,,,,,,,,,,,,,, -114200,2.1198745,1.7611356,,,,,,,,,,,,,, -114300,2.2589586,1.7953731,,,,,,,,,,,,,, -114308,,,0.5438257455825806,1.9540691375732424,0.5044999718666077,2.177619695663452,50000.0,0.3813000321388244,3.057910680770874,10000.0,38290.16973924637,39657.031764507294,38290.16973924637,1360.0454897880554,3.0832841396331787,0.0 -114400,2.2456293,1.8982661,,,,,,,,,,,,,, -114500,2.2631698,1.8780074,,,,,,,,,,,,,, -114600,2.1819217,1.8915917,,,,,,,,,,,,,, -114700,2.3066022,1.8174572,,,,,,,,,,,,,, -114800,2.4533322,1.8597132,,,,,,,,,,,,,, -114900,2.2196853,1.8590207,,,,,,,,,,,,,, -115000,1.9957861,1.744562,,,,,,,,,,,,,, -115100,2.0408354,1.9086075,,,,,,,,,,,,,, -115200,2.375861,1.9108918,,,,,,,,,,,,,, -115300,2.025238,1.8282964,,,,,,,,,,,,,, -115400,2.0474126,1.8844429,,,,,,,,,,,,,, -115500,2.1141214,1.7952508,,,,,,,,,,,,,, -115600,2.2364094,1.8701874,,,,,,,,,,,,,, -115700,2.1041923,1.9106839,,,,,,,,,,,,,, -115800,2.3885636,1.8022398,,,,,,,,,,,,,, -115836,,,0.6219108700752258,1.5522531270980835,0.5729599595069885,1.8273431062698364,50000.0,0.457500010728836,2.545141696929932,10000.0,38800.36287164688,40185.308972120285,38800.36287164688,1378.0283389091492,3.134504556655884,0.0 -115900,2.1898627,1.8347489,,,,,,,,,,,,,, -116000,2.1011612,1.7311622,,,,,,,,,,,,,, -116100,2.2764077,1.7954054,,,,,,,,,,,,,, -116200,2.2489967,1.788477,,,,,,,,,,,,,, -116300,1.9242736,1.7482548,,,,,,,,,,,,,, -116400,2.1944535,1.8614618,,,,,,,,,,,,,, -116500,2.1054258,1.7550467,,,,,,,,,,,,,, -116600,2.0487893,1.8358334,,,,,,,,,,,,,, -116700,1.9168344,1.782139,,,,,,,,,,,,,, -116800,2.091012,1.7962947,,,,,,,,,,,,,, -116900,2.1598117,1.8576143,,,,,,,,,,,,,, -117000,2.2262402,1.7755985,,,,,,,,,,,,,, -117100,2.2506845,1.8493686,,,,,,,,,,,,,, -117200,2.1156638,1.8329812,,,,,,,,,,,,,, -117300,2.4734364,1.8752127,,,,,,,,,,,,,, -117364,,,0.6141780614852905,1.578540563583374,0.551800012588501,1.9139325618743896,50000.0,0.435200035572052,2.6567647457122803,10000.0,39310.3577773571,40713.292598724365,39310.3577773571,1395.919328212738,3.18233060836792,0.0 -117400,2.5331414,1.8484242,,,,,,,,,,,,,, -117500,2.152506,1.778971,,,,,,,,,,,,,, -117600,2.2518485,1.9630977,,,,,,,,,,,,,, -117700,2.2028544,1.9026223,,,,,,,,,,,,,, -117800,2.259809,1.8150916,,,,,,,,,,,,,, -117900,2.7349985,1.8602164,,,,,,,,,,,,,, -118000,2.341344,1.8525455,,,,,,,,,,,,,, -118100,2.4980233,1.8719857,,,,,,,,,,,,,, -118200,2.3695068,1.8932886,,,,,,,,,,,,,, -118300,2.1998408,1.7848805,,,,,,,,,,,,,, -118400,2.2413864,1.703075,,,,,,,,,,,,,, -118500,2.2506528,1.8393018,,,,,,,,,,,,,, -118600,2.1806393,1.7680634,,,,,,,,,,,,,, -118700,2.0993874,1.6135434,,,,,,,,,,,,,, -118800,2.3287666,1.7703378,,,,,,,,,,,,,, -118892,,,0.5724250674247742,1.76844584941864,0.526639997959137,2.034866571426392,50000.0,0.4172000288963318,2.772226333618164,10000.0,39820.552035331726,41241.165909051895,39820.552035331726,1413.4967761039734,3.233930826187134,0.0 -118900,2.2752364,1.7952077,,,,,,,,,,,,,, -119000,2.2384605,1.7329304,,,,,,,,,,,,,, -119100,2.0595405,1.7961638,,,,,,,,,,,,,, -119200,2.199928,1.869015,,,,,,,,,,,,,, -119300,2.2038271,1.8020178,,,,,,,,,,,,,, -119400,2.2920625,1.7916825,,,,,,,,,,,,,, -119500,2.426537,1.7536161,,,,,,,,,,,,,, -119600,2.2527392,1.8929915,,,,,,,,,,,,,, -119700,2.3528616,1.8777221,,,,,,,,,,,,,, -119800,2.3390026,1.8719498,,,,,,,,,,,,,, -119900,2.34263,1.769001,,,,,,,,,,,,,, -120000,2.145081,1.8770726,,,,,,,,,,,,,, -120100,2.214665,1.828067,,,,,,,,,,,,,, -120200,2.5124645,1.8162364,,,,,,,,,,,,,, -120300,2.3331249,1.7296873,,,,,,,,,,,,,, -120400,2.456708,1.788643,,,,,,,,,,,,,, -120420,,,0.6161710619926453,1.567164182662964,0.5663999915122986,1.838868498802185,50000.0,0.443200021982193,2.578282833099365,10000.0,40330.61541700363,41768.916761636734,40330.61541700363,1431.083517551422,3.284346342086792,0.0 -120500,2.2529895,1.7675209,,,,,,,,,,,,,, -120600,2.3582098,1.7942734,,,,,,,,,,,,,, -120700,2.392437,1.8680155,,,,,,,,,,,,,, -120800,2.3671463,1.8395282,,,,,,,,,,,,,, -120900,2.18374,1.9170924,,,,,,,,,,,,,, -121000,2.428143,1.768426,,,,,,,,,,,,,, -121100,2.3208327,1.8041047,,,,,,,,,,,,,, -121200,2.478747,1.7699118,,,,,,,,,,,,,, -121300,2.4301925,1.8197329,,,,,,,,,,,,,, -121400,2.1457107,1.8362519,,,,,,,,,,,,,, -121500,2.2642386,1.8646377,,,,,,,,,,,,,, -121600,2.4465132,1.7919426,,,,,,,,,,,,,, -121700,2.0882702,1.7450275,,,,,,,,,,,,,, -121800,2.5412152,1.9629319,,,,,,,,,,,,,, -121900,2.5620766,1.8199048,,,,,,,,,,,,,, -121948,,,0.5903021097183228,1.6958752870559692,0.5471799969673157,1.95624577999115,50000.0,0.4316000342369079,2.7272050380706787,10000.0,40840.63898897171,42296.831644296646,40840.63898897171,1448.8743512630465,3.3349971771240234,0.0 -122000,2.2579195,1.7277447,,,,,,,,,,,,,, -122100,2.4746654,1.7505343,,,,,,,,,,,,,, -122200,2.235776,1.7821625,,,,,,,,,,,,,, -122300,2.362409,1.8191183,,,,,,,,,,,,,, -122400,2.5895119,1.7644573,,,,,,,,,,,,,, -122500,2.4328651,1.8248295,,,,,,,,,,,,,, -122600,2.441477,1.8341712,,,,,,,,,,,,,, -122700,2.7338533,1.9117601,,,,,,,,,,,,,, -122800,2.3049583,1.7205927,,,,,,,,,,,,,, -122900,2.2424805,1.7519473,,,,,,,,,,,,,, -123000,2.424321,1.867925,,,,,,,,,,,,,, -123100,2.2707534,1.6386391,,,,,,,,,,,,,, -123200,2.4941938,1.692978,,,,,,,,,,,,,, -123300,2.5078397,1.7473298,,,,,,,,,,,,,, -123400,2.4995081,1.7702138,,,,,,,,,,,,,, -123476,,,0.6090362071990967,1.6039897203445437,0.5696600079536438,1.830997109413147,50000.0,0.4378000199794769,2.6645984649658203,10000.0,41350.85647821426,42824.79166054726,41350.85647821426,1466.5170137882233,3.384669065475464,0.0 -123500,2.2244933,1.8318207,,,,,,,,,,,,,, -123600,2.3995407,1.7427446,,,,,,,,,,,,,, -123700,2.3410258,1.8297298,,,,,,,,,,,,,, -123800,2.3675194,1.8337557,,,,,,,,,,,,,, -123900,2.2021003,1.7660246,,,,,,,,,,,,,, -124000,2.2913427,1.8124262,,,,,,,,,,,,,, -124100,2.3289673,1.7691896,,,,,,,,,,,,,, -124200,2.641665,1.8063786,,,,,,,,,,,,,, -124300,2.327832,1.7796729,,,,,,,,,,,,,, -124400,2.3725588,1.789113,,,,,,,,,,,,,, -124500,2.5212402,1.8205287,,,,,,,,,,,,,, -124600,2.4200985,1.7235718,,,,,,,,,,,,,, -124700,2.4189026,1.7859305,,,,,,,,,,,,,, -124800,2.878249,1.8528806,,,,,,,,,,,,,, -124900,2.4411082,1.7058914,,,,,,,,,,,,,, -125000,2.6442845,1.643774,,,,,,,,,,,,,, -125004,,,0.6317163705825806,1.4943809509277344,0.5685999989509583,1.83841335773468,50000.0,0.4446000158786773,2.612132787704468,10000.0,41860.99411845207,43352.67394042015,41860.99411845207,1484.1569118499756,3.4393198490142822,0.0 -125100,2.4638894,1.8502431,,,,,,,,,,,,,, -125200,2.3914714,1.7383065,,,,,,,,,,,,,, -125300,2.2958977,1.7917582,,,,,,,,,,,,,, -125400,2.2790513,1.7017782,,,,,,,,,,,,,, -125500,2.6710336,1.9166731,,,,,,,,,,,,,, -125600,2.5010521,1.7983887,,,,,,,,,,,,,, -125700,2.4868221,1.7595904,,,,,,,,,,,,,, -125800,2.571256,1.7299587,,,,,,,,,,,,,, -125900,2.7055688,1.769448,,,,,,,,,,,,,, -126000,2.6803916,1.736564,,,,,,,,,,,,,, -126100,2.4738743,1.6971436,,,,,,,,,,,,,, -126200,2.5390153,1.8277335,,,,,,,,,,,,,, -126300,2.386506,1.6924621,,,,,,,,,,,,,, -126400,2.3391452,1.6787345,,,,,,,,,,,,,, -126500,2.5383048,1.7401273,,,,,,,,,,,,,, -126532,,,0.5991111397743225,1.6552882194519043,0.5497399568557739,1.9696252346038816,50000.0,0.4226000308990478,2.7985992431640625,10000.0,42371.09545564652,43880.53527808189,42371.09545564652,1501.814185142517,3.492196083068848,0.0 -126600,2.4085426,1.7266977,,,,,,,,,,,,,, -126700,2.6580486,1.6969165,,,,,,,,,,,,,, -126800,2.4349868,1.652873,,,,,,,,,,,,,, -126900,2.790207,1.832709,,,,,,,,,,,,,, -127000,2.6935964,1.8264413,,,,,,,,,,,,,, -127100,2.471872,1.6963996,,,,,,,,,,,,,, -127200,2.3330376,1.7657793,,,,,,,,,,,,,, -127300,2.4027548,1.7426724,,,,,,,,,,,,,, -127400,2.5526857,1.694185,,,,,,,,,,,,,, -127500,2.5311005,1.6430079,,,,,,,,,,,,,, -127600,2.6974764,1.6222363,,,,,,,,,,,,,, -127700,2.382717,1.769012,,,,,,,,,,,,,, -127800,2.773698,1.8260347,,,,,,,,,,,,,, -127900,2.663723,1.7589666,,,,,,,,,,,,,, -128000,2.7249503,1.7690759,,,,,,,,,,,,,, -128060,,,0.6313177347183228,1.4843790531158447,0.5854200124740601,1.7508049011230469,50000.0,0.4629000127315521,2.504871129989624,10000.0,42881.19850087166,44408.43687057495,42881.19850087166,1519.5065422058103,3.5481600761413574,0.0 -128100,2.4385223,1.6614745,,,,,,,,,,,,,, -128200,2.533914,1.7775456,,,,,,,,,,,,,, -128300,2.5903347,1.7252969,,,,,,,,,,,,,, -128400,2.686695,1.6885617,,,,,,,,,,,,,, -128500,2.3811843,1.6018244,,,,,,,,,,,,,, -128600,2.7003593,1.7788525,,,,,,,,,,,,,, -128700,2.4444945,1.7740618,,,,,,,,,,,,,, -128800,2.4407027,1.7075708,,,,,,,,,,,,,, -128900,2.4503229,1.7660913,,,,,,,,,,,,,, -129000,2.60954,1.6784675,,,,,,,,,,,,,, -129100,2.7461512,1.6922382,,,,,,,,,,,,,, -129200,2.6418986,1.7056916,,,,,,,,,,,,,, -129300,2.5251791,1.8542962,,,,,,,,,,,,,, -129400,2.5943975,1.6405505,,,,,,,,,,,,,, -129500,2.5679243,1.6085589,,,,,,,,,,,,,, -129588,,,0.620137095451355,1.572115778923035,0.5705400109291077,1.8495970964431765,50000.0,0.4502000212669372,2.617003917694092,10000.0,43391.20353627205,44936.09682822228,43391.20353627205,1537.0586075782776,3.6010334491729736,0.0 -129600,2.4722137,1.6368251,,,,,,,,,,,,,, -129700,2.7415814,1.6865941,,,,,,,,,,,,,, -129800,2.4949238,1.7645056,,,,,,,,,,,,,, -129900,2.55542,1.6858531,,,,,,,,,,,,,, -130000,2.5377932,1.7252985,,,,,,,,,,,,,, -130100,2.7618973,1.5834155,,,,,,,,,,,,,, -130200,2.4220488,1.6711756,,,,,,,,,,,,,, -130300,2.6700997,1.7875854,,,,,,,,,,,,,, -130400,2.6466713,1.6941593,,,,,,,,,,,,,, -130500,2.5745497,1.7490411,,,,,,,,,,,,,, -130600,2.744397,1.6847067,,,,,,,,,,,,,, -130700,2.612016,1.6152071,,,,,,,,,,,,,, -130800,2.6954403,1.669284,,,,,,,,,,,,,, -130900,2.4601295,1.791188,,,,,,,,,,,,,, -131000,2.7028453,1.665195,,,,,,,,,,,,,, -131100,2.8108864,1.772662,,,,,,,,,,,,,, -131116,,,0.6050900816917419,1.635583519935608,0.5597000122070312,1.9223034381866453,50000.0,0.433100014925003,2.7411255836486816,10000.0,43901.29053258896,45464.13428735733,43901.29053258896,1554.9034173488617,3.656316041946411,0.0 -131200,2.6736135,1.6752917,,,,,,,,,,,,,, -131300,2.6985548,1.703407,,,,,,,,,,,,,, -131400,2.6114502,1.6734643,,,,,,,,,,,,,, -131500,3.2438402,1.8729569,,,,,,,,,,,,,, -131600,2.6802516,1.7111241,,,,,,,,,,,,,, -131700,2.944066,1.6830742,,,,,,,,,,,,,, -131800,2.62069,1.6994381,,,,,,,,,,,,,, -131900,2.6351027,1.6181285,,,,,,,,,,,,,, -132000,2.5685315,1.6732265,,,,,,,,,,,,,, -132100,2.5972524,1.5896506,,,,,,,,,,,,,, -132200,2.7336648,1.6432683,,,,,,,,,,,,,, -132300,2.9022195,1.6526126,,,,,,,,,,,,,, -132400,2.8917236,1.8369495,,,,,,,,,,,,,, -132500,2.5403664,1.6030254,,,,,,,,,,,,,, -132600,2.8270252,1.6172873,,,,,,,,,,,,,, -132644,,,0.6350247263908386,1.4889466762542725,0.5819799900054932,1.7952096462249756,50000.0,0.457500010728836,2.617147445678711,10000.0,44411.38590765,45992.00921392441,44411.38590765,1572.5753400325775,3.7136833667755127,0.0 -132700,3.0006588,1.6462651,,,,,,,,,,,,,, -132800,2.516755,1.6104329,,,,,,,,,,,,,, -132900,2.9655354,1.8279808,,,,,,,,,,,,,, -133000,2.7040126,1.6518723,,,,,,,,,,,,,, -133100,2.6501887,1.6540558,,,,,,,,,,,,,, -133200,2.7300386,1.7046357,,,,,,,,,,,,,, -133300,2.7296872,1.6815501,,,,,,,,,,,,,, -133400,2.9106593,1.5912852,,,,,,,,,,,,,, -133500,2.6928895,1.6887097,,,,,,,,,,,,,, -133600,3.1483939,1.7430099,,,,,,,,,,,,,, -133700,2.5913503,1.6464533,,,,,,,,,,,,,, -133800,2.474372,1.6167556,,,,,,,,,,,,,, -133900,2.82757,1.7073648,,,,,,,,,,,,,, -134000,2.855151,1.587257,,,,,,,,,,,,,, -134100,2.82174,1.7322986,,,,,,,,,,,,,, -134172,,,0.6775948405265808,1.2720587253570557,0.6111199855804443,1.6209615468978882,50000.0,0.4809000194072723,2.401205778121948,10000.0,44921.41864657402,46519.67265796661,44921.41864657402,1590.098914861679,3.770865440368652,0.0 -134200,2.5212543,1.5610938,,,,,,,,,,,,,, -134300,2.797297,1.686489,,,,,,,,,,,,,, -134400,2.6236887,1.637596,,,,,,,,,,,,,, -134500,2.9265566,1.6764319,,,,,,,,,,,,,, -134600,2.887561,1.7263649,,,,,,,,,,,,,, -134700,2.9410405,1.6756574,,,,,,,,,,,,,, -134800,2.7205029,1.6014922,,,,,,,,,,,,,, -134900,2.6284473,1.5655496,,,,,,,,,,,,,, -135000,2.934686,1.6998305,,,,,,,,,,,,,, -135100,3.020195,1.7438995,,,,,,,,,,,,,, -135200,2.6544523,1.6203407,,,,,,,,,,,,,, -135300,2.8188438,1.615161,,,,,,,,,,,,,, -135400,2.780084,1.6210002,,,,,,,,,,,,,, -135500,2.7856407,1.5499249,,,,,,,,,,,,,, -135600,2.9144526,1.6022587,,,,,,,,,,,,,, -135700,,,0.6362802982330322,1.4815256595611572,0.5810799598693848,1.777900457382202,50000.0,0.4642000198364258,2.5052618980407715,10000.0,45431.4598543644,47047.70048522949,45431.4598543644,1607.9817507266998,3.824446439743042,0.0 -135700,3.0287101,1.7160194,,,,,,,,,,,,,, -135800,2.8902888,1.5998505,,,,,,,,,,,,,, -135900,2.749891,1.5730274,,,,,,,,,,,,,, -136000,3.2104447,1.6955876,,,,,,,,,,,,,, -136100,3.032385,1.6380377,,,,,,,,,,,,,, -136200,2.877368,1.6165714,,,,,,,,,,,,,, -136300,3.0256696,1.6798252,,,,,,,,,,,,,, -136400,2.8195293,1.6482947,,,,,,,,,,,,,, -136500,3.0675502,1.7023883,,,,,,,,,,,,,, -136600,3.1075501,1.6096314,,,,,,,,,,,,,, -136700,3.0722823,1.6124115,,,,,,,,,,,,,, -136800,2.7554667,1.6107347,,,,,,,,,,,,,, -136900,3.229811,1.6546586,,,,,,,,,,,,,, -137000,2.9152932,1.5777736,,,,,,,,,,,,,, -137100,3.0282953,1.5781432,,,,,,,,,,,,,, -137200,2.7352037,1.5484806,,,,,,,,,,,,,, -137228,,,0.672273576259613,1.3039746284484863,0.6099199652671814,1.6179298162460327,50000.0,0.497700035572052,2.3209965229034424,10000.0,45941.48925709725,47575.711570978165,45941.48925709725,1625.8593764305117,3.8787858486175537,0.0 -137300,3.0687852,1.547954,,,,,,,,,,,,,, -137400,2.8690598,1.5930772,,,,,,,,,,,,,, -137500,3.133787,1.7102641,,,,,,,,,,,,,, -137600,2.9677863,1.5554594,,,,,,,,,,,,,, -137700,2.8986847,1.5872473,,,,,,,,,,,,,, -137800,3.0010672,1.5887539,,,,,,,,,,,,,, -137900,2.984956,1.5116217,,,,,,,,,,,,,, -138000,2.8934367,1.513199,,,,,,,,,,,,,, -138100,3.1986294,1.6975802,,,,,,,,,,,,,, -138200,3.0192492,1.5790504,,,,,,,,,,,,,, -138300,2.752325,1.5195503,,,,,,,,,,,,,, -138400,2.9425302,1.6043444,,,,,,,,,,,,,, -138500,3.3937063,1.6465522,,,,,,,,,,,,,, -138600,3.0723267,1.590416,,,,,,,,,,,,,, -138700,2.729466,1.5730166,,,,,,,,,,,,,, -138756,,,0.6704400181770325,1.3004449605941772,0.6202799677848816,1.5722438097000122,50000.0,0.4925000369548797,2.3523638248443604,10000.0,46451.5012075901,48103.67853784561,46451.5012075901,1643.7108445167542,3.932407379150391,0.0 -138800,3.224176,1.6503463,,,,,,,,,,,,,, -138900,3.2829268,1.6736528,,,,,,,,,,,,,, -139000,3.1185534,1.6455348,,,,,,,,,,,,,, -139100,2.918933,1.5478152,,,,,,,,,,,,,, -139200,2.9824595,1.6109216,,,,,,,,,,,,,, -139300,3.224503,1.4689002,,,,,,,,,,,,,, -139400,3.0840042,1.5871223,,,,,,,,,,,,,, -139500,2.859714,1.4961333,,,,,,,,,,,,,, -139600,3.03953,1.5673233,,,,,,,,,,,,,, -139700,2.8846922,1.5069371,,,,,,,,,,,,,, -139800,3.3957891,1.5485114,,,,,,,,,,,,,, -139900,2.8742247,1.5341132,,,,,,,,,,,,,, -140000,3.1050286,1.4994721,,,,,,,,,,,,,, -140100,3.1964495,1.5959445,,,,,,,,,,,,,, -140200,2.986613,1.514946,,,,,,,,,,,,,, -140282,,,0.4956353604793548,2.2361655235290527,0.4678599834442138,2.429154634475708,50000.0,0.355400025844574,3.248229503631592,10000.0,46961.43891119957,48631.27277398109,46961.43891119957,1661.262745141983,3.98689866065979,0.0 -140300,2.7780185,1.4859116,,,,,,,,,,,,,, -140400,3.4080486,1.5779979,,,,,,,,,,,,,, -140500,2.938702,1.502888,,,,,,,,,,,,,, -140600,3.1588087,1.6381274,,,,,,,,,,,,,, -140700,3.1950262,1.5736265,,,,,,,,,,,,,, -140800,3.1951308,1.4414344,,,,,,,,,,,,,, -140900,3.5381353,1.5287737,,,,,,,,,,,,,, -141000,2.9650931,1.3954195,,,,,,,,,,,,,, -141100,3.2475505,1.5116659,,,,,,,,,,,,,, -141200,3.0768106,1.4522499,,,,,,,,,,,,,, -141300,2.923404,1.6803528,,,,,,,,,,,,,, -141400,3.435262,1.6692615,,,,,,,,,,,,,, -141500,3.1158576,1.6364992,,,,,,,,,,,,,, -141600,3.1172373,1.610973,,,,,,,,,,,,,, -141700,3.2159803,1.5509104,,,,,,,,,,,,,, -141800,3.4287531,1.57624,,,,,,,,,,,,,, -141810,,,0.7058752775192261,1.1420633792877195,0.6265599727630615,1.5559496879577637,50000.0,0.4934000372886657,2.3541271686553955,10000.0,47471.4825797081,49159.315249443054,47471.4825797081,1679.1552288532257,4.043623924255371,0.0 -141900,3.2625434,1.5410798,,,,,,,,,,,,,, -142000,3.1668708,1.5876756,,,,,,,,,,,,,, -142100,3.4441078,1.6255957,,,,,,,,,,,,,, -142200,3.0837107,1.5403069,,,,,,,,,,,,,, -142300,3.077242,1.4373021,,,,,,,,,,,,,, -142400,3.0531776,1.5480589,,,,,,,,,,,,,, -142500,2.9929388,1.4051802,,,,,,,,,,,,,, -142600,3.0052288,1.3977093,,,,,,,,,,,,,, -142700,2.9647155,1.531146,,,,,,,,,,,,,, -142800,3.1558955,1.4243169,,,,,,,,,,,,,, -142900,3.0340495,1.5966419,,,,,,,,,,,,,, -143000,3.3990076,1.571065,,,,,,,,,,,,,, -143100,3.2170799,1.4780905,,,,,,,,,,,,,, -143200,3.2541149,1.5355904,,,,,,,,,,,,,, -143300,3.1281025,1.5151168,,,,,,,,,,,,,, -143338,,,0.7002949714660645,1.1588937044143677,0.6348199844360352,1.5233947038650513,50000.0,0.5059000253677368,2.282422065734864,10000.0,47981.5309278965,49687.51860022545,47981.5309278965,1697.2029082775116,4.101156234741211,0.0 -143400,3.1848583,1.4828304,,,,,,,,,,,,,, -143500,3.476747,1.4816968,,,,,,,,,,,,,, -143600,3.0341492,1.4101357,,,,,,,,,,,,,, -143700,3.282734,1.539171,,,,,,,,,,,,,, -143800,3.2678943,1.5590553,,,,,,,,,,,,,, -143900,3.350779,1.4846694,,,,,,,,,,,,,, -144000,2.9933836,1.5442922,,,,,,,,,,,,,, -144100,3.1432638,1.5062042,,,,,,,,,,,,,, -144200,3.2906735,1.5293049,,,,,,,,,,,,,, -144300,3.3556678,1.5622213,,,,,,,,,,,,,, -144400,3.3885355,1.5063717,,,,,,,,,,,,,, -144500,3.4042573,1.5411189,,,,,,,,,,,,,, -144600,3.2355795,1.4264023,,,,,,,,,,,,,, -144700,3.2819238,1.5983837,,,,,,,,,,,,,, -144800,3.410215,1.6434413,,,,,,,,,,,,,, -144866,,,0.7127311825752258,1.1116129159927368,0.642300009727478,1.4744572639465332,50000.0,0.5162000060081482,2.229666948318481,10000.0,48491.59139537811,50215.48052382469,48491.59139537811,1714.9963533878326,4.159027576446533,0.0 -144900,3.250657,1.4988652,,,,,,,,,,,,,, -145000,3.4350228,1.5415045,,,,,,,,,,,,,, -145100,3.4996781,1.4565574,,,,,,,,,,,,,, -145200,3.4637089,1.4780256,,,,,,,,,,,,,, -145300,3.0885422,1.477702,,,,,,,,,,,,,, -145400,3.1518836,1.4620752,,,,,,,,,,,,,, -145500,3.1553035,1.4767466,,,,,,,,,,,,,, -145600,3.3956618,1.4029013,,,,,,,,,,,,,, -145700,3.4405007,1.4726067,,,,,,,,,,,,,, -145800,3.3049884,1.5616598,,,,,,,,,,,,,, -145900,3.0535939,1.3365953,,,,,,,,,,,,,, -146000,3.4679582,1.4650413,,,,,,,,,,,,,, -146100,3.5967708,1.5322576,,,,,,,,,,,,,, -146200,3.6016755,1.500392,,,,,,,,,,,,,, -146300,3.5360887,1.4785473,,,,,,,,,,,,,, -146394,,,0.7197863459587097,1.077343225479126,0.6534799933433533,1.424755334854126,50000.0,0.5260000228881836,2.159235954284668,10000.0,49001.64282894135,50743.13808107376,49001.64282894135,1732.4942715168,4.217477798461914,0.0 -146400,3.518324,1.4960555,,,,,,,,,,,,,, -146500,3.300567,1.4081265,,,,,,,,,,,,,, -146600,3.315117,1.5334563,,,,,,,,,,,,,, -146700,3.3806877,1.4939185,,,,,,,,,,,,,, -146800,3.7461522,1.4559479,,,,,,,,,,,,,, -146900,3.513479,1.561782,,,,,,,,,,,,,, -147000,3.4388032,1.4713237,,,,,,,,,,,,,, -147100,3.5148082,1.5564609,,,,,,,,,,,,,, -147200,3.3997302,1.3924047,,,,,,,,,,,,,, -147300,3.3816333,1.4952042,,,,,,,,,,,,,, -147400,3.5132616,1.5768445,,,,,,,,,,,,,, -147500,3.8737986,1.5095688,,,,,,,,,,,,,, -147600,3.3912876,1.4274677,,,,,,,,,,,,,, -147700,3.4420671,1.4654238,,,,,,,,,,,,,, -147800,3.4524002,1.4238341,,,,,,,,,,,,,, -147900,3.4032428,1.449625,,,,,,,,,,,,,, -147922,,,0.7180524468421936,1.0938856601715088,0.6541599631309509,1.4333339929580688,50000.0,0.526900053024292,2.1575634479522705,10000.0,49511.83917331696,51271.2620010376,49511.83917331696,1750.3120419979095,4.277265548706055,0.0 -148000,3.5539443,1.434267,,,,,,,,,,,,,, -148100,3.508266,1.4902407,,,,,,,,,,,,,, -148200,3.3434052,1.4287385,,,,,,,,,,,,,, -148300,3.680936,1.4042926,,,,,,,,,,,,,, -148400,3.5886033,1.4997989,,,,,,,,,,,,,, -148500,3.4613147,1.369122,,,,,,,,,,,,,, -148600,3.7533152,1.4943297,,,,,,,,,,,,,, -148700,3.2644482,1.4369495,,,,,,,,,,,,,, -148800,3.3046594,1.5347111,,,,,,,,,,,,,, -148900,3.5151608,1.3852398,,,,,,,,,,,,,, -149000,3.489998,1.3944354,,,,,,,,,,,,,, -149100,3.2897274,1.326438,,,,,,,,,,,,,, -149200,3.7942748,1.3770233,,,,,,,,,,,,,, -149300,3.2812514,1.4356735,,,,,,,,,,,,,, -149400,3.6674092,1.4805586,,,,,,,,,,,,,, -149450,,,0.7408322691917419,0.9889224767684937,0.6670199632644653,1.3719351291656494,50000.0,0.5319000482559204,2.132917881011963,10000.0,50021.862213134766,51799.39016842842,50021.862213134766,1768.307421207428,4.336929559707642,0.0 -149500,3.7283993,1.4748318,,,,,,,,,,,,,, -149600,3.8876302,1.5195013,,,,,,,,,,,,,, -149700,3.8208914,1.4122937,,,,,,,,,,,,,, -149800,3.7571993,1.4993703,,,,,,,,,,,,,, -149900,3.9060583,1.4073344,,,,,,,,,,,,,, -150000,3.5660605,1.4733083,,,,,,,,,,,,,, -150100,3.633433,1.3683366,,,,,,,,,,,,,, -150200,3.4719667,1.4832606,,,,,,,,,,,,,, -150300,4.6679893,1.5402867,,,,,,,,,,,,,, -150400,3.512021,1.3888543,,,,,,,,,,,,,, -150500,3.4323688,1.4082742,,,,,,,,,,,,,, -150600,3.7565017,1.4712298,,,,,,,,,,,,,, -150700,3.7625272,1.4056088,,,,,,,,,,,,,, -150800,3.7264144,1.3866496,,,,,,,,,,,,,, -150900,3.5983062,1.4209002,,,,,,,,,,,,,, -150978,,,0.7679368257522583,0.8675883412361145,0.6796999573707581,1.3259252309799194,50000.0,0.5469000339508057,2.0598459243774414,10000.0,50531.87656021118,52327.08454442024,50531.87656021118,1785.8828003406525,4.391806602478027,0.0 -151000,4.043656,1.3759743,,,,,,,,,,,,,, -151100,3.5824645,1.3528113,,,,,,,,,,,,,, -151200,3.7877853,1.3745571,,,,,,,,,,,,,, -151300,3.6744297,1.3667989,,,,,,,,,,,,,, -151400,3.6386545,1.3836939,,,,,,,,,,,,,, -151500,3.62489,1.4899107,,,,,,,,,,,,,, -151600,3.8844194,1.5359954,,,,,,,,,,,,,, -151700,3.7461398,1.4369649,,,,,,,,,,,,,, -151800,3.6948087,1.3534832,,,,,,,,,,,,,, -151900,4.1968718,1.3883864,,,,,,,,,,,,,, -152000,3.8344424,1.3769218,,,,,,,,,,,,,, -152100,3.782,1.3680714,,,,,,,,,,,,,, -152200,3.7783675,1.3006086,,,,,,,,,,,,,, -152300,3.888095,1.3601241,,,,,,,,,,,,,, -152400,3.7810192,1.2759849,,,,,,,,,,,,,, -152500,3.8969285,1.3053724,,,,,,,,,,,,,, -152506,,,0.7592275142669678,0.9073374271392822,0.6794599890708923,1.3260866403579712,50000.0,0.5541000366210938,2.066560745239258,10000.0,51042.05988812447,52855.28619623184,51042.05988812447,1803.7932722568512,4.44978928565979,0.0 -152600,3.8839247,1.508321,,,,,,,,,,,,,, -152700,3.9714754,1.4125085,,,,,,,,,,,,,, -152800,3.5871413,1.4151304,,,,,,,,,,,,,, -152900,3.6861563,1.4018638,,,,,,,,,,,,,, -153000,4.2529945,1.3728755,,,,,,,,,,,,,, -153100,3.6911623,1.3810388,,,,,,,,,,,,,, -153200,3.904281,1.3915622,,,,,,,,,,,,,, -153300,3.9585478,1.5480402,,,,,,,,,,,,,, -153400,3.993317,1.3416363,,,,,,,,,,,,,, -153500,3.8594306,1.4030263,,,,,,,,,,,,,, -153600,3.828422,1.2335265,,,,,,,,,,,,,, -153700,3.7013423,1.321666,,,,,,,,,,,,,, -153800,3.6076777,1.3624358,,,,,,,,,,,,,, -153900,3.803061,1.3114787,,,,,,,,,,,,,, -154000,3.9131136,1.3114802,,,,,,,,,,,,,, -154034,,,0.7609016299247742,0.9098778963088988,0.6835599541664124,1.3010308742523191,50000.0,0.5524000525474548,2.03236985206604,10000.0,51552.22501087189,53383.41737794876,51552.22501087189,1821.650677442551,4.508504867553711,0.0 -154100,3.8813124,1.4006475,,,,,,,,,,,,,, -154200,4.1319857,1.5192384,,,,,,,,,,,,,, -154300,3.6351411,1.3359374,,,,,,,,,,,,,, -154400,3.930943,1.3020153,,,,,,,,,,,,,, -154500,3.8720725,1.3398538,,,,,,,,,,,,,, -154600,4.0528746,1.3676412,,,,,,,,,,,,,, -154700,4.1604905,1.4607053,,,,,,,,,,,,,, -154800,3.8601947,1.2671022,,,,,,,,,,,,,, -154900,3.9038901,1.3419635,,,,,,,,,,,,,, -155000,4.0701327,1.4188138,,,,,,,,,,,,,, -155100,4.1990366,1.4218563,,,,,,,,,,,,,, -155200,4.2713995,1.374924,,,,,,,,,,,,,, -155300,3.9476852,1.3030361,,,,,,,,,,,,,, -155400,4.1923647,1.353024,,,,,,,,,,,,,, -155500,4.1588354,1.2686629,,,,,,,,,,,,,, -155562,,,0.7763671875,0.8441674709320068,0.6982199549674988,1.2281646728515625,50000.0,0.5663000345230103,1.9566088914871216,10000.0,52062.35582947731,53911.72821640968,52062.35582947731,1839.724959373474,4.563934326171875,0.0 -155600,4.095711,1.3191545,,,,,,,,,,,,,, -155700,3.6911092,1.3614243,,,,,,,,,,,,,, -155800,3.7605145,1.2530776,,,,,,,,,,,,,, -155900,3.8313885,1.3365579,,,,,,,,,,,,,, -156000,4.6759624,1.3623393,,,,,,,,,,,,,, -156100,3.9315143,1.4093641,,,,,,,,,,,,,, -156200,4.325929,1.3769684,,,,,,,,,,,,,, -156300,4.0216355,1.2471715,,,,,,,,,,,,,, -156400,4.1825557,1.2819265,,,,,,,,,,,,,, -156500,4.079772,1.2183769,,,,,,,,,,,,,, -156600,3.9217937,1.2853256,,,,,,,,,,,,,, -156700,3.9804077,1.2343261,,,,,,,,,,,,,, -156800,4.140772,1.384304,,,,,,,,,,,,,, -156900,3.969664,1.3174322,,,,,,,,,,,,,, -157000,4.3655443,1.4178424,,,,,,,,,,,,,, -157090,,,0.7760881781578064,0.8410550951957703,0.6970399618148804,1.2410316467285156,50000.0,0.5633000135421753,1.987952828407288,10000.0,52572.40886282921,54439.35976409912,52572.40886282921,1857.1919219493864,4.625617980957031,0.0 -157100,3.9259055,1.2304542,,,,,,,,,,,,,, -157200,4.483663,1.331212,,,,,,,,,,,,,, -157300,4.471752,1.3943315,,,,,,,,,,,,,, -157400,4.170481,1.2577479,,,,,,,,,,,,,, -157500,3.930213,1.1647925,,,,,,,,,,,,,, -157600,4.144679,1.313896,,,,,,,,,,,,,, -157700,4.184872,1.1450942,,,,,,,,,,,,,, -157800,4.1640854,1.3246119,,,,,,,,,,,,,, -157900,4.0448804,1.291046,,,,,,,,,,,,,, -158000,3.9711025,1.3095865,,,,,,,,,,,,,, -158100,4.1938286,1.2567408,,,,,,,,,,,,,, -158200,4.249487,1.373259,,,,,,,,,,,,,, -158300,4.1847205,1.2766553,,,,,,,,,,,,,, -158400,3.9026246,1.1937972,,,,,,,,,,,,,, -158500,4.0527563,1.2436131,,,,,,,,,,,,,, -158600,4.498453,1.2260026,,,,,,,,,,,,,, -158618,,,0.8107461333274841,0.691912055015564,0.7024999856948853,1.2118477821350098,50000.0,0.579800009727478,1.921416163444519,10000.0,53082.36680340767,54967.49105381966,53082.36680340767,1875.256622314453,4.684526920318604,0.0 -158700,4.0368013,1.2421889,,,,,,,,,,,,,, -158800,4.401575,1.3285865,,,,,,,,,,,,,, -158900,4.2264385,1.4050311,,,,,,,,,,,,,, -159000,4.1264024,1.2591038,,,,,,,,,,,,,, -159100,4.4225235,1.3350103,,,,,,,,,,,,,, -159200,4.3053794,1.2553856,,,,,,,,,,,,,, -159300,3.8253856,1.2092128,,,,,,,,,,,,,, -159400,4.4568334,1.2992209,,,,,,,,,,,,,, -159500,4.358603,1.2301892,,,,,,,,,,,,,, -159600,4.1464467,1.1799451,,,,,,,,,,,,,, -159700,4.204201,1.1749715,,,,,,,,,,,,,, -159800,4.0499153,1.2384814,,,,,,,,,,,,,, -159900,3.9817898,1.1614771,,,,,,,,,,,,,, -160000,4.2921357,1.3155022,,,,,,,,,,,,,, -160100,4.101977,1.2031295,,,,,,,,,,,,,, -160146,,,0.8052654266357422,0.7151588201522827,0.7060399651527405,1.1937520503997805,50000.0,0.5776000022888184,1.909941554069519,10000.0,53592.58766222,55495.69156050682,53592.58766222,1893.1282756328585,4.742824554443359,0.0 -160200,4.477134,1.4479425,,,,,,,,,,,,,, -160300,4.8709955,1.3243765,,,,,,,,,,,,,, -160400,4.889298,1.2805921,,,,,,,,,,,,,, -160500,4.7654786,1.2024202,,,,,,,,,,,,,, -160600,4.5707006,1.3117278,,,,,,,,,,,,,, -160700,4.3175635,1.2784022,,,,,,,,,,,,,, -160800,4.594934,1.2573676,,,,,,,,,,,,,, -160900,4.331704,1.2400699,,,,,,,,,,,,,, -161000,4.333441,1.1471598,,,,,,,,,,,,,, -161100,4.4038863,1.2015219,,,,,,,,,,,,,, -161200,4.380925,1.2885507,,,,,,,,,,,,,, -161300,4.615591,1.2905071,,,,,,,,,,,,,, -161400,4.3359447,1.157913,,,,,,,,,,,,,, -161500,4.3044953,1.1971672,,,,,,,,,,,,,, -161600,4.2596617,1.18738,,,,,,,,,,,,,, -161674,,,0.8055843114852905,0.7204368710517883,0.7084199786186218,1.1815768480300903,50000.0,0.5815000534057617,1.902944564819336,10000.0,54102.61943149567,56023.72791719437,54102.61943149567,1911.025577545166,4.800848007202148,0.0 -161700,4.5142555,1.2754002,,,,,,,,,,,,,, -161800,4.7200828,1.2013979,,,,,,,,,,,,,, -161900,4.232346,1.1841106,,,,,,,,,,,,,, -162000,4.4854136,1.2464772,,,,,,,,,,,,,, -162100,4.417578,1.1704609,,,,,,,,,,,,,, -162200,4.46126,1.2374574,,,,,,,,,,,,,, -162300,4.411191,1.2528533,,,,,,,,,,,,,, -162400,4.206425,1.2079501,,,,,,,,,,,,,, -162500,4.859926,1.2654796,,,,,,,,,,,,,, -162600,4.076648,1.0858271,,,,,,,,,,,,,, -162700,4.951346,1.2912369,,,,,,,,,,,,,, -162800,5.0482755,1.2325745,,,,,,,,,,,,,, -162900,4.6036997,1.2325062,,,,,,,,,,,,,, -163000,4.813452,1.2078501,,,,,,,,,,,,,, -163100,5.0939064,1.1955961,,,,,,,,,,,,,, -163200,4.4678383,1.1318052,,,,,,,,,,,,,, -163201,,,0.8162667155265808,0.6751604676246643,0.7181999683380127,1.147418975830078,50000.0,0.5901000499725342,1.8701530694961548,10000.0,54612.54327106476,56551.63760781288,54612.54327106476,1928.901858329773,4.860100746154785,0.0 -163300,4.5154066,1.1413201,,,,,,,,,,,,,, -163400,4.853083,1.2573159,,,,,,,,,,,,,, -163500,4.3460073,1.1798013,,,,,,,,,,,,,, -163600,4.9158344,1.1533525,,,,,,,,,,,,,, -163700,4.496633,1.1982317,,,,,,,,,,,,,, -163800,4.8503246,1.1864666,,,,,,,,,,,,,, -163900,4.608239,1.159291,,,,,,,,,,,,,, -164000,4.6159983,1.2171557,,,,,,,,,,,,,, -164100,4.925199,1.2555516,,,,,,,,,,,,,, -164200,5.256984,1.2922716,,,,,,,,,,,,,, -164300,4.561747,1.1033314,,,,,,,,,,,,,, -164400,4.7067127,1.2032223,,,,,,,,,,,,,, -164500,4.4457684,1.0474097,,,,,,,,,,,,,, -164600,4.6278696,1.1521659,,,,,,,,,,,,,, -164700,4.722069,1.2324274,,,,,,,,,,,,,, -164729,,,0.8210498690605164,0.6503438353538513,0.7199400067329407,1.1440032720565796,50000.0,0.5978000164031982,1.8753104209899905,10000.0,55122.71467757225,57079.52602314949,55122.71467757225,1946.5123527050016,4.916692018508911,0.0 -164800,4.7020683,1.2454816,,,,,,,,,,,,,, -164900,4.8433,1.2008605,,,,,,,,,,,,,, -165000,4.3970137,1.131036,,,,,,,,,,,,,, -165100,4.641318,1.2163842,,,,,,,,,,,,,, -165200,4.748372,1.1099375,,,,,,,,,,,,,, -165300,4.463599,1.1926845,,,,,,,,,,,,,, -165400,4.650483,1.189434,,,,,,,,,,,,,, -165500,5.1311536,1.181847,,,,,,,,,,,,,, -165600,4.628192,1.0988511,,,,,,,,,,,,,, -165700,5.337611,1.1853421,,,,,,,,,,,,,, -165800,4.605799,1.1040074,,,,,,,,,,,,,, -165900,5.013404,1.1674162,,,,,,,,,,,,,, -166000,4.792987,1.1909297,,,,,,,,,,,,,, -166100,4.9042435,1.0685018,,,,,,,,,,,,,, -166200,4.9017167,1.1532269,,,,,,,,,,,,,, -166257,,,0.8339245915412903,0.5999334454536438,0.720579981803894,1.1352955102920532,50000.0,0.5962000489234924,1.851595997810364,10000.0,55632.91140437126,57607.49398231506,55632.91140437126,1964.1750729084013,4.975133895874023,0.0 -166300,5.111663,1.1283152,,,,,,,,,,,,,, -166400,4.457679,1.0721598,,,,,,,,,,,,,, -166500,4.8363686,1.1198963,,,,,,,,,,,,,, -166600,5.245491,1.180757,,,,,,,,,,,,,, -166700,4.8489366,1.2143738,,,,,,,,,,,,,, -166800,4.91731,1.1490396,,,,,,,,,,,,,, -166900,5.2360682,1.1289241,,,,,,,,,,,,,, -167000,5.3355384,1.1976373,,,,,,,,,,,,,, -167100,5.2364936,1.1329002,,,,,,,,,,,,,, -167200,5.403057,1.150934,,,,,,,,,,,,,, -167300,5.354717,1.2099597,,,,,,,,,,,,,, -167400,4.6568885,1.0737718,,,,,,,,,,,,,, -167500,5.4705844,1.0637914,,,,,,,,,,,,,, -167600,4.8683634,1.160645,,,,,,,,,,,,,, -167700,4.5573215,1.1225849,,,,,,,,,,,,,, -167784,,,0.8476362824440002,0.5580589175224304,0.7263399958610535,1.1028032302856443,50000.0,0.5995000004768372,1.822930932044983,10000.0,56142.81165266037,58135.56666016579,56142.81165266037,1982.2370581626888,5.035584211349487,0.0 -167800,5.1912365,1.1742156,,,,,,,,,,,,,, -167900,5.1676316,1.1734421,,,,,,,,,,,,,, -168000,5.368926,1.1626636,,,,,,,,,,,,,, -168100,5.4293237,1.0582117,,,,,,,,,,,,,, -168200,5.605987,1.1575453,,,,,,,,,,,,,, -168300,4.882879,1.1053548,,,,,,,,,,,,,, -168400,4.9939475,1.0389216,,,,,,,,,,,,,, -168500,4.885245,1.0670898,,,,,,,,,,,,,, -168600,4.9726133,1.0731343,,,,,,,,,,,,,, -168700,4.728001,0.963377,,,,,,,,,,,,,, -168800,5.2490025,1.104239,,,,,,,,,,,,,, -168900,4.9208994,1.1415648,,,,,,,,,,,,,, -169000,4.838613,1.0293871,,,,,,,,,,,,,, -169100,5.2045436,1.0827588,,,,,,,,,,,,,, -169200,5.2170944,1.1123273,,,,,,,,,,,,,, -169300,5.157056,1.0833089,,,,,,,,,,,,,, -169312,,,0.8518813848495483,0.5361858010292053,0.7312799692153931,1.0846476554870603,50000.0,0.6077000498771667,1.800598382949829,10000.0,56652.88702845573,58664.32110500336,56652.88702845573,2000.8068022727969,5.0948405265808105,0.0 -169400,4.986576,1.0328349,,,,,,,,,,,,,, -169500,5.4347386,1.0205836,,,,,,,,,,,,,, -169600,5.4918504,1.1537215,,,,,,,,,,,,,, -169700,4.969001,1.0728748,,,,,,,,,,,,,, -169800,5.285594,1.2143306,,,,,,,,,,,,,, -169900,5.3387127,1.0897869,,,,,,,,,,,,,, -170000,5.2219777,1.1326584,,,,,,,,,,,,,, -170100,5.3449965,1.0467895,,,,,,,,,,,,,, -170200,4.909469,0.9799069,,,,,,,,,,,,,, -170300,5.4492664,1.1665952,,,,,,,,,,,,,, -170400,4.787754,1.0255747,,,,,,,,,,,,,, -170500,5.324461,1.0793988,,,,,,,,,,,,,, -170600,5.4458175,1.1341285,,,,,,,,,,,,,, -170700,5.250943,1.0384865,,,,,,,,,,,,,, -170800,4.797803,1.0746033,,,,,,,,,,,,,, -170840,,,0.8516421914100647,0.5271843671798706,0.7345199584960938,1.0806338787078855,50000.0,0.6133000254631042,1.7768418788909912,10000.0,57162.94032907486,59192.33838939667,57162.94032907486,2018.65958237648,5.155996322631836,0.0 -170900,5.675335,1.0487092,,,,,,,,,,,,,, -171000,5.3636026,1.145878,,,,,,,,,,,,,, -171100,5.1836667,1.0549215,,,,,,,,,,,,,, -171200,5.11349,1.0485578,,,,,,,,,,,,,, -171300,5.3337564,1.0320591,,,,,,,,,,,,,, -171400,4.839941,0.9786469,,,,,,,,,,,,,, -171500,5.2448926,0.9839504,,,,,,,,,,,,,, -171600,5.114656,0.9765968,,,,,,,,,,,,,, -171700,5.270953,1.0657711,,,,,,,,,,,,,, -171800,5.7078648,1.0620364,,,,,,,,,,,,,, -171900,5.1711845,1.0572617,,,,,,,,,,,,,, -172000,4.9873877,0.92394483,,,,,,,,,,,,,, -172100,5.1219945,1.0118836,,,,,,,,,,,,,, -172200,5.823935,0.9956916,,,,,,,,,,,,,, -172300,5.5250525,1.1626725,,,,,,,,,,,,,, -172368,,,0.8561064600944519,0.51559978723526,0.7358399629592896,1.0692557096481323,50000.0,0.6114000082015991,1.7621028423309326,10000.0,57673.15584611893,59720.38824558258,57673.15584611893,2036.3830211162567,5.2170140743255615,0.0 -172400,5.1656723,1.1251204,,,,,,,,,,,,,, -172500,5.331632,1.004823,,,,,,,,,,,,,, -172600,5.4048343,1.0561559,,,,,,,,,,,,,, -172700,4.9590893,1.0132631,,,,,,,,,,,,,, -172800,5.4354362,0.9841922,,,,,,,,,,,,,, -172900,5.059493,0.95910823,,,,,,,,,,,,,, -173000,5.2110777,1.0736911,,,,,,,,,,,,,, -173100,4.9990377,0.9524005,,,,,,,,,,,,,, -173200,5.3946238,1.0595176,,,,,,,,,,,,,, -173300,5.549783,0.98106694,,,,,,,,,,,,,, -173400,5.428185,1.1017449,,,,,,,,,,,,,, -173500,5.2245336,1.0436592,,,,,,,,,,,,,, -173600,5.0341306,1.0340858,,,,,,,,,,,,,, -173700,5.6198454,1.0522618,,,,,,,,,,,,,, -173800,5.6227126,1.0430613,,,,,,,,,,,,,, -173896,,,0.8584582209587097,0.5045436024665833,0.7406799793243408,1.059166669845581,50000.0,0.6181000471115112,1.763548493385315,10000.0,58183.20809054375,60248.25574231148,58183.20809054375,2054.0894026756287,5.276186227798462,0.0 -173900,5.1488757,1.0450897,,,,,,,,,,,,,, -174000,5.1291018,1.0031672,,,,,,,,,,,,,, -174100,5.070605,1.000469,,,,,,,,,,,,,, -174200,5.1555414,1.0127347,,,,,,,,,,,,,, -174300,5.6059146,1.1020322,,,,,,,,,,,,,, -174400,5.7813334,1.1062872,,,,,,,,,,,,,, -174500,5.2010474,1.040008,,,,,,,,,,,,,, -174600,5.4451537,1.0639827,,,,,,,,,,,,,, -174700,5.1942205,1.0068984,,,,,,,,,,,,,, -174800,5.111308,1.0130148,,,,,,,,,,,,,, -174900,5.389519,0.9850752,,,,,,,,,,,,,, -175000,6.050408,1.0654835,,,,,,,,,,,,,, -175100,5.189109,1.0863793,,,,,,,,,,,,,, -175200,5.2627716,1.0432129,,,,,,,,,,,,,, -175300,5.177179,1.020823,,,,,,,,,,,,,, -175400,5.1131153,1.0343616,,,,,,,,,,,,,, -175424,,,0.8757573366165161,0.4480269849300384,0.7421999573707581,1.0490195751190186,50000.0,0.6177000403404236,1.7469100952148438,10000.0,58693.388276577,60776.315844774246,58693.388276577,2071.8521168231964,5.343605995178223,0.0 -175500,5.289681,0.9678076,,,,,,,,,,,,,, -175600,5.436795,1.02275,,,,,,,,,,,,,, -175700,5.888881,1.0694407,,,,,,,,,,,,,, -175800,5.9609647,1.082009,,,,,,,,,,,,,, -175900,5.4063005,1.0579774,,,,,,,,,,,,,, -176000,5.029891,0.95308965,,,,,,,,,,,,,, -176100,5.4935384,0.998365,,,,,,,,,,,,,, -176200,5.4474487,0.98348486,,,,,,,,,,,,,, -176300,5.6848917,0.9888946,,,,,,,,,,,,,, -176400,5.964604,0.9278185,,,,,,,,,,,,,, -176500,5.6580524,0.96643907,,,,,,,,,,,,,, -176600,5.0413656,0.9706842,,,,,,,,,,,,,, -176700,5.178955,0.98714614,,,,,,,,,,,,,, -176800,5.3284793,1.0029483,,,,,,,,,,,,,, -176900,5.473533,1.0599568,,,,,,,,,,,,,, -176952,,,0.8742027878761292,0.4469842612743377,0.7447999715805054,1.043401122093201,50000.0,0.6208000183105469,1.750800848007202,10000.0,59203.56213974953,61304.11006069183,59203.56213974953,2089.3616197109222,5.404559135437012,0.0 -177000,5.57144,1.0022899,,,,,,,,,,,,,, -177100,5.669481,0.97539437,,,,,,,,,,,,,, -177200,5.5357375,1.0449542,,,,,,,,,,,,,, -177300,5.3582363,0.98245215,,,,,,,,,,,,,, -177400,5.641441,1.0141922,,,,,,,,,,,,,, -177500,5.4692025,0.95901847,,,,,,,,,,,,,, -177600,5.4004016,0.96932214,,,,,,,,,,,,,, -177700,5.5267143,1.045264,,,,,,,,,,,,,, -177800,5.793558,1.0357233,,,,,,,,,,,,,, -177900,5.0220795,0.9808598,,,,,,,,,,,,,, -178000,5.0994663,0.9792727,,,,,,,,,,,,,, -178100,5.360412,0.9476554,,,,,,,,,,,,,, -178200,5.8044014,1.0097497,,,,,,,,,,,,,, -178300,6.303057,1.034655,,,,,,,,,,,,,, -178400,5.3231473,0.9020612,,,,,,,,,,,,,, -178480,,,0.878348171710968,0.4332602918148041,0.7470600008964539,1.0316156148910522,50000.0,0.6232000589370728,1.740649938583374,10000.0,59713.66224193573,61831.87384080887,59713.66224193573,2106.9123561382294,5.46753716468811,0.0 -178500,5.634752,1.0512288,,,,,,,,,,,,,, -178600,5.1528034,0.8685947,,,,,,,,,,,,,, -178700,5.2180705,1.0384046,,,,,,,,,,,,,, -178800,5.176982,0.9540355,,,,,,,,,,,,,, -178900,5.941168,0.9986248,,,,,,,,,,,,,, -179000,5.6248446,0.9763817,,,,,,,,,,,,,, -179100,5.201477,0.9853245,,,,,,,,,,,,,, -179200,5.181643,1.0344801,,,,,,,,,,,,,, -179300,5.059471,0.86281794,,,,,,,,,,,,,, -179400,5.3482065,1.019677,,,,,,,,,,,,,, -179500,5.442405,0.9450011,,,,,,,,,,,,,, -179600,5.5601788,0.9465632,,,,,,,,,,,,,, -179700,5.2677965,0.88277406,,,,,,,,,,,,,, -179800,5.4805593,0.9371414,,,,,,,,,,,,,, -179900,5.3626513,1.0234956,,,,,,,,,,,,,, -180000,5.798477,0.8804064,,,,,,,,,,,,,, -180007,,,0.8775510191917419,0.4316553771495819,0.7484999895095825,1.030386447906494,50000.0,0.6262000203132629,1.7318052053451538,10000.0,60223.66440916061,62359.50334787369,60223.66440916061,2124.426818370819,5.530123233795166,0.0 -180100,6.229684,0.9474802,,,,,,,,,,,,,, -180200,5.675374,1.0800332,,,,,,,,,,,,,, -180300,5.4161434,0.98770356,,,,,,,,,,,,,, -180400,5.2747226,0.8871811,,,,,,,,,,,,,, -180500,5.222515,0.9059088,,,,,,,,,,,,,, -180600,5.7351327,1.0287443,,,,,,,,,,,,,, -180700,5.3829536,0.98492044,,,,,,,,,,,,,, -180800,4.9770665,0.9436463,,,,,,,,,,,,,, -180900,5.4949455,0.98914605,,,,,,,,,,,,,, -181000,5.594131,0.9677215,,,,,,,,,,,,,, -181100,5.0586953,0.93819714,,,,,,,,,,,,,, -181200,5.6582894,1.0149183,,,,,,,,,,,,,, -181300,5.658546,0.95113355,,,,,,,,,,,,,, -181400,5.813507,0.90665454,,,,,,,,,,,,,, -181500,5.431168,0.99380004,,,,,,,,,,,,,, -181535,,,0.883211076259613,0.4166988730430603,0.7495200037956238,1.027388334274292,50000.0,0.6265000104904175,1.7263579368591309,10000.0,60733.81407546997,62887.61942386627,60733.81407546997,2142.2800900936127,5.592935800552368,0.0 -181600,5.604099,0.85601795,,,,,,,,,,,,,, -181700,5.5564294,0.99430716,,,,,,,,,,,,,, -181800,6.1377788,1.0169637,,,,,,,,,,,,,, -181900,5.452004,0.8841994,,,,,,,,,,,,,, -182000,5.8130455,0.9743918,,,,,,,,,,,,,, -182100,5.3226166,0.98196936,,,,,,,,,,,,,, -182200,6.0835476,0.9687281,,,,,,,,,,,,,, -182300,5.474063,0.97249913,,,,,,,,,,,,,, -182400,5.8595796,0.9013879,,,,,,,,,,,,,, -182500,5.8780274,0.9691226,,,,,,,,,,,,,, -182600,5.0817842,0.91980165,,,,,,,,,,,,,, -182700,5.7918286,0.99464834,,,,,,,,,,,,,, -182800,5.7438884,0.9789021,,,,,,,,,,,,,, -182900,5.706108,0.9575723,,,,,,,,,,,,,, -183000,5.504881,0.9527129,,,,,,,,,,,,,, -183063,,,0.8834103941917419,0.4161964356899261,0.7491599917411804,1.0240086317062378,50000.0,0.6269000172615051,1.7243411540985107,10000.0,61243.983364105225,63415.52132463455,61243.983364105225,2159.90277671814,5.653109788894653,0.0 -183100,5.8148313,0.8984689,,,,,,,,,,,,,, -183200,5.330656,0.9011503,,,,,,,,,,,,,, -183300,5.509353,0.97260344,,,,,,,,,,,,,, -183400,5.8346443,0.95485485,,,,,,,,,,,,,, -183500,5.4913707,0.9068182,,,,,,,,,,,,,, -183600,5.624695,0.94938153,,,,,,,,,,,,,, -183700,5.5994997,0.8364866,,,,,,,,,,,,,, -183800,5.4438305,0.9207173,,,,,,,,,,,,,, -183900,5.686797,0.9530353,,,,,,,,,,,,,, -184000,5.6980777,0.90488535,,,,,,,,,,,,,, -184100,5.5764484,1.1055546,,,,,,,,,,,,,, -184200,5.5223303,0.9774849,,,,,,,,,,,,,, -184300,5.6496096,1.0755455,,,,,,,,,,,,,, -184400,5.3836927,0.9405545,,,,,,,,,,,,,, -184500,5.4651814,0.9012302,,,,,,,,,,,,,, -184590,,,0.8830117583274841,0.4159443080425262,0.7496199607849121,1.0223222970962524,50000.0,0.6244000196456909,1.7232787609100342,10000.0,61753.94016075134,63943.381813287735,61753.94016075134,2177.6936724185944,5.716378927230835,0.0 -184600,5.568681,0.9536032,,,,,,,,,,,,,, -184700,5.276056,0.9586228,,,,,,,,,,,,,, -184800,5.415563,0.85756373,,,,,,,,,,,,,, -184900,5.450569,0.9496587,,,,,,,,,,,,,, -185000,5.2778974,0.92739457,,,,,,,,,,,,,, -185100,6.082137,0.9718807,,,,,,,,,,,,,, -185200,5.6095543,0.95683986,,,,,,,,,,,,,, -185300,6.114453,0.9134925,,,,,,,,,,,,,, -185400,6.073182,0.9660857,,,,,,,,,,,,,, -185500,6.0392036,1.0028553,,,,,,,,,,,,,, -185600,5.5578475,0.9396361,,,,,,,,,,,,,, -185700,5.4170494,0.98484355,,,,,,,,,,,,,, -185800,5.5023546,0.95355946,,,,,,,,,,,,,, -185900,5.583215,0.9914416,,,,,,,,,,,,,, -186000,5.2130947,0.97095525,,,,,,,,,,,,,, -186100,5.707249,1.0169647,,,,,,,,,,,,,, -186118,,,0.8834103941917419,0.4153479635715484,0.7500199675559998,1.0227417945861816,50000.0,0.6252000331878662,1.7218387126922607,10000.0,62264.0898706913,64471.15631699562,62264.0898706913,2195.205705881119,5.779400110244751,0.0 -186200,5.484187,0.9580122,,,,,,,,,,,,,, -186300,5.577255,0.96718633,,,,,,,,,,,,,, -186400,5.7413573,0.97784275,,,,,,,,,,,,,, -186500,5.9236283,1.00898,,,,,,,,,,,,,, -186600,5.5393305,0.91081595,,,,,,,,,,,,,, -186666,,,0.8849848508834839,0.4112196862697601,0.7501599788665771,1.0221025943756104,50000.0,0.6258000135421753,1.7213950157165527,10000.0,62446.866736888885,64671.84037780762,62446.866736888885,2213.0312852859497,5.842752933502197,0.0 -186666,,,,,,,,,,,62446.866736888885,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/eval_measurements.csv deleted file mode 100644 index f4be768bd..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,125 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -17.916585683822632,0.0,34.15979290008545,1,0,34.15979290008545,0.0006000000284984,6.910250186920166,10000,52.07652568817139,0.0007772640092298,6.909999847412109,0.0009599999757483,6.910243988037109,50000 -35.58520984649658,0.0199439525604248,544.2960352897644,1514,0,544.2960352897644,0.1199000030755996,4.764464855194092,10000,579.9500658512115,0.1859654039144516,4.123931407928467,0.1693599969148636,4.246126651763916,50000 -53.290247440338135,0.0474295616149902,1054.3922047615051,3027,0,1054.3922047615051,0.2267000079154968,3.950527906417847,10000,1107.8278141021729,0.3240991532802582,3.1349141597747803,0.3021599948406219,3.293057441711426,50000 -70.83682298660278,0.0788309574127197,1564.3699560165403,4541,0,1564.3699560165403,0.3175000250339508,3.3052544593811035,10000,1635.432152509689,0.5005580186843872,2.1658248901367188,0.4260199964046478,2.557955503463745,50000 -88.65340995788574,0.1107234954833984,2074.408198595047,6057,0,2074.408198595047,0.3827000260353088,2.921358346939087,10000,2163.3676924705505,0.554109513759613,1.864499568939209,0.5052399635314941,2.154295921325684,50000 -106.74323630332948,0.1415345668792724,2584.638957977295,7574,0,2584.638957977295,0.4020000100135803,2.817870855331421,10000,2691.767830848694,0.5734215378761292,1.780144214630127,0.5266799926757812,2.0440306663513184,50000 -124.827782869339,0.1721091270446777,3094.6172511577606,9090,0,3094.6172511577606,0.4276000261306762,2.675454378128052,10000,3219.9098541736603,0.5981544852256775,1.6540687084197998,0.5498999953269958,1.909760475158692,50000 -142.7768428325653,0.1999351978302002,3604.871292591095,10607,0,3604.871292591095,0.4430000185966491,2.572227001190185,10000,3748.189495563507,0.6220703125,1.550428032875061,0.5730400085449219,1.805853247642517,50000 -160.3536171913147,0.2311937808990478,4114.898060321808,12124,0,4114.898060321808,0.4637000262737274,2.4511096477508545,10000,4275.873243093491,0.6409637928009033,1.4560710191726685,0.5910199880599976,1.7010043859481812,50000 -177.99181604385376,0.2608301639556885,4624.97420334816,13641,0,4624.97420334816,0.4702000319957733,2.450937509536743,10000,4803.666056632996,0.684988796710968,1.264384627342224,0.5950599908828735,1.6960663795471191,50000 -196.0883026123047,0.2926368713378906,5134.960388422012,15157,0,5134.960388422012,0.469400018453598,2.4176716804504395,10000,5331.829082727432,0.671894907951355,1.306368708610535,0.6003199815750122,1.6621453762054443,50000 -214.07593870162964,0.3269875049591064,5644.989506959915,16674,0,5644.989506959915,0.4724000096321106,2.433629035949707,10000,5859.929432630539,0.6599569320678711,1.363683581352234,0.5981000065803528,1.6931911706924438,50000 -231.9118676185608,0.356968879699707,6154.933459997177,18191,0,6154.933459997177,0.4735000133514404,2.4417972564697266,10000,6387.788614749908,0.6560705900192261,1.3798298835754397,0.5999000072479248,1.6763943433761597,50000 -250.8443946838379,0.3887760639190674,6665.012505054474,19708,0,6665.012505054474,0.4883000254631042,2.366896867752075,10000,6916.881340265274,0.6623684763908386,1.3440089225769043,0.6084399819374084,1.626311182975769,50000 -268.487957239151,0.4199604988098144,7175.144921779633,21225,0,7175.144921779633,0.4772000312805176,2.443645000457764,10000,7444.737069368362,0.6512874364852905,1.397564172744751,0.6033200025558472,1.6574549674987793,50000 -286.0397162437439,0.4515516757965088,7685.214646100998,22742,0,7685.214646100998,0.4816000163555145,2.404182910919189,10000,7972.439017057419,0.6950533986091614,1.1915576457977295,0.6091399788856506,1.6365070343017578,50000 -304.0058841705322,0.4865305423736572,8195.194686412811,24260,0,8195.194686412811,0.4821000099182129,2.375473737716675,10000,8500.469371795654,0.6743263602256775,1.2909504175186155,0.6072799563407898,1.6269196271896362,50000 -321.99264454841614,0.5148470401763916,8705.145479679108,25777,0,8705.145479679108,0.5004000067710876,2.291343927383423,10000,9028.48357987404,0.6813815236091614,1.2523821592330933,0.6219599843025208,1.5588935613632202,50000 -339.64798951148987,0.5484886169433594,9215.205620288849,27294,0,9215.205620288849,0.4772000312805176,2.400562047958374,10000,9556.281441688538,0.6639030575752258,1.3441171646118164,0.6081399917602539,1.6323024034500122,50000 -357.6402759552002,0.5823171138763428,9725.12575149536,28812,0,9725.12575149536,0.4860000312328338,2.371846914291382,10000,10084.276794195175,0.6672313213348389,1.3261438608169556,0.6168999671936035,1.6034150123596191,50000 -375.4659821987152,0.6163196563720703,10235.195373535156,30329,0,10235.195373535156,0.4878000319004059,2.3278603553771973,10000,10612.25537633896,0.6734893321990967,1.2919286489486694,0.610260009765625,1.616947889328003,50000 -393.60303115844727,0.6519203186035156,10745.252250671389,31847,0,10745.252250671389,0.499500036239624,2.2625062465667725,10000,11140.534215211868,0.7060347199440002,1.1288166046142578,0.6247999668121338,1.5537689924240112,50000 -411.4780659675598,0.6850986480712891,11255.309514045715,33365,0,11255.309514045715,0.4976000189781189,2.288537979125977,10000,11668.54907798767,0.6925222873687744,1.2043423652648926,0.6267799735069275,1.553866982460022,50000 -429.2470765113831,0.7219088077545166,11765.304506778715,34883,0,11765.304506778715,0.4850000143051147,2.387573003768921,10000,12196.39869570732,0.6701012253761292,1.305039405822754,0.6144799590110779,1.6147181987762451,50000 -446.7454869747162,0.7614481449127197,12275.59494113922,36401,0,12275.59494113922,0.480400025844574,2.386909008026123,10000,12724.276602745056,0.6705994606018066,1.3002781867980957,0.6158999800682068,1.5967798233032229,50000 -464.6421930789948,0.7978343963623047,12785.589760780334,37918,0,12785.589760780334,0.5107000470161438,2.221066236495972,10000,13252.253804206848,0.6964285373687744,1.1979633569717407,0.6372399926185608,1.4970052242279053,50000 -482.5713183879852,0.8367643356323242,13295.721585988998,39436,0,13295.721585988998,0.5151000022888184,2.183366060256958,10000,13780.402867794037,0.7492027878761292,0.9707934260368348,0.644320011138916,1.4667904376983645,50000 -500.2351813316345,0.8699560165405273,13805.760528564451,40954,0,13805.760528564451,0.5002000331878662,2.254570484161377,10000,14308.188065290453,0.711933970451355,1.1202871799468994,0.6308199763298035,1.5269010066986084,50000 -517.9955246448517,0.9070932865142822,14315.898855924606,42473,0,14315.898855924606,0.5049000382423401,2.2341115474700928,10000,14836.172902822496,0.7022480964660645,1.17045259475708,0.6351400017738342,1.5160480737686155,50000 -535.8300864696503,0.9444942474365234,14826.120551109314,43991,0,14826.120551109314,0.5031999945640564,2.266105890274048,10000,15364.315973997116,0.6902702450752258,1.219701886177063,0.6301400065422058,1.5335793495178225,50000 -553.6056270599365,0.9802684783935548,15336.121745824814,45509,0,15336.121745824814,0.5153000354766846,2.2475781440734863,10000,15892.178544044496,0.7076291441917419,1.1384071111679075,0.6442599892616272,1.483494520187378,50000 -571.2046520709991,1.0184855461120603,15846.208704471588,47027,0,15846.208704471588,0.5103000402450562,2.221822500228882,10000,16419.95223212242,0.6973453164100647,1.1850770711898804,0.6376000046730042,1.4864510297775269,50000 -588.8579633235931,1.0557844638824463,16356.312096595764,48545,0,16356.312096595764,0.5067000389099121,2.224776268005371,10000,16947.796236276627,0.7341358065605164,1.0244120359420776,0.6324599981307983,1.5245925188064575,50000 -606.7944264411926,1.0985126495361328,16866.38495707512,50063,0,16866.38495707512,0.5229000449180603,2.150914430618286,10000,17475.897213935852,0.7353116869926453,1.0175830125808716,0.6514399647712708,1.4259246587753296,50000 -624.6651320457458,1.1381235122680664,17376.622447252274,51583,0,17376.622447252274,0.5144000053405762,2.2002127170562744,10000,18004.093755483627,0.7122129797935486,1.10273540019989,0.6414600014686584,1.4719308614730835,50000 -642.1593663692474,1.176835536956787,17886.875101804733,53101,0,17886.875101804733,0.5200999975204468,2.178856134414673,10000,18531.92908811569,0.7161391973495483,1.1013281345367432,0.6503399610519409,1.445090889930725,50000 -659.8881301879883,1.2191162109375,18396.95570278168,54619,0,18396.95570278168,0.5213000178337097,2.167120695114136,10000,19059.83021473885,0.7102997303009033,1.1230549812316897,0.6488800048828125,1.440788507461548,50000 -677.467814207077,1.266390085220337,18906.910097837448,56137,0,18906.910097837448,0.5105000138282776,2.2319228649139404,10000,19587.460858106613,0.6968669891357422,1.1886688470840454,0.6329799890518188,1.519242763519287,50000 -695.7363994121552,1.3087108135223389,19417.0595304966,57656,0,19417.0595304966,0.5085000395774841,2.2348201274871826,10000,20115.9705028534,0.7419084906578064,0.977180004119873,0.6464999914169312,1.4487738609313965,50000 -713.406112909317,1.3524727821350098,19927.296014785767,59175,0,19927.296014785767,0.5190000534057617,2.1798503398895264,10000,20643.969654798508,0.7238121628761292,1.0625795125961304,0.6496399641036987,1.435991883277893,50000 -731.1252071857452,1.3902521133422852,20437.448969364166,60694,0,20437.448969364166,0.5267000198364258,2.1441636085510254,10000,21171.929075479507,0.7162587642669678,1.1006048917770386,0.6487999558448792,1.4462645053863523,50000 -748.7553527355194,1.4329195022583008,20947.50677704811,62213,0,20947.50677704811,0.511900007724762,2.21729040145874,10000,21699.70861840248,0.7052375674247742,1.142183542251587,0.6395800113677979,1.4831955432891846,50000 -766.4331395626068,1.493783473968506,21457.440640211105,63731,0,21457.440640211105,0.5250000357627869,2.16992735862732,10000,22227.43039250374,0.7198262214660645,1.084575653076172,0.657039999961853,1.4061697721481323,50000 -784.1061565876007,1.5358808040618896,21967.380613327023,65249,0,21967.380613327023,0.5270000100135803,2.1229050159454346,10000,22755.134654521946,0.732421875,1.026896834373474,0.6575999855995178,1.399552345275879,50000 -802.0245227813721,1.5785470008850098,22477.35139322281,66768,0,22477.35139322281,0.5331000089645386,2.115618705749512,10000,23283.116079568863,0.7518733739852905,0.9370354413986206,0.6610400080680847,1.3839858770370483,50000 -819.579030752182,1.6207971572875977,22987.300876379013,68286,0,22987.300876379013,0.5396000146865845,2.088154554367065,10000,23810.711325645447,0.7399752736091614,0.9882227778434752,0.6643999814987183,1.3710088729858398,50000 -837.5111167430878,1.6648674011230469,23497.229088783264,69804,0,23497.229088783264,0.5301000475883484,2.137101888656616,10000,24338.66460514069,0.736726701259613,1.0046941041946411,0.661359965801239,1.3887338638305664,50000 -855.4055006504059,1.7085967063903809,24007.27220749855,71322,0,24007.27220749855,0.5324000120162964,2.085169553756714,10000,24866.694973945618,0.7341557741165161,1.0238789319992063,0.66211998462677,1.3770734071731567,50000 -873.3924815654755,1.7509582042694092,24517.43322777748,72840,0,24517.43322777748,0.525600016117096,2.1446402072906494,10000,25394.93424105644,0.7183912396430969,1.0752381086349487,0.6584999561309814,1.404468059539795,50000 -891.2145557403564,1.7948503494262695,25027.50243353844,74359,0,25027.50243353844,0.531000018119812,2.1150431632995605,10000,25922.918542146683,0.7762077450752258,0.8492001891136169,0.660539984703064,1.3998095989227295,50000 -908.893723487854,1.839226007461548,25537.48215198517,75877,0,25537.48215198517,0.5397000312805176,2.06687068939209,10000,26450.67085123062,0.7547233700752258,0.92606920003891,0.668940007686615,1.3541022539138794,50000 -926.574179649353,1.88204026222229,26047.61165094376,77396,0,26047.61165094376,0.5337000489234924,2.0874693393707275,10000,26978.57283329964,0.7463727593421936,0.9656037092208862,0.6665999889373779,1.3637856245040894,50000 -944.2308526039124,1.929426908493042,26557.65236115456,78915,0,26557.65236115456,0.525700032711029,2.150360107421875,10000,27506.36722517013,0.724609375,1.0671521425247192,0.6492800116539001,1.4398260116577148,50000 -961.969643831253,1.9752976894378664,27067.61016345024,80433,0,27067.61016345024,0.5360000133514404,2.104771375656128,10000,28034.15910959244,0.7413902878761292,0.985304832458496,0.6678000092506409,1.350680589675903,50000 -979.7718670368196,2.0286335945129395,27577.544088840485,81951,0,27577.544088840485,0.5307000279426575,2.1657660007476807,10000,28561.997648715973,0.7299306392669678,1.0427826642990112,0.6603999733924866,1.401990294456482,50000 -997.4572043418884,2.074682235717773,28087.777733802795,83470,0,28087.777733802795,0.5425000190734863,2.0561273097991943,10000,29090.011644124985,0.7867506146430969,0.7846525311470032,0.6734799742698669,1.3287465572357178,50000 -1015.4390184879304,2.119399070739746,28597.72906398773,84988,0,28597.72906398773,0.5521000027656555,2.043912649154663,10000,29618.0389854908,0.7635124325752258,0.8853896856307983,0.6748999953269958,1.3213706016540527,50000 -1032.9682595729828,2.1635024547576904,29107.663843154907,86506,0,29107.663843154907,0.54830002784729,2.0270941257476807,10000,30145.59640312195,0.7604631781578064,0.9026753902435304,0.6757599711418152,1.3243170976638794,50000 -1050.7145743370056,2.2094473838806152,29617.69456934929,88024,0,29617.69456934929,0.541100025177002,2.064603328704834,10000,30673.468591213223,0.753926157951355,0.9288226366043092,0.6761999726295471,1.318540334701538,50000 -1068.669147491455,2.254591703414917,30127.660895824432,89542,0,30127.660895824432,0.5512000322341919,2.026613473892212,10000,31201.48393774033,0.7547233700752258,0.921085238456726,0.6802200078964233,1.296499252319336,50000 -1086.2701907157898,2.303988218307495,30637.73820257187,91061,0,30637.73820257187,0.5562000274658203,1.9942518472671509,10000,31729.2611079216,0.7591477632522583,0.9124574065208436,0.6842399835586548,1.2905629873275757,50000 -1104.1770284175873,2.3518271446228027,31147.86074543,92579,0,31147.86074543,0.5516000390052795,2.0014238357543945,10000,32257.387805223465,0.7926697731018066,0.767687976360321,0.6845600008964539,1.2874754667282104,50000 -1121.8362724781036,2.3979732990264893,31657.77154326439,94096,0,31657.77154326439,0.553600013256073,2.048421621322632,10000,32785.05307340622,0.7694913744926453,0.8633120656013489,0.6793199777603149,1.320208191871643,50000 -1139.6173412799835,2.44614315032959,32167.696942329407,95614,0,32167.696942329407,0.5561000108718872,2.018871784210205,10000,33312.85690236092,0.7730388641357422,0.8482105135917664,0.6853799819946289,1.288509726524353,50000 -1158.4141011238098,2.492197275161743,32677.73971581459,97132,0,32677.73971581459,0.5425000190734863,2.0723519325256348,10000,33841.79132437706,0.751375138759613,0.9182205200195312,0.6756399869918823,1.318800449371338,50000 -1176.1916980743408,2.540599822998047,33187.793897628784,98651,0,33187.793897628784,0.5623000264167786,1.983777403831482,10000,34369.720563173294,0.7703284025192261,0.8573641180992126,0.6903600096702576,1.2503340244293213,50000 -1193.8929476737976,2.5904340744018555,33697.73100566864,100169,0,33697.73100566864,0.5621000528335571,1.986859440803528,10000,34897.45811223984,0.7879264950752258,0.7837117314338684,0.6886000037193298,1.271813988685608,50000 -1211.8229558467865,2.642286777496338,34207.68478536606,101687,0,34207.68478536606,0.5557000041007996,1.995941519737244,10000,35425.44287323952,0.7952805757522583,0.7596470713615417,0.6904799938201904,1.2509621381759644,50000 -1229.4830603599548,2.692025899887085,34717.76279425621,103206,0,34717.76279425621,0.5639000535011292,2.0179336071014404,10000,35953.27993154526,0.7831433415412903,0.8068510293960571,0.6899600028991699,1.2710587978363037,50000 -1247.4493174552915,2.7401788234710693,35227.90310502052,104726,0,35227.90310502052,0.5525000095367432,2.056422472000122,10000,36481.48370862007,0.7667809128761292,0.8686758279800415,0.6767399907112122,1.3189637660980225,50000 -1265.152621269226,2.7908036708831787,35738.055804252625,106245,0,35738.055804252625,0.5676000118255615,1.992361187934876,10000,37009.439423561096,0.7792769074440002,0.8142668008804321,0.6949399709701538,1.2566287517547607,50000 -1283.092089176178,2.841114997863769,36248.185584783554,107764,0,36248.185584783554,0.5662000179290771,1.9687929153442385,10000,37537.60767388344,0.7824258208274841,0.7984607219696045,0.6949399709701538,1.2417380809783936,50000 -1300.768966436386,2.890819549560547,36758.29444336891,109283,0,36758.29444336891,0.5707000494003296,1.9671730995178225,10000,38065.49232196808,0.8298588991165161,0.6174313426017761,0.7004799842834473,1.232813596725464,50000 -1318.7363233566284,2.9423089027404785,37268.264142751694,110801,0,37268.264142751694,0.5663000345230103,1.978838562965393,10000,38593.53087067604,0.8005221486091614,0.7223421931266785,0.695580005645752,1.2395744323730469,50000 -1336.601214170456,2.992850542068481,37778.35658097267,112320,0,37778.35658097267,0.5700000524520874,1.9467543363571167,10000,39121.58758187294,0.800223171710968,0.7208084464073181,0.6976799964904785,1.232202649116516,50000 -1354.453326463699,3.041289567947388,38288.51729607582,113839,0,38288.51729607582,0.5770000219345093,1.942642331123352,10000,39649.69808101654,0.8005420565605164,0.7312487959861755,0.7041599750518799,1.2149962186813354,50000 -1372.5982236862185,3.089034080505371,38798.464625597,115357,0,38798.464625597,0.5707000494003296,1.9666895866394043,10000,40177.886984825134,0.7947823405265808,0.752173125743866,0.6983799934387207,1.2410300970077517,50000 -1390.4349954128263,3.141059637069702,39308.584755182266,116875,0,39308.584755182266,0.5698000192642212,1.955405712127685,10000,40705.944717884064,0.7978315949440002,0.7409703135490417,0.7010599970817566,1.2206995487213137,50000 -1408.42196559906,3.193824052810669,39818.51747059822,118393,0,39818.51747059822,0.5749000310897827,1.9618397951126096,10000,41233.966715574265,0.8346220850944519,0.5950406193733215,0.7032999992370605,1.215368151664734,50000 -1426.150707244873,3.247591495513916,40328.73139810562,119912,0,40328.73139810562,0.5819000005722046,1.8898450136184688,10000,41762.01231408119,0.8252750039100647,0.6352033615112305,0.7068799734115601,1.1854684352874756,50000 -1443.8689014911652,3.300241231918335,40838.95229148865,121431,0,40838.95229148865,0.5742000341415405,1.977544069290161,10000,42290.05305337906,0.8082947731018066,0.6803027391433716,0.7044599652290344,1.2126598358154297,50000 -1461.8070714473724,3.3489990234375,41349.047131061554,122950,0,41349.047131061554,0.5859000086784363,1.9121652841567995,10000,42818.18391633034,0.8224449753761292,0.6331303119659424,0.7143599987030029,1.1689563989639282,50000 -1479.3437526226044,3.399578094482422,41859.04773306847,124469,0,41859.04773306847,0.5866000056266785,1.896184682846069,10000,43345.82070159912,0.8160474896430969,0.6538271903991699,0.7134999632835388,1.174975872039795,50000 -1497.206482887268,3.458393812179565,42368.96591067314,125988,0,42368.96591067314,0.5868000388145447,1.908347964286804,10000,43873.70950007439,0.8224050998687744,0.6366180181503296,0.7099800109863281,1.189315915107727,50000 -1514.8715977668762,3.5093042850494385,42879.00650596619,127506,0,42879.00650596619,0.5898000001907349,1.8952957391738887,10000,44401.515615940094,0.8494299650192261,0.529978334903717,0.7148199677467346,1.1761623620986938,50000 -1532.6676306724548,3.56146502494812,43389.02200245857,129025,0,43389.02200245857,0.5893000364303589,1.9311414957046509,10000,44929.42847776413,0.8404814600944519,0.564034104347229,0.7139599919319153,1.1836059093475342,50000 -1550.256118774414,3.613261222839356,43899.07996249199,130544,0,43899.07996249199,0.5923000574111938,1.8697551488876345,10000,45457.17548465729,0.8401227593421936,0.5696372985839844,0.7194199562072754,1.1422617435455322,50000 -1568.0413491725922,3.679718255996704,44409.24393892288,132064,0,44409.24393892288,0.588200032711029,1.8888368606567385,10000,45985.240468502045,0.8375318646430969,0.5773704051971436,0.7168599963188171,1.1653879880905151,50000 -1585.9611542224884,3.7378854751586914,44919.25254154205,133582,0,44919.25254154205,0.5958000421524048,1.894065499305725,10000,46513.27612757683,0.8428332209587097,0.5535558462142944,0.720579981803894,1.1483376026153564,50000 -1604.497786283493,3.788525342941284,45429.48771595955,135101,0,45429.48771595955,0.5998000502586365,1.867887020111084,10000,47042.1477367878,0.8835100531578064,0.4169529676437378,0.72461998462677,1.1391019821166992,50000 -1622.4473087787628,3.860964298248291,45939.46535348892,136619,0,45939.46535348892,0.5943000316619873,1.8947100639343264,10000,47570.19637298584,0.8668088316917419,0.4636785984039306,0.724399983882904,1.1427175998687744,50000 -1640.4148676395416,3.913733720779419,46449.43815970421,138138,0,46449.43815970421,0.6041000485420227,1.837876796722412,10000,48098.23912549019,0.8622449040412903,0.4757066071033478,0.7257999777793884,1.1301318407058716,50000 -1658.3606095314026,3.971649646759033,46959.3910138607,139656,0,46959.3910138607,0.6045000553131104,1.8522108793258667,10000,48626.24474787712,0.8647361397743225,0.466713011264801,0.7309799790382385,1.1106938123703003,50000 -1676.064817905426,4.028349161148071,47469.4056904316,141173,0,47469.4056904316,0.6008000373840332,1.848008275032044,10000,49154.06947255135,0.8649353981018066,0.4697670340538025,0.7278599739074707,1.1257988214492798,50000 -1694.0392887592316,4.082364082336426,47979.4405105114,142691,0,47979.4405105114,0.5978000164031982,1.939948320388794,10000,49682.18212604523,0.8581592440605164,0.4860461056232452,0.7240399718284607,1.1514825820922852,50000 -1711.6912882328031,4.139669179916382,48489.41442799568,144208,0,48489.41442799568,0.6085000038146973,1.879884600639344,10000,50209.91431188584,0.8985769748687744,0.3564726412296295,0.7285599708557129,1.129584789276123,50000 -1729.3212552070618,4.194288969039917,48999.39666056633,145726,0,48999.39666056633,0.6066000461578369,1.8463207483291624,10000,50737.63028144837,0.8902861475944519,0.382175862789154,0.7315999865531921,1.1145838499069214,50000 -1746.9731595516205,4.246668100357056,49509.46116828919,147244,0,49509.46116828919,0.6055000424385071,1.839654922485352,10000,51265.44833254814,0.8928371667861938,0.3668327927589416,0.7371399998664856,1.0903639793395996,50000 -1764.7672312259674,4.303528308868408,50019.50054001808,148762,0,50019.50054001808,0.6096000075340271,1.825580596923828,10000,51793.38860464096,0.8951091766357422,0.3621106743812561,0.738599956035614,1.090329647064209,50000 -1782.3478038311005,4.357873916625977,50529.406017541885,150280,0,50529.406017541885,0.6111000180244446,1.879051089286804,10000,52320.97790455818,0.8956672549247742,0.356880247592926,0.7396799921989441,1.1052157878875732,50000 -1800.3187172412872,4.414421319961548,51039.55972290039,151799,0,51039.55972290039,0.613800048828125,1.824073076248169,10000,52849.2077562809,0.904715359210968,0.3321236968040466,0.7416599988937378,1.0797576904296875,50000 -1817.896065711975,4.474432706832886,51549.5005209446,153317,0,51549.5005209446,0.6152999997138977,1.827741861343384,10000,53376.83444476128,0.9268175959587096,0.2608273029327392,0.7423999905586243,1.080202579498291,50000 -1835.5662310123444,4.531611204147339,52059.46150302887,154835,0,52059.46150302887,0.6127000451087952,1.845734715461731,10000,53904.571982860565,0.9178690910339355,0.285111129283905,0.7413199543952942,1.083351969718933,50000 -1853.1430249214168,4.589997291564941,52569.59485697746,156353,0,52569.59485697746,0.6220000386238098,1.8379580974578853,10000,54432.38970851898,0.917629897594452,0.2818046808242798,0.7437999844551086,1.084424614906311,50000 -1870.8468675613403,4.649376630783081,53079.72857952118,157873,0,53079.72857952118,0.615600049495697,1.8537700176239007,10000,54960.33554697037,0.919702649116516,0.2783809602260589,0.7439000010490417,1.0915087461471558,50000 -1888.887590646744,4.708525896072388,53589.63150548935,159391,0,53589.63150548935,0.6165000200271606,1.855073928833008,10000,55488.38742780685,0.919742465019226,0.274879902601242,0.7453799843788147,1.084672927856445,50000 -1906.48606300354,4.766141414642334,54099.56414723396,160908,0,54099.56414723396,0.6208000183105469,1.851715087890625,10000,56016.02576851845,0.9289500713348388,0.247159719467163,0.7472599744796753,1.0752619504928589,50000 -1924.5159137248995,4.82317328453064,54609.62043738365,162426,0,54609.62043738365,0.624500036239624,1.8417563438415527,10000,56544.21882486344,0.9414859414100648,0.2091090530157089,0.7475999593734741,1.0705816745758057,50000 -1942.612450838089,4.879745721817017,55119.633002758026,163945,0,55119.633002758026,0.6200000047683716,1.8490190505981443,10000,57072.43365240097,0.9399114847183228,0.2118914574384689,0.7474600076675415,1.0750255584716797,50000 -1960.2603447437289,4.9395411014556885,55629.809804201126,165463,0,55629.809804201126,0.6241000294685364,1.847937822341919,10000,57600.36788249016,0.9395328164100648,0.2096386849880218,0.7483199834823608,1.0747122764587402,50000 -1978.183670282364,4.996540307998657,56139.72080612183,166981,0,56139.72080612183,0.6247000098228455,1.840732455253601,10000,58128.30862569809,0.9429607391357422,0.2056682556867599,0.750059962272644,1.0746914148330688,50000 -1995.9945611953733,5.058220386505127,56649.69009900093,168498,0,56649.69009900093,0.625700056552887,1.835061192512512,10000,58656.19987845421,0.945133090019226,0.1923914104700088,0.7511999607086182,1.0681169033050537,50000 -2013.9880871772768,5.121249437332153,57159.68464636803,170016,0,57159.68464636803,0.6259000301361084,1.8455358743667605,10000,59184.30030846596,0.9563735723495485,0.1631961911916732,0.7527999877929688,1.0671993494033811,50000 -2031.8099205493927,5.179990291595459,57669.88044476509,171535,0,57669.88044476509,0.6290000081062317,1.833582758903504,10000,59712.42644166946,0.9563934803009032,0.1602298766374588,0.7529199719429016,1.065573811531067,50000 -2049.632456064224,5.242520332336426,58179.970638751984,173053,0,58179.970638751984,0.627500057220459,1.8320350646972656,10000,60240.450879096985,0.9550382494926452,0.1676359325647354,0.7534399628639221,1.057241439819336,50000 -2067.8815484046936,5.303300857543945,58689.94938135147,174572,0,58689.94938135147,0.6304000020027161,1.833614706993103,10000,60768.78918766976,0.954300820827484,0.1634673774242401,0.7548999786376953,1.0598605871200562,50000 -2085.516634464264,5.3632285594940186,59199.97342252731,176090,0,59199.97342252731,0.6302000284194946,1.832032561302185,10000,61296.55711436272,0.9560347199440002,0.160325139760971,0.7560399770736694,1.0561996698379517,50000 -2103.398220539093,5.424897432327271,59710.12469863892,177608,0,59710.12469863892,0.6304000020027161,1.8350627422332764,10000,61824.70168447495,0.9580675959587096,0.1550871878862381,0.7548399567604065,1.0575575828552246,50000 -2121.054502010345,5.489213466644287,60220.22158074379,179127,0,60220.22158074379,0.629800021648407,1.835172295570373,10000,62352.568135261536,0.9602997303009032,0.1473343223333358,0.7542799711227417,1.0583137273788452,50000 -2138.7541739940643,5.5535314083099365,60730.1836707592,180645,0,60730.1836707592,0.6306000351905823,1.829922795295716,10000,62880.34313893318,0.9606983065605164,0.1453530639410019,0.7556399703025818,1.055141806602478,50000 -2156.6417529582977,5.613237142562866,61240.36377048493,182163,0,61240.36377048493,0.6303000450134277,1.828680157661438,10000,63408.5197532177,0.959004282951355,0.1480942070484161,0.7565400004386902,1.0530595779418943,50000 -2174.271763563156,5.677387714385986,61750.54144191742,183681,0,61750.54144191742,0.6308000087738037,1.8285574913024905,10000,63936.44031405449,0.9598413109779358,0.150454580783844,0.7559999823570251,1.053408980369568,50000 -2192.182664632797,5.741932153701782,62260.458035469055,185198,0,62260.458035469055,0.6305000185966492,1.8302977085113523,10000,64464.38110399246,0.9613759517669678,0.1429975479841232,0.7555399537086487,1.053418755531311,50000 -2210.0775923728943,5.804227828979492,62753.75569033623,186666,0,62753.75569033623,0.6309000253677368,1.83073890209198,10000,64975.68322634697,0.9604591727256775,0.14778749644756317,0.7561999559402466,1.0542693138122559,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/measurements.csv deleted file mode 100644 index 5fc338605..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1993 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.66344464,6.9194984,,,,,,,,,,,,,, -1,,,0.0007772640092298,6.909999847412109,0.0009599999757483,6.910243988037109,50000.0,0.0006000000284984,6.910250186920166,10000.0,34.15979290008545,52.07652568817139,34.15979290008545,17.916585683822632,0.0,0.0 -100,0.6624111,6.8069854,,,,,,,,,,,,,, -200,0.79211426,6.5287066,,,,,,,,,,,,,, -300,0.93708897,6.2537465,,,,,,,,,,,,,, -400,1.5437571,5.9739623,,,,,,,,,,,,,, -500,2.9322286,5.7622867,,,,,,,,,,,,,, -600,2.3782432,5.613275,,,,,,,,,,,,,, -700,3.767036,5.4244046,,,,,,,,,,,,,, -800,9.465177,5.2084436,,,,,,,,,,,,,, -900,5.591413,5.0347943,,,,,,,,,,,,,, -1000,5.0477767,5.0377383,,,,,,,,,,,,,, -1100,3.9609933,4.9338927,,,,,,,,,,,,,, -1200,5.8830113,4.8742385,,,,,,,,,,,,,, -1300,5.939958,4.6633673,,,,,,,,,,,,,, -1400,7.277764,4.615408,,,,,,,,,,,,,, -1500,8.378456,4.5031548,,,,,,,,,,,,,, -1514,,,0.1859654039144516,4.123931407928467,0.1693599969148636,4.246126651763916,50000.0,0.1199000030755996,4.764464855194092,10000.0,544.2960352897644,579.9500658512115,544.2960352897644,35.58520984649658,0.0199439525604248,0.0 -1600,9.11259,4.4334493,,,,,,,,,,,,,, -1700,9.974108,4.3347588,,,,,,,,,,,,,, -1800,6.342649,4.126088,,,,,,,,,,,,,, -1900,3.1343286,4.228733,,,,,,,,,,,,,, -2000,8.294042,3.975506,,,,,,,,,,,,,, -2100,5.6185455,3.9138875,,,,,,,,,,,,,, -2200,7.4012957,3.8930373,,,,,,,,,,,,,, -2300,6.467326,3.9193401,,,,,,,,,,,,,, -2400,5.109313,3.6129632,,,,,,,,,,,,,, -2500,5.1801248,3.6855097,,,,,,,,,,,,,, -2600,5.696466,3.6703975,,,,,,,,,,,,,, -2700,4.598977,3.648963,,,,,,,,,,,,,, -2800,4.7828684,3.5821912,,,,,,,,,,,,,, -2900,3.7168326,3.516387,,,,,,,,,,,,,, -3000,3.6465104,3.268432,,,,,,,,,,,,,, -3027,,,0.3240991532802582,3.1349141597747803,0.3021599948406219,3.293057441711426,50000.0,0.2267000079154968,3.950527906417847,10000.0,1054.3922047615051,1107.8278141021729,1054.3922047615051,53.290247440338135,0.0474295616149902,0.0 -3100,5.346083,3.3608723,,,,,,,,,,,,,, -3200,2.6739566,3.209617,,,,,,,,,,,,,, -3300,2.4312909,3.2323472,,,,,,,,,,,,,, -3400,3.3343089,3.3187232,,,,,,,,,,,,,, -3500,3.6475077,3.0643811,,,,,,,,,,,,,, -3600,2.846418,3.1036189,,,,,,,,,,,,,, -3700,2.5820045,3.0617447,,,,,,,,,,,,,, -3800,3.209566,2.925579,,,,,,,,,,,,,, -3900,2.639815,2.979244,,,,,,,,,,,,,, -4000,2.1272454,2.9588394,,,,,,,,,,,,,, -4100,2.9611478,3.0068111,,,,,,,,,,,,,, -4200,2.8221183,2.8413343,,,,,,,,,,,,,, -4300,3.2046876,2.8464928,,,,,,,,,,,,,, -4400,2.796657,2.8099089,,,,,,,,,,,,,, -4500,3.6736262,2.7503407,,,,,,,,,,,,,, -4541,,,0.5005580186843872,2.1658248901367188,0.4260199964046478,2.557955503463745,50000.0,0.3175000250339508,3.3052544593811035,10000.0,1564.3699560165403,1635.432152509689,1564.3699560165403,70.83682298660278,0.0788309574127197,0.0 -4600,2.5113957,2.6956809,,,,,,,,,,,,,, -4700,2.8868377,2.7234128,,,,,,,,,,,,,, -4800,2.6611104,2.680222,,,,,,,,,,,,,, -4900,2.2165415,2.6554422,,,,,,,,,,,,,, -5000,2.492279,2.5548556,,,,,,,,,,,,,, -5100,2.4955807,2.556593,,,,,,,,,,,,,, -5200,3.2575483,2.60256,,,,,,,,,,,,,, -5300,2.2708087,2.639179,,,,,,,,,,,,,, -5400,2.5401742,2.569795,,,,,,,,,,,,,, -5500,3.0794337,2.478432,,,,,,,,,,,,,, -5600,2.1609592,2.5947354,,,,,,,,,,,,,, -5700,1.5806814,2.3928747,,,,,,,,,,,,,, -5800,1.8063723,2.4509718,,,,,,,,,,,,,, -5900,2.0052826,2.4172447,,,,,,,,,,,,,, -6000,2.214011,2.3786488,,,,,,,,,,,,,, -6057,,,0.554109513759613,1.864499568939209,0.5052399635314941,2.154295921325684,50000.0,0.3827000260353088,2.921358346939087,10000.0,2074.408198595047,2163.3676924705505,2074.408198595047,88.65340995788574,0.1107234954833984,0.0 -6100,1.7314911,2.3389537,,,,,,,,,,,,,, -6200,2.143561,2.3900106,,,,,,,,,,,,,, -6300,2.1886587,2.5406396,,,,,,,,,,,,,, -6400,2.0689726,2.4784007,,,,,,,,,,,,,, -6500,2.0200646,2.419765,,,,,,,,,,,,,, -6600,2.4602368,2.3428793,,,,,,,,,,,,,, -6700,2.430586,2.3999438,,,,,,,,,,,,,, -6800,2.3869214,2.319611,,,,,,,,,,,,,, -6900,2.2819912,2.2878582,,,,,,,,,,,,,, -7000,2.4657786,2.383002,,,,,,,,,,,,,, -7100,1.786688,2.3659508,,,,,,,,,,,,,, -7200,1.8617152,2.3780584,,,,,,,,,,,,,, -7300,1.4208926,2.2417738,,,,,,,,,,,,,, -7400,2.217002,2.3265328,,,,,,,,,,,,,, -7500,1.8320768,2.3154292,,,,,,,,,,,,,, -7574,,,0.5734215378761292,1.780144214630127,0.5266799926757812,2.0440306663513184,50000.0,0.4020000100135803,2.817870855331421,10000.0,2584.638957977295,2691.767830848694,2584.638957977295,106.74323630332948,0.1415345668792724,0.0 -7600,1.6779605,2.271511,,,,,,,,,,,,,, -7700,2.1261392,2.1974764,,,,,,,,,,,,,, -7800,2.3497806,2.1489775,,,,,,,,,,,,,, -7900,2.0947967,2.2333019,,,,,,,,,,,,,, -8000,2.0676827,2.176919,,,,,,,,,,,,,, -8100,1.8988563,2.242451,,,,,,,,,,,,,, -8200,3.398979,2.2393453,,,,,,,,,,,,,, -8300,1.5630947,2.148047,,,,,,,,,,,,,, -8400,1.6981411,2.2412593,,,,,,,,,,,,,, -8500,1.7169671,2.1591585,,,,,,,,,,,,,, -8600,1.9160553,2.1612015,,,,,,,,,,,,,, -8700,1.6090511,2.1298008,,,,,,,,,,,,,, -8800,1.8163502,2.2138546,,,,,,,,,,,,,, -8900,1.9520241,2.142529,,,,,,,,,,,,,, -9000,2.0694134,2.2077782,,,,,,,,,,,,,, -9090,,,0.5981544852256775,1.6540687084197998,0.5498999953269958,1.909760475158692,50000.0,0.4276000261306762,2.675454378128052,10000.0,3094.6172511577606,3219.9098541736603,3094.6172511577606,124.827782869339,0.1721091270446777,0.0 -9100,2.0720146,1.9992706,,,,,,,,,,,,,, -9200,1.6421903,2.0575643,,,,,,,,,,,,,, -9300,1.79058,2.2049189,,,,,,,,,,,,,, -9400,1.8403804,2.0604007,,,,,,,,,,,,,, -9500,1.9472613,2.0188408,,,,,,,,,,,,,, -9600,1.6302847,2.144133,,,,,,,,,,,,,, -9700,2.406375,2.1813118,,,,,,,,,,,,,, -9800,1.7310781,2.0023592,,,,,,,,,,,,,, -9900,2.1148357,2.0967207,,,,,,,,,,,,,, -10000,1.8011159,2.0901148,,,,,,,,,,,,,, -10100,1.6116532,2.0566509,,,,,,,,,,,,,, -10200,2.4402978,2.228544,,,,,,,,,,,,,, -10300,1.550419,2.1220908,,,,,,,,,,,,,, -10400,2.0045924,2.1522827,,,,,,,,,,,,,, -10500,1.4904696,1.9428312,,,,,,,,,,,,,, -10600,1.4582161,1.934867,,,,,,,,,,,,,, -10607,,,0.6220703125,1.550428032875061,0.5730400085449219,1.805853247642517,50000.0,0.4430000185966491,2.572227001190185,10000.0,3604.871292591095,3748.189495563507,3604.871292591095,142.7768428325653,0.1999351978302002,0.0 -10700,1.4409064,2.0088747,,,,,,,,,,,,,, -10800,1.7649857,2.1191523,,,,,,,,,,,,,, -10900,1.326355,1.9730787,,,,,,,,,,,,,, -11000,1.9983298,1.970454,,,,,,,,,,,,,, -11100,1.4808112,2.00052,,,,,,,,,,,,,, -11200,1.5703595,1.9087675,,,,,,,,,,,,,, -11300,1.6704204,2.0628405,,,,,,,,,,,,,, -11400,1.9557422,2.071817,,,,,,,,,,,,,, -11500,1.6124359,2.018315,,,,,,,,,,,,,, -11600,1.3612239,2.1294968,,,,,,,,,,,,,, -11700,1.9638371,2.0376267,,,,,,,,,,,,,, -11800,2.045148,1.9492851,,,,,,,,,,,,,, -11900,1.6652535,2.0165584,,,,,,,,,,,,,, -12000,1.8643689,1.887658,,,,,,,,,,,,,, -12100,1.4087145,1.8321851,,,,,,,,,,,,,, -12124,,,0.6409637928009033,1.4560710191726685,0.5910199880599976,1.7010043859481812,50000.0,0.4637000262737274,2.4511096477508545,10000.0,4114.898060321808,4275.873243093491,4114.898060321808,160.3536171913147,0.2311937808990478,0.0 -12200,1.479215,1.8312459,,,,,,,,,,,,,, -12300,1.3904804,1.7978433,,,,,,,,,,,,,, -12400,1.6568143,2.005314,,,,,,,,,,,,,, -12500,1.7820321,2.0208483,,,,,,,,,,,,,, -12600,1.8933858,2.0123281,,,,,,,,,,,,,, -12700,1.4318682,1.8865738,,,,,,,,,,,,,, -12800,1.4032509,1.9326957,,,,,,,,,,,,,, -12900,1.385304,2.0155513,,,,,,,,,,,,,, -13000,1.2322575,2.0349772,,,,,,,,,,,,,, -13100,1.1984185,2.0205235,,,,,,,,,,,,,, -13200,1.4748129,1.8524374,,,,,,,,,,,,,, -13300,1.6486516,1.8630954,,,,,,,,,,,,,, -13400,1.5364486,2.004592,,,,,,,,,,,,,, -13500,1.6635022,1.8040292,,,,,,,,,,,,,, -13600,1.641962,1.8703647,,,,,,,,,,,,,, -13641,,,0.684988796710968,1.264384627342224,0.5950599908828735,1.6960663795471191,50000.0,0.4702000319957733,2.450937509536743,10000.0,4624.97420334816,4803.666056632996,4624.97420334816,177.99181604385376,0.2608301639556885,0.0 -13700,1.529274,1.9265256,,,,,,,,,,,,,, -13800,1.4419895,1.8732853,,,,,,,,,,,,,, -13900,1.5493745,1.9776479,,,,,,,,,,,,,, -14000,1.7391728,1.991881,,,,,,,,,,,,,, -14100,1.4758543,2.0345201,,,,,,,,,,,,,, -14200,1.4859619,2.0267754,,,,,,,,,,,,,, -14300,1.3742003,1.9528188,,,,,,,,,,,,,, -14400,1.5642077,1.7999678,,,,,,,,,,,,,, -14500,1.4734119,1.8920437,,,,,,,,,,,,,, -14600,1.9096205,2.011089,,,,,,,,,,,,,, -14700,1.6115074,2.0950997,,,,,,,,,,,,,, -14800,1.7018667,1.8088453,,,,,,,,,,,,,, -14900,1.6382072,1.8409315,,,,,,,,,,,,,, -15000,1.6749653,1.9071767,,,,,,,,,,,,,, -15100,1.5882016,1.9903555,,,,,,,,,,,,,, -15157,,,0.671894907951355,1.306368708610535,0.6003199815750122,1.6621453762054443,50000.0,0.469400018453598,2.4176716804504395,10000.0,5134.960388422012,5331.829082727432,5134.960388422012,196.0883026123047,0.2926368713378906,0.0 -15200,1.4817896,1.9261894,,,,,,,,,,,,,, -15300,1.5778269,1.9250278,,,,,,,,,,,,,, -15400,1.4223394,1.8219543,,,,,,,,,,,,,, -15500,1.4438282,1.8685801,,,,,,,,,,,,,, -15600,1.8429892,1.9696788,,,,,,,,,,,,,, -15700,1.5408559,1.8115867,,,,,,,,,,,,,, -15800,1.7811825,1.8655764,,,,,,,,,,,,,, -15900,1.5011618,1.8922862,,,,,,,,,,,,,, -16000,1.534951,1.8418033,,,,,,,,,,,,,, -16100,1.6132334,1.8680835,,,,,,,,,,,,,, -16200,2.064982,1.9738123,,,,,,,,,,,,,, -16300,1.720173,1.8970513,,,,,,,,,,,,,, -16400,1.631113,1.8643432,,,,,,,,,,,,,, -16500,1.6353827,1.8861177,,,,,,,,,,,,,, -16600,1.962767,1.776882,,,,,,,,,,,,,, -16674,,,0.6599569320678711,1.363683581352234,0.5981000065803528,1.6931911706924438,50000.0,0.4724000096321106,2.433629035949707,10000.0,5644.989506959915,5859.929432630539,5644.989506959915,214.07593870162964,0.3269875049591064,0.0 -16700,1.2947342,1.87536,,,,,,,,,,,,,, -16800,1.4627415,1.9153506,,,,,,,,,,,,,, -16900,1.6346954,1.7917632,,,,,,,,,,,,,, -17000,1.9608189,1.9785047,,,,,,,,,,,,,, -17100,1.4457879,1.8591094,,,,,,,,,,,,,, -17200,2.1580632,1.88503,,,,,,,,,,,,,, -17300,1.7560476,1.8420784,,,,,,,,,,,,,, -17400,1.5518769,1.8623167,,,,,,,,,,,,,, -17500,1.619396,1.8341622,,,,,,,,,,,,,, -17600,1.6555371,1.8853629,,,,,,,,,,,,,, -17700,1.7738916,1.8740646,,,,,,,,,,,,,, -17800,1.4917712,1.8989458,,,,,,,,,,,,,, -17900,1.6262993,1.7964662,,,,,,,,,,,,,, -18000,1.9979616,1.8853014,,,,,,,,,,,,,, -18100,1.6929624,1.7398126,,,,,,,,,,,,,, -18191,,,0.6560705900192261,1.3798298835754397,0.5999000072479248,1.6763943433761597,50000.0,0.4735000133514404,2.4417972564697266,10000.0,6154.933459997177,6387.788614749908,6154.933459997177,231.9118676185608,0.356968879699707,0.0 -18200,1.7593797,1.9185884,,,,,,,,,,,,,, -18300,1.6201738,1.7491367,,,,,,,,,,,,,, -18400,1.6866326,1.7325525,,,,,,,,,,,,,, -18500,1.6332737,1.8243914,,,,,,,,,,,,,, -18600,1.7979324,1.8454487,,,,,,,,,,,,,, -18700,1.710421,1.7847699,,,,,,,,,,,,,, -18800,1.6306641,1.9588494,,,,,,,,,,,,,, -18900,1.5796932,1.8452286,,,,,,,,,,,,,, -19000,1.6870387,1.7270548,,,,,,,,,,,,,, -19100,1.5426409,1.7528191,,,,,,,,,,,,,, -19200,1.7069323,1.8270367,,,,,,,,,,,,,, -19300,1.4800068,1.8790228,,,,,,,,,,,,,, -19400,1.7020825,1.7150434,,,,,,,,,,,,,, -19500,1.7130017,1.7554597,,,,,,,,,,,,,, -19600,1.5785666,1.7288499,,,,,,,,,,,,,, -19700,1.4516242,1.8475779,,,,,,,,,,,,,, -19708,,,0.6623684763908386,1.3440089225769043,0.6084399819374084,1.626311182975769,50000.0,0.4883000254631042,2.366896867752075,10000.0,6665.012505054474,6916.881340265274,6665.012505054474,250.8443946838379,0.3887760639190674,0.0 -19800,1.7225894,1.8879912,,,,,,,,,,,,,, -19900,1.5688943,1.7361807,,,,,,,,,,,,,, -20000,1.7550044,1.8617681,,,,,,,,,,,,,, -20100,1.6782769,1.844011,,,,,,,,,,,,,, -20200,1.7124578,1.9933736,,,,,,,,,,,,,, -20300,2.0009077,1.7610687,,,,,,,,,,,,,, -20400,2.0002182,1.7883955,,,,,,,,,,,,,, -20500,1.6148998,1.816967,,,,,,,,,,,,,, -20600,1.6884875,1.7524558,,,,,,,,,,,,,, -20700,2.1044939,1.8345742,,,,,,,,,,,,,, -20800,1.8021199,1.9537928,,,,,,,,,,,,,, -20900,1.8955414,1.8098612,,,,,,,,,,,,,, -21000,1.7120397,1.8797067,,,,,,,,,,,,,, -21100,1.7283866,1.9044232,,,,,,,,,,,,,, -21200,1.5935943,1.7805293,,,,,,,,,,,,,, -21225,,,0.6512874364852905,1.397564172744751,0.6033200025558472,1.6574549674987793,50000.0,0.4772000312805176,2.443645000457764,10000.0,7175.144921779633,7444.737069368362,7175.144921779633,268.487957239151,0.4199604988098144,0.0 -21300,1.6306225,1.7278197,,,,,,,,,,,,,, -21400,1.8954518,1.7121041,,,,,,,,,,,,,, -21500,1.5503947,1.79165,,,,,,,,,,,,,, -21600,1.5239083,1.7786541,,,,,,,,,,,,,, -21700,1.6466169,1.8331417,,,,,,,,,,,,,, -21800,1.6027143,1.7564492,,,,,,,,,,,,,, -21900,1.6931686,1.6811651,,,,,,,,,,,,,, -22000,1.5900224,1.7220078,,,,,,,,,,,,,, -22100,2.0097377,1.7083522,,,,,,,,,,,,,, -22200,1.9630853,1.8135875,,,,,,,,,,,,,, -22300,1.7339733,1.7593253,,,,,,,,,,,,,, -22400,2.3990164,1.7484283,,,,,,,,,,,,,, -22500,1.7634221,1.7676735,,,,,,,,,,,,,, -22600,1.6935054,1.720268,,,,,,,,,,,,,, -22700,1.8258785,1.8042456,,,,,,,,,,,,,, -22742,,,0.6950533986091614,1.1915576457977295,0.6091399788856506,1.6365070343017578,50000.0,0.4816000163555145,2.404182910919189,10000.0,7685.214646100998,7972.439017057419,7685.214646100998,286.0397162437439,0.4515516757965088,0.0 -22800,1.5819869,1.6327103,,,,,,,,,,,,,, -22900,1.5471214,1.7213402,,,,,,,,,,,,,, -23000,1.8179138,1.7974714,,,,,,,,,,,,,, -23100,1.746962,1.9609082,,,,,,,,,,,,,, -23200,1.90077,1.9082875,,,,,,,,,,,,,, -23300,1.8474897,1.8042699,,,,,,,,,,,,,, -23400,1.7355382,1.7205622,,,,,,,,,,,,,, -23500,1.6805203,1.7360951,,,,,,,,,,,,,, -23600,1.5588757,1.8662082,,,,,,,,,,,,,, -23700,1.6498388,1.7501858,,,,,,,,,,,,,, -23800,1.6122153,1.802878,,,,,,,,,,,,,, -23900,1.7782897,1.7618223,,,,,,,,,,,,,, -24000,1.545955,1.6701089,,,,,,,,,,,,,, -24100,1.5643371,1.8583658,,,,,,,,,,,,,, -24200,1.7018429,1.8414402,,,,,,,,,,,,,, -24260,,,0.6743263602256775,1.2909504175186155,0.6072799563407898,1.6269196271896362,50000.0,0.4821000099182129,2.375473737716675,10000.0,8195.194686412811,8500.469371795654,8195.194686412811,304.0058841705322,0.4865305423736572,0.0 -24300,1.697612,1.8395197,,,,,,,,,,,,,, -24400,1.507696,1.804873,,,,,,,,,,,,,, -24500,1.7032492,1.7851233,,,,,,,,,,,,,, -24600,1.6192722,1.7174988,,,,,,,,,,,,,, -24700,1.840575,1.8916125,,,,,,,,,,,,,, -24800,1.7779461,1.7646433,,,,,,,,,,,,,, -24900,1.7757595,1.7376484,,,,,,,,,,,,,, -25000,1.8883162,1.8524221,,,,,,,,,,,,,, -25100,1.6728451,1.6741099,,,,,,,,,,,,,, -25200,1.8736565,1.7201345,,,,,,,,,,,,,, -25300,1.887365,1.7679275,,,,,,,,,,,,,, -25400,1.8361812,1.7475517,,,,,,,,,,,,,, -25500,1.6433231,1.6739737,,,,,,,,,,,,,, -25600,1.6087388,1.7418158,,,,,,,,,,,,,, -25700,1.7214711,1.6238303,,,,,,,,,,,,,, -25777,,,0.6813815236091614,1.2523821592330933,0.6219599843025208,1.5588935613632202,50000.0,0.5004000067710876,2.291343927383423,10000.0,8705.145479679108,9028.48357987404,8705.145479679108,321.99264454841614,0.5148470401763916,0.0 -25800,1.9731874,1.802073,,,,,,,,,,,,,, -25900,1.7477868,1.8126915,,,,,,,,,,,,,, -26000,1.8245376,1.7439853,,,,,,,,,,,,,, -26100,1.6964384,1.7599959,,,,,,,,,,,,,, -26200,1.6922188,1.683889,,,,,,,,,,,,,, -26300,1.679767,1.7908212,,,,,,,,,,,,,, -26400,1.679549,1.7853835,,,,,,,,,,,,,, -26500,1.8519632,1.7670686,,,,,,,,,,,,,, -26600,1.6306708,1.8513759,,,,,,,,,,,,,, -26700,1.8973712,1.8429371,,,,,,,,,,,,,, -26800,1.7068073,1.7055805,,,,,,,,,,,,,, -26900,1.6229832,1.6789591,,,,,,,,,,,,,, -27000,1.7151203,1.567615,,,,,,,,,,,,,, -27100,1.7296673,1.892921,,,,,,,,,,,,,, -27200,1.6331791,1.6643875,,,,,,,,,,,,,, -27294,,,0.6639030575752258,1.3441171646118164,0.6081399917602539,1.6323024034500122,50000.0,0.4772000312805176,2.400562047958374,10000.0,9215.205620288849,9556.281441688538,9215.205620288849,339.64798951148987,0.5484886169433594,0.0 -27300,1.7209911,1.690364,,,,,,,,,,,,,, -27400,1.7125944,1.7304378,,,,,,,,,,,,,, -27500,1.5966582,1.7584422,,,,,,,,,,,,,, -27600,1.96898,1.73113,,,,,,,,,,,,,, -27700,1.7755308,1.7459738,,,,,,,,,,,,,, -27800,1.6804446,1.8004183,,,,,,,,,,,,,, -27900,2.229513,1.6577578,,,,,,,,,,,,,, -28000,1.613449,1.7273306,,,,,,,,,,,,,, -28100,1.7018701,1.842226,,,,,,,,,,,,,, -28200,1.8569434,1.7321366,,,,,,,,,,,,,, -28300,1.9311036,1.6994867,,,,,,,,,,,,,, -28400,1.8873328,1.8178034,,,,,,,,,,,,,, -28500,1.6704626,1.6435996,,,,,,,,,,,,,, -28600,1.7626472,1.6413454,,,,,,,,,,,,,, -28700,2.038272,1.7974008,,,,,,,,,,,,,, -28800,1.8794999,1.7636006,,,,,,,,,,,,,, -28812,,,0.6672313213348389,1.3261438608169556,0.6168999671936035,1.6034150123596191,50000.0,0.4860000312328338,2.371846914291382,10000.0,9725.12575149536,10084.276794195175,9725.12575149536,357.6402759552002,0.5823171138763428,0.0 -28900,1.8246449,1.7730347,,,,,,,,,,,,,, -29000,1.7055045,1.7131215,,,,,,,,,,,,,, -29100,1.8613131,1.6210659,,,,,,,,,,,,,, -29200,1.5365342,1.6409302,,,,,,,,,,,,,, -29300,1.6848435,1.6255305,,,,,,,,,,,,,, -29400,1.9738936,1.7090095,,,,,,,,,,,,,, -29500,1.7433579,1.7818749,,,,,,,,,,,,,, -29600,1.7980528,1.6869643,,,,,,,,,,,,,, -29700,1.6664851,1.7706264,,,,,,,,,,,,,, -29800,1.7806431,1.7704692,,,,,,,,,,,,,, -29900,1.7421279,1.7487915,,,,,,,,,,,,,, -30000,1.8130629,1.7664326,,,,,,,,,,,,,, -30100,1.879065,1.7034923,,,,,,,,,,,,,, -30200,2.1216857,1.6373385,,,,,,,,,,,,,, -30300,1.7389685,1.8370123,,,,,,,,,,,,,, -30329,,,0.6734893321990967,1.2919286489486694,0.610260009765625,1.616947889328003,50000.0,0.4878000319004059,2.3278603553771973,10000.0,10235.195373535156,10612.25537633896,10235.195373535156,375.4659821987152,0.6163196563720703,0.0 -30400,1.6676463,1.7094344,,,,,,,,,,,,,, -30500,2.0218968,1.7180321,,,,,,,,,,,,,, -30600,1.6565769,1.6373562,,,,,,,,,,,,,, -30700,1.7092136,1.7936039,,,,,,,,,,,,,, -30800,1.7362825,1.6838392,,,,,,,,,,,,,, -30900,1.8735955,1.7634125,,,,,,,,,,,,,, -31000,1.9905239,1.6756679,,,,,,,,,,,,,, -31100,1.6933302,1.7775443,,,,,,,,,,,,,, -31200,1.6833613,1.592122,,,,,,,,,,,,,, -31300,1.735799,1.6429563,,,,,,,,,,,,,, -31400,1.6845059,1.7396138,,,,,,,,,,,,,, -31500,1.6355087,1.6259706,,,,,,,,,,,,,, -31600,1.7621626,1.6161433,,,,,,,,,,,,,, -31700,1.8645747,1.733796,,,,,,,,,,,,,, -31800,1.9096862,1.843061,,,,,,,,,,,,,, -31847,,,0.7060347199440002,1.1288166046142578,0.6247999668121338,1.5537689924240112,50000.0,0.499500036239624,2.2625062465667725,10000.0,10745.252250671389,11140.534215211868,10745.252250671389,393.60303115844727,0.6519203186035156,0.0 -31900,1.7807997,1.8096926,,,,,,,,,,,,,, -32000,1.7268801,1.8329866,,,,,,,,,,,,,, -32100,1.784284,1.6061664,,,,,,,,,,,,,, -32200,1.7388287,1.6860657,,,,,,,,,,,,,, -32300,1.727747,1.6121638,,,,,,,,,,,,,, -32400,1.7352904,1.6093049,,,,,,,,,,,,,, -32500,1.9213762,1.7129713,,,,,,,,,,,,,, -32600,1.7940867,1.7149341,,,,,,,,,,,,,, -32700,2.123987,1.6638093,,,,,,,,,,,,,, -32800,2.1220288,1.6846832,,,,,,,,,,,,,, -32900,1.6646996,1.658567,,,,,,,,,,,,,, -33000,1.6462947,1.6677854,,,,,,,,,,,,,, -33100,1.8668808,1.7344579,,,,,,,,,,,,,, -33200,1.6072257,1.6946034,,,,,,,,,,,,,, -33300,2.1922705,1.751282,,,,,,,,,,,,,, -33365,,,0.6925222873687744,1.2043423652648926,0.6267799735069275,1.553866982460022,50000.0,0.4976000189781189,2.288537979125977,10000.0,11255.309514045715,11668.54907798767,11255.309514045715,411.4780659675598,0.6850986480712891,0.0 -33400,1.73513,1.6448255,,,,,,,,,,,,,, -33500,1.6948476,1.6539507,,,,,,,,,,,,,, -33600,1.735615,1.6839488,,,,,,,,,,,,,, -33700,1.6505597,1.8133593,,,,,,,,,,,,,, -33800,1.8093507,1.6373755,,,,,,,,,,,,,, -33900,1.6590669,1.6750737,,,,,,,,,,,,,, -34000,1.9991962,1.7741704,,,,,,,,,,,,,, -34100,1.527206,1.7118405,,,,,,,,,,,,,, -34200,1.856235,1.7407873,,,,,,,,,,,,,, -34300,1.8308469,1.8272076,,,,,,,,,,,,,, -34400,2.109743,1.7605869,,,,,,,,,,,,,, -34500,1.7472322,1.7186483,,,,,,,,,,,,,, -34600,1.7214799,1.7362733,,,,,,,,,,,,,, -34700,1.8045293,1.6518476,,,,,,,,,,,,,, -34800,1.6805174,1.7123576,,,,,,,,,,,,,, -34883,,,0.6701012253761292,1.305039405822754,0.6144799590110779,1.6147181987762451,50000.0,0.4850000143051147,2.387573003768921,10000.0,11765.304506778715,12196.39869570732,11765.304506778715,429.2470765113831,0.7219088077545166,0.0 -34900,1.9160686,1.799295,,,,,,,,,,,,,, -35000,1.9356877,1.7624612,,,,,,,,,,,,,, -35100,1.9311123,1.6993833,,,,,,,,,,,,,, -35200,1.8119289,1.6528776,,,,,,,,,,,,,, -35300,1.9235287,1.6560423,,,,,,,,,,,,,, -35400,1.7487334,1.7052193,,,,,,,,,,,,,, -35500,1.8032908,1.6342956,,,,,,,,,,,,,, -35600,1.8849291,1.6681292,,,,,,,,,,,,,, -35700,1.8254733,1.6241095,,,,,,,,,,,,,, -35800,1.6002496,1.6992729,,,,,,,,,,,,,, -35900,1.7959044,1.6780711,,,,,,,,,,,,,, -36000,1.7188033,1.7135983,,,,,,,,,,,,,, -36100,1.7118483,1.626118,,,,,,,,,,,,,, -36200,1.6172764,1.6327566,,,,,,,,,,,,,, -36300,1.649187,1.71584,,,,,,,,,,,,,, -36400,1.7616303,1.753985,,,,,,,,,,,,,, -36401,,,0.6705994606018066,1.3002781867980957,0.6158999800682068,1.5967798233032229,50000.0,0.480400025844574,2.386909008026123,10000.0,12275.59494113922,12724.276602745056,12275.59494113922,446.7454869747162,0.7614481449127197,0.0 -36500,1.9768198,1.8079735,,,,,,,,,,,,,, -36600,1.6071268,1.5172949,,,,,,,,,,,,,, -36700,1.9462801,1.6929413,,,,,,,,,,,,,, -36800,1.6878059,1.6273263,,,,,,,,,,,,,, -36900,1.7803113,1.7005153,,,,,,,,,,,,,, -37000,1.6075677,1.6626639,,,,,,,,,,,,,, -37100,1.5479972,1.5868745,,,,,,,,,,,,,, -37200,1.7414774,1.7103257,,,,,,,,,,,,,, -37300,1.8225609,1.701253,,,,,,,,,,,,,, -37400,2.126115,1.6613259,,,,,,,,,,,,,, -37500,1.8479836,1.696564,,,,,,,,,,,,,, -37600,1.6955462,1.6536225,,,,,,,,,,,,,, -37700,1.8279146,1.7127842,,,,,,,,,,,,,, -37800,1.7848513,1.586743,,,,,,,,,,,,,, -37900,1.7913139,1.751209,,,,,,,,,,,,,, -37918,,,0.6964285373687744,1.1979633569717407,0.6372399926185608,1.4970052242279053,50000.0,0.5107000470161438,2.221066236495972,10000.0,12785.589760780334,13252.253804206848,12785.589760780334,464.6421930789948,0.7978343963623047,0.0 -38000,1.7003258,1.6549995,,,,,,,,,,,,,, -38100,1.881621,1.5806873,,,,,,,,,,,,,, -38200,1.6179727,1.7126378,,,,,,,,,,,,,, -38300,1.9117817,1.7331982,,,,,,,,,,,,,, -38400,1.7906215,1.6153696,,,,,,,,,,,,,, -38500,1.8614137,1.6481211,,,,,,,,,,,,,, -38600,1.8177704,1.669172,,,,,,,,,,,,,, -38700,1.7094353,1.5970267,,,,,,,,,,,,,, -38800,1.9109297,1.6618745,,,,,,,,,,,,,, -38900,1.9571171,1.6760995,,,,,,,,,,,,,, -39000,1.8331558,1.4864004,,,,,,,,,,,,,, -39100,1.8346037,1.7435393,,,,,,,,,,,,,, -39200,1.7326118,1.6501926,,,,,,,,,,,,,, -39300,1.9925566,1.6261971,,,,,,,,,,,,,, -39400,1.8841301,1.6308763,,,,,,,,,,,,,, -39436,,,0.7492027878761292,0.9707934260368348,0.644320011138916,1.4667904376983645,50000.0,0.5151000022888184,2.183366060256958,10000.0,13295.721585988998,13780.402867794037,13295.721585988998,482.5713183879852,0.8367643356323242,0.0 -39500,1.816121,1.633582,,,,,,,,,,,,,, -39600,2.1246963,1.6638566,,,,,,,,,,,,,, -39700,1.8872383,1.6879768,,,,,,,,,,,,,, -39800,2.0558863,1.7811688,,,,,,,,,,,,,, -39900,1.9134314,1.6393429,,,,,,,,,,,,,, -40000,1.688912,1.6934319,,,,,,,,,,,,,, -40100,1.8283771,1.6093144,,,,,,,,,,,,,, -40200,1.7486985,1.6710098,,,,,,,,,,,,,, -40300,1.8382276,1.5654228,,,,,,,,,,,,,, -40400,1.9143403,1.7296337,,,,,,,,,,,,,, -40500,1.7294457,1.5671277,,,,,,,,,,,,,, -40600,2.8121812,1.703895,,,,,,,,,,,,,, -40700,1.7524246,1.7723322,,,,,,,,,,,,,, -40800,2.010658,1.7209893,,,,,,,,,,,,,, -40900,1.9718999,1.6829065,,,,,,,,,,,,,, -40954,,,0.711933970451355,1.1202871799468994,0.6308199763298035,1.5269010066986084,50000.0,0.5002000331878662,2.254570484161377,10000.0,13805.760528564451,14308.188065290453,13805.760528564451,500.2351813316345,0.8699560165405273,0.0 -41000,1.8214948,1.7265092,,,,,,,,,,,,,, -41100,1.7897288,1.6288437,,,,,,,,,,,,,, -41200,1.9011043,1.7581973,,,,,,,,,,,,,, -41300,2.029094,1.7139332,,,,,,,,,,,,,, -41400,1.879817,1.6055192,,,,,,,,,,,,,, -41500,1.8257176,1.7609553,,,,,,,,,,,,,, -41600,2.150733,1.65622,,,,,,,,,,,,,, -41700,2.0246828,1.6738099,,,,,,,,,,,,,, -41800,1.938521,1.5657686,,,,,,,,,,,,,, -41900,1.8390136,1.7076179,,,,,,,,,,,,,, -42000,1.7758952,1.5995505,,,,,,,,,,,,,, -42100,1.7834445,1.5689777,,,,,,,,,,,,,, -42200,2.0355883,1.6516593,,,,,,,,,,,,,, -42300,1.5845879,1.5628235,,,,,,,,,,,,,, -42400,2.1180782,1.7105912,,,,,,,,,,,,,, -42473,,,0.7022480964660645,1.17045259475708,0.6351400017738342,1.5160480737686155,50000.0,0.5049000382423401,2.2341115474700928,10000.0,14315.898855924606,14836.172902822496,14315.898855924606,517.9955246448517,0.9070932865142822,0.0 -42500,1.7195851,1.5751802,,,,,,,,,,,,,, -42600,1.8594745,1.6071126,,,,,,,,,,,,,, -42700,1.7778596,1.6511588,,,,,,,,,,,,,, -42800,2.2408187,1.6559634,,,,,,,,,,,,,, -42900,1.9815247,1.6542226,,,,,,,,,,,,,, -43000,1.8365386,1.691568,,,,,,,,,,,,,, -43100,1.8358436,1.6624337,,,,,,,,,,,,,, -43200,1.8407048,1.6293237,,,,,,,,,,,,,, -43300,1.6792988,1.6657096,,,,,,,,,,,,,, -43400,2.072963,1.5834758,,,,,,,,,,,,,, -43500,1.7668277,1.576503,,,,,,,,,,,,,, -43600,1.8191838,1.7083716,,,,,,,,,,,,,, -43700,2.070064,1.6041623,,,,,,,,,,,,,, -43800,1.7123715,1.7297735,,,,,,,,,,,,,, -43900,1.9784944,1.6289338,,,,,,,,,,,,,, -43991,,,0.6902702450752258,1.219701886177063,0.6301400065422058,1.5335793495178225,50000.0,0.5031999945640564,2.266105890274048,10000.0,14826.120551109314,15364.315973997116,14826.120551109314,535.8300864696503,0.9444942474365234,0.0 -44000,1.966048,1.6815963,,,,,,,,,,,,,, -44100,1.7816385,1.6417632,,,,,,,,,,,,,, -44200,1.9497297,1.7377212,,,,,,,,,,,,,, -44300,2.1530797,1.6360632,,,,,,,,,,,,,, -44400,1.9753865,1.7046794,,,,,,,,,,,,,, -44500,2.2057667,1.8028934,,,,,,,,,,,,,, -44600,1.7903287,1.6129138,,,,,,,,,,,,,, -44700,1.9245603,1.6838548,,,,,,,,,,,,,, -44800,1.7328713,1.6023086,,,,,,,,,,,,,, -44900,1.8551894,1.6201344,,,,,,,,,,,,,, -45000,1.7211262,1.5728983,,,,,,,,,,,,,, -45100,1.9782118,1.6855713,,,,,,,,,,,,,, -45200,1.7358586,1.5461361,,,,,,,,,,,,,, -45300,1.8042833,1.7159156,,,,,,,,,,,,,, -45400,1.6740447,1.5291489,,,,,,,,,,,,,, -45500,1.7785774,1.5076704,,,,,,,,,,,,,, -45509,,,0.7076291441917419,1.1384071111679075,0.6442599892616272,1.483494520187378,50000.0,0.5153000354766846,2.2475781440734863,10000.0,15336.121745824814,15892.178544044496,15336.121745824814,553.6056270599365,0.9802684783935548,0.0 -45600,1.7985115,1.6356901,,,,,,,,,,,,,, -45700,1.8492888,1.6351485,,,,,,,,,,,,,, -45800,2.096976,1.6190665,,,,,,,,,,,,,, -45900,1.9388297,1.6724772,,,,,,,,,,,,,, -46000,1.7999402,1.6571646,,,,,,,,,,,,,, -46100,1.7687987,1.5642034,,,,,,,,,,,,,, -46200,2.0837867,1.7362262,,,,,,,,,,,,,, -46300,1.7356812,1.4638367,,,,,,,,,,,,,, -46400,1.6477364,1.5274695,,,,,,,,,,,,,, -46500,1.6236386,1.6183896,,,,,,,,,,,,,, -46600,1.7876174,1.633049,,,,,,,,,,,,,, -46700,1.7187285,1.5544157,,,,,,,,,,,,,, -46800,1.81386,1.6284823,,,,,,,,,,,,,, -46900,1.8013829,1.6039684,,,,,,,,,,,,,, -47000,1.8012445,1.7190123,,,,,,,,,,,,,, -47027,,,0.6973453164100647,1.1850770711898804,0.6376000046730042,1.4864510297775269,50000.0,0.5103000402450562,2.221822500228882,10000.0,15846.208704471588,16419.95223212242,15846.208704471588,571.2046520709991,1.0184855461120603,0.0 -47100,1.9407839,1.5982251,,,,,,,,,,,,,, -47200,1.8304727,1.6540662,,,,,,,,,,,,,, -47300,2.04409,1.6757183,,,,,,,,,,,,,, -47400,1.9528738,1.6660197,,,,,,,,,,,,,, -47500,1.7958174,1.6824921,,,,,,,,,,,,,, -47600,1.668046,1.549953,,,,,,,,,,,,,, -47700,1.9422771,1.5733249,,,,,,,,,,,,,, -47800,1.9118327,1.6743715,,,,,,,,,,,,,, -47900,1.8836858,1.5789136,,,,,,,,,,,,,, -48000,1.819126,1.5625178,,,,,,,,,,,,,, -48100,1.6598558,1.5030446,,,,,,,,,,,,,, -48200,2.0352705,1.640595,,,,,,,,,,,,,, -48300,1.9092362,1.7686943,,,,,,,,,,,,,, -48400,1.814877,1.7898213,,,,,,,,,,,,,, -48500,2.0199738,1.5373608,,,,,,,,,,,,,, -48545,,,0.7341358065605164,1.0244120359420776,0.6324599981307983,1.5245925188064575,50000.0,0.5067000389099121,2.224776268005371,10000.0,16356.312096595764,16947.796236276627,16356.312096595764,588.8579633235931,1.0557844638824463,0.0 -48600,1.679946,1.5551212,,,,,,,,,,,,,, -48700,1.9481996,1.5584941,,,,,,,,,,,,,, -48800,1.9209868,1.589592,,,,,,,,,,,,,, -48900,1.7899909,1.641105,,,,,,,,,,,,,, -49000,2.1151965,1.5561231,,,,,,,,,,,,,, -49100,1.8841976,1.6094778,,,,,,,,,,,,,, -49200,2.1251926,1.5148872,,,,,,,,,,,,,, -49300,1.8746125,1.6593709,,,,,,,,,,,,,, -49400,1.9131987,1.7263904,,,,,,,,,,,,,, -49500,1.868555,1.5140021,,,,,,,,,,,,,, -49600,1.9097666,1.6674343,,,,,,,,,,,,,, -49700,1.9156669,1.6698825,,,,,,,,,,,,,, -49800,1.8777659,1.6795104,,,,,,,,,,,,,, -49900,1.8176186,1.6121223,,,,,,,,,,,,,, -50000,1.7071095,1.521069,,,,,,,,,,,,,, -50063,,,0.7353116869926453,1.0175830125808716,0.6514399647712708,1.4259246587753296,50000.0,0.5229000449180603,2.150914430618286,10000.0,16866.38495707512,17475.897213935852,16866.38495707512,606.7944264411926,1.0985126495361328,0.0 -50100,1.9927087,1.6448076,,,,,,,,,,,,,, -50200,1.7165222,1.6835982,,,,,,,,,,,,,, -50300,2.0626507,1.6527343,,,,,,,,,,,,,, -50400,1.9329213,1.5129985,,,,,,,,,,,,,, -50500,1.9157923,1.7693348,,,,,,,,,,,,,, -50600,2.0516949,1.6056069,,,,,,,,,,,,,, -50700,1.8186868,1.6095066,,,,,,,,,,,,,, -50800,1.9114405,1.62648,,,,,,,,,,,,,, -50900,1.917592,1.5499327,,,,,,,,,,,,,, -51000,1.7579722,1.6576884,,,,,,,,,,,,,, -51100,1.7236997,1.706196,,,,,,,,,,,,,, -51200,1.898023,1.62145,,,,,,,,,,,,,, -51300,1.7007691,1.7067152,,,,,,,,,,,,,, -51400,1.8295734,1.6305526,,,,,,,,,,,,,, -51500,1.9876603,1.6422622,,,,,,,,,,,,,, -51583,,,0.7122129797935486,1.10273540019989,0.6414600014686584,1.4719308614730835,50000.0,0.5144000053405762,2.2002127170562744,10000.0,17376.622447252274,18004.093755483627,17376.622447252274,624.6651320457458,1.1381235122680664,0.0 -51600,1.8889618,1.5175142,,,,,,,,,,,,,, -51700,2.0152993,1.6562033,,,,,,,,,,,,,, -51800,1.8129957,1.5988348,,,,,,,,,,,,,, -51900,2.004021,1.6531947,,,,,,,,,,,,,, -52000,1.8759596,1.6533196,,,,,,,,,,,,,, -52100,1.7328273,1.5228572,,,,,,,,,,,,,, -52200,2.0832043,1.4951059,,,,,,,,,,,,,, -52300,2.063225,1.6621959,,,,,,,,,,,,,, -52400,1.7860612,1.5941473,,,,,,,,,,,,,, -52500,1.9942813,1.5585781,,,,,,,,,,,,,, -52600,1.8669719,1.4744773,,,,,,,,,,,,,, -52700,1.8276782,1.6894586,,,,,,,,,,,,,, -52800,1.8370415,1.5531466,,,,,,,,,,,,,, -52900,1.7580898,1.5866845,,,,,,,,,,,,,, -53000,1.8062586,1.5283896,,,,,,,,,,,,,, -53100,1.9336963,1.7311959,,,,,,,,,,,,,, -53101,,,0.7161391973495483,1.1013281345367432,0.6503399610519409,1.445090889930725,50000.0,0.5200999975204468,2.178856134414673,10000.0,17886.875101804733,18531.92908811569,17886.875101804733,642.1593663692474,1.176835536956787,0.0 -53200,1.8229729,1.5193725,,,,,,,,,,,,,, -53300,1.8500069,1.6421709,,,,,,,,,,,,,, -53400,2.0272782,1.5112047,,,,,,,,,,,,,, -53500,1.9760464,1.5363163,,,,,,,,,,,,,, -53600,2.0567336,1.5518911,,,,,,,,,,,,,, -53700,1.9180503,1.6439672,,,,,,,,,,,,,, -53800,1.9848936,1.686425,,,,,,,,,,,,,, -53900,1.6902505,1.5536994,,,,,,,,,,,,,, -54000,1.7580223,1.5581872,,,,,,,,,,,,,, -54100,1.885272,1.681694,,,,,,,,,,,,,, -54200,1.8173496,1.6288662,,,,,,,,,,,,,, -54300,1.9961462,1.6732004,,,,,,,,,,,,,, -54400,2.0882616,1.5976646,,,,,,,,,,,,,, -54500,1.8555181,1.5435412,,,,,,,,,,,,,, -54600,1.7660192,1.4967513,,,,,,,,,,,,,, -54619,,,0.7102997303009033,1.1230549812316897,0.6488800048828125,1.440788507461548,50000.0,0.5213000178337097,2.167120695114136,10000.0,18396.95570278168,19059.83021473885,18396.95570278168,659.8881301879883,1.2191162109375,0.0 -54700,2.1254969,1.4984102,,,,,,,,,,,,,, -54800,1.8924887,1.6953007,,,,,,,,,,,,,, -54900,1.7515669,1.6191034,,,,,,,,,,,,,, -55000,1.817169,1.5965656,,,,,,,,,,,,,, -55100,2.0102997,1.6537611,,,,,,,,,,,,,, -55200,1.8417581,1.5828036,,,,,,,,,,,,,, -55300,1.7404966,1.4995093,,,,,,,,,,,,,, -55400,1.8124743,1.5866842,,,,,,,,,,,,,, -55500,1.7575953,1.5190845,,,,,,,,,,,,,, -55600,1.8677838,1.4847051,,,,,,,,,,,,,, -55700,1.943709,1.5792341,,,,,,,,,,,,,, -55800,2.0155551,1.5846648,,,,,,,,,,,,,, -55900,2.0768285,1.6379094,,,,,,,,,,,,,, -56000,2.0814338,1.5670248,,,,,,,,,,,,,, -56100,1.9315655,1.491107,,,,,,,,,,,,,, -56137,,,0.6968669891357422,1.1886688470840454,0.6329799890518188,1.519242763519287,50000.0,0.5105000138282776,2.2319228649139404,10000.0,18906.910097837448,19587.460858106613,18906.910097837448,677.467814207077,1.266390085220337,0.0 -56200,2.0227046,1.5789273,,,,,,,,,,,,,, -56300,2.053976,1.4771985,,,,,,,,,,,,,, -56400,1.8874779,1.6241806,,,,,,,,,,,,,, -56500,1.8760841,1.6148031,,,,,,,,,,,,,, -56600,1.8286705,1.7115784,,,,,,,,,,,,,, -56700,2.0001357,1.6666217,,,,,,,,,,,,,, -56800,1.8866557,1.7032045,,,,,,,,,,,,,, -56900,1.7601248,1.535034,,,,,,,,,,,,,, -57000,1.9514345,1.6769838,,,,,,,,,,,,,, -57100,2.0996284,1.5306375,,,,,,,,,,,,,, -57200,1.7814412,1.5631111,,,,,,,,,,,,,, -57300,1.9763523,1.6285508,,,,,,,,,,,,,, -57400,2.094291,1.5768249,,,,,,,,,,,,,, -57500,2.0562792,1.5738525,,,,,,,,,,,,,, -57600,1.9512424,1.6041629,,,,,,,,,,,,,, -57656,,,0.7419084906578064,0.977180004119873,0.6464999914169312,1.4487738609313965,50000.0,0.5085000395774841,2.2348201274871826,10000.0,19417.0595304966,20115.9705028534,19417.0595304966,695.7363994121552,1.3087108135223389,0.0 -57700,1.9013517,1.5629897,,,,,,,,,,,,,, -57800,1.7672137,1.5264984,,,,,,,,,,,,,, -57900,1.860235,1.551216,,,,,,,,,,,,,, -58000,1.7884399,1.6441791,,,,,,,,,,,,,, -58100,2.0097969,1.4680642,,,,,,,,,,,,,, -58200,2.0659833,1.61793,,,,,,,,,,,,,, -58300,1.8681378,1.461103,,,,,,,,,,,,,, -58400,1.9919143,1.524148,,,,,,,,,,,,,, -58500,2.1289294,1.4933616,,,,,,,,,,,,,, -58600,2.0725975,1.5751462,,,,,,,,,,,,,, -58700,1.7776253,1.6427312,,,,,,,,,,,,,, -58800,1.9424839,1.6180433,,,,,,,,,,,,,, -58900,1.8017801,1.4858161,,,,,,,,,,,,,, -59000,1.8783988,1.6473677,,,,,,,,,,,,,, -59100,1.7941014,1.546119,,,,,,,,,,,,,, -59175,,,0.7238121628761292,1.0625795125961304,0.6496399641036987,1.435991883277893,50000.0,0.5190000534057617,2.1798503398895264,10000.0,19927.296014785767,20643.969654798508,19927.296014785767,713.406112909317,1.3524727821350098,0.0 -59200,2.3041787,1.5929805,,,,,,,,,,,,,, -59300,1.8007852,1.4787352,,,,,,,,,,,,,, -59400,1.846621,1.541946,,,,,,,,,,,,,, -59500,1.8187732,1.4665836,,,,,,,,,,,,,, -59600,2.280161,1.6391335,,,,,,,,,,,,,, -59700,1.7944559,1.5806713,,,,,,,,,,,,,, -59800,2.140723,1.5535498,,,,,,,,,,,,,, -59900,1.8034421,1.5594634,,,,,,,,,,,,,, -60000,1.935454,1.6432397,,,,,,,,,,,,,, -60100,2.177756,1.6326746,,,,,,,,,,,,,, -60200,2.135015,1.6334043,,,,,,,,,,,,,, -60300,1.955721,1.5197705,,,,,,,,,,,,,, -60400,1.8731353,1.5224905,,,,,,,,,,,,,, -60500,2.1236384,1.5294147,,,,,,,,,,,,,, -60600,2.0718882,1.4481754,,,,,,,,,,,,,, -60694,,,0.7162587642669678,1.1006048917770386,0.6487999558448792,1.4462645053863523,50000.0,0.5267000198364258,2.1441636085510254,10000.0,20437.448969364166,21171.929075479507,20437.448969364166,731.1252071857452,1.3902521133422852,0.0 -60700,1.7568313,1.5478528,,,,,,,,,,,,,, -60800,2.1409514,1.7181782,,,,,,,,,,,,,, -60900,1.8553076,1.5383765,,,,,,,,,,,,,, -61000,2.0032086,1.5105519,,,,,,,,,,,,,, -61100,1.9750217,1.5151124,,,,,,,,,,,,,, -61200,2.1521246,1.5358086,,,,,,,,,,,,,, -61300,1.771805,1.4592764,,,,,,,,,,,,,, -61400,1.9167467,1.6237655,,,,,,,,,,,,,, -61500,2.1287894,1.7067051,,,,,,,,,,,,,, -61600,1.9920355,1.4917054,,,,,,,,,,,,,, -61700,2.1277988,1.6442151,,,,,,,,,,,,,, -61800,2.1846788,1.6145672,,,,,,,,,,,,,, -61900,2.0438852,1.6863816,,,,,,,,,,,,,, -62000,2.096682,1.6029913,,,,,,,,,,,,,, -62100,2.0214434,1.5188599,,,,,,,,,,,,,, -62200,2.0360823,1.6702917,,,,,,,,,,,,,, -62213,,,0.7052375674247742,1.142183542251587,0.6395800113677979,1.4831955432891846,50000.0,0.511900007724762,2.21729040145874,10000.0,20947.50677704811,21699.70861840248,20947.50677704811,748.7553527355194,1.4329195022583008,0.0 -62300,1.8716029,1.5230052,,,,,,,,,,,,,, -62400,1.9092687,1.5547495,,,,,,,,,,,,,, -62500,1.8304931,1.5421134,,,,,,,,,,,,,, -62600,2.0530431,1.5391624,,,,,,,,,,,,,, -62700,2.1855962,1.6770898,,,,,,,,,,,,,, -62800,1.8911501,1.5706244,,,,,,,,,,,,,, -62900,1.8697497,1.4810208,,,,,,,,,,,,,, -63000,1.9109521,1.4685698,,,,,,,,,,,,,, -63100,2.0953474,1.6992121,,,,,,,,,,,,,, -63200,1.8026153,1.5138524,,,,,,,,,,,,,, -63300,1.9341018,1.5335302,,,,,,,,,,,,,, -63400,1.9636191,1.5411783,,,,,,,,,,,,,, -63500,1.9349511,1.4578048,,,,,,,,,,,,,, -63600,2.2474513,1.6667106,,,,,,,,,,,,,, -63700,1.9840763,1.4511299,,,,,,,,,,,,,, -63731,,,0.7198262214660645,1.084575653076172,0.657039999961853,1.4061697721481323,50000.0,0.5250000357627869,2.16992735862732,10000.0,21457.440640211105,22227.43039250374,21457.440640211105,766.4331395626068,1.493783473968506,0.0 -63800,1.984464,1.5607946,,,,,,,,,,,,,, -63900,2.0674965,1.4696283,,,,,,,,,,,,,, -64000,1.7573748,1.6158154,,,,,,,,,,,,,, -64100,2.0280583,1.4481944,,,,,,,,,,,,,, -64200,1.8030365,1.5208136,,,,,,,,,,,,,, -64300,1.9931782,1.5873392,,,,,,,,,,,,,, -64400,1.8734094,1.5483068,,,,,,,,,,,,,, -64500,2.230864,1.567487,,,,,,,,,,,,,, -64600,1.992608,1.4290572,,,,,,,,,,,,,, -64700,1.8684262,1.5223167,,,,,,,,,,,,,, -64800,2.083523,1.5207245,,,,,,,,,,,,,, -64900,2.0561385,1.646748,,,,,,,,,,,,,, -65000,1.8306298,1.5049347,,,,,,,,,,,,,, -65100,2.1361318,1.5783383,,,,,,,,,,,,,, -65200,1.8502642,1.5477066,,,,,,,,,,,,,, -65249,,,0.732421875,1.026896834373474,0.6575999855995178,1.399552345275879,50000.0,0.5270000100135803,2.1229050159454346,10000.0,21967.380613327023,22755.134654521946,21967.380613327023,784.1061565876007,1.5358808040618896,0.0 -65300,1.9202161,1.5455687,,,,,,,,,,,,,, -65400,1.8716741,1.4964063,,,,,,,,,,,,,, -65500,1.9171506,1.572365,,,,,,,,,,,,,, -65600,1.9290283,1.5506388,,,,,,,,,,,,,, -65700,2.072381,1.4236329,,,,,,,,,,,,,, -65800,2.2477553,1.5541378,,,,,,,,,,,,,, -65900,2.1130507,1.5762985,,,,,,,,,,,,,, -66000,1.9314095,1.5236487,,,,,,,,,,,,,, -66100,2.0308821,1.5128875,,,,,,,,,,,,,, -66200,1.995945,1.4939573,,,,,,,,,,,,,, -66300,1.8104849,1.577023,,,,,,,,,,,,,, -66400,1.9831151,1.6309975,,,,,,,,,,,,,, -66500,2.0220568,1.5490488,,,,,,,,,,,,,, -66600,2.0369105,1.4837229,,,,,,,,,,,,,, -66700,2.0631745,1.6482155,,,,,,,,,,,,,, -66768,,,0.7518733739852905,0.9370354413986206,0.6610400080680847,1.3839858770370483,50000.0,0.5331000089645386,2.115618705749512,10000.0,22477.35139322281,23283.116079568863,22477.35139322281,802.0245227813721,1.5785470008850098,0.0 -66800,2.1544344,1.5559448,,,,,,,,,,,,,, -66900,1.8060113,1.52725,,,,,,,,,,,,,, -67000,2.2240357,1.5754514,,,,,,,,,,,,,, -67100,1.8227533,1.5421036,,,,,,,,,,,,,, -67200,2.1517274,1.5681121,,,,,,,,,,,,,, -67300,1.8900943,1.544702,,,,,,,,,,,,,, -67400,2.069358,1.5674406,,,,,,,,,,,,,, -67500,1.8678012,1.5026491,,,,,,,,,,,,,, -67600,1.9609331,1.5227212,,,,,,,,,,,,,, -67700,1.9903563,1.4838951,,,,,,,,,,,,,, -67800,2.1093025,1.560178,,,,,,,,,,,,,, -67900,2.0724485,1.4287548,,,,,,,,,,,,,, -68000,2.2120795,1.6264337,,,,,,,,,,,,,, -68100,2.1174877,1.5597094,,,,,,,,,,,,,, -68200,1.885245,1.4724398,,,,,,,,,,,,,, -68286,,,0.7399752736091614,0.9882227778434752,0.6643999814987183,1.3710088729858398,50000.0,0.5396000146865845,2.088154554367065,10000.0,22987.300876379013,23810.711325645447,22987.300876379013,819.579030752182,1.6207971572875977,0.0 -68300,2.069899,1.5481979,,,,,,,,,,,,,, -68400,2.05246,1.5638114,,,,,,,,,,,,,, -68500,2.1203048,1.6800866,,,,,,,,,,,,,, -68600,1.7959915,1.4738885,,,,,,,,,,,,,, -68700,1.8715858,1.5014588,,,,,,,,,,,,,, -68800,2.0400138,1.4787041,,,,,,,,,,,,,, -68900,2.0392332,1.5877405,,,,,,,,,,,,,, -69000,2.1263871,1.5544251,,,,,,,,,,,,,, -69100,2.2040794,1.5612104,,,,,,,,,,,,,, -69200,1.9414787,1.5534391,,,,,,,,,,,,,, -69300,1.9790131,1.5174993,,,,,,,,,,,,,, -69400,1.930435,1.4738429,,,,,,,,,,,,,, -69500,2.0946665,1.533291,,,,,,,,,,,,,, -69600,1.9758123,1.4376581,,,,,,,,,,,,,, -69700,2.0326035,1.4924397,,,,,,,,,,,,,, -69800,2.2176478,1.5215175,,,,,,,,,,,,,, -69804,,,0.736726701259613,1.0046941041946411,0.661359965801239,1.3887338638305664,50000.0,0.5301000475883484,2.137101888656616,10000.0,23497.229088783264,24338.66460514069,23497.229088783264,837.5111167430878,1.6648674011230469,0.0 -69900,1.8290431,1.5047715,,,,,,,,,,,,,, -70000,2.183005,1.5170197,,,,,,,,,,,,,, -70100,2.0741289,1.6206248,,,,,,,,,,,,,, -70200,2.2224624,1.5325521,,,,,,,,,,,,,, -70300,1.9889599,1.4068857,,,,,,,,,,,,,, -70400,2.0629551,1.40295,,,,,,,,,,,,,, -70500,2.062053,1.6028824,,,,,,,,,,,,,, -70600,2.0403242,1.5363997,,,,,,,,,,,,,, -70700,1.9259266,1.556135,,,,,,,,,,,,,, -70800,1.9269471,1.5249207,,,,,,,,,,,,,, -70900,1.9014947,1.518314,,,,,,,,,,,,,, -71000,1.958759,1.484026,,,,,,,,,,,,,, -71100,2.1238675,1.5137507,,,,,,,,,,,,,, -71200,2.098538,1.6354141,,,,,,,,,,,,,, -71300,2.1584206,1.5153667,,,,,,,,,,,,,, -71322,,,0.7341557741165161,1.0238789319992063,0.66211998462677,1.3770734071731567,50000.0,0.5324000120162964,2.085169553756714,10000.0,24007.27220749855,24866.694973945618,24007.27220749855,855.4055006504059,1.7085967063903809,0.0 -71400,2.1451423,1.5300889,,,,,,,,,,,,,, -71500,2.0903754,1.4582952,,,,,,,,,,,,,, -71600,1.9102921,1.5192883,,,,,,,,,,,,,, -71700,2.3106542,1.6003858,,,,,,,,,,,,,, -71800,1.8692858,1.4940383,,,,,,,,,,,,,, -71900,2.1448355,1.6096117,,,,,,,,,,,,,, -72000,1.8497827,1.3750274,,,,,,,,,,,,,, -72100,2.0101979,1.5544835,,,,,,,,,,,,,, -72200,2.030291,1.5937804,,,,,,,,,,,,,, -72300,1.904131,1.5232873,,,,,,,,,,,,,, -72400,2.0565398,1.4922351,,,,,,,,,,,,,, -72500,2.1233711,1.4700196,,,,,,,,,,,,,, -72600,2.053354,1.5749398,,,,,,,,,,,,,, -72700,2.0296354,1.5978397,,,,,,,,,,,,,, -72800,1.940137,1.458493,,,,,,,,,,,,,, -72840,,,0.7183912396430969,1.0752381086349487,0.6584999561309814,1.404468059539795,50000.0,0.525600016117096,2.1446402072906494,10000.0,24517.43322777748,25394.93424105644,24517.43322777748,873.3924815654755,1.7509582042694092,0.0 -72900,2.0281634,1.5762699,,,,,,,,,,,,,, -73000,1.9807961,1.4312296,,,,,,,,,,,,,, -73100,2.1195323,1.4607813,,,,,,,,,,,,,, -73200,1.9711037,1.4644581,,,,,,,,,,,,,, -73300,2.0781724,1.4623321,,,,,,,,,,,,,, -73400,2.2420442,1.3405783,,,,,,,,,,,,,, -73500,1.9411547,1.4265541,,,,,,,,,,,,,, -73600,2.1406696,1.4530782,,,,,,,,,,,,,, -73700,2.3296583,1.5791118,,,,,,,,,,,,,, -73800,2.2056777,1.4629,,,,,,,,,,,,,, -73900,2.1512568,1.5087998,,,,,,,,,,,,,, -74000,1.9768564,1.4519997,,,,,,,,,,,,,, -74100,1.9644194,1.4588182,,,,,,,,,,,,,, -74200,2.1269116,1.5031238,,,,,,,,,,,,,, -74300,2.0644333,1.5408728,,,,,,,,,,,,,, -74359,,,0.7762077450752258,0.8492001891136169,0.660539984703064,1.3998095989227295,50000.0,0.531000018119812,2.1150431632995605,10000.0,25027.50243353844,25922.918542146683,25027.50243353844,891.2145557403564,1.7948503494262695,0.0 -74400,2.123879,1.4478726,,,,,,,,,,,,,, -74500,2.0684388,1.4742182,,,,,,,,,,,,,, -74600,2.0926936,1.5398859,,,,,,,,,,,,,, -74700,2.1295238,1.4463519,,,,,,,,,,,,,, -74800,2.158454,1.5360113,,,,,,,,,,,,,, -74900,2.0443046,1.5620104,,,,,,,,,,,,,, -75000,2.2523623,1.4309807,,,,,,,,,,,,,, -75100,2.1660273,1.4359543,,,,,,,,,,,,,, -75200,2.0139194,1.4346602,,,,,,,,,,,,,, -75300,2.1048434,1.4928128,,,,,,,,,,,,,, -75400,1.9249827,1.4193803,,,,,,,,,,,,,, -75500,2.3861382,1.5265772,,,,,,,,,,,,,, -75600,2.1421459,1.4485639,,,,,,,,,,,,,, -75700,2.2954051,1.54925,,,,,,,,,,,,,, -75800,1.9644865,1.4431105,,,,,,,,,,,,,, -75877,,,0.7547233700752258,0.92606920003891,0.668940007686615,1.3541022539138794,50000.0,0.5397000312805176,2.06687068939209,10000.0,25537.48215198517,26450.67085123062,25537.48215198517,908.893723487854,1.839226007461548,0.0 -75900,2.1546855,1.4595854,,,,,,,,,,,,,, -76000,2.1409905,1.3936999,,,,,,,,,,,,,, -76100,2.3308887,1.5687002,,,,,,,,,,,,,, -76200,2.1112282,1.4928405,,,,,,,,,,,,,, -76300,2.2368786,1.4717543,,,,,,,,,,,,,, -76400,2.3932037,1.4377546,,,,,,,,,,,,,, -76500,2.1821673,1.6088636,,,,,,,,,,,,,, -76600,2.130633,1.4645568,,,,,,,,,,,,,, -76700,2.0666173,1.4809322,,,,,,,,,,,,,, -76800,2.1956608,1.4129852,,,,,,,,,,,,,, -76900,2.086215,1.531825,,,,,,,,,,,,,, -77000,2.1039133,1.4984179,,,,,,,,,,,,,, -77100,2.1202486,1.4107716,,,,,,,,,,,,,, -77200,2.2112648,1.4413723,,,,,,,,,,,,,, -77300,2.0533357,1.501067,,,,,,,,,,,,,, -77396,,,0.7463727593421936,0.9656037092208862,0.6665999889373779,1.3637856245040894,50000.0,0.5337000489234924,2.0874693393707275,10000.0,26047.61165094376,26978.57283329964,26047.61165094376,926.574179649353,1.88204026222229,0.0 -77400,2.1642807,1.4656287,,,,,,,,,,,,,, -77500,2.2540116,1.5367733,,,,,,,,,,,,,, -77600,2.1996295,1.4493674,,,,,,,,,,,,,, -77700,2.0885208,1.4707735,,,,,,,,,,,,,, -77800,2.149648,1.4873633,,,,,,,,,,,,,, -77900,2.0358558,1.5259607,,,,,,,,,,,,,, -78000,2.3764427,1.4671068,,,,,,,,,,,,,, -78100,2.2706673,1.4360584,,,,,,,,,,,,,, -78200,1.9438668,1.3233147,,,,,,,,,,,,,, -78300,2.2672234,1.4992874,,,,,,,,,,,,,, -78400,2.3095932,1.4565955,,,,,,,,,,,,,, -78500,2.42149,1.5029596,,,,,,,,,,,,,, -78600,2.1280518,1.466517,,,,,,,,,,,,,, -78700,2.1548228,1.5342534,,,,,,,,,,,,,, -78800,2.5037565,1.4883724,,,,,,,,,,,,,, -78900,2.1472938,1.4173455,,,,,,,,,,,,,, -78915,,,0.724609375,1.0671521425247192,0.6492800116539001,1.4398260116577148,50000.0,0.525700032711029,2.150360107421875,10000.0,26557.65236115456,27506.36722517013,26557.65236115456,944.2308526039124,1.929426908493042,0.0 -79000,2.2496636,1.534208,,,,,,,,,,,,,, -79100,2.041482,1.4331374,,,,,,,,,,,,,, -79200,2.179174,1.5460566,,,,,,,,,,,,,, -79300,2.3789053,1.5099652,,,,,,,,,,,,,, -79400,2.1925628,1.5332729,,,,,,,,,,,,,, -79500,2.192154,1.5851939,,,,,,,,,,,,,, -79600,2.3120754,1.4742892,,,,,,,,,,,,,, -79700,2.0203543,1.3330144,,,,,,,,,,,,,, -79800,2.2696784,1.4738274,,,,,,,,,,,,,, -79900,2.1633294,1.3956332,,,,,,,,,,,,,, -80000,2.2574134,1.5090551,,,,,,,,,,,,,, -80100,2.0275319,1.4724236,,,,,,,,,,,,,, -80200,2.253828,1.5143837,,,,,,,,,,,,,, -80300,2.1826782,1.5260532,,,,,,,,,,,,,, -80400,2.0627153,1.5286181,,,,,,,,,,,,,, -80433,,,0.7413902878761292,0.985304832458496,0.6678000092506409,1.350680589675903,50000.0,0.5360000133514404,2.104771375656128,10000.0,27067.61016345024,28034.15910959244,27067.61016345024,961.969643831253,1.9752976894378664,0.0 -80500,2.1646118,1.5878878,,,,,,,,,,,,,, -80600,2.3889825,1.4446945,,,,,,,,,,,,,, -80700,2.0667562,1.5118899,,,,,,,,,,,,,, -80800,2.1530163,1.5814357,,,,,,,,,,,,,, -80900,2.2921908,1.4494437,,,,,,,,,,,,,, -81000,2.0307837,1.3678632,,,,,,,,,,,,,, -81100,2.3400455,1.4564352,,,,,,,,,,,,,, -81200,2.176955,1.3882798,,,,,,,,,,,,,, -81300,2.260977,1.5270346,,,,,,,,,,,,,, -81400,2.5728269,1.4136811,,,,,,,,,,,,,, -81500,2.2112684,1.4161147,,,,,,,,,,,,,, -81600,2.265188,1.4124506,,,,,,,,,,,,,, -81700,2.321598,1.3759305,,,,,,,,,,,,,, -81800,2.08254,1.5236914,,,,,,,,,,,,,, -81900,2.1204937,1.4203774,,,,,,,,,,,,,, -81951,,,0.7299306392669678,1.0427826642990112,0.6603999733924866,1.401990294456482,50000.0,0.5307000279426575,2.1657660007476807,10000.0,27577.544088840485,28561.997648715973,27577.544088840485,979.7718670368196,2.0286335945129395,0.0 -82000,2.1607769,1.4261155,,,,,,,,,,,,,, -82100,2.4125009,1.38486,,,,,,,,,,,,,, -82200,2.1966312,1.4278672,,,,,,,,,,,,,, -82300,2.1732662,1.50922,,,,,,,,,,,,,, -82400,2.239318,1.5645839,,,,,,,,,,,,,, -82500,2.2584856,1.402373,,,,,,,,,,,,,, -82600,2.2453034,1.6237111,,,,,,,,,,,,,, -82700,2.226736,1.4752741,,,,,,,,,,,,,, -82800,2.1497703,1.4324088,,,,,,,,,,,,,, -82900,2.3639612,1.3921593,,,,,,,,,,,,,, -83000,2.2215385,1.3821985,,,,,,,,,,,,,, -83100,2.0739985,1.3387939,,,,,,,,,,,,,, -83200,2.2072556,1.5070119,,,,,,,,,,,,,, -83300,2.2878497,1.400567,,,,,,,,,,,,,, -83400,2.5613165,1.4884052,,,,,,,,,,,,,, -83470,,,0.7867506146430969,0.7846525311470032,0.6734799742698669,1.3287465572357178,50000.0,0.5425000190734863,2.0561273097991943,10000.0,28087.777733802795,29090.011644124985,28087.777733802795,997.4572043418884,2.074682235717773,0.0 -83500,2.2211099,1.5404507,,,,,,,,,,,,,, -83600,2.167523,1.5014894,,,,,,,,,,,,,, -83700,2.13839,1.3768729,,,,,,,,,,,,,, -83800,2.1463783,1.5183386,,,,,,,,,,,,,, -83900,2.2333822,1.5079726,,,,,,,,,,,,,, -84000,2.181637,1.4191432,,,,,,,,,,,,,, -84100,2.1693814,1.4213157,,,,,,,,,,,,,, -84200,2.2431333,1.4733405,,,,,,,,,,,,,, -84300,2.16897,1.413113,,,,,,,,,,,,,, -84400,2.2862918,1.447336,,,,,,,,,,,,,, -84500,1.9738706,1.4555146,,,,,,,,,,,,,, -84600,2.2594876,1.4841881,,,,,,,,,,,,,, -84700,2.208607,1.4441731,,,,,,,,,,,,,, -84800,2.2952106,1.5714464,,,,,,,,,,,,,, -84900,2.4556973,1.5299311,,,,,,,,,,,,,, -84988,,,0.7635124325752258,0.8853896856307983,0.6748999953269958,1.3213706016540527,50000.0,0.5521000027656555,2.043912649154663,10000.0,28597.72906398773,29618.0389854908,28597.72906398773,1015.4390184879304,2.119399070739746,0.0 -85000,2.2289376,1.4665426,,,,,,,,,,,,,, -85100,2.146265,1.4336283,,,,,,,,,,,,,, -85200,2.2940047,1.4039516,,,,,,,,,,,,,, -85300,2.1920073,1.4100658,,,,,,,,,,,,,, -85400,2.228873,1.4120102,,,,,,,,,,,,,, -85500,2.284901,1.4842517,,,,,,,,,,,,,, -85600,2.2994204,1.3792368,,,,,,,,,,,,,, -85700,2.2836494,1.4996773,,,,,,,,,,,,,, -85800,2.3339753,1.4572705,,,,,,,,,,,,,, -85900,2.1391532,1.3857529,,,,,,,,,,,,,, -86000,2.0296266,1.3752023,,,,,,,,,,,,,, -86100,2.1657398,1.4100112,,,,,,,,,,,,,, -86200,2.2652566,1.4129102,,,,,,,,,,,,,, -86300,2.5990303,1.4737661,,,,,,,,,,,,,, -86400,2.3647218,1.4579495,,,,,,,,,,,,,, -86500,2.4207015,1.5280285,,,,,,,,,,,,,, -86506,,,0.7604631781578064,0.9026753902435304,0.6757599711418152,1.3243170976638794,50000.0,0.54830002784729,2.0270941257476807,10000.0,29107.663843154907,30145.59640312195,29107.663843154907,1032.9682595729828,2.1635024547576904,0.0 -86600,2.1826603,1.4454408,,,,,,,,,,,,,, -86700,2.364616,1.4209362,,,,,,,,,,,,,, -86800,2.217778,1.5340718,,,,,,,,,,,,,, -86900,2.0039744,1.4021819,,,,,,,,,,,,,, -87000,2.2379975,1.4651868,,,,,,,,,,,,,, -87100,2.1424646,1.3921143,,,,,,,,,,,,,, -87200,2.4071398,1.366274,,,,,,,,,,,,,, -87300,2.3956814,1.5328579,,,,,,,,,,,,,, -87400,2.3597143,1.4536778,,,,,,,,,,,,,, -87500,2.3130572,1.4720504,,,,,,,,,,,,,, -87600,2.2292664,1.4910722,,,,,,,,,,,,,, -87700,2.448085,1.3628805,,,,,,,,,,,,,, -87800,2.319602,1.3749593,,,,,,,,,,,,,, -87900,2.4806948,1.490741,,,,,,,,,,,,,, -88000,2.175437,1.423365,,,,,,,,,,,,,, -88024,,,0.753926157951355,0.9288226366043092,0.6761999726295471,1.318540334701538,50000.0,0.541100025177002,2.064603328704834,10000.0,29617.69456934929,30673.468591213223,29617.69456934929,1050.7145743370056,2.2094473838806152,0.0 -88100,2.307543,1.3810494,,,,,,,,,,,,,, -88200,2.3297362,1.4317737,,,,,,,,,,,,,, -88300,2.2312305,1.4608753,,,,,,,,,,,,,, -88400,2.2398763,1.4081299,,,,,,,,,,,,,, -88500,2.4439628,1.4047413,,,,,,,,,,,,,, -88600,2.3832214,1.373968,,,,,,,,,,,,,, -88700,2.1463625,1.461516,,,,,,,,,,,,,, -88800,2.457516,1.2944802,,,,,,,,,,,,,, -88900,2.201662,1.3243771,,,,,,,,,,,,,, -89000,2.2510066,1.4168892,,,,,,,,,,,,,, -89100,2.3470078,1.4031564,,,,,,,,,,,,,, -89200,2.166659,1.414113,,,,,,,,,,,,,, -89300,2.540328,1.4145534,,,,,,,,,,,,,, -89400,2.2256825,1.2934334,,,,,,,,,,,,,, -89500,2.4054673,1.2516028,,,,,,,,,,,,,, -89542,,,0.7547233700752258,0.921085238456726,0.6802200078964233,1.296499252319336,50000.0,0.5512000322341919,2.026613473892212,10000.0,30127.660895824432,31201.48393774033,30127.660895824432,1068.669147491455,2.254591703414917,0.0 -89600,2.419553,1.4420922,,,,,,,,,,,,,, -89700,2.3822837,1.4913383,,,,,,,,,,,,,, -89800,2.3049052,1.4589527,,,,,,,,,,,,,, -89900,2.3775928,1.3359642,,,,,,,,,,,,,, -90000,2.2237482,1.3495519,,,,,,,,,,,,,, -90100,2.4642303,1.4326022,,,,,,,,,,,,,, -90200,2.2444305,1.3931745,,,,,,,,,,,,,, -90300,2.4016826,1.4295266,,,,,,,,,,,,,, -90400,2.4881506,1.4467914,,,,,,,,,,,,,, -90500,2.3345726,1.4512061,,,,,,,,,,,,,, -90600,2.385635,1.3883207,,,,,,,,,,,,,, -90700,2.2874887,1.3055819,,,,,,,,,,,,,, -90800,2.4509857,1.4111283,,,,,,,,,,,,,, -90900,2.2167182,1.3088789,,,,,,,,,,,,,, -91000,2.2690072,1.4168818,,,,,,,,,,,,,, -91061,,,0.7591477632522583,0.9124574065208436,0.6842399835586548,1.2905629873275757,50000.0,0.5562000274658203,1.9942518472671509,10000.0,30637.73820257187,31729.2611079216,30637.73820257187,1086.2701907157898,2.303988218307495,0.0 -91100,2.143249,1.4330297,,,,,,,,,,,,,, -91200,2.4238598,1.4342977,,,,,,,,,,,,,, -91300,2.5078614,1.450523,,,,,,,,,,,,,, -91400,2.226236,1.2845365,,,,,,,,,,,,,, -91500,2.349547,1.5383077,,,,,,,,,,,,,, -91600,2.44927,1.4133364,,,,,,,,,,,,,, -91700,2.339528,1.4072866,,,,,,,,,,,,,, -91800,2.4276226,1.444509,,,,,,,,,,,,,, -91900,2.6101449,1.4344833,,,,,,,,,,,,,, -92000,2.4618928,1.4083608,,,,,,,,,,,,,, -92100,2.2723234,1.4441062,,,,,,,,,,,,,, -92200,2.171644,1.3844361,,,,,,,,,,,,,, -92300,2.2947993,1.365236,,,,,,,,,,,,,, -92400,2.3840408,1.442111,,,,,,,,,,,,,, -92500,2.3872561,1.4002943,,,,,,,,,,,,,, -92579,,,0.7926697731018066,0.767687976360321,0.6845600008964539,1.2874754667282104,50000.0,0.5516000390052795,2.0014238357543945,10000.0,31147.86074543,32257.387805223465,31147.86074543,1104.1770284175873,2.3518271446228027,0.0 -92600,2.3658836,1.4110818,,,,,,,,,,,,,, -92700,2.6081834,1.4913207,,,,,,,,,,,,,, -92800,2.3301716,1.4480891,,,,,,,,,,,,,, -92900,2.510739,1.3896676,,,,,,,,,,,,,, -93000,2.2504978,1.3828862,,,,,,,,,,,,,, -93100,2.1906335,1.3789043,,,,,,,,,,,,,, -93200,2.3372529,1.3166592,,,,,,,,,,,,,, -93300,2.2697392,1.3890957,,,,,,,,,,,,,, -93400,2.420518,1.2891158,,,,,,,,,,,,,, -93500,2.4333377,1.3903279,,,,,,,,,,,,,, -93600,2.4852726,1.4091116,,,,,,,,,,,,,, -93700,2.5570652,1.4148215,,,,,,,,,,,,,, -93800,2.2492237,1.3459682,,,,,,,,,,,,,, -93900,2.1827545,1.2937692,,,,,,,,,,,,,, -94000,2.2496212,1.4545041,,,,,,,,,,,,,, -94096,,,0.7694913744926453,0.8633120656013489,0.6793199777603149,1.320208191871643,50000.0,0.553600013256073,2.048421621322632,10000.0,31657.77154326439,32785.05307340622,31657.77154326439,1121.8362724781036,2.3979732990264893,0.0 -94100,2.404697,1.419958,,,,,,,,,,,,,, -94200,2.2900078,1.3427694,,,,,,,,,,,,,, -94300,2.419888,1.4068955,,,,,,,,,,,,,, -94400,2.3525918,1.3859562,,,,,,,,,,,,,, -94500,2.193801,1.3692931,,,,,,,,,,,,,, -94600,2.4867325,1.3396909,,,,,,,,,,,,,, -94700,2.5861971,1.3506052,,,,,,,,,,,,,, -94800,2.6970417,1.3302212,,,,,,,,,,,,,, -94900,2.389108,1.4725927,,,,,,,,,,,,,, -95000,2.3184247,1.4142854,,,,,,,,,,,,,, -95100,2.4377666,1.4311253,,,,,,,,,,,,,, -95200,2.5823014,1.4481125,,,,,,,,,,,,,, -95300,2.333721,1.31189,,,,,,,,,,,,,, -95400,2.4357839,1.454011,,,,,,,,,,,,,, -95500,2.35961,1.4134595,,,,,,,,,,,,,, -95600,2.4267457,1.4651885,,,,,,,,,,,,,, -95614,,,0.7730388641357422,0.8482105135917664,0.6853799819946289,1.288509726524353,50000.0,0.5561000108718872,2.018871784210205,10000.0,32167.696942329407,33312.85690236092,32167.696942329407,1139.6173412799835,2.44614315032959,0.0 -95700,2.3820379,1.3871804,,,,,,,,,,,,,, -95800,2.459399,1.3812219,,,,,,,,,,,,,, -95900,2.238978,1.3953271,,,,,,,,,,,,,, -96000,2.3058233,1.3343326,,,,,,,,,,,,,, -96100,2.3100107,1.3925253,,,,,,,,,,,,,, -96200,2.5606341,1.4522108,,,,,,,,,,,,,, -96300,2.535591,1.4625165,,,,,,,,,,,,,, -96400,2.2295957,1.2749678,,,,,,,,,,,,,, -96500,2.6131036,1.4612334,,,,,,,,,,,,,, -96600,2.4329417,1.4596997,,,,,,,,,,,,,, -96700,2.5261998,1.3201641,,,,,,,,,,,,,, -96800,2.321192,1.3080387,,,,,,,,,,,,,, -96900,2.4680693,1.3819541,,,,,,,,,,,,,, -97000,2.5832722,1.4505997,,,,,,,,,,,,,, -97100,2.3461306,1.3783042,,,,,,,,,,,,,, -97132,,,0.751375138759613,0.9182205200195312,0.6756399869918823,1.318800449371338,50000.0,0.5425000190734863,2.0723519325256348,10000.0,32677.73971581459,33841.79132437706,32677.73971581459,1158.4141011238098,2.492197275161743,0.0 -97200,2.4137967,1.4329093,,,,,,,,,,,,,, -97300,2.521053,1.368022,,,,,,,,,,,,,, -97400,2.436407,1.29884,,,,,,,,,,,,,, -97500,2.2228496,1.2634976,,,,,,,,,,,,,, -97600,2.3129306,1.4053422,,,,,,,,,,,,,, -97700,2.4944022,1.4295136,,,,,,,,,,,,,, -97800,2.6416595,1.3695444,,,,,,,,,,,,,, -97900,2.3109753,1.3234295,,,,,,,,,,,,,, -98000,2.3448367,1.327305,,,,,,,,,,,,,, -98100,2.388745,1.3039756,,,,,,,,,,,,,, -98200,3.267564,1.3088033,,,,,,,,,,,,,, -98300,2.388425,1.3636311,,,,,,,,,,,,,, -98400,2.3696268,1.3220786,,,,,,,,,,,,,, -98500,2.411223,1.3461504,,,,,,,,,,,,,, -98600,2.3044617,1.3182602,,,,,,,,,,,,,, -98651,,,0.7703284025192261,0.8573641180992126,0.6903600096702576,1.2503340244293213,50000.0,0.5623000264167786,1.983777403831482,10000.0,33187.793897628784,34369.720563173294,33187.793897628784,1176.1916980743408,2.540599822998047,0.0 -98700,2.3701415,1.2633135,,,,,,,,,,,,,, -98800,2.554069,1.3651223,,,,,,,,,,,,,, -98900,2.3808522,1.3482674,,,,,,,,,,,,,, -99000,2.5457053,1.3441997,,,,,,,,,,,,,, -99100,2.328934,1.3770884,,,,,,,,,,,,,, -99200,2.2825556,1.3952541,,,,,,,,,,,,,, -99300,2.878551,1.3157268,,,,,,,,,,,,,, -99400,2.5632067,1.33356,,,,,,,,,,,,,, -99500,2.5090036,1.4908775,,,,,,,,,,,,,, -99600,2.3261132,1.3565187,,,,,,,,,,,,,, -99700,2.441928,1.4273283,,,,,,,,,,,,,, -99800,2.2738552,1.2157257,,,,,,,,,,,,,, -99900,2.311896,1.3163638,,,,,,,,,,,,,, -100000,2.4290385,1.309575,,,,,,,,,,,,,, -100100,2.4145858,1.3387384,,,,,,,,,,,,,, -100169,,,0.7879264950752258,0.7837117314338684,0.6886000037193298,1.271813988685608,50000.0,0.5621000528335571,1.986859440803528,10000.0,33697.73100566864,34897.45811223984,33697.73100566864,1193.8929476737976,2.5904340744018555,0.0 -100200,2.5483422,1.3184459,,,,,,,,,,,,,, -100300,2.388889,1.3854096,,,,,,,,,,,,,, -100400,2.4837844,1.38347,,,,,,,,,,,,,, -100500,2.2770581,1.3374872,,,,,,,,,,,,,, -100600,2.6284733,1.2583524,,,,,,,,,,,,,, -100700,2.3625226,1.272754,,,,,,,,,,,,,, -100800,2.3778524,1.2798895,,,,,,,,,,,,,, -100900,2.5109744,1.3472676,,,,,,,,,,,,,, -101000,2.4682574,1.2838984,,,,,,,,,,,,,, -101100,2.5444965,1.2274611,,,,,,,,,,,,,, -101200,2.4904501,1.2384176,,,,,,,,,,,,,, -101300,2.4150262,1.2901386,,,,,,,,,,,,,, -101400,2.5782592,1.3421319,,,,,,,,,,,,,, -101500,2.4919538,1.358412,,,,,,,,,,,,,, -101600,2.6594625,1.3030988,,,,,,,,,,,,,, -101687,,,0.7952805757522583,0.7596470713615417,0.6904799938201904,1.2509621381759644,50000.0,0.5557000041007996,1.995941519737244,10000.0,34207.68478536606,35425.44287323952,34207.68478536606,1211.8229558467865,2.642286777496338,0.0 -101700,2.5491292,1.383604,,,,,,,,,,,,,, -101800,2.6246548,1.4128424,,,,,,,,,,,,,, -101900,2.608169,1.4941139,,,,,,,,,,,,,, -102000,2.46081,1.2991552,,,,,,,,,,,,,, -102100,2.411296,1.3640096,,,,,,,,,,,,,, -102200,2.5698614,1.2942164,,,,,,,,,,,,,, -102300,2.4708016,1.3965929,,,,,,,,,,,,,, -102400,2.5958915,1.283557,,,,,,,,,,,,,, -102500,2.5542557,1.3772241,,,,,,,,,,,,,, -102600,2.4258394,1.3126643,,,,,,,,,,,,,, -102700,2.4456766,1.2653754,,,,,,,,,,,,,, -102800,2.853768,1.39182,,,,,,,,,,,,,, -102900,2.490349,1.3613073,,,,,,,,,,,,,, -103000,2.4135008,1.425635,,,,,,,,,,,,,, -103100,2.5140595,1.3722298,,,,,,,,,,,,,, -103200,2.988763,1.3606069,,,,,,,,,,,,,, -103206,,,0.7831433415412903,0.8068510293960571,0.6899600028991699,1.2710587978363037,50000.0,0.5639000535011292,2.0179336071014404,10000.0,34717.76279425621,35953.27993154526,34717.76279425621,1229.4830603599548,2.692025899887085,0.0 -103300,2.3894305,1.2850013,,,,,,,,,,,,,, -103400,2.5253565,1.3565984,,,,,,,,,,,,,, -103500,2.536957,1.2523934,,,,,,,,,,,,,, -103600,2.4791844,1.2968731,,,,,,,,,,,,,, -103700,2.6500652,1.23609,,,,,,,,,,,,,, -103800,2.532377,1.2774633,,,,,,,,,,,,,, -103900,2.5855644,1.2818526,,,,,,,,,,,,,, -104000,2.7327094,1.3697937,,,,,,,,,,,,,, -104100,2.5080025,1.3467618,,,,,,,,,,,,,, -104200,3.3673406,1.1964359,,,,,,,,,,,,,, -104300,2.9505477,1.3791904,,,,,,,,,,,,,, -104400,2.6570868,1.3557913,,,,,,,,,,,,,, -104500,2.5665896,1.249437,,,,,,,,,,,,,, -104600,2.701323,1.268945,,,,,,,,,,,,,, -104700,2.5984628,1.2822828,,,,,,,,,,,,,, -104726,,,0.7667809128761292,0.8686758279800415,0.6767399907112122,1.3189637660980225,50000.0,0.5525000095367432,2.056422472000122,10000.0,35227.90310502052,36481.48370862007,35227.90310502052,1247.4493174552915,2.7401788234710693,0.0 -104800,2.3797562,1.3105092,,,,,,,,,,,,,, -104900,2.411681,1.2870355,,,,,,,,,,,,,, -105000,2.4519758,1.2980727,,,,,,,,,,,,,, -105100,2.5945945,1.3510028,,,,,,,,,,,,,, -105200,2.4608696,1.2618665,,,,,,,,,,,,,, -105300,2.3218017,1.3201427,,,,,,,,,,,,,, -105400,2.811233,1.4056041,,,,,,,,,,,,,, -105500,2.5694978,1.3138475,,,,,,,,,,,,,, -105600,2.5038857,1.240166,,,,,,,,,,,,,, -105700,3.17336,1.3665369,,,,,,,,,,,,,, -105800,2.6396508,1.2721807,,,,,,,,,,,,,, -105900,2.4876704,1.2638371,,,,,,,,,,,,,, -106000,2.5457575,1.2471975,,,,,,,,,,,,,, -106100,2.75884,1.3484744,,,,,,,,,,,,,, -106200,2.7558038,1.2656487,,,,,,,,,,,,,, -106245,,,0.7792769074440002,0.8142668008804321,0.6949399709701538,1.2566287517547607,50000.0,0.5676000118255615,1.992361187934876,10000.0,35738.055804252625,37009.439423561096,35738.055804252625,1265.152621269226,2.7908036708831787,0.0 -106300,2.8407292,1.3540125,,,,,,,,,,,,,, -106400,2.5907335,1.3099357,,,,,,,,,,,,,, -106500,2.541218,1.3379304,,,,,,,,,,,,,, -106600,2.9072146,1.3320088,,,,,,,,,,,,,, -106700,2.3574834,1.1746207,,,,,,,,,,,,,, -106800,2.455966,1.2667317,,,,,,,,,,,,,, -106900,2.7632763,1.4051783,,,,,,,,,,,,,, -107000,2.3221328,1.2164714,,,,,,,,,,,,,, -107100,2.4358265,1.1994265,,,,,,,,,,,,,, -107200,2.5588489,1.2819908,,,,,,,,,,,,,, -107300,2.7845376,1.3519328,,,,,,,,,,,,,, -107400,2.7087617,1.3453153,,,,,,,,,,,,,, -107500,2.568517,1.2728195,,,,,,,,,,,,,, -107600,2.9068105,1.3663149,,,,,,,,,,,,,, -107700,2.4647717,1.3425592,,,,,,,,,,,,,, -107764,,,0.7824258208274841,0.7984607219696045,0.6949399709701538,1.2417380809783936,50000.0,0.5662000179290771,1.9687929153442385,10000.0,36248.185584783554,37537.60767388344,36248.185584783554,1283.092089176178,2.841114997863769,0.0 -107800,2.8264513,1.3725078,,,,,,,,,,,,,, -107900,2.7099671,1.3354659,,,,,,,,,,,,,, -108000,3.1120677,1.3945333,,,,,,,,,,,,,, -108100,2.728942,1.3868856,,,,,,,,,,,,,, -108200,2.5580494,1.3043082,,,,,,,,,,,,,, -108300,2.4376333,1.1460232,,,,,,,,,,,,,, -108400,2.625995,1.3275621,,,,,,,,,,,,,, -108500,2.5941608,1.2871164,,,,,,,,,,,,,, -108600,2.562912,1.2166858,,,,,,,,,,,,,, -108700,2.8259292,1.2812781,,,,,,,,,,,,,, -108800,2.5240073,1.3259029,,,,,,,,,,,,,, -108900,2.8195183,1.4459801,,,,,,,,,,,,,, -109000,2.6946492,1.26991,,,,,,,,,,,,,, -109100,2.5101018,1.2461501,,,,,,,,,,,,,, -109200,2.642542,1.2316431,,,,,,,,,,,,,, -109283,,,0.8298588991165161,0.6174313426017761,0.7004799842834473,1.232813596725464,50000.0,0.5707000494003296,1.9671730995178225,10000.0,36758.29444336891,38065.49232196808,36758.29444336891,1300.768966436386,2.890819549560547,0.0 -109300,2.7865903,1.3160512,,,,,,,,,,,,,, -109400,2.7224212,1.1985027,,,,,,,,,,,,,, -109500,2.5450704,1.2478791,,,,,,,,,,,,,, -109600,2.7318556,1.2810113,,,,,,,,,,,,,, -109700,2.7714446,1.3161465,,,,,,,,,,,,,, -109800,2.756103,1.2955114,,,,,,,,,,,,,, -109900,3.124198,1.4115319,,,,,,,,,,,,,, -110000,2.6106615,1.4026128,,,,,,,,,,,,,, -110100,2.435603,1.230008,,,,,,,,,,,,,, -110200,3.0671747,1.1518717,,,,,,,,,,,,,, -110300,2.8600998,1.3721995,,,,,,,,,,,,,, -110400,2.7922332,1.2594981,,,,,,,,,,,,,, -110500,2.533334,1.1925368,,,,,,,,,,,,,, -110600,2.9390974,1.4038757,,,,,,,,,,,,,, -110700,2.7909367,1.2242166,,,,,,,,,,,,,, -110800,2.4512355,1.1228623,,,,,,,,,,,,,, -110801,,,0.8005221486091614,0.7223421931266785,0.695580005645752,1.2395744323730469,50000.0,0.5663000345230103,1.978838562965393,10000.0,37268.264142751694,38593.53087067604,37268.264142751694,1318.7363233566284,2.9423089027404785,0.0 -110900,2.8332117,1.349357,,,,,,,,,,,,,, -111000,2.7618988,1.2644813,,,,,,,,,,,,,, -111100,2.6300566,1.2429239,,,,,,,,,,,,,, -111200,2.7190819,1.2272918,,,,,,,,,,,,,, -111300,2.5375788,1.2414532,,,,,,,,,,,,,, -111400,2.7350883,1.3886673,,,,,,,,,,,,,, -111500,2.708979,1.3063374,,,,,,,,,,,,,, -111600,3.1082249,1.2861814,,,,,,,,,,,,,, -111700,2.7627492,1.2503192,,,,,,,,,,,,,, -111800,2.7177153,1.3297935,,,,,,,,,,,,,, -111900,2.539016,1.1764455,,,,,,,,,,,,,, -112000,2.6231415,1.2391412,,,,,,,,,,,,,, -112100,2.9416707,1.3359878,,,,,,,,,,,,,, -112200,2.6422102,1.3454018,,,,,,,,,,,,,, -112300,2.7051558,1.2277124,,,,,,,,,,,,,, -112320,,,0.800223171710968,0.7208084464073181,0.6976799964904785,1.232202649116516,50000.0,0.5700000524520874,1.9467543363571167,10000.0,37778.35658097267,39121.58758187294,37778.35658097267,1336.601214170456,2.992850542068481,0.0 -112400,3.2405765,1.3579329,,,,,,,,,,,,,, -112500,2.6258593,1.2495654,,,,,,,,,,,,,, -112600,3.000234,1.3219547,,,,,,,,,,,,,, -112700,2.8860018,1.3782533,,,,,,,,,,,,,, -112800,2.8822653,1.265945,,,,,,,,,,,,,, -112900,2.9932115,1.3234377,,,,,,,,,,,,,, -113000,2.98889,1.3494843,,,,,,,,,,,,,, -113100,3.116378,1.2712044,,,,,,,,,,,,,, -113200,2.744719,1.2099519,,,,,,,,,,,,,, -113300,2.7103024,1.2876997,,,,,,,,,,,,,, -113400,2.9481895,1.3269291,,,,,,,,,,,,,, -113500,2.6993234,1.2596159,,,,,,,,,,,,,, -113600,2.8359249,1.2973416,,,,,,,,,,,,,, -113700,3.077586,1.2885894,,,,,,,,,,,,,, -113800,2.9541833,1.1868596,,,,,,,,,,,,,, -113839,,,0.8005420565605164,0.7312487959861755,0.7041599750518799,1.2149962186813354,50000.0,0.5770000219345093,1.942642331123352,10000.0,38288.51729607582,39649.69808101654,38288.51729607582,1354.453326463699,3.041289567947388,0.0 -113900,2.878481,1.2993715,,,,,,,,,,,,,, -114000,2.8740633,1.2578712,,,,,,,,,,,,,, -114100,3.1421494,1.2825913,,,,,,,,,,,,,, -114200,2.8060513,1.2433208,,,,,,,,,,,,,, -114300,2.8379653,1.204117,,,,,,,,,,,,,, -114400,2.756039,1.2994096,,,,,,,,,,,,,, -114500,2.954776,1.2630769,,,,,,,,,,,,,, -114600,2.7031467,1.3016291,,,,,,,,,,,,,, -114700,2.9332206,1.216738,,,,,,,,,,,,,, -114800,2.8806365,1.2502453,,,,,,,,,,,,,, -114900,3.1964169,1.2783421,,,,,,,,,,,,,, -115000,2.5622606,1.1155132,,,,,,,,,,,,,, -115100,2.81519,1.3205564,,,,,,,,,,,,,, -115200,2.8502085,1.2771118,,,,,,,,,,,,,, -115300,2.6527193,1.1567656,,,,,,,,,,,,,, -115357,,,0.7947823405265808,0.752173125743866,0.6983799934387207,1.2410300970077517,50000.0,0.5707000494003296,1.9666895866394043,10000.0,38798.464625597,40177.886984825134,38798.464625597,1372.5982236862185,3.089034080505371,0.0 -115400,2.8514194,1.2544458,,,,,,,,,,,,,, -115500,2.759974,1.1436663,,,,,,,,,,,,,, -115600,3.0992665,1.279004,,,,,,,,,,,,,, -115700,2.8164864,1.200238,,,,,,,,,,,,,, -115800,3.218374,1.1541295,,,,,,,,,,,,,, -115900,2.9443872,1.2714763,,,,,,,,,,,,,, -116000,2.7954264,1.1429212,,,,,,,,,,,,,, -116100,2.8038342,1.134409,,,,,,,,,,,,,, -116200,2.7846518,1.1638486,,,,,,,,,,,,,, -116300,2.9882667,1.160869,,,,,,,,,,,,,, -116400,3.1126883,1.3054397,,,,,,,,,,,,,, -116500,2.713856,1.177138,,,,,,,,,,,,,, -116600,2.8228085,1.2347826,,,,,,,,,,,,,, -116700,2.7977555,1.2183357,,,,,,,,,,,,,, -116800,2.963993,1.2028852,,,,,,,,,,,,,, -116875,,,0.7978315949440002,0.7409703135490417,0.7010599970817566,1.2206995487213137,50000.0,0.5698000192642212,1.955405712127685,10000.0,39308.584755182266,40705.944717884064,39308.584755182266,1390.4349954128263,3.141059637069702,0.0 -116900,2.732097,1.2371962,,,,,,,,,,,,,, -117000,2.9216294,1.1757208,,,,,,,,,,,,,, -117100,2.8669524,1.1851147,,,,,,,,,,,,,, -117200,2.6750991,1.2268757,,,,,,,,,,,,,, -117300,2.7803354,1.2168831,,,,,,,,,,,,,, -117400,3.2099938,1.2448092,,,,,,,,,,,,,, -117500,2.8229384,1.1820041,,,,,,,,,,,,,, -117600,3.0314338,1.2857596,,,,,,,,,,,,,, -117700,2.854911,1.2167852,,,,,,,,,,,,,, -117800,2.9417908,1.232763,,,,,,,,,,,,,, -117900,2.9254572,1.2024205,,,,,,,,,,,,,, -118000,3.3494008,1.2449949,,,,,,,,,,,,,, -118100,3.1286957,1.2319742,,,,,,,,,,,,,, -118200,3.0646744,1.2733275,,,,,,,,,,,,,, -118300,2.9096396,1.2086565,,,,,,,,,,,,,, -118393,,,0.8346220850944519,0.5950406193733215,0.7032999992370605,1.215368151664734,50000.0,0.5749000310897827,1.9618397951126096,10000.0,39818.51747059822,41233.966715574265,39818.51747059822,1408.42196559906,3.193824052810669,0.0 -118400,2.6599245,1.0987713,,,,,,,,,,,,,, -118500,2.9977977,1.2064521,,,,,,,,,,,,,, -118600,3.130947,1.1492248,,,,,,,,,,,,,, -118700,2.5838993,1.0366173,,,,,,,,,,,,,, -118800,2.6953177,1.170045,,,,,,,,,,,,,, -118900,2.7372503,1.2086153,,,,,,,,,,,,,, -119000,2.832352,1.1004837,,,,,,,,,,,,,, -119100,3.020504,1.2327585,,,,,,,,,,,,,, -119200,2.9126751,1.2606791,,,,,,,,,,,,,, -119300,3.0495687,1.2097268,,,,,,,,,,,,,, -119400,2.7633631,1.2117877,,,,,,,,,,,,,, -119500,2.755499,1.1184459,,,,,,,,,,,,,, -119600,3.214926,1.3289677,,,,,,,,,,,,,, -119700,2.8965573,1.2015865,,,,,,,,,,,,,, -119800,2.9694595,1.222896,,,,,,,,,,,,,, -119900,2.8919244,1.198822,,,,,,,,,,,,,, -119912,,,0.8252750039100647,0.6352033615112305,0.7068799734115601,1.1854684352874756,50000.0,0.5819000005722046,1.8898450136184688,10000.0,40328.73139810562,41762.01231408119,40328.73139810562,1426.150707244873,3.247591495513916,0.0 -120000,3.0414162,1.2400337,,,,,,,,,,,,,, -120100,2.9617991,1.2269249,,,,,,,,,,,,,, -120200,3.1583207,1.1833415,,,,,,,,,,,,,, -120300,3.0162706,1.1701028,,,,,,,,,,,,,, -120400,3.2329679,1.1742542,,,,,,,,,,,,,, -120500,2.814059,1.1528372,,,,,,,,,,,,,, -120600,2.9858985,1.2213035,,,,,,,,,,,,,, -120700,2.7429328,1.1882796,,,,,,,,,,,,,, -120800,3.1375237,1.1890974,,,,,,,,,,,,,, -120900,3.2418082,1.2589538,,,,,,,,,,,,,, -121000,3.2345107,1.1899107,,,,,,,,,,,,,, -121100,2.880526,1.1821221,,,,,,,,,,,,,, -121200,2.9737597,1.1926119,,,,,,,,,,,,,, -121300,3.1307144,1.1792941,,,,,,,,,,,,,, -121400,3.2316873,1.209508,,,,,,,,,,,,,, -121431,,,0.8082947731018066,0.6803027391433716,0.7044599652290344,1.2126598358154297,50000.0,0.5742000341415405,1.977544069290161,10000.0,40838.95229148865,42290.05305337906,40838.95229148865,1443.8689014911652,3.300241231918335,0.0 -121500,3.297025,1.2533425,,,,,,,,,,,,,, -121600,2.835026,1.158763,,,,,,,,,,,,,, -121700,3.0734909,1.1669012,,,,,,,,,,,,,, -121800,3.07349,1.2734799,,,,,,,,,,,,,, -121900,2.8717272,1.1903679,,,,,,,,,,,,,, -122000,3.0724342,1.1489751,,,,,,,,,,,,,, -122100,3.097862,1.1466748,,,,,,,,,,,,,, -122200,3.003956,1.1825032,,,,,,,,,,,,,, -122300,3.021268,1.2349199,,,,,,,,,,,,,, -122400,2.8732038,1.110122,,,,,,,,,,,,,, -122500,3.452316,1.2332476,,,,,,,,,,,,,, -122600,2.8617535,1.1504993,,,,,,,,,,,,,, -122700,3.294184,1.2586157,,,,,,,,,,,,,, -122800,2.9684212,1.1333127,,,,,,,,,,,,,, -122900,3.3313146,1.1636729,,,,,,,,,,,,,, -122950,,,0.8224449753761292,0.6331303119659424,0.7143599987030029,1.1689563989639282,50000.0,0.5859000086784363,1.9121652841567995,10000.0,41349.047131061554,42818.18391633034,41349.047131061554,1461.8070714473724,3.3489990234375,0.0 -123000,3.1161416,1.2754694,,,,,,,,,,,,,, -123100,3.0235443,1.0783201,,,,,,,,,,,,,, -123200,2.848422,1.0502505,,,,,,,,,,,,,, -123300,3.3084054,1.1106435,,,,,,,,,,,,,, -123400,3.4739273,1.1511896,,,,,,,,,,,,,, -123500,2.914164,1.1923032,,,,,,,,,,,,,, -123600,3.026652,1.1336324,,,,,,,,,,,,,, -123700,3.2753127,1.1778792,,,,,,,,,,,,,, -123800,3.0066488,1.2017597,,,,,,,,,,,,,, -123900,3.0002265,1.1566707,,,,,,,,,,,,,, -124000,2.9585075,1.1767561,,,,,,,,,,,,,, -124100,2.9131598,1.1707687,,,,,,,,,,,,,, -124200,3.1606805,1.1739293,,,,,,,,,,,,,, -124300,2.9521015,1.1744442,,,,,,,,,,,,,, -124400,3.080518,1.1686519,,,,,,,,,,,,,, -124469,,,0.8160474896430969,0.6538271903991699,0.7134999632835388,1.174975872039795,50000.0,0.5866000056266785,1.896184682846069,10000.0,41859.04773306847,43345.82070159912,41859.04773306847,1479.3437526226044,3.399578094482422,0.0 -124500,3.1099377,1.2090316,,,,,,,,,,,,,, -124600,3.0972219,1.0723703,,,,,,,,,,,,,, -124700,3.4704416,1.2602024,,,,,,,,,,,,,, -124800,3.096213,1.2168413,,,,,,,,,,,,,, -124900,3.3669944,1.1568604,,,,,,,,,,,,,, -125000,3.7023468,1.0973191,,,,,,,,,,,,,, -125100,3.298198,1.2101331,,,,,,,,,,,,,, -125200,3.150797,1.2026349,,,,,,,,,,,,,, -125300,3.2448146,1.193379,,,,,,,,,,,,,, -125400,3.0416145,1.1327314,,,,,,,,,,,,,, -125500,3.3611329,1.2941372,,,,,,,,,,,,,, -125600,3.318155,1.1631076,,,,,,,,,,,,,, -125700,3.1712065,1.1594644,,,,,,,,,,,,,, -125800,3.031396,1.1420512,,,,,,,,,,,,,, -125900,3.352072,1.2090824,,,,,,,,,,,,,, -125988,,,0.8224050998687744,0.6366180181503296,0.7099800109863281,1.189315915107727,50000.0,0.5868000388145447,1.908347964286804,10000.0,42368.96591067314,43873.70950007439,42368.96591067314,1497.206482887268,3.458393812179565,0.0 -126000,2.9857593,1.1108305,,,,,,,,,,,,,, -126100,2.9426212,1.0983403,,,,,,,,,,,,,, -126200,3.262341,1.2291752,,,,,,,,,,,,,, -126300,3.078442,1.1094568,,,,,,,,,,,,,, -126400,3.149294,1.0445403,,,,,,,,,,,,,, -126500,3.085289,1.1054422,,,,,,,,,,,,,, -126600,3.3465314,1.1611336,,,,,,,,,,,,,, -126700,3.0265267,1.0914764,,,,,,,,,,,,,, -126800,3.260128,1.0936872,,,,,,,,,,,,,, -126900,3.41858,1.223012,,,,,,,,,,,,,, -127000,3.395997,1.1476574,,,,,,,,,,,,,, -127100,3.2800884,1.0887704,,,,,,,,,,,,,, -127200,3.3666208,1.1409353,,,,,,,,,,,,,, -127300,3.142963,1.1486055,,,,,,,,,,,,,, -127400,3.3534863,1.1101344,,,,,,,,,,,,,, -127500,3.239631,1.0513461,,,,,,,,,,,,,, -127506,,,0.8494299650192261,0.529978334903717,0.7148199677467346,1.1761623620986938,50000.0,0.5898000001907349,1.8952957391738887,10000.0,42879.00650596619,44401.515615940094,42879.00650596619,1514.8715977668762,3.5093042850494385,0.0 -127600,3.0871584,1.075111,,,,,,,,,,,,,, -127700,3.4231448,1.1535707,,,,,,,,,,,,,, -127800,3.3989868,1.1962821,,,,,,,,,,,,,, -127900,3.263961,1.1217443,,,,,,,,,,,,,, -128000,3.294988,1.1046851,,,,,,,,,,,,,, -128100,3.1524353,1.0964828,,,,,,,,,,,,,, -128200,3.4295285,1.1468248,,,,,,,,,,,,,, -128300,3.5289392,1.1404349,,,,,,,,,,,,,, -128400,3.0809155,1.1144555,,,,,,,,,,,,,, -128500,3.2099938,1.0222179,,,,,,,,,,,,,, -128600,3.0511894,1.1311505,,,,,,,,,,,,,, -128700,3.4589407,1.1135631,,,,,,,,,,,,,, -128800,3.126314,1.1183236,,,,,,,,,,,,,, -128900,3.177179,1.1708986,,,,,,,,,,,,,, -129000,3.149951,1.05772,,,,,,,,,,,,,, -129025,,,0.8404814600944519,0.564034104347229,0.7139599919319153,1.1836059093475342,50000.0,0.5893000364303589,1.9311414957046509,10000.0,43389.02200245857,44929.42847776413,43389.02200245857,1532.6676306724548,3.56146502494812,0.0 -129100,3.2176318,1.0640814,,,,,,,,,,,,,, -129200,3.6992857,1.1019018,,,,,,,,,,,,,, -129300,3.3925319,1.2035675,,,,,,,,,,,,,, -129400,3.2934616,1.0376717,,,,,,,,,,,,,, -129500,3.367826,1.073234,,,,,,,,,,,,,, -129600,3.2716572,1.0105882,,,,,,,,,,,,,, -129700,3.1917894,1.1029546,,,,,,,,,,,,,, -129800,3.2795718,1.1441207,,,,,,,,,,,,,, -129900,3.345004,1.1099381,,,,,,,,,,,,,, -130000,3.4761126,1.0616271,,,,,,,,,,,,,, -130100,3.0834463,0.99254894,,,,,,,,,,,,,, -130200,3.3431003,1.1173706,,,,,,,,,,,,,, -130300,3.454204,1.1435149,,,,,,,,,,,,,, -130400,3.4589405,1.0632346,,,,,,,,,,,,,, -130500,3.093883,1.1330186,,,,,,,,,,,,,, -130544,,,0.8401227593421936,0.5696372985839844,0.7194199562072754,1.1422617435455322,50000.0,0.5923000574111938,1.8697551488876345,10000.0,43899.07996249199,45457.17548465729,43899.07996249199,1550.256118774414,3.613261222839356,0.0 -130600,3.7868655,1.0485023,,,,,,,,,,,,,, -130700,3.118196,1.0448594,,,,,,,,,,,,,, -130800,3.2739534,1.0348717,,,,,,,,,,,,,, -130900,3.420061,1.1864084,,,,,,,,,,,,,, -131000,3.2849813,1.1044893,,,,,,,,,,,,,, -131100,3.1488707,1.1142783,,,,,,,,,,,,,, -131200,3.0248663,1.0768237,,,,,,,,,,,,,, -131300,3.2786741,1.0940404,,,,,,,,,,,,,, -131400,3.0615866,1.0873871,,,,,,,,,,,,,, -131500,3.8524685,1.1398762,,,,,,,,,,,,,, -131600,3.5729783,1.0500376,,,,,,,,,,,,,, -131700,3.3080535,1.050489,,,,,,,,,,,,,, -131800,3.1875563,1.0534191,,,,,,,,,,,,,, -131900,3.2543387,1.0181506,,,,,,,,,,,,,, -132000,3.31981,1.0730247,,,,,,,,,,,,,, -132064,,,0.8375318646430969,0.5773704051971436,0.7168599963188171,1.1653879880905151,50000.0,0.588200032711029,1.8888368606567385,10000.0,44409.24393892288,45985.240468502045,44409.24393892288,1568.0413491725922,3.679718255996704,0.0 -132100,3.2112737,0.9890417,,,,,,,,,,,,,, -132200,3.4018803,1.0789347,,,,,,,,,,,,,, -132300,3.301607,1.0186492,,,,,,,,,,,,,, -132400,3.5352097,1.1927838,,,,,,,,,,,,,, -132500,3.3917644,1.0709114,,,,,,,,,,,,,, -132600,3.5264728,1.0828409,,,,,,,,,,,,,, -132700,3.1865067,0.9966833,,,,,,,,,,,,,, -132800,3.1130865,0.97359765,,,,,,,,,,,,,, -132900,3.9673648,1.216785,,,,,,,,,,,,,, -133000,3.257157,1.1021439,,,,,,,,,,,,,, -133100,3.3322103,1.0360183,,,,,,,,,,,,,, -133200,3.5359385,1.1137418,,,,,,,,,,,,,, -133300,3.5464492,1.0585687,,,,,,,,,,,,,, -133400,3.28884,1.0008137,,,,,,,,,,,,,, -133500,3.2921646,1.1267169,,,,,,,,,,,,,, -133582,,,0.8428332209587097,0.5535558462142944,0.720579981803894,1.1483376026153564,50000.0,0.5958000421524048,1.894065499305725,10000.0,44919.25254154205,46513.27612757683,44919.25254154205,1585.9611542224884,3.7378854751586914,0.0 -133600,3.4468215,1.0747781,,,,,,,,,,,,,, -133700,3.4950294,1.0991544,,,,,,,,,,,,,, -133800,3.2673967,1.0047619,,,,,,,,,,,,,, -133900,3.4884472,1.0560211,,,,,,,,,,,,,, -134000,3.3485916,1.0365391,,,,,,,,,,,,,, -134100,3.3221037,1.0446053,,,,,,,,,,,,,, -134200,3.1709437,0.99439156,,,,,,,,,,,,,, -134300,3.4520593,1.0537452,,,,,,,,,,,,,, -134400,3.2946103,1.0889046,,,,,,,,,,,,,, -134500,3.2857163,1.0514507,,,,,,,,,,,,,, -134600,3.6306026,1.0927345,,,,,,,,,,,,,, -134700,3.4841373,1.0197823,,,,,,,,,,,,,, -134800,3.5523782,0.9910298,,,,,,,,,,,,,, -134900,3.5235157,1.0068036,,,,,,,,,,,,,, -135000,3.4997473,1.0652672,,,,,,,,,,,,,, -135100,3.385601,1.0950339,,,,,,,,,,,,,, -135101,,,0.8835100531578064,0.4169529676437378,0.72461998462677,1.1391019821166992,50000.0,0.5998000502586365,1.867887020111084,10000.0,45429.48771595955,47042.1477367878,45429.48771595955,1604.497786283493,3.788525342941284,0.0 -135200,3.5578613,0.96898305,,,,,,,,,,,,,, -135300,3.4640756,1.0263362,,,,,,,,,,,,,, -135400,3.7158737,1.0476965,,,,,,,,,,,,,, -135500,3.497574,0.9639981,,,,,,,,,,,,,, -135600,3.2911458,1.0315783,,,,,,,,,,,,,, -135700,3.4402294,1.0900346,,,,,,,,,,,,,, -135800,3.3275983,1.0328679,,,,,,,,,,,,,, -135900,3.3366463,1.0185672,,,,,,,,,,,,,, -136000,3.824102,1.0388583,,,,,,,,,,,,,, -136100,3.3589368,1.018349,,,,,,,,,,,,,, -136200,3.6047003,0.98133224,,,,,,,,,,,,,, -136300,3.3773265,1.038914,,,,,,,,,,,,,, -136400,3.4724557,1.045478,,,,,,,,,,,,,, -136500,3.6412961,1.0511471,,,,,,,,,,,,,, -136600,3.5272753,1.0405484,,,,,,,,,,,,,, -136619,,,0.8668088316917419,0.4636785984039306,0.724399983882904,1.1427175998687744,50000.0,0.5943000316619873,1.8947100639343264,10000.0,45939.46535348892,47570.19637298584,45939.46535348892,1622.4473087787628,3.860964298248291,0.0 -136700,3.489896,1.0546832,,,,,,,,,,,,,, -136800,3.54866,1.020922,,,,,,,,,,,,,, -136900,3.6229417,1.1254245,,,,,,,,,,,,,, -137000,3.4350576,0.9645974,,,,,,,,,,,,,, -137100,3.6405873,1.0306543,,,,,,,,,,,,,, -137200,3.6536384,1.0092338,,,,,,,,,,,,,, -137300,3.569007,0.9533385,,,,,,,,,,,,,, -137400,3.3756328,0.9158541,,,,,,,,,,,,,, -137500,3.6803048,1.0796558,,,,,,,,,,,,,, -137600,3.4047072,0.9382963,,,,,,,,,,,,,, -137700,3.9652464,0.9789696,,,,,,,,,,,,,, -137800,3.6442432,0.97614133,,,,,,,,,,,,,, -137900,3.3734603,0.93543977,,,,,,,,,,,,,, -138000,3.3002837,0.9471048,,,,,,,,,,,,,, -138100,3.5696428,1.0447328,,,,,,,,,,,,,, -138138,,,0.8622449040412903,0.4757066071033478,0.7257999777793884,1.1301318407058716,50000.0,0.6041000485420227,1.837876796722412,10000.0,46449.43815970421,48098.23912549019,46449.43815970421,1640.4148676395416,3.913733720779419,0.0 -138200,3.430273,0.97209835,,,,,,,,,,,,,, -138300,3.405898,0.95773494,,,,,,,,,,,,,, -138400,3.442748,0.9413355,,,,,,,,,,,,,, -138500,3.7722673,1.1089429,,,,,,,,,,,,,, -138600,3.7230978,1.0296786,,,,,,,,,,,,,, -138700,3.5114157,1.0118973,,,,,,,,,,,,,, -138800,3.703127,1.0156654,,,,,,,,,,,,,, -138900,3.9600801,1.0469899,,,,,,,,,,,,,, -139000,3.9134767,1.0631112,,,,,,,,,,,,,, -139100,3.5056002,1.0050292,,,,,,,,,,,,,, -139200,3.490993,0.989186,,,,,,,,,,,,,, -139300,3.6667962,0.9518376,,,,,,,,,,,,,, -139400,3.5752697,1.0108181,,,,,,,,,,,,,, -139500,3.5886896,0.9505569,,,,,,,,,,,,,, -139600,3.432286,1.0119319,,,,,,,,,,,,,, -139656,,,0.8647361397743225,0.466713011264801,0.7309799790382385,1.1106938123703003,50000.0,0.6045000553131104,1.8522108793258667,10000.0,46959.3910138607,48626.24474787712,46959.3910138607,1658.3606095314026,3.971649646759033,0.0 -139700,3.278439,0.9536708,,,,,,,,,,,,,, -139800,3.8651626,0.96499085,,,,,,,,,,,,,, -139900,3.3149245,0.96323043,,,,,,,,,,,,,, -140000,3.520407,0.9682575,,,,,,,,,,,,,, -140100,3.8469713,0.9854961,,,,,,,,,,,,,, -140200,3.718662,0.9600978,,,,,,,,,,,,,, -140300,3.6679075,0.95018965,,,,,,,,,,,,,, -140400,3.6088872,0.97793955,,,,,,,,,,,,,, -140500,3.4444711,0.974653,,,,,,,,,,,,,, -140600,3.9978592,1.0011365,,,,,,,,,,,,,, -140700,3.6342456,0.9368608,,,,,,,,,,,,,, -140800,3.3748245,0.91070545,,,,,,,,,,,,,, -140900,3.753624,0.9213956,,,,,,,,,,,,,, -141000,3.6848006,0.8550924,,,,,,,,,,,,,, -141100,3.7906547,0.92512476,,,,,,,,,,,,,, -141173,,,0.8649353981018066,0.4697670340538025,0.7278599739074707,1.1257988214492798,50000.0,0.6008000373840332,1.848008275032044,10000.0,47469.4056904316,49154.06947255135,47469.4056904316,1676.064817905426,4.028349161148071,0.0 -141200,3.6265895,0.88782114,,,,,,,,,,,,,, -141300,3.6788342,1.1216844,,,,,,,,,,,,,, -141400,3.7266026,1.0672902,,,,,,,,,,,,,, -141500,3.8802736,1.0929232,,,,,,,,,,,,,, -141600,3.6316288,1.0162191,,,,,,,,,,,,,, -141700,3.606584,0.94692826,,,,,,,,,,,,,, -141800,3.7965136,0.97882295,,,,,,,,,,,,,, -141900,3.7905023,0.94232345,,,,,,,,,,,,,, -142000,3.3917391,0.9759438,,,,,,,,,,,,,, -142100,3.6028273,1.0047705,,,,,,,,,,,,,, -142200,3.4696689,1.0105532,,,,,,,,,,,,,, -142300,3.6018922,0.9151257,,,,,,,,,,,,,, -142400,3.7919066,0.99028677,,,,,,,,,,,,,, -142500,3.7727933,0.832261,,,,,,,,,,,,,, -142600,3.5386794,0.8680389,,,,,,,,,,,,,, -142691,,,0.8581592440605164,0.4860461056232452,0.7240399718284607,1.1514825820922852,50000.0,0.5978000164031982,1.939948320388794,10000.0,47979.4405105114,49682.18212604523,47979.4405105114,1694.0392887592316,4.082364082336426,0.0 -142700,3.8103755,0.98719066,,,,,,,,,,,,,, -142800,3.7814848,0.89668095,,,,,,,,,,,,,, -142900,3.8154905,0.9995258,,,,,,,,,,,,,, -143000,3.7355201,0.99836445,,,,,,,,,,,,,, -143100,3.7245142,0.9190706,,,,,,,,,,,,,, -143200,3.771955,0.9910025,,,,,,,,,,,,,, -143300,3.913062,0.97462016,,,,,,,,,,,,,, -143400,3.9680882,0.9053555,,,,,,,,,,,,,, -143500,3.8124382,0.8939053,,,,,,,,,,,,,, -143600,3.7605078,0.9270955,,,,,,,,,,,,,, -143700,3.7735803,0.9810468,,,,,,,,,,,,,, -143800,3.9207065,0.99516094,,,,,,,,,,,,,, -143900,3.9969497,0.9276066,,,,,,,,,,,,,, -144000,3.6293235,0.98759174,,,,,,,,,,,,,, -144100,3.5945585,0.9449196,,,,,,,,,,,,,, -144200,3.899049,0.91078585,,,,,,,,,,,,,, -144208,,,0.8985769748687744,0.3564726412296295,0.7285599708557129,1.129584789276123,50000.0,0.6085000038146973,1.879884600639344,10000.0,48489.41442799568,50209.91431188584,48489.41442799568,1711.6912882328031,4.139669179916382,0.0 -144300,3.9106338,0.96379274,,,,,,,,,,,,,, -144400,4.0557303,0.9373338,,,,,,,,,,,,,, -144500,3.7317545,0.94284654,,,,,,,,,,,,,, -144600,3.5862875,0.856071,,,,,,,,,,,,,, -144700,3.7377443,1.0576113,,,,,,,,,,,,,, -144800,4.080341,1.0433596,,,,,,,,,,,,,, -144900,3.8146734,0.9470725,,,,,,,,,,,,,, -145000,4.0244904,0.91486657,,,,,,,,,,,,,, -145100,3.6986213,0.9316283,,,,,,,,,,,,,, -145200,3.7781718,0.9184451,,,,,,,,,,,,,, -145300,3.867433,0.8590422,,,,,,,,,,,,,, -145400,3.8285983,0.878793,,,,,,,,,,,,,, -145500,3.981874,0.9741588,,,,,,,,,,,,,, -145600,3.451337,0.8428425,,,,,,,,,,,,,, -145700,4.012144,0.89540935,,,,,,,,,,,,,, -145726,,,0.8902861475944519,0.382175862789154,0.7315999865531921,1.1145838499069214,50000.0,0.6066000461578369,1.8463207483291624,10000.0,48999.39666056633,50737.63028144837,48999.39666056633,1729.3212552070618,4.194288969039917,0.0 -145800,4.0103393,0.9889054,,,,,,,,,,,,,, -145900,3.3962977,0.76962733,,,,,,,,,,,,,, -146000,3.9169679,0.92714685,,,,,,,,,,,,,, -146100,3.7229285,0.99470717,,,,,,,,,,,,,, -146200,4.071855,0.9251411,,,,,,,,,,,,,, -146300,4.1546974,0.90392804,,,,,,,,,,,,,, -146400,3.934606,0.93280566,,,,,,,,,,,,,, -146500,3.7524452,0.8893352,,,,,,,,,,,,,, -146600,4.043342,0.930272,,,,,,,,,,,,,, -146700,3.9617026,0.8975068,,,,,,,,,,,,,, -146800,3.9007003,0.9426927,,,,,,,,,,,,,, -146900,3.907159,0.9421273,,,,,,,,,,,,,, -147000,3.8085446,0.940589,,,,,,,,,,,,,, -147100,4.800347,1.0226254,,,,,,,,,,,,,, -147200,3.7255454,0.8243873,,,,,,,,,,,,,, -147244,,,0.8928371667861938,0.3668327927589416,0.7371399998664856,1.0903639793395996,50000.0,0.6055000424385071,1.839654922485352,10000.0,49509.46116828919,51265.44833254814,49509.46116828919,1746.9731595516205,4.246668100357056,0.0 -147300,3.908116,0.90785104,,,,,,,,,,,,,, -147400,4.0994635,1.0054861,,,,,,,,,,,,,, -147500,3.9072716,0.8772979,,,,,,,,,,,,,, -147600,3.8237154,0.84100354,,,,,,,,,,,,,, -147700,3.6976748,0.9052324,,,,,,,,,,,,,, -147800,3.780517,0.889747,,,,,,,,,,,,,, -147900,4.0297084,0.8966047,,,,,,,,,,,,,, -148000,3.7129595,0.87261283,,,,,,,,,,,,,, -148100,3.9180918,0.89019436,,,,,,,,,,,,,, -148200,4.384202,0.8551718,,,,,,,,,,,,,, -148300,4.1296926,0.88633984,,,,,,,,,,,,,, -148400,3.7862208,0.917254,,,,,,,,,,,,,, -148500,4.1796336,0.8986194,,,,,,,,,,,,,, -148600,3.9064019,0.9778116,,,,,,,,,,,,,, -148700,4.001844,0.8960148,,,,,,,,,,,,,, -148762,,,0.8951091766357422,0.3621106743812561,0.738599956035614,1.090329647064209,50000.0,0.6096000075340271,1.825580596923828,10000.0,50019.50054001808,51793.38860464096,50019.50054001808,1764.7672312259674,4.303528308868408,0.0 -148800,4.0853662,0.9693299,,,,,,,,,,,,,, -148900,3.7695785,0.8316464,,,,,,,,,,,,,, -149000,3.8538942,0.8583837,,,,,,,,,,,,,, -149100,3.5394568,0.7754705,,,,,,,,,,,,,, -149200,3.5758286,0.8142903,,,,,,,,,,,,,, -149300,4.147621,0.8663743,,,,,,,,,,,,,, -149400,4.1056147,0.9030801,,,,,,,,,,,,,, -149500,4.011642,0.892162,,,,,,,,,,,,,, -149600,3.9591303,0.90252864,,,,,,,,,,,,,, -149700,3.7661643,0.84634787,,,,,,,,,,,,,, -149800,4.1540265,0.9209551,,,,,,,,,,,,,, -149900,4.0727835,0.8946074,,,,,,,,,,,,,, -150000,3.8213954,0.9331251,,,,,,,,,,,,,, -150100,3.85225,0.8282331,,,,,,,,,,,,,, -150200,4.280439,0.9256818,,,,,,,,,,,,,, -150280,,,0.8956672549247742,0.356880247592926,0.7396799921989441,1.1052157878875732,50000.0,0.6111000180244446,1.879051089286804,10000.0,50529.406017541885,52320.97790455818,50529.406017541885,1782.3478038311005,4.357873916625977,0.0 -150300,4.290598,0.94327426,,,,,,,,,,,,,, -150400,4.16163,0.8724563,,,,,,,,,,,,,, -150500,3.983008,0.8951747,,,,,,,,,,,,,, -150600,4.041538,0.87900305,,,,,,,,,,,,,, -150700,3.7473767,0.83260345,,,,,,,,,,,,,, -150800,4.2125535,0.8789591,,,,,,,,,,,,,, -150900,4.1151,0.88661635,,,,,,,,,,,,,, -151000,4.220882,0.8263239,,,,,,,,,,,,,, -151100,4.2053723,0.80065596,,,,,,,,,,,,,, -151200,3.9772015,0.85669744,,,,,,,,,,,,,, -151300,3.8536232,0.77607065,,,,,,,,,,,,,, -151400,3.9927092,0.82877207,,,,,,,,,,,,,, -151500,4.064959,0.8930925,,,,,,,,,,,,,, -151600,4.216803,0.95387614,,,,,,,,,,,,,, -151700,3.7365613,0.8464184,,,,,,,,,,,,,, -151799,,,0.904715359210968,0.3321236968040466,0.7416599988937378,1.0797576904296875,50000.0,0.613800048828125,1.824073076248169,10000.0,51039.55972290039,52849.2077562809,51039.55972290039,1800.3187172412872,4.414421319961548,0.0 -151800,3.855538,0.81332713,,,,,,,,,,,,,, -151900,3.8687043,0.8208679,,,,,,,,,,,,,, -152000,3.7649572,0.8159202,,,,,,,,,,,,,, -152100,4.0660667,0.82842004,,,,,,,,,,,,,, -152200,3.9405649,0.823586,,,,,,,,,,,,,, -152300,4.107398,0.872795,,,,,,,,,,,,,, -152400,3.9415014,0.7781279,,,,,,,,,,,,,, -152500,3.709161,0.744434,,,,,,,,,,,,,, -152600,4.43393,0.8838512,,,,,,,,,,,,,, -152700,4.1563053,0.86394024,,,,,,,,,,,,,, -152800,4.5381336,0.8227209,,,,,,,,,,,,,, -152900,3.8135886,0.8483772,,,,,,,,,,,,,, -153000,4.0781217,0.8622616,,,,,,,,,,,,,, -153100,4.0343437,0.84435034,,,,,,,,,,,,,, -153200,4.0044665,0.8235664,,,,,,,,,,,,,, -153300,4.1230063,0.94001245,,,,,,,,,,,,,, -153317,,,0.9268175959587096,0.2608273029327392,0.7423999905586243,1.080202579498291,50000.0,0.6152999997138977,1.827741861343384,10000.0,51549.5005209446,53376.83444476128,51549.5005209446,1817.896065711975,4.474432706832886,0.0 -153400,4.2759485,0.8519272,,,,,,,,,,,,,, -153500,3.8790872,0.856229,,,,,,,,,,,,,, -153600,3.9923112,0.7815639,,,,,,,,,,,,,, -153700,4.0951014,0.84117573,,,,,,,,,,,,,, -153800,4.2448316,0.8683777,,,,,,,,,,,,,, -153900,3.8832612,0.79607004,,,,,,,,,,,,,, -154000,4.168441,0.8186695,,,,,,,,,,,,,, -154100,4.3871703,0.9131021,,,,,,,,,,,,,, -154200,4.251385,0.95654917,,,,,,,,,,,,,, -154300,3.829846,0.7896858,,,,,,,,,,,,,, -154400,4.3040814,0.83719516,,,,,,,,,,,,,, -154500,4.051118,0.80663615,,,,,,,,,,,,,, -154600,4.898046,0.862944,,,,,,,,,,,,,, -154700,4.04018,0.9076749,,,,,,,,,,,,,, -154800,3.975752,0.7645061,,,,,,,,,,,,,, -154835,,,0.9178690910339355,0.285111129283905,0.7413199543952942,1.083351969718933,50000.0,0.6127000451087952,1.845734715461731,10000.0,52059.46150302887,53904.571982860565,52059.46150302887,1835.5662310123444,4.531611204147339,0.0 -154900,4.0581646,0.8130056,,,,,,,,,,,,,, -155000,4.0319095,0.8649241,,,,,,,,,,,,,, -155100,4.6254225,0.8838602,,,,,,,,,,,,,, -155200,4.1826005,0.85160553,,,,,,,,,,,,,, -155300,3.8920803,0.8052373,,,,,,,,,,,,,, -155400,4.198257,0.8324632,,,,,,,,,,,,,, -155500,4.2655554,0.7510985,,,,,,,,,,,,,, -155600,4.064442,0.74470836,,,,,,,,,,,,,, -155700,4.112457,0.8001349,,,,,,,,,,,,,, -155800,3.8388946,0.7789154,,,,,,,,,,,,,, -155900,4.243874,0.82810354,,,,,,,,,,,,,, -156000,4.3102407,0.79547244,,,,,,,,,,,,,, -156100,4.3594117,0.8213169,,,,,,,,,,,,,, -156200,4.3890862,0.8396925,,,,,,,,,,,,,, -156300,4.3429933,0.7830018,,,,,,,,,,,,,, -156353,,,0.917629897594452,0.2818046808242798,0.7437999844551086,1.084424614906311,50000.0,0.6220000386238098,1.8379580974578853,10000.0,52569.59485697746,54432.38970851898,52569.59485697746,1853.1430249214168,4.589997291564941,0.0 -156400,4.147407,0.7713611,,,,,,,,,,,,,, -156500,4.067699,0.7293019,,,,,,,,,,,,,, -156600,4.317819,0.7557428,,,,,,,,,,,,,, -156700,3.919115,0.7239666,,,,,,,,,,,,,, -156800,4.5331755,0.8565703,,,,,,,,,,,,,, -156900,4.285721,0.7920735,,,,,,,,,,,,,, -157000,4.603735,0.8960345,,,,,,,,,,,,,, -157100,4.307884,0.75515574,,,,,,,,,,,,,, -157200,4.259166,0.8203618,,,,,,,,,,,,,, -157300,4.415821,0.822821,,,,,,,,,,,,,, -157400,4.029235,0.7354069,,,,,,,,,,,,,, -157500,3.9225547,0.6951902,,,,,,,,,,,,,, -157600,4.779216,0.8118714,,,,,,,,,,,,,, -157700,4.08079,0.70283026,,,,,,,,,,,,,, -157800,4.225633,0.79223967,,,,,,,,,,,,,, -157873,,,0.919702649116516,0.2783809602260589,0.7439000010490417,1.0915087461471558,50000.0,0.615600049495697,1.8537700176239007,10000.0,53079.72857952118,54960.33554697037,53079.72857952118,1870.8468675613403,4.649376630783081,0.0 -157900,4.384731,0.79794216,,,,,,,,,,,,,, -158000,4.4226317,0.85650754,,,,,,,,,,,,,, -158100,4.6448174,0.818311,,,,,,,,,,,,,, -158200,4.292173,0.8095093,,,,,,,,,,,,,, -158300,4.110618,0.7366523,,,,,,,,,,,,,, -158400,4.025226,0.77747947,,,,,,,,,,,,,, -158500,4.0477557,0.77071077,,,,,,,,,,,,,, -158600,4.167478,0.7849536,,,,,,,,,,,,,, -158700,4.17913,0.7805871,,,,,,,,,,,,,, -158800,4.5617375,0.7777722,,,,,,,,,,,,,, -158900,4.795452,0.89540064,,,,,,,,,,,,,, -159000,4.2011414,0.79481834,,,,,,,,,,,,,, -159100,4.532026,0.8431607,,,,,,,,,,,,,, -159200,4.4650497,0.8012545,,,,,,,,,,,,,, -159300,3.9785576,0.7171426,,,,,,,,,,,,,, -159391,,,0.919742465019226,0.274879902601242,0.7453799843788147,1.084672927856445,50000.0,0.6165000200271606,1.855073928833008,10000.0,53589.63150548935,55488.38742780685,53589.63150548935,1888.887590646744,4.708525896072388,0.0 -159400,4.1713543,0.7674001,,,,,,,,,,,,,, -159500,4.327895,0.7515154,,,,,,,,,,,,,, -159600,4.124542,0.73333216,,,,,,,,,,,,,, -159700,3.997137,0.73288107,,,,,,,,,,,,,, -159800,4.1220274,0.7620884,,,,,,,,,,,,,, -159900,4.071447,0.7078293,,,,,,,,,,,,,, -160000,4.2503257,0.84088075,,,,,,,,,,,,,, -160100,4.0477643,0.7274859,,,,,,,,,,,,,, -160200,4.795703,0.92304325,,,,,,,,,,,,,, -160300,4.5686045,0.7820179,,,,,,,,,,,,,, -160400,4.6858115,0.7985905,,,,,,,,,,,,,, -160500,4.314461,0.7327218,,,,,,,,,,,,,, -160600,4.1876135,0.7820707,,,,,,,,,,,,,, -160700,4.513283,0.85235476,,,,,,,,,,,,,, -160800,4.0700464,0.7403672,,,,,,,,,,,,,, -160900,4.200387,0.75110066,,,,,,,,,,,,,, -160908,,,0.9289500713348388,0.247159719467163,0.7472599744796753,1.0752619504928589,50000.0,0.6208000183105469,1.851715087890625,10000.0,54099.56414723396,56016.02576851845,54099.56414723396,1906.48606300354,4.766141414642334,0.0 -161000,3.9311614,0.68367666,,,,,,,,,,,,,, -161100,4.2553086,0.7325652,,,,,,,,,,,,,, -161200,4.3772454,0.7681659,,,,,,,,,,,,,, -161300,4.2795672,0.7853658,,,,,,,,,,,,,, -161400,4.4202394,0.7343191,,,,,,,,,,,,,, -161500,4.4512587,0.7043363,,,,,,,,,,,,,, -161600,4.1795406,0.7101362,,,,,,,,,,,,,, -161700,4.6262136,0.8107891,,,,,,,,,,,,,, -161800,4.3531775,0.77341247,,,,,,,,,,,,,, -161900,4.1677494,0.76394594,,,,,,,,,,,,,, -162000,4.534788,0.78234076,,,,,,,,,,,,,, -162100,4.6639643,0.7019062,,,,,,,,,,,,,, -162200,4.3483233,0.7756993,,,,,,,,,,,,,, -162300,4.5970573,0.76469797,,,,,,,,,,,,,, -162400,4.656219,0.7465953,,,,,,,,,,,,,, -162426,,,0.9414859414100648,0.2091090530157089,0.7475999593734741,1.0705816745758057,50000.0,0.624500036239624,1.8417563438415527,10000.0,54609.62043738365,56544.21882486344,54609.62043738365,1924.5159137248995,4.82317328453064,0.0 -162500,4.852631,0.80927837,,,,,,,,,,,,,, -162600,4.1103764,0.6872818,,,,,,,,,,,,,, -162700,4.8863707,0.8088526,,,,,,,,,,,,,, -162800,4.3834105,0.74336964,,,,,,,,,,,,,, -162900,4.5160103,0.7193938,,,,,,,,,,,,,, -163000,4.5890436,0.76718944,,,,,,,,,,,,,, -163100,4.4377356,0.7084687,,,,,,,,,,,,,, -163200,4.254345,0.6697124,,,,,,,,,,,,,, -163300,4.5040474,0.70669466,,,,,,,,,,,,,, -163400,4.585265,0.74218583,,,,,,,,,,,,,, -163500,4.0451226,0.732045,,,,,,,,,,,,,, -163600,4.3647914,0.7070023,,,,,,,,,,,,,, -163700,4.429706,0.7313297,,,,,,,,,,,,,, -163800,4.501663,0.72361743,,,,,,,,,,,,,, -163900,4.2928424,0.6859384,,,,,,,,,,,,,, -163945,,,0.9399114847183228,0.2118914574384689,0.7474600076675415,1.0750255584716797,50000.0,0.6200000047683716,1.8490190505981443,10000.0,55119.633002758026,57072.43365240097,55119.633002758026,1942.612450838089,4.879745721817017,0.0 -164000,4.472169,0.7867625,,,,,,,,,,,,,, -164100,4.35756,0.7651987,,,,,,,,,,,,,, -164200,4.963368,0.75115705,,,,,,,,,,,,,, -164300,4.2101183,0.6520821,,,,,,,,,,,,,, -164400,4.624988,0.7339078,,,,,,,,,,,,,, -164500,4.1907506,0.6378065,,,,,,,,,,,,,, -164600,4.6477036,0.7348669,,,,,,,,,,,,,, -164700,3.9286425,0.73155737,,,,,,,,,,,,,, -164800,4.528963,0.7725135,,,,,,,,,,,,,, -164900,4.2410307,0.7349117,,,,,,,,,,,,,, -165000,4.1601605,0.7152376,,,,,,,,,,,,,, -165100,4.5265284,0.7872636,,,,,,,,,,,,,, -165200,4.6595516,0.65969336,,,,,,,,,,,,,, -165300,4.7542205,0.75531137,,,,,,,,,,,,,, -165400,4.499912,0.7385239,,,,,,,,,,,,,, -165463,,,0.9395328164100648,0.2096386849880218,0.7483199834823608,1.0747122764587402,50000.0,0.6241000294685364,1.847937822341919,10000.0,55629.809804201126,57600.36788249016,55629.809804201126,1960.2603447437289,4.9395411014556885,0.0 -165500,4.1755505,0.7381796,,,,,,,,,,,,,, -165600,4.395035,0.66546047,,,,,,,,,,,,,, -165700,4.5432267,0.697908,,,,,,,,,,,,,, -165800,4.1956687,0.71628827,,,,,,,,,,,,,, -165900,4.597966,0.7260038,,,,,,,,,,,,,, -166000,4.4404182,0.78947216,,,,,,,,,,,,,, -166100,4.2688713,0.661343,,,,,,,,,,,,,, -166200,4.4577513,0.7050557,,,,,,,,,,,,,, -166300,4.260931,0.67206067,,,,,,,,,,,,,, -166400,4.3151493,0.7035108,,,,,,,,,,,,,, -166500,4.54894,0.66345257,,,,,,,,,,,,,, -166600,4.435088,0.7128221,,,,,,,,,,,,,, -166700,4.5874267,0.74296474,,,,,,,,,,,,,, -166800,4.2504754,0.69098556,,,,,,,,,,,,,, -166900,4.37502,0.66462463,,,,,,,,,,,,,, -166981,,,0.9429607391357422,0.2056682556867599,0.750059962272644,1.0746914148330688,50000.0,0.6247000098228455,1.840732455253601,10000.0,56139.72080612183,58128.30862569809,56139.72080612183,1978.183670282364,4.996540307998657,0.0 -167000,4.6390176,0.70169437,,,,,,,,,,,,,, -167100,4.5733247,0.69976676,,,,,,,,,,,,,, -167200,4.4268365,0.6909421,,,,,,,,,,,,,, -167300,5.111729,0.8119328,,,,,,,,,,,,,, -167400,4.280455,0.6560622,,,,,,,,,,,,,, -167500,4.761509,0.63131356,,,,,,,,,,,,,, -167600,5.3586698,0.7675456,,,,,,,,,,,,,, -167700,4.4760647,0.7324142,,,,,,,,,,,,,, -167800,4.6686053,0.73200446,,,,,,,,,,,,,, -167900,4.3848486,0.7520102,,,,,,,,,,,,,, -168000,5.305427,0.7161015,,,,,,,,,,,,,, -168100,4.6800814,0.652064,,,,,,,,,,,,,, -168200,4.287256,0.71441525,,,,,,,,,,,,,, -168300,4.5081177,0.7274011,,,,,,,,,,,,,, -168400,4.3883195,0.6723928,,,,,,,,,,,,,, -168498,,,0.945133090019226,0.1923914104700088,0.7511999607086182,1.0681169033050537,50000.0,0.625700056552887,1.835061192512512,10000.0,56649.69009900093,58656.19987845421,56649.69009900093,1995.9945611953733,5.058220386505127,0.0 -168500,4.1898885,0.6566721,,,,,,,,,,,,,, -168600,4.452724,0.6882482,,,,,,,,,,,,,, -168700,4.4153113,0.54956913,,,,,,,,,,,,,, -168800,4.581071,0.6674348,,,,,,,,,,,,,, -168900,4.678371,0.7470333,,,,,,,,,,,,,, -169000,4.244549,0.63715494,,,,,,,,,,,,,, -169100,4.298596,0.67931765,,,,,,,,,,,,,, -169200,4.311736,0.68885183,,,,,,,,,,,,,, -169300,5.0096292,0.69944,,,,,,,,,,,,,, -169400,4.4815683,0.65622437,,,,,,,,,,,,,, -169500,4.1142535,0.60247576,,,,,,,,,,,,,, -169600,4.7589564,0.7527282,,,,,,,,,,,,,, -169700,5.0591054,0.6759418,,,,,,,,,,,,,, -169800,4.9900837,0.77322567,,,,,,,,,,,,,, -169900,4.673542,0.72016627,,,,,,,,,,,,,, -170000,4.5402265,0.67638427,,,,,,,,,,,,,, -170016,,,0.9563735723495485,0.1631961911916732,0.7527999877929688,1.0671993494033811,50000.0,0.6259000301361084,1.8455358743667605,10000.0,57159.68464636803,59184.30030846596,57159.68464636803,2013.9880871772768,5.121249437332153,0.0 -170100,4.277624,0.65375656,,,,,,,,,,,,,, -170200,4.0677495,0.6082261,,,,,,,,,,,,,, -170300,4.5658092,0.68689495,,,,,,,,,,,,,, -170400,4.935953,0.5923596,,,,,,,,,,,,,, -170500,4.791982,0.70082444,,,,,,,,,,,,,, -170600,4.7759132,0.7143913,,,,,,,,,,,,,, -170700,4.3765845,0.6243905,,,,,,,,,,,,,, -170800,4.42229,0.71169674,,,,,,,,,,,,,, -170900,4.7658153,0.62861335,,,,,,,,,,,,,, -171000,4.2050114,0.6736596,,,,,,,,,,,,,, -171100,4.4378233,0.6745172,,,,,,,,,,,,,, -171200,4.288138,0.6652855,,,,,,,,,,,,,, -171300,4.5417395,0.5950274,,,,,,,,,,,,,, -171400,4.7485476,0.62013876,,,,,,,,,,,,,, -171500,4.301819,0.6124931,,,,,,,,,,,,,, -171535,,,0.9563934803009032,0.1602298766374588,0.7529199719429016,1.065573811531067,50000.0,0.6290000081062317,1.833582758903504,10000.0,57669.88044476509,59712.42644166946,57669.88044476509,2031.8099205493927,5.179990291595459,0.0 -171600,4.2821817,0.57709324,,,,,,,,,,,,,, -171700,4.7609944,0.6690032,,,,,,,,,,,,,, -171800,4.397058,0.65145284,,,,,,,,,,,,,, -171900,4.596008,0.6870944,,,,,,,,,,,,,, -172000,4.14668,0.58837414,,,,,,,,,,,,,, -172100,4.5895886,0.63610995,,,,,,,,,,,,,, -172200,4.580012,0.6439815,,,,,,,,,,,,,, -172300,4.5653577,0.7158263,,,,,,,,,,,,,, -172400,4.8819327,0.70330566,,,,,,,,,,,,,, -172500,4.4858966,0.63342947,,,,,,,,,,,,,, -172600,4.484122,0.6531336,,,,,,,,,,,,,, -172700,4.3146963,0.65580624,,,,,,,,,,,,,, -172800,4.187025,0.5714103,,,,,,,,,,,,,, -172900,4.930545,0.5955264,,,,,,,,,,,,,, -173000,4.437134,0.671165,,,,,,,,,,,,,, -173053,,,0.9550382494926452,0.1676359325647354,0.7534399628639221,1.057241439819336,50000.0,0.627500057220459,1.8320350646972656,10000.0,58179.970638751984,60240.450879096985,58179.970638751984,2049.632456064224,5.242520332336426,0.0 -173100,4.8056993,0.66418666,,,,,,,,,,,,,, -173200,4.8035107,0.68726814,,,,,,,,,,,,,, -173300,4.56288,0.64789426,,,,,,,,,,,,,, -173400,4.724403,0.7392304,,,,,,,,,,,,,, -173500,4.2150373,0.6318365,,,,,,,,,,,,,, -173600,4.736138,0.6340217,,,,,,,,,,,,,, -173700,4.431741,0.6604414,,,,,,,,,,,,,, -173800,4.7956195,0.6703165,,,,,,,,,,,,,, -173900,4.74915,0.6349366,,,,,,,,,,,,,, -174000,4.3690543,0.59961134,,,,,,,,,,,,,, -174100,4.24185,0.56436217,,,,,,,,,,,,,, -174200,4.06733,0.6274868,,,,,,,,,,,,,, -174300,4.9185014,0.6855537,,,,,,,,,,,,,, -174400,5.5025663,0.6918892,,,,,,,,,,,,,, -174500,4.4965315,0.64890367,,,,,,,,,,,,,, -174572,,,0.954300820827484,0.1634673774242401,0.7548999786376953,1.0598605871200562,50000.0,0.6304000020027161,1.833614706993103,10000.0,58689.94938135147,60768.78918766976,58689.94938135147,2067.8815484046936,5.303300857543945,0.0 -174600,4.710434,0.6389763,,,,,,,,,,,,,, -174700,4.2014737,0.65719473,,,,,,,,,,,,,, -174800,4.294554,0.6412582,,,,,,,,,,,,,, -174900,4.39321,0.6487477,,,,,,,,,,,,,, -175000,4.8304777,0.7150998,,,,,,,,,,,,,, -175100,4.536862,0.64029557,,,,,,,,,,,,,, -175200,4.536063,0.6877639,,,,,,,,,,,,,, -175300,4.555101,0.636876,,,,,,,,,,,,,, -175400,4.8433666,0.6965122,,,,,,,,,,,,,, -175500,4.578622,0.6078561,,,,,,,,,,,,,, -175600,4.328395,0.6212616,,,,,,,,,,,,,, -175700,5.2289643,0.72217226,,,,,,,,,,,,,, -175800,4.874781,0.7375535,,,,,,,,,,,,,, -175900,4.357349,0.70404446,,,,,,,,,,,,,, -176000,4.3597465,0.6099902,,,,,,,,,,,,,, -176090,,,0.9560347199440002,0.160325139760971,0.7560399770736694,1.0561996698379517,50000.0,0.6302000284194946,1.832032561302185,10000.0,59199.97342252731,61296.55711436272,59199.97342252731,2085.516634464264,5.3632285594940186,0.0 -176100,4.71195,0.6474534,,,,,,,,,,,,,, -176200,4.6240497,0.66566443,,,,,,,,,,,,,, -176300,3.9061553,0.60955065,,,,,,,,,,,,,, -176400,4.666757,0.59850585,,,,,,,,,,,,,, -176500,4.549137,0.6432718,,,,,,,,,,,,,, -176600,4.372125,0.59857196,,,,,,,,,,,,,, -176700,4.5246334,0.6316218,,,,,,,,,,,,,, -176800,4.5129585,0.6497772,,,,,,,,,,,,,, -176900,5.5404434,0.6765971,,,,,,,,,,,,,, -177000,4.328712,0.66560566,,,,,,,,,,,,,, -177100,4.568686,0.63351655,,,,,,,,,,,,,, -177200,4.496566,0.6832261,,,,,,,,,,,,,, -177300,4.291744,0.61120874,,,,,,,,,,,,,, -177400,4.5252867,0.6487252,,,,,,,,,,,,,, -177500,4.8333845,0.6322167,,,,,,,,,,,,,, -177600,4.7822123,0.6221291,,,,,,,,,,,,,, -177608,,,0.9580675959587096,0.1550871878862381,0.7548399567604065,1.0575575828552246,50000.0,0.6304000020027161,1.8350627422332764,10000.0,59710.12469863892,61824.70168447495,59710.12469863892,2103.398220539093,5.424897432327271,0.0 -177700,4.5763907,0.64784795,,,,,,,,,,,,,, -177800,4.336284,0.67313546,,,,,,,,,,,,,, -177900,4.376255,0.64957565,,,,,,,,,,,,,, -178000,4.8494205,0.67158806,,,,,,,,,,,,,, -178100,4.4117684,0.6326178,,,,,,,,,,,,,, -178200,4.7808876,0.67180455,,,,,,,,,,,,,, -178300,4.9267554,0.6999314,,,,,,,,,,,,,, -178400,4.1986637,0.58868563,,,,,,,,,,,,,, -178500,4.9303503,0.6836279,,,,,,,,,,,,,, -178600,4.7210994,0.577256,,,,,,,,,,,,,, -178700,4.319908,0.6188925,,,,,,,,,,,,,, -178800,4.5197515,0.5971606,,,,,,,,,,,,,, -178900,4.450411,0.6238265,,,,,,,,,,,,,, -179000,4.3575206,0.61646056,,,,,,,,,,,,,, -179100,4.641731,0.65549797,,,,,,,,,,,,,, -179127,,,0.9602997303009032,0.1473343223333358,0.7542799711227417,1.0583137273788452,50000.0,0.629800021648407,1.835172295570373,10000.0,60220.22158074379,62352.568135261536,60220.22158074379,2121.054502010345,5.489213466644287,0.0 -179200,4.328329,0.70020497,,,,,,,,,,,,,, -179300,4.586617,0.58429486,,,,,,,,,,,,,, -179400,4.613016,0.66777223,,,,,,,,,,,,,, -179500,4.3422055,0.6244595,,,,,,,,,,,,,, -179600,4.5085025,0.64090693,,,,,,,,,,,,,, -179700,4.44059,0.5880321,,,,,,,,,,,,,, -179800,5.5160074,0.60581625,,,,,,,,,,,,,, -179900,4.7325215,0.6803167,,,,,,,,,,,,,, -180000,4.438186,0.60993356,,,,,,,,,,,,,, -180100,4.8686013,0.6213161,,,,,,,,,,,,,, -180200,4.5896587,0.71504754,,,,,,,,,,,,,, -180300,4.555829,0.6366839,,,,,,,,,,,,,, -180400,4.537146,0.5829802,,,,,,,,,,,,,, -180500,4.6774125,0.57956475,,,,,,,,,,,,,, -180600,4.9953876,0.67436194,,,,,,,,,,,,,, -180645,,,0.9606983065605164,0.1453530639410019,0.7556399703025818,1.055141806602478,50000.0,0.6306000351905823,1.829922795295716,10000.0,60730.1836707592,62880.34313893318,60730.1836707592,2138.7541739940643,5.5535314083099365,0.0 -180700,4.6401687,0.57754874,,,,,,,,,,,,,, -180800,4.202494,0.6135435,,,,,,,,,,,,,, -180900,4.5687375,0.6677391,,,,,,,,,,,,,, -181000,4.7406,0.64761513,,,,,,,,,,,,,, -181100,4.3242455,0.63496256,,,,,,,,,,,,,, -181200,4.6411886,0.60688305,,,,,,,,,,,,,, -181300,4.554166,0.65593326,,,,,,,,,,,,,, -181400,4.359284,0.58013153,,,,,,,,,,,,,, -181500,4.389636,0.6631408,,,,,,,,,,,,,, -181600,4.7524605,0.57063097,,,,,,,,,,,,,, -181700,4.1942077,0.6292343,,,,,,,,,,,,,, -181800,4.8088126,0.6602154,,,,,,,,,,,,,, -181900,4.216315,0.580317,,,,,,,,,,,,,, -182000,5.129227,0.67562824,,,,,,,,,,,,,, -182100,4.3764796,0.6337172,,,,,,,,,,,,,, -182163,,,0.959004282951355,0.1480942070484161,0.7565400004386902,1.0530595779418943,50000.0,0.6303000450134277,1.828680157661438,10000.0,61240.36377048493,63408.5197532177,61240.36377048493,2156.6417529582977,5.613237142562866,0.0 -182200,5.2184763,0.6271939,,,,,,,,,,,,,, -182300,4.586573,0.64148414,,,,,,,,,,,,,, -182400,4.8625875,0.6117109,,,,,,,,,,,,,, -182500,4.4195113,0.61325306,,,,,,,,,,,,,, -182600,4.7857203,0.62716883,,,,,,,,,,,,,, -182700,4.2425585,0.6358813,,,,,,,,,,,,,, -182800,4.938696,0.6382276,,,,,,,,,,,,,, -182900,4.3933263,0.6072759,,,,,,,,,,,,,, -183000,5.0668044,0.6107155,,,,,,,,,,,,,, -183100,4.4194655,0.5701294,,,,,,,,,,,,,, -183200,4.1251645,0.5700763,,,,,,,,,,,,,, -183300,4.7802806,0.6480717,,,,,,,,,,,,,, -183400,4.814106,0.6562469,,,,,,,,,,,,,, -183500,4.3466353,0.60716975,,,,,,,,,,,,,, -183600,4.349889,0.59137136,,,,,,,,,,,,,, -183681,,,0.9598413109779358,0.150454580783844,0.7559999823570251,1.053408980369568,50000.0,0.6308000087738037,1.8285574913024905,10000.0,61750.54144191742,63936.44031405449,61750.54144191742,2174.271763563156,5.677387714385986,0.0 -183700,3.9961345,0.49574363,,,,,,,,,,,,,, -183800,5.045462,0.62838835,,,,,,,,,,,,,, -183900,4.53881,0.6290709,,,,,,,,,,,,,, -184000,4.800845,0.58329964,,,,,,,,,,,,,, -184100,4.2941804,0.72224057,,,,,,,,,,,,,, -184200,4.5690417,0.6125653,,,,,,,,,,,,,, -184300,4.59969,0.75475895,,,,,,,,,,,,,, -184400,4.3854995,0.6091924,,,,,,,,,,,,,, -184500,4.544567,0.5734937,,,,,,,,,,,,,, -184600,4.493949,0.63085747,,,,,,,,,,,,,, -184700,4.646603,0.6145623,,,,,,,,,,,,,, -184800,4.542182,0.6083689,,,,,,,,,,,,,, -184900,5.2109847,0.6210394,,,,,,,,,,,,,, -185000,4.904721,0.65819985,,,,,,,,,,,,,, -185100,4.366572,0.6115645,,,,,,,,,,,,,, -185198,,,0.9613759517669678,0.1429975479841232,0.7555399537086487,1.053418755531311,50000.0,0.6305000185966492,1.8302977085113523,10000.0,62260.458035469055,64464.38110399246,62260.458035469055,2192.182664632797,5.741932153701782,0.0 -185200,4.3346725,0.6466571,,,,,,,,,,,,,, -185300,4.34202,0.5722515,,,,,,,,,,,,,, -185400,4.476545,0.61087114,,,,,,,,,,,,,, -185500,4.9781137,0.64011616,,,,,,,,,,,,,, -185600,4.5493746,0.62493324,,,,,,,,,,,,,, -185700,4.4498057,0.6515099,,,,,,,,,,,,,, -185800,5.369878,0.6760565,,,,,,,,,,,,,, -185900,4.3667808,0.6569587,,,,,,,,,,,,,, -186000,3.9616296,0.6419591,,,,,,,,,,,,,, -186100,5.3378196,0.7218009,,,,,,,,,,,,,, -186200,4.3588195,0.59130615,,,,,,,,,,,,,, -186300,4.4360204,0.6412943,,,,,,,,,,,,,, -186400,4.6323814,0.63568074,,,,,,,,,,,,,, -186500,4.879674,0.65707505,,,,,,,,,,,,,, -186600,4.2419066,0.56811094,,,,,,,,,,,,,, -186666,,,0.9604591727256776,0.1477874964475631,0.7561999559402466,1.054269313812256,50000.0,0.6309000253677368,1.83073890209198,10000.0,62753.75569033623,64975.68322634697,62753.75569033623,2210.0775923728943,5.804227828979492,0.0 -186666,,,,,,,,,,,62753.75569033623,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 4d297020b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.31481671333313,0.0,56.03143048286438,1,0,56.03143048286438,0.0010000000474974,6.907756805419922,10000,96.34635639190674,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -62.29332900047302,0.0311734676361084,476.1916465759277,884,0,476.1916465759277,0.0112000005319714,6.467188358306885,10000,538.5675106048584,0.0150976562872529,6.414700031280518,0.0146199995651841,6.424112796783447,50000 -84.03109979629517,0.0588784217834472,896.3407437801361,1816,0,896.3407437801361,0.0388000011444091,5.978135585784912,10000,980.535098552704,0.0458203107118606,5.830249309539795,0.044039998203516,5.860832214355469,50000 -105.91550326347352,0.0854272842407226,1316.5516102313995,2749,0,1316.5516102313995,0.0491000041365623,5.647575378417969,10000,1422.709864616394,0.0702343732118606,5.432772636413574,0.0639199987053871,5.477494716644287,50000 -129.37765622138977,0.1137945652008056,1736.8272049427032,3681,0,1736.8272049427032,0.0680000036954879,5.371665954589844,10000,1866.528002262116,0.0948632806539535,5.102994441986084,0.0889399945735931,5.145900249481201,50000 -151.5914740562439,0.1422510147094726,2157.039966583252,4611,0,2157.039966583252,0.0988000035285949,5.062837600708008,10000,2309.0352096557617,0.1406054645776748,4.725737571716309,0.1301199942827224,4.775650501251221,50000 -173.5701711177826,0.1694018840789795,2577.1580624580383,5539,0,2577.1580624580383,0.1281000077724456,4.742315292358398,10000,2751.210842370987,0.1874218732118606,4.307502269744873,0.1694400012493133,4.408010959625244,50000 -197.8914442062378,0.1972701549530029,2997.12113070488,6464,0,2997.12113070488,0.1588000059127807,4.514123916625977,10000,3195.574226856232,0.2234570235013961,4.075367450714111,0.2088599950075149,4.152421951293945,50000 -224.1046841144561,0.232349157333374,3417.268706560135,7389,0,3417.268706560135,0.1915000081062317,4.239701271057129,10000,3642.021924495697,0.2674999833106994,3.717461109161377,0.2469999939203262,3.828011751174927,50000 -254.00139904022217,0.2779793739318847,3837.3803448677054,8319,0,3837.3803448677054,0.2145000100135803,4.04166316986084,10000,4092.129539489746,0.3072656095027923,3.447201728820801,0.2823199927806854,3.5882530212402344,50000 -287.7465693950653,0.3111603260040283,4257.711834192276,9246,0,4257.711834192276,0.245600014925003,3.852262020111084,10000,4546.292115211487,0.3461132645606994,3.243919849395752,0.3188599944114685,3.357321262359619,50000 -316.598552942276,0.3597798347473144,4677.799368858337,10176,0,4677.799368858337,0.2625000178813934,3.703757524490357,10000,4995.334156990051,0.3649413883686065,3.0633111000061035,0.3436799943447113,3.1808507442474365,50000 -346.9868311882019,0.3919997215270996,5098.222493410111,11101,0,5098.222493410111,0.2785000205039978,3.62067985534668,10000,5446.229580163956,0.3976367115974426,2.9350926876068115,0.3646000027656555,3.088977336883545,50000 -372.8019599914551,0.4224827289581299,5518.611259937286,12025,0,5518.611259937286,0.2924000024795532,3.486982107162476,10000,5892.516352653503,0.4351562261581421,2.699234962463379,0.3865399956703186,2.927830457687378,50000 -400.3198812007904,0.4549741744995117,5938.5728850364685,12947,0,5938.5728850364685,0.3106000125408172,3.41860294342041,10000,6340.079912424088,0.429980456829071,2.715768575668335,0.3989799916744232,2.8612937927246094,50000 -434.1339132785797,0.499969482421875,6358.816624403,13871,0,6358.816624403,0.3206000030040741,3.336005210876465,10000,6794.234518766403,0.447558581829071,2.601588010787964,0.4180800020694732,2.759405851364136,50000 -470.73082447052,0.5254685878753662,6778.866580486298,14794,0,6778.866580486298,0.3368000090122223,3.2613182067871094,10000,7250.958543777466,0.4675585925579071,2.4731810092926025,0.4262399971485138,2.686066150665283,50000 -508.96109414100647,0.5512490272521973,7199.002569198608,15716,0,7199.002569198608,0.3315000236034393,3.258542060852051,10000,7709.4025728702545,0.4583398401737213,2.544507026672364,0.4334799945354461,2.6762075424194336,50000 -546.786703824997,0.5834970474243164,7619.164701223373,16636,0,7619.164701223373,0.3485000133514404,3.135037422180176,10000,8167.473934412003,0.4844140410423279,2.364227533340454,0.4491399824619293,2.542757511138916,50000 -582.9615287780762,0.6108376979827881,8039.487378358841,17552,0,8039.487378358841,0.3505000174045563,3.126204013824463,10000,8624.050165176392,0.488085925579071,2.3479232788085938,0.4491199851036072,2.5377163887023926,50000 -621.6249876022339,0.6407480239868164,8459.713441371918,18470,0,8459.713441371918,0.3570000231266022,3.123909950256348,10000,9083.020942211151,0.494140625,2.3677914142608643,0.4590199887752533,2.532339334487915,50000 -661.0747628211975,0.6683251857757568,8879.765430927277,19389,0,8879.765430927277,0.3579000234603882,3.091227769851685,10000,9542.602039337158,0.500292956829071,2.304877996444702,0.4635799825191498,2.468919515609741,50000 -698.6457741260529,1.5207412242889404,9298.976362466812,20307,0,9298.976362466812,0.3741000294685364,3.045219898223877,10000,10000.288239002228,0.5157226324081421,2.245275020599365,0.471780002117157,2.446197986602783,50000 -733.1075923442841,1.548471212387085,9719.232017755508,21226,0,9719.232017755508,0.3777000308036804,2.9841349124908447,10000,10455.08547091484,0.5420507788658142,2.0856730937957764,0.4818599820137024,2.369181871414185,50000 -766.7560062408447,1.5872962474822998,10139.521540164948,22142,0,10139.521540164948,0.380700021982193,2.985949754714966,10000,10909.113496303558,0.5222460627555847,2.2017922401428223,0.4871599972248077,2.363506555557251,50000 -799.8872475624084,1.6160616874694824,10559.62688589096,23064,0,10559.62688589096,0.3899000287055969,2.9014742374420166,10000,11362.43018102646,0.5399999618530273,2.108938455581665,0.499099999666214,2.298219919204712,50000 -833.3303182125092,1.6461646556854248,10979.964556217194,23984,0,10979.964556217194,0.3901000320911407,2.90531587600708,10000,11816.292253255844,0.5517382621765137,2.039448499679565,0.5036199688911438,2.2866790294647217,50000 -866.7247035503387,1.6784873008728027,11400.30337524414,24904,0,11400.30337524414,0.3970000147819519,2.8567094802856445,10000,12270.108603954315,0.541796863079071,2.0832176208496094,0.5054000020027161,2.2420337200164795,50000 -900.2320485115051,1.71818208694458,11820.599084615707,25811,0,11820.599084615707,0.4013000130653381,2.819454431533813,10000,12724.000076532364,0.5504101514816284,2.0340700149536133,0.5138599872589111,2.201979875564575,50000 -933.998238325119,1.7490291595458984,12240.525072574615,26725,0,12240.525072574615,0.4070000052452087,2.832810878753662,10000,13177.773655176165,0.566210925579071,2.00087571144104,0.5141199827194214,2.229660987854004,50000 -966.7126722335817,1.7793781757354736,12660.572672367096,27643,0,12660.572672367096,0.4043000340461731,2.829466819763184,10000,13630.616124868391,0.5484570264816284,2.040403366088867,0.519760012626648,2.1909496784210205,50000 -998.4982385635376,1.811950922012329,13080.651557922363,28560,0,13080.651557922363,0.4077000319957733,2.8053174018859863,10000,14082.56403541565,0.560546875,1.999878168106079,0.5216999650001526,2.1805710792541504,50000 -1030.136039018631,1.8442060947418213,13500.786858081818,29477,0,13500.786858081818,0.4166000187397003,2.758517026901245,10000,14534.419796228409,0.5726562142372131,1.928755760192871,0.5296599864959717,2.13201904296875,50000 -1063.1991493701937,1.8761727809906008,13920.732541561129,30393,0,13920.732541561129,0.4150000214576721,2.7555928230285645,10000,14987.511257886888,0.5676171779632568,1.9483190774917605,0.5301600098609924,2.114281415939331,50000 -1096.492238998413,1.905264139175415,14340.659535884855,31309,0,14340.659535884855,0.4228000342845917,2.728455305099488,10000,15440.811559438704,0.5728124976158142,1.927043080329895,0.537559986114502,2.104546546936035,50000 -1128.7679710388184,1.9361648559570312,14760.61864376068,32227,0,14760.61864376068,0.4244000315666199,2.728424072265625,10000,15893.128350257874,0.5775781273841858,1.9075417518615725,0.5331599712371826,2.1145474910736084,50000 -1161.712742805481,1.9648699760437007,15180.854577302933,33142,0,15180.854577302933,0.4270000159740448,2.7143259048461914,10000,16346.38900589943,0.6056835651397705,1.7929567098617554,0.5388399958610535,2.0905473232269287,50000 -1193.7472488880155,1.9952008724212649,15601.136335372925,34059,0,15601.136335372925,0.4227000176906585,2.717620372772217,10000,16798.78573513031,0.577441394329071,1.9018783569335933,0.5426999926567078,2.0791501998901367,50000 -1225.170911550522,2.0287718772888184,16021.365535497664,34977,0,16021.365535497664,0.429500013589859,2.695765256881714,10000,17250.523627996445,0.5922070145606995,1.842816710472107,0.5476399660110474,2.05157208442688,50000 -1257.3156542778015,2.0581812858581543,16441.70537519455,35896,0,16441.70537519455,0.4358000159263611,2.656701803207397,10000,17703.087596178055,0.6042382717132568,1.7828890085220337,0.5482999682426453,2.0370826721191406,50000 -1288.7657787799835,2.0886483192443848,16862.025440216064,36813,0,16862.025440216064,0.4314000308513641,2.669766902923584,10000,18154.93812942505,0.5854101181030273,1.868322730064392,0.5462200045585632,2.047626256942749,50000 -1320.7677319049835,2.1223950386047363,17282.239025354385,37732,0,17282.239025354385,0.4401000142097473,2.634584426879883,10000,18607.237397432327,0.5966015458106995,1.8163317441940308,0.5523599982261658,2.01404070854187,50000 -1353.4470887184143,2.154141902923584,17702.3341217041,38651,0,17702.3341217041,0.4394000172615051,2.6386351585388184,10000,19060.093980789185,0.60693359375,1.7717552185058594,0.5542600154876709,2.0164029598236084,50000 -1387.2440507411957,2.182854175567627,18122.364139556885,39571,0,18122.364139556885,0.4341000318527221,2.667387723922729,10000,19513.99955034256,0.590136706829071,1.8692361116409304,0.5558800101280212,2.0332841873168945,50000 -1418.446210384369,2.216303586959839,18542.35852384568,40490,0,18542.35852384568,0.4387000203132629,2.6926157474517822,10000,19965.28075647354,0.5978320240974426,1.8688842058181765,0.5544999837875366,2.060164213180542,50000 -1449.7089776992798,2.25161361694336,18962.650825738907,41407,0,18962.650825738907,0.4399000108242035,2.66598916053772,10000,20416.922464370728,0.6039062142372131,1.8160500526428225,0.5564199686050415,2.0289831161499023,50000 -1482.2264337539673,2.291311264038086,19382.72688031197,42326,0,19382.72688031197,0.4429000318050384,2.6232457160949707,10000,20869.60631251335,0.5984765291213989,1.8087574243545528,0.5619800090789795,1.9925109148025515,50000 -1514.9667398929596,2.3243231773376465,19802.75421524048,43244,0,19802.75421524048,0.4507000148296356,2.5570688247680664,10000,21322.457344055176,0.6040819883346558,1.7556627988815308,0.5645599961280823,1.9334279298782349,50000 -1547.9607965946198,2.358118295669556,20223.016721487045,44162,0,20223.016721487045,0.4501000344753265,2.5805513858795166,10000,21775.79777216912,0.6100195050239563,1.7446229457855225,0.5627399682998657,1.9554500579833984,50000 -1581.5847356319427,2.3912997245788574,20643.23784804344,45079,0,20643.23784804344,0.4431000351905823,2.618972063064575,10000,22229.7259953022,0.6299804449081421,1.65792977809906,0.564079999923706,1.9666937589645384,50000 -1611.773603439331,2.4242405891418457,21063.51750826836,45998,0,21063.51750826836,0.4560000300407409,2.58458948135376,10000,22680.277902126312,0.6064062118530273,1.7620118856430054,0.5678600072860718,1.9412288665771484,50000 -1642.3246562480929,2.4671669006347656,21483.74810171128,46912,0,21483.74810171128,0.4544000327587127,2.5762476921081543,10000,23131.15274167061,0.6177538633346558,1.7487009763717651,0.5697599649429321,1.956942081451416,50000 -1673.292890548706,2.5029873847961426,21903.900598526,47827,0,21903.900598526,0.4462000131607055,2.595702171325684,10000,23582.35942387581,0.6209570169448853,1.705472469329834,0.5658599734306335,1.9722998142242432,50000 -1705.5355398654938,2.5407254695892334,22323.86534023285,48743,0,22323.86534023285,0.4463000297546386,2.602745771408081,10000,24034.65631222725,0.604785144329071,1.7722070217132568,0.568120002746582,1.9483237266540527,50000 -1738.979534626007,2.573563814163208,22743.817096471783,49662,0,22743.817096471783,0.4552000164985657,2.5426297187805176,10000,24488.13495373726,0.6148828268051147,1.721850037574768,0.5727399587631226,1.9160726070404053,50000 -1770.112517118454,2.60744047164917,23163.91581749916,50580,0,23163.91581749916,0.4517000317573547,2.5810353755950928,10000,24939.45225691796,0.61865234375,1.7359861135482788,0.5676199793815613,1.9560025930404663,50000 -1802.488485097885,2.6418302059173584,23584.244931459427,51498,0,23584.244931459427,0.4583000242710113,2.5383946895599365,10000,25392.24175477028,0.6126366853713989,1.7147948741912842,0.5749599933624268,1.904335618019104,50000 -1833.9402480125427,2.6762707233428955,24004.364980459213,52414,0,24004.364980459213,0.4573000073432922,2.529676914215088,10000,25843.898319244385,0.6123437285423279,1.7380605936050415,0.5724999904632568,1.917758822441101,50000 -1865.5457472801208,2.711615800857544,24424.607334136963,53330,0,24424.607334136963,0.45210000872612,2.568796396255493,10000,26295.831995487213,0.6242578029632568,1.6958975791931152,0.5759199857711792,1.9176926612854004,50000 -1897.9944295883176,2.7577033042907715,24844.78872013092,54244,0,24844.78872013092,0.4621000289916992,2.5444371700286865,10000,26748.55784964561,0.6257421970367432,1.7023617029190063,0.5797600150108337,1.9058101177215576,50000 -1929.7939734458923,2.7902755737304688,25264.96487569809,55162,0,25264.96487569809,0.4599000215530395,2.5232367515563965,10000,27200.61718583107,0.6188085675239563,1.706288456916809,0.5832399725914001,1.87911856174469,50000 -1961.75643324852,2.822684526443481,25685.004106283188,56078,0,25685.004106283188,0.4586000144481659,2.569799423217773,10000,27652.70127034188,0.6207226514816284,1.7265022993087769,0.5751399993896484,1.932764768600464,50000 -1994.321251630783,2.8584837913513184,26105.01860499382,56994,0,26105.01860499382,0.4688000082969665,2.537362575531006,10000,28105.36644411087,0.6523046493530273,1.600400686264038,0.580299973487854,1.9175043106079104,50000 -2026.0364346504207,2.890997171401977,26525.360887289047,57911,0,26525.360887289047,0.4709000289440155,2.4993984699249268,10000,28557.5066075325,0.6321093440055847,1.6738526821136477,0.5862999558448792,1.871827960014344,50000 -2058.912977695465,2.924208402633667,26945.452782392505,58829,0,26945.452782392505,0.4593000113964081,2.511036157608032,10000,29010.55766916275,0.6249608993530273,1.6657127141952517,0.5823599696159363,1.864213824272156,50000 -2092.420877933502,2.961343050003052,27365.50388979912,59747,0,27365.50388979912,0.4684000313282013,2.4895777702331543,10000,29464.2038834095,0.643359363079071,1.587910771369934,0.5870400071144104,1.8554558753967283,50000 -2124.993516921997,2.995525598526001,27785.721799850464,60664,0,27785.721799850464,0.4672000110149383,2.470869541168213,10000,29917.078412532806,0.6283788681030273,1.662112593650818,0.585319995880127,1.851845622062683,50000 -2158.5834772586823,3.031461477279663,28205.97324538231,61582,0,28205.97324538231,0.4756000339984894,2.4673588275909424,10000,30371.00706982613,0.6338085532188416,1.6124237775802612,0.591219961643219,1.8236290216445925,50000 -2189.576506137848,3.069997310638428,28626.62084031105,62501,0,28626.62084031105,0.4697000086307525,2.5128331184387207,10000,30822.73681116104,0.6354882717132568,1.626826286315918,0.5827599763870239,1.8661490678787231,50000 -2222.766315698624,3.1039247512817383,29046.67929458618,63419,0,29046.67929458618,0.4674000144004822,2.485158920288086,10000,31276.06823301316,0.6289843320846558,1.6643264293670654,0.583899974822998,1.8471851348876955,50000 -2255.2734336853027,3.1402459144592285,29466.69498157501,64335,0,29466.69498157501,0.4731000363826751,2.420977830886841,10000,31728.67764186859,0.6406640410423279,1.5750188827514648,0.5931800007820129,1.7844353914260864,50000 -2288.2237129211426,3.1775012016296387,29886.907014369965,65251,0,29886.907014369965,0.4781000316143036,2.480360984802246,10000,32181.92774629593,0.6419140696525574,1.613297700881958,0.5929799675941467,1.845779657363892,50000 -2320.53994345665,3.215022325515747,30307.044764518738,66167,0,30307.044764518738,0.4720000326633453,2.46311092376709,10000,32634.469495773315,0.6465820074081421,1.5871703624725342,0.596560001373291,1.818373203277588,50000 -2352.242401361465,3.251006841659546,30727.09972333908,67084,0,30727.09972333908,0.4830000102519989,2.422799825668335,10000,33086.314239263535,0.6468554735183716,1.5900592803955078,0.6029599905014038,1.7880793809890747,50000 -2384.98783993721,3.290488004684448,31147.37824487686,68001,0,31147.37824487686,0.4832000136375427,2.4284214973449707,10000,33539.42764592171,0.6465234160423279,1.5876259803771973,0.5993599891662598,1.807966232299805,50000 -2416.788944482804,3.324157476425171,31567.33975481987,68916,0,31567.33975481987,0.4843000173568725,2.398958206176758,10000,33991.27369570732,0.6715624928474426,1.442800521850586,0.6010000109672546,1.7613905668258667,50000 -2448.853832483292,3.36092495918274,31987.69392681122,69831,0,31987.69392681122,0.4826000332832336,2.457373142242432,10000,34443.77940249443,0.6374413967132568,1.655277967453003,0.5964199900627136,1.832472443580628,50000 -2478.640973091125,3.3965108394622803,32407.676120996475,70747,0,32407.676120996475,0.4763000309467315,2.390446662902832,10000,34893.63464021683,0.6484960913658142,1.5486077070236206,0.6013599634170532,1.7578142881393433,50000 -2510.5186746120453,3.442301034927368,32827.82526350021,71662,0,32827.82526350021,0.4798000156879425,2.4019968509674072,10000,35345.7574763298,0.6622851490974426,1.493794083595276,0.6013199687004089,1.7728713750839231,50000 -2542.784086465836,3.47953462600708,33247.85134124756,72578,0,33247.85134124756,0.4820000231266022,2.406587600708008,10000,35798.13573217392,0.6486914157867432,1.5760599374771118,0.6078199744224548,1.7598813772201538,50000 -2574.4699428081512,3.524752616882324,33668.151161670685,73498,0,33668.151161670685,0.4812000095844269,2.388460159301758,10000,36250.21720933914,0.650683581829071,1.5488523244857788,0.602840006351471,1.7498250007629397,50000 -2606.0012538433075,3.561843156814575,34088.156512498856,74415,0,34088.156512498856,0.4891000092029571,2.39551329612732,10000,36701.84150767326,0.6654492020606995,1.506258487701416,0.6062999963760376,1.7637112140655518,50000 -2640.042801618576,3.6005239486694336,34508.42599821091,75333,0,34508.42599821091,0.4856000244617462,2.3694193363189697,10000,37156.241079092026,0.6486523151397705,1.547737956047058,0.612500011920929,1.729038119316101,50000 -2673.288207530976,3.637963056564331,34928.72370290756,76251,0,34928.72370290756,0.4861000180244446,2.3885414600372314,10000,37609.87198114395,0.6528710722923279,1.5526872873306274,0.6100199818611145,1.7567421197891235,50000 -2708.1096754074097,3.672785997390747,35348.81794667244,77167,0,35348.81794667244,0.4903000295162201,2.388901948928833,10000,38064.8754966259,0.6625585556030273,1.513722538948059,0.6113399863243103,1.7449710369110107,50000 -2740.87061214447,3.71204137802124,35768.99364566803,78083,0,35768.99364566803,0.4865000247955322,2.363787174224853,10000,38517.90196752548,0.6606640219688416,1.493207335472107,0.6092599630355835,1.7219427824020386,50000 -2774.249065160752,3.752545356750488,36189.29687142372,78998,0,36189.29687142372,0.4889000356197357,2.359149217605591,10000,38971.67401695252,0.659375011920929,1.5171035528182983,0.6125400066375732,1.7171244621276855,50000 -2807.375717639923,3.791879415512085,36609.45451402664,79915,0,36609.45451402664,0.4943000376224518,2.326002836227417,10000,39425.04754161835,0.66845703125,1.4566603899002075,0.6201399564743042,1.6755112409591677,50000 -2837.8364627361298,3.83516001701355,37029.78786754608,80832,0,37029.78786754608,0.4910000264644623,2.404242038726806,10000,39875.934716939926,0.6868945360183716,1.438260197639465,0.6148200035095215,1.7528892755508425,50000 -2871.551378250122,3.878218173980713,37449.77145719528,81747,0,37449.77145719528,0.4950000345706939,2.3316457271575928,10000,40329.72599077225,0.6599804759025574,1.493338108062744,0.6163600087165833,1.6874927282333374,50000 -2903.697116851806,3.915627241134644,37869.73485660553,82663,0,37869.73485660553,0.4934000372886657,2.3003392219543457,10000,40781.92210578919,0.6682031154632568,1.454301834106445,0.6208800077438354,1.668615698814392,50000 -2935.783272266388,3.954903364181519,38289.80323362351,83581,0,38289.80323362351,0.4946000277996063,2.320722818374634,10000,41234.16646409035,0.6833788752555847,1.3950210809707642,0.6195200085639954,1.6812617778778076,50000 -2969.496173620224,3.991729259490967,38709.94033956528,84496,0,38709.94033956528,0.4934000372886657,2.347602367401123,10000,41688.10350394249,0.6632617115974426,1.50346839427948,0.6191200017929077,1.6938670873641968,50000 -3000.645429611206,4.033493041992188,39129.99418258667,85412,0,39129.99418258667,0.4971000254154205,2.3224990367889404,10000,42139.39819288254,0.6689453125,1.475509762763977,0.6204000115394592,1.6864943504333496,50000 -3034.7353508472443,4.08236575126648,39550.02034115791,86326,0,39550.02034115791,0.5010000467300415,2.304348945617676,10000,42593.61324310303,0.679492175579071,1.427093505859375,0.627299964427948,1.6685994863510132,50000 -3067.972613334656,4.12157940864563,39969.94578671456,87244,0,39969.94578671456,0.5063000321388245,2.290591239929199,10000,43046.86581158638,0.6720312237739563,1.4575451612472534,0.6271199584007263,1.6576414108276367,50000 -3101.585325717926,4.163185358047485,40389.9857711792,88162,0,40389.9857711792,0.4997000098228454,2.302229642868042,10000,43500.610604286194,0.6704687476158142,1.4411784410476685,0.622439980506897,1.6661144495010376,50000 -3131.58740067482,4.202484369277954,40810.23924612999,89080,0,40810.23924612999,0.5085000395774841,2.2891640663146973,10000,43950.955739974976,0.6823828220367432,1.4180799722671509,0.6261999607086182,1.6651922464370728,50000 -3164.264275074005,4.245952367782593,41230.43558573723,89995,0,41230.43558573723,0.5055000185966492,2.2911794185638428,10000,44403.9231262207,0.6862109303474426,1.411870360374451,0.6287199854850769,1.6493552923202517,50000 -3195.636889457702,4.286664247512817,41650.82543206215,90912,0,41650.82543206215,0.5124000310897827,2.259521722793579,10000,44855.7765583992,0.6795117259025574,1.41821026802063,0.6332799792289734,1.625242829322815,50000 -3228.6166894435883,4.32750678062439,42070.84104323387,91828,0,42070.84104323387,0.5081000328063965,2.3019185066223145,10000,45308.86335706711,0.6826757788658142,1.4430053234100342,0.6317399740219116,1.6762853860855105,50000 -3259.2013654708862,4.36621356010437,42490.79935574532,92745,0,42490.79935574532,0.5100000500679016,2.3166470527648926,10000,45759.4947450161,0.7028710842132568,1.3684577941894531,0.6319199800491333,1.6790177822113037,50000 -3292.632405996322,4.411803960800171,42910.90014410019,93656,0,42910.90014410019,0.5134000182151794,2.2607595920562744,10000,46213.121638059616,0.6768945455551147,1.43509840965271,0.6339799761772156,1.6315512657165527,50000 -3325.326532363892,4.453327894210815,43331.148206710815,94571,0,43331.148206710815,0.5184000134468079,2.2514212131500244,10000,46666.15451836586,0.6860546469688416,1.4129973649978638,0.6372999548912048,1.6310029029846191,50000 -3356.9392170906067,4.495721101760864,43751.27952218056,95488,0,43751.27952218056,0.5152000188827515,2.261467218399048,10000,47117.99061465264,0.6981640458106995,1.355225682258606,0.6358199715614319,1.629407286643982,50000 -3391.115294933319,4.54344367980957,44171.42354559898,96404,0,44171.42354559898,0.5163000226020813,2.2728638648986816,10000,47572.40820598602,0.6825000047683716,1.4293886423110962,0.6386799812316895,1.6273837089538574,50000 -3423.7243151664734,4.584112644195557,44591.63816213608,97322,0,44591.63816213608,0.515500009059906,2.2348742485046387,10000,48025.32202601433,0.6897070407867432,1.3623371124267578,0.6406199932098389,1.586683988571167,50000 -3457.3588194847107,4.624876022338867,45011.66071987152,98238,0,45011.66071987152,0.5222000479698181,2.2077906131744385,10000,48479.07020926476,0.7014843821525574,1.3189449310302734,0.646399974822998,1.5679948329925537,50000 -3490.893681287765,4.664891719818115,45431.88610672951,99156,0,45431.88610672951,0.5197000503540039,2.2255241870880127,10000,48932.92041492462,0.69189453125,1.3856916427612305,0.6421399712562561,1.6010757684707642,50000 -3522.6747205257416,4.70860743522644,45852.21845984459,100073,0,45852.21845984459,0.527999997138977,2.16940712928772,10000,49385.12797021866,0.6983593702316284,1.3234851360321045,0.646619975566864,1.549912452697754,50000 -3552.8196897506714,4.7579333782196045,46272.5213842392,100985,0,46272.5213842392,0.5232000350952148,2.1762855052948,10000,49835.674525260925,0.6984570026397705,1.3088330030441284,0.6455999612808228,1.5476953983306885,50000 -3584.373185634613,4.80739164352417,46692.711477041245,101900,0,46692.711477041245,0.5252000093460083,2.188363552093506,10000,50287.51734471321,0.7024999856948853,1.3207366466522217,0.6457799673080444,1.5591189861297607,50000 -3616.940553426743,4.861057758331299,47113.01589202881,102818,0,47113.01589202881,0.5247000455856323,2.24755334854126,10000,50740.492428302765,0.6942577958106995,1.4096049070358276,0.644599974155426,1.627078652381897,50000 -3650.198092460632,4.902195453643799,47533.03374886513,103735,0,47533.03374886513,0.5296000242233276,2.18750262260437,10000,51193.85984563828,0.7060937285423279,1.303206443786621,0.6481800079345703,1.5560694932937622,50000 -3683.071009159088,4.946327209472656,47953.09908533096,104652,0,47953.09908533096,0.5229000449180603,2.1618435382843018,10000,51646.89317679405,0.728710949420929,1.209591507911682,0.655519962310791,1.5235939025878906,50000 -3715.691724061966,4.996261835098267,48373.37495970726,105569,0,48373.37495970726,0.535800039768219,2.1648764610290527,10000,52099.889456510544,0.7003515362739563,1.328979730606079,0.6550799608230591,1.5396244525909424,50000 -3749.424998044968,5.040008306503296,48793.45449113846,106485,0,48793.45449113846,0.5314000248908997,2.192826986312866,10000,52553.79598546028,0.706250011920929,1.3207530975341797,0.6521199941635132,1.5590405464172363,50000 -3782.543523073197,5.082547903060913,49214.00630617142,107401,0,49214.00630617142,0.5392000079154968,2.1324949264526367,10000,53007.55867099762,0.7234960794448853,1.20840322971344,0.6581199765205383,1.5009328126907349,50000 -3814.051218271256,5.121634244918823,49634.25281214714,108316,0,49634.25281214714,0.5303000211715698,2.155637502670288,10000,53459.40159130096,0.7032812237739563,1.2869420051574707,0.6589800119400024,1.5009398460388184,50000 -3845.278760910034,5.172155141830444,50054.5297896862,109232,0,50054.5297896862,0.5329000353813171,2.154524326324463,10000,53911.006695985794,0.7122656106948853,1.288287878036499,0.657759964466095,1.5199744701385498,50000 -3878.277625799179,5.222049713134766,50474.51218295097,110148,0,50474.51218295097,0.5366000533103943,2.149949550628662,10000,54364.08873963356,0.7259570360183716,1.2455788850784302,0.6635199785232544,1.5217573642730713,50000 -3911.306578159332,5.262262582778931,50894.50727891922,111065,0,50894.50727891922,0.5421000123023987,2.113887310028076,10000,54817.20326185226,0.71337890625,1.262039303779602,0.6626600027084351,1.4916067123413086,50000 -3943.933856487274,5.305820465087891,51314.64157676697,111980,0,51314.64157676697,0.5409000515937805,2.1428349018096924,10000,55270.05773234368,0.7187108993530273,1.2647498846054075,0.6654999852180481,1.5069676637649536,50000 -3976.276055574417,5.350057363510132,51734.58570885658,112896,0,51734.58570885658,0.5398000478744507,2.1151599884033203,10000,55722.437901735306,0.7238085865974426,1.2229537963867188,0.6610999703407288,1.4882365465164185,50000 -4006.590903520584,5.393388271331787,52154.66342067719,113811,0,52154.66342067719,0.5451000332832336,2.094324350357056,10000,56172.92315030098,0.7215625047683716,1.2210685014724731,0.6649199724197388,1.4647303819656372,50000 -4038.273143053055,5.442919015884399,52575.01256537437,114726,0,52575.01256537437,0.5450000166893005,2.083441257476806,10000,56625.05432701111,0.7206835746765137,1.2099860906600952,0.67249995470047,1.443393349647522,50000 -4072.856017351152,5.488822221755981,52995.04573082924,115640,0,52995.04573082924,0.5496000051498413,2.064040422439575,10000,57079.766466617584,0.7305468320846558,1.191508173942566,0.67249995470047,1.4450451135635376,50000 -4106.33264541626,5.534053087234497,53415.35868215561,116556,0,53415.35868215561,0.5498000383377075,2.0630745887756348,10000,57533.651420116425,0.7452148199081421,1.1201274394989014,0.6720199584960938,1.4388738870620728,50000 -4139.674651861191,5.577041864395142,53835.411581754684,117472,0,53835.411581754684,0.5496000051498413,2.128594160079956,10000,57987.13926744461,0.7250390648841858,1.2660032510757446,0.667199969291687,1.5061935186386108,50000 -4172.6747174263,5.619581699371338,54255.57951760292,118386,0,54255.57951760292,0.5502000451087952,2.0854063034057617,10000,58440.40096735954,0.73046875,1.1939589977264404,0.6730200052261353,1.4451491832733154,50000 -4206.0634133815765,5.661855459213257,54675.84220218658,119299,0,54675.84220218658,0.5516000390052795,2.039374589920044,10000,58894.145411252975,0.7431640625,1.1251846551895142,0.6783999800682068,1.4041777849197388,50000 -4239.874573707581,5.706495046615601,55095.76208996773,120216,0,55095.76208996773,0.5534999966621399,2.0480198860168457,10000,59347.97211909294,0.7290429472923279,1.183713674545288,0.6782199740409851,1.414423584938049,50000 -4272.361785888672,5.750799179077148,55516.00203704834,121132,0,55516.00203704834,0.5556000471115112,2.0637381076812744,10000,59800.793695926666,0.7370507717132568,1.166053056716919,0.6767799854278564,1.426482319831848,50000 -4305.591577529907,5.79656982421875,55936.13686299324,122049,0,55936.13686299324,0.5579000115394592,2.025280237197876,10000,60254.25495505333,0.7417968511581421,1.1477159261703491,0.6811599731445312,1.4139902591705322,50000 -4337.223828554153,5.842200517654419,56356.257717609406,122963,0,56356.257717609406,0.5614000558853149,2.017380475997925,10000,60706.10312247276,0.7372655868530273,1.1513622999191284,0.688979983329773,1.3695464134216309,50000 -4370.009130716324,5.895697832107544,56776.523052454,123878,0,56776.523052454,0.5649000406265259,2.015331745147705,10000,61159.25761413574,0.7439843416213989,1.1398800611495972,0.6844399571418762,1.3942458629608154,50000 -4403.7414972782135,5.944585561752319,57196.867312431335,124794,0,57196.867312431335,0.562000036239624,2.016331672668457,10000,61613.43428826332,0.74818354845047,1.1148760318756104,0.6860199570655823,1.3866878747940063,50000 -4434.629065275192,5.990480661392212,57616.94558739662,125709,0,57616.94558739662,0.5626000165939331,2.021228790283203,10000,62064.4964826107,0.7470703125,1.1385987997055054,0.6839199662208557,1.4028888940811155,50000 -4466.095130681992,6.0411882400512695,58037.063480854034,126620,0,58037.063480854034,0.5684000253677368,1.9932000637054443,10000,62516.18090867996,0.7459765672683716,1.1232553720474243,0.6888200044631958,1.3786057233810425,50000 -4499.679343938828,6.095107316970825,58457.12394356728,127535,0,58457.12394356728,0.5659000277519226,1.9629532098770144,10000,62969.93001675606,0.7511913776397705,1.0778286457061768,0.6904999613761902,1.3422291278839111,50000 -4531.347820997238,6.140000343322754,58877.354343652725,128452,0,58877.354343652725,0.5690000057220459,1.971156001091004,10000,63421.92477989197,0.7696874737739563,1.024175047874451,0.6924600005149841,1.3428523540496826,50000 -4564.715627908707,6.195974826812744,59297.305203437805,129367,0,59297.305203437805,0.5629000067710876,1.973888039588928,10000,63875.34980750084,0.7485156059265137,1.1012896299362185,0.6915199756622314,1.3491532802581787,50000 -4597.110694646835,6.243996620178223,59717.62939977646,130283,0,59717.62939977646,0.5763000249862671,1.9376977682113647,10000,64328.16715955734,0.7594531178474426,1.0408467054367063,0.6985399723052979,1.3164602518081665,50000 -4630.069223642349,6.288846492767334,60138.00821852684,131200,0,60138.00821852684,0.5750000476837158,1.946197748184204,10000,64781.60162329674,0.7681640386581421,1.0163657665252686,0.6958799958229065,1.330001711845398,50000 -4662.193197011948,6.333850622177124,60558.07671093941,132116,0,60558.07671093941,0.5717000365257263,1.9606136083602903,10000,65233.88895082474,0.7521288990974426,1.0870712995529177,0.6966399550437927,1.3381699323654177,50000 -4694.585594415665,6.377013444900513,60978.0214304924,133031,0,60978.0214304924,0.5728000402450562,1.9535140991210933,10000,65686.31872987747,0.7635741829872131,1.053499698638916,0.6963199973106384,1.3278359174728394,50000 -4727.902453184128,6.423879384994507,61398.24069237709,133946,0,61398.24069237709,0.5840000510215759,1.9154057502746584,10000,66139.95198273659,0.7712304592132568,0.9981160759925842,0.7023999691009521,1.3043373823165894,50000 -4761.82052397728,6.472502708435059,61818.17729949951,134862,0,61818.17729949951,0.5812000036239624,1.9075292348861688,10000,66593.90523028374,0.7629492282867432,1.0372766256332395,0.702739953994751,1.304007887840271,50000 -4794.236758947372,6.520557165145874,62238.25712633133,135777,0,62238.25712633133,0.579800009727478,1.9201245307922363,10000,67046.49923276901,0.7660546898841858,1.0418373346328735,0.7041400074958801,1.3056179285049438,50000 -4826.84645652771,6.569271087646484,62658.68343901634,136693,0,62658.68343901634,0.5826000571250916,1.9220960140228271,10000,67499.63382172585,0.7722460627555847,1.0183627605438232,0.7065199613571167,1.3071863651275637,50000 -4860.415402173996,6.61378026008606,63078.72762656212,137608,0,63078.72762656212,0.5869000554084778,1.897950291633606,10000,67953.34159827232,0.7710937261581421,1.0141106843948364,0.7095400094985962,1.2850812673568726,50000 -4893.19992518425,6.666823387145996,63498.872061014175,138523,0,63498.872061014175,0.5896000266075134,1.89242959022522,10000,68406.37291073799,0.7734375,1.0056006908416748,0.7101199626922607,1.2799828052520752,50000 -4926.0291867256165,6.712033033370972,63919.03820538521,139438,0,63919.03820538521,0.5873000025749207,1.8855514526367188,10000,68859.4631357193,0.7773827910423279,0.9855165481567384,0.7126399874687195,1.2628765106201172,50000 -4957.572335958481,6.756109237670898,64339.187994003296,140354,0,64339.187994003296,0.5871000289916992,1.891028642654419,10000,69311.2506814003,0.78968745470047,0.938185453414917,0.7109599709510803,1.2699692249298096,50000 -4990.778796195984,6.810421466827393,64759.18298387528,141268,0,64759.18298387528,0.5908000469207764,1.856269598007202,10000,69764.55665397644,0.775585949420929,0.9803733825683594,0.7143399715423584,1.2429838180541992,50000 -5022.137254714966,6.856017589569092,65179.46191477776,142185,0,65179.46191477776,0.596500039100647,1.858317494392395,10000,70216.29029083252,0.7842773199081421,0.9769614338874816,0.7149400115013123,1.2617579698562622,50000 -5053.564244508743,6.910325288772583,65600.1053712368,143101,0,65600.1053712368,0.5972000360488892,1.844268798828125,10000,70668.46542716026,0.7922265529632568,0.911641538143158,0.7174999713897705,1.237351417541504,50000 -5086.700350761414,6.957584381103516,66020.01851248741,144016,0,66020.01851248741,0.6003000140190125,1.8443188667297363,10000,71121.61238765717,0.7851757407188416,0.9561517238616944,0.7189399600028992,1.2434120178222656,50000 -5120.010430335999,7.005602598190308,66440.11183166504,144931,0,66440.11183166504,0.5937000513076782,1.8433793783187864,10000,71575.11425447464,0.7807421684265137,0.9476374387741088,0.7181999683380127,1.2305222749710083,50000 -5151.019411563873,7.051945686340332,66860.22039985657,145846,0,66860.22039985657,0.6014000177383423,1.8521920442581177,10000,72026.32795882225,0.7958202958106995,0.9202547073364258,0.7206000089645386,1.2377933263778689,50000 -5184.43443775177,7.107837200164795,67280.31985378265,146759,0,67280.31985378265,0.6012000441551208,1.833222389221192,10000,72479.94853115082,0.7870507836341858,0.943701982498169,0.7204200029373169,1.2290470600128174,50000 -5218.43242764473,7.157444715499878,67700.53570318222,147673,0,67700.53570318222,0.5987000465393066,1.8122233152389529,10000,72934.26190567017,0.79359370470047,0.9170867204666138,0.7231999635696411,1.2131633758544922,50000 -5252.327347993851,7.205964088439941,68120.48303842545,148592,0,68120.48303842545,0.6028000116348267,1.801270842552185,10000,73388.2038257122,0.7974804639816284,0.8924198150634766,0.7242599725723267,1.2011919021606443,50000 -5285.950101613998,7.251769781112671,68540.57233738899,149508,0,68540.57233738899,0.6076000332832336,1.780668020248413,10000,73842.01174998283,0.7968554496765137,0.8914695382118225,0.7272399663925171,1.1844379901885986,50000 -5317.539261579514,7.299367904663086,68960.80321502686,150424,0,68960.80321502686,0.6101000308990479,1.7993606328964231,10000,74293.92950153351,0.7983788847923279,0.8947476744651794,0.7258999943733215,1.197722315788269,50000 -5349.239114522934,7.360641241073608,69380.76213383675,151339,0,69380.76213383675,0.6135000586509705,1.7731417417526243,10000,74745.69967579842,0.8019726276397705,0.8627512454986572,0.7283200025558472,1.176479458808899,50000 -5382.729545354843,7.409675359725952,69800.7703230381,152253,0,69800.7703230381,0.6112000346183777,1.764483094215393,10000,75199.2972612381,0.8101562261581421,0.8360294103622437,0.7314800024032593,1.171595811843872,50000 -5414.739239692688,7.460782766342163,70220.83098077774,153168,0,70220.83098077774,0.6140000224113464,1.7503547668457031,10000,75651.4686627388,0.8065429329872131,0.850128710269928,0.7335399985313416,1.1555325984954834,50000 -5448.728743553162,7.51036524772644,70640.93357086182,154083,0,70640.93357086182,0.6159000396728516,1.768608331680298,10000,76105.66011571884,0.8090429306030273,0.8482322692871094,0.7348799705505371,1.165394306182861,50000 -5481.284249305725,7.567243576049805,71061.22500658035,154996,0,71061.22500658035,0.614300012588501,1.7606123685836792,10000,76558.61492061615,0.8161718845367432,0.8130109310150146,0.7342199683189392,1.1588497161865234,50000 -5513.983215093613,7.616268157958984,71481.48496103287,155912,0,71481.48496103287,0.6164000034332275,1.768826603889465,10000,77011.67231607437,0.8117773532867432,0.8497369289398193,0.7390799522399902,1.1490964889526367,50000 -5546.199652433395,7.663997650146484,71901.40081095695,156825,0,71901.40081095695,0.617900013923645,1.7345569133758545,10000,77463.90251994133,0.8138867020606995,0.8235838413238525,0.7387599945068359,1.1445491313934326,50000 -5579.793093681335,7.714296340942383,72321.37640357018,157740,0,72321.37640357018,0.6212000250816345,1.7404789924621582,10000,77917.57216620445,0.8207616806030273,0.800974428653717,0.7397199869155884,1.1447432041168213,50000 -5613.552079439163,7.765320539474487,72741.29397368431,158654,0,72741.29397368431,0.625700056552887,1.7191606760025024,10000,78371.34893417358,0.81507807970047,0.812286913394928,0.7413399815559387,1.1248809099197388,50000 -5645.6390788555145,7.827660083770752,73161.59970808029,159569,0,73161.59970808029,0.6260000467300415,1.720393419265747,10000,78823.85418653488,0.8191210627555847,0.8072444200515747,0.7432799935340881,1.1267979145050049,50000 -5680.3631727695465,7.879194021224976,73581.53504562378,160442,0,73581.53504562378,0.6271000504493713,1.7195051908493042,10000,79278.61271739006,0.8229101300239563,0.7834881544113159,0.744659960269928,1.1210800409317017,50000 -5711.764150619507,7.929988384246826,74001.81647443771,161358,0,74001.81647443771,0.6326000094413757,1.683404803276062,10000,79730.39583301544,0.82289057970047,0.7732518315315247,0.7462199926376343,1.0985074043273926,50000 -5743.91107749939,7.989051818847656,74421.98902630806,162271,0,74421.98902630806,0.632900059223175,1.7059128284454346,10000,80182.82422280312,0.8252733945846558,0.7839264273643494,0.7461000084877014,1.112027645111084,50000 -5777.016916275024,8.03922438621521,74842.11705088615,163186,0,74842.11705088615,0.6315000057220459,1.6792042255401611,10000,80636.16126894951,0.8279882669448853,0.7568796277046204,0.7485599517822266,1.0927319526672363,50000 -5809.274814844132,8.088002920150757,75262.68575310707,164101,0,75262.68575310707,0.6350000500679016,1.6938539743423462,10000,81089.08715343475,0.8231250047683716,0.784327507019043,0.7511199712753296,1.100296974182129,50000 -5841.812193155289,8.13952898979187,75682.98915481567,165017,0,75682.98915481567,0.6292000412940979,1.6791881322860718,10000,81542.0298383236,0.8282421827316284,0.7752244472503662,0.7502599954605103,1.0974425077438354,50000 -5875.553389310837,8.199171781539917,76103.24325037003,165934,0,76103.24325037003,0.6335000395774841,1.672433853149414,10000,81996.13554024696,0.8337695002555847,0.7408263087272644,0.7530800104141235,1.0839965343475342,50000 -5910.257388830185,8.253775358200073,76523.2434270382,166850,0,76523.2434270382,0.6446000337600708,1.6342931985855105,10000,82450.94435310364,0.8430468440055847,0.7069917917251587,0.7552599906921387,1.0686745643615725,50000 -5943.479510307312,8.302427291870117,76943.4230298996,167765,0,76943.4230298996,0.6406000256538391,1.656589150428772,10000,82904.4444053173,0.8338086009025574,0.7452972531318665,0.7548399567604065,1.074892282485962,50000 -5977.2169399261475,8.360501766204834,77363.56269574165,168679,0,77363.56269574165,0.6388000249862671,1.652498722076416,10000,83358.4293923378,0.83363276720047,0.7375044822692871,0.7561999559402466,1.0720763206481934,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 7445235c6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1878 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.3291099,6.907757,,,,,,,,,,,,,, -1,,,0.0009179687476716,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,56.03143048286438,96.34635639190674,56.03143048286438,40.31481671333313,0.0,0.0 -100,0.33430946,6.906837,,,,,,,,,,,,,, -200,0.4102002,6.893395,,,,,,,,,,,,,, -300,0.5273446,6.8571315,,,,,,,,,,,,,, -400,0.65252805,6.8071923,,,,,,,,,,,,,, -500,0.7201225,6.78598,,,,,,,,,,,,,, -600,1.1000947,6.741387,,,,,,,,,,,,,, -700,0.9100101,6.8100257,,,,,,,,,,,,,, -800,1.1927352,6.6562433,,,,,,,,,,,,,, -884,,,0.0150976562872529,6.414700031280518,0.0146199995651841,6.424112796783447,50000.0,0.0112000005319714,6.467188358306885,10000.0,476.1916465759277,538.5675106048584,476.1916465759277,62.29332900047302,0.0311734676361084,0.0 -900,3.1867454,6.574435,,,,,,,,,,,,,, -1000,2.533499,6.567875,,,,,,,,,,,,,, -1100,2.0868247,6.4655604,,,,,,,,,,,,,, -1200,1.6976291,6.436185,,,,,,,,,,,,,, -1300,1.8970184,6.4515514,,,,,,,,,,,,,, -1400,1.2821145,6.417068,,,,,,,,,,,,,, -1500,2.5363595,6.4378414,,,,,,,,,,,,,, -1600,1.7365093,6.2836,,,,,,,,,,,,,, -1700,1.3382031,6.5716286,,,,,,,,,,,,,, -1800,1.4495664,6.4994383,,,,,,,,,,,,,, -1816,,,0.0458203107118606,5.830249309539795,0.044039998203516,5.860832214355469,50000.0,0.0388000011444091,5.978135585784912,10000.0,896.3407437801361,980.535098552704,896.3407437801361,84.03109979629517,0.0588784217834472,0.0 -1900,2.2629044,6.695959,,,,,,,,,,,,,, -2000,1.9959128,6.1162496,,,,,,,,,,,,,, -2100,1.6380744,6.1367197,,,,,,,,,,,,,, -2200,2.0715973,6.113777,,,,,,,,,,,,,, -2300,2.3674517,6.0722375,,,,,,,,,,,,,, -2400,2.1478984,6.0662518,,,,,,,,,,,,,, -2500,1.5049723,6.0285892,,,,,,,,,,,,,, -2600,1.7792007,6.012537,,,,,,,,,,,,,, -2700,1.4447225,6.4912806,,,,,,,,,,,,,, -2749,,,0.0702343732118606,5.432772636413574,0.0639199987053871,5.477494716644287,50000.0,0.0491000041365623,5.647575378417969,10000.0,1316.5516102313995,1422.709864616394,1316.5516102313995,105.91550326347352,0.0854272842407226,0.0 -2800,1.7777814,5.910261,,,,,,,,,,,,,, -2900,2.0118456,5.89996,,,,,,,,,,,,,, -3000,1.4685583,6.007889,,,,,,,,,,,,,, -3100,1.7601773,5.8884377,,,,,,,,,,,,,, -3200,1.3264701,6.4988403,,,,,,,,,,,,,, -3300,1.6396443,5.8199024,,,,,,,,,,,,,, -3400,1.9351851,5.884766,,,,,,,,,,,,,, -3500,1.4570885,6.048254,,,,,,,,,,,,,, -3600,1.6980745,6.408496,,,,,,,,,,,,,, -3681,,,0.0948632806539535,5.102994441986084,0.0889399945735931,5.145900249481201,50000.0,0.0680000036954879,5.371665954589844,10000.0,1736.8272049427032,1866.528002262116,1736.8272049427032,129.37765622138977,0.1137945652008056,0.0 -3700,1.537025,5.7396336,,,,,,,,,,,,,, -3800,1.5486251,5.7652674,,,,,,,,,,,,,, -3900,1.3460654,5.7984304,,,,,,,,,,,,,, -4000,1.916542,5.734318,,,,,,,,,,,,,, -4100,1.9846426,5.690876,,,,,,,,,,,,,, -4200,1.5840681,5.6398635,,,,,,,,,,,,,, -4300,2.0064654,5.615769,,,,,,,,,,,,,, -4400,1.1998067,6.424536,,,,,,,,,,,,,, -4500,1.9522899,5.549839,,,,,,,,,,,,,, -4600,1.7450657,5.5205092,,,,,,,,,,,,,, -4611,,,0.1406054645776748,4.725737571716309,0.1301199942827224,4.775650501251221,50000.0,0.0988000035285949,5.062837600708008,10000.0,2157.039966583252,2309.0352096557617,2157.039966583252,151.5914740562439,0.1422510147094726,0.0 -4700,1.752772,5.5340657,,,,,,,,,,,,,, -4800,1.3503635,6.449416,,,,,,,,,,,,,, -4900,1.8452933,5.408026,,,,,,,,,,,,,, -5000,1.5928313,5.586424,,,,,,,,,,,,,, -5100,1.7655579,5.4565334,,,,,,,,,,,,,, -5200,1.8649502,5.372303,,,,,,,,,,,,,, -5300,1.358459,6.3541465,,,,,,,,,,,,,, -5400,1.5860063,6.494045,,,,,,,,,,,,,, -5500,1.5220419,5.700618,,,,,,,,,,,,,, -5539,,,0.1874218732118606,4.307502269744873,0.1694400012493133,4.408010959625244,50000.0,0.1281000077724456,4.742315292358398,10000.0,2577.1580624580383,2751.210842370987,2577.1580624580383,173.5701711177826,0.1694018840789795,0.0 -5600,1.646654,5.283615,,,,,,,,,,,,,, -5700,1.647145,5.5679173,,,,,,,,,,,,,, -5800,1.3429315,5.816905,,,,,,,,,,,,,, -5900,1.5587778,5.284742,,,,,,,,,,,,,, -6000,1.9667062,5.573971,,,,,,,,,,,,,, -6100,1.7129866,5.0972795,,,,,,,,,,,,,, -6200,1.4702239,5.626332,,,,,,,,,,,,,, -6300,1.6484748,5.072614,,,,,,,,,,,,,, -6400,1.6401852,5.2614055,,,,,,,,,,,,,, -6464,,,0.2234570235013961,4.075367450714111,0.2088599950075149,4.152421951293945,50000.0,0.1588000059127807,4.514123916625977,10000.0,2997.12113070488,3195.574226856232,2997.12113070488,197.8914442062378,0.1972701549530029,0.0 -6500,1.5626813,5.306756,,,,,,,,,,,,,, -6600,1.5431539,5.5785427,,,,,,,,,,,,,, -6700,1.0858132,6.381526,,,,,,,,,,,,,, -6800,1.6122162,5.2973704,,,,,,,,,,,,,, -6900,1.8433222,4.9413776,,,,,,,,,,,,,, -7000,1.4324809,6.39707,,,,,,,,,,,,,, -7100,1.4266453,6.339537,,,,,,,,,,,,,, -7200,1.472529,6.294532,,,,,,,,,,,,,, -7300,1.277693,6.226548,,,,,,,,,,,,,, -7389,,,0.2674999833106994,3.717461109161377,0.2469999939203262,3.828011751174927,50000.0,0.1915000081062317,4.239701271057129,10000.0,3417.268706560135,3642.021924495697,3417.268706560135,224.1046841144561,0.232349157333374,0.0 -7400,1.6891117,4.7622,,,,,,,,,,,,,, -7500,1.4720782,5.463061,,,,,,,,,,,,,, -7600,1.5113028,4.8205576,,,,,,,,,,,,,, -7700,1.4760004,4.8126483,,,,,,,,,,,,,, -7800,1.8742754,4.8400903,,,,,,,,,,,,,, -7900,1.2618273,5.6116705,,,,,,,,,,,,,, -8000,1.6610467,4.6678514,,,,,,,,,,,,,, -8100,1.5596471,4.754899,,,,,,,,,,,,,, -8200,1.1524535,6.22902,,,,,,,,,,,,,, -8300,1.6814334,4.805926,,,,,,,,,,,,,, -8319,,,0.3072656095027923,3.447201728820801,0.2823199927806854,3.5882530212402344,50000.0,0.2145000100135803,4.04166316986084,10000.0,3837.3803448677054,4092.129539489746,3837.3803448677054,254.00139904022217,0.2779793739318847,0.0 -8400,1.588669,4.6294737,,,,,,,,,,,,,, -8500,1.6573488,4.632052,,,,,,,,,,,,,, -8600,1.6787689,4.7463217,,,,,,,,,,,,,, -8700,1.455972,5.1351705,,,,,,,,,,,,,, -8800,1.9948889,4.55826,,,,,,,,,,,,,, -8900,1.6879791,5.6590085,,,,,,,,,,,,,, -9000,1.6319952,4.593475,,,,,,,,,,,,,, -9100,1.1331518,6.2024226,,,,,,,,,,,,,, -9200,1.1327292,6.141655,,,,,,,,,,,,,, -9246,,,0.3461132645606994,3.243919849395752,0.3188599944114685,3.357321262359619,50000.0,0.245600014925003,3.852262020111084,10000.0,4257.711834192276,4546.292115211487,4257.711834192276,287.7465693950653,0.3111603260040283,0.0 -9300,1.1808684,5.932537,,,,,,,,,,,,,, -9400,1.1117712,6.1585636,,,,,,,,,,,,,, -9500,1.0278672,6.219107,,,,,,,,,,,,,, -9600,1.6878737,4.620962,,,,,,,,,,,,,, -9700,1.668733,4.5585203,,,,,,,,,,,,,, -9800,2.7196667,4.4976273,,,,,,,,,,,,,, -9900,1.6019452,4.4561415,,,,,,,,,,,,,, -10000,1.4438128,4.5270243,,,,,,,,,,,,,, -10100,1.5247498,5.2830276,,,,,,,,,,,,,, -10176,,,0.3649413883686065,3.0633111000061035,0.3436799943447113,3.1808507442474365,50000.0,0.2625000178813934,3.703757524490357,10000.0,4677.799368858337,4995.334156990051,4677.799368858337,316.598552942276,0.3597798347473144,0.0 -10200,1.0803195,5.6228275,,,,,,,,,,,,,, -10300,1.6231889,4.4333086,,,,,,,,,,,,,, -10400,1.9277245,4.6640906,,,,,,,,,,,,,, -10500,1.5543206,4.3387356,,,,,,,,,,,,,, -10600,1.4223602,4.688762,,,,,,,,,,,,,, -10700,1.6198804,4.9180207,,,,,,,,,,,,,, -10800,1.0385597,6.127368,,,,,,,,,,,,,, -10900,1.5829189,4.3938065,,,,,,,,,,,,,, -11000,1.3569801,4.6185927,,,,,,,,,,,,,, -11100,1.6510075,4.3560643,,,,,,,,,,,,,, -11101,,,0.3976367115974426,2.9350926876068115,0.3646000027656555,3.088977336883545,50000.0,0.2785000205039978,3.62067985534668,10000.0,5098.222493410111,5446.229580163956,5098.222493410111,346.9868311882019,0.3919997215270996,0.0 -11200,0.96423906,5.881105,,,,,,,,,,,,,, -11300,1.8593858,4.5126185,,,,,,,,,,,,,, -11400,1.3111153,5.8767433,,,,,,,,,,,,,, -11500,1.5711607,4.228761,,,,,,,,,,,,,, -11600,1.6854869,4.2637606,,,,,,,,,,,,,, -11700,1.5589702,4.3119273,,,,,,,,,,,,,, -11800,1.4923192,4.3401012,,,,,,,,,,,,,, -11900,1.6231695,4.139243,,,,,,,,,,,,,, -12000,1.4598291,4.210677,,,,,,,,,,,,,, -12025,,,0.4351562261581421,2.699234962463379,0.3865399956703186,2.927830457687378,50000.0,0.2924000024795532,3.486982107162476,10000.0,5518.611259937286,5892.516352653503,5518.611259937286,372.8019599914551,0.4224827289581299,0.0 -12100,1.6456257,4.1695495,,,,,,,,,,,,,, -12200,1.2103641,5.913884,,,,,,,,,,,,,, -12300,1.1635107,5.2973413,,,,,,,,,,,,,, -12400,1.1336277,5.969995,,,,,,,,,,,,,, -12500,1.6688778,4.21784,,,,,,,,,,,,,, -12600,1.1335307,6.034396,,,,,,,,,,,,,, -12700,1.4787582,4.2312675,,,,,,,,,,,,,, -12800,1.356152,4.2306166,,,,,,,,,,,,,, -12900,1.6841148,4.0924144,,,,,,,,,,,,,, -12947,,,0.429980456829071,2.715768575668335,0.3989799916744232,2.8612937927246094,50000.0,0.3106000125408172,3.41860294342041,10000.0,5938.5728850364685,6340.079912424088,5938.5728850364685,400.3198812007904,0.4549741744995117,0.0 -13000,0.9236607,6.0503087,,,,,,,,,,,,,, -13100,1.3215206,4.304862,,,,,,,,,,,,,, -13200,1.4519806,4.19007,,,,,,,,,,,,,, -13300,1.3251724,4.0830173,,,,,,,,,,,,,, -13400,1.7520819,4.284338,,,,,,,,,,,,,, -13500,1.2603621,5.210645,,,,,,,,,,,,,, -13600,1.5705853,4.071068,,,,,,,,,,,,,, -13700,1.3978353,4.0536327,,,,,,,,,,,,,, -13800,1.5643032,4.07439,,,,,,,,,,,,,, -13871,,,0.447558581829071,2.601588010787964,0.4180800020694732,2.759405851364136,50000.0,0.3206000030040741,3.336005210876465,10000.0,6358.816624403,6794.234518766403,6358.816624403,434.1339132785797,0.499969482421875,0.0 -13900,1.1864665,4.8744125,,,,,,,,,,,,,, -14000,1.5946541,4.3741965,,,,,,,,,,,,,, -14100,1.5286967,4.0751643,,,,,,,,,,,,,, -14200,1.6467168,3.997131,,,,,,,,,,,,,, -14300,1.4210927,4.145789,,,,,,,,,,,,,, -14400,0.89955574,5.568205,,,,,,,,,,,,,, -14500,1.5422927,4.0095844,,,,,,,,,,,,,, -14600,1.4897007,3.9984832,,,,,,,,,,,,,, -14700,0.9604812,5.7521143,,,,,,,,,,,,,, -14794,,,0.4675585925579071,2.4731810092926025,0.4262399971485138,2.686066150665283,50000.0,0.3368000090122223,3.2613182067871094,10000.0,6778.866580486298,7250.958543777466,6778.866580486298,470.73082447052,0.5254685878753662,0.0 -14800,1.5328525,3.97562,,,,,,,,,,,,,, -14900,0.9846098,5.6829095,,,,,,,,,,,,,, -15000,0.9518223,5.9251328,,,,,,,,,,,,,, -15100,1.3863554,3.9816911,,,,,,,,,,,,,, -15200,1.4890785,3.8780737,,,,,,,,,,,,,, -15300,1.258868,4.2326245,,,,,,,,,,,,,, -15400,1.4512782,4.048507,,,,,,,,,,,,,, -15500,1.2176901,4.342952,,,,,,,,,,,,,, -15600,1.5127548,3.918115,,,,,,,,,,,,,, -15700,1.3477426,4.1934524,,,,,,,,,,,,,, -15716,,,0.4583398401737213,2.544507026672364,0.4334799945354461,2.6762075424194336,50000.0,0.3315000236034393,3.258542060852051,10000.0,7199.002569198608,7709.4025728702545,7199.002569198608,508.96109414100647,0.5512490272521973,0.0 -15800,0.9804586,5.252469,,,,,,,,,,,,,, -15900,1.5217483,3.9479043,,,,,,,,,,,,,, -16000,1.3212177,3.9140615,,,,,,,,,,,,,, -16100,1.3514192,4.0138226,,,,,,,,,,,,,, -16200,1.2376333,4.4423866,,,,,,,,,,,,,, -16300,1.4876841,4.0040846,,,,,,,,,,,,,, -16400,1.4812545,4.1528897,,,,,,,,,,,,,, -16500,1.2455932,4.325816,,,,,,,,,,,,,, -16600,1.437074,3.884824,,,,,,,,,,,,,, -16636,,,0.4844140410423279,2.364227533340454,0.4491399824619293,2.542757511138916,50000.0,0.3485000133514404,3.135037422180176,10000.0,7619.164701223373,8167.473934412003,7619.164701223373,546.786703824997,0.5834970474243164,0.0 -16700,1.5926385,4.29595,,,,,,,,,,,,,, -16800,1.3831434,4.46525,,,,,,,,,,,,,, -16900,1.4914445,4.0138674,,,,,,,,,,,,,, -17000,1.1082401,4.6842465,,,,,,,,,,,,,, -17100,1.6043628,3.872,,,,,,,,,,,,,, -17200,1.4651881,3.9099412,,,,,,,,,,,,,, -17300,1.4362326,3.9508667,,,,,,,,,,,,,, -17400,1.2810241,3.9638124,,,,,,,,,,,,,, -17500,1.0740163,4.7271113,,,,,,,,,,,,,, -17552,,,0.488085925579071,2.3479232788085938,0.4491199851036072,2.5377163887023926,50000.0,0.3505000174045563,3.126204013824463,10000.0,8039.487378358841,8624.050165176392,8039.487378358841,582.9615287780762,0.6108376979827881,0.0 -17600,1.0354174,5.420446,,,,,,,,,,,,,, -17700,1.500348,3.973505,,,,,,,,,,,,,, -17800,1.4032788,3.93089,,,,,,,,,,,,,, -17900,0.8548171,5.908079,,,,,,,,,,,,,, -18000,1.3372358,4.0071373,,,,,,,,,,,,,, -18100,1.4698308,3.8919241,,,,,,,,,,,,,, -18200,1.541442,3.86503,,,,,,,,,,,,,, -18300,1.6699436,3.9934187,,,,,,,,,,,,,, -18400,1.5415769,3.8339825,,,,,,,,,,,,,, -18470,,,0.494140625,2.3677914142608643,0.4590199887752533,2.532339334487915,50000.0,0.3570000231266022,3.123909950256348,10000.0,8459.713441371918,9083.020942211151,8459.713441371918,621.6249876022339,0.6407480239868164,0.0 -18500,1.1713141,3.993101,,,,,,,,,,,,,, -18600,1.3021505,4.6481514,,,,,,,,,,,,,, -18700,0.97857046,5.8300514,,,,,,,,,,,,,, -18800,1.6943791,3.8895967,,,,,,,,,,,,,, -18900,1.2155277,4.189431,,,,,,,,,,,,,, -19000,1.3895806,3.8477588,,,,,,,,,,,,,, -19100,1.588638,3.884274,,,,,,,,,,,,,, -19200,1.6104912,4.5736613,,,,,,,,,,,,,, -19300,1.3938943,4.1641626,,,,,,,,,,,,,, -19389,,,0.500292956829071,2.304877996444702,0.4635799825191498,2.468919515609741,50000.0,0.3579000234603882,3.091227769851685,10000.0,8879.765430927277,9542.602039337158,8879.765430927277,661.0747628211975,0.6683251857757568,0.0 -19400,1.2309151,4.1107416,,,,,,,,,,,,,, -19500,1.3228763,3.9737895,,,,,,,,,,,,,, -19600,1.2795075,3.8661604,,,,,,,,,,,,,, -19700,1.0593756,4.532345,,,,,,,,,,,,,, -19800,1.3464409,3.7665803,,,,,,,,,,,,,, -19900,1.350464,3.7825427,,,,,,,,,,,,,, -20000,0.93134165,5.5829196,,,,,,,,,,,,,, -20100,1.4095168,3.9532552,,,,,,,,,,,,,, -20200,0.975221,5.1784315,,,,,,,,,,,,,, -20300,1.4536507,3.8991988,,,,,,,,,,,,,, -20307,,,0.5157226324081421,2.245275020599365,0.471780002117157,2.446197986602783,50000.0,0.3741000294685364,3.045219898223877,10000.0,9298.976362466812,10000.288239002228,9298.976362466812,698.6457741260529,1.5207412242889404,0.0 -20400,1.3575912,3.9521213,,,,,,,,,,,,,, -20500,1.2579746,3.7419922,,,,,,,,,,,,,, -20600,1.2402129,3.9819317,,,,,,,,,,,,,, -20700,0.86096406,5.7708435,,,,,,,,,,,,,, -20800,0.829622,5.704003,,,,,,,,,,,,,, -20900,1.1395689,5.6904726,,,,,,,,,,,,,, -21000,1.2990304,4.7781515,,,,,,,,,,,,,, -21100,1.0826696,5.6505523,,,,,,,,,,,,,, -21200,1.3302497,3.9725838,,,,,,,,,,,,,, -21226,,,0.5420507788658142,2.0856730937957764,0.4818599820137024,2.369181871414185,50000.0,0.3777000308036804,2.9841349124908447,10000.0,9719.232017755508,10455.08547091484,9719.232017755508,733.1075923442841,1.548471212387085,0.0 -21300,1.3740554,3.857836,,,,,,,,,,,,,, -21400,1.0680952,5.1987333,,,,,,,,,,,,,, -21500,1.8044488,3.7078288,,,,,,,,,,,,,, -21600,0.8537771,5.6241016,,,,,,,,,,,,,, -21700,1.1113476,4.66068,,,,,,,,,,,,,, -21800,1.2205443,4.024323,,,,,,,,,,,,,, -21900,1.3377239,3.7676985,,,,,,,,,,,,,, -22000,1.0081714,4.874508,,,,,,,,,,,,,, -22100,1.3025512,3.712394,,,,,,,,,,,,,, -22142,,,0.5222460627555847,2.2017922401428223,0.4871599972248077,2.363506555557251,50000.0,0.380700021982193,2.985949754714966,10000.0,10139.521540164948,10909.113496303558,10139.521540164948,766.7560062408447,1.5872962474822998,0.0 -22200,1.4561176,3.8140907,,,,,,,,,,,,,, -22300,1.245024,3.718606,,,,,,,,,,,,,, -22400,0.95670533,5.6590943,,,,,,,,,,,,,, -22500,0.910919,5.727264,,,,,,,,,,,,,, -22600,1.32852,4.366585,,,,,,,,,,,,,, -22700,1.3661344,3.9521039,,,,,,,,,,,,,, -22800,1.1857448,4.028304,,,,,,,,,,,,,, -22900,1.3208408,3.7644854,,,,,,,,,,,,,, -23000,0.8921906,5.771597,,,,,,,,,,,,,, -23064,,,0.5399999618530273,2.108938455581665,0.499099999666214,2.298219919204712,50000.0,0.3899000287055969,2.9014742374420166,10000.0,10559.62688589096,11362.43018102646,10559.62688589096,799.8872475624084,1.6160616874694824,0.0 -23100,1.4246439,3.7365737,,,,,,,,,,,,,, -23200,1.0457933,5.377886,,,,,,,,,,,,,, -23300,1.6765325,3.7834036,,,,,,,,,,,,,, -23400,1.3494543,4.2397685,,,,,,,,,,,,,, -23500,1.5218214,3.7060819,,,,,,,,,,,,,, -23600,1.5086657,3.655724,,,,,,,,,,,,,, -23700,1.1809566,4.1407766,,,,,,,,,,,,,, -23800,1.1213423,4.168758,,,,,,,,,,,,,, -23900,0.90710646,5.70887,,,,,,,,,,,,,, -23984,,,0.5517382621765137,2.039448499679565,0.5036199688911438,2.2866790294647217,50000.0,0.3901000320911407,2.90531587600708,10000.0,10979.964556217194,11816.292253255844,10979.964556217194,833.3303182125092,1.6461646556854248,0.0 -24000,1.1446731,4.319785,,,,,,,,,,,,,, -24100,1.576364,3.6969872,,,,,,,,,,,,,, -24200,1.0352113,5.712245,,,,,,,,,,,,,, -24300,1.0702307,4.54146,,,,,,,,,,,,,, -24400,1.0796603,4.835906,,,,,,,,,,,,,, -24500,1.3475224,3.7413297,,,,,,,,,,,,,, -24600,1.3509624,3.7593663,,,,,,,,,,,,,, -24700,1.1199574,5.7508345,,,,,,,,,,,,,, -24800,1.2994294,3.542072,,,,,,,,,,,,,, -24900,1.4912039,3.777523,,,,,,,,,,,,,, -24904,,,0.541796863079071,2.0832176208496094,0.5054000020027161,2.2420337200164795,50000.0,0.3970000147819519,2.8567094802856445,10000.0,11400.30337524414,12270.108603954315,11400.30337524414,866.7247035503387,1.6784873008728027,0.0 -25000,0.9513382,5.307952,,,,,,,,,,,,,, -25100,0.92234945,5.0024996,,,,,,,,,,,,,, -25200,1.3481005,3.7473433,,,,,,,,,,,,,, -25300,1.4711946,3.641522,,,,,,,,,,,,,, -25400,1.3865274,3.7607734,,,,,,,,,,,,,, -25500,1.3930278,4.375123,,,,,,,,,,,,,, -25600,1.0428213,5.7385674,,,,,,,,,,,,,, -25700,1.8290999,3.7013476,,,,,,,,,,,,,, -25800,1.4972556,3.673467,,,,,,,,,,,,,, -25811,,,0.5504101514816284,2.0340700149536133,0.5138599872589111,2.201979875564575,50000.0,0.4013000130653381,2.819454431533813,10000.0,11820.599084615707,12724.000076532364,11820.599084615707,900.2320485115051,1.71818208694458,0.0 -25900,1.3727567,3.666178,,,,,,,,,,,,,, -26000,1.07012,4.643422,,,,,,,,,,,,,, -26100,1.5500178,3.6275735,,,,,,,,,,,,,, -26200,1.50368,3.6958652,,,,,,,,,,,,,, -26300,1.2372766,4.761891,,,,,,,,,,,,,, -26400,1.1872265,4.5932612,,,,,,,,,,,,,, -26500,1.3204228,3.7259808,,,,,,,,,,,,,, -26600,1.2213749,3.7330027,,,,,,,,,,,,,, -26700,1.1156632,4.372858,,,,,,,,,,,,,, -26725,,,0.566210925579071,2.00087571144104,0.5141199827194214,2.229660987854004,50000.0,0.4070000052452087,2.832810878753662,10000.0,12240.525072574615,13177.773655176165,12240.525072574615,933.998238325119,1.7490291595458984,0.0 -26800,1.1974324,3.9650457,,,,,,,,,,,,,, -26900,2.056407,3.6323993,,,,,,,,,,,,,, -27000,1.1331531,5.1867266,,,,,,,,,,,,,, -27100,1.1577648,4.960825,,,,,,,,,,,,,, -27200,1.2449707,4.9820013,,,,,,,,,,,,,, -27300,1.3365911,3.5217881,,,,,,,,,,,,,, -27400,1.0242274,5.2381535,,,,,,,,,,,,,, -27500,1.2648429,3.9948158,,,,,,,,,,,,,, -27600,1.225906,3.8986566,,,,,,,,,,,,,, -27643,,,0.5484570264816284,2.040403366088867,0.519760012626648,2.1909496784210205,50000.0,0.4043000340461731,2.829466819763184,10000.0,12660.572672367096,13630.616124868391,12660.572672367096,966.7126722335817,1.7793781757354736,0.0 -27700,1.3406552,3.5953865,,,,,,,,,,,,,, -27800,1.4391555,3.5451255,,,,,,,,,,,,,, -27900,1.3829192,3.645707,,,,,,,,,,,,,, -28000,1.176431,5.344029,,,,,,,,,,,,,, -28100,1.4736257,3.503264,,,,,,,,,,,,,, -28200,1.395238,3.6596477,,,,,,,,,,,,,, -28300,1.4951646,3.5234318,,,,,,,,,,,,,, -28400,1.490239,3.936549,,,,,,,,,,,,,, -28500,1.1397563,5.2202096,,,,,,,,,,,,,, -28560,,,0.560546875,1.999878168106079,0.5216999650001526,2.1805710792541504,50000.0,0.4077000319957733,2.8053174018859863,10000.0,13080.651557922363,14082.56403541565,13080.651557922363,998.4982385635376,1.811950922012329,0.0 -28600,1.366369,3.4624233,,,,,,,,,,,,,, -28700,1.5302266,3.654175,,,,,,,,,,,,,, -28800,1.7598654,3.7782996,,,,,,,,,,,,,, -28900,1.328548,3.696287,,,,,,,,,,,,,, -29000,1.0499932,4.556059,,,,,,,,,,,,,, -29100,1.117115,5.7528515,,,,,,,,,,,,,, -29200,1.5785695,3.5062265,,,,,,,,,,,,,, -29300,0.96413517,5.8067446,,,,,,,,,,,,,, -29400,1.1775266,4.37958,,,,,,,,,,,,,, -29477,,,0.5726562142372131,1.928755760192871,0.5296599864959717,2.13201904296875,50000.0,0.4166000187397003,2.758517026901245,10000.0,13500.786858081818,14534.419796228409,13500.786858081818,1030.136039018631,1.8442060947418213,0.0 -29500,1.3200682,4.2226086,,,,,,,,,,,,,, -29600,1.0148313,5.0064683,,,,,,,,,,,,,, -29700,0.9999203,5.6886654,,,,,,,,,,,,,, -29800,1.3581282,3.5497658,,,,,,,,,,,,,, -29900,1.3372606,3.674982,,,,,,,,,,,,,, -30000,1.127438,5.716855,,,,,,,,,,,,,, -30100,1.5506757,3.6002653,,,,,,,,,,,,,, -30200,1.2707953,4.112058,,,,,,,,,,,,,, -30300,1.4008746,3.5749743,,,,,,,,,,,,,, -30393,,,0.5676171779632568,1.9483190774917605,0.5301600098609924,2.114281415939331,50000.0,0.4150000214576721,2.7555928230285645,10000.0,13920.732541561129,14987.511257886888,13920.732541561129,1063.1991493701937,1.8761727809906008,0.0 -30400,1.4005921,3.629017,,,,,,,,,,,,,, -30500,1.3922793,3.4909637,,,,,,,,,,,,,, -30600,0.9752747,5.6927867,,,,,,,,,,,,,, -30700,1.3935742,3.409047,,,,,,,,,,,,,, -30800,1.4203463,3.5381951,,,,,,,,,,,,,, -30900,1.4020886,3.527884,,,,,,,,,,,,,, -31000,1.4327757,3.5428653,,,,,,,,,,,,,, -31100,1.4422367,3.5083923,,,,,,,,,,,,,, -31200,1.3385077,3.8594947,,,,,,,,,,,,,, -31300,1.2566639,3.8417022,,,,,,,,,,,,,, -31309,,,0.5728124976158142,1.927043080329895,0.537559986114502,2.104546546936035,50000.0,0.4228000342845917,2.728455305099488,10000.0,14340.659535884855,15440.811559438704,14340.659535884855,1096.492238998413,1.905264139175415,0.0 -31400,1.3133359,4.3308773,,,,,,,,,,,,,, -31500,1.5484396,3.4388905,,,,,,,,,,,,,, -31600,1.6738315,3.5187156,,,,,,,,,,,,,, -31700,1.3389418,3.4994874,,,,,,,,,,,,,, -31800,1.5468432,3.59757,,,,,,,,,,,,,, -31900,1.5155793,3.479515,,,,,,,,,,,,,, -32000,1.2783794,4.518468,,,,,,,,,,,,,, -32100,1.6957769,3.4955025,,,,,,,,,,,,,, -32200,1.6225657,3.4220147,,,,,,,,,,,,,, -32227,,,0.5775781273841858,1.9075417518615725,0.5331599712371826,2.1145474910736084,50000.0,0.4244000315666199,2.728424072265625,10000.0,14760.61864376068,15893.128350257874,14760.61864376068,1128.7679710388184,1.9361648559570312,0.0 -32300,1.3552105,3.5313325,,,,,,,,,,,,,, -32400,1.3515077,4.007159,,,,,,,,,,,,,, -32500,1.4628683,3.5213904,,,,,,,,,,,,,, -32600,1.3237787,3.9015217,,,,,,,,,,,,,, -32700,1.1848127,3.9644284,,,,,,,,,,,,,, -32800,1.2855122,3.665643,,,,,,,,,,,,,, -32900,1.6598111,3.5450711,,,,,,,,,,,,,, -33000,1.5732356,3.542544,,,,,,,,,,,,,, -33100,1.4175321,3.5631866,,,,,,,,,,,,,, -33142,,,0.6056835651397705,1.7929567098617554,0.5388399958610535,2.0905473232269287,50000.0,0.4270000159740448,2.7143259048461914,10000.0,15180.854577302933,16346.38900589943,15180.854577302933,1161.712742805481,1.9648699760437007,0.0 -33200,1.2306992,4.337687,,,,,,,,,,,,,, -33300,1.591588,3.526279,,,,,,,,,,,,,, -33400,1.1359341,4.3419585,,,,,,,,,,,,,, -33500,1.5568762,3.6077485,,,,,,,,,,,,,, -33600,1.2773572,3.9859715,,,,,,,,,,,,,, -33700,1.4688545,3.6058893,,,,,,,,,,,,,, -33800,1.4206654,3.5174844,,,,,,,,,,,,,, -33900,1.1559008,5.0494833,,,,,,,,,,,,,, -34000,1.3346902,3.460751,,,,,,,,,,,,,, -34059,,,0.577441394329071,1.9018783569335933,0.5426999926567078,2.0791501998901367,50000.0,0.4227000176906585,2.717620372772217,10000.0,15601.136335372925,16798.78573513031,15601.136335372925,1193.7472488880155,1.9952008724212649,0.0 -34100,1.1991158,5.381578,,,,,,,,,,,,,, -34200,0.9982546,5.6742587,,,,,,,,,,,,,, -34300,1.5884286,3.4483685,,,,,,,,,,,,,, -34400,1.2656833,4.0771713,,,,,,,,,,,,,, -34500,1.239813,3.9619799,,,,,,,,,,,,,, -34600,1.3786324,3.705836,,,,,,,,,,,,,, -34700,1.4012644,3.5323606,,,,,,,,,,,,,, -34800,1.4380276,3.3493214,,,,,,,,,,,,,, -34900,1.216013,5.448099,,,,,,,,,,,,,, -34977,,,0.5922070145606995,1.842816710472107,0.5476399660110474,2.05157208442688,50000.0,0.429500013589859,2.695765256881714,10000.0,16021.365535497664,17250.523627996445,16021.365535497664,1225.170911550522,2.0287718772888184,0.0 -35000,1.4173183,3.4411898,,,,,,,,,,,,,, -35100,1.5310234,3.4860845,,,,,,,,,,,,,, -35200,1.0531123,5.6133575,,,,,,,,,,,,,, -35300,0.9168505,5.3993406,,,,,,,,,,,,,, -35400,1.3615404,3.9912784,,,,,,,,,,,,,, -35500,1.4740646,3.4400134,,,,,,,,,,,,,, -35600,1.8937165,3.6068664,,,,,,,,,,,,,, -35700,1.5501194,3.4352722,,,,,,,,,,,,,, -35800,1.2622725,4.7090626,,,,,,,,,,,,,, -35896,,,0.6042382717132568,1.7828890085220337,0.5482999682426453,2.0370826721191406,50000.0,0.4358000159263611,2.656701803207397,10000.0,16441.70537519455,17703.087596178055,16441.70537519455,1257.3156542778015,2.0581812858581543,0.0 -35900,1.5348475,3.7812412,,,,,,,,,,,,,, -36000,1.2773991,4.2300315,,,,,,,,,,,,,, -36100,1.4860132,3.413454,,,,,,,,,,,,,, -36200,1.8254994,3.5282516,,,,,,,,,,,,,, -36300,0.98717564,4.886425,,,,,,,,,,,,,, -36400,1.8363765,3.4763901,,,,,,,,,,,,,, -36500,1.5348132,3.6219153,,,,,,,,,,,,,, -36600,1.5022737,3.4580622,,,,,,,,,,,,,, -36700,1.7807162,3.4356759,,,,,,,,,,,,,, -36800,1.4562402,3.5281043,,,,,,,,,,,,,, -36813,,,0.5854101181030273,1.868322730064392,0.5462200045585632,2.047626256942749,50000.0,0.4314000308513641,2.669766902923584,10000.0,16862.025440216064,18154.93812942505,16862.025440216064,1288.7657787799835,2.0886483192443848,0.0 -36900,1.3961796,3.5975423,,,,,,,,,,,,,, -37000,1.1212623,5.122786,,,,,,,,,,,,,, -37100,1.3764403,3.562281,,,,,,,,,,,,,, -37200,1.3522748,5.022589,,,,,,,,,,,,,, -37300,1.5016373,3.6483412,,,,,,,,,,,,,, -37400,1.520567,3.3646305,,,,,,,,,,,,,, -37500,1.4059182,3.4330933,,,,,,,,,,,,,, -37600,1.2291768,4.7476463,,,,,,,,,,,,,, -37700,1.3865092,3.362024,,,,,,,,,,,,,, -37732,,,0.5966015458106995,1.8163317441940308,0.5523599982261658,2.01404070854187,50000.0,0.4401000142097473,2.634584426879883,10000.0,17282.239025354385,18607.237397432327,17282.239025354385,1320.7677319049835,2.1223950386047363,0.0 -37800,1.5826997,3.5295427,,,,,,,,,,,,,, -37900,1.0322819,5.4082065,,,,,,,,,,,,,, -38000,1.2350311,3.8465014,,,,,,,,,,,,,, -38100,1.5029644,3.5068905,,,,,,,,,,,,,, -38200,1.6685119,3.5291142,,,,,,,,,,,,,, -38300,1.4358902,5.675362,,,,,,,,,,,,,, -38400,1.4617552,3.5254674,,,,,,,,,,,,,, -38500,1.5066112,3.56098,,,,,,,,,,,,,, -38600,1.0380509,5.2500553,,,,,,,,,,,,,, -38651,,,0.60693359375,1.7717552185058594,0.5542600154876709,2.0164029598236084,50000.0,0.4394000172615051,2.6386351585388184,10000.0,17702.3341217041,19060.093980789185,17702.3341217041,1353.4470887184143,2.154141902923584,0.0 -38700,1.4960352,3.7929182,,,,,,,,,,,,,, -38800,1.0874872,4.7621803,,,,,,,,,,,,,, -38900,1.265389,4.9719086,,,,,,,,,,,,,, -39000,1.7187343,3.3934696,,,,,,,,,,,,,, -39100,1.1785223,4.180703,,,,,,,,,,,,,, -39200,1.1751474,5.647996,,,,,,,,,,,,,, -39300,1.4534736,3.6206303,,,,,,,,,,,,,, -39400,1.3279576,5.573661,,,,,,,,,,,,,, -39500,1.0608361,5.690296,,,,,,,,,,,,,, -39571,,,0.590136706829071,1.8692361116409304,0.5558800101280212,2.0332841873168945,50000.0,0.4341000318527221,2.667387723922729,10000.0,18122.364139556885,19513.99955034256,18122.364139556885,1387.2440507411957,2.182854175567627,0.0 -39600,1.4871228,3.4918032,,,,,,,,,,,,,, -39700,1.3630452,3.480262,,,,,,,,,,,,,, -39800,1.178164,4.872717,,,,,,,,,,,,,, -39900,1.2228616,4.147446,,,,,,,,,,,,,, -40000,1.4743465,3.6319308,,,,,,,,,,,,,, -40100,1.5026672,3.5653977,,,,,,,,,,,,,, -40200,1.0641049,5.593485,,,,,,,,,,,,,, -40300,1.7059624,3.4562953,,,,,,,,,,,,,, -40400,1.3918805,3.452987,,,,,,,,,,,,,, -40490,,,0.5978320240974426,1.8688842058181765,0.5544999837875366,2.060164213180542,50000.0,0.4387000203132629,2.6926157474517822,10000.0,18542.35852384568,19965.28075647354,18542.35852384568,1418.446210384369,2.216303586959839,0.0 -40500,1.4024516,3.734883,,,,,,,,,,,,,, -40600,1.5418285,3.4582586,,,,,,,,,,,,,, -40700,1.3558136,4.212097,,,,,,,,,,,,,, -40800,1.4523107,3.3578637,,,,,,,,,,,,,, -40900,1.6070836,3.567349,,,,,,,,,,,,,, -41000,1.7209182,3.4246616,,,,,,,,,,,,,, -41100,1.4459726,3.4891632,,,,,,,,,,,,,, -41200,1.5629797,3.740427,,,,,,,,,,,,,, -41300,1.0529515,5.443426,,,,,,,,,,,,,, -41400,1.3600373,3.3270755,,,,,,,,,,,,,, -41407,,,0.6039062142372131,1.8160500526428225,0.5564199686050415,2.0289831161499023,50000.0,0.4399000108242035,2.66598916053772,10000.0,18962.650825738907,20416.922464370728,18962.650825738907,1449.7089776992798,2.25161361694336,0.0 -41500,1.2673082,5.0301933,,,,,,,,,,,,,, -41600,1.522531,3.5328557,,,,,,,,,,,,,, -41700,1.1333443,4.6304965,,,,,,,,,,,,,, -41800,1.6963879,3.5100558,,,,,,,,,,,,,, -41900,1.5610914,3.6794872,,,,,,,,,,,,,, -42000,1.4280792,3.3787982,,,,,,,,,,,,,, -42100,1.4695992,3.6069984,,,,,,,,,,,,,, -42200,1.5438976,3.5179834,,,,,,,,,,,,,, -42300,1.2559767,5.1661444,,,,,,,,,,,,,, -42326,,,0.5984765291213989,1.8087574243545528,0.5619800090789795,1.9925109148025515,50000.0,0.4429000318050384,2.6232457160949707,10000.0,19382.72688031197,20869.60631251335,19382.72688031197,1482.2264337539673,2.291311264038086,0.0 -42400,1.2438567,4.110463,,,,,,,,,,,,,, -42500,1.4777873,3.4145763,,,,,,,,,,,,,, -42600,1.6823769,3.4404159,,,,,,,,,,,,,, -42700,1.2284335,4.676367,,,,,,,,,,,,,, -42800,1.2133771,4.8602324,,,,,,,,,,,,,, -42900,1.5348483,3.4870129,,,,,,,,,,,,,, -43000,1.5763654,3.4268327,,,,,,,,,,,,,, -43100,1.6281253,3.413531,,,,,,,,,,,,,, -43200,1.234294,3.856714,,,,,,,,,,,,,, -43244,,,0.6040819883346558,1.7556627988815308,0.5645599961280823,1.9334279298782349,50000.0,0.4507000148296356,2.5570688247680664,10000.0,19802.75421524048,21322.457344055176,19802.75421524048,1514.9667398929596,2.3243231773376465,0.0 -43300,1.5204716,3.4968183,,,,,,,,,,,,,, -43400,1.5030233,3.3681047,,,,,,,,,,,,,, -43500,1.6076912,3.4890618,,,,,,,,,,,,,, -43600,1.4495608,3.5060015,,,,,,,,,,,,,, -43700,1.1264848,5.049402,,,,,,,,,,,,,, -43800,1.473061,3.510929,,,,,,,,,,,,,, -43900,1.4663076,3.433294,,,,,,,,,,,,,, -44000,1.3679975,3.4734993,,,,,,,,,,,,,, -44100,1.4132239,4.918461,,,,,,,,,,,,,, -44162,,,0.6100195050239563,1.7446229457855225,0.5627399682998657,1.9554500579833984,50000.0,0.4501000344753265,2.5805513858795166,10000.0,20223.016721487045,21775.79777216912,20223.016721487045,1547.9607965946198,2.358118295669556,0.0 -44200,1.4566072,3.4106567,,,,,,,,,,,,,, -44300,1.5348556,3.4672291,,,,,,,,,,,,,, -44400,1.2774116,4.618615,,,,,,,,,,,,,, -44500,1.5066069,3.3904037,,,,,,,,,,,,,, -44600,1.7019216,3.4399223,,,,,,,,,,,,,, -44700,1.4737653,3.955035,,,,,,,,,,,,,, -44800,1.321194,4.8641486,,,,,,,,,,,,,, -44900,1.522794,3.382295,,,,,,,,,,,,,, -45000,1.1824316,5.149306,,,,,,,,,,,,,, -45079,,,0.6299804449081421,1.65792977809906,0.564079999923706,1.9666937589645384,50000.0,0.4431000351905823,2.618972063064575,10000.0,20643.23784804344,22229.7259953022,20643.23784804344,1581.5847356319427,2.3912997245788574,0.0 -45100,1.3552954,4.9055805,,,,,,,,,,,,,, -45200,1.164786,5.6092405,,,,,,,,,,,,,, -45300,1.5067319,3.4841452,,,,,,,,,,,,,, -45400,1.4448247,3.389394,,,,,,,,,,,,,, -45500,1.7063441,3.420042,,,,,,,,,,,,,, -45600,1.7096314,3.3872368,,,,,,,,,,,,,, -45700,1.1291896,5.4316216,,,,,,,,,,,,,, -45800,1.1045488,4.9971514,,,,,,,,,,,,,, -45900,1.6523165,3.4666877,,,,,,,,,,,,,, -45998,,,0.6064062118530273,1.7620118856430054,0.5678600072860718,1.9412288665771484,50000.0,0.4560000300407409,2.58458948135376,10000.0,21063.51750826836,22680.277902126312,21063.51750826836,1611.773603439331,2.4242405891418457,0.0 -46000,1.3469255,5.2146664,,,,,,,,,,,,,, -46100,1.5328436,3.42958,,,,,,,,,,,,,, -46200,1.7341324,3.4447749,,,,,,,,,,,,,, -46300,1.489264,4.0969424,,,,,,,,,,,,,, -46400,1.4901268,3.3417194,,,,,,,,,,,,,, -46500,1.434655,3.5169497,,,,,,,,,,,,,, -46600,1.6742198,3.3171494,,,,,,,,,,,,,, -46700,1.557619,3.4668844,,,,,,,,,,,,,, -46800,1.1305867,5.4072027,,,,,,,,,,,,,, -46900,1.4981146,4.696863,,,,,,,,,,,,,, -46912,,,0.6177538633346558,1.7487009763717651,0.5697599649429321,1.956942081451416,50000.0,0.4544000327587127,2.5762476921081543,10000.0,21483.74810171128,23131.15274167061,21483.74810171128,1642.3246562480929,2.4671669006347656,0.0 -47000,1.3198414,4.612858,,,,,,,,,,,,,, -47100,1.641439,3.4345746,,,,,,,,,,,,,, -47200,1.5070084,3.2995985,,,,,,,,,,,,,, -47300,1.1676389,4.954886,,,,,,,,,,,,,, -47400,0.99538654,5.558426,,,,,,,,,,,,,, -47500,1.572184,3.5948045,,,,,,,,,,,,,, -47600,1.4827806,3.4662676,,,,,,,,,,,,,, -47700,1.3819321,4.333733,,,,,,,,,,,,,, -47800,1.3357761,4.1410584,,,,,,,,,,,,,, -47827,,,0.6209570169448853,1.705472469329834,0.5658599734306335,1.9722998142242432,50000.0,0.4462000131607055,2.595702171325684,10000.0,21903.900598526,23582.35942387581,21903.900598526,1673.292890548706,2.5029873847961426,0.0 -47900,1.6017423,3.401939,,,,,,,,,,,,,, -48000,1.4979172,4.1985188,,,,,,,,,,,,,, -48100,1.4862152,3.444151,,,,,,,,,,,,,, -48200,1.5541459,3.5740254,,,,,,,,,,,,,, -48300,1.4973286,3.4670286,,,,,,,,,,,,,, -48400,1.5707868,3.3863096,,,,,,,,,,,,,, -48500,1.585821,3.2909195,,,,,,,,,,,,,, -48600,1.4593344,3.3759847,,,,,,,,,,,,,, -48700,1.499682,3.3561893,,,,,,,,,,,,,, -48743,,,0.604785144329071,1.7722070217132568,0.568120002746582,1.9483237266540527,50000.0,0.4463000297546386,2.602745771408081,10000.0,22323.86534023285,24034.65631222725,22323.86534023285,1705.5355398654938,2.5407254695892334,0.0 -48800,1.573317,3.388006,,,,,,,,,,,,,, -48900,1.4965612,4.0902386,,,,,,,,,,,,,, -49000,1.5461962,4.5849137,,,,,,,,,,,,,, -49100,1.6819954,3.297087,,,,,,,,,,,,,, -49200,1.5698278,3.4179988,,,,,,,,,,,,,, -49300,1.3285198,4.569746,,,,,,,,,,,,,, -49400,1.5678785,3.4202433,,,,,,,,,,,,,, -49500,1.7696354,3.2264585,,,,,,,,,,,,,, -49600,1.3850915,5.49875,,,,,,,,,,,,,, -49662,,,0.6148828268051147,1.721850037574768,0.5727399587631226,1.9160726070404053,50000.0,0.4552000164985657,2.5426297187805176,10000.0,22743.817096471783,24488.13495373726,22743.817096471783,1738.979534626007,2.573563814163208,0.0 -49700,1.7171229,3.3803363,,,,,,,,,,,,,, -49800,1.4569743,3.2708335,,,,,,,,,,,,,, -49900,1.4875256,3.4834661,,,,,,,,,,,,,, -50000,1.555485,3.3542876,,,,,,,,,,,,,, -50100,1.0631858,5.4601264,,,,,,,,,,,,,, -50200,1.4267907,3.3368812,,,,,,,,,,,,,, -50300,1.709091,3.6951108,,,,,,,,,,,,,, -50400,1.2624861,5.568554,,,,,,,,,,,,,, -50500,1.556989,3.4705412,,,,,,,,,,,,,, -50580,,,0.61865234375,1.7359861135482788,0.5676199793815613,1.9560025930404663,50000.0,0.4517000317573547,2.5810353755950928,10000.0,23163.91581749916,24939.45225691796,23163.91581749916,1770.112517118454,2.60744047164917,0.0 -50600,1.5419164,3.4415534,,,,,,,,,,,,,, -50700,1.4793766,4.0796075,,,,,,,,,,,,,, -50800,1.7644651,3.2841227,,,,,,,,,,,,,, -50900,1.589322,3.3186433,,,,,,,,,,,,,, -51000,1.5900834,3.3651905,,,,,,,,,,,,,, -51100,1.6353259,3.360404,,,,,,,,,,,,,, -51200,1.4390594,5.0207477,,,,,,,,,,,,,, -51300,1.5175154,3.3321788,,,,,,,,,,,,,, -51400,1.4703751,3.4237845,,,,,,,,,,,,,, -51498,,,0.6126366853713989,1.7147948741912842,0.5749599933624268,1.904335618019104,50000.0,0.4583000242710113,2.5383946895599365,10000.0,23584.244931459427,25392.24175477028,23584.244931459427,1802.488485097885,2.6418302059173584,0.0 -51500,1.4742122,3.314662,,,,,,,,,,,,,, -51600,1.2159933,5.5686464,,,,,,,,,,,,,, -51700,1.4709166,3.9997272,,,,,,,,,,,,,, -51800,1.2320896,5.1593733,,,,,,,,,,,,,, -51900,1.6626241,3.245416,,,,,,,,,,,,,, -52000,1.4654317,3.758319,,,,,,,,,,,,,, -52100,1.065374,5.0023575,,,,,,,,,,,,,, -52200,1.5825535,3.3184993,,,,,,,,,,,,,, -52300,1.562338,3.290513,,,,,,,,,,,,,, -52400,1.2020496,4.2206554,,,,,,,,,,,,,, -52414,,,0.6123437285423279,1.7380605936050415,0.5724999904632568,1.917758822441101,50000.0,0.4573000073432922,2.529676914215088,10000.0,24004.364980459213,25843.898319244385,24004.364980459213,1833.9402480125427,2.6762707233428955,0.0 -52500,1.526335,3.760602,,,,,,,,,,,,,, -52600,1.4710318,5.187855,,,,,,,,,,,,,, -52700,1.1174605,5.451118,,,,,,,,,,,,,, -52800,1.6338079,3.6643493,,,,,,,,,,,,,, -52900,1.3553071,4.2444506,,,,,,,,,,,,,, -53000,1.4719884,3.4658866,,,,,,,,,,,,,, -53100,1.7312845,3.4841673,,,,,,,,,,,,,, -53200,1.4284487,3.4632292,,,,,,,,,,,,,, -53300,1.3911797,3.7805953,,,,,,,,,,,,,, -53330,,,0.6242578029632568,1.6958975791931152,0.5759199857711792,1.9176926612854004,50000.0,0.45210000872612,2.568796396255493,10000.0,24424.607334136963,26295.831995487213,24424.607334136963,1865.5457472801208,2.711615800857544,0.0 -53400,1.5007131,3.6788783,,,,,,,,,,,,,, -53500,1.2996526,5.58373,,,,,,,,,,,,,, -53600,1.2276765,4.5345545,,,,,,,,,,,,,, -53700,1.526002,3.3672922,,,,,,,,,,,,,, -53800,1.2683412,5.2956476,,,,,,,,,,,,,, -53900,1.1839136,5.0093374,,,,,,,,,,,,,, -54000,1.3787965,4.2298093,,,,,,,,,,,,,, -54100,1.5190911,3.2980747,,,,,,,,,,,,,, -54200,1.672736,3.374403,,,,,,,,,,,,,, -54244,,,0.6257421970367432,1.7023617029190063,0.5797600150108337,1.9058101177215576,50000.0,0.4621000289916992,2.5444371700286865,10000.0,24844.78872013092,26748.55784964561,24844.78872013092,1897.9944295883176,2.7577033042907715,0.0 -54300,1.4274645,3.7123246,,,,,,,,,,,,,, -54400,1.850888,3.5134566,,,,,,,,,,,,,, -54500,1.6566178,3.1497438,,,,,,,,,,,,,, -54600,1.5757823,3.3277454,,,,,,,,,,,,,, -54700,1.421313,3.6207266,,,,,,,,,,,,,, -54800,1.4645401,4.010674,,,,,,,,,,,,,, -54900,1.4798223,3.3963315,,,,,,,,,,,,,, -55000,1.5148294,3.2474911,,,,,,,,,,,,,, -55100,1.4701324,4.220255,,,,,,,,,,,,,, -55162,,,0.6188085675239563,1.706288456916809,0.5832399725914001,1.87911856174469,50000.0,0.4599000215530395,2.5232367515563965,10000.0,25264.96487569809,27200.61718583107,25264.96487569809,1929.7939734458923,2.7902755737304688,0.0 -55200,1.5917721,3.3795276,,,,,,,,,,,,,, -55300,1.5512059,3.3808527,,,,,,,,,,,,,, -55400,1.6579902,3.3329573,,,,,,,,,,,,,, -55500,1.1479219,5.506762,,,,,,,,,,,,,, -55600,1.485941,3.440207,,,,,,,,,,,,,, -55700,1.4439579,4.5063744,,,,,,,,,,,,,, -55800,1.3989165,4.5761695,,,,,,,,,,,,,, -55900,1.5483122,3.2861369,,,,,,,,,,,,,, -56000,1.1715782,5.3272448,,,,,,,,,,,,,, -56078,,,0.6207226514816284,1.7265022993087769,0.5751399993896484,1.932764768600464,50000.0,0.4586000144481659,2.569799423217773,10000.0,25685.004106283188,27652.70127034188,25685.004106283188,1961.75643324852,2.822684526443481,0.0 -56100,1.4635913,3.5831366,,,,,,,,,,,,,, -56200,1.5100672,3.429462,,,,,,,,,,,,,, -56300,1.3808672,4.3546505,,,,,,,,,,,,,, -56400,1.6637408,3.3987281,,,,,,,,,,,,,, -56500,1.4056547,3.7963235,,,,,,,,,,,,,, -56600,1.4944886,3.23712,,,,,,,,,,,,,, -56700,1.5510169,3.281042,,,,,,,,,,,,,, -56800,1.6808589,3.3897707,,,,,,,,,,,,,, -56900,1.6987568,3.320021,,,,,,,,,,,,,, -56994,,,0.6523046493530273,1.600400686264038,0.580299973487854,1.9175043106079104,50000.0,0.4688000082969665,2.537362575531006,10000.0,26105.01860499382,28105.36644411087,26105.01860499382,1994.321251630783,2.8584837913513184,0.0 -57000,1.6479291,3.2794771,,,,,,,,,,,,,, -57100,1.6774228,3.334105,,,,,,,,,,,,,, -57200,1.3409693,3.716226,,,,,,,,,,,,,, -57300,1.4899281,3.4566522,,,,,,,,,,,,,, -57400,1.5076426,3.6265461,,,,,,,,,,,,,, -57500,1.6138575,3.281505,,,,,,,,,,,,,, -57600,1.2396686,4.5903263,,,,,,,,,,,,,, -57700,1.3647746,3.5231087,,,,,,,,,,,,,, -57800,1.5098995,3.456223,,,,,,,,,,,,,, -57900,1.5525599,3.2900114,,,,,,,,,,,,,, -57911,,,0.6321093440055847,1.6738526821136477,0.5862999558448792,1.871827960014344,50000.0,0.4709000289440155,2.4993984699249268,10000.0,26525.360887289047,28557.5066075325,26525.360887289047,2026.0364346504207,2.890997171401977,0.0 -58000,1.2573082,5.469857,,,,,,,,,,,,,, -58100,1.2713494,5.471533,,,,,,,,,,,,,, -58200,1.5827017,3.2857194,,,,,,,,,,,,,, -58300,1.6768351,3.3730657,,,,,,,,,,,,,, -58400,1.1835821,4.849215,,,,,,,,,,,,,, -58500,1.501511,3.454235,,,,,,,,,,,,,, -58600,1.240306,5.1452723,,,,,,,,,,,,,, -58700,1.9301594,3.3488495,,,,,,,,,,,,,, -58800,1.1533949,5.090337,,,,,,,,,,,,,, -58829,,,0.6249608993530273,1.6657127141952517,0.5823599696159363,1.864213824272156,50000.0,0.4593000113964081,2.511036157608032,10000.0,26945.452782392505,29010.55766916275,26945.452782392505,2058.912977695465,2.924208402633667,0.0 -58900,1.1379532,5.3050876,,,,,,,,,,,,,, -59000,1.6353639,3.3065126,,,,,,,,,,,,,, -59100,1.7773212,3.2477574,,,,,,,,,,,,,, -59200,1.9620662,3.3938391,,,,,,,,,,,,,, -59300,1.1813407,4.836458,,,,,,,,,,,,,, -59400,1.5299909,3.3331485,,,,,,,,,,,,,, -59500,1.2876353,3.8396664,,,,,,,,,,,,,, -59600,1.7249727,3.340471,,,,,,,,,,,,,, -59700,1.6780605,3.4979653,,,,,,,,,,,,,, -59747,,,0.643359363079071,1.587910771369934,0.5870400071144104,1.8554558753967283,50000.0,0.4684000313282013,2.4895777702331543,10000.0,27365.50388979912,29464.2038834095,27365.50388979912,2092.420877933502,2.961343050003052,0.0 -59800,1.2237308,5.287975,,,,,,,,,,,,,, -59900,1.7629968,3.3856697,,,,,,,,,,,,,, -60000,1.3262681,4.701234,,,,,,,,,,,,,, -60100,1.5615702,3.3551311,,,,,,,,,,,,,, -60200,1.534223,3.5195143,,,,,,,,,,,,,, -60300,1.6304824,3.2406127,,,,,,,,,,,,,, -60400,1.564869,3.4659796,,,,,,,,,,,,,, -60500,1.4199876,3.9778087,,,,,,,,,,,,,, -60600,1.7126867,3.353728,,,,,,,,,,,,,, -60664,,,0.6283788681030273,1.662112593650818,0.585319995880127,1.851845622062683,50000.0,0.4672000110149383,2.470869541168213,10000.0,27785.721799850464,29917.078412532806,27785.721799850464,2124.993516921997,2.995525598526001,0.0 -60700,1.2651597,4.27878,,,,,,,,,,,,,, -60800,1.6744623,3.299244,,,,,,,,,,,,,, -60900,1.3661796,5.3912163,,,,,,,,,,,,,, -61000,1.3674312,3.8534722,,,,,,,,,,,,,, -61100,1.5394219,3.1948094,,,,,,,,,,,,,, -61200,1.716275,3.2888973,,,,,,,,,,,,,, -61300,1.3270947,5.044715,,,,,,,,,,,,,, -61400,1.620699,3.3547983,,,,,,,,,,,,,, -61500,1.579594,3.4058688,,,,,,,,,,,,,, -61582,,,0.6338085532188416,1.6124237775802612,0.591219961643219,1.8236290216445925,50000.0,0.4756000339984894,2.4673588275909424,10000.0,28205.97324538231,30371.00706982613,28205.97324538231,2158.5834772586823,3.031461477279663,0.0 -61600,1.318848,4.710682,,,,,,,,,,,,,, -61700,1.4615494,3.289896,,,,,,,,,,,,,, -61800,1.7577074,3.3448684,,,,,,,,,,,,,, -61900,1.6087992,3.5951414,,,,,,,,,,,,,, -62000,1.2588181,4.7708797,,,,,,,,,,,,,, -62100,1.6307672,3.3280149,,,,,,,,,,,,,, -62200,1.2683376,4.9712687,,,,,,,,,,,,,, -62300,1.3588164,3.588304,,,,,,,,,,,,,, -62400,1.5837841,3.5283942,,,,,,,,,,,,,, -62500,1.4510052,4.0186725,,,,,,,,,,,,,, -62501,,,0.6354882717132568,1.626826286315918,0.5827599763870239,1.8661490678787231,50000.0,0.4697000086307525,2.5128331184387207,10000.0,28626.62084031105,30822.73681116104,28626.62084031105,2189.576506137848,3.069997310638428,0.0 -62600,1.4946817,4.709477,,,,,,,,,,,,,, -62700,1.5174055,3.476295,,,,,,,,,,,,,, -62800,1.3736507,5.429424,,,,,,,,,,,,,, -62900,1.5250502,3.2226224,,,,,,,,,,,,,, -63000,1.5880467,3.3031383,,,,,,,,,,,,,, -63100,1.6276113,3.1916244,,,,,,,,,,,,,, -63200,1.2047137,4.880866,,,,,,,,,,,,,, -63300,1.3018368,5.327491,,,,,,,,,,,,,, -63400,1.2481849,5.3484173,,,,,,,,,,,,,, -63419,,,0.6289843320846558,1.6643264293670654,0.583899974822998,1.8471851348876955,50000.0,0.4674000144004822,2.485158920288086,10000.0,29046.67929458618,31276.06823301316,29046.67929458618,2222.766315698624,3.1039247512817383,0.0 -63500,1.5480659,3.1911337,,,,,,,,,,,,,, -63600,1.4060489,4.868266,,,,,,,,,,,,,, -63700,1.6969918,3.387775,,,,,,,,,,,,,, -63800,1.6677219,3.2921746,,,,,,,,,,,,,, -63900,1.3822404,4.012583,,,,,,,,,,,,,, -64000,1.7348183,3.3107417,,,,,,,,,,,,,, -64100,1.6575973,3.3037832,,,,,,,,,,,,,, -64200,1.4777083,4.3071923,,,,,,,,,,,,,, -64300,1.5163478,3.1289392,,,,,,,,,,,,,, -64335,,,0.6406640410423279,1.5750188827514648,0.5931800007820129,1.7844353914260864,50000.0,0.4731000363826751,2.420977830886841,10000.0,29466.69498157501,31728.67764186859,29466.69498157501,2255.2734336853027,3.1402459144592285,0.0 -64400,1.9681138,3.3965642,,,,,,,,,,,,,, -64500,1.8716023,3.1748013,,,,,,,,,,,,,, -64600,1.8503009,3.4183488,,,,,,,,,,,,,, -64700,1.3597658,5.1984186,,,,,,,,,,,,,, -64800,1.6754881,3.3230715,,,,,,,,,,,,,, -64900,1.437002,3.6512132,,,,,,,,,,,,,, -65000,1.6332254,3.3660078,,,,,,,,,,,,,, -65100,1.7520053,3.2256372,,,,,,,,,,,,,, -65200,1.7442238,3.262257,,,,,,,,,,,,,, -65251,,,0.6419140696525574,1.613297700881958,0.5929799675941467,1.845779657363892,50000.0,0.4781000316143036,2.480360984802246,10000.0,29886.907014369965,32181.92774629593,29886.907014369965,2288.2237129211426,3.1775012016296387,0.0 -65300,1.6833498,3.1952174,,,,,,,,,,,,,, -65400,1.9311984,3.2977822,,,,,,,,,,,,,, -65500,1.8352355,3.4253097,,,,,,,,,,,,,, -65600,1.5387063,4.179469,,,,,,,,,,,,,, -65700,1.630857,3.5839076,,,,,,,,,,,,,, -65800,1.7110554,3.1863055,,,,,,,,,,,,,, -65900,1.3162575,5.418335,,,,,,,,,,,,,, -66000,1.3790668,3.6491528,,,,,,,,,,,,,, -66100,1.4806806,5.1260138,,,,,,,,,,,,,, -66167,,,0.6465820074081421,1.5871703624725342,0.596560001373291,1.818373203277588,50000.0,0.4720000326633453,2.46311092376709,10000.0,30307.044764518738,32634.469495773315,30307.044764518738,2320.53994345665,3.215022325515747,0.0 -66200,1.3427719,4.7463703,,,,,,,,,,,,,, -66300,1.443774,3.98584,,,,,,,,,,,,,, -66400,1.3070178,4.822782,,,,,,,,,,,,,, -66500,1.448916,5.380732,,,,,,,,,,,,,, -66600,1.5246483,3.2284398,,,,,,,,,,,,,, -66700,1.7351897,3.2777636,,,,,,,,,,,,,, -66800,1.8433689,3.3067906,,,,,,,,,,,,,, -66900,1.696687,3.2452233,,,,,,,,,,,,,, -67000,1.5979866,3.3177762,,,,,,,,,,,,,, -67084,,,0.6468554735183716,1.5900592803955078,0.6029599905014038,1.7880793809890747,50000.0,0.4830000102519989,2.422799825668335,10000.0,30727.09972333908,33086.314239263535,30727.09972333908,2352.242401361465,3.251006841659546,0.0 -67100,1.5354953,3.5843954,,,,,,,,,,,,,, -67200,1.526425,3.440453,,,,,,,,,,,,,, -67300,1.670199,3.3618546,,,,,,,,,,,,,, -67400,1.3139894,4.8527117,,,,,,,,,,,,,, -67500,1.7535243,3.301386,,,,,,,,,,,,,, -67600,1.298258,4.3330097,,,,,,,,,,,,,, -67700,1.5317271,3.4187186,,,,,,,,,,,,,, -67800,1.700497,3.263038,,,,,,,,,,,,,, -67900,1.3918014,4.236579,,,,,,,,,,,,,, -68000,1.5477855,3.3963826,,,,,,,,,,,,,, -68001,,,0.6465234160423279,1.5876259803771973,0.5993599891662598,1.807966232299805,50000.0,0.4832000136375427,2.4284214973449707,10000.0,31147.37824487686,33539.42764592171,31147.37824487686,2384.98783993721,3.290488004684448,0.0 -68100,1.6550958,3.2521563,,,,,,,,,,,,,, -68200,1.6089232,3.247099,,,,,,,,,,,,,, -68300,1.4186826,4.490781,,,,,,,,,,,,,, -68400,1.7976002,3.2560167,,,,,,,,,,,,,, -68500,1.2962102,4.7075653,,,,,,,,,,,,,, -68600,1.3680483,5.401349,,,,,,,,,,,,,, -68700,1.485674,3.8374527,,,,,,,,,,,,,, -68800,1.6974443,3.2477887,,,,,,,,,,,,,, -68900,1.9216346,3.3511143,,,,,,,,,,,,,, -68916,,,0.6715624928474426,1.442800521850586,0.6010000109672546,1.7613905668258667,50000.0,0.4843000173568725,2.398958206176758,10000.0,31567.33975481987,33991.27369570732,31567.33975481987,2416.788944482804,3.324157476425171,0.0 -69000,1.3112519,4.032984,,,,,,,,,,,,,, -69100,1.2795674,4.870262,,,,,,,,,,,,,, -69200,1.3879379,4.43514,,,,,,,,,,,,,, -69300,1.6171017,3.537228,,,,,,,,,,,,,, -69400,1.816656,4.827389,,,,,,,,,,,,,, -69500,1.5313609,3.51971,,,,,,,,,,,,,, -69600,1.4693667,3.889073,,,,,,,,,,,,,, -69700,1.7106397,3.3528323,,,,,,,,,,,,,, -69800,2.0579116,3.2023935,,,,,,,,,,,,,, -69831,,,0.6374413967132568,1.655277967453003,0.5964199900627136,1.832472443580628,50000.0,0.4826000332832336,2.457373142242432,10000.0,31987.69392681122,34443.77940249443,31987.69392681122,2448.853832483292,3.36092495918274,0.0 -69900,1.6538374,3.124631,,,,,,,,,,,,,, -70000,1.5361758,4.0002017,,,,,,,,,,,,,, -70100,1.8116349,3.2063568,,,,,,,,,,,,,, -70200,1.4100574,3.7773194,,,,,,,,,,,,,, -70300,1.5995381,3.1858208,,,,,,,,,,,,,, -70400,1.4924531,4.5080028,,,,,,,,,,,,,, -70500,1.4441769,5.2876225,,,,,,,,,,,,,, -70600,1.6082212,3.1863875,,,,,,,,,,,,,, -70700,1.7517382,3.2918627,,,,,,,,,,,,,, -70747,,,0.6484960913658142,1.5486077070236206,0.6013599634170532,1.7578142881393433,50000.0,0.4763000309467315,2.390446662902832,10000.0,32407.676120996475,34893.63464021683,32407.676120996475,2478.640973091125,3.3965108394622803,0.0 -70800,1.8041427,3.1858451,,,,,,,,,,,,,, -70900,1.3972446,4.4930124,,,,,,,,,,,,,, -71000,1.5615801,3.4273138,,,,,,,,,,,,,, -71100,1.6609179,5.301161,,,,,,,,,,,,,, -71200,1.6653383,3.245964,,,,,,,,,,,,,, -71300,1.3330444,4.4274626,,,,,,,,,,,,,, -71400,1.8310331,3.234802,,,,,,,,,,,,,, -71500,1.6213106,3.2367725,,,,,,,,,,,,,, -71600,1.4131039,4.1584377,,,,,,,,,,,,,, -71662,,,0.6622851490974426,1.493794083595276,0.6013199687004089,1.7728713750839231,50000.0,0.4798000156879425,2.4019968509674072,10000.0,32827.82526350021,35345.7574763298,32827.82526350021,2510.5186746120453,3.442301034927368,0.0 -71700,1.7173661,3.1596913,,,,,,,,,,,,,, -71800,1.4064031,4.023472,,,,,,,,,,,,,, -71900,1.4953665,3.5334477,,,,,,,,,,,,,, -72000,1.5282952,3.842487,,,,,,,,,,,,,, -72100,1.6956303,3.6954548,,,,,,,,,,,,,, -72200,2.090251,3.375944,,,,,,,,,,,,,, -72300,1.8121653,3.2271855,,,,,,,,,,,,,, -72400,1.6822222,3.2701154,,,,,,,,,,,,,, -72500,1.3544328,5.1355405,,,,,,,,,,,,,, -72578,,,0.6486914157867432,1.5760599374771118,0.6078199744224548,1.7598813772201538,50000.0,0.4820000231266022,2.406587600708008,10000.0,33247.85134124756,35798.13573217392,33247.85134124756,2542.784086465836,3.47953462600708,0.0 -72600,1.6515391,3.1929765,,,,,,,,,,,,,, -72700,1.3360025,5.125786,,,,,,,,,,,,,, -72800,1.9727094,3.2502935,,,,,,,,,,,,,, -72900,1.5173566,4.865423,,,,,,,,,,,,,, -73000,1.4769864,3.851155,,,,,,,,,,,,,, -73100,1.6295668,3.105083,,,,,,,,,,,,,, -73200,1.6356999,3.1686652,,,,,,,,,,,,,, -73300,1.5309441,3.088005,,,,,,,,,,,,,, -73400,1.602505,3.1410594,,,,,,,,,,,,,, -73498,,,0.650683581829071,1.5488523244857788,0.602840006351471,1.7498250007629397,50000.0,0.4812000095844269,2.388460159301758,10000.0,33668.151161670685,36250.21720933914,33668.151161670685,2574.4699428081512,3.524752616882324,0.0 -73500,1.5013123,3.4901707,,,,,,,,,,,,,, -73600,1.6857681,3.1776223,,,,,,,,,,,,,, -73700,1.7782314,3.2293289,,,,,,,,,,,,,, -73800,1.6757787,3.196283,,,,,,,,,,,,,, -73900,1.5576495,4.83078,,,,,,,,,,,,,, -74000,1.5554283,3.9100277,,,,,,,,,,,,,, -74100,1.610766,3.3653078,,,,,,,,,,,,,, -74200,1.3812904,5.302803,,,,,,,,,,,,,, -74300,1.8349675,3.1484056,,,,,,,,,,,,,, -74400,1.8348508,3.3047774,,,,,,,,,,,,,, -74415,,,0.6654492020606995,1.506258487701416,0.6062999963760376,1.7637112140655518,50000.0,0.4891000092029571,2.39551329612732,10000.0,34088.156512498856,36701.84150767326,34088.156512498856,2606.0012538433075,3.561843156814575,0.0 -74500,1.8066292,3.3273344,,,,,,,,,,,,,, -74600,1.9016178,3.3327155,,,,,,,,,,,,,, -74700,1.7958194,3.2642765,,,,,,,,,,,,,, -74800,1.7868466,3.0291712,,,,,,,,,,,,,, -74900,1.7136289,3.912473,,,,,,,,,,,,,, -75000,1.6645925,3.1116073,,,,,,,,,,,,,, -75100,1.6822424,3.1841023,,,,,,,,,,,,,, -75200,1.6365032,3.465344,,,,,,,,,,,,,, -75300,1.4783136,4.9582224,,,,,,,,,,,,,, -75333,,,0.6486523151397705,1.547737956047058,0.612500011920929,1.729038119316101,50000.0,0.4856000244617462,2.3694193363189697,10000.0,34508.42599821091,37156.241079092026,34508.42599821091,2640.042801618576,3.6005239486694336,0.0 -75400,1.7463105,3.4336448,,,,,,,,,,,,,, -75500,1.5652117,3.0963082,,,,,,,,,,,,,, -75600,1.5790474,3.22674,,,,,,,,,,,,,, -75700,1.593408,5.3715615,,,,,,,,,,,,,, -75800,1.6040981,3.385611,,,,,,,,,,,,,, -75900,1.3925562,4.8287773,,,,,,,,,,,,,, -76000,1.6520145,3.1672382,,,,,,,,,,,,,, -76100,1.7808805,3.272676,,,,,,,,,,,,,, -76200,1.7843312,3.098418,,,,,,,,,,,,,, -76251,,,0.6528710722923279,1.5526872873306274,0.6100199818611145,1.7567421197891235,50000.0,0.4861000180244446,2.3885414600372314,10000.0,34928.72370290756,37609.87198114395,34928.72370290756,2673.288207530976,3.637963056564331,0.0 -76300,1.6771531,3.2957187,,,,,,,,,,,,,, -76400,1.540866,3.0724304,,,,,,,,,,,,,, -76500,1.7954369,3.1683767,,,,,,,,,,,,,, -76600,1.808926,3.2934957,,,,,,,,,,,,,, -76700,1.761253,3.1325526,,,,,,,,,,,,,, -76800,1.6368092,5.3034387,,,,,,,,,,,,,, -76900,1.7570955,3.1017563,,,,,,,,,,,,,, -77000,1.4927105,5.364896,,,,,,,,,,,,,, -77100,1.4475417,4.737903,,,,,,,,,,,,,, -77167,,,0.6625585556030273,1.513722538948059,0.6113399863243103,1.7449710369110107,50000.0,0.4903000295162201,2.388901948928833,10000.0,35348.81794667244,38064.8754966259,35348.81794667244,2708.1096754074097,3.672785997390747,0.0 -77200,1.3846581,5.1611314,,,,,,,,,,,,,, -77300,1.4909252,4.557972,,,,,,,,,,,,,, -77400,1.7104437,3.1513536,,,,,,,,,,,,,, -77500,1.4061776,4.8387136,,,,,,,,,,,,,, -77600,1.6678412,3.8369513,,,,,,,,,,,,,, -77700,1.6322591,3.1935363,,,,,,,,,,,,,, -77800,1.7802988,4.2238297,,,,,,,,,,,,,, -77900,1.6367863,3.230103,,,,,,,,,,,,,, -78000,1.8328434,3.1595683,,,,,,,,,,,,,, -78083,,,0.6606640219688416,1.493207335472107,0.6092599630355835,1.7219427824020386,50000.0,0.4865000247955322,2.363787174224853,10000.0,35768.99364566803,38517.90196752548,35768.99364566803,2740.87061214447,3.71204137802124,0.0 -78100,1.7008513,3.108632,,,,,,,,,,,,,, -78200,1.5362841,3.2880263,,,,,,,,,,,,,, -78300,1.4002304,4.867584,,,,,,,,,,,,,, -78400,1.8607972,3.1387215,,,,,,,,,,,,,, -78500,1.4380219,5.1060815,,,,,,,,,,,,,, -78600,1.8783494,3.192944,,,,,,,,,,,,,, -78700,1.6532404,3.3974066,,,,,,,,,,,,,, -78800,1.7567469,3.1325,,,,,,,,,,,,,, -78900,1.6619434,3.2356403,,,,,,,,,,,,,, -78998,,,0.659375011920929,1.5171035528182983,0.6125400066375732,1.7171244621276855,50000.0,0.4889000356197357,2.359149217605591,10000.0,36189.29687142372,38971.67401695252,36189.29687142372,2774.249065160752,3.752545356750488,0.0 -79000,1.7125497,3.4889216,,,,,,,,,,,,,, -79100,1.7348938,3.175332,,,,,,,,,,,,,, -79200,1.7991322,3.3263655,,,,,,,,,,,,,, -79300,1.5703436,4.238258,,,,,,,,,,,,,, -79400,1.7587876,3.4126391,,,,,,,,,,,,,, -79500,1.5176748,3.7809567,,,,,,,,,,,,,, -79600,1.6136312,3.4804492,,,,,,,,,,,,,, -79700,1.7107482,3.1413915,,,,,,,,,,,,,, -79800,1.9440541,3.1619968,,,,,,,,,,,,,, -79900,1.9154184,3.2237267,,,,,,,,,,,,,, -79915,,,0.66845703125,1.4566603899002075,0.6201399564743042,1.6755112409591677,50000.0,0.4943000376224518,2.326002836227417,10000.0,36609.45451402664,39425.04754161835,36609.45451402664,2807.375717639923,3.791879415512085,0.0 -80000,1.5263772,4.1394877,,,,,,,,,,,,,, -80100,1.7164487,3.2786498,,,,,,,,,,,,,, -80200,1.4235659,4.8695354,,,,,,,,,,,,,, -80300,1.8184259,3.3340988,,,,,,,,,,,,,, -80400,1.8950832,3.3001728,,,,,,,,,,,,,, -80500,1.6425521,3.1222847,,,,,,,,,,,,,, -80600,1.7065797,3.1176844,,,,,,,,,,,,,, -80700,1.5987556,3.2979832,,,,,,,,,,,,,, -80800,1.729171,3.2525027,,,,,,,,,,,,,, -80832,,,0.6868945360183716,1.438260197639465,0.6148200035095215,1.7528892755508425,50000.0,0.4910000264644623,2.404242038726806,10000.0,37029.78786754608,39875.934716939926,37029.78786754608,2837.8364627361298,3.83516001701355,0.0 -80900,1.7866976,3.1709406,,,,,,,,,,,,,, -81000,1.9300715,3.1410923,,,,,,,,,,,,,, -81100,2.009627,3.6841145,,,,,,,,,,,,,, -81200,1.8738381,3.1413586,,,,,,,,,,,,,, -81300,1.9000417,3.130574,,,,,,,,,,,,,, -81400,1.682197,3.1830628,,,,,,,,,,,,,, -81500,1.6566184,3.1163807,,,,,,,,,,,,,, -81600,1.8719013,3.5020156,,,,,,,,,,,,,, -81700,1.4329119,5.2545795,,,,,,,,,,,,,, -81747,,,0.6599804759025574,1.493338108062744,0.6163600087165833,1.6874927282333374,50000.0,0.4950000345706939,2.3316457271575928,10000.0,37449.77145719528,40329.72599077225,37449.77145719528,2871.551378250122,3.878218173980713,0.0 -81800,1.6999662,3.0709598,,,,,,,,,,,,,, -81900,1.7938137,3.1134987,,,,,,,,,,,,,, -82000,1.5911617,3.8260655,,,,,,,,,,,,,, -82100,1.8832688,3.1328912,,,,,,,,,,,,,, -82200,1.4259541,4.519056,,,,,,,,,,,,,, -82300,1.7714416,3.102735,,,,,,,,,,,,,, -82400,1.9398216,3.105968,,,,,,,,,,,,,, -82500,1.7497257,5.1976914,,,,,,,,,,,,,, -82600,1.7057831,3.2352521,,,,,,,,,,,,,, -82663,,,0.6682031154632568,1.454301834106445,0.6208800077438354,1.668615698814392,50000.0,0.4934000372886657,2.3003392219543457,10000.0,37869.73485660553,40781.92210578919,37869.73485660553,2903.697116851806,3.915627241134644,0.0 -82700,1.9541173,3.014242,,,,,,,,,,,,,, -82800,1.8958436,3.0342577,,,,,,,,,,,,,, -82900,1.6256971,3.9347975,,,,,,,,,,,,,, -83000,1.773745,3.1484628,,,,,,,,,,,,,, -83100,1.5523107,4.8857093,,,,,,,,,,,,,, -83200,1.489171,4.5022426,,,,,,,,,,,,,, -83300,1.6247008,4.780937,,,,,,,,,,,,,, -83400,1.6618234,3.4225078,,,,,,,,,,,,,, -83500,1.5515345,3.946847,,,,,,,,,,,,,, -83581,,,0.6833788752555847,1.3950210809707642,0.6195200085639954,1.6812617778778076,50000.0,0.4946000277996063,2.320722818374634,10000.0,38289.80323362351,41234.16646409035,38289.80323362351,2935.783272266388,3.954903364181519,0.0 -83600,1.4751854,4.5644755,,,,,,,,,,,,,, -83700,1.6661078,3.13657,,,,,,,,,,,,,, -83800,1.7694424,3.1006641,,,,,,,,,,,,,, -83900,1.63269,4.464981,,,,,,,,,,,,,, -84000,1.818161,3.0963776,,,,,,,,,,,,,, -84100,1.6487902,3.6737604,,,,,,,,,,,,,, -84200,1.6231992,3.6641688,,,,,,,,,,,,,, -84300,1.7712626,3.1225448,,,,,,,,,,,,,, -84400,1.8578254,3.0541022,,,,,,,,,,,,,, -84496,,,0.6632617115974426,1.50346839427948,0.6191200017929077,1.6938670873641968,50000.0,0.4934000372886657,2.347602367401123,10000.0,38709.94033956528,41688.10350394249,38709.94033956528,2969.496173620224,3.991729259490967,0.0 -84500,1.857396,3.4148302,,,,,,,,,,,,,, -84600,1.5686443,4.6696677,,,,,,,,,,,,,, -84700,1.579488,4.5236416,,,,,,,,,,,,,, -84800,1.813728,3.2532675,,,,,,,,,,,,,, -84900,1.7431366,3.0899367,,,,,,,,,,,,,, -85000,1.4472826,4.513537,,,,,,,,,,,,,, -85100,1.5714025,4.0094814,,,,,,,,,,,,,, -85200,1.4606314,5.177404,,,,,,,,,,,,,, -85300,1.6681443,3.4199374,,,,,,,,,,,,,, -85400,2.0904365,3.181889,,,,,,,,,,,,,, -85412,,,0.6689453125,1.475509762763977,0.6204000115394592,1.6864943504333496,50000.0,0.4971000254154205,2.3224990367889404,10000.0,39129.99418258667,42139.39819288254,39129.99418258667,3000.645429611206,4.033493041992188,0.0 -85500,2.073373,3.4492984,,,,,,,,,,,,,, -85600,2.0453105,3.1513371,,,,,,,,,,,,,, -85700,1.7754952,3.311388,,,,,,,,,,,,,, -85800,1.6200517,3.7312033,,,,,,,,,,,,,, -85900,1.6858072,3.3159795,,,,,,,,,,,,,, -86000,1.4455512,4.8978934,,,,,,,,,,,,,, -86100,1.7814286,3.148181,,,,,,,,,,,,,, -86200,1.66406,3.5769422,,,,,,,,,,,,,, -86300,1.9575465,3.2006037,,,,,,,,,,,,,, -86326,,,0.679492175579071,1.427093505859375,0.627299964427948,1.6685994863510132,50000.0,0.5010000467300415,2.304348945617676,10000.0,39550.02034115791,42593.61324310303,39550.02034115791,3034.7353508472443,4.08236575126648,0.0 -86400,1.5268238,3.9260383,,,,,,,,,,,,,, -86500,1.6930631,3.6267214,,,,,,,,,,,,,, -86600,1.8843715,3.094463,,,,,,,,,,,,,, -86700,1.573886,3.8975554,,,,,,,,,,,,,, -86800,1.8453099,3.1979723,,,,,,,,,,,,,, -86900,1.6792601,5.3313003,,,,,,,,,,,,,, -87000,1.7089127,3.0176234,,,,,,,,,,,,,, -87100,1.9477415,5.243828,,,,,,,,,,,,,, -87200,2.0886698,3.212211,,,,,,,,,,,,,, -87244,,,0.6720312237739563,1.4575451612472534,0.6271199584007263,1.6576414108276367,50000.0,0.5063000321388245,2.290591239929199,10000.0,39969.94578671456,43046.86581158638,39969.94578671456,3067.972613334656,4.12157940864563,0.0 -87300,1.5044844,4.874069,,,,,,,,,,,,,, -87400,1.7673476,3.1313372,,,,,,,,,,,,,, -87500,1.7032719,3.5387788,,,,,,,,,,,,,, -87600,1.7423089,3.473089,,,,,,,,,,,,,, -87700,1.8202572,3.031968,,,,,,,,,,,,,, -87800,2.0540924,3.121814,,,,,,,,,,,,,, -87900,1.9999533,3.0470192,,,,,,,,,,,,,, -88000,2.0466268,3.0592623,,,,,,,,,,,,,, -88100,1.5849838,3.9166749,,,,,,,,,,,,,, -88162,,,0.6704687476158142,1.4411784410476685,0.622439980506897,1.6661144495010376,50000.0,0.4997000098228454,2.302229642868042,10000.0,40389.9857711792,43500.610604286194,40389.9857711792,3101.585325717926,4.163185358047485,0.0 -88200,1.7275318,3.5022588,,,,,,,,,,,,,, -88300,1.7023572,3.6133816,,,,,,,,,,,,,, -88400,1.6385465,4.0689597,,,,,,,,,,,,,, -88500,1.8807175,3.0619617,,,,,,,,,,,,,, -88600,1.9351336,3.128247,,,,,,,,,,,,,, -88700,1.9307436,4.6901937,,,,,,,,,,,,,, -88800,1.6834779,2.982091,,,,,,,,,,,,,, -88900,1.8792859,3.162881,,,,,,,,,,,,,, -89000,1.670045,3.7497852,,,,,,,,,,,,,, -89080,,,0.6823828220367432,1.4180799722671509,0.6261999607086182,1.6651922464370728,50000.0,0.5085000395774841,2.2891640663146973,10000.0,40810.23924612999,43950.955739974976,40810.23924612999,3131.58740067482,4.202484369277954,0.0 -89100,1.5898082,3.7972019,,,,,,,,,,,,,, -89200,1.778465,3.0896015,,,,,,,,,,,,,, -89300,1.9327985,3.2354598,,,,,,,,,,,,,, -89400,1.4522066,4.2763977,,,,,,,,,,,,,, -89500,1.9757093,3.1004636,,,,,,,,,,,,,, -89600,1.8276465,5.2205024,,,,,,,,,,,,,, -89700,1.6149659,3.861472,,,,,,,,,,,,,, -89800,1.591981,3.4207382,,,,,,,,,,,,,, -89900,1.8032732,3.1241684,,,,,,,,,,,,,, -89995,,,0.6862109303474426,1.411870360374451,0.6287199854850769,1.6493552923202517,50000.0,0.5055000185966492,2.2911794185638428,10000.0,41230.43558573723,44403.9231262207,41230.43558573723,3164.264275074005,4.245952367782593,0.0 -90000,1.9237894,3.144911,,,,,,,,,,,,,, -90100,1.6711049,3.4559126,,,,,,,,,,,,,, -90200,1.6886497,3.3763921,,,,,,,,,,,,,, -90300,2.1422412,3.21084,,,,,,,,,,,,,, -90400,1.5709434,4.111516,,,,,,,,,,,,,, -90500,1.9964609,2.985724,,,,,,,,,,,,,, -90600,1.4250097,4.332646,,,,,,,,,,,,,, -90700,1.4697549,4.4849358,,,,,,,,,,,,,, -90800,1.8195878,3.039358,,,,,,,,,,,,,, -90900,1.7521703,3.0402362,,,,,,,,,,,,,, -90912,,,0.6795117259025574,1.41821026802063,0.6332799792289734,1.625242829322815,50000.0,0.5124000310897827,2.259521722793579,10000.0,41650.82543206215,44855.7765583992,41650.82543206215,3195.636889457702,4.286664247512817,0.0 -91000,1.8252832,3.8497577,,,,,,,,,,,,,, -91100,1.4967742,4.0167227,,,,,,,,,,,,,, -91200,1.6543773,4.668436,,,,,,,,,,,,,, -91300,1.9072093,3.268948,,,,,,,,,,,,,, -91400,1.5828292,4.8129873,,,,,,,,,,,,,, -91500,1.4781895,4.585326,,,,,,,,,,,,,, -91600,1.8268625,3.1910338,,,,,,,,,,,,,, -91700,1.8607349,3.1701195,,,,,,,,,,,,,, -91800,1.9459441,3.0380487,,,,,,,,,,,,,, -91828,,,0.6826757788658142,1.4430053234100342,0.6317399740219116,1.6762853860855105,50000.0,0.5081000328063965,2.3019185066223145,10000.0,42070.84104323387,45308.86335706711,42070.84104323387,3228.6166894435883,4.32750678062439,0.0 -91900,1.5810889,4.329269,,,,,,,,,,,,,, -92000,1.8541641,3.0342693,,,,,,,,,,,,,, -92100,1.7673451,3.2764783,,,,,,,,,,,,,, -92200,1.5162559,5.006115,,,,,,,,,,,,,, -92300,1.7266312,2.9446657,,,,,,,,,,,,,, -92400,2.0502923,3.3051753,,,,,,,,,,,,,, -92500,1.9138244,2.925034,,,,,,,,,,,,,, -92600,1.7468055,3.0518088,,,,,,,,,,,,,, -92700,1.6373065,4.948288,,,,,,,,,,,,,, -92745,,,0.7028710842132568,1.3684577941894531,0.6319199800491333,1.6790177822113037,50000.0,0.5100000500679016,2.3166470527648926,10000.0,42490.79935574532,45759.4947450161,42490.79935574532,3259.2013654708862,4.36621356010437,0.0 -92800,1.6779122,3.317134,,,,,,,,,,,,,, -92900,1.9269094,3.1403446,,,,,,,,,,,,,, -93000,1.6535066,5.225213,,,,,,,,,,,,,, -93100,1.9810171,3.04186,,,,,,,,,,,,,, -93200,2.1595745,3.3198602,,,,,,,,,,,,,, -93300,1.6219512,4.6957326,,,,,,,,,,,,,, -93400,1.8712302,3.0251944,,,,,,,,,,,,,, -93500,1.8650787,3.0395236,,,,,,,,,,,,,, -93600,1.6212376,3.7185643,,,,,,,,,,,,,, -93656,,,0.6768945455551147,1.43509840965271,0.6339799761772156,1.6315512657165527,50000.0,0.5134000182151794,2.2607595920562744,10000.0,42910.90014410019,46213.121638059616,42910.90014410019,3292.632405996322,4.411803960800171,0.0 -93700,1.8323809,3.0906346,,,,,,,,,,,,,, -93800,1.5843822,4.7836475,,,,,,,,,,,,,, -93900,1.8235154,4.10947,,,,,,,,,,,,,, -94000,1.8370657,3.032273,,,,,,,,,,,,,, -94100,1.7449023,3.0711663,,,,,,,,,,,,,, -94200,1.9021008,3.0541282,,,,,,,,,,,,,, -94300,1.9575794,3.0287976,,,,,,,,,,,,,, -94400,1.7159189,3.5216603,,,,,,,,,,,,,, -94500,2.15999,3.113888,,,,,,,,,,,,,, -94571,,,0.6860546469688416,1.4129973649978638,0.6372999548912048,1.6310029029846191,50000.0,0.5184000134468079,2.2514212131500244,10000.0,43331.148206710815,46666.15451836586,43331.148206710815,3325.326532363892,4.453327894210815,0.0 -94600,1.8417447,3.0690253,,,,,,,,,,,,,, -94700,1.5541819,5.039957,,,,,,,,,,,,,, -94800,1.8007588,3.0764854,,,,,,,,,,,,,, -94900,1.9081283,3.12521,,,,,,,,,,,,,, -95000,1.8608348,3.1032803,,,,,,,,,,,,,, -95100,1.8925775,2.9299088,,,,,,,,,,,,,, -95200,1.8904431,3.1656244,,,,,,,,,,,,,, -95300,1.5877168,5.182809,,,,,,,,,,,,,, -95400,1.6068149,4.5429325,,,,,,,,,,,,,, -95488,,,0.6981640458106995,1.355225682258606,0.6358199715614319,1.629407286643982,50000.0,0.5152000188827515,2.261467218399048,10000.0,43751.27952218056,47117.99061465264,43751.27952218056,3356.9392170906067,4.495721101760864,0.0 -95500,1.998692,3.7730792,,,,,,,,,,,,,, -95600,1.9583299,2.9481397,,,,,,,,,,,,,, -95700,1.9999796,2.9554007,,,,,,,,,,,,,, -95800,2.060577,3.2475476,,,,,,,,,,,,,, -95900,1.721148,5.1139455,,,,,,,,,,,,,, -96000,1.5277853,4.525172,,,,,,,,,,,,,, -96100,1.859595,2.986272,,,,,,,,,,,,,, -96200,1.7249631,3.2686813,,,,,,,,,,,,,, -96300,2.1690824,3.1786556,,,,,,,,,,,,,, -96400,1.7621473,4.9967103,,,,,,,,,,,,,, -96404,,,0.6825000047683716,1.4293886423110962,0.6386799812316895,1.6273837089538574,50000.0,0.5163000226020813,2.2728638648986816,10000.0,44171.42354559898,47572.40820598602,44171.42354559898,3391.115294933319,4.54344367980957,0.0 -96500,1.805404,2.9585443,,,,,,,,,,,,,, -96600,1.8234571,4.36382,,,,,,,,,,,,,, -96700,2.0271585,3.0424485,,,,,,,,,,,,,, -96800,1.7485065,4.141193,,,,,,,,,,,,,, -96900,2.0270138,3.0275247,,,,,,,,,,,,,, -97000,1.9030422,3.11171,,,,,,,,,,,,,, -97100,1.6372303,3.92334,,,,,,,,,,,,,, -97200,1.8048861,4.36329,,,,,,,,,,,,,, -97300,1.9445271,2.9936504,,,,,,,,,,,,,, -97322,,,0.6897070407867432,1.3623371124267578,0.6406199932098389,1.586683988571167,50000.0,0.515500009059906,2.2348742485046387,10000.0,44591.63816213608,48025.32202601433,44591.63816213608,3423.7243151664734,4.584112644195557,0.0 -97400,1.9519929,2.9969914,,,,,,,,,,,,,, -97500,2.0344884,3.066665,,,,,,,,,,,,,, -97600,1.785989,3.0924175,,,,,,,,,,,,,, -97700,2.0306857,2.9960716,,,,,,,,,,,,,, -97800,1.877346,3.0262487,,,,,,,,,,,,,, -97900,2.0011058,3.0414865,,,,,,,,,,,,,, -98000,1.972406,2.9145408,,,,,,,,,,,,,, -98100,1.8661084,3.0395648,,,,,,,,,,,,,, -98200,1.7759202,4.609671,,,,,,,,,,,,,, -98238,,,0.7014843821525574,1.3189449310302734,0.646399974822998,1.5679948329925537,50000.0,0.5222000479698181,2.2077906131744385,10000.0,45011.66071987152,48479.07020926476,45011.66071987152,3457.3588194847107,4.624876022338867,0.0 -98300,1.8170394,3.270753,,,,,,,,,,,,,, -98400,2.0869005,3.07439,,,,,,,,,,,,,, -98500,2.055218,2.8581991,,,,,,,,,,,,,, -98600,1.5469198,5.0850496,,,,,,,,,,,,,, -98700,1.948411,2.9265606,,,,,,,,,,,,,, -98800,1.9302582,3.0180814,,,,,,,,,,,,,, -98900,2.0676193,2.9755027,,,,,,,,,,,,,, -99000,2.1714776,3.0792646,,,,,,,,,,,,,, -99100,1.7312099,4.2981853,,,,,,,,,,,,,, -99156,,,0.69189453125,1.3856916427612305,0.6421399712562561,1.6010757684707642,50000.0,0.5197000503540039,2.2255241870880127,10000.0,45431.88610672951,48932.92041492462,45431.88610672951,3490.893681287765,4.664891719818115,0.0 -99200,1.6811771,3.9744174,,,,,,,,,,,,,, -99300,1.9194515,2.9926686,,,,,,,,,,,,,, -99400,1.7766179,3.6786318,,,,,,,,,,,,,, -99500,1.7585001,5.2197495,,,,,,,,,,,,,, -99600,2.0125616,3.0850575,,,,,,,,,,,,,, -99700,1.9610108,3.35201,,,,,,,,,,,,,, -99800,1.9187,3.1424212,,,,,,,,,,,,,, -99900,1.8717906,3.0248637,,,,,,,,,,,,,, -100000,1.9749019,2.993653,,,,,,,,,,,,,, -100073,,,0.6983593702316284,1.3234851360321045,0.646619975566864,1.549912452697754,50000.0,0.527999997138977,2.16940712928772,10000.0,45852.21845984459,49385.12797021866,45852.21845984459,3522.6747205257416,4.70860743522644,0.0 -100100,2.362532,3.0746112,,,,,,,,,,,,,, -100200,1.6555529,3.9543707,,,,,,,,,,,,,, -100300,2.081286,3.0256467,,,,,,,,,,,,,, -100400,2.2976036,3.0048144,,,,,,,,,,,,,, -100500,1.6932651,4.4228816,,,,,,,,,,,,,, -100600,1.7250723,3.8039484,,,,,,,,,,,,,, -100700,1.8400393,3.854417,,,,,,,,,,,,,, -100800,1.8101884,3.9149854,,,,,,,,,,,,,, -100900,1.83073,4.230287,,,,,,,,,,,,,, -100985,,,0.6984570026397705,1.3088330030441284,0.6455999612808228,1.5476953983306885,50000.0,0.5232000350952148,2.1762855052948,10000.0,46272.5213842392,49835.674525260925,46272.5213842392,3552.8196897506714,4.7579333782196045,0.0 -101000,2.0767918,2.9717054,,,,,,,,,,,,,, -101100,2.0470817,2.9908192,,,,,,,,,,,,,, -101200,2.2245007,2.9072146,,,,,,,,,,,,,, -101300,1.896166,3.2783825,,,,,,,,,,,,,, -101400,2.0966766,2.909473,,,,,,,,,,,,,, -101500,1.9720939,3.201041,,,,,,,,,,,,,, -101600,2.2632263,3.2558236,,,,,,,,,,,,,, -101700,2.0615597,3.55846,,,,,,,,,,,,,, -101800,1.9388481,5.189231,,,,,,,,,,,,,, -101900,,,0.7024999856948853,1.3207366466522217,0.6457799673080444,1.5591189861297607,50000.0,0.5252000093460083,2.188363552093506,10000.0,46692.711477041245,50287.51734471321,46692.711477041245,3584.373185634613,4.80739164352417,0.0 -101900,1.6751113,4.834204,,,,,,,,,,,,,, -102000,1.7967346,3.1581633,,,,,,,,,,,,,, -102100,1.997042,3.0159101,,,,,,,,,,,,,, -102200,2.096006,3.092787,,,,,,,,,,,,,, -102300,2.1053898,3.0729384,,,,,,,,,,,,,, -102400,2.2001147,2.873367,,,,,,,,,,,,,, -102500,1.9005642,4.9677157,,,,,,,,,,,,,, -102600,2.1079993,2.9626198,,,,,,,,,,,,,, -102700,1.6309438,4.297977,,,,,,,,,,,,,, -102800,1.6221445,5.0483384,,,,,,,,,,,,,, -102818,,,0.6942577958106995,1.4096049070358276,0.644599974155426,1.627078652381897,50000.0,0.5247000455856323,2.24755334854126,10000.0,47113.01589202881,50740.492428302765,47113.01589202881,3616.940553426743,4.861057758331299,0.0 -102900,1.663025,3.7120156,,,,,,,,,,,,,, -103000,1.86833,5.0822263,,,,,,,,,,,,,, -103100,1.9383935,3.0041146,,,,,,,,,,,,,, -103200,1.6865897,3.7613585,,,,,,,,,,,,,, -103300,1.908456,2.8494968,,,,,,,,,,,,,, -103400,2.2871177,2.898346,,,,,,,,,,,,,, -103500,2.190132,3.0175567,,,,,,,,,,,,,, -103600,1.7339725,3.9964225,,,,,,,,,,,,,, -103700,1.972418,2.9104888,,,,,,,,,,,,,, -103735,,,0.7060937285423279,1.303206443786621,0.6481800079345703,1.5560694932937622,50000.0,0.5296000242233276,2.18750262260437,10000.0,47533.03374886513,51193.85984563828,47533.03374886513,3650.198092460632,4.902195453643799,0.0 -103800,2.0900636,2.9591768,,,,,,,,,,,,,, -103900,2.253842,3.0173886,,,,,,,,,,,,,, -104000,1.7368335,4.6871333,,,,,,,,,,,,,, -104100,2.3101704,3.000988,,,,,,,,,,,,,, -104200,1.9050711,3.1550539,,,,,,,,,,,,,, -104300,2.2513306,2.9437554,,,,,,,,,,,,,, -104400,1.8698698,4.692953,,,,,,,,,,,,,, -104500,2.1991804,3.0336943,,,,,,,,,,,,,, -104600,1.919568,3.160541,,,,,,,,,,,,,, -104652,,,0.728710949420929,1.209591507911682,0.655519962310791,1.5235939025878906,50000.0,0.5229000449180603,2.1618435382843018,10000.0,47953.09908533096,51646.89317679405,47953.09908533096,3683.071009159088,4.946327209472656,0.0 -104700,2.1879315,2.93031,,,,,,,,,,,,,, -104800,1.9927906,3.5691662,,,,,,,,,,,,,, -104900,2.3323514,2.971722,,,,,,,,,,,,,, -105000,2.1553633,2.9293776,,,,,,,,,,,,,, -105100,1.9875803,3.0079,,,,,,,,,,,,,, -105200,1.7439221,4.728069,,,,,,,,,,,,,, -105300,1.9703729,3.2784636,,,,,,,,,,,,,, -105400,1.9635085,3.398453,,,,,,,,,,,,,, -105500,1.9142814,2.9143102,,,,,,,,,,,,,, -105569,,,0.7003515362739563,1.328979730606079,0.6550799608230591,1.5396244525909424,50000.0,0.535800039768219,2.1648764610290527,10000.0,48373.37495970726,52099.889456510544,48373.37495970726,3715.691724061966,4.996261835098267,0.0 -105600,2.1901567,2.952439,,,,,,,,,,,,,, -105700,1.6395129,3.9455562,,,,,,,,,,,,,, -105800,2.1440504,3.074468,,,,,,,,,,,,,, -105900,1.7045064,4.4290166,,,,,,,,,,,,,, -106000,1.7875191,4.5366344,,,,,,,,,,,,,, -106100,2.0815425,3.0358477,,,,,,,,,,,,,, -106200,1.8777717,3.1037757,,,,,,,,,,,,,, -106300,2.0713904,5.11956,,,,,,,,,,,,,, -106400,2.0473058,3.0201309,,,,,,,,,,,,,, -106485,,,0.706250011920929,1.3207530975341797,0.6521199941635132,1.5590405464172363,50000.0,0.5314000248908997,2.192826986312866,10000.0,48793.45449113846,52553.79598546028,48793.45449113846,3749.424998044968,5.040008306503296,0.0 -106500,2.053234,3.344307,,,,,,,,,,,,,, -106600,2.0658395,2.9999475,,,,,,,,,,,,,, -106700,2.1367183,2.890286,,,,,,,,,,,,,, -106800,2.2208574,3.1045833,,,,,,,,,,,,,, -106900,1.9528918,3.7729,,,,,,,,,,,,,, -107000,1.8706429,3.296503,,,,,,,,,,,,,, -107100,2.2028425,3.0041616,,,,,,,,,,,,,, -107200,2.1066577,2.9137826,,,,,,,,,,,,,, -107300,2.0334237,3.0280936,,,,,,,,,,,,,, -107400,2.0043085,3.23385,,,,,,,,,,,,,, -107401,,,0.7234960794448853,1.20840322971344,0.6581199765205383,1.5009328126907349,50000.0,0.5392000079154968,2.1324949264526367,10000.0,49214.00630617142,53007.55867099762,49214.00630617142,3782.543523073197,5.082547903060913,0.0 -107500,2.0802073,2.9620779,,,,,,,,,,,,,, -107600,2.075156,2.8674917,,,,,,,,,,,,,, -107700,2.0715735,3.0896163,,,,,,,,,,,,,, -107800,2.3301027,2.964892,,,,,,,,,,,,,, -107900,1.8586771,3.949773,,,,,,,,,,,,,, -108000,2.1983438,2.9673235,,,,,,,,,,,,,, -108100,1.9227443,3.3870378,,,,,,,,,,,,,, -108200,2.1210322,5.1079936,,,,,,,,,,,,,, -108300,2.2924168,2.9161518,,,,,,,,,,,,,, -108316,,,0.7032812237739563,1.2869420051574707,0.6589800119400024,1.5009398460388184,50000.0,0.5303000211715698,2.155637502670288,10000.0,49634.25281214714,53459.40159130096,49634.25281214714,3814.051218271256,5.121634244918823,0.0 -108400,2.0610304,2.9120603,,,,,,,,,,,,,, -108500,1.9721817,4.874876,,,,,,,,,,,,,, -108600,1.8673983,3.9782243,,,,,,,,,,,,,, -108700,2.0431974,2.9024246,,,,,,,,,,,,,, -108800,2.2224016,2.927854,,,,,,,,,,,,,, -108900,2.1936994,2.999804,,,,,,,,,,,,,, -109000,1.9904274,3.0862787,,,,,,,,,,,,,, -109100,2.1105487,3.025109,,,,,,,,,,,,,, -109200,2.129536,3.2262073,,,,,,,,,,,,,, -109232,,,0.7122656106948853,1.288287878036499,0.657759964466095,1.5199744701385498,50000.0,0.5329000353813171,2.154524326324463,10000.0,50054.5297896862,53911.006695985794,50054.5297896862,3845.278760910034,5.172155141830444,0.0 -109300,1.9289229,3.7935276,,,,,,,,,,,,,, -109400,2.4889157,2.894831,,,,,,,,,,,,,, -109500,2.0061977,4.0118732,,,,,,,,,,,,,, -109600,2.0361788,2.9438376,,,,,,,,,,,,,, -109700,1.8903507,3.6059005,,,,,,,,,,,,,, -109800,2.172364,2.8680897,,,,,,,,,,,,,, -109900,2.1708856,2.9477072,,,,,,,,,,,,,, -110000,1.8807585,3.471553,,,,,,,,,,,,,, -110100,2.070049,4.610359,,,,,,,,,,,,,, -110148,,,0.7259570360183716,1.2455788850784302,0.6635199785232544,1.5217573642730713,50000.0,0.5366000533103943,2.149949550628662,10000.0,50474.51218295097,54364.08873963356,50474.51218295097,3878.277625799179,5.222049713134766,0.0 -110200,2.1069665,2.9128165,,,,,,,,,,,,,, -110300,1.8469425,4.674434,,,,,,,,,,,,,, -110400,1.9595946,4.697834,,,,,,,,,,,,,, -110500,1.8046205,4.3608317,,,,,,,,,,,,,, -110600,2.2280517,2.7874372,,,,,,,,,,,,,, -110700,2.1433012,2.8868806,,,,,,,,,,,,,, -110800,1.9922392,2.721102,,,,,,,,,,,,,, -110900,2.0052164,3.1248553,,,,,,,,,,,,,, -111000,2.0627701,2.9723606,,,,,,,,,,,,,, -111065,,,0.71337890625,1.262039303779602,0.6626600027084351,1.4916067123413086,50000.0,0.5421000123023987,2.113887310028076,10000.0,50894.50727891922,54817.20326185226,50894.50727891922,3911.306578159332,5.262262582778931,0.0 -111100,2.1299143,5.071566,,,,,,,,,,,,,, -111200,1.9167181,4.9829946,,,,,,,,,,,,,, -111300,2.1851108,3.3210306,,,,,,,,,,,,,, -111400,1.9349356,3.63045,,,,,,,,,,,,,, -111500,2.1253572,2.8588676,,,,,,,,,,,,,, -111600,2.0781376,3.937958,,,,,,,,,,,,,, -111700,2.0727565,3.1051202,,,,,,,,,,,,,, -111800,1.8935729,5.0813932,,,,,,,,,,,,,, -111900,1.848171,3.8983223,,,,,,,,,,,,,, -111980,,,0.7187108993530273,1.2647498846054075,0.6654999852180481,1.5069676637649536,50000.0,0.5409000515937805,2.1428349018096924,10000.0,51314.64157676697,55270.05773234368,51314.64157676697,3943.933856487274,5.305820465087891,0.0 -112000,2.2349925,2.9530113,,,,,,,,,,,,,, -112100,2.0619106,4.083522,,,,,,,,,,,,,, -112200,1.9183807,4.9369426,,,,,,,,,,,,,, -112300,1.9748597,3.5199728,,,,,,,,,,,,,, -112400,2.271259,3.0775123,,,,,,,,,,,,,, -112500,2.0796068,2.9087725,,,,,,,,,,,,,, -112600,1.980987,4.0096903,,,,,,,,,,,,,, -112700,2.1427145,2.884217,,,,,,,,,,,,,, -112800,1.809557,4.2235746,,,,,,,,,,,,,, -112896,,,0.7238085865974426,1.2229537963867188,0.6610999703407288,1.4882365465164185,50000.0,0.5398000478744507,2.1151599884033203,10000.0,51734.58570885658,55722.437901735306,51734.58570885658,3976.276055574417,5.350057363510132,0.0 -112900,2.2563994,2.9623563,,,,,,,,,,,,,, -113000,1.8827775,4.6016245,,,,,,,,,,,,,, -113100,2.2640827,2.9430857,,,,,,,,,,,,,, -113200,2.07736,3.9865801,,,,,,,,,,,,,, -113300,2.3478925,2.882763,,,,,,,,,,,,,, -113400,2.0549312,2.8436446,,,,,,,,,,,,,, -113500,2.0972643,3.054733,,,,,,,,,,,,,, -113600,2.3660924,2.8628416,,,,,,,,,,,,,, -113700,2.150429,2.849454,,,,,,,,,,,,,, -113800,2.3624055,2.903973,,,,,,,,,,,,,, -113811,,,0.7215625047683716,1.2210685014724731,0.6649199724197388,1.4647303819656372,50000.0,0.5451000332832336,2.094324350357056,10000.0,52154.66342067719,56172.92315030098,52154.66342067719,4006.590903520584,5.393388271331787,0.0 -113900,2.0478568,4.534926,,,,,,,,,,,,,, -114000,2.016459,4.673222,,,,,,,,,,,,,, -114100,2.1895323,2.9139109,,,,,,,,,,,,,, -114200,2.134849,3.576498,,,,,,,,,,,,,, -114300,2.1590824,4.287754,,,,,,,,,,,,,, -114400,2.2983284,2.943359,,,,,,,,,,,,,, -114500,2.137359,3.4094493,,,,,,,,,,,,,, -114600,2.2778792,2.9109674,,,,,,,,,,,,,, -114700,2.4238794,2.937031,,,,,,,,,,,,,, -114726,,,0.7206835746765137,1.2099860906600952,0.67249995470047,1.443393349647522,50000.0,0.5450000166893005,2.083441257476806,10000.0,52575.01256537437,56625.05432701111,52575.01256537437,4038.273143053055,5.442919015884399,0.0 -114800,2.093063,4.119184,,,,,,,,,,,,,, -114900,2.195865,2.8120773,,,,,,,,,,,,,, -115000,2.0710347,4.7652144,,,,,,,,,,,,,, -115100,2.1757872,2.8688633,,,,,,,,,,,,,, -115200,2.252555,2.94843,,,,,,,,,,,,,, -115300,2.367999,2.7923625,,,,,,,,,,,,,, -115400,2.2700326,3.0599189,,,,,,,,,,,,,, -115500,2.0464265,3.4930189,,,,,,,,,,,,,, -115600,2.064847,4.3904185,,,,,,,,,,,,,, -115640,,,0.7305468320846558,1.191508173942566,0.67249995470047,1.4450451135635376,50000.0,0.5496000051498413,2.064040422439575,10000.0,52995.04573082924,57079.766466617584,52995.04573082924,4072.856017351152,5.488822221755981,0.0 -115700,2.4245207,2.850495,,,,,,,,,,,,,, -115800,1.9932692,3.2648675,,,,,,,,,,,,,, -115900,2.6393726,2.7732306,,,,,,,,,,,,,, -116000,2.1175785,2.7810504,,,,,,,,,,,,,, -116100,2.1763887,2.8192992,,,,,,,,,,,,,, -116200,2.147251,2.98288,,,,,,,,,,,,,, -116300,2.1989284,2.7547436,,,,,,,,,,,,,, -116400,2.285991,2.819863,,,,,,,,,,,,,, -116500,2.3276865,3.041174,,,,,,,,,,,,,, -116556,,,0.7452148199081421,1.1201274394989014,0.6720199584960938,1.4388738870620728,50000.0,0.5498000383377075,2.0630745887756348,10000.0,53415.35868215561,57533.651420116425,53415.35868215561,4106.33264541626,5.534053087234497,0.0 -116600,2.0575514,2.9640644,,,,,,,,,,,,,, -116700,2.1191535,3.2174714,,,,,,,,,,,,,, -116800,2.4109006,2.889019,,,,,,,,,,,,,, -116900,2.2384505,2.9534416,,,,,,,,,,,,,, -117000,2.0479417,4.464362,,,,,,,,,,,,,, -117100,2.085236,2.79464,,,,,,,,,,,,,, -117200,2.3040075,3.0136893,,,,,,,,,,,,,, -117300,2.2070966,2.8364377,,,,,,,,,,,,,, -117400,1.9881364,3.4483392,,,,,,,,,,,,,, -117472,,,0.7250390648841858,1.2660032510757446,0.667199969291687,1.5061935186386108,50000.0,0.5496000051498413,2.128594160079956,10000.0,53835.411581754684,57987.13926744461,53835.411581754684,4139.674651861191,5.577041864395142,0.0 -117500,2.2349148,4.597686,,,,,,,,,,,,,, -117600,2.342607,2.893077,,,,,,,,,,,,,, -117700,2.3029792,4.8999453,,,,,,,,,,,,,, -117800,2.6155925,2.7843118,,,,,,,,,,,,,, -117900,2.4724948,3.1154556,,,,,,,,,,,,,, -118000,2.3836703,2.8482819,,,,,,,,,,,,,, -118100,2.4261484,4.944702,,,,,,,,,,,,,, -118200,2.1074336,4.285712,,,,,,,,,,,,,, -118300,2.0506022,3.637355,,,,,,,,,,,,,, -118386,,,0.73046875,1.1939589977264404,0.6730200052261353,1.4451491832733154,50000.0,0.5502000451087952,2.0854063034057617,10000.0,54255.57951760292,58440.40096735954,54255.57951760292,4172.6747174263,5.619581699371338,0.0 -118400,2.4538004,2.8516078,,,,,,,,,,,,,, -118500,2.1831121,4.4505277,,,,,,,,,,,,,, -118600,2.116106,3.4913397,,,,,,,,,,,,,, -118700,2.1791143,2.8923032,,,,,,,,,,,,,, -118800,2.0467467,3.7759354,,,,,,,,,,,,,, -118900,2.2949288,2.8532732,,,,,,,,,,,,,, -119000,2.1883867,2.8211484,,,,,,,,,,,,,, -119100,2.4106622,5.0105677,,,,,,,,,,,,,, -119200,2.5686967,2.807901,,,,,,,,,,,,,, -119299,,,0.7431640625,1.1251846551895142,0.6783999800682068,1.4041777849197388,50000.0,0.5516000390052795,2.039374589920044,10000.0,54675.84220218658,58894.145411252975,54675.84220218658,4206.0634133815765,5.661855459213257,0.0 -119300,2.2467842,3.26292,,,,,,,,,,,,,, -119400,2.252701,3.0991788,,,,,,,,,,,,,, -119500,2.5635428,4.789781,,,,,,,,,,,,,, -119600,2.2366319,2.7990193,,,,,,,,,,,,,, -119700,2.4903166,2.8018672,,,,,,,,,,,,,, -119800,2.351305,2.7137735,,,,,,,,,,,,,, -119900,2.470078,2.7529235,,,,,,,,,,,,,, -120000,2.4714313,2.7988741,,,,,,,,,,,,,, -120100,2.3012025,3.0134487,,,,,,,,,,,,,, -120200,2.2981775,2.7687545,,,,,,,,,,,,,, -120216,,,0.7290429472923279,1.183713674545288,0.6782199740409851,1.414423584938049,50000.0,0.5534999966621399,2.0480198860168457,10000.0,55095.76208996773,59347.97211909294,55095.76208996773,4239.874573707581,5.706495046615601,0.0 -120300,2.4463935,2.8286982,,,,,,,,,,,,,, -120400,2.260557,3.0529993,,,,,,,,,,,,,, -120500,2.5783644,2.9641986,,,,,,,,,,,,,, -120600,2.1338453,3.4038033,,,,,,,,,,,,,, -120700,2.076126,3.9622426,,,,,,,,,,,,,, -120800,2.2198946,2.9123838,,,,,,,,,,,,,, -120900,2.1249664,3.268167,,,,,,,,,,,,,, -121000,2.8016346,4.8982167,,,,,,,,,,,,,, -121100,2.127125,3.8136644,,,,,,,,,,,,,, -121132,,,0.7370507717132568,1.166053056716919,0.6767799854278564,1.426482319831848,50000.0,0.5556000471115112,2.0637381076812744,10000.0,55516.00203704834,59800.793695926666,55516.00203704834,4272.361785888672,5.750799179077148,0.0 -121200,2.4837778,4.8317237,,,,,,,,,,,,,, -121300,2.592319,5.087405,,,,,,,,,,,,,, -121400,2.1966004,4.7951846,,,,,,,,,,,,,, -121500,2.1895883,3.6870496,,,,,,,,,,,,,, -121600,2.0424922,4.893917,,,,,,,,,,,,,, -121700,2.4286172,2.765622,,,,,,,,,,,,,, -121800,2.031274,4.1925406,,,,,,,,,,,,,, -121900,2.788019,2.6573179,,,,,,,,,,,,,, -122000,2.3650074,3.0075088,,,,,,,,,,,,,, -122049,,,0.7417968511581421,1.1477159261703491,0.6811599731445312,1.4139902591705322,50000.0,0.5579000115394592,2.025280237197876,10000.0,55936.13686299324,60254.25495505333,55936.13686299324,4305.591577529907,5.79656982421875,0.0 -122100,2.1519623,3.8011484,,,,,,,,,,,,,, -122200,2.3930566,2.7858717,,,,,,,,,,,,,, -122300,2.3836489,4.9436817,,,,,,,,,,,,,, -122400,2.4204865,2.9747355,,,,,,,,,,,,,, -122500,2.4233406,3.681344,,,,,,,,,,,,,, -122600,2.4095178,2.8056583,,,,,,,,,,,,,, -122700,2.450467,3.5590973,,,,,,,,,,,,,, -122800,2.2050622,4.1488276,,,,,,,,,,,,,, -122900,2.3916383,2.8899531,,,,,,,,,,,,,, -122963,,,0.7372655868530273,1.1513622999191284,0.688979983329773,1.3695464134216309,50000.0,0.5614000558853149,2.017380475997925,10000.0,56356.257717609406,60706.10312247276,56356.257717609406,4337.223828554153,5.842200517654419,0.0 -123000,2.5626724,2.8300009,,,,,,,,,,,,,, -123100,2.42184,4.4722185,,,,,,,,,,,,,, -123200,2.5522068,2.699675,,,,,,,,,,,,,, -123300,2.1875873,3.2943673,,,,,,,,,,,,,, -123400,2.4784389,2.643689,,,,,,,,,,,,,, -123500,2.2961836,2.6779797,,,,,,,,,,,,,, -123600,2.1297584,4.1762714,,,,,,,,,,,,,, -123700,2.4139206,2.7877882,,,,,,,,,,,,,, -123800,2.3660936,2.810378,,,,,,,,,,,,,, -123878,,,0.7439843416213989,1.1398800611495972,0.6844399571418762,1.3942458629608154,50000.0,0.5649000406265259,2.015331745147705,10000.0,56776.523052454,61159.25761413574,56776.523052454,4370.009130716324,5.895697832107544,0.0 -123900,2.6951087,2.835009,,,,,,,,,,,,,, -124000,2.318124,3.6651173,,,,,,,,,,,,,, -124100,2.269925,3.6966019,,,,,,,,,,,,,, -124200,2.5937653,2.9177318,,,,,,,,,,,,,, -124300,2.4615393,2.8251698,,,,,,,,,,,,,, -124400,2.5560737,2.7604723,,,,,,,,,,,,,, -124500,2.155866,4.065844,,,,,,,,,,,,,, -124600,2.3765304,3.097828,,,,,,,,,,,,,, -124700,2.3868995,4.9282355,,,,,,,,,,,,,, -124794,,,0.74818354845047,1.1148760318756104,0.6860199570655823,1.3866878747940063,50000.0,0.562000036239624,2.016331672668457,10000.0,57196.867312431335,61613.43428826332,57196.867312431335,4403.7414972782135,5.944585561752319,0.0 -124800,2.5045679,2.861461,,,,,,,,,,,,,, -124900,2.8269155,2.8688278,,,,,,,,,,,,,, -125000,2.195043,4.2035427,,,,,,,,,,,,,, -125100,2.319279,3.3831143,,,,,,,,,,,,,, -125200,2.3604305,2.7241423,,,,,,,,,,,,,, -125300,2.4046934,2.7074199,,,,,,,,,,,,,, -125400,2.5141704,4.1439853,,,,,,,,,,,,,, -125500,2.6714072,2.6777167,,,,,,,,,,,,,, -125600,2.548962,2.7948916,,,,,,,,,,,,,, -125700,2.6502855,2.7759354,,,,,,,,,,,,,, -125709,,,0.7470703125,1.1385987997055054,0.6839199662208557,1.4028888940811155,50000.0,0.5626000165939331,2.021228790283203,10000.0,57616.94558739662,62064.4964826107,57616.94558739662,4434.629065275192,5.990480661392212,0.0 -125800,2.4307892,2.9540405,,,,,,,,,,,,,, -125900,2.54669,2.7148812,,,,,,,,,,,,,, -126000,2.4232233,2.6265996,,,,,,,,,,,,,, -126100,2.2129648,3.4319055,,,,,,,,,,,,,, -126200,2.4617937,4.5058737,,,,,,,,,,,,,, -126300,2.5389526,2.7491455,,,,,,,,,,,,,, -126400,2.3164482,4.457447,,,,,,,,,,,,,, -126500,2.4562416,2.8474214,,,,,,,,,,,,,, -126600,2.6474304,2.8254619,,,,,,,,,,,,,, -126620,,,0.7459765672683716,1.1232553720474243,0.6888200044631958,1.3786057233810425,50000.0,0.5684000253677368,1.9932000637054443,10000.0,58037.063480854034,62516.18090867996,58037.063480854034,4466.095130681992,6.0411882400512695,0.0 -126700,2.8391235,4.9411273,,,,,,,,,,,,,, -126800,2.8468516,2.937809,,,,,,,,,,,,,, -126900,2.771958,2.6923962,,,,,,,,,,,,,, -127000,2.59019,4.7414837,,,,,,,,,,,,,, -127100,2.9116743,2.8467376,,,,,,,,,,,,,, -127200,2.491136,2.8129241,,,,,,,,,,,,,, -127300,2.7415843,2.7335088,,,,,,,,,,,,,, -127400,2.5225806,2.6434824,,,,,,,,,,,,,, -127500,2.4962628,4.8936334,,,,,,,,,,,,,, -127535,,,0.7511913776397705,1.0778286457061768,0.6904999613761902,1.3422291278839111,50000.0,0.5659000277519226,1.9629532098770144,10000.0,58457.12394356728,62969.93001675606,58457.12394356728,4499.679343938828,6.095107316970825,0.0 -127600,2.3420823,2.6917017,,,,,,,,,,,,,, -127700,2.5112267,4.1949234,,,,,,,,,,,,,, -127800,2.4970138,4.3949046,,,,,,,,,,,,,, -127900,2.40975,4.227607,,,,,,,,,,,,,, -128000,2.637268,2.7850628,,,,,,,,,,,,,, -128100,2.4586074,4.1233497,,,,,,,,,,,,,, -128200,2.621822,4.4018207,,,,,,,,,,,,,, -128300,2.789695,4.3926907,,,,,,,,,,,,,, -128400,2.7176,2.8171928,,,,,,,,,,,,,, -128452,,,0.7696874737739563,1.024175047874451,0.6924600005149841,1.3428523540496826,50000.0,0.5690000057220459,1.971156001091004,10000.0,58877.354343652725,63421.92477989197,58877.354343652725,4531.347820997238,6.140000343322754,0.0 -128500,2.4377291,2.9533093,,,,,,,,,,,,,, -128600,2.8624432,2.7828505,,,,,,,,,,,,,, -128700,2.715638,2.724787,,,,,,,,,,,,,, -128800,2.7598486,2.7359414,,,,,,,,,,,,,, -128900,2.3824446,3.8222623,,,,,,,,,,,,,, -129000,2.7038438,3.028814,,,,,,,,,,,,,, -129100,2.6573837,2.6466205,,,,,,,,,,,,,, -129200,2.3218105,3.6304944,,,,,,,,,,,,,, -129300,2.6169536,2.6795294,,,,,,,,,,,,,, -129367,,,0.7485156059265137,1.1012896299362185,0.6915199756622314,1.3491532802581787,50000.0,0.5629000067710876,1.973888039588928,10000.0,59297.305203437805,63875.34980750084,59297.305203437805,4564.715627908707,6.195974826812744,0.0 -129400,2.5404131,4.4644046,,,,,,,,,,,,,, -129500,2.3862321,4.167593,,,,,,,,,,,,,, -129600,2.6520827,4.727998,,,,,,,,,,,,,, -129700,3.0187569,2.708898,,,,,,,,,,,,,, -129800,2.6579409,4.837002,,,,,,,,,,,,,, -129900,2.6239688,2.732142,,,,,,,,,,,,,, -130000,2.6130996,3.0150673,,,,,,,,,,,,,, -130100,2.8636487,2.7627375,,,,,,,,,,,,,, -130200,2.625992,2.6961217,,,,,,,,,,,,,, -130283,,,0.7594531178474426,1.0408467054367063,0.6985399723052979,1.3164602518081665,50000.0,0.5763000249862671,1.9376977682113647,10000.0,59717.62939977646,64328.16715955734,59717.62939977646,4597.110694646835,6.243996620178223,0.0 -130300,2.313708,3.4919822,,,,,,,,,,,,,, -130400,2.36529,3.1979184,,,,,,,,,,,,,, -130500,2.6803064,2.78238,,,,,,,,,,,,,, -130600,2.5991309,2.6710913,,,,,,,,,,,,,, -130700,2.632587,2.9805224,,,,,,,,,,,,,, -130800,3.2074041,4.8758144,,,,,,,,,,,,,, -130900,2.3000472,4.000304,,,,,,,,,,,,,, -131000,2.6645877,2.936324,,,,,,,,,,,,,, -131100,2.8687782,2.8270123,,,,,,,,,,,,,, -131200,,,0.7681640386581421,1.0163657665252686,0.6958799958229065,1.330001711845398,50000.0,0.5750000476837158,1.946197748184204,10000.0,60138.00821852684,64781.60162329674,60138.00821852684,4630.069223642349,6.288846492767334,0.0 -131200,2.4259121,4.221766,,,,,,,,,,,,,, -131300,2.5886848,2.6254451,,,,,,,,,,,,,, -131400,2.7834835,3.1881342,,,,,,,,,,,,,, -131500,2.616458,2.6807156,,,,,,,,,,,,,, -131600,2.8213665,2.7645607,,,,,,,,,,,,,, -131700,2.779662,2.7390752,,,,,,,,,,,,,, -131800,2.608333,2.8150263,,,,,,,,,,,,,, -131900,2.5720258,2.809714,,,,,,,,,,,,,, -132000,2.634124,2.725158,,,,,,,,,,,,,, -132100,2.9804764,4.492136,,,,,,,,,,,,,, -132116,,,0.7521288990974426,1.0870712995529177,0.6966399550437927,1.3381699323654177,50000.0,0.5717000365257263,1.9606136083602903,10000.0,60558.07671093941,65233.88895082474,60558.07671093941,4662.193197011948,6.333850622177124,0.0 -132200,2.7536004,2.6880527,,,,,,,,,,,,,, -132300,2.565083,4.410695,,,,,,,,,,,,,, -132400,2.4615932,2.541579,,,,,,,,,,,,,, -132500,2.676594,4.477272,,,,,,,,,,,,,, -132600,3.0178316,2.6821752,,,,,,,,,,,,,, -132700,2.7022462,2.7630646,,,,,,,,,,,,,, -132800,2.84464,2.6903124,,,,,,,,,,,,,, -132900,2.7842638,2.7415547,,,,,,,,,,,,,, -133000,2.4866166,3.5703952,,,,,,,,,,,,,, -133031,,,0.7635741829872131,1.053499698638916,0.6963199973106384,1.3278359174728394,50000.0,0.5728000402450562,1.9535140991210933,10000.0,60978.0214304924,65686.31872987747,60978.0214304924,4694.585594415665,6.377013444900513,0.0 -133100,2.8821628,2.8103628,,,,,,,,,,,,,, -133200,2.5143108,4.7759542,,,,,,,,,,,,,, -133300,2.9112914,2.636758,,,,,,,,,,,,,, -133400,2.7643244,4.155533,,,,,,,,,,,,,, -133500,2.8602839,2.7690043,,,,,,,,,,,,,, -133600,2.8819025,2.650099,,,,,,,,,,,,,, -133700,2.493847,4.170912,,,,,,,,,,,,,, -133800,2.5466182,2.6282578,,,,,,,,,,,,,, -133900,3.2894318,2.7258258,,,,,,,,,,,,,, -133946,,,0.7712304592132568,0.9981160759925842,0.7023999691009521,1.3043373823165894,50000.0,0.5840000510215759,1.9154057502746584,10000.0,61398.24069237709,66139.95198273659,61398.24069237709,4727.902453184128,6.423879384994507,0.0 -134000,2.8940089,2.7511945,,,,,,,,,,,,,, -134100,2.888324,2.7540808,,,,,,,,,,,,,, -134200,2.6641572,2.5793328,,,,,,,,,,,,,, -134300,2.9078007,2.7148576,,,,,,,,,,,,,, -134400,2.6925828,3.0016615,,,,,,,,,,,,,, -134500,2.9108765,2.6343129,,,,,,,,,,,,,, -134600,3.1372268,4.717974,,,,,,,,,,,,,, -134700,2.9724143,2.8175328,,,,,,,,,,,,,, -134800,2.7058914,3.4395618,,,,,,,,,,,,,, -134862,,,0.7629492282867432,1.0372766256332395,0.702739953994751,1.304007887840271,50000.0,0.5812000036239624,1.9075292348861688,10000.0,61818.17729949951,66593.90523028374,61818.17729949951,4761.82052397728,6.472502708435059,0.0 -134900,2.8523753,2.6319113,,,,,,,,,,,,,, -135000,2.7662723,2.8642359,,,,,,,,,,,,,, -135100,2.4959104,4.4135065,,,,,,,,,,,,,, -135200,2.6384187,3.5972693,,,,,,,,,,,,,, -135300,3.160944,4.663546,,,,,,,,,,,,,, -135400,2.5288875,3.0679054,,,,,,,,,,,,,, -135500,2.8906617,2.6767707,,,,,,,,,,,,,, -135600,2.9688268,2.5925796,,,,,,,,,,,,,, -135700,3.2421627,2.634619,,,,,,,,,,,,,, -135777,,,0.7660546898841858,1.0418373346328735,0.7041400074958801,1.3056179285049438,50000.0,0.579800009727478,1.9201245307922363,10000.0,62238.25712633133,67046.49923276901,62238.25712633133,4794.236758947372,6.520557165145874,0.0 -135800,2.676668,3.8379793,,,,,,,,,,,,,, -135900,2.6930418,3.0288615,,,,,,,,,,,,,, -136000,2.5153935,3.912414,,,,,,,,,,,,,, -136100,2.938334,3.0574887,,,,,,,,,,,,,, -136200,2.8399835,2.624889,,,,,,,,,,,,,, -136300,3.0841763,2.722905,,,,,,,,,,,,,, -136400,3.3281913,2.697471,,,,,,,,,,,,,, -136500,3.4933548,2.6283438,,,,,,,,,,,,,, -136600,2.5152364,4.1651535,,,,,,,,,,,,,, -136693,,,0.7722460627555847,1.0183627605438232,0.7065199613571167,1.3071863651275637,50000.0,0.5826000571250916,1.9220960140228271,10000.0,62658.68343901634,67499.63382172585,62658.68343901634,4826.84645652771,6.569271087646484,0.0 -136700,3.0343382,2.5586193,,,,,,,,,,,,,, -136800,2.7585282,2.7714028,,,,,,,,,,,,,, -136900,3.135875,2.9427865,,,,,,,,,,,,,, -137000,2.9365697,2.6122806,,,,,,,,,,,,,, -137100,2.5826302,3.8169854,,,,,,,,,,,,,, -137200,3.1630986,2.6330533,,,,,,,,,,,,,, -137300,2.9039042,3.5604515,,,,,,,,,,,,,, -137400,3.182873,2.6244233,,,,,,,,,,,,,, -137500,3.3038812,2.6156855,,,,,,,,,,,,,, -137600,2.8643098,2.6916854,,,,,,,,,,,,,, -137608,,,0.7710937261581421,1.0141106843948364,0.7095400094985962,1.2850812673568726,50000.0,0.5869000554084778,1.897950291633606,10000.0,63078.72762656212,67953.34159827232,63078.72762656212,4860.415402173996,6.61378026008606,0.0 -137700,2.817639,4.2619467,,,,,,,,,,,,,, -137800,2.9641244,2.7633936,,,,,,,,,,,,,, -137900,2.725552,3.7001092,,,,,,,,,,,,,, -138000,2.886626,2.9362526,,,,,,,,,,,,,, -138100,3.1140144,3.4255276,,,,,,,,,,,,,, -138200,2.7393558,4.304074,,,,,,,,,,,,,, -138300,3.011894,2.6111174,,,,,,,,,,,,,, -138400,2.7607408,4.3927717,,,,,,,,,,,,,, -138500,3.3312986,4.7012215,,,,,,,,,,,,,, -138523,,,0.7734375,1.0056006908416748,0.7101199626922607,1.2799828052520752,50000.0,0.5896000266075134,1.89242959022522,10000.0,63498.872061014175,68406.37291073799,63498.872061014175,4893.19992518425,6.666823387145996,0.0 -138600,2.6335645,2.9858801,,,,,,,,,,,,,, -138700,2.930568,3.0918481,,,,,,,,,,,,,, -138800,3.0873625,3.055345,,,,,,,,,,,,,, -138900,2.9941592,2.7005363,,,,,,,,,,,,,, -139000,3.2139242,4.1123724,,,,,,,,,,,,,, -139100,3.0299962,2.59759,,,,,,,,,,,,,, -139200,3.0702772,2.7032063,,,,,,,,,,,,,, -139300,3.2198336,2.82255,,,,,,,,,,,,,, -139400,2.9311557,4.4664764,,,,,,,,,,,,,, -139438,,,0.7773827910423279,0.9855165481567384,0.7126399874687195,1.2628765106201172,50000.0,0.5873000025749207,1.8855514526367188,10000.0,63919.03820538521,68859.4631357193,63919.03820538521,4926.0291867256165,6.712033033370972,0.0 -139500,3.0155976,4.005044,,,,,,,,,,,,,, -139600,3.308058,2.5696447,,,,,,,,,,,,,, -139700,2.9181435,2.7036095,,,,,,,,,,,,,, -139800,3.2552025,2.6184995,,,,,,,,,,,,,, -139900,3.2035096,2.5127802,,,,,,,,,,,,,, -140000,2.9958107,2.5246375,,,,,,,,,,,,,, -140100,2.8538039,2.833423,,,,,,,,,,,,,, -140200,3.0317605,3.2154489,,,,,,,,,,,,,, -140300,3.0352576,2.6490166,,,,,,,,,,,,,, -140354,,,0.78968745470047,0.938185453414917,0.7109599709510803,1.2699692249298096,50000.0,0.5871000289916992,1.891028642654419,10000.0,64339.187994003296,69311.2506814003,64339.187994003296,4957.572335958481,6.756109237670898,0.0 -140400,2.8530838,2.9958405,,,,,,,,,,,,,, -140500,2.662365,3.8302972,,,,,,,,,,,,,, -140600,3.341677,2.5378103,,,,,,,,,,,,,, -140700,3.041165,2.5146894,,,,,,,,,,,,,, -140800,2.868174,2.5786123,,,,,,,,,,,,,, -140900,2.9718552,2.55307,,,,,,,,,,,,,, -141000,3.0470235,2.850027,,,,,,,,,,,,,, -141100,3.1162918,4.4847703,,,,,,,,,,,,,, -141200,3.1631107,3.568259,,,,,,,,,,,,,, -141268,,,0.775585949420929,0.9803733825683594,0.7143399715423584,1.2429838180541992,50000.0,0.5908000469207764,1.856269598007202,10000.0,64759.18298387528,69764.55665397644,64759.18298387528,4990.778796195984,6.810421466827393,0.0 -141300,3.237747,4.315854,,,,,,,,,,,,,, -141400,2.923087,3.1287787,,,,,,,,,,,,,, -141500,2.8405943,2.8371615,,,,,,,,,,,,,, -141600,3.1980855,4.134901,,,,,,,,,,,,,, -141700,3.261083,4.1913776,,,,,,,,,,,,,, -141800,3.4603615,2.4748104,,,,,,,,,,,,,, -141900,3.8995035,4.7509975,,,,,,,,,,,,,, -142000,3.2251358,2.602272,,,,,,,,,,,,,, -142100,3.1204042,3.4128857,,,,,,,,,,,,,, -142185,,,0.7842773199081421,0.9769614338874816,0.7149400115013123,1.2617579698562622,50000.0,0.596500039100647,1.858317494392395,10000.0,65179.46191477776,70216.29029083252,65179.46191477776,5022.137254714966,6.856017589569092,0.0 -142200,3.3047128,2.5757847,,,,,,,,,,,,,, -142300,2.888166,3.6296685,,,,,,,,,,,,,, -142400,3.0583706,2.6739774,,,,,,,,,,,,,, -142500,3.11009,2.4617085,,,,,,,,,,,,,, -142600,3.0058413,2.541302,,,,,,,,,,,,,, -142700,3.0266256,2.9091854,,,,,,,,,,,,,, -142800,3.0199594,3.3624196,,,,,,,,,,,,,, -142900,3.5002186,2.6160686,,,,,,,,,,,,,, -143000,3.529769,4.386165,,,,,,,,,,,,,, -143100,3.1513457,2.492602,,,,,,,,,,,,,, -143101,,,0.7922265529632568,0.911641538143158,0.7174999713897705,1.237351417541504,50000.0,0.5972000360488892,1.844268798828125,10000.0,65600.1053712368,70668.46542716026,65600.1053712368,5053.564244508743,6.910325288772583,0.0 -143200,3.4002209,2.5978503,,,,,,,,,,,,,, -143300,3.8772087,4.6630387,,,,,,,,,,,,,, -143400,3.430577,2.6503592,,,,,,,,,,,,,, -143500,3.1612132,2.620431,,,,,,,,,,,,,, -143600,3.081424,3.0458875,,,,,,,,,,,,,, -143700,2.944878,3.411071,,,,,,,,,,,,,, -143800,3.259803,2.556141,,,,,,,,,,,,,, -143900,2.8536177,3.4903655,,,,,,,,,,,,,, -144000,3.5132365,2.5581563,,,,,,,,,,,,,, -144016,,,0.7851757407188416,0.9561517238616944,0.7189399600028992,1.2434120178222656,50000.0,0.6003000140190125,1.8443188667297363,10000.0,66020.01851248741,71121.61238765717,66020.01851248741,5086.700350761414,6.957584381103516,0.0 -144100,3.2826045,4.3463244,,,,,,,,,,,,,, -144200,3.3040304,2.575667,,,,,,,,,,,,,, -144300,3.1820703,2.487117,,,,,,,,,,,,,, -144400,3.4259346,2.613141,,,,,,,,,,,,,, -144500,3.9328246,2.6053674,,,,,,,,,,,,,, -144600,3.610339,2.534114,,,,,,,,,,,,,, -144700,3.6013486,2.5328305,,,,,,,,,,,,,, -144800,3.1668715,2.9190624,,,,,,,,,,,,,, -144900,3.2965574,2.5122478,,,,,,,,,,,,,, -144931,,,0.7807421684265137,0.9476374387741088,0.7181999683380127,1.2305222749710083,50000.0,0.5937000513076782,1.8433793783187864,10000.0,66440.11183166504,71575.11425447464,66440.11183166504,5120.010430335999,7.005602598190308,0.0 -145000,3.099601,3.2681324,,,,,,,,,,,,,, -145100,3.4239693,2.6430364,,,,,,,,,,,,,, -145200,3.285455,2.589188,,,,,,,,,,,,,, -145300,3.3995776,2.4689567,,,,,,,,,,,,,, -145400,3.3298292,2.5609393,,,,,,,,,,,,,, -145500,3.3919413,2.4964497,,,,,,,,,,,,,, -145600,3.4217377,4.2501135,,,,,,,,,,,,,, -145700,3.2895286,2.4408462,,,,,,,,,,,,,, -145800,3.223227,3.5483146,,,,,,,,,,,,,, -145846,,,0.7958202958106995,0.9202547073364258,0.7206000089645386,1.2377933263778689,50000.0,0.6014000177383423,1.8521920442581177,10000.0,66860.22039985657,72026.32795882225,66860.22039985657,5151.019411563873,7.051945686340332,0.0 -145900,3.122467,3.9575129,,,,,,,,,,,,,, -146000,3.517692,2.6621926,,,,,,,,,,,,,, -146100,3.2208314,2.4994879,,,,,,,,,,,,,, -146200,3.438695,2.500023,,,,,,,,,,,,,, -146300,3.463728,2.7455323,,,,,,,,,,,,,, -146400,3.3901255,3.1002295,,,,,,,,,,,,,, -146500,3.924068,4.6278,,,,,,,,,,,,,, -146600,3.4102812,2.5455368,,,,,,,,,,,,,, -146700,3.4845488,2.511236,,,,,,,,,,,,,, -146759,,,0.7870507836341858,0.943701982498169,0.7204200029373169,1.2290470600128174,50000.0,0.6012000441551208,1.833222389221192,10000.0,67280.31985378265,72479.94853115082,67280.31985378265,5184.43443775177,7.107837200164795,0.0 -146800,4.7767687,4.644071,,,,,,,,,,,,,, -146900,3.5371597,2.8305774,,,,,,,,,,,,,, -147000,3.521345,2.5266337,,,,,,,,,,,,,, -147100,4.215097,2.611788,,,,,,,,,,,,,, -147200,3.2050128,3.7965355,,,,,,,,,,,,,, -147300,3.9705818,2.4835904,,,,,,,,,,,,,, -147400,3.400017,3.6125395,,,,,,,,,,,,,, -147500,3.4130628,3.2100925,,,,,,,,,,,,,, -147600,3.8785138,2.4879377,,,,,,,,,,,,,, -147673,,,0.79359370470047,0.9170867204666138,0.7231999635696411,1.2131633758544922,50000.0,0.5987000465393066,1.8122233152389529,10000.0,67700.53570318222,72934.26190567017,67700.53570318222,5218.43242764473,7.157444715499878,0.0 -147700,3.3168814,2.7829905,,,,,,,,,,,,,, -147800,3.448952,2.5072832,,,,,,,,,,,,,, -147900,3.3148584,2.4490278,,,,,,,,,,,,,, -148000,3.5529053,2.5146475,,,,,,,,,,,,,, -148100,3.5951693,2.7716048,,,,,,,,,,,,,, -148200,3.7980545,2.64786,,,,,,,,,,,,,, -148300,3.1977444,2.7433596,,,,,,,,,,,,,, -148400,3.5231338,3.6389418,,,,,,,,,,,,,, -148500,3.358323,4.012281,,,,,,,,,,,,,, -148592,,,0.7974804639816284,0.8924198150634766,0.7242599725723267,1.2011919021606443,50000.0,0.6028000116348267,1.801270842552185,10000.0,68120.48303842545,73388.2038257122,68120.48303842545,5252.327347993851,7.205964088439941,0.0 -148600,3.4261737,2.450588,,,,,,,,,,,,,, -148700,3.4302154,2.3860078,,,,,,,,,,,,,, -148800,3.8124049,4.438034,,,,,,,,,,,,,, -148900,3.5286517,2.6794782,,,,,,,,,,,,,, -149000,3.8713608,2.4261422,,,,,,,,,,,,,, -149100,4.467604,2.5031018,,,,,,,,,,,,,, -149200,3.74237,2.508608,,,,,,,,,,,,,, -149300,3.9340723,2.89718,,,,,,,,,,,,,, -149400,4.011509,4.480707,,,,,,,,,,,,,, -149500,4.0742116,2.3802567,,,,,,,,,,,,,, -149508,,,0.7968554496765137,0.8914695382118225,0.7272399663925171,1.1844379901885986,50000.0,0.6076000332832336,1.780668020248413,10000.0,68540.57233738899,73842.01174998283,68540.57233738899,5285.950101613998,7.251769781112671,0.0 -149600,3.3558037,3.112781,,,,,,,,,,,,,, -149700,3.2990403,3.460107,,,,,,,,,,,,,, -149800,3.7768755,2.5592718,,,,,,,,,,,,,, -149900,3.875803,2.5185325,,,,,,,,,,,,,, -150000,3.810129,2.557703,,,,,,,,,,,,,, -150100,3.5664303,3.4251592,,,,,,,,,,,,,, -150200,3.8722284,2.4983187,,,,,,,,,,,,,, -150300,4.064248,4.133264,,,,,,,,,,,,,, -150400,3.3555465,3.6547775,,,,,,,,,,,,,, -150424,,,0.7983788847923279,0.8947476744651794,0.7258999943733215,1.197722315788269,50000.0,0.6101000308990479,1.7993606328964231,10000.0,68960.80321502686,74293.92950153351,68960.80321502686,5317.539261579514,7.299367904663086,0.0 -150500,3.4944842,2.5083206,,,,,,,,,,,,,, -150600,3.4993403,2.5216808,,,,,,,,,,,,,, -150700,3.806022,4.269171,,,,,,,,,,,,,, -150800,3.49543,3.1247718,,,,,,,,,,,,,, -150900,3.7024105,2.3925474,,,,,,,,,,,,,, -151000,3.8256342,2.3888574,,,,,,,,,,,,,, -151100,3.7703485,2.515645,,,,,,,,,,,,,, -151200,4.0099077,2.4640453,,,,,,,,,,,,,, -151300,3.7498176,4.207905,,,,,,,,,,,,,, -151339,,,0.8019726276397705,0.8627512454986572,0.7283200025558472,1.176479458808899,50000.0,0.6135000586509705,1.7731417417526243,10000.0,69380.76213383675,74745.69967579842,69380.76213383675,5349.239114522934,7.360641241073608,0.0 -151400,3.7749147,3.8934736,,,,,,,,,,,,,, -151500,3.3028243,3.315192,,,,,,,,,,,,,, -151600,4.033526,2.5111241,,,,,,,,,,,,,, -151700,3.8073306,2.516864,,,,,,,,,,,,,, -151800,3.6970444,2.8812342,,,,,,,,,,,,,, -151900,4.1297293,2.404542,,,,,,,,,,,,,, -152000,4.4678154,2.468036,,,,,,,,,,,,,, -152100,4.412911,2.5396852,,,,,,,,,,,,,, -152200,3.5823727,2.445545,,,,,,,,,,,,,, -152253,,,0.8101562261581421,0.8360294103622437,0.7314800024032593,1.171595811843872,50000.0,0.6112000346183777,1.764483094215393,10000.0,69800.7703230381,75199.2972612381,69800.7703230381,5382.729545354843,7.409675359725952,0.0 -152300,3.7502358,3.6669862,,,,,,,,,,,,,, -152400,3.856808,2.6171813,,,,,,,,,,,,,, -152500,3.7116144,2.6436584,,,,,,,,,,,,,, -152600,3.6543796,2.8001182,,,,,,,,,,,,,, -152700,3.68359,2.3768768,,,,,,,,,,,,,, -152800,4.1255226,2.4739354,,,,,,,,,,,,,, -152900,4.571383,4.131557,,,,,,,,,,,,,, -153000,3.6961539,3.4516675,,,,,,,,,,,,,, -153100,4.21104,3.9976394,,,,,,,,,,,,,, -153168,,,0.8065429329872131,0.850128710269928,0.7335399985313416,1.1555325984954834,50000.0,0.6140000224113464,1.7503547668457031,10000.0,70220.83098077774,75651.4686627388,70220.83098077774,5414.739239692688,7.460782766342163,0.0 -153200,3.9453468,2.5756652,,,,,,,,,,,,,, -153300,4.2106423,2.899344,,,,,,,,,,,,,, -153400,3.9992075,2.3502529,,,,,,,,,,,,,, -153500,3.8163826,3.2580094,,,,,,,,,,,,,, -153600,4.213654,2.4980018,,,,,,,,,,,,,, -153700,3.81237,2.487009,,,,,,,,,,,,,, -153800,3.636984,3.138877,,,,,,,,,,,,,, -153900,3.8210566,2.3647206,,,,,,,,,,,,,, -154000,4.261025,2.3900764,,,,,,,,,,,,,, -154083,,,0.8090429306030273,0.8482322692871094,0.7348799705505371,1.165394306182861,50000.0,0.6159000396728516,1.768608331680298,10000.0,70640.93357086182,76105.66011571884,70640.93357086182,5448.728743553162,7.51036524772644,0.0 -154100,4.698037,4.430338,,,,,,,,,,,,,, -154200,4.2102003,2.573469,,,,,,,,,,,,,, -154300,4.166495,2.715942,,,,,,,,,,,,,, -154400,4.352486,2.358682,,,,,,,,,,,,,, -154500,3.8528013,3.0292425,,,,,,,,,,,,,, -154600,4.0460114,2.4038076,,,,,,,,,,,,,, -154700,4.215269,2.4256945,,,,,,,,,,,,,, -154800,4.2777762,2.5821605,,,,,,,,,,,,,, -154900,4.2051497,2.414612,,,,,,,,,,,,,, -154996,,,0.8161718845367432,0.8130109310150146,0.7342199683189392,1.1588497161865234,50000.0,0.614300012588501,1.7606123685836792,10000.0,71061.22500658035,76558.61492061615,71061.22500658035,5481.284249305725,7.567243576049805,0.0 -155000,3.828829,2.9394143,,,,,,,,,,,,,, -155100,3.9983547,2.5353122,,,,,,,,,,,,,, -155200,4.210378,2.4285452,,,,,,,,,,,,,, -155300,3.7871823,2.8433952,,,,,,,,,,,,,, -155400,4.0930166,2.4363196,,,,,,,,,,,,,, -155500,4.531346,4.4487114,,,,,,,,,,,,,, -155600,3.7177997,2.4211202,,,,,,,,,,,,,, -155700,4.0062265,3.015172,,,,,,,,,,,,,, -155800,3.7650452,2.4012754,,,,,,,,,,,,,, -155900,3.8369246,2.7903562,,,,,,,,,,,,,, -155912,,,0.8117773532867432,0.8497369289398193,0.7390799522399902,1.1490964889526367,50000.0,0.6164000034332275,1.768826603889465,10000.0,71481.48496103287,77011.67231607437,71481.48496103287,5513.983215093613,7.616268157958984,0.0 -156000,4.1174397,2.474412,,,,,,,,,,,,,, -156100,4.5024643,2.4924734,,,,,,,,,,,,,, -156200,4.224721,2.422073,,,,,,,,,,,,,, -156300,4.149583,3.8439054,,,,,,,,,,,,,, -156400,5.561456,4.4861937,,,,,,,,,,,,,, -156500,4.0577326,2.3945434,,,,,,,,,,,,,, -156600,4.132515,2.537324,,,,,,,,,,,,,, -156700,4.481554,2.4163868,,,,,,,,,,,,,, -156800,4.227215,2.3761861,,,,,,,,,,,,,, -156825,,,0.8138867020606995,0.8235838413238525,0.7387599945068359,1.1445491313934326,50000.0,0.617900013923645,1.7345569133758545,10000.0,71901.40081095695,77463.90251994133,71901.40081095695,5546.199652433395,7.663997650146484,0.0 -156900,4.3786526,2.489207,,,,,,,,,,,,,, -157000,4.6743903,2.296553,,,,,,,,,,,,,, -157100,4.1089725,2.3834121,,,,,,,,,,,,,, -157200,4.0562563,3.5382888,,,,,,,,,,,,,, -157300,4.543538,2.4285245,,,,,,,,,,,,,, -157400,4.913051,2.3632786,,,,,,,,,,,,,, -157500,4.9248867,2.3295996,,,,,,,,,,,,,, -157600,4.3213882,2.3713663,,,,,,,,,,,,,, -157700,4.2613926,2.3441932,,,,,,,,,,,,,, -157740,,,0.8207616806030273,0.800974428653717,0.7397199869155884,1.1447432041168213,50000.0,0.6212000250816345,1.7404789924621582,10000.0,72321.37640357018,77917.57216620445,72321.37640357018,5579.793093681335,7.714296340942383,0.0 -157800,4.2235456,2.5534506,,,,,,,,,,,,,, -157900,4.2125225,2.4537091,,,,,,,,,,,,,, -158000,4.2338133,2.3899674,,,,,,,,,,,,,, -158100,4.2682796,2.3182373,,,,,,,,,,,,,, -158200,3.9887562,3.6059134,,,,,,,,,,,,,, -158300,4.3335233,2.5270298,,,,,,,,,,,,,, -158400,4.360615,2.4242692,,,,,,,,,,,,,, -158500,4.5556574,2.595992,,,,,,,,,,,,,, -158600,4.9096985,4.0950127,,,,,,,,,,,,,, -158654,,,0.81507807970047,0.812286913394928,0.7413399815559387,1.1248809099197388,50000.0,0.625700056552887,1.7191606760025024,10000.0,72741.29397368431,78371.34893417358,72741.29397368431,5613.552079439163,7.765320539474487,0.0 -158700,4.3275623,3.3261714,,,,,,,,,,,,,, -158800,5.182661,4.0105076,,,,,,,,,,,,,, -158900,4.250817,3.063925,,,,,,,,,,,,,, -159000,4.480187,2.3901155,,,,,,,,,,,,,, -159100,3.9385498,2.7232895,,,,,,,,,,,,,, -159200,4.4002028,3.114353,,,,,,,,,,,,,, -159300,4.376199,2.4618301,,,,,,,,,,,,,, -159400,4.6946445,2.4310656,,,,,,,,,,,,,, -159500,4.4115553,2.319306,,,,,,,,,,,,,, -159569,,,0.8191210627555847,0.8072444200515747,0.7432799935340881,1.1267979145050049,50000.0,0.6260000467300415,1.720393419265747,10000.0,73161.59970808029,78823.85418653488,73161.59970808029,5645.6390788555145,7.827660083770752,0.0 -159600,4.1204267,3.6125839,,,,,,,,,,,,,, -159700,4.634108,3.74417,,,,,,,,,,,,,, -159800,4.3804255,2.3213081,,,,,,,,,,,,,, -159900,4.6273613,3.3315496,,,,,,,,,,,,,, -160000,4.9027715,4.0873656,,,,,,,,,,,,,, -160100,4.230355,3.58985,,,,,,,,,,,,,, -160200,4.9494886,3.7550235,,,,,,,,,,,,,, -160300,5.2208843,4.226954,,,,,,,,,,,,,, -160400,4.442016,2.449759,,,,,,,,,,,,,, -160442,,,0.8229101300239563,0.7834881544113159,0.744659960269928,1.1210800409317017,50000.0,0.6271000504493713,1.7195051908493042,10000.0,73581.53504562378,79278.61271739006,73581.53504562378,5680.3631727695465,7.879194021224976,0.0 -160500,4.391049,2.2305226,,,,,,,,,,,,,, -160600,5.195231,4.134737,,,,,,,,,,,,,, -160700,4.875384,2.451442,,,,,,,,,,,,,, -160800,4.734665,3.7968996,,,,,,,,,,,,,, -160900,4.7706327,2.5845597,,,,,,,,,,,,,, -161000,4.7686534,2.7071452,,,,,,,,,,,,,, -161100,4.5378895,2.2430353,,,,,,,,,,,,,, -161200,4.4794965,2.3434834,,,,,,,,,,,,,, -161300,4.768068,2.3828368,,,,,,,,,,,,,, -161358,,,0.82289057970047,0.7732518315315247,0.7462199926376343,1.0985074043273926,50000.0,0.6326000094413757,1.683404803276062,10000.0,74001.81647443771,79730.39583301544,74001.81647443771,5711.764150619507,7.929988384246826,0.0 -161400,4.196041,2.8996143,,,,,,,,,,,,,, -161500,4.246245,3.335432,,,,,,,,,,,,,, -161600,4.537725,2.3855171,,,,,,,,,,,,,, -161700,4.7841845,2.3898206,,,,,,,,,,,,,, -161800,4.4800596,2.2624183,,,,,,,,,,,,,, -161900,5.2313724,4.05805,,,,,,,,,,,,,, -162000,5.0408397,2.4286058,,,,,,,,,,,,,, -162100,6.156624,4.3218536,,,,,,,,,,,,,, -162200,4.320743,2.389998,,,,,,,,,,,,,, -162271,,,0.8252733945846558,0.7839264273643494,0.7461000084877014,1.112027645111084,50000.0,0.632900059223175,1.7059128284454346,10000.0,74421.98902630806,80182.82422280312,74421.98902630806,5743.91107749939,7.989051818847656,0.0 -162300,5.3419256,2.339717,,,,,,,,,,,,,, -162400,4.6352606,2.3238404,,,,,,,,,,,,,, -162500,4.869153,2.3675876,,,,,,,,,,,,,, -162600,4.6907244,2.323892,,,,,,,,,,,,,, -162700,4.9244046,2.3125324,,,,,,,,,,,,,, -162800,5.8455224,4.0541043,,,,,,,,,,,,,, -162900,4.9568796,3.629508,,,,,,,,,,,,,, -163000,4.859818,2.3388438,,,,,,,,,,,,,, -163100,4.318821,3.0724084,,,,,,,,,,,,,, -163186,,,0.8279882669448853,0.7568796277046204,0.7485599517822266,1.0927319526672363,50000.0,0.6315000057220459,1.6792042255401611,10000.0,74842.11705088615,80636.16126894951,74842.11705088615,5777.016916275024,8.03922438621521,0.0 -163200,4.730288,2.2930458,,,,,,,,,,,,,, -163300,5.1586876,2.4242024,,,,,,,,,,,,,, -163400,4.8682394,2.3573928,,,,,,,,,,,,,, -163500,5.072216,3.5809245,,,,,,,,,,,,,, -163600,5.10694,3.946532,,,,,,,,,,,,,, -163700,4.7455935,3.104183,,,,,,,,,,,,,, -163800,5.2272964,2.3490753,,,,,,,,,,,,,, -163900,4.9205446,2.384901,,,,,,,,,,,,,, -164000,5.081567,3.5595324,,,,,,,,,,,,,, -164100,6.056404,4.3720236,,,,,,,,,,,,,, -164101,,,0.8231250047683716,0.784327507019043,0.7511199712753296,1.100296974182129,50000.0,0.6350000500679016,1.6938539743423462,10000.0,75262.68575310707,81089.08715343475,75262.68575310707,5809.274814844132,8.088002920150757,0.0 -164200,5.6398892,4.2875805,,,,,,,,,,,,,, -164300,5.2638774,4.0132003,,,,,,,,,,,,,, -164400,4.870557,2.4171467,,,,,,,,,,,,,, -164500,4.61302,2.2430046,,,,,,,,,,,,,, -164600,5.3689013,2.3125122,,,,,,,,,,,,,, -164700,4.9183545,3.7023497,,,,,,,,,,,,,, -164800,4.537648,2.7750885,,,,,,,,,,,,,, -164900,5.2369103,2.5197003,,,,,,,,,,,,,, -165000,5.042342,2.373418,,,,,,,,,,,,,, -165017,,,0.8282421827316284,0.7752244472503662,0.7502599954605103,1.0974425077438354,50000.0,0.6292000412940979,1.6791881322860718,10000.0,75682.98915481567,81542.0298383236,75682.98915481567,5841.812193155289,8.13952898979187,0.0 -165100,4.770754,3.5231504,,,,,,,,,,,,,, -165200,4.7596083,2.4380221,,,,,,,,,,,,,, -165300,4.734285,2.524931,,,,,,,,,,,,,, -165400,4.86002,2.3368587,,,,,,,,,,,,,, -165500,4.7475796,3.2437222,,,,,,,,,,,,,, -165600,5.655956,4.2562637,,,,,,,,,,,,,, -165700,5.5778813,2.3191226,,,,,,,,,,,,,, -165800,5.221932,2.2731724,,,,,,,,,,,,,, -165900,4.9701157,2.6735435,,,,,,,,,,,,,, -165934,,,0.8337695002555847,0.7408263087272644,0.7530800104141235,1.0839965343475342,50000.0,0.6335000395774841,1.672433853149414,10000.0,76103.24325037003,81996.13554024696,76103.24325037003,5875.553389310837,8.199171781539917,0.0 -166000,4.9227448,2.4737453,,,,,,,,,,,,,, -166100,5.0983744,2.737607,,,,,,,,,,,,,, -166200,5.1221957,2.5105872,,,,,,,,,,,,,, -166300,6.687985,2.4605758,,,,,,,,,,,,,, -166400,5.02105,2.762753,,,,,,,,,,,,,, -166500,4.9829803,2.2688186,,,,,,,,,,,,,, -166600,5.181909,3.4327981,,,,,,,,,,,,,, -166700,5.5079284,2.60682,,,,,,,,,,,,,, -166800,4.809514,2.3458223,,,,,,,,,,,,,, -166850,,,0.8430468440055847,0.7069917917251587,0.7552599906921387,1.0686745643615725,50000.0,0.6446000337600708,1.6342931985855105,10000.0,76523.2434270382,82450.94435310364,76523.2434270382,5910.257388830185,8.253775358200073,0.0 -166900,5.388859,3.9067607,,,,,,,,,,,,,, -167000,5.0833106,3.7707791,,,,,,,,,,,,,, -167100,5.173587,2.3447607,,,,,,,,,,,,,, -167200,5.071762,2.8245008,,,,,,,,,,,,,, -167300,4.595765,2.4052374,,,,,,,,,,,,,, -167400,4.8020964,3.718265,,,,,,,,,,,,,, -167500,5.4971075,2.2624316,,,,,,,,,,,,,, -167600,5.2850475,2.274371,,,,,,,,,,,,,, -167700,5.5118318,3.7317371,,,,,,,,,,,,,, -167765,,,0.8338086009025574,0.7452972531318665,0.7548399567604065,1.074892282485962,50000.0,0.6406000256538391,1.656589150428772,10000.0,76943.4230298996,82904.4444053173,76943.4230298996,5943.479510307312,8.302427291870117,0.0 -167800,5.0213256,3.5008137,,,,,,,,,,,,,, -167900,4.5806794,3.036641,,,,,,,,,,,,,, -168000,5.791859,2.2566311,,,,,,,,,,,,,, -168100,5.455316,3.1214435,,,,,,,,,,,,,, -168200,5.024442,2.814005,,,,,,,,,,,,,, -168300,5.1516337,3.008329,,,,,,,,,,,,,, -168400,4.848825,3.166434,,,,,,,,,,,,,, -168500,5.4774704,2.9541168,,,,,,,,,,,,,, -168600,5.0681,2.3949292,,,,,,,,,,,,,, -168679,,,0.83363276720047,0.7375044822692871,0.7561999559402466,1.0720763206481934,50000.0,0.6388000249862671,1.652498722076416,10000.0,77363.56269574165,83358.4293923378,77363.56269574165,5977.216939926148,8.360501766204834,0.0 -168700,5.996604,4.105456,,,,,,,,,,,,,, -168800,4.8884335,2.345025,,,,,,,,,,,,,, -168900,5.3833513,2.8591452,,,,,,,,,,,,,, -169000,4.9739146,2.2644687,,,,,,,,,,,,,, -169029,,,,,,,,,,,77520.33018016815,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/eval_measurements.csv deleted file mode 100644 index d96dfc557..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.24734091758728,0.0,33.638116121292114,1,0,33.638116121292114,0.0010000000474974,6.907756805419922,10000,61.88614153862,0.000800781243015,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -63.434080839157104,0.0171256065368652,453.676509141922,861,0,453.676509141922,0.0138000007718801,6.452201843261719,10000,517.1768567562103,0.0163085926324129,6.390534400939941,0.0160399992018938,6.40871000289917,50000 -97.9622402191162,0.04842209815979,873.8910715579987,1772,0,873.8910715579987,0.036700002849102,5.948112487792969,10000,972.0015864372252,0.0479882806539535,5.804277896881104,0.0460199974477291,5.834495544433594,50000 -134.87805843353271,0.0764341354370117,1294.1512134075165,2684,0,1294.1512134075165,0.0516000017523765,5.641368389129639,10000,1429.2565631866455,0.0750781223177909,5.424489498138428,0.0673599988222122,5.472591876983643,50000 -168.3693425655365,0.1036822795867919,1714.4163200855255,3598,0,1714.4163200855255,0.077000007033348,5.386642932891846,10000,1883.0908553600311,0.1093554645776748,5.114729881286621,0.1017999947071075,5.160167217254639,50000 -204.27332472801208,0.1335217952728271,2134.3432302474976,4511,0,2134.3432302474976,0.1103000044822692,4.995343208312988,10000,2339.002152442932,0.1607226580381393,4.637515068054199,0.1462000012397766,4.709775924682617,50000 -239.3009958267212,0.158071756362915,2554.3411922454834,5424,0,2554.3411922454834,0.1410000026226043,4.718472480773926,10000,2794.102832555771,0.2049999982118606,4.282425880432129,0.1854999959468841,4.378244400024414,50000 -274.43955540657043,0.1884043216705322,2974.4840099811554,6337,0,2974.4840099811554,0.174900010228157,4.422224998474121,10000,3249.46443939209,0.2488867193460464,3.939825773239136,0.2300200015306472,4.036154747009277,50000 -308.14124870300293,0.2187526226043701,3394.577829360962,7251,0,3394.577829360962,0.2026000022888183,4.218814849853516,10000,3703.3408839702606,0.2895117104053497,3.659607172012329,0.2579399943351745,3.816094875335693,50000 -344.1261923313141,0.2454137802124023,3814.550952911377,8163,0,3814.550952911377,0.2290000170469284,4.028704643249512,10000,4159.377365589142,0.3244921863079071,3.4557206630706787,0.3014000058174133,3.572089910507202,50000 -377.9731593132019,0.2774481773376465,4234.478483200073,9074,0,4234.478483200073,0.2499000132083892,3.857383728027344,10000,4613.234064817429,0.3578320145606994,3.234337329864502,0.3291600048542022,3.37368392944336,50000 -411.2083342075348,0.3075971603393554,4654.747854232788,9986,0,4654.747854232788,0.2779000103473663,3.6841766834259033,10000,5066.819159269333,0.3981249928474426,2.985623359680176,0.3610599935054779,3.1759495735168457,50000 -445.02062249183655,0.3370425701141357,5074.991655111313,10899,0,5074.991655111313,0.3056000173091888,3.5289344787597656,10000,5520.955825805664,0.4224413931369781,2.8559553623199463,0.391759991645813,2.990187644958496,50000 -479.3563580513001,0.3674821853637695,5494.998826980591,11811,0,5494.998826980591,0.3346000015735626,3.379513740539551,10000,5975.380021810532,0.4569726586341858,2.680199146270752,0.4228200018405914,2.834578275680542,50000 -513.4131627082825,0.3969335556030273,5914.95384311676,12722,0,5914.95384311676,0.3421000242233276,3.313959836959839,10000,6429.472145318985,0.4777539074420929,2.5539870262146,0.4357999861240387,2.754807949066162,50000 -548.3734202384949,0.4250233173370361,6335.079059839249,13634,0,6335.079059839249,0.3613000214099884,3.21649432182312,10000,6884.635620117188,0.4913281202316284,2.4966847896575928,0.4591799974441528,2.6533918380737305,50000 -582.966285943985,0.4544599056243896,6755.355276584625,14550,0,6755.355276584625,0.3671000301837921,3.088716983795166,10000,7339.585505962372,0.514355480670929,2.332359790802002,0.4775199890136719,2.504133939743042,50000 -616.5665156841278,0.4815404415130615,7175.678120136261,15464,0,7175.678120136261,0.3820000290870666,3.105388879776001,10000,7793.58558678627,0.5285546779632568,2.330787420272827,0.4850399792194366,2.5376617908477783,50000 -651.0233614444733,0.5114836692810059,7596.013070583343,16379,0,7596.013070583343,0.3884000182151794,3.026716470718384,10000,8248.45794916153,0.5498241782188416,2.2195632457733154,0.5060399770736694,2.409464120864868,50000 -683.3534235954285,0.5411381721496582,8016.122918605804,17293,0,8016.122918605804,0.409600019454956,2.956143617630005,10000,8700.978935956955,0.551562488079071,2.1806676387786865,0.5150399804115295,2.34705114364624,50000 -717.8099949359894,0.5694966316223145,8436.345466375351,18206,0,8436.345466375351,0.4150000214576721,2.885838031768799,10000,9155.736872911451,0.5733984112739563,2.088829040527344,0.5263000130653381,2.2876055240631104,50000 -751.9210438728333,0.5966720581054688,8856.345707178116,19119,0,8856.345707178116,0.4210000336170196,2.853330612182617,10000,9609.926443576813,0.6010546684265137,1.936023116111756,0.5326799750328064,2.2370855808258057,50000 -786.1423320770264,0.6236920356750488,9276.2863240242,20024,0,9276.2863240242,0.4229000210762024,2.806061267852783,10000,10064.165142297745,0.581347644329071,2.0001742839813232,0.541920006275177,2.18528151512146,50000 -821.1112020015717,0.6576018333435059,9696.536899328232,20935,0,9696.536899328232,0.4413000345230102,2.791691541671753,10000,10519.46850657463,0.5981249809265137,1.97610092163086,0.5450199842453003,2.1948776245117188,50000 -855.9199452400208,0.6856215000152588,10116.994480848312,21847,0,10116.994480848312,0.4457000195980072,2.695769786834717,10000,10974.81406068802,0.6225780844688416,1.8207433223724363,0.5593400001525879,2.095564365386963,50000 -890.505363702774,0.7179117202758789,10537.3939139843,22761,0,10537.3939139843,0.448600023984909,2.720872640609741,10000,11429.881940364838,0.6095117330551147,1.905519366264344,0.5639399886131287,2.1037421226501465,50000 -923.5195500850676,0.7508177757263184,10957.388006210327,23672,0,10957.388006210327,0.4562000334262848,2.6657521724700928,10000,11882.973408699036,0.6233788728713989,1.81912624835968,0.5717200040817261,2.032801628112793,50000 -957.6350147724152,0.784376859664917,11377.64279603958,24585,0,11377.64279603958,0.456900030374527,2.673712015151977,10000,12337.427982330322,0.6335155963897705,1.789616942405701,0.5715799927711487,2.054870367050171,50000 -991.7295281887054,0.8135378360748291,11797.979766845703,25496,0,11797.979766845703,0.4708000123500824,2.616346597671509,10000,12791.94000005722,0.6327929496765137,1.804942488670349,0.5851399898529053,2.0230395793914795,50000 -1025.8188734054563,0.8469099998474121,12218.137630224228,26406,0,12218.137630224228,0.4729000329971313,2.6314289569854736,10000,13246.270922422407,0.6333593726158142,1.802080750465393,0.5868799686431885,2.013918876647949,50000 -1060.2494506835938,0.8771259784698486,12638.15647506714,27317,0,12638.15647506714,0.4727000296115875,2.574725866317749,10000,13700.801671981812,0.6479687094688416,1.6960623264312744,0.5909599661827087,1.9587397575378416,50000 -1093.944759130478,0.9052963256835938,13058.312211036682,28229,0,13058.312211036682,0.4753000140190124,2.569722890853882,10000,14154.731746673584,0.6434375047683716,1.7335150241851809,0.5952199697494507,1.9401875734329224,50000 -1127.892193555832,0.9366669654846193,13478.41717338562,29139,0,13478.41717338562,0.4758000373840332,2.5420379638671875,10000,14608.866518974304,0.6492577791213989,1.7050681114196775,0.601639986038208,1.9144740104675293,50000 -1161.8061337471008,0.969707489013672,13898.656394004822,30052,0,13898.656394004822,0.4879000186920166,2.49440336227417,10000,15063.104083538055,0.6590429544448853,1.6291348934173584,0.6089199781417847,1.8620059490203853,50000 -1195.0702483654022,1.0004019737243652,14318.683321475984,30963,0,14318.683321475984,0.4963000118732452,2.465266227722168,10000,15516.47614622116,0.6642773151397705,1.6285508871078491,0.6131199598312378,1.841789960861206,50000 -1228.8334305286407,1.0361483097076416,14739.014737844467,31875,0,14739.014737844467,0.4912000298500061,2.4833028316497803,10000,15970.657630205154,0.6638085842132568,1.6437461376190186,0.6118999719619751,1.870315432548523,50000 -1263.0607657432556,1.0695884227752686,15159.29290342331,32789,0,15159.29290342331,0.4969000220298767,2.4702181816101074,10000,16425.24866938591,0.6706835627555847,1.6040401458740234,0.614139974117279,1.843677759170532,50000 -1298.4260349273682,1.1072144508361816,15579.49324440956,33701,0,15579.49324440956,0.4958000183105469,2.447704315185547,10000,16880.903116226196,0.6911327838897705,1.4933984279632568,0.6180599927902222,1.8183931112289429,50000 -1331.7058236598969,1.141645908355713,15999.567883491516,34614,0,15999.567883491516,0.5020000338554382,2.422393560409546,10000,17334.343224287033,0.6721093654632568,1.5771911144256592,0.619219958782196,1.7937535047531128,50000 -1365.5460460186005,1.1748707294464111,16419.87308859825,35525,0,16419.87308859825,0.506600022315979,2.37952208518982,10000,17788.572286605835,0.6826562285423279,1.5192320346832275,0.6253799796104431,1.7672829627990725,50000 -1398.829402923584,1.20875883102417,16839.959003686905,36438,0,16839.959003686905,0.5051000118255615,2.390291929244995,10000,18242.025453329086,0.6946093440055847,1.4783991575241089,0.6278199553489685,1.7693872451782229,50000 -1432.1537964344025,1.2387840747833252,17260.044524669647,37351,0,17260.044524669647,0.5121999979019165,2.3797619342803955,10000,18695.51702475548,0.6821093559265137,1.5482853651046753,0.6320799589157104,1.763115167617798,50000 -1465.1609783172607,1.273481845855713,17680.28241252899,38264,0,17680.28241252899,0.5141000151634216,2.38298773765564,10000,19148.84790277481,0.6850000023841858,1.5315663814544678,0.6313199996948242,1.7667924165725708,50000 -1498.4337601661682,1.3083534240722656,18100.41104626656,39176,0,18100.41104626656,0.5213000178337097,2.35638165473938,10000,19602.334558963776,0.6991210579872131,1.461413025856018,0.6354999542236328,1.740901231765747,50000 -1531.565257549286,1.344674587249756,18520.329745292664,40088,0,18520.329745292664,0.5170000195503235,2.318955183029175,10000,20055.471518993378,0.6839648485183716,1.4925146102905271,0.6367599964141846,1.704120635986328,50000 -1565.088744878769,1.376600980758667,18940.719674110413,41001,0,18940.719674110413,0.5174000263214111,2.3340766429901123,10000,20509.467386484142,0.6875976324081421,1.500531792640686,0.6361799836158752,1.7311961650848389,50000 -1597.391491651535,1.410295486450195,19360.97429251671,41912,0,19360.97429251671,0.5248000025749207,2.298939943313598,10000,20962.10922479629,0.7040234208106995,1.4264371395111084,0.6458399891853333,1.6851632595062256,50000 -1631.407816171646,1.445805549621582,19781.33933091164,42821,0,19781.33933091164,0.5216000080108643,2.343894720077514,10000,21416.57683706284,0.6966992020606995,1.4827358722686768,0.6440399885177612,1.710224151611328,50000 -1665.082776069641,1.4831054210662842,20201.28215122223,43733,0,20201.28215122223,0.51910001039505,2.349139451980591,10000,21870.283143281937,0.6943163871765137,1.5012829303741455,0.6460599899291992,1.7175486087799072,50000 -1696.6110389232635,1.5203256607055664,20621.605316638947,44646,0,20621.605316638947,0.5242000222206116,2.3520002365112305,10000,22322.22277355194,0.70361328125,1.4847627878189087,0.643559992313385,1.7350969314575195,50000 -1730.922518491745,1.5553491115570068,21041.62238311768,45557,0,21041.62238311768,0.5333000421524048,2.260016918182373,10000,22776.63697385788,0.7167187333106995,1.356802463531494,0.6495999693870544,1.6522201299667358,50000 -1765.5656650066376,1.5932340621948242,21461.639271736145,46468,0,21461.639271736145,0.5254999995231628,2.357267379760742,10000,23231.38586783409,0.6974999904632568,1.5168081521987915,0.6491599678993225,1.7391897439956665,50000 -1798.4531662464142,1.6383380889892578,21881.792940855023,47381,0,21881.792940855023,0.5318000316619873,2.307551622390747,10000,23684.5236389637,0.7109375,1.4334065914154053,0.6492399573326111,1.697941541671753,50000 -1831.97790145874,1.6770923137664795,22301.990122556686,48293,0,22301.990122556686,0.5335000157356262,2.266916036605835,10000,24138.33581233025,0.73046875,1.311131715774536,0.6547600030899048,1.646582007408142,50000 -1865.7097523212435,1.711674690246582,22722.04015159607,49206,0,22722.04015159607,0.5358999967575073,2.262042999267578,10000,24592.202979803085,0.710156261920929,1.4143155813217163,0.6555399894714355,1.65362811088562,50000 -1900.3253211975093,1.7474756240844729,23142.09687113762,50115,0,23142.09687113762,0.5360000133514404,2.217803955078125,10000,25046.962195158005,0.7168163657188416,1.348970651626587,0.6584199666976929,1.6088238954544067,50000 -1935.4861042499545,1.7837295532226562,23562.426542520523,51027,0,23562.426542520523,0.5329000353813171,2.230520725250244,10000,25502.540413618088,0.7308984398841858,1.3000104427337646,0.6588199734687805,1.60606586933136,50000 -1970.5525019168847,1.8164739608764648,23982.65742635727,51941,0,23982.65742635727,0.5430000424385071,2.21612286567688,10000,25957.92196583748,0.71009761095047,1.3835265636444092,0.6599999666213989,1.6134287118911743,50000 -2004.1362063884733,1.8542118072509768,24402.8046002388,52851,0,24402.8046002388,0.5406000018119812,2.2067039012908936,10000,26411.741233587265,0.7220507860183716,1.3369845151901243,0.6620000004768372,1.5932549238204956,50000 -2037.2056503295896,1.89034390449524,24822.854472875595,53763,0,24822.854472875595,0.5420000553131104,2.219597339630127,10000,26864.94805812836,0.7294335961341858,1.317986011505127,0.6652799844741821,1.6028327941894531,50000 -2070.7375979423523,1.931077480316162,25243.066395998,54677,0,25243.066395998,0.5410000085830688,2.1876070499420166,10000,27318.78368639946,0.71937495470047,1.3365647792816162,0.6654199957847595,1.5735455751419067,50000 -2104.065089225769,1.9737062454223635,25663.03023004532,55591,0,25663.03023004532,0.5437000393867493,2.244570732116699,10000,27772.1686103344,0.7230077981948853,1.379983901977539,0.6684600114822388,1.6201144456863403,50000 -2139.5847957134247,2.0120248794555664,26083.24350643158,56506,0,26083.24350643158,0.5479000210762024,2.196012496948242,10000,28227.99153614044,0.7322070002555847,1.3016204833984375,0.66975998878479,1.5884922742843628,50000 -2174.6806008815765,2.048649311065674,26503.21027612686,57418,0,26503.21027612686,0.5449000000953674,2.177110433578491,10000,28683.141376018524,0.7277148365974426,1.3060152530670166,0.6685799956321716,1.5581419467926023,50000 -2209.344695329666,2.083533763885498,26923.5639295578,58331,0,26923.5639295578,0.5443000197410583,2.1926655769348145,10000,29138.24506425857,0.7244530916213989,1.3326257467269895,0.6692999601364136,1.571911096572876,50000 -2242.9132976531982,2.122434616088867,27343.48020362854,59242,0,27343.48020362854,0.5469000339508057,2.191420316696167,10000,29591.821184158325,0.7317968606948853,1.3039686679840088,0.671999990940094,1.571418523788452,50000 -2277.7794167995453,2.156461715698242,27763.48047947884,60156,0,27763.48047947884,0.5527999997138977,2.186643362045288,10000,30046.77208018303,0.7542577981948853,1.2266101837158203,0.671459972858429,1.5660921335220337,50000 -2311.4618620872498,2.193611860275269,28183.66246366501,61068,0,28183.66246366501,0.5565000176429749,2.116438865661621,10000,30500.72461438179,0.7308984398841858,1.2666162252426147,0.6772399544715881,1.4976990222930908,50000 -2346.8619673252106,2.235115766525269,28603.927173376083,61980,0,28603.927173376083,0.5525000095367432,2.141698837280273,10000,30956.482098340988,0.73646479845047,1.2581452131271362,0.6758399605751038,1.5217695236206057,50000 -2379.947853088379,2.274474620819092,29024.041388511658,62892,0,29024.041388511658,0.5552999973297119,2.162600517272949,10000,31409.77238416672,0.7489648461341858,1.2218483686447144,0.6783999800682068,1.5327138900756836,50000 -2412.2440707683563,2.648000955581665,29443.85234260559,63803,0,29443.85234260559,0.5556000471115112,2.131977796554565,10000,31862.30394887924,0.7317187190055847,1.283490777015686,0.677899956703186,1.513463854789734,50000 -2448.706707715988,2.6889126300811768,29863.933977127075,64715,0,29863.933977127075,0.5520000457763672,2.1582353115081787,10000,32318.9393632412,0.740527331829071,1.269747018814087,0.6794599890708923,1.536699652671814,50000 -2482.5249683856964,2.7266457080841064,30284.21901607513,65626,0,30284.21901607513,0.5551000237464905,2.1161949634552,10000,32773.13093471527,0.7494140267372131,1.1988508701324463,0.6782400012016296,1.5024621486663818,50000 -2517.2284500598907,2.763366460800171,30704.13455057144,66538,0,30704.13455057144,0.5455000400543213,2.175212383270264,10000,33227.839062690735,0.7292382717132568,1.3125627040863037,0.6787999868392944,1.5494202375411987,50000 -2550.1747620105743,2.8025145530700684,31124.17865896225,67449,0,31124.17865896225,0.558899998664856,2.096421718597412,10000,33680.919352293015,0.7469726204872131,1.2171990871429443,0.6873199939727783,1.4766743183135986,50000 -2584.592365026474,2.8403160572052,31544.106262922287,68359,0,31544.106262922287,0.5598000288009644,2.143460750579834,10000,34135.353620529175,0.7477148175239563,1.2257516384124756,0.6833999752998352,1.5142306089401243,50000 -2621.076464414597,2.879974603652954,31964.290115594864,69272,0,31964.290115594864,0.5688000321388245,2.0847983360290527,10000,34592.112513780594,0.7469335794448853,1.238759160041809,0.6876199841499329,1.490525484085083,50000 -2653.8648805618286,2.919417381286621,32384.55722093582,70182,0,32384.55722093582,0.5636000037193298,2.132345676422119,10000,35045.258445978165,0.7432226538658142,1.2668992280960083,0.6850199699401855,1.5202137231826782,50000 -2690.0571892261505,2.959210157394409,32804.878509521484,71092,0,32804.878509521484,0.5626000165939331,2.0810532569885254,10000,35501.86202931404,0.7544921636581421,1.1773465871810913,0.6906999945640564,1.4612739086151123,50000 -2723.2984120845795,2.999574899673462,33225.23718857765,72003,0,33225.23718857765,0.5637000203132629,2.097407341003418,10000,35955.55195403099,0.7497460842132568,1.197008490562439,0.6877399682998657,1.4694221019744873,50000 -2756.8011240959167,3.038689374923706,33645.199608802795,72913,0,33645.199608802795,0.5586000084877014,2.1017489433288574,10000,36409.107684612274,0.7474804520606995,1.2026795148849487,0.6865999698638916,1.473946213722229,50000 -2788.398811101913,3.081950187683105,34065.1739525795,73824,0,34065.1739525795,0.569100022315979,2.0638108253479004,10000,36860.7737197876,0.7582421898841858,1.1599042415618896,0.6906999945640564,1.4475252628326416,50000 -2823.236649990082,3.122302770614624,34485.24294900894,74735,0,34485.24294900894,0.5658000111579895,2.082012176513672,10000,37315.77154183388,0.7710937261581421,1.1149591207504272,0.6913999915122986,1.4553277492523191,50000 -2856.7374284267426,3.161123752593994,34905.34354329109,75645,0,34905.34354329109,0.5645000338554382,2.1078085899353027,10000,37769.46210241318,0.7518945336341858,1.226989984512329,0.6896199584007263,1.4911552667617798,50000 -2889.461337566376,3.200040578842163,35325.38757443428,76556,0,35325.38757443428,0.5637000203132629,2.063631772994995,10000,38222.31966996193,0.7568554282188416,1.1630686521530151,0.694379985332489,1.4409019947052002,50000 -2921.5647172927856,3.242464065551758,35745.66950130463,77469,0,35745.66950130463,0.572100043296814,2.0542802810668945,10000,38674.79872870445,0.7740234136581421,1.1011523008346558,0.6961199641227722,1.4330646991729736,50000 -2955.930624961853,3.281820297241211,36165.88693213463,78381,0,36165.88693213463,0.5669000148773193,2.084571599960327,10000,39129.471900224686,0.7498828172683716,1.1963554620742798,0.6926999688148499,1.452593445777893,50000 -2991.409699201584,3.3323724269866943,36585.79311347008,79292,0,36585.79311347008,0.5736000537872314,2.044193744659424,10000,39584.9595644474,0.7596288919448853,1.14617919921875,0.6948999762535095,1.4263193607330322,50000 -3024.742676258087,3.375904083251953,37006.06656885147,80204,0,37006.06656885147,0.5777000188827515,2.0484261512756348,10000,40038.66042852402,0.7778124809265137,1.1191484928131104,0.7011399865150452,1.4412455558776855,50000 -3061.377053260803,3.416645765304565,37425.9896273613,81112,0,37425.9896273613,0.570900022983551,2.0844738483428955,10000,40495.30886769295,0.75927734375,1.1983951330184937,0.6988799571990967,1.469262957572937,50000 -3097.956475019455,3.4548816680908203,37846.07184147835,82023,0,37846.07184147835,0.5773000121116638,2.0627498626708984,10000,40952.05909061432,0.7640624642372131,1.1621880531311035,0.702019989490509,1.4289867877960205,50000 -3130.1866416931152,3.492640733718872,38266.4472155571,82935,0,38266.4472155571,0.5746000409126282,2.0360336303710938,10000,41404.75353455544,0.7707226276397705,1.1098370552062988,0.7006799578666687,1.4164754152297974,50000 -3165.501063108444,3.5313708782196045,38686.41108036041,83846,0,38686.41108036041,0.5703000426292419,2.065577745437622,10000,41860.1210463047,0.7615429759025574,1.1838170289993286,0.700760006904602,1.4477368593215942,50000 -3200.868814229965,3.5725789070129395,39106.723685503006,84759,0,39106.723685503006,0.5855000019073486,2.0161125659942627,10000,42315.893428087234,0.7667187452316284,1.1391881704330444,0.7033799886703491,1.4073001146316528,50000 -3238.1050729751587,3.625221729278565,39526.67753863335,85671,0,39526.67753863335,0.5834000110626221,2.0242509841918945,10000,42773.18708443642,0.7760937213897705,1.1103529930114746,0.7053200006484985,1.4143160581588743,50000 -3276.332666158676,3.668002128601074,39946.76177215576,86581,0,39946.76177215576,0.5835000276565552,2.003753900527954,10000,43231.5920381546,0.7747851610183716,1.1178936958312988,0.7052599787712097,1.4013491868972778,50000 -3311.738111257553,3.710970640182495,40366.83447360992,87492,0,40366.83447360992,0.5782999992370605,2.0230345726013184,10000,43687.16769170761,0.7694140672683716,1.1339526176452637,0.7033799886703491,1.4059618711471558,50000 -3345.784725189209,3.752886056900024,40787.19777345657,88404,0,40787.19777345657,0.5807000398635864,1.9886709451675413,10000,44141.67059183121,0.7787304520606995,1.0719337463378906,0.7060399651527405,1.3819568157196045,50000 -3380.6273963451385,3.791879415512085,41207.16171503067,89314,0,41207.16171503067,0.5854000449180603,1.981370210647583,10000,44596.56769442558,0.7935351133346558,1.021241545677185,0.7061399817466736,1.3768731355667114,50000 -3416.499861478805,3.834437370300293,41627.17049217224,90225,0,41627.17049217224,0.581000030040741,2.001279354095459,10000,45052.54302406311,0.7718163728713989,1.1109107732772827,0.7108599543571472,1.373457670211792,50000 -3454.4704039096832,3.881521701812744,42047.39622235298,91136,0,42047.39622235298,0.5821000337600708,1.994511842727661,10000,45510.83752632141,0.7748632431030273,1.0888638496398926,0.7067999839782715,1.3886761665344238,50000 -3489.544613361358,3.9240567684173584,42467.585582733154,92046,0,42467.585582733154,0.585800051689148,1.991669774055481,10000,45966.19464612007,0.7909570336341858,1.043149471282959,0.7130199670791626,1.3840327262878418,50000 -3525.5610892772675,3.9683220386505127,42887.68548750877,92957,0,42887.68548750877,0.5952000021934509,1.9467597007751465,10000,46422.40642333031,0.7769726514816284,1.0807150602340698,0.7142399549484253,1.3572258949279783,50000 -3559.4507846832275,4.0101823806762695,43307.87548875809,93870,0,43307.87548875809,0.593000054359436,1.960601687431336,10000,46876.57886219025,0.784863293170929,1.0583343505859375,0.7152599692344666,1.3549997806549072,50000 -3594.6408054828644,4.049016237258911,43727.9484167099,94782,0,43727.9484167099,0.5889000296592712,1.9845054149627688,10000,47331.931569337845,0.7880077958106995,1.056152582168579,0.7136200070381165,1.3756757974624634,50000 -3629.39878821373,4.098080635070801,44147.98247885704,95694,0,44147.98247885704,0.5938000082969666,1.965236783027649,10000,47786.82453203201,0.7802929282188416,1.084344744682312,0.7172200083732605,1.356879949569702,50000 -3666.518232584,4.142136096954346,44568.36045074463,96606,0,44568.36045074463,0.5927000045776367,1.9558128118515008,10000,48244.41781044006,0.7871288657188416,1.0491591691970823,0.7164199948310852,1.3515568971633911,50000 -3702.0162086486816,4.1854517459869385,44988.40380692482,97517,0,44988.40380692482,0.593500018119812,1.929459810256958,10000,48700.0536942482,0.7954687476158142,1.002249240875244,0.7200999855995178,1.323193907737732,50000 -3736.6778705120087,4.226574182510376,45408.60495519638,98429,0,45408.60495519638,0.5968000292778015,1.9434531927108765,10000,49155.00846171379,0.7873241901397705,1.0458989143371582,0.7177599668502808,1.3348171710968018,50000 -3771.2031738758087,4.268864870071411,45828.62591218949,99341,0,45828.62591218949,0.6041000485420227,1.933828592300415,10000,49609.64790344238,0.7878515720367432,1.0436352491378784,0.7171199917793274,1.3446561098098757,50000 -3805.623178482056,4.309793710708618,46248.61977934837,100253,0,46248.61977934837,0.5974000096321106,1.94913387298584,10000,50064.15418791771,0.7922070026397705,1.0412719249725342,0.7170599699020386,1.350500226020813,50000 -3841.550982236862,4.3530638217926025,46668.6330909729,101164,0,46668.6330909729,0.5995000004768372,1.9695676565170288,10000,50520.18922710419,0.8041796684265137,1.0106101036071775,0.7200999855995178,1.3571895360946655,50000 -3878.951035261154,4.401503562927246,47088.92100191116,102076,0,47088.92100191116,0.5991000533103943,1.948511242866516,10000,50977.975981235504,0.7922070026397705,1.0534024238586426,0.7217999696731567,1.349803447723389,50000 -3916.5116069316864,4.4449450969696045,47508.928409576416,102988,0,47508.928409576416,0.5969000458717346,1.9463989734649656,10000,51435.638216257095,0.79798823595047,1.0308090448379517,0.7218199968338013,1.3451817035675049,50000 -3951.357241153717,4.487648963928223,47929.27530050278,103898,0,47929.27530050278,0.6037000417709351,1.9161393642425537,10000,51890.92503976822,0.8116992115974426,0.9391869306564332,0.7234399914741516,1.3024083375930786,50000 -3985.942140340805,4.530482769012451,48349.42050933838,104812,0,48349.42050933838,0.605400025844574,1.936792254447937,10000,52345.748777627945,0.7913671731948853,1.040604829788208,0.7239199876785278,1.3334083557128906,50000 -4022.8122441768646,4.570789575576782,48769.416988134384,105725,0,48769.416988134384,0.5963000059127808,1.9204466342926023,10000,52802.70588493347,0.7983007431030273,1.0020376443862915,0.7238799929618835,1.3144477605819702,50000 -4058.433498620987,4.963132858276367,49189.34255599976,106635,0,49189.34255599976,0.5995000004768372,1.908978581428528,10000,53258.69718146324,0.8101366758346558,0.9570842385292052,0.725820004940033,1.2992676496505735,50000 -4094.8872702121735,5.010452508926392,49609.5889275074,107544,0,49609.5889275074,0.6050000190734863,1.9033678770065308,10000,53715.4951851368,0.7962890267372131,1.0105478763580322,0.7277799844741821,1.2957388162612915,50000 -4132.766643047333,5.052933216094971,50029.638154029846,108455,0,50029.638154029846,0.6094000339508057,1.8905043601989744,10000,54173.51714682579,0.802539050579071,0.989858627319336,0.7298199534416199,1.2934014797210691,50000 -4171.423996925354,5.098599672317505,50449.68918180466,109365,0,50449.68918180466,0.6034000515937805,1.914489388465881,10000,54632.3221013546,0.8080468773841858,0.9750061631202698,0.7301200032234192,1.2988848686218262,50000 -4207.95077753067,5.143129587173462,50869.71073126793,110274,0,50869.71073126793,0.6038000583648682,1.9229018688201904,10000,55088.965970516205,0.7994921803474426,1.0312963724136353,0.7298199534416199,1.3248251676559448,50000 -4244.599578619003,5.188409090042114,51289.89997005463,111184,0,51289.89997005463,0.6066000461578369,1.912151455879212,10000,55545.89949464798,0.8047069907188416,0.9976529479026794,0.7334799766540527,1.296475529670715,50000 -4282.025975942612,5.237752914428711,51709.84427022934,112093,0,51709.84427022934,0.6106000542640686,1.889366149902344,10000,56003.3701505661,0.8110741972923279,0.9474124908447266,0.7343599796295166,1.2785613536834717,50000 -4317.279319286346,5.281320095062256,52129.879022836685,113001,0,52129.879022836685,0.6092000007629395,1.8907074928283687,10000,56458.75296974182,0.8006640672683716,0.9824928045272828,0.7322399616241455,1.2828460931777954,50000 -4352.38493680954,5.327167272567749,52549.861157894135,113912,0,52549.861157894135,0.6098000407218933,1.8746048212051392,10000,56913.93776369095,0.8100976347923279,0.9563019871711732,0.7339800000190735,1.2726399898529053,50000 -4387.382179737091,5.3694212436676025,52970.15888476372,114823,0,52970.15888476372,0.6060000061988831,1.896116852760315,10000,57369.3252120018,0.8125,0.9517192840576172,0.733460009098053,1.2863125801086426,50000 -4424.193821430206,5.413953542709351,53390.33986020088,115733,0,53390.33986020088,0.6184000372886658,1.8812453746795648,10000,57826.413011312485,0.8245702981948853,0.9153898358345032,0.7346000075340271,1.282979965209961,50000 -4460.06673002243,5.46102499961853,53810.46590876579,116642,0,53810.46590876579,0.615600049495697,1.860821962356568,10000,58282.51025009155,0.8141601085662842,0.9461965560913086,0.7366200089454651,1.266655683517456,50000 -4494.974093198776,5.504605293273926,54230.78221082688,117554,0,54230.78221082688,0.6165000200271606,1.889684081077576,10000,58737.82883524895,0.8140038847923279,0.9657190442085266,0.7342599630355835,1.2892900705337524,50000 -4531.71052479744,5.548776865005493,54650.84451293945,118465,0,54650.84451293945,0.6170000433921814,1.8633793592453003,10000,59194.7225048542,0.8282226324081421,0.9008296132087708,0.7376599907875061,1.2724199295043943,50000 -4566.349142313004,5.601795434951782,55070.77024292946,119375,0,55070.77024292946,0.6144000291824341,1.868046522140503,10000,59649.38969564438,0.81787109375,0.9437991380691528,0.7403199672698975,1.265620231628418,50000 -4602.237959384918,5.651767730712891,55490.71073937416,120285,0,55490.71073937416,0.6229000091552734,1.870540738105774,10000,60105.31962871552,0.8151366710662842,0.9468899369239808,0.7399799823760986,1.2717607021331787,50000 -4639.570498466492,5.699680805206299,55910.66052460671,121196,0,55910.66052460671,0.6231000423431396,1.803972601890564,10000,60562.70079231262,0.8320898413658142,0.8478189706802368,0.742419958114624,1.217381715774536,50000 -4674.371244430542,5.749354600906372,56330.80289840698,122107,0,56330.80289840698,0.6258000135421753,1.83869731426239,10000,61017.7448694706,0.8207421898841858,0.926133930683136,0.7434799671173096,1.2434664964675903,50000 -4710.162944316864,5.793039083480835,56751.03601574898,123019,0,56751.03601574898,0.6247000098228455,1.8158106803894043,10000,61473.86522865296,0.8250390291213989,0.8884112238883972,0.7450599670410156,1.2220823764801023,50000 -4743.8915066719055,5.839264869689941,57171.007404088974,123930,0,57171.007404088974,0.6234000325202942,1.8210102319717407,10000,61927.66263628006,0.8338476419448853,0.8774459362030029,0.7470399737358093,1.2402421236038208,50000 -4781.690023899078,5.884131908416748,57591.029952049255,124839,0,57591.029952049255,0.6278000473976135,1.8021245002746584,10000,62385.57867670059,0.8262499570846558,0.8751652240753174,0.7460199594497681,1.2193933725357056,50000 -4817.354856729507,5.930292129516602,58010.99251627922,125749,0,58010.99251627922,0.6221000552177429,1.8037065267562864,10000,62841.30251932144,0.8309765458106995,0.8596699833869934,0.7472400069236755,1.2096494436264038,50000 -4853.018300771713,5.975062370300293,58431.05421996117,126660,0,58431.05421996117,0.6221000552177429,1.84426748752594,10000,63297.124264001846,0.83447265625,0.8702915906906128,0.7460399866104126,1.243055820465088,50000 -4889.977033615112,6.026872634887695,58851.27061104775,127571,0,58851.27061104775,0.6247000098228455,1.8229761123657229,10000,63754.40138721466,0.8270898461341858,0.900422215461731,0.7506600022315979,1.2298741340637207,50000 -4924.630076885223,6.076805591583252,59271.248304367065,128484,0,59271.248304367065,0.6278000473976135,1.7846035957336426,10000,64209.13311076164,0.8315820097923279,0.8549019694328308,0.7501399517059326,1.189975619316101,50000 -4962.272361755371,6.125463962554932,59691.69730448723,129393,0,59691.69730448723,0.6319000124931335,1.7807656526565552,10000,64667.32375311852,0.8374413847923279,0.8287783861160278,0.7497599720954895,1.192121505737305,50000 -4996.821247339249,6.174273729324341,60111.69879961014,130301,0,60111.69879961014,0.6328000426292419,1.7882546186447144,10000,65121.97418880463,0.8481640219688416,0.8119232058525085,0.7521599531173706,1.201475977897644,50000 -5031.673780918121,6.2202746868133545,60531.89884757996,131212,0,60531.89884757996,0.6363000273704529,1.7767421007156372,10000,65577.1236114502,0.8323046565055847,0.855158269405365,0.750819981098175,1.1929678916931152,50000 -5068.405083656311,6.265010595321655,60951.83554720879,132124,0,60951.83554720879,0.6365000009536743,1.767372488975525,10000,66033.88747620583,0.8398241996765137,0.8325179815292358,0.7538399696350098,1.1827151775360107,50000 -5103.814422369003,6.311906576156616,61371.80930924416,133034,0,61371.80930924416,0.636400043964386,1.7902463674545288,10000,66489.36799764633,0.8498437404632568,0.8029429912567139,0.7543999552726746,1.2014082670211792,50000 -5139.275134801865,6.361809492111206,61791.847019433975,133945,0,61791.847019433975,0.6361000537872314,1.7660489082336426,10000,66944.96662092209,0.8379882574081421,0.8267109394073486,0.7539199590682983,1.1843620538711548,50000 -5174.021430253983,6.409978866577148,62211.88814759255,134857,0,62211.88814759255,0.6372000575065613,1.7756555080413818,10000,67399.85404849052,0.8421288728713989,0.8246394395828247,0.7559399604797363,1.183579444885254,50000 -5210.1441123485565,6.4558424949646,62631.87271881104,135766,0,62631.87271881104,0.635200023651123,1.788246989250183,10000,67856.05798530579,0.847460925579071,0.8194810152053833,0.7557799816131592,1.199291110038757,50000 -5248.283586978912,6.5085837841033936,63052.01003956795,136677,0,63052.01003956795,0.6416000127792358,1.7473360300064087,10000,68314.43851542473,0.8434374928474426,0.8096722364425659,0.75764000415802,1.1676242351531982,50000 -5285.890654087067,6.560576677322388,63472.32473063469,137587,0,63472.32473063469,0.6402000188827515,1.75056791305542,10000,68772.46301627159,0.8507421612739563,0.7920365929603577,0.7602199912071228,1.160401463508606,50000 -5321.47384428978,6.611251592636108,63892.652406454086,138498,0,63892.652406454086,0.6422000527381897,1.75530207157135,10000,69228.47512078285,0.8550781011581421,0.7727980613708496,0.7599799633026123,1.1643576622009275,50000 -5361.627897024155,6.65903902053833,64312.980843544006,139410,0,64312.980843544006,0.65010005235672,1.7389004230499268,10000,69689.05562400818,0.8495312333106995,0.8162821531295776,0.7620399594306946,1.1751127243041992,50000 -5397.175455093384,6.705945253372192,64732.98385691643,140320,0,64732.98385691643,0.6412000060081482,1.7588180303573608,10000,70144.70323824883,0.8509179353713989,0.797985315322876,0.7607600092887878,1.1660995483398438,50000 -5432.35400223732,6.752318620681763,65152.95014190674,141230,0,65152.95014190674,0.648900032043457,1.7374924421310425,10000,70599.94521164894,0.8560156226158142,0.7755244374275208,0.7628999948501587,1.1578243970870972,50000 -5470.536109209061,6.80316686630249,65573.02896523476,142141,0,65573.02896523476,0.6499000191688538,1.714654803276062,10000,71058.30751371384,0.85267573595047,0.7707575559616089,0.7630599737167358,1.145895004272461,50000 -5509.510308504105,6.851216554641724,65993.15498352051,143053,0,65993.15498352051,0.6479000449180603,1.7201541662216189,10000,71517.50735139847,0.8552343845367432,0.7671908140182495,0.7658599615097046,1.1423100233078003,50000 -5547.255858421326,6.90261173248291,66413.09290719032,143963,0,66413.09290719032,0.648300051689148,1.7129008769989014,10000,71975.2932009697,0.8605859279632568,0.7456603646278381,0.7659599781036377,1.1388903856277466,50000 -5584.90097284317,6.950519561767578,66833.13764357567,144873,0,66833.13764357567,0.6476000547409058,1.7198455333709717,10000,72433.08190202713,0.8684179782867432,0.724058210849762,0.7653599977493286,1.1371018886566162,50000 -5625.206836462021,7.006183862686157,67253.06000828743,145782,0,67253.06000828743,0.6463000178337097,1.7320133447647097,10000,72893.4165430069,0.8575780987739563,0.7682866454124451,0.764519989490509,1.1535091400146484,50000 -5662.5311868190765,7.058150768280029,67673.04480743408,146692,0,67673.04480743408,0.6510000228881836,1.7285139560699463,10000,73350.82883524895,0.8640234470367432,0.74609375,0.7672399878501892,1.1380559206008911,50000 -5698.334320783615,7.110531806945801,68093.3377597332,147604,0,68093.3377597332,0.6548000574111938,1.6992754936218262,10000,73807.02789735794,0.8709960579872131,0.6939271092414856,0.7691799998283386,1.1134551763534546,50000 -5735.344721794128,7.158812522888184,68513.33066034317,148514,0,68513.33066034317,0.6541000604629517,1.6979526281356812,10000,74264.12940835953,0.8633202910423279,0.7357399463653564,0.7691999673843384,1.1223732233047483,50000 -5773.777179718018,7.209390163421631,68933.41969394684,149422,0,68933.41969394684,0.65420001745224,1.6978703737258911,10000,74722.75263619423,0.8650000095367432,0.7219473123550415,0.7705199718475342,1.1146769523620603,50000 -5812.420975446701,7.256996393203735,69353.35374689102,150330,0,69353.35374689102,0.652400016784668,1.7470016479492188,10000,75181.42854118347,0.8721874952316284,0.741978645324707,0.7699599862098694,1.1527972221374512,50000 -5851.233624219894,7.306897401809692,69773.56124687195,151238,0,69773.56124687195,0.6549000144004822,1.6994355916976929,10000,75640.54940104485,0.8633593320846558,0.7349731922149658,0.7711600065231323,1.12005615234375,50000 -5886.390437841415,7.358208894729614,70193.82568836212,152146,0,70193.82568836212,0.6581000089645386,1.682053804397583,10000,76096.07240009308,0.8692968487739563,0.7008588314056396,0.7726799845695496,1.1004048585891724,50000 -5924.112850666046,7.41174578666687,70613.74158143997,153053,0,70613.74158143997,0.6557000279426575,1.7197359800338743,10000,76553.8149971962,0.8724218606948853,0.7150508761405945,0.7715199589729309,1.1307733058929443,50000 -5959.418373346329,7.46028733253479,71033.70966124535,153959,0,71033.70966124535,0.6573000550270081,1.6898996829986572,10000,77009.18772745132,0.8653124570846558,0.7207576632499695,0.7726999521255493,1.1108520030975342,50000 -5993.421750068665,7.513123273849487,71454.04873561859,154870,0,71454.04873561859,0.6573000550270081,1.684720754623413,10000,77463.63437724113,0.8743554353713989,0.6960504651069641,0.7752000093460083,1.1026690006256104,50000 -6029.695031404495,7.564982175827026,71874.4541144371,155782,0,71874.4541144371,0.6583000421524048,1.697216510772705,10000,77920.41535067558,0.872851550579071,0.6959668397903442,0.7761799693107605,1.1069529056549072,50000 -6064.638352155685,7.620237112045288,72294.57926630974,156692,0,72294.57926630974,0.6554000377655029,1.695989727973938,10000,78375.58976507187,0.8742382526397705,0.7012655735015869,0.7750200033187866,1.1078037023544312,50000 -6100.034034729004,7.676239013671875,72714.48941850662,157601,0,72714.48941850662,0.6605000495910645,1.682713747024536,10000,78831.00256085396,0.8741992115974426,0.7010714411735535,0.776919960975647,1.1080583333969116,50000 -6135.559215307236,7.725278377532959,73134.6789803505,158513,0,73134.6789803505,0.6610000133514404,1.6656324863433838,10000,79286.81630277634,0.8779296875,0.6749905943870544,0.7773799896240234,1.0918878316879272,50000 -6171.688840389252,7.782891035079956,73554.84628415108,159424,0,73554.84628415108,0.6589000225067139,1.6929190158843994,10000,79743.22312545776,0.8819335699081421,0.6788507103919983,0.7778399586677551,1.1027580499649048,50000 -6209.042106866837,7.834174156188965,73974.78279566765,160336,0,73974.78279566765,0.6620000600814819,1.66422700881958,10000,80200.6146595478,0.8783202767372131,0.6747952103614807,0.7796799540519714,1.08854079246521,50000 -6244.829773902893,7.885188102722168,74395.06458425522,161245,0,74395.06458425522,0.6619000434875488,1.676234245300293,10000,80656.78582334518,0.8847070336341858,0.6641880869865417,0.7793799638748169,1.094588279724121,50000 -6280.536881446838,7.93939733505249,74815.33997106552,162155,0,74815.33997106552,0.6637000441551208,1.6746059656143188,10000,81112.87297201157,0.8860546946525574,0.6558742523193359,0.7773999571800232,1.094112992286682,50000 -6315.99707698822,7.991906642913818,75235.40973472595,163065,0,75235.40973472595,0.6628000140190125,1.6721868515014648,10000,81568.50692629814,0.8809179663658142,0.6672109961509705,0.7792999744415283,1.0870453119277954,50000 -6352.588169336319,8.046016693115234,75655.63436961174,163977,0,75655.63436961174,0.6634000539779663,1.6664263010025024,10000,82025.42798662186,0.8843554258346558,0.6567904353141785,0.7809799909591675,1.085119605064392,50000 -6390.728995323181,8.098177194595337,76075.67520737648,164889,0,76075.67520737648,0.6682000160217285,1.664968967437744,10000,82483.71307039261,0.8860155940055847,0.6490659117698669,0.7811999917030334,1.085237979888916,50000 -6429.695057630539,8.155537843704224,76495.9494357109,165800,0,76495.9494357109,0.6676000356674194,1.659148931503296,10000,82943.06176996231,0.8830273151397705,0.6618872880935669,0.7819199562072754,1.0855050086975098,50000 -6464.591763496399,8.211509466171265,76916.21601438522,166711,0,76916.21601438522,0.6674000024795532,1.6567219495773315,10000,83398.33122730255,0.8886913657188416,0.645067036151886,0.7826399803161621,1.0786688327789309,50000 -6501.212018728256,8.268371820449829,77336.24432229996,167620,0,77336.24432229996,0.6629000306129456,1.6669039726257324,10000,83855.08708715439,0.88818359375,0.6586396098136902,0.7811399698257446,1.0905895233154297,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/measurements.csv deleted file mode 100644 index a1f028677..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1868 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.29378036,6.907754,,,,,,,,,,,,,, -1,,,0.000800781243015,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,33.638116121292114,61.88614153862,33.638116121292114,28.24734091758728,0.0,0.0 -100,0.32093313,6.9054646,,,,,,,,,,,,,, -200,0.43192223,6.8820705,,,,,,,,,,,,,, -300,0.49081007,6.8423805,,,,,,,,,,,,,, -400,0.5400204,6.8066015,,,,,,,,,,,,,, -500,0.5087024,6.787075,,,,,,,,,,,,,, -600,0.6336219,6.7353587,,,,,,,,,,,,,, -700,0.6481859,6.8159084,,,,,,,,,,,,,, -800,0.8692016,6.6820574,,,,,,,,,,,,,, -861,,,0.0163085926324129,6.390534400939941,0.0160399992018938,6.40871000289917,50000.0,0.0138000007718801,6.452201843261719,10000.0,453.676509141922,517.1768567562103,453.676509141922,63.434080839157104,0.0171256065368652,0.0 -900,1.3987123,6.598748,,,,,,,,,,,,,, -1000,1.3596609,6.5997133,,,,,,,,,,,,,, -1100,1.3272564,6.5207043,,,,,,,,,,,,,, -1200,1.0169679,6.481376,,,,,,,,,,,,,, -1300,1.1300431,6.5050597,,,,,,,,,,,,,, -1400,1.0779511,6.4661975,,,,,,,,,,,,,, -1500,1.3657479,6.488914,,,,,,,,,,,,,, -1600,1.1397434,6.359528,,,,,,,,,,,,,, -1700,0.8195744,6.625597,,,,,,,,,,,,,, -1772,,,0.0479882806539535,5.804277896881104,0.0460199974477291,5.834495544433594,50000.0,0.036700002849102,5.948112487792969,10000.0,873.8910715579987,972.0015864372252,873.8910715579987,97.9622402191162,0.04842209815979,0.0 -1800,1.0069796,6.569232,,,,,,,,,,,,,, -1900,0.98239475,6.7253504,,,,,,,,,,,,,, -2000,1.0651392,6.2220435,,,,,,,,,,,,,, -2100,1.2516338,6.252002,,,,,,,,,,,,,, -2200,1.0460929,6.2282367,,,,,,,,,,,,,, -2300,1.3127251,6.2030625,,,,,,,,,,,,,, -2400,1.1925633,6.197405,,,,,,,,,,,,,, -2500,1.2375587,6.178865,,,,,,,,,,,,,, -2600,0.94939065,6.1316533,,,,,,,,,,,,,, -2684,,,0.0750781223177909,5.424489498138428,0.0673599988222122,5.472591876983643,50000.0,0.0516000017523765,5.641368389129639,10000.0,1294.1512134075165,1429.2565631866455,1294.1512134075165,134.87805843353271,0.0764341354370117,0.0 -2700,0.9956859,6.568632,,,,,,,,,,,,,, -2800,1.3829552,6.1138964,,,,,,,,,,,,,, -2900,1.2814574,6.0841503,,,,,,,,,,,,,, -3000,0.9298249,6.1438475,,,,,,,,,,,,,, -3100,0.9341716,6.04572,,,,,,,,,,,,,, -3200,1.2432295,6.597198,,,,,,,,,,,,,, -3300,1.0589963,5.973632,,,,,,,,,,,,,, -3400,1.8107166,6.0419407,,,,,,,,,,,,,, -3500,1.3874265,6.208271,,,,,,,,,,,,,, -3598,,,0.1093554645776748,5.114729881286621,0.1017999947071075,5.160167217254639,50000.0,0.077000007033348,5.386642932891846,10000.0,1714.4163200855255,1883.0908553600311,1714.4163200855255,168.3693425655365,0.1036822795867919,0.0 -3600,1.0308965,6.5021505,,,,,,,,,,,,,, -3700,1.2052927,5.8895397,,,,,,,,,,,,,, -3800,1.692345,5.9520235,,,,,,,,,,,,,, -3900,1.0130464,5.9367647,,,,,,,,,,,,,, -4000,0.9911086,5.8610077,,,,,,,,,,,,,, -4100,1.0323453,5.818923,,,,,,,,,,,,,, -4200,1.0513152,5.774354,,,,,,,,,,,,,, -4300,1.2098572,5.7651615,,,,,,,,,,,,,, -4400,0.8186908,6.474824,,,,,,,,,,,,,, -4500,1.1849225,5.733221,,,,,,,,,,,,,, -4511,,,0.1607226580381393,4.637515068054199,0.1462000012397766,4.709775924682617,50000.0,0.1103000044822692,4.995343208312988,10000.0,2134.3432302474976,2339.002152442932,2134.3432302474976,204.27332472801208,0.1335217952728271,0.0 -4600,1.2857487,5.7103915,,,,,,,,,,,,,, -4700,1.251294,5.7016344,,,,,,,,,,,,,, -4800,0.9432155,6.554584,,,,,,,,,,,,,, -4900,1.1558504,5.6393437,,,,,,,,,,,,,, -5000,0.99342924,5.7486925,,,,,,,,,,,,,, -5100,1.5304261,5.701103,,,,,,,,,,,,,, -5200,1.3217181,5.552513,,,,,,,,,,,,,, -5300,0.8622296,6.424939,,,,,,,,,,,,,, -5400,1.1694787,6.5840235,,,,,,,,,,,,,, -5424,,,0.2049999982118606,4.282425880432129,0.1854999959468841,4.378244400024414,50000.0,0.1410000026226043,4.718472480773926,10000.0,2554.3411922454834,2794.102832555771,2554.3411922454834,239.3009958267212,0.158071756362915,0.0 -5500,0.9021115,5.8201356,,,,,,,,,,,,,, -5600,1.2970493,5.477948,,,,,,,,,,,,,, -5700,0.9790323,5.731115,,,,,,,,,,,,,, -5800,0.8329662,5.938535,,,,,,,,,,,,,, -5900,0.9916728,5.49112,,,,,,,,,,,,,, -6000,1.1580856,5.753063,,,,,,,,,,,,,, -6100,1.1538005,5.3209343,,,,,,,,,,,,,, -6200,1.1247162,5.803546,,,,,,,,,,,,,, -6300,1.0695916,5.3154407,,,,,,,,,,,,,, -6337,,,0.2488867193460464,3.939825773239136,0.2300200015306472,4.036154747009277,50000.0,0.174900010228157,4.422224998474121,10000.0,2974.4840099811554,3249.46443939209,2974.4840099811554,274.43955540657043,0.1884043216705322,0.0 -6400,1.0529957,5.458369,,,,,,,,,,,,,, -6500,1.2539123,5.552887,,,,,,,,,,,,,, -6600,0.925616,5.720832,,,,,,,,,,,,,, -6700,0.7664814,6.4468184,,,,,,,,,,,,,, -6800,0.99296504,5.532189,,,,,,,,,,,,,, -6900,1.213226,5.2165093,,,,,,,,,,,,,, -7000,0.9429502,6.482198,,,,,,,,,,,,,, -7100,1.1178781,6.4524403,,,,,,,,,,,,,, -7200,0.93158615,6.342278,,,,,,,,,,,,,, -7251,,,0.2895117104053497,3.659607172012329,0.2579399943351745,3.816094875335693,50000.0,0.2026000022888183,4.218814849853516,10000.0,3394.577829360962,3703.3408839702606,3394.577829360962,308.14124870300293,0.2187526226043701,0.0 -7300,0.77762914,6.315944,,,,,,,,,,,,,, -7400,1.2180384,5.1089144,,,,,,,,,,,,,, -7500,0.8897846,5.6507983,,,,,,,,,,,,,, -7600,0.95682216,5.1000366,,,,,,,,,,,,,, -7700,1.0242583,5.1163054,,,,,,,,,,,,,, -7800,1.2472111,5.161247,,,,,,,,,,,,,, -7900,0.8376461,5.781951,,,,,,,,,,,,,, -8000,1.4289161,5.0425367,,,,,,,,,,,,,, -8100,1.0636063,5.0309377,,,,,,,,,,,,,, -8163,,,0.3244921863079071,3.4557206630706787,0.3014000058174133,3.572089910507202,50000.0,0.2290000170469284,4.028704643249512,10000.0,3814.550952911377,4159.377365589142,3814.550952911377,344.1261923313141,0.2454137802124023,0.0 -8200,0.6923196,6.3016787,,,,,,,,,,,,,, -8300,0.96808475,5.0613375,,,,,,,,,,,,,, -8400,1.1968992,4.9376535,,,,,,,,,,,,,, -8500,1.4005997,4.9933805,,,,,,,,,,,,,, -8600,1.0790908,5.0912867,,,,,,,,,,,,,, -8700,0.8800077,5.340657,,,,,,,,,,,,,, -8800,1.0271211,4.9056416,,,,,,,,,,,,,, -8900,0.9069841,5.858944,,,,,,,,,,,,,, -9000,1.0951066,4.8969684,,,,,,,,,,,,,, -9074,,,0.3578320145606994,3.234337329864502,0.3291600048542022,3.37368392944336,50000.0,0.2499000132083892,3.857383728027344,10000.0,4234.478483200073,4613.234064817429,4234.478483200073,377.9731593132019,0.2774481773376465,0.0 -9100,0.681334,6.2926087,,,,,,,,,,,,,, -9200,0.707688,6.2649384,,,,,,,,,,,,,, -9300,0.6810168,6.0618277,,,,,,,,,,,,,, -9400,0.8330123,6.3152504,,,,,,,,,,,,,, -9500,0.722674,6.2927284,,,,,,,,,,,,,, -9600,0.9734195,4.972359,,,,,,,,,,,,,, -9700,1.3159074,4.92693,,,,,,,,,,,,,, -9800,1.0409005,4.8286533,,,,,,,,,,,,,, -9900,0.9504471,4.7760925,,,,,,,,,,,,,, -9986,,,0.3981249928474426,2.985623359680176,0.3610599935054779,3.1759495735168457,50000.0,0.2779000103473663,3.6841766834259033,10000.0,4654.747854232788,5066.819159269333,4654.747854232788,411.2083342075348,0.3075971603393554,0.0 -10000,0.96230155,4.8009567,,,,,,,,,,,,,, -10100,0.86380297,5.4900217,,,,,,,,,,,,,, -10200,0.7995179,5.772729,,,,,,,,,,,,,, -10300,0.97974324,4.709177,,,,,,,,,,,,,, -10400,0.93970644,4.9333105,,,,,,,,,,,,,, -10500,1.0012776,4.648302,,,,,,,,,,,,,, -10600,0.95370597,4.997157,,,,,,,,,,,,,, -10700,0.9006133,5.163813,,,,,,,,,,,,,, -10800,0.8498352,6.249497,,,,,,,,,,,,,, -10899,,,0.4224413931369781,2.8559553623199463,0.391759991645813,2.990187644958496,50000.0,0.3056000173091888,3.5289344787597656,10000.0,5074.991655111313,5520.955825805664,5074.991655111313,445.02062249183655,0.3370425701141357,0.0 -10900,0.90817505,4.721223,,,,,,,,,,,,,, -11000,0.8591127,4.904366,,,,,,,,,,,,,, -11100,0.98950475,4.6460238,,,,,,,,,,,,,, -11200,0.6480378,5.9756026,,,,,,,,,,,,,, -11300,1.0499123,4.7907314,,,,,,,,,,,,,, -11400,0.6655779,5.9381585,,,,,,,,,,,,,, -11500,1.1542008,4.545001,,,,,,,,,,,,,, -11600,0.85418046,4.5301933,,,,,,,,,,,,,, -11700,0.92468035,4.6056576,,,,,,,,,,,,,, -11800,0.95207477,4.6491227,,,,,,,,,,,,,, -11811,,,0.4569726586341858,2.680199146270752,0.4228200018405914,2.834578275680542,50000.0,0.3346000015735626,3.379513740539551,10000.0,5494.998826980591,5975.380021810532,5494.998826980591,479.3563580513001,0.3674821853637695,0.0 -11900,0.9532515,4.454133,,,,,,,,,,,,,, -12000,0.9593404,4.4709077,,,,,,,,,,,,,, -12100,1.1640481,4.518027,,,,,,,,,,,,,, -12200,0.6647208,5.9849195,,,,,,,,,,,,,, -12300,0.7031978,5.4470453,,,,,,,,,,,,,, -12400,0.6744128,6.0640745,,,,,,,,,,,,,, -12500,0.95290685,4.49859,,,,,,,,,,,,,, -12600,0.72124505,6.119844,,,,,,,,,,,,,, -12700,1.0149512,4.504995,,,,,,,,,,,,,, -12722,,,0.4777539074420929,2.5539870262146,0.4357999861240387,2.754807949066162,50000.0,0.3421000242233276,3.313959836959839,10000.0,5914.95384311676,6429.472145318985,5914.95384311676,513.4131627082825,0.3969335556030273,0.0 -12800,0.92112327,4.512724,,,,,,,,,,,,,, -12900,0.93616486,4.3652925,,,,,,,,,,,,,, -13000,0.6465939,6.119732,,,,,,,,,,,,,, -13100,0.94249487,4.5583014,,,,,,,,,,,,,, -13200,0.92576426,4.456536,,,,,,,,,,,,,, -13300,0.9342562,4.4053516,,,,,,,,,,,,,, -13400,0.8967752,4.5743375,,,,,,,,,,,,,, -13500,0.7396695,5.334022,,,,,,,,,,,,,, -13600,0.942825,4.3497133,,,,,,,,,,,,,, -13634,,,0.4913281202316284,2.4966847896575928,0.4591799974441528,2.6533918380737305,50000.0,0.3613000214099884,3.21649432182312,10000.0,6335.079059839249,6884.635620117188,6335.079059839249,548.3734202384949,0.4250233173370361,0.0 -13700,0.97315115,4.391658,,,,,,,,,,,,,, -13800,1.0868237,4.420297,,,,,,,,,,,,,, -13900,0.7744133,5.038829,,,,,,,,,,,,,, -14000,0.84004563,4.6203976,,,,,,,,,,,,,, -14100,0.92475104,4.3037524,,,,,,,,,,,,,, -14200,1.0509282,4.319062,,,,,,,,,,,,,, -14300,0.9897817,4.428795,,,,,,,,,,,,,, -14400,0.6051523,5.667149,,,,,,,,,,,,,, -14500,0.8552441,4.25669,,,,,,,,,,,,,, -14550,,,0.514355480670929,2.332359790802002,0.4775199890136719,2.504133939743042,50000.0,0.3671000301837921,3.088716983795166,10000.0,6755.355276584625,7339.585505962372,6755.355276584625,582.966285943985,0.4544599056243896,0.0 -14600,0.8779924,4.2177153,,,,,,,,,,,,,, -14700,0.7497479,5.81092,,,,,,,,,,,,,, -14800,0.9369479,4.2695236,,,,,,,,,,,,,, -14900,0.6469347,5.745248,,,,,,,,,,,,,, -15000,0.70725507,6.003191,,,,,,,,,,,,,, -15100,0.8517229,4.2128625,,,,,,,,,,,,,, -15200,0.8995087,4.170624,,,,,,,,,,,,,, -15300,0.9242904,4.4560537,,,,,,,,,,,,,, -15400,0.93761593,4.250752,,,,,,,,,,,,,, -15464,,,0.5285546779632568,2.330787420272827,0.4850399792194366,2.5376617908477783,50000.0,0.3820000290870666,3.105388879776001,10000.0,7175.678120136261,7793.58558678627,7175.678120136261,616.5665156841278,0.4815404415130615,0.0 -15500,0.8600569,4.520481,,,,,,,,,,,,,, -15600,0.9078675,4.2066994,,,,,,,,,,,,,, -15700,0.89040065,4.4838314,,,,,,,,,,,,,, -15800,0.725104,5.3526497,,,,,,,,,,,,,, -15900,0.8761096,4.1849594,,,,,,,,,,,,,, -16000,0.9208193,4.1465487,,,,,,,,,,,,,, -16100,0.9053377,4.253268,,,,,,,,,,,,,, -16200,0.8381008,4.6533723,,,,,,,,,,,,,, -16300,0.9085048,4.2640085,,,,,,,,,,,,,, -16379,,,0.5498241782188416,2.2195632457733154,0.5060399770736694,2.409464120864868,50000.0,0.3884000182151794,3.026716470718384,10000.0,7596.013070583343,8248.45794916153,7596.013070583343,651.0233614444733,0.5114836692810059,0.0 -16400,0.92678446,4.361484,,,,,,,,,,,,,, -16500,0.8465249,4.524189,,,,,,,,,,,,,, -16600,1.0277636,4.148208,,,,,,,,,,,,,, -16700,0.9911067,4.4627476,,,,,,,,,,,,,, -16800,1.0328287,4.6585255,,,,,,,,,,,,,, -16900,0.94972163,4.1903996,,,,,,,,,,,,,, -17000,0.7361724,4.81029,,,,,,,,,,,,,, -17100,0.93293285,4.096544,,,,,,,,,,,,,, -17200,1.0747104,4.14526,,,,,,,,,,,,,, -17293,,,0.551562488079071,2.1806676387786865,0.5150399804115295,2.34705114364624,50000.0,0.409600019454956,2.956143617630005,10000.0,8016.122918605804,8700.978935956955,8016.122918605804,683.3534235954285,0.5411381721496582,0.0 -17300,0.9485876,4.1884785,,,,,,,,,,,,,, -17400,0.9583888,4.1961203,,,,,,,,,,,,,, -17500,0.7941146,4.857909,,,,,,,,,,,,,, -17600,0.792579,5.4380674,,,,,,,,,,,,,, -17700,0.9199699,4.160411,,,,,,,,,,,,,, -17800,1.0176833,4.160239,,,,,,,,,,,,,, -17900,0.695653,5.9279704,,,,,,,,,,,,,, -18000,1.1068286,4.2699094,,,,,,,,,,,,,, -18100,0.90857154,4.0576096,,,,,,,,,,,,,, -18200,0.9356694,4.0997505,,,,,,,,,,,,,, -18206,,,0.5733984112739563,2.088829040527344,0.5263000130653381,2.2876055240631104,50000.0,0.4150000214576721,2.885838031768799,10000.0,8436.345466375351,9155.736872911451,8436.345466375351,717.8099949359894,0.5694966316223145,0.0 -18300,0.94634336,4.2023373,,,,,,,,,,,,,, -18400,1.0655799,4.0250154,,,,,,,,,,,,,, -18500,0.92376816,4.228206,,,,,,,,,,,,,, -18600,0.85671383,4.7922897,,,,,,,,,,,,,, -18700,0.69200003,5.7817535,,,,,,,,,,,,,, -18800,0.96515167,4.0537786,,,,,,,,,,,,,, -18900,0.92617923,4.398779,,,,,,,,,,,,,, -19000,0.92323256,4.041466,,,,,,,,,,,,,, -19100,0.90305847,4.09908,,,,,,,,,,,,,, -19119,,,0.6010546684265137,1.936023116111756,0.5326799750328064,2.2370855808258057,50000.0,0.4210000336170196,2.853330612182617,10000.0,8856.345707178116,9609.926443576813,8856.345707178116,751.9210438728333,0.5966720581054688,0.0 -19200,0.8058081,4.6880145,,,,,,,,,,,,,, -19300,0.88443136,4.327604,,,,,,,,,,,,,, -19400,0.91233677,4.248243,,,,,,,,,,,,,, -19500,1.0255277,4.1521063,,,,,,,,,,,,,, -19600,0.8950974,4.0547338,,,,,,,,,,,,,, -19700,0.78663474,4.646948,,,,,,,,,,,,,, -19800,0.9517504,3.9954712,,,,,,,,,,,,,, -19900,0.92420423,4.016351,,,,,,,,,,,,,, -20000,0.6610921,5.566089,,,,,,,,,,,,,, -20024,,,0.581347644329071,2.0001742839813232,0.541920006275177,2.18528151512146,50000.0,0.4229000210762024,2.806061267852783,10000.0,9276.2863240242,10064.165142297745,9276.2863240242,786.1423320770264,0.6236920356750488,0.0 -20100,0.9010285,4.1208324,,,,,,,,,,,,,, -20200,0.7461008,5.1886587,,,,,,,,,,,,,, -20300,0.9369366,4.0841193,,,,,,,,,,,,,, -20400,0.9276465,4.1808653,,,,,,,,,,,,,, -20500,0.97487134,3.9524648,,,,,,,,,,,,,, -20600,0.9178257,4.2288933,,,,,,,,,,,,,, -20700,0.695153,5.750394,,,,,,,,,,,,,, -20800,0.6716325,5.677337,,,,,,,,,,,,,, -20900,0.74231976,5.6552486,,,,,,,,,,,,,, -20935,,,0.5981249809265137,1.97610092163086,0.5450199842453003,2.1948776245117188,50000.0,0.4413000345230102,2.791691541671753,10000.0,9696.536899328232,10519.46850657463,9696.536899328232,821.1112020015717,0.6576018333435059,0.0 -21000,0.825948,4.807064,,,,,,,,,,,,,, -21100,0.81154966,5.633216,,,,,,,,,,,,,, -21200,0.90361387,4.1796694,,,,,,,,,,,,,, -21300,0.9554571,4.0809364,,,,,,,,,,,,,, -21400,0.7530163,5.1642365,,,,,,,,,,,,,, -21500,1.0195984,3.9494815,,,,,,,,,,,,,, -21600,0.67373496,5.624689,,,,,,,,,,,,,, -21700,0.7389797,4.772542,,,,,,,,,,,,,, -21800,0.84822696,4.183843,,,,,,,,,,,,,, -21847,,,0.6225780844688416,1.8207433223724363,0.5593400001525879,2.095564365386963,50000.0,0.4457000195980072,2.695769786834717,10000.0,10116.994480848312,10974.81406068802,10116.994480848312,855.9199452400208,0.6856215000152588,0.0 -21900,0.9192051,3.9683576,,,,,,,,,,,,,, -22000,0.72682077,4.921135,,,,,,,,,,,,,, -22100,1.0391148,3.9162078,,,,,,,,,,,,,, -22200,0.97700506,4.0089397,,,,,,,,,,,,,, -22300,0.97195065,3.9500198,,,,,,,,,,,,,, -22400,0.7960354,5.6149507,,,,,,,,,,,,,, -22500,0.64593744,5.657979,,,,,,,,,,,,,, -22600,0.9295452,4.5530705,,,,,,,,,,,,,, -22700,0.9074857,4.151124,,,,,,,,,,,,,, -22761,,,0.6095117330551147,1.905519366264344,0.5639399886131287,2.1037421226501465,50000.0,0.448600023984909,2.720872640609741,10000.0,10537.3939139843,11429.881940364838,10537.3939139843,890.505363702774,0.7179117202758789,0.0 -22800,1.0147501,4.2088814,,,,,,,,,,,,,, -22900,0.88501996,3.9680648,,,,,,,,,,,,,, -23000,0.723352,5.7327995,,,,,,,,,,,,,, -23100,0.90488315,3.8730764,,,,,,,,,,,,,, -23200,0.80754656,5.341397,,,,,,,,,,,,,, -23300,0.98895305,3.9865515,,,,,,,,,,,,,, -23400,0.8944821,4.3467484,,,,,,,,,,,,,, -23500,1.0195292,3.9574642,,,,,,,,,,,,,, -23600,1.0106058,3.8812478,,,,,,,,,,,,,, -23672,,,0.6233788728713989,1.81912624835968,0.5717200040817261,2.032801628112793,50000.0,0.4562000334262848,2.6657521724700928,10000.0,10957.388006210327,11882.973408699036,10957.388006210327,923.5195500850676,0.7508177757263184,0.0 -23700,0.9555288,4.3004627,,,,,,,,,,,,,, -23800,0.88453925,4.322708,,,,,,,,,,,,,, -23900,0.69513744,5.6702404,,,,,,,,,,,,,, -24000,0.79307485,4.437585,,,,,,,,,,,,,, -24100,1.0884377,3.9791045,,,,,,,,,,,,,, -24200,0.88077843,5.678124,,,,,,,,,,,,,, -24300,0.8063163,4.656919,,,,,,,,,,,,,, -24400,0.7276219,4.8741174,,,,,,,,,,,,,, -24500,0.94800866,4.0040207,,,,,,,,,,,,,, -24585,,,0.6335155963897705,1.789616942405701,0.5715799927711487,2.054870367050171,50000.0,0.456900030374527,2.673712015151977,10000.0,11377.64279603958,12337.427982330322,11377.64279603958,957.6350147724152,0.784376859664917,0.0 -24600,0.95787275,3.9469788,,,,,,,,,,,,,, -24700,0.7162851,5.7086687,,,,,,,,,,,,,, -24800,1.0306379,3.8243089,,,,,,,,,,,,,, -24900,1.0842366,3.9471023,,,,,,,,,,,,,, -25000,0.73619246,5.3067145,,,,,,,,,,,,,, -25100,0.7757311,5.0722737,,,,,,,,,,,,,, -25200,0.96749055,3.9594815,,,,,,,,,,,,,, -25300,1.0279956,3.8586116,,,,,,,,,,,,,, -25400,0.947784,4.0097637,,,,,,,,,,,,,, -25496,,,0.6327929496765137,1.804942488670349,0.5851399898529053,2.0230395793914795,50000.0,0.4708000123500824,2.616346597671509,10000.0,11797.979766845703,12791.94000005722,11797.979766845703,991.7295281887054,0.8135378360748291,0.0 -25500,0.88046914,4.4722357,,,,,,,,,,,,,, -25600,0.750736,5.6858773,,,,,,,,,,,,,, -25700,1.0771359,3.9751077,,,,,,,,,,,,,, -25800,1.0007553,3.9122925,,,,,,,,,,,,,, -25900,0.9658046,3.906466,,,,,,,,,,,,,, -26000,0.7550489,4.721419,,,,,,,,,,,,,, -26100,0.95624477,3.8307104,,,,,,,,,,,,,, -26200,0.93678564,3.848211,,,,,,,,,,,,,, -26300,0.8124624,4.791731,,,,,,,,,,,,,, -26400,0.9300761,4.6895175,,,,,,,,,,,,,, -26406,,,0.6333593726158142,1.802080750465393,0.5868799686431885,2.013918876647949,50000.0,0.4729000329971313,2.6314289569854736,10000.0,12218.137630224228,13246.270922422407,12218.137630224228,1025.8188734054563,0.8469099998474121,0.0 -26500,0.9563252,3.9078846,,,,,,,,,,,,,, -26600,0.9313797,3.9816551,,,,,,,,,,,,,, -26700,0.83092046,4.4899025,,,,,,,,,,,,,, -26800,0.8723821,4.1340837,,,,,,,,,,,,,, -26900,0.9997833,3.8035774,,,,,,,,,,,,,, -27000,0.84581363,5.1518583,,,,,,,,,,,,,, -27100,0.7940672,4.955906,,,,,,,,,,,,,, -27200,0.8090547,5.0241733,,,,,,,,,,,,,, -27300,0.99982727,3.7751226,,,,,,,,,,,,,, -27317,,,0.6479687094688416,1.6960623264312744,0.5909599661827087,1.9587397575378416,50000.0,0.4727000296115875,2.574725866317749,10000.0,12638.15647506714,13700.801671981812,12638.15647506714,1060.2494506835938,0.8771259784698486,0.0 -27400,0.8145634,5.2344503,,,,,,,,,,,,,, -27500,0.8134555,4.187135,,,,,,,,,,,,,, -27600,0.90218675,4.0553813,,,,,,,,,,,,,, -27700,0.99420464,3.836877,,,,,,,,,,,,,, -27800,1.0217036,3.811782,,,,,,,,,,,,,, -27900,0.9466677,3.8295176,,,,,,,,,,,,,, -28000,0.814898,5.2914248,,,,,,,,,,,,,, -28100,0.99483055,3.7642295,,,,,,,,,,,,,, -28200,0.97817725,3.795567,,,,,,,,,,,,,, -28229,,,0.6434375047683716,1.7335150241851809,0.5952199697494507,1.9401875734329224,50000.0,0.4753000140190124,2.569722890853882,10000.0,13058.312211036682,14154.731746673584,13058.312211036682,1093.944759130478,0.9052963256835938,0.0 -28300,0.94184315,3.7853284,,,,,,,,,,,,,, -28400,0.962939,4.130245,,,,,,,,,,,,,, -28500,0.82019806,5.1827273,,,,,,,,,,,,,, -28600,1.0667571,3.7311513,,,,,,,,,,,,,, -28700,0.97422737,3.8274877,,,,,,,,,,,,,, -28800,1.0152117,3.9377482,,,,,,,,,,,,,, -28900,0.96858144,3.9606323,,,,,,,,,,,,,, -29000,0.7571354,4.6563897,,,,,,,,,,,,,, -29100,0.8038996,5.689344,,,,,,,,,,,,,, -29139,,,0.6492577791213989,1.7050681114196775,0.601639986038208,1.9144740104675293,50000.0,0.4758000373840332,2.5420379638671875,10000.0,13478.41717338562,14608.866518974304,13478.41717338562,1127.892193555832,0.9366669654846193,0.0 -29200,0.99830794,3.7236261,,,,,,,,,,,,,, -29300,0.6838887,5.6948814,,,,,,,,,,,,,, -29400,0.8228846,4.5481606,,,,,,,,,,,,,, -29500,0.8983951,4.337779,,,,,,,,,,,,,, -29600,0.8015922,5.013868,,,,,,,,,,,,,, -29700,0.83566785,5.6438503,,,,,,,,,,,,,, -29800,0.9776448,3.7364178,,,,,,,,,,,,,, -29900,1.1933945,3.8384721,,,,,,,,,,,,,, -30000,0.84032553,5.621847,,,,,,,,,,,,,, -30052,,,0.6590429544448853,1.6291348934173584,0.6089199781417847,1.8620059490203853,50000.0,0.4879000186920166,2.49440336227417,10000.0,13898.656394004822,15063.104083538055,13898.656394004822,1161.8061337471008,0.969707489013672,0.0 -30100,1.015758,3.8271997,,,,,,,,,,,,,, -30200,0.9407461,4.274619,,,,,,,,,,,,,, -30300,1.095638,3.8430161,,,,,,,,,,,,,, -30400,0.9830837,3.8393056,,,,,,,,,,,,,, -30500,1.1437868,3.791187,,,,,,,,,,,,,, -30600,0.82076746,5.6334996,,,,,,,,,,,,,, -30700,1.0766568,3.6940625,,,,,,,,,,,,,, -30800,1.0706431,3.7713263,,,,,,,,,,,,,, -30900,1.0017338,3.7382164,,,,,,,,,,,,,, -30963,,,0.6642773151397705,1.6285508871078491,0.6131199598312378,1.841789960861206,50000.0,0.4963000118732452,2.465266227722168,10000.0,14318.683321475984,15516.47614622116,14318.683321475984,1195.0702483654022,1.0004019737243652,0.0 -31000,1.0162926,3.8143806,,,,,,,,,,,,,, -31100,0.9977527,3.7574363,,,,,,,,,,,,,, -31200,0.86414325,4.036622,,,,,,,,,,,,,, -31300,0.9337176,4.0255375,,,,,,,,,,,,,, -31400,0.92655295,4.4521456,,,,,,,,,,,,,, -31500,0.9871528,3.6656446,,,,,,,,,,,,,, -31600,1.0531467,3.7441535,,,,,,,,,,,,,, -31700,0.9997963,3.721343,,,,,,,,,,,,,, -31800,1.0321394,3.8289642,,,,,,,,,,,,,, -31875,,,0.6638085842132568,1.6437461376190186,0.6118999719619751,1.870315432548523,50000.0,0.4912000298500061,2.4833028316497803,10000.0,14739.014737844467,15970.657630205154,14739.014737844467,1228.8334305286407,1.0361483097076416,0.0 -31900,1.0295974,3.681894,,,,,,,,,,,,,, -32000,0.83510584,4.540277,,,,,,,,,,,,,, -32100,1.2044717,3.7288651,,,,,,,,,,,,,, -32200,1.0276271,3.6930518,,,,,,,,,,,,,, -32300,0.93610996,3.7173524,,,,,,,,,,,,,, -32400,0.90130067,4.14003,,,,,,,,,,,,,, -32500,1.077094,3.740079,,,,,,,,,,,,,, -32600,1.0647401,4.0581193,,,,,,,,,,,,,, -32700,1.0777428,4.1575174,,,,,,,,,,,,,, -32789,,,0.6706835627555847,1.6040401458740234,0.614139974117279,1.843677759170532,50000.0,0.4969000220298767,2.4702181816101074,10000.0,15159.29290342331,16425.24866938591,15159.29290342331,1263.0607657432556,1.0695884227752686,0.0 -32800,1.0140908,3.8576107,,,,,,,,,,,,,, -32900,1.0660218,3.716414,,,,,,,,,,,,,, -33000,0.98572123,3.75045,,,,,,,,,,,,,, -33100,0.9378372,3.7998664,,,,,,,,,,,,,, -33200,0.84703296,4.437536,,,,,,,,,,,,,, -33300,0.9665396,3.726871,,,,,,,,,,,,,, -33400,0.8628668,4.418552,,,,,,,,,,,,,, -33500,0.9935027,3.812138,,,,,,,,,,,,,, -33600,0.8923848,4.1414347,,,,,,,,,,,,,, -33700,0.9762397,3.8137527,,,,,,,,,,,,,, -33701,,,0.6911327838897705,1.4933984279632568,0.6180599927902222,1.8183931112289429,50000.0,0.4958000183105469,2.447704315185547,10000.0,15579.49324440956,16880.903116226196,15579.49324440956,1298.4260349273682,1.1072144508361816,0.0 -33800,1.0758538,3.7319653,,,,,,,,,,,,,, -33900,0.9235855,5.036478,,,,,,,,,,,,,, -34000,1.0246674,3.6873033,,,,,,,,,,,,,, -34100,0.9593179,5.291744,,,,,,,,,,,,,, -34200,0.76278853,5.5566163,,,,,,,,,,,,,, -34300,0.9958557,3.6612804,,,,,,,,,,,,,, -34400,0.86098266,4.1883283,,,,,,,,,,,,,, -34500,0.88777554,4.1653695,,,,,,,,,,,,,, -34600,0.9594669,3.8510056,,,,,,,,,,,,,, -34614,,,0.6721093654632568,1.5771911144256592,0.619219958782196,1.7937535047531128,50000.0,0.5020000338554382,2.422393560409546,10000.0,15999.567883491516,17334.343224287033,15999.567883491516,1331.7058236598969,1.141645908355713,0.0 -34700,1.0761002,3.7504194,,,,,,,,,,,,,, -34800,0.9341453,3.5889907,,,,,,,,,,,,,, -34900,0.8938735,5.38656,,,,,,,,,,,,,, -35000,1.0596812,3.7073665,,,,,,,,,,,,,, -35100,0.99291456,3.6550453,,,,,,,,,,,,,, -35200,0.9308388,5.5343995,,,,,,,,,,,,,, -35300,0.75484526,5.3257174,,,,,,,,,,,,,, -35400,0.88085335,4.097284,,,,,,,,,,,,,, -35500,0.92687434,3.6551785,,,,,,,,,,,,,, -35525,,,0.6826562285423279,1.5192320346832275,0.6253799796104431,1.7672829627990725,50000.0,0.506600022315979,2.37952208518982,10000.0,16419.87308859825,17788.572286605835,16419.87308859825,1365.5460460186005,1.1748707294464111,0.0 -35600,1.0123595,3.754843,,,,,,,,,,,,,, -35700,1.0096056,3.6472507,,,,,,,,,,,,,, -35800,0.8430053,4.688071,,,,,,,,,,,,,, -35900,0.897045,3.9575605,,,,,,,,,,,,,, -36000,0.9465888,4.3634777,,,,,,,,,,,,,, -36100,1.0504867,3.6839235,,,,,,,,,,,,,, -36200,0.95731837,3.700465,,,,,,,,,,,,,, -36300,0.9030247,4.9133234,,,,,,,,,,,,,, -36400,1.0651187,3.6768627,,,,,,,,,,,,,, -36438,,,0.6946093440055847,1.4783991575241089,0.6278199553489685,1.7693872451782229,50000.0,0.5051000118255615,2.390291929244995,10000.0,16839.959003686905,18242.025453329086,16839.959003686905,1398.829402923584,1.20875883102417,0.0 -36500,1.0388892,3.7558417,,,,,,,,,,,,,, -36600,0.98996246,3.6922064,,,,,,,,,,,,,, -36700,1.109313,3.627881,,,,,,,,,,,,,, -36800,1.1273612,3.7496774,,,,,,,,,,,,,, -36900,1.0080537,3.748917,,,,,,,,,,,,,, -37000,0.878548,5.070827,,,,,,,,,,,,,, -37100,1.0533988,3.7547138,,,,,,,,,,,,,, -37200,0.90104306,4.963248,,,,,,,,,,,,,, -37300,1.0133222,3.8609385,,,,,,,,,,,,,, -37351,,,0.6821093559265137,1.5482853651046753,0.6320799589157104,1.763115167617798,50000.0,0.5121999979019165,2.3797619342803955,10000.0,17260.044524669647,18695.51702475548,17260.044524669647,1432.1537964344025,1.2387840747833252,0.0 -37400,0.94756687,3.5933964,,,,,,,,,,,,,, -37500,0.9909565,3.6750689,,,,,,,,,,,,,, -37600,0.9341703,4.7863016,,,,,,,,,,,,,, -37700,1.1346735,3.6058753,,,,,,,,,,,,,, -37800,1.0069202,3.6857572,,,,,,,,,,,,,, -37900,0.8315595,5.299811,,,,,,,,,,,,,, -38000,0.92302686,4.016286,,,,,,,,,,,,,, -38100,1.0525451,3.6508918,,,,,,,,,,,,,, -38200,1.0846261,3.7501714,,,,,,,,,,,,,, -38264,,,0.6850000023841858,1.5315663814544678,0.6313199996948242,1.7667924165725708,50000.0,0.5141000151634216,2.38298773765564,10000.0,17680.28241252899,19148.84790277481,17680.28241252899,1465.1609783172607,1.273481845855713,0.0 -38300,1.0115418,5.51764,,,,,,,,,,,,,, -38400,1.0805676,3.6834881,,,,,,,,,,,,,, -38500,1.1932398,3.7274854,,,,,,,,,,,,,, -38600,0.79113245,5.139277,,,,,,,,,,,,,, -38700,0.9760189,3.9539604,,,,,,,,,,,,,, -38800,0.8657337,4.7477574,,,,,,,,,,,,,, -38900,0.851344,4.925381,,,,,,,,,,,,,, -39000,1.2731341,3.6379547,,,,,,,,,,,,,, -39100,0.9459153,4.3488383,,,,,,,,,,,,,, -39176,,,0.6991210579872131,1.461413025856018,0.6354999542236328,1.740901231765747,50000.0,0.5213000178337097,2.35638165473938,10000.0,18100.41104626656,19602.334558963776,18100.41104626656,1498.4337601661682,1.3083534240722656,0.0 -39200,0.8900294,5.4899087,,,,,,,,,,,,,, -39300,1.0357648,3.803254,,,,,,,,,,,,,, -39400,0.9436845,5.435916,,,,,,,,,,,,,, -39500,0.8819464,5.5546675,,,,,,,,,,,,,, -39600,1.0008479,3.7093635,,,,,,,,,,,,,, -39700,1.0692579,3.6656685,,,,,,,,,,,,,, -39800,0.8959252,4.836136,,,,,,,,,,,,,, -39900,0.92403364,4.2599473,,,,,,,,,,,,,, -40000,1.0965062,3.8431325,,,,,,,,,,,,,, -40088,,,0.6839648485183716,1.4925146102905271,0.6367599964141846,1.704120635986328,50000.0,0.5170000195503235,2.318955183029175,10000.0,18520.329745292664,20055.471518993378,18520.329745292664,1531.565257549286,1.344674587249756,0.0 -40100,1.0197946,3.7455068,,,,,,,,,,,,,, -40200,0.85503817,5.4571147,,,,,,,,,,,,,, -40300,1.0177406,3.64888,,,,,,,,,,,,,, -40400,1.0387981,3.6002183,,,,,,,,,,,,,, -40500,1.0293076,3.94336,,,,,,,,,,,,,, -40600,1.0851125,3.6704094,,,,,,,,,,,,,, -40700,0.936051,4.3394938,,,,,,,,,,,,,, -40800,1.1305882,3.5601068,,,,,,,,,,,,,, -40900,1.0890516,3.7843564,,,,,,,,,,,,,, -41000,1.0390897,3.675615,,,,,,,,,,,,,, -41001,,,0.6875976324081421,1.500531792640686,0.6361799836158752,1.7311961650848389,50000.0,0.5174000263214111,2.3340766429901123,10000.0,18940.719674110413,20509.467386484142,18940.719674110413,1565.088744878769,1.376600980758667,0.0 -41100,1.0043372,3.6969602,,,,,,,,,,,,,, -41200,0.94753164,3.840611,,,,,,,,,,,,,, -41300,0.8298231,5.301342,,,,,,,,,,,,,, -41400,1.0010569,3.555497,,,,,,,,,,,,,, -41500,0.9042818,4.972782,,,,,,,,,,,,,, -41600,1.0331782,3.73209,,,,,,,,,,,,,, -41700,0.9311222,4.6684213,,,,,,,,,,,,,, -41800,1.1184709,3.7079785,,,,,,,,,,,,,, -41900,0.97637874,3.9131136,,,,,,,,,,,,,, -41912,,,0.7040234208106995,1.4264371395111084,0.6458399891853333,1.6851632595062256,50000.0,0.5248000025749207,2.298939943313598,10000.0,19360.97429251671,20962.10922479629,19360.97429251671,1597.391491651535,1.410295486450195,0.0 -42000,1.0116541,3.5914993,,,,,,,,,,,,,, -42100,1.0913081,3.732717,,,,,,,,,,,,,, -42200,1.1088535,3.6779768,,,,,,,,,,,,,, -42300,0.9255049,5.063106,,,,,,,,,,,,,, -42400,0.9333391,4.186743,,,,,,,,,,,,,, -42500,0.992465,3.5924377,,,,,,,,,,,,,, -42600,1.068792,3.583392,,,,,,,,,,,,,, -42700,0.953891,4.681044,,,,,,,,,,,,,, -42800,0.90389454,4.8297973,,,,,,,,,,,,,, -42821,,,0.6966992020606995,1.4827358722686768,0.6440399885177612,1.710224151611328,50000.0,0.5216000080108643,2.343894720077514,10000.0,19781.33933091164,21416.57683706284,19781.33933091164,1631.407816171646,1.445805549621582,0.0 -42900,1.0605546,3.692893,,,,,,,,,,,,,, -43000,1.0262673,3.566764,,,,,,,,,,,,,, -43100,0.983,3.568571,,,,,,,,,,,,,, -43200,0.93325776,3.99238,,,,,,,,,,,,,, -43300,0.996946,3.6540732,,,,,,,,,,,,,, -43400,1.0405816,3.582852,,,,,,,,,,,,,, -43500,1.0899541,3.65833,,,,,,,,,,,,,, -43600,0.98655736,3.6448653,,,,,,,,,,,,,, -43700,0.8667639,4.963787,,,,,,,,,,,,,, -43733,,,0.6943163871765137,1.5012829303741455,0.6460599899291992,1.7175486087799072,50000.0,0.51910001039505,2.349139451980591,10000.0,20201.28215122223,21870.283143281937,20201.28215122223,1665.082776069641,1.4831054210662842,0.0 -43800,1.0099415,3.6615534,,,,,,,,,,,,,, -43900,1.0399079,3.6536002,,,,,,,,,,,,,, -44000,1.0131247,3.6480803,,,,,,,,,,,,,, -44100,0.9442152,4.8612003,,,,,,,,,,,,,, -44200,1.0436672,3.6021128,,,,,,,,,,,,,, -44300,1.1916286,3.634911,,,,,,,,,,,,,, -44400,0.9490848,4.584207,,,,,,,,,,,,,, -44500,1.1396201,3.5731595,,,,,,,,,,,,,, -44600,1.1477675,3.6694694,,,,,,,,,,,,,, -44646,,,0.70361328125,1.4847627878189087,0.643559992313385,1.7350969314575195,50000.0,0.5242000222206116,2.3520002365112305,10000.0,20621.605316638947,22322.22277355194,20621.605316638947,1696.6110389232635,1.5203256607055664,0.0 -44700,1.0323985,4.09127,,,,,,,,,,,,,, -44800,0.9900128,4.8144526,,,,,,,,,,,,,, -44900,1.1752156,3.579792,,,,,,,,,,,,,, -45000,0.86230147,5.0312495,,,,,,,,,,,,,, -45100,0.9489953,4.865878,,,,,,,,,,,,,, -45200,0.9446361,5.4672413,,,,,,,,,,,,,, -45300,1.0464358,3.700841,,,,,,,,,,,,,, -45400,0.9808841,3.62951,,,,,,,,,,,,,, -45500,1.0933805,3.5686433,,,,,,,,,,,,,, -45557,,,0.7167187333106995,1.356802463531494,0.6495999693870544,1.6522201299667358,50000.0,0.5333000421524048,2.260016918182373,10000.0,21041.62238311768,22776.63697385788,21041.62238311768,1730.922518491745,1.5553491115570068,0.0 -45600,1.0547656,3.6016877,,,,,,,,,,,,,, -45700,0.8921203,5.32537,,,,,,,,,,,,,, -45800,0.9090803,4.9128976,,,,,,,,,,,,,, -45900,1.0196489,3.6472907,,,,,,,,,,,,,, -46000,0.9401467,5.0681963,,,,,,,,,,,,,, -46100,1.0825887,3.5986483,,,,,,,,,,,,,, -46200,1.098298,3.5583208,,,,,,,,,,,,,, -46300,1.0119085,4.1493344,,,,,,,,,,,,,, -46400,1.0691149,3.5460496,,,,,,,,,,,,,, -46468,,,0.6974999904632568,1.5168081521987915,0.6491599678993225,1.7391897439956665,50000.0,0.5254999995231628,2.357267379760742,10000.0,21461.639271736145,23231.38586783409,21461.639271736145,1765.5656650066376,1.5932340621948242,0.0 -46500,1.0022128,3.688054,,,,,,,,,,,,,, -46600,1.0292602,3.5029342,,,,,,,,,,,,,, -46700,1.0732577,3.6279635,,,,,,,,,,,,,, -46800,0.91979414,5.2416697,,,,,,,,,,,,,, -46900,1.0239141,4.6803904,,,,,,,,,,,,,, -47000,0.9787612,4.55979,,,,,,,,,,,,,, -47100,1.0900257,3.6147683,,,,,,,,,,,,,, -47200,1.0320812,3.5405567,,,,,,,,,,,,,, -47300,0.8967352,4.8941793,,,,,,,,,,,,,, -47381,,,0.7109375,1.4334065914154053,0.6492399573326111,1.697941541671753,50000.0,0.5318000316619873,2.307551622390747,10000.0,21881.792940855023,23684.5236389637,21881.792940855023,1798.4531662464142,1.6383380889892578,0.0 -47400,0.8914073,5.424882,,,,,,,,,,,,,, -47500,1.0225726,3.7156236,,,,,,,,,,,,,, -47600,1.0446572,3.6403103,,,,,,,,,,,,,, -47700,0.98037076,4.417347,,,,,,,,,,,,,, -47800,0.99749583,4.2212315,,,,,,,,,,,,,, -47900,1.139144,3.5891893,,,,,,,,,,,,,, -48000,1.0054939,4.2513924,,,,,,,,,,,,,, -48100,1.1407785,3.6347115,,,,,,,,,,,,,, -48200,1.0222267,3.7346187,,,,,,,,,,,,,, -48293,,,0.73046875,1.311131715774536,0.6547600030899048,1.646582007408142,50000.0,0.5335000157356262,2.266916036605835,10000.0,22301.990122556686,24138.33581233025,22301.990122556686,1831.97790145874,1.6770923137664795,0.0 -48300,1.0090399,3.6921773,,,,,,,,,,,,,, -48400,1.1182871,3.6112347,,,,,,,,,,,,,, -48500,1.0720456,3.472618,,,,,,,,,,,,,, -48600,0.9840437,3.582828,,,,,,,,,,,,,, -48700,1.0754938,3.5606873,,,,,,,,,,,,,, -48800,1.0619363,3.61534,,,,,,,,,,,,,, -48900,0.8897755,4.169423,,,,,,,,,,,,,, -49000,0.9948544,4.601507,,,,,,,,,,,,,, -49100,1.0288067,3.4980116,,,,,,,,,,,,,, -49200,1.0428673,3.5967977,,,,,,,,,,,,,, -49206,,,0.710156261920929,1.4143155813217163,0.6555399894714355,1.65362811088562,50000.0,0.5358999967575073,2.262042999267578,10000.0,22722.04015159607,24592.202979803085,22722.04015159607,1865.7097523212435,1.711674690246582,0.0 -49300,0.9167611,4.5846796,,,,,,,,,,,,,, -49400,1.1044239,3.5877311,,,,,,,,,,,,,, -49500,1.0055296,3.4097579,,,,,,,,,,,,,, -49600,0.9773659,5.272257,,,,,,,,,,,,,, -49700,1.0404078,3.534213,,,,,,,,,,,,,, -49800,0.9837915,3.4976962,,,,,,,,,,,,,, -49900,0.9630673,3.683572,,,,,,,,,,,,,, -50000,1.0451609,3.5931613,,,,,,,,,,,,,, -50100,0.908849,5.2689085,,,,,,,,,,,,,, -50115,,,0.7168163657188416,1.348970651626587,0.6584199666976929,1.6088238954544067,50000.0,0.5360000133514404,2.217803955078125,10000.0,23142.09687113762,25046.962195158005,23142.09687113762,1900.3253211975093,1.7474756240844729,0.0 -50200,1.062489,3.540635,,,,,,,,,,,,,, -50300,1.11654,3.802226,,,,,,,,,,,,,, -50400,1.0415219,5.3532434,,,,,,,,,,,,,, -50500,1.0815477,3.6541047,,,,,,,,,,,,,, -50600,1.1638042,3.5660563,,,,,,,,,,,,,, -50700,1.0031009,4.171278,,,,,,,,,,,,,, -50800,1.1106963,3.4855475,,,,,,,,,,,,,, -50900,1.2105525,3.5198796,,,,,,,,,,,,,, -51000,1.0746074,3.5584087,,,,,,,,,,,,,, -51027,,,0.7308984398841858,1.3000104427337646,0.6588199734687805,1.60606586933136,50000.0,0.5329000353813171,2.230520725250244,10000.0,23562.426542520523,25502.540413618088,23562.426542520523,1935.4861042499545,1.7837295532226562,0.0 -51100,1.0857327,3.5331252,,,,,,,,,,,,,, -51200,0.97100985,4.8528113,,,,,,,,,,,,,, -51300,1.1001557,3.5071535,,,,,,,,,,,,,, -51400,1.095813,3.684476,,,,,,,,,,,,,, -51500,1.0961845,3.4938,,,,,,,,,,,,,, -51600,1.0085684,5.3492436,,,,,,,,,,,,,, -51700,0.98563683,4.1060944,,,,,,,,,,,,,, -51800,0.9793874,5.076541,,,,,,,,,,,,,, -51900,1.1387475,3.4467847,,,,,,,,,,,,,, -51941,,,0.71009761095047,1.3835265636444092,0.6599999666213989,1.6134287118911743,50000.0,0.5430000424385071,2.21612286567688,10000.0,23982.65742635727,25957.92196583748,23982.65742635727,1970.5525019168847,1.8164739608764648,0.0 -52000,0.99632686,3.9196887,,,,,,,,,,,,,, -52100,0.9715728,4.907781,,,,,,,,,,,,,, -52200,1.1523438,3.4681532,,,,,,,,,,,,,, -52300,1.2597407,3.4911213,,,,,,,,,,,,,, -52400,0.92547613,4.272914,,,,,,,,,,,,,, -52500,1.067289,3.8942645,,,,,,,,,,,,,, -52600,1.0718081,5.0360975,,,,,,,,,,,,,, -52700,0.91322047,5.3144073,,,,,,,,,,,,,, -52800,1.1049489,3.7761788,,,,,,,,,,,,,, -52851,,,0.7220507860183716,1.3369845151901243,0.6620000004768372,1.5932549238204956,50000.0,0.5406000018119812,2.2067039012908936,10000.0,24402.8046002388,26411.741233587265,24402.8046002388,2004.1362063884733,1.8542118072509768,0.0 -52900,1.0379539,4.3321657,,,,,,,,,,,,,, -53000,1.0761764,3.6774693,,,,,,,,,,,,,, -53100,1.159798,3.6610782,,,,,,,,,,,,,, -53200,0.98130405,3.6301734,,,,,,,,,,,,,, -53300,0.9935719,3.9119804,,,,,,,,,,,,,, -53400,1.0300885,3.7744741,,,,,,,,,,,,,, -53500,1.0202285,5.36139,,,,,,,,,,,,,, -53600,0.9795552,4.501462,,,,,,,,,,,,,, -53700,1.1146425,3.547374,,,,,,,,,,,,,, -53763,,,0.7294335961341858,1.317986011505127,0.6652799844741821,1.6028327941894531,50000.0,0.5420000553131104,2.219597339630127,10000.0,24822.854472875595,26864.94805812836,24822.854472875595,2037.2056503295896,1.89034390449524,0.0 -53800,0.98365515,5.1558,,,,,,,,,,,,,, -53900,0.8931533,4.8482766,,,,,,,,,,,,,, -54000,0.9733683,4.311698,,,,,,,,,,,,,, -54100,1.0470159,3.499494,,,,,,,,,,,,,, -54200,1.1289966,3.589755,,,,,,,,,,,,,, -54300,1.0142336,3.8403807,,,,,,,,,,,,,, -54400,1.1494497,3.6394696,,,,,,,,,,,,,, -54500,1.1003228,3.425428,,,,,,,,,,,,,, -54600,1.0798798,3.5321665,,,,,,,,,,,,,, -54677,,,0.71937495470047,1.3365647792816162,0.6654199957847595,1.5735455751419067,50000.0,0.5410000085830688,2.1876070499420166,10000.0,25243.066395998,27318.78368639946,25243.066395998,2070.7375979423523,1.931077480316162,0.0 -54700,1.0636854,3.7014236,,,,,,,,,,,,,, -54800,1.0094075,4.075866,,,,,,,,,,,,,, -54900,1.1361334,3.6041152,,,,,,,,,,,,,, -55000,1.0926508,3.443311,,,,,,,,,,,,,, -55100,1.0224869,4.263559,,,,,,,,,,,,,, -55200,1.1061014,3.5704181,,,,,,,,,,,,,, -55300,1.1295335,3.5588021,,,,,,,,,,,,,, -55400,1.2519699,3.5203798,,,,,,,,,,,,,, -55500,1.0254971,5.3272843,,,,,,,,,,,,,, -55591,,,0.7230077981948853,1.379983901977539,0.6684600114822388,1.6201144456863403,50000.0,0.5437000393867493,2.244570732116699,10000.0,25663.03023004532,27772.1686103344,25663.03023004532,2104.065089225769,1.9737062454223635,0.0 -55600,1.0766263,3.644771,,,,,,,,,,,,,, -55700,1.0120999,4.5451603,,,,,,,,,,,,,, -55800,0.9533396,4.5840006,,,,,,,,,,,,,, -55900,1.0602889,3.5646813,,,,,,,,,,,,,, -56000,1.0343091,5.1217527,,,,,,,,,,,,,, -56100,1.0634713,3.7104893,,,,,,,,,,,,,, -56200,1.07725,3.6305792,,,,,,,,,,,,,, -56300,1.0060178,4.392029,,,,,,,,,,,,,, -56400,1.1199938,3.5042572,,,,,,,,,,,,,, -56500,1.1221868,3.9638972,,,,,,,,,,,,,, -56506,,,0.7322070002555847,1.3016204833984375,0.66975998878479,1.5884922742843628,50000.0,0.5479000210762024,2.196012496948242,10000.0,26083.24350643158,28227.99153614044,26083.24350643158,2139.5847957134247,2.0120248794555664,0.0 -56600,1.0748863,3.4999082,,,,,,,,,,,,,, -56700,1.1335825,3.494278,,,,,,,,,,,,,, -56800,1.1414815,3.4982007,,,,,,,,,,,,,, -56900,1.1181859,3.495444,,,,,,,,,,,,,, -57000,1.1541163,3.4520466,,,,,,,,,,,,,, -57100,1.1590161,3.5669444,,,,,,,,,,,,,, -57200,1.0678264,3.8138962,,,,,,,,,,,,,, -57300,0.9972155,3.6785064,,,,,,,,,,,,,, -57400,1.26878,3.767336,,,,,,,,,,,,,, -57418,,,0.7277148365974426,1.3060152530670166,0.6685799956321716,1.5581419467926023,50000.0,0.5449000000953674,2.177110433578491,10000.0,26503.21027612686,28683.141376018524,26503.21027612686,2174.6806008815765,2.048649311065674,0.0 -57500,1.0983055,3.4324722,,,,,,,,,,,,,, -57600,0.9974094,4.561544,,,,,,,,,,,,,, -57700,1.0599166,3.740426,,,,,,,,,,,,,, -57800,1.2016896,3.6535442,,,,,,,,,,,,,, -57900,1.0753313,3.428563,,,,,,,,,,,,,, -58000,1.0986065,5.2831807,,,,,,,,,,,,,, -58100,1.1181434,5.2838607,,,,,,,,,,,,,, -58200,1.1058913,3.4850047,,,,,,,,,,,,,, -58300,1.0989715,3.559523,,,,,,,,,,,,,, -58331,,,0.7244530916213989,1.3326257467269895,0.6692999601364136,1.571911096572876,50000.0,0.5443000197410583,2.1926655769348145,10000.0,26923.5639295578,29138.24506425857,26923.5639295578,2209.344695329666,2.083533763885498,0.0 -58400,0.9223677,4.779751,,,,,,,,,,,,,, -58500,1.1414965,3.5881715,,,,,,,,,,,,,, -58600,1.001815,5.009034,,,,,,,,,,,,,, -58700,1.2771351,3.4604902,,,,,,,,,,,,,, -58800,1.0296177,4.9279437,,,,,,,,,,,,,, -58900,1.0401037,5.115907,,,,,,,,,,,,,, -59000,1.1574671,3.5285852,,,,,,,,,,,,,, -59100,1.1806866,3.4190104,,,,,,,,,,,,,, -59200,1.2004064,3.578754,,,,,,,,,,,,,, -59242,,,0.7317968606948853,1.3039686679840088,0.671999990940094,1.571418523788452,50000.0,0.5469000339508057,2.191420316696167,10000.0,27343.48020362854,29591.821184158325,27343.48020362854,2242.9132976531982,2.122434616088867,0.0 -59300,1.0017263,4.722234,,,,,,,,,,,,,, -59400,1.2267008,3.5161133,,,,,,,,,,,,,, -59500,0.9610067,3.9470282,,,,,,,,,,,,,, -59600,1.2895058,3.5531585,,,,,,,,,,,,,, -59700,1.0861076,3.6449032,,,,,,,,,,,,,, -59800,1.1261663,5.145629,,,,,,,,,,,,,, -59900,1.223604,3.5432231,,,,,,,,,,,,,, -60000,1.0126486,4.646271,,,,,,,,,,,,,, -60100,1.138889,3.5368946,,,,,,,,,,,,,, -60156,,,0.7542577981948853,1.2266101837158203,0.671459972858429,1.5660921335220337,50000.0,0.5527999997138977,2.186643362045288,10000.0,27763.48047947884,30046.77208018303,27763.48047947884,2277.7794167995453,2.156461715698242,0.0 -60200,1.05113,3.6792374,,,,,,,,,,,,,, -60300,1.1535614,3.4392593,,,,,,,,,,,,,, -60400,1.019253,3.595828,,,,,,,,,,,,,, -60500,1.06439,4.059339,,,,,,,,,,,,,, -60600,1.1464736,3.507936,,,,,,,,,,,,,, -60700,1.0518191,4.3532887,,,,,,,,,,,,,, -60800,1.1094321,3.5015252,,,,,,,,,,,,,, -60900,1.0195208,5.175598,,,,,,,,,,,,,, -61000,1.0334793,3.9269805,,,,,,,,,,,,,, -61068,,,0.7308984398841858,1.2666162252426147,0.6772399544715881,1.4976990222930908,50000.0,0.5565000176429749,2.116438865661621,10000.0,28183.66246366501,30500.72461438179,28183.66246366501,2311.4618620872498,2.193611860275269,0.0 -61100,1.1020598,3.422868,,,,,,,,,,,,,, -61200,1.1744219,3.4624262,,,,,,,,,,,,,, -61300,1.1200243,4.8999143,,,,,,,,,,,,,, -61400,1.2328681,3.5369515,,,,,,,,,,,,,, -61500,1.0383376,3.5929565,,,,,,,,,,,,,, -61600,1.0314422,4.657843,,,,,,,,,,,,,, -61700,1.0416988,3.546937,,,,,,,,,,,,,, -61800,1.1734245,3.5344613,,,,,,,,,,,,,, -61900,1.0729918,3.7437782,,,,,,,,,,,,,, -61980,,,0.73646479845047,1.2581452131271362,0.6758399605751038,1.5217695236206057,50000.0,0.5525000095367432,2.141698837280273,10000.0,28603.927173376083,30956.482098340988,28603.927173376083,2346.8619673252106,2.235115766525269,0.0 -62000,0.98076797,4.7314925,,,,,,,,,,,,,, -62100,1.187319,3.5077555,,,,,,,,,,,,,, -62200,1.0190117,4.8583374,,,,,,,,,,,,,, -62300,1.030557,3.7592223,,,,,,,,,,,,,, -62400,1.0888507,3.6446843,,,,,,,,,,,,,, -62500,1.0522795,4.0813217,,,,,,,,,,,,,, -62600,0.99070835,4.665809,,,,,,,,,,,,,, -62700,1.1977342,3.6002321,,,,,,,,,,,,,, -62800,1.1287081,5.2175455,,,,,,,,,,,,,, -62892,,,0.7489648461341858,1.2218483686447144,0.6783999800682068,1.5327138900756836,50000.0,0.5552999973297119,2.162600517272949,10000.0,29024.041388511658,31409.77238416672,29024.041388511658,2379.947853088379,2.274474620819092,0.0 -62900,1.0964706,3.4504368,,,,,,,,,,,,,, -63000,1.139547,3.444248,,,,,,,,,,,,,, -63100,1.1283145,3.373954,,,,,,,,,,,,,, -63200,1.0814172,4.7921495,,,,,,,,,,,,,, -63300,1.1610047,5.1374516,,,,,,,,,,,,,, -63400,1.2195085,5.243383,,,,,,,,,,,,,, -63500,1.1651887,3.4171603,,,,,,,,,,,,,, -63600,1.101699,4.7388277,,,,,,,,,,,,,, -63700,1.2691705,3.5314322,,,,,,,,,,,,,, -63800,1.0883243,3.457578,,,,,,,,,,,,,, -63803,,,0.7317187190055847,1.283490777015686,0.677899956703186,1.513463854789734,50000.0,0.5556000471115112,2.131977796554565,10000.0,29443.85234260559,31862.30394887924,29443.85234260559,2412.2440707683563,2.648000955581665,0.0 -63900,1.046932,4.0347743,,,,,,,,,,,,,, -64000,1.1156294,3.5147138,,,,,,,,,,,,,, -64100,1.0877986,3.473004,,,,,,,,,,,,,, -64200,1.0670024,4.3169503,,,,,,,,,,,,,, -64300,1.0894129,3.3560574,,,,,,,,,,,,,, -64400,1.1786548,3.5162559,,,,,,,,,,,,,, -64500,1.1741436,3.4533262,,,,,,,,,,,,,, -64600,1.1565858,3.5345137,,,,,,,,,,,,,, -64700,1.1094036,5.03642,,,,,,,,,,,,,, -64715,,,0.740527331829071,1.269747018814087,0.6794599890708923,1.536699652671814,50000.0,0.5520000457763672,2.1582353115081787,10000.0,29863.933977127075,32318.9393632412,29863.933977127075,2448.706707715988,2.6889126300811768,0.0 -64800,1.2649907,3.5198963,,,,,,,,,,,,,, -64900,1.1327009,3.8381827,,,,,,,,,,,,,, -65000,1.1017672,3.5543644,,,,,,,,,,,,,, -65100,1.2689488,3.4233456,,,,,,,,,,,,,, -65200,1.1348295,3.3924143,,,,,,,,,,,,,, -65300,1.1194096,3.4267936,,,,,,,,,,,,,, -65400,1.1857973,3.416155,,,,,,,,,,,,,, -65500,1.3161231,3.5922103,,,,,,,,,,,,,, -65600,1.0633745,4.2308683,,,,,,,,,,,,,, -65626,,,0.7494140267372131,1.1988508701324463,0.6782400012016296,1.5024621486663818,50000.0,0.5551000237464905,2.1161949634552,10000.0,30284.21901607513,32773.13093471527,30284.21901607513,2482.5249683856964,2.7266457080841064,0.0 -65700,1.218395,3.7593944,,,,,,,,,,,,,, -65800,1.2454245,3.3596919,,,,,,,,,,,,,, -65900,1.1592089,5.224226,,,,,,,,,,,,,, -66000,1.0571073,3.8438976,,,,,,,,,,,,,, -66100,1.184334,4.975856,,,,,,,,,,,,,, -66200,1.0731968,4.6202474,,,,,,,,,,,,,, -66300,1.0499047,4.0707254,,,,,,,,,,,,,, -66400,1.0387594,4.7353516,,,,,,,,,,,,,, -66500,1.1227646,5.1661577,,,,,,,,,,,,,, -66538,,,0.7292382717132568,1.3125627040863037,0.6787999868392944,1.5494202375411987,50000.0,0.5455000400543213,2.175212383270264,10000.0,30704.13455057144,33227.839062690735,30704.13455057144,2517.2284500598907,2.763366460800171,0.0 -66600,1.1519086,3.4024096,,,,,,,,,,,,,, -66700,1.2131382,3.444592,,,,,,,,,,,,,, -66800,1.1287224,3.47417,,,,,,,,,,,,,, -66900,1.1736116,3.4249232,,,,,,,,,,,,,, -67000,1.2910345,3.4841523,,,,,,,,,,,,,, -67100,1.1327912,3.6938756,,,,,,,,,,,,,, -67200,1.095946,3.6408834,,,,,,,,,,,,,, -67300,1.1929132,3.5667782,,,,,,,,,,,,,, -67400,1.2078676,4.750591,,,,,,,,,,,,,, -67449,,,0.7469726204872131,1.2171990871429443,0.6873199939727783,1.4766743183135986,50000.0,0.558899998664856,2.096421718597412,10000.0,31124.17865896225,33680.919352293015,31124.17865896225,2550.1747620105743,2.8025145530700684,0.0 -67500,1.1424446,3.3975134,,,,,,,,,,,,,, -67600,1.0531992,4.3058505,,,,,,,,,,,,,, -67700,1.1902293,3.5649574,,,,,,,,,,,,,, -67800,1.2023171,3.49551,,,,,,,,,,,,,, -67900,1.1083218,4.2431817,,,,,,,,,,,,,, -68000,1.1695207,3.5846114,,,,,,,,,,,,,, -68100,1.2266887,3.4192815,,,,,,,,,,,,,, -68200,1.2414272,3.4498773,,,,,,,,,,,,,, -68300,1.2434218,4.441929,,,,,,,,,,,,,, -68359,,,0.7477148175239563,1.2257516384124756,0.6833999752998352,1.5142306089401243,50000.0,0.5598000288009644,2.143460750579834,10000.0,31544.106262922287,34135.353620529175,31544.106262922287,2584.592365026474,2.8403160572052,0.0 -68400,1.0993282,3.4245079,,,,,,,,,,,,,, -68500,1.0628256,4.59041,,,,,,,,,,,,,, -68600,1.2020011,5.124017,,,,,,,,,,,,,, -68700,1.1131996,3.899511,,,,,,,,,,,,,, -68800,1.2232356,3.4153373,,,,,,,,,,,,,, -68900,1.1193675,3.475464,,,,,,,,,,,,,, -69000,1.0172544,4.0961394,,,,,,,,,,,,,, -69100,1.1203308,4.7409935,,,,,,,,,,,,,, -69200,1.1114061,4.364735,,,,,,,,,,,,,, -69272,,,0.7469335794448853,1.238759160041809,0.6876199841499329,1.490525484085083,50000.0,0.5688000321388245,2.0847983360290527,10000.0,31964.290115594864,34592.112513780594,31964.290115594864,2621.076464414597,2.879974603652954,0.0 -69300,1.1329066,3.7232141,,,,,,,,,,,,,, -69400,1.1808839,4.707025,,,,,,,,,,,,,, -69500,1.1245707,3.624711,,,,,,,,,,,,,, -69600,1.0426803,3.9655223,,,,,,,,,,,,,, -69700,1.2540386,3.5155878,,,,,,,,,,,,,, -69800,1.107602,3.3272173,,,,,,,,,,,,,, -69900,1.1757765,3.345587,,,,,,,,,,,,,, -70000,1.1000034,4.039454,,,,,,,,,,,,,, -70100,1.1108066,3.4042032,,,,,,,,,,,,,, -70182,,,0.7432226538658142,1.2668992280960083,0.6850199699401855,1.5202137231826782,50000.0,0.5636000037193298,2.132345676422119,10000.0,32384.55722093582,35045.258445978165,32384.55722093582,2653.8648805618286,2.919417381286621,0.0 -70200,1.0401722,3.8903093,,,,,,,,,,,,,, -70300,1.1770111,3.4298353,,,,,,,,,,,,,, -70400,1.1314467,4.4608464,,,,,,,,,,,,,, -70500,1.1729524,5.077889,,,,,,,,,,,,,, -70600,1.1657996,3.3793495,,,,,,,,,,,,,, -70700,1.278247,3.4419847,,,,,,,,,,,,,, -70800,1.223805,3.4111884,,,,,,,,,,,,,, -70900,1.1282498,4.462816,,,,,,,,,,,,,, -71000,1.0553074,3.581389,,,,,,,,,,,,,, -71092,,,0.7544921636581421,1.1773465871810913,0.6906999945640564,1.4612739086151123,50000.0,0.5626000165939331,2.0810532569885254,10000.0,32804.878509521484,35501.86202931404,32804.878509521484,2690.0571892261505,2.959210157394409,0.0 -71100,1.1351312,5.076595,,,,,,,,,,,,,, -71200,1.1731162,3.4204686,,,,,,,,,,,,,, -71300,1.0483173,4.35411,,,,,,,,,,,,,, -71400,1.2190181,3.4421072,,,,,,,,,,,,,, -71500,1.1573997,3.4730442,,,,,,,,,,,,,, -71600,1.068702,4.2064767,,,,,,,,,,,,,, -71700,1.192746,3.3998165,,,,,,,,,,,,,, -71800,1.1238236,4.130922,,,,,,,,,,,,,, -71900,1.0777473,3.6299682,,,,,,,,,,,,,, -72000,1.1785061,3.9427733,,,,,,,,,,,,,, -72003,,,0.7497460842132568,1.197008490562439,0.6877399682998657,1.4694221019744873,50000.0,0.5637000203132629,2.097407341003418,10000.0,33225.23718857765,35955.55195403099,33225.23718857765,2723.2984120845795,2.999574899673462,0.0 -72100,1.184416,3.8139174,,,,,,,,,,,,,, -72200,1.2101573,3.4767504,,,,,,,,,,,,,, -72300,1.150569,3.4084435,,,,,,,,,,,,,, -72400,1.1318176,3.4595873,,,,,,,,,,,,,, -72500,1.1439009,4.946765,,,,,,,,,,,,,, -72600,1.2083075,3.3638248,,,,,,,,,,,,,, -72700,1.1177409,4.952543,,,,,,,,,,,,,, -72800,1.1228155,3.4829679,,,,,,,,,,,,,, -72900,1.1333269,4.684021,,,,,,,,,,,,,, -72913,,,0.7474804520606995,1.2026795148849487,0.6865999698638916,1.473946213722229,50000.0,0.5586000084877014,2.1017489433288574,10000.0,33645.199608802795,36409.107684612274,33645.199608802795,2756.8011240959167,3.038689374923706,0.0 -73000,1.1823255,3.9553053,,,,,,,,,,,,,, -73100,1.227556,3.3481822,,,,,,,,,,,,,, -73200,1.2048196,3.376064,,,,,,,,,,,,,, -73300,1.1361609,3.3063323,,,,,,,,,,,,,, -73400,1.1752088,3.3287334,,,,,,,,,,,,,, -73500,1.0639842,3.637258,,,,,,,,,,,,,, -73600,1.2063851,3.3602502,,,,,,,,,,,,,, -73700,1.1873804,3.3780553,,,,,,,,,,,,,, -73800,1.1562095,3.3411763,,,,,,,,,,,,,, -73824,,,0.7582421898841858,1.1599042415618896,0.6906999945640564,1.4475252628326416,50000.0,0.569100022315979,2.0638108253479004,10000.0,34065.1739525795,36860.7737197876,34065.1739525795,2788.398811101913,3.081950187683105,0.0 -73900,1.1446712,4.6942387,,,,,,,,,,,,,, -74000,1.1423197,3.9715369,,,,,,,,,,,,,, -74100,1.1969917,3.533245,,,,,,,,,,,,,, -74200,1.199578,5.130947,,,,,,,,,,,,,, -74300,1.1033725,3.365479,,,,,,,,,,,,,, -74400,1.251233,3.5163302,,,,,,,,,,,,,, -74500,1.1957673,3.4882617,,,,,,,,,,,,,, -74600,1.248758,3.4600909,,,,,,,,,,,,,, -74700,1.2374326,3.4533374,,,,,,,,,,,,,, -74735,,,0.7710937261581421,1.1149591207504272,0.6913999915122986,1.4553277492523191,50000.0,0.5658000111579895,2.082012176513672,10000.0,34485.24294900894,37315.77154183388,34485.24294900894,2823.236649990082,3.122302770614624,0.0 -74800,1.2969224,3.291121,,,,,,,,,,,,,, -74900,1.0539595,4.024419,,,,,,,,,,,,,, -75000,1.2087328,3.3584118,,,,,,,,,,,,,, -75100,1.1094669,3.4065838,,,,,,,,,,,,,, -75200,1.1447678,3.6374722,,,,,,,,,,,,,, -75300,1.129746,4.8239117,,,,,,,,,,,,,, -75400,1.1757356,3.6018062,,,,,,,,,,,,,, -75500,1.2232543,3.3466232,,,,,,,,,,,,,, -75600,1.2665651,3.3816223,,,,,,,,,,,,,, -75645,,,0.7518945336341858,1.226989984512329,0.6896199584007263,1.4911552667617798,50000.0,0.5645000338554382,2.1078085899353027,10000.0,34905.34354329109,37769.46210241318,34905.34354329109,2856.7374284267426,3.161123752593994,0.0 -75700,1.1819037,5.0665236,,,,,,,,,,,,,, -75800,1.2625349,3.4989932,,,,,,,,,,,,,, -75900,1.2018504,4.68438,,,,,,,,,,,,,, -76000,1.1679077,3.4046073,,,,,,,,,,,,,, -76100,1.1812338,3.4521391,,,,,,,,,,,,,, -76200,1.2366782,3.3133001,,,,,,,,,,,,,, -76300,1.1531498,3.4608572,,,,,,,,,,,,,, -76400,1.1488808,3.249195,,,,,,,,,,,,,, -76500,1.2460409,3.3394668,,,,,,,,,,,,,, -76556,,,0.7568554282188416,1.1630686521530151,0.694379985332489,1.4409019947052002,50000.0,0.5637000203132629,2.063631772994995,10000.0,35325.38757443428,38222.31966996193,35325.38757443428,2889.461337566376,3.200040578842163,0.0 -76600,1.2188076,3.4536638,,,,,,,,,,,,,, -76700,1.2978786,3.3618562,,,,,,,,,,,,,, -76800,1.2715309,4.99067,,,,,,,,,,,,,, -76900,1.2801429,3.3634634,,,,,,,,,,,,,, -77000,1.315834,5.1545486,,,,,,,,,,,,,, -77100,1.154297,4.668353,,,,,,,,,,,,,, -77200,1.2954103,4.990496,,,,,,,,,,,,,, -77300,1.0658691,4.5007463,,,,,,,,,,,,,, -77400,1.1821382,3.3687406,,,,,,,,,,,,,, -77469,,,0.7740234136581421,1.1011523008346558,0.6961199641227722,1.4330646991729736,50000.0,0.572100043296814,2.0542802810668945,10000.0,35745.66950130463,38674.79872870445,35745.66950130463,2921.5647172927856,3.242464065551758,0.0 -77500,1.2726082,4.6735277,,,,,,,,,,,,,, -77600,1.0497439,3.8871357,,,,,,,,,,,,,, -77700,1.1450971,3.381996,,,,,,,,,,,,,, -77800,1.1384315,4.2642064,,,,,,,,,,,,,, -77900,1.2109015,3.4405951,,,,,,,,,,,,,, -78000,1.204829,3.3401496,,,,,,,,,,,,,, -78100,1.200351,3.310546,,,,,,,,,,,,,, -78200,1.3048315,3.4724941,,,,,,,,,,,,,, -78300,1.1786038,4.722002,,,,,,,,,,,,,, -78381,,,0.7498828172683716,1.1963554620742798,0.6926999688148499,1.452593445777893,50000.0,0.5669000148773193,2.084571599960327,10000.0,36165.88693213463,39129.471900224686,36165.88693213463,2955.930624961853,3.281820297241211,0.0 -78400,1.466477,3.332872,,,,,,,,,,,,,, -78500,1.1936845,4.8840027,,,,,,,,,,,,,, -78600,1.2540447,3.348827,,,,,,,,,,,,,, -78700,1.262657,3.558653,,,,,,,,,,,,,, -78800,1.2060906,3.3591914,,,,,,,,,,,,,, -78900,1.222131,3.3437948,,,,,,,,,,,,,, -79000,1.1275369,3.635667,,,,,,,,,,,,,, -79100,1.1944422,3.3206763,,,,,,,,,,,,,, -79200,1.2226729,3.5552292,,,,,,,,,,,,,, -79292,,,0.7596288919448853,1.14617919921875,0.6948999762535095,1.4263193607330322,50000.0,0.5736000537872314,2.044193744659424,10000.0,36585.79311347008,39584.9595644474,36585.79311347008,2991.409699201584,3.3323724269866943,0.0 -79300,1.2057372,4.217871,,,,,,,,,,,,,, -79400,1.2151839,3.5421984,,,,,,,,,,,,,, -79500,1.1765164,3.8728967,,,,,,,,,,,,,, -79600,1.0881366,3.6283836,,,,,,,,,,,,,, -79700,1.2731645,3.308113,,,,,,,,,,,,,, -79800,1.3028146,3.395541,,,,,,,,,,,,,, -79900,1.2522557,3.384532,,,,,,,,,,,,,, -80000,1.1771187,4.183384,,,,,,,,,,,,,, -80100,1.2534842,3.4862432,,,,,,,,,,,,,, -80200,1.164983,4.687069,,,,,,,,,,,,,, -80204,,,0.7778124809265137,1.1191484928131104,0.7011399865150452,1.4412455558776855,50000.0,0.5777000188827515,2.0484261512756348,10000.0,37006.06656885147,40038.66042852402,37006.06656885147,3024.742676258087,3.375904083251953,0.0 -80300,1.2105404,3.5049748,,,,,,,,,,,,,, -80400,1.2739655,3.5011039,,,,,,,,,,,,,, -80500,1.2704405,3.3351169,,,,,,,,,,,,,, -80600,1.2761588,3.3100576,,,,,,,,,,,,,, -80700,1.258003,3.4819617,,,,,,,,,,,,,, -80800,1.2641404,3.4924242,,,,,,,,,,,,,, -80900,1.2858338,3.370059,,,,,,,,,,,,,, -81000,1.2975293,3.2988644,,,,,,,,,,,,,, -81100,1.2054052,3.8182328,,,,,,,,,,,,,, -81112,,,0.75927734375,1.1983951330184937,0.6988799571990967,1.469262957572937,50000.0,0.570900022983551,2.0844738483428955,10000.0,37425.9896273613,40495.30886769295,37425.9896273613,3061.377053260803,3.416645765304565,0.0 -81200,1.3466024,3.308766,,,,,,,,,,,,,, -81300,1.3310065,3.3530064,,,,,,,,,,,,,, -81400,1.3410872,3.400844,,,,,,,,,,,,,, -81500,1.2738044,3.3640447,,,,,,,,,,,,,, -81600,1.2735964,3.6224685,,,,,,,,,,,,,, -81700,1.3660324,4.984288,,,,,,,,,,,,,, -81800,1.2703995,3.3291838,,,,,,,,,,,,,, -81900,1.2501416,3.326803,,,,,,,,,,,,,, -82000,1.2143489,3.879685,,,,,,,,,,,,,, -82023,,,0.7640624642372131,1.1621880531311035,0.702019989490509,1.4289867877960205,50000.0,0.5773000121116638,2.0627498626708984,10000.0,37846.07184147835,40952.05909061432,37846.07184147835,3097.956475019455,3.4548816680908203,0.0 -82100,1.3478558,3.317656,,,,,,,,,,,,,, -82200,1.1583064,4.443121,,,,,,,,,,,,,, -82300,1.3697435,3.2595763,,,,,,,,,,,,,, -82400,1.3169432,3.359966,,,,,,,,,,,,,, -82500,1.3610953,5.02068,,,,,,,,,,,,,, -82600,1.247911,3.393779,,,,,,,,,,,,,, -82700,1.2182392,3.2973204,,,,,,,,,,,,,, -82800,1.3079569,3.3054297,,,,,,,,,,,,,, -82900,1.298116,4.0340443,,,,,,,,,,,,,, -82935,,,0.7707226276397705,1.1098370552062988,0.7006799578666687,1.4164754152297974,50000.0,0.5746000409126282,2.0360336303710938,10000.0,38266.4472155571,41404.75353455544,38266.4472155571,3130.1866416931152,3.492640733718872,0.0 -83000,1.2723728,3.4192722,,,,,,,,,,,,,, -83100,1.1182729,4.6923785,,,,,,,,,,,,,, -83200,1.1546073,4.437574,,,,,,,,,,,,,, -83300,1.1957502,4.62879,,,,,,,,,,,,,, -83400,1.2194628,3.6699083,,,,,,,,,,,,,, -83500,1.2086238,4.035552,,,,,,,,,,,,,, -83600,1.2709267,4.5568457,,,,,,,,,,,,,, -83700,1.237329,3.3278275,,,,,,,,,,,,,, -83800,1.33327,3.2833128,,,,,,,,,,,,,, -83846,,,0.7615429759025574,1.1838170289993286,0.700760006904602,1.4477368593215942,50000.0,0.5703000426292419,2.065577745437622,10000.0,38686.41108036041,41860.1210463047,38686.41108036041,3165.501063108444,3.5313708782196045,0.0 -83900,1.2743952,4.381671,,,,,,,,,,,,,, -84000,1.2105495,3.338024,,,,,,,,,,,,,, -84100,1.1774949,3.8019104,,,,,,,,,,,,,, -84200,1.2284925,3.8180811,,,,,,,,,,,,,, -84300,1.2451577,3.3656647,,,,,,,,,,,,,, -84400,1.2398916,3.274108,,,,,,,,,,,,,, -84500,1.2950035,3.6133149,,,,,,,,,,,,,, -84600,1.1555177,4.5889893,,,,,,,,,,,,,, -84700,1.1624684,4.429122,,,,,,,,,,,,,, -84759,,,0.7667187452316284,1.1391881704330444,0.7033799886703491,1.4073001146316528,50000.0,0.5855000019073486,2.0161125659942627,10000.0,39106.723685503006,42315.893428087234,39106.723685503006,3200.868814229965,3.5725789070129395,0.0 -84800,1.3308575,3.416764,,,,,,,,,,,,,, -84900,1.4340316,3.2996783,,,,,,,,,,,,,, -85000,1.2268591,4.4236927,,,,,,,,,,,,,, -85100,1.2223787,4.1033354,,,,,,,,,,,,,, -85200,1.2789336,4.9180574,,,,,,,,,,,,,, -85300,1.1071478,3.581493,,,,,,,,,,,,,, -85400,1.3001704,3.3725448,,,,,,,,,,,,,, -85500,1.1542773,3.563245,,,,,,,,,,,,,, -85600,1.3316326,3.4318566,,,,,,,,,,,,,, -85671,,,0.7760937213897705,1.1103529930114746,0.7053200006484985,1.4143160581588743,50000.0,0.5834000110626221,2.0242509841918945,10000.0,39526.67753863335,42773.18708443642,39526.67753863335,3238.1050729751587,3.625221729278565,0.0 -85700,1.4073669,3.4685934,,,,,,,,,,,,,, -85800,1.3077344,3.8436089,,,,,,,,,,,,,, -85900,1.235984,3.5374596,,,,,,,,,,,,,, -86000,1.3344046,4.7378097,,,,,,,,,,,,,, -86100,1.3001666,3.3211784,,,,,,,,,,,,,, -86200,1.2034863,3.7075922,,,,,,,,,,,,,, -86300,1.2281743,3.3657875,,,,,,,,,,,,,, -86400,1.1396139,3.986714,,,,,,,,,,,,,, -86500,1.2404789,3.7878544,,,,,,,,,,,,,, -86581,,,0.7747851610183716,1.1178936958312988,0.7052599787712097,1.4013491868972778,50000.0,0.5835000276565552,2.003753900527954,10000.0,39946.76177215576,43231.5920381546,39946.76177215576,3276.332666158676,3.668002128601074,0.0 -86600,1.4381679,3.3175123,,,,,,,,,,,,,, -86700,1.2355077,3.9658136,,,,,,,,,,,,,, -86800,1.2560395,3.398929,,,,,,,,,,,,,, -86900,1.435577,5.075554,,,,,,,,,,,,,, -87000,1.2864974,3.259512,,,,,,,,,,,,,, -87100,1.4698524,5.0221457,,,,,,,,,,,,,, -87200,1.208372,3.413322,,,,,,,,,,,,,, -87300,1.2456608,4.7321634,,,,,,,,,,,,,, -87400,1.3568703,3.3755867,,,,,,,,,,,,,, -87492,,,0.7694140672683716,1.1339526176452637,0.7033799886703491,1.4059618711471558,50000.0,0.5782999992370605,2.0230345726013184,10000.0,40366.83447360992,43687.16769170761,40366.83447360992,3311.738111257553,3.710970640182495,0.0 -87500,1.1508493,3.6688514,,,,,,,,,,,,,, -87600,1.2644036,3.6242635,,,,,,,,,,,,,, -87700,1.2486604,3.2802134,,,,,,,,,,,,,, -87800,1.2494981,3.2841454,,,,,,,,,,,,,, -87900,1.2260613,3.2644196,,,,,,,,,,,,,, -88000,1.2885592,3.2741709,,,,,,,,,,,,,, -88100,1.2208999,4.0532475,,,,,,,,,,,,,, -88200,1.2652998,3.634357,,,,,,,,,,,,,, -88300,1.1719912,3.7541509,,,,,,,,,,,,,, -88400,1.2075605,4.143376,,,,,,,,,,,,,, -88404,,,0.7787304520606995,1.0719337463378906,0.7060399651527405,1.3819568157196045,50000.0,0.5807000398635864,1.9886709451675413,10000.0,40787.19777345657,44141.67059183121,40787.19777345657,3345.784725189209,3.752886056900024,0.0 -88500,1.3435801,3.3171134,,,,,,,,,,,,,, -88600,1.417671,3.3105316,,,,,,,,,,,,,, -88700,1.3576488,4.5245066,,,,,,,,,,,,,, -88800,1.3264084,3.2743993,,,,,,,,,,,,,, -88900,1.3963621,3.3278298,,,,,,,,,,,,,, -89000,1.3608633,3.8742988,,,,,,,,,,,,,, -89100,1.2391198,3.8574684,,,,,,,,,,,,,, -89200,1.3458027,3.324493,,,,,,,,,,,,,, -89300,1.2482611,3.3683152,,,,,,,,,,,,,, -89314,,,0.7935351133346558,1.021241545677185,0.7061399817466736,1.3768731355667114,50000.0,0.5854000449180603,1.981370210647583,10000.0,41207.16171503067,44596.56769442558,41207.16171503067,3380.6273963451385,3.791879415512085,0.0 -89400,1.3170105,4.312498,,,,,,,,,,,,,, -89500,1.3411057,3.318781,,,,,,,,,,,,,, -89600,1.2425518,5.014036,,,,,,,,,,,,,, -89700,1.260978,3.9330497,,,,,,,,,,,,,, -89800,1.1422269,3.586834,,,,,,,,,,,,,, -89900,1.2409801,3.3182354,,,,,,,,,,,,,, -90000,1.2623492,3.3346608,,,,,,,,,,,,,, -90100,1.1474049,3.6464558,,,,,,,,,,,,,, -90200,1.2346138,3.5283043,,,,,,,,,,,,,, -90225,,,0.7718163728713989,1.1109107732772827,0.7108599543571472,1.373457670211792,50000.0,0.581000030040741,2.001279354095459,10000.0,41627.17049217224,45052.54302406311,41627.17049217224,3416.499861478805,3.834437370300293,0.0 -90300,1.4047835,3.3799307,,,,,,,,,,,,,, -90400,1.2741735,4.111584,,,,,,,,,,,,,, -90500,1.405669,3.2179623,,,,,,,,,,,,,, -90600,1.2327754,4.307452,,,,,,,,,,,,,, -90700,1.3024349,4.397399,,,,,,,,,,,,,, -90800,1.394289,3.2566092,,,,,,,,,,,,,, -90900,1.2412908,3.259766,,,,,,,,,,,,,, -91000,1.3357531,3.9615362,,,,,,,,,,,,,, -91100,1.1681585,4.093935,,,,,,,,,,,,,, -91136,,,0.7748632431030273,1.0888638496398926,0.7067999839782715,1.3886761665344238,50000.0,0.5821000337600708,1.994511842727661,10000.0,42047.39622235298,45510.83752632141,42047.39622235298,3454.4704039096832,3.881521701812744,0.0 -91200,1.2737888,4.5804014,,,,,,,,,,,,,, -91300,1.3370208,3.4534748,,,,,,,,,,,,,, -91400,1.2672102,4.6363273,,,,,,,,,,,,,, -91500,1.1888406,4.4990864,,,,,,,,,,,,,, -91600,1.324671,3.429083,,,,,,,,,,,,,, -91700,1.4910866,3.3341436,,,,,,,,,,,,,, -91800,1.3114153,3.2731013,,,,,,,,,,,,,, -91900,1.2988411,4.3043137,,,,,,,,,,,,,, -92000,1.434278,3.2959049,,,,,,,,,,,,,, -92046,,,0.7909570336341858,1.043149471282959,0.7130199670791626,1.3840327262878418,50000.0,0.585800051689148,1.991669774055481,10000.0,42467.585582733154,45966.19464612007,42467.585582733154,3489.544613361358,3.9240567684173584,0.0 -92100,1.2665595,3.4959886,,,,,,,,,,,,,, -92200,1.323195,4.813529,,,,,,,,,,,,,, -92300,1.3466003,3.234602,,,,,,,,,,,,,, -92400,1.25867,3.416337,,,,,,,,,,,,,, -92500,1.3941064,3.1648493,,,,,,,,,,,,,, -92600,1.5031036,3.3035028,,,,,,,,,,,,,, -92700,1.3793294,4.8204823,,,,,,,,,,,,,, -92800,1.2681916,3.5084155,,,,,,,,,,,,,, -92900,1.2855731,3.3372822,,,,,,,,,,,,,, -92957,,,0.7769726514816284,1.0807150602340698,0.7142399549484253,1.3572258949279783,50000.0,0.5952000021934509,1.9467597007751465,10000.0,42887.68548750877,46422.40642333031,42887.68548750877,3525.5610892772675,3.9683220386505127,0.0 -93000,1.426684,5.065611,,,,,,,,,,,,,, -93100,1.376254,3.2119412,,,,,,,,,,,,,, -93200,1.3098419,3.5107198,,,,,,,,,,,,,, -93300,1.4346745,4.609699,,,,,,,,,,,,,, -93400,1.3383327,3.168836,,,,,,,,,,,,,, -93500,1.4391937,3.2465627,,,,,,,,,,,,,, -93600,1.2716991,3.8264346,,,,,,,,,,,,,, -93700,1.4571067,3.279106,,,,,,,,,,,,,, -93800,1.298708,4.669185,,,,,,,,,,,,,, -93870,,,0.784863293170929,1.0583343505859375,0.7152599692344666,1.3549997806549072,50000.0,0.593000054359436,1.960601687431336,10000.0,43307.87548875809,46876.57886219025,43307.87548875809,3559.4507846832275,4.0101823806762695,0.0 -93900,1.2484283,4.0986757,,,,,,,,,,,,,, -94000,1.3722473,3.2114816,,,,,,,,,,,,,, -94100,1.289353,3.2828994,,,,,,,,,,,,,, -94200,1.3356715,3.286113,,,,,,,,,,,,,, -94300,1.3412216,3.3076365,,,,,,,,,,,,,, -94400,1.2556646,3.651221,,,,,,,,,,,,,, -94500,1.3623903,3.3116226,,,,,,,,,,,,,, -94600,1.280127,3.2922814,,,,,,,,,,,,,, -94700,1.3560494,4.819556,,,,,,,,,,,,,, -94782,,,0.7880077958106995,1.056152582168579,0.7136200070381165,1.3756757974624634,50000.0,0.5889000296592712,1.9845054149627688,10000.0,43727.9484167099,47331.931569337845,43727.9484167099,3594.6408054828644,4.049016237258911,0.0 -94800,1.4140306,3.2937903,,,,,,,,,,,,,, -94900,1.2712264,3.350438,,,,,,,,,,,,,, -95000,1.3408096,3.3436358,,,,,,,,,,,,,, -95100,1.3335136,3.1936533,,,,,,,,,,,,,, -95200,1.392681,3.3785484,,,,,,,,,,,,,, -95300,1.4857899,4.9765716,,,,,,,,,,,,,, -95400,1.3377771,4.50295,,,,,,,,,,,,,, -95500,1.187339,3.8494415,,,,,,,,,,,,,, -95600,1.3622065,3.254087,,,,,,,,,,,,,, -95694,,,0.7802929282188416,1.084344744682312,0.7172200083732605,1.356879949569702,50000.0,0.5938000082969666,1.965236783027649,10000.0,44147.98247885704,47786.82453203201,44147.98247885704,3629.39878821373,4.098080635070801,0.0 -95700,1.315841,3.1853728,,,,,,,,,,,,,, -95800,1.3261828,3.4953887,,,,,,,,,,,,,, -95900,1.5310043,4.92097,,,,,,,,,,,,,, -96000,1.3607506,4.5063004,,,,,,,,,,,,,, -96100,1.4105198,3.2059083,,,,,,,,,,,,,, -96200,1.3919269,3.4719968,,,,,,,,,,,,,, -96300,1.3661594,3.341379,,,,,,,,,,,,,, -96400,1.5805098,4.8607674,,,,,,,,,,,,,, -96500,1.3387054,3.2619932,,,,,,,,,,,,,, -96600,1.3409742,4.3141847,,,,,,,,,,,,,, -96606,,,0.7871288657188416,1.0491591691970823,0.7164199948310852,1.3515568971633911,50000.0,0.5927000045776367,1.9558128118515008,10000.0,44568.36045074463,48244.41781044006,44568.36045074463,3666.518232584,4.142136096954346,0.0 -96700,1.3754191,3.2663648,,,,,,,,,,,,,, -96800,1.2504373,4.160149,,,,,,,,,,,,,, -96900,1.4254688,3.2910688,,,,,,,,,,,,,, -97000,1.41241,3.317151,,,,,,,,,,,,,, -97100,1.2754599,3.9935346,,,,,,,,,,,,,, -97200,1.330873,4.359921,,,,,,,,,,,,,, -97300,1.344493,3.241509,,,,,,,,,,,,,, -97400,1.4065205,3.255163,,,,,,,,,,,,,, -97500,1.2747507,3.2584996,,,,,,,,,,,,,, -97517,,,0.7954687476158142,1.002249240875244,0.7200999855995178,1.323193907737732,50000.0,0.593500018119812,1.929459810256958,10000.0,44988.40380692482,48700.0536942482,44988.40380692482,3702.0162086486816,4.1854517459869385,0.0 -97600,1.3940973,3.2878447,,,,,,,,,,,,,, -97700,1.3442034,3.2222447,,,,,,,,,,,,,, -97800,1.5230179,3.23581,,,,,,,,,,,,,, -97900,1.4015259,3.3099897,,,,,,,,,,,,,, -98000,1.2402271,3.1999383,,,,,,,,,,,,,, -98100,1.490185,3.3260174,,,,,,,,,,,,,, -98200,1.4044763,4.4511833,,,,,,,,,,,,,, -98300,1.272347,3.457089,,,,,,,,,,,,,, -98400,1.4067442,3.2707627,,,,,,,,,,,,,, -98429,,,0.7873241901397705,1.0458989143371582,0.7177599668502808,1.3348171710968018,50000.0,0.5968000292778015,1.9434531927108765,10000.0,45408.60495519638,49155.00846171379,45408.60495519638,3736.6778705120087,4.226574182510376,0.0 -98500,1.3553905,3.1150103,,,,,,,,,,,,,, -98600,1.5856874,4.8534226,,,,,,,,,,,,,, -98700,1.3293846,3.191881,,,,,,,,,,,,,, -98800,1.405955,3.2950075,,,,,,,,,,,,,, -98900,1.4569649,3.2503445,,,,,,,,,,,,,, -99000,1.3222759,3.184925,,,,,,,,,,,,,, -99100,1.4542687,4.2765865,,,,,,,,,,,,,, -99200,1.2275714,4.029261,,,,,,,,,,,,,, -99300,1.3983591,3.2523797,,,,,,,,,,,,,, -99341,,,0.7878515720367432,1.0436352491378784,0.7171199917793274,1.3446561098098757,50000.0,0.6041000485420227,1.933828592300415,10000.0,45828.62591218949,49609.64790344238,45828.62591218949,3771.2031738758087,4.268864870071411,0.0 -99400,1.3243502,3.7908537,,,,,,,,,,,,,, -99500,1.5120931,4.9923034,,,,,,,,,,,,,, -99600,1.299592,3.3377476,,,,,,,,,,,,,, -99700,1.2741758,3.504257,,,,,,,,,,,,,, -99800,1.4210767,3.4001758,,,,,,,,,,,,,, -99900,1.342543,3.2467752,,,,,,,,,,,,,, -100000,1.3940861,3.2250483,,,,,,,,,,,,,, -100100,1.4161134,3.293825,,,,,,,,,,,,,, -100200,1.2768831,4.00827,,,,,,,,,,,,,, -100253,,,0.7922070026397705,1.0412719249725342,0.7170599699020386,1.350500226020813,50000.0,0.5974000096321106,1.94913387298584,10000.0,46248.61977934837,50064.15418791771,46248.61977934837,3805.623178482056,4.309793710708618,0.0 -100300,1.4596457,3.2691748,,,,,,,,,,,,,, -100400,1.4543599,3.314582,,,,,,,,,,,,,, -100500,1.3369036,4.3444624,,,,,,,,,,,,,, -100600,1.3501704,3.909946,,,,,,,,,,,,,, -100700,1.4436531,3.934168,,,,,,,,,,,,,, -100800,1.3615001,3.9451911,,,,,,,,,,,,,, -100900,1.3397036,4.2100916,,,,,,,,,,,,,, -101000,1.3845023,3.2604275,,,,,,,,,,,,,, -101100,1.4793237,3.3111463,,,,,,,,,,,,,, -101164,,,0.8041796684265137,1.0106101036071775,0.7200999855995178,1.3571895360946655,50000.0,0.5995000004768372,1.9695676565170288,10000.0,46668.6330909729,50520.18922710419,46668.6330909729,3841.550982236862,4.3530638217926025,0.0 -101200,1.4364395,3.1266317,,,,,,,,,,,,,, -101300,1.3854595,3.4274702,,,,,,,,,,,,,, -101400,1.4572432,3.1920056,,,,,,,,,,,,,, -101500,1.4252886,3.4032,,,,,,,,,,,,,, -101600,1.4166344,3.4366233,,,,,,,,,,,,,, -101700,1.52052,3.712556,,,,,,,,,,,,,, -101800,1.4029986,4.897826,,,,,,,,,,,,,, -101900,1.5602884,4.7119465,,,,,,,,,,,,,, -102000,1.4390225,3.4277947,,,,,,,,,,,,,, -102076,,,0.7922070026397705,1.0534024238586426,0.7217999696731567,1.349803447723389,50000.0,0.5991000533103943,1.948511242866516,10000.0,47088.92100191116,50977.975981235504,47088.92100191116,3878.951035261154,4.401503562927246,0.0 -102100,1.5136675,3.2822735,,,,,,,,,,,,,, -102200,1.4331065,3.3359532,,,,,,,,,,,,,, -102300,1.4167597,3.2540586,,,,,,,,,,,,,, -102400,1.4232546,3.039833,,,,,,,,,,,,,, -102500,1.5515791,4.7878065,,,,,,,,,,,,,, -102600,1.6146207,3.2459385,,,,,,,,,,,,,, -102700,1.4391629,4.3198485,,,,,,,,,,,,,, -102800,1.5615014,4.823719,,,,,,,,,,,,,, -102900,1.4278758,3.8603115,,,,,,,,,,,,,, -102988,,,0.79798823595047,1.0308090448379517,0.7218199968338013,1.3451817035675049,50000.0,0.5969000458717346,1.9463989734649656,10000.0,47508.928409576416,51435.638216257095,47508.928409576416,3916.5116069316864,4.4449450969696045,0.0 -103000,1.9081906,4.8929415,,,,,,,,,,,,,, -103100,1.4787645,3.2950127,,,,,,,,,,,,,, -103200,1.3680785,3.895692,,,,,,,,,,,,,, -103300,1.4414228,3.1496718,,,,,,,,,,,,,, -103400,1.3881278,3.1882725,,,,,,,,,,,,,, -103500,1.3794003,3.2296307,,,,,,,,,,,,,, -103600,1.3323784,4.027294,,,,,,,,,,,,,, -103700,1.5289668,3.1453764,,,,,,,,,,,,,, -103800,1.558273,3.2520068,,,,,,,,,,,,,, -103898,,,0.8116992115974426,0.9391869306564332,0.7234399914741516,1.3024083375930786,50000.0,0.6037000417709351,1.9161393642425537,10000.0,47929.27530050278,51890.92503976822,47929.27530050278,3951.357241153717,4.487648963928223,0.0 -103900,1.4224467,3.226323,,,,,,,,,,,,,, -104000,1.5053567,4.5503435,,,,,,,,,,,,,, -104100,1.4467409,3.2295063,,,,,,,,,,,,,, -104200,1.4161624,3.4131887,,,,,,,,,,,,,, -104300,1.5789183,3.1789906,,,,,,,,,,,,,, -104400,1.5174834,4.6106606,,,,,,,,,,,,,, -104500,1.4648486,3.2418616,,,,,,,,,,,,,, -104600,1.4472127,3.3571074,,,,,,,,,,,,,, -104700,1.6554428,3.1899223,,,,,,,,,,,,,, -104800,1.3622713,3.7484648,,,,,,,,,,,,,, -104812,,,0.7913671731948853,1.040604829788208,0.7239199876785278,1.3334083557128906,50000.0,0.605400025844574,1.936792254447937,10000.0,48349.42050933838,52345.748777627945,48349.42050933838,3985.942140340805,4.530482769012451,0.0 -104900,1.4982697,3.2121966,,,,,,,,,,,,,, -105000,1.4290664,3.2474835,,,,,,,,,,,,,, -105100,1.3907125,3.228625,,,,,,,,,,,,,, -105200,1.4300487,4.571459,,,,,,,,,,,,,, -105300,1.358539,3.4779127,,,,,,,,,,,,,, -105400,1.4512628,3.5399394,,,,,,,,,,,,,, -105500,1.4588623,3.1654952,,,,,,,,,,,,,, -105600,1.4175745,3.1990204,,,,,,,,,,,,,, -105700,1.3457953,4.047627,,,,,,,,,,,,,, -105725,,,0.7983007431030273,1.0020376443862915,0.7238799929618835,1.3144477605819702,50000.0,0.5963000059127808,1.9204466342926023,10000.0,48769.416988134384,52802.70588493347,48769.416988134384,4022.8122441768646,4.570789575576782,0.0 -105800,1.5716876,3.251966,,,,,,,,,,,,,, -105900,1.5269352,4.3994355,,,,,,,,,,,,,, -106000,1.5430764,4.461462,,,,,,,,,,,,,, -106100,1.647856,3.2390184,,,,,,,,,,,,,, -106200,1.3604333,3.3404922,,,,,,,,,,,,,, -106300,1.7038236,4.8916936,,,,,,,,,,,,,, -106400,1.5095328,3.2712483,,,,,,,,,,,,,, -106500,1.520266,3.5055616,,,,,,,,,,,,,, -106600,1.419148,3.2212567,,,,,,,,,,,,,, -106635,,,0.8101366758346558,0.9570842385292052,0.725820004940033,1.2992676496505735,50000.0,0.5995000004768372,1.908978581428528,10000.0,49189.34255599976,53258.69718146324,49189.34255599976,4058.433498620987,4.963132858276367,0.0 -106700,1.4645886,3.197486,,,,,,,,,,,,,, -106800,1.640712,3.3126998,,,,,,,,,,,,,, -106900,1.3729053,3.908736,,,,,,,,,,,,,, -107000,1.4883044,3.4953067,,,,,,,,,,,,,, -107100,1.5270116,3.2322755,,,,,,,,,,,,,, -107200,1.4258351,3.159677,,,,,,,,,,,,,, -107300,1.4389135,3.2824867,,,,,,,,,,,,,, -107400,1.4819654,3.3814673,,,,,,,,,,,,,, -107500,1.4569716,3.1800528,,,,,,,,,,,,,, -107544,,,0.7962890267372131,1.0105478763580322,0.7277799844741821,1.2957388162612915,50000.0,0.6050000190734863,1.9033678770065308,10000.0,49609.5889275074,53715.4951851368,49609.5889275074,4094.8872702121735,5.010452508926392,0.0 -107600,1.3552091,3.0994298,,,,,,,,,,,,,, -107700,1.4388095,3.2609372,,,,,,,,,,,,,, -107800,1.5062852,3.2107232,,,,,,,,,,,,,, -107900,1.4145229,4.010944,,,,,,,,,,,,,, -108000,1.4037483,3.1901972,,,,,,,,,,,,,, -108100,1.3735883,3.5742579,,,,,,,,,,,,,, -108200,1.5508361,4.8292594,,,,,,,,,,,,,, -108300,1.4889582,3.1836166,,,,,,,,,,,,,, -108400,1.6038542,3.1932693,,,,,,,,,,,,,, -108455,,,0.802539050579071,0.989858627319336,0.7298199534416199,1.2934014797210691,50000.0,0.6094000339508057,1.8905043601989744,10000.0,50029.638154029846,54173.51714682579,50029.638154029846,4132.766643047333,5.052933216094971,0.0 -108500,1.5837299,4.6891465,,,,,,,,,,,,,, -108600,1.4290097,4.0354257,,,,,,,,,,,,,, -108700,1.5591044,3.1627798,,,,,,,,,,,,,, -108800,1.5596874,3.1239867,,,,,,,,,,,,,, -108900,1.5654353,3.2488086,,,,,,,,,,,,,, -109000,1.4673567,3.3684764,,,,,,,,,,,,,, -109100,1.3850272,3.254669,,,,,,,,,,,,,, -109200,1.6204323,3.4022114,,,,,,,,,,,,,, -109300,1.4426874,3.9461048,,,,,,,,,,,,,, -109365,,,0.8080468773841858,0.9750061631202698,0.7301200032234192,1.2988848686218262,50000.0,0.6034000515937805,1.914489388465881,10000.0,50449.68918180466,54632.3221013546,50449.68918180466,4171.423996925354,5.098599672317505,0.0 -109400,1.5103426,3.1212933,,,,,,,,,,,,,, -109500,1.435635,4.06249,,,,,,,,,,,,,, -109600,1.5120871,3.261739,,,,,,,,,,,,,, -109700,1.3858805,3.715804,,,,,,,,,,,,,, -109800,1.4879144,3.1207795,,,,,,,,,,,,,, -109900,1.6387688,3.153074,,,,,,,,,,,,,, -110000,1.3958001,3.6557546,,,,,,,,,,,,,, -110100,1.6870707,4.4440117,,,,,,,,,,,,,, -110200,1.3941835,3.149356,,,,,,,,,,,,,, -110274,,,0.7994921803474426,1.0312963724136353,0.7298199534416199,1.3248251676559448,50000.0,0.6038000583648682,1.9229018688201904,10000.0,50869.71073126793,55088.965970516205,50869.71073126793,4207.95077753067,5.143129587173462,0.0 -110300,1.5314304,4.5908213,,,,,,,,,,,,,, -110400,1.5859746,4.5730567,,,,,,,,,,,,,, -110500,1.4729798,4.2745457,,,,,,,,,,,,,, -110600,1.5124998,3.13647,,,,,,,,,,,,,, -110700,1.5821525,3.1502242,,,,,,,,,,,,,, -110800,1.5007634,3.0836625,,,,,,,,,,,,,, -110900,1.512691,3.3839326,,,,,,,,,,,,,, -111000,1.5116109,3.2621548,,,,,,,,,,,,,, -111100,1.7065598,4.8513427,,,,,,,,,,,,,, -111184,,,0.8047069907188416,0.9976529479026794,0.7334799766540527,1.296475529670715,50000.0,0.6066000461578369,1.912151455879212,10000.0,51289.89997005463,55545.89949464798,51289.89997005463,4244.599578619003,5.188409090042114,0.0 -111200,1.7801636,4.7616806,,,,,,,,,,,,,, -111300,1.5230458,3.554769,,,,,,,,,,,,,, -111400,1.4711431,3.774589,,,,,,,,,,,,,, -111500,1.5298976,3.1303906,,,,,,,,,,,,,, -111600,1.4741622,3.969026,,,,,,,,,,,,,, -111700,1.5632044,3.3254075,,,,,,,,,,,,,, -111800,1.7100703,4.873798,,,,,,,,,,,,,, -111900,1.6637977,3.9740055,,,,,,,,,,,,,, -112000,1.5011072,3.2170372,,,,,,,,,,,,,, -112093,,,0.8110741972923279,0.9474124908447266,0.7343599796295166,1.2785613536834717,50000.0,0.6106000542640686,1.889366149902344,10000.0,51709.84427022934,56003.3701505661,51709.84427022934,4282.025975942612,5.237752914428711,0.0 -112100,1.468841,4.087143,,,,,,,,,,,,,, -112200,1.5174912,4.720995,,,,,,,,,,,,,, -112300,1.3923742,3.6819766,,,,,,,,,,,,,, -112400,1.649106,3.3404346,,,,,,,,,,,,,, -112500,1.499037,3.2005816,,,,,,,,,,,,,, -112600,1.704267,4.0755157,,,,,,,,,,,,,, -112700,1.4602019,3.1506567,,,,,,,,,,,,,, -112800,1.718875,4.231906,,,,,,,,,,,,,, -112900,1.6114597,3.268847,,,,,,,,,,,,,, -113000,1.6905516,4.5077753,,,,,,,,,,,,,, -113001,,,0.8006640672683716,0.9824928045272828,0.7322399616241455,1.2828460931777954,50000.0,0.6092000007629395,1.8907074928283687,10000.0,52129.879022836685,56458.75296974182,52129.879022836685,4317.279319286346,5.281320095062256,0.0 -113100,1.6394962,3.2693367,,,,,,,,,,,,,, -113200,1.515438,4.048202,,,,,,,,,,,,,, -113300,1.5334276,3.2183166,,,,,,,,,,,,,, -113400,1.4615891,3.1527765,,,,,,,,,,,,,, -113500,1.5685449,3.2757845,,,,,,,,,,,,,, -113600,1.6154753,3.1215098,,,,,,,,,,,,,, -113700,1.7290124,3.1240497,,,,,,,,,,,,,, -113800,1.5503836,3.1941447,,,,,,,,,,,,,, -113900,1.5723164,4.4390955,,,,,,,,,,,,,, -113912,,,0.8100976347923279,0.9563019871711732,0.7339800000190735,1.2726399898529053,50000.0,0.6098000407218933,1.8746048212051392,10000.0,52549.861157894135,56913.93776369095,52549.861157894135,4352.38493680954,5.327167272567749,0.0 -114000,1.639262,4.524612,,,,,,,,,,,,,, -114100,1.4931427,3.1590123,,,,,,,,,,,,,, -114200,1.4761593,3.7309675,,,,,,,,,,,,,, -114300,1.5554876,4.2960825,,,,,,,,,,,,,, -114400,1.4786671,3.1664057,,,,,,,,,,,,,, -114500,1.4423789,3.6290739,,,,,,,,,,,,,, -114600,1.6358523,3.1897554,,,,,,,,,,,,,, -114700,1.5738143,3.1897106,,,,,,,,,,,,,, -114800,1.4863505,4.1414137,,,,,,,,,,,,,, -114823,,,0.8125,0.9517192840576172,0.733460009098053,1.2863125801086426,50000.0,0.6060000061988831,1.896116852760315,10000.0,52970.15888476372,57369.3252120018,52970.15888476372,4387.382179737091,5.3694212436676025,0.0 -114900,1.4595358,3.134584,,,,,,,,,,,,,, -115000,1.6368021,4.617882,,,,,,,,,,,,,, -115100,1.6232604,3.1103837,,,,,,,,,,,,,, -115200,1.5478626,3.1308484,,,,,,,,,,,,,, -115300,1.5639583,3.0997474,,,,,,,,,,,,,, -115400,1.5030433,3.3059983,,,,,,,,,,,,,, -115500,1.4911952,3.68225,,,,,,,,,,,,,, -115600,1.5758488,4.3202467,,,,,,,,,,,,,, -115700,1.5485198,3.1460977,,,,,,,,,,,,,, -115733,,,0.8245702981948853,0.9153898358345032,0.7346000075340271,1.282979965209961,50000.0,0.6184000372886658,1.8812453746795648,10000.0,53390.33986020088,57826.413011312485,53390.33986020088,4424.193821430206,5.413953542709351,0.0 -115800,1.4892782,3.4733536,,,,,,,,,,,,,, -115900,1.5270933,3.0425959,,,,,,,,,,,,,, -116000,1.5899584,3.0780563,,,,,,,,,,,,,, -116100,1.608597,3.1365504,,,,,,,,,,,,,, -116200,1.5549468,3.178758,,,,,,,,,,,,,, -116300,1.4797791,3.0664227,,,,,,,,,,,,,, -116400,1.7152264,3.1525583,,,,,,,,,,,,,, -116500,1.6392719,3.2800791,,,,,,,,,,,,,, -116600,1.5421233,3.2033026,,,,,,,,,,,,,, -116642,,,0.8141601085662842,0.9461965560913086,0.7366200089454651,1.266655683517456,50000.0,0.615600049495697,1.860821962356568,10000.0,53810.46590876579,58282.51025009155,53810.46590876579,4460.06673002243,5.46102499961853,0.0 -116700,1.5032917,3.3558238,,,,,,,,,,,,,, -116800,1.532956,3.1549945,,,,,,,,,,,,,, -116900,1.758556,3.212014,,,,,,,,,,,,,, -117000,1.6353378,4.4365387,,,,,,,,,,,,,, -117100,1.6329479,3.164271,,,,,,,,,,,,,, -117200,1.5717506,3.3175712,,,,,,,,,,,,,, -117300,1.5885222,3.1267505,,,,,,,,,,,,,, -117400,1.3994422,3.6296253,,,,,,,,,,,,,, -117500,1.6634152,4.4653106,,,,,,,,,,,,,, -117554,,,0.8140038847923279,0.9657190442085266,0.7342599630355835,1.2892900705337524,50000.0,0.6165000200271606,1.889684081077576,10000.0,54230.78221082688,58737.82883524895,54230.78221082688,4494.974093198776,5.504605293273926,0.0 -117600,1.636894,3.1848733,,,,,,,,,,,,,, -117700,1.9198632,4.6836414,,,,,,,,,,,,,, -117800,1.6062887,3.0898428,,,,,,,,,,,,,, -117900,1.5613658,3.3625863,,,,,,,,,,,,,, -118000,1.7251632,3.1093905,,,,,,,,,,,,,, -118100,1.921901,4.6914883,,,,,,,,,,,,,, -118200,1.5697162,4.264983,,,,,,,,,,,,,, -118300,1.5212055,3.7627752,,,,,,,,,,,,,, -118400,1.7505717,3.1320307,,,,,,,,,,,,,, -118465,,,0.8282226324081421,0.9008296132087708,0.7376599907875061,1.2724199295043943,50000.0,0.6170000433921814,1.8633793592453003,10000.0,54650.84451293945,59194.7225048542,54650.84451293945,4531.71052479744,5.548776865005493,0.0 -118500,1.6718067,4.449377,,,,,,,,,,,,,, -118600,1.5113382,3.6413877,,,,,,,,,,,,,, -118700,1.5861713,3.2033734,,,,,,,,,,,,,, -118800,1.458669,3.880893,,,,,,,,,,,,,, -118900,1.8054178,3.1814594,,,,,,,,,,,,,, -119000,1.6284755,3.0953817,,,,,,,,,,,,,, -119100,1.8440466,4.735669,,,,,,,,,,,,,, -119200,1.6494577,3.0825365,,,,,,,,,,,,,, -119300,1.4699482,3.48562,,,,,,,,,,,,,, -119375,,,0.81787109375,0.9437991380691528,0.7403199672698975,1.265620231628418,50000.0,0.6144000291824341,1.868046522140503,10000.0,55070.77024292946,59649.38969564438,55070.77024292946,4566.349142313004,5.601795434951782,0.0 -119400,1.5721488,3.4175644,,,,,,,,,,,,,, -119500,2.0594149,4.5669823,,,,,,,,,,,,,, -119600,1.5571879,3.078189,,,,,,,,,,,,,, -119700,1.6434621,3.0950472,,,,,,,,,,,,,, -119800,1.6040559,2.9632893,,,,,,,,,,,,,, -119900,1.6611133,3.0419211,,,,,,,,,,,,,, -120000,1.6965499,3.1155536,,,,,,,,,,,,,, -120100,1.5905204,3.2598276,,,,,,,,,,,,,, -120200,1.5836561,3.1258752,,,,,,,,,,,,,, -120285,,,0.8151366710662842,0.9468899369239808,0.7399799823760986,1.2717607021331787,50000.0,0.6229000091552734,1.870540738105774,10000.0,55490.71073937416,60105.31962871552,55490.71073937416,4602.237959384918,5.651767730712891,0.0 -120300,1.6857295,3.080517,,,,,,,,,,,,,, -120400,1.6088028,3.296302,,,,,,,,,,,,,, -120500,1.6811746,3.1927185,,,,,,,,,,,,,, -120600,1.5012858,3.5885465,,,,,,,,,,,,,, -120700,1.5917616,4.0215945,,,,,,,,,,,,,, -120800,1.4972048,3.1622286,,,,,,,,,,,,,, -120900,1.4105209,3.4814231,,,,,,,,,,,,,, -121000,1.9529921,4.686514,,,,,,,,,,,,,, -121100,1.5911252,3.896095,,,,,,,,,,,,,, -121196,,,0.8320898413658142,0.8478189706802368,0.742419958114624,1.217381715774536,50000.0,0.6231000423431396,1.803972601890564,10000.0,55910.66052460671,60562.70079231262,55910.66052460671,4639.570498466492,5.699680805206299,0.0 -121200,1.9013916,4.645633,,,,,,,,,,,,,, -121300,1.8474864,4.806991,,,,,,,,,,,,,, -121400,1.7202811,4.6297846,,,,,,,,,,,,,, -121500,1.5753113,3.839492,,,,,,,,,,,,,, -121600,1.9584183,4.692464,,,,,,,,,,,,,, -121700,1.815149,3.091277,,,,,,,,,,,,,, -121800,1.5346123,4.2041554,,,,,,,,,,,,,, -121900,1.6934385,3.0156991,,,,,,,,,,,,,, -122000,1.6192343,3.3030057,,,,,,,,,,,,,, -122100,1.4918507,3.890985,,,,,,,,,,,,,, -122107,,,0.8207421898841858,0.926133930683136,0.7434799671173096,1.2434664964675903,50000.0,0.6258000135421753,1.83869731426239,10000.0,56330.80289840698,61017.7448694706,56330.80289840698,4674.371244430542,5.749354600906372,0.0 -122200,1.6005267,3.0616431,,,,,,,,,,,,,, -122300,1.8604783,4.737118,,,,,,,,,,,,,, -122400,1.5989918,3.2189796,,,,,,,,,,,,,, -122500,1.6710955,3.7833462,,,,,,,,,,,,,, -122600,1.6452434,3.1125515,,,,,,,,,,,,,, -122700,1.678557,3.6878095,,,,,,,,,,,,,, -122800,1.7127391,4.171491,,,,,,,,,,,,,, -122900,1.6501526,3.2132447,,,,,,,,,,,,,, -123000,1.6765553,3.0742521,,,,,,,,,,,,,, -123019,,,0.8250390291213989,0.8884112238883972,0.7450599670410156,1.2220823764801023,50000.0,0.6247000098228455,1.8158106803894043,10000.0,56751.03601574898,61473.86522865296,56751.03601574898,4710.162944316864,5.793039083480835,0.0 -123100,1.7938595,4.382135,,,,,,,,,,,,,, -123200,1.6202291,3.0283954,,,,,,,,,,,,,, -123300,1.6262932,3.488864,,,,,,,,,,,,,, -123400,1.7178631,2.9877298,,,,,,,,,,,,,, -123500,1.573661,3.034999,,,,,,,,,,,,,, -123600,1.7181083,4.1809506,,,,,,,,,,,,,, -123700,1.7030371,3.0848076,,,,,,,,,,,,,, -123800,1.5831268,3.128779,,,,,,,,,,,,,, -123900,1.8355645,3.0888455,,,,,,,,,,,,,, -123930,,,0.8338476419448853,0.8774459362030029,0.7470399737358093,1.2402421236038208,50000.0,0.6234000325202942,1.8210102319717407,10000.0,57171.007404088974,61927.66263628006,57171.007404088974,4743.8915066719055,5.839264869689941,0.0 -124000,1.7286787,3.7706351,,,,,,,,,,,,,, -124100,1.626473,3.804004,,,,,,,,,,,,,, -124200,1.6632813,3.2092264,,,,,,,,,,,,,, -124300,1.6744562,3.1259022,,,,,,,,,,,,,, -124400,1.7333676,3.0096126,,,,,,,,,,,,,, -124500,1.8741782,4.0650086,,,,,,,,,,,,,, -124600,1.7358652,3.3759031,,,,,,,,,,,,,, -124700,2.0034754,4.6825976,,,,,,,,,,,,,, -124800,1.8181709,3.0831158,,,,,,,,,,,,,, -124839,,,0.8262499570846558,0.8751652240753174,0.7460199594497681,1.2193933725357056,50000.0,0.6278000473976135,1.8021245002746584,10000.0,57591.029952049255,62385.57867670059,57591.029952049255,4781.690023899078,5.884131908416748,0.0 -124900,1.7859615,3.134729,,,,,,,,,,,,,, -125000,1.6972122,4.169699,,,,,,,,,,,,,, -125100,1.5485467,3.6172082,,,,,,,,,,,,,, -125200,1.588092,3.0217843,,,,,,,,,,,,,, -125300,1.6838042,3.07545,,,,,,,,,,,,,, -125400,1.7402098,4.0992975,,,,,,,,,,,,,, -125500,1.7426145,3.0788896,,,,,,,,,,,,,, -125600,1.6939515,3.079276,,,,,,,,,,,,,, -125700,1.7253457,3.017863,,,,,,,,,,,,,, -125749,,,0.8309765458106995,0.8596699833869934,0.7472400069236755,1.2096494436264038,50000.0,0.6221000552177429,1.8037065267562864,10000.0,58010.99251627922,62841.30251932144,58010.99251627922,4817.354856729507,5.930292129516602,0.0 -125800,1.5891403,3.267531,,,,,,,,,,,,,, -125900,1.75926,3.020614,,,,,,,,,,,,,, -126000,1.7671174,2.9688077,,,,,,,,,,,,,, -126100,1.585538,3.6730974,,,,,,,,,,,,,, -126200,1.8437177,4.364389,,,,,,,,,,,,,, -126300,1.7546252,3.0449314,,,,,,,,,,,,,, -126400,1.7916591,4.378639,,,,,,,,,,,,,, -126500,1.6849222,3.1237445,,,,,,,,,,,,,, -126600,1.7378017,3.0681453,,,,,,,,,,,,,, -126660,,,0.83447265625,0.8702915906906128,0.7460399866104126,1.243055820465088,50000.0,0.6221000552177429,1.84426748752594,10000.0,58431.05421996117,63297.124264001846,58431.05421996117,4853.018300771713,5.975062370300293,0.0 -126700,2.090869,4.6610527,,,,,,,,,,,,,, -126800,1.7795914,3.1778295,,,,,,,,,,,,,, -126900,1.7096407,3.0289836,,,,,,,,,,,,,, -127000,2.1956995,4.622472,,,,,,,,,,,,,, -127100,1.675094,3.1061492,,,,,,,,,,,,,, -127200,1.6424166,3.094842,,,,,,,,,,,,,, -127300,1.7850708,3.0199995,,,,,,,,,,,,,, -127400,1.7686367,2.969593,,,,,,,,,,,,,, -127500,2.0535827,4.654274,,,,,,,,,,,,,, -127571,,,0.8270898461341858,0.900422215461731,0.7506600022315979,1.2298741340637207,50000.0,0.6247000098228455,1.8229761123657229,10000.0,58851.27061104775,63754.40138721466,58851.27061104775,4889.977033615112,6.026872634887695,0.0 -127600,1.8253182,3.081821,,,,,,,,,,,,,, -127700,1.7632213,4.230262,,,,,,,,,,,,,, -127800,2.0031652,4.3287477,,,,,,,,,,,,,, -127900,1.7587465,4.202478,,,,,,,,,,,,,, -128000,1.8514895,3.0726128,,,,,,,,,,,,,, -128100,1.7680513,4.123122,,,,,,,,,,,,,, -128200,2.0038242,4.271864,,,,,,,,,,,,,, -128300,2.0451362,4.2913723,,,,,,,,,,,,,, -128400,1.7771415,3.148745,,,,,,,,,,,,,, -128484,,,0.8315820097923279,0.8549019694328308,0.7501399517059326,1.189975619316101,50000.0,0.6278000473976135,1.7846035957336426,10000.0,59271.248304367065,64209.13311076164,59271.248304367065,4924.630076885223,6.076805591583252,0.0 -128500,1.6817642,3.2284155,,,,,,,,,,,,,, -128600,1.7610993,3.0945957,,,,,,,,,,,,,, -128700,1.9746851,3.0744498,,,,,,,,,,,,,, -128800,1.8300486,3.0270662,,,,,,,,,,,,,, -128900,1.9637669,3.9531028,,,,,,,,,,,,,, -129000,1.7691659,3.2737253,,,,,,,,,,,,,, -129100,1.6689339,2.9364147,,,,,,,,,,,,,, -129200,1.7847161,3.8051414,,,,,,,,,,,,,, -129300,1.7883364,3.0412693,,,,,,,,,,,,,, -129393,,,0.8374413847923279,0.8287783861160278,0.7497599720954895,1.192121505737305,50000.0,0.6319000124931335,1.7807656526565552,10000.0,59691.69730448723,64667.32375311852,59691.69730448723,4962.272361755371,6.125463962554932,0.0 -129400,1.8561088,4.408467,,,,,,,,,,,,,, -129500,1.7817657,4.194348,,,,,,,,,,,,,, -129600,2.0220475,4.570338,,,,,,,,,,,,,, -129700,1.9587464,3.0923152,,,,,,,,,,,,,, -129800,2.550052,4.692494,,,,,,,,,,,,,, -129900,1.7298945,3.0061066,,,,,,,,,,,,,, -130000,1.7255136,3.3135777,,,,,,,,,,,,,, -130100,1.9975723,3.0620725,,,,,,,,,,,,,, -130200,1.661058,3.0107584,,,,,,,,,,,,,, -130300,1.5532461,3.684693,,,,,,,,,,,,,, -130301,,,0.8481640219688416,0.8119232058525085,0.7521599531173706,1.201475977897644,50000.0,0.6328000426292419,1.7882546186447144,10000.0,60111.69879961014,65121.97418880463,60111.69879961014,4996.821247339249,6.174273729324341,0.0 -130400,1.6109612,3.433545,,,,,,,,,,,,,, -130500,1.8330797,3.0720909,,,,,,,,,,,,,, -130600,1.7723571,2.975987,,,,,,,,,,,,,, -130700,1.7631333,3.253689,,,,,,,,,,,,,, -130800,2.1949778,4.6420555,,,,,,,,,,,,,, -130900,2.081808,4.0633035,,,,,,,,,,,,,, -131000,1.67473,3.2364514,,,,,,,,,,,,,, -131100,1.6555071,3.113465,,,,,,,,,,,,,, -131200,2.0446362,4.2407393,,,,,,,,,,,,,, -131212,,,0.8323046565055847,0.855158269405365,0.750819981098175,1.1929678916931152,50000.0,0.6363000273704529,1.7767421007156372,10000.0,60531.89884757996,65577.1236114502,60531.89884757996,5031.673780918121,6.2202746868133545,0.0 -131300,1.7782062,2.9875593,,,,,,,,,,,,,, -131400,1.8542188,3.4428844,,,,,,,,,,,,,, -131500,1.8368787,2.9960127,,,,,,,,,,,,,, -131600,1.7633047,3.0475748,,,,,,,,,,,,,, -131700,1.6820651,3.0614972,,,,,,,,,,,,,, -131800,1.7688308,3.1767368,,,,,,,,,,,,,, -131900,1.8224701,3.1177206,,,,,,,,,,,,,, -132000,1.881875,3.061416,,,,,,,,,,,,,, -132100,2.199421,4.381676,,,,,,,,,,,,,, -132124,,,0.8398241996765137,0.8325179815292358,0.7538399696350098,1.1827151775360107,50000.0,0.6365000009536743,1.767372488975525,10000.0,60951.83554720879,66033.88747620583,60951.83554720879,5068.405083656311,6.265010595321655,0.0 -132200,1.8608937,3.0307279,,,,,,,,,,,,,, -132300,1.8610202,4.372232,,,,,,,,,,,,,, -132400,1.729989,2.871678,,,,,,,,,,,,,, -132500,1.9596288,4.341441,,,,,,,,,,,,,, -132600,1.8741249,3.0165677,,,,,,,,,,,,,, -132700,1.7275331,3.112442,,,,,,,,,,,,,, -132800,1.8560601,3.0503824,,,,,,,,,,,,,, -132900,1.9420623,3.0707245,,,,,,,,,,,,,, -133000,1.9588801,3.7017326,,,,,,,,,,,,,, -133034,,,0.8498437404632568,0.8029429912567139,0.7543999552726746,1.2014082670211792,50000.0,0.636400043964386,1.7902463674545288,10000.0,61371.80930924416,66489.36799764633,61371.80930924416,5103.814422369003,6.311906576156616,0.0 -133100,1.8468945,3.172113,,,,,,,,,,,,,, -133200,2.193943,4.597551,,,,,,,,,,,,,, -133300,1.8681086,2.9835186,,,,,,,,,,,,,, -133400,1.9310172,4.1395807,,,,,,,,,,,,,, -133500,1.9194841,3.0624704,,,,,,,,,,,,,, -133600,1.8633312,3.02385,,,,,,,,,,,,,, -133700,1.8515239,4.17922,,,,,,,,,,,,,, -133800,1.8604883,2.947484,,,,,,,,,,,,,, -133900,1.9003803,3.0406487,,,,,,,,,,,,,, -133945,,,0.8379882574081421,0.8267109394073486,0.7539199590682983,1.1843620538711548,50000.0,0.6361000537872314,1.7660489082336426,10000.0,61791.847019433975,66944.96662092209,61791.847019433975,5139.275134801865,6.361809492111206,0.0 -134000,1.7615244,3.0827825,,,,,,,,,,,,,, -134100,1.8502672,3.0543387,,,,,,,,,,,,,, -134200,1.9937787,2.9428942,,,,,,,,,,,,,, -134300,1.8598775,3.0427446,,,,,,,,,,,,,, -134400,1.7779399,3.2798567,,,,,,,,,,,,,, -134500,1.7876703,3.0080562,,,,,,,,,,,,,, -134600,2.313984,4.5590577,,,,,,,,,,,,,, -134700,1.8193148,3.158963,,,,,,,,,,,,,, -134800,1.7814873,3.6610954,,,,,,,,,,,,,, -134857,,,0.8421288728713989,0.8246394395828247,0.7559399604797363,1.183579444885254,50000.0,0.6372000575065613,1.7756555080413818,10000.0,62211.88814759255,67399.85404849052,62211.88814759255,5174.021430253983,6.409978866577148,0.0 -134900,1.842498,2.9765694,,,,,,,,,,,,,, -135000,1.7893023,3.1844025,,,,,,,,,,,,,, -135100,2.0342298,4.3212566,,,,,,,,,,,,,, -135200,1.9471186,3.7499413,,,,,,,,,,,,,, -135300,2.2798233,4.505261,,,,,,,,,,,,,, -135400,1.7672031,3.3250742,,,,,,,,,,,,,, -135500,1.8575153,2.9876103,,,,,,,,,,,,,, -135600,1.9758472,2.901029,,,,,,,,,,,,,, -135700,1.9090269,3.0116088,,,,,,,,,,,,,, -135766,,,0.847460925579071,0.8194810152053833,0.7557799816131592,1.199291110038757,50000.0,0.635200023651123,1.788246989250183,10000.0,62631.87271881104,67856.05798530579,62631.87271881104,5210.1441123485565,6.4558424949646,0.0 -135800,1.8460827,3.9084911,,,,,,,,,,,,,, -135900,1.8437523,3.3207078,,,,,,,,,,,,,, -136000,2.0641267,4.007006,,,,,,,,,,,,,, -136100,1.8412012,3.3344617,,,,,,,,,,,,,, -136200,2.0245495,3.025289,,,,,,,,,,,,,, -136300,2.0604455,3.0498881,,,,,,,,,,,,,, -136400,1.9100901,3.038397,,,,,,,,,,,,,, -136500,2.0543153,2.9850552,,,,,,,,,,,,,, -136600,2.0854564,4.173213,,,,,,,,,,,,,, -136677,,,0.8434374928474426,0.8096722364425659,0.75764000415802,1.1676242351531982,50000.0,0.6416000127792358,1.7473360300064087,10000.0,63052.01003956795,68314.43851542473,63052.01003956795,5248.283586978912,6.5085837841033936,0.0 -136700,2.0653088,2.9571486,,,,,,,,,,,,,, -136800,1.978332,3.113801,,,,,,,,,,,,,, -136900,1.9655857,3.2229757,,,,,,,,,,,,,, -137000,1.8098725,2.9892492,,,,,,,,,,,,,, -137100,1.9283804,3.906632,,,,,,,,,,,,,, -137200,2.0054054,2.9546638,,,,,,,,,,,,,, -137300,2.056617,3.746877,,,,,,,,,,,,,, -137400,1.7935863,2.949948,,,,,,,,,,,,,, -137500,1.964451,2.9966207,,,,,,,,,,,,,, -137587,,,0.8507421612739563,0.7920365929603577,0.7602199912071228,1.160401463508606,50000.0,0.6402000188827515,1.75056791305542,10000.0,63472.32473063469,68772.46301627159,63472.32473063469,5285.890654087067,6.560576677322388,0.0 -137600,1.775043,3.0449865,,,,,,,,,,,,,, -137700,1.9665804,4.1788135,,,,,,,,,,,,,, -137800,2.0719013,3.0808144,,,,,,,,,,,,,, -137900,2.0004857,3.8491926,,,,,,,,,,,,,, -138000,1.8808249,3.2543387,,,,,,,,,,,,,, -138100,1.8450216,3.5993316,,,,,,,,,,,,,, -138200,2.3560221,4.3000727,,,,,,,,,,,,,, -138300,1.8998997,2.981722,,,,,,,,,,,,,, -138400,2.3421426,4.349445,,,,,,,,,,,,,, -138498,,,0.8550781011581421,0.7727980613708496,0.7599799633026123,1.1643576622009275,50000.0,0.6422000527381897,1.75530207157135,10000.0,63892.652406454086,69228.47512078285,63892.652406454086,5321.47384428978,6.611251592636108,0.0 -138500,3.2135227,4.52667,,,,,,,,,,,,,, -138600,1.9043753,3.2835789,,,,,,,,,,,,,, -138700,1.955138,3.3842683,,,,,,,,,,,,,, -138800,1.8376358,3.3470273,,,,,,,,,,,,,, -138900,2.0185292,3.052918,,,,,,,,,,,,,, -139000,2.0015519,4.1208615,,,,,,,,,,,,,, -139100,1.9074506,2.9806275,,,,,,,,,,,,,, -139200,1.9188727,3.0441527,,,,,,,,,,,,,, -139300,1.9807521,3.1342838,,,,,,,,,,,,,, -139400,2.3863933,4.3891664,,,,,,,,,,,,,, -139410,,,0.8495312333106995,0.8162821531295776,0.7620399594306946,1.1751127243041992,50000.0,0.65010005235672,1.7389004230499268,10000.0,64312.980843544006,69689.05562400818,64312.980843544006,5361.627897024155,6.65903902053833,0.0 -139500,2.0795138,4.0476503,,,,,,,,,,,,,, -139600,1.9320122,2.8847158,,,,,,,,,,,,,, -139700,1.9913675,3.088325,,,,,,,,,,,,,, -139800,1.9511999,2.9969487,,,,,,,,,,,,,, -139900,1.9922891,2.8459747,,,,,,,,,,,,,, -140000,1.7785468,2.9166465,,,,,,,,,,,,,, -140100,1.8897533,3.1573262,,,,,,,,,,,,,, -140200,1.8808887,3.397255,,,,,,,,,,,,,, -140300,1.9305648,2.977721,,,,,,,,,,,,,, -140320,,,0.8509179353713989,0.797985315322876,0.7607600092887878,1.1660995483398438,50000.0,0.6412000060081482,1.7588180303573608,10000.0,64732.98385691643,70144.70323824883,64732.98385691643,5397.175455093384,6.705945253372192,0.0 -140400,1.8382736,3.2887497,,,,,,,,,,,,,, -140500,1.9024541,3.9555297,,,,,,,,,,,,,, -140600,2.0963402,2.86147,,,,,,,,,,,,,, -140700,2.114427,2.9486856,,,,,,,,,,,,,, -140800,1.8712828,2.9385166,,,,,,,,,,,,,, -140900,2.002651,2.934032,,,,,,,,,,,,,, -141000,2.0246332,3.2113175,,,,,,,,,,,,,, -141100,2.2243567,4.339758,,,,,,,,,,,,,, -141200,2.1178324,3.7931678,,,,,,,,,,,,,, -141230,,,0.8560156226158142,0.7755244374275208,0.7628999948501587,1.1578243970870972,50000.0,0.648900032043457,1.7374924421310425,10000.0,65152.95014190674,70599.94521164894,65152.95014190674,5432.35400223732,6.752318620681763,0.0 -141300,2.2197924,4.241903,,,,,,,,,,,,,, -141400,1.9888006,3.4374335,,,,,,,,,,,,,, -141500,1.8562369,3.234419,,,,,,,,,,,,,, -141600,2.3320248,4.1130443,,,,,,,,,,,,,, -141700,2.434499,4.2041254,,,,,,,,,,,,,, -141800,2.1226542,2.9068205,,,,,,,,,,,,,, -141900,2.4930904,4.5499063,,,,,,,,,,,,,, -142000,2.0544276,3.0060456,,,,,,,,,,,,,, -142100,2.000438,3.6135938,,,,,,,,,,,,,, -142141,,,0.85267573595047,0.7707575559616089,0.7630599737167358,1.145895004272461,50000.0,0.6499000191688538,1.714654803276062,10000.0,65573.02896523476,71058.30751371384,65573.02896523476,5470.536109209061,6.80316686630249,0.0 -142200,1.9656831,2.9612117,,,,,,,,,,,,,, -142300,2.077454,3.7806668,,,,,,,,,,,,,, -142400,2.0182703,3.064938,,,,,,,,,,,,,, -142500,1.9105885,2.8010545,,,,,,,,,,,,,, -142600,1.9953558,2.9023077,,,,,,,,,,,,,, -142700,2.002463,3.2152987,,,,,,,,,,,,,, -142800,1.936676,3.5570474,,,,,,,,,,,,,, -142900,2.1324186,3.004218,,,,,,,,,,,,,, -143000,2.3562083,4.291907,,,,,,,,,,,,,, -143053,,,0.8552343845367432,0.7671908140182495,0.7658599615097046,1.1423100233078003,50000.0,0.6479000449180603,1.7201541662216189,10000.0,65993.15498352051,71517.50735139847,65993.15498352051,5509.510308504105,6.851216554641724,0.0 -143100,2.1874068,2.868282,,,,,,,,,,,,,, -143200,1.8866147,2.9405675,,,,,,,,,,,,,, -143300,2.570569,4.434076,,,,,,,,,,,,,, -143400,2.1523108,3.0371103,,,,,,,,,,,,,, -143500,2.0620382,2.993281,,,,,,,,,,,,,, -143600,2.234574,3.2993708,,,,,,,,,,,,,, -143700,1.9258355,3.6281452,,,,,,,,,,,,,, -143800,2.015136,2.940074,,,,,,,,,,,,,, -143900,2.0522892,3.658233,,,,,,,,,,,,,, -143963,,,0.8605859279632568,0.7456603646278381,0.7659599781036377,1.1388903856277466,50000.0,0.648300051689148,1.7129008769989014,10000.0,66413.09290719032,71975.2932009697,66413.09290719032,5547.255858421326,6.90261173248291,0.0 -144000,2.3536334,2.894527,,,,,,,,,,,,,, -144100,2.0964708,4.3189487,,,,,,,,,,,,,, -144200,2.205914,2.9984684,,,,,,,,,,,,,, -144300,2.0343108,2.8293192,,,,,,,,,,,,,, -144400,2.3431427,2.919578,,,,,,,,,,,,,, -144500,2.25456,2.9892907,,,,,,,,,,,,,, -144600,1.9967266,2.9035697,,,,,,,,,,,,,, -144700,2.0569763,2.871564,,,,,,,,,,,,,, -144800,2.164325,3.2686923,,,,,,,,,,,,,, -144873,,,0.8684179782867432,0.724058210849762,0.7653599977493286,1.1371018886566162,50000.0,0.6476000547409058,1.7198455333709717,10000.0,66833.13764357567,72433.08190202713,66833.13764357567,5584.90097284317,6.950519561767578,0.0 -144900,2.0341635,2.930642,,,,,,,,,,,,,, -145000,2.0845053,3.5298533,,,,,,,,,,,,,, -145100,1.8685719,3.0399961,,,,,,,,,,,,,, -145200,2.070672,2.9610813,,,,,,,,,,,,,, -145300,2.1477516,2.8752654,,,,,,,,,,,,,, -145400,1.9874903,2.9138165,,,,,,,,,,,,,, -145500,1.982653,2.8696742,,,,,,,,,,,,,, -145600,2.6055684,4.278634,,,,,,,,,,,,,, -145700,1.9567246,2.8238432,,,,,,,,,,,,,, -145782,,,0.8575780987739563,0.7682866454124451,0.764519989490509,1.1535091400146484,50000.0,0.6463000178337097,1.7320133447647097,10000.0,67253.06000828743,72893.4165430069,67253.06000828743,5625.206836462021,7.006183862686157,0.0 -145800,2.1484,3.7570586,,,,,,,,,,,,,, -145900,2.558361,4.0186124,,,,,,,,,,,,,, -146000,2.120771,2.9952118,,,,,,,,,,,,,, -146100,2.0757694,2.9174035,,,,,,,,,,,,,, -146200,2.2355533,2.864677,,,,,,,,,,,,,, -146300,2.0086095,3.096239,,,,,,,,,,,,,, -146400,1.8513843,3.3650844,,,,,,,,,,,,,, -146500,2.7367525,4.4576454,,,,,,,,,,,,,, -146600,2.1555758,2.9139385,,,,,,,,,,,,,, -146692,,,0.8640234470367432,0.74609375,0.7672399878501892,1.1380559206008911,50000.0,0.6510000228881836,1.7285139560699463,10000.0,67673.04480743408,73350.82883524895,67673.04480743408,5662.5311868190765,7.058150768280029,0.0 -146700,2.0196128,2.9331083,,,,,,,,,,,,,, -146800,2.9104257,4.4575586,,,,,,,,,,,,,, -146900,2.0639045,3.1180887,,,,,,,,,,,,,, -147000,2.1031458,2.8801718,,,,,,,,,,,,,, -147100,2.1043513,2.9560542,,,,,,,,,,,,,, -147200,2.2356215,3.8986135,,,,,,,,,,,,,, -147300,2.0777936,2.8670642,,,,,,,,,,,,,, -147400,2.0821218,3.7857602,,,,,,,,,,,,,, -147500,1.9840493,3.4809616,,,,,,,,,,,,,, -147600,2.1919549,2.8480744,,,,,,,,,,,,,, -147604,,,0.8709960579872131,0.6939271092414856,0.7691799998283386,1.1134551763534546,50000.0,0.6548000574111938,1.6992754936218262,10000.0,68093.3377597332,73807.02789735794,68093.3377597332,5698.334320783615,7.110531806945801,0.0 -147700,2.0981596,3.116476,,,,,,,,,,,,,, -147800,2.1742027,2.9216626,,,,,,,,,,,,,, -147900,2.1287613,2.9106452,,,,,,,,,,,,,, -148000,2.201289,2.922781,,,,,,,,,,,,,, -148100,2.0696688,3.0961897,,,,,,,,,,,,,, -148200,2.2124598,2.958764,,,,,,,,,,,,,, -148300,2.0856957,3.092037,,,,,,,,,,,,,, -148400,2.0994349,3.7663932,,,,,,,,,,,,,, -148500,2.267539,4.0448856,,,,,,,,,,,,,, -148514,,,0.8633202910423279,0.7357399463653564,0.7691999673843384,1.1223732233047483,50000.0,0.6541000604629517,1.6979526281356812,10000.0,68513.33066034317,74264.12940835953,68513.33066034317,5735.344721794128,7.158812522888184,0.0 -148600,2.114255,2.865368,,,,,,,,,,,,,, -148700,2.0026715,2.8082976,,,,,,,,,,,,,, -148800,2.5842733,4.3616333,,,,,,,,,,,,,, -148900,2.2620354,3.0276294,,,,,,,,,,,,,, -149000,2.0636327,2.8034503,,,,,,,,,,,,,, -149100,2.2040036,2.8711426,,,,,,,,,,,,,, -149200,2.311219,2.9025369,,,,,,,,,,,,,, -149300,2.4403055,3.2147336,,,,,,,,,,,,,, -149400,3.0239148,4.3701067,,,,,,,,,,,,,, -149422,,,0.8650000095367432,0.7219473123550415,0.7705199718475342,1.1146769523620603,50000.0,0.65420001745224,1.6978703737258911,10000.0,68933.41969394684,74722.75263619423,68933.41969394684,5773.777179718018,7.209390163421631,0.0 -149500,2.0970042,2.7717085,,,,,,,,,,,,,, -149600,2.0282023,3.4291716,,,,,,,,,,,,,, -149700,2.226192,3.6600668,,,,,,,,,,,,,, -149800,2.2275224,2.9760013,,,,,,,,,,,,,, -149900,2.1264758,2.9277172,,,,,,,,,,,,,, -150000,2.2343912,2.9574103,,,,,,,,,,,,,, -150100,2.1597288,3.6245387,,,,,,,,,,,,,, -150200,2.2171798,2.8992338,,,,,,,,,,,,,, -150300,2.408702,4.1405873,,,,,,,,,,,,,, -150330,,,0.8721874952316284,0.741978645324707,0.7699599862098694,1.1527972221374512,50000.0,0.652400016784668,1.7470016479492188,10000.0,69353.35374689102,75181.42854118347,69353.35374689102,5812.420975446701,7.256996393203735,0.0 -150400,2.1240563,3.8088477,,,,,,,,,,,,,, -150500,2.0078504,2.8943906,,,,,,,,,,,,,, -150600,2.1041303,2.9726408,,,,,,,,,,,,,, -150700,2.7705798,4.24671,,,,,,,,,,,,,, -150800,2.080184,3.4036214,,,,,,,,,,,,,, -150900,2.0176456,2.815829,,,,,,,,,,,,,, -151000,2.230612,2.7590528,,,,,,,,,,,,,, -151100,1.9869938,2.9364512,,,,,,,,,,,,,, -151200,2.1055849,2.8345807,,,,,,,,,,,,,, -151238,,,0.8633593320846558,0.7349731922149658,0.7711600065231323,1.12005615234375,50000.0,0.6549000144004822,1.6994355916976929,10000.0,69773.56124687195,75640.54940104485,69773.56124687195,5851.233624219894,7.306897401809692,0.0 -151300,2.6057067,4.2006598,,,,,,,,,,,,,, -151400,2.2750537,3.9652243,,,,,,,,,,,,,, -151500,2.233904,3.5447237,,,,,,,,,,,,,, -151600,2.3479118,2.8792944,,,,,,,,,,,,,, -151700,2.152716,2.9123864,,,,,,,,,,,,,, -151800,2.0386865,3.21441,,,,,,,,,,,,,, -151900,2.0667107,2.8185763,,,,,,,,,,,,,, -152000,2.4441714,2.943715,,,,,,,,,,,,,, -152100,2.303912,2.9129071,,,,,,,,,,,,,, -152146,,,0.8692968487739563,0.7008588314056396,0.7726799845695496,1.1004048585891724,50000.0,0.6581000089645386,1.682053804397583,10000.0,70193.82568836212,76096.07240009308,70193.82568836212,5886.390437841415,7.358208894729614,0.0 -152200,2.2109423,2.8922005,,,,,,,,,,,,,, -152300,2.4091504,3.8469038,,,,,,,,,,,,,, -152400,2.1562173,2.9663994,,,,,,,,,,,,,, -152500,2.3132842,3.033348,,,,,,,,,,,,,, -152600,2.2882662,3.1125257,,,,,,,,,,,,,, -152700,2.2454593,2.8270884,,,,,,,,,,,,,, -152800,2.3555417,2.8636475,,,,,,,,,,,,,, -152900,2.3877823,4.096744,,,,,,,,,,,,,, -153000,2.348904,3.682967,,,,,,,,,,,,,, -153053,,,0.8724218606948853,0.7150508761405945,0.7715199589729309,1.1307733058929443,50000.0,0.6557000279426575,1.7197359800338743,10000.0,70613.74158143997,76553.8149971962,70613.74158143997,5924.112850666046,7.41174578666687,0.0 -153100,2.4631004,3.9935765,,,,,,,,,,,,,, -153200,2.4535952,2.9401639,,,,,,,,,,,,,, -153300,2.0564332,3.2324193,,,,,,,,,,,,,, -153400,2.2662377,2.8453016,,,,,,,,,,,,,, -153500,2.225906,3.564443,,,,,,,,,,,,,, -153600,2.2297459,2.9042983,,,,,,,,,,,,,, -153700,2.2425282,2.928686,,,,,,,,,,,,,, -153800,2.1461291,3.434473,,,,,,,,,,,,,, -153900,2.248521,2.8068259,,,,,,,,,,,,,, -153959,,,0.8653124570846558,0.7207576632499695,0.7726999521255493,1.1108520030975342,50000.0,0.6573000550270081,1.6898996829986572,10000.0,71033.70966124535,77009.18772745132,71033.70966124535,5959.418373346329,7.46028733253479,0.0 -154000,2.2873034,2.8612256,,,,,,,,,,,,,, -154100,2.996412,4.3292513,,,,,,,,,,,,,, -154200,2.348401,2.9359446,,,,,,,,,,,,,, -154300,2.3726847,3.0819888,,,,,,,,,,,,,, -154400,2.3726277,2.8009062,,,,,,,,,,,,,, -154500,2.308328,3.322178,,,,,,,,,,,,,, -154600,2.1505876,2.8010106,,,,,,,,,,,,,, -154700,2.3289604,2.8416727,,,,,,,,,,,,,, -154800,2.2814918,2.9967372,,,,,,,,,,,,,, -154870,,,0.8743554353713989,0.6960504651069641,0.7752000093460083,1.1026690006256104,50000.0,0.6573000550270081,1.684720754623413,10000.0,71454.04873561859,77463.63437724113,71454.04873561859,5993.421750068665,7.513123273849487,0.0 -154900,2.103045,2.8724532,,,,,,,,,,,,,, -155000,2.3345687,3.258585,,,,,,,,,,,,,, -155100,2.215027,2.9714289,,,,,,,,,,,,,, -155200,2.2572207,2.8373995,,,,,,,,,,,,,, -155300,2.2493603,3.1951797,,,,,,,,,,,,,, -155400,2.402625,2.872893,,,,,,,,,,,,,, -155500,3.0651634,4.338187,,,,,,,,,,,,,, -155600,2.0862367,2.8728046,,,,,,,,,,,,,, -155700,2.3206074,3.3491392,,,,,,,,,,,,,, -155782,,,0.872851550579071,0.6959668397903442,0.7761799693107605,1.1069529056549072,50000.0,0.6583000421524048,1.697216510772705,10000.0,71874.4541144371,77920.41535067558,71874.4541144371,6029.695031404495,7.564982175827026,0.0 -155800,2.3784451,2.8746042,,,,,,,,,,,,,, -155900,2.303643,3.1533897,,,,,,,,,,,,,, -156000,2.5839422,2.908265,,,,,,,,,,,,,, -156100,2.4550042,2.9421012,,,,,,,,,,,,,, -156200,2.3279135,2.852824,,,,,,,,,,,,,, -156300,2.36565,3.9506621,,,,,,,,,,,,,, -156400,3.1951127,4.413699,,,,,,,,,,,,,, -156500,2.2056541,2.868913,,,,,,,,,,,,,, -156600,2.4440114,2.9418418,,,,,,,,,,,,,, -156692,,,0.8742382526397705,0.7012655735015869,0.7750200033187866,1.1078037023544312,50000.0,0.6554000377655029,1.695989727973938,10000.0,72294.57926630974,78375.58976507187,72294.57926630974,6064.638352155685,7.620237112045288,0.0 -156700,2.3258214,2.859114,,,,,,,,,,,,,, -156800,2.3698397,2.805331,,,,,,,,,,,,,, -156900,2.2765205,2.9034224,,,,,,,,,,,,,, -157000,2.461865,2.7755847,,,,,,,,,,,,,, -157100,2.801206,2.8500104,,,,,,,,,,,,,, -157200,2.3683214,3.7007196,,,,,,,,,,,,,, -157300,2.2567613,2.8640642,,,,,,,,,,,,,, -157400,2.3604572,2.8507316,,,,,,,,,,,,,, -157500,2.3241692,2.7646775,,,,,,,,,,,,,, -157600,2.3501382,2.8416748,,,,,,,,,,,,,, -157601,,,0.8741992115974426,0.7010714411735535,0.776919960975647,1.1080583333969116,50000.0,0.6605000495910645,1.682713747024536,10000.0,72714.48941850662,78831.00256085396,72714.48941850662,6100.034034729004,7.676239013671875,0.0 -157700,2.344385,2.819013,,,,,,,,,,,,,, -157800,2.2376914,2.9719272,,,,,,,,,,,,,, -157900,2.3989506,2.8703349,,,,,,,,,,,,,, -158000,2.2486417,2.8444242,,,,,,,,,,,,,, -158100,2.356475,2.7989285,,,,,,,,,,,,,, -158200,6.249833,3.837706,,,,,,,,,,,,,, -158300,2.5570822,2.947917,,,,,,,,,,,,,, -158400,2.3291514,2.8989825,,,,,,,,,,,,,, -158500,2.3483827,2.9994743,,,,,,,,,,,,,, -158513,,,0.8779296875,0.6749905943870544,0.7773799896240234,1.0918878316879272,50000.0,0.6610000133514404,1.6656324863433838,10000.0,73134.6789803505,79286.81630277634,73134.6789803505,6135.559215307236,7.725278377532959,0.0 -158600,2.7713451,4.174698,,,,,,,,,,,,,, -158700,2.3121257,3.5740168,,,,,,,,,,,,,, -158800,2.733803,4.08171,,,,,,,,,,,,,, -158900,2.2557492,3.3972046,,,,,,,,,,,,,, -159000,2.3858883,2.834886,,,,,,,,,,,,,, -159100,2.170416,3.1181614,,,,,,,,,,,,,, -159200,2.4649544,3.4414825,,,,,,,,,,,,,, -159300,2.4456782,2.9056828,,,,,,,,,,,,,, -159400,2.3584232,2.8932493,,,,,,,,,,,,,, -159424,,,0.8819335699081421,0.6788507103919983,0.7778399586677551,1.1027580499649048,50000.0,0.6589000225067139,1.6929190158843994,10000.0,73554.84628415108,79743.22312545776,73554.84628415108,6171.688840389252,7.782891035079956,0.0 -159500,2.4132862,2.7889202,,,,,,,,,,,,,, -159600,2.3587418,3.8222127,,,,,,,,,,,,,, -159700,2.4933238,3.8593206,,,,,,,,,,,,,, -159800,2.3137784,2.8121219,,,,,,,,,,,,,, -159900,2.3718765,3.5954218,,,,,,,,,,,,,, -160000,2.6968186,4.158178,,,,,,,,,,,,,, -160100,2.4864058,3.762695,,,,,,,,,,,,,, -160200,2.7940855,3.8792892,,,,,,,,,,,,,, -160300,3.0521026,4.2296076,,,,,,,,,,,,,, -160336,,,0.8783202767372131,0.6747952103614807,0.7796799540519714,1.08854079246521,50000.0,0.6620000600814819,1.66422700881958,10000.0,73974.78279566765,80200.6146595478,73974.78279566765,6209.042106866837,7.834174156188965,0.0 -160400,2.380001,2.8981197,,,,,,,,,,,,,, -160500,2.2920685,2.699819,,,,,,,,,,,,,, -160600,2.7686257,4.135027,,,,,,,,,,,,,, -160700,2.4115725,2.8316,,,,,,,,,,,,,, -160800,2.6154926,3.9134705,,,,,,,,,,,,,, -160900,2.5200307,2.9857242,,,,,,,,,,,,,, -161000,2.306012,3.082279,,,,,,,,,,,,,, -161100,2.193573,2.745028,,,,,,,,,,,,,, -161200,2.4122224,2.8632324,,,,,,,,,,,,,, -161245,,,0.8847070336341858,0.6641880869865417,0.7793799638748169,1.094588279724121,50000.0,0.6619000434875488,1.676234245300293,10000.0,74395.06458425522,80656.78582334518,74395.06458425522,6244.829773902893,7.885188102722168,0.0 -161300,2.609011,2.8416681,,,,,,,,,,,,,, -161400,2.209622,3.221414,,,,,,,,,,,,,, -161500,2.3281395,3.5959053,,,,,,,,,,,,,, -161600,2.555103,2.818286,,,,,,,,,,,,,, -161700,2.5325503,2.8217258,,,,,,,,,,,,,, -161800,2.3935335,2.756759,,,,,,,,,,,,,, -161900,3.1363385,4.138384,,,,,,,,,,,,,, -162000,2.5978217,2.8492546,,,,,,,,,,,,,, -162100,3.396204,4.3000016,,,,,,,,,,,,,, -162155,,,0.8860546946525574,0.6558742523193359,0.7773999571800232,1.094112992286682,50000.0,0.6637000441551208,1.6746059656143188,10000.0,74815.33997106552,81112.87297201157,74815.33997106552,6280.536881446838,7.93939733505249,0.0 -162200,2.6536655,2.886056,,,,,,,,,,,,,, -162300,2.500281,2.791034,,,,,,,,,,,,,, -162400,2.3703878,2.745278,,,,,,,,,,,,,, -162500,2.5085983,2.841587,,,,,,,,,,,,,, -162600,2.3066247,2.7705986,,,,,,,,,,,,,, -162700,2.431174,2.7208734,,,,,,,,,,,,,, -162800,2.7202046,4.105852,,,,,,,,,,,,,, -162900,2.8223596,3.8354044,,,,,,,,,,,,,, -163000,2.379993,2.844086,,,,,,,,,,,,,, -163065,,,0.8809179663658142,0.6672109961509705,0.7792999744415283,1.0870453119277954,50000.0,0.6628000140190125,1.6721868515014648,10000.0,75235.40973472595,81568.50692629814,75235.40973472595,6315.99707698822,7.991906642913818,0.0 -163100,2.5678031,3.4044614,,,,,,,,,,,,,, -163200,2.5595238,2.763332,,,,,,,,,,,,,, -163300,2.570089,2.8930671,,,,,,,,,,,,,, -163400,2.2900288,2.8516774,,,,,,,,,,,,,, -163500,2.527768,3.7343874,,,,,,,,,,,,,, -163600,3.1164448,4.061781,,,,,,,,,,,,,, -163700,2.3648562,3.4290252,,,,,,,,,,,,,, -163800,2.6082103,2.8438926,,,,,,,,,,,,,, -163900,2.5164642,2.84584,,,,,,,,,,,,,, -163977,,,0.8843554258346558,0.6567904353141785,0.7809799909591675,1.085119605064392,50000.0,0.6634000539779663,1.6664263010025024,10000.0,75655.63436961174,82025.42798662186,75655.63436961174,6352.588169336319,8.046016693115234,0.0 -164000,2.626476,3.737489,,,,,,,,,,,,,, -164100,3.275879,4.344206,,,,,,,,,,,,,, -164200,3.5007482,4.2510743,,,,,,,,,,,,,, -164300,2.776705,4.015199,,,,,,,,,,,,,, -164400,2.6753457,2.8783963,,,,,,,,,,,,,, -164500,2.3031297,2.7583232,,,,,,,,,,,,,, -164600,2.580354,2.8079677,,,,,,,,,,,,,, -164700,2.931025,3.8675458,,,,,,,,,,,,,, -164800,2.4108858,3.1572487,,,,,,,,,,,,,, -164889,,,0.8860155940055847,0.6490659117698669,0.7811999917030334,1.085237979888916,50000.0,0.6682000160217285,1.664968967437744,10000.0,76075.67520737648,82483.71307039261,76075.67520737648,6390.728995323181,8.098177194595337,0.0 -164900,2.32671,2.9471703,,,,,,,,,,,,,, -165000,2.848328,2.8280864,,,,,,,,,,,,,, -165100,2.6519785,3.761668,,,,,,,,,,,,,, -165200,2.368853,2.9139497,,,,,,,,,,,,,, -165300,2.445151,2.9946303,,,,,,,,,,,,,, -165400,2.2125895,2.83423,,,,,,,,,,,,,, -165500,2.492956,3.5574148,,,,,,,,,,,,,, -165600,3.347341,4.24133,,,,,,,,,,,,,, -165700,2.5413141,2.8340077,,,,,,,,,,,,,, -165800,,,0.8830273151397705,0.6618872880935669,0.7819199562072754,1.0855050086975098,50000.0,0.6676000356674194,1.659148931503296,10000.0,76495.9494357109,82943.06176996231,76495.9494357109,6429.695057630539,8.155537843704224,0.0 -165800,2.3996265,2.7934594,,,,,,,,,,,,,, -165900,2.4421701,3.1134968,,,,,,,,,,,,,, -166000,2.477282,2.926253,,,,,,,,,,,,,, -166100,2.552682,3.1593497,,,,,,,,,,,,,, -166200,2.5290587,2.966774,,,,,,,,,,,,,, -166300,2.4406042,2.9179235,,,,,,,,,,,,,, -166400,2.5675945,3.1614568,,,,,,,,,,,,,, -166500,2.7827535,2.7531314,,,,,,,,,,,,,, -166600,2.578835,3.694612,,,,,,,,,,,,,, -166700,2.5761487,3.0320048,,,,,,,,,,,,,, -166711,,,0.8886913657188416,0.645067036151886,0.7826399803161621,1.0786688327789309,50000.0,0.6674000024795532,1.6567219495773315,10000.0,76916.21601438522,83398.33122730255,76916.21601438522,6464.591763496399,8.211509466171265,0.0 -166800,2.521521,2.840773,,,,,,,,,,,,,, -166900,3.1130347,3.9973686,,,,,,,,,,,,,, -167000,2.9646134,3.9172068,,,,,,,,,,,,,, -167100,2.6200054,2.8452241,,,,,,,,,,,,,, -167200,2.4702673,3.201235,,,,,,,,,,,,,, -167300,2.551314,2.9166014,,,,,,,,,,,,,, -167400,2.8536022,3.8739684,,,,,,,,,,,,,, -167500,2.6823373,2.753675,,,,,,,,,,,,,, -167600,2.5442946,2.774345,,,,,,,,,,,,,, -167620,,,0.88818359375,0.6586396098136902,0.7811399698257446,1.0905895233154297,50000.0,0.6629000306129456,1.6669039726257324,10000.0,77336.24432229996,83855.08708715439,77336.24432229996,6501.212018728256,8.268371820449829,0.0 -167700,2.658464,3.9387946,,,,,,,,,,,,,, -167800,2.5290515,3.7527976,,,,,,,,,,,,,, -167900,2.4550323,3.397627,,,,,,,,,,,,,, -168000,2.8336961,2.7597208,,,,,,,,,,,,,, -168028,,,,,,,,,,,77520.28070497513,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 47358f594..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -26.919355869293213,0.0,35.492053747177124,1,0,35.492053747177124,0.0010000000474974,6.907756805419922,10000,62.411550998687744,0.0008398437057621,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -63.456319093704224,0.0174732208251953,455.7455909252167,857,0,455.7455909252167,0.0090000005438923,6.5088372230529785,10000,519.2673494815826,0.0123046869412064,6.458219051361084,0.0127399992197752,6.469983100891113,50000 -102.13130497932434,0.0477561950683593,875.765299320221,1766,0,875.765299320221,0.0318999998271465,5.995128154754639,10000,978.044549703598,0.036210935562849,5.856694221496582,0.0349400006234645,5.884265422821045,50000 -140.7256236076355,0.073580265045166,1296.171561002731,2677,0,1296.171561002731,0.0478000007569789,5.662439346313477,10000,1437.1218678951263,0.0633203089237213,5.436893939971924,0.0602799989283084,5.480632305145264,50000 -179.481369972229,0.0996489524841308,1716.1118450164795,3587,0,1716.1118450164795,0.068400003015995,5.372107982635498,10000,1895.894778966904,0.0926367193460464,5.093984603881836,0.0870999991893768,5.130605220794678,50000 -218.3541798591613,0.1258656978607177,2136.0499653816223,4496,0,2136.0499653816223,0.0940000042319297,5.104016304016113,10000,2354.783221721649,0.1288671791553497,4.745182514190674,0.120679996907711,4.8043293952941895,50000 -257.7483706474304,0.1558928489685058,2555.98603963852,5405,0,2555.98603963852,0.1173000037670135,4.828131675720215,10000,2814.194491147995,0.1678320318460464,4.406538486480713,0.1526799947023391,4.486951351165772,50000 -291.5754690170288,0.181589126586914,2976.0876848697662,6314,0,2976.0876848697662,0.1439000070095062,4.525965213775635,10000,3268.1993370056152,0.2112695276737213,4.0574235916137695,0.1956399977207183,4.138259410858154,50000 -329.72999024391174,0.2092578411102295,3396.280373096466,7227,0,3396.280373096466,0.1714000105857849,4.288461685180664,10000,3726.625602722168,0.2490820288658142,3.757694959640503,0.23157998919487,3.856245756149292,50000 -366.4525589942932,0.2342324256896972,3816.497978210449,8138,0,3816.497978210449,0.2063000053167343,4.033924579620361,10000,4183.641686677933,0.2967578172683716,3.435223340988159,0.2703399956226349,3.561870574951172,50000 -405.5273621082306,0.2614433765411377,4236.516438961029,9048,0,4236.516438961029,0.2297000139951706,3.847393989562988,10000,4642.813379764557,0.3414452970027923,3.161966562271118,0.3065399825572967,3.330437660217285,50000 -441.491233587265,0.2879626750946045,4656.703306913376,9958,0,4656.703306913376,0.2514000236988067,3.730279922485352,10000,5099.041333913803,0.3555664122104645,3.039619445800781,0.3312200009822845,3.173012971878052,50000 -477.1607825756073,0.3193974494934082,5076.962865829468,10870,0,5076.962865829468,0.2725000083446502,3.626384973526001,10000,5555.052897453308,0.3751562535762787,2.954991340637207,0.3479799926280975,3.092005491256714,50000 -516.0527064800262,0.3488492965698242,5497.338998794556,11782,0,5497.338998794556,0.2871000170707702,3.488808870315552,10000,6014.403256416321,0.4147265553474426,2.723894596099853,0.371099978685379,2.9334120750427246,50000 -551.9732129573822,0.3765082359313965,5917.761694431305,12691,0,5917.761694431305,0.3032000064849853,3.3758649826049805,10000,6470.824687480927,0.4253320097923279,2.623037099838257,0.3962000012397766,2.782031536102295,50000 -587.4360723495483,0.4064729213714599,6337.694573402405,13601,0,6337.694573402405,0.3160000145435333,3.312729358673096,10000,6926.301305532455,0.4386132657527923,2.5734517574310303,0.4082199931144714,2.7287282943725586,50000 -622.858304977417,0.4353101253509521,6757.871794700623,14512,0,6757.871794700623,0.3279000222682953,3.24593186378479,10000,7381.980210542679,0.46240234375,2.450325965881348,0.4207199811935425,2.656090021133423,50000 -660.8948771953583,0.4631009101867676,7177.840661525726,15422,0,7177.840661525726,0.32710000872612,3.216073989868164,10000,7840.064656257629,0.4571093618869781,2.456397771835327,0.4252599775791168,2.6171255111694336,50000 -700.3218250274658,0.4916057586669922,7597.769626379013,16333,0,7597.769626379013,0.3414000272750854,3.161466598510742,10000,8299.500366926193,0.4747656285762787,2.3626480102539062,0.4391599893569946,2.550366163253784,50000 -739.2799081802368,0.520301103591919,8018.110145807266,17241,0,8018.110145807266,0.344400018453598,3.137245178222656,10000,8758.87922167778,0.4889843761920929,2.295726776123047,0.4437599778175354,2.507998466491699,50000 -781.0245745182037,0.5498020648956299,8438.437040090561,18151,0,8438.437040090561,0.3503000140190124,3.0864689350128174,10000,9221.031891822817,0.4826171696186065,2.3039026260375977,0.4477799832820892,2.470780372619629,50000 -821.6779868602753,0.579658031463623,8858.829084157944,19061,0,8858.829084157944,0.3585000038146972,3.022629976272583,10000,9682.15875673294,0.4994726479053497,2.22053599357605,0.462039977312088,2.404991149902344,50000 -860.1880419254303,0.6096100807189941,9278.822909593582,19970,0,9278.822909593582,0.3583000302314758,3.0018560886383057,10000,10140.743117809296,0.5083202719688416,2.1640000343322754,0.4681599736213684,2.3670108318328857,50000 -895.5525405406952,0.6403217315673828,9699.075863838196,20882,0,9699.075863838196,0.3707000315189361,2.948025941848755,10000,10596.444150209429,0.5099218487739563,2.1630897521972656,0.4744599759578705,2.34031629562378,50000 -932.010551214218,0.6730847358703613,10119.355597019196,21793,0,10119.355597019196,0.3747000098228454,2.934434413909912,10000,11053.265462636948,0.5142382979393005,2.1343255043029785,0.4792400002479553,2.311788320541382,50000 -967.9775204658508,0.7074997425079346,10539.649500846865,22703,0,10539.649500846865,0.3823000192642212,2.8995566368103027,10000,11509.611449241638,0.5282812118530273,2.0714240074157715,0.4855599999427795,2.277931928634644,50000 -1007.3027441501616,0.7358355522155762,10959.790457248688,23613,0,10959.790457248688,0.3823000192642212,2.859225273132324,10000,11969.157633304596,0.5564648509025574,1.9199002981185915,0.4958999752998352,2.209649085998535,50000 -1046.9547855854034,0.7669098377227783,11379.97221159935,24523,0,11379.97221159935,0.3924000263214111,2.822327852249145,10000,12429.073031425476,0.5373241901397705,2.020864248275757,0.5016999840736389,2.1879594326019287,50000 -1082.676115512848,0.7968857288360596,11800.332841157911,25434,0,11800.332841157911,0.3924000263214111,2.835370779037476,10000,12885.237624645231,0.5425586104393005,1.988130569458008,0.5017399787902832,2.190019130706787,50000 -1118.181747674942,0.8260171413421631,12220.630915403366,26343,0,12220.630915403366,0.4079000055789947,2.7394766807556152,10000,13341.120736837389,0.573046863079071,1.8160487413406368,0.5158599615097046,2.1130800247192383,50000 -1154.6123263835907,0.8571968078613281,12640.863479614258,27254,0,12640.863479614258,0.4070000052452087,2.7597310543060303,10000,13797.866409778597,0.5526952743530273,1.9304035902023315,0.5169399976730347,2.108940362930298,50000 -1188.6037278175354,0.8905489444732666,13061.040511369703,28164,0,13061.040511369703,0.4034000337123871,2.77888560295105,10000,14252.11918401718,0.5553905963897705,1.9481656551361084,0.51528000831604,2.139265298843384,50000 -1223.0425362586975,0.9243690967559814,13481.007174015043,29074,0,13481.007174015043,0.4146000146865845,2.690381050109864,10000,14706.609322786331,0.576171875,1.81451153755188,0.5225200057029724,2.0690627098083496,50000 -1256.6402442455292,0.9586703777313232,13901.04438996315,29985,0,13901.04438996315,0.4086000323295593,2.7167224884033203,10000,15160.329586267471,0.5616992115974426,1.872113943099976,0.5232599973678589,2.0561981201171875,50000 -1295.0405583381653,0.9944887161254884,14321.219009399414,30896,0,14321.219009399414,0.4124000072479248,2.723940849304199,10000,15618.991045713425,0.5642382502555847,1.8705335855484009,0.519819974899292,2.07366156578064,50000 -1329.095343351364,1.027569055557251,14741.544480085371,31807,0,14741.544480085371,0.4239000082015991,2.674529790878296,10000,16073.455638170242,0.5814062356948853,1.8077421188354488,0.5366799831390381,2.019350290298462,50000 -1368.20298743248,1.0579285621643066,15161.686965703964,32718,0,15161.686965703964,0.4264000058174133,2.6434519290924072,10000,16532.786767959595,0.5716406106948853,1.8151863813400269,0.5365599989891052,1.983070731163025,50000 -1405.1028938293457,1.087789535522461,15581.740196228027,33628,0,15581.740196228027,0.4292000234127044,2.636946439743042,10000,16989.820979833603,0.5784569978713989,1.8230026960372925,0.5387799739837646,2.005711555480957,50000 -1439.4901642799375,1.1178858280181885,16001.98867702484,34539,0,16001.98867702484,0.425100028514862,2.602484941482544,10000,17444.537808418274,0.5900781154632568,1.7298932075500488,0.5426799654960632,1.94635820388794,50000 -1475.937474489212,1.149817943572998,16422.16807770729,35449,0,16422.16807770729,0.433100014925003,2.59553599357605,10000,17901.24688887596,0.5878710746765137,1.7523179054260254,0.5478000044822693,1.945866227149964,50000 -1514.2002153396606,1.1828999519348145,16842.540033340454,36359,0,16842.540033340454,0.4359000325202942,2.588318109512329,10000,18359.965654611588,0.5913280844688416,1.7497471570968628,0.5470399856567383,1.9463144540786743,50000 -1553.5912311077118,1.2173235416412354,17262.59283065796,37270,0,17262.59283065796,0.4330000281333923,2.5716423988342285,10000,18819.49461555481,0.5938476324081421,1.7227494716644287,0.5518400073051453,1.9233359098434448,50000 -1587.3950009346008,1.255021095275879,17682.696738004684,38182,0,17682.696738004684,0.4394000172615051,2.5792202949523926,10000,19273.49109721184,0.6218359470367432,1.614800214767456,0.5530799627304077,1.933805584907532,50000 -1622.6679162979126,1.2887091636657717,18102.84414291382,39093,0,18102.84414291382,0.4311000108718872,2.6091208457946777,10000,19728.99629807472,0.5911718606948853,1.776065468788147,0.5468999743461609,1.9785503149032595,50000 -1658.0126931667328,1.3234169483184814,18522.978043079376,40001,0,18522.978043079376,0.4356000125408172,2.5480105876922607,10000,20184.561008930206,0.6050390601158142,1.6630072593688965,0.5593199729919434,1.883844017982483,50000 -1695.5409202575684,1.355790376663208,18943.06448030472,40914,0,18943.06448030472,0.4389000236988067,2.5651254653930664,10000,20642.2593035698,0.6123827695846558,1.6324337720870972,0.5535199642181396,1.9175604581832888,50000 -1732.9484958648682,1.3894102573394775,19363.120608329773,41821,0,19363.120608329773,0.4374000132083893,2.5530781745910645,10000,21099.80816817284,0.5993554592132568,1.6984341144561768,0.5589599609375,1.893854022026062,50000 -1769.8629813194275,1.42185640335083,19783.279280662537,42732,0,19783.279280662537,0.4466000199317932,2.5125553607940674,10000,21556.965598344803,0.6033788919448853,1.66139018535614,0.5579400062561035,1.870342493057251,50000 -1804.986863613129,1.4544131755828855,20203.238243579865,43642,0,20203.238243579865,0.4421000182628631,2.542107105255127,10000,22012.131967306137,0.6136132478713989,1.6449350118637085,0.5623199939727783,1.889671802520752,50000 -1839.381169557572,1.4913108348846436,20623.36201238632,44550,0,20623.36201238632,0.4440000355243683,2.530787944793701,10000,22466.73867583275,0.6031249761581421,1.6811929941177368,0.5647599697113037,1.867743134498596,50000 -1873.501362800598,1.5242390632629397,21043.66229391098,45461,0,21043.66229391098,0.4482000172138214,2.4963107109069824,10000,22921.24355506897,0.6108788847923279,1.6417521238327026,0.567579984664917,1.847375750541687,50000 -1912.0473115444183,1.5611541271209717,21463.67381906509,46372,0,21463.67381906509,0.4461000263690948,2.515526056289673,10000,23379.889157772064,0.6156249642372131,1.6239218711853027,0.5699399709701538,1.8485853672027588,50000 -1946.0973546504968,1.593156099319458,21883.85037612915,47284,0,21883.85037612915,0.4461000263690948,2.4953420162200928,10000,23834.19975042343,0.6090039014816284,1.6575889587402344,0.5688999891281128,1.851066827774048,50000 -1982.7054772377007,1.629871845245361,22304.195620536804,48197,0,22304.195620536804,0.4580000340938568,2.478607177734375,10000,24291.240940332413,0.611621081829071,1.6552009582519531,0.5723199844360352,1.836769938468933,50000 -2023.9590499401093,1.666623592376709,22724.1991057396,49108,0,22724.1991057396,0.4492000341415405,2.4828150272369385,10000,24752.58571910858,0.6163281202316284,1.6063055992126465,0.5687999725341797,1.8405550718307493,50000 -2059.097989797592,1.7026152610778809,23144.24242377281,50019,0,23144.24242377281,0.4542000293731689,2.4651927947998047,10000,25207.855845689774,0.6085546612739563,1.6156573295593262,0.5709399580955505,1.8096331357955933,50000 -2092.429440975189,1.737779140472412,23564.4644010067,50931,0,23564.4644010067,0.4537000358104706,2.484137296676636,10000,25661.495080709457,0.6145898103713989,1.6268950700759888,0.5734800100326538,1.8265553712844849,50000 -2132.76446890831,1.7713980674743652,23984.826218366623,51842,0,23984.826218366623,0.4580000340938568,2.465365409851074,10000,26122.27624464035,0.619140625,1.61297345161438,0.5745399594306946,1.8239986896514893,50000 -2173.295962333679,1.8063812255859373,24404.973826408383,52752,0,24404.973826408383,0.4611000120639801,2.4369454383850098,10000,26583.041292905807,0.6456249952316284,1.4840998649597168,0.5798799991607666,1.7855687141418457,50000 -2208.069890022278,1.8420429229736328,24825.06359386444,53663,0,24825.06359386444,0.4559000134468078,2.470672130584717,10000,27037.99173426628,0.6150586009025574,1.626201033592224,0.5754199624061584,1.8177956342697144,50000 -2247.7225930690765,1.88374662399292,25245.0795879364,54573,0,25245.0795879364,0.4609000086784363,2.430996894836426,10000,27497.753811597824,0.6263281106948853,1.5781675577163696,0.5829600095748901,1.770058512687683,50000 -2288.314208507538,1.9204604625701904,25665.39813661576,55484,0,25665.39813661576,0.4611000120639801,2.447931289672852,10000,27958.75144600868,0.6351562142372131,1.5078805685043335,0.5780400037765503,1.778692603111267,50000 -2327.961247444153,1.957954168319702,26085.56778717041,56394,0,26085.56778717041,0.4609000086784363,2.4148685932159424,10000,28418.65738463401,0.6251562237739563,1.554768204689026,0.5821200013160706,1.764073371887207,50000 -2364.1540179252625,1.9971649646759035,26505.91107916832,57304,0,26505.91107916832,0.4624000191688537,2.455163955688477,10000,28875.283742427822,0.6250194907188416,1.5971978902816772,0.5831599831581116,1.794812798500061,50000 -2404.646691799164,2.0316531658172607,26926.0757522583,58212,0,26926.0757522583,0.4683000147342682,2.404117584228516,10000,29336.025983572006,0.6395898461341858,1.5047837495803833,0.5877599716186523,1.7553316354751587,50000 -2446.963950157165,2.0650646686553955,27346.409598112103,59122,0,27346.409598112103,0.4666000306606293,2.407226085662842,10000,29798.76133680344,0.6267773509025574,1.5606385469436646,0.5853599905967712,1.7512341737747192,50000 -2484.952459335327,2.1022024154663086,27766.34730625153,60032,0,27766.34730625153,0.4752000272274017,2.382453680038452,10000,30256.77499818802,0.6316601634025574,1.5331214666366575,0.5899999737739563,1.7361007928848269,50000 -2524.839050769806,2.1383426189422607,28186.30168414116,60941,0,28186.30168414116,0.4711000323295593,2.39124083518982,10000,30716.703328847885,0.6475585699081421,1.4857375621795654,0.5921599864959717,1.733931541442871,50000 -2564.9455330371857,2.175709009170532,28606.5775911808,61851,0,28606.5775911808,0.4680000245571136,2.374927759170532,10000,31177.174296855927,0.6359570026397705,1.5239770412445068,0.5936599969863892,1.7153202295303345,50000 -2607.876157522201,2.214055299758911,29026.926725625992,62761,0,29026.926725625992,0.4731000363826751,2.3883395195007324,10000,31640.543329954147,0.6364062428474426,1.5226380825042725,0.592960000038147,1.7290226221084597,50000 -2643.59489440918,2.254435777664185,29447.01801109314,63670,0,29447.01801109314,0.4786000251770019,2.3399128913879395,10000,32096.44493150711,0.6468554735183716,1.4551808834075928,0.5985000133514404,1.690281867980957,50000 -2683.8333218097687,2.2939212322235107,29867.20395731926,64580,0,29867.20395731926,0.4704000353813171,2.3861141204833984,10000,32556.95909023285,0.6409569978713989,1.5243161916732788,0.5908799767494202,1.7485175132751465,50000 -2724.1434757709503,2.335206985473633,30287.157656669617,65489,0,30287.157656669617,0.4764000177383423,2.362998485565185,10000,33017.3158800602,0.63232421875,1.529699206352234,0.5950599908828735,1.7079485654830933,50000 -2762.0575156211853,2.374621152877808,30707.10316038132,66399,0,30707.10316038132,0.4799000322818756,2.360117197036743,10000,33475.26572585106,0.6452734470367432,1.4815714359283447,0.5956999659538269,1.7074819803237915,50000 -2803.6267223358154,2.4154183864593506,31127.178329706192,67309,0,31127.178329706192,0.4728000164031982,2.3525753021240234,10000,33937.002141714096,0.6687890291213989,1.3784257173538208,0.5956799983978271,1.7057280540466309,50000 -2845.501267194748,2.461395263671875,31547.497469186783,68219,0,31547.497469186783,0.4745000302791595,2.364424467086792,10000,34399.2928917408,0.6403710842132568,1.5155787467956543,0.5939199924468994,1.7145018577575684,50000 -2885.74742937088,2.4994466304779053,31967.57464170456,69130,0,31967.57464170456,0.4783000349998474,2.311363458633423,10000,34859.705701351166,0.6488866806030273,1.4620481729507446,0.6020599603652954,1.6721429824829102,50000 -2925.956077575684,2.5358903408050537,32387.68676304817,70042,0,32387.68676304817,0.4807000160217285,2.3237833976745605,10000,35320.11451506615,0.6694726347923279,1.3626357316970823,0.6027399897575378,1.665413737297058,50000 -2965.8574674129486,2.5720467567443848,32807.90108847618,70951,0,32807.90108847618,0.4836000204086303,2.3453047275543213,10000,35780.31765007973,0.6404492259025574,1.498771071434021,0.6004399657249451,1.6853318214416504,50000 -3006.7480852603912,2.6135470867156982,33228.07022809982,71860,0,33228.07022809982,0.4808000326156616,2.3266289234161377,10000,36241.46924185753,0.6525781154632568,1.4528508186340332,0.604699969291687,1.674795389175415,50000 -3043.734041452408,2.652384042739868,33648.16423654556,72770,0,33648.16423654556,0.4833000302314758,2.316112756729126,10000,36698.63916397095,0.6627148389816284,1.4005804061889648,0.6051599979400635,1.6614387035369873,50000 -3084.372179746628,2.68951416015625,34068.218707084656,73682,0,34068.218707084656,0.4877000153064728,2.2946367263793945,10000,37159.42027497292,0.6526562571525574,1.441790223121643,0.6084200143814087,1.6414271593093872,50000 -3122.4718701839447,2.7291154861450195,34488.36898994446,74591,0,34488.36898994446,0.48580002784729,2.3193719387054443,10000,37617.76039338112,0.6527929306030273,1.446118950843811,0.6092000007629395,1.650182604789734,50000 -3162.1323087215424,2.7706854343414307,34908.69756317139,75503,0,34908.69756317139,0.4943000376224518,2.2832729816436768,10000,38077.84191131592,0.6614648103713989,1.3912386894226074,0.606719970703125,1.632909655570984,50000 -3197.922151327133,2.8070130348205566,35328.6748585701,76412,0,35328.6748585701,0.4888000190258026,2.271522760391236,10000,38533.69681978226,0.6544336080551147,1.4055346250534058,0.6148200035095215,1.6018518209457395,50000 -3237.468738079071,2.8454105854034424,35748.62918996811,77323,0,35748.62918996811,0.48580002784729,2.295750617980957,10000,38993.28647398949,0.6539453268051147,1.4418593645095823,0.6126799583435059,1.63576340675354,50000 -3273.759335756302,2.88728928565979,36168.83649253845,78235,0,36168.83649253845,0.4935000240802765,2.2690629959106445,10000,39449.87701368332,0.665722668170929,1.3743427991867063,0.6133399605751038,1.6258063316345217,50000 -3313.05346608162,2.9241793155670166,36589.174001932144,79146,0,36589.174001932144,0.4942000210285187,2.2950570583343506,10000,39909.59665203095,0.6544336080551147,1.4474139213562012,0.6101999878883362,1.6504663228988647,50000 -3349.5856976509094,2.9685537815093994,37009.08654427528,80055,0,37009.08654427528,0.4994000196456909,2.22688364982605,10000,40366.136556625366,0.6633593440055847,1.3997061252593994,0.6201199889183044,1.590294361114502,50000 -3384.9342653751373,3.0103461742401123,37429.40017914772,80968,0,37429.40017914772,0.4915000200271606,2.279834032058716,10000,40821.89309167862,0.6633203029632568,1.383533239364624,0.6125400066375732,1.6144942045211792,50000 -3424.0457706451416,3.052287578582764,37849.3330321312,81878,0,37849.3330321312,0.4943000376224518,2.2313928604125977,10000,41281.0299179554,0.6926171779632568,1.2652881145477295,0.6225599646568298,1.5752804279327393,50000 -3461.037127256393,3.0954391956329346,38269.31256151199,82788,0,38269.31256151199,0.4970000088214874,2.224552154541016,10000,41738.09440588951,0.6661523580551147,1.387743592262268,0.6181600093841553,1.5939165353775024,50000 -3500.086168289185,3.1364245414733887,38689.46293449402,83698,0,38689.46293449402,0.5004000067710876,2.207161903381348,10000,42197.38641667366,0.6722851395606995,1.3354723453521729,0.6222999691963196,1.5551286935806274,50000 -3539.470131397248,3.183518648147583,39109.49723100662,84609,0,39109.49723100662,0.5009000301361084,2.2192747592926025,10000,42656.90248131752,0.6911913752555847,1.2636808156967163,0.6240800023078918,1.56609046459198,50000 -3582.189652442932,3.224731683731079,39529.7694542408,85519,0,39529.7694542408,0.5010000467300415,2.2568812370300293,10000,43119.98700237274,0.6674023270606995,1.393503189086914,0.6214599609375,1.5956655740737915,50000 -3620.007707118988,3.262495994567871,39950.09155750275,86427,0,39950.09155750275,0.5054000020027161,2.208587408065796,10000,43578.21590304375,0.6750780940055847,1.347376823425293,0.6284599900245667,1.55901837348938,50000 -3659.488579750061,3.3086326122283936,40370.41019010544,87338,0,40370.41019010544,0.503000020980835,2.1890664100646973,10000,44038.112585783005,0.6869726181030273,1.2832924127578735,0.6260600090026855,1.5575437545776367,50000 -3698.100270032882,3.3508591651916504,40790.58680820465,88245,0,40790.58680820465,0.5083000063896179,2.1848161220550537,10000,44496.99372458458,0.6767382621765137,1.3393456935882568,0.6292399764060974,1.5438076257705688,50000 -3734.2811844348894,3.396424531936645,41210.69068980217,89155,0,41210.69068980217,0.4978000223636627,2.233126878738404,10000,44953.37469482422,0.6765820384025574,1.3568689823150637,0.6251199841499329,1.592879056930542,50000 -3774.460196733474,3.4379467964172363,41631.16453623772,90066,0,41631.16453623772,0.508400022983551,2.170599460601806,10000,45414.1202609539,0.6912695169448853,1.2604010105133057,0.6340000033378601,1.5193389654159546,50000 -3808.74009847641,3.4811456203460693,42051.34705209732,90976,0,42051.34705209732,0.5042999982833862,2.1914098262786865,10000,45868.677743434906,0.6769921779632568,1.3207937479019165,0.6307799816131592,1.5352376699447632,50000 -3845.205990314484,3.5260469913482666,42471.574355363846,91885,0,42471.574355363846,0.5106000304222107,2.133218050003052,10000,46325.46682262421,0.6859374642372131,1.277827262878418,0.6394599676132202,1.4927690029144287,50000 -3884.377197265625,3.5714974403381348,42891.8378636837,92793,0,42891.8378636837,0.5193000435829163,2.167919874191284,10000,46784.998166799545,0.6937304735183716,1.268678903579712,0.6373999714851379,1.5291402339935305,50000 -3920.189643383026,3.617255449295044,43311.91257548332,93700,0,43311.91257548332,0.5074000358581543,2.174853563308716,10000,47240.98172211647,0.684374988079071,1.302620768547058,0.6352800130844116,1.5236302614212036,50000 -3958.876034975052,3.6599841117858887,43732.20620751381,94609,0,43732.20620751381,0.5190000534057617,2.115467071533203,10000,47700.05508303642,0.6906640529632568,1.2588465213775637,0.6386399865150452,1.485780119895935,50000 -3995.7055275440216,3.703071117401123,44152.35373473168,95519,0,44152.35373473168,0.516800045967102,2.149538278579712,10000,48157.12686371803,0.6943945288658142,1.260308861732483,0.6425999999046326,1.499654769897461,50000 -4033.443476676941,3.7422258853912354,44572.31914424896,96429,0,44572.31914424896,0.5157000422477722,2.13046932220459,10000,48614.92049980164,0.7122656106948853,1.174549221992493,0.6404199600219727,1.485430121421814,50000 -4072.826028347016,3.780719995498657,44992.56587338448,97339,0,44992.56587338448,0.5144000053405762,2.127562522888184,10000,49074.64042210579,0.6898632645606995,1.2658751010894775,0.6431399583816528,1.4787806272506714,50000 -4109.709174156189,3.827104806900024,45412.90539312363,98248,0,45412.90539312363,0.520300030708313,2.1551709175109863,10000,49531.960334301,0.6936132907867432,1.285739541053772,0.6411199569702148,1.5151735544204712,50000 -4149.903354167938,3.872529983520508,45832.94949889183,99158,0,45832.94949889183,0.5199000239372253,2.1205554008483887,10000,49992.29464268685,0.7100781202316284,1.1822881698608398,0.6423599720001221,1.4824674129486084,50000 -4188.944055318832,3.9171104431152335,46253.27742695808,100068,0,46253.27742695808,0.5268000364303589,2.0834946632385254,10000,50451.75859832764,0.69677734375,1.2403640747070312,0.6484400033950806,1.4505292177200315,50000 -4226.306841373444,3.9611897468566895,46673.61992907524,100978,0,46673.61992907524,0.5200000405311584,2.102366209030152,10000,50909.56090068817,0.7009961009025574,1.234709620475769,0.6490600109100342,1.46269428730011,50000 -4267.212629556656,4.426010370254517,47093.20279479027,101885,0,47093.20279479027,0.5250000357627869,2.06763768196106,10000,51370.56505489349,0.7134765386581421,1.1484699249267578,0.6495400071144104,1.434063196182251,50000 -4303.410413980484,4.468097686767578,47513.41642546654,102794,0,47513.41642546654,0.5301000475883484,2.074958086013794,10000,51827.06900215149,0.69691401720047,1.217390060424805,0.6519399881362915,1.433598875999451,50000 -4346.339977025986,4.515044212341309,47933.50048518181,103704,0,47933.50048518181,0.5266000032424927,2.095916271209717,10000,52290.1813583374,0.7031054496765137,1.2242692708969116,0.6475200057029724,1.4630873203277588,50000 -4390.538145542145,4.555092096328735,48353.644204854965,104614,0,48353.644204854965,0.5293000340461731,2.080199718475342,10000,52754.61422896385,0.7122460603713989,1.175422430038452,0.6545599699020386,1.444462776184082,50000 -4432.014835357666,4.596234560012817,48773.55797100067,105522,0,48773.55797100067,0.525600016117096,2.077850341796875,10000,53216.09664773941,0.7056835889816284,1.2042386531829834,0.6565399765968323,1.425550937652588,50000 -4475.56670665741,4.640961647033691,49193.51614046097,106430,0,49193.51614046097,0.539400041103363,2.0336849689483643,10000,53679.70312547684,0.7126562595367432,1.1635273694992063,0.6616799831390381,1.3938157558441162,50000 -4515.323534250259,4.687347173690796,49613.643409490585,107337,0,49613.643409490585,0.532200038433075,2.072439670562744,10000,54139.6849834919,0.717968761920929,1.1395562887191772,0.6581799983978271,1.4131308794021606,50000 -4553.532001495361,4.729357481002808,50033.96087384224,108246,0,50033.96087384224,0.5309000015258789,2.076124429702759,10000,54598.303564071655,0.7106640338897705,1.1932311058044434,0.6593999862670898,1.424089789390564,50000 -4594.851831197739,4.773675441741943,50454.120992183685,109155,0,50454.120992183685,0.5367000102996826,2.0257108211517334,10000,55059.87839961052,0.7136523127555847,1.161201238632202,0.6612799763679504,1.3903865814208984,50000 -4635.680241346359,4.820109605789185,50874.046491622925,110063,0,50874.046491622925,0.5388000011444092,2.0300371646881104,10000,55520.72982406616,0.724804699420929,1.1234993934631348,0.6626799702644348,1.3978769779205322,50000 -4673.234929800034,4.870450496673584,51293.97769546509,110973,0,51293.97769546509,0.5364000201225281,2.026243209838867,10000,55978.31710648537,0.7303124666213989,1.0987309217453003,0.6647199988365173,1.3945322036743164,50000 -4714.676728963852,4.919151782989502,51714.3005130291,111884,0,51714.3005130291,0.5379000306129456,2.036576271057129,10000,56440.18104696274,0.7140820026397705,1.1694566011428833,0.6618799567222595,1.3972446918487549,50000 -4753.851475954056,4.964050054550171,52134.38846230507,112793,0,52134.38846230507,0.5406000018119812,2.0242462158203125,10000,56899.53965449333,0.7234960794448853,1.1330419778823853,0.6687399744987488,1.3698722124099731,50000 -4793.7336666584015,5.01096510887146,52555.08496952057,113701,0,52555.08496952057,0.5449000000953674,2.020291805267334,10000,57360.21570634842,0.74085932970047,1.0626214742660522,0.6674599647521973,1.3853868246078491,50000 -4835.897415399551,5.053654193878174,52975.2583398819,114607,0,52975.2583398819,0.5461000204086304,2.014665603637696,10000,57822.64588141441,0.7202929258346558,1.1391555070877075,0.6705399751663208,1.3693618774414062,50000 -4877.004498958588,5.096050024032593,53395.23126530647,115514,0,53395.23126530647,0.5449000000953674,1.9930286407470703,10000,58283.81888151169,0.7275585532188416,1.085247039794922,0.6701200008392334,1.3525265455245972,50000 -4919.428336620331,5.142207384109497,53815.36892461777,116421,0,53815.36892461777,0.5443000197410583,2.023348093032837,10000,58746.47779226303,0.7390038967132568,1.067587971687317,0.6695599555969238,1.372552514076233,50000 -4960.681174516678,5.188116788864136,54235.27755284309,117327,0,54235.27755284309,0.5481000542640686,1.9819422960281368,10000,59207.73613762856,0.7241405844688416,1.1141245365142822,0.6716200113296509,1.3542362451553345,50000 -4997.757047891617,5.235001087188721,54655.42571687698,118236,0,54655.42571687698,0.5499000549316406,1.9686278104782104,10000,59665.05784249306,0.7311132550239563,1.0923360586166382,0.674079954624176,1.3469245433807373,50000 -5039.916709423065,5.278695106506348,55075.748200416565,119146,0,55075.748200416565,0.5534000396728516,1.954906702041626,10000,60127.63447976112,0.7421093583106995,1.022621989250183,0.6779999732971191,1.313728094100952,50000 -5078.214918136597,5.3276238441467285,55495.910600185394,120055,0,55495.910600185394,0.5569000244140625,1.9266986846923828,10000,60586.19439768791,0.7331249713897705,1.0569225549697876,0.6812399625778198,1.2919137477874756,50000 -5119.290406227112,5.378458261489868,55916.22533154488,120963,0,55916.22533154488,0.5561000108718872,1.9345511198043823,10000,61047.6866889,0.7388867139816284,1.049467921257019,0.6840800046920776,1.2940298318862915,50000 -5156.877597808838,5.425642251968384,56336.53861951828,121870,0,56336.53861951828,0.5592000484466553,1.9299076795578003,10000,61505.68466210365,0.7439062595367432,1.030897855758667,0.6851599812507629,1.2949132919311523,50000 -5195.670909404755,5.473036527633667,56756.696388721466,122779,0,56756.696388721466,0.5611000061035156,1.8941129446029663,10000,61964.73351264,0.7379491925239563,1.0452035665512085,0.685979962348938,1.2754040956497192,50000 -5239.167343139648,5.522738933563232,57176.67815852165,123689,0,57176.67815852165,0.5586000084877014,1.901072382926941,10000,62428.31165552139,0.7442382574081421,1.0326616764068604,0.6847800016403198,1.2847895622253418,50000 -5276.72448348999,5.570141315460205,57596.78480887413,124598,0,57596.78480887413,0.5652000308036804,1.9213535785675049,10000,62886.07406878472,0.7477929592132568,1.0206478834152222,0.6859599947929382,1.2963595390319824,50000 -5315.758058547974,5.622182130813599,58016.77628803253,125508,0,58016.77628803253,0.5609000325202942,1.9046622514724727,10000,63345.20164251328,0.7463671565055847,1.0201128721237185,0.6872999668121338,1.2751214504241943,50000 -5353.759917020798,5.668635845184326,58437.0435795784,126415,0,58437.0435795784,0.5678000450134277,1.879265069961548,10000,63803.56805491448,0.7504687309265137,0.9982839822769164,0.6909799575805664,1.252796649932861,50000 -5391.798512220383,5.717786550521851,58857.124438762665,127324,0,58857.124438762665,0.567300021648407,1.8777422904968264,10000,64261.78774547577,0.7506445050239563,0.9825835824012756,0.6895599961280823,1.252158761024475,50000 -5429.511273860931,5.765959978103638,59277.34511780739,128236,0,59277.34511780739,0.5665000081062317,1.882158637046814,10000,64719.820784807205,0.7712695002555847,0.9164721369743348,0.6928399801254272,1.2563698291778564,50000 -5471.143674373627,5.814197540283203,59697.71044564247,129143,0,59697.71044564247,0.5640000104904175,1.883506298065185,10000,65181.91693449021,0.753222644329071,1.0042492151260376,0.6917200088500977,1.2587978839874268,50000 -5506.229387283325,5.861069679260254,60117.83626127243,130052,0,60117.83626127243,0.5714000463485718,1.8770899772644043,10000,65637.22610259056,0.7607616782188416,0.9540216326713562,0.6956999897956848,1.2329747676849363,50000 -5547.110866785049,5.910343647003174,60537.82030844688,130962,0,60537.82030844688,0.5742000341415405,1.8582838773727417,10000,66098.19182014465,0.7698827981948853,0.9147235155105592,0.6962400078773499,1.2309191226959229,50000 -5589.005449295044,5.961735725402832,60957.97349905968,131870,0,60957.97349905968,0.5725000500679016,1.835734844207764,10000,66560.34111618996,0.7575585842132568,0.9641546607017516,0.7005599737167358,1.219005107879639,50000 -5628.361680984497,6.011536359786987,61378.23578214645,132778,0,61378.23578214645,0.5773000121116638,1.856716513633728,10000,67020.0600142479,0.7624413967132568,0.9442371129989624,0.6990199685096741,1.221044421195984,50000 -5665.5401656627655,6.066869735717773,61798.40344524384,133688,0,61798.40344524384,0.57750004529953,1.8353841304779053,10000,67477.51272082329,0.7743359208106995,0.8995473384857178,0.7016800045967102,1.211005687713623,50000 -5707.8211581707,6.115529537200928,62218.46045303345,134599,0,62218.46045303345,0.5789000391960144,1.8349767923355105,10000,67939.94962263107,0.76185542345047,0.9467484951019288,0.7027999758720398,1.2054047584533691,50000 -5748.042767763138,6.164484024047852,62638.41057395935,135507,0,62638.41057395935,0.5776000022888184,1.8514093160629272,10000,68400.22089409828,0.7659569978713989,0.9456735253334044,0.7011799812316895,1.2311320304870603,50000 -5784.257081747055,6.215606927871704,63058.55126523972,136414,0,63058.55126523972,0.5824000239372253,1.8206918239593504,10000,68856.67860889435,0.7757421731948853,0.90151709318161,0.704539954662323,1.1993683576583862,50000 -5824.872405529022,6.267154216766357,63478.52360081673,137322,0,63478.52360081673,0.5777000188827515,1.8227862119674685,10000,69317.36960935593,0.7714648246765137,0.9188026189804076,0.7074999809265137,1.1961042881011963,50000 -5863.821855783463,6.314066648483276,63898.66561055184,138229,0,63898.66561055184,0.5875000357627869,1.7642264366149902,10000,69776.55884242058,0.777636706829071,0.8658644556999207,0.7101799845695496,1.1642639636993408,50000 -5903.501656532288,6.361291170120239,64318.93400526047,139141,0,64318.93400526047,0.5849000215530396,1.7911489009857178,10000,70236.60565376282,0.7803515195846558,0.868783712387085,0.7113199830055237,1.1666555404663086,50000 -5945.039888620377,6.41525673866272,64738.83408522606,140050,0,64738.83408522606,0.5907000303268433,1.7813829183578491,10000,70698.14928531647,0.7785937190055847,0.8846719861030579,0.7123799920082092,1.1717824935913086,50000 -5981.03254365921,6.46516752243042,65158.86788415909,140959,0,65158.86788415909,0.5898000001907349,1.7759381532669067,10000,71154.27746748924,0.775390625,0.8824504017829895,0.7128199934959412,1.150181770324707,50000 -6021.227711677551,6.5184853076934814,65579.00881290436,141867,0,65579.00881290436,0.5962000489234924,1.753408432006836,10000,71614.71821403503,0.7857617139816284,0.8551717400550842,0.7181199789047241,1.148403525352478,50000 -6063.670241594315,6.572426080703735,65999.13509559631,142775,0,65999.13509559631,0.5949000120162964,1.7673008441925049,10000,72077.39252829552,0.7930663824081421,0.8114429116249084,0.7167999744415283,1.1399496793746948,50000 -6103.460964918137,6.625840425491333,66419.16562604904,143682,0,66419.16562604904,0.6010000109672546,1.724225640296936,10000,72537.31822061539,0.7851952910423279,0.8426395058631897,0.7227999567985535,1.118364691734314,50000 -6141.835455179215,7.105395555496216,66839.12981963158,144588,0,66839.12981963158,0.5928000211715698,1.753752827644348,10000,72996.18788266182,0.7878710627555847,0.8365593552589417,0.7193599939346313,1.1317459344863892,50000 -6183.42341208458,7.158540964126587,67259.35423231125,145495,0,67259.35423231125,0.6044000387191772,1.7381134033203125,10000,73458.10443782806,0.7982421517372131,0.8103423714637756,0.7236799597740173,1.131292462348938,50000 -6224.7675149440765,7.207941770553589,67679.5624153614,146404,0,67679.5624153614,0.6007000207901001,1.729142427444458,10000,73919.75740528107,0.7885546684265137,0.8147965669631958,0.7246999740600586,1.102521538734436,50000 -6268.529381752014,7.259737014770508,68099.6663825512,147315,0,68099.6663825512,0.6016000509262085,1.728710412979126,10000,74383.72610616684,0.7919921875,0.8182185888290405,0.7255399823188782,1.113688826560974,50000 -6304.80890417099,7.309703350067139,68519.65882515907,148224,0,68519.65882515907,0.6040000319480896,1.6969178915023804,10000,74840.09953451157,0.8016796708106995,0.7689365744590759,0.7294999957084656,1.0898467302322388,50000 -6346.617129325867,7.357788801193237,68939.72495675087,149132,0,68939.72495675087,0.6046000123023987,1.711686372756958,10000,75302.07324838638,0.7948632836341858,0.8024986982345581,0.7267000079154968,1.0969750881195068,50000 -6385.988118648529,7.408616781234741,69359.74498486519,150039,0,69359.74498486519,0.6082000136375427,1.716665267944336,10000,75761.56579613686,0.7957812547683716,0.812127411365509,0.7278800010681152,1.094635009765625,50000 -6429.201789140701,7.461533546447754,69780.05288887024,150949,0,69780.05288887024,0.6103000044822693,1.6876506805419922,10000,76225.19110178947,0.8050976395606995,0.7490441799163818,0.7310199737548828,1.0718308687210083,50000 -6471.581538200378,7.511557102203369,70200.3424217701,151859,0,70200.3424217701,0.6122000217437744,1.6726754903793335,10000,76687.96187376976,0.8020703196525574,0.7700572609901428,0.7359799742698669,1.066151738166809,50000 -6510.923154830933,7.5636162757873535,70620.28472137451,152769,0,70620.28472137451,0.614300012588501,1.6564711332321167,10000,77147.348580122,0.8058202862739563,0.7566875219345093,0.7363399863243103,1.0597726106643677,50000 -6548.374161958695,7.616273880004883,71040.23987174034,153678,0,71040.23987174034,0.6166000366210938,1.6612064838409424,10000,77604.85888195038,0.81298828125,0.7200682759284973,0.7379800081253052,1.0499589443206787,50000 -6586.756381750107,7.667668581008911,71460.52765202522,154587,0,71460.52765202522,0.6158000230789185,1.6725051403045654,10000,78063.63143539429,0.8049609065055847,0.746151864528656,0.737339973449707,1.0490286350250244,50000 -6627.746407985687,7.721628189086914,71880.68959569931,155497,0,71880.68959569931,0.6164000034332275,1.652098298072815,10000,78524.88885855675,0.810839831829071,0.723305344581604,0.7399199604988098,1.035240888595581,50000 -6664.082603693008,7.774799346923828,72300.75505638123,156406,0,72300.75505638123,0.62090003490448,1.653515338897705,10000,78981.39456152916,0.8157812356948853,0.7162489295005798,0.7403799891471863,1.0414212942123413,50000 -6705.670013189316,7.826829195022583,72720.79807853699,157314,0,72720.79807853699,0.6237000226974487,1.6517783403396606,10000,79443.12803387642,0.8150194883346558,0.7227945327758789,0.7415800094604492,1.0487637519836426,50000 -6748.654703617096,7.878124952316284,73141.04412603378,158223,0,73141.04412603378,0.626800000667572,1.6282966136932373,10000,79906.46152997017,0.8199414014816284,0.7127540707588196,0.7445399761199951,1.0261250734329224,50000 -6792.603569984436,7.927542448043823,73561.13456368446,159130,0,73561.13456368446,0.6269000172615051,1.622787356376648,10000,80370.60028123856,0.8240038752555847,0.6871118545532227,0.747439980506897,1.0210548639297483,50000 -6831.573830366135,7.979724168777466,73981.15388822556,160039,0,73981.15388822556,0.6220000386238098,1.6190569400787354,10000,80829.69215226173,0.8271484375,0.6677748560905457,0.7458199858665466,1.0177862644195557,50000 -6869.350848913193,8.031014919281006,74401.07062625885,160948,0,74401.07062625885,0.6239000558853149,1.608596682548523,10000,81287.48914146423,0.8216406106948853,0.6819668412208557,0.7469599843025208,1.0042126178741455,50000 -6911.723528146744,8.083473920822144,74821.04405879974,161856,0,74821.04405879974,0.6320000290870667,1.5975457429885864,10000,81749.93900322914,0.8240429759025574,0.6735599637031555,0.7506600022315979,0.9978360533714294,50000 -6956.957853317261,8.137528657913208,75240.97953391075,162765,0,75240.97953391075,0.6312000155448914,1.600247502326965,10000,82215.21632957458,0.8299023509025574,0.6468636989593506,0.7511199712753296,0.9937317967414856,50000 -6995.772277355194,8.188964128494263,75661.26156401634,163675,0,75661.26156401634,0.6281000375747681,1.6094584465026855,10000,82674.41547107697,0.8234374523162842,0.6833599209785461,0.7499200105667114,0.9993503093719482,50000 -7036.61114192009,8.245900869369507,76081.20295095444,164583,0,76081.20295095444,0.6324000358581543,1.5833725929260254,10000,83135.30335402489,0.8306640386581421,0.6563959121704102,0.7524600028991699,0.9888281226158142,50000 -7075.693947792053,8.304190874099731,76501.53271389008,165493,0,76501.53271389008,0.6305000185966492,1.5695737600326538,10000,83594.82507920265,0.8374218344688416,0.6225919723510742,0.7544599771499634,0.9747494459152222,50000 -7118.194668292999,8.359277486801147,76921.90054345131,166403,0,76921.90054345131,0.6349000334739685,1.5555449724197388,10000,84057.80018854141,0.8303124904632568,0.6463891267776489,0.7556399703025818,0.9667149782180786,50000 -7155.896743297577,8.412142515182495,77342.16409659386,167311,0,77342.16409659386,0.6397000551223755,1.5751323699951172,10000,84515.86970067024,0.8325781226158142,0.6565069556236267,0.7567799687385559,0.9810318350791931,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/measurements.csv deleted file mode 100644 index d32398e56..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1865 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36637583,6.9077563,,,,,,,,,,,,,, -1,,,0.0008398437057621,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,35.492053747177124,62.411550998687744,35.492053747177124,26.919355869293213,0.0,0.0 -100,0.37267134,6.906714,,,,,,,,,,,,,, -200,0.45638955,6.891657,,,,,,,,,,,,,, -300,0.6337917,6.8492975,,,,,,,,,,,,,, -400,0.7851419,6.79194,,,,,,,,,,,,,, -500,0.69633317,6.762593,,,,,,,,,,,,,, -600,1.2280363,6.7156596,,,,,,,,,,,,,, -700,0.8617648,6.795324,,,,,,,,,,,,,, -800,1.2471343,6.6191792,,,,,,,,,,,,,, -857,,,0.0123046869412064,6.458219051361084,0.0127399992197752,6.469983100891113,50000.0,0.0090000005438923,6.5088372230529785,10000.0,455.7455909252167,519.2673494815826,455.7455909252167,63.456319093704224,0.0174732208251953,0.0 -900,3.3303752,6.514044,,,,,,,,,,,,,, -1000,3.46582,6.5008893,,,,,,,,,,,,,, -1100,2.1978655,6.372173,,,,,,,,,,,,,, -1200,1.7159153,6.344222,,,,,,,,,,,,,, -1300,2.268237,6.3654814,,,,,,,,,,,,,, -1400,1.5309668,6.317447,,,,,,,,,,,,,, -1500,2.8203628,6.337349,,,,,,,,,,,,,, -1600,1.7355629,6.15565,,,,,,,,,,,,,, -1700,1.5362016,6.5116453,,,,,,,,,,,,,, -1766,,,0.036210935562849,5.856694221496582,0.0349400006234645,5.884265422821045,50000.0,0.0318999998271465,5.995128154754639,10000.0,875.765299320221,978.044549703598,875.765299320221,102.13130497932434,0.0477561950683593,0.0 -1800,1.780816,6.434475,,,,,,,,,,,,,, -1900,2.494838,6.653183,,,,,,,,,,,,,, -2000,2.1567938,5.952784,,,,,,,,,,,,,, -2100,1.640092,5.95557,,,,,,,,,,,,,, -2200,2.1688035,5.949238,,,,,,,,,,,,,, -2300,2.9563422,5.882807,,,,,,,,,,,,,, -2400,2.7088563,5.8906817,,,,,,,,,,,,,, -2500,2.113705,5.860345,,,,,,,,,,,,,, -2600,2.2340653,5.8045893,,,,,,,,,,,,,, -2677,,,0.0633203089237213,5.436893939971924,0.0602799989283084,5.480632305145264,50000.0,0.0478000007569789,5.662439346313477,10000.0,1296.171561002731,1437.1218678951263,1296.171561002731,140.7256236076355,0.073580265045166,0.0 -2700,1.412259,6.414305,,,,,,,,,,,,,, -2800,2.2242897,5.702967,,,,,,,,,,,,,, -2900,2.1788282,5.691255,,,,,,,,,,,,,, -3000,1.6709917,5.8159747,,,,,,,,,,,,,, -3100,2.0272858,5.6754932,,,,,,,,,,,,,, -3200,1.8537449,6.4260354,,,,,,,,,,,,,, -3300,2.056754,5.554441,,,,,,,,,,,,,, -3400,2.113476,5.6552205,,,,,,,,,,,,,, -3500,1.6596628,5.884234,,,,,,,,,,,,,, -3587,,,0.0926367193460464,5.093984603881836,0.0870999991893768,5.130605220794678,50000.0,0.068400003015995,5.372107982635498,10000.0,1716.1118450164795,1895.894778966904,1716.1118450164795,179.481369972229,0.0996489524841308,0.0 -3600,1.6473718,6.3072777,,,,,,,,,,,,,, -3700,1.8785701,5.4949293,,,,,,,,,,,,,, -3800,1.9505922,5.521139,,,,,,,,,,,,,, -3900,1.5967665,5.602399,,,,,,,,,,,,,, -4000,2.3040364,5.489657,,,,,,,,,,,,,, -4100,2.3844995,5.429547,,,,,,,,,,,,,, -4200,2.101097,5.3505154,,,,,,,,,,,,,, -4300,2.0933318,5.3171864,,,,,,,,,,,,,, -4400,1.4331657,6.345272,,,,,,,,,,,,,, -4496,,,0.1288671791553497,4.745182514190674,0.120679996907711,4.8043293952941895,50000.0,0.0940000042319297,5.104016304016113,10000.0,2136.0499653816223,2354.783221721649,2136.0499653816223,218.3541798591613,0.1258656978607177,0.0 -4500,2.1178987,5.2812815,,,,,,,,,,,,,, -4600,2.3009987,5.2433343,,,,,,,,,,,,,, -4700,1.7357867,5.2139187,,,,,,,,,,,,,, -4800,1.5750878,6.380578,,,,,,,,,,,,,, -4900,1.9872174,5.124793,,,,,,,,,,,,,, -5000,2.3038847,5.372861,,,,,,,,,,,,,, -5100,1.7576816,5.1554613,,,,,,,,,,,,,, -5200,2.1110468,5.058659,,,,,,,,,,,,,, -5300,1.6145927,6.25284,,,,,,,,,,,,,, -5400,2.0652876,6.4228783,,,,,,,,,,,,,, -5405,,,0.1678320318460464,4.406538486480713,0.1526799947023391,4.486951351165772,50000.0,0.1173000037670135,4.828131675720215,10000.0,2555.98603963852,2814.194491147995,2555.98603963852,257.7483706474304,0.1558928489685058,0.0 -5500,1.8964496,5.5128407,,,,,,,,,,,,,, -5600,1.8536466,4.936146,,,,,,,,,,,,,, -5700,1.9571908,5.3340845,,,,,,,,,,,,,, -5800,1.5802017,5.6489,,,,,,,,,,,,,, -5900,1.9654428,4.9690814,,,,,,,,,,,,,, -6000,2.7130888,5.3793416,,,,,,,,,,,,,, -6100,2.1198342,4.7741575,,,,,,,,,,,,,, -6200,1.7685767,5.4163294,,,,,,,,,,,,,, -6300,1.9791529,4.7457447,,,,,,,,,,,,,, -6314,,,0.2112695276737213,4.0574235916137695,0.1956399977207183,4.138259410858154,50000.0,0.1439000070095062,4.525965213775635,10000.0,2976.0876848697662,3268.1993370056152,2976.0876848697662,291.5754690170288,0.181589126586914,0.0 -6400,1.9061272,4.9429283,,,,,,,,,,,,,, -6500,2.2208362,5.071496,,,,,,,,,,,,,, -6600,2.070998,5.310601,,,,,,,,,,,,,, -6700,1.1671882,6.2726135,,,,,,,,,,,,,, -6800,2.1357198,5.0288334,,,,,,,,,,,,,, -6900,2.0398102,4.533389,,,,,,,,,,,,,, -7000,1.5711331,6.3173547,,,,,,,,,,,,,, -7100,1.6643807,6.2178774,,,,,,,,,,,,,, -7200,1.7987877,6.1786814,,,,,,,,,,,,,, -7227,,,0.2490820288658142,3.757694959640503,0.23157998919487,3.856245756149292,50000.0,0.1714000105857849,4.288461685180664,10000.0,3396.280373096466,3726.625602722168,3396.280373096466,329.72999024391174,0.2092578411102295,0.0 -7300,1.4165337,6.098141,,,,,,,,,,,,,, -7400,1.7240698,4.3486347,,,,,,,,,,,,,, -7500,1.6236767,5.232667,,,,,,,,,,,,,, -7600,1.9619926,4.4498186,,,,,,,,,,,,,, -7700,1.8880677,4.437031,,,,,,,,,,,,,, -7800,2.1256924,4.4521575,,,,,,,,,,,,,, -7900,1.5328115,5.415993,,,,,,,,,,,,,, -8000,2.1512706,4.2602177,,,,,,,,,,,,,, -8100,1.8038691,4.2859697,,,,,,,,,,,,,, -8138,,,0.2967578172683716,3.435223340988159,0.2703399956226349,3.561870574951172,50000.0,0.2063000053167343,4.033924579620361,10000.0,3816.497978210449,4183.641686677933,3816.497978210449,366.4525589942932,0.2342324256896972,0.0 -8200,1.4337878,6.110095,,,,,,,,,,,,,, -8300,1.8845947,4.403481,,,,,,,,,,,,,, -8400,2.3766098,4.206101,,,,,,,,,,,,,, -8500,1.970301,4.1341023,,,,,,,,,,,,,, -8600,2.1805224,4.406675,,,,,,,,,,,,,, -8700,1.9098258,4.7990484,,,,,,,,,,,,,, -8800,2.057741,4.1085386,,,,,,,,,,,,,, -8900,1.5480711,5.4610834,,,,,,,,,,,,,, -9000,1.886082,4.1916995,,,,,,,,,,,,,, -9048,,,0.3414452970027923,3.161966562271118,0.3065399825572967,3.330437660217285,50000.0,0.2297000139951706,3.847393989562988,10000.0,4236.516438961029,4642.813379764557,4236.516438961029,405.5273621082306,0.2614433765411377,0.0 -9100,1.3707361,6.0372095,,,,,,,,,,,,,, -9200,1.244833,5.9873257,,,,,,,,,,,,,, -9300,1.4534646,5.7480187,,,,,,,,,,,,,, -9400,1.1946943,6.052431,,,,,,,,,,,,,, -9500,1.3053265,6.0512714,,,,,,,,,,,,,, -9600,1.8705636,4.186033,,,,,,,,,,,,,, -9700,2.113959,4.1003504,,,,,,,,,,,,,, -9800,1.8161825,4.05628,,,,,,,,,,,,,, -9900,2.0596938,3.963779,,,,,,,,,,,,,, -9958,,,0.3555664122104645,3.039619445800781,0.3312200009822845,3.173012971878052,50000.0,0.2514000236988067,3.730279922485352,10000.0,4656.703306913376,5099.041333913803,4656.703306913376,441.491233587265,0.2879626750946045,0.0 -10000,1.7953275,4.0166903,,,,,,,,,,,,,, -10100,2.3416119,4.9929376,,,,,,,,,,,,,, -10200,1.4188648,5.4206915,,,,,,,,,,,,,, -10300,1.8505689,3.936671,,,,,,,,,,,,,, -10400,1.9195908,4.2501636,,,,,,,,,,,,,, -10500,1.6864752,3.8272152,,,,,,,,,,,,,, -10600,1.8336288,4.2719464,,,,,,,,,,,,,, -10700,1.495034,4.5685263,,,,,,,,,,,,,, -10800,1.2626909,5.9719243,,,,,,,,,,,,,, -10870,,,0.3751562535762787,2.954991340637207,0.3479799926280975,3.092005491256714,50000.0,0.2725000083446502,3.626384973526001,10000.0,5076.962865829468,5555.052897453308,5076.962865829468,477.1607825756073,0.3193974494934082,0.0 -10900,1.8446767,3.9203937,,,,,,,,,,,,,, -11000,1.7880055,4.2141113,,,,,,,,,,,,,, -11100,1.7590146,3.820333,,,,,,,,,,,,,, -11200,1.3012772,5.7359695,,,,,,,,,,,,,, -11300,1.6812917,4.0531063,,,,,,,,,,,,,, -11400,1.3513033,5.6987944,,,,,,,,,,,,,, -11500,1.9252981,3.6680305,,,,,,,,,,,,,, -11600,1.9258398,3.7257137,,,,,,,,,,,,,, -11700,1.8729278,3.7735023,,,,,,,,,,,,,, -11782,,,0.4147265553474426,2.723894596099853,0.371099978685379,2.9334120750427246,50000.0,0.2871000170707702,3.488808870315552,10000.0,5497.338998794556,6014.403256416321,5497.338998794556,516.0527064800262,0.3488492965698242,0.0 -11800,1.8882431,3.8529353,,,,,,,,,,,,,, -11900,1.9244883,3.6291852,,,,,,,,,,,,,, -12000,1.6739806,3.6318765,,,,,,,,,,,,,, -12100,2.0421114,3.6492493,,,,,,,,,,,,,, -12200,1.3770887,5.742474,,,,,,,,,,,,,, -12300,1.2895458,5.059752,,,,,,,,,,,,,, -12400,1.2089399,5.80412,,,,,,,,,,,,,, -12500,2.131541,3.6936376,,,,,,,,,,,,,, -12600,1.3197058,5.8496804,,,,,,,,,,,,,, -12691,,,0.4253320097923279,2.623037099838257,0.3962000012397766,2.782031536102295,50000.0,0.3032000064849853,3.3758649826049805,10000.0,5917.761694431305,6470.824687480927,5917.761694431305,551.9732129573822,0.3765082359313965,0.0 -12700,2.0262005,3.7460535,,,,,,,,,,,,,, -12800,1.7467446,3.686485,,,,,,,,,,,,,, -12900,1.8301735,3.5035934,,,,,,,,,,,,,, -13000,0.9955839,5.869458,,,,,,,,,,,,,, -13100,1.8121318,3.865427,,,,,,,,,,,,,, -13200,1.6957921,3.559269,,,,,,,,,,,,,, -13300,1.7606071,3.5481915,,,,,,,,,,,,,, -13400,2.1685858,3.85589,,,,,,,,,,,,,, -13500,1.233332,4.889925,,,,,,,,,,,,,, -13600,1.6691729,3.5395558,,,,,,,,,,,,,, -13601,,,0.4386132657527923,2.5734517574310303,0.4082199931144714,2.7287282943725586,50000.0,0.3160000145435333,3.312729358673096,10000.0,6337.694573402405,6926.301305532455,6337.694573402405,587.4360723495483,0.4064729213714599,0.0 -13700,1.7091638,3.5526478,,,,,,,,,,,,,, -13800,1.9304028,3.5250056,,,,,,,,,,,,,, -13900,1.9637195,4.530271,,,,,,,,,,,,,, -14000,1.5009553,3.922643,,,,,,,,,,,,,, -14100,1.8027785,3.5461538,,,,,,,,,,,,,, -14200,1.6978422,3.390644,,,,,,,,,,,,,, -14300,1.9330916,3.6381178,,,,,,,,,,,,,, -14400,1.0512421,5.3696785,,,,,,,,,,,,,, -14500,1.9089445,3.454385,,,,,,,,,,,,,, -14512,,,0.46240234375,2.450325965881348,0.4207199811935425,2.656090021133423,50000.0,0.3279000222682953,3.24593186378479,10000.0,6757.871794700623,7381.980210542679,6757.871794700623,622.858304977417,0.4353101253509521,0.0 -14600,1.7512122,3.4504766,,,,,,,,,,,,,, -14700,1.1757699,5.554591,,,,,,,,,,,,,, -14800,1.6316845,3.375704,,,,,,,,,,,,,, -14900,1.0511746,5.4622946,,,,,,,,,,,,,, -15000,1.3362353,5.804858,,,,,,,,,,,,,, -15100,1.5399075,3.407262,,,,,,,,,,,,,, -15200,1.6626357,3.3190384,,,,,,,,,,,,,, -15300,1.7118502,3.6859176,,,,,,,,,,,,,, -15400,1.8511895,3.4906223,,,,,,,,,,,,,, -15422,,,0.4571093618869781,2.456397771835327,0.4252599775791168,2.6171255111694336,50000.0,0.32710000872612,3.216073989868164,10000.0,7177.840661525726,7840.064656257629,7177.840661525726,660.8948771953583,0.4631009101867676,0.0 -15500,1.5392892,3.860465,,,,,,,,,,,,,, -15600,1.4500531,3.3466365,,,,,,,,,,,,,, -15700,1.615786,3.746085,,,,,,,,,,,,,, -15800,1.1659439,4.9714894,,,,,,,,,,,,,, -15900,1.5997196,3.298102,,,,,,,,,,,,,, -16000,1.6251103,3.3101723,,,,,,,,,,,,,, -16100,1.664619,3.4168873,,,,,,,,,,,,,, -16200,1.3511586,4.016933,,,,,,,,,,,,,, -16300,2.0010896,3.4711986,,,,,,,,,,,,,, -16333,,,0.4747656285762787,2.3626480102539062,0.4391599893569946,2.550366163253784,50000.0,0.3414000272750854,3.161466598510742,10000.0,7597.769626379013,8299.500366926193,7597.769626379013,700.3218250274658,0.4916057586669922,0.0 -16400,1.7841231,3.6544337,,,,,,,,,,,,,, -16500,1.493423,3.8453143,,,,,,,,,,,,,, -16600,1.6501155,3.3086324,,,,,,,,,,,,,, -16700,1.7545756,3.8494778,,,,,,,,,,,,,, -16800,1.6515391,4.050986,,,,,,,,,,,,,, -16900,1.7877554,3.4198887,,,,,,,,,,,,,, -17000,1.5767502,4.3109565,,,,,,,,,,,,,, -17100,1.487646,3.2692688,,,,,,,,,,,,,, -17200,1.8691221,3.3457117,,,,,,,,,,,,,, -17241,,,0.4889843761920929,2.295726776123047,0.4437599778175354,2.507998466491699,50000.0,0.344400018453598,3.137245178222656,10000.0,8018.110145807266,8758.87922167778,8018.110145807266,739.2799081802368,0.520301103591919,0.0 -17300,1.6062546,3.3299918,,,,,,,,,,,,,, -17400,1.537104,3.3367715,,,,,,,,,,,,,, -17500,1.2577612,4.399012,,,,,,,,,,,,,, -17600,1.149565,5.1390023,,,,,,,,,,,,,, -17700,1.8468401,3.3862276,,,,,,,,,,,,,, -17800,1.5259715,3.3698583,,,,,,,,,,,,,, -17900,1.1436763,5.761128,,,,,,,,,,,,,, -18000,1.522131,3.4946969,,,,,,,,,,,,,, -18100,1.6804856,3.342877,,,,,,,,,,,,,, -18151,,,0.4826171696186065,2.3039026260375977,0.4477799832820892,2.470780372619629,50000.0,0.3503000140190124,3.0864689350128174,10000.0,8438.437040090561,9221.031891822817,8438.437040090561,781.0245745182037,0.5498020648956299,0.0 -18200,1.6317384,3.2263353,,,,,,,,,,,,,, -18300,1.5395006,3.480889,,,,,,,,,,,,,, -18400,1.971139,3.153539,,,,,,,,,,,,,, -18500,1.50196,3.5158165,,,,,,,,,,,,,, -18600,1.1193131,4.3162847,,,,,,,,,,,,,, -18700,1.3531222,5.6877866,,,,,,,,,,,,,, -18800,1.7546481,3.2173119,,,,,,,,,,,,,, -18900,1.4256542,3.6583312,,,,,,,,,,,,,, -19000,1.4493533,3.2391582,,,,,,,,,,,,,, -19061,,,0.4994726479053497,2.22053599357605,0.462039977312088,2.404991149902344,50000.0,0.3585000038146972,3.022629976272583,10000.0,8858.829084157944,9682.15875673294,8858.829084157944,821.6779868602753,0.579658031463623,0.0 -19100,1.8067793,3.2599497,,,,,,,,,,,,,, -19200,1.4496942,4.1813974,,,,,,,,,,,,,, -19300,1.642419,3.6322434,,,,,,,,,,,,,, -19400,1.4763405,3.5802858,,,,,,,,,,,,,, -19500,1.6251116,3.4132018,,,,,,,,,,,,,, -19600,1.7935976,3.2961571,,,,,,,,,,,,,, -19700,1.3560511,4.118657,,,,,,,,,,,,,, -19800,1.7229288,3.245439,,,,,,,,,,,,,, -19900,2.224242,3.1998262,,,,,,,,,,,,,, -19970,,,0.5083202719688416,2.1640000343322754,0.4681599736213684,2.3670108318328857,50000.0,0.3583000302314758,3.0018560886383057,10000.0,9278.822909593582,10140.743117809296,9278.822909593582,860.1880419254303,0.6096100807189941,0.0 -20000,0.9684226,5.3563643,,,,,,,,,,,,,, -20100,1.6041733,3.3049314,,,,,,,,,,,,,, -20200,1.2313266,4.892512,,,,,,,,,,,,,, -20300,1.4479297,3.2517414,,,,,,,,,,,,,, -20400,1.4501532,3.3564448,,,,,,,,,,,,,, -20500,1.6847068,3.1368418,,,,,,,,,,,,,, -20600,1.4809246,3.4613338,,,,,,,,,,,,,, -20700,1.0864612,5.5548334,,,,,,,,,,,,,, -20800,0.8501078,5.494459,,,,,,,,,,,,,, -20882,,,0.5099218487739563,2.1630897521972656,0.4744599759578705,2.34031629562378,50000.0,0.3707000315189361,2.948025941848755,10000.0,9699.075863838196,10596.444150209429,9699.075863838196,895.5525405406952,0.6403217315673828,0.0 -20900,1.1377941,5.453763,,,,,,,,,,,,,, -21000,1.2665824,4.414874,,,,,,,,,,,,,, -21100,1.3510925,5.4754696,,,,,,,,,,,,,, -21200,1.4724754,3.4296808,,,,,,,,,,,,,, -21300,1.4853019,3.2196438,,,,,,,,,,,,,, -21400,1.1043429,4.8882804,,,,,,,,,,,,,, -21500,1.6279193,3.0911741,,,,,,,,,,,,,, -21600,1.1048827,5.415373,,,,,,,,,,,,,, -21700,1.2973729,4.297576,,,,,,,,,,,,,, -21793,,,0.5142382979393005,2.1343255043029785,0.4792400002479553,2.311788320541382,50000.0,0.3747000098228454,2.934434413909912,10000.0,10119.355597019196,11053.265462636948,10119.355597019196,932.010551214218,0.6730847358703613,0.0 -21800,1.5180378,3.4329066,,,,,,,,,,,,,, -21900,1.5916129,3.0554624,,,,,,,,,,,,,, -22000,1.224946,4.52197,,,,,,,,,,,,,, -22100,1.7034836,3.060234,,,,,,,,,,,,,, -22200,1.6931517,3.2206469,,,,,,,,,,,,,, -22300,1.5183634,3.1164699,,,,,,,,,,,,,, -22400,1.12549,5.457399,,,,,,,,,,,,,, -22500,1.0063373,5.488881,,,,,,,,,,,,,, -22600,1.3165451,3.9419544,,,,,,,,,,,,,, -22700,1.4208382,3.3927898,,,,,,,,,,,,,, -22703,,,0.5282812118530273,2.0714240074157715,0.4855599999427795,2.277931928634644,50000.0,0.3823000192642212,2.8995566368103027,10000.0,10539.649500846865,11509.611449241638,10539.649500846865,967.9775204658508,0.7074997425079346,0.0 -22800,1.6011431,3.5050685,,,,,,,,,,,,,, -22900,1.4595133,3.1161287,,,,,,,,,,,,,, -23000,1.1570804,5.5398555,,,,,,,,,,,,,, -23100,1.6036235,3.060254,,,,,,,,,,,,,, -23200,1.3120482,5.126998,,,,,,,,,,,,,, -23300,1.586272,3.1730444,,,,,,,,,,,,,, -23400,1.3541522,3.6973734,,,,,,,,,,,,,, -23500,1.8990902,3.0876398,,,,,,,,,,,,,, -23600,1.5586566,3.0082345,,,,,,,,,,,,,, -23613,,,0.5564648509025574,1.9199002981185915,0.4958999752998352,2.209649085998535,50000.0,0.3823000192642212,2.859225273132324,10000.0,10959.790457248688,11969.157633304596,10959.790457248688,1007.3027441501616,0.7358355522155762,0.0 -23700,1.3608227,3.6936343,,,,,,,,,,,,,, -23800,1.3174425,3.7211938,,,,,,,,,,,,,, -23900,1.0888174,5.5042353,,,,,,,,,,,,,, -24000,1.2087523,3.7964542,,,,,,,,,,,,,, -24100,1.9538693,3.068621,,,,,,,,,,,,,, -24200,1.3707019,5.4907804,,,,,,,,,,,,,, -24300,1.1030629,4.0965323,,,,,,,,,,,,,, -24400,1.2098666,4.4897084,,,,,,,,,,,,,, -24500,1.6569275,3.1283877,,,,,,,,,,,,,, -24523,,,0.5373241901397705,2.020864248275757,0.5016999840736389,2.1879594326019287,50000.0,0.3924000263214111,2.822327852249145,10000.0,11379.97221159935,12429.073031425476,11379.97221159935,1046.9547855854034,0.7669098377227783,0.0 -24600,1.5792496,3.1768136,,,,,,,,,,,,,, -24700,1.1413271,5.5639086,,,,,,,,,,,,,, -24800,1.6756955,2.899881,,,,,,,,,,,,,, -24900,1.4510556,3.074454,,,,,,,,,,,,,, -25000,1.0564171,5.091308,,,,,,,,,,,,,, -25100,1.2170091,4.746859,,,,,,,,,,,,,, -25200,1.547977,3.0958862,,,,,,,,,,,,,, -25300,1.6390455,2.9785395,,,,,,,,,,,,,, -25400,1.4078116,3.159856,,,,,,,,,,,,,, -25434,,,0.5425586104393005,1.988130569458008,0.5017399787902832,2.190019130706787,50000.0,0.3924000263214111,2.835370779037476,10000.0,11800.332841157911,12885.237624645231,11800.332841157911,1082.676115512848,0.7968857288360596,0.0 -25500,1.3154274,3.9738998,,,,,,,,,,,,,, -25600,1.1683106,5.537079,,,,,,,,,,,,,, -25700,2.3787134,3.1349752,,,,,,,,,,,,,, -25800,1.6664733,2.9952793,,,,,,,,,,,,,, -25900,1.6032499,3.010532,,,,,,,,,,,,,, -26000,1.3213439,4.2879057,,,,,,,,,,,,,, -26100,1.6194786,2.9955008,,,,,,,,,,,,,, -26200,1.6024219,3.0194218,,,,,,,,,,,,,, -26300,1.2459067,4.4470415,,,,,,,,,,,,,, -26343,,,0.573046863079071,1.8160487413406368,0.5158599615097046,2.1130800247192383,50000.0,0.4079000055789947,2.7394766807556152,10000.0,12220.630915403366,13341.120736837389,12220.630915403366,1118.181747674942,0.8260171413421631,0.0 -26400,1.3062036,4.2471557,,,,,,,,,,,,,, -26500,1.6053629,3.0563931,,,,,,,,,,,,,, -26600,1.6111983,3.1228778,,,,,,,,,,,,,, -26700,1.3788861,3.9608035,,,,,,,,,,,,,, -26800,1.4228013,3.4432893,,,,,,,,,,,,,, -26900,1.5724394,2.8918629,,,,,,,,,,,,,, -27000,1.3406336,4.8653626,,,,,,,,,,,,,, -27100,1.2240058,4.643665,,,,,,,,,,,,,, -27200,1.283646,4.6346765,,,,,,,,,,,,,, -27254,,,0.5526952743530273,1.9304035902023315,0.5169399976730347,2.108940362930298,50000.0,0.4070000052452087,2.7597310543060303,10000.0,12640.863479614258,13797.866409778597,12640.863479614258,1154.6123263835907,0.8571968078613281,0.0 -27300,1.5410564,2.8288665,,,,,,,,,,,,,, -27400,1.315371,5.026712,,,,,,,,,,,,,, -27500,1.5359033,3.485869,,,,,,,,,,,,,, -27600,1.4999665,3.3507934,,,,,,,,,,,,,, -27700,1.7307075,2.9638655,,,,,,,,,,,,,, -27800,1.5391666,2.9156797,,,,,,,,,,,,,, -27900,1.5972958,2.9743526,,,,,,,,,,,,,, -28000,1.211753,5.0480123,,,,,,,,,,,,,, -28100,1.7038171,2.865641,,,,,,,,,,,,,, -28164,,,0.5553905963897705,1.9481656551361084,0.51528000831604,2.139265298843384,50000.0,0.4034000337123871,2.77888560295105,10000.0,13061.040511369703,14252.11918401718,13061.040511369703,1188.6037278175354,0.8905489444732666,0.0 -28200,1.736405,2.9676132,,,,,,,,,,,,,, -28300,1.6122209,2.9333313,,,,,,,,,,,,,, -28400,1.6650273,3.3699603,,,,,,,,,,,,,, -28500,1.353749,4.9680567,,,,,,,,,,,,,, -28600,1.5764663,2.7761786,,,,,,,,,,,,,, -28700,1.628904,2.945949,,,,,,,,,,,,,, -28800,1.9378557,3.118143,,,,,,,,,,,,,, -28900,1.595987,3.2198322,,,,,,,,,,,,,, -29000,1.2234367,4.2351456,,,,,,,,,,,,,, -29074,,,0.576171875,1.81451153755188,0.5225200057029724,2.0690627098083496,50000.0,0.4146000146865845,2.690381050109864,10000.0,13481.007174015043,14706.609322786331,13481.007174015043,1223.0425362586975,0.9243690967559814,0.0 -29100,1.2059809,5.5462832,,,,,,,,,,,,,, -29200,1.7093481,2.8652768,,,,,,,,,,,,,, -29300,1.1502218,5.6053386,,,,,,,,,,,,,, -29400,1.2791879,3.983767,,,,,,,,,,,,,, -29500,1.4725751,3.7462842,,,,,,,,,,,,,, -29600,1.1997586,4.6854777,,,,,,,,,,,,,, -29700,1.0998073,5.4291315,,,,,,,,,,,,,, -29800,1.7005752,2.9092426,,,,,,,,,,,,,, -29900,1.6660206,3.0520391,,,,,,,,,,,,,, -29985,,,0.5616992115974426,1.872113943099976,0.5232599973678589,2.0561981201171875,50000.0,0.4086000323295593,2.7167224884033203,10000.0,13901.04438996315,15160.329586267471,13901.04438996315,1256.6402442455292,0.9586703777313232,0.0 -30000,1.6015475,5.5099144,,,,,,,,,,,,,, -30100,1.7816643,2.9161937,,,,,,,,,,,,,, -30200,1.5747019,3.6541114,,,,,,,,,,,,,, -30300,1.6599051,2.9418967,,,,,,,,,,,,,, -30400,1.5843793,3.0111609,,,,,,,,,,,,,, -30500,1.6430256,2.811545,,,,,,,,,,,,,, -30600,1.1492312,5.4783525,,,,,,,,,,,,,, -30700,1.7227511,2.7587771,,,,,,,,,,,,,, -30800,1.6679543,2.826537,,,,,,,,,,,,,, -30896,,,0.5642382502555847,1.8705335855484009,0.519819974899292,2.07366156578064,50000.0,0.4124000072479248,2.723940849304199,10000.0,14321.219009399414,15618.991045713425,14321.219009399414,1295.0405583381653,0.9944887161254884,0.0 -30900,1.4792086,2.855171,,,,,,,,,,,,,, -31000,1.7950869,2.9881272,,,,,,,,,,,,,, -31100,1.6567565,2.8308477,,,,,,,,,,,,,, -31200,1.5779243,3.291578,,,,,,,,,,,,,, -31300,1.5589213,3.2711742,,,,,,,,,,,,,, -31400,1.5615405,3.9385452,,,,,,,,,,,,,, -31500,1.5007427,2.7809489,,,,,,,,,,,,,, -31600,1.797527,2.8726234,,,,,,,,,,,,,, -31700,1.9083321,2.8268757,,,,,,,,,,,,,, -31800,1.6126436,2.9104354,,,,,,,,,,,,,, -31807,,,0.5814062356948853,1.8077421188354488,0.5366799831390381,2.019350290298462,50000.0,0.4239000082015991,2.674529790878296,10000.0,14741.544480085371,16073.455638170242,14741.544480085371,1329.095343351364,1.027569055557251,0.0 -31900,1.6321774,2.7347093,,,,,,,,,,,,,, -32000,1.2647034,4.110183,,,,,,,,,,,,,, -32100,1.6756854,2.7831874,,,,,,,,,,,,,, -32200,1.6547925,2.789516,,,,,,,,,,,,,, -32300,1.6284198,2.8510404,,,,,,,,,,,,,, -32400,1.5183259,3.4183497,,,,,,,,,,,,,, -32500,1.814081,2.8989294,,,,,,,,,,,,,, -32600,1.7315044,3.3543093,,,,,,,,,,,,,, -32700,1.5142187,3.4208338,,,,,,,,,,,,,, -32718,,,0.5716406106948853,1.8151863813400269,0.5365599989891052,1.983070731163025,50000.0,0.4264000058174133,2.6434519290924072,10000.0,15161.686965703964,16532.786767959595,15161.686965703964,1368.20298743248,1.0579285621643066,0.0 -32800,1.5900035,3.0383945,,,,,,,,,,,,,, -32900,2.0582883,2.8558064,,,,,,,,,,,,,, -33000,1.8775009,2.8682518,,,,,,,,,,,,,, -33100,1.5866894,2.9513657,,,,,,,,,,,,,, -33200,1.2865767,3.9354825,,,,,,,,,,,,,, -33300,1.7265437,2.75678,,,,,,,,,,,,,, -33400,1.3488798,3.8684533,,,,,,,,,,,,,, -33500,1.7125889,2.9813232,,,,,,,,,,,,,, -33600,1.5428042,3.4436407,,,,,,,,,,,,,, -33628,,,0.5784569978713989,1.8230026960372925,0.5387799739837646,2.005711555480957,50000.0,0.4292000234127044,2.636946439743042,10000.0,15581.740196228027,16989.820979833603,15581.740196228027,1405.1028938293457,1.087789535522461,0.0 -33700,1.6279681,2.8828788,,,,,,,,,,,,,, -33800,1.6341746,2.8037047,,,,,,,,,,,,,, -33900,1.2938776,4.7642593,,,,,,,,,,,,,, -34000,1.6245124,2.7421472,,,,,,,,,,,,,, -34100,1.3791178,5.1385055,,,,,,,,,,,,,, -34200,1.2007254,5.4293375,,,,,,,,,,,,,, -34300,1.5762763,2.7977698,,,,,,,,,,,,,, -34400,1.4889967,3.5275688,,,,,,,,,,,,,, -34500,1.4215648,3.5081658,,,,,,,,,,,,,, -34539,,,0.5900781154632568,1.7298932075500488,0.5426799654960632,1.94635820388794,50000.0,0.425100028514862,2.602484941482544,10000.0,16001.98867702484,17444.537808418274,16001.98867702484,1439.4901642799375,1.1178858280181885,0.0 -34600,1.5613424,3.0990753,,,,,,,,,,,,,, -34700,1.9664152,2.8559315,,,,,,,,,,,,,, -34800,1.5820713,2.7119558,,,,,,,,,,,,,, -34900,1.3003061,5.175588,,,,,,,,,,,,,, -35000,1.658272,2.7455416,,,,,,,,,,,,,, -35100,1.5983312,2.766315,,,,,,,,,,,,,, -35200,1.3096918,5.3910728,,,,,,,,,,,,,, -35300,1.3297175,5.1567802,,,,,,,,,,,,,, -35400,1.6030436,3.4425387,,,,,,,,,,,,,, -35449,,,0.5878710746765137,1.7523179054260254,0.5478000044822693,1.945866227149964,50000.0,0.433100014925003,2.59553599357605,10000.0,16422.16807770729,17901.24688887596,16422.16807770729,1475.937474489212,1.149817943572998,0.0 -35500,1.5344203,2.710189,,,,,,,,,,,,,, -35600,1.9281348,2.9556932,,,,,,,,,,,,,, -35700,1.5815694,2.703555,,,,,,,,,,,,,, -35800,1.3371769,4.3372793,,,,,,,,,,,,,, -35900,1.65361,3.1947112,,,,,,,,,,,,,, -36000,1.3599926,3.7684102,,,,,,,,,,,,,, -36100,1.7023541,2.7330434,,,,,,,,,,,,,, -36200,2.0378478,2.7741203,,,,,,,,,,,,,, -36300,1.3492488,4.58777,,,,,,,,,,,,,, -36359,,,0.5913280844688416,1.7497471570968628,0.5470399856567383,1.9463144540786743,50000.0,0.4359000325202942,2.588318109512329,10000.0,16842.540033340454,18359.965654611588,16842.540033340454,1514.2002153396606,1.1828999519348145,0.0 -36400,1.984558,2.838181,,,,,,,,,,,,,, -36500,1.8881189,2.9754815,,,,,,,,,,,,,, -36600,1.7190813,2.773356,,,,,,,,,,,,,, -36700,1.830337,2.7869859,,,,,,,,,,,,,, -36800,1.8265312,2.8966804,,,,,,,,,,,,,, -36900,1.7650964,2.9491067,,,,,,,,,,,,,, -37000,1.4946654,4.822768,,,,,,,,,,,,,, -37100,1.6317787,2.8707008,,,,,,,,,,,,,, -37200,1.4918858,4.688072,,,,,,,,,,,,,, -37270,,,0.5938476324081421,1.7227494716644287,0.5518400073051453,1.9233359098434448,50000.0,0.4330000281333923,2.5716423988342285,10000.0,17262.59283065796,18819.49461555481,17262.59283065796,1553.5912311077118,1.2173235416412354,0.0 -37300,1.7867521,3.0915372,,,,,,,,,,,,,, -37400,1.5787448,2.6612954,,,,,,,,,,,,,, -37500,1.5811323,2.7829833,,,,,,,,,,,,,, -37600,1.333849,4.428947,,,,,,,,,,,,,, -37700,1.6630857,2.6832483,,,,,,,,,,,,,, -37800,1.7005457,2.7376506,,,,,,,,,,,,,, -37900,1.1327873,5.153896,,,,,,,,,,,,,, -38000,1.5290327,3.2746687,,,,,,,,,,,,,, -38100,1.6938447,2.7828176,,,,,,,,,,,,,, -38182,,,0.6218359470367432,1.614800214767456,0.5530799627304077,1.933805584907532,50000.0,0.4394000172615051,2.5792202949523926,10000.0,17682.696738004684,19273.49109721184,17682.696738004684,1587.3950009346008,1.255021095275879,0.0 -38200,1.6328092,2.8984394,,,,,,,,,,,,,, -38300,1.9571934,5.455042,,,,,,,,,,,,,, -38400,1.7388511,2.7980983,,,,,,,,,,,,,, -38500,1.6974984,2.8229854,,,,,,,,,,,,,, -38600,1.2188994,5.012413,,,,,,,,,,,,,, -38700,1.6693603,3.197184,,,,,,,,,,,,,, -38800,1.3993025,4.3840857,,,,,,,,,,,,,, -38900,1.3809489,4.652706,,,,,,,,,,,,,, -39000,1.7844858,2.6579566,,,,,,,,,,,,,, -39093,,,0.5911718606948853,1.776065468788147,0.5468999743461609,1.9785503149032595,50000.0,0.4311000108718872,2.6091208457946777,10000.0,18102.84414291382,19728.99629807472,18102.84414291382,1622.6679162979126,1.2887091636657717,0.0 -39100,1.3709065,3.7552662,,,,,,,,,,,,,, -39200,1.4993587,5.4620695,,,,,,,,,,,,,, -39300,1.7651036,3.0162597,,,,,,,,,,,,,, -39400,1.4793842,5.318906,,,,,,,,,,,,,, -39500,1.3192011,5.4709816,,,,,,,,,,,,,, -39600,1.6134212,2.8672452,,,,,,,,,,,,,, -39700,1.5968301,2.7875986,,,,,,,,,,,,,, -39800,1.4252275,4.544666,,,,,,,,,,,,,, -39900,1.3807899,3.6519282,,,,,,,,,,,,,, -40000,1.7129354,2.992948,,,,,,,,,,,,,, -40001,,,0.6050390601158142,1.6630072593688965,0.5593199729919434,1.883844017982483,50000.0,0.4356000125408172,2.5480105876922607,10000.0,18522.978043079376,20184.561008930206,18522.978043079376,1658.0126931667328,1.3234169483184814,0.0 -40100,1.9492304,2.8428903,,,,,,,,,,,,,, -40200,1.2021704,5.330967,,,,,,,,,,,,,, -40300,1.8697718,2.764602,,,,,,,,,,,,,, -40400,1.8051897,2.7918003,,,,,,,,,,,,,, -40500,1.4395976,3.0885162,,,,,,,,,,,,,, -40600,1.8942683,2.7801745,,,,,,,,,,,,,, -40700,1.396852,3.738425,,,,,,,,,,,,,, -40800,1.6640033,2.636268,,,,,,,,,,,,,, -40900,1.7288529,2.9027257,,,,,,,,,,,,,, -40914,,,0.6123827695846558,1.6324337720870972,0.5535199642181396,1.9175604581832888,50000.0,0.4389000236988067,2.5651254653930664,10000.0,18943.06448030472,20642.2593035698,18943.06448030472,1695.5409202575684,1.355790376663208,0.0 -41000,1.6571009,2.7831142,,,,,,,,,,,,,, -41100,1.7780051,2.8576832,,,,,,,,,,,,,, -41200,1.6795939,3.152385,,,,,,,,,,,,,, -41300,1.2632855,5.2496095,,,,,,,,,,,,,, -41400,1.60368,2.6525424,,,,,,,,,,,,,, -41500,1.2572626,4.7054515,,,,,,,,,,,,,, -41600,1.8715419,2.9387496,,,,,,,,,,,,,, -41700,1.2755783,4.2897224,,,,,,,,,,,,,, -41800,1.7774627,2.826569,,,,,,,,,,,,,, -41821,,,0.5993554592132568,1.6984341144561768,0.5589599609375,1.893854022026062,50000.0,0.4374000132083893,2.5530781745910645,10000.0,19363.120608329773,21099.80816817284,19363.120608329773,1732.9484958648682,1.3894102573394775,0.0 -41900,1.5953587,3.0746427,,,,,,,,,,,,,, -42000,1.7386845,2.7066145,,,,,,,,,,,,,, -42100,1.7311615,2.949666,,,,,,,,,,,,,, -42200,1.8875563,2.83003,,,,,,,,,,,,,, -42300,1.2202463,4.836658,,,,,,,,,,,,,, -42400,1.5441697,3.603648,,,,,,,,,,,,,, -42500,1.6464595,2.7057292,,,,,,,,,,,,,, -42600,1.6257535,2.727677,,,,,,,,,,,,,, -42700,1.3652319,4.3054185,,,,,,,,,,,,,, -42732,,,0.6033788919448853,1.66139018535614,0.5579400062561035,1.870342493057251,50000.0,0.4466000199317932,2.5125553607940674,10000.0,19783.279280662537,21556.965598344803,19783.279280662537,1769.8629813194275,1.42185640335083,0.0 -42800,1.5860666,4.516199,,,,,,,,,,,,,, -42900,1.7430313,2.7955353,,,,,,,,,,,,,, -43000,1.7466079,2.6851664,,,,,,,,,,,,,, -43100,1.7022988,2.643727,,,,,,,,,,,,,, -43200,1.4735059,3.2375493,,,,,,,,,,,,,, -43300,1.7071408,2.84711,,,,,,,,,,,,,, -43400,1.6390951,2.754539,,,,,,,,,,,,,, -43500,1.9600055,2.7678576,,,,,,,,,,,,,, -43600,1.6885668,2.7891865,,,,,,,,,,,,,, -43642,,,0.6136132478713989,1.6449350118637085,0.5623199939727783,1.889671802520752,50000.0,0.4421000182628631,2.542107105255127,10000.0,20203.238243579865,22012.131967306137,20203.238243579865,1804.986863613129,1.4544131755828855,0.0 -43700,1.2274396,4.7890873,,,,,,,,,,,,,, -43800,1.8663632,2.8422368,,,,,,,,,,,,,, -43900,1.7413708,2.7987378,,,,,,,,,,,,,, -44000,1.7284131,2.7701278,,,,,,,,,,,,,, -44100,1.3983403,4.5768585,,,,,,,,,,,,,, -44200,1.6707911,2.6104875,,,,,,,,,,,,,, -44300,1.6843212,2.7961874,,,,,,,,,,,,,, -44400,1.4650003,4.1935234,,,,,,,,,,,,,, -44500,1.8428217,2.68499,,,,,,,,,,,,,, -44550,,,0.6031249761581421,1.6811929941177368,0.5647599697113037,1.867743134498596,50000.0,0.4440000355243683,2.530787944793701,10000.0,20623.36201238632,22466.73867583275,20623.36201238632,1839.381169557572,1.4913108348846436,0.0 -44600,1.8443552,2.7756815,,,,,,,,,,,,,, -44700,1.6049365,3.4213152,,,,,,,,,,,,,, -44800,1.621707,4.5242414,,,,,,,,,,,,,, -44900,1.7132803,2.6987514,,,,,,,,,,,,,, -45000,1.3498601,4.7983446,,,,,,,,,,,,,, -45100,1.5026194,4.523445,,,,,,,,,,,,,, -45200,1.4618397,5.407559,,,,,,,,,,,,,, -45300,1.7158103,2.9009013,,,,,,,,,,,,,, -45400,1.6311312,2.7880476,,,,,,,,,,,,,, -45461,,,0.6108788847923279,1.6417521238327026,0.567579984664917,1.847375750541687,50000.0,0.4482000172138214,2.4963107109069824,10000.0,21043.66229391098,22921.24355506897,21043.66229391098,1873.501362800598,1.5242390632629397,0.0 -45500,1.7252542,2.635551,,,,,,,,,,,,,, -45600,1.6780229,2.5788863,,,,,,,,,,,,,, -45700,1.2824881,5.1749086,,,,,,,,,,,,,, -45800,1.3342122,4.6539116,,,,,,,,,,,,,, -45900,1.6561664,2.7544217,,,,,,,,,,,,,, -46000,1.4921314,4.9041133,,,,,,,,,,,,,, -46100,1.7193873,2.7153454,,,,,,,,,,,,,, -46200,1.6659051,2.7241213,,,,,,,,,,,,,, -46300,1.6294681,3.5497246,,,,,,,,,,,,,, -46372,,,0.6156249642372131,1.6239218711853027,0.5699399709701538,1.8485853672027588,50000.0,0.4461000263690948,2.515526056289673,10000.0,21463.67381906509,23379.889157772064,21463.67381906509,1912.0473115444183,1.5611541271209717,0.0 -46400,1.7133671,2.6480296,,,,,,,,,,,,,, -46500,1.6622857,2.861866,,,,,,,,,,,,,, -46600,1.6475428,2.548046,,,,,,,,,,,,,, -46700,1.8800125,2.7597656,,,,,,,,,,,,,, -46800,1.2158401,5.1145926,,,,,,,,,,,,,, -46900,1.5302569,4.324436,,,,,,,,,,,,,, -47000,1.3993368,4.190272,,,,,,,,,,,,,, -47100,1.7086304,2.7367494,,,,,,,,,,,,,, -47200,1.7204189,2.6269598,,,,,,,,,,,,,, -47284,,,0.6090039014816284,1.6575889587402344,0.5688999891281128,1.851066827774048,50000.0,0.4461000263690948,2.4953420162200928,10000.0,21883.85037612915,23834.19975042343,21883.85037612915,1946.0973546504968,1.593156099319458,0.0 -47300,1.5912879,4.670757,,,,,,,,,,,,,, -47400,1.2763901,5.301762,,,,,,,,,,,,,, -47500,1.6515298,2.918499,,,,,,,,,,,,,, -47600,1.5881476,2.7576978,,,,,,,,,,,,,, -47700,1.5543593,3.9154506,,,,,,,,,,,,,, -47800,1.4995122,3.6263537,,,,,,,,,,,,,, -47900,1.8220615,2.68131,,,,,,,,,,,,,, -48000,1.6684135,3.7041476,,,,,,,,,,,,,, -48100,1.7586789,2.7457387,,,,,,,,,,,,,, -48197,,,0.611621081829071,1.6552009582519531,0.5723199844360352,1.836769938468933,50000.0,0.4580000340938568,2.478607177734375,10000.0,22304.195620536804,24291.240940332413,22304.195620536804,1982.7054772377007,1.629871845245361,0.0 -48200,1.7940243,2.9820592,,,,,,,,,,,,,, -48300,1.878441,2.8136592,,,,,,,,,,,,,, -48400,1.6910728,2.7659833,,,,,,,,,,,,,, -48500,1.8022894,2.5278077,,,,,,,,,,,,,, -48600,1.6695191,2.7389688,,,,,,,,,,,,,, -48700,1.9809449,2.6253767,,,,,,,,,,,,,, -48800,1.6427158,2.6161761,,,,,,,,,,,,,, -48900,1.6627506,3.5641823,,,,,,,,,,,,,, -49000,1.6519567,4.1617565,,,,,,,,,,,,,, -49100,1.746542,2.672263,,,,,,,,,,,,,, -49108,,,0.6163281202316284,1.6063055992126465,0.5687999725341797,1.8405550718307493,50000.0,0.4492000341415405,2.4828150272369385,10000.0,22724.1991057396,24752.58571910858,22724.1991057396,2023.9590499401093,1.666623592376709,0.0 -49200,1.6953058,2.669308,,,,,,,,,,,,,, -49300,1.440712,4.2126427,,,,,,,,,,,,,, -49400,1.9197828,2.725619,,,,,,,,,,,,,, -49500,1.6392517,2.504624,,,,,,,,,,,,,, -49600,1.5285975,5.25028,,,,,,,,,,,,,, -49700,1.9542751,2.6339953,,,,,,,,,,,,,, -49800,1.6881019,2.56412,,,,,,,,,,,,,, -49900,1.5732359,2.8557491,,,,,,,,,,,,,, -50000,1.825138,2.6812923,,,,,,,,,,,,,, -50019,,,0.6085546612739563,1.6156573295593262,0.5709399580955505,1.8096331357955933,50000.0,0.4542000293731689,2.4651927947998047,10000.0,23144.24242377281,25207.855845689774,23144.24242377281,2059.097989797592,1.7026152610778809,0.0 -50100,1.3414533,5.1705804,,,,,,,,,,,,,, -50200,1.8795847,2.6120481,,,,,,,,,,,,,, -50300,1.8040116,3.0626674,,,,,,,,,,,,,, -50400,1.3958262,5.268229,,,,,,,,,,,,,, -50500,1.924773,2.7814,,,,,,,,,,,,,, -50600,1.8366936,2.6900625,,,,,,,,,,,,,, -50700,1.5052459,3.5438108,,,,,,,,,,,,,, -50800,2.0102446,2.5073788,,,,,,,,,,,,,, -50900,1.9064841,2.5927348,,,,,,,,,,,,,, -50931,,,0.6145898103713989,1.6268950700759888,0.5734800100326538,1.8265553712844849,50000.0,0.4537000358104706,2.484137296676636,10000.0,23564.4644010067,25661.495080709457,23564.4644010067,2092.429440975189,1.737779140472412,0.0 -51000,1.7592221,2.680138,,,,,,,,,,,,,, -51100,1.8086977,2.575561,,,,,,,,,,,,,, -51200,1.4825432,4.648968,,,,,,,,,,,,,, -51300,1.855346,2.6162095,,,,,,,,,,,,,, -51400,1.730155,2.868468,,,,,,,,,,,,,, -51500,1.8572834,2.5990574,,,,,,,,,,,,,, -51600,1.3049332,5.275103,,,,,,,,,,,,,, -51700,1.6214559,3.513315,,,,,,,,,,,,,, -51800,1.3412251,4.870785,,,,,,,,,,,,,, -51842,,,0.619140625,1.61297345161438,0.5745399594306946,1.8239986896514893,50000.0,0.4580000340938568,2.465365409851074,10000.0,23984.826218366623,26122.27624464035,23984.826218366623,2132.76446890831,1.7713980674743652,0.0 -51900,1.9708791,2.526569,,,,,,,,,,,,,, -52000,2.086911,3.226871,,,,,,,,,,,,,, -52100,1.3680278,4.7265053,,,,,,,,,,,,,, -52200,1.8545259,2.445588,,,,,,,,,,,,,, -52300,1.8891215,2.556409,,,,,,,,,,,,,, -52400,1.5765878,3.8165216,,,,,,,,,,,,,, -52500,1.7005165,3.1740685,,,,,,,,,,,,,, -52600,1.5522556,4.8882837,,,,,,,,,,,,,, -52700,1.3798988,5.1983943,,,,,,,,,,,,,, -52752,,,0.6456249952316284,1.4840998649597168,0.5798799991607666,1.7855687141418457,50000.0,0.4611000120639801,2.4369454383850098,10000.0,24404.973826408383,26583.041292905807,24404.973826408383,2173.295962333679,1.8063812255859373,0.0 -52800,1.5452133,3.041666,,,,,,,,,,,,,, -52900,1.4183288,3.7670312,,,,,,,,,,,,,, -53000,1.8121082,2.8286138,,,,,,,,,,,,,, -53100,1.8814019,2.8396547,,,,,,,,,,,,,, -53200,1.7156299,2.8407135,,,,,,,,,,,,,, -53300,1.535369,3.239737,,,,,,,,,,,,,, -53400,1.5798736,3.0966563,,,,,,,,,,,,,, -53500,1.5123178,5.323857,,,,,,,,,,,,,, -53600,1.3987784,4.144755,,,,,,,,,,,,,, -53663,,,0.6150586009025574,1.626201033592224,0.5754199624061584,1.8177956342697144,50000.0,0.4559000134468078,2.470672130584717,10000.0,24825.06359386444,27037.99173426628,24825.06359386444,2208.069890022278,1.8420429229736328,0.0 -53700,1.6395054,2.669969,,,,,,,,,,,,,, -53800,1.4864106,4.9771366,,,,,,,,,,,,,, -53900,1.3720016,4.6086497,,,,,,,,,,,,,, -54000,1.4401853,3.7846482,,,,,,,,,,,,,, -54100,1.7562732,2.607356,,,,,,,,,,,,,, -54200,1.8767444,2.6885555,,,,,,,,,,,,,, -54300,1.601259,3.1657753,,,,,,,,,,,,,, -54400,1.9025977,2.8350515,,,,,,,,,,,,,, -54500,1.6734883,2.442162,,,,,,,,,,,,,, -54573,,,0.6263281106948853,1.5781675577163696,0.5829600095748901,1.770058512687683,50000.0,0.4609000086784363,2.430996894836426,10000.0,25245.0795879364,27497.753811597824,25245.0795879364,2247.7225930690765,1.88374662399292,0.0 -54600,1.6838363,2.6269946,,,,,,,,,,,,,, -54700,1.7985458,2.9416246,,,,,,,,,,,,,, -54800,1.6538886,3.5332382,,,,,,,,,,,,,, -54900,1.9846961,2.725306,,,,,,,,,,,,,, -55000,1.5585887,2.492642,,,,,,,,,,,,,, -55100,1.5136272,3.6830475,,,,,,,,,,,,,, -55200,1.9912786,2.6881216,,,,,,,,,,,,,, -55300,1.7444317,2.6995513,,,,,,,,,,,,,, -55400,1.8888285,2.6065624,,,,,,,,,,,,,, -55484,,,0.6351562142372131,1.5078805685043335,0.5780400037765503,1.778692603111267,50000.0,0.4611000120639801,2.447931289672852,10000.0,25665.39813661576,27958.75144600868,25665.39813661576,2288.314208507538,1.9204604625701904,0.0 -55500,1.3031441,5.2493553,,,,,,,,,,,,,, -55600,1.6739911,2.7938614,,,,,,,,,,,,,, -55700,1.4432834,4.0933332,,,,,,,,,,,,,, -55800,1.503126,4.1556478,,,,,,,,,,,,,, -55900,1.7759026,2.6113973,,,,,,,,,,,,,, -56000,1.3943694,5.040288,,,,,,,,,,,,,, -56100,1.7105744,2.8797016,,,,,,,,,,,,,, -56200,1.8685343,2.7497978,,,,,,,,,,,,,, -56300,1.359545,3.978483,,,,,,,,,,,,,, -56394,,,0.6251562237739563,1.554768204689026,0.5821200013160706,1.764073371887207,50000.0,0.4609000086784363,2.4148685932159424,10000.0,26085.56778717041,28418.65738463401,26085.56778717041,2327.961247444153,1.957954168319702,0.0 -56400,1.9045985,2.7257953,,,,,,,,,,,,,, -56500,1.7910981,3.2984834,,,,,,,,,,,,,, -56600,1.9368477,2.5494256,,,,,,,,,,,,,, -56700,1.9173331,2.5880635,,,,,,,,,,,,,, -56800,2.0587258,2.6685996,,,,,,,,,,,,,, -56900,2.0881865,2.6261952,,,,,,,,,,,,,, -57000,1.8113391,2.5643046,,,,,,,,,,,,,, -57100,1.9367458,2.6404357,,,,,,,,,,,,,, -57200,1.580144,3.1605313,,,,,,,,,,,,,, -57300,1.6270841,2.8034072,,,,,,,,,,,,,, -57304,,,0.6250194907188416,1.5971978902816772,0.5831599831581116,1.794812798500061,50000.0,0.4624000191688537,2.455163955688477,10000.0,26505.91107916832,28875.283742427822,26505.91107916832,2364.1540179252625,1.9971649646759035,0.0 -57400,1.6838974,2.9910178,,,,,,,,,,,,,, -57500,1.9230101,2.4499855,,,,,,,,,,,,,, -57600,1.4598475,4.1724634,,,,,,,,,,,,,, -57700,1.7675735,2.9538476,,,,,,,,,,,,,, -57800,1.710482,2.7790422,,,,,,,,,,,,,, -57900,1.7551391,2.4991655,,,,,,,,,,,,,, -58000,1.4803097,5.2406616,,,,,,,,,,,,,, -58100,1.2771069,5.181564,,,,,,,,,,,,,, -58200,1.8801743,2.5638592,,,,,,,,,,,,,, -58212,,,0.6395898461341858,1.5047837495803833,0.5877599716186523,1.7553316354751587,50000.0,0.4683000147342682,2.404117584228516,10000.0,26926.0757522583,29336.025983572006,26926.0757522583,2404.646691799164,2.0316531658172607,0.0 -58300,1.8649347,2.6793416,,,,,,,,,,,,,, -58400,1.3427411,4.552004,,,,,,,,,,,,,, -58500,1.8644696,2.7695532,,,,,,,,,,,,,, -58600,1.563057,4.833414,,,,,,,,,,,,,, -58700,1.984955,2.523409,,,,,,,,,,,,,, -58800,1.316956,4.7177215,,,,,,,,,,,,,, -58900,1.2388763,5.0080895,,,,,,,,,,,,,, -59000,1.8136778,2.6644268,,,,,,,,,,,,,, -59100,1.920592,2.4626265,,,,,,,,,,,,,, -59122,,,0.6267773509025574,1.5606385469436646,0.5853599905967712,1.7512341737747192,50000.0,0.4666000306606293,2.407226085662842,10000.0,27346.409598112103,29798.76133680344,27346.409598112103,2446.963950157165,2.0650646686553955,0.0 -59200,1.9325209,2.6008072,,,,,,,,,,,,,, -59300,1.3749084,4.4310827,,,,,,,,,,,,,, -59400,1.912842,2.6055467,,,,,,,,,,,,,, -59500,1.5441487,3.2563531,,,,,,,,,,,,,, -59600,2.0277665,2.635071,,,,,,,,,,,,,, -59700,1.9882114,2.9201162,,,,,,,,,,,,,, -59800,1.3882248,5.0610046,,,,,,,,,,,,,, -59900,1.9012437,2.6674304,,,,,,,,,,,,,, -60000,1.5167667,4.363447,,,,,,,,,,,,,, -60032,,,0.6316601634025574,1.5331214666366575,0.5899999737739563,1.7361007928848269,50000.0,0.4752000272274017,2.382453680038452,10000.0,27766.34730625153,30256.77499818802,27766.34730625153,2484.952459335327,2.1022024154663086,0.0 -60100,1.823226,2.6336217,,,,,,,,,,,,,, -60200,1.8011087,2.9416509,,,,,,,,,,,,,, -60300,1.9395298,2.503772,,,,,,,,,,,,,, -60400,1.825069,2.7742336,,,,,,,,,,,,,, -60500,1.7051821,3.4419744,,,,,,,,,,,,,, -60600,2.0422046,2.6556478,,,,,,,,,,,,,, -60700,1.5307221,3.8313384,,,,,,,,,,,,,, -60800,1.9157062,2.6010847,,,,,,,,,,,,,, -60900,1.5656612,5.155205,,,,,,,,,,,,,, -60941,,,0.6475585699081421,1.4857375621795654,0.5921599864959717,1.733931541442871,50000.0,0.4711000323295593,2.39124083518982,10000.0,28186.30168414116,30716.703328847885,28186.30168414116,2524.839050769806,2.1383426189422607,0.0 -61000,1.536297,3.2733467,,,,,,,,,,,,,, -61100,1.7864877,2.4027846,,,,,,,,,,,,,, -61200,2.2941995,2.5715632,,,,,,,,,,,,,, -61300,1.5054144,4.679489,,,,,,,,,,,,,, -61400,2.1724775,2.6362793,,,,,,,,,,,,,, -61500,2.102707,2.712834,,,,,,,,,,,,,, -61600,1.4717234,4.324958,,,,,,,,,,,,,, -61700,1.7163655,2.650218,,,,,,,,,,,,,, -61800,1.9966358,2.5878983,,,,,,,,,,,,,, -61851,,,0.6359570026397705,1.5239770412445068,0.5936599969863892,1.7153202295303345,50000.0,0.4680000245571136,2.374927759170532,10000.0,28606.5775911808,31177.174296855927,28606.5775911808,2564.9455330371857,2.175709009170532,0.0 -61900,1.7566676,2.945058,,,,,,,,,,,,,, -62000,1.6486399,4.4642625,,,,,,,,,,,,,, -62100,1.725396,2.6132376,,,,,,,,,,,,,, -62200,1.4715567,4.6294775,,,,,,,,,,,,,, -62300,1.6603166,2.9949403,,,,,,,,,,,,,, -62400,1.7440588,2.8281803,,,,,,,,,,,,,, -62500,1.5213163,3.5633545,,,,,,,,,,,,,, -62600,1.4682521,4.326314,,,,,,,,,,,,,, -62700,1.852948,2.781773,,,,,,,,,,,,,, -62761,,,0.6364062428474426,1.5226380825042725,0.592960000038147,1.7290226221084597,50000.0,0.4731000363826751,2.3883395195007324,10000.0,29026.926725625992,31640.543329954147,29026.926725625992,2607.876157522201,2.214055299758911,0.0 -62800,1.5329999,5.167841,,,,,,,,,,,,,, -62900,2.1286476,2.5194838,,,,,,,,,,,,,, -63000,2.0235686,2.5909986,,,,,,,,,,,,,, -63100,1.8760692,2.4395437,,,,,,,,,,,,,, -63200,1.7362618,4.5878124,,,,,,,,,,,,,, -63300,1.5605365,5.0470033,,,,,,,,,,,,,, -63400,1.5078063,5.127654,,,,,,,,,,,,,, -63500,2.0363586,2.4737957,,,,,,,,,,,,,, -63600,1.6830726,4.556839,,,,,,,,,,,,,, -63670,,,0.6468554735183716,1.4551808834075928,0.5985000133514404,1.690281867980957,50000.0,0.4786000251770019,2.3399128913879395,10000.0,29447.01801109314,32096.44493150711,29447.01801109314,2643.59489440918,2.254435777664185,0.0 -63700,1.9695206,2.6554565,,,,,,,,,,,,,, -63800,2.1406484,2.5384169,,,,,,,,,,,,,, -63900,1.527345,3.465512,,,,,,,,,,,,,, -64000,1.9421028,2.6049018,,,,,,,,,,,,,, -64100,1.8492194,2.565617,,,,,,,,,,,,,, -64200,1.7693993,3.8738482,,,,,,,,,,,,,, -64300,1.6797316,2.3795865,,,,,,,,,,,,,, -64400,2.2329133,2.6944826,,,,,,,,,,,,,, -64500,1.9847649,2.5455484,,,,,,,,,,,,,, -64580,,,0.6409569978713989,1.5243161916732788,0.5908799767494202,1.7485175132751465,50000.0,0.4704000353813171,2.3861141204833984,10000.0,29867.20395731926,32556.95909023285,29867.20395731926,2683.8333218097687,2.2939212322235107,0.0 -64600,1.9157809,2.6831524,,,,,,,,,,,,,, -64700,1.7092533,4.8938456,,,,,,,,,,,,,, -64800,2.0074449,2.63153,,,,,,,,,,,,,, -64900,1.8159046,3.104327,,,,,,,,,,,,,, -65000,1.7997004,2.7370138,,,,,,,,,,,,,, -65100,1.9048061,2.5458899,,,,,,,,,,,,,, -65200,1.8837577,2.5619082,,,,,,,,,,,,,, -65300,1.9116886,2.4756675,,,,,,,,,,,,,, -65400,2.0121603,2.53934,,,,,,,,,,,,,, -65489,,,0.63232421875,1.529699206352234,0.5950599908828735,1.7079485654830933,50000.0,0.4764000177383423,2.362998485565185,10000.0,30287.157656669617,33017.3158800602,30287.157656669617,2724.1434757709503,2.335206985473633,0.0 -65500,1.8849347,2.7145886,,,,,,,,,,,,,, -65600,1.6007338,3.7415318,,,,,,,,,,,,,, -65700,1.6242385,2.9961245,,,,,,,,,,,,,, -65800,1.9653229,2.4674869,,,,,,,,,,,,,, -65900,1.5055115,5.155983,,,,,,,,,,,,,, -66000,1.6290834,3.0695465,,,,,,,,,,,,,, -66100,1.446966,4.8224673,,,,,,,,,,,,,, -66200,1.3602306,4.3077,,,,,,,,,,,,,, -66300,1.6526759,3.4916587,,,,,,,,,,,,,, -66399,,,0.6452734470367432,1.4815714359283447,0.5956999659538269,1.7074819803237915,50000.0,0.4799000322818756,2.360117197036743,10000.0,30707.10316038132,33475.26572585106,30707.10316038132,2762.0575156211853,2.374621152877808,0.0 -66400,1.5393595,4.4760337,,,,,,,,,,,,,, -66500,1.5135854,5.020714,,,,,,,,,,,,,, -66600,1.8241553,2.497988,,,,,,,,,,,,,, -66700,2.0021188,2.5787735,,,,,,,,,,,,,, -66800,1.99572,2.5873964,,,,,,,,,,,,,, -66900,2.042955,2.5088775,,,,,,,,,,,,,, -67000,1.907834,2.5977738,,,,,,,,,,,,,, -67100,1.7355187,2.9685943,,,,,,,,,,,,,, -67200,1.7384708,2.809717,,,,,,,,,,,,,, -67300,1.8289031,2.6616673,,,,,,,,,,,,,, -67309,,,0.6687890291213989,1.3784257173538208,0.5956799983978271,1.7057280540466309,50000.0,0.4728000164031982,2.3525753021240234,10000.0,31127.178329706192,33937.002141714096,31127.178329706192,2803.6267223358154,2.4154183864593506,0.0 -67400,1.7748821,4.5035315,,,,,,,,,,,,,, -67500,1.8736043,2.589116,,,,,,,,,,,,,, -67600,1.7255094,3.9069605,,,,,,,,,,,,,, -67700,2.014597,2.7806714,,,,,,,,,,,,,, -67800,1.9919984,2.4920342,,,,,,,,,,,,,, -67900,1.4886584,3.724688,,,,,,,,,,,,,, -68000,1.8280116,2.7823908,,,,,,,,,,,,,, -68100,1.9254345,2.529502,,,,,,,,,,,,,, -68200,2.0617692,2.5133426,,,,,,,,,,,,,, -68219,,,0.6403710842132568,1.5155787467956543,0.5939199924468994,1.7145018577575684,50000.0,0.4745000302791595,2.364424467086792,10000.0,31547.497469186783,34399.2928917408,31547.497469186783,2845.501267194748,2.461395263671875,0.0 -68300,1.6521418,4.077239,,,,,,,,,,,,,, -68400,2.0029225,2.5296736,,,,,,,,,,,,,, -68500,1.559731,4.3011165,,,,,,,,,,,,,, -68600,1.4177157,5.083455,,,,,,,,,,,,,, -68700,1.7397588,3.267697,,,,,,,,,,,,,, -68800,1.9144063,2.4826896,,,,,,,,,,,,,, -68900,2.0621865,2.6624417,,,,,,,,,,,,,, -69000,1.5784978,3.482864,,,,,,,,,,,,,, -69100,1.3876514,4.5036774,,,,,,,,,,,,,, -69130,,,0.6488866806030273,1.4620481729507446,0.6020599603652954,1.6721429824829102,50000.0,0.4783000349998474,2.311363458633423,10000.0,31967.57464170456,34859.705701351166,31967.57464170456,2885.74742937088,2.4994466304779053,0.0 -69200,1.492345,3.997345,,,,,,,,,,,,,, -69300,1.720865,2.9520473,,,,,,,,,,,,,, -69400,1.5657765,4.4258556,,,,,,,,,,,,,, -69500,1.7743917,2.8714046,,,,,,,,,,,,,, -69600,1.582601,3.435451,,,,,,,,,,,,,, -69700,1.8826721,2.6175704,,,,,,,,,,,,,, -69800,1.867959,2.3606138,,,,,,,,,,,,,, -69900,1.9913172,2.352653,,,,,,,,,,,,,, -70000,1.806622,3.5090382,,,,,,,,,,,,,, -70042,,,0.6694726347923279,1.3626357316970823,0.6027399897575378,1.665413737297058,50000.0,0.4807000160217285,2.3237833976745605,10000.0,32387.68676304817,35320.11451506615,32387.68676304817,2925.956077575684,2.5358903408050537,0.0 -70100,1.9168781,2.5122616,,,,,,,,,,,,,, -70200,1.6340027,3.2262113,,,,,,,,,,,,,, -70300,1.7923404,2.462511,,,,,,,,,,,,,, -70400,1.6626997,4.106618,,,,,,,,,,,,,, -70500,1.9471576,4.973587,,,,,,,,,,,,,, -70600,2.0712428,2.5227365,,,,,,,,,,,,,, -70700,1.8104792,2.5126688,,,,,,,,,,,,,, -70800,2.039227,2.4915283,,,,,,,,,,,,,, -70900,1.7177495,4.0860095,,,,,,,,,,,,,, -70951,,,0.6404492259025574,1.498771071434021,0.6004399657249451,1.6853318214416504,50000.0,0.4836000204086303,2.3453047275543213,10000.0,32807.90108847618,35780.31765007973,32807.90108847618,2965.8574674129486,2.5720467567443848,0.0 -71000,1.8913509,2.8300683,,,,,,,,,,,,,, -71100,1.9942638,5.043316,,,,,,,,,,,,,, -71200,2.1296763,2.520159,,,,,,,,,,,,,, -71300,1.4822519,3.9768405,,,,,,,,,,,,,, -71400,2.111579,2.536354,,,,,,,,,,,,,, -71500,2.2461102,2.5525296,,,,,,,,,,,,,, -71600,1.560946,3.7140813,,,,,,,,,,,,,, -71700,1.9482841,2.4451022,,,,,,,,,,,,,, -71800,1.6677021,3.5010283,,,,,,,,,,,,,, -71860,,,0.6525781154632568,1.4528508186340332,0.604699969291687,1.674795389175415,50000.0,0.4808000326156616,2.3266289234161377,10000.0,33228.07022809982,36241.46924185753,33228.07022809982,3006.7480852603912,2.6135470867156982,0.0 -71900,1.8877068,2.8841407,,,,,,,,,,,,,, -72000,1.5755537,3.2620087,,,,,,,,,,,,,, -72100,1.6625205,3.0913796,,,,,,,,,,,,,, -72200,1.950402,2.6655164,,,,,,,,,,,,,, -72300,2.2273512,2.483082,,,,,,,,,,,,,, -72400,1.9997447,2.5857146,,,,,,,,,,,,,, -72500,1.6918986,4.8146048,,,,,,,,,,,,,, -72600,1.9343303,2.4411497,,,,,,,,,,,,,, -72700,1.512987,4.8293786,,,,,,,,,,,,,, -72770,,,0.6627148389816284,1.4005804061889648,0.6051599979400635,1.6614387035369873,50000.0,0.4833000302314758,2.316112756729126,10000.0,33648.16423654556,36698.63916397095,33648.16423654556,3043.734041452408,2.652384042739868,0.0 -72800,1.9294506,2.554411,,,,,,,,,,,,,, -72900,1.7223192,4.5052137,,,,,,,,,,,,,, -73000,1.762363,3.328615,,,,,,,,,,,,,, -73100,1.9423518,2.381424,,,,,,,,,,,,,, -73200,1.8248651,2.4421177,,,,,,,,,,,,,, -73300,2.131548,2.3569603,,,,,,,,,,,,,, -73400,1.9003702,2.4024498,,,,,,,,,,,,,, -73500,1.8635161,2.8863223,,,,,,,,,,,,,, -73600,1.8393495,2.4211485,,,,,,,,,,,,,, -73682,,,0.6526562571525574,1.441790223121643,0.6084200143814087,1.6414271593093872,50000.0,0.4877000153064728,2.2946367263793945,10000.0,34068.218707084656,37159.42027497292,34068.218707084656,3084.372179746628,2.68951416015625,0.0 -73700,1.8611859,2.5106997,,,,,,,,,,,,,, -73800,1.8508404,2.493032,,,,,,,,,,,,,, -73900,1.5990311,4.445746,,,,,,,,,,,,,, -74000,2.082964,3.4143927,,,,,,,,,,,,,, -74100,2.015884,2.7368379,,,,,,,,,,,,,, -74200,1.6955984,5.107168,,,,,,,,,,,,,, -74300,2.1245108,2.4240732,,,,,,,,,,,,,, -74400,1.9870776,2.574022,,,,,,,,,,,,,, -74500,2.01962,2.629094,,,,,,,,,,,,,, -74591,,,0.6527929306030273,1.446118950843811,0.6092000007629395,1.650182604789734,50000.0,0.48580002784729,2.3193719387054443,10000.0,34488.36898994446,37617.76039338112,34488.36898994446,3122.4718701839447,2.7291154861450195,0.0 -74600,2.1141796,2.5732267,,,,,,,,,,,,,, -74700,1.9756747,2.5337262,,,,,,,,,,,,,, -74800,2.030619,2.3026648,,,,,,,,,,,,,, -74900,1.8172178,3.4002404,,,,,,,,,,,,,, -75000,1.9274327,2.4363894,,,,,,,,,,,,,, -75100,1.8461248,2.502933,,,,,,,,,,,,,, -75200,1.813464,2.8353097,,,,,,,,,,,,,, -75300,1.560026,4.641141,,,,,,,,,,,,,, -75400,1.9575088,2.7747016,,,,,,,,,,,,,, -75500,1.9811015,2.3208356,,,,,,,,,,,,,, -75503,,,0.6614648103713989,1.3912386894226074,0.606719970703125,1.632909655570984,50000.0,0.4943000376224518,2.2832729816436768,10000.0,34908.69756317139,38077.84191131592,34908.69756317139,3162.1323087215424,2.7706854343414307,0.0 -75600,1.9357227,2.4767733,,,,,,,,,,,,,, -75700,1.575106,5.058093,,,,,,,,,,,,,, -75800,1.9678053,2.6674001,,,,,,,,,,,,,, -75900,1.5871556,4.439356,,,,,,,,,,,,,, -76000,1.8203967,2.4241428,,,,,,,,,,,,,, -76100,1.8320683,2.5885434,,,,,,,,,,,,,, -76200,1.9764802,2.4527879,,,,,,,,,,,,,, -76300,1.9442916,2.53926,,,,,,,,,,,,,, -76400,1.8577508,2.344494,,,,,,,,,,,,,, -76412,,,0.6544336080551147,1.4055346250534058,0.6148200035095215,1.6018518209457395,50000.0,0.4888000190258026,2.271522760391236,10000.0,35328.6748585701,38533.69681978226,35328.6748585701,3197.922151327133,2.8070130348205566,0.0 -76500,2.295555,2.446481,,,,,,,,,,,,,, -76600,1.989992,2.575156,,,,,,,,,,,,,, -76700,1.8915856,2.3773263,,,,,,,,,,,,,, -76800,1.7258902,4.9068055,,,,,,,,,,,,,, -76900,2.0170908,2.442694,,,,,,,,,,,,,, -77000,1.5243404,5.016209,,,,,,,,,,,,,, -77100,1.6515852,4.3839664,,,,,,,,,,,,,, -77200,1.9601954,4.827261,,,,,,,,,,,,,, -77300,1.5889448,4.18512,,,,,,,,,,,,,, -77323,,,0.6539453268051147,1.4418593645095823,0.6126799583435059,1.63576340675354,50000.0,0.48580002784729,2.295750617980957,10000.0,35748.62918996811,38993.28647398949,35748.62918996811,3237.468738079071,2.8454105854034424,0.0 -77400,1.9059725,2.419179,,,,,,,,,,,,,, -77500,1.6158991,4.4780283,,,,,,,,,,,,,, -77600,1.6487097,3.2966666,,,,,,,,,,,,,, -77700,1.9464507,2.4538693,,,,,,,,,,,,,, -77800,1.6824421,3.8014216,,,,,,,,,,,,,, -77900,2.0054543,2.5190709,,,,,,,,,,,,,, -78000,2.4140997,2.4369502,,,,,,,,,,,,,, -78100,2.0003638,2.4052923,,,,,,,,,,,,,, -78200,2.03275,2.6465688,,,,,,,,,,,,,, -78235,,,0.665722668170929,1.3743427991867063,0.6133399605751038,1.6258063316345217,50000.0,0.4935000240802765,2.2690629959106445,10000.0,36168.83649253845,39449.87701368332,36168.83649253845,3273.759335756302,2.88728928565979,0.0 -78300,1.6811911,4.5233216,,,,,,,,,,,,,, -78400,1.9620273,2.3573692,,,,,,,,,,,,,, -78500,1.6316259,4.7776556,,,,,,,,,,,,,, -78600,2.0887697,2.4586554,,,,,,,,,,,,,, -78700,1.968734,2.6410334,,,,,,,,,,,,,, -78800,2.2068694,2.4715588,,,,,,,,,,,,,, -78900,1.9358071,2.512922,,,,,,,,,,,,,, -79000,1.8908657,2.9536662,,,,,,,,,,,,,, -79100,2.017804,2.4311624,,,,,,,,,,,,,, -79146,,,0.6544336080551147,1.4474139213562012,0.6101999878883362,1.6504663228988647,50000.0,0.4942000210285187,2.2950570583343506,10000.0,36589.174001932144,39909.59665203095,36589.174001932144,3313.05346608162,2.9241793155670166,0.0 -79200,1.9281669,2.6988134,,,,,,,,,,,,,, -79300,1.702232,3.7816863,,,,,,,,,,,,,, -79400,1.9939076,2.7211375,,,,,,,,,,,,,, -79500,1.7976454,3.309289,,,,,,,,,,,,,, -79600,1.821813,2.8588078,,,,,,,,,,,,,, -79700,1.8694352,2.3899019,,,,,,,,,,,,,, -79800,2.161076,2.4243093,,,,,,,,,,,,,, -79900,2.1138139,2.4262388,,,,,,,,,,,,,, -80000,1.7259854,3.7289233,,,,,,,,,,,,,, -80055,,,0.6633593440055847,1.3997061252593994,0.6201199889183044,1.590294361114502,50000.0,0.4994000196456909,2.22688364982605,10000.0,37009.08654427528,40366.136556625366,37009.08654427528,3349.5856976509094,2.9685537815093994,0.0 -80100,2.00108,2.571973,,,,,,,,,,,,,, -80200,1.6364243,4.485181,,,,,,,,,,,,,, -80300,1.8008368,2.7187808,,,,,,,,,,,,,, -80400,2.0532196,2.6336467,,,,,,,,,,,,,, -80500,1.9893676,2.4244218,,,,,,,,,,,,,, -80600,1.9929996,2.3189359,,,,,,,,,,,,,, -80700,1.8323827,2.630984,,,,,,,,,,,,,, -80800,2.0388188,2.5190196,,,,,,,,,,,,,, -80900,1.9046855,2.443984,,,,,,,,,,,,,, -80968,,,0.6633203029632568,1.383533239364624,0.6125400066375732,1.6144942045211792,50000.0,0.4915000200271606,2.279834032058716,10000.0,37429.40017914772,40821.89309167862,37429.40017914772,3384.9342653751373,3.0103461742401123,0.0 -81000,2.129009,2.4135003,,,,,,,,,,,,,, -81100,1.9667357,3.1215017,,,,,,,,,,,,,, -81200,2.2314086,2.3894851,,,,,,,,,,,,,, -81300,2.0712807,2.3886986,,,,,,,,,,,,,, -81400,1.9364039,2.4187052,,,,,,,,,,,,,, -81500,2.037444,2.38345,,,,,,,,,,,,,, -81600,2.0240777,2.8313649,,,,,,,,,,,,,, -81700,1.6174678,4.929395,,,,,,,,,,,,,, -81800,1.9096859,2.3361793,,,,,,,,,,,,,, -81878,,,0.6926171779632568,1.2652881145477295,0.6225599646568298,1.5752804279327393,50000.0,0.4943000376224518,2.2313928604125977,10000.0,37849.3330321312,41281.0299179554,37849.3330321312,3424.0457706451416,3.052287578582764,0.0 -81900,2.1953819,2.4420896,,,,,,,,,,,,,, -82000,1.7230664,3.252509,,,,,,,,,,,,,, -82100,2.464428,2.4513793,,,,,,,,,,,,,, -82200,1.6582408,4.0718856,,,,,,,,,,,,,, -82300,2.1247838,2.331684,,,,,,,,,,,,,, -82400,2.0428746,2.3562949,,,,,,,,,,,,,, -82500,1.7345707,4.9757824,,,,,,,,,,,,,, -82600,2.0330837,2.4947915,,,,,,,,,,,,,, -82700,2.060842,2.2859416,,,,,,,,,,,,,, -82788,,,0.6661523580551147,1.387743592262268,0.6181600093841553,1.5939165353775024,50000.0,0.4970000088214874,2.224552154541016,10000.0,38269.31256151199,41738.09440588951,38269.31256151199,3461.037127256393,3.0954391956329346,0.0 -82800,2.164855,2.2926583,,,,,,,,,,,,,, -82900,1.8975282,3.453058,,,,,,,,,,,,,, -83000,2.143551,2.4877,,,,,,,,,,,,,, -83100,1.8093524,4.4982085,,,,,,,,,,,,,, -83200,1.7777278,4.1342297,,,,,,,,,,,,,, -83300,1.753518,4.41823,,,,,,,,,,,,,, -83400,1.8923426,2.882932,,,,,,,,,,,,,, -83500,1.6966574,3.4267778,,,,,,,,,,,,,, -83600,1.7746679,4.238009,,,,,,,,,,,,,, -83698,,,0.6722851395606995,1.3354723453521729,0.6222999691963196,1.5551286935806274,50000.0,0.5004000067710876,2.207161903381348,10000.0,38689.46293449402,42197.38641667366,38689.46293449402,3500.086168289185,3.1364245414733887,0.0 -83700,2.2238522,2.3808415,,,,,,,,,,,,,, -83800,2.0639877,2.2913618,,,,,,,,,,,,,, -83900,1.6920495,4.023455,,,,,,,,,,,,,, -84000,1.9428613,2.4302409,,,,,,,,,,,,,, -84100,1.7623533,3.080134,,,,,,,,,,,,,, -84200,1.7051566,3.11094,,,,,,,,,,,,,, -84300,2.010363,2.4244843,,,,,,,,,,,,,, -84400,2.1418052,2.2777493,,,,,,,,,,,,,, -84500,1.9719658,2.8370855,,,,,,,,,,,,,, -84600,1.773105,4.2459645,,,,,,,,,,,,,, -84609,,,0.6911913752555847,1.2636808156967163,0.6240800023078918,1.56609046459198,50000.0,0.5009000301361084,2.2192747592926025,10000.0,39109.49723100662,42656.90248131752,39109.49723100662,3539.470131397248,3.183518648147583,0.0 -84700,1.7747642,4.1186876,,,,,,,,,,,,,, -84800,2.0114026,2.506188,,,,,,,,,,,,,, -84900,2.0598617,2.3339956,,,,,,,,,,,,,, -85000,1.9310671,4.007517,,,,,,,,,,,,,, -85100,1.7215196,3.5805526,,,,,,,,,,,,,, -85200,1.8151237,4.885353,,,,,,,,,,,,,, -85300,1.7731822,2.8583727,,,,,,,,,,,,,, -85400,2.2108514,2.521857,,,,,,,,,,,,,, -85500,2.1251016,2.7562394,,,,,,,,,,,,,, -85519,,,0.6674023270606995,1.393503189086914,0.6214599609375,1.5956655740737915,50000.0,0.5010000467300415,2.2568812370300293,10000.0,39529.7694542408,43119.98700237274,39529.7694542408,3582.189652442932,3.224731683731079,0.0 -85600,2.3013957,2.4224741,,,,,,,,,,,,,, -85700,2.293868,2.6095345,,,,,,,,,,,,,, -85800,1.9521737,3.177922,,,,,,,,,,,,,, -85900,2.0604618,2.6746902,,,,,,,,,,,,,, -86000,1.7425835,4.538514,,,,,,,,,,,,,, -86100,1.9693011,2.433979,,,,,,,,,,,,,, -86200,1.9584486,2.979013,,,,,,,,,,,,,, -86300,2.075599,2.3963664,,,,,,,,,,,,,, -86400,1.8844395,3.365381,,,,,,,,,,,,,, -86427,,,0.6750780940055847,1.347376823425293,0.6284599900245667,1.55901837348938,50000.0,0.5054000020027161,2.208587408065796,10000.0,39950.09155750275,43578.21590304375,39950.09155750275,3620.007707118988,3.262495994567871,0.0 -86500,1.8604532,3.078366,,,,,,,,,,,,,, -86600,2.1199627,2.3007157,,,,,,,,,,,,,, -86700,1.8369594,3.3647914,,,,,,,,,,,,,, -86800,2.171396,2.5071132,,,,,,,,,,,,,, -86900,1.8384954,5.0313735,,,,,,,,,,,,,, -87000,2.1316543,2.2071373,,,,,,,,,,,,,, -87100,2.1162899,4.951618,,,,,,,,,,,,,, -87200,1.97592,2.485434,,,,,,,,,,,,,, -87300,1.5756905,4.507495,,,,,,,,,,,,,, -87338,,,0.6869726181030273,1.2832924127578735,0.6260600090026855,1.5575437545776367,50000.0,0.503000020980835,2.1890664100646973,10000.0,40370.41019010544,44038.112585783005,40370.41019010544,3659.488579750061,3.3086326122283936,0.0 -87400,2.2955701,2.3581846,,,,,,,,,,,,,, -87500,1.8632336,2.9016716,,,,,,,,,,,,,, -87600,1.9964052,2.8169954,,,,,,,,,,,,,, -87700,2.171495,2.309449,,,,,,,,,,,,,, -87800,2.0504403,2.3404772,,,,,,,,,,,,,, -87900,1.9332674,2.2298448,,,,,,,,,,,,,, -88000,2.2077453,2.3648922,,,,,,,,,,,,,, -88100,1.6593555,3.4459805,,,,,,,,,,,,,, -88200,2.0095744,2.8802872,,,,,,,,,,,,,, -88245,,,0.6767382621765137,1.3393456935882568,0.6292399764060974,1.5438076257705688,50000.0,0.5083000063896179,2.1848161220550537,10000.0,40790.58680820465,44496.99372458458,40790.58680820465,3698.100270032882,3.3508591651916504,0.0 -88300,1.9921714,3.097447,,,,,,,,,,,,,, -88400,1.7379369,3.622111,,,,,,,,,,,,,, -88500,2.1513789,2.367212,,,,,,,,,,,,,, -88600,2.444734,2.367032,,,,,,,,,,,,,, -88700,2.4598258,4.292954,,,,,,,,,,,,,, -88800,2.0457988,2.2998767,,,,,,,,,,,,,, -88900,2.3538,2.4421446,,,,,,,,,,,,,, -89000,1.8789291,3.2426283,,,,,,,,,,,,,, -89100,1.8573458,3.2052398,,,,,,,,,,,,,, -89155,,,0.6765820384025574,1.3568689823150637,0.6251199841499329,1.592879056930542,50000.0,0.4978000223636627,2.233126878738404,10000.0,41210.69068980217,44953.37469482422,41210.69068980217,3734.2811844348894,3.396424531936645,0.0 -89200,2.075542,2.3594742,,,,,,,,,,,,,, -89300,2.1661017,2.4962273,,,,,,,,,,,,,, -89400,1.750301,3.8102503,,,,,,,,,,,,,, -89500,2.1618583,2.3587184,,,,,,,,,,,,,, -89600,1.6564292,4.8581514,,,,,,,,,,,,,, -89700,1.863582,3.3755255,,,,,,,,,,,,,, -89800,1.7317111,2.754301,,,,,,,,,,,,,, -89900,2.1272593,2.398138,,,,,,,,,,,,,, -90000,2.0860653,2.386715,,,,,,,,,,,,,, -90066,,,0.6912695169448853,1.2604010105133057,0.6340000033378601,1.5193389654159546,50000.0,0.508400022983551,2.170599460601806,10000.0,41631.16453623772,45414.1202609539,41631.16453623772,3774.460196733474,3.4379467964172363,0.0 -90100,1.8158253,2.8243027,,,,,,,,,,,,,, -90200,2.0576813,2.7363806,,,,,,,,,,,,,, -90300,2.266749,2.4513752,,,,,,,,,,,,,, -90400,1.7550405,3.5717068,,,,,,,,,,,,,, -90500,2.344426,2.268132,,,,,,,,,,,,,, -90600,1.7746308,3.8667912,,,,,,,,,,,,,, -90700,1.8571424,4.04774,,,,,,,,,,,,,, -90800,2.3591084,2.2889252,,,,,,,,,,,,,, -90900,2.1219738,2.297681,,,,,,,,,,,,,, -90976,,,0.6769921779632568,1.3207937479019165,0.6307799816131592,1.5352376699447632,50000.0,0.5042999982833862,2.1914098262786865,10000.0,42051.34705209732,45868.677743434906,42051.34705209732,3808.74009847641,3.4811456203460693,0.0 -91000,1.8139321,3.3657243,,,,,,,,,,,,,, -91100,1.6467489,3.5310485,,,,,,,,,,,,,, -91200,1.7444587,4.2592893,,,,,,,,,,,,,, -91300,2.1795497,2.5899742,,,,,,,,,,,,,, -91400,2.0115027,4.439393,,,,,,,,,,,,,, -91500,1.6685606,4.232569,,,,,,,,,,,,,, -91600,2.1022909,2.5029552,,,,,,,,,,,,,, -91700,2.1740232,2.3937428,,,,,,,,,,,,,, -91800,2.0360956,2.2829974,,,,,,,,,,,,,, -91885,,,0.6859374642372131,1.277827262878418,0.6394599676132202,1.4927690029144287,50000.0,0.5106000304222107,2.133218050003052,10000.0,42471.574355363846,46325.46682262421,42471.574355363846,3845.205990314484,3.5260469913482666,0.0 -91900,1.8078814,3.8908837,,,,,,,,,,,,,, -92000,2.2108428,2.2640405,,,,,,,,,,,,,, -92100,1.9714714,2.6660633,,,,,,,,,,,,,, -92200,1.8022314,4.646629,,,,,,,,,,,,,, -92300,2.063649,2.1881702,,,,,,,,,,,,,, -92400,2.1160476,2.5274675,,,,,,,,,,,,,, -92500,2.4460428,2.1810923,,,,,,,,,,,,,, -92600,2.060929,2.3232756,,,,,,,,,,,,,, -92700,1.7874918,4.6164675,,,,,,,,,,,,,, -92793,,,0.6937304735183716,1.268678903579712,0.6373999714851379,1.5291402339935305,50000.0,0.5193000435829163,2.167919874191284,10000.0,42891.8378636837,46784.998166799545,42891.8378636837,3884.377197265625,3.5714974403381348,0.0 -92800,2.050287,2.654745,,,,,,,,,,,,,, -92900,1.9976331,2.3233929,,,,,,,,,,,,,, -93000,1.9228191,4.9435077,,,,,,,,,,,,,, -93100,2.4718127,2.297248,,,,,,,,,,,,,, -93200,2.1157873,2.6787262,,,,,,,,,,,,,, -93300,2.1502116,4.3749943,,,,,,,,,,,,,, -93400,2.1458302,2.1816864,,,,,,,,,,,,,, -93500,2.1883235,2.235424,,,,,,,,,,,,,, -93600,1.916912,3.1521037,,,,,,,,,,,,,, -93700,,,0.684374988079071,1.302620768547058,0.6352800130844116,1.5236302614212036,50000.0,0.5074000358581543,2.174853563308716,10000.0,43311.91257548332,47240.98172211647,43311.91257548332,3920.189643383026,3.617255449295044,0.0 -93700,2.088538,2.3102374,,,,,,,,,,,,,, -93800,1.6556972,4.4127607,,,,,,,,,,,,,, -93900,1.8383044,3.615847,,,,,,,,,,,,,, -94000,2.2275324,2.236187,,,,,,,,,,,,,, -94100,2.1793425,2.3018522,,,,,,,,,,,,,, -94200,2.0714028,2.3087852,,,,,,,,,,,,,, -94300,2.2504816,2.318582,,,,,,,,,,,,,, -94400,2.012848,2.8852663,,,,,,,,,,,,,, -94500,2.3992238,2.360674,,,,,,,,,,,,,, -94600,2.2689822,2.3104467,,,,,,,,,,,,,, -94609,,,0.6906640529632568,1.2588465213775637,0.6386399865150452,1.485780119895935,50000.0,0.5190000534057617,2.115467071533203,10000.0,43732.20620751381,47700.05508303642,43732.20620751381,3958.876034975052,3.6599841117858887,0.0 -94700,1.7068652,4.765961,,,,,,,,,,,,,, -94800,2.218803,2.3174148,,,,,,,,,,,,,, -94900,2.1438773,2.3721466,,,,,,,,,,,,,, -95000,2.0306578,2.374915,,,,,,,,,,,,,, -95100,2.18364,2.123113,,,,,,,,,,,,,, -95200,2.1761923,2.4192433,,,,,,,,,,,,,, -95300,1.9373693,4.9359856,,,,,,,,,,,,,, -95400,1.8238722,4.154378,,,,,,,,,,,,,, -95500,1.8631135,3.2246108,,,,,,,,,,,,,, -95519,,,0.6943945288658142,1.260308861732483,0.6425999999046326,1.499654769897461,50000.0,0.516800045967102,2.149538278579712,10000.0,44152.35373473168,48157.12686371803,44152.35373473168,3995.7055275440216,3.703071117401123,0.0 -95600,2.8141656,2.2180147,,,,,,,,,,,,,, -95700,2.4550488,2.1932359,,,,,,,,,,,,,, -95800,2.246812,2.659143,,,,,,,,,,,,,, -95900,2.1093392,4.787561,,,,,,,,,,,,,, -96000,1.8327633,4.150689,,,,,,,,,,,,,, -96100,2.3302941,2.2859447,,,,,,,,,,,,,, -96200,2.144699,2.6367953,,,,,,,,,,,,,, -96300,2.547133,2.4437706,,,,,,,,,,,,,, -96400,2.0979018,4.7579303,,,,,,,,,,,,,, -96429,,,0.7122656106948853,1.174549221992493,0.6404199600219727,1.485430121421814,50000.0,0.5157000422477722,2.13046932220459,10000.0,44572.31914424896,48614.92049980164,44572.31914424896,4033.443476676941,3.7422258853912354,0.0 -96500,2.1439538,2.2108135,,,,,,,,,,,,,, -96600,1.8950802,3.943699,,,,,,,,,,,,,, -96700,2.3743086,2.2723548,,,,,,,,,,,,,, -96800,1.8604437,3.7359436,,,,,,,,,,,,,, -96900,2.4525774,2.2697663,,,,,,,,,,,,,, -97000,2.2127714,2.2746718,,,,,,,,,,,,,, -97100,2.0132983,3.3583453,,,,,,,,,,,,,, -97200,1.9635416,3.9684894,,,,,,,,,,,,,, -97300,2.1114478,2.2660568,,,,,,,,,,,,,, -97339,,,0.6898632645606995,1.2658751010894775,0.6431399583816528,1.4787806272506714,50000.0,0.5144000053405762,2.127562522888184,10000.0,44992.56587338448,49074.64042210579,44992.56587338448,4072.826028347016,3.780719995498657,0.0 -97400,2.4104886,2.2700217,,,,,,,,,,,,,, -97500,2.2807415,2.2739916,,,,,,,,,,,,,, -97600,2.1714265,2.2838948,,,,,,,,,,,,,, -97700,2.3223007,2.2691064,,,,,,,,,,,,,, -97800,2.1453803,2.2396443,,,,,,,,,,,,,, -97900,2.3449173,2.2951565,,,,,,,,,,,,,, -98000,2.3104224,2.183446,,,,,,,,,,,,,, -98100,2.1721532,2.3214884,,,,,,,,,,,,,, -98200,1.9149586,4.16228,,,,,,,,,,,,,, -98248,,,0.6936132907867432,1.285739541053772,0.6411199569702148,1.5151735544204712,50000.0,0.520300030708313,2.1551709175109863,10000.0,45412.90539312363,49531.960334301,45412.90539312363,4109.709174156189,3.827104806900024,0.0 -98300,2.1147516,2.5728014,,,,,,,,,,,,,, -98400,2.3037364,2.2918923,,,,,,,,,,,,,, -98500,2.1855767,2.0990982,,,,,,,,,,,,,, -98600,1.9419391,4.7384377,,,,,,,,,,,,,, -98700,2.2633448,2.1950467,,,,,,,,,,,,,, -98800,2.5949016,2.2557125,,,,,,,,,,,,,, -98900,2.1630335,2.1978083,,,,,,,,,,,,,, -99000,2.3338103,2.242234,,,,,,,,,,,,,, -99100,1.9506406,3.8779137,,,,,,,,,,,,,, -99158,,,0.7100781202316284,1.1822881698608398,0.6423599720001221,1.4824674129486084,50000.0,0.5199000239372253,2.1205554008483887,10000.0,45832.94949889183,49992.29464268685,45832.94949889183,4149.903354167938,3.872529983520508,0.0 -99200,2.0163004,3.4799364,,,,,,,,,,,,,, -99300,2.2701082,2.2355914,,,,,,,,,,,,,, -99400,2.0764701,3.1106904,,,,,,,,,,,,,, -99500,2.0821495,4.8959303,,,,,,,,,,,,,, -99600,2.3119025,2.3379903,,,,,,,,,,,,,, -99700,2.2698019,2.7225406,,,,,,,,,,,,,, -99800,2.0770776,2.4285436,,,,,,,,,,,,,, -99900,2.1948988,2.2555635,,,,,,,,,,,,,, -100000,2.3874605,2.195687,,,,,,,,,,,,,, -100068,,,0.69677734375,1.2403640747070312,0.6484400033950806,1.4505292177200315,50000.0,0.5268000364303589,2.0834946632385254,10000.0,46253.27742695808,50451.75859832764,46253.27742695808,4188.944055318832,3.9171104431152335,0.0 -100100,2.3644545,2.3037257,,,,,,,,,,,,,, -100200,1.803041,3.4582133,,,,,,,,,,,,,, -100300,2.3982828,2.247395,,,,,,,,,,,,,, -100400,2.3904629,2.2880414,,,,,,,,,,,,,, -100500,2.2706294,3.9921746,,,,,,,,,,,,,, -100600,1.9731871,3.3103685,,,,,,,,,,,,,, -100700,1.9418285,3.2665544,,,,,,,,,,,,,, -100800,1.8928654,3.3901474,,,,,,,,,,,,,, -100900,2.181731,3.743672,,,,,,,,,,,,,, -100978,,,0.7009961009025574,1.234709620475769,0.6490600109100342,1.46269428730011,50000.0,0.5200000405311584,2.102366209030152,10000.0,46673.61992907524,50909.56090068817,46673.61992907524,4226.306841373444,3.9611897468566895,0.0 -101000,2.3578663,2.2396202,,,,,,,,,,,,,, -101100,2.2411165,2.3123407,,,,,,,,,,,,,, -101200,2.19387,2.1426454,,,,,,,,,,,,,, -101300,2.0680652,2.5490923,,,,,,,,,,,,,, -101400,2.1816604,2.1454945,,,,,,,,,,,,,, -101500,2.4182012,2.5274925,,,,,,,,,,,,,, -101600,2.2130513,2.5294452,,,,,,,,,,,,,, -101700,2.2277334,3.0230713,,,,,,,,,,,,,, -101800,2.2346125,4.847573,,,,,,,,,,,,,, -101885,,,0.7134765386581421,1.1484699249267578,0.6495400071144104,1.434063196182251,50000.0,0.5250000357627869,2.06763768196106,10000.0,47093.20279479027,51370.56505489349,47093.20279479027,4267.212629556656,4.426010370254517,0.0 -101900,2.1565132,4.4602113,,,,,,,,,,,,,, -102000,2.1145465,2.4592361,,,,,,,,,,,,,, -102100,1.9411408,2.1664147,,,,,,,,,,,,,, -102200,2.1522028,2.3662462,,,,,,,,,,,,,, -102300,2.220144,2.3121321,,,,,,,,,,,,,, -102400,2.4916115,2.0726337,,,,,,,,,,,,,, -102500,2.2618604,4.6330385,,,,,,,,,,,,,, -102600,2.3293116,2.1636384,,,,,,,,,,,,,, -102700,2.0195057,3.8973246,,,,,,,,,,,,,, -102794,,,0.69691401720047,1.217390060424805,0.6519399881362915,1.433598875999451,50000.0,0.5301000475883484,2.074958086013794,10000.0,47513.41642546654,51827.06900215149,47513.41642546654,4303.410413980484,4.468097686767578,0.0 -102800,1.8880303,4.7149553,,,,,,,,,,,,,, -102900,2.0616004,3.1841013,,,,,,,,,,,,,, -103000,2.0179424,4.784566,,,,,,,,,,,,,, -103100,2.2447667,2.3273096,,,,,,,,,,,,,, -103200,2.1397676,3.2386208,,,,,,,,,,,,,, -103300,2.3035147,2.0911822,,,,,,,,,,,,,, -103400,2.33175,2.1303327,,,,,,,,,,,,,, -103500,2.2292986,2.263915,,,,,,,,,,,,,, -103600,2.1217089,3.5446358,,,,,,,,,,,,,, -103700,2.2772202,2.1353147,,,,,,,,,,,,,, -103704,,,0.7031054496765137,1.2242692708969116,0.6475200057029724,1.4630873203277588,50000.0,0.5266000032424927,2.095916271209717,10000.0,47933.50048518181,52290.1813583374,47933.50048518181,4346.339977025986,4.515044212341309,0.0 -103800,2.385793,2.167869,,,,,,,,,,,,,, -103900,2.3373578,2.2448008,,,,,,,,,,,,,, -104000,1.9285198,4.317104,,,,,,,,,,,,,, -104100,2.3526855,2.2287767,,,,,,,,,,,,,, -104200,2.2945182,2.503358,,,,,,,,,,,,,, -104300,2.3417523,2.1323364,,,,,,,,,,,,,, -104400,2.2147863,4.325581,,,,,,,,,,,,,, -104500,2.3574922,2.2081182,,,,,,,,,,,,,, -104600,2.2395182,2.4314408,,,,,,,,,,,,,, -104614,,,0.7122460603713989,1.175422430038452,0.6545599699020386,1.444462776184082,50000.0,0.5293000340461731,2.080199718475342,10000.0,48353.644204854965,52754.61422896385,48353.644204854965,4390.538145542145,4.555092096328735,0.0 -104700,2.11531,2.1652124,,,,,,,,,,,,,, -104800,2.0593112,3.0418425,,,,,,,,,,,,,, -104900,2.4419372,2.183391,,,,,,,,,,,,,, -105000,2.2759876,2.213323,,,,,,,,,,,,,, -105100,2.557606,2.2175713,,,,,,,,,,,,,, -105200,1.8568529,4.3572655,,,,,,,,,,,,,, -105300,2.0527143,2.5765738,,,,,,,,,,,,,, -105400,2.1462176,2.690263,,,,,,,,,,,,,, -105500,2.314221,2.1329405,,,,,,,,,,,,,, -105522,,,0.7056835889816284,1.2042386531829834,0.6565399765968323,1.425550937652588,50000.0,0.525600016117096,2.077850341796875,10000.0,48773.55797100067,53216.09664773941,48773.55797100067,4432.014835357666,4.596234560012817,0.0 -105600,2.5742216,2.1432242,,,,,,,,,,,,,, -105700,2.0193052,3.442046,,,,,,,,,,,,,, -105800,2.4276316,2.2415152,,,,,,,,,,,,,, -105900,2.067924,3.9631395,,,,,,,,,,,,,, -106000,2.0034256,4.0637283,,,,,,,,,,,,,, -106100,2.447034,2.2544096,,,,,,,,,,,,,, -106200,2.118731,2.4235625,,,,,,,,,,,,,, -106300,2.4649806,4.7353625,,,,,,,,,,,,,, -106400,2.665403,2.2788494,,,,,,,,,,,,,, -106430,,,0.7126562595367432,1.1635273694992063,0.6616799831390381,1.3938157558441162,50000.0,0.539400041103363,2.0336849689483643,10000.0,49193.51614046097,53679.70312547684,49193.51614046097,4475.56670665741,4.640961647033691,0.0 -106500,2.2355137,2.6839545,,,,,,,,,,,,,, -106600,2.4214413,2.2449684,,,,,,,,,,,,,, -106700,2.4445193,2.118531,,,,,,,,,,,,,, -106800,2.2241464,2.2955744,,,,,,,,,,,,,, -106900,2.163634,3.2518282,,,,,,,,,,,,,, -107000,2.181361,2.58858,,,,,,,,,,,,,, -107100,2.3070524,2.2294621,,,,,,,,,,,,,, -107200,2.4527462,2.0971258,,,,,,,,,,,,,, -107300,2.4913518,2.2692218,,,,,,,,,,,,,, -107337,,,0.717968761920929,1.1395562887191772,0.6581799983978271,1.4131308794021606,50000.0,0.532200038433075,2.072439670562744,10000.0,49613.643409490585,54139.6849834919,49613.643409490585,4515.323534250259,4.687347173690796,0.0 -107400,2.2916026,2.5294058,,,,,,,,,,,,,, -107500,2.2596688,2.2283506,,,,,,,,,,,,,, -107600,2.281271,2.0091915,,,,,,,,,,,,,, -107700,2.369147,2.3007393,,,,,,,,,,,,,, -107800,2.7555013,2.1956635,,,,,,,,,,,,,, -107900,2.0061944,3.479309,,,,,,,,,,,,,, -108000,2.6516612,2.1461122,,,,,,,,,,,,,, -108100,2.1538665,2.7733774,,,,,,,,,,,,,, -108200,2.2531302,4.7625175,,,,,,,,,,,,,, -108246,,,0.7106640338897705,1.1932311058044434,0.6593999862670898,1.424089789390564,50000.0,0.5309000015258789,2.076124429702759,10000.0,50033.96087384224,54598.303564071655,50033.96087384224,4553.532001495361,4.729357481002808,0.0 -108300,2.4267073,2.1038864,,,,,,,,,,,,,, -108400,2.327236,2.181145,,,,,,,,,,,,,, -108500,2.357583,4.495598,,,,,,,,,,,,,, -108600,1.9634091,3.4863,,,,,,,,,,,,,, -108700,2.6440928,2.0553222,,,,,,,,,,,,,, -108800,2.5775244,2.1331847,,,,,,,,,,,,,, -108900,2.3434005,2.2106528,,,,,,,,,,,,,, -109000,2.1159973,2.4225512,,,,,,,,,,,,,, -109100,2.1687326,2.319337,,,,,,,,,,,,,, -109155,,,0.7136523127555847,1.161201238632202,0.6612799763679504,1.3903865814208984,50000.0,0.5367000102996826,2.0257108211517334,10000.0,50454.120992183685,55059.87839961052,50454.120992183685,4594.851831197739,4.773675441741943,0.0 -109200,2.2786043,2.519749,,,,,,,,,,,,,, -109300,2.0669,3.3607593,,,,,,,,,,,,,, -109400,2.559827,2.1878335,,,,,,,,,,,,,, -109500,2.3044693,3.5522559,,,,,,,,,,,,,, -109600,2.389916,2.204735,,,,,,,,,,,,,, -109700,2.0241547,3.0620444,,,,,,,,,,,,,, -109800,2.2425528,2.0627437,,,,,,,,,,,,,, -109900,2.2955725,2.151613,,,,,,,,,,,,,, -110000,2.3048086,2.9386106,,,,,,,,,,,,,, -110063,,,0.724804699420929,1.1234993934631348,0.6626799702644348,1.3978769779205322,50000.0,0.5388000011444092,2.0300371646881104,10000.0,50874.046491622925,55520.72982406616,50874.046491622925,4635.680241346359,4.820109605789185,0.0 -110100,2.4713666,4.217562,,,,,,,,,,,,,, -110200,2.3338754,2.1319125,,,,,,,,,,,,,, -110300,1.9951055,4.3255773,,,,,,,,,,,,,, -110400,2.2049427,4.3490176,,,,,,,,,,,,,, -110500,2.144986,3.8993921,,,,,,,,,,,,,, -110600,2.3567045,2.0243506,,,,,,,,,,,,,, -110700,2.7084527,2.0995088,,,,,,,,,,,,,, -110800,2.5762024,1.9809089,,,,,,,,,,,,,, -110900,2.3338509,2.448204,,,,,,,,,,,,,, -110973,,,0.7303124666213989,1.0987309217453003,0.6647199988365173,1.3945322036743164,50000.0,0.5364000201225281,2.026243209838867,10000.0,51293.97769546509,55978.31710648537,51293.97769546509,4673.234929800034,4.870450496673584,0.0 -111000,2.5981135,2.2646265,,,,,,,,,,,,,, -111100,2.2566867,4.6864495,,,,,,,,,,,,,, -111200,2.2319064,4.602122,,,,,,,,,,,,,, -111300,2.2485187,2.6742034,,,,,,,,,,,,,, -111400,2.1220415,3.0838253,,,,,,,,,,,,,, -111500,2.4901023,2.1039357,,,,,,,,,,,,,, -111600,2.5246248,3.4522269,,,,,,,,,,,,,, -111700,2.30736,2.3929596,,,,,,,,,,,,,, -111800,2.210093,4.814972,,,,,,,,,,,,,, -111884,,,0.7140820026397705,1.1694566011428833,0.6618799567222595,1.3972446918487549,50000.0,0.5379000306129456,2.036576271057129,10000.0,51714.3005130291,56440.18104696274,51714.3005130291,4714.676728963852,4.919151782989502,0.0 -111900,2.1954885,3.3636398,,,,,,,,,,,,,, -112000,2.4962173,2.1138623,,,,,,,,,,,,,, -112100,2.3019843,3.5337644,,,,,,,,,,,,,, -112200,2.1481106,4.5973735,,,,,,,,,,,,,, -112300,2.2245138,2.9541907,,,,,,,,,,,,,, -112400,2.7749546,2.3596067,,,,,,,,,,,,,, -112500,2.4636383,2.0950003,,,,,,,,,,,,,, -112600,2.3753588,3.5226784,,,,,,,,,,,,,, -112700,2.5869339,2.1404529,,,,,,,,,,,,,, -112793,,,0.7234960794448853,1.1330419778823853,0.6687399744987488,1.3698722124099731,50000.0,0.5406000018119812,2.0242462158203125,10000.0,52134.38846230507,56899.53965449333,52134.38846230507,4753.851475954056,4.964050054550171,0.0 -112800,1.998228,3.7659452,,,,,,,,,,,,,, -112900,2.5623472,2.1785054,,,,,,,,,,,,,, -113000,2.1890004,4.2133045,,,,,,,,,,,,,, -113100,2.5521333,2.1783266,,,,,,,,,,,,,, -113200,2.2580726,3.4969537,,,,,,,,,,,,,, -113300,2.478101,2.1005769,,,,,,,,,,,,,, -113400,2.4500299,2.0070162,,,,,,,,,,,,,, -113500,2.268887,2.3069687,,,,,,,,,,,,,, -113600,2.5322075,2.0312145,,,,,,,,,,,,,, -113700,2.7383335,2.091927,,,,,,,,,,,,,, -113701,,,0.74085932970047,1.0626214742660522,0.6674599647521973,1.3853868246078491,50000.0,0.5449000000953674,2.020291805267334,10000.0,52555.08496952057,57360.21570634842,52555.08496952057,4793.7336666584015,5.01096510887146,0.0 -113800,2.7688863,2.1450887,,,,,,,,,,,,,, -113900,2.3018954,4.1398787,,,,,,,,,,,,,, -114000,2.2383273,4.3441467,,,,,,,,,,,,,, -114100,2.3642974,2.169136,,,,,,,,,,,,,, -114200,2.5061164,3.0475423,,,,,,,,,,,,,, -114300,2.3670502,3.8358276,,,,,,,,,,,,,, -114400,2.429719,2.1636753,,,,,,,,,,,,,, -114500,2.2748148,2.8627024,,,,,,,,,,,,,, -114600,2.651281,2.212771,,,,,,,,,,,,,, -114607,,,0.7202929258346558,1.1391555070877075,0.6705399751663208,1.3693618774414062,50000.0,0.5461000204086304,2.014665603637696,10000.0,52975.2583398819,57822.64588141441,52975.2583398819,4835.897415399551,5.053654193878174,0.0 -114700,2.4921923,2.196931,,,,,,,,,,,,,, -114800,2.2830982,3.6781905,,,,,,,,,,,,,, -114900,2.530628,2.0615149,,,,,,,,,,,,,, -115000,2.3304372,4.4384604,,,,,,,,,,,,,, -115100,2.932144,2.0661216,,,,,,,,,,,,,, -115200,2.796032,2.1590283,,,,,,,,,,,,,, -115300,2.8085258,2.0711174,,,,,,,,,,,,,, -115400,2.3663888,2.4144733,,,,,,,,,,,,,, -115500,2.2318132,2.904818,,,,,,,,,,,,,, -115514,,,0.7275585532188416,1.085247039794922,0.6701200008392334,1.3525265455245972,50000.0,0.5449000000953674,1.9930286407470703,10000.0,53395.23126530647,58283.81888151169,53395.23126530647,4877.004498958588,5.096050024032593,0.0 -115600,2.360339,3.9522834,,,,,,,,,,,,,, -115700,2.475482,2.0842705,,,,,,,,,,,,,, -115800,2.2296422,2.6242738,,,,,,,,,,,,,, -115900,2.706337,1.9746594,,,,,,,,,,,,,, -116000,2.483182,2.043411,,,,,,,,,,,,,, -116100,2.867999,2.0724812,,,,,,,,,,,,,, -116200,2.4677136,2.1662862,,,,,,,,,,,,,, -116300,2.3880372,1.9261843,,,,,,,,,,,,,, -116400,2.5599427,2.0961592,,,,,,,,,,,,,, -116421,,,0.7390038967132568,1.067587971687317,0.6695599555969238,1.372552514076233,50000.0,0.5443000197410583,2.023348093032837,10000.0,53815.36892461777,58746.47779226303,53815.36892461777,4919.428336620331,5.142207384109497,0.0 -116500,2.696959,2.2843714,,,,,,,,,,,,,, -116600,2.53282,2.2774506,,,,,,,,,,,,,, -116700,2.4169617,2.4368243,,,,,,,,,,,,,, -116800,2.541424,2.1111412,,,,,,,,,,,,,, -116900,2.6401289,2.1982288,,,,,,,,,,,,,, -117000,2.2295957,4.0640144,,,,,,,,,,,,,, -117100,2.6107836,1.9950445,,,,,,,,,,,,,, -117200,2.3631635,2.3335195,,,,,,,,,,,,,, -117300,2.5506568,2.083373,,,,,,,,,,,,,, -117327,,,0.7241405844688416,1.1141245365142822,0.6716200113296509,1.3542362451553345,50000.0,0.5481000542640686,1.9819422960281368,10000.0,54235.27755284309,59207.73613762856,54235.27755284309,4960.681174516678,5.188116788864136,0.0 -117400,2.201539,2.852634,,,,,,,,,,,,,, -117500,2.290897,4.169346,,,,,,,,,,,,,, -117600,2.7573228,2.1700258,,,,,,,,,,,,,, -117700,2.421967,4.507406,,,,,,,,,,,,,, -117800,3.2188497,2.0654242,,,,,,,,,,,,,, -117900,2.6163614,2.4561071,,,,,,,,,,,,,, -118000,2.71272,1.9951551,,,,,,,,,,,,,, -118100,2.4662135,4.4942946,,,,,,,,,,,,,, -118200,2.3285177,3.8485377,,,,,,,,,,,,,, -118236,,,0.7311132550239563,1.0923360586166382,0.674079954624176,1.3469245433807373,50000.0,0.5499000549316406,1.9686278104782104,10000.0,54655.42571687698,59665.05784249306,54655.42571687698,4997.757047891617,5.235001087188721,0.0 -118300,2.2865047,3.0213006,,,,,,,,,,,,,, -118400,2.8384666,2.047915,,,,,,,,,,,,,, -118500,2.4821963,4.0462823,,,,,,,,,,,,,, -118600,2.3035727,2.878275,,,,,,,,,,,,,, -118700,2.7534714,2.1322517,,,,,,,,,,,,,, -118800,2.3120403,3.302699,,,,,,,,,,,,,, -118900,2.8131373,2.1279495,,,,,,,,,,,,,, -119000,2.4719687,2.018856,,,,,,,,,,,,,, -119100,2.4319882,4.661434,,,,,,,,,,,,,, -119146,,,0.7421093583106995,1.022621989250183,0.6779999732971191,1.313728094100952,50000.0,0.5534000396728516,1.954906702041626,10000.0,55075.748200416565,60127.63447976112,55075.748200416565,5039.916709423065,5.278695106506348,0.0 -119200,2.8173783,1.9895618,,,,,,,,,,,,,, -119300,2.501731,2.6020494,,,,,,,,,,,,,, -119400,2.6397486,2.4981945,,,,,,,,,,,,,, -119500,3.1633492,4.408782,,,,,,,,,,,,,, -119600,2.7011375,2.0126445,,,,,,,,,,,,,, -119700,2.6942704,2.0355082,,,,,,,,,,,,,, -119800,2.4912996,1.8824692,,,,,,,,,,,,,, -119900,2.609367,1.9118564,,,,,,,,,,,,,, -120000,3.0232134,2.0492764,,,,,,,,,,,,,, -120055,,,0.7331249713897705,1.0569225549697876,0.6812399625778198,1.2919137477874756,50000.0,0.5569000244140625,1.9266986846923828,10000.0,55495.910600185394,60586.19439768791,55495.910600185394,5078.214918136597,5.3276238441467285,0.0 -120100,2.5703042,2.3213344,,,,,,,,,,,,,, -120200,2.4826813,1.9294518,,,,,,,,,,,,,, -120300,2.5014188,1.9480085,,,,,,,,,,,,,, -120400,2.786331,2.3493683,,,,,,,,,,,,,, -120500,2.8549137,2.183313,,,,,,,,,,,,,, -120600,2.3114023,2.7945619,,,,,,,,,,,,,, -120700,2.3402817,3.4687805,,,,,,,,,,,,,, -120800,2.8103516,2.2661018,,,,,,,,,,,,,, -120900,2.4229105,2.6637292,,,,,,,,,,,,,, -120963,,,0.7388867139816284,1.049467921257019,0.6840800046920776,1.2940298318862915,50000.0,0.5561000108718872,1.9345511198043823,10000.0,55916.22533154488,61047.6866889,55916.22533154488,5119.290406227112,5.378458261489868,0.0 -121000,2.7481444,4.4823604,,,,,,,,,,,,,, -121100,2.5444593,3.2522745,,,,,,,,,,,,,, -121200,2.454014,4.3765593,,,,,,,,,,,,,, -121300,2.9087312,4.6867514,,,,,,,,,,,,,, -121400,3.0962577,4.378363,,,,,,,,,,,,,, -121500,2.3036137,3.1776063,,,,,,,,,,,,,, -121600,2.3023908,4.5598755,,,,,,,,,,,,,, -121700,2.6939137,1.9511744,,,,,,,,,,,,,, -121800,2.2971137,3.8109546,,,,,,,,,,,,,, -121870,,,0.7439062595367432,1.030897855758667,0.6851599812507629,1.2949132919311523,50000.0,0.5592000484466553,1.9299076795578003,10000.0,56336.53861951828,61505.68466210365,56336.53861951828,5156.877597808838,5.425642251968384,0.0 -121900,2.954022,1.8372666,,,,,,,,,,,,,, -122000,2.6325884,2.3260775,,,,,,,,,,,,,, -122100,2.34752,3.3318439,,,,,,,,,,,,,, -122200,2.8115542,1.9912463,,,,,,,,,,,,,, -122300,2.5583155,4.556916,,,,,,,,,,,,,, -122400,2.536231,2.2033548,,,,,,,,,,,,,, -122500,2.3296142,3.1382394,,,,,,,,,,,,,, -122600,3.253787,2.0245664,,,,,,,,,,,,,, -122700,2.6374602,3.0342119,,,,,,,,,,,,,, -122779,,,0.7379491925239563,1.0452035665512085,0.685979962348938,1.2754040956497192,50000.0,0.5611000061035156,1.8941129446029663,10000.0,56756.696388721466,61964.73351264,56756.696388721466,5195.670909404755,5.473036527633667,0.0 -122800,2.5831223,3.6823077,,,,,,,,,,,,,, -122900,2.694488,2.1232753,,,,,,,,,,,,,, -123000,2.7669647,1.9401772,,,,,,,,,,,,,, -123100,2.695107,4.0636177,,,,,,,,,,,,,, -123200,2.8735971,1.8933649,,,,,,,,,,,,,, -123300,2.4391549,2.6633983,,,,,,,,,,,,,, -123400,3.1835616,1.8597866,,,,,,,,,,,,,, -123500,2.6545317,1.8653939,,,,,,,,,,,,,, -123600,2.438372,3.7326813,,,,,,,,,,,,,, -123689,,,0.7442382574081421,1.0326616764068604,0.6847800016403198,1.2847895622253418,50000.0,0.5586000084877014,1.901072382926941,10000.0,57176.67815852165,62428.31165552139,57176.67815852165,5239.167343139648,5.522738933563232,0.0 -123700,2.6904333,2.0512066,,,,,,,,,,,,,, -123800,2.5566497,1.989046,,,,,,,,,,,,,, -123900,2.7852747,2.0662253,,,,,,,,,,,,,, -124000,2.4003837,3.0861847,,,,,,,,,,,,,, -124100,2.418761,3.1182446,,,,,,,,,,,,,, -124200,2.8812902,2.220741,,,,,,,,,,,,,, -124300,2.9771595,2.004798,,,,,,,,,,,,,, -124400,2.7785125,1.8661627,,,,,,,,,,,,,, -124500,2.7316048,3.5711734,,,,,,,,,,,,,, -124598,,,0.7477929592132568,1.0206478834152222,0.6859599947929382,1.2963595390319824,50000.0,0.5652000308036804,1.9213535785675049,10000.0,57596.78480887413,62886.07406878472,57596.78480887413,5276.72448348999,5.570141315460205,0.0 -124600,2.853962,2.4840436,,,,,,,,,,,,,, -124700,3.1526997,4.463107,,,,,,,,,,,,,, -124800,2.9857879,2.0512047,,,,,,,,,,,,,, -124900,2.6183035,2.0457072,,,,,,,,,,,,,, -125000,2.3330224,3.7059865,,,,,,,,,,,,,, -125100,2.670055,2.8018446,,,,,,,,,,,,,, -125200,2.9378493,1.9249921,,,,,,,,,,,,,, -125300,2.5859234,1.8918887,,,,,,,,,,,,,, -125400,3.2377908,3.6443472,,,,,,,,,,,,,, -125500,2.8720179,1.9332377,,,,,,,,,,,,,, -125508,,,0.7463671565055847,1.0201128721237185,0.6872999668121338,1.2751214504241943,50000.0,0.5609000325202942,1.9046622514724727,10000.0,58016.77628803253,63345.20164251328,58016.77628803253,5315.758058547974,5.622182130813599,0.0 -125600,3.0323043,2.0195367,,,,,,,,,,,,,, -125700,2.9360342,1.9861412,,,,,,,,,,,,,, -125800,2.9736457,2.269913,,,,,,,,,,,,,, -125900,2.847453,1.9297504,,,,,,,,,,,,,, -126000,3.0867207,1.846676,,,,,,,,,,,,,, -126100,2.5803907,2.8708296,,,,,,,,,,,,,, -126200,2.630569,4.027163,,,,,,,,,,,,,, -126300,2.9001982,1.9368304,,,,,,,,,,,,,, -126400,3.046719,4.1014028,,,,,,,,,,,,,, -126415,,,0.7504687309265137,0.9982839822769164,0.6909799575805664,1.252796649932861,50000.0,0.5678000450134277,1.879265069961548,10000.0,58437.0435795784,63803.56805491448,58437.0435795784,5353.759917020798,5.668635845184326,0.0 -126500,3.0490003,2.1650684,,,,,,,,,,,,,, -126600,3.1001995,2.0219615,,,,,,,,,,,,,, -126700,2.7862108,4.5565467,,,,,,,,,,,,,, -126800,2.8572345,2.132348,,,,,,,,,,,,,, -126900,2.870737,1.8410182,,,,,,,,,,,,,, -127000,2.8608825,4.367957,,,,,,,,,,,,,, -127100,3.241544,2.0545506,,,,,,,,,,,,,, -127200,3.0228117,2.081241,,,,,,,,,,,,,, -127300,3.0719101,1.9233806,,,,,,,,,,,,,, -127324,,,0.7506445050239563,0.9825835824012756,0.6895599961280823,1.252158761024475,50000.0,0.567300021648407,1.8777422904968264,10000.0,58857.124438762665,64261.78774547577,58857.124438762665,5391.798512220383,5.717786550521851,0.0 -127400,2.9040012,1.8637698,,,,,,,,,,,,,, -127500,2.9013834,4.5239553,,,,,,,,,,,,,, -127600,2.8514109,1.9405026,,,,,,,,,,,,,, -127700,2.6238663,3.6982093,,,,,,,,,,,,,, -127800,3.5439157,4.024604,,,,,,,,,,,,,, -127900,2.6827931,3.7371352,,,,,,,,,,,,,, -128000,3.2378078,1.9579507,,,,,,,,,,,,,, -128100,2.8830693,3.6101375,,,,,,,,,,,,,, -128200,2.7496321,3.9511619,,,,,,,,,,,,,, -128236,,,0.7712695002555847,0.9164721369743348,0.6928399801254272,1.2563698291778564,50000.0,0.5665000081062317,1.882158637046814,10000.0,59277.34511780739,64719.820784807205,59277.34511780739,5429.511273860931,5.765959978103638,0.0 -128300,2.904199,3.9381728,,,,,,,,,,,,,, -128400,3.014847,2.0485642,,,,,,,,,,,,,, -128500,3.017563,2.1772938,,,,,,,,,,,,,, -128600,3.0262535,1.9709231,,,,,,,,,,,,,, -128700,2.8045535,1.9631038,,,,,,,,,,,,,, -128800,3.459558,1.973625,,,,,,,,,,,,,, -128900,2.591227,3.2967434,,,,,,,,,,,,,, -129000,2.8701055,2.3832197,,,,,,,,,,,,,, -129100,3.105059,1.8778386,,,,,,,,,,,,,, -129143,,,0.753222644329071,1.0042492151260376,0.6917200088500977,1.2587978839874268,50000.0,0.5640000104904175,1.883506298065185,10000.0,59697.71044564247,65181.91693449021,59697.71044564247,5471.143674373627,5.814197540283203,0.0 -129200,2.5697174,3.1246185,,,,,,,,,,,,,, -129300,2.9472454,1.8740749,,,,,,,,,,,,,, -129400,2.9410717,4.0111885,,,,,,,,,,,,,, -129500,2.699187,3.7312262,,,,,,,,,,,,,, -129600,2.8516831,4.3679557,,,,,,,,,,,,,, -129700,2.8787642,1.8562756,,,,,,,,,,,,,, -129800,3.0427222,4.5070767,,,,,,,,,,,,,, -129900,3.171801,1.9533387,,,,,,,,,,,,,, -130000,2.807222,2.3283043,,,,,,,,,,,,,, -130052,,,0.7607616782188416,0.9540216326713562,0.6956999897956848,1.2329747676849363,50000.0,0.5714000463485718,1.8770899772644043,10000.0,60117.83626127243,65637.22610259056,60117.83626127243,5506.229387283325,5.861069679260254,0.0 -130100,3.244104,1.9205863,,,,,,,,,,,,,, -130200,2.8507106,1.8306223,,,,,,,,,,,,,, -130300,2.504824,2.9101865,,,,,,,,,,,,,, -130400,2.7275882,2.5484009,,,,,,,,,,,,,, -130500,3.2510834,1.9945523,,,,,,,,,,,,,, -130600,2.866459,1.8238792,,,,,,,,,,,,,, -130700,2.9970875,2.2334018,,,,,,,,,,,,,, -130800,3.1917331,4.4424715,,,,,,,,,,,,,, -130900,2.9435222,3.5376244,,,,,,,,,,,,,, -130962,,,0.7698827981948853,0.9147235155105592,0.6962400078773499,1.2309191226959229,50000.0,0.5742000341415405,1.8582838773727417,10000.0,60537.82030844688,66098.19182014465,60537.82030844688,5547.110866785049,5.910343647003174,0.0 -131000,3.1626196,2.2744536,,,,,,,,,,,,,, -131100,2.896511,2.0722754,,,,,,,,,,,,,, -131200,2.7245038,3.753051,,,,,,,,,,,,,, -131300,3.1666288,1.7736256,,,,,,,,,,,,,, -131400,2.944957,2.5808306,,,,,,,,,,,,,, -131500,2.9260468,1.8257165,,,,,,,,,,,,,, -131600,3.1922722,1.9961925,,,,,,,,,,,,,, -131700,2.932169,1.950036,,,,,,,,,,,,,, -131800,3.2225747,2.0742407,,,,,,,,,,,,,, -131870,,,0.7575585842132568,0.9641546607017516,0.7005599737167358,1.219005107879639,50000.0,0.5725000500679016,1.835734844207764,10000.0,60957.97349905968,66560.34111618996,60957.97349905968,5589.005449295044,5.961735725402832,0.0 -131900,3.35084,2.1030672,,,,,,,,,,,,,, -132000,3.1949894,1.9700695,,,,,,,,,,,,,, -132100,3.314168,4.0630326,,,,,,,,,,,,,, -132200,3.1506457,1.929581,,,,,,,,,,,,,, -132300,2.7545955,4.029374,,,,,,,,,,,,,, -132400,3.1058888,1.6977057,,,,,,,,,,,,,, -132500,3.0246315,3.9755201,,,,,,,,,,,,,, -132600,3.2498124,1.8617234,,,,,,,,,,,,,, -132700,2.929804,1.9689965,,,,,,,,,,,,,, -132778,,,0.7624413967132568,0.9442371129989624,0.6990199685096741,1.221044421195984,50000.0,0.5773000121116638,1.856716513633728,10000.0,61378.23578214645,67020.0600142479,61378.23578214645,5628.361680984497,6.011536359786987,0.0 -132800,3.4185085,1.9496944,,,,,,,,,,,,,, -132900,3.0144305,1.9134628,,,,,,,,,,,,,, -133000,2.9134126,2.9644446,,,,,,,,,,,,,, -133100,3.0924032,2.031645,,,,,,,,,,,,,, -133200,2.9990125,4.364404,,,,,,,,,,,,,, -133300,3.1249154,1.7810466,,,,,,,,,,,,,, -133400,2.9567783,3.683311,,,,,,,,,,,,,, -133500,3.2642639,1.9700816,,,,,,,,,,,,,, -133600,3.3398726,1.8728378,,,,,,,,,,,,,, -133688,,,0.7743359208106995,0.8995473384857178,0.7016800045967102,1.211005687713623,50000.0,0.57750004529953,1.8353841304779053,10000.0,61798.40344524384,67477.51272082329,61798.40344524384,5665.5401656627655,6.066869735717773,0.0 -133700,2.8057065,3.7389605,,,,,,,,,,,,,, -133800,3.127058,1.797509,,,,,,,,,,,,,, -133900,3.0067072,1.8721466,,,,,,,,,,,,,, -134000,3.1295185,1.9869881,,,,,,,,,,,,,, -134100,3.1845462,1.9564532,,,,,,,,,,,,,, -134200,3.1916428,1.8052262,,,,,,,,,,,,,, -134300,3.5303626,1.9713161,,,,,,,,,,,,,, -134400,3.3313518,2.2889977,,,,,,,,,,,,,, -134500,3.13343,1.8774763,,,,,,,,,,,,,, -134599,,,0.76185542345047,0.9467484951019288,0.7027999758720398,1.2054047584533691,50000.0,0.5789000391960144,1.8349767923355105,10000.0,62218.46045303345,67939.94962263107,62218.46045303345,5707.8211581707,6.115529537200928,0.0 -134600,3.3213258,4.333021,,,,,,,,,,,,,, -134700,3.3129828,2.0847962,,,,,,,,,,,,,, -134800,2.9213085,2.8757224,,,,,,,,,,,,,, -134900,3.3565388,1.8050084,,,,,,,,,,,,,, -135000,3.041992,2.1213338,,,,,,,,,,,,,, -135100,3.0625322,3.9864638,,,,,,,,,,,,,, -135200,2.861083,3.0150142,,,,,,,,,,,,,, -135300,3.465546,4.2261276,,,,,,,,,,,,,, -135400,2.9647083,2.3389602,,,,,,,,,,,,,, -135500,3.2661343,1.7839102,,,,,,,,,,,,,, -135507,,,0.7659569978713989,0.9456735253334044,0.7011799812316895,1.2311320304870603,50000.0,0.5776000022888184,1.8514093160629272,10000.0,62638.41057395935,68400.22089409828,62638.41057395935,5748.042767763138,6.164484024047852,0.0 -135600,3.1835089,1.7539008,,,,,,,,,,,,,, -135700,3.3708746,1.8491056,,,,,,,,,,,,,, -135800,2.9225743,3.305976,,,,,,,,,,,,,, -135900,3.0335145,2.3996036,,,,,,,,,,,,,, -136000,2.9961066,3.4396396,,,,,,,,,,,,,, -136100,3.1987689,2.3515646,,,,,,,,,,,,,, -136200,3.3138142,1.8327122,,,,,,,,,,,,,, -136300,3.4751084,1.9123849,,,,,,,,,,,,,, -136400,3.4221973,1.8961334,,,,,,,,,,,,,, -136414,,,0.7757421731948853,0.90151709318161,0.704539954662323,1.1993683576583862,50000.0,0.5824000239372253,1.8206918239593504,10000.0,63058.55126523972,68856.67860889435,63058.55126523972,5784.257081747055,6.215606927871704,0.0 -136500,3.6500552,1.8249305,,,,,,,,,,,,,, -136600,3.063408,3.6891823,,,,,,,,,,,,,, -136700,3.3570313,1.7456868,,,,,,,,,,,,,, -136800,3.29939,2.0225801,,,,,,,,,,,,,, -136900,3.4687402,2.2005205,,,,,,,,,,,,,, -137000,3.2484953,1.8224477,,,,,,,,,,,,,, -137100,2.8824444,3.255744,,,,,,,,,,,,,, -137200,3.294549,1.7769045,,,,,,,,,,,,,, -137300,3.1243663,2.9828064,,,,,,,,,,,,,, -137322,,,0.7714648246765137,0.9188026189804076,0.7074999809265137,1.1961042881011963,50000.0,0.5777000188827515,1.8227862119674685,10000.0,63478.52360081673,69317.36960935593,63478.52360081673,5824.872405529022,6.267154216766357,0.0 -137400,3.3006864,1.76461,,,,,,,,,,,,,, -137500,3.2261267,1.7903428,,,,,,,,,,,,,, -137600,3.2334554,1.9310845,,,,,,,,,,,,,, -137700,3.450763,3.8738413,,,,,,,,,,,,,, -137800,3.6527565,1.9030812,,,,,,,,,,,,,, -137900,3.0848522,3.1852014,,,,,,,,,,,,,, -138000,3.2264454,2.2413237,,,,,,,,,,,,,, -138100,3.1536958,2.825197,,,,,,,,,,,,,, -138200,3.6658313,3.8499277,,,,,,,,,,,,,, -138229,,,0.777636706829071,0.8658644556999207,0.7101799845695496,1.1642639636993408,50000.0,0.5875000357627869,1.7642264366149902,10000.0,63898.66561055184,69776.55884242058,63898.66561055184,5863.821855783463,6.314066648483276,0.0 -138300,3.5073576,1.8038464,,,,,,,,,,,,,, -138400,3.053764,3.9824736,,,,,,,,,,,,,, -138500,4.0778146,4.3687057,,,,,,,,,,,,,, -138600,3.224355,2.2540643,,,,,,,,,,,,,, -138700,3.187636,2.4924526,,,,,,,,,,,,,, -138800,3.0249207,2.3945765,,,,,,,,,,,,,, -138900,3.35676,1.8649552,,,,,,,,,,,,,, -139000,3.5432537,3.6899173,,,,,,,,,,,,,, -139100,3.4250472,1.7593815,,,,,,,,,,,,,, -139141,,,0.7803515195846558,0.868783712387085,0.7113199830055237,1.1666555404663086,50000.0,0.5849000215530396,1.7911489009857178,10000.0,64318.93400526047,70236.60565376282,64318.93400526047,5903.501656532288,6.361291170120239,0.0 -139200,3.2991912,1.919726,,,,,,,,,,,,,, -139300,3.7338696,2.0800405,,,,,,,,,,,,,, -139400,3.3389952,4.0418715,,,,,,,,,,,,,, -139500,3.428679,3.553822,,,,,,,,,,,,,, -139600,3.4789488,1.7473277,,,,,,,,,,,,,, -139700,3.437701,1.9498451,,,,,,,,,,,,,, -139800,3.6645372,1.8198328,,,,,,,,,,,,,, -139900,3.3502038,1.6526806,,,,,,,,,,,,,, -140000,3.6347473,1.7301965,,,,,,,,,,,,,, -140050,,,0.7785937190055847,0.8846719861030579,0.7123799920082092,1.1717824935913086,50000.0,0.5907000303268433,1.7813829183578491,10000.0,64738.83408522606,70698.14928531647,64738.83408522606,5945.039888620377,6.41525673866272,0.0 -140100,3.405264,2.0875988,,,,,,,,,,,,,, -140200,3.1826198,2.5650315,,,,,,,,,,,,,, -140300,3.7178032,1.8049424,,,,,,,,,,,,,, -140400,3.1827629,2.3387256,,,,,,,,,,,,,, -140500,3.075081,3.3571908,,,,,,,,,,,,,, -140600,3.787926,1.706944,,,,,,,,,,,,,, -140700,3.4843147,1.7070202,,,,,,,,,,,,,, -140800,3.1754708,1.8025544,,,,,,,,,,,,,, -140900,3.2817833,1.7377452,,,,,,,,,,,,,, -140959,,,0.775390625,0.8824504017829895,0.7128199934959412,1.150181770324707,50000.0,0.5898000001907349,1.7759381532669067,10000.0,65158.86788415909,71154.27746748924,65158.86788415909,5981.03254365921,6.46516752243042,0.0 -141000,3.4117324,2.2224107,,,,,,,,,,,,,, -141100,3.785676,4.030558,,,,,,,,,,,,,, -141200,3.4073915,3.062592,,,,,,,,,,,,,, -141300,3.5933168,3.8917623,,,,,,,,,,,,,, -141400,3.4423358,2.5498924,,,,,,,,,,,,,, -141500,3.3040915,2.1668713,,,,,,,,,,,,,, -141600,4.0120263,3.655686,,,,,,,,,,,,,, -141700,3.6181426,3.7345657,,,,,,,,,,,,,, -141800,3.486729,1.6760796,,,,,,,,,,,,,, -141867,,,0.7857617139816284,0.8551717400550842,0.7181199789047241,1.148403525352478,50000.0,0.5962000489234924,1.753408432006836,10000.0,65579.00881290436,71614.71821403503,65579.00881290436,6021.227711677551,6.5184853076934814,0.0 -141900,4.317617,4.315064,,,,,,,,,,,,,, -142000,4.423187,1.8226557,,,,,,,,,,,,,, -142100,3.4179077,2.8631845,,,,,,,,,,,,,, -142200,3.3320687,1.6968288,,,,,,,,,,,,,, -142300,3.439332,3.0783129,,,,,,,,,,,,,, -142400,3.3486288,1.8486371,,,,,,,,,,,,,, -142500,3.6323736,1.5652124,,,,,,,,,,,,,, -142600,3.326433,1.725761,,,,,,,,,,,,,, -142700,3.4967563,2.1709864,,,,,,,,,,,,,, -142775,,,0.7930663824081421,0.8114429116249084,0.7167999744415283,1.1399496793746948,50000.0,0.5949000120162964,1.7673008441925049,10000.0,65999.13509559631,72077.39252829552,65999.13509559631,6063.670241594315,6.572426080703735,0.0 -142800,3.1527278,2.737857,,,,,,,,,,,,,, -142900,3.9912915,1.8611008,,,,,,,,,,,,,, -143000,3.6796527,3.9436336,,,,,,,,,,,,,, -143100,3.5471895,1.6809264,,,,,,,,,,,,,, -143200,3.6501634,1.7376963,,,,,,,,,,,,,, -143300,4.1388283,4.1855164,,,,,,,,,,,,,, -143400,4.2352405,1.8719531,,,,,,,,,,,,,, -143500,3.6412976,1.8269964,,,,,,,,,,,,,, -143600,3.3884423,2.278401,,,,,,,,,,,,,, -143682,,,0.7851952910423279,0.8426395058631897,0.7227999567985535,1.118364691734314,50000.0,0.6010000109672546,1.724225640296936,10000.0,66419.16562604904,72537.31822061539,66419.16562604904,6103.460964918137,6.625840425491333,0.0 -143700,3.9585543,2.8792143,,,,,,,,,,,,,, -143800,3.5626826,1.6989509,,,,,,,,,,,,,, -143900,3.4242678,2.9882827,,,,,,,,,,,,,, -144000,3.4927096,1.7088065,,,,,,,,,,,,,, -144100,3.6382825,3.8902164,,,,,,,,,,,,,, -144200,4.075186,1.737598,,,,,,,,,,,,,, -144300,3.8218956,1.6630232,,,,,,,,,,,,,, -144400,3.5338182,1.7488002,,,,,,,,,,,,,, -144500,4.044301,1.8142455,,,,,,,,,,,,,, -144588,,,0.7878710627555847,0.8365593552589417,0.7193599939346313,1.1317459344863892,50000.0,0.5928000211715698,1.753752827644348,10000.0,66839.12981963158,72996.18788266182,66839.12981963158,6141.835455179215,7.105395555496216,0.0 -144600,4.0171123,1.7283651,,,,,,,,,,,,,, -144700,4.0162473,1.7090766,,,,,,,,,,,,,, -144800,3.5283697,2.3051782,,,,,,,,,,,,,, -144900,3.693083,1.7201498,,,,,,,,,,,,,, -145000,3.272388,2.6642709,,,,,,,,,,,,,, -145100,3.5823634,1.852216,,,,,,,,,,,,,, -145200,3.5973651,1.735326,,,,,,,,,,,,,, -145300,3.9864278,1.6427633,,,,,,,,,,,,,, -145400,3.9167194,1.6741768,,,,,,,,,,,,,, -145495,,,0.7982421517372131,0.8103423714637756,0.7236799597740173,1.131292462348938,50000.0,0.6044000387191772,1.7381134033203125,10000.0,67259.35423231125,73458.10443782806,67259.35423231125,6183.42341208458,7.158540964126587,0.0 -145500,3.77582,1.6315219,,,,,,,,,,,,,, -145600,3.7553737,3.8617945,,,,,,,,,,,,,, -145700,3.9941478,1.5595428,,,,,,,,,,,,,, -145800,3.3861268,3.0271482,,,,,,,,,,,,,, -145900,3.6530972,3.4922523,,,,,,,,,,,,,, -146000,4.133284,1.8511393,,,,,,,,,,,,,, -146100,4.071462,1.7047105,,,,,,,,,,,,,, -146200,4.168551,1.687638,,,,,,,,,,,,,, -146300,3.8911417,2.0201273,,,,,,,,,,,,,, -146400,3.5749352,2.5363243,,,,,,,,,,,,,, -146404,,,0.7885546684265137,0.8147965669631958,0.7246999740600586,1.102521538734436,50000.0,0.6007000207901001,1.729142427444458,10000.0,67679.5624153614,73919.75740528107,67679.5624153614,6224.7675149440765,7.207941770553589,0.0 -146500,4.185532,4.155547,,,,,,,,,,,,,, -146600,4.1657796,1.7655354,,,,,,,,,,,,,, -146700,4.381468,1.6815915,,,,,,,,,,,,,, -146800,4.5490675,4.172939,,,,,,,,,,,,,, -146900,3.8292367,2.0774539,,,,,,,,,,,,,, -147000,4.1572027,1.691914,,,,,,,,,,,,,, -147100,4.2165637,1.7662064,,,,,,,,,,,,,, -147200,3.500645,3.241788,,,,,,,,,,,,,, -147300,3.9128006,1.6000947,,,,,,,,,,,,,, -147315,,,0.7919921875,0.8182185888290405,0.7255399823188782,1.113688826560974,50000.0,0.6016000509262085,1.728710412979126,10000.0,68099.6663825512,74383.72610616684,68099.6663825512,6268.529381752014,7.259737014770508,0.0 -147400,3.6681454,3.0551195,,,,,,,,,,,,,, -147500,4.0532823,2.6370752,,,,,,,,,,,,,, -147600,4.2881374,1.6669632,,,,,,,,,,,,,, -147700,4.5194182,2.1072485,,,,,,,,,,,,,, -147800,4.8636537,1.6474756,,,,,,,,,,,,,, -147900,4.004147,1.6251347,,,,,,,,,,,,,, -148000,3.984644,1.7044475,,,,,,,,,,,,,, -148100,3.858758,2.0260463,,,,,,,,,,,,,, -148200,4.307051,1.7806312,,,,,,,,,,,,,, -148224,,,0.8016796708106995,0.7689365744590759,0.7294999957084656,1.0898467302322388,50000.0,0.6040000319480896,1.6969178915023804,10000.0,68519.65882515907,74840.09953451157,68519.65882515907,6304.80890417099,7.309703350067139,0.0 -148300,4.1618643,1.9930743,,,,,,,,,,,,,, -148400,3.9899049,3.0790105,,,,,,,,,,,,,, -148500,3.8340743,3.484936,,,,,,,,,,,,,, -148600,3.785273,1.6516403,,,,,,,,,,,,,, -148700,4.080778,1.5277656,,,,,,,,,,,,,, -148800,4.2188635,4.0474305,,,,,,,,,,,,,, -148900,3.9911036,1.8335764,,,,,,,,,,,,,, -149000,4.8355837,1.5393677,,,,,,,,,,,,,, -149100,5.0049148,1.633723,,,,,,,,,,,,,, -149132,,,0.7948632836341858,0.8024986982345581,0.7267000079154968,1.0969750881195068,50000.0,0.6046000123023987,1.711686372756958,10000.0,68939.72495675087,75302.07324838638,68939.72495675087,6346.617129325867,7.357788801193237,0.0 -149200,4.0690603,1.6296881,,,,,,,,,,,,,, -149300,4.100327,2.204357,,,,,,,,,,,,,, -149400,4.640232,4.0065956,,,,,,,,,,,,,, -149500,4.838805,1.5069665,,,,,,,,,,,,,, -149600,3.937498,2.5101862,,,,,,,,,,,,,, -149700,4.587201,2.9227622,,,,,,,,,,,,,, -149800,4.3802485,1.7427579,,,,,,,,,,,,,, -149900,4.836405,1.6951879,,,,,,,,,,,,,, -150000,4.394236,1.6880882,,,,,,,,,,,,,, -150039,,,0.7957812547683716,0.812127411365509,0.7278800010681152,1.094635009765625,50000.0,0.6082000136375427,1.716665267944336,10000.0,69359.74498486519,75761.56579613686,69359.74498486519,6385.988118648529,7.408616781234741,0.0 -150100,4.2910633,2.8289282,,,,,,,,,,,,,, -150200,4.4675245,1.6522158,,,,,,,,,,,,,, -150300,4.304016,3.6371572,,,,,,,,,,,,,, -150400,3.8370566,3.08729,,,,,,,,,,,,,, -150500,4.060591,1.6591247,,,,,,,,,,,,,, -150600,4.101966,1.7997081,,,,,,,,,,,,,, -150700,4.5472155,3.810835,,,,,,,,,,,,,, -150800,3.8801556,2.5055654,,,,,,,,,,,,,, -150900,4.503926,1.5234711,,,,,,,,,,,,,, -150949,,,0.8050976395606995,0.7490441799163818,0.7310199737548828,1.0718308687210083,50000.0,0.6103000044822693,1.6876506805419922,10000.0,69780.05288887024,76225.19110178947,69780.05288887024,6429.201789140701,7.461533546447754,0.0 -151000,4.459345,1.5254548,,,,,,,,,,,,,, -151100,4.297399,1.7445929,,,,,,,,,,,,,, -151200,4.42548,1.6107942,,,,,,,,,,,,,, -151300,4.370276,3.673207,,,,,,,,,,,,,, -151400,4.4764824,3.3412383,,,,,,,,,,,,,, -151500,4.0430207,2.6792543,,,,,,,,,,,,,, -151600,4.3152113,1.5852178,,,,,,,,,,,,,, -151700,4.3411174,1.6762043,,,,,,,,,,,,,, -151800,4.118387,2.1760817,,,,,,,,,,,,,, -151859,,,0.8020703196525574,0.7700572609901428,0.7359799742698669,1.066151738166809,50000.0,0.6122000217437744,1.6726754903793335,10000.0,70200.3424217701,76687.96187376976,70200.3424217701,6471.581538200378,7.511557102203369,0.0 -151900,4.3448133,1.5924146,,,,,,,,,,,,,, -152000,4.825909,1.6209264,,,,,,,,,,,,,, -152100,4.5657396,1.6451149,,,,,,,,,,,,,, -152200,4.376588,1.6532152,,,,,,,,,,,,,, -152300,4.006399,3.102149,,,,,,,,,,,,,, -152400,4.1316557,1.7790675,,,,,,,,,,,,,, -152500,4.3925357,1.8440201,,,,,,,,,,,,,, -152600,4.133952,2.022254,,,,,,,,,,,,,, -152700,4.3733377,1.5221,,,,,,,,,,,,,, -152769,,,0.8058202862739563,0.7566875219345093,0.7363399863243103,1.0597726106643677,50000.0,0.614300012588501,1.6564711332321167,10000.0,70620.28472137451,77147.348580122,70620.28472137451,6510.923154830933,7.5636162757873535,0.0 -152800,4.381661,1.6448642,,,,,,,,,,,,,, -152900,4.571039,3.6765473,,,,,,,,,,,,,, -153000,4.4391737,2.9168692,,,,,,,,,,,,,, -153100,4.1622777,3.4541378,,,,,,,,,,,,,, -153200,4.626109,1.7370381,,,,,,,,,,,,,, -153300,4.1909776,2.1800063,,,,,,,,,,,,,, -153400,4.189559,1.503458,,,,,,,,,,,,,, -153500,4.205463,2.6508482,,,,,,,,,,,,,, -153600,4.3156047,1.6376638,,,,,,,,,,,,,, -153678,,,0.81298828125,0.7200682759284973,0.7379800081253052,1.0499589443206787,50000.0,0.6166000366210938,1.6612064838409424,10000.0,71040.23987174034,77604.85888195038,71040.23987174034,6548.374161958695,7.616273880004883,0.0 -153700,4.5750504,1.6558683,,,,,,,,,,,,,, -153800,4.2740383,2.5366712,,,,,,,,,,,,,, -153900,5.1978626,1.5907313,,,,,,,,,,,,,, -154000,4.7408886,1.602071,,,,,,,,,,,,,, -154100,5.4338226,3.9535432,,,,,,,,,,,,,, -154200,4.436411,1.8054934,,,,,,,,,,,,,, -154300,4.2933903,1.9138551,,,,,,,,,,,,,, -154400,4.570478,1.4989501,,,,,,,,,,,,,, -154500,4.237087,2.3442283,,,,,,,,,,,,,, -154587,,,0.8049609065055847,0.746151864528656,0.737339973449707,1.0490286350250244,50000.0,0.6158000230789185,1.6725051403045654,10000.0,71460.52765202522,78063.63143539429,71460.52765202522,6586.756381750107,7.667668581008911,0.0 -154600,4.7637596,1.536762,,,,,,,,,,,,,, -154700,4.7103767,1.5920825,,,,,,,,,,,,,, -154800,4.357322,1.8212079,,,,,,,,,,,,,, -154900,4.761257,1.6284277,,,,,,,,,,,,,, -155000,4.3252115,2.212307,,,,,,,,,,,,,, -155100,4.7854414,1.7812653,,,,,,,,,,,,,, -155200,4.6829295,1.6282969,,,,,,,,,,,,,, -155300,4.431113,2.1384892,,,,,,,,,,,,,, -155400,5.1335588,1.629634,,,,,,,,,,,,,, -155497,,,0.810839831829071,0.723305344581604,0.7399199604988098,1.035240888595581,50000.0,0.6164000034332275,1.652098298072815,10000.0,71880.68959569931,78524.88885855675,71880.68959569931,6627.746407985687,7.721628189086914,0.0 -155500,5.2436094,4.0010557,,,,,,,,,,,,,, -155600,4.317945,1.5426666,,,,,,,,,,,,,, -155700,5.259094,2.4000607,,,,,,,,,,,,,, -155800,4.603499,1.5239826,,,,,,,,,,,,,, -155900,4.4073234,2.0371335,,,,,,,,,,,,,, -156000,5.260271,1.7389948,,,,,,,,,,,,,, -156100,4.912743,1.8014874,,,,,,,,,,,,,, -156200,4.899172,1.5620217,,,,,,,,,,,,,, -156300,4.4102907,3.3068202,,,,,,,,,,,,,, -156400,4.7870183,3.9894648,,,,,,,,,,,,,, -156406,,,0.8157812356948853,0.7162489295005798,0.7403799891471863,1.0414212942123413,50000.0,0.62090003490448,1.653515338897705,10000.0,72300.75505638123,78981.39456152916,72300.75505638123,6664.082603693008,7.774799346923828,0.0 -156500,4.933386,1.565386,,,,,,,,,,,,,, -156600,4.6850643,1.7049301,,,,,,,,,,,,,, -156700,4.716921,1.5791115,,,,,,,,,,,,,, -156800,4.627549,1.4996693,,,,,,,,,,,,,, -156900,4.384672,1.6975529,,,,,,,,,,,,,, -157000,4.6018586,1.4279332,,,,,,,,,,,,,, -157100,4.807559,1.5033476,,,,,,,,,,,,,, -157200,4.4642777,3.0189102,,,,,,,,,,,,,, -157300,4.8375297,1.5308111,,,,,,,,,,,,,, -157314,,,0.8150194883346558,0.7227945327758789,0.7415800094604492,1.0487637519836426,50000.0,0.6237000226974487,1.6517783403396606,10000.0,72720.79807853699,79443.12803387642,72720.79807853699,6705.670013189316,7.826829195022583,0.0 -157400,5.083042,1.5723904,,,,,,,,,,,,,, -157500,5.054735,1.4954238,,,,,,,,,,,,,, -157600,5.1048365,1.5005306,,,,,,,,,,,,,, -157700,4.9169726,1.5438662,,,,,,,,,,,,,, -157800,4.66267,1.7648251,,,,,,,,,,,,,, -157900,5.167237,1.5799084,,,,,,,,,,,,,, -158000,4.4301825,1.5461783,,,,,,,,,,,,,, -158100,5.1051474,1.4561907,,,,,,,,,,,,,, -158200,4.7339253,3.0928075,,,,,,,,,,,,,, -158223,,,0.8199414014816284,0.7127540707588196,0.7445399761199951,1.0261250734329224,50000.0,0.626800000667572,1.6282966136932373,10000.0,73141.04412603378,79906.46152997017,73141.04412603378,6748.654703617096,7.878124952316284,0.0 -158300,4.9885206,1.719894,,,,,,,,,,,,,, -158400,4.8140855,1.635546,,,,,,,,,,,,,, -158500,4.592237,1.7996199,,,,,,,,,,,,,, -158600,5.6609836,3.5719275,,,,,,,,,,,,,, -158700,4.4263515,2.7198136,,,,,,,,,,,,,, -158800,4.9417386,3.5036554,,,,,,,,,,,,,, -158900,4.965267,2.3934453,,,,,,,,,,,,,, -159000,4.879825,1.5127543,,,,,,,,,,,,,, -159100,4.447586,1.9919944,,,,,,,,,,,,,, -159130,,,0.8240038752555847,0.6871118545532227,0.747439980506897,1.0210548639297483,50000.0,0.6269000172615051,1.622787356376648,10000.0,73561.13456368446,80370.60028123856,73561.13456368446,6792.603569984436,7.927542448043823,0.0 -159200,4.992254,2.5613496,,,,,,,,,,,,,, -159300,4.9318047,1.622144,,,,,,,,,,,,,, -159400,5.2600894,1.6406924,,,,,,,,,,,,,, -159500,4.9147186,1.5090659,,,,,,,,,,,,,, -159600,5.314946,3.1008663,,,,,,,,,,,,,, -159700,5.1003003,3.170702,,,,,,,,,,,,,, -159800,4.7947917,1.5447384,,,,,,,,,,,,,, -159900,4.7830105,2.7421494,,,,,,,,,,,,,, -160000,5.0875998,3.5860734,,,,,,,,,,,,,, -160039,,,0.8271484375,0.6677748560905457,0.7458199858665466,1.0177862644195557,50000.0,0.6220000386238098,1.6190569400787354,10000.0,73981.15388822556,80829.69215226173,73981.15388822556,6831.573830366135,7.979724168777466,0.0 -160100,5.2910523,3.0568175,,,,,,,,,,,,,, -160200,5.2283435,3.2066226,,,,,,,,,,,,,, -160300,5.4611297,3.755849,,,,,,,,,,,,,, -160400,4.8349543,1.5644649,,,,,,,,,,,,,, -160500,5.246345,1.3310789,,,,,,,,,,,,,, -160600,5.7018123,3.5882013,,,,,,,,,,,,,, -160700,5.021765,1.5847372,,,,,,,,,,,,,, -160800,5.106628,3.2684884,,,,,,,,,,,,,, -160900,5.2269993,1.8021342,,,,,,,,,,,,,, -160948,,,0.8216406106948853,0.6819668412208557,0.7469599843025208,1.0042126178741455,50000.0,0.6239000558853149,1.608596682548523,10000.0,74401.07062625885,81287.48914146423,74401.07062625885,6869.350848913193,8.031014919281006,0.0 -161000,4.899248,1.9066516,,,,,,,,,,,,,, -161100,5.1493874,1.3769774,,,,,,,,,,,,,, -161200,4.9332337,1.5423193,,,,,,,,,,,,,, -161300,5.0890765,1.5664178,,,,,,,,,,,,,, -161400,4.63897,2.1690705,,,,,,,,,,,,,, -161500,5.1738844,2.7759428,,,,,,,,,,,,,, -161600,5.3108034,1.5061911,,,,,,,,,,,,,, -161700,5.720708,1.5509307,,,,,,,,,,,,,, -161800,5.0242324,1.349074,,,,,,,,,,,,,, -161856,,,0.8240429759025574,0.6735599637031555,0.7506600022315979,0.9978360533714294,50000.0,0.6320000290870667,1.5975457429885864,10000.0,74821.04405879974,81749.93900322914,74821.04405879974,6911.723528146744,8.083473920822144,0.0 -161900,5.543312,3.5584917,,,,,,,,,,,,,, -162000,5.5760064,1.6285651,,,,,,,,,,,,,, -162100,6.3932514,3.86937,,,,,,,,,,,,,, -162200,5.119869,1.5673915,,,,,,,,,,,,,, -162300,6.869681,1.4612155,,,,,,,,,,,,,, -162400,5.218779,1.436313,,,,,,,,,,,,,, -162500,5.6511636,1.580131,,,,,,,,,,,,,, -162600,5.0883837,1.4819082,,,,,,,,,,,,,, -162700,5.5883026,1.4385885,,,,,,,,,,,,,, -162765,,,0.8299023509025574,0.6468636989593506,0.7511199712753296,0.9937317967414856,50000.0,0.6312000155448914,1.600247502326965,10000.0,75240.97953391075,82215.21632957458,75240.97953391075,6956.957853317261,8.137528657913208,0.0 -162800,5.72983,3.5386002,,,,,,,,,,,,,, -162900,5.0929904,3.0751824,,,,,,,,,,,,,, -163000,5.660496,1.5584016,,,,,,,,,,,,,, -163100,5.2451134,2.3880928,,,,,,,,,,,,,, -163200,4.985847,1.3981751,,,,,,,,,,,,,, -163300,5.616801,1.6031232,,,,,,,,,,,,,, -163400,5.117429,1.5816228,,,,,,,,,,,,,, -163500,5.5027785,3.0232997,,,,,,,,,,,,,, -163600,6.464455,3.4296384,,,,,,,,,,,,,, -163675,,,0.8234374523162842,0.6833599209785461,0.7499200105667114,0.9993503093719482,50000.0,0.6281000375747681,1.6094584465026855,10000.0,75661.26156401634,82674.41547107697,75661.26156401634,6995.772277355194,8.188964128494263,0.0 -163700,5.6128125,2.4686587,,,,,,,,,,,,,, -163800,5.495387,1.5202433,,,,,,,,,,,,,, -163900,5.949727,1.5296253,,,,,,,,,,,,,, -164000,5.5196867,2.9833534,,,,,,,,,,,,,, -164100,6.5849266,3.9122655,,,,,,,,,,,,,, -164200,6.3146987,3.808545,,,,,,,,,,,,,, -164300,5.5640345,3.4600623,,,,,,,,,,,,,, -164400,5.334046,1.5988978,,,,,,,,,,,,,, -164500,5.090824,1.34694,,,,,,,,,,,,,, -164583,,,0.8306640386581421,0.6563959121704102,0.7524600028991699,0.9888281226158142,50000.0,0.6324000358581543,1.5833725929260254,10000.0,76081.20295095444,83135.30335402489,76081.20295095444,7036.61114192009,8.245900869369507,0.0 -164600,5.59644,1.4760687,,,,,,,,,,,,,, -164700,5.725207,3.1455383,,,,,,,,,,,,,, -164800,5.2399116,2.0841968,,,,,,,,,,,,,, -164900,5.279284,1.733848,,,,,,,,,,,,,, -165000,6.4968457,1.5461051,,,,,,,,,,,,,, -165100,5.5568433,2.970376,,,,,,,,,,,,,, -165200,5.4510913,1.6186664,,,,,,,,,,,,,, -165300,5.599234,1.7513479,,,,,,,,,,,,,, -165400,5.3341794,1.4686086,,,,,,,,,,,,,, -165493,,,0.8374218344688416,0.6225919723510742,0.7544599771499634,0.9747494459152222,50000.0,0.6305000185966492,1.5695737600326538,10000.0,76501.53271389008,83594.82507920265,76501.53271389008,7075.693947792053,8.304190874099731,0.0 -165500,5.16528,2.6495345,,,,,,,,,,,,,, -165600,6.552411,3.7652197,,,,,,,,,,,,,, -165700,6.076065,1.4906249,,,,,,,,,,,,,, -165800,5.2210298,1.4358528,,,,,,,,,,,,,, -165900,5.8189673,1.9488078,,,,,,,,,,,,,, -166000,5.292177,1.6176746,,,,,,,,,,,,,, -166100,5.4325614,2.0321171,,,,,,,,,,,,,, -166200,6.2004113,1.741364,,,,,,,,,,,,,, -166300,5.5738406,1.6479778,,,,,,,,,,,,,, -166400,5.197275,2.037641,,,,,,,,,,,,,, -166403,,,0.8303124904632568,0.6463891267776489,0.7556399703025818,0.9667149782180786,50000.0,0.6349000334739685,1.5555449724197388,10000.0,76921.90054345131,84057.80018854141,76921.90054345131,7118.194668292999,8.359277486801147,0.0 -166500,5.813925,1.3654575,,,,,,,,,,,,,, -166600,5.202008,2.8747883,,,,,,,,,,,,,, -166700,5.4936676,1.8426344,,,,,,,,,,,,,, -166800,5.772556,1.4812406,,,,,,,,,,,,,, -166900,5.9586697,3.4224076,,,,,,,,,,,,,, -167000,6.308682,3.2181451,,,,,,,,,,,,,, -167100,5.532713,1.4931566,,,,,,,,,,,,,, -167200,6.29789,2.1704817,,,,,,,,,,,,,, -167300,5.520452,1.61934,,,,,,,,,,,,,, -167311,,,0.8325781226158142,0.6565069556236267,0.7567799687385559,0.9810318350791932,50000.0,0.6397000551223755,1.5751323699951172,10000.0,77342.16409659386,84515.86970067024,77342.16409659386,7155.896743297577,8.412142515182495,0.0 -167400,5.6412206,3.1710114,,,,,,,,,,,,,, -167500,6.367657,1.4503276,,,,,,,,,,,,,, -167600,5.897591,1.4307128,,,,,,,,,,,,,, -167700,6.0765357,3.2180262,,,,,,,,,,,,,, -167701,,,,,,,,,,,77520.09283590317,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 3995c072f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.721853494644165,0.0,36.40435767173767,1,0,36.40435767173767,0.0010000000474974,6.907756805419922,10000,65.12631821632385,0.0012109375093132,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -69.03572416305542,0.0187666416168212,456.5975193977356,850,0,456.5975193977356,0.0304000005125999,6.007944583892822,10000,525.6996972560883,0.0379492193460464,5.836334705352783,0.0373799987137317,5.868350028991699,50000 -110.14891505241394,0.0506336688995361,876.5749049186707,1753,0,876.5749049186707,0.0419000014662742,5.780269145965576,10000,986.8721923828124,0.0610937476158142,5.500517845153809,0.0549599975347518,5.583978176116943,50000 -153.70484042167664,0.0762069225311279,1296.616675615311,2659,0,1296.616675615311,0.0767000019550323,5.3251471519470215,10000,1450.5467166900637,0.1049999967217445,4.955112934112549,0.095660001039505,5.036952495574951,50000 -191.8669803142548,0.1044766902923584,1716.937399148941,3565,0,1716.937399148941,0.0967000052332878,5.13610315322876,10000,1909.107887744904,0.1318359375,4.715084075927734,0.1221200004220008,4.792366981506348,50000 -230.93383693695068,0.1339180469512939,2137.0898151397705,4470,0,2137.0898151397705,0.1157000064849853,4.900367259979248,10000,2368.407334804535,0.1639648377895355,4.44169282913208,0.1498399972915649,4.542600631713867,50000 -276.35117983818054,0.1603636741638183,2557.313781738281,5375,0,2557.313781738281,0.1370000094175338,4.693284034729004,10000,2834.125443696976,0.1900781244039535,4.193375587463379,0.1757399886846542,4.2860283851623535,50000 -319.9392533302307,0.1883072853088379,2977.3031663894653,6280,0,2977.3031663894653,0.1522000133991241,4.536256790161133,10000,3297.7814099788666,0.2101367115974426,4.0672478675842285,0.1936799883842468,4.1517534255981445,50000 -364.5954835414887,0.2181315422058105,3397.616981983185,7187,0,3397.616981983185,0.1632000058889389,4.448765754699707,10000,3762.831591129303,0.2340429574251175,3.8879802227020255,0.2138399928808212,4.022542476654053,50000 -404.7730875015259,0.2447774410247802,3817.795110702514,8094,0,3817.795110702514,0.1687000095844268,4.404850482940674,10000,4223.264901161194,0.2420898377895355,3.8262104988098153,0.2233599871397018,3.9299376010894775,50000 -448.5805284976959,0.2743458747863769,4237.933657169342,9001,0,4237.933657169342,0.1903000026941299,4.200378894805908,10000,4687.290641784668,0.2715038955211639,3.572762489318848,0.2484599947929382,3.701709032058716,50000 -490.6866438388825,0.2993443012237549,4658.288609981537,9908,0,4658.288609981537,0.2045000046491623,4.137157917022705,10000,5149.827425003052,0.2799023389816284,3.503488063812256,0.2557999789714813,3.6524834632873535,50000 -529.848639011383,0.8035287857055664,5077.988896608353,10808,0,5077.988896608353,0.1838000118732452,4.281069278717041,10000,5609.243244171143,0.2586914002895355,3.687410593032837,0.2386999875307083,3.820055484771729,50000 -574.7927012443542,0.8307521343231201,5498.148392438889,11712,0,5498.148392438889,0.2044000029563903,4.096872806549072,10000,6074.42378115654,0.2904882729053497,3.440134286880493,0.2700800001621246,3.5581839084625244,50000 -617.9171347618103,0.8616728782653809,5918.338252067566,12615,0,5918.338252067566,0.2028000056743621,4.120352745056152,10000,6537.819251060486,0.2845117151737213,3.495602369308472,0.262939989566803,3.6273386478424072,50000 -658.4205052852631,0.8963735103607178,6338.559593200684,13521,0,6338.559593200684,0.2183000147342682,3.9869000911712646,10000,6998.628676176071,0.3322265446186065,3.20548152923584,0.2886999845504761,3.4614808559417725,50000 -698.0081906318665,0.9258718490600586,6758.897345304489,14428,0,6758.897345304489,0.211200013756752,4.033987522125244,10000,7458.634178161621,0.301562488079071,3.388521671295166,0.280019998550415,3.510899305343628,50000 -742.7038342952728,0.957528829574585,7178.821423768997,15334,0,7178.821423768997,0.2208000123500824,3.9570510387420654,10000,7923.335958003998,0.3197460770606994,3.255380868911743,0.2985399961471557,3.3957231044769287,50000 -786.7944359779358,0.9867620468139648,7598.78031373024,16242,0,7598.78031373024,0.2251000106334686,3.943704128265381,10000,8387.46527838707,0.3459570109844208,3.120103120803833,0.2992799878120422,3.3753139972686768,50000 -823.8475241661072,1.0148942470550537,8019.018817424774,17149,0,8019.018817424774,0.2376000136137008,3.858373641967773,10000,8844.835416793823,0.3329882621765136,3.1438841819763184,0.3075200021266937,3.2924156188964844,50000 -865.6273958683014,1.045391082763672,8438.975415468216,18053,0,8438.975415468216,0.2346000075340271,3.882843255996704,10000,9306.65334200859,0.3213476538658142,3.226347208023072,0.2976999878883362,3.3574440479278564,50000 -910.0973136425018,1.0767600536346436,8859.113800764084,18960,0,8859.113800764084,0.2371000051498413,3.893336057662964,10000,9771.34357905388,0.3495312333106994,3.1147797107696533,0.3076199889183044,3.3317410945892334,50000 -954.149953365326,1.10762619972229,9279.488339185717,19867,0,9279.488339185717,0.2427000105381012,3.791513919830322,10000,10235.852065563202,0.3369531035423279,3.1348671913146973,0.3179399967193603,3.2607951164245605,50000 -997.4499151706696,1.1384267807006836,9699.60234951973,20771,0,9699.60234951973,0.2291000038385391,3.92555832862854,10000,10699.347157001495,0.3226367235183716,3.2174956798553467,0.2995399832725525,3.355835199356079,50000 -1043.6431443691254,1.169325590133667,10119.622790336609,21675,0,10119.622790336609,0.2485000044107437,3.8394834995269775,10000,11165.643399238586,0.3468359410762787,3.1015145778656006,0.3150199949741363,3.295905113220215,50000 -1088.2273619174955,1.2005112171173096,10539.923010349274,22581,0,10539.923010349274,0.2472000122070312,3.7654662132263184,10000,11630.60943365097,0.3467577993869781,3.073084592819214,0.3256799876689911,3.1851108074188232,50000 -1134.8523106575012,1.2339890003204346,10959.892746925354,23486,0,10959.892746925354,0.2502000033855438,3.758253812789917,10000,12097.288171768188,0.3531640470027923,3.034122467041016,0.3314200043678283,3.17514705657959,50000 -1180.8493909835815,1.2654187679290771,11380.030760765076,24390,0,11380.030760765076,0.2608000040054321,3.6874732971191406,10000,12563.504616975784,0.370410144329071,2.9094772338867188,0.3361800014972687,3.098283052444458,50000 -1226.3632354736328,1.2985377311706543,11800.12370634079,25296,0,11800.12370634079,0.252700001001358,3.7258048057556152,10000,13029.195685625076,0.3532421886920929,3.014591217041016,0.3339999914169311,3.1400809288024902,50000 -1271.7826147079468,1.3286545276641846,12220.688608169556,26202,0,12220.688608169556,0.2502000033855438,3.755117416381836,10000,13495.26099729538,0.3569726347923279,3.0196990966796875,0.3306399881839752,3.152756690979004,50000 -1319.2206687927246,1.3581278324127195,12640.947337388992,27109,0,12640.947337388992,0.2606000006198883,3.661961078643799,10000,13963.037720918655,0.3746679723262787,2.8854188919067383,0.3447200059890747,3.0645339488983154,50000 -1359.713604927063,1.389518976211548,13060.903878450394,28014,0,13060.903878450394,0.2556000053882599,3.7341465950012207,10000,14423.568979501724,0.36474609375,2.9812686443328857,0.3385799825191498,3.11944842338562,50000 -1404.037055015564,1.4189238548278809,13481.02701640129,28921,0,13481.02701640129,0.2603000104427337,3.740095853805542,10000,14888.096348762512,0.3594140410423279,3.011570692062378,0.3353599905967712,3.157099723815918,50000 -1448.2114634513855,1.4525518417358398,13901.145128250122,29827,0,13901.145128250122,0.2722000181674957,3.626361608505249,10000,15352.472299575806,0.3867773413658142,2.8524653911590576,0.3562400043010711,3.020305633544922,50000 -1492.4437718391418,1.489807367324829,14321.3808157444,30733,0,14321.3808157444,0.2699000239372253,3.6251919269561768,10000,15817.029076099396,0.40380859375,2.765493631362915,0.3507199883460998,3.034952402114868,50000 -1536.8387801647186,1.5209715366363523,14741.569679737093,31638,0,14741.569679737093,0.2657999992370605,3.6384551525115967,10000,16281.695380210876,0.372871071100235,2.894007921218872,0.3455199897289276,3.053222179412842,50000 -1580.6043591499329,1.5537612438201904,15161.727267503738,32541,0,15161.727267503738,0.2789000272750854,3.585503578186035,10000,16745.701989650726,0.3836914002895355,2.851547718048096,0.3562799990177154,3.018123149871826,50000 -1622.586046218872,1.5887606143951416,15581.699176073074,33446,0,15581.699176073074,0.2671000063419342,3.6647629737854,10000,17207.740622758865,0.4034960865974426,2.7900664806365967,0.3520599901676178,3.0779330730438232,50000 -1664.3043756484983,1.623687982559204,16001.772747516632,34350,0,16001.772747516632,0.2727999985218048,3.618873119354248,10000,17669.6182513237,0.3739843666553497,2.903052568435669,0.3515399992465973,3.0420234203338623,50000 -1703.8427374362946,1.655987024307251,16422.13542985916,35257,0,16422.13542985916,0.281000018119812,3.570101261138916,10000,18129.60253977776,0.3941406309604645,2.7809207439422607,0.3646000027656555,2.946488380432129,50000 -1750.1966817378998,1.688964605331421,16842.51434326172,36160,0,16842.51434326172,0.289000004529953,3.515514612197876,10000,18596.41860938072,0.4143359363079071,2.687903881072998,0.374699980020523,2.915180206298828,50000 -1794.428700685501,1.7227368354797363,17262.60005879402,37067,0,17262.60005879402,0.296500027179718,3.5063552856445312,10000,19060.820858955383,0.3991992175579071,2.74881911277771,0.3721599876880646,2.8969502449035645,50000 -1835.2586352825165,1.7550582885742188,17682.65801167488,37975,0,17682.65801167488,0.2865000069141388,3.4961540699005127,10000,19521.791479349136,0.4001367092132568,2.7493491172790527,0.3765600025653839,2.8872478008270264,50000 -1881.6751172542567,1.7953834533691406,18102.94758486748,38881,0,18102.94758486748,0.2857000231742859,3.532196283340454,10000,19988.588298797607,0.4105273485183716,2.71467924118042,0.3744199872016907,2.916626453399658,50000 -1925.902155160904,1.829215288162232,18522.871568918228,39787,0,18522.871568918228,0.2865000069141388,3.536340713500977,10000,20452.82318210601,0.3961132764816284,2.812042713165283,0.3680199980735779,2.944639205932617,50000 -1970.02938914299,1.8625555038452148,18942.88044810295,40692,0,18942.88044810295,0.2774000167846679,3.621984243392944,10000,20917.042991399765,0.3843359351158142,2.880928039550781,0.3576000034809112,3.033652782440185,50000 -2012.153751850128,1.900541305541992,19363.05272817612,41598,0,19363.05272817612,0.3025000095367431,3.3992068767547607,10000,21379.428280830383,0.4193945229053497,2.614825487136841,0.3873199820518493,2.794647216796875,50000 -2058.390007257461,1.9348745346069336,19783.368386268616,42501,0,19783.368386268616,0.2962000072002411,3.468750238418579,10000,21846.06435275078,0.4103906154632568,2.7060678005218506,0.3838399946689605,2.85146713256836,50000 -2098.456430912018,1.9667832851409912,20203.399605989456,43408,0,20203.399605989456,0.3060000240802765,3.412951707839966,10000,22306.24627304077,0.4212304651737213,2.6352226734161377,0.3901399970054626,2.7732999324798584,50000 -2140.2166497707367,1.9997265338897705,20623.64171051979,44314,0,20623.64171051979,0.3040000200271606,3.4005684852600098,10000,22768.33304667473,0.424628883600235,2.5954785346984863,0.3944399952888489,2.772815465927124,50000 -2186.9113490581512,2.037261247634888,21043.79127836228,45220,0,21043.79127836228,0.3058000206947326,3.4516892433166504,10000,23235.264661312103,0.4167773425579071,2.6995503902435303,0.3911199867725372,2.838123083114624,50000 -2230.829014539718,2.072231769561768,21463.8963804245,46125,0,21463.8963804245,0.2909000217914581,3.441751718521118,10000,23699.373474121094,0.4143945276737213,2.6916167736053467,0.3907199800014496,2.808105945587158,50000 -2276.559750556946,2.1068155765533447,21884.086091279984,47029,0,21884.086091279984,0.3026000261306762,3.3919708728790283,10000,24165.378796577454,0.4289257824420929,2.5802981853485107,0.3979199826717376,2.764410257339477,50000 -2320.15681385994,2.1419427394866943,22304.20796895027,47934,0,22304.20796895027,0.3137000203132629,3.3566060066223145,10000,24629.18269968033,0.4466210901737213,2.509028196334839,0.4015399813652038,2.7524359226226807,50000 -2363.196498155594,2.1787829399108887,22724.48389363289,48840,0,22724.48389363289,0.3037000000476837,3.428317070007324,10000,25092.585326194763,0.422656238079071,2.647350788116455,0.3947599828243255,2.812582969665528,50000 -2408.141443490982,2.212597131729126,23144.46818780899,49746,0,23144.46818780899,0.3099000155925751,3.355839252471924,10000,25557.598313570023,0.4342578053474426,2.547525405883789,0.3990799784660339,2.7272603511810303,50000 -2454.2899844646454,2.249370574951172,23564.61387562752,50654,0,23564.61387562752,0.313400000333786,3.36251187324524,10000,26023.98063802719,0.4611327946186065,2.438607215881348,0.4008199870586395,2.7578697204589844,50000 -2497.8036675453186,2.286559820175171,23984.93890619278,51560,0,23984.93890619278,0.314300000667572,3.3453259468078613,10000,26487.90759134293,0.4273632764816284,2.59151029586792,0.4017999768257141,2.755169630050659,50000 -2540.815685033798,2.323159694671631,24404.984939336777,52465,0,24404.984939336777,0.3096000254154205,3.3761465549468994,10000,26951.052780389786,0.4356249868869781,2.548950672149658,0.4038800001144409,2.7284858226776123,50000 -2585.6624703407288,2.361346244812012,24825.079872131348,53370,0,24825.079872131348,0.3069000244140625,3.367358684539795,10000,27416.082621097565,0.4490039050579071,2.478402137756348,0.4007599949836731,2.747389316558838,50000 -2628.652004241944,2.3979926109313965,25245.164329767227,54276,0,25245.164329767227,0.3165000081062317,3.3001091480255127,10000,27879.24419236183,0.4378320276737213,2.539738655090332,0.4072999954223633,2.691128969192505,50000 -2673.9957478046417,2.4337828159332275,25665.14268875122,55183,0,25665.14268875122,0.3144000172615051,3.315009593963623,10000,28344.652994155884,0.4406445324420929,2.527313470840454,0.4094799757003784,2.68900203704834,50000 -2719.850363969803,2.47387957572937,26085.327237844467,56089,0,26085.327237844467,0.3237000107765198,3.3011837005615234,10000,28810.783006429672,0.4562695324420929,2.470243453979492,0.4152199923992157,2.69305157661438,50000 -2765.686028242111,2.5117812156677246,26505.334926843643,56995,0,26505.334926843643,0.3207000195980072,3.299652099609375,10000,29276.715250492096,0.4382421672344208,2.575744867324829,0.4134999811649322,2.716792821884156,50000 -2810.310459375381,2.5485615730285645,26925.474437713623,57901,0,26925.474437713623,0.3211000263690948,3.273627281188965,10000,29741.56713962555,0.4457812309265136,2.49980092048645,0.4198599755764007,2.653440475463867,50000 -2854.4081478118896,2.5851545333862305,27345.811821460724,58809,0,27345.811821460724,0.3300000131130218,3.225630521774292,10000,30206.088725328445,0.4603320062160492,2.395381212234497,0.4214800000190735,2.604405164718628,50000 -2897.407002687454,2.621112585067749,27765.990793704987,59714,0,27765.990793704987,0.3210000097751617,3.3422865867614746,10000,30669.352987527847,0.4366796910762787,2.612844705581665,0.4098999798297882,2.748467445373535,50000 -2945.836772441864,2.6567232608795166,28186.21174645424,60619,0,28186.21174645424,0.3256000280380249,3.250809669494629,10000,31138.09031653404,0.4499804675579071,2.467857837677002,0.4191199839115143,2.639002799987793,50000 -2990.0949623584747,2.695803642272949,28606.36927008629,61528,0,28606.36927008629,0.3350000083446502,3.217402696609497,10000,31602.596896648407,0.4624609351158142,2.3810806274414062,0.4293999969959259,2.5703067779541016,50000 -3031.917558193207,2.7409818172454834,29026.63190627098,62435,0,29026.63190627098,0.3382000029087066,3.219500541687012,10000,32064.77761387825,0.4630273282527923,2.4418976306915283,0.4314000010490417,2.5994136333465576,50000 -3075.777815580368,2.7796471118927,29446.979049682617,63342,0,29446.979049682617,0.3413000106811523,3.1512272357940674,10000,32529.07609534264,0.4673046767711639,2.354846954345703,0.4369799792766571,2.523207664489746,50000 -3120.9421343803406,2.8182005882263184,29867.28816127777,64248,0,29867.28816127777,0.3419000208377838,3.1868813037872314,10000,32994.63881659508,0.4647851586341858,2.3934576511383057,0.4303799867630005,2.585111856460572,50000 -3160.80952334404,2.8595807552337646,30287.48583388329,65155,0,30287.48583388329,0.3257000148296356,3.3210971355438232,10000,33454.79567170143,0.4700976312160492,2.458969593048096,0.4227599799633026,2.704218626022339,50000 -3203.6639914512634,2.895918846130371,30707.519748210907,66059,0,30707.519748210907,0.3461000025272369,3.131437301635742,10000,33917.77133059502,0.4700585901737213,2.350682258605957,0.4400999844074249,2.517786741256714,50000 -3249.730548620224,2.936368942260742,31127.525916337967,66962,0,31127.525916337967,0.3354000151157379,3.2074503898620605,10000,34383.9345805645,0.4616210758686065,2.3884122371673584,0.4339599907398224,2.554633617401123,50000 -3294.203216075897,2.976597309112549,31547.554701805115,67868,0,31547.554701805115,0.3392000198364258,3.1619865894317627,10000,34848.52709269524,0.5029687285423279,2.1818504333496094,0.4394199848175049,2.515748977661133,50000 -3337.2650215625763,3.0178487300872803,31967.763231039047,68775,0,31967.763231039047,0.3438000082969665,3.122487545013428,10000,35311.88851070404,0.4720703065395355,2.326566219329834,0.4456599950790405,2.470954656600952,50000 -3381.8008399009705,3.0573742389678955,32387.87904167176,69682,0,32387.87904167176,0.34620001912117,3.122164487838745,10000,35776.63043308258,0.4788476526737213,2.295114278793335,0.4430199861526489,2.4969727993011475,50000 -3422.816102027893,3.100841999053955,32807.88691544533,70588,0,32807.88691544533,0.34620001912117,3.092029333114624,10000,36237.747968912125,0.4978906214237213,2.208686113357544,0.4494799971580505,2.4724645614624023,50000 -3468.0950469970703,3.1422665119171143,33228.21760249138,71496,0,33228.21760249138,0.3406000137329101,3.128873109817505,10000,36703.45001864433,0.4721484184265136,2.327339172363281,0.4462999999523163,2.475956678390503,50000 -3511.7916102409363,3.1832900047302246,33648.21060991287,72403,0,33648.21060991287,0.3487000167369842,3.151477575302124,10000,37167.23076486588,0.4733007848262787,2.381632089614868,0.4436999857425689,2.5497748851776123,50000 -3553.8385372161865,3.2199816703796387,34068.42133665085,73310,0,34068.42133665085,0.3387000262737274,3.1976158618927,10000,37629.576330661774,0.4833788871765136,2.329761266708374,0.4386399984359741,2.551053047180176,50000 -3599.0153257846832,3.2576658725738525,34488.45884680748,74216,0,34488.45884680748,0.3589000105857849,3.0651798248291016,10000,38094.87999844551,0.4860742092132568,2.254650354385376,0.4541399776935577,2.429095268249512,50000 -3645.705919742584,3.294504880905152,34908.47983670235,75122,0,34908.47983670235,0.3505000174045563,3.143037796020508,10000,38561.67894077301,0.4874609410762787,2.328404903411865,0.4536999762058258,2.506525754928589,50000 -3691.841574668884,3.332691431045532,35328.69760990143,76031,0,35328.69760990143,0.3543000221252441,3.110901117324829,10000,39028.121799230576,0.49755859375,2.247817039489746,0.4552599787712097,2.464518785476685,50000 -3732.874413013458,3.374640941619873,35748.71830582619,76940,0,35748.71830582619,0.3583000302314758,3.084025621414185,10000,39489.267315626144,0.4865234196186065,2.275959014892578,0.4581599831581116,2.418885946273804,50000 -3774.578455448151,3.415534734725952,36168.80930709839,77847,0,36168.80930709839,0.3667000234127044,3.028707265853882,10000,39951.15533995628,0.5015038847923279,2.2032933235168457,0.4641399979591369,2.393378496170044,50000 -3821.7439455986014,3.457021474838257,36588.75086402893,78755,0,36588.75086402893,0.3612000048160553,3.025886297225952,10000,40418.35488009453,0.5059765577316284,2.1536612510681152,0.466399997472763,2.365823268890381,50000 -3862.8010079860687,3.495908260345459,37009.04603528976,79662,0,37009.04603528976,0.3664000034332275,3.014451026916504,10000,40879.798095703125,0.5085351467132568,2.1568405628204346,0.4750399887561798,2.333594560623169,50000 -3907.055018186569,3.535282850265503,37429.33586239815,80571,0,37429.33586239815,0.358100026845932,3.056978940963745,10000,41344.432941913605,0.4918359220027923,2.267910480499268,0.4589399993419647,2.433488130569458,50000 -3951.705055236816,3.575528621673584,37849.65655899048,81479,0,37849.65655899048,0.3653000295162201,3.051608085632324,10000,41809.49540543556,0.5015624761581421,2.221553087234497,0.4664799869060516,2.406338691711426,50000 -3991.25189948082,3.620238780975342,38269.87842488289,82386,0,38269.87842488289,0.3764000236988067,2.9651641845703125,10000,42269.36018133164,0.5391015410423279,2.031850099563598,0.4758799970149994,2.33320951461792,50000 -4035.173714160919,3.662067651748657,38689.997086048126,83292,0,38689.997086048126,0.370600014925003,2.987058639526367,10000,42733.49368548393,0.509472668170929,2.1737000942230225,0.472599983215332,2.343513250350952,50000 -4080.785442113876,3.705613374710083,39110.11424565315,84198,0,39110.11424565315,0.3580000102519989,3.0713653564453125,10000,43199.31651544571,0.5048437118530273,2.2214019298553467,0.4667199850082397,2.4125983715057373,50000 -4120.204581737518,3.74727201461792,39530.06701803208,85105,0,39530.06701803208,0.3827000260353088,2.939318895339966,10000,43658.78203487396,0.5434765219688416,1.975115776062012,0.4802799820899963,2.3029189109802246,50000 -4163.963869810104,3.7868733406066895,39950.27555394173,86012,0,39950.27555394173,0.3664000034332275,3.0009925365448,10000,44122.84530091286,0.5088672041893005,2.185390949249268,0.4745599925518036,2.349569082260132,50000 -4209.52586555481,3.8280630111694336,40370.32043933869,86918,0,40370.32043933869,0.3716000318527221,2.941880702972412,10000,44588.54420852661,0.5166406035423279,2.1085727214813232,0.4803199768066406,2.2921464443206787,50000 -4254.7183582782745,3.870031595230103,40790.41343569756,87822,0,40790.41343569756,0.3785000145435333,2.930760622024536,10000,45053.92148447037,0.5345312356948853,2.0101184844970703,0.4843799769878387,2.270280122756958,50000 -4300.670683383942,3.910280704498291,41210.61529803276,88728,0,41210.61529803276,0.3844000101089477,2.910942554473877,10000,45520.16591215134,0.5253124833106995,2.1078312397003174,0.4916599988937378,2.2708756923675537,50000 -4343.729335069656,3.955927610397339,41630.89320707321,89635,0,41630.89320707321,0.3828000128269195,2.9128763675689697,10000,45983.5998442173,0.5248632431030273,2.0753753185272217,0.4898599982261657,2.257879734039306,50000 -4388.4740562438965,4.000329256057739,42050.96832442284,90538,0,42050.96832442284,0.3754000067710876,2.9723060131073,10000,46448.51464676857,0.5314648151397705,2.07725191116333,0.4858999848365783,2.3086907863616943,50000 -4433.668882369995,4.045256614685059,42471.20595598221,91442,0,42471.20595598221,0.394400030374527,2.8569884300231934,10000,46914.04198694229,0.5350781083106995,2.023699283599853,0.5030400156974792,2.1932621002197266,50000 -4475.179137229919,4.08874249458313,42891.15717124939,92349,0,42891.15717124939,0.3882000148296356,2.938272476196289,10000,47375.598071336746,0.5308398604393005,2.097571849822998,0.4921799898147583,2.277677059173584,50000 -4518.571710586548,4.132167816162109,43311.2950565815,93254,0,43311.2950565815,0.3873000144958496,2.8787736892700195,10000,47839.22241187096,0.5414257645606995,2.011777639389038,0.4967799782752991,2.2263717651367188,50000 -4564.084159851074,4.174811124801636,43731.3213903904,94158,0,43731.3213903904,0.395300030708313,2.828667640686035,10000,48304.85384202004,0.5350976586341858,2.0232887268066406,0.5008000135421753,2.1791908740997314,50000 -4609.533478021622,4.222652435302734,44151.28519535065,95065,0,44151.28519535065,0.4004000127315521,2.8348753452301025,10000,48770.36529827118,0.5455663800239563,1.9697304964065552,0.5080199837684631,2.1573288440704346,50000 -4650.271003246307,4.266034364700317,44571.42405152321,95972,0,44571.42405152321,0.3910000324249267,2.8307206630706787,10000,49231.33609867096,0.5479297041893005,1.936432957649231,0.5101799964904785,2.144627571105957,50000 -4691.759567499161,4.712217807769775,44991.071533203125,96873,0,44991.071533203125,0.4003000259399414,2.800229549407959,10000,49692.96869182587,0.5483007431030273,1.9771265983581543,0.5109999775886536,2.1710920333862305,50000 -4738.096249103546,4.756515741348267,45411.392758369446,97778,0,45411.392758369446,0.4104000329971313,2.800655603408813,10000,50159.7207980156,0.550976574420929,1.9703459739685056,0.5161399841308594,2.141354322433472,50000 -4782.682325363159,4.801445245742798,45831.61665439606,98684,0,45831.61665439606,0.4022000133991241,2.7948946952819824,10000,50624.62672114372,0.5575000047683716,1.9146215915679927,0.5131999850273132,2.129566669464112,50000 -4826.09882068634,4.854480981826782,46251.78324842453,99591,0,46251.78324842453,0.4103000164031982,2.728025436401367,10000,51088.31325173378,0.587890625,1.753287434577942,0.5281199812889099,2.0528218746185303,50000 -4874.545390844345,4.899670839309692,46672.05391287804,100497,0,46672.05391287804,0.4115000069141388,2.745382308959961,10000,51557.12585926056,0.5584765672683716,1.894142508506775,0.5205199718475342,2.08619236946106,50000 -4917.531066417694,4.94485330581665,47092.40252113342,101401,0,47092.40252113342,0.4159000217914581,2.720703601837158,10000,52020.5557975769,0.5635937452316284,1.8744672536849976,0.5206999778747559,2.0845861434936523,50000 -4961.000032663345,4.988691568374634,47512.490429639816,102306,0,47512.490429639816,0.417900025844574,2.692728281021118,10000,52484.20654010773,0.6005273461341858,1.6896134614944458,0.5290799736976624,2.035361051559448,50000 -5001.267645597458,5.030147075653076,47932.62344169617,103213,0,47932.62344169617,0.4070000052452087,2.746877193450928,10000,52944.69896483421,0.5564843416213989,1.9020169973373413,0.5239399671554565,2.0747263431549072,50000 -5046.142510414124,5.07442831993103,48352.95508027077,104120,0,48352.95508027077,0.41880002617836,2.700232744216919,10000,53410.00083827973,0.5759570002555847,1.842762351036072,0.5297799706459045,2.0636205673217773,50000 -5091.086785554886,5.116716623306274,48773.32455801964,105028,0,48773.32455801964,0.4181000292301178,2.695244312286377,10000,53875.40821886063,0.5927929282188416,1.7482671737670898,0.5322999954223633,2.0392508506774902,50000 -5135.292055368424,5.161932468414307,49193.34681630135,105936,0,49193.34681630135,0.422400027513504,2.671083688735962,10000,54339.73200273514,0.571972668170929,1.83349883556366,0.5362200140953064,2.0098884105682373,50000 -5176.745602607727,5.209371089935303,49613.69072461128,106841,0,49613.69072461128,0.4261000156402588,2.6580498218536377,10000,54801.62851881981,0.5798632502555847,1.801438570022583,0.5390200018882751,2.008123874664306,50000 -5224.058357954025,5.257244348526001,50034.01110982895,107744,0,50034.01110982895,0.4178000092506408,2.7498693466186523,10000,55269.35957503319,0.5798437595367432,1.8626606464385984,0.5293999910354614,2.1074492931365967,50000 -5264.811240434647,5.306406021118164,50454.28623723984,108651,0,50454.28623723984,0.4240000247955322,2.657770156860352,10000,55730.48776316643,0.5818749666213989,1.7912267446517944,0.5414400100708008,1.9961179494857788,50000 -5304.367951393127,5.35333776473999,50874.582459926605,109557,0,50874.582459926605,0.4294000267982483,2.6383376121521,10000,56190.43788194656,0.5807812213897705,1.7843594551086426,0.5475599765777588,1.9676730632781985,50000 -5349.13404250145,5.398436546325684,51294.605360507965,110464,0,51294.605360507965,0.4315000176429748,2.643044948577881,10000,56655.322309970856,0.5948437452316284,1.731536626815796,0.5445600152015686,1.979048132896424,50000 -5388.784100055695,5.452826261520386,51714.772094249725,111371,0,51714.772094249725,0.4382000267505646,2.608139991760254,10000,57115.24357366562,0.5899804830551147,1.7508074045181274,0.555079996585846,1.9349216222763064,50000 -5433.386849164963,5.500851154327393,52135.0346596241,112281,0,52135.0346596241,0.4413000345230102,2.598820447921753,10000,57580.20839238167,0.5955273509025574,1.7364590167999268,0.5569800138473511,1.9330426454544067,50000 -5479.865357398987,5.547860145568848,52555.18597626686,113187,0,52555.18597626686,0.4421000182628631,2.5687875747680664,10000,58046.9350438118,0.6030077934265137,1.6891416311264038,0.5582199692726135,1.9116791486740112,50000 -5526.562952518463,5.59241247177124,52975.51297545433,114093,0,52975.51297545433,0.4409000277519226,2.6083571910858154,10000,58514.05428671837,0.5959765315055847,1.7394243478775024,0.5590400099754333,1.9202549457550049,50000 -5572.246104717255,5.641361236572266,53395.86148428917,114999,0,53395.86148428917,0.4438000321388244,2.568891763687134,10000,58980.18489718437,0.5999413728713989,1.698953628540039,0.5589599609375,1.901583433151245,50000 -5615.125262737274,5.690145254135132,53815.892790317535,115903,0,53815.892790317535,0.447700023651123,2.532124757766724,10000,59443.194900512695,0.6122460961341858,1.6250638961791992,0.5637800097465515,1.8601083755493164,50000 -5657.219936609268,5.732917547225952,54236.21877121925,116809,0,54236.21877121925,0.4429000318050384,2.542642116546631,10000,59905.71079039574,0.6307226419448853,1.5751163959503174,0.5652799606323242,1.8742843866348269,50000 -5699.84024477005,5.780933141708374,54656.47298908234,117715,0,54656.47298908234,0.4442000091075897,2.5546884536743164,10000,60368.68414545059,0.606738269329071,1.6873126029968262,0.5666399598121643,1.881292462348938,50000 -5743.124422311783,5.825575351715088,55076.77208423615,118621,0,55076.77208423615,0.4595000147819519,2.4840798377990723,10000,60832.36305522919,0.6252343654632568,1.57450532913208,0.5776000022888184,1.8012522459030151,50000 -5788.15007519722,5.872597932815552,55496.69383692741,119523,0,55496.69383692741,0.4567000269889831,2.4974396228790283,10000,61297.40745139122,0.6426367163658142,1.501293420791626,0.5740599632263184,1.8339369297027588,50000 -5832.912647247314,5.921623706817627,55916.784044504166,120427,0,55916.784044504166,0.4561000168323517,2.485478162765503,10000,61762.35899710655,0.6160937547683716,1.6148782968521118,0.5806800127029419,1.8061535358428955,50000 -5881.302688598633,5.969411849975586,56336.82581615448,121333,0,56336.82581615448,0.458400011062622,2.487220764160156,10000,62230.889113903046,0.6248632669448853,1.5781731605529783,0.5776399970054626,1.8077296018600464,50000 -5926.17863202095,6.0188164710998535,56756.91334247589,122237,0,56756.91334247589,0.458400011062622,2.5103447437286377,10000,62695.95139026642,0.6359570026397705,1.568046808242798,0.5774999856948853,1.843104600906372,50000 -5967.953478097916,6.071156740188599,57177.12469792366,123144,0,57177.12469792366,0.4677000343799591,2.4270832538604736,10000,63158.04018783569,0.6289648413658142,1.563175916671753,0.5860599875450134,1.769154667854309,50000 -6014.114964962006,6.117965459823608,57597.355749607086,124049,0,57597.355749607086,0.4705000221729278,2.4021661281585693,10000,63624.529782772064,0.640429675579071,1.5160832405090332,0.5918599963188171,1.7468619346618652,50000 -6060.277943134308,6.1684510707855225,58017.52013874054,124954,0,58017.52013874054,0.4673000276088714,2.402538537979126,10000,64090.95868706703,0.6468749642372131,1.47525954246521,0.5937199592590332,1.733345627784729,50000 -6104.252175807953,6.215758085250855,58437.53057861328,125858,0,58437.53057861328,0.4754000306129455,2.3764123916625977,10000,64555.04180955887,0.6384179592132568,1.4969202280044556,0.5963000059127808,1.702694058418274,50000 -6149.214513778687,6.2652997970581055,58857.83850765228,126763,0,58857.83850765228,0.4791000187397003,2.3670125007629395,10000,65020.412427186966,0.6460155844688416,1.478101372718811,0.6011399626731873,1.703931212425232,50000 -6196.937285423279,6.313244581222534,59278.02184915543,127669,0,59278.02184915543,0.4782000184059143,2.389633655548096,10000,65488.41676783562,0.6504492163658142,1.4836300611495972,0.5961999893188477,1.7273157835006714,50000 -6241.280877828598,6.368543863296509,59698.199717760086,128576,0,59698.199717760086,0.4870000183582306,2.3387036323547363,10000,65953.0438401699,0.6483983993530273,1.4797208309173584,0.6062399744987488,1.687266826629639,50000 -6287.754509687424,6.416506052017212,60118.53882169724,129480,0,60118.53882169724,0.4841000139713287,2.354218006134033,10000,66419.9552268982,0.6539648175239563,1.4522738456726074,0.605679988861084,1.6788485050201416,50000 -6333.296562671661,6.466786623001099,60538.44319176674,130386,0,60538.44319176674,0.4922000169754028,2.3320703506469727,10000,66885.50264811516,0.6702734231948853,1.4016005992889404,0.6136199831962585,1.660366773605347,50000 -6374.600378513336,6.51424241065979,60958.65554857254,131293,0,60958.65554857254,0.4907000362873077,2.317131996154785,10000,67347.1164739132,0.6552929282188416,1.4327030181884766,0.6084200143814087,1.6539702415466309,50000 -6420.612956047058,6.562076568603516,61378.65146899223,132198,0,61378.65146899223,0.4917000234127044,2.3008532524108887,10000,67813.22391462326,0.6626366972923279,1.4006588459014893,0.6137999892234802,1.6258078813552856,50000 -6465.681235074997,6.609776496887207,61798.64251708984,133104,0,61798.64251708984,0.4901000261306762,2.3164613246917725,10000,68278.38157367706,0.6655077934265137,1.4070724248886108,0.6142599582672119,1.652784824371338,50000 -6509.769501447678,6.65800929069519,62218.54648470879,134008,0,62218.54648470879,0.497700035572052,2.2720818519592285,10000,68742.47394442558,0.6812499761581421,1.345516324043274,0.6206799745559692,1.6100428104400637,50000 -6556.1755702495575,6.708175182342529,62638.87726259232,134913,0,62638.87726259232,0.5078000426292419,2.2082858085632324,10000,69209.31068348885,0.6734570264816284,1.3531352281570437,0.6266799569129944,1.563097596168518,50000 -6598.604954004288,6.757187128067017,63059.16834306717,135817,0,63059.16834306717,0.5031000375747681,2.244506359100342,10000,69672.13042116165,0.6738085746765137,1.3599903583526611,0.6240599751472473,1.5829092264175415,50000 -6644.519300699234,6.807169198989868,63479.30086803436,136722,0,63479.30086803436,0.5020000338554382,2.246837615966797,10000,70138.27771234512,0.7079687118530273,1.2369790077209473,0.6272000074386597,1.5897619724273682,50000 -6685.11887216568,6.857578992843628,63899.45568084717,137630,0,63899.45568084717,0.5099000334739685,2.199458122253418,10000,70599.13270401955,0.675585925579071,1.340919017791748,0.6327599883079529,1.5627877712249756,50000 -6728.542403936386,6.904958009719849,64319.47856760025,138533,0,64319.47856760025,0.5120000243186951,2.1723082065582275,10000,71062.68217468262,0.6859960556030273,1.2904889583587646,0.6377399563789368,1.5166523456573486,50000 -6772.798098802567,6.95450758934021,64739.70188331604,139435,0,64739.70188331604,0.5063000321388245,2.2396981716156006,10000,71527.2613837719,0.6988281011581421,1.2646912336349487,0.6321799755096436,1.5676641464233398,50000 -6817.2175986766815,7.002084493637085,65159.62976360321,140342,0,65159.62976360321,0.5200999975204468,2.154827356338501,10000,71991.70762300491,0.6906836032867432,1.2769279479980469,0.6407399773597717,1.5057283639907837,50000 -6862.2890038490295,7.054901123046875,65579.96733903885,141247,0,65579.96733903885,0.5234000086784363,2.1413466930389404,10000,72457.2199523449,0.69837886095047,1.2515182495117188,0.6464599967002869,1.4908294677734375,50000 -6907.379415750504,7.107923269271851,66000.20277547836,142151,0,66000.20277547836,0.5210000276565552,2.128354549407959,10000,72922.64905381203,0.7085741758346558,1.1992639303207395,0.6479600071907043,1.4763096570968628,50000 -6955.451377868652,7.158038854598999,66420.13485479355,143055,0,66420.13485479355,0.5277000069618225,2.1308555603027344,10000,73390.75317406654,0.7023242115974426,1.2293777465820312,0.6479600071907043,1.4754472970962524,50000 -7001.00201010704,7.209303379058838,66840.06156110764,143961,0,66840.06156110764,0.5301000475883484,2.108567237854004,10000,73856.33251214027,0.70570307970047,1.21626615524292,0.6527599692344666,1.4616146087646484,50000 -7046.307956695557,7.260730981826782,67260.06018710136,144867,0,67260.06018710136,0.534500002861023,2.0833964347839355,10000,74321.73847436905,0.7190039157867432,1.1549161672592163,0.6559000015258789,1.4345026016235352,50000 -7092.093881845474,7.313398361206055,67680.40997552872,145774,0,67680.40997552872,0.5312000513076782,2.109943389892578,10000,74787.97717380524,0.7044140696525574,1.2267926931381226,0.6544199585914612,1.4555846452713013,50000 -7140.838541984558,7.363656282424927,68100.34601187706,146680,0,68100.34601187706,0.5371000170707703,2.0639452934265137,10000,75256.75840878487,0.7185156345367432,1.158596158027649,0.6607599854469299,1.4224454164505005,50000 -7182.027892827988,7.416749477386475,68520.35908174515,147585,0,68520.35908174515,0.541700005531311,2.0514132976531982,10000,75718.0649137497,0.7215625047683716,1.143648982048035,0.6613399982452393,1.4184284210205078,50000 -7227.158820390701,7.469902276992798,68940.4190402031,148493,0,68940.4190402031,0.5446000099182129,2.0252130031585693,10000,76183.35967731476,0.7193359136581421,1.1469899415969849,0.667140007019043,1.39079749584198,50000 -7272.813814640045,7.522792100906372,69360.67560601234,149400,0,69360.67560601234,0.542900025844574,2.033937692642212,10000,76649.3745880127,0.7261328101158142,1.1292051076889038,0.6677199602127075,1.3793786764144895,50000 -7314.125680685043,7.57967209815979,69780.98648524284,150306,0,69780.98648524284,0.5432000160217285,2.0280535221099854,10000,77111.10426402092,0.7345898151397705,1.0911519527435305,0.6721599698066711,1.3719079494476318,50000 -7360.41232419014,7.630008459091186,70201.0912911892,151210,0,70201.0912911892,0.5479000210762024,2.0155558586120605,10000,77577.59732437134,0.7315624952316284,1.0995270013809204,0.6723600029945374,1.3750596046447754,50000 -7405.974725008011,7.682317733764648,70620.99999403954,152116,0,70620.99999403954,0.5561000108718872,1.9908647537231443,10000,78043.17117524147,0.7346875071525574,1.0947026014328003,0.6771000027656555,1.352385401725769,50000 -7447.667580366135,7.731944561004639,71041.29298949242,153020,0,71041.29298949242,0.5586000084877014,1.9496073722839355,10000,78505.25744843483,0.7451757788658142,1.0271271467208862,0.6819199919700623,1.3150537014007568,50000 -7490.222132205963,7.780643224716186,71461.58991360664,153922,0,71461.58991360664,0.5543000102043152,1.9502512216567995,10000,78968.2080783844,0.758496105670929,0.978778839111328,0.6845600008964539,1.3084824085235596,50000 -7537.654074668884,7.830984830856323,71881.67134642601,154829,0,71881.67134642601,0.5617000460624695,1.938288211822509,10000,79435.82341265678,0.7406249642372131,1.0503544807434082,0.6873799562454224,1.2965481281280518,50000 -7582.568460941315,7.88397216796875,72301.95320534706,155736,0,72301.95320534706,0.5610000491142273,1.922777771949768,10000,79901.12423586845,0.7502343654632568,1.0065470933914185,0.687559962272644,1.2907850742340088,50000 -7627.477694272995,7.94585371017456,72721.92038750648,156641,0,72721.92038750648,0.5648000240325928,1.9105957746505733,10000,80366.11267876625,0.76318359375,0.952807605266571,0.6926800012588501,1.2713115215301514,50000 -7673.310721158981,8.001603126525879,73141.96051359177,157547,0,73141.96051359177,0.572700023651123,1.8889102935791016,10000,80832.09207010269,0.7520703077316284,1.0017552375793457,0.6958400011062622,1.251082420349121,50000 -7713.773415803909,8.056279420852661,73562.28885316849,158452,0,73562.28885316849,0.5720000267028809,1.892466902732849,10000,81292.98815059662,0.7565234303474426,0.9877767562866212,0.6963799595832825,1.252785563468933,50000 -7759.871674776077,8.111968994140625,73982.47636890411,159359,0,73982.47636890411,0.5732000470161438,1.8835759162902832,10000,81759.38010597229,0.7693945169448853,0.939318060874939,0.6983799934387207,1.2409757375717163,50000 -7804.115849733353,8.173494815826416,74402.69127678871,160264,0,74402.69127678871,0.5814000368118286,1.869041085243225,10000,82223.95091509819,0.7602343559265137,0.971134066581726,0.6994400024414062,1.2386878728866575,50000 -7851.924923419952,8.229990243911743,74822.95345377922,161168,0,74822.95345377922,0.5782999992370605,1.8563323020935056,10000,82692.12824940681,0.7683203220367432,0.9366870522499084,0.7050999999046326,1.2145330905914309,50000 -7895.971745014191,8.284348249435425,75242.92199492455,162070,0,75242.92199492455,0.5851000547409058,1.8382530212402344,10000,83156.2474398613,0.77783203125,0.8856746554374695,0.7069199681282043,1.2027339935302734,50000 -7937.003118753433,8.34020447731018,75663.21282863617,162976,0,75663.21282863617,0.5868000388145447,1.793388605117798,10000,83617.6762702465,0.7721288800239563,0.9069911241531372,0.7118799686431885,1.1784546375274658,50000 -7984.124661684036,8.398981094360352,76083.81787419319,163883,0,76083.81787419319,0.5900000333786011,1.8173900842666624,10000,84085.51246571541,0.7769335508346558,0.8996948599815369,0.7105799913406372,1.1845804452896118,50000 -8030.3428745269775,8.457646608352661,76504.19021439552,164791,0,76504.19021439552,0.5868000388145447,1.812342643737793,10000,84552.21259260178,0.7859765291213989,0.87266606092453,0.7135399580001831,1.186859130859375,50000 -8075.156850099564,8.51274585723877,76924.15655565262,165696,0,76924.15655565262,0.5914000272750854,1.7810871601104736,10000,85017.09811878204,0.7836523056030273,0.8662102222442627,0.7154799699783325,1.161203145980835,50000 -8122.107517242432,8.569658279418945,77344.08736562729,166602,0,77344.08736562729,0.595300018787384,1.7856409549713135,10000,85484.08644890785,0.7826757431030273,0.8744816184043884,0.7170400023460388,1.1657946109771729,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/measurements.csv deleted file mode 100644 index 273b2e013..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1857 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36613074,6.9077563,,,,,,,,,,,,,, -1,,,0.0012109375093132,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,36.40435767173767,65.12631821632385,36.40435767173767,28.721853494644165,0.0,0.0 -100,0.5026673,6.8646927,,,,,,,,,,,,,, -200,0.87780493,6.7224402,,,,,,,,,,,,,, -300,0.9597464,6.5656753,,,,,,,,,,,,,, -400,1.1293039,6.4527893,,,,,,,,,,,,,, -500,0.8986033,6.482828,,,,,,,,,,,,,, -600,1.3880694,6.357982,,,,,,,,,,,,,, -700,0.95491713,6.8025713,,,,,,,,,,,,,, -800,0.7636935,6.327052,,,,,,,,,,,,,, -850,,,0.0379492193460464,5.836334705352783,0.0373799987137317,5.868350028991699,50000.0,0.0304000005125999,6.007944583892822,10000.0,456.5975193977356,525.6996972560883,456.5975193977356,69.03572416305542,0.0187666416168212,0.0 -900,0.79328316,6.156171,,,,,,,,,,,,,, -1000,0.7723462,6.176821,,,,,,,,,,,,,, -1100,0.7683089,6.0873456,,,,,,,,,,,,,, -1200,0.71333855,5.969474,,,,,,,,,,,,,, -1300,0.6383229,6.043,,,,,,,,,,,,,, -1400,0.62578416,6.123153,,,,,,,,,,,,,, -1500,0.58902425,6.128996,,,,,,,,,,,,,, -1600,0.63934886,5.899905,,,,,,,,,,,,,, -1700,0.56833154,6.471216,,,,,,,,,,,,,, -1753,,,0.0610937476158142,5.500517845153809,0.0549599975347518,5.583978176116943,50000.0,0.0419000014662742,5.780269145965576,10000.0,876.5749049186707,986.8721923828124,876.5749049186707,110.14891505241394,0.0506336688995361,0.0 -1800,0.5575941,6.3596835,,,,,,,,,,,,,, -1900,0.49406505,6.646168,,,,,,,,,,,,,, -2000,0.49529424,5.7615747,,,,,,,,,,,,,, -2100,0.68372214,5.7773886,,,,,,,,,,,,,, -2200,0.4735354,5.6540813,,,,,,,,,,,,,, -2300,0.53925705,5.683591,,,,,,,,,,,,,, -2400,0.53789073,5.6139913,,,,,,,,,,,,,, -2500,0.46459708,5.6230807,,,,,,,,,,,,,, -2600,0.51330805,5.5380855,,,,,,,,,,,,,, -2659,,,0.1049999967217445,4.955112934112549,0.095660001039505,5.036952495574951,50000.0,0.0767000019550323,5.3251471519470215,10000.0,1296.616675615311,1450.5467166900637,1296.616675615311,153.70484042167664,0.0762069225311279,0.0 -2700,0.47981018,6.358533,,,,,,,,,,,,,, -2800,0.45399776,5.340251,,,,,,,,,,,,,, -2900,0.5613094,5.510897,,,,,,,,,,,,,, -3000,0.45191863,5.567216,,,,,,,,,,,,,, -3100,0.44126642,5.3349414,,,,,,,,,,,,,, -3200,0.533241,6.3702188,,,,,,,,,,,,,, -3300,0.90547657,5.3647075,,,,,,,,,,,,,, -3400,0.46495348,5.351703,,,,,,,,,,,,,, -3500,0.48612288,5.6960154,,,,,,,,,,,,,, -3565,,,0.1318359375,4.715084075927734,0.1221200004220008,4.792366981506348,50000.0,0.0967000052332878,5.13610315322876,10000.0,1716.937399148941,1909.107887744904,1716.937399148941,191.8669803142548,0.1044766902923584,0.0 -3600,0.6632005,6.2456455,,,,,,,,,,,,,, -3700,0.7527672,5.25135,,,,,,,,,,,,,, -3800,0.6764079,5.3071165,,,,,,,,,,,,,, -3900,0.636554,5.386167,,,,,,,,,,,,,, -4000,0.5921555,5.228377,,,,,,,,,,,,,, -4100,0.7115057,5.2806935,,,,,,,,,,,,,, -4200,0.65273815,5.0810227,,,,,,,,,,,,,, -4300,0.8480277,5.1317377,,,,,,,,,,,,,, -4400,0.7156684,6.2610073,,,,,,,,,,,,,, -4470,,,0.1639648377895355,4.44169282913208,0.1498399972915649,4.542600631713867,50000.0,0.1157000064849853,4.900367259979248,10000.0,2137.0898151397705,2368.407334804535,2137.0898151397705,230.93383693695068,0.1339180469512939,0.0 -4500,0.69790685,5.084432,,,,,,,,,,,,,, -4600,0.6151472,4.9844275,,,,,,,,,,,,,, -4700,0.64886427,4.9451194,,,,,,,,,,,,,, -4800,0.69446975,6.34702,,,,,,,,,,,,,, -4900,0.67583793,4.9779773,,,,,,,,,,,,,, -5000,0.6250452,5.156717,,,,,,,,,,,,,, -5100,0.82771015,4.997974,,,,,,,,,,,,,, -5200,0.89456666,4.8684826,,,,,,,,,,,,,, -5300,0.6642363,6.2059617,,,,,,,,,,,,,, -5375,,,0.1900781244039535,4.193375587463379,0.1757399886846542,4.2860283851623535,50000.0,0.1370000094175338,4.693284034729004,10000.0,2557.313781738281,2834.125443696976,2557.313781738281,276.35117983818054,0.1603636741638183,0.0 -5400,0.9133067,6.429204,,,,,,,,,,,,,, -5500,0.9006122,5.4857674,,,,,,,,,,,,,, -5600,0.9013867,4.922329,,,,,,,,,,,,,, -5700,0.7712226,5.2093062,,,,,,,,,,,,,, -5800,0.605332,5.572266,,,,,,,,,,,,,, -5900,0.82127833,5.016424,,,,,,,,,,,,,, -6000,0.89714545,5.3091817,,,,,,,,,,,,,, -6100,0.78441054,4.726543,,,,,,,,,,,,,, -6200,0.6973718,5.365483,,,,,,,,,,,,,, -6280,,,0.2101367115974426,4.0672478675842285,0.1936799883842468,4.1517534255981445,50000.0,0.1522000133991241,4.536256790161133,10000.0,2977.3031663894653,3297.7814099788666,2977.3031663894653,319.9392533302307,0.1883072853088379,0.0 -6300,1.0037487,4.8657293,,,,,,,,,,,,,, -6400,0.77639633,4.970251,,,,,,,,,,,,,, -6500,0.93192315,5.094144,,,,,,,,,,,,,, -6600,0.79254097,5.343972,,,,,,,,,,,,,, -6700,0.6548689,6.297595,,,,,,,,,,,,,, -6800,0.7227132,5.0579367,,,,,,,,,,,,,, -6900,0.7641245,4.5994177,,,,,,,,,,,,,, -7000,0.71927816,6.371228,,,,,,,,,,,,,, -7100,0.846105,6.311278,,,,,,,,,,,,,, -7187,,,0.2340429574251175,3.8879802227020255,0.2138399928808212,4.022542476654053,50000.0,0.1632000058889389,4.448765754699707,10000.0,3397.616981983185,3762.831591129303,3397.616981983185,364.5954835414887,0.2181315422058105,0.0 -7200,0.6176763,6.156828,,,,,,,,,,,,,, -7300,0.6328699,6.1376586,,,,,,,,,,,,,, -7400,1.0859587,4.6488423,,,,,,,,,,,,,, -7500,0.75156605,5.3527045,,,,,,,,,,,,,, -7600,0.76998717,4.5463004,,,,,,,,,,,,,, -7700,0.65441513,4.6346483,,,,,,,,,,,,,, -7800,0.8754725,4.6354427,,,,,,,,,,,,,, -7900,0.5972992,5.523309,,,,,,,,,,,,,, -8000,0.9660842,4.55575,,,,,,,,,,,,,, -8094,,,0.2420898377895355,3.8262104988098153,0.2233599871397018,3.9299376010894775,50000.0,0.1687000095844268,4.404850482940674,10000.0,3817.795110702514,4223.264901161194,3817.795110702514,404.7730875015259,0.2447774410247802,0.0 -8100,0.85145855,4.4891367,,,,,,,,,,,,,, -8200,0.58387274,6.1689396,,,,,,,,,,,,,, -8300,0.83013797,4.6730933,,,,,,,,,,,,,, -8400,0.7921017,4.473904,,,,,,,,,,,,,, -8500,0.8367611,4.451767,,,,,,,,,,,,,, -8600,0.78429675,4.6733017,,,,,,,,,,,,,, -8700,0.81557465,5.0522623,,,,,,,,,,,,,, -8800,1.0183419,4.5376096,,,,,,,,,,,,,, -8900,0.623726,5.602408,,,,,,,,,,,,,, -9000,0.84918475,4.5387154,,,,,,,,,,,,,, -9001,,,0.2715038955211639,3.572762489318848,0.2484599947929382,3.701709032058716,50000.0,0.1903000026941299,4.200378894805908,10000.0,4237.933657169342,4687.290641784668,4237.933657169342,448.5805284976959,0.2743458747863769,0.0 -9100,0.7528808,6.2309427,,,,,,,,,,,,,, -9200,0.7241117,6.155617,,,,,,,,,,,,,, -9300,0.565229,5.911604,,,,,,,,,,,,,, -9400,0.76389354,6.2479377,,,,,,,,,,,,,, -9500,0.63106406,6.253528,,,,,,,,,,,,,, -9600,0.89826673,4.6720343,,,,,,,,,,,,,, -9700,0.8671069,4.5949326,,,,,,,,,,,,,, -9800,0.9878972,4.496834,,,,,,,,,,,,,, -9900,0.9646971,4.507008,,,,,,,,,,,,,, -9908,,,0.2799023389816284,3.503488063812256,0.2557999789714813,3.6524834632873535,50000.0,0.2045000046491623,4.137157917022705,10000.0,4658.288609981537,5149.827425003052,4658.288609981537,490.6866438388825,0.2993443012237549,0.0 -10000,0.74393237,4.435141,,,,,,,,,,,,,, -10100,0.6669383,5.2992105,,,,,,,,,,,,,, -10200,0.6267431,5.6868815,,,,,,,,,,,,,, -10300,0.79999006,4.3638496,,,,,,,,,,,,,, -10400,0.93934244,4.7805357,,,,,,,,,,,,,, -10500,0.81657445,4.3793035,,,,,,,,,,,,,, -10600,0.85426307,4.717287,,,,,,,,,,,,,, -10700,0.7721536,4.9944034,,,,,,,,,,,,,, -10800,0.80107105,6.239219,,,,,,,,,,,,,, -10808,,,0.2586914002895355,3.687410593032837,0.2386999875307083,3.820055484771729,50000.0,0.1838000118732452,4.281069278717041,10000.0,5077.988896608353,5609.243244171143,5077.988896608353,529.848639011383,0.8035287857055664,0.0 -10900,0.85119206,4.456387,,,,,,,,,,,,,, -11000,0.74694294,4.6899905,,,,,,,,,,,,,, -11100,0.8500642,4.4152884,,,,,,,,,,,,,, -11200,0.6849278,5.9820046,,,,,,,,,,,,,, -11300,1.0889766,4.6416726,,,,,,,,,,,,,, -11400,0.58508563,5.9059668,,,,,,,,,,,,,, -11500,0.9213524,4.347582,,,,,,,,,,,,,, -11600,0.78769624,4.28015,,,,,,,,,,,,,, -11700,1.0696558,4.4087467,,,,,,,,,,,,,, -11712,,,0.2904882729053497,3.440134286880493,0.2700800001621246,3.5581839084625244,50000.0,0.2044000029563903,4.096872806549072,10000.0,5498.148392438889,6074.42378115654,5498.148392438889,574.7927012443542,0.8307521343231201,0.0 -11800,0.9459999,4.506239,,,,,,,,,,,,,, -11900,1.2871575,4.3490467,,,,,,,,,,,,,, -12000,0.97026426,4.3828483,,,,,,,,,,,,,, -12100,1.1354142,4.4071884,,,,,,,,,,,,,, -12200,0.8700796,6.0534725,,,,,,,,,,,,,, -12300,0.87124085,5.5002604,,,,,,,,,,,,,, -12400,0.59341764,6.020584,,,,,,,,,,,,,, -12500,1.0783069,4.424327,,,,,,,,,,,,,, -12600,0.77185607,6.1527824,,,,,,,,,,,,,, -12615,,,0.2845117151737213,3.495602369308472,0.262939989566803,3.6273386478424072,50000.0,0.2028000056743621,4.120352745056152,10000.0,5918.338252067566,6537.819251060486,5918.338252067566,617.9171347618103,0.8616728782653809,0.0 -12700,0.8616754,4.4631057,,,,,,,,,,,,,, -12800,0.9672055,4.342987,,,,,,,,,,,,,, -12900,0.85633194,4.199679,,,,,,,,,,,,,, -13000,0.6699335,6.1414065,,,,,,,,,,,,,, -13100,0.8575717,4.4777327,,,,,,,,,,,,,, -13200,1.6282679,4.4908495,,,,,,,,,,,,,, -13300,0.9184172,4.2209067,,,,,,,,,,,,,, -13400,0.8529115,4.4755087,,,,,,,,,,,,,, -13500,0.7884413,5.348583,,,,,,,,,,,,,, -13521,,,0.3322265446186065,3.20548152923584,0.2886999845504761,3.4614808559417725,50000.0,0.2183000147342682,3.9869000911712646,10000.0,6338.559593200684,6998.628676176071,6338.559593200684,658.4205052852631,0.8963735103607178,0.0 -13600,0.95960224,4.2342777,,,,,,,,,,,,,, -13700,0.907327,4.247072,,,,,,,,,,,,,, -13800,0.8328937,4.2370687,,,,,,,,,,,,,, -13900,0.78116477,5.01066,,,,,,,,,,,,,, -14000,0.7207399,4.5017633,,,,,,,,,,,,,, -14100,1.016316,4.2904897,,,,,,,,,,,,,, -14200,0.9744484,4.2568874,,,,,,,,,,,,,, -14300,0.7348255,4.3317533,,,,,,,,,,,,,, -14400,0.75532174,5.723024,,,,,,,,,,,,,, -14428,,,0.301562488079071,3.388521671295166,0.280019998550415,3.510899305343628,50000.0,0.211200013756752,4.033987522125244,10000.0,6758.897345304489,7458.634178161621,6758.897345304489,698.0081906318665,0.9258718490600586,0.0 -14500,1.3661306,4.217581,,,,,,,,,,,,,, -14600,0.98624307,4.2545214,,,,,,,,,,,,,, -14700,0.9103895,5.9654927,,,,,,,,,,,,,, -14800,1.0785432,4.1729293,,,,,,,,,,,,,, -14900,0.6937628,5.8571258,,,,,,,,,,,,,, -15000,0.67796665,6.086543,,,,,,,,,,,,,, -15100,1.0349058,4.220075,,,,,,,,,,,,,, -15200,1.1155249,4.1930914,,,,,,,,,,,,,, -15300,0.9337533,4.4711833,,,,,,,,,,,,,, -15334,,,0.3197460770606994,3.255380868911743,0.2985399961471557,3.3957231044769287,50000.0,0.2208000123500824,3.9570510387420654,10000.0,7178.821423768997,7923.335958003998,7178.821423768997,742.7038342952728,0.957528829574585,0.0 -15400,0.84367025,4.1974745,,,,,,,,,,,,,, -15500,0.82603437,4.4154243,,,,,,,,,,,,,, -15600,0.965935,4.154888,,,,,,,,,,,,,, -15700,0.9489647,4.5009193,,,,,,,,,,,,,, -15800,0.7582674,5.442158,,,,,,,,,,,,,, -15900,1.1784043,4.168757,,,,,,,,,,,,,, -16000,1.0756639,4.1615725,,,,,,,,,,,,,, -16100,0.86575645,4.257573,,,,,,,,,,,,,, -16200,1.0099248,4.7654543,,,,,,,,,,,,,, -16242,,,0.3459570109844208,3.120103120803833,0.2992799878120422,3.3753139972686768,50000.0,0.2251000106334686,3.943704128265381,10000.0,7598.78031373024,8387.46527838707,7598.78031373024,786.7944359779358,0.9867620468139648,0.0 -16300,0.8750512,4.2447433,,,,,,,,,,,,,, -16400,1.1025293,4.393198,,,,,,,,,,,,,, -16500,0.765895,4.518778,,,,,,,,,,,,,, -16600,1.0000968,4.1976175,,,,,,,,,,,,,, -16700,0.92757857,4.5975327,,,,,,,,,,,,,, -16800,0.8072069,4.72233,,,,,,,,,,,,,, -16900,0.8133741,4.1417875,,,,,,,,,,,,,, -17000,0.7676934,4.9365096,,,,,,,,,,,,,, -17100,0.94698286,4.120989,,,,,,,,,,,,,, -17149,,,0.3329882621765136,3.1438841819763184,0.3075200021266937,3.2924156188964844,50000.0,0.2376000136137008,3.858373641967773,10000.0,8019.018817424774,8844.835416793823,8019.018817424774,823.8475241661072,1.0148942470550537,0.0 -17200,0.9223108,4.181518,,,,,,,,,,,,,, -17300,0.97091323,4.2239943,,,,,,,,,,,,,, -17400,0.9788287,4.2639346,,,,,,,,,,,,,, -17500,0.7111505,4.9299417,,,,,,,,,,,,,, -17600,0.9057176,5.6346455,,,,,,,,,,,,,, -17700,0.8708755,4.217663,,,,,,,,,,,,,, -17800,0.8971,4.2291665,,,,,,,,,,,,,, -17900,0.7170459,6.1179695,,,,,,,,,,,,,, -18000,0.9340475,4.2333345,,,,,,,,,,,,,, -18053,,,0.3213476538658142,3.226347208023072,0.2976999878883362,3.3574440479278564,50000.0,0.2346000075340271,3.882843255996704,10000.0,8438.975415468216,9306.65334200859,8438.975415468216,865.6273958683014,1.045391082763672,0.0 -18100,0.84657365,4.108053,,,,,,,,,,,,,, -18200,1.2131792,4.2611895,,,,,,,,,,,,,, -18300,0.8956284,4.1733036,,,,,,,,,,,,,, -18400,1.1357977,4.0879,,,,,,,,,,,,,, -18500,0.7622826,4.277514,,,,,,,,,,,,,, -18600,0.84902525,4.8638163,,,,,,,,,,,,,, -18700,0.7023905,6.020735,,,,,,,,,,,,,, -18800,1.0984638,4.1316133,,,,,,,,,,,,,, -18900,0.93252563,4.4254823,,,,,,,,,,,,,, -18960,,,0.3495312333106994,3.1147797107696533,0.3076199889183044,3.3317410945892334,50000.0,0.2371000051498413,3.893336057662964,10000.0,8859.113800764084,9771.34357905388,8859.113800764084,910.0973136425018,1.0767600536346436,0.0 -19000,1.1658953,4.118202,,,,,,,,,,,,,, -19100,1.0948393,4.101457,,,,,,,,,,,,,, -19200,0.7495793,4.802417,,,,,,,,,,,,,, -19300,1.04694,4.399058,,,,,,,,,,,,,, -19400,0.8827125,4.3577943,,,,,,,,,,,,,, -19500,0.98019314,4.192881,,,,,,,,,,,,,, -19600,0.93969053,4.1531286,,,,,,,,,,,,,, -19700,1.0063589,4.830584,,,,,,,,,,,,,, -19800,1.0108098,4.135188,,,,,,,,,,,,,, -19867,,,0.3369531035423279,3.1348671913146973,0.3179399967193603,3.2607951164245605,50000.0,0.2427000105381012,3.791513919830322,10000.0,9279.488339185717,10235.852065563202,9279.488339185717,954.149953365326,1.10762619972229,0.0 -19900,1.0138499,4.10409,,,,,,,,,,,,,, -20000,0.5715135,5.755466,,,,,,,,,,,,,, -20100,0.98644215,4.2013116,,,,,,,,,,,,,, -20200,0.7365615,5.407537,,,,,,,,,,,,,, -20300,0.86395866,4.09158,,,,,,,,,,,,,, -20400,0.86583114,4.1627636,,,,,,,,,,,,,, -20500,0.9190978,3.9689689,,,,,,,,,,,,,, -20600,0.845826,4.293968,,,,,,,,,,,,,, -20700,0.95792687,6.115809,,,,,,,,,,,,,, -20771,,,0.3226367235183716,3.2174956798553467,0.2995399832725525,3.355835199356079,50000.0,0.2291000038385391,3.92555832862854,10000.0,9699.60234951973,10699.347157001495,9699.60234951973,997.4499151706696,1.1384267807006836,0.0 -20800,0.51618767,5.8922195,,,,,,,,,,,,,, -20900,0.86371607,5.9777093,,,,,,,,,,,,,, -21000,0.84715426,5.008268,,,,,,,,,,,,,, -21100,0.7373493,5.898023,,,,,,,,,,,,,, -21200,0.7986677,4.188284,,,,,,,,,,,,,, -21300,0.88646454,4.0136714,,,,,,,,,,,,,, -21400,0.9271022,5.5035963,,,,,,,,,,,,,, -21500,1.084656,4.1306963,,,,,,,,,,,,,, -21600,0.78644156,5.88802,,,,,,,,,,,,,, -21675,,,0.3468359410762787,3.1015145778656006,0.3150199949741363,3.295905113220215,50000.0,0.2485000044107437,3.8394834995269775,10000.0,10119.622790336609,11165.643399238586,10119.622790336609,1043.6431443691254,1.169325590133667,0.0 -21700,0.87209433,5.011794,,,,,,,,,,,,,, -21800,0.8648128,4.3185124,,,,,,,,,,,,,, -21900,0.92394567,3.9542432,,,,,,,,,,,,,, -22000,0.8149367,5.146124,,,,,,,,,,,,,, -22100,0.8641859,4.023528,,,,,,,,,,,,,, -22200,1.2726238,4.150646,,,,,,,,,,,,,, -22300,0.95336,4.024148,,,,,,,,,,,,,, -22400,0.8154483,5.980691,,,,,,,,,,,,,, -22500,0.6626675,5.9798603,,,,,,,,,,,,,, -22581,,,0.3467577993869781,3.073084592819214,0.3256799876689911,3.1851108074188232,50000.0,0.2472000122070312,3.7654662132263184,10000.0,10539.923010349274,11630.60943365097,10539.923010349274,1088.2273619174955,1.2005112171173096,0.0 -22600,1.0679544,4.749102,,,,,,,,,,,,,, -22700,0.96293986,4.32522,,,,,,,,,,,,,, -22800,0.8213071,4.338603,,,,,,,,,,,,,, -22900,0.817514,4.037875,,,,,,,,,,,,,, -23000,0.9976954,6.07852,,,,,,,,,,,,,, -23100,1.0401876,3.9577117,,,,,,,,,,,,,, -23200,0.7974133,5.6073704,,,,,,,,,,,,,, -23300,0.8652331,4.072365,,,,,,,,,,,,,, -23400,1.121957,4.4998264,,,,,,,,,,,,,, -23486,,,0.3531640470027923,3.034122467041016,0.3314200043678283,3.17514705657959,50000.0,0.2502000033855438,3.758253812789917,10000.0,10959.892746925354,12097.288171768188,10959.892746925354,1134.8523106575012,1.2339890003204346,0.0 -23500,0.8967546,4.047505,,,,,,,,,,,,,, -23600,0.9908957,4.0570235,,,,,,,,,,,,,, -23700,0.8737084,4.467829,,,,,,,,,,,,,, -23800,0.7321724,4.4323463,,,,,,,,,,,,,, -23900,0.82308656,6.056263,,,,,,,,,,,,,, -24000,0.926319,4.649256,,,,,,,,,,,,,, -24100,0.9305346,4.0668936,,,,,,,,,,,,,, -24200,0.8476685,5.968585,,,,,,,,,,,,,, -24300,0.8388769,4.8394046,,,,,,,,,,,,,, -24390,,,0.370410144329071,2.9094772338867188,0.3361800014972687,3.098283052444458,50000.0,0.2608000040054321,3.6874732971191406,10000.0,11380.030760765076,12563.504616975784,11380.030760765076,1180.8493909835815,1.2654187679290771,0.0 -24400,0.83422405,5.178413,,,,,,,,,,,,,, -24500,0.8878768,4.0299435,,,,,,,,,,,,,, -24600,1.0144956,4.038498,,,,,,,,,,,,,, -24700,0.72302383,6.031301,,,,,,,,,,,,,, -24800,1.0194254,3.875382,,,,,,,,,,,,,, -24900,0.9446152,4.043076,,,,,,,,,,,,,, -25000,0.7977149,5.6211143,,,,,,,,,,,,,, -25100,0.72382116,5.366373,,,,,,,,,,,,,, -25200,1.076248,4.179847,,,,,,,,,,,,,, -25296,,,0.3532421886920929,3.014591217041016,0.3339999914169311,3.1400809288024902,50000.0,0.252700001001358,3.7258048057556152,10000.0,11800.12370634079,13029.195685625076,11800.12370634079,1226.3632354736328,1.2985377311706543,0.0 -25300,1.008757,4.073501,,,,,,,,,,,,,, -25400,1.4674184,4.160227,,,,,,,,,,,,,, -25500,0.9332014,4.690338,,,,,,,,,,,,,, -25600,0.66335523,6.0422297,,,,,,,,,,,,,, -25700,0.98238814,4.207319,,,,,,,,,,,,,, -25800,0.902375,3.9195433,,,,,,,,,,,,,, -25900,0.9538682,4.141796,,,,,,,,,,,,,, -26000,0.7823022,4.933389,,,,,,,,,,,,,, -26100,1.0862799,3.9265103,,,,,,,,,,,,,, -26200,1.1162612,4.0500517,,,,,,,,,,,,,, -26202,,,0.3569726347923279,3.0196990966796875,0.3306399881839752,3.152756690979004,50000.0,0.2502000033855438,3.755117416381836,10000.0,12220.688608169556,13495.26099729538,12220.688608169556,1271.7826147079468,1.3286545276641846,0.0 -26300,0.77310973,5.101513,,,,,,,,,,,,,, -26400,0.89730483,4.9903727,,,,,,,,,,,,,, -26500,1.0380471,4.0643597,,,,,,,,,,,,,, -26600,1.0236772,4.008742,,,,,,,,,,,,,, -26700,0.94896096,4.7944937,,,,,,,,,,,,,, -26800,1.1464947,4.4151683,,,,,,,,,,,,,, -26900,0.95406216,3.8784955,,,,,,,,,,,,,, -27000,0.863099,5.514673,,,,,,,,,,,,,, -27100,0.77870846,5.316947,,,,,,,,,,,,,, -27109,,,0.3746679723262787,2.8854188919067383,0.3447200059890747,3.0645339488983154,50000.0,0.2606000006198883,3.661961078643799,10000.0,12640.947337388992,13963.037720918655,12640.947337388992,1319.2206687927246,1.3581278324127195,0.0 -27200,0.78452104,5.3260946,,,,,,,,,,,,,, -27300,0.8839861,3.8560312,,,,,,,,,,,,,, -27400,0.7745389,5.5757933,,,,,,,,,,,,,, -27500,0.805647,4.2912793,,,,,,,,,,,,,, -27600,1.0058842,4.3166604,,,,,,,,,,,,,, -27700,0.96503717,3.8791013,,,,,,,,,,,,,, -27800,0.90585256,3.9504983,,,,,,,,,,,,,, -27900,0.90572274,4.031497,,,,,,,,,,,,,, -28000,0.7005461,5.6979504,,,,,,,,,,,,,, -28014,,,0.36474609375,2.9812686443328857,0.3385799825191498,3.11944842338562,50000.0,0.2556000053882599,3.7341465950012207,10000.0,13060.903878450394,14423.568979501724,13060.903878450394,1359.713604927063,1.389518976211548,0.0 -28100,1.1190641,3.9213734,,,,,,,,,,,,,, -28200,0.9644594,3.9952035,,,,,,,,,,,,,, -28300,1.0256183,3.9329803,,,,,,,,,,,,,, -28400,1.0655751,4.3930817,,,,,,,,,,,,,, -28500,0.8386179,5.565655,,,,,,,,,,,,,, -28600,1.0211564,3.935107,,,,,,,,,,,,,, -28700,0.94564384,3.9886498,,,,,,,,,,,,,, -28800,1.066103,4.018146,,,,,,,,,,,,,, -28900,0.9932887,4.103701,,,,,,,,,,,,,, -28921,,,0.3594140410423279,3.011570692062378,0.3353599905967712,3.157099723815918,50000.0,0.2603000104427337,3.740095853805542,10000.0,13481.02701640129,14888.096348762512,13481.02701640129,1404.037055015564,1.4189238548278809,0.0 -29000,0.78608924,4.927191,,,,,,,,,,,,,, -29100,0.8204841,6.067853,,,,,,,,,,,,,, -29200,1.0538934,3.8708963,,,,,,,,,,,,,, -29300,0.62902784,6.0677414,,,,,,,,,,,,,, -29400,0.8360273,4.7066402,,,,,,,,,,,,,, -29500,0.8900874,4.5879536,,,,,,,,,,,,,, -29600,0.6728018,5.3896003,,,,,,,,,,,,,, -29700,0.6921838,6.0376015,,,,,,,,,,,,,, -29800,1.0201968,3.8898818,,,,,,,,,,,,,, -29827,,,0.3867773413658142,2.8524653911590576,0.3562400043010711,3.020305633544922,50000.0,0.2722000181674957,3.626361608505249,10000.0,13901.145128250122,15352.472299575806,13901.145128250122,1448.2114634513855,1.4525518417358398,0.0 -29900,0.96877944,3.9799411,,,,,,,,,,,,,, -30000,0.8207971,6.0463657,,,,,,,,,,,,,, -30100,0.87852186,3.975509,,,,,,,,,,,,,, -30200,1.0243695,4.511607,,,,,,,,,,,,,, -30300,1.0567857,3.9175901,,,,,,,,,,,,,, -30400,1.0005327,4.056678,,,,,,,,,,,,,, -30500,1.2555034,4.005375,,,,,,,,,,,,,, -30600,0.82219595,6.0555067,,,,,,,,,,,,,, -30700,0.86513865,3.736522,,,,,,,,,,,,,, -30733,,,0.40380859375,2.765493631362915,0.3507199883460998,3.034952402114868,50000.0,0.2699000239372253,3.6251919269561768,10000.0,14321.3808157444,15817.029076099396,14321.3808157444,1492.4437718391418,1.489807367324829,0.0 -30800,0.9013755,3.819302,,,,,,,,,,,,,, -30900,1.1172849,3.9863377,,,,,,,,,,,,,, -31000,1.0403138,3.8685017,,,,,,,,,,,,,, -31100,1.0837094,3.8200154,,,,,,,,,,,,,, -31200,1.0077212,4.265631,,,,,,,,,,,,,, -31300,0.88426876,4.158582,,,,,,,,,,,,,, -31400,0.83577484,4.693437,,,,,,,,,,,,,, -31500,1.0611925,3.7613697,,,,,,,,,,,,,, -31600,1.0432358,3.8795033,,,,,,,,,,,,,, -31638,,,0.372871071100235,2.894007921218872,0.3455199897289276,3.053222179412842,50000.0,0.2657999992370605,3.6384551525115967,10000.0,14741.569679737093,16281.695380210876,14741.569679737093,1536.8387801647186,1.5209715366363523,0.0 -31700,0.95083225,3.813273,,,,,,,,,,,,,, -31800,0.8113702,3.8326154,,,,,,,,,,,,,, -31900,0.97358805,3.8458004,,,,,,,,,,,,,, -32000,0.81759816,4.8224945,,,,,,,,,,,,,, -32100,1.0051626,3.811139,,,,,,,,,,,,,, -32200,1.0333861,3.8246884,,,,,,,,,,,,,, -32300,1.0216604,3.8735752,,,,,,,,,,,,,, -32400,0.95449305,4.392437,,,,,,,,,,,,,, -32500,0.9626884,3.9720078,,,,,,,,,,,,,, -32541,,,0.3836914002895355,2.851547718048096,0.3562799990177154,3.018123149871826,50000.0,0.2789000272750854,3.585503578186035,10000.0,15161.727267503738,16745.701989650726,15161.727267503738,1580.6043591499329,1.5537612438201904,0.0 -32600,0.80937886,4.2221017,,,,,,,,,,,,,, -32700,0.8714666,4.289056,,,,,,,,,,,,,, -32800,1.0732828,4.0641537,,,,,,,,,,,,,, -32900,1.1179577,3.9548903,,,,,,,,,,,,,, -33000,1.032812,3.855898,,,,,,,,,,,,,, -33100,0.8859893,3.9132357,,,,,,,,,,,,,, -33200,1.0894243,4.7134247,,,,,,,,,,,,,, -33300,1.2985588,3.8836565,,,,,,,,,,,,,, -33400,0.865427,4.6368704,,,,,,,,,,,,,, -33446,,,0.4034960865974426,2.7900664806365967,0.3520599901676178,3.0779330730438232,50000.0,0.2671000063419342,3.6647629737854,10000.0,15581.699176073074,17207.740622758865,15581.699176073074,1622.586046218872,1.5887606143951416,0.0 -33500,1.0971024,3.9673457,,,,,,,,,,,,,, -33600,0.9092609,4.29397,,,,,,,,,,,,,, -33700,1.0450102,3.9876728,,,,,,,,,,,,,, -33800,0.8712444,3.8722477,,,,,,,,,,,,,, -33900,0.905508,5.4135265,,,,,,,,,,,,,, -34000,1.2108855,3.7716002,,,,,,,,,,,,,, -34100,0.7897271,5.7364798,,,,,,,,,,,,,, -34200,0.8330647,6.0276957,,,,,,,,,,,,,, -34300,1.1664786,3.7982001,,,,,,,,,,,,,, -34350,,,0.3739843666553497,2.903052568435669,0.3515399992465973,3.0420234203338623,50000.0,0.2727999985218048,3.618873119354248,10000.0,16001.772747516632,17669.6182513237,16001.772747516632,1664.3043756484983,1.623687982559204,0.0 -34400,0.94250286,4.334894,,,,,,,,,,,,,, -34500,0.7803128,4.263715,,,,,,,,,,,,,, -34600,1.0653663,4.065797,,,,,,,,,,,,,, -34700,1.0344976,3.918528,,,,,,,,,,,,,, -34800,0.9308787,3.7883317,,,,,,,,,,,,,, -34900,0.879832,5.8315454,,,,,,,,,,,,,, -35000,1.0430491,3.8029053,,,,,,,,,,,,,, -35100,1.1721113,3.87502,,,,,,,,,,,,,, -35200,0.68930656,5.9347095,,,,,,,,,,,,,, -35257,,,0.3941406309604645,2.7809207439422607,0.3646000027656555,2.946488380432129,50000.0,0.281000018119812,3.570101261138916,10000.0,16422.13542985916,18129.60253977776,16422.13542985916,1703.8427374362946,1.655987024307251,0.0 -35300,0.7353336,5.8220077,,,,,,,,,,,,,, -35400,0.9827943,4.346022,,,,,,,,,,,,,, -35500,1.0220169,3.8051963,,,,,,,,,,,,,, -35600,1.1686375,4.069987,,,,,,,,,,,,,, -35700,1.0908738,3.7591903,,,,,,,,,,,,,, -35800,0.68405133,5.055809,,,,,,,,,,,,,, -35900,0.91085845,4.08763,,,,,,,,,,,,,, -36000,0.86606365,4.6329203,,,,,,,,,,,,,, -36100,1.1037617,3.7944841,,,,,,,,,,,,,, -36160,,,0.4143359363079071,2.687903881072998,0.374699980020523,2.915180206298828,50000.0,0.289000004529953,3.515514612197876,10000.0,16842.51434326172,18596.41860938072,16842.51434326172,1750.1966817378998,1.688964605331421,0.0 -36200,0.99904794,3.741406,,,,,,,,,,,,,, -36300,0.71945447,5.2783513,,,,,,,,,,,,,, -36400,1.208693,3.8712234,,,,,,,,,,,,,, -36500,1.0504916,4.0058885,,,,,,,,,,,,,, -36600,0.88236755,3.7954028,,,,,,,,,,,,,, -36700,1.0461606,3.6490574,,,,,,,,,,,,,, -36800,1.0732002,3.9031796,,,,,,,,,,,,,, -36900,1.1591997,4.077625,,,,,,,,,,,,,, -37000,0.7727172,5.4671855,,,,,,,,,,,,,, -37067,,,0.3991992175579071,2.74881911277771,0.3721599876880646,2.8969502449035645,50000.0,0.296500027179718,3.5063552856445312,10000.0,17262.60005879402,19060.820858955383,17262.60005879402,1794.428700685501,1.7227368354797363,0.0 -37100,1.143582,3.8978643,,,,,,,,,,,,,, -37200,0.81862694,5.376846,,,,,,,,,,,,,, -37300,1.0802437,4.031866,,,,,,,,,,,,,, -37400,1.0048363,3.6584404,,,,,,,,,,,,,, -37500,0.8823985,3.6704607,,,,,,,,,,,,,, -37600,1.0620458,5.1449957,,,,,,,,,,,,,, -37700,0.97909766,3.794752,,,,,,,,,,,,,, -37800,1.0254581,3.8397465,,,,,,,,,,,,,, -37900,0.7479708,5.784243,,,,,,,,,,,,,, -37975,,,0.4001367092132568,2.7493491172790527,0.3765600025653839,2.8872478008270264,50000.0,0.2865000069141388,3.4961540699005127,10000.0,17682.65801167488,19521.791479349136,17682.65801167488,1835.2586352825165,1.7550582885742188,0.0 -38000,0.9232344,4.115963,,,,,,,,,,,,,, -38100,1.0475839,3.8243399,,,,,,,,,,,,,, -38200,1.0254359,3.9191875,,,,,,,,,,,,,, -38300,1.0684509,6.0892324,,,,,,,,,,,,,, -38400,1.0365902,3.8363786,,,,,,,,,,,,,, -38500,0.89671904,3.8466787,,,,,,,,,,,,,, -38600,0.6823714,5.6403236,,,,,,,,,,,,,, -38700,1.0171503,4.115528,,,,,,,,,,,,,, -38800,0.75712633,5.0712953,,,,,,,,,,,,,, -38881,,,0.4105273485183716,2.71467924118042,0.3744199872016907,2.916626453399658,50000.0,0.2857000231742859,3.532196283340454,10000.0,18102.94758486748,19988.588298797607,18102.94758486748,1881.6751172542567,1.7953834533691406,0.0 -38900,0.7998069,5.3317814,,,,,,,,,,,,,, -39000,1.2248858,3.7048674,,,,,,,,,,,,,, -39100,0.83305264,4.4902477,,,,,,,,,,,,,, -39200,0.8463672,6.0034995,,,,,,,,,,,,,, -39300,0.91317356,4.0085077,,,,,,,,,,,,,, -39400,0.7169103,5.912732,,,,,,,,,,,,,, -39500,0.82733107,5.9850173,,,,,,,,,,,,,, -39600,1.1241442,3.8121257,,,,,,,,,,,,,, -39700,1.02437,3.6770215,,,,,,,,,,,,,, -39787,,,0.3961132764816284,2.812042713165283,0.3680199980735779,2.944639205932617,50000.0,0.2865000069141388,3.536340713500977,10000.0,18522.871568918228,20452.82318210601,18522.871568918228,1925.902155160904,1.829215288162232,0.0 -39800,0.75447965,5.1972446,,,,,,,,,,,,,, -39900,1.044009,4.4715824,,,,,,,,,,,,,, -40000,1.0213645,3.9048452,,,,,,,,,,,,,, -40100,1.0518017,3.9140737,,,,,,,,,,,,,, -40200,0.70189893,5.969816,,,,,,,,,,,,,, -40300,1.077737,3.8249352,,,,,,,,,,,,,, -40400,0.9582539,3.7339306,,,,,,,,,,,,,, -40500,1.0673748,3.9801984,,,,,,,,,,,,,, -40600,0.9168899,3.775505,,,,,,,,,,,,,, -40692,,,0.3843359351158142,2.880928039550781,0.3576000034809112,3.033652782440185,50000.0,0.2774000167846679,3.621984243392944,10000.0,18942.88044810295,20917.042991399765,18942.88044810295,1970.02938914299,1.8625555038452148,0.0 -40700,0.915003,4.587262,,,,,,,,,,,,,, -40800,0.98034376,3.7651753,,,,,,,,,,,,,, -40900,0.975759,3.859294,,,,,,,,,,,,,, -41000,1.0907348,3.8632898,,,,,,,,,,,,,, -41100,1.0752149,3.8657603,,,,,,,,,,,,,, -41200,1.2486737,4.1243806,,,,,,,,,,,,,, -41300,0.7436906,5.817233,,,,,,,,,,,,,, -41400,0.9822512,3.6063976,,,,,,,,,,,,,, -41500,1.0198755,5.4315543,,,,,,,,,,,,,, -41598,,,0.4193945229053497,2.614825487136841,0.3873199820518493,2.794647216796875,50000.0,0.3025000095367431,3.3992068767547607,10000.0,19363.05272817612,21379.428280830383,19363.05272817612,2012.153751850128,1.900541305541992,0.0 -41600,1.056055,3.8559375,,,,,,,,,,,,,, -41700,0.8605054,4.9504805,,,,,,,,,,,,,, -41800,1.2836057,3.966124,,,,,,,,,,,,,, -41900,0.94241714,3.9796429,,,,,,,,,,,,,, -42000,1.0980844,3.743333,,,,,,,,,,,,,, -42100,1.1092818,3.9498537,,,,,,,,,,,,,, -42200,1.0073456,3.8983028,,,,,,,,,,,,,, -42300,0.7294127,5.511849,,,,,,,,,,,,,, -42400,0.8887565,4.412218,,,,,,,,,,,,,, -42500,1.155432,3.6797416,,,,,,,,,,,,,, -42501,,,0.4103906154632568,2.7060678005218506,0.3838399946689605,2.85146713256836,50000.0,0.2962000072002411,3.468750238418579,10000.0,19783.368386268616,21846.06435275078,19783.368386268616,2058.390007257461,1.9348745346069336,0.0 -42600,0.9753873,3.7432349,,,,,,,,,,,,,, -42700,0.75216883,5.024841,,,,,,,,,,,,,, -42800,0.8110386,5.2070622,,,,,,,,,,,,,, -42900,1.1776309,3.8161347,,,,,,,,,,,,,, -43000,0.99902713,3.640152,,,,,,,,,,,,,, -43100,1.0281929,3.6426413,,,,,,,,,,,,,, -43200,0.91801107,4.205293,,,,,,,,,,,,,, -43300,1.0971446,3.7968235,,,,,,,,,,,,,, -43400,0.9283715,3.6941254,,,,,,,,,,,,,, -43408,,,0.4212304651737213,2.6352226734161377,0.3901399970054626,2.7732999324798584,50000.0,0.3060000240802765,3.412951707839966,10000.0,20203.399605989456,22306.24627304077,20203.399605989456,2098.456430912018,1.9667832851409912,0.0 -43500,1.0383253,3.7436848,,,,,,,,,,,,,, -43600,1.0937381,3.8247318,,,,,,,,,,,,,, -43700,0.7257154,5.389971,,,,,,,,,,,,,, -43800,1.0150609,3.7796144,,,,,,,,,,,,,, -43900,1.1967862,3.7384112,,,,,,,,,,,,,, -44000,1.1489253,3.7546906,,,,,,,,,,,,,, -44100,0.8602142,5.3265414,,,,,,,,,,,,,, -44200,1.0642176,3.7008123,,,,,,,,,,,,,, -44300,0.8934915,3.7511098,,,,,,,,,,,,,, -44314,,,0.424628883600235,2.5954785346984863,0.3944399952888489,2.772815465927124,50000.0,0.3040000200271606,3.4005684852600098,10000.0,20623.64171051979,22768.33304667473,20623.64171051979,2140.2166497707367,1.9997265338897705,0.0 -44400,0.9023878,4.9066367,,,,,,,,,,,,,, -44500,1.2310241,3.726229,,,,,,,,,,,,,, -44600,1.1705475,3.8955297,,,,,,,,,,,,,, -44700,1.1986057,4.275909,,,,,,,,,,,,,, -44800,0.8190817,5.1953077,,,,,,,,,,,,,, -44900,1.2825103,3.883586,,,,,,,,,,,,,, -45000,0.7539216,5.5166383,,,,,,,,,,,,,, -45100,0.7309912,5.2134204,,,,,,,,,,,,,, -45200,0.8481608,6.0143027,,,,,,,,,,,,,, -45220,,,0.4167773425579071,2.6995503902435303,0.3911199867725372,2.838123083114624,50000.0,0.3058000206947326,3.4516892433166504,10000.0,21043.79127836228,23235.264661312103,21043.79127836228,2186.9113490581512,2.037261247634888,0.0 -45300,0.9503035,3.7706404,,,,,,,,,,,,,, -45400,0.96935636,3.766499,,,,,,,,,,,,,, -45500,1.1166606,3.6542838,,,,,,,,,,,,,, -45600,0.9917832,3.6639266,,,,,,,,,,,,,, -45700,0.7095108,5.7600427,,,,,,,,,,,,,, -45800,0.75978553,5.411741,,,,,,,,,,,,,, -45900,0.97626305,3.810297,,,,,,,,,,,,,, -46000,0.91448706,5.6461124,,,,,,,,,,,,,, -46100,1.0206107,3.6861045,,,,,,,,,,,,,, -46125,,,0.4143945276737213,2.6916167736053467,0.3907199800014496,2.808105945587158,50000.0,0.2909000217914581,3.441751718521118,10000.0,21463.8963804245,23699.373474121094,21463.8963804245,2230.829014539718,2.072231769561768,0.0 -46200,1.055854,3.7520065,,,,,,,,,,,,,, -46300,0.84049445,4.3368907,,,,,,,,,,,,,, -46400,0.9696662,3.616074,,,,,,,,,,,,,, -46500,0.9523535,3.8337877,,,,,,,,,,,,,, -46600,1.1028153,3.5284858,,,,,,,,,,,,,, -46700,0.97184294,3.8017628,,,,,,,,,,,,,, -46800,0.7922237,5.753168,,,,,,,,,,,,,, -46900,0.8319672,4.960566,,,,,,,,,,,,,, -47000,0.75492495,4.971003,,,,,,,,,,,,,, -47029,,,0.4289257824420929,2.5802981853485107,0.3979199826717376,2.764410257339477,50000.0,0.3026000261306762,3.3919708728790283,10000.0,21884.086091279984,24165.378796577454,21884.086091279984,2276.559750556946,2.1068155765533447,0.0 -47100,1.02344,3.7506187,,,,,,,,,,,,,, -47200,0.9667769,3.5424845,,,,,,,,,,,,,, -47300,0.86147064,5.353689,,,,,,,,,,,,,, -47400,0.65462506,5.872892,,,,,,,,,,,,,, -47500,1.0504767,3.91816,,,,,,,,,,,,,, -47600,1.4160504,3.818594,,,,,,,,,,,,,, -47700,0.9114255,4.6664863,,,,,,,,,,,,,, -47800,0.9469828,4.4223585,,,,,,,,,,,,,, -47900,1.0211029,3.7175093,,,,,,,,,,,,,, -47934,,,0.4466210901737213,2.509028196334839,0.4015399813652038,2.7524359226226807,50000.0,0.3137000203132629,3.3566060066223145,10000.0,22304.20796895027,24629.18269968033,22304.20796895027,2320.15681385994,2.1419427394866943,0.0 -48000,0.98550963,4.4596844,,,,,,,,,,,,,, -48100,1.1470484,3.6943007,,,,,,,,,,,,,, -48200,0.99307716,3.8452077,,,,,,,,,,,,,, -48300,1.1343548,3.8190844,,,,,,,,,,,,,, -48400,1.0564089,3.6857734,,,,,,,,,,,,,, -48500,1.1605955,3.449694,,,,,,,,,,,,,, -48600,1.0910289,3.679259,,,,,,,,,,,,,, -48700,1.3142103,3.8013942,,,,,,,,,,,,,, -48800,0.9468855,3.6226702,,,,,,,,,,,,,, -48840,,,0.422656238079071,2.647350788116455,0.3947599828243255,2.812582969665528,50000.0,0.3037000000476837,3.428317070007324,10000.0,22724.48389363289,25092.585326194763,22724.48389363289,2363.196498155594,2.1787829399108887,0.0 -48900,1.0718968,4.3668003,,,,,,,,,,,,,, -49000,0.8829202,4.855129,,,,,,,,,,,,,, -49100,1.2032663,3.6495836,,,,,,,,,,,,,, -49200,0.9596657,3.6011326,,,,,,,,,,,,,, -49300,0.8322063,4.9493756,,,,,,,,,,,,,, -49400,1.053928,3.6763391,,,,,,,,,,,,,, -49500,0.97275555,3.5497713,,,,,,,,,,,,,, -49600,0.6027409,5.815443,,,,,,,,,,,,,, -49700,1.0540589,3.6812298,,,,,,,,,,,,,, -49746,,,0.4342578053474426,2.547525405883789,0.3990799784660339,2.7272603511810303,50000.0,0.3099000155925751,3.355839252471924,10000.0,23144.46818780899,25557.598313570023,23144.46818780899,2408.141443490982,2.212597131729126,0.0 -49800,0.99660665,3.5555549,,,,,,,,,,,,,, -49900,0.9727672,3.8070364,,,,,,,,,,,,,, -50000,0.9975298,3.5969923,,,,,,,,,,,,,, -50100,0.68948245,5.789691,,,,,,,,,,,,,, -50200,1.0194671,3.5672596,,,,,,,,,,,,,, -50300,0.9250716,3.9697847,,,,,,,,,,,,,, -50400,1.0146999,6.005576,,,,,,,,,,,,,, -50500,1.1007469,3.8321717,,,,,,,,,,,,,, -50600,1.0277863,3.7200346,,,,,,,,,,,,,, -50654,,,0.4611327946186065,2.438607215881348,0.4008199870586395,2.7578697204589844,50000.0,0.313400000333786,3.36251187324524,10000.0,23564.61387562752,26023.98063802719,23564.61387562752,2454.2899844646454,2.249370574951172,0.0 -50700,1.0131359,4.292828,,,,,,,,,,,,,, -50800,1.0822704,3.5602233,,,,,,,,,,,,,, -50900,1.072422,3.6268258,,,,,,,,,,,,,, -51000,1.0233507,3.663537,,,,,,,,,,,,,, -51100,1.3110554,3.5988765,,,,,,,,,,,,,, -51200,0.864685,5.3656235,,,,,,,,,,,,,, -51300,1.1694301,3.6540537,,,,,,,,,,,,,, -51400,0.9296033,3.622305,,,,,,,,,,,,,, -51500,0.99696136,3.5202644,,,,,,,,,,,,,, -51560,,,0.4273632764816284,2.59151029586792,0.4017999768257141,2.755169630050659,50000.0,0.314300000667572,3.3453259468078613,10000.0,23984.93890619278,26487.90759134293,23984.93890619278,2497.8036675453186,2.286559820175171,0.0 -51600,0.7541987,5.8813186,,,,,,,,,,,,,, -51700,0.9525409,4.3555975,,,,,,,,,,,,,, -51800,0.8469765,5.663669,,,,,,,,,,,,,, -51900,1.05682,3.4692323,,,,,,,,,,,,,, -52000,1.2690939,4.142297,,,,,,,,,,,,,, -52100,0.7899046,5.4411383,,,,,,,,,,,,,, -52200,1.0051132,3.5605865,,,,,,,,,,,,,, -52300,0.94512945,3.5787024,,,,,,,,,,,,,, -52400,0.944969,4.555226,,,,,,,,,,,,,, -52465,,,0.4356249868869781,2.548950672149658,0.4038800001144409,2.7284858226776123,50000.0,0.3096000254154205,3.3761465549468994,10000.0,24404.984939336777,26951.052780389786,24404.984939336777,2540.815685033798,2.323159694671631,0.0 -52500,0.97327745,3.9474645,,,,,,,,,,,,,, -52600,0.8892608,5.546914,,,,,,,,,,,,,, -52700,0.6359026,5.826551,,,,,,,,,,,,,, -52800,0.92259735,3.9613485,,,,,,,,,,,,,, -52900,1.1603988,4.5902047,,,,,,,,,,,,,, -53000,1.0314101,3.7220275,,,,,,,,,,,,,, -53100,1.1311727,3.7731857,,,,,,,,,,,,,, -53200,0.8787946,3.718727,,,,,,,,,,,,,, -53300,0.97710544,4.078141,,,,,,,,,,,,,, -53370,,,0.4490039050579071,2.478402137756348,0.4007599949836731,2.747389316558838,50000.0,0.3069000244140625,3.367358684539795,10000.0,24825.079872131348,27416.082621097565,24825.079872131348,2585.6624703407288,2.361346244812012,0.0 -53400,0.99824536,3.932377,,,,,,,,,,,,,, -53500,0.8657197,5.957645,,,,,,,,,,,,,, -53600,0.8457476,4.9005218,,,,,,,,,,,,,, -53700,0.98691475,3.5725052,,,,,,,,,,,,,, -53800,0.8779948,5.752632,,,,,,,,,,,,,, -53900,0.7183251,5.3810053,,,,,,,,,,,,,, -54000,0.9695945,4.5795846,,,,,,,,,,,,,, -54100,1.2081413,3.5920398,,,,,,,,,,,,,, -54200,1.0670743,3.5947537,,,,,,,,,,,,,, -54276,,,0.4378320276737213,2.539738655090332,0.4072999954223633,2.691128969192505,50000.0,0.3165000081062317,3.3001091480255127,10000.0,25245.164329767227,27879.24419236183,25245.164329767227,2628.652004241944,2.3979926109313965,0.0 -54300,0.8645099,3.9561367,,,,,,,,,,,,,, -54400,1.1916442,3.7786236,,,,,,,,,,,,,, -54500,1.1032304,3.4343681,,,,,,,,,,,,,, -54600,1.1588756,3.624021,,,,,,,,,,,,,, -54700,1.1761985,3.8059905,,,,,,,,,,,,,, -54800,0.86234766,4.249125,,,,,,,,,,,,,, -54900,1.0334973,3.547945,,,,,,,,,,,,,, -55000,0.9697398,3.4614959,,,,,,,,,,,,,, -55100,0.8548281,4.4312315,,,,,,,,,,,,,, -55183,,,0.4406445324420929,2.527313470840454,0.4094799757003784,2.68900203704834,50000.0,0.3144000172615051,3.315009593963623,10000.0,25665.14268875122,28344.652994155884,25665.14268875122,2673.9957478046417,2.4337828159332275,0.0 -55200,1.0323783,3.6012256,,,,,,,,,,,,,, -55300,0.89236987,3.5980058,,,,,,,,,,,,,, -55400,0.9765656,3.430606,,,,,,,,,,,,,, -55500,0.6943748,5.8352566,,,,,,,,,,,,,, -55600,1.0337698,3.7141328,,,,,,,,,,,,,, -55700,0.8734805,4.791636,,,,,,,,,,,,,, -55800,0.8957836,4.9006963,,,,,,,,,,,,,, -55900,0.9852515,3.5569344,,,,,,,,,,,,,, -56000,0.7637875,5.692889,,,,,,,,,,,,,, -56089,,,0.4562695324420929,2.470243453979492,0.4152199923992157,2.69305157661438,50000.0,0.3237000107765198,3.3011837005615234,10000.0,26085.327237844467,28810.783006429672,26085.327237844467,2719.850363969803,2.47387957572937,0.0 -56100,1.0691692,3.759647,,,,,,,,,,,,,, -56200,0.9857457,3.6353025,,,,,,,,,,,,,, -56300,0.85565567,4.649613,,,,,,,,,,,,,, -56400,1.0517482,3.676245,,,,,,,,,,,,,, -56500,0.89961296,4.0850153,,,,,,,,,,,,,, -56600,1.1262356,3.4867954,,,,,,,,,,,,,, -56700,1.0648844,3.6075552,,,,,,,,,,,,,, -56800,1.182407,3.6343071,,,,,,,,,,,,,, -56900,1.0937861,3.4895704,,,,,,,,,,,,,, -56995,,,0.4382421672344208,2.575744867324829,0.4134999811649322,2.716792821884156,50000.0,0.3207000195980072,3.299652099609375,10000.0,26505.334926843643,29276.715250492096,26505.334926843643,2765.686028242111,2.5117812156677246,0.0 -57000,1.1985177,3.6011698,,,,,,,,,,,,,, -57100,1.298841,3.6624339,,,,,,,,,,,,,, -57200,0.9135745,3.926844,,,,,,,,,,,,,, -57300,1.018082,3.7281384,,,,,,,,,,,,,, -57400,1.0241423,3.8595514,,,,,,,,,,,,,, -57500,0.95531213,3.360098,,,,,,,,,,,,,, -57600,0.7955826,4.843897,,,,,,,,,,,,,, -57700,0.9720583,3.8312807,,,,,,,,,,,,,, -57800,1.0710325,3.716051,,,,,,,,,,,,,, -57900,1.2126684,3.558702,,,,,,,,,,,,,, -57901,,,0.4457812309265136,2.49980092048645,0.4198599755764007,2.653440475463867,50000.0,0.3211000263690948,3.273627281188965,10000.0,26925.474437713623,29741.56713962555,26925.474437713623,2810.310459375381,2.5485615730285645,0.0 -58000,0.723683,5.7605686,,,,,,,,,,,,,, -58100,0.7242757,5.8540263,,,,,,,,,,,,,, -58200,1.1422821,3.5318038,,,,,,,,,,,,,, -58300,1.0681001,3.608974,,,,,,,,,,,,,, -58400,0.9391205,5.313471,,,,,,,,,,,,,, -58500,1.024844,3.6203449,,,,,,,,,,,,,, -58600,0.86880493,5.5195603,,,,,,,,,,,,,, -58700,0.9232465,3.556032,,,,,,,,,,,,,, -58800,0.7785976,5.4816093,,,,,,,,,,,,,, -58809,,,0.4603320062160492,2.395381212234497,0.4214800000190735,2.604405164718628,50000.0,0.3300000131130218,3.225630521774292,10000.0,27345.811821460724,30206.088725328445,27345.811821460724,2854.4081478118896,2.5851545333862305,0.0 -58900,0.6368253,5.6391926,,,,,,,,,,,,,, -59000,1.1511246,3.6400123,,,,,,,,,,,,,, -59100,1.0735899,3.4554377,,,,,,,,,,,,,, -59200,0.964837,3.5830147,,,,,,,,,,,,,, -59300,0.7678216,5.116944,,,,,,,,,,,,,, -59400,1.031897,3.6077778,,,,,,,,,,,,,, -59500,0.9882977,4.057981,,,,,,,,,,,,,, -59600,1.324414,3.5980318,,,,,,,,,,,,,, -59700,1.1573732,3.723999,,,,,,,,,,,,,, -59714,,,0.4366796910762787,2.612844705581665,0.4098999798297882,2.748467445373535,50000.0,0.3210000097751617,3.3422865867614746,10000.0,27765.990793704987,30669.352987527847,27765.990793704987,2897.407002687454,2.621112585067749,0.0 -59800,0.86382055,5.7426634,,,,,,,,,,,,,, -59900,1.2167453,3.659268,,,,,,,,,,,,,, -60000,0.7577666,5.0124044,,,,,,,,,,,,,, -60100,0.98023796,3.5831816,,,,,,,,,,,,,, -60200,0.9779506,3.7569766,,,,,,,,,,,,,, -60300,1.110737,3.43341,,,,,,,,,,,,,, -60400,0.9497982,3.6371393,,,,,,,,,,,,,, -60500,0.92628634,4.2085376,,,,,,,,,,,,,, -60600,0.9755059,3.6248133,,,,,,,,,,,,,, -60619,,,0.4499804675579071,2.467857837677002,0.4191199839115143,2.639002799987793,50000.0,0.3256000280380249,3.250809669494629,10000.0,28186.21174645424,31138.09031653404,28186.21174645424,2945.836772441864,2.6567232608795166,0.0 -60700,0.9939459,4.576383,,,,,,,,,,,,,, -60800,1.0837295,3.579424,,,,,,,,,,,,,, -60900,0.7894731,5.7610188,,,,,,,,,,,,,, -61000,0.97896165,4.0751343,,,,,,,,,,,,,, -61100,1.1029814,3.4028156,,,,,,,,,,,,,, -61200,1.1973932,3.5473723,,,,,,,,,,,,,, -61300,0.81598777,5.421104,,,,,,,,,,,,,, -61400,1.2488959,3.6955194,,,,,,,,,,,,,, -61500,1.1605903,3.58673,,,,,,,,,,,,,, -61528,,,0.4624609351158142,2.3810806274414062,0.4293999969959259,2.5703067779541016,50000.0,0.3350000083446502,3.217402696609497,10000.0,28606.36927008629,31602.596896648407,28606.36927008629,2990.0949623584747,2.695803642272949,0.0 -61600,0.9054037,5.0489864,,,,,,,,,,,,,, -61700,1.1819541,3.5872908,,,,,,,,,,,,,, -61800,1.1662145,3.5520787,,,,,,,,,,,,,, -61900,1.1766075,3.7921007,,,,,,,,,,,,,, -62000,0.8757917,5.1045585,,,,,,,,,,,,,, -62100,1.0551425,3.5337152,,,,,,,,,,,,,, -62200,0.8902194,5.3670144,,,,,,,,,,,,,, -62300,1.1508607,3.8203132,,,,,,,,,,,,,, -62400,1.0280949,3.6610312,,,,,,,,,,,,,, -62435,,,0.4630273282527923,2.4418976306915283,0.4314000010490417,2.5994136333465576,50000.0,0.3382000029087066,3.219500541687012,10000.0,29026.63190627098,32064.77761387825,29026.63190627098,3031.917558193207,2.7409818172454834,0.0 -62500,0.91935265,4.226321,,,,,,,,,,,,,, -62600,0.7903499,5.0284433,,,,,,,,,,,,,, -62700,1.1531777,3.6604905,,,,,,,,,,,,,, -62800,1.0241252,5.8394065,,,,,,,,,,,,,, -62900,1.2371739,3.3982873,,,,,,,,,,,,,, -63000,1.1665891,3.5264888,,,,,,,,,,,,,, -63100,1.1615652,3.5213299,,,,,,,,,,,,,, -63200,0.8246839,5.1921253,,,,,,,,,,,,,, -63300,0.8871682,5.680353,,,,,,,,,,,,,, -63342,,,0.4673046767711639,2.354846954345703,0.4369799792766571,2.523207664489746,50000.0,0.3413000106811523,3.1512272357940674,10000.0,29446.979049682617,32529.07609534264,29446.979049682617,3075.777815580368,2.7796471118927,0.0 -63400,0.87080973,5.8151464,,,,,,,,,,,,,, -63500,1.1448368,3.438215,,,,,,,,,,,,,, -63600,0.9150621,5.216172,,,,,,,,,,,,,, -63700,1.0591981,3.6111922,,,,,,,,,,,,,, -63800,1.0685792,3.5680566,,,,,,,,,,,,,, -63900,0.8599528,4.221594,,,,,,,,,,,,,, -64000,1.2639356,3.499056,,,,,,,,,,,,,, -64100,1.1095628,3.5534196,,,,,,,,,,,,,, -64200,0.92219394,4.582884,,,,,,,,,,,,,, -64248,,,0.4647851586341858,2.3934576511383057,0.4303799867630005,2.585111856460572,50000.0,0.3419000208377838,3.1868813037872314,10000.0,29867.28816127777,32994.63881659508,29867.28816127777,3120.9421343803406,2.8182005882263184,0.0 -64300,1.3743478,3.359719,,,,,,,,,,,,,, -64400,0.9814425,3.4541762,,,,,,,,,,,,,, -64500,1.0236288,3.3647072,,,,,,,,,,,,,, -64600,0.9990654,3.6177619,,,,,,,,,,,,,, -64700,1.0747633,5.705968,,,,,,,,,,,,,, -64800,1.1996152,3.4905515,,,,,,,,,,,,,, -64900,0.9227879,3.9739108,,,,,,,,,,,,,, -65000,0.9670937,3.5866232,,,,,,,,,,,,,, -65100,1.1545979,3.5152462,,,,,,,,,,,,,, -65155,,,0.4700976312160492,2.458969593048096,0.4227599799633026,2.704218626022339,50000.0,0.3257000148296356,3.3210971355438232,10000.0,30287.48583388329,33454.79567170143,30287.48583388329,3160.80952334404,2.8595807552337646,0.0 -65200,0.9885074,3.4665258,,,,,,,,,,,,,, -65300,1.1335584,3.4166412,,,,,,,,,,,,,, -65400,1.1375853,3.5314322,,,,,,,,,,,,,, -65500,1.2039036,3.639411,,,,,,,,,,,,,, -65600,1.0365667,4.4259977,,,,,,,,,,,,,, -65700,1.1054779,3.7927585,,,,,,,,,,,,,, -65800,1.0924237,3.373447,,,,,,,,,,,,,, -65900,0.8164733,5.832394,,,,,,,,,,,,,, -66000,1.0299051,3.8390484,,,,,,,,,,,,,, -66059,,,0.4700585901737213,2.350682258605957,0.4400999844074249,2.517786741256714,50000.0,0.3461000025272369,3.131437301635742,10000.0,30707.519748210907,33917.77133059502,30707.519748210907,3203.6639914512634,2.895918846130371,0.0 -66100,1.0553089,5.584874,,,,,,,,,,,,,, -66200,0.95800763,5.103706,,,,,,,,,,,,,, -66300,1.2166005,4.319576,,,,,,,,,,,,,, -66400,0.76543355,5.1423054,,,,,,,,,,,,,, -66500,0.9115604,5.7365403,,,,,,,,,,,,,, -66600,1.1480262,3.4682121,,,,,,,,,,,,,, -66700,1.2017384,3.4978983,,,,,,,,,,,,,, -66800,1.1728121,3.5930777,,,,,,,,,,,,,, -66900,1.1130588,3.4463978,,,,,,,,,,,,,, -66962,,,0.4616210758686065,2.3884122371673584,0.4339599907398224,2.554633617401123,50000.0,0.3354000151157379,3.2074503898620605,10000.0,31127.525916337967,34383.9345805645,31127.525916337967,3249.730548620224,2.936368942260742,0.0 -67000,1.0916976,3.4732277,,,,,,,,,,,,,, -67100,0.99645644,3.7021344,,,,,,,,,,,,,, -67200,1.0537391,3.6474373,,,,,,,,,,,,,, -67300,1.1021794,3.593914,,,,,,,,,,,,,, -67400,0.8528603,5.1595435,,,,,,,,,,,,,, -67500,1.0119917,3.475259,,,,,,,,,,,,,, -67600,0.9565813,4.571341,,,,,,,,,,,,,, -67700,1.0561761,3.5218303,,,,,,,,,,,,,, -67800,1.2867512,3.460105,,,,,,,,,,,,,, -67868,,,0.5029687285423279,2.1818504333496094,0.4394199848175049,2.515748977661133,50000.0,0.3392000198364258,3.1619865894317627,10000.0,31547.554701805115,34848.52709269524,31547.554701805115,3294.203216075897,2.976597309112549,0.0 -67900,0.9676418,4.5272136,,,,,,,,,,,,,, -68000,1.1799963,3.695206,,,,,,,,,,,,,, -68100,1.0865244,3.571284,,,,,,,,,,,,,, -68200,1.1682489,3.5021892,,,,,,,,,,,,,, -68300,0.83652526,4.7100205,,,,,,,,,,,,,, -68400,1.1055589,3.4174912,,,,,,,,,,,,,, -68500,0.8749938,4.9983,,,,,,,,,,,,,, -68600,0.9644507,5.7676916,,,,,,,,,,,,,, -68700,0.978895,4.0686865,,,,,,,,,,,,,, -68775,,,0.4720703065395355,2.326566219329834,0.4456599950790405,2.470954656600952,50000.0,0.3438000082969665,3.122487545013428,10000.0,31967.763231039047,35311.88851070404,31967.763231039047,3337.2650215625763,3.0178487300872803,0.0 -68800,1.1571374,3.365612,,,,,,,,,,,,,, -68900,0.9622443,3.5425096,,,,,,,,,,,,,, -69000,0.90553725,4.263489,,,,,,,,,,,,,, -69100,0.85596985,5.195759,,,,,,,,,,,,,, -69200,0.8753299,4.6754336,,,,,,,,,,,,,, -69300,0.8978737,3.8036742,,,,,,,,,,,,,, -69400,0.76761866,5.160802,,,,,,,,,,,,,, -69500,1.043835,3.7975264,,,,,,,,,,,,,, -69600,1.032449,4.1633563,,,,,,,,,,,,,, -69682,,,0.4788476526737213,2.295114278793335,0.4430199861526489,2.4969727993011475,50000.0,0.34620001912117,3.122164487838745,10000.0,32387.87904167176,35776.63043308258,32387.87904167176,3381.8008399009705,3.0573742389678955,0.0 -69700,1.1272113,3.565136,,,,,,,,,,,,,, -69800,0.9774087,3.2708075,,,,,,,,,,,,,, -69900,1.0775937,3.3376799,,,,,,,,,,,,,, -70000,1.0809832,4.3066497,,,,,,,,,,,,,, -70100,1.0515394,3.42159,,,,,,,,,,,,,, -70200,1.012904,4.0468984,,,,,,,,,,,,,, -70300,1.1349511,3.299947,,,,,,,,,,,,,, -70400,0.9127085,4.77391,,,,,,,,,,,,,, -70500,1.0431172,5.7014785,,,,,,,,,,,,,, -70588,,,0.4978906214237213,2.208686113357544,0.4494799971580505,2.4724645614624023,50000.0,0.34620001912117,3.092029333114624,10000.0,32807.88691544533,36237.747968912125,32807.88691544533,3422.816102027893,3.100841999053955,0.0 -70600,1.0743436,3.4325156,,,,,,,,,,,,,, -70700,1.1175828,3.488499,,,,,,,,,,,,,, -70800,1.238982,3.3929498,,,,,,,,,,,,,, -70900,0.8618867,4.808103,,,,,,,,,,,,,, -71000,1.3910247,3.7730057,,,,,,,,,,,,,, -71100,0.96300757,5.6821504,,,,,,,,,,,,,, -71200,1.1915892,3.5266848,,,,,,,,,,,,,, -71300,0.8037041,4.681423,,,,,,,,,,,,,, -71400,1.0456122,3.3952196,,,,,,,,,,,,,, -71496,,,0.4721484184265136,2.327339172363281,0.4462999999523163,2.475956678390503,50000.0,0.3406000137329101,3.128873109817505,10000.0,33228.21760249138,36703.45001864433,33228.21760249138,3468.0950469970703,3.1422665119171143,0.0 -71500,1.1168503,3.4091496,,,,,,,,,,,,,, -71600,0.9474678,4.4135175,,,,,,,,,,,,,, -71700,1.0266899,3.3407402,,,,,,,,,,,,,, -71800,1.0191532,4.23257,,,,,,,,,,,,,, -71900,1.0844606,3.706274,,,,,,,,,,,,,, -72000,1.1272466,4.030569,,,,,,,,,,,,,, -72100,0.9546287,3.9140737,,,,,,,,,,,,,, -72200,1.178017,3.6507957,,,,,,,,,,,,,, -72300,1.1814866,3.450322,,,,,,,,,,,,,, -72400,1.067583,3.5360355,,,,,,,,,,,,,, -72403,,,0.4733007848262787,2.381632089614868,0.4436999857425689,2.5497748851776123,50000.0,0.3487000167369842,3.151477575302124,10000.0,33648.21060991287,37167.23076486588,33648.21060991287,3511.7916102409363,3.1832900047302246,0.0 -72500,0.8899319,5.521339,,,,,,,,,,,,,, -72600,1.0266964,3.2491715,,,,,,,,,,,,,, -72700,0.7973518,5.4984837,,,,,,,,,,,,,, -72800,1.0655024,3.4116888,,,,,,,,,,,,,, -72900,0.9580737,5.2068763,,,,,,,,,,,,,, -73000,0.90273654,4.039364,,,,,,,,,,,,,, -73100,1.1833726,3.2913742,,,,,,,,,,,,,, -73200,1.061161,3.376883,,,,,,,,,,,,,, -73300,1.0980797,3.3097131,,,,,,,,,,,,,, -73310,,,0.4833788871765136,2.329761266708374,0.4386399984359741,2.551053047180176,50000.0,0.3387000262737274,3.1976158618927,10000.0,34068.42133665085,37629.576330661774,34068.42133665085,3553.8385372161865,3.2199816703796387,0.0 -73400,1.0545889,3.3033156,,,,,,,,,,,,,, -73500,0.91866654,3.6086755,,,,,,,,,,,,,, -73600,1.0833149,3.346887,,,,,,,,,,,,,, -73700,1.0706213,3.3363678,,,,,,,,,,,,,, -73800,1.1460166,3.4076216,,,,,,,,,,,,,, -73900,0.7410844,5.1944723,,,,,,,,,,,,,, -74000,0.9720412,4.0945635,,,,,,,,,,,,,, -74100,1.0982255,3.626571,,,,,,,,,,,,,, -74200,0.8930702,5.7851562,,,,,,,,,,,,,, -74216,,,0.4860742092132568,2.254650354385376,0.4541399776935577,2.429095268249512,50000.0,0.3589000105857849,3.0651798248291016,10000.0,34488.45884680748,38094.87999844551,34488.45884680748,3599.0153257846832,3.2576658725738525,0.0 -74300,1.053452,3.2792993,,,,,,,,,,,,,, -74400,1.042848,3.4062538,,,,,,,,,,,,,, -74500,1.0041181,3.4520779,,,,,,,,,,,,,, -74600,1.1705997,3.5401504,,,,,,,,,,,,,, -74700,1.0283891,3.4347067,,,,,,,,,,,,,, -74800,1.1141269,3.2393608,,,,,,,,,,,,,, -74900,1.1434513,4.187175,,,,,,,,,,,,,, -75000,1.1299952,3.1918716,,,,,,,,,,,,,, -75100,0.9624275,3.3441985,,,,,,,,,,,,,, -75122,,,0.4874609410762787,2.328404903411865,0.4536999762058258,2.506525754928589,50000.0,0.3505000174045563,3.143037796020508,10000.0,34908.47983670235,38561.67894077301,34908.47983670235,3645.705919742584,3.294504880905152,0.0 -75200,1.0676538,3.6463075,,,,,,,,,,,,,, -75300,0.7334498,5.2654376,,,,,,,,,,,,,, -75400,1.0214837,3.588861,,,,,,,,,,,,,, -75500,1.1015612,3.22614,,,,,,,,,,,,,, -75600,1.0823705,3.2956939,,,,,,,,,,,,,, -75700,0.8549253,5.7190285,,,,,,,,,,,,,, -75800,1.0857923,3.6131184,,,,,,,,,,,,,, -75900,0.7654585,5.1113663,,,,,,,,,,,,,, -76000,1.0026588,3.2368796,,,,,,,,,,,,,, -76031,,,0.49755859375,2.247817039489746,0.4552599787712097,2.464518785476685,50000.0,0.3543000221252441,3.110901117324829,10000.0,35328.69760990143,39028.121799230576,35328.69760990143,3691.841574668884,3.332691431045532,0.0 -76100,0.99579674,3.4168835,,,,,,,,,,,,,, -76200,1.4848527,3.3594112,,,,,,,,,,,,,, -76300,1.0111716,3.3318026,,,,,,,,,,,,,, -76400,1.0677953,3.1379912,,,,,,,,,,,,,, -76500,1.2644125,3.2999687,,,,,,,,,,,,,, -76600,1.1983963,3.4954603,,,,,,,,,,,,,, -76700,1.0607343,3.2864265,,,,,,,,,,,,,, -76800,0.9040037,5.6254597,,,,,,,,,,,,,, -76900,1.195885,3.3661995,,,,,,,,,,,,,, -76940,,,0.4865234196186065,2.275959014892578,0.4581599831581116,2.418885946273804,50000.0,0.3583000302314758,3.084025621414185,10000.0,35748.71830582619,39489.267315626144,35748.71830582619,3732.874413013458,3.374640941619873,0.0 -77000,0.9716465,5.7820387,,,,,,,,,,,,,, -77100,1.0471579,5.157127,,,,,,,,,,,,,, -77200,0.95326155,5.558619,,,,,,,,,,,,,, -77300,0.9701295,4.8690624,,,,,,,,,,,,,, -77400,1.1006603,3.27006,,,,,,,,,,,,,, -77500,0.9618105,5.2126255,,,,,,,,,,,,,, -77600,0.96981907,4.003352,,,,,,,,,,,,,, -77700,1.0945745,3.335896,,,,,,,,,,,,,, -77800,0.9510861,4.4390907,,,,,,,,,,,,,, -77847,,,0.5015038847923279,2.2032933235168457,0.4641399979591369,2.393378496170044,50000.0,0.3667000234127044,3.028707265853882,10000.0,36168.80930709839,39951.15533995628,36168.80930709839,3774.578455448151,3.415534734725952,0.0 -77900,1.1665002,3.3830674,,,,,,,,,,,,,, -78000,1.0444082,3.2121115,,,,,,,,,,,,,, -78100,1.1246703,3.3058949,,,,,,,,,,,,,, -78200,1.0518627,3.4299824,,,,,,,,,,,,,, -78300,0.96210724,5.2143645,,,,,,,,,,,,,, -78400,1.093456,3.1519485,,,,,,,,,,,,,, -78500,0.96191084,5.4916945,,,,,,,,,,,,,, -78600,1.240901,3.3903954,,,,,,,,,,,,,, -78700,0.99584335,3.4755592,,,,,,,,,,,,,, -78755,,,0.5059765577316284,2.1536612510681152,0.466399997472763,2.365823268890381,50000.0,0.3612000048160553,3.025886297225952,10000.0,36588.75086402893,40418.35488009453,36588.75086402893,3821.7439455986014,3.457021474838257,0.0 -78800,1.2012339,3.2886014,,,,,,,,,,,,,, -78900,1.3306831,3.3848886,,,,,,,,,,,,,, -79000,1.1426153,3.6546206,,,,,,,,,,,,,, -79100,1.124214,3.3311195,,,,,,,,,,,,,, -79200,1.0593005,3.462148,,,,,,,,,,,,,, -79300,0.9735219,4.520824,,,,,,,,,,,,,, -79400,1.0784079,3.5712159,,,,,,,,,,,,,, -79500,1.1508781,3.9934235,,,,,,,,,,,,,, -79600,1.0108871,3.6242933,,,,,,,,,,,,,, -79662,,,0.5085351467132568,2.1568405628204346,0.4750399887561798,2.333594560623169,50000.0,0.3664000034332275,3.014451026916504,10000.0,37009.04603528976,40879.798095703125,37009.04603528976,3862.8010079860687,3.495908260345459,0.0 -79700,1.067635,3.2534852,,,,,,,,,,,,,, -79800,1.3285303,3.2963855,,,,,,,,,,,,,, -79900,1.1684343,3.25244,,,,,,,,,,,,,, -80000,0.91829705,4.3734226,,,,,,,,,,,,,, -80100,1.2751489,3.4456291,,,,,,,,,,,,,, -80200,1.0088493,5.1303053,,,,,,,,,,,,,, -80300,0.99925065,3.5073118,,,,,,,,,,,,,, -80400,1.1164047,3.415131,,,,,,,,,,,,,, -80500,1.2895486,3.345904,,,,,,,,,,,,,, -80571,,,0.4918359220027923,2.267910480499268,0.4589399993419647,2.433488130569458,50000.0,0.358100026845932,3.056978940963745,10000.0,37429.33586239815,41344.432941913605,37429.33586239815,3907.055018186569,3.535282850265503,0.0 -80600,1.2411664,3.2228572,,,,,,,,,,,,,, -80700,1.0830647,3.3989067,,,,,,,,,,,,,, -80800,1.2424481,3.3874102,,,,,,,,,,,,,, -80900,0.95941293,3.1961026,,,,,,,,,,,,,, -81000,1.0998322,3.3136065,,,,,,,,,,,,,, -81100,0.9263935,3.7970378,,,,,,,,,,,,,, -81200,1.2247062,3.2130485,,,,,,,,,,,,,, -81300,1.1446098,3.2555585,,,,,,,,,,,,,, -81400,1.1103228,3.2964833,,,,,,,,,,,,,, -81479,,,0.5015624761581421,2.221553087234497,0.4664799869060516,2.406338691711426,50000.0,0.3653000295162201,3.051608085632324,10000.0,37849.65655899048,41809.49540543556,37849.65655899048,3951.705055236816,3.575528621673584,0.0 -81500,1.1034305,3.233485,,,,,,,,,,,,,, -81600,1.1336397,3.598157,,,,,,,,,,,,,, -81700,0.9240781,5.5990067,,,,,,,,,,,,,, -81800,1.1219192,3.139806,,,,,,,,,,,,,, -81900,1.2310997,3.3580804,,,,,,,,,,,,,, -82000,1.0566313,3.9775686,,,,,,,,,,,,,, -82100,1.2740092,3.2753448,,,,,,,,,,,,,, -82200,0.8774272,4.753035,,,,,,,,,,,,,, -82300,1.0291677,3.168604,,,,,,,,,,,,,, -82386,,,0.5391015410423279,2.031850099563598,0.4758799970149994,2.33320951461792,50000.0,0.3764000236988067,2.9651641845703125,10000.0,38269.87842488289,42269.36018133164,38269.87842488289,3991.25189948082,3.620238780975342,0.0 -82400,1.0685294,3.2342849,,,,,,,,,,,,,, -82500,0.9214802,5.6152244,,,,,,,,,,,,,, -82600,1.1335247,3.2981367,,,,,,,,,,,,,, -82700,1.2450161,3.2654212,,,,,,,,,,,,,, -82800,0.9809168,3.1089396,,,,,,,,,,,,,, -82900,0.9435194,4.1334004,,,,,,,,,,,,,, -83000,1.5431968,3.3429065,,,,,,,,,,,,,, -83100,1.0191975,5.186293,,,,,,,,,,,,,, -83200,0.909801,4.738863,,,,,,,,,,,,,, -83292,,,0.509472668170929,2.1737000942230225,0.472599983215332,2.343513250350952,50000.0,0.370600014925003,2.987058639526367,10000.0,38689.997086048126,42733.49368548393,38689.997086048126,4035.173714160919,3.662067651748657,0.0 -83300,1.0394248,5.1030965,,,,,,,,,,,,,, -83400,1.053228,3.609267,,,,,,,,,,,,,, -83500,1.1060015,4.1101155,,,,,,,,,,,,,, -83600,0.88899493,4.894029,,,,,,,,,,,,,, -83700,1.1406587,3.2299838,,,,,,,,,,,,,, -83800,1.224426,3.1541264,,,,,,,,,,,,,, -83900,0.8342769,4.6983485,,,,,,,,,,,,,, -84000,1.1041844,3.1365044,,,,,,,,,,,,,, -84100,1.0531932,3.7555768,,,,,,,,,,,,,, -84198,,,0.5048437118530273,2.2214019298553467,0.4667199850082397,2.4125983715057373,50000.0,0.3580000102519989,3.0713653564453125,10000.0,39110.11424565315,43199.31651544571,39110.11424565315,4080.785442113876,3.705613374710083,0.0 -84200,1.0149175,3.9252737,,,,,,,,,,,,,, -84300,1.0594671,3.2194927,,,,,,,,,,,,,, -84400,1.0935248,3.1799777,,,,,,,,,,,,,, -84500,1.0186841,3.5648623,,,,,,,,,,,,,, -84600,0.96327347,4.951681,,,,,,,,,,,,,, -84700,1.0477352,4.781257,,,,,,,,,,,,,, -84800,1.3125422,3.437125,,,,,,,,,,,,,, -84900,1.2117401,3.2404523,,,,,,,,,,,,,, -85000,0.9785527,4.745008,,,,,,,,,,,,,, -85100,0.88912785,4.1505685,,,,,,,,,,,,,, -85105,,,0.5434765219688416,1.975115776062012,0.4802799820899963,2.3029189109802246,50000.0,0.3827000260353088,2.939318895339966,10000.0,39530.06701803208,43658.78203487396,39530.06701803208,4120.204581737518,3.74727201461792,0.0 -85200,0.9744013,5.5989356,,,,,,,,,,,,,, -85300,0.96757185,3.5564013,,,,,,,,,,,,,, -85400,1.1859441,3.278109,,,,,,,,,,,,,, -85500,1.1997241,3.5051227,,,,,,,,,,,,,, -85600,0.98989815,3.301422,,,,,,,,,,,,,, -85700,1.0143868,3.3743405,,,,,,,,,,,,,, -85800,0.98747104,3.8117352,,,,,,,,,,,,,, -85900,1.064475,3.3973932,,,,,,,,,,,,,, -86000,0.815046,5.1878386,,,,,,,,,,,,,, -86012,,,0.5088672041893005,2.185390949249268,0.4745599925518036,2.349569082260132,50000.0,0.3664000034332275,3.0009925365448,10000.0,39950.27555394173,44122.84530091286,39950.27555394173,4163.963869810104,3.7868733406066895,0.0 -86100,1.1325973,3.2347894,,,,,,,,,,,,,, -86200,1.0868235,3.7138166,,,,,,,,,,,,,, -86300,1.2239354,3.3342881,,,,,,,,,,,,,, -86400,0.9933539,4.0598207,,,,,,,,,,,,,, -86500,1.0207278,3.7515938,,,,,,,,,,,,,, -86600,1.2020019,3.2733874,,,,,,,,,,,,,, -86700,0.9319085,4.1100063,,,,,,,,,,,,,, -86800,1.1062958,3.30655,,,,,,,,,,,,,, -86900,1.0530146,5.719544,,,,,,,,,,,,,, -86918,,,0.5166406035423279,2.1085727214813232,0.4803199768066406,2.2921464443206787,50000.0,0.3716000318527221,2.941880702972412,10000.0,40370.32043933869,44588.54420852661,40370.32043933869,4209.52586555481,3.8280630111694336,0.0 -87000,1.1201642,3.113048,,,,,,,,,,,,,, -87100,1.213924,5.675073,,,,,,,,,,,,,, -87200,1.2190567,3.4149117,,,,,,,,,,,,,, -87300,0.9638692,5.240411,,,,,,,,,,,,,, -87400,1.2616322,3.220594,,,,,,,,,,,,,, -87500,1.1035643,3.6675413,,,,,,,,,,,,,, -87600,1.1503298,3.6177764,,,,,,,,,,,,,, -87700,1.2249752,3.0919285,,,,,,,,,,,,,, -87800,1.0731236,3.195643,,,,,,,,,,,,,, -87822,,,0.5345312356948853,2.0101184844970703,0.4843799769878387,2.270280122756958,50000.0,0.3785000145435333,2.930760622024536,10000.0,40790.41343569756,45053.92148447037,40790.41343569756,4254.7183582782745,3.870031595230103,0.0 -87900,1.0447236,3.0731869,,,,,,,,,,,,,, -88000,1.3024359,3.20139,,,,,,,,,,,,,, -88100,1.1946918,4.1695604,,,,,,,,,,,,,, -88200,1.1199087,3.6895287,,,,,,,,,,,,,, -88300,0.96420234,3.6812105,,,,,,,,,,,,,, -88400,0.9275377,4.2971277,,,,,,,,,,,,,, -88500,1.1419533,3.089446,,,,,,,,,,,,,, -88600,1.244224,3.194999,,,,,,,,,,,,,, -88700,1.0298475,4.9641647,,,,,,,,,,,,,, -88728,,,0.5253124833106995,2.1078312397003174,0.4916599988937378,2.2708756923675537,50000.0,0.3844000101089477,2.910942554473877,10000.0,41210.61529803276,45520.16591215134,41210.61529803276,4300.670683383942,3.910280704498291,0.0 -88800,1.2629977,3.1271718,,,,,,,,,,,,,, -88900,1.1172454,3.1811917,,,,,,,,,,,,,, -89000,0.9762281,3.9231746,,,,,,,,,,,,,, -89100,1.1646453,3.9779453,,,,,,,,,,,,,, -89200,1.0556816,3.134107,,,,,,,,,,,,,, -89300,1.0593877,3.3444383,,,,,,,,,,,,,, -89400,0.8581181,4.4620204,,,,,,,,,,,,,, -89500,1.3754811,3.179278,,,,,,,,,,,,,, -89600,1.0281144,5.6458106,,,,,,,,,,,,,, -89635,,,0.5248632431030273,2.0753753185272217,0.4898599982261657,2.257879734039306,50000.0,0.3828000128269195,2.9128763675689697,10000.0,41630.89320707321,45983.5998442173,41630.89320707321,4343.729335069656,3.955927610397339,0.0 -89700,1.0993724,4.0581646,,,,,,,,,,,,,, -89800,1.1688826,3.4880319,,,,,,,,,,,,,, -89900,1.3002199,3.0996869,,,,,,,,,,,,,, -90000,1.1999246,3.214639,,,,,,,,,,,,,, -90100,1.0256746,3.4942555,,,,,,,,,,,,,, -90200,1.0578651,3.4991894,,,,,,,,,,,,,, -90300,1.1004124,3.1931896,,,,,,,,,,,,,, -90400,0.98655,4.220607,,,,,,,,,,,,,, -90500,1.2998312,3.2282646,,,,,,,,,,,,,, -90538,,,0.5314648151397705,2.07725191116333,0.4858999848365783,2.3086907863616943,50000.0,0.3754000067710876,2.9723060131073,10000.0,42050.96832442284,46448.51464676857,42050.96832442284,4388.4740562438965,4.000329256057739,0.0 -90600,1.1406159,4.571581,,,,,,,,,,,,,, -90700,0.9334535,4.7481303,,,,,,,,,,,,,, -90800,1.1422163,3.0208554,,,,,,,,,,,,,, -90900,1.0727404,3.1012168,,,,,,,,,,,,,, -91000,1.1657785,4.045657,,,,,,,,,,,,,, -91100,0.8988543,4.2109632,,,,,,,,,,,,,, -91200,1.0470195,4.998911,,,,,,,,,,,,,, -91300,1.180904,3.333056,,,,,,,,,,,,,, -91400,0.89536047,5.1511993,,,,,,,,,,,,,, -91442,,,0.5350781083106995,2.023699283599853,0.5030400156974792,2.1932621002197266,50000.0,0.394400030374527,2.8569884300231934,10000.0,42471.20595598221,46914.04198694229,42471.20595598221,4433.668882369995,4.045256614685059,0.0 -91500,0.84650105,4.8668385,,,,,,,,,,,,,, -91600,1.113795,3.2636702,,,,,,,,,,,,,, -91700,1.1952826,3.184683,,,,,,,,,,,,,, -91800,1.1718634,3.020596,,,,,,,,,,,,,, -91900,0.95371693,4.568018,,,,,,,,,,,,,, -92000,1.2783632,3.1169186,,,,,,,,,,,,,, -92100,1.0589061,3.3838067,,,,,,,,,,,,,, -92200,0.97439754,5.393852,,,,,,,,,,,,,, -92300,1.2102948,3.072164,,,,,,,,,,,,,, -92349,,,0.5308398604393005,2.097571849822998,0.4921799898147583,2.277677059173584,50000.0,0.3882000148296356,2.938272476196289,10000.0,42891.15717124939,47375.598071336746,42891.15717124939,4475.179137229919,4.08874249458313,0.0 -92400,1.0771943,3.2255814,,,,,,,,,,,,,, -92500,1.1618946,2.9818923,,,,,,,,,,,,,, -92600,1.074142,3.0128455,,,,,,,,,,,,,, -92700,0.972345,5.313184,,,,,,,,,,,,,, -92800,1.0412992,3.46099,,,,,,,,,,,,,, -92900,1.0701903,3.099752,,,,,,,,,,,,,, -93000,1.3549283,5.676144,,,,,,,,,,,,,, -93100,1.299871,3.15949,,,,,,,,,,,,,, -93200,1.3281845,3.4176621,,,,,,,,,,,,,, -93254,,,0.5414257645606995,2.011777639389038,0.4967799782752991,2.2263717651367188,50000.0,0.3873000144958496,2.8787736892700195,10000.0,43311.2950565815,47839.22241187096,43311.2950565815,4518.571710586548,4.132167816162109,0.0 -93300,1.1129006,5.0543294,,,,,,,,,,,,,, -93400,1.3294115,3.0332227,,,,,,,,,,,,,, -93500,1.0956718,3.1250508,,,,,,,,,,,,,, -93600,0.99690086,3.8468993,,,,,,,,,,,,,, -93700,1.0728035,3.0817513,,,,,,,,,,,,,, -93800,0.9622946,5.1140466,,,,,,,,,,,,,, -93900,0.9632842,4.2869005,,,,,,,,,,,,,, -94000,1.1981151,2.9752772,,,,,,,,,,,,,, -94100,1.1521407,3.11372,,,,,,,,,,,,,, -94158,,,0.5350976586341858,2.0232887268066406,0.5008000135421753,2.1791908740997314,50000.0,0.395300030708313,2.828667640686035,10000.0,43731.3213903904,48304.85384202004,43731.3213903904,4564.084159851074,4.174811124801636,0.0 -94200,1.1451415,3.112156,,,,,,,,,,,,,, -94300,1.1836789,3.19392,,,,,,,,,,,,,, -94400,1.0491,3.6333084,,,,,,,,,,,,,, -94500,1.1786768,3.1098042,,,,,,,,,,,,,, -94600,1.3148227,3.0424242,,,,,,,,,,,,,, -94700,0.89052093,5.3914065,,,,,,,,,,,,,, -94800,1.1466535,3.1215718,,,,,,,,,,,,,, -94900,1.0577075,3.0791698,,,,,,,,,,,,,, -95000,1.1225528,3.2147634,,,,,,,,,,,,,, -95065,,,0.5455663800239563,1.9697304964065552,0.5080199837684631,2.1573288440704346,50000.0,0.4004000127315521,2.8348753452301025,10000.0,44151.28519535065,48770.36529827118,44151.28519535065,4609.533478021622,4.222652435302734,0.0 -95100,1.3012931,2.978713,,,,,,,,,,,,,, -95200,1.1767035,3.1963174,,,,,,,,,,,,,, -95300,0.8561568,5.519674,,,,,,,,,,,,,, -95400,0.8673763,4.8071017,,,,,,,,,,,,,, -95500,1.0208964,3.927569,,,,,,,,,,,,,, -95600,1.1078327,2.9437795,,,,,,,,,,,,,, -95700,1.1801448,2.9732518,,,,,,,,,,,,,, -95800,1.1833725,3.327861,,,,,,,,,,,,,, -95900,0.94433266,5.406753,,,,,,,,,,,,,, -95972,,,0.5479297041893005,1.936432957649231,0.5101799964904785,2.144627571105957,50000.0,0.3910000324249267,2.8307206630706787,10000.0,44571.42405152321,49231.33609867096,44571.42405152321,4650.271003246307,4.266034364700317,0.0 -96000,0.84144306,4.8174033,,,,,,,,,,,,,, -96100,1.1582823,2.9949963,,,,,,,,,,,,,, -96200,1.0801575,3.361809,,,,,,,,,,,,,, -96300,1.2995594,3.2280996,,,,,,,,,,,,,, -96400,0.88065314,5.3608723,,,,,,,,,,,,,, -96500,1.2224427,2.9544573,,,,,,,,,,,,,, -96600,1.1057492,4.5825834,,,,,,,,,,,,,, -96700,1.3576896,3.1060274,,,,,,,,,,,,,, -96800,0.9357443,4.3582554,,,,,,,,,,,,,, -96873,,,0.5483007431030273,1.9771265983581543,0.5109999775886536,2.1710920333862305,50000.0,0.4003000259399414,2.800229549407959,10000.0,44991.071533203125,49692.96869182587,44991.071533203125,4691.759567499161,4.712217807769775,0.0 -96900,1.2942647,3.1009595,,,,,,,,,,,,,, -97000,1.1241918,3.136647,,,,,,,,,,,,,, -97100,1.0170202,4.0376825,,,,,,,,,,,,,, -97200,0.9770101,4.5664372,,,,,,,,,,,,,, -97300,1.3683361,2.989482,,,,,,,,,,,,,, -97400,1.2530463,2.9837286,,,,,,,,,,,,,, -97500,1.3603537,3.0952315,,,,,,,,,,,,,, -97600,1.2985398,3.1012995,,,,,,,,,,,,,, -97700,1.3031653,3.051296,,,,,,,,,,,,,, -97778,,,0.550976574420929,1.9703459739685056,0.5161399841308594,2.141354322433472,50000.0,0.4104000329971313,2.800655603408813,10000.0,45411.392758369446,50159.7207980156,45411.392758369446,4738.096249103546,4.756515741348267,0.0 -97800,1.2793489,3.105772,,,,,,,,,,,,,, -97900,1.3331919,3.0353103,,,,,,,,,,,,,, -98000,1.1807513,2.8897974,,,,,,,,,,,,,, -98100,1.2445656,3.1537793,,,,,,,,,,,,,, -98200,0.9802983,4.863607,,,,,,,,,,,,,, -98300,1.135536,3.3336966,,,,,,,,,,,,,, -98400,1.3159287,2.9947512,,,,,,,,,,,,,, -98500,1.2967107,2.8794894,,,,,,,,,,,,,, -98600,0.9468604,5.3665404,,,,,,,,,,,,,, -98684,,,0.5575000047683716,1.9146215915679927,0.5131999850273132,2.129566669464112,50000.0,0.4022000133991241,2.7948946952819824,10000.0,45831.61665439606,50624.62672114372,45831.61665439606,4782.682325363159,4.801445245742798,0.0 -98700,1.2380245,2.8746367,,,,,,,,,,,,,, -98800,1.1351687,3.0371325,,,,,,,,,,,,,, -98900,1.161325,2.9704232,,,,,,,,,,,,,, -99000,1.1375597,3.012941,,,,,,,,,,,,,, -99100,0.98086077,4.5380015,,,,,,,,,,,,,, -99200,0.97762716,4.062523,,,,,,,,,,,,,, -99300,1.2995663,3.1161387,,,,,,,,,,,,,, -99400,1.0006777,3.7892568,,,,,,,,,,,,,, -99500,1.0207226,5.5706053,,,,,,,,,,,,,, -99591,,,0.587890625,1.753287434577942,0.5281199812889099,2.0528218746185303,50000.0,0.4103000164031982,2.728025436401367,10000.0,46251.78324842453,51088.31325173378,46251.78324842453,4826.09882068634,4.854480981826782,0.0 -99600,1.0862452,3.0737543,,,,,,,,,,,,,, -99700,1.1687057,3.3811512,,,,,,,,,,,,,, -99800,1.19736,3.1038568,,,,,,,,,,,,,, -99900,1.1740347,3.0950942,,,,,,,,,,,,,, -100000,1.1714386,3.0225468,,,,,,,,,,,,,, -100100,1.2142521,3.1286674,,,,,,,,,,,,,, -100200,1.0135062,4.086232,,,,,,,,,,,,,, -100300,1.4212968,3.1219218,,,,,,,,,,,,,, -100400,1.1451411,3.03721,,,,,,,,,,,,,, -100497,,,0.5584765672683716,1.894142508506775,0.5205199718475342,2.08619236946106,50000.0,0.4115000069141388,2.745382308959961,10000.0,46672.05391287804,51557.12585926056,46672.05391287804,4874.545390844345,4.899670839309692,0.0 -100500,0.9912287,4.6585903,,,,,,,,,,,,,, -100600,0.96496624,3.9264374,,,,,,,,,,,,,, -100700,1.190036,3.989848,,,,,,,,,,,,,, -100800,1.1266226,4.0276217,,,,,,,,,,,,,, -100900,0.9359094,4.362176,,,,,,,,,,,,,, -101000,1.1303885,2.9996588,,,,,,,,,,,,,, -101100,1.1658428,3.0975955,,,,,,,,,,,,,, -101200,1.2741529,2.9625306,,,,,,,,,,,,,, -101300,1.0549334,3.2064245,,,,,,,,,,,,,, -101400,1.2372804,2.891221,,,,,,,,,,,,,, -101401,,,0.5635937452316284,1.8744672536849976,0.5206999778747559,2.0845861434936523,50000.0,0.4159000217914581,2.720703601837158,10000.0,47092.40252113342,52020.5557975769,47092.40252113342,4917.531066417694,4.94485330581665,0.0 -101500,1.1682202,3.2293072,,,,,,,,,,,,,, -101600,1.1733872,3.216804,,,,,,,,,,,,,, -101700,1.0530825,3.5372536,,,,,,,,,,,,,, -101800,1.0125787,5.475484,,,,,,,,,,,,,, -101900,1.0499104,5.1719036,,,,,,,,,,,,,, -102000,1.3077124,3.2239878,,,,,,,,,,,,,, -102100,1.3658335,2.9268703,,,,,,,,,,,,,, -102200,1.335852,3.036695,,,,,,,,,,,,,, -102300,1.2303914,3.022921,,,,,,,,,,,,,, -102306,,,0.6005273461341858,1.6896134614944458,0.5290799736976624,2.035361051559448,50000.0,0.417900025844574,2.692728281021118,10000.0,47512.490429639816,52484.20654010773,47512.490429639816,4961.000032663345,4.988691568374634,0.0 -102400,1.3929554,2.840539,,,,,,,,,,,,,, -102500,1.0037286,5.2873774,,,,,,,,,,,,,, -102600,1.1377548,2.9048922,,,,,,,,,,,,,, -102700,1.0392706,4.572704,,,,,,,,,,,,,, -102800,1.0134379,5.3985305,,,,,,,,,,,,,, -102900,1.120887,3.7849905,,,,,,,,,,,,,, -103000,0.96146697,5.4287167,,,,,,,,,,,,,, -103100,1.2193334,2.9484577,,,,,,,,,,,,,, -103200,1.1563355,3.768162,,,,,,,,,,,,,, -103213,,,0.5564843416213989,1.9020169973373413,0.5239399671554565,2.0747263431549072,50000.0,0.4070000052452087,2.746877193450928,10000.0,47932.62344169617,52944.69896483421,47932.62344169617,5001.267645597458,5.030147075653076,0.0 -103300,1.2060289,2.79393,,,,,,,,,,,,,, -103400,1.2749296,2.8966672,,,,,,,,,,,,,, -103500,1.2201542,2.9894814,,,,,,,,,,,,,, -103600,1.0812589,4.1384826,,,,,,,,,,,,,, -103700,1.304411,2.7691722,,,,,,,,,,,,,, -103800,1.2525865,2.9446194,,,,,,,,,,,,,, -103900,1.1029379,3.004035,,,,,,,,,,,,,, -104000,0.9053612,4.976848,,,,,,,,,,,,,, -104100,1.2787861,2.9796448,,,,,,,,,,,,,, -104120,,,0.5759570002555847,1.842762351036072,0.5297799706459045,2.0636205673217773,50000.0,0.41880002617836,2.700232744216919,10000.0,48352.95508027077,53410.00083827973,48352.95508027077,5046.142510414124,5.07442831993103,0.0 -104200,1.1961832,3.1847036,,,,,,,,,,,,,, -104300,1.3841276,2.881773,,,,,,,,,,,,,, -104400,1.2213577,5.014341,,,,,,,,,,,,,, -104500,1.1943283,2.9451106,,,,,,,,,,,,,, -104600,1.1395977,3.0195918,,,,,,,,,,,,,, -104700,1.1576508,2.918385,,,,,,,,,,,,,, -104800,1.0227208,3.6278732,,,,,,,,,,,,,, -104900,1.4395703,2.954237,,,,,,,,,,,,,, -105000,1.1536242,2.936545,,,,,,,,,,,,,, -105028,,,0.5927929282188416,1.7482671737670898,0.5322999954223633,2.0392508506774902,50000.0,0.4181000292301178,2.695244312286377,10000.0,48773.32455801964,53875.40821886063,48773.32455801964,5091.086785554886,5.116716623306274,0.0 -105100,1.2210453,2.9350863,,,,,,,,,,,,,, -105200,0.95772785,5.0220685,,,,,,,,,,,,,, -105300,1.0976421,3.2436771,,,,,,,,,,,,,, -105400,1.0661845,3.3989415,,,,,,,,,,,,,, -105500,1.2436637,2.8547347,,,,,,,,,,,,,, -105600,1.5009604,2.92443,,,,,,,,,,,,,, -105700,0.9897925,4.0489063,,,,,,,,,,,,,, -105800,1.3315654,3.0278678,,,,,,,,,,,,,, -105900,1.1028162,4.619576,,,,,,,,,,,,,, -105936,,,0.571972668170929,1.83349883556366,0.5362200140953064,2.0098884105682373,50000.0,0.422400027513504,2.671083688735962,10000.0,49193.34681630135,54339.73200273514,49193.34681630135,5135.292055368424,5.161932468414307,0.0 -106000,0.97162974,4.7408133,,,,,,,,,,,,,, -106100,1.2287517,2.9089458,,,,,,,,,,,,,, -106200,1.1426889,3.067884,,,,,,,,,,,,,, -106300,1.2104366,5.447899,,,,,,,,,,,,,, -106400,1.3064023,3.0032058,,,,,,,,,,,,,, -106500,1.1498169,3.3249366,,,,,,,,,,,,,, -106600,1.333461,3.0218415,,,,,,,,,,,,,, -106700,1.330111,2.8721602,,,,,,,,,,,,,, -106800,1.2559425,3.1168656,,,,,,,,,,,,,, -106841,,,0.5798632502555847,1.801438570022583,0.5390200018882751,2.008123874664306,50000.0,0.4261000156402588,2.6580498218536377,10000.0,49613.69072461128,54801.62851881981,49613.69072461128,5176.745602607727,5.209371089935303,0.0 -106900,1.1477402,3.775279,,,,,,,,,,,,,, -107000,1.1575176,3.229756,,,,,,,,,,,,,, -107100,1.1873552,2.9838452,,,,,,,,,,,,,, -107200,1.311551,2.9348617,,,,,,,,,,,,,, -107300,1.2364447,2.9422069,,,,,,,,,,,,,, -107400,1.2387464,3.1682374,,,,,,,,,,,,,, -107500,1.1962441,2.7983198,,,,,,,,,,,,,, -107600,1.4243121,2.8621278,,,,,,,,,,,,,, -107700,1.2878603,3.023969,,,,,,,,,,,,,, -107744,,,0.5798437595367432,1.8626606464385984,0.5293999910354614,2.1074492931365967,50000.0,0.4178000092506408,2.7498693466186523,10000.0,50034.01110982895,55269.35957503319,50034.01110982895,5224.058357954025,5.257244348526001,0.0 -107800,1.3335072,2.961997,,,,,,,,,,,,,, -107900,1.1021254,4.1086435,,,,,,,,,,,,,, -108000,1.2542471,2.9115043,,,,,,,,,,,,,, -108100,1.1859931,3.4512367,,,,,,,,,,,,,, -108200,0.93087083,5.336425,,,,,,,,,,,,,, -108300,1.2802832,2.845711,,,,,,,,,,,,,, -108400,1.3083769,2.8654425,,,,,,,,,,,,,, -108500,0.99305123,5.2286515,,,,,,,,,,,,,, -108600,1.0653121,4.025424,,,,,,,,,,,,,, -108651,,,0.5818749666213989,1.7912267446517944,0.5414400100708008,1.9961179494857788,50000.0,0.4240000247955322,2.657770156860352,10000.0,50454.28623723984,55730.48776316643,50454.28623723984,5264.811240434647,5.306406021118164,0.0 -108700,1.2125883,2.7457323,,,,,,,,,,,,,, -108800,1.3395106,2.8851416,,,,,,,,,,,,,, -108900,1.2962744,2.9194202,,,,,,,,,,,,,, -109000,1.3361915,3.117099,,,,,,,,,,,,,, -109100,1.2218763,2.914183,,,,,,,,,,,,,, -109200,1.2485601,3.204328,,,,,,,,,,,,,, -109300,1.0482619,3.856431,,,,,,,,,,,,,, -109400,1.4562871,2.848093,,,,,,,,,,,,,, -109500,1.1690999,4.094583,,,,,,,,,,,,,, -109557,,,0.5807812213897705,1.7843594551086426,0.5475599765777588,1.9676730632781985,50000.0,0.4294000267982483,2.6383376121521,10000.0,50874.582459926605,56190.43788194656,50874.582459926605,5304.367951393127,5.35333776473999,0.0 -109600,1.2314894,2.8017955,,,,,,,,,,,,,, -109700,1.0124905,3.6538768,,,,,,,,,,,,,, -109800,1.2421442,2.8455787,,,,,,,,,,,,,, -109900,1.2728603,2.8956568,,,,,,,,,,,,,, -110000,1.1725101,3.5377686,,,,,,,,,,,,,, -110100,0.9670069,4.8729324,,,,,,,,,,,,,, -110200,1.3899689,2.843139,,,,,,,,,,,,,, -110300,0.97244394,4.933617,,,,,,,,,,,,,, -110400,0.91870296,4.930254,,,,,,,,,,,,,, -110464,,,0.5948437452316284,1.731536626815796,0.5445600152015686,1.979048132896424,50000.0,0.4315000176429748,2.643044948577881,10000.0,51294.605360507965,56655.322309970856,51294.605360507965,5349.13404250145,5.398436546325684,0.0 -110500,1.1051745,4.5585756,,,,,,,,,,,,,, -110600,1.276053,2.666851,,,,,,,,,,,,,, -110700,1.2833216,2.808948,,,,,,,,,,,,,, -110800,1.2675139,2.7777703,,,,,,,,,,,,,, -110900,1.2918602,3.1657526,,,,,,,,,,,,,, -111000,1.1684266,2.8937783,,,,,,,,,,,,,, -111100,1.0482582,5.365305,,,,,,,,,,,,,, -111200,1.1338618,5.3016043,,,,,,,,,,,,,, -111300,1.14532,3.313119,,,,,,,,,,,,,, -111371,,,0.5899804830551147,1.7508074045181274,0.555079996585846,1.9349216222763064,50000.0,0.4382000267505646,2.608139991760254,10000.0,51714.772094249725,57115.24357366562,51714.772094249725,5388.784100055695,5.452826261520386,0.0 -111400,1.1033916,3.6230097,,,,,,,,,,,,,, -111500,1.4200569,2.8196821,,,,,,,,,,,,,, -111600,1.1664075,4.025983,,,,,,,,,,,,,, -111700,1.2521507,3.0924215,,,,,,,,,,,,,, -111800,1.013254,5.3960266,,,,,,,,,,,,,, -111900,1.0850954,4.0220075,,,,,,,,,,,,,, -112000,1.1940389,2.793003,,,,,,,,,,,,,, -112100,1.1214049,4.1339426,,,,,,,,,,,,,, -112200,0.9857859,5.2248755,,,,,,,,,,,,,, -112281,,,0.5955273509025574,1.7364590167999268,0.5569800138473511,1.9330426454544067,50000.0,0.4413000345230102,2.598820447921753,10000.0,52135.0346596241,57580.20839238167,52135.0346596241,5433.386849164963,5.500851154327393,0.0 -112300,1.3385417,3.5503767,,,,,,,,,,,,,, -112400,1.2576575,2.9661927,,,,,,,,,,,,,, -112500,1.4788185,2.826757,,,,,,,,,,,,,, -112600,1.0200844,4.070782,,,,,,,,,,,,,, -112700,1.4625196,2.7821848,,,,,,,,,,,,,, -112800,1.0511382,4.4132676,,,,,,,,,,,,,, -112900,1.3042744,2.7990432,,,,,,,,,,,,,, -113000,1.003803,4.8126993,,,,,,,,,,,,,, -113100,1.2751191,2.8728724,,,,,,,,,,,,,, -113187,,,0.6030077934265137,1.6891416311264038,0.5582199692726135,1.9116791486740112,50000.0,0.4421000182628631,2.5687875747680664,10000.0,52555.18597626686,58046.9350438118,52555.18597626686,5479.865357398987,5.547860145568848,0.0 -113200,1.3068638,4.1045322,,,,,,,,,,,,,, -113300,1.245599,2.833709,,,,,,,,,,,,,, -113400,1.3650191,2.7659497,,,,,,,,,,,,,, -113500,1.2064928,2.9576359,,,,,,,,,,,,,, -113600,1.369793,2.7394278,,,,,,,,,,,,,, -113700,1.3690383,2.7709017,,,,,,,,,,,,,, -113800,1.2377087,2.7480626,,,,,,,,,,,,,, -113900,1.1681759,4.763914,,,,,,,,,,,,,, -114000,1.1919552,4.892055,,,,,,,,,,,,,, -114093,,,0.5959765315055847,1.7394243478775024,0.5590400099754333,1.9202549457550049,50000.0,0.4409000277519226,2.6083571910858154,10000.0,52975.51297545433,58514.05428671837,52975.51297545433,5526.562952518463,5.59241247177124,0.0 -114100,1.2114735,2.8250718,,,,,,,,,,,,,, -114200,1.144917,3.6006362,,,,,,,,,,,,,, -114300,1.1083237,4.405704,,,,,,,,,,,,,, -114400,1.3273259,2.9078176,,,,,,,,,,,,,, -114500,1.1984621,3.3852873,,,,,,,,,,,,,, -114600,1.5699799,2.8827987,,,,,,,,,,,,,, -114700,1.2804495,2.8246036,,,,,,,,,,,,,, -114800,1.1081712,4.21522,,,,,,,,,,,,,, -114900,1.3719696,2.7326188,,,,,,,,,,,,,, -114999,,,0.5999413728713989,1.698953628540039,0.5589599609375,1.901583433151245,50000.0,0.4438000321388244,2.568891763687134,10000.0,53395.86148428917,58980.18489718437,53395.86148428917,5572.246104717255,5.641361236572266,0.0 -115000,1.1734211,5.0927424,,,,,,,,,,,,,, -115100,1.4669119,2.7900987,,,,,,,,,,,,,, -115200,1.4028256,2.938313,,,,,,,,,,,,,, -115300,1.31411,2.734905,,,,,,,,,,,,,, -115400,1.3686708,3.006544,,,,,,,,,,,,,, -115500,1.192561,3.5093694,,,,,,,,,,,,,, -115600,1.0426309,4.576954,,,,,,,,,,,,,, -115700,1.3546332,2.7408173,,,,,,,,,,,,,, -115800,1.2712812,3.254363,,,,,,,,,,,,,, -115900,1.4327873,2.611769,,,,,,,,,,,,,, -115903,,,0.6122460961341858,1.6250638961791992,0.5637800097465515,1.8601083755493164,50000.0,0.447700023651123,2.532124757766724,10000.0,53815.892790317535,59443.194900512695,53815.892790317535,5615.125262737274,5.690145254135132,0.0 -116000,1.3847528,2.7589092,,,,,,,,,,,,,, -116100,1.304974,2.739166,,,,,,,,,,,,,, -116200,1.446854,2.8803246,,,,,,,,,,,,,, -116300,1.2830276,2.567577,,,,,,,,,,,,,, -116400,1.4209781,2.7614515,,,,,,,,,,,,,, -116500,1.4470117,2.9365525,,,,,,,,,,,,,, -116600,1.3574955,2.8722324,,,,,,,,,,,,,, -116700,1.3170198,3.052093,,,,,,,,,,,,,, -116800,1.5289086,2.8547404,,,,,,,,,,,,,, -116809,,,0.6307226419448853,1.5751163959503174,0.5652799606323242,1.8742843866348269,50000.0,0.4429000318050384,2.542642116546631,10000.0,54236.21877121925,59905.71079039574,54236.21877121925,5657.219936609268,5.732917547225952,0.0 -116900,1.3501502,2.866686,,,,,,,,,,,,,, -117000,1.066421,4.665671,,,,,,,,,,,,,, -117100,1.3098722,2.713343,,,,,,,,,,,,,, -117200,1.2732196,2.9826841,,,,,,,,,,,,,, -117300,1.3013165,2.7901404,,,,,,,,,,,,,, -117400,1.2336173,3.4318063,,,,,,,,,,,,,, -117500,1.1020457,4.78695,,,,,,,,,,,,,, -117600,1.3389643,2.8483136,,,,,,,,,,,,,, -117700,1.145053,5.2015038,,,,,,,,,,,,,, -117715,,,0.606738269329071,1.6873126029968262,0.5666399598121643,1.881292462348938,50000.0,0.4442000091075897,2.5546884536743164,10000.0,54656.47298908234,60368.68414545059,54656.47298908234,5699.84024477005,5.780933141708374,0.0 -117800,1.32397,2.7311943,,,,,,,,,,,,,, -117900,1.2629267,3.0165062,,,,,,,,,,,,,, -118000,1.3305925,2.6992273,,,,,,,,,,,,,, -118100,1.3425456,5.2602944,,,,,,,,,,,,,, -118200,1.1394935,4.423386,,,,,,,,,,,,,, -118300,1.2148062,3.631989,,,,,,,,,,,,,, -118400,1.4276431,2.6939216,,,,,,,,,,,,,, -118500,1.1898886,4.733373,,,,,,,,,,,,,, -118600,1.2464261,3.416658,,,,,,,,,,,,,, -118621,,,0.6252343654632568,1.57450532913208,0.5776000022888184,1.8012522459030151,50000.0,0.4595000147819519,2.4840798377990723,10000.0,55076.77208423615,60832.36305522919,55076.77208423615,5743.124422311783,5.825575351715088,0.0 -118700,1.2710803,2.7400253,,,,,,,,,,,,,, -118800,1.1671271,3.8920317,,,,,,,,,,,,,, -118900,1.3645374,2.7200868,,,,,,,,,,,,,, -119000,1.329816,2.6459248,,,,,,,,,,,,,, -119100,1.1769543,5.252554,,,,,,,,,,,,,, -119200,1.4879917,2.6791053,,,,,,,,,,,,,, -119300,1.2205498,3.189848,,,,,,,,,,,,,, -119400,1.4652379,3.0931249,,,,,,,,,,,,,, -119500,1.3287677,5.0883775,,,,,,,,,,,,,, -119523,,,0.6426367163658142,1.501293420791626,0.5740599632263184,1.8339369297027588,50000.0,0.4567000269889831,2.4974396228790283,10000.0,55496.69383692741,61297.40745139122,55496.69383692741,5788.15007519722,5.872597932815552,0.0 -119600,1.5011498,2.6652608,,,,,,,,,,,,,, -119700,1.5666451,2.6364439,,,,,,,,,,,,,, -119800,1.399331,2.5272002,,,,,,,,,,,,,, -119900,1.3648726,2.5597389,,,,,,,,,,,,,, -120000,1.3541172,2.711422,,,,,,,,,,,,,, -120100,1.2543868,2.8517976,,,,,,,,,,,,,, -120200,1.4515623,2.5664797,,,,,,,,,,,,,, -120300,1.484315,2.5668797,,,,,,,,,,,,,, -120400,1.6929117,2.9694486,,,,,,,,,,,,,, -120427,,,0.6160937547683716,1.6148782968521118,0.5806800127029419,1.8061535358428955,50000.0,0.4561000168323517,2.485478162765503,10000.0,55916.784044504166,61762.35899710655,55916.784044504166,5832.912647247314,5.921623706817627,0.0 -120500,1.3420086,2.816339,,,,,,,,,,,,,, -120600,1.2468159,3.3800507,,,,,,,,,,,,,, -120700,1.087302,4.0487537,,,,,,,,,,,,,, -120800,1.356334,2.7859502,,,,,,,,,,,,,, -120900,1.2085229,3.1400778,,,,,,,,,,,,,, -121000,1.2167344,5.1654477,,,,,,,,,,,,,, -121100,1.3875684,3.8769512,,,,,,,,,,,,,, -121200,1.0456313,5.0526495,,,,,,,,,,,,,, -121300,1.4938062,5.3791957,,,,,,,,,,,,,, -121333,,,0.6248632669448853,1.5781731605529783,0.5776399970054626,1.8077296018600464,50000.0,0.458400011062622,2.487220764160156,10000.0,56336.82581615448,62230.889113903046,56336.82581615448,5881.302688598633,5.969411849975586,0.0 -121400,1.0730989,5.018385,,,,,,,,,,,,,, -121500,1.2568904,3.7108963,,,,,,,,,,,,,, -121600,1.0520805,5.1597977,,,,,,,,,,,,,, -121700,1.5293612,2.615373,,,,,,,,,,,,,, -121800,1.2349222,4.349253,,,,,,,,,,,,,, -121900,1.5394264,2.4250033,,,,,,,,,,,,,, -122000,1.3648082,2.8575122,,,,,,,,,,,,,, -122100,1.2263143,3.8385687,,,,,,,,,,,,,, -122200,1.3510883,2.569454,,,,,,,,,,,,,, -122237,,,0.6359570026397705,1.568046808242798,0.5774999856948853,1.843104600906372,50000.0,0.458400011062622,2.5103447437286377,10000.0,56756.91334247589,62695.95139026642,56756.91334247589,5926.17863202095,6.0188164710998535,0.0 -122300,1.2820902,5.2135243,,,,,,,,,,,,,, -122400,1.3401217,2.8345919,,,,,,,,,,,,,, -122500,1.2309275,3.6745133,,,,,,,,,,,,,, -122600,1.3044516,2.6066768,,,,,,,,,,,,,, -122700,1.3356105,3.6000066,,,,,,,,,,,,,, -122800,1.2113392,4.27045,,,,,,,,,,,,,, -122900,1.3721597,2.6682448,,,,,,,,,,,,,, -123000,1.4636594,2.5944068,,,,,,,,,,,,,, -123100,1.2538447,4.6749034,,,,,,,,,,,,,, -123144,,,0.6289648413658142,1.563175916671753,0.5860599875450134,1.769154667854309,50000.0,0.4677000343799591,2.4270832538604736,10000.0,57177.12469792366,63158.04018783569,57177.12469792366,5967.953478097916,6.071156740188599,0.0 -123200,1.4676385,2.5846148,,,,,,,,,,,,,, -123300,1.3118504,3.2657006,,,,,,,,,,,,,, -123400,1.4283948,2.4499729,,,,,,,,,,,,,, -123500,1.467729,2.525516,,,,,,,,,,,,,, -123600,1.2421762,4.3002095,,,,,,,,,,,,,, -123700,1.5841695,2.7058125,,,,,,,,,,,,,, -123800,1.4127759,2.5870087,,,,,,,,,,,,,, -123900,1.5668918,2.6615558,,,,,,,,,,,,,, -124000,1.2679182,3.6372788,,,,,,,,,,,,,, -124049,,,0.640429675579071,1.5160832405090332,0.5918599963188171,1.7468619346618652,50000.0,0.4705000221729278,2.4021661281585693,10000.0,57597.355749607086,63624.529782772064,57597.355749607086,6014.114964962006,6.117965459823608,0.0 -124100,1.5016057,3.685557,,,,,,,,,,,,,, -124200,1.4323322,2.8475518,,,,,,,,,,,,,, -124300,1.437598,2.638984,,,,,,,,,,,,,, -124400,1.3752759,2.5423007,,,,,,,,,,,,,, -124500,1.2595409,4.1074853,,,,,,,,,,,,,, -124600,1.3596197,3.0668612,,,,,,,,,,,,,, -124700,1.2603091,5.154341,,,,,,,,,,,,,, -124800,1.4486916,2.5797544,,,,,,,,,,,,,, -124900,1.6702058,2.7212925,,,,,,,,,,,,,, -124954,,,0.6468749642372131,1.47525954246521,0.5937199592590332,1.733345627784729,50000.0,0.4673000276088714,2.402538537979126,10000.0,58017.52013874054,64090.95868706703,58017.52013874054,6060.277943134308,6.1684510707855225,0.0 -125000,1.2087843,4.276552,,,,,,,,,,,,,, -125100,1.2612433,3.211564,,,,,,,,,,,,,, -125200,1.4618814,2.5467675,,,,,,,,,,,,,, -125300,1.585495,2.5503268,,,,,,,,,,,,,, -125400,1.4881693,4.239197,,,,,,,,,,,,,, -125500,1.5313216,2.5103583,,,,,,,,,,,,,, -125600,1.5873626,2.7241244,,,,,,,,,,,,,, -125700,1.468413,2.6261754,,,,,,,,,,,,,, -125800,1.2366124,2.820293,,,,,,,,,,,,,, -125858,,,0.6384179592132568,1.4969202280044556,0.5963000059127808,1.702694058418274,50000.0,0.4754000306129455,2.3764123916625977,10000.0,58437.53057861328,64555.04180955887,58437.53057861328,6104.252175807953,6.215758085250855,0.0 -125900,1.6031256,2.5154464,,,,,,,,,,,,,, -126000,1.4657747,2.448945,,,,,,,,,,,,,, -126100,1.3226589,3.4001613,,,,,,,,,,,,,, -126200,1.2442397,4.6346335,,,,,,,,,,,,,, -126300,1.5201769,2.5658014,,,,,,,,,,,,,, -126400,1.2183576,4.665,,,,,,,,,,,,,, -126500,1.4460251,2.743866,,,,,,,,,,,,,, -126600,1.5193906,2.6024127,,,,,,,,,,,,,, -126700,1.3575944,5.183425,,,,,,,,,,,,,, -126763,,,0.6460155844688416,1.478101372718811,0.6011399626731873,1.703931212425232,50000.0,0.4791000187397003,2.3670125007629395,10000.0,58857.83850765228,65020.412427186966,58857.83850765228,6149.214513778687,6.2652997970581055,0.0 -126800,1.7009276,2.7862337,,,,,,,,,,,,,, -126900,1.4529018,2.475863,,,,,,,,,,,,,, -127000,1.3229731,5.015596,,,,,,,,,,,,,, -127100,1.4723604,2.6904855,,,,,,,,,,,,,, -127200,1.5414044,2.6693392,,,,,,,,,,,,,, -127300,1.5571207,2.4703577,,,,,,,,,,,,,, -127400,1.5366119,2.4852865,,,,,,,,,,,,,, -127500,1.5917689,5.199087,,,,,,,,,,,,,, -127600,1.5204827,2.5587153,,,,,,,,,,,,,, -127669,,,0.6504492163658142,1.4836300611495972,0.5961999893188477,1.7273157835006714,50000.0,0.4782000184059143,2.389633655548096,10000.0,59278.02184915543,65488.41676783562,59278.02184915543,6196.937285423279,6.313244581222534,0.0 -127700,1.2248186,4.211808,,,,,,,,,,,,,, -127800,1.4265842,4.577612,,,,,,,,,,,,,, -127900,1.187023,4.2736115,,,,,,,,,,,,,, -128000,1.4664994,2.5754297,,,,,,,,,,,,,, -128100,1.1690868,4.186881,,,,,,,,,,,,,, -128200,1.27105,4.449048,,,,,,,,,,,,,, -128300,1.3578871,4.527561,,,,,,,,,,,,,, -128400,1.727993,2.6276083,,,,,,,,,,,,,, -128500,1.4648672,2.7804887,,,,,,,,,,,,,, -128576,,,0.6483983993530273,1.4797208309173584,0.6062399744987488,1.687266826629639,50000.0,0.4870000183582306,2.3387036323547363,10000.0,59698.199717760086,65953.0438401699,59698.199717760086,6241.280877828598,6.368543863296509,0.0 -128600,1.4945041,2.5407782,,,,,,,,,,,,,, -128700,1.4618552,2.5084805,,,,,,,,,,,,,, -128800,1.6698648,2.5241127,,,,,,,,,,,,,, -128900,1.3132386,3.8288858,,,,,,,,,,,,,, -129000,1.4942381,2.8834794,,,,,,,,,,,,,, -129100,1.5209966,2.4800506,,,,,,,,,,,,,, -129200,1.2395637,3.6153314,,,,,,,,,,,,,, -129300,1.520479,2.4189525,,,,,,,,,,,,,, -129400,1.3732648,4.6347795,,,,,,,,,,,,,, -129480,,,0.6539648175239563,1.4522738456726074,0.605679988861084,1.6788485050201416,50000.0,0.4841000139713287,2.354218006134033,10000.0,60118.53882169724,66419.9552268982,60118.53882169724,6287.754509687424,6.416506052017212,0.0 -129500,1.3418655,4.2701945,,,,,,,,,,,,,, -129600,1.2711122,4.986047,,,,,,,,,,,,,, -129700,1.5049648,2.5324488,,,,,,,,,,,,,, -129800,1.2800722,5.077081,,,,,,,,,,,,,, -129900,1.6943158,2.5816426,,,,,,,,,,,,,, -130000,1.4345919,2.8324573,,,,,,,,,,,,,, -130100,1.6159735,2.4712193,,,,,,,,,,,,,, -130200,1.7577698,2.423575,,,,,,,,,,,,,, -130300,1.3458626,3.4622667,,,,,,,,,,,,,, -130386,,,0.6702734231948853,1.4016005992889404,0.6136199831962585,1.660366773605347,50000.0,0.4922000169754028,2.3320703506469727,10000.0,60538.44319176674,66885.50264811516,60538.44319176674,6333.296562671661,6.466786623001099,0.0 -130400,1.2545295,3.0890088,,,,,,,,,,,,,, -130500,1.4915386,2.578549,,,,,,,,,,,,,, -130600,1.6139671,2.403954,,,,,,,,,,,,,, -130700,1.5396879,2.8198783,,,,,,,,,,,,,, -130800,1.3542578,5.063865,,,,,,,,,,,,,, -130900,1.3147482,4.099302,,,,,,,,,,,,,, -131000,1.509483,2.8525498,,,,,,,,,,,,,, -131100,1.4428021,2.5965679,,,,,,,,,,,,,, -131200,1.2688763,4.3870277,,,,,,,,,,,,,, -131293,,,0.6552929282188416,1.4327030181884766,0.6084200143814087,1.6539702415466309,50000.0,0.4907000362873077,2.317131996154785,10000.0,60958.65554857254,67347.1164739132,60958.65554857254,6374.600378513336,6.51424241065979,0.0 -131300,1.6469826,2.3540328,,,,,,,,,,,,,, -131400,1.5572983,3.0509465,,,,,,,,,,,,,, -131500,1.6432302,2.458194,,,,,,,,,,,,,, -131600,1.628844,2.536977,,,,,,,,,,,,,, -131700,1.6181936,2.5101228,,,,,,,,,,,,,, -131800,1.3948876,2.5283465,,,,,,,,,,,,,, -131900,1.5821314,2.5829194,,,,,,,,,,,,,, -132000,1.6531255,2.4318733,,,,,,,,,,,,,, -132100,1.3873883,4.6166153,,,,,,,,,,,,,, -132198,,,0.6626366972923279,1.4006588459014893,0.6137999892234802,1.6258078813552856,50000.0,0.4917000234127044,2.3008532524108887,10000.0,61378.65146899223,67813.22391462326,61378.65146899223,6420.612956047058,6.562076568603516,0.0 -132200,1.7042453,2.4587352,,,,,,,,,,,,,, -132300,1.3421308,4.5842514,,,,,,,,,,,,,, -132400,1.6774952,2.2944293,,,,,,,,,,,,,, -132500,1.3143622,4.6423965,,,,,,,,,,,,,, -132600,1.557563,2.4131248,,,,,,,,,,,,,, -132700,1.484788,2.5252666,,,,,,,,,,,,,, -132800,1.666434,2.5603962,,,,,,,,,,,,,, -132900,1.5426115,2.5117483,,,,,,,,,,,,,, -133000,1.4113213,3.4555626,,,,,,,,,,,,,, -133100,1.5779932,2.616608,,,,,,,,,,,,,, -133104,,,0.6655077934265137,1.4070724248886108,0.6142599582672119,1.652784824371338,50000.0,0.4901000261306762,2.3164613246917725,10000.0,61798.64251708984,68278.38157367706,61798.64251708984,6465.681235074997,6.609776496887207,0.0 -133200,1.351878,5.0517497,,,,,,,,,,,,,, -133300,1.6262199,2.3689725,,,,,,,,,,,,,, -133400,1.3094817,4.2222013,,,,,,,,,,,,,, -133500,1.7219083,2.5419497,,,,,,,,,,,,,, -133600,1.5854958,2.3993325,,,,,,,,,,,,,, -133700,1.3975794,4.3187695,,,,,,,,,,,,,, -133800,1.7304974,2.3787785,,,,,,,,,,,,,, -133900,1.5291551,2.3971534,,,,,,,,,,,,,, -134000,1.4478296,2.490668,,,,,,,,,,,,,, -134008,,,0.6812499761581421,1.345516324043274,0.6206799745559692,1.6100428104400637,50000.0,0.497700035572052,2.2720818519592285,10000.0,62218.54648470879,68742.47394442558,62218.54648470879,6509.769501447678,6.65800929069519,0.0 -134100,1.5989941,2.5471687,,,,,,,,,,,,,, -134200,1.6895727,2.3363576,,,,,,,,,,,,,, -134300,1.5430466,2.4928722,,,,,,,,,,,,,, -134400,1.5502565,2.8106,,,,,,,,,,,,,, -134500,1.6087757,2.3651116,,,,,,,,,,,,,, -134600,1.4292568,4.885476,,,,,,,,,,,,,, -134700,1.5515345,2.663675,,,,,,,,,,,,,, -134800,1.3714931,3.3429458,,,,,,,,,,,,,, -134900,1.6449345,2.2962544,,,,,,,,,,,,,, -134913,,,0.6734570264816284,1.3531352281570437,0.6266799569129944,1.563097596168518,50000.0,0.5078000426292419,2.2082858085632324,10000.0,62638.87726259232,69209.31068348885,62638.87726259232,6556.1755702495575,6.708175182342529,0.0 -135000,1.6159569,2.6278133,,,,,,,,,,,,,, -135100,1.4306087,4.5244627,,,,,,,,,,,,,, -135200,1.6277765,3.4994674,,,,,,,,,,,,,, -135300,1.5110856,4.8778706,,,,,,,,,,,,,, -135400,1.474823,2.8735504,,,,,,,,,,,,,, -135500,1.6188363,2.3603213,,,,,,,,,,,,,, -135600,1.6133993,2.3159454,,,,,,,,,,,,,, -135700,1.5594082,2.4251504,,,,,,,,,,,,,, -135800,1.378795,3.8476665,,,,,,,,,,,,,, -135817,,,0.6738085746765137,1.3599903583526611,0.6240599751472473,1.5829092264175415,50000.0,0.5031000375747681,2.244506359100342,10000.0,63059.16834306717,69672.13042116165,63059.16834306717,6598.604954004288,6.757187128067017,0.0 -135900,1.4928916,2.7670624,,,,,,,,,,,,,, -136000,1.4562509,3.9184313,,,,,,,,,,,,,, -136100,1.4568068,2.8707187,,,,,,,,,,,,,, -136200,1.7229849,2.3387537,,,,,,,,,,,,,, -136300,1.629709,2.4553673,,,,,,,,,,,,,, -136400,1.7838297,2.4479847,,,,,,,,,,,,,, -136500,1.7152457,2.265524,,,,,,,,,,,,,, -136600,1.3318148,4.191487,,,,,,,,,,,,,, -136700,1.6101224,2.251271,,,,,,,,,,,,,, -136722,,,0.7079687118530273,1.2369790077209473,0.6272000074386597,1.5897619724273682,50000.0,0.5020000338554382,2.246837615966797,10000.0,63479.30086803436,70138.27771234512,63479.30086803436,6644.519300699234,6.807169198989868,0.0 -136800,1.670566,2.4708655,,,,,,,,,,,,,, -136900,1.5781252,2.6939063,,,,,,,,,,,,,, -137000,1.5038306,2.3140736,,,,,,,,,,,,,, -137100,1.4358816,3.7253668,,,,,,,,,,,,,, -137200,1.8432568,2.342779,,,,,,,,,,,,,, -137300,1.4692944,3.4711924,,,,,,,,,,,,,, -137400,1.7007891,2.2842708,,,,,,,,,,,,,, -137500,1.6936723,2.265284,,,,,,,,,,,,,, -137600,1.7389976,2.4625196,,,,,,,,,,,,,, -137630,,,0.675585925579071,1.340919017791748,0.6327599883079529,1.5627877712249756,50000.0,0.5099000334739685,2.199458122253418,10000.0,63899.45568084717,70599.13270401955,63899.45568084717,6685.11887216568,6.857578992843628,0.0 -137700,1.7027807,4.386802,,,,,,,,,,,,,, -137800,1.6136678,2.4283218,,,,,,,,,,,,,, -137900,1.4984981,3.6925025,,,,,,,,,,,,,, -138000,1.6742557,2.7070804,,,,,,,,,,,,,, -138100,1.5269862,3.3080864,,,,,,,,,,,,,, -138200,1.4356017,4.4137444,,,,,,,,,,,,,, -138300,1.7661682,2.3225272,,,,,,,,,,,,,, -138400,1.4564099,4.546668,,,,,,,,,,,,,, -138500,1.5592228,4.905936,,,,,,,,,,,,,, -138533,,,0.6859960556030273,1.2904889583587646,0.6377399563789368,1.5166523456573486,50000.0,0.5120000243186951,2.1723082065582275,10000.0,64319.47856760025,71062.68217468262,64319.47856760025,6728.542403936386,6.904958009719849,0.0 -138600,1.5678705,2.7564435,,,,,,,,,,,,,, -138700,1.5579565,2.9532368,,,,,,,,,,,,,, -138800,1.5843538,2.7793705,,,,,,,,,,,,,, -138900,1.7236761,2.378051,,,,,,,,,,,,,, -139000,1.4917445,4.1731925,,,,,,,,,,,,,, -139100,1.6980864,2.3394628,,,,,,,,,,,,,, -139200,1.708577,2.3803132,,,,,,,,,,,,,, -139300,1.7951939,2.5500538,,,,,,,,,,,,,, -139400,1.7521628,4.6123657,,,,,,,,,,,,,, -139435,,,0.6988281011581421,1.2646912336349487,0.6321799755096436,1.5676641464233398,50000.0,0.5063000321388245,2.2396981716156006,10000.0,64739.70188331604,71527.2613837719,64739.70188331604,6772.798098802567,6.95450758934021,0.0 -139500,1.631434,4.06946,,,,,,,,,,,,,, -139600,1.9748617,2.2888684,,,,,,,,,,,,,, -139700,1.7190027,2.3955014,,,,,,,,,,,,,, -139800,1.7450622,2.3200965,,,,,,,,,,,,,, -139900,1.7149817,2.144475,,,,,,,,,,,,,, -140000,1.6663011,2.1913278,,,,,,,,,,,,,, -140100,1.637149,2.6890178,,,,,,,,,,,,,, -140200,1.6236824,3.016374,,,,,,,,,,,,,, -140300,1.919599,2.310785,,,,,,,,,,,,,, -140342,,,0.6906836032867432,1.2769279479980469,0.6407399773597717,1.5057283639907837,50000.0,0.5200999975204468,2.154827356338501,10000.0,65159.62976360321,71991.70762300491,65159.62976360321,6817.2175986766815,7.002084493637085,0.0 -140400,1.6433636,2.7965374,,,,,,,,,,,,,, -140500,1.5997591,3.8276553,,,,,,,,,,,,,, -140600,1.8543082,2.208207,,,,,,,,,,,,,, -140700,1.6939815,2.182856,,,,,,,,,,,,,, -140800,1.8017113,2.221386,,,,,,,,,,,,,, -140900,1.6729103,2.2588544,,,,,,,,,,,,,, -141000,1.7404299,2.6116161,,,,,,,,,,,,,, -141100,1.6365286,4.6135306,,,,,,,,,,,,,, -141200,1.5356765,3.5683014,,,,,,,,,,,,,, -141247,,,0.69837886095047,1.2515182495117188,0.6464599967002869,1.4908294677734375,50000.0,0.5234000086784363,2.1413466930389404,10000.0,65579.96733903885,72457.2199523449,65579.96733903885,6862.2890038490295,7.054901123046875,0.0 -141300,1.5743362,4.386913,,,,,,,,,,,,,, -141400,1.7123545,2.9814854,,,,,,,,,,,,,, -141500,1.6103799,2.5933778,,,,,,,,,,,,,, -141600,1.6122531,4.2167225,,,,,,,,,,,,,, -141700,1.5538338,4.3195796,,,,,,,,,,,,,, -141800,1.9320731,2.154416,,,,,,,,,,,,,, -141900,1.5895824,4.9000287,,,,,,,,,,,,,, -142000,2.0512815,2.3781548,,,,,,,,,,,,,, -142100,1.6542394,3.2602139,,,,,,,,,,,,,, -142151,,,0.7085741758346558,1.1992639303207395,0.6479600071907043,1.4763096570968628,50000.0,0.5210000276565552,2.128354549407959,10000.0,66000.20277547836,72922.64905381203,66000.20277547836,6907.379415750504,7.107923269271851,0.0 -142200,1.8508207,2.267456,,,,,,,,,,,,,, -142300,1.5889963,3.552074,,,,,,,,,,,,,, -142400,1.6951928,2.3404489,,,,,,,,,,,,,, -142500,1.8367057,2.0585492,,,,,,,,,,,,,, -142600,1.7894886,2.1742897,,,,,,,,,,,,,, -142700,1.7260022,2.6446338,,,,,,,,,,,,,, -142800,1.59797,3.1473625,,,,,,,,,,,,,, -142900,1.8307898,2.232771,,,,,,,,,,,,,, -143000,1.7663525,4.4806957,,,,,,,,,,,,,, -143055,,,0.7023242115974426,1.2293777465820312,0.6479600071907043,1.4754472970962524,50000.0,0.5277000069618225,2.1308555603027344,10000.0,66420.13485479355,73390.75317406654,66420.13485479355,6955.451377868652,7.158038854598999,0.0 -143100,1.779344,2.1638892,,,,,,,,,,,,,, -143200,1.9201916,2.2473009,,,,,,,,,,,,,, -143300,1.9373969,4.767539,,,,,,,,,,,,,, -143400,1.8254434,2.3343487,,,,,,,,,,,,,, -143500,1.9260131,2.333313,,,,,,,,,,,,,, -143600,1.6254305,2.7591708,,,,,,,,,,,,,, -143700,1.538968,3.315338,,,,,,,,,,,,,, -143800,1.7805927,2.2403662,,,,,,,,,,,,,, -143900,1.58946,3.3910568,,,,,,,,,,,,,, -143961,,,0.70570307970047,1.21626615524292,0.6527599692344666,1.4616146087646484,50000.0,0.5301000475883484,2.108567237854004,10000.0,66840.06156110764,73856.33251214027,66840.06156110764,7001.00201010704,7.209303379058838,0.0 -144000,2.0984483,2.1778886,,,,,,,,,,,,,, -144100,1.6029823,4.442468,,,,,,,,,,,,,, -144200,2.0191734,2.2677338,,,,,,,,,,,,,, -144300,1.9894856,2.2088053,,,,,,,,,,,,,, -144400,1.8890426,2.2149916,,,,,,,,,,,,,, -144500,1.6850301,2.2131279,,,,,,,,,,,,,, -144600,2.192859,2.2230432,,,,,,,,,,,,,, -144700,1.8415724,2.2294912,,,,,,,,,,,,,, -144800,1.7813064,2.6597204,,,,,,,,,,,,,, -144867,,,0.7190039157867432,1.1549161672592163,0.6559000015258789,1.4345026016235352,50000.0,0.534500002861023,2.0833964347839355,10000.0,67260.06018710136,74321.73847436905,67260.06018710136,7046.307956695557,7.260730981826782,0.0 -144900,2.0027847,2.1453955,,,,,,,,,,,,,, -145000,1.6523701,3.1128192,,,,,,,,,,,,,, -145100,1.8784236,2.332222,,,,,,,,,,,,,, -145200,1.916367,2.228688,,,,,,,,,,,,,, -145300,1.8295503,2.0724232,,,,,,,,,,,,,, -145400,1.8648958,2.1546528,,,,,,,,,,,,,, -145500,1.753258,2.1201108,,,,,,,,,,,,,, -145600,2.0000987,4.4164042,,,,,,,,,,,,,, -145700,1.9022927,1.974113,,,,,,,,,,,,,, -145774,,,0.7044140696525574,1.2267926931381226,0.6544199585914612,1.4555846452713013,50000.0,0.5312000513076782,2.109943389892578,10000.0,67680.40997552872,74787.97717380524,67680.40997552872,7092.093881845474,7.313398361206055,0.0 -145800,1.6438257,3.460515,,,,,,,,,,,,,, -145900,1.5342791,3.9586413,,,,,,,,,,,,,, -146000,1.8674183,2.3201861,,,,,,,,,,,,,, -146100,1.9608419,2.1347451,,,,,,,,,,,,,, -146200,2.2672408,2.1952696,,,,,,,,,,,,,, -146300,1.8667758,2.4374688,,,,,,,,,,,,,, -146400,1.6703644,2.8971004,,,,,,,,,,,,,, -146500,1.81466,4.7496886,,,,,,,,,,,,,, -146600,1.8581572,2.1774313,,,,,,,,,,,,,, -146680,,,0.7185156345367432,1.158596158027649,0.6607599854469299,1.4224454164505005,50000.0,0.5371000170707703,2.0639452934265137,10000.0,68100.34601187706,75256.75840878487,68100.34601187706,7140.838541984558,7.363656282424927,0.0 -146700,2.3555944,2.129681,,,,,,,,,,,,,, -146800,2.0399554,4.765086,,,,,,,,,,,,,, -146900,2.056958,2.531935,,,,,,,,,,,,,, -147000,2.0124793,2.119724,,,,,,,,,,,,,, -147100,1.8773527,2.22531,,,,,,,,,,,,,, -147200,1.6699748,3.7149167,,,,,,,,,,,,,, -147300,1.9994129,2.1220543,,,,,,,,,,,,,, -147400,1.7524326,3.4733462,,,,,,,,,,,,,, -147500,1.8297518,2.9908512,,,,,,,,,,,,,, -147585,,,0.7215625047683716,1.143648982048035,0.6613399982452393,1.4184284210205078,50000.0,0.541700005531311,2.0514132976531982,10000.0,68520.35908174515,75718.0649137497,68520.35908174515,7182.027892827988,7.416749477386475,0.0 -147600,2.0714638,2.1391659,,,,,,,,,,,,,, -147700,1.8533729,2.4457285,,,,,,,,,,,,,, -147800,1.9769719,2.1271155,,,,,,,,,,,,,, -147900,1.9688214,2.070985,,,,,,,,,,,,,, -148000,2.0109465,2.18828,,,,,,,,,,,,,, -148100,1.850211,2.409503,,,,,,,,,,,,,, -148200,1.8686397,2.2102222,,,,,,,,,,,,,, -148300,2.0105064,2.450851,,,,,,,,,,,,,, -148400,1.8251877,3.4424965,,,,,,,,,,,,,, -148493,,,0.7193359136581421,1.1469899415969849,0.667140007019043,1.39079749584198,50000.0,0.5446000099182129,2.0252130031585693,10000.0,68940.4190402031,76183.35967731476,68940.4190402031,7227.158820390701,7.469902276992798,0.0 -148500,1.8340827,4.025883,,,,,,,,,,,,,, -148600,2.009056,2.1074822,,,,,,,,,,,,,, -148700,1.9866968,1.9598007,,,,,,,,,,,,,, -148800,1.9638923,4.5198417,,,,,,,,,,,,,, -148900,1.8755534,2.2569659,,,,,,,,,,,,,, -149000,2.2263253,2.0146208,,,,,,,,,,,,,, -149100,2.242474,2.0624132,,,,,,,,,,,,,, -149200,2.1416597,2.0450952,,,,,,,,,,,,,, -149300,1.8916993,2.6169252,,,,,,,,,,,,,, -149400,,,0.7261328101158142,1.1292051076889038,0.6677199602127075,1.3793786764144895,50000.0,0.542900025844574,2.033937692642212,10000.0,69360.67560601234,76649.3745880127,69360.67560601234,7272.813814640045,7.522792100906372,0.0 -149400,2.0343595,4.5558977,,,,,,,,,,,,,, -149500,2.0831301,1.9708508,,,,,,,,,,,,,, -149600,1.8359096,2.901528,,,,,,,,,,,,,, -149700,1.8524373,3.3014824,,,,,,,,,,,,,, -149800,2.1033218,2.190166,,,,,,,,,,,,,, -149900,2.0839148,2.1342654,,,,,,,,,,,,,, -150000,1.9911919,2.1641433,,,,,,,,,,,,,, -150100,1.8389165,3.2698295,,,,,,,,,,,,,, -150200,2.0457518,2.0849652,,,,,,,,,,,,,, -150300,1.898977,4.151627,,,,,,,,,,,,,, -150306,,,0.7345898151397705,1.0911519527435305,0.6721599698066711,1.3719079494476318,50000.0,0.5432000160217285,2.0280535221099854,10000.0,69780.98648524284,77111.10426402092,69780.98648524284,7314.125680685043,7.57967209815979,0.0 -150400,1.8660499,3.5191662,,,,,,,,,,,,,, -150500,2.1571207,2.062399,,,,,,,,,,,,,, -150600,1.8985674,2.1497357,,,,,,,,,,,,,, -150700,2.3095012,4.3762646,,,,,,,,,,,,,, -150800,1.8425163,2.9056628,,,,,,,,,,,,,, -150900,2.105294,2.0240393,,,,,,,,,,,,,, -151000,2.0570068,1.8995944,,,,,,,,,,,,,, -151100,2.005332,2.125307,,,,,,,,,,,,,, -151200,2.3278563,2.0429223,,,,,,,,,,,,,, -151210,,,0.7315624952316284,1.0995270013809204,0.6723600029945374,1.3750596046447754,50000.0,0.5479000210762024,2.0155558586120605,10000.0,70201.0912911892,77577.59732437134,70201.0912911892,7360.41232419014,7.630008459091186,0.0 -151300,1.9472933,4.261952,,,,,,,,,,,,,, -151400,1.9322392,3.8089426,,,,,,,,,,,,,, -151500,1.9815153,3.0557382,,,,,,,,,,,,,, -151600,2.1451051,2.1062734,,,,,,,,,,,,,, -151700,1.9933788,2.0230029,,,,,,,,,,,,,, -151800,2.0989823,2.599053,,,,,,,,,,,,,, -151900,2.1481776,1.9661561,,,,,,,,,,,,,, -152000,2.129817,2.0301745,,,,,,,,,,,,,, -152100,2.1127088,2.0902643,,,,,,,,,,,,,, -152116,,,0.7346875071525574,1.0947026014328003,0.6771000027656555,1.352385401725769,50000.0,0.5561000108718872,1.9908647537231443,10000.0,70620.99999403954,78043.17117524147,70620.99999403954,7405.974725008011,7.682317733764648,0.0 -152200,1.9439275,1.9969985,,,,,,,,,,,,,, -152300,1.8502551,3.510057,,,,,,,,,,,,,, -152400,2.12776,2.1554098,,,,,,,,,,,,,, -152500,1.9702077,2.288209,,,,,,,,,,,,,, -152600,2.0116947,2.3958893,,,,,,,,,,,,,, -152700,2.1144626,1.953721,,,,,,,,,,,,,, -152800,2.3201318,2.1414413,,,,,,,,,,,,,, -152900,2.1290386,4.1745205,,,,,,,,,,,,,, -153000,1.945833,3.2800949,,,,,,,,,,,,,, -153020,,,0.7451757788658142,1.0271271467208862,0.6819199919700623,1.3150537014007568,50000.0,0.5586000084877014,1.9496073722839355,10000.0,71041.29298949242,78505.25744843483,71041.29298949242,7447.667580366135,7.731944561004639,0.0 -153100,2.1892917,3.9009628,,,,,,,,,,,,,, -153200,2.1370952,2.1615977,,,,,,,,,,,,,, -153300,1.8926504,2.5079417,,,,,,,,,,,,,, -153400,1.9909366,1.9191341,,,,,,,,,,,,,, -153500,1.7971338,3.0266697,,,,,,,,,,,,,, -153600,2.1906888,2.0306847,,,,,,,,,,,,,, -153700,2.3673725,2.0783901,,,,,,,,,,,,,, -153800,1.9565359,2.953045,,,,,,,,,,,,,, -153900,2.228876,1.9575003,,,,,,,,,,,,,, -153922,,,0.758496105670929,0.978778839111328,0.6845600008964539,1.3084824085235596,50000.0,0.5543000102043152,1.9502512216567995,10000.0,71461.58991360664,78968.2080783844,71461.58991360664,7490.222132205963,7.780643224716186,0.0 -154000,2.2620618,1.967571,,,,,,,,,,,,,, -154100,2.448861,4.5163636,,,,,,,,,,,,,, -154200,2.347002,2.1703875,,,,,,,,,,,,,, -154300,2.1227179,2.2583442,,,,,,,,,,,,,, -154400,2.2198377,1.8997064,,,,,,,,,,,,,, -154500,2.1934464,2.713244,,,,,,,,,,,,,, -154600,2.327889,1.9755961,,,,,,,,,,,,,, -154700,2.3691776,2.032339,,,,,,,,,,,,,, -154800,2.0453095,2.1701176,,,,,,,,,,,,,, -154829,,,0.7406249642372131,1.0503544807434082,0.6873799562454224,1.2965481281280518,50000.0,0.5617000460624695,1.938288211822509,10000.0,71881.67134642601,79435.82341265678,71881.67134642601,7537.654074668884,7.830984830856323,0.0 -154900,2.1997802,2.03142,,,,,,,,,,,,,, -155000,2.027988,2.5311818,,,,,,,,,,,,,, -155100,2.2632203,2.082763,,,,,,,,,,,,,, -155200,2.308052,1.9869653,,,,,,,,,,,,,, -155300,2.0642202,2.4782143,,,,,,,,,,,,,, -155400,2.3414369,2.0085793,,,,,,,,,,,,,, -155500,2.5021753,4.51127,,,,,,,,,,,,,, -155600,2.132657,1.9973276,,,,,,,,,,,,,, -155700,2.0843732,2.6983278,,,,,,,,,,,,,, -155736,,,0.7502343654632568,1.0065470933914185,0.687559962272644,1.2907850742340088,50000.0,0.5610000491142273,1.922777771949768,10000.0,72301.95320534706,79901.12423586845,72301.95320534706,7582.568460941315,7.88397216796875,0.0 -155800,2.3819485,1.9723498,,,,,,,,,,,,,, -155900,2.386361,2.395928,,,,,,,,,,,,,, -156000,2.4170926,2.1124701,,,,,,,,,,,,,, -156100,2.2755377,2.1451793,,,,,,,,,,,,,, -156200,2.489265,1.9801602,,,,,,,,,,,,,, -156300,2.1744397,3.7199254,,,,,,,,,,,,,, -156400,2.6275651,4.500008,,,,,,,,,,,,,, -156500,2.3697145,1.98062,,,,,,,,,,,,,, -156600,2.4093792,2.09664,,,,,,,,,,,,,, -156641,,,0.76318359375,0.952807605266571,0.6926800012588501,1.2713115215301514,50000.0,0.5648000240325928,1.9105957746505733,10000.0,72721.92038750648,80366.11267876625,72721.92038750648,7627.477694272995,7.94585371017456,0.0 -156700,2.392156,1.9336957,,,,,,,,,,,,,, -156800,2.2056913,1.885701,,,,,,,,,,,,,, -156900,2.2562425,2.0589955,,,,,,,,,,,,,, -157000,2.5472527,1.827702,,,,,,,,,,,,,, -157100,2.4607635,1.8799953,,,,,,,,,,,,,, -157200,2.0387523,3.3437073,,,,,,,,,,,,,, -157300,2.3922906,1.9671605,,,,,,,,,,,,,, -157400,2.303369,1.8711282,,,,,,,,,,,,,, -157500,2.444636,1.8536581,,,,,,,,,,,,,, -157547,,,0.7520703077316284,1.0017552375793457,0.6958400011062622,1.251082420349121,50000.0,0.572700023651123,1.8889102935791016,10000.0,73141.96051359177,80832.09207010269,73141.96051359177,7673.310721158981,8.001603126525879,0.0 -157600,2.3630848,1.8836861,,,,,,,,,,,,,, -157700,2.5095117,1.889449,,,,,,,,,,,,,, -157800,2.2158618,2.1624792,,,,,,,,,,,,,, -157900,2.5911481,1.9570959,,,,,,,,,,,,,, -158000,2.564672,1.8890163,,,,,,,,,,,,,, -158100,2.4034705,1.8431847,,,,,,,,,,,,,, -158200,2.4683247,3.4401267,,,,,,,,,,,,,, -158300,2.5698223,2.0560026,,,,,,,,,,,,,, -158400,2.5095963,2.0217185,,,,,,,,,,,,,, -158452,,,0.7565234303474426,0.9877767562866212,0.6963799595832825,1.252785563468933,50000.0,0.5720000267028809,1.892466902732849,10000.0,73562.28885316849,81292.98815059662,73562.28885316849,7713.773415803909,8.056279420852661,0.0 -158500,2.348444,2.184414,,,,,,,,,,,,,, -158600,2.6930578,4.113834,,,,,,,,,,,,,, -158700,2.1753294,3.1236403,,,,,,,,,,,,,, -158800,2.4824042,3.9852955,,,,,,,,,,,,,, -158900,2.2219129,2.7826154,,,,,,,,,,,,,, -159000,2.4294808,1.890028,,,,,,,,,,,,,, -159100,2.2852204,2.3657875,,,,,,,,,,,,,, -159200,2.281914,2.8482857,,,,,,,,,,,,,, -159300,2.4461894,1.9688437,,,,,,,,,,,,,, -159359,,,0.7693945169448853,0.939318060874939,0.6983799934387207,1.2409757375717163,50000.0,0.5732000470161438,1.8835759162902832,10000.0,73982.47636890411,81759.38010597229,73982.47636890411,7759.871674776077,8.111968994140625,0.0 -159400,2.4793262,1.9911664,,,,,,,,,,,,,, -159500,2.4460392,1.8063138,,,,,,,,,,,,,, -159600,2.3648398,3.4550536,,,,,,,,,,,,,, -159700,2.3751705,3.5103033,,,,,,,,,,,,,, -159800,2.3182263,1.8278515,,,,,,,,,,,,,, -159900,2.3400111,3.0967948,,,,,,,,,,,,,, -160000,2.5275989,3.993286,,,,,,,,,,,,,, -160100,2.3086252,3.3453271,,,,,,,,,,,,,, -160200,2.5455337,3.579295,,,,,,,,,,,,,, -160264,,,0.7602343559265137,0.971134066581726,0.6994400024414062,1.2386878728866575,50000.0,0.5814000368118286,1.869041085243225,10000.0,74402.69127678871,82223.95091509819,74402.69127678871,7804.115849733353,8.173494815826416,0.0 -160300,2.610319,4.2131286,,,,,,,,,,,,,, -160400,2.5310812,1.9626036,,,,,,,,,,,,,, -160500,2.2770617,1.6572069,,,,,,,,,,,,,, -160600,2.489326,4.0229764,,,,,,,,,,,,,, -160700,2.765037,1.9413154,,,,,,,,,,,,,, -160800,2.3341558,3.6080427,,,,,,,,,,,,,, -160900,2.4887314,2.1036458,,,,,,,,,,,,,, -161000,2.576216,2.3074212,,,,,,,,,,,,,, -161100,2.6876712,1.7641145,,,,,,,,,,,,,, -161168,,,0.7683203220367432,0.9366870522499084,0.7050999999046326,1.2145330905914309,50000.0,0.5782999992370605,1.8563323020935056,10000.0,74822.95345377922,82692.12824940681,74822.95345377922,7851.924923419952,8.229990243911743,0.0 -161200,2.9265168,1.8899605,,,,,,,,,,,,,, -161300,2.5520613,1.9182515,,,,,,,,,,,,,, -161400,2.341963,2.466602,,,,,,,,,,,,,, -161500,2.3826299,3.1073103,,,,,,,,,,,,,, -161600,2.6345608,1.8946917,,,,,,,,,,,,,, -161700,2.7331896,1.8711212,,,,,,,,,,,,,, -161800,2.7289758,1.7034802,,,,,,,,,,,,,, -161900,2.5534356,4.028376,,,,,,,,,,,,,, -162000,2.6663647,1.8712169,,,,,,,,,,,,,, -162070,,,0.77783203125,0.8856746554374695,0.7069199681282043,1.2027339935302734,50000.0,0.5851000547409058,1.8382530212402344,10000.0,75242.92199492455,83156.2474398613,75242.92199492455,7895.971745014191,8.284348249435425,0.0 -162100,2.8882244,4.336593,,,,,,,,,,,,,, -162200,2.670511,1.8819133,,,,,,,,,,,,,, -162300,2.7303593,1.8065746,,,,,,,,,,,,,, -162400,2.5849676,1.7790022,,,,,,,,,,,,,, -162500,2.5731585,1.878536,,,,,,,,,,,,,, -162600,2.7612085,1.8191983,,,,,,,,,,,,,, -162700,2.5404308,1.7514763,,,,,,,,,,,,,, -162800,2.6765478,3.9716625,,,,,,,,,,,,,, -162900,2.494664,3.4273028,,,,,,,,,,,,,, -162976,,,0.7721288800239563,0.9069911241531372,0.7118799686431885,1.1784546375274658,50000.0,0.5868000388145447,1.793388605117798,10000.0,75663.21282863617,83617.6762702465,75663.21282863617,7937.003118753433,8.34020447731018,0.0 -163000,2.8706598,1.8088189,,,,,,,,,,,,,, -163100,2.6978035,2.7380118,,,,,,,,,,,,,, -163200,2.4937246,1.7370195,,,,,,,,,,,,,, -163300,2.7081964,2.0031571,,,,,,,,,,,,,, -163400,2.5476308,1.8508635,,,,,,,,,,,,,, -163500,2.5926092,3.4077475,,,,,,,,,,,,,, -163600,2.920666,3.9102988,,,,,,,,,,,,,, -163700,2.8483994,2.7827234,,,,,,,,,,,,,, -163800,2.6475406,1.8310843,,,,,,,,,,,,,, -163883,,,0.7769335508346558,0.8996948599815369,0.7105799913406372,1.1845804452896118,50000.0,0.5900000333786011,1.8173900842666624,10000.0,76083.81787419319,84085.51246571541,76083.81787419319,7984.124661684036,8.398981094360352,0.0 -163900,2.8552864,1.840627,,,,,,,,,,,,,, -164000,3.05316,3.3524075,,,,,,,,,,,,,, -164100,2.8379333,4.343681,,,,,,,,,,,,,, -164200,2.7627804,4.215621,,,,,,,,,,,,,, -164300,3.0051591,3.8463058,,,,,,,,,,,,,, -164400,2.731777,1.9358041,,,,,,,,,,,,,, -164500,2.527088,1.705179,,,,,,,,,,,,,, -164600,2.779749,1.7580155,,,,,,,,,,,,,, -164700,2.7189133,3.5137172,,,,,,,,,,,,,, -164791,,,0.7859765291213989,0.87266606092453,0.7135399580001831,1.186859130859375,50000.0,0.5868000388145447,1.812342643737793,10000.0,76504.19021439552,84552.21259260178,76504.19021439552,8030.3428745269775,8.457646608352661,0.0 -164800,2.5903888,2.3815122,,,,,,,,,,,,,, -164900,2.7548199,2.097618,,,,,,,,,,,,,, -165000,2.66566,1.8158846,,,,,,,,,,,,,, -165100,2.458388,3.32417,,,,,,,,,,,,,, -165200,2.7744255,1.9150553,,,,,,,,,,,,,, -165300,2.8716114,2.096362,,,,,,,,,,,,,, -165400,2.853375,1.8295513,,,,,,,,,,,,,, -165500,2.6576877,2.9232683,,,,,,,,,,,,,, -165600,3.3543055,4.139017,,,,,,,,,,,,,, -165696,,,0.7836523056030273,0.8662102222442627,0.7154799699783325,1.161203145980835,50000.0,0.5914000272750854,1.7810871601104736,10000.0,76924.15655565262,85017.09811878204,76924.15655565262,8075.156850099564,8.51274585723877,0.0 -165700,2.8973875,1.7660471,,,,,,,,,,,,,, -165800,2.8175452,1.7661943,,,,,,,,,,,,,, -165900,2.8632572,2.2299726,,,,,,,,,,,,,, -166000,2.8396935,2.0053983,,,,,,,,,,,,,, -166100,2.6665509,2.2960649,,,,,,,,,,,,,, -166200,2.6024473,1.9467107,,,,,,,,,,,,,, -166300,2.795602,1.9061906,,,,,,,,,,,,,, -166400,2.718288,2.3209038,,,,,,,,,,,,,, -166500,2.8330302,1.7461257,,,,,,,,,,,,,, -166600,3.0425944,3.155744,,,,,,,,,,,,,, -166602,,,0.7826757431030273,0.8744816184043884,0.7170400023460388,1.1657946109771729,50000.0,0.595300018787384,1.7856409549713137,10000.0,77344.08736562729,85484.08644890785,77344.08736562729,8122.107517242432,8.569658279418945,0.0 -166700,2.654861,2.1126955,,,,,,,,,,,,,, -166800,2.8583913,1.8386326,,,,,,,,,,,,,, -166900,2.9256847,3.723069,,,,,,,,,,,,,, -166988,,,,,,,,,,,77520.1577796936,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 192e860ee..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,186 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -28.86489605903625,0.0,40.46163988113403,1,0,40.46163988113403,0.0010000000474974,6.907756805419922,10000,69.32669401168823,0.0008203124743886,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -69.4005184173584,0.017852783203125,460.644246339798,844,0,460.644246339798,0.021900001913309,6.086627006530762,10000,530.1097981929779,0.0341015607118606,5.897230625152588,0.0298199988901615,5.968258380889893,50000 -112.46707344055176,0.0463771820068359,880.6717927455902,1745,0,880.6717927455902,0.0512000024318695,5.565213203430176,10000,993.2836394309998,0.0720312520861625,5.3030266761779785,0.0671200007200241,5.3553972244262695,50000 -157.77245998382568,0.0728633403778076,1300.8934905529022,2649,0,1300.8934905529022,0.0862000063061714,5.133504390716553,10000,1458.8881137371063,0.1213671863079071,4.764578342437744,0.1108599975705146,4.83726167678833,50000 -202.7385392189026,0.1011176109313964,1720.9098417758942,3552,0,1720.9098417758942,0.1246000081300735,4.737356662750244,10000,1923.949401378632,0.1814257800579071,4.194916248321533,0.1604000031948089,4.344925403594971,50000 -245.4547533988953,0.1272704601287841,2141.054492235184,4456,0,2141.054492235184,0.1565000116825103,4.419209003448486,10000,2386.887092113495,0.2257421761751175,3.862603664398194,0.2096999883651733,3.967260837554932,50000 -289.7249255180359,0.1532239913940429,2560.998118877411,5359,0,2560.998118877411,0.1880000084638595,4.178256511688232,10000,2851.177732467652,0.2718164026737213,3.585689306259156,0.2515200078487396,3.714376211166382,50000 -331.9820320606232,0.1811397075653076,2980.9549465179443,6257,0,2980.9549465179443,0.2222000062465667,3.962292432785034,10000,3313.470268011093,0.3246679604053497,3.229868173599243,0.2886599898338318,3.425157070159912,50000 -375.0047626495361,0.2081298828125,3400.9355845451355,7164,0,3400.9355845451355,0.2367000132799148,3.815329551696777,10000,3776.551196575165,0.3366992175579071,3.113884210586548,0.3124600052833557,3.242727041244507,50000 -419.699774980545,0.2372419834136963,3821.070417881012,8068,0,3821.070417881012,0.2571000158786773,3.680340051651001,10000,4241.460793018341,0.3708398342132568,2.9305522441864014,0.338919997215271,3.110529899597168,50000 -458.3977072238922,0.2645635604858398,4241.001017093658,8974,0,4241.001017093658,0.2725000083446502,3.5983641147613525,10000,4700.1675000190735,0.3893359303474426,2.798448085784912,0.3562199771404266,2.991145610809326,50000 -500.7986707687378,0.2959358692169189,4661.186886072159,9879,0,4661.186886072159,0.2928000092506408,3.460809230804444,10000,5162.836849212647,0.405078113079071,2.719622850418091,0.3754999935626983,2.876948118209839,50000 -546.2405309677124,0.3265047073364258,5081.527356147766,10783,0,5081.527356147766,0.3021000027656555,3.441816568374634,10000,5628.701157808304,0.4206640422344208,2.6827948093414307,0.3864399790763855,2.8488309383392334,50000 -588.2387778759003,0.3575336933135986,5501.6132843494415,11688,0,5501.6132843494415,0.3106000125408172,3.323500633239746,10000,6090.867266654968,0.4415038824081421,2.5091471672058105,0.405379980802536,2.7154462337493896,50000 -634.9937858581543,0.3894164562225342,5922.011283874512,12590,0,5922.011283874512,0.3197000026702881,3.3049449920654297,10000,6558.102586507797,0.4447070360183716,2.521016836166382,0.4148999750614166,2.6761584281921387,50000 -675.3277657032013,0.4220564365386963,6342.245297193527,13496,0,6342.245297193527,0.3357000052928924,3.1664600372314453,10000,7018.754692077637,0.4694921672344208,2.356694221496582,0.435479998588562,2.5276846885681152,50000 -720.3443946838379,0.4542615413665771,6762.292719364166,14400,0,6762.292719364166,0.3365000188350677,3.197645664215088,10000,7483.901390790939,0.4701952934265136,2.3852572441101074,0.43121999502182,2.579622268676758,50000 -766.3428432941437,0.4822478294372558,7182.435878038406,15307,0,7182.435878038406,0.3420000076293945,3.1562328338623047,10000,7950.122545957565,0.4783593714237213,2.331885814666748,0.4416399896144867,2.534491539001465,50000 -813.2723207473755,0.5117042064666748,7602.47861289978,16215,0,7602.47861289978,0.3481000065803528,3.092832088470459,10000,8417.176041603088,0.4888867139816284,2.2850089073181152,0.4542399942874908,2.4670143127441406,50000 -856.1803793907166,0.5442097187042236,8022.711845397949,17122,0,8022.711845397949,0.3571000099182129,3.0820775032043457,10000,8880.401599168777,0.4998632669448852,2.2155532836914062,0.4608399868011474,2.4244303703308105,50000 -903.8825159072876,0.5782530307769775,8442.835622787476,18026,0,8442.835622787476,0.3575000166893005,3.0405938625335693,10000,9348.312663078308,0.5294336080551147,2.071880340576172,0.4652199745178222,2.3781328201293945,50000 -949.3529365062714,0.607285737991333,8862.765635967255,18931,0,8862.765635967255,0.368800014257431,3.0019314289093018,10000,9813.79310321808,0.5103710889816284,2.167488574981689,0.4759399890899658,2.3467721939086914,50000 -996.4339108467102,0.6371915340423584,9283.02355337143,19837,0,9283.02355337143,0.3788000047206878,2.899568557739258,10000,10281.213695764542,0.5275976657867432,2.005328416824341,0.4874999821186065,2.2312562465667725,50000 -1033.9986152648926,0.6677725315093994,9703.317378759384,20743,0,9703.317378759384,0.3737000226974487,2.998039484024048,10000,10739.15369963646,0.5421093702316284,2.062912940979004,0.4810599982738495,2.356579303741455,50000 -1081.1976935863495,0.6967108249664307,10123.384189367294,21646,0,10123.384189367294,0.3819000124931335,2.9162533283233643,10000,11206.49951314926,0.5312694907188416,2.0607404708862305,0.4913399815559387,2.2656807899475098,50000 -1126.02326130867,0.7283682823181152,10543.569134950638,22552,0,10543.569134950638,0.3917000293731689,2.8608169555664062,10000,11671.593814611437,0.5475195050239563,1.9644417762756348,0.5002999901771545,2.18509578704834,50000 -1174.560831785202,0.7606453895568848,10963.765152215958,23459,0,10963.765152215958,0.3927000164985657,2.830965995788574,10000,12140.410922527311,0.5622265338897705,1.8803318738937376,0.5024799704551697,2.188140630722046,50000 -1214.4176337718964,0.7908592224121094,11384.004220485687,24365,0,11384.004220485687,0.3934000134468078,2.869940996170044,10000,12600.587675094604,0.5406445264816284,2.031728982925415,0.5039199590682983,2.227212190628052,50000 -1261.837417602539,0.822918176651001,11804.188628435137,25272,0,11804.188628435137,0.3937000334262848,2.8316471576690674,10000,13068.275243759155,0.5468554496765137,1.95665979385376,0.5080199837684631,2.159663438796997,50000 -1304.0307462215424,0.8551509380340576,12224.11382508278,26178,0,12224.11382508278,0.4075000286102295,2.760247707366944,10000,13530.477799415588,0.5712695121765137,1.830441117286682,0.519760012626648,2.1027395725250244,50000 -1348.732283115387,0.8903157711029053,12644.132603883743,27083,0,12644.132603883743,0.4067000150680542,2.7455155849456787,10000,13995.284102916718,0.5553905963897705,1.8902047872543333,0.5142999887466431,2.094199657440185,50000 -1390.9314963817596,0.926861047744751,13064.056718111038,27989,0,13064.056718111038,0.4098000228404999,2.7530479431152344,10000,14457.495115756989,0.5625976324081421,1.860401391983032,0.5207599997520447,2.0770411491394043,50000 -1437.0800392627716,0.9647171497344972,13484.38296675682,28896,0,13484.38296675682,0.4182000160217285,2.730376005172729,10000,14924.05919790268,0.5817968845367432,1.820831298828125,0.5314399600028992,2.074846029281616,50000 -1482.9127361774445,0.9980545043945312,13904.436289548874,29803,0,13904.436289548874,0.415800005197525,2.7066798210144043,10000,15390.029060602188,0.5716015696525574,1.8176162242889404,0.5288999676704407,2.026429891586304,50000 -1530.08101439476,1.0294535160064695,14324.511536359789,30710,0,14324.511536359789,0.4155000150203705,2.723546504974365,10000,15857.354704141617,0.5737109184265137,1.8410923480987549,0.5341199636459351,2.0497286319732666,50000 -1576.4077117443085,1.0594823360443115,14744.565649032593,31613,0,14744.565649032593,0.4191000163555145,2.6509931087493896,10000,16323.81619167328,0.5854882597923279,1.7064871788024902,0.5389999747276306,1.964975118637085,50000 -1618.4065613746643,1.089357614517212,15164.7554564476,32519,0,15164.7554564476,0.4282000064849853,2.6211376190185547,10000,16786.085805416107,0.5890039205551147,1.7607898712158203,0.5417999625205994,1.980234980583191,50000 -1664.2355268001556,1.119279861450195,15585.092749595642,33426,0,15585.092749595642,0.4208000302314758,2.6821558475494385,10000,17252.33269262314,0.5855468511581421,1.7376009225845337,0.5393800139427185,1.9697725772857664,50000 -1709.8396821022034,1.1521375179290771,16005.261153697968,34332,0,16005.261153697968,0.4281000196933746,2.639055252075196,10000,17718.188776254654,0.5911523103713989,1.7255398035049438,0.5455799698829651,1.9740146398544312,50000 -1755.8981931209564,1.190577507019043,16425.55720925331,35240,0,16425.55720925331,0.4411000311374664,2.584881544113159,10000,18184.633157491684,0.6228320002555847,1.5777928829193115,0.555180013179779,1.9033994674682613,50000 -1801.0961050987244,1.2253484725952148,16845.653829813004,36145,0,16845.653829813004,0.4221000075340271,2.6815497875213623,10000,18650.0134100914,0.5803906321525574,1.8397955894470213,0.5388000011444092,2.0403988361358643,50000 -1848.0710878372192,1.2622802257537842,17265.834097862244,37050,0,17265.834097862244,0.4429000318050384,2.5216729640960693,10000,19117.25606751442,0.6094335913658142,1.6258376836776731,0.5605800151824951,1.867913126945496,50000 -1889.9757521152496,1.2984967231750488,17685.957972049713,37953,0,17685.957972049713,0.442300021648407,2.54549503326416,10000,19579.37179613113,0.6267187595367432,1.5315265655517578,0.5590599775314331,1.876887917518616,50000 -1938.1587483882904,1.3353419303894043,18105.92478275299,38858,0,18105.92478275299,0.4444000124931335,2.581925630569458,10000,20047.60985803604,0.6036718487739563,1.6944724321365356,0.5605999827384949,1.9091485738754272,50000 -1985.312269449234,1.3708126544952393,18526.282481193542,39766,0,18526.282481193542,0.4431000351905823,2.54802680015564,10000,20515.20761990547,0.6093164086341858,1.628745675086975,0.5579000115394592,1.878006219863892,50000 -2031.407597780228,1.408151388168335,18946.47481775284,40672,0,18946.47481775284,0.4397000074386596,2.577730178833008,10000,20981.583657741547,0.6217382550239563,1.5823453664779663,0.5575399994850159,1.8897854089736936,50000 -2079.980150461197,1.4446847438812256,19366.68581557274,41577,0,19366.68581557274,0.4427000284194946,2.5433125495910645,10000,21450.45510196685,0.6031640768051147,1.654468655586243,0.5612999796867371,1.8774502277374268,50000 -2122.748773813248,1.4790616035461426,19786.95946097374,42482,0,19786.95946097374,0.4410000145435333,2.589221715927124,10000,21913.58345246315,0.6094140410423279,1.6763795614242554,0.5593799948692322,1.911125898361206,50000 -2168.378532409668,1.5133049488067627,20207.0824007988,43388,0,20207.0824007988,0.4397000074386596,2.585341215133667,10000,22379.42211842537,0.6217382550239563,1.639617919921875,0.5581200122833252,1.9320470094680784,50000 -2216.601419687271,1.5490844249725342,20627.26855278015,44294,0,20627.26855278015,0.4519000351428985,2.499508619308472,10000,22847.91852927208,0.613964855670929,1.6308157444000244,0.5718799829483032,1.843588590621948,50000 -2260.163228034973,1.5819377899169922,21047.384136915207,45197,0,21047.384136915207,0.45660001039505,2.4988009929656982,10000,23311.679278612137,0.6221679449081421,1.5743632316589355,0.5730199813842773,1.8135331869125368,50000 -2307.684317350388,1.6177711486816406,21468.03368353844,46101,0,21468.03368353844,0.4600000083446502,2.440560340881348,10000,23779.937309980392,0.6398242115974426,1.471780776977539,0.5819199681282043,1.7632266283035278,50000 -2353.266669511795,1.649388551712036,21888.051579475403,47003,0,21888.051579475403,0.4534000158309936,2.5132648944854736,10000,24245.620250225067,0.6192382574081421,1.632784128189087,0.5776799917221069,1.838263750076294,50000 -2401.858140230179,1.6812567710876465,22308.01690769196,47905,0,22308.01690769196,0.4571000337600708,2.462394952774048,10000,24714.25982022285,0.6240624785423279,1.551044464111328,0.5730800032615662,1.7903846502304075,50000 -2441.2826936244965,1.717482089996338,22728.048118829727,48809,0,22728.048118829727,0.4568000137805938,2.4816720485687256,10000,25173.80234694481,0.6260156035423279,1.572181224822998,0.5747199654579163,1.8311009407043457,50000 -2487.4757957458496,2.1974527835845947,23147.635646104813,49711,0,23147.635646104813,0.4463000297546386,2.5759851932525635,10000,25640.11376547813,0.6138671636581421,1.6917517185211182,0.5685399770736694,1.90058434009552,50000 -2531.470423936844,2.233795166015625,23567.73213458061,50614,0,23567.73213458061,0.4588000178337097,2.4899773597717285,10000,26104.293008327484,0.627246081829071,1.6077184677124023,0.5752800107002258,1.8396003246307373,50000 -2577.3440272808075,2.2699053287506104,23987.8773355484,51518,0,23987.8773355484,0.4621000289916992,2.4717304706573486,10000,26570.3987429142,0.6355859041213989,1.560902118682861,0.5841400027275085,1.812380313873291,50000 -2624.098313808441,2.305191993713379,24408.21269702912,52423,0,24408.21269702912,0.4638000130653381,2.4787209033966064,10000,27037.57492542267,0.6373242139816284,1.560644030570984,0.5805400013923645,1.8203362226486208,50000 -2667.875230550766,2.341886043548584,24828.36497950554,53329,0,24828.36497950554,0.4708000123500824,2.432278871536255,10000,27501.591789245605,0.6320117115974426,1.5440112352371216,0.586359977722168,1.7737897634506226,50000 -2710.399166822433,2.375976324081421,25248.749005794525,54233,0,25248.749005794525,0.4805000126361847,2.3959012031555176,10000,27964.58424758911,0.6431054472923279,1.4797742366790771,0.5918599963188171,1.7450494766235352,50000 -2756.7333924770355,2.421182155609131,25668.903258800507,55137,0,25668.903258800507,0.4723000228404999,2.417467832565308,10000,28431.16887879372,0.6678124666213989,1.3960999250411987,0.5922799706459045,1.7475827932357788,50000 -2803.8037741184235,2.458335876464844,26088.894425868988,56041,0,26088.894425868988,0.4618000090122223,2.443195343017578,10000,28898.31833386421,0.6312109231948853,1.543101787567139,0.5813199877738953,1.7773977518081665,50000 -2847.0764875411987,2.49930739402771,26509.195707321167,56946,0,26509.195707321167,0.4783000349998474,2.411650896072388,10000,29361.984380960464,0.6483007669448853,1.5265979766845703,0.598039984703064,1.7617758512496948,50000 -2883.793509721756,2.5372140407562256,26929.53288602829,57850,0,26929.53288602829,0.4708000123500824,2.402419328689575,10000,29819.127604722977,0.6579492092132568,1.4141149520874023,0.590499997138977,1.743949294090271,50000 -2932.443725347519,2.576699495315552,27349.482364177704,58755,0,27349.482364177704,0.4786000251770019,2.3848416805267334,10000,30287.81723570824,0.6382421851158142,1.5162142515182495,0.5955199599266052,1.725347638130188,50000 -2980.876615524292,2.613621950149536,27769.68719244004,59662,0,27769.68719244004,0.4775000214576721,2.395569562911988,10000,30756.54388856888,0.6444531083106995,1.4967687129974363,0.5925599932670593,1.7437061071395874,50000 -3028.5183403491974,2.6536691188812256,28189.677449941635,60567,0,28189.677449941635,0.4758000373840332,2.3964974880218506,10000,31224.26690030098,0.6565039157867432,1.454769730567932,0.5924599766731262,1.755020260810852,50000 -3069.303905725479,2.693972587585449,28609.880881786343,61470,0,28609.880881786343,0.4808000326156616,2.325901508331299,10000,31685.34749984741,0.650585949420929,1.4309154748916626,0.6028599739074707,1.6541436910629272,50000 -3117.232168197632,2.733799457550049,29030.00518536568,62375,0,29030.00518536568,0.4886000156402588,2.3273630142211914,10000,32153.49118900299,0.6543945074081421,1.4196845293045044,0.60725998878479,1.649720549583435,50000 -3158.303008079529,2.775385618209839,29450.34034228325,63280,0,29450.34034228325,0.484000027179718,2.328487157821656,10000,32614.989812135696,0.6655077934265137,1.3871628046035769,0.6027199625968933,1.683455228805542,50000 -3200.9960539340973,2.816265344619751,29870.42541265488,64183,0,29870.42541265488,0.4806000292301178,2.3341472148895264,10000,33077.85925221443,0.6518945097923279,1.436215043067932,0.605139970779419,1.6577297449111938,50000 -3248.1766040325165,2.854301929473877,30290.650116682053,65089,0,30290.650116682053,0.4880000352859497,2.3240888118743896,10000,33545.353556633,0.6542773246765137,1.4176008701324463,0.6029399633407593,1.6662062406539917,50000 -3293.040856361389,2.891425609588623,30710.779942512512,65996,0,30710.779942512512,0.4899000227451324,2.314547300338745,10000,34010.43670129776,0.667675793170929,1.373803734779358,0.6104399561882019,1.6528773307800293,50000 -3333.1036410331726,2.93249773979187,31131.218547344208,66901,0,31131.218547344208,0.4816000163555145,2.354022741317749,10000,34471.03016901016,0.6533398032188416,1.4194730520248413,0.6050599813461304,1.660703182220459,50000 -3377.696927547455,2.971553564071656,31551.12981677056,67805,0,31551.12981677056,0.492900013923645,2.301361322402954,10000,34935.6247048378,0.6598047018051147,1.3968747854232788,0.6152600049972534,1.627685785293579,50000 -3423.8143379688263,3.0081887245178223,31971.418674469,68712,0,31971.418674469,0.4833000302314758,2.339237689971924,10000,35402.118267059326,0.6655663847923279,1.3830251693725586,0.6067999601364136,1.658873438835144,50000 -3466.132476091385,3.0513832569122314,32391.608671188354,69620,0,32391.608671188354,0.4856000244617462,2.3167552947998047,10000,35864.72148346901,0.6632031202316284,1.383965253829956,0.610539972782135,1.6313594579696655,50000 -3512.953349590301,3.0895795822143555,32811.70836234093,70524,0,32811.70836234093,0.4950000345706939,2.2632946968078613,10000,36331.73110461235,0.6672461032867432,1.3634191751480105,0.6170399785041809,1.6041995286941528,50000 -3556.7215859889984,3.130367517471313,33231.66958928108,71426,0,33231.66958928108,0.4943000376224518,2.282932043075561,10000,36795.55163145065,0.6733202934265137,1.3281680345535278,0.6173200011253357,1.5995622873306274,50000 -3604.856119155884,3.186277151107788,33651.71494960785,72330,0,33651.71494960785,0.4918000102043152,2.269059896469116,10000,37263.838456869125,0.6946874856948853,1.2652043104171753,0.6173799633979797,1.6047745943069458,50000 -3652.995265960693,3.228080034255981,34072.08019518852,73236,0,34072.08019518852,0.4927000105381012,2.279181241989136,10000,37732.43508505821,0.6619726419448853,1.3814362287521362,0.613860011100769,1.6136529445648191,50000 -3694.383110284805,3.269738912582397,34492.027354717255,74139,0,34492.027354717255,0.5021000504493713,2.2524607181549072,10000,38193.86260414124,0.6750390529632568,1.3150864839553833,0.6202999949455261,1.581875562667847,50000 -3741.594585418701,3.3109562397003174,34912.15011835098,75045,0,34912.15011835098,0.4984000325202942,2.2771427631378174,10000,38661.28926753998,0.6972851157188416,1.2580310106277466,0.6181199550628662,1.6094812154769895,50000 -3786.706786632538,3.353020668029785,35332.28189110756,75953,0,35332.28189110756,0.5069000124931335,2.2062301635742188,10000,39126.62657475472,0.6793164014816284,1.3091703653335571,0.6315000057220459,1.5481455326080322,50000 -3836.816582202912,3.397064447402954,35752.626155376434,76861,0,35752.626155376434,0.4974000155925751,2.2438337802886963,10000,39597.17609834671,0.6729882955551147,1.3250367641448977,0.6203599572181702,1.5800641775131226,50000 -3884.51095700264,3.4398858547210693,36172.57538366318,77766,0,36172.57538366318,0.5112000107765198,2.183667659759521,10000,40064.91317510605,0.7009375095367432,1.209425687789917,0.6351000070571899,1.5253475904464722,50000 -3933.235659122467,3.48156213760376,36592.812076091766,78670,0,36592.812076091766,0.5099000334739685,2.180517196655273,10000,40533.96717476845,0.6809374690055847,1.2866195440292358,0.6292799711227417,1.5328986644744873,50000 -3981.647619247437,3.5212857723236084,37012.992997169495,79574,0,37012.992997169495,0.5047000050544739,2.2638602256774902,10000,41002.65064263344,0.6755468845367432,1.3667091131210327,0.6249200105667114,1.604951024055481,50000 -4021.62891960144,3.560465812683105,37432.94675517082,80480,0,37432.94675517082,0.5015000104904175,2.2130167484283447,10000,41462.675387859344,0.6903710961341858,1.240017652511597,0.6298399567604065,1.541815161705017,50000 -4069.3755803108215,3.600275039672852,37852.8576259613,81384,0,37852.8576259613,0.5029000043869019,2.2255892753601074,10000,41930.42304897308,0.6767382621765137,1.3347312211990356,0.6283999681472778,1.5647926330566406,50000 -4117.19310426712,3.6395747661590576,38272.83093523979,82289,0,38272.83093523979,0.5054000020027161,2.210899353027344,10000,42398.30388045311,0.6830077767372131,1.3016310930252075,0.6261599659919739,1.565454125404358,50000 -4168.097243785858,3.682982444763184,38692.91235637665,83196,0,38692.91235637665,0.5112000107765198,2.164104700088501,10000,42869.38398528099,0.6936327815055847,1.2280399799346924,0.6358000040054321,1.5165338516235352,50000 -4214.460152864456,3.722850561141968,39113.014108896255,84100,0,39113.014108896255,0.513700008392334,2.1791257858276367,10000,43335.9387626648,0.6860546469688416,1.281497359275818,0.638260006904602,1.5209600925445557,50000 -4262.848652362824,3.765680074691773,39533.14155960083,85007,0,39533.14155960083,0.5118000507354736,2.1877145767211914,10000,43804.54938220978,0.6898437142372131,1.2605102062225342,0.6386399865150452,1.5136559009552002,50000 -4310.929792881012,3.8074779510498047,39953.23417115212,85914,0,39953.23417115212,0.5211000442504883,2.127619743347168,10000,44272.816187381744,0.6995507478713989,1.186415195465088,0.642300009727478,1.469892501831055,50000 -4358.529216766357,3.8530352115631104,40373.50329899788,86817,0,40373.50329899788,0.51910001039505,2.128563165664673,10000,44740.78130793572,0.6941601634025574,1.2347404956817627,0.6406199932098389,1.486888766288757,50000 -4404.2962164878845,3.895873308181762,40793.42196941376,87720,0,40793.42196941376,0.5139000415802002,2.182572841644287,10000,45206.56125569344,0.69593745470047,1.2813823223114014,0.6403999924659729,1.535065770149231,50000 -4456.234867095947,3.935400485992432,41213.36802792549,88625,0,41213.36802792549,0.5161000490188599,2.1546852588653564,10000,45678.53724288941,0.7049609422683716,1.2156800031661987,0.6411799788475037,1.4978916645050049,50000 -4504.405340433121,3.975401163101196,41633.58397102356,89532,0,41633.58397102356,0.5234000086784363,2.139137983322144,10000,46147.01468038559,0.7158789038658142,1.1632670164108276,0.6412999629974365,1.5002963542938232,50000 -4551.505216121674,4.02495002746582,42053.79886484146,90437,0,42053.79886484146,0.5220000147819519,2.13555645942688,10000,46614.43021249771,0.6942968368530273,1.242924690246582,0.6450600028038025,1.48419451713562,50000 -4598.104979038239,4.066986322402954,42474.13457107544,91344,0,42474.13457107544,0.5290000438690186,2.1351263523101807,10000,47081.45916390419,0.7076367139816284,1.2079362869262695,0.644320011138916,1.490636110305786,50000 -4647.107894182205,4.110424280166626,42894.400421381,92246,0,42894.400421381,0.5223000049591064,2.1604013442993164,10000,47550.82168245316,0.7279687523841858,1.1314303874969482,0.6431199908256531,1.4939802885055542,50000 -4695.176491975784,4.155153512954712,43314.599668979645,93152,0,43314.599668979645,0.5295000076293945,2.0929665565490723,10000,48019.18526363373,0.7049609422683716,1.1868587732315063,0.6511799693107605,1.4446868896484375,50000 -4743.113221168518,4.200689315795898,43734.53122782707,94057,0,43734.53122782707,0.5242000222206116,2.119861364364624,10000,48487.1493666172,0.7050195336341858,1.1846143007278442,0.6482399702072144,1.4538512229919434,50000 -4788.12454199791,4.246379375457764,44154.71737384796,94962,0,44154.71737384796,0.527400016784668,2.117228507995605,10000,48952.4436519146,0.7193359136581421,1.1572200059890747,0.6502999663352966,1.47944438457489,50000 -4836.9977996349335,4.29305100440979,44575.00277280808,95868,0,44575.00277280808,0.531000018119812,2.119685411453247,10000,49421.69991064072,0.7100390195846558,1.2094417810440063,0.6509999632835388,1.471266269683838,50000 -4885.990926504135,4.338782787322998,44995.451451301575,96774,0,44995.451451301575,0.525600016117096,2.174257516860962,10000,49891.23855257034,0.7082812190055847,1.2541940212249756,0.6438999772071838,1.5317820310592651,50000 -4933.726239919663,4.382209062576294,45415.57292270661,97681,0,45415.57292270661,0.5313000082969666,2.074703931808472,10000,50359.1891078949,0.7260546684265137,1.110514760017395,0.6534599661827087,1.432780385017395,50000 -4977.804324626923,4.428659439086914,45835.61195087433,98584,0,45835.61195087433,0.5313000082969666,2.124727725982666,10000,50823.40377354622,0.7095507383346558,1.2087146043777466,0.6518399715423584,1.4704029560089111,50000 -5025.0931096076965,4.4764626026153564,46255.78856515885,99487,0,46255.78856515885,0.5333000421524048,2.0649092197418213,10000,51290.96784090996,0.7159179449081421,1.143684983253479,0.660539984703064,1.4055538177490234,50000 -5073.639664173126,4.521659851074219,46676.1024954319,100394,0,46676.1024954319,0.5331000089645386,2.06860876083374,10000,51759.92477321625,0.7255663871765137,1.114113211631775,0.6623600125312805,1.4175740480422974,50000 -5120.609957456589,4.565702676773071,47096.401584386826,101303,0,47096.401584386826,0.5372000336647034,2.0668838024139404,10000,52227.28931951523,0.7194140553474426,1.1470216512680054,0.6613399982452393,1.3947124481201172,50000 -5165.949766874313,4.609235286712647,47516.692006111145,102211,0,47516.692006111145,0.5344000458717346,2.0949273109436035,10000,52693.0149166584,0.7141796946525574,1.1836493015289309,0.6551799774169922,1.4438341856002808,50000 -5214.342481613159,4.651084899902344,47937.02245235443,103115,0,47937.02245235443,0.5384000539779663,2.0607264041900635,10000,53161.830530405045,0.7273827791213989,1.117196559906006,0.66211998462677,1.4077105522155762,50000 -5259.345903396606,4.697351932525635,48357.16518211365,104018,0,48357.16518211365,0.5418000221252441,2.032897472381592,10000,53627.07335996628,0.7233007550239563,1.127516508102417,0.6676799654960632,1.3816344738006592,50000 -5306.26788520813,4.742557764053345,48777.29816532135,104923,0,48777.29816532135,0.538100004196167,2.138030767440796,10000,54094.22512769699,0.7226171493530273,1.2302206754684448,0.6633599996566772,1.492753267288208,50000 -5351.597653627396,4.792775630950928,49197.20958852768,105828,0,49197.20958852768,0.5491999983787537,2.021711826324463,10000,54559.56789302826,0.7344140410423279,1.0741634368896484,0.6691799759864807,1.373185396194458,50000 -5400.683298826218,4.839393138885498,49617.28916144371,106733,0,49617.28916144371,0.5445000529289246,2.057142734527588,10000,55028.830676317215,0.7335546612739563,1.0971087217330933,0.6657599806785583,1.4038810729980469,50000 -5444.214246273041,4.885617017745972,50037.69368457794,107638,0,50037.69368457794,0.5491999983787537,1.9994500875473025,10000,55492.862330675125,0.7317773103713989,1.0785207748413086,0.6732400059700012,1.3483539819717407,50000 -5492.055480241776,4.9278857707977295,50458.04724955559,108542,0,50458.04724955559,0.5520000457763672,2.020024299621582,10000,55961.15143156052,0.7343358993530273,1.076116919517517,0.6702600121498108,1.35907781124115,50000 -5539.1366448402405,4.976301908493042,50877.9556055069,109444,0,50877.9556055069,0.5466000437736511,2.0300540924072266,10000,56428.23986721039,0.7549999952316284,1.0015008449554443,0.6709199547767639,1.374889612197876,50000 -5586.264869213104,5.022613763809204,51298.14277744293,110352,0,51298.14277744293,0.5463000535964966,1.9860961437225344,10000,56895.654056310654,0.7322655916213989,1.0748839378356934,0.6771999597549438,1.3334211111068726,50000 -5634.615519762039,5.071029901504517,51718.23906922341,111258,0,51718.23906922341,0.5458000302314758,2.025487184524536,10000,57364.20047616959,0.7407421469688416,1.0589429140090942,0.6723399758338928,1.357522964477539,50000 -5682.680174589157,5.116419792175293,52138.199348688126,112164,0,52138.199348688126,0.550000011920929,1.9666587114334104,10000,57832.32154393196,0.7571874856948853,0.9590351581573486,0.6782799959182739,1.3155368566513062,50000 -5730.1758716106415,5.1636152267456055,52558.10915327072,113066,0,52558.10915327072,0.5481000542640686,2.024175882339477,10000,58299.82621335983,0.7348827719688416,1.1102678775787354,0.6744199991226196,1.374962568283081,50000 -5778.074056625366,5.209723234176636,52978.40332150459,113970,0,52978.40332150459,0.549500048160553,2.0046536922454834,10000,58768.11574149132,0.7422655820846558,1.044321060180664,0.6773599982261658,1.3351013660430908,50000 -5821.660964488983,5.261880159378052,53398.46142745018,114873,0,53398.46142745018,0.5515000224113464,1.9546692371368408,10000,59231.86329960823,0.7559765577316284,0.9663622975349426,0.6796199679374695,1.3073861598968506,50000 -5870.175407886505,5.30842137336731,53818.511486291885,115777,0,53818.511486291885,0.5634000301361084,1.946633100509644,10000,59700.5247297287,0.7392773032188416,1.056602954864502,0.6826800107955933,1.3141595125198364,50000 -5919.104496240616,5.356947898864746,54238.54416680336,116682,0,54238.54416680336,0.5532000064849854,1.9711726903915403,10000,60169.58668136597,0.7460546493530273,1.0224932432174685,0.6796199679374695,1.3186691999435425,50000 -5966.191679239273,5.405437469482422,54658.53148937225,117587,0,54658.53148937225,0.5639000535011292,1.97365140914917,10000,60636.760419130325,0.7542187571525574,1.0017627477645874,0.6813200116157532,1.3278917074203491,50000 -6014.26641201973,5.455573797225952,55078.50196790695,118493,0,55078.50196790695,0.55840003490448,1.9573954343795776,10000,61104.90820026398,0.7477148175239563,1.0368415117263794,0.6837999820709229,1.3157402276992798,50000 -6056.332176923752,5.500463962554932,55498.71514606476,119399,0,55498.71514606476,0.5649000406265259,1.9199223518371584,10000,61567.28264904022,0.7528710961341858,0.9768688082695008,0.6881799697875977,1.2712483406066897,50000 -6102.064264535904,5.548049688339233,55918.85990691185,120301,0,55918.85990691185,0.5645000338554382,1.9446827173233032,10000,62033.25811219216,0.7587109208106995,0.9645220637321472,0.6875999569892883,1.2869911193847656,50000 -6149.277095556259,5.596048831939697,56338.76570916176,121206,0,56338.76570916176,0.5652000308036804,1.9142930507659912,10000,62500.47589445114,0.749804675579071,0.995943248271942,0.6902399659156799,1.2766941785812378,50000 -6198.029419660568,5.6411542892456055,56758.85825443268,122110,0,56758.85825443268,0.5681000351905823,1.9204736948013303,10000,62969.41731142998,0.7586718797683716,0.9999610781669616,0.6941399574279785,1.28845477104187,50000 -6246.399075984955,5.687482833862305,57178.78510522842,123017,0,57178.78510522842,0.5654000043869019,1.9345817565917969,10000,63437.81168818474,0.7599804401397705,0.9783530235290528,0.6932199597358704,1.2936484813690186,50000 -6293.630912065506,5.734906911849976,57598.70680522919,123923,0,57598.70680522919,0.5713000297546387,1.8837766647338867,10000,63905.06403899193,0.7659960985183716,0.934769570827484,0.6953999996185303,1.2383933067321775,50000 -6343.900760173798,5.780289649963379,58018.94120192528,124829,0,58018.94120192528,0.5699000358581543,1.8792723417282104,10000,64375.66458725929,0.763378918170929,0.9280872344970704,0.6961399912834167,1.2279818058013916,50000 -6391.183966159821,5.8267786502838135,58438.90938568115,125733,0,58438.90938568115,0.5800999999046326,1.8638954162597656,10000,64843.013377428055,0.7677538990974426,0.924182951450348,0.6993599534034729,1.229907751083374,50000 -6440.705798387528,5.875086545944214,58859.0885219574,126635,0,58859.0885219574,0.5753000378608704,1.8391834497451784,10000,65312.81332349777,0.78236323595047,0.8538432121276855,0.7017399668693542,1.2105683088302612,50000 -6487.029499053955,5.920918226242065,59279.26599240303,127542,0,59279.26599240303,0.5718000531196594,1.9025813341140747,10000,65779.41101312637,0.759082019329071,0.964400053024292,0.6955400109291077,1.2564650774002075,50000 -6533.84437918663,5.970127820968628,59699.49506402016,128448,0,59699.49506402016,0.5759000182151794,1.8524872064590447,10000,66246.55462884903,0.774609386920929,0.8842771053314209,0.701259970664978,1.2053685188293457,50000 -6579.178344249725,6.016943454742432,60119.54176044464,129353,0,60119.54176044464,0.5803000330924988,1.8332209587097168,10000,66712.03279447556,0.7856835722923279,0.8250307440757751,0.7035599946975708,1.1953197717666626,50000 -6627.108244419098,6.065059423446655,60539.66762089729,130258,0,60539.66762089729,0.5763000249862671,1.872269868850708,10000,67180.1879966259,0.7689648270606995,0.9314327239990234,0.7040199637413025,1.226270079612732,50000 -6674.615309238434,6.117140293121338,60959.68307638168,131165,0,60959.68307638168,0.589400053024292,1.8679510354995728,10000,67647.81341028214,0.7744531035423279,0.9301571249961852,0.7041199803352356,1.241104245185852,50000 -6719.562696695328,6.168532371520996,61379.67488145828,132069,0,61379.67488145828,0.5800000429153442,1.8685365915298464,10000,68112.85487318039,0.7873827815055847,0.8701390624046326,0.7058999538421631,1.2221981287002563,50000 -6770.016849517822,6.2154810428619385,61799.59830021858,132973,0,61799.59830021858,0.5845000147819519,1.819374918937683,10000,68583.33038425446,0.7793359160423279,0.8791748881340027,0.7087000012397766,1.1858389377593994,50000 -6816.891093969345,6.26798677444458,62219.77271294594,133878,0,62219.77271294594,0.5937000513076782,1.7851358652114868,10000,69050.48225259781,0.7819726467132568,0.849389910697937,0.7128599882125854,1.159160017967224,50000 -6866.628557682037,6.316884279251099,62639.89648962021,134782,0,62639.89648962021,0.5934000015258789,1.7752066850662231,10000,69520.44361186028,0.792773425579071,0.7919089198112488,0.7138599753379822,1.1439038515090942,50000 -6910.820559263229,6.79111123085022,63059.385607004166,135682,0,63059.385607004166,0.5867000222206116,1.83801019191742,10000,69984.6502289772,0.7776952981948853,0.8922684788703918,0.7090799808502197,1.2007648944854736,50000 -6955.858748197556,6.841918230056763,63479.35899710655,136587,0,63479.35899710655,0.5918000340461731,1.7970510721206665,10000,70449.76348471642,0.7864648103713989,0.8539221286773682,0.715499997138977,1.170691967010498,50000 -7005.048368930817,6.893892765045166,63899.612374305725,137494,0,63899.612374305725,0.5878000259399414,1.812727212905884,10000,70919.30957722664,0.79066401720047,0.8283480405807495,0.7112599611282349,1.1805391311645508,50000 -7052.7244708538055,6.946126461029053,64319.660131692886,138400,0,64319.660131692886,0.5919000506401062,1.8064581155776973,10000,71387.13676571846,0.7880859375,0.8579350709915161,0.7173199653625488,1.1725926399230957,50000 -7100.193810462952,6.996117830276489,64739.67565011978,139301,0,64739.67565011978,0.5958999991416931,1.7701454162597656,10000,71854.72164583206,0.7901171445846558,0.8141428828239441,0.7179799675941467,1.1384248733520508,50000 -7149.690311908722,7.050333499908447,65159.74625015259,140204,0,65159.74625015259,0.5996000170707703,1.7611976861953735,10000,72324.3936612606,0.8003124594688416,0.7800408601760864,0.717739999294281,1.1377885341644287,50000 -7194.6404457092285,7.10044264793396,65579.95306015015,141109,0,65579.95306015015,0.5931000113487244,1.7584959268569946,10000,72789.65150928497,0.7925195097923279,0.8001589179039001,0.7177000045776367,1.1277657747268677,50000 -7244.440073251724,7.150071620941162,65999.96157240868,142014,0,65999.96157240868,0.6005000472068787,1.7380754947662354,10000,73259.56003165245,0.8015038967132568,0.7711244821548462,0.722599983215332,1.1164039373397827,50000 -7293.018522024155,7.19990348815918,66420.08053159714,142915,0,66420.08053159714,0.5999000072479248,1.7464487552642822,10000,73728.358700037,0.8013671636581421,0.7719994187355042,0.721340000629425,1.121065616607666,50000 -7341.1628839969635,7.251514196395874,66840.15920996666,143819,0,66840.15920996666,0.5993000268936157,1.7647473812103271,10000,74196.6847281456,0.8026562333106995,0.7779040932655334,0.7221399545669556,1.1373088359832764,50000 -7385.170194864273,7.303040027618408,67260.12916207314,144724,0,67260.12916207314,0.5981000065803528,1.7557260990142822,10000,74660.76443743706,0.7992382645606995,0.7773339748382568,0.7215799689292908,1.1161545515060425,50000 -7432.81751871109,7.352954149246216,67680.05412721634,145628,0,67680.05412721634,0.6051000356674194,1.739145040512085,10000,75128.43763279915,0.80517578125,0.7599520683288574,0.7254599928855896,1.1081088781356812,50000 -7477.329308271408,7.4053473472595215,68100.3693766594,146535,0,68100.3693766594,0.6033000349998474,1.7289841175079346,10000,75593.36808228493,0.8170703053474426,0.7173280119895935,0.7281399965286255,1.1051756143569946,50000 -7523.654905796051,7.458659648895264,68520.51922082901,147440,0,68520.51922082901,0.601900041103363,1.774980902671814,10000,76059.94834542274,0.8044531345367432,0.7989120483398438,0.7260000109672546,1.1396101713180542,50000 -7567.843139410019,7.50982141494751,68940.56979131699,148347,0,68940.56979131699,0.6075000166893005,1.727489352226257,10000,76524.28863739967,0.8073046803474426,0.7472729086875916,0.7283799648284912,1.0935417413711548,50000 -7611.454800367355,7.564750909805298,69360.65314507484,149250,0,69360.65314507484,0.6032000184059143,1.7510703802108765,10000,76988.09020090103,0.81214839220047,0.7340012192726135,0.7254999876022339,1.1147717237472534,50000 -7655.94317984581,7.613767862319946,69780.64332485199,150156,0,69780.64332485199,0.6111000180244446,1.714582920074463,10000,77452.66871452332,0.8089648485183716,0.7477630972862244,0.7283399701118469,1.0934544801712036,50000 -7701.411295890808,7.668791770935059,70200.61243200302,151062,0,70200.61243200302,0.6077000498771667,1.7483646869659424,10000,77918.21200990677,0.8106249570846558,0.7664503455162048,0.7301999926567078,1.119113564491272,50000 -7748.537898540497,7.726431369781494,70620.5970556736,151966,0,70620.5970556736,0.6164000034332275,1.7178739309310913,10000,78385.43194818497,0.8160937428474426,0.7283310294151306,0.7328400015830994,1.0975955724716189,50000 -7791.45180106163,7.7803356647491455,71040.94093084335,152872,0,71040.94093084335,0.6099000573158264,1.705723762512207,10000,78848.79572200775,0.8132226467132568,0.7306903004646301,0.7350199818611145,1.078985333442688,50000 -7838.883292198181,7.837963819503784,71460.94529819489,153775,0,71460.94529819489,0.6076000332832336,1.7089613676071167,10000,79316.3397808075,0.81068354845047,0.7280431389808655,0.7346599698066711,1.074668526649475,50000 -7887.490196943283,7.888981103897095,71881.10067653656,154682,0,71881.10067653656,0.6127000451087952,1.691803216934204,10000,79785.20401096344,0.8229491710662842,0.6811491250991821,0.7339999675750732,1.0581152439117432,50000 -7935.300303936005,7.945384502410889,72301.3064084053,155586,0,72301.3064084053,0.6158000230789185,1.7009111642837524,10000,80253.32729268074,0.8204687237739563,0.7243859767913818,0.7372999787330627,1.0768547058105469,50000 -7985.649563550949,7.997928142547607,72721.32059788704,156487,0,72721.32059788704,0.615600049495697,1.6857863664627075,10000,80723.79385137558,0.8235155940055847,0.6817896962165833,0.7389400005340576,1.045330047607422,50000 -8030.373905658722,8.049262046813965,73141.60916376114,157392,0,73141.60916376114,0.6150000095367432,1.6935813426971436,10000,81188.90914106369,0.8236913681030273,0.6922850012779236,0.7382199764251709,1.066370725631714,50000 -8078.557241678238,8.106578588485718,73561.83675098419,158297,0,73561.83675098419,0.6205000281333923,1.6642597913742063,10000,81657.4282822609,0.8206835985183716,0.6895581483840942,0.7396399974822998,1.041470289230347,50000 -8123.858737468719,8.167153358459473,73981.86839866638,159204,0,73981.86839866638,0.6186000108718872,1.7149657011032104,10000,82122.87282919884,0.82533198595047,0.7197791934013367,0.7388399839401245,1.0863151550292969,50000 -8168.378638029098,8.218732833862305,74401.79539680481,160110,0,74401.79539680481,0.6185000538825989,1.684755802154541,10000,82587.42238020897,0.8278124928474426,0.6878633499145508,0.7412199974060059,1.0593788623809814,50000 -8217.548476457596,8.271214008331299,74821.9603896141,161015,0,74821.9603896141,0.6220000386238098,1.668728590011597,10000,83056.85979914665,0.82972651720047,0.6675693392753601,0.7438399791717529,1.0392054319381714,50000 -8264.541722774506,8.323035478591919,75242.30540776253,161921,0,75242.30540776253,0.6206000447273254,1.6684740781784058,10000,83524.30050992966,0.8286718726158142,0.672528088092804,0.7439000010490417,1.04596745967865,50000 -8309.82877087593,8.38948941230774,75662.24251437187,162826,0,75662.24251437187,0.6234000325202942,1.676736831665039,10000,83989.64229750633,0.8308398127555847,0.6795992851257324,0.743399977684021,1.0521153211593628,50000 -8358.68628025055,8.442897319793701,76082.43580412865,163733,0,76082.43580412865,0.6276000142097473,1.6446024179458618,10000,84458.79799222946,0.8407226204872131,0.6239649653434753,0.7476399540901184,1.0164140462875366,50000 -8406.580387592316,8.498199224472046,76502.39468336105,164639,0,76502.39468336105,0.6267000436782837,1.6496158838272097,10000,84926.75717353821,0.8356835842132568,0.6445589065551758,0.7468599677085876,1.0233219861984253,50000 -8451.303115844727,8.553112983703613,76922.57128500938,165546,0,76922.57128500938,0.6241000294685364,1.640983819961548,10000,85391.76311731339,0.8334375023841858,0.6440165638923645,0.7469399571418762,1.0224330425262451,50000 -8500.320755720139,8.607258081436157,77342.91505241394,166451,0,77342.91505241394,0.6309000253677368,1.6426352262496948,10000,85861.23006987572,0.8373242020606995,0.6317007541656494,0.7473799586296082,1.0245435237884521,50000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/measurements.csv deleted file mode 100644 index f72a4ee32..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1856 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.34916806,6.9077563,,,,,,,,,,,,,, -1,,,0.0008203124743886,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,40.46163988113403,69.32669401168823,40.46163988113403,28.86489605903625,0.0,0.0 -100,0.44360203,6.8918695,,,,,,,,,,,,,, -200,0.5918577,6.7905817,,,,,,,,,,,,,, -300,0.7995335,6.7130694,,,,,,,,,,,,,, -400,1.0912279,6.557935,,,,,,,,,,,,,, -500,0.9816435,6.525278,,,,,,,,,,,,,, -600,0.97524804,6.396842,,,,,,,,,,,,,, -700,0.78115654,6.739863,,,,,,,,,,,,,, -800,1.3334811,6.3608046,,,,,,,,,,,,,, -844,,,0.0341015607118606,5.897230625152588,0.0298199988901615,5.968258380889893,50000.0,0.021900001913309,6.086627006530762,10000.0,460.644246339798,530.1097981929779,460.644246339798,69.4005184173584,0.017852783203125,0.0 -900,1.3952488,6.1328025,,,,,,,,,,,,,, -1000,1.2681668,6.1577997,,,,,,,,,,,,,, -1100,1.3139182,6.033288,,,,,,,,,,,,,, -1200,1.2634848,5.9491987,,,,,,,,,,,,,, -1300,1.0638866,6.0350623,,,,,,,,,,,,,, -1400,1.1219183,6.052869,,,,,,,,,,,,,, -1500,1.0971411,6.0787454,,,,,,,,,,,,,, -1600,1.1379021,5.8170657,,,,,,,,,,,,,, -1700,0.7618756,6.4211845,,,,,,,,,,,,,, -1745,,,0.0720312520861625,5.3030266761779785,0.0671200007200241,5.3553972244262695,50000.0,0.0512000024318695,5.565213203430176,10000.0,880.6717927455902,993.2836394309998,880.6717927455902,112.46707344055176,0.0463771820068359,0.0 -1800,1.1030118,6.320637,,,,,,,,,,,,,, -1900,1.0349852,6.65249,,,,,,,,,,,,,, -2000,0.98351836,5.614855,,,,,,,,,,,,,, -2100,0.9788415,5.6233916,,,,,,,,,,,,,, -2200,1.2596825,5.593454,,,,,,,,,,,,,, -2300,1.0677398,5.496742,,,,,,,,,,,,,, -2400,1.3896489,5.500402,,,,,,,,,,,,,, -2500,1.0889136,5.4818277,,,,,,,,,,,,,, -2600,1.0291852,5.3157864,,,,,,,,,,,,,, -2649,,,0.1213671863079071,4.764578342437744,0.1108599975705146,4.83726167678833,50000.0,0.0862000063061714,5.133504390716553,10000.0,1300.8934905529022,1458.8881137371063,1300.8934905529022,157.77245998382568,0.0728633403778076,0.0 -2700,0.8529948,6.2359996,,,,,,,,,,,,,, -2800,1.288594,5.2754025,,,,,,,,,,,,,, -2900,1.0545619,5.2855587,,,,,,,,,,,,,, -3000,0.8948273,5.433722,,,,,,,,,,,,,, -3100,0.83163136,5.196929,,,,,,,,,,,,,, -3200,1.2006154,6.36836,,,,,,,,,,,,,, -3300,1.0066563,5.0626063,,,,,,,,,,,,,, -3400,1.1036518,5.1737795,,,,,,,,,,,,,, -3500,1.0003617,5.5543656,,,,,,,,,,,,,, -3552,,,0.1814257800579071,4.194916248321533,0.1604000031948089,4.344925403594971,50000.0,0.1246000081300735,4.737356662750244,10000.0,1720.9098417758942,1923.949401378632,1720.9098417758942,202.7385392189026,0.1011176109313964,0.0 -3600,0.6704508,6.0586944,,,,,,,,,,,,,, -3700,1.0349457,4.8741474,,,,,,,,,,,,,, -3800,0.9550288,5.001726,,,,,,,,,,,,,, -3900,0.91601515,5.107234,,,,,,,,,,,,,, -4000,0.7848474,4.921911,,,,,,,,,,,,,, -4100,0.914854,4.8918056,,,,,,,,,,,,,, -4200,0.8592373,4.704129,,,,,,,,,,,,,, -4300,1.093174,4.728849,,,,,,,,,,,,,, -4400,0.6328603,6.0809627,,,,,,,,,,,,,, -4456,,,0.2257421761751175,3.862603664398194,0.2096999883651733,3.967260837554932,50000.0,0.1565000116825103,4.419209003448486,10000.0,2141.054492235184,2386.887092113495,2141.054492235184,245.4547533988953,0.1272704601287841,0.0 -4500,0.84871393,4.735203,,,,,,,,,,,,,, -4600,0.9934407,4.646956,,,,,,,,,,,,,, -4700,0.85977876,4.650463,,,,,,,,,,,,,, -4800,0.6869458,6.228036,,,,,,,,,,,,,, -4900,0.810125,4.5971966,,,,,,,,,,,,,, -5000,0.8106538,4.8528786,,,,,,,,,,,,,, -5100,0.79457444,4.6024647,,,,,,,,,,,,,, -5200,1.1726747,4.450419,,,,,,,,,,,,,, -5300,0.7723103,6.0962973,,,,,,,,,,,,,, -5359,,,0.2718164026737213,3.585689306259156,0.2515200078487396,3.714376211166382,50000.0,0.1880000084638595,4.178256511688232,10000.0,2560.998118877411,2851.177732467652,2560.998118877411,289.7249255180359,0.1532239913940429,0.0 -5400,1.0658298,6.3424892,,,,,,,,,,,,,, -5500,0.77038014,5.10406,,,,,,,,,,,,,, -5600,0.9342071,4.361048,,,,,,,,,,,,,, -5700,0.7967261,4.877899,,,,,,,,,,,,,, -5800,0.6337284,5.3219957,,,,,,,,,,,,,, -5900,0.7939241,4.5080023,,,,,,,,,,,,,, -6000,0.78327614,4.9131656,,,,,,,,,,,,,, -6100,0.89777124,4.2847767,,,,,,,,,,,,,, -6200,0.6607839,5.01295,,,,,,,,,,,,,, -6257,,,0.3246679604053497,3.229868173599243,0.2886599898338318,3.425157070159912,50000.0,0.2222000062465667,3.962292432785034,10000.0,2980.9549465179443,3313.470268011093,2980.9549465179443,331.9820320606232,0.1811397075653076,0.0 -6300,0.74377954,4.203767,,,,,,,,,,,,,, -6400,0.7821588,4.406228,,,,,,,,,,,,,, -6500,0.80293274,4.6246243,,,,,,,,,,,,,, -6600,0.8083724,5.01937,,,,,,,,,,,,,, -6700,0.5613438,6.117232,,,,,,,,,,,,,, -6800,0.7120822,4.608323,,,,,,,,,,,,,, -6900,0.8466134,3.9985619,,,,,,,,,,,,,, -7000,0.7214173,6.153179,,,,,,,,,,,,,, -7100,0.8199877,6.134425,,,,,,,,,,,,,, -7164,,,0.3366992175579071,3.113884210586548,0.3124600052833557,3.242727041244507,50000.0,0.2367000132799148,3.815329551696777,10000.0,3400.9355845451355,3776.551196575165,3400.9355845451355,375.0047626495361,0.2081298828125,0.0 -7200,0.785505,6.0704384,,,,,,,,,,,,,, -7300,0.6133499,5.939592,,,,,,,,,,,,,, -7400,1.001699,3.9619126,,,,,,,,,,,,,, -7500,0.785079,4.975891,,,,,,,,,,,,,, -7600,0.7696901,4.0122705,,,,,,,,,,,,,, -7700,0.82321435,4.0371547,,,,,,,,,,,,,, -7800,1.0265274,4.0500093,,,,,,,,,,,,,, -7900,0.6170253,5.201285,,,,,,,,,,,,,, -8000,1.0765839,3.941787,,,,,,,,,,,,,, -8068,,,0.3708398342132568,2.9305522441864014,0.338919997215271,3.110529899597168,50000.0,0.2571000158786773,3.680340051651001,10000.0,3821.070417881012,4241.460793018341,3821.070417881012,419.699774980545,0.2372419834136963,0.0 -8100,0.83997697,3.9420376,,,,,,,,,,,,,, -8200,0.50479174,5.9545856,,,,,,,,,,,,,, -8300,0.7928471,4.1151257,,,,,,,,,,,,,, -8400,0.86331034,3.887035,,,,,,,,,,,,,, -8500,0.85777515,3.8356318,,,,,,,,,,,,,, -8600,0.9933428,4.065798,,,,,,,,,,,,,, -8700,0.66635823,4.603259,,,,,,,,,,,,,, -8800,0.85319865,3.8046315,,,,,,,,,,,,,, -8900,0.753153,5.3062015,,,,,,,,,,,,,, -8974,,,0.3893359303474426,2.798448085784912,0.3562199771404266,2.991145610809326,50000.0,0.2725000083446502,3.5983641147613525,10000.0,4241.001017093658,4700.1675000190735,4241.001017093658,458.3977072238922,0.2645635604858398,0.0 -9000,0.7440718,3.857831,,,,,,,,,,,,,, -9100,0.7013073,6.014196,,,,,,,,,,,,,, -9200,0.68078375,5.9081116,,,,,,,,,,,,,, -9300,0.63570327,5.6338215,,,,,,,,,,,,,, -9400,0.5640709,5.927367,,,,,,,,,,,,,, -9500,0.6875703,5.9891357,,,,,,,,,,,,,, -9600,0.8009371,4.04372,,,,,,,,,,,,,, -9700,1.380053,3.8972116,,,,,,,,,,,,,, -9800,1.0116878,3.8253396,,,,,,,,,,,,,, -9879,,,0.405078113079071,2.719622850418091,0.3754999935626983,2.876948118209839,50000.0,0.2928000092506408,3.460809230804444,10000.0,4661.186886072159,5162.836849212647,4661.186886072159,500.7986707687378,0.2959358692169189,0.0 -9900,0.95720536,3.7325034,,,,,,,,,,,,,, -10000,0.89983195,3.7491753,,,,,,,,,,,,,, -10100,1.2818279,4.9357414,,,,,,,,,,,,,, -10200,0.66435486,5.3017454,,,,,,,,,,,,,, -10300,0.86877865,3.722014,,,,,,,,,,,,,, -10400,0.98700917,4.103431,,,,,,,,,,,,,, -10500,0.9021738,3.605653,,,,,,,,,,,,,, -10600,0.87770057,4.110489,,,,,,,,,,,,,, -10700,0.8297798,4.4662805,,,,,,,,,,,,,, -10783,,,0.4206640422344208,2.6827948093414307,0.3864399790763855,2.8488309383392334,50000.0,0.3021000027656555,3.441816568374634,10000.0,5081.527356147766,5628.701157808304,5081.527356147766,546.2405309677124,0.3265047073364258,0.0 -10800,0.96577203,5.954132,,,,,,,,,,,,,, -10900,0.9642635,3.7162917,,,,,,,,,,,,,, -11000,0.86983716,4.1107078,,,,,,,,,,,,,, -11100,1.1149484,3.7572362,,,,,,,,,,,,,, -11200,0.69549555,5.65575,,,,,,,,,,,,,, -11300,0.9800062,3.9244933,,,,,,,,,,,,,, -11400,0.6388872,5.593734,,,,,,,,,,,,,, -11500,0.9718998,3.5360858,,,,,,,,,,,,,, -11600,0.9143522,3.5948892,,,,,,,,,,,,,, -11688,,,0.4415038824081421,2.5091471672058105,0.405379980802536,2.7154462337493896,50000.0,0.3106000125408172,3.323500633239746,10000.0,5501.6132843494415,6090.867266654968,5501.6132843494415,588.2387778759003,0.3575336933135986,0.0 -11700,0.8492387,3.629385,,,,,,,,,,,,,, -11800,0.941833,3.7205884,,,,,,,,,,,,,, -11900,0.94983983,3.5205514,,,,,,,,,,,,,, -12000,0.85072774,3.5355654,,,,,,,,,,,,,, -12100,0.95049804,3.4281118,,,,,,,,,,,,,, -12200,0.73862636,5.6670084,,,,,,,,,,,,,, -12300,0.6644914,4.921721,,,,,,,,,,,,,, -12400,0.598598,5.7305703,,,,,,,,,,,,,, -12500,0.9413764,3.5176215,,,,,,,,,,,,,, -12590,,,0.4447070360183716,2.521016836166382,0.4148999750614166,2.6761584281921387,50000.0,0.3197000026702881,3.3049449920654297,10000.0,5922.011283874512,6558.102586507797,5922.011283874512,634.9937858581543,0.3894164562225342,0.0 -12600,0.85992104,5.8551474,,,,,,,,,,,,,, -12700,1.0816737,3.6405535,,,,,,,,,,,,,, -12800,0.9285709,3.6088865,,,,,,,,,,,,,, -12900,1.0220225,3.4657369,,,,,,,,,,,,,, -13000,0.71121156,5.842996,,,,,,,,,,,,,, -13100,0.99409115,3.754823,,,,,,,,,,,,,, -13200,1.1378882,3.5460696,,,,,,,,,,,,,, -13300,0.925706,3.4431856,,,,,,,,,,,,,, -13400,0.9392374,3.7132328,,,,,,,,,,,,,, -13496,,,0.4694921672344208,2.356694221496582,0.435479998588562,2.5276846885681152,50000.0,0.3357000052928924,3.1664600372314453,10000.0,6342.245297193527,7018.754692077637,6342.245297193527,675.3277657032013,0.4220564365386963,0.0 -13500,0.81010926,4.8434005,,,,,,,,,,,,,, -13600,1.1599064,3.4103653,,,,,,,,,,,,,, -13700,1.0510097,3.3908327,,,,,,,,,,,,,, -13800,1.1210108,3.476776,,,,,,,,,,,,,, -13900,0.8100934,4.456892,,,,,,,,,,,,,, -14000,0.8774613,3.8379338,,,,,,,,,,,,,, -14100,0.980321,3.3721683,,,,,,,,,,,,,, -14200,0.9247099,3.3326747,,,,,,,,,,,,,, -14300,1.0608996,3.5451872,,,,,,,,,,,,,, -14400,,,0.4701952934265136,2.3852572441101074,0.43121999502182,2.579622268676758,50000.0,0.3365000188350677,3.197645664215088,10000.0,6762.292719364166,7483.901390790939,6762.292719364166,720.3443946838379,0.4542615413665771,0.0 -14400,0.62283266,5.2848907,,,,,,,,,,,,,, -14500,1.0757194,3.3935256,,,,,,,,,,,,,, -14600,1.006991,3.2891047,,,,,,,,,,,,,, -14700,0.83614355,5.512132,,,,,,,,,,,,,, -14800,0.9971668,3.2917533,,,,,,,,,,,,,, -14900,0.6900709,5.3893127,,,,,,,,,,,,,, -15000,0.8618313,5.7540846,,,,,,,,,,,,,, -15100,0.9125649,3.2743769,,,,,,,,,,,,,, -15200,1.0895284,3.175331,,,,,,,,,,,,,, -15300,0.92790776,3.626689,,,,,,,,,,,,,, -15307,,,0.4783593714237213,2.331885814666748,0.4416399896144867,2.534491539001465,50000.0,0.3420000076293945,3.1562328338623047,10000.0,7182.435878038406,7950.122545957565,7182.435878038406,766.3428432941437,0.4822478294372558,0.0 -15400,1.110404,3.4285376,,,,,,,,,,,,,, -15500,0.8245906,3.7627015,,,,,,,,,,,,,, -15600,1.0594913,3.274434,,,,,,,,,,,,,, -15700,0.9886884,3.687938,,,,,,,,,,,,,, -15800,0.7171176,4.916171,,,,,,,,,,,,,, -15900,1.053988,3.2581835,,,,,,,,,,,,,, -16000,1.1479266,3.2629473,,,,,,,,,,,,,, -16100,1.0534431,3.332883,,,,,,,,,,,,,, -16200,0.88333124,3.9687886,,,,,,,,,,,,,, -16215,,,0.4888867139816284,2.2850089073181152,0.4542399942874908,2.4670143127441406,50000.0,0.3481000065803528,3.092832088470459,10000.0,7602.47861289978,8417.176041603088,7602.47861289978,813.2723207473755,0.5117042064666748,0.0 -16300,0.990678,3.3660686,,,,,,,,,,,,,, -16400,1.2097372,3.5828514,,,,,,,,,,,,,, -16500,0.99822944,3.8252263,,,,,,,,,,,,,, -16600,1.067413,3.2473993,,,,,,,,,,,,,, -16700,0.98817116,3.741584,,,,,,,,,,,,,, -16800,0.9397366,4.0011578,,,,,,,,,,,,,, -16900,1.0075245,3.2756283,,,,,,,,,,,,,, -17000,0.797722,4.203371,,,,,,,,,,,,,, -17100,0.9891348,3.2023997,,,,,,,,,,,,,, -17122,,,0.4998632669448852,2.2155532836914062,0.4608399868011474,2.4244303703308105,50000.0,0.3571000099182129,3.0820775032043457,10000.0,8022.711845397949,8880.401599168777,8022.711845397949,856.1803793907166,0.5442097187042236,0.0 -17200,1.0935954,3.2803636,,,,,,,,,,,,,, -17300,1.0820216,3.3162112,,,,,,,,,,,,,, -17400,1.051218,3.3161614,,,,,,,,,,,,,, -17500,0.8712337,4.3669233,,,,,,,,,,,,,, -17600,0.76581186,5.0833015,,,,,,,,,,,,,, -17700,1.0685524,3.3119974,,,,,,,,,,,,,, -17800,0.9320235,3.2260199,,,,,,,,,,,,,, -17900,0.76169693,5.7057915,,,,,,,,,,,,,, -18000,1.1108499,3.3594818,,,,,,,,,,,,,, -18026,,,0.5294336080551147,2.071880340576172,0.4652199745178222,2.3781328201293945,50000.0,0.3575000166893005,3.0405938625335693,10000.0,8442.835622787476,9348.312663078308,8442.835622787476,903.8825159072876,0.5782530307769775,0.0 -18100,1.0344173,3.207993,,,,,,,,,,,,,, -18200,1.1893489,3.1785474,,,,,,,,,,,,,, -18300,0.9773809,3.3355184,,,,,,,,,,,,,, -18400,1.172865,3.1009183,,,,,,,,,,,,,, -18500,1.0342684,3.3670917,,,,,,,,,,,,,, -18600,0.8834873,4.2308083,,,,,,,,,,,,,, -18700,0.8441664,5.584816,,,,,,,,,,,,,, -18800,1.0319018,3.1468794,,,,,,,,,,,,,, -18900,1.0231969,3.6540759,,,,,,,,,,,,,, -18931,,,0.5103710889816284,2.167488574981689,0.4759399890899658,2.3467721939086914,50000.0,0.368800014257431,3.0019314289093018,10000.0,8862.765635967255,9813.79310321808,8862.765635967255,949.3529365062714,0.607285737991333,0.0 -19000,1.1053841,3.158403,,,,,,,,,,,,,, -19100,1.0128815,3.2241855,,,,,,,,,,,,,, -19200,0.92208815,4.1174064,,,,,,,,,,,,,, -19300,1.2225792,3.6353297,,,,,,,,,,,,,, -19400,0.9188987,3.4645166,,,,,,,,,,,,,, -19500,1.0366881,3.3307095,,,,,,,,,,,,,, -19600,1.0173178,3.162242,,,,,,,,,,,,,, -19700,0.94170475,4.120306,,,,,,,,,,,,,, -19800,1.0768281,3.188411,,,,,,,,,,,,,, -19837,,,0.5275976657867432,2.005328416824341,0.4874999821186065,2.2312562465667725,50000.0,0.3788000047206878,2.899568557739258,10000.0,9283.02355337143,10281.213695764542,9283.02355337143,996.4339108467102,0.6371915340423584,0.0 -19900,1.0084505,3.1950512,,,,,,,,,,,,,, -20000,0.74762684,5.2892227,,,,,,,,,,,,,, -20100,1.1414279,3.301883,,,,,,,,,,,,,, -20200,0.8254015,4.8166,,,,,,,,,,,,,, -20300,1.0520532,3.1962192,,,,,,,,,,,,,, -20400,1.154934,3.3542194,,,,,,,,,,,,,, -20500,1.1410731,3.108071,,,,,,,,,,,,,, -20600,1.0319883,3.4787343,,,,,,,,,,,,,, -20700,0.8100566,5.5316644,,,,,,,,,,,,,, -20743,,,0.5421093702316284,2.062912940979004,0.4810599982738495,2.356579303741455,50000.0,0.3737000226974487,2.998039484024048,10000.0,9703.317378759384,10739.15369963646,9703.317378759384,1033.9986152648926,0.6677725315093994,0.0 -20800,0.6601565,5.4511676,,,,,,,,,,,,,, -20900,0.9562824,5.454737,,,,,,,,,,,,,, -21000,0.9532594,4.3518767,,,,,,,,,,,,,, -21100,0.88777274,5.364402,,,,,,,,,,,,,, -21200,1.0602126,3.4573376,,,,,,,,,,,,,, -21300,1.0370048,3.2148786,,,,,,,,,,,,,, -21400,0.8478764,4.9243054,,,,,,,,,,,,,, -21500,1.0526739,3.0860353,,,,,,,,,,,,,, -21600,0.73789614,5.416651,,,,,,,,,,,,,, -21646,,,0.5312694907188416,2.0607404708862305,0.4913399815559387,2.2656807899475098,50000.0,0.3819000124931335,2.9162533283233643,10000.0,10123.384189367294,11206.49951314926,10123.384189367294,1081.1976935863495,0.6967108249664307,0.0 -21700,0.84478176,4.3641877,,,,,,,,,,,,,, -21800,1.0919111,3.426642,,,,,,,,,,,,,, -21900,1.1677222,3.0889945,,,,,,,,,,,,,, -22000,0.84473884,4.4702954,,,,,,,,,,,,,, -22100,1.1444156,3.051804,,,,,,,,,,,,,, -22200,1.2890155,3.199666,,,,,,,,,,,,,, -22300,1.0949444,3.0853658,,,,,,,,,,,,,, -22400,0.9554744,5.434185,,,,,,,,,,,,,, -22500,0.6987645,5.4578347,,,,,,,,,,,,,, -22552,,,0.5475195050239563,1.9644417762756348,0.5002999901771545,2.18509578704834,50000.0,0.3917000293731689,2.8608169555664062,10000.0,10543.569134950638,11671.593814611437,10543.569134950638,1126.02326130867,0.7283682823181152,0.0 -22600,1.0195802,3.94189,,,,,,,,,,,,,, -22700,0.9882701,3.3890438,,,,,,,,,,,,,, -22800,0.97424895,3.5433035,,,,,,,,,,,,,, -22900,1.0455612,3.1134403,,,,,,,,,,,,,, -23000,0.8455005,5.5301266,,,,,,,,,,,,,, -23100,1.5186226,3.074084,,,,,,,,,,,,,, -23200,0.8112956,5.0879335,,,,,,,,,,,,,, -23300,1.1501558,3.1924477,,,,,,,,,,,,,, -23400,1.003722,3.697342,,,,,,,,,,,,,, -23459,,,0.5622265338897705,1.8803318738937376,0.5024799704551697,2.188140630722046,50000.0,0.3927000164985657,2.830965995788574,10000.0,10963.765152215958,12140.410922527311,10963.765152215958,1174.560831785202,0.7606453895568848,0.0 -23500,1.049174,3.0670667,,,,,,,,,,,,,, -23600,1.3584902,3.0147657,,,,,,,,,,,,,, -23700,1.2739545,3.7291813,,,,,,,,,,,,,, -23800,0.9526061,3.7240214,,,,,,,,,,,,,, -23900,0.78141135,5.4839044,,,,,,,,,,,,,, -24000,0.9104976,3.8919137,,,,,,,,,,,,,, -24100,1.1935048,3.0758097,,,,,,,,,,,,,, -24200,1.0678968,5.495455,,,,,,,,,,,,,, -24300,0.93468,4.1438975,,,,,,,,,,,,,, -24365,,,0.5406445264816284,2.031728982925415,0.5039199590682983,2.227212190628052,50000.0,0.3934000134468078,2.869940996170044,10000.0,11384.004220485687,12600.587675094604,11384.004220485687,1214.4176337718964,0.7908592224121094,0.0 -24400,0.86605805,4.4971347,,,,,,,,,,,,,, -24500,1.1278025,3.111362,,,,,,,,,,,,,, -24600,1.093293,3.1699085,,,,,,,,,,,,,, -24700,0.72678256,5.5116377,,,,,,,,,,,,,, -24800,1.082565,2.8673348,,,,,,,,,,,,,, -24900,1.1176357,3.0912943,,,,,,,,,,,,,, -25000,0.7834158,5.057205,,,,,,,,,,,,,, -25100,0.78077334,4.7264624,,,,,,,,,,,,,, -25200,1.1603842,3.09996,,,,,,,,,,,,,, -25272,,,0.5468554496765137,1.95665979385376,0.5080199837684631,2.159663438796997,50000.0,0.3937000334262848,2.8316471576690674,10000.0,11804.188628435137,13068.275243759155,11804.188628435137,1261.837417602539,0.822918176651001,0.0 -25300,1.1614841,3.047953,,,,,,,,,,,,,, -25400,1.1687466,3.1492176,,,,,,,,,,,,,, -25500,0.8834936,3.9293299,,,,,,,,,,,,,, -25600,0.83363926,5.5290155,,,,,,,,,,,,,, -25700,1.2832605,3.1618497,,,,,,,,,,,,,, -25800,1.2663158,2.9758687,,,,,,,,,,,,,, -25900,1.0308938,2.9547143,,,,,,,,,,,,,, -26000,1.0357907,4.2893248,,,,,,,,,,,,,, -26100,1.1294012,2.9508133,,,,,,,,,,,,,, -26178,,,0.5712695121765137,1.830441117286682,0.519760012626648,2.1027395725250244,50000.0,0.4075000286102295,2.760247707366944,10000.0,12224.11382508278,13530.477799415588,12224.11382508278,1304.0307462215424,0.8551509380340576,0.0 -26200,1.1184077,2.989462,,,,,,,,,,,,,, -26300,0.87414306,4.3720574,,,,,,,,,,,,,, -26400,0.9378818,4.2211943,,,,,,,,,,,,,, -26500,1.0550706,3.0561342,,,,,,,,,,,,,, -26600,1.0995985,3.1270483,,,,,,,,,,,,,, -26700,0.9546052,3.9952862,,,,,,,,,,,,,, -26800,0.9675384,3.4741454,,,,,,,,,,,,,, -26900,1.3440741,2.9033606,,,,,,,,,,,,,, -27000,1.0032892,4.9214654,,,,,,,,,,,,,, -27083,,,0.5553905963897705,1.8902047872543333,0.5142999887466431,2.094199657440185,50000.0,0.4067000150680542,2.7455155849456787,10000.0,12644.132603883743,13995.284102916718,12644.132603883743,1348.732283115387,0.8903157711029053,0.0 -27100,0.8733201,4.5961423,,,,,,,,,,,,,, -27200,0.86530256,4.6738234,,,,,,,,,,,,,, -27300,1.0697966,2.795797,,,,,,,,,,,,,, -27400,0.9433778,5.045333,,,,,,,,,,,,,, -27500,1.1764423,3.4702182,,,,,,,,,,,,,, -27600,1.1106248,3.3652723,,,,,,,,,,,,,, -27700,1.0302024,2.9051602,,,,,,,,,,,,,, -27800,1.1900859,2.9116914,,,,,,,,,,,,,, -27900,1.177879,2.9253325,,,,,,,,,,,,,, -27989,,,0.5625976324081421,1.860401391983032,0.5207599997520447,2.0770411491394043,50000.0,0.4098000228404999,2.7530479431152344,10000.0,13064.056718111038,14457.495115756989,13064.056718111038,1390.9314963817596,0.926861047744751,0.0 -28000,0.8254001,5.058073,,,,,,,,,,,,,, -28100,1.1532525,2.8572335,,,,,,,,,,,,,, -28200,1.0750158,2.8449218,,,,,,,,,,,,,, -28300,1.108691,2.942278,,,,,,,,,,,,,, -28400,1.0599259,3.4456046,,,,,,,,,,,,,, -28500,0.8638268,4.883448,,,,,,,,,,,,,, -28600,1.0728592,2.8677647,,,,,,,,,,,,,, -28700,1.0751843,2.9912453,,,,,,,,,,,,,, -28800,1.323053,3.101194,,,,,,,,,,,,,, -28896,,,0.5817968845367432,1.820831298828125,0.5314399600028992,2.074846029281616,50000.0,0.4182000160217285,2.730376005172729,10000.0,13484.38296675682,14924.05919790268,13484.38296675682,1437.0800392627716,0.9647171497344972,0.0 -28900,1.049192,3.159399,,,,,,,,,,,,,, -29000,0.8660136,4.1686797,,,,,,,,,,,,,, -29100,0.8983988,5.577222,,,,,,,,,,,,,, -29200,1.1402825,2.8613763,,,,,,,,,,,,,, -29300,0.872307,5.570551,,,,,,,,,,,,,, -29400,1.0707214,4.098223,,,,,,,,,,,,,, -29500,0.941323,3.747703,,,,,,,,,,,,,, -29600,0.86262584,4.7599683,,,,,,,,,,,,,, -29700,0.8457533,5.5026035,,,,,,,,,,,,,, -29800,1.1323509,2.8979633,,,,,,,,,,,,,, -29803,,,0.5716015696525574,1.8176162242889404,0.5288999676704407,2.026429891586304,50000.0,0.415800005197525,2.7066798210144043,10000.0,13904.436289548874,15390.029060602188,13904.436289548874,1482.9127361774445,0.9980545043945312,0.0 -29900,1.0781243,2.9914706,,,,,,,,,,,,,, -30000,1.0680863,5.538283,,,,,,,,,,,,,, -30100,1.2156551,2.9537182,,,,,,,,,,,,,, -30200,1.0207554,3.5971913,,,,,,,,,,,,,, -30300,1.3334419,2.9787445,,,,,,,,,,,,,, -30400,1.0502403,2.9847615,,,,,,,,,,,,,, -30500,1.248845,2.9266062,,,,,,,,,,,,,, -30600,0.8325709,5.456135,,,,,,,,,,,,,, -30700,1.2308906,2.8110874,,,,,,,,,,,,,, -30710,,,0.5737109184265137,1.8410923480987549,0.5341199636459351,2.0497286319732666,50000.0,0.4155000150203705,2.723546504974365,10000.0,14324.511536359789,15857.354704141617,14324.511536359789,1530.08101439476,1.0294535160064695,0.0 -30800,1.2396749,2.876717,,,,,,,,,,,,,, -30900,1.2702022,2.9710944,,,,,,,,,,,,,, -31000,1.1708893,2.9203148,,,,,,,,,,,,,, -31100,1.1028467,2.8715153,,,,,,,,,,,,,, -31200,0.9868847,3.2717865,,,,,,,,,,,,,, -31300,0.98715603,3.347869,,,,,,,,,,,,,, -31400,1.075509,3.9343727,,,,,,,,,,,,,, -31500,1.1083214,2.7471042,,,,,,,,,,,,,, -31600,1.126628,2.8952384,,,,,,,,,,,,,, -31613,,,0.5854882597923279,1.7064871788024902,0.5389999747276306,1.964975118637085,50000.0,0.4191000163555145,2.6509931087493896,10000.0,14744.565649032593,16323.81619167328,14744.565649032593,1576.4077117443085,1.0594823360443115,0.0 -31700,1.1669239,2.8164392,,,,,,,,,,,,,, -31800,1.2254916,2.9661303,,,,,,,,,,,,,, -31900,1.2922689,2.861941,,,,,,,,,,,,,, -32000,0.93603945,4.132213,,,,,,,,,,,,,, -32100,1.2556844,2.7875378,,,,,,,,,,,,,, -32200,1.1616957,2.7510722,,,,,,,,,,,,,, -32300,1.228829,2.828368,,,,,,,,,,,,,, -32400,1.0508931,3.5191708,,,,,,,,,,,,,, -32500,1.1613344,2.9149895,,,,,,,,,,,,,, -32519,,,0.5890039205551147,1.7607898712158203,0.5417999625205994,1.980234980583191,50000.0,0.4282000064849853,2.6211376190185547,10000.0,15164.7554564476,16786.085805416107,15164.7554564476,1618.4065613746643,1.089357614517212,0.0 -32600,1.1154097,3.2832992,,,,,,,,,,,,,, -32700,0.9693073,3.4807544,,,,,,,,,,,,,, -32800,1.0938779,3.035719,,,,,,,,,,,,,, -32900,1.5454831,2.9225142,,,,,,,,,,,,,, -33000,1.1884298,2.877024,,,,,,,,,,,,,, -33100,1.1309335,2.947952,,,,,,,,,,,,,, -33200,0.8861622,3.9727573,,,,,,,,,,,,,, -33300,1.1176188,2.857616,,,,,,,,,,,,,, -33400,0.9453057,3.8797967,,,,,,,,,,,,,, -33426,,,0.5855468511581421,1.7376009225845337,0.5393800139427185,1.9697725772857664,50000.0,0.4208000302314758,2.6821558475494385,10000.0,15585.092749595642,17252.33269262314,15585.092749595642,1664.2355268001556,1.119279861450195,0.0 -33500,1.1407803,2.9588783,,,,,,,,,,,,,, -33600,0.97236365,3.5065696,,,,,,,,,,,,,, -33700,1.2370957,2.97051,,,,,,,,,,,,,, -33800,1.1020839,2.9002922,,,,,,,,,,,,,, -33900,0.9176145,4.7530785,,,,,,,,,,,,,, -34000,1.1166412,2.7166476,,,,,,,,,,,,,, -34100,0.94345003,5.1388664,,,,,,,,,,,,,, -34200,0.87167144,5.3844004,,,,,,,,,,,,,, -34300,1.127367,2.7733417,,,,,,,,,,,,,, -34332,,,0.5911523103713989,1.7255398035049438,0.5455799698829651,1.9740146398544312,50000.0,0.4281000196933746,2.639055252075196,10000.0,16005.261153697968,17718.188776254654,16005.261153697968,1709.8396821022034,1.1521375179290771,0.0 -34400,1.0241934,3.5924063,,,,,,,,,,,,,, -34500,0.98459643,3.508066,,,,,,,,,,,,,, -34600,1.0880481,3.1282172,,,,,,,,,,,,,, -34700,1.1346797,2.8427758,,,,,,,,,,,,,, -34800,1.0381949,2.6954727,,,,,,,,,,,,,, -34900,1.2063789,5.2783346,,,,,,,,,,,,,, -35000,1.2040114,2.7247562,,,,,,,,,,,,,, -35100,1.204703,2.79909,,,,,,,,,,,,,, -35200,0.9322124,5.37925,,,,,,,,,,,,,, -35240,,,0.6228320002555847,1.5777928829193115,0.555180013179779,1.9033994674682613,50000.0,0.4411000311374664,2.584881544113159,10000.0,16425.55720925331,18184.633157491684,16425.55720925331,1755.8981931209564,1.190577507019043,0.0 -35300,1.0205629,5.162715,,,,,,,,,,,,,, -35400,1.018731,3.4873154,,,,,,,,,,,,,, -35500,1.1320609,2.8416722,,,,,,,,,,,,,, -35600,1.1735605,2.9486341,,,,,,,,,,,,,, -35700,1.0815824,2.7556224,,,,,,,,,,,,,, -35800,0.915492,4.390539,,,,,,,,,,,,,, -35900,1.0886095,3.1893153,,,,,,,,,,,,,, -36000,1.0315688,3.821695,,,,,,,,,,,,,, -36100,1.1594507,2.7141173,,,,,,,,,,,,,, -36145,,,0.5803906321525574,1.8397955894470213,0.5388000011444092,2.0403988361358643,50000.0,0.4221000075340271,2.6815497875213623,10000.0,16845.653829813004,18650.0134100914,16845.653829813004,1801.0961050987244,1.2253484725952148,0.0 -36200,1.2157171,2.8188741,,,,,,,,,,,,,, -36300,0.98199934,4.631789,,,,,,,,,,,,,, -36400,1.1025237,2.7940977,,,,,,,,,,,,,, -36500,1.1909714,2.9299703,,,,,,,,,,,,,, -36600,1.0783081,2.7351992,,,,,,,,,,,,,, -36700,1.31362,2.798564,,,,,,,,,,,,,, -36800,1.2421799,2.8854575,,,,,,,,,,,,,, -36900,1.112635,2.9618735,,,,,,,,,,,,,, -37000,0.87165254,4.8037977,,,,,,,,,,,,,, -37050,,,0.6094335913658142,1.6258376836776731,0.5605800151824951,1.867913126945496,50000.0,0.4429000318050384,2.5216729640960693,10000.0,17265.834097862244,19117.25606751442,17265.834097862244,1848.0710878372192,1.2622802257537842,0.0 -37100,1.3083448,2.9056907,,,,,,,,,,,,,, -37200,0.96137196,4.697594,,,,,,,,,,,,,, -37300,1.2081512,3.0987358,,,,,,,,,,,,,, -37400,1.1657672,2.6637802,,,,,,,,,,,,,, -37500,1.0354902,2.7950842,,,,,,,,,,,,,, -37600,0.90354973,4.4451776,,,,,,,,,,,,,, -37700,1.1271099,2.7641194,,,,,,,,,,,,,, -37800,1.1293254,2.8072505,,,,,,,,,,,,,, -37900,0.7912034,5.083433,,,,,,,,,,,,,, -37953,,,0.6267187595367432,1.5315265655517578,0.5590599775314331,1.876887917518616,50000.0,0.442300021648407,2.54549503326416,10000.0,17685.957972049713,19579.37179613113,17685.957972049713,1889.9757521152496,1.2984967231750488,0.0 -38000,1.1059916,3.3517394,,,,,,,,,,,,,, -38100,1.1944634,2.7651274,,,,,,,,,,,,,, -38200,1.1162457,2.860854,,,,,,,,,,,,,, -38300,1.375615,5.488253,,,,,,,,,,,,,, -38400,1.1278325,2.8006065,,,,,,,,,,,,,, -38500,1.124948,2.903851,,,,,,,,,,,,,, -38600,0.8328524,4.9953995,,,,,,,,,,,,,, -38700,1.0326022,3.2346373,,,,,,,,,,,,,, -38800,0.9034371,4.4263186,,,,,,,,,,,,,, -38858,,,0.6036718487739563,1.6944724321365356,0.5605999827384949,1.9091485738754272,50000.0,0.4444000124931335,2.581925630569458,10000.0,18105.92478275299,20047.60985803604,18105.92478275299,1938.1587483882904,1.3353419303894043,0.0 -38900,1.0973184,4.6615987,,,,,,,,,,,,,, -39000,1.1573176,2.736465,,,,,,,,,,,,,, -39100,0.9740647,3.7417846,,,,,,,,,,,,,, -39200,0.9898333,5.391466,,,,,,,,,,,,,, -39300,1.120783,3.0456262,,,,,,,,,,,,,, -39400,1.002215,5.331161,,,,,,,,,,,,,, -39500,0.9545083,5.4321094,,,,,,,,,,,,,, -39600,1.1649187,2.837703,,,,,,,,,,,,,, -39700,1.1722625,2.7746694,,,,,,,,,,,,,, -39766,,,0.6093164086341858,1.628745675086975,0.5579000115394592,1.878006219863892,50000.0,0.4431000351905823,2.54802680015564,10000.0,18526.282481193542,20515.20761990547,18526.282481193542,1985.312269449234,1.3708126544952393,0.0 -39800,1.0963749,4.543038,,,,,,,,,,,,,, -39900,0.99036103,3.7441578,,,,,,,,,,,,,, -40000,1.235985,3.0523102,,,,,,,,,,,,,, -40100,1.4149498,2.9361503,,,,,,,,,,,,,, -40200,1.0006144,5.3678126,,,,,,,,,,,,,, -40300,1.1807325,2.7409387,,,,,,,,,,,,,, -40400,1.1955247,2.8084981,,,,,,,,,,,,,, -40500,1.05398,3.1931977,,,,,,,,,,,,,, -40600,1.1444895,2.7929256,,,,,,,,,,,,,, -40672,,,0.6217382550239563,1.5823453664779663,0.5575399994850159,1.8897854089736936,50000.0,0.4397000074386596,2.577730178833008,10000.0,18946.47481775284,20981.583657741547,18946.47481775284,2031.407597780228,1.408151388168335,0.0 -40700,1.1420329,3.8116522,,,,,,,,,,,,,, -40800,1.172648,2.634239,,,,,,,,,,,,,, -40900,1.1223339,2.9794977,,,,,,,,,,,,,, -41000,1.2763859,2.7734852,,,,,,,,,,,,,, -41100,1.1225481,2.8838217,,,,,,,,,,,,,, -41200,1.1589652,3.1725476,,,,,,,,,,,,,, -41300,0.91016066,5.18435,,,,,,,,,,,,,, -41400,1.1437049,2.6584525,,,,,,,,,,,,,, -41500,0.99965507,4.742202,,,,,,,,,,,,,, -41577,,,0.6031640768051147,1.654468655586243,0.5612999796867371,1.8774502277374268,50000.0,0.4427000284194946,2.5433125495910645,10000.0,19366.68581557274,21450.45510196685,19366.68581557274,2079.980150461197,1.4446847438812256,0.0 -41600,1.1557956,2.9292798,,,,,,,,,,,,,, -41700,0.9041278,4.2547603,,,,,,,,,,,,,, -41800,1.342908,2.783907,,,,,,,,,,,,,, -41900,1.0365033,3.1303432,,,,,,,,,,,,,, -42000,1.2455124,2.698615,,,,,,,,,,,,,, -42100,1.1844631,2.9964724,,,,,,,,,,,,,, -42200,1.1583117,2.8124292,,,,,,,,,,,,,, -42300,0.9156639,4.8078485,,,,,,,,,,,,,, -42400,1.085858,3.6150799,,,,,,,,,,,,,, -42482,,,0.6094140410423279,1.6763795614242554,0.5593799948692322,1.911125898361206,50000.0,0.4410000145435333,2.589221715927124,10000.0,19786.95946097374,21913.58345246315,19786.95946097374,2122.748773813248,1.4790616035461426,0.0 -42500,1.0597513,2.720621,,,,,,,,,,,,,, -42600,1.2534431,2.709533,,,,,,,,,,,,,, -42700,0.95682406,4.2879043,,,,,,,,,,,,,, -42800,0.98387694,4.5165744,,,,,,,,,,,,,, -42900,1.20639,2.8833287,,,,,,,,,,,,,, -43000,1.1120764,2.659943,,,,,,,,,,,,,, -43100,1.1933343,2.685265,,,,,,,,,,,,,, -43200,1.1380267,3.303155,,,,,,,,,,,,,, -43300,1.1889683,2.7982757,,,,,,,,,,,,,, -43388,,,0.6217382550239563,1.639617919921875,0.5581200122833252,1.9320470094680784,50000.0,0.4397000074386596,2.585341215133667,10000.0,20207.0824007988,22379.42211842537,20207.0824007988,2168.378532409668,1.5133049488067627,0.0 -43400,1.1169707,2.7308674,,,,,,,,,,,,,, -43500,1.2399535,2.810278,,,,,,,,,,,,,, -43600,1.2732174,2.817771,,,,,,,,,,,,,, -43700,0.8430821,4.6969876,,,,,,,,,,,,,, -43800,1.2489007,2.756197,,,,,,,,,,,,,, -43900,1.2098675,2.8124523,,,,,,,,,,,,,, -44000,1.1879909,2.7358103,,,,,,,,,,,,,, -44100,0.98573023,4.5903296,,,,,,,,,,,,,, -44200,1.2175514,2.6802197,,,,,,,,,,,,,, -44294,,,0.613964855670929,1.6308157444000244,0.5718799829483032,1.843588590621948,50000.0,0.4519000351428985,2.499508619308472,10000.0,20627.26855278015,22847.91852927208,20627.26855278015,2216.601419687271,1.5490844249725342,0.0 -44300,1.2412211,2.7569287,,,,,,,,,,,,,, -44400,0.9791773,4.2241974,,,,,,,,,,,,,, -44500,1.291249,2.622689,,,,,,,,,,,,,, -44600,1.2126305,2.812227,,,,,,,,,,,,,, -44700,1.0792964,3.4566636,,,,,,,,,,,,,, -44800,1.1211636,4.4949946,,,,,,,,,,,,,, -44900,1.1733605,2.660806,,,,,,,,,,,,,, -45000,0.88627905,4.7804203,,,,,,,,,,,,,, -45100,1.0513858,4.5699654,,,,,,,,,,,,,, -45197,,,0.6221679449081421,1.5743632316589355,0.5730199813842773,1.8135331869125368,50000.0,0.45660001039505,2.4988009929656982,10000.0,21047.384136915207,23311.679278612137,21047.384136915207,2260.163228034973,1.5819377899169922,0.0 -45200,1.0484107,5.3918467,,,,,,,,,,,,,, -45300,1.109544,2.859095,,,,,,,,,,,,,, -45400,1.1362144,2.8166533,,,,,,,,,,,,,, -45500,1.3282595,2.675827,,,,,,,,,,,,,, -45600,1.3938663,2.668303,,,,,,,,,,,,,, -45700,0.96837175,5.1689434,,,,,,,,,,,,,, -45800,1.0228614,4.6815963,,,,,,,,,,,,,, -45900,1.1716352,2.7812057,,,,,,,,,,,,,, -46000,1.0836825,4.87192,,,,,,,,,,,,,, -46100,1.220679,2.6916304,,,,,,,,,,,,,, -46101,,,0.6398242115974426,1.471780776977539,0.5819199681282043,1.7632266283035278,50000.0,0.4600000083446502,2.440560340881348,10000.0,21468.03368353844,23779.937309980392,21468.03368353844,2307.684317350388,1.6177711486816406,0.0 -46200,1.1277366,2.7338533,,,,,,,,,,,,,, -46300,1.051018,3.5664485,,,,,,,,,,,,,, -46400,1.1867957,2.6617153,,,,,,,,,,,,,, -46500,1.0943735,2.81091,,,,,,,,,,,,,, -46600,1.1922059,2.5163653,,,,,,,,,,,,,, -46700,1.2002525,2.7953715,,,,,,,,,,,,,, -46800,0.88983536,5.1104584,,,,,,,,,,,,,, -46900,1.0642107,4.335901,,,,,,,,,,,,,, -47000,1.0363035,4.2464633,,,,,,,,,,,,,, -47003,,,0.6192382574081421,1.632784128189087,0.5776799917221069,1.838263750076294,50000.0,0.4534000158309936,2.5132648944854736,10000.0,21888.051579475403,24245.620250225067,21888.051579475403,2353.266669511795,1.649388551712036,0.0 -47100,1.0995497,2.627588,,,,,,,,,,,,,, -47200,1.5332805,2.6570628,,,,,,,,,,,,,, -47300,0.9509449,4.6724286,,,,,,,,,,,,,, -47400,0.9072458,5.2877827,,,,,,,,,,,,,, -47500,1.1219434,2.9193742,,,,,,,,,,,,,, -47600,1.3260721,2.742074,,,,,,,,,,,,,, -47700,1.0201262,3.9620585,,,,,,,,,,,,,, -47800,1.1572549,3.6709661,,,,,,,,,,,,,, -47900,1.2597692,2.653979,,,,,,,,,,,,,, -47905,,,0.6240624785423279,1.551044464111328,0.5730800032615662,1.7903846502304075,50000.0,0.4571000337600708,2.462394952774048,10000.0,22308.01690769196,24714.25982022285,22308.01690769196,2401.858140230179,1.6812567710876465,0.0 -48000,1.0258676,3.7070432,,,,,,,,,,,,,, -48100,1.2525321,2.7571583,,,,,,,,,,,,,, -48200,1.118605,2.9398036,,,,,,,,,,,,,, -48300,1.1638653,2.8351128,,,,,,,,,,,,,, -48400,1.1623044,2.644191,,,,,,,,,,,,,, -48500,1.2396475,2.570083,,,,,,,,,,,,,, -48600,1.0860468,2.699936,,,,,,,,,,,,,, -48700,1.1989775,2.6312406,,,,,,,,,,,,,, -48800,1.078038,2.599808,,,,,,,,,,,,,, -48809,,,0.6260156035423279,1.572181224822998,0.5747199654579163,1.8311009407043457,50000.0,0.4568000137805938,2.4816720485687256,10000.0,22728.048118829727,25173.80234694481,22728.048118829727,2441.2826936244965,1.717482089996338,0.0 -48900,1.0456822,3.592969,,,,,,,,,,,,,, -49000,1.0634189,4.199345,,,,,,,,,,,,,, -49100,1.1680238,2.643869,,,,,,,,,,,,,, -49200,1.2890092,2.68642,,,,,,,,,,,,,, -49300,0.9709727,4.2005405,,,,,,,,,,,,,, -49400,1.0976136,2.7563448,,,,,,,,,,,,,, -49500,1.2852178,2.5049684,,,,,,,,,,,,,, -49600,0.91016644,5.173407,,,,,,,,,,,,,, -49700,1.1916938,2.6295605,,,,,,,,,,,,,, -49711,,,0.6138671636581421,1.6917517185211182,0.5685399770736694,1.90058434009552,50000.0,0.4463000297546386,2.5759851932525635,10000.0,23147.635646104813,25640.11376547813,23147.635646104813,2487.4757957458496,2.1974527835845947,0.0 -49800,1.2849826,2.5097024,,,,,,,,,,,,,, -49900,1.2960458,2.8978624,,,,,,,,,,,,,, -50000,1.2551601,2.7280846,,,,,,,,,,,,,, -50100,0.9132762,5.1505456,,,,,,,,,,,,,, -50200,1.1965481,2.6021671,,,,,,,,,,,,,, -50300,1.2300501,3.052909,,,,,,,,,,,,,, -50400,1.2433003,5.3186893,,,,,,,,,,,,,, -50500,1.2420237,2.8106313,,,,,,,,,,,,,, -50600,1.1574373,2.7206805,,,,,,,,,,,,,, -50614,,,0.627246081829071,1.6077184677124023,0.5752800107002258,1.8396003246307373,50000.0,0.4588000178337097,2.4899773597717285,10000.0,23567.73213458061,26104.293008327484,23567.73213458061,2531.470423936844,2.233795166015625,0.0 -50700,1.0636271,3.564756,,,,,,,,,,,,,, -50800,1.2027423,2.597332,,,,,,,,,,,,,, -50900,1.2899693,2.5815434,,,,,,,,,,,,,, -51000,1.3417819,2.7183166,,,,,,,,,,,,,, -51100,1.173828,2.5394394,,,,,,,,,,,,,, -51200,1.0151078,4.655072,,,,,,,,,,,,,, -51300,1.1580335,2.6362488,,,,,,,,,,,,,, -51400,1.4848359,2.7705202,,,,,,,,,,,,,, -51500,1.250691,2.5304933,,,,,,,,,,,,,, -51518,,,0.6355859041213989,1.560902118682861,0.5841400027275085,1.812380313873291,50000.0,0.4621000289916992,2.4717304706573486,10000.0,23987.8773355484,26570.3987429142,23987.8773355484,2577.3440272808075,2.2699053287506104,0.0 -51600,0.9118603,5.2463617,,,,,,,,,,,,,, -51700,1.0778497,3.5026875,,,,,,,,,,,,,, -51800,1.0695906,4.9271393,,,,,,,,,,,,,, -51900,1.2470446,2.5055897,,,,,,,,,,,,,, -52000,1.0436797,3.2306113,,,,,,,,,,,,,, -52100,0.94327873,4.6714144,,,,,,,,,,,,,, -52200,1.4318223,2.577856,,,,,,,,,,,,,, -52300,1.1405329,2.494756,,,,,,,,,,,,,, -52400,1.0109793,3.8242693,,,,,,,,,,,,,, -52423,,,0.6373242139816284,1.560644030570984,0.5805400013923645,1.8203362226486208,50000.0,0.4638000130653381,2.4787209033966064,10000.0,24408.21269702912,27037.57492542267,24408.21269702912,2624.098313808441,2.305191993713379,0.0 -52500,1.20002,3.1957514,,,,,,,,,,,,,, -52600,1.1615419,4.851033,,,,,,,,,,,,,, -52700,0.9061332,5.116596,,,,,,,,,,,,,, -52800,1.1211582,3.0081995,,,,,,,,,,,,,, -52900,1.089858,3.7425406,,,,,,,,,,,,,, -53000,1.196018,2.79204,,,,,,,,,,,,,, -53100,1.2210289,2.8901598,,,,,,,,,,,,,, -53200,1.1342822,2.855911,,,,,,,,,,,,,, -53300,1.0956165,3.2050948,,,,,,,,,,,,,, -53329,,,0.6320117115974426,1.5440112352371216,0.586359977722168,1.7737897634506226,50000.0,0.4708000123500824,2.432278871536255,10000.0,24828.36497950554,27501.591789245605,24828.36497950554,2667.875230550766,2.341886043548584,0.0 -53400,1.1949788,3.0410593,,,,,,,,,,,,,, -53500,1.0455054,5.2865663,,,,,,,,,,,,,, -53600,1.0587558,4.092231,,,,,,,,,,,,,, -53700,1.1637199,2.6368864,,,,,,,,,,,,,, -53800,1.0779506,4.985638,,,,,,,,,,,,,, -53900,1.1279112,4.615459,,,,,,,,,,,,,, -54000,1.0548408,3.83177,,,,,,,,,,,,,, -54100,1.2335573,2.5595796,,,,,,,,,,,,,, -54200,1.1714798,2.6429393,,,,,,,,,,,,,, -54233,,,0.6431054472923279,1.4797742366790771,0.5918599963188171,1.7450494766235352,50000.0,0.4805000126361847,2.3959012031555176,10000.0,25248.749005794525,27964.58424758911,25248.749005794525,2710.399166822433,2.375976324081421,0.0 -54300,1.0695298,3.0945368,,,,,,,,,,,,,, -54400,1.3822994,2.7583308,,,,,,,,,,,,,, -54500,1.2863195,2.4857702,,,,,,,,,,,,,, -54600,1.294624,2.6863828,,,,,,,,,,,,,, -54700,1.1803379,2.9660182,,,,,,,,,,,,,, -54800,1.1647831,3.4810476,,,,,,,,,,,,,, -54900,1.1655322,2.630158,,,,,,,,,,,,,, -55000,1.1062056,2.4767566,,,,,,,,,,,,,, -55100,1.0418848,3.6800866,,,,,,,,,,,,,, -55137,,,0.6678124666213989,1.3960999250411987,0.5922799706459045,1.7475827932357788,50000.0,0.4723000228404999,2.417467832565308,10000.0,25668.903258800507,28431.16887879372,25668.903258800507,2756.7333924770355,2.421182155609131,0.0 -55200,1.2485181,2.7007458,,,,,,,,,,,,,, -55300,1.1795318,2.6583753,,,,,,,,,,,,,, -55400,1.2597867,2.5994396,,,,,,,,,,,,,, -55500,0.93392134,5.223372,,,,,,,,,,,,,, -55600,1.1300842,2.7920096,,,,,,,,,,,,,, -55700,1.0664614,4.139467,,,,,,,,,,,,,, -55800,1.0679083,4.1996446,,,,,,,,,,,,,, -55900,1.2508239,2.6005464,,,,,,,,,,,,,, -56000,0.9987563,5.0145206,,,,,,,,,,,,,, -56041,,,0.6312109231948853,1.543101787567139,0.5813199877738953,1.7773977518081665,50000.0,0.4618000090122223,2.443195343017578,10000.0,26088.894425868988,28898.31833386421,26088.894425868988,2803.8037741184235,2.458335876464844,0.0 -56100,1.4833972,2.9380095,,,,,,,,,,,,,, -56200,1.1762065,2.7054212,,,,,,,,,,,,,, -56300,0.94490147,3.9517474,,,,,,,,,,,,,, -56400,1.1994628,2.6620865,,,,,,,,,,,,,, -56500,1.1420085,3.3060193,,,,,,,,,,,,,, -56600,1.1106722,2.5146148,,,,,,,,,,,,,, -56700,1.1847799,2.581322,,,,,,,,,,,,,, -56800,1.3184311,2.6904922,,,,,,,,,,,,,, -56900,1.350188,2.553128,,,,,,,,,,,,,, -56946,,,0.6483007669448853,1.5265979766845703,0.598039984703064,1.7617758512496948,50000.0,0.4783000349998474,2.411650896072388,10000.0,26509.195707321167,29361.984380960464,26509.195707321167,2847.0764875411987,2.49930739402771,0.0 -57000,1.4203733,2.583954,,,,,,,,,,,,,, -57100,1.2210754,2.5767736,,,,,,,,,,,,,, -57200,1.0298742,3.0681992,,,,,,,,,,,,,, -57300,1.1213336,2.796384,,,,,,,,,,,,,, -57400,1.11986,2.987828,,,,,,,,,,,,,, -57500,1.3180418,2.4908614,,,,,,,,,,,,,, -57600,0.99949217,4.1604633,,,,,,,,,,,,,, -57700,1.1734732,2.9329095,,,,,,,,,,,,,, -57800,1.2034929,2.7918098,,,,,,,,,,,,,, -57850,,,0.6579492092132568,1.4141149520874023,0.590499997138977,1.743949294090271,50000.0,0.4708000123500824,2.402419328689575,10000.0,26929.53288602829,29819.127604722977,26929.53288602829,2883.793509721756,2.5372140407562256,0.0 -57900,1.2589105,2.4918635,,,,,,,,,,,,,, -58000,0.9995558,5.129338,,,,,,,,,,,,,, -58100,0.9835253,5.1977196,,,,,,,,,,,,,, -58200,1.1628447,2.481336,,,,,,,,,,,,,, -58300,1.2889651,2.640234,,,,,,,,,,,,,, -58400,1.0242974,4.5608754,,,,,,,,,,,,,, -58500,1.2456428,2.7288396,,,,,,,,,,,,,, -58600,0.9487177,4.7993665,,,,,,,,,,,,,, -58700,1.2859557,2.5964158,,,,,,,,,,,,,, -58755,,,0.6382421851158142,1.5162142515182495,0.5955199599266052,1.725347638130188,50000.0,0.4786000251770019,2.3848416805267334,10000.0,27349.482364177704,30287.81723570824,27349.482364177704,2932.443725347519,2.576699495315552,0.0 -58800,0.9187762,4.7169137,,,,,,,,,,,,,, -58900,0.92767584,4.974282,,,,,,,,,,,,,, -59000,1.2614673,2.664248,,,,,,,,,,,,,, -59100,1.2660401,2.4844365,,,,,,,,,,,,,, -59200,1.3615676,2.7130814,,,,,,,,,,,,,, -59300,0.95901203,4.367745,,,,,,,,,,,,,, -59400,1.3727033,2.620192,,,,,,,,,,,,,, -59500,1.021021,3.305465,,,,,,,,,,,,,, -59600,1.4317842,2.6949732,,,,,,,,,,,,,, -59662,,,0.6444531083106995,1.4967687129974363,0.5925599932670593,1.7437061071395874,50000.0,0.4775000214576721,2.395569562911988,10000.0,27769.68719244004,30756.54388856888,27769.68719244004,2980.876615524292,2.613621950149536,0.0 -59700,1.2336817,2.8427825,,,,,,,,,,,,,, -59800,1.1065642,5.029841,,,,,,,,,,,,,, -59900,1.2302152,2.6203933,,,,,,,,,,,,,, -60000,0.98421663,4.2952385,,,,,,,,,,,,,, -60100,1.3222756,2.6153846,,,,,,,,,,,,,, -60200,1.1530482,2.902824,,,,,,,,,,,,,, -60300,1.8459643,2.4823208,,,,,,,,,,,,,, -60400,1.1391802,2.7648559,,,,,,,,,,,,,, -60500,1.145718,3.3967426,,,,,,,,,,,,,, -60567,,,0.6565039157867432,1.454769730567932,0.5924599766731262,1.755020260810852,50000.0,0.4758000373840332,2.3964974880218506,10000.0,28189.677449941635,31224.26690030098,28189.677449941635,3028.5183403491974,2.6536691188812256,0.0 -60600,1.3323182,2.6291115,,,,,,,,,,,,,, -60700,1.0851331,3.8463445,,,,,,,,,,,,,, -60800,1.2716839,2.5654695,,,,,,,,,,,,,, -60900,1.1097225,5.0642376,,,,,,,,,,,,,, -61000,1.1270077,3.309669,,,,,,,,,,,,,, -61100,1.2383536,2.4338193,,,,,,,,,,,,,, -61200,1.2719573,2.6136005,,,,,,,,,,,,,, -61300,0.99281293,4.6110106,,,,,,,,,,,,,, -61400,1.3769503,2.6417112,,,,,,,,,,,,,, -61470,,,0.650585949420929,1.4309154748916626,0.6028599739074707,1.6541436910629272,50000.0,0.4808000326156616,2.325901508331299,10000.0,28609.880881786343,31685.34749984741,28609.880881786343,3069.303905725479,2.693972587585449,0.0 -61500,1.2901464,2.7081742,,,,,,,,,,,,,, -61600,1.0163176,4.325141,,,,,,,,,,,,,, -61700,1.1234406,2.6324437,,,,,,,,,,,,,, -61800,1.2351019,2.6230745,,,,,,,,,,,,,, -61900,1.1777494,2.9357786,,,,,,,,,,,,,, -62000,1.0793835,4.413041,,,,,,,,,,,,,, -62100,1.3439744,2.5913963,,,,,,,,,,,,,, -62200,1.0878729,4.6163645,,,,,,,,,,,,,, -62300,1.1270262,2.9750936,,,,,,,,,,,,,, -62375,,,0.6543945074081421,1.4196845293045044,0.60725998878479,1.649720549583435,50000.0,0.4886000156402588,2.3273630142211914,10000.0,29030.00518536568,32153.49118900299,29030.00518536568,3117.232168197632,2.733799457550049,0.0 -62400,1.1821795,2.824226,,,,,,,,,,,,,, -62500,1.0672907,3.4702098,,,,,,,,,,,,,, -62600,1.0001645,4.285569,,,,,,,,,,,,,, -62700,1.1033264,2.7741964,,,,,,,,,,,,,, -62800,1.1649895,5.183322,,,,,,,,,,,,,, -62900,1.5902892,2.5233371,,,,,,,,,,,,,, -63000,1.3043325,2.5631483,,,,,,,,,,,,,, -63100,1.2289183,2.417583,,,,,,,,,,,,,, -63200,0.95163125,4.5232563,,,,,,,,,,,,,, -63280,,,0.6655077934265137,1.3871628046035769,0.6027199625968933,1.683455228805542,50000.0,0.484000027179718,2.328487157821656,10000.0,29450.34034228325,32614.989812135696,29450.34034228325,3158.303008079529,2.775385618209839,0.0 -63300,1.1725181,5.026673,,,,,,,,,,,,,, -63400,0.97406447,5.0233617,,,,,,,,,,,,,, -63500,1.2413881,2.4269583,,,,,,,,,,,,,, -63600,1.175594,4.5323954,,,,,,,,,,,,,, -63700,1.2520983,2.6888475,,,,,,,,,,,,,, -63800,1.2622705,2.531438,,,,,,,,,,,,,, -63900,1.0551866,3.4442565,,,,,,,,,,,,,, -64000,1.3057425,2.6417158,,,,,,,,,,,,,, -64100,1.2474533,2.63941,,,,,,,,,,,,,, -64183,,,0.6518945097923279,1.436215043067932,0.605139970779419,1.6577297449111938,50000.0,0.4806000292301178,2.3341472148895264,10000.0,29870.42541265488,33077.85925221443,29870.42541265488,3200.9960539340973,2.816265344619751,0.0 -64200,1.144671,3.8640337,,,,,,,,,,,,,, -64300,1.149583,2.3899028,,,,,,,,,,,,,, -64400,1.3606273,2.6817513,,,,,,,,,,,,,, -64500,1.3270867,2.4376926,,,,,,,,,,,,,, -64600,1.5310303,2.5955994,,,,,,,,,,,,,, -64700,1.1989418,4.8801765,,,,,,,,,,,,,, -64800,1.4360603,2.5726395,,,,,,,,,,,,,, -64900,1.1834123,3.0747879,,,,,,,,,,,,,, -65000,1.2401551,2.669681,,,,,,,,,,,,,, -65089,,,0.6542773246765137,1.4176008701324463,0.6029399633407593,1.6662062406539917,50000.0,0.4880000352859497,2.3240888118743896,10000.0,30290.650116682053,33545.353556633,30290.650116682053,3248.1766040325165,2.854301929473877,0.0 -65100,1.2540168,2.4949002,,,,,,,,,,,,,, -65200,1.3447984,2.495296,,,,,,,,,,,,,, -65300,1.1554167,2.3823938,,,,,,,,,,,,,, -65400,1.2866114,2.5001059,,,,,,,,,,,,,, -65500,1.3707857,2.682911,,,,,,,,,,,,,, -65600,1.194392,3.6763554,,,,,,,,,,,,,, -65700,1.2755556,3.006279,,,,,,,,,,,,,, -65800,1.280132,2.3982148,,,,,,,,,,,,,, -65900,1.0106612,5.0987687,,,,,,,,,,,,,, -65996,,,0.667675793170929,1.373803734779358,0.6104399561882019,1.6528773307800293,50000.0,0.4899000227451324,2.314547300338745,10000.0,30710.779942512512,34010.43670129776,30710.779942512512,3293.040856361389,2.891425609588623,0.0 -66000,1.0564195,3.064025,,,,,,,,,,,,,, -66100,1.0019968,4.825779,,,,,,,,,,,,,, -66200,0.99949336,4.2746468,,,,,,,,,,,,,, -66300,1.1331086,3.4765337,,,,,,,,,,,,,, -66400,1.0654899,4.4367867,,,,,,,,,,,,,, -66500,1.206694,4.9823174,,,,,,,,,,,,,, -66600,1.2617729,2.4561296,,,,,,,,,,,,,, -66700,1.289724,2.5242596,,,,,,,,,,,,,, -66800,1.472148,2.582696,,,,,,,,,,,,,, -66900,1.3096892,2.4620159,,,,,,,,,,,,,, -66901,,,0.6533398032188416,1.4194730520248413,0.6050599813461304,1.660703182220459,50000.0,0.4816000163555145,2.354022741317749,10000.0,31131.218547344208,34471.03016901016,31131.218547344208,3333.1036410331726,2.93249773979187,0.0 -67000,1.3584666,2.537341,,,,,,,,,,,,,, -67100,1.1670984,2.9549499,,,,,,,,,,,,,, -67200,1.1887611,2.783517,,,,,,,,,,,,,, -67300,1.1515115,2.6587539,,,,,,,,,,,,,, -67400,1.0145614,4.469718,,,,,,,,,,,,,, -67500,1.321129,2.5434558,,,,,,,,,,,,,, -67600,0.99763507,3.8687654,,,,,,,,,,,,,, -67700,1.2442052,2.6891668,,,,,,,,,,,,,, -67800,1.2063226,2.4709663,,,,,,,,,,,,,, -67805,,,0.6598047018051147,1.3968747854232788,0.6152600049972534,1.627685785293579,50000.0,0.492900013923645,2.301361322402954,10000.0,31551.12981677056,34935.6247048378,31551.12981677056,3377.696927547455,2.971553564071656,0.0 -67900,1.1998379,3.755128,,,,,,,,,,,,,, -68000,1.2724879,2.7665575,,,,,,,,,,,,,, -68100,1.28148,2.450747,,,,,,,,,,,,,, -68200,1.3027024,2.5495036,,,,,,,,,,,,,, -68300,1.3559316,4.0904865,,,,,,,,,,,,,, -68400,1.3076544,2.4757018,,,,,,,,,,,,,, -68500,1.1124554,4.268525,,,,,,,,,,,,,, -68600,1.0089906,5.0429754,,,,,,,,,,,,,, -68700,1.097358,3.2183692,,,,,,,,,,,,,, -68712,,,0.6655663847923279,1.3830251693725586,0.6067999601364136,1.658873438835144,50000.0,0.4833000302314758,2.339237689971924,10000.0,31971.418674469,35402.118267059326,31971.418674469,3423.8143379688263,3.0081887245178223,0.0 -68800,1.4045893,2.5044487,,,,,,,,,,,,,, -68900,1.2924503,2.6584508,,,,,,,,,,,,,, -69000,1.126569,3.4566631,,,,,,,,,,,,,, -69100,0.99526757,4.4397745,,,,,,,,,,,,,, -69200,1.0583667,3.9482083,,,,,,,,,,,,,, -69300,1.2793614,2.9489665,,,,,,,,,,,,,, -69400,1.0959951,4.424648,,,,,,,,,,,,,, -69500,1.247396,2.8903139,,,,,,,,,,,,,, -69600,1.1089282,3.3764844,,,,,,,,,,,,,, -69620,,,0.6632031202316284,1.383965253829956,0.610539972782135,1.6313594579696655,50000.0,0.4856000244617462,2.3167552947998047,10000.0,32391.608671188354,35864.72148346901,32391.608671188354,3466.132476091385,3.0513832569122314,0.0 -69700,1.3191658,2.59715,,,,,,,,,,,,,, -69800,1.3380896,2.3346505,,,,,,,,,,,,,, -69900,1.3443836,2.3982315,,,,,,,,,,,,,, -70000,1.2615101,3.464433,,,,,,,,,,,,,, -70100,1.275307,2.490757,,,,,,,,,,,,,, -70200,1.14806,3.2120194,,,,,,,,,,,,,, -70300,1.1910089,2.4209092,,,,,,,,,,,,,, -70400,1.1452562,4.0740986,,,,,,,,,,,,,, -70500,1.1491808,4.924539,,,,,,,,,,,,,, -70524,,,0.6672461032867432,1.3634191751480105,0.6170399785041809,1.6041995286941528,50000.0,0.4950000345706939,2.2632946968078613,10000.0,32811.70836234093,36331.73110461235,32811.70836234093,3512.953349590301,3.0895795822143555,0.0 -70600,1.2093252,2.447476,,,,,,,,,,,,,, -70700,1.2834533,2.517764,,,,,,,,,,,,,, -70800,1.3460292,2.3896418,,,,,,,,,,,,,, -70900,1.0399812,4.07222,,,,,,,,,,,,,, -71000,1.2732289,2.7725863,,,,,,,,,,,,,, -71100,1.2408693,4.951352,,,,,,,,,,,,,, -71200,1.314388,2.5093088,,,,,,,,,,,,,, -71300,1.0605303,3.969292,,,,,,,,,,,,,, -71400,1.2475246,2.4202633,,,,,,,,,,,,,, -71426,,,0.6733202934265137,1.3281680345535278,0.6173200011253357,1.5995622873306274,50000.0,0.4943000376224518,2.282932043075561,10000.0,33231.66958928108,36795.55163145065,33231.66958928108,3556.7215859889984,3.130367517471313,0.0 -71500,1.4740741,2.597005,,,,,,,,,,,,,, -71600,1.0955188,3.7176116,,,,,,,,,,,,,, -71700,1.2916061,2.3976855,,,,,,,,,,,,,, -71800,1.1474408,3.5259652,,,,,,,,,,,,,, -71900,1.2034633,2.9003701,,,,,,,,,,,,,, -72000,1.2762576,3.2980254,,,,,,,,,,,,,, -72100,1.1276051,3.089871,,,,,,,,,,,,,, -72200,1.5495576,2.669768,,,,,,,,,,,,,, -72300,1.397594,2.4781404,,,,,,,,,,,,,, -72330,,,0.6946874856948853,1.2652043104171753,0.6173799633979797,1.6047745943069458,50000.0,0.4918000102043152,2.269059896469116,10000.0,33651.71494960785,37263.838456869125,33651.71494960785,3604.856119155884,3.186277151107788,0.0 -72400,1.3714079,2.6436338,,,,,,,,,,,,,, -72500,1.2052052,4.753511,,,,,,,,,,,,,, -72600,1.2819649,2.3938248,,,,,,,,,,,,,, -72700,1.0285404,4.773347,,,,,,,,,,,,,, -72800,1.4212704,2.538609,,,,,,,,,,,,,, -72900,1.0871155,4.485089,,,,,,,,,,,,,, -73000,1.1234531,3.3560438,,,,,,,,,,,,,, -73100,1.2756158,2.3869076,,,,,,,,,,,,,, -73200,1.3564498,2.444387,,,,,,,,,,,,,, -73236,,,0.6619726419448853,1.3814362287521362,0.613860011100769,1.6136529445648191,50000.0,0.4927000105381012,2.279181241989136,10000.0,34072.08019518852,37732.43508505821,34072.08019518852,3652.995265960693,3.228080034255981,0.0 -73300,1.4600294,2.3525689,,,,,,,,,,,,,, -73400,1.4919164,2.446317,,,,,,,,,,,,,, -73500,1.2525575,2.8757899,,,,,,,,,,,,,, -73600,1.4291236,2.4168997,,,,,,,,,,,,,, -73700,1.3943088,2.4988787,,,,,,,,,,,,,, -73800,1.2846283,2.3639624,,,,,,,,,,,,,, -73900,1.1592814,4.45893,,,,,,,,,,,,,, -74000,1.1939658,3.3584447,,,,,,,,,,,,,, -74100,1.2345059,2.6847138,,,,,,,,,,,,,, -74139,,,0.6750390529632568,1.3150864839553833,0.6202999949455261,1.581875562667847,50000.0,0.5021000504493713,2.2524607181549072,10000.0,34492.027354717255,38193.86260414124,34492.027354717255,3694.383110284805,3.269738912582397,0.0 -74200,1.1378622,5.0367937,,,,,,,,,,,,,, -74300,1.2406678,2.3371408,,,,,,,,,,,,,, -74400,1.5330786,2.6068978,,,,,,,,,,,,,, -74500,1.426733,2.622564,,,,,,,,,,,,,, -74600,1.3400162,2.5949225,,,,,,,,,,,,,, -74700,1.5791733,2.536065,,,,,,,,,,,,,, -74800,1.5586691,2.2600102,,,,,,,,,,,,,, -74900,1.1027122,3.404482,,,,,,,,,,,,,, -75000,1.2665292,2.3233032,,,,,,,,,,,,,, -75045,,,0.6972851157188416,1.2580310106277466,0.6181199550628662,1.6094812154769895,50000.0,0.4984000325202942,2.2771427631378174,10000.0,34912.15011835098,38661.28926753998,34912.15011835098,3741.594585418701,3.3109562397003174,0.0 -75100,1.4687306,2.4851074,,,,,,,,,,,,,, -75200,1.1905009,2.821324,,,,,,,,,,,,,, -75300,1.0029638,4.590486,,,,,,,,,,,,,, -75400,1.487781,2.778025,,,,,,,,,,,,,, -75500,1.3392086,2.303537,,,,,,,,,,,,,, -75600,1.3084117,2.4791327,,,,,,,,,,,,,, -75700,1.0981703,4.9601054,,,,,,,,,,,,,, -75800,1.3043692,2.6174042,,,,,,,,,,,,,, -75900,1.1340995,4.4148893,,,,,,,,,,,,,, -75953,,,0.6793164014816284,1.3091703653335571,0.6315000057220459,1.5481455326080322,50000.0,0.5069000124931335,2.2062301635742188,10000.0,35332.28189110756,39126.62657475472,35332.28189110756,3786.706786632538,3.353020668029785,0.0 -76000,1.3815182,2.4115436,,,,,,,,,,,,,, -76100,1.4107195,2.5005863,,,,,,,,,,,,,, -76200,1.220781,2.3921824,,,,,,,,,,,,,, -76300,1.2672086,2.5099149,,,,,,,,,,,,,, -76400,1.3021843,2.296963,,,,,,,,,,,,,, -76500,1.4170104,2.3742633,,,,,,,,,,,,,, -76600,1.4749513,2.5299492,,,,,,,,,,,,,, -76700,1.3957721,2.3159223,,,,,,,,,,,,,, -76800,1.1563327,4.9181685,,,,,,,,,,,,,, -76861,,,0.6729882955551147,1.3250367641448977,0.6203599572181702,1.5800641775131226,50000.0,0.4974000155925751,2.2438337802886963,10000.0,35752.626155376434,39597.17609834671,35752.626155376434,3836.816582202912,3.397064447402954,0.0 -76900,1.284343,2.3676515,,,,,,,,,,,,,, -77000,1.1970047,4.9436903,,,,,,,,,,,,,, -77100,1.0882329,4.3397574,,,,,,,,,,,,,, -77200,1.1380773,4.7883115,,,,,,,,,,,,,, -77300,1.0821031,4.1145606,,,,,,,,,,,,,, -77400,1.3527306,2.3186054,,,,,,,,,,,,,, -77500,1.0725416,4.4064984,,,,,,,,,,,,,, -77600,1.1724368,3.2653503,,,,,,,,,,,,,, -77700,1.2867383,2.39767,,,,,,,,,,,,,, -77766,,,0.7009375095367432,1.209425687789917,0.6351000070571899,1.5253475904464722,50000.0,0.5112000107765198,2.183667659759521,10000.0,36172.57538366318,40064.91317510605,36172.57538366318,3884.51095700264,3.4398858547210693,0.0 -77800,1.102162,3.7291937,,,,,,,,,,,,,, -77900,1.258687,2.4357967,,,,,,,,,,,,,, -78000,1.5579137,2.4185302,,,,,,,,,,,,,, -78100,1.4479213,2.4027562,,,,,,,,,,,,,, -78200,1.2464025,2.5786822,,,,,,,,,,,,,, -78300,1.2113698,4.4803395,,,,,,,,,,,,,, -78400,1.4473053,2.360982,,,,,,,,,,,,,, -78500,1.0910589,4.7261558,,,,,,,,,,,,,, -78600,1.3303974,2.409619,,,,,,,,,,,,,, -78670,,,0.6809374690055847,1.2866195440292358,0.6292799711227417,1.5328986644744873,50000.0,0.5099000334739685,2.180517196655273,10000.0,36592.812076091766,40533.96717476845,36592.812076091766,3933.235659122467,3.48156213760376,0.0 -78700,1.3339753,2.7377224,,,,,,,,,,,,,, -78800,1.4029348,2.4015298,,,,,,,,,,,,,, -78900,1.3714858,2.3892093,,,,,,,,,,,,,, -79000,1.2591133,2.888085,,,,,,,,,,,,,, -79100,1.3289737,2.3512297,,,,,,,,,,,,,, -79200,1.280598,2.714934,,,,,,,,,,,,,, -79300,1.3873168,3.808692,,,,,,,,,,,,,, -79400,1.3016672,2.7328439,,,,,,,,,,,,,, -79500,1.2187648,3.2604895,,,,,,,,,,,,,, -79574,,,0.6755468845367432,1.3667091131210327,0.6249200105667114,1.604951024055481,50000.0,0.5047000050544739,2.2638602256774902,10000.0,37012.992997169495,41002.65064263344,37012.992997169495,3981.647619247437,3.5212857723236084,0.0 -79600,1.274793,2.7943687,,,,,,,,,,,,,, -79700,1.50307,2.3118432,,,,,,,,,,,,,, -79800,1.51017,2.4430254,,,,,,,,,,,,,, -79900,1.2986097,2.32484,,,,,,,,,,,,,, -80000,1.2995574,3.7106557,,,,,,,,,,,,,, -80100,1.3466268,2.6079736,,,,,,,,,,,,,, -80200,1.1052684,4.4285827,,,,,,,,,,,,,, -80300,1.3351538,2.675133,,,,,,,,,,,,,, -80400,1.4333621,2.5691075,,,,,,,,,,,,,, -80480,,,0.6903710961341858,1.240017652511597,0.6298399567604065,1.541815161705017,50000.0,0.5015000104904175,2.2130167484283447,10000.0,37432.94675517082,41462.675387859344,37432.94675517082,4021.62891960144,3.560465812683105,0.0 -80500,1.4004929,2.3967059,,,,,,,,,,,,,, -80600,1.2899952,2.3492386,,,,,,,,,,,,,, -80700,1.4379336,2.626531,,,,,,,,,,,,,, -80800,1.4066937,2.5420196,,,,,,,,,,,,,, -80900,1.468743,2.397902,,,,,,,,,,,,,, -81000,1.2515029,2.3796093,,,,,,,,,,,,,, -81100,1.2302003,3.05629,,,,,,,,,,,,,, -81200,1.3691815,2.3235314,,,,,,,,,,,,,, -81300,1.3340976,2.3823328,,,,,,,,,,,,,, -81384,,,0.6767382621765137,1.3347312211990356,0.6283999681472778,1.5647926330566406,50000.0,0.5029000043869019,2.2255892753601074,10000.0,37852.8576259613,41930.42304897308,37852.8576259613,4069.3755803108215,3.600275039672852,0.0 -81400,1.3207719,2.4220405,,,,,,,,,,,,,, -81500,1.3501959,2.3429155,,,,,,,,,,,,,, -81600,1.3713589,2.784569,,,,,,,,,,,,,, -81700,1.0859939,4.898922,,,,,,,,,,,,,, -81800,1.3368582,2.2608423,,,,,,,,,,,,,, -81900,1.3383436,2.3986566,,,,,,,,,,,,,, -82000,1.2453303,3.2075844,,,,,,,,,,,,,, -82100,1.7247249,2.3198957,,,,,,,,,,,,,, -82200,1.1998096,4.098297,,,,,,,,,,,,,, -82289,,,0.6830077767372131,1.3016310930252075,0.6261599659919739,1.565454125404358,50000.0,0.5054000020027161,2.210899353027344,10000.0,38272.83093523979,42398.30388045311,38272.83093523979,4117.19310426712,3.6395747661590576,0.0 -82300,1.3585616,2.2816901,,,,,,,,,,,,,, -82400,1.3463401,2.3630643,,,,,,,,,,,,,, -82500,1.2506943,4.9262185,,,,,,,,,,,,,, -82600,1.4231113,2.4899688,,,,,,,,,,,,,, -82700,1.4435289,2.2543807,,,,,,,,,,,,,, -82800,1.6068901,2.248999,,,,,,,,,,,,,, -82900,1.3009727,3.4013405,,,,,,,,,,,,,, -83000,1.4423045,2.4354692,,,,,,,,,,,,,, -83100,1.2443163,4.4280086,,,,,,,,,,,,,, -83196,,,0.6936327815055847,1.2280399799346924,0.6358000040054321,1.5165338516235352,50000.0,0.5112000107765198,2.164104700088501,10000.0,38692.91235637665,42869.38398528099,38692.91235637665,4168.097243785858,3.682982444763184,0.0 -83200,1.2049651,4.111354,,,,,,,,,,,,,, -83300,1.2781869,4.380293,,,,,,,,,,,,,, -83400,1.335932,2.866299,,,,,,,,,,,,,, -83500,1.2351731,3.4308288,,,,,,,,,,,,,, -83600,1.1540211,4.1773705,,,,,,,,,,,,,, -83700,1.4794227,2.4168527,,,,,,,,,,,,,, -83800,1.4020114,2.2855182,,,,,,,,,,,,,, -83900,1.0694256,4.0041795,,,,,,,,,,,,,, -84000,1.5586469,2.3518043,,,,,,,,,,,,,, -84100,,,0.6860546469688416,1.281497359275818,0.638260006904602,1.5209600925445557,50000.0,0.513700008392334,2.1791257858276367,10000.0,39113.014108896255,43335.9387626648,39113.014108896255,4214.460152864456,3.722850561141968,0.0 -84100,1.1777526,3.0637577,,,,,,,,,,,,,, -84200,1.2326566,3.0435848,,,,,,,,,,,,,, -84300,1.3933525,2.3629274,,,,,,,,,,,,,, -84400,1.4964067,2.286275,,,,,,,,,,,,,, -84500,1.3177148,2.853018,,,,,,,,,,,,,, -84600,1.1544646,4.2499046,,,,,,,,,,,,,, -84700,1.3080478,4.0678635,,,,,,,,,,,,,, -84800,1.4803156,2.5613725,,,,,,,,,,,,,, -84900,1.4243855,2.3224428,,,,,,,,,,,,,, -85000,1.3187834,4.053368,,,,,,,,,,,,,, -85007,,,0.6898437142372131,1.2605102062225342,0.6386399865150452,1.5136559009552002,50000.0,0.5118000507354736,2.1877145767211914,10000.0,39533.14155960083,43804.54938220978,39533.14155960083,4262.848652362824,3.765680074691773,0.0 -85100,1.2413094,3.5253716,,,,,,,,,,,,,, -85200,1.2830142,4.8083897,,,,,,,,,,,,,, -85300,1.1183397,2.8291898,,,,,,,,,,,,,, -85400,1.3856614,2.4116304,,,,,,,,,,,,,, -85500,1.2909524,2.7517097,,,,,,,,,,,,,, -85600,1.5087103,2.4089148,,,,,,,,,,,,,, -85700,1.4161311,2.5657034,,,,,,,,,,,,,, -85800,1.3328769,3.1259341,,,,,,,,,,,,,, -85900,1.4891896,2.6598184,,,,,,,,,,,,,, -85914,,,0.6995507478713989,1.186415195465088,0.642300009727478,1.469892501831055,50000.0,0.5211000442504883,2.127619743347168,10000.0,39953.23417115212,44272.816187381744,39953.23417115212,4310.929792881012,3.8074779510498047,0.0 -86000,1.1567308,4.4780126,,,,,,,,,,,,,, -86100,1.6310089,2.3717952,,,,,,,,,,,,,, -86200,1.3061191,2.9800534,,,,,,,,,,,,,, -86300,1.3660477,2.3744435,,,,,,,,,,,,,, -86400,1.1189227,3.3280737,,,,,,,,,,,,,, -86500,1.3108078,3.06542,,,,,,,,,,,,,, -86600,1.3694664,2.318746,,,,,,,,,,,,,, -86700,1.1674193,3.3443217,,,,,,,,,,,,,, -86800,1.3480393,2.4664578,,,,,,,,,,,,,, -86817,,,0.6941601634025574,1.2347404956817627,0.6406199932098389,1.486888766288757,50000.0,0.51910001039505,2.128563165664673,10000.0,40373.50329899788,44740.78130793572,40373.50329899788,4358.529216766357,3.8530352115631104,0.0 -86900,1.2571967,4.977654,,,,,,,,,,,,,, -87000,1.7207991,2.3074858,,,,,,,,,,,,,, -87100,1.4940164,4.9087954,,,,,,,,,,,,,, -87200,1.5763191,2.4567096,,,,,,,,,,,,,, -87300,1.1691363,4.4625187,,,,,,,,,,,,,, -87400,1.3510847,2.3705602,,,,,,,,,,,,,, -87500,1.2831175,2.9257827,,,,,,,,,,,,,, -87600,1.2990605,2.7975442,,,,,,,,,,,,,, -87700,1.5126861,2.2587535,,,,,,,,,,,,,, -87720,,,0.69593745470047,1.2813823223114014,0.6403999924659729,1.535065770149231,50000.0,0.5139000415802002,2.182572841644287,10000.0,40793.42196941376,45206.56125569344,40793.42196941376,4404.2962164878845,3.895873308181762,0.0 -87800,1.3626455,2.289511,,,,,,,,,,,,,, -87900,1.4652791,2.3136625,,,,,,,,,,,,,, -88000,1.4278156,2.3223538,,,,,,,,,,,,,, -88100,1.4090565,3.4601567,,,,,,,,,,,,,, -88200,1.2438747,2.7909973,,,,,,,,,,,,,, -88300,1.2912596,3.0826368,,,,,,,,,,,,,, -88400,1.2195836,3.6157606,,,,,,,,,,,,,, -88500,1.2961322,2.3488266,,,,,,,,,,,,,, -88600,1.3969606,2.3509448,,,,,,,,,,,,,, -88625,,,0.7049609422683716,1.2156800031661987,0.6411799788475037,1.4978916645050049,50000.0,0.5161000490188599,2.1546852588653564,10000.0,41213.36802792549,45678.53724288941,41213.36802792549,4456.234867095947,3.935400485992432,0.0 -88700,1.4443421,4.236131,,,,,,,,,,,,,, -88800,1.3975985,2.2196765,,,,,,,,,,,,,, -88900,1.5364981,2.382819,,,,,,,,,,,,,, -89000,1.2879186,3.2030282,,,,,,,,,,,,,, -89100,1.4570314,3.2226348,,,,,,,,,,,,,, -89200,1.331146,2.2918906,,,,,,,,,,,,,, -89300,1.3336885,2.388173,,,,,,,,,,,,,, -89400,1.1615808,3.831967,,,,,,,,,,,,,, -89500,1.5566576,2.3369823,,,,,,,,,,,,,, -89532,,,0.7158789038658142,1.1632670164108276,0.6412999629974365,1.5002963542938232,50000.0,0.5234000086784363,2.139137983322144,10000.0,41633.58397102356,46147.01468038559,41633.58397102356,4504.405340433121,3.975401163101196,0.0 -89600,1.2365,4.907817,,,,,,,,,,,,,, -89700,1.2410457,3.3052838,,,,,,,,,,,,,, -89800,1.4596835,2.7563436,,,,,,,,,,,,,, -89900,1.6138816,2.365961,,,,,,,,,,,,,, -90000,1.3750763,2.3165364,,,,,,,,,,,,,, -90100,1.3033812,2.7986343,,,,,,,,,,,,,, -90200,1.2856885,2.710837,,,,,,,,,,,,,, -90300,1.6544905,2.4762988,,,,,,,,,,,,,, -90400,1.2437263,3.5741775,,,,,,,,,,,,,, -90437,,,0.6942968368530273,1.242924690246582,0.6450600028038025,1.48419451713562,50000.0,0.5220000147819519,2.13555645942688,10000.0,42053.79886484146,46614.43021249771,42053.79886484146,4551.505216121674,4.02495002746582,0.0 -90500,1.487367,2.2555208,,,,,,,,,,,,,, -90600,1.2742908,3.8787556,,,,,,,,,,,,,, -90700,1.1317806,3.992954,,,,,,,,,,,,,, -90800,1.398322,2.175939,,,,,,,,,,,,,, -90900,1.3796256,2.2272274,,,,,,,,,,,,,, -91000,1.3693179,3.336706,,,,,,,,,,,,,, -91100,1.1916213,3.506085,,,,,,,,,,,,,, -91200,1.1623318,4.2085137,,,,,,,,,,,,,, -91300,1.274411,2.54578,,,,,,,,,,,,,, -91344,,,0.7076367139816284,1.2079362869262695,0.644320011138916,1.490636110305786,50000.0,0.5290000438690186,2.1351263523101807,10000.0,42474.13457107544,47081.45916390419,42474.13457107544,4598.104979038239,4.066986322402954,0.0 -91400,1.3788412,4.412657,,,,,,,,,,,,,, -91500,1.2535853,4.162162,,,,,,,,,,,,,, -91600,1.4138963,2.410543,,,,,,,,,,,,,, -91700,1.3860918,2.2919674,,,,,,,,,,,,,, -91800,1.4288127,2.2263494,,,,,,,,,,,,,, -91900,1.2353649,3.853559,,,,,,,,,,,,,, -92000,1.6147943,2.2588425,,,,,,,,,,,,,, -92100,1.4321009,2.6455936,,,,,,,,,,,,,, -92200,1.224102,4.6321483,,,,,,,,,,,,,, -92246,,,0.7279687523841858,1.1314303874969482,0.6431199908256531,1.4939802885055542,50000.0,0.5223000049591064,2.1604013442993164,10000.0,42894.400421381,47550.82168245316,42894.400421381,4647.107894182205,4.110424280166626,0.0 -92300,1.5604446,2.1523426,,,,,,,,,,,,,, -92400,1.4503798,2.5441399,,,,,,,,,,,,,, -92500,1.4366578,2.103774,,,,,,,,,,,,,, -92600,1.3610011,2.3022947,,,,,,,,,,,,,, -92700,1.1802402,4.574692,,,,,,,,,,,,,, -92800,1.4164163,2.705374,,,,,,,,,,,,,, -92900,1.4129438,2.3318782,,,,,,,,,,,,,, -93000,1.367053,4.874253,,,,,,,,,,,,,, -93100,1.5355424,2.2644274,,,,,,,,,,,,,, -93152,,,0.7049609422683716,1.1868587732315063,0.6511799693107605,1.4446868896484375,50000.0,0.5295000076293945,2.0929665565490723,10000.0,43314.599668979645,48019.18526363373,43314.599668979645,4695.176491975784,4.155153512954712,0.0 -93200,1.4892652,2.6504314,,,,,,,,,,,,,, -93300,1.250735,4.311888,,,,,,,,,,,,,, -93400,1.5151051,2.1763566,,,,,,,,,,,,,, -93500,1.4485416,2.2275574,,,,,,,,,,,,,, -93600,1.3754342,3.152842,,,,,,,,,,,,,, -93700,1.4745972,2.2530065,,,,,,,,,,,,,, -93800,1.170549,4.372121,,,,,,,,,,,,,, -93900,1.2987554,3.6247954,,,,,,,,,,,,,, -94000,1.5226089,2.1669002,,,,,,,,,,,,,, -94057,,,0.7050195336341858,1.1846143007278442,0.6482399702072144,1.4538512229919434,50000.0,0.5242000222206116,2.119861364364624,10000.0,43734.53122782707,48487.1493666172,43734.53122782707,4743.113221168518,4.200689315795898,0.0 -94100,1.5163177,2.2691844,,,,,,,,,,,,,, -94200,1.4646881,2.308304,,,,,,,,,,,,,, -94300,1.5115186,2.2558684,,,,,,,,,,,,,, -94400,1.5252444,2.9181912,,,,,,,,,,,,,, -94500,1.6651869,2.3778503,,,,,,,,,,,,,, -94600,1.4378091,2.2638872,,,,,,,,,,,,,, -94700,1.1883472,4.6468906,,,,,,,,,,,,,, -94800,1.400396,2.2616673,,,,,,,,,,,,,, -94900,1.4124795,2.3265922,,,,,,,,,,,,,, -94962,,,0.7193359136581421,1.1572200059890747,0.6502999663352966,1.47944438457489,50000.0,0.527400016784668,2.117228507995605,10000.0,44154.71737384796,48952.4436519146,44154.71737384796,4788.12454199791,4.246379375457764,0.0 -95000,1.418282,2.3820949,,,,,,,,,,,,,, -95100,1.3811537,2.072721,,,,,,,,,,,,,, -95200,1.6550057,2.4087362,,,,,,,,,,,,,, -95300,1.2587022,4.8686543,,,,,,,,,,,,,, -95400,1.1607249,4.144033,,,,,,,,,,,,,, -95500,1.1604227,3.1963327,,,,,,,,,,,,,, -95600,1.5437657,2.1959631,,,,,,,,,,,,,, -95700,1.3844377,2.1204758,,,,,,,,,,,,,, -95800,1.4629205,2.6001604,,,,,,,,,,,,,, -95868,,,0.7100390195846558,1.2094417810440063,0.6509999632835388,1.471266269683838,50000.0,0.531000018119812,2.119685411453247,10000.0,44575.00277280808,49421.69991064072,44575.00277280808,4836.9977996349335,4.29305100440979,0.0 -95900,1.3697796,4.7130356,,,,,,,,,,,,,, -96000,1.1961007,4.09017,,,,,,,,,,,,,, -96100,1.5228363,2.1441605,,,,,,,,,,,,,, -96200,1.3721241,2.610653,,,,,,,,,,,,,, -96300,1.5043728,2.3944955,,,,,,,,,,,,,, -96400,1.368467,4.6899567,,,,,,,,,,,,,, -96500,1.4442393,2.1771972,,,,,,,,,,,,,, -96600,1.2528118,3.8434076,,,,,,,,,,,,,, -96700,1.5244771,2.27466,,,,,,,,,,,,,, -96774,,,0.7082812190055847,1.2541940212249756,0.6438999772071838,1.5317820310592651,50000.0,0.525600016117096,2.174257516860962,10000.0,44995.451451301575,49891.23855257034,44995.451451301575,4885.990926504135,4.338782787322998,0.0 -96800,1.3540275,3.7613068,,,,,,,,,,,,,, -96900,1.6031847,2.2531252,,,,,,,,,,,,,, -97000,1.4328768,2.2518673,,,,,,,,,,,,,, -97100,1.2525481,3.3872125,,,,,,,,,,,,,, -97200,1.2804992,3.9317172,,,,,,,,,,,,,, -97300,1.6908635,2.2399917,,,,,,,,,,,,,, -97400,1.5128455,2.1463304,,,,,,,,,,,,,, -97500,1.4420744,2.2583091,,,,,,,,,,,,,, -97600,1.4328519,2.2799864,,,,,,,,,,,,,, -97681,,,0.7260546684265137,1.110514760017395,0.6534599661827087,1.432780385017395,50000.0,0.5313000082969666,2.074703931808472,10000.0,45415.57292270661,50359.1891078949,45415.57292270661,4933.726239919663,4.382209062576294,0.0 -97700,1.4114171,2.2183475,,,,,,,,,,,,,, -97800,1.569591,2.2426105,,,,,,,,,,,,,, -97900,1.4403777,2.268855,,,,,,,,,,,,,, -98000,1.4231611,2.1375184,,,,,,,,,,,,,, -98100,1.456041,2.3237243,,,,,,,,,,,,,, -98200,1.2534568,4.1298485,,,,,,,,,,,,,, -98300,1.3909087,2.614203,,,,,,,,,,,,,, -98400,1.4256344,2.1716874,,,,,,,,,,,,,, -98500,1.6185507,2.0920439,,,,,,,,,,,,,, -98584,,,0.7095507383346558,1.2087146043777466,0.6518399715423584,1.4704029560089111,50000.0,0.5313000082969666,2.124727725982666,10000.0,45835.61195087433,50823.40377354622,45835.61195087433,4977.804324626923,4.428659439086914,0.0 -98600,1.2808522,4.697168,,,,,,,,,,,,,, -98700,1.5874648,2.112531,,,,,,,,,,,,,, -98800,1.4840813,2.2221937,,,,,,,,,,,,,, -98900,1.6121427,2.2333734,,,,,,,,,,,,,, -99000,1.5781347,2.1864448,,,,,,,,,,,,,, -99100,1.230674,3.8085642,,,,,,,,,,,,,, -99200,1.2184904,3.449766,,,,,,,,,,,,,, -99300,1.4639645,2.1566887,,,,,,,,,,,,,, -99400,1.4535816,3.1090922,,,,,,,,,,,,,, -99487,,,0.7159179449081421,1.143684983253479,0.660539984703064,1.4055538177490234,50000.0,0.5333000421524048,2.0649092197418213,10000.0,46255.78856515885,51290.96784090996,46255.78856515885,5025.0931096076965,4.4764626026153564,0.0 -99500,1.27144,4.80241,,,,,,,,,,,,,, -99600,1.5078206,2.2905328,,,,,,,,,,,,,, -99700,1.3550758,2.6889524,,,,,,,,,,,,,, -99800,1.5510553,2.4350233,,,,,,,,,,,,,, -99900,1.5858738,2.2992454,,,,,,,,,,,,,, -100000,1.6867251,2.183658,,,,,,,,,,,,,, -100100,1.6189119,2.3125622,,,,,,,,,,,,,, -100200,1.2592249,3.4285047,,,,,,,,,,,,,, -100300,1.6392088,2.2900333,,,,,,,,,,,,,, -100394,,,0.7255663871765137,1.114113211631775,0.6623600125312805,1.4175740480422974,50000.0,0.5331000089645386,2.06860876083374,10000.0,46676.1024954319,51759.92477321625,46676.1024954319,5073.639664173126,4.521659851074219,0.0 -100400,1.5163857,2.2438347,,,,,,,,,,,,,, -100500,1.2928962,3.9824877,,,,,,,,,,,,,, -100600,1.352112,3.2840896,,,,,,,,,,,,,, -100700,1.3645163,3.2523046,,,,,,,,,,,,,, -100800,1.3772192,3.3387053,,,,,,,,,,,,,, -100900,1.3838049,3.7004466,,,,,,,,,,,,,, -101000,1.595945,2.2326133,,,,,,,,,,,,,, -101100,1.5163652,2.2905116,,,,,,,,,,,,,, -101200,1.5662835,2.115925,,,,,,,,,,,,,, -101300,1.5570868,2.5978518,,,,,,,,,,,,,, -101303,,,0.7194140553474426,1.1470216512680054,0.6613399982452393,1.3947124481201172,50000.0,0.5372000336647034,2.0668838024139404,10000.0,47096.401584386826,52227.28931951523,47096.401584386826,5120.609957456589,4.565702676773071,0.0 -101400,1.5081319,2.0592573,,,,,,,,,,,,,, -101500,1.6398115,2.5294662,,,,,,,,,,,,,, -101600,1.5590866,2.508999,,,,,,,,,,,,,, -101700,1.3542658,2.8735662,,,,,,,,,,,,,, -101800,1.3630134,4.7531514,,,,,,,,,,,,,, -101900,1.417746,4.4657183,,,,,,,,,,,,,, -102000,1.4660108,2.444017,,,,,,,,,,,,,, -102100,1.5177901,2.193408,,,,,,,,,,,,,, -102200,1.4594883,2.3200645,,,,,,,,,,,,,, -102211,,,0.7141796946525574,1.1836493015289309,0.6551799774169922,1.4438341856002808,50000.0,0.5344000458717346,2.0949273109436035,10000.0,47516.692006111145,52693.0149166584,47516.692006111145,5165.949766874313,4.609235286712647,0.0 -102300,1.5099186,2.2441535,,,,,,,,,,,,,, -102400,1.9999173,2.1054013,,,,,,,,,,,,,, -102500,1.3564817,4.6069765,,,,,,,,,,,,,, -102600,1.4719601,2.1414533,,,,,,,,,,,,,, -102700,1.3353325,3.9328382,,,,,,,,,,,,,, -102800,1.4933449,4.6799803,,,,,,,,,,,,,, -102900,1.2624631,3.14584,,,,,,,,,,,,,, -103000,1.3945819,4.7099886,,,,,,,,,,,,,, -103100,1.6157062,2.256441,,,,,,,,,,,,,, -103115,,,0.7273827791213989,1.117196559906006,0.66211998462677,1.4077105522155762,50000.0,0.5384000539779663,2.0607264041900635,10000.0,47937.02245235443,53161.830530405045,47937.02245235443,5214.342481613159,4.651084899902344,0.0 -103200,1.2848305,3.170141,,,,,,,,,,,,,, -103300,1.465693,1.970102,,,,,,,,,,,,,, -103400,1.4565588,2.0503469,,,,,,,,,,,,,, -103500,1.5460492,2.2502282,,,,,,,,,,,,,, -103600,1.4254885,3.481452,,,,,,,,,,,,,, -103700,1.4407624,2.0257468,,,,,,,,,,,,,, -103800,1.6322455,2.1742442,,,,,,,,,,,,,, -103900,1.5120735,2.23457,,,,,,,,,,,,,, -104000,1.2971162,4.268691,,,,,,,,,,,,,, -104018,,,0.7233007550239563,1.127516508102417,0.6676799654960632,1.3816344738006592,50000.0,0.5418000221252441,2.032897472381592,10000.0,48357.16518211365,53627.07335996628,48357.16518211365,5259.345903396606,4.697351932525635,0.0 -104100,1.6316043,2.2066357,,,,,,,,,,,,,, -104200,1.4884074,2.4605987,,,,,,,,,,,,,, -104300,1.7322056,2.1890054,,,,,,,,,,,,,, -104400,1.8172734,4.26726,,,,,,,,,,,,,, -104500,1.5763749,2.2126262,,,,,,,,,,,,,, -104600,1.4758091,2.4252563,,,,,,,,,,,,,, -104700,1.5450392,2.1260803,,,,,,,,,,,,,, -104800,1.3131305,3.0243192,,,,,,,,,,,,,, -104900,1.5469798,2.102505,,,,,,,,,,,,,, -104923,,,0.7226171493530273,1.2302206754684448,0.6633599996566772,1.492753267288208,50000.0,0.538100004196167,2.138030767440796,10000.0,48777.29816532135,54094.22512769699,48777.29816532135,5306.26788520813,4.742557764053345,0.0 -105000,1.606687,2.129082,,,,,,,,,,,,,, -105100,1.5342667,2.2205997,,,,,,,,,,,,,, -105200,1.2610332,4.3415546,,,,,,,,,,,,,, -105300,1.3398933,2.5768113,,,,,,,,,,,,,, -105400,1.477011,2.7176676,,,,,,,,,,,,,, -105500,1.49173,2.0729172,,,,,,,,,,,,,, -105600,1.6943978,2.0976553,,,,,,,,,,,,,, -105700,1.3791454,3.4231448,,,,,,,,,,,,,, -105800,1.5469861,2.2216818,,,,,,,,,,,,,, -105828,,,0.7344140410423279,1.0741634368896484,0.6691799759864807,1.373185396194458,50000.0,0.5491999983787537,2.021711826324463,10000.0,49197.20958852768,54559.56789302826,49197.20958852768,5351.597653627396,4.792775630950928,0.0 -105900,1.3518549,3.9738894,,,,,,,,,,,,,, -106000,1.4582504,4.124207,,,,,,,,,,,,,, -106100,1.5904512,2.2191176,,,,,,,,,,,,,, -106200,1.4766712,2.4123561,,,,,,,,,,,,,, -106300,1.4572291,4.713183,,,,,,,,,,,,,, -106400,1.6745051,2.227385,,,,,,,,,,,,,, -106500,1.6970968,2.6820743,,,,,,,,,,,,,, -106600,1.5803328,2.1984453,,,,,,,,,,,,,, -106700,1.8129041,2.1233294,,,,,,,,,,,,,, -106733,,,0.7335546612739563,1.0971087217330933,0.6657599806785583,1.4038810729980469,50000.0,0.5445000529289246,2.057142734527588,10000.0,49617.28916144371,55028.830676317215,49617.28916144371,5400.683298826218,4.839393138885498,0.0 -106800,1.599324,2.311994,,,,,,,,,,,,,, -106900,1.5868281,3.2269135,,,,,,,,,,,,,, -107000,1.4834962,2.5897326,,,,,,,,,,,,,, -107100,1.6629547,2.2132006,,,,,,,,,,,,,, -107200,1.5916281,2.1182241,,,,,,,,,,,,,, -107300,1.6650949,2.2617056,,,,,,,,,,,,,, -107400,1.4578127,2.510912,,,,,,,,,,,,,, -107500,1.575311,2.191446,,,,,,,,,,,,,, -107600,1.5304865,2.0577893,,,,,,,,,,,,,, -107638,,,0.7317773103713989,1.0785207748413086,0.6732400059700012,1.3483539819717407,50000.0,0.5491999983787537,1.9994500875473025,10000.0,50037.69368457794,55492.862330675125,50037.69368457794,5444.214246273041,4.885617017745972,0.0 -107700,1.685572,2.2709951,,,,,,,,,,,,,, -107800,1.6447717,2.1370683,,,,,,,,,,,,,, -107900,1.336523,3.4661338,,,,,,,,,,,,,, -108000,1.7185875,2.1076155,,,,,,,,,,,,,, -108100,1.5883193,2.7476397,,,,,,,,,,,,,, -108200,1.3740026,4.6706653,,,,,,,,,,,,,, -108300,1.6617411,2.0819712,,,,,,,,,,,,,, -108400,1.8809779,2.157006,,,,,,,,,,,,,, -108500,1.3627875,4.457324,,,,,,,,,,,,,, -108542,,,0.7343358993530273,1.076116919517517,0.6702600121498108,1.35907781124115,50000.0,0.5520000457763672,2.020024299621582,10000.0,50458.04724955559,55961.15143156052,50458.04724955559,5492.055480241776,4.9278857707977295,0.0 -108600,1.4872943,3.4618492,,,,,,,,,,,,,, -108700,1.7161103,2.0947485,,,,,,,,,,,,,, -108800,1.524388,2.0794246,,,,,,,,,,,,,, -108900,1.6573738,2.1692314,,,,,,,,,,,,,, -109000,1.4933354,2.4050298,,,,,,,,,,,,,, -109100,1.5963616,2.3065758,,,,,,,,,,,,,, -109200,1.5679086,2.5627105,,,,,,,,,,,,,, -109300,1.4879053,3.3204975,,,,,,,,,,,,,, -109400,1.8016181,2.1001792,,,,,,,,,,,,,, -109444,,,0.7549999952316284,1.0015008449554443,0.6709199547767639,1.374889612197876,50000.0,0.5466000437736511,2.0300540924072266,10000.0,50877.9556055069,56428.23986721039,50877.9556055069,5539.1366448402405,4.976301908493042,0.0 -109500,1.5404395,3.5721529,,,,,,,,,,,,,, -109600,1.7085735,2.1329648,,,,,,,,,,,,,, -109700,1.2750411,2.9456363,,,,,,,,,,,,,, -109800,1.5360318,2.0121512,,,,,,,,,,,,,, -109900,1.7739766,2.1644237,,,,,,,,,,,,,, -110000,1.503434,2.915205,,,,,,,,,,,,,, -110100,1.6326251,4.1438212,,,,,,,,,,,,,, -110200,1.6961398,2.0801327,,,,,,,,,,,,,, -110300,1.4582403,4.289194,,,,,,,,,,,,,, -110352,,,0.7322655916213989,1.0748839378356934,0.6771999597549438,1.3334211111068726,50000.0,0.5463000535964966,1.9860961437225344,10000.0,51298.14277744293,56895.654056310654,51298.14277744293,5586.264869213104,5.022613763809204,0.0 -110400,1.3774551,4.311425,,,,,,,,,,,,,, -110500,1.5213661,3.9068413,,,,,,,,,,,,,, -110600,1.6199518,1.9658448,,,,,,,,,,,,,, -110700,1.7804025,2.068782,,,,,,,,,,,,,, -110800,1.6896648,1.9765936,,,,,,,,,,,,,, -110900,1.6526722,2.4339557,,,,,,,,,,,,,, -111000,1.563839,2.2745042,,,,,,,,,,,,,, -111100,1.5074371,4.6631126,,,,,,,,,,,,,, -111200,1.5265006,4.584398,,,,,,,,,,,,,, -111258,,,0.7407421469688416,1.0589429140090942,0.6723399758338928,1.357522964477539,50000.0,0.5458000302314758,2.025487184524536,10000.0,51718.23906922341,57364.20047616959,51718.23906922341,5634.615519762039,5.071029901504517,0.0 -111300,1.4542602,2.6645257,,,,,,,,,,,,,, -111400,1.3891863,3.1109016,,,,,,,,,,,,,, -111500,1.6987711,2.0733995,,,,,,,,,,,,,, -111600,1.6020324,3.4529226,,,,,,,,,,,,,, -111700,1.6010363,2.349532,,,,,,,,,,,,,, -111800,1.445871,4.727164,,,,,,,,,,,,,, -111900,1.5664269,3.3872597,,,,,,,,,,,,,, -112000,1.6873312,2.1337814,,,,,,,,,,,,,, -112100,1.517637,3.5478609,,,,,,,,,,,,,, -112164,,,0.7571874856948853,0.9590351581573486,0.6782799959182739,1.3155368566513062,50000.0,0.550000011920929,1.9666587114334104,10000.0,52138.199348688126,57832.32154393196,52138.199348688126,5682.680174589157,5.116419792175293,0.0 -112200,1.5582064,4.5096254,,,,,,,,,,,,,, -112300,1.3826556,2.952857,,,,,,,,,,,,,, -112400,1.7384325,2.3651457,,,,,,,,,,,,,, -112500,1.6792201,2.1888344,,,,,,,,,,,,,, -112600,1.5489258,3.5313613,,,,,,,,,,,,,, -112700,1.7194084,2.0887423,,,,,,,,,,,,,, -112800,1.3250738,3.7610044,,,,,,,,,,,,,, -112900,1.5571288,2.1129415,,,,,,,,,,,,,, -113000,1.4676839,4.2107487,,,,,,,,,,,,,, -113066,,,0.7348827719688416,1.1102678775787354,0.6744199991226196,1.374962568283081,50000.0,0.5481000542640686,2.024175882339477,10000.0,52558.10915327072,58299.82621335983,52558.10915327072,5730.1758716106415,5.1636152267456055,0.0 -113100,1.9157062,2.1939158,,,,,,,,,,,,,, -113200,1.4519712,3.4851692,,,,,,,,,,,,,, -113300,1.7692437,2.0853534,,,,,,,,,,,,,, -113400,1.6896774,2.0202014,,,,,,,,,,,,,, -113500,1.7081306,2.3744624,,,,,,,,,,,,,, -113600,1.7083231,2.030054,,,,,,,,,,,,,, -113700,1.673994,2.0271842,,,,,,,,,,,,,, -113800,1.8455836,2.0384133,,,,,,,,,,,,,, -113900,1.4031163,4.1267858,,,,,,,,,,,,,, -113970,,,0.7422655820846558,1.044321060180664,0.6773599982261658,1.3351013660430908,50000.0,0.549500048160553,2.0046536922454834,10000.0,52978.40332150459,58768.11574149132,52978.40332150459,5778.074056625366,5.209723234176636,0.0 -114000,1.5578349,4.272726,,,,,,,,,,,,,, -114100,1.6048828,2.1525722,,,,,,,,,,,,,, -114200,1.623055,3.0613282,,,,,,,,,,,,,, -114300,1.4678477,3.8113198,,,,,,,,,,,,,, -114400,1.7670871,2.0594478,,,,,,,,,,,,,, -114500,1.4316388,2.7777977,,,,,,,,,,,,,, -114600,1.7926667,2.1088006,,,,,,,,,,,,,, -114700,1.8612657,2.199725,,,,,,,,,,,,,, -114800,1.6824901,3.6859777,,,,,,,,,,,,,, -114873,,,0.7559765577316284,0.9663622975349426,0.6796199679374695,1.3073861598968506,50000.0,0.5515000224113464,1.9546692371368408,10000.0,53398.46142745018,59231.86329960823,53398.46142745018,5821.660964488983,5.261880159378052,0.0 -114900,1.6730769,2.0163736,,,,,,,,,,,,,, -115000,1.5195966,4.3875914,,,,,,,,,,,,,, -115100,1.6910352,2.095693,,,,,,,,,,,,,, -115200,1.6587236,2.1564333,,,,,,,,,,,,,, -115300,1.7116402,1.9810702,,,,,,,,,,,,,, -115400,1.6542336,2.4119473,,,,,,,,,,,,,, -115500,1.4180965,2.9050422,,,,,,,,,,,,,, -115600,1.5808958,3.9608107,,,,,,,,,,,,,, -115700,1.7490606,1.9964838,,,,,,,,,,,,,, -115777,,,0.7392773032188416,1.056602954864502,0.6826800107955933,1.3141595125198364,50000.0,0.5634000301361084,1.946633100509644,10000.0,53818.511486291885,59700.5247297287,53818.511486291885,5870.175407886505,5.30842137336731,0.0 -115800,1.5443977,2.624791,,,,,,,,,,,,,, -115900,1.679666,1.9194705,,,,,,,,,,,,,, -116000,1.7409333,2.007111,,,,,,,,,,,,,, -116100,1.6927989,2.0914416,,,,,,,,,,,,,, -116200,1.6732177,2.1802204,,,,,,,,,,,,,, -116300,1.570551,1.9153306,,,,,,,,,,,,,, -116400,1.6290556,1.9956261,,,,,,,,,,,,,, -116500,1.6705835,2.2738333,,,,,,,,,,,,,, -116600,1.5800146,2.1720235,,,,,,,,,,,,,, -116682,,,0.7460546493530273,1.0224932432174685,0.6796199679374695,1.3186691999435425,50000.0,0.5532000064849854,1.9711726903915403,10000.0,54238.54416680336,60169.58668136597,54238.54416680336,5919.104496240616,5.356947898864746,0.0 -116700,1.5965817,2.5045655,,,,,,,,,,,,,, -116800,1.6424173,2.1140652,,,,,,,,,,,,,, -116900,1.6577036,2.1486742,,,,,,,,,,,,,, -117000,1.4338981,4.0134544,,,,,,,,,,,,,, -117100,1.7190236,2.0419486,,,,,,,,,,,,,, -117200,1.8390212,2.3111832,,,,,,,,,,,,,, -117300,1.7589319,2.0482874,,,,,,,,,,,,,, -117400,1.5034045,2.8812318,,,,,,,,,,,,,, -117500,1.7346079,4.121623,,,,,,,,,,,,,, -117587,,,0.7542187571525574,1.0017627477645874,0.6813200116157532,1.3278917074203491,50000.0,0.5639000535011292,1.97365140914917,10000.0,54658.53148937225,60636.760419130325,54658.53148937225,5966.191679239273,5.405437469482422,0.0 -117600,1.668729,2.146583,,,,,,,,,,,,,, -117700,1.6101329,4.4946027,,,,,,,,,,,,,, -117800,1.781732,1.945276,,,,,,,,,,,,,, -117900,1.729059,2.4142303,,,,,,,,,,,,,, -118000,1.8997483,1.987328,,,,,,,,,,,,,, -118100,1.543379,4.4749756,,,,,,,,,,,,,, -118200,1.5360649,3.8580122,,,,,,,,,,,,,, -118300,1.6090158,3.0526097,,,,,,,,,,,,,, -118400,2.0840852,1.9561222,,,,,,,,,,,,,, -118493,,,0.7477148175239563,1.0368415117263794,0.6837999820709229,1.3157402276992798,50000.0,0.55840003490448,1.9573954343795776,10000.0,55078.50196790695,61104.90820026398,55078.50196790695,6014.26641201973,5.455573797225952,0.0 -118500,1.5501466,4.04613,,,,,,,,,,,,,, -118600,1.4954783,2.8432848,,,,,,,,,,,,,, -118700,1.6176856,2.0993204,,,,,,,,,,,,,, -118800,1.6389118,3.25388,,,,,,,,,,,,,, -118900,2.1899426,2.127779,,,,,,,,,,,,,, -119000,1.631167,1.9753217,,,,,,,,,,,,,, -119100,1.6003939,4.632907,,,,,,,,,,,,,, -119200,1.7167182,1.9649761,,,,,,,,,,,,,, -119300,1.6670018,2.58263,,,,,,,,,,,,,, -119399,,,0.7528710961341858,0.9768688082695008,0.6881799697875977,1.2712483406066897,50000.0,0.5649000406265259,1.9199223518371584,10000.0,55498.71514606476,61567.28264904022,55498.71514606476,6056.332176923752,5.500463962554932,0.0 -119400,1.7330202,2.432744,,,,,,,,,,,,,, -119500,1.8922441,4.374653,,,,,,,,,,,,,, -119600,1.5743347,1.9929203,,,,,,,,,,,,,, -119700,1.8065412,1.9711282,,,,,,,,,,,,,, -119800,1.6593692,1.7908716,,,,,,,,,,,,,, -119900,1.6329129,1.9259187,,,,,,,,,,,,,, -120000,1.7785835,2.022137,,,,,,,,,,,,,, -120100,1.5940014,2.2865963,,,,,,,,,,,,,, -120200,1.662694,1.968363,,,,,,,,,,,,,, -120300,1.746843,1.969656,,,,,,,,,,,,,, -120301,,,0.7587109208106995,0.9645220637321472,0.6875999569892883,1.2869911193847656,50000.0,0.5645000338554382,1.9446827173233032,10000.0,55918.85990691185,62033.25811219216,55918.85990691185,6102.064264535904,5.548049688339233,0.0 -120400,1.7780517,2.283558,,,,,,,,,,,,,, -120500,1.832225,2.102626,,,,,,,,,,,,,, -120600,1.4798805,2.7836325,,,,,,,,,,,,,, -120700,1.7765943,3.455512,,,,,,,,,,,,,, -120800,1.6924239,2.2059186,,,,,,,,,,,,,, -120900,1.4342997,2.6268287,,,,,,,,,,,,,, -121000,1.7890209,4.448208,,,,,,,,,,,,,, -121100,1.6912538,3.269558,,,,,,,,,,,,,, -121200,1.5865332,4.365386,,,,,,,,,,,,,, -121206,,,0.749804675579071,0.995943248271942,0.6902399659156799,1.2766941785812378,50000.0,0.5652000308036804,1.9142930507659912,10000.0,56338.76570916176,62500.47589445114,56338.76570916176,6149.277095556259,5.596048831939697,0.0 -121300,1.782107,4.6271634,,,,,,,,,,,,,, -121400,1.667571,4.414975,,,,,,,,,,,,,, -121500,1.6413059,3.1282208,,,,,,,,,,,,,, -121600,1.5838611,4.504275,,,,,,,,,,,,,, -121700,1.6584727,1.9811833,,,,,,,,,,,,,, -121800,1.5602086,3.780148,,,,,,,,,,,,,, -121900,1.9847519,1.8528953,,,,,,,,,,,,,, -122000,1.7168226,2.3103251,,,,,,,,,,,,,, -122100,1.7000932,3.2874045,,,,,,,,,,,,,, -122110,,,0.7586718797683716,0.9999610781669616,0.6941399574279785,1.28845477104187,50000.0,0.5681000351905823,1.9204736948013303,10000.0,56758.85825443268,62969.41731142998,56758.85825443268,6198.029419660568,5.6411542892456055,0.0 -122200,1.683741,1.8932116,,,,,,,,,,,,,, -122300,1.5770133,4.5348916,,,,,,,,,,,,,, -122400,1.8459586,2.206762,,,,,,,,,,,,,, -122500,1.755999,3.0773802,,,,,,,,,,,,,, -122600,1.6413304,1.9920017,,,,,,,,,,,,,, -122700,1.7291794,2.9767406,,,,,,,,,,,,,, -122800,1.7179959,3.7154953,,,,,,,,,,,,,, -122900,1.6671637,2.0505784,,,,,,,,,,,,,, -123000,1.8471297,1.9397213,,,,,,,,,,,,,, -123017,,,0.7599804401397705,0.9783530235290528,0.6932199597358704,1.2936484813690186,50000.0,0.5654000043869019,1.9345817565917969,10000.0,57178.78510522842,63437.81168818474,57178.78510522842,6246.399075984955,5.687482833862305,0.0 -123100,1.8025179,4.071594,,,,,,,,,,,,,, -123200,1.7389417,1.9019141,,,,,,,,,,,,,, -123300,1.566928,2.5870671,,,,,,,,,,,,,, -123400,1.7585819,1.8297256,,,,,,,,,,,,,, -123500,1.7538466,1.8519604,,,,,,,,,,,,,, -123600,1.7286795,3.6840034,,,,,,,,,,,,,, -123700,1.7126521,2.018257,,,,,,,,,,,,,, -123800,1.6859062,2.0103626,,,,,,,,,,,,,, -123900,1.8310004,1.9833796,,,,,,,,,,,,,, -123923,,,0.7659960985183716,0.934769570827484,0.6953999996185303,1.2383933067321775,50000.0,0.5713000297546387,1.8837766647338867,10000.0,57598.70680522919,63905.06403899193,57598.70680522919,6293.630912065506,5.734906911849976,0.0 -124000,1.9021101,3.100252,,,,,,,,,,,,,, -124100,1.6613852,3.1257367,,,,,,,,,,,,,, -124200,1.7442627,2.2124496,,,,,,,,,,,,,, -124300,1.8182404,1.9533224,,,,,,,,,,,,,, -124400,2.5913806,1.88041,,,,,,,,,,,,,, -124500,1.7949592,3.6063848,,,,,,,,,,,,,, -124600,1.7865584,2.4664552,,,,,,,,,,,,,, -124700,1.5762117,4.452853,,,,,,,,,,,,,, -124800,1.8753991,2.0419247,,,,,,,,,,,,,, -124829,,,0.763378918170929,0.9280872344970704,0.6961399912834167,1.2279818058013916,50000.0,0.5699000358581543,1.8792723417282104,10000.0,58018.94120192528,64375.66458725929,58018.94120192528,6343.900760173798,5.780289649963379,0.0 -124900,2.0290112,2.0101492,,,,,,,,,,,,,, -125000,1.6377606,3.7455697,,,,,,,,,,,,,, -125100,1.7130295,2.7928922,,,,,,,,,,,,,, -125200,2.0048609,1.8644447,,,,,,,,,,,,,, -125300,1.9357132,1.8867,,,,,,,,,,,,,, -125400,1.8278376,3.692988,,,,,,,,,,,,,, -125500,1.8013427,1.8863642,,,,,,,,,,,,,, -125600,1.872684,1.983203,,,,,,,,,,,,,, -125700,2.0573704,1.9537514,,,,,,,,,,,,,, -125733,,,0.7677538990974426,0.924182951450348,0.6993599534034729,1.229907751083374,50000.0,0.5800999999046326,1.8638954162597656,10000.0,58438.90938568115,64843.013377428055,58438.90938568115,6391.183966159821,5.8267786502838135,0.0 -125800,1.7613469,2.2380037,,,,,,,,,,,,,, -125900,1.807366,1.9088788,,,,,,,,,,,,,, -126000,2.0003254,1.7726853,,,,,,,,,,,,,, -126100,1.6102464,2.8972793,,,,,,,,,,,,,, -126200,1.5614784,3.9945023,,,,,,,,,,,,,, -126300,2.038605,1.9235095,,,,,,,,,,,,,, -126400,1.7178019,4.069713,,,,,,,,,,,,,, -126500,1.9336934,2.1023588,,,,,,,,,,,,,, -126600,1.8932122,1.9767658,,,,,,,,,,,,,, -126635,,,0.78236323595047,0.8538432121276855,0.7017399668693542,1.2105683088302612,50000.0,0.5753000378608704,1.8391834497451784,10000.0,58859.0885219574,65312.81332349777,58859.0885219574,6440.705798387528,5.875086545944214,0.0 -126700,1.7480558,4.5098467,,,,,,,,,,,,,, -126800,1.9191195,2.1009707,,,,,,,,,,,,,, -126900,1.8545257,1.8283366,,,,,,,,,,,,,, -127000,1.7023385,4.3605905,,,,,,,,,,,,,, -127100,1.8828903,2.05074,,,,,,,,,,,,,, -127200,1.768038,2.069767,,,,,,,,,,,,,, -127300,1.9147954,1.8366537,,,,,,,,,,,,,, -127400,1.8112999,1.8068424,,,,,,,,,,,,,, -127500,1.723386,4.4461575,,,,,,,,,,,,,, -127542,,,0.759082019329071,0.964400053024292,0.6955400109291077,1.2564650774002075,50000.0,0.5718000531196594,1.9025813341140747,10000.0,59279.26599240303,65779.41101312637,59279.26599240303,6487.029499053955,5.920918226242065,0.0 -127600,1.7323351,1.8511862,,,,,,,,,,,,,, -127700,1.6916435,3.6918645,,,,,,,,,,,,,, -127800,1.9207093,4.0015416,,,,,,,,,,,,,, -127900,1.6496576,3.7204752,,,,,,,,,,,,,, -128000,2.1194715,1.9558048,,,,,,,,,,,,,, -128100,1.7231393,3.6473427,,,,,,,,,,,,,, -128200,1.652262,3.888436,,,,,,,,,,,,,, -128300,1.7354196,3.9530942,,,,,,,,,,,,,, -128400,2.2115278,2.0471623,,,,,,,,,,,,,, -128448,,,0.774609386920929,0.8842771053314209,0.701259970664978,1.2053685188293457,50000.0,0.5759000182151794,1.8524872064590447,10000.0,59699.49506402016,66246.55462884903,59699.49506402016,6533.84437918663,5.970127820968628,0.0 -128500,1.9533776,2.2039828,,,,,,,,,,,,,, -128600,1.9368068,1.9818865,,,,,,,,,,,,,, -128700,1.7531165,1.8599324,,,,,,,,,,,,,, -128800,2.1948485,1.9079487,,,,,,,,,,,,,, -128900,1.6642025,3.2609718,,,,,,,,,,,,,, -129000,1.7574929,2.3387554,,,,,,,,,,,,,, -129100,1.9529116,1.8347266,,,,,,,,,,,,,, -129200,1.7935244,3.0963383,,,,,,,,,,,,,, -129300,2.0845673,1.8063939,,,,,,,,,,,,,, -129353,,,0.7856835722923279,0.8250307440757751,0.7035599946975708,1.1953197717666626,50000.0,0.5803000330924988,1.8332209587097168,10000.0,60119.54176044464,66712.03279447556,60119.54176044464,6579.178344249725,6.016943454742432,0.0 -129400,1.7208356,4.015048,,,,,,,,,,,,,, -129500,1.7274067,3.6826525,,,,,,,,,,,,,, -129600,1.7734026,4.2978497,,,,,,,,,,,,,, -129700,1.8942662,1.8825107,,,,,,,,,,,,,, -129800,1.6999454,4.43597,,,,,,,,,,,,,, -129900,2.2862027,1.926897,,,,,,,,,,,,,, -130000,1.9784008,2.381103,,,,,,,,,,,,,, -130100,1.9221497,1.9824299,,,,,,,,,,,,,, -130200,2.0062788,1.864022,,,,,,,,,,,,,, -130258,,,0.7689648270606995,0.9314327239990234,0.7040199637413025,1.226270079612732,50000.0,0.5763000249862671,1.872269868850708,10000.0,60539.66762089729,67180.1879966259,60539.66762089729,6627.108244419098,6.065059423446655,0.0 -130300,1.6443851,2.9644842,,,,,,,,,,,,,, -130400,1.8173109,2.578699,,,,,,,,,,,,,, -130500,2.0971675,1.9726446,,,,,,,,,,,,,, -130600,1.7624705,1.7905784,,,,,,,,,,,,,, -130700,2.1353445,2.25204,,,,,,,,,,,,,, -130800,2.1652994,4.451212,,,,,,,,,,,,,, -130900,1.7519401,3.526619,,,,,,,,,,,,,, -131000,1.9267604,2.2871106,,,,,,,,,,,,,, -131100,2.0659907,2.0699804,,,,,,,,,,,,,, -131165,,,0.7744531035423279,0.9301571249961852,0.7041199803352356,1.241104245185852,50000.0,0.589400053024292,1.8679510354995728,10000.0,60959.68307638168,67647.81341028214,60959.68307638168,6674.615309238434,6.117140293121338,0.0 -131200,1.7590218,3.821093,,,,,,,,,,,,,, -131300,2.1726463,1.7858436,,,,,,,,,,,,,, -131400,1.7532934,2.5797956,,,,,,,,,,,,,, -131500,1.9527302,1.8486102,,,,,,,,,,,,,, -131600,1.8462605,1.8950113,,,,,,,,,,,,,, -131700,1.8001935,1.9270241,,,,,,,,,,,,,, -131800,2.2245452,2.091934,,,,,,,,,,,,,, -131900,1.8225449,2.0719385,,,,,,,,,,,,,, -132000,1.8613954,1.9043046,,,,,,,,,,,,,, -132069,,,0.7873827815055847,0.8701390624046326,0.7058999538421631,1.2221981287002563,50000.0,0.5800000429153442,1.8685365915298464,10000.0,61379.67488145828,68112.85487318039,61379.67488145828,6719.562696695328,6.168532371520996,0.0 -132100,1.9027302,4.068385,,,,,,,,,,,,,, -132200,2.0140946,1.8556699,,,,,,,,,,,,,, -132300,1.7584102,3.980136,,,,,,,,,,,,,, -132400,2.0966775,1.734349,,,,,,,,,,,,,, -132500,1.8578271,4.0634747,,,,,,,,,,,,,, -132600,1.9412892,1.8018775,,,,,,,,,,,,,, -132700,2.0139215,2.0175607,,,,,,,,,,,,,, -132800,1.8860826,1.8860676,,,,,,,,,,,,,, -132900,1.9864136,1.9685858,,,,,,,,,,,,,, -132973,,,0.7793359160423279,0.8791748881340027,0.7087000012397766,1.1858389377593994,50000.0,0.5845000147819519,1.819374918937683,10000.0,61799.59830021858,68583.33038425446,61799.59830021858,6770.016849517822,6.2154810428619385,0.0 -133000,1.8543599,2.9447927,,,,,,,,,,,,,, -133100,2.2073226,2.0389135,,,,,,,,,,,,,, -133200,1.8022178,4.3799906,,,,,,,,,,,,,, -133300,1.9439385,1.8173912,,,,,,,,,,,,,, -133400,1.8080009,3.7095559,,,,,,,,,,,,,, -133500,2.4582968,1.8771796,,,,,,,,,,,,,, -133600,2.0911248,1.8215215,,,,,,,,,,,,,, -133700,1.7747927,3.727648,,,,,,,,,,,,,, -133800,1.9978507,1.7653097,,,,,,,,,,,,,, -133878,,,0.7819726467132568,0.849389910697937,0.7128599882125854,1.159160017967224,50000.0,0.5937000513076782,1.7851358652114868,10000.0,62219.77271294594,69050.48225259781,62219.77271294594,6816.891093969345,6.26798677444458,0.0 -133900,1.9755254,1.8531866,,,,,,,,,,,,,, -134000,1.9075017,1.9379134,,,,,,,,,,,,,, -134100,1.9852023,1.8862238,,,,,,,,,,,,,, -134200,1.9987637,1.7979426,,,,,,,,,,,,,, -134300,2.3014367,1.876924,,,,,,,,,,,,,, -134400,1.8936639,2.252575,,,,,,,,,,,,,, -134500,1.9215955,1.8085593,,,,,,,,,,,,,, -134600,2.033106,4.3401184,,,,,,,,,,,,,, -134700,2.0692916,2.0637963,,,,,,,,,,,,,, -134782,,,0.792773425579071,0.7919089198112488,0.7138599753379822,1.1439038515090942,50000.0,0.5934000015258789,1.7752066850662231,10000.0,62639.89648962021,69520.44361186028,62639.89648962021,6866.628557682037,6.316884279251099,0.0 -134800,1.7888057,2.8399167,,,,,,,,,,,,,, -134900,2.1486604,1.7714361,,,,,,,,,,,,,, -135000,2.1152704,2.0606918,,,,,,,,,,,,,, -135100,1.8545375,4.010672,,,,,,,,,,,,,, -135200,1.9249041,2.970177,,,,,,,,,,,,,, -135300,1.960162,4.255148,,,,,,,,,,,,,, -135400,1.96682,2.4062953,,,,,,,,,,,,,, -135500,2.2827954,1.836079,,,,,,,,,,,,,, -135600,1.8732485,1.6944729,,,,,,,,,,,,,, -135682,,,0.7776952981948853,0.8922684788703918,0.7090799808502197,1.2007648944854736,50000.0,0.5867000222206116,1.83801019191742,10000.0,63059.385607004166,69984.6502289772,63059.385607004166,6910.820559263229,6.79111123085022,0.0 -135700,2.173636,1.7948053,,,,,,,,,,,,,, -135800,2.1326835,3.379788,,,,,,,,,,,,,, -135900,1.7877185,2.3881493,,,,,,,,,,,,,, -136000,1.8137026,3.3744683,,,,,,,,,,,,,, -136100,1.9841547,2.3902714,,,,,,,,,,,,,, -136200,2.0393944,1.8646553,,,,,,,,,,,,,, -136300,2.1624377,1.8916457,,,,,,,,,,,,,, -136400,2.237017,1.9031236,,,,,,,,,,,,,, -136500,2.2305892,1.8206995,,,,,,,,,,,,,, -136587,,,0.7864648103713989,0.8539221286773682,0.715499997138977,1.170691967010498,50000.0,0.5918000340461731,1.7970510721206665,10000.0,63479.35899710655,70449.76348471642,63479.35899710655,6955.858748197556,6.841918230056763,0.0 -136600,1.8352135,3.7399142,,,,,,,,,,,,,, -136700,2.0850606,1.7271463,,,,,,,,,,,,,, -136800,1.8357885,1.9538516,,,,,,,,,,,,,, -136900,2.0106664,2.1845343,,,,,,,,,,,,,, -137000,2.0085015,1.7610587,,,,,,,,,,,,,, -137100,1.9066293,3.2854092,,,,,,,,,,,,,, -137200,2.0681152,1.769687,,,,,,,,,,,,,, -137300,1.8071618,2.980956,,,,,,,,,,,,,, -137400,1.8705885,1.725288,,,,,,,,,,,,,, -137494,,,0.79066401720047,0.8283480405807495,0.7112599611282349,1.1805391311645508,50000.0,0.5878000259399414,1.812727212905884,10000.0,63899.612374305725,70919.30957722664,63899.612374305725,7005.048368930817,6.893892765045166,0.0 -137500,2.00044,1.8165059,,,,,,,,,,,,,, -137600,1.9474279,1.9739189,,,,,,,,,,,,,, -137700,2.190761,3.8201377,,,,,,,,,,,,,, -137800,1.9155637,1.8392313,,,,,,,,,,,,,, -137900,1.9374299,3.2042594,,,,,,,,,,,,,, -138000,2.006053,2.2534795,,,,,,,,,,,,,, -138100,1.9784383,2.8845243,,,,,,,,,,,,,, -138200,2.0770333,3.8930635,,,,,,,,,,,,,, -138300,1.9573349,1.7873352,,,,,,,,,,,,,, -138400,,,0.7880859375,0.8579350709915161,0.7173199653625488,1.1725926399230957,50000.0,0.5919000506401062,1.8064581155776973,10000.0,64319.660131692886,71387.13676571846,64319.660131692886,7052.7244708538055,6.946126461029053,0.0 -138400,1.7993071,3.9962847,,,,,,,,,,,,,, -138500,2.0818403,4.3832054,,,,,,,,,,,,,, -138600,1.9448545,2.2873764,,,,,,,,,,,,,, -138700,1.9671786,2.4079285,,,,,,,,,,,,,, -138800,1.9778349,2.4226563,,,,,,,,,,,,,, -138900,2.0814557,1.8352844,,,,,,,,,,,,,, -139000,2.045182,3.6582375,,,,,,,,,,,,,, -139100,2.0601535,1.7515815,,,,,,,,,,,,,, -139200,2.1041653,1.898985,,,,,,,,,,,,,, -139300,2.5160148,2.086749,,,,,,,,,,,,,, -139301,,,0.7901171445846558,0.8141428828239441,0.7179799675941467,1.1384248733520508,50000.0,0.5958999991416931,1.7701454162597656,10000.0,64739.67565011978,71854.72164583206,64739.67565011978,7100.193810462952,6.996117830276489,0.0 -139400,1.924054,4.013544,,,,,,,,,,,,,, -139500,2.4442894,3.600393,,,,,,,,,,,,,, -139600,2.263777,1.7393968,,,,,,,,,,,,,, -139700,2.2988245,1.9417294,,,,,,,,,,,,,, -139800,2.1328053,1.7709404,,,,,,,,,,,,,, -139900,2.1047168,1.5998921,,,,,,,,,,,,,, -140000,2.040499,1.6788839,,,,,,,,,,,,,, -140100,2.0108242,2.085488,,,,,,,,,,,,,, -140200,1.9287641,2.548304,,,,,,,,,,,,,, -140204,,,0.8003124594688416,0.7800408601760864,0.717739999294281,1.1377885341644287,50000.0,0.5996000170707703,1.7611976861953735,10000.0,65159.74625015259,72324.3936612606,65159.74625015259,7149.690311908722,7.050333499908447,0.0 -140300,2.1763594,1.7938372,,,,,,,,,,,,,, -140400,1.8666023,2.277455,,,,,,,,,,,,,, -140500,2.0062523,3.344007,,,,,,,,,,,,,, -140600,2.3379843,1.7237843,,,,,,,,,,,,,, -140700,1.9030099,1.6677746,,,,,,,,,,,,,, -140800,2.2709858,1.7691935,,,,,,,,,,,,,, -140900,1.9167044,1.7021651,,,,,,,,,,,,,, -141000,2.217843,2.1963124,,,,,,,,,,,,,, -141100,2.0963817,4.098794,,,,,,,,,,,,,, -141109,,,0.7925195097923279,0.8001589179039001,0.7177000045776367,1.1277657747268677,50000.0,0.5931000113487244,1.7584959268569946,10000.0,65579.95306015015,72789.65150928497,65579.95306015015,7194.6404457092285,7.10044264793396,0.0 -141200,1.8133245,3.1000867,,,,,,,,,,,,,, -141300,2.0210586,3.843813,,,,,,,,,,,,,, -141400,2.1175363,2.5355654,,,,,,,,,,,,,, -141500,2.000119,2.126278,,,,,,,,,,,,,, -141600,2.316611,3.6823425,,,,,,,,,,,,,, -141700,2.1167264,3.76085,,,,,,,,,,,,,, -141800,2.0753357,1.6392962,,,,,,,,,,,,,, -141900,2.30667,4.305166,,,,,,,,,,,,,, -142000,2.2773757,1.7724237,,,,,,,,,,,,,, -142014,,,0.8015038967132568,0.7711244821548462,0.722599983215332,1.1164039373397827,50000.0,0.6005000472068787,1.7380754947662354,10000.0,65999.96157240868,73259.56003165245,65999.96157240868,7244.440073251724,7.150071620941162,0.0 -142100,1.9539695,2.8376288,,,,,,,,,,,,,, -142200,2.0727513,1.7477205,,,,,,,,,,,,,, -142300,2.001568,3.0860577,,,,,,,,,,,,,, -142400,1.9314666,1.8925124,,,,,,,,,,,,,, -142500,2.1218255,1.5268779,,,,,,,,,,,,,, -142600,2.310873,1.7571738,,,,,,,,,,,,,, -142700,1.9853195,2.2146876,,,,,,,,,,,,,, -142800,2.1163726,2.80634,,,,,,,,,,,,,, -142900,2.4252334,1.802265,,,,,,,,,,,,,, -142915,,,0.8013671636581421,0.7719994187355042,0.721340000629425,1.121065616607666,50000.0,0.5999000072479248,1.7464487552642822,10000.0,66420.08053159714,73728.358700037,66420.08053159714,7293.018522024155,7.19990348815918,0.0 -143000,2.134498,3.928156,,,,,,,,,,,,,, -143100,2.1243062,1.6326526,,,,,,,,,,,,,, -143200,2.1982892,1.7625033,,,,,,,,,,,,,, -143300,2.181384,4.134807,,,,,,,,,,,,,, -143400,2.2625694,1.7926991,,,,,,,,,,,,,, -143500,2.1752157,1.8548453,,,,,,,,,,,,,, -143600,2.1225173,2.357581,,,,,,,,,,,,,, -143700,1.9908874,2.8643062,,,,,,,,,,,,,, -143800,2.7078576,1.7668232,,,,,,,,,,,,,, -143819,,,0.8026562333106995,0.7779040932655334,0.7221399545669556,1.1373088359832764,50000.0,0.5993000268936157,1.7647473812103271,10000.0,66840.15920996666,74196.6847281456,66840.15920996666,7341.1628839969635,7.251514196395874,0.0 -143900,1.8852503,2.9991887,,,,,,,,,,,,,, -144000,2.478135,1.6429592,,,,,,,,,,,,,, -144100,1.9453684,3.929772,,,,,,,,,,,,,, -144200,3.3400862,1.780128,,,,,,,,,,,,,, -144300,2.146013,1.6257441,,,,,,,,,,,,,, -144400,2.1876822,1.7342533,,,,,,,,,,,,,, -144500,2.3140254,1.7428164,,,,,,,,,,,,,, -144600,2.4542525,1.7342339,,,,,,,,,,,,,, -144700,2.2732072,1.6885902,,,,,,,,,,,,,, -144724,,,0.7992382645606995,0.7773339748382568,0.7215799689292908,1.1161545515060425,50000.0,0.5981000065803528,1.7557260990142822,10000.0,67260.12916207314,74660.76443743706,67260.12916207314,7385.170194864273,7.303040027618408,0.0 -144800,2.2848423,2.2700086,,,,,,,,,,,,,, -144900,2.4284098,1.7074504,,,,,,,,,,,,,, -145000,2.055779,2.7483292,,,,,,,,,,,,,, -145100,2.4290984,1.8958269,,,,,,,,,,,,,, -145200,2.312082,1.6871952,,,,,,,,,,,,,, -145300,2.0101004,1.6586974,,,,,,,,,,,,,, -145400,2.171569,1.7260661,,,,,,,,,,,,,, -145500,1.9792542,1.5981543,,,,,,,,,,,,,, -145600,2.0794618,3.8606257,,,,,,,,,,,,,, -145628,,,0.80517578125,0.7599520683288574,0.7254599928855896,1.1081088781356812,50000.0,0.6051000356674194,1.739145040512085,10000.0,67680.05412721634,75128.43763279915,67680.05412721634,7432.81751871109,7.352954149246216,0.0 -145700,2.2456276,1.5612638,,,,,,,,,,,,,, -145800,2.1708639,3.065325,,,,,,,,,,,,,, -145900,2.0977526,3.5032187,,,,,,,,,,,,,, -146000,2.2674632,1.8636591,,,,,,,,,,,,,, -146100,2.2788248,1.6885289,,,,,,,,,,,,,, -146200,2.3135202,1.6204964,,,,,,,,,,,,,, -146300,2.1387484,2.036123,,,,,,,,,,,,,, -146400,1.9863898,2.4657614,,,,,,,,,,,,,, -146500,2.2777941,4.200493,,,,,,,,,,,,,, -146535,,,0.8170703053474426,0.7173280119895935,0.7281399965286255,1.1051756143569946,50000.0,0.6033000349998474,1.7289841175079346,10000.0,68100.3693766594,75593.36808228493,68100.3693766594,7477.329308271408,7.4053473472595215,0.0 -146600,2.6657186,1.7584391,,,,,,,,,,,,,, -146700,2.565392,1.6969187,,,,,,,,,,,,,, -146800,2.517545,4.198935,,,,,,,,,,,,,, -146900,2.2137616,2.104578,,,,,,,,,,,,,, -147000,2.2298794,1.6569513,,,,,,,,,,,,,, -147100,2.3407958,1.7619973,,,,,,,,,,,,,, -147200,2.2128901,3.260829,,,,,,,,,,,,,, -147300,2.6886034,1.6158122,,,,,,,,,,,,,, -147400,2.0830445,3.065189,,,,,,,,,,,,,, -147440,,,0.8044531345367432,0.7989120483398438,0.7260000109672546,1.1396101713180542,50000.0,0.601900041103363,1.774980902671814,10000.0,68520.51922082901,76059.94834542274,68520.51922082901,7523.654905796051,7.458659648895264,0.0 -147500,2.2162173,2.6078768,,,,,,,,,,,,,, -147600,2.2562037,1.6628529,,,,,,,,,,,,,, -147700,2.179348,2.063851,,,,,,,,,,,,,, -147800,2.2798781,1.6491771,,,,,,,,,,,,,, -147900,2.5242596,1.6377772,,,,,,,,,,,,,, -148000,2.999257,1.7017988,,,,,,,,,,,,,, -148100,2.333249,1.9752356,,,,,,,,,,,,,, -148200,2.2583046,1.7916104,,,,,,,,,,,,,, -148300,2.3690767,2.005187,,,,,,,,,,,,,, -148347,,,0.8073046803474426,0.7472729086875916,0.7283799648284912,1.0935417413711548,50000.0,0.6075000166893005,1.727489352226257,10000.0,68940.56979131699,76524.28863739967,68940.56979131699,7567.843139410019,7.50982141494751,0.0 -148400,2.1995902,3.1363673,,,,,,,,,,,,,, -148500,2.2661934,3.5681057,,,,,,,,,,,,,, -148600,2.2274292,1.6388726,,,,,,,,,,,,,, -148700,2.4052122,1.5416605,,,,,,,,,,,,,, -148800,2.2030904,4.0728674,,,,,,,,,,,,,, -148900,2.0785992,1.8576741,,,,,,,,,,,,,, -149000,2.7509224,1.5984472,,,,,,,,,,,,,, -149100,2.4980237,1.6604946,,,,,,,,,,,,,, -149200,2.4690864,1.6006365,,,,,,,,,,,,,, -149250,,,0.81214839220047,0.7340012192726135,0.7254999876022339,1.1147717237472534,50000.0,0.6032000184059143,1.7510703802108765,10000.0,69360.65314507484,76988.09020090103,69360.65314507484,7611.454800367355,7.564750909805298,0.0 -149300,2.314749,2.2772498,,,,,,,,,,,,,, -149400,2.3720603,4.075362,,,,,,,,,,,,,, -149500,2.3113382,1.4804351,,,,,,,,,,,,,, -149600,2.2838511,2.5769124,,,,,,,,,,,,,, -149700,2.1730473,2.9253688,,,,,,,,,,,,,, -149800,3.0127652,1.7428921,,,,,,,,,,,,,, -149900,2.4621475,1.6205372,,,,,,,,,,,,,, -150000,2.2417672,1.7181655,,,,,,,,,,,,,, -150100,2.462935,2.8600104,,,,,,,,,,,,,, -150156,,,0.8089648485183716,0.7477630972862244,0.7283399701118469,1.0934544801712036,50000.0,0.6111000180244446,1.714582920074463,10000.0,69780.64332485199,77452.66871452332,69780.64332485199,7655.94317984581,7.613767862319946,0.0 -150200,2.3514504,1.6644264,,,,,,,,,,,,,, -150300,2.4512527,3.7000039,,,,,,,,,,,,,, -150400,2.2135425,3.147434,,,,,,,,,,,,,, -150500,2.3896897,1.7203,,,,,,,,,,,,,, -150600,2.0731025,1.7465297,,,,,,,,,,,,,, -150700,2.2075307,3.8411222,,,,,,,,,,,,,, -150800,2.2208128,2.5070827,,,,,,,,,,,,,, -150900,2.2399032,1.5432843,,,,,,,,,,,,,, -151000,2.4350984,1.5411432,,,,,,,,,,,,,, -151062,,,0.8106249570846558,0.7664503455162048,0.7301999926567078,1.119113564491272,50000.0,0.6077000498771667,1.7483646869659424,10000.0,70200.61243200302,77918.21200990677,70200.61243200302,7701.411295890808,7.668791770935059,0.0 -151100,2.0199637,1.7466424,,,,,,,,,,,,,, -151200,2.4028013,1.5583371,,,,,,,,,,,,,, -151300,2.2076705,3.82895,,,,,,,,,,,,,, -151400,2.2418048,3.415556,,,,,,,,,,,,,, -151500,2.464292,2.730569,,,,,,,,,,,,,, -151600,2.6375012,1.6763319,,,,,,,,,,,,,, -151700,2.9318497,1.6712846,,,,,,,,,,,,,, -151800,2.3460164,2.2001684,,,,,,,,,,,,,, -151900,2.3805697,1.5581039,,,,,,,,,,,,,, -151966,,,0.8160937428474426,0.7283310294151306,0.7328400015830994,1.0975955724716189,50000.0,0.6164000034332275,1.7178739309310913,10000.0,70620.5970556736,78385.43194818497,70620.5970556736,7748.537898540497,7.726431369781494,0.0 -152000,2.2288208,1.6411141,,,,,,,,,,,,,, -152100,2.4604468,1.7317382,,,,,,,,,,,,,, -152200,2.85526,1.6464101,,,,,,,,,,,,,, -152300,2.2626662,3.1501617,,,,,,,,,,,,,, -152400,2.3555763,1.7965872,,,,,,,,,,,,,, -152500,2.5461705,1.8497858,,,,,,,,,,,,,, -152600,2.617598,2.0191987,,,,,,,,,,,,,, -152700,2.3265605,1.5202698,,,,,,,,,,,,,, -152800,2.324989,1.6929034,,,,,,,,,,,,,, -152872,,,0.8132226467132568,0.7306903004646301,0.7350199818611145,1.078985333442688,50000.0,0.6099000573158264,1.705723762512207,10000.0,71040.94093084335,78848.79572200775,71040.94093084335,7791.45180106163,7.7803356647491455,0.0 -152900,2.475632,3.686675,,,,,,,,,,,,,, -153000,2.2408652,2.9613411,,,,,,,,,,,,,, -153100,2.4255474,3.4870448,,,,,,,,,,,,,, -153200,2.4715025,1.7375029,,,,,,,,,,,,,, -153300,2.61987,2.2249002,,,,,,,,,,,,,, -153400,2.4728174,1.5294511,,,,,,,,,,,,,, -153500,2.3386028,2.74134,,,,,,,,,,,,,, -153600,2.319334,1.6270115,,,,,,,,,,,,,, -153700,2.4618957,1.6643771,,,,,,,,,,,,,, -153775,,,0.81068354845047,0.7280431389808655,0.7346599698066711,1.074668526649475,50000.0,0.6076000332832336,1.7089613676071167,10000.0,71460.94529819489,79316.3397808075,71460.94529819489,7838.883292198181,7.837963819503784,0.0 -153800,2.951817,2.6042037,,,,,,,,,,,,,, -153900,2.820663,1.4994283,,,,,,,,,,,,,, -154000,2.4917653,1.5615752,,,,,,,,,,,,,, -154100,2.4932866,4.00926,,,,,,,,,,,,,, -154200,2.4831262,1.7695425,,,,,,,,,,,,,, -154300,2.3717449,1.9661512,,,,,,,,,,,,,, -154400,2.7731633,1.5140263,,,,,,,,,,,,,, -154500,2.352212,2.3309562,,,,,,,,,,,,,, -154600,2.58267,1.5641872,,,,,,,,,,,,,, -154682,,,0.8229491710662842,0.6811491250991821,0.7339999675750732,1.0581152439117432,50000.0,0.6127000451087952,1.691803216934204,10000.0,71881.10067653656,79785.20401096344,71881.10067653656,7887.490196943283,7.888981103897095,0.0 -154700,2.5415308,1.5967637,,,,,,,,,,,,,, -154800,2.3001301,1.8472276,,,,,,,,,,,,,, -154900,2.6368313,1.7452952,,,,,,,,,,,,,, -155000,2.6931958,2.265443,,,,,,,,,,,,,, -155100,2.4858959,1.7696066,,,,,,,,,,,,,, -155200,2.4815347,1.6774534,,,,,,,,,,,,,, -155300,2.4385302,2.216432,,,,,,,,,,,,,, -155400,2.5246065,1.654176,,,,,,,,,,,,,, -155500,2.429861,4.0781674,,,,,,,,,,,,,, -155586,,,0.8204687237739563,0.7243859767913818,0.7372999787330627,1.0768547058105469,50000.0,0.6158000230789185,1.7009111642837524,10000.0,72301.3064084053,80253.32729268074,72301.3064084053,7935.300303936005,7.945384502410889,0.0 -155600,2.5629375,1.6359369,,,,,,,,,,,,,, -155700,2.3462574,2.4222407,,,,,,,,,,,,,, -155800,2.3677278,1.5920107,,,,,,,,,,,,,, -155900,2.5875647,2.1310327,,,,,,,,,,,,,, -156000,2.6161618,1.7378609,,,,,,,,,,,,,, -156100,2.627565,1.7959478,,,,,,,,,,,,,, -156200,2.573605,1.6322337,,,,,,,,,,,,,, -156300,2.4265163,3.384339,,,,,,,,,,,,,, -156400,2.7916648,4.086506,,,,,,,,,,,,,, -156487,,,0.8235155940055847,0.6817896962165833,0.7389400005340576,1.045330047607422,50000.0,0.615600049495697,1.6857863664627075,10000.0,72721.32059788704,80723.79385137558,72721.32059788704,7985.649563550949,7.997928142547607,0.0 -156500,2.3872864,1.5607367,,,,,,,,,,,,,, -156600,2.6116307,1.7505845,,,,,,,,,,,,,, -156700,2.6919775,1.6061988,,,,,,,,,,,,,, -156800,2.8755796,1.5779866,,,,,,,,,,,,,, -156900,2.349486,1.6882004,,,,,,,,,,,,,, -157000,2.60193,1.4213177,,,,,,,,,,,,,, -157100,2.406311,1.513712,,,,,,,,,,,,,, -157200,2.2130442,3.0293107,,,,,,,,,,,,,, -157300,2.6491468,1.5480964,,,,,,,,,,,,,, -157392,,,0.8236913681030273,0.6922850012779236,0.7382199764251709,1.066370725631714,50000.0,0.6150000095367432,1.6935813426971436,10000.0,73141.60916376114,81188.90914106369,73141.60916376114,8030.373905658722,8.049262046813965,0.0 -157400,2.59133,1.5178617,,,,,,,,,,,,,, -157500,2.9648824,1.473556,,,,,,,,,,,,,, -157600,2.5849419,1.5368751,,,,,,,,,,,,,, -157700,2.534377,1.5114127,,,,,,,,,,,,,, -157800,2.2928226,1.8482615,,,,,,,,,,,,,, -157900,2.6553586,1.596497,,,,,,,,,,,,,, -158000,2.6553364,1.5713406,,,,,,,,,,,,,, -158100,2.801177,1.5560515,,,,,,,,,,,,,, -158200,2.4296958,3.1540487,,,,,,,,,,,,,, -158297,,,0.8206835985183716,0.6895581483840942,0.7396399974822998,1.041470289230347,50000.0,0.6205000281333923,1.6642597913742063,10000.0,73561.83675098419,81657.4282822609,73561.83675098419,8078.557241678238,8.106578588485718,0.0 -158300,2.4879646,1.7304358,,,,,,,,,,,,,, -158400,2.6198456,1.6917108,,,,,,,,,,,,,, -158500,2.4512637,1.8243501,,,,,,,,,,,,,, -158600,2.502944,3.7485943,,,,,,,,,,,,,, -158700,2.335186,2.818204,,,,,,,,,,,,,, -158800,2.3495703,3.5785334,,,,,,,,,,,,,, -158900,2.5490153,2.5048635,,,,,,,,,,,,,, -159000,2.5020673,1.6320056,,,,,,,,,,,,,, -159100,2.5510213,2.059548,,,,,,,,,,,,,, -159200,2.7404933,2.5910025,,,,,,,,,,,,,, -159204,,,0.82533198595047,0.7197791934013367,0.7388399839401245,1.0863151550292969,50000.0,0.6186000108718872,1.7149657011032104,10000.0,73981.86839866638,82122.87282919884,73981.86839866638,8123.858737468719,8.167153358459473,0.0 -159300,2.4242892,1.652343,,,,,,,,,,,,,, -159400,2.5939693,1.6721592,,,,,,,,,,,,,, -159500,2.8064997,1.474087,,,,,,,,,,,,,, -159600,2.5467112,3.2160606,,,,,,,,,,,,,, -159700,2.329447,3.2289855,,,,,,,,,,,,,, -159800,2.453354,1.5151434,,,,,,,,,,,,,, -159900,2.4878109,2.8046093,,,,,,,,,,,,,, -160000,2.443356,3.6749115,,,,,,,,,,,,,, -160100,2.4448717,3.0969846,,,,,,,,,,,,,, -160110,,,0.8278124928474426,0.6878633499145508,0.7412199974060059,1.0593788623809814,50000.0,0.6185000538825989,1.684755802154541,10000.0,74401.79539680481,82587.42238020897,74401.79539680481,8168.378638029098,8.218732833862305,0.0 -160200,2.866104,3.3153915,,,,,,,,,,,,,, -160300,2.4861944,3.887039,,,,,,,,,,,,,, -160400,2.7260284,1.596635,,,,,,,,,,,,,, -160500,2.4328759,1.3459219,,,,,,,,,,,,,, -160600,2.5609152,3.6953132,,,,,,,,,,,,,, -160700,2.8033679,1.6300881,,,,,,,,,,,,,, -160800,3.1448376,3.338777,,,,,,,,,,,,,, -160900,2.580375,1.8355013,,,,,,,,,,,,,, -161000,2.8331892,1.9931189,,,,,,,,,,,,,, -161015,,,0.82972651720047,0.6675693392753601,0.7438399791717529,1.0392054319381714,50000.0,0.6220000386238098,1.668728590011597,10000.0,74821.9603896141,83056.85979914665,74821.9603896141,8217.548476457596,8.271214008331299,0.0 -161100,2.3862686,1.4645939,,,,,,,,,,,,,, -161200,2.6492212,1.5305679,,,,,,,,,,,,,, -161300,2.4472044,1.5873803,,,,,,,,,,,,,, -161400,2.680604,2.2315497,,,,,,,,,,,,,, -161500,2.364977,2.8146904,,,,,,,,,,,,,, -161600,2.7057214,1.5354612,,,,,,,,,,,,,, -161700,3.020027,1.5081874,,,,,,,,,,,,,, -161800,2.54242,1.4002428,,,,,,,,,,,,,, -161900,2.727247,3.7463105,,,,,,,,,,,,,, -161921,,,0.8286718726158142,0.672528088092804,0.7439000010490417,1.04596745967865,50000.0,0.6206000447273254,1.6684740781784058,10000.0,75242.30540776253,83524.30050992966,75242.30540776253,8264.541722774506,8.323035478591919,0.0 -162000,2.6415093,1.6487311,,,,,,,,,,,,,, -162100,2.7772079,3.9234717,,,,,,,,,,,,,, -162200,2.8710334,1.6781969,,,,,,,,,,,,,, -162300,2.9306502,1.4456129,,,,,,,,,,,,,, -162400,2.6138837,1.4481039,,,,,,,,,,,,,, -162500,2.499027,1.5483769,,,,,,,,,,,,,, -162600,2.7141876,1.6215271,,,,,,,,,,,,,, -162700,2.7048073,1.4659629,,,,,,,,,,,,,, -162800,2.623567,3.6754787,,,,,,,,,,,,,, -162826,,,0.8308398127555847,0.6795992851257324,0.743399977684021,1.0521153211593628,50000.0,0.6234000325202942,1.676736831665039,10000.0,75662.24251437187,83989.64229750633,75662.24251437187,8309.82877087593,8.38948941230774,0.0 -162900,2.5934088,3.1683888,,,,,,,,,,,,,, -163000,2.6659868,1.5362679,,,,,,,,,,,,,, -163100,2.7985964,2.524224,,,,,,,,,,,,,, -163200,2.6645005,1.4611084,,,,,,,,,,,,,, -163300,2.713196,1.59849,,,,,,,,,,,,,, -163400,2.5763364,1.6789249,,,,,,,,,,,,,, -163500,3.0339208,3.0908573,,,,,,,,,,,,,, -163600,2.8945649,3.5836556,,,,,,,,,,,,,, -163700,2.5748737,2.5499945,,,,,,,,,,,,,, -163733,,,0.8407226204872131,0.6239649653434753,0.7476399540901184,1.0164140462875366,50000.0,0.6276000142097473,1.6446024179458618,10000.0,76082.43580412865,84458.79799222946,76082.43580412865,8358.68628025055,8.442897319793701,0.0 -163800,2.5643587,1.5431798,,,,,,,,,,,,,, -163900,2.862733,1.5472476,,,,,,,,,,,,,, -164000,2.7156475,3.161572,,,,,,,,,,,,,, -164100,2.9359343,4.0380144,,,,,,,,,,,,,, -164200,2.8161724,3.909131,,,,,,,,,,,,,, -164300,2.6028411,3.5557663,,,,,,,,,,,,,, -164400,2.9105148,1.6674281,,,,,,,,,,,,,, -164500,2.7927504,1.4693933,,,,,,,,,,,,,, -164600,2.7456858,1.4720645,,,,,,,,,,,,,, -164639,,,0.8356835842132568,0.6445589065551758,0.7468599677085876,1.0233219861984253,50000.0,0.6267000436782837,1.6496158838272097,10000.0,76502.39468336105,84926.75717353821,76502.39468336105,8406.580387592316,8.498199224472046,0.0 -164700,2.712297,3.2963436,,,,,,,,,,,,,, -164800,2.7020025,2.157304,,,,,,,,,,,,,, -164900,2.5650668,1.8063874,,,,,,,,,,,,,, -165000,2.5914743,1.5693761,,,,,,,,,,,,,, -165100,2.713981,3.0876522,,,,,,,,,,,,,, -165200,2.727073,1.6616851,,,,,,,,,,,,,, -165300,3.2269084,1.8282915,,,,,,,,,,,,,, -165400,2.6678607,1.5579872,,,,,,,,,,,,,, -165500,2.626661,2.7488294,,,,,,,,,,,,,, -165546,,,0.8334375023841858,0.6440165638923645,0.7469399571418762,1.0224330425262451,50000.0,0.6241000294685364,1.640983819961548,10000.0,76922.57128500938,85391.76311731339,76922.57128500938,8451.303115844727,8.553112983703613,0.0 -165600,2.6849782,3.9360108,,,,,,,,,,,,,, -165700,2.8948543,1.5247092,,,,,,,,,,,,,, -165800,2.8603039,1.5037457,,,,,,,,,,,,,, -165900,2.885876,2.0013957,,,,,,,,,,,,,, -166000,2.6745927,1.7860904,,,,,,,,,,,,,, -166100,3.0869808,2.1446195,,,,,,,,,,,,,, -166200,2.6149948,1.819627,,,,,,,,,,,,,, -166300,2.927298,1.7149631,,,,,,,,,,,,,, -166400,2.9979203,2.1705992,,,,,,,,,,,,,, -166451,,,0.8373242020606995,0.6317007541656494,0.7473799586296082,1.024543523788452,50000.0,0.6309000253677368,1.6426352262496948,10000.0,77342.91505241394,85861.23006987572,77342.91505241394,8500.320755720139,8.607258081436157,0.0 -166500,2.751465,1.5014408,,,,,,,,,,,,,, -166600,2.7228882,2.9765782,,,,,,,,,,,,,, -166700,2.7619162,1.9201633,,,,,,,,,,,,,, -166800,2.916421,1.5794451,,,,,,,,,,,,,, -166841,,,,,,,,,,,77520.45635008812,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 2e5cfe82c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -178.22711658477783,0.0,75.61166572570801,1,0,75.61166572570801,31.23744,2472,1.10235004976337,253.83884978294373,32.12085,1.3685577743997668,31.08705,5348,1.0585651254622166 -283.8664665222168,0.0405006408691406,1515.929588317871,1844,0,1515.929588317871,6.2753453,2472,0.899579550301627,1799.9081389904022,6.1664042,0.944635537887994,6.2940016,5348,0.8966179750330672 -413.2503607273102,0.0917296409606933,2956.97966003418,3701,0,2956.97966003418,2.4001753,2472,0.5354335506672354,3370.4715399742126,2.3516572,0.5535412726740004,2.7404253,5348,0.5721250856850459 -544.7035682201385,0.1459474563598632,4397.298481225967,5551,0,4397.298481225967,0.78583676,2472,0.2550118822740844,4942.37772154808,0.7146427,0.2437602458476314,1.0680373,5348,0.3108122459619413 -676.2148461341858,0.1972596645355224,5837.907940626144,7377,0,5837.907940626144,0.5604038,2472,0.1868055978713464,6514.627334594727,0.5079412,0.1768737760370304,0.821356,5348,0.2468405147861011 -808.0982096195221,0.2487895488739013,7277.963893651962,9199,0,7277.963893651962,0.48268652,2472,0.1639956939451181,8086.693351507187,0.43121392,0.1545967513970732,0.7380511,5348,0.222839047278836 -941.0179760456084,0.3003580570220947,8718.414444446564,11020,0,8718.414444446564,0.4419891,2472,0.1509759714012958,9660.191576719284,0.41369617,0.1462302148294754,0.6845674,5348,0.2078839896888305 -1075.3007299900055,0.3552329540252685,10159.513425588608,12857,0,10159.513425588608,0.402794,2472,0.1361485182702659,11235.706638336182,0.35117,0.1268904345606534,0.6366001,5348,0.1944157486700715 -1209.0333228111267,0.4085783958435058,11599.84973692894,14670,0,11599.84973692894,0.38013542,2472,0.1297909938455913,12809.905858516691,0.301788,0.1133882836267193,0.6053062,5348,0.1848866060998098 -1342.5781121253967,0.4630181789398193,13039.76072025299,16457,0,13039.76072025299,0.36317125,2472,0.1238600125931793,14383.49019765854,0.27452624,0.1010592172113486,0.5743748,5348,0.1744981994072043 -1475.30983710289,0.512258768081665,14480.069234132769,18251,0,14480.069234132769,0.34133852,2472,0.1155119533646131,15956.655859947205,0.2666409,0.1004896308730231,0.5565517,5348,0.1694584705098622 -1608.389491558075,0.5656495094299316,15920.46349310875,20072,0,15920.46349310875,0.3296891,2472,0.113358925923669,17530.263286590576,0.27632526,0.0999444723551653,0.53343123,5348,0.1639939368778783 -1741.165496110916,0.6244997978210449,17360.60667347908,21871,0,17360.60667347908,0.3216067,2472,0.1086263278695184,19103.31728196144,0.2690903,0.0988207982694489,0.52979755,5348,0.1603830966334224 -1875.9892790317533,0.6764576435089111,18800.65876030922,23650,0,18800.65876030922,0.3166944,2472,0.1070420246582576,20678.319122314453,0.2543325,0.0942792468360942,0.51538944,5348,0.1583170008785734 -2007.8673095703125,0.731112003326416,20241.77404141426,25434,0,20241.77404141426,0.30282408,2472,0.1014157171003189,22251.44185328484,0.23121329,0.0862822515941049,0.4990429,5348,0.1501684736958977 -2141.1486024856567,0.7904298305511475,21681.949670553207,27251,0,21681.949670553207,0.29341123,2472,0.0982471106777974,23825.03625488281,0.21819447,0.082224366832729,0.4925936,5348,0.1478803209206677 -2275.6245658397675,0.8445065021514893,23122.05716252327,29050,0,23122.05716252327,0.28730962,2472,0.0981049296203765,25399.74922537804,0.21523209,0.0802436305243886,0.48609242,5348,0.146576942757562 -2408.1328547000885,0.8940913677215576,24562.581042051315,30830,0,24562.581042051315,0.27968472,2472,0.0959519021794324,26972.9074883461,0.22447392,0.0838804385508753,0.46955344,5348,0.1418944360234415 -2539.6788890361786,0.9472103118896484,26002.89443707466,32621,0,26002.89443707466,0.27777946,2472,0.0946316495033818,28544.89704513549,0.21071382,0.0778517517729342,0.4662035,5348,0.1414020487173793 -2673.288696050644,1.005300521850586,27443.38892006874,34442,0,27443.38892006874,0.26703775,2472,0.0911989925456502,30119.13695454597,0.21640244,0.0787028237675233,0.45866513,5348,0.1378974096565839 -2805.333705663681,1.062840461730957,28883.798013210297,36233,0,28883.798013210297,0.2658482,2472,0.0884772408750228,31691.72387599945,0.21171698,0.0761426729328516,0.4551876,5348,0.1361595721057763 -2939.579090833664,1.119497299194336,30324.08151173592,38013,0,30324.08151173592,0.25705752,2472,0.0859586050007109,33266.3851544857,0.16391788,0.0609964843229503,0.43490544,5348,0.1313708641879954 -3072.89638376236,1.1725895404815674,31764.34447383881,39801,0,31764.34447383881,0.25120482,2472,0.0843743017894501,34840.09559392929,0.1884352,0.0688739414162319,0.43504775,5348,0.1308591675758131 -3207.057755947113,1.2265305519104004,33204.71107053757,41617,0,33204.71107053757,0.24482445,2472,0.0829728027948733,36414.75404071808,0.23098333,0.0850767236376062,0.4273046,5348,0.1274800390047983 -3338.9981787204742,1.2806153297424316,34644.79412436485,43396,0,34644.79412436485,0.23914756,2472,0.0816119269595596,37986.90700483322,0.22959857,0.0844912483276617,0.42108253,5348,0.1264373364743138 -3470.2350482940674,1.3378307819366455,36085.03482174873,45174,0,36085.03482174873,0.23407225,2472,0.0779152194666179,39558.51782393456,0.26250243,0.0966651262642511,0.40786082,5348,0.122614093862537 -3602.4942483901978,1.3967421054840088,37525.23612308502,46972,0,37525.23612308502,0.22878855,2472,0.0766762131090934,41131.11450552941,0.22550887,0.0816548816890296,0.39984173,5348,0.1198914816996051 -3734.108068704605,1.460125207901001,38965.76935267448,48778,0,38965.76935267448,0.22130717,2472,0.0745231856681494,42703.40314888954,0.20397122,0.0767817576520102,0.3939008,5348,0.1173136893325738 -3866.873418092728,1.5228583812713623,40405.70114803314,50551,0,40405.70114803314,0.21482898,2472,0.0720857961123636,44276.238609075546,0.1726337,0.0656287724716962,0.38651025,5348,0.1130656419861552 -3999.761225223541,1.5781021118164062,41845.93539810181,52335,0,41845.93539810181,0.21250953,2472,0.0727357666605731,45849.49245977402,0.19079773,0.0716679248562133,0.37963584,5348,0.1131428792106355 -4131.691259860992,1.639273166656494,43285.85983061791,54147,0,43285.85983061791,0.20436308,2472,0.0686734507342636,47421.48486161232,0.16157928,0.0621564164232949,0.36532888,5348,0.1073983606399104 -4263.011621952057,1.696920394897461,44725.77019500733,55961,0,44725.77019500733,0.19516303,2472,0.0657892064265838,48992.8511402607,0.15674654,0.0598292452727844,0.35715374,5348,0.1051198625177404 -4394.548684120178,1.7538185119628906,46166.32611012459,57741,0,46166.32611012459,0.19236514,2472,0.063575244246745,50565.07560944557,0.14496757,0.0550945479641131,0.3530484,5348,0.1040964692933759 -4525.326953172684,1.8184218406677248,47606.610372543335,59535,0,47606.610372543335,0.1857698,2472,0.0610566083724331,52136.28013706207,0.1456565,0.0555097695820885,0.3391099,5348,0.1003890825183197 -4658.723962068558,1.8841652870178225,49047.04236245155,61341,0,49047.04236245155,0.1807087,2472,0.0595941746389616,53710.25221323967,0.1419333,0.0538886287151896,0.33404234,5348,0.0982457495389903 -4790.242144107819,1.94150710105896,50487.46106958389,63155,0,50487.46106958389,0.17809816,2472,0.0581520524851217,55282.32294297218,0.11702894,0.045682292016948,0.32601267,5348,0.0950790233352964 -4921.200897216797,1.9962365627288816,51927.87755322456,64934,0,51927.87755322456,0.16898389,2472,0.0554099892348627,56853.82898592949,0.11297734,0.0438991830587376,0.3239654,5348,0.0935728974579298 -5052.923782587051,2.061052083969116,53368.10543203354,66736,0,53368.10543203354,0.16851163,2472,0.0558162208274937,58425.920551776886,0.10090725,0.0388173458911752,0.31437576,5348,0.0914006005194203 -5182.85727763176,2.1215648651123047,54808.36094260216,68543,0,54808.36094260216,0.16452241,2472,0.054231917616233,59996.24846315384,0.09813993,0.0374010032054749,0.3097387,5348,0.089740000193093 -5314.483243465424,2.182187557220459,56248.92697405815,70351,0,56248.92697405815,0.1618215,2472,0.0524851217679198,61568.578528642654,0.08928266,0.0343127389816864,0.30275312,5348,0.0866118926016393 -5446.1889128685,2.245168685913086,57688.89278316498,72130,0,57688.89278316498,0.15875757,2472,0.0512461154103954,63140.387882232666,0.08532207,0.0323987469975537,0.29723024,5348,0.0853857516630139 -5576.509063482285,2.30984878540039,59128.87972784042,73933,0,59128.87972784042,0.15754385,2472,0.0511648690918692,64710.83512163162,0.07010609,0.0270024254691219,0.29594654,5348,0.0836865327244465 -5706.956416845322,2.3700406551361084,60569.61856889725,75739,0,60569.61856889725,0.15539056,2472,0.04945869640281925,66282.15744042397,0.066934444,0.025029311311617473,0.2931396,5348,0.08405340954072815 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index bd19f7c49..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,809 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,33.21241,32.966354,,,,,,,,,,,,,, -1,,,32.12085,1.3685577743997668,31.08705,1.0585651254622166,5348.0,31.23744,1.10235004976337,2472.0,75.61166572570801,253.83884978294373,75.61166572570801,178.22711658477783,0.0,0.0 -100,9.983563,12.461342,,,,,,,,,,,,,, -200,7.734189,8.108964,,,,,,,,,,,,,, -300,1.9566586,5.884061,,,,,,,,,,,,,, -400,0.297594,5.817329,,,,,,,,,,,,,, -500,0.25465888,5.805249,,,,,,,,,,,,,, -600,0.51671726,5.8326507,,,,,,,,,,,,,, -700,0.25856465,5.816505,,,,,,,,,,,,,, -800,0.36578134,5.7941756,,,,,,,,,,,,,, -900,0.310337,5.7866507,,,,,,,,,,,,,, -1000,0.3963677,5.784909,,,,,,,,,,,,,, -1100,0.38422942,5.7926316,,,,,,,,,,,,,, -1200,0.2434211,5.7883587,,,,,,,,,,,,,, -1300,0.32350108,5.7737074,,,,,,,,,,,,,, -1400,0.9222871,5.783597,,,,,,,,,,,,,, -1500,0.29649183,5.790073,,,,,,,,,,,,,, -1600,0.42886877,5.7762556,,,,,,,,,,,,,, -1700,0.58826274,5.740625,,,,,,,,,,,,,, -1800,0.9408768,5.7257366,,,,,,,,,,,,,, -1844,,,6.1664042,0.944635537887994,6.2940016,0.8966179750330672,5348.0,6.2753453,0.899579550301627,2472.0,1515.929588317871,1799.9081389904022,1515.929588317871,283.8664665222168,0.0405006408691406,0.0 -1900,0.42002887,5.564891,,,,,,,,,,,,,, -2000,0.40838405,5.4992003,,,,,,,,,,,,,, -2100,1.1358263,5.445907,,,,,,,,,,,,,, -2200,0.70550215,5.199731,,,,,,,,,,,,,, -2300,0.8540748,4.6469035,,,,,,,,,,,,,, -2400,0.8894241,4.130167,,,,,,,,,,,,,, -2500,1.0272408,3.777485,,,,,,,,,,,,,, -2600,0.90114623,3.6219575,,,,,,,,,,,,,, -2700,1.201091,3.4155712,,,,,,,,,,,,,, -2800,1.0838958,3.2711153,,,,,,,,,,,,,, -2900,1.473424,3.1744847,,,,,,,,,,,,,, -3000,1.4715248,3.1371152,,,,,,,,,,,,,, -3100,0.8342765,2.9878352,,,,,,,,,,,,,, -3200,1.3438548,2.9279299,,,,,,,,,,,,,, -3300,1.6347708,2.8715997,,,,,,,,,,,,,, -3400,0.9641251,2.7868562,,,,,,,,,,,,,, -3500,1.0264078,2.7086475,,,,,,,,,,,,,, -3600,0.8189949,2.7102041,,,,,,,,,,,,,, -3700,0.7207583,2.6528955,,,,,,,,,,,,,, -3701,,,2.3516572,0.5535412726740004,2.7404253,0.5721250856850459,5348.0,2.4001753,0.5354335506672354,2472.0,2956.97966003418,3370.4715399742126,2956.97966003418,413.2503607273102,0.0917296409606933,0.0 -3800,0.8361738,2.544724,,,,,,,,,,,,,, -3900,0.7557063,2.458973,,,,,,,,,,,,,, -4000,0.8845719,2.4472737,,,,,,,,,,,,,, -4100,0.89998704,2.4855928,,,,,,,,,,,,,, -4200,0.7802202,2.3884687,,,,,,,,,,,,,, -4300,1.0139604,2.3172941,,,,,,,,,,,,,, -4400,0.88325375,2.273337,,,,,,,,,,,,,, -4500,1.3582932,2.2643127,,,,,,,,,,,,,, -4600,0.9533034,2.1892579,,,,,,,,,,,,,, -4700,1.0467693,2.1984365,,,,,,,,,,,,,, -4800,0.9151458,2.1454082,,,,,,,,,,,,,, -4900,0.88426614,2.1759312,,,,,,,,,,,,,, -5000,1.0184906,2.0934825,,,,,,,,,,,,,, -5100,0.7641158,2.0912015,,,,,,,,,,,,,, -5200,0.88312745,2.0199742,,,,,,,,,,,,,, -5300,0.78910434,2.024872,,,,,,,,,,,,,, -5400,1.0291661,2.0031853,,,,,,,,,,,,,, -5500,0.7799565,1.9532471,,,,,,,,,,,,,, -5551,,,0.7146427,0.2437602458476314,1.0680373,0.3108122459619413,5348.0,0.78583676,0.2550118822740844,2472.0,4397.298481225967,4942.37772154808,4397.298481225967,544.7035682201385,0.1459474563598632,0.0 -5600,0.9299503,1.9500828,,,,,,,,,,,,,, -5700,0.6646159,1.9642781,,,,,,,,,,,,,, -5800,0.6984374,1.9510616,,,,,,,,,,,,,, -5900,0.8121982,1.965237,,,,,,,,,,,,,, -6000,0.83052415,1.9033912,,,,,,,,,,,,,, -6100,0.7201721,1.8969798,,,,,,,,,,,,,, -6200,0.7017038,1.8482485,,,,,,,,,,,,,, -6300,0.8703515,1.8341068,,,,,,,,,,,,,, -6400,0.6815411,1.8296561,,,,,,,,,,,,,, -6500,0.76086205,1.8802598,,,,,,,,,,,,,, -6600,0.7309242,1.7879442,,,,,,,,,,,,,, -6700,0.86315554,1.8960544,,,,,,,,,,,,,, -6800,0.8242069,1.773305,,,,,,,,,,,,,, -6900,0.74859375,1.819983,,,,,,,,,,,,,, -7000,0.74581224,1.8089203,,,,,,,,,,,,,, -7100,0.65829635,1.8341149,,,,,,,,,,,,,, -7200,0.7459347,1.8010924,,,,,,,,,,,,,, -7300,0.66916937,1.745122,,,,,,,,,,,,,, -7377,,,0.5079412,0.1768737760370304,0.821356,0.2468405147861011,5348.0,0.5604038,0.1868055978713464,2472.0,5837.907940626144,6514.627334594727,5837.907940626144,676.2148461341858,0.1972596645355224,0.0 -7400,0.6956134,1.7124631,,,,,,,,,,,,,, -7500,0.82036465,1.8332255,,,,,,,,,,,,,, -7600,0.8061602,1.7414794,,,,,,,,,,,,,, -7700,0.69524574,1.7902658,,,,,,,,,,,,,, -7800,0.8580008,1.7359397,,,,,,,,,,,,,, -7900,0.6678094,1.7229782,,,,,,,,,,,,,, -8000,0.62647486,1.748769,,,,,,,,,,,,,, -8100,0.67894346,1.6850411,,,,,,,,,,,,,, -8200,0.7328283,1.7245786,,,,,,,,,,,,,, -8300,0.67931277,1.7054114,,,,,,,,,,,,,, -8400,0.7050513,1.6118149,,,,,,,,,,,,,, -8500,0.70786244,1.6816897,,,,,,,,,,,,,, -8600,0.7337898,1.6840727,,,,,,,,,,,,,, -8700,0.730264,1.7077366,,,,,,,,,,,,,, -8800,0.6759494,1.6636859,,,,,,,,,,,,,, -8900,0.69227487,1.5920509,,,,,,,,,,,,,, -9000,0.8077222,1.5863843,,,,,,,,,,,,,, -9100,0.64429,1.6644491,,,,,,,,,,,,,, -9199,,,0.43121392,0.1545967513970732,0.7380511,0.222839047278836,5348.0,0.48268652,0.1639956939451181,2472.0,7277.963893651962,8086.693351507187,7277.963893651962,808.0982096195221,0.2487895488739013,0.0 -9200,0.82955617,1.7076337,,,,,,,,,,,,,, -9300,0.8451242,1.6367102,,,,,,,,,,,,,, -9400,0.8188917,1.6348354,,,,,,,,,,,,,, -9500,0.5944735,1.6273934,,,,,,,,,,,,,, -9600,0.7626678,1.577959,,,,,,,,,,,,,, -9700,0.66844124,1.634835,,,,,,,,,,,,,, -9800,0.7396521,1.658123,,,,,,,,,,,,,, -9900,0.6855019,1.5996308,,,,,,,,,,,,,, -10000,0.679845,1.6368026,,,,,,,,,,,,,, -10100,0.6879979,1.6687781,,,,,,,,,,,,,, -10200,0.66951025,1.5618879,,,,,,,,,,,,,, -10300,0.62210065,1.5982609,,,,,,,,,,,,,, -10400,0.69744843,1.5663896,,,,,,,,,,,,,, -10500,0.7093265,1.6075362,,,,,,,,,,,,,, -10600,0.60360116,1.6465988,,,,,,,,,,,,,, -10700,0.5906874,1.5199091,,,,,,,,,,,,,, -10800,0.6714099,1.5481186,,,,,,,,,,,,,, -10900,0.68710697,1.5724611,,,,,,,,,,,,,, -11000,0.75721526,1.5822247,,,,,,,,,,,,,, -11020,,,0.41369617,0.1462302148294754,0.6845674,0.2078839896888305,5348.0,0.4419891,0.1509759714012958,2472.0,8718.414444446564,9660.191576719284,8718.414444446564,941.0179760456084,0.3003580570220947,0.0 -11100,0.64220405,1.6013824,,,,,,,,,,,,,, -11200,0.63710934,1.5651814,,,,,,,,,,,,,, -11300,0.6604585,1.5361693,,,,,,,,,,,,,, -11400,0.72821933,1.5529604,,,,,,,,,,,,,, -11500,0.588901,1.5798535,,,,,,,,,,,,,, -11600,0.59391016,1.5661535,,,,,,,,,,,,,, -11700,0.7498773,1.5393685,,,,,,,,,,,,,, -11800,0.6446552,1.5622381,,,,,,,,,,,,,, -11900,0.6216778,1.5373589,,,,,,,,,,,,,, -12000,0.7101223,1.5210326,,,,,,,,,,,,,, -12100,0.7013983,1.5445713,,,,,,,,,,,,,, -12200,0.6694419,1.5935885,,,,,,,,,,,,,, -12300,0.74225825,1.508092,,,,,,,,,,,,,, -12400,0.6706859,1.5015607,,,,,,,,,,,,,, -12500,0.63230765,1.4774479,,,,,,,,,,,,,, -12600,0.6259533,1.5592904,,,,,,,,,,,,,, -12700,0.63905925,1.5376062,,,,,,,,,,,,,, -12800,0.78279376,1.5054674,,,,,,,,,,,,,, -12857,,,0.35117,0.1268904345606534,0.6366001,0.1944157486700715,5348.0,0.402794,0.1361485182702659,2472.0,10159.513425588608,11235.706638336182,10159.513425588608,1075.3007299900055,0.3552329540252685,0.0 -12900,0.6438839,1.5121471,,,,,,,,,,,,,, -13000,0.63576436,1.4404951,,,,,,,,,,,,,, -13100,0.5355639,1.4979258,,,,,,,,,,,,,, -13200,0.5850212,1.5138848,,,,,,,,,,,,,, -13300,0.65263045,1.4944075,,,,,,,,,,,,,, -13400,0.59487677,1.4523978,,,,,,,,,,,,,, -13500,0.57541114,1.4740708,,,,,,,,,,,,,, -13600,0.5705086,1.502896,,,,,,,,,,,,,, -13700,0.6734678,1.4717256,,,,,,,,,,,,,, -13800,0.7320103,1.433317,,,,,,,,,,,,,, -13900,0.6773922,1.4377099,,,,,,,,,,,,,, -14000,0.76067394,1.5383608,,,,,,,,,,,,,, -14100,0.7166124,1.4191934,,,,,,,,,,,,,, -14200,0.6110212,1.5002799,,,,,,,,,,,,,, -14300,0.70947284,1.4870098,,,,,,,,,,,,,, -14400,0.5934903,1.4740988,,,,,,,,,,,,,, -14500,0.7113027,1.4621658,,,,,,,,,,,,,, -14600,0.5932393,1.4886124,,,,,,,,,,,,,, -14670,,,0.301788,0.1133882836267193,0.6053062,0.1848866060998098,5348.0,0.38013542,0.1297909938455913,2472.0,11599.84973692894,12809.905858516691,11599.84973692894,1209.0333228111267,0.4085783958435058,0.0 -14700,0.6325256,1.5651873,,,,,,,,,,,,,, -14800,0.54839563,1.4477608,,,,,,,,,,,,,, -14900,0.55296016,1.4921806,,,,,,,,,,,,,, -15000,0.6049938,1.4435253,,,,,,,,,,,,,, -15100,0.730218,1.4539089,,,,,,,,,,,,,, -15200,0.783701,1.471281,,,,,,,,,,,,,, -15300,0.8188775,1.515104,,,,,,,,,,,,,, -15400,0.76587814,1.4438281,,,,,,,,,,,,,, -15500,0.55140036,1.4538571,,,,,,,,,,,,,, -15600,0.6075353,1.4034499,,,,,,,,,,,,,, -15700,0.560748,1.4672905,,,,,,,,,,,,,, -15800,0.6641683,1.4414243,,,,,,,,,,,,,, -15900,0.5570737,1.4714768,,,,,,,,,,,,,, -16000,0.7205768,1.3964936,,,,,,,,,,,,,, -16100,0.9125459,1.4647804,,,,,,,,,,,,,, -16200,0.66712445,1.4503301,,,,,,,,,,,,,, -16300,0.6570682,1.4211922,,,,,,,,,,,,,, -16400,0.6403898,1.474854,,,,,,,,,,,,,, -16457,,,0.27452624,0.1010592172113486,0.5743748,0.1744981994072043,5348.0,0.36317125,0.1238600125931793,2472.0,13039.76072025299,14383.49019765854,13039.76072025299,1342.5781121253967,0.4630181789398193,0.0 -16500,0.6326728,1.4486239,,,,,,,,,,,,,, -16600,0.6876586,1.4263575,,,,,,,,,,,,,, -16700,0.6223419,1.4167514,,,,,,,,,,,,,, -16800,0.5797487,1.4565861,,,,,,,,,,,,,, -16900,0.6142772,1.455088,,,,,,,,,,,,,, -17000,0.5806687,1.4421129,,,,,,,,,,,,,, -17100,0.62477326,1.465157,,,,,,,,,,,,,, -17200,0.85086316,1.4161229,,,,,,,,,,,,,, -17300,0.69475245,1.4360869,,,,,,,,,,,,,, -17400,0.71086514,1.4771632,,,,,,,,,,,,,, -17500,0.60658276,1.3951889,,,,,,,,,,,,,, -17600,0.70810825,1.4430337,,,,,,,,,,,,,, -17700,0.7256587,1.3916188,,,,,,,,,,,,,, -17800,0.5978139,1.4188223,,,,,,,,,,,,,, -17900,0.64270264,1.3858737,,,,,,,,,,,,,, -18000,0.63797575,1.436378,,,,,,,,,,,,,, -18100,1.061914,1.401363,,,,,,,,,,,,,, -18200,0.65917486,1.4360862,,,,,,,,,,,,,, -18251,,,0.2666409,0.1004896308730231,0.5565517,0.1694584705098622,5348.0,0.34133852,0.1155119533646131,2472.0,14480.069234132769,15956.655859947205,14480.069234132769,1475.30983710289,0.512258768081665,0.0 -18300,0.656222,1.4061584,,,,,,,,,,,,,, -18400,0.6367197,1.3642521,,,,,,,,,,,,,, -18500,0.610121,1.3918335,,,,,,,,,,,,,, -18600,0.5879678,1.3250277,,,,,,,,,,,,,, -18700,0.82015,1.4441663,,,,,,,,,,,,,, -18800,0.74687874,1.4184171,,,,,,,,,,,,,, -18900,0.66192347,1.4025211,,,,,,,,,,,,,, -19000,0.6455687,1.3498714,,,,,,,,,,,,,, -19100,0.8041407,1.4558516,,,,,,,,,,,,,, -19200,0.57925755,1.3684982,,,,,,,,,,,,,, -19300,0.658583,1.3753393,,,,,,,,,,,,,, -19400,0.68826425,1.4317068,,,,,,,,,,,,,, -19500,0.5927798,1.3891972,,,,,,,,,,,,,, -19600,0.6915652,1.4007895,,,,,,,,,,,,,, -19700,0.6495363,1.4284775,,,,,,,,,,,,,, -19800,0.57305443,1.4248096,,,,,,,,,,,,,, -19900,0.61951095,1.4127411,,,,,,,,,,,,,, -20000,0.75964385,1.4152641,,,,,,,,,,,,,, -20072,,,0.27632526,0.0999444723551653,0.53343123,0.1639939368778783,5348.0,0.3296891,0.113358925923669,2472.0,15920.46349310875,17530.263286590576,15920.46349310875,1608.389491558075,0.5656495094299316,0.0 -20100,0.664606,1.3915294,,,,,,,,,,,,,, -20200,0.70098555,1.3953317,,,,,,,,,,,,,, -20300,0.609935,1.3239253,,,,,,,,,,,,,, -20400,0.7171242,1.4278731,,,,,,,,,,,,,, -20500,0.78078717,1.3933072,,,,,,,,,,,,,, -20600,0.6033562,1.3406041,,,,,,,,,,,,,, -20700,0.6319208,1.3116006,,,,,,,,,,,,,, -20800,0.5528214,1.4039341,,,,,,,,,,,,,, -20900,0.6639984,1.3912313,,,,,,,,,,,,,, -21000,0.6082505,1.3852172,,,,,,,,,,,,,, -21100,0.7216615,1.3774658,,,,,,,,,,,,,, -21200,0.6404696,1.4024875,,,,,,,,,,,,,, -21300,0.6771079,1.364877,,,,,,,,,,,,,, -21400,0.65216887,1.4008081,,,,,,,,,,,,,, -21500,0.57535547,1.3872008,,,,,,,,,,,,,, -21600,0.69722885,1.361947,,,,,,,,,,,,,, -21700,0.70882845,1.3779174,,,,,,,,,,,,,, -21800,0.7367233,1.396208,,,,,,,,,,,,,, -21871,,,0.2690903,0.0988207982694489,0.52979755,0.1603830966334224,5348.0,0.3216067,0.1086263278695184,2472.0,17360.60667347908,19103.31728196144,17360.60667347908,1741.165496110916,0.6244997978210449,0.0 -21900,0.71164954,1.320838,,,,,,,,,,,,,, -22000,0.5755798,1.39628,,,,,,,,,,,,,, -22100,0.7252333,1.3483323,,,,,,,,,,,,,, -22200,0.6629494,1.3224941,,,,,,,,,,,,,, -22300,0.5888684,1.3724744,,,,,,,,,,,,,, -22400,0.59133303,1.3251847,,,,,,,,,,,,,, -22500,0.66334105,1.3444952,,,,,,,,,,,,,, -22600,0.57064384,1.3357228,,,,,,,,,,,,,, -22700,0.71023697,1.3255053,,,,,,,,,,,,,, -22800,0.7280311,1.3785155,,,,,,,,,,,,,, -22900,0.81746125,1.3318573,,,,,,,,,,,,,, -23000,0.6268785,1.3666021,,,,,,,,,,,,,, -23100,0.7031548,1.3302326,,,,,,,,,,,,,, -23200,0.65634316,1.3490411,,,,,,,,,,,,,, -23300,0.6846848,1.3691308,,,,,,,,,,,,,, -23400,0.5970543,1.3490459,,,,,,,,,,,,,, -23500,0.79717916,1.3279778,,,,,,,,,,,,,, -23600,0.6835988,1.367861,,,,,,,,,,,,,, -23650,,,0.2543325,0.0942792468360942,0.51538944,0.1583170008785734,5348.0,0.3166944,0.1070420246582576,2472.0,18800.65876030922,20678.319122314453,18800.65876030922,1875.9892790317533,0.6764576435089111,0.0 -23700,0.6952201,1.3520787,,,,,,,,,,,,,, -23800,0.5884546,1.3156643,,,,,,,,,,,,,, -23900,0.76377565,1.3431728,,,,,,,,,,,,,, -24000,0.60961384,1.344938,,,,,,,,,,,,,, -24100,0.63384724,1.2865212,,,,,,,,,,,,,, -24200,0.6302128,1.328498,,,,,,,,,,,,,, -24300,0.72927046,1.3353969,,,,,,,,,,,,,, -24400,0.60179234,1.3596597,,,,,,,,,,,,,, -24500,0.6510388,1.3234383,,,,,,,,,,,,,, -24600,0.6784646,1.3703467,,,,,,,,,,,,,, -24700,0.6767397,1.3140533,,,,,,,,,,,,,, -24800,0.796337,1.3647465,,,,,,,,,,,,,, -24900,0.6734172,1.3254627,,,,,,,,,,,,,, -25000,0.8014134,1.3467528,,,,,,,,,,,,,, -25100,0.5853746,1.3369939,,,,,,,,,,,,,, -25200,0.6244779,1.3454233,,,,,,,,,,,,,, -25300,0.6609678,1.2800108,,,,,,,,,,,,,, -25400,0.6887763,1.3754873,,,,,,,,,,,,,, -25434,,,0.23121329,0.0862822515941049,0.4990429,0.1501684736958977,5348.0,0.30282408,0.1014157171003189,2472.0,20241.77404141426,22251.44185328484,20241.77404141426,2007.8673095703125,0.731112003326416,0.0 -25500,0.63854057,1.292237,,,,,,,,,,,,,, -25600,0.81961244,1.2767388,,,,,,,,,,,,,, -25700,0.6893062,1.2770756,,,,,,,,,,,,,, -25800,0.70238626,1.278374,,,,,,,,,,,,,, -25900,0.7135637,1.3345103,,,,,,,,,,,,,, -26000,0.5464558,1.306543,,,,,,,,,,,,,, -26100,0.83733326,1.3261865,,,,,,,,,,,,,, -26200,0.78237617,1.3302304,,,,,,,,,,,,,, -26300,0.772075,1.322818,,,,,,,,,,,,,, -26400,0.6351698,1.2967541,,,,,,,,,,,,,, -26500,0.5505367,1.2870014,,,,,,,,,,,,,, -26600,0.71129113,1.3451467,,,,,,,,,,,,,, -26700,0.6774754,1.3433043,,,,,,,,,,,,,, -26800,0.62159854,1.3320112,,,,,,,,,,,,,, -26900,0.59327924,1.2640846,,,,,,,,,,,,,, -27000,0.7824734,1.3258755,,,,,,,,,,,,,, -27100,0.6222213,1.3554947,,,,,,,,,,,,,, -27200,0.71793634,1.326082,,,,,,,,,,,,,, -27251,,,0.21819447,0.082224366832729,0.4925936,0.1478803209206677,5348.0,0.29341123,0.0982471106777974,2472.0,21681.949670553207,23825.03625488281,21681.949670553207,2141.1486024856567,0.7904298305511475,0.0 -27300,0.68690056,1.3340352,,,,,,,,,,,,,, -27400,0.66974735,1.3058565,,,,,,,,,,,,,, -27500,0.59155405,1.3155594,,,,,,,,,,,,,, -27600,0.7031822,1.3715811,,,,,,,,,,,,,, -27700,0.69181466,1.28707,,,,,,,,,,,,,, -27800,0.67902225,1.3012716,,,,,,,,,,,,,, -27900,0.69939005,1.2620062,,,,,,,,,,,,,, -28000,0.5378313,1.2518698,,,,,,,,,,,,,, -28100,0.7396673,1.3102252,,,,,,,,,,,,,, -28200,0.6329663,1.3151097,,,,,,,,,,,,,, -28300,0.62774396,1.2986456,,,,,,,,,,,,,, -28400,0.6093675,1.3020955,,,,,,,,,,,,,, -28500,0.71918964,1.3324171,,,,,,,,,,,,,, -28600,0.6863877,1.2906263,,,,,,,,,,,,,, -28700,0.7913007,1.325129,,,,,,,,,,,,,, -28800,0.7693012,1.287236,,,,,,,,,,,,,, -28900,0.6832045,1.3450247,,,,,,,,,,,,,, -29000,0.6922307,1.2711005,,,,,,,,,,,,,, -29050,,,0.21523209,0.0802436305243886,0.48609242,0.146576942757562,5348.0,0.28730962,0.0981049296203765,2472.0,23122.05716252327,25399.74922537804,23122.05716252327,2275.6245658397675,0.8445065021514893,0.0 -29100,0.6057116,1.319761,,,,,,,,,,,,,, -29200,0.63362575,1.2463932,,,,,,,,,,,,,, -29300,0.6563654,1.297709,,,,,,,,,,,,,, -29400,0.7723191,1.2562077,,,,,,,,,,,,,, -29500,0.98613715,1.287303,,,,,,,,,,,,,, -29600,0.93226683,1.3284521,,,,,,,,,,,,,, -29700,0.652132,1.3214278,,,,,,,,,,,,,, -29800,0.7955818,1.2905141,,,,,,,,,,,,,, -29900,0.6235093,1.3168793,,,,,,,,,,,,,, -30000,0.68528426,1.2922082,,,,,,,,,,,,,, -30100,0.7793556,1.2823994,,,,,,,,,,,,,, -30200,0.6586088,1.3130344,,,,,,,,,,,,,, -30300,0.7148057,1.2886717,,,,,,,,,,,,,, -30400,0.9046023,1.2702492,,,,,,,,,,,,,, -30500,0.7262186,1.3341917,,,,,,,,,,,,,, -30600,0.6815282,1.273291,,,,,,,,,,,,,, -30700,0.6913267,1.3242055,,,,,,,,,,,,,, -30800,0.669817,1.289181,,,,,,,,,,,,,, -30830,,,0.22447392,0.0838804385508753,0.46955344,0.1418944360234415,5348.0,0.27968472,0.0959519021794324,2472.0,24562.581042051315,26972.9074883461,24562.581042051315,2408.1328547000885,0.8940913677215576,0.0 -30900,0.6139009,1.263562,,,,,,,,,,,,,, -31000,0.78689677,1.2744567,,,,,,,,,,,,,, -31100,0.68694854,1.3045069,,,,,,,,,,,,,, -31200,0.61114454,1.2079575,,,,,,,,,,,,,, -31300,0.6853625,1.2739555,,,,,,,,,,,,,, -31400,0.77612597,1.2955089,,,,,,,,,,,,,, -31500,0.68015456,1.2900016,,,,,,,,,,,,,, -31600,0.7413328,1.2760825,,,,,,,,,,,,,, -31700,0.7129882,1.2858181,,,,,,,,,,,,,, -31800,0.72186464,1.2700216,,,,,,,,,,,,,, -31900,0.74745446,1.3132725,,,,,,,,,,,,,, -32000,0.6065154,1.2109427,,,,,,,,,,,,,, -32100,0.71501565,1.2962703,,,,,,,,,,,,,, -32200,0.74575716,1.2470505,,,,,,,,,,,,,, -32300,0.6595567,1.2531246,,,,,,,,,,,,,, -32400,0.757074,1.2735647,,,,,,,,,,,,,, -32500,0.6410949,1.2379609,,,,,,,,,,,,,, -32600,0.6585343,1.2612723,,,,,,,,,,,,,, -32621,,,0.21071382,0.0778517517729342,0.4662035,0.1414020487173793,5348.0,0.27777946,0.0946316495033818,2472.0,26002.89443707466,28544.89704513549,26002.89443707466,2539.6788890361786,0.9472103118896484,0.0 -32700,0.76781976,1.2606198,,,,,,,,,,,,,, -32800,0.6849071,1.3374727,,,,,,,,,,,,,, -32900,0.6966297,1.2744901,,,,,,,,,,,,,, -33000,0.7286338,1.2215548,,,,,,,,,,,,,, -33100,0.8120851,1.2478704,,,,,,,,,,,,,, -33200,0.74058294,1.2382492,,,,,,,,,,,,,, -33300,0.76954526,1.3248918,,,,,,,,,,,,,, -33400,0.6378895,1.2739292,,,,,,,,,,,,,, -33500,0.71312195,1.2495555,,,,,,,,,,,,,, -33600,0.6771078,1.2438077,,,,,,,,,,,,,, -33700,0.7238767,1.2263379,,,,,,,,,,,,,, -33800,0.63211703,1.3363957,,,,,,,,,,,,,, -33900,0.71897185,1.230399,,,,,,,,,,,,,, -34000,0.6693177,1.2067074,,,,,,,,,,,,,, -34100,0.6493326,1.2982384,,,,,,,,,,,,,, -34200,0.59227854,1.2464068,,,,,,,,,,,,,, -34300,0.69750094,1.1991755,,,,,,,,,,,,,, -34400,0.6243921,1.2445729,,,,,,,,,,,,,, -34442,,,0.21640244,0.0787028237675233,0.45866513,0.1378974096565839,5348.0,0.26703775,0.0911989925456502,2472.0,27443.38892006874,30119.13695454597,27443.38892006874,2673.288696050644,1.005300521850586,0.0 -34500,0.6832928,1.2086782,,,,,,,,,,,,,, -34600,0.7775732,1.288613,,,,,,,,,,,,,, -34700,0.7365724,1.2297757,,,,,,,,,,,,,, -34800,0.742515,1.2816203,,,,,,,,,,,,,, -34900,0.6353212,1.2411366,,,,,,,,,,,,,, -35000,0.6385358,1.2807103,,,,,,,,,,,,,, -35100,0.6747811,1.2295756,,,,,,,,,,,,,, -35200,0.66464186,1.2272706,,,,,,,,,,,,,, -35300,1.1601167,1.2954485,,,,,,,,,,,,,, -35400,0.69886667,1.2338765,,,,,,,,,,,,,, -35500,0.8206623,1.2474096,,,,,,,,,,,,,, -35600,0.7071913,1.2487247,,,,,,,,,,,,,, -35700,0.63907677,1.2278738,,,,,,,,,,,,,, -35800,0.7459086,1.2469553,,,,,,,,,,,,,, -35900,0.8158639,1.3155061,,,,,,,,,,,,,, -36000,0.9018087,1.2237486,,,,,,,,,,,,,, -36100,0.91675705,1.2317419,,,,,,,,,,,,,, -36200,0.8100051,1.2319148,,,,,,,,,,,,,, -36233,,,0.21171698,0.0761426729328516,0.4551876,0.1361595721057763,5348.0,0.2658482,0.0884772408750228,2472.0,28883.798013210297,31691.72387599945,28883.798013210297,2805.333705663681,1.062840461730957,0.0 -36300,0.6782378,1.2369748,,,,,,,,,,,,,, -36400,0.6547641,1.2237695,,,,,,,,,,,,,, -36500,0.5931549,1.1966571,,,,,,,,,,,,,, -36600,0.6470267,1.2165796,,,,,,,,,,,,,, -36700,0.7192595,1.1925924,,,,,,,,,,,,,, -36800,0.68811625,1.2722257,,,,,,,,,,,,,, -36900,0.70381075,1.2537069,,,,,,,,,,,,,, -37000,0.61768526,1.2456264,,,,,,,,,,,,,, -37100,0.9249444,1.2133571,,,,,,,,,,,,,, -37200,0.7246305,1.2198843,,,,,,,,,,,,,, -37300,0.7089352,1.212176,,,,,,,,,,,,,, -37400,0.6739758,1.2113366,,,,,,,,,,,,,, -37500,0.6193462,1.2096297,,,,,,,,,,,,,, -37600,0.792353,1.2744397,,,,,,,,,,,,,, -37700,0.70069104,1.2727331,,,,,,,,,,,,,, -37800,0.68986624,1.225665,,,,,,,,,,,,,, -37900,0.67114836,1.18607,,,,,,,,,,,,,, -38000,0.74994504,1.2205448,,,,,,,,,,,,,, -38013,,,0.16391788,0.0609964843229503,0.43490544,0.1313708641879954,5348.0,0.25705752,0.0859586050007109,2472.0,30324.08151173592,33266.3851544857,30324.08151173592,2939.579090833664,1.119497299194336,0.0 -38100,0.80499166,1.2555063,,,,,,,,,,,,,, -38200,0.71571887,1.2021663,,,,,,,,,,,,,, -38300,0.73036975,1.2119684,,,,,,,,,,,,,, -38400,0.64768153,1.2111378,,,,,,,,,,,,,, -38500,0.7577542,1.1790105,,,,,,,,,,,,,, -38600,0.6683185,1.2446755,,,,,,,,,,,,,, -38700,0.815539,1.2192943,,,,,,,,,,,,,, -38800,0.66254073,1.229042,,,,,,,,,,,,,, -38900,0.7240636,1.1666101,,,,,,,,,,,,,, -39000,0.7132943,1.179785,,,,,,,,,,,,,, -39100,0.66694546,1.1873847,,,,,,,,,,,,,, -39200,0.7423563,1.2565495,,,,,,,,,,,,,, -39300,0.5975994,1.180793,,,,,,,,,,,,,, -39400,0.98128825,1.1759346,,,,,,,,,,,,,, -39500,0.78927815,1.2190825,,,,,,,,,,,,,, -39600,0.82168436,1.2136898,,,,,,,,,,,,,, -39700,0.7649733,1.2109811,,,,,,,,,,,,,, -39800,0.7682652,1.2595978,,,,,,,,,,,,,, -39801,,,0.1884352,0.0688739414162319,0.43504775,0.1308591675758131,5348.0,0.25120482,0.0843743017894501,2472.0,31764.34447383881,34840.09559392929,31764.34447383881,3072.89638376236,1.1725895404815674,0.0 -39900,0.71862316,1.2246957,,,,,,,,,,,,,, -40000,0.81785756,1.1958674,,,,,,,,,,,,,, -40100,0.83203655,1.2204506,,,,,,,,,,,,,, -40200,0.7167887,1.2100822,,,,,,,,,,,,,, -40300,0.7450099,1.162228,,,,,,,,,,,,,, -40400,0.7387628,1.2524943,,,,,,,,,,,,,, -40500,0.82257503,1.198882,,,,,,,,,,,,,, -40600,0.77749866,1.2045829,,,,,,,,,,,,,, -40700,0.7657988,1.1434518,,,,,,,,,,,,,, -40800,0.78495634,1.1995587,,,,,,,,,,,,,, -40900,0.82253116,1.204297,,,,,,,,,,,,,, -41000,0.70528203,1.2263068,,,,,,,,,,,,,, -41100,0.7638619,1.1975594,,,,,,,,,,,,,, -41200,0.6368992,1.1140423,,,,,,,,,,,,,, -41300,0.8149149,1.1727128,,,,,,,,,,,,,, -41400,0.71837777,1.2197907,,,,,,,,,,,,,, -41500,0.76676106,1.2337002,,,,,,,,,,,,,, -41600,0.66989595,1.2091744,,,,,,,,,,,,,, -41617,,,0.23098333,0.0850767236376062,0.4273046,0.1274800390047983,5348.0,0.24482445,0.0829728027948733,2472.0,33204.71107053757,36414.75404071808,33204.71107053757,3207.057755947113,1.2265305519104004,0.0 -41700,0.7340641,1.1829104,,,,,,,,,,,,,, -41800,0.70545155,1.1943785,,,,,,,,,,,,,, -41900,0.6777623,1.2064322,,,,,,,,,,,,,, -42000,0.6938196,1.1711942,,,,,,,,,,,,,, -42100,0.74164,1.1440481,,,,,,,,,,,,,, -42200,0.91565907,1.1934271,,,,,,,,,,,,,, -42300,0.8925627,1.1230589,,,,,,,,,,,,,, -42400,0.6918188,1.1977534,,,,,,,,,,,,,, -42500,1.0195647,1.1778327,,,,,,,,,,,,,, -42600,0.72959465,1.1939824,,,,,,,,,,,,,, -42700,0.9793472,1.1974763,,,,,,,,,,,,,, -42800,0.65602213,1.1705084,,,,,,,,,,,,,, -42900,0.9053371,1.1860207,,,,,,,,,,,,,, -43000,0.8871334,1.1929438,,,,,,,,,,,,,, -43100,0.685535,1.2029587,,,,,,,,,,,,,, -43200,0.81732106,1.2174618,,,,,,,,,,,,,, -43300,0.7977023,1.1581582,,,,,,,,,,,,,, -43396,,,0.22959857,0.0844912483276617,0.42108253,0.1264373364743138,5348.0,0.23914756,0.0816119269595596,2472.0,34644.79412436485,37986.90700483322,34644.79412436485,3338.9981787204742,1.2806153297424316,0.0 -43400,0.7481188,1.1798024,,,,,,,,,,,,,, -43500,0.68321204,1.2569697,,,,,,,,,,,,,, -43600,0.9461597,1.1579363,,,,,,,,,,,,,, -43700,0.81074464,1.1579173,,,,,,,,,,,,,, -43800,0.91537535,1.1727751,,,,,,,,,,,,,, -43900,0.7475325,1.2092676,,,,,,,,,,,,,, -44000,0.69294685,1.2184992,,,,,,,,,,,,,, -44100,0.84269035,1.1870658,,,,,,,,,,,,,, -44200,0.8592705,1.1644498,,,,,,,,,,,,,, -44300,0.82939786,1.1817526,,,,,,,,,,,,,, -44400,0.78633314,1.1620059,,,,,,,,,,,,,, -44500,0.71696436,1.1394361,,,,,,,,,,,,,, -44600,0.83332133,1.1625361,,,,,,,,,,,,,, -44700,0.73730177,1.1727608,,,,,,,,,,,,,, -44800,0.70617425,1.1379099,,,,,,,,,,,,,, -44900,0.8358946,1.0848048,,,,,,,,,,,,,, -45000,0.74565667,1.1629395,,,,,,,,,,,,,, -45100,0.7144433,1.159493,,,,,,,,,,,,,, -45174,,,0.26250243,0.0966651262642511,0.40786082,0.122614093862537,5348.0,0.23407225,0.0779152194666179,2472.0,36085.03482174873,39558.51782393456,36085.03482174873,3470.2350482940674,1.3378307819366455,0.0 -45200,0.7375281,1.2348514,,,,,,,,,,,,,, -45300,0.77227944,1.1690594,,,,,,,,,,,,,, -45400,0.8281146,1.1754831,,,,,,,,,,,,,, -45500,0.92708355,1.1377949,,,,,,,,,,,,,, -45600,0.74695337,1.1762854,,,,,,,,,,,,,, -45700,1.0548394,1.1466305,,,,,,,,,,,,,, -45800,0.8292622,1.1641665,,,,,,,,,,,,,, -45900,0.71246606,1.1799566,,,,,,,,,,,,,, -46000,0.769371,1.1808451,,,,,,,,,,,,,, -46100,0.67079246,1.1353905,,,,,,,,,,,,,, -46200,0.7990733,1.1202905,,,,,,,,,,,,,, -46300,0.6846289,1.1296982,,,,,,,,,,,,,, -46400,0.744569,1.1209371,,,,,,,,,,,,,, -46500,0.79894006,1.1454334,,,,,,,,,,,,,, -46600,0.8123377,1.1621149,,,,,,,,,,,,,, -46700,0.7909,1.1364977,,,,,,,,,,,,,, -46800,0.83620536,1.1871115,,,,,,,,,,,,,, -46900,0.76872015,1.124009,,,,,,,,,,,,,, -46972,,,0.22550887,0.0816548816890296,0.39984173,0.1198914816996051,5348.0,0.22878855,0.0766762131090934,2472.0,37525.23612308502,41131.11450552941,37525.23612308502,3602.4942483901978,1.3967421054840088,0.0 -47000,0.7664332,1.1751374,,,,,,,,,,,,,, -47100,0.8151763,1.1160816,,,,,,,,,,,,,, -47200,0.81347805,1.1416142,,,,,,,,,,,,,, -47300,0.85354424,1.1473632,,,,,,,,,,,,,, -47400,0.7361582,1.1400504,,,,,,,,,,,,,, -47500,0.8448109,1.1694481,,,,,,,,,,,,,, -47600,0.99389124,1.0694599,,,,,,,,,,,,,, -47700,0.734124,1.0828019,,,,,,,,,,,,,, -47800,0.7080079,1.0867568,,,,,,,,,,,,,, -47900,0.7537037,1.1528418,,,,,,,,,,,,,, -48000,0.68814635,1.1324143,,,,,,,,,,,,,, -48100,0.8081285,1.1367712,,,,,,,,,,,,,, -48200,0.6500959,1.1198394,,,,,,,,,,,,,, -48300,0.79004383,1.1188812,,,,,,,,,,,,,, -48400,0.77822125,1.1169056,,,,,,,,,,,,,, -48500,0.72016597,1.0947101,,,,,,,,,,,,,, -48600,0.7410624,1.0702006,,,,,,,,,,,,,, -48700,0.76932657,1.1487703,,,,,,,,,,,,,, -48778,,,0.20397122,0.0767817576520102,0.3939008,0.1173136893325738,5348.0,0.22130717,0.0745231856681494,2472.0,38965.76935267448,42703.40314888954,38965.76935267448,3734.108068704605,1.460125207901001,0.0 -48800,0.7633156,1.0923759,,,,,,,,,,,,,, -48900,0.8316284,1.1580577,,,,,,,,,,,,,, -49000,0.9062044,1.1314447,,,,,,,,,,,,,, -49100,0.84738034,1.0801007,,,,,,,,,,,,,, -49200,0.8186935,1.0881872,,,,,,,,,,,,,, -49300,0.80802876,1.0842528,,,,,,,,,,,,,, -49400,0.8035471,1.1197538,,,,,,,,,,,,,, -49500,0.81292397,1.141568,,,,,,,,,,,,,, -49600,0.68383795,1.0680908,,,,,,,,,,,,,, -49700,0.8181435,1.0648873,,,,,,,,,,,,,, -49800,0.82643497,1.1371866,,,,,,,,,,,,,, -49900,0.8166151,1.1118245,,,,,,,,,,,,,, -50000,0.7094359,1.1215965,,,,,,,,,,,,,, -50100,0.8029215,1.0905652,,,,,,,,,,,,,, -50200,0.7426601,1.0885625,,,,,,,,,,,,,, -50300,0.71524733,1.077479,,,,,,,,,,,,,, -50400,0.8867585,1.1491953,,,,,,,,,,,,,, -50500,0.76596385,1.0800233,,,,,,,,,,,,,, -50551,,,0.1726337,0.0656287724716962,0.38651025,0.1130656419861552,5348.0,0.21482898,0.0720857961123636,2472.0,40405.70114803314,44276.238609075546,40405.70114803314,3866.873418092728,1.5228583812713623,0.0 -50600,0.82649446,1.1258997,,,,,,,,,,,,,, -50700,0.85102814,1.0433955,,,,,,,,,,,,,, -50800,0.74703646,1.0793718,,,,,,,,,,,,,, -50900,0.7099571,1.0642993,,,,,,,,,,,,,, -51000,0.828307,1.0423329,,,,,,,,,,,,,, -51100,0.8487482,1.0895469,,,,,,,,,,,,,, -51200,0.8053677,1.1056076,,,,,,,,,,,,,, -51300,0.8006541,1.1278783,,,,,,,,,,,,,, -51400,0.9368798,1.0764251,,,,,,,,,,,,,, -51500,0.7749474,1.0429366,,,,,,,,,,,,,, -51600,0.97435313,1.1051526,,,,,,,,,,,,,, -51700,1.0403581,1.0874327,,,,,,,,,,,,,, -51800,0.7699827,1.0703686,,,,,,,,,,,,,, -51900,0.8957333,1.092003,,,,,,,,,,,,,, -52000,0.8254232,1.0920079,,,,,,,,,,,,,, -52100,0.89843595,1.0847797,,,,,,,,,,,,,, -52200,0.7659097,1.1096413,,,,,,,,,,,,,, -52300,0.89870816,1.1253927,,,,,,,,,,,,,, -52335,,,0.19079773,0.0716679248562133,0.37963584,0.1131428792106355,5348.0,0.21250953,0.0727357666605731,2472.0,41845.93539810181,45849.49245977402,41845.93539810181,3999.761225223541,1.5781021118164062,0.0 -52400,0.82786155,1.07984,,,,,,,,,,,,,, -52500,0.78411514,1.1091254,,,,,,,,,,,,,, -52600,0.89885813,1.0813835,,,,,,,,,,,,,, -52700,0.7444037,1.0226434,,,,,,,,,,,,,, -52800,0.96214974,1.0975673,,,,,,,,,,,,,, -52900,0.8749397,1.1122103,,,,,,,,,,,,,, -53000,0.852446,1.0534486,,,,,,,,,,,,,, -53100,0.95622987,1.0709786,,,,,,,,,,,,,, -53200,0.78819996,1.0661368,,,,,,,,,,,,,, -53300,0.7670037,1.0578115,,,,,,,,,,,,,, -53400,0.79667103,1.0424898,,,,,,,,,,,,,, -53500,0.83613867,1.0664469,,,,,,,,,,,,,, -53600,1.0270737,1.0168567,,,,,,,,,,,,,, -53700,1.0191206,1.0642105,,,,,,,,,,,,,, -53800,0.92754036,1.0528175,,,,,,,,,,,,,, -53900,0.8722436,1.0297365,,,,,,,,,,,,,, -54000,0.9416964,1.0597259,,,,,,,,,,,,,, -54100,0.8154734,1.1125091,,,,,,,,,,,,,, -54147,,,0.16157928,0.0621564164232949,0.36532888,0.1073983606399104,5348.0,0.20436308,0.0686734507342636,2472.0,43285.85983061791,47421.48486161232,43285.85983061791,4131.691259860992,1.639273166656494,0.0 -54200,0.8487352,1.05709,,,,,,,,,,,,,, -54300,0.9285391,1.0697279,,,,,,,,,,,,,, -54400,0.8149089,1.0938324,,,,,,,,,,,,,, -54500,0.8093273,1.0489631,,,,,,,,,,,,,, -54600,0.85541075,1.0712078,,,,,,,,,,,,,, -54700,0.9407604,1.0940882,,,,,,,,,,,,,, -54800,0.7461684,1.0560809,,,,,,,,,,,,,, -54900,0.80326366,1.042818,,,,,,,,,,,,,, -55000,0.8615629,1.0762109,,,,,,,,,,,,,, -55100,1.0492486,1.0119876,,,,,,,,,,,,,, -55200,0.89668226,1.1098075,,,,,,,,,,,,,, -55300,0.7949202,1.054017,,,,,,,,,,,,,, -55400,1.0031717,1.1039867,,,,,,,,,,,,,, -55500,0.7699608,1.0233736,,,,,,,,,,,,,, -55600,0.8234735,1.0114194,,,,,,,,,,,,,, -55700,0.9329308,1.0611924,,,,,,,,,,,,,, -55800,0.84598386,1.021841,,,,,,,,,,,,,, -55900,0.9189455,1.0245261,,,,,,,,,,,,,, -55961,,,0.15674654,0.0598292452727844,0.35715374,0.1051198625177404,5348.0,0.19516303,0.0657892064265838,2472.0,44725.77019500733,48992.8511402607,44725.77019500733,4263.011621952057,1.696920394897461,0.0 -56000,1.0251431,1.0195857,,,,,,,,,,,,,, -56100,0.9071408,1.1381862,,,,,,,,,,,,,, -56200,0.9814144,1.0367547,,,,,,,,,,,,,, -56300,0.8925828,1.0764349,,,,,,,,,,,,,, -56400,0.93519354,1.0609814,,,,,,,,,,,,,, -56500,0.88043684,1.0869714,,,,,,,,,,,,,, -56600,0.88841724,1.0477651,,,,,,,,,,,,,, -56700,0.94263476,1.0791134,,,,,,,,,,,,,, -56800,0.8155187,0.9867995,,,,,,,,,,,,,, -56900,0.99039555,1.0642833,,,,,,,,,,,,,, -57000,0.8582054,1.093601,,,,,,,,,,,,,, -57100,1.0169874,1.0909882,,,,,,,,,,,,,, -57200,0.88810515,1.0880698,,,,,,,,,,,,,, -57300,1.0805339,1.0180409,,,,,,,,,,,,,, -57400,0.8627996,0.9567117,,,,,,,,,,,,,, -57500,0.9050224,1.025896,,,,,,,,,,,,,, -57600,0.88032854,1.0452826,,,,,,,,,,,,,, -57700,0.9207276,0.9973476,,,,,,,,,,,,,, -57741,,,0.14496757,0.0550945479641131,0.3530484,0.1040964692933759,5348.0,0.19236514,0.063575244246745,2472.0,46166.32611012459,50565.07560944557,46166.32611012459,4394.548684120178,1.7538185119628906,0.0 -57800,0.88236344,1.0151411,,,,,,,,,,,,,, -57900,1.028959,1.0298911,,,,,,,,,,,,,, -58000,1.1677592,0.9935299,,,,,,,,,,,,,, -58100,1.1445979,1.0194376,,,,,,,,,,,,,, -58200,0.8724711,1.0232816,,,,,,,,,,,,,, -58300,0.8905098,1.0411035,,,,,,,,,,,,,, -58400,0.8917262,0.9680069,,,,,,,,,,,,,, -58500,0.7989872,1.0050827,,,,,,,,,,,,,, -58600,0.9273192,1.0738875,,,,,,,,,,,,,, -58700,0.82034916,1.0438164,,,,,,,,,,,,,, -58800,0.9416907,0.9993248,,,,,,,,,,,,,, -58900,1.0030884,1.0163369,,,,,,,,,,,,,, -59000,0.93341374,1.0193949,,,,,,,,,,,,,, -59100,1.0804497,1.0298114,,,,,,,,,,,,,, -59200,0.8792921,0.97882634,,,,,,,,,,,,,, -59300,0.9078601,0.99319524,,,,,,,,,,,,,, -59400,0.9621279,1.0077858,,,,,,,,,,,,,, -59500,0.95445603,0.9211779,,,,,,,,,,,,,, -59535,,,0.1456565,0.0555097695820885,0.3391099,0.1003890825183197,5348.0,0.1857698,0.0610566083724331,2472.0,47606.610372543335,52136.28013706207,47606.610372543335,4525.326953172684,1.8184218406677248,0.0 -59600,0.912952,0.9803317,,,,,,,,,,,,,, -59700,0.8467993,1.0204965,,,,,,,,,,,,,, -59800,0.8499305,0.9705929,,,,,,,,,,,,,, -59900,1.0365438,0.98100525,,,,,,,,,,,,,, -60000,1.0713097,0.975439,,,,,,,,,,,,,, -60100,0.97412086,1.0034112,,,,,,,,,,,,,, -60200,1.0591521,0.96896917,,,,,,,,,,,,,, -60300,0.84941566,0.9990296,,,,,,,,,,,,,, -60400,1.0414904,0.9755014,,,,,,,,,,,,,, -60500,0.96483624,1.0065275,,,,,,,,,,,,,, -60600,0.9560725,1.0088414,,,,,,,,,,,,,, -60700,0.93727267,0.9935269,,,,,,,,,,,,,, -60800,0.8407301,0.9559388,,,,,,,,,,,,,, -60900,0.9844163,0.93958884,,,,,,,,,,,,,, -61000,1.0068178,0.99383813,,,,,,,,,,,,,, -61100,0.9480136,0.9745725,,,,,,,,,,,,,, -61200,0.9132103,0.9976822,,,,,,,,,,,,,, -61300,0.9197554,0.97678465,,,,,,,,,,,,,, -61341,,,0.1419333,0.0538886287151896,0.33404234,0.0982457495389903,5348.0,0.1807087,0.0595941746389616,2472.0,49047.04236245155,53710.25221323967,49047.04236245155,4658.723962068558,1.8841652870178225,0.0 -61400,1.3378705,0.9670861,,,,,,,,,,,,,, -61500,1.0785056,0.9798936,,,,,,,,,,,,,, -61600,1.0022283,0.9521483,,,,,,,,,,,,,, -61700,0.86645734,0.93572474,,,,,,,,,,,,,, -61800,1.3229997,0.98353845,,,,,,,,,,,,,, -61900,1.0413169,0.9979813,,,,,,,,,,,,,, -62000,1.2199371,0.98031753,,,,,,,,,,,,,, -62100,0.9159052,0.9497696,,,,,,,,,,,,,, -62200,0.975309,0.9176093,,,,,,,,,,,,,, -62300,1.06908,0.98861915,,,,,,,,,,,,,, -62400,1.2293222,0.9648634,,,,,,,,,,,,,, -62500,0.9372394,0.9996474,,,,,,,,,,,,,, -62600,1.0516343,0.9741835,,,,,,,,,,,,,, -62700,1.0505066,1.0108341,,,,,,,,,,,,,, -62800,0.89150083,0.94134676,,,,,,,,,,,,,, -62900,1.1224191,0.9565505,,,,,,,,,,,,,, -63000,1.000043,0.961684,,,,,,,,,,,,,, -63100,1.0293927,0.96649224,,,,,,,,,,,,,, -63155,,,0.11702894,0.045682292016948,0.32601267,0.0950790233352964,5348.0,0.17809816,0.0581520524851217,2472.0,50487.46106958389,55282.32294297218,50487.46106958389,4790.242144107819,1.94150710105896,0.0 -63200,1.0122848,0.93478465,,,,,,,,,,,,,, -63300,1.05922,0.9634793,,,,,,,,,,,,,, -63400,1.1341442,1.0009964,,,,,,,,,,,,,, -63500,1.0025291,0.931114,,,,,,,,,,,,,, -63600,1.0164964,0.93588287,,,,,,,,,,,,,, -63700,1.6168966,0.9656656,,,,,,,,,,,,,, -63800,1.5913037,0.96766347,,,,,,,,,,,,,, -63900,1.0781907,0.93997854,,,,,,,,,,,,,, -64000,1.2980555,0.93140376,,,,,,,,,,,,,, -64100,1.0152798,0.9106613,,,,,,,,,,,,,, -64200,1.0779247,0.9188863,,,,,,,,,,,,,, -64300,1.148478,0.9106015,,,,,,,,,,,,,, -64400,1.0523032,0.9216065,,,,,,,,,,,,,, -64500,1.045583,0.9681268,,,,,,,,,,,,,, -64600,1.089725,0.9941494,,,,,,,,,,,,,, -64700,0.9608086,0.9644007,,,,,,,,,,,,,, -64800,1.0451026,0.95251524,,,,,,,,,,,,,, -64900,1.2878342,0.91514814,,,,,,,,,,,,,, -64934,,,0.11297734,0.0438991830587376,0.3239654,0.0935728974579298,5348.0,0.16898389,0.0554099892348627,2472.0,51927.87755322456,56853.82898592949,51927.87755322456,4921.200897216797,1.9962365627288816,0.0 -65000,0.924665,0.8935961,,,,,,,,,,,,,, -65100,0.99644494,0.9105684,,,,,,,,,,,,,, -65200,1.148592,0.949459,,,,,,,,,,,,,, -65300,1.0672013,0.94498384,,,,,,,,,,,,,, -65400,1.1557786,0.96138227,,,,,,,,,,,,,, -65500,1.1730595,0.9048701,,,,,,,,,,,,,, -65600,1.3592762,0.9388465,,,,,,,,,,,,,, -65700,0.99018055,0.955971,,,,,,,,,,,,,, -65800,1.0248313,0.9706855,,,,,,,,,,,,,, -65900,1.1674027,0.95483625,,,,,,,,,,,,,, -66000,1.1295105,0.8683919,,,,,,,,,,,,,, -66100,1.0878786,0.9786239,,,,,,,,,,,,,, -66200,1.1505482,0.9499042,,,,,,,,,,,,,, -66300,1.1031284,0.9028372,,,,,,,,,,,,,, -66400,1.4291686,0.88597,,,,,,,,,,,,,, -66500,1.2682323,0.8785719,,,,,,,,,,,,,, -66600,0.96201825,0.91252846,,,,,,,,,,,,,, -66700,1.0831935,0.86033463,,,,,,,,,,,,,, -66736,,,0.10090725,0.0388173458911752,0.31437576,0.0914006005194203,5348.0,0.16851163,0.0558162208274937,2472.0,53368.10543203354,58425.920551776886,53368.10543203354,5052.923782587051,2.061052083969116,0.0 -66800,0.9809439,0.91464186,,,,,,,,,,,,,, -66900,1.1551332,0.901714,,,,,,,,,,,,,, -67000,0.9711654,0.8872381,,,,,,,,,,,,,, -67100,1.2076058,0.9450692,,,,,,,,,,,,,, -67200,1.0389117,0.95000243,,,,,,,,,,,,,, -67300,1.2004241,0.89834917,,,,,,,,,,,,,, -67400,1.3413904,0.8484535,,,,,,,,,,,,,, -67500,1.3006092,0.9490264,,,,,,,,,,,,,, -67600,1.2190647,0.90232384,,,,,,,,,,,,,, -67700,1.0593147,0.90713143,,,,,,,,,,,,,, -67800,1.0417639,0.91888845,,,,,,,,,,,,,, -67900,1.2912661,0.9260692,,,,,,,,,,,,,, -68000,1.1038423,0.8669882,,,,,,,,,,,,,, -68100,0.9520313,0.9217618,,,,,,,,,,,,,, -68200,1.2826324,0.86963296,,,,,,,,,,,,,, -68300,1.2639021,0.8678953,,,,,,,,,,,,,, -68400,1.275292,0.90209156,,,,,,,,,,,,,, -68500,1.1668783,0.88239086,,,,,,,,,,,,,, -68543,,,0.09813993,0.0374010032054749,0.3097387,0.089740000193093,5348.0,0.16452241,0.054231917616233,2472.0,54808.36094260216,59996.24846315384,54808.36094260216,5182.85727763176,2.1215648651123047,0.0 -68600,1.2370334,0.89232826,,,,,,,,,,,,,, -68700,1.8566535,0.87829214,,,,,,,,,,,,,, -68800,1.0670822,0.91349393,,,,,,,,,,,,,, -68900,1.0224357,0.8955646,,,,,,,,,,,,,, -69000,1.1595826,0.9126392,,,,,,,,,,,,,, -69100,1.3206561,0.89642364,,,,,,,,,,,,,, -69200,1.2264444,0.8532629,,,,,,,,,,,,,, -69300,1.1787392,0.8939322,,,,,,,,,,,,,, -69400,1.0707084,0.9005837,,,,,,,,,,,,,, -69500,1.4279581,0.8810173,,,,,,,,,,,,,, -69600,1.059981,0.8485562,,,,,,,,,,,,,, -69700,1.3621486,0.90211886,,,,,,,,,,,,,, -69800,1.1648126,0.9196947,,,,,,,,,,,,,, -69900,1.0588622,0.8978059,,,,,,,,,,,,,, -70000,1.1688402,0.910931,,,,,,,,,,,,,, -70100,1.1958871,0.891673,,,,,,,,,,,,,, -70200,1.0902816,0.87486625,,,,,,,,,,,,,, -70300,1.0989238,0.91810954,,,,,,,,,,,,,, -70351,,,0.08928266,0.0343127389816864,0.30275312,0.0866118926016393,5348.0,0.1618215,0.0524851217679198,2472.0,56248.92697405815,61568.578528642654,56248.92697405815,5314.483243465424,2.182187557220459,0.0 -70400,1.2933328,0.9207161,,,,,,,,,,,,,, -70500,1.1131821,0.8708072,,,,,,,,,,,,,, -70600,1.1209323,0.84419537,,,,,,,,,,,,,, -70700,1.0881234,0.8400307,,,,,,,,,,,,,, -70800,1.3747278,0.84259576,,,,,,,,,,,,,, -70900,1.0678163,0.8659273,,,,,,,,,,,,,, -71000,2.3983603,0.8848964,,,,,,,,,,,,,, -71100,1.1364655,0.8858465,,,,,,,,,,,,,, -71200,0.9850198,0.8291259,,,,,,,,,,,,,, -71300,1.297078,0.8715832,,,,,,,,,,,,,, -71400,1.0609201,0.8570777,,,,,,,,,,,,,, -71500,1.3493713,0.85333204,,,,,,,,,,,,,, -71600,1.9002789,0.9259356,,,,,,,,,,,,,, -71700,1.3541689,0.8772607,,,,,,,,,,,,,, -71800,1.1917511,0.8853033,,,,,,,,,,,,,, -71900,1.304847,0.9432053,,,,,,,,,,,,,, -72000,1.136976,0.87403744,,,,,,,,,,,,,, -72100,1.3097508,0.8579079,,,,,,,,,,,,,, -72130,,,0.08532207,0.0323987469975537,0.29723024,0.0853857516630139,5348.0,0.15875757,0.0512461154103954,2472.0,57688.89278316498,63140.387882232666,57688.89278316498,5446.1889128685,2.245168685913086,0.0 -72200,1.1524137,0.829874,,,,,,,,,,,,,, -72300,1.305893,0.8800992,,,,,,,,,,,,,, -72400,1.0887119,0.91032684,,,,,,,,,,,,,, -72500,1.2815418,0.8696448,,,,,,,,,,,,,, -72600,1.5888809,0.8680385,,,,,,,,,,,,,, -72700,1.7865378,0.852099,,,,,,,,,,,,,, -72800,1.1144872,0.843207,,,,,,,,,,,,,, -72900,0.9854174,0.866006,,,,,,,,,,,,,, -73000,1.4248599,0.90552944,,,,,,,,,,,,,, -73100,1.6180325,0.85077167,,,,,,,,,,,,,, -73200,1.2498511,0.8145712,,,,,,,,,,,,,, -73300,1.0887351,0.9098062,,,,,,,,,,,,,, -73400,1.0270339,0.83710086,,,,,,,,,,,,,, -73500,1.1463615,0.7854283,,,,,,,,,,,,,, -73600,1.3425001,0.87476844,,,,,,,,,,,,,, -73700,1.0507349,0.81626594,,,,,,,,,,,,,, -73800,1.2230783,0.8798584,,,,,,,,,,,,,, -73900,1.2849123,0.82042384,,,,,,,,,,,,,, -73933,,,0.07010609,0.0270024254691219,0.29594654,0.0836865327244465,5348.0,0.15754385,0.0511648690918692,2472.0,59128.87972784042,64710.83512163162,59128.87972784042,5576.509063482285,2.30984878540039,0.0 -74000,1.1610496,0.8786777,,,,,,,,,,,,,, -74100,1.2510169,0.8945135,,,,,,,,,,,,,, -74200,1.1672747,0.8449276,,,,,,,,,,,,,, -74300,1.3827012,0.8910396,,,,,,,,,,,,,, -74400,1.2476773,0.88616896,,,,,,,,,,,,,, -74500,1.396785,0.80993426,,,,,,,,,,,,,, -74600,1.2228665,0.846684,,,,,,,,,,,,,, -74700,1.2152262,0.89976585,,,,,,,,,,,,,, -74800,1.1267622,0.8421262,,,,,,,,,,,,,, -74900,1.1420115,0.84651357,,,,,,,,,,,,,, -75000,1.0619425,0.8293541,,,,,,,,,,,,,, -75100,1.2456757,0.87484235,,,,,,,,,,,,,, -75200,1.1552224,0.9075517,,,,,,,,,,,,,, -75300,1.2331173,0.9079382,,,,,,,,,,,,,, -75400,1.0296392,0.8087959,,,,,,,,,,,,,, -75500,1.1612974,0.840243,,,,,,,,,,,,,, -75600,1.4043882,0.8678618,,,,,,,,,,,,,, -75700,1.2411091,0.87888473,,,,,,,,,,,,,, -75739,,,0.066934444,0.0250293113116174,0.2931396,0.0840534095407281,5348.0,0.15539056,0.0494586964028192,2472.0,60569.61856889725,66282.15744042397,60569.61856889725,5706.956416845322,2.3700406551361084,0.0 -75800,1.1517148,0.8525677,,,,,,,,,,,,,, -75900,1.1022316,0.8733092,,,,,,,,,,,,,, -76000,1.9537706,0.8270698,,,,,,,,,,,,,, -76100,1.1694807,0.91390544,,,,,,,,,,,,,, -76200,1.1093173,0.8674327,,,,,,,,,,,,,, -76300,1.1363078,0.85682887,,,,,,,,,,,,,, -76377,,,,,,,,,,,61068.13390135765,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 632dbec72..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -135.64001059532166,0.0,35.945107221603394,1,0,35.945107221603394,31.23744,2472,1.10235004976337,171.58516669273376,32.57337,1.3618151438861104,31.08705,5348,1.0585651254622166 -245.00649857521057,0.0288045406341552,1476.374610900879,1760,0,1476.374610900879,7.0403576,2472,0.899579550301627,1721.4816236495972,7.147141,0.938839590443686,7.041991,5348,0.8966179750330672 -353.3921241760254,0.0856506824493408,2916.5684781074524,3580,0,2916.5684781074524,5.784349,2472,0.899579550301627,3270.1946194171906,5.8638206,0.9417131519458763,5.820514,5348,0.8966179750330672 -463.0172669887543,0.1368649005889892,4356.562227487564,5374,0,4356.562227487564,5.7255216,2472,0.8994983039831007,4819.940863370895,5.8075776,0.940334510929565,5.716846,5348,0.8965503924616469 -571.2973206043243,0.1869440078735351,5797.4471888542175,7157,0,5797.4471888542175,5.6091366,2472,0.899579550301627,6369.234001159668,5.684661,0.9380313186211507,5.6359296,5348,0.8966179750330672 -679.9091982841492,0.2448720932006836,7237.664181232452,8957,0,7237.664181232452,5.4412518,2472,0.899579550301627,7918.197555303574,5.55107,0.940821302569804,5.475975,5348,0.8966179750330672 -829.0853753089905,0.2981421947479248,8677.619053125381,10800,0,8677.619053125381,4.28696,2472,0.7847175674852234,9507.45812034607,4.311075,0.812946352661112,4.4674177,5348,0.795195844637323 -962.355612039566,0.345757246017456,10117.509083271028,12614,0,10117.509083271028,2.5912423,2472,0.5990697296528751,11080.74349808693,2.5407991,0.6001672788101385,2.867275,5348,0.6317522229838671 -1095.8711938858032,0.395554780960083,11557.722923755646,14388,0,11557.722923755646,1.5566329,2472,0.4498202425202608,12654.597905158997,1.5154848,0.4560902772271994,1.8800801,5348,0.4955057590005503 -1229.3207650184631,0.4482710361480713,12998.31579875946,16164,0,12998.31579875946,1.2202077,2472,0.3743830357686917,14228.770209550858,1.1518191,0.3756104432757325,1.5427254,5348,0.4272473618660513 -1361.0830354690552,0.5029153823852539,14438.654992818832,17977,0,14438.654992818832,0.92897916,2472,0.2999410964190685,15801.002537488936,0.89231455,0.2999344248311306,1.2439867,5348,0.3619143246087452 -1493.616887807846,0.5617284774780273,15878.671095132828,19762,0,15878.671095132828,0.75496924,2472,0.2517214063737737,17373.685792684555,0.69894594,0.2439761932747069,1.0527449,5348,0.3146741071859583 -1626.2040507793429,0.6224701404571533,17318.76188802719,21532,0,17318.76188802719,0.6438148,2472,0.2174151483760892,18946.50092363357,0.62691694,0.2201369216241737,0.92000806,5348,0.2802745783330276 -1760.1241040229795,0.6804590225219727,18758.965651988983,23320,0,18758.965651988983,0.58487004,2472,0.1990534803891698,20520.760549545288,0.5428806,0.1951972010178117,0.85429704,5348,0.2618341909883468 -1893.8379728794096,0.7349987030029297,20199.36521506309,25128,0,20199.36521506309,0.534569,2472,0.1835760567099303,22095.005754709244,0.44918406,0.1670094697167437,0.79768723,5348,0.2459909053168174 -2025.173578500748,0.7892842292785645,21639.26774644852,26903,0,21639.26774644852,0.51442176,2472,0.176040460666626,23666.374056100845,0.4337378,0.1593523874575764,0.77539897,5348,0.2393967772768085 -2157.2508704662323,0.8437254428863525,23079.26572728157,28686,0,23079.26572728157,0.4679157,2472,0.1655596855767473,25238.579748630524,0.3969769,0.1507755277897458,0.7187384,5348,0.223302470625718 -2291.930029153824,0.8999524116516113,24519.34055161476,30480,0,24519.34055161476,0.44529587,2472,0.1558304389332358,26813.466148376465,0.37746027,0.1383267689800353,0.68599296,5348,0.213705745484036 -2425.518532037735,0.9573686122894288,25959.83496165276,32284,0,25959.83496165276,0.42393988,2472,0.1478276765584059,28387.68375515937,0.36455256,0.1364740165128703,0.659797,5348,0.2071888546685074 -2559.377106666565,1.0186927318572998,27400.28605747223,34050,0,27400.28605747223,0.39865455,2472,0.1377937562204212,29962.1297519207,0.34491673,0.1303387524763742,0.6365954,5348,0.1974473097309248 -2692.8508801460266,1.1042413711547852,28840.38019633293,35829,0,28840.38019633293,0.39754122,2472,0.1369406698758962,31535.858542203903,0.33196563,0.1249875557371981,0.6245508,5348,0.1948598627108335 -2827.135392189026,1.1598222255706787,30280.80466890335,37632,0,30280.80466890335,0.37677324,2472,0.1318221518087461,33110.70022511482,0.2906541,0.1126476899473298,0.6097289,5348,0.1901483920175328 -2960.3469581604004,1.2174606323242188,31721.29601264,39434,0,31721.29601264,0.37109178,2472,0.1293644506733288,34684.53674340248,0.2863048,0.1086859870300435,0.5883182,5348,0.184712822344729 -3091.5962493419647,1.2778022289276123,33161.7505197525,41198,0,33161.7505197525,0.351271,2472,0.1214226230373936,36256.37497663498,0.28539166,0.10729820385198,0.5655498,5348,0.176255346264132 -3224.325934410095,1.335150957107544,34602.12437748909,42961,0,34602.12437748909,0.3410652,2472,0.1194117766538703,37829.61024093628,0.2625836,0.0988772736434404,0.5573182,5348,0.1747299110806453 -3359.500274658203,1.389387607574463,36041.99537992477,44771,0,36041.99537992477,0.33385333,2472,0.1161822354924542,39404.787296772,0.25002244,0.0974915358571868,0.55120695,5348,0.1716018034891916 -3494.410943031311,1.4525690078735352,37482.20679521561,46549,0,37482.20679521561,0.32607883,2472,0.1123433469420916,40980.0468621254,0.27212,0.0963176872267781,0.52371204,5348,0.1627677959392529 -3629.584242582321,1.512967586517334,38922.34973311424,48311,0,38922.34973311424,0.30938765,2472,0.1074076330916255,42555.49884772301,0.19800447,0.0776606435998933,0.51308125,5348,0.1602093128783417 -3764.5926241874695,1.574493646621704,40362.35197472572,50086,0,40362.35197472572,0.30191362,2472,0.1057014604025755,44130.64695501328,0.21282442,0.0822324569887133,0.50448716,5348,0.157303262307269 -3898.094487667084,1.630784034729004,41802.36686468125,51885,0,41802.36686468125,0.29826474,2472,0.1038124834968415,45704.29691815376,0.27286735,0.1037494229120408,0.5019356,5348,0.1571680971644284 -4030.19356417656,1.6926136016845703,43242.77938055992,53658,0,43242.77938055992,0.28853884,2472,0.1002173339020575,47276.94588851929,0.28197432,0.1062609862536667,0.4842883,5348,0.1505739691244195 -4162.418050527573,1.748811960220337,44682.94412112236,55435,0,44682.94412112236,0.28182715,2472,0.0968456116832206,48849.46684360504,0.3120024,0.1183694542416929,0.47713417,5348,0.1493864467980343 -4293.939308643341,1.810908794403076,46123.28657770157,57224,0,46123.28657770157,0.2736266,2472,0.0924176873235431,50421.46864414215,0.27040535,0.1010153773917126,0.46657196,5348,0.1453218378597565 -4426.556207895279,1.867297887802124,47563.39091444016,59012,0,47563.39091444016,0.26979008,2472,0.0920317673105437,51994.320743083954,0.24490872,0.0938699363336498,0.45609862,5348,0.1409772439827374 -4559.163670539856,1.924031972885132,49003.68826889992,60762,0,49003.68826889992,0.26356447,2472,0.0903865293603883,53567.35671496391,0.20578378,0.0797272673202013,0.45228767,5348,0.1395773192890313 -4693.026634216309,1.9902966022491453,50443.94960784912,62525,0,50443.94960784912,0.2582945,2472,0.0898381167103365,55141.62321925163,0.23597239,0.0912774590432145,0.44328588,5348,0.1363526651669772 -4827.03906917572,2.052520036697388,51884.26449036598,64333,0,51884.26449036598,0.25623506,2472,0.0880710092823919,56716.0893816948,0.20355521,0.077820553951449,0.44492757,5348,0.1354644370854533 -4960.271555185318,2.1096343994140625,53324.19461846352,66115,0,53324.19461846352,0.2523467,2472,0.0857758007840269,58289.38478684425,0.20436361,0.0800104603080342,0.43434986,5348,0.1336397076571053 -5093.883127212524,2.1731672286987305,54764.53757071495,67890,0,54764.53757071495,0.24856624,2472,0.0845774175857656,59863.47870230675,0.19853008,0.0757643860245349,0.426306,5348,0.1307722756982728 -5224.875775098801,2.240447998046875,56204.93294286728,69674,0,56204.93294286728,0.24726346,2472,0.0837852659801352,61435.01011872292,0.20562021,0.0787849206121335,0.42373237,5348,0.1293337323923264 -5357.613756656647,2.299800157546997,57644.84786987305,71491,0,57644.84786987305,0.24452537,2472,0.0833790343875043,63007.79894852638,0.2064576,0.0766548233632427,0.41951063,5348,0.1278082972088398 -5488.593738079071,2.365021228790283,59085.31165289879,73269,0,59085.31165289879,0.24227335,2472,0.0833384112282412,64579.38373923302,0.18731399,0.0725551522855374,0.41804275,5348,0.1278662251272 -5620.283970355988,2.4329538345336914,60525.435987234116,75049,0,60525.435987234116,0.2408368,2472,0.08199784697255906,66151.34201812744,0.19145286,0.07430065901654453,0.41606402,5348,0.12715178080075693 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/measurements.csv deleted file mode 100644 index 3e3f9aac4..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/measurements.csv +++ /dev/null @@ -1,803 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,40.786842,32.304886,,,,,,,,,,,,,, -1,,,32.57337,1.3618151438861104,31.08705,1.0585651254622166,5348.0,31.23744,1.10235004976337,2472.0,35.945107221603394,171.58516669273376,35.945107221603394,135.64001059532166,0.0,0.0 -100,4.3110824,10.49865,,,,,,,,,,,,,, -200,0.9840107,6.046527,,,,,,,,,,,,,, -300,0.7924592,5.8594756,,,,,,,,,,,,,, -400,0.61243606,5.8249335,,,,,,,,,,,,,, -500,2.1186507,5.80933,,,,,,,,,,,,,, -600,0.517427,5.825859,,,,,,,,,,,,,, -700,2.8423176,5.800404,,,,,,,,,,,,,, -800,0.77811384,5.803669,,,,,,,,,,,,,, -900,2.1722612,5.812424,,,,,,,,,,,,,, -1000,0.42358577,5.8061047,,,,,,,,,,,,,, -1100,0.5556227,5.76439,,,,,,,,,,,,,, -1200,0.3796498,5.811295,,,,,,,,,,,,,, -1300,0.7388076,5.7928333,,,,,,,,,,,,,, -1400,0.60259855,5.658049,,,,,,,,,,,,,, -1500,0.41349247,5.52914,,,,,,,,,,,,,, -1600,0.7574542,5.530147,,,,,,,,,,,,,, -1700,0.5048182,5.52581,,,,,,,,,,,,,, -1760,,,7.147141,0.938839590443686,7.041991,0.8966179750330672,5348.0,7.0403576,0.899579550301627,2472.0,1476.374610900879,1721.4816236495972,1476.374610900879,245.00649857521057,0.0288045406341552,0.0 -1800,1.851145,5.5218544,,,,,,,,,,,,,, -1900,0.21175344,5.795949,,,,,,,,,,,,,, -2000,0.24839543,5.792898,,,,,,,,,,,,,, -2100,0.27693662,5.773966,,,,,,,,,,,,,, -2200,0.2351677,5.783643,,,,,,,,,,,,,, -2300,0.40745258,5.7925954,,,,,,,,,,,,,, -2400,0.32464695,5.780691,,,,,,,,,,,,,, -2500,0.29806462,5.7930956,,,,,,,,,,,,,, -2600,0.40399,5.784899,,,,,,,,,,,,,, -2700,0.6775421,5.7771173,,,,,,,,,,,,,, -2800,0.28869346,5.8112483,,,,,,,,,,,,,, -2900,0.70396304,5.7861633,,,,,,,,,,,,,, -3000,0.90836793,5.8076673,,,,,,,,,,,,,, -3100,0.9250737,5.7748437,,,,,,,,,,,,,, -3200,0.645942,5.7670746,,,,,,,,,,,,,, -3300,0.527208,5.797604,,,,,,,,,,,,,, -3400,0.49066448,5.602515,,,,,,,,,,,,,, -3500,4.9987264,5.812967,,,,,,,,,,,,,, -3580,,,5.8638206,0.9417131519458763,5.820514,0.8966179750330672,5348.0,5.784349,0.899579550301627,2472.0,2916.5684781074524,3270.1946194171906,2916.5684781074524,353.3921241760254,0.0856506824493408,0.0 -3600,1.2941027,5.5767083,,,,,,,,,,,,,, -3700,1.1472689,5.512039,,,,,,,,,,,,,, -3800,0.5707582,5.53448,,,,,,,,,,,,,, -3900,0.75363,5.525724,,,,,,,,,,,,,, -4000,1.5002002,5.538712,,,,,,,,,,,,,, -4100,0.45646977,5.509579,,,,,,,,,,,,,, -4200,0.493837,5.5168185,,,,,,,,,,,,,, -4300,1.3551065,5.526008,,,,,,,,,,,,,, -4400,0.48356164,5.5190635,,,,,,,,,,,,,, -4500,0.37526798,5.523414,,,,,,,,,,,,,, -4600,0.34308675,5.5265965,,,,,,,,,,,,,, -4700,0.71962214,5.518023,,,,,,,,,,,,,, -4800,0.7655572,5.523728,,,,,,,,,,,,,, -4900,0.63938516,5.4948616,,,,,,,,,,,,,, -5000,0.8109749,5.500837,,,,,,,,,,,,,, -5100,1.4886025,5.5100546,,,,,,,,,,,,,, -5200,9.531803,5.5384197,,,,,,,,,,,,,, -5300,0.30111703,5.5138974,,,,,,,,,,,,,, -5374,,,5.8075776,0.940334510929565,5.716846,0.8965503924616469,5348.0,5.7255216,0.8994983039831007,2472.0,4356.562227487564,4819.940863370895,4356.562227487564,463.0172669887543,0.1368649005889892,0.0 -5400,0.8504049,5.501849,,,,,,,,,,,,,, -5500,0.5702205,5.496133,,,,,,,,,,,,,, -5600,0.8408251,5.498366,,,,,,,,,,,,,, -5700,0.92733324,5.4969425,,,,,,,,,,,,,, -5800,1.4970317,5.5642776,,,,,,,,,,,,,, -5900,0.38682735,5.482661,,,,,,,,,,,,,, -6000,0.9348911,5.5184436,,,,,,,,,,,,,, -6100,0.35270655,5.472892,,,,,,,,,,,,,, -6200,0.27464524,5.476278,,,,,,,,,,,,,, -6300,0.41664717,5.4912057,,,,,,,,,,,,,, -6400,0.40592718,5.4514275,,,,,,,,,,,,,, -6500,0.35017753,5.4848995,,,,,,,,,,,,,, -6600,0.7976031,5.495403,,,,,,,,,,,,,, -6700,0.56747496,5.4703355,,,,,,,,,,,,,, -6800,0.3159434,5.490645,,,,,,,,,,,,,, -6900,0.9918735,5.486289,,,,,,,,,,,,,, -7000,0.5411699,5.462322,,,,,,,,,,,,,, -7100,0.59229445,5.4569573,,,,,,,,,,,,,, -7157,,,5.684661,0.9380313186211507,5.6359296,0.8966179750330672,5348.0,5.6091366,0.899579550301627,2472.0,5797.4471888542175,6369.234001159668,5797.4471888542175,571.2973206043243,0.1869440078735351,0.0 -7200,0.640166,5.469576,,,,,,,,,,,,,, -7300,0.5691814,5.457305,,,,,,,,,,,,,, -7400,0.8107567,5.477365,,,,,,,,,,,,,, -7500,1.5821881,5.4652224,,,,,,,,,,,,,, -7600,0.6196199,5.458225,,,,,,,,,,,,,, -7700,1.0208366,5.466438,,,,,,,,,,,,,, -7800,0.6164023,5.434881,,,,,,,,,,,,,, -7900,1.8003418,5.4557495,,,,,,,,,,,,,, -8000,1.3614093,5.4798474,,,,,,,,,,,,,, -8100,0.7599719,5.488925,,,,,,,,,,,,,, -8200,0.60451853,5.4476523,,,,,,,,,,,,,, -8300,0.5179183,5.4463835,,,,,,,,,,,,,, -8400,0.39556757,5.4241924,,,,,,,,,,,,,, -8500,0.9371677,5.4250426,,,,,,,,,,,,,, -8600,0.55863684,5.431657,,,,,,,,,,,,,, -8700,0.40197927,5.430676,,,,,,,,,,,,,, -8800,0.99356484,5.424018,,,,,,,,,,,,,, -8900,0.41892067,5.4108067,,,,,,,,,,,,,, -8957,,,5.55107,0.940821302569804,5.475975,0.8966179750330672,5348.0,5.4412518,0.899579550301627,2472.0,7237.664181232452,7918.197555303574,7237.664181232452,679.9091982841492,0.2448720932006836,0.0 -9000,0.5247257,5.407474,,,,,,,,,,,,,, -9100,0.60455054,5.420511,,,,,,,,,,,,,, -9200,0.9389545,5.3949347,,,,,,,,,,,,,, -9300,0.7159285,5.391618,,,,,,,,,,,,,, -9400,0.5770337,5.373419,,,,,,,,,,,,,, -9500,0.41257375,5.4110312,,,,,,,,,,,,,, -9600,0.7001106,5.3870773,,,,,,,,,,,,,, -9700,0.3951041,5.383956,,,,,,,,,,,,,, -9800,0.8747939,5.37667,,,,,,,,,,,,,, -9900,0.37294793,5.3629537,,,,,,,,,,,,,, -10000,1.1997865,5.3494906,,,,,,,,,,,,,, -10100,0.7553692,5.2780924,,,,,,,,,,,,,, -10200,0.70074654,5.1334753,,,,,,,,,,,,,, -10300,1.3534747,4.927495,,,,,,,,,,,,,, -10400,1.6800984,4.6144953,,,,,,,,,,,,,, -10500,0.7644579,4.399341,,,,,,,,,,,,,, -10600,0.8297515,4.468192,,,,,,,,,,,,,, -10700,0.93543196,4.2087,,,,,,,,,,,,,, -10800,,,4.311075,0.812946352661112,4.4674177,0.795195844637323,5348.0,4.28696,0.7847175674852234,2472.0,8677.619053125381,9507.45812034607,8677.619053125381,829.0853753089905,0.2981421947479248,0.0 -10800,0.7687471,4.1322217,,,,,,,,,,,,,, -10900,0.76139104,4.0064797,,,,,,,,,,,,,, -11000,0.76430017,3.9495707,,,,,,,,,,,,,, -11100,0.73712516,3.9503775,,,,,,,,,,,,,, -11200,1.0501137,3.8753476,,,,,,,,,,,,,, -11300,0.93455404,3.865771,,,,,,,,,,,,,, -11400,1.2338194,3.7659836,,,,,,,,,,,,,, -11500,1.1008257,3.7106385,,,,,,,,,,,,,, -11600,0.8421493,3.729955,,,,,,,,,,,,,, -11700,1.2371861,3.6325169,,,,,,,,,,,,,, -11800,2.418772,3.5933247,,,,,,,,,,,,,, -11900,1.0327973,3.5152252,,,,,,,,,,,,,, -12000,1.4041717,3.5536106,,,,,,,,,,,,,, -12100,0.84430134,3.433368,,,,,,,,,,,,,, -12200,1.0598441,3.3634257,,,,,,,,,,,,,, -12300,1.1304483,3.3593452,,,,,,,,,,,,,, -12400,1.0136127,3.330556,,,,,,,,,,,,,, -12500,1.4148915,3.2866535,,,,,,,,,,,,,, -12600,0.9236208,3.300476,,,,,,,,,,,,,, -12614,,,2.5407991,0.6001672788101385,2.867275,0.6317522229838671,5348.0,2.5912423,0.5990697296528751,2472.0,10117.509083271028,11080.74349808693,10117.509083271028,962.355612039566,0.345757246017456,0.0 -12700,1.0358553,3.2086313,,,,,,,,,,,,,, -12800,1.1377467,3.2100687,,,,,,,,,,,,,, -12900,1.0412422,3.158434,,,,,,,,,,,,,, -13000,0.9628804,3.0282772,,,,,,,,,,,,,, -13100,0.75938845,3.0661557,,,,,,,,,,,,,, -13200,0.95255965,3.0005994,,,,,,,,,,,,,, -13300,1.4232062,2.9263556,,,,,,,,,,,,,, -13400,1.1970283,2.948447,,,,,,,,,,,,,, -13500,1.1244439,2.991532,,,,,,,,,,,,,, -13600,1.0458843,2.9529932,,,,,,,,,,,,,, -13700,1.1336601,2.8779714,,,,,,,,,,,,,, -13800,0.84262854,2.8236425,,,,,,,,,,,,,, -13900,0.9536529,2.8055363,,,,,,,,,,,,,, -14000,1.5885041,2.83737,,,,,,,,,,,,,, -14100,1.0706663,2.7956257,,,,,,,,,,,,,, -14200,1.0307431,2.7811377,,,,,,,,,,,,,, -14300,0.8706629,2.739196,,,,,,,,,,,,,, -14388,,,1.5154848,0.4560902772271994,1.8800801,0.4955057590005503,5348.0,1.5566329,0.4498202425202608,2472.0,11557.722923755646,12654.597905158997,11557.722923755646,1095.8711938858032,0.395554780960083,0.0 -14400,0.72463036,2.7441168,,,,,,,,,,,,,, -14500,0.9690197,2.6943386,,,,,,,,,,,,,, -14600,0.9040949,2.6589987,,,,,,,,,,,,,, -14700,0.7923906,2.708236,,,,,,,,,,,,,, -14800,0.989968,2.6256034,,,,,,,,,,,,,, -14900,1.3962868,2.6326492,,,,,,,,,,,,,, -15000,0.9468417,2.659631,,,,,,,,,,,,,, -15100,0.8878627,2.622525,,,,,,,,,,,,,, -15200,1.0407087,2.6254284,,,,,,,,,,,,,, -15300,0.8467367,2.5923007,,,,,,,,,,,,,, -15400,0.8080323,2.4805481,,,,,,,,,,,,,, -15500,0.7812358,2.5200572,,,,,,,,,,,,,, -15600,2.426437,2.4687915,,,,,,,,,,,,,, -15700,0.94812995,2.5120876,,,,,,,,,,,,,, -15800,1.1230013,2.4368787,,,,,,,,,,,,,, -15900,0.8651819,2.4517248,,,,,,,,,,,,,, -16000,0.8364754,2.3542457,,,,,,,,,,,,,, -16100,0.9082248,2.4772954,,,,,,,,,,,,,, -16164,,,1.1518191,0.3756104432757325,1.5427254,0.4272473618660513,5348.0,1.2202077,0.3743830357686917,2472.0,12998.31579875946,14228.770209550858,12998.31579875946,1229.3207650184631,0.4482710361480713,0.0 -16200,0.87650913,2.370383,,,,,,,,,,,,,, -16300,1.0905447,2.3626442,,,,,,,,,,,,,, -16400,1.1967976,2.3921776,,,,,,,,,,,,,, -16500,0.90353125,2.393886,,,,,,,,,,,,,, -16600,0.7414656,2.2719638,,,,,,,,,,,,,, -16700,1.0563092,2.2835352,,,,,,,,,,,,,, -16800,0.79170954,2.2646127,,,,,,,,,,,,,, -16900,1.1536012,2.2848296,,,,,,,,,,,,,, -17000,0.7161067,2.2275043,,,,,,,,,,,,,, -17100,0.7628895,2.2788744,,,,,,,,,,,,,, -17200,0.66873163,2.2279067,,,,,,,,,,,,,, -17300,0.8235757,2.2080107,,,,,,,,,,,,,, -17400,0.74308956,2.218356,,,,,,,,,,,,,, -17500,1.0217409,2.1908643,,,,,,,,,,,,,, -17600,0.83938605,2.1501534,,,,,,,,,,,,,, -17700,0.79505646,2.1643322,,,,,,,,,,,,,, -17800,1.91434,2.2050579,,,,,,,,,,,,,, -17900,1.2030313,2.1127818,,,,,,,,,,,,,, -17977,,,0.89231455,0.2999344248311306,1.2439867,0.3619143246087452,5348.0,0.92897916,0.2999410964190685,2472.0,14438.654992818832,15801.002537488936,14438.654992818832,1361.0830354690552,0.5029153823852539,0.0 -18000,0.7969752,2.1053243,,,,,,,,,,,,,, -18100,0.9015946,2.077633,,,,,,,,,,,,,, -18200,0.85798025,2.1971788,,,,,,,,,,,,,, -18300,0.7462399,2.0459306,,,,,,,,,,,,,, -18400,0.9669405,2.0378752,,,,,,,,,,,,,, -18500,0.80244404,2.0685725,,,,,,,,,,,,,, -18600,0.8948533,2.0795588,,,,,,,,,,,,,, -18700,1.0276401,2.0731182,,,,,,,,,,,,,, -18800,0.7089602,2.0317829,,,,,,,,,,,,,, -18900,0.7935293,2.0123818,,,,,,,,,,,,,, -19000,1.0376363,2.004071,,,,,,,,,,,,,, -19100,0.96603733,2.025001,,,,,,,,,,,,,, -19200,0.8186861,1.9821671,,,,,,,,,,,,,, -19300,0.7142566,2.096302,,,,,,,,,,,,,, -19400,0.74147594,1.9887081,,,,,,,,,,,,,, -19500,0.7136231,2.0242286,,,,,,,,,,,,,, -19600,0.71207446,1.9988811,,,,,,,,,,,,,, -19700,1.0098206,2.0033352,,,,,,,,,,,,,, -19762,,,0.69894594,0.2439761932747069,1.0527449,0.3146741071859583,5348.0,0.75496924,0.2517214063737737,2472.0,15878.671095132828,17373.685792684555,15878.671095132828,1493.616887807846,0.5617284774780273,0.0 -19800,1.0285492,1.9603268,,,,,,,,,,,,,, -19900,0.7561634,1.9378786,,,,,,,,,,,,,, -20000,0.8698228,1.93064,,,,,,,,,,,,,, -20100,1.3813714,1.8631349,,,,,,,,,,,,,, -20200,0.7364758,1.9530984,,,,,,,,,,,,,, -20300,0.7882989,1.8933097,,,,,,,,,,,,,, -20400,0.6749902,1.9015697,,,,,,,,,,,,,, -20500,0.75490725,1.9584703,,,,,,,,,,,,,, -20600,0.68561625,1.9360505,,,,,,,,,,,,,, -20700,0.7584766,1.8723198,,,,,,,,,,,,,, -20800,1.0984206,1.9380244,,,,,,,,,,,,,, -20900,0.83063895,1.9438756,,,,,,,,,,,,,, -21000,0.7084007,1.9094476,,,,,,,,,,,,,, -21100,0.69582164,1.8813365,,,,,,,,,,,,,, -21200,0.7881478,1.8092453,,,,,,,,,,,,,, -21300,0.8179406,1.8751293,,,,,,,,,,,,,, -21400,1.0290294,1.8838878,,,,,,,,,,,,,, -21500,0.8358455,1.8699937,,,,,,,,,,,,,, -21532,,,0.62691694,0.2201369216241737,0.92000806,0.2802745783330276,5348.0,0.6438148,0.2174151483760892,2472.0,17318.76188802719,18946.50092363357,17318.76188802719,1626.2040507793429,0.6224701404571533,0.0 -21600,0.73536384,1.8286903,,,,,,,,,,,,,, -21700,0.812015,1.8002965,,,,,,,,,,,,,, -21800,0.9170946,1.8188787,,,,,,,,,,,,,, -21900,0.9825019,1.7864766,,,,,,,,,,,,,, -22000,1.1061513,1.8539294,,,,,,,,,,,,,, -22100,1.1242133,1.8350953,,,,,,,,,,,,,, -22200,0.8251492,1.8217294,,,,,,,,,,,,,, -22300,1.3698188,1.8303263,,,,,,,,,,,,,, -22400,0.779665,1.7926288,,,,,,,,,,,,,, -22500,0.69795483,1.7717108,,,,,,,,,,,,,, -22600,0.8123699,1.849199,,,,,,,,,,,,,, -22700,0.78596324,1.7937716,,,,,,,,,,,,,, -22800,0.84737104,1.7782376,,,,,,,,,,,,,, -22900,0.726195,1.743966,,,,,,,,,,,,,, -23000,1.2247798,1.8075092,,,,,,,,,,,,,, -23100,0.7105258,1.7657127,,,,,,,,,,,,,, -23200,0.9902371,1.7802997,,,,,,,,,,,,,, -23300,0.6455476,1.775783,,,,,,,,,,,,,, -23320,,,0.5428806,0.1951972010178117,0.85429704,0.2618341909883468,5348.0,0.58487004,0.1990534803891698,2472.0,18758.965651988983,20520.760549545288,18758.965651988983,1760.1241040229795,0.6804590225219727,0.0 -23400,0.9872751,1.7926563,,,,,,,,,,,,,, -23500,0.75796175,1.768993,,,,,,,,,,,,,, -23600,0.7228446,1.7775788,,,,,,,,,,,,,, -23700,0.7954681,1.7517316,,,,,,,,,,,,,, -23800,0.86205035,1.771678,,,,,,,,,,,,,, -23900,0.8632302,1.7292084,,,,,,,,,,,,,, -24000,0.73836386,1.7275032,,,,,,,,,,,,,, -24100,1.0026656,1.7584964,,,,,,,,,,,,,, -24200,0.718464,1.7708775,,,,,,,,,,,,,, -24300,1.314983,1.7168481,,,,,,,,,,,,,, -24400,0.7511684,1.7862602,,,,,,,,,,,,,, -24500,1.2878503,1.7218666,,,,,,,,,,,,,, -24600,0.76890194,1.7008477,,,,,,,,,,,,,, -24700,0.65554506,1.7457352,,,,,,,,,,,,,, -24800,0.74778265,1.7712879,,,,,,,,,,,,,, -24900,0.6885046,1.712362,,,,,,,,,,,,,, -25000,0.7985767,1.7673345,,,,,,,,,,,,,, -25100,0.6677719,1.659193,,,,,,,,,,,,,, -25128,,,0.44918406,0.1670094697167437,0.79768723,0.2459909053168174,5348.0,0.534569,0.1835760567099303,2472.0,20199.36521506309,22095.005754709244,20199.36521506309,1893.8379728794096,0.7349987030029297,0.0 -25200,1.2819151,1.6717904,,,,,,,,,,,,,, -25300,0.8019315,1.7216445,,,,,,,,,,,,,, -25400,0.64044607,1.6599982,,,,,,,,,,,,,, -25500,1.0831127,1.6601138,,,,,,,,,,,,,, -25600,0.93733394,1.7138394,,,,,,,,,,,,,, -25700,0.8787853,1.7195953,,,,,,,,,,,,,, -25800,1.0869471,1.6640548,,,,,,,,,,,,,, -25900,0.8734548,1.694394,,,,,,,,,,,,,, -26000,0.7136815,1.7186034,,,,,,,,,,,,,, -26100,0.80401635,1.6821748,,,,,,,,,,,,,, -26200,0.7531477,1.6625099,,,,,,,,,,,,,, -26300,0.76978976,1.6641816,,,,,,,,,,,,,, -26400,0.69039464,1.6522298,,,,,,,,,,,,,, -26500,0.67449063,1.7237548,,,,,,,,,,,,,, -26600,0.6994063,1.651277,,,,,,,,,,,,,, -26700,0.8507517,1.6899523,,,,,,,,,,,,,, -26800,0.87328905,1.7596157,,,,,,,,,,,,,, -26900,0.6821998,1.7136612,,,,,,,,,,,,,, -26903,,,0.4337378,0.1593523874575764,0.77539897,0.2393967772768085,5348.0,0.51442176,0.176040460666626,2472.0,21639.26774644852,23666.374056100845,21639.26774644852,2025.173578500748,0.7892842292785645,0.0 -27000,1.0881909,1.7363662,,,,,,,,,,,,,, -27100,0.6542847,1.6152455,,,,,,,,,,,,,, -27200,1.3033422,1.6036893,,,,,,,,,,,,,, -27300,0.6789921,1.6021941,,,,,,,,,,,,,, -27400,0.77734345,1.6134506,,,,,,,,,,,,,, -27500,1.1692277,1.5865101,,,,,,,,,,,,,, -27600,0.89855516,1.7136203,,,,,,,,,,,,,, -27700,0.7661427,1.6296346,,,,,,,,,,,,,, -27800,0.67444867,1.6474917,,,,,,,,,,,,,, -27900,0.70214146,1.6020889,,,,,,,,,,,,,, -28000,0.6227277,1.57631,,,,,,,,,,,,,, -28100,0.85214406,1.6539667,,,,,,,,,,,,,, -28200,1.2184554,1.6174204,,,,,,,,,,,,,, -28300,0.5755251,1.6189487,,,,,,,,,,,,,, -28400,0.6918505,1.607309,,,,,,,,,,,,,, -28500,0.7835231,1.6345867,,,,,,,,,,,,,, -28600,0.70756924,1.6119943,,,,,,,,,,,,,, -28686,,,0.3969769,0.1507755277897458,0.7187384,0.223302470625718,5348.0,0.4679157,0.1655596855767473,2472.0,23079.26572728157,25238.579748630524,23079.26572728157,2157.2508704662323,0.8437254428863525,0.0 -28700,0.71602297,1.578996,,,,,,,,,,,,,, -28800,0.6512488,1.5799415,,,,,,,,,,,,,, -28900,0.6235401,1.6087931,,,,,,,,,,,,,, -29000,0.8437107,1.5649581,,,,,,,,,,,,,, -29100,0.69660914,1.6031779,,,,,,,,,,,,,, -29200,0.9828238,1.5892466,,,,,,,,,,,,,, -29300,0.64799744,1.6147153,,,,,,,,,,,,,, -29400,0.9311987,1.5675422,,,,,,,,,,,,,, -29500,0.7234617,1.6220608,,,,,,,,,,,,,, -29600,0.8596421,1.5662636,,,,,,,,,,,,,, -29700,0.65928745,1.6139045,,,,,,,,,,,,,, -29800,0.63809097,1.5616997,,,,,,,,,,,,,, -29900,0.9060632,1.5228308,,,,,,,,,,,,,, -30000,0.7319001,1.5458928,,,,,,,,,,,,,, -30100,0.81958246,1.5042797,,,,,,,,,,,,,, -30200,0.68123776,1.5278322,,,,,,,,,,,,,, -30300,0.8249164,1.5518814,,,,,,,,,,,,,, -30400,0.75895715,1.5179793,,,,,,,,,,,,,, -30480,,,0.37746027,0.1383267689800353,0.68599296,0.213705745484036,5348.0,0.44529587,0.1558304389332358,2472.0,24519.34055161476,26813.466148376465,24519.34055161476,2291.930029153824,0.8999524116516113,0.0 -30500,0.69247246,1.5544819,,,,,,,,,,,,,, -30600,0.99213845,1.5287999,,,,,,,,,,,,,, -30700,0.6929848,1.5754012,,,,,,,,,,,,,, -30800,0.9142481,1.485097,,,,,,,,,,,,,, -30900,1.0178708,1.5461911,,,,,,,,,,,,,, -31000,0.77440405,1.5552567,,,,,,,,,,,,,, -31100,0.65054953,1.493115,,,,,,,,,,,,,, -31200,0.7542856,1.5083566,,,,,,,,,,,,,, -31300,0.6940626,1.5669292,,,,,,,,,,,,,, -31400,0.74574345,1.5755206,,,,,,,,,,,,,, -31500,0.7482579,1.4744998,,,,,,,,,,,,,, -31600,0.6924164,1.5602229,,,,,,,,,,,,,, -31700,1.036533,1.5712707,,,,,,,,,,,,,, -31800,1.4245418,1.5002882,,,,,,,,,,,,,, -31900,0.76453424,1.5663551,,,,,,,,,,,,,, -32000,0.94790024,1.516338,,,,,,,,,,,,,, -32100,0.7993086,1.5183187,,,,,,,,,,,,,, -32200,0.8347914,1.5280486,,,,,,,,,,,,,, -32284,,,0.36455256,0.1364740165128703,0.659797,0.2071888546685074,5348.0,0.42393988,0.1478276765584059,2472.0,25959.83496165276,28387.68375515937,25959.83496165276,2425.518532037735,0.9573686122894288,0.0 -32300,1.0039,1.5150212,,,,,,,,,,,,,, -32400,1.1239371,1.458501,,,,,,,,,,,,,, -32500,0.73172873,1.4929328,,,,,,,,,,,,,, -32600,0.67670476,1.5125141,,,,,,,,,,,,,, -32700,0.82767093,1.5210668,,,,,,,,,,,,,, -32800,0.7170565,1.5480136,,,,,,,,,,,,,, -32900,0.74177253,1.4541229,,,,,,,,,,,,,, -33000,0.79121333,1.5103151,,,,,,,,,,,,,, -33100,0.6926166,1.4866222,,,,,,,,,,,,,, -33200,0.94649523,1.5241446,,,,,,,,,,,,,, -33300,0.7191436,1.4976051,,,,,,,,,,,,,, -33400,0.67266196,1.5151798,,,,,,,,,,,,,, -33500,0.6787438,1.4637039,,,,,,,,,,,,,, -33600,2.5922334,1.4914283,,,,,,,,,,,,,, -33700,0.85100335,1.5512846,,,,,,,,,,,,,, -33800,0.9072776,1.5238736,,,,,,,,,,,,,, -33900,0.6878389,1.5153245,,,,,,,,,,,,,, -34000,0.7799429,1.4879041,,,,,,,,,,,,,, -34050,,,0.34491673,0.1303387524763742,0.6365954,0.1974473097309248,5348.0,0.39865455,0.1377937562204212,2472.0,27400.28605747223,29962.1297519207,27400.28605747223,2559.377106666565,1.0186927318572998,0.0 -34100,0.87054837,1.4972994,,,,,,,,,,,,,, -34200,0.85625833,1.4745005,,,,,,,,,,,,,, -34300,0.7626874,1.4673771,,,,,,,,,,,,,, -34400,0.7257872,1.5337331,,,,,,,,,,,,,, -34500,0.84161556,1.5064316,,,,,,,,,,,,,, -34600,0.92891324,1.5048724,,,,,,,,,,,,,, -34700,1.0044414,1.5204875,,,,,,,,,,,,,, -34800,0.95228237,1.524062,,,,,,,,,,,,,, -34900,0.66149217,1.4380646,,,,,,,,,,,,,, -35000,1.8088667,1.4794983,,,,,,,,,,,,,, -35100,1.1669502,1.4987828,,,,,,,,,,,,,, -35200,1.0783963,1.4714721,,,,,,,,,,,,,, -35300,1.8498379,1.4972134,,,,,,,,,,,,,, -35400,1.2418089,1.4817514,,,,,,,,,,,,,, -35500,0.7886706,1.4733691,,,,,,,,,,,,,, -35600,0.82335484,1.4399776,,,,,,,,,,,,,, -35700,0.7404618,1.5185901,,,,,,,,,,,,,, -35800,0.96755075,1.4844335,,,,,,,,,,,,,, -35829,,,0.33196563,0.1249875557371981,0.6245508,0.1948598627108335,5348.0,0.39754122,0.1369406698758962,2472.0,28840.38019633293,31535.858542203903,28840.38019633293,2692.8508801460266,1.1042413711547852,0.0 -35900,0.6661249,1.5097656,,,,,,,,,,,,,, -36000,0.71900195,1.5146258,,,,,,,,,,,,,, -36100,0.6962673,1.4869542,,,,,,,,,,,,,, -36200,1.0344379,1.4935805,,,,,,,,,,,,,, -36300,0.83111113,1.4566656,,,,,,,,,,,,,, -36400,0.75870854,1.4332098,,,,,,,,,,,,,, -36500,0.7235997,1.4347408,,,,,,,,,,,,,, -36600,1.0230696,1.457089,,,,,,,,,,,,,, -36700,0.9639714,1.496932,,,,,,,,,,,,,, -36800,1.4703492,1.44806,,,,,,,,,,,,,, -36900,0.8195961,1.4857445,,,,,,,,,,,,,, -37000,0.78247875,1.4419872,,,,,,,,,,,,,, -37100,0.69790125,1.4076862,,,,,,,,,,,,,, -37200,1.1687149,1.468667,,,,,,,,,,,,,, -37300,1.0295521,1.4421651,,,,,,,,,,,,,, -37400,1.1517855,1.4067239,,,,,,,,,,,,,, -37500,0.65449965,1.4366372,,,,,,,,,,,,,, -37600,1.1801274,1.4650635,,,,,,,,,,,,,, -37632,,,0.2906541,0.1126476899473298,0.6097289,0.1901483920175328,5348.0,0.37677324,0.1318221518087461,2472.0,30280.80466890335,33110.70022511482,30280.80466890335,2827.135392189026,1.1598222255706787,0.0 -37700,1.2580746,1.4622434,,,,,,,,,,,,,, -37800,0.65943664,1.4237275,,,,,,,,,,,,,, -37900,0.7038637,1.4357028,,,,,,,,,,,,,, -38000,0.8266633,1.4781679,,,,,,,,,,,,,, -38100,0.6854661,1.4059904,,,,,,,,,,,,,, -38200,0.8676422,1.4451393,,,,,,,,,,,,,, -38300,1.0972908,1.4423032,,,,,,,,,,,,,, -38400,0.7040989,1.3966112,,,,,,,,,,,,,, -38500,0.84724164,1.44502,,,,,,,,,,,,,, -38600,0.6718954,1.4314324,,,,,,,,,,,,,, -38700,0.88969857,1.4283737,,,,,,,,,,,,,, -38800,0.7893879,1.3877728,,,,,,,,,,,,,, -38900,0.8283111,1.4395698,,,,,,,,,,,,,, -39000,1.0343812,1.3921705,,,,,,,,,,,,,, -39100,0.677927,1.428751,,,,,,,,,,,,,, -39200,0.81430846,1.3970199,,,,,,,,,,,,,, -39300,0.6657112,1.3958492,,,,,,,,,,,,,, -39400,0.9001463,1.4300569,,,,,,,,,,,,,, -39434,,,0.2863048,0.1086859870300435,0.5883182,0.184712822344729,5348.0,0.37109178,0.1293644506733288,2472.0,31721.29601264,34684.53674340248,31721.29601264,2960.3469581604004,1.2174606323242188,0.0 -39500,0.6543179,1.4190292,,,,,,,,,,,,,, -39600,0.77720696,1.3813334,,,,,,,,,,,,,, -39700,0.77380645,1.4250891,,,,,,,,,,,,,, -39800,0.70041144,1.3653262,,,,,,,,,,,,,, -39900,0.7245723,1.398763,,,,,,,,,,,,,, -40000,0.72051185,1.3736415,,,,,,,,,,,,,, -40100,0.813132,1.3694956,,,,,,,,,,,,,, -40200,0.9955671,1.4644992,,,,,,,,,,,,,, -40300,0.76283115,1.405286,,,,,,,,,,,,,, -40400,0.666863,1.4390217,,,,,,,,,,,,,, -40500,0.79219794,1.4185452,,,,,,,,,,,,,, -40600,0.8347136,1.3943427,,,,,,,,,,,,,, -40700,0.89488643,1.3896887,,,,,,,,,,,,,, -40800,1.0339823,1.4138793,,,,,,,,,,,,,, -40900,0.80906326,1.41035,,,,,,,,,,,,,, -41000,0.87953115,1.3746278,,,,,,,,,,,,,, -41100,0.9468744,1.398208,,,,,,,,,,,,,, -41198,,,0.28539166,0.10729820385198,0.5655498,0.176255346264132,5348.0,0.351271,0.1214226230373936,2472.0,33161.7505197525,36256.37497663498,33161.7505197525,3091.5962493419647,1.2778022289276123,0.0 -41200,0.68095773,1.3795981,,,,,,,,,,,,,, -41300,1.2070272,1.3431224,,,,,,,,,,,,,, -41400,0.78887445,1.3827554,,,,,,,,,,,,,, -41500,0.92624015,1.3570012,,,,,,,,,,,,,, -41600,0.827173,1.4176157,,,,,,,,,,,,,, -41700,0.81153196,1.3718712,,,,,,,,,,,,,, -41800,0.6922123,1.3789029,,,,,,,,,,,,,, -41900,1.2611425,1.387617,,,,,,,,,,,,,, -42000,0.74499416,1.3859323,,,,,,,,,,,,,, -42100,0.793986,1.3512828,,,,,,,,,,,,,, -42200,0.7691312,1.3746554,,,,,,,,,,,,,, -42300,0.72772366,1.353974,,,,,,,,,,,,,, -42400,0.7416378,1.3948226,,,,,,,,,,,,,, -42500,1.0994961,1.3528199,,,,,,,,,,,,,, -42600,0.80659527,1.4007696,,,,,,,,,,,,,, -42700,0.67447215,1.352384,,,,,,,,,,,,,, -42800,0.7435555,1.3482267,,,,,,,,,,,,,, -42900,0.90793717,1.3303913,,,,,,,,,,,,,, -42961,,,0.2625836,0.0988772736434404,0.5573182,0.1747299110806453,5348.0,0.3410652,0.1194117766538703,2472.0,34602.12437748909,37829.61024093628,34602.12437748909,3224.325934410095,1.335150957107544,0.0 -43000,1.3721236,1.4260358,,,,,,,,,,,,,, -43100,0.92808074,1.4202255,,,,,,,,,,,,,, -43200,0.80410725,1.4102049,,,,,,,,,,,,,, -43300,0.6892419,1.3420222,,,,,,,,,,,,,, -43400,2.5720747,1.3584851,,,,,,,,,,,,,, -43500,0.7081877,1.345132,,,,,,,,,,,,,, -43600,0.75712126,1.4036018,,,,,,,,,,,,,, -43700,0.72708946,1.3517805,,,,,,,,,,,,,, -43800,0.7450126,1.3583881,,,,,,,,,,,,,, -43900,0.722541,1.4091585,,,,,,,,,,,,,, -44000,0.85090804,1.4488883,,,,,,,,,,,,,, -44100,0.7568694,1.3902986,,,,,,,,,,,,,, -44200,1.0787483,1.335858,,,,,,,,,,,,,, -44300,0.75730884,1.3435144,,,,,,,,,,,,,, -44400,0.75442255,1.3483974,,,,,,,,,,,,,, -44500,0.7399559,1.2779566,,,,,,,,,,,,,, -44600,0.9068672,1.3274193,,,,,,,,,,,,,, -44700,1.3208854,1.3344783,,,,,,,,,,,,,, -44771,,,0.25002244,0.0974915358571868,0.55120695,0.1716018034891916,5348.0,0.33385333,0.1161822354924542,2472.0,36041.99537992477,39404.787296772,36041.99537992477,3359.500274658203,1.389387607574463,0.0 -44800,0.8497704,1.3559756,,,,,,,,,,,,,, -44900,0.93849444,1.3403579,,,,,,,,,,,,,, -45000,0.8391726,1.2947804,,,,,,,,,,,,,, -45100,0.8577537,1.345215,,,,,,,,,,,,,, -45200,1.0016124,1.4009641,,,,,,,,,,,,,, -45300,0.90872973,1.3622271,,,,,,,,,,,,,, -45400,0.767611,1.3067495,,,,,,,,,,,,,, -45500,0.80345124,1.36023,,,,,,,,,,,,,, -45600,0.7060658,1.3227926,,,,,,,,,,,,,, -45700,0.8190684,1.3000982,,,,,,,,,,,,,, -45800,0.75148624,1.3061479,,,,,,,,,,,,,, -45900,0.8399595,1.3322165,,,,,,,,,,,,,, -46000,0.8686167,1.3745954,,,,,,,,,,,,,, -46100,0.7484481,1.3191854,,,,,,,,,,,,,, -46200,0.7366443,1.3640249,,,,,,,,,,,,,, -46300,0.8545639,1.2992389,,,,,,,,,,,,,, -46400,0.7781837,1.2738591,,,,,,,,,,,,,, -46500,0.85003537,1.3132006,,,,,,,,,,,,,, -46549,,,0.27212,0.0963176872267781,0.52371204,0.1627677959392529,5348.0,0.32607883,0.1123433469420916,2472.0,37482.20679521561,40980.0468621254,37482.20679521561,3494.410943031311,1.4525690078735352,0.0 -46600,0.97553986,1.3337609,,,,,,,,,,,,,, -46700,0.79523385,1.3257842,,,,,,,,,,,,,, -46800,0.8232192,1.3597435,,,,,,,,,,,,,, -46900,0.79508966,1.2909753,,,,,,,,,,,,,, -47000,0.8459107,1.303327,,,,,,,,,,,,,, -47100,1.3565166,1.3031619,,,,,,,,,,,,,, -47200,1.2543756,1.3235556,,,,,,,,,,,,,, -47300,0.8597055,1.3571558,,,,,,,,,,,,,, -47400,1.1894128,1.304255,,,,,,,,,,,,,, -47500,0.84811294,1.3218552,,,,,,,,,,,,,, -47600,0.7972808,1.3047779,,,,,,,,,,,,,, -47700,0.83295774,1.2170175,,,,,,,,,,,,,, -47800,1.0277458,1.284494,,,,,,,,,,,,,, -47900,0.84849983,1.3204899,,,,,,,,,,,,,, -48000,0.9081509,1.2933003,,,,,,,,,,,,,, -48100,1.3116995,1.3035666,,,,,,,,,,,,,, -48200,0.9480332,1.311769,,,,,,,,,,,,,, -48300,0.8769223,1.3453459,,,,,,,,,,,,,, -48311,,,0.19800447,0.0776606435998933,0.51308125,0.1602093128783417,5348.0,0.30938765,0.1074076330916255,2472.0,38922.34973311424,42555.49884772301,38922.34973311424,3629.584242582321,1.512967586517334,0.0 -48400,0.8301226,1.2915821,,,,,,,,,,,,,, -48500,1.0177343,1.2451226,,,,,,,,,,,,,, -48600,1.2026061,1.2875221,,,,,,,,,,,,,, -48700,0.7080909,1.3523659,,,,,,,,,,,,,, -48800,1.026619,1.3027354,,,,,,,,,,,,,, -48900,1.0519681,1.2850524,,,,,,,,,,,,,, -49000,0.93336535,1.3228027,,,,,,,,,,,,,, -49100,0.92050743,1.2649647,,,,,,,,,,,,,, -49200,1.4845589,1.3326286,,,,,,,,,,,,,, -49300,0.83068144,1.28135,,,,,,,,,,,,,, -49400,0.7140302,1.3247985,,,,,,,,,,,,,, -49500,0.9228073,1.3067831,,,,,,,,,,,,,, -49600,0.9156417,1.3008147,,,,,,,,,,,,,, -49700,1.0260539,1.2829771,,,,,,,,,,,,,, -49800,0.8590715,1.3259422,,,,,,,,,,,,,, -49900,1.0996052,1.278818,,,,,,,,,,,,,, -50000,1.0330216,1.3027706,,,,,,,,,,,,,, -50086,,,0.21282442,0.0822324569887133,0.50448716,0.157303262307269,5348.0,0.30191362,0.1057014604025755,2472.0,40362.35197472572,44130.64695501328,40362.35197472572,3764.5926241874695,1.574493646621704,0.0 -50100,0.94289637,1.2710952,,,,,,,,,,,,,, -50200,0.84131557,1.3037617,,,,,,,,,,,,,, -50300,0.8025203,1.3163633,,,,,,,,,,,,,, -50400,1.2044499,1.325301,,,,,,,,,,,,,, -50500,0.78390735,1.2264656,,,,,,,,,,,,,, -50600,0.82850945,1.2883652,,,,,,,,,,,,,, -50700,0.88374895,1.254686,,,,,,,,,,,,,, -50800,0.80297947,1.3011895,,,,,,,,,,,,,, -50900,1.1305332,1.285229,,,,,,,,,,,,,, -51000,0.7164274,1.2283826,,,,,,,,,,,,,, -51100,0.86144835,1.2776399,,,,,,,,,,,,,, -51200,0.8670879,1.318166,,,,,,,,,,,,,, -51300,1.0112116,1.2967622,,,,,,,,,,,,,, -51400,1.2233918,1.3323462,,,,,,,,,,,,,, -51500,1.167551,1.2595533,,,,,,,,,,,,,, -51600,1.1354479,1.2520521,,,,,,,,,,,,,, -51700,1.0170772,1.2865566,,,,,,,,,,,,,, -51800,0.8743835,1.2147653,,,,,,,,,,,,,, -51885,,,0.27286735,0.1037494229120408,0.5019356,0.1571680971644284,5348.0,0.29826474,0.1038124834968415,2472.0,41802.36686468125,45704.29691815376,41802.36686468125,3898.094487667084,1.630784034729004,0.0 -51900,1.5047225,1.2339686,,,,,,,,,,,,,, -52000,0.89575344,1.2969477,,,,,,,,,,,,,, -52100,0.95695627,1.2762923,,,,,,,,,,,,,, -52200,3.421035,1.2795917,,,,,,,,,,,,,, -52300,0.85406387,1.3085567,,,,,,,,,,,,,, -52400,0.8537201,1.256245,,,,,,,,,,,,,, -52500,0.83405507,1.3069042,,,,,,,,,,,,,, -52600,1.0485793,1.27246,,,,,,,,,,,,,, -52700,1.1799184,1.2973343,,,,,,,,,,,,,, -52800,1.255016,1.2763098,,,,,,,,,,,,,, -52900,1.6659304,1.277235,,,,,,,,,,,,,, -53000,0.7999328,1.2924513,,,,,,,,,,,,,, -53100,0.90744877,1.284803,,,,,,,,,,,,,, -53200,1.5299829,1.2511595,,,,,,,,,,,,,, -53300,0.8297081,1.260824,,,,,,,,,,,,,, -53400,0.7810176,1.2989686,,,,,,,,,,,,,, -53500,1.0687186,1.2837676,,,,,,,,,,,,,, -53600,0.9226444,1.2445807,,,,,,,,,,,,,, -53658,,,0.28197432,0.1062609862536667,0.4842883,0.1505739691244195,5348.0,0.28853884,0.1002173339020575,2472.0,43242.77938055992,47276.94588851929,43242.77938055992,4030.19356417656,1.6926136016845703,0.0 -53700,0.7705038,1.2325445,,,,,,,,,,,,,, -53800,0.85550267,1.2529261,,,,,,,,,,,,,, -53900,2.0082102,1.2189593,,,,,,,,,,,,,, -54000,0.90208876,1.2770705,,,,,,,,,,,,,, -54100,1.0696983,1.2872082,,,,,,,,,,,,,, -54200,0.88161075,1.2351179,,,,,,,,,,,,,, -54300,0.895895,1.210899,,,,,,,,,,,,,, -54400,0.80014545,1.2202919,,,,,,,,,,,,,, -54500,0.8220321,1.225491,,,,,,,,,,,,,, -54600,0.7842956,1.2368851,,,,,,,,,,,,,, -54700,0.8571842,1.1977837,,,,,,,,,,,,,, -54800,1.1841261,1.2208134,,,,,,,,,,,,,, -54900,1.08946,1.2424798,,,,,,,,,,,,,, -55000,0.8505579,1.2140728,,,,,,,,,,,,,, -55100,0.9736928,1.2500744,,,,,,,,,,,,,, -55200,0.9205712,1.2444733,,,,,,,,,,,,,, -55300,0.8398061,1.2500949,,,,,,,,,,,,,, -55400,0.8635072,1.2548792,,,,,,,,,,,,,, -55435,,,0.3120024,0.1183694542416929,0.47713417,0.1493864467980343,5348.0,0.28182715,0.0968456116832206,2472.0,44682.94412112236,48849.46684360504,44682.94412112236,4162.418050527573,1.748811960220337,0.0 -55500,1.0014083,1.241937,,,,,,,,,,,,,, -55600,0.8606537,1.2302442,,,,,,,,,,,,,, -55700,1.5032452,1.2133157,,,,,,,,,,,,,, -55800,0.91764134,1.2087048,,,,,,,,,,,,,, -55900,0.8091775,1.2316772,,,,,,,,,,,,,, -56000,1.0606312,1.2474867,,,,,,,,,,,,,, -56100,0.7679771,1.2554742,,,,,,,,,,,,,, -56200,0.84667027,1.2149445,,,,,,,,,,,,,, -56300,1.4689163,1.26865,,,,,,,,,,,,,, -56400,1.2579333,1.2425126,,,,,,,,,,,,,, -56500,1.1578971,1.2372264,,,,,,,,,,,,,, -56600,1.0138168,1.1904851,,,,,,,,,,,,,, -56700,0.95547926,1.3029519,,,,,,,,,,,,,, -56800,1.1298003,1.1855601,,,,,,,,,,,,,, -56900,0.7595231,1.2113042,,,,,,,,,,,,,, -57000,2.298792,1.2109778,,,,,,,,,,,,,, -57100,0.91842836,1.2019402,,,,,,,,,,,,,, -57200,0.99194247,1.1693943,,,,,,,,,,,,,, -57224,,,0.27040535,0.1010153773917126,0.46657196,0.1453218378597565,5348.0,0.2736266,0.0924176873235431,2472.0,46123.28657770157,50421.46864414215,46123.28657770157,4293.939308643341,1.810908794403076,0.0 -57300,0.68724537,1.2022574,,,,,,,,,,,,,, -57400,0.8468652,1.2526588,,,,,,,,,,,,,, -57500,1.4435172,1.2102457,,,,,,,,,,,,,, -57600,1.6325915,1.2067577,,,,,,,,,,,,,, -57700,1.88912,1.2399801,,,,,,,,,,,,,, -57800,1.0280318,1.2113342,,,,,,,,,,,,,, -57900,0.93524975,1.2293944,,,,,,,,,,,,,, -58000,0.9965137,1.1704031,,,,,,,,,,,,,, -58100,0.8132985,1.2327085,,,,,,,,,,,,,, -58200,0.82872117,1.1877369,,,,,,,,,,,,,, -58300,1.1485659,1.2170343,,,,,,,,,,,,,, -58400,0.8001562,1.2173371,,,,,,,,,,,,,, -58500,1.3074417,1.2609283,,,,,,,,,,,,,, -58600,1.5361702,1.1632901,,,,,,,,,,,,,, -58700,0.83108103,1.2120167,,,,,,,,,,,,,, -58800,0.9442075,1.1794165,,,,,,,,,,,,,, -58900,1.2290037,1.1688149,,,,,,,,,,,,,, -59000,0.9688812,1.1901093,,,,,,,,,,,,,, -59012,,,0.24490872,0.0938699363336498,0.45609862,0.1409772439827374,5348.0,0.26979008,0.0920317673105437,2472.0,47563.39091444016,51994.320743083954,47563.39091444016,4426.556207895279,1.867297887802124,0.0 -59100,0.88491994,1.1667488,,,,,,,,,,,,,, -59200,1.7112501,1.1993055,,,,,,,,,,,,,, -59300,1.4763722,1.1820304,,,,,,,,,,,,,, -59400,1.2346604,1.2122251,,,,,,,,,,,,,, -59500,0.9797051,1.1707977,,,,,,,,,,,,,, -59600,1.442152,1.2301029,,,,,,,,,,,,,, -59700,1.3742776,1.16908,,,,,,,,,,,,,, -59800,0.91348326,1.1683956,,,,,,,,,,,,,, -59900,1.0278995,1.1637144,,,,,,,,,,,,,, -60000,0.77842003,1.1484869,,,,,,,,,,,,,, -60100,0.8200382,1.1711396,,,,,,,,,,,,,, -60200,1.0051081,1.1675084,,,,,,,,,,,,,, -60300,0.865344,1.1868436,,,,,,,,,,,,,, -60400,1.1689441,1.1471913,,,,,,,,,,,,,, -60500,0.975508,1.2390244,,,,,,,,,,,,,, -60600,1.01788,1.2388812,,,,,,,,,,,,,, -60700,1.885535,1.1614422,,,,,,,,,,,,,, -60762,,,0.20578378,0.0797272673202013,0.45228767,0.1395773192890313,5348.0,0.26356447,0.0903865293603883,2472.0,49003.68826889992,53567.35671496391,49003.68826889992,4559.163670539856,1.924031972885132,0.0 -60800,0.8631258,1.1120389,,,,,,,,,,,,,, -60900,1.0905501,1.1395892,,,,,,,,,,,,,, -61000,1.2657704,1.2088182,,,,,,,,,,,,,, -61100,1.3566227,1.1320125,,,,,,,,,,,,,, -61200,1.1167736,1.1872009,,,,,,,,,,,,,, -61300,0.8687684,1.1763313,,,,,,,,,,,,,, -61400,0.8464244,1.1592969,,,,,,,,,,,,,, -61500,1.1134404,1.1899972,,,,,,,,,,,,,, -61600,1.1717569,1.1559443,,,,,,,,,,,,,, -61700,1.2491755,1.2025497,,,,,,,,,,,,,, -61800,0.9953974,1.1675216,,,,,,,,,,,,,, -61900,1.2912891,1.1432983,,,,,,,,,,,,,, -62000,0.8811149,1.1886703,,,,,,,,,,,,,, -62100,1.5258751,1.1661035,,,,,,,,,,,,,, -62200,1.1038299,1.1655604,,,,,,,,,,,,,, -62300,1.2389891,1.1649792,,,,,,,,,,,,,, -62400,0.9641043,1.1954534,,,,,,,,,,,,,, -62500,1.8563623,1.1793894,,,,,,,,,,,,,, -62525,,,0.23597239,0.0912774590432145,0.44328588,0.1363526651669772,5348.0,0.2582945,0.0898381167103365,2472.0,50443.94960784912,55141.62321925163,50443.94960784912,4693.026634216309,1.9902966022491453,0.0 -62600,1.57685,1.1399431,,,,,,,,,,,,,, -62700,0.9896387,1.2007887,,,,,,,,,,,,,, -62800,0.91548705,1.2188058,,,,,,,,,,,,,, -62900,1.0681537,1.1824892,,,,,,,,,,,,,, -63000,1.372391,1.117354,,,,,,,,,,,,,, -63100,1.284231,1.1887355,,,,,,,,,,,,,, -63200,1.4500169,1.1784928,,,,,,,,,,,,,, -63300,1.1657963,1.1565092,,,,,,,,,,,,,, -63400,1.0020677,1.2054534,,,,,,,,,,,,,, -63500,0.98427063,1.1805775,,,,,,,,,,,,,, -63600,1.8291788,1.1430047,,,,,,,,,,,,,, -63700,1.2234843,1.1333922,,,,,,,,,,,,,, -63800,0.9780483,1.1251608,,,,,,,,,,,,,, -63900,1.1809496,1.1459057,,,,,,,,,,,,,, -64000,1.0366415,1.0934849,,,,,,,,,,,,,, -64100,0.9674358,1.1585344,,,,,,,,,,,,,, -64200,0.9090862,1.1471237,,,,,,,,,,,,,, -64300,1.4006034,1.1315191,,,,,,,,,,,,,, -64333,,,0.20355521,0.077820553951449,0.44492757,0.1354644370854533,5348.0,0.25623506,0.0880710092823919,2472.0,51884.26449036598,56716.0893816948,51884.26449036598,4827.03906917572,2.052520036697388,0.0 -64400,0.98946637,1.130262,,,,,,,,,,,,,, -64500,1.7205497,1.1608248,,,,,,,,,,,,,, -64600,0.97011316,1.2045261,,,,,,,,,,,,,, -64700,1.3043971,1.1477185,,,,,,,,,,,,,, -64800,1.3620033,1.1673851,,,,,,,,,,,,,, -64900,0.91338104,1.1740564,,,,,,,,,,,,,, -65000,2.005644,1.1056191,,,,,,,,,,,,,, -65100,0.86391276,1.1328154,,,,,,,,,,,,,, -65200,1.1501312,1.1732321,,,,,,,,,,,,,, -65300,1.1596458,1.1614567,,,,,,,,,,,,,, -65400,1.023668,1.10579,,,,,,,,,,,,,, -65500,0.99620205,1.1377383,,,,,,,,,,,,,, -65600,1.3003029,1.1693869,,,,,,,,,,,,,, -65700,0.95839864,1.1718111,,,,,,,,,,,,,, -65800,1.2989722,1.1738943,,,,,,,,,,,,,, -65900,1.5175022,1.1514094,,,,,,,,,,,,,, -66000,1.0598711,1.1110951,,,,,,,,,,,,,, -66100,2.9464571,1.1581453,,,,,,,,,,,,,, -66115,,,0.20436361,0.0800104603080342,0.43434986,0.1336397076571053,5348.0,0.2523467,0.0857758007840269,2472.0,53324.19461846352,58289.38478684425,53324.19461846352,4960.271555185318,2.1096343994140625,0.0 -66200,0.987384,1.1901728,,,,,,,,,,,,,, -66300,0.951163,1.1372653,,,,,,,,,,,,,, -66400,0.9799125,1.1335998,,,,,,,,,,,,,, -66500,1.1536444,1.1486697,,,,,,,,,,,,,, -66600,1.1722728,1.1651622,,,,,,,,,,,,,, -66700,0.9105351,1.1596243,,,,,,,,,,,,,, -66800,1.2456717,1.1522193,,,,,,,,,,,,,, -66900,0.9347247,1.098455,,,,,,,,,,,,,, -67000,1.8162841,1.1131041,,,,,,,,,,,,,, -67100,1.1325039,1.1558332,,,,,,,,,,,,,, -67200,1.348379,1.1881651,,,,,,,,,,,,,, -67300,2.4729366,1.142817,,,,,,,,,,,,,, -67400,1.3502545,1.0725458,,,,,,,,,,,,,, -67500,1.5302466,1.1658118,,,,,,,,,,,,,, -67600,1.405123,1.128271,,,,,,,,,,,,,, -67700,1.3545206,1.1388927,,,,,,,,,,,,,, -67800,0.87563175,1.1043025,,,,,,,,,,,,,, -67890,,,0.19853008,0.0757643860245349,0.426306,0.1307722756982728,5348.0,0.24856624,0.0845774175857656,2472.0,54764.53757071495,59863.47870230675,54764.53757071495,5093.883127212524,2.1731672286987305,0.0 -67900,1.403048,1.1644363,,,,,,,,,,,,,, -68000,0.8489724,1.1222968,,,,,,,,,,,,,, -68100,1.3050473,1.0979079,,,,,,,,,,,,,, -68200,1.4963905,1.1364743,,,,,,,,,,,,,, -68300,1.0105275,1.0982046,,,,,,,,,,,,,, -68400,0.9317613,1.0955508,,,,,,,,,,,,,, -68500,0.8950527,1.1145629,,,,,,,,,,,,,, -68600,0.975305,1.1144973,,,,,,,,,,,,,, -68700,1.1324104,1.1276411,,,,,,,,,,,,,, -68800,1.3736874,1.1436839,,,,,,,,,,,,,, -68900,2.0875144,1.131707,,,,,,,,,,,,,, -69000,0.9050479,1.1676894,,,,,,,,,,,,,, -69100,1.8574377,1.0836742,,,,,,,,,,,,,, -69200,1.3694711,1.0390452,,,,,,,,,,,,,, -69300,1.1246725,1.0942507,,,,,,,,,,,,,, -69400,1.7510068,1.0998949,,,,,,,,,,,,,, -69500,4.5568852,1.0985507,,,,,,,,,,,,,, -69600,1.3284351,1.1342922,,,,,,,,,,,,,, -69674,,,0.20562021,0.0787849206121335,0.42373237,0.1293337323923264,5348.0,0.24726346,0.0837852659801352,2472.0,56204.93294286728,61435.01011872292,56204.93294286728,5224.875775098801,2.240447998046875,0.0 -69700,0.90989417,1.1152946,,,,,,,,,,,,,, -69800,2.8171272,1.159232,,,,,,,,,,,,,, -69900,1.3025298,1.1340032,,,,,,,,,,,,,, -70000,2.2857454,1.1395077,,,,,,,,,,,,,, -70100,2.5508015,1.0804902,,,,,,,,,,,,,, -70200,1.1663811,1.117942,,,,,,,,,,,,,, -70300,0.92691886,1.0768396,,,,,,,,,,,,,, -70400,1.0311704,1.1586846,,,,,,,,,,,,,, -70500,1.0528548,1.1157603,,,,,,,,,,,,,, -70600,1.0044049,1.1156023,,,,,,,,,,,,,, -70700,0.97877353,1.0671688,,,,,,,,,,,,,, -70800,1.6710566,1.0669926,,,,,,,,,,,,,, -70900,1.2173206,1.0960734,,,,,,,,,,,,,, -71000,1.0680557,1.1034197,,,,,,,,,,,,,, -71100,1.1077124,1.1081265,,,,,,,,,,,,,, -71200,1.698656,1.0443195,,,,,,,,,,,,,, -71300,1.2216724,1.1045254,,,,,,,,,,,,,, -71400,1.4019892,1.1019301,,,,,,,,,,,,,, -71491,,,0.2064576,0.0766548233632427,0.41951063,0.1278082972088398,5348.0,0.24452537,0.0833790343875043,2472.0,57644.84786987305,63007.79894852638,57644.84786987305,5357.613756656647,2.299800157546997,0.0 -71500,1.1849889,1.1125001,,,,,,,,,,,,,, -71600,1.3431203,1.1579822,,,,,,,,,,,,,, -71700,1.3492279,1.0969154,,,,,,,,,,,,,, -71800,1.0619526,1.1459092,,,,,,,,,,,,,, -71900,2.431025,1.1145885,,,,,,,,,,,,,, -72000,1.0786042,1.0582148,,,,,,,,,,,,,, -72100,3.2478697,1.1198038,,,,,,,,,,,,,, -72200,1.0907406,1.1156143,,,,,,,,,,,,,, -72300,1.1163063,1.0832727,,,,,,,,,,,,,, -72400,1.3596089,1.1395228,,,,,,,,,,,,,, -72500,1.1420995,1.1051086,,,,,,,,,,,,,, -72600,1.0763205,1.0862049,,,,,,,,,,,,,, -72700,1.3787739,1.1499741,,,,,,,,,,,,,, -72800,1.0461358,1.0940821,,,,,,,,,,,,,, -72900,1.4157646,1.07932,,,,,,,,,,,,,, -73000,1.9405068,1.1072867,,,,,,,,,,,,,, -73100,1.9804847,1.0550326,,,,,,,,,,,,,, -73200,3.1606994,1.1186842,,,,,,,,,,,,,, -73269,,,0.18731399,0.0725551522855374,0.41804275,0.1278662251272,5348.0,0.24227335,0.0833384112282412,2472.0,59085.31165289879,64579.38373923302,59085.31165289879,5488.593738079071,2.365021228790283,0.0 -73300,0.94842964,1.122646,,,,,,,,,,,,,, -73400,1.4284133,1.0967916,,,,,,,,,,,,,, -73500,0.83087134,1.0586754,,,,,,,,,,,,,, -73600,0.93225235,1.0833353,,,,,,,,,,,,,, -73700,1.062169,1.0926279,,,,,,,,,,,,,, -73800,0.91154283,1.1114063,,,,,,,,,,,,,, -73900,0.8852986,1.0565649,,,,,,,,,,,,,, -74000,0.9044471,1.0960307,,,,,,,,,,,,,, -74100,1.1378353,1.08147,,,,,,,,,,,,,, -74200,1.2435393,1.1140286,,,,,,,,,,,,,, -74300,1.1671859,1.0859872,,,,,,,,,,,,,, -74400,1.2467419,1.0439245,,,,,,,,,,,,,, -74500,1.3928221,1.0591267,,,,,,,,,,,,,, -74600,0.9315247,1.1051162,,,,,,,,,,,,,, -74700,0.9394461,1.1089478,,,,,,,,,,,,,, -74800,1.860185,1.1203421,,,,,,,,,,,,,, -74900,1.0642337,1.096853,,,,,,,,,,,,,, -75000,1.2401967,1.0773882,,,,,,,,,,,,,, -75049,,,0.19145286,0.0743006590165445,0.41606402,0.1271517808007569,5348.0,0.2408368,0.081997846972559,2472.0,60525.43598723412,66151.34201812744,60525.43598723412,5620.283970355988,2.4329538345336914,0.0 -75100,1.2117182,1.0950876,,,,,,,,,,,,,, -75200,4.896626,1.1133379,,,,,,,,,,,,,, -75300,0.9422386,1.1006441,,,,,,,,,,,,,, -75400,1.2184843,1.0692313,,,,,,,,,,,,,, -75500,1.2246271,1.1367856,,,,,,,,,,,,,, -75600,1.0998983,1.1138046,,,,,,,,,,,,,, -75700,1.2218217,1.0795918,,,,,,,,,,,,,, -75746,,,,,,,,,,,61068.55351090431,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 3694dbae7..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -135.31539034843445,0.0,34.33066487312317,1,0,34.33066487312317,31.23744,2472,1.10235004976337,169.64612746238708,32.871246,1.3708653155889337,31.08705,5348,1.0585651254622166 -244.99092769622803,0.028472900390625,1474.825115442276,1750,0,1474.825115442276,6.0945807,2472,0.899579550301627,1719.9163784980774,6.003965,0.9391896477614642,6.1192107,5348,0.8966179750330672 -365.9432199001312,0.0846376419067382,2914.8816883563995,3549,0,2914.8816883563995,2.8966093,2472,0.6532000893709504,3281.0594849586487,3.3109722,0.7425314465408805,3.2170153,5348,0.7051565501993686 -497.9646122455597,0.136824369430542,4355.136976242065,5327,0,4355.136976242065,0.85711485,2472,0.2735563544776877,4853.463213682175,1.1042546,0.342650224439735,1.1603082,5348,0.3323421222858356 -629.766964673996,0.1887216567993164,5796.596266269684,7101,0,5796.596266269684,0.58395475,2472,0.1956614465907013,6426.853343486786,0.71781844,0.2381636636211919,0.856381,5348,0.2575957982949882 -762.0983710289001,0.2411792278289795,7236.706146240234,8886,0,7236.706146240234,0.5030347,2472,0.1674892856417443,7999.423669338226,0.6037436,0.2008276839710539,0.7638324,5348,0.2303117487473087 -895.9172258377075,0.2942352294921875,8676.956310272217,10688,0,8676.956310272217,0.45312193,2472,0.1555663883980257,9573.624148845673,0.59205496,0.2036127467765384,0.702261,5348,0.2120741091168889 -1027.5034348964691,0.3512380123138428,10117.289225816729,12458,0,10117.289225816729,0.4115406,2472,0.1393374362724189,11145.675840854645,0.49941278,0.171514791173534,0.64830595,5348,0.1956998175270571 -1160.9660465717316,0.408066987991333,11557.23063492775,14248,0,11557.23063492775,0.38463014,2472,0.1315987244327991,12719.211951255798,0.49112177,0.166320429114731,0.6159968,5348,0.1868851192832385 -1288.4258918762207,0.4595286846160888,12997.853597402573,16095,0,12997.853597402573,0.36829486,2472,0.1253021347470193,14287.417057275772,0.41128248,0.1451225741954266,0.5923722,5348,0.1799820423453083 -1413.5200998783112,0.5142014026641846,14437.80846619606,17959,0,14437.80846619606,0.35398287,2472,0.1204679787947108,15852.590964317322,0.41662517,0.1468018951732306,0.57269543,5348,0.1724514129584753 -1538.7788639068604,0.561873197555542,15878.370507478714,19819,0,15878.370507478714,0.33988407,2472,0.115024475453456,17418.529887914658,0.3939056,0.138076927376175,0.55853176,5348,0.1676144317753941 -1677.0056114196775,0.6107773780822754,17318.98425388336,21678,0,17318.98425388336,0.32638323,2472,0.1087075741880446,18997.48906755448,0.26414925,0.0981504777709376,0.5420662,5348,0.1627291773270127 -1807.410040140152,0.6587982177734375,18759.508040905,23548,0,18759.508040905,0.3171775,2472,0.1089310015639916,20568.534235477448,0.24399593,0.0890187326286641,0.5178802,5348,0.1580466705928922 -1936.7809422016144,0.7138986587524414,20200.06555223465,25401,0,20200.06555223465,0.31287903,2472,0.1052952288099445,22138.588984251022,0.2405506,0.0882221486073646,0.51276976,5348,0.1538275872056537 -2066.3479647636414,0.7682766914367676,21640.561172246933,27264,0,21640.561172246933,0.29457772,2472,0.1012329128836349,23708.775426387787,0.22414844,0.083011886543259,0.50314236,5348,0.1510277378182415 -2194.294489145279,0.8199701309204102,23080.513560533524,29119,0,23080.513560533524,0.29227978,2472,0.0996892328316373,25276.79612565041,0.22378603,0.0847956333553655,0.4944983,5348,0.1492802456143738 -2322.846914052964,0.8701145648956299,24520.68779182434,30958,0,24520.68779182434,0.286033,2472,0.0981861759389027,26845.64391064644,0.20724949,0.0780107444682858,0.48233274,5348,0.1448487598598144 -2451.4838552474976,0.925905466079712,25960.884137392044,32805,0,25960.884137392044,0.27722582,2472,0.095748786383117,28414.607186079025,0.24316454,0.084999483434562,0.4656841,5348,0.1414696312887996 -2581.6163563728333,0.9803595542907716,27401.33624315262,34652,0,27401.33624315262,0.264707,2472,0.0922145715272276,29985.31670928001,0.20290689,0.0774240193362147,0.46335202,5348,0.1393938808808905 -2712.818051815033,1.0341756343841553,28841.840735912323,36509,0,28841.840735912323,0.2665545,2472,0.0891068998436008,31557.14866900444,0.18551725,0.0709965430233595,0.45179525,5348,0.135850623207855 -2842.51593208313,1.0907597541809082,30282.38442492485,38355,0,30282.38442492485,0.26117116,2472,0.0884366177157597,33127.51769065857,0.17950557,0.0672176894251468,0.4428782,5348,0.1336300530040453 -2972.221476078033,1.150517463684082,31722.82284355164,40185,0,31722.82284355164,0.25024888,2472,0.0846180407450287,34697.79404234886,0.16687053,0.0648847975882859,0.43387946,5348,0.1303378163105708 -3102.9151356220245,1.2059061527252195,33162.69558787346,42037,0,33162.69558787346,0.24411249,2472,0.0836430849227144,36268.49072861672,0.17537016,0.0659101794201083,0.42432463,5348,0.1273062552497176 -3233.0268499851227,1.269730806350708,34603.0143558979,43882,0,34603.0143558979,0.23707676,2472,0.0805963479779822,37839.05934667587,0.17138386,0.0645453444471942,0.4157252,5348,0.124429168637825 -3363.4156663417816,1.3258774280548096,36043.24026441574,45739,0,36043.24026441574,0.2356796,2472,0.080291674283509,39409.80263566971,0.15769911,0.0606665491460746,0.4071435,5348,0.122372727536036 -3495.040193796158,1.3885679244995115,37483.37846851349,47582,0,37483.37846851349,0.22712962,2472,0.0770824447017244,40981.70229816437,0.15129596,0.0577826683285268,0.40430495,5348,0.1194377130057831 -3626.74472117424,1.4479291439056396,38923.28881287575,49408,0,38923.28881287575,0.21716937,2472,0.075437206751569,42553.44931507111,0.13115764,0.0500343243385964,0.39019933,5348,0.1160199658225281 -3759.30832862854,1.5018019676208496,40363.78140926361,51227,0,40363.78140926361,0.21561463,2472,0.0724920277049946,44126.6337416172,0.13253723,0.0503795192415744,0.3822779,5348,0.1125539453739729 -3889.758169412613,1.571280002593994,41803.87991976738,53073,0,41803.87991976738,0.20504084,2472,0.0704608697418398,45697.32681298256,0.13393289,0.0510575079009211,0.3740585,5348,0.1119070836189501 -4020.283055782318,1.6276521682739258,43243.759679079056,54913,0,43243.759679079056,0.20187275,2472,0.0668657201470558,47267.8637509346,0.113810115,0.0432091109589262,0.36113426,5348,0.107765237456192 -4151.09471654892,1.6856026649475098,44683.91825604439,56736,0,44683.91825604439,0.19632728,2472,0.0671094591026344,48838.96536445618,0.11524229,0.0462044148115838,0.35849932,5348,0.1044343821504774 -4281.586913585663,1.746351718902588,46123.78596878052,58571,0,46123.78596878052,0.1881308,2472,0.0637174253041659,50409.45980882645,0.13256885,0.0473521011397905,0.35000604,5348,0.1022234665997277 -4413.043275594711,1.8080739974975584,47564.2847366333,60404,0,47564.2847366333,0.18706349,2472,0.0626815347429569,51981.551478385925,0.08553199,0.0332429879676399,0.34243563,5348,0.0999546231306178 -4544.038645029068,1.8715159893035889,49004.790016412735,62252,0,49004.790016412735,0.1801457,2472,0.0602644567668027,53553.19096279144,0.08559206,0.0322760184761349,0.33397695,5348,0.0979174913349488 -4673.833595752716,1.9265692234039309,50445.33984518051,64082,0,50445.33984518051,0.17913373,2472,0.0594723051611723,55123.66427946091,0.103173256,0.0402818574870457,0.32366726,5348,0.09361151607017 -4802.509584188461,1.989338636398316,51886.01150846481,65901,0,51886.01150846481,0.16896814,2472,0.0557349745089675,56693.15033197403,0.09804582,0.0370182862740134,0.31444234,5348,0.0911495795398592 -4932.19512963295,2.0536038875579834,53326.61968564987,67717,0,53326.61968564987,0.1688004,2472,0.0571161619239128,58263.58122968674,0.11457752,0.0450217154005156,0.3129322,5348,0.0905316817440165 -5060.983474969864,2.1116764545440674,54766.81351041794,69559,0,54766.81351041794,0.16448735,2472,0.0552474965978104,59832.69885110855,0.09059269,0.0335395084535696,0.30794328,5348,0.0875966672137636 -5190.269483566284,2.172675371170044,56206.77034330368,71385,0,56206.77034330368,0.16138907,2472,0.0539272439217597,61402.078199625015,0.08755792,0.0338997509127512,0.30202192,5348,0.0857912470915357 -5322.592356920242,2.23431396484375,57646.677629470825,73176,0,57646.677629470825,0.15961748,2472,0.0531554038957609,62974.4433233738,0.06485754,0.0246496489765177,0.29641142,5348,0.084719580601871 -5453.71427154541,2.295135021209717,59086.73326683045,74993,0,59086.73326683045,0.15675355,2472,0.0516320354233948,64545.75790762901,0.07795109,0.0302521100395187,0.29337308,5348,0.0834065477857053 -5585.158769369125,2.3584864139556885,60527.72573661804,76822,0,60527.72573661804,0.15502998,2472,0.05106331119371153,66118.33410978317,0.06765499,0.02511690046760187,0.29259604,5348,0.08351274896936578 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/measurements.csv deleted file mode 100644 index 3c897f02c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/measurements.csv +++ /dev/null @@ -1,821 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,38.067623,32.596924,,,,,,,,,,,,,, -1,,,32.871246,1.3708653155889337,31.08705,1.0585651254622166,5348.0,31.23744,1.10235004976337,2472.0,34.33066487312317,169.64612746238708,34.33066487312317,135.31539034843445,0.0,0.0 -100,13.501869,12.081586,,,,,,,,,,,,,, -200,10.620639,8.440752,,,,,,,,,,,,,, -300,1.7551786,5.9174623,,,,,,,,,,,,,, -400,0.27833554,5.839061,,,,,,,,,,,,,, -500,0.3771681,5.8433676,,,,,,,,,,,,,, -600,0.4493005,5.8160152,,,,,,,,,,,,,, -700,0.44495162,5.8150167,,,,,,,,,,,,,, -800,0.23562686,5.79615,,,,,,,,,,,,,, -900,0.4641355,5.7988553,,,,,,,,,,,,,, -1000,0.39493918,5.7759404,,,,,,,,,,,,,, -1100,0.23527026,5.7875795,,,,,,,,,,,,,, -1200,0.63074243,5.791107,,,,,,,,,,,,,, -1300,0.89518255,5.7951593,,,,,,,,,,,,,, -1400,0.3692652,5.7807884,,,,,,,,,,,,,, -1500,0.26808625,5.7658396,,,,,,,,,,,,,, -1600,0.808636,5.800333,,,,,,,,,,,,,, -1700,0.73911136,5.773538,,,,,,,,,,,,,, -1750,,,6.003965,0.9391896477614642,6.1192107,0.8966179750330672,5348.0,6.0945807,0.899579550301627,2472.0,1474.825115442276,1719.9163784980774,1474.825115442276,244.99092769622803,0.028472900390625,0.0 -1800,1.3290808,5.762103,,,,,,,,,,,,,, -1900,0.4113412,5.601807,,,,,,,,,,,,,, -2000,1.453527,5.532372,,,,,,,,,,,,,, -2100,1.2075151,5.4260216,,,,,,,,,,,,,, -2200,0.73887575,5.165704,,,,,,,,,,,,,, -2300,1.2570124,4.5386014,,,,,,,,,,,,,, -2400,0.89941365,4.017739,,,,,,,,,,,,,, -2500,1.9838961,3.749665,,,,,,,,,,,,,, -2600,1.0504307,3.5693886,,,,,,,,,,,,,, -2700,1.3168054,3.3104746,,,,,,,,,,,,,, -2800,1.4955388,3.284036,,,,,,,,,,,,,, -2900,0.98818,3.114577,,,,,,,,,,,,,, -3000,1.0053937,3.0817716,,,,,,,,,,,,,, -3100,1.1306428,2.9097674,,,,,,,,,,,,,, -3200,1.1879926,2.8740945,,,,,,,,,,,,,, -3300,1.4372923,2.9086888,,,,,,,,,,,,,, -3400,1.089657,2.7650452,,,,,,,,,,,,,, -3500,0.9943317,2.7442887,,,,,,,,,,,,,, -3549,,,3.3109722,0.7425314465408805,3.2170153,0.7051565501993686,5348.0,2.8966093,0.6532000893709504,2472.0,2914.8816883563995,3281.0594849586487,2914.8816883563995,365.9432199001312,0.0846376419067382,0.0 -3600,1.0778339,2.6886842,,,,,,,,,,,,,, -3700,1.4357258,2.7275007,,,,,,,,,,,,,, -3800,1.002112,2.6376054,,,,,,,,,,,,,, -3900,0.94468224,2.577868,,,,,,,,,,,,,, -4000,1.2682769,2.4712389,,,,,,,,,,,,,, -4100,1.9745675,2.5315347,,,,,,,,,,,,,, -4200,0.90517366,2.3955066,,,,,,,,,,,,,, -4300,1.0497807,2.31763,,,,,,,,,,,,,, -4400,0.74044263,2.310153,,,,,,,,,,,,,, -4500,0.82635415,2.2375548,,,,,,,,,,,,,, -4600,0.79872715,2.2377102,,,,,,,,,,,,,, -4700,0.761146,2.1870615,,,,,,,,,,,,,, -4800,0.76491493,2.1740913,,,,,,,,,,,,,, -4900,0.7878329,2.1852422,,,,,,,,,,,,,, -5000,1.2414566,2.1858275,,,,,,,,,,,,,, -5100,1.0197295,2.129431,,,,,,,,,,,,,, -5200,1.0027683,2.0483627,,,,,,,,,,,,,, -5300,0.74981123,1.9800481,,,,,,,,,,,,,, -5327,,,1.1042546,0.342650224439735,1.1603082,0.3323421222858356,5348.0,0.85711485,0.2735563544776877,2472.0,4355.136976242065,4853.463213682175,4355.136976242065,497.9646122455597,0.136824369430542,0.0 -5400,0.7469812,1.9782666,,,,,,,,,,,,,, -5500,0.8233426,1.9319053,,,,,,,,,,,,,, -5600,0.7867582,1.9201868,,,,,,,,,,,,,, -5700,0.8277224,1.998779,,,,,,,,,,,,,, -5800,0.79946905,1.9456948,,,,,,,,,,,,,, -5900,0.711874,1.9679855,,,,,,,,,,,,,, -6000,0.70437807,1.9520067,,,,,,,,,,,,,, -6100,0.79645133,1.9400115,,,,,,,,,,,,,, -6200,0.70470107,1.8728151,,,,,,,,,,,,,, -6300,0.74604046,1.8291994,,,,,,,,,,,,,, -6400,0.8053494,1.8402531,,,,,,,,,,,,,, -6500,0.7998385,1.8633916,,,,,,,,,,,,,, -6600,0.7639485,1.811744,,,,,,,,,,,,,, -6700,0.72123694,1.8853008,,,,,,,,,,,,,, -6800,0.6319283,1.7795986,,,,,,,,,,,,,, -6900,0.75212455,1.7970928,,,,,,,,,,,,,, -7000,0.6972738,1.8311677,,,,,,,,,,,,,, -7100,0.722061,1.8341091,,,,,,,,,,,,,, -7101,,,0.71781844,0.2381636636211919,0.856381,0.2575957982949882,5348.0,0.58395475,0.1956614465907013,2472.0,5796.596266269684,6426.853343486786,5796.596266269684,629.766964673996,0.1887216567993164,0.0 -7200,0.80422664,1.7944891,,,,,,,,,,,,,, -7300,0.84896344,1.7319748,,,,,,,,,,,,,, -7400,0.7175194,1.7838688,,,,,,,,,,,,,, -7500,0.6907898,1.7964553,,,,,,,,,,,,,, -7600,0.7355417,1.746382,,,,,,,,,,,,,, -7700,0.6524436,1.7277544,,,,,,,,,,,,,, -7800,0.7964155,1.7453856,,,,,,,,,,,,,, -7900,0.7555558,1.7302645,,,,,,,,,,,,,, -8000,0.788792,1.7952958,,,,,,,,,,,,,, -8100,0.762475,1.7191511,,,,,,,,,,,,,, -8200,0.6938627,1.7229921,,,,,,,,,,,,,, -8300,0.71169484,1.735665,,,,,,,,,,,,,, -8400,0.8083808,1.6977462,,,,,,,,,,,,,, -8500,0.6059254,1.6710553,,,,,,,,,,,,,, -8600,0.7094631,1.6895974,,,,,,,,,,,,,, -8700,0.627023,1.69691,,,,,,,,,,,,,, -8800,0.60847443,1.6847326,,,,,,,,,,,,,, -8886,,,0.6037436,0.2008276839710539,0.7638324,0.2303117487473087,5348.0,0.5030347,0.1674892856417443,2472.0,7236.706146240234,7999.423669338226,7236.706146240234,762.0983710289001,0.2411792278289795,0.0 -8900,0.61606544,1.6940259,,,,,,,,,,,,,, -9000,0.70615864,1.6805046,,,,,,,,,,,,,, -9100,0.7998855,1.6640577,,,,,,,,,,,,,, -9200,0.7339855,1.7566482,,,,,,,,,,,,,, -9300,0.68935657,1.6630195,,,,,,,,,,,,,, -9400,0.6685877,1.6848836,,,,,,,,,,,,,, -9500,0.7868973,1.658969,,,,,,,,,,,,,, -9600,0.629621,1.5985084,,,,,,,,,,,,,, -9700,0.73518604,1.6602304,,,,,,,,,,,,,, -9800,0.6656329,1.6597472,,,,,,,,,,,,,, -9900,0.71903074,1.6462239,,,,,,,,,,,,,, -10000,0.8285795,1.6729301,,,,,,,,,,,,,, -10100,0.6315865,1.6412369,,,,,,,,,,,,,, -10200,0.6990032,1.6145347,,,,,,,,,,,,,, -10300,0.6560351,1.5937527,,,,,,,,,,,,,, -10400,0.73578626,1.6628748,,,,,,,,,,,,,, -10500,0.674898,1.6405367,,,,,,,,,,,,,, -10600,0.73461646,1.6242203,,,,,,,,,,,,,, -10688,,,0.59205496,0.2036127467765384,0.702261,0.2120741091168889,5348.0,0.45312193,0.1555663883980257,2472.0,8676.956310272217,9573.624148845673,8676.956310272217,895.9172258377075,0.2942352294921875,0.0 -10700,0.5601898,1.614907,,,,,,,,,,,,,, -10800,0.603941,1.5457288,,,,,,,,,,,,,, -10900,0.8899335,1.6160513,,,,,,,,,,,,,, -11000,0.7376869,1.5589404,,,,,,,,,,,,,, -11100,0.8917353,1.628555,,,,,,,,,,,,,, -11200,0.6135945,1.5613979,,,,,,,,,,,,,, -11300,0.6216001,1.5962645,,,,,,,,,,,,,, -11400,0.65380025,1.5419189,,,,,,,,,,,,,, -11500,0.6276085,1.5783873,,,,,,,,,,,,,, -11600,0.5932217,1.5411848,,,,,,,,,,,,,, -11700,0.78808117,1.5490359,,,,,,,,,,,,,, -11800,0.7914728,1.6353649,,,,,,,,,,,,,, -11900,0.6169199,1.5470792,,,,,,,,,,,,,, -12000,0.55789447,1.5712328,,,,,,,,,,,,,, -12100,0.5617573,1.506038,,,,,,,,,,,,,, -12200,0.6774934,1.5879865,,,,,,,,,,,,,, -12300,0.56010014,1.4628534,,,,,,,,,,,,,, -12400,0.744919,1.5831109,,,,,,,,,,,,,, -12458,,,0.49941278,0.171514791173534,0.64830595,0.1956998175270571,5348.0,0.4115406,0.1393374362724189,2472.0,10117.289225816729,11145.675840854645,10117.289225816729,1027.5034348964691,0.3512380123138428,0.0 -12500,0.5998048,1.5038517,,,,,,,,,,,,,, -12600,0.66517854,1.551234,,,,,,,,,,,,,, -12700,0.6292232,1.5463313,,,,,,,,,,,,,, -12800,0.53584516,1.5177621,,,,,,,,,,,,,, -12900,0.74010247,1.5293041,,,,,,,,,,,,,, -13000,0.675258,1.4810723,,,,,,,,,,,,,, -13100,0.6829418,1.6093469,,,,,,,,,,,,,, -13200,0.5881966,1.5542865,,,,,,,,,,,,,, -13300,0.6495852,1.5639832,,,,,,,,,,,,,, -13400,0.61039335,1.4947652,,,,,,,,,,,,,, -13500,0.62846774,1.4898155,,,,,,,,,,,,,, -13600,0.77896976,1.5424212,,,,,,,,,,,,,, -13700,0.6035838,1.4954181,,,,,,,,,,,,,, -13800,0.59719306,1.5076691,,,,,,,,,,,,,, -13900,0.61777335,1.506693,,,,,,,,,,,,,, -14000,0.5902298,1.5356951,,,,,,,,,,,,,, -14100,0.58334345,1.5173187,,,,,,,,,,,,,, -14200,0.75793433,1.5090183,,,,,,,,,,,,,, -14248,,,0.49112177,0.166320429114731,0.6159968,0.1868851192832385,5348.0,0.38463014,0.1315987244327991,2472.0,11557.23063492775,12719.211951255798,11557.23063492775,1160.9660465717316,0.408066987991333,0.0 -14300,0.6300611,1.5219568,,,,,,,,,,,,,, -14400,0.6057978,1.5644038,,,,,,,,,,,,,, -14500,0.6505775,1.4832855,,,,,,,,,,,,,, -14600,0.61432946,1.5018972,,,,,,,,,,,,,, -14700,0.61926943,1.5311692,,,,,,,,,,,,,, -14800,0.6525829,1.4623793,,,,,,,,,,,,,, -14900,0.7295564,1.4952004,,,,,,,,,,,,,, -15000,0.65181476,1.5768275,,,,,,,,,,,,,, -15100,0.70369935,1.4980773,,,,,,,,,,,,,, -15200,0.6435203,1.545094,,,,,,,,,,,,,, -15300,0.9817212,1.5440928,,,,,,,,,,,,,, -15400,0.5999351,1.4427451,,,,,,,,,,,,,, -15500,0.7493273,1.4212914,,,,,,,,,,,,,, -15600,0.7311821,1.4474114,,,,,,,,,,,,,, -15700,0.6831851,1.4816774,,,,,,,,,,,,,, -15800,0.6974112,1.464098,,,,,,,,,,,,,, -15900,0.70603335,1.5197229,,,,,,,,,,,,,, -16000,0.72410643,1.437979,,,,,,,,,,,,,, -16095,,,0.41128248,0.1451225741954266,0.5923722,0.1799820423453083,5348.0,0.36829486,0.1253021347470193,2472.0,12997.853597402573,14287.417057275772,12997.853597402573,1288.4258918762207,0.4595286846160888,0.0 -16100,0.5917782,1.4802774,,,,,,,,,,,,,, -16200,0.74251866,1.4720122,,,,,,,,,,,,,, -16300,0.8395603,1.438899,,,,,,,,,,,,,, -16400,0.57731843,1.4581379,,,,,,,,,,,,,, -16500,0.5909291,1.4834741,,,,,,,,,,,,,, -16600,0.6626472,1.4379724,,,,,,,,,,,,,, -16700,0.5675664,1.4331834,,,,,,,,,,,,,, -16800,0.60891753,1.4170567,,,,,,,,,,,,,, -16900,0.58914113,1.4733866,,,,,,,,,,,,,, -17000,0.64524573,1.4084606,,,,,,,,,,,,,, -17100,0.5954276,1.4491202,,,,,,,,,,,,,, -17200,0.60367954,1.4417173,,,,,,,,,,,,,, -17300,0.6659915,1.4301305,,,,,,,,,,,,,, -17400,0.6372462,1.5289237,,,,,,,,,,,,,, -17500,0.56763095,1.3698362,,,,,,,,,,,,,, -17600,0.566029,1.4069247,,,,,,,,,,,,,, -17700,0.7605959,1.4832109,,,,,,,,,,,,,, -17800,0.62489843,1.427317,,,,,,,,,,,,,, -17900,0.7052804,1.4457226,,,,,,,,,,,,,, -17959,,,0.41662517,0.1468018951732306,0.57269543,0.1724514129584753,5348.0,0.35398287,0.1204679787947108,2472.0,14437.80846619606,15852.590964317322,14437.80846619606,1413.5200998783112,0.5142014026641846,0.0 -18000,0.6463114,1.439896,,,,,,,,,,,,,, -18100,0.66318715,1.4549752,,,,,,,,,,,,,, -18200,0.5771496,1.3811663,,,,,,,,,,,,,, -18300,0.6207645,1.405688,,,,,,,,,,,,,, -18400,0.69517505,1.4756124,,,,,,,,,,,,,, -18500,0.6316853,1.3971535,,,,,,,,,,,,,, -18600,0.6810769,1.4248012,,,,,,,,,,,,,, -18700,0.67810524,1.3677301,,,,,,,,,,,,,, -18800,0.59662586,1.4245226,,,,,,,,,,,,,, -18900,0.5773408,1.426345,,,,,,,,,,,,,, -19000,0.6377075,1.3863528,,,,,,,,,,,,,, -19100,0.67340547,1.463532,,,,,,,,,,,,,, -19200,0.5888234,1.3650998,,,,,,,,,,,,,, -19300,0.71095914,1.4703839,,,,,,,,,,,,,, -19400,0.60214645,1.3993196,,,,,,,,,,,,,, -19500,0.742252,1.4308501,,,,,,,,,,,,,, -19600,0.62430245,1.3877233,,,,,,,,,,,,,, -19700,0.7548382,1.4643434,,,,,,,,,,,,,, -19800,0.64981484,1.3700423,,,,,,,,,,,,,, -19819,,,0.3939056,0.138076927376175,0.55853176,0.1676144317753941,5348.0,0.33988407,0.115024475453456,2472.0,15878.370507478714,17418.529887914658,15878.370507478714,1538.7788639068604,0.561873197555542,0.0 -19900,0.65013075,1.3540012,,,,,,,,,,,,,, -20000,0.659319,1.3614931,,,,,,,,,,,,,, -20100,0.64273757,1.359357,,,,,,,,,,,,,, -20200,0.63329476,1.4109911,,,,,,,,,,,,,, -20300,0.76693594,1.3404152,,,,,,,,,,,,,, -20400,0.6265496,1.3872945,,,,,,,,,,,,,, -20500,0.70365745,1.5026863,,,,,,,,,,,,,, -20600,0.72156507,1.3659835,,,,,,,,,,,,,, -20700,0.6003838,1.3655319,,,,,,,,,,,,,, -20800,0.75279677,1.3474451,,,,,,,,,,,,,, -20900,0.66388524,1.3729993,,,,,,,,,,,,,, -21000,0.6060349,1.3217424,,,,,,,,,,,,,, -21100,0.649406,1.3334839,,,,,,,,,,,,,, -21200,0.65250474,1.4023438,,,,,,,,,,,,,, -21300,0.62154037,1.3781585,,,,,,,,,,,,,, -21400,0.62504363,1.3853111,,,,,,,,,,,,,, -21500,0.694984,1.3690903,,,,,,,,,,,,,, -21600,0.7149877,1.387831,,,,,,,,,,,,,, -21678,,,0.26414925,0.0981504777709376,0.5420662,0.1627291773270127,5348.0,0.32638323,0.1087075741880446,2472.0,17318.98425388336,18997.48906755448,17318.98425388336,1677.0056114196775,0.6107773780822754,0.0 -21700,0.75736856,1.4302669,,,,,,,,,,,,,, -21800,0.5676989,1.3762858,,,,,,,,,,,,,, -21900,0.5535851,1.3146111,,,,,,,,,,,,,, -22000,0.5986888,1.4108522,,,,,,,,,,,,,, -22100,0.76708233,1.3989954,,,,,,,,,,,,,, -22200,0.74730486,1.4225265,,,,,,,,,,,,,, -22300,0.701401,1.3720918,,,,,,,,,,,,,, -22400,0.62046474,1.3870689,,,,,,,,,,,,,, -22500,0.7297077,1.3724029,,,,,,,,,,,,,, -22600,0.6222844,1.4052104,,,,,,,,,,,,,, -22700,0.72881365,1.3410271,,,,,,,,,,,,,, -22800,0.58794636,1.3063873,,,,,,,,,,,,,, -22900,0.6612012,1.3432436,,,,,,,,,,,,,, -23000,0.6334787,1.4063088,,,,,,,,,,,,,, -23100,0.65697086,1.3503637,,,,,,,,,,,,,, -23200,0.66427785,1.3823142,,,,,,,,,,,,,, -23300,0.6662744,1.3829728,,,,,,,,,,,,,, -23400,0.73983073,1.4016099,,,,,,,,,,,,,, -23500,0.5752273,1.3444363,,,,,,,,,,,,,, -23548,,,0.24399593,0.0890187326286641,0.5178802,0.1580466705928922,5348.0,0.3171775,0.1089310015639916,2472.0,18759.508040905,20568.534235477448,18759.508040905,1807.410040140152,0.6587982177734375,0.0 -23600,0.71928865,1.40149,,,,,,,,,,,,,, -23700,0.6737293,1.3497956,,,,,,,,,,,,,, -23800,0.7739791,1.3310935,,,,,,,,,,,,,, -23900,0.587552,1.352515,,,,,,,,,,,,,, -24000,0.7897197,1.3552129,,,,,,,,,,,,,, -24100,0.56806165,1.3212072,,,,,,,,,,,,,, -24200,0.7572311,1.3586844,,,,,,,,,,,,,, -24300,0.6193611,1.3972142,,,,,,,,,,,,,, -24400,0.71221024,1.3574061,,,,,,,,,,,,,, -24500,0.61796355,1.318777,,,,,,,,,,,,,, -24600,0.6264741,1.4315728,,,,,,,,,,,,,, -24700,0.6311317,1.3374687,,,,,,,,,,,,,, -24800,0.6540306,1.3378801,,,,,,,,,,,,,, -24900,0.74578947,1.3526694,,,,,,,,,,,,,, -25000,0.65297794,1.3817496,,,,,,,,,,,,,, -25100,0.82658786,1.3609608,,,,,,,,,,,,,, -25200,0.6403841,1.300049,,,,,,,,,,,,,, -25300,0.6694469,1.3738921,,,,,,,,,,,,,, -25400,0.85188377,1.3677598,,,,,,,,,,,,,, -25401,,,0.2405506,0.0882221486073646,0.51276976,0.1538275872056537,5348.0,0.31287903,0.1052952288099445,2472.0,20200.06555223465,22138.588984251022,20200.06555223465,1936.7809422016144,0.7138986587524414,0.0 -25500,0.690993,1.3170185,,,,,,,,,,,,,, -25600,0.66770655,1.40771,,,,,,,,,,,,,, -25700,0.6393434,1.2994702,,,,,,,,,,,,,, -25800,0.66085666,1.3008633,,,,,,,,,,,,,, -25900,0.6657844,1.341293,,,,,,,,,,,,,, -26000,0.69573206,1.3074075,,,,,,,,,,,,,, -26100,0.6774055,1.4000239,,,,,,,,,,,,,, -26200,0.7391319,1.4108102,,,,,,,,,,,,,, -26300,0.6751267,1.2962098,,,,,,,,,,,,,, -26400,0.6027122,1.3920908,,,,,,,,,,,,,, -26500,0.6228839,1.3471861,,,,,,,,,,,,,, -26600,0.7020665,1.2812753,,,,,,,,,,,,,, -26700,0.6817833,1.3799545,,,,,,,,,,,,,, -26800,0.7321378,1.3342088,,,,,,,,,,,,,, -26900,0.6674952,1.3340682,,,,,,,,,,,,,, -27000,0.66976035,1.3477962,,,,,,,,,,,,,, -27100,0.7203674,1.3077985,,,,,,,,,,,,,, -27200,0.58586234,1.3010446,,,,,,,,,,,,,, -27264,,,0.22414844,0.083011886543259,0.50314236,0.1510277378182415,5348.0,0.29457772,0.1012329128836349,2472.0,21640.561172246933,23708.775426387787,21640.561172246933,2066.3479647636414,0.7682766914367676,0.0 -27300,0.75549453,1.3383538,,,,,,,,,,,,,, -27400,0.7232828,1.2960402,,,,,,,,,,,,,, -27500,0.6375971,1.2870725,,,,,,,,,,,,,, -27600,0.6963926,1.3746204,,,,,,,,,,,,,, -27700,0.61962026,1.2947545,,,,,,,,,,,,,, -27800,0.6358591,1.3471532,,,,,,,,,,,,,, -27900,0.5904726,1.2412144,,,,,,,,,,,,,, -28000,0.752607,1.29424,,,,,,,,,,,,,, -28100,0.7430916,1.350161,,,,,,,,,,,,,, -28200,0.6213535,1.3228744,,,,,,,,,,,,,, -28300,0.66517943,1.3419355,,,,,,,,,,,,,, -28400,0.7429977,1.3310378,,,,,,,,,,,,,, -28500,0.61558825,1.3548948,,,,,,,,,,,,,, -28600,0.75169307,1.3356141,,,,,,,,,,,,,, -28700,0.6338993,1.3339194,,,,,,,,,,,,,, -28800,0.6245799,1.2872151,,,,,,,,,,,,,, -28900,0.6428317,1.3374716,,,,,,,,,,,,,, -29000,0.92435324,1.2434173,,,,,,,,,,,,,, -29100,0.8025964,1.2738653,,,,,,,,,,,,,, -29119,,,0.22378603,0.0847956333553655,0.4944983,0.1492802456143738,5348.0,0.29227978,0.0996892328316373,2472.0,23080.513560533524,25276.79612565041,23080.513560533524,2194.294489145279,0.8199701309204102,0.0 -29200,0.62859684,1.2589357,,,,,,,,,,,,,, -29300,0.70426106,1.3259051,,,,,,,,,,,,,, -29400,0.69312614,1.2996019,,,,,,,,,,,,,, -29500,0.6764509,1.2730392,,,,,,,,,,,,,, -29600,0.6231527,1.2891711,,,,,,,,,,,,,, -29700,0.6870258,1.286375,,,,,,,,,,,,,, -29800,0.5930538,1.2749257,,,,,,,,,,,,,, -29900,0.63591784,1.2591252,,,,,,,,,,,,,, -30000,0.6385017,1.2768537,,,,,,,,,,,,,, -30100,0.63689315,1.2947468,,,,,,,,,,,,,, -30200,0.8007249,1.2916845,,,,,,,,,,,,,, -30300,0.74721956,1.289172,,,,,,,,,,,,,, -30400,0.59984344,1.2994342,,,,,,,,,,,,,, -30500,0.7394684,1.2997308,,,,,,,,,,,,,, -30600,0.7291419,1.3112876,,,,,,,,,,,,,, -30700,0.69577783,1.3253123,,,,,,,,,,,,,, -30800,0.73077595,1.3122804,,,,,,,,,,,,,, -30900,0.7630749,1.299649,,,,,,,,,,,,,, -30958,,,0.20724949,0.0780107444682858,0.48233274,0.1448487598598144,5348.0,0.286033,0.0981861759389027,2472.0,24520.68779182434,26845.64391064644,24520.68779182434,2322.846914052964,0.8701145648956299,0.0 -31000,0.6049328,1.2744688,,,,,,,,,,,,,, -31100,0.9067135,1.2961993,,,,,,,,,,,,,, -31200,0.70115334,1.2366923,,,,,,,,,,,,,, -31300,0.6374056,1.2835633,,,,,,,,,,,,,, -31400,0.8608088,1.3420544,,,,,,,,,,,,,, -31500,0.72300637,1.2775491,,,,,,,,,,,,,, -31600,0.6093959,1.3061401,,,,,,,,,,,,,, -31700,0.6902252,1.251288,,,,,,,,,,,,,, -31800,0.74835855,1.2809229,,,,,,,,,,,,,, -31900,0.7161158,1.3061818,,,,,,,,,,,,,, -32000,0.58105576,1.2081257,,,,,,,,,,,,,, -32100,0.8671501,1.2558377,,,,,,,,,,,,,, -32200,0.63938063,1.3304383,,,,,,,,,,,,,, -32300,0.5690446,1.2391902,,,,,,,,,,,,,, -32400,0.6496863,1.2141509,,,,,,,,,,,,,, -32500,0.70382744,1.2672693,,,,,,,,,,,,,, -32600,0.73444384,1.2924345,,,,,,,,,,,,,, -32700,0.75065184,1.2996079,,,,,,,,,,,,,, -32800,0.63078547,1.2958102,,,,,,,,,,,,,, -32805,,,0.24316454,0.084999483434562,0.4656841,0.1414696312887996,5348.0,0.27722582,0.095748786383117,2472.0,25960.884137392044,28414.607186079025,25960.884137392044,2451.4838552474976,0.925905466079712,0.0 -32900,0.82080895,1.2479317,,,,,,,,,,,,,, -33000,0.5838804,1.2193189,,,,,,,,,,,,,, -33100,0.6251597,1.219345,,,,,,,,,,,,,, -33200,0.63468546,1.2132131,,,,,,,,,,,,,, -33300,0.7588254,1.2978823,,,,,,,,,,,,,, -33400,0.7002371,1.2409163,,,,,,,,,,,,,, -33500,0.67580104,1.2588116,,,,,,,,,,,,,, -33600,0.89531773,1.2796621,,,,,,,,,,,,,, -33700,0.6872555,1.2304003,,,,,,,,,,,,,, -33800,0.6196623,1.2933247,,,,,,,,,,,,,, -33900,0.69446707,1.229989,,,,,,,,,,,,,, -34000,0.74656814,1.2145008,,,,,,,,,,,,,, -34100,0.7453095,1.2524725,,,,,,,,,,,,,, -34200,0.6773642,1.2487408,,,,,,,,,,,,,, -34300,0.6159156,1.2586863,,,,,,,,,,,,,, -34400,0.7783782,1.2744205,,,,,,,,,,,,,, -34500,0.7346658,1.2454883,,,,,,,,,,,,,, -34600,0.74977803,1.2000539,,,,,,,,,,,,,, -34652,,,0.20290689,0.0774240193362147,0.46335202,0.1393938808808905,5348.0,0.264707,0.0922145715272276,2472.0,27401.33624315262,29985.31670928001,27401.33624315262,2581.6163563728333,0.9803595542907716,0.0 -34700,0.75221515,1.3036262,,,,,,,,,,,,,, -34800,0.82302517,1.2147188,,,,,,,,,,,,,, -34900,0.66165876,1.3038571,,,,,,,,,,,,,, -35000,0.63921684,1.2589909,,,,,,,,,,,,,, -35100,0.69985026,1.261112,,,,,,,,,,,,,, -35200,0.67405057,1.1997349,,,,,,,,,,,,,, -35300,0.6665841,1.2487633,,,,,,,,,,,,,, -35400,0.796381,1.2686405,,,,,,,,,,,,,, -35500,0.6529525,1.250299,,,,,,,,,,,,,, -35600,0.6329977,1.270475,,,,,,,,,,,,,, -35700,0.593829,1.201245,,,,,,,,,,,,,, -35800,0.7656365,1.3093077,,,,,,,,,,,,,, -35900,0.582152,1.2270476,,,,,,,,,,,,,, -36000,0.81828725,1.279916,,,,,,,,,,,,,, -36100,0.6824217,1.2019366,,,,,,,,,,,,,, -36200,0.8763173,1.2477745,,,,,,,,,,,,,, -36300,0.903403,1.2194632,,,,,,,,,,,,,, -36400,0.7010616,1.2546883,,,,,,,,,,,,,, -36500,0.6402165,1.2284423,,,,,,,,,,,,,, -36509,,,0.18551725,0.0709965430233595,0.45179525,0.135850623207855,5348.0,0.2665545,0.0891068998436008,2472.0,28841.840735912323,31557.14866900444,28841.840735912323,2712.818051815033,1.0341756343841553,0.0 -36600,0.64106524,1.2189053,,,,,,,,,,,,,, -36700,0.61989725,1.1948217,,,,,,,,,,,,,, -36800,0.9172292,1.2286987,,,,,,,,,,,,,, -36900,0.86213243,1.2559983,,,,,,,,,,,,,, -37000,0.6952605,1.2649193,,,,,,,,,,,,,, -37100,0.80536515,1.2258471,,,,,,,,,,,,,, -37200,0.73621625,1.2395707,,,,,,,,,,,,,, -37300,0.7216827,1.2217239,,,,,,,,,,,,,, -37400,1.0339922,1.2279543,,,,,,,,,,,,,, -37500,0.8437639,1.190239,,,,,,,,,,,,,, -37600,0.615777,1.2519325,,,,,,,,,,,,,, -37700,0.62115747,1.204888,,,,,,,,,,,,,, -37800,0.8089975,1.229409,,,,,,,,,,,,,, -37900,0.89200425,1.2522203,,,,,,,,,,,,,, -38000,0.79480547,1.2426839,,,,,,,,,,,,,, -38100,0.7066829,1.253671,,,,,,,,,,,,,, -38200,0.68706185,1.2182807,,,,,,,,,,,,,, -38300,0.7324811,1.26978,,,,,,,,,,,,,, -38355,,,0.17950557,0.0672176894251468,0.4428782,0.1336300530040453,5348.0,0.26117116,0.0884366177157597,2472.0,30282.38442492485,33127.51769065857,30282.38442492485,2842.51593208313,1.0907597541809082,0.0 -38400,0.639853,1.1587144,,,,,,,,,,,,,, -38500,0.8953443,1.2038367,,,,,,,,,,,,,, -38600,0.64773417,1.2187802,,,,,,,,,,,,,, -38700,0.68708843,1.2142034,,,,,,,,,,,,,, -38800,0.7994206,1.2473717,,,,,,,,,,,,,, -38900,0.66670406,1.2707816,,,,,,,,,,,,,, -39000,0.83723086,1.2283709,,,,,,,,,,,,,, -39100,0.733281,1.2038312,,,,,,,,,,,,,, -39200,0.6842256,1.2083362,,,,,,,,,,,,,, -39300,0.6921915,1.1906172,,,,,,,,,,,,,, -39400,0.7287001,1.2019743,,,,,,,,,,,,,, -39500,0.7150089,1.2161101,,,,,,,,,,,,,, -39600,0.66011477,1.2014049,,,,,,,,,,,,,, -39700,0.72154504,1.1714542,,,,,,,,,,,,,, -39800,0.7226629,1.2292997,,,,,,,,,,,,,, -39900,0.6416689,1.2450619,,,,,,,,,,,,,, -40000,0.6578179,1.1804017,,,,,,,,,,,,,, -40100,0.6783434,1.1644384,,,,,,,,,,,,,, -40185,,,0.16687053,0.0648847975882859,0.43387946,0.1303378163105708,5348.0,0.25024888,0.0846180407450287,2472.0,31722.82284355164,34697.79404234886,31722.82284355164,2972.221476078033,1.150517463684082,0.0 -40200,0.7097952,1.187827,,,,,,,,,,,,,, -40300,0.73235995,1.1798604,,,,,,,,,,,,,, -40400,0.79493177,1.216021,,,,,,,,,,,,,, -40500,0.7668146,1.1738437,,,,,,,,,,,,,, -40600,0.70182425,1.193793,,,,,,,,,,,,,, -40700,0.6731058,1.1677785,,,,,,,,,,,,,, -40800,0.695106,1.1862216,,,,,,,,,,,,,, -40900,0.9222873,1.2505937,,,,,,,,,,,,,, -41000,0.7272726,1.2200338,,,,,,,,,,,,,, -41100,0.9034264,1.2104404,,,,,,,,,,,,,, -41200,0.86850303,1.1769726,,,,,,,,,,,,,, -41300,0.7011626,1.1247183,,,,,,,,,,,,,, -41400,0.7167945,1.1281203,,,,,,,,,,,,,, -41500,0.7327503,1.2382905,,,,,,,,,,,,,, -41600,0.78896254,1.2062873,,,,,,,,,,,,,, -41700,0.9314225,1.1817887,,,,,,,,,,,,,, -41800,0.76451856,1.17432,,,,,,,,,,,,,, -41900,0.8333698,1.2374855,,,,,,,,,,,,,, -42000,0.8089289,1.1890246,,,,,,,,,,,,,, -42037,,,0.17537016,0.0659101794201083,0.42432463,0.1273062552497176,5348.0,0.24411249,0.0836430849227144,2472.0,33162.69558787346,36268.49072861672,33162.69558787346,3102.9151356220245,1.2059061527252195,0.0 -42100,1.0070149,1.1849585,,,,,,,,,,,,,, -42200,0.7053371,1.1907278,,,,,,,,,,,,,, -42300,0.7859993,1.1883743,,,,,,,,,,,,,, -42400,0.7556804,1.1435016,,,,,,,,,,,,,, -42500,0.68772167,1.1481818,,,,,,,,,,,,,, -42600,0.7865188,1.2166315,,,,,,,,,,,,,, -42700,0.70008874,1.1372851,,,,,,,,,,,,,, -42800,0.7181674,1.1949705,,,,,,,,,,,,,, -42900,0.8525365,1.2021389,,,,,,,,,,,,,, -43000,0.70981216,1.2612828,,,,,,,,,,,,,, -43100,0.6457502,1.2170328,,,,,,,,,,,,,, -43200,0.7588104,1.1905271,,,,,,,,,,,,,, -43300,0.88044864,1.1717821,,,,,,,,,,,,,, -43400,0.8239336,1.1959449,,,,,,,,,,,,,, -43500,0.745618,1.1906334,,,,,,,,,,,,,, -43600,0.7581235,1.0972645,,,,,,,,,,,,,, -43700,0.80327743,1.1571721,,,,,,,,,,,,,, -43800,0.67487353,1.1093104,,,,,,,,,,,,,, -43882,,,0.17138386,0.0645453444471942,0.4157252,0.124429168637825,5348.0,0.23707676,0.0805963479779822,2472.0,34603.0143558979,37839.05934667587,34603.0143558979,3233.0268499851227,1.269730806350708,0.0 -43900,0.844245,1.2281442,,,,,,,,,,,,,, -44000,0.701031,1.2012984,,,,,,,,,,,,,, -44100,0.7470639,1.1653138,,,,,,,,,,,,,, -44200,0.630223,1.079858,,,,,,,,,,,,,, -44300,0.68684214,1.144929,,,,,,,,,,,,,, -44400,0.67989093,1.1316353,,,,,,,,,,,,,, -44500,0.75555915,1.154032,,,,,,,,,,,,,, -44600,0.87698215,1.1590431,,,,,,,,,,,,,, -44700,0.754633,1.1347048,,,,,,,,,,,,,, -44800,0.76184815,1.2246609,,,,,,,,,,,,,, -44900,0.67051613,1.1410114,,,,,,,,,,,,,, -45000,0.7770295,1.1927205,,,,,,,,,,,,,, -45100,0.8544224,1.1895646,,,,,,,,,,,,,, -45200,0.83770853,1.1827034,,,,,,,,,,,,,, -45300,0.6451028,1.088934,,,,,,,,,,,,,, -45400,0.70402133,1.1459161,,,,,,,,,,,,,, -45500,0.6826138,1.1701739,,,,,,,,,,,,,, -45600,0.707441,1.1073422,,,,,,,,,,,,,, -45700,0.6290699,1.1261177,,,,,,,,,,,,,, -45739,,,0.15769911,0.0606665491460746,0.4071435,0.122372727536036,5348.0,0.2356796,0.080291674283509,2472.0,36043.24026441574,39409.80263566971,36043.24026441574,3363.4156663417816,1.3258774280548096,0.0 -45800,0.8326336,1.160177,,,,,,,,,,,,,, -45900,0.7924003,1.1317629,,,,,,,,,,,,,, -46000,0.76787776,1.1989536,,,,,,,,,,,,,, -46100,0.74026185,1.1619141,,,,,,,,,,,,,, -46200,0.78904444,1.1370225,,,,,,,,,,,,,, -46300,0.73392534,1.1431044,,,,,,,,,,,,,, -46400,0.8290809,1.1017714,,,,,,,,,,,,,, -46500,0.8274174,1.2022563,,,,,,,,,,,,,, -46600,0.7893027,1.1660619,,,,,,,,,,,,,, -46700,0.7926067,1.1309419,,,,,,,,,,,,,, -46800,0.83989054,1.158132,,,,,,,,,,,,,, -46900,0.84986967,1.1005082,,,,,,,,,,,,,, -47000,0.8648887,1.1204345,,,,,,,,,,,,,, -47100,0.71167034,1.1218648,,,,,,,,,,,,,, -47200,0.7671475,1.1014208,,,,,,,,,,,,,, -47300,0.7087501,1.1367847,,,,,,,,,,,,,, -47400,0.70219815,1.1442343,,,,,,,,,,,,,, -47500,0.7399302,1.1261356,,,,,,,,,,,,,, -47582,,,0.15129596,0.0577826683285268,0.40430495,0.1194377130057831,5348.0,0.22712962,0.0770824447017244,2472.0,37483.37846851349,40981.70229816437,37483.37846851349,3495.040193796158,1.3885679244995115,0.0 -47600,0.86748946,1.111286,,,,,,,,,,,,,, -47700,0.79695445,1.0960612,,,,,,,,,,,,,, -47800,0.6892994,1.1151739,,,,,,,,,,,,,, -47900,0.80556166,1.1616689,,,,,,,,,,,,,, -48000,0.7621941,1.1729244,,,,,,,,,,,,,, -48100,0.75115037,1.1386518,,,,,,,,,,,,,, -48200,0.81557846,1.1057934,,,,,,,,,,,,,, -48300,0.70738775,1.0954769,,,,,,,,,,,,,, -48400,0.8433637,1.1256489,,,,,,,,,,,,,, -48500,1.0536139,1.0971782,,,,,,,,,,,,,, -48600,0.8680709,1.0936518,,,,,,,,,,,,,, -48700,0.7698705,1.1333617,,,,,,,,,,,,,, -48800,0.9233009,1.1301743,,,,,,,,,,,,,, -48900,0.8516311,1.1576707,,,,,,,,,,,,,, -49000,0.89929926,1.1270988,,,,,,,,,,,,,, -49100,0.7336141,1.103591,,,,,,,,,,,,,, -49200,0.9131273,1.1373768,,,,,,,,,,,,,, -49300,0.8161751,1.1311885,,,,,,,,,,,,,, -49400,0.7585375,1.1436563,,,,,,,,,,,,,, -49408,,,0.13115764,0.0500343243385964,0.39019933,0.1160199658225281,5348.0,0.21716937,0.075437206751569,2472.0,38923.28881287575,42553.44931507111,38923.28881287575,3626.74472117424,1.4479291439056396,0.0 -49500,0.72364044,1.0482519,,,,,,,,,,,,,, -49600,0.9489737,1.0803851,,,,,,,,,,,,,, -49700,1.0585021,1.1035168,,,,,,,,,,,,,, -49800,0.85946983,1.0726641,,,,,,,,,,,,,, -49900,0.8696627,1.1179206,,,,,,,,,,,,,, -50000,0.7718337,1.1570159,,,,,,,,,,,,,, -50100,0.8616587,1.1155403,,,,,,,,,,,,,, -50200,1.0813943,1.1319201,,,,,,,,,,,,,, -50300,0.75412315,1.1216205,,,,,,,,,,,,,, -50400,0.855985,1.1193973,,,,,,,,,,,,,, -50500,0.844067,1.1095253,,,,,,,,,,,,,, -50600,0.7409262,1.0480099,,,,,,,,,,,,,, -50700,0.83616245,1.1327218,,,,,,,,,,,,,, -50800,0.7950045,1.0725523,,,,,,,,,,,,,, -50900,0.9570945,1.082577,,,,,,,,,,,,,, -51000,0.824232,1.076516,,,,,,,,,,,,,, -51100,1.002029,1.1400607,,,,,,,,,,,,,, -51200,0.8284463,1.0937364,,,,,,,,,,,,,, -51227,,,0.13253723,0.0503795192415744,0.3822779,0.1125539453739729,5348.0,0.21561463,0.0724920277049946,2472.0,40363.78140926361,44126.6337416172,40363.78140926361,3759.30832862854,1.5018019676208496,0.0 -51300,0.76290435,1.1291006,,,,,,,,,,,,,, -51400,0.8293953,1.0537661,,,,,,,,,,,,,, -51500,0.7628894,1.1048189,,,,,,,,,,,,,, -51600,0.81994945,1.106908,,,,,,,,,,,,,, -51700,0.8028125,1.0436493,,,,,,,,,,,,,, -51800,0.92517287,1.0908924,,,,,,,,,,,,,, -51900,0.9763925,1.0691425,,,,,,,,,,,,,, -52000,0.8401335,1.1288139,,,,,,,,,,,,,, -52100,0.9353139,1.099799,,,,,,,,,,,,,, -52200,1.1689734,1.121436,,,,,,,,,,,,,, -52300,0.79671544,1.1007231,,,,,,,,,,,,,, -52400,0.90630215,1.131915,,,,,,,,,,,,,, -52500,0.75088394,1.1047145,,,,,,,,,,,,,, -52600,0.90507776,1.1053336,,,,,,,,,,,,,, -52700,0.8245503,1.046323,,,,,,,,,,,,,, -52800,0.7533507,1.0861062,,,,,,,,,,,,,, -52900,0.9129347,1.0830722,,,,,,,,,,,,,, -53000,0.7896434,1.0934696,,,,,,,,,,,,,, -53073,,,0.13393289,0.0510575079009211,0.3740585,0.1119070836189501,5348.0,0.20504084,0.0704608697418398,2472.0,41803.87991976738,45697.32681298256,41803.87991976738,3889.758169412613,1.571280002593994,0.0 -53100,0.8518459,1.1348603,,,,,,,,,,,,,, -53200,0.8880597,1.1154054,,,,,,,,,,,,,, -53300,0.8829306,1.1029993,,,,,,,,,,,,,, -53400,1.8154761,1.0637484,,,,,,,,,,,,,, -53500,1.1755073,1.1001545,,,,,,,,,,,,,, -53600,0.84576565,1.0163809,,,,,,,,,,,,,, -53700,0.8384095,1.0724943,,,,,,,,,,,,,, -53800,0.76973826,1.0973747,,,,,,,,,,,,,, -53900,0.7221666,1.0239867,,,,,,,,,,,,,, -54000,0.97929215,0.9988081,,,,,,,,,,,,,, -54100,1.0098262,1.0872957,,,,,,,,,,,,,, -54200,1.1486449,1.051357,,,,,,,,,,,,,, -54300,0.83853865,1.0545167,,,,,,,,,,,,,, -54400,0.76945966,1.0562327,,,,,,,,,,,,,, -54500,0.9058943,1.04067,,,,,,,,,,,,,, -54600,0.78982806,1.0190421,,,,,,,,,,,,,, -54700,0.81949514,1.0559267,,,,,,,,,,,,,, -54800,0.84535325,1.0537398,,,,,,,,,,,,,, -54900,0.9212629,1.0641096,,,,,,,,,,,,,, -54913,,,0.113810115,0.0432091109589262,0.36113426,0.107765237456192,5348.0,0.20187275,0.0668657201470558,2472.0,43243.759679079056,47267.8637509346,43243.759679079056,4020.283055782318,1.6276521682739258,0.0 -55000,0.8163352,1.0434299,,,,,,,,,,,,,, -55100,0.95715964,1.0744624,,,,,,,,,,,,,, -55200,1.0269104,1.0603089,,,,,,,,,,,,,, -55300,0.9230635,1.0771465,,,,,,,,,,,,,, -55400,0.9336215,1.1497579,,,,,,,,,,,,,, -55500,0.8309194,1.0268925,,,,,,,,,,,,,, -55600,0.83541274,1.0217155,,,,,,,,,,,,,, -55700,0.8410051,1.0561334,,,,,,,,,,,,,, -55800,0.9999559,1.041848,,,,,,,,,,,,,, -55900,0.87163985,1.0290227,,,,,,,,,,,,,, -56000,1.2817633,1.0788388,,,,,,,,,,,,,, -56100,0.9573587,1.0263317,,,,,,,,,,,,,, -56200,0.99141943,1.0372769,,,,,,,,,,,,,, -56300,0.88428783,1.0543349,,,,,,,,,,,,,, -56400,0.89741945,1.0615739,,,,,,,,,,,,,, -56500,0.95472795,1.0483111,,,,,,,,,,,,,, -56600,0.983611,1.0118566,,,,,,,,,,,,,, -56700,0.879112,1.0793567,,,,,,,,,,,,,, -56736,,,0.11524229,0.0462044148115838,0.35849932,0.1044343821504774,5348.0,0.19632728,0.0671094591026344,2472.0,44683.91825604439,48838.96536445618,44683.91825604439,4151.09471654892,1.6856026649475098,0.0 -56800,1.0460877,0.9497644,,,,,,,,,,,,,, -56900,1.2459282,1.0843196,,,,,,,,,,,,,, -57000,0.90331125,1.0408895,,,,,,,,,,,,,, -57100,0.8891808,1.0887768,,,,,,,,,,,,,, -57200,0.9439235,1.033061,,,,,,,,,,,,,, -57300,0.81889826,1.0132108,,,,,,,,,,,,,, -57400,0.8907105,1.0627817,,,,,,,,,,,,,, -57500,0.95480186,1.0382236,,,,,,,,,,,,,, -57600,0.90081203,1.0772347,,,,,,,,,,,,,, -57700,0.8819143,1.0370953,,,,,,,,,,,,,, -57800,0.87303007,1.0349046,,,,,,,,,,,,,, -57900,1.581399,1.0250336,,,,,,,,,,,,,, -58000,1.0343965,0.98809665,,,,,,,,,,,,,, -58100,1.1861075,1.0572621,,,,,,,,,,,,,, -58200,0.97534496,1.0011152,,,,,,,,,,,,,, -58300,1.163556,1.0016385,,,,,,,,,,,,,, -58400,0.94041157,1.0095353,,,,,,,,,,,,,, -58500,0.8785717,1.045316,,,,,,,,,,,,,, -58571,,,0.13256885,0.0473521011397905,0.35000604,0.1022234665997277,5348.0,0.1881308,0.0637174253041659,2472.0,46123.78596878052,50409.45980882645,46123.78596878052,4281.586913585663,1.746351718902588,0.0 -58600,0.9760976,1.0085783,,,,,,,,,,,,,, -58700,1.1388837,1.0477389,,,,,,,,,,,,,, -58800,1.1415881,1.0323973,,,,,,,,,,,,,, -58900,0.8359049,0.9650665,,,,,,,,,,,,,, -59000,0.93326247,1.0104984,,,,,,,,,,,,,, -59100,0.95175135,1.0049211,,,,,,,,,,,,,, -59200,1.0268638,1.024313,,,,,,,,,,,,,, -59300,1.1614778,1.0212512,,,,,,,,,,,,,, -59400,1.1766229,0.9824216,,,,,,,,,,,,,, -59500,0.88593185,1.0106755,,,,,,,,,,,,,, -59600,0.97635055,1.020209,,,,,,,,,,,,,, -59700,0.9604484,0.9722128,,,,,,,,,,,,,, -59800,0.95853925,0.9481418,,,,,,,,,,,,,, -59900,0.9970776,1.0190636,,,,,,,,,,,,,, -60000,1.0464118,0.9691196,,,,,,,,,,,,,, -60100,1.1217237,1.031437,,,,,,,,,,,,,, -60200,0.8905595,0.9767616,,,,,,,,,,,,,, -60300,1.2378445,1.0268162,,,,,,,,,,,,,, -60400,0.9489797,0.95529425,,,,,,,,,,,,,, -60404,,,0.08553199,0.0332429879676399,0.34243563,0.0999546231306178,5348.0,0.18706349,0.0626815347429569,2472.0,47564.2847366333,51981.551478385925,47564.2847366333,4413.043275594711,1.8080739974975584,0.0 -60500,0.9115344,1.0114623,,,,,,,,,,,,,, -60600,0.928089,0.9700241,,,,,,,,,,,,,, -60700,0.92839533,0.9684213,,,,,,,,,,,,,, -60800,1.0583546,1.0169787,,,,,,,,,,,,,, -60900,1.0197166,0.9660919,,,,,,,,,,,,,, -61000,0.9433969,0.9505972,,,,,,,,,,,,,, -61100,1.2596467,1.005475,,,,,,,,,,,,,, -61200,1.0703012,0.98046803,,,,,,,,,,,,,, -61300,1.1645576,0.9457174,,,,,,,,,,,,,, -61400,1.0194051,1.0298448,,,,,,,,,,,,,, -61500,0.9107358,0.9789805,,,,,,,,,,,,,, -61600,1.0019889,1.0185966,,,,,,,,,,,,,, -61700,1.094374,0.9661951,,,,,,,,,,,,,, -61800,0.9875588,0.9748438,,,,,,,,,,,,,, -61900,1.067913,0.9641621,,,,,,,,,,,,,, -62000,1.0152721,1.0083531,,,,,,,,,,,,,, -62100,1.5021092,0.9931929,,,,,,,,,,,,,, -62200,1.0386466,0.990354,,,,,,,,,,,,,, -62252,,,0.08559206,0.0322760184761349,0.33397695,0.0979174913349488,5348.0,0.1801457,0.0602644567668027,2472.0,49004.790016412735,53553.19096279144,49004.790016412735,4544.038645029068,1.8715159893035889,0.0 -62300,1.3746723,0.9305693,,,,,,,,,,,,,, -62400,1.05479,1.0038259,,,,,,,,,,,,,, -62500,1.7352939,1.0090702,,,,,,,,,,,,,, -62600,1.0786391,0.9016301,,,,,,,,,,,,,, -62700,0.9872352,0.9693835,,,,,,,,,,,,,, -62800,1.0977521,0.95878667,,,,,,,,,,,,,, -62900,1.1621128,0.92626804,,,,,,,,,,,,,, -63000,1.0074669,0.96911174,,,,,,,,,,,,,, -63100,0.8473162,0.9381742,,,,,,,,,,,,,, -63200,1.0557656,0.9795676,,,,,,,,,,,,,, -63300,1.0058587,0.9670023,,,,,,,,,,,,,, -63400,1.0582744,0.94785625,,,,,,,,,,,,,, -63500,1.1630304,0.97346693,,,,,,,,,,,,,, -63600,1.1307288,0.9670682,,,,,,,,,,,,,, -63700,1.1291753,0.96523845,,,,,,,,,,,,,, -63800,1.0240517,0.92206526,,,,,,,,,,,,,, -63900,1.1511915,0.93174994,,,,,,,,,,,,,, -64000,1.0290889,0.9690492,,,,,,,,,,,,,, -64082,,,0.103173256,0.0402818574870457,0.32366726,0.09361151607017,5348.0,0.17913373,0.0594723051611723,2472.0,50445.33984518051,55123.66427946091,50445.33984518051,4673.833595752716,1.9265692234039309,0.0 -64100,1.183317,0.94188,,,,,,,,,,,,,, -64200,1.1189758,0.90610605,,,,,,,,,,,,,, -64300,0.998261,0.95612806,,,,,,,,,,,,,, -64400,1.140074,0.9522707,,,,,,,,,,,,,, -64500,0.98648924,1.0044866,,,,,,,,,,,,,, -64600,0.9516051,0.9418199,,,,,,,,,,,,,, -64700,0.9536602,0.9712092,,,,,,,,,,,,,, -64800,1.1457791,0.9666433,,,,,,,,,,,,,, -64900,1.3647377,0.9851967,,,,,,,,,,,,,, -65000,1.1371723,0.9348711,,,,,,,,,,,,,, -65100,1.4460986,0.9334673,,,,,,,,,,,,,, -65200,1.0617318,0.9635227,,,,,,,,,,,,,, -65300,1.0218371,0.92080706,,,,,,,,,,,,,, -65400,1.0202351,0.9197851,,,,,,,,,,,,,, -65500,1.4800647,0.93924034,,,,,,,,,,,,,, -65600,1.1980907,0.9650959,,,,,,,,,,,,,, -65700,1.2595804,0.94310385,,,,,,,,,,,,,, -65800,0.9656548,0.93093324,,,,,,,,,,,,,, -65900,1.0770311,0.890279,,,,,,,,,,,,,, -65901,,,0.09804582,0.0370182862740134,0.31444234,0.0911495795398592,5348.0,0.16896814,0.0557349745089675,2472.0,51886.01150846481,56693.15033197403,51886.01150846481,4802.509584188461,1.989338636398316,0.0 -66000,1.238419,0.8806407,,,,,,,,,,,,,, -66100,1.1371869,0.9534455,,,,,,,,,,,,,, -66200,1.0692873,0.9270996,,,,,,,,,,,,,, -66300,1.3278662,0.9419632,,,,,,,,,,,,,, -66400,1.07562,0.9410213,,,,,,,,,,,,,, -66500,1.2244066,0.9099035,,,,,,,,,,,,,, -66600,1.1349667,0.92547816,,,,,,,,,,,,,, -66700,1.0195392,0.9526493,,,,,,,,,,,,,, -66800,1.4031456,0.92180437,,,,,,,,,,,,,, -66900,1.0377781,0.9409182,,,,,,,,,,,,,, -67000,1.1166682,0.88904613,,,,,,,,,,,,,, -67100,1.1475339,0.97155213,,,,,,,,,,,,,, -67200,1.1181043,0.9217334,,,,,,,,,,,,,, -67300,1.1411816,0.9583829,,,,,,,,,,,,,, -67400,1.2222046,0.8777964,,,,,,,,,,,,,, -67500,1.170589,0.9008273,,,,,,,,,,,,,, -67600,1.3416914,0.94812,,,,,,,,,,,,,, -67700,1.0168114,0.8915605,,,,,,,,,,,,,, -67717,,,0.11457752,0.0450217154005156,0.3129322,0.0905316817440165,5348.0,0.1688004,0.0571161619239128,2472.0,53326.61968564987,58263.58122968674,53326.61968564987,4932.19512963295,2.0536038875579834,0.0 -67800,1.0596288,0.8791014,,,,,,,,,,,,,, -67900,1.0818759,0.9290765,,,,,,,,,,,,,, -68000,1.1153189,0.9203769,,,,,,,,,,,,,, -68100,1.2112916,0.92140317,,,,,,,,,,,,,, -68200,1.1347692,0.91206414,,,,,,,,,,,,,, -68300,1.150979,0.9125483,,,,,,,,,,,,,, -68400,1.1602668,0.8934732,,,,,,,,,,,,,, -68500,1.1746353,0.90546423,,,,,,,,,,,,,, -68600,1.146326,0.89387465,,,,,,,,,,,,,, -68700,1.2215871,0.8955791,,,,,,,,,,,,,, -68800,1.2181212,0.91061974,,,,,,,,,,,,,, -68900,1.3070962,0.90227973,,,,,,,,,,,,,, -69000,1.1623261,0.9113028,,,,,,,,,,,,,, -69100,1.1390779,0.9171027,,,,,,,,,,,,,, -69200,1.3227943,0.8554483,,,,,,,,,,,,,, -69300,1.2994577,0.9142424,,,,,,,,,,,,,, -69400,1.1977859,0.89943594,,,,,,,,,,,,,, -69500,1.1126825,0.8601006,,,,,,,,,,,,,, -69559,,,0.09059269,0.0335395084535696,0.30794328,0.0875966672137636,5348.0,0.16448735,0.0552474965978104,2472.0,54766.81351041794,59832.69885110855,54766.81351041794,5060.983474969864,2.1116764545440674,0.0 -69600,1.1761684,0.92256975,,,,,,,,,,,,,, -69700,1.2756757,0.8722434,,,,,,,,,,,,,, -69800,1.1182141,0.88628465,,,,,,,,,,,,,, -69900,1.055009,0.8843253,,,,,,,,,,,,,, -70000,1.3026128,0.87135047,,,,,,,,,,,,,, -70100,1.1378407,0.8534074,,,,,,,,,,,,,, -70200,1.4702528,0.86093634,,,,,,,,,,,,,, -70300,1.1639535,0.8738993,,,,,,,,,,,,,, -70400,1.1389891,0.8640681,,,,,,,,,,,,,, -70500,1.261852,0.89745265,,,,,,,,,,,,,, -70600,1.5104121,0.8695445,,,,,,,,,,,,,, -70700,1.1958355,0.86138326,,,,,,,,,,,,,, -70800,1.0712212,0.8683896,,,,,,,,,,,,,, -70900,1.3792536,0.89765227,,,,,,,,,,,,,, -71000,1.0938518,0.89141256,,,,,,,,,,,,,, -71100,1.3070381,0.8344655,,,,,,,,,,,,,, -71200,1.1887197,0.8763798,,,,,,,,,,,,,, -71300,1.08832,0.86860025,,,,,,,,,,,,,, -71385,,,0.08755792,0.0338997509127512,0.30202192,0.0857912470915357,5348.0,0.16138907,0.0539272439217597,2472.0,56206.77034330368,61402.078199625015,56206.77034330368,5190.269483566284,2.172675371170044,0.0 -71400,1.2515657,0.85874325,,,,,,,,,,,,,, -71500,1.3935224,0.9086463,,,,,,,,,,,,,, -71600,1.2727195,0.88997763,,,,,,,,,,,,,, -71700,1.2204835,0.9014664,,,,,,,,,,,,,, -71800,1.1787835,0.9211409,,,,,,,,,,,,,, -71900,1.1593152,0.8986484,,,,,,,,,,,,,, -72000,1.2723081,0.89224184,,,,,,,,,,,,,, -72100,1.553436,0.8491749,,,,,,,,,,,,,, -72200,1.192348,0.8495351,,,,,,,,,,,,,, -72300,1.5882387,0.8576758,,,,,,,,,,,,,, -72400,1.2244265,0.9199903,,,,,,,,,,,,,, -72500,2.0821924,0.8707984,,,,,,,,,,,,,, -72600,1.2663996,0.8847715,,,,,,,,,,,,,, -72700,1.0946484,0.8607265,,,,,,,,,,,,,, -72800,1.2613341,0.86280024,,,,,,,,,,,,,, -72900,1.0886538,0.84526473,,,,,,,,,,,,,, -73000,1.0952668,0.875777,,,,,,,,,,,,,, -73100,1.1709232,0.85998195,,,,,,,,,,,,,, -73176,,,0.06485754,0.0246496489765177,0.29641142,0.084719580601871,5348.0,0.15961748,0.0531554038957609,2472.0,57646.677629470825,62974.4433233738,57646.677629470825,5322.592356920242,2.23431396484375,0.0 -73200,1.1883984,0.8737047,,,,,,,,,,,,,, -73300,1.1047981,0.88082564,,,,,,,,,,,,,, -73400,1.1932429,0.83655244,,,,,,,,,,,,,, -73500,1.1457406,0.85143286,,,,,,,,,,,,,, -73600,1.2621057,0.872128,,,,,,,,,,,,,, -73700,1.2248317,0.90113294,,,,,,,,,,,,,, -73800,1.2622439,0.8638007,,,,,,,,,,,,,, -73900,1.2948058,0.7985536,,,,,,,,,,,,,, -74000,1.3044555,0.870516,,,,,,,,,,,,,, -74100,1.3104775,0.88348734,,,,,,,,,,,,,, -74200,1.1002247,0.86791533,,,,,,,,,,,,,, -74300,1.2925651,0.8566726,,,,,,,,,,,,,, -74400,1.0278642,0.86553305,,,,,,,,,,,,,, -74500,1.3599646,0.84657174,,,,,,,,,,,,,, -74600,1.2390543,0.8643368,,,,,,,,,,,,,, -74700,1.3049723,0.8419931,,,,,,,,,,,,,, -74800,1.0948443,0.8406551,,,,,,,,,,,,,, -74900,1.113361,0.8408835,,,,,,,,,,,,,, -74993,,,0.07795109,0.0302521100395187,0.29337308,0.0834065477857053,5348.0,0.15675355,0.0516320354233948,2472.0,59086.73326683045,64545.75790762901,59086.73326683045,5453.71427154541,2.295135021209717,0.0 -75000,1.3533216,0.8318763,,,,,,,,,,,,,, -75100,1.395901,0.9046739,,,,,,,,,,,,,, -75200,1.1157862,0.89859444,,,,,,,,,,,,,, -75300,1.2781591,0.8684505,,,,,,,,,,,,,, -75400,1.2196894,0.8589132,,,,,,,,,,,,,, -75500,1.1451693,0.86438364,,,,,,,,,,,,,, -75600,1.3738146,0.8619721,,,,,,,,,,,,,, -75700,1.2627324,0.8209698,,,,,,,,,,,,,, -75800,1.3091686,0.84451735,,,,,,,,,,,,,, -75900,1.1276549,0.8707193,,,,,,,,,,,,,, -76000,1.4713327,0.83050907,,,,,,,,,,,,,, -76100,1.1478623,0.8534703,,,,,,,,,,,,,, -76200,1.023433,0.88083655,,,,,,,,,,,,,, -76300,1.3442043,0.8722925,,,,,,,,,,,,,, -76400,1.2873683,0.86238146,,,,,,,,,,,,,, -76500,1.2312558,0.838748,,,,,,,,,,,,,, -76600,1.1879016,0.83489263,,,,,,,,,,,,,, -76700,1.9607345,0.847357,,,,,,,,,,,,,, -76800,1.4551681,0.90043646,,,,,,,,,,,,,, -76822,,,0.06765499,0.0251169004676018,0.29259604,0.0835127489693657,5348.0,0.15502998,0.0510633111937115,2472.0,60527.72573661804,66118.33410978317,60527.72573661804,5585.158769369125,2.3584864139556885,0.0 -76900,1.1674323,0.83709556,,,,,,,,,,,,,, -77000,1.1500514,0.86864734,,,,,,,,,,,,,, -77100,1.1022197,0.83780634,,,,,,,,,,,,,, -77200,1.2303975,0.88244087,,,,,,,,,,,,,, -77300,1.2940042,0.859577,,,,,,,,,,,,,, -77400,1.084326,0.86033916,,,,,,,,,,,,,, -77500,1.2855059,0.9021666,,,,,,,,,,,,,, -77529,,,,,,,,,,,61068.68900227547,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 6a8ff3e96..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -135.83981275558472,0.0,35.09761643409729,1,0,35.09761643409729,31.23744,2472,1.10235004976337,170.93748545646667,31.98225,1.3614880623081345,31.08705,5348,1.0585747801152765 -244.7291920185089,0.0279905796051025,1476.0195994377136,1791,0,1476.0195994377136,6.0074077,2472,0.899579550301627,1720.848159313202,6.036521,0.9413900245298448,6.0377555,5348,0.8966179750330672 -353.1840178966522,0.0870876312255859,2916.5635409355164,3623,0,2916.5635409355164,5.9581666,2472,0.8983405439441026,3269.98179101944,6.206665,0.9382353276158586,6.116562,5348,0.8959904225841644 -461.63093519210815,0.1405568122863769,4356.7338581085205,5452,0,4356.7338581085205,12.230537,2472,0.899579550301627,4818.730646371841,12.30888,0.939860390191516,12.248377,5348,0.8966179750330672 -569.9528002738953,0.1885228157043457,5796.618039608002,7255,0,5796.618039608002,9.05541,2472,0.899579550301627,6367.059968471527,9.199602,0.9406809116337576,9.2184,5348,0.8966179750330672 -700.5129299163818,0.2425758838653564,7236.628142595291,9072,0,7236.628142595291,1.4274611,2472,0.403936384132594,7937.760040521622,1.9409115,0.5133557669993037,1.8373336,5348,0.4703650424322002 -832.8874080181122,0.2927138805389404,8677.431626081467,10901,0,8677.431626081467,1.2287663,2472,0.3752767452724798,9511.06428551674,1.5330179,0.4447662118572818,1.5998214,5348,0.4368054683954932 -965.0422995090483,0.3434820175170898,10117.52890110016,12737,0,10117.52890110016,1.1772243,2472,0.3542745719334593,11083.44418811798,1.5478851,0.4351763341723465,1.5292115,5348,0.4163086399490234 -1095.4794552326202,0.3978815078735351,11557.644403457642,14544,0,11557.644403457642,0.9283827,2472,0.2951881867852863,12654.127409219742,1.2687769,0.3805742176611467,1.2760807,5348,0.363468723751412 -1227.9253115653992,0.447887659072876,12997.894040346146,16362,0,12997.894040346146,0.8068257,2472,0.2676659963845388,14226.949261188509,1.0706007,0.3385405215249989,1.1228197,5348,0.3314442395512517 -1358.848022699356,0.5031416416168213,14437.993826389313,18176,0,14437.993826389313,0.7684886,2472,0.2561696423130827,15798.103556871414,0.96476525,0.3123660160216631,1.0905648,5348,0.3211427247361866 -1488.713681936264,0.5551145076751709,15878.171899795532,20012,0,15878.171899795532,0.715538,2472,0.2363861637519549,17368.276427268982,0.8736541,0.2842924222705095,1.0145409,5348,0.3012444847794394 -1620.1200077533722,0.6046969890594482,17318.47605085373,21832,0,17318.47605085373,0.72071874,2472,0.2370158227205329,18940.11133503914,0.9753044,0.3114894714417061,1.0329162,5348,0.3040057155546115 -1752.4782774448397,0.654998779296875,18758.69697093964,23646,0,18758.69697093964,0.6862627,2472,0.2272256413381268,20512.817051410675,0.8934462,0.2851090476244203,0.9852448,5348,0.291782924780598 -1884.7277166843407,0.7082116603851318,20198.639166355133,25454,0,20198.639166355133,0.66787153,2472,0.2244835780878679,22085.138073682785,0.8983038,0.2872240604474295,0.96987027,5348,0.2899871593114301 -2017.1603062152865,0.7608418464660645,21638.938226938248,27290,0,21638.938226938248,0.64174926,2472,0.2136778177238844,23658.000257968903,0.7946374,0.2607233290857483,0.9428824,5348,0.2765865008640914 -2148.321827411652,0.8176655769348145,23079.49067544937,29119,0,23079.49067544937,0.6265644,2472,0.2123778766274653,25229.84836292267,0.7964099,0.2679116939109455,0.9283386,5348,0.2796470258841249 -2279.065259218216,0.8696136474609375,24519.64860892296,30930,0,24519.64860892296,0.614802,2472,0.2055531858712652,26800.87813210488,0.81224275,0.2673072213330903,0.9076635,5348,0.2705909613138051 -2420.939267396927,0.9254543781280518,25960.040961503983,32775,0,25960.040961503983,0.5819507,2472,0.1957630044888591,28383.27775406837,0.5687455,0.1970680333286381,0.8738757,5348,0.2610714733966034 -2556.310618162155,0.9855234622955322,27400.785620450974,34601,0,27400.785620450974,0.5760456,2472,0.1954177076351227,29959.528460025787,0.52129424,0.1864079379382489,0.86678445,5348,0.2615735153557257 -2689.823467731476,1.0541713237762451,28841.119877815247,36435,0,28841.119877815247,0.5492272,2472,0.1873946336806613,31533.52166032791,0.4877663,0.1712832078885945,0.82872856,5348,0.2493024513164119 -2824.1240861415863,1.1081647872924805,30281.715041160583,38254,0,30281.715041160583,0.5308841,2472,0.1792496902484106,33108.547865867615,0.45671743,0.1629685196202695,0.79590267,5348,0.2387885341340259 -2956.557309150696,1.1653995513916016,31721.85123872757,40068,0,31721.85123872757,0.50365597,2472,0.1681595677695854,34681.24870944023,0.44774008,0.1597666977808232,0.7690591,5348,0.2315668536451142 -3090.524173974991,1.2250707149505615,33161.960060596466,41881,0,33161.960060596466,0.5015915,2472,0.1699469867771616,36255.45990538597,0.42459473,0.1542755805534277,0.7605355,5348,0.2281491064618593 -3223.048075437546,1.28309965133667,34602.280463695526,43718,0,34602.280463695526,0.492929,2472,0.1679767635529015,37828.44002938271,0.4951775,0.1707194429413918,0.7597602,5348,0.2293269741351844 -3354.965092897415,1.3371038436889648,36042.54721474648,45542,0,36042.54721474648,0.46914527,2472,0.1600146243373347,39400.75337386131,0.40693375,0.1485157320888572,0.7193552,5348,0.2186682371568977 -3486.938924312592,1.3953540325164795,37482.9290099144,47348,0,37482.9290099144,0.4517525,2472,0.1527633904088721,40973.24291205406,0.38006338,0.1380358563428216,0.69912624,5348,0.2128078627494521 -3619.7294538021088,1.4531559944152832,38923.53028583527,49157,0,38923.53028583527,0.43440887,2472,0.1462230617675136,42546.768450737,0.3494525,0.1273569686936377,0.67146903,5348,0.2035297411587514 -3751.937073707581,1.511134386062622,40363.68802213669,50996,0,40363.68802213669,0.4144835,2472,0.1407795584262588,44119.26902413368,0.33166713,0.1234891205972291,0.6463883,5348,0.1973024899350242 -3883.693731546402,1.566605806350708,41803.74843025208,52829,0,41803.74843025208,0.40218315,2472,0.1346860845367944,45691.219561100006,0.3443252,0.1249640157371176,0.6351694,5348,0.1935178659354876 -4017.97895359993,1.6243414878845217,43244.20497202873,54640,0,43244.20497202873,0.38517037,2472,0.131781528649483,47266.09590554237,0.33093026,0.1206635390150001,0.61256367,5348,0.1880533323035036 -4151.758671045303,1.6820077896118164,44684.27011537552,56468,0,44684.27011537552,0.3706889,2472,0.1253630694859139,48840.0756649971,0.31670836,0.1151420291651815,0.5975832,5348,0.1822508858144182 -4284.507459640503,1.7384145259857178,46124.81386065483,58294,0,46124.81386065483,0.35391665,2472,0.1200414356224483,50413.50209951401,0.2850827,0.1052168831373953,0.5681,5348,0.1727313978972165 -4417.584407091141,1.7928149700164795,47564.94284963608,60134,0,47564.94284963608,0.33913124,2472,0.1140495196311417,51986.8414990902,0.2588146,0.0954069945424412,0.5504874,5348,0.168261293530417 -4551.215311527252,1.854458570480347,49004.83890962601,61956,0,49004.83890962601,0.32282758,2472,0.1082607194361505,53560.50705432892,0.24200241,0.0895358863809282,0.52719754,5348,0.1611265049190457 -4684.490033864975,1.9127497673034668,50445.49860596657,63784,0,50445.49860596657,0.30937302,2472,0.1051733593321552,55134.57596540451,0.23850976,0.0884111856169462,0.50763637,5348,0.1557971364299024 -4816.434291601181,1.971311092376709,51885.54228544235,65607,0,51885.54228544235,0.29772234,2472,0.0994048707167956,56706.69957733154,0.21402435,0.0783076178700069,0.4928836,5348,0.1519352752058854 -4949.930072784424,2.032210111618042,53325.713799238205,67448,0,53325.713799238205,0.28477463,2472,0.095972213759064,58280.50552463532,0.21193847,0.0781770019697072,0.4739441,5348,0.144308099288452 -5084.901441574097,2.091064214706421,54765.77884674072,69280,0,54765.77884674072,0.27006716,2472,0.0910974346474925,59855.67841768265,0.22361787,0.0777870612249214,0.45318562,5348,0.1395869739420913 -5217.849287033081,2.1515071392059326,56206.375772714615,71090,0,56206.375772714615,0.26090908,2472,0.0878678934860764,61429.3595867157,0.15905462,0.0586980822030711,0.44548163,5348,0.1354644370854533 -5350.778831243515,2.2165493965148926,57646.29920887947,72927,0,57646.29920887947,0.25125402,2472,0.0844149249487132,63002.354064941406,0.15804432,0.0580403099802996,0.42757252,5348,0.1304922907595315 -5482.808529615402,2.2834372520446777,59086.25704813004,74758,0,59086.25704813004,0.24742267,2472,0.082221274348506,64574.48491859436,0.198389,0.0739333898908075,0.4172008,5348,0.1273159099027776 -5615.292538166046,2.351891279220581,60526.15056729317,76591,0,60526.15056729317,0.24104773,2472,0.0800276237482989,66147.00818371773,0.19787984,0.07074304504809609,0.41123158,5348,0.1253656699846491 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/measurements.csv deleted file mode 100644 index 9abf08fbe..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/measurements.csv +++ /dev/null @@ -1,818 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,41.225094,32.500175,,,,,,,,,,,,,, -1,,,31.98225,1.3614880623081345,31.08705,1.0585747801152765,5348.0,31.23744,1.10235004976337,2472.0,35.09761643409729,170.93748545646667,35.09761643409729,135.83981275558472,0.0,0.0 -100,1.2876325,5.811515,,,,,,,,,,,,,, -200,3.134496,5.7787166,,,,,,,,,,,,,, -300,4.4562244,5.6196585,,,,,,,,,,,,,, -400,4.854106,5.709814,,,,,,,,,,,,,, -500,1.6085688,5.5744348,,,,,,,,,,,,,, -600,1.5013859,5.52662,,,,,,,,,,,,,, -700,1.9265677,5.531251,,,,,,,,,,,,,, -800,2.6148605,5.51348,,,,,,,,,,,,,, -900,1.0518553,5.5081143,,,,,,,,,,,,,, -1000,0.45549318,5.514954,,,,,,,,,,,,,, -1100,0.5545591,5.555872,,,,,,,,,,,,,, -1200,1.7306392,5.5091414,,,,,,,,,,,,,, -1300,2.6999168,5.5501966,,,,,,,,,,,,,, -1400,1.5529575,5.5828,,,,,,,,,,,,,, -1500,0.71682805,5.4950104,,,,,,,,,,,,,, -1600,1.9646274,5.4907722,,,,,,,,,,,,,, -1700,0.879482,5.483246,,,,,,,,,,,,,, -1791,,,6.036521,0.9413900245298448,6.0377555,0.8966179750330672,5348.0,6.0074077,0.899579550301627,2472.0,1476.0195994377136,1720.848159313202,1476.0195994377136,244.7291920185089,0.0279905796051025,0.0 -1800,1.5402199,5.4785385,,,,,,,,,,,,,, -1900,0.27711138,5.4861174,,,,,,,,,,,,,, -2000,2.9759154,5.5092793,,,,,,,,,,,,,, -2100,1.152269,5.499337,,,,,,,,,,,,,, -2200,2.198111,5.4851923,,,,,,,,,,,,,, -2300,1.8656409,5.4653664,,,,,,,,,,,,,, -2400,0.5751339,5.4730806,,,,,,,,,,,,,, -2500,1.7430346,5.4899125,,,,,,,,,,,,,, -2600,1.3307626,5.485405,,,,,,,,,,,,,, -2700,1.1162171,5.4720926,,,,,,,,,,,,,, -2800,1.571605,5.4115787,,,,,,,,,,,,,, -2900,1.6550376,5.1972547,,,,,,,,,,,,,, -3000,0.72926605,4.8039947,,,,,,,,,,,,,, -3100,2.2000787,4.539493,,,,,,,,,,,,,, -3200,0.892168,4.3794036,,,,,,,,,,,,,, -3300,1.193834,4.249994,,,,,,,,,,,,,, -3400,0.47885427,4.103858,,,,,,,,,,,,,, -3500,1.0440756,3.947807,,,,,,,,,,,,,, -3600,2.0673862,3.9305491,,,,,,,,,,,,,, -3623,,,6.206665,0.9382353276158586,6.116562,0.8959904225841644,5348.0,5.9581666,0.8983405439441026,2472.0,2916.5635409355164,3269.98179101944,2916.5635409355164,353.1840178966522,0.0870876312255859,0.0 -3700,1.3138648,3.8190417,,,,,,,,,,,,,, -3800,1.1777034,3.748208,,,,,,,,,,,,,, -3900,0.6602829,3.7286022,,,,,,,,,,,,,, -4000,0.525895,3.5891275,,,,,,,,,,,,,, -4100,1.3864722,3.6336207,,,,,,,,,,,,,, -4200,1.4028348,3.532399,,,,,,,,,,,,,, -4300,0.5234232,3.4425738,,,,,,,,,,,,,, -4400,1.8323972,3.4627874,,,,,,,,,,,,,, -4500,0.8264446,3.3930526,,,,,,,,,,,,,, -4600,1.1651124,3.3789668,,,,,,,,,,,,,, -4700,0.6925115,3.2717655,,,,,,,,,,,,,, -4800,0.69962496,3.2199364,,,,,,,,,,,,,, -4900,0.9930499,3.2837806,,,,,,,,,,,,,, -5000,1.1754122,3.2547805,,,,,,,,,,,,,, -5100,0.9367324,3.2067444,,,,,,,,,,,,,, -5200,0.9595026,3.2182403,,,,,,,,,,,,,, -5300,1.5892767,3.1499534,,,,,,,,,,,,,, -5400,1.4208711,3.1131744,,,,,,,,,,,,,, -5452,,,12.30888,0.939860390191516,12.248377,0.8966179750330672,5348.0,12.230537,0.899579550301627,2472.0,4356.7338581085205,4818.730646371841,4356.7338581085205,461.63093519210815,0.1405568122863769,0.0 -5500,1.1445731,3.0848086,,,,,,,,,,,,,, -5600,1.1607579,3.0976229,,,,,,,,,,,,,, -5700,1.2232727,3.0623665,,,,,,,,,,,,,, -5800,0.93951166,3.0205336,,,,,,,,,,,,,, -5900,1.1034682,2.984054,,,,,,,,,,,,,, -6000,0.63513917,2.951689,,,,,,,,,,,,,, -6100,1.9355108,2.958307,,,,,,,,,,,,,, -6200,0.54365,2.9145818,,,,,,,,,,,,,, -6300,1.3776506,2.9723353,,,,,,,,,,,,,, -6400,0.7954254,2.8820412,,,,,,,,,,,,,, -6500,0.6892909,2.8726087,,,,,,,,,,,,,, -6600,1.5265375,2.828589,,,,,,,,,,,,,, -6700,1.3336536,2.9123352,,,,,,,,,,,,,, -6800,1.2913073,2.8740952,,,,,,,,,,,,,, -6900,1.3217506,2.7961388,,,,,,,,,,,,,, -7000,0.7536093,2.8529134,,,,,,,,,,,,,, -7100,1.0503409,2.83411,,,,,,,,,,,,,, -7200,1.8485929,2.758929,,,,,,,,,,,,,, -7255,,,9.199602,0.9406809116337576,9.2184,0.8966179750330672,5348.0,9.05541,0.899579550301627,2472.0,5796.618039608002,6367.059968471527,5796.618039608002,569.9528002738953,0.1885228157043457,0.0 -7300,1.1890254,2.6771924,,,,,,,,,,,,,, -7400,1.1786077,2.756926,,,,,,,,,,,,,, -7500,1.1619486,2.7168357,,,,,,,,,,,,,, -7600,0.84725946,2.7374005,,,,,,,,,,,,,, -7700,1.0352936,2.6631217,,,,,,,,,,,,,, -7800,1.524048,2.7350018,,,,,,,,,,,,,, -7900,5.965471,2.630027,,,,,,,,,,,,,, -8000,0.6886199,2.6954317,,,,,,,,,,,,,, -8100,1.039105,2.6757202,,,,,,,,,,,,,, -8200,1.2603432,2.614648,,,,,,,,,,,,,, -8300,1.0041118,2.5903084,,,,,,,,,,,,,, -8400,0.879227,2.621099,,,,,,,,,,,,,, -8500,0.8534201,2.5603657,,,,,,,,,,,,,, -8600,0.9452957,2.5938578,,,,,,,,,,,,,, -8700,1.2612625,2.4889565,,,,,,,,,,,,,, -8800,1.7143097,2.5207567,,,,,,,,,,,,,, -8900,1.6636199,2.50091,,,,,,,,,,,,,, -9000,0.81705576,2.4538448,,,,,,,,,,,,,, -9072,,,1.9409115,0.5133557669993037,1.8373336,0.4703650424322002,5348.0,1.4274611,0.403936384132594,2472.0,7236.628142595291,7937.760040521622,7236.628142595291,700.5129299163818,0.2425758838653564,0.0 -9100,0.7113036,2.4476967,,,,,,,,,,,,,, -9200,0.6106435,2.5016603,,,,,,,,,,,,,, -9300,0.5690951,2.4743652,,,,,,,,,,,,,, -9400,1.0486487,2.443634,,,,,,,,,,,,,, -9500,1.5932828,2.4568884,,,,,,,,,,,,,, -9600,0.7262469,2.3445637,,,,,,,,,,,,,, -9700,1.1869574,2.3941183,,,,,,,,,,,,,, -9800,0.9094433,2.3261585,,,,,,,,,,,,,, -9900,1.0279367,2.3791451,,,,,,,,,,,,,, -10000,0.7884894,2.3879097,,,,,,,,,,,,,, -10100,0.5291112,2.4176707,,,,,,,,,,,,,, -10200,0.62162596,2.3645868,,,,,,,,,,,,,, -10300,0.9036465,2.2969127,,,,,,,,,,,,,, -10400,0.9579142,2.3195815,,,,,,,,,,,,,, -10500,1.4865558,2.430054,,,,,,,,,,,,,, -10600,1.0665923,2.3454556,,,,,,,,,,,,,, -10700,1.0800734,2.281501,,,,,,,,,,,,,, -10800,1.1983714,2.263661,,,,,,,,,,,,,, -10900,0.7339348,2.339212,,,,,,,,,,,,,, -10901,,,1.5330179,0.4447662118572818,1.5998214,0.4368054683954932,5348.0,1.2287663,0.3752767452724798,2472.0,8677.431626081467,9511.06428551674,8677.431626081467,832.8874080181122,0.2927138805389404,0.0 -11000,1.0690856,2.2845345,,,,,,,,,,,,,, -11100,1.3879399,2.2752917,,,,,,,,,,,,,, -11200,0.59105784,2.2434509,,,,,,,,,,,,,, -11300,0.7910381,2.2353826,,,,,,,,,,,,,, -11400,1.0794815,2.2683206,,,,,,,,,,,,,, -11500,1.8573544,2.3310406,,,,,,,,,,,,,, -11600,0.9951427,2.2814262,,,,,,,,,,,,,, -11700,0.7393371,2.252943,,,,,,,,,,,,,, -11800,0.60616475,2.2919865,,,,,,,,,,,,,, -11900,1.11839,2.2511213,,,,,,,,,,,,,, -12000,0.72721297,2.2747316,,,,,,,,,,,,,, -12100,1.3002368,2.2079494,,,,,,,,,,,,,, -12200,1.0870413,2.2198803,,,,,,,,,,,,,, -12300,0.99360365,2.2298558,,,,,,,,,,,,,, -12400,0.80717045,2.2048862,,,,,,,,,,,,,, -12500,1.3012866,2.2025461,,,,,,,,,,,,,, -12600,0.8627979,2.2667606,,,,,,,,,,,,,, -12700,0.74313086,2.2140076,,,,,,,,,,,,,, -12737,,,1.5478851,0.4351763341723465,1.5292115,0.4163086399490234,5348.0,1.1772243,0.3542745719334593,2472.0,10117.52890110016,11083.44418811798,10117.52890110016,965.0422995090483,0.3434820175170898,0.0 -12800,0.9277956,2.2008486,,,,,,,,,,,,,, -12900,1.0962507,2.2641697,,,,,,,,,,,,,, -13000,0.91904306,2.1956987,,,,,,,,,,,,,, -13100,1.1817074,2.2129066,,,,,,,,,,,,,, -13200,1.7177947,2.2078614,,,,,,,,,,,,,, -13300,1.6135741,2.2446845,,,,,,,,,,,,,, -13400,0.8214699,2.1926663,,,,,,,,,,,,,, -13500,1.3857505,2.1871707,,,,,,,,,,,,,, -13600,0.6157643,2.2074037,,,,,,,,,,,,,, -13700,0.8258518,2.1294365,,,,,,,,,,,,,, -13800,0.6759928,2.1139941,,,,,,,,,,,,,, -13900,0.94419533,2.153175,,,,,,,,,,,,,, -14000,0.9427561,2.1630197,,,,,,,,,,,,,, -14100,0.50127715,2.1231253,,,,,,,,,,,,,, -14200,0.86952806,2.1769495,,,,,,,,,,,,,, -14300,1.4329954,2.1299825,,,,,,,,,,,,,, -14400,1.0764427,2.1708302,,,,,,,,,,,,,, -14500,0.8985648,2.053501,,,,,,,,,,,,,, -14544,,,1.2687769,0.3805742176611467,1.2760807,0.363468723751412,5348.0,0.9283827,0.2951881867852863,2472.0,11557.644403457642,12654.127409219742,11557.644403457642,1095.4794552326202,0.3978815078735351,0.0 -14600,0.7274431,2.1654973,,,,,,,,,,,,,, -14700,0.77866435,2.1474416,,,,,,,,,,,,,, -14800,0.6985087,2.166964,,,,,,,,,,,,,, -14900,0.62695235,2.0970237,,,,,,,,,,,,,, -15000,1.1251326,2.1216383,,,,,,,,,,,,,, -15100,0.9835066,2.1187816,,,,,,,,,,,,,, -15200,0.5577106,2.138873,,,,,,,,,,,,,, -15300,0.8352132,2.051838,,,,,,,,,,,,,, -15400,0.8326207,2.0462906,,,,,,,,,,,,,, -15500,0.80296576,2.1468036,,,,,,,,,,,,,, -15600,1.1368577,2.1010342,,,,,,,,,,,,,, -15700,0.7185681,2.0804236,,,,,,,,,,,,,, -15800,0.7331739,2.0554307,,,,,,,,,,,,,, -15900,0.6745645,2.0812974,,,,,,,,,,,,,, -16000,0.80077827,2.0729997,,,,,,,,,,,,,, -16100,0.64374036,2.0846095,,,,,,,,,,,,,, -16200,0.963808,2.0995715,,,,,,,,,,,,,, -16300,0.54944706,2.029422,,,,,,,,,,,,,, -16362,,,1.0706007,0.3385405215249989,1.1228197,0.3314442395512517,5348.0,0.8068257,0.2676659963845388,2472.0,12997.894040346146,14226.949261188509,12997.894040346146,1227.9253115653992,0.447887659072876,0.0 -16400,0.883458,2.1479228,,,,,,,,,,,,,, -16500,0.5965188,2.0545294,,,,,,,,,,,,,, -16600,0.5270486,2.0929656,,,,,,,,,,,,,, -16700,0.7662306,2.0365717,,,,,,,,,,,,,, -16800,1.0389216,2.0670342,,,,,,,,,,,,,, -16900,0.5044258,2.0840952,,,,,,,,,,,,,, -17000,0.5808797,2.0532491,,,,,,,,,,,,,, -17100,0.9204271,2.0894535,,,,,,,,,,,,,, -17200,0.52741915,2.0896401,,,,,,,,,,,,,, -17300,0.9459673,2.0094345,,,,,,,,,,,,,, -17400,0.990825,2.1051717,,,,,,,,,,,,,, -17500,0.8886914,1.992318,,,,,,,,,,,,,, -17600,0.97381246,2.0924199,,,,,,,,,,,,,, -17700,0.48578763,2.0989885,,,,,,,,,,,,,, -17800,1.1865833,2.0671885,,,,,,,,,,,,,, -17900,1.0931777,2.010439,,,,,,,,,,,,,, -18000,0.7513918,2.0277762,,,,,,,,,,,,,, -18100,0.5604362,2.001535,,,,,,,,,,,,,, -18176,,,0.96476525,0.3123660160216631,1.0905648,0.3211427247361866,5348.0,0.7684886,0.2561696423130827,2472.0,14437.993826389313,15798.103556871414,14437.993826389313,1358.848022699356,0.5031416416168213,0.0 -18200,0.92087764,2.0341175,,,,,,,,,,,,,, -18300,0.6971056,1.9663981,,,,,,,,,,,,,, -18400,0.6177012,2.0406961,,,,,,,,,,,,,, -18500,0.8592016,1.9885019,,,,,,,,,,,,,, -18600,0.5001023,1.9635255,,,,,,,,,,,,,, -18700,0.9326684,2.0545409,,,,,,,,,,,,,, -18800,0.5305278,1.955017,,,,,,,,,,,,,, -18900,0.5246686,1.9786359,,,,,,,,,,,,,, -19000,0.8414839,1.9728408,,,,,,,,,,,,,, -19100,1.0870076,2.007898,,,,,,,,,,,,,, -19200,0.8143119,1.9574041,,,,,,,,,,,,,, -19300,0.96921223,2.034297,,,,,,,,,,,,,, -19400,0.78350747,2.0137033,,,,,,,,,,,,,, -19500,0.6640954,1.9743041,,,,,,,,,,,,,, -19600,0.7920206,1.9935442,,,,,,,,,,,,,, -19700,0.66235673,1.9849741,,,,,,,,,,,,,, -19800,0.53035784,1.9527038,,,,,,,,,,,,,, -19900,0.8736927,1.9881564,,,,,,,,,,,,,, -20000,0.71667576,1.9443765,,,,,,,,,,,,,, -20012,,,0.8736541,0.2842924222705095,1.0145409,0.3012444847794394,5348.0,0.715538,0.2363861637519549,2472.0,15878.171899795532,17368.276427268982,15878.171899795532,1488.713681936264,0.5551145076751709,0.0 -20100,0.733755,1.9524573,,,,,,,,,,,,,, -20200,0.9373197,1.9939977,,,,,,,,,,,,,, -20300,0.582002,1.9002634,,,,,,,,,,,,,, -20400,0.71322596,1.9766234,,,,,,,,,,,,,, -20500,0.6978313,2.0485165,,,,,,,,,,,,,, -20600,0.5298064,1.8955266,,,,,,,,,,,,,, -20700,0.85201555,1.9386063,,,,,,,,,,,,,, -20800,0.6392401,2.0386384,,,,,,,,,,,,,, -20900,0.6782697,1.9508015,,,,,,,,,,,,,, -21000,1.3331383,1.9749277,,,,,,,,,,,,,, -21100,0.85468173,1.97976,,,,,,,,,,,,,, -21200,0.7756649,1.9812647,,,,,,,,,,,,,, -21300,0.5597084,1.9091904,,,,,,,,,,,,,, -21400,0.73987037,2.0088367,,,,,,,,,,,,,, -21500,1.0603302,1.9364369,,,,,,,,,,,,,, -21600,0.8134308,1.9781073,,,,,,,,,,,,,, -21700,0.5237167,1.9720789,,,,,,,,,,,,,, -21800,0.796098,1.9888867,,,,,,,,,,,,,, -21832,,,0.9753044,0.3114894714417061,1.0329162,0.3040057155546115,5348.0,0.72071874,0.2370158227205329,2472.0,17318.47605085373,18940.11133503914,17318.47605085373,1620.1200077533722,0.6046969890594482,0.0 -21900,0.7067625,1.9166948,,,,,,,,,,,,,, -22000,0.663868,2.024743,,,,,,,,,,,,,, -22100,0.75130206,1.9888513,,,,,,,,,,,,,, -22200,0.56893516,1.9371508,,,,,,,,,,,,,, -22300,0.6336576,2.0078218,,,,,,,,,,,,,, -22400,0.6958771,1.8998661,,,,,,,,,,,,,, -22500,0.53049624,1.9233098,,,,,,,,,,,,,, -22600,0.5536155,1.9738153,,,,,,,,,,,,,, -22700,0.996366,1.9903758,,,,,,,,,,,,,, -22800,1.0772432,1.9623405,,,,,,,,,,,,,, -22900,0.94753087,1.8665501,,,,,,,,,,,,,, -23000,0.60780215,1.9622685,,,,,,,,,,,,,, -23100,0.8643584,1.9167986,,,,,,,,,,,,,, -23200,0.8648585,1.9446955,,,,,,,,,,,,,, -23300,0.83568996,1.9775516,,,,,,,,,,,,,, -23400,0.7233279,1.878557,,,,,,,,,,,,,, -23500,0.92674184,1.889207,,,,,,,,,,,,,, -23600,0.7581581,1.9306539,,,,,,,,,,,,,, -23646,,,0.8934462,0.2851090476244203,0.9852448,0.291782924780598,5348.0,0.6862627,0.2272256413381268,2472.0,18758.69697093964,20512.817051410675,18758.69697093964,1752.4782774448397,0.654998779296875,0.0 -23700,0.7325564,1.9452311,,,,,,,,,,,,,, -23800,0.73263943,1.9448577,,,,,,,,,,,,,, -23900,0.916017,1.9149895,,,,,,,,,,,,,, -24000,0.6854348,1.9234529,,,,,,,,,,,,,, -24100,0.5725834,1.8831141,,,,,,,,,,,,,, -24200,0.6924946,1.9231508,,,,,,,,,,,,,, -24300,0.78126806,2.0534499,,,,,,,,,,,,,, -24400,0.9317804,1.9691049,,,,,,,,,,,,,, -24500,0.5611037,1.859214,,,,,,,,,,,,,, -24600,0.64446294,1.9208055,,,,,,,,,,,,,, -24700,0.864995,1.9002886,,,,,,,,,,,,,, -24800,0.61345786,1.976547,,,,,,,,,,,,,, -24900,0.6408587,1.8831028,,,,,,,,,,,,,, -25000,0.75694036,1.9287674,,,,,,,,,,,,,, -25100,0.8293033,1.8837843,,,,,,,,,,,,,, -25200,0.6882867,1.8803264,,,,,,,,,,,,,, -25300,0.6244137,1.8476324,,,,,,,,,,,,,, -25400,0.7303742,1.8598258,,,,,,,,,,,,,, -25454,,,0.8983038,0.2872240604474295,0.96987027,0.2899871593114301,5348.0,0.66787153,0.2244835780878679,2472.0,20198.639166355133,22085.138073682785,20198.639166355133,1884.7277166843407,0.7082116603851318,0.0 -25500,0.87811524,1.9058248,,,,,,,,,,,,,, -25600,1.0300561,1.8548677,,,,,,,,,,,,,, -25700,0.68546116,1.8555615,,,,,,,,,,,,,, -25800,0.85194963,1.8688738,,,,,,,,,,,,,, -25900,0.76366526,1.8974906,,,,,,,,,,,,,, -26000,0.77064836,1.8965248,,,,,,,,,,,,,, -26100,0.6148084,1.8447698,,,,,,,,,,,,,, -26200,0.5682249,1.9478513,,,,,,,,,,,,,, -26300,0.7564064,1.877763,,,,,,,,,,,,,, -26400,0.88353425,1.8563342,,,,,,,,,,,,,, -26500,0.58964556,1.8640379,,,,,,,,,,,,,, -26600,0.9186136,1.8770467,,,,,,,,,,,,,, -26700,0.71265465,1.9373465,,,,,,,,,,,,,, -26800,0.5427478,1.8401138,,,,,,,,,,,,,, -26900,0.5669902,1.8389596,,,,,,,,,,,,,, -27000,0.64256626,1.8796399,,,,,,,,,,,,,, -27100,0.6985716,1.8568535,,,,,,,,,,,,,, -27200,0.5987121,1.893422,,,,,,,,,,,,,, -27290,,,0.7946374,0.2607233290857483,0.9428824,0.2765865008640914,5348.0,0.64174926,0.2136778177238844,2472.0,21638.938226938248,23658.000257968903,21638.938226938248,2017.1603062152865,0.7608418464660645,0.0 -27300,0.5988986,1.8659858,,,,,,,,,,,,,, -27400,0.5792663,1.8854675,,,,,,,,,,,,,, -27500,0.59426105,1.8301251,,,,,,,,,,,,,, -27600,0.85069597,1.9690338,,,,,,,,,,,,,, -27700,0.52094007,1.864356,,,,,,,,,,,,,, -27800,0.6372333,1.8955877,,,,,,,,,,,,,, -27900,0.5851693,1.8120717,,,,,,,,,,,,,, -28000,0.85721344,1.8965886,,,,,,,,,,,,,, -28100,0.64689565,1.9051713,,,,,,,,,,,,,, -28200,0.8079185,1.8334858,,,,,,,,,,,,,, -28300,0.59787524,1.8802158,,,,,,,,,,,,,, -28400,0.8802003,1.8960003,,,,,,,,,,,,,, -28500,0.6709821,1.8885006,,,,,,,,,,,,,, -28600,0.6316101,1.9278462,,,,,,,,,,,,,, -28700,1.0063771,1.8823844,,,,,,,,,,,,,, -28800,0.5796868,1.9036418,,,,,,,,,,,,,, -28900,0.59153336,1.8739307,,,,,,,,,,,,,, -29000,0.67077154,1.818793,,,,,,,,,,,,,, -29100,0.5683894,1.793526,,,,,,,,,,,,,, -29119,,,0.7964099,0.2679116939109455,0.9283386,0.2796470258841249,5348.0,0.6265644,0.2123778766274653,2472.0,23079.49067544937,25229.84836292267,23079.49067544937,2148.321827411652,0.8176655769348145,0.0 -29200,0.97836465,1.940029,,,,,,,,,,,,,, -29300,0.67227036,1.8270792,,,,,,,,,,,,,, -29400,0.66857314,1.8049059,,,,,,,,,,,,,, -29500,0.70467144,1.8151708,,,,,,,,,,,,,, -29600,0.68646157,1.8687397,,,,,,,,,,,,,, -29700,0.8371204,1.8715708,,,,,,,,,,,,,, -29800,0.7512271,1.7829206,,,,,,,,,,,,,, -29900,0.8811357,1.8127241,,,,,,,,,,,,,, -30000,0.7354343,1.8015255,,,,,,,,,,,,,, -30100,0.59695816,1.8631405,,,,,,,,,,,,,, -30200,0.69753206,1.8204045,,,,,,,,,,,,,, -30300,0.95260346,1.8412261,,,,,,,,,,,,,, -30400,0.5894862,1.7890315,,,,,,,,,,,,,, -30500,0.6227132,1.8461522,,,,,,,,,,,,,, -30600,1.1767358,1.7942579,,,,,,,,,,,,,, -30700,0.65035987,1.8587469,,,,,,,,,,,,,, -30800,0.9218606,1.8000944,,,,,,,,,,,,,, -30900,0.7229354,1.8137647,,,,,,,,,,,,,, -30930,,,0.81224275,0.2673072213330903,0.9076635,0.2705909613138051,5348.0,0.614802,0.2055531858712652,2472.0,24519.64860892296,26800.87813210488,24519.64860892296,2279.065259218216,0.8696136474609375,0.0 -31000,0.63472646,1.7937536,,,,,,,,,,,,,, -31100,0.64737874,1.8025692,,,,,,,,,,,,,, -31200,0.64648885,1.7535285,,,,,,,,,,,,,, -31300,0.78756404,1.8623856,,,,,,,,,,,,,, -31400,0.891692,1.8702484,,,,,,,,,,,,,, -31500,0.66419035,1.843884,,,,,,,,,,,,,, -31600,0.6984339,1.780856,,,,,,,,,,,,,, -31700,0.7448504,1.8195761,,,,,,,,,,,,,, -31800,0.8245868,1.8001021,,,,,,,,,,,,,, -31900,0.74595046,1.9529351,,,,,,,,,,,,,, -32000,0.6138212,1.7499522,,,,,,,,,,,,,, -32100,0.6059198,1.803312,,,,,,,,,,,,,, -32200,0.7206038,1.8375812,,,,,,,,,,,,,, -32300,0.830439,1.7718784,,,,,,,,,,,,,, -32400,0.6341564,1.7525138,,,,,,,,,,,,,, -32500,0.754513,1.82752,,,,,,,,,,,,,, -32600,0.8705972,1.7983078,,,,,,,,,,,,,, -32700,0.60676533,1.8355343,,,,,,,,,,,,,, -32775,,,0.5687455,0.1970680333286381,0.8738757,0.2610714733966034,5348.0,0.5819507,0.1957630044888591,2472.0,25960.040961503983,28383.27775406837,25960.040961503983,2420.939267396927,0.9254543781280518,0.0 -32800,0.5627086,1.8417262,,,,,,,,,,,,,, -32900,0.8420296,1.7414149,,,,,,,,,,,,,, -33000,0.69334435,1.8237659,,,,,,,,,,,,,, -33100,0.8371852,1.826173,,,,,,,,,,,,,, -33200,0.8519983,1.7972277,,,,,,,,,,,,,, -33300,0.85963494,1.8113601,,,,,,,,,,,,,, -33400,0.64327514,1.758538,,,,,,,,,,,,,, -33500,0.73234105,1.7804615,,,,,,,,,,,,,, -33600,0.62562716,1.7466068,,,,,,,,,,,,,, -33700,0.89562595,1.7758464,,,,,,,,,,,,,, -33800,0.6369485,1.8048226,,,,,,,,,,,,,, -33900,0.75614995,1.7680963,,,,,,,,,,,,,, -34000,0.64243597,1.7422247,,,,,,,,,,,,,, -34100,0.60403013,1.7464801,,,,,,,,,,,,,, -34200,0.7033512,1.7933753,,,,,,,,,,,,,, -34300,0.610686,1.7281463,,,,,,,,,,,,,, -34400,0.89769626,1.7510704,,,,,,,,,,,,,, -34500,0.8103094,1.8002957,,,,,,,,,,,,,, -34600,0.6810785,1.8092382,,,,,,,,,,,,,, -34601,,,0.52129424,0.1864079379382489,0.86678445,0.2615735153557257,5348.0,0.5760456,0.1954177076351227,2472.0,27400.785620450974,29959.528460025787,27400.785620450974,2556.310618162155,0.9855234622955322,0.0 -34700,0.63832563,1.7371479,,,,,,,,,,,,,, -34800,0.6635974,1.7728769,,,,,,,,,,,,,, -34900,0.5309125,1.7929869,,,,,,,,,,,,,, -35000,0.8565769,1.7896662,,,,,,,,,,,,,, -35100,0.788858,1.7515057,,,,,,,,,,,,,, -35200,0.60626745,1.7532187,,,,,,,,,,,,,, -35300,0.82130146,1.7723806,,,,,,,,,,,,,, -35400,1.0252415,1.7421416,,,,,,,,,,,,,, -35500,0.6569538,1.7628075,,,,,,,,,,,,,, -35600,0.5411737,1.7200253,,,,,,,,,,,,,, -35700,0.7764942,1.7941998,,,,,,,,,,,,,, -35800,1.0322847,1.7958523,,,,,,,,,,,,,, -35900,0.7923087,1.8014077,,,,,,,,,,,,,, -36000,0.7409377,1.7378942,,,,,,,,,,,,,, -36100,0.8035399,1.7658702,,,,,,,,,,,,,, -36200,0.8531659,1.7105488,,,,,,,,,,,,,, -36300,0.88184315,1.8197256,,,,,,,,,,,,,, -36400,0.7148254,1.7368731,,,,,,,,,,,,,, -36435,,,0.4877663,0.1712832078885945,0.82872856,0.2493024513164119,5348.0,0.5492272,0.1873946336806613,2472.0,28841.119877815247,31533.52166032791,28841.119877815247,2689.823467731476,1.0541713237762451,0.0 -36500,0.62908524,1.7336336,,,,,,,,,,,,,, -36600,0.72996175,1.7327782,,,,,,,,,,,,,, -36700,0.6065952,1.7452564,,,,,,,,,,,,,, -36800,0.6301127,1.7309241,,,,,,,,,,,,,, -36900,0.7288125,1.8090786,,,,,,,,,,,,,, -37000,0.5830707,1.7816745,,,,,,,,,,,,,, -37100,0.83009404,1.7323103,,,,,,,,,,,,,, -37200,1.1182089,1.7615491,,,,,,,,,,,,,, -37300,0.7072661,1.7366619,,,,,,,,,,,,,, -37400,0.84292924,1.7167035,,,,,,,,,,,,,, -37500,0.6170663,1.6925004,,,,,,,,,,,,,, -37600,0.6627536,1.7975357,,,,,,,,,,,,,, -37700,0.58219403,1.7369092,,,,,,,,,,,,,, -37800,0.6648789,1.6986697,,,,,,,,,,,,,, -37900,0.69548744,1.7603046,,,,,,,,,,,,,, -38000,0.88758165,1.7330185,,,,,,,,,,,,,, -38100,0.7969849,1.7646399,,,,,,,,,,,,,, -38200,0.910226,1.6841954,,,,,,,,,,,,,, -38254,,,0.45671743,0.1629685196202695,0.79590267,0.2387885341340259,5348.0,0.5308841,0.1792496902484106,2472.0,30281.715041160583,33108.547865867615,30281.715041160583,2824.1240861415863,1.1081647872924805,0.0 -38300,0.66242015,1.6937134,,,,,,,,,,,,,, -38400,0.75524676,1.7319063,,,,,,,,,,,,,, -38500,0.7222726,1.7429188,,,,,,,,,,,,,, -38600,0.81705666,1.7276897,,,,,,,,,,,,,, -38700,0.6553376,1.720086,,,,,,,,,,,,,, -38800,0.6345532,1.7351943,,,,,,,,,,,,,, -38900,0.7064388,1.7845703,,,,,,,,,,,,,, -39000,0.76690716,1.6959156,,,,,,,,,,,,,, -39100,0.5950131,1.6411927,,,,,,,,,,,,,, -39200,0.61772686,1.7023877,,,,,,,,,,,,,, -39300,0.7158096,1.6752105,,,,,,,,,,,,,, -39400,0.6718128,1.6637263,,,,,,,,,,,,,, -39500,0.7404291,1.7221044,,,,,,,,,,,,,, -39600,0.7842389,1.632128,,,,,,,,,,,,,, -39700,0.78145397,1.6854352,,,,,,,,,,,,,, -39800,0.5784468,1.6975895,,,,,,,,,,,,,, -39900,0.70777506,1.6968069,,,,,,,,,,,,,, -40000,0.64739406,1.7033192,,,,,,,,,,,,,, -40068,,,0.44774008,0.1597666977808232,0.7690591,0.2315668536451142,5348.0,0.50365597,0.1681595677695854,2472.0,31721.85123872757,34681.24870944023,31721.85123872757,2956.557309150696,1.1653995513916016,0.0 -40100,0.7740686,1.6737742,,,,,,,,,,,,,, -40200,0.8184048,1.708366,,,,,,,,,,,,,, -40300,0.6988506,1.6906719,,,,,,,,,,,,,, -40400,0.6806684,1.71553,,,,,,,,,,,,,, -40500,0.65146416,1.6546414,,,,,,,,,,,,,, -40600,0.6874283,1.6665571,,,,,,,,,,,,,, -40700,0.645569,1.6312302,,,,,,,,,,,,,, -40800,0.80256766,1.6944089,,,,,,,,,,,,,, -40900,0.732359,1.7583268,,,,,,,,,,,,,, -41000,0.7687641,1.7172738,,,,,,,,,,,,,, -41100,0.5859225,1.6667281,,,,,,,,,,,,,, -41200,0.6761442,1.6463299,,,,,,,,,,,,,, -41300,0.6366219,1.6706222,,,,,,,,,,,,,, -41400,0.6496557,1.6923261,,,,,,,,,,,,,, -41500,0.8215423,1.6658977,,,,,,,,,,,,,, -41600,0.66296464,1.658484,,,,,,,,,,,,,, -41700,0.647141,1.6804631,,,,,,,,,,,,,, -41800,0.64221007,1.6368028,,,,,,,,,,,,,, -41881,,,0.42459473,0.1542755805534277,0.7605355,0.2281491064618593,5348.0,0.5015915,0.1699469867771616,2472.0,33161.960060596466,36255.45990538597,33161.960060596466,3090.524173974991,1.2250707149505615,0.0 -41900,0.7631117,1.7229337,,,,,,,,,,,,,, -42000,0.59608954,1.6589702,,,,,,,,,,,,,, -42100,0.7191452,1.6760477,,,,,,,,,,,,,, -42200,0.83297324,1.7002033,,,,,,,,,,,,,, -42300,0.77923715,1.6314259,,,,,,,,,,,,,, -42400,0.7734921,1.6591038,,,,,,,,,,,,,, -42500,0.7031919,1.6302546,,,,,,,,,,,,,, -42600,0.6456909,1.6262785,,,,,,,,,,,,,, -42700,0.62452215,1.6803907,,,,,,,,,,,,,, -42800,0.661934,1.6300466,,,,,,,,,,,,,, -42900,0.62311774,1.6760638,,,,,,,,,,,,,, -43000,0.6024751,1.724629,,,,,,,,,,,,,, -43100,0.7688785,1.6879073,,,,,,,,,,,,,, -43200,0.7068452,1.6788453,,,,,,,,,,,,,, -43300,0.697504,1.6364791,,,,,,,,,,,,,, -43400,0.8312285,1.6024987,,,,,,,,,,,,,, -43500,0.7092365,1.672339,,,,,,,,,,,,,, -43600,0.7190229,1.6509563,,,,,,,,,,,,,, -43700,0.60995054,1.5978733,,,,,,,,,,,,,, -43718,,,0.4951775,0.1707194429413918,0.7597602,0.2293269741351844,5348.0,0.492929,0.1679767635529015,2472.0,34602.280463695526,37828.44002938271,34602.280463695526,3223.048075437546,1.28309965133667,0.0 -43800,0.6716858,1.6043882,,,,,,,,,,,,,, -43900,0.7389186,1.6313243,,,,,,,,,,,,,, -44000,0.6939202,1.6389624,,,,,,,,,,,,,, -44100,0.686428,1.6273029,,,,,,,,,,,,,, -44200,0.8806577,1.6201007,,,,,,,,,,,,,, -44300,0.78639954,1.628456,,,,,,,,,,,,,, -44400,0.57917696,1.576263,,,,,,,,,,,,,, -44500,0.76286036,1.5914004,,,,,,,,,,,,,, -44600,0.80443555,1.6358702,,,,,,,,,,,,,, -44700,0.68117464,1.6103654,,,,,,,,,,,,,, -44800,0.6928775,1.6695268,,,,,,,,,,,,,, -44900,0.68970776,1.6177307,,,,,,,,,,,,,, -45000,0.6874708,1.6065978,,,,,,,,,,,,,, -45100,0.87025166,1.6342564,,,,,,,,,,,,,, -45200,0.6269151,1.6594983,,,,,,,,,,,,,, -45300,0.71443087,1.5957223,,,,,,,,,,,,,, -45400,0.7163579,1.6695062,,,,,,,,,,,,,, -45500,0.80698806,1.598367,,,,,,,,,,,,,, -45542,,,0.40693375,0.1485157320888572,0.7193552,0.2186682371568977,5348.0,0.46914527,0.1600146243373347,2472.0,36042.54721474648,39400.75337386131,36042.54721474648,3354.965092897415,1.3371038436889648,0.0 -45600,0.5566494,1.5464422,,,,,,,,,,,,,, -45700,0.63961256,1.5717813,,,,,,,,,,,,,, -45800,0.74641716,1.6064774,,,,,,,,,,,,,, -45900,0.74726754,1.6709344,,,,,,,,,,,,,, -46000,0.69234234,1.6592971,,,,,,,,,,,,,, -46100,0.73657984,1.6736579,,,,,,,,,,,,,, -46200,0.6397485,1.6411434,,,,,,,,,,,,,, -46300,0.71711296,1.6337045,,,,,,,,,,,,,, -46400,0.65501267,1.537223,,,,,,,,,,,,,, -46500,0.67702717,1.5804814,,,,,,,,,,,,,, -46600,0.70975363,1.5820656,,,,,,,,,,,,,, -46700,0.705534,1.6180828,,,,,,,,,,,,,, -46800,0.817078,1.6204133,,,,,,,,,,,,,, -46900,0.7475502,1.5244274,,,,,,,,,,,,,, -47000,0.6627005,1.5726502,,,,,,,,,,,,,, -47100,0.7010861,1.572298,,,,,,,,,,,,,, -47200,0.8141754,1.5411289,,,,,,,,,,,,,, -47300,0.6693883,1.5407547,,,,,,,,,,,,,, -47348,,,0.38006338,0.1380358563428216,0.69912624,0.2128078627494521,5348.0,0.4517525,0.1527633904088721,2472.0,37482.9290099144,40973.24291205406,37482.9290099144,3486.938924312592,1.3953540325164795,0.0 -47400,0.79531825,1.5204444,,,,,,,,,,,,,, -47500,0.86189777,1.5817138,,,,,,,,,,,,,, -47600,0.63794214,1.5468701,,,,,,,,,,,,,, -47700,0.69943297,1.5182409,,,,,,,,,,,,,, -47800,0.5969493,1.5253582,,,,,,,,,,,,,, -47900,0.6901073,1.6412411,,,,,,,,,,,,,, -48000,0.7097162,1.5830939,,,,,,,,,,,,,, -48100,0.6426713,1.5917857,,,,,,,,,,,,,, -48200,0.67550415,1.5861758,,,,,,,,,,,,,, -48300,0.6988833,1.5854367,,,,,,,,,,,,,, -48400,0.8321968,1.5681691,,,,,,,,,,,,,, -48500,0.8068834,1.4902618,,,,,,,,,,,,,, -48600,0.7482664,1.5404172,,,,,,,,,,,,,, -48700,0.66604775,1.611933,,,,,,,,,,,,,, -48800,0.86581236,1.5035459,,,,,,,,,,,,,, -48900,0.79671395,1.5597581,,,,,,,,,,,,,, -49000,0.69722736,1.543098,,,,,,,,,,,,,, -49100,0.74120575,1.6089245,,,,,,,,,,,,,, -49157,,,0.3494525,0.1273569686936377,0.67146903,0.2035297411587514,5348.0,0.43440887,0.1462230617675136,2472.0,38923.53028583527,42546.768450737,38923.53028583527,3619.7294538021088,1.4531559944152832,0.0 -49200,0.6415959,1.494449,,,,,,,,,,,,,, -49300,0.69639546,1.4976252,,,,,,,,,,,,,, -49400,0.6943454,1.5552273,,,,,,,,,,,,,, -49500,0.80149835,1.5300051,,,,,,,,,,,,,, -49600,0.65654236,1.4955972,,,,,,,,,,,,,, -49700,0.74570715,1.5585729,,,,,,,,,,,,,, -49800,0.73792684,1.4897081,,,,,,,,,,,,,, -49900,0.7948426,1.5724667,,,,,,,,,,,,,, -50000,0.7498459,1.5896134,,,,,,,,,,,,,, -50100,0.895532,1.5474874,,,,,,,,,,,,,, -50200,0.84178436,1.6105927,,,,,,,,,,,,,, -50300,0.776554,1.5133001,,,,,,,,,,,,,, -50400,0.69566506,1.5432377,,,,,,,,,,,,,, -50500,0.6862639,1.4729534,,,,,,,,,,,,,, -50600,0.70175767,1.5241427,,,,,,,,,,,,,, -50700,0.79295295,1.5377496,,,,,,,,,,,,,, -50800,0.65049297,1.4907974,,,,,,,,,,,,,, -50900,0.80239683,1.4948374,,,,,,,,,,,,,, -50996,,,0.33166713,0.1234891205972291,0.6463883,0.1973024899350242,5348.0,0.4144835,0.1407795584262588,2472.0,40363.68802213669,44119.26902413368,40363.68802213669,3751.937073707581,1.511134386062622,0.0 -51000,0.77335334,1.4886378,,,,,,,,,,,,,, -51100,0.71470773,1.5325958,,,,,,,,,,,,,, -51200,0.7375504,1.524286,,,,,,,,,,,,,, -51300,0.8370721,1.5265418,,,,,,,,,,,,,, -51400,0.7414057,1.5871708,,,,,,,,,,,,,, -51500,0.67673564,1.5230922,,,,,,,,,,,,,, -51600,0.6402762,1.4791832,,,,,,,,,,,,,, -51700,0.7294104,1.529914,,,,,,,,,,,,,, -51800,0.6966087,1.4953692,,,,,,,,,,,,,, -51900,0.77918184,1.5623724,,,,,,,,,,,,,, -52000,0.7887438,1.5947524,,,,,,,,,,,,,, -52100,0.7768247,1.5020533,,,,,,,,,,,,,, -52200,0.75598377,1.5303652,,,,,,,,,,,,,, -52300,0.7362757,1.5868287,,,,,,,,,,,,,, -52400,0.6298641,1.5245653,,,,,,,,,,,,,, -52500,0.8180415,1.4694691,,,,,,,,,,,,,, -52600,0.69686633,1.478281,,,,,,,,,,,,,, -52700,0.703503,1.4761161,,,,,,,,,,,,,, -52800,0.77305377,1.5409405,,,,,,,,,,,,,, -52829,,,0.3443252,0.1249640157371176,0.6351694,0.1935178659354876,5348.0,0.40218315,0.1346860845367944,2472.0,41803.74843025208,45691.219561100006,41803.74843025208,3883.693731546402,1.566605806350708,0.0 -52900,0.766528,1.4580071,,,,,,,,,,,,,, -53000,0.8737831,1.4909774,,,,,,,,,,,,,, -53100,0.90373963,1.580046,,,,,,,,,,,,,, -53200,0.79303247,1.5697695,,,,,,,,,,,,,, -53300,0.7736021,1.5036831,,,,,,,,,,,,,, -53400,0.8760794,1.5105128,,,,,,,,,,,,,, -53500,0.8656949,1.5073221,,,,,,,,,,,,,, -53600,0.78179824,1.483009,,,,,,,,,,,,,, -53700,0.8163442,1.4986175,,,,,,,,,,,,,, -53800,0.7107992,1.5199724,,,,,,,,,,,,,, -53900,0.7477528,1.5187151,,,,,,,,,,,,,, -54000,0.81681496,1.4589154,,,,,,,,,,,,,, -54100,0.6768469,1.5215098,,,,,,,,,,,,,, -54200,0.6843895,1.4896421,,,,,,,,,,,,,, -54300,0.70321393,1.4653609,,,,,,,,,,,,,, -54400,0.73349434,1.5158153,,,,,,,,,,,,,, -54500,0.7593411,1.4576639,,,,,,,,,,,,,, -54600,0.7850019,1.4399401,,,,,,,,,,,,,, -54640,,,0.33093026,0.1206635390150001,0.61256367,0.1880533323035036,5348.0,0.38517037,0.131781528649483,2472.0,43244.20497202873,47266.09590554237,43244.20497202873,4017.97895359993,1.6243414878845217,0.0 -54700,0.64920473,1.4918662,,,,,,,,,,,,,, -54800,0.81060416,1.4813159,,,,,,,,,,,,,, -54900,0.8091494,1.5223461,,,,,,,,,,,,,, -55000,0.732401,1.4523604,,,,,,,,,,,,,, -55100,0.75392866,1.4073205,,,,,,,,,,,,,, -55200,0.77708405,1.4745382,,,,,,,,,,,,,, -55300,0.84446436,1.4612721,,,,,,,,,,,,,, -55400,0.8581723,1.5041263,,,,,,,,,,,,,, -55500,0.78754205,1.4485492,,,,,,,,,,,,,, -55600,0.95516837,1.485112,,,,,,,,,,,,,, -55700,0.7620931,1.4879183,,,,,,,,,,,,,, -55800,0.7400791,1.415705,,,,,,,,,,,,,, -55900,0.8082041,1.4715141,,,,,,,,,,,,,, -56000,0.8755672,1.4781841,,,,,,,,,,,,,, -56100,0.78784174,1.4693481,,,,,,,,,,,,,, -56200,0.90338784,1.4487518,,,,,,,,,,,,,, -56300,0.72703004,1.459029,,,,,,,,,,,,,, -56400,0.7772079,1.4870517,,,,,,,,,,,,,, -56468,,,0.31670836,0.1151420291651815,0.5975832,0.1822508858144182,5348.0,0.3706889,0.1253630694859139,2472.0,44684.27011537552,48840.0756649971,44684.27011537552,4151.758671045303,1.6820077896118164,0.0 -56500,0.74926794,1.4123596,,,,,,,,,,,,,, -56600,0.74250317,1.4360331,,,,,,,,,,,,,, -56700,0.78390896,1.4661565,,,,,,,,,,,,,, -56800,0.7197212,1.4192778,,,,,,,,,,,,,, -56900,0.7364143,1.4538171,,,,,,,,,,,,,, -57000,0.9115261,1.3923582,,,,,,,,,,,,,, -57100,0.76648,1.4403806,,,,,,,,,,,,,, -57200,0.87843674,1.3914653,,,,,,,,,,,,,, -57300,0.70302725,1.4102943,,,,,,,,,,,,,, -57400,0.97452176,1.4781396,,,,,,,,,,,,,, -57500,0.77974004,1.4671943,,,,,,,,,,,,,, -57600,0.67942154,1.4005804,,,,,,,,,,,,,, -57700,0.68158144,1.3775944,,,,,,,,,,,,,, -57800,0.8526475,1.4222628,,,,,,,,,,,,,, -57900,0.8419723,1.3962625,,,,,,,,,,,,,, -58000,0.6945808,1.3665127,,,,,,,,,,,,,, -58100,0.764386,1.4563774,,,,,,,,,,,,,, -58200,0.7454016,1.425332,,,,,,,,,,,,,, -58294,,,0.2850827,0.1052168831373953,0.5681,0.1727313978972165,5348.0,0.35391665,0.1200414356224483,2472.0,46124.81386065483,50413.50209951401,46124.81386065483,4284.507459640503,1.7384145259857178,0.0 -58300,0.9150175,1.4069508,,,,,,,,,,,,,, -58400,0.7279937,1.416198,,,,,,,,,,,,,, -58500,0.6753889,1.4174738,,,,,,,,,,,,,, -58600,0.70375043,1.395094,,,,,,,,,,,,,, -58700,0.87395954,1.4455819,,,,,,,,,,,,,, -58800,0.7321965,1.4077461,,,,,,,,,,,,,, -58900,0.75334734,1.3358747,,,,,,,,,,,,,, -59000,0.9364174,1.4400644,,,,,,,,,,,,,, -59100,0.9315891,1.3925896,,,,,,,,,,,,,, -59200,0.7348055,1.414604,,,,,,,,,,,,,, -59300,0.8215118,1.3973218,,,,,,,,,,,,,, -59400,0.78362167,1.414869,,,,,,,,,,,,,, -59500,0.7716734,1.4051771,,,,,,,,,,,,,, -59600,0.92777616,1.428411,,,,,,,,,,,,,, -59700,0.81931895,1.3932647,,,,,,,,,,,,,, -59800,0.808023,1.3471926,,,,,,,,,,,,,, -59900,0.77192414,1.3969432,,,,,,,,,,,,,, -60000,0.9737228,1.4166385,,,,,,,,,,,,,, -60100,0.73668486,1.426182,,,,,,,,,,,,,, -60134,,,0.2588146,0.0954069945424412,0.5504874,0.168261293530417,5348.0,0.33913124,0.1140495196311417,2472.0,47564.94284963608,51986.8414990902,47564.94284963608,4417.584407091141,1.7928149700164795,0.0 -60200,0.7888806,1.3994329,,,,,,,,,,,,,, -60300,0.7703463,1.440927,,,,,,,,,,,,,, -60400,0.88736075,1.3708259,,,,,,,,,,,,,, -60500,0.7824625,1.4022267,,,,,,,,,,,,,, -60600,0.7852739,1.3855478,,,,,,,,,,,,,, -60700,0.79545856,1.3385154,,,,,,,,,,,,,, -60800,0.81801754,1.3804072,,,,,,,,,,,,,, -60900,0.67749214,1.3005776,,,,,,,,,,,,,, -61000,0.7727444,1.3554401,,,,,,,,,,,,,, -61100,1.0011687,1.3671488,,,,,,,,,,,,,, -61200,0.87054497,1.411682,,,,,,,,,,,,,, -61300,0.759155,1.3722906,,,,,,,,,,,,,, -61400,0.8108783,1.3747905,,,,,,,,,,,,,, -61500,0.8727451,1.3551958,,,,,,,,,,,,,, -61600,0.7933469,1.380457,,,,,,,,,,,,,, -61700,0.9814176,1.4108473,,,,,,,,,,,,,, -61800,0.869885,1.344394,,,,,,,,,,,,,, -61900,0.8145936,1.353908,,,,,,,,,,,,,, -61956,,,0.24200241,0.0895358863809282,0.52719754,0.1611265049190457,5348.0,0.32282758,0.1082607194361505,2472.0,49004.83890962601,53560.50705432892,49004.83890962601,4551.215311527252,1.854458570480347,0.0 -62000,0.7503039,1.33853,,,,,,,,,,,,,, -62100,0.8249803,1.3226124,,,,,,,,,,,,,, -62200,0.7878189,1.2785226,,,,,,,,,,,,,, -62300,0.9531099,1.2945633,,,,,,,,,,,,,, -62400,0.7569578,1.3335894,,,,,,,,,,,,,, -62500,0.8542624,1.3994225,,,,,,,,,,,,,, -62600,0.8326847,1.3575944,,,,,,,,,,,,,, -62700,0.9098115,1.3560966,,,,,,,,,,,,,, -62800,0.8380085,1.3162163,,,,,,,,,,,,,, -62900,0.8231059,1.3293849,,,,,,,,,,,,,, -63000,0.8870033,1.3003418,,,,,,,,,,,,,, -63100,0.7701644,1.3248657,,,,,,,,,,,,,, -63200,0.9162298,1.3178426,,,,,,,,,,,,,, -63300,0.9269153,1.3818265,,,,,,,,,,,,,, -63400,0.7778029,1.3967911,,,,,,,,,,,,,, -63500,0.8230495,1.317511,,,,,,,,,,,,,, -63600,0.8909978,1.3369913,,,,,,,,,,,,,, -63700,0.81714666,1.317986,,,,,,,,,,,,,, -63784,,,0.23850976,0.0884111856169462,0.50763637,0.1557971364299024,5348.0,0.30937302,0.1051733593321552,2472.0,50445.49860596657,55134.57596540451,50445.49860596657,4684.490033864975,1.9127497673034668,0.0 -63800,0.8543606,1.3120738,,,,,,,,,,,,,, -63900,1.0232078,1.303557,,,,,,,,,,,,,, -64000,0.9219251,1.3205187,,,,,,,,,,,,,, -64100,0.9369575,1.3030925,,,,,,,,,,,,,, -64200,0.84338206,1.3169327,,,,,,,,,,,,,, -64300,0.8271688,1.2490828,,,,,,,,,,,,,, -64400,0.91610867,1.326088,,,,,,,,,,,,,, -64500,0.76180786,1.2758503,,,,,,,,,,,,,, -64600,0.82055604,1.3602974,,,,,,,,,,,,,, -64700,0.8585578,1.350231,,,,,,,,,,,,,, -64800,0.84645355,1.342821,,,,,,,,,,,,,, -64900,0.8167995,1.3135023,,,,,,,,,,,,,, -65000,1.0892726,1.3204447,,,,,,,,,,,,,, -65100,0.88173527,1.2304342,,,,,,,,,,,,,, -65200,1.0550809,1.3679478,,,,,,,,,,,,,, -65300,1.1373695,1.3075973,,,,,,,,,,,,,, -65400,0.90493256,1.3753438,,,,,,,,,,,,,, -65500,1.0897366,1.2836204,,,,,,,,,,,,,, -65600,0.83701783,1.3030915,,,,,,,,,,,,,, -65607,,,0.21402435,0.0783076178700069,0.4928836,0.1519352752058854,5348.0,0.29772234,0.0994048707167956,2472.0,51885.54228544235,56706.69957733154,51885.54228544235,4816.434291601181,1.971311092376709,0.0 -65700,1.0334952,1.3187971,,,,,,,,,,,,,, -65800,0.802937,1.2774479,,,,,,,,,,,,,, -65900,0.8901277,1.3219742,,,,,,,,,,,,,, -66000,0.9146964,1.2358139,,,,,,,,,,,,,, -66100,0.9386261,1.3145791,,,,,,,,,,,,,, -66200,0.88387114,1.2586929,,,,,,,,,,,,,, -66300,0.9139986,1.3248757,,,,,,,,,,,,,, -66400,0.8274451,1.2415714,,,,,,,,,,,,,, -66500,0.9575238,1.2146274,,,,,,,,,,,,,, -66600,0.85809946,1.3101269,,,,,,,,,,,,,, -66700,0.9597189,1.2600168,,,,,,,,,,,,,, -66800,0.9461825,1.344988,,,,,,,,,,,,,, -66900,1.1453052,1.2505383,,,,,,,,,,,,,, -67000,0.84756744,1.229948,,,,,,,,,,,,,, -67100,1.0277003,1.3226284,,,,,,,,,,,,,, -67200,1.4420797,1.3420737,,,,,,,,,,,,,, -67300,0.91237414,1.2794331,,,,,,,,,,,,,, -67400,0.93940675,1.269898,,,,,,,,,,,,,, -67448,,,0.21193847,0.0781770019697072,0.4739441,0.144308099288452,5348.0,0.28477463,0.095972213759064,2472.0,53325.713799238205,58280.50552463532,53325.713799238205,4949.930072784424,2.032210111618042,0.0 -67500,0.8214328,1.2452147,,,,,,,,,,,,,, -67600,0.9507411,1.24451,,,,,,,,,,,,,, -67700,0.9361618,1.2821078,,,,,,,,,,,,,, -67800,1.0688052,1.2246767,,,,,,,,,,,,,, -67900,0.9282828,1.289877,,,,,,,,,,,,,, -68000,1.2059009,1.2700648,,,,,,,,,,,,,, -68100,1.0121107,1.2003678,,,,,,,,,,,,,, -68200,0.99555486,1.2609977,,,,,,,,,,,,,, -68300,1.1058258,1.2229316,,,,,,,,,,,,,, -68400,1.0955651,1.2482903,,,,,,,,,,,,,, -68500,1.0023713,1.2458477,,,,,,,,,,,,,, -68600,0.8382067,1.196644,,,,,,,,,,,,,, -68700,0.8945985,1.2789664,,,,,,,,,,,,,, -68800,1.1744251,1.2680485,,,,,,,,,,,,,, -68900,0.86205214,1.2252642,,,,,,,,,,,,,, -69000,0.9338068,1.258596,,,,,,,,,,,,,, -69100,0.9017036,1.2120882,,,,,,,,,,,,,, -69200,0.8692149,1.1713551,,,,,,,,,,,,,, -69280,,,0.22361787,0.0777870612249214,0.45318562,0.1395869739420913,5348.0,0.27006716,0.0910974346474925,2472.0,54765.77884674072,59855.67841768265,54765.77884674072,5084.901441574097,2.091064214706421,0.0 -69300,0.8183003,1.1968938,,,,,,,,,,,,,, -69400,0.92965084,1.2382421,,,,,,,,,,,,,, -69500,0.921414,1.1798193,,,,,,,,,,,,,, -69600,1.0082663,1.2163419,,,,,,,,,,,,,, -69700,1.0204512,1.203012,,,,,,,,,,,,,, -69800,1.0567508,1.21553,,,,,,,,,,,,,, -69900,0.8894067,1.2258381,,,,,,,,,,,,,, -70000,0.9682375,1.2557657,,,,,,,,,,,,,, -70100,0.9430733,1.2077224,,,,,,,,,,,,,, -70200,0.97250545,1.20237,,,,,,,,,,,,,, -70300,0.85558915,1.1618328,,,,,,,,,,,,,, -70400,1.0653046,1.1693124,,,,,,,,,,,,,, -70500,0.9825495,1.2174361,,,,,,,,,,,,,, -70600,1.1816872,1.2004316,,,,,,,,,,,,,, -70700,1.0089349,1.1914245,,,,,,,,,,,,,, -70800,0.9000536,1.2189088,,,,,,,,,,,,,, -70900,0.9366719,1.1591343,,,,,,,,,,,,,, -71000,1.2091864,1.2424039,,,,,,,,,,,,,, -71090,,,0.15905462,0.0586980822030711,0.44548163,0.1354644370854533,5348.0,0.26090908,0.0878678934860764,2472.0,56206.375772714615,61429.3595867157,56206.375772714615,5217.849287033081,2.1515071392059326,0.0 -71100,0.9790452,1.1960709,,,,,,,,,,,,,, -71200,0.98266864,1.1541039,,,,,,,,,,,,,, -71300,1.0241747,1.2458389,,,,,,,,,,,,,, -71400,1.0095437,1.1863893,,,,,,,,,,,,,, -71500,0.94828296,1.1604629,,,,,,,,,,,,,, -71600,1.2069082,1.2026066,,,,,,,,,,,,,, -71700,1.2916622,1.1420941,,,,,,,,,,,,,, -71800,1.108529,1.1981484,,,,,,,,,,,,,, -71900,1.1372008,1.2277657,,,,,,,,,,,,,, -72000,1.2299207,1.2137961,,,,,,,,,,,,,, -72100,0.88513625,1.1382256,,,,,,,,,,,,,, -72200,0.86668986,1.1384954,,,,,,,,,,,,,, -72300,1.1512867,1.182714,,,,,,,,,,,,,, -72400,1.022433,1.2226564,,,,,,,,,,,,,, -72500,1.0185477,1.1843097,,,,,,,,,,,,,, -72600,0.8978033,1.1615087,,,,,,,,,,,,,, -72700,1.0399857,1.1590511,,,,,,,,,,,,,, -72800,1.0496914,1.1771244,,,,,,,,,,,,,, -72900,1.0828512,1.1625313,,,,,,,,,,,,,, -72927,,,0.15804432,0.0580403099802996,0.42757252,0.1304922907595315,5348.0,0.25125402,0.0844149249487132,2472.0,57646.29920887947,63002.354064941406,57646.29920887947,5350.778831243515,2.2165493965148926,0.0 -73000,1.0883545,1.2416182,,,,,,,,,,,,,, -73100,1.349008,1.1230824,,,,,,,,,,,,,, -73200,1.070612,1.1967225,,,,,,,,,,,,,, -73300,0.9835518,1.1885734,,,,,,,,,,,,,, -73400,1.1293851,1.1426284,,,,,,,,,,,,,, -73500,1.0302378,1.1087742,,,,,,,,,,,,,, -73600,0.93055326,1.155047,,,,,,,,,,,,,, -73700,1.037638,1.1801916,,,,,,,,,,,,,, -73800,1.1514587,1.1724801,,,,,,,,,,,,,, -73900,1.1497732,1.1186233,,,,,,,,,,,,,, -74000,1.2115002,1.1122439,,,,,,,,,,,,,, -74100,1.033188,1.2038083,,,,,,,,,,,,,, -74200,1.0615423,1.1244098,,,,,,,,,,,,,, -74300,1.1613094,1.1210524,,,,,,,,,,,,,, -74400,1.0634835,1.1245662,,,,,,,,,,,,,, -74500,1.0533849,1.0806867,,,,,,,,,,,,,, -74600,1.2932583,1.1454453,,,,,,,,,,,,,, -74700,0.9698255,1.2008038,,,,,,,,,,,,,, -74758,,,0.198389,0.0739333898908075,0.4172008,0.1273159099027776,5348.0,0.24742267,0.082221274348506,2472.0,59086.25704813004,64574.48491859436,59086.25704813004,5482.808529615402,2.2834372520446777,0.0 -74800,0.92676455,1.1561073,,,,,,,,,,,,,, -74900,1.0312532,1.1778498,,,,,,,,,,,,,, -75000,0.96812415,1.1402808,,,,,,,,,,,,,, -75100,0.9839331,1.1572361,,,,,,,,,,,,,, -75200,0.96379906,1.163681,,,,,,,,,,,,,, -75300,0.97907907,1.1650239,,,,,,,,,,,,,, -75400,1.1350244,1.1079968,,,,,,,,,,,,,, -75500,1.1709363,1.1350012,,,,,,,,,,,,,, -75600,1.2556056,1.157522,,,,,,,,,,,,,, -75700,0.99732506,1.113299,,,,,,,,,,,,,, -75800,1.1169585,1.1595615,,,,,,,,,,,,,, -75900,1.1396351,1.1562247,,,,,,,,,,,,,, -76000,1.0311893,1.1116809,,,,,,,,,,,,,, -76100,1.0520691,1.1791714,,,,,,,,,,,,,, -76200,1.1893959,1.1594223,,,,,,,,,,,,,, -76300,1.1344719,1.1295619,,,,,,,,,,,,,, -76400,1.0663881,1.098636,,,,,,,,,,,,,, -76500,0.9373577,1.0787045,,,,,,,,,,,,,, -76591,,,0.19787984,0.070743045048096,0.41123158,0.1253656699846491,5348.0,0.24104773,0.0800276237482989,2472.0,60526.15056729317,66147.00818371773,60526.15056729317,5615.292538166046,2.351891279220581,0.0 -76600,1.0736743,1.0920305,,,,,,,,,,,,,, -76700,1.1077404,1.1122149,,,,,,,,,,,,,, -76800,1.0691088,1.1454505,,,,,,,,,,,,,, -76900,1.075644,1.1704329,,,,,,,,,,,,,, -77000,1.0688636,1.1882432,,,,,,,,,,,,,, -77100,0.9801665,1.1244105,,,,,,,,,,,,,, -77200,0.92448163,1.1227921,,,,,,,,,,,,,, -77276,,,,,,,,,,,61068.48909711838,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 564405202..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,44 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -133.37171983718872,0.0,36.564847230911255,1,0,36.564847230911255,31.237438,2472,1.1023297381837385,169.93663430213928,32.575264,1.3520278272013182,31.087046,5348,1.0585651254622166 -257.0766489505768,0.0305192470550537,1476.6239953041077,1784,0,1476.6239953041077,3.5804627,2472,0.5944590010765137,1733.8049068450928,4.0700784,0.6632959745030377,3.8945782,5348,0.6469486469003737 -386.9912831783295,0.0886681079864502,2917.59091591835,3613,0,2917.59091591835,0.6869836,2472,0.220746247435663,3304.8210270404816,0.928824,0.2913230185931187,0.9600111,5348,0.2800428666595866 -519.0642786026001,0.1420328617095947,4357.46843123436,5441,0,4357.46843123436,0.55067784,2472,0.1833323177543517,4876.902319431305,0.6443802,0.2176905895966201,0.83076894,5348,0.2493507245817121 -650.7103319168091,0.1908364295959472,5797.962051391602,7246,0,5797.962051391602,0.45089793,2472,0.1516259419495054,6449.165390253067,0.5867469,0.1951140565509825,0.7002347,5348,0.209592863280458 -782.7996170520782,0.2458448410034179,7238.451512336731,9070,0,7238.451512336731,0.405954,2472,0.1373875246277903,8021.87784409523,0.4774239,0.1646485756372009,0.6440632,5348,0.1956032709964567 -914.9536378383636,0.2943975925445556,8678.661425828934,10892,0,8678.661425828934,0.38973525,2472,0.1317409054902199,9594.367964029312,0.4616692,0.1635571770115438,0.62427807,5348,0.1891153441401083 -1045.6514217853546,0.3471841812133789,10119.200992584229,12728,0,10119.200992584229,0.36318412,2472,0.1225397599171287,11165.735782384872,0.42614308,0.1467377049180327,0.5865799,5348,0.1779545652026994 -1175.3477528095243,0.3970916271209717,11559.555562257769,14539,0,11559.555562257769,0.3582644,2472,0.1215444925151829,12735.91278553009,0.43106797,0.1506862616304639,0.58034813,5348,0.1754926286723886 -1307.0945451259613,0.4598226547241211,13000.484461307526,16355,0,13000.484461307526,0.3553724,2472,0.1196758271890805,14308.726927280426,0.45272714,0.1547102812009115,0.5875047,5348,0.1765160218967531 -1438.896250963211,0.5163545608520508,14440.638258457184,18164,0,14440.638258457184,0.34942833,2472,0.1183961976722929,15880.81642794609,0.43035367,0.1507336027182378,0.57992405,5348,0.1750292053255066 -1567.792672872543,0.5708911418914795,15881.029235124588,20003,0,15881.029235124588,0.32704383,2472,0.1088903784047285,17450.236847639084,0.42334485,0.1497484324507529,0.537798,5348,0.1622754086331907 -1700.3270015716553,0.6231732368469238,17321.37030696869,21819,0,17321.37030696869,0.31939596,2472,0.1072857636138362,19023.241703748703,0.35540873,0.1246373141148812,0.52952045,5348,0.1576218658582504 -1831.173261165619,0.6752007007598877,18761.539265871048,23623,0,18761.539265871048,0.30010766,2472,0.1024516076615278,20594.3846950531,0.35186237,0.1239148128052089,0.5021615,5348,0.1515297797773637 -1962.2557699680328,0.7344212532043457,20202.135964870453,25432,0,20202.135964870453,0.3096689,2472,0.1035281213819998,22166.200580835342,0.36941093,0.1303682806206556,0.5116533,5348,0.1539530976954343 -2095.230221509933,0.7898995876312256,21642.26782298088,27267,0,21642.26782298088,0.29647806,2472,0.0982064875185343,23739.43984937668,0.35059008,0.1234827597570218,0.49381366,5348,0.1492995549204939 -2225.6262097358704,0.8439171314239502,23082.85145688057,29093,0,23082.85145688057,0.28592485,2472,0.0953019316312229,25310.55036687851,0.31172824,0.1129408188082691,0.48430818,5348,0.1454956216148372 -2354.6739501953125,0.8952944278717041,24523.22280049324,30900,0,24523.22280049324,0.27936098,2472,0.0943269758089086,26880.09740138054,0.27994934,0.1000941861460741,0.4781998,5348,0.1433522886355079 -2485.4901201725006,0.9515819549560548,25964.00806760788,32700,0,25964.00806760788,0.27516112,2472,0.0914224199215973,28451.8306787014,0.31628087,0.114678250442593,0.46717182,5348,0.1399731600644931 -2617.0825748443604,1.0097923278808594,27403.91512775421,34518,0,27403.91512775421,0.26905754,2472,0.090264659882599,30023.46605920792,0.29610607,0.1073693235276106,0.47222495,5348,0.139036658717669 -2747.890196323395,1.0664589405059814,28844.04262661934,36337,0,28844.04262661934,0.26998752,2472,0.0893506387991794,31594.535401821136,0.31838846,0.1097850960381802,0.4541082,5348,0.1351844521467121 -2878.6182096004486,1.1174898147583008,30284.337661266327,38133,0,30284.337661266327,0.27322602,2472,0.0915036662401234,33165.68579244614,0.27280307,0.0998227803789436,0.47273842,5348,0.1405041659827954 -3008.73100566864,1.1783719062805176,31724.60327744484,39958,0,31724.60327744484,0.2446584,2472,0.0824853248837162,34736.20393896103,0.23454855,0.0854467962307291,0.43293995,5348,0.1286289427189434 -3138.707051753998,1.2340517044067385,33165.11921596527,41786,0,33165.11921596527,0.24258377,2472,0.0816728616984543,36306.827905893326,0.25626796,0.0934509031760979,0.42910972,5348,0.1268042132905954 -3280.400631427765,1.2985985279083252,34605.279014348984,43625,0,34605.279014348984,0.23896857,2472,0.0785448784351959,37888.82444357872,0.17323135,0.0646907749398744,0.4196442,5348,0.1238305801481023 -3414.3309786319733,1.3558061122894287,36045.82629442215,45444,0,36045.82629442215,0.23613536,2472,0.0795198342575102,39463.43499088287,0.1462238,0.0550591908416807,0.4182151,5348,0.1212624424341311 -3549.1361298561096,1.411712408065796,37485.83829760552,47262,0,37485.83829760552,0.22954966,2472,0.0757012572867791,41038.38579106331,0.15300256,0.0583098080857786,0.41734868,5348,0.1231837183930795 -3683.813942432404,1.4672369956970217,38926.85561108589,49083,0,38926.85561108589,0.22368722,2472,0.0727357666605731,42614.214336395264,0.14905518,0.0560649479656953,0.39878234,5348,0.1168309566795717 -3815.917148351669,1.5273003578186035,40367.21039605141,50921,0,40367.21039605141,0.21463418,2472,0.0714358255641541,44186.81035447121,0.13042888,0.0508238532410719,0.38371804,5348,0.1114436602720681 -3948.819551467896,1.5827083587646484,41807.4494600296,52735,0,41807.4494600296,0.21119802,2472,0.0702983771047874,45760.08290052414,0.1234728,0.048450242173839,0.38684124,5348,0.1116077893740888 -4080.977082490921,1.6382241249084473,43247.61201763153,54541,0,43247.61201763153,0.20237793,2472,0.0666626043507403,47332.53596878052,0.13599047,0.0502915692023852,0.36117485,5348,0.1053901928034216 -4214.5278577804565,1.6951165199279783,44687.94367861748,56349,0,44687.94367861748,0.1969371,2472,0.0647533158653748,48906.5517706871,0.12227172,0.0438607010829609,0.3628258,5348,0.1053033009258812 -4347.578888177872,1.7528400421142578,46128.45850539208,58181,0,46128.45850539208,0.19438767,2472,0.0630471431763248,50480.25352835655,0.11364357,0.0437140421904964,0.35202444,5348,0.1014028210896241 -4479.026484251022,1.8070552349090576,47569.04715514183,60009,0,47569.04715514183,0.19161133,2472,0.0622346799910629,52052.42256188393,0.09262989,0.0358730629181133,0.34524652,5348,0.0995298183959759 -4610.411191225052,1.8653712272644043,49009.17750668526,61811,0,49009.17750668526,0.18546745,2472,0.0596551093778563,53624.07222151756,0.08787621,0.0344356772032324,0.34125715,5348,0.0982264402328702 -4743.393732786179,1.925506830215454,50449.34701442719,63638,0,50449.34701442719,0.18322575,2472,0.0593301241037515,55197.36210608482,0.09771829,0.0379921302402006,0.33584002,5348,0.0955810652944186 -4875.155463933945,1.9842355251312256,51889.73348236084,65470,0,51889.73348236084,0.1784024,2472,0.0578473787906485,56769.64642548561,0.08789349,0.0339759476169358,0.32956982,5348,0.0938046091313708 -5009.825822591782,2.044980049133301,53329.91587328911,67306,0,53329.91587328911,0.1745069,2472,0.0571161619239128,58344.639607191086,0.08603654,0.0328172183847109,0.32403156,5348,0.0927425972947662 -5142.46967458725,2.107982397079468,54770.38207864761,69115,0,54770.38207864761,0.17174639,2472,0.0546584607884955,59917.89059305191,0.07548663,0.0299299109679863,0.31870773,5348,0.0905799550093167 -5275.688220500946,2.1689844131469727,56210.31742787361,70933,0,56210.31742787361,0.17032003,2472,0.0542522291958645,61491.18255186081,0.06823383,0.0263253865710832,0.31399453,5348,0.0886779883564884 -5408.91233420372,2.2319982051849365,57650.18681359291,72746,0,57650.18681359291,0.16882586,2472,0.0543740986736538,63064.41569805145,0.07446662,0.0274383379387984,0.31198248,5348,0.0877994149280245 -5540.094013929367,2.298288345336914,59090.40604901314,74586,0,59090.40604901314,0.16789155,2472,0.0534397660106026,64635.95995926857,0.07089892,0.0265824493211326,0.3090179,5348,0.0869304961526207 -5671.698009729385,2.359781265258789,60530.30942058563,76400,0,60530.30942058563,0.16666134,2472,0.0529116649401824,66207.60625338554,0.072326496,0.02635920509936258,0.3084427,5348,0.08698842407098101 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/measurements.csv deleted file mode 100644 index 1752249aa..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/measurements.csv +++ /dev/null @@ -1,816 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,31.602928,33.00271,,,,,,,,,,,,,, -1,,,32.575264,1.3520278272013182,31.087046,1.0585651254622166,5348.0,31.237438,1.1023297381837385,2472.0,36.564847230911255,169.93663430213928,36.564847230911255,133.37171983718872,0.0,0.0 -100,1.1111318,6.025469,,,,,,,,,,,,,, -200,0.53578866,5.8327827,,,,,,,,,,,,,, -300,1.0401261,5.7791805,,,,,,,,,,,,,, -400,1.0732434,5.8030543,,,,,,,,,,,,,, -500,2.5742862,5.8061514,,,,,,,,,,,,,, -600,0.5501649,5.758834,,,,,,,,,,,,,, -700,0.30672976,5.793676,,,,,,,,,,,,,, -800,2.9358513,5.600902,,,,,,,,,,,,,, -900,0.8309426,5.465377,,,,,,,,,,,,,, -1000,2.4962606,5.2033587,,,,,,,,,,,,,, -1100,0.81594056,3.8966248,,,,,,,,,,,,,, -1200,1.3236127,3.4096313,,,,,,,,,,,,,, -1300,1.3257629,3.132804,,,,,,,,,,,,,, -1400,2.0051126,3.1232078,,,,,,,,,,,,,, -1500,1.2721086,2.9802155,,,,,,,,,,,,,, -1600,0.78395563,2.7373302,,,,,,,,,,,,,, -1700,0.79064494,2.756326,,,,,,,,,,,,,, -1784,,,4.0700784,0.6632959745030377,3.8945782,0.6469486469003737,5348.0,3.5804627,0.5944590010765137,2472.0,1476.6239953041077,1733.8049068450928,1476.6239953041077,257.0766489505768,0.0305192470550537,0.0 -1800,0.9904402,2.6650915,,,,,,,,,,,,,, -1900,0.6191192,2.3814921,,,,,,,,,,,,,, -2000,1.2429163,2.351413,,,,,,,,,,,,,, -2100,0.9183548,2.3108675,,,,,,,,,,,,,, -2200,0.60882187,2.269346,,,,,,,,,,,,,, -2300,0.7068114,2.2023413,,,,,,,,,,,,,, -2400,0.7689561,2.226229,,,,,,,,,,,,,, -2500,0.5156541,2.1340647,,,,,,,,,,,,,, -2600,0.60388833,2.169797,,,,,,,,,,,,,, -2700,0.5384337,2.0663137,,,,,,,,,,,,,, -2800,0.6053293,2.0926137,,,,,,,,,,,,,, -2900,0.6014496,2.0394945,,,,,,,,,,,,,, -3000,0.662145,2.107972,,,,,,,,,,,,,, -3100,0.6100946,1.9963528,,,,,,,,,,,,,, -3200,0.55652857,1.9719387,,,,,,,,,,,,,, -3300,0.71054596,1.9391037,,,,,,,,,,,,,, -3400,0.7574368,2.0459647,,,,,,,,,,,,,, -3500,0.5858683,1.9827647,,,,,,,,,,,,,, -3600,0.45648336,1.8262678,,,,,,,,,,,,,, -3613,,,0.928824,0.2913230185931187,0.9600111,0.2800428666595866,5348.0,0.6869836,0.220746247435663,2472.0,2917.59091591835,3304.8210270404816,2917.59091591835,386.9912831783295,0.0886681079864502,0.0 -3700,0.52851355,1.9340434,,,,,,,,,,,,,, -3800,0.46137178,1.901358,,,,,,,,,,,,,, -3900,0.4855455,1.910009,,,,,,,,,,,,,, -4000,0.4671819,1.8086269,,,,,,,,,,,,,, -4100,0.58014387,1.8826944,,,,,,,,,,,,,, -4200,0.58361864,1.7977545,,,,,,,,,,,,,, -4300,0.8217641,1.8199221,,,,,,,,,,,,,, -4400,0.4858357,1.8772662,,,,,,,,,,,,,, -4500,0.7314132,1.7771114,,,,,,,,,,,,,, -4600,0.5291428,1.734885,,,,,,,,,,,,,, -4700,0.60047805,1.7821859,,,,,,,,,,,,,, -4800,0.6091509,1.7993428,,,,,,,,,,,,,, -4900,1.1580539,1.7967592,,,,,,,,,,,,,, -5000,0.6097497,1.8422728,,,,,,,,,,,,,, -5100,0.7366995,1.7910995,,,,,,,,,,,,,, -5200,0.61804175,1.8218628,,,,,,,,,,,,,, -5300,0.6710583,1.7666041,,,,,,,,,,,,,, -5400,0.5334122,1.7546585,,,,,,,,,,,,,, -5441,,,0.6443802,0.2176905895966201,0.83076894,0.2493507245817121,5348.0,0.55067784,0.1833323177543517,2472.0,4357.46843123436,4876.902319431305,4357.46843123436,519.0642786026001,0.1420328617095947,0.0 -5500,0.64987844,1.6825866,,,,,,,,,,,,,, -5600,0.43378296,1.7292026,,,,,,,,,,,,,, -5700,0.63566005,1.7460754,,,,,,,,,,,,,, -5800,0.47459912,1.7312068,,,,,,,,,,,,,, -5900,0.9540112,1.7277956,,,,,,,,,,,,,, -6000,0.53468543,1.7133937,,,,,,,,,,,,,, -6100,0.41851828,1.7601315,,,,,,,,,,,,,, -6200,0.5057246,1.6969436,,,,,,,,,,,,,, -6300,0.53741324,1.6587229,,,,,,,,,,,,,, -6400,0.5537949,1.7595611,,,,,,,,,,,,,, -6500,0.5641043,1.639412,,,,,,,,,,,,,, -6600,0.6759477,1.7552832,,,,,,,,,,,,,, -6700,0.56686187,1.744381,,,,,,,,,,,,,, -6800,0.5801516,1.6975154,,,,,,,,,,,,,, -6900,0.70683455,1.7698734,,,,,,,,,,,,,, -7000,0.58038026,1.7379806,,,,,,,,,,,,,, -7100,0.6160405,1.7486007,,,,,,,,,,,,,, -7200,0.5696513,1.7221715,,,,,,,,,,,,,, -7246,,,0.5867469,0.1951140565509825,0.7002347,0.209592863280458,5348.0,0.45089793,0.1516259419495054,2472.0,5797.962051391602,6449.165390253067,5797.962051391602,650.7103319168091,0.1908364295959472,0.0 -7300,0.4449401,1.630431,,,,,,,,,,,,,, -7400,0.54888916,1.6866659,,,,,,,,,,,,,, -7500,0.47326323,1.6755097,,,,,,,,,,,,,, -7600,0.45586377,1.6182003,,,,,,,,,,,,,, -7700,0.46713787,1.6476175,,,,,,,,,,,,,, -7800,0.3968945,1.6617173,,,,,,,,,,,,,, -7900,0.55933887,1.5783947,,,,,,,,,,,,,, -8000,0.7437589,1.6850373,,,,,,,,,,,,,, -8100,0.66172135,1.5370587,,,,,,,,,,,,,, -8200,0.49433246,1.6097249,,,,,,,,,,,,,, -8300,0.46190733,1.5857865,,,,,,,,,,,,,, -8400,0.45757043,1.5619035,,,,,,,,,,,,,, -8500,0.5258592,1.6108197,,,,,,,,,,,,,, -8600,0.578043,1.6231014,,,,,,,,,,,,,, -8700,0.85768616,1.6637841,,,,,,,,,,,,,, -8800,0.50138754,1.6461359,,,,,,,,,,,,,, -8900,0.53288835,1.6066208,,,,,,,,,,,,,, -9000,0.50485396,1.6194319,,,,,,,,,,,,,, -9070,,,0.4774239,0.1646485756372009,0.6440632,0.1956032709964567,5348.0,0.405954,0.1373875246277903,2472.0,7238.451512336731,8021.87784409523,7238.451512336731,782.7996170520782,0.2458448410034179,0.0 -9100,0.49009052,1.6072655,,,,,,,,,,,,,, -9200,0.5275915,1.6194502,,,,,,,,,,,,,, -9300,0.4790421,1.6176882,,,,,,,,,,,,,, -9400,0.4317319,1.538893,,,,,,,,,,,,,, -9500,0.5563259,1.6244624,,,,,,,,,,,,,, -9600,0.45479664,1.5289615,,,,,,,,,,,,,, -9700,0.49455547,1.5645429,,,,,,,,,,,,,, -9800,0.57610005,1.5828056,,,,,,,,,,,,,, -9900,0.44507146,1.6233594,,,,,,,,,,,,,, -10000,0.608769,1.6308467,,,,,,,,,,,,,, -10100,0.4885354,1.5651085,,,,,,,,,,,,,, -10200,0.40624127,1.5128345,,,,,,,,,,,,,, -10300,0.58632976,1.5426247,,,,,,,,,,,,,, -10400,0.50558126,1.5478276,,,,,,,,,,,,,, -10500,0.56118774,1.6053002,,,,,,,,,,,,,, -10600,0.43653753,1.5021502,,,,,,,,,,,,,, -10700,0.44031632,1.5380518,,,,,,,,,,,,,, -10800,0.4953972,1.5619498,,,,,,,,,,,,,, -10892,,,0.4616692,0.1635571770115438,0.62427807,0.1891153441401083,5348.0,0.38973525,0.1317409054902199,2472.0,8678.661425828934,9594.367964029312,8678.661425828934,914.9536378383636,0.2943975925445556,0.0 -10900,0.5642609,1.6039896,,,,,,,,,,,,,, -11000,0.59053874,1.5376683,,,,,,,,,,,,,, -11100,0.42980787,1.5042117,,,,,,,,,,,,,, -11200,0.580475,1.4989704,,,,,,,,,,,,,, -11300,0.48116577,1.4967308,,,,,,,,,,,,,, -11400,0.50420314,1.5044147,,,,,,,,,,,,,, -11500,0.558956,1.5145352,,,,,,,,,,,,,, -11600,0.5404307,1.5780083,,,,,,,,,,,,,, -11700,0.5867504,1.5292041,,,,,,,,,,,,,, -11800,0.565114,1.5491518,,,,,,,,,,,,,, -11900,0.52126634,1.4822533,,,,,,,,,,,,,, -12000,0.43745548,1.582057,,,,,,,,,,,,,, -12100,0.62347895,1.5566959,,,,,,,,,,,,,, -12200,0.5413187,1.5315084,,,,,,,,,,,,,, -12300,0.5130208,1.4596828,,,,,,,,,,,,,, -12400,0.43070468,1.4652274,,,,,,,,,,,,,, -12500,0.5802744,1.5257457,,,,,,,,,,,,,, -12600,0.52338177,1.5189706,,,,,,,,,,,,,, -12700,0.49448034,1.5483687,,,,,,,,,,,,,, -12728,,,0.42614308,0.1467377049180327,0.5865799,0.1779545652026994,5348.0,0.36318412,0.1225397599171287,2472.0,10119.200992584229,11165.735782384872,10119.200992584229,1045.6514217853546,0.3471841812133789,0.0 -12800,0.54595816,1.5175337,,,,,,,,,,,,,, -12900,0.58175534,1.5068913,,,,,,,,,,,,,, -13000,0.4565973,1.4885896,,,,,,,,,,,,,, -13100,0.43422934,1.4798515,,,,,,,,,,,,,, -13200,0.5191414,1.5479684,,,,,,,,,,,,,, -13300,0.44643992,1.4963049,,,,,,,,,,,,,, -13400,0.5663482,1.5277234,,,,,,,,,,,,,, -13500,0.439266,1.4299628,,,,,,,,,,,,,, -13600,0.6051183,1.5209631,,,,,,,,,,,,,, -13700,0.45150286,1.4467841,,,,,,,,,,,,,, -13800,0.46978647,1.4373124,,,,,,,,,,,,,, -13900,0.55410635,1.5013428,,,,,,,,,,,,,, -14000,0.5259518,1.4548314,,,,,,,,,,,,,, -14100,0.56728625,1.5142219,,,,,,,,,,,,,, -14200,0.52679396,1.4924494,,,,,,,,,,,,,, -14300,0.49178532,1.5031444,,,,,,,,,,,,,, -14400,0.47695932,1.4879787,,,,,,,,,,,,,, -14500,0.6239121,1.5349736,,,,,,,,,,,,,, -14539,,,0.43106797,0.1506862616304639,0.58034813,0.1754926286723886,5348.0,0.3582644,0.1215444925151829,2472.0,11559.555562257769,12735.91278553009,11559.555562257769,1175.3477528095243,0.3970916271209717,0.0 -14600,0.5542355,1.5107676,,,,,,,,,,,,,, -14700,0.5620632,1.4683096,,,,,,,,,,,,,, -14800,0.71357745,1.4650936,,,,,,,,,,,,,, -14900,0.60199213,1.4838794,,,,,,,,,,,,,, -15000,0.44953018,1.4936658,,,,,,,,,,,,,, -15100,0.6225218,1.4494853,,,,,,,,,,,,,, -15200,0.6314373,1.4659832,,,,,,,,,,,,,, -15300,0.55624765,1.4832286,,,,,,,,,,,,,, -15400,0.46175233,1.4419302,,,,,,,,,,,,,, -15500,0.53729373,1.4581587,,,,,,,,,,,,,, -15600,0.4887843,1.3987347,,,,,,,,,,,,,, -15700,0.46520072,1.4527684,,,,,,,,,,,,,, -15800,0.507695,1.4741328,,,,,,,,,,,,,, -15900,0.6578467,1.5268128,,,,,,,,,,,,,, -16000,0.56609505,1.4213288,,,,,,,,,,,,,, -16100,0.693711,1.4991436,,,,,,,,,,,,,, -16200,0.6053857,1.5160081,,,,,,,,,,,,,, -16300,0.66512316,1.5222147,,,,,,,,,,,,,, -16355,,,0.45272714,0.1547102812009115,0.5875047,0.1765160218967531,5348.0,0.3553724,0.1196758271890805,2472.0,13000.484461307526,14308.726927280426,13000.484461307526,1307.0945451259613,0.4598226547241211,0.0 -16400,0.46087638,1.5000457,,,,,,,,,,,,,, -16500,0.42368093,1.5101684,,,,,,,,,,,,,, -16600,0.49395144,1.4973779,,,,,,,,,,,,,, -16700,0.87506026,1.4713733,,,,,,,,,,,,,, -16800,1.082231,1.4619291,,,,,,,,,,,,,, -16900,0.91521686,1.5545048,,,,,,,,,,,,,, -17000,0.47996384,1.3983421,,,,,,,,,,,,,, -17100,0.57597643,1.4623344,,,,,,,,,,,,,, -17200,0.47631606,1.4902414,,,,,,,,,,,,,, -17300,0.76979524,1.4301744,,,,,,,,,,,,,, -17400,0.6215479,1.4733791,,,,,,,,,,,,,, -17500,0.4008885,1.4180325,,,,,,,,,,,,,, -17600,0.54363763,1.5165445,,,,,,,,,,,,,, -17700,0.7923452,1.5526891,,,,,,,,,,,,,, -17800,0.48811764,1.5118746,,,,,,,,,,,,,, -17900,0.46627027,1.4481019,,,,,,,,,,,,,, -18000,0.42892522,1.4268942,,,,,,,,,,,,,, -18100,0.52945346,1.4568865,,,,,,,,,,,,,, -18164,,,0.43035367,0.1507336027182378,0.57992405,0.1750292053255066,5348.0,0.34942833,0.1183961976722929,2472.0,14440.638258457184,15880.81642794609,14440.638258457184,1438.896250963211,0.5163545608520508,0.0 -18200,0.73908514,1.4164779,,,,,,,,,,,,,, -18300,0.52937293,1.4695807,,,,,,,,,,,,,, -18400,0.478456,1.4654373,,,,,,,,,,,,,, -18500,0.46352306,1.3954788,,,,,,,,,,,,,, -18600,0.52866495,1.497318,,,,,,,,,,,,,, -18700,0.49748594,1.501007,,,,,,,,,,,,,, -18800,0.5846558,1.491682,,,,,,,,,,,,,, -18900,0.4307417,1.4271896,,,,,,,,,,,,,, -19000,0.59536463,1.43649,,,,,,,,,,,,,, -19100,0.56945074,1.4404908,,,,,,,,,,,,,, -19200,0.5865462,1.5083082,,,,,,,,,,,,,, -19300,0.5790281,1.4692105,,,,,,,,,,,,,, -19400,0.6137594,1.4534465,,,,,,,,,,,,,, -19500,0.51593053,1.460375,,,,,,,,,,,,,, -19600,0.5289775,1.4308699,,,,,,,,,,,,,, -19700,0.6674195,1.4813225,,,,,,,,,,,,,, -19800,0.81059337,1.4027021,,,,,,,,,,,,,, -19900,0.52591354,1.406281,,,,,,,,,,,,,, -20000,0.4995658,1.4528624,,,,,,,,,,,,,, -20003,,,0.42334485,0.1497484324507529,0.537798,0.1622754086331907,5348.0,0.32704383,0.1088903784047285,2472.0,15881.029235124588,17450.236847639084,15881.029235124588,1567.792672872543,0.5708911418914795,0.0 -20100,0.6642165,1.4035331,,,,,,,,,,,,,, -20200,0.46698523,1.4255023,,,,,,,,,,,,,, -20300,0.55602944,1.3634764,,,,,,,,,,,,,, -20400,0.45323008,1.3731337,,,,,,,,,,,,,, -20500,0.7151332,1.4776325,,,,,,,,,,,,,, -20600,0.62505275,1.3602059,,,,,,,,,,,,,, -20700,0.5685678,1.4165444,,,,,,,,,,,,,, -20800,0.62774384,1.4217697,,,,,,,,,,,,,, -20900,0.56844497,1.3510629,,,,,,,,,,,,,, -21000,0.55166787,1.4119213,,,,,,,,,,,,,, -21100,0.5465281,1.3668958,,,,,,,,,,,,,, -21200,0.4738144,1.3955094,,,,,,,,,,,,,, -21300,0.6531449,1.4051249,,,,,,,,,,,,,, -21400,0.5458423,1.4476229,,,,,,,,,,,,,, -21500,0.6226678,1.3861016,,,,,,,,,,,,,, -21600,0.59086096,1.4241797,,,,,,,,,,,,,, -21700,0.5846583,1.4061997,,,,,,,,,,,,,, -21800,0.4857,1.4232968,,,,,,,,,,,,,, -21819,,,0.35540873,0.1246373141148812,0.52952045,0.1576218658582504,5348.0,0.31939596,0.1072857636138362,2472.0,17321.37030696869,19023.241703748703,17321.37030696869,1700.3270015716553,0.6231732368469238,0.0 -21900,0.55482554,1.4211706,,,,,,,,,,,,,, -22000,1.4134998,1.4352125,,,,,,,,,,,,,, -22100,0.50601214,1.4242874,,,,,,,,,,,,,, -22200,0.5712321,1.4232141,,,,,,,,,,,,,, -22300,0.6922635,1.4177514,,,,,,,,,,,,,, -22400,1.1465119,1.3204108,,,,,,,,,,,,,, -22500,0.7556807,1.3904531,,,,,,,,,,,,,, -22600,0.9122837,1.4635508,,,,,,,,,,,,,, -22700,0.72451717,1.4196459,,,,,,,,,,,,,, -22800,1.222276,1.340581,,,,,,,,,,,,,, -22900,0.56134355,1.4010665,,,,,,,,,,,,,, -23000,0.4611722,1.4155489,,,,,,,,,,,,,, -23100,0.5411689,1.3226919,,,,,,,,,,,,,, -23200,0.48621354,1.3539945,,,,,,,,,,,,,, -23300,0.6379012,1.4520074,,,,,,,,,,,,,, -23400,0.43128005,1.3544427,,,,,,,,,,,,,, -23500,0.8808856,1.3985339,,,,,,,,,,,,,, -23600,0.42200255,1.3335938,,,,,,,,,,,,,, -23623,,,0.35186237,0.1239148128052089,0.5021615,0.1515297797773637,5348.0,0.30010766,0.1024516076615278,2472.0,18761.539265871048,20594.3846950531,18761.539265871048,1831.173261165619,0.6752007007598877,0.0 -23700,0.58374727,1.3188448,,,,,,,,,,,,,, -23800,0.6032245,1.4049264,,,,,,,,,,,,,, -23900,0.6426251,1.3738872,,,,,,,,,,,,,, -24000,0.6217328,1.4357841,,,,,,,,,,,,,, -24100,0.5142518,1.4120365,,,,,,,,,,,,,, -24200,0.5864708,1.3971163,,,,,,,,,,,,,, -24300,0.61110055,1.3678938,,,,,,,,,,,,,, -24400,0.5185258,1.4570087,,,,,,,,,,,,,, -24500,0.43183503,1.3324305,,,,,,,,,,,,,, -24600,0.56249976,1.4003501,,,,,,,,,,,,,, -24700,0.77442676,1.3976915,,,,,,,,,,,,,, -24800,0.6559688,1.3704466,,,,,,,,,,,,,, -24900,0.5134965,1.4034514,,,,,,,,,,,,,, -25000,0.56010187,1.3744652,,,,,,,,,,,,,, -25100,0.57769054,1.3778034,,,,,,,,,,,,,, -25200,0.7183368,1.4210241,,,,,,,,,,,,,, -25300,0.869237,1.3896314,,,,,,,,,,,,,, -25400,0.53767157,1.4127257,,,,,,,,,,,,,, -25432,,,0.36941093,0.1303682806206556,0.5116533,0.1539530976954343,5348.0,0.3096689,0.1035281213819998,2472.0,20202.135964870453,22166.200580835342,20202.135964870453,1962.2557699680328,0.7344212532043457,0.0 -25500,0.540338,1.4151082,,,,,,,,,,,,,, -25600,0.6005689,1.4075583,,,,,,,,,,,,,, -25700,0.55207753,1.3196253,,,,,,,,,,,,,, -25800,0.48434618,1.3554797,,,,,,,,,,,,,, -25900,1.2415069,1.4423978,,,,,,,,,,,,,, -26000,0.4706924,1.3482307,,,,,,,,,,,,,, -26100,0.6011619,1.3899025,,,,,,,,,,,,,, -26200,0.57623523,1.3737375,,,,,,,,,,,,,, -26300,0.48889792,1.3867081,,,,,,,,,,,,,, -26400,0.60521555,1.4269897,,,,,,,,,,,,,, -26500,0.56331515,1.3842968,,,,,,,,,,,,,, -26600,0.5233056,1.3127006,,,,,,,,,,,,,, -26700,0.78670824,1.4264592,,,,,,,,,,,,,, -26800,0.60607547,1.436457,,,,,,,,,,,,,, -26900,0.69272554,1.3613194,,,,,,,,,,,,,, -27000,0.53887045,1.3525538,,,,,,,,,,,,,, -27100,0.42080766,1.3710749,,,,,,,,,,,,,, -27200,0.8844958,1.368384,,,,,,,,,,,,,, -27267,,,0.35059008,0.1234827597570218,0.49381366,0.1492995549204939,5348.0,0.29647806,0.0982064875185343,2472.0,21642.26782298088,23739.43984937668,21642.26782298088,2095.230221509933,0.7898995876312256,0.0 -27300,0.64918745,1.3086737,,,,,,,,,,,,,, -27400,0.54520583,1.4081028,,,,,,,,,,,,,, -27500,0.6249026,1.2857456,,,,,,,,,,,,,, -27600,0.6684991,1.4142379,,,,,,,,,,,,,, -27700,0.4373987,1.3423603,,,,,,,,,,,,,, -27800,0.6114307,1.3675768,,,,,,,,,,,,,, -27900,0.6437399,1.2956492,,,,,,,,,,,,,, -28000,0.47106197,1.2900538,,,,,,,,,,,,,, -28100,0.51814073,1.328582,,,,,,,,,,,,,, -28200,0.6254003,1.3381305,,,,,,,,,,,,,, -28300,0.5551865,1.3516585,,,,,,,,,,,,,, -28400,0.99253577,1.3464501,,,,,,,,,,,,,, -28500,0.5317108,1.3393722,,,,,,,,,,,,,, -28600,0.51138294,1.4252084,,,,,,,,,,,,,, -28700,0.91770774,1.3525258,,,,,,,,,,,,,, -28800,0.69543225,1.3526971,,,,,,,,,,,,,, -28900,0.4951501,1.343673,,,,,,,,,,,,,, -29000,0.8113992,1.2815439,,,,,,,,,,,,,, -29093,,,0.31172824,0.1129408188082691,0.48430818,0.1454956216148372,5348.0,0.28592485,0.0953019316312229,2472.0,23082.85145688057,25310.55036687851,23082.85145688057,2225.6262097358704,0.8439171314239502,0.0 -29100,0.59825194,1.3501132,,,,,,,,,,,,,, -29200,0.84765154,1.2974553,,,,,,,,,,,,,, -29300,0.5258946,1.3613572,,,,,,,,,,,,,, -29400,0.5832279,1.3280131,,,,,,,,,,,,,, -29500,0.92161393,1.333442,,,,,,,,,,,,,, -29600,0.5614407,1.3513981,,,,,,,,,,,,,, -29700,0.59411174,1.289391,,,,,,,,,,,,,, -29800,0.46508846,1.3059056,,,,,,,,,,,,,, -29900,0.59352845,1.3056133,,,,,,,,,,,,,, -30000,0.657469,1.3502512,,,,,,,,,,,,,, -30100,0.6577471,1.3281556,,,,,,,,,,,,,, -30200,1.0939163,1.3602451,,,,,,,,,,,,,, -30300,0.47425464,1.3251026,,,,,,,,,,,,,, -30400,0.966164,1.2981629,,,,,,,,,,,,,, -30500,0.61003846,1.3414463,,,,,,,,,,,,,, -30600,0.6675832,1.3504369,,,,,,,,,,,,,, -30700,0.7383178,1.3629749,,,,,,,,,,,,,, -30800,0.54626316,1.2972295,,,,,,,,,,,,,, -30900,,,0.27994934,0.1000941861460741,0.4781998,0.1433522886355079,5348.0,0.27936098,0.0943269758089086,2472.0,24523.22280049324,26880.09740138054,24523.22280049324,2354.6739501953125,0.8952944278717041,0.0 -30900,1.0561683,1.3233979,,,,,,,,,,,,,, -31000,0.6433433,1.3074445,,,,,,,,,,,,,, -31100,0.72668487,1.3052641,,,,,,,,,,,,,, -31200,0.6656382,1.3297552,,,,,,,,,,,,,, -31300,0.66464806,1.3779589,,,,,,,,,,,,,, -31400,0.4752533,1.3812169,,,,,,,,,,,,,, -31500,0.53794676,1.3361212,,,,,,,,,,,,,, -31600,0.48910344,1.3610307,,,,,,,,,,,,,, -31700,0.43444785,1.2645744,,,,,,,,,,,,,, -31800,0.59495586,1.3096583,,,,,,,,,,,,,, -31900,0.598289,1.3692826,,,,,,,,,,,,,, -32000,0.46292564,1.2343943,,,,,,,,,,,,,, -32100,0.62744635,1.3845668,,,,,,,,,,,,,, -32200,0.5206886,1.3608587,,,,,,,,,,,,,, -32300,0.52145153,1.2718991,,,,,,,,,,,,,, -32400,0.5693879,1.2943308,,,,,,,,,,,,,, -32500,0.57848716,1.2480438,,,,,,,,,,,,,, -32600,0.64482766,1.2968202,,,,,,,,,,,,,, -32700,,,0.31628087,0.114678250442593,0.46717182,0.1399731600644931,5348.0,0.27516112,0.0914224199215973,2472.0,25964.00806760788,28451.8306787014,25964.00806760788,2485.4901201725006,0.9515819549560548,0.0 -32700,0.5126682,1.3001195,,,,,,,,,,,,,, -32800,0.8943302,1.3992203,,,,,,,,,,,,,, -32900,0.584508,1.3141621,,,,,,,,,,,,,, -33000,0.50603205,1.3066733,,,,,,,,,,,,,, -33100,0.59818256,1.3365622,,,,,,,,,,,,,, -33200,0.56224513,1.2938828,,,,,,,,,,,,,, -33300,0.4867858,1.2715638,,,,,,,,,,,,,, -33400,0.4808053,1.2880722,,,,,,,,,,,,,, -33500,0.49691966,1.2949795,,,,,,,,,,,,,, -33600,0.6272752,1.3091267,,,,,,,,,,,,,, -33700,0.54319155,1.332421,,,,,,,,,,,,,, -33800,0.44288588,1.3105541,,,,,,,,,,,,,, -33900,0.6126447,1.3363589,,,,,,,,,,,,,, -34000,0.6389887,1.275755,,,,,,,,,,,,,, -34100,0.4653033,1.2477438,,,,,,,,,,,,,, -34200,0.5014164,1.268973,,,,,,,,,,,,,, -34300,0.5442547,1.2893343,,,,,,,,,,,,,, -34400,0.7115738,1.2694135,,,,,,,,,,,,,, -34500,0.74846196,1.3363249,,,,,,,,,,,,,, -34518,,,0.29610607,0.1073693235276106,0.47222495,0.139036658717669,5348.0,0.26905754,0.090264659882599,2472.0,27403.91512775421,30023.46605920792,27403.91512775421,2617.0825748443604,1.0097923278808594,0.0 -34600,0.55514854,1.3518335,,,,,,,,,,,,,, -34700,0.66424954,1.2967417,,,,,,,,,,,,,, -34800,1.3989145,1.2474697,,,,,,,,,,,,,, -34900,0.69725865,1.2747519,,,,,,,,,,,,,, -35000,0.7562657,1.2872884,,,,,,,,,,,,,, -35100,0.6627404,1.3035446,,,,,,,,,,,,,, -35200,0.66938424,1.3190763,,,,,,,,,,,,,, -35300,0.62428266,1.314657,,,,,,,,,,,,,, -35400,0.8509412,1.3204492,,,,,,,,,,,,,, -35500,0.6253007,1.2541255,,,,,,,,,,,,,, -35600,0.7761378,1.2964735,,,,,,,,,,,,,, -35700,0.48655587,1.341288,,,,,,,,,,,,,, -35800,0.86789846,1.3121696,,,,,,,,,,,,,, -35900,0.76196456,1.2772925,,,,,,,,,,,,,, -36000,0.54678947,1.2791842,,,,,,,,,,,,,, -36100,0.5047138,1.2991867,,,,,,,,,,,,,, -36200,0.5602455,1.2808487,,,,,,,,,,,,,, -36300,0.50760484,1.2665439,,,,,,,,,,,,,, -36337,,,0.31838846,0.1097850960381802,0.4541082,0.1351844521467121,5348.0,0.26998752,0.0893506387991794,2472.0,28844.04262661934,31594.535401821136,28844.04262661934,2747.890196323395,1.0664589405059814,0.0 -36400,0.46266696,1.2623112,,,,,,,,,,,,,, -36500,0.5658997,1.2968751,,,,,,,,,,,,,, -36600,0.6962751,1.2877039,,,,,,,,,,,,,, -36700,0.5371335,1.2537568,,,,,,,,,,,,,, -36800,0.6905447,1.2914083,,,,,,,,,,,,,, -36900,0.62150055,1.3221776,,,,,,,,,,,,,, -37000,0.70584744,1.2818184,,,,,,,,,,,,,, -37100,0.57847124,1.2621636,,,,,,,,,,,,,, -37200,0.6743866,1.2361903,,,,,,,,,,,,,, -37300,0.67514235,1.2577889,,,,,,,,,,,,,, -37400,0.5129046,1.2474484,,,,,,,,,,,,,, -37500,0.58950096,1.2235291,,,,,,,,,,,,,, -37600,0.76942074,1.2248418,,,,,,,,,,,,,, -37700,1.5179701,1.2901001,,,,,,,,,,,,,, -37800,0.8316952,1.2405257,,,,,,,,,,,,,, -37900,0.53507864,1.267607,,,,,,,,,,,,,, -38000,0.7138499,1.3135049,,,,,,,,,,,,,, -38100,0.68753695,1.3110757,,,,,,,,,,,,,, -38133,,,0.27280307,0.0998227803789436,0.47273842,0.1405041659827954,5348.0,0.27322602,0.0915036662401234,2472.0,30284.337661266327,33165.68579244614,30284.337661266327,2878.6182096004486,1.1174898147583008,0.0 -38200,0.4923909,1.2714901,,,,,,,,,,,,,, -38300,0.50178695,1.2906628,,,,,,,,,,,,,, -38400,0.62974083,1.2522159,,,,,,,,,,,,,, -38500,0.57601833,1.3176419,,,,,,,,,,,,,, -38600,0.4934987,1.2755452,,,,,,,,,,,,,, -38700,0.60111874,1.2671021,,,,,,,,,,,,,, -38800,0.49673766,1.3048358,,,,,,,,,,,,,, -38900,1.8055891,1.2633125,,,,,,,,,,,,,, -39000,0.60774994,1.2909776,,,,,,,,,,,,,, -39100,1.0983949,1.2564048,,,,,,,,,,,,,, -39200,0.6475919,1.2394027,,,,,,,,,,,,,, -39300,0.70171446,1.2030632,,,,,,,,,,,,,, -39400,0.65982115,1.2786739,,,,,,,,,,,,,, -39500,0.49459535,1.2250183,,,,,,,,,,,,,, -39600,0.7037264,1.2887698,,,,,,,,,,,,,, -39700,0.74705493,1.2050061,,,,,,,,,,,,,, -39800,0.5717866,1.2595054,,,,,,,,,,,,,, -39900,0.757763,1.2681761,,,,,,,,,,,,,, -39958,,,0.23454855,0.0854467962307291,0.43293995,0.1286289427189434,5348.0,0.2446584,0.0824853248837162,2472.0,31724.60327744484,34736.20393896103,31724.60327744484,3008.73100566864,1.1783719062805176,0.0 -40000,0.58336794,1.2665164,,,,,,,,,,,,,, -40100,0.6085652,1.2523516,,,,,,,,,,,,,, -40200,0.5918863,1.297555,,,,,,,,,,,,,, -40300,0.6834634,1.1943631,,,,,,,,,,,,,, -40400,0.5986727,1.2473699,,,,,,,,,,,,,, -40500,0.6052493,1.267425,,,,,,,,,,,,,, -40600,0.5090441,1.222527,,,,,,,,,,,,,, -40700,0.6818888,1.2513511,,,,,,,,,,,,,, -40800,0.62189794,1.1764091,,,,,,,,,,,,,, -40900,0.6913582,1.2806776,,,,,,,,,,,,,, -41000,0.8111242,1.197071,,,,,,,,,,,,,, -41100,0.5829287,1.2405392,,,,,,,,,,,,,, -41200,0.6769561,1.1970981,,,,,,,,,,,,,, -41300,0.5415767,1.2062078,,,,,,,,,,,,,, -41400,0.73365116,1.2020997,,,,,,,,,,,,,, -41500,0.7506019,1.2123419,,,,,,,,,,,,,, -41600,0.48714194,1.2188911,,,,,,,,,,,,,, -41700,0.51617205,1.2084726,,,,,,,,,,,,,, -41786,,,0.25626796,0.0934509031760979,0.42910972,0.1268042132905954,5348.0,0.24258377,0.0816728616984543,2472.0,33165.11921596527,36306.827905893326,33165.11921596527,3138.707051753998,1.2340517044067385,0.0 -41800,0.93743193,1.2112072,,,,,,,,,,,,,, -41900,0.65287536,1.2449553,,,,,,,,,,,,,, -42000,0.5275901,1.2273605,,,,,,,,,,,,,, -42100,0.86687183,1.263113,,,,,,,,,,,,,, -42200,0.47131106,1.2270391,,,,,,,,,,,,,, -42300,0.5413738,1.1934707,,,,,,,,,,,,,, -42400,0.5811048,1.2305397,,,,,,,,,,,,,, -42500,1.0399997,1.1698008,,,,,,,,,,,,,, -42600,0.55510116,1.204612,,,,,,,,,,,,,, -42700,0.6435042,1.2264708,,,,,,,,,,,,,, -42800,0.55932087,1.2093284,,,,,,,,,,,,,, -42900,0.52706116,1.1869767,,,,,,,,,,,,,, -43000,0.63601774,1.267712,,,,,,,,,,,,,, -43100,0.6518623,1.26543,,,,,,,,,,,,,, -43200,0.6398104,1.2431096,,,,,,,,,,,,,, -43300,0.7118821,1.154033,,,,,,,,,,,,,, -43400,0.52564245,1.1711513,,,,,,,,,,,,,, -43500,0.60972625,1.1892014,,,,,,,,,,,,,, -43600,1.4255298,1.1783575,,,,,,,,,,,,,, -43625,,,0.17323135,0.0646907749398744,0.4196442,0.1238305801481023,5348.0,0.23896857,0.0785448784351959,2472.0,34605.279014348984,37888.82444357872,34605.279014348984,3280.400631427765,1.2985985279083252,0.0 -43700,0.65515125,1.190147,,,,,,,,,,,,,, -43800,0.6962889,1.154139,,,,,,,,,,,,,, -43900,0.6110588,1.1986203,,,,,,,,,,,,,, -44000,1.0790193,1.2580397,,,,,,,,,,,,,, -44100,0.6363279,1.218389,,,,,,,,,,,,,, -44200,0.57914907,1.2130768,,,,,,,,,,,,,, -44300,0.7068135,1.1933748,,,,,,,,,,,,,, -44400,0.7118467,1.1407309,,,,,,,,,,,,,, -44500,0.5539703,1.1391633,,,,,,,,,,,,,, -44600,0.4941348,1.1465586,,,,,,,,,,,,,, -44700,0.60848933,1.2009529,,,,,,,,,,,,,, -44800,0.5736806,1.2224964,,,,,,,,,,,,,, -44900,0.81970495,1.2500284,,,,,,,,,,,,,, -45000,0.51004267,1.2088342,,,,,,,,,,,,,, -45100,0.60458547,1.1872091,,,,,,,,,,,,,, -45200,0.5165591,1.2524301,,,,,,,,,,,,,, -45300,0.7111892,1.2219088,,,,,,,,,,,,,, -45400,0.52335006,1.2007735,,,,,,,,,,,,,, -45444,,,0.1462238,0.0550591908416807,0.4182151,0.1212624424341311,5348.0,0.23613536,0.0795198342575102,2472.0,36045.82629442215,39463.43499088287,36045.82629442215,3414.3309786319733,1.3558061122894287,0.0 -45500,0.62015224,1.1831031,,,,,,,,,,,,,, -45600,0.52051824,1.144088,,,,,,,,,,,,,, -45700,0.5600447,1.1474409,,,,,,,,,,,,,, -45800,0.6665585,1.1865877,,,,,,,,,,,,,, -45900,0.70641446,1.206443,,,,,,,,,,,,,, -46000,0.53712887,1.1963881,,,,,,,,,,,,,, -46100,1.5995964,1.2570751,,,,,,,,,,,,,, -46200,0.78260833,1.2215751,,,,,,,,,,,,,, -46300,0.98477304,1.1684304,,,,,,,,,,,,,, -46400,0.7937488,1.1777118,,,,,,,,,,,,,, -46500,0.664781,1.199064,,,,,,,,,,,,,, -46600,1.2657421,1.2174791,,,,,,,,,,,,,, -46700,0.5687873,1.1931452,,,,,,,,,,,,,, -46800,0.5193123,1.1736327,,,,,,,,,,,,,, -46900,0.76833254,1.1426717,,,,,,,,,,,,,, -47000,0.57933384,1.1677885,,,,,,,,,,,,,, -47100,0.66230905,1.1260397,,,,,,,,,,,,,, -47200,1.0354167,1.2300098,,,,,,,,,,,,,, -47262,,,0.15300256,0.0583098080857786,0.41734868,0.1231837183930795,5348.0,0.22954966,0.0757012572867791,2472.0,37485.83829760552,41038.38579106331,37485.83829760552,3549.1361298561096,1.411712408065796,0.0 -47300,0.69305354,1.233677,,,,,,,,,,,,,, -47400,0.694101,1.1684823,,,,,,,,,,,,,, -47500,0.84489685,1.1635256,,,,,,,,,,,,,, -47600,1.0010254,1.1908727,,,,,,,,,,,,,, -47700,1.2164559,1.1328548,,,,,,,,,,,,,, -47800,0.5001442,1.1617839,,,,,,,,,,,,,, -47900,0.5725244,1.194973,,,,,,,,,,,,,, -48000,0.62761176,1.1975204,,,,,,,,,,,,,, -48100,0.71537644,1.1930732,,,,,,,,,,,,,, -48200,1.0163658,1.1647792,,,,,,,,,,,,,, -48300,0.89219695,1.1567814,,,,,,,,,,,,,, -48400,0.5222494,1.1670448,,,,,,,,,,,,,, -48500,0.60963213,1.1453183,,,,,,,,,,,,,, -48600,0.6592353,1.1345674,,,,,,,,,,,,,, -48700,0.7734271,1.1673017,,,,,,,,,,,,,, -48800,0.73352534,1.1907079,,,,,,,,,,,,,, -48900,0.6721895,1.2000339,,,,,,,,,,,,,, -49000,0.6079245,1.1521215,,,,,,,,,,,,,, -49083,,,0.14905518,0.0560649479656953,0.39878234,0.1168309566795717,5348.0,0.22368722,0.0727357666605731,2472.0,38926.85561108589,42614.214336395264,38926.85561108589,3683.813942432404,1.4672369956970217,0.0 -49100,0.51781785,1.1262754,,,,,,,,,,,,,, -49200,0.52575773,1.1051692,,,,,,,,,,,,,, -49300,0.51748174,1.1286498,,,,,,,,,,,,,, -49400,0.5165033,1.1647643,,,,,,,,,,,,,, -49500,0.57892925,1.1576517,,,,,,,,,,,,,, -49600,0.6079225,1.1442978,,,,,,,,,,,,,, -49700,0.86918867,1.1409274,,,,,,,,,,,,,, -49800,0.5973651,1.0924971,,,,,,,,,,,,,, -49900,0.63405836,1.1188631,,,,,,,,,,,,,, -50000,0.67691,1.148336,,,,,,,,,,,,,, -50100,0.7731224,1.1406037,,,,,,,,,,,,,, -50200,0.67801875,1.158069,,,,,,,,,,,,,, -50300,0.55592626,1.1428524,,,,,,,,,,,,,, -50400,0.7912814,1.1853372,,,,,,,,,,,,,, -50500,0.7616601,1.1189523,,,,,,,,,,,,,, -50600,0.6275433,1.1875961,,,,,,,,,,,,,, -50700,0.5495095,1.1527493,,,,,,,,,,,,,, -50800,0.66634727,1.1466491,,,,,,,,,,,,,, -50900,0.5489536,1.0947653,,,,,,,,,,,,,, -50921,,,0.13042888,0.0508238532410719,0.38371804,0.1114436602720681,5348.0,0.21463418,0.0714358255641541,2472.0,40367.21039605141,44186.81035447121,40367.21039605141,3815.917148351669,1.5273003578186035,0.0 -51000,0.8109634,1.1243249,,,,,,,,,,,,,, -51100,0.9675967,1.1868721,,,,,,,,,,,,,, -51200,0.59330547,1.1376231,,,,,,,,,,,,,, -51300,1.0386539,1.1615078,,,,,,,,,,,,,, -51400,0.6168064,1.1023839,,,,,,,,,,,,,, -51500,0.69496405,1.191491,,,,,,,,,,,,,, -51600,0.5902677,1.1400566,,,,,,,,,,,,,, -51700,0.5819357,1.1361983,,,,,,,,,,,,,, -51800,0.5863334,1.0993472,,,,,,,,,,,,,, -51900,1.0054106,1.1726351,,,,,,,,,,,,,, -52000,0.52995604,1.1907033,,,,,,,,,,,,,, -52100,0.8254183,1.1483605,,,,,,,,,,,,,, -52200,0.54403496,1.1630476,,,,,,,,,,,,,, -52300,0.76076865,1.1493449,,,,,,,,,,,,,, -52400,0.8177406,1.1462505,,,,,,,,,,,,,, -52500,0.767835,1.123802,,,,,,,,,,,,,, -52600,0.54443043,1.1323394,,,,,,,,,,,,,, -52700,0.7877121,1.0715057,,,,,,,,,,,,,, -52735,,,0.1234728,0.048450242173839,0.38684124,0.1116077893740888,5348.0,0.21119802,0.0702983771047874,2472.0,41807.4494600296,45760.08290052414,41807.4494600296,3948.819551467896,1.5827083587646484,0.0 -52800,1.1627009,1.1689492,,,,,,,,,,,,,, -52900,0.89413166,1.1523376,,,,,,,,,,,,,, -53000,0.6023535,1.1552854,,,,,,,,,,,,,, -53100,0.8791644,1.1780484,,,,,,,,,,,,,, -53200,0.60014427,1.1835785,,,,,,,,,,,,,, -53300,0.844148,1.1617296,,,,,,,,,,,,,, -53400,0.8849469,1.0937167,,,,,,,,,,,,,, -53500,0.5336052,1.1330017,,,,,,,,,,,,,, -53600,0.5601488,1.1113915,,,,,,,,,,,,,, -53700,0.66260576,1.0923287,,,,,,,,,,,,,, -53800,0.79759514,1.1104907,,,,,,,,,,,,,, -53900,0.5436174,1.0875404,,,,,,,,,,,,,, -54000,0.83048195,1.0820765,,,,,,,,,,,,,, -54100,0.60404575,1.1198527,,,,,,,,,,,,,, -54200,0.74459225,1.1036388,,,,,,,,,,,,,, -54300,0.5814472,1.0988879,,,,,,,,,,,,,, -54400,0.7492327,1.1086081,,,,,,,,,,,,,, -54500,0.62034804,1.0798388,,,,,,,,,,,,,, -54541,,,0.13599047,0.0502915692023852,0.36117485,0.1053901928034216,5348.0,0.20237793,0.0666626043507403,2472.0,43247.61201763153,47332.53596878052,43247.61201763153,4080.977082490921,1.6382241249084473,0.0 -54600,0.6591173,1.08502,,,,,,,,,,,,,, -54700,0.6980215,1.0963004,,,,,,,,,,,,,, -54800,0.5973332,1.0684103,,,,,,,,,,,,,, -54900,0.66280335,1.1127664,,,,,,,,,,,,,, -55000,0.5848707,1.1277301,,,,,,,,,,,,,, -55100,0.68825114,1.0703871,,,,,,,,,,,,,, -55200,0.5454152,1.1105695,,,,,,,,,,,,,, -55300,1.0452892,1.1157438,,,,,,,,,,,,,, -55400,0.7555475,1.208182,,,,,,,,,,,,,, -55500,0.7158094,1.1102096,,,,,,,,,,,,,, -55600,0.6712936,1.1135584,,,,,,,,,,,,,, -55700,0.7135366,1.060419,,,,,,,,,,,,,, -55800,0.6196655,1.0935061,,,,,,,,,,,,,, -55900,0.5754197,1.0876017,,,,,,,,,,,,,, -56000,0.6197679,1.1099104,,,,,,,,,,,,,, -56100,0.6871908,1.0686373,,,,,,,,,,,,,, -56200,0.9217257,1.0920118,,,,,,,,,,,,,, -56300,0.801978,1.1364679,,,,,,,,,,,,,, -56349,,,0.12227172,0.0438607010829609,0.3628258,0.1053033009258812,5348.0,0.1969371,0.0647533158653748,2472.0,44687.94367861748,48906.5517706871,44687.94367861748,4214.5278577804565,1.6951165199279783,0.0 -56400,0.7368044,1.1050099,,,,,,,,,,,,,, -56500,0.5662207,1.1097413,,,,,,,,,,,,,, -56600,0.83012986,1.0697595,,,,,,,,,,,,,, -56700,0.59354013,1.1045147,,,,,,,,,,,,,, -56800,0.70114,1.0170616,,,,,,,,,,,,,, -56900,0.84182614,1.1045364,,,,,,,,,,,,,, -57000,0.8677862,1.0616688,,,,,,,,,,,,,, -57100,0.6909482,1.1288383,,,,,,,,,,,,,, -57200,0.782821,1.0774027,,,,,,,,,,,,,, -57300,0.5647778,1.0781214,,,,,,,,,,,,,, -57400,0.7016367,1.1126561,,,,,,,,,,,,,, -57500,0.65189195,1.1014854,,,,,,,,,,,,,, -57600,0.75660294,1.0505321,,,,,,,,,,,,,, -57700,0.61249125,1.0245695,,,,,,,,,,,,,, -57800,1.0042341,1.0152451,,,,,,,,,,,,,, -57900,0.72138757,1.0881907,,,,,,,,,,,,,, -58000,0.6729968,1.0841266,,,,,,,,,,,,,, -58100,0.69136935,1.0366498,,,,,,,,,,,,,, -58181,,,0.11364357,0.0437140421904964,0.35202444,0.1014028210896241,5348.0,0.19438767,0.0630471431763248,2472.0,46128.45850539208,50480.25352835655,46128.45850539208,4347.578888177872,1.7528400421142578,0.0 -58200,0.6264593,1.0852839,,,,,,,,,,,,,, -58300,0.85124505,1.0963665,,,,,,,,,,,,,, -58400,0.82497215,1.0657014,,,,,,,,,,,,,, -58500,0.6361246,1.041618,,,,,,,,,,,,,, -58600,0.66157037,1.0527127,,,,,,,,,,,,,, -58700,0.78514516,1.0882049,,,,,,,,,,,,,, -58800,0.7308016,1.0375447,,,,,,,,,,,,,, -58900,0.61436737,1.0606182,,,,,,,,,,,,,, -59000,0.70344883,1.0548832,,,,,,,,,,,,,, -59100,1.3002615,1.0692551,,,,,,,,,,,,,, -59200,0.75751466,1.0730407,,,,,,,,,,,,,, -59300,0.55575603,1.0768931,,,,,,,,,,,,,, -59400,0.90368783,1.0778058,,,,,,,,,,,,,, -59500,0.8183459,1.0345501,,,,,,,,,,,,,, -59600,0.74295056,1.0481498,,,,,,,,,,,,,, -59700,0.71015406,1.0788815,,,,,,,,,,,,,, -59800,1.5594468,1.0428666,,,,,,,,,,,,,, -59900,0.5589547,1.0254985,,,,,,,,,,,,,, -60000,0.62887067,1.0268859,,,,,,,,,,,,,, -60009,,,0.09262989,0.0358730629181133,0.34524652,0.0995298183959759,5348.0,0.19161133,0.0622346799910629,2472.0,47569.04715514183,52052.42256188393,47569.04715514183,4479.026484251022,1.8070552349090576,0.0 -60100,0.68950534,1.0619982,,,,,,,,,,,,,, -60200,0.8263608,1.0688403,,,,,,,,,,,,,, -60300,0.92706287,1.0108432,,,,,,,,,,,,,, -60400,0.6357784,1.0295987,,,,,,,,,,,,,, -60500,0.5274041,1.0747819,,,,,,,,,,,,,, -60600,0.62410533,1.0343156,,,,,,,,,,,,,, -60700,0.61799276,1.0398781,,,,,,,,,,,,,, -60800,0.62375253,1.0645183,,,,,,,,,,,,,, -60900,1.0104467,0.984994,,,,,,,,,,,,,, -61000,1.5810189,1.0273336,,,,,,,,,,,,,, -61100,0.74819785,1.0191021,,,,,,,,,,,,,, -61200,0.7013473,0.9875921,,,,,,,,,,,,,, -61300,0.70177066,1.0529809,,,,,,,,,,,,,, -61400,0.745585,1.0189927,,,,,,,,,,,,,, -61500,0.6434776,1.0506923,,,,,,,,,,,,,, -61600,0.7554957,1.0209467,,,,,,,,,,,,,, -61700,0.8461314,1.0403569,,,,,,,,,,,,,, -61800,0.67502016,1.0181029,,,,,,,,,,,,,, -61811,,,0.08787621,0.0344356772032324,0.34125715,0.0982264402328702,5348.0,0.18546745,0.0596551093778563,2472.0,49009.17750668526,53624.07222151756,49009.17750668526,4610.411191225052,1.8653712272644043,0.0 -61900,1.1116387,1.0487543,,,,,,,,,,,,,, -62000,0.6340748,1.0346593,,,,,,,,,,,,,, -62100,0.7392094,1.0698829,,,,,,,,,,,,,, -62200,0.60342866,0.9931832,,,,,,,,,,,,,, -62300,0.5675756,0.994788,,,,,,,,,,,,,, -62400,0.57236695,1.0419414,,,,,,,,,,,,,, -62500,0.74185026,1.0393778,,,,,,,,,,,,,, -62600,0.6860265,0.97250795,,,,,,,,,,,,,, -62700,0.8898811,1.0494524,,,,,,,,,,,,,, -62800,0.7201442,1.0199832,,,,,,,,,,,,,, -62900,0.67110175,0.9887775,,,,,,,,,,,,,, -63000,0.6054132,1.0369476,,,,,,,,,,,,,, -63100,0.73468125,1.0378366,,,,,,,,,,,,,, -63200,1.0241556,1.0378917,,,,,,,,,,,,,, -63300,0.5954566,0.9784239,,,,,,,,,,,,,, -63400,0.67731816,1.031753,,,,,,,,,,,,,, -63500,2.2279303,1.045612,,,,,,,,,,,,,, -63600,0.90146047,0.9380697,,,,,,,,,,,,,, -63638,,,0.09771829,0.0379921302402006,0.33584002,0.0955810652944186,5348.0,0.18322575,0.0593301241037515,2472.0,50449.34701442719,55197.36210608482,50449.34701442719,4743.393732786179,1.925506830215454,0.0 -63700,0.81026435,1.0637083,,,,,,,,,,,,,, -63800,0.63305604,0.97371435,,,,,,,,,,,,,, -63900,0.86841905,0.96174777,,,,,,,,,,,,,, -64000,0.89729947,1.0063262,,,,,,,,,,,,,, -64100,0.9625889,0.96676624,,,,,,,,,,,,,, -64200,0.6701909,0.9846762,,,,,,,,,,,,,, -64300,0.5568936,0.97589195,,,,,,,,,,,,,, -64400,0.69009906,0.97430855,,,,,,,,,,,,,, -64500,0.7231031,1.0253814,,,,,,,,,,,,,, -64600,1.3993233,1.0546542,,,,,,,,,,,,,, -64700,0.9698656,1.0539455,,,,,,,,,,,,,, -64800,0.817785,1.0376002,,,,,,,,,,,,,, -64900,0.77863103,1.0536561,,,,,,,,,,,,,, -65000,0.89102787,0.993067,,,,,,,,,,,,,, -65100,0.9760473,0.9520222,,,,,,,,,,,,,, -65200,0.877143,1.0160618,,,,,,,,,,,,,, -65300,0.6687743,0.9872723,,,,,,,,,,,,,, -65400,1.1357957,1.0067198,,,,,,,,,,,,,, -65470,,,0.08789349,0.0339759476169358,0.32956982,0.0938046091313708,5348.0,0.1784024,0.0578473787906485,2472.0,51889.73348236084,56769.64642548561,51889.73348236084,4875.155463933945,1.9842355251312256,0.0 -65500,0.6664837,1.0234162,,,,,,,,,,,,,, -65600,0.85139525,1.0337402,,,,,,,,,,,,,, -65700,0.93329287,1.0187118,,,,,,,,,,,,,, -65800,0.679908,1.0011321,,,,,,,,,,,,,, -65900,0.59553885,1.0099959,,,,,,,,,,,,,, -66000,0.7547937,0.99755126,,,,,,,,,,,,,, -66100,0.83297145,1.0590732,,,,,,,,,,,,,, -66200,0.8244391,1.015378,,,,,,,,,,,,,, -66300,1.2066656,0.9776343,,,,,,,,,,,,,, -66400,0.5740201,0.97403526,,,,,,,,,,,,,, -66500,0.98836493,0.93613386,,,,,,,,,,,,,, -66600,2.5128791,0.98194826,,,,,,,,,,,,,, -66700,0.737051,1.0388398,,,,,,,,,,,,,, -66800,0.6357226,1.0231048,,,,,,,,,,,,,, -66900,0.6976282,0.9620524,,,,,,,,,,,,,, -67000,0.662373,0.99648976,,,,,,,,,,,,,, -67100,1.6229084,1.0103967,,,,,,,,,,,,,, -67200,0.59395534,1.0183275,,,,,,,,,,,,,, -67300,0.6392914,1.003107,,,,,,,,,,,,,, -67306,,,0.08603654,0.0328172183847109,0.32403156,0.0927425972947662,5348.0,0.1745069,0.0571161619239128,2472.0,53329.91587328911,58344.639607191086,53329.91587328911,5009.825822591782,2.044980049133301,0.0 -67400,0.64040107,0.9144003,,,,,,,,,,,,,, -67500,0.77911997,0.99850565,,,,,,,,,,,,,, -67600,0.6418719,0.985099,,,,,,,,,,,,,, -67700,0.6170795,0.97021663,,,,,,,,,,,,,, -67800,0.83711004,0.9496415,,,,,,,,,,,,,, -67900,0.7308821,1.0059388,,,,,,,,,,,,,, -68000,0.6969801,0.9801206,,,,,,,,,,,,,, -68100,0.5869476,0.9203848,,,,,,,,,,,,,, -68200,0.65836054,0.97286046,,,,,,,,,,,,,, -68300,0.8340695,0.9129697,,,,,,,,,,,,,, -68400,0.6149794,0.94686556,,,,,,,,,,,,,, -68500,0.63319904,0.98408437,,,,,,,,,,,,,, -68600,0.75872636,0.9926231,,,,,,,,,,,,,, -68700,0.62050503,0.97076076,,,,,,,,,,,,,, -68800,1.2193143,0.97051203,,,,,,,,,,,,,, -68900,0.59057885,0.96394706,,,,,,,,,,,,,, -69000,0.7728795,1.026112,,,,,,,,,,,,,, -69100,0.6677767,1.0102705,,,,,,,,,,,,,, -69115,,,0.07548663,0.0299299109679863,0.31870773,0.0905799550093167,5348.0,0.17174639,0.0546584607884955,2472.0,54770.38207864761,59917.89059305191,54770.38207864761,5142.46967458725,2.107982397079468,0.0 -69200,3.6367364,0.9550589,,,,,,,,,,,,,, -69300,1.0592197,0.98880255,,,,,,,,,,,,,, -69400,0.76014423,0.9613387,,,,,,,,,,,,,, -69500,0.8175445,0.9824551,,,,,,,,,,,,,, -69600,0.79658777,0.98412377,,,,,,,,,,,,,, -69700,1.26442,0.90310955,,,,,,,,,,,,,, -69800,0.8110343,1.0200622,,,,,,,,,,,,,, -69900,0.65053535,1.002046,,,,,,,,,,,,,, -70000,0.60230625,1.0155702,,,,,,,,,,,,,, -70100,0.77025133,0.9423735,,,,,,,,,,,,,, -70200,0.57731634,0.94834554,,,,,,,,,,,,,, -70300,0.98784274,0.943339,,,,,,,,,,,,,, -70400,0.8541968,0.93250245,,,,,,,,,,,,,, -70500,0.97369444,0.96241325,,,,,,,,,,,,,, -70600,0.70480794,0.9955228,,,,,,,,,,,,,, -70700,1.294171,0.8937514,,,,,,,,,,,,,, -70800,0.85858756,0.9534756,,,,,,,,,,,,,, -70900,2.3979735,0.9221618,,,,,,,,,,,,,, -70933,,,0.06823383,0.0263253865710832,0.31399453,0.0886779883564884,5348.0,0.17032003,0.0542522291958645,2472.0,56210.31742787361,61491.18255186081,56210.31742787361,5275.688220500946,2.1689844131469727,0.0 -71000,2.2184365,0.9840044,,,,,,,,,,,,,, -71100,0.72228813,0.92083806,,,,,,,,,,,,,, -71200,2.9415476,0.94639516,,,,,,,,,,,,,, -71300,0.6033143,0.9223115,,,,,,,,,,,,,, -71400,1.3472432,0.94896495,,,,,,,,,,,,,, -71500,0.7321791,0.9481487,,,,,,,,,,,,,, -71600,0.7251399,0.9540644,,,,,,,,,,,,,, -71700,0.84141314,0.98115426,,,,,,,,,,,,,, -71800,0.6579982,0.97085077,,,,,,,,,,,,,, -71900,0.74178874,1.017016,,,,,,,,,,,,,, -72000,1.6829388,0.9628964,,,,,,,,,,,,,, -72100,0.607095,0.9305166,,,,,,,,,,,,,, -72200,0.6679743,0.89362866,,,,,,,,,,,,,, -72300,0.7052585,0.9836006,,,,,,,,,,,,,, -72400,0.84773874,0.989993,,,,,,,,,,,,,, -72500,0.9532842,1.0047252,,,,,,,,,,,,,, -72600,0.8310069,0.96330655,,,,,,,,,,,,,, -72700,0.6605687,1.0013365,,,,,,,,,,,,,, -72746,,,0.07446662,0.0274383379387984,0.31198248,0.0877994149280245,5348.0,0.16882586,0.0543740986736538,2472.0,57650.18681359291,63064.41569805145,57650.18681359291,5408.91233420372,2.2319982051849365,0.0 -72800,0.7558392,0.9750668,,,,,,,,,,,,,, -72900,1.2512296,0.9105537,,,,,,,,,,,,,, -73000,1.3349184,1.008864,,,,,,,,,,,,,, -73100,0.9712462,0.95638394,,,,,,,,,,,,,, -73200,0.6812974,0.9161371,,,,,,,,,,,,,, -73300,0.8989541,0.9379458,,,,,,,,,,,,,, -73400,0.9124473,0.9532056,,,,,,,,,,,,,, -73500,0.5842369,0.95594466,,,,,,,,,,,,,, -73600,0.9897033,0.996143,,,,,,,,,,,,,, -73700,0.9119383,0.9264968,,,,,,,,,,,,,, -73800,0.80122274,0.9480283,,,,,,,,,,,,,, -73900,1.6521684,0.94835705,,,,,,,,,,,,,, -74000,0.56530696,0.94475466,,,,,,,,,,,,,, -74100,0.6636167,0.96922386,,,,,,,,,,,,,, -74200,0.5627129,0.9522148,,,,,,,,,,,,,, -74300,1.2339942,0.92568594,,,,,,,,,,,,,, -74400,0.6615651,0.9474824,,,,,,,,,,,,,, -74500,0.7544107,0.9107511,,,,,,,,,,,,,, -74586,,,0.07089892,0.0265824493211326,0.3090179,0.0869304961526207,5348.0,0.16789155,0.0534397660106026,2472.0,59090.40604901314,64635.95995926857,59090.40604901314,5540.094013929367,2.298288345336914,0.0 -74600,0.8646049,0.92270714,,,,,,,,,,,,,, -74700,1.4234923,0.9764876,,,,,,,,,,,,,, -74800,0.8992937,0.9372461,,,,,,,,,,,,,, -74900,0.6864019,0.9332594,,,,,,,,,,,,,, -75000,0.8526871,0.9475464,,,,,,,,,,,,,, -75100,0.7072118,0.9904003,,,,,,,,,,,,,, -75200,1.0011678,0.9624683,,,,,,,,,,,,,, -75300,1.3005698,1.00307,,,,,,,,,,,,,, -75400,1.0929226,0.88394666,,,,,,,,,,,,,, -75500,0.76217943,0.9341695,,,,,,,,,,,,,, -75600,0.6399228,0.95153713,,,,,,,,,,,,,, -75700,0.7753552,0.9266425,,,,,,,,,,,,,, -75800,0.6506649,0.94802296,,,,,,,,,,,,,, -75900,0.6190934,0.91813934,,,,,,,,,,,,,, -76000,0.70140404,0.9280233,,,,,,,,,,,,,, -76100,1.878941,0.99388117,,,,,,,,,,,,,, -76200,0.90536225,0.90522975,,,,,,,,,,,,,, -76300,1.3876231,0.93824124,,,,,,,,,,,,,, -76400,,,0.072326496,0.0263592050993625,0.3084427,0.086988424070981,5348.0,0.16666134,0.0529116649401824,2472.0,60530.30942058563,66207.60625338554,60530.30942058563,5671.698009729385,2.359781265258789,0.0 -76400,0.98348707,0.9438509,,,,,,,,,,,,,, -76500,0.5672375,0.93976206,,,,,,,,,,,,,, -76600,0.69045603,0.92545635,,,,,,,,,,,,,, -76700,1.5265235,0.93457276,,,,,,,,,,,,,, -76800,1.1314038,0.903381,,,,,,,,,,,,,, -76900,1.2969674,0.92154914,,,,,,,,,,,,,, -77000,0.7169641,0.975021,,,,,,,,,,,,,, -77088,,,,,,,,,,,61068.57802128792,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 70f582663..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -241.28403306007385,0.0,59.86646819114685,1,0,59.86646819114685,30.942284,2472,3.698758962484512,301.15058636665344,31.83468,3.780360452283281,30.891884,5348,3.324937003388784 -349.2952125072479,0.0406079292297363,1500.21915102005,1861,0,1500.21915102005,5.865311,2472,0.8945016553937399,1849.6291027069087,6.043958,0.9349704434512164,5.9683733,5348,0.8930167894416714 -466.6468231678009,0.0785510540008544,2940.3601796627045,3701,0,2940.3601796627045,3.1674793,2472,0.685657993622164,3407.2356901168823,3.299461,0.7144394427722878,3.655602,5348,0.7556600403564497 -601.7314755916595,0.1203293800354003,4380.670278787613,5552,0,4380.670278787613,0.6076796,2472,0.1981394593057502,4982.7510821819305,0.5416918,0.1891877645471269,0.92900485,5348,0.2685924481303764 -737.0722260475159,0.1613924503326416,5821.245131254196,7387,0,5821.245131254196,0.5291419,2472,0.1748420774683647,6558.784062385559,0.49491432,0.1698677334562096,0.8487835,5348,0.2476708149492648 -870.5358927249908,0.2001349925994873,7261.33238196373,9209,0,7261.33238196373,0.47206378,2472,0.1556679462961834,8132.448884963989,0.42755538,0.1478614451281828,0.773474,5348,0.2234376357685586 -1003.0489454269408,0.2415390014648437,8701.565061092377,11028,0,8701.565061092377,0.4390526,2472,0.1433591290394654,9705.31336402893,0.42699566,0.1452166166091405,0.7293994,5348,0.2142077874431582 -1140.7337267398834,0.2815468311309814,10142.004707336426,12867,0,10142.004707336426,0.41152832,2472,0.1355391708813194,11283.554337739944,0.37451404,0.1283474783570501,0.6910725,5348,0.2017050117304034 -1275.1978507041931,0.3295059204101562,11581.955075740814,14702,0,11581.955075740814,0.39033836,2472,0.1276989011435419,12858.095179080963,0.3195103,0.1121714033126186,0.6546651,5348,0.1937109589966884 -1411.4254655838013,0.3693804740905761,13021.907640695572,16509,0,13021.907640695572,0.37663108,2472,0.1234944041598115,14434.39049911499,0.3045598,0.1061558022865472,0.63967645,5348,0.1875995636096817 -1543.3572795391085,0.4117932319641113,14462.00174665451,18338,0,14462.00174665451,0.35803446,2472,0.1172384376332947,16006.535148620604,0.30234045,0.1066837563751226,0.6237858,5348,0.1815557507940952 -1677.659893035889,0.4502491950988769,15902.158890008926,20161,0,15902.158890008926,0.3493241,2472,0.1138057806755631,17581.109792232513,0.3074392,0.1046140828683994,0.59997153,5348,0.1748554215704258 -1811.598342895508,0.4881389141082763,17342.512769937515,22000,0,17342.512769937515,0.3443247,2472,0.112688643795828,19155.51647043228,0.30920258,0.1061334318199804,0.5998876,5348,0.1756181391621692 -1946.2173926830287,0.5289163589477539,18782.41963362693,23814,0,18782.41963362693,0.32907978,2472,0.1080982267990981,20730.158828258514,0.28427473,0.0983331618479267,0.5757108,5348,0.1699605124689844 -2080.084167957306,0.5681560039520264,20222.704214334488,25630,0,20222.704214334488,0.3139176,2472,0.1031422013690004,22304.42474460601,0.2570293,0.0890956885266669,0.5563869,5348,0.1642739218166195 -2215.4049229621887,0.6067309379577637,21662.970312833782,27442,0,21662.970312833782,0.3053218,2472,0.0992830012390063,23880.126302480698,0.24057451,0.0856308197918568,0.54482305,5348,0.1604217152456626 -2350.835747003556,0.6452951431274414,23103.443352222443,29281,0,23103.443352222443,0.29165354,2472,0.0950581926756443,25456.14620923996,0.22545318,0.0781571222550654,0.5211732,5348,0.1521959508385066 -2484.500387430191,0.6883645057678223,24543.64266705513,31100,0,24543.64266705513,0.2836596,2472,0.0927629841772794,27030.12952065468,0.23854691,0.0826887242156352,0.50667065,5348,0.1498305608387962 -2617.611071109772,0.7320413589477539,25983.56226277352,32914,0,25983.56226277352,0.2696539,2472,0.0879085166453395,28603.279937028885,0.21479145,0.0734133120635489,0.4919465,5348,0.1437095107987294 -2752.476796388626,0.7754151821136475,27424.42139530182,34715,0,27424.42139530182,0.25954035,2472,0.0839071354579245,30179.125629663467,0.2172395,0.0744648608663937,0.47097364,5348,0.1380036108402444 -2886.515316724777,0.8175325393676758,28864.815721035004,36545,0,28864.815721035004,0.24819948,2472,0.0804338553409298,31753.677274227142,0.20283253,0.0670258000575466,0.4575766,5348,0.135174797493652 -3028.468215703964,0.8610215187072754,30305.21221017837,38366,0,30305.21221017837,0.2412294,2472,0.0784433205370381,33336.14661240578,0.15163435,0.0533193936225823,0.44698212,5348,0.1318342875348774 -3162.2286465168,0.9062769412994384,31749.13193511963,40171,0,31749.13193511963,0.23414363,2472,0.0759653078219893,34913.94862866402,0.16758403,0.0578076492576714,0.43268782,5348,0.1269973063517962 -3298.460347652436,0.9550228118896484,33189.04080152512,41994,0,33189.04080152512,0.22567298,2472,0.0731216866735726,36490.21577978134,0.20786917,0.0725796390815944,0.4219524,5348,0.1231064811685992 -3431.107417821884,1.0002148151397705,34629.64849734306,43820,0,34629.64849734306,0.22200587,2472,0.0724920277049946,38063.593044281006,0.21207216,0.0729444811449063,0.41716936,5348,0.1221120519034148 -3561.2571907043457,1.044391393661499,36070.08312439919,45655,0,36070.08312439919,0.21940964,2472,0.0717404992586273,39634.29805088043,0.24418886,0.0848722733414838,0.4126984,5348,0.1206348899852283 -3701.1584811210632,1.0890610218048096,37509.99685645104,47439,0,37509.99685645104,0.21931432,2472,0.0715373834623118,41214.23322463036,0.22175954,0.0753316180842984,0.41188765,5348,0.1202004305975264 -3845.009263277054,1.1332612037658691,37935.802810907364,48000,0,37935.802810907364,0.21926215,2472,0.07163894136046961,41783.96035575867,0.20823953,0.07322993330586886,0.41214904,5348,0.12032594108730703 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index bc268d60d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,24.13289,33.282764,,,,,,,,,,,,,, -1,,,31.83468,3.780360452283281,30.891884,3.324937003388784,5348.0,30.942284,3.698758962484512,2472.0,59.86646819114685,301.15058636665344,59.86646819114685,241.28403306007385,0.0,0.0 -100,6.973602,10.8380575,,,,,,,,,,,,,, -200,1.4295144,6.3394523,,,,,,,,,,,,,, -300,0.6353589,5.897638,,,,,,,,,,,,,, -400,0.52055746,5.834236,,,,,,,,,,,,,, -500,0.3048621,5.8047986,,,,,,,,,,,,,, -600,0.31098303,5.7930202,,,,,,,,,,,,,, -700,0.33709225,5.705035,,,,,,,,,,,,,, -800,0.4574441,5.566649,,,,,,,,,,,,,, -900,0.56544685,5.340004,,,,,,,,,,,,,, -1000,0.8599819,5.0257134,,,,,,,,,,,,,, -1100,1.5764257,4.680764,,,,,,,,,,,,,, -1200,0.9357449,4.3398323,,,,,,,,,,,,,, -1300,1.0200728,4.0965033,,,,,,,,,,,,,, -1400,2.6621175,3.867077,,,,,,,,,,,,,, -1500,2.401539,3.6621127,,,,,,,,,,,,,, -1600,3.4418342,3.4585989,,,,,,,,,,,,,, -1700,3.1921313,3.271159,,,,,,,,,,,,,, -1800,1.5977002,3.2036245,,,,,,,,,,,,,, -1861,,,6.043958,0.9349704434512164,5.9683733,0.8930167894416714,5348.0,5.865311,0.8945016553937399,2472.0,1500.21915102005,1849.6291027069087,1500.21915102005,349.2952125072479,0.0406079292297363,0.0 -1900,2.2438982,3.0609593,,,,,,,,,,,,,, -2000,2.6607933,2.898207,,,,,,,,,,,,,, -2100,2.8247802,2.83694,,,,,,,,,,,,,, -2200,2.0146852,2.7630672,,,,,,,,,,,,,, -2300,2.7160647,2.6953828,,,,,,,,,,,,,, -2400,2.6143997,2.5858886,,,,,,,,,,,,,, -2500,2.4099174,2.5147507,,,,,,,,,,,,,, -2600,3.648985,2.4714737,,,,,,,,,,,,,, -2700,2.9737203,2.4700394,,,,,,,,,,,,,, -2800,3.6829255,2.3740172,,,,,,,,,,,,,, -2900,2.7775607,2.3610303,,,,,,,,,,,,,, -3000,2.2790108,2.3439746,,,,,,,,,,,,,, -3100,4.628522,2.34178,,,,,,,,,,,,,, -3200,3.1453831,2.239221,,,,,,,,,,,,,, -3300,3.9571505,2.1795192,,,,,,,,,,,,,, -3400,3.160029,2.1884575,,,,,,,,,,,,,, -3500,3.0040052,2.1356199,,,,,,,,,,,,,, -3600,3.689927,2.1379633,,,,,,,,,,,,,, -3700,3.678203,2.125854,,,,,,,,,,,,,, -3701,,,3.299461,0.7144394427722878,3.655602,0.7556600403564497,5348.0,3.1674793,0.685657993622164,2472.0,2940.3601796627045,3407.2356901168823,2940.3601796627045,466.6468231678009,0.0785510540008544,0.0 -3800,3.9817753,2.099019,,,,,,,,,,,,,, -3900,2.9464025,2.052976,,,,,,,,,,,,,, -4000,4.28328,2.1370797,,,,,,,,,,,,,, -4100,2.9281068,2.0989957,,,,,,,,,,,,,, -4200,3.2085805,2.0267704,,,,,,,,,,,,,, -4300,4.944285,1.9949522,,,,,,,,,,,,,, -4400,3.022923,2.0343919,,,,,,,,,,,,,, -4500,3.1790664,1.9681007,,,,,,,,,,,,,, -4600,3.241639,1.961379,,,,,,,,,,,,,, -4700,3.1460454,1.9920585,,,,,,,,,,,,,, -4800,3.024657,2.022919,,,,,,,,,,,,,, -4900,3.387593,1.9290082,,,,,,,,,,,,,, -5000,3.2021916,1.9357326,,,,,,,,,,,,,, -5100,3.720296,2.0002322,,,,,,,,,,,,,, -5200,3.2302275,1.884342,,,,,,,,,,,,,, -5300,3.6086972,1.8948468,,,,,,,,,,,,,, -5400,3.422835,1.7877145,,,,,,,,,,,,,, -5500,4.0004115,1.8769194,,,,,,,,,,,,,, -5552,,,0.5416918,0.1891877645471269,0.92900485,0.2685924481303764,5348.0,0.6076796,0.1981394593057502,2472.0,4380.670278787613,4982.7510821819305,4380.670278787613,601.7314755916595,0.1203293800354003,0.0 -5600,2.4610236,1.859694,,,,,,,,,,,,,, -5700,3.3910763,1.8747615,,,,,,,,,,,,,, -5800,3.1289184,1.832739,,,,,,,,,,,,,, -5900,3.3198981,1.8548487,,,,,,,,,,,,,, -6000,2.9464295,1.8275541,,,,,,,,,,,,,, -6100,4.3591847,1.8152666,,,,,,,,,,,,,, -6200,3.852103,1.7930698,,,,,,,,,,,,,, -6300,2.3026302,1.7060262,,,,,,,,,,,,,, -6400,3.5555265,1.7426182,,,,,,,,,,,,,, -6500,3.7864072,1.7120099,,,,,,,,,,,,,, -6600,3.2836099,1.802182,,,,,,,,,,,,,, -6700,2.7911975,1.77556,,,,,,,,,,,,,, -6800,3.3075132,1.847653,,,,,,,,,,,,,, -6900,3.9087262,1.7932286,,,,,,,,,,,,,, -7000,4.2528934,1.7656518,,,,,,,,,,,,,, -7100,2.8977063,1.7727433,,,,,,,,,,,,,, -7200,4.3177505,1.7901231,,,,,,,,,,,,,, -7300,2.734324,1.6912214,,,,,,,,,,,,,, -7387,,,0.49491432,0.1698677334562096,0.8487835,0.2476708149492648,5348.0,0.5291419,0.1748420774683647,2472.0,5821.245131254196,6558.784062385559,5821.245131254196,737.0722260475159,0.1613924503326416,0.0 -7400,3.29401,1.7467654,,,,,,,,,,,,,, -7500,3.3900552,1.710855,,,,,,,,,,,,,, -7600,3.1690142,1.7463762,,,,,,,,,,,,,, -7700,3.655635,1.7607126,,,,,,,,,,,,,, -7800,2.0171971,1.7906481,,,,,,,,,,,,,, -7900,3.034415,1.7527783,,,,,,,,,,,,,, -8000,3.6789086,1.6759479,,,,,,,,,,,,,, -8100,3.6419141,1.7845278,,,,,,,,,,,,,, -8200,3.0667803,1.6940137,,,,,,,,,,,,,, -8300,4.2801714,1.7092613,,,,,,,,,,,,,, -8400,3.1853375,1.7143588,,,,,,,,,,,,,, -8500,4.4233136,1.7419808,,,,,,,,,,,,,, -8600,2.935029,1.7294879,,,,,,,,,,,,,, -8700,3.1901245,1.7068683,,,,,,,,,,,,,, -8800,3.7358522,1.6773779,,,,,,,,,,,,,, -8900,3.0560265,1.6631082,,,,,,,,,,,,,, -9000,3.0043085,1.7347589,,,,,,,,,,,,,, -9100,4.957992,1.7019014,,,,,,,,,,,,,, -9200,3.8453848,1.7122322,,,,,,,,,,,,,, -9209,,,0.42755538,0.1478614451281828,0.773474,0.2234376357685586,5348.0,0.47206378,0.1556679462961834,2472.0,7261.33238196373,8132.448884963989,7261.33238196373,870.5358927249908,0.2001349925994873,0.0 -9300,3.8430421,1.6607723,,,,,,,,,,,,,, -9400,2.6533096,1.7373104,,,,,,,,,,,,,, -9500,2.716733,1.6447812,,,,,,,,,,,,,, -9600,3.561364,1.6722008,,,,,,,,,,,,,, -9700,2.0770152,1.688592,,,,,,,,,,,,,, -9800,3.219944,1.6775602,,,,,,,,,,,,,, -9900,3.1616788,1.6667879,,,,,,,,,,,,,, -10000,2.6621103,1.7439855,,,,,,,,,,,,,, -10100,3.4155312,1.670157,,,,,,,,,,,,,, -10200,3.1725726,1.6386088,,,,,,,,,,,,,, -10300,3.459162,1.5657713,,,,,,,,,,,,,, -10400,3.3549826,1.6304387,,,,,,,,,,,,,, -10500,3.9872541,1.6178367,,,,,,,,,,,,,, -10600,4.585502,1.6257217,,,,,,,,,,,,,, -10700,2.756471,1.6443685,,,,,,,,,,,,,, -10800,3.8489883,1.6188676,,,,,,,,,,,,,, -10900,2.4714463,1.6437492,,,,,,,,,,,,,, -11000,3.2517931,1.6821113,,,,,,,,,,,,,, -11028,,,0.42699566,0.1452166166091405,0.7293994,0.2142077874431582,5348.0,0.4390526,0.1433591290394654,2472.0,8701.565061092377,9705.31336402893,8701.565061092377,1003.0489454269408,0.2415390014648437,0.0 -11100,4.01467,1.6571591,,,,,,,,,,,,,, -11200,3.0142796,1.6270539,,,,,,,,,,,,,, -11300,4.1744814,1.6744093,,,,,,,,,,,,,, -11400,2.6764166,1.6499733,,,,,,,,,,,,,, -11500,3.2579153,1.6418906,,,,,,,,,,,,,, -11600,2.828197,1.5847648,,,,,,,,,,,,,, -11700,3.4640224,1.6864779,,,,,,,,,,,,,, -11800,2.502814,1.5984104,,,,,,,,,,,,,, -11900,3.0988228,1.609288,,,,,,,,,,,,,, -12000,4.013948,1.577355,,,,,,,,,,,,,, -12100,3.7116416,1.5848235,,,,,,,,,,,,,, -12200,2.9700782,1.6000996,,,,,,,,,,,,,, -12300,3.6299832,1.5894387,,,,,,,,,,,,,, -12400,2.113777,1.6367091,,,,,,,,,,,,,, -12500,3.9742458,1.6207759,,,,,,,,,,,,,, -12600,4.710164,1.6127763,,,,,,,,,,,,,, -12700,2.6113536,1.6300362,,,,,,,,,,,,,, -12800,3.0205226,1.6683629,,,,,,,,,,,,,, -12867,,,0.37451404,0.1283474783570501,0.6910725,0.2017050117304034,5348.0,0.41152832,0.1355391708813194,2472.0,10142.004707336426,11283.554337739944,10142.004707336426,1140.7337267398834,0.2815468311309814,0.0 -12900,3.193888,1.5916692,,,,,,,,,,,,,, -13000,2.0617507,1.5779073,,,,,,,,,,,,,, -13100,3.2679994,1.616774,,,,,,,,,,,,,, -13200,4.341536,1.5989168,,,,,,,,,,,,,, -13300,5.6058984,1.6152914,,,,,,,,,,,,,, -13400,2.6400485,1.596776,,,,,,,,,,,,,, -13500,3.007507,1.5939021,,,,,,,,,,,,,, -13600,3.0234463,1.5918217,,,,,,,,,,,,,, -13700,2.982665,1.5157385,,,,,,,,,,,,,, -13800,2.7567008,1.5374776,,,,,,,,,,,,,, -13900,3.066675,1.5656152,,,,,,,,,,,,,, -14000,3.2709608,1.6831268,,,,,,,,,,,,,, -14100,3.6852129,1.6152977,,,,,,,,,,,,,, -14200,2.1111002,1.6368484,,,,,,,,,,,,,, -14300,4.5181375,1.5705848,,,,,,,,,,,,,, -14400,3.2139728,1.5273302,,,,,,,,,,,,,, -14500,3.1215374,1.5762472,,,,,,,,,,,,,, -14600,3.2738695,1.5633458,,,,,,,,,,,,,, -14700,2.8381906,1.5350475,,,,,,,,,,,,,, -14702,,,0.3195103,0.1121714033126186,0.6546651,0.1937109589966884,5348.0,0.39033836,0.1276989011435419,2472.0,11581.955075740814,12858.095179080963,11581.955075740814,1275.1978507041931,0.3295059204101562,0.0 -14800,2.0849993,1.5989709,,,,,,,,,,,,,, -14900,3.0576153,1.5000896,,,,,,,,,,,,,, -15000,3.1981683,1.5352912,,,,,,,,,,,,,, -15100,4.2891893,1.5626504,,,,,,,,,,,,,, -15200,2.8686843,1.6218512,,,,,,,,,,,,,, -15300,4.073192,1.5629153,,,,,,,,,,,,,, -15400,2.5904508,1.6026407,,,,,,,,,,,,,, -15500,2.7824311,1.5636376,,,,,,,,,,,,,, -15600,3.6931577,1.5773234,,,,,,,,,,,,,, -15700,2.1511092,1.5681078,,,,,,,,,,,,,, -15800,3.2677908,1.5786622,,,,,,,,,,,,,, -15900,2.2802918,1.539816,,,,,,,,,,,,,, -16000,3.0888526,1.5433836,,,,,,,,,,,,,, -16100,2.6277707,1.5109732,,,,,,,,,,,,,, -16200,2.979862,1.4631165,,,,,,,,,,,,,, -16300,3.292534,1.562672,,,,,,,,,,,,,, -16400,3.0155418,1.5539849,,,,,,,,,,,,,, -16500,3.288834,1.4650472,,,,,,,,,,,,,, -16509,,,0.3045598,0.1061558022865472,0.63967645,0.1875995636096817,5348.0,0.37663108,0.1234944041598115,2472.0,13021.907640695572,14434.39049911499,13021.907640695572,1411.4254655838013,0.3693804740905761,0.0 -16600,2.6318617,1.5331415,,,,,,,,,,,,,, -16700,3.444716,1.4815712,,,,,,,,,,,,,, -16800,2.1922657,1.5017262,,,,,,,,,,,,,, -16900,2.5465353,1.5431769,,,,,,,,,,,,,, -17000,3.076037,1.5123825,,,,,,,,,,,,,, -17100,2.6875587,1.5088954,,,,,,,,,,,,,, -17200,2.5665596,1.511739,,,,,,,,,,,,,, -17300,2.2775967,1.5805506,,,,,,,,,,,,,, -17400,2.3831325,1.5001475,,,,,,,,,,,,,, -17500,2.9475024,1.5717361,,,,,,,,,,,,,, -17600,2.911095,1.4910127,,,,,,,,,,,,,, -17700,2.871904,1.5119714,,,,,,,,,,,,,, -17800,2.724191,1.5436646,,,,,,,,,,,,,, -17900,3.8322632,1.5595896,,,,,,,,,,,,,, -18000,2.2196848,1.5042994,,,,,,,,,,,,,, -18100,2.3414924,1.4874818,,,,,,,,,,,,,, -18200,2.267945,1.4898275,,,,,,,,,,,,,, -18300,3.9020877,1.5138875,,,,,,,,,,,,,, -18338,,,0.30234045,0.1066837563751226,0.6237858,0.1815557507940952,5348.0,0.35803446,0.1172384376332947,2472.0,14462.00174665451,16006.535148620604,14462.00174665451,1543.3572795391085,0.4117932319641113,0.0 -18400,2.3993955,1.482369,,,,,,,,,,,,,, -18500,3.54379,1.5365727,,,,,,,,,,,,,, -18600,3.4609113,1.453685,,,,,,,,,,,,,, -18700,2.7809231,1.5212727,,,,,,,,,,,,,, -18800,3.0653431,1.5143199,,,,,,,,,,,,,, -18900,2.0638964,1.5279526,,,,,,,,,,,,,, -19000,2.6115103,1.4446975,,,,,,,,,,,,,, -19100,3.150009,1.5084879,,,,,,,,,,,,,, -19200,3.6216297,1.5805323,,,,,,,,,,,,,, -19300,3.1500847,1.537873,,,,,,,,,,,,,, -19400,2.9507616,1.4298391,,,,,,,,,,,,,, -19500,2.832706,1.4274203,,,,,,,,,,,,,, -19600,2.275482,1.5150516,,,,,,,,,,,,,, -19700,2.6174598,1.5110344,,,,,,,,,,,,,, -19800,3.140279,1.5371901,,,,,,,,,,,,,, -19900,2.7379954,1.4879978,,,,,,,,,,,,,, -20000,2.6624832,1.4982264,,,,,,,,,,,,,, -20100,1.9972464,1.4727379,,,,,,,,,,,,,, -20161,,,0.3074392,0.1046140828683994,0.59997153,0.1748554215704258,5348.0,0.3493241,0.1138057806755631,2472.0,15902.158890008926,17581.109792232513,15902.158890008926,1677.659893035889,0.4502491950988769,0.0 -20200,2.7835073,1.5171111,,,,,,,,,,,,,, -20300,2.63223,1.494586,,,,,,,,,,,,,, -20400,1.7389991,1.4779685,,,,,,,,,,,,,, -20500,2.0952716,1.4342324,,,,,,,,,,,,,, -20600,2.8207064,1.4848061,,,,,,,,,,,,,, -20700,2.7696705,1.5009515,,,,,,,,,,,,,, -20800,3.3860893,1.4544034,,,,,,,,,,,,,, -20900,2.4124389,1.4718546,,,,,,,,,,,,,, -21000,2.2685993,1.4946519,,,,,,,,,,,,,, -21100,3.0495245,1.4569666,,,,,,,,,,,,,, -21200,2.9160178,1.5235236,,,,,,,,,,,,,, -21300,3.3684309,1.4617968,,,,,,,,,,,,,, -21400,2.5815337,1.4461726,,,,,,,,,,,,,, -21500,2.9499967,1.4862776,,,,,,,,,,,,,, -21600,2.9064643,1.4619308,,,,,,,,,,,,,, -21700,2.7869937,1.3612231,,,,,,,,,,,,,, -21800,2.8432949,1.5362144,,,,,,,,,,,,,, -21900,2.3458335,1.4498721,,,,,,,,,,,,,, -22000,,,0.30920258,0.1061334318199804,0.5998876,0.1756181391621692,5348.0,0.3443247,0.112688643795828,2472.0,17342.512769937515,19155.51647043228,17342.512769937515,1811.598342895508,0.4881389141082763,0.0 -22000,2.6971884,1.4217861,,,,,,,,,,,,,, -22100,2.512123,1.4303519,,,,,,,,,,,,,, -22200,2.190637,1.4524258,,,,,,,,,,,,,, -22300,3.1112154,1.4806136,,,,,,,,,,,,,, -22400,2.567705,1.4828421,,,,,,,,,,,,,, -22500,2.4912825,1.4764006,,,,,,,,,,,,,, -22600,2.2713401,1.4413849,,,,,,,,,,,,,, -22700,2.4845583,1.4324908,,,,,,,,,,,,,, -22800,3.2088413,1.4460424,,,,,,,,,,,,,, -22900,3.3614545,1.4408464,,,,,,,,,,,,,, -23000,1.9969802,1.5032132,,,,,,,,,,,,,, -23100,3.1955545,1.3881044,,,,,,,,,,,,,, -23200,2.9938536,1.4593424,,,,,,,,,,,,,, -23300,2.7136922,1.4433365,,,,,,,,,,,,,, -23400,2.5363302,1.3992914,,,,,,,,,,,,,, -23500,3.067186,1.4541054,,,,,,,,,,,,,, -23600,4.3632474,1.4599589,,,,,,,,,,,,,, -23700,3.2512884,1.4065672,,,,,,,,,,,,,, -23800,3.5997818,1.3878895,,,,,,,,,,,,,, -23814,,,0.28427473,0.0983331618479267,0.5757108,0.1699605124689844,5348.0,0.32907978,0.1080982267990981,2472.0,18782.41963362693,20730.158828258514,18782.41963362693,1946.2173926830287,0.5289163589477539,0.0 -23900,2.3973153,1.3939278,,,,,,,,,,,,,, -24000,2.5650365,1.4119474,,,,,,,,,,,,,, -24100,2.7661529,1.4196569,,,,,,,,,,,,,, -24200,2.979572,1.4152455,,,,,,,,,,,,,, -24300,3.3047414,1.4827243,,,,,,,,,,,,,, -24400,2.6667922,1.3950077,,,,,,,,,,,,,, -24500,3.3783097,1.4644732,,,,,,,,,,,,,, -24600,3.6313171,1.4529916,,,,,,,,,,,,,, -24700,3.714221,1.4184749,,,,,,,,,,,,,, -24800,2.5370953,1.369153,,,,,,,,,,,,,, -24900,2.9118576,1.3811746,,,,,,,,,,,,,, -25000,3.162071,1.3588945,,,,,,,,,,,,,, -25100,2.6170342,1.4060844,,,,,,,,,,,,,, -25200,3.2212565,1.4707613,,,,,,,,,,,,,, -25300,2.6852672,1.4493536,,,,,,,,,,,,,, -25400,2.741864,1.3808637,,,,,,,,,,,,,, -25500,2.2977226,1.3889186,,,,,,,,,,,,,, -25600,3.4045348,1.4733505,,,,,,,,,,,,,, -25630,,,0.2570293,0.0890956885266669,0.5563869,0.1642739218166195,5348.0,0.3139176,0.1031422013690004,2472.0,20222.704214334488,22304.42474460601,20222.704214334488,2080.084167957306,0.5681560039520264,0.0 -25700,2.7583466,1.401469,,,,,,,,,,,,,, -25800,2.0853739,1.3729112,,,,,,,,,,,,,, -25900,3.259947,1.3826528,,,,,,,,,,,,,, -26000,2.7967985,1.4573171,,,,,,,,,,,,,, -26100,2.4734943,1.3749956,,,,,,,,,,,,,, -26200,2.225152,1.3682371,,,,,,,,,,,,,, -26300,2.5147734,1.4806741,,,,,,,,,,,,,, -26400,3.2768667,1.3776363,,,,,,,,,,,,,, -26500,2.5216012,1.3772968,,,,,,,,,,,,,, -26600,2.1706593,1.4136168,,,,,,,,,,,,,, -26700,2.963314,1.3125958,,,,,,,,,,,,,, -26800,3.7186532,1.4166425,,,,,,,,,,,,,, -26900,2.5366597,1.3493876,,,,,,,,,,,,,, -27000,2.556785,1.3760921,,,,,,,,,,,,,, -27100,3.4289207,1.425246,,,,,,,,,,,,,, -27200,2.4650896,1.3572261,,,,,,,,,,,,,, -27300,4.4042935,1.3678321,,,,,,,,,,,,,, -27400,3.9125266,1.3659981,,,,,,,,,,,,,, -27442,,,0.24057451,0.0856308197918568,0.54482305,0.1604217152456626,5348.0,0.3053218,0.0992830012390063,2472.0,21662.970312833782,23880.126302480698,21662.970312833782,2215.4049229621887,0.6067309379577637,0.0 -27500,2.5345442,1.3304733,,,,,,,,,,,,,, -27600,3.5715134,1.4205987,,,,,,,,,,,,,, -27700,2.455182,1.4091684,,,,,,,,,,,,,, -27800,3.3981583,1.3812399,,,,,,,,,,,,,, -27900,1.6916361,1.4316213,,,,,,,,,,,,,, -28000,2.3861933,1.3403898,,,,,,,,,,,,,, -28100,4.083364,1.3731931,,,,,,,,,,,,,, -28200,3.3647048,1.3500015,,,,,,,,,,,,,, -28300,2.261224,1.3182201,,,,,,,,,,,,,, -28400,2.7645872,1.4014376,,,,,,,,,,,,,, -28500,2.998744,1.4071507,,,,,,,,,,,,,, -28600,2.520063,1.3787609,,,,,,,,,,,,,, -28700,2.6840577,1.4245074,,,,,,,,,,,,,, -28800,2.5103006,1.4477297,,,,,,,,,,,,,, -28900,3.0541224,1.3608327,,,,,,,,,,,,,, -29000,2.4347785,1.2655668,,,,,,,,,,,,,, -29100,2.1921232,1.3295768,,,,,,,,,,,,,, -29200,2.9562347,1.4069437,,,,,,,,,,,,,, -29281,,,0.22545318,0.0781571222550654,0.5211732,0.1521959508385066,5348.0,0.29165354,0.0950581926756443,2472.0,23103.443352222443,25456.14620923996,23103.443352222443,2350.835747003556,0.6452951431274414,0.0 -29300,2.5184069,1.3391004,,,,,,,,,,,,,, -29400,2.3596103,1.3749908,,,,,,,,,,,,,, -29500,3.047933,1.4074633,,,,,,,,,,,,,, -29600,2.5685828,1.3633108,,,,,,,,,,,,,, -29700,3.256958,1.2951276,,,,,,,,,,,,,, -29800,3.3487015,1.3521671,,,,,,,,,,,,,, -29900,3.1395736,1.3170885,,,,,,,,,,,,,, -30000,2.064467,1.2913074,,,,,,,,,,,,,, -30100,2.0406907,1.3557726,,,,,,,,,,,,,, -30200,3.4211972,1.3968796,,,,,,,,,,,,,, -30300,2.645461,1.3469942,,,,,,,,,,,,,, -30400,2.8198555,1.3222158,,,,,,,,,,,,,, -30500,2.084583,1.3756716,,,,,,,,,,,,,, -30600,2.6836078,1.3613147,,,,,,,,,,,,,, -30700,3.1492503,1.3510199,,,,,,,,,,,,,, -30800,3.1129339,1.3444182,,,,,,,,,,,,,, -30900,2.0112326,1.291837,,,,,,,,,,,,,, -31000,2.6197746,1.3647738,,,,,,,,,,,,,, -31100,,,0.23854691,0.0826887242156352,0.50667065,0.1498305608387962,5348.0,0.2836596,0.0927629841772794,2472.0,24543.64266705513,27030.12952065468,24543.64266705513,2484.500387430191,0.6883645057678223,0.0 -31100,2.2456205,1.3434252,,,,,,,,,,,,,, -31200,3.7499192,1.3246487,,,,,,,,,,,,,, -31300,3.303486,1.2902806,,,,,,,,,,,,,, -31400,2.482566,1.2860986,,,,,,,,,,,,,, -31500,2.8062584,1.2959808,,,,,,,,,,,,,, -31600,3.0687804,1.3198464,,,,,,,,,,,,,, -31700,2.6205347,1.2866945,,,,,,,,,,,,,, -31800,2.7695043,1.3940793,,,,,,,,,,,,,, -31900,3.3537908,1.3289479,,,,,,,,,,,,,, -32000,3.0370338,1.3064377,,,,,,,,,,,,,, -32100,2.8210926,1.3255275,,,,,,,,,,,,,, -32200,2.2691882,1.301455,,,,,,,,,,,,,, -32300,3.188768,1.2732121,,,,,,,,,,,,,, -32400,2.85854,1.3092146,,,,,,,,,,,,,, -32500,3.2532303,1.281815,,,,,,,,,,,,,, -32600,2.5443223,1.3289334,,,,,,,,,,,,,, -32700,3.7823956,1.2477179,,,,,,,,,,,,,, -32800,3.362277,1.3231955,,,,,,,,,,,,,, -32900,2.2313635,1.3128257,,,,,,,,,,,,,, -32914,,,0.21479145,0.0734133120635489,0.4919465,0.1437095107987294,5348.0,0.2696539,0.0879085166453395,2472.0,25983.56226277352,28603.279937028885,25983.56226277352,2617.611071109772,0.7320413589477539,0.0 -33000,3.0084393,1.2763088,,,,,,,,,,,,,, -33100,2.1601708,1.2758037,,,,,,,,,,,,,, -33200,4.278338,1.2755165,,,,,,,,,,,,,, -33300,2.4315028,1.3002943,,,,,,,,,,,,,, -33400,2.0825307,1.289373,,,,,,,,,,,,,, -33500,2.6229224,1.2931854,,,,,,,,,,,,,, -33600,2.9173357,1.2734342,,,,,,,,,,,,,, -33700,2.448111,1.2715394,,,,,,,,,,,,,, -33800,4.2796655,1.2397808,,,,,,,,,,,,,, -33900,2.6884978,1.2735434,,,,,,,,,,,,,, -34000,2.2319207,1.2319055,,,,,,,,,,,,,, -34100,2.7995808,1.2794155,,,,,,,,,,,,,, -34200,2.4755328,1.2227793,,,,,,,,,,,,,, -34300,3.0285022,1.2575458,,,,,,,,,,,,,, -34400,2.7605424,1.2795682,,,,,,,,,,,,,, -34500,2.36891,1.2867653,,,,,,,,,,,,,, -34600,3.450152,1.2301985,,,,,,,,,,,,,, -34700,5.0839505,1.2677253,,,,,,,,,,,,,, -34715,,,0.2172395,0.0744648608663937,0.47097364,0.1380036108402444,5348.0,0.25954035,0.0839071354579245,2472.0,27424.42139530182,30179.125629663467,27424.42139530182,2752.476796388626,0.7754151821136475,0.0 -34800,2.6172886,1.2066782,,,,,,,,,,,,,, -34900,2.4037428,1.2979947,,,,,,,,,,,,,, -35000,4.287908,1.2795779,,,,,,,,,,,,,, -35100,3.2091289,1.2846328,,,,,,,,,,,,,, -35200,2.3846555,1.265364,,,,,,,,,,,,,, -35300,3.9658124,1.2079654,,,,,,,,,,,,,, -35400,2.5271807,1.2200536,,,,,,,,,,,,,, -35500,4.128968,1.2313416,,,,,,,,,,,,,, -35600,2.7223957,1.2541955,,,,,,,,,,,,,, -35700,2.720332,1.1810983,,,,,,,,,,,,,, -35800,2.9665015,1.2461935,,,,,,,,,,,,,, -35900,2.297688,1.2309633,,,,,,,,,,,,,, -36000,3.0828202,1.2684939,,,,,,,,,,,,,, -36100,2.3821177,1.2121577,,,,,,,,,,,,,, -36200,3.1373405,1.248768,,,,,,,,,,,,,, -36300,2.398777,1.2151288,,,,,,,,,,,,,, -36400,2.96082,1.2024634,,,,,,,,,,,,,, -36500,2.5553377,1.2146565,,,,,,,,,,,,,, -36545,,,0.20283253,0.0670258000575466,0.4575766,0.135174797493652,5348.0,0.24819948,0.0804338553409298,2472.0,28864.815721035004,31753.677274227142,28864.815721035004,2886.515316724777,0.8175325393676758,0.0 -36600,2.6310658,1.2503294,,,,,,,,,,,,,, -36700,3.1356032,1.2640923,,,,,,,,,,,,,, -36800,3.0827131,1.2109039,,,,,,,,,,,,,, -36900,3.2950861,1.2368562,,,,,,,,,,,,,, -37000,4.065084,1.2167964,,,,,,,,,,,,,, -37100,3.606573,1.2317067,,,,,,,,,,,,,, -37200,2.6393807,1.2080654,,,,,,,,,,,,,, -37300,2.5314639,1.1884066,,,,,,,,,,,,,, -37400,2.694804,1.2420362,,,,,,,,,,,,,, -37500,3.1847222,1.1908973,,,,,,,,,,,,,, -37600,2.1981342,1.2138081,,,,,,,,,,,,,, -37700,2.8428211,1.256136,,,,,,,,,,,,,, -37800,3.711492,1.202703,,,,,,,,,,,,,, -37900,2.5296144,1.267932,,,,,,,,,,,,,, -38000,3.5724666,1.221358,,,,,,,,,,,,,, -38100,2.8216283,1.1882465,,,,,,,,,,,,,, -38200,3.46357,1.1322311,,,,,,,,,,,,,, -38300,3.5634162,1.1862897,,,,,,,,,,,,,, -38366,,,0.15163435,0.0533193936225823,0.44698212,0.1318342875348774,5348.0,0.2412294,0.0784433205370381,2472.0,30305.21221017837,33336.14661240578,30305.21221017837,3028.468215703964,0.8610215187072754,0.0 -38400,2.5394928,1.1753994,,,,,,,,,,,,,, -38500,3.7295723,1.1664433,,,,,,,,,,,,,, -38600,3.296945,1.1789428,,,,,,,,,,,,,, -38700,3.235008,1.2744023,,,,,,,,,,,,,, -38800,2.469632,1.2033774,,,,,,,,,,,,,, -38900,3.3925242,1.2199464,,,,,,,,,,,,,, -39000,3.876519,1.2301155,,,,,,,,,,,,,, -39100,3.9710076,1.2295985,,,,,,,,,,,,,, -39200,3.6521358,1.1860884,,,,,,,,,,,,,, -39300,4.4939284,1.1449592,,,,,,,,,,,,,, -39400,2.8270295,1.1885052,,,,,,,,,,,,,, -39500,3.405498,1.1636751,,,,,,,,,,,,,, -39600,2.663614,1.1770017,,,,,,,,,,,,,, -39700,2.936143,1.179067,,,,,,,,,,,,,, -39800,1.8128529,1.1484358,,,,,,,,,,,,,, -39900,3.721826,1.1477844,,,,,,,,,,,,,, -40000,2.7979515,1.1253908,,,,,,,,,,,,,, -40100,2.8464565,1.191008,,,,,,,,,,,,,, -40171,,,0.16758403,0.0578076492576714,0.43268782,0.1269973063517962,5348.0,0.23414363,0.0759653078219893,2472.0,31749.13193511963,34913.94862866402,31749.13193511963,3162.2286465168,0.9062769412994384,0.0 -40200,2.4673784,1.069814,,,,,,,,,,,,,, -40300,2.8761578,1.1614839,,,,,,,,,,,,,, -40400,3.5099828,1.1552995,,,,,,,,,,,,,, -40500,3.3375185,1.16057,,,,,,,,,,,,,, -40600,2.506947,1.1710929,,,,,,,,,,,,,, -40700,2.4615612,1.1275724,,,,,,,,,,,,,, -40800,2.8191485,1.1609582,,,,,,,,,,,,,, -40900,4.053442,1.1797976,,,,,,,,,,,,,, -41000,3.4291937,1.176202,,,,,,,,,,,,,, -41100,3.004897,1.1893436,,,,,,,,,,,,,, -41200,3.738324,1.1448823,,,,,,,,,,,,,, -41300,3.2068465,1.1238248,,,,,,,,,,,,,, -41400,2.1777253,1.1414156,,,,,,,,,,,,,, -41500,2.5160286,1.189605,,,,,,,,,,,,,, -41600,2.7956061,1.156822,,,,,,,,,,,,,, -41700,2.9803145,1.1534566,,,,,,,,,,,,,, -41800,2.6562455,1.1538624,,,,,,,,,,,,,, -41900,4.3160553,1.1321009,,,,,,,,,,,,,, -41994,,,0.20786917,0.0725796390815944,0.4219524,0.1231064811685992,5348.0,0.22567298,0.0731216866735726,2472.0,33189.04080152512,36490.21577978134,33189.04080152512,3298.460347652436,0.9550228118896484,0.0 -42000,2.715022,1.182788,,,,,,,,,,,,,, -42100,2.2297735,1.1463073,,,,,,,,,,,,,, -42200,3.1873,1.1987741,,,,,,,,,,,,,, -42300,3.569842,1.1234453,,,,,,,,,,,,,, -42400,2.6995952,1.1381323,,,,,,,,,,,,,, -42500,3.010986,1.1023104,,,,,,,,,,,,,, -42600,4.7238593,1.1727233,,,,,,,,,,,,,, -42700,4.4900312,1.1361382,,,,,,,,,,,,,, -42800,2.8681734,1.0950992,,,,,,,,,,,,,, -42900,4.2437057,1.1595175,,,,,,,,,,,,,, -43000,2.8019607,1.1211711,,,,,,,,,,,,,, -43100,4.149897,1.1189724,,,,,,,,,,,,,, -43200,3.406498,1.1142746,,,,,,,,,,,,,, -43300,2.7651458,1.119174,,,,,,,,,,,,,, -43400,3.8791566,1.1138681,,,,,,,,,,,,,, -43500,3.2424023,1.1859207,,,,,,,,,,,,,, -43600,5.010179,1.1128272,,,,,,,,,,,,,, -43700,3.7573016,1.143314,,,,,,,,,,,,,, -43800,4.0501876,1.1112411,,,,,,,,,,,,,, -43820,,,0.21207216,0.0729444811449063,0.41716936,0.1221120519034148,5348.0,0.22200587,0.0724920277049946,2472.0,34629.64849734306,38063.593044281006,34629.64849734306,3431.107417821884,1.0002148151397705,0.0 -43900,3.4637153,1.1466298,,,,,,,,,,,,,, -44000,3.7646704,1.1497222,,,,,,,,,,,,,, -44100,4.7304,1.1061654,,,,,,,,,,,,,, -44200,3.020114,1.1054742,,,,,,,,,,,,,, -44300,3.0230353,1.1612114,,,,,,,,,,,,,, -44400,2.0640972,1.1085074,,,,,,,,,,,,,, -44500,3.7673395,1.1186049,,,,,,,,,,,,,, -44600,2.3176687,1.1438162,,,,,,,,,,,,,, -44700,3.6076398,1.1327741,,,,,,,,,,,,,, -44800,3.7180727,1.106222,,,,,,,,,,,,,, -44900,5.125507,1.0705179,,,,,,,,,,,,,, -45000,4.0635033,1.1676177,,,,,,,,,,,,,, -45100,2.469315,1.1868075,,,,,,,,,,,,,, -45200,3.7221358,1.1861917,,,,,,,,,,,,,, -45300,3.2949278,1.1288661,,,,,,,,,,,,,, -45400,2.9468412,1.1384239,,,,,,,,,,,,,, -45500,2.613156,1.0879667,,,,,,,,,,,,,, -45600,3.4136322,1.0918121,,,,,,,,,,,,,, -45655,,,0.24418886,0.0848722733414838,0.4126984,0.1206348899852283,5348.0,0.21940964,0.0717404992586273,2472.0,36070.08312439919,39634.29805088043,36070.08312439919,3561.2571907043457,1.044391393661499,0.0 -45700,2.7197149,1.1132656,,,,,,,,,,,,,, -45800,2.7942228,1.0570076,,,,,,,,,,,,,, -45900,2.4014244,1.0952396,,,,,,,,,,,,,, -46000,3.844466,1.1331536,,,,,,,,,,,,,, -46100,2.6171002,1.1025649,,,,,,,,,,,,,, -46200,2.462153,1.1037176,,,,,,,,,,,,,, -46300,3.5795245,1.077631,,,,,,,,,,,,,, -46400,3.0557687,1.0897788,,,,,,,,,,,,,, -46500,2.672965,1.1427515,,,,,,,,,,,,,, -46600,3.10515,1.0941223,,,,,,,,,,,,,, -46700,4.385889,1.138784,,,,,,,,,,,,,, -46800,3.2898815,1.123021,,,,,,,,,,,,,, -46900,2.6825855,1.0654675,,,,,,,,,,,,,, -47000,3.1793685,1.0853684,,,,,,,,,,,,,, -47100,3.8303347,1.1614951,,,,,,,,,,,,,, -47200,3.1214168,1.1221646,,,,,,,,,,,,,, -47300,4.1428647,1.1102482,,,,,,,,,,,,,, -47400,3.8519106,1.1310966,,,,,,,,,,,,,, -47439,,,0.22175954,0.0753316180842984,0.41188765,0.1202004305975264,5348.0,0.21931432,0.0715373834623118,2472.0,37509.99685645104,41214.23322463036,37509.99685645104,3701.1584811210632,1.0890610218048096,0.0 -47500,4.1843677,1.1115756,,,,,,,,,,,,,, -47600,2.811217,1.1058462,,,,,,,,,,,,,, -47700,2.5444515,1.0804869,,,,,,,,,,,,,, -47800,6.3306375,1.0557166,,,,,,,,,,,,,, -47900,4.9606314,1.1182593,,,,,,,,,,,,,, -48000,,,0.20823953,0.0732299333058688,0.41214904,0.120325941087307,5348.0,0.21926215,0.0716389413604696,2472.0,37935.80281090736,41783.96035575867,37935.80281090736,3845.009263277054,1.1332612037658691,0.0 -48000,,,,,,,,,,,37935.802810907364,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 50fea448b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -192.73394989967343,0.0,16.34257674217224,1,0,16.34257674217224,30.942331,2472,3.698840208803039,209.0766015052796,32.25542,3.4835125993757408,30.891924,5348,3.325159060409164 -307.60274863243103,0.0288844108581542,1456.8843939304352,1812,0,1456.8843939304352,6.340809,2472,0.899396746084943,1764.5923347473145,6.4745483,0.9413882809148034,6.38616,5348,0.8965214285024667 -429.35713052749634,0.06998872756958,2897.059493541717,3641,0,2897.059493541717,2.53634,2472,0.5975869843397721,3326.6420726776123,3.3249242,0.7415177804500945,3.0489736,5348,0.6717804145708024 -562.3317792415619,0.11525559425354,4337.671471118927,5468,0,4337.671471118927,0.56013215,2472,0.1844291430544553,4900.35649228096,0.7368655,0.238899315462307,0.8666504,5348,0.2528746729486276 -698.1108963489532,0.1686549186706543,5777.897728443146,7267,0,5777.897728443146,0.47255045,2472,0.1568053947555501,6476.492695569992,0.63280827,0.2065148378191856,0.77585036,5348,0.226729872462033 -832.2964282035828,0.2222280502319336,7217.773640632629,9075,0,7217.773640632629,0.42079455,2472,0.1399061605021022,8050.689073801041,0.5822703,0.1911793322115769,0.71564424,5348,0.2078743350357705 -967.5596957206726,0.2755579948425293,8657.661759138107,10886,0,8657.661759138107,0.3917676,2472,0.129851928584486,9625.975238323212,0.5224885,0.1712662956468705,0.66655123,5348,0.19406818115991 -1102.7841680049896,0.329695463180542,10097.6127679348,12704,0,10097.6127679348,0.36725986,2472,0.1205289135336055,11201.28601717949,0.47296602,0.1585457439415338,0.6403958,5348,0.1855720864670728 -1234.0966272354126,0.3797941207885742,11538.042315006256,14498,0,11538.042315006256,0.3434072,2472,0.1144760628034042,12773.159367084503,0.43832353,0.148441422116745,0.60433036,5348,0.1755312472846288 -1368.8134655952454,0.4296009540557861,12978.146248579023,16299,0,12978.146248579023,0.34683496,2472,0.1129526943310381,14348.110781908035,0.42478886,0.144852251930545,0.6073105,5348,0.1771242650395358 -1502.5155773162842,0.478834867477417,14418.65439915657,18099,0,14418.65439915657,0.3229248,2472,0.105945199358154,15922.452051639557,0.41303357,0.1393726834323064,0.57551175,5348,0.1683674947140774 -1637.3995015621183,0.5337679386138916,15858.754093170166,19912,0,15858.754093170166,0.31086192,2472,0.1017203907947921,17497.57366490364,0.37763608,0.127199749558373,0.54934216,5348,0.1614644177761472 -1770.5106179714203,0.5898373126983643,17298.95994949341,21698,0,17298.95994949341,0.30279443,2472,0.098531472792639,19071.02762985229,0.36806503,0.1274199439912334,0.5384376,5348,0.1559033376135628 -1903.7190117836,0.6465094089508057,18738.897290468216,23505,0,18738.897290468216,0.2927988,2472,0.0941035484329616,20644.312327861782,0.32012546,0.1089708434865694,0.52728623,5348,0.152302152022167 -2036.9786303043363,0.6984333992004395,20178.92725777626,25316,0,20178.92725777626,0.2822816,2472,0.0930676578717526,22217.73483490944,0.28483084,0.0978063244464834,0.5107498,5348,0.1476775732064068 -2170.779187440872,0.7506082057952881,21619.27502799034,27130,0,21619.27502799034,0.27239585,2472,0.0885787987731805,23792.0169081688,0.3214634,0.1132856159048059,0.5000843,5348,0.1443853365129324 -2302.746828556061,0.8038895130157471,23059.81148886681,28917,0,23059.81148886681,0.26738203,2472,0.0866288871285519,25364.654709339145,0.29604378,0.1013367463026166,0.4922317,5348,0.1427730094519053 -2436.196858882904,0.8572819232940674,24500.349231004715,30727,0,24500.349231004715,0.25862825,2472,0.0829321796356102,26938.776789188385,0.2781695,0.0952528094455918,0.48324794,5348,0.1397897216563522 -2570.2126417160034,0.9131293296813964,25940.43915891648,32535,0,25940.43915891648,0.25137788,2472,0.0823837669855584,28513.021070718765,0.24887384,0.0883252373592404,0.46815404,5348,0.1354934010446334 -2701.2522921562195,0.965153694152832,27380.701331615448,34360,0,27380.701331615448,0.24266852,2472,0.0781183352629334,30084.456929922104,0.23843347,0.0856507230255839,0.45768866,5348,0.1327418249225214 -2833.7774863243103,1.021730661392212,28821.154938220978,36156,0,28821.154938220978,0.23776175,2472,0.0775902341925131,31657.57952785492,0.22635128,0.0811560540930281,0.4513795,5348,0.1310329513308939 -2988.8153936862946,1.078145980834961,30261.48275089264,37981,0,30261.48275089264,0.2351409,2472,0.0758840615034631,33253.08378100395,0.13815072,0.0506929093691549,0.4403994,5348,0.1277503692904795 -3125.0316421985626,1.131788969039917,31702.03856253624,39793,0,31702.03856253624,0.23304227,2472,0.0750715983182012,34829.99032139778,0.13410744,0.0472328863906739,0.43591997,5348,0.1262152794539328 -3259.6550998687744,1.1871211528778076,33142.06616973877,41616,0,33142.06616973877,0.22871488,2472,0.0732029329920987,36404.78058648109,0.14266028,0.0509104363223129,0.4308139,5348,0.1242457302296842 -3396.0110619068146,1.2427010536193848,34582.3564324379,43415,0,34582.3564324379,0.2267602,2472,0.0727154550809416,37981.5632686615,0.124804124,0.0457206361131981,0.42841774,5348,0.1241009104337835 -3531.409448623657,1.2997477054595947,36022.44448566437,45201,0,36022.44448566437,0.22566003,2472,0.0728576361383624,39557.18790698052,0.13416459,0.0478005661856043,0.42583966,5348,0.1228844241482182 -3667.886235713959,1.3547301292419434,37463.04101443291,46991,0,37463.04101443291,0.22541308,2472,0.0722686003290475,41134.39881324768,0.12691087,0.0466410933142423,0.42552513,5348,0.1227589136584376 -3803.54567193985,1.418626070022583,38255.23158955574,48000,0,38255.23158955574,0.22542043,2472,0.07234984664757378,42062.36226296425,0.15317775,0.05205918963001116,0.4256426,5348,0.12272994969925756 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/measurements.csv deleted file mode 100644 index 111633c02..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.78763,32.913765,,,,,,,,,,,,,, -1,,,32.25542,3.4835125993757408,30.891924,3.325159060409164,5348.0,30.942331,3.698840208803039,2472.0,16.34257674217224,209.0766015052796,16.34257674217224,192.73394989967343,0.0,0.0 -100,4.4774933,7.2492824,,,,,,,,,,,,,, -200,1.3349208,5.9699445,,,,,,,,,,,,,, -300,0.7663013,5.8242435,,,,,,,,,,,,,, -400,0.6974452,5.7334127,,,,,,,,,,,,,, -500,1.2205541,5.678003,,,,,,,,,,,,,, -600,1.0400324,5.4724474,,,,,,,,,,,,,, -700,1.6835965,5.1179633,,,,,,,,,,,,,, -800,1.7046168,4.7520065,,,,,,,,,,,,,, -900,1.4046686,4.3127637,,,,,,,,,,,,,, -1000,2.7943096,4.0375648,,,,,,,,,,,,,, -1100,2.0079381,3.7787504,,,,,,,,,,,,,, -1200,1.814777,3.5943806,,,,,,,,,,,,,, -1300,3.5461576,3.4309013,,,,,,,,,,,,,, -1400,3.7367198,3.3179047,,,,,,,,,,,,,, -1500,2.4837136,3.2482386,,,,,,,,,,,,,, -1600,5.0074744,3.086563,,,,,,,,,,,,,, -1700,3.6230555,2.9639351,,,,,,,,,,,,,, -1800,2.8228085,2.86632,,,,,,,,,,,,,, -1812,,,6.4745483,0.9413882809148034,6.38616,0.8965214285024667,5348.0,6.340809,0.899396746084943,2472.0,1456.8843939304352,1764.5923347473145,1456.8843939304352,307.60274863243103,0.0288844108581542,0.0 -1900,2.2509425,2.8137703,,,,,,,,,,,,,, -2000,2.8940759,2.7320895,,,,,,,,,,,,,, -2100,1.9355856,2.593678,,,,,,,,,,,,,, -2200,2.5790052,2.5557854,,,,,,,,,,,,,, -2300,2.906076,2.5428007,,,,,,,,,,,,,, -2400,2.1730206,2.4053557,,,,,,,,,,,,,, -2500,2.731586,2.4284673,,,,,,,,,,,,,, -2600,2.443274,2.3378265,,,,,,,,,,,,,, -2700,2.3771074,2.3271246,,,,,,,,,,,,,, -2800,4.1734967,2.3115296,,,,,,,,,,,,,, -2900,2.3372679,2.2732255,,,,,,,,,,,,,, -3000,2.227771,2.1976898,,,,,,,,,,,,,, -3100,3.1249208,2.1815803,,,,,,,,,,,,,, -3200,3.167398,2.1689928,,,,,,,,,,,,,, -3300,3.988006,2.129016,,,,,,,,,,,,,, -3400,2.4032407,2.1411405,,,,,,,,,,,,,, -3500,5.723097,2.0866642,,,,,,,,,,,,,, -3600,2.9210892,2.058356,,,,,,,,,,,,,, -3641,,,3.3249242,0.7415177804500945,3.0489736,0.6717804145708024,5348.0,2.53634,0.5975869843397721,2472.0,2897.059493541717,3326.6420726776123,2897.059493541717,429.35713052749634,0.06998872756958,0.0 -3700,3.8312495,2.0300944,,,,,,,,,,,,,, -3800,1.9960449,1.9501117,,,,,,,,,,,,,, -3900,3.5684154,2.0301197,,,,,,,,,,,,,, -4000,2.5394678,2.0361395,,,,,,,,,,,,,, -4100,3.033378,1.969026,,,,,,,,,,,,,, -4200,2.4869704,1.9483844,,,,,,,,,,,,,, -4300,3.2625873,1.9561815,,,,,,,,,,,,,, -4400,3.5879679,1.9705051,,,,,,,,,,,,,, -4500,2.9486134,1.8689724,,,,,,,,,,,,,, -4600,2.7956579,1.9277416,,,,,,,,,,,,,, -4700,2.3894794,1.8656678,,,,,,,,,,,,,, -4800,3.045536,1.8618041,,,,,,,,,,,,,, -4900,3.4142184,1.8365002,,,,,,,,,,,,,, -5000,2.992753,1.8777643,,,,,,,,,,,,,, -5100,2.9889145,1.8094829,,,,,,,,,,,,,, -5200,3.5895486,1.7672251,,,,,,,,,,,,,, -5300,2.8685663,1.8134192,,,,,,,,,,,,,, -5400,3.0525854,1.7304639,,,,,,,,,,,,,, -5468,,,0.7368655,0.238899315462307,0.8666504,0.2528746729486276,5348.0,0.56013215,0.1844291430544553,2472.0,4337.671471118927,4900.35649228096,4337.671471118927,562.3317792415619,0.11525559425354,0.0 -5500,3.420125,1.742717,,,,,,,,,,,,,, -5600,5.0763946,1.8251541,,,,,,,,,,,,,, -5700,4.546139,1.7923913,,,,,,,,,,,,,, -5800,4.60231,1.8366966,,,,,,,,,,,,,, -5900,3.1438982,1.7859616,,,,,,,,,,,,,, -6000,3.500972,1.8049006,,,,,,,,,,,,,, -6100,3.09005,1.7355376,,,,,,,,,,,,,, -6200,2.5828216,1.6995144,,,,,,,,,,,,,, -6300,1.753697,1.6995361,,,,,,,,,,,,,, -6400,2.6382964,1.7380397,,,,,,,,,,,,,, -6500,2.8703253,1.7740304,,,,,,,,,,,,,, -6600,2.7492476,1.7139684,,,,,,,,,,,,,, -6700,3.353859,1.6888007,,,,,,,,,,,,,, -6800,2.288644,1.7070049,,,,,,,,,,,,,, -6900,2.6505826,1.704978,,,,,,,,,,,,,, -7000,2.4917333,1.6217847,,,,,,,,,,,,,, -7100,2.246232,1.648182,,,,,,,,,,,,,, -7200,1.7352729,1.7117652,,,,,,,,,,,,,, -7267,,,0.63280827,0.2065148378191856,0.77585036,0.226729872462033,5348.0,0.47255045,0.1568053947555501,2472.0,5777.897728443146,6476.492695569992,5777.897728443146,698.1108963489532,0.1686549186706543,0.0 -7300,1.9846826,1.6202906,,,,,,,,,,,,,, -7400,2.547655,1.6288592,,,,,,,,,,,,,, -7500,2.9117093,1.6328871,,,,,,,,,,,,,, -7600,2.4439614,1.656302,,,,,,,,,,,,,, -7700,2.483525,1.65456,,,,,,,,,,,,,, -7800,3.7297742,1.6928427,,,,,,,,,,,,,, -7900,3.0344527,1.6395848,,,,,,,,,,,,,, -8000,2.5956392,1.6775082,,,,,,,,,,,,,, -8100,4.627241,1.6570212,,,,,,,,,,,,,, -8200,2.5982502,1.690181,,,,,,,,,,,,,, -8300,5.0289354,1.6942756,,,,,,,,,,,,,, -8400,2.7173355,1.6504359,,,,,,,,,,,,,, -8500,3.0500455,1.6930513,,,,,,,,,,,,,, -8600,2.1839442,1.6516799,,,,,,,,,,,,,, -8700,2.5530772,1.6549404,,,,,,,,,,,,,, -8800,2.6962943,1.6445426,,,,,,,,,,,,,, -8900,2.1019318,1.6196337,,,,,,,,,,,,,, -9000,3.0093377,1.6288755,,,,,,,,,,,,,, -9075,,,0.5822703,0.1911793322115769,0.71564424,0.2078743350357705,5348.0,0.42079455,0.1399061605021022,2472.0,7217.773640632629,8050.689073801041,7217.773640632629,832.2964282035828,0.2222280502319336,0.0 -9100,2.3033981,1.6807352,,,,,,,,,,,,,, -9200,3.234028,1.5820404,,,,,,,,,,,,,, -9300,2.6465588,1.6072575,,,,,,,,,,,,,, -9400,2.3157856,1.6150616,,,,,,,,,,,,,, -9500,3.2909567,1.5318547,,,,,,,,,,,,,, -9600,2.900956,1.6483701,,,,,,,,,,,,,, -9700,2.5430696,1.5606229,,,,,,,,,,,,,, -9800,2.3875203,1.6192666,,,,,,,,,,,,,, -9900,3.2627394,1.5615594,,,,,,,,,,,,,, -10000,3.6557748,1.5799015,,,,,,,,,,,,,, -10100,4.0142655,1.570082,,,,,,,,,,,,,, -10200,4.5163083,1.617841,,,,,,,,,,,,,, -10300,2.6774156,1.5036432,,,,,,,,,,,,,, -10400,2.0314605,1.5556836,,,,,,,,,,,,,, -10500,2.478811,1.5842742,,,,,,,,,,,,,, -10600,7.764841,1.5808142,,,,,,,,,,,,,, -10700,2.5303538,1.5771611,,,,,,,,,,,,,, -10800,2.4700747,1.502601,,,,,,,,,,,,,, -10886,,,0.5224885,0.1712662956468705,0.66655123,0.19406818115991,5348.0,0.3917676,0.129851928584486,2472.0,8657.661759138107,9625.975238323212,8657.661759138107,967.5596957206726,0.2755579948425293,0.0 -10900,3.9247017,1.5694004,,,,,,,,,,,,,, -11000,2.259468,1.5609794,,,,,,,,,,,,,, -11100,1.9558823,1.4531418,,,,,,,,,,,,,, -11200,3.1739404,1.4964098,,,,,,,,,,,,,, -11300,2.5093822,1.5217556,,,,,,,,,,,,,, -11400,2.0887682,1.5211768,,,,,,,,,,,,,, -11500,2.9614103,1.5200183,,,,,,,,,,,,,, -11600,2.8275526,1.545657,,,,,,,,,,,,,, -11700,1.8187097,1.5577157,,,,,,,,,,,,,, -11800,2.8365002,1.5842986,,,,,,,,,,,,,, -11900,3.200417,1.613,,,,,,,,,,,,,, -12000,2.8142073,1.5172932,,,,,,,,,,,,,, -12100,3.4837344,1.447014,,,,,,,,,,,,,, -12200,2.6515675,1.5154862,,,,,,,,,,,,,, -12300,2.624823,1.5901347,,,,,,,,,,,,,, -12400,2.8812692,1.5277976,,,,,,,,,,,,,, -12500,2.3910477,1.5517874,,,,,,,,,,,,,, -12600,2.972629,1.5433819,,,,,,,,,,,,,, -12700,2.5660827,1.5028023,,,,,,,,,,,,,, -12704,,,0.47296602,0.1585457439415338,0.6403958,0.1855720864670728,5348.0,0.36725986,0.1205289135336055,2472.0,10097.6127679348,11201.28601717949,10097.6127679348,1102.7841680049896,0.329695463180542,0.0 -12800,5.1555533,1.4905111,,,,,,,,,,,,,, -12900,1.6572422,1.4932581,,,,,,,,,,,,,, -13000,2.0698392,1.5567439,,,,,,,,,,,,,, -13100,2.481153,1.5982813,,,,,,,,,,,,,, -13200,3.09921,1.4930303,,,,,,,,,,,,,, -13300,2.7066677,1.5392909,,,,,,,,,,,,,, -13400,3.6483238,1.4807351,,,,,,,,,,,,,, -13500,2.2602837,1.4720837,,,,,,,,,,,,,, -13600,2.544429,1.5193412,,,,,,,,,,,,,, -13700,1.8812248,1.4445794,,,,,,,,,,,,,, -13800,2.038808,1.4438094,,,,,,,,,,,,,, -13900,1.9310932,1.5072912,,,,,,,,,,,,,, -14000,2.7403646,1.5025936,,,,,,,,,,,,,, -14100,2.7627192,1.4567078,,,,,,,,,,,,,, -14200,2.6453907,1.4911069,,,,,,,,,,,,,, -14300,2.1838942,1.4991989,,,,,,,,,,,,,, -14400,4.5670104,1.4763968,,,,,,,,,,,,,, -14498,,,0.43832353,0.148441422116745,0.60433036,0.1755312472846288,5348.0,0.3434072,0.1144760628034042,2472.0,11538.042315006256,12773.159367084503,11538.042315006256,1234.0966272354126,0.3797941207885742,0.0 -14500,2.5772922,1.4849247,,,,,,,,,,,,,, -14600,4.1758122,1.4726443,,,,,,,,,,,,,, -14700,2.6362734,1.4801989,,,,,,,,,,,,,, -14800,2.7510662,1.5239469,,,,,,,,,,,,,, -14900,3.6535423,1.4230326,,,,,,,,,,,,,, -15000,3.5201342,1.4244206,,,,,,,,,,,,,, -15100,2.549618,1.4906424,,,,,,,,,,,,,, -15200,2.473958,1.4426857,,,,,,,,,,,,,, -15300,3.382513,1.4415565,,,,,,,,,,,,,, -15400,3.04604,1.5103061,,,,,,,,,,,,,, -15500,2.226354,1.494233,,,,,,,,,,,,,, -15600,2.5359392,1.4347786,,,,,,,,,,,,,, -15700,2.193418,1.5101528,,,,,,,,,,,,,, -15800,2.713631,1.4846092,,,,,,,,,,,,,, -15900,1.9313406,1.49574,,,,,,,,,,,,,, -16000,2.091552,1.4894185,,,,,,,,,,,,,, -16100,2.7771518,1.4875818,,,,,,,,,,,,,, -16200,2.3597012,1.4488392,,,,,,,,,,,,,, -16299,,,0.42478886,0.144852251930545,0.6073105,0.1771242650395358,5348.0,0.34683496,0.1129526943310381,2472.0,12978.146248579023,14348.110781908035,12978.146248579023,1368.8134655952454,0.4296009540557861,0.0 -16300,2.4412184,1.5606024,,,,,,,,,,,,,, -16400,2.7535572,1.393798,,,,,,,,,,,,,, -16500,3.5274725,1.4138056,,,,,,,,,,,,,, -16600,1.9609637,1.4367876,,,,,,,,,,,,,, -16700,2.2480407,1.4071124,,,,,,,,,,,,,, -16800,2.036733,1.4411527,,,,,,,,,,,,,, -16900,2.6197288,1.4371232,,,,,,,,,,,,,, -17000,2.4969382,1.419488,,,,,,,,,,,,,, -17100,2.1540153,1.4609417,,,,,,,,,,,,,, -17200,1.7534873,1.3986961,,,,,,,,,,,,,, -17300,3.8188088,1.4367738,,,,,,,,,,,,,, -17400,2.6435897,1.4287271,,,,,,,,,,,,,, -17500,3.630775,1.4365761,,,,,,,,,,,,,, -17600,2.506574,1.3964562,,,,,,,,,,,,,, -17700,2.3764043,1.4277171,,,,,,,,,,,,,, -17800,2.092467,1.4831823,,,,,,,,,,,,,, -17900,2.2817411,1.4669124,,,,,,,,,,,,,, -18000,2.0819483,1.3683827,,,,,,,,,,,,,, -18099,,,0.41303357,0.1393726834323064,0.57551175,0.1683674947140774,5348.0,0.3229248,0.105945199358154,2472.0,14418.65439915657,15922.452051639557,14418.65439915657,1502.5155773162842,0.478834867477417,0.0 -18100,3.0825794,1.4139452,,,,,,,,,,,,,, -18200,2.1380584,1.4165651,,,,,,,,,,,,,, -18300,1.8297,1.4453721,,,,,,,,,,,,,, -18400,3.0842223,1.398817,,,,,,,,,,,,,, -18500,3.1667447,1.3875602,,,,,,,,,,,,,, -18600,3.8010087,1.3733215,,,,,,,,,,,,,, -18700,2.8633032,1.413715,,,,,,,,,,,,,, -18800,1.8663814,1.3510292,,,,,,,,,,,,,, -18900,2.0064158,1.3572179,,,,,,,,,,,,,, -19000,3.1938,1.3794377,,,,,,,,,,,,,, -19100,2.1990356,1.4039582,,,,,,,,,,,,,, -19200,2.213252,1.4518203,,,,,,,,,,,,,, -19300,2.0994203,1.3572131,,,,,,,,,,,,,, -19400,2.108618,1.4196265,,,,,,,,,,,,,, -19500,3.0521033,1.357278,,,,,,,,,,,,,, -19600,2.6516335,1.3940939,,,,,,,,,,,,,, -19700,2.3004813,1.4349021,,,,,,,,,,,,,, -19800,2.5968645,1.3967568,,,,,,,,,,,,,, -19900,2.1151555,1.3934339,,,,,,,,,,,,,, -19912,,,0.37763608,0.127199749558373,0.54934216,0.1614644177761472,5348.0,0.31086192,0.1017203907947921,2472.0,15858.754093170166,17497.57366490364,15858.754093170166,1637.3995015621183,0.5337679386138916,0.0 -20000,2.5187156,1.3730743,,,,,,,,,,,,,, -20100,3.5232086,1.3710836,,,,,,,,,,,,,, -20200,1.9735963,1.3895521,,,,,,,,,,,,,, -20300,2.3040643,1.3871655,,,,,,,,,,,,,, -20400,3.4932423,1.3395474,,,,,,,,,,,,,, -20500,4.0514765,1.342814,,,,,,,,,,,,,, -20600,2.2287107,1.3473519,,,,,,,,,,,,,, -20700,2.522094,1.3242118,,,,,,,,,,,,,, -20800,2.4351068,1.3731008,,,,,,,,,,,,,, -20900,2.211717,1.3645817,,,,,,,,,,,,,, -21000,1.9816672,1.3777344,,,,,,,,,,,,,, -21100,2.1821818,1.3864462,,,,,,,,,,,,,, -21200,2.2473888,1.3966819,,,,,,,,,,,,,, -21300,1.9304173,1.3934717,,,,,,,,,,,,,, -21400,2.7974584,1.4030423,,,,,,,,,,,,,, -21500,1.6304425,1.3888038,,,,,,,,,,,,,, -21600,3.2908988,1.366768,,,,,,,,,,,,,, -21698,,,0.36806503,0.1274199439912334,0.5384376,0.1559033376135628,5348.0,0.30279443,0.098531472792639,2472.0,17298.95994949341,19071.02762985229,17298.95994949341,1770.5106179714203,0.5898373126983643,0.0 -21700,2.6697254,1.3282948,,,,,,,,,,,,,, -21800,3.3975246,1.4406682,,,,,,,,,,,,,, -21900,2.2200608,1.344935,,,,,,,,,,,,,, -22000,2.6450882,1.3448251,,,,,,,,,,,,,, -22100,3.230626,1.3039978,,,,,,,,,,,,,, -22200,2.0931315,1.3191013,,,,,,,,,,,,,, -22300,4.2503853,1.3747733,,,,,,,,,,,,,, -22400,3.1089327,1.3622265,,,,,,,,,,,,,, -22500,2.0316968,1.3837355,,,,,,,,,,,,,, -22600,1.7239909,1.3814272,,,,,,,,,,,,,, -22700,2.420113,1.3171387,,,,,,,,,,,,,, -22800,2.6787472,1.3250034,,,,,,,,,,,,,, -22900,1.8909793,1.3289926,,,,,,,,,,,,,, -23000,2.4385076,1.3896157,,,,,,,,,,,,,, -23100,1.8941182,1.28519,,,,,,,,,,,,,, -23200,2.0497801,1.2808133,,,,,,,,,,,,,, -23300,2.0657651,1.3604945,,,,,,,,,,,,,, -23400,2.3214734,1.3027382,,,,,,,,,,,,,, -23500,2.185044,1.3039154,,,,,,,,,,,,,, -23505,,,0.32012546,0.1089708434865694,0.52728623,0.152302152022167,5348.0,0.2927988,0.0941035484329616,2472.0,18738.897290468216,20644.312327861782,18738.897290468216,1903.7190117836,0.6465094089508057,0.0 -23600,2.064898,1.2938837,,,,,,,,,,,,,, -23700,1.9924587,1.3249409,,,,,,,,,,,,,, -23800,2.2640688,1.2740283,,,,,,,,,,,,,, -23900,2.6360989,1.314354,,,,,,,,,,,,,, -24000,2.5016732,1.3768048,,,,,,,,,,,,,, -24100,2.3598843,1.2795949,,,,,,,,,,,,,, -24200,2.2908144,1.3542455,,,,,,,,,,,,,, -24300,2.5886166,1.3670201,,,,,,,,,,,,,, -24400,1.848461,1.2969255,,,,,,,,,,,,,, -24500,4.312531,1.3381038,,,,,,,,,,,,,, -24600,4.1511774,1.3111702,,,,,,,,,,,,,, -24700,2.0887294,1.2906119,,,,,,,,,,,,,, -24800,2.6052823,1.3893256,,,,,,,,,,,,,, -24900,2.0414937,1.3362398,,,,,,,,,,,,,, -25000,2.9907136,1.280356,,,,,,,,,,,,,, -25100,1.9021877,1.3220075,,,,,,,,,,,,,, -25200,2.1594903,1.3714929,,,,,,,,,,,,,, -25300,2.4150167,1.3509185,,,,,,,,,,,,,, -25316,,,0.28483084,0.0978063244464834,0.5107498,0.1476775732064068,5348.0,0.2822816,0.0930676578717526,2472.0,20178.92725777626,22217.73483490944,20178.92725777626,2036.9786303043363,0.6984333992004395,0.0 -25400,2.6260755,1.3170123,,,,,,,,,,,,,, -25500,3.3421905,1.3211538,,,,,,,,,,,,,, -25600,3.526776,1.3329298,,,,,,,,,,,,,, -25700,3.2314186,1.3306417,,,,,,,,,,,,,, -25800,1.653644,1.3201473,,,,,,,,,,,,,, -25900,2.404283,1.301958,,,,,,,,,,,,,, -26000,2.5434804,1.3494122,,,,,,,,,,,,,, -26100,1.7111071,1.3116468,,,,,,,,,,,,,, -26200,1.8404249,1.2861341,,,,,,,,,,,,,, -26300,2.1484072,1.3124818,,,,,,,,,,,,,, -26400,1.3898453,1.3374335,,,,,,,,,,,,,, -26500,1.9526825,1.2969719,,,,,,,,,,,,,, -26600,2.9733336,1.3027667,,,,,,,,,,,,,, -26700,3.2882178,1.3198304,,,,,,,,,,,,,, -26800,1.7968477,1.2580109,,,,,,,,,,,,,, -26900,2.396325,1.2801135,,,,,,,,,,,,,, -27000,2.006491,1.2927437,,,,,,,,,,,,,, -27100,2.31905,1.28891,,,,,,,,,,,,,, -27130,,,0.3214634,0.1132856159048059,0.5000843,0.1443853365129324,5348.0,0.27239585,0.0885787987731805,2472.0,21619.27502799034,23792.0169081688,21619.27502799034,2170.779187440872,0.7506082057952881,0.0 -27200,1.9587927,1.2112231,,,,,,,,,,,,,, -27300,1.9936885,1.2615575,,,,,,,,,,,,,, -27400,4.2525883,1.2807431,,,,,,,,,,,,,, -27500,1.9075942,1.3411835,,,,,,,,,,,,,, -27600,3.3530788,1.3171381,,,,,,,,,,,,,, -27700,2.6436923,1.2676532,,,,,,,,,,,,,, -27800,1.7301478,1.3269604,,,,,,,,,,,,,, -27900,2.429313,1.3083919,,,,,,,,,,,,,, -28000,1.9072136,1.2591352,,,,,,,,,,,,,, -28100,2.2532468,1.2637221,,,,,,,,,,,,,, -28200,2.3907757,1.2891952,,,,,,,,,,,,,, -28300,1.8600276,1.2206057,,,,,,,,,,,,,, -28400,3.0006561,1.2798612,,,,,,,,,,,,,, -28500,2.7456043,1.2657183,,,,,,,,,,,,,, -28600,2.8239453,1.333205,,,,,,,,,,,,,, -28700,3.2886379,1.3277344,,,,,,,,,,,,,, -28800,3.5931697,1.337908,,,,,,,,,,,,,, -28900,2.25613,1.2958183,,,,,,,,,,,,,, -28917,,,0.29604378,0.1013367463026166,0.4922317,0.1427730094519053,5348.0,0.26738203,0.0866288871285519,2472.0,23059.81148886681,25364.654709339145,23059.81148886681,2302.746828556061,0.8038895130157471,0.0 -29000,1.8743668,1.2604773,,,,,,,,,,,,,, -29100,3.1739879,1.2616891,,,,,,,,,,,,,, -29200,2.341958,1.2824537,,,,,,,,,,,,,, -29300,4.184691,1.2745438,,,,,,,,,,,,,, -29400,1.3962052,1.2461772,,,,,,,,,,,,,, -29500,1.7398034,1.2759571,,,,,,,,,,,,,, -29600,5.2272935,1.2596942,,,,,,,,,,,,,, -29700,2.7934122,1.2540041,,,,,,,,,,,,,, -29800,2.415404,1.1776129,,,,,,,,,,,,,, -29900,3.1491804,1.2467363,,,,,,,,,,,,,, -30000,4.9470577,1.1983153,,,,,,,,,,,,,, -30100,2.207154,1.2364696,,,,,,,,,,,,,, -30200,2.4903932,1.2777675,,,,,,,,,,,,,, -30300,1.7220417,1.247342,,,,,,,,,,,,,, -30400,2.9669504,1.2148094,,,,,,,,,,,,,, -30500,1.9986665,1.2872682,,,,,,,,,,,,,, -30600,2.0743616,1.2583822,,,,,,,,,,,,,, -30700,2.3887708,1.294175,,,,,,,,,,,,,, -30727,,,0.2781695,0.0952528094455918,0.48324794,0.1397897216563522,5348.0,0.25862825,0.0829321796356102,2472.0,24500.349231004715,26938.776789188385,24500.349231004715,2436.196858882904,0.8572819232940674,0.0 -30800,1.910796,1.2468239,,,,,,,,,,,,,, -30900,2.0474508,1.2380093,,,,,,,,,,,,,, -31000,3.3475227,1.2734472,,,,,,,,,,,,,, -31100,1.8778102,1.2324755,,,,,,,,,,,,,, -31200,2.401359,1.1858331,,,,,,,,,,,,,, -31300,2.6616719,1.2596244,,,,,,,,,,,,,, -31400,1.9396768,1.2359949,,,,,,,,,,,,,, -31500,2.655327,1.1778469,,,,,,,,,,,,,, -31600,2.1909459,1.2386774,,,,,,,,,,,,,, -31700,3.2371628,1.2275294,,,,,,,,,,,,,, -31800,1.966151,1.2626181,,,,,,,,,,,,,, -31900,2.180745,1.1560946,,,,,,,,,,,,,, -32000,2.1663933,1.2058476,,,,,,,,,,,,,, -32100,3.0655742,1.2372869,,,,,,,,,,,,,, -32200,1.9955612,1.1937207,,,,,,,,,,,,,, -32300,2.9905822,1.2499198,,,,,,,,,,,,,, -32400,2.2749307,1.2726705,,,,,,,,,,,,,, -32500,2.5147533,1.2370372,,,,,,,,,,,,,, -32535,,,0.24887384,0.0883252373592404,0.46815404,0.1354934010446334,5348.0,0.25137788,0.0823837669855584,2472.0,25940.43915891648,28513.021070718765,25940.43915891648,2570.2126417160034,0.9131293296813964,0.0 -32600,2.3547022,1.2369031,,,,,,,,,,,,,, -32700,2.3779364,1.2514741,,,,,,,,,,,,,, -32800,2.8747063,1.2059364,,,,,,,,,,,,,, -32900,3.4333425,1.2005785,,,,,,,,,,,,,, -33000,2.294537,1.1799369,,,,,,,,,,,,,, -33100,2.3867579,1.2621167,,,,,,,,,,,,,, -33200,1.9458594,1.2295648,,,,,,,,,,,,,, -33300,2.3463469,1.2592301,,,,,,,,,,,,,, -33400,2.5953243,1.2091588,,,,,,,,,,,,,, -33500,3.042392,1.2309467,,,,,,,,,,,,,, -33600,2.7753956,1.2333854,,,,,,,,,,,,,, -33700,2.5299873,1.2073534,,,,,,,,,,,,,, -33800,1.7268276,1.1707598,,,,,,,,,,,,,, -33900,1.7414007,1.165849,,,,,,,,,,,,,, -34000,2.2352636,1.2055873,,,,,,,,,,,,,, -34100,2.9568002,1.2609068,,,,,,,,,,,,,, -34200,2.8192892,1.1855336,,,,,,,,,,,,,, -34300,2.3410747,1.2451236,,,,,,,,,,,,,, -34360,,,0.23843347,0.0856507230255839,0.45768866,0.1327418249225214,5348.0,0.24266852,0.0781183352629334,2472.0,27380.701331615448,30084.456929922104,27380.701331615448,2701.2522921562195,0.965153694152832,0.0 -34400,3.412038,1.1937033,,,,,,,,,,,,,, -34500,2.550304,1.186504,,,,,,,,,,,,,, -34600,2.13714,1.2058493,,,,,,,,,,,,,, -34700,2.5047126,1.1898122,,,,,,,,,,,,,, -34800,2.2013042,1.1742618,,,,,,,,,,,,,, -34900,2.9545665,1.2479941,,,,,,,,,,,,,, -35000,2.730922,1.2196094,,,,,,,,,,,,,, -35100,1.8577342,1.2060758,,,,,,,,,,,,,, -35200,2.3643644,1.1543674,,,,,,,,,,,,,, -35300,2.2400734,1.1409662,,,,,,,,,,,,,, -35400,2.7991698,1.1407396,,,,,,,,,,,,,, -35500,3.0792975,1.211035,,,,,,,,,,,,,, -35600,1.9640996,1.162506,,,,,,,,,,,,,, -35700,2.6374516,1.1915683,,,,,,,,,,,,,, -35800,1.7224112,1.1326818,,,,,,,,,,,,,, -35900,1.9186693,1.1754991,,,,,,,,,,,,,, -36000,2.7415287,1.1885959,,,,,,,,,,,,,, -36100,1.7507392,1.1602198,,,,,,,,,,,,,, -36156,,,0.22635128,0.0811560540930281,0.4513795,0.1310329513308939,5348.0,0.23776175,0.0775902341925131,2472.0,28821.154938220978,31657.57952785492,28821.154938220978,2833.7774863243103,1.021730661392212,0.0 -36200,2.6861856,1.1717975,,,,,,,,,,,,,, -36300,2.6874804,1.1631193,,,,,,,,,,,,,, -36400,2.8600566,1.1978353,,,,,,,,,,,,,, -36500,2.6765745,1.1801562,,,,,,,,,,,,,, -36600,2.0191388,1.2006468,,,,,,,,,,,,,, -36700,2.5787387,1.152749,,,,,,,,,,,,,, -36800,2.0800552,1.1231458,,,,,,,,,,,,,, -36900,1.98818,1.125346,,,,,,,,,,,,,, -37000,2.6534672,1.1793804,,,,,,,,,,,,,, -37100,2.4143186,1.1662266,,,,,,,,,,,,,, -37200,2.4640417,1.1678263,,,,,,,,,,,,,, -37300,2.5064194,1.1839254,,,,,,,,,,,,,, -37400,1.928022,1.1746335,,,,,,,,,,,,,, -37500,3.5614414,1.1261997,,,,,,,,,,,,,, -37600,4.564933,1.1961573,,,,,,,,,,,,,, -37700,2.2035773,1.2247454,,,,,,,,,,,,,, -37800,5.3258343,1.1518518,,,,,,,,,,,,,, -37900,2.1552231,1.1485349,,,,,,,,,,,,,, -37981,,,0.13815072,0.0506929093691549,0.4403994,0.1277503692904795,5348.0,0.2351409,0.0758840615034631,2472.0,30261.48275089264,33253.08378100395,30261.48275089264,2988.8153936862946,1.078145980834961,0.0 -38000,2.8406672,1.1942657,,,,,,,,,,,,,, -38100,1.8016934,1.1572556,,,,,,,,,,,,,, -38200,2.2191331,1.1758374,,,,,,,,,,,,,, -38300,3.6824167,1.1792935,,,,,,,,,,,,,, -38400,2.1979253,1.2254528,,,,,,,,,,,,,, -38500,3.4643314,1.1610154,,,,,,,,,,,,,, -38600,2.0001216,1.1789502,,,,,,,,,,,,,, -38700,2.4066832,1.1617045,,,,,,,,,,,,,, -38800,1.9592751,1.1479583,,,,,,,,,,,,,, -38900,1.9016949,1.1789737,,,,,,,,,,,,,, -39000,3.2922106,1.1729615,,,,,,,,,,,,,, -39100,2.3913286,1.1393021,,,,,,,,,,,,,, -39200,3.874015,1.1629971,,,,,,,,,,,,,, -39300,1.9434277,1.1034985,,,,,,,,,,,,,, -39400,3.2225528,1.1592932,,,,,,,,,,,,,, -39500,2.7643292,1.1364964,,,,,,,,,,,,,, -39600,2.9059525,1.192086,,,,,,,,,,,,,, -39700,2.2064538,1.1355044,,,,,,,,,,,,,, -39793,,,0.13410744,0.0472328863906739,0.43591997,0.1262152794539328,5348.0,0.23304227,0.0750715983182012,2472.0,31702.03856253624,34829.99032139778,31702.03856253624,3125.0316421985626,1.131788969039917,0.0 -39800,2.907356,1.1973203,,,,,,,,,,,,,, -39900,2.048091,1.1628382,,,,,,,,,,,,,, -40000,1.8332583,1.0969931,,,,,,,,,,,,,, -40100,1.9743787,1.1224185,,,,,,,,,,,,,, -40200,3.8556669,1.123217,,,,,,,,,,,,,, -40300,2.0305433,1.1036106,,,,,,,,,,,,,, -40400,3.7781525,1.1739081,,,,,,,,,,,,,, -40500,2.9727829,1.1549926,,,,,,,,,,,,,, -40600,1.9829494,1.1526601,,,,,,,,,,,,,, -40700,1.8708662,1.1439856,,,,,,,,,,,,,, -40800,3.0030763,1.1044081,,,,,,,,,,,,,, -40900,2.3224614,1.158938,,,,,,,,,,,,,, -41000,2.0908546,1.1112221,,,,,,,,,,,,,, -41100,2.6537352,1.119181,,,,,,,,,,,,,, -41200,3.242029,1.1515832,,,,,,,,,,,,,, -41300,2.1310277,1.1096698,,,,,,,,,,,,,, -41400,3.503109,1.1461954,,,,,,,,,,,,,, -41500,2.3090842,1.1375383,,,,,,,,,,,,,, -41600,2.3289502,1.1331338,,,,,,,,,,,,,, -41616,,,0.14266028,0.0509104363223129,0.4308139,0.1242457302296842,5348.0,0.22871488,0.0732029329920987,2472.0,33142.06616973877,36404.78058648109,33142.06616973877,3259.6550998687744,1.1871211528778076,0.0 -41700,2.3562572,1.0695429,,,,,,,,,,,,,, -41800,2.4146912,1.0782968,,,,,,,,,,,,,, -41900,2.1856053,1.1255399,,,,,,,,,,,,,, -42000,1.9609665,1.1601455,,,,,,,,,,,,,, -42100,2.5934772,1.1480985,,,,,,,,,,,,,, -42200,2.035999,1.1535699,,,,,,,,,,,,,, -42300,2.441737,1.1397434,,,,,,,,,,,,,, -42400,3.1845162,1.1164265,,,,,,,,,,,,,, -42500,2.0572827,1.0754514,,,,,,,,,,,,,, -42600,2.724914,1.1899692,,,,,,,,,,,,,, -42700,2.9430764,1.1248411,,,,,,,,,,,,,, -42800,2.0197914,1.1150839,,,,,,,,,,,,,, -42900,4.532193,1.1305729,,,,,,,,,,,,,, -43000,3.1726007,1.1038998,,,,,,,,,,,,,, -43100,2.3670802,1.141449,,,,,,,,,,,,,, -43200,2.4524057,1.1750777,,,,,,,,,,,,,, -43300,2.5557218,1.1497719,,,,,,,,,,,,,, -43400,3.4587088,1.1287198,,,,,,,,,,,,,, -43415,,,0.124804124,0.0457206361131981,0.42841774,0.1241009104337835,5348.0,0.2267602,0.0727154550809416,2472.0,34582.3564324379,37981.5632686615,34582.3564324379,3396.0110619068146,1.2427010536193848,0.0 -43500,2.3521197,1.1165701,,,,,,,,,,,,,, -43600,2.6033816,1.149223,,,,,,,,,,,,,, -43700,2.229865,1.1397841,,,,,,,,,,,,,, -43800,4.970867,1.1402665,,,,,,,,,,,,,, -43900,2.4651563,1.1284676,,,,,,,,,,,,,, -44000,1.8197672,1.1287363,,,,,,,,,,,,,, -44100,1.6448603,1.1159264,,,,,,,,,,,,,, -44200,2.4253285,1.0376023,,,,,,,,,,,,,, -44300,2.0130863,1.1797125,,,,,,,,,,,,,, -44400,2.1034737,1.1336151,,,,,,,,,,,,,, -44500,6.711744,1.1030027,,,,,,,,,,,,,, -44600,2.7482302,1.142451,,,,,,,,,,,,,, -44700,3.9299943,1.1189927,,,,,,,,,,,,,, -44800,2.6591473,1.1222547,,,,,,,,,,,,,, -44900,2.8166592,1.1347935,,,,,,,,,,,,,, -45000,5.0245275,1.1558092,,,,,,,,,,,,,, -45100,2.0316222,1.120518,,,,,,,,,,,,,, -45200,3.0444736,1.0907016,,,,,,,,,,,,,, -45201,,,0.13416459,0.0478005661856043,0.42583966,0.1228844241482182,5348.0,0.22566003,0.0728576361383624,2472.0,36022.44448566437,39557.18790698052,36022.44448566437,3531.409448623657,1.2997477054595947,0.0 -45300,3.5277488,1.0696641,,,,,,,,,,,,,, -45400,1.8897532,1.1023448,,,,,,,,,,,,,, -45500,2.4502912,1.0669147,,,,,,,,,,,,,, -45600,2.3011503,1.0700334,,,,,,,,,,,,,, -45700,2.9809291,1.0913196,,,,,,,,,,,,,, -45800,1.858084,1.0616235,,,,,,,,,,,,,, -45900,1.8326764,1.1172221,,,,,,,,,,,,,, -46000,3.506964,1.1065161,,,,,,,,,,,,,, -46100,2.9190447,1.0937767,,,,,,,,,,,,,, -46200,2.3075707,1.1550014,,,,,,,,,,,,,, -46300,2.6082766,1.1076019,,,,,,,,,,,,,, -46400,2.4970536,1.1083391,,,,,,,,,,,,,, -46500,1.8709967,1.1402348,,,,,,,,,,,,,, -46600,2.313605,1.061643,,,,,,,,,,,,,, -46700,3.741661,1.1237872,,,,,,,,,,,,,, -46800,2.7368557,1.1187967,,,,,,,,,,,,,, -46900,2.082361,1.0612543,,,,,,,,,,,,,, -46991,,,0.12691087,0.0466410933142423,0.42552513,0.1227589136584376,5348.0,0.22541308,0.0722686003290475,2472.0,37463.04101443291,41134.39881324768,37463.04101443291,3667.886235713959,1.3547301292419434,0.0 -47000,2.5369205,1.1377493,,,,,,,,,,,,,, -47100,2.0934603,1.1525052,,,,,,,,,,,,,, -47200,3.7920523,1.1348448,,,,,,,,,,,,,, -47300,1.760273,1.0908843,,,,,,,,,,,,,, -47400,5.635401,1.153533,,,,,,,,,,,,,, -47500,4.741837,1.1389966,,,,,,,,,,,,,, -47600,2.4403315,1.2013888,,,,,,,,,,,,,, -47700,2.4218986,1.0489776,,,,,,,,,,,,,, -47800,2.2872188,1.1193806,,,,,,,,,,,,,, -47900,2.7132015,1.0743778,,,,,,,,,,,,,, -48000,,,0.15317775,0.0520591896300111,0.4256426,0.1227299496992575,5348.0,0.22542043,0.0723498466475737,2472.0,38255.23158955574,42062.36226296425,38255.23158955574,3803.54567193985,1.418626070022583,0.0 -48000,,,,,,,,,,,38255.23158955574,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/eval_measurements.csv deleted file mode 100644 index 80904b51d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -197.8555688858032,0.0,15.847177743911743,1,0,15.847177743911743,30.942295,2472,3.6990636361789857,213.7028248310089,30.26521,3.632442748091603,30.89189,5348,3.325081823184684 -308.7532677650452,0.0294642448425292,1456.7463419437408,1800,0,1456.7463419437408,5.869004,2472,0.898990514492312,1765.605096578598,5.9749413,0.9422576974075736,5.959475,5348,0.8962607528698456 -426.3852117061615,0.0845613479614257,2897.1611421108246,3622,0,2897.1611421108246,3.7333565,2472,0.7959092478622063,3323.78942489624,3.9224856,0.8321735789451216,4.115441,5348,0.8269982718171023 -561.1081030368805,0.1388754844665527,4337.0904133319855,5435,0,4337.0904133319855,0.6223996,2472,0.2067515690695265,4898.577420711517,0.58773357,0.2038345540715209,0.9542844,5348,0.2757272367417477 -698.7054476737976,0.1892588138580322,5777.671191692352,7216,0,5777.671191692352,0.5297743,2472,0.1731155931996831,6476.886586666107,0.47603443,0.1632538994348039,0.834246,5348,0.2414628730316576 -834.2313735485077,0.2436678409576416,7218.265148639679,9030,0,7218.265148639679,0.4683264,2472,0.1521540430199256,8053.141635894775,0.42759815,0.1447462888274172,0.7569138,5348,0.2205895131158461 -971.4258334636688,0.2975211143493652,8658.311223983765,10847,0,8658.311223983765,0.4358419,2472,0.1405358194706802,9630.516981840134,0.40683183,0.1396856103906852,0.72036105,5348,0.2091584038927561 -1107.0503504276276,0.3526680469512939,10098.84746146202,12664,0,10098.84746146202,0.40882304,2472,0.1320862023439563,11206.81583738327,0.36989608,0.1255429628348816,0.67882633,5348,0.1983934657308089 -1242.6684551239014,0.4081716537475586,11539.406084775925,14437,0,11539.406084775925,0.3992169,2472,0.1300347328011699,12783.127084732056,0.33032325,0.1163172018916966,0.66513646,5348,0.1929965146702453 -1381.285228252411,0.4616940021514892,12979.301619768145,16237,0,12979.301619768145,0.38088343,2472,0.1229459915097597,14361.773697376251,0.31750143,0.1080931915393774,0.641495,5348,0.1859099993241743 -1517.154412984848,0.5125668048858643,14419.182766199112,18055,0,14419.182766199112,0.36559522,2472,0.1184165092519245,15937.654762268066,0.33247074,0.1144936161004111,0.62319475,5348,0.1816040240593954 -1652.0571603775024,0.5724399089813232,15859.548187494278,19871,0,15859.548187494278,0.35727885,2472,0.1170353218369792,17513.06376671791,0.3052099,0.1027504344368119,0.6104939,5348,0.1763712021008525 -1789.0675375461578,0.6251442432403564,17299.65816307068,21649,0,17299.65816307068,0.3434821,2472,0.1111043405845672,19090.31603384018,0.28996828,0.1009951780034882,0.59253216,5348,0.1723065931625747 -1926.1365442276,0.6778745651245117,18740.54448866844,23449,0,18740.54448866844,0.3419359,2472,0.111429325858672,20668.40568065644,0.32587433,0.1060233787506514,0.5905726,5348,0.1707425393668478 -2063.8816010952,0.741168737411499,20180.9346408844,25254,0,20180.9346408844,0.32006806,2472,0.1041984035098409,22246.685710906982,0.23896928,0.0836136349652357,0.5691403,5348,0.1654421348368846 -2201.445280790329,0.8065252304077148,21621.25813126564,27049,0,21621.25813126564,0.31684002,2472,0.1034874982227367,23824.7185087204,0.25387847,0.0880248004307663,0.55523735,5348,0.1612520154088262 -2335.782987356186,0.8676490783691406,23061.37408900261,28819,0,23061.37408900261,0.29585266,2472,0.0954441126886437,25399.31268310547,0.32222232,0.1100863983116055,0.52448064,5348,0.1545034129198567 -2471.308205604553,0.9225730895996094,24501.35665678978,30598,0,24501.35665678978,0.28533491,2472,0.0923973757439116,26974.95587038994,0.33865708,0.1133199141616622,0.5113746,5348,0.1491161165123531 -2604.052140712738,0.9797863960266112,25941.70943188668,32416,0,25941.70943188668,0.2788384,2472,0.090264659882599,28548.19150686264,0.3645432,0.1253538611472062,0.5027272,5348,0.1458142251658186 -2737.744610309601,1.035611629486084,27381.691277742382,34215,0,27381.691277742382,0.2685194,2472,0.0848617797006073,30122.00229978561,0.31370944,0.10446648667684,0.4884356,5348,0.1406682950848161 -2872.178744316101,1.0904130935668943,28822.284260749817,35991,0,28822.284260749817,0.25631005,2472,0.0828915564763471,31697.164457798004,0.2759164,0.0957748694924111,0.46581116,5348,0.1357830406364347 -3007.717676639557,1.1477704048156738,30262.1793999672,37763,0,30262.1793999672,0.25000796,2472,0.0806166595576138,33272.73663830757,0.2330637,0.0813926453015253,0.4498272,5348,0.1304247081881112 -3142.482138156891,1.2042455673217771,31702.49013137817,39574,0,31702.49013137817,0.24149479,2472,0.0769808868035667,34847.95053982735,0.26204503,0.0910593137147968,0.4422166,5348,0.1287254892495438 -3278.741602897644,1.27134108543396,33142.4312608242,41352,0,33142.4312608242,0.23231012,2472,0.0752340909552535,36424.297860622406,0.22383782,0.0781360282322746,0.43122214,5348,0.1246898442704461 -3413.544237613678,1.331352949142456,34582.384875535965,43138,0,34582.384875535965,0.2263191,2472,0.0726951435013101,37999.1936173439,0.21453248,0.0768669198206493,0.42354012,5348,0.122614093862537 -3549.495146036148,1.3847966194152832,36022.423419713974,44926,0,36022.423419713974,0.2226167,2472,0.0722889119086791,39575.31591534615,0.21462,0.0751011324643432,0.41773412,5348,0.1206348899852283 -3682.490786552429,1.443239688873291,37462.40724658966,46737,0,37462.40724658966,0.2220608,2472,0.0719842382142059,41148.4344394207,0.22253259,0.0786585891067788,0.41605195,5348,0.1204514515770875 -3818.4000244140625,1.5066735744476318,38467.08563065529,48000,0,38467.08563065529,0.22198159,2472,0.071801433997522,42289.14387321472,0.23057689,0.07855449653186469,0.41583607,5348,0.12014250267916622 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/measurements.csv deleted file mode 100644 index 2b05063b0..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.000578,32.458633,,,,,,,,,,,,,, -1,,,30.26521,3.632442748091603,30.89189,3.325081823184684,5348.0,30.942295,3.6990636361789857,2472.0,15.847177743911743,213.7028248310089,15.847177743911743,197.8555688858032,0.0,0.0 -100,7.249695,10.579131,,,,,,,,,,,,,, -200,1.3991257,6.3177543,,,,,,,,,,,,,, -300,0.39116654,5.9199615,,,,,,,,,,,,,, -400,0.6269756,5.8376145,,,,,,,,,,,,,, -500,0.37736058,5.802742,,,,,,,,,,,,,, -600,0.34086668,5.783569,,,,,,,,,,,,,, -700,0.807894,5.7137446,,,,,,,,,,,,,, -800,0.5243511,5.588911,,,,,,,,,,,,,, -900,1.1410357,5.3543816,,,,,,,,,,,,,, -1000,0.59418446,5.0449142,,,,,,,,,,,,,, -1100,0.8398036,4.7288275,,,,,,,,,,,,,, -1200,0.8062568,4.3640385,,,,,,,,,,,,,, -1300,1.3668331,4.1074986,,,,,,,,,,,,,, -1400,1.3311298,3.848743,,,,,,,,,,,,,, -1500,2.103476,3.7081118,,,,,,,,,,,,,, -1600,2.348966,3.4250753,,,,,,,,,,,,,, -1700,2.2550116,3.2696147,,,,,,,,,,,,,, -1800,,,5.9749413,0.9422576974075736,5.959475,0.8962607528698456,5348.0,5.869004,0.898990514492312,2472.0,1456.7463419437408,1765.605096578598,1456.7463419437408,308.7532677650452,0.0294642448425292,0.0 -1800,2.9647417,3.1508555,,,,,,,,,,,,,, -1900,2.1865757,3.07186,,,,,,,,,,,,,, -2000,2.5842664,2.897386,,,,,,,,,,,,,, -2100,2.080323,2.8147805,,,,,,,,,,,,,, -2200,2.991465,2.760753,,,,,,,,,,,,,, -2300,2.658125,2.6928654,,,,,,,,,,,,,, -2400,2.4039876,2.591531,,,,,,,,,,,,,, -2500,2.7702074,2.557922,,,,,,,,,,,,,, -2600,2.4617057,2.4956481,,,,,,,,,,,,,, -2700,2.394612,2.4706056,,,,,,,,,,,,,, -2800,2.9394045,2.3780599,,,,,,,,,,,,,, -2900,3.1548126,2.310826,,,,,,,,,,,,,, -3000,2.7881613,2.291325,,,,,,,,,,,,,, -3100,3.2104542,2.2446651,,,,,,,,,,,,,, -3200,4.0704546,2.2359896,,,,,,,,,,,,,, -3300,3.3717859,2.1845617,,,,,,,,,,,,,, -3400,2.9781241,2.18731,,,,,,,,,,,,,, -3500,4.691005,2.1332474,,,,,,,,,,,,,, -3600,3.2097888,2.1219618,,,,,,,,,,,,,, -3622,,,3.9224856,0.8321735789451216,4.115441,0.8269982718171023,5348.0,3.7333565,0.7959092478622063,2472.0,2897.1611421108246,3323.78942489624,2897.1611421108246,426.3852117061615,0.0845613479614257,0.0 -3700,4.863398,2.1370249,,,,,,,,,,,,,, -3800,4.418569,2.0391622,,,,,,,,,,,,,, -3900,3.5450726,2.0385923,,,,,,,,,,,,,, -4000,2.9565234,2.1233118,,,,,,,,,,,,,, -4100,3.5074823,2.0480895,,,,,,,,,,,,,, -4200,3.9650688,2.038919,,,,,,,,,,,,,, -4300,2.4120347,2.0291574,,,,,,,,,,,,,, -4400,4.565085,2.0482724,,,,,,,,,,,,,, -4500,3.2011235,2.0028255,,,,,,,,,,,,,, -4600,2.9578874,2.0263183,,,,,,,,,,,,,, -4700,4.3959293,2.0617156,,,,,,,,,,,,,, -4800,3.9221265,1.9913028,,,,,,,,,,,,,, -4900,2.8772178,1.9770505,,,,,,,,,,,,,, -5000,3.8799367,1.930402,,,,,,,,,,,,,, -5100,3.238692,1.989314,,,,,,,,,,,,,, -5200,3.5099688,1.9223719,,,,,,,,,,,,,, -5300,3.6107485,1.9222814,,,,,,,,,,,,,, -5400,3.8399076,1.8629148,,,,,,,,,,,,,, -5435,,,0.58773357,0.2038345540715209,0.9542844,0.2757272367417477,5348.0,0.6223996,0.2067515690695265,2472.0,4337.0904133319855,4898.577420711517,4337.0904133319855,561.1081030368805,0.1388754844665527,0.0 -5500,3.523066,1.8043928,,,,,,,,,,,,,, -5600,3.1463444,1.8665112,,,,,,,,,,,,,, -5700,3.3814216,1.8913904,,,,,,,,,,,,,, -5800,2.9276066,1.8280536,,,,,,,,,,,,,, -5900,3.3606083,1.8195162,,,,,,,,,,,,,, -6000,3.111603,1.8321143,,,,,,,,,,,,,, -6100,7.5817847,1.806492,,,,,,,,,,,,,, -6200,3.6534212,1.8122681,,,,,,,,,,,,,, -6300,2.9897158,1.775582,,,,,,,,,,,,,, -6400,2.687231,1.7538056,,,,,,,,,,,,,, -6500,3.522685,1.7329001,,,,,,,,,,,,,, -6600,2.1044319,1.7692481,,,,,,,,,,,,,, -6700,2.943892,1.7080618,,,,,,,,,,,,,, -6800,4.64072,1.8032559,,,,,,,,,,,,,, -6900,2.9589322,1.7628129,,,,,,,,,,,,,, -7000,2.8684661,1.7756318,,,,,,,,,,,,,, -7100,3.9693313,1.7850727,,,,,,,,,,,,,, -7200,2.8698874,1.7783008,,,,,,,,,,,,,, -7216,,,0.47603443,0.1632538994348039,0.834246,0.2414628730316576,5348.0,0.5297743,0.1731155931996831,2472.0,5777.671191692352,6476.886586666107,5777.671191692352,698.7054476737976,0.1892588138580322,0.0 -7300,3.6803856,1.7692052,,,,,,,,,,,,,, -7400,3.2029588,1.7160448,,,,,,,,,,,,,, -7500,3.1443,1.8175445,,,,,,,,,,,,,, -7600,2.733222,1.7106104,,,,,,,,,,,,,, -7700,3.09327,1.669265,,,,,,,,,,,,,, -7800,2.6500769,1.7789745,,,,,,,,,,,,,, -7900,2.7053308,1.6916384,,,,,,,,,,,,,, -8000,3.7120583,1.6923766,,,,,,,,,,,,,, -8100,3.0427089,1.6795549,,,,,,,,,,,,,, -8200,2.877501,1.6877365,,,,,,,,,,,,,, -8300,3.6507614,1.7179914,,,,,,,,,,,,,, -8400,3.3538451,1.6753796,,,,,,,,,,,,,, -8500,3.2128925,1.698726,,,,,,,,,,,,,, -8600,2.2489204,1.6749727,,,,,,,,,,,,,, -8700,2.8503788,1.659698,,,,,,,,,,,,,, -8800,2.5211434,1.6634237,,,,,,,,,,,,,, -8900,2.2911298,1.6817322,,,,,,,,,,,,,, -9000,3.2555413,1.6615136,,,,,,,,,,,,,, -9030,,,0.42759815,0.1447462888274172,0.7569138,0.2205895131158461,5348.0,0.4683264,0.1521540430199256,2472.0,7218.265148639679,8053.141635894775,7218.265148639679,834.2313735485077,0.2436678409576416,0.0 -9100,2.9760451,1.7006474,,,,,,,,,,,,,, -9200,2.3072932,1.6624373,,,,,,,,,,,,,, -9300,3.9871416,1.6262535,,,,,,,,,,,,,, -9400,4.301894,1.6520513,,,,,,,,,,,,,, -9500,3.4338558,1.6265136,,,,,,,,,,,,,, -9600,3.52109,1.6413696,,,,,,,,,,,,,, -9700,2.9861064,1.6755103,,,,,,,,,,,,,, -9800,3.3711867,1.6809046,,,,,,,,,,,,,, -9900,2.6493435,1.6393096,,,,,,,,,,,,,, -10000,2.992992,1.6429162,,,,,,,,,,,,,, -10100,3.6914515,1.645543,,,,,,,,,,,,,, -10200,2.4327874,1.5744014,,,,,,,,,,,,,, -10300,2.545717,1.5936891,,,,,,,,,,,,,, -10400,3.4786336,1.6063871,,,,,,,,,,,,,, -10500,2.387031,1.6408054,,,,,,,,,,,,,, -10600,2.187433,1.6417289,,,,,,,,,,,,,, -10700,2.7547126,1.6074417,,,,,,,,,,,,,, -10800,2.7091835,1.6188158,,,,,,,,,,,,,, -10847,,,0.40683183,0.1396856103906852,0.72036105,0.2091584038927561,5348.0,0.4358419,0.1405358194706802,2472.0,8658.311223983765,9630.516981840134,8658.311223983765,971.4258334636688,0.2975211143493652,0.0 -10900,2.2483897,1.61425,,,,,,,,,,,,,, -11000,2.8435771,1.6444625,,,,,,,,,,,,,, -11100,3.1511455,1.5829959,,,,,,,,,,,,,, -11200,2.1352644,1.6223304,,,,,,,,,,,,,, -11300,1.8672657,1.648546,,,,,,,,,,,,,, -11400,4.9033976,1.6029553,,,,,,,,,,,,,, -11500,3.8832154,1.5359682,,,,,,,,,,,,,, -11600,2.4215908,1.5535357,,,,,,,,,,,,,, -11700,5.382492,1.6237195,,,,,,,,,,,,,, -11800,3.429153,1.6438869,,,,,,,,,,,,,, -11900,2.334247,1.6065443,,,,,,,,,,,,,, -12000,4.1094337,1.5960764,,,,,,,,,,,,,, -12100,2.4016814,1.5543716,,,,,,,,,,,,,, -12200,2.5660868,1.6194564,,,,,,,,,,,,,, -12300,4.0930443,1.6253163,,,,,,,,,,,,,, -12400,3.153123,1.654336,,,,,,,,,,,,,, -12500,2.4533734,1.5817076,,,,,,,,,,,,,, -12600,2.9742415,1.6201415,,,,,,,,,,,,,, -12664,,,0.36989608,0.1255429628348816,0.67882633,0.1983934657308089,5348.0,0.40882304,0.1320862023439563,2472.0,10098.84746146202,11206.81583738327,10098.84746146202,1107.0503504276276,0.3526680469512939,0.0 -12700,2.961497,1.6011851,,,,,,,,,,,,,, -12800,2.1461031,1.6693594,,,,,,,,,,,,,, -12900,2.7456276,1.5808558,,,,,,,,,,,,,, -13000,2.8634853,1.571799,,,,,,,,,,,,,, -13100,3.8669882,1.592296,,,,,,,,,,,,,, -13200,3.2521572,1.5803925,,,,,,,,,,,,,, -13300,3.2869542,1.6279626,,,,,,,,,,,,,, -13400,2.4997804,1.5610561,,,,,,,,,,,,,, -13500,3.0671353,1.5421613,,,,,,,,,,,,,, -13600,2.686801,1.5853225,,,,,,,,,,,,,, -13700,2.553191,1.5608035,,,,,,,,,,,,,, -13800,2.4521477,1.5160244,,,,,,,,,,,,,, -13900,3.7613218,1.5759093,,,,,,,,,,,,,, -14000,3.3468401,1.5731776,,,,,,,,,,,,,, -14100,2.9660845,1.5568954,,,,,,,,,,,,,, -14200,3.8656259,1.5780582,,,,,,,,,,,,,, -14300,2.1721156,1.5668894,,,,,,,,,,,,,, -14400,3.2157962,1.5625864,,,,,,,,,,,,,, -14437,,,0.33032325,0.1163172018916966,0.66513646,0.1929965146702453,5348.0,0.3992169,0.1300347328011699,2472.0,11539.406084775925,12783.127084732056,11539.406084775925,1242.6684551239014,0.4081716537475586,0.0 -14500,3.513356,1.583292,,,,,,,,,,,,,, -14600,3.0619366,1.610676,,,,,,,,,,,,,, -14700,4.374893,1.5404533,,,,,,,,,,,,,, -14800,2.2901042,1.5433625,,,,,,,,,,,,,, -14900,3.2988281,1.5584399,,,,,,,,,,,,,, -15000,5.370968,1.5397354,,,,,,,,,,,,,, -15100,2.1059482,1.5707256,,,,,,,,,,,,,, -15200,6.1785593,1.4801987,,,,,,,,,,,,,, -15300,3.812626,1.5962675,,,,,,,,,,,,,, -15400,1.9945444,1.5782536,,,,,,,,,,,,,, -15500,3.1418295,1.5424653,,,,,,,,,,,,,, -15600,2.2139359,1.477252,,,,,,,,,,,,,, -15700,2.2058892,1.5916145,,,,,,,,,,,,,, -15800,2.5383298,1.5411658,,,,,,,,,,,,,, -15900,2.9978147,1.4892852,,,,,,,,,,,,,, -16000,2.367647,1.5531645,,,,,,,,,,,,,, -16100,1.8035926,1.542727,,,,,,,,,,,,,, -16200,3.2415805,1.495852,,,,,,,,,,,,,, -16237,,,0.31750143,0.1080931915393774,0.641495,0.1859099993241743,5348.0,0.38088343,0.1229459915097597,2472.0,12979.301619768145,14361.773697376251,12979.301619768145,1381.285228252411,0.4616940021514892,0.0 -16300,3.1064365,1.5206815,,,,,,,,,,,,,, -16400,2.3086207,1.4828353,,,,,,,,,,,,,, -16500,3.0280244,1.4994954,,,,,,,,,,,,,, -16600,2.7224095,1.5297266,,,,,,,,,,,,,, -16700,2.7037585,1.5231307,,,,,,,,,,,,,, -16800,2.529875,1.5386199,,,,,,,,,,,,,, -16900,3.3727481,1.4947292,,,,,,,,,,,,,, -17000,3.5453584,1.4991431,,,,,,,,,,,,,, -17100,3.0552287,1.5519034,,,,,,,,,,,,,, -17200,3.1701453,1.469981,,,,,,,,,,,,,, -17300,2.9003155,1.5379254,,,,,,,,,,,,,, -17400,2.0294044,1.5083143,,,,,,,,,,,,,, -17500,2.8108294,1.562314,,,,,,,,,,,,,, -17600,3.4250424,1.4966013,,,,,,,,,,,,,, -17700,3.2140162,1.5429773,,,,,,,,,,,,,, -17800,3.4756014,1.5318022,,,,,,,,,,,,,, -17900,2.7070174,1.5346128,,,,,,,,,,,,,, -18000,2.8842256,1.5198293,,,,,,,,,,,,,, -18055,,,0.33247074,0.1144936161004111,0.62319475,0.1816040240593954,5348.0,0.36559522,0.1184165092519245,2472.0,14419.182766199112,15937.654762268066,14419.182766199112,1517.154412984848,0.5125668048858643,0.0 -18100,3.791962,1.5080892,,,,,,,,,,,,,, -18200,2.5914962,1.4818467,,,,,,,,,,,,,, -18300,4.07891,1.4900565,,,,,,,,,,,,,, -18400,2.7647521,1.5180945,,,,,,,,,,,,,, -18500,3.2444813,1.5847541,,,,,,,,,,,,,, -18600,2.2008228,1.4539579,,,,,,,,,,,,,, -18700,1.7589518,1.4641308,,,,,,,,,,,,,, -18800,2.8204467,1.4498383,,,,,,,,,,,,,, -18900,2.977772,1.5275259,,,,,,,,,,,,,, -19000,2.5497139,1.4466287,,,,,,,,,,,,,, -19100,2.509351,1.4140494,,,,,,,,,,,,,, -19200,2.4907687,1.4779897,,,,,,,,,,,,,, -19300,3.2746148,1.5265093,,,,,,,,,,,,,, -19400,2.2725632,1.4512091,,,,,,,,,,,,,, -19500,3.165803,1.4385283,,,,,,,,,,,,,, -19600,3.6634872,1.5334849,,,,,,,,,,,,,, -19700,3.896511,1.4764892,,,,,,,,,,,,,, -19800,2.2457066,1.4364378,,,,,,,,,,,,,, -19871,,,0.3052099,0.1027504344368119,0.6104939,0.1763712021008525,5348.0,0.35727885,0.1170353218369792,2472.0,15859.548187494278,17513.06376671791,15859.548187494278,1652.0571603775024,0.5724399089813232,0.0 -19900,5.4808517,1.4841564,,,,,,,,,,,,,, -20000,3.9828317,1.4502646,,,,,,,,,,,,,, -20100,3.424955,1.4684302,,,,,,,,,,,,,, -20200,3.1900437,1.4786587,,,,,,,,,,,,,, -20300,2.7070487,1.4457587,,,,,,,,,,,,,, -20400,3.643428,1.5361296,,,,,,,,,,,,,, -20500,2.8261485,1.5192161,,,,,,,,,,,,,, -20600,2.4655416,1.518433,,,,,,,,,,,,,, -20700,2.589166,1.4953346,,,,,,,,,,,,,, -20800,2.8188143,1.4535336,,,,,,,,,,,,,, -20900,3.841489,1.4299257,,,,,,,,,,,,,, -21000,1.9171609,1.4353814,,,,,,,,,,,,,, -21100,2.5587463,1.5000159,,,,,,,,,,,,,, -21200,3.5142417,1.5450745,,,,,,,,,,,,,, -21300,4.132081,1.4746677,,,,,,,,,,,,,, -21400,3.6426167,1.4996108,,,,,,,,,,,,,, -21500,2.6669102,1.5003232,,,,,,,,,,,,,, -21600,2.7093806,1.4937158,,,,,,,,,,,,,, -21649,,,0.28996828,0.1009951780034882,0.59253216,0.1723065931625747,5348.0,0.3434821,0.1111043405845672,2472.0,17299.65816307068,19090.31603384018,17299.65816307068,1789.0675375461578,0.6251442432403564,0.0 -21700,2.8839014,1.3902192,,,,,,,,,,,,,, -21800,3.194971,1.4648315,,,,,,,,,,,,,, -21900,2.620356,1.3947407,,,,,,,,,,,,,, -22000,3.164708,1.4360474,,,,,,,,,,,,,, -22100,4.4021907,1.4079536,,,,,,,,,,,,,, -22200,3.1523392,1.3993403,,,,,,,,,,,,,, -22300,3.039623,1.4774402,,,,,,,,,,,,,, -22400,2.7330303,1.4503485,,,,,,,,,,,,,, -22500,2.5246222,1.416047,,,,,,,,,,,,,, -22600,3.8019686,1.4512144,,,,,,,,,,,,,, -22700,2.9391453,1.4089266,,,,,,,,,,,,,, -22800,2.4588344,1.4894606,,,,,,,,,,,,,, -22900,1.9267988,1.4589578,,,,,,,,,,,,,, -23000,2.853233,1.4397279,,,,,,,,,,,,,, -23100,2.062811,1.464648,,,,,,,,,,,,,, -23200,3.0635362,1.5223627,,,,,,,,,,,,,, -23300,2.3182042,1.4683046,,,,,,,,,,,,,, -23400,4.9067354,1.4581703,,,,,,,,,,,,,, -23449,,,0.32587433,0.1060233787506514,0.5905726,0.1707425393668478,5348.0,0.3419359,0.111429325858672,2472.0,18740.54448866844,20668.40568065644,18740.54448866844,1926.1365442276,0.6778745651245117,0.0 -23500,3.5296776,1.4356405,,,,,,,,,,,,,, -23600,2.6174223,1.4121419,,,,,,,,,,,,,, -23700,2.0282211,1.4666716,,,,,,,,,,,,,, -23800,3.3894143,1.4449158,,,,,,,,,,,,,, -23900,4.4884152,1.4998198,,,,,,,,,,,,,, -24000,3.374515,1.4455647,,,,,,,,,,,,,, -24100,2.1740236,1.3932246,,,,,,,,,,,,,, -24200,3.043052,1.4401761,,,,,,,,,,,,,, -24300,2.7718432,1.4552101,,,,,,,,,,,,,, -24400,2.5590189,1.3857262,,,,,,,,,,,,,, -24500,2.1899662,1.4483261,,,,,,,,,,,,,, -24600,2.737518,1.4891812,,,,,,,,,,,,,, -24700,1.826145,1.4330903,,,,,,,,,,,,,, -24800,2.8059375,1.3834727,,,,,,,,,,,,,, -24900,2.257043,1.4332538,,,,,,,,,,,,,, -25000,2.1404736,1.3445115,,,,,,,,,,,,,, -25100,2.5787113,1.4200377,,,,,,,,,,,,,, -25200,2.2837267,1.4488077,,,,,,,,,,,,,, -25254,,,0.23896928,0.0836136349652357,0.5691403,0.1654421348368846,5348.0,0.32006806,0.1041984035098409,2472.0,20180.9346408844,22246.685710906982,20180.9346408844,2063.8816010952,0.741168737411499,0.0 -25300,3.018578,1.4224029,,,,,,,,,,,,,, -25400,3.315642,1.4433881,,,,,,,,,,,,,, -25500,3.29869,1.4628501,,,,,,,,,,,,,, -25600,2.9660752,1.4123363,,,,,,,,,,,,,, -25700,2.472792,1.4739237,,,,,,,,,,,,,, -25800,2.5559838,1.3722864,,,,,,,,,,,,,, -25900,2.3885982,1.3695157,,,,,,,,,,,,,, -26000,2.697711,1.4555256,,,,,,,,,,,,,, -26100,2.0553684,1.4205755,,,,,,,,,,,,,, -26200,3.2826307,1.3696015,,,,,,,,,,,,,, -26300,4.228216,1.4028417,,,,,,,,,,,,,, -26400,4.1034083,1.4137461,,,,,,,,,,,,,, -26500,2.7223945,1.3906833,,,,,,,,,,,,,, -26600,3.7700455,1.3605474,,,,,,,,,,,,,, -26700,2.6900215,1.3731241,,,,,,,,,,,,,, -26800,2.664823,1.3723997,,,,,,,,,,,,,, -26900,12.49764,1.3857179,,,,,,,,,,,,,, -27000,4.3726954,1.3826113,,,,,,,,,,,,,, -27049,,,0.25387847,0.0880248004307663,0.55523735,0.1612520154088262,5348.0,0.31684002,0.1034874982227367,2472.0,21621.25813126564,23824.7185087204,21621.25813126564,2201.445280790329,0.8065252304077148,0.0 -27100,2.265653,1.4287903,,,,,,,,,,,,,, -27200,2.9323685,1.4178299,,,,,,,,,,,,,, -27300,2.6908543,1.4226735,,,,,,,,,,,,,, -27400,1.9295348,1.3671494,,,,,,,,,,,,,, -27500,2.5131814,1.3753133,,,,,,,,,,,,,, -27600,3.2525864,1.3908482,,,,,,,,,,,,,, -27700,2.5676734,1.4720861,,,,,,,,,,,,,, -27800,2.087926,1.4368371,,,,,,,,,,,,,, -27900,2.4549663,1.4435116,,,,,,,,,,,,,, -28000,1.9360691,1.3148277,,,,,,,,,,,,,, -28100,2.5811534,1.4021549,,,,,,,,,,,,,, -28200,2.0750947,1.3203429,,,,,,,,,,,,,, -28300,3.2795537,1.3492688,,,,,,,,,,,,,, -28400,3.6156752,1.3968885,,,,,,,,,,,,,, -28500,3.2428858,1.3826029,,,,,,,,,,,,,, -28600,4.230747,1.4703771,,,,,,,,,,,,,, -28700,1.876872,1.3706232,,,,,,,,,,,,,, -28800,3.4345999,1.4100418,,,,,,,,,,,,,, -28819,,,0.32222232,0.1100863983116055,0.52448064,0.1545034129198567,5348.0,0.29585266,0.0954441126886437,2472.0,23061.37408900261,25399.31268310547,23061.37408900261,2335.782987356186,0.8676490783691406,0.0 -28900,3.367317,1.3482472,,,,,,,,,,,,,, -29000,2.5366886,1.3219692,,,,,,,,,,,,,, -29100,2.5348449,1.3426787,,,,,,,,,,,,,, -29200,2.9006636,1.3431683,,,,,,,,,,,,,, -29300,3.6103802,1.3077939,,,,,,,,,,,,,, -29400,2.5056074,1.3810328,,,,,,,,,,,,,, -29500,2.5458832,1.3581322,,,,,,,,,,,,,, -29600,3.5536873,1.4029484,,,,,,,,,,,,,, -29700,2.4017224,1.3013177,,,,,,,,,,,,,, -29800,2.8317707,1.3330193,,,,,,,,,,,,,, -29900,2.2202334,1.3432741,,,,,,,,,,,,,, -30000,3.022152,1.3408132,,,,,,,,,,,,,, -30100,2.673028,1.302871,,,,,,,,,,,,,, -30200,2.43091,1.4010389,,,,,,,,,,,,,, -30300,2.362267,1.3572549,,,,,,,,,,,,,, -30400,3.3989775,1.2798008,,,,,,,,,,,,,, -30500,2.4160106,1.3125525,,,,,,,,,,,,,, -30598,,,0.33865708,0.1133199141616622,0.5113746,0.1491161165123531,5348.0,0.28533491,0.0923973757439116,2472.0,24501.35665678978,26974.95587038994,24501.35665678978,2471.308205604553,0.9225730895996094,0.0 -30600,2.5146399,1.3133162,,,,,,,,,,,,,, -30700,2.373535,1.3506204,,,,,,,,,,,,,, -30800,2.7241426,1.3517636,,,,,,,,,,,,,, -30900,2.44678,1.3276924,,,,,,,,,,,,,, -31000,2.947295,1.3242407,,,,,,,,,,,,,, -31100,2.7751303,1.3152856,,,,,,,,,,,,,, -31200,2.040884,1.2988861,,,,,,,,,,,,,, -31300,4.0000205,1.3022381,,,,,,,,,,,,,, -31400,2.392371,1.2915637,,,,,,,,,,,,,, -31500,3.1752565,1.334214,,,,,,,,,,,,,, -31600,2.3847399,1.3575372,,,,,,,,,,,,,, -31700,3.0043595,1.3630741,,,,,,,,,,,,,, -31800,3.3342445,1.3745993,,,,,,,,,,,,,, -31900,2.552019,1.3088418,,,,,,,,,,,,,, -32000,2.5269322,1.3373922,,,,,,,,,,,,,, -32100,2.8513117,1.335521,,,,,,,,,,,,,, -32200,2.5054345,1.3463103,,,,,,,,,,,,,, -32300,2.3152668,1.311964,,,,,,,,,,,,,, -32400,2.4063804,1.3168677,,,,,,,,,,,,,, -32416,,,0.3645432,0.1253538611472062,0.5027272,0.1458142251658186,5348.0,0.2788384,0.090264659882599,2472.0,25941.70943188668,28548.19150686264,25941.70943188668,2604.052140712738,0.9797863960266112,0.0 -32500,2.8065684,1.3307188,,,,,,,,,,,,,, -32600,3.1601021,1.2168553,,,,,,,,,,,,,, -32700,3.1883428,1.2908803,,,,,,,,,,,,,, -32800,3.3066087,1.3184677,,,,,,,,,,,,,, -32900,2.9522674,1.3276194,,,,,,,,,,,,,, -33000,2.8268394,1.3034552,,,,,,,,,,,,,, -33100,3.3573747,1.3003823,,,,,,,,,,,,,, -33200,2.929951,1.3189942,,,,,,,,,,,,,, -33300,2.7203634,1.3381509,,,,,,,,,,,,,, -33400,2.8640895,1.3090435,,,,,,,,,,,,,, -33500,2.5509343,1.2963159,,,,,,,,,,,,,, -33600,2.9030569,1.3079326,,,,,,,,,,,,,, -33700,2.675323,1.3316606,,,,,,,,,,,,,, -33800,3.6119866,1.2923126,,,,,,,,,,,,,, -33900,2.7020946,1.2576419,,,,,,,,,,,,,, -34000,3.4987338,1.2573997,,,,,,,,,,,,,, -34100,2.7106383,1.2551572,,,,,,,,,,,,,, -34200,3.5241103,1.2740169,,,,,,,,,,,,,, -34215,,,0.31370944,0.10446648667684,0.4884356,0.1406682950848161,5348.0,0.2685194,0.0848617797006073,2472.0,27381.691277742382,30122.00229978561,27381.691277742382,2737.744610309601,1.035611629486084,0.0 -34300,2.527982,1.2902136,,,,,,,,,,,,,, -34400,2.8257892,1.2766396,,,,,,,,,,,,,, -34500,3.5271678,1.2766289,,,,,,,,,,,,,, -34600,2.3840404,1.241737,,,,,,,,,,,,,, -34700,2.5708086,1.2640417,,,,,,,,,,,,,, -34800,3.0743575,1.2215767,,,,,,,,,,,,,, -34900,3.606846,1.2986721,,,,,,,,,,,,,, -35000,3.554164,1.2794085,,,,,,,,,,,,,, -35100,2.47071,1.2487046,,,,,,,,,,,,,, -35200,2.9534194,1.2808018,,,,,,,,,,,,,, -35300,3.011921,1.2596499,,,,,,,,,,,,,, -35400,2.9202392,1.2124565,,,,,,,,,,,,,, -35500,3.5243156,1.2236214,,,,,,,,,,,,,, -35600,2.8925748,1.2258196,,,,,,,,,,,,,, -35700,3.326166,1.2577859,,,,,,,,,,,,,, -35800,3.6091032,1.2448821,,,,,,,,,,,,,, -35900,2.5934427,1.2208381,,,,,,,,,,,,,, -35991,,,0.2759164,0.0957748694924111,0.46581116,0.1357830406364347,5348.0,0.25631005,0.0828915564763471,2472.0,28822.284260749817,31697.164457798004,28822.284260749817,2872.178744316101,1.0904130935668943,0.0 -36000,2.2198057,1.295809,,,,,,,,,,,,,, -36100,3.1849008,1.1995168,,,,,,,,,,,,,, -36200,3.3844917,1.1864499,,,,,,,,,,,,,, -36300,3.0571685,1.250655,,,,,,,,,,,,,, -36400,2.6740909,1.2453105,,,,,,,,,,,,,, -36500,3.7412407,1.2454157,,,,,,,,,,,,,, -36600,2.9167671,1.2947683,,,,,,,,,,,,,, -36700,3.5282123,1.2418929,,,,,,,,,,,,,, -36800,2.1987922,1.2541989,,,,,,,,,,,,,, -36900,3.7388933,1.1860291,,,,,,,,,,,,,, -37000,4.4514666,1.2054828,,,,,,,,,,,,,, -37100,2.7640488,1.1732006,,,,,,,,,,,,,, -37200,3.1489263,1.2037482,,,,,,,,,,,,,, -37300,2.4885247,1.2121384,,,,,,,,,,,,,, -37400,3.4498386,1.1921024,,,,,,,,,,,,,, -37500,3.4417677,1.1803668,,,,,,,,,,,,,, -37600,4.4770355,1.2399743,,,,,,,,,,,,,, -37700,2.845928,1.2383734,,,,,,,,,,,,,, -37763,,,0.2330637,0.0813926453015253,0.4498272,0.1304247081881112,5348.0,0.25000796,0.0806166595576138,2472.0,30262.1793999672,33272.73663830757,30262.1793999672,3007.717676639557,1.1477704048156738,0.0 -37800,2.4363875,1.1911218,,,,,,,,,,,,,, -37900,2.186499,1.2678454,,,,,,,,,,,,,, -38000,2.6688259,1.2406853,,,,,,,,,,,,,, -38100,3.9775796,1.2144682,,,,,,,,,,,,,, -38200,3.348663,1.194133,,,,,,,,,,,,,, -38300,3.4398458,1.2752147,,,,,,,,,,,,,, -38400,2.6963716,1.1730391,,,,,,,,,,,,,, -38500,3.5454886,1.1733259,,,,,,,,,,,,,, -38600,2.4148765,1.1832198,,,,,,,,,,,,,, -38700,2.626239,1.2538288,,,,,,,,,,,,,, -38800,4.128517,1.2202945,,,,,,,,,,,,,, -38900,3.275925,1.2160264,,,,,,,,,,,,,, -39000,3.281356,1.1932775,,,,,,,,,,,,,, -39100,2.7427905,1.1949557,,,,,,,,,,,,,, -39200,4.488079,1.1935033,,,,,,,,,,,,,, -39300,3.0799565,1.1538845,,,,,,,,,,,,,, -39400,3.6779115,1.1420902,,,,,,,,,,,,,, -39500,2.4636967,1.1506305,,,,,,,,,,,,,, -39574,,,0.26204503,0.0910593137147968,0.4422166,0.1287254892495438,5348.0,0.24149479,0.0769808868035667,2472.0,31702.49013137817,34847.95053982735,31702.49013137817,3142.482138156891,1.2042455673217771,0.0 -39600,2.6583712,1.2808832,,,,,,,,,,,,,, -39700,2.278192,1.1248759,,,,,,,,,,,,,, -39800,3.025122,1.1966455,,,,,,,,,,,,,, -39900,3.9279914,1.1530343,,,,,,,,,,,,,, -40000,2.8386183,1.1281924,,,,,,,,,,,,,, -40100,2.8943908,1.1806595,,,,,,,,,,,,,, -40200,2.9007406,1.1628791,,,,,,,,,,,,,, -40300,2.6978405,1.1100303,,,,,,,,,,,,,, -40400,2.3010638,1.189942,,,,,,,,,,,,,, -40500,2.3441372,1.1738598,,,,,,,,,,,,,, -40600,2.9956584,1.1788793,,,,,,,,,,,,,, -40700,2.6556005,1.1508192,,,,,,,,,,,,,, -40800,4.4872975,1.1825448,,,,,,,,,,,,,, -40900,2.7787316,1.1592699,,,,,,,,,,,,,, -41000,3.4830446,1.1892573,,,,,,,,,,,,,, -41100,2.5584614,1.1955653,,,,,,,,,,,,,, -41200,3.1562338,1.1789361,,,,,,,,,,,,,, -41300,3.0145214,1.1691332,,,,,,,,,,,,,, -41352,,,0.22383782,0.0781360282322746,0.43122214,0.1246898442704461,5348.0,0.23231012,0.0752340909552535,2472.0,33142.4312608242,36424.297860622406,33142.4312608242,3278.741602897644,1.27134108543396,0.0 -41400,2.9533422,1.2048758,,,,,,,,,,,,,, -41500,3.2599652,1.1484078,,,,,,,,,,,,,, -41600,4.0926485,1.1640987,,,,,,,,,,,,,, -41700,2.2947125,1.1298016,,,,,,,,,,,,,, -41800,2.2625513,1.2048043,,,,,,,,,,,,,, -41900,3.5616248,1.1669408,,,,,,,,,,,,,, -42000,2.9519558,1.1796494,,,,,,,,,,,,,, -42100,4.3631005,1.1857793,,,,,,,,,,,,,, -42200,3.1201982,1.149436,,,,,,,,,,,,,, -42300,4.142092,1.1253359,,,,,,,,,,,,,, -42400,3.1697934,1.1382691,,,,,,,,,,,,,, -42500,2.5665565,1.1652634,,,,,,,,,,,,,, -42600,2.8655586,1.1909959,,,,,,,,,,,,,, -42700,3.9006624,1.1098286,,,,,,,,,,,,,, -42800,2.9249916,1.1843507,,,,,,,,,,,,,, -42900,3.1743655,1.1673198,,,,,,,,,,,,,, -43000,1.9808291,1.1203144,,,,,,,,,,,,,, -43100,3.4094574,1.0955058,,,,,,,,,,,,,, -43138,,,0.21453248,0.0768669198206493,0.42354012,0.122614093862537,5348.0,0.2263191,0.0726951435013101,2472.0,34582.384875535965,37999.1936173439,34582.384875535965,3413.544237613678,1.331352949142456,0.0 -43200,2.4176524,1.1250092,,,,,,,,,,,,,, -43300,3.4195561,1.148611,,,,,,,,,,,,,, -43400,3.273625,1.1500401,,,,,,,,,,,,,, -43500,3.5954623,1.1121962,,,,,,,,,,,,,, -43600,3.9452643,1.1704592,,,,,,,,,,,,,, -43700,2.4658272,1.0956365,,,,,,,,,,,,,, -43800,2.4280639,1.161257,,,,,,,,,,,,,, -43900,3.2656343,1.1899732,,,,,,,,,,,,,, -44000,3.2647216,1.1237893,,,,,,,,,,,,,, -44100,5.5656233,1.1085234,,,,,,,,,,,,,, -44200,3.4599612,1.0507314,,,,,,,,,,,,,, -44300,2.9413188,1.1453567,,,,,,,,,,,,,, -44400,4.4814363,1.1317601,,,,,,,,,,,,,, -44500,3.780688,1.1148475,,,,,,,,,,,,,, -44600,2.809259,1.1411611,,,,,,,,,,,,,, -44700,3.287651,1.1058387,,,,,,,,,,,,,, -44800,2.9691079,1.1624931,,,,,,,,,,,,,, -44900,2.3337126,1.1444921,,,,,,,,,,,,,, -44926,,,0.21462,0.0751011324643432,0.41773412,0.1206348899852283,5348.0,0.2226167,0.0722889119086791,2472.0,36022.423419713974,39575.31591534615,36022.423419713974,3549.495146036148,1.3847966194152832,0.0 -45000,3.6254375,1.1414121,,,,,,,,,,,,,, -45100,2.7608216,1.1219515,,,,,,,,,,,,,, -45200,2.7883792,1.0872359,,,,,,,,,,,,,, -45300,3.557738,1.1260848,,,,,,,,,,,,,, -45400,3.2364013,1.1508353,,,,,,,,,,,,,, -45500,2.5780241,1.1106838,,,,,,,,,,,,,, -45600,3.3751037,1.0789407,,,,,,,,,,,,,, -45700,2.7137601,1.1747137,,,,,,,,,,,,,, -45800,3.5746355,1.0824634,,,,,,,,,,,,,, -45900,3.0369697,1.1289793,,,,,,,,,,,,,, -46000,4.650227,1.1283854,,,,,,,,,,,,,, -46100,3.4959612,1.1460919,,,,,,,,,,,,,, -46200,2.4779413,1.1492808,,,,,,,,,,,,,, -46300,3.5317526,1.121662,,,,,,,,,,,,,, -46400,3.609678,1.1318266,,,,,,,,,,,,,, -46500,3.3455555,1.1474127,,,,,,,,,,,,,, -46600,4.0809746,1.0945421,,,,,,,,,,,,,, -46700,3.3620794,1.119127,,,,,,,,,,,,,, -46737,,,0.22253259,0.0786585891067788,0.41605195,0.1204514515770875,5348.0,0.2220608,0.0719842382142059,2472.0,37462.40724658966,41148.4344394207,37462.40724658966,3682.490786552429,1.443239688873291,0.0 -46800,2.665097,1.168215,,,,,,,,,,,,,, -46900,4.4028544,1.1258392,,,,,,,,,,,,,, -47000,3.2454083,1.1064439,,,,,,,,,,,,,, -47100,3.515598,1.1256648,,,,,,,,,,,,,, -47200,2.8767116,1.1183399,,,,,,,,,,,,,, -47300,2.2364423,1.1232252,,,,,,,,,,,,,, -47400,4.367217,1.113265,,,,,,,,,,,,,, -47500,2.9716718,1.1467284,,,,,,,,,,,,,, -47600,2.6441944,1.1537657,,,,,,,,,,,,,, -47700,2.3412628,1.1590865,,,,,,,,,,,,,, -47800,2.73238,1.1044688,,,,,,,,,,,,,, -47900,2.020733,1.0827094,,,,,,,,,,,,,, -48000,,,0.23057689,0.0785544965318646,0.41583607,0.1201425026791662,5348.0,0.22198159,0.071801433997522,2472.0,38467.08563065529,42289.14387321472,38467.08563065529,3818.4000244140634,1.5066735744476318,0.0 -48000,,,,,,,,,,,38467.08563065529,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/eval_measurements.csv deleted file mode 100644 index b80bc2d8c..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -193.0313441753388,0.0,16.281517267227173,1,0,16.281517267227173,30.942242,2472,3.698637093006723,209.31295228004456,31.853792,3.525787909704684,30.891838,5348,3.324647363796982 -323.79737067222595,0.028933048248291,1456.6854193210602,1789,0,1456.6854193210602,1.2356836,2472,0.3786890906505799,1780.5886964797974,1.6250488,0.4563804783630582,1.6839274,5348,0.4397694468849262 -459.8945791721344,0.0802168846130371,2897.160602331161,3594,0,2897.160602331161,0.6614137,2472,0.2108545081550992,3357.2956132888794,0.87861687,0.2692583798111436,1.0104467,5348,0.2839530011489037 -594.9467799663544,0.1359140872955322,4337.480946302414,5389,0,4337.480946302414,0.646533,2472,0.2001909288485365,4932.8055765628815,0.98097146,0.2836968079089654,1.0103266,5348,0.279453932822924 -729.7598900794983,0.1928915977478027,5777.349337816238,7157,0,5777.349337816238,0.57268214,2472,0.1808136818800398,6507.6251838207245,0.8190737,0.2471248876909254,0.892246,5348,0.2521023007038242 -866.3746891021729,0.244370698928833,7217.478336572647,8942,0,7217.478336572647,0.54710823,2472,0.1736233826904718,8084.5014843940735,0.79178494,0.2449146461528192,0.8644864,5348,0.2443496142966102 -1003.893921136856,0.2965955734252929,8658.395947933197,10758,0,8658.395947933197,0.55628407,2472,0.1745780269331545,9663.073320865631,0.7392486,0.2289254094419144,0.8693898,5348,0.246811550826921 -1139.6634182929993,0.3509695529937744,10098.912323236464,12547,0,10098.912323236464,0.5140238,2472,0.1648284687100115,11239.493403196337,0.65287876,0.2031998445234529,0.8087663,5348,0.2305241511146297 -1275.162227153778,0.4020240306854248,11539.125578641891,14325,0,11539.125578641891,0.48911905,2472,0.1573334958259704,12815.336868524551,0.71331424,0.2232145363099013,0.78945225,5348,0.2262278305029108 -1410.6499030590055,0.4547863006591797,12979.941724300385,16117,0,12979.941724300385,0.47254232,2472,0.1539414620275019,14391.774739265442,0.62898093,0.1986291381070438,0.76331466,5348,0.2208791527076474 -1546.1836066246033,0.5079362392425537,14420.404343128204,17933,0,14420.404343128204,0.45915446,2472,0.1470558365324071,15967.906356573105,0.6621947,0.2031150388005891,0.7468117,5348,0.2147677573206406 -1681.472809791565,0.5603365898132324,15860.634415864944,19714,0,15860.634415864944,0.4412394,2472,0.1405764426299433,17543.55829191208,0.5638722,0.1797114027233827,0.7201087,5348,0.2051034496075383 -1815.649961233139,0.6182372570037842,17300.85076236725,21494,0,17300.85076236725,0.42382613,2472,0.1354782361424248,19118.08944129944,0.5838771,0.1874216988223503,0.70004,5348,0.200015447444896 -1950.2114119529724,0.6731538772583008,18741.65048289299,23301,0,18741.65048289299,0.4096471,2472,0.131558101273536,20693.58714914322,0.5520931,0.1759146443701724,0.67959964,5348,0.1954101779352559 -2093.9861640930176,0.7318322658538818,20182.07522916794,25116,0,20182.07522916794,0.3809952,2472,0.12373814311539,22277.92815804481,0.33693334,0.1138513423088247,0.645735,5348,0.1852052096507912 -2229.720189809799,0.790858268737793,21621.94614171981,26912,0,21621.94614171981,0.3652431,2472,0.1191477261186602,23853.673711776733,0.31949133,0.1068421107012595,0.6205764,5348,0.180821997161532 -2365.1423330307007,0.8519728183746338,23062.360239744183,28704,0,23062.360239744183,0.3550856,2472,0.1157556923201917,25429.651166677475,0.30595523,0.1032869033910158,0.6050423,5348,0.1738706469583015 -2500.8551738262177,0.9049315452575684,24503.368768692017,30507,0,24503.368768692017,0.33707365,2472,0.1091544289399386,27006.50775051117,0.2792333,0.0960252405609945,0.5789832,5348,0.1674309933672533 -2635.5505130290985,0.9616458415985109,25943.38379716873,32326,0,25943.38379716873,0.31017295,2472,0.0994658054556903,28581.357215881348,0.26147306,0.0895716509242873,0.5437209,5348,0.1586549137356749 -2771.551521062851,1.019352912902832,27383.67330932617,34115,0,27383.67330932617,0.29474744,2472,0.0952816200515914,30157.786672353745,0.23398308,0.0817539105541467,0.51874906,5348,0.1500622725122372 -2907.5363302230835,1.075251579284668,28824.013179302216,35915,0,28824.013179302216,0.27777967,2472,0.0904677756789145,31734.248168230057,0.25606412,0.0835476626773314,0.49699828,5348,0.1436612375334292 -3044.5422701835632,1.1326208114624023,30264.20079755783,37721,0,30264.20079755783,0.26118982,2472,0.0833180996486096,33311.580392599106,0.20800158,0.0720675078764466,0.47449413,5348,0.1377815538198634 -3180.9032578468323,1.188124656677246,31704.79609966278,39546,0,31704.79609966278,0.24431218,2472,0.0793979647797209,34888.67411828041,0.17795631,0.0637573506654286,0.45436054,5348,0.1326259690858009 -3315.796919107437,1.2503759860992432,33145.08502173424,41345,0,33145.08502173424,0.23327044,2472,0.0745638088274125,36463.99954032898,0.16376394,0.0562627900178063,0.43039808,5348,0.1254139432499493 -3451.147789478302,1.3133165836334229,34585.27427482605,43139,0,34585.27427482605,0.22326158,2472,0.071618629780838,38039.684517383575,0.14803201,0.0525516795865633,0.41873074,5348,0.1214844994545121 -3587.4509439468384,1.3760840892791748,36025.31555771828,44943,0,36025.31555771828,0.2159714,2472,0.0694046676009993,39616.17360472679,0.16079587,0.054562870721733,0.4065272,5348,0.1176322928835552 -3723.850387096405,1.433948278427124,37465.57762217522,46770,0,37465.57762217522,0.2135498,2472,0.0683484654601588,41192.97604799271,0.15611503,0.0529189681591044,0.40304145,5348,0.1167440648020313 -3861.5062580108643,1.4898200035095215,38433.03055810928,48000,0,38433.03055810928,0.21367131,2472,0.06851095809721122,42298.19808506966,0.147639,0.05145846511193836,0.40327126,5348,0.11670544618979117 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/measurements.csv deleted file mode 100644 index 0f41a5e3d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.54699,32.61637,,,,,,,,,,,,,, -1,,,31.853792,3.525787909704684,30.891838,3.324647363796982,5348.0,30.942242,3.698637093006723,2472.0,16.281517267227173,209.31295228004456,16.281517267227173,193.0313441753388,0.0,0.0 -100,0.9575856,5.735673,,,,,,,,,,,,,, -200,1.3971944,4.5280766,,,,,,,,,,,,,, -300,63.26881,4.284371,,,,,,,,,,,,,, -400,2.3799198,3.1970003,,,,,,,,,,,,,, -500,2.9636633,2.9168386,,,,,,,,,,,,,, -600,2.157837,2.816291,,,,,,,,,,,,,, -700,2.3549454,2.6472092,,,,,,,,,,,,,, -800,2.510827,2.5534768,,,,,,,,,,,,,, -900,3.2353494,2.5338666,,,,,,,,,,,,,, -1000,3.4511628,2.509617,,,,,,,,,,,,,, -1100,2.3512344,2.420075,,,,,,,,,,,,,, -1200,2.0291066,2.343738,,,,,,,,,,,,,, -1300,2.5603266,2.3237734,,,,,,,,,,,,,, -1400,2.5725737,2.237245,,,,,,,,,,,,,, -1500,2.9105186,2.2393854,,,,,,,,,,,,,, -1600,2.4775112,2.213996,,,,,,,,,,,,,, -1700,3.425678,2.184754,,,,,,,,,,,,,, -1789,,,1.6250488,0.4563804783630582,1.6839274,0.4397694468849262,5348.0,1.2356836,0.3786890906505799,2472.0,1456.6854193210602,1780.5886964797974,1456.6854193210602,323.79737067222595,0.028933048248291,0.0 -1800,1.8834116,2.1895423,,,,,,,,,,,,,, -1900,4.231301,2.2134824,,,,,,,,,,,,,, -2000,2.5670903,2.120617,,,,,,,,,,,,,, -2100,2.332119,2.1184788,,,,,,,,,,,,,, -2200,2.653341,2.1391792,,,,,,,,,,,,,, -2300,6.1167936,2.0331054,,,,,,,,,,,,,, -2400,3.3007898,2.0564594,,,,,,,,,,,,,, -2500,3.4525912,2.0562315,,,,,,,,,,,,,, -2600,3.3196018,2.0393884,,,,,,,,,,,,,, -2700,3.5854933,2.1081173,,,,,,,,,,,,,, -2800,2.105704,2.0359068,,,,,,,,,,,,,, -2900,1.7365869,2.0214326,,,,,,,,,,,,,, -3000,2.2454698,1.9472492,,,,,,,,,,,,,, -3100,2.2280498,2.0595293,,,,,,,,,,,,,, -3200,2.9597843,1.956611,,,,,,,,,,,,,, -3300,2.9170053,1.9977441,,,,,,,,,,,,,, -3400,2.0953975,2.0075226,,,,,,,,,,,,,, -3500,2.4822304,1.9963216,,,,,,,,,,,,,, -3594,,,0.87861687,0.2692583798111436,1.0104467,0.2839530011489037,5348.0,0.6614137,0.2108545081550992,2472.0,2897.160602331161,3357.2956132888794,2897.160602331161,459.8945791721344,0.0802168846130371,0.0 -3600,4.26001,2.1079943,,,,,,,,,,,,,, -3700,3.9879885,1.9829823,,,,,,,,,,,,,, -3800,2.5993843,1.9725739,,,,,,,,,,,,,, -3900,3.063369,1.970152,,,,,,,,,,,,,, -4000,2.0065022,1.9855816,,,,,,,,,,,,,, -4100,2.9636679,1.9882481,,,,,,,,,,,,,, -4200,2.7486756,2.0095007,,,,,,,,,,,,,, -4300,6.688135,1.9274207,,,,,,,,,,,,,, -4400,3.5467212,1.9648402,,,,,,,,,,,,,, -4500,3.0031009,1.9181702,,,,,,,,,,,,,, -4600,3.008346,1.964307,,,,,,,,,,,,,, -4700,2.3915489,1.9464334,,,,,,,,,,,,,, -4800,2.261763,1.9002771,,,,,,,,,,,,,, -4900,2.527052,1.9471017,,,,,,,,,,,,,, -5000,3.638614,1.9105314,,,,,,,,,,,,,, -5100,2.437363,1.992387,,,,,,,,,,,,,, -5200,3.1353376,1.9615394,,,,,,,,,,,,,, -5300,2.6386538,1.9867024,,,,,,,,,,,,,, -5389,,,0.98097146,0.2836968079089654,1.0103266,0.279453932822924,5348.0,0.646533,0.2001909288485365,2472.0,4337.480946302414,4932.8055765628815,4337.480946302414,594.9467799663544,0.1359140872955322,0.0 -5400,1.880312,1.8790891,,,,,,,,,,,,,, -5500,1.9137745,1.927081,,,,,,,,,,,,,, -5600,3.3702047,1.9213492,,,,,,,,,,,,,, -5700,3.0901432,1.9981697,,,,,,,,,,,,,, -5800,1.7937202,1.9635477,,,,,,,,,,,,,, -5900,2.8418267,1.9392197,,,,,,,,,,,,,, -6000,2.786588,1.9234713,,,,,,,,,,,,,, -6100,2.554915,1.9467359,,,,,,,,,,,,,, -6200,2.3107872,1.8904766,,,,,,,,,,,,,, -6300,3.111662,2.055473,,,,,,,,,,,,,, -6400,2.996001,1.9436376,,,,,,,,,,,,,, -6500,2.4614258,1.9157586,,,,,,,,,,,,,, -6600,3.2838538,1.9538473,,,,,,,,,,,,,, -6700,3.472812,1.8590536,,,,,,,,,,,,,, -6800,2.455214,1.9492735,,,,,,,,,,,,,, -6900,2.1805415,1.8874133,,,,,,,,,,,,,, -7000,3.1230686,1.8721135,,,,,,,,,,,,,, -7100,2.4933062,1.8350745,,,,,,,,,,,,,, -7157,,,0.8190737,0.2471248876909254,0.892246,0.2521023007038242,5348.0,0.57268214,0.1808136818800398,2472.0,5777.349337816238,6507.6251838207245,5777.349337816238,729.7598900794983,0.1928915977478027,0.0 -7200,2.1326225,1.8570938,,,,,,,,,,,,,, -7300,4.7984734,1.9024143,,,,,,,,,,,,,, -7400,2.0409188,1.8838562,,,,,,,,,,,,,, -7500,3.0880525,1.9354765,,,,,,,,,,,,,, -7600,3.9142215,1.9053291,,,,,,,,,,,,,, -7700,2.8880103,1.913737,,,,,,,,,,,,,, -7800,2.044566,1.9178556,,,,,,,,,,,,,, -7900,2.9431078,1.8826551,,,,,,,,,,,,,, -8000,3.2730587,1.8652831,,,,,,,,,,,,,, -8100,2.0343542,1.8882145,,,,,,,,,,,,,, -8200,1.9699998,1.8281541,,,,,,,,,,,,,, -8300,3.1878078,1.8343315,,,,,,,,,,,,,, -8400,3.938653,1.9458963,,,,,,,,,,,,,, -8500,3.4593647,1.8647703,,,,,,,,,,,,,, -8600,2.9552717,1.889458,,,,,,,,,,,,,, -8700,3.1542876,1.9189965,,,,,,,,,,,,,, -8800,4.0816674,1.8405349,,,,,,,,,,,,,, -8900,3.2012844,1.8455746,,,,,,,,,,,,,, -8942,,,0.79178494,0.2449146461528192,0.8644864,0.2443496142966102,5348.0,0.54710823,0.1736233826904718,2472.0,7217.478336572647,8084.5014843940735,7217.478336572647,866.3746891021729,0.244370698928833,0.0 -9000,5.697097,2.4496942,,,,,,,,,,,,,, -9100,3.9782417,1.9364939,,,,,,,,,,,,,, -9200,3.2802165,1.8564115,,,,,,,,,,,,,, -9300,2.4129522,1.9177094,,,,,,,,,,,,,, -9400,3.328307,1.9163256,,,,,,,,,,,,,, -9500,2.9666874,1.7855198,,,,,,,,,,,,,, -9600,1.9597337,1.8684777,,,,,,,,,,,,,, -9700,2.43019,1.9149616,,,,,,,,,,,,,, -9800,2.314126,1.8836082,,,,,,,,,,,,,, -9900,2.9577649,1.8711576,,,,,,,,,,,,,, -10000,2.415025,1.839082,,,,,,,,,,,,,, -10100,2.192709,1.8291888,,,,,,,,,,,,,, -10200,2.2761915,1.7795694,,,,,,,,,,,,,, -10300,2.1340048,1.8398124,,,,,,,,,,,,,, -10400,2.643748,1.9180031,,,,,,,,,,,,,, -10500,3.9203627,1.8580401,,,,,,,,,,,,,, -10600,2.9858615,1.9255781,,,,,,,,,,,,,, -10700,3.0292017,1.8740855,,,,,,,,,,,,,, -10758,,,0.7392486,0.2289254094419144,0.8693898,0.246811550826921,5348.0,0.55628407,0.1745780269331545,2472.0,8658.395947933197,9663.073320865631,8658.395947933197,1003.893921136856,0.2965955734252929,0.0 -10800,2.0767882,1.8317219,,,,,,,,,,,,,, -10900,2.333391,1.8100208,,,,,,,,,,,,,, -11000,2.2019935,1.8188764,,,,,,,,,,,,,, -11100,2.1525977,1.83267,,,,,,,,,,,,,, -11200,3.1199212,1.793217,,,,,,,,,,,,,, -11300,3.9827683,1.8612924,,,,,,,,,,,,,, -11400,2.6836252,1.8735613,,,,,,,,,,,,,, -11500,1.8299406,1.8202825,,,,,,,,,,,,,, -11600,2.242928,1.8007371,,,,,,,,,,,,,, -11700,3.4324903,1.8386467,,,,,,,,,,,,,, -11800,2.9142098,1.8768868,,,,,,,,,,,,,, -11900,2.2841682,1.8976272,,,,,,,,,,,,,, -12000,3.4734943,1.7915025,,,,,,,,,,,,,, -12100,1.7968678,1.7125173,,,,,,,,,,,,,, -12200,2.8634043,1.7997851,,,,,,,,,,,,,, -12300,3.2489996,1.8070831,,,,,,,,,,,,,, -12400,2.8853543,1.9050739,,,,,,,,,,,,,, -12500,3.2254658,1.846768,,,,,,,,,,,,,, -12547,,,0.65287876,0.2031998445234529,0.8087663,0.2305241511146297,5348.0,0.5140238,0.1648284687100115,2472.0,10098.912323236464,11239.493403196337,10098.912323236464,1139.6634182929993,0.3509695529937744,0.0 -12600,2.3850136,1.7364913,,,,,,,,,,,,,, -12700,2.561576,1.779639,,,,,,,,,,,,,, -12800,2.1554089,1.8661412,,,,,,,,,,,,,, -12900,2.571332,1.8086572,,,,,,,,,,,,,, -13000,2.3998134,1.7421108,,,,,,,,,,,,,, -13100,2.8051004,1.8409265,,,,,,,,,,,,,, -13200,3.61711,1.8245889,,,,,,,,,,,,,, -13300,2.383348,1.8451359,,,,,,,,,,,,,, -13400,2.4686828,1.7744765,,,,,,,,,,,,,, -13500,2.8218699,1.8235028,,,,,,,,,,,,,, -13600,3.2198718,1.8023434,,,,,,,,,,,,,, -13700,2.4012654,1.7702552,,,,,,,,,,,,,, -13800,1.8278702,1.7366623,,,,,,,,,,,,,, -13900,2.660468,1.7809682,,,,,,,,,,,,,, -14000,2.8870568,1.7749273,,,,,,,,,,,,,, -14100,2.07422,1.6935555,,,,,,,,,,,,,, -14200,1.942051,1.8187919,,,,,,,,,,,,,, -14300,2.6889613,1.8009812,,,,,,,,,,,,,, -14325,,,0.71331424,0.2232145363099013,0.78945225,0.2262278305029108,5348.0,0.48911905,0.1573334958259704,2472.0,11539.125578641891,12815.336868524551,11539.125578641891,1275.162227153778,0.4020240306854248,0.0 -14400,2.0976374,1.7557777,,,,,,,,,,,,,, -14500,2.044434,1.7306406,,,,,,,,,,,,,, -14600,1.9758577,1.7132549,,,,,,,,,,,,,, -14700,2.9199884,1.788698,,,,,,,,,,,,,, -14800,3.4975166,1.7594938,,,,,,,,,,,,,, -14900,2.7001593,1.7215959,,,,,,,,,,,,,, -15000,2.3231282,1.7631202,,,,,,,,,,,,,, -15100,1.9373245,1.7732096,,,,,,,,,,,,,, -15200,1.5395864,1.7731265,,,,,,,,,,,,,, -15300,2.4391599,1.8399214,,,,,,,,,,,,,, -15400,1.9908731,1.7710354,,,,,,,,,,,,,, -15500,2.2023077,1.7900883,,,,,,,,,,,,,, -15600,2.4172325,1.8020089,,,,,,,,,,,,,, -15700,3.4329662,1.7561414,,,,,,,,,,,,,, -15800,1.9667262,1.7568959,,,,,,,,,,,,,, -15900,2.5744069,1.6917347,,,,,,,,,,,,,, -16000,5.852506,1.766893,,,,,,,,,,,,,, -16100,2.389939,1.6958481,,,,,,,,,,,,,, -16117,,,0.62898093,0.1986291381070438,0.76331466,0.2208791527076474,5348.0,0.47254232,0.1539414620275019,2472.0,12979.941724300385,14391.774739265442,12979.941724300385,1410.6499030590055,0.4547863006591797,0.0 -16200,1.8729246,1.651036,,,,,,,,,,,,,, -16300,2.3016443,1.7319129,,,,,,,,,,,,,, -16400,1.9993213,1.6611593,,,,,,,,,,,,,, -16500,2.1718044,1.6390165,,,,,,,,,,,,,, -16600,3.589682,1.7351308,,,,,,,,,,,,,, -16700,2.298212,1.7113657,,,,,,,,,,,,,, -16800,2.5371556,1.8053193,,,,,,,,,,,,,, -16900,3.4267187,1.726816,,,,,,,,,,,,,, -17000,3.1550643,1.7126545,,,,,,,,,,,,,, -17100,2.8193338,1.7004589,,,,,,,,,,,,,, -17200,2.2373044,1.6724887,,,,,,,,,,,,,, -17300,2.3006063,1.686886,,,,,,,,,,,,,, -17400,2.8341248,1.7748688,,,,,,,,,,,,,, -17500,2.8288312,1.7548255,,,,,,,,,,,,,, -17600,2.2772264,1.6646298,,,,,,,,,,,,,, -17700,2.0347097,1.7182989,,,,,,,,,,,,,, -17800,3.2641857,1.7614706,,,,,,,,,,,,,, -17900,2.490569,1.6852574,,,,,,,,,,,,,, -17933,,,0.6621947,0.2031150388005891,0.7468117,0.2147677573206406,5348.0,0.45915446,0.1470558365324071,2472.0,14420.404343128204,15967.906356573105,14420.404343128204,1546.1836066246033,0.5079362392425537,0.0 -18000,2.1281693,1.7189945,,,,,,,,,,,,,, -18100,2.0354218,1.6948831,,,,,,,,,,,,,, -18200,3.1184006,1.6702783,,,,,,,,,,,,,, -18300,1.8143419,1.6649643,,,,,,,,,,,,,, -18400,3.735671,1.7446811,,,,,,,,,,,,,, -18500,3.6219478,1.7219826,,,,,,,,,,,,,, -18600,1.8653822,1.7176231,,,,,,,,,,,,,, -18700,4.4598923,1.642849,,,,,,,,,,,,,, -18800,3.336395,1.6884857,,,,,,,,,,,,,, -18900,2.2820368,1.7198137,,,,,,,,,,,,,, -19000,2.3936079,1.6795193,,,,,,,,,,,,,, -19100,2.7894497,1.7389752,,,,,,,,,,,,,, -19200,2.5446384,1.6776805,,,,,,,,,,,,,, -19300,1.8186355,1.6935564,,,,,,,,,,,,,, -19400,3.2519708,1.6695313,,,,,,,,,,,,,, -19500,3.004108,1.5961806,,,,,,,,,,,,,, -19600,2.664751,1.6658831,,,,,,,,,,,,,, -19700,2.7382488,1.6905332,,,,,,,,,,,,,, -19714,,,0.5638722,0.1797114027233827,0.7201087,0.2051034496075383,5348.0,0.4412394,0.1405764426299433,2472.0,15860.634415864944,17543.55829191208,15860.634415864944,1681.472809791565,0.5603365898132324,0.0 -19800,3.1019254,1.6438166,,,,,,,,,,,,,, -19900,3.713493,1.6178144,,,,,,,,,,,,,, -20000,3.4739163,1.7397592,,,,,,,,,,,,,, -20100,3.2075334,1.6467113,,,,,,,,,,,,,, -20200,2.5785475,1.7338812,,,,,,,,,,,,,, -20300,2.972027,1.6817714,,,,,,,,,,,,,, -20400,2.0613854,1.6991502,,,,,,,,,,,,,, -20500,3.3287656,1.6984649,,,,,,,,,,,,,, -20600,2.866622,1.6297659,,,,,,,,,,,,,, -20700,2.3006737,1.6250931,,,,,,,,,,,,,, -20800,2.2270308,1.6624316,,,,,,,,,,,,,, -20900,2.2164445,1.6731019,,,,,,,,,,,,,, -21000,2.6497896,1.7032751,,,,,,,,,,,,,, -21100,2.492021,1.6267931,,,,,,,,,,,,,, -21200,2.4696465,1.6721495,,,,,,,,,,,,,, -21300,2.4997656,1.6319464,,,,,,,,,,,,,, -21400,1.5189568,1.6348926,,,,,,,,,,,,,, -21494,,,0.5838771,0.1874216988223503,0.70004,0.200015447444896,5348.0,0.42382613,0.1354782361424248,2472.0,17300.85076236725,19118.08944129944,17300.85076236725,1815.649961233139,0.6182372570037842,0.0 -21500,1.9542967,1.7041678,,,,,,,,,,,,,, -21600,4.319945,1.6676352,,,,,,,,,,,,,, -21700,8.582473,1.5390725,,,,,,,,,,,,,, -21800,4.072018,1.7071983,,,,,,,,,,,,,, -21900,2.4985502,1.663934,,,,,,,,,,,,,, -22000,3.3925073,1.6065192,,,,,,,,,,,,,, -22100,2.3643782,1.5952822,,,,,,,,,,,,,, -22200,2.6425323,1.6227412,,,,,,,,,,,,,, -22300,4.246853,1.6578791,,,,,,,,,,,,,, -22400,2.5256894,1.6392218,,,,,,,,,,,,,, -22500,4.091908,1.6801251,,,,,,,,,,,,,, -22600,1.9727869,1.5738033,,,,,,,,,,,,,, -22700,2.3637974,1.5555183,,,,,,,,,,,,,, -22800,2.1003594,1.5903733,,,,,,,,,,,,,, -22900,4.702016,1.6145908,,,,,,,,,,,,,, -23000,2.2818136,1.5871128,,,,,,,,,,,,,, -23100,2.7912297,1.5887818,,,,,,,,,,,,,, -23200,3.254649,1.6600498,,,,,,,,,,,,,, -23300,2.7824588,1.6845073,,,,,,,,,,,,,, -23301,,,0.5520931,0.1759146443701724,0.67959964,0.1954101779352559,5348.0,0.4096471,0.131558101273536,2472.0,18741.65048289299,20693.58714914322,18741.65048289299,1950.2114119529724,0.6731538772583008,0.0 -23400,2.0950701,1.6407579,,,,,,,,,,,,,, -23500,2.1319606,1.6015621,,,,,,,,,,,,,, -23600,2.4574044,1.6132256,,,,,,,,,,,,,, -23700,2.3144248,1.6132182,,,,,,,,,,,,,, -23800,3.65675,1.627581,,,,,,,,,,,,,, -23900,2.5907152,1.6602396,,,,,,,,,,,,,, -24000,4.863266,1.6496692,,,,,,,,,,,,,, -24100,4.105412,1.5984224,,,,,,,,,,,,,, -24200,3.836656,1.616959,,,,,,,,,,,,,, -24300,1.8987403,1.6398704,,,,,,,,,,,,,, -24400,2.3400373,1.544015,,,,,,,,,,,,,, -24500,2.105082,1.6122988,,,,,,,,,,,,,, -24600,2.2418785,1.63797,,,,,,,,,,,,,, -24700,4.4203763,1.5326868,,,,,,,,,,,,,, -24800,2.2009091,1.618324,,,,,,,,,,,,,, -24900,2.7122688,1.5830824,,,,,,,,,,,,,, -25000,3.3186336,1.530223,,,,,,,,,,,,,, -25100,1.8767952,1.5562209,,,,,,,,,,,,,, -25116,,,0.33693334,0.1138513423088247,0.645735,0.1852052096507912,5348.0,0.3809952,0.12373814311539,2472.0,20182.07522916794,22277.92815804481,20182.07522916794,2093.9861640930176,0.7318322658538818,0.0 -25200,3.2075558,1.5860267,,,,,,,,,,,,,, -25300,3.374045,1.6234643,,,,,,,,,,,,,, -25400,3.4791543,1.5941715,,,,,,,,,,,,,, -25500,2.993545,1.6255097,,,,,,,,,,,,,, -25600,2.461433,1.6195518,,,,,,,,,,,,,, -25700,2.408703,1.5627484,,,,,,,,,,,,,, -25800,2.9434774,1.5803328,,,,,,,,,,,,,, -25900,2.2767873,1.5624132,,,,,,,,,,,,,, -26000,3.5099776,1.594696,,,,,,,,,,,,,, -26100,1.7477845,1.6113962,,,,,,,,,,,,,, -26200,2.7859392,1.5766802,,,,,,,,,,,,,, -26300,2.383408,1.6370009,,,,,,,,,,,,,, -26400,2.0025282,1.5165825,,,,,,,,,,,,,, -26500,2.4207559,1.48415,,,,,,,,,,,,,, -26600,1.861664,1.5696009,,,,,,,,,,,,,, -26700,1.6786275,1.5366694,,,,,,,,,,,,,, -26800,2.8315737,1.5824832,,,,,,,,,,,,,, -26900,2.576414,1.4980696,,,,,,,,,,,,,, -26912,,,0.31949133,0.1068421107012595,0.6205764,0.180821997161532,5348.0,0.3652431,0.1191477261186602,2472.0,21621.94614171981,23853.673711776733,21621.94614171981,2229.720189809799,0.790858268737793,0.0 -27000,1.6711705,1.5040073,,,,,,,,,,,,,, -27100,1.529492,1.5455849,,,,,,,,,,,,,, -27200,2.285821,1.4800166,,,,,,,,,,,,,, -27300,1.9500811,1.5504388,,,,,,,,,,,,,, -27400,2.0054104,1.4912776,,,,,,,,,,,,,, -27500,2.6576695,1.532404,,,,,,,,,,,,,, -27600,2.17988,1.5228382,,,,,,,,,,,,,, -27700,2.6397376,1.5300585,,,,,,,,,,,,,, -27800,2.5312526,1.527482,,,,,,,,,,,,,, -27900,1.8676174,1.4772565,,,,,,,,,,,,,, -28000,4.514863,1.570009,,,,,,,,,,,,,, -28100,3.7404826,1.5887674,,,,,,,,,,,,,, -28200,2.9199238,1.54009,,,,,,,,,,,,,, -28300,2.3252935,1.4875449,,,,,,,,,,,,,, -28400,3.071763,1.5633391,,,,,,,,,,,,,, -28500,1.7497067,1.5236588,,,,,,,,,,,,,, -28600,2.266763,1.523601,,,,,,,,,,,,,, -28700,2.1790202,1.5323473,,,,,,,,,,,,,, -28704,,,0.30595523,0.1032869033910158,0.6050423,0.1738706469583015,5348.0,0.3550856,0.1157556923201917,2472.0,23062.360239744183,25429.651166677475,23062.360239744183,2365.1423330307007,0.8519728183746338,0.0 -28800,2.062605,1.5174,,,,,,,,,,,,,, -28900,1.99171,1.5385579,,,,,,,,,,,,,, -29000,3.1315007,1.459182,,,,,,,,,,,,,, -29100,2.1659596,1.4904736,,,,,,,,,,,,,, -29200,2.0810974,1.502389,,,,,,,,,,,,,, -29300,2.2628055,1.4698683,,,,,,,,,,,,,, -29400,2.7646112,1.5152928,,,,,,,,,,,,,, -29500,4.4674754,1.5515231,,,,,,,,,,,,,, -29600,1.6607236,1.4933678,,,,,,,,,,,,,, -29700,2.1607637,1.49834,,,,,,,,,,,,,, -29800,2.7831576,1.4955902,,,,,,,,,,,,,, -29900,2.8445954,1.4559997,,,,,,,,,,,,,, -30000,2.3502746,1.4033072,,,,,,,,,,,,,, -30100,1.9316044,1.4403894,,,,,,,,,,,,,, -30200,2.3136182,1.4765605,,,,,,,,,,,,,, -30300,2.8794346,1.4998139,,,,,,,,,,,,,, -30400,4.305969,1.4315038,,,,,,,,,,,,,, -30500,1.9486443,1.4750141,,,,,,,,,,,,,, -30507,,,0.2792333,0.0960252405609945,0.5789832,0.1674309933672533,5348.0,0.33707365,0.1091544289399386,2472.0,24503.368768692017,27006.50775051117,24503.368768692017,2500.8551738262177,0.9049315452575684,0.0 -30600,3.7370853,1.5398817,,,,,,,,,,,,,, -30700,2.2587337,1.4963022,,,,,,,,,,,,,, -30800,1.5663055,1.4776946,,,,,,,,,,,,,, -30900,2.5251553,1.4958942,,,,,,,,,,,,,, -31000,3.8962476,1.4739816,,,,,,,,,,,,,, -31100,3.7257826,1.5078332,,,,,,,,,,,,,, -31200,2.87347,1.460807,,,,,,,,,,,,,, -31300,3.3989947,1.4311584,,,,,,,,,,,,,, -31400,1.8615774,1.4425337,,,,,,,,,,,,,, -31500,2.3977904,1.4599686,,,,,,,,,,,,,, -31600,1.5960914,1.3917744,,,,,,,,,,,,,, -31700,1.971349,1.4857125,,,,,,,,,,,,,, -31800,2.7631917,1.5086154,,,,,,,,,,,,,, -31900,2.5959291,1.3562217,,,,,,,,,,,,,, -32000,2.2772655,1.3690877,,,,,,,,,,,,,, -32100,2.0310888,1.4536624,,,,,,,,,,,,,, -32200,3.7010303,1.4456321,,,,,,,,,,,,,, -32300,1.7712647,1.3969322,,,,,,,,,,,,,, -32326,,,0.26147306,0.0895716509242873,0.5437209,0.1586549137356749,5348.0,0.31017295,0.0994658054556903,2472.0,25943.38379716873,28581.357215881348,25943.38379716873,2635.5505130290985,0.9616458415985109,0.0 -32400,2.1949925,1.4715042,,,,,,,,,,,,,, -32500,2.5745492,1.4251665,,,,,,,,,,,,,, -32600,2.7661562,1.3397549,,,,,,,,,,,,,, -32700,2.5932288,1.3629625,,,,,,,,,,,,,, -32800,2.9929245,1.4021475,,,,,,,,,,,,,, -32900,1.8471522,1.3900422,,,,,,,,,,,,,, -33000,2.5382118,1.3974724,,,,,,,,,,,,,, -33100,1.8578038,1.4240631,,,,,,,,,,,,,, -33200,1.9299922,1.445138,,,,,,,,,,,,,, -33300,2.0470772,1.3943163,,,,,,,,,,,,,, -33400,2.4200585,1.3333107,,,,,,,,,,,,,, -33500,1.493197,1.3902557,,,,,,,,,,,,,, -33600,2.302345,1.37027,,,,,,,,,,,,,, -33700,1.7316784,1.329682,,,,,,,,,,,,,, -33800,3.110834,1.3936256,,,,,,,,,,,,,, -33900,3.1921253,1.3560069,,,,,,,,,,,,,, -34000,4.8770356,1.3903055,,,,,,,,,,,,,, -34100,1.7707328,1.3621941,,,,,,,,,,,,,, -34115,,,0.23398308,0.0817539105541467,0.51874906,0.1500622725122372,5348.0,0.29474744,0.0952816200515914,2472.0,27383.67330932617,30157.786672353745,27383.67330932617,2771.551521062851,1.019352912902832,0.0 -34200,1.3829465,1.354673,,,,,,,,,,,,,, -34300,3.3488522,1.3773419,,,,,,,,,,,,,, -34400,2.194156,1.3601433,,,,,,,,,,,,,, -34500,3.1184423,1.3659201,,,,,,,,,,,,,, -34600,1.6129599,1.3598609,,,,,,,,,,,,,, -34700,1.6443785,1.3623115,,,,,,,,,,,,,, -34800,1.7379116,1.3407234,,,,,,,,,,,,,, -34900,2.585359,1.4185568,,,,,,,,,,,,,, -35000,1.9152533,1.3707522,,,,,,,,,,,,,, -35100,2.43093,1.3749954,,,,,,,,,,,,,, -35200,1.9972389,1.3080704,,,,,,,,,,,,,, -35300,1.981178,1.3461201,,,,,,,,,,,,,, -35400,2.2937005,1.3561089,,,,,,,,,,,,,, -35500,2.4277942,1.3595736,,,,,,,,,,,,,, -35600,2.2316155,1.3238982,,,,,,,,,,,,,, -35700,1.8096721,1.3486605,,,,,,,,,,,,,, -35800,2.884621,1.3030735,,,,,,,,,,,,,, -35900,1.6321298,1.3576561,,,,,,,,,,,,,, -35915,,,0.25606412,0.0835476626773314,0.49699828,0.1436612375334292,5348.0,0.27777967,0.0904677756789145,2472.0,28824.013179302216,31734.248168230057,28824.013179302216,2907.5363302230835,1.075251579284668,0.0 -36000,3.154991,1.3818473,,,,,,,,,,,,,, -36100,6.4689326,1.2870562,,,,,,,,,,,,,, -36200,2.8296156,1.2765713,,,,,,,,,,,,,, -36300,3.3258622,1.3618257,,,,,,,,,,,,,, -36400,2.3314676,1.2969822,,,,,,,,,,,,,, -36500,1.352626,1.3387536,,,,,,,,,,,,,, -36600,2.334899,1.3703932,,,,,,,,,,,,,, -36700,3.5430813,1.2943281,,,,,,,,,,,,,, -36800,3.0539744,1.2882965,,,,,,,,,,,,,, -36900,1.7224653,1.298516,,,,,,,,,,,,,, -37000,2.7775319,1.2891002,,,,,,,,,,,,,, -37100,1.8925534,1.3020006,,,,,,,,,,,,,, -37200,3.1100032,1.2778343,,,,,,,,,,,,,, -37300,1.7656263,1.2982068,,,,,,,,,,,,,, -37400,1.8303001,1.3363458,,,,,,,,,,,,,, -37500,1.7455261,1.2712995,,,,,,,,,,,,,, -37600,1.936488,1.3402406,,,,,,,,,,,,,, -37700,2.4536138,1.300963,,,,,,,,,,,,,, -37721,,,0.20800158,0.0720675078764466,0.47449413,0.1377815538198634,5348.0,0.26118982,0.0833180996486096,2472.0,30264.20079755783,33311.580392599106,30264.20079755783,3044.5422701835632,1.1326208114624023,0.0 -37800,2.1587682,1.3662511,,,,,,,,,,,,,, -37900,3.026114,1.3491422,,,,,,,,,,,,,, -38000,1.4121199,1.2719474,,,,,,,,,,,,,, -38100,2.357528,1.3116807,,,,,,,,,,,,,, -38200,3.089932,1.2772626,,,,,,,,,,,,,, -38300,1.5428772,1.2614571,,,,,,,,,,,,,, -38400,1.6725423,1.2523215,,,,,,,,,,,,,, -38500,2.054454,1.2905996,,,,,,,,,,,,,, -38600,2.4416907,1.3162302,,,,,,,,,,,,,, -38700,3.041414,1.285817,,,,,,,,,,,,,, -38800,2.119477,1.2648721,,,,,,,,,,,,,, -38900,3.676389,1.2280129,,,,,,,,,,,,,, -39000,3.8962934,1.3158864,,,,,,,,,,,,,, -39100,2.4464972,1.3155171,,,,,,,,,,,,,, -39200,3.908872,1.2324454,,,,,,,,,,,,,, -39300,2.1682532,1.2373816,,,,,,,,,,,,,, -39400,2.1243463,1.2855235,,,,,,,,,,,,,, -39500,2.134438,1.2179582,,,,,,,,,,,,,, -39546,,,0.17795631,0.0637573506654286,0.45436054,0.1326259690858009,5348.0,0.24431218,0.0793979647797209,2472.0,31704.79609966278,34888.67411828041,31704.79609966278,3180.9032578468323,1.188124656677246,0.0 -39600,2.3110967,1.3051124,,,,,,,,,,,,,, -39700,1.6274064,1.2565491,,,,,,,,,,,,,, -39800,3.6014025,1.2649932,,,,,,,,,,,,,, -39900,2.7060425,1.2555735,,,,,,,,,,,,,, -40000,1.7134603,1.2702323,,,,,,,,,,,,,, -40100,2.6637144,1.2670678,,,,,,,,,,,,,, -40200,2.850623,1.2418011,,,,,,,,,,,,,, -40300,1.9462199,1.1266465,,,,,,,,,,,,,, -40400,2.6516163,1.2378963,,,,,,,,,,,,,, -40500,1.4750957,1.2052003,,,,,,,,,,,,,, -40600,1.8901627,1.250134,,,,,,,,,,,,,, -40700,2.280034,1.154497,,,,,,,,,,,,,, -40800,2.0306191,1.2374828,,,,,,,,,,,,,, -40900,2.5891988,1.238197,,,,,,,,,,,,,, -41000,3.0874891,1.2160735,,,,,,,,,,,,,, -41100,1.8458512,1.2052763,,,,,,,,,,,,,, -41200,2.5274432,1.1797996,,,,,,,,,,,,,, -41300,2.2692277,1.134357,,,,,,,,,,,,,, -41345,,,0.16376394,0.0562627900178063,0.43039808,0.1254139432499493,5348.0,0.23327044,0.0745638088274125,2472.0,33145.08502173424,36463.99954032898,33145.08502173424,3315.796919107437,1.2503759860992432,0.0 -41400,2.026322,1.2159897,,,,,,,,,,,,,, -41500,2.5507991,1.155378,,,,,,,,,,,,,, -41600,2.8639922,1.1885568,,,,,,,,,,,,,, -41700,2.1319587,1.1488266,,,,,,,,,,,,,, -41800,3.026354,1.1583183,,,,,,,,,,,,,, -41900,1.7501775,1.1770072,,,,,,,,,,,,,, -42000,3.0011516,1.2596434,,,,,,,,,,,,,, -42100,3.0139213,1.1940053,,,,,,,,,,,,,, -42200,2.2001698,1.2031432,,,,,,,,,,,,,, -42300,2.826396,1.2108456,,,,,,,,,,,,,, -42400,2.298977,1.1856351,,,,,,,,,,,,,, -42500,4.38965,1.1544303,,,,,,,,,,,,,, -42600,3.6833484,1.172604,,,,,,,,,,,,,, -42700,2.142073,1.1577048,,,,,,,,,,,,,, -42800,2.124016,1.204417,,,,,,,,,,,,,, -42900,1.9044127,1.1886647,,,,,,,,,,,,,, -43000,1.842496,1.1479641,,,,,,,,,,,,,, -43100,1.92904,1.1440982,,,,,,,,,,,,,, -43139,,,0.14803201,0.0525516795865633,0.41873074,0.1214844994545121,5348.0,0.22326158,0.071618629780838,2472.0,34585.27427482605,38039.684517383575,34585.27427482605,3451.147789478302,1.3133165836334229,0.0 -43200,3.1065314,1.220882,,,,,,,,,,,,,, -43300,3.1360178,1.1791732,,,,,,,,,,,,,, -43400,1.9567341,1.160031,,,,,,,,,,,,,, -43500,2.4616404,1.1761438,,,,,,,,,,,,,, -43600,2.4854069,1.1167446,,,,,,,,,,,,,, -43700,1.404655,1.1665609,,,,,,,,,,,,,, -43800,3.2132413,1.1379864,,,,,,,,,,,,,, -43900,2.2170324,1.1395457,,,,,,,,,,,,,, -44000,2.0296597,1.1873195,,,,,,,,,,,,,, -44100,1.6787992,1.1556461,,,,,,,,,,,,,, -44200,2.274571,1.1570493,,,,,,,,,,,,,, -44300,1.7600248,1.1501092,,,,,,,,,,,,,, -44400,2.2979972,1.1255097,,,,,,,,,,,,,, -44500,3.4415367,1.13202,,,,,,,,,,,,,, -44600,2.1065202,1.1567831,,,,,,,,,,,,,, -44700,2.6358225,1.1487006,,,,,,,,,,,,,, -44800,1.5920421,1.1372101,,,,,,,,,,,,,, -44900,2.61703,1.1990103,,,,,,,,,,,,,, -44943,,,0.16079587,0.054562870721733,0.4065272,0.1176322928835552,5348.0,0.2159714,0.0694046676009993,2472.0,36025.31555771828,39616.17360472679,36025.31555771828,3587.4509439468384,1.3760840892791748,0.0 -45000,2.234555,1.1661128,,,,,,,,,,,,,, -45100,2.9934654,1.0908625,,,,,,,,,,,,,, -45200,2.2050266,1.1344898,,,,,,,,,,,,,, -45300,2.4134192,1.1243521,,,,,,,,,,,,,, -45400,2.5848153,1.1590935,,,,,,,,,,,,,, -45500,2.2225797,1.1247996,,,,,,,,,,,,,, -45600,2.7916481,1.1271546,,,,,,,,,,,,,, -45700,2.6027915,1.160547,,,,,,,,,,,,,, -45800,3.7508261,1.0997449,,,,,,,,,,,,,, -45900,1.7265284,1.0806123,,,,,,,,,,,,,, -46000,3.1106763,1.1745782,,,,,,,,,,,,,, -46100,2.6759398,1.1553972,,,,,,,,,,,,,, -46200,2.0716398,1.1868604,,,,,,,,,,,,,, -46300,2.318609,1.1103486,,,,,,,,,,,,,, -46400,1.8633941,1.0828502,,,,,,,,,,,,,, -46500,2.5765097,1.1088759,,,,,,,,,,,,,, -46600,2.081222,1.1007026,,,,,,,,,,,,,, -46700,2.1766512,1.131804,,,,,,,,,,,,,, -46770,,,0.15611503,0.0529189681591044,0.40304145,0.1167440648020313,5348.0,0.2135498,0.0683484654601588,2472.0,37465.57762217522,41192.97604799271,37465.57762217522,3723.850387096405,1.433948278427124,0.0 -46800,2.6870112,1.1133848,,,,,,,,,,,,,, -46900,1.7230766,1.0830717,,,,,,,,,,,,,, -47000,2.2738938,1.1139296,,,,,,,,,,,,,, -47100,3.38438,1.0890893,,,,,,,,,,,,,, -47200,1.6940862,1.1408336,,,,,,,,,,,,,, -47300,2.783568,1.1386424,,,,,,,,,,,,,, -47400,1.8112502,1.1625122,,,,,,,,,,,,,, -47500,1.5330952,1.1742045,,,,,,,,,,,,,, -47600,2.2977943,1.1339858,,,,,,,,,,,,,, -47700,2.666989,1.0942851,,,,,,,,,,,,,, -47800,4.200168,1.1620668,,,,,,,,,,,,,, -47900,2.6885161,1.0706109,,,,,,,,,,,,,, -48000,,,0.147639,0.0514584651119383,0.40327126,0.1167054461897911,5348.0,0.21367131,0.0685109580972112,2472.0,38433.03055810928,42298.19808506966,38433.03055810928,3861.506258010864,1.4898200035095217,0.0 -48000,,,,,,,,,,,38433.03055810928,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 3315914d6..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,29 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -203.4074904918671,0.0,16.74648356437683,1,0,16.74648356437683,30.942286,2472,3.698819897223407,220.1540470123291,31.380108,3.862109067393991,30.891895,5348,3.324869420817363 -324.5598194599152,0.0317559242248535,1457.3477289676666,1803,0,1457.3477289676666,3.7096946,2472,0.7506753600227489,1782.016296625137,3.9911628,0.7952104345989333,4.117805,5348,0.7905519565154426 -462.5844323635101,0.0894415378570556,2897.5168826580048,3624,0,2897.5168826580048,0.59962285,2472,0.1914569496069709,3360.348862886429,0.5607584,0.1856732763288281,0.92604756,5348,0.2639678693146162 -598.7094995975494,0.1492354869842529,4337.565103292465,5448,0,4337.565103292465,0.493212,2472,0.1578615968963906,4936.664131879807,0.46810666,0.1586667026822614,0.79226905,5348,0.2307172441758305 -736.0015232563019,0.2118763923645019,5777.571934938431,7230,0,5777.571934938431,0.47449654,2472,0.1537180346515548,6514.106091976166,0.40930718,0.1378129600370995,0.7724912,5348,0.2218735819728318 -873.7335147857666,0.2663674354553222,7218.460584163666,9039,0,7218.460584163666,0.42413357,2472,0.1365344382832652,8092.8622970581055,0.39555684,0.1352120428499472,0.7158556,5348,0.2070440348726068 -1011.367288351059,0.3219921588897705,8658.568603992462,10850,0,8658.568603992462,0.40153503,2472,0.1301566022789592,9670.741458892822,0.39360082,0.1260683874330512,0.68587214,5348,0.1970128503432229 -1149.0821301937103,0.3768749237060547,10099.024131059648,12664,0,10099.024131059648,0.38939542,2472,0.1271504884934901,11249.047943592072,0.30456778,0.1065818652211408,0.66065574,5348,0.1916158992826593 -1288.3281433582306,0.4314930438995361,11539.4461581707,14453,0,11539.4461581707,0.37232095,2472,0.1210773261836573,12828.851173877716,0.29636636,0.1025290250093629,0.63199,5348,0.181507477528795 -1424.4248707294464,0.4784595966339111,12980.45679616928,16259,0,12980.45679616928,0.36600688,2472,0.1188633640038185,14406.084884166718,0.4332551,0.1454063390430927,0.6278105,5348,0.1829556754878013 -1561.9921565055847,0.5266251564025879,14420.758267641068,18069,0,14420.758267641068,0.33499104,2472,0.1091341173603071,15984.083220481873,0.41025263,0.136527974202211,0.59172356,5348,0.1722969385095146 -1696.8221898078918,0.5745553970336914,15860.820094823835,19886,0,15860.820094823835,0.32893935,2472,0.1075295025694148,17559.10546183586,0.4899871,0.1588533804027219,0.57561785,5348,0.1675661585100939 -1830.587270498276,0.6198842525482178,17300.74575161934,21663,0,17300.74575161934,0.31936333,2472,0.1038937298153677,19132.92129182816,0.38680267,0.1280546266632369,0.5477699,5348,0.1610685770006854 -1965.1153333187103,0.6744070053100586,18741.40641236305,23473,0,18741.40641236305,0.3084308,2472,0.0999532833668474,20708.24504303932,0.39550456,0.1330114534639051,0.54725,5348,0.1592341929192774 -2100.4818663597107,0.7313663959503174,20181.81382036209,25286,0,20181.81382036209,0.2982347,2472,0.0980643064611134,22284.15694141388,0.3043809,0.105151034241601,0.5312797,5348,0.1542330826341755 -2236.380726337433,0.7822191715240479,21621.99354362488,27090,0,21621.99354362488,0.28729317,2472,0.0920520788901752,23860.36754226685,0.3554216,0.120767605556711,0.51525754,5348,0.1495698852061751 -2373.104640483856,0.8374857902526855,23062.2101726532,28873,0,23062.2101726532,0.27378377,2472,0.0878882050657079,25437.44374489784,0.3024444,0.1036152304609218,0.49139738,5348,0.1432267781457273 -2507.414543867111,0.8959517478942871,24502.66463828087,30676,0,24502.66463828087,0.26265836,2472,0.085369569191396,27012.347834587097,0.2924098,0.102579539414451,0.48229715,5348,0.1403014182685345 -2642.467351436615,0.9536893367767334,25942.722469568253,32491,0,25942.722469568253,0.2568106,2472,0.082850933317084,28587.597064971924,0.2727315,0.0943036249659307,0.4702428,5348,0.137231238595441 -2778.475238084793,1.014662265777588,27382.652527332302,34304,0,27382.652527332302,0.24790536,2472,0.0808603985131923,30163.677755355835,0.26884595,0.0949476657441259,0.46222648,5348,0.133350068065304 -2911.465180158615,1.0729358196258545,28822.911303281784,36084,0,28822.911303281784,0.23915242,2472,0.0777527268295655,31737.064566612244,0.2914037,0.0972122785036692,0.44839144,5348,0.1292951137800863 -3044.3916516304016,1.1292433738708496,30262.89229130745,37885,0,30262.89229130745,0.23163031,2472,0.074360693031097,33310.108652830124,0.22805688,0.0809899755595646,0.43772352,5348,0.1265725016171544 -3177.021288871765,1.1870660781860352,31703.72644138336,39694,0,31703.72644138336,0.22602047,2472,0.0723904698068368,34883.71180319786,0.23828937,0.0854084474355999,0.4246287,5348,0.1229906253318787 -3311.470253229141,1.2442855834960938,33144.32629656792,41500,0,33144.32629656792,0.22364078,2472,0.0725732740235208,36458.89913678169,0.20024322,0.0706996884290118,0.42071438,5348,0.1205093794954478 -3445.4803664684296,1.304053783416748,34584.63703107834,43278,0,34584.63703107834,0.21895356,2472,0.0701765076269981,38033.35954880714,0.21650901,0.0750982776662236,0.41551012,5348,0.118954980352781 -3581.0494186878204,1.3660883903503418,36025.066982507706,45080,0,36025.066982507706,0.21681073,2472,0.0693031097028415,39609.50177645683,0.21757558,0.077070545603486,0.41123518,5348,0.1187715419446402 -3717.407273054123,1.4297056198120115,37465.26635026932,46887,0,37465.26635026932,0.21612823,2472,0.0696280949769463,41186.20255231857,0.22279812,0.0782672265951365,0.4108359,5348,0.118501211658959 -3851.2109656333923,1.4912919998168945,38341.64616441727,48000,0,38341.64616441727,0.21605648,2472,0.06948591391952552,42196.50052070618,0.18652456,0.06758433938846892,0.41070828,5348,0.11837570116917849 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/measurements.csv deleted file mode 100644 index 2f413cc9d..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/measurements.csv +++ /dev/null @@ -1,510 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,20.484734,33.18051,,,,,,,,,,,,,, -1,,,31.380108,3.862109067393991,30.891895,3.324869420817363,5348.0,30.942286,3.698819897223407,2472.0,16.74648356437683,220.1540470123291,16.74648356437683,203.4074904918671,0.0,0.0 -100,0.74213386,5.9834266,,,,,,,,,,,,,, -200,0.3134994,5.8143816,,,,,,,,,,,,,, -300,1.2168733,5.6298113,,,,,,,,,,,,,, -400,1.3223073,5.12092,,,,,,,,,,,,,, -500,1.421955,4.3497562,,,,,,,,,,,,,, -600,1.5679104,3.7516403,,,,,,,,,,,,,, -700,2.0190203,3.4216132,,,,,,,,,,,,,, -800,2.161063,3.0890331,,,,,,,,,,,,,, -900,3.2241173,3.0507622,,,,,,,,,,,,,, -1000,2.4136972,2.8234122,,,,,,,,,,,,,, -1100,3.961841,2.716893,,,,,,,,,,,,,, -1200,1.8294408,2.5896509,,,,,,,,,,,,,, -1300,2.8037698,2.518723,,,,,,,,,,,,,, -1400,2.0711799,2.4661734,,,,,,,,,,,,,, -1500,2.2929113,2.3612108,,,,,,,,,,,,,, -1600,1.8432629,2.3642464,,,,,,,,,,,,,, -1700,3.279991,2.3476741,,,,,,,,,,,,,, -1800,2.6625402,2.3193393,,,,,,,,,,,,,, -1803,,,3.9911628,0.7952104345989333,4.117805,0.7905519565154426,5348.0,3.7096946,0.7506753600227489,2472.0,1457.3477289676666,1782.016296625137,1457.3477289676666,324.5598194599152,0.0317559242248535,0.0 -1900,2.8865309,2.2797894,,,,,,,,,,,,,, -2000,2.8830686,2.2672808,,,,,,,,,,,,,, -2100,3.244476,2.166888,,,,,,,,,,,,,, -2200,3.3636234,2.1775897,,,,,,,,,,,,,, -2300,5.403248,2.1411693,,,,,,,,,,,,,, -2400,2.4775631,2.1575806,,,,,,,,,,,,,, -2500,2.1631606,2.0125382,,,,,,,,,,,,,, -2600,4.89714,2.0674145,,,,,,,,,,,,,, -2700,2.9258647,2.039038,,,,,,,,,,,,,, -2800,2.4834263,1.9755131,,,,,,,,,,,,,, -2900,2.6924477,1.9870903,,,,,,,,,,,,,, -3000,2.6437337,2.007395,,,,,,,,,,,,,, -3100,2.893535,1.9859551,,,,,,,,,,,,,, -3200,3.0869813,1.9112777,,,,,,,,,,,,,, -3300,2.623179,1.9507486,,,,,,,,,,,,,, -3400,3.1165605,1.908474,,,,,,,,,,,,,, -3500,1.6150048,1.939308,,,,,,,,,,,,,, -3600,3.782373,1.9490575,,,,,,,,,,,,,, -3624,,,0.5607584,0.1856732763288281,0.92604756,0.2639678693146162,5348.0,0.59962285,0.1914569496069709,2472.0,2897.5168826580048,3360.348862886429,2897.5168826580048,462.5844323635101,0.0894415378570556,0.0 -3700,3.4302635,1.9222826,,,,,,,,,,,,,, -3800,1.7222766,1.8209114,,,,,,,,,,,,,, -3900,2.0486147,1.851782,,,,,,,,,,,,,, -4000,2.144154,1.8710929,,,,,,,,,,,,,, -4100,2.9942563,1.8273567,,,,,,,,,,,,,, -4200,2.5021548,1.8555206,,,,,,,,,,,,,, -4300,3.624553,1.8424585,,,,,,,,,,,,,, -4400,2.2767656,1.8539667,,,,,,,,,,,,,, -4500,5.1179338,1.8521177,,,,,,,,,,,,,, -4600,2.932478,1.8701842,,,,,,,,,,,,,, -4700,2.9494843,1.7680036,,,,,,,,,,,,,, -4800,2.1669817,1.7684941,,,,,,,,,,,,,, -4900,2.842343,1.6998668,,,,,,,,,,,,,, -5000,3.1230576,1.8089263,,,,,,,,,,,,,, -5100,1.6238672,1.7453041,,,,,,,,,,,,,, -5200,3.285688,1.7287806,,,,,,,,,,,,,, -5300,1.9826094,1.7482527,,,,,,,,,,,,,, -5400,2.3720465,1.755631,,,,,,,,,,,,,, -5448,,,0.46810666,0.1586667026822614,0.79226905,0.2307172441758305,5348.0,0.493212,0.1578615968963906,2472.0,4337.565103292465,4936.664131879807,4337.565103292465,598.7094995975494,0.1492354869842529,0.0 -5500,2.2637331,1.7373216,,,,,,,,,,,,,, -5600,2.4813995,1.7088968,,,,,,,,,,,,,, -5700,3.2396135,1.7360872,,,,,,,,,,,,,, -5800,2.4064276,1.7973839,,,,,,,,,,,,,, -5900,2.4702818,1.6778524,,,,,,,,,,,,,, -6000,3.3211846,1.6601034,,,,,,,,,,,,,, -6100,4.82628,1.6988311,,,,,,,,,,,,,, -6200,2.6297405,1.6607755,,,,,,,,,,,,,, -6300,2.7572813,1.6555907,,,,,,,,,,,,,, -6400,2.4580243,1.6703781,,,,,,,,,,,,,, -6500,3.4329367,1.7071942,,,,,,,,,,,,,, -6600,2.1095543,1.6845961,,,,,,,,,,,,,, -6700,3.6276019,1.7252563,,,,,,,,,,,,,, -6800,2.0174396,1.7607723,,,,,,,,,,,,,, -6900,1.9617645,1.6370692,,,,,,,,,,,,,, -7000,3.9489675,1.684501,,,,,,,,,,,,,, -7100,2.8764756,1.6904887,,,,,,,,,,,,,, -7200,3.4878466,1.6961004,,,,,,,,,,,,,, -7230,,,0.40930718,0.1378129600370995,0.7724912,0.2218735819728318,5348.0,0.47449654,0.1537180346515548,2472.0,5777.571934938431,6514.106091976166,5777.571934938431,736.0015232563019,0.2118763923645019,0.0 -7300,2.6239345,1.66035,,,,,,,,,,,,,, -7400,2.102375,1.646288,,,,,,,,,,,,,, -7500,3.3941765,1.6848036,,,,,,,,,,,,,, -7600,2.6547225,1.6596928,,,,,,,,,,,,,, -7700,2.9297853,1.6785185,,,,,,,,,,,,,, -7800,2.7488253,1.703931,,,,,,,,,,,,,, -7900,3.1697154,1.6783984,,,,,,,,,,,,,, -8000,3.0410678,1.642336,,,,,,,,,,,,,, -8100,4.015056,1.641035,,,,,,,,,,,,,, -8200,1.5925337,1.594501,,,,,,,,,,,,,, -8300,2.783434,1.6028206,,,,,,,,,,,,,, -8400,2.8319905,1.6765181,,,,,,,,,,,,,, -8500,2.6669545,1.6176598,,,,,,,,,,,,,, -8600,2.44186,1.7106022,,,,,,,,,,,,,, -8700,2.438644,1.688127,,,,,,,,,,,,,, -8800,2.8529396,1.6040909,,,,,,,,,,,,,, -8900,5.088569,1.5476049,,,,,,,,,,,,,, -9000,2.8790438,1.6763254,,,,,,,,,,,,,, -9039,,,0.39555684,0.1352120428499472,0.7158556,0.2070440348726068,5348.0,0.42413357,0.1365344382832652,2472.0,7218.460584163666,8092.8622970581055,7218.460584163666,873.7335147857666,0.2663674354553222,0.0 -9100,2.937503,1.7114329,,,,,,,,,,,,,, -9200,5.1203036,1.6317747,,,,,,,,,,,,,, -9300,2.4055357,1.6098064,,,,,,,,,,,,,, -9400,2.645508,1.6756252,,,,,,,,,,,,,, -9500,3.7351894,1.6059096,,,,,,,,,,,,,, -9600,2.3249998,1.6001471,,,,,,,,,,,,,, -9700,2.4795587,1.5887771,,,,,,,,,,,,,, -9800,2.6701238,1.7111017,,,,,,,,,,,,,, -9900,2.4844608,1.5849646,,,,,,,,,,,,,, -10000,4.877182,1.6340953,,,,,,,,,,,,,, -10100,2.2107873,1.5429083,,,,,,,,,,,,,, -10200,2.85582,1.5845945,,,,,,,,,,,,,, -10300,2.5963986,1.5835495,,,,,,,,,,,,,, -10400,2.3718498,1.5309194,,,,,,,,,,,,,, -10500,2.8314161,1.6079403,,,,,,,,,,,,,, -10600,3.2849133,1.6032845,,,,,,,,,,,,,, -10700,3.0923133,1.6355208,,,,,,,,,,,,,, -10800,2.085449,1.5547279,,,,,,,,,,,,,, -10850,,,0.39360082,0.1260683874330512,0.68587214,0.1970128503432229,5348.0,0.40153503,0.1301566022789592,2472.0,8658.568603992462,9670.741458892822,8658.568603992462,1011.367288351059,0.3219921588897705,0.0 -10900,2.23743,1.6339291,,,,,,,,,,,,,, -11000,2.7490563,1.6056687,,,,,,,,,,,,,, -11100,4.058692,1.6007514,,,,,,,,,,,,,, -11200,2.77024,1.5279447,,,,,,,,,,,,,, -11300,2.3942122,1.6149831,,,,,,,,,,,,,, -11400,4.1460204,1.5347179,,,,,,,,,,,,,, -11500,2.8892715,1.5591502,,,,,,,,,,,,,, -11600,2.761071,1.5501885,,,,,,,,,,,,,, -11700,4.3123193,1.6545883,,,,,,,,,,,,,, -11800,2.8104167,1.6020694,,,,,,,,,,,,,, -11900,4.337468,1.5453861,,,,,,,,,,,,,, -12000,3.2830362,1.5584182,,,,,,,,,,,,,, -12100,3.7636077,1.5394853,,,,,,,,,,,,,, -12200,2.990751,1.6159576,,,,,,,,,,,,,, -12300,4.6938834,1.6395761,,,,,,,,,,,,,, -12400,2.7136185,1.5754529,,,,,,,,,,,,,, -12500,3.2462232,1.5377337,,,,,,,,,,,,,, -12600,3.2066526,1.5521545,,,,,,,,,,,,,, -12664,,,0.30456778,0.1065818652211408,0.66065574,0.1916158992826593,5348.0,0.38939542,0.1271504884934901,2472.0,10099.024131059648,11249.047943592072,10099.024131059648,1149.0821301937103,0.3768749237060547,0.0 -12700,2.4963315,1.5559157,,,,,,,,,,,,,, -12800,3.567303,1.635151,,,,,,,,,,,,,, -12900,2.73678,1.5770562,,,,,,,,,,,,,, -13000,3.982029,1.5504612,,,,,,,,,,,,,, -13100,2.2482026,1.5017103,,,,,,,,,,,,,, -13200,2.9714227,1.5704271,,,,,,,,,,,,,, -13300,2.55209,1.5564058,,,,,,,,,,,,,, -13400,4.793172,1.560251,,,,,,,,,,,,,, -13500,2.6873784,1.553968,,,,,,,,,,,,,, -13600,2.3085275,1.5724201,,,,,,,,,,,,,, -13700,3.4236276,1.4687093,,,,,,,,,,,,,, -13800,2.9649534,1.4536018,,,,,,,,,,,,,, -13900,2.4902115,1.5556295,,,,,,,,,,,,,, -14000,2.7686758,1.5971714,,,,,,,,,,,,,, -14100,4.537937,1.5164636,,,,,,,,,,,,,, -14200,2.4217994,1.5826789,,,,,,,,,,,,,, -14300,3.041007,1.5296273,,,,,,,,,,,,,, -14400,3.1058927,1.5730733,,,,,,,,,,,,,, -14453,,,0.29636636,0.1025290250093629,0.63199,0.181507477528795,5348.0,0.37232095,0.1210773261836573,2472.0,11539.4461581707,12828.851173877716,11539.4461581707,1288.3281433582306,0.4314930438995361,0.0 -14500,2.9308543,1.5139526,,,,,,,,,,,,,, -14600,2.1230931,1.5381941,,,,,,,,,,,,,, -14700,3.712639,1.4771504,,,,,,,,,,,,,, -14800,2.5284219,1.5265878,,,,,,,,,,,,,, -14900,1.7178209,1.537362,,,,,,,,,,,,,, -15000,2.4876642,1.5389534,,,,,,,,,,,,,, -15100,4.280358,1.5741764,,,,,,,,,,,,,, -15200,2.6780603,1.5300001,,,,,,,,,,,,,, -15300,2.0252445,1.5487648,,,,,,,,,,,,,, -15400,3.4261594,1.5386333,,,,,,,,,,,,,, -15500,1.9904454,1.5261174,,,,,,,,,,,,,, -15600,2.1034985,1.5391964,,,,,,,,,,,,,, -15700,3.4457152,1.5509164,,,,,,,,,,,,,, -15800,3.2624102,1.5115883,,,,,,,,,,,,,, -15900,2.4977043,1.4783716,,,,,,,,,,,,,, -16000,3.1500258,1.4993011,,,,,,,,,,,,,, -16100,1.9477059,1.4483807,,,,,,,,,,,,,, -16200,2.4484735,1.5024052,,,,,,,,,,,,,, -16259,,,0.4332551,0.1454063390430927,0.6278105,0.1829556754878013,5348.0,0.36600688,0.1188633640038185,2472.0,12980.45679616928,14406.084884166718,12980.45679616928,1424.4248707294464,0.4784595966339111,0.0 -16300,2.9445179,1.556127,,,,,,,,,,,,,, -16400,1.5814002,1.4474385,,,,,,,,,,,,,, -16500,2.760056,1.4817994,,,,,,,,,,,,,, -16600,2.3599474,1.5187898,,,,,,,,,,,,,, -16700,2.740259,1.4478036,,,,,,,,,,,,,, -16800,3.721732,1.5317696,,,,,,,,,,,,,, -16900,2.8361087,1.4481896,,,,,,,,,,,,,, -17000,2.6351645,1.4905326,,,,,,,,,,,,,, -17100,2.6458638,1.4119637,,,,,,,,,,,,,, -17200,2.2550678,1.3982569,,,,,,,,,,,,,, -17300,1.8983148,1.4439479,,,,,,,,,,,,,, -17400,2.7678218,1.5742773,,,,,,,,,,,,,, -17500,2.7560189,1.5363953,,,,,,,,,,,,,, -17600,2.5261464,1.4699354,,,,,,,,,,,,,, -17700,2.3296902,1.4236199,,,,,,,,,,,,,, -17800,3.3310304,1.5257113,,,,,,,,,,,,,, -17900,2.1868765,1.5336541,,,,,,,,,,,,,, -18000,2.0168037,1.4413775,,,,,,,,,,,,,, -18069,,,0.41025263,0.136527974202211,0.59172356,0.1722969385095146,5348.0,0.33499104,0.1091341173603071,2472.0,14420.758267641068,15984.083220481873,14420.758267641068,1561.9921565055847,0.5266251564025879,0.0 -18100,2.9702213,1.3913066,,,,,,,,,,,,,, -18200,2.7842171,1.4186984,,,,,,,,,,,,,, -18300,1.971073,1.4713314,,,,,,,,,,,,,, -18400,2.776698,1.4715827,,,,,,,,,,,,,, -18500,4.9375267,1.4421638,,,,,,,,,,,,,, -18600,2.0166075,1.4692299,,,,,,,,,,,,,, -18700,1.9725113,1.4542358,,,,,,,,,,,,,, -18800,2.5208864,1.385426,,,,,,,,,,,,,, -18900,2.0971575,1.4677912,,,,,,,,,,,,,, -19000,2.5848603,1.4372818,,,,,,,,,,,,,, -19100,2.3677788,1.4360021,,,,,,,,,,,,,, -19200,4.759363,1.4234055,,,,,,,,,,,,,, -19300,2.7758658,1.4758484,,,,,,,,,,,,,, -19400,3.211294,1.4525795,,,,,,,,,,,,,, -19500,2.5564566,1.4108793,,,,,,,,,,,,,, -19600,2.5925612,1.4623849,,,,,,,,,,,,,, -19700,2.2861657,1.3994086,,,,,,,,,,,,,, -19800,2.156554,1.4936469,,,,,,,,,,,,,, -19886,,,0.4899871,0.1588533804027219,0.57561785,0.1675661585100939,5348.0,0.32893935,0.1075295025694148,2472.0,15860.820094823835,17559.10546183586,15860.820094823835,1696.8221898078918,0.5745553970336914,0.0 -19900,2.1022818,1.4503834,,,,,,,,,,,,,, -20000,2.5608063,1.4695017,,,,,,,,,,,,,, -20100,1.9420046,1.4725575,,,,,,,,,,,,,, -20200,2.4107633,1.4462284,,,,,,,,,,,,,, -20300,3.052617,1.4316577,,,,,,,,,,,,,, -20400,2.8914318,1.4396073,,,,,,,,,,,,,, -20500,2.713993,1.4884845,,,,,,,,,,,,,, -20600,3.4488008,1.4210452,,,,,,,,,,,,,, -20700,3.5313873,1.4053211,,,,,,,,,,,,,, -20800,2.747072,1.4350194,,,,,,,,,,,,,, -20900,2.0603335,1.4666885,,,,,,,,,,,,,, -21000,1.5020522,1.418771,,,,,,,,,,,,,, -21100,3.223429,1.4537508,,,,,,,,,,,,,, -21200,4.116687,1.4687887,,,,,,,,,,,,,, -21300,2.528088,1.4515759,,,,,,,,,,,,,, -21400,2.6876504,1.4397634,,,,,,,,,,,,,, -21500,2.815967,1.4690922,,,,,,,,,,,,,, -21600,2.5630255,1.3909175,,,,,,,,,,,,,, -21663,,,0.38680267,0.1280546266632369,0.5477699,0.1610685770006854,5348.0,0.31936333,0.1038937298153677,2472.0,17300.74575161934,19132.92129182816,17300.74575161934,1830.587270498276,0.6198842525482178,0.0 -21700,3.073709,1.3484205,,,,,,,,,,,,,, -21800,2.8492184,1.4717222,,,,,,,,,,,,,, -21900,2.620919,1.4271389,,,,,,,,,,,,,, -22000,2.2077188,1.3820609,,,,,,,,,,,,,, -22100,2.3195934,1.3174173,,,,,,,,,,,,,, -22200,3.2770433,1.3367198,,,,,,,,,,,,,, -22300,2.4207094,1.4050943,,,,,,,,,,,,,, -22400,2.1346314,1.4450643,,,,,,,,,,,,,, -22500,3.4989536,1.4562424,,,,,,,,,,,,,, -22600,3.2246194,1.4411615,,,,,,,,,,,,,, -22700,2.2186487,1.4081804,,,,,,,,,,,,,, -22800,2.7294474,1.3984406,,,,,,,,,,,,,, -22900,3.9503217,1.4385365,,,,,,,,,,,,,, -23000,2.1580582,1.417938,,,,,,,,,,,,,, -23100,1.8190292,1.4084331,,,,,,,,,,,,,, -23200,3.411206,1.455835,,,,,,,,,,,,,, -23300,2.0607376,1.4343535,,,,,,,,,,,,,, -23400,1.7336923,1.3748772,,,,,,,,,,,,,, -23473,,,0.39550456,0.1330114534639051,0.54725,0.1592341929192774,5348.0,0.3084308,0.0999532833668474,2472.0,18741.40641236305,20708.24504303932,18741.40641236305,1965.1153333187103,0.6744070053100586,0.0 -23500,2.867396,1.387893,,,,,,,,,,,,,, -23600,2.0971534,1.4579465,,,,,,,,,,,,,, -23700,2.029802,1.383554,,,,,,,,,,,,,, -23800,1.8576688,1.4693025,,,,,,,,,,,,,, -23900,2.485642,1.4370034,,,,,,,,,,,,,, -24000,2.2901952,1.4326504,,,,,,,,,,,,,, -24100,3.1242292,1.3810128,,,,,,,,,,,,,, -24200,2.3221343,1.3739048,,,,,,,,,,,,,, -24300,2.355068,1.3916311,,,,,,,,,,,,,, -24400,3.4882274,1.3921068,,,,,,,,,,,,,, -24500,2.5843313,1.3776841,,,,,,,,,,,,,, -24600,2.730191,1.374827,,,,,,,,,,,,,, -24700,2.5497336,1.3345364,,,,,,,,,,,,,, -24800,2.381929,1.3076792,,,,,,,,,,,,,, -24900,2.8930852,1.397612,,,,,,,,,,,,,, -25000,2.1562238,1.3467286,,,,,,,,,,,,,, -25100,2.8757331,1.3534867,,,,,,,,,,,,,, -25200,4.8474183,1.4634635,,,,,,,,,,,,,, -25286,,,0.3043809,0.105151034241601,0.5312797,0.1542330826341755,5348.0,0.2982347,0.0980643064611134,2472.0,20181.81382036209,22284.15694141388,20181.81382036209,2100.4818663597107,0.7313663959503174,0.0 -25300,3.905393,1.4153953,,,,,,,,,,,,,, -25400,2.3251467,1.3716652,,,,,,,,,,,,,, -25500,2.0185344,1.3727343,,,,,,,,,,,,,, -25600,2.4301634,1.3740976,,,,,,,,,,,,,, -25700,3.247579,1.4104929,,,,,,,,,,,,,, -25800,3.2623405,1.375311,,,,,,,,,,,,,, -25900,2.3838582,1.3907167,,,,,,,,,,,,,, -26000,2.1275816,1.4050318,,,,,,,,,,,,,, -26100,3.0767057,1.3373694,,,,,,,,,,,,,, -26200,2.870218,1.3998995,,,,,,,,,,,,,, -26300,2.243197,1.3961213,,,,,,,,,,,,,, -26400,2.0791001,1.3138268,,,,,,,,,,,,,, -26500,2.3446627,1.3795649,,,,,,,,,,,,,, -26600,3.2726946,1.3807691,,,,,,,,,,,,,, -26700,2.1071742,1.3564427,,,,,,,,,,,,,, -26800,2.5856915,1.3435575,,,,,,,,,,,,,, -26900,2.3685493,1.3511559,,,,,,,,,,,,,, -27000,2.6901515,1.2821618,,,,,,,,,,,,,, -27090,,,0.3554216,0.120767605556711,0.51525754,0.1495698852061751,5348.0,0.28729317,0.0920520788901752,2472.0,21621.99354362488,23860.36754226685,21621.99354362488,2236.380726337433,0.7822191715240479,0.0 -27100,1.6100531,1.3470913,,,,,,,,,,,,,, -27200,3.0828419,1.3547667,,,,,,,,,,,,,, -27300,2.1616995,1.2760901,,,,,,,,,,,,,, -27400,2.4942226,1.36475,,,,,,,,,,,,,, -27500,2.4217381,1.3054855,,,,,,,,,,,,,, -27600,2.3235154,1.2727818,,,,,,,,,,,,,, -27700,2.3359857,1.3167448,,,,,,,,,,,,,, -27800,2.8568811,1.3618107,,,,,,,,,,,,,, -27900,2.1259816,1.3338097,,,,,,,,,,,,,, -28000,3.4142253,1.3161986,,,,,,,,,,,,,, -28100,3.8061569,1.3831513,,,,,,,,,,,,,, -28200,2.7813766,1.3440773,,,,,,,,,,,,,, -28300,2.8733642,1.3071958,,,,,,,,,,,,,, -28400,1.4970998,1.3811427,,,,,,,,,,,,,, -28500,2.2059634,1.2981328,,,,,,,,,,,,,, -28600,2.5109057,1.3387572,,,,,,,,,,,,,, -28700,3.1237745,1.4040014,,,,,,,,,,,,,, -28800,3.2232869,1.3756758,,,,,,,,,,,,,, -28873,,,0.3024444,0.1036152304609218,0.49139738,0.1432267781457273,5348.0,0.27378377,0.0878882050657079,2472.0,23062.2101726532,25437.44374489784,23062.2101726532,2373.104640483856,0.8374857902526855,0.0 -28900,2.333491,1.2787442,,,,,,,,,,,,,, -29000,2.092814,1.2834747,,,,,,,,,,,,,, -29100,2.6718438,1.294449,,,,,,,,,,,,,, -29200,2.5865645,1.3196459,,,,,,,,,,,,,, -29300,1.7925372,1.2974322,,,,,,,,,,,,,, -29400,3.1529748,1.3595623,,,,,,,,,,,,,, -29500,2.3504405,1.3158704,,,,,,,,,,,,,, -29600,2.673802,1.3365266,,,,,,,,,,,,,, -29700,1.9303893,1.3053864,,,,,,,,,,,,,, -29800,2.1552496,1.2939103,,,,,,,,,,,,,, -29900,3.115611,1.2435895,,,,,,,,,,,,,, -30000,2.0795627,1.2406651,,,,,,,,,,,,,, -30100,2.6078742,1.2398251,,,,,,,,,,,,,, -30200,2.50656,1.3216089,,,,,,,,,,,,,, -30300,2.167051,1.2506078,,,,,,,,,,,,,, -30400,2.3586528,1.3223811,,,,,,,,,,,,,, -30500,2.2406106,1.3123415,,,,,,,,,,,,,, -30600,3.266064,1.3525077,,,,,,,,,,,,,, -30676,,,0.2924098,0.102579539414451,0.48229715,0.1403014182685345,5348.0,0.26265836,0.085369569191396,2472.0,24502.66463828087,27012.347834587097,24502.66463828087,2507.414543867111,0.8959517478942871,0.0 -30700,2.7757852,1.2985802,,,,,,,,,,,,,, -30800,2.5525243,1.3007731,,,,,,,,,,,,,, -30900,3.0939586,1.3110178,,,,,,,,,,,,,, -31000,3.06867,1.297288,,,,,,,,,,,,,, -31100,2.4167268,1.3135039,,,,,,,,,,,,,, -31200,2.1063406,1.2709627,,,,,,,,,,,,,, -31300,2.4234335,1.3269049,,,,,,,,,,,,,, -31400,1.716328,1.2722394,,,,,,,,,,,,,, -31500,2.0752716,1.2893872,,,,,,,,,,,,,, -31600,2.1532784,1.2631004,,,,,,,,,,,,,, -31700,2.6406517,1.2614125,,,,,,,,,,,,,, -31800,2.3225403,1.3502748,,,,,,,,,,,,,, -31900,3.9404569,1.2716179,,,,,,,,,,,,,, -32000,2.4022527,1.2881474,,,,,,,,,,,,,, -32100,1.9171443,1.2385604,,,,,,,,,,,,,, -32200,2.50521,1.2969514,,,,,,,,,,,,,, -32300,2.7501125,1.2600809,,,,,,,,,,,,,, -32400,2.5165267,1.2394211,,,,,,,,,,,,,, -32491,,,0.2727315,0.0943036249659307,0.4702428,0.137231238595441,5348.0,0.2568106,0.082850933317084,2472.0,25942.722469568253,28587.597064971924,25942.722469568253,2642.467351436615,0.9536893367767334,0.0 -32500,1.7421784,1.2719815,,,,,,,,,,,,,, -32600,2.0693033,1.2769524,,,,,,,,,,,,,, -32700,2.0831702,1.2893006,,,,,,,,,,,,,, -32800,2.1377826,1.290442,,,,,,,,,,,,,, -32900,2.8482578,1.2500015,,,,,,,,,,,,,, -33000,2.9075303,1.2454444,,,,,,,,,,,,,, -33100,3.6687078,1.2650204,,,,,,,,,,,,,, -33200,2.5440166,1.2787261,,,,,,,,,,,,,, -33300,3.4827886,1.2539563,,,,,,,,,,,,,, -33400,1.7888727,1.2327907,,,,,,,,,,,,,, -33500,2.3226402,1.3369595,,,,,,,,,,,,,, -33600,2.0689516,1.2418759,,,,,,,,,,,,,, -33700,2.338434,1.2355341,,,,,,,,,,,,,, -33800,4.170946,1.2148923,,,,,,,,,,,,,, -33900,6.718029,1.218358,,,,,,,,,,,,,, -34000,2.4469385,1.2518789,,,,,,,,,,,,,, -34100,2.366356,1.229392,,,,,,,,,,,,,, -34200,2.8983011,1.2190636,,,,,,,,,,,,,, -34300,2.4741046,1.2184913,,,,,,,,,,,,,, -34304,,,0.26884595,0.0949476657441259,0.46222648,0.133350068065304,5348.0,0.24790536,0.0808603985131923,2472.0,27382.652527332302,30163.677755355835,27382.652527332302,2778.475238084793,1.014662265777588,0.0 -34400,2.9038942,1.2127821,,,,,,,,,,,,,, -34500,2.69638,1.2230058,,,,,,,,,,,,,, -34600,1.9302775,1.2112638,,,,,,,,,,,,,, -34700,1.5544317,1.259684,,,,,,,,,,,,,, -34800,2.343986,1.2517226,,,,,,,,,,,,,, -34900,2.3759894,1.2422199,,,,,,,,,,,,,, -35000,6.502762,1.2564758,,,,,,,,,,,,,, -35100,1.8244973,1.2448894,,,,,,,,,,,,,, -35200,3.5402763,1.1748743,,,,,,,,,,,,,, -35300,2.3627636,1.2401277,,,,,,,,,,,,,, -35400,3.4092116,1.2425231,,,,,,,,,,,,,, -35500,2.6997175,1.2319701,,,,,,,,,,,,,, -35600,2.216112,1.1975636,,,,,,,,,,,,,, -35700,3.1115782,1.2318149,,,,,,,,,,,,,, -35800,2.8301795,1.2217934,,,,,,,,,,,,,, -35900,2.531229,1.2437198,,,,,,,,,,,,,, -36000,1.804329,1.2129511,,,,,,,,,,,,,, -36084,,,0.2914037,0.0972122785036692,0.44839144,0.1292951137800863,5348.0,0.23915242,0.0777527268295655,2472.0,28822.911303281784,31737.064566612244,28822.911303281784,2911.465180158615,1.0729358196258545,0.0 -36100,3.2296257,1.1976902,,,,,,,,,,,,,, -36200,4.306237,1.1780405,,,,,,,,,,,,,, -36300,3.098053,1.2188053,,,,,,,,,,,,,, -36400,2.1002908,1.2328907,,,,,,,,,,,,,, -36500,2.3214164,1.2598916,,,,,,,,,,,,,, -36600,2.9527447,1.2140442,,,,,,,,,,,,,, -36700,2.0681772,1.1702893,,,,,,,,,,,,,, -36800,2.8610282,1.1516062,,,,,,,,,,,,,, -36900,3.2994058,1.2024357,,,,,,,,,,,,,, -37000,2.971394,1.1668364,,,,,,,,,,,,,, -37100,2.3341143,1.178437,,,,,,,,,,,,,, -37200,3.7743118,1.2089856,,,,,,,,,,,,,, -37300,3.2393434,1.2424238,,,,,,,,,,,,,, -37400,6.451042,1.1947763,,,,,,,,,,,,,, -37500,2.6085303,1.1549212,,,,,,,,,,,,,, -37600,2.4514382,1.1887486,,,,,,,,,,,,,, -37700,3.4476995,1.2339721,,,,,,,,,,,,,, -37800,1.9849676,1.2117778,,,,,,,,,,,,,, -37885,,,0.22805688,0.0809899755595646,0.43772352,0.1265725016171544,5348.0,0.23163031,0.074360693031097,2472.0,30262.89229130745,33310.108652830124,30262.89229130745,3044.3916516304016,1.1292433738708496,0.0 -37900,2.9568906,1.1939815,,,,,,,,,,,,,, -38000,2.4863153,1.2085061,,,,,,,,,,,,,, -38100,1.6776104,1.1471908,,,,,,,,,,,,,, -38200,3.0697293,1.2160221,,,,,,,,,,,,,, -38300,2.1022937,1.186768,,,,,,,,,,,,,, -38400,1.6945852,1.1905212,,,,,,,,,,,,,, -38500,2.7321246,1.1368432,,,,,,,,,,,,,, -38600,2.21231,1.1619871,,,,,,,,,,,,,, -38700,4.2510986,1.1652176,,,,,,,,,,,,,, -38800,2.3836534,1.2368381,,,,,,,,,,,,,, -38900,2.3629956,1.2032634,,,,,,,,,,,,,, -39000,2.1802764,1.2041448,,,,,,,,,,,,,, -39100,1.4155995,1.2168387,,,,,,,,,,,,,, -39200,3.1507847,1.1953125,,,,,,,,,,,,,, -39300,2.3354328,1.130953,,,,,,,,,,,,,, -39400,2.547296,1.1256735,,,,,,,,,,,,,, -39500,2.414185,1.1672585,,,,,,,,,,,,,, -39600,2.8311791,1.2049305,,,,,,,,,,,,,, -39694,,,0.23828937,0.0854084474355999,0.4246287,0.1229906253318787,5348.0,0.22602047,0.0723904698068368,2472.0,31703.72644138336,34883.71180319786,31703.72644138336,3177.021288871765,1.1870660781860352,0.0 -39700,2.758453,1.1794431,,,,,,,,,,,,,, -39800,3.8541803,1.2104427,,,,,,,,,,,,,, -39900,2.198061,1.1810976,,,,,,,,,,,,,, -40000,4.4814415,1.1192218,,,,,,,,,,,,,, -40100,2.2793596,1.1145285,,,,,,,,,,,,,, -40200,2.7202926,1.1484942,,,,,,,,,,,,,, -40300,2.3472207,1.1127305,,,,,,,,,,,,,, -40400,2.618201,1.2071478,,,,,,,,,,,,,, -40500,2.7303631,1.1987522,,,,,,,,,,,,,, -40600,2.0245335,1.2014037,,,,,,,,,,,,,, -40700,1.9778315,1.1347127,,,,,,,,,,,,,, -40800,2.3716629,1.1229354,,,,,,,,,,,,,, -40900,4.8291817,1.2156234,,,,,,,,,,,,,, -41000,3.22807,1.1504955,,,,,,,,,,,,,, -41100,3.0795043,1.1426481,,,,,,,,,,,,,, -41200,2.1254988,1.1448623,,,,,,,,,,,,,, -41300,2.0023699,1.1721568,,,,,,,,,,,,,, -41400,2.4532793,1.185914,,,,,,,,,,,,,, -41500,,,0.20024322,0.0706996884290118,0.42071438,0.1205093794954478,5348.0,0.22364078,0.0725732740235208,2472.0,33144.32629656792,36458.89913678169,33144.32629656792,3311.470253229141,1.2442855834960938,0.0 -41500,3.0420241,1.1042179,,,,,,,,,,,,,, -41600,5.275315,1.1546991,,,,,,,,,,,,,, -41700,2.4014018,1.1129272,,,,,,,,,,,,,, -41800,2.386598,1.153924,,,,,,,,,,,,,, -41900,3.8647106,1.1390226,,,,,,,,,,,,,, -42000,2.368711,1.1978734,,,,,,,,,,,,,, -42100,3.2184453,1.1714118,,,,,,,,,,,,,, -42200,1.6727208,1.1560347,,,,,,,,,,,,,, -42300,2.8810089,1.1461005,,,,,,,,,,,,,, -42400,3.3580494,1.1301644,,,,,,,,,,,,,, -42500,2.7303505,1.123092,,,,,,,,,,,,,, -42600,2.8428679,1.1588156,,,,,,,,,,,,,, -42700,2.0885053,1.140811,,,,,,,,,,,,,, -42800,3.4741545,1.1785357,,,,,,,,,,,,,, -42900,3.2462552,1.1846592,,,,,,,,,,,,,, -43000,8.090484,1.1317166,,,,,,,,,,,,,, -43100,2.506691,1.1103959,,,,,,,,,,,,,, -43200,3.5427673,1.1172754,,,,,,,,,,,,,, -43278,,,0.21650901,0.0750982776662236,0.41551012,0.118954980352781,5348.0,0.21895356,0.0701765076269981,2472.0,34584.63703107834,38033.35954880714,34584.63703107834,3445.4803664684296,1.304053783416748,0.0 -43300,4.159617,1.133108,,,,,,,,,,,,,, -43400,3.8250492,1.1284118,,,,,,,,,,,,,, -43500,2.9370804,1.1469729,,,,,,,,,,,,,, -43600,3.419183,1.1243393,,,,,,,,,,,,,, -43700,2.5737908,1.1630665,,,,,,,,,,,,,, -43800,4.096757,1.1470841,,,,,,,,,,,,,, -43900,2.872769,1.1652954,,,,,,,,,,,,,, -44000,2.7466993,1.1382365,,,,,,,,,,,,,, -44100,3.3706324,1.1184345,,,,,,,,,,,,,, -44200,3.2225966,1.098551,,,,,,,,,,,,,, -44300,3.3629549,1.1500386,,,,,,,,,,,,,, -44400,2.2344005,1.1274668,,,,,,,,,,,,,, -44500,3.5951831,1.131491,,,,,,,,,,,,,, -44600,3.1159465,1.1532732,,,,,,,,,,,,,, -44700,3.883373,1.1209599,,,,,,,,,,,,,, -44800,3.0590246,1.1032226,,,,,,,,,,,,,, -44900,3.64622,1.1005973,,,,,,,,,,,,,, -45000,3.5943773,1.1482221,,,,,,,,,,,,,, -45080,,,0.21757558,0.077070545603486,0.41123518,0.1187715419446402,5348.0,0.21681073,0.0693031097028415,2472.0,36025.066982507706,39609.50177645683,36025.066982507706,3581.0494186878204,1.3660883903503418,0.0 -45100,3.1854126,1.1027305,,,,,,,,,,,,,, -45200,2.5725987,1.0975407,,,,,,,,,,,,,, -45300,3.6176784,1.0790193,,,,,,,,,,,,,, -45400,3.4625742,1.1117942,,,,,,,,,,,,,, -45500,2.062218,1.1343453,,,,,,,,,,,,,, -45600,3.3383646,1.1100503,,,,,,,,,,,,,, -45700,3.111824,1.1276349,,,,,,,,,,,,,, -45800,2.2954488,1.1137878,,,,,,,,,,,,,, -45900,2.5935667,1.1000828,,,,,,,,,,,,,, -46000,3.1830933,1.1353642,,,,,,,,,,,,,, -46100,3.5510998,1.1224293,,,,,,,,,,,,,, -46200,2.4739437,1.1830022,,,,,,,,,,,,,, -46300,2.2618365,1.1101975,,,,,,,,,,,,,, -46400,2.3077817,1.10468,,,,,,,,,,,,,, -46500,4.147683,1.1165084,,,,,,,,,,,,,, -46600,2.4203622,1.1131771,,,,,,,,,,,,,, -46700,4.2061496,1.1637043,,,,,,,,,,,,,, -46800,4.597704,1.1317294,,,,,,,,,,,,,, -46887,,,0.22279812,0.0782672265951365,0.4108359,0.118501211658959,5348.0,0.21612823,0.0696280949769463,2472.0,37465.26635026932,41186.20255231857,37465.26635026932,3717.407273054123,1.4297056198120115,0.0 -46900,2.830828,1.082628,,,,,,,,,,,,,, -47000,2.2020867,1.0925554,,,,,,,,,,,,,, -47100,2.7898786,1.1304334,,,,,,,,,,,,,, -47200,2.9250956,1.0904872,,,,,,,,,,,,,, -47300,1.8802117,1.164326,,,,,,,,,,,,,, -47400,2.0821135,1.1695347,,,,,,,,,,,,,, -47500,4.0499825,1.1640494,,,,,,,,,,,,,, -47600,3.3625963,1.1132065,,,,,,,,,,,,,, -47700,2.283709,1.1426816,,,,,,,,,,,,,, -47800,2.4964707,1.1461287,,,,,,,,,,,,,, -47900,3.7416272,1.0958766,,,,,,,,,,,,,, -48000,,,0.18652456,0.0675843393884689,0.41070828,0.1183757011691784,5348.0,0.21605648,0.0694859139195255,2472.0,38341.64616441727,42196.50052070618,38341.64616441727,3851.210965633392,1.4912919998168943,0.0 -48000,,,,,,,,,,,38341.64616441727,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 2a1b35ea2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -301.979544878006,0.0,24.272481203079224,1,0,24.272481203079224,0.3947486877441406,0.7956756353378296,0.0263659558726878,43793,326.25207114219666,0.3887696266174316,0.7993155121803284,0.0215552107974815,0.3926452994346618,0.7974664568901062,0.0248266565510279,43793 -416.35071659088135,0.0298035144805908,264.3512933254242,763,0,264.3512933254242,0.983142077922821,0.0812012106180191,0.0383700018853028,43793,680.7509214878082,0.9866331219673156,0.0705618038773536,0.0359794845055697,0.9841179251670836,0.078437902033329,0.0371951699350063,43793 -536.3063020706177,0.0594303607940673,504.42682576179504,1516,0,504.42682576179504,0.9836373925209044,0.0605317205190658,0.0888645896201985,43793,1040.8308537006378,0.9874165058135986,0.0481969825923442,0.0838339449449183,0.9846513271331788,0.0573393926024436,0.0884273181317277,43793 -659.5478177070618,0.0860383510589599,744.610445022583,2272,0,744.610445022583,0.9840400815010072,0.0568913184106349,0.1302070232937932,43793,1404.3017330169678,0.9878497123718262,0.0441702157258987,0.1317534680029279,0.9849797487258912,0.0539003200829029,0.1300213241066441,43793 -783.8252048492432,0.1164727210998535,984.7043483257294,3024,0,984.7043483257294,0.9843626618385316,0.0538056381046772,0.1520338482806402,43793,1768.7225918769836,0.9880422353744508,0.042276244610548,0.1592302566590645,0.9852914810180664,0.0512494929134845,0.1531901395730448,43793 -905.8047299385072,0.1440849304199218,1224.8253166675568,3778,0,1224.8253166675568,0.9844376444816588,0.0531672686338424,0.1731884919959402,43793,2130.8698279857635,0.9882649779319764,0.0405199825763702,0.1913616464709115,0.9853755235671996,0.0502678081393241,0.1748420624477223,43793 -1028.835110425949,0.1708650588989257,1464.9182755947113,4538,0,1464.9182755947113,0.984628438949585,0.052729532122612,0.1795988521353688,43793,2494.039370775223,0.9881364107131958,0.0405524857342243,0.1995915133924315,0.985506236553192,0.0498411767184734,0.1833361099801978,43793 -1154.775056600571,0.2018733024597168,1704.9223086833954,5291,0,1704.9223086833954,0.9847902059555054,0.0514041110873222,0.1930594567232489,43793,2860.033614873886,0.9885979890823364,0.0390627458691597,0.235640214546389,0.9856986403465272,0.0487240441143512,0.1980871299196871,43793 -1280.962295293808,0.2301163673400879,1945.1591565608976,6048,0,1945.1591565608976,0.9850159883499146,0.0504659637808799,0.1988551882391724,43793,3226.505700349808,0.9889920353889464,0.0375846289098262,0.2471297526207681,0.9858882427215576,0.0478384830057621,0.2030285123094846,43793 -1405.2155866622925,0.2566075325012207,2185.3532433509827,6799,0,2185.3532433509827,0.9850117564201356,0.0501020587980747,0.2055948206630922,43793,3590.998979330063,0.9890881180763244,0.0374409407377243,0.2541532404260944,0.9859243631362916,0.0475738011300563,0.2074654936605272,43793 -1530.9100363254547,0.2840509414672851,2425.605710029602,7555,0,2425.605710029602,0.9853213429450988,0.0491655617952346,0.2203662111267828,43793,3956.9928188323975,0.9891881942749025,0.0366190671920776,0.2631767259260585,0.9861549735069276,0.0466803461313247,0.2189902151347329,43793 -1656.8199598789215,0.3133604526519775,2665.6187121868134,8313,0,2665.6187121868134,0.985276699066162,0.0492173172533512,0.2197312395648356,43793,4322.964905500412,0.989005982875824,0.0370355211198329,0.2554642359352475,0.9861886501312256,0.04653275385499,0.2260755061735975,43793 -1784.504117488861,0.3417127132415771,2905.712122440338,9070,0,2905.712122440338,0.9853150248527528,0.0493118911981582,0.2232171338319447,43793,4690.790143251419,0.9891762137413024,0.0364898554980754,0.2828834821324972,0.9861512780189514,0.0466004610061645,0.2250200468792361,43793 -1910.97766661644,0.371028184890747,3145.721575975418,9820,0,3145.721575975418,0.9853967428207396,0.0487997084856033,0.2273547198715588,43793,5057.321557760239,0.9896302819252014,0.035380981862545,0.2862586124008465,0.9862738847732544,0.0461940877139568,0.2245747281650726,43793 -2036.571834087372,0.3991975784301758,3385.802992105484,10569,0,3385.802992105484,0.9855096340179444,0.0483597517013549,0.2335138669890933,43793,5423.044348478317,0.9896215200424194,0.0348889753222465,0.3107923658999827,0.9863603711128236,0.045642539858818,0.2417199449464199,43793 -2166.000794649124,0.427293062210083,3625.966535568237,11323,0,3625.966535568237,0.9857492446899414,0.0474965870380401,0.2476235995592323,43793,5792.6845734119415,0.989838182926178,0.03400369733572,0.3266752975698576,0.9865649342536926,0.0450877733528614,0.2468068704228591,43793 -2292.3009836673737,0.4576456546783447,3865.95241522789,12080,0,3865.95241522789,0.9857353568077089,0.0474820621311664,0.2390921175235748,43793,6159.020653247833,0.990098476409912,0.0331880785524845,0.3468790956136938,0.9866153001785278,0.0448572300374507,0.2468721325324762,43793 -2419.3099246025085,0.4859592914581299,4106.013333797455,12832,0,4106.013333797455,0.9856481552124025,0.0474556349217891,0.2417152882689169,43793,6526.13848400116,0.9901254177093506,0.0328053459525108,0.355174481887847,0.9865312576293944,0.0449366346001625,0.2488322315631422,43793 -2546.054483890533,0.5157315731048584,4346.096912384033,13581,0,4346.096912384033,0.9857589602470398,0.0478064157068729,0.242833364604821,43793,6893.016060590744,0.9902234077453612,0.0321284905076026,0.3709276512140884,0.9865828156471252,0.0451812818646431,0.2497697210787227,43793 -2677.8932163715363,0.547095775604248,4586.202342510223,14331,0,4586.202342510223,0.9858056902885436,0.0471140295267105,0.2448461793444904,43793,7265.011176109314,0.9904870986938475,0.031724065542221,0.3738916030379458,0.9866850972175598,0.0445470623672008,0.2559132804697909,43793 -2807.9602975845337,0.5814604759216309,4826.174679040909,15078,0,4826.174679040909,0.9858945608139038,0.0471870973706245,0.2494500413095305,43793,7635.104958295822,0.9903684258461,0.0319615788757801,0.3792209517081584,0.9867200255393982,0.0446940958499908,0.2523525355948909,43793 -2938.505881547928,0.6110355854034424,5066.263954401016,15823,0,5066.263954401016,0.9858419299125672,0.0476639531552791,0.2480464015958148,43793,8005.788725376129,0.9903414845466614,0.0320120230317115,0.3866011311652768,0.9866615533828736,0.0450022481381893,0.2581168652912072,43793 -3067.133838415146,0.6414587497711182,5306.334696531296,16570,0,5306.334696531296,0.9859615564346312,0.0469288416206836,0.2569955946510767,43793,8374.538804292679,0.9903900623321532,0.0318858809769153,0.367818958616825,0.9866932034492492,0.0446525178849697,0.2544019229495246,43793 -3195.857089042664,0.6710550785064697,5546.373893737793,17320,0,5546.373893737793,0.9859164953231812,0.0470666550099849,0.2520827505637201,43793,8743.351408958435,0.9903742671012878,0.0317201763391494,0.3792666162288067,0.9867115020751952,0.0447704903781414,0.2558755533093419,43793 -3326.569318294525,0.7022576332092285,5786.353320121765,18060,0,5786.353320121765,0.9859927296638488,0.0470446906983852,0.2545460441147208,43793,9114.093678951263,0.9906107783317566,0.0310329459607601,0.3850302235403376,0.9868434071540833,0.0444827862083911,0.2634650572068787,43793 -3459.07012462616,0.7338111400604248,6026.574734687805,18801,0,6026.574734687805,0.9858486652374268,0.0473712384700775,0.2486702052961499,43793,9486.867498636246,0.9907031059265136,0.0309583060443401,0.3988686456391712,0.9866709113121032,0.0448623113334178,0.2583840982955815,43793 -3590.8943939208984,0.7656044960021973,6266.726683616638,19550,0,6266.726683616638,0.9858562350273132,0.0470095984637737,0.2529303810550452,43793,9858.89572906494,0.9908729195594788,0.0300773531198501,0.4276070078582515,0.986758589744568,0.0443111844360828,0.2655009222441444,43793 -3723.218013286591,0.7975211143493652,6506.863411426544,20294,0,6506.863411426544,0.9859615564346312,0.0468484051525592,0.2587649927248948,43793,10231.407826423643,0.9909902811050416,0.029482202604413,0.4394125373132448,0.9867805242538452,0.0443419106304645,0.2630536781553423,43793 -3850.82346367836,0.8276448249816895,6747.060025215149,21040,0,6747.060025215149,0.985920250415802,0.0470907315611839,0.2505250804658075,43793,10599.26006937027,0.9909581542015076,0.0293449759483337,0.4319782718431922,0.9868271946907043,0.044256966561079,0.2620292413708662,43793 -3979.278872251511,0.8579602241516113,6987.061242580414,21788,0,6987.061242580414,0.9858971238136292,0.047172050923109,0.245060863473351,43793,10967.76712179184,0.9909757971763612,0.0297860577702522,0.411273935375378,0.986751675605774,0.0445716865360736,0.2534629003852076,43793 -4110.835282087326,0.8883786201477051,7227.019123077393,22534,0,7227.019123077393,0.9860268235206604,0.0470134206116199,0.259165562861682,43793,11339.331525325775,0.9908212423324584,0.0303160864859819,0.4167117446792816,0.9868998527526855,0.0442841872572898,0.2639624370741263,43793 -4240.48300409317,0.9200398921966552,7466.973075628281,23281,0,7466.973075628281,0.9858128428459167,0.0469983145594596,0.2511379732354883,43793,11708.985038518906,0.9908854365348816,0.03038932941854,0.4145162980485438,0.986748456954956,0.0443939976394176,0.2643333702414174,43793 -4370.933499574661,1.214691162109375,7706.785302639008,24031,0,7706.785302639008,0.9858853220939636,0.0472135208547115,0.2514192693563914,43793,12079.562668561935,0.9907976984977722,0.0301222447305917,0.4307058110209221,0.986678183078766,0.0445076636970043,0.2598207593992417,43793 -4498.967143058777,1.2464373111724854,7946.956842422485,24784,0,7946.956842422485,0.9858975410461426,0.0471698269248008,0.2590929594246151,43793,12447.819665908812,0.9909395575523376,0.0297166630625724,0.4196510958908334,0.9867833256721495,0.0445629172027111,0.2656587042217679,43793 -4624.81149148941,1.2782447338104248,8187.189737558365,25543,0,8187.189737558365,0.9859825968742372,0.0472048185765743,0.2633457291123027,43793,12813.948610544205,0.991019070148468,0.0294891316443681,0.4367420533835061,0.986832082271576,0.0445814169943332,0.2679887542443359,43793 -4750.859451770783,1.3099524974822998,8427.414668560028,26303,0,8427.414668560028,0.9859455227851868,0.0469372197985649,0.2593809136492444,43793,13180.273483514786,0.9911377429962158,0.0290332436561584,0.4396199351537836,0.9868190884590148,0.0445399396121501,0.2687223142231175,43793 -4880.380812883377,1.3404219150543213,8667.366248846054,27058,0,8667.366248846054,0.9858773350715636,0.0470265299081802,0.2526834643991631,43793,13549.797183036804,0.9913957715034484,0.0284608248621225,0.4587331702582208,0.9867423176765442,0.0443950109183788,0.2620780825173272,43793 -5007.864626646042,1.3766028881072998,8907.356931447983,27812,0,8907.356931447983,0.9860082864761353,0.0469130389392375,0.256955366585077,43793,13917.328118801115,0.9915438890457152,0.0275515113025903,0.478820071746103,0.9868454337120056,0.0440764725208282,0.268497485961063,43793 -5137.846126794815,1.4091360569000244,9147.537345647812,28566,0,9147.537345647812,0.9860162734985352,0.0467939786612987,0.2587861733148328,43793,14287.542890071869,0.9915900230407716,0.0274300016462802,0.4838465650251311,0.9868852496147156,0.0441147089004516,0.2774226575024014,43793 -5269.571957588196,1.4414465427398682,9387.8157954216,29316,0,9387.8157954216,0.985925316810608,0.0473748035728931,0.2550472083789822,43793,14659.599410057068,0.9914113283157348,0.028087593615055,0.4503752334796423,0.9868101477622986,0.044549535959959,0.2645352341758079,43793 -5398.746550559998,1.4734342098236084,9627.81943321228,30073,0,9627.81943321228,0.9858600497245787,0.0476008988916873,0.2547086969015181,43793,15028.82996916771,0.9912699460983276,0.0286596063524484,0.44959692861203,0.986766278743744,0.0447003468871116,0.2619020218166186,43793 -5527.817716121674,1.5052387714385986,9867.785235404968,30829,0,9867.785235404968,0.985971212387085,0.0476885735988616,0.2549692084955743,43793,15397.919102430344,0.9910553097724916,0.0290402844548225,0.4345576765374209,0.9868341088294984,0.044944878667593,0.2651539653275671,43793 -5655.200493574143,1.5384142398834229,10107.885499715803,31588,0,10107.885499715803,0.985917329788208,0.0471377708017826,0.2580299895531414,43793,15765.455152750015,0.9913302063941956,0.0285799913108348,0.4591087979868709,0.9867050051689148,0.044590026140213,0.2681575482452986,43793 -5783.775832414627,1.572129487991333,10347.881727457048,32346,0,10347.881727457048,0.9860752820968628,0.0467753335833549,0.26618305739415,43793,16134.079976081848,0.9914727210998536,0.0278371162712574,0.4717979107615732,0.9869274497032166,0.043929535895586,0.274115570411685,43793 -5910.768916606903,1.6054506301879885,10587.90379667282,33102,0,10587.90379667282,0.985958993434906,0.0473562479019165,0.2575073350230395,43793,16501.14848446846,0.9913511276245116,0.0281847678124904,0.4569296601938585,0.9868227243423462,0.0446917377412319,0.2654920394579966,43793 -6036.926429271698,1.6387813091278076,10828.046881198885,33860,0,10828.046881198885,0.9860659837722778,0.0471661426126956,0.2590065779058145,43793,16867.5029835701,0.9915713667869568,0.0275131501257419,0.4805088503117457,0.986916482448578,0.0443529598414897,0.2733849468277443,43793 -6162.10580778122,1.6727240085601809,11068.133890151978,34617,0,11068.133890151978,0.9859358668327332,0.0476084761321544,0.2522764802677962,43793,17232.824068784714,0.9917721152305604,0.0267431810498237,0.4938607820037304,0.9869108200073242,0.0445934534072876,0.2638629801469858,43793 -6287.737751483917,1.7055885791778564,11308.152801513672,35373,0,11308.152801513672,0.9859430193901062,0.0474424511194229,0.2559087693871861,43793,17598.52809739113,0.9918096661567688,0.0265304185450077,0.5029054677780674,0.9868081212043762,0.0446687713265419,0.2637214181806488,43793 -6413.520535945892,1.7420799732208252,11548.278539657593,36130,0,11548.278539657593,0.9859346151351928,0.0476196333765983,0.254952374996752,43793,17964.49371266365,0.9921783208847046,0.0254643373191356,0.5405584099723635,0.9869144558906556,0.0447491630911827,0.2696709121978386,43793 -6541.208086490631,1.776219606399536,11788.357815027235,36881,0,11788.357815027235,0.985889971256256,0.047700397670269,0.2532243725665771,43793,18332.315428972244,0.9921891689300536,0.0256269536912441,0.5054668073603061,0.9867618083953856,0.0448433235287666,0.2666542209544567,43793 -6669.069313287735,1.80993127822876,12028.320405721664,37631,0,12028.320405721664,0.9858503341674804,0.0473716296255588,0.2548070281969129,43793,18700.193468809128,0.9919636249542236,0.0262818839401006,0.5000979311796215,0.9867650866508484,0.04451534897089,0.269276343714745,43793 -6792.561163425446,1.8454234600067136,12268.40572667122,38379,0,12268.40572667122,0.9860424399375916,0.0478706955909729,0.2579735518736352,43793,19063.826691389084,0.9917044639587402,0.0267996452748775,0.4938836258819183,0.986860454082489,0.0450091734528541,0.2717148167414828,43793 -6918.80362701416,1.879314661026001,12508.509371519089,39127,0,12508.509371519089,0.986038625240326,0.0473871938884258,0.2648304309454832,43793,19430.22696685791,0.9917911887168884,0.0266721807420253,0.4895440123990278,0.9869790077209472,0.044631291180849,0.2678226921038907,43793 -7039.660442113876,1.91489839553833,12748.566566467283,39885,0,12748.566566467283,0.9861629009246826,0.0473428405821323,0.2622531743337971,43793,19791.197275161743,0.9919468760490416,0.0260848235338926,0.5204706545550591,0.987012267112732,0.0442746318876743,0.2786177764963892,43793 -7168.117209196091,1.949426889419556,12988.536489725111,40639,0,12988.536489725111,0.985961139202118,0.0480013452470302,0.2567544314291233,43793,20159.67905855179,0.991911232471466,0.0260745268315076,0.5069046204867362,0.986780881881714,0.0451743900775909,0.2680662528745569,43793 -7291.893787145615,1.9835777282714844,13228.646829366684,41398,0,13228.646829366684,0.985961139202118,0.0479624792933464,0.2656677761975639,43793,20523.62076807022,0.9920253753662108,0.0256763193756341,0.508227948274832,0.98688805103302,0.0451797060668468,0.2732676072132827,43793 -7414.11568903923,2.017220973968506,13468.731812000276,42155,0,13468.731812000276,0.98598974943161,0.0482239499688148,0.2574946687375729,43793,20885.98131942749,0.992261290550232,0.0248192362487316,0.5357011778114732,0.986861288547516,0.0453098826110363,0.2641243823559999,43793 -7538.565611362457,2.051811456680298,13708.764877796171,42908,0,13708.764877796171,0.9860158562660216,0.0480639971792697,0.2653323038381081,43793,21250.51918125153,0.9926353096961976,0.0238346364349126,0.5570144195376461,0.986894965171814,0.045154895633459,0.2684881619213733,43793 -7664.496058940887,2.0864064693450928,13948.78694844246,43665,0,13948.78694844246,0.98598051071167,0.0484649688005447,0.2619368512501648,43793,21616.52692937851,0.9925816655158995,0.0239026378840208,0.569847598106749,0.9868986010551452,0.0452833510935306,0.2717307550178367,43793 -7793.158420085907,2.121481418609619,14188.775550365448,44417,0,14188.775550365448,0.9859729409217834,0.0484565906226635,0.2569524214786334,43793,21985.23408293724,0.9929485321044922,0.022788044065237,0.5749294484192354,0.9868471026420592,0.0452699847519397,0.2726284919209182,43793 -7917.267426967621,2.1592814922332764,14428.960973501204,45164,0,14428.960973501204,0.985913097858429,0.0485909059643745,0.2562790789794792,43793,22349.58697938919,0.9925280213356018,0.0241652112454175,0.5439482996360956,0.9867951273918152,0.0455055013298988,0.2725618970001102,43793 -8042.86420583725,2.19640588760376,14669.197919368744,45910,0,14669.197919368744,0.9858949780464172,0.0484018251299858,0.2579783262777751,43793,22715.47890305519,0.9924951195716858,0.024184413254261,0.5322754362561134,0.9868247509002686,0.04539680108428,0.2694207075858307,43793 -8168.112809181213,2.2335081100463867,14909.381882667542,46658,0,14909.381882667542,0.9860398769378662,0.0489005632698535,0.2569565723669573,43793,23080.968989133835,0.9923391938209534,0.0245573688298463,0.5505022282160861,0.9869205355644226,0.0457893349230289,0.2665850458991095,43793 -8294.17154455185,2.270673513412476,15149.46195435524,47410,0,15149.46195435524,0.986047863960266,0.0485950969159603,0.2584206174290291,43793,23447.16524910927,0.9923987984657288,0.0243546906858682,0.547945267187708,0.9869290590286256,0.0455531850457191,0.2730923089403295,43793 -8420.551263570786,2.3075530529022217,15389.69013428688,48164,0,15389.69013428688,0.9860175848007202,0.0482630915939807,0.2613342734660124,43793,23813.83040785789,0.9926859736442566,0.0235126633197069,0.5531412566917546,0.9868718385696412,0.0453873015940189,0.2729452931736282,43793 -8543.148166894913,2.350667953491211,15629.627718687056,48923,0,15629.627718687056,0.9859282970428468,0.0485388562083244,0.2617990529820947,43793,24176.42839407921,0.9927862286567688,0.0231304336339235,0.5690608496377537,0.9868023991584778,0.0456211678683757,0.2696094281540254,43793 -8668.656472444534,2.388645648956299,15869.585559368134,49677,0,15869.585559368134,0.9859762787818908,0.0492651611566543,0.2634007317030291,43793,24541.95281338692,0.9929873943328856,0.0223092995584011,0.5790065918357813,0.9868316650390624,0.0461039021611213,0.2703625890701598,43793 -8793.248821020126,2.4257941246032715,16109.806084394457,50435,0,16109.806084394457,0.9858596324920654,0.0491060465574264,0.2610190774407657,43793,24906.823295116425,0.993471086025238,0.0211111195385456,0.6146497697624851,0.986735463142395,0.0460940673947334,0.2744567062110209,43793 -8918.434713840485,2.461792469024658,16349.852685928345,51187,0,16349.852685928345,0.985773265361786,0.0498418360948562,0.2565580586568688,43793,25272.11195421219,0.9937585592269896,0.0203729905188083,0.6229511751791663,0.9867293238639832,0.0467021130025386,0.2689869825711359,43793 -9043.52482008934,2.4987967014312744,16589.82354283333,51940,0,16589.82354283333,0.9859463572502136,0.0497677139937877,0.2659970159599488,43793,25637.23035120964,0.9935181140899658,0.0207281652837991,0.6268248896791947,0.9868150353431702,0.0465990602970123,0.2722384098665055,43793 -9170.178788661957,2.5351483821868896,16829.831008911133,52690,0,16829.831008911133,0.9856839776039124,0.0501977689564228,0.2549874512181881,43793,26003.949209213257,0.9936492443084716,0.0206983480602502,0.6196390393205258,0.9866599440574646,0.0467642769217491,0.2696261098279769,43793 -9294.73175382614,2.572896957397461,17069.828882217407,53446,0,17069.828882217407,0.9859278798103333,0.0503723919391632,0.2595123515077037,43793,26368.558528900143,0.9932865500450134,0.0214157178997993,0.6066971485592483,0.986796736717224,0.0472288504242897,0.2679517868571796,43793 -9417.017122030258,2.609468460083008,17310.08274435997,54202,0,17310.08274435997,0.9859842658042908,0.0508865006268024,0.2572087545398565,43793,26731.15495181084,0.9932235479354858,0.0213823653757572,0.6007852406767729,0.9868023991584778,0.047565758228302,0.2681301990027719,43793 -9540.67410349846,2.645934820175171,17550.31499028206,54959,0,17550.31499028206,0.9857816696166992,0.0509912185370922,0.2596567648926789,43793,27095.101389169693,0.9933767318725586,0.0210478398948907,0.6104609656727764,0.9867898225784302,0.0478362925350666,0.2712276426357609,43793 -9662.156247615814,2.6831068992614746,17790.44306921959,55710,0,17790.44306921959,0.9857547283172609,0.0514043755829334,0.254248360382065,43793,27456.76895523072,0.993337333202362,0.0210698246955871,0.6063672334387846,0.9867419600486756,0.0481171533465385,0.2701161407026355,43793 -9783.057143211365,2.721312522888184,18030.64174723625,56463,0,18030.64174723625,0.9857159852981568,0.0516699701547622,0.2506636334811378,43793,27817.927065372467,0.9936086535453796,0.0201977919787168,0.6252101015049698,0.9866887331008912,0.0482036210596561,0.2638393455795906,43793 -9904.612380743027,2.7617759704589844,18270.87906050682,57216,0,18270.87906050682,0.9858503341674805,0.05154655501246452,0.26062947479450477,43793,28179.78066945076,0.9939561486244202,0.01918221265077591,0.6443242592758944,0.9866737127304077,0.04826799035072327,0.266641069144078,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index 457160747..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,658 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.1866565,0.7992708,,,,,,,,,,,,,,,,, -1,,,0.3887696266174316,0.7993155121803284,0.0215552107974815,0.3926452994346618,0.7974664568901062,0.0248266565510279,43793.0,0.3947486877441406,0.7956756353378296,0.0263659558726878,43793.0,24.272481203079224,326.25207114219666,24.272481203079224,301.979544878006,0.0,0.0 -100,0.7065518,0.50436944,,,,,,,,,,,,,,,,, -200,0.41144764,0.36354873,,,,,,,,,,,,,,,,, -300,0.30055618,0.2604981,,,,,,,,,,,,,,,,, -400,0.20152383,0.1751076,,,,,,,,,,,,,,,,, -500,0.12493377,0.1222978,,,,,,,,,,,,,,,,, -600,0.08027274,0.08971704,,,,,,,,,,,,,,,,, -700,0.07393093,0.084785126,,,,,,,,,,,,,,,,, -763,,,0.9866331219673156,0.0705618038773536,0.0359794845055697,0.9841179251670836,0.078437902033329,0.0371951699350063,43793.0,0.983142077922821,0.0812012106180191,0.0383700018853028,43793.0,264.3512933254242,680.7509214878082,264.3512933254242,416.35071659088135,0.0298035144805908,0.0 -800,0.11751059,0.06160292,,,,,,,,,,,,,,,,, -900,0.035384748,0.062139828,,,,,,,,,,,,,,,,, -1000,0.08266612,0.061075773,,,,,,,,,,,,,,,,, -1100,0.09521577,0.05519803,,,,,,,,,,,,,,,,, -1200,0.108206734,0.04897866,,,,,,,,,,,,,,,,, -1300,0.13927081,0.04677374,,,,,,,,,,,,,,,,, -1400,0.13944149,0.04518235,,,,,,,,,,,,,,,,, -1500,0.11165842,0.054504454,,,,,,,,,,,,,,,,, -1516,,,0.9874165058135986,0.0481969825923442,0.0838339449449183,0.9846513271331788,0.0573393926024436,0.0884273181317277,43793.0,0.9836373925209044,0.0605317205190658,0.0888645896201985,43793.0,504.42682576179504,1040.8308537006378,504.42682576179504,536.3063020706177,0.0594303607940673,0.0 -1600,0.14471407,0.052310333,,,,,,,,,,,,,,,,, -1700,0.19064789,0.04971242,,,,,,,,,,,,,,,,, -1800,0.17152084,0.04815773,,,,,,,,,,,,,,,,, -1900,0.077987924,0.045729335,,,,,,,,,,,,,,,,, -2000,0.17210911,0.050392468,,,,,,,,,,,,,,,,, -2100,0.10318856,0.044199087,,,,,,,,,,,,,,,,, -2200,0.18620503,0.045299646,,,,,,,,,,,,,,,,, -2272,,,0.9878497123718262,0.0441702157258987,0.1317534680029279,0.9849797487258912,0.0539003200829029,0.1300213241066441,43793.0,0.9840400815010072,0.0568913184106349,0.1302070232937932,43793.0,744.610445022583,1404.3017330169678,744.610445022583,659.5478177070618,0.0860383510589599,0.0 -2300,0.1542212,0.047942735,,,,,,,,,,,,,,,,, -2400,0.17344005,0.04739619,,,,,,,,,,,,,,,,, -2500,0.22667956,0.045106184,,,,,,,,,,,,,,,,, -2600,0.07314138,0.04713095,,,,,,,,,,,,,,,,, -2700,0.11542914,0.047433197,,,,,,,,,,,,,,,,, -2800,0.090182535,0.04331975,,,,,,,,,,,,,,,,, -2900,0.07331361,0.045388747,,,,,,,,,,,,,,,,, -3000,0.10026912,0.04630819,,,,,,,,,,,,,,,,, -3024,,,0.9880422353744508,0.042276244610548,0.1592302566590645,0.9852914810180664,0.0512494929134845,0.1531901395730448,43793.0,0.9843626618385316,0.0538056381046772,0.1520338482806402,43793.0,984.7043483257294,1768.7225918769836,984.7043483257294,783.8252048492432,0.1164727210998535,0.0 -3100,0.06467709,0.041971367,,,,,,,,,,,,,,,,, -3200,0.05693326,0.043666024,,,,,,,,,,,,,,,,, -3300,0.12190675,0.039540607,,,,,,,,,,,,,,,,, -3400,0.08756083,0.04175049,,,,,,,,,,,,,,,,, -3500,0.050066486,0.041732088,,,,,,,,,,,,,,,,, -3600,0.06480187,0.042604044,,,,,,,,,,,,,,,,, -3700,0.08216144,0.040774412,,,,,,,,,,,,,,,,, -3778,,,0.9882649779319764,0.0405199825763702,0.1913616464709115,0.9853755235671996,0.0502678081393241,0.1748420624477223,43793.0,0.9844376444816588,0.0531672686338424,0.1731884919959402,43793.0,1224.8253166675568,2130.8698279857635,1224.8253166675568,905.8047299385072,0.1440849304199218,0.0 -3800,0.10389339,0.045280002,,,,,,,,,,,,,,,,, -3900,0.04780066,0.04088153,,,,,,,,,,,,,,,,, -4000,0.061417166,0.04051753,,,,,,,,,,,,,,,,, -4100,0.0902683,0.046344183,,,,,,,,,,,,,,,,, -4200,0.07847788,0.04531592,,,,,,,,,,,,,,,,, -4300,0.054945122,0.03881543,,,,,,,,,,,,,,,,, -4400,0.077988684,0.03664126,,,,,,,,,,,,,,,,, -4500,0.034632146,0.041486215,,,,,,,,,,,,,,,,, -4538,,,0.9881364107131958,0.0405524857342243,0.1995915133924315,0.985506236553192,0.0498411767184734,0.1833361099801978,43793.0,0.984628438949585,0.052729532122612,0.1795988521353688,43793.0,1464.9182755947113,2494.039370775223,1464.9182755947113,1028.835110425949,0.1708650588989257,0.0 -4600,0.107409135,0.04289005,,,,,,,,,,,,,,,,, -4700,0.07356065,0.04207149,,,,,,,,,,,,,,,,, -4800,0.03563027,0.04026913,,,,,,,,,,,,,,,,, -4900,0.03404192,0.044142112,,,,,,,,,,,,,,,,, -5000,0.06880524,0.044252235,,,,,,,,,,,,,,,,, -5100,0.035076007,0.038851276,,,,,,,,,,,,,,,,, -5200,0.041872058,0.04003624,,,,,,,,,,,,,,,,, -5291,,,0.9885979890823364,0.0390627458691597,0.235640214546389,0.9856986403465272,0.0487240441143512,0.1980871299196871,43793.0,0.9847902059555054,0.0514041110873222,0.1930594567232489,43793.0,1704.9223086833954,2860.033614873886,1704.9223086833954,1154.775056600571,0.2018733024597168,0.0 -5300,0.055668127,0.043307245,,,,,,,,,,,,,,,,, -5400,0.050443307,0.03722988,,,,,,,,,,,,,,,,, -5500,0.0369667,0.045223027,,,,,,,,,,,,,,,,, -5600,0.047926527,0.03940282,,,,,,,,,,,,,,,,, -5700,0.0428191,0.0406016,,,,,,,,,,,,,,,,, -5800,0.0409922,0.039085217,,,,,,,,,,,,,,,,, -5900,0.11071988,0.03893374,,,,,,,,,,,,,,,,, -6000,0.035486978,0.041759986,,,,,,,,,,,,,,,,, -6048,,,0.9889920353889464,0.0375846289098262,0.2471297526207681,0.9858882427215576,0.0478384830057621,0.2030285123094846,43793.0,0.9850159883499146,0.0504659637808799,0.1988551882391724,43793.0,1945.1591565608976,3226.505700349808,1945.1591565608976,1280.962295293808,0.2301163673400879,0.0 -6100,0.037913598,0.04160702,,,,,,,,,,,,,,,,, -6200,0.04132923,0.038250867,,,,,,,,,,,,,,,,, -6300,0.035999306,0.039149392,,,,,,,,,,,,,,,,, -6400,0.031308565,0.04535444,,,,,,,,,,,,,,,,, -6500,0.032439735,0.041577313,,,,,,,,,,,,,,,,, -6600,0.036024306,0.038859885,,,,,,,,,,,,,,,,, -6700,0.02331911,0.039672457,,,,,,,,,,,,,,,,, -6799,,,0.9890881180763244,0.0374409407377243,0.2541532404260944,0.9859243631362916,0.0475738011300563,0.2074654936605272,43793.0,0.9850117564201356,0.0501020587980747,0.2055948206630922,43793.0,2185.3532433509827,3590.998979330063,2185.3532433509827,1405.2155866622925,0.2566075325012207,0.0 -6800,0.03033978,0.038202286,,,,,,,,,,,,,,,,, -6900,0.043846373,0.041878138,,,,,,,,,,,,,,,,, -7000,0.03363675,0.041541472,,,,,,,,,,,,,,,,, -7100,0.02973714,0.0379416,,,,,,,,,,,,,,,,, -7200,0.039425105,0.0380211,,,,,,,,,,,,,,,,, -7300,0.045745898,0.042528,,,,,,,,,,,,,,,,, -7400,0.030906606,0.04099234,,,,,,,,,,,,,,,,, -7500,0.02914314,0.04042081,,,,,,,,,,,,,,,,, -7555,,,0.9891881942749025,0.0366190671920776,0.2631767259260585,0.9861549735069276,0.0466803461313247,0.2189902151347329,43793.0,0.9853213429450988,0.0491655617952346,0.2203662111267828,43793.0,2425.605710029602,3956.9928188323975,2425.605710029602,1530.9100363254547,0.2840509414672851,0.0 -7600,0.039555725,0.04028752,,,,,,,,,,,,,,,,, -7700,0.0251068,0.039336804,,,,,,,,,,,,,,,,, -7800,0.0270616,0.03712251,,,,,,,,,,,,,,,,, -7900,0.02989884,0.039549686,,,,,,,,,,,,,,,,, -8000,0.031665877,0.040527765,,,,,,,,,,,,,,,,, -8100,0.034261774,0.04430368,,,,,,,,,,,,,,,,, -8200,0.036407378,0.037024193,,,,,,,,,,,,,,,,, -8300,0.021569734,0.037204046,,,,,,,,,,,,,,,,, -8313,,,0.989005982875824,0.0370355211198329,0.2554642359352475,0.9861886501312256,0.04653275385499,0.2260755061735975,43793.0,0.985276699066162,0.0492173172533512,0.2197312395648356,43793.0,2665.6187121868134,4322.964905500412,2665.6187121868134,1656.8199598789215,0.3133604526519775,0.0 -8400,0.033980988,0.04002229,,,,,,,,,,,,,,,,, -8500,0.0355665,0.039555013,,,,,,,,,,,,,,,,, -8600,0.03513325,0.039502785,,,,,,,,,,,,,,,,, -8700,0.028790966,0.040686406,,,,,,,,,,,,,,,,, -8800,0.02709945,0.03975329,,,,,,,,,,,,,,,,, -8900,0.030476779,0.04000039,,,,,,,,,,,,,,,,, -9000,0.028522441,0.043404143,,,,,,,,,,,,,,,,, -9070,,,0.9891762137413024,0.0364898554980754,0.2828834821324972,0.9861512780189514,0.0466004610061645,0.2250200468792361,43793.0,0.9853150248527528,0.0493118911981582,0.2232171338319447,43793.0,2905.712122440338,4690.790143251419,2905.712122440338,1784.504117488861,0.3417127132415771,0.0 -9100,0.026218774,0.04098729,,,,,,,,,,,,,,,,, -9200,0.026643902,0.03761563,,,,,,,,,,,,,,,,, -9300,0.030075697,0.042415645,,,,,,,,,,,,,,,,, -9400,0.037834436,0.04181869,,,,,,,,,,,,,,,,, -9500,0.021896688,0.036083557,,,,,,,,,,,,,,,,, -9600,0.028602367,0.04019165,,,,,,,,,,,,,,,,, -9700,0.028998202,0.041782968,,,,,,,,,,,,,,,,, -9800,0.027667558,0.036142044,,,,,,,,,,,,,,,,, -9820,,,0.9896302819252014,0.035380981862545,0.2862586124008465,0.9862738847732544,0.0461940877139568,0.2245747281650726,43793.0,0.9853967428207396,0.0487997084856033,0.2273547198715588,43793.0,3145.721575975418,5057.321557760239,3145.721575975418,1910.97766661644,0.371028184890747,0.0 -9900,0.025644585,0.03861752,,,,,,,,,,,,,,,,, -10000,0.020946562,0.03772664,,,,,,,,,,,,,,,,, -10100,0.020458164,0.03650102,,,,,,,,,,,,,,,,, -10200,0.025993804,0.039587554,,,,,,,,,,,,,,,,, -10300,0.022110578,0.035853222,,,,,,,,,,,,,,,,, -10400,0.02202253,0.0392696,,,,,,,,,,,,,,,,, -10500,0.024250206,0.036661945,,,,,,,,,,,,,,,,, -10569,,,0.9896215200424194,0.0348889753222465,0.3107923658999827,0.9863603711128236,0.045642539858818,0.2417199449464199,43793.0,0.9855096340179444,0.0483597517013549,0.2335138669890933,43793.0,3385.802992105484,5423.044348478317,3385.802992105484,2036.571834087372,0.3991975784301758,0.0 -10600,0.025585609,0.036317065,,,,,,,,,,,,,,,,, -10700,0.045643125,0.035796534,,,,,,,,,,,,,,,,, -10800,0.028183054,0.038288053,,,,,,,,,,,,,,,,, -10900,0.024030898,0.03820738,,,,,,,,,,,,,,,,, -11000,0.039464604,0.038838122,,,,,,,,,,,,,,,,, -11100,0.03528491,0.037656303,,,,,,,,,,,,,,,,, -11200,0.032598987,0.034800306,,,,,,,,,,,,,,,,, -11300,0.02702343,0.03844989,,,,,,,,,,,,,,,,, -11323,,,0.989838182926178,0.03400369733572,0.3266752975698576,0.9865649342536926,0.0450877733528614,0.2468068704228591,43793.0,0.9857492446899414,0.0474965870380401,0.2476235995592323,43793.0,3625.966535568237,5792.6845734119415,3625.966535568237,2166.000794649124,0.427293062210083,0.0 -11400,0.024788283,0.037898075,,,,,,,,,,,,,,,,, -11500,0.040271904,0.038879506,,,,,,,,,,,,,,,,, -11600,0.03177446,0.037941273,,,,,,,,,,,,,,,,, -11700,0.03746754,0.038640875,,,,,,,,,,,,,,,,, -11800,0.030740462,0.03448375,,,,,,,,,,,,,,,,, -11900,0.031848,0.038762197,,,,,,,,,,,,,,,,, -12000,0.033737916,0.037307795,,,,,,,,,,,,,,,,, -12080,,,0.990098476409912,0.0331880785524845,0.3468790956136938,0.9866153001785278,0.0448572300374507,0.2468721325324762,43793.0,0.9857353568077089,0.0474820621311664,0.2390921175235748,43793.0,3865.95241522789,6159.020653247833,3865.95241522789,2292.3009836673737,0.4576456546783447,0.0 -12100,0.04761622,0.03646014,,,,,,,,,,,,,,,,, -12200,0.03255898,0.03693257,,,,,,,,,,,,,,,,, -12300,0.05304555,0.03783487,,,,,,,,,,,,,,,,, -12400,0.027934803,0.035723656,,,,,,,,,,,,,,,,, -12500,0.045308836,0.037334472,,,,,,,,,,,,,,,,, -12600,0.04183007,0.035248425,,,,,,,,,,,,,,,,, -12700,0.03995151,0.039924435,,,,,,,,,,,,,,,,, -12800,0.031893067,0.03616774,,,,,,,,,,,,,,,,, -12832,,,0.9901254177093506,0.0328053459525108,0.355174481887847,0.9865312576293944,0.0449366346001625,0.2488322315631422,43793.0,0.9856481552124025,0.0474556349217891,0.2417152882689169,43793.0,4106.013333797455,6526.13848400116,4106.013333797455,2419.3099246025085,0.4859592914581299,0.0 -12900,0.034718852,0.035224773,,,,,,,,,,,,,,,,, -13000,0.03136088,0.03644007,,,,,,,,,,,,,,,,, -13100,0.046816558,0.033295136,,,,,,,,,,,,,,,,, -13200,0.051245585,0.03473206,,,,,,,,,,,,,,,,, -13300,0.042266224,0.038246363,,,,,,,,,,,,,,,,, -13400,0.02919406,0.035934336,,,,,,,,,,,,,,,,, -13500,0.04772646,0.038125668,,,,,,,,,,,,,,,,, -13581,,,0.9902234077453612,0.0321284905076026,0.3709276512140884,0.9865828156471252,0.0451812818646431,0.2497697210787227,43793.0,0.9857589602470398,0.0478064157068729,0.242833364604821,43793.0,4346.096912384033,6893.016060590744,4346.096912384033,2546.054483890533,0.5157315731048584,0.0 -13600,0.04317512,0.039373998,,,,,,,,,,,,,,,,, -13700,0.066223286,0.03330094,,,,,,,,,,,,,,,,, -13800,0.036428425,0.034377415,,,,,,,,,,,,,,,,, -13900,0.054297063,0.040771335,,,,,,,,,,,,,,,,, -14000,0.04925575,0.035273332,,,,,,,,,,,,,,,,, -14100,0.04019865,0.036246426,,,,,,,,,,,,,,,,, -14200,0.03903427,0.041649085,,,,,,,,,,,,,,,,, -14300,0.06797067,0.036065105,,,,,,,,,,,,,,,,, -14331,,,0.9904870986938475,0.031724065542221,0.3738916030379458,0.9866850972175598,0.0445470623672008,0.2559132804697909,43793.0,0.9858056902885436,0.0471140295267105,0.2448461793444904,43793.0,4586.202342510223,7265.011176109314,4586.202342510223,2677.8932163715363,0.547095775604248,0.0 -14400,0.04543383,0.040351234,,,,,,,,,,,,,,,,, -14500,0.045809463,0.03514759,,,,,,,,,,,,,,,,, -14600,0.081976295,0.03684784,,,,,,,,,,,,,,,,, -14700,0.038982693,0.03683134,,,,,,,,,,,,,,,,, -14800,0.04139058,0.03194149,,,,,,,,,,,,,,,,, -14900,0.0728403,0.038174067,,,,,,,,,,,,,,,,, -15000,0.06498065,0.037767984,,,,,,,,,,,,,,,,, -15078,,,0.9903684258461,0.0319615788757801,0.3792209517081584,0.9867200255393982,0.0446940958499908,0.2523525355948909,43793.0,0.9858945608139038,0.0471870973706245,0.2494500413095305,43793.0,4826.174679040909,7635.104958295822,4826.174679040909,2807.9602975845337,0.5814604759216309,0.0 -15100,0.05279883,0.033336546,,,,,,,,,,,,,,,,, -15200,0.07373848,0.03729742,,,,,,,,,,,,,,,,, -15300,0.04853734,0.03993529,,,,,,,,,,,,,,,,, -15400,0.05562851,0.034701373,,,,,,,,,,,,,,,,, -15500,0.061104268,0.03673313,,,,,,,,,,,,,,,,, -15600,0.039010536,0.037181396,,,,,,,,,,,,,,,,, -15700,0.06308284,0.03845905,,,,,,,,,,,,,,,,, -15800,0.047055326,0.035579894,,,,,,,,,,,,,,,,, -15823,,,0.9903414845466614,0.0320120230317115,0.3866011311652768,0.9866615533828736,0.0450022481381893,0.2581168652912072,43793.0,0.9858419299125672,0.0476639531552791,0.2480464015958148,43793.0,5066.263954401016,8005.788725376129,5066.263954401016,2938.505881547928,0.6110355854034424,0.0 -15900,0.046954837,0.03414049,,,,,,,,,,,,,,,,, -16000,0.0432432,0.031496007,,,,,,,,,,,,,,,,, -16100,0.05658664,0.03842983,,,,,,,,,,,,,,,,, -16200,0.04807265,0.033867348,,,,,,,,,,,,,,,,, -16300,0.068888836,0.033223785,,,,,,,,,,,,,,,,, -16400,0.04496865,0.03789689,,,,,,,,,,,,,,,,, -16500,0.044612404,0.031615905,,,,,,,,,,,,,,,,, -16570,,,0.9903900623321532,0.0318858809769153,0.367818958616825,0.9866932034492492,0.0446525178849697,0.2544019229495246,43793.0,0.9859615564346312,0.0469288416206836,0.2569955946510767,43793.0,5306.334696531296,8374.538804292679,5306.334696531296,3067.133838415146,0.6414587497711182,0.0 -16600,0.06421914,0.034598306,,,,,,,,,,,,,,,,, -16700,0.06784644,0.037347067,,,,,,,,,,,,,,,,, -16800,0.039677005,0.034406174,,,,,,,,,,,,,,,,, -16900,0.07198808,0.03571426,,,,,,,,,,,,,,,,, -17000,0.07116256,0.037139654,,,,,,,,,,,,,,,,, -17100,0.068356425,0.035261437,,,,,,,,,,,,,,,,, -17200,0.05158501,0.03313924,,,,,,,,,,,,,,,,, -17300,0.053248834,0.035424665,,,,,,,,,,,,,,,,, -17320,,,0.9903742671012878,0.0317201763391494,0.3792666162288067,0.9867115020751952,0.0447704903781414,0.2558755533093419,43793.0,0.9859164953231812,0.0470666550099849,0.2520827505637201,43793.0,5546.373893737793,8743.351408958435,5546.373893737793,3195.857089042664,0.6710550785064697,0.0 -17400,0.057249565,0.033515006,,,,,,,,,,,,,,,,, -17500,0.061238196,0.033163864,,,,,,,,,,,,,,,,, -17600,0.069893345,0.03402101,,,,,,,,,,,,,,,,, -17700,0.13418116,0.031158805,,,,,,,,,,,,,,,,, -17800,0.04509314,0.032487,,,,,,,,,,,,,,,,, -17900,0.089701794,0.034476895,,,,,,,,,,,,,,,,, -18000,0.05411909,0.035851214,,,,,,,,,,,,,,,,, -18060,,,0.9906107783317566,0.0310329459607601,0.3850302235403376,0.9868434071540833,0.0444827862083911,0.2634650572068787,43793.0,0.9859927296638488,0.0470446906983852,0.2545460441147208,43793.0,5786.353320121765,9114.093678951263,5786.353320121765,3326.569318294525,0.7022576332092285,0.0 -18100,0.10001166,0.034113035,,,,,,,,,,,,,,,,, -18200,0.05345493,0.03728614,,,,,,,,,,,,,,,,, -18300,0.060820278,0.03607399,,,,,,,,,,,,,,,,, -18400,0.049110115,0.03352506,,,,,,,,,,,,,,,,, -18500,0.056070477,0.03282519,,,,,,,,,,,,,,,,, -18600,0.0612769,0.035940353,,,,,,,,,,,,,,,,, -18700,0.05982638,0.03134213,,,,,,,,,,,,,,,,, -18800,0.0703241,0.034353804,,,,,,,,,,,,,,,,, -18801,,,0.9907031059265136,0.0309583060443401,0.3988686456391712,0.9866709113121032,0.0448623113334178,0.2583840982955815,43793.0,0.9858486652374268,0.0473712384700775,0.2486702052961499,43793.0,6026.574734687805,9486.867498636246,6026.574734687805,3459.07012462616,0.7338111400604248,0.0 -18900,0.05409127,0.034720767,,,,,,,,,,,,,,,,, -19000,0.065742016,0.034925453,,,,,,,,,,,,,,,,, -19100,0.061195973,0.036814183,,,,,,,,,,,,,,,,, -19200,0.10066567,0.038170226,,,,,,,,,,,,,,,,, -19300,0.060354214,0.036350973,,,,,,,,,,,,,,,,, -19400,0.056563955,0.036034506,,,,,,,,,,,,,,,,, -19500,0.073456354,0.032697354,,,,,,,,,,,,,,,,, -19550,,,0.9908729195594788,0.0300773531198501,0.4276070078582515,0.986758589744568,0.0443111844360828,0.2655009222441444,43793.0,0.9858562350273132,0.0470095984637737,0.2529303810550452,43793.0,6266.726683616638,9858.89572906494,6266.726683616638,3590.8943939208984,0.7656044960021973,0.0 -19600,0.0643927,0.0340296,,,,,,,,,,,,,,,,, -19700,0.06591787,0.032090567,,,,,,,,,,,,,,,,, -19800,0.054701954,0.03263684,,,,,,,,,,,,,,,,, -19900,0.05298004,0.034115274,,,,,,,,,,,,,,,,, -20000,0.08028686,0.032417152,,,,,,,,,,,,,,,,, -20100,0.0682291,0.03235371,,,,,,,,,,,,,,,,, -20200,0.10036147,0.029818946,,,,,,,,,,,,,,,,, -20294,,,0.9909902811050416,0.029482202604413,0.4394125373132448,0.9867805242538452,0.0443419106304645,0.2630536781553423,43793.0,0.9859615564346312,0.0468484051525592,0.2587649927248948,43793.0,6506.863411426544,10231.407826423643,6506.863411426544,3723.218013286591,0.7975211143493652,0.0 -20300,0.058959212,0.03360052,,,,,,,,,,,,,,,,, -20400,0.07633833,0.035196915,,,,,,,,,,,,,,,,, -20500,0.06596668,0.036795773,,,,,,,,,,,,,,,,, -20600,0.07473062,0.027961534,,,,,,,,,,,,,,,,, -20700,0.05642819,0.03563433,,,,,,,,,,,,,,,,, -20800,0.08205942,0.032535,,,,,,,,,,,,,,,,, -20900,0.09611622,0.03430699,,,,,,,,,,,,,,,,, -21000,0.09079676,0.035543356,,,,,,,,,,,,,,,,, -21040,,,0.9909581542015076,0.0293449759483337,0.4319782718431922,0.9868271946907043,0.044256966561079,0.2620292413708662,43793.0,0.985920250415802,0.0470907315611839,0.2505250804658075,43793.0,6747.060025215149,10599.26006937027,6747.060025215149,3850.82346367836,0.8276448249816895,0.0 -21100,0.0561941,0.03549791,,,,,,,,,,,,,,,,, -21200,0.10686528,0.034475565,,,,,,,,,,,,,,,,, -21300,0.060547877,0.033468306,,,,,,,,,,,,,,,,, -21400,0.06458549,0.03391161,,,,,,,,,,,,,,,,, -21500,0.108619474,0.037605714,,,,,,,,,,,,,,,,, -21600,0.058333687,0.033788174,,,,,,,,,,,,,,,,, -21700,0.059743512,0.029799804,,,,,,,,,,,,,,,,, -21788,,,0.9909757971763612,0.0297860577702522,0.411273935375378,0.986751675605774,0.0445716865360736,0.2534629003852076,43793.0,0.9858971238136292,0.047172050923109,0.245060863473351,43793.0,6987.061242580414,10967.76712179184,6987.061242580414,3979.278872251511,0.8579602241516113,0.0 -21800,0.084833145,0.03429757,,,,,,,,,,,,,,,,, -21900,0.06874618,0.03697433,,,,,,,,,,,,,,,,, -22000,0.08450579,0.036014255,,,,,,,,,,,,,,,,, -22100,0.085134365,0.03591014,,,,,,,,,,,,,,,,, -22200,0.061930947,0.03472844,,,,,,,,,,,,,,,,, -22300,0.06915531,0.03282841,,,,,,,,,,,,,,,,, -22400,0.06217612,0.03481266,,,,,,,,,,,,,,,,, -22500,0.12807895,0.03523011,,,,,,,,,,,,,,,,, -22534,,,0.9908212423324584,0.0303160864859819,0.4167117446792816,0.9868998527526855,0.0442841872572898,0.2639624370741263,43793.0,0.9860268235206604,0.0470134206116199,0.259165562861682,43793.0,7227.019123077393,11339.331525325775,7227.019123077393,4110.835282087326,0.8883786201477051,0.0 -22600,0.08415913,0.03513229,,,,,,,,,,,,,,,,, -22700,0.06396111,0.030983046,,,,,,,,,,,,,,,,, -22800,0.08447562,0.036169443,,,,,,,,,,,,,,,,, -22900,0.078796774,0.033062097,,,,,,,,,,,,,,,,, -23000,0.07681526,0.0323458,,,,,,,,,,,,,,,,, -23100,0.07244599,0.032766424,,,,,,,,,,,,,,,,, -23200,0.079423554,0.034303557,,,,,,,,,,,,,,,,, -23281,,,0.9908854365348816,0.03038932941854,0.4145162980485438,0.986748456954956,0.0443939976394176,0.2643333702414174,43793.0,0.9858128428459167,0.0469983145594596,0.2511379732354883,43793.0,7466.973075628281,11708.985038518906,7466.973075628281,4240.48300409317,0.9200398921966552,0.0 -23300,0.06889578,0.035946373,,,,,,,,,,,,,,,,, -23400,0.06303855,0.030401666,,,,,,,,,,,,,,,,, -23500,0.071410716,0.036315285,,,,,,,,,,,,,,,,, -23600,0.11960209,0.031013915,,,,,,,,,,,,,,,,, -23700,0.10843871,0.03470133,,,,,,,,,,,,,,,,, -23800,0.07318896,0.03506002,,,,,,,,,,,,,,,,, -23900,0.09048099,0.03437926,,,,,,,,,,,,,,,,, -24000,0.0786778,0.03582771,,,,,,,,,,,,,,,,, -24031,,,0.9907976984977722,0.0301222447305917,0.4307058110209221,0.986678183078766,0.0445076636970043,0.2598207593992417,43793.0,0.9858853220939636,0.0472135208547115,0.2514192693563914,43793.0,7706.785302639008,12079.562668561935,7706.785302639008,4370.933499574661,1.214691162109375,0.0 -24100,0.084643774,0.03411919,,,,,,,,,,,,,,,,, -24200,0.08478215,0.034157943,,,,,,,,,,,,,,,,, -24300,0.08671452,0.02978431,,,,,,,,,,,,,,,,, -24400,0.060444705,0.029163541,,,,,,,,,,,,,,,,, -24500,0.07462614,0.03333292,,,,,,,,,,,,,,,,, -24600,0.069873855,0.032653883,,,,,,,,,,,,,,,,, -24700,0.06985338,0.038115367,,,,,,,,,,,,,,,,, -24784,,,0.9909395575523376,0.0297166630625724,0.4196510958908334,0.9867833256721495,0.0445629172027111,0.2656587042217679,43793.0,0.9858975410461426,0.0471698269248008,0.2590929594246151,43793.0,7946.956842422485,12447.819665908812,7946.956842422485,4498.967143058777,1.2464373111724854,0.0 -24800,0.06647885,0.028219758,,,,,,,,,,,,,,,,, -24900,0.09443232,0.03446004,,,,,,,,,,,,,,,,, -25000,0.065645464,0.03286449,,,,,,,,,,,,,,,,, -25100,0.10241743,0.030639604,,,,,,,,,,,,,,,,, -25200,0.06747496,0.033479307,,,,,,,,,,,,,,,,, -25300,0.07287486,0.03283426,,,,,,,,,,,,,,,,, -25400,0.072604835,0.0322904,,,,,,,,,,,,,,,,, -25500,0.066437215,0.031077813,,,,,,,,,,,,,,,,, -25543,,,0.991019070148468,0.0294891316443681,0.4367420533835061,0.986832082271576,0.0445814169943332,0.2679887542443359,43793.0,0.9859825968742372,0.0472048185765743,0.2633457291123027,43793.0,8187.189737558365,12813.948610544205,8187.189737558365,4624.81149148941,1.2782447338104248,0.0 -25600,0.073403925,0.033604555,,,,,,,,,,,,,,,,, -25700,0.094422266,0.03655751,,,,,,,,,,,,,,,,, -25800,0.080665626,0.033426583,,,,,,,,,,,,,,,,, -25900,0.085349634,0.037831414,,,,,,,,,,,,,,,,, -26000,0.07593097,0.032697827,,,,,,,,,,,,,,,,, -26100,0.06881025,0.03502377,,,,,,,,,,,,,,,,, -26200,0.087186486,0.03504753,,,,,,,,,,,,,,,,, -26300,0.08727915,0.031691246,,,,,,,,,,,,,,,,, -26303,,,0.9911377429962158,0.0290332436561584,0.4396199351537836,0.9868190884590148,0.0445399396121501,0.2687223142231175,43793.0,0.9859455227851868,0.0469372197985649,0.2593809136492444,43793.0,8427.414668560028,13180.273483514786,8427.414668560028,4750.859451770783,1.3099524974822998,0.0 -26400,0.08343537,0.03393128,,,,,,,,,,,,,,,,, -26500,0.11214395,0.03289414,,,,,,,,,,,,,,,,, -26600,0.08120607,0.033356342,,,,,,,,,,,,,,,,, -26700,0.10414628,0.033864442,,,,,,,,,,,,,,,,, -26800,0.07586365,0.033971913,,,,,,,,,,,,,,,,, -26900,0.08717828,0.037790973,,,,,,,,,,,,,,,,, -27000,0.08312699,0.032831557,,,,,,,,,,,,,,,,, -27058,,,0.9913957715034484,0.0284608248621225,0.4587331702582208,0.9867423176765442,0.0443950109183788,0.2620780825173272,43793.0,0.9858773350715636,0.0470265299081802,0.2526834643991631,43793.0,8667.366248846054,13549.797183036804,8667.366248846054,4880.380812883377,1.3404219150543213,0.0 -27100,0.06991469,0.03411313,,,,,,,,,,,,,,,,, -27200,0.06946058,0.032307193,,,,,,,,,,,,,,,,, -27300,0.065436155,0.03183798,,,,,,,,,,,,,,,,, -27400,0.083995216,0.034911003,,,,,,,,,,,,,,,,, -27500,0.06788217,0.03481566,,,,,,,,,,,,,,,,, -27600,0.08448606,0.032415137,,,,,,,,,,,,,,,,, -27700,0.06785488,0.033262692,,,,,,,,,,,,,,,,, -27800,0.06967032,0.032843906,,,,,,,,,,,,,,,,, -27812,,,0.9915438890457152,0.0275515113025903,0.478820071746103,0.9868454337120056,0.0440764725208282,0.268497485961063,43793.0,0.9860082864761353,0.0469130389392375,0.256955366585077,43793.0,8907.356931447983,13917.328118801115,8907.356931447983,5007.864626646042,1.3766028881072998,0.0 -27900,0.09896196,0.03447991,,,,,,,,,,,,,,,,, -28000,0.085722886,0.034198713,,,,,,,,,,,,,,,,, -28100,0.07606207,0.035084963,,,,,,,,,,,,,,,,, -28200,0.069280885,0.035963833,,,,,,,,,,,,,,,,, -28300,0.09535423,0.03323931,,,,,,,,,,,,,,,,, -28400,0.09023384,0.032131948,,,,,,,,,,,,,,,,, -28500,0.07108946,0.03267786,,,,,,,,,,,,,,,,, -28566,,,0.9915900230407716,0.0274300016462802,0.4838465650251311,0.9868852496147156,0.0441147089004516,0.2774226575024014,43793.0,0.9860162734985352,0.0467939786612987,0.2587861733148328,43793.0,9147.537345647812,14287.542890071869,9147.537345647812,5137.846126794815,1.4091360569000244,0.0 -28600,0.089618295,0.0329853,,,,,,,,,,,,,,,,, -28700,0.101682134,0.033479746,,,,,,,,,,,,,,,,, -28800,0.07899015,0.03260769,,,,,,,,,,,,,,,,, -28900,0.08307875,0.030513598,,,,,,,,,,,,,,,,, -29000,0.06784397,0.029613292,,,,,,,,,,,,,,,,, -29100,0.10891644,0.03170363,,,,,,,,,,,,,,,,, -29200,0.08767354,0.032148227,,,,,,,,,,,,,,,,, -29300,0.08421777,0.034102604,,,,,,,,,,,,,,,,, -29316,,,0.9914113283157348,0.028087593615055,0.4503752334796423,0.9868101477622986,0.044549535959959,0.2645352341758079,43793.0,0.985925316810608,0.0473748035728931,0.2550472083789822,43793.0,9387.8157954216,14659.599410057068,9387.8157954216,5269.571957588196,1.4414465427398682,0.0 -29400,0.069053315,0.030559335,,,,,,,,,,,,,,,,, -29500,0.083417006,0.03127768,,,,,,,,,,,,,,,,, -29600,0.0907621,0.032186307,,,,,,,,,,,,,,,,, -29700,0.088114716,0.031340238,,,,,,,,,,,,,,,,, -29800,0.06826825,0.031464934,,,,,,,,,,,,,,,,, -29900,0.07928663,0.033333305,,,,,,,,,,,,,,,,, -30000,0.08338325,0.034472816,,,,,,,,,,,,,,,,, -30073,,,0.9912699460983276,0.0286596063524484,0.44959692861203,0.986766278743744,0.0447003468871116,0.2619020218166186,43793.0,0.9858600497245787,0.0476008988916873,0.2547086969015181,43793.0,9627.81943321228,15028.82996916771,9627.81943321228,5398.746550559998,1.4734342098236084,0.0 -30100,0.088410825,0.033589106,,,,,,,,,,,,,,,,, -30200,0.07193491,0.031762507,,,,,,,,,,,,,,,,, -30300,0.07591001,0.03041226,,,,,,,,,,,,,,,,, -30400,0.10712662,0.032625772,,,,,,,,,,,,,,,,, -30500,0.07972556,0.034620054,,,,,,,,,,,,,,,,, -30600,0.07761669,0.029312553,,,,,,,,,,,,,,,,, -30700,0.07133061,0.030042484,,,,,,,,,,,,,,,,, -30800,0.10411361,0.033180714,,,,,,,,,,,,,,,,, -30829,,,0.9910553097724916,0.0290402844548225,0.4345576765374209,0.9868341088294984,0.044944878667593,0.2651539653275671,43793.0,0.985971212387085,0.0476885735988616,0.2549692084955743,43793.0,9867.785235404968,15397.919102430344,9867.785235404968,5527.817716121674,1.5052387714385986,0.0 -30900,0.117618725,0.033805955,,,,,,,,,,,,,,,,, -31000,0.095855266,0.032465767,,,,,,,,,,,,,,,,, -31100,0.06541868,0.030813774,,,,,,,,,,,,,,,,, -31200,0.07975252,0.03347507,,,,,,,,,,,,,,,,, -31300,0.07298814,0.031595517,,,,,,,,,,,,,,,,, -31400,0.07907132,0.029912822,,,,,,,,,,,,,,,,, -31500,0.06719668,0.02811461,,,,,,,,,,,,,,,,, -31588,,,0.9913302063941956,0.0285799913108348,0.4591087979868709,0.9867050051689148,0.044590026140213,0.2681575482452986,43793.0,0.985917329788208,0.0471377708017826,0.2580299895531414,43793.0,10107.885499715803,15765.455152750015,10107.885499715803,5655.200493574143,1.5384142398834229,0.0 -31600,0.07443855,0.03424905,,,,,,,,,,,,,,,,, -31700,0.079303734,0.03188568,,,,,,,,,,,,,,,,, -31800,0.09083607,0.0341688,,,,,,,,,,,,,,,,, -31900,0.07993641,0.033293072,,,,,,,,,,,,,,,,, -32000,0.08775882,0.033431686,,,,,,,,,,,,,,,,, -32100,0.08147239,0.028928382,,,,,,,,,,,,,,,,, -32200,0.10597686,0.034592595,,,,,,,,,,,,,,,,, -32300,0.09032177,0.031185111,,,,,,,,,,,,,,,,, -32346,,,0.9914727210998536,0.0278371162712574,0.4717979107615732,0.9869274497032166,0.043929535895586,0.274115570411685,43793.0,0.9860752820968628,0.0467753335833549,0.26618305739415,43793.0,10347.881727457048,16134.079976081848,10347.881727457048,5783.775832414627,1.572129487991333,0.0 -32400,0.07445735,0.03170336,,,,,,,,,,,,,,,,, -32500,0.0745582,0.029739529,,,,,,,,,,,,,,,,, -32600,0.08016501,0.032918386,,,,,,,,,,,,,,,,, -32700,0.08455324,0.03288667,,,,,,,,,,,,,,,,, -32800,0.07975114,0.03389966,,,,,,,,,,,,,,,,, -32900,0.10822178,0.035658874,,,,,,,,,,,,,,,,, -33000,0.08412143,0.03175649,,,,,,,,,,,,,,,,, -33100,0.08662465,0.033611175,,,,,,,,,,,,,,,,, -33102,,,0.9913511276245116,0.0281847678124904,0.4569296601938585,0.9868227243423462,0.0446917377412319,0.2654920394579966,43793.0,0.985958993434906,0.0473562479019165,0.2575073350230395,43793.0,10587.90379667282,16501.14848446846,10587.90379667282,5910.768916606903,1.6054506301879885,0.0 -33200,0.08039142,0.030588701,,,,,,,,,,,,,,,,, -33300,0.07526128,0.03271895,,,,,,,,,,,,,,,,, -33400,0.11270269,0.03473075,,,,,,,,,,,,,,,,, -33500,0.07779003,0.033518568,,,,,,,,,,,,,,,,, -33600,0.07237551,0.030128516,,,,,,,,,,,,,,,,, -33700,0.107612915,0.033804893,,,,,,,,,,,,,,,,, -33800,0.07227572,0.027987424,,,,,,,,,,,,,,,,, -33860,,,0.9915713667869568,0.0275131501257419,0.4805088503117457,0.986916482448578,0.0443529598414897,0.2733849468277443,43793.0,0.9860659837722778,0.0471661426126956,0.2590065779058145,43793.0,10828.046881198885,16867.5029835701,10828.046881198885,6036.926429271698,1.6387813091278076,0.0 -33900,0.08511027,0.03164882,,,,,,,,,,,,,,,,, -34000,0.09399857,0.032913,,,,,,,,,,,,,,,,, -34100,0.08227077,0.030646725,,,,,,,,,,,,,,,,, -34200,0.122610554,0.029062403,,,,,,,,,,,,,,,,, -34300,0.10019471,0.032213878,,,,,,,,,,,,,,,,, -34400,0.106841505,0.035421472,,,,,,,,,,,,,,,,, -34500,0.09535569,0.032245565,,,,,,,,,,,,,,,,, -34600,0.08089996,0.03302924,,,,,,,,,,,,,,,,, -34617,,,0.9917721152305604,0.0267431810498237,0.4938607820037304,0.9869108200073242,0.0445934534072876,0.2638629801469858,43793.0,0.9859358668327332,0.0476084761321544,0.2522764802677962,43793.0,11068.133890151978,17232.824068784714,11068.133890151978,6162.10580778122,1.6727240085601809,0.0 -34700,0.09463804,0.03307128,,,,,,,,,,,,,,,,, -34800,0.095193446,0.03354656,,,,,,,,,,,,,,,,, -34900,0.08349893,0.03228534,,,,,,,,,,,,,,,,, -35000,0.09430778,0.03215031,,,,,,,,,,,,,,,,, -35100,0.08636995,0.03371399,,,,,,,,,,,,,,,,, -35200,0.09649223,0.03024603,,,,,,,,,,,,,,,,, -35300,0.074983895,0.031424258,,,,,,,,,,,,,,,,, -35373,,,0.9918096661567688,0.0265304185450077,0.5029054677780674,0.9868081212043762,0.0446687713265419,0.2637214181806488,43793.0,0.9859430193901062,0.0474424511194229,0.2559087693871861,43793.0,11308.152801513672,17598.52809739113,11308.152801513672,6287.737751483917,1.7055885791778564,0.0 -35400,0.080417044,0.03016624,,,,,,,,,,,,,,,,, -35500,0.10874491,0.0327883,,,,,,,,,,,,,,,,, -35600,0.07193249,0.029597221,,,,,,,,,,,,,,,,, -35700,0.076396555,0.031721894,,,,,,,,,,,,,,,,, -35800,0.09599995,0.035092816,,,,,,,,,,,,,,,,, -35900,0.116773665,0.032603633,,,,,,,,,,,,,,,,, -36000,0.10525485,0.03271771,,,,,,,,,,,,,,,,, -36100,0.09257216,0.030959373,,,,,,,,,,,,,,,,, -36130,,,0.9921783208847046,0.0254643373191356,0.5405584099723635,0.9869144558906556,0.0447491630911827,0.2696709121978386,43793.0,0.9859346151351928,0.0476196333765983,0.254952374996752,43793.0,11548.278539657593,17964.49371266365,11548.278539657593,6413.520535945892,1.7420799732208252,0.0 -36200,0.08445351,0.032227755,,,,,,,,,,,,,,,,, -36300,0.07872666,0.030828767,,,,,,,,,,,,,,,,, -36400,0.089253485,0.032023873,,,,,,,,,,,,,,,,, -36500,0.09096959,0.03357632,,,,,,,,,,,,,,,,, -36600,0.096699536,0.03151698,,,,,,,,,,,,,,,,, -36700,0.10083207,0.031114435,,,,,,,,,,,,,,,,, -36800,0.09159986,0.032440674,,,,,,,,,,,,,,,,, -36881,,,0.9921891689300536,0.0256269536912441,0.5054668073603061,0.9867618083953856,0.0448433235287666,0.2666542209544567,43793.0,0.985889971256256,0.047700397670269,0.2532243725665771,43793.0,11788.357815027235,18332.315428972244,11788.357815027235,6541.208086490631,1.776219606399536,0.0 -36900,0.09031687,0.034298,,,,,,,,,,,,,,,,, -37000,0.095123366,0.030724365,,,,,,,,,,,,,,,,, -37100,0.09276752,0.030553516,,,,,,,,,,,,,,,,, -37200,0.08336512,0.028607305,,,,,,,,,,,,,,,,, -37300,0.09488112,0.033889093,,,,,,,,,,,,,,,,, -37400,0.09788999,0.032576405,,,,,,,,,,,,,,,,, -37500,0.09594563,0.029845804,,,,,,,,,,,,,,,,, -37600,0.09227175,0.034058813,,,,,,,,,,,,,,,,, -37631,,,0.9919636249542236,0.0262818839401006,0.5000979311796215,0.9867650866508484,0.04451534897089,0.269276343714745,43793.0,0.9858503341674804,0.0473716296255588,0.2548070281969129,43793.0,12028.320405721664,18700.193468809128,12028.320405721664,6669.069313287735,1.80993127822876,0.0 -37700,0.08852345,0.03136857,,,,,,,,,,,,,,,,, -37800,0.08682548,0.03224601,,,,,,,,,,,,,,,,, -37900,0.08746931,0.031891596,,,,,,,,,,,,,,,,, -38000,0.08283442,0.030077,,,,,,,,,,,,,,,,, -38100,0.09643421,0.032898262,,,,,,,,,,,,,,,,, -38200,0.084942766,0.030841196,,,,,,,,,,,,,,,,, -38300,0.1210168,0.032317623,,,,,,,,,,,,,,,,, -38379,,,0.9917044639587402,0.0267996452748775,0.4938836258819183,0.986860454082489,0.0450091734528541,0.2717148167414828,43793.0,0.9860424399375916,0.0478706955909729,0.2579735518736352,43793.0,12268.40572667122,19063.826691389084,12268.40572667122,6792.561163425446,1.8454234600067136,0.0 -38400,0.08967451,0.031563446,,,,,,,,,,,,,,,,, -38500,0.11261774,0.033893734,,,,,,,,,,,,,,,,, -38600,0.09173951,0.032878276,,,,,,,,,,,,,,,,, -38700,0.08611218,0.02846017,,,,,,,,,,,,,,,,, -38800,0.09662739,0.030735984,,,,,,,,,,,,,,,,, -38900,0.10106093,0.029335888,,,,,,,,,,,,,,,,, -39000,0.1057683,0.032386012,,,,,,,,,,,,,,,,, -39100,0.08984081,0.029872335,,,,,,,,,,,,,,,,, -39127,,,0.9917911887168884,0.0266721807420253,0.4895440123990278,0.9869790077209472,0.044631291180849,0.2678226921038907,43793.0,0.986038625240326,0.0473871938884258,0.2648304309454832,43793.0,12508.509371519089,19430.22696685791,12508.509371519089,6918.80362701416,1.879314661026001,0.0 -39200,0.10726153,0.03047174,,,,,,,,,,,,,,,,, -39300,0.092400916,0.030106178,,,,,,,,,,,,,,,,, -39400,0.09390861,0.03263896,,,,,,,,,,,,,,,,, -39500,0.10437052,0.03255598,,,,,,,,,,,,,,,,, -39600,0.082125835,0.029981352,,,,,,,,,,,,,,,,, -39700,0.070648484,0.028917218,,,,,,,,,,,,,,,,, -39800,0.09057234,0.032945607,,,,,,,,,,,,,,,,, -39885,,,0.9919468760490416,0.0260848235338926,0.5204706545550591,0.987012267112732,0.0442746318876743,0.2786177764963892,43793.0,0.9861629009246826,0.0473428405821323,0.2622531743337971,43793.0,12748.566566467283,19791.197275161743,12748.566566467283,7039.660442113876,1.91489839553833,0.0 -39900,0.09485344,0.029958643,,,,,,,,,,,,,,,,, -40000,0.11570868,0.035664313,,,,,,,,,,,,,,,,, -40100,0.08460533,0.029539926,,,,,,,,,,,,,,,,, -40200,0.09999034,0.030670632,,,,,,,,,,,,,,,,, -40300,0.09511584,0.029736962,,,,,,,,,,,,,,,,, -40400,0.09902125,0.030394047,,,,,,,,,,,,,,,,, -40500,0.09636814,0.029586574,,,,,,,,,,,,,,,,, -40600,0.09303227,0.028720265,,,,,,,,,,,,,,,,, -40639,,,0.991911232471466,0.0260745268315076,0.5069046204867362,0.986780881881714,0.0451743900775909,0.2680662528745569,43793.0,0.985961139202118,0.0480013452470302,0.2567544314291233,43793.0,12988.536489725111,20159.67905855179,12988.536489725111,7168.117209196091,1.949426889419556,0.0 -40700,0.09783454,0.029476972,,,,,,,,,,,,,,,,, -40800,0.12614919,0.030581877,,,,,,,,,,,,,,,,, -40900,0.08784877,0.030710125,,,,,,,,,,,,,,,,, -41000,0.0978445,0.031924065,,,,,,,,,,,,,,,,, -41100,0.09972532,0.03024241,,,,,,,,,,,,,,,,, -41200,0.08756406,0.029258955,,,,,,,,,,,,,,,,, -41300,0.10254097,0.027926292,,,,,,,,,,,,,,,,, -41398,,,0.9920253753662108,0.0256763193756341,0.508227948274832,0.98688805103302,0.0451797060668468,0.2732676072132827,43793.0,0.985961139202118,0.0479624792933464,0.2656677761975639,43793.0,13228.646829366684,20523.62076807022,13228.646829366684,7291.893787145615,1.9835777282714844,0.0 -41400,0.0931523,0.028328259,,,,,,,,,,,,,,,,, -41500,0.10825098,0.034212552,,,,,,,,,,,,,,,,, -41600,0.10538146,0.031499345,,,,,,,,,,,,,,,,, -41700,0.082444504,0.02843204,,,,,,,,,,,,,,,,, -41800,0.08741234,0.029350873,,,,,,,,,,,,,,,,, -41900,0.09100535,0.028078033,,,,,,,,,,,,,,,,, -42000,0.07933398,0.029013386,,,,,,,,,,,,,,,,, -42100,0.105021566,0.028547728,,,,,,,,,,,,,,,,, -42155,,,0.992261290550232,0.0248192362487316,0.5357011778114732,0.986861288547516,0.0453098826110363,0.2641243823559999,43793.0,0.98598974943161,0.0482239499688148,0.2574946687375729,43793.0,13468.731812000276,20885.98131942749,13468.731812000276,7414.11568903923,2.017220973968506,0.0 -42200,0.1269519,0.030228505,,,,,,,,,,,,,,,,, -42300,0.118835375,0.0317761,,,,,,,,,,,,,,,,, -42400,0.09781918,0.030287791,,,,,,,,,,,,,,,,, -42500,0.11196132,0.030200329,,,,,,,,,,,,,,,,, -42600,0.11049021,0.030738762,,,,,,,,,,,,,,,,, -42700,0.107638,0.030164713,,,,,,,,,,,,,,,,, -42800,0.102986604,0.02938879,,,,,,,,,,,,,,,,, -42900,0.09890954,0.028041936,,,,,,,,,,,,,,,,, -42908,,,0.9926353096961976,0.0238346364349126,0.5570144195376461,0.986894965171814,0.045154895633459,0.2684881619213733,43793.0,0.9860158562660216,0.0480639971792697,0.2653323038381081,43793.0,13708.764877796171,21250.51918125153,13708.764877796171,7538.565611362457,2.051811456680298,0.0 -43000,0.117894225,0.028441573,,,,,,,,,,,,,,,,, -43100,0.10186048,0.03167741,,,,,,,,,,,,,,,,, -43200,0.114334755,0.029901857,,,,,,,,,,,,,,,,, -43300,0.107933976,0.030420251,,,,,,,,,,,,,,,,, -43400,0.10693309,0.032075208,,,,,,,,,,,,,,,,, -43500,0.09412414,0.02971939,,,,,,,,,,,,,,,,, -43600,0.10663265,0.031191742,,,,,,,,,,,,,,,,, -43665,,,0.9925816655158995,0.0239026378840208,0.569847598106749,0.9868986010551452,0.0452833510935306,0.2717307550178367,43793.0,0.98598051071167,0.0484649688005447,0.2619368512501648,43793.0,13948.78694844246,21616.52692937851,13948.78694844246,7664.496058940887,2.0864064693450928,0.0 -43700,0.10259215,0.031144224,,,,,,,,,,,,,,,,, -43800,0.11965284,0.030288583,,,,,,,,,,,,,,,,, -43900,0.09572613,0.02894036,,,,,,,,,,,,,,,,, -44000,0.103839695,0.028820079,,,,,,,,,,,,,,,,, -44100,0.123212844,0.028929742,,,,,,,,,,,,,,,,, -44200,0.10507713,0.030402057,,,,,,,,,,,,,,,,, -44300,0.1222242,0.03234564,,,,,,,,,,,,,,,,, -44400,0.12410361,0.02682947,,,,,,,,,,,,,,,,, -44417,,,0.9929485321044922,0.022788044065237,0.5749294484192354,0.9868471026420592,0.0452699847519397,0.2726284919209182,43793.0,0.9859729409217834,0.0484565906226635,0.2569524214786334,43793.0,14188.775550365448,21985.23408293724,14188.775550365448,7793.158420085907,2.121481418609619,0.0 -44500,0.093358494,0.028818106,,,,,,,,,,,,,,,,, -44600,0.108533844,0.031563733,,,,,,,,,,,,,,,,, -44700,0.10190971,0.027554598,,,,,,,,,,,,,,,,, -44800,0.11454327,0.029757846,,,,,,,,,,,,,,,,, -44900,0.14329845,0.030771919,,,,,,,,,,,,,,,,, -45000,0.123567484,0.028536612,,,,,,,,,,,,,,,,, -45100,0.12866288,0.029647624,,,,,,,,,,,,,,,,, -45164,,,0.9925280213356018,0.0241652112454175,0.5439482996360956,0.9867951273918152,0.0455055013298988,0.2725618970001102,43793.0,0.985913097858429,0.0485909059643745,0.2562790789794792,43793.0,14428.960973501204,22349.58697938919,14428.960973501204,7917.267426967621,2.1592814922332764,0.0 -45200,0.12317775,0.029521603,,,,,,,,,,,,,,,,, -45300,0.11308669,0.027217824,,,,,,,,,,,,,,,,, -45400,0.10220863,0.027743772,,,,,,,,,,,,,,,,, -45500,0.10118784,0.027610723,,,,,,,,,,,,,,,,, -45600,0.10691255,0.02903831,,,,,,,,,,,,,,,,, -45700,0.12650248,0.029453808,,,,,,,,,,,,,,,,, -45800,0.11653476,0.027632523,,,,,,,,,,,,,,,,, -45900,0.10512281,0.028776648,,,,,,,,,,,,,,,,, -45910,,,0.9924951195716858,0.024184413254261,0.5322754362561134,0.9868247509002686,0.04539680108428,0.2694207075858307,43793.0,0.9858949780464172,0.0484018251299858,0.2579783262777751,43793.0,14669.197919368744,22715.47890305519,14669.197919368744,8042.86420583725,2.19640588760376,0.0 -46000,0.12623481,0.030972863,,,,,,,,,,,,,,,,, -46100,0.14347313,0.028046036,,,,,,,,,,,,,,,,, -46200,0.12406198,0.027746128,,,,,,,,,,,,,,,,, -46300,0.11261514,0.029012855,,,,,,,,,,,,,,,,, -46400,0.13066074,0.027800903,,,,,,,,,,,,,,,,, -46500,0.11080495,0.028868485,,,,,,,,,,,,,,,,, -46600,0.14377072,0.029408718,,,,,,,,,,,,,,,,, -46658,,,0.9923391938209534,0.0245573688298463,0.5505022282160861,0.9869205355644226,0.0457893349230289,0.2665850458991095,43793.0,0.9860398769378662,0.0489005632698535,0.2569565723669573,43793.0,14909.381882667542,23080.968989133835,14909.381882667542,8168.112809181213,2.2335081100463867,0.0 -46700,0.121194385,0.032028962,,,,,,,,,,,,,,,,, -46800,0.16352242,0.029585283,,,,,,,,,,,,,,,,, -46900,0.13167354,0.03171676,,,,,,,,,,,,,,,,, -47000,0.10394198,0.026682759,,,,,,,,,,,,,,,,, -47100,0.12036214,0.028783493,,,,,,,,,,,,,,,,, -47200,0.12509906,0.02997297,,,,,,,,,,,,,,,,, -47300,0.12815893,0.030969616,,,,,,,,,,,,,,,,, -47400,0.11713304,0.027627144,,,,,,,,,,,,,,,,, -47410,,,0.9923987984657288,0.0243546906858682,0.547945267187708,0.9869290590286256,0.0455531850457191,0.2730923089403295,43793.0,0.986047863960266,0.0485950969159603,0.2584206174290291,43793.0,15149.46195435524,23447.16524910927,15149.46195435524,8294.17154455185,2.270673513412476,0.0 -47500,0.10864319,0.02853135,,,,,,,,,,,,,,,,, -47600,0.12509598,0.027682355,,,,,,,,,,,,,,,,, -47700,0.12684803,0.029499777,,,,,,,,,,,,,,,,, -47800,0.152909,0.029093137,,,,,,,,,,,,,,,,, -47900,0.1336799,0.028117985,,,,,,,,,,,,,,,,, -48000,0.13860232,0.026316397,,,,,,,,,,,,,,,,, -48100,0.10917069,0.02819792,,,,,,,,,,,,,,,,, -48164,,,0.9926859736442566,0.0235126633197069,0.5531412566917546,0.9868718385696412,0.0453873015940189,0.2729452931736282,43793.0,0.9860175848007202,0.0482630915939807,0.2613342734660124,43793.0,15389.69013428688,23813.83040785789,15389.69013428688,8420.551263570786,2.3075530529022217,0.0 -48200,0.11457331,0.026731722,,,,,,,,,,,,,,,,, -48300,0.114832684,0.025079112,,,,,,,,,,,,,,,,, -48400,0.12641448,0.02797928,,,,,,,,,,,,,,,,, -48500,0.12307426,0.031394523,,,,,,,,,,,,,,,,, -48600,0.12660931,0.027703475,,,,,,,,,,,,,,,,, -48700,0.12933026,0.028222438,,,,,,,,,,,,,,,,, -48800,0.13588107,0.029250138,,,,,,,,,,,,,,,,, -48900,0.12437773,0.026791284,,,,,,,,,,,,,,,,, -48923,,,0.9927862286567688,0.0231304336339235,0.5690608496377537,0.9868023991584778,0.0456211678683757,0.2696094281540254,43793.0,0.9859282970428468,0.0485388562083244,0.2617990529820947,43793.0,15629.627718687056,24176.42839407921,15629.627718687056,8543.148166894913,2.350667953491211,0.0 -49000,0.118483774,0.028152056,,,,,,,,,,,,,,,,, -49100,0.13100451,0.026346559,,,,,,,,,,,,,,,,, -49200,0.12763396,0.030065874,,,,,,,,,,,,,,,,, -49300,0.13059337,0.027708242,,,,,,,,,,,,,,,,, -49400,0.12052511,0.0277514,,,,,,,,,,,,,,,,, -49500,0.13599573,0.026807673,,,,,,,,,,,,,,,,, -49600,0.11720882,0.027987681,,,,,,,,,,,,,,,,, -49677,,,0.9929873943328856,0.0223092995584011,0.5790065918357813,0.9868316650390624,0.0461039021611213,0.2703625890701598,43793.0,0.9859762787818908,0.0492651611566543,0.2634007317030291,43793.0,15869.585559368134,24541.95281338692,15869.585559368134,8668.656472444534,2.388645648956299,0.0 -49700,0.14322746,0.030527545,,,,,,,,,,,,,,,,, -49800,0.12827206,0.028889555,,,,,,,,,,,,,,,,, -49900,0.12878396,0.027273221,,,,,,,,,,,,,,,,, -50000,0.15228704,0.027398288,,,,,,,,,,,,,,,,, -50100,0.14454201,0.028470634,,,,,,,,,,,,,,,,, -50200,0.1268862,0.025501743,,,,,,,,,,,,,,,,, -50300,0.13173269,0.025879875,,,,,,,,,,,,,,,,, -50400,0.15081826,0.02616136,,,,,,,,,,,,,,,,, -50435,,,0.993471086025238,0.0211111195385456,0.6146497697624851,0.986735463142395,0.0460940673947334,0.2744567062110209,43793.0,0.9858596324920654,0.0491060465574264,0.2610190774407657,43793.0,16109.806084394457,24906.823295116425,16109.806084394457,8793.248821020126,2.4257941246032715,0.0 -50500,0.12392141,0.02704194,,,,,,,,,,,,,,,,, -50600,0.1349633,0.027842738,,,,,,,,,,,,,,,,, -50700,0.13462529,0.027160836,,,,,,,,,,,,,,,,, -50800,0.13533548,0.028376795,,,,,,,,,,,,,,,,, -50900,0.14096704,0.028896831,,,,,,,,,,,,,,,,, -51000,0.17986253,0.027131466,,,,,,,,,,,,,,,,, -51100,0.15059805,0.028204989,,,,,,,,,,,,,,,,, -51187,,,0.9937585592269896,0.0203729905188083,0.6229511751791663,0.9867293238639832,0.0467021130025386,0.2689869825711359,43793.0,0.985773265361786,0.0498418360948562,0.2565580586568688,43793.0,16349.852685928345,25272.11195421219,16349.852685928345,8918.434713840485,2.461792469024658,0.0 -51200,0.1330179,0.028059892,,,,,,,,,,,,,,,,, -51300,0.11818609,0.026387284,,,,,,,,,,,,,,,,, -51400,0.1322981,0.026348908,,,,,,,,,,,,,,,,, -51500,0.13810685,0.0271302,,,,,,,,,,,,,,,,, -51600,0.14266288,0.029427065,,,,,,,,,,,,,,,,, -51700,0.13249859,0.026802562,,,,,,,,,,,,,,,,, -51800,0.14148311,0.025686046,,,,,,,,,,,,,,,,, -51900,0.14683901,0.029188043,,,,,,,,,,,,,,,,, -51940,,,0.9935181140899658,0.0207281652837991,0.6268248896791947,0.9868150353431702,0.0465990602970123,0.2722384098665055,43793.0,0.9859463572502136,0.0497677139937877,0.2659970159599488,43793.0,16589.82354283333,25637.23035120964,16589.82354283333,9043.52482008934,2.4987967014312744,0.0 -52000,0.13115408,0.026754651,,,,,,,,,,,,,,,,, -52100,0.15439245,0.02815099,,,,,,,,,,,,,,,,, -52200,0.12748604,0.025686352,,,,,,,,,,,,,,,,, -52300,0.13862689,0.026479576,,,,,,,,,,,,,,,,, -52400,0.16701774,0.026949983,,,,,,,,,,,,,,,,, -52500,0.13497365,0.029472724,,,,,,,,,,,,,,,,, -52600,0.14271073,0.027859084,,,,,,,,,,,,,,,,, -52690,,,0.9936492443084716,0.0206983480602502,0.6196390393205258,0.9866599440574646,0.0467642769217491,0.2696261098279769,43793.0,0.9856839776039124,0.0501977689564228,0.2549874512181881,43793.0,16829.831008911133,26003.949209213257,16829.831008911133,9170.178788661957,2.5351483821868896,0.0 -52700,0.16075563,0.025686057,,,,,,,,,,,,,,,,, -52800,0.13740978,0.026387239,,,,,,,,,,,,,,,,, -52900,0.16243301,0.02973413,,,,,,,,,,,,,,,,, -53000,0.13878906,0.026149811,,,,,,,,,,,,,,,,, -53100,0.15112223,0.027567297,,,,,,,,,,,,,,,,, -53200,0.14577219,0.025658151,,,,,,,,,,,,,,,,, -53300,0.1573195,0.0317174,,,,,,,,,,,,,,,,, -53400,0.1567435,0.02571575,,,,,,,,,,,,,,,,, -53446,,,0.9932865500450134,0.0214157178997993,0.6066971485592483,0.986796736717224,0.0472288504242897,0.2679517868571796,43793.0,0.9859278798103333,0.0503723919391632,0.2595123515077037,43793.0,17069.828882217407,26368.558528900143,17069.828882217407,9294.73175382614,2.572896957397461,0.0 -53500,0.15917179,0.026511509,,,,,,,,,,,,,,,,, -53600,0.1642216,0.027022084,,,,,,,,,,,,,,,,, -53700,0.14852098,0.026185885,,,,,,,,,,,,,,,,, -53800,0.14854904,0.025889628,,,,,,,,,,,,,,,,, -53900,0.13866973,0.025642516,,,,,,,,,,,,,,,,, -54000,0.14715049,0.027316727,,,,,,,,,,,,,,,,, -54100,0.15716295,0.024396608,,,,,,,,,,,,,,,,, -54200,0.16785507,0.025879722,,,,,,,,,,,,,,,,, -54202,,,0.9932235479354858,0.0213823653757572,0.6007852406767729,0.9868023991584778,0.047565758228302,0.2681301990027719,43793.0,0.9859842658042908,0.0508865006268024,0.2572087545398565,43793.0,17310.08274435997,26731.15495181084,17310.08274435997,9417.017122030258,2.609468460083008,0.0 -54300,0.1478455,0.026487764,,,,,,,,,,,,,,,,, -54400,0.13538708,0.025375465,,,,,,,,,,,,,,,,, -54500,0.14247355,0.025733564,,,,,,,,,,,,,,,,, -54600,0.15439765,0.024408946,,,,,,,,,,,,,,,,, -54700,0.17566459,0.027563713,,,,,,,,,,,,,,,,, -54800,0.14549813,0.024370627,,,,,,,,,,,,,,,,, -54900,0.13920452,0.025760662,,,,,,,,,,,,,,,,, -54959,,,0.9933767318725586,0.0210478398948907,0.6104609656727764,0.9867898225784302,0.0478362925350666,0.2712276426357609,43793.0,0.9857816696166992,0.0509912185370922,0.2596567648926789,43793.0,17550.31499028206,27095.101389169693,17550.31499028206,9540.67410349846,2.645934820175171,0.0 -55000,0.17584234,0.024914538,,,,,,,,,,,,,,,,, -55100,0.1578873,0.026412541,,,,,,,,,,,,,,,,, -55200,0.18957835,0.024783725,,,,,,,,,,,,,,,,, -55300,0.18214622,0.026582187,,,,,,,,,,,,,,,,, -55400,0.1617485,0.026982596,,,,,,,,,,,,,,,,, -55500,0.14270835,0.024108758,,,,,,,,,,,,,,,,, -55600,0.15493849,0.025047883,,,,,,,,,,,,,,,,, -55700,0.15381747,0.025229324,,,,,,,,,,,,,,,,, -55710,,,0.993337333202362,0.0210698246955871,0.6063672334387846,0.9867419600486756,0.0481171533465385,0.2701161407026355,43793.0,0.9857547283172609,0.0514043755829334,0.254248360382065,43793.0,17790.44306921959,27456.76895523072,17790.44306921959,9662.156247615814,2.6831068992614746,0.0 -55800,0.16562757,0.024850015,,,,,,,,,,,,,,,,, -55900,0.15095636,0.024649449,,,,,,,,,,,,,,,,, -56000,0.14531659,0.022942279,,,,,,,,,,,,,,,,, -56100,0.1732054,0.026089553,,,,,,,,,,,,,,,,, -56200,0.1589717,0.02650332,,,,,,,,,,,,,,,,, -56300,0.16475266,0.025078464,,,,,,,,,,,,,,,,, -56400,0.15710953,0.02512601,,,,,,,,,,,,,,,,, -56463,,,0.9936086535453796,0.0201977919787168,0.6252101015049698,0.9866887331008912,0.0482036210596561,0.2638393455795906,43793.0,0.9857159852981568,0.0516699701547622,0.2506636334811378,43793.0,18030.64174723625,27817.927065372467,18030.64174723625,9783.057143211365,2.721312522888184,0.0 -56500,0.1699922,0.025961697,,,,,,,,,,,,,,,,, -56600,0.142862,0.023571054,,,,,,,,,,,,,,,,, -56700,0.14980246,0.02427512,,,,,,,,,,,,,,,,, -56800,0.1814827,0.025851078,,,,,,,,,,,,,,,,, -56900,0.15436612,0.023265079,,,,,,,,,,,,,,,,, -57000,0.16778398,0.026984757,,,,,,,,,,,,,,,,, -57100,0.18175414,0.025134955,,,,,,,,,,,,,,,,, -57200,0.17338803,0.024988545,,,,,,,,,,,,,,,,, -57216,,,0.9939561486244202,0.0191822126507759,0.6443242592758944,0.9866737127304076,0.0482679903507232,0.266641069144078,43793.0,0.9858503341674804,0.0515465550124645,0.2606294747945047,43793.0,18270.87906050682,28179.78066945076,18270.87906050682,9904.612380743029,2.7617759704589844,0.0 -57300,0.18042533,0.025204295,,,,,,,,,,,,,,,,, -57400,0.16809419,0.02465515,,,,,,,,,,,,,,,,, -57500,0.21182747,0.025429765,,,,,,,,,,,,,,,,, -57600,0.17339386,0.024611458,,,,,,,,,,,,,,,,, -57700,0.22125477,0.026803114,,,,,,,,,,,,,,,,, -57800,0.14812241,0.022411447,,,,,,,,,,,,,,,,, -57870,,,,,,,,,,,,,,18477.214519262314,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/eval_measurements.csv deleted file mode 100644 index add12a8d2..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -125.95763325691225,0.0,12.21111297607422,1,0,12.21111297607422,0.3947480618953705,0.7956756353378296,0.0263787062049418,43793,138.16879105567932,0.388690173625946,0.7994009852409363,0.0242480071660675,0.3926452994346618,0.7974664568901062,0.0248610234149521,43793 -246.8352580070496,0.0201737880706787,252.2515525817871,746,0,252.2515525817871,0.983142077922821,0.07968270778656,0.0409027168983369,43793,499.1276910305023,0.9867796897888184,0.0680287182331085,0.0402479627806167,0.9841179251670836,0.0767398923635482,0.0399961615512908,43793 -371.7513871192932,0.0476319789886474,492.32368326187134,1496,0,492.32368326187134,0.9833930730819702,0.0623854212462902,0.0889474205675733,43793,864.1640558242798,0.9871765375137328,0.0492448322474956,0.0841241205174674,0.98441344499588,0.059025228023529,0.0906389025904376,43793 -492.4274632930756,0.0747287273406982,732.4389727115631,2245,0,732.4389727115631,0.9838993549346924,0.0564872920513153,0.1300595740504896,43793,1225.00315618515,0.9875968098640442,0.044729109853506,0.1386323224885631,0.9848965406417848,0.0536001808941364,0.1376027091955169,43793 -614.4260520935059,0.1023535728454589,972.578722000122,2995,0,972.578722000122,0.984199285507202,0.0547094978392124,0.1508030045133156,43793,1587.189527511597,0.9878544807434082,0.0425002016127109,0.1700224423253142,0.9851725697517396,0.0518642179667949,0.154618689098535,43793 -734.4927840232849,0.1301419734954834,1212.7844922542572,3740,0,1212.7844922542572,0.984386682510376,0.0537297464907169,0.1661442551062359,43793,1947.5102033615112,0.9881300330162048,0.0417385660111904,0.1906998618648367,0.9853069186210632,0.0512183792889118,0.1698402414675214,43793 -856.147173166275,0.1576616764068603,1452.9492535591123,4502,0,1452.9492535591123,0.9844212532043456,0.053231094032526,0.1764556192425607,43793,2309.3776705265045,0.9881775975227356,0.0406310521066188,0.2115285711345851,0.985330045223236,0.0504153743386268,0.1774918128223911,43793 -980.266857624054,0.186072826385498,1693.020597934723,5250,0,1693.020597934723,0.9847843050956726,0.0512623116374015,0.1949999632181195,43793,2673.617773771286,0.9885719418525696,0.0391697064042091,0.2206978605301385,0.98569256067276,0.0487014725804328,0.1930046659606876,43793 -1101.0847754478457,0.2145509719848632,1933.2276899814608,6005,0,1933.2276899814608,0.984965443611145,0.0509249605238437,0.2020465953537728,43793,3034.691485643387,0.98882657289505,0.0381029769778251,0.2496363131292405,0.9858545660972596,0.0482324659824371,0.2024298052989203,43793 -1223.1207497119904,0.2429695129394531,2173.2930114269257,6763,0,2173.2930114269257,0.9848862290382384,0.0515451729297637,0.2159434795261385,43793,3396.8416588306427,0.988649606704712,0.0384971499443054,0.2566270749427176,0.985815167427063,0.0485941171646118,0.2162778550527938,43793 -1336.374439239502,0.2714407444000244,2413.523932695389,7517,0,2413.523932695389,0.9851423501968384,0.0500729568302631,0.2281813738583669,43793,3750.3751661777496,0.9889222979545592,0.0373423583805561,0.2822505777691139,0.9860039353370668,0.0474244654178619,0.2307869997388974,43793 -1454.5519473552704,0.3003604412078857,2653.601585626602,8263,0,2653.601585626602,0.9852981567382812,0.0496879629790782,0.2319208823917943,43793,4108.679584980011,0.989389181137085,0.0359759628772735,0.3080768857675089,0.9861922860145568,0.0469828322529792,0.2293059993856671,43793 -1574.195317029953,0.3297526836395263,2893.6669194698334,9016,0,2893.6669194698334,0.9853802919387816,0.0493492819368839,0.2334611983065395,43793,4468.438313007355,0.989482879638672,0.0353130511939525,0.3330498346157439,0.9862101674079896,0.0466694124042987,0.2334677768297713,43793 -1698.0499968528748,0.3593053817749023,3133.7673287391663,9774,0,3133.7673287391663,0.98552268743515,0.0488676764070987,0.2379734807150239,43793,4832.443814992905,0.9897928833961488,0.0340859591960907,0.3401076340438795,0.9863518476486206,0.0461979731917381,0.2403073041110711,43793 -1823.354031085968,0.3890912532806396,3373.8351554870605,10527,0,3373.8351554870605,0.9856115579605104,0.0480131767690181,0.2437551593420303,43793,5197.865916728973,0.9902619123458862,0.0326508581638336,0.3747174656981753,0.9864894151687622,0.045411080121994,0.2439541939419379,43793 -1944.7356841564176,0.4193212985992431,3614.053644180298,11279,0,3614.053644180298,0.985668420791626,0.0481474585831165,0.2456158085649711,43793,5559.517339468002,0.9901916980743408,0.0325437784194946,0.3962896425949149,0.9865381717681884,0.0455406345427036,0.2480548745005217,43793 -2069.12337732315,0.4490396976470947,3854.257782936096,12032,0,3854.257782936096,0.9857859015464784,0.0478194020688533,0.2491431849524868,43793,5924.159775972366,0.9903891086578368,0.0321890227496624,0.3812690287674594,0.986594557762146,0.0452002100646495,0.2515214649864786,43793 -2189.3605513572693,0.4801578521728515,4094.240201473236,12773,0,4094.240201473236,0.9857690334320068,0.0479464195668697,0.2452384303975356,43793,6284.43124961853,0.9904727935791016,0.0318740420043468,0.3875386029297761,0.9866684675216676,0.045241005718708,0.2541977274558192,43793 -2315.6966729164124,0.509023904800415,4334.438627481461,13514,0,4334.438627481461,0.9857829809188844,0.0476573593914508,0.2466627038263133,43793,6651.01534485817,0.9906947016716005,0.0309747941792011,0.4179961264701088,0.986711084842682,0.04488967359066,0.2612030305011974,43793 -2436.6722359657288,0.5402529239654541,4574.635347366333,14263,0,4574.635347366333,0.9858015179634094,0.0478576831519603,0.2506972280003634,43793,7012.2398047447205,0.9907548427581788,0.0304514113813638,0.4381895808015915,0.9867098927497864,0.0452342554926872,0.2643233205755981,43793 -2554.1713659763336,0.5711688995361328,4814.710758686066,15014,0,4814.710758686066,0.9858587980270386,0.0478467419743537,0.2430560819462322,43793,7369.865949630737,0.9909887313842772,0.029821703210473,0.4313697479562325,0.9867955446243286,0.0449834764003753,0.2626834619716003,43793 -2670.8944869041443,0.6007318496704102,5054.779027462006,15769,0,5054.779027462006,0.9858301281929016,0.0481723323464393,0.2489279901466689,43793,7726.707957744598,0.9910573363304138,0.0291592068970203,0.4634965915873742,0.9866985082626344,0.0453937016427516,0.2624384497477145,43793 -2790.1232488155365,0.6307599544525146,5294.750148057938,16521,0,5294.750148057938,0.9859126806259156,0.0487215369939804,0.255597488873233,43793,8085.958600759506,0.9913237690925598,0.0280455593019723,0.4919684364377577,0.9867537021636964,0.0459627620875835,0.2659343145663652,43793 -2911.728835821152,0.6605877876281738,5534.754784345627,17272,0,5534.754784345627,0.9858415126800536,0.048270720988512,0.250724034154077,43793,8447.618882656097,0.9916371703147888,0.0272269304841756,0.5029472729700839,0.9867812991142272,0.0454972200095653,0.2610559170688771,43793 -3029.6452848911285,0.6906979084014893,5774.820902109146,18023,0,5774.820902109146,0.985922396183014,0.048704195767641,0.249591743371493,43793,8805.652330636978,0.9916689991950988,0.0266898479312658,0.5299052978509329,0.98673015832901,0.0459031797945499,0.2613781244846068,43793 -3148.767208337784,0.7213609218597412,6014.902328491211,18783,0,6014.902328491211,0.985990583896637,0.0483993627130985,0.2556490203754977,43793,9164.90726852417,0.9919673204421996,0.0264405272901058,0.5123992699519384,0.98688805103302,0.0455853156745433,0.2624033645561057,43793 -3268.615795850754,0.7516143321990967,6255.073797464371,19536,0,6255.073797464371,0.985903024673462,0.0487295389175415,0.2492389463327525,43793,9524.977820158005,0.9915534853935242,0.0274011809378862,0.4995214715325733,0.9868194460868835,0.0458450391888618,0.2619863024626112,43793 -3384.753395795822,0.7820303440093994,6495.213454723358,20286,0,6495.213454723358,0.9859623908996582,0.0483418703079223,0.2521318678421627,43793,9881.306076049805,0.9918997883796692,0.0265837498009204,0.5098556117578865,0.9867545366287231,0.0458379797637462,0.2618027021554986,43793 -3507.4175040721893,0.8135173320770264,6735.299637794495,21038,0,6735.299637794495,0.9859055280685424,0.0488892011344432,0.2506835965684331,43793,10244.107994318008,0.9918867945671082,0.0263824984431266,0.5210673442668449,0.9867756366729736,0.0461118780076503,0.2637069541466819,43793 -3627.493724346161,0.8439273834228516,6975.372453927994,21776,0,6975.372453927994,0.985910177230835,0.0492215119302272,0.2534857262928888,43793,10604.308366298676,0.9919347763061525,0.0261033531278371,0.5254924327822671,0.9867740273475648,0.0462831333279609,0.2619066646066753,43793 -3744.21767282486,0.8761856555938721,7215.407013654709,22521,0,7215.407013654709,0.9858823418617249,0.0494879372417926,0.2469861820981355,43793,10961.120171308516,0.992203176021576,0.0252135433256626,0.5488126686446302,0.9866769909858704,0.0468030013144016,0.256033160416531,43793 -3856.554924964905,0.9077761173248292,7455.372872114181,23271,0,7455.372872114181,0.9858642220497132,0.0495327115058898,0.2471890951017505,43793,11313.475859165192,0.9925545454025269,0.0241808146238327,0.5611472685572075,0.9866765737533568,0.0466638915240764,0.2592590032349791,43793 -3975.702848911285,0.9397668838500975,7695.610192298889,24029,0,7695.610192298889,0.9858760237693788,0.0497966334223747,0.2444182896173111,43793,11672.913708925247,0.9927038550376892,0.0236411523073911,0.5831575393716177,0.9867139458656312,0.0469455383718013,0.2529076497241188,43793 -4095.730885744095,0.9724068641662598,7935.776249885559,24784,0,7935.776249885559,0.985854148864746,0.0500700362026691,0.2452221928092572,43793,12033.161201000214,0.9929652214050292,0.0226630270481109,0.6082374097411949,0.9866116046905518,0.0472517907619476,0.2516948713212647,43793 -4213.223459482193,1.0048103332519531,8175.77491402626,25532,0,8175.77491402626,0.9857863187789916,0.0506494082510471,0.2503859213316042,43793,12390.704937458038,0.9932979941368104,0.021950002759695,0.6178963070088231,0.9866976737976074,0.0474837683141231,0.2625155957085833,43793 -4330.8537838459015,1.036971092224121,8415.784977912903,26287,0,8415.784977912903,0.9856705069541932,0.0509564951062202,0.2435300717592869,43793,12748.3978600502,0.9928412437438964,0.0233154073357582,0.5911147703334756,0.986622989177704,0.0478868074715137,0.2472492395604049,43793 -4449.962881565094,1.069892644882202,8655.887106895447,27045,0,8655.887106895447,0.9858288764953612,0.0509972684085369,0.2449214140991655,43793,13107.662237644196,0.9928683638572692,0.0230578985065221,0.5763528542791916,0.9865012168884276,0.0482120066881179,0.2510674249777193,43793 -4571.956423997879,1.1025707721710205,8895.934210538864,27803,0,8895.934210538864,0.9856839776039124,0.0508299171924591,0.246662901412047,43793,13469.756130218506,0.9929357767105104,0.0228564292192459,0.5939315356893492,0.9865418076515198,0.047837596386671,0.2569972428572599,43793 -4684.665537118912,1.1340579986572266,9135.94087100029,28557,0,9135.94087100029,0.9858267903327942,0.0512571483850479,0.2444432918203075,43793,13822.52390408516,0.992776334285736,0.0231408923864364,0.5826014335653862,0.9866157174110411,0.0484429933130741,0.2535272435693478,43793 -4804.87996172905,1.1669306755065918,9375.945397377014,29310,0,9375.945397377014,0.9858486652374268,0.0519274547696113,0.2451963959188004,43793,14182.795937299728,0.992730975151062,0.0230717379599809,0.5945630507535318,0.9866514205932616,0.0488613173365592,0.2543643487363246,43793 -4919.782430887222,1.1999526023864746,9615.973546028135,30060,0,9615.973546028135,0.9857067465782166,0.0519684143364429,0.2419138446436762,43793,14537.78070282936,0.9930763244628906,0.0221608541905879,0.6124207352776663,0.9865061044692992,0.0489749684929847,0.2477602052753223,43793 -5037.405160903931,1.233625888824463,9856.123964548113,30814,0,9856.123964548113,0.9857239723205566,0.0521627850830554,0.243831158342494,43793,14895.607670545578,0.9933850169181824,0.0210515223443508,0.6370004880408724,0.986537754535675,0.0492011420428752,0.2494871203691623,43793 -5154.061851263046,1.2679109573364258,10096.213876724243,31564,0,10096.213876724243,0.9856313467025756,0.0524902753531932,0.2379547686815906,43793,15252.409235239027,0.99375718832016,0.0201460476964712,0.6480886876305693,0.9864618182182312,0.0492116846144199,0.2465337061384347,43793 -5272.460424661636,1.301285743713379,10336.165597438812,32313,0,10336.165597438812,0.9856852293014526,0.0522603243589401,0.2391222709634299,43793,15610.812860250471,0.9943658113479614,0.018616784363985,0.6874966935541733,0.9864301681518556,0.0492644384503364,0.2495978738414457,43793 -5389.545041322708,1.3359241485595703,10576.188651800156,33071,0,10576.188651800156,0.9856359362602234,0.0532675720751285,0.2360733337637231,43793,15967.97516155243,0.9944509863853456,0.0182839594781398,0.6934071978381392,0.9864094853401184,0.0501183457672596,0.2427620945351066,43793 -5506.360715389252,1.3688089847564695,10816.311703920364,33823,0,10816.311703920364,0.9855618476867676,0.0527226999402046,0.2433782019584294,43793,16324.967136383057,0.994499444961548,0.0184213556349277,0.6916567183095863,0.9863566756248474,0.0498206280171871,0.2441304028718217,43793 -5623.496901512146,1.4026639461517334,11056.350311517715,34579,0,11056.350311517715,0.9853659868240356,0.0530076585710048,0.234444546466829,43793,16682.196016073227,0.9933765530586244,0.0213746968656778,0.616448395606391,0.9862044453620912,0.0500890426337718,0.2375528526500311,43793 -5737.836226701736,1.4372141361236572,11296.30923128128,35330,0,11296.30923128128,0.9857400059700012,0.0538732260465621,0.2414957014130568,43793,17036.55020093918,0.9937258958816528,0.0200071949511766,0.6484402300425647,0.986441135406494,0.0510506741702556,0.2465732751416903,43793 -5856.3613159656525,1.4723279476165771,11536.434869289398,36077,0,11536.434869289398,0.9854881167411804,0.0542867630720138,0.233509910277653,43793,17395.25648856163,0.993654489517212,0.0201186686754226,0.6498716669258316,0.9863563179969788,0.0509757995605468,0.2449991985495474,43793 -5974.8999898433685,1.880678653717041,11776.005351305008,36820,0,11776.005351305008,0.9855584502220154,0.0547553859651088,0.234687398547803,43793,17753.795357704163,0.9935859441757202,0.0201924126595258,0.6461987033015026,0.9864293336868286,0.0515260584652423,0.2412587934181796,43793 -6091.20179772377,1.9143686294555664,12016.200606584547,37569,0,12016.200606584547,0.9854522943496704,0.0544375777244567,0.2326816912556495,43793,18110.347497224808,0.9937519431114196,0.0198342092335224,0.6552595150078337,0.9864228963851928,0.0513176284730434,0.2467318139167215,43793 -6207.4986119270325,1.9481298923492432,12256.40864610672,38325,0,12256.40864610672,0.9854717254638672,0.0555345155298709,0.2283262606970071,43793,18466.906888246536,0.9941584467887878,0.0183299854397773,0.6959283207416935,0.9863651990890504,0.0523091927170753,0.2425323032977104,43793 -6324.049305677414,1.984095573425293,12496.516707897186,39084,0,12496.516707897186,0.9854000806808472,0.0549664758145809,0.2311936963063237,43793,18823.623073339462,0.9948763251304626,0.0168443229049444,0.7108119769907197,0.9862645268440248,0.0518356934189796,0.242916454343494,43793 -6439.895386219025,2.017606496810913,12736.719133377075,39838,0,12736.719133377075,0.9853402972221376,0.0557583570480346,0.2323686769713601,43793,19179.72560429573,0.9953045845031738,0.0157356765121221,0.743408396758674,0.986170768737793,0.0523720234632492,0.2415788185318563,43793 -6552.548879623413,2.0531373023986816,12976.663056850432,40588,0,12976.663056850432,0.9853579998016356,0.0561539120972156,0.2277253445759712,43793,19532.37944626808,0.995507836341858,0.0151507463306188,0.7610764629387446,0.9862247705459596,0.0528991483151912,0.2365607502634769,43793 -6664.61195230484,2.0880980491638184,13216.849076271055,41337,0,13216.849076271055,0.985393762588501,0.056329395622015,0.2293886822357372,43793,19884.684267520905,0.9952184557914734,0.0158216003328561,0.7453891223328923,0.9862101674079896,0.0529556274414062,0.2418717596601929,43793 -6778.221654415131,2.122004747390747,13457.108339548113,42091,0,13457.108339548113,0.985342800617218,0.0567243658006191,0.231405241881548,43793,20238.608663082123,0.99482524394989,0.0167483184486627,0.7169117412420599,0.9861918687820436,0.0537334345281124,0.2353393975103577,43793 -6891.040193080902,2.158005714416504,13697.295625209808,42845,0,13697.295625209808,0.9853861927986144,0.0574819557368755,0.227422304637655,43793,20591.67139029503,0.993688941001892,0.0195393487811088,0.6601647292398976,0.9862625002861024,0.0540768094360828,0.2355572185518675,43793 -7005.932266712189,2.1925864219665527,13937.247032403946,43601,0,13937.247032403946,0.9854089617729188,0.0578960292041301,0.2298433674016478,43793,20946.570340633392,0.9939095377922058,0.0189428441226482,0.6758613495698387,0.986154556274414,0.0547237023711204,0.2323547004349432,43793 -7116.383370637894,2.228288173675537,14177.268211841583,44353,0,14177.268211841583,0.9852712154388428,0.058071594685316,0.225607776266737,43793,21297.09924530983,0.994534432888031,0.0170852746814489,0.7098776166423071,0.9860822558403016,0.0548556223511695,0.2360641167700365,43793 -7230.485638141632,2.2639756202697754,14417.23194217682,45098,0,14417.23194217682,0.9852467775344848,0.0584373772144317,0.2243313486320527,43793,21651.221431732178,0.9942643046379088,0.0177052374929189,0.7059773931270805,0.986059546470642,0.0551608502864837,0.2314232975069496,43793 -7343.484463214874,2.3002424240112305,14657.17865562439,45843,0,14657.17865562439,0.9852017164230348,0.0595104843378067,0.2203197279736102,43793,22004.224217414856,0.9943100214004515,0.0175719261169433,0.711232518020157,0.9860436916351318,0.0557162389159202,0.2310569725032088,43793 -7458.133021593094,2.3357396125793457,14897.215474367142,46594,0,14897.215474367142,0.9852076172828674,0.0586089491844177,0.2262182866703456,43793,22358.96573448181,0.9962734580039978,0.0129803111776709,0.8045735996318846,0.9860225915908812,0.0556060001254081,0.2332473908672151,43793 -7568.040026426315,2.371712207794189,15137.289443016052,47350,0,15137.289443016052,0.9852636456489564,0.0596487857401371,0.2282927555369014,43793,22709.00346231461,0.9964053630828856,0.0125846806913614,0.8052383302908812,0.9860960841178894,0.056346520781517,0.2354068295352333,43793 -7679.870763301849,2.407210350036621,15377.269327640532,48094,0,15377.269327640532,0.9851579070091248,0.0601330213248729,0.2245630058696907,43793,23060.869948148727,0.9965007305145264,0.0123935807496309,0.8195799484038679,0.98596453666687,0.0567309893667697,0.2269859560014178,43793 -7791.991349935532,2.443833827972412,15617.50436925888,48846,0,15617.50436925888,0.9852501749992372,0.0608636960387229,0.2234635566048425,43793,23413.28269124031,0.9960833191871644,0.0130016067996621,0.8006489196358728,0.9861443638801576,0.0572860650718212,0.2332671815231344,43793 -7902.767471790314,2.482178211212158,15857.626742601396,49595,0,15857.626742601396,0.9852871894836426,0.0605197958648204,0.2236626675084078,43793,23764.23997664452,0.9956724047660828,0.0138410003855824,0.797884302093334,0.9861240983009338,0.0572965294122695,0.2315614853424323,43793 -8008.843861818314,2.518423318862915,16097.812799215317,50347,0,16097.812799215317,0.9851604104042052,0.0614901892840862,0.2223522837272326,43793,24110.558915138245,0.9952304363250732,0.0149086937308311,0.7654407735773269,0.9860441088676452,0.0581214651465415,0.231803685346562,43793 -8119.3956298828125,2.554534435272217,16337.775570392609,51099,0,16337.775570392609,0.9852526783943176,0.062040738761425,0.2210060273314866,43793,24461.12946128845,0.9944838285446168,0.0166831631213426,0.735212748790653,0.9861963391304016,0.0584785975515842,0.230818436316012,43793 -8227.75286102295,2.590883016586304,16577.728811979294,51855,0,16577.728811979294,0.9852792024612428,0.0623346120119094,0.2210380329015674,43793,24809.496871471405,0.9948172569274902,0.0157752875238657,0.7602446975067121,0.9861037731170654,0.0588659830391407,0.231489586881311,43793 -8337.247876405716,2.629739284515381,16817.936812877655,52612,0,16817.936812877655,0.9851730465888976,0.0626277402043342,0.2172134167158539,43793,25159.259149074554,0.9952720999717712,0.0145624438300728,0.7853455614625474,0.9860274791717528,0.0588804148137569,0.2315921501409306,43793 -8449.776660203934,2.665452480316162,17057.958690166473,53365,0,17057.958690166473,0.985157072544098,0.0631423890590667,0.2154677620242295,43793,25511.86580467224,0.9949951767921448,0.0151583664119243,0.7691710098606686,0.9860047698020936,0.0594560466706752,0.229525978638166,43793 -8559.149673700333,2.703575134277344,17298.226668834686,54120,0,17298.226668834686,0.9852569103240968,0.0641731023788452,0.2164415117984676,43793,25861.56529688835,0.9957131147384644,0.0136777451261878,0.805361075076067,0.9861322045326232,0.0605087429285049,0.2302780896123628,43793 -8670.035662651062,2.741891622543335,17538.3272023201,54869,0,17538.3272023201,0.9850584864616394,0.0636362731456756,0.2129353179469481,43793,26212.610743045807,0.9979147911071776,0.009291942231357,0.8778368365167077,0.9860007166862488,0.0598200224339962,0.2295835965197492,43793 -8782.897083044052,2.778715848922729,17778.506385326385,55620,0,17778.506385326385,0.9850918054580688,0.0635327771306037,0.2129513821616357,43793,26565.708988189697,0.9977280497550964,0.0094415247440338,0.886361785591973,0.9860274791717528,0.0599618293344974,0.2274789746190675,43793 -8898.630964279175,2.816796064376831,18018.62752890587,56372,0,18018.62752890587,0.9851629734039308,0.064345933496952,0.210049108776192,43793,26921.62234067917,0.9974233508110046,0.0099036889150738,0.8731948836777758,0.9860575199127196,0.0605407804250717,0.2240677321307488,43793 -9009.587788581848,2.856903314590454,18258.657961845398,57128,0,18258.657961845398,0.9850812554359436,0.06488525122404099,0.21352743763209386,43793,27272.670634508133,0.9969244599342346,0.010655700229108334,0.8580153795451726,0.986025869846344,0.06091301143169403,0.22415679755386675,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/measurements.csv deleted file mode 100644 index c858ab0fa..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/measurements.csv +++ /dev/null @@ -1,658 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.1922061,0.7994188,,,,,,,,,,,,,,,,, -1,,,0.388690173625946,0.7994009852409363,0.0242480071660675,0.3926452994346618,0.7974664568901062,0.0248610234149521,43793.0,0.3947480618953705,0.7956756353378296,0.0263787062049418,43793.0,12.21111297607422,138.16879105567932,12.21111297607422,125.95763325691225,0.0,0.0 -100,0.5441595,0.44529498,,,,,,,,,,,,,,,,, -200,0.37456548,0.3250346,,,,,,,,,,,,,,,,, -300,0.26810423,0.22160494,,,,,,,,,,,,,,,,, -400,0.16979486,0.14550261,,,,,,,,,,,,,,,,, -500,0.10604899,0.10376826,,,,,,,,,,,,,,,,, -600,0.07218719,0.078646384,,,,,,,,,,,,,,,,, -700,0.06821112,0.0769736,,,,,,,,,,,,,,,,, -746,,,0.9867796897888184,0.0680287182331085,0.0402479627806167,0.9841179251670836,0.0767398923635482,0.0399961615512908,43793.0,0.983142077922821,0.07968270778656,0.0409027168983369,43793.0,252.2515525817871,499.1276910305023,252.2515525817871,246.8352580070496,0.0201737880706787,0.0 -800,0.07519407,0.057619628,,,,,,,,,,,,,,,,, -900,0.034765318,0.059465516,,,,,,,,,,,,,,,,, -1000,0.06931045,0.05814928,,,,,,,,,,,,,,,,, -1100,0.040943425,0.052637823,,,,,,,,,,,,,,,,, -1200,0.05578661,0.04724662,,,,,,,,,,,,,,,,, -1300,0.07405004,0.046464212,,,,,,,,,,,,,,,,, -1400,0.038086753,0.0453712,,,,,,,,,,,,,,,,, -1496,,,0.9871765375137328,0.0492448322474956,0.0841241205174674,0.98441344499588,0.059025228023529,0.0906389025904376,43793.0,0.9833930730819702,0.0623854212462902,0.0889474205675733,43793.0,492.32368326187134,864.1640558242798,492.32368326187134,371.7513871192932,0.0476319789886474,0.0 -1500,0.0596777,0.05464226,,,,,,,,,,,,,,,,, -1600,0.11015059,0.05393122,,,,,,,,,,,,,,,,, -1700,0.09178949,0.051345274,,,,,,,,,,,,,,,,, -1800,0.028172249,0.04855531,,,,,,,,,,,,,,,,, -1900,0.03887776,0.04723259,,,,,,,,,,,,,,,,, -2000,0.04501613,0.0517789,,,,,,,,,,,,,,,,, -2100,0.07974578,0.045801304,,,,,,,,,,,,,,,,, -2200,0.062176038,0.046927724,,,,,,,,,,,,,,,,, -2245,,,0.9875968098640442,0.044729109853506,0.1386323224885631,0.9848965406417848,0.0536001808941364,0.1376027091955169,43793.0,0.9838993549346924,0.0564872920513153,0.1300595740504896,43793.0,732.4389727115631,1225.00315618515,732.4389727115631,492.4274632930756,0.0747287273406982,0.0 -2300,0.07203466,0.04944961,,,,,,,,,,,,,,,,, -2400,0.046279304,0.04860178,,,,,,,,,,,,,,,,, -2500,0.072354876,0.046270933,,,,,,,,,,,,,,,,, -2600,0.04495153,0.0497994,,,,,,,,,,,,,,,,, -2700,0.055432204,0.04932293,,,,,,,,,,,,,,,,, -2800,0.025476642,0.044786666,,,,,,,,,,,,,,,,, -2900,0.021362117,0.04673139,,,,,,,,,,,,,,,,, -2995,,,0.9878544807434082,0.0425002016127109,0.1700224423253142,0.9851725697517396,0.0518642179667949,0.154618689098535,43793.0,0.984199285507202,0.0547094978392124,0.1508030045133156,43793.0,972.578722000122,1587.189527511597,972.578722000122,614.4260520935059,0.1023535728454589,0.0 -3000,0.05053711,0.04836601,,,,,,,,,,,,,,,,, -3100,0.04106118,0.044801958,,,,,,,,,,,,,,,,, -3200,0.04415555,0.046622217,,,,,,,,,,,,,,,,, -3300,0.044112474,0.041662022,,,,,,,,,,,,,,,,, -3400,0.028024284,0.043735947,,,,,,,,,,,,,,,,, -3500,0.041842736,0.046703164,,,,,,,,,,,,,,,,, -3600,0.020912921,0.04446267,,,,,,,,,,,,,,,,, -3700,0.036397178,0.04256724,,,,,,,,,,,,,,,,, -3740,,,0.9881300330162048,0.0417385660111904,0.1906998618648367,0.9853069186210632,0.0512183792889118,0.1698402414675214,43793.0,0.984386682510376,0.0537297464907169,0.1661442551062359,43793.0,1212.7844922542572,1947.5102033615112,1212.7844922542572,734.4927840232849,0.1301419734954834,0.0 -3800,0.021781672,0.04659333,,,,,,,,,,,,,,,,, -3900,0.024998194,0.04287018,,,,,,,,,,,,,,,,, -4000,0.0264137,0.04320994,,,,,,,,,,,,,,,,, -4100,0.032613605,0.047880795,,,,,,,,,,,,,,,,, -4200,0.029383078,0.049713235,,,,,,,,,,,,,,,,, -4300,0.021208744,0.04145205,,,,,,,,,,,,,,,,, -4400,0.023209125,0.039230425,,,,,,,,,,,,,,,,, -4500,0.015416955,0.04348787,,,,,,,,,,,,,,,,, -4502,,,0.9881775975227356,0.0406310521066188,0.2115285711345851,0.985330045223236,0.0504153743386268,0.1774918128223911,43793.0,0.9844212532043456,0.053231094032526,0.1764556192425607,43793.0,1452.9492535591123,2309.3776705265045,1452.9492535591123,856.147173166275,0.1576616764068603,0.0 -4600,0.027546419,0.044296637,,,,,,,,,,,,,,,,, -4700,0.031397834,0.04507167,,,,,,,,,,,,,,,,, -4800,0.024592208,0.04199115,,,,,,,,,,,,,,,,, -4900,0.013736869,0.045973092,,,,,,,,,,,,,,,,, -5000,0.021976823,0.045558374,,,,,,,,,,,,,,,,, -5100,0.021471072,0.041476943,,,,,,,,,,,,,,,,, -5200,0.019006032,0.0417459,,,,,,,,,,,,,,,,, -5250,,,0.9885719418525696,0.0391697064042091,0.2206978605301385,0.98569256067276,0.0487014725804328,0.1930046659606876,43793.0,0.9847843050956726,0.0512623116374015,0.1949999632181195,43793.0,1693.020597934723,2673.617773771286,1693.020597934723,980.266857624054,0.186072826385498,0.0 -5300,0.022263432,0.045071594,,,,,,,,,,,,,,,,, -5400,0.02263671,0.03913112,,,,,,,,,,,,,,,,, -5500,0.023084477,0.04676274,,,,,,,,,,,,,,,,, -5600,0.016988622,0.040994406,,,,,,,,,,,,,,,,, -5700,0.023811903,0.042672876,,,,,,,,,,,,,,,,, -5800,0.029170837,0.041467946,,,,,,,,,,,,,,,,, -5900,0.0313905,0.040266663,,,,,,,,,,,,,,,,, -6000,0.016190032,0.043044243,,,,,,,,,,,,,,,,, -6005,,,0.98882657289505,0.0381029769778251,0.2496363131292405,0.9858545660972596,0.0482324659824371,0.2024298052989203,43793.0,0.984965443611145,0.0509249605238437,0.2020465953537728,43793.0,1933.2276899814608,3034.691485643387,1933.2276899814608,1101.0847754478457,0.2145509719848632,0.0 -6100,0.016035458,0.04370355,,,,,,,,,,,,,,,,, -6200,0.022015385,0.040200107,,,,,,,,,,,,,,,,, -6300,0.019398175,0.04101315,,,,,,,,,,,,,,,,, -6400,0.015995627,0.04641711,,,,,,,,,,,,,,,,, -6500,0.015086429,0.04339779,,,,,,,,,,,,,,,,, -6600,0.015541316,0.040498827,,,,,,,,,,,,,,,,, -6700,0.014260353,0.042158227,,,,,,,,,,,,,,,,, -6763,,,0.988649606704712,0.0384971499443054,0.2566270749427176,0.985815167427063,0.0485941171646118,0.2162778550527938,43793.0,0.9848862290382384,0.0515451729297637,0.2159434795261385,43793.0,2173.2930114269257,3396.8416588306427,2173.2930114269257,1223.1207497119904,0.2429695129394531,0.0 -6800,0.015108726,0.040342562,,,,,,,,,,,,,,,,, -6900,0.025606073,0.043810748,,,,,,,,,,,,,,,,, -7000,0.019333279,0.043113507,,,,,,,,,,,,,,,,, -7100,0.015280983,0.03920671,,,,,,,,,,,,,,,,, -7200,0.023914328,0.040130794,,,,,,,,,,,,,,,,, -7300,0.016917877,0.04337341,,,,,,,,,,,,,,,,, -7400,0.021900456,0.042164486,,,,,,,,,,,,,,,,, -7500,0.014861268,0.04179569,,,,,,,,,,,,,,,,, -7517,,,0.9889222979545592,0.0373423583805561,0.2822505777691139,0.9860039353370668,0.0474244654178619,0.2307869997388974,43793.0,0.9851423501968384,0.0500729568302631,0.2281813738583669,43793.0,2413.523932695389,3750.3751661777496,2413.523932695389,1336.374439239502,0.2714407444000244,0.0 -7600,0.016420696,0.041399956,,,,,,,,,,,,,,,,, -7700,0.020238096,0.041844107,,,,,,,,,,,,,,,,, -7800,0.025563749,0.03948439,,,,,,,,,,,,,,,,, -7900,0.015258615,0.040462624,,,,,,,,,,,,,,,,, -8000,0.022316795,0.04123968,,,,,,,,,,,,,,,,, -8100,0.018157206,0.044545148,,,,,,,,,,,,,,,,, -8200,0.015709372,0.038139123,,,,,,,,,,,,,,,,, -8263,,,0.989389181137085,0.0359759628772735,0.3080768857675089,0.9861922860145568,0.0469828322529792,0.2293059993856671,43793.0,0.9852981567382812,0.0496879629790782,0.2319208823917943,43793.0,2653.601585626602,4108.679584980011,2653.601585626602,1454.5519473552704,0.3003604412078857,0.0 -8300,0.016043857,0.03867453,,,,,,,,,,,,,,,,, -8400,0.01923471,0.04096817,,,,,,,,,,,,,,,,, -8500,0.032775234,0.041977603,,,,,,,,,,,,,,,,, -8600,0.020938495,0.041268896,,,,,,,,,,,,,,,,, -8700,0.017302813,0.041659806,,,,,,,,,,,,,,,,, -8800,0.01684666,0.041122,,,,,,,,,,,,,,,,, -8900,0.016744563,0.040590692,,,,,,,,,,,,,,,,, -9000,0.016196523,0.043624047,,,,,,,,,,,,,,,,, -9016,,,0.989482879638672,0.0353130511939525,0.3330498346157439,0.9862101674079896,0.0466694124042987,0.2334677768297713,43793.0,0.9853802919387816,0.0493492819368839,0.2334611983065395,43793.0,2893.6669194698334,4468.438313007355,2893.6669194698334,1574.195317029953,0.3297526836395263,0.0 -9100,0.017069511,0.042489775,,,,,,,,,,,,,,,,, -9200,0.019303927,0.039524637,,,,,,,,,,,,,,,,, -9300,0.018160092,0.04321225,,,,,,,,,,,,,,,,, -9400,0.021268383,0.042109795,,,,,,,,,,,,,,,,, -9500,0.019029096,0.037469525,,,,,,,,,,,,,,,,, -9600,0.020607462,0.04197068,,,,,,,,,,,,,,,,, -9700,0.020420713,0.043063074,,,,,,,,,,,,,,,,, -9774,,,0.9897928833961488,0.0340859591960907,0.3401076340438795,0.9863518476486206,0.0461979731917381,0.2403073041110711,43793.0,0.98552268743515,0.0488676764070987,0.2379734807150239,43793.0,3133.7673287391663,4832.443814992905,3133.7673287391663,1698.0499968528748,0.3593053817749023,0.0 -9800,0.019867461,0.03804237,,,,,,,,,,,,,,,,, -9900,0.017196855,0.039701827,,,,,,,,,,,,,,,,, -10000,0.014383595,0.03917426,,,,,,,,,,,,,,,,, -10100,0.01485471,0.03852236,,,,,,,,,,,,,,,,, -10200,0.018469218,0.04142895,,,,,,,,,,,,,,,,, -10300,0.013717654,0.03749875,,,,,,,,,,,,,,,,, -10400,0.019094246,0.040721957,,,,,,,,,,,,,,,,, -10500,0.01770751,0.038784247,,,,,,,,,,,,,,,,, -10527,,,0.9902619123458862,0.0326508581638336,0.3747174656981753,0.9864894151687622,0.045411080121994,0.2439541939419379,43793.0,0.9856115579605104,0.0480131767690181,0.2437551593420303,43793.0,3373.8351554870605,5197.865916728973,3373.8351554870605,1823.354031085968,0.3890912532806396,0.0 -10600,0.021079103,0.03861311,,,,,,,,,,,,,,,,, -10700,0.023223098,0.037666377,,,,,,,,,,,,,,,,, -10800,0.015113804,0.038623184,,,,,,,,,,,,,,,,, -10900,0.014202921,0.03955519,,,,,,,,,,,,,,,,, -11000,0.01724854,0.0389536,,,,,,,,,,,,,,,,, -11100,0.01655346,0.03907058,,,,,,,,,,,,,,,,, -11200,0.015010716,0.036207117,,,,,,,,,,,,,,,,, -11279,,,0.9901916980743408,0.0325437784194946,0.3962896425949149,0.9865381717681884,0.0455406345427036,0.2480548745005217,43793.0,0.985668420791626,0.0481474585831165,0.2456158085649711,43793.0,3614.053644180298,5559.517339468002,3614.053644180298,1944.7356841564176,0.4193212985992431,0.0 -11300,0.01629127,0.03969346,,,,,,,,,,,,,,,,, -11400,0.016472591,0.038744446,,,,,,,,,,,,,,,,, -11500,0.017022913,0.03954712,,,,,,,,,,,,,,,,, -11600,0.015956676,0.03934904,,,,,,,,,,,,,,,,, -11700,0.017549783,0.04026198,,,,,,,,,,,,,,,,, -11800,0.016188726,0.035971265,,,,,,,,,,,,,,,,, -11900,0.02228166,0.040026646,,,,,,,,,,,,,,,,, -12000,0.014328134,0.03851815,,,,,,,,,,,,,,,,, -12032,,,0.9903891086578368,0.0321890227496624,0.3812690287674594,0.986594557762146,0.0452002100646495,0.2515214649864786,43793.0,0.9857859015464784,0.0478194020688533,0.2491431849524868,43793.0,3854.257782936096,5924.159775972366,3854.257782936096,2069.12337732315,0.4490396976470947,0.0 -12100,0.021624418,0.037643004,,,,,,,,,,,,,,,,, -12200,0.01615661,0.03749042,,,,,,,,,,,,,,,,, -12300,0.017718013,0.03892627,,,,,,,,,,,,,,,,, -12400,0.015571869,0.036741134,,,,,,,,,,,,,,,,, -12500,0.017132662,0.038685814,,,,,,,,,,,,,,,,, -12600,0.020523258,0.03633302,,,,,,,,,,,,,,,,, -12700,0.021187365,0.041590817,,,,,,,,,,,,,,,,, -12773,,,0.9904727935791016,0.0318740420043468,0.3875386029297761,0.9866684675216676,0.045241005718708,0.2541977274558192,43793.0,0.9857690334320068,0.0479464195668697,0.2452384303975356,43793.0,4094.240201473236,6284.43124961853,4094.240201473236,2189.3605513572693,0.4801578521728515,0.0 -12800,0.017425379,0.036912903,,,,,,,,,,,,,,,,, -12900,0.018276457,0.037012402,,,,,,,,,,,,,,,,, -13000,0.016449964,0.037695866,,,,,,,,,,,,,,,,, -13100,0.022786664,0.034566335,,,,,,,,,,,,,,,,, -13200,0.02061739,0.036562305,,,,,,,,,,,,,,,,, -13300,0.019500703,0.038267378,,,,,,,,,,,,,,,,, -13400,0.018664751,0.037819363,,,,,,,,,,,,,,,,, -13500,0.024156952,0.03840742,,,,,,,,,,,,,,,,, -13514,,,0.9906947016716005,0.0309747941792011,0.4179961264701088,0.986711084842682,0.04488967359066,0.2612030305011974,43793.0,0.9857829809188844,0.0476573593914508,0.2466627038263133,43793.0,4334.438627481461,6651.01534485817,4334.438627481461,2315.6966729164124,0.509023904800415,0.0 -13600,0.016378652,0.03919016,,,,,,,,,,,,,,,,, -13700,0.016569909,0.0352451,,,,,,,,,,,,,,,,, -13800,0.019519102,0.035864536,,,,,,,,,,,,,,,,, -13900,0.019651676,0.040633094,,,,,,,,,,,,,,,,, -14000,0.020348163,0.035531677,,,,,,,,,,,,,,,,, -14100,0.019324973,0.037038844,,,,,,,,,,,,,,,,, -14200,0.024472833,0.042180277,,,,,,,,,,,,,,,,, -14263,,,0.9907548427581788,0.0304514113813638,0.4381895808015915,0.9867098927497864,0.0452342554926872,0.2643233205755981,43793.0,0.9858015179634094,0.0478576831519603,0.2506972280003634,43793.0,4574.635347366333,7012.2398047447205,4574.635347366333,2436.6722359657288,0.5402529239654541,0.0 -14300,0.022452742,0.036371127,,,,,,,,,,,,,,,,, -14400,0.02657711,0.041220218,,,,,,,,,,,,,,,,, -14500,0.021809313,0.03661196,,,,,,,,,,,,,,,,, -14600,0.018349232,0.037616774,,,,,,,,,,,,,,,,, -14700,0.016337775,0.037405238,,,,,,,,,,,,,,,,, -14800,0.018694216,0.03361194,,,,,,,,,,,,,,,,, -14900,0.025069395,0.038400576,,,,,,,,,,,,,,,,, -15000,0.021960992,0.03709795,,,,,,,,,,,,,,,,, -15014,,,0.9909887313842772,0.029821703210473,0.4313697479562325,0.9867955446243286,0.0449834764003753,0.2626834619716003,43793.0,0.9858587980270386,0.0478467419743537,0.2430560819462322,43793.0,4814.710758686066,7369.865949630737,4814.710758686066,2554.1713659763336,0.5711688995361328,0.0 -15100,0.020597357,0.03440588,,,,,,,,,,,,,,,,, -15200,0.020615513,0.03639748,,,,,,,,,,,,,,,,, -15300,0.020708133,0.039487176,,,,,,,,,,,,,,,,, -15400,0.018015597,0.035461593,,,,,,,,,,,,,,,,, -15500,0.0254598,0.0369301,,,,,,,,,,,,,,,,, -15600,0.02037867,0.03747807,,,,,,,,,,,,,,,,, -15700,0.023755994,0.038346946,,,,,,,,,,,,,,,,, -15769,,,0.9910573363304138,0.0291592068970203,0.4634965915873742,0.9866985082626344,0.0453937016427516,0.2624384497477145,43793.0,0.9858301281929016,0.0481723323464393,0.2489279901466689,43793.0,5054.779027462006,7726.707957744598,5054.779027462006,2670.8944869041443,0.6007318496704102,0.0 -15800,0.020963335,0.035785507,,,,,,,,,,,,,,,,, -15900,0.022291372,0.035020664,,,,,,,,,,,,,,,,, -16000,0.018666256,0.032749716,,,,,,,,,,,,,,,,, -16100,0.026406823,0.038733732,,,,,,,,,,,,,,,,, -16200,0.025438488,0.03549186,,,,,,,,,,,,,,,,, -16300,0.023075882,0.03318051,,,,,,,,,,,,,,,,, -16400,0.025970066,0.03831976,,,,,,,,,,,,,,,,, -16500,0.021748533,0.032315668,,,,,,,,,,,,,,,,, -16521,,,0.9913237690925598,0.0280455593019723,0.4919684364377577,0.9867537021636964,0.0459627620875835,0.2659343145663652,43793.0,0.9859126806259156,0.0487215369939804,0.255597488873233,43793.0,5294.750148057938,8085.958600759506,5294.750148057938,2790.1232488155365,0.6307599544525146,0.0 -16600,0.026553156,0.0350002,,,,,,,,,,,,,,,,, -16700,0.027117305,0.036998376,,,,,,,,,,,,,,,,, -16800,0.022094851,0.035737745,,,,,,,,,,,,,,,,, -16900,0.023088733,0.03547673,,,,,,,,,,,,,,,,, -17000,0.022194501,0.036198277,,,,,,,,,,,,,,,,, -17100,0.023829881,0.03537052,,,,,,,,,,,,,,,,, -17200,0.022145147,0.03433244,,,,,,,,,,,,,,,,, -17272,,,0.9916371703147888,0.0272269304841756,0.5029472729700839,0.9867812991142272,0.0454972200095653,0.2610559170688771,43793.0,0.9858415126800536,0.048270720988512,0.250724034154077,43793.0,5534.754784345627,8447.618882656097,5534.754784345627,2911.728835821152,0.6605877876281738,0.0 -17300,0.022792583,0.03534958,,,,,,,,,,,,,,,,, -17400,0.02766155,0.034304388,,,,,,,,,,,,,,,,, -17500,0.026699165,0.034651194,,,,,,,,,,,,,,,,, -17600,0.025591027,0.0352104,,,,,,,,,,,,,,,,, -17700,0.022521606,0.031905696,,,,,,,,,,,,,,,,, -17800,0.02461243,0.033331223,,,,,,,,,,,,,,,,, -17900,0.021658665,0.03526763,,,,,,,,,,,,,,,,, -18000,0.024188451,0.03598379,,,,,,,,,,,,,,,,, -18023,,,0.9916689991950988,0.0266898479312658,0.5299052978509329,0.98673015832901,0.0459031797945499,0.2613781244846068,43793.0,0.985922396183014,0.048704195767641,0.249591743371493,43793.0,5774.820902109146,8805.652330636978,5774.820902109146,3029.6452848911285,0.6906979084014893,0.0 -18100,0.031742316,0.034720756,,,,,,,,,,,,,,,,, -18200,0.024233095,0.03668072,,,,,,,,,,,,,,,,, -18300,0.030578054,0.035697456,,,,,,,,,,,,,,,,, -18400,0.022842875,0.033517387,,,,,,,,,,,,,,,,, -18500,0.023383822,0.03328839,,,,,,,,,,,,,,,,, -18600,0.024191964,0.03583899,,,,,,,,,,,,,,,,, -18700,0.024849523,0.032245368,,,,,,,,,,,,,,,,, -18783,,,0.9919673204421996,0.0264405272901058,0.5123992699519384,0.98688805103302,0.0455853156745433,0.2624033645561057,43793.0,0.985990583896637,0.0483993627130985,0.2556490203754977,43793.0,6014.902328491211,9164.90726852417,6014.902328491211,3148.767208337784,0.7213609218597412,0.0 -18800,0.026344374,0.034603458,,,,,,,,,,,,,,,,, -18900,0.02471025,0.034334973,,,,,,,,,,,,,,,,, -19000,0.027950741,0.034405064,,,,,,,,,,,,,,,,, -19100,0.022510689,0.03580899,,,,,,,,,,,,,,,,, -19200,0.025120385,0.036236368,,,,,,,,,,,,,,,,, -19300,0.023973953,0.035358187,,,,,,,,,,,,,,,,, -19400,0.022738658,0.035109144,,,,,,,,,,,,,,,,, -19500,0.025897622,0.03306145,,,,,,,,,,,,,,,,, -19536,,,0.9915534853935242,0.0274011809378862,0.4995214715325733,0.9868194460868835,0.0458450391888618,0.2619863024626112,43793.0,0.985903024673462,0.0487295389175415,0.2492389463327525,43793.0,6255.073797464371,9524.977820158005,6255.073797464371,3268.615795850754,0.7516143321990967,0.0 -19600,0.024463769,0.03379128,,,,,,,,,,,,,,,,, -19700,0.024578027,0.03189247,,,,,,,,,,,,,,,,, -19800,0.024158673,0.032540843,,,,,,,,,,,,,,,,, -19900,0.024289148,0.033468198,,,,,,,,,,,,,,,,, -20000,0.022796152,0.0321448,,,,,,,,,,,,,,,,, -20100,0.031388234,0.03318251,,,,,,,,,,,,,,,,, -20200,0.028189585,0.03066277,,,,,,,,,,,,,,,,, -20286,,,0.9918997883796692,0.0265837498009204,0.5098556117578865,0.9867545366287231,0.0458379797637462,0.2618027021554986,43793.0,0.9859623908996582,0.0483418703079223,0.2521318678421627,43793.0,6495.213454723358,9881.306076049805,6495.213454723358,3384.753395795822,0.7820303440093994,0.0 -20300,0.030224606,0.033636983,,,,,,,,,,,,,,,,, -20400,0.023714941,0.033640638,,,,,,,,,,,,,,,,, -20500,0.032122042,0.035195693,,,,,,,,,,,,,,,,, -20600,0.021620933,0.028878551,,,,,,,,,,,,,,,,, -20700,0.024045715,0.034928735,,,,,,,,,,,,,,,,, -20800,0.026528496,0.032436248,,,,,,,,,,,,,,,,, -20900,0.02902534,0.033232935,,,,,,,,,,,,,,,,, -21000,0.0389976,0.034106225,,,,,,,,,,,,,,,,, -21038,,,0.9918867945671082,0.0263824984431266,0.5210673442668449,0.9867756366729736,0.0461118780076503,0.2637069541466819,43793.0,0.9859055280685424,0.0488892011344432,0.2506835965684331,43793.0,6735.299637794495,10244.107994318008,6735.299637794495,3507.4175040721893,0.8135173320770264,0.0 -21100,0.026643673,0.03484566,,,,,,,,,,,,,,,,, -21200,0.02759609,0.03287039,,,,,,,,,,,,,,,,, -21300,0.025617687,0.031989813,,,,,,,,,,,,,,,,, -21400,0.024732074,0.033164483,,,,,,,,,,,,,,,,, -21500,0.03687034,0.03610732,,,,,,,,,,,,,,,,, -21600,0.027879635,0.03398483,,,,,,,,,,,,,,,,, -21700,0.029592734,0.030393457,,,,,,,,,,,,,,,,, -21776,,,0.9919347763061525,0.0261033531278371,0.5254924327822671,0.9867740273475648,0.0462831333279609,0.2619066646066753,43793.0,0.985910177230835,0.0492215119302272,0.2534857262928888,43793.0,6975.372453927994,10604.308366298676,6975.372453927994,3627.493724346161,0.8439273834228516,0.0 -21800,0.030806528,0.033710834,,,,,,,,,,,,,,,,, -21900,0.029158123,0.035533443,,,,,,,,,,,,,,,,, -22000,0.034427375,0.03459317,,,,,,,,,,,,,,,,, -22100,0.050217275,0.035532303,,,,,,,,,,,,,,,,, -22200,0.03208507,0.033866204,,,,,,,,,,,,,,,,, -22300,0.026689412,0.03218332,,,,,,,,,,,,,,,,, -22400,0.035336252,0.03411513,,,,,,,,,,,,,,,,, -22500,0.02810055,0.032909136,,,,,,,,,,,,,,,,, -22521,,,0.992203176021576,0.0252135433256626,0.5488126686446302,0.9866769909858704,0.0468030013144016,0.256033160416531,43793.0,0.9858823418617249,0.0494879372417926,0.2469861820981355,43793.0,7215.407013654709,10961.120171308516,7215.407013654709,3744.21767282486,0.8761856555938721,0.0 -22600,0.029504007,0.03470409,,,,,,,,,,,,,,,,, -22700,0.02884866,0.03113829,,,,,,,,,,,,,,,,, -22800,0.03480878,0.034379266,,,,,,,,,,,,,,,,, -22900,0.029713552,0.031861912,,,,,,,,,,,,,,,,, -23000,0.032393895,0.030540127,,,,,,,,,,,,,,,,, -23100,0.033326965,0.03209101,,,,,,,,,,,,,,,,, -23200,0.033997197,0.03311995,,,,,,,,,,,,,,,,, -23271,,,0.9925545454025269,0.0241808146238327,0.5611472685572075,0.9866765737533568,0.0466638915240764,0.2592590032349791,43793.0,0.9858642220497132,0.0495327115058898,0.2471890951017505,43793.0,7455.372872114181,11313.475859165192,7455.372872114181,3856.554924964905,0.9077761173248292,0.0 -23300,0.03270588,0.033189338,,,,,,,,,,,,,,,,, -23400,0.02864021,0.030014247,,,,,,,,,,,,,,,,, -23500,0.032303564,0.034730863,,,,,,,,,,,,,,,,, -23600,0.030874277,0.030832026,,,,,,,,,,,,,,,,, -23700,0.035401847,0.03392321,,,,,,,,,,,,,,,,, -23800,0.033948682,0.033192333,,,,,,,,,,,,,,,,, -23900,0.04033536,0.032985426,,,,,,,,,,,,,,,,, -24000,0.03165744,0.03326486,,,,,,,,,,,,,,,,, -24029,,,0.9927038550376892,0.0236411523073911,0.5831575393716177,0.9867139458656312,0.0469455383718013,0.2529076497241188,43793.0,0.9858760237693788,0.0497966334223747,0.2444182896173111,43793.0,7695.610192298889,11672.913708925247,7695.610192298889,3975.702848911285,0.9397668838500975,0.0 -24100,0.033021152,0.032600854,,,,,,,,,,,,,,,,, -24200,0.040314965,0.034110844,,,,,,,,,,,,,,,,, -24300,0.0333415,0.028951623,,,,,,,,,,,,,,,,, -24400,0.029916316,0.029926108,,,,,,,,,,,,,,,,, -24500,0.03125113,0.03181173,,,,,,,,,,,,,,,,, -24600,0.034549415,0.03205708,,,,,,,,,,,,,,,,, -24700,0.036262847,0.035686098,,,,,,,,,,,,,,,,, -24784,,,0.9929652214050292,0.0226630270481109,0.6082374097411949,0.9866116046905518,0.0472517907619476,0.2516948713212647,43793.0,0.985854148864746,0.0500700362026691,0.2452221928092572,43793.0,7935.776249885559,12033.161201000214,7935.776249885559,4095.730885744095,0.9724068641662598,0.0 -24800,0.031054566,0.028448258,,,,,,,,,,,,,,,,, -24900,0.032333992,0.032715484,,,,,,,,,,,,,,,,, -25000,0.03281111,0.031692702,,,,,,,,,,,,,,,,, -25100,0.030267645,0.030250901,,,,,,,,,,,,,,,,, -25200,0.030765364,0.032127578,,,,,,,,,,,,,,,,, -25300,0.035081025,0.031597555,,,,,,,,,,,,,,,,, -25400,0.03335375,0.031014545,,,,,,,,,,,,,,,,, -25500,0.035145942,0.030428836,,,,,,,,,,,,,,,,, -25532,,,0.9932979941368104,0.021950002759695,0.6178963070088231,0.9866976737976074,0.0474837683141231,0.2625155957085833,43793.0,0.9857863187789916,0.0506494082510471,0.2503859213316042,43793.0,8175.77491402626,12390.704937458038,8175.77491402626,4213.223459482193,1.0048103332519531,0.0 -25600,0.035349153,0.03177192,,,,,,,,,,,,,,,,, -25700,0.035119887,0.03342251,,,,,,,,,,,,,,,,, -25800,0.03852156,0.031943325,,,,,,,,,,,,,,,,, -25900,0.039394405,0.034385834,,,,,,,,,,,,,,,,, -26000,0.044323843,0.032159116,,,,,,,,,,,,,,,,, -26100,0.036743343,0.03320145,,,,,,,,,,,,,,,,, -26200,0.047062777,0.03242605,,,,,,,,,,,,,,,,, -26287,,,0.9928412437438964,0.0233154073357582,0.5911147703334756,0.986622989177704,0.0478868074715137,0.2472492395604049,43793.0,0.9856705069541932,0.0509564951062202,0.2435300717592869,43793.0,8415.784977912903,12748.3978600502,8415.784977912903,4330.8537838459015,1.036971092224121,0.0 -26300,0.04443999,0.030415747,,,,,,,,,,,,,,,,, -26400,0.040135633,0.031541537,,,,,,,,,,,,,,,,, -26500,0.040586066,0.030346194,,,,,,,,,,,,,,,,, -26600,0.040315636,0.03179676,,,,,,,,,,,,,,,,, -26700,0.03818584,0.032244198,,,,,,,,,,,,,,,,, -26800,0.039059386,0.032440793,,,,,,,,,,,,,,,,, -26900,0.056139395,0.035590284,,,,,,,,,,,,,,,,, -27000,0.044310406,0.032218914,,,,,,,,,,,,,,,,, -27045,,,0.9928683638572692,0.0230578985065221,0.5763528542791916,0.9865012168884276,0.0482120066881179,0.2510674249777193,43793.0,0.9858288764953612,0.0509972684085369,0.2449214140991655,43793.0,8655.887106895447,13107.662237644196,8655.887106895447,4449.962881565094,1.069892644882202,0.0 -27100,0.040222052,0.032402944,,,,,,,,,,,,,,,,, -27200,0.040366296,0.031205649,,,,,,,,,,,,,,,,, -27300,0.04449009,0.030567719,,,,,,,,,,,,,,,,, -27400,0.042523384,0.033152103,,,,,,,,,,,,,,,,, -27500,0.043833673,0.033688042,,,,,,,,,,,,,,,,, -27600,0.04102629,0.031030491,,,,,,,,,,,,,,,,, -27700,0.041809987,0.031535257,,,,,,,,,,,,,,,,, -27800,0.041200824,0.030895617,,,,,,,,,,,,,,,,, -27803,,,0.9929357767105104,0.0228564292192459,0.5939315356893492,0.9865418076515198,0.047837596386671,0.2569972428572599,43793.0,0.9856839776039124,0.0508299171924591,0.246662901412047,43793.0,8895.934210538864,13469.756130218506,8895.934210538864,4571.956423997879,1.1025707721710205,0.0 -27900,0.04187109,0.032272164,,,,,,,,,,,,,,,,, -28000,0.044497427,0.032227296,,,,,,,,,,,,,,,,, -28100,0.04038538,0.032194376,,,,,,,,,,,,,,,,, -28200,0.041848462,0.0330435,,,,,,,,,,,,,,,,, -28300,0.047971092,0.030986637,,,,,,,,,,,,,,,,, -28400,0.052351322,0.03069764,,,,,,,,,,,,,,,,, -28500,0.044988744,0.030978346,,,,,,,,,,,,,,,,, -28557,,,0.992776334285736,0.0231408923864364,0.5826014335653862,0.9866157174110411,0.0484429933130741,0.2535272435693478,43793.0,0.9858267903327942,0.0512571483850479,0.2444432918203075,43793.0,9135.94087100029,13822.52390408516,9135.94087100029,4684.665537118912,1.1340579986572266,0.0 -28600,0.0472734,0.030420732,,,,,,,,,,,,,,,,, -28700,0.03861588,0.03100108,,,,,,,,,,,,,,,,, -28800,0.047390845,0.030371947,,,,,,,,,,,,,,,,, -28900,0.04223051,0.02972027,,,,,,,,,,,,,,,,, -29000,0.03858356,0.028437449,,,,,,,,,,,,,,,,, -29100,0.039827257,0.029334381,,,,,,,,,,,,,,,,, -29200,0.04145348,0.029565506,,,,,,,,,,,,,,,,, -29300,0.042686686,0.031393535,,,,,,,,,,,,,,,,, -29310,,,0.992730975151062,0.0230717379599809,0.5945630507535318,0.9866514205932616,0.0488613173365592,0.2543643487363246,43793.0,0.9858486652374268,0.0519274547696113,0.2451963959188004,43793.0,9375.945397377014,14182.795937299728,9375.945397377014,4804.87996172905,1.1669306755065918,0.0 -29400,0.046897963,0.028737715,,,,,,,,,,,,,,,,, -29500,0.054804526,0.029461911,,,,,,,,,,,,,,,,, -29600,0.04714113,0.030113418,,,,,,,,,,,,,,,,, -29700,0.048136726,0.029959293,,,,,,,,,,,,,,,,, -29800,0.04318194,0.028991671,,,,,,,,,,,,,,,,, -29900,0.05005567,0.03218858,,,,,,,,,,,,,,,,, -30000,0.05065964,0.031523105,,,,,,,,,,,,,,,,, -30060,,,0.9930763244628906,0.0221608541905879,0.6124207352776663,0.9865061044692992,0.0489749684929847,0.2477602052753223,43793.0,0.9857067465782166,0.0519684143364429,0.2419138446436762,43793.0,9615.973546028135,14537.78070282936,9615.973546028135,4919.782430887222,1.1999526023864746,0.0 -30100,0.041829813,0.031271644,,,,,,,,,,,,,,,,, -30200,0.041665822,0.029804174,,,,,,,,,,,,,,,,, -30300,0.05718102,0.028774569,,,,,,,,,,,,,,,,, -30400,0.045893203,0.02989669,,,,,,,,,,,,,,,,, -30500,0.04476334,0.031167928,,,,,,,,,,,,,,,,, -30600,0.044442177,0.028313048,,,,,,,,,,,,,,,,, -30700,0.042446837,0.02821415,,,,,,,,,,,,,,,,, -30800,0.0553468,0.03147413,,,,,,,,,,,,,,,,, -30814,,,0.9933850169181824,0.0210515223443508,0.6370004880408724,0.986537754535675,0.0492011420428752,0.2494871203691623,43793.0,0.9857239723205566,0.0521627850830554,0.243831158342494,43793.0,9856.123964548113,14895.607670545578,9856.123964548113,5037.405160903931,1.233625888824463,0.0 -30900,0.057880808,0.030578207,,,,,,,,,,,,,,,,, -31000,0.046312135,0.029917704,,,,,,,,,,,,,,,,, -31100,0.047084626,0.029232573,,,,,,,,,,,,,,,,, -31200,0.050959952,0.03063741,,,,,,,,,,,,,,,,, -31300,0.04643387,0.02992704,,,,,,,,,,,,,,,,, -31400,0.05385332,0.027786916,,,,,,,,,,,,,,,,, -31500,0.04523963,0.027037328,,,,,,,,,,,,,,,,, -31564,,,0.99375718832016,0.0201460476964712,0.6480886876305693,0.9864618182182312,0.0492116846144199,0.2465337061384347,43793.0,0.9856313467025756,0.0524902753531932,0.2379547686815906,43793.0,10096.213876724243,15252.409235239027,10096.213876724243,5154.061851263046,1.2679109573364258,0.0 -31600,0.048892647,0.030627156,,,,,,,,,,,,,,,,, -31700,0.04503823,0.02975753,,,,,,,,,,,,,,,,, -31800,0.051342543,0.030244814,,,,,,,,,,,,,,,,, -31900,0.05540103,0.030603444,,,,,,,,,,,,,,,,, -32000,0.053453166,0.030160949,,,,,,,,,,,,,,,,, -32100,0.04705421,0.027524961,,,,,,,,,,,,,,,,, -32200,0.06978684,0.03136772,,,,,,,,,,,,,,,,, -32300,0.050006393,0.028999096,,,,,,,,,,,,,,,,, -32313,,,0.9943658113479614,0.018616784363985,0.6874966935541733,0.9864301681518556,0.0492644384503364,0.2495978738414457,43793.0,0.9856852293014526,0.0522603243589401,0.2391222709634299,43793.0,10336.165597438812,15610.812860250471,10336.165597438812,5272.460424661636,1.301285743713379,0.0 -32400,0.051810894,0.029493215,,,,,,,,,,,,,,,,, -32500,0.04854439,0.028104026,,,,,,,,,,,,,,,,, -32600,0.057380352,0.029411849,,,,,,,,,,,,,,,,, -32700,0.057197493,0.031363077,,,,,,,,,,,,,,,,, -32800,0.06288756,0.03146241,,,,,,,,,,,,,,,,, -32900,0.056415986,0.03174608,,,,,,,,,,,,,,,,, -33000,0.060323175,0.028618364,,,,,,,,,,,,,,,,, -33071,,,0.9944509863853456,0.0182839594781398,0.6934071978381392,0.9864094853401184,0.0501183457672596,0.2427620945351066,43793.0,0.9856359362602234,0.0532675720751285,0.2360733337637231,43793.0,10576.188651800156,15967.97516155243,10576.188651800156,5389.545041322708,1.3359241485595703,0.0 -33100,0.057054475,0.030453572,,,,,,,,,,,,,,,,, -33200,0.05208462,0.028072897,,,,,,,,,,,,,,,,, -33300,0.052862708,0.030960785,,,,,,,,,,,,,,,,, -33400,0.065731674,0.03182047,,,,,,,,,,,,,,,,, -33500,0.06282933,0.03027743,,,,,,,,,,,,,,,,, -33600,0.050389193,0.028225964,,,,,,,,,,,,,,,,, -33700,0.05417537,0.030790016,,,,,,,,,,,,,,,,, -33800,0.049824167,0.026856655,,,,,,,,,,,,,,,,, -33823,,,0.994499444961548,0.0184213556349277,0.6916567183095863,0.9863566756248474,0.0498206280171871,0.2441304028718217,43793.0,0.9855618476867676,0.0527226999402046,0.2433782019584294,43793.0,10816.311703920364,16324.967136383057,10816.311703920364,5506.360715389252,1.3688089847564695,0.0 -33900,0.056531038,0.029213179,,,,,,,,,,,,,,,,, -34000,0.062443856,0.029893326,,,,,,,,,,,,,,,,, -34100,0.050115246,0.028339105,,,,,,,,,,,,,,,,, -34200,0.048905235,0.027182806,,,,,,,,,,,,,,,,, -34300,0.05234681,0.028814353,,,,,,,,,,,,,,,,, -34400,0.05869744,0.030816087,,,,,,,,,,,,,,,,, -34500,0.06195291,0.029726045,,,,,,,,,,,,,,,,, -34579,,,0.9933765530586244,0.0213746968656778,0.616448395606391,0.9862044453620912,0.0500890426337718,0.2375528526500311,43793.0,0.9853659868240356,0.0530076585710048,0.234444546466829,43793.0,11056.350311517715,16682.196016073227,11056.350311517715,5623.496901512146,1.4026639461517334,0.0 -34600,0.062369596,0.029841544,,,,,,,,,,,,,,,,, -34700,0.0633669,0.030267611,,,,,,,,,,,,,,,,, -34800,0.05875145,0.029944064,,,,,,,,,,,,,,,,, -34900,0.05799521,0.02901047,,,,,,,,,,,,,,,,, -35000,0.063360736,0.029302077,,,,,,,,,,,,,,,,, -35100,0.062630445,0.030027358,,,,,,,,,,,,,,,,, -35200,0.055279173,0.027814562,,,,,,,,,,,,,,,,, -35300,0.063213475,0.029560175,,,,,,,,,,,,,,,,, -35330,,,0.9937258958816528,0.0200071949511766,0.6484402300425647,0.986441135406494,0.0510506741702556,0.2465732751416903,43793.0,0.9857400059700012,0.0538732260465621,0.2414957014130568,43793.0,11296.30923128128,17036.55020093918,11296.30923128128,5737.836226701736,1.4372141361236572,0.0 -35400,0.061329775,0.0287785,,,,,,,,,,,,,,,,, -35500,0.058836084,0.030253544,,,,,,,,,,,,,,,,, -35600,0.05767288,0.028032249,,,,,,,,,,,,,,,,, -35700,0.060090505,0.028083831,,,,,,,,,,,,,,,,, -35800,0.06940593,0.031546753,,,,,,,,,,,,,,,,, -35900,0.059941642,0.029077983,,,,,,,,,,,,,,,,, -36000,0.08372617,0.029333591,,,,,,,,,,,,,,,,, -36077,,,0.993654489517212,0.0201186686754226,0.6498716669258316,0.9863563179969788,0.0509757995605468,0.2449991985495474,43793.0,0.9854881167411804,0.0542867630720138,0.233509910277653,43793.0,11536.434869289398,17395.25648856163,11536.434869289398,5856.3613159656525,1.4723279476165771,0.0 -36100,0.07200634,0.029027537,,,,,,,,,,,,,,,,, -36200,0.05871804,0.028410196,,,,,,,,,,,,,,,,, -36300,0.06340632,0.02874576,,,,,,,,,,,,,,,,, -36400,0.059700098,0.028990578,,,,,,,,,,,,,,,,, -36500,0.0623374,0.030596687,,,,,,,,,,,,,,,,, -36600,0.05994691,0.028116046,,,,,,,,,,,,,,,,, -36700,0.057492122,0.02805793,,,,,,,,,,,,,,,,, -36800,0.06757592,0.029450709,,,,,,,,,,,,,,,,, -36820,,,0.9935859441757202,0.0201924126595258,0.6461987033015026,0.9864293336868286,0.0515260584652423,0.2412587934181796,43793.0,0.9855584502220154,0.0547553859651088,0.234687398547803,43793.0,11776.005351305008,17753.795357704163,11776.005351305008,5974.8999898433685,1.880678653717041,0.0 -36900,0.060884908,0.03041711,,,,,,,,,,,,,,,,, -37000,0.065796025,0.028189091,,,,,,,,,,,,,,,,, -37100,0.058091365,0.027739016,,,,,,,,,,,,,,,,, -37200,0.061918706,0.026238522,,,,,,,,,,,,,,,,, -37300,0.061503235,0.029687697,,,,,,,,,,,,,,,,, -37400,0.079392284,0.029558793,,,,,,,,,,,,,,,,, -37500,0.06339431,0.027388932,,,,,,,,,,,,,,,,, -37569,,,0.9937519431114196,0.0198342092335224,0.6552595150078337,0.9864228963851928,0.0513176284730434,0.2467318139167215,43793.0,0.9854522943496704,0.0544375777244567,0.2326816912556495,43793.0,12016.200606584547,18110.347497224808,12016.200606584547,6091.20179772377,1.9143686294555664,0.0 -37600,0.06577043,0.029969618,,,,,,,,,,,,,,,,, -37700,0.06713559,0.027577922,,,,,,,,,,,,,,,,, -37800,0.06524223,0.028453412,,,,,,,,,,,,,,,,, -37900,0.06272375,0.02952434,,,,,,,,,,,,,,,,, -38000,0.060761657,0.026903676,,,,,,,,,,,,,,,,, -38100,0.06556626,0.029304894,,,,,,,,,,,,,,,,, -38200,0.064548664,0.027191974,,,,,,,,,,,,,,,,, -38300,0.07114408,0.028938156,,,,,,,,,,,,,,,,, -38325,,,0.9941584467887878,0.0183299854397773,0.6959283207416935,0.9863651990890504,0.0523091927170753,0.2425323032977104,43793.0,0.9854717254638672,0.0555345155298709,0.2283262606970071,43793.0,12256.40864610672,18466.906888246536,12256.40864610672,6207.4986119270325,1.9481298923492432,0.0 -38400,0.07199738,0.02804737,,,,,,,,,,,,,,,,, -38500,0.07080935,0.029083777,,,,,,,,,,,,,,,,, -38600,0.074567616,0.029015012,,,,,,,,,,,,,,,,, -38700,0.07151044,0.026688075,,,,,,,,,,,,,,,,, -38800,0.0666374,0.026902813,,,,,,,,,,,,,,,,, -38900,0.07458376,0.02703301,,,,,,,,,,,,,,,,, -39000,0.07049574,0.030101242,,,,,,,,,,,,,,,,, -39084,,,0.9948763251304626,0.0168443229049444,0.7108119769907197,0.9862645268440248,0.0518356934189796,0.242916454343494,43793.0,0.9854000806808472,0.0549664758145809,0.2311936963063237,43793.0,12496.516707897186,18823.623073339462,12496.516707897186,6324.049305677414,1.984095573425293,0.0 -39100,0.06359324,0.027332501,,,,,,,,,,,,,,,,, -39200,0.06790229,0.02743856,,,,,,,,,,,,,,,,, -39300,0.0694758,0.027034145,,,,,,,,,,,,,,,,, -39400,0.07020866,0.028907266,,,,,,,,,,,,,,,,, -39500,0.061649118,0.028189939,,,,,,,,,,,,,,,,, -39600,0.07261454,0.027547738,,,,,,,,,,,,,,,,, -39700,0.06829754,0.026649572,,,,,,,,,,,,,,,,, -39800,0.08969187,0.030175293,,,,,,,,,,,,,,,,, -39838,,,0.9953045845031738,0.0157356765121221,0.743408396758674,0.986170768737793,0.0523720234632492,0.2415788185318563,43793.0,0.9853402972221376,0.0557583570480346,0.2323686769713601,43793.0,12736.719133377075,19179.72560429573,12736.719133377075,6439.895386219025,2.017606496810913,0.0 -39900,0.071765386,0.028102532,,,,,,,,,,,,,,,,, -40000,0.085828744,0.029846193,,,,,,,,,,,,,,,,, -40100,0.06366054,0.026848031,,,,,,,,,,,,,,,,, -40200,0.06744374,0.026610918,,,,,,,,,,,,,,,,, -40300,0.06692043,0.026432944,,,,,,,,,,,,,,,,, -40400,0.07771948,0.027622472,,,,,,,,,,,,,,,,, -40500,0.072979726,0.026623923,,,,,,,,,,,,,,,,, -40588,,,0.995507836341858,0.0151507463306188,0.7610764629387446,0.9862247705459596,0.0528991483151912,0.2365607502634769,43793.0,0.9853579998016356,0.0561539120972156,0.2277253445759712,43793.0,12976.663056850432,19532.37944626808,12976.663056850432,6552.548879623413,2.0531373023986816,0.0 -40600,0.06491581,0.025902996,,,,,,,,,,,,,,,,, -40700,0.08326058,0.02734756,,,,,,,,,,,,,,,,, -40800,0.07206325,0.026182115,,,,,,,,,,,,,,,,, -40900,0.074245155,0.027509915,,,,,,,,,,,,,,,,, -41000,0.078688666,0.028105453,,,,,,,,,,,,,,,,, -41100,0.072719105,0.027057068,,,,,,,,,,,,,,,,, -41200,0.07899589,0.02709287,,,,,,,,,,,,,,,,, -41300,0.068271294,0.025754597,,,,,,,,,,,,,,,,, -41337,,,0.9952184557914734,0.0158216003328561,0.7453891223328923,0.9862101674079896,0.0529556274414062,0.2418717596601929,43793.0,0.985393762588501,0.056329395622015,0.2293886822357372,43793.0,13216.849076271055,19884.684267520905,13216.849076271055,6664.61195230484,2.0880980491638184,0.0 -41400,0.07094576,0.025611637,,,,,,,,,,,,,,,,, -41500,0.08336233,0.029815324,,,,,,,,,,,,,,,,, -41600,0.08196915,0.027678398,,,,,,,,,,,,,,,,, -41700,0.07497471,0.026085505,,,,,,,,,,,,,,,,, -41800,0.07040791,0.026340812,,,,,,,,,,,,,,,,, -41900,0.07756412,0.026051905,,,,,,,,,,,,,,,,, -42000,0.07507182,0.026700024,,,,,,,,,,,,,,,,, -42091,,,0.99482524394989,0.0167483184486627,0.7169117412420599,0.9861918687820436,0.0537334345281124,0.2353393975103577,43793.0,0.985342800617218,0.0567243658006191,0.231405241881548,43793.0,13457.108339548113,20238.608663082123,13457.108339548113,6778.221654415131,2.122004747390747,0.0 -42100,0.06709084,0.025134563,,,,,,,,,,,,,,,,, -42200,0.07810965,0.026598595,,,,,,,,,,,,,,,,, -42300,0.07683037,0.028498406,,,,,,,,,,,,,,,,, -42400,0.07918554,0.027391488,,,,,,,,,,,,,,,,, -42500,0.06491657,0.025651114,,,,,,,,,,,,,,,,, -42600,0.076771155,0.026719319,,,,,,,,,,,,,,,,, -42700,0.075086735,0.026575051,,,,,,,,,,,,,,,,, -42800,0.08600445,0.027261965,,,,,,,,,,,,,,,,, -42845,,,0.993688941001892,0.0195393487811088,0.6601647292398976,0.9862625002861024,0.0540768094360828,0.2355572185518675,43793.0,0.9853861927986144,0.0574819557368755,0.227422304637655,43793.0,13697.295625209808,20591.67139029503,13697.295625209808,6891.040193080902,2.158005714416504,0.0 -42900,0.074460156,0.02494267,,,,,,,,,,,,,,,,, -43000,0.0741788,0.025374388,,,,,,,,,,,,,,,,, -43100,0.07725272,0.027055731,,,,,,,,,,,,,,,,, -43200,0.0748864,0.02632945,,,,,,,,,,,,,,,,, -43300,0.0757925,0.026824588,,,,,,,,,,,,,,,,, -43400,0.08588723,0.028369466,,,,,,,,,,,,,,,,, -43500,0.09567371,0.026585897,,,,,,,,,,,,,,,,, -43600,0.08019046,0.027218403,,,,,,,,,,,,,,,,, -43601,,,0.9939095377922058,0.0189428441226482,0.6758613495698387,0.986154556274414,0.0547237023711204,0.2323547004349432,43793.0,0.9854089617729188,0.0578960292041301,0.2298433674016478,43793.0,13937.247032403946,20946.570340633392,13937.247032403946,7005.932266712189,2.1925864219665527,0.0 -43700,0.085581526,0.027507495,,,,,,,,,,,,,,,,, -43800,0.07075411,0.024913236,,,,,,,,,,,,,,,,, -43900,0.07975878,0.026353633,,,,,,,,,,,,,,,,, -44000,0.08401298,0.025605775,,,,,,,,,,,,,,,,, -44100,0.08265468,0.026015822,,,,,,,,,,,,,,,,, -44200,0.08163446,0.025870038,,,,,,,,,,,,,,,,, -44300,0.10268317,0.028154235,,,,,,,,,,,,,,,,, -44353,,,0.994534432888031,0.0170852746814489,0.7098776166423071,0.9860822558403016,0.0548556223511695,0.2360641167700365,43793.0,0.9852712154388428,0.058071594685316,0.225607776266737,43793.0,14177.268211841583,21297.09924530983,14177.268211841583,7116.383370637894,2.228288173675537,0.0 -44400,0.07483181,0.023807853,,,,,,,,,,,,,,,,, -44500,0.075631835,0.02537609,,,,,,,,,,,,,,,,, -44600,0.09055327,0.02704191,,,,,,,,,,,,,,,,, -44700,0.0837208,0.024858404,,,,,,,,,,,,,,,,, -44800,0.103006415,0.026506793,,,,,,,,,,,,,,,,, -44900,0.083725385,0.027054705,,,,,,,,,,,,,,,,, -45000,0.077696465,0.025381386,,,,,,,,,,,,,,,,, -45098,,,0.9942643046379088,0.0177052374929189,0.7059773931270805,0.986059546470642,0.0551608502864837,0.2314232975069496,43793.0,0.9852467775344848,0.0584373772144317,0.2243313486320527,43793.0,14417.23194217682,21651.221431732178,14417.23194217682,7230.485638141632,2.2639756202697754,0.0 -45100,0.079044595,0.025996624,,,,,,,,,,,,,,,,, -45200,0.07550698,0.025934482,,,,,,,,,,,,,,,,, -45300,0.075329416,0.024426216,,,,,,,,,,,,,,,,, -45400,0.083504245,0.02473464,,,,,,,,,,,,,,,,, -45500,0.086384356,0.025876684,,,,,,,,,,,,,,,,, -45600,0.09700825,0.025566192,,,,,,,,,,,,,,,,, -45700,0.08966132,0.025788315,,,,,,,,,,,,,,,,, -45800,0.08433766,0.024759982,,,,,,,,,,,,,,,,, -45843,,,0.9943100214004515,0.0175719261169433,0.711232518020157,0.9860436916351318,0.0557162389159202,0.2310569725032088,43793.0,0.9852017164230348,0.0595104843378067,0.2203197279736102,43793.0,14657.17865562439,22004.224217414856,14657.17865562439,7343.484463214874,2.3002424240112305,0.0 -45900,0.09271649,0.025311017,,,,,,,,,,,,,,,,, -46000,0.09367594,0.026465658,,,,,,,,,,,,,,,,, -46100,0.07527294,0.024363048,,,,,,,,,,,,,,,,, -46200,0.088081524,0.024500532,,,,,,,,,,,,,,,,, -46300,0.08942863,0.025804939,,,,,,,,,,,,,,,,, -46400,0.094499305,0.02446098,,,,,,,,,,,,,,,,, -46500,0.07845964,0.026358688,,,,,,,,,,,,,,,,, -46594,,,0.9962734580039978,0.0129803111776709,0.8045735996318846,0.9860225915908812,0.0556060001254081,0.2332473908672151,43793.0,0.9852076172828674,0.0586089491844177,0.2262182866703456,43793.0,14897.215474367142,22358.96573448181,14897.215474367142,7458.133021593094,2.3357396125793457,0.0 -46600,0.09377461,0.025890853,,,,,,,,,,,,,,,,, -46700,0.099809974,0.026971592,,,,,,,,,,,,,,,,, -46800,0.08693416,0.02542859,,,,,,,,,,,,,,,,, -46900,0.085471176,0.026433527,,,,,,,,,,,,,,,,, -47000,0.07579168,0.023775915,,,,,,,,,,,,,,,,, -47100,0.085673556,0.025634022,,,,,,,,,,,,,,,,, -47200,0.09150782,0.025371317,,,,,,,,,,,,,,,,, -47300,0.0962195,0.026989209,,,,,,,,,,,,,,,,, -47350,,,0.9964053630828856,0.0125846806913614,0.8052383302908812,0.9860960841178894,0.056346520781517,0.2354068295352333,43793.0,0.9852636456489564,0.0596487857401371,0.2282927555369014,43793.0,15137.289443016052,22709.00346231461,15137.289443016052,7568.040026426315,2.371712207794189,0.0 -47400,0.07993744,0.023923082,,,,,,,,,,,,,,,,, -47500,0.0919345,0.024928333,,,,,,,,,,,,,,,,, -47600,0.095127046,0.02483287,,,,,,,,,,,,,,,,, -47700,0.07988507,0.024695406,,,,,,,,,,,,,,,,, -47800,0.08504211,0.025036762,,,,,,,,,,,,,,,,, -47900,0.09647974,0.025008932,,,,,,,,,,,,,,,,, -48000,0.09099619,0.023590833,,,,,,,,,,,,,,,,, -48094,,,0.9965007305145264,0.0123935807496309,0.8195799484038679,0.98596453666687,0.0567309893667697,0.2269859560014178,43793.0,0.9851579070091248,0.0601330213248729,0.2245630058696907,43793.0,15377.269327640532,23060.869948148727,15377.269327640532,7679.870763301849,2.407210350036621,0.0 -48100,0.077533856,0.023874069,,,,,,,,,,,,,,,,, -48200,0.089310735,0.023902683,,,,,,,,,,,,,,,,, -48300,0.07416212,0.02273475,,,,,,,,,,,,,,,,, -48400,0.080436096,0.025114192,,,,,,,,,,,,,,,,, -48500,0.10091297,0.02652699,,,,,,,,,,,,,,,,, -48600,0.07853329,0.024135733,,,,,,,,,,,,,,,,, -48700,0.09481102,0.024442952,,,,,,,,,,,,,,,,, -48800,0.10014164,0.025571845,,,,,,,,,,,,,,,,, -48846,,,0.9960833191871644,0.0130016067996621,0.8006489196358728,0.9861443638801576,0.0572860650718212,0.2332671815231344,43793.0,0.9852501749992372,0.0608636960387229,0.2234635566048425,43793.0,15617.50436925888,23413.28269124031,15617.50436925888,7791.991349935532,2.443833827972412,0.0 -48900,0.09082024,0.023524893,,,,,,,,,,,,,,,,, -49000,0.104400694,0.024919888,,,,,,,,,,,,,,,,, -49100,0.0814435,0.0229936,,,,,,,,,,,,,,,,, -49200,0.08809446,0.025476065,,,,,,,,,,,,,,,,, -49300,0.09607958,0.023837069,,,,,,,,,,,,,,,,, -49400,0.096696325,0.023955999,,,,,,,,,,,,,,,,, -49500,0.09427276,0.023759838,,,,,,,,,,,,,,,,, -49595,,,0.9956724047660828,0.0138410003855824,0.797884302093334,0.9861240983009338,0.0572965294122695,0.2315614853424323,43793.0,0.9852871894836426,0.0605197958648204,0.2236626675084078,43793.0,15857.626742601396,23764.23997664452,15857.626742601396,7902.767471790314,2.482178211212158,0.0 -49600,0.09118014,0.024196135,,,,,,,,,,,,,,,,, -49700,0.10755081,0.026295649,,,,,,,,,,,,,,,,, -49800,0.08542167,0.02495909,,,,,,,,,,,,,,,,, -49900,0.10258734,0.024248106,,,,,,,,,,,,,,,,, -50000,0.10307307,0.023678528,,,,,,,,,,,,,,,,, -50100,0.09015167,0.023982605,,,,,,,,,,,,,,,,, -50200,0.08551275,0.022324461,,,,,,,,,,,,,,,,, -50300,0.096700534,0.024116853,,,,,,,,,,,,,,,,, -50347,,,0.9952304363250732,0.0149086937308311,0.7654407735773269,0.9860441088676452,0.0581214651465415,0.231803685346562,43793.0,0.9851604104042052,0.0614901892840862,0.2223522837272326,43793.0,16097.812799215317,24110.558915138245,16097.812799215317,8008.843861818314,2.518423318862915,0.0 -50400,0.08440688,0.02295816,,,,,,,,,,,,,,,,, -50500,0.08881361,0.023979373,,,,,,,,,,,,,,,,, -50600,0.09105808,0.024207896,,,,,,,,,,,,,,,,, -50700,0.09265617,0.024570843,,,,,,,,,,,,,,,,, -50800,0.08992469,0.024765663,,,,,,,,,,,,,,,,, -50900,0.10119565,0.025339099,,,,,,,,,,,,,,,,, -51000,0.10129445,0.024123942,,,,,,,,,,,,,,,,, -51099,,,0.9944838285446168,0.0166831631213426,0.735212748790653,0.9861963391304016,0.0584785975515842,0.230818436316012,43793.0,0.9852526783943176,0.062040738761425,0.2210060273314866,43793.0,16337.775570392609,24461.12946128845,16337.775570392609,8119.3956298828125,2.554534435272217,0.0 -51100,0.09930264,0.023742463,,,,,,,,,,,,,,,,, -51200,0.09680889,0.02389547,,,,,,,,,,,,,,,,, -51300,0.07435206,0.023126429,,,,,,,,,,,,,,,,, -51400,0.09976913,0.02346017,,,,,,,,,,,,,,,,, -51500,0.09134117,0.023294816,,,,,,,,,,,,,,,,, -51600,0.104730695,0.024575016,,,,,,,,,,,,,,,,, -51700,0.07915842,0.023239044,,,,,,,,,,,,,,,,, -51800,0.09603318,0.02421003,,,,,,,,,,,,,,,,, -51855,,,0.9948172569274902,0.0157752875238657,0.7602446975067121,0.9861037731170654,0.0588659830391407,0.231489586881311,43793.0,0.9852792024612428,0.0623346120119094,0.2210380329015674,43793.0,16577.728811979294,24809.496871471405,16577.728811979294,8227.75286102295,2.590883016586304,0.0 -51900,0.10090698,0.02544246,,,,,,,,,,,,,,,,, -52000,0.083599135,0.024016762,,,,,,,,,,,,,,,,, -52100,0.0874752,0.024643201,,,,,,,,,,,,,,,,, -52200,0.10101675,0.023384439,,,,,,,,,,,,,,,,, -52300,0.087314814,0.023547659,,,,,,,,,,,,,,,,, -52400,0.109753214,0.023790276,,,,,,,,,,,,,,,,, -52500,0.09911168,0.025183633,,,,,,,,,,,,,,,,, -52600,0.14482239,0.02699485,,,,,,,,,,,,,,,,, -52612,,,0.9952720999717712,0.0145624438300728,0.7853455614625474,0.9860274791717528,0.0588804148137569,0.2315921501409306,43793.0,0.9851730465888976,0.0626277402043342,0.2172134167158539,43793.0,16817.936812877655,25159.259149074554,16817.936812877655,8337.247876405716,2.629739284515381,0.0 -52700,0.080949806,0.02238277,,,,,,,,,,,,,,,,, -52800,0.09022975,0.023623688,,,,,,,,,,,,,,,,, -52900,0.09623349,0.024831709,,,,,,,,,,,,,,,,, -53000,0.08609832,0.023060745,,,,,,,,,,,,,,,,, -53100,0.08790736,0.024413144,,,,,,,,,,,,,,,,, -53200,0.101405196,0.023316797,,,,,,,,,,,,,,,,, -53300,0.14266035,0.027309816,,,,,,,,,,,,,,,,, -53365,,,0.9949951767921448,0.0151583664119243,0.7691710098606686,0.9860047698020936,0.0594560466706752,0.229525978638166,43793.0,0.985157072544098,0.0631423890590667,0.2154677620242295,43793.0,17057.958690166473,25511.86580467224,17057.958690166473,8449.776660203934,2.665452480316162,0.0 -53400,0.09959425,0.02330208,,,,,,,,,,,,,,,,, -53500,0.08705601,0.02307375,,,,,,,,,,,,,,,,, -53600,0.08876185,0.023066007,,,,,,,,,,,,,,,,, -53700,0.09317399,0.023120357,,,,,,,,,,,,,,,,, -53800,0.09205514,0.023237815,,,,,,,,,,,,,,,,, -53900,0.099432416,0.02337186,,,,,,,,,,,,,,,,, -54000,0.09648861,0.023862971,,,,,,,,,,,,,,,,, -54100,0.08441891,0.022262467,,,,,,,,,,,,,,,,, -54120,,,0.9957131147384644,0.0136777451261878,0.805361075076067,0.9861322045326232,0.0605087429285049,0.2302780896123628,43793.0,0.9852569103240968,0.0641731023788452,0.2164415117984676,43793.0,17298.226668834686,25861.56529688835,17298.226668834686,8559.149673700333,2.703575134277344,0.0 -54200,0.09469388,0.022972027,,,,,,,,,,,,,,,,, -54300,0.09639499,0.023060048,,,,,,,,,,,,,,,,, -54400,0.09084543,0.022222206,,,,,,,,,,,,,,,,, -54500,0.0891746,0.023143355,,,,,,,,,,,,,,,,, -54600,0.08474285,0.022537034,,,,,,,,,,,,,,,,, -54700,0.08957765,0.024188481,,,,,,,,,,,,,,,,, -54800,0.07344625,0.021825971,,,,,,,,,,,,,,,,, -54869,,,0.9979147911071776,0.009291942231357,0.8778368365167077,0.9860007166862488,0.0598200224339962,0.2295835965197492,43793.0,0.9850584864616394,0.0636362731456756,0.2129353179469481,43793.0,17538.3272023201,26212.610743045807,17538.3272023201,8670.035662651062,2.741891622543335,0.0 -54900,0.1108055,0.02312094,,,,,,,,,,,,,,,,, -55000,0.08491816,0.02236018,,,,,,,,,,,,,,,,, -55100,0.08800136,0.023075871,,,,,,,,,,,,,,,,, -55200,0.09589562,0.022330929,,,,,,,,,,,,,,,,, -55300,0.08371707,0.022376932,,,,,,,,,,,,,,,,, -55400,0.10156616,0.023602229,,,,,,,,,,,,,,,,, -55500,0.08847153,0.022171887,,,,,,,,,,,,,,,,, -55600,0.08489247,0.022639168,,,,,,,,,,,,,,,,, -55620,,,0.9977280497550964,0.0094415247440338,0.886361785591973,0.9860274791717528,0.0599618293344974,0.2274789746190675,43793.0,0.9850918054580688,0.0635327771306037,0.2129513821616357,43793.0,17778.506385326385,26565.708988189697,17778.506385326385,8782.897083044052,2.778715848922729,0.0 -55700,0.08138003,0.02214327,,,,,,,,,,,,,,,,, -55800,0.08411133,0.02256832,,,,,,,,,,,,,,,,, -55900,0.083255015,0.022508739,,,,,,,,,,,,,,,,, -56000,0.076263584,0.02168837,,,,,,,,,,,,,,,,, -56100,0.08825501,0.02283138,,,,,,,,,,,,,,,,, -56200,0.07980229,0.022851622,,,,,,,,,,,,,,,,, -56300,0.09355178,0.02221301,,,,,,,,,,,,,,,,, -56372,,,0.9974233508110046,0.0099036889150738,0.8731948836777758,0.9860575199127196,0.0605407804250717,0.2240677321307488,43793.0,0.9851629734039308,0.064345933496952,0.210049108776192,43793.0,18018.62752890587,26921.62234067917,18018.62752890587,8898.630964279175,2.816796064376831,0.0 -56400,0.08595392,0.022085264,,,,,,,,,,,,,,,,, -56500,0.09087035,0.022926537,,,,,,,,,,,,,,,,, -56600,0.10599432,0.021551438,,,,,,,,,,,,,,,,, -56700,0.09124279,0.022029243,,,,,,,,,,,,,,,,, -56800,0.08594089,0.022553477,,,,,,,,,,,,,,,,, -56900,0.082991585,0.021773046,,,,,,,,,,,,,,,,, -57000,0.08286864,0.023208736,,,,,,,,,,,,,,,,, -57100,0.09310004,0.022300873,,,,,,,,,,,,,,,,, -57128,,,0.9969244599342346,0.0106557002291083,0.8580153795451726,0.986025869846344,0.060913011431694,0.2241567975538667,43793.0,0.9850812554359436,0.0648852512240409,0.2135274376320938,43793.0,18258.6579618454,27272.670634508133,18258.6579618454,9009.587788581848,2.856903314590454,0.0 -57200,0.07691746,0.02166561,,,,,,,,,,,,,,,,, -57300,0.0932962,0.021754265,,,,,,,,,,,,,,,,, -57400,0.08578686,0.022610517,,,,,,,,,,,,,,,,, -57500,0.08124464,0.021694787,,,,,,,,,,,,,,,,, -57600,0.094319336,0.022741506,,,,,,,,,,,,,,,,, -57700,0.10259048,0.022981297,,,,,,,,,,,,,,,,, -57800,0.07072325,0.021754729,,,,,,,,,,,,,,,,, -57807,,,,,,,,,,,,,,18477.262572288513,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/eval_measurements.csv deleted file mode 100644 index c28ac0160..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -112.40478038787842,0.0,11.536062002182009,1,0,11.536062002182009,0.3947484791278839,0.7956756353378296,0.0263854595591664,43793,123.94088912010191,0.3884694278240204,0.7994099855422974,0.0218700491084387,0.3926452994346618,0.7974664568901062,0.0248372634449589,43793 -224.29861760139465,0.024303913116455,251.66132426261905,752,0,251.66132426261905,0.983142077922821,0.0816964283585548,0.0397184025301055,43793,476.0043268203736,0.986751914024353,0.071010872721672,0.0352122054391893,0.9841179251670836,0.0789842307567596,0.0382685972028179,43793 -339.94337153434753,0.0521657466888427,491.7242970466614,1500,0,491.7242970466614,0.9836398959159852,0.0608634054660797,0.0885365779737749,43793,831.7605443000793,0.9871714115142822,0.0486308261752128,0.0883432783390799,0.9846265912055968,0.0575493611395359,0.090297576243713,43793 -450.96800208091736,0.079129934310913,731.9238619804382,2248,0,731.9238619804382,0.9841011166572572,0.0555038340389728,0.1326923946081284,43793,1183.0322844982147,0.9878830909729004,0.0436494052410125,0.1356913549852444,0.9850901365280152,0.052616462111473,0.1338274383866124,43793 -559.0434498786926,0.1082615852355957,972.1416659355164,2995,0,972.1416659355164,0.9844043850898744,0.0541264489293098,0.1550618006766275,43793,1531.3750817775726,0.988243281841278,0.0413642711937427,0.1645979659390785,0.9853106141090392,0.0512285642325878,0.1607643725562286,43793 -669.6453518867493,0.1363131999969482,1212.3301203250885,3746,0,1212.3301203250885,0.984574556350708,0.0529424399137496,0.1661460566473754,43793,1882.2137567996976,0.9883965849876404,0.0402144491672515,0.1835936500380219,0.985468089580536,0.050165306776762,0.1661972704844983,43793 -783.5995593070984,0.1643040180206298,1452.4844810962677,4495,0,1452.4844810962677,0.9847337603569032,0.0519829653203487,0.1837031279984529,43793,2236.3710923194885,0.9884520173072816,0.0397511273622512,0.2036303545138438,0.9856345057487488,0.0494197458028793,0.1831442118251721,43793 -898.0833630561829,0.1930367946624755,1692.6457846164703,5238,0,1692.6457846164703,0.9850395321846008,0.0501854233443737,0.1966521375545678,43793,2591.0650255680084,0.9888370633125304,0.0381620749831199,0.2267099895853587,0.9859211444854736,0.0475562997162342,0.1990976778900939,43793 -1013.100502729416,0.22141695022583,1932.7427368164065,5982,0,1932.7427368164065,0.9850606322288512,0.0498638488352298,0.2016183392127283,43793,2946.22796201706,0.9888089299201964,0.037779688835144,0.2349900976866167,0.98591947555542,0.047309760004282,0.2019110328460734,43793 -1124.8570773601532,0.2525761127471924,2172.955377101898,6722,0,2172.955377101898,0.9852674007415771,0.0491311475634574,0.2183474994479768,43793,3298.24889087677,0.9892280697822572,0.0368048213422298,0.273699303325344,0.9860668778419496,0.0467587560415267,0.2127030918890502,43793 -1239.154440164566,0.2822961807250976,2413.086009979248,7469,0,2413.086009979248,0.9851044416427612,0.0503110401332378,0.2082329961052818,43793,3652.726801872253,0.9891047477722168,0.0365598015487194,0.2647126914659687,0.9860149025917052,0.0473848022520542,0.2152321516486376,43793 -1350.2171156406405,0.3105733394622803,2653.050799131393,8209,0,2653.050799131393,0.9853959083557128,0.0490268431603908,0.2199700243822788,43793,4003.8033468723297,0.9893330335617064,0.0359494984149932,0.2700580912185201,0.9862012267112732,0.0464534796774387,0.2239962696330975,43793 -1465.2301201820374,0.3409047126770019,2893.292719125748,8959,0,2893.292719125748,0.9855567812919616,0.0479598231613636,0.2320619693990359,43793,4359.108901500702,0.989347517490387,0.0358376614749431,0.2764391488246033,0.9863911867141724,0.0454035401344299,0.2295555249097951,43793 -1577.3784453868866,0.3689651489257812,3133.352974653244,9715,0,3133.352974653244,0.9855496287345886,0.048551145941019,0.2278495689342686,43793,4711.366160392761,0.9895155429840088,0.0352685004472732,0.2899146334264587,0.9863510131835938,0.0459966510534286,0.2286938271593316,43793 -1692.609769821167,0.397641658782959,3373.53779745102,10451,0,3373.53779745102,0.9855100512504578,0.0483393445611,0.2322145265859602,43793,5066.831092119217,0.9897177219390868,0.0344866290688514,0.3137916619433596,0.9864187836647034,0.0456844270229339,0.2386852578627625,43793 -1801.215446472168,0.4285142421722412,3613.533330202103,11196,0,3613.533330202103,0.9856237769126892,0.0480704680085182,0.244828373988002,43793,5415.48432302475,0.9898854494094848,0.033528134226799,0.3287181087338847,0.9863871335983276,0.0453821495175361,0.2417079907701178,43793 -1914.5718655586245,0.4614851474761963,3853.5349090099335,11934,0,3853.5349090099335,0.9855976104736328,0.0480018183588981,0.232307527215203,43793,5768.89560508728,0.9900333285331726,0.033205009996891,0.3330144534538547,0.9865024089813232,0.0452807024121284,0.2396226332273137,43793 -2022.8337025642395,0.4917325973510742,4093.756762266159,12685,0,4093.756762266159,0.9856582880020142,0.0479451939463615,0.2415683202863619,43793,6117.430259227753,0.9902761578559875,0.0320247150957584,0.3787546138420707,0.9864386916160583,0.0451642982661724,0.2510761793891903,43793 -2134.716328382492,0.5204606056213379,4333.850111484528,13433,0,4333.850111484528,0.9858389496803284,0.0473495759069919,0.2462638447057491,43793,6469.4553780555725,0.9904384613037108,0.0314989201724529,0.3717057181802991,0.9866668581962584,0.0446271151304245,0.2472126803363637,43793 -2245.7474250793457,0.549689769744873,4573.89826631546,14178,0,4573.89826631546,0.9856254458427428,0.0479994229972362,0.2504320785663486,43793,6820.58492398262,0.9904738664627076,0.0314284339547157,0.3734624472438089,0.9864821434020996,0.0451357886195182,0.2507288477370744,43793 -2355.645773649216,0.5813260078430176,4813.926500082016,14921,0,4813.926500082016,0.9858343601226808,0.0468188151717186,0.2576561377273499,43793,7170.563362598419,0.9905796051025392,0.0311626549810171,0.3785872396635164,0.9866968989372252,0.0440506637096405,0.2621364522054176,43793 -2463.657342433929,0.611168384552002,5053.890505075455,15673,0,5053.890505075455,0.9859552383422852,0.0467160493135452,0.2549108661049648,43793,7518.589793205261,0.990501582622528,0.0315143540501594,0.3703973020116818,0.986763060092926,0.0441457703709602,0.2611574499676861,43793 -2574.221416711808,0.6401748657226562,5293.863883733749,16426,0,5293.863883733749,0.9858486652374268,0.0467076227068901,0.2628098371482614,43793,7869.176714420319,0.9906003475189208,0.0310765262693166,0.3902105435847471,0.9866453409194946,0.0442377552390098,0.2586473980508823,43793 -2687.1701986789703,0.6697847843170166,5533.909998893738,17178,0,5533.909998893738,0.9858617186546326,0.0468544140458107,0.255345915528305,43793,8222.222283363342,0.9905943274497986,0.0308886766433715,0.3930183406771667,0.9867305755615234,0.0441368259489536,0.2651857368659311,43793 -2798.502618312836,0.699364423751831,5773.913890838623,17930,0,5773.913890838623,0.9858819246292114,0.046987097710371,0.2571229912207292,43793,8573.608848333359,0.9906186461448668,0.0308266170322895,0.3925631480101053,0.9867184162139891,0.0444091372191906,0.2627921188302755,43793 -2903.7837584018707,0.73091721534729,6013.996770858765,18681,0,6013.996770858765,0.9858953952789308,0.0472092404961586,0.2572136924916275,43793,8919.025218486786,0.9908760190010072,0.0300082452595233,0.4102506283303613,0.986647367477417,0.0445027761161327,0.2655605888096022,43793 -3019.294206380844,0.7621853351593018,6254.042366743088,19432,0,6254.042366743088,0.985889494419098,0.0474565178155899,0.2520375820148166,43793,9274.63326048851,0.9908602237701416,0.0298089552670717,0.4269604107545199,0.9866859316825868,0.0446731783449649,0.2662054523909674,43793 -3128.8182978630066,0.7937026023864746,6494.278495073319,20175,0,6494.278495073319,0.9858208894729614,0.0472513772547245,0.2518622757509492,43793,9624.445745706558,0.9910188317298888,0.0289773251861333,0.4231962341096236,0.9867512583732604,0.0445153228938579,0.2637522492694075,43793 -3240.6934225559235,0.8236665725708008,6734.232186079025,20920,0,6734.232186079025,0.985737442970276,0.0471720993518829,0.2560707943968566,43793,9976.325129032137,0.9911974668502808,0.0287505108863115,0.4392576285815208,0.9865494966506958,0.0446839854121208,0.2564115413675414,43793 -3348.290340423584,0.854386568069458,6974.358366250992,21667,0,6974.358366250992,0.9859662055969238,0.0467855930328369,0.2591121619136606,43793,10324.099648237228,0.9909592270851136,0.0293954890221357,0.4405538692550107,0.9868373274803162,0.0441310405731201,0.2741253771401372,43793 -3455.4900765419006,0.8864257335662842,7214.402206897736,22414,0,7214.402206897736,0.9860116839408876,0.0476753339171409,0.255349596446049,43793,10671.39572572708,0.9908499121665956,0.0297807045280933,0.4160635084609168,0.9868572354316713,0.0448806248605251,0.2624296437397281,43793 -3569.4941403865814,0.9170982837677002,7454.508299827576,23167,0,7454.508299827576,0.9858132600784302,0.0473533831536769,0.2534655795057701,43793,11025.55727314949,0.9909346103668212,0.0296036303043365,0.4199905114907307,0.9866583347320556,0.0445701442658901,0.2730043897985921,43793 -3678.099539041519,0.9474050998687744,7694.484705686569,23921,0,7694.484705686569,0.9859514236450196,0.0475130304694175,0.2569773068587559,43793,11374.190093517303,0.9909778833389282,0.0295232720673084,0.4255365867503188,0.9867942929267884,0.0446453690528869,0.2707242531647493,43793 -3784.543060064316,0.9804179668426514,7934.573546886444,24666,0,7934.573546886444,0.9860731363296508,0.0475754104554653,0.2620029472242,43793,11720.776325702667,0.9910315871238708,0.0290667042136192,0.4393065140763119,0.9868904948234558,0.0446149222552776,0.2751389474181981,43793 -3898.072008609772,1.0132358074188232,8174.771550655365,25413,0,8174.771550655365,0.9860171675682068,0.047292198985815,0.2669060818825145,43793,12074.557059288025,0.9910230040550232,0.0289991591125726,0.4279734510210868,0.986810564994812,0.0445956252515316,0.2719177220954323,43793 -4007.445623874664,1.0445480346679688,8414.809683322906,26165,0,8414.809683322906,0.9857838153839112,0.0471026189625263,0.258419016985112,43793,12424.020931243896,0.9912290573120116,0.0283614490181207,0.4525934288065413,0.9866153001785278,0.0446678176522254,0.2719697859873213,43793 -4116.604898691177,1.077523708343506,8654.877965211868,26920,0,8654.877965211868,0.985987663269043,0.0471849218010902,0.2646887047154832,43793,12773.30180120468,0.991524875164032,0.0273654703050851,0.4807769372003801,0.9869144558906556,0.0444242358207702,0.2759751459481721,43793 -4223.421566724777,1.1091325283050537,8895.08589887619,27672,0,8895.08589887619,0.9859206676483154,0.0474046319723129,0.2627906059002099,43793,13120.378950834274,0.9917396306991576,0.0267806220799684,0.487505205329753,0.9867861866950988,0.0446733795106411,0.2727244827693121,43793 -4327.597607374191,1.1407244205474854,9135.200062274933,28419,0,9135.200062274933,0.9859859943389891,0.0476726926863193,0.2642261563246968,43793,13464.72099852562,0.9916260838508606,0.0269398819655179,0.4857912710610084,0.9868592619895936,0.0448070652782917,0.2716923961335559,43793 -4430.912664890289,1.174621820449829,9375.395035743712,29173,0,9375.395035743712,0.985869288444519,0.0475153736770153,0.2566216951670911,43793,13808.285351991652,0.991509735584259,0.0274159666150808,0.4712679039979643,0.9867748022079468,0.0447444543242454,0.2703691700265554,43793 -4536.631418704987,1.2075426578521729,9615.55433011055,29924,0,9615.55433011055,0.985859215259552,0.0479270778596401,0.260950514393316,43793,14154.216762781143,0.9912999868392944,0.0282875876873731,0.4548449958025577,0.9867565631866456,0.0449810102581977,0.2710515862463466,43793 -4643.205877542496,1.239299774169922,9855.597105264664,30677,0,9855.597105264664,0.9858810901641846,0.0478551611304283,0.2620535742729729,43793,14500.886621952057,0.9913409948349,0.0280736107379198,0.4493626684822014,0.9867821335792542,0.044753324240446,0.2723217587699308,43793 -4750.179671287537,1.2727985382080078,10095.629689216614,31434,0,10095.629689216614,0.985932469367981,0.0474705062806606,0.2614179156552181,43793,14847.947264671326,0.9913233518600464,0.0280889272689819,0.4580684687707109,0.9867622256278992,0.0449602045118808,0.2680808067314692,43793 -4859.741172552109,1.325857400894165,10335.560795545578,32188,0,10335.560795545578,0.9858309626579284,0.0481604635715484,0.261548027310026,43793,15197.513407230375,0.9914279580116272,0.0274613872170448,0.4688236204403521,0.986764669418335,0.0450935922563076,0.2805715467170202,43793 -4967.086313724518,1.359889268875122,10575.767451763151,32936,0,10575.767451763151,0.9859328866004944,0.0476415865123271,0.2580936874518775,43793,15545.120171785356,0.9914365410804749,0.0274828933179378,0.4736294815094482,0.986638844013214,0.0448877029120922,0.2706954798470259,43793 -5078.3935606479645,1.392927169799805,10815.805988311768,33684,0,10815.805988311768,0.9859544038772584,0.0479418747127056,0.2566208483574557,43793,15896.519999742508,0.991641879081726,0.0267317239195108,0.4813694209161578,0.9868279695510864,0.0449778325855731,0.2782887354386499,43793 -5183.58104300499,1.426476001739502,11055.845739364624,34436,0,11055.845739364624,0.985964059829712,0.0477706789970397,0.267297587830477,43793,16241.80126285553,0.9919087290763856,0.0259376429021358,0.505223532274534,0.9868279695510864,0.0449526906013488,0.2817446194608765,43793 -5296.110832691193,1.461381196975708,11296.034619808195,35189,0,11296.034619808195,0.9858478307724,0.0475787110626697,0.268802852899132,43793,16594.574902057648,0.9923018217086792,0.0247634388506412,0.5400334925387037,0.9866660237312316,0.0450561232864856,0.2804774655869647,43793 -5409.861083984375,1.496187686920166,11536.095036268234,35940,0,11536.095036268234,0.9859143495559692,0.0483407154679298,0.2643073077441462,43793,16948.44060611725,0.9923170804977416,0.0245520714670419,0.5424872716795479,0.9867898225784302,0.0454942397773265,0.2835340130049796,43793 -5518.131883144379,1.530383825302124,11776.166133642197,36692,0,11776.166133642197,0.9858882427215576,0.0482158660888671,0.2715438217701766,43793,17296.837022542953,0.9918184876441956,0.0261344145983457,0.5122857286942415,0.9867236614227296,0.0451662205159664,0.2817615748781584,43793 -5627.343741178513,1.5651028156280518,12016.210379838943,37454,0,12016.210379838943,0.9857577085494996,0.0483944676816463,0.2608674631916036,43793,17646.148535490036,0.9918424487113952,0.0261771418154239,0.4943024589468207,0.9865970015525818,0.0453442595899105,0.2697746331850882,43793 -5733.825747728348,1.5990040302276611,12256.275754451752,38209,0,12256.275754451752,0.985753059387207,0.048491571098566,0.2581400577281293,43793,17992.750111341476,0.991769313812256,0.0264698714017868,0.4968405851355063,0.9866034984588624,0.0454076938331127,0.2799366756800923,43793 -5839.652360200882,1.6322979927062988,12496.309900045397,38963,0,12496.309900045397,0.9859619736671448,0.0483974255621433,0.2601541778417228,43793,18338.66453528404,0.9920052289962769,0.0257331114262342,0.4961674082046797,0.9868503212928772,0.0452651977539062,0.2786514717710018,43793 -5947.527049303055,1.6655912399291992,12736.541811704636,39723,0,12736.541811704636,0.9859986305236816,0.048444353044033,0.2660761208561459,43793,18686.82477331161,0.9920634031295776,0.0253645367920398,0.5136206193556406,0.9867849946022034,0.0453127808868885,0.2773486024374958,43793 -6054.862420797348,1.6991665363311768,12976.76999258995,40479,0,12976.76999258995,0.9859328866004944,0.0487171821296215,0.2655723199105637,43793,19034.442306518555,0.9920461177825928,0.025316696614027,0.527780740298631,0.986867368221283,0.0455422848463058,0.2802686791438417,43793 -6163.554557561874,1.7342724800109863,13216.979594230652,41240,0,13216.979594230652,0.985854148864746,0.0483081564307212,0.2637783218790555,43793,19383.399483442307,0.9922798275947572,0.0245455391705036,0.524614840843699,0.9867496490478516,0.0452902019023895,0.283905575901221,43793 -6265.874538183212,1.769547462463379,13457.06114768982,41996,0,13457.06114768982,0.9858916401863098,0.0483588948845863,0.2650313659736344,43793,19725.856940746307,0.9924442172050476,0.023939685896039,0.5495143535314464,0.9866639971733092,0.0453347228467464,0.2868209618737811,43793 -6376.554070949554,1.804668664932251,13697.11407327652,42753,0,13697.11407327652,0.9859809279441832,0.0491667203605175,0.2701141743328484,43793,20076.64497256279,0.992691457271576,0.023066472262144,0.5733023445192826,0.9868162274360656,0.0458563193678855,0.2811874228785678,43793 -6487.12228512764,1.838754415512085,13937.296043395996,43508,0,13937.296043395996,0.9859379529953004,0.0492123737931251,0.2669640700896823,43793,20427.44962954521,0.9930054545402528,0.0222231112420558,0.5976667659929961,0.9868007898330688,0.0461998917162418,0.28058899499968,43793 -6594.789473056793,1.875447988510132,14177.254234313965,44262,0,14177.254234313965,0.985765278339386,0.0493208542466163,0.2665610301084667,43793,20775.13205766678,0.9929794073104858,0.0224257409572601,0.5909267678181348,0.9865661859512328,0.0464385598897933,0.2845100129284141,43793 -6702.926654815674,1.910888671875,14417.521105766296,45013,0,14417.521105766296,0.9858819246292114,0.0496662817895412,0.2710117089890043,43793,21123.5917699337,0.9926031827926636,0.0235844925045967,0.553874841307491,0.9867179989814758,0.0464133955538272,0.2796055275831616,43793 -6812.935254096985,1.94553542137146,14657.61928486824,45768,0,14657.61928486824,0.9857400059700012,0.0496977865695953,0.2654268209684074,43793,21473.75328922272,0.9925385117530824,0.0236756782978773,0.5445353539343192,0.9865588545799256,0.0466780699789524,0.2796525563782617,43793 -6915.363301515579,1.981616258621216,14897.71324658394,46528,0,14897.71324658394,0.9858840703964232,0.0497120022773742,0.2653429253234877,43793,21816.33208823204,0.9925999045372008,0.0235732309520244,0.5610417006179397,0.9867228865623474,0.0465218871831893,0.2785722170901397,43793 -7024.628471374512,2.017472982406616,15137.88663125038,47283,0,15137.88663125038,0.9857581257820128,0.0503697209060192,0.2608815513531671,43793,22165.82682275772,0.992514193058014,0.0235684197396039,0.5593789092608922,0.9866579174995422,0.0471155829727649,0.2736626211202902,43793 -7127.328198194504,2.0556468963623047,15377.85105419159,48037,0,15377.85105419159,0.9855588674545288,0.0503176636993885,0.2622498398743667,43793,22508.54941940308,0.992720901966095,0.022959679365158,0.5675811443116955,0.9865494966506958,0.0468745082616806,0.2781233637894734,43793 -7229.995413780212,2.0907273292541504,15618.037787914276,48792,0,15618.037787914276,0.985990583896637,0.0508939661085605,0.2647443093865677,43793,22851.45850634575,0.9929220080375672,0.0219717137515544,0.5988421261644996,0.986821472644806,0.0474879816174507,0.2777354293557749,43793 -7335.499788284302,2.4467310905456543,15857.938402414322,49545,0,15857.938402414322,0.9858722686767578,0.0508918836712837,0.2633965535200139,43793,23197.240000247955,0.9931709170341492,0.0213091485202312,0.6088010839997859,0.986845850944519,0.0474283695220947,0.2821814441667886,43793 -7442.966633319855,2.483210563659668,16098.165282726288,50299,0,16098.165282726288,0.9857922196388244,0.0507895275950431,0.2697486751280569,43793,23544.99045753479,0.99346262216568,0.0204917322844266,0.6304601455926362,0.9867045879364014,0.0474251732230186,0.2834991495572792,43793 -7546.960394859314,2.519238471984864,16338.194887399672,51055,0,16338.194887399672,0.9858272075653076,0.0514804720878601,0.2618129902030606,43793,23889.070935964584,0.9939581155776978,0.0190504807978868,0.6652305154089658,0.9866855144500732,0.0482732728123664,0.2792698083358352,43793 -7652.790205001831,2.555806398391724,16578.26107263565,51814,0,16578.26107263565,0.9856536388397216,0.0521537996828556,0.2611147893670439,43793,24235.02452325821,0.9940596222877502,0.0188466794788837,0.6622694237726308,0.9865750670433044,0.0486783199012279,0.2745682719653072,43793 -7761.37190580368,2.5926008224487305,16818.50644493103,52572,0,16818.50644493103,0.9855180382728576,0.0523321330547332,0.260130659568339,43793,24583.90949845314,0.9937030076980592,0.0197627544403076,0.6384974196668561,0.9864805340766908,0.0489682592451572,0.276152484739644,43793 -7869.87908744812,2.6299984455108643,17058.707488775253,53324,0,17058.707488775253,0.9857269525527954,0.0524364113807678,0.2673185343252118,43793,24932.676094055176,0.993492603302002,0.0203177742660045,0.6140696882882921,0.9865548014640808,0.0489306151866912,0.2848737177843266,43793 -7976.645930767059,2.6677396297454834,17298.7577586174,54079,0,17298.7577586174,0.985632598400116,0.0533222369849681,0.2601269322421878,43793,25279.552406549454,0.9934467077255248,0.0203902274370193,0.6370068629099701,0.9865905046463012,0.0495647341012954,0.2744399979478792,43793 -8083.9382219314575,2.706973075866699,17538.758221387863,54833,0,17538.758221387863,0.9855727553367616,0.0537689179182052,0.260429199675821,43793,25626.90543746948,0.9934847354888916,0.0201917085796594,0.6408067088189375,0.9866063594818116,0.0500474199652671,0.2746520933144069,43793 -8191.948867797852,2.743889093399048,17778.90761089325,55585,0,17778.90761089325,0.9853209257125854,0.05393723025918,0.257986693851216,43793,25975.122877836227,0.9934579133987428,0.0202217083424329,0.6336687988840867,0.986279547214508,0.0502303689718246,0.2801965927562672,43793 -8297.136292934418,2.7808399200439453,18019.107362270355,56345,0,18019.107362270355,0.9856279492378236,0.0548713281750679,0.2624462802706482,43793,26320.5678255558,0.9935694336891174,0.0196368563920259,0.6439595382453188,0.9865819811820984,0.0509174615144729,0.2785519321633012,43793 -8407.226685523987,2.8178892135620117,18259.057409524918,57096,0,18259.057409524918,0.985390841960907,0.05514898896217346,0.25607921974810816,43793,26670.66596722603,0.9941542148590088,0.017990222200751305,0.6795010598582831,0.9863534569740295,0.051240529865026474,0.2821614791012339,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/measurements.csv deleted file mode 100644 index b701026f9..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/measurements.csv +++ /dev/null @@ -1,657 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.1811905,0.7991228,,,,,,,,,,,,,,,,, -1,,,0.3884694278240204,0.7994099855422974,0.0218700491084387,0.3926452994346618,0.7974664568901062,0.0248372634449589,43793.0,0.3947484791278839,0.7956756353378296,0.0263854595591664,43793.0,11.536062002182009,123.94088912010191,11.536062002182009,112.40478038787842,0.0,0.0 -100,0.7057593,0.5042978,,,,,,,,,,,,,,,,, -200,0.41106924,0.36387348,,,,,,,,,,,,,,,,, -300,0.30048,0.2609474,,,,,,,,,,,,,,,,, -400,0.20156635,0.17559423,,,,,,,,,,,,,,,,, -500,0.12518756,0.12303498,,,,,,,,,,,,,,,,, -600,0.08223224,0.090225615,,,,,,,,,,,,,,,,, -700,0.09540457,0.086242095,,,,,,,,,,,,,,,,, -752,,,0.986751914024353,0.071010872721672,0.0352122054391893,0.9841179251670836,0.0789842307567596,0.0382685972028179,43793.0,0.983142077922821,0.0816964283585548,0.0397184025301055,43793.0,251.66132426261905,476.0043268203736,251.66132426261905,224.29861760139465,0.024303913116455,0.0 -800,0.1434681,0.061537087,,,,,,,,,,,,,,,,, -900,0.035710543,0.062647775,,,,,,,,,,,,,,,,, -1000,0.107833736,0.061865002,,,,,,,,,,,,,,,,, -1100,0.06522166,0.054914877,,,,,,,,,,,,,,,,, -1200,0.094500706,0.048176788,,,,,,,,,,,,,,,,, -1300,0.07995064,0.045516375,,,,,,,,,,,,,,,,, -1400,0.13089514,0.0436236,,,,,,,,,,,,,,,,, -1500,,,0.9871714115142822,0.0486308261752128,0.0883432783390799,0.9846265912055968,0.0575493611395359,0.090297576243713,43793.0,0.9836398959159852,0.0608634054660797,0.0885365779737749,43793.0,491.7242970466614,831.7605443000793,491.7242970466614,339.94337153434753,0.0521657466888427,0.0 -1500,0.14762637,0.05370939,,,,,,,,,,,,,,,,, -1600,0.13772489,0.05157917,,,,,,,,,,,,,,,,, -1700,0.2320506,0.04844257,,,,,,,,,,,,,,,,, -1800,0.14592801,0.0463063,,,,,,,,,,,,,,,,, -1900,0.0645061,0.043713413,,,,,,,,,,,,,,,,, -2000,0.216361,0.04844293,,,,,,,,,,,,,,,,, -2100,0.14563128,0.0419148,,,,,,,,,,,,,,,,, -2200,0.22893466,0.04375143,,,,,,,,,,,,,,,,, -2248,,,0.9878830909729004,0.0436494052410125,0.1356913549852444,0.9850901365280152,0.052616462111473,0.1338274383866124,43793.0,0.9841011166572572,0.0555038340389728,0.1326923946081284,43793.0,731.9238619804382,1183.0322844982147,731.9238619804382,450.96800208091736,0.079129934310913,0.0 -2300,0.23404182,0.046754815,,,,,,,,,,,,,,,,, -2400,0.15276542,0.04521217,,,,,,,,,,,,,,,,, -2500,0.27293447,0.042996004,,,,,,,,,,,,,,,,, -2600,0.078587696,0.04505481,,,,,,,,,,,,,,,,, -2700,0.114036076,0.04512944,,,,,,,,,,,,,,,,, -2800,0.080361046,0.040200666,,,,,,,,,,,,,,,,, -2900,0.11365886,0.043803066,,,,,,,,,,,,,,,,, -2995,,,0.988243281841278,0.0413642711937427,0.1645979659390785,0.9853106141090392,0.0512285642325878,0.1607643725562286,43793.0,0.9844043850898744,0.0541264489293098,0.1550618006766275,43793.0,972.1416659355164,1531.3750817775726,972.1416659355164,559.0434498786926,0.1082615852355957,0.0 -3000,0.12604824,0.044116322,,,,,,,,,,,,,,,,, -3100,0.058770336,0.03904752,,,,,,,,,,,,,,,,, -3200,0.05175565,0.04119438,,,,,,,,,,,,,,,,, -3300,0.1096466,0.036647256,,,,,,,,,,,,,,,,, -3400,0.0983598,0.039434172,,,,,,,,,,,,,,,,, -3500,0.052620348,0.039065808,,,,,,,,,,,,,,,,, -3600,0.061288983,0.039773714,,,,,,,,,,,,,,,,, -3700,0.091826506,0.037564207,,,,,,,,,,,,,,,,, -3746,,,0.9883965849876404,0.0402144491672515,0.1835936500380219,0.985468089580536,0.050165306776762,0.1661972704844983,43793.0,0.984574556350708,0.0529424399137496,0.1661460566473754,43793.0,1212.3301203250885,1882.2137567996976,1212.3301203250885,669.6453518867493,0.1363131999969482,0.0 -3800,0.085494176,0.042595845,,,,,,,,,,,,,,,,, -3900,0.057988845,0.03780406,,,,,,,,,,,,,,,,, -4000,0.07737998,0.037195243,,,,,,,,,,,,,,,,, -4100,0.09317471,0.04409948,,,,,,,,,,,,,,,,, -4200,0.092340514,0.04300261,,,,,,,,,,,,,,,,, -4300,0.080712706,0.036959156,,,,,,,,,,,,,,,,, -4400,0.08807638,0.033998083,,,,,,,,,,,,,,,,, -4495,,,0.9884520173072816,0.0397511273622512,0.2036303545138438,0.9856345057487488,0.0494197458028793,0.1831442118251721,43793.0,0.9847337603569032,0.0519829653203487,0.1837031279984529,43793.0,1452.4844810962677,2236.3710923194885,1452.4844810962677,783.5995593070984,0.1643040180206298,0.0 -4500,0.04749077,0.03918071,,,,,,,,,,,,,,,,, -4600,0.116685785,0.040283114,,,,,,,,,,,,,,,,, -4700,0.08741747,0.03960416,,,,,,,,,,,,,,,,, -4800,0.0474846,0.037833333,,,,,,,,,,,,,,,,, -4900,0.04686567,0.041523468,,,,,,,,,,,,,,,,, -5000,0.08243224,0.042110607,,,,,,,,,,,,,,,,, -5100,0.059452645,0.036951855,,,,,,,,,,,,,,,,, -5200,0.06714585,0.037238266,,,,,,,,,,,,,,,,, -5238,,,0.9888370633125304,0.0381620749831199,0.2267099895853587,0.9859211444854736,0.0475562997162342,0.1990976778900939,43793.0,0.9850395321846008,0.0501854233443737,0.1966521375545678,43793.0,1692.6457846164703,2591.0650255680084,1692.6457846164703,898.0833630561829,0.1930367946624755,0.0 -5300,0.069339775,0.041418962,,,,,,,,,,,,,,,,, -5400,0.058363218,0.03336064,,,,,,,,,,,,,,,,, -5500,0.051600087,0.043777164,,,,,,,,,,,,,,,,, -5600,0.063140355,0.03639585,,,,,,,,,,,,,,,,, -5700,0.0532385,0.03783676,,,,,,,,,,,,,,,,, -5800,0.041125976,0.03599831,,,,,,,,,,,,,,,,, -5900,0.09778523,0.035234824,,,,,,,,,,,,,,,,, -5982,,,0.9888089299201964,0.037779688835144,0.2349900976866167,0.98591947555542,0.047309760004282,0.2019110328460734,43793.0,0.9850606322288512,0.0498638488352298,0.2016183392127283,43793.0,1932.7427368164065,2946.22796201706,1932.7427368164065,1013.100502729416,0.22141695022583,0.0 -6000,0.04035511,0.038758304,,,,,,,,,,,,,,,,, -6100,0.04869773,0.039822258,,,,,,,,,,,,,,,,, -6200,0.049050376,0.03468845,,,,,,,,,,,,,,,,, -6300,0.044863906,0.035893265,,,,,,,,,,,,,,,,, -6400,0.034950495,0.04267629,,,,,,,,,,,,,,,,, -6500,0.043623403,0.03916313,,,,,,,,,,,,,,,,, -6600,0.04995483,0.03588019,,,,,,,,,,,,,,,,, -6700,0.035456084,0.036991697,,,,,,,,,,,,,,,,, -6722,,,0.9892280697822572,0.0368048213422298,0.273699303325344,0.9860668778419496,0.0467587560415267,0.2127030918890502,43793.0,0.9852674007415771,0.0491311475634574,0.2183474994479768,43793.0,2172.955377101898,3298.24889087677,2172.955377101898,1124.8570773601532,0.2525761127471924,0.0 -6800,0.037432604,0.034746896,,,,,,,,,,,,,,,,, -6900,0.06845246,0.039204728,,,,,,,,,,,,,,,,, -7000,0.03269692,0.037739933,,,,,,,,,,,,,,,,, -7100,0.03240307,0.033569384,,,,,,,,,,,,,,,,, -7200,0.047594365,0.03467738,,,,,,,,,,,,,,,,, -7300,0.040782724,0.039403405,,,,,,,,,,,,,,,,, -7400,0.03617245,0.037453357,,,,,,,,,,,,,,,,, -7469,,,0.9891047477722168,0.0365598015487194,0.2647126914659687,0.9860149025917052,0.0473848022520542,0.2152321516486376,43793.0,0.9851044416427612,0.0503110401332378,0.2082329961052818,43793.0,2413.086009979248,3652.726801872253,2413.086009979248,1239.154440164566,0.2822961807250976,0.0 -7500,0.034662616,0.037521005,,,,,,,,,,,,,,,,, -7600,0.04134075,0.037363153,,,,,,,,,,,,,,,,, -7700,0.03237778,0.036016252,,,,,,,,,,,,,,,,, -7800,0.04317868,0.033759177,,,,,,,,,,,,,,,,, -7900,0.029410634,0.03620121,,,,,,,,,,,,,,,,, -8000,0.032380223,0.03657479,,,,,,,,,,,,,,,,, -8100,0.044887014,0.041484997,,,,,,,,,,,,,,,,, -8200,0.038316082,0.03378678,,,,,,,,,,,,,,,,, -8209,,,0.9893330335617064,0.0359494984149932,0.2700580912185201,0.9862012267112732,0.0464534796774387,0.2239962696330975,43793.0,0.9853959083557128,0.0490268431603908,0.2199700243822788,43793.0,2653.050799131393,4003.8033468723297,2653.050799131393,1350.2171156406405,0.3105733394622803,0.0 -8300,0.029094871,0.034226697,,,,,,,,,,,,,,,,, -8400,0.037420034,0.035946492,,,,,,,,,,,,,,,,, -8500,0.042061754,0.03614368,,,,,,,,,,,,,,,,, -8600,0.032749794,0.035560053,,,,,,,,,,,,,,,,, -8700,0.02629581,0.03688923,,,,,,,,,,,,,,,,, -8800,0.034676462,0.037725266,,,,,,,,,,,,,,,,, -8900,0.030584162,0.03648229,,,,,,,,,,,,,,,,, -8959,,,0.989347517490387,0.0358376614749431,0.2764391488246033,0.9863911867141724,0.0454035401344299,0.2295555249097951,43793.0,0.9855567812919616,0.0479598231613636,0.2320619693990359,43793.0,2893.292719125748,4359.108901500702,2893.292719125748,1465.2301201820374,0.3409047126770019,0.0 -9000,0.04738119,0.03965415,,,,,,,,,,,,,,,,, -9100,0.034828205,0.03871112,,,,,,,,,,,,,,,,, -9200,0.03525715,0.033686407,,,,,,,,,,,,,,,,, -9300,0.034008488,0.03876976,,,,,,,,,,,,,,,,, -9400,0.041282978,0.037728537,,,,,,,,,,,,,,,,, -9500,0.030456658,0.03328562,,,,,,,,,,,,,,,,, -9600,0.03751879,0.03792925,,,,,,,,,,,,,,,,, -9700,0.033269893,0.038866404,,,,,,,,,,,,,,,,, -9715,,,0.9895155429840088,0.0352685004472732,0.2899146334264587,0.9863510131835938,0.0459966510534286,0.2286938271593316,43793.0,0.9855496287345886,0.048551145941019,0.2278495689342686,43793.0,3133.352974653244,4711.366160392761,3133.352974653244,1577.3784453868866,0.3689651489257812,0.0 -9800,0.0286453,0.032845,,,,,,,,,,,,,,,,, -9900,0.038124062,0.0352141,,,,,,,,,,,,,,,,, -10000,0.023816857,0.034196455,,,,,,,,,,,,,,,,, -10100,0.023330478,0.031704716,,,,,,,,,,,,,,,,, -10200,0.0364113,0.037041433,,,,,,,,,,,,,,,,, -10300,0.027191289,0.031535573,,,,,,,,,,,,,,,,, -10400,0.028736144,0.036351588,,,,,,,,,,,,,,,,, -10451,,,0.9897177219390868,0.0344866290688514,0.3137916619433596,0.9864187836647034,0.0456844270229339,0.2386852578627625,43793.0,0.9855100512504578,0.0483393445611,0.2322145265859602,43793.0,3373.53779745102,5066.831092119217,3373.53779745102,1692.609769821167,0.397641658782959,0.0 -10500,0.029269556,0.033510275,,,,,,,,,,,,,,,,, -10600,0.024365766,0.031973768,,,,,,,,,,,,,,,,, -10700,0.050757706,0.03170419,,,,,,,,,,,,,,,,, -10800,0.035137985,0.034485284,,,,,,,,,,,,,,,,, -10900,0.03272405,0.03433876,,,,,,,,,,,,,,,,, -11000,0.046463676,0.034938198,,,,,,,,,,,,,,,,, -11100,0.036751196,0.034846716,,,,,,,,,,,,,,,,, -11196,,,0.9898854494094848,0.033528134226799,0.3287181087338847,0.9863871335983276,0.0453821495175361,0.2417079907701178,43793.0,0.9856237769126892,0.0480704680085182,0.244828373988002,43793.0,3613.533330202103,5415.48432302475,3613.533330202103,1801.215446472168,0.4285142421722412,0.0 -11200,0.031483937,0.031081283,,,,,,,,,,,,,,,,, -11300,0.041487183,0.035127766,,,,,,,,,,,,,,,,, -11400,0.03153587,0.03430224,,,,,,,,,,,,,,,,, -11500,0.049172267,0.03513458,,,,,,,,,,,,,,,,, -11600,0.029444711,0.033952557,,,,,,,,,,,,,,,,, -11700,0.04760876,0.036093548,,,,,,,,,,,,,,,,, -11800,0.034057476,0.030656112,,,,,,,,,,,,,,,,, -11900,0.040500432,0.034750026,,,,,,,,,,,,,,,,, -11934,,,0.9900333285331726,0.033205009996891,0.3330144534538547,0.9865024089813232,0.0452807024121284,0.2396226332273137,43793.0,0.9855976104736328,0.0480018183588981,0.232307527215203,43793.0,3853.5349090099335,5768.89560508728,3853.5349090099335,1914.5718655586245,0.4614851474761963,0.0 -12000,0.04244617,0.035087533,,,,,,,,,,,,,,,,, -12100,0.066412754,0.0323646,,,,,,,,,,,,,,,,, -12200,0.035250615,0.033122763,,,,,,,,,,,,,,,,, -12300,0.03917625,0.03357372,,,,,,,,,,,,,,,,, -12400,0.027452175,0.03144184,,,,,,,,,,,,,,,,, -12500,0.05389898,0.03385537,,,,,,,,,,,,,,,,, -12600,0.041655596,0.03087946,,,,,,,,,,,,,,,,, -12685,,,0.9902761578559875,0.0320247150957584,0.3787546138420707,0.9864386916160583,0.0451642982661724,0.2510761793891903,43793.0,0.9856582880020142,0.0479451939463615,0.2415683202863619,43793.0,4093.756762266159,6117.430259227753,4093.756762266159,2022.8337025642395,0.4917325973510742,0.0 -12700,0.04173459,0.03624654,,,,,,,,,,,,,,,,, -12800,0.035315398,0.031120185,,,,,,,,,,,,,,,,, -12900,0.029267382,0.030816304,,,,,,,,,,,,,,,,, -13000,0.042557422,0.032599624,,,,,,,,,,,,,,,,, -13100,0.083294034,0.029755313,,,,,,,,,,,,,,,,, -13200,0.049558863,0.030424902,,,,,,,,,,,,,,,,, -13300,0.044243384,0.034436725,,,,,,,,,,,,,,,,, -13400,0.042811662,0.032312777,,,,,,,,,,,,,,,,, -13433,,,0.9904384613037108,0.0314989201724529,0.3717057181802991,0.9866668581962584,0.0446271151304245,0.2472126803363637,43793.0,0.9858389496803284,0.0473495759069919,0.2462638447057491,43793.0,4333.850111484528,6469.4553780555725,4333.850111484528,2134.716328382492,0.5204606056213379,0.0 -13500,0.05946672,0.034072775,,,,,,,,,,,,,,,,, -13600,0.05511935,0.037191812,,,,,,,,,,,,,,,,, -13700,0.039414514,0.028772075,,,,,,,,,,,,,,,,, -13800,0.08962145,0.030920593,,,,,,,,,,,,,,,,, -13900,0.07577498,0.03812842,,,,,,,,,,,,,,,,, -14000,0.08083256,0.032135807,,,,,,,,,,,,,,,,, -14100,0.048741728,0.032377787,,,,,,,,,,,,,,,,, -14178,,,0.9904738664627076,0.0314284339547157,0.3734624472438089,0.9864821434020996,0.0451357886195182,0.2507288477370744,43793.0,0.9856254458427428,0.0479994229972362,0.2504320785663486,43793.0,4573.89826631546,6820.58492398262,4573.89826631546,2245.7474250793457,0.549689769744873,0.0 -14200,0.049719762,0.03775465,,,,,,,,,,,,,,,,, -14300,0.062508926,0.031991515,,,,,,,,,,,,,,,,, -14400,0.070793256,0.03751506,,,,,,,,,,,,,,,,, -14500,0.046812583,0.030918563,,,,,,,,,,,,,,,,, -14600,0.07120082,0.033405837,,,,,,,,,,,,,,,,, -14700,0.05640352,0.03358313,,,,,,,,,,,,,,,,, -14800,0.042095073,0.02771428,,,,,,,,,,,,,,,,, -14900,0.08744816,0.03431247,,,,,,,,,,,,,,,,, -14921,,,0.9905796051025392,0.0311626549810171,0.3785872396635164,0.9866968989372252,0.0440506637096405,0.2621364522054176,43793.0,0.9858343601226808,0.0468188151717186,0.2576561377273499,43793.0,4813.926500082016,7170.563362598419,4813.926500082016,2355.645773649216,0.5813260078430176,0.0 -15000,0.08076745,0.03323097,,,,,,,,,,,,,,,,, -15100,0.05883313,0.029403025,,,,,,,,,,,,,,,,, -15200,0.05386222,0.03171165,,,,,,,,,,,,,,,,, -15300,0.06125777,0.036541764,,,,,,,,,,,,,,,,, -15400,0.046213143,0.03031242,,,,,,,,,,,,,,,,, -15500,0.0788385,0.032774482,,,,,,,,,,,,,,,,, -15600,0.04510215,0.03294447,,,,,,,,,,,,,,,,, -15673,,,0.990501582622528,0.0315143540501594,0.3703973020116818,0.986763060092926,0.0441457703709602,0.2611574499676861,43793.0,0.9859552383422852,0.0467160493135452,0.2549108661049648,43793.0,5053.890505075455,7518.589793205261,5053.890505075455,2463.657342433929,0.611168384552002,0.0 -15700,0.07302901,0.03520187,,,,,,,,,,,,,,,,, -15800,0.065181665,0.031107808,,,,,,,,,,,,,,,,, -15900,0.06749478,0.029841207,,,,,,,,,,,,,,,,, -16000,0.05146862,0.027344896,,,,,,,,,,,,,,,,, -16100,0.06687942,0.035189383,,,,,,,,,,,,,,,,, -16200,0.06682732,0.030019876,,,,,,,,,,,,,,,,, -16300,0.088682935,0.028696673,,,,,,,,,,,,,,,,, -16400,0.048107263,0.033746794,,,,,,,,,,,,,,,,, -16426,,,0.9906003475189208,0.0310765262693166,0.3902105435847471,0.9866453409194946,0.0442377552390098,0.2586473980508823,43793.0,0.9858486652374268,0.0467076227068901,0.2628098371482614,43793.0,5293.863883733749,7869.176714420319,5293.863883733749,2574.221416711808,0.6401748657226562,0.0 -16500,0.05788681,0.026719717,,,,,,,,,,,,,,,,, -16600,0.091977805,0.029972572,,,,,,,,,,,,,,,,, -16700,0.08042412,0.033148676,,,,,,,,,,,,,,,,, -16800,0.054546423,0.030671414,,,,,,,,,,,,,,,,, -16900,0.13444477,0.032926552,,,,,,,,,,,,,,,,, -17000,0.0638009,0.034186788,,,,,,,,,,,,,,,,, -17100,0.06587236,0.030406544,,,,,,,,,,,,,,,,, -17178,,,0.9905943274497986,0.0308886766433715,0.3930183406771667,0.9867305755615234,0.0441368259489536,0.2651857368659311,43793.0,0.9858617186546326,0.0468544140458107,0.255345915528305,43793.0,5533.909998893738,8222.222283363342,5533.909998893738,2687.1701986789703,0.6697847843170166,0.0 -17200,0.06342472,0.029855102,,,,,,,,,,,,,,,,, -17300,0.071285754,0.03163365,,,,,,,,,,,,,,,,, -17400,0.062912956,0.030033898,,,,,,,,,,,,,,,,, -17500,0.09235132,0.030060915,,,,,,,,,,,,,,,,, -17600,0.07564278,0.030578831,,,,,,,,,,,,,,,,, -17700,0.07514809,0.026339231,,,,,,,,,,,,,,,,, -17800,0.06425344,0.02941859,,,,,,,,,,,,,,,,, -17900,0.07792626,0.031061336,,,,,,,,,,,,,,,,, -17930,,,0.9906186461448668,0.0308266170322895,0.3925631480101053,0.9867184162139891,0.0444091372191906,0.2627921188302755,43793.0,0.9858819246292114,0.046987097710371,0.2571229912207292,43793.0,5773.913890838623,8573.608848333359,5773.913890838623,2798.502618312836,0.699364423751831,0.0 -18000,0.102921784,0.032195434,,,,,,,,,,,,,,,,, -18100,0.13868941,0.03033594,,,,,,,,,,,,,,,,, -18200,0.08096807,0.03448382,,,,,,,,,,,,,,,,, -18300,0.08142799,0.032387957,,,,,,,,,,,,,,,,, -18400,0.06067557,0.028802093,,,,,,,,,,,,,,,,, -18500,0.07673976,0.028483573,,,,,,,,,,,,,,,,, -18600,0.08085685,0.03230864,,,,,,,,,,,,,,,,, -18681,,,0.9908760190010072,0.0300082452595233,0.4102506283303613,0.986647367477417,0.0445027761161327,0.2655605888096022,43793.0,0.9858953952789308,0.0472092404961586,0.2572136924916275,43793.0,6013.996770858765,8919.025218486786,6013.996770858765,2903.7837584018707,0.73091721534729,0.0 -18700,0.0677221,0.027104672,,,,,,,,,,,,,,,,, -18800,0.1054245,0.030982899,,,,,,,,,,,,,,,,, -18900,0.09822012,0.031681906,,,,,,,,,,,,,,,,, -19000,0.064879805,0.02949683,,,,,,,,,,,,,,,,, -19100,0.07252491,0.032376446,,,,,,,,,,,,,,,,, -19200,0.096059516,0.033575866,,,,,,,,,,,,,,,,, -19300,0.106911704,0.032206304,,,,,,,,,,,,,,,,, -19400,0.07989868,0.032239646,,,,,,,,,,,,,,,,, -19432,,,0.9908602237701416,0.0298089552670717,0.4269604107545199,0.9866859316825868,0.0446731783449649,0.2662054523909674,43793.0,0.985889494419098,0.0474565178155899,0.2520375820148166,43793.0,6254.042366743088,9274.63326048851,6254.042366743088,3019.294206380844,0.7621853351593018,0.0 -19500,0.1337475,0.028343597,,,,,,,,,,,,,,,,, -19600,0.076588966,0.030859815,,,,,,,,,,,,,,,,, -19700,0.09731674,0.027471593,,,,,,,,,,,,,,,,, -19800,0.08680582,0.028967848,,,,,,,,,,,,,,,,, -19900,0.069691636,0.029947955,,,,,,,,,,,,,,,,, -20000,0.09934369,0.028299239,,,,,,,,,,,,,,,,, -20100,0.07068726,0.027671615,,,,,,,,,,,,,,,,, -20175,,,0.9910188317298888,0.0289773251861333,0.4231962341096236,0.9867512583732604,0.0445153228938579,0.2637522492694075,43793.0,0.9858208894729614,0.0472513772547245,0.2518622757509492,43793.0,6494.278495073319,9624.445745706558,6494.278495073319,3128.8182978630066,0.7937026023864746,0.0 -20200,0.09542258,0.02570095,,,,,,,,,,,,,,,,, -20300,0.07460451,0.030452246,,,,,,,,,,,,,,,,, -20400,0.070332445,0.029644657,,,,,,,,,,,,,,,,, -20500,0.093015425,0.032545604,,,,,,,,,,,,,,,,, -20600,0.0872536,0.02330589,,,,,,,,,,,,,,,,, -20700,0.08611452,0.031020043,,,,,,,,,,,,,,,,, -20800,0.07869405,0.027736166,,,,,,,,,,,,,,,,, -20900,0.10086179,0.030265618,,,,,,,,,,,,,,,,, -20920,,,0.9911974668502808,0.0287505108863115,0.4392576285815208,0.9865494966506958,0.0446839854121208,0.2564115413675414,43793.0,0.985737442970276,0.0471720993518829,0.2560707943968566,43793.0,6734.232186079025,9976.325129032137,6734.232186079025,3240.6934225559235,0.8236665725708008,0.0 -21000,0.09074177,0.031531636,,,,,,,,,,,,,,,,, -21100,0.073893115,0.030700853,,,,,,,,,,,,,,,,, -21200,0.11329582,0.029683864,,,,,,,,,,,,,,,,, -21300,0.087183,0.028189974,,,,,,,,,,,,,,,,, -21400,0.067819804,0.029111404,,,,,,,,,,,,,,,,, -21500,0.16510409,0.03311664,,,,,,,,,,,,,,,,, -21600,0.09057626,0.029143527,,,,,,,,,,,,,,,,, -21667,,,0.9909592270851136,0.0293954890221357,0.4405538692550107,0.9868373274803162,0.0441310405731201,0.2741253771401372,43793.0,0.9859662055969238,0.0467855930328369,0.2591121619136606,43793.0,6974.358366250992,10324.099648237228,6974.358366250992,3348.290340423584,0.854386568069458,0.0 -21700,0.08649963,0.02472326,,,,,,,,,,,,,,,,, -21800,0.09559773,0.030403974,,,,,,,,,,,,,,,,, -21900,0.068669885,0.033178627,,,,,,,,,,,,,,,,, -22000,0.11236954,0.03140688,,,,,,,,,,,,,,,,, -22100,0.09225621,0.032277774,,,,,,,,,,,,,,,,, -22200,0.08999974,0.030836467,,,,,,,,,,,,,,,,, -22300,0.0786183,0.028113576,,,,,,,,,,,,,,,,, -22400,0.09116974,0.030742548,,,,,,,,,,,,,,,,, -22414,,,0.9908499121665956,0.0297807045280933,0.4160635084609168,0.9868572354316713,0.0448806248605251,0.2624296437397281,43793.0,0.9860116839408876,0.0476753339171409,0.255349596446049,43793.0,7214.402206897736,10671.39572572708,7214.402206897736,3455.4900765419006,0.8864257335662842,0.0 -22500,0.1679181,0.031640023,,,,,,,,,,,,,,,,, -22600,0.14535108,0.031612407,,,,,,,,,,,,,,,,, -22700,0.08830501,0.027037822,,,,,,,,,,,,,,,,, -22800,0.1272011,0.031094229,,,,,,,,,,,,,,,,, -22900,0.0732736,0.027316473,,,,,,,,,,,,,,,,, -23000,0.11092788,0.02812777,,,,,,,,,,,,,,,,, -23100,0.08331299,0.029003575,,,,,,,,,,,,,,,,, -23167,,,0.9909346103668212,0.0296036303043365,0.4199905114907307,0.9866583347320556,0.0445701442658901,0.2730043897985921,43793.0,0.9858132600784302,0.0473533831536769,0.2534655795057701,43793.0,7454.508299827576,11025.55727314949,7454.508299827576,3569.4941403865814,0.9170982837677002,0.0 -23200,0.08304079,0.02980928,,,,,,,,,,,,,,,,, -23300,0.0846135,0.031190937,,,,,,,,,,,,,,,,, -23400,0.071469866,0.025852636,,,,,,,,,,,,,,,,, -23500,0.0934242,0.032182705,,,,,,,,,,,,,,,,, -23600,0.10295962,0.025912166,,,,,,,,,,,,,,,,, -23700,0.112368815,0.0307281,,,,,,,,,,,,,,,,, -23800,0.09727676,0.031410806,,,,,,,,,,,,,,,,, -23900,0.1005231,0.030133085,,,,,,,,,,,,,,,,, -23921,,,0.9909778833389282,0.0295232720673084,0.4255365867503188,0.9867942929267884,0.0446453690528869,0.2707242531647493,43793.0,0.9859514236450196,0.0475130304694175,0.2569773068587559,43793.0,7694.484705686569,11374.190093517303,7694.484705686569,3678.099539041519,0.9474050998687744,0.0 -24000,0.09950609,0.030799164,,,,,,,,,,,,,,,,, -24100,0.09293519,0.028157879,,,,,,,,,,,,,,,,, -24200,0.09744178,0.029801557,,,,,,,,,,,,,,,,, -24300,0.096944265,0.026251558,,,,,,,,,,,,,,,,, -24400,0.07620737,0.025343418,,,,,,,,,,,,,,,,, -24500,0.12423154,0.03058563,,,,,,,,,,,,,,,,, -24600,0.15407783,0.029239561,,,,,,,,,,,,,,,,, -24666,,,0.9910315871238708,0.0290667042136192,0.4393065140763119,0.9868904948234558,0.0446149222552776,0.2751389474181981,43793.0,0.9860731363296508,0.0475754104554653,0.2620029472242,43793.0,7934.573546886444,11720.776325702667,7934.573546886444,3784.543060064316,0.9804179668426514,0.0 -24700,0.07620036,0.033079654,,,,,,,,,,,,,,,,, -24800,0.08441842,0.023320774,,,,,,,,,,,,,,,,, -24900,0.10017102,0.029317135,,,,,,,,,,,,,,,,, -25000,0.13832471,0.028126415,,,,,,,,,,,,,,,,, -25100,0.17101224,0.02679582,,,,,,,,,,,,,,,,, -25200,0.08159357,0.029465297,,,,,,,,,,,,,,,,, -25300,0.084269784,0.027718939,,,,,,,,,,,,,,,,, -25400,0.09319445,0.027594747,,,,,,,,,,,,,,,,, -25413,,,0.9910230040550232,0.0289991591125726,0.4279734510210868,0.986810564994812,0.0445956252515316,0.2719177220954323,43793.0,0.9860171675682068,0.047292198985815,0.2669060818825145,43793.0,8174.771550655365,12074.557059288025,8174.771550655365,3898.072008609772,1.0132358074188232,0.0 -25500,0.09512209,0.02535053,,,,,,,,,,,,,,,,, -25600,0.07474205,0.028905053,,,,,,,,,,,,,,,,, -25700,0.10895071,0.031758763,,,,,,,,,,,,,,,,, -25800,0.105629295,0.030262176,,,,,,,,,,,,,,,,, -25900,0.11316522,0.03269132,,,,,,,,,,,,,,,,, -26000,0.09344402,0.027354252,,,,,,,,,,,,,,,,, -26100,0.09271824,0.030837728,,,,,,,,,,,,,,,,, -26165,,,0.9912290573120116,0.0283614490181207,0.4525934288065413,0.9866153001785278,0.0446678176522254,0.2719697859873213,43793.0,0.9857838153839112,0.0471026189625263,0.258419016985112,43793.0,8414.809683322906,12424.020931243896,8414.809683322906,4007.445623874664,1.0445480346679688,0.0 -26200,0.12765649,0.0328278,,,,,,,,,,,,,,,,, -26300,0.11585446,0.027292676,,,,,,,,,,,,,,,,, -26400,0.11101242,0.028334463,,,,,,,,,,,,,,,,, -26500,0.09350464,0.028102366,,,,,,,,,,,,,,,,, -26600,0.08862032,0.029449284,,,,,,,,,,,,,,,,, -26700,0.13496424,0.030812578,,,,,,,,,,,,,,,,, -26800,0.0980226,0.03049798,,,,,,,,,,,,,,,,, -26900,0.11322423,0.03375342,,,,,,,,,,,,,,,,, -26920,,,0.991524875164032,0.0273654703050851,0.4807769372003801,0.9869144558906556,0.0444242358207702,0.2759751459481721,43793.0,0.985987663269043,0.0471849218010902,0.2646887047154832,43793.0,8654.877965211868,12773.30180120468,8654.877965211868,4116.604898691177,1.077523708343506,0.0 -27000,0.09959238,0.02925551,,,,,,,,,,,,,,,,, -27100,0.09709441,0.029741507,,,,,,,,,,,,,,,,, -27200,0.07650339,0.027533432,,,,,,,,,,,,,,,,, -27300,0.08569273,0.028576968,,,,,,,,,,,,,,,,, -27400,0.10026153,0.030728346,,,,,,,,,,,,,,,,, -27500,0.100055784,0.030179404,,,,,,,,,,,,,,,,, -27600,0.08561185,0.028143382,,,,,,,,,,,,,,,,, -27672,,,0.9917396306991576,0.0267806220799684,0.487505205329753,0.9867861866950988,0.0446733795106411,0.2727244827693121,43793.0,0.9859206676483154,0.0474046319723129,0.2627906059002099,43793.0,8895.08589887619,13120.378950834274,8895.08589887619,4223.421566724777,1.1091325283050537,0.0 -27700,0.106040746,0.028746193,,,,,,,,,,,,,,,,, -27800,0.10803318,0.030769376,,,,,,,,,,,,,,,,, -27900,0.10820472,0.030390214,,,,,,,,,,,,,,,,, -28000,0.08821924,0.030364433,,,,,,,,,,,,,,,,, -28100,0.10293744,0.03178551,,,,,,,,,,,,,,,,, -28200,0.08649277,0.03046767,,,,,,,,,,,,,,,,, -28300,0.08477515,0.027817667,,,,,,,,,,,,,,,,, -28400,0.114231594,0.02788254,,,,,,,,,,,,,,,,, -28419,,,0.9916260838508606,0.0269398819655179,0.4857912710610084,0.9868592619895936,0.0448070652782917,0.2716923961335559,43793.0,0.9859859943389891,0.0476726926863193,0.2642261563246968,43793.0,9135.200062274933,13464.72099852562,9135.200062274933,4327.597607374191,1.1407244205474854,0.0 -28500,0.08118294,0.02764289,,,,,,,,,,,,,,,,, -28600,0.110488005,0.028246807,,,,,,,,,,,,,,,,, -28700,0.097483575,0.028573662,,,,,,,,,,,,,,,,, -28800,0.09223068,0.027807858,,,,,,,,,,,,,,,,, -28900,0.105182685,0.025808549,,,,,,,,,,,,,,,,, -29000,0.07489929,0.024962856,,,,,,,,,,,,,,,,, -29100,0.12457996,0.02786164,,,,,,,,,,,,,,,,, -29173,,,0.991509735584259,0.0274159666150808,0.4712679039979643,0.9867748022079468,0.0447444543242454,0.2703691700265554,43793.0,0.985869288444519,0.0475153736770153,0.2566216951670911,43793.0,9375.395035743712,13808.285351991652,9375.395035743712,4430.912664890289,1.174621820449829,0.0 -29200,0.09610643,0.02778572,,,,,,,,,,,,,,,,, -29300,0.09947094,0.029529365,,,,,,,,,,,,,,,,, -29400,0.086306885,0.02595517,,,,,,,,,,,,,,,,, -29500,0.11336335,0.026361067,,,,,,,,,,,,,,,,, -29600,0.10812339,0.026615635,,,,,,,,,,,,,,,,, -29700,0.103822574,0.027323704,,,,,,,,,,,,,,,,, -29800,0.09435423,0.025936073,,,,,,,,,,,,,,,,, -29900,0.1064076,0.030473003,,,,,,,,,,,,,,,,, -29924,,,0.9912999868392944,0.0282875876873731,0.4548449958025577,0.9867565631866456,0.0449810102581977,0.2710515862463466,43793.0,0.985859215259552,0.0479270778596401,0.260950514393316,43793.0,9615.55433011055,14154.216762781143,9615.55433011055,4536.631418704987,1.2075426578521729,0.0 -30000,0.13208686,0.030837726,,,,,,,,,,,,,,,,, -30100,0.08565216,0.028637731,,,,,,,,,,,,,,,,, -30200,0.09642157,0.02763103,,,,,,,,,,,,,,,,, -30300,0.11320578,0.027283713,,,,,,,,,,,,,,,,, -30400,0.16169004,0.027421372,,,,,,,,,,,,,,,,, -30500,0.09638908,0.031113941,,,,,,,,,,,,,,,,, -30600,0.09016423,0.02551765,,,,,,,,,,,,,,,,, -30677,,,0.9913409948349,0.0280736107379198,0.4493626684822014,0.9867821335792542,0.044753324240446,0.2723217587699308,43793.0,0.9858810901641846,0.0478551611304283,0.2620535742729729,43793.0,9855.597105264664,14500.886621952057,9855.597105264664,4643.205877542496,1.239299774169922,0.0 -30700,0.09406374,0.024791816,,,,,,,,,,,,,,,,, -30800,0.12404365,0.02840516,,,,,,,,,,,,,,,,, -30900,0.12262208,0.029765017,,,,,,,,,,,,,,,,, -31000,0.10530895,0.026812892,,,,,,,,,,,,,,,,, -31100,0.09678743,0.02699618,,,,,,,,,,,,,,,,, -31200,0.13726598,0.030316202,,,,,,,,,,,,,,,,, -31300,0.09976413,0.02672901,,,,,,,,,,,,,,,,, -31400,0.11467312,0.025832217,,,,,,,,,,,,,,,,, -31434,,,0.9913233518600464,0.0280889272689819,0.4580684687707109,0.9867622256278992,0.0449602045118808,0.2680808067314692,43793.0,0.985932469367981,0.0474705062806606,0.2614179156552181,43793.0,10095.629689216614,14847.947264671326,10095.629689216614,4750.179671287537,1.2727985382080078,0.0 -31500,0.078009516,0.023199033,,,,,,,,,,,,,,,,, -31600,0.087578036,0.029329496,,,,,,,,,,,,,,,,, -31700,0.08719047,0.027154686,,,,,,,,,,,,,,,,, -31800,0.13016887,0.029728662,,,,,,,,,,,,,,,,, -31900,0.12614216,0.030074576,,,,,,,,,,,,,,,,, -32000,0.10409077,0.028598143,,,,,,,,,,,,,,,,, -32100,0.109742604,0.024469173,,,,,,,,,,,,,,,,, -32188,,,0.9914279580116272,0.0274613872170448,0.4688236204403521,0.986764669418335,0.0450935922563076,0.2805715467170202,43793.0,0.9858309626579284,0.0481604635715484,0.261548027310026,43793.0,10335.560795545578,15197.513407230375,10335.560795545578,4859.741172552109,1.325857400894165,0.0 -32200,0.13059792,0.031309083,,,,,,,,,,,,,,,,, -32300,0.09017203,0.02622286,,,,,,,,,,,,,,,,, -32400,0.110187024,0.028504366,,,,,,,,,,,,,,,,, -32500,0.090587564,0.02539816,,,,,,,,,,,,,,,,, -32600,0.1033806,0.027909199,,,,,,,,,,,,,,,,, -32700,0.10338889,0.028240457,,,,,,,,,,,,,,,,, -32800,0.108099006,0.029330563,,,,,,,,,,,,,,,,, -32900,0.1268205,0.031256557,,,,,,,,,,,,,,,,, -32936,,,0.9914365410804749,0.0274828933179378,0.4736294815094482,0.986638844013214,0.0448877029120922,0.2706954798470259,43793.0,0.9859328866004944,0.0476415865123271,0.2580936874518775,43793.0,10575.767451763151,15545.120171785356,10575.767451763151,4967.086313724518,1.359889268875122,0.0 -33000,0.09003498,0.025937269,,,,,,,,,,,,,,,,, -33100,0.10631412,0.029375704,,,,,,,,,,,,,,,,, -33200,0.09637109,0.02518792,,,,,,,,,,,,,,,,, -33300,0.102609955,0.02844688,,,,,,,,,,,,,,,,, -33400,0.15059446,0.030744959,,,,,,,,,,,,,,,,, -33500,0.09644662,0.028964357,,,,,,,,,,,,,,,,, -33600,0.12034761,0.026574101,,,,,,,,,,,,,,,,, -33684,,,0.991641879081726,0.0267317239195108,0.4813694209161578,0.9868279695510864,0.0449778325855731,0.2782887354386499,43793.0,0.9859544038772584,0.0479418747127056,0.2566208483574557,43793.0,10815.805988311768,15896.519999742508,10815.805988311768,5078.3935606479645,1.392927169799805,0.0 -33700,0.088320404,0.028421672,,,,,,,,,,,,,,,,, -33800,0.10329301,0.023418767,,,,,,,,,,,,,,,,, -33900,0.1222385,0.028202387,,,,,,,,,,,,,,,,, -34000,0.09928198,0.028460994,,,,,,,,,,,,,,,,, -34100,0.10265724,0.026166493,,,,,,,,,,,,,,,,, -34200,0.12807243,0.023936344,,,,,,,,,,,,,,,,, -34300,0.13036513,0.027546449,,,,,,,,,,,,,,,,, -34400,0.11608485,0.030173752,,,,,,,,,,,,,,,,, -34436,,,0.9919087290763856,0.0259376429021358,0.505223532274534,0.9868279695510864,0.0449526906013488,0.2817446194608765,43793.0,0.985964059829712,0.0477706789970397,0.267297587830477,43793.0,11055.845739364624,16241.80126285553,11055.845739364624,5183.58104300499,1.426476001739502,0.0 -34500,0.114883974,0.027379965,,,,,,,,,,,,,,,,, -34600,0.10203687,0.028920958,,,,,,,,,,,,,,,,, -34700,0.11709224,0.027855616,,,,,,,,,,,,,,,,, -34800,0.110454984,0.028820336,,,,,,,,,,,,,,,,, -34900,0.104477994,0.027425239,,,,,,,,,,,,,,,,, -35000,0.108754635,0.027277881,,,,,,,,,,,,,,,,, -35100,0.09554021,0.027714029,,,,,,,,,,,,,,,,, -35189,,,0.9923018217086792,0.0247634388506412,0.5400334925387037,0.9866660237312316,0.0450561232864856,0.2804774655869647,43793.0,0.9858478307724,0.0475787110626697,0.268802852899132,43793.0,11296.034619808195,16594.574902057648,11296.034619808195,5296.110832691193,1.461381196975708,0.0 -35200,0.09738057,0.024766026,,,,,,,,,,,,,,,,, -35300,0.10741805,0.026071245,,,,,,,,,,,,,,,,, -35400,0.08850343,0.024369806,,,,,,,,,,,,,,,,, -35500,0.11900225,0.02878682,,,,,,,,,,,,,,,,, -35600,0.10463239,0.025157627,,,,,,,,,,,,,,,,, -35700,0.10096384,0.027127115,,,,,,,,,,,,,,,,, -35800,0.1106928,0.030989578,,,,,,,,,,,,,,,,, -35900,0.13004929,0.027193397,,,,,,,,,,,,,,,,, -35940,,,0.9923170804977416,0.0245520714670419,0.5424872716795479,0.9867898225784302,0.0454942397773265,0.2835340130049796,43793.0,0.9859143495559692,0.0483407154679298,0.2643073077441462,43793.0,11536.095036268234,16948.44060611725,11536.095036268234,5409.861083984375,1.496187686920166,0.0 -36000,0.104658976,0.028079076,,,,,,,,,,,,,,,,, -36100,0.11069538,0.026723275,,,,,,,,,,,,,,,,, -36200,0.1095237,0.027290665,,,,,,,,,,,,,,,,, -36300,0.09486662,0.025764843,,,,,,,,,,,,,,,,, -36400,0.10679077,0.02702573,,,,,,,,,,,,,,,,, -36500,0.10211688,0.027921744,,,,,,,,,,,,,,,,, -36600,0.12605076,0.02651468,,,,,,,,,,,,,,,,, -36692,,,0.9918184876441956,0.0261344145983457,0.5122857286942415,0.9867236614227296,0.0451662205159664,0.2817615748781584,43793.0,0.9858882427215576,0.0482158660888671,0.2715438217701766,43793.0,11776.166133642197,17296.837022542953,11776.166133642197,5518.131883144379,1.530383825302124,0.0 -36700,0.12404067,0.025114805,,,,,,,,,,,,,,,,, -36800,0.11287263,0.028017916,,,,,,,,,,,,,,,,, -36900,0.113063306,0.029450892,,,,,,,,,,,,,,,,, -37000,0.115040876,0.025711752,,,,,,,,,,,,,,,,, -37100,0.106727004,0.02609683,,,,,,,,,,,,,,,,, -37200,0.10579425,0.023569722,,,,,,,,,,,,,,,,, -37300,0.11074049,0.028270116,,,,,,,,,,,,,,,,, -37400,0.11644443,0.028360037,,,,,,,,,,,,,,,,, -37454,,,0.9918424487113952,0.0261771418154239,0.4943024589468207,0.9865970015525818,0.0453442595899105,0.2697746331850882,43793.0,0.9857577085494996,0.0483944676816463,0.2608674631916036,43793.0,12016.210379838943,17646.148535490036,12016.210379838943,5627.343741178513,1.5651028156280518,0.0 -37500,0.11033833,0.023486616,,,,,,,,,,,,,,,,, -37600,0.13855396,0.029592348,,,,,,,,,,,,,,,,, -37700,0.11651566,0.026733013,,,,,,,,,,,,,,,,, -37800,0.10778134,0.026953863,,,,,,,,,,,,,,,,, -37900,0.12150134,0.027010534,,,,,,,,,,,,,,,,, -38000,0.111250125,0.024512721,,,,,,,,,,,,,,,,, -38100,0.11771052,0.028216716,,,,,,,,,,,,,,,,, -38200,0.108789444,0.026243703,,,,,,,,,,,,,,,,, -38209,,,0.991769313812256,0.0264698714017868,0.4968405851355063,0.9866034984588624,0.0454076938331127,0.2799366756800923,43793.0,0.985753059387207,0.048491571098566,0.2581400577281293,43793.0,12256.275754451752,17992.750111341476,12256.275754451752,5733.825747728348,1.5990040302276611,0.0 -38300,0.11088818,0.02630021,,,,,,,,,,,,,,,,, -38400,0.097487,0.02623868,,,,,,,,,,,,,,,,, -38500,0.13507241,0.028759923,,,,,,,,,,,,,,,,, -38600,0.12429386,0.028377771,,,,,,,,,,,,,,,,, -38700,0.11664924,0.022960242,,,,,,,,,,,,,,,,, -38800,0.1342531,0.02570733,,,,,,,,,,,,,,,,, -38900,0.12472219,0.024059867,,,,,,,,,,,,,,,,, -38963,,,0.9920052289962769,0.0257331114262342,0.4961674082046797,0.9868503212928772,0.0452651977539062,0.2786514717710018,43793.0,0.9859619736671448,0.0483974255621433,0.2601541778417228,43793.0,12496.309900045397,18338.66453528404,12496.309900045397,5839.652360200882,1.6322979927062988,0.0 -39000,0.1295753,0.028134428,,,,,,,,,,,,,,,,, -39100,0.13197121,0.025779087,,,,,,,,,,,,,,,,, -39200,0.1212269,0.023939583,,,,,,,,,,,,,,,,, -39300,0.09653284,0.024382439,,,,,,,,,,,,,,,,, -39400,0.10562256,0.026063884,,,,,,,,,,,,,,,,, -39500,0.14404733,0.027851535,,,,,,,,,,,,,,,,, -39600,0.12633242,0.025987413,,,,,,,,,,,,,,,,, -39700,0.1099984,0.023992464,,,,,,,,,,,,,,,,, -39723,,,0.9920634031295776,0.0253645367920398,0.5136206193556406,0.9867849946022034,0.0453127808868885,0.2773486024374958,43793.0,0.9859986305236816,0.048444353044033,0.2660761208561459,43793.0,12736.541811704636,18686.82477331161,12736.541811704636,5947.527049303055,1.6655912399291992,0.0 -39800,0.13386087,0.0284281,,,,,,,,,,,,,,,,, -39900,0.10211527,0.024325183,,,,,,,,,,,,,,,,, -40000,0.14111629,0.02983688,,,,,,,,,,,,,,,,, -40100,0.10688256,0.024119478,,,,,,,,,,,,,,,,, -40200,0.10597854,0.025323458,,,,,,,,,,,,,,,,, -40300,0.101340584,0.023941483,,,,,,,,,,,,,,,,, -40400,0.1119085,0.025350193,,,,,,,,,,,,,,,,, -40479,,,0.9920461177825928,0.025316696614027,0.527780740298631,0.986867368221283,0.0455422848463058,0.2802686791438417,43793.0,0.9859328866004944,0.0487171821296215,0.2655723199105637,43793.0,12976.76999258995,19034.442306518555,12976.76999258995,6054.862420797348,1.6991665363311768,0.0 -40500,0.13558558,0.024021175,,,,,,,,,,,,,,,,, -40600,0.1183881,0.02353126,,,,,,,,,,,,,,,,, -40700,0.1143569,0.023779683,,,,,,,,,,,,,,,,, -40800,0.14338423,0.024835799,,,,,,,,,,,,,,,,, -40900,0.13725306,0.026685987,,,,,,,,,,,,,,,,, -41000,0.13369544,0.025704151,,,,,,,,,,,,,,,,, -41100,0.12548085,0.025833331,,,,,,,,,,,,,,,,, -41200,0.13508876,0.024796523,,,,,,,,,,,,,,,,, -41240,,,0.9922798275947572,0.0245455391705036,0.524614840843699,0.9867496490478516,0.0452902019023895,0.283905575901221,43793.0,0.985854148864746,0.0483081564307212,0.2637783218790555,43793.0,13216.979594230652,19383.399483442307,13216.979594230652,6163.554557561874,1.7342724800109863,0.0 -41300,0.12500325,0.022859609,,,,,,,,,,,,,,,,, -41400,0.11642859,0.023240974,,,,,,,,,,,,,,,,, -41500,0.13618585,0.02915909,,,,,,,,,,,,,,,,, -41600,0.124493755,0.025569223,,,,,,,,,,,,,,,,, -41700,0.12242176,0.024278961,,,,,,,,,,,,,,,,, -41800,0.12485071,0.02406328,,,,,,,,,,,,,,,,, -41900,0.115886904,0.023129214,,,,,,,,,,,,,,,,, -41996,,,0.9924442172050476,0.023939685896039,0.5495143535314464,0.9866639971733092,0.0453347228467464,0.2868209618737811,43793.0,0.9858916401863098,0.0483588948845863,0.2650313659736344,43793.0,13457.06114768982,19725.856940746307,13457.06114768982,6265.874538183212,1.769547462463379,0.0 -42000,0.13140363,0.02406728,,,,,,,,,,,,,,,,, -42100,0.11442114,0.02307933,,,,,,,,,,,,,,,,, -42200,0.14430262,0.025347058,,,,,,,,,,,,,,,,, -42300,0.14492431,0.026674004,,,,,,,,,,,,,,,,, -42400,0.12389042,0.024783066,,,,,,,,,,,,,,,,, -42500,0.13002875,0.024353582,,,,,,,,,,,,,,,,, -42600,0.1512148,0.02607217,,,,,,,,,,,,,,,,, -42700,0.12358296,0.023889147,,,,,,,,,,,,,,,,, -42753,,,0.992691457271576,0.023066472262144,0.5733023445192826,0.9868162274360656,0.0458563193678855,0.2811874228785678,43793.0,0.9859809279441832,0.0491667203605175,0.2701141743328484,43793.0,13697.11407327652,20076.64497256279,13697.11407327652,6376.554070949554,1.804668664932251,0.0 -42800,0.13272904,0.024059298,,,,,,,,,,,,,,,,, -42900,0.1353601,0.023195555,,,,,,,,,,,,,,,,, -43000,0.11197306,0.022775605,,,,,,,,,,,,,,,,, -43100,0.117693186,0.027911332,,,,,,,,,,,,,,,,, -43200,0.12649706,0.024536917,,,,,,,,,,,,,,,,, -43300,0.12481297,0.025209326,,,,,,,,,,,,,,,,, -43400,0.1480584,0.026448116,,,,,,,,,,,,,,,,, -43500,0.13984643,0.024833297,,,,,,,,,,,,,,,,, -43508,,,0.9930054545402528,0.0222231112420558,0.5976667659929961,0.9868007898330688,0.0461998917162418,0.28058899499968,43793.0,0.9859379529953004,0.0492123737931251,0.2669640700896823,43793.0,13937.296043395996,20427.44962954521,13937.296043395996,6487.12228512764,1.838754415512085,0.0 -43600,0.15542081,0.027103474,,,,,,,,,,,,,,,,, -43700,0.12701176,0.025728287,,,,,,,,,,,,,,,,, -43800,0.1416995,0.024309108,,,,,,,,,,,,,,,,, -43900,0.12633975,0.023743434,,,,,,,,,,,,,,,,, -44000,0.14386019,0.023796987,,,,,,,,,,,,,,,,, -44100,0.14341721,0.024076087,,,,,,,,,,,,,,,,, -44200,0.12603812,0.024544613,,,,,,,,,,,,,,,,, -44262,,,0.9929794073104858,0.0224257409572601,0.5909267678181348,0.9865661859512328,0.0464385598897933,0.2845100129284141,43793.0,0.985765278339386,0.0493208542466163,0.2665610301084667,43793.0,14177.254234313965,20775.13205766678,14177.254234313965,6594.789473056793,1.875447988510132,0.0 -44300,0.17365,0.026924385,,,,,,,,,,,,,,,,, -44400,0.17893527,0.021565039,,,,,,,,,,,,,,,,, -44500,0.12542911,0.022631805,,,,,,,,,,,,,,,,, -44600,0.12826903,0.027355976,,,,,,,,,,,,,,,,, -44700,0.1381464,0.022724543,,,,,,,,,,,,,,,,, -44800,0.13579763,0.025011836,,,,,,,,,,,,,,,,, -44900,0.16054167,0.025730645,,,,,,,,,,,,,,,,, -45000,0.14113265,0.022995127,,,,,,,,,,,,,,,,, -45013,,,0.9926031827926636,0.0235844925045967,0.553874841307491,0.9867179989814758,0.0464133955538272,0.2796055275831616,43793.0,0.9858819246292114,0.0496662817895412,0.2710117089890043,43793.0,14417.521105766296,21123.5917699337,14417.521105766296,6702.926654815674,1.910888671875,0.0 -45100,0.16105007,0.025283327,,,,,,,,,,,,,,,,, -45200,0.1358655,0.02439125,,,,,,,,,,,,,,,,, -45300,0.11866348,0.021755766,,,,,,,,,,,,,,,,, -45400,0.11324339,0.022289189,,,,,,,,,,,,,,,,, -45500,0.13974708,0.022488004,,,,,,,,,,,,,,,,, -45600,0.15340038,0.024622014,,,,,,,,,,,,,,,,, -45700,0.15771492,0.024606952,,,,,,,,,,,,,,,,, -45768,,,0.9925385117530824,0.0236756782978773,0.5445353539343192,0.9865588545799256,0.0466780699789524,0.2796525563782617,43793.0,0.9857400059700012,0.0496977865695953,0.2654268209684074,43793.0,14657.61928486824,21473.75328922272,14657.61928486824,6812.935254096985,1.94553542137146,0.0 -45800,0.16730407,0.02133388,,,,,,,,,,,,,,,,, -45900,0.13555364,0.023292718,,,,,,,,,,,,,,,,, -46000,0.14697342,0.025416452,,,,,,,,,,,,,,,,, -46100,0.12938485,0.021783324,,,,,,,,,,,,,,,,, -46200,0.15563937,0.022110058,,,,,,,,,,,,,,,,, -46300,0.13512127,0.023628464,,,,,,,,,,,,,,,,, -46400,0.18197936,0.02246185,,,,,,,,,,,,,,,,, -46500,0.13498199,0.021862095,,,,,,,,,,,,,,,,, -46528,,,0.9925999045372008,0.0235732309520244,0.5610417006179397,0.9867228865623474,0.0465218871831893,0.2785722170901397,43793.0,0.9858840703964232,0.0497120022773742,0.2653429253234877,43793.0,14897.71324658394,21816.33208823204,14897.71324658394,6915.363301515579,1.981616258621216,0.0 -46600,0.17777683,0.023785178,,,,,,,,,,,,,,,,, -46700,0.1524715,0.026409106,,,,,,,,,,,,,,,,, -46800,0.1585337,0.023832183,,,,,,,,,,,,,,,,, -46900,0.15219554,0.025596447,,,,,,,,,,,,,,,,, -47000,0.13550606,0.021008816,,,,,,,,,,,,,,,,, -47100,0.12900144,0.02329574,,,,,,,,,,,,,,,,, -47200,0.15171231,0.023885101,,,,,,,,,,,,,,,,, -47283,,,0.992514193058014,0.0235684197396039,0.5593789092608922,0.9866579174995422,0.0471155829727649,0.2736626211202902,43793.0,0.9857581257820128,0.0503697209060192,0.2608815513531671,43793.0,15137.88663125038,22165.82682275772,15137.88663125038,7024.628471374512,2.017472982406616,0.0 -47300,0.17854829,0.025735589,,,,,,,,,,,,,,,,, -47400,0.1485699,0.022536444,,,,,,,,,,,,,,,,, -47500,0.14530925,0.022960111,,,,,,,,,,,,,,,,, -47600,0.17270267,0.023107022,,,,,,,,,,,,,,,,, -47700,0.18713433,0.023168752,,,,,,,,,,,,,,,,, -47800,0.1555621,0.022365658,,,,,,,,,,,,,,,,, -47900,0.15588427,0.02263848,,,,,,,,,,,,,,,,, -48000,0.15798762,0.02050135,,,,,,,,,,,,,,,,, -48037,,,0.992720901966095,0.022959679365158,0.5675811443116955,0.9865494966506958,0.0468745082616806,0.2781233637894734,43793.0,0.9855588674545288,0.0503176636993885,0.2622498398743667,43793.0,15377.85105419159,22508.54941940308,15377.85105419159,7127.328198194504,2.0556468963623047,0.0 -48100,0.1447433,0.021637963,,,,,,,,,,,,,,,,, -48200,0.15293168,0.021303358,,,,,,,,,,,,,,,,, -48300,0.12911433,0.019402772,,,,,,,,,,,,,,,,, -48400,0.13646686,0.022028634,,,,,,,,,,,,,,,,, -48500,0.15309776,0.024459979,,,,,,,,,,,,,,,,, -48600,0.13941948,0.021618413,,,,,,,,,,,,,,,,, -48700,0.17246507,0.022755112,,,,,,,,,,,,,,,,, -48792,,,0.9929220080375672,0.0219717137515544,0.5988421261644996,0.986821472644806,0.0474879816174507,0.2777354293557749,43793.0,0.985990583896637,0.0508939661085605,0.2647443093865677,43793.0,15618.037787914276,22851.45850634575,15618.037787914276,7229.995413780212,2.0907273292541504,0.0 -48800,0.16362233,0.023846004,,,,,,,,,,,,,,,,, -48900,0.16487417,0.021610742,,,,,,,,,,,,,,,,, -49000,0.15841629,0.020780582,,,,,,,,,,,,,,,,, -49100,0.1763972,0.018981006,,,,,,,,,,,,,,,,, -49200,0.19086753,0.024684995,,,,,,,,,,,,,,,,, -49300,0.15176144,0.020683141,,,,,,,,,,,,,,,,, -49400,0.15278497,0.022260606,,,,,,,,,,,,,,,,, -49500,0.14765511,0.020907143,,,,,,,,,,,,,,,,, -49545,,,0.9931709170341492,0.0213091485202312,0.6088010839997859,0.986845850944519,0.0474283695220947,0.2821814441667886,43793.0,0.9858722686767578,0.0508918836712837,0.2633965535200139,43793.0,15857.938402414322,23197.240000247955,15857.938402414322,7335.499788284302,2.4467310905456543,0.0 -49600,0.1500767,0.022977713,,,,,,,,,,,,,,,,, -49700,0.1568389,0.023643909,,,,,,,,,,,,,,,,, -49800,0.1537544,0.022072842,,,,,,,,,,,,,,,,, -49900,0.18068871,0.022353739,,,,,,,,,,,,,,,,, -50000,0.18269514,0.020325623,,,,,,,,,,,,,,,,, -50100,0.16152933,0.022775684,,,,,,,,,,,,,,,,, -50200,0.16311733,0.020047016,,,,,,,,,,,,,,,,, -50299,,,0.99346262216568,0.0204917322844266,0.6304601455926362,0.9867045879364014,0.0474251732230186,0.2834991495572792,43793.0,0.9857922196388244,0.0507895275950431,0.2697486751280569,43793.0,16098.165282726288,23544.99045753479,16098.165282726288,7442.966633319855,2.483210563659668,0.0 -50300,0.18235911,0.019658316,,,,,,,,,,,,,,,,, -50400,0.14669958,0.01968868,,,,,,,,,,,,,,,,, -50500,0.14905506,0.020904347,,,,,,,,,,,,,,,,, -50600,0.18642448,0.022596389,,,,,,,,,,,,,,,,, -50700,0.15924332,0.020372258,,,,,,,,,,,,,,,,, -50800,0.17022589,0.022130625,,,,,,,,,,,,,,,,, -50900,0.17483047,0.022818167,,,,,,,,,,,,,,,,, -51000,0.15735109,0.020856282,,,,,,,,,,,,,,,,, -51055,,,0.9939581155776978,0.0190504807978868,0.6652305154089658,0.9866855144500732,0.0482732728123664,0.2792698083358352,43793.0,0.9858272075653076,0.0514804720878601,0.2618129902030606,43793.0,16338.194887399672,23889.070935964584,16338.194887399672,7546.960394859314,2.519238471984864,0.0 -51100,0.16324665,0.021712564,,,,,,,,,,,,,,,,, -51200,0.17218377,0.022380736,,,,,,,,,,,,,,,,, -51300,0.18231472,0.021911489,,,,,,,,,,,,,,,,, -51400,0.15361747,0.018736582,,,,,,,,,,,,,,,,, -51500,0.1650467,0.02038097,,,,,,,,,,,,,,,,, -51600,0.18007277,0.023572344,,,,,,,,,,,,,,,,, -51700,0.17305377,0.020888966,,,,,,,,,,,,,,,,, -51800,0.20673376,0.020363458,,,,,,,,,,,,,,,,, -51814,,,0.9940596222877502,0.0188466794788837,0.6622694237726308,0.9865750670433044,0.0486783199012279,0.2745682719653072,43793.0,0.9856536388397216,0.0521537996828556,0.2611147893670439,43793.0,16578.26107263565,24235.02452325821,16578.26107263565,7652.790205001831,2.555806398391724,0.0 -51900,0.18098997,0.02253756,,,,,,,,,,,,,,,,, -52000,0.16122702,0.020633807,,,,,,,,,,,,,,,,, -52100,0.18055458,0.022888465,,,,,,,,,,,,,,,,, -52200,0.15902865,0.02022348,,,,,,,,,,,,,,,,, -52300,0.15640828,0.019900644,,,,,,,,,,,,,,,,, -52400,0.1966242,0.020939613,,,,,,,,,,,,,,,,, -52500,0.18143754,0.023823287,,,,,,,,,,,,,,,,, -52572,,,0.9937030076980592,0.0197627544403076,0.6384974196668561,0.9864805340766908,0.0489682592451572,0.276152484739644,43793.0,0.9855180382728576,0.0523321330547332,0.260130659568339,43793.0,16818.50644493103,24583.90949845314,16818.50644493103,7761.37190580368,2.5926008224487305,0.0 -52600,0.22058998,0.022902608,,,,,,,,,,,,,,,,, -52700,0.18668303,0.018968217,,,,,,,,,,,,,,,,, -52800,0.1720474,0.019489389,,,,,,,,,,,,,,,,, -52900,0.20409922,0.022447016,,,,,,,,,,,,,,,,, -53000,0.1835241,0.019821793,,,,,,,,,,,,,,,,, -53100,0.18659712,0.020571683,,,,,,,,,,,,,,,,, -53200,0.18194471,0.019319523,,,,,,,,,,,,,,,,, -53300,0.25658262,0.025085997,,,,,,,,,,,,,,,,, -53324,,,0.993492603302002,0.0203177742660045,0.6140696882882921,0.9865548014640808,0.0489306151866912,0.2848737177843266,43793.0,0.9857269525527954,0.0524364113807678,0.2673185343252118,43793.0,17058.707488775253,24932.676094055176,17058.707488775253,7869.87908744812,2.6299984455108643,0.0 -53400,0.17803016,0.019008694,,,,,,,,,,,,,,,,, -53500,0.20382269,0.019269513,,,,,,,,,,,,,,,,, -53600,0.20606613,0.020211583,,,,,,,,,,,,,,,,, -53700,0.19453575,0.02056214,,,,,,,,,,,,,,,,, -53800,0.17609021,0.018785423,,,,,,,,,,,,,,,,, -53900,0.19446594,0.019976653,,,,,,,,,,,,,,,,, -54000,0.20090297,0.021649808,,,,,,,,,,,,,,,,, -54079,,,0.9934467077255248,0.0203902274370193,0.6370068629099701,0.9865905046463012,0.0495647341012954,0.2744399979478792,43793.0,0.985632598400116,0.0533222369849681,0.2601269322421878,43793.0,17298.7577586174,25279.552406549454,17298.7577586174,7976.645930767059,2.6677396297454834,0.0 -54100,0.19401404,0.019671379,,,,,,,,,,,,,,,,, -54200,0.20368387,0.01993841,,,,,,,,,,,,,,,,, -54300,0.21555051,0.02043125,,,,,,,,,,,,,,,,, -54400,0.19378243,0.01948389,,,,,,,,,,,,,,,,, -54500,0.17929104,0.019278528,,,,,,,,,,,,,,,,, -54600,0.19973308,0.018142633,,,,,,,,,,,,,,,,, -54700,0.25223905,0.021703478,,,,,,,,,,,,,,,,, -54800,0.19299453,0.01823105,,,,,,,,,,,,,,,,, -54833,,,0.9934847354888916,0.0201917085796594,0.6408067088189375,0.9866063594818116,0.0500474199652671,0.2746520933144069,43793.0,0.9855727553367616,0.0537689179182052,0.260429199675821,43793.0,17538.758221387863,25626.90543746948,17538.758221387863,8083.9382219314575,2.706973075866699,0.0 -54900,0.22145048,0.02022146,,,,,,,,,,,,,,,,, -55000,0.20820549,0.018343594,,,,,,,,,,,,,,,,, -55100,0.19771394,0.019904934,,,,,,,,,,,,,,,,, -55200,0.18342878,0.018013457,,,,,,,,,,,,,,,,, -55300,0.22689593,0.020297846,,,,,,,,,,,,,,,,, -55400,0.22260362,0.020667134,,,,,,,,,,,,,,,,, -55500,0.18032706,0.01772224,,,,,,,,,,,,,,,,, -55585,,,0.9934579133987428,0.0202217083424329,0.6336687988840867,0.986279547214508,0.0502303689718246,0.2801965927562672,43793.0,0.9853209257125854,0.05393723025918,0.257986693851216,43793.0,17778.90761089325,25975.122877836227,17778.90761089325,8191.948867797852,2.743889093399048,0.0 -55600,0.220307,0.018257579,,,,,,,,,,,,,,,,, -55700,0.19874874,0.018124506,,,,,,,,,,,,,,,,, -55800,0.21508546,0.018753774,,,,,,,,,,,,,,,,, -55900,0.19191024,0.017902821,,,,,,,,,,,,,,,,, -56000,0.17254527,0.016060492,,,,,,,,,,,,,,,,, -56100,0.22015385,0.019667542,,,,,,,,,,,,,,,,, -56200,0.21396314,0.019441685,,,,,,,,,,,,,,,,, -56300,0.23158816,0.019145152,,,,,,,,,,,,,,,,, -56345,,,0.9935694336891174,0.0196368563920259,0.6439595382453188,0.9865819811820984,0.0509174615144729,0.2785519321633012,43793.0,0.9856279492378236,0.0548713281750679,0.2624462802706482,43793.0,18019.107362270355,26320.5678255558,18019.107362270355,8297.136292934418,2.7808399200439453,0.0 -56400,0.21062699,0.018449979,,,,,,,,,,,,,,,,, -56500,0.22793974,0.018485632,,,,,,,,,,,,,,,,, -56600,0.18544306,0.017294744,,,,,,,,,,,,,,,,, -56700,0.21010266,0.018243253,,,,,,,,,,,,,,,,, -56800,0.22858049,0.018585151,,,,,,,,,,,,,,,,, -56900,0.20755263,0.016250303,,,,,,,,,,,,,,,,, -57000,0.21133356,0.020728081,,,,,,,,,,,,,,,,, -57096,,,0.9941542148590088,0.0179902222007513,0.6795010598582831,0.9863534569740297,0.0512405298650264,0.2821614791012339,43793.0,0.985390841960907,0.0551489889621734,0.2560792197481081,43793.0,18259.05740952492,26670.66596722603,18259.05740952492,8407.226685523987,2.817889213562012,0.0 -57100,0.2325887,0.019246196,,,,,,,,,,,,,,,,, -57200,0.2223508,0.019482804,,,,,,,,,,,,,,,,, -57300,0.21657865,0.017885773,,,,,,,,,,,,,,,,, -57400,0.19985834,0.017766245,,,,,,,,,,,,,,,,, -57500,0.23624818,0.017845789,,,,,,,,,,,,,,,,, -57600,0.24636623,0.019021705,,,,,,,,,,,,,,,,, -57700,0.24615563,0.020089442,,,,,,,,,,,,,,,,, -57780,,,,,,,,,,,,,,18477.00848555565,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 239fa8917..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -104.78602743148804,0.0,11.89487600326538,1,0,11.89487600326538,0.3947480618953705,0.7956756353378296,0.0263663607361259,43793,116.68095016479492,0.3887049853801727,0.7994356751441956,0.0230993680346546,0.3926452994346618,0.7974664568901062,0.0248431496041781,43793 -211.86702752113345,0.023287296295166,251.93400812149048,748,0,251.93400812149048,0.9831724166870116,0.0675854012370109,0.0414079131255289,43793,463.8443031311035,0.9867178201675416,0.0540571250021457,0.0413288035889094,0.984153687953949,0.0642426088452339,0.0405507054465399,43793 -315.04234194755554,0.0506999492645263,491.9445328712464,1500,0,491.9445328712464,0.9832056760787964,0.0670676082372665,0.041944288818654,43793,807.0779266357422,0.9868547916412354,0.0534534752368927,0.0404074919278545,0.984169065952301,0.0638072565197944,0.0401475249861516,43793 -422.1489651203156,0.07857346534729,731.971120595932,2255,0,731.971120595932,0.983265459537506,0.0653808563947677,0.05076436693773,43793,1154.2588243484497,0.9868224263191224,0.0518390424549579,0.0526378552763878,0.9842417240142822,0.0621321313083171,0.0498430073024989,43793 -528.547721862793,0.1052675247192382,972.108246088028,3012,0,972.108246088028,0.9833332896232604,0.0633403435349464,0.0663088386050019,43793,1500.8414223194122,0.9869569540023804,0.0495302826166152,0.0698423122171888,0.98429936170578,0.0596925765275955,0.0653821551267546,43793 -636.8478765487671,0.1316099166870117,1212.2337291240692,3768,0,1212.2337291240692,0.9835565090179444,0.0613648705184459,0.0921200282411233,43793,1849.313452243805,0.9871721863746644,0.0480428338050842,0.0928386623424184,0.9844650030136108,0.0578740499913692,0.0906684969170909,43793 -741.5066666603088,0.1595327854156494,1452.3454518318176,4522,0,1452.3454518318176,0.9838467240333556,0.0577796697616577,0.1173304671659813,43793,2194.1322169303894,0.9874393939971924,0.0453436709940433,0.1245404069794609,0.9847800135612488,0.0548060648143291,0.1111730464072495,43793 -847.9944367408752,0.1878836154937744,1692.4349780082705,5269,0,1692.4349780082705,0.9838109016418456,0.0564559586346149,0.128097338821601,43793,2540.757847547531,0.9877853393554688,0.0436364114284515,0.1349541273898319,0.9847633838653564,0.0536157079041004,0.1228188394742861,43793 -957.4483652114868,0.2167634963989257,1932.6885635852807,6017,0,1932.6885635852807,0.9840855598449708,0.0552940890192985,0.1426776114746375,43793,2890.5146567821503,0.9877686500549316,0.0430554561316967,0.1524773079411523,0.9850386381149292,0.0525088608264923,0.1389896901816074,43793 -1066.894276857376,0.2442078590393066,2172.851979732513,6771,0,2172.851979732513,0.984137773513794,0.0547958463430404,0.137266701689431,43793,3240.1714539527893,0.988117814064026,0.0419328287243843,0.1471271580563905,0.9850714802742004,0.0519345328211784,0.1369009566762386,43793 -1173.5531170368197,0.2712841033935547,2412.9996078014374,7522,0,2412.9996078014374,0.984243094921112,0.05458440259099,0.1388095702206179,43793,3587.0254402160645,0.9878849387168884,0.0425020903348922,0.1590378641956764,0.9852147698402404,0.0517855808138847,0.1396837630676205,43793 -1282.5451092720032,0.2990598678588867,2653.148269176483,8278,0,2653.148269176483,0.9843167662620544,0.0536992959678173,0.1497537212633665,43793,3936.214086532593,0.9880563616752625,0.0417524687945842,0.1610822629742694,0.9852578043937684,0.0509598068892955,0.1488274819384134,43793 -1388.0890011787417,0.3277096748352051,2893.190643548965,9026,0,2893.190643548965,0.9842881560325624,0.053785428404808,0.1467478836927074,43793,4281.849505901337,0.9882237315177916,0.0422638691961765,0.1666925139943225,0.9851611852645874,0.0513968542218208,0.1470027461109916,43793 -1491.438364028931,0.3579344749450683,3133.340534448624,9778,0,3133.340534448624,0.9843904972076416,0.0536238476634025,0.1550489886621871,43793,4625.398817539215,0.988153874874115,0.0412655286490917,0.1776809481114947,0.9853041172027588,0.0507806502282619,0.1524563100207836,43793 -1599.2725040912628,0.3884594440460205,3373.36945104599,10526,0,3373.36945104599,0.9843748807907104,0.0532482601702213,0.1629483331394156,43793,4973.312300443649,0.988370418548584,0.0405784808099269,0.1764253451645966,0.9853061437606812,0.0504616163671016,0.1536274094312301,43793 -1703.5707485675812,0.4186489582061767,3613.459751367569,11280,0,3613.459751367569,0.98440819978714,0.0535706207156181,0.1535308390425825,43793,5317.750963449478,0.9882692694664,0.0407231338322162,0.1806463704697511,0.9853795766830444,0.0506321750581264,0.1516324973639082,43793 -1809.784565925598,0.4482748508453369,3853.7374968528734,12034,0,3853.7374968528734,0.9845050573349,0.0542274564504623,0.1533889459399407,43793,5664.29166841507,0.9883570671081544,0.0406102873384952,0.187897305803684,0.9853954315185548,0.0513092726469039,0.1469485466278225,43793 -1913.699834108353,0.4794890880584717,4093.757108926773,12793,0,4093.757108926773,0.9844494462013244,0.0533479265868663,0.160429288795423,43793,6008.2779705524445,0.9884112477302552,0.0403369925916194,0.1881470002900639,0.9853706955909728,0.0505053550004959,0.1582621068851679,43793 -2018.3670988082888,0.5114123821258545,4334.002388477325,13551,0,4334.002388477325,0.9844717979431152,0.0538974851369857,0.1589984335660068,43793,6353.242366313934,0.9883461594581604,0.040721870958805,0.1881293046142475,0.9853450655937196,0.0510617010295391,0.1625691607394891,43793 -2127.1761713027954,0.542165994644165,4574.218851804733,14309,0,4574.218851804733,0.9845977425575256,0.0522046722471714,0.1706820676001708,43793,6702.318521976471,0.9886543154716492,0.0395914278924465,0.1886557689800052,0.9855687618255616,0.0496367439627647,0.1672704252272741,43793 -2236.548712730408,0.5745348930358887,4814.288403272629,15065,0,4814.288403272629,0.984438955783844,0.0533509105443954,0.1598961203343125,43793,7051.812611579895,0.9882400631904602,0.0407610535621643,0.1768932244139741,0.98537677526474,0.0505119189620018,0.1621677928424066,43793 -2341.950671672821,0.6038038730621338,5054.5061230659485,15819,0,5054.5061230659485,0.9844840168952942,0.0530835837125778,0.1610488174560943,43793,7397.481730699539,0.9884235262870787,0.0403160750865936,0.1800596144394523,0.9854745864868164,0.0501674003899097,0.1589919840343427,43793 -2449.4044041633606,0.6335453987121582,5294.63217496872,16575,0,5294.63217496872,0.9845366477966307,0.0522386245429515,0.1585987983113044,43793,7745.110958576202,0.9884944558143616,0.0400652587413787,0.1946530005520942,0.9855245351791382,0.0493463389575481,0.1648574669725729,43793 -2552.754077911377,0.6638550758361816,5534.787349462509,17338,0,5534.787349462509,0.9845648407936096,0.0527358315885067,0.1628610260560505,43793,8088.665885448456,0.988466501235962,0.0401264503598213,0.1891231957882106,0.9855257272720336,0.049877855926752,0.1615751031520323,43793 -2657.8833651542664,0.6942009925842285,5774.893808841705,18102,0,5774.893808841705,0.9846120476722716,0.0526713393628597,0.1660621285801552,43793,8433.951842308044,0.9886404275894164,0.0396051965653896,0.2016534385517859,0.985596776008606,0.0498393774032592,0.1660370317504795,43793 -2764.619640350342,0.7254533767700195,6014.983092546463,18858,0,6014.983092546463,0.9844696521759032,0.0534532889723777,0.1569761173961862,43793,8780.828848361969,0.9884719252586364,0.0400543436408042,0.1856856220138234,0.9854221940040588,0.0504087135195732,0.1566981421152463,43793 -2866.300521612168,0.7550556659698486,6255.108852148056,19617,0,6255.108852148056,0.9845610857009888,0.0529608465731143,0.1662060151744309,43793,9122.684712648392,0.9885842204093932,0.039523422718048,0.1949042477951265,0.9855241179466248,0.0500289462506771,0.1680505834962828,43793 -2975.4721944332123,0.787273645401001,6495.332993984222,20370,0,6495.332993984222,0.984565258026123,0.0525690093636512,0.1609021073804863,43793,9472.132626056671,0.9885866641998292,0.0403407551348209,0.190552363333073,0.985431969165802,0.0497220121324062,0.1618298861649432,43793 -3077.280902385712,0.8171370029449463,6735.384921073914,21128,0,6735.384921073914,0.984618365764618,0.0525290071964263,0.1673807543089017,43793,9814.043020009996,0.9888412952423096,0.0390158295631408,0.1984423527563512,0.9855196475982666,0.0496362783014774,0.1704775937419216,43793 -3185.17812037468,0.8489246368408203,6975.345520019531,21888,0,6975.345520019531,0.9847084879875184,0.052076943218708,0.1687030523777823,43793,10161.952837228777,0.988599419593811,0.0396040715277195,0.1955595368230457,0.9856069087982178,0.0493299253284931,0.1709071697455449,43793 -3291.6833214759827,0.8813366889953613,7215.383780956268,22640,0,7215.383780956268,0.9846465587615968,0.0523969121277332,0.1618779776619127,43793,10508.548904895782,0.9885961413383484,0.0399170108139514,0.19312791976085,0.9854997396469116,0.0498074442148208,0.1646969130909154,43793 -3397.201486110688,0.9188354015350342,7455.464512586594,23395,0,7455.464512586594,0.9845758080482484,0.052251573652029,0.1668913197902987,43793,10854.205847978592,0.9885261058807372,0.0398454070091247,0.2045054501347917,0.9855204820632936,0.0496206805109977,0.1696779981251334,43793 -3500.685555458069,0.9494316577911376,7695.480072259903,24150,0,7695.480072259903,0.9847046732902528,0.0524015016853809,0.1677331479726042,43793,11197.75633430481,0.9885877370834352,0.0395864844322204,0.198987881596415,0.9856178760528564,0.0494555197656154,0.1701272548379744,43793 -3607.390516757965,0.980963945388794,7935.451113462448,24895,0,7935.451113462448,0.9846591949462892,0.0524363219738006,0.1622380428695896,43793,11544.484077215197,0.9885725378990172,0.0396206192672252,0.1884983398828092,0.985593557357788,0.0495513565838336,0.1638078991555382,43793 -3713.474276781082,1.012589931488037,8175.517338037491,25646,0,8175.517338037491,0.9842435121536256,0.0553393326699733,0.14969740247395,43793,11890.685946941376,0.9884662628173828,0.0406963936984539,0.1889124579895067,0.9852277636528016,0.0517746023833751,0.1620684550428439,43793 -3818.364675998688,1.0460069179534912,8415.550181627274,26396,0,8415.550181627274,0.9846705794334412,0.0522284172475338,0.1665207318270333,43793,12235.662456035614,0.9886945486068726,0.0391205176711082,0.2076740135078882,0.9855862259864808,0.0494204312562942,0.1699061284238353,43793 -3924.174050092697,1.0775668621063232,8655.730852127075,27156,0,8655.730852127075,0.9845792055130004,0.0527769736945629,0.1750600617505908,43793,12581.70412158966,0.9885122776031494,0.0393948778510093,0.2045708442516365,0.985469937324524,0.0498745180666446,0.1768885285711521,43793 -4030.839526414871,1.110039234161377,8895.984867811203,27913,0,8895.984867811203,0.9831374287605286,0.0604283660650253,0.1038184794279353,43793,12928.67640185356,0.9870010018348694,0.0472155846655368,0.1103796881982437,0.9840448498725892,0.0574222430586814,0.0985057329111871,43793 -4136.737054347992,1.1446597576141355,9135.996017217636,28671,0,9135.996017217636,0.9846815466880798,0.052851814776659,0.1664202583307426,43793,13274.639439105988,0.98862624168396,0.0392830222845077,0.2086000109865182,0.9856077432632446,0.0498518608510494,0.1624308171851189,43793 -4240.01061463356,1.1763203144073486,9376.22216939926,29423,0,9376.22216939926,0.9847674369812012,0.0517728812992572,0.1757862534867406,43793,13618.191462039948,0.9888563752174376,0.0386092141270637,0.1952220106581476,0.9856755137443542,0.0489604957401752,0.1736647414079732,43793 -4344.998346328735,1.209303617477417,9616.297856807709,30177,0,9616.297856807709,0.9847522974014282,0.0520043931901454,0.1742552044543923,43793,13963.307752609251,0.9887091517448424,0.039166934788227,0.2039721153390068,0.9856755137443542,0.0492084473371505,0.1739264363776995,43793 -4446.108120918274,1.2410881519317627,9856.373237371445,30930,0,9856.373237371445,0.98477041721344,0.0525437444448471,0.1694027176352054,43793,14304.544981956482,0.9886908531188964,0.0390931963920593,0.1994932528086023,0.9856739044189452,0.0495341904461383,0.1778154397527345,43793 -4550.519570350647,1.2745091915130615,10096.422625780106,31693,0,10096.422625780106,0.9847164750099182,0.0521720796823501,0.1737591675951564,43793,14649.05891919136,0.9885149598121644,0.0396221540868282,0.194215934757525,0.9857165217399596,0.0492618419229984,0.1747485124816603,43793 -4657.76614189148,1.3084368705749512,10336.437881469728,32452,0,10336.437881469728,0.9845787882804872,0.0522355996072292,0.1658119645684935,43793,14996.374430418016,0.988632082939148,0.0392586700618267,0.204369646453474,0.9855963587760924,0.0492731444537639,0.175708370852377,43793 -4760.532192707062,1.342067003250122,10576.67140340805,33217,0,10576.67140340805,0.9848264455795288,0.0522915534675121,0.1770464495206037,43793,15339.427662611008,0.9888612031936646,0.0386213921010494,0.2074459373809987,0.9858188033103944,0.0493051446974277,0.1833785917632685,43793 -4869.175209760666,1.3774352073669434,10816.757288694382,33968,0,10816.757288694382,0.9848175644874572,0.0521162152290344,0.1782565607649952,43793,15688.211569309236,0.988823652267456,0.0387829691171646,0.2097071365273531,0.9857664704322816,0.0491368472576141,0.1783703510066827,43793 -4976.963281154633,1.411080837249756,11056.986756563188,34725,0,11056.986756563188,0.9847207069396972,0.0518995597958564,0.1696407534624433,43793,16036.28269791603,0.9889429807662964,0.0380373783409595,0.2179248873481545,0.9857100248336792,0.0489446595311164,0.171346110199196,43793 -5083.223834276199,1.4434118270874023,11297.105067253113,35479,0,11297.105067253113,0.9847089052200316,0.0522429086267948,0.1731621813997392,43793,16382.713775157928,0.9887934327125548,0.0383532419800758,0.2175011480776211,0.985688328742981,0.0491277314722538,0.1822192971931874,43793 -5188.498993873596,1.47762131690979,11537.17705130577,36232,0,11537.17705130577,0.984761118888855,0.0518080927431583,0.1681800602295691,43793,16728.115339756012,0.9889160990715028,0.038255613297224,0.2174486785456183,0.9857396483421326,0.0487405806779861,0.1753565414250365,43793 -5292.71022772789,1.513559103012085,11777.208164453506,36979,0,11777.208164453506,0.9847059845924376,0.0520849637687206,0.1709051116411258,43793,17072.413522958755,0.9888588786125184,0.0385875515639781,0.2114442024080533,0.9856597185134888,0.0490989945828914,0.173963467929486,43793 -5399.546529531479,1.54679536819458,12017.201822519302,37734,0,12017.201822519302,0.9846655130386353,0.0519244857132434,0.168482360769061,43793,17419.296885967255,0.988965630531311,0.0381307713687419,0.2163211936202405,0.9856706857681274,0.0491302870213985,0.1755820478866468,43793 -5502.337423086166,1.5815489292144775,12257.161703586578,38494,0,12257.161703586578,0.9848403334617616,0.0511401407420635,0.1729110865622203,43793,17762.10279250145,0.9889072179794312,0.0382791757583618,0.2058152897816714,0.9858139753341676,0.0483922623097896,0.1763234303719536,43793 -5607.583341121674,1.6165733337402344,12497.155459403992,39250,0,12497.155459403992,0.9843534231185912,0.0526152662932872,0.1698303266528774,43793,18107.39728331566,0.9886656999588012,0.0391177199780941,0.209000849892865,0.9853398203849792,0.0495943687856197,0.1764142580300723,43793 -5712.008977174759,1.6503448486328125,12737.230195999146,40008,0,12737.230195999146,0.9848032593727112,0.0515053868293762,0.176415089457208,43793,18451.95160317421,0.988823652267456,0.0383989922702312,0.2118176751457389,0.9857344031333924,0.048629205673933,0.1850538860129511,43793 -5814.817209243774,1.6850457191467283,12977.429183483124,40765,0,12977.429183483124,0.98490309715271,0.0514543503522872,0.1773463244063912,43793,18795.01301598549,0.988957405090332,0.0381148569285869,0.2250403524375973,0.9858992099761964,0.0484270341694355,0.1902167060453705,43793 -5918.8470821380615,1.7193021774291992,13217.376702070236,41525,0,13217.376702070236,0.9849595427513124,0.0513292327523231,0.1781156339557913,43793,19139.04503273964,0.9889556169509888,0.0379303097724914,0.2247792695009524,0.9858963489532472,0.0484644919633865,0.1854287582158917,43793 -6024.671246051788,1.7549612522125244,13457.534606933594,42284,0,13457.534606933594,0.9849300384521484,0.0514775849878788,0.1728550651727835,43793,19485.08216929436,0.9891242384910583,0.037462193518877,0.2190261638573074,0.9858837723731996,0.0484282746911048,0.1882398772699011,43793 -6133.084717750549,1.789576292037964,13697.556046247482,43039,0,13697.556046247482,0.9848761558532716,0.0512370876967906,0.1755742926960362,43793,19833.57168912888,0.9891573190689088,0.0371311753988266,0.2342620274547162,0.9858505129814148,0.0484101846814155,0.1830095767989059,43793 -6242.9166431427,1.824268341064453,13937.573343992231,43794,0,13937.573343992231,0.984913170337677,0.0510397516191005,0.1783023438325755,43793,20183.47554087639,0.989040195941925,0.0374348014593124,0.2364169550507474,0.9858590364456176,0.0483596324920654,0.1810546301375452,43793 -6345.184635639191,1.8598694801330569,14177.521633148192,44546,0,14177.521633148192,0.9848727583885192,0.0515077896416187,0.1787041508472491,43793,20525.7474205494,0.9892123341560364,0.0371120907366275,0.233629171166757,0.9858176112174988,0.0486885346472263,0.1865813241001744,43793 -6448.654628753662,1.8947200775146484,14417.702919960022,45299,0,14417.702919960022,0.9849607944488524,0.0515817105770111,0.1814172696895895,43793,20869.454118013386,0.989071786403656,0.0375232920050621,0.2258647965845565,0.9858951568603516,0.0485557839274406,0.1847851065721733,43793 -6558.222561836243,1.9303202629089355,14657.908961057665,46053,0,14657.908961057665,0.9849578142166138,0.0514113120734691,0.1808884317825607,43793,21219.284123420715,0.989142119884491,0.0371794253587722,0.227694452117626,0.9859349131584167,0.0484732091426849,0.1875919367502804,43793 -6659.538984060288,1.972861051559448,14898.099229335783,46810,0,14898.099229335783,0.9849451780319214,0.0510099902749061,0.1808509299838068,43793,21560.853466033936,0.989086389541626,0.0373251177370548,0.2304754300024778,0.9859126210212708,0.0482357218861579,0.1895051070078998,43793 -6767.612953901291,2.0079030990600586,15138.353684902191,47566,0,15138.353684902191,0.9849637150764464,0.0508946739137172,0.1829725058936036,43793,21909.236936807632,0.9891598224639891,0.0371924303472042,0.2341049216560896,0.9859057068824768,0.0480745211243629,0.1944715354910958,43793 -6869.479328393936,2.0448648929595947,15378.298068523409,48326,0,15378.298068523409,0.9850214123725892,0.0514427870512008,0.1861461973192388,43793,22251.10497307777,0.9889408946037292,0.037719901651144,0.2345589341680993,0.985910177230835,0.0486140400171279,0.1934921789183694,43793 -6977.155750751495,2.0811920166015625,15618.53837943077,49084,0,15618.53837943077,0.984955072402954,0.0509009025990963,0.1842454033049275,43793,22599.077934503555,0.9892590641975404,0.0370510518550872,0.2275835327942321,0.9858817458152772,0.0481711849570274,0.1869669971379753,43793 -7082.8976101875305,2.118276596069336,15858.68499994278,49835,0,15858.68499994278,0.9850075244903564,0.0509440377354621,0.1868286006328181,43793,22945.02338528633,0.9892882704734802,0.0367698408663272,0.2388334799873369,0.9859365224838256,0.0483722016215324,0.1890876817019972,43793 -7189.770831346512,2.1541895866394043,16098.916516304016,50593,0,16098.916516304016,0.9849485754966736,0.0503535121679306,0.1815041684128693,43793,23292.18442082405,0.989309787750244,0.0366719216108322,0.2393315454742929,0.9859462976455688,0.0478029549121856,0.184098849675496,43793 -7292.035793304443,2.192917823791504,16339.007791280746,51352,0,16339.007791280746,0.9850766062736512,0.0506192557513713,0.1871973443523216,43793,23634.599209308624,0.9894444346427916,0.0360986664891243,0.2547790329618258,0.9859942197799684,0.0479114614427089,0.1914060906010451,43793 -7394.589677810669,2.2324140071868896,16579.094081163406,52115,0,16579.094081163406,0.9850197434425354,0.0503366254270076,0.189377131888189,43793,23977.299023389816,0.9893057346343994,0.0364396683871746,0.2481928758179206,0.9859434366226196,0.0476787872612476,0.1925508469856677,43793 -7495.518446922302,2.2704098224639893,16819.26360821724,52871,0,16819.26360821724,0.9850972294807434,0.0508808642625808,0.1879493142080678,43793,24318.454954862595,0.9894201755523682,0.0361618362367153,0.2529269382896945,0.985986053943634,0.0480906032025814,0.1971760256055278,43793 -7598.746030807495,2.308499336242676,17059.2628159523,53632,0,17059.2628159523,0.9850951433181764,0.0502948872745037,0.1894639877631705,43793,24661.73969578743,0.9894657731056212,0.035980150103569,0.2576308237827217,0.9860092401504515,0.0478395707905292,0.1956039990538769,43793 -7702.997226715088,2.3462722301483154,17299.21511077881,54393,0,17299.21511077881,0.9847282767295836,0.0514081604778766,0.1770618964335692,43793,25006.00048089028,0.9890407919883728,0.0375233180820941,0.2252914477340575,0.9857092499732972,0.0487162582576274,0.1846760456002189,43793 -7805.111110925674,2.3839142322540283,17539.427534341812,55154,0,17539.427534341812,0.9851566553115844,0.0504313670098781,0.1851077487487467,43793,25348.384285211563,0.9894230365753174,0.0361375026404857,0.2417186113024705,0.9860656261444092,0.0478094816207885,0.1957881328081809,43793 -7908.455640792847,2.4211137294769287,17779.39391064644,55910,0,17779.39391064644,0.985159158706665,0.0499532595276832,0.1888697926384954,43793,25691.751713991165,0.9894699454307556,0.0362928993999958,0.2515200120653674,0.9860270619392396,0.0474446602165699,0.1968637869987704,43793 -8011.882115125656,2.458162784576416,18019.52805519104,56674,0,18019.52805519104,0.9852097034454346,0.0502450168132782,0.1947143161858327,43793,26035.3692150116,0.9894645810127258,0.035822756588459,0.2573926845896402,0.9861395359039308,0.0474042557179927,0.2015896895037112,43793 -8110.851234436035,2.49564528465271,18259.709180355072,57432,0,18259.709180355072,0.985026478767395,0.050025396049022675,0.190945310755514,43793,26374.576673030853,0.9895554780960083,0.03577268868684769,0.2601344773012853,0.9859702587127686,0.04734781011939049,0.19831942626624036,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/measurements.csv deleted file mode 100644 index 84f689b1e..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/measurements.csv +++ /dev/null @@ -1,661 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.181188,0.79912275,,,,,,,,,,,,,,,,, -1,,,0.3887049853801727,0.7994356751441956,0.0230993680346546,0.3926452994346618,0.7974664568901062,0.0248431496041781,43793.0,0.3947480618953705,0.7956756353378296,0.0263663607361259,43793.0,11.89487600326538,116.68095016479492,11.89487600326538,104.78602743148804,0.0,0.0 -100,0.11703073,0.12093872,,,,,,,,,,,,,,,,, -200,0.008400563,0.056903604,,,,,,,,,,,,,,,,, -300,0.007058307,0.05111576,,,,,,,,,,,,,,,,, -400,0.009147398,0.048861403,,,,,,,,,,,,,,,,, -500,0.024820365,0.054282445,,,,,,,,,,,,,,,,, -600,0.015920186,0.049871292,,,,,,,,,,,,,,,,, -700,0.024270125,0.06576804,,,,,,,,,,,,,,,,, -748,,,0.9867178201675416,0.0540571250021457,0.0413288035889094,0.984153687953949,0.0642426088452339,0.0405507054465399,43793.0,0.9831724166870116,0.0675854012370109,0.0414079131255289,43793.0,251.93400812149048,463.8443031311035,251.93400812149048,211.86702752113345,0.023287296295166,0.0 -800,0.0077818,0.046091232,,,,,,,,,,,,,,,,, -900,0.0047898726,0.054992292,,,,,,,,,,,,,,,,, -1000,0.0055767787,0.05686925,,,,,,,,,,,,,,,,, -1100,0.009790724,0.052105032,,,,,,,,,,,,,,,,, -1200,0.0132822,0.047238614,,,,,,,,,,,,,,,,, -1300,0.0086457,0.04663629,,,,,,,,,,,,,,,,, -1400,0.007955853,0.047648836,,,,,,,,,,,,,,,,, -1500,,,0.9868547916412354,0.0534534752368927,0.0404074919278545,0.984169065952301,0.0638072565197944,0.0401475249861516,43793.0,0.9832056760787964,0.0670676082372665,0.041944288818654,43793.0,491.9445328712464,807.0779266357422,491.9445328712464,315.04234194755554,0.0506999492645263,0.0 -1500,0.014615547,0.061812956,,,,,,,,,,,,,,,,, -1600,0.009512017,0.057269324,,,,,,,,,,,,,,,,, -1700,0.008002428,0.053574752,,,,,,,,,,,,,,,,, -1800,0.018469367,0.056763317,,,,,,,,,,,,,,,,, -1900,0.006810995,0.05034717,,,,,,,,,,,,,,,,, -2000,0.010300093,0.058726814,,,,,,,,,,,,,,,,, -2100,0.008668623,0.050912216,,,,,,,,,,,,,,,,, -2200,0.0130805755,0.049892657,,,,,,,,,,,,,,,,, -2255,,,0.9868224263191224,0.0518390424549579,0.0526378552763878,0.9842417240142822,0.0621321313083171,0.0498430073024989,43793.0,0.983265459537506,0.0653808563947677,0.05076436693773,43793.0,731.971120595932,1154.2588243484497,731.971120595932,422.1489651203156,0.07857346534729,0.0 -2300,0.01005689,0.05697137,,,,,,,,,,,,,,,,, -2400,0.008230871,0.05046413,,,,,,,,,,,,,,,,, -2500,0.007200115,0.04840285,,,,,,,,,,,,,,,,, -2600,0.008549281,0.05511635,,,,,,,,,,,,,,,,, -2700,0.01340149,0.053353474,,,,,,,,,,,,,,,,, -2800,0.004164778,0.047755703,,,,,,,,,,,,,,,,, -2900,0.01574226,0.052425135,,,,,,,,,,,,,,,,, -3000,0.0071828007,0.052065935,,,,,,,,,,,,,,,,, -3012,,,0.9869569540023804,0.0495302826166152,0.0698423122171888,0.98429936170578,0.0596925765275955,0.0653821551267546,43793.0,0.9833332896232604,0.0633403435349464,0.0663088386050019,43793.0,972.108246088028,1500.8414223194122,972.108246088028,528.547721862793,0.1052675247192382,0.0 -3100,0.010308714,0.04816851,,,,,,,,,,,,,,,,, -3200,0.009694669,0.04928027,,,,,,,,,,,,,,,,, -3300,0.020849679,0.041689638,,,,,,,,,,,,,,,,, -3400,0.011961268,0.045448743,,,,,,,,,,,,,,,,, -3500,0.013705427,0.046822477,,,,,,,,,,,,,,,,, -3600,0.013958787,0.046837997,,,,,,,,,,,,,,,,, -3700,0.02369793,0.043525256,,,,,,,,,,,,,,,,, -3768,,,0.9871721863746644,0.0480428338050842,0.0928386623424184,0.9844650030136108,0.0578740499913692,0.0906684969170909,43793.0,0.9835565090179444,0.0613648705184459,0.0921200282411233,43793.0,1212.2337291240692,1849.313452243805,1212.2337291240692,636.8478765487671,0.1316099166870117,0.0 -3800,0.0069086775,0.04873315,,,,,,,,,,,,,,,,, -3900,0.007932282,0.043742903,,,,,,,,,,,,,,,,, -4000,0.057884295,0.04644651,,,,,,,,,,,,,,,,, -4100,0.026205543,0.05141551,,,,,,,,,,,,,,,,, -4200,0.031497225,0.05207849,,,,,,,,,,,,,,,,, -4300,0.01638793,0.042233612,,,,,,,,,,,,,,,,, -4400,0.0303966,0.03855563,,,,,,,,,,,,,,,,, -4500,0.015532244,0.04518826,,,,,,,,,,,,,,,,, -4522,,,0.9874393939971924,0.0453436709940433,0.1245404069794609,0.9847800135612488,0.0548060648143291,0.1111730464072495,43793.0,0.9838467240333556,0.0577796697616577,0.1173304671659813,43793.0,1452.3454518318176,2194.1322169303894,1452.3454518318176,741.5066666603088,0.1595327854156494,0.0 -4600,0.032473717,0.04637864,,,,,,,,,,,,,,,,, -4700,0.02309703,0.044220973,,,,,,,,,,,,,,,,, -4800,0.01709275,0.042405616,,,,,,,,,,,,,,,,, -4900,0.011588146,0.04750985,,,,,,,,,,,,,,,,, -5000,0.056276076,0.04757439,,,,,,,,,,,,,,,,, -5100,0.02645505,0.041980915,,,,,,,,,,,,,,,,, -5200,0.020261763,0.04277195,,,,,,,,,,,,,,,,, -5269,,,0.9877853393554688,0.0436364114284515,0.1349541273898319,0.9847633838653564,0.0536157079041004,0.1228188394742861,43793.0,0.9838109016418456,0.0564559586346149,0.128097338821601,43793.0,1692.4349780082705,2540.757847547531,1692.4349780082705,847.9944367408752,0.1878836154937744,0.0 -5300,0.028921455,0.046095252,,,,,,,,,,,,,,,,, -5400,0.053672485,0.038601372,,,,,,,,,,,,,,,,, -5500,0.04613568,0.048646115,,,,,,,,,,,,,,,,, -5600,0.021845054,0.041501798,,,,,,,,,,,,,,,,, -5700,0.032970656,0.042591807,,,,,,,,,,,,,,,,, -5800,0.030567495,0.040604703,,,,,,,,,,,,,,,,, -5900,0.029096752,0.037926562,,,,,,,,,,,,,,,,, -6000,0.014444614,0.044370376,,,,,,,,,,,,,,,,, -6017,,,0.9877686500549316,0.0430554561316967,0.1524773079411523,0.9850386381149292,0.0525088608264923,0.1389896901816074,43793.0,0.9840855598449708,0.0552940890192985,0.1426776114746375,43793.0,1932.6885635852807,2890.5146567821503,1932.6885635852807,957.4483652114868,0.2167634963989257,0.0 -6100,0.03553502,0.045044474,,,,,,,,,,,,,,,,, -6200,0.028335316,0.03927748,,,,,,,,,,,,,,,,, -6300,0.030868407,0.041787006,,,,,,,,,,,,,,,,, -6400,0.025489463,0.04915225,,,,,,,,,,,,,,,,, -6500,0.021848654,0.0445282,,,,,,,,,,,,,,,,, -6600,0.051704917,0.04119904,,,,,,,,,,,,,,,,, -6700,0.041828234,0.04121943,,,,,,,,,,,,,,,,, -6771,,,0.988117814064026,0.0419328287243843,0.1471271580563905,0.9850714802742004,0.0519345328211784,0.1369009566762386,43793.0,0.984137773513794,0.0547958463430404,0.137266701689431,43793.0,2172.851979732513,3240.1714539527893,2172.851979732513,1066.894276857376,0.2442078590393066,0.0 -6800,0.029019589,0.0405311,,,,,,,,,,,,,,,,, -6900,0.041147545,0.043413855,,,,,,,,,,,,,,,,, -7000,0.022471758,0.044233967,,,,,,,,,,,,,,,,, -7100,0.044702258,0.03836928,,,,,,,,,,,,,,,,, -7200,0.03581504,0.038564514,,,,,,,,,,,,,,,,, -7300,0.06837282,0.04473391,,,,,,,,,,,,,,,,, -7400,0.050240938,0.044092283,,,,,,,,,,,,,,,,, -7500,0.027805066,0.042204577,,,,,,,,,,,,,,,,, -7522,,,0.9878849387168884,0.0425020903348922,0.1590378641956764,0.9852147698402404,0.0517855808138847,0.1396837630676205,43793.0,0.984243094921112,0.05458440259099,0.1388095702206179,43793.0,2412.9996078014374,3587.0254402160645,2412.9996078014374,1173.5531170368197,0.2712841033935547,0.0 -7600,0.029783163,0.042140048,,,,,,,,,,,,,,,,, -7700,0.03308024,0.044533934,,,,,,,,,,,,,,,,, -7800,0.062452313,0.03903328,,,,,,,,,,,,,,,,, -7900,0.022281986,0.040279392,,,,,,,,,,,,,,,,, -8000,0.025801552,0.041263185,,,,,,,,,,,,,,,,, -8100,0.036176603,0.04651108,,,,,,,,,,,,,,,,, -8200,0.029281247,0.037617307,,,,,,,,,,,,,,,,, -8278,,,0.9880563616752625,0.0417524687945842,0.1610822629742694,0.9852578043937684,0.0509598068892955,0.1488274819384134,43793.0,0.9843167662620544,0.0536992959678173,0.1497537212633665,43793.0,2653.148269176483,3936.214086532593,2653.148269176483,1282.5451092720032,0.2990598678588867,0.0 -8300,0.08908327,0.038277853,,,,,,,,,,,,,,,,, -8400,0.040671133,0.03985018,,,,,,,,,,,,,,,,, -8500,0.026590073,0.041341055,,,,,,,,,,,,,,,,, -8600,0.073031954,0.04324941,,,,,,,,,,,,,,,,, -8700,0.0490887,0.04246467,,,,,,,,,,,,,,,,, -8800,0.034988668,0.04498991,,,,,,,,,,,,,,,,, -8900,0.083986424,0.04252296,,,,,,,,,,,,,,,,, -9000,0.04180062,0.044616427,,,,,,,,,,,,,,,,, -9026,,,0.9882237315177916,0.0422638691961765,0.1666925139943225,0.9851611852645874,0.0513968542218208,0.1470027461109916,43793.0,0.9842881560325624,0.053785428404808,0.1467478836927074,43793.0,2893.190643548965,4281.849505901337,2893.190643548965,1388.0890011787417,0.3277096748352051,0.0 -9100,0.041080892,0.04385686,,,,,,,,,,,,,,,,, -9200,0.031456202,0.039208647,,,,,,,,,,,,,,,,, -9300,0.036290254,0.045348573,,,,,,,,,,,,,,,,, -9400,0.05595151,0.0446357,,,,,,,,,,,,,,,,, -9500,0.056044407,0.038116045,,,,,,,,,,,,,,,,, -9600,0.044872653,0.044171162,,,,,,,,,,,,,,,,, -9700,0.05120895,0.043789186,,,,,,,,,,,,,,,,, -9778,,,0.988153874874115,0.0412655286490917,0.1776809481114947,0.9853041172027588,0.0507806502282619,0.1524563100207836,43793.0,0.9843904972076416,0.0536238476634025,0.1550489886621871,43793.0,3133.340534448624,4625.398817539215,3133.340534448624,1491.438364028931,0.3579344749450683,0.0 -9800,0.041496158,0.038168408,,,,,,,,,,,,,,,,, -9900,0.034806028,0.040213775,,,,,,,,,,,,,,,,, -10000,0.06853823,0.04100928,,,,,,,,,,,,,,,,, -10100,0.04662127,0.040416513,,,,,,,,,,,,,,,,, -10200,0.031886965,0.044545624,,,,,,,,,,,,,,,,, -10300,0.046712473,0.038849205,,,,,,,,,,,,,,,,, -10400,0.060666747,0.04364281,,,,,,,,,,,,,,,,, -10500,0.027288806,0.039331652,,,,,,,,,,,,,,,,, -10526,,,0.988370418548584,0.0405784808099269,0.1764253451645966,0.9853061437606812,0.0504616163671016,0.1536274094312301,43793.0,0.9843748807907104,0.0532482601702213,0.1629483331394156,43793.0,3373.36945104599,4973.312300443649,3373.36945104599,1599.2725040912628,0.3884594440460205,0.0 -10600,0.032557916,0.038782597,,,,,,,,,,,,,,,,, -10700,0.12085679,0.039579872,,,,,,,,,,,,,,,,, -10800,0.058605015,0.040776778,,,,,,,,,,,,,,,,, -10900,0.038825814,0.041867536,,,,,,,,,,,,,,,,, -11000,0.08746978,0.043299105,,,,,,,,,,,,,,,,, -11100,0.052248597,0.04487,,,,,,,,,,,,,,,,, -11200,0.0332627,0.03952092,,,,,,,,,,,,,,,,, -11280,,,0.9882692694664,0.0407231338322162,0.1806463704697511,0.9853795766830444,0.0506321750581264,0.1516324973639082,43793.0,0.98440819978714,0.0535706207156181,0.1535308390425825,43793.0,3613.459751367569,5317.750963449478,3613.459751367569,1703.5707485675812,0.4186489582061767,0.0 -11300,0.029733034,0.040656682,,,,,,,,,,,,,,,,, -11400,0.045917258,0.040561456,,,,,,,,,,,,,,,,, -11500,0.043296766,0.040496744,,,,,,,,,,,,,,,,, -11600,0.027271776,0.04104187,,,,,,,,,,,,,,,,, -11700,0.03300098,0.044252984,,,,,,,,,,,,,,,,, -11800,0.04022972,0.036556553,,,,,,,,,,,,,,,,, -11900,0.060981873,0.045459133,,,,,,,,,,,,,,,,, -12000,0.03360352,0.039555795,,,,,,,,,,,,,,,,, -12034,,,0.9883570671081544,0.0406102873384952,0.187897305803684,0.9853954315185548,0.0513092726469039,0.1469485466278225,43793.0,0.9845050573349,0.0542274564504623,0.1533889459399407,43793.0,3853.7374968528734,5664.29166841507,3853.7374968528734,1809.784565925598,0.4482748508453369,0.0 -12100,0.029022543,0.03978595,,,,,,,,,,,,,,,,, -12200,0.02305079,0.038631402,,,,,,,,,,,,,,,,, -12300,0.06270307,0.042699642,,,,,,,,,,,,,,,,, -12400,0.047122136,0.039054707,,,,,,,,,,,,,,,,, -12500,0.029173702,0.041207694,,,,,,,,,,,,,,,,, -12600,0.033207238,0.039256226,,,,,,,,,,,,,,,,, -12700,0.061357114,0.04350246,,,,,,,,,,,,,,,,, -12793,,,0.9884112477302552,0.0403369925916194,0.1881470002900639,0.9853706955909728,0.0505053550004959,0.1582621068851679,43793.0,0.9844494462013244,0.0533479265868663,0.160429288795423,43793.0,4093.757108926773,6008.2779705524445,4093.757108926773,1913.699834108353,0.4794890880584717,0.0 -12800,0.09451413,0.040817574,,,,,,,,,,,,,,,,, -12900,0.022314448,0.038867455,,,,,,,,,,,,,,,,, -13000,0.02392749,0.039197244,,,,,,,,,,,,,,,,, -13100,0.045290336,0.035553254,,,,,,,,,,,,,,,,, -13200,0.040849026,0.038249303,,,,,,,,,,,,,,,,, -13300,0.06960114,0.04147885,,,,,,,,,,,,,,,,, -13400,0.038070194,0.039747268,,,,,,,,,,,,,,,,, -13500,0.041720957,0.04306306,,,,,,,,,,,,,,,,, -13551,,,0.9883461594581604,0.040721870958805,0.1881293046142475,0.9853450655937196,0.0510617010295391,0.1625691607394891,43793.0,0.9844717979431152,0.0538974851369857,0.1589984335660068,43793.0,4334.002388477325,6353.242366313934,4334.002388477325,2018.3670988082888,0.5114123821258545,0.0 -13600,0.030522855,0.043168243,,,,,,,,,,,,,,,,, -13700,0.07089634,0.03691132,,,,,,,,,,,,,,,,, -13800,0.034982312,0.0406771,,,,,,,,,,,,,,,,, -13900,0.083685175,0.048307706,,,,,,,,,,,,,,,,, -14000,0.10016646,0.03996741,,,,,,,,,,,,,,,,, -14100,0.049439907,0.041720778,,,,,,,,,,,,,,,,, -14200,0.056843348,0.04623571,,,,,,,,,,,,,,,,, -14300,0.037486658,0.04204954,,,,,,,,,,,,,,,,, -14309,,,0.9886543154716492,0.0395914278924465,0.1886557689800052,0.9855687618255616,0.0496367439627647,0.1672704252272741,43793.0,0.9845977425575256,0.0522046722471714,0.1706820676001708,43793.0,4574.218851804733,6702.318521976471,4574.218851804733,2127.1761713027954,0.542165994644165,0.0 -14400,0.041319054,0.04471642,,,,,,,,,,,,,,,,, -14500,0.055857945,0.0383445,,,,,,,,,,,,,,,,, -14600,0.036160186,0.042926352,,,,,,,,,,,,,,,,, -14700,0.058831174,0.04299041,,,,,,,,,,,,,,,,, -14800,0.043793194,0.036978427,,,,,,,,,,,,,,,,, -14900,0.060424134,0.04169609,,,,,,,,,,,,,,,,, -15000,0.04798126,0.04070208,,,,,,,,,,,,,,,,, -15065,,,0.9882400631904602,0.0407610535621643,0.1768932244139741,0.98537677526474,0.0505119189620018,0.1621677928424066,43793.0,0.984438955783844,0.0533509105443954,0.1598961203343125,43793.0,4814.288403272629,7051.812611579895,4814.288403272629,2236.548712730408,0.5745348930358887,0.0 -15100,0.02990904,0.03760619,,,,,,,,,,,,,,,,, -15200,0.052875184,0.04095115,,,,,,,,,,,,,,,,, -15300,0.02640373,0.044104442,,,,,,,,,,,,,,,,, -15400,0.033335242,0.03896044,,,,,,,,,,,,,,,,, -15500,0.05872879,0.040257704,,,,,,,,,,,,,,,,, -15600,0.029539939,0.041738275,,,,,,,,,,,,,,,,, -15700,0.053765457,0.043025035,,,,,,,,,,,,,,,,, -15800,0.03935892,0.04103006,,,,,,,,,,,,,,,,, -15819,,,0.9884235262870787,0.0403160750865936,0.1800596144394523,0.9854745864868164,0.0501674003899097,0.1589919840343427,43793.0,0.9844840168952942,0.0530835837125778,0.1610488174560943,43793.0,5054.5061230659485,7397.481730699539,5054.5061230659485,2341.950671672821,0.6038038730621338,0.0 -15900,0.02856349,0.038382646,,,,,,,,,,,,,,,,, -16000,0.030290376,0.0349425,,,,,,,,,,,,,,,,, -16100,0.0674792,0.044828337,,,,,,,,,,,,,,,,, -16200,0.117535494,0.040523734,,,,,,,,,,,,,,,,, -16300,0.05219392,0.037426375,,,,,,,,,,,,,,,,, -16400,0.029212536,0.04371221,,,,,,,,,,,,,,,,, -16500,0.056736615,0.035482273,,,,,,,,,,,,,,,,, -16575,,,0.9884944558143616,0.0400652587413787,0.1946530005520942,0.9855245351791382,0.0493463389575481,0.1648574669725729,43793.0,0.9845366477966307,0.0522386245429515,0.1585987983113044,43793.0,5294.63217496872,7745.110958576202,5294.63217496872,2449.4044041633606,0.6335453987121582,0.0 -16600,0.050069023,0.04097883,,,,,,,,,,,,,,,,, -16700,0.054604933,0.04531509,,,,,,,,,,,,,,,,, -16800,0.046125095,0.03785271,,,,,,,,,,,,,,,,, -16900,0.05212537,0.04119502,,,,,,,,,,,,,,,,, -17000,0.068972394,0.043087874,,,,,,,,,,,,,,,,, -17100,0.041108396,0.03927431,,,,,,,,,,,,,,,,, -17200,0.055347808,0.038465466,,,,,,,,,,,,,,,,, -17300,0.03091043,0.040265054,,,,,,,,,,,,,,,,, -17338,,,0.988466501235962,0.0401264503598213,0.1891231957882106,0.9855257272720336,0.049877855926752,0.1615751031520323,43793.0,0.9845648407936096,0.0527358315885067,0.1628610260560505,43793.0,5534.787349462509,8088.665885448456,5534.787349462509,2552.754077911377,0.6638550758361816,0.0 -17400,0.046137094,0.038009275,,,,,,,,,,,,,,,,, -17500,0.103060655,0.039427366,,,,,,,,,,,,,,,,, -17600,0.022250947,0.038531397,,,,,,,,,,,,,,,,, -17700,0.037642505,0.035016205,,,,,,,,,,,,,,,,, -17800,0.044331152,0.036842868,,,,,,,,,,,,,,,,, -17900,0.054433133,0.038489107,,,,,,,,,,,,,,,,, -18000,0.10183462,0.046227396,,,,,,,,,,,,,,,,, -18100,0.080129586,0.039220963,,,,,,,,,,,,,,,,, -18102,,,0.9886404275894164,0.0396051965653896,0.2016534385517859,0.985596776008606,0.0498393774032592,0.1660370317504795,43793.0,0.9846120476722716,0.0526713393628597,0.1660621285801552,43793.0,5774.893808841705,8433.951842308044,5774.893808841705,2657.8833651542664,0.6942009925842285,0.0 -18200,0.035690475,0.041773066,,,,,,,,,,,,,,,,, -18300,0.05457016,0.04050163,,,,,,,,,,,,,,,,, -18400,0.037163496,0.037930652,,,,,,,,,,,,,,,,, -18500,0.0282006,0.036861278,,,,,,,,,,,,,,,,, -18600,0.059437655,0.04100105,,,,,,,,,,,,,,,,, -18700,0.038889505,0.036497757,,,,,,,,,,,,,,,,, -18800,0.057795156,0.037622172,,,,,,,,,,,,,,,,, -18858,,,0.9884719252586364,0.0400543436408042,0.1856856220138234,0.9854221940040588,0.0504087135195732,0.1566981421152463,43793.0,0.9844696521759032,0.0534532889723777,0.1569761173961862,43793.0,6014.983092546463,8780.828848361969,6014.983092546463,2764.619640350342,0.7254533767700195,0.0 -18900,0.068137355,0.04062159,,,,,,,,,,,,,,,,, -19000,0.04239826,0.040084884,,,,,,,,,,,,,,,,, -19100,0.045895852,0.040861707,,,,,,,,,,,,,,,,, -19200,0.065126374,0.044394966,,,,,,,,,,,,,,,,, -19300,0.025654886,0.0434852,,,,,,,,,,,,,,,,, -19400,0.031419285,0.03899859,,,,,,,,,,,,,,,,, -19500,0.035256173,0.040029764,,,,,,,,,,,,,,,,, -19600,0.05701829,0.04095657,,,,,,,,,,,,,,,,, -19617,,,0.9885842204093932,0.039523422718048,0.1949042477951265,0.9855241179466248,0.0500289462506771,0.1680505834962828,43793.0,0.9845610857009888,0.0529608465731143,0.1662060151744309,43793.0,6255.108852148056,9122.684712648392,6255.108852148056,2866.300521612168,0.7550556659698486,0.0 -19700,0.06391672,0.036653582,,,,,,,,,,,,,,,,, -19800,0.0296462,0.037826505,,,,,,,,,,,,,,,,, -19900,0.025551755,0.039331272,,,,,,,,,,,,,,,,, -20000,0.025012849,0.035805874,,,,,,,,,,,,,,,,, -20100,0.04362743,0.035940725,,,,,,,,,,,,,,,,, -20200,0.045095447,0.03454575,,,,,,,,,,,,,,,,, -20300,0.035285905,0.039336022,,,,,,,,,,,,,,,,, -20370,,,0.9885866641998292,0.0403407551348209,0.190552363333073,0.985431969165802,0.0497220121324062,0.1618298861649432,43793.0,0.984565258026123,0.0525690093636512,0.1609021073804863,43793.0,6495.332993984222,9472.132626056671,6495.332993984222,2975.4721944332123,0.787273645401001,0.0 -20400,0.09305106,0.04089195,,,,,,,,,,,,,,,,, -20500,0.026319934,0.042933203,,,,,,,,,,,,,,,,, -20600,0.102135405,0.03262291,,,,,,,,,,,,,,,,, -20700,0.0354608,0.04271416,,,,,,,,,,,,,,,,, -20800,0.043929096,0.038486194,,,,,,,,,,,,,,,,, -20900,0.12807621,0.03913547,,,,,,,,,,,,,,,,, -21000,0.054280147,0.04136372,,,,,,,,,,,,,,,,, -21100,0.035538115,0.039675128,,,,,,,,,,,,,,,,, -21128,,,0.9888412952423096,0.0390158295631408,0.1984423527563512,0.9855196475982666,0.0496362783014774,0.1704775937419216,43793.0,0.984618365764618,0.0525290071964263,0.1673807543089017,43793.0,6735.384921073914,9814.043020009996,6735.384921073914,3077.280902385712,0.8171370029449463,0.0 -21200,0.054993354,0.039250873,,,,,,,,,,,,,,,,, -21300,0.024489127,0.039003413,,,,,,,,,,,,,,,,, -21400,0.051249042,0.03923173,,,,,,,,,,,,,,,,, -21500,0.10221082,0.04408939,,,,,,,,,,,,,,,,, -21600,0.061161056,0.04058823,,,,,,,,,,,,,,,,, -21700,0.07934924,0.03555709,,,,,,,,,,,,,,,,, -21800,0.06851134,0.041661367,,,,,,,,,,,,,,,,, -21888,,,0.988599419593811,0.0396040715277195,0.1955595368230457,0.9856069087982178,0.0493299253284931,0.1709071697455449,43793.0,0.9847084879875184,0.052076943218708,0.1687030523777823,43793.0,6975.345520019531,10161.952837228777,6975.345520019531,3185.17812037468,0.8489246368408203,0.0 -21900,0.07820616,0.04300594,,,,,,,,,,,,,,,,, -22000,0.08259749,0.04042268,,,,,,,,,,,,,,,,, -22100,0.116805986,0.04240753,,,,,,,,,,,,,,,,, -22200,0.035567045,0.04141854,,,,,,,,,,,,,,,,, -22300,0.044656016,0.041500505,,,,,,,,,,,,,,,,, -22400,0.10166704,0.042098638,,,,,,,,,,,,,,,,, -22500,0.039108448,0.040253002,,,,,,,,,,,,,,,,, -22600,0.03601566,0.040666115,,,,,,,,,,,,,,,,, -22640,,,0.9885961413383484,0.0399170108139514,0.19312791976085,0.9854997396469116,0.0498074442148208,0.1646969130909154,43793.0,0.9846465587615968,0.0523969121277332,0.1618779776619127,43793.0,7215.383780956268,10508.548904895782,7215.383780956268,3291.6833214759827,0.8813366889953613,0.0 -22700,0.045995045,0.036598206,,,,,,,,,,,,,,,,, -22800,0.055868484,0.04036023,,,,,,,,,,,,,,,,, -22900,0.040327925,0.039332923,,,,,,,,,,,,,,,,, -23000,0.09162715,0.038607977,,,,,,,,,,,,,,,,, -23100,0.047371574,0.038910206,,,,,,,,,,,,,,,,, -23200,0.06071579,0.039582837,,,,,,,,,,,,,,,,, -23300,0.034244347,0.040918656,,,,,,,,,,,,,,,,, -23395,,,0.9885261058807372,0.0398454070091247,0.2045054501347917,0.9855204820632936,0.0496206805109977,0.1696779981251334,43793.0,0.9845758080482484,0.052251573652029,0.1668913197902987,43793.0,7455.464512586594,10854.205847978592,7455.464512586594,3397.201486110688,0.9188354015350342,0.0 -23400,0.07604526,0.03581498,,,,,,,,,,,,,,,,, -23500,0.053623777,0.044666275,,,,,,,,,,,,,,,,, -23600,0.053122886,0.03554492,,,,,,,,,,,,,,,,, -23700,0.046750624,0.041346528,,,,,,,,,,,,,,,,, -23800,0.04286584,0.040012754,,,,,,,,,,,,,,,,, -23900,0.06558863,0.039877273,,,,,,,,,,,,,,,,, -24000,0.048435234,0.04071402,,,,,,,,,,,,,,,,, -24100,0.06699298,0.041417908,,,,,,,,,,,,,,,,, -24150,,,0.9885877370834352,0.0395864844322204,0.198987881596415,0.9856178760528564,0.0494555197656154,0.1701272548379744,43793.0,0.9847046732902528,0.0524015016853809,0.1677331479726042,43793.0,7695.480072259903,11197.75633430481,7695.480072259903,3500.685555458069,0.9494316577911376,0.0 -24200,0.04464273,0.039795727,,,,,,,,,,,,,,,,, -24300,0.057955507,0.034196436,,,,,,,,,,,,,,,,, -24400,0.074622095,0.033240404,,,,,,,,,,,,,,,,, -24500,0.055578243,0.038032897,,,,,,,,,,,,,,,,, -24600,0.08114985,0.038096722,,,,,,,,,,,,,,,,, -24700,0.058287222,0.044495534,,,,,,,,,,,,,,,,, -24800,0.048428863,0.033692867,,,,,,,,,,,,,,,,, -24895,,,0.9885725378990172,0.0396206192672252,0.1884983398828092,0.985593557357788,0.0495513565838336,0.1638078991555382,43793.0,0.9846591949462892,0.0524363219738006,0.1622380428695896,43793.0,7935.451113462448,11544.484077215197,7935.451113462448,3607.390516757965,0.980963945388794,0.0 -24900,0.05963468,0.0420189,,,,,,,,,,,,,,,,, -25000,0.05453759,0.03813148,,,,,,,,,,,,,,,,, -25100,0.04131412,0.037194695,,,,,,,,,,,,,,,,, -25200,0.061264098,0.040742874,,,,,,,,,,,,,,,,, -25300,0.036273584,0.03743152,,,,,,,,,,,,,,,,, -25400,0.05397074,0.037899073,,,,,,,,,,,,,,,,, -25500,0.03887081,0.034827687,,,,,,,,,,,,,,,,, -25600,0.03980176,0.039196886,,,,,,,,,,,,,,,,, -25646,,,0.9884662628173828,0.0406963936984539,0.1889124579895067,0.9852277636528016,0.0517746023833751,0.1620684550428439,43793.0,0.9842435121536256,0.0553393326699733,0.14969740247395,43793.0,8175.517338037491,11890.685946941376,8175.517338037491,3713.474276781082,1.012589931488037,0.0 -25700,0.06848494,0.04406819,,,,,,,,,,,,,,,,, -25800,0.0645594,0.041165188,,,,,,,,,,,,,,,,, -25900,0.049572144,0.042834915,,,,,,,,,,,,,,,,, -26000,0.04429525,0.03834771,,,,,,,,,,,,,,,,, -26100,0.05620775,0.044592354,,,,,,,,,,,,,,,,, -26200,0.033965647,0.040521976,,,,,,,,,,,,,,,,, -26300,0.034031384,0.03732816,,,,,,,,,,,,,,,,, -26396,,,0.9886945486068726,0.0391205176711082,0.2076740135078882,0.9855862259864808,0.0494204312562942,0.1699061284238353,43793.0,0.9846705794334412,0.0522284172475338,0.1665207318270333,43793.0,8415.550181627274,12235.662456035614,8415.550181627274,3818.364675998688,1.0460069179534912,0.0 -26400,0.06898129,0.04067807,,,,,,,,,,,,,,,,, -26500,0.043417864,0.037726775,,,,,,,,,,,,,,,,, -26600,0.03815891,0.037853684,,,,,,,,,,,,,,,,, -26700,0.0807708,0.041834675,,,,,,,,,,,,,,,,, -26800,0.040659238,0.042023405,,,,,,,,,,,,,,,,, -26900,0.08843373,0.045593437,,,,,,,,,,,,,,,,, -27000,0.053341813,0.038475174,,,,,,,,,,,,,,,,, -27100,0.057815343,0.041868627,,,,,,,,,,,,,,,,, -27156,,,0.9885122776031494,0.0393948778510093,0.2045708442516365,0.985469937324524,0.0498745180666446,0.1768885285711521,43793.0,0.9845792055130004,0.0527769736945629,0.1750600617505908,43793.0,8655.730852127075,12581.70412158966,8655.730852127075,3924.174050092697,1.0775668621063232,0.0 -27200,0.048329107,0.037262928,,,,,,,,,,,,,,,,, -27300,0.03304073,0.035434566,,,,,,,,,,,,,,,,, -27400,0.043322995,0.041908875,,,,,,,,,,,,,,,,, -27500,0.06414057,0.04370503,,,,,,,,,,,,,,,,, -27600,0.120246574,0.03768408,,,,,,,,,,,,,,,,, -27700,0.046758052,0.04072787,,,,,,,,,,,,,,,,, -27800,0.047443673,0.039688934,,,,,,,,,,,,,,,,, -27900,0.31409508,0.06294956,,,,,,,,,,,,,,,,, -27913,,,0.9870010018348694,0.0472155846655368,0.1103796881982437,0.9840448498725892,0.0574222430586814,0.0985057329111871,43793.0,0.9831374287605286,0.0604283660650253,0.1038184794279353,43793.0,8895.984867811203,12928.67640185356,8895.984867811203,4030.839526414871,1.110039234161377,0.0 -28000,0.06616247,0.040505134,,,,,,,,,,,,,,,,, -28100,0.08419733,0.042615652,,,,,,,,,,,,,,,,, -28200,0.045755323,0.043243635,,,,,,,,,,,,,,,,, -28300,0.05144089,0.040762622,,,,,,,,,,,,,,,,, -28400,0.0415426,0.03756537,,,,,,,,,,,,,,,,, -28500,0.03458117,0.040674802,,,,,,,,,,,,,,,,, -28600,0.038907934,0.03706369,,,,,,,,,,,,,,,,, -28671,,,0.98862624168396,0.0392830222845077,0.2086000109865182,0.9856077432632446,0.0498518608510494,0.1624308171851189,43793.0,0.9846815466880798,0.052851814776659,0.1664202583307426,43793.0,9135.996017217636,13274.639439105988,9135.996017217636,4136.737054347992,1.1446597576141355,0.0 -28700,0.06682213,0.040091936,,,,,,,,,,,,,,,,, -28800,0.067805074,0.039428066,,,,,,,,,,,,,,,,, -28900,0.076561645,0.0347241,,,,,,,,,,,,,,,,, -29000,0.039944656,0.03533345,,,,,,,,,,,,,,,,, -29100,0.049924683,0.0372737,,,,,,,,,,,,,,,,, -29200,0.03897362,0.03557658,,,,,,,,,,,,,,,,, -29300,0.05060842,0.04272787,,,,,,,,,,,,,,,,, -29400,0.03253785,0.035124935,,,,,,,,,,,,,,,,, -29423,,,0.9888563752174376,0.0386092141270637,0.1952220106581476,0.9856755137443542,0.0489604957401752,0.1736647414079732,43793.0,0.9847674369812012,0.0517728812992572,0.1757862534867406,43793.0,9376.22216939926,13618.191462039948,9376.22216939926,4240.01061463356,1.1763203144073486,0.0 -29500,0.09125552,0.038205903,,,,,,,,,,,,,,,,, -29600,0.073692255,0.03851399,,,,,,,,,,,,,,,,, -29700,0.03688989,0.03595013,,,,,,,,,,,,,,,,, -29800,0.061657775,0.035614166,,,,,,,,,,,,,,,,, -29900,0.049871456,0.042207737,,,,,,,,,,,,,,,,, -30000,0.042765163,0.041588936,,,,,,,,,,,,,,,,, -30100,0.076253645,0.041524485,,,,,,,,,,,,,,,,, -30177,,,0.9887091517448424,0.039166934788227,0.2039721153390068,0.9856755137443542,0.0492084473371505,0.1739264363776995,43793.0,0.9847522974014282,0.0520043931901454,0.1742552044543923,43793.0,9616.297856807709,13963.307752609251,9616.297856807709,4344.998346328735,1.209303617477417,0.0 -30200,0.030261714,0.03764597,,,,,,,,,,,,,,,,, -30300,0.09219082,0.037977636,,,,,,,,,,,,,,,,, -30400,0.09005538,0.039119516,,,,,,,,,,,,,,,,, -30500,0.056810338,0.04429206,,,,,,,,,,,,,,,,, -30600,0.058699388,0.034496065,,,,,,,,,,,,,,,,, -30700,0.038731817,0.033831567,,,,,,,,,,,,,,,,, -30800,0.052409276,0.039652757,,,,,,,,,,,,,,,,, -30900,0.075235605,0.04178394,,,,,,,,,,,,,,,,, -30930,,,0.9886908531188964,0.0390931963920593,0.1994932528086023,0.9856739044189452,0.0495341904461383,0.1778154397527345,43793.0,0.98477041721344,0.0525437444448471,0.1694027176352054,43793.0,9856.373237371445,14304.544981956482,9856.373237371445,4446.108120918274,1.2410881519317627,0.0 -31000,0.060562335,0.038921632,,,,,,,,,,,,,,,,, -31100,0.051469397,0.036407027,,,,,,,,,,,,,,,,, -31200,0.057042915,0.04077182,,,,,,,,,,,,,,,,, -31300,0.031845875,0.038340308,,,,,,,,,,,,,,,,, -31400,0.04363073,0.03510689,,,,,,,,,,,,,,,,, -31500,0.045257907,0.03203849,,,,,,,,,,,,,,,,, -31600,0.043237712,0.04051863,,,,,,,,,,,,,,,,, -31693,,,0.9885149598121644,0.0396221540868282,0.194215934757525,0.9857165217399596,0.0492618419229984,0.1747485124816603,43793.0,0.9847164750099182,0.0521720796823501,0.1737591675951564,43793.0,10096.422625780106,14649.05891919136,10096.422625780106,4550.519570350647,1.2745091915130615,0.0 -31700,0.049328513,0.036722347,,,,,,,,,,,,,,,,, -31800,0.06487567,0.044282697,,,,,,,,,,,,,,,,, -31900,0.06692086,0.0406805,,,,,,,,,,,,,,,,, -32000,0.08715615,0.043868244,,,,,,,,,,,,,,,,, -32100,0.057084676,0.033942614,,,,,,,,,,,,,,,,, -32200,0.05576125,0.041195188,,,,,,,,,,,,,,,,, -32300,0.056865133,0.035672616,,,,,,,,,,,,,,,,, -32400,0.07153117,0.03895532,,,,,,,,,,,,,,,,, -32452,,,0.988632082939148,0.0392586700618267,0.204369646453474,0.9855963587760924,0.0492731444537639,0.175708370852377,43793.0,0.9845787882804872,0.0522355996072292,0.1658119645684935,43793.0,10336.437881469728,14996.374430418016,10336.437881469728,4657.76614189148,1.3084368705749512,0.0 -32500,0.06608441,0.03520349,,,,,,,,,,,,,,,,, -32600,0.04855947,0.038118307,,,,,,,,,,,,,,,,, -32700,0.038487386,0.040658597,,,,,,,,,,,,,,,,, -32800,0.05046722,0.04054818,,,,,,,,,,,,,,,,, -32900,0.046322677,0.04385421,,,,,,,,,,,,,,,,, -33000,0.07934231,0.038463715,,,,,,,,,,,,,,,,, -33100,0.087140895,0.041335385,,,,,,,,,,,,,,,,, -33200,0.04257717,0.03581385,,,,,,,,,,,,,,,,, -33217,,,0.9888612031936646,0.0386213921010494,0.2074459373809987,0.9858188033103944,0.0493051446974277,0.1833785917632685,43793.0,0.9848264455795288,0.0522915534675121,0.1770464495206037,43793.0,10576.67140340805,15339.427662611008,10576.67140340805,4760.532192707062,1.342067003250122,0.0 -33300,0.04688944,0.04136968,,,,,,,,,,,,,,,,, -33400,0.066133484,0.04495192,,,,,,,,,,,,,,,,, -33500,0.05721685,0.040432002,,,,,,,,,,,,,,,,, -33600,0.07295847,0.03665052,,,,,,,,,,,,,,,,, -33700,0.05096222,0.041204438,,,,,,,,,,,,,,,,, -33800,0.057523422,0.03266647,,,,,,,,,,,,,,,,, -33900,0.123347566,0.037318386,,,,,,,,,,,,,,,,, -33968,,,0.988823652267456,0.0387829691171646,0.2097071365273531,0.9857664704322816,0.0491368472576141,0.1783703510066827,43793.0,0.9848175644874572,0.0521162152290344,0.1782565607649952,43793.0,10816.757288694382,15688.211569309236,10816.757288694382,4869.175209760666,1.3774352073669434,0.0 -34000,0.064796716,0.042403996,,,,,,,,,,,,,,,,, -34100,0.056608193,0.036287095,,,,,,,,,,,,,,,,, -34200,0.10959189,0.035295974,,,,,,,,,,,,,,,,, -34300,0.0675878,0.037117265,,,,,,,,,,,,,,,,, -34400,0.09761157,0.04511398,,,,,,,,,,,,,,,,, -34500,0.06421666,0.039224528,,,,,,,,,,,,,,,,, -34600,0.04748691,0.039991282,,,,,,,,,,,,,,,,, -34700,0.06254627,0.04289279,,,,,,,,,,,,,,,,, -34725,,,0.9889429807662964,0.0380373783409595,0.2179248873481545,0.9857100248336792,0.0489446595311164,0.171346110199196,43793.0,0.9847207069396972,0.0518995597958564,0.1696407534624433,43793.0,11056.986756563188,16036.28269791603,11056.986756563188,4976.963281154633,1.411080837249756,0.0 -34800,0.043713365,0.04084126,,,,,,,,,,,,,,,,, -34900,0.042401496,0.039820824,,,,,,,,,,,,,,,,, -35000,0.064418785,0.038489416,,,,,,,,,,,,,,,,, -35100,0.047412354,0.042029332,,,,,,,,,,,,,,,,, -35200,0.1021271,0.036432493,,,,,,,,,,,,,,,,, -35300,0.059039354,0.038489077,,,,,,,,,,,,,,,,, -35400,0.04699734,0.037177164,,,,,,,,,,,,,,,,, -35479,,,0.9887934327125548,0.0383532419800758,0.2175011480776211,0.985688328742981,0.0491277314722538,0.1822192971931874,43793.0,0.9847089052200316,0.0522429086267948,0.1731621813997392,43793.0,11297.105067253113,16382.713775157928,11297.105067253113,5083.223834276199,1.4434118270874023,0.0 -35500,0.042791434,0.039906945,,,,,,,,,,,,,,,,, -35600,0.07311728,0.036165703,,,,,,,,,,,,,,,,, -35700,0.0704559,0.036620237,,,,,,,,,,,,,,,,, -35800,0.13432449,0.043021854,,,,,,,,,,,,,,,,, -35900,0.055408906,0.04140858,,,,,,,,,,,,,,,,, -36000,0.11130322,0.038939234,,,,,,,,,,,,,,,,, -36100,0.066051036,0.03614865,,,,,,,,,,,,,,,,, -36200,0.060860574,0.04001861,,,,,,,,,,,,,,,,, -36232,,,0.9889160990715028,0.038255613297224,0.2174486785456183,0.9857396483421326,0.0487405806779861,0.1753565414250365,43793.0,0.984761118888855,0.0518080927431583,0.1681800602295691,43793.0,11537.17705130577,16728.115339756012,11537.17705130577,5188.498993873596,1.47762131690979,0.0 -36300,0.06460517,0.039481312,,,,,,,,,,,,,,,,, -36400,0.045544278,0.039218042,,,,,,,,,,,,,,,,, -36500,0.064692184,0.04074554,,,,,,,,,,,,,,,,, -36600,0.06964855,0.038002387,,,,,,,,,,,,,,,,, -36700,0.06256195,0.037209988,,,,,,,,,,,,,,,,, -36800,0.044288173,0.04124161,,,,,,,,,,,,,,,,, -36900,0.045673963,0.04214897,,,,,,,,,,,,,,,,, -36979,,,0.9888588786125184,0.0385875515639781,0.2114442024080533,0.9856597185134888,0.0490989945828914,0.173963467929486,43793.0,0.9847059845924376,0.0520849637687206,0.1709051116411258,43793.0,11777.208164453506,17072.413522958755,11777.208164453506,5292.71022772789,1.513559103012085,0.0 -37000,0.06735428,0.03875807,,,,,,,,,,,,,,,,, -37100,0.040439803,0.037872896,,,,,,,,,,,,,,,,, -37200,0.054867055,0.032429963,,,,,,,,,,,,,,,,, -37300,0.04761156,0.04069595,,,,,,,,,,,,,,,,, -37400,0.08536161,0.040788285,,,,,,,,,,,,,,,,, -37500,0.053586155,0.035420496,,,,,,,,,,,,,,,,, -37600,0.051695615,0.043282166,,,,,,,,,,,,,,,,, -37700,0.079915166,0.038574,,,,,,,,,,,,,,,,, -37734,,,0.988965630531311,0.0381307713687419,0.2163211936202405,0.9856706857681274,0.0491302870213985,0.1755820478866468,43793.0,0.9846655130386353,0.0519244857132434,0.168482360769061,43793.0,12017.201822519302,17419.296885967255,12017.201822519302,5399.546529531479,1.54679536819458,0.0 -37800,0.043777548,0.03945302,,,,,,,,,,,,,,,,, -37900,0.055122253,0.043368224,,,,,,,,,,,,,,,,, -38000,0.05605896,0.03660482,,,,,,,,,,,,,,,,, -38100,0.08173208,0.0432627,,,,,,,,,,,,,,,,, -38200,0.050745122,0.038277376,,,,,,,,,,,,,,,,, -38300,0.069504045,0.040206615,,,,,,,,,,,,,,,,, -38400,0.051966224,0.03844922,,,,,,,,,,,,,,,,, -38494,,,0.9889072179794312,0.0382791757583618,0.2058152897816714,0.9858139753341676,0.0483922623097896,0.1763234303719536,43793.0,0.9848403334617616,0.0511401407420635,0.1729110865622203,43793.0,12257.161703586578,17762.10279250145,12257.161703586578,5502.337423086166,1.5815489292144775,0.0 -38500,0.049427357,0.043384865,,,,,,,,,,,,,,,,, -38600,0.0478232,0.041710626,,,,,,,,,,,,,,,,, -38700,0.05518548,0.03760816,,,,,,,,,,,,,,,,, -38800,0.06231375,0.037473198,,,,,,,,,,,,,,,,, -38900,0.056103118,0.03630571,,,,,,,,,,,,,,,,, -39000,0.07311313,0.042928766,,,,,,,,,,,,,,,,, -39100,0.047178086,0.0363479,,,,,,,,,,,,,,,,, -39200,0.045192406,0.036422174,,,,,,,,,,,,,,,,, -39250,,,0.9886656999588012,0.0391177199780941,0.209000849892865,0.9853398203849792,0.0495943687856197,0.1764142580300723,43793.0,0.9843534231185912,0.0526152662932872,0.1698303266528774,43793.0,12497.155459403992,18107.39728331566,12497.155459403992,5607.583341121674,1.6165733337402344,0.0 -39300,0.06540109,0.036385406,,,,,,,,,,,,,,,,, -39400,0.047122274,0.04106301,,,,,,,,,,,,,,,,, -39500,0.07171911,0.040592063,,,,,,,,,,,,,,,,, -39600,0.07025954,0.038922817,,,,,,,,,,,,,,,,, -39700,0.049477167,0.035072062,,,,,,,,,,,,,,,,, -39800,0.16318864,0.045028593,,,,,,,,,,,,,,,,, -39900,0.046442013,0.03851647,,,,,,,,,,,,,,,,, -40000,0.0745759,0.04577438,,,,,,,,,,,,,,,,, -40008,,,0.988823652267456,0.0383989922702312,0.2118176751457389,0.9857344031333924,0.048629205673933,0.1850538860129511,43793.0,0.9848032593727112,0.0515053868293762,0.176415089457208,43793.0,12737.230195999146,18451.95160317421,12737.230195999146,5712.008977174759,1.6503448486328125,0.0 -40100,0.095958814,0.03640744,,,,,,,,,,,,,,,,, -40200,0.05177801,0.03771259,,,,,,,,,,,,,,,,, -40300,0.059390318,0.036592502,,,,,,,,,,,,,,,,, -40400,0.07105537,0.04006803,,,,,,,,,,,,,,,,, -40500,0.05232524,0.036469344,,,,,,,,,,,,,,,,, -40600,0.10484053,0.034298737,,,,,,,,,,,,,,,,, -40700,0.057976395,0.037224952,,,,,,,,,,,,,,,,, -40765,,,0.988957405090332,0.0381148569285869,0.2250403524375973,0.9858992099761964,0.0484270341694355,0.1902167060453705,43793.0,0.98490309715271,0.0514543503522872,0.1773463244063912,43793.0,12977.429183483124,18795.01301598549,12977.429183483124,5814.817209243774,1.6850457191467283,0.0 -40800,0.075324856,0.03617942,,,,,,,,,,,,,,,,, -40900,0.05296647,0.040249802,,,,,,,,,,,,,,,,, -41000,0.056467734,0.038797513,,,,,,,,,,,,,,,,, -41100,0.0662855,0.038864188,,,,,,,,,,,,,,,,, -41200,0.058369134,0.03682999,,,,,,,,,,,,,,,,, -41300,0.03997094,0.03401049,,,,,,,,,,,,,,,,, -41400,0.06655545,0.037216663,,,,,,,,,,,,,,,,, -41500,0.0664067,0.044461157,,,,,,,,,,,,,,,,, -41525,,,0.9889556169509888,0.0379303097724914,0.2247792695009524,0.9858963489532472,0.0484644919633865,0.1854287582158917,43793.0,0.9849595427513124,0.0513292327523231,0.1781156339557913,43793.0,13217.376702070236,19139.04503273964,13217.376702070236,5918.8470821380615,1.7193021774291992,0.0 -41600,0.037655264,0.03893754,,,,,,,,,,,,,,,,, -41700,0.12052763,0.037611324,,,,,,,,,,,,,,,,, -41800,0.1054071,0.03454786,,,,,,,,,,,,,,,,, -41900,0.06257117,0.03470516,,,,,,,,,,,,,,,,, -42000,0.052627478,0.037152898,,,,,,,,,,,,,,,,, -42100,0.055538833,0.035377838,,,,,,,,,,,,,,,,, -42200,0.06078399,0.039022695,,,,,,,,,,,,,,,,, -42284,,,0.9891242384910583,0.037462193518877,0.2190261638573074,0.9858837723731996,0.0484282746911048,0.1882398772699011,43793.0,0.9849300384521484,0.0514775849878788,0.1728550651727835,43793.0,13457.534606933594,19485.08216929436,13457.534606933594,6024.671246051788,1.7549612522125244,0.0 -42300,0.07745964,0.041029625,,,,,,,,,,,,,,,,, -42400,0.09731044,0.041164923,,,,,,,,,,,,,,,,, -42500,0.053610552,0.036494557,,,,,,,,,,,,,,,,, -42600,0.06804194,0.038114008,,,,,,,,,,,,,,,,, -42700,0.06826506,0.03730451,,,,,,,,,,,,,,,,, -42800,0.08269555,0.03958811,,,,,,,,,,,,,,,,, -42900,0.06561653,0.036135074,,,,,,,,,,,,,,,,, -43000,0.11625347,0.03588023,,,,,,,,,,,,,,,,, -43039,,,0.9891573190689088,0.0371311753988266,0.2342620274547162,0.9858505129814148,0.0484101846814155,0.1830095767989059,43793.0,0.9848761558532716,0.0512370876967906,0.1755742926960362,43793.0,13697.556046247482,19833.57168912888,13697.556046247482,6133.084717750549,1.789576292037964,0.0 -43100,0.15200809,0.04154525,,,,,,,,,,,,,,,,, -43200,0.17986287,0.040061906,,,,,,,,,,,,,,,,, -43300,0.06696193,0.037099164,,,,,,,,,,,,,,,,, -43400,0.057937514,0.040213566,,,,,,,,,,,,,,,,, -43500,0.07406241,0.03721264,,,,,,,,,,,,,,,,, -43600,0.08042214,0.041519478,,,,,,,,,,,,,,,,, -43700,0.13733636,0.040770665,,,,,,,,,,,,,,,,, -43794,,,0.989040195941925,0.0374348014593124,0.2364169550507474,0.9858590364456176,0.0483596324920654,0.1810546301375452,43793.0,0.984913170337677,0.0510397516191005,0.1783023438325755,43793.0,13937.573343992231,20183.47554087639,13937.573343992231,6242.9166431427,1.824268341064453,0.0 -43800,0.09674201,0.036735166,,,,,,,,,,,,,,,,, -43900,0.08767714,0.03696667,,,,,,,,,,,,,,,,, -44000,0.23112184,0.03894147,,,,,,,,,,,,,,,,, -44100,0.06104265,0.038186993,,,,,,,,,,,,,,,,, -44200,0.113742895,0.03822902,,,,,,,,,,,,,,,,, -44300,0.13894235,0.039531708,,,,,,,,,,,,,,,,, -44400,0.119574696,0.03460856,,,,,,,,,,,,,,,,, -44500,0.10948115,0.035270438,,,,,,,,,,,,,,,,, -44546,,,0.9892123341560364,0.0371120907366275,0.233629171166757,0.9858176112174988,0.0486885346472263,0.1865813241001744,43793.0,0.9848727583885192,0.0515077896416187,0.1787041508472491,43793.0,14177.521633148192,20525.7474205494,14177.521633148192,6345.184635639191,1.8598694801330569,0.0 -44600,0.066353016,0.04299562,,,,,,,,,,,,,,,,, -44700,0.08496302,0.034649484,,,,,,,,,,,,,,,,, -44800,0.07218529,0.03726144,,,,,,,,,,,,,,,,, -44900,0.055184085,0.039386965,,,,,,,,,,,,,,,,, -45000,0.056509867,0.03469642,,,,,,,,,,,,,,,,, -45100,0.072352946,0.03893326,,,,,,,,,,,,,,,,, -45200,0.094079666,0.036553353,,,,,,,,,,,,,,,,, -45299,,,0.989071786403656,0.0375232920050621,0.2258647965845565,0.9858951568603516,0.0485557839274406,0.1847851065721733,43793.0,0.9849607944488524,0.0515817105770111,0.1814172696895895,43793.0,14417.702919960022,20869.454118013386,14417.702919960022,6448.654628753662,1.8947200775146484,0.0 -45300,0.11544393,0.035656124,,,,,,,,,,,,,,,,, -45400,0.06348733,0.036627628,,,,,,,,,,,,,,,,, -45500,0.079438545,0.03659054,,,,,,,,,,,,,,,,, -45600,0.0972612,0.036794614,,,,,,,,,,,,,,,,, -45700,0.091879524,0.039137788,,,,,,,,,,,,,,,,, -45800,0.08850249,0.034703597,,,,,,,,,,,,,,,,, -45900,0.087163985,0.037739493,,,,,,,,,,,,,,,,, -46000,0.047444157,0.039486162,,,,,,,,,,,,,,,,, -46053,,,0.989142119884491,0.0371794253587722,0.227694452117626,0.9859349131584167,0.0484732091426849,0.1875919367502804,43793.0,0.9849578142166138,0.0514113120734691,0.1808884317825607,43793.0,14657.908961057665,21219.284123420715,14657.908961057665,6558.222561836243,1.9303202629089355,0.0 -46100,0.05913173,0.034380287,,,,,,,,,,,,,,,,, -46200,0.05957967,0.03647527,,,,,,,,,,,,,,,,, -46300,0.08622191,0.038339503,,,,,,,,,,,,,,,,, -46400,0.11615519,0.032995254,,,,,,,,,,,,,,,,, -46500,0.15661368,0.0398799,,,,,,,,,,,,,,,,, -46600,0.08322495,0.03827539,,,,,,,,,,,,,,,,, -46700,0.080869585,0.04226951,,,,,,,,,,,,,,,,, -46800,0.07139504,0.040078882,,,,,,,,,,,,,,,,, -46810,,,0.989086389541626,0.0373251177370548,0.2304754300024778,0.9859126210212708,0.0482357218861579,0.1895051070078998,43793.0,0.9849451780319214,0.0510099902749061,0.1808509299838068,43793.0,14898.099229335783,21560.853466033936,14898.099229335783,6659.538984060288,1.972861051559448,0.0 -46900,0.077639334,0.041613385,,,,,,,,,,,,,,,,, -47000,0.084141724,0.03383318,,,,,,,,,,,,,,,,, -47100,0.1438701,0.038067054,,,,,,,,,,,,,,,,, -47200,0.1393525,0.038684744,,,,,,,,,,,,,,,,, -47300,0.08210447,0.04111315,,,,,,,,,,,,,,,,, -47400,0.07066001,0.034509033,,,,,,,,,,,,,,,,, -47500,0.12275029,0.03802591,,,,,,,,,,,,,,,,, -47566,,,0.9891598224639891,0.0371924303472042,0.2341049216560896,0.9859057068824768,0.0480745211243629,0.1944715354910958,43793.0,0.9849637150764464,0.0508946739137172,0.1829725058936036,43793.0,15138.353684902191,21909.236936807632,15138.353684902191,6767.612953901291,2.0079030990600586,0.0 -47600,0.05781228,0.03897435,,,,,,,,,,,,,,,,, -47700,0.07295887,0.040230215,,,,,,,,,,,,,,,,, -47800,0.12374906,0.039914943,,,,,,,,,,,,,,,,, -47900,0.07362344,0.038318463,,,,,,,,,,,,,,,,, -48000,0.07774845,0.034201812,,,,,,,,,,,,,,,,, -48100,0.08599611,0.034506384,,,,,,,,,,,,,,,,, -48200,0.09054938,0.03340847,,,,,,,,,,,,,,,,, -48300,0.14012262,0.033856567,,,,,,,,,,,,,,,,, -48326,,,0.9889408946037292,0.037719901651144,0.2345589341680993,0.985910177230835,0.0486140400171279,0.1934921789183694,43793.0,0.9850214123725892,0.0514427870512008,0.1861461973192388,43793.0,15378.298068523409,22251.10497307777,15378.298068523409,6869.479328393936,2.0448648929595947,0.0 -48400,0.06393995,0.038290374,,,,,,,,,,,,,,,,, -48500,0.10410659,0.039626136,,,,,,,,,,,,,,,,, -48600,0.08070042,0.036992606,,,,,,,,,,,,,,,,, -48700,0.071088485,0.03573489,,,,,,,,,,,,,,,,, -48800,0.12131289,0.04041788,,,,,,,,,,,,,,,,, -48900,0.060316067,0.033509422,,,,,,,,,,,,,,,,, -49000,0.1009564,0.03677513,,,,,,,,,,,,,,,,, -49084,,,0.9892590641975404,0.0370510518550872,0.2275835327942321,0.9858817458152772,0.0481711849570274,0.1869669971379753,43793.0,0.984955072402954,0.0509009025990963,0.1842454033049275,43793.0,15618.53837943077,22599.077934503555,15618.53837943077,6977.155750751495,2.0811920166015625,0.0 -49100,0.06166997,0.0338906,,,,,,,,,,,,,,,,, -49200,0.10530415,0.04110198,,,,,,,,,,,,,,,,, -49300,0.057920348,0.037626557,,,,,,,,,,,,,,,,, -49400,0.06374754,0.037427314,,,,,,,,,,,,,,,,, -49500,0.09462104,0.035905667,,,,,,,,,,,,,,,,, -49600,0.070002496,0.0382088,,,,,,,,,,,,,,,,, -49700,0.15530552,0.04244934,,,,,,,,,,,,,,,,, -49800,0.080833875,0.038750473,,,,,,,,,,,,,,,,, -49835,,,0.9892882704734802,0.0367698408663272,0.2388334799873369,0.9859365224838256,0.0483722016215324,0.1890876817019972,43793.0,0.9850075244903564,0.0509440377354621,0.1868286006328181,43793.0,15858.68499994278,22945.02338528633,15858.68499994278,7082.8976101875305,2.118276596069336,0.0 -49900,0.14852835,0.035764173,,,,,,,,,,,,,,,,, -50000,0.06282359,0.034729734,,,,,,,,,,,,,,,,, -50100,0.11393108,0.03798485,,,,,,,,,,,,,,,,, -50200,0.062359545,0.03339396,,,,,,,,,,,,,,,,, -50300,0.09664855,0.0362438,,,,,,,,,,,,,,,,, -50400,0.06883264,0.035397887,,,,,,,,,,,,,,,,, -50500,0.08464004,0.03563324,,,,,,,,,,,,,,,,, -50593,,,0.989309787750244,0.0366719216108322,0.2393315454742929,0.9859462976455688,0.0478029549121856,0.184098849675496,43793.0,0.9849485754966736,0.0503535121679306,0.1815041684128693,43793.0,16098.916516304016,23292.18442082405,16098.916516304016,7189.770831346512,2.1541895866394043,0.0 -50600,0.090832666,0.037776597,,,,,,,,,,,,,,,,, -50700,0.14931287,0.038431134,,,,,,,,,,,,,,,,, -50800,0.08716349,0.038842294,,,,,,,,,,,,,,,,, -50900,0.08609693,0.038503047,,,,,,,,,,,,,,,,, -51000,0.05202643,0.035666447,,,,,,,,,,,,,,,,, -51100,0.072633326,0.0373282,,,,,,,,,,,,,,,,, -51200,0.08452511,0.037428346,,,,,,,,,,,,,,,,, -51300,0.07470846,0.03727572,,,,,,,,,,,,,,,,, -51352,,,0.9894444346427916,0.0360986664891243,0.2547790329618258,0.9859942197799684,0.0479114614427089,0.1914060906010451,43793.0,0.9850766062736512,0.0506192557513713,0.1871973443523216,43793.0,16339.007791280746,23634.599209308624,16339.007791280746,7292.035793304443,2.192917823791504,0.0 -51400,0.07281806,0.03393299,,,,,,,,,,,,,,,,, -51500,0.057168987,0.03786992,,,,,,,,,,,,,,,,, -51600,0.06819698,0.037311293,,,,,,,,,,,,,,,,, -51700,0.094687596,0.035120804,,,,,,,,,,,,,,,,, -51800,0.15927209,0.034967434,,,,,,,,,,,,,,,,, -51900,0.10114619,0.03913641,,,,,,,,,,,,,,,,, -52000,0.073732965,0.03886255,,,,,,,,,,,,,,,,, -52100,0.15924251,0.039832205,,,,,,,,,,,,,,,,, -52115,,,0.9893057346343994,0.0364396683871746,0.2481928758179206,0.9859434366226196,0.0476787872612476,0.1925508469856677,43793.0,0.9850197434425354,0.0503366254270076,0.189377131888189,43793.0,16579.094081163406,23977.299023389816,16579.094081163406,7394.589677810669,2.2324140071868896,0.0 -52200,0.13814357,0.034593645,,,,,,,,,,,,,,,,, -52300,0.080929264,0.036043257,,,,,,,,,,,,,,,,, -52400,0.123121046,0.036154937,,,,,,,,,,,,,,,,, -52500,0.07252123,0.040032152,,,,,,,,,,,,,,,,, -52600,0.10360471,0.037230242,,,,,,,,,,,,,,,,, -52700,0.1302075,0.03594606,,,,,,,,,,,,,,,,, -52800,0.14360578,0.035803452,,,,,,,,,,,,,,,,, -52871,,,0.9894201755523682,0.0361618362367153,0.2529269382896945,0.985986053943634,0.0480906032025814,0.1971760256055278,43793.0,0.9850972294807434,0.0508808642625808,0.1879493142080678,43793.0,16819.26360821724,24318.454954862595,16819.26360821724,7495.518446922302,2.2704098224639893,0.0 -52900,0.1293116,0.042605616,,,,,,,,,,,,,,,,, -53000,0.09154661,0.036440987,,,,,,,,,,,,,,,,, -53100,0.1725326,0.037017927,,,,,,,,,,,,,,,,, -53200,0.09942765,0.0381425,,,,,,,,,,,,,,,,, -53300,0.12188124,0.044787787,,,,,,,,,,,,,,,,, -53400,0.083840474,0.034964446,,,,,,,,,,,,,,,,, -53500,0.12563868,0.0363191,,,,,,,,,,,,,,,,, -53600,0.12121073,0.03757208,,,,,,,,,,,,,,,,, -53632,,,0.9894657731056212,0.035980150103569,0.2576308237827217,0.9860092401504515,0.0478395707905292,0.1956039990538769,43793.0,0.9850951433181764,0.0502948872745037,0.1894639877631705,43793.0,17059.2628159523,24661.73969578743,17059.2628159523,7598.746030807495,2.308499336242676,0.0 -53700,0.21709739,0.03801772,,,,,,,,,,,,,,,,, -53800,0.13581292,0.036691222,,,,,,,,,,,,,,,,, -53900,0.13847333,0.0359949,,,,,,,,,,,,,,,,, -54000,0.09088436,0.03822893,,,,,,,,,,,,,,,,, -54100,0.10943844,0.036137763,,,,,,,,,,,,,,,,, -54200,0.06747403,0.03649277,,,,,,,,,,,,,,,,, -54300,0.09567677,0.035700325,,,,,,,,,,,,,,,,, -54393,,,0.9890407919883728,0.0375233180820941,0.2252914477340575,0.9857092499732972,0.0487162582576274,0.1846760456002189,43793.0,0.9847282767295836,0.0514081604778766,0.1770618964335692,43793.0,17299.21511077881,25006.00048089028,17299.21511077881,7702.997226715088,2.3462722301483154,0.0 -54400,0.117079325,0.036931287,,,,,,,,,,,,,,,,, -54500,0.101508394,0.036002144,,,,,,,,,,,,,,,,, -54600,0.1105307,0.037291545,,,,,,,,,,,,,,,,, -54700,0.10756579,0.039155677,,,,,,,,,,,,,,,,, -54800,0.081107065,0.035550468,,,,,,,,,,,,,,,,, -54900,0.10958632,0.03661165,,,,,,,,,,,,,,,,, -55000,0.0861155,0.034811933,,,,,,,,,,,,,,,,, -55100,0.1183039,0.03694345,,,,,,,,,,,,,,,,, -55154,,,0.9894230365753174,0.0361375026404857,0.2417186113024705,0.9860656261444092,0.0478094816207885,0.1957881328081809,43793.0,0.9851566553115844,0.0504313670098781,0.1851077487487467,43793.0,17539.427534341812,25348.384285211563,17539.427534341812,7805.111110925674,2.3839142322540283,0.0 -55200,0.09721071,0.034920428,,,,,,,,,,,,,,,,, -55300,0.11496783,0.035774656,,,,,,,,,,,,,,,,, -55400,0.106783226,0.037434626,,,,,,,,,,,,,,,,, -55500,0.078784116,0.033444326,,,,,,,,,,,,,,,,, -55600,0.14808135,0.035168603,,,,,,,,,,,,,,,,, -55700,0.10916709,0.036861483,,,,,,,,,,,,,,,,, -55800,0.083489,0.03534178,,,,,,,,,,,,,,,,, -55900,0.09657458,0.034870777,,,,,,,,,,,,,,,,, -55910,,,0.9894699454307556,0.0362928993999958,0.2515200120653674,0.9860270619392396,0.0474446602165699,0.1968637869987704,43793.0,0.985159158706665,0.0499532595276832,0.1888697926384954,43793.0,17779.39391064644,25691.751713991165,17779.39391064644,7908.455640792847,2.4211137294769287,0.0 -56000,0.07707204,0.032609493,,,,,,,,,,,,,,,,, -56100,0.083523154,0.039691944,,,,,,,,,,,,,,,,, -56200,0.086301185,0.03809122,,,,,,,,,,,,,,,,, -56300,0.08826629,0.03635545,,,,,,,,,,,,,,,,, -56400,0.14586784,0.034606718,,,,,,,,,,,,,,,,, -56500,0.08706496,0.03817263,,,,,,,,,,,,,,,,, -56600,0.15097332,0.031881016,,,,,,,,,,,,,,,,, -56674,,,0.9894645810127258,0.035822756588459,0.2573926845896402,0.9861395359039308,0.0474042557179927,0.2015896895037112,43793.0,0.9852097034454346,0.0502450168132782,0.1947143161858327,43793.0,18019.52805519104,26035.3692150116,18019.52805519104,8011.882115125656,2.458162784576416,0.0 -56700,0.08355874,0.035674475,,,,,,,,,,,,,,,,, -56800,0.13446526,0.037721302,,,,,,,,,,,,,,,,, -56900,0.12490987,0.033020157,,,,,,,,,,,,,,,,, -57000,0.105766304,0.037390184,,,,,,,,,,,,,,,,, -57100,0.17854206,0.035462394,,,,,,,,,,,,,,,,, -57200,0.08757709,0.037407137,,,,,,,,,,,,,,,,, -57300,0.14861135,0.035067335,,,,,,,,,,,,,,,,, -57400,0.15805668,0.036424678,,,,,,,,,,,,,,,,, -57432,,,0.9895554780960084,0.0357726886868476,0.2601344773012853,0.9859702587127686,0.0473478101193904,0.1983194262662403,43793.0,0.985026478767395,0.0500253960490226,0.190945310755514,43793.0,18259.709180355072,26374.576673030853,18259.709180355072,8110.851234436035,2.49564528465271,0.0 -57500,0.08274938,0.035564028,,,,,,,,,,,,,,,,, -57600,0.0781089,0.036018685,,,,,,,,,,,,,,,,, -57700,0.14327052,0.04038669,,,,,,,,,,,,,,,,, -57800,0.08737076,0.033491753,,,,,,,,,,,,,,,,, -57900,0.1160375,0.031708475,,,,,,,,,,,,,,,,, -58000,0.20722319,0.0351513,,,,,,,,,,,,,,,,, -58100,0.14265268,0.034435317,,,,,,,,,,,,,,,,, -58119,,,,,,,,,,,,,,18477.048362255096,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 4ca862bb5..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,78 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -107.96615147590636,0.0,13.266764402389526,1,0,13.266764402389526,0.3947482705116272,0.7956756353378296,0.0263795780531647,43793,121.23296904563904,0.3886732161045074,0.7993758320808411,0.0241892252979955,0.3926452994346618,0.7974664568901062,0.0248380825643216,43793 -207.39866089820865,0.019460916519165,253.24348163604736,760,0,253.24348163604736,0.9831896424293518,0.0638464763760566,0.0554174900674349,43793,460.68167448043823,0.9868772029876708,0.0513061732053756,0.0554290952326774,0.9841986894607544,0.0606113150715827,0.0533419380277251,43793 -313.2266294956207,0.0457274913787841,493.3325746059418,1513,0,493.3325746059418,0.9836639165878296,0.0584327653050422,0.1038389597988644,43793,806.6452507972717,0.9872434139251708,0.0468114838004112,0.1083326564721908,0.9846939444541932,0.0553979873657226,0.1049047631685208,43793 -416.6431694030762,0.0753901004791259,733.5270512104034,2267,0,733.5270512104034,0.984085977077484,0.0557627975940704,0.1352206329712963,43793,1150.3065140247345,0.9878523945808412,0.0432136245071888,0.1485056292047553,0.9850824475288392,0.052856158465147,0.1401113333246038,43793 -521.5821809768677,0.1021766662597656,973.7290751934052,3027,0,973.7290751934052,0.984433889389038,0.0533240772783756,0.1593613143548954,43793,1495.4947953224182,0.9883006811141968,0.0409547798335552,0.1729483531246934,0.9853901267051696,0.0504433996975421,0.1634846763662925,43793 -626.0516443252563,0.1292841434478759,1213.7164585590365,3783,0,1213.7164585590365,0.984683632850647,0.0518145784735679,0.1757199431759489,43793,1839.9989259243007,0.98840993642807,0.0401165075600147,0.1947053002717981,0.9855772852897644,0.0491400696337223,0.1763400668451602,43793 -727.9304811954498,0.1579248905181884,1453.7154524326324,4538,0,1453.7154524326324,0.9848495721817015,0.0511390194296836,0.1969125162851597,43793,2181.925208091736,0.9886505007743835,0.0387487187981605,0.212599821809528,0.9857733845710754,0.048427578061819,0.1953902653325974,43793 -836.9415402412415,0.1859288215637207,1693.9603996276855,5291,0,1693.9603996276855,0.9849818348884584,0.0499075688421726,0.2096289671249377,43793,2531.229151964188,0.9889928698539734,0.0374399162828922,0.2475354839814961,0.985901653766632,0.0472876690328121,0.2130693687389712,43793 -945.0108182430268,0.2134826183319091,1933.9973559379573,6049,0,1933.9973559379573,0.9852286577224731,0.0498530082404613,0.2094072852101087,43793,2879.383031368256,0.9890961050987244,0.0369906276464462,0.2587219490424934,0.9860299229621888,0.0470738820731639,0.214085916316131,43793 -1048.5638234615326,0.2411527633666992,2174.019069910049,6808,0,2174.019069910049,0.9852871894836426,0.0484927594661712,0.2190290022927095,43793,3223.005018234253,0.9892982840538024,0.0359948612749576,0.2853854649351488,0.9862214922904968,0.0460150241851806,0.2256096644268872,43793 -1151.3277099132538,0.2674412727355957,2414.260739564896,7566,0,2414.260739564896,0.985509157180786,0.0481332577764987,0.2383437345658389,43793,3566.0569252967834,0.9897069931030272,0.0347049161791801,0.3021627549946922,0.986340880393982,0.045676652342081,0.2318702043338875,43793 -1257.461489200592,0.2958061695098877,2654.523429632187,8324,0,2654.523429632187,0.9855584502220154,0.0483413860201835,0.2369385751055225,43793,3912.501790523529,0.989915370941162,0.0336551368236541,0.3202873999988332,0.9864200353622437,0.0456519722938537,0.2381516566782394,43793 -1358.925484418869,0.3236792087554931,2894.683405160904,9083,0,2894.683405160904,0.985533595085144,0.0476725287735462,0.2434450341870531,43793,4254.173964262009,0.9900816082954408,0.0332052633166313,0.3405124811740702,0.9864683151245116,0.0451069064438343,0.2389902367890833,43793 -1461.5836379528046,0.3566343784332275,3134.751321554184,9838,0,3134.751321554184,0.9856464862823486,0.0476103201508522,0.2401196061095888,43793,4596.953059434891,0.9902423024177552,0.0324068292975425,0.3589596793839536,0.9864655137062072,0.0449420176446437,0.2456156073082285,43793 -1564.5984835624697,0.3857431411743164,3374.937530517578,10591,0,3374.937530517578,0.9858199954032898,0.0472011156380176,0.2493607192724526,43793,4940.203378915787,0.9902772903442384,0.0322590358555316,0.3610310912047775,0.9866834878921508,0.0445441342890262,0.2554394149345458,43793 -1666.2861070632937,0.4149281978607178,3615.076661109924,11347,0,3615.076661109924,0.9857147336006165,0.0472510457038879,0.2535599677668376,43793,5282.0795221328735,0.9904393553733826,0.0319894962012767,0.3553863849024475,0.9865807890892028,0.0445267893373966,0.2493957465244841,43793 -1772.8010022640228,0.4449071884155273,3855.316538333893,12099,0,3855.316538333893,0.9858659505844116,0.0469923354685306,0.2538715301046184,43793,5628.884557723999,0.99038964509964,0.0319667942821979,0.362734710936444,0.9866493940353394,0.0443710908293724,0.2522052206234418,43793 -1875.968015909195,0.4730713367462158,4095.371396303177,12859,0,4095.371396303177,0.985922396183014,0.0469017066061496,0.2577900264790435,43793,5972.155205726624,0.9905506372451782,0.0315115861594677,0.3814257408336702,0.9867147207260132,0.0444487147033214,0.2606091547759409,43793 -1983.082782745361,0.5029690265655518,4335.552411794663,13612,0,4335.552411794663,0.9859269857406616,0.0467573627829551,0.2620023400737273,43793,6319.501272678375,0.9904999732971193,0.0312000997364521,0.3974700571924796,0.986764669418335,0.0441124215722084,0.2609337751096847,43793 -2086.727002859116,0.5339128971099854,4575.777185678482,14362,0,4575.777185678482,0.985931634902954,0.0466216616332531,0.2593869756363092,43793,6663.42175078392,0.9905785322189332,0.0309624504297971,0.3904840151507778,0.9867537021636964,0.0440865717828273,0.2646759096238124,43793 -2188.0321526527405,0.5664629936218262,4816.017210483551,15115,0,4816.017210483551,0.9859648942947388,0.0465397238731384,0.2624386163105034,43793,7005.01956653595,0.9908715486526488,0.0299752708524465,0.4078574356020376,0.9867532849311828,0.0438701473176479,0.2704771739922652,43793 -2292.2479150295258,0.595801830291748,5056.221124887466,15868,0,5056.221124887466,0.9860209226608276,0.047208122909069,0.255676770999339,43793,7349.488345146179,0.990837574005127,0.0299731157720088,0.4219413948334143,0.9869092106819152,0.0443227887153625,0.2641223882343574,43793 -2394.0149455070496,0.6254878044128418,5296.290515899658,16631,0,5296.290515899658,0.9859219789505004,0.0466020852327346,0.2708772382010023,43793,7691.375034809113,0.9911161065101624,0.029060611501336,0.4442643550924829,0.9867061972618104,0.0439366847276687,0.2679191421744507,43793 -2496.5304474830627,0.655855655670166,5536.512343883514,17393,0,5536.512343883514,0.9859758615493774,0.0470364689826965,0.2687661791415418,43793,8034.163238525391,0.9911894798278807,0.0288370326161384,0.4573207657369297,0.9868093132972716,0.0442428737878799,0.268753044165262,43793 -2599.5078728199005,0.6851916313171387,5776.531605005264,18153,0,5776.531605005264,0.9860365390777588,0.0470629520714283,0.2600734428996599,43793,8377.209501743317,0.9911827445030212,0.0288665257394313,0.4415226656518888,0.9869375824928284,0.0440139845013618,0.2757121702309885,43793 -2702.5033643245697,0.7152974605560303,6016.586377620697,18905,0,6016.586377620697,0.9859219789505004,0.0465562902390956,0.2624453804195361,43793,8720.30987906456,0.9912071824073792,0.0289520751684904,0.4259516293447483,0.9868218898773192,0.0438161194324493,0.2703069049199251,43793 -2805.7474246025085,0.7460854053497314,6256.5426633358,19659,0,6256.5426633358,0.9860196709632874,0.0470933727920055,0.2598258010263444,43793,9063.561178445816,0.9910314083099364,0.0292637553066015,0.4337188049806543,0.9868438243865968,0.0443489290773868,0.2657752752914989,43793 -2907.2453095912933,0.7763607501983643,6496.541145086288,20406,0,6496.541145086288,0.9860963225364684,0.0464952029287815,0.2645376394413712,43793,9405.10793018341,0.9910750985145568,0.0291318371891975,0.4227976599470897,0.9867857694625854,0.0440329238772392,0.2679220228668975,43793 -3010.938796758652,0.8066954612731934,6736.561691999435,21153,0,6736.561691999435,0.9861363172531128,0.0466601848602294,0.2698105443588578,43793,9748.872096776962,0.9912204146385192,0.028793504461646,0.4551327226887302,0.9869741201400756,0.0436850786209106,0.2746462865459141,43793 -3113.4867935180664,0.838770866394043,6976.753581285477,21908,0,6976.753581285477,0.986116111278534,0.046595923602581,0.2598133222072068,43793,10091.66427564621,0.9911985397338868,0.0285947620868682,0.4477774106470991,0.9868621230125428,0.0438196025788784,0.2761807715640725,43793 -3216.602859258652,0.8691599369049072,7216.855234861374,22662,0,7216.855234861374,0.9861489534378052,0.0468988865613937,0.2668017828894048,43793,10434.932267189026,0.9913204908370972,0.0282438620924949,0.4488545594683917,0.9869660139083862,0.0441065616905689,0.2770595068690319,43793 -3321.61783194542,0.9006702899932861,7456.835754871368,23409,0,7456.835754871368,0.9861843585968018,0.0464783944189548,0.2654733454371574,43793,10779.978934049606,0.991339385509491,0.0280072912573814,0.4768710540761837,0.9870260953903198,0.0435880720615386,0.2754455617613686,43793 -3425.3005759716034,0.932265043258667,7696.949980020523,24169,0,7696.949980020523,0.9862883687019348,0.0466070286929607,0.2689184754578295,43793,11123.827285289764,0.9917247891426086,0.0268406271934509,0.4892830506636969,0.9871340990066528,0.0434612222015857,0.2824920145696338,43793 -3533.4301381111145,0.9634485244750975,7937.204426765442,24918,0,7937.204426765442,0.9860310554504396,0.0468678697943687,0.2691691968252736,43793,11472.261877059937,0.991833746433258,0.0265441630035638,0.5018239814929434,0.9869444966316224,0.0440145805478096,0.2853757787085315,43793 -3636.113703250885,0.9951300621032716,8177.371206998825,25672,0,8177.371206998825,0.9861308932304382,0.0467511713504791,0.2679375446158011,43793,11815.163880109789,0.9916606545448304,0.0270655062049627,0.4931265444328698,0.9869298934936525,0.0440454222261905,0.2736783751099223,43793 -3744.255452156067,1.0263991355895996,8417.410165786743,26430,0,8417.410165786743,0.9861363172531128,0.0463011413812637,0.2708964711546916,43793,12163.396028280258,0.9915685653686525,0.027531573548913,0.4764078593676328,0.9869379997253418,0.0434568747878074,0.280425913360974,43793 -3847.5303523540497,1.059520959854126,8657.560030698776,27188,0,8657.560030698776,0.986103057861328,0.0470838136970996,0.267270030296919,43793,12506.874527454376,0.9914869666099548,0.02746900357306,0.469126213426749,0.9869745373725892,0.0441538840532302,0.2801483549488656,43793 -3948.3207392692566,1.093000411987305,8897.627020835876,27948,0,8897.627020835876,0.9861097931861876,0.047021172940731,0.2677642311999752,43793,12847.785507917404,0.9915177226066588,0.0276230983436107,0.4635797663326154,0.9868791699409484,0.0440908521413803,0.2768368741050048,43793 -4052.942448377609,1.1257007122039795,9137.69330883026,28703,0,9137.69330883026,0.98611319065094,0.0470520183444023,0.269928501441591,43793,13192.525599956512,0.9916020035743712,0.0273597296327352,0.4797551808692469,0.9869339466094972,0.0440394915640354,0.2776483019824962,43793 -4154.9796550273895,1.158497333526611,9377.844159126282,29457,0,9377.844159126282,0.9862100481987,0.0468228124082088,0.2762927577236896,43793,13534.766516923904,0.9915141463279724,0.0272193383425474,0.481494203847978,0.9870768189430236,0.043897371739149,0.2825347877673383,43793 -4259.123318433762,1.1916203498840332,9617.910662651062,30214,0,9617.910662651062,0.9861536026000975,0.0466206222772598,0.262752059541407,43793,13879.02964568138,0.9916245937347412,0.0269749481230974,0.4861085832401384,0.986906349658966,0.0439552627503871,0.2740178464048872,43793 -4362.0056364536285,1.2251360416412354,9858.047510623932,30970,0,9858.047510623932,0.9861034750938416,0.04676054418087,0.2675303731162233,43793,14222.101789474487,0.9916879534721376,0.0268805008381605,0.4831309493913412,0.9869319200515748,0.0439032725989818,0.2780499286315356,43793 -4466.693303823471,1.2588024139404297,10098.214116334915,31733,0,10098.214116334915,0.9861843585968018,0.0466894954442977,0.2690698646759086,43793,14567.009452581406,0.9920236468315125,0.025705374777317,0.5056896381051835,0.9871158003807068,0.0436328127980232,0.279755651589898,43793 -4568.222731590271,1.2922701835632324,10338.35930466652,32490,0,10338.35930466652,0.9861034750938416,0.0477668866515159,0.2708168201569263,43793,14908.73725104332,0.9920137524604796,0.0257115643471479,0.5341408908884115,0.9870135188102722,0.0446216352283954,0.2860193035317245,43793 -4674.432077407837,1.3265063762664795,10578.31761264801,33250,0,10578.31761264801,0.9861953258514404,0.0467899516224861,0.2757306207445348,43793,15254.959168434145,0.992345094680786,0.0247074533253908,0.542667942251531,0.9870890378952026,0.0436905212700367,0.2875374195248217,43793 -4778.975328207016,1.3605234622955322,10818.476083278656,34005,0,10818.476083278656,0.9861144423484802,0.0468964874744415,0.2706710636528843,43793,15599.71495604515,0.9921961426734924,0.0252094622701406,0.5336781380529731,0.9870187640190125,0.043862909078598,0.2830934933826157,43793 -4881.47107052803,1.395246505737305,11058.613788604736,34753,0,11058.613788604736,0.9861447811126708,0.0475664623081684,0.268727189597655,43793,15942.403191328049,0.9918820858001708,0.0260654781013727,0.4910209539627918,0.9870484471321106,0.0443894192576408,0.28277370317285,43793 -4987.312251329422,1.4282660484313965,11298.787636518478,35506,0,11298.787636518478,0.9860979914665222,0.0468775518238544,0.2738745127717128,43793,16288.471473455427,0.991861879825592,0.0261897947639226,0.5024073601537873,0.9869396090507508,0.04402307420969,0.2845682260674991,43793 -5091.926290988922,1.4618382453918457,11538.803719043732,36263,0,11538.803719043732,0.9861472845077516,0.0472131557762622,0.2717603088414615,43793,16633.155037403107,0.9919314980506896,0.0258906707167625,0.5145215114867434,0.9870853424072266,0.0440011210739612,0.2819405539208721,43793 -5194.814932107925,1.4971117973327637,11779.05526304245,37020,0,11779.05526304245,0.9861915111541748,0.0470620207488536,0.2745681328480498,43793,16976.350893974304,0.9920294880867004,0.0256625600159168,0.525321302551629,0.9870873689651488,0.0440217033028602,0.2841475760462093,43793 -5300.418438196182,1.9214563369750977,12018.9160490036,37780,0,12018.9160490036,0.986250936985016,0.0473027415573596,0.2781049767731259,43793,17322.259452581406,0.992198646068573,0.0250362455844879,0.5411420089665825,0.9870870113372804,0.0441103279590606,0.2871141885306458,43793 -5404.342231273651,1.955871105194092,12259.118973731996,38539,0,12259.118973731996,0.9861801266670228,0.0478105284273624,0.2695032759656423,43793,17666.440479516983,0.9921022057533264,0.025213586166501,0.5186278352257488,0.9870630502700806,0.044580589979887,0.2844634555788871,43793 -5505.22732257843,1.9891891479492188,12499.194529294968,39293,0,12499.194529294968,0.9863056540489196,0.0475598722696304,0.2716499638558299,43793,18007.454338550568,0.9923055171966552,0.024532938376069,0.5386582770107501,0.9871178269386292,0.0443666987121105,0.2873161617452175,43793 -5607.085157632828,2.023126602172852,12739.169730901718,40044,0,12739.169730901718,0.9863005876541138,0.0477026961743831,0.275908618284249,43793,18349.34113383293,0.9925379157066344,0.0237658787518739,0.5654487645215553,0.9872254133224488,0.044390469789505,0.2868401826016203,43793 -5708.435259819031,2.0572030544281006,12979.314190387726,40804,0,12979.314190387726,0.9861422181129456,0.047527328133583,0.276028532270375,43793,18690.889854192734,0.9927852749824524,0.023151222616434,0.5775893719350957,0.9870390892028807,0.0445139482617378,0.2907892343854625,43793 -5811.207458734512,2.09119725227356,13219.500958919523,41566,0,13219.500958919523,0.986311972141266,0.0479333437979221,0.2786746810670965,43793,19033.90281820297,0.9926801323890686,0.0232359189540147,0.5827686459930178,0.9871527552604676,0.0449665486812591,0.2888831560468628,43793 -5916.065999746323,2.125895738601685,13459.468119859695,42326,0,13459.468119859695,0.9861717224121094,0.0477462261915206,0.2741958646714575,43793,19378.783309936523,0.9926807880401612,0.0234455466270446,0.5660779500427957,0.9870723485946656,0.0446420200169086,0.2907271142531775,43793 -6016.614222764969,2.186445474624634,13699.443043470385,43076,0,13699.443043470385,0.9863384962081908,0.0477807931602001,0.2780795705772881,43793,19719.386711359024,0.992499589920044,0.0239414982497692,0.5477740813822252,0.9870922565460204,0.0447935499250888,0.2892499494045117,43793 -6120.364076852799,2.2230827808380127,13939.41558265686,43826,0,13939.41558265686,0.9862349033355712,0.0479205027222633,0.2761384758753186,43793,20063.16539645195,0.9925187826156616,0.0238534715026617,0.5623201160024388,0.9870967268943788,0.0447729937732219,0.290604844012177,43793 -6223.816519021988,2.2597038745880127,14179.501737833025,44573,0,14179.501737833025,0.986177384853363,0.0478943772614002,0.2796849265893141,43793,20406.76056933403,0.9926663041114808,0.0234936866909265,0.5641044492073488,0.9869808554649352,0.0448011122643947,0.2861025098978526,43793 -6330.305959939957,2.296353340148926,14419.762070655825,45329,0,14419.762070655825,0.9862361550331116,0.0477320775389671,0.2748628905795507,43793,20753.566682338715,0.992615520954132,0.0235337913036346,0.5735633061987906,0.987092673778534,0.0445826053619384,0.2924033260479522,43793 -6435.036421775818,2.332881212234497,14659.82816028595,46086,0,14659.82816028595,0.9862534403800964,0.0483238622546196,0.2770070234684541,43793,21098.419695854187,0.9928174614906312,0.0227997712790966,0.5843304260900297,0.9871162176132202,0.0450176522135734,0.2908162408747274,43793 -6538.548140287399,2.3681788444519043,14899.894470214844,46840,0,14899.894470214844,0.9863351583480836,0.04808259755373,0.2740103134893802,43793,21442.05270147324,0.9929105043411256,0.0224435105919837,0.583947243919866,0.9870622158050536,0.0450861379504203,0.2884127098754348,43793 -6645.693230390549,2.4038455486297607,15139.885927677156,47594,0,15139.885927677156,0.9861367344856262,0.0480062998831272,0.2781953903528778,43793,21789.244199752808,0.99306058883667,0.0221116058528423,0.6052337144115114,0.9869980812072754,0.0450661145150661,0.2931761841290524,43793 -6747.254262685776,2.439406394958496,15380.088567256927,48354,0,15380.088567256927,0.9861510992050172,0.0481312908232212,0.2733146559294196,43793,22131.06327676773,0.9933203458786012,0.0213257987052202,0.6125192019839938,0.9870272874832152,0.0449936538934707,0.2895411568214211,43793 -6848.968739748001,2.4766972064971924,15620.312088727953,49113,0,15620.312088727953,0.9861595034599304,0.0488424561917781,0.273296710563419,43793,22473.058209180832,0.9934682250022888,0.02071463316679,0.6426092112092957,0.9870119094848632,0.045634426176548,0.2912810678611326,43793 -6949.345129728317,2.5182931423187256,15860.463735103607,49870,0,15860.463735103607,0.98635071516037,0.0493396520614624,0.2741738490622731,43793,22813.64751195908,0.993198812007904,0.0213176608085632,0.6175284713675935,0.9872095584869384,0.0462044775485992,0.2915582478001256,43793 -7051.691777467728,2.554926872253418,16100.710295438766,50630,0,16100.710295438766,0.9861738085746764,0.0487218536436557,0.2741141156405496,43793,23156.29674863816,0.9932072758674622,0.0215307734906673,0.6065806372072025,0.9869790077209472,0.0456237904727458,0.2889365308084413,43793 -7152.396923303604,2.5921742916107178,16340.700924158096,51395,0,16340.700924158096,0.9862711429595948,0.0488883107900619,0.2745948013890646,43793,23497.049400806427,0.9932363629341124,0.0213737972080707,0.6189733847194379,0.9871328473091124,0.0456616654992103,0.2887289210986936,43793 -7253.871799230576,2.6272754669189453,16580.71767473221,52159,0,16580.71767473221,0.9862285852432252,0.0490499958395957,0.2820580224329422,43793,23838.59583926201,0.9932331442832948,0.0213511120527982,0.6257244482281129,0.9869733452796936,0.0459792278707027,0.2871613198633184,43793 -7355.121006250381,2.663091897964477,16820.859786748886,52918,0,16820.859786748886,0.9862694144248962,0.0490830689668655,0.2786305604471064,43793,24180.04267501831,0.9933499693870544,0.0209690537303686,0.6198161926220507,0.9870723485946656,0.04600640386343,0.2898691017552575,43793 -7457.915698289871,2.698948621749878,17061.08062529564,53674,0,17061.08062529564,0.9862861037254332,0.0495411418378353,0.2794267027148863,43793,24523.11379837989,0.9933534264564514,0.020808607339859,0.629301270151079,0.9871153831481934,0.0463089346885681,0.2890898983273863,43793 -7556.839626789093,2.735714912414551,17301.122370243073,54432,0,17301.122370243073,0.986311972141266,0.0494890883564949,0.276395624635202,43793,24862.13622522354,0.9934797883033752,0.0204735491424798,0.62993081766115,0.9871438145637512,0.0462908633053302,0.2877251097946155,43793 -7661.45597743988,2.774885654449463,17541.313578367233,55184,0,17541.313578367233,0.986243724822998,0.0493969358503818,0.2795634855130245,43793,25207.003221273422,0.9935911297798156,0.020000223070383,0.6547790093724063,0.9871222972869872,0.0461614653468132,0.2900952606281319,43793 -7760.551167964935,2.8143293857574463,17781.294590473175,55944,0,17781.294590473175,0.9862310886383056,0.0500529184937477,0.2778119291983832,43793,25546.138780355453,0.9937967658042908,0.0194714777171611,0.6527825708898463,0.987092673778534,0.0467407219111919,0.2911061819818859,43793 -7865.168210506439,2.853222131729126,18021.25959467888,56701,0,18021.25959467888,0.9862062335014344,0.0498196221888065,0.282412806094966,43793,25890.77930259705,0.9940793514251708,0.0186166800558567,0.6898328240343279,0.987098753452301,0.0466601215302944,0.2899552937529251,43793 -7965.316137075424,2.891075849533081,18261.490270614624,57456,0,18261.490270614624,0.9862176179885864,0.05016188323497772,0.2802577179914768,43793,26231.215923547745,0.9942800998687744,0.018083401024341583,0.6976413701881876,0.9871283769607544,0.04675629362463951,0.2962188183781732,43793 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/measurements.csv deleted file mode 100644 index 5008648df..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/measurements.csv +++ /dev/null @@ -1,661 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.0070398,0.8003827,,,,,,,,,,,,,,,,, -1,,,0.3886732161045074,0.7993758320808411,0.0241892252979955,0.3926452994346618,0.7974664568901062,0.0248380825643216,43793.0,0.3947482705116272,0.7956756353378296,0.0263795780531647,43793.0,13.266764402389526,121.23296904563904,13.266764402389526,107.96615147590636,0.0,0.0 -100,0.32185128,0.28955022,,,,,,,,,,,,,,,,, -200,0.10585849,0.114757605,,,,,,,,,,,,,,,,, -300,0.03583317,0.06658442,,,,,,,,,,,,,,,,, -400,0.017815065,0.053930927,,,,,,,,,,,,,,,,, -500,0.016821546,0.055248573,,,,,,,,,,,,,,,,, -600,0.034691684,0.05145672,,,,,,,,,,,,,,,,, -700,0.040598463,0.06432848,,,,,,,,,,,,,,,,, -760,,,0.9868772029876708,0.0513061732053756,0.0554290952326774,0.9841986894607544,0.0606113150715827,0.0533419380277251,43793.0,0.9831896424293518,0.0638464763760566,0.0554174900674349,43793.0,253.24348163604736,460.68167448043823,253.24348163604736,207.39866089820865,0.019460916519165,0.0 -800,0.025725761,0.045381885,,,,,,,,,,,,,,,,, -900,0.025329227,0.053107996,,,,,,,,,,,,,,,,, -1000,0.018476725,0.051536664,,,,,,,,,,,,,,,,, -1100,0.01988059,0.048706878,,,,,,,,,,,,,,,,, -1200,0.024689008,0.04300809,,,,,,,,,,,,,,,,, -1300,0.019410932,0.04177025,,,,,,,,,,,,,,,,, -1400,0.028400918,0.04213781,,,,,,,,,,,,,,,,, -1500,0.021620717,0.05336936,,,,,,,,,,,,,,,,, -1513,,,0.9872434139251708,0.0468114838004112,0.1083326564721908,0.9846939444541932,0.0553979873657226,0.1049047631685208,43793.0,0.9836639165878296,0.0584327653050422,0.1038389597988644,43793.0,493.3325746059418,806.6452507972717,493.3325746059418,313.2266294956207,0.0457274913787841,0.0 -1600,0.020448616,0.051085502,,,,,,,,,,,,,,,,, -1700,0.027160497,0.047336597,,,,,,,,,,,,,,,,, -1800,0.036273517,0.046879817,,,,,,,,,,,,,,,,, -1900,0.010626146,0.043376714,,,,,,,,,,,,,,,,, -2000,0.016731223,0.04833337,,,,,,,,,,,,,,,,, -2100,0.014410788,0.042188376,,,,,,,,,,,,,,,,, -2200,0.023790358,0.044188634,,,,,,,,,,,,,,,,, -2267,,,0.9878523945808412,0.0432136245071888,0.1485056292047553,0.9850824475288392,0.052856158465147,0.1401113333246038,43793.0,0.984085977077484,0.0557627975940704,0.1352206329712963,43793.0,733.5270512104034,1150.3065140247345,733.5270512104034,416.6431694030762,0.0753901004791259,0.0 -2300,0.011218041,0.045868956,,,,,,,,,,,,,,,,, -2400,0.014700553,0.045635704,,,,,,,,,,,,,,,,, -2500,0.023203423,0.042597968,,,,,,,,,,,,,,,,, -2600,0.012210422,0.045819614,,,,,,,,,,,,,,,,, -2700,0.018722652,0.04695901,,,,,,,,,,,,,,,,, -2800,0.014566555,0.041574523,,,,,,,,,,,,,,,,, -2900,0.020207094,0.044207044,,,,,,,,,,,,,,,,, -3000,0.014617329,0.045325086,,,,,,,,,,,,,,,,, -3027,,,0.9883006811141968,0.0409547798335552,0.1729483531246934,0.9853901267051696,0.0504433996975421,0.1634846763662925,43793.0,0.984433889389038,0.0533240772783756,0.1593613143548954,43793.0,973.7290751934052,1495.4947953224182,973.7290751934052,521.5821809768677,0.1021766662597656,0.0 -3100,0.008953773,0.039693758,,,,,,,,,,,,,,,,, -3200,0.0105147865,0.043016765,,,,,,,,,,,,,,,,, -3300,0.019950027,0.036495384,,,,,,,,,,,,,,,,, -3400,0.013300649,0.040168762,,,,,,,,,,,,,,,,, -3500,0.009737795,0.040130854,,,,,,,,,,,,,,,,, -3600,0.013877414,0.039967034,,,,,,,,,,,,,,,,, -3700,0.013605918,0.037647393,,,,,,,,,,,,,,,,, -3783,,,0.98840993642807,0.0401165075600147,0.1947053002717981,0.9855772852897644,0.0491400696337223,0.1763400668451602,43793.0,0.984683632850647,0.0518145784735679,0.1757199431759489,43793.0,1213.7164585590365,1839.9989259243007,1213.7164585590365,626.0516443252563,0.1292841434478759,0.0 -3800,0.010895645,0.043252468,,,,,,,,,,,,,,,,, -3900,0.012252887,0.0380712,,,,,,,,,,,,,,,,, -4000,0.012788391,0.038225956,,,,,,,,,,,,,,,,, -4100,0.01759816,0.045537207,,,,,,,,,,,,,,,,, -4200,0.013772093,0.04352885,,,,,,,,,,,,,,,,, -4300,0.013019488,0.035894558,,,,,,,,,,,,,,,,, -4400,0.011109368,0.033729993,,,,,,,,,,,,,,,,, -4500,0.011329235,0.039856806,,,,,,,,,,,,,,,,, -4538,,,0.9886505007743835,0.0387487187981605,0.212599821809528,0.9857733845710754,0.048427578061819,0.1953902653325974,43793.0,0.9848495721817015,0.0511390194296836,0.1969125162851597,43793.0,1453.7154524326324,2181.925208091736,1453.7154524326324,727.9304811954498,0.1579248905181884,0.0 -4600,0.017149318,0.040593892,,,,,,,,,,,,,,,,, -4700,0.012351568,0.038293805,,,,,,,,,,,,,,,,, -4800,0.010360706,0.037946004,,,,,,,,,,,,,,,,, -4900,0.010559828,0.041764356,,,,,,,,,,,,,,,,, -5000,0.016063716,0.041368593,,,,,,,,,,,,,,,,, -5100,0.012523894,0.036187142,,,,,,,,,,,,,,,,, -5200,0.015755074,0.036455587,,,,,,,,,,,,,,,,, -5291,,,0.9889928698539734,0.0374399162828922,0.2475354839814961,0.985901653766632,0.0472876690328121,0.2130693687389712,43793.0,0.9849818348884584,0.0499075688421726,0.2096289671249377,43793.0,1693.9603996276855,2531.229151964188,1693.9603996276855,836.9415402412415,0.1859288215637207,0.0 -5300,0.01416571,0.041312445,,,,,,,,,,,,,,,,, -5400,0.010288359,0.03374867,,,,,,,,,,,,,,,,, -5500,0.014312146,0.042386014,,,,,,,,,,,,,,,,, -5600,0.013637287,0.03719154,,,,,,,,,,,,,,,,, -5700,0.0124036055,0.03765102,,,,,,,,,,,,,,,,, -5800,0.012111167,0.035673395,,,,,,,,,,,,,,,,, -5900,0.01835997,0.034772873,,,,,,,,,,,,,,,,, -6000,0.012052056,0.039301585,,,,,,,,,,,,,,,,, -6049,,,0.9890961050987244,0.0369906276464462,0.2587219490424934,0.9860299229621888,0.0470738820731639,0.214085916316131,43793.0,0.9852286577224731,0.0498530082404613,0.2094072852101087,43793.0,1933.9973559379573,2879.383031368256,1933.9973559379573,945.0108182430268,0.2134826183319091,0.0 -6100,0.011754401,0.0386655,,,,,,,,,,,,,,,,, -6200,0.013710401,0.034512103,,,,,,,,,,,,,,,,, -6300,0.013644936,0.035368305,,,,,,,,,,,,,,,,, -6400,0.014333498,0.042238228,,,,,,,,,,,,,,,,, -6500,0.0123275,0.039197296,,,,,,,,,,,,,,,,, -6600,0.012442291,0.03483074,,,,,,,,,,,,,,,,, -6700,0.012027254,0.036723826,,,,,,,,,,,,,,,,, -6800,0.014247279,0.035258032,,,,,,,,,,,,,,,,, -6808,,,0.9892982840538024,0.0359948612749576,0.2853854649351488,0.9862214922904968,0.0460150241851806,0.2256096644268872,43793.0,0.9852871894836426,0.0484927594661712,0.2190290022927095,43793.0,2174.019069910049,3223.005018234253,2174.019069910049,1048.5638234615326,0.2411527633666992,0.0 -6900,0.012989588,0.0387479,,,,,,,,,,,,,,,,, -7000,0.014511299,0.03700159,,,,,,,,,,,,,,,,, -7100,0.0137069365,0.03366936,,,,,,,,,,,,,,,,, -7200,0.013737445,0.034191716,,,,,,,,,,,,,,,,, -7300,0.022691062,0.039541192,,,,,,,,,,,,,,,,, -7400,0.013509766,0.035822287,,,,,,,,,,,,,,,,, -7500,0.01677727,0.036971953,,,,,,,,,,,,,,,,, -7566,,,0.9897069931030272,0.0347049161791801,0.3021627549946922,0.986340880393982,0.045676652342081,0.2318702043338875,43793.0,0.985509157180786,0.0481332577764987,0.2383437345658389,43793.0,2414.260739564896,3566.0569252967834,2414.260739564896,1151.3277099132538,0.2674412727355957,0.0 -7600,0.013820797,0.035680875,,,,,,,,,,,,,,,,, -7700,0.013824449,0.035898935,,,,,,,,,,,,,,,,, -7800,0.017139107,0.03283047,,,,,,,,,,,,,,,,, -7900,0.017312521,0.0344465,,,,,,,,,,,,,,,,, -8000,0.0141577,0.035145324,,,,,,,,,,,,,,,,, -8100,0.01841916,0.040396444,,,,,,,,,,,,,,,,, -8200,0.013783677,0.032139227,,,,,,,,,,,,,,,,, -8300,0.012726171,0.033043,,,,,,,,,,,,,,,,, -8324,,,0.989915370941162,0.0336551368236541,0.3202873999988332,0.9864200353622437,0.0456519722938537,0.2381516566782394,43793.0,0.9855584502220154,0.0483413860201835,0.2369385751055225,43793.0,2654.523429632187,3912.501790523529,2654.523429632187,1257.461489200592,0.2958061695098877,0.0 -8400,0.02011069,0.03490498,,,,,,,,,,,,,,,,, -8500,0.016771458,0.03562021,,,,,,,,,,,,,,,,, -8600,0.024257537,0.036022916,,,,,,,,,,,,,,,,, -8700,0.016539855,0.035484184,,,,,,,,,,,,,,,,, -8800,0.015312231,0.036134012,,,,,,,,,,,,,,,,, -8900,0.0238619,0.036719035,,,,,,,,,,,,,,,,, -9000,0.02924225,0.038756393,,,,,,,,,,,,,,,,, -9083,,,0.9900816082954408,0.0332052633166313,0.3405124811740702,0.9864683151245116,0.0451069064438343,0.2389902367890833,43793.0,0.985533595085144,0.0476725287735462,0.2434450341870531,43793.0,2894.683405160904,4254.173964262009,2894.683405160904,1358.925484418869,0.3236792087554931,0.0 -9100,0.016849753,0.03756,,,,,,,,,,,,,,,,, -9200,0.018382583,0.033330515,,,,,,,,,,,,,,,,, -9300,0.018933283,0.03762369,,,,,,,,,,,,,,,,, -9400,0.020353183,0.036705215,,,,,,,,,,,,,,,,, -9500,0.021352181,0.0321075,,,,,,,,,,,,,,,,, -9600,0.021765683,0.036866244,,,,,,,,,,,,,,,,, -9700,0.025593853,0.03739149,,,,,,,,,,,,,,,,, -9800,0.020615213,0.032614626,,,,,,,,,,,,,,,,, -9838,,,0.9902423024177552,0.0324068292975425,0.3589596793839536,0.9864655137062072,0.0449420176446437,0.2456156073082285,43793.0,0.9856464862823486,0.0476103201508522,0.2401196061095888,43793.0,3134.751321554184,4596.953059434891,3134.751321554184,1461.5836379528046,0.3566343784332275,0.0 -9900,0.025194898,0.034299914,,,,,,,,,,,,,,,,, -10000,0.018383887,0.0342898,,,,,,,,,,,,,,,,, -10100,0.020702712,0.03325596,,,,,,,,,,,,,,,,, -10200,0.023587717,0.03631531,,,,,,,,,,,,,,,,, -10300,0.018870821,0.032221295,,,,,,,,,,,,,,,,, -10400,0.020755902,0.03641096,,,,,,,,,,,,,,,,, -10500,0.022899656,0.03280661,,,,,,,,,,,,,,,,, -10591,,,0.9902772903442384,0.0322590358555316,0.3610310912047775,0.9866834878921508,0.0445441342890262,0.2554394149345458,43793.0,0.9858199954032898,0.0472011156380176,0.2493607192724526,43793.0,3374.937530517578,4940.203378915787,3374.937530517578,1564.5984835624697,0.3857431411743164,0.0 -10600,0.022045113,0.032409135,,,,,,,,,,,,,,,,, -10700,0.038792666,0.03176922,,,,,,,,,,,,,,,,, -10800,0.020963298,0.03357166,,,,,,,,,,,,,,,,, -10900,0.022332322,0.034090202,,,,,,,,,,,,,,,,, -11000,0.026463667,0.033328693,,,,,,,,,,,,,,,,, -11100,0.02772784,0.03484841,,,,,,,,,,,,,,,,, -11200,0.019040205,0.03092867,,,,,,,,,,,,,,,,, -11300,0.023130117,0.034685563,,,,,,,,,,,,,,,,, -11347,,,0.9904393553733826,0.0319894962012767,0.3553863849024475,0.9865807890892028,0.0445267893373966,0.2493957465244841,43793.0,0.9857147336006165,0.0472510457038879,0.2535599677668376,43793.0,3615.076661109924,5282.0795221328735,3615.076661109924,1666.2861070632937,0.4149281978607178,0.0 -11400,0.023497198,0.033986058,,,,,,,,,,,,,,,,, -11500,0.024077758,0.03489173,,,,,,,,,,,,,,,,, -11600,0.020526933,0.034002304,,,,,,,,,,,,,,,,, -11700,0.030695599,0.035670456,,,,,,,,,,,,,,,,, -11800,0.027651887,0.030951913,,,,,,,,,,,,,,,,, -11900,0.029106954,0.03506944,,,,,,,,,,,,,,,,, -12000,0.026983764,0.034860753,,,,,,,,,,,,,,,,, -12099,,,0.99038964509964,0.0319667942821979,0.362734710936444,0.9866493940353394,0.0443710908293724,0.2522052206234418,43793.0,0.9858659505844116,0.0469923354685306,0.2538715301046184,43793.0,3855.316538333893,5628.884557723999,3855.316538333893,1772.8010022640228,0.4449071884155273,0.0 -12100,0.03838894,0.032454442,,,,,,,,,,,,,,,,, -12200,0.026032314,0.03403655,,,,,,,,,,,,,,,,, -12300,0.0265557,0.033490866,,,,,,,,,,,,,,,,, -12400,0.023563774,0.030890774,,,,,,,,,,,,,,,,, -12500,0.027637416,0.032413237,,,,,,,,,,,,,,,,, -12600,0.026982987,0.031572092,,,,,,,,,,,,,,,,, -12700,0.03536851,0.036753703,,,,,,,,,,,,,,,,, -12800,0.022718502,0.031483866,,,,,,,,,,,,,,,,, -12859,,,0.9905506372451782,0.0315115861594677,0.3814257408336702,0.9867147207260132,0.0444487147033214,0.2606091547759409,43793.0,0.985922396183014,0.0469017066061496,0.2577900264790435,43793.0,4095.371396303177,5972.155205726624,4095.371396303177,1875.968015909195,0.4730713367462158,0.0 -12900,0.024320098,0.03095778,,,,,,,,,,,,,,,,, -13000,0.03280234,0.03274766,,,,,,,,,,,,,,,,, -13100,0.030311674,0.02970514,,,,,,,,,,,,,,,,, -13200,0.03644712,0.030555323,,,,,,,,,,,,,,,,, -13300,0.03519226,0.034624793,,,,,,,,,,,,,,,,, -13400,0.034327235,0.03280861,,,,,,,,,,,,,,,,, -13500,0.03234599,0.033564825,,,,,,,,,,,,,,,,, -13600,0.030399501,0.035893742,,,,,,,,,,,,,,,,, -13612,,,0.9904999732971193,0.0312000997364521,0.3974700571924796,0.986764669418335,0.0441124215722084,0.2609337751096847,43793.0,0.9859269857406616,0.0467573627829551,0.2620023400737273,43793.0,4335.552411794663,6319.501272678375,4335.552411794663,1983.082782745361,0.5029690265655518,0.0 -13700,0.031626876,0.029040731,,,,,,,,,,,,,,,,, -13800,0.028427968,0.031189973,,,,,,,,,,,,,,,,, -13900,0.033847522,0.037597008,,,,,,,,,,,,,,,,, -14000,0.03140434,0.030611753,,,,,,,,,,,,,,,,, -14100,0.035759512,0.033491924,,,,,,,,,,,,,,,,, -14200,0.03687934,0.038260706,,,,,,,,,,,,,,,,, -14300,0.04042376,0.03244015,,,,,,,,,,,,,,,,, -14362,,,0.9905785322189332,0.0309624504297971,0.3904840151507778,0.9867537021636964,0.0440865717828273,0.2646759096238124,43793.0,0.985931634902954,0.0466216616332531,0.2593869756363092,43793.0,4575.777185678482,6663.42175078392,4575.777185678482,2086.727002859116,0.5339128971099854,0.0 -14400,0.043815013,0.038985617,,,,,,,,,,,,,,,,, -14500,0.034378655,0.03160812,,,,,,,,,,,,,,,,, -14600,0.037771616,0.033703793,,,,,,,,,,,,,,,,, -14700,0.031322334,0.0334028,,,,,,,,,,,,,,,,, -14800,0.032286424,0.028051542,,,,,,,,,,,,,,,,, -14900,0.04220555,0.03549859,,,,,,,,,,,,,,,,, -15000,0.044324543,0.033975635,,,,,,,,,,,,,,,,, -15100,0.03421147,0.030111317,,,,,,,,,,,,,,,,, -15115,,,0.9908715486526488,0.0299752708524465,0.4078574356020376,0.9867532849311828,0.0438701473176479,0.2704771739922652,43793.0,0.9859648942947388,0.0465397238731384,0.2624386163105034,43793.0,4816.017210483551,7005.01956653595,4816.017210483551,2188.0321526527405,0.5664629936218262,0.0 -15200,0.032034226,0.03245965,,,,,,,,,,,,,,,,, -15300,0.03678946,0.03621458,,,,,,,,,,,,,,,,, -15400,0.032604236,0.030253397,,,,,,,,,,,,,,,,, -15500,0.045169,0.03236748,,,,,,,,,,,,,,,,, -15600,0.03397616,0.03369416,,,,,,,,,,,,,,,,, -15700,0.041682422,0.03614257,,,,,,,,,,,,,,,,, -15800,0.03664046,0.031657968,,,,,,,,,,,,,,,,, -15868,,,0.990837574005127,0.0299731157720088,0.4219413948334143,0.9869092106819152,0.0443227887153625,0.2641223882343574,43793.0,0.9860209226608276,0.047208122909069,0.255676770999339,43793.0,5056.221124887466,7349.488345146179,5056.221124887466,2292.2479150295258,0.595801830291748,0.0 -15900,0.036790732,0.030097336,,,,,,,,,,,,,,,,, -16000,0.03624048,0.02817605,,,,,,,,,,,,,,,,, -16100,0.04059016,0.03590135,,,,,,,,,,,,,,,,, -16200,0.04222178,0.029935297,,,,,,,,,,,,,,,,, -16300,0.033469886,0.028673774,,,,,,,,,,,,,,,,, -16400,0.0314084,0.03383673,,,,,,,,,,,,,,,,, -16500,0.036359258,0.029125797,,,,,,,,,,,,,,,,, -16600,0.051013473,0.031414215,,,,,,,,,,,,,,,,, -16631,,,0.9911161065101624,0.029060611501336,0.4442643550924829,0.9867061972618104,0.0439366847276687,0.2679191421744507,43793.0,0.9859219789505004,0.0466020852327346,0.2708772382010023,43793.0,5296.290515899658,7691.375034809113,5296.290515899658,2394.0149455070496,0.6254878044128418,0.0 -16700,0.037937567,0.033046912,,,,,,,,,,,,,,,,, -16800,0.03202712,0.030743597,,,,,,,,,,,,,,,,, -16900,0.037257176,0.0314969,,,,,,,,,,,,,,,,, -17000,0.04162668,0.03456881,,,,,,,,,,,,,,,,, -17100,0.03726111,0.03017491,,,,,,,,,,,,,,,,, -17200,0.03139661,0.029077759,,,,,,,,,,,,,,,,, -17300,0.048250537,0.03287034,,,,,,,,,,,,,,,,, -17393,,,0.9911894798278807,0.0288370326161384,0.4573207657369297,0.9868093132972716,0.0442428737878799,0.268753044165262,43793.0,0.9859758615493774,0.0470364689826965,0.2687661791415418,43793.0,5536.512343883514,8034.163238525391,5536.512343883514,2496.5304474830627,0.655855655670166,0.0 -17400,0.032886095,0.02897498,,,,,,,,,,,,,,,,, -17500,0.039126344,0.02959019,,,,,,,,,,,,,,,,, -17600,0.051938727,0.030901235,,,,,,,,,,,,,,,,, -17700,0.039318535,0.026271375,,,,,,,,,,,,,,,,, -17800,0.03468989,0.029115835,,,,,,,,,,,,,,,,, -17900,0.04346025,0.03054348,,,,,,,,,,,,,,,,, -18000,0.08766651,0.03409441,,,,,,,,,,,,,,,,, -18100,0.056139674,0.031409726,,,,,,,,,,,,,,,,, -18153,,,0.9911827445030212,0.0288665257394313,0.4415226656518888,0.9869375824928284,0.0440139845013618,0.2757121702309885,43793.0,0.9860365390777588,0.0470629520714283,0.2600734428996599,43793.0,5776.531605005264,8377.209501743317,5776.531605005264,2599.5078728199005,0.6851916313171387,0.0 -18200,0.040182233,0.033029474,,,,,,,,,,,,,,,,, -18300,0.038793854,0.03192582,,,,,,,,,,,,,,,,, -18400,0.043040108,0.029924987,,,,,,,,,,,,,,,,, -18500,0.03711928,0.027836826,,,,,,,,,,,,,,,,, -18600,0.04339247,0.03290504,,,,,,,,,,,,,,,,, -18700,0.031186445,0.027366186,,,,,,,,,,,,,,,,, -18800,0.04131307,0.029847156,,,,,,,,,,,,,,,,, -18900,0.04156123,0.03165081,,,,,,,,,,,,,,,,, -18905,,,0.9912071824073792,0.0289520751684904,0.4259516293447483,0.9868218898773192,0.0438161194324493,0.2703069049199251,43793.0,0.9859219789505004,0.0465562902390956,0.2624453804195361,43793.0,6016.586377620697,8720.30987906456,6016.586377620697,2702.5033643245697,0.7152974605560303,0.0 -19000,0.045083802,0.031411838,,,,,,,,,,,,,,,,, -19100,0.03686078,0.032944966,,,,,,,,,,,,,,,,, -19200,0.043181982,0.03444446,,,,,,,,,,,,,,,,, -19300,0.050937235,0.032739505,,,,,,,,,,,,,,,,, -19400,0.037983272,0.031094098,,,,,,,,,,,,,,,,, -19500,0.05695517,0.030159311,,,,,,,,,,,,,,,,, -19600,0.05139087,0.03118261,,,,,,,,,,,,,,,,, -19659,,,0.9910314083099364,0.0292637553066015,0.4337188049806543,0.9868438243865968,0.0443489290773868,0.2657752752914989,43793.0,0.9860196709632874,0.0470933727920055,0.2598258010263444,43793.0,6256.5426633358,9063.561178445816,6256.5426633358,2805.7474246025085,0.7460854053497314,0.0 -19700,0.03869724,0.027476663,,,,,,,,,,,,,,,,, -19800,0.04539036,0.02920253,,,,,,,,,,,,,,,,, -19900,0.044632975,0.031469043,,,,,,,,,,,,,,,,, -20000,0.044786118,0.028411943,,,,,,,,,,,,,,,,, -20100,0.04936405,0.02973259,,,,,,,,,,,,,,,,, -20200,0.047050375,0.02610189,,,,,,,,,,,,,,,,, -20300,0.04016884,0.03164056,,,,,,,,,,,,,,,,, -20400,0.041952297,0.030150993,,,,,,,,,,,,,,,,, -20406,,,0.9910750985145568,0.0291318371891975,0.4227976599470897,0.9867857694625854,0.0440329238772392,0.2679220228668975,43793.0,0.9860963225364684,0.0464952029287815,0.2645376394413712,43793.0,6496.541145086288,9405.10793018341,6496.541145086288,2907.2453095912933,0.7763607501983643,0.0 -20500,0.05140037,0.034190577,,,,,,,,,,,,,,,,, -20600,0.044877272,0.02298366,,,,,,,,,,,,,,,,, -20700,0.041650925,0.031610887,,,,,,,,,,,,,,,,, -20800,0.05055671,0.029293513,,,,,,,,,,,,,,,,, -20900,0.07467784,0.030556023,,,,,,,,,,,,,,,,, -21000,0.04821849,0.03161227,,,,,,,,,,,,,,,,, -21100,0.0514353,0.032033153,,,,,,,,,,,,,,,,, -21153,,,0.9912204146385192,0.028793504461646,0.4551327226887302,0.9869741201400756,0.0436850786209106,0.2746462865459141,43793.0,0.9861363172531128,0.0466601848602294,0.2698105443588578,43793.0,6736.561691999435,9748.872096776962,6736.561691999435,3010.938796758652,0.8066954612731934,0.0 -21200,0.061621547,0.029898278,,,,,,,,,,,,,,,,, -21300,0.04280477,0.028828159,,,,,,,,,,,,,,,,, -21400,0.04944517,0.03010772,,,,,,,,,,,,,,,,, -21500,0.06670043,0.03522566,,,,,,,,,,,,,,,,, -21600,0.06991282,0.031358507,,,,,,,,,,,,,,,,, -21700,0.05796257,0.025013698,,,,,,,,,,,,,,,,, -21800,0.054168828,0.03096473,,,,,,,,,,,,,,,,, -21900,0.048937496,0.03345179,,,,,,,,,,,,,,,,, -21908,,,0.9911985397338868,0.0285947620868682,0.4477774106470991,0.9868621230125428,0.0438196025788784,0.2761807715640725,43793.0,0.986116111278534,0.046595923602581,0.2598133222072068,43793.0,6976.753581285477,10091.66427564621,6976.753581285477,3113.4867935180664,0.838770866394043,0.0 -22000,0.060946397,0.031745423,,,,,,,,,,,,,,,,, -22100,0.062967405,0.032880254,,,,,,,,,,,,,,,,, -22200,0.051092405,0.031452224,,,,,,,,,,,,,,,,, -22300,0.047159616,0.028637072,,,,,,,,,,,,,,,,, -22400,0.049936686,0.031523846,,,,,,,,,,,,,,,,, -22500,0.06437871,0.031673007,,,,,,,,,,,,,,,,, -22600,0.048838638,0.032221463,,,,,,,,,,,,,,,,, -22662,,,0.9913204908370972,0.0282438620924949,0.4488545594683917,0.9869660139083862,0.0441065616905689,0.2770595068690319,43793.0,0.9861489534378052,0.0468988865613937,0.2668017828894048,43793.0,7216.855234861374,10434.932267189026,7216.855234861374,3216.602859258652,0.8691599369049072,0.0 -22700,0.05236299,0.028059121,,,,,,,,,,,,,,,,, -22800,0.05648179,0.032154698,,,,,,,,,,,,,,,,, -22900,0.0549235,0.029050505,,,,,,,,,,,,,,,,, -23000,0.06831002,0.029492768,,,,,,,,,,,,,,,,, -23100,0.05049764,0.028295137,,,,,,,,,,,,,,,,, -23200,0.048477054,0.029264146,,,,,,,,,,,,,,,,, -23300,0.04743813,0.031385917,,,,,,,,,,,,,,,,, -23400,0.04274771,0.02626013,,,,,,,,,,,,,,,,, -23409,,,0.991339385509491,0.0280072912573814,0.4768710540761837,0.9870260953903198,0.0435880720615386,0.2754455617613686,43793.0,0.9861843585968018,0.0464783944189548,0.2654733454371574,43793.0,7456.835754871368,10779.978934049606,7456.835754871368,3321.61783194542,0.9006702899932861,0.0 -23500,0.047176804,0.032858778,,,,,,,,,,,,,,,,, -23600,0.05605964,0.026608158,,,,,,,,,,,,,,,,, -23700,0.060278766,0.030939745,,,,,,,,,,,,,,,,, -23800,0.046966538,0.031639736,,,,,,,,,,,,,,,,, -23900,0.054377746,0.031169465,,,,,,,,,,,,,,,,, -24000,0.046103735,0.030879078,,,,,,,,,,,,,,,,, -24100,0.054528505,0.030460728,,,,,,,,,,,,,,,,, -24169,,,0.9917247891426086,0.0268406271934509,0.4892830506636969,0.9871340990066528,0.0434612222015857,0.2824920145696338,43793.0,0.9862883687019348,0.0466070286929607,0.2689184754578295,43793.0,7696.949980020523,11123.827285289764,7696.949980020523,3425.3005759716034,0.932265043258667,0.0 -24200,0.05475011,0.030819995,,,,,,,,,,,,,,,,, -24300,0.053607106,0.025631271,,,,,,,,,,,,,,,,, -24400,0.054608762,0.025928315,,,,,,,,,,,,,,,,, -24500,0.052036613,0.030224545,,,,,,,,,,,,,,,,, -24600,0.070121825,0.02974945,,,,,,,,,,,,,,,,, -24700,0.055219255,0.034050282,,,,,,,,,,,,,,,,, -24800,0.043205295,0.024617376,,,,,,,,,,,,,,,,, -24900,0.061195966,0.030322716,,,,,,,,,,,,,,,,, -24918,,,0.991833746433258,0.0265441630035638,0.5018239814929434,0.9869444966316224,0.0440145805478096,0.2853757787085315,43793.0,0.9860310554504396,0.0468678697943687,0.2691691968252736,43793.0,7937.204426765442,11472.261877059937,7937.204426765442,3533.4301381111145,0.9634485244750975,0.0 -25000,0.053935703,0.028662117,,,,,,,,,,,,,,,,, -25100,0.0569725,0.02718896,,,,,,,,,,,,,,,,, -25200,0.04563529,0.029816847,,,,,,,,,,,,,,,,, -25300,0.052451268,0.029281488,,,,,,,,,,,,,,,,, -25400,0.05300572,0.028591303,,,,,,,,,,,,,,,,, -25500,0.045719773,0.026037175,,,,,,,,,,,,,,,,, -25600,0.05067313,0.029995358,,,,,,,,,,,,,,,,, -25672,,,0.9916606545448304,0.0270655062049627,0.4931265444328698,0.9869298934936525,0.0440454222261905,0.2736783751099223,43793.0,0.9861308932304382,0.0467511713504791,0.2679375446158011,43793.0,8177.371206998825,11815.163880109789,8177.371206998825,3636.113703250885,0.9951300621032716,0.0 -25700,0.0571868,0.032143857,,,,,,,,,,,,,,,,, -25800,0.056169305,0.030486282,,,,,,,,,,,,,,,,, -25900,0.056753766,0.034157168,,,,,,,,,,,,,,,,, -26000,0.047606844,0.02834301,,,,,,,,,,,,,,,,, -26100,0.06011094,0.0326666,,,,,,,,,,,,,,,,, -26200,0.05855453,0.032510206,,,,,,,,,,,,,,,,, -26300,0.047931366,0.027597243,,,,,,,,,,,,,,,,, -26400,0.058125444,0.029812848,,,,,,,,,,,,,,,,, -26430,,,0.9915685653686525,0.027531573548913,0.4764078593676328,0.9869379997253418,0.0434568747878074,0.280425913360974,43793.0,0.9861363172531128,0.0463011413812637,0.2708964711546916,43793.0,8417.410165786743,12163.396028280258,8417.410165786743,3744.255452156067,1.0263991355895996,0.0 -26500,0.054304004,0.028979039,,,,,,,,,,,,,,,,, -26600,0.05357748,0.03037404,,,,,,,,,,,,,,,,, -26700,0.066430844,0.031564876,,,,,,,,,,,,,,,,, -26800,0.054657597,0.031233953,,,,,,,,,,,,,,,,, -26900,0.054725543,0.032812476,,,,,,,,,,,,,,,,, -27000,0.0576656,0.03031598,,,,,,,,,,,,,,,,, -27100,0.051558763,0.03108923,,,,,,,,,,,,,,,,, -27188,,,0.9914869666099548,0.02746900357306,0.469126213426749,0.9869745373725892,0.0441538840532302,0.2801483549488656,43793.0,0.986103057861328,0.0470838136970996,0.267270030296919,43793.0,8657.560030698776,12506.874527454376,8657.560030698776,3847.5303523540497,1.059520959854126,0.0 -27200,0.046982475,0.02842046,,,,,,,,,,,,,,,,, -27300,0.053971224,0.028281482,,,,,,,,,,,,,,,,, -27400,0.053970836,0.030575868,,,,,,,,,,,,,,,,, -27500,0.056483187,0.03124488,,,,,,,,,,,,,,,,, -27600,0.05513594,0.028962718,,,,,,,,,,,,,,,,, -27700,0.056857266,0.03024676,,,,,,,,,,,,,,,,, -27800,0.059578057,0.029648516,,,,,,,,,,,,,,,,, -27900,0.059498966,0.029355397,,,,,,,,,,,,,,,,, -27948,,,0.9915177226066588,0.0276230983436107,0.4635797663326154,0.9868791699409484,0.0440908521413803,0.2768368741050048,43793.0,0.9861097931861876,0.047021172940731,0.2677642311999752,43793.0,8897.627020835876,12847.785507917404,8897.627020835876,3948.3207392692566,1.093000411987305,0.0 -28000,0.057264727,0.030419564,,,,,,,,,,,,,,,,, -28100,0.05177421,0.029934337,,,,,,,,,,,,,,,,, -28200,0.050157417,0.03111284,,,,,,,,,,,,,,,,, -28300,0.066453174,0.028283304,,,,,,,,,,,,,,,,, -28400,0.061924726,0.027897239,,,,,,,,,,,,,,,,, -28500,0.049864925,0.029363567,,,,,,,,,,,,,,,,, -28600,0.05358822,0.028290754,,,,,,,,,,,,,,,,, -28700,0.06583123,0.029722026,,,,,,,,,,,,,,,,, -28703,,,0.9916020035743712,0.0273597296327352,0.4797551808692469,0.9869339466094972,0.0440394915640354,0.2776483019824962,43793.0,0.98611319065094,0.0470520183444023,0.269928501441591,43793.0,9137.69330883026,13192.525599956512,9137.69330883026,4052.942448377609,1.1257007122039795,0.0 -28800,0.053473398,0.029223299,,,,,,,,,,,,,,,,, -28900,0.062063444,0.027701339,,,,,,,,,,,,,,,,, -29000,0.048762623,0.026030052,,,,,,,,,,,,,,,,, -29100,0.057078984,0.02767335,,,,,,,,,,,,,,,,, -29200,0.048501972,0.027557382,,,,,,,,,,,,,,,,, -29300,0.057374634,0.030797869,,,,,,,,,,,,,,,,, -29400,0.05773923,0.027100328,,,,,,,,,,,,,,,,, -29457,,,0.9915141463279724,0.0272193383425474,0.481494203847978,0.9870768189430236,0.043897371739149,0.2825347877673383,43793.0,0.9862100481987,0.0468228124082088,0.2762927577236896,43793.0,9377.844159126282,13534.766516923904,9377.844159126282,4154.9796550273895,1.158497333526611,0.0 -29500,0.06230366,0.026944472,,,,,,,,,,,,,,,,, -29600,0.05280197,0.027864417,,,,,,,,,,,,,,,,, -29700,0.059911337,0.028820526,,,,,,,,,,,,,,,,, -29800,0.049618658,0.02712925,,,,,,,,,,,,,,,,, -29900,0.051198155,0.031074757,,,,,,,,,,,,,,,,, -30000,0.06821477,0.03229658,,,,,,,,,,,,,,,,, -30100,0.06337825,0.03036318,,,,,,,,,,,,,,,,, -30200,0.05169976,0.028021038,,,,,,,,,,,,,,,,, -30214,,,0.9916245937347412,0.0269749481230974,0.4861085832401384,0.986906349658966,0.0439552627503871,0.2740178464048872,43793.0,0.9861536026000975,0.0466206222772598,0.262752059541407,43793.0,9617.910662651062,13879.02964568138,9617.910662651062,4259.123318433762,1.1916203498840332,0.0 -30300,0.056900293,0.026931873,,,,,,,,,,,,,,,,, -30400,0.05604828,0.026911054,,,,,,,,,,,,,,,,, -30500,0.07407475,0.031945374,,,,,,,,,,,,,,,,, -30600,0.05604934,0.026287574,,,,,,,,,,,,,,,,, -30700,0.05994944,0.025947804,,,,,,,,,,,,,,,,, -30800,0.052754696,0.02895841,,,,,,,,,,,,,,,,, -30900,0.07987711,0.029903531,,,,,,,,,,,,,,,,, -30970,,,0.9916879534721376,0.0268805008381605,0.4831309493913412,0.9869319200515748,0.0439032725989818,0.2780499286315356,43793.0,0.9861034750938416,0.04676054418087,0.2675303731162233,43793.0,9858.047510623932,14222.101789474487,9858.047510623932,4362.0056364536285,1.2251360416412354,0.0 -31000,0.058902834,0.029076623,,,,,,,,,,,,,,,,, -31100,0.06577425,0.02741181,,,,,,,,,,,,,,,,, -31200,0.06651479,0.029235449,,,,,,,,,,,,,,,,, -31300,0.053775884,0.02751682,,,,,,,,,,,,,,,,, -31400,0.054924726,0.026220653,,,,,,,,,,,,,,,,, -31500,0.046505112,0.023474852,,,,,,,,,,,,,,,,, -31600,0.0634183,0.032426,,,,,,,,,,,,,,,,, -31700,0.056159474,0.027299603,,,,,,,,,,,,,,,,, -31733,,,0.9920236468315125,0.025705374777317,0.5056896381051835,0.9871158003807068,0.0436328127980232,0.279755651589898,43793.0,0.9861843585968018,0.0466894954442977,0.2690698646759086,43793.0,10098.214116334915,14567.009452581406,10098.214116334915,4466.693303823471,1.2588024139404297,0.0 -31800,0.06759822,0.030382035,,,,,,,,,,,,,,,,, -31900,0.056278452,0.029515114,,,,,,,,,,,,,,,,, -32000,0.07460212,0.030072102,,,,,,,,,,,,,,,,, -32100,0.059449773,0.025460815,,,,,,,,,,,,,,,,, -32200,0.0648031,0.031278502,,,,,,,,,,,,,,,,, -32300,0.04683311,0.0271023,,,,,,,,,,,,,,,,, -32400,0.065677,0.028612072,,,,,,,,,,,,,,,,, -32490,,,0.9920137524604796,0.0257115643471479,0.5341408908884115,0.9870135188102722,0.0446216352283954,0.2860193035317245,43793.0,0.9861034750938416,0.0477668866515159,0.2708168201569263,43793.0,10338.35930466652,14908.73725104332,10338.35930466652,4568.222731590271,1.2922701835632324,0.0 -32500,0.06395547,0.025257554,,,,,,,,,,,,,,,,, -32600,0.06402881,0.02884761,,,,,,,,,,,,,,,,, -32700,0.06073823,0.028953655,,,,,,,,,,,,,,,,, -32800,0.056491148,0.03032784,,,,,,,,,,,,,,,,, -32900,0.054669354,0.031611603,,,,,,,,,,,,,,,,, -33000,0.05370171,0.026752802,,,,,,,,,,,,,,,,, -33100,0.057101175,0.028900897,,,,,,,,,,,,,,,,, -33200,0.058590405,0.026082467,,,,,,,,,,,,,,,,, -33250,,,0.992345094680786,0.0247074533253908,0.542667942251531,0.9870890378952026,0.0436905212700367,0.2875374195248217,43793.0,0.9861953258514404,0.0467899516224861,0.2757306207445348,43793.0,10578.31761264801,15254.959168434145,10578.31761264801,4674.432077407837,1.3265063762664795,0.0 -33300,0.057383046,0.02979469,,,,,,,,,,,,,,,,, -33400,0.08027797,0.031808037,,,,,,,,,,,,,,,,, -33500,0.0590648,0.029589707,,,,,,,,,,,,,,,,, -33600,0.05582413,0.02619591,,,,,,,,,,,,,,,,, -33700,0.06242222,0.030442193,,,,,,,,,,,,,,,,, -33800,0.059038244,0.023162236,,,,,,,,,,,,,,,,, -33900,0.071074985,0.029026506,,,,,,,,,,,,,,,,, -34000,0.07568531,0.030011924,,,,,,,,,,,,,,,,, -34005,,,0.9921961426734924,0.0252094622701406,0.5336781380529731,0.9870187640190125,0.043862909078598,0.2830934933826157,43793.0,0.9861144423484802,0.0468964874744415,0.2706710636528843,43793.0,10818.476083278656,15599.71495604515,10818.476083278656,4778.975328207016,1.3605234622955322,0.0 -34100,0.06442112,0.028075496,,,,,,,,,,,,,,,,, -34200,0.065526605,0.024932979,,,,,,,,,,,,,,,,, -34300,0.067839295,0.028596591,,,,,,,,,,,,,,,,, -34400,0.07064793,0.031281013,,,,,,,,,,,,,,,,, -34500,0.063971505,0.028909877,,,,,,,,,,,,,,,,, -34600,0.05990135,0.030688751,,,,,,,,,,,,,,,,, -34700,0.05967394,0.029069968,,,,,,,,,,,,,,,,, -34753,,,0.9918820858001708,0.0260654781013727,0.4910209539627918,0.9870484471321106,0.0443894192576408,0.28277370317285,43793.0,0.9861447811126708,0.0475664623081684,0.268727189597655,43793.0,11058.613788604736,15942.403191328049,11058.613788604736,4881.47107052803,1.395246505737305,0.0 -34800,0.06261248,0.028687159,,,,,,,,,,,,,,,,, -34900,0.05636945,0.027966278,,,,,,,,,,,,,,,,, -35000,0.060999617,0.028455354,,,,,,,,,,,,,,,,, -35100,0.06308751,0.030328458,,,,,,,,,,,,,,,,, -35200,0.058577817,0.026776098,,,,,,,,,,,,,,,,, -35300,0.055889264,0.026826505,,,,,,,,,,,,,,,,, -35400,0.06337357,0.027043061,,,,,,,,,,,,,,,,, -35500,0.062014528,0.028480897,,,,,,,,,,,,,,,,, -35506,,,0.991861879825592,0.0261897947639226,0.5024073601537873,0.9869396090507508,0.04402307420969,0.2845682260674991,43793.0,0.9860979914665222,0.0468775518238544,0.2738745127717128,43793.0,11298.787636518478,16288.471473455427,11298.787636518478,4987.312251329422,1.4282660484313965,0.0 -35600,0.05529018,0.026543388,,,,,,,,,,,,,,,,, -35700,0.06598728,0.028936932,,,,,,,,,,,,,,,,, -35800,0.06331518,0.030606186,,,,,,,,,,,,,,,,, -35900,0.08056742,0.029077688,,,,,,,,,,,,,,,,, -36000,0.06727331,0.02982352,,,,,,,,,,,,,,,,, -36100,0.062868655,0.027041975,,,,,,,,,,,,,,,,, -36200,0.07010061,0.030216303,,,,,,,,,,,,,,,,, -36263,,,0.9919314980506896,0.0258906707167625,0.5145215114867434,0.9870853424072266,0.0440011210739612,0.2819405539208721,43793.0,0.9861472845077516,0.0472131557762622,0.2717603088414615,43793.0,11538.803719043732,16633.155037403107,11538.803719043732,5091.926290988922,1.4618382453918457,0.0 -36300,0.055817273,0.027071511,,,,,,,,,,,,,,,,, -36400,0.07615253,0.02894161,,,,,,,,,,,,,,,,, -36500,0.060111344,0.029603105,,,,,,,,,,,,,,,,, -36600,0.071118526,0.02801379,,,,,,,,,,,,,,,,, -36700,0.07059386,0.026304897,,,,,,,,,,,,,,,,, -36800,0.07344301,0.030469745,,,,,,,,,,,,,,,,, -36900,0.06077494,0.029646268,,,,,,,,,,,,,,,,, -37000,0.058889855,0.027437398,,,,,,,,,,,,,,,,, -37020,,,0.9920294880867004,0.0256625600159168,0.525321302551629,0.9870873689651488,0.0440217033028602,0.2841475760462093,43793.0,0.9861915111541748,0.0470620207488536,0.2745681328480498,43793.0,11779.05526304245,16976.350893974304,11779.05526304245,5194.814932107925,1.4971117973327637,0.0 -37100,0.0669552,0.026400393,,,,,,,,,,,,,,,,, -37200,0.057127886,0.025760362,,,,,,,,,,,,,,,,, -37300,0.0634416,0.030032855,,,,,,,,,,,,,,,,, -37400,0.06162749,0.029228235,,,,,,,,,,,,,,,,, -37500,0.056394253,0.024506671,,,,,,,,,,,,,,,,, -37600,0.066356435,0.030295875,,,,,,,,,,,,,,,,, -37700,0.07695384,0.027803581,,,,,,,,,,,,,,,,, -37780,,,0.992198646068573,0.0250362455844879,0.5411420089665825,0.9870870113372804,0.0441103279590606,0.2871141885306458,43793.0,0.986250936985016,0.0473027415573596,0.2781049767731259,43793.0,12018.9160490036,17322.259452581406,12018.9160490036,5300.418438196182,1.9214563369750977,0.0 -37800,0.06261536,0.02839384,,,,,,,,,,,,,,,,, -37900,0.07519338,0.028732061,,,,,,,,,,,,,,,,, -38000,0.06617181,0.02671643,,,,,,,,,,,,,,,,, -38100,0.065255255,0.029039841,,,,,,,,,,,,,,,,, -38200,0.06635869,0.027393095,,,,,,,,,,,,,,,,, -38300,0.07035973,0.028199421,,,,,,,,,,,,,,,,, -38400,0.07125832,0.027747083,,,,,,,,,,,,,,,,, -38500,0.07230716,0.02975372,,,,,,,,,,,,,,,,, -38539,,,0.9921022057533264,0.025213586166501,0.5186278352257488,0.9870630502700806,0.044580589979887,0.2844634555788871,43793.0,0.9861801266670228,0.0478105284273624,0.2695032759656423,43793.0,12259.118973731996,17666.440479516983,12259.118973731996,5404.342231273651,1.955871105194092,0.0 -38600,0.06567787,0.029282589,,,,,,,,,,,,,,,,, -38700,0.074043244,0.025935866,,,,,,,,,,,,,,,,, -38800,0.076866634,0.027782395,,,,,,,,,,,,,,,,, -38900,0.06394879,0.026083563,,,,,,,,,,,,,,,,, -39000,0.06315172,0.028338902,,,,,,,,,,,,,,,,, -39100,0.064646184,0.026191741,,,,,,,,,,,,,,,,, -39200,0.06166722,0.026113674,,,,,,,,,,,,,,,,, -39293,,,0.9923055171966552,0.024532938376069,0.5386582770107501,0.9871178269386292,0.0443666987121105,0.2873161617452175,43793.0,0.9863056540489196,0.0475598722696304,0.2716499638558299,43793.0,12499.194529294968,18007.454338550568,12499.194529294968,5505.22732257843,1.9891891479492188,0.0 -39300,0.060057297,0.025907123,,,,,,,,,,,,,,,,, -39400,0.06450438,0.029311396,,,,,,,,,,,,,,,,, -39500,0.071948275,0.028544594,,,,,,,,,,,,,,,,, -39600,0.065934114,0.026599841,,,,,,,,,,,,,,,,, -39700,0.06338986,0.025858862,,,,,,,,,,,,,,,,, -39800,0.086403415,0.030751852,,,,,,,,,,,,,,,,, -39900,0.075949095,0.026556438,,,,,,,,,,,,,,,,, -40000,0.07496462,0.030708948,,,,,,,,,,,,,,,,, -40044,,,0.9925379157066344,0.0237658787518739,0.5654487645215553,0.9872254133224488,0.044390469789505,0.2868401826016203,43793.0,0.9863005876541138,0.0477026961743831,0.275908618284249,43793.0,12739.169730901718,18349.34113383293,12739.169730901718,5607.085157632828,2.023126602172852,0.0 -40100,0.055569496,0.025636302,,,,,,,,,,,,,,,,, -40200,0.066769,0.02746622,,,,,,,,,,,,,,,,, -40300,0.059925083,0.025466366,,,,,,,,,,,,,,,,, -40400,0.062461067,0.026507042,,,,,,,,,,,,,,,,, -40500,0.06936464,0.02523117,,,,,,,,,,,,,,,,, -40600,0.079450734,0.025134997,,,,,,,,,,,,,,,,, -40700,0.07495677,0.025915,,,,,,,,,,,,,,,,, -40800,0.08214653,0.025865542,,,,,,,,,,,,,,,,, -40804,,,0.9927852749824524,0.023151222616434,0.5775893719350957,0.9870390892028807,0.0445139482617378,0.2907892343854625,43793.0,0.9861422181129456,0.047527328133583,0.276028532270375,43793.0,12979.314190387726,18690.889854192734,12979.314190387726,5708.435259819031,2.0572030544281006,0.0 -40900,0.13037993,0.027642163,,,,,,,,,,,,,,,,, -41000,0.08015654,0.028803827,,,,,,,,,,,,,,,,, -41100,0.08000657,0.028929971,,,,,,,,,,,,,,,,, -41200,0.076629415,0.026360746,,,,,,,,,,,,,,,,, -41300,0.059665523,0.024362834,,,,,,,,,,,,,,,,, -41400,0.06702768,0.02477953,,,,,,,,,,,,,,,,, -41500,0.0789363,0.031027623,,,,,,,,,,,,,,,,, -41566,,,0.9926801323890686,0.0232359189540147,0.5827686459930178,0.9871527552604676,0.0449665486812591,0.2888831560468628,43793.0,0.986311972141266,0.0479333437979221,0.2786746810670965,43793.0,13219.500958919523,19033.90281820297,13219.500958919523,5811.207458734512,2.09119725227356,0.0 -41600,0.0757598,0.0281049,,,,,,,,,,,,,,,,, -41700,0.0666274,0.025410218,,,,,,,,,,,,,,,,, -41800,0.078097634,0.02484096,,,,,,,,,,,,,,,,, -41900,0.0679941,0.02425735,,,,,,,,,,,,,,,,, -42000,0.061662234,0.025242748,,,,,,,,,,,,,,,,, -42100,0.07257701,0.023734275,,,,,,,,,,,,,,,,, -42200,0.07786505,0.027130187,,,,,,,,,,,,,,,,, -42300,0.0896163,0.028129168,,,,,,,,,,,,,,,,, -42326,,,0.9926807880401612,0.0234455466270446,0.5660779500427957,0.9870723485946656,0.0446420200169086,0.2907271142531775,43793.0,0.9861717224121094,0.0477462261915206,0.2741958646714575,43793.0,13459.468119859695,19378.783309936523,13459.468119859695,5916.065999746323,2.125895738601685,0.0 -42400,0.0823107,0.027209446,,,,,,,,,,,,,,,,, -42500,0.07209789,0.025419418,,,,,,,,,,,,,,,,, -42600,0.07512908,0.026680486,,,,,,,,,,,,,,,,, -42700,0.06942878,0.02609246,,,,,,,,,,,,,,,,, -42800,0.07113483,0.02658251,,,,,,,,,,,,,,,,, -42900,0.084927596,0.026258532,,,,,,,,,,,,,,,,, -43000,0.07361956,0.02527555,,,,,,,,,,,,,,,,, -43076,,,0.992499589920044,0.0239414982497692,0.5477740813822252,0.9870922565460204,0.0447935499250888,0.2892499494045117,43793.0,0.9863384962081908,0.0477807931602001,0.2780795705772881,43793.0,13699.443043470385,19719.386711359024,13699.443043470385,6016.614222764969,2.186445474624634,0.0 -43100,0.07977934,0.03011698,,,,,,,,,,,,,,,,, -43200,0.075219505,0.02629191,,,,,,,,,,,,,,,,, -43300,0.0646824,0.025829827,,,,,,,,,,,,,,,,, -43400,0.078105025,0.028222114,,,,,,,,,,,,,,,,, -43500,0.0731333,0.02584035,,,,,,,,,,,,,,,,, -43600,0.08904543,0.026988937,,,,,,,,,,,,,,,,, -43700,0.0665172,0.026911141,,,,,,,,,,,,,,,,, -43800,0.083058074,0.025751483,,,,,,,,,,,,,,,,, -43826,,,0.9925187826156616,0.0238534715026617,0.5623201160024388,0.9870967268943788,0.0447729937732219,0.290604844012177,43793.0,0.9862349033355712,0.0479205027222633,0.2761384758753186,43793.0,13939.41558265686,20063.16539645195,13939.41558265686,6120.364076852799,2.2230827808380127,0.0 -43900,0.06565802,0.025001764,,,,,,,,,,,,,,,,, -44000,0.07152749,0.025053814,,,,,,,,,,,,,,,,, -44100,0.08757192,0.026471103,,,,,,,,,,,,,,,,, -44200,0.07985068,0.02751519,,,,,,,,,,,,,,,,, -44300,0.0903382,0.027705988,,,,,,,,,,,,,,,,, -44400,0.07765311,0.022163108,,,,,,,,,,,,,,,,, -44500,0.07868834,0.02532936,,,,,,,,,,,,,,,,, -44573,,,0.9926663041114808,0.0234936866909265,0.5641044492073488,0.9869808554649352,0.0448011122643947,0.2861025098978526,43793.0,0.986177384853363,0.0478943772614002,0.2796849265893141,43793.0,14179.501737833025,20406.76056933403,14179.501737833025,6223.816519021988,2.2597038745880127,0.0 -44600,0.083567746,0.029559728,,,,,,,,,,,,,,,,, -44700,0.09108621,0.02446513,,,,,,,,,,,,,,,,, -44800,0.07490741,0.02629754,,,,,,,,,,,,,,,,, -44900,0.0819295,0.027251473,,,,,,,,,,,,,,,,, -45000,0.077768765,0.024733955,,,,,,,,,,,,,,,,, -45100,0.091464184,0.027291575,,,,,,,,,,,,,,,,, -45200,0.07223809,0.02592121,,,,,,,,,,,,,,,,, -45300,0.07086152,0.023523718,,,,,,,,,,,,,,,,, -45329,,,0.992615520954132,0.0235337913036346,0.5735633061987906,0.987092673778534,0.0445826053619384,0.2924033260479522,43793.0,0.9862361550331116,0.0477320775389671,0.2748628905795507,43793.0,14419.762070655825,20753.566682338715,14419.762070655825,6330.305959939957,2.296353340148926,0.0 -45400,0.07172712,0.024109248,,,,,,,,,,,,,,,,, -45500,0.08904166,0.02531318,,,,,,,,,,,,,,,,, -45600,0.090109006,0.027100153,,,,,,,,,,,,,,,,, -45700,0.10110574,0.027000604,,,,,,,,,,,,,,,,, -45800,0.070198044,0.023953462,,,,,,,,,,,,,,,,, -45900,0.08360091,0.02547152,,,,,,,,,,,,,,,,, -46000,0.07826702,0.027340427,,,,,,,,,,,,,,,,, -46086,,,0.9928174614906312,0.0227997712790966,0.5843304260900297,0.9871162176132202,0.0450176522135734,0.2908162408747274,43793.0,0.9862534403800964,0.0483238622546196,0.2770070234684541,43793.0,14659.82816028595,21098.419695854187,14659.82816028595,6435.036421775818,2.332881212234497,0.0 -46100,0.07515926,0.023935925,,,,,,,,,,,,,,,,, -46200,0.07900253,0.023827579,,,,,,,,,,,,,,,,, -46300,0.07247596,0.026051728,,,,,,,,,,,,,,,,, -46400,0.0799658,0.023743676,,,,,,,,,,,,,,,,, -46500,0.09056982,0.025396338,,,,,,,,,,,,,,,,, -46600,0.089813136,0.02666249,,,,,,,,,,,,,,,,, -46700,0.08785889,0.028898573,,,,,,,,,,,,,,,,, -46800,0.08515427,0.025613923,,,,,,,,,,,,,,,,, -46840,,,0.9929105043411256,0.0224435105919837,0.583947243919866,0.9870622158050536,0.0450861379504203,0.2884127098754348,43793.0,0.9863351583480836,0.04808259755373,0.2740103134893802,43793.0,14899.894470214844,21442.05270147324,14899.894470214844,6538.548140287399,2.3681788444519043,0.0 -46900,0.10274233,0.028067874,,,,,,,,,,,,,,,,, -47000,0.080112875,0.024483383,,,,,,,,,,,,,,,,, -47100,0.07921959,0.025532968,,,,,,,,,,,,,,,,, -47200,0.10611947,0.026016762,,,,,,,,,,,,,,,,, -47300,0.08584018,0.027717343,,,,,,,,,,,,,,,,, -47400,0.07616165,0.024578648,,,,,,,,,,,,,,,,, -47500,0.09483766,0.025892463,,,,,,,,,,,,,,,,, -47594,,,0.99306058883667,0.0221116058528423,0.6052337144115114,0.9869980812072754,0.0450661145150661,0.2931761841290524,43793.0,0.9861367344856262,0.0480062998831272,0.2781953903528778,43793.0,15139.885927677156,21789.244199752808,15139.885927677156,6645.693230390549,2.4038455486297607,0.0 -47600,0.08129989,0.0252456,,,,,,,,,,,,,,,,, -47700,0.08224773,0.026214285,,,,,,,,,,,,,,,,, -47800,0.09852437,0.026127841,,,,,,,,,,,,,,,,, -47900,0.088856824,0.02607576,,,,,,,,,,,,,,,,, -48000,0.07670254,0.023125073,,,,,,,,,,,,,,,,, -48100,0.07555811,0.024339749,,,,,,,,,,,,,,,,, -48200,0.07931499,0.0237766,,,,,,,,,,,,,,,,, -48300,0.07952816,0.021392355,,,,,,,,,,,,,,,,, -48354,,,0.9933203458786012,0.0213257987052202,0.6125192019839938,0.9870272874832152,0.0449936538934707,0.2895411568214211,43793.0,0.9861510992050172,0.0481312908232212,0.2733146559294196,43793.0,15380.088567256927,22131.06327676773,15380.088567256927,6747.254262685776,2.439406394958496,0.0 -48400,0.08750572,0.024295855,,,,,,,,,,,,,,,,, -48500,0.10084423,0.027888123,,,,,,,,,,,,,,,,, -48600,0.08366789,0.024951804,,,,,,,,,,,,,,,,, -48700,0.09000765,0.02549371,,,,,,,,,,,,,,,,, -48800,0.09296532,0.02679725,,,,,,,,,,,,,,,,, -48900,0.08312482,0.023110427,,,,,,,,,,,,,,,,, -49000,0.08488233,0.024455512,,,,,,,,,,,,,,,,, -49100,0.08955977,0.021919414,,,,,,,,,,,,,,,,, -49113,,,0.9934682250022888,0.02071463316679,0.6426092112092957,0.9870119094848632,0.045634426176548,0.2912810678611326,43793.0,0.9861595034599304,0.0488424561917781,0.273296710563419,43793.0,15620.312088727953,22473.058209180832,15620.312088727953,6848.968739748001,2.4766972064971924,0.0 -49200,0.09761613,0.027538972,,,,,,,,,,,,,,,,, -49300,0.09021009,0.024865264,,,,,,,,,,,,,,,,, -49400,0.09735661,0.025410714,,,,,,,,,,,,,,,,, -49500,0.07976919,0.022598805,,,,,,,,,,,,,,,,, -49600,0.09345383,0.025349328,,,,,,,,,,,,,,,,, -49700,0.094323166,0.027476933,,,,,,,,,,,,,,,,, -49800,0.089483395,0.02578621,,,,,,,,,,,,,,,,, -49870,,,0.993198812007904,0.0213176608085632,0.6175284713675935,0.9872095584869384,0.0462044775485992,0.2915582478001256,43793.0,0.98635071516037,0.0493396520614624,0.2741738490622731,43793.0,15860.463735103607,22813.64751195908,15860.463735103607,6949.345129728317,2.5182931423187256,0.0 -49900,0.08839648,0.022849288,,,,,,,,,,,,,,,,, -50000,0.10039382,0.02395129,,,,,,,,,,,,,,,,, -50100,0.09398212,0.025908394,,,,,,,,,,,,,,,,, -50200,0.07833172,0.021518342,,,,,,,,,,,,,,,,, -50300,0.088446416,0.022697518,,,,,,,,,,,,,,,,, -50400,0.10647886,0.024069224,,,,,,,,,,,,,,,,, -50500,0.085024096,0.024249068,,,,,,,,,,,,,,,,, -50600,0.09725208,0.0242876,,,,,,,,,,,,,,,,, -50630,,,0.9932072758674622,0.0215307734906673,0.6065806372072025,0.9869790077209472,0.0456237904727458,0.2889365308084413,43793.0,0.9861738085746764,0.0487218536436557,0.2741141156405496,43793.0,16100.710295438766,23156.29674863816,16100.710295438766,7051.691777467728,2.554926872253418,0.0 -50700,0.09684337,0.024116425,,,,,,,,,,,,,,,,, -50800,0.0878536,0.0248509,,,,,,,,,,,,,,,,, -50900,0.08267153,0.024376608,,,,,,,,,,,,,,,,, -51000,0.083126694,0.024618335,,,,,,,,,,,,,,,,, -51100,0.11465116,0.025792371,,,,,,,,,,,,,,,,, -51200,0.088590905,0.025115425,,,,,,,,,,,,,,,,, -51300,0.09378345,0.02347063,,,,,,,,,,,,,,,,, -51395,,,0.9932363629341124,0.0213737972080707,0.6189733847194379,0.9871328473091124,0.0456616654992103,0.2887289210986936,43793.0,0.9862711429595948,0.0488883107900619,0.2745948013890646,43793.0,16340.700924158096,23497.049400806427,16340.700924158096,7152.396923303604,2.5921742916107178,0.0 -51400,0.074171476,0.02210599,,,,,,,,,,,,,,,,, -51500,0.09837264,0.023531659,,,,,,,,,,,,,,,,, -51600,0.10275896,0.026203038,,,,,,,,,,,,,,,,, -51700,0.080893405,0.022684146,,,,,,,,,,,,,,,,, -51800,0.102321655,0.023430152,,,,,,,,,,,,,,,,, -51900,0.10206988,0.026211057,,,,,,,,,,,,,,,,, -52000,0.09768881,0.024263108,,,,,,,,,,,,,,,,, -52100,0.095633104,0.026313215,,,,,,,,,,,,,,,,, -52159,,,0.9932331442832948,0.0213511120527982,0.6257244482281129,0.9869733452796936,0.0459792278707027,0.2871613198633184,43793.0,0.9862285852432252,0.0490499958395957,0.2820580224329422,43793.0,16580.71767473221,23838.59583926201,16580.71767473221,7253.871799230576,2.6272754669189453,0.0 -52200,0.08645384,0.022944529,,,,,,,,,,,,,,,,, -52300,0.087455645,0.022429077,,,,,,,,,,,,,,,,, -52400,0.09459072,0.02414929,,,,,,,,,,,,,,,,, -52500,0.08875874,0.02721901,,,,,,,,,,,,,,,,, -52600,0.11248671,0.02544354,,,,,,,,,,,,,,,,, -52700,0.105657026,0.021873794,,,,,,,,,,,,,,,,, -52800,0.10035535,0.024584223,,,,,,,,,,,,,,,,, -52900,0.101445906,0.02626336,,,,,,,,,,,,,,,,, -52918,,,0.9933499693870544,0.0209690537303686,0.6198161926220507,0.9870723485946656,0.04600640386343,0.2898691017552575,43793.0,0.9862694144248962,0.0490830689668655,0.2786305604471064,43793.0,16820.859786748886,24180.04267501831,16820.859786748886,7355.121006250381,2.663091897964477,0.0 -53000,0.10610035,0.024802787,,,,,,,,,,,,,,,,, -53100,0.10738879,0.024901947,,,,,,,,,,,,,,,,, -53200,0.09616361,0.022731459,,,,,,,,,,,,,,,,, -53300,0.11057663,0.028341966,,,,,,,,,,,,,,,,, -53400,0.086306944,0.022257542,,,,,,,,,,,,,,,,, -53500,0.097895645,0.024086643,,,,,,,,,,,,,,,,, -53600,0.09897125,0.02392455,,,,,,,,,,,,,,,,, -53674,,,0.9933534264564514,0.020808607339859,0.629301270151079,0.9871153831481934,0.0463089346885681,0.2890898983273863,43793.0,0.9862861037254332,0.0495411418378353,0.2794267027148863,43793.0,17061.08062529564,24523.11379837989,17061.08062529564,7457.915698289871,2.698948621749878,0.0 -53700,0.11104071,0.023816168,,,,,,,,,,,,,,,,, -53800,0.094566286,0.022318935,,,,,,,,,,,,,,,,, -53900,0.08929176,0.022853812,,,,,,,,,,,,,,,,, -54000,0.104897015,0.025041586,,,,,,,,,,,,,,,,, -54100,0.09904857,0.02213507,,,,,,,,,,,,,,,,, -54200,0.114689045,0.024098707,,,,,,,,,,,,,,,,, -54300,0.10132828,0.023660786,,,,,,,,,,,,,,,,, -54400,0.12659095,0.023139222,,,,,,,,,,,,,,,,, -54432,,,0.9934797883033752,0.0204735491424798,0.62993081766115,0.9871438145637512,0.0462908633053302,0.2877251097946155,43793.0,0.986311972141266,0.0494890883564949,0.276395624635202,43793.0,17301.122370243073,24862.13622522354,17301.122370243073,7556.839626789093,2.735714912414551,0.0 -54500,0.09684514,0.023954885,,,,,,,,,,,,,,,,, -54600,0.095387906,0.021233214,,,,,,,,,,,,,,,,, -54700,0.11293123,0.02501589,,,,,,,,,,,,,,,,, -54800,0.10385654,0.02170539,,,,,,,,,,,,,,,,, -54900,0.100466065,0.02352211,,,,,,,,,,,,,,,,, -55000,0.09615887,0.020898852,,,,,,,,,,,,,,,,, -55100,0.10223402,0.024477195,,,,,,,,,,,,,,,,, -55184,,,0.9935911297798156,0.020000223070383,0.6547790093724063,0.9871222972869872,0.0461614653468132,0.2900952606281319,43793.0,0.986243724822998,0.0493969358503818,0.2795634855130245,43793.0,17541.313578367233,25207.003221273422,17541.313578367233,7661.45597743988,2.774885654449463,0.0 -55200,0.11018653,0.022187715,,,,,,,,,,,,,,,,, -55300,0.1020423,0.023580657,,,,,,,,,,,,,,,,, -55400,0.103537045,0.024450202,,,,,,,,,,,,,,,,, -55500,0.09787756,0.0214605,,,,,,,,,,,,,,,,, -55600,0.108663194,0.02295037,,,,,,,,,,,,,,,,, -55700,0.10085635,0.021405173,,,,,,,,,,,,,,,,, -55800,0.09766624,0.021874156,,,,,,,,,,,,,,,,, -55900,0.09712295,0.02089955,,,,,,,,,,,,,,,,, -55944,,,0.9937967658042908,0.0194714777171611,0.6527825708898463,0.987092673778534,0.0467407219111919,0.2911061819818859,43793.0,0.9862310886383056,0.0500529184937477,0.2778119291983832,43793.0,17781.294590473175,25546.138780355453,17781.294590473175,7760.551167964935,2.8143293857574463,0.0 -56000,0.09704724,0.020178527,,,,,,,,,,,,,,,,, -56100,0.10626347,0.023137294,,,,,,,,,,,,,,,,, -56200,0.100626096,0.023490729,,,,,,,,,,,,,,,,, -56300,0.10983499,0.022134887,,,,,,,,,,,,,,,,, -56400,0.111367606,0.022196911,,,,,,,,,,,,,,,,, -56500,0.11275341,0.024606217,,,,,,,,,,,,,,,,, -56600,0.097592495,0.021769566,,,,,,,,,,,,,,,,, -56700,0.11516929,0.022974608,,,,,,,,,,,,,,,,, -56701,,,0.9940793514251708,0.0186166800558567,0.6898328240343279,0.987098753452301,0.0466601215302944,0.2899552937529251,43793.0,0.9862062335014344,0.0498196221888065,0.282412806094966,43793.0,18021.25959467888,25890.77930259705,18021.25959467888,7865.168210506439,2.853222131729126,0.0 -56800,0.098506704,0.021994837,,,,,,,,,,,,,,,,, -56900,0.10381998,0.020360464,,,,,,,,,,,,,,,,, -57000,0.09975175,0.024521213,,,,,,,,,,,,,,,,, -57100,0.10583009,0.022818735,,,,,,,,,,,,,,,,, -57200,0.1267304,0.023893187,,,,,,,,,,,,,,,,, -57300,0.12116043,0.023203252,,,,,,,,,,,,,,,,, -57400,0.10581498,0.022278547,,,,,,,,,,,,,,,,, -57456,,,0.9942800998687744,0.0180834010243415,0.6976413701881876,0.9871283769607544,0.0467562936246395,0.2962188183781732,43793.0,0.9862176179885864,0.0501618832349777,0.2802577179914768,43793.0,18261.490270614624,26231.215923547745,18261.490270614624,7965.316137075424,2.891075849533081,0.0 -57500,0.10700099,0.021514041,,,,,,,,,,,,,,,,, -57600,0.113459446,0.023134716,,,,,,,,,,,,,,,,, -57700,0.11505164,0.025557088,,,,,,,,,,,,,,,,, -57800,0.10034347,0.019657273,,,,,,,,,,,,,,,,, -57900,0.113618806,0.019872693,,,,,,,,,,,,,,,,, -58000,0.10889889,0.022283172,,,,,,,,,,,,,,,,, -58100,0.101442516,0.023295645,,,,,,,,,,,,,,,,, -58138,,,,,,,,,,,,,,18477.238425970078,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 6861fc288..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,57 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -879.4209234714508,0.0,43.30344867706299,1,0,43.30344867706299,0.0007088489946909,0.0,11.036273956298828,3003,922.7244355678558,0.0006987399538047,0.0,11.025596618652344,0.0004835649742744,0.0,11.047277450561523,3000 -1431.4115691184998,0.0320186614990234,883.4170672893524,2430,0,883.4170672893524,0.3871826231479645,8.292150497981243,4.267494201660156,3003,2314.9411220550537,0.4160743057727813,14.572851959066698,3.9471724033355713,0.3996106684207916,9.879816759300748,4.088059425354004,3000 -1917.9620339870453,0.062225341796875,1723.4967126846311,4858,0,1723.4967126846311,0.5444541573524475,19.09181470174716,2.736154556274414,3003,3641.679195165634,0.5455783009529114,24.37134457556304,2.718634366989136,0.5439361929893494,20.40709098587339,2.697274923324585,3000 -2385.4128901958466,0.0925908088684082,2563.5470848083496,7287,0,2563.5470848083496,0.5892278552055359,22.25110580174832,2.3010103702545166,3003,4949.287957668304,0.582105815410614,27.688640848896345,2.346024751663208,0.5865395069122314,23.46770972919153,2.3033597469329834,3000 -2825.075345516205,0.1242415904998779,3403.686936378479,9718,0,3403.686936378479,0.6132822036743164,23.29025859698584,2.0930325984954834,3003,6229.200881242752,0.59144127368927,27.976897526423063,2.244596481323242,0.6069732308387756,24.509640670153523,2.119935989379883,3000 -3456.4474980831146,0.1512916088104248,4243.90465593338,12150,0,4243.90465593338,0.6300621628761292,24.743552012654888,1.9602621793746948,3003,7700.897699356079,0.6025044322013855,28.804964846373508,2.141352653503418,0.6188392043113708,25.295877114636344,2.0071840286254883,3000 -3898.492534637451,0.1806375980377197,5084.021989107132,14582,0,5084.021989107132,0.6397652626037598,25.7206136569018,1.8675507307052608,3003,8983.16919708252,0.6175777316093445,29.49418838670804,2.0234029293060303,0.6317466497421265,26.446092461664627,1.9158974885940552,3000 -4382.143723726273,0.3407480716705322,5923.943639755249,17013,0,5923.943639755249,0.6494567394256592,26.089327875831792,1.7908827066421509,3003,10306.980433225632,0.6190584897994995,30.250872839385124,2.003643274307251,0.6397812962532043,26.92715774325129,1.84742283821106,3000 -4849.296671628952,0.3672385215759277,6764.000878095627,19445,0,6764.000878095627,0.6539306640625,26.60863093886362,1.760025978088379,3003,11614.296314954758,0.6344646215438843,30.712155809537048,1.891523718833924,0.6446169018745422,27.5092599023563,1.809007167816162,3000 -5349.620749235153,0.3979494571685791,7604.201534986496,21878,0,7604.201534986496,0.6575562357902527,26.950158208181367,1.7227576971054075,3003,12954.930808782578,0.6279143691062927,30.117385200900088,1.938279628753662,0.6497625708580017,27.39135952713508,1.7786868810653689,3000 -6040.07980966568,0.4274477958679199,8444.363789081573,24311,0,8444.363789081573,0.6625530123710632,27.38874092003688,1.7036962509155271,3003,14485.660428762436,0.6267665028572083,30.495337151691416,1.955460786819458,0.6528747081756592,27.740078969040887,1.7616363763809204,3000 -6516.639975786209,0.4582266807556152,9284.414820194244,26744,0,9284.414820194244,0.6628900170326233,27.04296441504016,1.6931463479995728,3003,15802.381581544876,0.6343207955360413,31.050486129125893,1.8967632055282595,0.6534822583198547,28.09813306279128,1.7492674589157104,3000 -6996.507694721222,0.4879882335662842,10124.47589802742,29177,0,10124.47589802742,0.665864884853363,27.865501179623493,1.6830744743347168,3003,17122.42004466057,0.6336384415626526,30.684595268987056,1.8958899974823,0.6565200686454773,28.03879571932596,1.7404468059539795,3000 -7457.794209480286,0.5176031589508057,10964.547756910324,31610,0,10964.547756910324,0.6660043001174927,27.773628041650408,1.6788660287857056,3003,18423.88781070709,0.6526271104812622,32.05716202678791,1.768532156944275,0.654350221157074,28.03466477961463,1.736448884010315,3000 -7946.703389883041,0.5480039119720459,11804.70926952362,34044,0,11804.70926952362,0.6665620803833008,27.55349728228144,1.669747233390808,3003,19753.06885743141,0.6363418102264404,31.210927634435524,1.877790093421936,0.65668123960495,28.09483148598837,1.7198882102966309,3000 -8539.038867473602,0.5769636631011963,12644.77906513214,36478,0,12644.77906513214,0.6687583923339844,27.75639606929041,1.6572718620300293,3003,21185.581884860992,0.6364160180091858,31.076741563059336,1.8722435235977173,0.6580327749252319,28.03697177435267,1.717186689376831,3000 -9024.285813570024,0.6064877510070801,13484.840069770811,38912,0,13484.840069770811,0.6716634631156921,27.945611634666346,1.6385266780853271,3003,22510.99568128585,0.6482433080673218,31.43274199855436,1.7970041036605835,0.6609093546867371,28.180620886314376,1.703112006187439,3000 -9498.646374464037,0.6380147933959961,14324.897027015686,41346,0,14324.897027015686,0.6728255152702332,28.00362928869641,1.63887357711792,3003,23825.52012705803,0.6373471021652222,31.01315039572357,1.8696467876434328,0.6601405739784241,28.38573685693433,1.7056689262390137,3000 -9989.341492176056,0.668973445892334,15164.826464653015,43780,0,15164.826464653015,0.6713381409645081,27.69613416655745,1.6407963037490845,3003,25156.252128362656,0.6370298266410828,31.06694249891112,1.8723466396331787,0.6600537896156311,28.26108878119694,1.697940468788147,3000 -10469.45325231552,0.7002456188201904,16004.92531490326,46214,0,16004.92531490326,0.6738249063491821,28.05211040280357,1.6290220022201538,3003,26476.57047533989,0.6392388343811035,31.98645200096553,1.8472505807876587,0.6632279753684998,28.6184400125844,1.6884485483169556,3000 -10951.921907424929,0.731576681137085,16845.00081062317,48648,0,16845.00081062317,0.6732555031776428,28.00563946047731,1.6297379732131958,3003,27799.22328066826,0.6418137550354004,31.435799419526703,1.8364354372024536,0.662161648273468,28.773273751970383,1.6864773035049438,3000 -11405.755249738691,0.7633512020111084,17684.95172381401,51082,0,17684.95172381401,0.6755679845809937,28.39498522119112,1.6110336780548096,3003,29093.11604142189,0.6506011486053467,32.20130396977178,1.7793463468551636,0.6655961871147156,28.732523018514083,1.6724234819412231,3000 -11877.134405851364,0.7947971820831299,18525.0443277359,53516,0,18525.0443277359,0.6770902276039124,28.199818862678665,1.6028481721878052,3003,30404.6956615448,0.6455166935920715,31.56522640495758,1.8138216733932493,0.6665757298469543,28.52078753000625,1.6703606843948364,3000 -12500.575400590897,0.8282985687255859,19365.25709867477,55951,0,19365.25709867477,0.680553138256073,28.54534410724413,1.5927239656448364,3003,31868.45972037316,0.6445412635803223,32.05229143406017,1.82591450214386,0.6659185886383057,28.706732866328736,1.664124608039856,3000 -13047.804039001465,0.8610539436340332,20205.149728775024,58385,0,20205.149728775024,0.6782174110412598,28.35522166002915,1.5912344455718994,3003,33255.68974637985,0.6512637138366699,31.62150612676097,1.7649765014648438,0.6669105291366577,28.5506335590646,1.6533396244049072,3000 -13514.856243610382,0.8937373161315918,21045.22128152848,60820,0,21045.22128152848,0.6813201308250427,28.69376951309905,1.5846421718597412,3003,34562.92338228226,0.6503833532333374,31.359232455254705,1.7853611707687378,0.6674684882164001,28.79202635866657,1.655722975730896,3000 -13980.560697555542,0.9289267063140868,21885.19492340088,63255,0,21885.19492340088,0.6831212639808655,28.78299893054562,1.5691425800323486,3003,35868.71333575249,0.6639273166656494,32.86530003174533,1.6947698593139648,0.6687827706336975,29.130301380967907,1.639392375946045,3000 -14488.048207998276,0.964684009552002,22725.28811383248,65690,0,22725.28811383248,0.6841439008712769,28.850066271740804,1.5664318799972534,3003,37216.40605187416,0.6493473649024963,32.37563384288451,1.7821170091629028,0.6715725660324097,29.14645826795015,1.6345800161361694,3000 -14976.696114301682,0.99855375289917,23565.320879220963,68125,0,23565.320879220963,0.6857939958572388,29.1256183217246,1.5533937215805054,3003,38545.19595956802,0.6522968411445618,32.26716561149852,1.7712243795394895,0.6721181273460388,29.2856531151115,1.6278233528137207,3000 -15442.105145931244,1.085925817489624,24405.44509220124,70560,0,24405.44509220124,0.6852827072143555,28.89788171712957,1.5494496822357178,3003,39850.8931684494,0.6611246466636658,32.33671503844988,1.709555745124817,0.6735812425613403,29.11200803725564,1.61667001247406,3000 -15908.294090032578,1.1230123043060305,25245.6302895546,72995,0,25245.6302895546,0.6878392100334167,29.30915630269909,1.5402271747589111,3003,41157.38154554367,0.6553438901901245,32.79811780395135,1.7596919536590576,0.6746723651885986,29.66018880236669,1.609011173248291,3000 -16438.57457280159,1.1592369079589844,26085.57997250557,75429,0,26085.57997250557,0.689012885093689,29.0940858153008,1.5321506261825562,3003,42527.72531962395,0.6763533353805542,33.937369926587024,1.624076247215271,0.6763462424278259,29.693910913014506,1.6011841297149658,3000 -16923.432988643646,1.1936464309692385,26925.578989982605,77864,0,26925.578989982605,0.6887107491493225,29.20757973008861,1.527599573135376,3003,43852.69458556175,0.6609399318695068,32.66516373016686,1.7068854570388794,0.6770405769348145,29.53969753202616,1.597065806388855,3000 -17435.547990322113,1.230426788330078,27765.792411088943,80299,0,27765.792411088943,0.6894195675849915,29.1376581392568,1.5198761224746704,3003,45205.13687705994,0.6609601974487305,32.26415699149423,1.7184412479400637,0.6776357293128967,29.55108559881877,1.5888279676437378,3000 -17891.961584091187,1.264892816543579,28605.77160620689,82733,0,28605.77160620689,0.6937772631645203,29.74323986290128,1.509999394416809,3003,46501.64113926888,0.6679980754852295,33.5625012966058,1.673125147819519,0.6785408854484558,29.696808866105183,1.587047100067139,3000 -18397.664662599564,1.301595687866211,29445.863482236862,85167,0,29445.863482236862,0.6938701868057251,29.60024070059604,1.5030529499053955,3003,47847.549050569534,0.6636251211166382,32.7424648321825,1.700329303741455,0.6798551678657532,29.852257792933663,1.578163981437683,3000 -18879.404970645905,1.3394997119903564,30286.26316356659,87601,0,30286.26316356659,0.6945558190345764,29.48370587895101,1.498718023300171,3003,49169.80411338806,0.7056272029876709,35.86558291442635,1.4632694721221924,0.6798055768013,29.87314936635056,1.5734652280807495,3000 -19360.830275535583,1.3778259754180908,31126.28993988037,90035,0,31126.28993988037,0.6961826682090759,29.9603429821149,1.489499807357788,3003,50491.37156367302,0.6675459742546082,33.33141123987844,1.6747876405715942,0.6829301714897156,29.92197180523269,1.564455509185791,3000 -19840.359940052032,1.4165542125701904,31966.21331453324,92469,0,31966.21331453324,0.6978908777236938,30.035932898851115,1.4811248779296875,3003,51810.94109606743,0.6658973097801208,33.660730787562834,1.6872364282608032,0.6835625171661377,30.26361366275628,1.5563279390335083,3000 -20324.09602546692,1.453615665435791,32806.30698490143,94903,0,32806.30698490143,0.7008076310157776,29.92625119054191,1.4715487957000732,3003,53134.88615298271,0.6817435026168823,33.982398172121265,1.5854727029800415,0.6839964985847473,30.10177069772297,1.5531708002090454,3000 -20785.443420887,1.492318868637085,33646.52457332611,97337,0,33646.52457332611,0.7019813060760498,30.23504562070128,1.4636123180389404,3003,54436.56728100777,0.6754931807518005,33.42584230300187,1.6232917308807373,0.6868730783462524,30.324988297327963,1.5403165817260742,3000 -21264.37947511673,1.5324418544769287,34486.669437885284,99771,0,34486.669437885284,0.7032363414764404,30.214177626963963,1.4567337036132812,3003,55755.76510024071,0.6739282011985779,33.990486539139006,1.63223135471344,0.6875302195549011,30.240875818928536,1.536251187324524,3000 -21716.28870844841,1.5729172229766846,35326.57743191719,102204,0,35326.57743191719,0.7022369503974915,30.21770641738397,1.456139326095581,3003,57047.69973301888,0.6823273301124573,34.581900656888564,1.5801806449890137,0.6878277659416199,30.43020782921961,1.5308395624160769,3000 -22211.2339220047,1.612227439880371,36166.62100124359,104638,0,36166.62100124359,0.7040613889694214,30.680197257595136,1.444055676460266,3003,58382.8054254055,0.6786729693412781,34.14329271426784,1.605922818183899,0.688373327255249,30.576941500675023,1.5237150192260742,3000 -22677.718291044235,1.6531658172607422,37006.69671726227,107072,0,37006.69671726227,0.7058277130126953,30.48795900641476,1.44207763671875,3003,59689.48287606239,0.6972777843475342,35.75124263193658,1.501353740692139,0.6902828216552734,30.50713657492696,1.5222079753875732,3000 -23166.557002544403,1.693382978439331,37846.90900039673,109506,0,37846.90900039673,0.7075126767158508,30.78264729883304,1.4303762912750244,3003,61018.65080690384,0.6877254843711853,35.05868995744967,1.5565861463546753,0.6905431747436523,30.72632314292713,1.5144131183624268,3000 -23638.802623033524,1.736116647720337,38686.852848529816,111939,0,38686.852848529816,0.707129180431366,30.68408721481186,1.427580952644348,3003,62330.95977449417,0.6905927658081055,34.9471526156188,1.540831446647644,0.6911631226539612,30.61271867582907,1.5094802379608154,3000 -24104.972608327866,1.7767961025238037,39526.98830747605,114372,0,39526.98830747605,0.708628237247467,30.82562319625663,1.4222077131271362,3003,63637.38416719437,0.7004918456077576,36.07395854939219,1.4847887754440308,0.6930230259895325,30.78058822928789,1.5063766241073608,3000 -24577.193524599075,1.8184754848480225,40366.93438744545,116805,0,40366.93438744545,0.709511399269104,30.89250024411417,1.4227913618087769,3003,64949.66950559616,0.6977561116218567,35.92062263456214,1.5053915977478027,0.6923906803131104,30.750445632725896,1.5051144361495972,3000 -25054.274214982983,1.8607072830200195,41207.07846236229,119239,0,41207.07846236229,0.7099180817604065,30.699474406255312,1.4181479215621948,3003,66267.013463974,0.7087553143501282,36.90472472820269,1.4513678550720217,0.6937917470932007,31.007339831048924,1.5004736185073853,3000 -25522.81412148476,1.9024908542633057,42047.01487851143,121672,0,42047.01487851143,0.7114868760108948,30.87680663083024,1.4147039651870728,3003,67575.60860204697,0.7050728797912598,36.167227486174745,1.468165159225464,0.6942753195762634,30.98735750993361,1.498553991317749,3000 -25989.104485034943,1.9998183250427248,42886.95123958588,124104,0,42886.95123958588,0.7113357782363892,30.965012996729225,1.4104185104370115,3003,68882.01002573967,0.7041686773300171,36.28302984249575,1.464862823486328,0.6946101188659668,30.891498149474604,1.495320200920105,3000 -26465.41655921936,2.0420727729797363,43726.96578860283,126537,0,43726.96578860283,0.7120446562767029,31.058713552731536,1.4122170209884644,3003,70198.45528745651,0.7092657685279846,36.99707375984359,1.449809193611145,0.693990170955658,30.97352039029193,1.496898889541626,3000 -26926.0817000866,2.082228422164917,44566.927010297775,128970,0,44566.927010297775,0.7122886776924133,30.86724253804981,1.4110324382781982,3003,71499.1986413002,0.7105950117111206,37.02532880316543,1.435489535331726,0.6939777731895447,31.032989289148937,1.4965522289276123,3000 -27392.42360830307,2.123875141143799,45407.02690410614,131403,0,45407.02690410614,0.7123119235038757,30.963210714876272,1.4104419946670532,3003,72805.75824856758,0.705418586730957,36.74133977241949,1.4653115272521973,0.6943249106407166,30.95072385384553,1.496314287185669,3000 -27869.805367946625,2.1672544479370117,46073.19484710693,133333,0,46073.19484710693,0.7121259570121765,30.99405716021917,1.4103668928146362,3003,73949.41149926186,0.7103955149650574,36.525395183981935,1.443301796913147,0.6940893530845642,30.98135216301977,1.4960108995437622,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 3c4baa1fb..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1392 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.1712017,11.0257225,,,,,,,,,,,,,,,,, -1,,,0.0006987399538047,11.025596618652344,0.0,0.0004835649742744,11.047277450561523,0.0,3000.0,0.0007088489946909,11.036273956298828,0.0,3003.0,43.30344867706299,922.7244355678558,43.30344867706299,879.4209234714508,0.0,0.0 -100,0.3851019,8.9368305,,,,,,,,,,,,,,,,, -200,0.1641899,8.642636,,,,,,,,,,,,,,,,, -300,0.16898745,8.394832,,,,,,,,,,,,,,,,, -400,0.257133,8.007311,,,,,,,,,,,,,,,,, -500,0.26825276,7.6315174,,,,,,,,,,,,,,,,, -600,0.4216326,7.3458133,,,,,,,,,,,,,,,,, -700,0.57517725,7.205698,,,,,,,,,,,,,,,,, -800,0.71955097,6.9412518,,,,,,,,,,,,,,,,, -900,0.60490173,6.7433176,,,,,,,,,,,,,,,,, -1000,0.44264117,6.566967,,,,,,,,,,,,,,,,, -1100,0.53146935,6.432264,,,,,,,,,,,,,,,,, -1200,0.5997828,6.166048,,,,,,,,,,,,,,,,, -1300,0.5674819,6.101433,,,,,,,,,,,,,,,,, -1400,0.6836199,5.9339643,,,,,,,,,,,,,,,,, -1500,0.934188,5.829527,,,,,,,,,,,,,,,,, -1600,0.7273587,5.708962,,,,,,,,,,,,,,,,, -1700,0.9930723,5.661512,,,,,,,,,,,,,,,,, -1800,0.6728308,5.422971,,,,,,,,,,,,,,,,, -1900,0.97025,5.404276,,,,,,,,,,,,,,,,, -2000,0.9927054,5.319634,,,,,,,,,,,,,,,,, -2100,0.9574005,5.0800104,,,,,,,,,,,,,,,,, -2200,0.86829203,5.101583,,,,,,,,,,,,,,,,, -2300,0.95582,4.937293,,,,,,,,,,,,,,,,, -2400,0.98412985,4.8351955,,,,,,,,,,,,,,,,, -2430,,,0.4160743057727813,3.9471724033355713,14.572851959066698,0.3996106684207916,4.088059425354004,9.879816759300748,3000.0,0.3871826231479645,4.267494201660156,8.292150497981243,3003.0,883.4170672893524,2314.9411220550537,883.4170672893524,1431.4115691184998,0.0320186614990234,0.0 -2500,0.8142932,4.79581,,,,,,,,,,,,,,,,, -2600,0.72622967,4.650498,,,,,,,,,,,,,,,,, -2700,0.71196073,4.5882936,,,,,,,,,,,,,,,,, -2800,0.75972694,4.5735636,,,,,,,,,,,,,,,,, -2900,0.99238825,4.420064,,,,,,,,,,,,,,,,, -3000,0.6328083,4.4355416,,,,,,,,,,,,,,,,, -3100,0.7147423,4.3574796,,,,,,,,,,,,,,,,, -3200,0.7797041,4.306712,,,,,,,,,,,,,,,,, -3300,0.6220578,4.235191,,,,,,,,,,,,,,,,, -3400,0.6400887,4.212076,,,,,,,,,,,,,,,,, -3500,0.6570379,4.082402,,,,,,,,,,,,,,,,, -3600,0.85077256,4.1599646,,,,,,,,,,,,,,,,, -3700,0.65788966,4.049694,,,,,,,,,,,,,,,,, -3800,0.7003765,4.0311637,,,,,,,,,,,,,,,,, -3900,0.615092,3.9730077,,,,,,,,,,,,,,,,, -4000,0.58707434,4.0189333,,,,,,,,,,,,,,,,, -4100,0.61248827,3.8960257,,,,,,,,,,,,,,,,, -4200,0.6493089,3.9684746,,,,,,,,,,,,,,,,, -4300,0.6292191,3.8649256,,,,,,,,,,,,,,,,, -4400,0.56004757,3.8033233,,,,,,,,,,,,,,,,, -4500,0.52885294,3.7611926,,,,,,,,,,,,,,,,, -4600,0.6535723,3.8091874,,,,,,,,,,,,,,,,, -4700,0.7550335,3.7470505,,,,,,,,,,,,,,,,, -4800,0.6004969,3.8519838,,,,,,,,,,,,,,,,, -4858,,,0.5455783009529114,2.718634366989136,24.37134457556304,0.5439361929893494,2.697274923324585,20.40709098587339,3000.0,0.5444541573524475,2.736154556274414,19.09181470174716,3003.0,1723.4967126846311,3641.679195165634,1723.4967126846311,1917.9620339870453,0.062225341796875,0.0 -4900,0.7737281,3.7521155,,,,,,,,,,,,,,,,, -5000,0.55329084,3.6989856,,,,,,,,,,,,,,,,, -5100,0.5570775,3.6933494,,,,,,,,,,,,,,,,, -5200,0.50587684,3.7463343,,,,,,,,,,,,,,,,, -5300,0.51153564,3.6608565,,,,,,,,,,,,,,,,, -5400,0.55925393,3.6621423,,,,,,,,,,,,,,,,, -5500,0.5408886,3.6810148,,,,,,,,,,,,,,,,, -5600,0.4565034,3.6279325,,,,,,,,,,,,,,,,, -5700,0.47904062,3.6168637,,,,,,,,,,,,,,,,, -5800,0.4904022,3.5882428,,,,,,,,,,,,,,,,, -5900,0.46978527,3.5569727,,,,,,,,,,,,,,,,, -6000,0.4687871,3.60821,,,,,,,,,,,,,,,,, -6100,0.5816827,3.639455,,,,,,,,,,,,,,,,, -6200,0.6184545,3.641589,,,,,,,,,,,,,,,,, -6300,0.44208613,3.6108103,,,,,,,,,,,,,,,,, -6400,0.59079564,3.549563,,,,,,,,,,,,,,,,, -6500,0.5215443,3.524483,,,,,,,,,,,,,,,,, -6600,0.445443,3.4663038,,,,,,,,,,,,,,,,, -6700,0.48295745,3.4303844,,,,,,,,,,,,,,,,, -6800,0.4339562,3.496317,,,,,,,,,,,,,,,,, -6900,0.458095,3.4893773,,,,,,,,,,,,,,,,, -7000,0.5815599,3.598828,,,,,,,,,,,,,,,,, -7100,0.45408657,3.4312618,,,,,,,,,,,,,,,,, -7200,0.38237104,3.4049737,,,,,,,,,,,,,,,,, -7287,,,0.582105815410614,2.346024751663208,27.688640848896345,0.5865395069122314,2.3033597469329834,23.46770972919153,3000.0,0.5892278552055359,2.3010103702545166,22.25110580174832,3003.0,2563.5470848083496,4949.287957668304,2563.5470848083496,2385.4128901958466,0.0925908088684082,0.0 -7300,0.381182,3.3686392,,,,,,,,,,,,,,,,, -7400,0.43344927,3.4607358,,,,,,,,,,,,,,,,, -7500,0.4336014,3.4162145,,,,,,,,,,,,,,,,, -7600,0.43483078,3.4563737,,,,,,,,,,,,,,,,, -7700,0.37907562,3.4520528,,,,,,,,,,,,,,,,, -7800,0.37093613,3.3106256,,,,,,,,,,,,,,,,, -7900,0.41310206,3.4451034,,,,,,,,,,,,,,,,, -8000,0.38150573,3.433023,,,,,,,,,,,,,,,,, -8100,0.35237506,3.4446735,,,,,,,,,,,,,,,,, -8200,0.33952716,3.3808851,,,,,,,,,,,,,,,,, -8300,0.41522095,3.371549,,,,,,,,,,,,,,,,, -8400,0.37341973,3.4604623,,,,,,,,,,,,,,,,, -8500,0.37238416,3.4126904,,,,,,,,,,,,,,,,, -8600,0.33254337,3.3248508,,,,,,,,,,,,,,,,, -8700,0.32618785,3.3000166,,,,,,,,,,,,,,,,, -8800,0.4145163,3.4043927,,,,,,,,,,,,,,,,, -8900,0.3578754,3.33627,,,,,,,,,,,,,,,,, -9000,0.34279498,3.4395185,,,,,,,,,,,,,,,,, -9100,0.31083676,3.3370872,,,,,,,,,,,,,,,,, -9200,0.33976305,3.4482431,,,,,,,,,,,,,,,,, -9300,0.29680517,3.329344,,,,,,,,,,,,,,,,, -9400,0.36264497,3.395278,,,,,,,,,,,,,,,,, -9500,0.29302478,3.310645,,,,,,,,,,,,,,,,, -9600,0.35725766,3.3330336,,,,,,,,,,,,,,,,, -9700,0.304271,3.2735195,,,,,,,,,,,,,,,,, -9718,,,0.59144127368927,2.244596481323242,27.976897526423063,0.6069732308387756,2.119935989379883,24.509640670153523,3000.0,0.6132822036743164,2.0930325984954834,23.29025859698584,3003.0,3403.686936378479,6229.200881242752,3403.686936378479,2825.075345516205,0.1242415904998779,0.0 -9800,0.2867816,3.3017752,,,,,,,,,,,,,,,,, -9900,0.28720528,3.2299335,,,,,,,,,,,,,,,,, -10000,0.27980936,3.3170786,,,,,,,,,,,,,,,,, -10100,0.26189148,3.2556992,,,,,,,,,,,,,,,,, -10200,0.37970003,3.3166256,,,,,,,,,,,,,,,,, -10300,0.2770932,3.2282405,,,,,,,,,,,,,,,,, -10400,0.29271013,3.243567,,,,,,,,,,,,,,,,, -10500,0.27698407,3.314072,,,,,,,,,,,,,,,,, -10600,0.3195848,3.2099576,,,,,,,,,,,,,,,,, -10700,0.28591675,3.2660756,,,,,,,,,,,,,,,,, -10800,0.2557221,3.2943316,,,,,,,,,,,,,,,,, -10900,0.26945284,3.2941368,,,,,,,,,,,,,,,,, -11000,0.25582057,3.343699,,,,,,,,,,,,,,,,, -11100,0.26753014,3.2049398,,,,,,,,,,,,,,,,, -11200,0.27561042,3.2743483,,,,,,,,,,,,,,,,, -11300,0.26955467,3.2409325,,,,,,,,,,,,,,,,, -11400,0.28594542,3.194338,,,,,,,,,,,,,,,,, -11500,0.25171742,3.3079603,,,,,,,,,,,,,,,,, -11600,0.27873358,3.3072712,,,,,,,,,,,,,,,,, -11700,0.25464168,3.1806085,,,,,,,,,,,,,,,,, -11800,0.25084648,3.2797773,,,,,,,,,,,,,,,,, -11900,0.25307587,3.218046,,,,,,,,,,,,,,,,, -12000,0.26811692,3.2869346,,,,,,,,,,,,,,,,, -12100,0.27603683,3.1849656,,,,,,,,,,,,,,,,, -12150,,,0.6025044322013855,2.141352653503418,28.804964846373508,0.6188392043113708,2.0071840286254883,25.295877114636344,3000.0,0.6300621628761292,1.9602621793746948,24.743552012654888,3003.0,4243.90465593338,7700.897699356079,4243.90465593338,3456.4474980831146,0.1512916088104248,0.0 -12200,0.24495853,3.2372665,,,,,,,,,,,,,,,,, -12300,0.25761378,3.2696655,,,,,,,,,,,,,,,,, -12400,0.240599,3.1970985,,,,,,,,,,,,,,,,, -12500,0.2915406,3.2113252,,,,,,,,,,,,,,,,, -12600,0.25040796,3.1393912,,,,,,,,,,,,,,,,, -12700,0.25245506,3.1241448,,,,,,,,,,,,,,,,, -12800,0.22774994,3.2149374,,,,,,,,,,,,,,,,, -12900,0.26934025,3.1906786,,,,,,,,,,,,,,,,, -13000,0.2737624,3.1672273,,,,,,,,,,,,,,,,, -13100,0.2378514,3.2016518,,,,,,,,,,,,,,,,, -13200,0.23781396,3.201273,,,,,,,,,,,,,,,,, -13300,0.2540831,3.2014232,,,,,,,,,,,,,,,,, -13400,0.232701,3.2234154,,,,,,,,,,,,,,,,, -13500,0.23692839,3.141309,,,,,,,,,,,,,,,,, -13600,0.24773884,3.1700296,,,,,,,,,,,,,,,,, -13700,0.24207138,3.1519377,,,,,,,,,,,,,,,,, -13800,0.23671032,3.190748,,,,,,,,,,,,,,,,, -13900,0.26493958,3.2415335,,,,,,,,,,,,,,,,, -14000,0.23949876,3.0848422,,,,,,,,,,,,,,,,, -14100,0.23066385,3.1467433,,,,,,,,,,,,,,,,, -14200,0.2517174,3.1129923,,,,,,,,,,,,,,,,, -14300,0.21937731,3.1337054,,,,,,,,,,,,,,,,, -14400,0.23810203,3.1767132,,,,,,,,,,,,,,,,, -14500,0.30181333,3.1590848,,,,,,,,,,,,,,,,, -14582,,,0.6175777316093445,2.0234029293060303,29.49418838670804,0.6317466497421265,1.9158974885940552,26.446092461664627,3000.0,0.6397652626037598,1.8675507307052608,25.7206136569018,3003.0,5084.021989107132,8983.16919708252,5084.021989107132,3898.492534637451,0.1806375980377197,0.0 -14600,0.26428092,3.2252333,,,,,,,,,,,,,,,,, -14700,0.25053525,3.1262898,,,,,,,,,,,,,,,,, -14800,0.29223573,3.163513,,,,,,,,,,,,,,,,, -14900,0.2889306,3.103704,,,,,,,,,,,,,,,,, -15000,0.32726353,3.0552804,,,,,,,,,,,,,,,,, -15100,0.24491455,3.1772325,,,,,,,,,,,,,,,,, -15200,0.24765864,3.1544204,,,,,,,,,,,,,,,,, -15300,0.2541877,3.1259768,,,,,,,,,,,,,,,,, -15400,0.23295389,3.0402963,,,,,,,,,,,,,,,,, -15500,0.2539882,3.1559846,,,,,,,,,,,,,,,,, -15600,0.24787907,3.1203933,,,,,,,,,,,,,,,,, -15700,0.24901712,3.1051152,,,,,,,,,,,,,,,,, -15800,0.24576916,3.0433187,,,,,,,,,,,,,,,,, -15900,0.26279223,3.136595,,,,,,,,,,,,,,,,, -16000,0.2640742,3.0995135,,,,,,,,,,,,,,,,, -16100,0.26434073,3.185093,,,,,,,,,,,,,,,,, -16200,0.26794356,3.0754955,,,,,,,,,,,,,,,,, -16300,0.25899974,3.0615304,,,,,,,,,,,,,,,,, -16400,0.37872523,3.0167298,,,,,,,,,,,,,,,,, -16500,0.29292658,3.081177,,,,,,,,,,,,,,,,, -16600,0.27489418,3.0839708,,,,,,,,,,,,,,,,, -16700,0.33617252,3.0707138,,,,,,,,,,,,,,,,, -16800,0.2975119,3.1092947,,,,,,,,,,,,,,,,, -16900,0.30002388,3.0880306,,,,,,,,,,,,,,,,, -17000,0.2737495,3.0873108,,,,,,,,,,,,,,,,, -17013,,,0.6190584897994995,2.003643274307251,30.250872839385124,0.6397812962532043,1.84742283821106,26.92715774325129,3000.0,0.6494567394256592,1.7908827066421509,26.089327875831792,3003.0,5923.943639755249,10306.980433225632,5923.943639755249,4382.143723726273,0.3407480716705322,0.0 -17100,0.29774535,3.1142774,,,,,,,,,,,,,,,,, -17200,0.25445375,3.0458388,,,,,,,,,,,,,,,,, -17300,0.2755354,3.1110659,,,,,,,,,,,,,,,,, -17400,0.30045414,3.0590887,,,,,,,,,,,,,,,,, -17500,0.31948295,3.0865448,,,,,,,,,,,,,,,,, -17600,0.29226175,3.122108,,,,,,,,,,,,,,,,, -17700,0.3343601,3.0672398,,,,,,,,,,,,,,,,, -17800,0.28505722,3.058982,,,,,,,,,,,,,,,,, -17900,0.30494663,3.0910053,,,,,,,,,,,,,,,,, -18000,0.33457476,3.0640674,,,,,,,,,,,,,,,,, -18100,0.361145,3.113535,,,,,,,,,,,,,,,,, -18200,0.36137843,3.0756156,,,,,,,,,,,,,,,,, -18300,0.33990178,3.1200533,,,,,,,,,,,,,,,,, -18400,0.28482687,3.0920928,,,,,,,,,,,,,,,,, -18500,0.2856334,3.097022,,,,,,,,,,,,,,,,, -18600,0.32341227,3.0451725,,,,,,,,,,,,,,,,, -18700,0.28054044,3.0416214,,,,,,,,,,,,,,,,, -18800,0.37298298,3.1304684,,,,,,,,,,,,,,,,, -18900,0.296957,2.964656,,,,,,,,,,,,,,,,, -19000,0.41019216,3.0267622,,,,,,,,,,,,,,,,, -19100,0.29319093,3.0725656,,,,,,,,,,,,,,,,, -19200,0.3215903,3.0467162,,,,,,,,,,,,,,,,, -19300,0.32969984,3.030423,,,,,,,,,,,,,,,,, -19400,0.3232109,3.0681005,,,,,,,,,,,,,,,,, -19445,,,0.6344646215438843,1.891523718833924,30.712155809537048,0.6446169018745422,1.809007167816162,27.5092599023563,3000.0,0.6539306640625,1.760025978088379,26.60863093886362,3003.0,6764.000878095627,11614.296314954758,6764.000878095627,4849.296671628952,0.3672385215759277,0.0 -19500,0.32656047,3.0507536,,,,,,,,,,,,,,,,, -19600,0.3697987,3.0246816,,,,,,,,,,,,,,,,, -19700,0.3114059,3.070331,,,,,,,,,,,,,,,,, -19800,0.35591024,3.1442196,,,,,,,,,,,,,,,,, -19900,0.30727974,3.058143,,,,,,,,,,,,,,,,, -20000,0.3165507,3.060579,,,,,,,,,,,,,,,,, -20100,0.398243,3.0136669,,,,,,,,,,,,,,,,, -20200,0.34397742,3.0157201,,,,,,,,,,,,,,,,, -20300,0.31134334,3.0554307,,,,,,,,,,,,,,,,, -20400,0.32500315,3.0962658,,,,,,,,,,,,,,,,, -20500,0.43170053,3.0540445,,,,,,,,,,,,,,,,, -20600,0.32952988,3.0799587,,,,,,,,,,,,,,,,, -20700,0.38006735,3.0447993,,,,,,,,,,,,,,,,, -20800,0.3558271,3.0131042,,,,,,,,,,,,,,,,, -20900,0.30311826,2.9737537,,,,,,,,,,,,,,,,, -21000,0.32175153,3.0009518,,,,,,,,,,,,,,,,, -21100,0.4225156,2.9900198,,,,,,,,,,,,,,,,, -21200,0.39885956,3.0283015,,,,,,,,,,,,,,,,, -21300,0.4050013,3.0314658,,,,,,,,,,,,,,,,, -21400,0.31350264,3.040162,,,,,,,,,,,,,,,,, -21500,0.31632978,3.0510182,,,,,,,,,,,,,,,,, -21600,0.33028117,3.0150602,,,,,,,,,,,,,,,,, -21700,0.34845224,2.99552,,,,,,,,,,,,,,,,, -21800,0.40093148,3.0232608,,,,,,,,,,,,,,,,, -21878,,,0.6279143691062927,1.938279628753662,30.117385200900088,0.6497625708580017,1.7786868810653689,27.39135952713508,3000.0,0.6575562357902527,1.7227576971054075,26.950158208181367,3003.0,7604.201534986496,12954.930808782578,7604.201534986496,5349.620749235153,0.3979494571685791,0.0 -21900,0.39610526,3.0989454,,,,,,,,,,,,,,,,, -22000,0.36828062,3.0726395,,,,,,,,,,,,,,,,, -22100,0.8854342,3.058156,,,,,,,,,,,,,,,,, -22200,0.41969052,3.1172523,,,,,,,,,,,,,,,,, -22300,0.33024457,2.9458296,,,,,,,,,,,,,,,,, -22400,0.3856897,3.0801234,,,,,,,,,,,,,,,,, -22500,0.3665072,3.0582132,,,,,,,,,,,,,,,,, -22600,0.3234683,3.0019991,,,,,,,,,,,,,,,,, -22700,0.30085382,3.057558,,,,,,,,,,,,,,,,, -22800,0.32499388,3.007506,,,,,,,,,,,,,,,,, -22900,0.35044453,3.0289621,,,,,,,,,,,,,,,,, -23000,0.36287895,3.1065311,,,,,,,,,,,,,,,,, -23100,0.33753115,3.04642,,,,,,,,,,,,,,,,, -23200,0.30571377,2.996339,,,,,,,,,,,,,,,,, -23300,0.3631164,3.022717,,,,,,,,,,,,,,,,, -23400,0.31436402,3.0215037,,,,,,,,,,,,,,,,, -23500,0.39132467,2.99788,,,,,,,,,,,,,,,,, -23600,0.32979834,3.011233,,,,,,,,,,,,,,,,, -23700,0.35775816,2.9744515,,,,,,,,,,,,,,,,, -23800,0.31863615,3.0335553,,,,,,,,,,,,,,,,, -23900,0.3574672,3.1053386,,,,,,,,,,,,,,,,, -24000,0.37463886,2.9805677,,,,,,,,,,,,,,,,, -24100,0.32095414,3.0500882,,,,,,,,,,,,,,,,, -24200,0.38029793,3.0500546,,,,,,,,,,,,,,,,, -24300,0.40141168,3.0344768,,,,,,,,,,,,,,,,, -24311,,,0.6267665028572083,1.955460786819458,30.495337151691416,0.6528747081756592,1.7616363763809204,27.740078969040887,3000.0,0.6625530123710632,1.7036962509155271,27.38874092003688,3003.0,8444.363789081573,14485.660428762436,8444.363789081573,6040.07980966568,0.4274477958679199,0.0 -24400,0.33942077,3.0289016,,,,,,,,,,,,,,,,, -24500,0.3333199,3.0191476,,,,,,,,,,,,,,,,, -24600,0.33936137,3.035697,,,,,,,,,,,,,,,,, -24700,0.36732653,3.012133,,,,,,,,,,,,,,,,, -24800,0.37625507,3.0302043,,,,,,,,,,,,,,,,, -24900,0.35036355,2.9867587,,,,,,,,,,,,,,,,, -25000,0.39649564,2.9732754,,,,,,,,,,,,,,,,, -25100,0.3671107,3.0009775,,,,,,,,,,,,,,,,, -25200,0.38800827,3.0150056,,,,,,,,,,,,,,,,, -25300,0.40503058,3.1098762,,,,,,,,,,,,,,,,, -25400,0.47513658,2.989754,,,,,,,,,,,,,,,,, -25500,0.33287588,2.920869,,,,,,,,,,,,,,,,, -25600,0.36461166,2.9470637,,,,,,,,,,,,,,,,, -25700,0.35469472,3.0341427,,,,,,,,,,,,,,,,, -25800,0.3584485,3.1335387,,,,,,,,,,,,,,,,, -25900,0.34480685,2.9378626,,,,,,,,,,,,,,,,, -26000,0.35523462,2.943679,,,,,,,,,,,,,,,,, -26100,0.32188165,2.96185,,,,,,,,,,,,,,,,, -26200,0.44603977,3.0472744,,,,,,,,,,,,,,,,, -26300,0.31515858,2.9860234,,,,,,,,,,,,,,,,, -26400,0.40134314,3.0397415,,,,,,,,,,,,,,,,, -26500,0.35749596,3.0046425,,,,,,,,,,,,,,,,, -26600,0.42174262,3.0625885,,,,,,,,,,,,,,,,, -26700,0.3798892,3.0080554,,,,,,,,,,,,,,,,, -26744,,,0.6343207955360413,1.8967632055282595,31.050486129125893,0.6534822583198547,1.7492674589157104,28.09813306279128,3000.0,0.6628900170326233,1.6931463479995728,27.04296441504016,3003.0,9284.414820194244,15802.381581544876,9284.414820194244,6516.639975786209,0.4582266807556152,0.0 -26800,0.45647728,3.0106318,,,,,,,,,,,,,,,,, -26900,0.32476458,3.0596824,,,,,,,,,,,,,,,,, -27000,0.34300825,2.9427507,,,,,,,,,,,,,,,,, -27100,0.36830246,2.9626114,,,,,,,,,,,,,,,,, -27200,0.36932576,2.9610653,,,,,,,,,,,,,,,,, -27300,0.38395086,3.0258017,,,,,,,,,,,,,,,,, -27400,0.39763388,2.9192293,,,,,,,,,,,,,,,,, -27500,0.5328803,3.0313635,,,,,,,,,,,,,,,,, -27600,0.34346658,2.9978762,,,,,,,,,,,,,,,,, -27700,0.32544303,2.9895945,,,,,,,,,,,,,,,,, -27800,0.37739277,3.023379,,,,,,,,,,,,,,,,, -27900,0.339514,2.9872007,,,,,,,,,,,,,,,,, -28000,0.41107145,3.0083625,,,,,,,,,,,,,,,,, -28100,0.35774848,2.9896002,,,,,,,,,,,,,,,,, -28200,0.3774197,2.9920032,,,,,,,,,,,,,,,,, -28300,0.364314,2.959883,,,,,,,,,,,,,,,,, -28400,0.33738422,3.0354507,,,,,,,,,,,,,,,,, -28500,0.35827976,3.0259392,,,,,,,,,,,,,,,,, -28600,0.33496332,3.0249448,,,,,,,,,,,,,,,,, -28700,0.35992408,3.027673,,,,,,,,,,,,,,,,, -28800,0.44185233,3.0144763,,,,,,,,,,,,,,,,, -28900,0.34129614,3.0069313,,,,,,,,,,,,,,,,, -29000,0.3399121,3.1225297,,,,,,,,,,,,,,,,, -29100,0.39161795,2.972681,,,,,,,,,,,,,,,,, -29177,,,0.6336384415626526,1.8958899974823,30.684595268987056,0.6565200686454773,1.7404468059539795,28.03879571932596,3000.0,0.665864884853363,1.6830744743347168,27.865501179623493,3003.0,10124.47589802742,17122.42004466057,10124.47589802742,6996.507694721222,0.4879882335662842,0.0 -29200,0.4271574,2.9674373,,,,,,,,,,,,,,,,, -29300,0.35839918,3.006251,,,,,,,,,,,,,,,,, -29400,0.71460843,3.232433,,,,,,,,,,,,,,,,, -29500,0.45927396,3.0175612,,,,,,,,,,,,,,,,, -29600,0.41402018,2.9761364,,,,,,,,,,,,,,,,, -29700,0.32809812,3.0392702,,,,,,,,,,,,,,,,, -29800,0.33681202,2.9831004,,,,,,,,,,,,,,,,, -29900,0.32086593,2.9342253,,,,,,,,,,,,,,,,, -30000,0.3174326,2.9493484,,,,,,,,,,,,,,,,, -30100,0.35360137,2.9178565,,,,,,,,,,,,,,,,, -30200,0.34185663,3.0335155,,,,,,,,,,,,,,,,, -30300,0.30062184,2.9516091,,,,,,,,,,,,,,,,, -30400,0.33305645,2.9403865,,,,,,,,,,,,,,,,, -30500,0.34386352,3.012525,,,,,,,,,,,,,,,,, -30600,0.31752637,2.9709795,,,,,,,,,,,,,,,,, -30700,0.3775364,2.9514604,,,,,,,,,,,,,,,,, -30800,0.3375172,2.9924335,,,,,,,,,,,,,,,,, -30900,0.3226729,2.8747203,,,,,,,,,,,,,,,,, -31000,0.30701485,2.9591134,,,,,,,,,,,,,,,,, -31100,0.31422108,2.9390583,,,,,,,,,,,,,,,,, -31200,0.30469543,2.9419942,,,,,,,,,,,,,,,,, -31300,0.44573808,2.979024,,,,,,,,,,,,,,,,, -31400,0.3763352,2.9091384,,,,,,,,,,,,,,,,, -31500,0.38669473,2.9649003,,,,,,,,,,,,,,,,, -31600,0.33769888,3.028815,,,,,,,,,,,,,,,,, -31610,,,0.6526271104812622,1.768532156944275,32.05716202678791,0.654350221157074,1.736448884010315,28.03466477961463,3000.0,0.6660043001174927,1.6788660287857056,27.773628041650408,3003.0,10964.547756910324,18423.88781070709,10964.547756910324,7457.794209480286,0.5176031589508057,0.0 -31700,0.35281122,2.9881835,,,,,,,,,,,,,,,,, -31800,0.3992646,3.0173113,,,,,,,,,,,,,,,,, -31900,0.3867289,3.0122058,,,,,,,,,,,,,,,,, -32000,0.31878033,2.9758348,,,,,,,,,,,,,,,,, -32100,0.35571402,3.0441682,,,,,,,,,,,,,,,,, -32200,0.330525,2.9540186,,,,,,,,,,,,,,,,, -32300,0.3378366,2.9318917,,,,,,,,,,,,,,,,, -32400,0.36279827,2.9637892,,,,,,,,,,,,,,,,, -32500,0.3935372,2.9449177,,,,,,,,,,,,,,,,, -32600,0.44860393,2.9012344,,,,,,,,,,,,,,,,, -32700,0.3429123,3.0268779,,,,,,,,,,,,,,,,, -32800,0.37707487,3.023736,,,,,,,,,,,,,,,,, -32900,0.3452883,3.0195966,,,,,,,,,,,,,,,,, -33000,0.3400487,2.982784,,,,,,,,,,,,,,,,, -33100,0.34679523,2.99175,,,,,,,,,,,,,,,,, -33200,0.334901,3.0016184,,,,,,,,,,,,,,,,, -33300,0.43758908,2.9783392,,,,,,,,,,,,,,,,, -33400,0.3203152,3.0426176,,,,,,,,,,,,,,,,, -33500,0.3754501,3.0030408,,,,,,,,,,,,,,,,, -33600,0.3498437,2.9948745,,,,,,,,,,,,,,,,, -33700,0.34215155,2.983966,,,,,,,,,,,,,,,,, -33800,0.3609381,2.9985247,,,,,,,,,,,,,,,,, -33900,0.3282053,2.9078715,,,,,,,,,,,,,,,,, -34000,0.33786067,3.000822,,,,,,,,,,,,,,,,, -34044,,,0.6363418102264404,1.877790093421936,31.210927634435524,0.65668123960495,1.7198882102966309,28.09483148598837,3000.0,0.6665620803833008,1.669747233390808,27.55349728228144,3003.0,11804.70926952362,19753.06885743141,11804.70926952362,7946.703389883041,0.5480039119720459,0.0 -34100,0.36705947,3.0096536,,,,,,,,,,,,,,,,, -34200,0.34124437,3.0199735,,,,,,,,,,,,,,,,, -34300,0.34266135,2.9288666,,,,,,,,,,,,,,,,, -34400,0.3698144,2.9160542,,,,,,,,,,,,,,,,, -34500,0.32200998,2.9668562,,,,,,,,,,,,,,,,, -34600,0.37358782,2.949086,,,,,,,,,,,,,,,,, -34700,0.30127698,2.9819608,,,,,,,,,,,,,,,,, -34800,0.3840873,2.945708,,,,,,,,,,,,,,,,, -34900,0.38548625,2.9963794,,,,,,,,,,,,,,,,, -35000,0.36267695,2.929625,,,,,,,,,,,,,,,,, -35100,0.37840095,3.0702522,,,,,,,,,,,,,,,,, -35200,0.3290111,2.946527,,,,,,,,,,,,,,,,, -35300,0.38351095,3.0463264,,,,,,,,,,,,,,,,, -35400,0.3767718,2.9821508,,,,,,,,,,,,,,,,, -35500,0.355536,2.974838,,,,,,,,,,,,,,,,, -35600,0.33217737,2.9281335,,,,,,,,,,,,,,,,, -35700,0.35956258,2.9777288,,,,,,,,,,,,,,,,, -35800,0.3486337,2.9943473,,,,,,,,,,,,,,,,, -35900,0.31717274,2.9518712,,,,,,,,,,,,,,,,, -36000,0.37901568,3.017398,,,,,,,,,,,,,,,,, -36100,0.44720137,3.0244427,,,,,,,,,,,,,,,,, -36200,0.3290731,2.9656065,,,,,,,,,,,,,,,,, -36300,0.41021097,2.956577,,,,,,,,,,,,,,,,, -36400,0.33285776,3.036662,,,,,,,,,,,,,,,,, -36478,,,0.6364160180091858,1.8722435235977173,31.076741563059336,0.6580327749252319,1.717186689376831,28.03697177435267,3000.0,0.6687583923339844,1.6572718620300293,27.75639606929041,3003.0,12644.77906513214,21185.581884860992,12644.77906513214,8539.038867473602,0.5769636631011963,0.0 -36500,0.33201024,2.909361,,,,,,,,,,,,,,,,, -36600,0.34328505,3.0108087,,,,,,,,,,,,,,,,, -36700,0.32457203,2.9261355,,,,,,,,,,,,,,,,, -36800,0.34820792,3.0048318,,,,,,,,,,,,,,,,, -36900,0.3688931,3.0000603,,,,,,,,,,,,,,,,, -37000,0.4580484,2.9462745,,,,,,,,,,,,,,,,, -37100,0.39180234,2.935589,,,,,,,,,,,,,,,,, -37200,0.40332103,2.9977243,,,,,,,,,,,,,,,,, -37300,0.40757617,2.9498453,,,,,,,,,,,,,,,,, -37400,0.43425876,2.9695396,,,,,,,,,,,,,,,,, -37500,0.34108043,2.9843976,,,,,,,,,,,,,,,,, -37600,0.86450255,2.9863508,,,,,,,,,,,,,,,,, -37700,0.4232283,3.012269,,,,,,,,,,,,,,,,, -37800,0.4191566,2.9626105,,,,,,,,,,,,,,,,, -37900,0.39540255,3.033767,,,,,,,,,,,,,,,,, -38000,0.34961748,2.9404066,,,,,,,,,,,,,,,,, -38100,0.34503073,2.9200978,,,,,,,,,,,,,,,,, -38200,0.34239703,2.959038,,,,,,,,,,,,,,,,, -38300,0.35597885,2.9492252,,,,,,,,,,,,,,,,, -38400,0.42146033,2.913509,,,,,,,,,,,,,,,,, -38500,0.37701777,2.9377267,,,,,,,,,,,,,,,,, -38600,0.35306144,3.0043397,,,,,,,,,,,,,,,,, -38700,0.3141523,2.9061213,,,,,,,,,,,,,,,,, -38800,0.37629318,2.9278193,,,,,,,,,,,,,,,,, -38900,0.3248708,2.9990506,,,,,,,,,,,,,,,,, -38912,,,0.6482433080673218,1.7970041036605835,31.43274199855436,0.6609093546867371,1.703112006187439,28.180620886314376,3000.0,0.6716634631156921,1.6385266780853271,27.945611634666346,3003.0,13484.840069770811,22510.99568128585,13484.840069770811,9024.285813570024,0.6064877510070801,0.0 -39000,0.31525275,2.886722,,,,,,,,,,,,,,,,, -39100,0.34738055,2.954534,,,,,,,,,,,,,,,,, -39200,0.34513414,2.982342,,,,,,,,,,,,,,,,, -39300,0.3403583,3.0563111,,,,,,,,,,,,,,,,, -39400,0.35297495,2.954301,,,,,,,,,,,,,,,,, -39500,0.31305444,2.923185,,,,,,,,,,,,,,,,, -39600,0.32689545,2.9467015,,,,,,,,,,,,,,,,, -39700,0.3253616,2.948823,,,,,,,,,,,,,,,,, -39800,0.42620587,2.9387772,,,,,,,,,,,,,,,,, -39900,0.4232228,2.9018152,,,,,,,,,,,,,,,,, -40000,0.35258073,2.9261854,,,,,,,,,,,,,,,,, -40100,0.35143313,2.904769,,,,,,,,,,,,,,,,, -40200,0.32869452,2.9393973,,,,,,,,,,,,,,,,, -40300,0.39151186,3.0058563,,,,,,,,,,,,,,,,, -40400,0.38817218,3.0124402,,,,,,,,,,,,,,,,, -40500,0.34772998,2.9492276,,,,,,,,,,,,,,,,, -40600,0.33135337,2.9004943,,,,,,,,,,,,,,,,, -40700,0.3511369,2.9599924,,,,,,,,,,,,,,,,, -40800,0.35731354,2.9924595,,,,,,,,,,,,,,,,, -40900,0.3372435,3.0127127,,,,,,,,,,,,,,,,, -41000,0.32287693,2.9007237,,,,,,,,,,,,,,,,, -41100,0.3637548,2.997751,,,,,,,,,,,,,,,,, -41200,0.33375248,2.8859253,,,,,,,,,,,,,,,,, -41300,0.34065232,2.9600656,,,,,,,,,,,,,,,,, -41346,,,0.6373471021652222,1.8696467876434328,31.01315039572357,0.6601405739784241,1.7056689262390137,28.38573685693433,3000.0,0.6728255152702332,1.63887357711792,28.00362928869641,3003.0,14324.897027015686,23825.52012705803,14324.897027015686,9498.646374464037,0.6380147933959961,0.0 -41400,0.3617677,2.9597614,,,,,,,,,,,,,,,,, -41500,0.32328233,2.835898,,,,,,,,,,,,,,,,, -41600,0.40405962,3.0005002,,,,,,,,,,,,,,,,, -41700,0.3614713,2.8914847,,,,,,,,,,,,,,,,, -41800,0.34367788,3.0052261,,,,,,,,,,,,,,,,, -41900,0.32742214,3.0411434,,,,,,,,,,,,,,,,, -42000,0.33277225,2.9406164,,,,,,,,,,,,,,,,, -42100,0.3548932,3.0483577,,,,,,,,,,,,,,,,, -42200,0.35373542,2.966661,,,,,,,,,,,,,,,,, -42300,0.35837662,3.0182037,,,,,,,,,,,,,,,,, -42400,0.36650053,2.9174583,,,,,,,,,,,,,,,,, -42500,0.37575006,2.9403467,,,,,,,,,,,,,,,,, -42600,0.34681413,2.9773433,,,,,,,,,,,,,,,,, -42700,0.3400503,2.9121776,,,,,,,,,,,,,,,,, -42800,0.38759813,2.9692051,,,,,,,,,,,,,,,,, -42900,0.33177558,2.9444206,,,,,,,,,,,,,,,,, -43000,0.33986345,2.9741428,,,,,,,,,,,,,,,,, -43100,0.33065027,2.9502022,,,,,,,,,,,,,,,,, -43200,0.34091976,2.9050424,,,,,,,,,,,,,,,,, -43300,0.31889388,2.9341328,,,,,,,,,,,,,,,,, -43400,0.3548196,2.9360886,,,,,,,,,,,,,,,,, -43500,0.325884,2.9128528,,,,,,,,,,,,,,,,, -43600,0.5393248,2.941387,,,,,,,,,,,,,,,,, -43700,0.40393052,2.9293106,,,,,,,,,,,,,,,,, -43780,,,0.6370298266410828,1.8723466396331787,31.06694249891112,0.6600537896156311,1.697940468788147,28.26108878119694,3000.0,0.6713381409645081,1.6407963037490845,27.69613416655745,3003.0,15164.826464653015,25156.252128362656,15164.826464653015,9989.341492176056,0.668973445892334,0.0 -43800,0.33816263,2.9959521,,,,,,,,,,,,,,,,, -43900,0.3741976,2.8966866,,,,,,,,,,,,,,,,, -44000,0.3392364,2.928427,,,,,,,,,,,,,,,,, -44100,0.36470065,2.9993513,,,,,,,,,,,,,,,,, -44200,0.35893843,2.8633108,,,,,,,,,,,,,,,,, -44300,0.3992713,2.986108,,,,,,,,,,,,,,,,, -44400,0.33541903,2.9242952,,,,,,,,,,,,,,,,, -44500,0.3283644,2.9610322,,,,,,,,,,,,,,,,, -44600,0.35069108,2.948788,,,,,,,,,,,,,,,,, -44700,0.33229586,2.9778671,,,,,,,,,,,,,,,,, -44800,0.3321212,2.916878,,,,,,,,,,,,,,,,, -44900,0.37897635,2.8898082,,,,,,,,,,,,,,,,, -45000,0.35087755,2.911936,,,,,,,,,,,,,,,,, -45100,0.36930498,2.970912,,,,,,,,,,,,,,,,, -45200,0.3478599,2.9352343,,,,,,,,,,,,,,,,, -45300,0.34505564,2.9161885,,,,,,,,,,,,,,,,, -45400,0.3211208,2.9689178,,,,,,,,,,,,,,,,, -45500,0.37320656,2.933133,,,,,,,,,,,,,,,,, -45600,0.36348927,2.932602,,,,,,,,,,,,,,,,, -45700,0.36803946,2.9770603,,,,,,,,,,,,,,,,, -45800,0.3449733,2.9413943,,,,,,,,,,,,,,,,, -45900,0.3507509,2.9667737,,,,,,,,,,,,,,,,, -46000,0.37541392,2.940384,,,,,,,,,,,,,,,,, -46100,0.38938564,2.9949474,,,,,,,,,,,,,,,,, -46200,0.40151566,2.9472573,,,,,,,,,,,,,,,,, -46214,,,0.6392388343811035,1.8472505807876587,31.98645200096553,0.6632279753684998,1.6884485483169556,28.6184400125844,3000.0,0.6738249063491821,1.6290220022201538,28.05211040280357,3003.0,16004.92531490326,26476.57047533989,16004.92531490326,10469.45325231552,0.7002456188201904,0.0 -46300,0.40916964,2.8865724,,,,,,,,,,,,,,,,, -46400,0.34052938,2.8676064,,,,,,,,,,,,,,,,, -46500,0.40466556,3.059718,,,,,,,,,,,,,,,,, -46600,0.36657014,2.9514737,,,,,,,,,,,,,,,,, -46700,0.44683924,2.957353,,,,,,,,,,,,,,,,, -46800,0.35975847,2.9268804,,,,,,,,,,,,,,,,, -46900,0.34041053,2.8924081,,,,,,,,,,,,,,,,, -47000,0.37123471,2.860225,,,,,,,,,,,,,,,,, -47100,0.33818606,2.9433863,,,,,,,,,,,,,,,,, -47200,0.33882105,2.9772174,,,,,,,,,,,,,,,,, -47300,0.31807175,2.8802454,,,,,,,,,,,,,,,,, -47400,0.36414194,2.981306,,,,,,,,,,,,,,,,, -47500,0.39268094,2.9310062,,,,,,,,,,,,,,,,, -47600,0.33157483,2.8699787,,,,,,,,,,,,,,,,, -47700,0.38190627,2.912533,,,,,,,,,,,,,,,,, -47800,0.34076318,2.9430602,,,,,,,,,,,,,,,,, -47900,0.39146677,2.9442377,,,,,,,,,,,,,,,,, -48000,0.4387413,2.9267616,,,,,,,,,,,,,,,,, -48100,0.36757925,2.952939,,,,,,,,,,,,,,,,, -48200,0.3156392,2.986202,,,,,,,,,,,,,,,,, -48300,0.35513857,2.8813663,,,,,,,,,,,,,,,,, -48400,0.34852022,2.8864746,,,,,,,,,,,,,,,,, -48500,0.35581422,2.9456935,,,,,,,,,,,,,,,,, -48600,0.34555367,2.971519,,,,,,,,,,,,,,,,, -48648,,,0.6418137550354004,1.8364354372024536,31.435799419526703,0.662161648273468,1.6864773035049438,28.773273751970383,3000.0,0.6732555031776428,1.6297379732131958,28.00563946047731,3003.0,16845.00081062317,27799.22328066826,16845.00081062317,10951.921907424929,0.731576681137085,0.0 -48700,0.31958526,2.9374917,,,,,,,,,,,,,,,,, -48800,0.34893292,2.8919065,,,,,,,,,,,,,,,,, -48900,0.3314189,2.9234657,,,,,,,,,,,,,,,,, -49000,0.3382172,2.867888,,,,,,,,,,,,,,,,, -49100,0.31970614,2.962653,,,,,,,,,,,,,,,,, -49200,0.32203215,2.8424876,,,,,,,,,,,,,,,,, -49300,0.38682947,2.945643,,,,,,,,,,,,,,,,, -49400,0.32486403,2.9061549,,,,,,,,,,,,,,,,, -49500,0.3268834,2.9581425,,,,,,,,,,,,,,,,, -49600,0.37421557,2.921217,,,,,,,,,,,,,,,,, -49700,0.3543721,2.9358428,,,,,,,,,,,,,,,,, -49800,0.34613958,2.9374516,,,,,,,,,,,,,,,,, -49900,0.35101232,2.988049,,,,,,,,,,,,,,,,, -50000,0.35617056,2.9113991,,,,,,,,,,,,,,,,, -50100,0.3539573,2.9367723,,,,,,,,,,,,,,,,, -50200,0.3221352,2.8779404,,,,,,,,,,,,,,,,, -50300,0.39453727,2.8875835,,,,,,,,,,,,,,,,, -50400,0.32991463,2.9546366,,,,,,,,,,,,,,,,, -50500,0.38436186,2.898927,,,,,,,,,,,,,,,,, -50600,0.34867913,2.8576155,,,,,,,,,,,,,,,,, -50700,0.39129525,2.9656403,,,,,,,,,,,,,,,,, -50800,0.40725523,2.8996303,,,,,,,,,,,,,,,,, -50900,0.34582716,2.9899108,,,,,,,,,,,,,,,,, -51000,0.36208287,2.8931599,,,,,,,,,,,,,,,,, -51082,,,0.6506011486053467,1.7793463468551636,32.20130396977178,0.6655961871147156,1.6724234819412231,28.732523018514083,3000.0,0.6755679845809937,1.6110336780548096,28.39498522119112,3003.0,17684.95172381401,29093.11604142189,17684.95172381401,11405.755249738691,0.7633512020111084,0.0 -51100,0.35711342,2.9425485,,,,,,,,,,,,,,,,, -51200,0.3756432,2.9818177,,,,,,,,,,,,,,,,, -51300,0.3867473,2.9260082,,,,,,,,,,,,,,,,, -51400,0.33802217,2.9039767,,,,,,,,,,,,,,,,, -51500,0.3718742,2.8942063,,,,,,,,,,,,,,,,, -51600,0.33173612,2.9257598,,,,,,,,,,,,,,,,, -51700,0.39009807,2.8890429,,,,,,,,,,,,,,,,, -51800,0.36379573,3.0039215,,,,,,,,,,,,,,,,, -51900,0.3653172,2.8731067,,,,,,,,,,,,,,,,, -52000,0.33422074,2.8422546,,,,,,,,,,,,,,,,, -52100,0.34991208,2.9331138,,,,,,,,,,,,,,,,, -52200,0.3450978,2.9402392,,,,,,,,,,,,,,,,, -52300,0.3196083,2.892317,,,,,,,,,,,,,,,,, -52400,0.3523471,2.9375796,,,,,,,,,,,,,,,,, -52500,0.32823357,2.8976421,,,,,,,,,,,,,,,,, -52600,0.3645753,2.8806114,,,,,,,,,,,,,,,,, -52700,0.34815103,2.90877,,,,,,,,,,,,,,,,, -52800,0.34510404,2.961332,,,,,,,,,,,,,,,,, -52900,0.37158778,2.9603302,,,,,,,,,,,,,,,,, -53000,0.3797868,2.9277034,,,,,,,,,,,,,,,,, -53100,0.34631503,2.9107435,,,,,,,,,,,,,,,,, -53200,0.35378563,2.8955982,,,,,,,,,,,,,,,,, -53300,0.3543929,2.9050887,,,,,,,,,,,,,,,,, -53400,0.33146426,2.8969765,,,,,,,,,,,,,,,,, -53500,0.37312496,2.937176,,,,,,,,,,,,,,,,, -53516,,,0.6455166935920715,1.8138216733932493,31.56522640495758,0.6665757298469543,1.6703606843948364,28.52078753000625,3000.0,0.6770902276039124,1.6028481721878052,28.199818862678665,3003.0,18525.0443277359,30404.6956615448,18525.0443277359,11877.134405851364,0.7947971820831299,0.0 -53600,0.35965234,2.9664507,,,,,,,,,,,,,,,,, -53700,0.38101467,2.9774654,,,,,,,,,,,,,,,,, -53800,0.3368752,2.9108968,,,,,,,,,,,,,,,,, -53900,0.36280504,2.8742535,,,,,,,,,,,,,,,,, -54000,0.36239302,2.9121215,,,,,,,,,,,,,,,,, -54100,0.3482951,2.842432,,,,,,,,,,,,,,,,, -54200,0.38101658,2.9092095,,,,,,,,,,,,,,,,, -54300,0.35152873,2.89571,,,,,,,,,,,,,,,,, -54400,0.36374325,2.881546,,,,,,,,,,,,,,,,, -54500,0.38936532,2.875037,,,,,,,,,,,,,,,,, -54600,0.35521558,2.9449275,,,,,,,,,,,,,,,,, -54700,0.33051792,2.909532,,,,,,,,,,,,,,,,, -54800,0.3984378,2.9068103,,,,,,,,,,,,,,,,, -54900,0.3562217,2.9599116,,,,,,,,,,,,,,,,, -55000,0.32834488,2.8938007,,,,,,,,,,,,,,,,, -55100,0.35749194,2.9748163,,,,,,,,,,,,,,,,, -55200,0.38757968,2.917029,,,,,,,,,,,,,,,,, -55300,0.34656554,2.9142704,,,,,,,,,,,,,,,,, -55400,0.32724786,2.8810618,,,,,,,,,,,,,,,,, -55500,0.42205486,2.890149,,,,,,,,,,,,,,,,, -55600,0.35564923,2.923806,,,,,,,,,,,,,,,,, -55700,0.34283295,2.8932714,,,,,,,,,,,,,,,,, -55800,0.39931569,2.961138,,,,,,,,,,,,,,,,, -55900,0.3603746,2.9159176,,,,,,,,,,,,,,,,, -55951,,,0.6445412635803223,1.82591450214386,32.05229143406017,0.6659185886383057,1.664124608039856,28.706732866328736,3000.0,0.680553138256073,1.5927239656448364,28.54534410724413,3003.0,19365.25709867477,31868.45972037316,19365.25709867477,12500.575400590897,0.8282985687255859,0.0 -56000,0.3594253,2.9384906,,,,,,,,,,,,,,,,, -56100,0.34726346,2.802471,,,,,,,,,,,,,,,,, -56200,0.43342814,2.8769886,,,,,,,,,,,,,,,,, -56300,0.32544237,2.9567451,,,,,,,,,,,,,,,,, -56400,0.34748277,2.8925614,,,,,,,,,,,,,,,,, -56500,0.3791045,2.9199023,,,,,,,,,,,,,,,,, -56600,0.3541468,2.8381152,,,,,,,,,,,,,,,,, -56700,0.43788058,2.9498906,,,,,,,,,,,,,,,,, -56800,0.35596249,2.8965209,,,,,,,,,,,,,,,,, -56900,0.34334883,2.903291,,,,,,,,,,,,,,,,, -57000,0.36080223,2.8832974,,,,,,,,,,,,,,,,, -57100,0.36760738,2.9804792,,,,,,,,,,,,,,,,, -57200,0.3715733,2.8448293,,,,,,,,,,,,,,,,, -57300,0.31864694,2.9127262,,,,,,,,,,,,,,,,, -57400,0.3465733,2.898925,,,,,,,,,,,,,,,,, -57500,0.37213054,2.9413085,,,,,,,,,,,,,,,,, -57600,0.37894535,2.9073203,,,,,,,,,,,,,,,,, -57700,0.35254288,2.860124,,,,,,,,,,,,,,,,, -57800,0.43232307,2.864606,,,,,,,,,,,,,,,,, -57900,0.35745245,2.8851438,,,,,,,,,,,,,,,,, -58000,0.39233872,2.8472466,,,,,,,,,,,,,,,,, -58100,0.3305097,2.9106185,,,,,,,,,,,,,,,,, -58200,0.36006412,2.8704758,,,,,,,,,,,,,,,,, -58300,0.35372704,2.8800218,,,,,,,,,,,,,,,,, -58385,,,0.6512637138366699,1.7649765014648438,31.62150612676097,0.6669105291366577,1.6533396244049072,28.5506335590646,3000.0,0.6782174110412598,1.5912344455718994,28.35522166002915,3003.0,20205.149728775024,33255.68974637985,20205.149728775024,13047.804039001465,0.8610539436340332,0.0 -58400,0.34508395,2.8812115,,,,,,,,,,,,,,,,, -58500,0.34845102,2.9314234,,,,,,,,,,,,,,,,, -58600,0.3422383,2.9099972,,,,,,,,,,,,,,,,, -58700,0.36091307,2.9286556,,,,,,,,,,,,,,,,, -58800,0.3353339,2.936548,,,,,,,,,,,,,,,,, -58900,0.4099879,2.88707,,,,,,,,,,,,,,,,, -59000,0.33547008,2.8412163,,,,,,,,,,,,,,,,, -59100,0.3369396,2.917368,,,,,,,,,,,,,,,,, -59200,0.407971,2.9018764,,,,,,,,,,,,,,,,, -59300,0.36918774,2.9956732,,,,,,,,,,,,,,,,, -59400,0.35418788,2.895566,,,,,,,,,,,,,,,,, -59500,0.31091633,2.8886874,,,,,,,,,,,,,,,,, -59600,0.34134465,2.8379664,,,,,,,,,,,,,,,,, -59700,0.37321833,2.917819,,,,,,,,,,,,,,,,, -59800,0.3599405,2.874531,,,,,,,,,,,,,,,,, -59900,0.34272516,2.8542933,,,,,,,,,,,,,,,,, -60000,0.34316152,2.8796258,,,,,,,,,,,,,,,,, -60100,0.42343357,2.9203084,,,,,,,,,,,,,,,,, -60200,0.35622537,2.8797421,,,,,,,,,,,,,,,,, -60300,0.34455898,2.9247582,,,,,,,,,,,,,,,,, -60400,0.36956707,2.937772,,,,,,,,,,,,,,,,, -60500,0.3647311,2.9177065,,,,,,,,,,,,,,,,, -60600,0.3159147,2.8984194,,,,,,,,,,,,,,,,, -60700,0.35897413,2.8466048,,,,,,,,,,,,,,,,, -60800,0.4101123,2.9911811,,,,,,,,,,,,,,,,, -60820,,,0.6503833532333374,1.7853611707687378,31.359232455254705,0.6674684882164001,1.655722975730896,28.79202635866657,3000.0,0.6813201308250427,1.5846421718597412,28.69376951309905,3003.0,21045.22128152848,34562.92338228226,21045.22128152848,13514.856243610382,0.8937373161315918,0.0 -60900,0.3391039,2.9212828,,,,,,,,,,,,,,,,, -61000,0.35266778,2.9057097,,,,,,,,,,,,,,,,, -61100,0.34508565,2.8805802,,,,,,,,,,,,,,,,, -61200,0.35057878,2.886646,,,,,,,,,,,,,,,,, -61300,0.34724286,2.9094205,,,,,,,,,,,,,,,,, -61400,0.3438647,2.8433084,,,,,,,,,,,,,,,,, -61500,0.3581337,2.840123,,,,,,,,,,,,,,,,, -61600,0.35905707,2.897513,,,,,,,,,,,,,,,,, -61700,0.39539036,2.8682501,,,,,,,,,,,,,,,,, -61800,0.33574596,2.9241385,,,,,,,,,,,,,,,,, -61900,0.34948257,2.9258707,,,,,,,,,,,,,,,,, -62000,0.39625868,2.931777,,,,,,,,,,,,,,,,, -62100,0.33433825,2.8506036,,,,,,,,,,,,,,,,, -62200,0.34610194,2.9195588,,,,,,,,,,,,,,,,, -62300,0.34922996,2.915227,,,,,,,,,,,,,,,,, -62400,0.34751114,2.867332,,,,,,,,,,,,,,,,, -62500,0.3361182,2.870315,,,,,,,,,,,,,,,,, -62600,0.34592384,2.8398342,,,,,,,,,,,,,,,,, -62700,0.35847887,2.8840148,,,,,,,,,,,,,,,,, -62800,0.35264575,2.9683304,,,,,,,,,,,,,,,,, -62900,0.34836963,2.9070506,,,,,,,,,,,,,,,,, -63000,0.364368,2.9191093,,,,,,,,,,,,,,,,, -63100,0.36317468,2.8709476,,,,,,,,,,,,,,,,, -63200,0.35594282,2.8379636,,,,,,,,,,,,,,,,, -63255,,,0.6639273166656494,1.6947698593139648,32.86530003174533,0.6687827706336975,1.639392375946045,29.130301380967907,3000.0,0.6831212639808655,1.5691425800323486,28.78299893054562,3003.0,21885.19492340088,35868.71333575249,21885.19492340088,13980.560697555542,0.9289267063140868,0.0 -63300,0.37948328,2.8475847,,,,,,,,,,,,,,,,, -63400,0.386408,2.9185793,,,,,,,,,,,,,,,,, -63500,0.35166788,2.9125283,,,,,,,,,,,,,,,,, -63600,0.38842914,2.9075952,,,,,,,,,,,,,,,,, -63700,0.34611976,2.8406026,,,,,,,,,,,,,,,,, -63800,0.36242825,2.8914976,,,,,,,,,,,,,,,,, -63900,0.34575486,2.8860343,,,,,,,,,,,,,,,,, -64000,0.35162333,2.9070072,,,,,,,,,,,,,,,,, -64100,0.3779158,2.883286,,,,,,,,,,,,,,,,, -64200,0.33713654,2.9126897,,,,,,,,,,,,,,,,, -64300,0.35162532,2.9068134,,,,,,,,,,,,,,,,, -64400,0.3722058,2.9164388,,,,,,,,,,,,,,,,, -64500,0.33646667,2.9310985,,,,,,,,,,,,,,,,, -64600,0.38865545,2.90933,,,,,,,,,,,,,,,,, -64700,0.40910655,2.9131868,,,,,,,,,,,,,,,,, -64800,0.33713958,2.8801441,,,,,,,,,,,,,,,,, -64900,0.3926707,2.8674893,,,,,,,,,,,,,,,,, -65000,0.3525869,2.8730059,,,,,,,,,,,,,,,,, -65100,0.324154,2.8562517,,,,,,,,,,,,,,,,, -65200,0.3642934,2.9014778,,,,,,,,,,,,,,,,, -65300,0.32823125,2.893238,,,,,,,,,,,,,,,,, -65400,0.35393462,2.921225,,,,,,,,,,,,,,,,, -65500,0.35159513,2.8983386,,,,,,,,,,,,,,,,, -65600,0.35629252,2.8938336,,,,,,,,,,,,,,,,, -65690,,,0.6493473649024963,1.7821170091629028,32.37563384288451,0.6715725660324097,1.6345800161361694,29.14645826795015,3000.0,0.6841439008712769,1.5664318799972534,28.850066271740804,3003.0,22725.28811383248,37216.40605187416,22725.28811383248,14488.048207998276,0.964684009552002,0.0 -65700,0.33311388,2.8285534,,,,,,,,,,,,,,,,, -65800,0.385225,2.917838,,,,,,,,,,,,,,,,, -65900,0.36719802,2.8723912,,,,,,,,,,,,,,,,, -66000,0.38324907,2.8902137,,,,,,,,,,,,,,,,, -66100,0.38857332,2.839616,,,,,,,,,,,,,,,,, -66200,0.3538888,2.8431907,,,,,,,,,,,,,,,,, -66300,0.33960637,2.8164413,,,,,,,,,,,,,,,,, -66400,0.34965557,2.798232,,,,,,,,,,,,,,,,, -66500,0.39730823,2.8718913,,,,,,,,,,,,,,,,, -66600,0.36104637,2.948207,,,,,,,,,,,,,,,,, -66700,0.39416155,2.8387778,,,,,,,,,,,,,,,,, -66800,0.3746964,2.8801818,,,,,,,,,,,,,,,,, -66900,0.33694345,2.8340347,,,,,,,,,,,,,,,,, -67000,0.349952,2.9266717,,,,,,,,,,,,,,,,, -67100,0.337301,2.895922,,,,,,,,,,,,,,,,, -67200,0.3863748,2.8510096,,,,,,,,,,,,,,,,, -67300,0.36235175,2.8071835,,,,,,,,,,,,,,,,, -67400,0.35935837,2.9050355,,,,,,,,,,,,,,,,, -67500,0.33600542,2.8000965,,,,,,,,,,,,,,,,, -67600,0.38356817,2.766079,,,,,,,,,,,,,,,,, -67700,0.3582778,2.8592253,,,,,,,,,,,,,,,,, -67800,0.34733504,2.8801513,,,,,,,,,,,,,,,,, -67900,0.41649017,2.917199,,,,,,,,,,,,,,,,, -68000,0.36984769,2.8199213,,,,,,,,,,,,,,,,, -68100,0.3858246,2.8019426,,,,,,,,,,,,,,,,, -68125,,,0.6522968411445618,1.7712243795394895,32.26716561149852,0.6721181273460388,1.6278233528137207,29.2856531151115,3000.0,0.6857939958572388,1.5533937215805054,29.1256183217246,3003.0,23565.320879220963,38545.19595956802,23565.320879220963,14976.696114301682,0.99855375289917,0.0 -68200,0.38510168,2.9256177,,,,,,,,,,,,,,,,, -68300,0.32529294,2.9192193,,,,,,,,,,,,,,,,, -68400,0.34770396,2.8897874,,,,,,,,,,,,,,,,, -68500,0.35355917,2.8707378,,,,,,,,,,,,,,,,, -68600,0.36338347,2.9100926,,,,,,,,,,,,,,,,, -68700,0.3667795,2.8694103,,,,,,,,,,,,,,,,, -68800,0.3785396,2.9368966,,,,,,,,,,,,,,,,, -68900,0.34426454,2.8558202,,,,,,,,,,,,,,,,, -69000,0.3612728,2.8747935,,,,,,,,,,,,,,,,, -69100,0.3598729,2.901184,,,,,,,,,,,,,,,,, -69200,0.36036763,2.8941505,,,,,,,,,,,,,,,,, -69300,0.36729342,2.8067677,,,,,,,,,,,,,,,,, -69400,0.3305444,2.8298163,,,,,,,,,,,,,,,,, -69500,0.3513388,2.8270428,,,,,,,,,,,,,,,,, -69600,0.35219863,2.8547485,,,,,,,,,,,,,,,,, -69700,0.34536147,2.8774998,,,,,,,,,,,,,,,,, -69800,0.36080062,2.7823663,,,,,,,,,,,,,,,,, -69900,0.37925547,2.8450372,,,,,,,,,,,,,,,,, -70000,0.35165668,2.8080044,,,,,,,,,,,,,,,,, -70100,0.35412425,2.8460524,,,,,,,,,,,,,,,,, -70200,0.33869722,2.8074527,,,,,,,,,,,,,,,,, -70300,0.3628773,2.7902262,,,,,,,,,,,,,,,,, -70400,0.4092732,2.8657365,,,,,,,,,,,,,,,,, -70500,0.34018034,2.853689,,,,,,,,,,,,,,,,, -70560,,,0.6611246466636658,1.709555745124817,32.33671503844988,0.6735812425613403,1.61667001247406,29.11200803725564,3000.0,0.6852827072143555,1.5494496822357178,28.89788171712957,3003.0,24405.44509220124,39850.8931684494,24405.44509220124,15442.105145931244,1.085925817489624,0.0 -70600,0.35747132,2.895919,,,,,,,,,,,,,,,,, -70700,0.33832622,2.8112183,,,,,,,,,,,,,,,,, -70800,0.35063785,2.9101007,,,,,,,,,,,,,,,,, -70900,0.35050488,2.8459883,,,,,,,,,,,,,,,,, -71000,0.3806098,2.9265902,,,,,,,,,,,,,,,,, -71100,0.37020966,2.884432,,,,,,,,,,,,,,,,, -71200,0.3431848,2.8535318,,,,,,,,,,,,,,,,, -71300,0.3445167,2.8897016,,,,,,,,,,,,,,,,, -71400,0.4075488,2.8776586,,,,,,,,,,,,,,,,, -71500,0.40515256,2.9010901,,,,,,,,,,,,,,,,, -71600,0.36054468,2.8252642,,,,,,,,,,,,,,,,, -71700,0.3628796,2.897875,,,,,,,,,,,,,,,,, -71800,0.36157212,2.8836162,,,,,,,,,,,,,,,,, -71900,0.35870782,2.8298595,,,,,,,,,,,,,,,,, -72000,0.3536941,2.8672082,,,,,,,,,,,,,,,,, -72100,0.38453117,2.8060672,,,,,,,,,,,,,,,,, -72200,0.37380236,2.8910964,,,,,,,,,,,,,,,,, -72300,0.3400857,2.830429,,,,,,,,,,,,,,,,, -72400,0.35684565,2.8304257,,,,,,,,,,,,,,,,, -72500,0.35007438,2.8157592,,,,,,,,,,,,,,,,, -72600,0.34230876,2.8034737,,,,,,,,,,,,,,,,, -72700,0.3495328,2.7642703,,,,,,,,,,,,,,,,, -72800,0.35917646,2.8630433,,,,,,,,,,,,,,,,, -72900,0.3774998,2.766456,,,,,,,,,,,,,,,,, -72995,,,0.6553438901901245,1.7596919536590576,32.79811780395135,0.6746723651885986,1.609011173248291,29.66018880236669,3000.0,0.6878392100334167,1.5402271747589111,29.30915630269909,3003.0,25245.6302895546,41157.38154554367,25245.6302895546,15908.294090032578,1.1230123043060305,0.0 -73000,0.3692859,2.8903863,,,,,,,,,,,,,,,,, -73100,0.39462826,2.7918172,,,,,,,,,,,,,,,,, -73200,0.3886864,2.8331225,,,,,,,,,,,,,,,,, -73300,0.34681633,2.8074749,,,,,,,,,,,,,,,,, -73400,0.37998733,2.8429492,,,,,,,,,,,,,,,,, -73500,0.35816962,2.7981386,,,,,,,,,,,,,,,,, -73600,0.39866945,2.836995,,,,,,,,,,,,,,,,, -73700,0.37711135,2.8609128,,,,,,,,,,,,,,,,, -73800,0.37281838,2.8584776,,,,,,,,,,,,,,,,, -73900,0.4295126,2.8067684,,,,,,,,,,,,,,,,, -74000,0.46367303,2.8444376,,,,,,,,,,,,,,,,, -74100,0.36839855,2.805006,,,,,,,,,,,,,,,,, -74200,0.40113485,2.8788145,,,,,,,,,,,,,,,,, -74300,0.37106022,2.8024235,,,,,,,,,,,,,,,,, -74400,0.38142425,2.8488708,,,,,,,,,,,,,,,,, -74500,0.3801339,2.8917751,,,,,,,,,,,,,,,,, -74600,0.3869457,2.8151462,,,,,,,,,,,,,,,,, -74700,0.354592,2.8077617,,,,,,,,,,,,,,,,, -74800,0.38178372,2.8054788,,,,,,,,,,,,,,,,, -74900,0.365642,2.8278198,,,,,,,,,,,,,,,,, -75000,0.35319048,2.8465545,,,,,,,,,,,,,,,,, -75100,0.39037824,2.8462791,,,,,,,,,,,,,,,,, -75200,0.35594195,2.8863325,,,,,,,,,,,,,,,,, -75300,0.39911753,2.7521436,,,,,,,,,,,,,,,,, -75400,0.34561577,2.7879953,,,,,,,,,,,,,,,,, -75429,,,0.6763533353805542,1.624076247215271,33.937369926587024,0.6763462424278259,1.6011841297149658,29.693910913014506,3000.0,0.689012885093689,1.5321506261825562,29.0940858153008,3003.0,26085.57997250557,42527.72531962395,26085.57997250557,16438.57457280159,1.1592369079589844,0.0 -75500,0.346204,2.7685475,,,,,,,,,,,,,,,,, -75600,0.34641913,2.8045778,,,,,,,,,,,,,,,,, -75700,0.3797508,2.861543,,,,,,,,,,,,,,,,, -75800,0.38356695,2.9069226,,,,,,,,,,,,,,,,, -75900,0.36306942,2.879778,,,,,,,,,,,,,,,,, -76000,0.33127835,2.8537912,,,,,,,,,,,,,,,,, -76100,0.3744418,2.8281767,,,,,,,,,,,,,,,,, -76200,0.3608464,2.8416798,,,,,,,,,,,,,,,,, -76300,0.37248793,2.8270502,,,,,,,,,,,,,,,,, -76400,0.37407553,2.847594,,,,,,,,,,,,,,,,, -76500,0.34173644,2.8366828,,,,,,,,,,,,,,,,, -76600,0.36982697,2.7509778,,,,,,,,,,,,,,,,, -76700,0.3934886,2.9178286,,,,,,,,,,,,,,,,, -76800,0.38574877,2.9334645,,,,,,,,,,,,,,,,, -76900,0.331351,2.7863822,,,,,,,,,,,,,,,,, -77000,0.35530078,2.7609403,,,,,,,,,,,,,,,,, -77100,0.37962854,2.8511808,,,,,,,,,,,,,,,,, -77200,0.35480666,2.8066492,,,,,,,,,,,,,,,,, -77300,0.37378436,2.8666248,,,,,,,,,,,,,,,,, -77400,0.3719321,2.817469,,,,,,,,,,,,,,,,, -77500,0.36836204,2.8098366,,,,,,,,,,,,,,,,, -77600,0.3629925,2.824547,,,,,,,,,,,,,,,,, -77700,0.36806664,2.758399,,,,,,,,,,,,,,,,, -77800,0.3599362,2.8206496,,,,,,,,,,,,,,,,, -77864,,,0.6609399318695068,1.7068854570388794,32.66516373016686,0.6770405769348145,1.597065806388855,29.53969753202616,3000.0,0.6887107491493225,1.527599573135376,29.20757973008861,3003.0,26925.578989982605,43852.69458556175,26925.578989982605,16923.432988643646,1.1936464309692385,0.0 -77900,0.3863221,2.8522544,,,,,,,,,,,,,,,,, -78000,0.3919584,2.8955736,,,,,,,,,,,,,,,,, -78100,0.4187831,2.8647308,,,,,,,,,,,,,,,,, -78200,0.3980327,2.921264,,,,,,,,,,,,,,,,, -78300,0.37182137,2.8102343,,,,,,,,,,,,,,,,, -78400,0.37681952,2.72687,,,,,,,,,,,,,,,,, -78500,0.35856023,2.7814386,,,,,,,,,,,,,,,,, -78600,0.36117065,2.8392134,,,,,,,,,,,,,,,,, -78700,0.3810944,2.875113,,,,,,,,,,,,,,,,, -78800,0.37435302,2.7989757,,,,,,,,,,,,,,,,, -78900,0.3338066,2.7789447,,,,,,,,,,,,,,,,, -79000,0.35637623,2.8100424,,,,,,,,,,,,,,,,, -79100,0.34769607,2.8476822,,,,,,,,,,,,,,,,, -79200,0.35593,2.7957737,,,,,,,,,,,,,,,,, -79300,0.35671,2.8114526,,,,,,,,,,,,,,,,, -79400,0.43067327,2.8902047,,,,,,,,,,,,,,,,, -79500,0.37910402,2.822723,,,,,,,,,,,,,,,,, -79600,0.34975973,2.8253472,,,,,,,,,,,,,,,,, -79700,0.3696266,2.863313,,,,,,,,,,,,,,,,, -79800,0.37963918,2.804483,,,,,,,,,,,,,,,,, -79900,0.36429304,2.7255774,,,,,,,,,,,,,,,,, -80000,0.3916567,2.8560998,,,,,,,,,,,,,,,,, -80100,0.38595742,2.8460882,,,,,,,,,,,,,,,,, -80200,0.35855317,2.7478907,,,,,,,,,,,,,,,,, -80299,,,0.6609601974487305,1.7184412479400637,32.26415699149423,0.6776357293128967,1.5888279676437378,29.55108559881877,3000.0,0.6894195675849915,1.5198761224746704,29.1376581392568,3003.0,27765.792411088943,45205.13687705994,27765.792411088943,17435.547990322113,1.230426788330078,0.0 -80300,0.3854335,2.8389227,,,,,,,,,,,,,,,,, -80400,0.3603259,2.798495,,,,,,,,,,,,,,,,, -80500,0.3586007,2.8109481,,,,,,,,,,,,,,,,, -80600,0.3919053,2.793174,,,,,,,,,,,,,,,,, -80700,0.39370915,2.827441,,,,,,,,,,,,,,,,, -80800,0.37092382,2.9042034,,,,,,,,,,,,,,,,, -80900,0.38712782,2.8135893,,,,,,,,,,,,,,,,, -81000,0.3986868,2.800721,,,,,,,,,,,,,,,,, -81100,0.40391955,2.8570354,,,,,,,,,,,,,,,,, -81200,0.37792137,2.7943091,,,,,,,,,,,,,,,,, -81300,0.3700341,2.8000503,,,,,,,,,,,,,,,,, -81400,0.36925852,2.8271081,,,,,,,,,,,,,,,,, -81500,0.36957186,2.8466284,,,,,,,,,,,,,,,,, -81600,0.39904568,2.867309,,,,,,,,,,,,,,,,, -81700,0.361885,2.8091962,,,,,,,,,,,,,,,,, -81800,0.43730453,2.8340795,,,,,,,,,,,,,,,,, -81900,0.3873677,2.8523026,,,,,,,,,,,,,,,,, -82000,0.37980598,2.8628752,,,,,,,,,,,,,,,,, -82100,0.4044101,2.8393745,,,,,,,,,,,,,,,,, -82200,0.39563015,2.8959992,,,,,,,,,,,,,,,,, -82300,0.41462395,2.8502274,,,,,,,,,,,,,,,,, -82400,0.37693933,2.823025,,,,,,,,,,,,,,,,, -82500,0.40254733,2.779641,,,,,,,,,,,,,,,,, -82600,0.41693482,2.8567605,,,,,,,,,,,,,,,,, -82700,0.39687622,2.857513,,,,,,,,,,,,,,,,, -82733,,,0.6679980754852295,1.673125147819519,33.5625012966058,0.6785408854484558,1.587047100067139,29.696808866105183,3000.0,0.6937772631645203,1.509999394416809,29.74323986290128,3003.0,28605.77160620689,46501.64113926888,28605.77160620689,17891.961584091187,1.264892816543579,0.0 -82800,0.3660778,2.8412943,,,,,,,,,,,,,,,,, -82900,0.383931,2.8023279,,,,,,,,,,,,,,,,, -83000,0.37446138,2.7451253,,,,,,,,,,,,,,,,, -83100,0.38504124,2.8600378,,,,,,,,,,,,,,,,, -83200,0.38327459,2.862922,,,,,,,,,,,,,,,,, -83300,0.3657644,2.8040257,,,,,,,,,,,,,,,,, -83400,0.4016064,2.8576558,,,,,,,,,,,,,,,,, -83500,0.39111596,2.849615,,,,,,,,,,,,,,,,, -83600,0.3903927,2.7913332,,,,,,,,,,,,,,,,, -83700,0.39036733,2.7898986,,,,,,,,,,,,,,,,, -83800,0.39151305,2.7561753,,,,,,,,,,,,,,,,, -83900,0.38375023,2.8324947,,,,,,,,,,,,,,,,, -84000,0.4196913,2.789019,,,,,,,,,,,,,,,,, -84100,0.38895634,2.8219259,,,,,,,,,,,,,,,,, -84200,0.3842014,2.839518,,,,,,,,,,,,,,,,, -84300,0.40026975,2.8256733,,,,,,,,,,,,,,,,, -84400,0.4083423,2.8793364,,,,,,,,,,,,,,,,, -84500,0.4169987,2.8141418,,,,,,,,,,,,,,,,, -84600,0.38310283,2.789577,,,,,,,,,,,,,,,,, -84700,0.36490208,2.75463,,,,,,,,,,,,,,,,, -84800,0.38871235,2.7868454,,,,,,,,,,,,,,,,, -84900,0.41926205,2.7819912,,,,,,,,,,,,,,,,, -85000,0.40877807,2.8042076,,,,,,,,,,,,,,,,, -85100,0.39234212,2.8685007,,,,,,,,,,,,,,,,, -85167,,,0.6636251211166382,1.700329303741455,32.7424648321825,0.6798551678657532,1.578163981437683,29.852257792933663,3000.0,0.6938701868057251,1.5030529499053955,29.60024070059604,3003.0,29445.863482236862,47847.549050569534,29445.863482236862,18397.664662599564,1.301595687866211,0.0 -85200,0.3997066,2.7532964,,,,,,,,,,,,,,,,, -85300,0.4077866,2.8072402,,,,,,,,,,,,,,,,, -85400,0.39948967,2.7508218,,,,,,,,,,,,,,,,, -85500,0.43928444,2.7917356,,,,,,,,,,,,,,,,, -85600,0.39761934,2.7609491,,,,,,,,,,,,,,,,, -85700,0.4000173,2.8283472,,,,,,,,,,,,,,,,, -85800,0.41435987,2.8857195,,,,,,,,,,,,,,,,, -85900,0.39746374,2.8089113,,,,,,,,,,,,,,,,, -86000,0.51702577,2.8116882,,,,,,,,,,,,,,,,, -86100,0.41274163,2.753381,,,,,,,,,,,,,,,,, -86200,0.3951449,2.8168392,,,,,,,,,,,,,,,,, -86300,0.39824036,2.7992387,,,,,,,,,,,,,,,,, -86400,0.38998124,2.8456607,,,,,,,,,,,,,,,,, -86500,0.3933551,2.821566,,,,,,,,,,,,,,,,, -86600,0.40295035,2.8473485,,,,,,,,,,,,,,,,, -86700,0.3978913,2.7963233,,,,,,,,,,,,,,,,, -86800,0.4152337,2.814467,,,,,,,,,,,,,,,,, -86900,0.39676824,2.7231667,,,,,,,,,,,,,,,,, -87000,0.4180643,2.8688178,,,,,,,,,,,,,,,,, -87100,0.41170198,2.794441,,,,,,,,,,,,,,,,, -87200,0.40865654,2.811601,,,,,,,,,,,,,,,,, -87300,0.38286343,2.8027203,,,,,,,,,,,,,,,,, -87400,0.3734897,2.7586315,,,,,,,,,,,,,,,,, -87500,0.40077135,2.793964,,,,,,,,,,,,,,,,, -87600,0.39772403,2.7985744,,,,,,,,,,,,,,,,, -87601,,,0.7056272029876709,1.4632694721221924,35.86558291442635,0.6798055768013,1.5734652280807495,29.87314936635056,3000.0,0.6945558190345764,1.498718023300171,29.48370587895101,3003.0,30286.26316356659,49169.80411338806,30286.26316356659,18879.404970645905,1.3394997119903564,0.0 -87700,0.38917688,2.7626278,,,,,,,,,,,,,,,,, -87800,0.38903856,2.7641168,,,,,,,,,,,,,,,,, -87900,0.397098,2.7559323,,,,,,,,,,,,,,,,, -88000,0.41133508,2.7885075,,,,,,,,,,,,,,,,, -88100,0.38920522,2.7483387,,,,,,,,,,,,,,,,, -88200,0.37052742,2.8373265,,,,,,,,,,,,,,,,, -88300,0.41851428,2.7999976,,,,,,,,,,,,,,,,, -88400,0.38969308,2.8004687,,,,,,,,,,,,,,,,, -88500,0.3999699,2.7921023,,,,,,,,,,,,,,,,, -88600,0.38934124,2.7730799,,,,,,,,,,,,,,,,, -88700,0.3973031,2.7867365,,,,,,,,,,,,,,,,, -88800,0.4310371,2.8919678,,,,,,,,,,,,,,,,, -88900,0.42075676,2.762229,,,,,,,,,,,,,,,,, -89000,0.38015583,2.748861,,,,,,,,,,,,,,,,, -89100,0.38056242,2.831008,,,,,,,,,,,,,,,,, -89200,0.41338184,2.8488908,,,,,,,,,,,,,,,,, -89300,0.421441,2.8193004,,,,,,,,,,,,,,,,, -89400,0.42122418,2.756263,,,,,,,,,,,,,,,,, -89500,0.39110246,2.7491033,,,,,,,,,,,,,,,,, -89600,0.42949003,2.7452023,,,,,,,,,,,,,,,,, -89700,0.41381702,2.7680514,,,,,,,,,,,,,,,,, -89800,0.40886337,2.780121,,,,,,,,,,,,,,,,, -89900,0.40067858,2.812733,,,,,,,,,,,,,,,,, -90000,0.3889222,2.8035986,,,,,,,,,,,,,,,,, -90035,,,0.6675459742546082,1.6747876405715942,33.33141123987844,0.6829301714897156,1.564455509185791,29.92197180523269,3000.0,0.6961826682090759,1.489499807357788,29.9603429821149,3003.0,31126.28993988037,50491.37156367302,31126.28993988037,19360.830275535583,1.3778259754180908,0.0 -90100,0.41422626,2.7963486,,,,,,,,,,,,,,,,, -90200,0.39934754,2.691349,,,,,,,,,,,,,,,,, -90300,0.43756044,2.8080995,,,,,,,,,,,,,,,,, -90400,0.4330603,2.778881,,,,,,,,,,,,,,,,, -90500,0.3895903,2.671828,,,,,,,,,,,,,,,,, -90600,0.40877196,2.7943816,,,,,,,,,,,,,,,,, -90700,0.4236087,2.736386,,,,,,,,,,,,,,,,, -90800,0.41813517,2.74125,,,,,,,,,,,,,,,,, -90900,0.435955,2.8194077,,,,,,,,,,,,,,,,, -91000,0.3932387,2.7744765,,,,,,,,,,,,,,,,, -91100,0.40291923,2.7693436,,,,,,,,,,,,,,,,, -91200,0.40480807,2.8614883,,,,,,,,,,,,,,,,, -91300,0.4121788,2.7369065,,,,,,,,,,,,,,,,, -91400,0.40069383,2.7519069,,,,,,,,,,,,,,,,, -91500,0.4273519,2.822082,,,,,,,,,,,,,,,,, -91600,0.44144836,2.797181,,,,,,,,,,,,,,,,, -91700,0.46168524,2.7987401,,,,,,,,,,,,,,,,, -91800,0.42405316,2.8661356,,,,,,,,,,,,,,,,, -91900,0.40542176,2.7375643,,,,,,,,,,,,,,,,, -92000,0.43101013,2.7684767,,,,,,,,,,,,,,,,, -92100,0.43928492,2.7097158,,,,,,,,,,,,,,,,, -92200,0.44499713,2.7696157,,,,,,,,,,,,,,,,, -92300,0.4582066,2.752253,,,,,,,,,,,,,,,,, -92400,0.45158893,2.8297787,,,,,,,,,,,,,,,,, -92469,,,0.6658973097801208,1.6872364282608032,33.660730787562834,0.6835625171661377,1.5563279390335083,30.26361366275628,3000.0,0.6978908777236938,1.4811248779296875,30.035932898851115,3003.0,31966.21331453324,51810.94109606743,31966.21331453324,19840.359940052032,1.4165542125701904,0.0 -92500,0.43750438,2.8092532,,,,,,,,,,,,,,,,, -92600,0.40557745,2.7957234,,,,,,,,,,,,,,,,, -92700,0.41732645,2.7690876,,,,,,,,,,,,,,,,, -92800,0.4075086,2.705478,,,,,,,,,,,,,,,,, -92900,0.4370218,2.7192223,,,,,,,,,,,,,,,,, -93000,0.4142984,2.7695322,,,,,,,,,,,,,,,,, -93100,0.41493204,2.7344213,,,,,,,,,,,,,,,,, -93200,0.4408437,2.6973403,,,,,,,,,,,,,,,,, -93300,0.42445305,2.7128503,,,,,,,,,,,,,,,,, -93400,0.45090058,2.7622094,,,,,,,,,,,,,,,,, -93500,0.4418728,2.810429,,,,,,,,,,,,,,,,, -93600,0.43805173,2.7621713,,,,,,,,,,,,,,,,, -93700,0.45136482,2.7603097,,,,,,,,,,,,,,,,, -93800,0.44353032,2.7553496,,,,,,,,,,,,,,,,, -93900,0.4378443,2.8100152,,,,,,,,,,,,,,,,, -94000,0.46038967,2.726849,,,,,,,,,,,,,,,,, -94100,0.43571705,2.769535,,,,,,,,,,,,,,,,, -94200,0.44601744,2.7438767,,,,,,,,,,,,,,,,, -94300,0.43888387,2.8016474,,,,,,,,,,,,,,,,, -94400,0.43920293,2.7566273,,,,,,,,,,,,,,,,, -94500,0.42376697,2.7906697,,,,,,,,,,,,,,,,, -94600,0.4205604,2.7571077,,,,,,,,,,,,,,,,, -94700,0.437258,2.7510178,,,,,,,,,,,,,,,,, -94800,0.45379764,2.6975207,,,,,,,,,,,,,,,,, -94900,0.4260164,2.7242577,,,,,,,,,,,,,,,,, -94903,,,0.6817435026168823,1.5854727029800415,33.982398172121265,0.6839964985847473,1.5531708002090454,30.10177069772297,3000.0,0.7008076310157776,1.4715487957000732,29.92625119054191,3003.0,32806.30698490143,53134.88615298271,32806.30698490143,20324.09602546692,1.453615665435791,0.0 -95000,0.45562112,2.7577536,,,,,,,,,,,,,,,,, -95100,0.44691408,2.7336416,,,,,,,,,,,,,,,,, -95200,0.4400149,2.783988,,,,,,,,,,,,,,,,, -95300,0.47284013,2.7782319,,,,,,,,,,,,,,,,, -95400,0.44675407,2.6646593,,,,,,,,,,,,,,,,, -95500,0.4456523,2.8061428,,,,,,,,,,,,,,,,, -95600,0.4426156,2.7523828,,,,,,,,,,,,,,,,, -95700,0.42609325,2.7108707,,,,,,,,,,,,,,,,, -95800,0.41711769,2.706761,,,,,,,,,,,,,,,,, -95900,0.44812834,2.7196224,,,,,,,,,,,,,,,,, -96000,0.46148345,2.7654583,,,,,,,,,,,,,,,,, -96100,0.48023513,2.7617922,,,,,,,,,,,,,,,,, -96200,0.43188712,2.7246726,,,,,,,,,,,,,,,,, -96300,0.44984448,2.8127704,,,,,,,,,,,,,,,,, -96400,0.6351434,2.7402844,,,,,,,,,,,,,,,,, -96500,0.44296113,2.775553,,,,,,,,,,,,,,,,, -96600,0.4701511,2.791177,,,,,,,,,,,,,,,,, -96700,0.45576185,2.7023866,,,,,,,,,,,,,,,,, -96800,0.4597129,2.8114011,,,,,,,,,,,,,,,,, -96900,0.46325168,2.7390628,,,,,,,,,,,,,,,,, -97000,0.4669587,2.7877328,,,,,,,,,,,,,,,,, -97100,0.48030677,2.6804218,,,,,,,,,,,,,,,,, -97200,0.44024336,2.7473905,,,,,,,,,,,,,,,,, -97300,0.45449764,2.7975667,,,,,,,,,,,,,,,,, -97337,,,0.6754931807518005,1.6232917308807373,33.42584230300187,0.6868730783462524,1.5403165817260742,30.324988297327963,3000.0,0.7019813060760498,1.4636123180389404,30.23504562070128,3003.0,33646.52457332611,54436.56728100777,33646.52457332611,20785.443420887,1.492318868637085,0.0 -97400,0.47762173,2.7131698,,,,,,,,,,,,,,,,, -97500,0.48397335,2.8127291,,,,,,,,,,,,,,,,, -97600,0.45864624,2.727857,,,,,,,,,,,,,,,,, -97700,0.4548108,2.7366738,,,,,,,,,,,,,,,,, -97800,0.4429896,2.7013552,,,,,,,,,,,,,,,,, -97900,0.4681747,2.7727156,,,,,,,,,,,,,,,,, -98000,0.47684205,2.7835078,,,,,,,,,,,,,,,,, -98100,0.4713831,2.71481,,,,,,,,,,,,,,,,, -98200,0.47609687,2.7756665,,,,,,,,,,,,,,,,, -98300,0.47657803,2.7509756,,,,,,,,,,,,,,,,, -98400,0.50770974,2.7349474,,,,,,,,,,,,,,,,, -98500,0.4720359,2.7099485,,,,,,,,,,,,,,,,, -98600,0.5125375,2.7263165,,,,,,,,,,,,,,,,, -98700,0.45512348,2.6708307,,,,,,,,,,,,,,,,, -98800,0.50902283,2.8156972,,,,,,,,,,,,,,,,, -98900,0.44339854,2.7020025,,,,,,,,,,,,,,,,, -99000,0.47009623,2.7222798,,,,,,,,,,,,,,,,, -99100,0.47928903,2.779903,,,,,,,,,,,,,,,,, -99200,0.43753982,2.683856,,,,,,,,,,,,,,,,, -99300,0.487309,2.7433147,,,,,,,,,,,,,,,,, -99400,0.44657087,2.7186995,,,,,,,,,,,,,,,,, -99500,0.46843854,2.7487438,,,,,,,,,,,,,,,,, -99600,0.4875963,2.691799,,,,,,,,,,,,,,,,, -99700,0.45651472,2.7287676,,,,,,,,,,,,,,,,, -99771,,,0.6739282011985779,1.63223135471344,33.990486539139006,0.6875302195549011,1.536251187324524,30.240875818928536,3000.0,0.7032363414764404,1.4567337036132812,30.214177626963963,3003.0,34486.669437885284,55755.76510024071,34486.669437885284,21264.37947511673,1.5324418544769287,0.0 -99800,0.45085528,2.7196884,,,,,,,,,,,,,,,,, -99900,0.49691728,2.7700226,,,,,,,,,,,,,,,,, -100000,0.4706662,2.7770014,,,,,,,,,,,,,,,,, -100100,0.47086146,2.7272277,,,,,,,,,,,,,,,,, -100200,0.5027704,2.7145374,,,,,,,,,,,,,,,,, -100300,0.44731048,2.7213733,,,,,,,,,,,,,,,,, -100400,0.4702882,2.7383862,,,,,,,,,,,,,,,,, -100500,0.4991207,2.7712915,,,,,,,,,,,,,,,,, -100600,0.4897987,2.7157588,,,,,,,,,,,,,,,,, -100700,0.48638982,2.7625606,,,,,,,,,,,,,,,,, -100800,0.51568866,2.6802866,,,,,,,,,,,,,,,,, -100900,0.49055842,2.6979504,,,,,,,,,,,,,,,,, -101000,0.48564908,2.7308292,,,,,,,,,,,,,,,,, -101100,0.501328,2.6840594,,,,,,,,,,,,,,,,, -101200,0.515797,2.7478645,,,,,,,,,,,,,,,,, -101300,0.48742002,2.7901528,,,,,,,,,,,,,,,,, -101400,0.4701153,2.721751,,,,,,,,,,,,,,,,, -101500,0.5307113,2.7644327,,,,,,,,,,,,,,,,, -101600,0.52015704,2.808343,,,,,,,,,,,,,,,,, -101700,0.4937845,2.7264493,,,,,,,,,,,,,,,,, -101800,0.49326152,2.7397974,,,,,,,,,,,,,,,,, -101900,0.52615315,2.7507682,,,,,,,,,,,,,,,,, -102000,0.46712604,2.7042532,,,,,,,,,,,,,,,,, -102100,0.4789574,2.651358,,,,,,,,,,,,,,,,, -102200,0.4797905,2.7217717,,,,,,,,,,,,,,,,, -102204,,,0.6823273301124573,1.5801806449890137,34.581900656888564,0.6878277659416199,1.5308395624160769,30.43020782921961,3000.0,0.7022369503974915,1.456139326095581,30.21770641738397,3003.0,35326.57743191719,57047.69973301888,35326.57743191719,21716.28870844841,1.5729172229766846,0.0 -102300,0.533074,2.7090228,,,,,,,,,,,,,,,,, -102400,0.46516612,2.6584067,,,,,,,,,,,,,,,,, -102500,0.4957742,2.7493277,,,,,,,,,,,,,,,,, -102600,0.49657756,2.7168984,,,,,,,,,,,,,,,,, -102700,0.50060415,2.8169127,,,,,,,,,,,,,,,,, -102800,0.4959047,2.7683342,,,,,,,,,,,,,,,,, -102900,0.5013596,2.691773,,,,,,,,,,,,,,,,, -103000,0.48589617,2.7504673,,,,,,,,,,,,,,,,, -103100,0.5306182,2.7501194,,,,,,,,,,,,,,,,, -103200,0.5120052,2.7069912,,,,,,,,,,,,,,,,, -103300,0.5350913,2.7429333,,,,,,,,,,,,,,,,, -103400,0.50769967,2.7045655,,,,,,,,,,,,,,,,, -103500,0.5198885,2.6984494,,,,,,,,,,,,,,,,, -103600,0.52736205,2.7224145,,,,,,,,,,,,,,,,, -103700,0.5126938,2.718476,,,,,,,,,,,,,,,,, -103800,0.5056605,2.6500337,,,,,,,,,,,,,,,,, -103900,0.531613,2.7331338,,,,,,,,,,,,,,,,, -104000,0.5172449,2.6915472,,,,,,,,,,,,,,,,, -104100,0.52041286,2.7422793,,,,,,,,,,,,,,,,, -104200,0.52624035,2.7120779,,,,,,,,,,,,,,,,, -104300,0.51757777,2.7253196,,,,,,,,,,,,,,,,, -104400,0.5047919,2.724412,,,,,,,,,,,,,,,,, -104500,0.51568013,2.7675476,,,,,,,,,,,,,,,,, -104600,0.49837148,2.6740396,,,,,,,,,,,,,,,,, -104638,,,0.6786729693412781,1.605922818183899,34.14329271426784,0.688373327255249,1.5237150192260742,30.576941500675023,3000.0,0.7040613889694214,1.444055676460266,30.680197257595136,3003.0,36166.62100124359,58382.8054254055,36166.62100124359,22211.2339220047,1.612227439880371,0.0 -104700,0.5218705,2.719554,,,,,,,,,,,,,,,,, -104800,0.5240168,2.714625,,,,,,,,,,,,,,,,, -104900,0.545658,2.7008762,,,,,,,,,,,,,,,,, -105000,0.5190566,2.7043614,,,,,,,,,,,,,,,,, -105100,0.51475364,2.6703055,,,,,,,,,,,,,,,,, -105200,0.5332725,2.744594,,,,,,,,,,,,,,,,, -105300,0.5154016,2.71533,,,,,,,,,,,,,,,,, -105400,0.5206993,2.6754653,,,,,,,,,,,,,,,,, -105500,0.5255925,2.6926847,,,,,,,,,,,,,,,,, -105600,0.5135904,2.7022185,,,,,,,,,,,,,,,,, -105700,0.52610487,2.6863606,,,,,,,,,,,,,,,,, -105800,0.54099405,2.712115,,,,,,,,,,,,,,,,, -105900,0.5559063,2.668672,,,,,,,,,,,,,,,,, -106000,0.516307,2.6290872,,,,,,,,,,,,,,,,, -106100,0.5222603,2.6632116,,,,,,,,,,,,,,,,, -106200,0.5559447,2.7377455,,,,,,,,,,,,,,,,, -106300,0.548705,2.7414253,,,,,,,,,,,,,,,,, -106400,0.54450774,2.7071955,,,,,,,,,,,,,,,,, -106500,0.5214258,2.6808496,,,,,,,,,,,,,,,,, -106600,0.5196952,2.7261283,,,,,,,,,,,,,,,,, -106700,0.54795563,2.7326803,,,,,,,,,,,,,,,,, -106800,0.5400943,2.7040715,,,,,,,,,,,,,,,,, -106900,0.53436184,2.6594486,,,,,,,,,,,,,,,,, -107000,0.54318786,2.7481382,,,,,,,,,,,,,,,,, -107072,,,0.6972777843475342,1.501353740692139,35.75124263193658,0.6902828216552734,1.5222079753875732,30.50713657492696,3000.0,0.7058277130126953,1.44207763671875,30.48795900641476,3003.0,37006.69671726227,59689.48287606239,37006.69671726227,22677.718291044235,1.6531658172607422,0.0 -107100,0.5773163,2.7146535,,,,,,,,,,,,,,,,, -107200,0.53912824,2.7148297,,,,,,,,,,,,,,,,, -107300,0.5469886,2.6813567,,,,,,,,,,,,,,,,, -107400,0.5252877,2.6992326,,,,,,,,,,,,,,,,, -107500,0.56934506,2.6248412,,,,,,,,,,,,,,,,, -107600,0.5501127,2.65376,,,,,,,,,,,,,,,,, -107700,0.542097,2.6466613,,,,,,,,,,,,,,,,, -107800,0.55467767,2.678414,,,,,,,,,,,,,,,,, -107900,0.5628548,2.6295037,,,,,,,,,,,,,,,,, -108000,0.5655197,2.6892505,,,,,,,,,,,,,,,,, -108100,0.5648713,2.6884181,,,,,,,,,,,,,,,,, -108200,0.58002925,2.685999,,,,,,,,,,,,,,,,, -108300,0.56025654,2.7233257,,,,,,,,,,,,,,,,, -108400,0.52768844,2.6608124,,,,,,,,,,,,,,,,, -108500,0.56569856,2.6750047,,,,,,,,,,,,,,,,, -108600,0.5600093,2.7117653,,,,,,,,,,,,,,,,, -108700,0.5853475,2.7160854,,,,,,,,,,,,,,,,, -108800,0.5669766,2.6554673,,,,,,,,,,,,,,,,, -108900,0.5677695,2.6412876,,,,,,,,,,,,,,,,, -109000,0.58807874,2.7007856,,,,,,,,,,,,,,,,, -109100,0.5581748,2.6287508,,,,,,,,,,,,,,,,, -109200,0.5950594,2.708316,,,,,,,,,,,,,,,,, -109300,0.59280103,2.6524644,,,,,,,,,,,,,,,,, -109400,0.5738068,2.6538436,,,,,,,,,,,,,,,,, -109500,0.54288805,2.7297692,,,,,,,,,,,,,,,,, -109506,,,0.6877254843711853,1.5565861463546753,35.05868995744967,0.6905431747436523,1.5144131183624268,30.72632314292713,3000.0,0.7075126767158508,1.4303762912750244,30.78264729883304,3003.0,37846.90900039673,61018.65080690384,37846.90900039673,23166.557002544403,1.693382978439331,0.0 -109600,0.58618706,2.6263776,,,,,,,,,,,,,,,,, -109700,0.58137953,2.69551,,,,,,,,,,,,,,,,, -109800,0.5663245,2.608821,,,,,,,,,,,,,,,,, -109900,0.59304917,2.7064471,,,,,,,,,,,,,,,,, -110000,0.55977076,2.6723506,,,,,,,,,,,,,,,,, -110100,0.5869387,2.689675,,,,,,,,,,,,,,,,, -110200,0.56134355,2.7095292,,,,,,,,,,,,,,,,, -110300,0.5805781,2.6630838,,,,,,,,,,,,,,,,, -110400,0.6062577,2.7542582,,,,,,,,,,,,,,,,, -110500,0.58442205,2.6580439,,,,,,,,,,,,,,,,, -110600,0.5693973,2.6161985,,,,,,,,,,,,,,,,, -110700,0.61160004,2.7218468,,,,,,,,,,,,,,,,, -110800,0.5643775,2.709857,,,,,,,,,,,,,,,,, -110900,0.5880643,2.7352579,,,,,,,,,,,,,,,,, -111000,0.6013828,2.67127,,,,,,,,,,,,,,,,, -111100,0.6350969,2.6946728,,,,,,,,,,,,,,,,, -111200,0.5852474,2.6286113,,,,,,,,,,,,,,,,, -111300,0.5823316,2.6876347,,,,,,,,,,,,,,,,, -111400,0.5978616,2.7073128,,,,,,,,,,,,,,,,, -111500,0.6024085,2.7068417,,,,,,,,,,,,,,,,, -111600,0.6144939,2.677921,,,,,,,,,,,,,,,,, -111700,0.5940711,2.6583915,,,,,,,,,,,,,,,,, -111800,0.60645854,2.6526327,,,,,,,,,,,,,,,,, -111900,0.61503184,2.7143703,,,,,,,,,,,,,,,,, -111939,,,0.6905927658081055,1.540831446647644,34.9471526156188,0.6911631226539612,1.5094802379608154,30.61271867582907,3000.0,0.707129180431366,1.427580952644348,30.68408721481186,3003.0,38686.852848529816,62330.95977449417,38686.852848529816,23638.802623033524,1.736116647720337,0.0 -112000,0.63232297,2.7024138,,,,,,,,,,,,,,,,, -112100,0.6377853,2.6255841,,,,,,,,,,,,,,,,, -112200,0.6182472,2.631025,,,,,,,,,,,,,,,,, -112300,0.59860206,2.6392992,,,,,,,,,,,,,,,,, -112400,0.59016937,2.6312778,,,,,,,,,,,,,,,,, -112500,0.6129797,2.6618161,,,,,,,,,,,,,,,,, -112600,0.6145235,2.6751049,,,,,,,,,,,,,,,,, -112700,0.6114168,2.604345,,,,,,,,,,,,,,,,, -112800,0.5963383,2.6422324,,,,,,,,,,,,,,,,, -112900,0.61089057,2.6381016,,,,,,,,,,,,,,,,, -113000,0.59085476,2.579448,,,,,,,,,,,,,,,,, -113100,0.6190312,2.5871463,,,,,,,,,,,,,,,,, -113200,0.5795373,2.584764,,,,,,,,,,,,,,,,, -113300,0.5962075,2.6688657,,,,,,,,,,,,,,,,, -113400,0.64109737,2.6485038,,,,,,,,,,,,,,,,, -113500,0.63573396,2.6524913,,,,,,,,,,,,,,,,, -113600,0.60789514,2.634188,,,,,,,,,,,,,,,,, -113700,0.627123,2.6465614,,,,,,,,,,,,,,,,, -113800,0.6599619,2.6272035,,,,,,,,,,,,,,,,, -113900,0.620943,2.6107154,,,,,,,,,,,,,,,,, -114000,0.6127349,2.6357005,,,,,,,,,,,,,,,,, -114100,0.62212104,2.6368191,,,,,,,,,,,,,,,,, -114200,0.6244847,2.6679027,,,,,,,,,,,,,,,,, -114300,0.6476968,2.663907,,,,,,,,,,,,,,,,, -114372,,,0.7004918456077576,1.4847887754440308,36.07395854939219,0.6930230259895325,1.5063766241073608,30.78058822928789,3000.0,0.708628237247467,1.4222077131271362,30.82562319625663,3003.0,39526.98830747605,63637.38416719437,39526.98830747605,24104.972608327866,1.7767961025238037,0.0 -114400,0.6576761,2.6478207,,,,,,,,,,,,,,,,, -114500,0.6305356,2.7186642,,,,,,,,,,,,,,,,, -114600,0.63579774,2.6341226,,,,,,,,,,,,,,,,, -114700,0.64385456,2.597977,,,,,,,,,,,,,,,,, -114800,0.6312261,2.6222684,,,,,,,,,,,,,,,,, -114900,0.6382873,2.6393259,,,,,,,,,,,,,,,,, -115000,0.6236563,2.6118546,,,,,,,,,,,,,,,,, -115100,0.63321066,2.6026256,,,,,,,,,,,,,,,,, -115200,0.62712497,2.6586056,,,,,,,,,,,,,,,,, -115300,0.63444674,2.6701138,,,,,,,,,,,,,,,,, -115400,0.6726684,2.6624184,,,,,,,,,,,,,,,,, -115500,0.63628066,2.6874468,,,,,,,,,,,,,,,,, -115600,0.63936716,2.6600115,,,,,,,,,,,,,,,,, -115700,0.6490977,2.6723895,,,,,,,,,,,,,,,,, -115800,0.6491664,2.6363413,,,,,,,,,,,,,,,,, -115900,0.62944114,2.6464438,,,,,,,,,,,,,,,,, -116000,0.6775685,2.6517954,,,,,,,,,,,,,,,,, -116100,0.6395488,2.6122532,,,,,,,,,,,,,,,,, -116200,0.6549245,2.61446,,,,,,,,,,,,,,,,, -116300,0.6388351,2.6346886,,,,,,,,,,,,,,,,, -116400,0.6758113,2.559812,,,,,,,,,,,,,,,,, -116500,0.63430345,2.6303484,,,,,,,,,,,,,,,,, -116600,0.6790726,2.6529758,,,,,,,,,,,,,,,,, -116700,0.6358241,2.5764174,,,,,,,,,,,,,,,,, -116800,0.6597469,2.6250117,,,,,,,,,,,,,,,,, -116805,,,0.6977561116218567,1.5053915977478027,35.92062263456214,0.6923906803131104,1.5051144361495972,30.750445632725896,3000.0,0.709511399269104,1.4227913618087769,30.89250024411417,3003.0,40366.93438744545,64949.66950559616,40366.93438744545,24577.193524599075,1.8184754848480225,0.0 -116900,0.6529588,2.6188495,,,,,,,,,,,,,,,,, -117000,0.6909079,2.660515,,,,,,,,,,,,,,,,, -117100,0.6727483,2.6386688,,,,,,,,,,,,,,,,, -117200,0.6991691,2.6309273,,,,,,,,,,,,,,,,, -117300,0.6461209,2.6599483,,,,,,,,,,,,,,,,, -117400,0.664458,2.665253,,,,,,,,,,,,,,,,, -117500,0.6708516,2.6890397,,,,,,,,,,,,,,,,, -117600,0.65255636,2.5792825,,,,,,,,,,,,,,,,, -117700,0.645276,2.588713,,,,,,,,,,,,,,,,, -117800,0.6377008,2.6128378,,,,,,,,,,,,,,,,, -117900,0.67259383,2.6169345,,,,,,,,,,,,,,,,, -118000,0.6412981,2.5451734,,,,,,,,,,,,,,,,, -118100,0.6911122,2.638119,,,,,,,,,,,,,,,,, -118200,0.6820326,2.5949023,,,,,,,,,,,,,,,,, -118300,0.6475033,2.6487513,,,,,,,,,,,,,,,,, -118400,0.68643105,2.6266458,,,,,,,,,,,,,,,,, -118500,0.6457895,2.6354492,,,,,,,,,,,,,,,,, -118600,0.6448856,2.60132,,,,,,,,,,,,,,,,, -118700,0.649504,2.6744826,,,,,,,,,,,,,,,,, -118800,0.6747824,2.6011808,,,,,,,,,,,,,,,,, -118900,0.67611057,2.5798752,,,,,,,,,,,,,,,,, -119000,0.6886612,2.674837,,,,,,,,,,,,,,,,, -119100,0.6723477,2.6079276,,,,,,,,,,,,,,,,, -119200,0.69500524,2.7010884,,,,,,,,,,,,,,,,, -119239,,,0.7087553143501282,1.4513678550720217,36.90472472820269,0.6937917470932007,1.5004736185073853,31.007339831048924,3000.0,0.7099180817604065,1.4181479215621948,30.699474406255312,3003.0,41207.07846236229,66267.013463974,41207.07846236229,25054.274214982983,1.8607072830200195,0.0 -119300,0.67399657,2.6035762,,,,,,,,,,,,,,,,, -119400,0.6880154,2.6386294,,,,,,,,,,,,,,,,, -119500,0.6708542,2.6291442,,,,,,,,,,,,,,,,, -119600,0.6602882,2.6304953,,,,,,,,,,,,,,,,, -119700,0.65643984,2.6376536,,,,,,,,,,,,,,,,, -119800,0.6869085,2.5624075,,,,,,,,,,,,,,,,, -119900,0.67142504,2.6247344,,,,,,,,,,,,,,,,, -120000,0.68443376,2.602614,,,,,,,,,,,,,,,,, -120100,0.6860633,2.639381,,,,,,,,,,,,,,,,, -120200,0.6920472,2.5941067,,,,,,,,,,,,,,,,, -120300,0.70421356,2.6610467,,,,,,,,,,,,,,,,, -120400,0.6860248,2.6596541,,,,,,,,,,,,,,,,, -120500,0.67236406,2.6240573,,,,,,,,,,,,,,,,, -120600,0.68296015,2.6787224,,,,,,,,,,,,,,,,, -120700,0.6844714,2.599471,,,,,,,,,,,,,,,,, -120800,0.70444804,2.6059096,,,,,,,,,,,,,,,,, -120900,0.72006726,2.6297526,,,,,,,,,,,,,,,,, -121000,0.70244265,2.562088,,,,,,,,,,,,,,,,, -121100,0.708877,2.5896778,,,,,,,,,,,,,,,,, -121200,0.68586063,2.6247315,,,,,,,,,,,,,,,,, -121300,0.6849692,2.544944,,,,,,,,,,,,,,,,, -121400,0.71046275,2.5976703,,,,,,,,,,,,,,,,, -121500,0.69974816,2.6090944,,,,,,,,,,,,,,,,, -121600,0.68123883,2.522913,,,,,,,,,,,,,,,,, -121672,,,0.7050728797912598,1.468165159225464,36.167227486174745,0.6942753195762634,1.498553991317749,30.98735750993361,3000.0,0.7114868760108948,1.4147039651870728,30.87680663083024,3003.0,42047.01487851143,67575.60860204697,42047.01487851143,25522.81412148476,1.9024908542633057,0.0 -121700,0.69731224,2.6310387,,,,,,,,,,,,,,,,, -121800,0.7313239,2.673307,,,,,,,,,,,,,,,,, -121900,0.68840873,2.6332958,,,,,,,,,,,,,,,,, -122000,0.6905536,2.5795958,,,,,,,,,,,,,,,,, -122100,0.73059845,2.538184,,,,,,,,,,,,,,,,, -122200,0.68625,2.6333754,,,,,,,,,,,,,,,,, -122300,0.7069116,2.5793512,,,,,,,,,,,,,,,,, -122400,0.73277813,2.563972,,,,,,,,,,,,,,,,, -122500,0.6726527,2.5717392,,,,,,,,,,,,,,,,, -122600,0.68648255,2.603793,,,,,,,,,,,,,,,,, -122700,0.70279443,2.6139991,,,,,,,,,,,,,,,,, -122800,0.70210636,2.639304,,,,,,,,,,,,,,,,, -122900,0.72713447,2.5847156,,,,,,,,,,,,,,,,, -123000,0.733029,2.6566143,,,,,,,,,,,,,,,,, -123100,0.7083381,2.6110053,,,,,,,,,,,,,,,,, -123200,0.7091617,2.552158,,,,,,,,,,,,,,,,, -123300,0.71986943,2.6075404,,,,,,,,,,,,,,,,, -123400,0.7233879,2.5996735,,,,,,,,,,,,,,,,, -123500,0.7500421,2.5721672,,,,,,,,,,,,,,,,, -123600,0.7129775,2.649002,,,,,,,,,,,,,,,,, -123700,0.7187927,2.5993278,,,,,,,,,,,,,,,,, -123800,0.68056583,2.57526,,,,,,,,,,,,,,,,, -123900,0.71936053,2.5951831,,,,,,,,,,,,,,,,, -124000,0.6976693,2.540912,,,,,,,,,,,,,,,,, -124100,0.75658065,2.5638943,,,,,,,,,,,,,,,,, -124104,,,0.7041686773300171,1.464862823486328,36.28302984249575,0.6946101188659668,1.495320200920105,30.891498149474604,3000.0,0.7113357782363892,1.4104185104370115,30.965012996729225,3003.0,42886.95123958588,68882.01002573967,42886.95123958588,25989.104485034943,1.9998183250427248,0.0 -124200,0.7009514,2.5893023,,,,,,,,,,,,,,,,, -124300,0.7407254,2.554493,,,,,,,,,,,,,,,,, -124400,0.74154896,2.623568,,,,,,,,,,,,,,,,, -124500,0.70944226,2.626973,,,,,,,,,,,,,,,,, -124600,0.6983223,2.615912,,,,,,,,,,,,,,,,, -124700,0.72372156,2.59397,,,,,,,,,,,,,,,,, -124800,0.74584794,2.6254833,,,,,,,,,,,,,,,,, -124900,0.69770026,2.6089694,,,,,,,,,,,,,,,,, -125000,0.68515,2.606928,,,,,,,,,,,,,,,,, -125100,0.73377246,2.5847375,,,,,,,,,,,,,,,,, -125200,0.7160908,2.5490892,,,,,,,,,,,,,,,,, -125300,0.69986206,2.5817058,,,,,,,,,,,,,,,,, -125400,0.7236822,2.6649945,,,,,,,,,,,,,,,,, -125500,0.7244582,2.6062317,,,,,,,,,,,,,,,,, -125600,0.71635026,2.629124,,,,,,,,,,,,,,,,, -125700,0.7170126,2.5750093,,,,,,,,,,,,,,,,, -125800,0.70392084,2.6168244,,,,,,,,,,,,,,,,, -125900,0.72510713,2.5927143,,,,,,,,,,,,,,,,, -126000,0.7204409,2.5968204,,,,,,,,,,,,,,,,, -126100,0.6910527,2.6141176,,,,,,,,,,,,,,,,, -126200,0.7366061,2.5588093,,,,,,,,,,,,,,,,, -126300,0.7461549,2.6873481,,,,,,,,,,,,,,,,, -126400,0.7339336,2.6066172,,,,,,,,,,,,,,,,, -126500,0.70435387,2.5871363,,,,,,,,,,,,,,,,, -126537,,,0.7092657685279846,1.449809193611145,36.99707375984359,0.693990170955658,1.496898889541626,30.97352039029193,3000.0,0.7120446562767029,1.4122170209884644,31.058713552731536,3003.0,43726.96578860283,70198.45528745651,43726.96578860283,26465.41655921936,2.0420727729797363,0.0 -126600,0.7044248,2.529949,,,,,,,,,,,,,,,,, -126700,0.71759635,2.5230103,,,,,,,,,,,,,,,,, -126800,0.75095326,2.5818524,,,,,,,,,,,,,,,,, -126900,0.7008444,2.5416665,,,,,,,,,,,,,,,,, -127000,0.7251836,2.5347316,,,,,,,,,,,,,,,,, -127100,0.7263775,2.560517,,,,,,,,,,,,,,,,, -127200,0.6980979,2.5460553,,,,,,,,,,,,,,,,, -127300,0.7139923,2.6079435,,,,,,,,,,,,,,,,, -127400,0.75464344,2.5350246,,,,,,,,,,,,,,,,, -127500,0.7249083,2.6349523,,,,,,,,,,,,,,,,, -127600,0.72494966,2.575395,,,,,,,,,,,,,,,,, -127700,0.7267903,2.602969,,,,,,,,,,,,,,,,, -127800,0.7361113,2.5747466,,,,,,,,,,,,,,,,, -127900,0.71859205,2.5611563,,,,,,,,,,,,,,,,, -128000,0.72039616,2.597434,,,,,,,,,,,,,,,,, -128100,0.74304956,2.6214936,,,,,,,,,,,,,,,,, -128200,0.69686484,2.6096883,,,,,,,,,,,,,,,,, -128300,0.71301186,2.6266963,,,,,,,,,,,,,,,,, -128400,0.72141033,2.5690963,,,,,,,,,,,,,,,,, -128500,0.7235224,2.5830834,,,,,,,,,,,,,,,,, -128600,0.7256461,2.6088192,,,,,,,,,,,,,,,,, -128700,0.72937053,2.6008036,,,,,,,,,,,,,,,,, -128800,0.7286293,2.6106236,,,,,,,,,,,,,,,,, -128900,0.69256026,2.5920064,,,,,,,,,,,,,,,,, -128970,,,0.7105950117111206,1.435489535331726,37.02532880316543,0.6939777731895447,1.4965522289276123,31.032989289148937,3000.0,0.7122886776924133,1.4110324382781982,30.86724253804981,3003.0,44566.927010297775,71499.1986413002,44566.927010297775,26926.0817000866,2.082228422164917,0.0 -129000,0.7043309,2.596786,,,,,,,,,,,,,,,,, -129100,0.72272134,2.548776,,,,,,,,,,,,,,,,, -129200,0.71112853,2.590703,,,,,,,,,,,,,,,,, -129300,0.70578146,2.5836203,,,,,,,,,,,,,,,,, -129400,0.72953403,2.561236,,,,,,,,,,,,,,,,, -129500,0.7015856,2.5513976,,,,,,,,,,,,,,,,, -129600,0.7011418,2.562884,,,,,,,,,,,,,,,,, -129700,0.71791834,2.6004837,,,,,,,,,,,,,,,,, -129800,0.7195572,2.5511389,,,,,,,,,,,,,,,,, -129900,0.71342605,2.6261072,,,,,,,,,,,,,,,,, -130000,0.740406,2.5305326,,,,,,,,,,,,,,,,, -130100,0.7244907,2.5997267,,,,,,,,,,,,,,,,, -130200,0.7316239,2.6086168,,,,,,,,,,,,,,,,, -130300,0.7044088,2.6391714,,,,,,,,,,,,,,,,, -130400,0.7065279,2.593871,,,,,,,,,,,,,,,,, -130500,0.7441197,2.594134,,,,,,,,,,,,,,,,, -130600,0.71032476,2.587895,,,,,,,,,,,,,,,,, -130700,0.7243642,2.5740767,,,,,,,,,,,,,,,,, -130800,0.71986127,2.6283162,,,,,,,,,,,,,,,,, -130900,0.69488704,2.5509658,,,,,,,,,,,,,,,,, -131000,0.7256914,2.542745,,,,,,,,,,,,,,,,, -131100,0.71179265,2.6133773,,,,,,,,,,,,,,,,, -131200,0.72651327,2.6186805,,,,,,,,,,,,,,,,, -131300,0.75471085,2.6463852,,,,,,,,,,,,,,,,, -131400,0.7130412,2.5711915,,,,,,,,,,,,,,,,, -131403,,,0.705418586730957,1.4653115272521973,36.74133977241949,0.6943249106407166,1.496314287185669,30.95072385384553,3000.0,0.7123119235038757,1.4104419946670532,30.963210714876272,3003.0,45407.02690410614,72805.75824856758,45407.02690410614,27392.42360830307,2.123875141143799,0.0 -131500,0.7207009,2.5752392,,,,,,,,,,,,,,,,, -131600,0.7319162,2.5869524,,,,,,,,,,,,,,,,, -131700,0.7350638,2.604468,,,,,,,,,,,,,,,,, -131800,0.71539116,2.6029284,,,,,,,,,,,,,,,,, -131900,0.7140505,2.5317962,,,,,,,,,,,,,,,,, -132000,0.72976995,2.6157246,,,,,,,,,,,,,,,,, -132100,0.70664674,2.5245128,,,,,,,,,,,,,,,,, -132200,0.71071917,2.5941384,,,,,,,,,,,,,,,,, -132300,0.73240477,2.5891454,,,,,,,,,,,,,,,,, -132400,0.7631851,2.6316643,,,,,,,,,,,,,,,,, -132500,0.7014834,2.5905836,,,,,,,,,,,,,,,,, -132600,0.72699875,2.5648546,,,,,,,,,,,,,,,,, -132700,0.7327249,2.6126697,,,,,,,,,,,,,,,,, -132800,0.73910034,2.6331663,,,,,,,,,,,,,,,,, -132900,0.71032965,2.6506166,,,,,,,,,,,,,,,,, -133000,0.6980224,2.6170251,,,,,,,,,,,,,,,,, -133100,0.71815425,2.559532,,,,,,,,,,,,,,,,, -133200,0.7102926,2.5731895,,,,,,,,,,,,,,,,, -133300,0.7179195,2.6467326,,,,,,,,,,,,,,,,, -133333,,,0.7103955149650574,1.443301796913147,36.525395183981935,0.6940893530845642,1.4960108995437622,30.98135216301977,3000.0,0.7121259570121765,1.4103668928146362,30.99405716021917,3003.0,46073.19484710693,73949.41149926186,46073.19484710693,27869.805367946625,2.167254447937012,0.0 -133333,,,,,,,,,,,,,,46073.19484710693,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/eval_measurements.csv deleted file mode 100644 index 605f59226..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/eval_measurements.csv +++ /dev/null @@ -1,57 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -856.4819233417511,0.0,26.1668484210968,1,0,26.1668484210968,0.0007088489946909,0.0,11.036273956298828,3003,882.6488099098206,0.0005942517309449,0.0,11.024877548217772,0.0004835649742744,0.0,11.047277450561523,3000 -1400.5178980827332,0.0198602676391601,866.3570308685303,2430,0,866.3570308685303,0.4070304036140442,9.262918884334642,4.187845706939697,3003,2266.9725909233093,0.4335348904132843,15.486008673345257,3.881221294403076,0.4170809984207153,10.90630801345595,4.0259270668029785,3000 -1882.8273282051089,0.0443954467773437,1706.3713157176971,4858,0,1706.3713157176971,0.5491255521774292,18.97626936557872,2.8148353099823,3003,3589.398926973343,0.5460978150367737,23.83795839443816,2.815052032470703,0.5495033860206604,20.46226812080449,2.779167413711548,3000 -2339.8702476024628,0.0700273513793945,2546.6103324890137,7289,0,2546.6103324890137,0.5970367789268494,22.367772014175003,2.35249662399292,3003,4886.784422636032,0.583653450012207,27.028060189117536,2.4412691593170166,0.5920695066452026,23.554423470160216,2.3615870475769043,3000 -2809.7746393680573,0.0957303047180175,3386.586151123047,9720,0,3386.586151123047,0.6241241097450256,24.38662733001904,2.124663829803467,3003,6196.767947673798,0.6012474894523621,29.023278515494034,2.2867743968963623,0.6191863417625427,25.267400374292624,2.1540048122406006,3000 -3383.849452972412,0.1213636398315429,4226.775480031967,12152,0,4226.775480031967,0.6393585801124573,25.511380252988697,2.005385637283325,3003,7611.135657072067,0.6117793917655945,29.49963286982012,2.2029600143432617,0.6326642036437988,26.453320841038845,2.0467686653137207,3000 -3971.3980989456177,0.1472148895263672,5066.86608338356,14583,0,5066.86608338356,0.6523270010948181,26.47813059844154,1.8937307596206665,3003,9038.87869143486,0.6235041618347168,30.151688486651715,2.090024471282959,0.642856240272522,26.98718339076088,1.947042465209961,3000 -4430.000137805939,0.1750760078430175,5906.852504491806,17014,0,5906.852504491806,0.6583580374717712,26.80384903284937,1.854627013206482,3003,10337.573017835615,0.6282114386558533,30.50936542867486,2.054261445999145,0.6478034853935242,27.620231439534933,1.9196758270263672,3000 -4911.174456119537,0.201505422592163,6746.784756422043,19445,0,6746.784756422043,0.6621463298797607,27.046071671892715,1.81953227519989,3003,11658.784047603607,0.6460630893707275,31.32886972123233,1.9484455585479736,0.652415931224823,27.88024937862478,1.8769932985305784,3000 -5370.643972635269,0.2301142215728759,7586.71210360527,21876,0,7586.71210360527,0.6679449081420898,27.51437928438289,1.7722675800323486,3003,12958.286858320236,0.6394797563552856,31.291298722145083,1.9793699979782104,0.6581319570541382,28.01187089919093,1.8350783586502075,3000 -6068.142309427261,0.2572338581085205,8426.640635728836,24307,0,8426.640635728836,0.672105073928833,27.47958904209218,1.7438435554504397,3003,14495.81839632988,0.6397490501403809,30.85258301009002,1.9547765254974363,0.6600165963172913,28.20852868398465,1.805528998374939,3000 -6762.721871137619,0.286259651184082,9266.780019283296,26738,0,9266.780019283296,0.6745221018791199,28.200531420440008,1.7301757335662842,3003,16030.644038438795,0.6513376832008362,31.97804613904664,1.8955870866775515,0.6653606295585632,28.451939799484045,1.7924036979675293,3000 -7236.388736963272,0.3159620761871338,10107.006126642227,29170,0,10107.006126642227,0.6792516708374023,28.4535721097046,1.7129313945770264,3003,17344.64367866516,0.6469125747680664,32.02554933303594,1.9222787618637085,0.6672824621200562,28.697036519419203,1.7752902507781982,3000 -7705.922129631042,0.3449394702911377,10947.211053609848,31601,0,10947.211053609848,0.679449200630188,28.46140712600896,1.6971747875213623,3003,18654.48917913437,0.6717249155044556,33.27960037029534,1.7598832845687866,0.6671584844589233,29.136006467228604,1.7688673734664917,3000 -8219.765256166458,0.3742711544036865,11787.294679403303,34032,0,11787.294679403303,0.6810644268989563,28.599915373846635,1.6823828220367432,3003,20008.52248167992,0.6559919714927673,32.682312093461405,1.8558120727539065,0.6694275140762329,29.291129779033984,1.746963381767273,3000 -8834.486620664597,0.4038436412811279,12627.530505895616,36463,0,12627.530505895616,0.682819128036499,29.09519909837025,1.669180154800415,3003,21463.58862900734,0.6533215045928955,32.04000187532934,1.8728480339050293,0.6712626218795776,28.41640521392595,1.741132140159607,3000 -9306.22245168686,0.4338409900665283,13467.513338565826,38894,0,13467.513338565826,0.6835280060768127,28.82156068715825,1.6568653583526611,3003,22775.415242671967,0.6603764891624451,32.47630884958098,1.8058222532272337,0.6701342463493347,29.079701147488368,1.730035662651062,3000 -9812.306602954865,0.464618444442749,14307.620192050934,41325,0,14307.620192050934,0.6861310005187988,29.236897732583365,1.6359333992004397,3003,24121.71455836296,0.6599708795547485,32.381106417776536,1.8201743364334104,0.6728744506835938,29.184639902057,1.712181806564331,3000 -10383.679992437364,0.4959902763366699,15147.648758888245,43756,0,15147.648758888245,0.6876997351646423,29.037590233653013,1.6349014043807983,3003,25533.22567415237,0.6520457863807678,32.20933811417118,1.86770761013031,0.6751186847686768,29.25031655424665,1.7147032022476196,3000 -10862.922046422958,0.5261423587799072,15987.78390431404,46189,0,15987.78390431404,0.6856428980827332,29.08447388429547,1.6498336791992188,3003,26852.710567712784,0.6639376878738403,32.48354802392069,1.786540985107422,0.6739283800125122,29.52103293408988,1.713200569152832,3000 -11336.640148878098,0.5586118698120117,16827.680357694626,48620,0,16827.680357694626,0.6883621215820312,29.228574269408423,1.642948031425476,3003,28166.43555521965,0.6577057242393494,32.11022349497067,1.8336910009384155,0.6760362386703491,29.52968287144964,1.7174301147460938,3000 -11802.670159101486,0.5905261039733887,17667.6850566864,51051,0,17667.6850566864,0.6865725517272949,28.930613904916136,1.620548129081726,3003,29472.580087661743,0.6678268909454346,33.25293797178625,1.75491201877594,0.6759990453720093,29.552817707166724,1.6939091682434082,3000 -12332.954099178314,0.6230416297912598,18507.86905694008,53483,0,18507.86905694008,0.693265974521637,29.551221567046127,1.6090433597564695,3003,30843.158007144928,0.6628263592720032,32.78826064342799,1.7955244779586792,0.6780325174331665,29.68220960991648,1.688191056251526,3000 -12796.398730516434,0.6548542976379395,19347.87039089203,55914,0,19347.87039089203,0.6930218935012817,29.3509889737326,1.5961912870407104,3003,32146.7134320736,0.6604369282722473,33.13701132138776,1.80417549610138,0.6777101159095764,29.61541145915985,1.676916241645813,3000 -13610.639861106873,0.6897470951080322,20187.9033575058,58348,0,20187.9033575058,0.3172041177749634,0.2760901163922393,4.3111090660095215,3003,33801.09925675392,0.3733042478561401,1.7790266662847811,3.6769845485687256,0.3269147276878357,0.3167417668719349,4.14604377746582,3000 -14126.851303100586,0.7242088317871094,21027.7988653183,60779,0,21027.7988653183,0.6925687193870544,29.86454461957114,1.6012877225875854,3003,35157.31735706329,0.6671022176742554,32.896878208625374,1.770507574081421,0.6774373650550842,29.561452728455883,1.6847100257873535,3000 -14592.508635282516,0.7560958862304688,21867.882091760635,63211,0,21867.882091760635,0.6939283013343811,29.44425505315059,1.5856915712356567,3003,36463.16767692566,0.6855831146240234,34.1527935497728,1.6496340036392212,0.679582417011261,29.659998515776348,1.6690008640289309,3000 -15051.599786758425,0.7890956401824951,22707.863388299946,65642,0,22707.863388299946,0.6969264149665833,29.88272037242085,1.5698133707046509,3003,37762.35045266152,0.6722832322120667,33.70603561239361,1.7180349826812744,0.681913435459137,30.087363587995423,1.6529948711395264,3000 -15502.837298870088,0.822166919708252,23547.92662167549,68073,0,23547.92662167549,0.6972169280052185,29.769703618118545,1.5671659708023071,3003,39053.7624464035,0.6699897646903992,33.207613914396966,1.7524583339691162,0.6819630265235901,30.112537396175984,1.6551176309585571,3000 -15983.3171210289,0.857762336730957,24387.89455485344,70504,0,24387.89455485344,0.6974377036094666,30.00090817367382,1.5637001991271973,3003,40374.32408332825,0.6783646941184998,34.29033621037187,1.6977791786193848,0.6820374131202698,30.249611344394367,1.6467190980911257,3000 -16439.81956934929,0.8949284553527832,25227.994698286057,72935,0,25227.994698286057,0.6994480490684509,29.76888711725281,1.5574828386306765,3003,41671.04165291786,0.6749582886695862,33.6660576168447,1.7154207229614258,0.6837732791900635,29.911717014627985,1.642994403839111,3000 -16958.04394555092,0.9335498809814452,26068.055867910385,75366,0,26068.055867910385,0.6980187296867371,29.991422765336544,1.5505318641662598,3003,43029.4435377121,0.701172411441803,35.91140519112401,1.5690522193908691,0.6847156286239624,30.37538583212888,1.6373581886291504,3000 -17410.868065595627,0.9702820777893066,26908.02235507965,77797,0,26908.02235507965,0.701714038848877,30.3965153832504,1.5495461225509644,3003,44322.34855294228,0.6788302659988403,34.17489567231058,1.6935021877288818,0.6846908330917358,30.014238840241926,1.637537956237793,3000 -17856.037237882614,1.008216142654419,27748.175061941147,80228,0,27748.175061941147,0.7003777027130127,30.21403086923635,1.542347073554993,3003,45607.78610897064,0.6816733479499817,34.12771578283215,1.680078387260437,0.6865506768226624,30.30910404470047,1.6328672170639038,3000 -18366.711642980576,1.0443508625030518,28588.085990428925,82659,0,28588.085990428925,0.7024344801902771,30.29900160524727,1.531815767288208,3003,46958.48557043076,0.6934168934822083,34.86383322748257,1.601096272468567,0.6877533793449402,30.61256790425198,1.6192811727523804,3000 -18848.596970558167,1.0820410251617432,29428.2906563282,85091,0,29428.2906563282,0.7035732865333557,30.34697045957415,1.5303460359573364,3003,48280.69108605385,0.6842333674430847,34.231418554266256,1.6667418479919434,0.6870714426040649,30.22605562863886,1.6234965324401855,3000 -19316.526229143143,1.1193876266479492,30268.27372980117,87522,0,30268.27372980117,0.7043286561965942,30.191210875242326,1.5281420946121216,3003,49588.71892952919,0.6859403252601624,34.449619726829866,1.6498194932937622,0.6880509853363037,30.61483554539125,1.620361566543579,3000 -19782.92421078682,1.1572742462158203,31108.289889335632,89953,0,31108.289889335632,0.7040846347808838,30.366536884259617,1.5222641229629517,3003,50895.24898195267,0.6895501613616943,35.17824123678239,1.6185948848724363,0.6874062418937683,30.320178766417413,1.6165112257003784,3000 -20263.464618206024,1.1962149143218994,31948.42696595192,92384,0,31948.42696595192,0.7052234411239624,30.455149942129623,1.5193902254104614,3003,52216.04299545288,0.690657913684845,35.047649936271675,1.6202352046966553,0.6881749629974365,30.26145651209316,1.61602783203125,3000 -20743.521106243134,1.2359349727630615,32788.33744096756,94815,0,32788.33744096756,0.7068154215812683,30.62494146013703,1.5066797733306885,3003,53536.12700033188,0.7034842371940613,35.83842989432134,1.538518309593201,0.6895264983177185,30.30097926569033,1.606481432914734,3000 -21208.848113775253,1.274146318435669,33628.37856912613,97246,0,33628.37856912613,0.706466794013977,30.46507838062227,1.511273980140686,3003,54841.610292196274,0.6954914927482605,35.719509026518935,1.5836246013641355,0.6894396543502808,30.44411971751765,1.608036756515503,3000 -21675.62931656837,1.3121390342712402,34468.591000556946,99677,0,34468.591000556946,0.7069316506385803,30.37346591024028,1.5114892721176147,3003,56148.71933174133,0.6979652643203735,35.63073779555025,1.5806125402450562,0.6885221600532532,30.35029837043829,1.6104274988174438,3000 -22138.22982096672,1.3539230823516846,35308.638811826706,102108,0,35308.638811826706,0.7062808871269226,30.5054449175256,1.5039721727371216,3003,57451.48646616936,0.7036283016204834,36.0726587683943,1.5418643951416016,0.689315676689148,30.482721601595376,1.6059999465942385,3000 -22614.202502965927,1.394146203994751,36148.625277519226,104539,0,36148.625277519226,0.7068619132041931,30.359611628316312,1.5041028261184692,3003,58767.56322598457,0.6998471617698669,35.92477397313134,1.5669548511505127,0.6904687881469727,30.45298341634368,1.6009769439697266,3000 -23122.549444437027,1.4340860843658447,36988.803647995,106971,0,36988.803647995,0.7066411375999451,30.1928181417836,1.503878474235535,3003,60116.20536804199,0.7189033031463623,37.323135864973146,1.4635096788406372,0.6899852156639099,30.498854117375075,1.603049635887146,3000 -23569.930787324905,1.4757776260375977,37828.95445275307,109401,0,37828.95445275307,0.7091511487960815,30.79524511078754,1.4960204362869265,3003,61403.85651016235,0.7105783224105835,36.803288837604285,1.5062320232391355,0.6902704238891602,30.68162179003621,1.5982686281204224,3000 -24032.82445716858,1.5178844928741455,38669.03541469574,111832,0,38669.03541469574,0.7090930342674255,30.644621518964943,1.5006964206695557,3003,62706.950540065765,0.7118147611618042,36.46637331018224,1.5033549070358276,0.6907911896705627,30.68753003150424,1.6010395288467407,3000 -24482.726551771164,1.5615882873535156,39509.11229848862,114263,0,39509.11229848862,0.708512008190155,30.55669261861972,1.4953467845916748,3003,63997.05086636543,0.7177315354347229,37.12487177143348,1.4715121984481812,0.6904315948486328,30.83881345974083,1.597692847251892,3000 -24945.666995048523,1.6028475761413574,40349.132508039474,116694,0,40349.132508039474,0.7089884281158447,30.72359865724317,1.5000479221343994,3003,65300.12974977493,0.7157125473022461,36.77640951803491,1.4814000129699707,0.6911507248878479,30.63074329345861,1.601112961769104,3000 -25411.930948019028,1.6457676887512207,41189.27968287468,119125,0,41189.27968287468,0.7088141441345215,30.61001220758412,1.4978607892990112,3003,66606.66089344025,0.7265663146972656,37.28005875910896,1.425704002380371,0.6910267472267151,30.580480746826456,1.5974948406219482,3000 -25877.113832235336,1.688605785369873,42029.3257484436,121555,0,42029.3257484436,0.7092208862304688,30.70075004627944,1.4971085786819458,3003,67912.01016974449,0.722062885761261,37.89198746749157,1.4466272592544556,0.6907168030738831,30.64146443986613,1.5995148420333862,3000 -26331.379861593246,1.732133388519287,42869.3330552578,123986,0,42869.3330552578,0.7090930342674255,30.69339181447463,1.4954748153686523,3003,69206.40446782112,0.7214351892471313,37.47774055464697,1.4518346786499023,0.6916838884353638,30.64448680287795,1.5992367267608645,3000 -26785.094694375992,1.7746250629425049,43709.52652788162,126417,0,43709.52652788162,0.7093719244003296,30.82672499037724,1.4938480854034424,3003,70500.4326248169,0.7267507910728455,38.03936471702318,1.4239526987075806,0.6922666430473328,30.60053445338147,1.5971962213516235,3000 -27242.02351140976,1.816523790359497,44549.70166611672,128848,0,44549.70166611672,0.709511399269104,30.74373078206694,1.4947388172149658,3003,71797.655200243,0.7229095101356506,37.66247762215224,1.4474419355392456,0.691745936870575,30.603091854338725,1.5984151363372805,3000 -27706.21432876587,1.859803915023804,45389.855187654495,131279,0,45389.855187654495,0.7093486785888672,30.861089142384305,1.4943861961364746,3003,73102.1201004982,0.7239293456077576,37.89795779107517,1.431621551513672,0.6918451189994812,30.610609263383143,1.5985901355743408,3000 -28152.985813856125,1.9054770469665527,46099.627485990524,133333,0,46099.627485990524,0.7095927000045776,30.79392196122609,1.4941442012786865,3003,74258.77504205704,0.7237716913223267,37.902471942390434,1.4409384727478027,0.691882312297821,30.55257145746482,1.59796142578125,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/measurements.csv deleted file mode 100644 index 0d342ea3f..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/measurements.csv +++ /dev/null @@ -1,1392 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.656282,11.028264,,,,,,,,,,,,,,,,, -1,,,0.0005942517309449,11.024877548217772,0.0,0.0004835649742744,11.047277450561523,0.0,3000.0,0.0007088489946909,11.036273956298828,0.0,3003.0,26.1668484210968,882.6488099098206,26.1668484210968,856.4819233417511,0.0,0.0 -100,0.24970856,9.060127,,,,,,,,,,,,,,,,, -200,0.25675812,8.737283,,,,,,,,,,,,,,,,, -300,0.4785685,8.342386,,,,,,,,,,,,,,,,, -400,0.5125925,8.048563,,,,,,,,,,,,,,,,, -500,0.93003225,7.8140526,,,,,,,,,,,,,,,,, -600,0.8229386,7.594801,,,,,,,,,,,,,,,,, -700,0.6472436,7.4801617,,,,,,,,,,,,,,,,, -800,0.75284845,7.26868,,,,,,,,,,,,,,,,, -900,0.5833507,7.114068,,,,,,,,,,,,,,,,, -1000,0.6050288,6.9823,,,,,,,,,,,,,,,,, -1100,0.6585421,6.890654,,,,,,,,,,,,,,,,, -1200,0.5036805,6.6712728,,,,,,,,,,,,,,,,, -1300,0.62357837,6.6476755,,,,,,,,,,,,,,,,, -1400,0.64420974,6.51519,,,,,,,,,,,,,,,,, -1500,0.6198375,6.4246235,,,,,,,,,,,,,,,,, -1600,0.5220594,6.33146,,,,,,,,,,,,,,,,, -1700,0.62855375,6.3015404,,,,,,,,,,,,,,,,, -1800,0.63655394,6.0978265,,,,,,,,,,,,,,,,, -1900,0.5143372,6.0576086,,,,,,,,,,,,,,,,, -2000,0.71852314,6.013935,,,,,,,,,,,,,,,,, -2100,0.6577823,5.7830005,,,,,,,,,,,,,,,,, -2200,0.61547065,5.789316,,,,,,,,,,,,,,,,, -2300,0.55488753,5.636215,,,,,,,,,,,,,,,,, -2400,0.63817644,5.5819387,,,,,,,,,,,,,,,,, -2430,,,0.4335348904132843,3.881221294403076,15.486008673345257,0.4170809984207153,4.0259270668029785,10.90630801345595,3000.0,0.4070304036140442,4.187845706939697,9.262918884334642,3003.0,866.3570308685303,2266.9725909233093,866.3570308685303,1400.5178980827332,0.0198602676391601,0.0 -2500,0.49179155,5.518532,,,,,,,,,,,,,,,,, -2600,0.6080176,5.4033227,,,,,,,,,,,,,,,,, -2700,0.74202436,5.398664,,,,,,,,,,,,,,,,, -2800,0.7009761,5.358186,,,,,,,,,,,,,,,,, -2900,0.5825669,5.2138004,,,,,,,,,,,,,,,,, -3000,0.5221297,5.229901,,,,,,,,,,,,,,,,, -3100,0.672671,5.2311273,,,,,,,,,,,,,,,,, -3200,0.5249581,5.1405916,,,,,,,,,,,,,,,,, -3300,0.5225136,5.077334,,,,,,,,,,,,,,,,, -3400,0.5296397,5.078788,,,,,,,,,,,,,,,,, -3500,0.48407632,4.952955,,,,,,,,,,,,,,,,, -3600,0.48212674,5.0151744,,,,,,,,,,,,,,,,, -3700,0.42927384,4.919651,,,,,,,,,,,,,,,,, -3800,0.5885478,4.9424515,,,,,,,,,,,,,,,,, -3900,0.43089685,4.8747106,,,,,,,,,,,,,,,,, -4000,0.4439744,4.918001,,,,,,,,,,,,,,,,, -4100,0.55777156,4.836022,,,,,,,,,,,,,,,,, -4200,0.42313033,4.864248,,,,,,,,,,,,,,,,, -4300,0.37321055,4.7806,,,,,,,,,,,,,,,,, -4400,0.42991903,4.7477756,,,,,,,,,,,,,,,,, -4500,0.3828291,4.6951327,,,,,,,,,,,,,,,,, -4600,0.43577486,4.752388,,,,,,,,,,,,,,,,, -4700,0.36099043,4.6822715,,,,,,,,,,,,,,,,, -4800,0.40406695,4.777835,,,,,,,,,,,,,,,,, -4858,,,0.5460978150367737,2.815052032470703,23.83795839443816,0.5495033860206604,2.779167413711548,20.46226812080449,3000.0,0.5491255521774292,2.8148353099823,18.97626936557872,3003.0,1706.3713157176971,3589.398926973343,1706.3713157176971,1882.8273282051089,0.0443954467773437,0.0 -4900,0.4385088,4.6959367,,,,,,,,,,,,,,,,, -5000,0.35837,4.6541524,,,,,,,,,,,,,,,,, -5100,0.37256655,4.6547885,,,,,,,,,,,,,,,,, -5200,0.38871658,4.7000904,,,,,,,,,,,,,,,,, -5300,0.3449119,4.6216426,,,,,,,,,,,,,,,,, -5400,0.32819194,4.6060333,,,,,,,,,,,,,,,,, -5500,0.37002355,4.6396713,,,,,,,,,,,,,,,,, -5600,0.3780744,4.5937195,,,,,,,,,,,,,,,,, -5700,0.36134058,4.5863957,,,,,,,,,,,,,,,,, -5800,0.33724973,4.557417,,,,,,,,,,,,,,,,, -5900,0.33655638,4.5410695,,,,,,,,,,,,,,,,, -6000,0.31709638,4.5775957,,,,,,,,,,,,,,,,, -6100,0.33898535,4.609502,,,,,,,,,,,,,,,,, -6200,0.31921685,4.5978274,,,,,,,,,,,,,,,,, -6300,0.28950605,4.572294,,,,,,,,,,,,,,,,, -6400,0.30981928,4.5223465,,,,,,,,,,,,,,,,, -6500,0.29485187,4.496591,,,,,,,,,,,,,,,,, -6600,0.29737028,4.4610777,,,,,,,,,,,,,,,,, -6700,0.27773675,4.4207945,,,,,,,,,,,,,,,,, -6800,0.27275217,4.4713063,,,,,,,,,,,,,,,,, -6900,0.25504676,4.458575,,,,,,,,,,,,,,,,, -7000,0.33406526,4.5501227,,,,,,,,,,,,,,,,, -7100,0.2569602,4.4085574,,,,,,,,,,,,,,,,, -7200,0.26265523,4.3848486,,,,,,,,,,,,,,,,, -7289,,,0.583653450012207,2.4412691593170166,27.028060189117536,0.5920695066452026,2.3615870475769043,23.554423470160216,3000.0,0.5970367789268494,2.35249662399292,22.367772014175003,3003.0,2546.6103324890137,4886.784422636032,2546.6103324890137,2339.8702476024628,0.0700273513793945,0.0 -7300,0.251458,4.35499,,,,,,,,,,,,,,,,, -7400,0.23995318,4.4287953,,,,,,,,,,,,,,,,, -7500,0.2299421,4.387822,,,,,,,,,,,,,,,,, -7600,0.23130456,4.4136596,,,,,,,,,,,,,,,,, -7700,0.22411247,4.4057045,,,,,,,,,,,,,,,,, -7800,0.21740465,4.288665,,,,,,,,,,,,,,,,, -7900,0.22643517,4.400508,,,,,,,,,,,,,,,,, -8000,0.21689557,4.399694,,,,,,,,,,,,,,,,, -8100,0.20273648,4.399587,,,,,,,,,,,,,,,,, -8200,0.21883182,4.3483706,,,,,,,,,,,,,,,,, -8300,0.21390258,4.332186,,,,,,,,,,,,,,,,, -8400,0.20977163,4.41306,,,,,,,,,,,,,,,,, -8500,0.20173243,4.3619666,,,,,,,,,,,,,,,,, -8600,0.24417591,4.306021,,,,,,,,,,,,,,,,, -8700,0.20623906,4.272008,,,,,,,,,,,,,,,,, -8800,0.23524359,4.3623314,,,,,,,,,,,,,,,,, -8900,0.21520326,4.2933326,,,,,,,,,,,,,,,,, -9000,0.21641186,4.3977523,,,,,,,,,,,,,,,,, -9100,0.2140721,4.306321,,,,,,,,,,,,,,,,, -9200,0.23645599,4.40175,,,,,,,,,,,,,,,,, -9300,0.2093663,4.2977276,,,,,,,,,,,,,,,,, -9400,0.2208344,4.350485,,,,,,,,,,,,,,,,, -9500,0.18286997,4.273016,,,,,,,,,,,,,,,,, -9600,0.20184384,4.2957,,,,,,,,,,,,,,,,, -9700,0.17896527,4.2407126,,,,,,,,,,,,,,,,, -9720,,,0.6012474894523621,2.2867743968963623,29.023278515494034,0.6191863417625427,2.1540048122406006,25.267400374292624,3000.0,0.6241241097450256,2.124663829803467,24.38662733001904,3003.0,3386.586151123047,6196.767947673798,3386.586151123047,2809.7746393680573,0.0957303047180175,0.0 -9800,0.21417764,4.270072,,,,,,,,,,,,,,,,, -9900,0.17594631,4.202403,,,,,,,,,,,,,,,,, -10000,0.1766806,4.274723,,,,,,,,,,,,,,,,, -10100,0.1894864,4.2261567,,,,,,,,,,,,,,,,, -10200,0.21253924,4.280235,,,,,,,,,,,,,,,,, -10300,0.1849691,4.1938148,,,,,,,,,,,,,,,,, -10400,0.17299335,4.2105446,,,,,,,,,,,,,,,,, -10500,0.17295864,4.2779574,,,,,,,,,,,,,,,,, -10600,0.17536801,4.181812,,,,,,,,,,,,,,,,, -10700,0.17627874,4.229937,,,,,,,,,,,,,,,,, -10800,0.18459095,4.26055,,,,,,,,,,,,,,,,, -10900,0.18001473,4.2544436,,,,,,,,,,,,,,,,, -11000,0.17849201,4.292294,,,,,,,,,,,,,,,,, -11100,0.20011128,4.1831126,,,,,,,,,,,,,,,,, -11200,0.1883405,4.237778,,,,,,,,,,,,,,,,, -11300,0.24342701,4.2109528,,,,,,,,,,,,,,,,, -11400,0.1673832,4.1639504,,,,,,,,,,,,,,,,, -11500,0.16630238,4.263939,,,,,,,,,,,,,,,,, -11600,0.1728169,4.2713146,,,,,,,,,,,,,,,,, -11700,0.16455585,4.1544027,,,,,,,,,,,,,,,,, -11800,0.18624689,4.243899,,,,,,,,,,,,,,,,, -11900,0.1875267,4.1910486,,,,,,,,,,,,,,,,, -12000,0.16727243,4.2512293,,,,,,,,,,,,,,,,, -12100,0.16999733,4.1555004,,,,,,,,,,,,,,,,, -12152,,,0.6117793917655945,2.2029600143432617,29.49963286982012,0.6326642036437988,2.0467686653137207,26.453320841038845,3000.0,0.6393585801124573,2.005385637283325,25.511380252988697,3003.0,4226.775480031967,7611.135657072067,4226.775480031967,3383.849452972412,0.1213636398315429,0.0 -12200,0.21005332,4.207713,,,,,,,,,,,,,,,,, -12300,0.17951657,4.2291536,,,,,,,,,,,,,,,,, -12400,0.17202514,4.1733828,,,,,,,,,,,,,,,,, -12500,0.18931405,4.1810913,,,,,,,,,,,,,,,,, -12600,0.17104864,4.119153,,,,,,,,,,,,,,,,, -12700,0.16051687,4.1047254,,,,,,,,,,,,,,,,, -12800,0.16361412,4.1872797,,,,,,,,,,,,,,,,, -12900,0.17719954,4.1647573,,,,,,,,,,,,,,,,, -13000,0.22839645,4.144459,,,,,,,,,,,,,,,,, -13100,0.17565618,4.1707563,,,,,,,,,,,,,,,,, -13200,0.17626207,4.1657887,,,,,,,,,,,,,,,,, -13300,0.15766591,4.164898,,,,,,,,,,,,,,,,, -13400,0.15796216,4.1916223,,,,,,,,,,,,,,,,, -13500,0.15746333,4.1230793,,,,,,,,,,,,,,,,, -13600,0.1949547,4.146139,,,,,,,,,,,,,,,,, -13700,0.18312311,4.1344447,,,,,,,,,,,,,,,,, -13800,0.17996444,4.158989,,,,,,,,,,,,,,,,, -13900,0.17600772,4.2015524,,,,,,,,,,,,,,,,, -14000,0.17153619,4.071762,,,,,,,,,,,,,,,,, -14100,0.16580199,4.129162,,,,,,,,,,,,,,,,, -14200,0.1963608,4.1020904,,,,,,,,,,,,,,,,, -14300,0.2048936,4.114494,,,,,,,,,,,,,,,,, -14400,0.1745205,4.155298,,,,,,,,,,,,,,,,, -14500,0.16918217,4.1372294,,,,,,,,,,,,,,,,, -14583,,,0.6235041618347168,2.090024471282959,30.151688486651715,0.642856240272522,1.947042465209961,26.98718339076088,3000.0,0.6523270010948181,1.8937307596206665,26.47813059844154,3003.0,5066.86608338356,9038.87869143486,5066.86608338356,3971.3980989456177,0.1472148895263672,0.0 -14600,0.16011553,4.2046375,,,,,,,,,,,,,,,,, -14700,0.15501858,4.109436,,,,,,,,,,,,,,,,, -14800,0.16599402,4.1364307,,,,,,,,,,,,,,,,, -14900,0.16730542,4.091846,,,,,,,,,,,,,,,,, -15000,0.18438856,4.0491076,,,,,,,,,,,,,,,,, -15100,0.18838039,4.161539,,,,,,,,,,,,,,,,, -15200,0.16953261,4.135013,,,,,,,,,,,,,,,,, -15300,0.20341612,4.1168647,,,,,,,,,,,,,,,,, -15400,0.15195197,4.0471606,,,,,,,,,,,,,,,,, -15500,0.16516212,4.1364994,,,,,,,,,,,,,,,,, -15600,0.1747058,4.1166124,,,,,,,,,,,,,,,,, -15700,0.17914405,4.0960913,,,,,,,,,,,,,,,,, -15800,0.1644029,4.0451117,,,,,,,,,,,,,,,,, -15900,0.15753882,4.12002,,,,,,,,,,,,,,,,, -16000,0.16569768,4.0885215,,,,,,,,,,,,,,,,, -16100,0.16983774,4.158625,,,,,,,,,,,,,,,,, -16200,0.16392876,4.0608587,,,,,,,,,,,,,,,,, -16300,0.23734646,4.063591,,,,,,,,,,,,,,,,, -16400,0.16386147,4.018138,,,,,,,,,,,,,,,,, -16500,0.16814956,4.067452,,,,,,,,,,,,,,,,, -16600,0.18735825,4.0699954,,,,,,,,,,,,,,,,, -16700,0.17039023,4.0698614,,,,,,,,,,,,,,,,, -16800,0.18261552,4.095457,,,,,,,,,,,,,,,,, -16900,0.15989794,4.072461,,,,,,,,,,,,,,,,, -17000,0.17133856,4.0736594,,,,,,,,,,,,,,,,, -17014,,,0.6282114386558533,2.054261445999145,30.50936542867486,0.6478034853935242,1.9196758270263672,27.620231439534933,3000.0,0.6583580374717712,1.854627013206482,26.80384903284937,3003.0,5906.852504491806,10337.573017835615,5906.852504491806,4430.000137805939,0.1750760078430175,0.0 -17100,0.22397046,4.106982,,,,,,,,,,,,,,,,, -17200,0.17839296,4.0430017,,,,,,,,,,,,,,,,, -17300,0.17969383,4.1035852,,,,,,,,,,,,,,,,, -17400,0.16661946,4.0539837,,,,,,,,,,,,,,,,, -17500,0.16288038,4.069442,,,,,,,,,,,,,,,,, -17600,0.15677524,4.1036716,,,,,,,,,,,,,,,,, -17700,0.19505228,4.0578275,,,,,,,,,,,,,,,,, -17800,0.14796223,4.0442877,,,,,,,,,,,,,,,,, -17900,0.17276518,4.078544,,,,,,,,,,,,,,,,, -18000,0.18592007,4.059007,,,,,,,,,,,,,,,,, -18100,0.2188565,4.094707,,,,,,,,,,,,,,,,, -18200,0.19737786,4.0608187,,,,,,,,,,,,,,,,, -18300,0.28009853,4.097874,,,,,,,,,,,,,,,,, -18400,0.16060059,4.0810857,,,,,,,,,,,,,,,,, -18500,0.17793378,4.0848475,,,,,,,,,,,,,,,,, -18600,0.16498175,4.0384755,,,,,,,,,,,,,,,,, -18700,0.1534776,4.024473,,,,,,,,,,,,,,,,, -18800,0.20284268,4.118767,,,,,,,,,,,,,,,,, -18900,0.22456998,3.9655457,,,,,,,,,,,,,,,,, -19000,0.15381739,4.0170245,,,,,,,,,,,,,,,,, -19100,0.15414624,4.0520663,,,,,,,,,,,,,,,,, -19200,0.26634815,4.040752,,,,,,,,,,,,,,,,, -19300,0.17494135,4.029645,,,,,,,,,,,,,,,,, -19400,0.17273721,4.0612783,,,,,,,,,,,,,,,,, -19445,,,0.6460630893707275,1.9484455585479736,31.32886972123233,0.652415931224823,1.8769932985305784,27.88024937862478,3000.0,0.6621463298797607,1.81953227519989,27.046071671892715,3003.0,6746.784756422043,11658.784047603607,6746.784756422043,4911.174456119537,0.201505422592163,0.0 -19500,0.16173108,4.0412507,,,,,,,,,,,,,,,,, -19600,0.1775547,4.017233,,,,,,,,,,,,,,,,, -19700,0.1590311,4.053435,,,,,,,,,,,,,,,,, -19800,0.18357739,4.120497,,,,,,,,,,,,,,,,, -19900,0.21766149,4.048386,,,,,,,,,,,,,,,,, -20000,0.19097242,4.049173,,,,,,,,,,,,,,,,, -20100,0.26440367,4.0044985,,,,,,,,,,,,,,,,, -20200,0.16152573,4.014699,,,,,,,,,,,,,,,,, -20300,0.1984797,4.0371757,,,,,,,,,,,,,,,,, -20400,0.15613677,4.0782127,,,,,,,,,,,,,,,,, -20500,0.2266611,4.037351,,,,,,,,,,,,,,,,, -20600,0.16723602,4.066018,,,,,,,,,,,,,,,,, -20700,0.1629171,4.03049,,,,,,,,,,,,,,,,, -20800,0.2110301,4.0079474,,,,,,,,,,,,,,,,, -20900,0.16629468,3.9711437,,,,,,,,,,,,,,,,, -21000,0.21106564,3.9863682,,,,,,,,,,,,,,,,, -21100,0.18690342,3.9871566,,,,,,,,,,,,,,,,, -21200,0.1674266,4.00262,,,,,,,,,,,,,,,,, -21300,0.25887522,4.0193,,,,,,,,,,,,,,,,, -21400,0.15951854,4.023148,,,,,,,,,,,,,,,,, -21500,0.16978285,4.0288515,,,,,,,,,,,,,,,,, -21600,0.19187488,4.0005875,,,,,,,,,,,,,,,,, -21700,0.1788812,3.9743862,,,,,,,,,,,,,,,,, -21800,0.223,4.0114045,,,,,,,,,,,,,,,,, -21876,,,0.6394797563552856,1.9793699979782104,31.291298722145083,0.6581319570541382,1.8350783586502075,28.01187089919093,3000.0,0.6679449081420898,1.7722675800323486,27.51437928438289,3003.0,7586.71210360527,12958.286858320236,7586.71210360527,5370.643972635269,0.2301142215728759,0.0 -21900,0.21864478,4.075962,,,,,,,,,,,,,,,,, -22000,0.17627187,4.0491076,,,,,,,,,,,,,,,,, -22100,0.17378405,4.042076,,,,,,,,,,,,,,,,, -22200,0.2875796,4.0845404,,,,,,,,,,,,,,,,, -22300,0.17595407,3.9377294,,,,,,,,,,,,,,,,, -22400,0.18519011,4.0563354,,,,,,,,,,,,,,,,, -22500,0.15978678,4.036182,,,,,,,,,,,,,,,,, -22600,0.17236903,3.9852104,,,,,,,,,,,,,,,,, -22700,0.22103433,4.035254,,,,,,,,,,,,,,,,, -22800,0.17548445,3.987738,,,,,,,,,,,,,,,,, -22900,0.21858485,4.008193,,,,,,,,,,,,,,,,, -23000,0.18406892,4.0698504,,,,,,,,,,,,,,,,, -23100,0.16984168,4.022645,,,,,,,,,,,,,,,,, -23200,0.16837691,3.9796262,,,,,,,,,,,,,,,,, -23300,0.18242075,3.9994905,,,,,,,,,,,,,,,,, -23400,0.20624758,4.0021844,,,,,,,,,,,,,,,,, -23500,0.16886692,3.9833488,,,,,,,,,,,,,,,,, -23600,0.23516046,3.9941921,,,,,,,,,,,,,,,,, -23700,0.17489274,3.9532404,,,,,,,,,,,,,,,,, -23800,0.17544031,4.009988,,,,,,,,,,,,,,,,, -23900,0.18476997,4.069125,,,,,,,,,,,,,,,,, -24000,0.19071813,3.961437,,,,,,,,,,,,,,,,, -24100,0.29238242,4.024887,,,,,,,,,,,,,,,,, -24200,0.2981303,4.029534,,,,,,,,,,,,,,,,, -24300,0.20162284,4.0132847,,,,,,,,,,,,,,,,, -24307,,,0.6397490501403809,1.9547765254974363,30.85258301009002,0.6600165963172913,1.805528998374939,28.20852868398465,3000.0,0.672105073928833,1.7438435554504397,27.47958904209218,3003.0,8426.640635728836,14495.81839632988,8426.640635728836,6068.142309427261,0.2572338581085205,0.0 -24400,0.20736061,3.9945493,,,,,,,,,,,,,,,,, -24500,0.1938085,4.00246,,,,,,,,,,,,,,,,, -24600,0.20785256,4.030732,,,,,,,,,,,,,,,,, -24700,0.18445498,3.9950798,,,,,,,,,,,,,,,,, -24800,0.1793054,4.003248,,,,,,,,,,,,,,,,, -24900,0.2043523,3.9562354,,,,,,,,,,,,,,,,, -25000,0.18533427,3.94858,,,,,,,,,,,,,,,,, -25100,0.18465137,3.9796326,,,,,,,,,,,,,,,,, -25200,0.20894103,3.9932477,,,,,,,,,,,,,,,,, -25300,0.2547198,4.067405,,,,,,,,,,,,,,,,, -25400,0.17790636,3.9645529,,,,,,,,,,,,,,,,, -25500,0.2358421,3.9026232,,,,,,,,,,,,,,,,, -25600,0.20024422,3.9284565,,,,,,,,,,,,,,,,, -25700,0.26882324,4.007994,,,,,,,,,,,,,,,,, -25800,0.17953885,4.08826,,,,,,,,,,,,,,,,, -25900,0.20887046,3.92022,,,,,,,,,,,,,,,,, -26000,0.44421676,3.9406543,,,,,,,,,,,,,,,,, -26100,0.17671277,3.9393086,,,,,,,,,,,,,,,,, -26200,0.17829128,4.003883,,,,,,,,,,,,,,,,, -26300,0.17857139,3.9631383,,,,,,,,,,,,,,,,, -26400,0.18987833,4.0042367,,,,,,,,,,,,,,,,, -26500,0.17567907,3.9681368,,,,,,,,,,,,,,,,, -26600,0.19777878,4.0207043,,,,,,,,,,,,,,,,, -26700,0.19685996,3.9745183,,,,,,,,,,,,,,,,, -26738,,,0.6513376832008362,1.8955870866775515,31.97804613904664,0.6653606295585632,1.7924036979675293,28.451939799484045,3000.0,0.6745221018791199,1.7301757335662842,28.200531420440008,3003.0,9266.780019283296,16030.644038438795,9266.780019283296,6762.721871137619,0.286259651184082,0.0 -26800,0.2502606,3.977285,,,,,,,,,,,,,,,,, -26900,0.2854922,4.023392,,,,,,,,,,,,,,,,, -27000,0.1725611,3.914306,,,,,,,,,,,,,,,,, -27100,0.20475091,3.9323585,,,,,,,,,,,,,,,,, -27200,0.18646543,3.936383,,,,,,,,,,,,,,,,, -27300,0.24931769,3.9866676,,,,,,,,,,,,,,,,, -27400,0.19226436,3.89414,,,,,,,,,,,,,,,,, -27500,0.23510815,3.9994917,,,,,,,,,,,,,,,,, -27600,0.24171285,3.966964,,,,,,,,,,,,,,,,, -27700,0.18084365,3.9515626,,,,,,,,,,,,,,,,, -27800,0.2156832,3.9845762,,,,,,,,,,,,,,,,, -27900,0.18905124,3.9589884,,,,,,,,,,,,,,,,, -28000,0.30541632,3.9806678,,,,,,,,,,,,,,,,, -28100,0.19518775,3.95503,,,,,,,,,,,,,,,,, -28200,0.20671883,3.9588912,,,,,,,,,,,,,,,,, -28300,0.20162664,3.9336112,,,,,,,,,,,,,,,,, -28400,0.2347062,3.9918644,,,,,,,,,,,,,,,,, -28500,0.26475623,3.9947731,,,,,,,,,,,,,,,,, -28600,0.20243825,3.9868395,,,,,,,,,,,,,,,,, -28700,0.24811515,3.994183,,,,,,,,,,,,,,,,, -28800,0.19688733,3.9757028,,,,,,,,,,,,,,,,, -28900,0.18985401,3.9770463,,,,,,,,,,,,,,,,, -29000,0.20327063,4.0707808,,,,,,,,,,,,,,,,, -29100,0.20834683,3.943301,,,,,,,,,,,,,,,,, -29170,,,0.6469125747680664,1.9222787618637085,32.02554933303594,0.6672824621200562,1.7752902507781982,28.697036519419203,3000.0,0.6792516708374023,1.7129313945770264,28.4535721097046,3003.0,10107.006126642227,17344.64367866516,10107.006126642227,7236.388736963272,0.3159620761871338,0.0 -29200,0.42896864,3.939378,,,,,,,,,,,,,,,,, -29300,0.22600828,3.9667704,,,,,,,,,,,,,,,,, -29400,0.19185853,3.9569209,,,,,,,,,,,,,,,,, -29500,0.19597597,3.924145,,,,,,,,,,,,,,,,, -29600,0.19518793,3.9311228,,,,,,,,,,,,,,,,, -29700,0.20565653,3.9913971,,,,,,,,,,,,,,,,, -29800,0.21729602,3.9563901,,,,,,,,,,,,,,,,, -29900,0.23865283,3.913205,,,,,,,,,,,,,,,,, -30000,0.22397487,3.9321747,,,,,,,,,,,,,,,,, -30100,0.19587606,3.8952136,,,,,,,,,,,,,,,,, -30200,0.30584705,3.998295,,,,,,,,,,,,,,,,, -30300,0.39941868,3.925077,,,,,,,,,,,,,,,,, -30400,0.19055034,3.9131498,,,,,,,,,,,,,,,,, -30500,0.27521208,3.9736528,,,,,,,,,,,,,,,,, -30600,0.21347524,3.940785,,,,,,,,,,,,,,,,, -30700,0.20686361,3.9177403,,,,,,,,,,,,,,,,, -30800,0.2901822,3.9573982,,,,,,,,,,,,,,,,, -30900,0.2351549,3.852226,,,,,,,,,,,,,,,,, -31000,0.23509045,3.921664,,,,,,,,,,,,,,,,, -31100,0.3025125,3.9113975,,,,,,,,,,,,,,,,, -31200,0.2353992,3.904643,,,,,,,,,,,,,,,,, -31300,0.32959053,3.9442394,,,,,,,,,,,,,,,,, -31400,0.1915766,3.8759873,,,,,,,,,,,,,,,,, -31500,0.39481008,3.924739,,,,,,,,,,,,,,,,, -31600,0.21837474,3.9711015,,,,,,,,,,,,,,,,, -31601,,,0.6717249155044556,1.7598832845687866,33.27960037029534,0.6671584844589233,1.7688673734664917,29.136006467228604,3000.0,0.679449200630188,1.6971747875213623,28.46140712600896,3003.0,10947.211053609848,18654.48917913437,10947.211053609848,7705.922129631042,0.3449394702911377,0.0 -31700,0.20572269,3.9380896,,,,,,,,,,,,,,,,, -31800,0.20864716,3.9614925,,,,,,,,,,,,,,,,, -31900,0.26489756,3.9645126,,,,,,,,,,,,,,,,, -32000,0.21515721,3.9376721,,,,,,,,,,,,,,,,, -32100,0.22068639,3.9807048,,,,,,,,,,,,,,,,, -32200,0.20752013,3.916014,,,,,,,,,,,,,,,,, -32300,0.21947415,3.8904848,,,,,,,,,,,,,,,,, -32400,0.2811066,3.9155855,,,,,,,,,,,,,,,,, -32500,0.20965843,3.899996,,,,,,,,,,,,,,,,, -32600,0.22299737,3.8732963,,,,,,,,,,,,,,,,, -32700,0.26605186,3.966334,,,,,,,,,,,,,,,,, -32800,0.21103598,3.9683888,,,,,,,,,,,,,,,,, -32900,0.21723638,3.959103,,,,,,,,,,,,,,,,, -33000,0.21078588,3.9365065,,,,,,,,,,,,,,,,, -33100,0.22429173,3.945,,,,,,,,,,,,,,,,, -33200,0.20750548,3.9428043,,,,,,,,,,,,,,,,, -33300,0.3114483,3.937591,,,,,,,,,,,,,,,,, -33400,0.23731312,3.9863675,,,,,,,,,,,,,,,,, -33500,0.21046647,3.947239,,,,,,,,,,,,,,,,, -33600,0.25179774,3.9279904,,,,,,,,,,,,,,,,, -33700,0.21978673,3.9274752,,,,,,,,,,,,,,,,, -33800,0.2091216,3.9466958,,,,,,,,,,,,,,,,, -33900,0.22741845,3.8666408,,,,,,,,,,,,,,,,, -34000,0.29187676,3.94688,,,,,,,,,,,,,,,,, -34032,,,0.6559919714927673,1.8558120727539065,32.682312093461405,0.6694275140762329,1.746963381767273,29.291129779033984,3000.0,0.6810644268989563,1.6823828220367432,28.599915373846635,3003.0,11787.294679403303,20008.52248167992,11787.294679403303,8219.765256166458,0.3742711544036865,0.0 -34100,0.2309864,3.9463942,,,,,,,,,,,,,,,,, -34200,0.26453823,3.9597464,,,,,,,,,,,,,,,,, -34300,0.23162307,3.8898318,,,,,,,,,,,,,,,,, -34400,0.20178454,3.8732471,,,,,,,,,,,,,,,,, -34500,0.2342846,3.9235651,,,,,,,,,,,,,,,,, -34600,0.12786525,7.8551197,,,,,,,,,,,,,,,,, -34700,0.37547186,5.90557,,,,,,,,,,,,,,,,, -34800,0.34252825,5.6041136,,,,,,,,,,,,,,,,, -34900,0.40060127,5.5239487,,,,,,,,,,,,,,,,, -35000,3.1821861,4.7891054,,,,,,,,,,,,,,,,, -35100,0.2266117,4.063758,,,,,,,,,,,,,,,,, -35200,0.19646071,3.9284103,,,,,,,,,,,,,,,,, -35300,0.21893488,3.9980624,,,,,,,,,,,,,,,,, -35400,0.21299605,3.9360015,,,,,,,,,,,,,,,,, -35500,0.21017382,3.925595,,,,,,,,,,,,,,,,, -35600,0.21561958,3.8820496,,,,,,,,,,,,,,,,, -35700,0.20425856,3.922149,,,,,,,,,,,,,,,,, -35800,0.22173297,3.9334714,,,,,,,,,,,,,,,,, -35900,0.21545957,3.8922865,,,,,,,,,,,,,,,,, -36000,0.19853634,3.9522743,,,,,,,,,,,,,,,,, -36100,0.2311498,3.9628284,,,,,,,,,,,,,,,,, -36200,0.24766468,3.9075801,,,,,,,,,,,,,,,,, -36300,0.22713701,3.9069262,,,,,,,,,,,,,,,,, -36400,0.21492599,3.9742513,,,,,,,,,,,,,,,,, -36463,,,0.6533215045928955,1.8728480339050293,32.04000187532934,0.6712626218795776,1.741132140159607,28.41640521392595,3000.0,0.682819128036499,1.669180154800415,29.09519909837025,3003.0,12627.530505895616,21463.58862900734,12627.530505895616,8834.486620664597,0.4038436412811279,0.0 -36500,0.21019246,3.8612099,,,,,,,,,,,,,,,,, -36600,0.22599156,3.9492178,,,,,,,,,,,,,,,,, -36700,0.24321888,3.885894,,,,,,,,,,,,,,,,, -36800,0.2203184,3.942472,,,,,,,,,,,,,,,,, -36900,0.23158656,3.9464579,,,,,,,,,,,,,,,,, -37000,0.27995694,3.8983288,,,,,,,,,,,,,,,,, -37100,0.21955152,3.892959,,,,,,,,,,,,,,,,, -37200,0.27889946,3.9430711,,,,,,,,,,,,,,,,, -37300,0.29897028,3.9037142,,,,,,,,,,,,,,,,, -37400,0.28899777,3.9288216,,,,,,,,,,,,,,,,, -37500,0.22380154,3.937736,,,,,,,,,,,,,,,,, -37600,0.27390707,3.9115987,,,,,,,,,,,,,,,,, -37700,0.2278797,3.9492612,,,,,,,,,,,,,,,,, -37800,0.19990368,3.8762956,,,,,,,,,,,,,,,,, -37900,0.2591418,3.9598324,,,,,,,,,,,,,,,,, -38000,0.21594948,3.906,,,,,,,,,,,,,,,,, -38100,0.21215482,3.8789144,,,,,,,,,,,,,,,,, -38200,0.23411036,3.9188442,,,,,,,,,,,,,,,,, -38300,0.21368514,3.9095144,,,,,,,,,,,,,,,,, -38400,0.27446267,3.8811922,,,,,,,,,,,,,,,,, -38500,0.2322334,3.8939366,,,,,,,,,,,,,,,,, -38600,0.21129113,3.9550333,,,,,,,,,,,,,,,,, -38700,0.2257584,3.8742754,,,,,,,,,,,,,,,,, -38800,0.24507302,3.8890572,,,,,,,,,,,,,,,,, -38894,,,0.6603764891624451,1.8058222532272337,32.47630884958098,0.6701342463493347,1.730035662651062,29.079701147488368,3000.0,0.6835280060768127,1.6568653583526611,28.82156068715825,3003.0,13467.513338565826,22775.415242671967,13467.513338565826,9306.22245168686,0.4338409900665283,0.0 -38900,0.2174341,3.9448986,,,,,,,,,,,,,,,,, -39000,0.24498907,3.855453,,,,,,,,,,,,,,,,, -39100,0.21253589,3.9098048,,,,,,,,,,,,,,,,, -39200,0.23734163,3.9354436,,,,,,,,,,,,,,,,, -39300,0.31010684,3.990172,,,,,,,,,,,,,,,,, -39400,0.24153675,3.904541,,,,,,,,,,,,,,,,, -39500,0.22835201,3.8855066,,,,,,,,,,,,,,,,, -39600,0.24072875,3.8923447,,,,,,,,,,,,,,,,, -39700,0.22945121,3.9028668,,,,,,,,,,,,,,,,, -39800,0.280219,3.8946805,,,,,,,,,,,,,,,,, -39900,0.24222688,3.8514106,,,,,,,,,,,,,,,,, -40000,0.24094988,3.8776956,,,,,,,,,,,,,,,,, -40100,0.2530475,3.864243,,,,,,,,,,,,,,,,, -40200,0.22334579,3.8861794,,,,,,,,,,,,,,,,, -40300,0.23683643,3.9438832,,,,,,,,,,,,,,,,, -40400,0.2928067,3.9546266,,,,,,,,,,,,,,,,, -40500,0.31553093,3.9006667,,,,,,,,,,,,,,,,, -40600,0.25142422,3.8564804,,,,,,,,,,,,,,,,, -40700,0.25997567,3.9020333,,,,,,,,,,,,,,,,, -40800,0.23609754,3.9258173,,,,,,,,,,,,,,,,, -40900,0.24284942,3.978713,,,,,,,,,,,,,,,,, -41000,0.29099137,3.870062,,,,,,,,,,,,,,,,, -41100,0.23899908,3.9560003,,,,,,,,,,,,,,,,, -41200,0.20746745,3.8602045,,,,,,,,,,,,,,,,, -41300,0.2402448,3.9239094,,,,,,,,,,,,,,,,, -41325,,,0.6599708795547485,1.8201743364334104,32.381106417776536,0.6728744506835938,1.712181806564331,29.184639902057,3000.0,0.6861310005187988,1.6359333992004397,29.236897732583365,3003.0,14307.620192050934,24121.71455836296,14307.620192050934,9812.306602954865,0.464618444442749,0.0 -41400,0.28746602,3.916014,,,,,,,,,,,,,,,,, -41500,0.21645007,3.803945,,,,,,,,,,,,,,,,, -41600,0.2476837,3.950689,,,,,,,,,,,,,,,,, -41700,0.23305817,3.8551204,,,,,,,,,,,,,,,,, -41800,0.29994154,3.957037,,,,,,,,,,,,,,,,, -41900,0.23310001,3.9793026,,,,,,,,,,,,,,,,, -42000,0.23890051,3.889383,,,,,,,,,,,,,,,,, -42100,0.25366,3.981354,,,,,,,,,,,,,,,,, -42200,0.22734256,3.9199958,,,,,,,,,,,,,,,,, -42300,0.24198543,3.9591408,,,,,,,,,,,,,,,,, -42400,0.26468113,3.8793027,,,,,,,,,,,,,,,,, -42500,0.27644783,3.889747,,,,,,,,,,,,,,,,, -42600,0.2553274,3.9066281,,,,,,,,,,,,,,,,, -42700,0.24622421,3.863234,,,,,,,,,,,,,,,,, -42800,0.2520887,3.912785,,,,,,,,,,,,,,,,, -42900,0.23693797,3.8949876,,,,,,,,,,,,,,,,, -43000,0.23889108,3.9248857,,,,,,,,,,,,,,,,, -43100,0.3065046,3.89571,,,,,,,,,,,,,,,,, -43200,0.23289266,3.8552833,,,,,,,,,,,,,,,,, -43300,0.22213429,3.8832366,,,,,,,,,,,,,,,,, -43400,0.26498106,3.8712666,,,,,,,,,,,,,,,,, -43500,0.299514,3.8627458,,,,,,,,,,,,,,,,, -43600,0.2915491,3.8815491,,,,,,,,,,,,,,,,, -43700,0.24870807,3.8756692,,,,,,,,,,,,,,,,, -43756,,,0.6520457863807678,1.86770761013031,32.20933811417118,0.6751186847686768,1.7147032022476196,29.25031655424665,3000.0,0.6876997351646423,1.6349014043807983,29.037590233653013,3003.0,15147.648758888245,25533.22567415237,15147.648758888245,10383.679992437364,0.4959902763366699,0.0 -43800,0.28245455,3.9377544,,,,,,,,,,,,,,,,, -43900,0.03697965,7.9848847,,,,,,,,,,,,,,,,, -44000,0.3696148,7.2997684,,,,,,,,,,,,,,,,, -44100,0.8531296,5.980448,,,,,,,,,,,,,,,,, -44200,0.53080326,5.545892,,,,,,,,,,,,,,,,, -44300,0.42867544,5.54358,,,,,,,,,,,,,,,,, -44400,0.3994902,5.477238,,,,,,,,,,,,,,,,, -44500,0.7397384,5.4460363,,,,,,,,,,,,,,,,, -44600,1.5603817,5.447281,,,,,,,,,,,,,,,,, -44700,0.41557726,5.6352224,,,,,,,,,,,,,,,,, -44800,0.66567653,5.470024,,,,,,,,,,,,,,,,, -44900,0.94829226,5.405342,,,,,,,,,,,,,,,,, -45000,0.92754644,5.1328096,,,,,,,,,,,,,,,,, -45100,0.32247192,4.0082145,,,,,,,,,,,,,,,,, -45200,0.22982807,3.9385202,,,,,,,,,,,,,,,,, -45300,0.2291847,3.8996332,,,,,,,,,,,,,,,,, -45400,0.23513491,3.9327137,,,,,,,,,,,,,,,,, -45500,0.24408981,3.9002526,,,,,,,,,,,,,,,,, -45600,0.21874042,3.8872123,,,,,,,,,,,,,,,,, -45700,0.23368362,3.9292011,,,,,,,,,,,,,,,,, -45800,0.2964335,3.895436,,,,,,,,,,,,,,,,, -45900,0.24073794,3.9107468,,,,,,,,,,,,,,,,, -46000,0.24357495,3.8898475,,,,,,,,,,,,,,,,, -46100,0.22835572,3.9346714,,,,,,,,,,,,,,,,, -46189,,,0.6639376878738403,1.786540985107422,32.48354802392069,0.6739283800125122,1.713200569152832,29.52103293408988,3000.0,0.6856428980827332,1.6498336791992188,29.08447388429547,3003.0,15987.78390431404,26852.710567712784,15987.78390431404,10862.922046422958,0.5261423587799072,0.0 -46200,0.28408194,3.9015248,,,,,,,,,,,,,,,,, -46300,0.3311431,3.8409882,,,,,,,,,,,,,,,,, -46400,0.2486086,3.8321202,,,,,,,,,,,,,,,,, -46500,0.2410142,3.9881613,,,,,,,,,,,,,,,,, -46600,0.28824162,3.8981388,,,,,,,,,,,,,,,,, -46700,0.27087238,3.8997593,,,,,,,,,,,,,,,,, -46800,0.2412581,3.8866787,,,,,,,,,,,,,,,,, -46900,0.27371365,3.8532789,,,,,,,,,,,,,,,,, -47000,0.23494896,3.832459,,,,,,,,,,,,,,,,, -47100,0.22370437,3.9001205,,,,,,,,,,,,,,,,, -47200,0.2751121,3.9210682,,,,,,,,,,,,,,,,, -47300,0.23880441,3.8441143,,,,,,,,,,,,,,,,, -47400,0.23347221,3.9324973,,,,,,,,,,,,,,,,, -47500,0.2700099,3.8846498,,,,,,,,,,,,,,,,, -47600,0.27069935,3.8347437,,,,,,,,,,,,,,,,, -47700,0.31139278,3.8674905,,,,,,,,,,,,,,,,, -47800,0.34706515,3.8974192,,,,,,,,,,,,,,,,, -47900,0.29393348,3.888967,,,,,,,,,,,,,,,,, -48000,0.2623786,3.8759518,,,,,,,,,,,,,,,,, -48100,0.2435002,3.8970358,,,,,,,,,,,,,,,,, -48200,0.24857827,3.932195,,,,,,,,,,,,,,,,, -48300,0.23567829,3.836247,,,,,,,,,,,,,,,,, -48400,0.30865332,3.8436425,,,,,,,,,,,,,,,,, -48500,0.23965943,3.8866885,,,,,,,,,,,,,,,,, -48600,0.30426303,3.922462,,,,,,,,,,,,,,,,, -48620,,,0.6577057242393494,1.8336910009384155,32.11022349497067,0.6760362386703491,1.7174301147460938,29.52968287144964,3000.0,0.6883621215820312,1.642948031425476,29.228574269408423,3003.0,16827.680357694626,28166.43555521965,16827.680357694626,11336.640148878098,0.5586118698120117,0.0 -48700,0.28992745,3.8848634,,,,,,,,,,,,,,,,, -48800,0.24876858,3.8488564,,,,,,,,,,,,,,,,, -48900,0.23433763,3.8653612,,,,,,,,,,,,,,,,, -49000,0.3414132,3.8308134,,,,,,,,,,,,,,,,, -49100,0.24560654,3.912461,,,,,,,,,,,,,,,,, -49200,0.259513,3.8054175,,,,,,,,,,,,,,,,, -49300,0.27789396,3.8926668,,,,,,,,,,,,,,,,, -49400,0.24661222,3.8612974,,,,,,,,,,,,,,,,, -49500,0.23107147,3.89774,,,,,,,,,,,,,,,,, -49600,0.2617864,3.871775,,,,,,,,,,,,,,,,, -49700,0.2848986,3.8758142,,,,,,,,,,,,,,,,, -49800,0.26758757,3.8876214,,,,,,,,,,,,,,,,, -49900,0.25757554,3.9251463,,,,,,,,,,,,,,,,, -50000,0.25139183,3.8546567,,,,,,,,,,,,,,,,, -50100,0.3054822,3.8803186,,,,,,,,,,,,,,,,, -50200,0.26210815,3.851182,,,,,,,,,,,,,,,,, -50300,0.23847547,3.8599856,,,,,,,,,,,,,,,,, -50400,0.24549367,3.9171982,,,,,,,,,,,,,,,,, -50500,0.23150454,3.8663077,,,,,,,,,,,,,,,,, -50600,0.23230231,3.8290179,,,,,,,,,,,,,,,,, -50700,0.23112793,3.9285238,,,,,,,,,,,,,,,,, -50800,0.24977805,3.8661106,,,,,,,,,,,,,,,,, -50900,0.3039418,3.939455,,,,,,,,,,,,,,,,, -51000,0.30750296,3.8631406,,,,,,,,,,,,,,,,, -51051,,,0.6678268909454346,1.75491201877594,33.25293797178625,0.6759990453720093,1.6939091682434082,29.552817707166724,3000.0,0.6865725517272949,1.620548129081726,28.930613904916136,3003.0,17667.6850566864,29472.580087661743,17667.6850566864,11802.670159101486,0.5905261039733887,0.0 -51100,0.23039356,3.9019084,,,,,,,,,,,,,,,,, -51200,0.3187501,3.9247417,,,,,,,,,,,,,,,,, -51300,0.25153273,3.8875177,,,,,,,,,,,,,,,,, -51400,0.29653752,3.8633897,,,,,,,,,,,,,,,,, -51500,0.26428822,3.852942,,,,,,,,,,,,,,,,, -51600,0.33752656,3.8763902,,,,,,,,,,,,,,,,, -51700,0.26894853,3.8506062,,,,,,,,,,,,,,,,, -51800,0.2882455,3.9431257,,,,,,,,,,,,,,,,, -51900,0.22878921,3.8251436,,,,,,,,,,,,,,,,, -52000,0.26469612,3.8092556,,,,,,,,,,,,,,,,, -52100,0.2788928,3.8834143,,,,,,,,,,,,,,,,, -52200,0.25808173,3.876122,,,,,,,,,,,,,,,,, -52300,0.29464626,3.8447351,,,,,,,,,,,,,,,,, -52400,0.24412382,3.8816085,,,,,,,,,,,,,,,,, -52500,0.24483919,3.8533456,,,,,,,,,,,,,,,,, -52600,0.33894956,3.821558,,,,,,,,,,,,,,,,, -52700,0.25931793,3.848975,,,,,,,,,,,,,,,,, -52800,0.2570355,3.8975556,,,,,,,,,,,,,,,,, -52900,0.26892126,3.9006379,,,,,,,,,,,,,,,,, -53000,0.25602013,3.871201,,,,,,,,,,,,,,,,, -53100,0.23304246,3.8562412,,,,,,,,,,,,,,,,, -53200,0.22680613,3.8461778,,,,,,,,,,,,,,,,, -53300,0.28336826,3.852583,,,,,,,,,,,,,,,,, -53400,0.2474982,3.848419,,,,,,,,,,,,,,,,, -53483,,,0.6628263592720032,1.7955244779586792,32.78826064342799,0.6780325174331665,1.688191056251526,29.68220960991648,3000.0,0.693265974521637,1.6090433597564695,29.551221567046127,3003.0,18507.86905694008,30843.158007144928,18507.86905694008,12332.954099178314,0.6230416297912598,0.0 -53500,0.255718,3.8873966,,,,,,,,,,,,,,,,, -53600,0.3588582,3.905414,,,,,,,,,,,,,,,,, -53700,0.37343216,3.9133692,,,,,,,,,,,,,,,,, -53800,0.26381177,3.861986,,,,,,,,,,,,,,,,, -53900,0.2290356,3.8249865,,,,,,,,,,,,,,,,, -54000,0.2693464,3.8576624,,,,,,,,,,,,,,,,, -54100,0.26402578,3.8062224,,,,,,,,,,,,,,,,, -54200,0.27434838,3.8523583,,,,,,,,,,,,,,,,, -54300,0.24764794,3.850646,,,,,,,,,,,,,,,,, -54400,0.29702458,3.838414,,,,,,,,,,,,,,,,, -54500,0.25271228,3.8240647,,,,,,,,,,,,,,,,, -54600,0.25522813,3.8771405,,,,,,,,,,,,,,,,, -54700,0.24570474,3.8550954,,,,,,,,,,,,,,,,, -54800,0.2822752,3.85463,,,,,,,,,,,,,,,,, -54900,0.27296987,3.8964672,,,,,,,,,,,,,,,,, -55000,0.23260516,3.8328025,,,,,,,,,,,,,,,,, -55100,0.25121418,3.9018965,,,,,,,,,,,,,,,,, -55200,0.26157972,3.863937,,,,,,,,,,,,,,,,, -55300,0.25545743,3.8528879,,,,,,,,,,,,,,,,, -55400,0.30176944,3.8346088,,,,,,,,,,,,,,,,, -55500,0.29449737,3.8391447,,,,,,,,,,,,,,,,, -55600,0.28331667,3.8625312,,,,,,,,,,,,,,,,, -55700,0.28391564,3.8382573,,,,,,,,,,,,,,,,, -55800,0.26528278,3.8925872,,,,,,,,,,,,,,,,, -55900,0.251223,3.8606865,,,,,,,,,,,,,,,,, -55914,,,0.6604369282722473,1.80417549610138,33.13701132138776,0.6777101159095764,1.676916241645813,29.61541145915985,3000.0,0.6930218935012817,1.5961912870407104,29.3509889737326,3003.0,19347.87039089203,32146.7134320736,19347.87039089203,12796.398730516434,0.6548542976379395,0.0 -56000,0.27824324,3.8768904,,,,,,,,,,,,,,,,, -56100,0.24127075,3.7628036,,,,,,,,,,,,,,,,, -56200,0.24669066,3.8269696,,,,,,,,,,,,,,,,, -56300,0.26728082,3.8855581,,,,,,,,,,,,,,,,, -56400,0.34044856,3.8588817,,,,,,,,,,,,,,,,, -56500,0.26974642,3.87489,,,,,,,,,,,,,,,,, -56600,0.23853505,3.803246,,,,,,,,,,,,,,,,, -56700,0.23586997,3.8948724,,,,,,,,,,,,,,,,, -56800,0.25172725,3.85723,,,,,,,,,,,,,,,,, -56900,0.30760407,3.8589387,,,,,,,,,,,,,,,,, -57000,0.23738727,3.8343742,,,,,,,,,,,,,,,,, -57100,1.000504,6.0067916,,,,,,,,,,,,,,,,, -57200,0.39733118,5.4908714,,,,,,,,,,,,,,,,, -57300,0.47814658,5.4667835,,,,,,,,,,,,,,,,, -57400,0.31668773,5.448461,,,,,,,,,,,,,,,,, -57500,0.34830284,5.428177,,,,,,,,,,,,,,,,, -57600,0.54247516,5.391367,,,,,,,,,,,,,,,,, -57700,0.40031695,5.391412,,,,,,,,,,,,,,,,, -57800,0.41851184,5.379458,,,,,,,,,,,,,,,,, -57900,1.1971883,5.3450994,,,,,,,,,,,,,,,,, -58000,1.1504794,5.3721347,,,,,,,,,,,,,,,,, -58100,0.79986686,5.332372,,,,,,,,,,,,,,,,, -58200,0.7657688,5.3042283,,,,,,,,,,,,,,,,, -58300,1.2589587,5.3072786,,,,,,,,,,,,,,,,, -58348,,,0.3733042478561401,3.6769845485687256,1.7790266662847811,0.3269147276878357,4.14604377746582,0.3167417668719349,3000.0,0.3172041177749634,4.3111090660095215,0.2760901163922393,3003.0,20187.9033575058,33801.09925675392,20187.9033575058,13610.639861106873,0.6897470951080322,0.0 -58400,1.3828009,5.268264,,,,,,,,,,,,,,,,, -58500,2.3988132,5.15802,,,,,,,,,,,,,,,,, -58600,0.2665735,3.9455125,,,,,,,,,,,,,,,,, -58700,0.2844873,3.9067812,,,,,,,,,,,,,,,,, -58800,0.2709341,3.9058638,,,,,,,,,,,,,,,,, -58900,0.24271491,3.8535998,,,,,,,,,,,,,,,,, -59000,0.2618744,3.801081,,,,,,,,,,,,,,,,, -59100,0.23761958,3.8706198,,,,,,,,,,,,,,,,, -59200,0.2634926,3.851739,,,,,,,,,,,,,,,,, -59300,0.2900863,3.9287183,,,,,,,,,,,,,,,,, -59400,0.25276473,3.8404148,,,,,,,,,,,,,,,,, -59500,0.2652356,3.8337486,,,,,,,,,,,,,,,,, -59600,0.25332862,3.7947695,,,,,,,,,,,,,,,,, -59700,0.25778136,3.8686116,,,,,,,,,,,,,,,,, -59800,0.24498256,3.8295605,,,,,,,,,,,,,,,,, -59900,0.25101775,3.810486,,,,,,,,,,,,,,,,, -60000,0.34755123,3.8387318,,,,,,,,,,,,,,,,, -60100,0.25223964,3.8680036,,,,,,,,,,,,,,,,, -60200,0.25331664,3.8359118,,,,,,,,,,,,,,,,, -60300,0.25713035,3.8714092,,,,,,,,,,,,,,,,, -60400,0.30829388,3.8816504,,,,,,,,,,,,,,,,, -60500,0.24636434,3.8599098,,,,,,,,,,,,,,,,, -60600,0.24792688,3.8451526,,,,,,,,,,,,,,,,, -60700,0.32623166,3.7905338,,,,,,,,,,,,,,,,, -60779,,,0.6671022176742554,1.770507574081421,32.896878208625374,0.6774373650550842,1.6847100257873535,29.561452728455883,3000.0,0.6925687193870544,1.6012877225875854,29.86454461957114,3003.0,21027.7988653183,35157.31735706329,21027.7988653183,14126.851303100586,0.7242088317871094,0.0 -60800,0.27149063,3.9138205,,,,,,,,,,,,,,,,, -60900,0.24337274,3.8705451,,,,,,,,,,,,,,,,, -61000,0.27357286,3.852247,,,,,,,,,,,,,,,,, -61100,0.2724446,3.8337288,,,,,,,,,,,,,,,,, -61200,0.27165154,3.8320904,,,,,,,,,,,,,,,,, -61300,0.26390395,3.8557444,,,,,,,,,,,,,,,,, -61400,0.26563755,3.7992988,,,,,,,,,,,,,,,,, -61500,0.25288194,3.7943163,,,,,,,,,,,,,,,,, -61600,0.3071015,3.8374617,,,,,,,,,,,,,,,,, -61700,0.25361094,3.8190856,,,,,,,,,,,,,,,,, -61800,0.2948244,3.8636265,,,,,,,,,,,,,,,,, -61900,0.25430274,3.866412,,,,,,,,,,,,,,,,, -62000,0.3245986,3.8662934,,,,,,,,,,,,,,,,, -62100,0.25354904,3.7997384,,,,,,,,,,,,,,,,, -62200,0.2655752,3.863189,,,,,,,,,,,,,,,,, -62300,0.2701924,3.851819,,,,,,,,,,,,,,,,, -62400,0.28028974,3.8066728,,,,,,,,,,,,,,,,, -62500,0.27271402,3.823958,,,,,,,,,,,,,,,,, -62600,0.26165387,3.7942193,,,,,,,,,,,,,,,,, -62700,0.27242514,3.8395565,,,,,,,,,,,,,,,,, -62800,0.28545585,3.903385,,,,,,,,,,,,,,,,, -62900,0.24836597,3.8578942,,,,,,,,,,,,,,,,, -63000,0.26253787,3.865089,,,,,,,,,,,,,,,,, -63100,0.2660427,3.8176925,,,,,,,,,,,,,,,,, -63200,0.26063567,3.7981362,,,,,,,,,,,,,,,,, -63211,,,0.6855831146240234,1.6496340036392212,34.1527935497728,0.679582417011261,1.6690008640289309,29.659998515776348,3000.0,0.6939283013343811,1.5856915712356567,29.44425505315059,3003.0,21867.882091760635,36463.16767692566,21867.882091760635,14592.508635282516,0.7560958862304688,0.0 -63300,0.30069408,3.8047051,,,,,,,,,,,,,,,,, -63400,0.27411273,3.8909419,,,,,,,,,,,,,,,,, -63500,0.2523466,3.8803596,,,,,,,,,,,,,,,,, -63600,0.26744288,3.875837,,,,,,,,,,,,,,,,, -63700,0.25095153,3.8136923,,,,,,,,,,,,,,,,, -63800,0.26188996,3.8617663,,,,,,,,,,,,,,,,, -63900,0.2702102,3.8536139,,,,,,,,,,,,,,,,, -64000,0.27079788,3.8690536,,,,,,,,,,,,,,,,, -64100,0.28131574,3.8438294,,,,,,,,,,,,,,,,, -64200,0.27460498,3.8743865,,,,,,,,,,,,,,,,, -64300,0.23813848,3.8674169,,,,,,,,,,,,,,,,, -64400,0.2603325,3.8841763,,,,,,,,,,,,,,,,, -64500,0.26517528,3.8920243,,,,,,,,,,,,,,,,, -64600,0.24339184,3.8631968,,,,,,,,,,,,,,,,, -64700,0.27064255,3.8683143,,,,,,,,,,,,,,,,, -64800,0.246125,3.839677,,,,,,,,,,,,,,,,, -64900,0.25985456,3.8188162,,,,,,,,,,,,,,,,, -65000,0.25656182,3.822448,,,,,,,,,,,,,,,,, -65100,0.25810364,3.8078947,,,,,,,,,,,,,,,,, -65200,0.25636834,3.8477476,,,,,,,,,,,,,,,,, -65300,0.25386894,3.8396747,,,,,,,,,,,,,,,,, -65400,0.2726756,3.8651588,,,,,,,,,,,,,,,,, -65500,0.2602204,3.8428402,,,,,,,,,,,,,,,,, -65600,0.2566704,3.841327,,,,,,,,,,,,,,,,, -65642,,,0.6722832322120667,1.7180349826812744,33.70603561239361,0.681913435459137,1.6529948711395264,30.087363587995423,3000.0,0.6969264149665833,1.5698133707046509,29.88272037242085,3003.0,22707.863388299946,37762.35045266152,22707.863388299946,15051.599786758425,0.7890956401824951,0.0 -65700,0.24223448,3.7911768,,,,,,,,,,,,,,,,, -65800,0.32175085,3.861134,,,,,,,,,,,,,,,,, -65900,0.2508993,3.819955,,,,,,,,,,,,,,,,, -66000,0.36121696,3.8390465,,,,,,,,,,,,,,,,, -66100,0.28990608,3.7909112,,,,,,,,,,,,,,,,, -66200,0.25213093,3.7936604,,,,,,,,,,,,,,,,, -66300,0.26251638,3.767585,,,,,,,,,,,,,,,,, -66400,0.2549546,3.7594857,,,,,,,,,,,,,,,,, -66500,0.28226745,3.8217359,,,,,,,,,,,,,,,,, -66600,0.2694402,3.8905618,,,,,,,,,,,,,,,,, -66700,0.26908734,3.7920017,,,,,,,,,,,,,,,,, -66800,0.29178125,3.820964,,,,,,,,,,,,,,,,, -66900,0.27546692,3.789419,,,,,,,,,,,,,,,,, -67000,0.2933413,3.861275,,,,,,,,,,,,,,,,, -67100,0.26366538,3.8361812,,,,,,,,,,,,,,,,, -67200,0.25142455,3.79981,,,,,,,,,,,,,,,,, -67300,0.25735876,3.762613,,,,,,,,,,,,,,,,, -67400,0.27143568,3.8433588,,,,,,,,,,,,,,,,, -67500,0.25224233,3.7525225,,,,,,,,,,,,,,,,, -67600,0.26991618,3.726893,,,,,,,,,,,,,,,,, -67700,0.25884768,3.8132505,,,,,,,,,,,,,,,,, -67800,0.27615905,3.823396,,,,,,,,,,,,,,,,, -67900,0.31957397,3.8445187,,,,,,,,,,,,,,,,, -68000,0.25303867,3.7693295,,,,,,,,,,,,,,,,, -68073,,,0.6699897646903992,1.7524583339691162,33.207613914396966,0.6819630265235901,1.6551176309585571,30.112537396175984,3000.0,0.6972169280052185,1.5671659708023071,29.769703618118545,3003.0,23547.92662167549,39053.7624464035,23547.92662167549,15502.837298870088,0.822166919708252,0.0 -68100,0.35290208,3.7576835,,,,,,,,,,,,,,,,, -68200,0.30201292,3.8602598,,,,,,,,,,,,,,,,, -68300,0.27506232,3.8519914,,,,,,,,,,,,,,,,, -68400,0.27119726,3.835701,,,,,,,,,,,,,,,,, -68500,0.27521527,3.8147104,,,,,,,,,,,,,,,,, -68600,0.29554048,3.8434355,,,,,,,,,,,,,,,,, -68700,0.27912277,3.8121943,,,,,,,,,,,,,,,,, -68800,0.30702633,3.870947,,,,,,,,,,,,,,,,, -68900,0.24526985,3.8051133,,,,,,,,,,,,,,,,, -69000,0.30443126,3.8236504,,,,,,,,,,,,,,,,, -69100,0.2776346,3.8504355,,,,,,,,,,,,,,,,, -69200,0.2815869,3.8403068,,,,,,,,,,,,,,,,, -69300,0.2514621,3.7633958,,,,,,,,,,,,,,,,, -69400,0.25379503,3.782497,,,,,,,,,,,,,,,,, -69500,0.26683784,3.7780921,,,,,,,,,,,,,,,,, -69600,0.28307524,3.8178365,,,,,,,,,,,,,,,,, -69700,0.2574552,3.8396232,,,,,,,,,,,,,,,,, -69800,0.26230168,3.7593122,,,,,,,,,,,,,,,,, -69900,0.26326555,3.8042974,,,,,,,,,,,,,,,,, -70000,0.26970628,3.7773554,,,,,,,,,,,,,,,,, -70100,0.29301935,3.806632,,,,,,,,,,,,,,,,, -70200,0.26801938,3.7725768,,,,,,,,,,,,,,,,, -70300,0.2754293,3.7576892,,,,,,,,,,,,,,,,, -70400,0.3068652,3.819703,,,,,,,,,,,,,,,,, -70500,0.25867325,3.811612,,,,,,,,,,,,,,,,, -70504,,,0.6783646941184998,1.6977791786193848,34.29033621037187,0.6820374131202698,1.6467190980911257,30.249611344394367,3000.0,0.6974377036094666,1.5637001991271973,30.00090817367382,3003.0,24387.89455485344,40374.32408332825,24387.89455485344,15983.3171210289,0.857762336730957,0.0 -70600,0.26016244,3.8486705,,,,,,,,,,,,,,,,, -70700,0.27728114,3.7738657,,,,,,,,,,,,,,,,, -70800,0.327577,3.862101,,,,,,,,,,,,,,,,, -70900,0.2611342,3.8026905,,,,,,,,,,,,,,,,, -71000,0.2621092,3.8722363,,,,,,,,,,,,,,,,, -71100,0.26396865,3.8256526,,,,,,,,,,,,,,,,, -71200,0.25936076,3.8054197,,,,,,,,,,,,,,,,, -71300,0.30463898,3.8298516,,,,,,,,,,,,,,,,, -71400,0.2720246,3.8236542,,,,,,,,,,,,,,,,, -71500,0.34143108,3.8382833,,,,,,,,,,,,,,,,, -71600,0.2667616,3.7858937,,,,,,,,,,,,,,,,, -71700,0.27434698,3.828504,,,,,,,,,,,,,,,,, -71800,0.27693474,3.8232715,,,,,,,,,,,,,,,,, -71900,0.27937067,3.7802963,,,,,,,,,,,,,,,,, -72000,0.26116747,3.8119466,,,,,,,,,,,,,,,,, -72100,0.30376446,3.7563753,,,,,,,,,,,,,,,,, -72200,0.2892431,3.832332,,,,,,,,,,,,,,,,, -72300,0.31454048,3.7783027,,,,,,,,,,,,,,,,, -72400,0.32455128,3.782164,,,,,,,,,,,,,,,,, -72500,0.26051998,3.772284,,,,,,,,,,,,,,,,, -72600,0.26815844,3.7562833,,,,,,,,,,,,,,,,, -72700,0.2703252,3.7208416,,,,,,,,,,,,,,,,, -72800,0.27606896,3.8058274,,,,,,,,,,,,,,,,, -72900,0.26493472,3.7151318,,,,,,,,,,,,,,,,, -72935,,,0.6749582886695862,1.7154207229614258,33.6660576168447,0.6837732791900635,1.642994403839111,29.911717014627985,3000.0,0.6994480490684509,1.5574828386306765,29.76888711725281,3003.0,25227.994698286057,41671.04165291786,25227.994698286057,16439.81956934929,0.8949284553527832,0.0 -73000,0.27458858,3.824102,,,,,,,,,,,,,,,,, -73100,0.34678856,3.7459316,,,,,,,,,,,,,,,,, -73200,0.29720667,3.7751591,,,,,,,,,,,,,,,,, -73300,0.28941086,3.7580686,,,,,,,,,,,,,,,,, -73400,0.28134674,3.7825584,,,,,,,,,,,,,,,,, -73500,0.31697842,3.7451937,,,,,,,,,,,,,,,,, -73600,0.29168397,3.7777436,,,,,,,,,,,,,,,,, -73700,0.28388894,3.79989,,,,,,,,,,,,,,,,, -73800,0.26926947,3.7972944,,,,,,,,,,,,,,,,, -73900,0.30172458,3.7557738,,,,,,,,,,,,,,,,, -74000,0.4251098,3.8072739,,,,,,,,,,,,,,,,, -74100,0.29441324,3.7578952,,,,,,,,,,,,,,,,, -74200,0.29367185,3.8165522,,,,,,,,,,,,,,,,, -74300,0.28058258,3.7508447,,,,,,,,,,,,,,,,, -74400,0.283784,3.7900295,,,,,,,,,,,,,,,,, -74500,0.28889105,3.8249297,,,,,,,,,,,,,,,,, -74600,0.2808966,3.7610488,,,,,,,,,,,,,,,,, -74700,0.2634034,3.7546532,,,,,,,,,,,,,,,,, -74800,0.2683674,3.760335,,,,,,,,,,,,,,,,, -74900,0.27249783,3.7648764,,,,,,,,,,,,,,,,, -75000,0.2822455,3.7781088,,,,,,,,,,,,,,,,, -75100,0.2883091,3.7852666,,,,,,,,,,,,,,,,, -75200,0.27446234,3.8276362,,,,,,,,,,,,,,,,, -75300,0.27592093,3.7205966,,,,,,,,,,,,,,,,, -75366,,,0.701172411441803,1.5690522193908691,35.91140519112401,0.6847156286239624,1.6373581886291504,30.37538583212888,3000.0,0.6980187296867371,1.5505318641662598,29.991422765336544,3003.0,26068.055867910385,43029.4435377121,26068.055867910385,16958.04394555092,0.9335498809814452,0.0 -75400,0.28446695,3.741814,,,,,,,,,,,,,,,,, -75500,0.26999593,3.730642,,,,,,,,,,,,,,,,, -75600,0.25520676,3.7584503,,,,,,,,,,,,,,,,, -75700,0.29553849,3.802745,,,,,,,,,,,,,,,,, -75800,0.28755924,3.8425589,,,,,,,,,,,,,,,,, -75900,0.28101632,3.8300633,,,,,,,,,,,,,,,,, -76000,0.2862062,3.8084953,,,,,,,,,,,,,,,,, -76100,0.27442577,3.7871997,,,,,,,,,,,,,,,,, -76200,0.27887318,3.8017395,,,,,,,,,,,,,,,,, -76300,0.29838693,3.7870114,,,,,,,,,,,,,,,,, -76400,0.2604842,3.7910402,,,,,,,,,,,,,,,,, -76500,0.2742122,3.7948735,,,,,,,,,,,,,,,,, -76600,0.27975592,3.716699,,,,,,,,,,,,,,,,, -76700,0.3403001,3.8598523,,,,,,,,,,,,,,,,, -76800,0.28535104,3.865795,,,,,,,,,,,,,,,,, -76900,0.28973067,3.743599,,,,,,,,,,,,,,,,, -77000,0.26755467,3.727043,,,,,,,,,,,,,,,,, -77100,0.2910181,3.7995691,,,,,,,,,,,,,,,,, -77200,0.2627808,3.7633858,,,,,,,,,,,,,,,,, -77300,0.28467187,3.8137767,,,,,,,,,,,,,,,,, -77400,0.2706904,3.766339,,,,,,,,,,,,,,,,, -77500,0.28957358,3.7589517,,,,,,,,,,,,,,,,, -77600,0.27824655,3.7725704,,,,,,,,,,,,,,,,, -77700,0.29954177,3.7142775,,,,,,,,,,,,,,,,, -77797,,,0.6788302659988403,1.6935021877288818,34.17489567231058,0.6846908330917358,1.637537956237793,30.014238840241926,3000.0,0.701714038848877,1.5495461225509644,30.3965153832504,3003.0,26908.02235507965,44322.34855294228,26908.02235507965,17410.868065595627,0.9702820777893066,0.0 -77800,0.269521,3.764953,,,,,,,,,,,,,,,,, -77900,0.2838412,3.797432,,,,,,,,,,,,,,,,, -78000,0.29893893,3.8284528,,,,,,,,,,,,,,,,, -78100,0.3023875,3.8046148,,,,,,,,,,,,,,,,, -78200,0.30208248,3.8554294,,,,,,,,,,,,,,,,, -78300,0.31876197,3.7581146,,,,,,,,,,,,,,,,, -78400,0.26653814,3.6877918,,,,,,,,,,,,,,,,, -78500,0.28857797,3.735472,,,,,,,,,,,,,,,,, -78600,0.29157138,3.7831175,,,,,,,,,,,,,,,,, -78700,0.29038683,3.809387,,,,,,,,,,,,,,,,, -78800,0.27666312,3.756177,,,,,,,,,,,,,,,,, -78900,0.26434097,3.738491,,,,,,,,,,,,,,,,, -79000,0.2963629,3.7601776,,,,,,,,,,,,,,,,, -79100,0.2776035,3.7950563,,,,,,,,,,,,,,,,, -79200,0.25880435,3.7476466,,,,,,,,,,,,,,,,, -79300,0.26713946,3.7623723,,,,,,,,,,,,,,,,, -79400,0.3337549,3.8177962,,,,,,,,,,,,,,,,, -79500,0.2832084,3.771048,,,,,,,,,,,,,,,,, -79600,0.2957193,3.7748663,,,,,,,,,,,,,,,,, -79700,0.2926531,3.807306,,,,,,,,,,,,,,,,, -79800,0.30843562,3.7528865,,,,,,,,,,,,,,,,, -79900,0.29212883,3.693246,,,,,,,,,,,,,,,,, -80000,0.28128946,3.792115,,,,,,,,,,,,,,,,, -80100,0.2832637,3.7829468,,,,,,,,,,,,,,,,, -80200,0.313989,3.709531,,,,,,,,,,,,,,,,, -80228,,,0.6816733479499817,1.680078387260437,34.12771578283215,0.6865506768226624,1.6328672170639038,30.30910404470047,3000.0,0.7003777027130127,1.542347073554993,30.21403086923635,3003.0,27748.175061941147,45607.78610897064,27748.175061941147,17856.037237882614,1.008216142654419,0.0 -80300,0.28709033,3.7852998,,,,,,,,,,,,,,,,, -80400,0.3017351,3.7559586,,,,,,,,,,,,,,,,, -80500,0.2579172,3.759698,,,,,,,,,,,,,,,,, -80600,0.34743512,3.746446,,,,,,,,,,,,,,,,, -80700,0.29162785,3.7696698,,,,,,,,,,,,,,,,, -80800,0.26867983,3.8245223,,,,,,,,,,,,,,,,, -80900,0.2661407,3.757922,,,,,,,,,,,,,,,,, -81000,0.29690397,3.7498918,,,,,,,,,,,,,,,,, -81100,0.3275214,3.7927895,,,,,,,,,,,,,,,,, -81200,0.28219658,3.7418442,,,,,,,,,,,,,,,,, -81300,0.29586917,3.7464101,,,,,,,,,,,,,,,,, -81400,0.27380508,3.7792714,,,,,,,,,,,,,,,,, -81500,0.28978544,3.7965221,,,,,,,,,,,,,,,,, -81600,0.28526944,3.811778,,,,,,,,,,,,,,,,, -81700,0.29295358,3.7610598,,,,,,,,,,,,,,,,, -81800,0.3138164,3.775589,,,,,,,,,,,,,,,,, -81900,0.284405,3.7888155,,,,,,,,,,,,,,,,, -82000,0.30162546,3.7977407,,,,,,,,,,,,,,,,, -82100,0.34129214,3.785895,,,,,,,,,,,,,,,,, -82200,0.31130743,3.8423939,,,,,,,,,,,,,,,,, -82300,0.28193778,3.7995856,,,,,,,,,,,,,,,,, -82400,0.29537988,3.7771838,,,,,,,,,,,,,,,,, -82500,0.32463518,3.7401865,,,,,,,,,,,,,,,,, -82600,0.31462154,3.803655,,,,,,,,,,,,,,,,, -82659,,,0.6934168934822083,1.601096272468567,34.86383322748257,0.6877533793449402,1.6192811727523804,30.61256790425198,3000.0,0.7024344801902771,1.531815767288208,30.29900160524727,3003.0,28588.085990428925,46958.48557043076,28588.085990428925,18366.711642980576,1.0443508625030518,0.0 -82700,0.29936874,3.807118,,,,,,,,,,,,,,,,, -82800,0.28553447,3.7901604,,,,,,,,,,,,,,,,, -82900,0.3096531,3.7554855,,,,,,,,,,,,,,,,, -83000,0.27687222,3.7063208,,,,,,,,,,,,,,,,, -83100,0.2939378,3.8084278,,,,,,,,,,,,,,,,, -83200,0.28435487,3.8074398,,,,,,,,,,,,,,,,, -83300,0.29121885,3.76286,,,,,,,,,,,,,,,,, -83400,0.33781222,3.8047159,,,,,,,,,,,,,,,,, -83500,0.2864055,3.7902813,,,,,,,,,,,,,,,,, -83600,0.28759548,3.740042,,,,,,,,,,,,,,,,, -83700,0.28713962,3.7426066,,,,,,,,,,,,,,,,, -83800,0.28134838,3.7118213,,,,,,,,,,,,,,,,, -83900,0.29786468,3.777373,,,,,,,,,,,,,,,,, -84000,0.30200124,3.73936,,,,,,,,,,,,,,,,, -84100,0.303765,3.7652283,,,,,,,,,,,,,,,,, -84200,0.2782526,3.7753825,,,,,,,,,,,,,,,,, -84300,0.28108943,3.7743812,,,,,,,,,,,,,,,,, -84400,0.3143699,3.8121626,,,,,,,,,,,,,,,,, -84500,0.30991215,3.756332,,,,,,,,,,,,,,,,, -84600,0.28093353,3.7409914,,,,,,,,,,,,,,,,, -84700,0.2961866,3.710986,,,,,,,,,,,,,,,,, -84800,0.30022308,3.73159,,,,,,,,,,,,,,,,, -84900,0.32788965,3.7309182,,,,,,,,,,,,,,,,, -85000,0.30313808,3.7458014,,,,,,,,,,,,,,,,, -85091,,,0.6842333674430847,1.6667418479919434,34.231418554266256,0.6870714426040649,1.6234965324401855,30.22605562863886,3000.0,0.7035732865333557,1.5303460359573364,30.34697045957415,3003.0,29428.2906563282,48280.69108605385,29428.2906563282,18848.596970558167,1.0820410251617432,0.0 -85100,0.2962749,3.80494,,,,,,,,,,,,,,,,, -85200,0.30367997,3.7091718,,,,,,,,,,,,,,,,, -85300,0.3246774,3.7544532,,,,,,,,,,,,,,,,, -85400,0.29384324,3.694162,,,,,,,,,,,,,,,,, -85500,0.29514748,3.7381241,,,,,,,,,,,,,,,,, -85600,0.28825212,3.7140098,,,,,,,,,,,,,,,,, -85700,0.29366186,3.7735639,,,,,,,,,,,,,,,,, -85800,0.30349323,3.8212008,,,,,,,,,,,,,,,,, -85900,0.2969151,3.75568,,,,,,,,,,,,,,,,, -86000,0.30002287,3.7572532,,,,,,,,,,,,,,,,, -86100,0.32606518,3.7077618,,,,,,,,,,,,,,,,, -86200,0.3026062,3.7664359,,,,,,,,,,,,,,,,, -86300,0.30108595,3.7441368,,,,,,,,,,,,,,,,, -86400,0.29951775,3.7782135,,,,,,,,,,,,,,,,, -86500,0.3092722,3.767549,,,,,,,,,,,,,,,,, -86600,0.31610197,3.7819154,,,,,,,,,,,,,,,,, -86700,0.3106381,3.7416723,,,,,,,,,,,,,,,,, -86800,0.30374533,3.7561243,,,,,,,,,,,,,,,,, -86900,0.2883497,3.680769,,,,,,,,,,,,,,,,, -87000,0.31328636,3.8022218,,,,,,,,,,,,,,,,, -87100,0.39453122,3.7367575,,,,,,,,,,,,,,,,, -87200,0.31318182,3.7471802,,,,,,,,,,,,,,,,, -87300,0.30651832,3.7447405,,,,,,,,,,,,,,,,, -87400,0.30320722,3.7107894,,,,,,,,,,,,,,,,, -87500,0.28966475,3.7359667,,,,,,,,,,,,,,,,, -87522,,,0.6859403252601624,1.6498194932937622,34.449619726829866,0.6880509853363037,1.620361566543579,30.61483554539125,3000.0,0.7043286561965942,1.5281420946121216,30.191210875242326,3003.0,30268.27372980117,49588.71892952919,30268.27372980117,19316.526229143143,1.1193876266479492,0.0 -87600,0.28934276,3.7471464,,,,,,,,,,,,,,,,, -87700,0.29838347,3.7227046,,,,,,,,,,,,,,,,, -87800,0.3197896,3.723797,,,,,,,,,,,,,,,,, -87900,0.30957294,3.7117267,,,,,,,,,,,,,,,,, -88000,0.30267382,3.739151,,,,,,,,,,,,,,,,, -88100,0.30483755,3.7112186,,,,,,,,,,,,,,,,, -88200,0.31488582,3.7775292,,,,,,,,,,,,,,,,, -88300,0.3042793,3.7458074,,,,,,,,,,,,,,,,, -88400,0.30742812,3.753329,,,,,,,,,,,,,,,,, -88500,0.30556193,3.7531924,,,,,,,,,,,,,,,,, -88600,0.30999196,3.7346137,,,,,,,,,,,,,,,,, -88700,0.30001503,3.7460644,,,,,,,,,,,,,,,,, -88800,0.32020137,3.8243353,,,,,,,,,,,,,,,,, -88900,0.30970854,3.7164786,,,,,,,,,,,,,,,,, -89000,0.3081119,3.7142093,,,,,,,,,,,,,,,,, -89100,0.33743432,3.779042,,,,,,,,,,,,,,,,, -89200,0.32632482,3.795112,,,,,,,,,,,,,,,,, -89300,0.31399873,3.776314,,,,,,,,,,,,,,,,, -89400,0.30028397,3.7148664,,,,,,,,,,,,,,,,, -89500,0.2960291,3.7062743,,,,,,,,,,,,,,,,, -89600,0.3316579,3.7068233,,,,,,,,,,,,,,,,, -89700,0.30386066,3.7185857,,,,,,,,,,,,,,,,, -89800,0.2966518,3.7338383,,,,,,,,,,,,,,,,, -89900,0.32444555,3.76174,,,,,,,,,,,,,,,,, -89953,,,0.6895501613616943,1.6185948848724363,35.17824123678239,0.6874062418937683,1.6165112257003784,30.320178766417413,3000.0,0.7040846347808838,1.5222641229629517,30.366536884259617,3003.0,31108.289889335632,50895.24898195267,31108.289889335632,19782.92421078682,1.1572742462158203,0.0 -90000,0.31996167,3.7493486,,,,,,,,,,,,,,,,, -90100,0.3116934,3.7415178,,,,,,,,,,,,,,,,, -90200,0.31232116,3.6591125,,,,,,,,,,,,,,,,, -90300,0.320582,3.7529075,,,,,,,,,,,,,,,,, -90400,0.3039051,3.7275708,,,,,,,,,,,,,,,,, -90500,0.29795784,3.644917,,,,,,,,,,,,,,,,, -90600,0.3022357,3.7398126,,,,,,,,,,,,,,,,, -90700,0.30015293,3.6881588,,,,,,,,,,,,,,,,, -90800,0.3186641,3.7080135,,,,,,,,,,,,,,,,, -90900,0.31365588,3.7674367,,,,,,,,,,,,,,,,, -91000,0.30739498,3.7274756,,,,,,,,,,,,,,,,, -91100,0.34559318,3.7296145,,,,,,,,,,,,,,,,, -91200,0.3074692,3.7986448,,,,,,,,,,,,,,,,, -91300,0.30664653,3.6983132,,,,,,,,,,,,,,,,, -91400,0.30977204,3.702571,,,,,,,,,,,,,,,,, -91500,0.31178516,3.7679424,,,,,,,,,,,,,,,,, -91600,0.3029775,3.7392788,,,,,,,,,,,,,,,,, -91700,0.30235392,3.7387872,,,,,,,,,,,,,,,,, -91800,0.312723,3.7980316,,,,,,,,,,,,,,,,, -91900,0.2986438,3.6968873,,,,,,,,,,,,,,,,, -92000,0.3014035,3.7215383,,,,,,,,,,,,,,,,, -92100,0.31360584,3.6719844,,,,,,,,,,,,,,,,, -92200,0.35653427,3.7188344,,,,,,,,,,,,,,,,, -92300,0.3332208,3.7027035,,,,,,,,,,,,,,,,, -92384,,,0.690657913684845,1.6202352046966553,35.047649936271675,0.6881749629974365,1.61602783203125,30.26145651209316,3000.0,0.7052234411239624,1.5193902254104614,30.455149942129623,3003.0,31948.42696595192,52216.04299545288,31948.42696595192,20263.464618206024,1.1962149143218994,0.0 -92400,0.31126758,3.765258,,,,,,,,,,,,,,,,, -92500,0.32282546,3.7489126,,,,,,,,,,,,,,,,, -92600,0.30945906,3.7437975,,,,,,,,,,,,,,,,, -92700,0.30744466,3.7259119,,,,,,,,,,,,,,,,, -92800,0.2975896,3.6679616,,,,,,,,,,,,,,,,, -92900,0.30338946,3.6833472,,,,,,,,,,,,,,,,, -93000,0.30520976,3.717661,,,,,,,,,,,,,,,,, -93100,0.30647093,3.6872478,,,,,,,,,,,,,,,,, -93200,0.30768463,3.6587844,,,,,,,,,,,,,,,,, -93300,0.3010382,3.6820805,,,,,,,,,,,,,,,,, -93400,0.31437215,3.7133088,,,,,,,,,,,,,,,,, -93500,0.34762526,3.7579393,,,,,,,,,,,,,,,,, -93600,0.32333305,3.7101345,,,,,,,,,,,,,,,,, -93700,0.33321774,3.7116044,,,,,,,,,,,,,,,,, -93800,0.31477,3.7069523,,,,,,,,,,,,,,,,, -93900,0.3206806,3.7584531,,,,,,,,,,,,,,,,, -94000,0.31123164,3.6877751,,,,,,,,,,,,,,,,, -94100,0.3046161,3.7221143,,,,,,,,,,,,,,,,, -94200,0.3123454,3.6962054,,,,,,,,,,,,,,,,, -94300,0.31886086,3.746396,,,,,,,,,,,,,,,,, -94400,0.33659178,3.713657,,,,,,,,,,,,,,,,, -94500,0.3100142,3.74487,,,,,,,,,,,,,,,,, -94600,0.31107917,3.71607,,,,,,,,,,,,,,,,, -94700,0.31416288,3.7198126,,,,,,,,,,,,,,,,, -94800,0.32737324,3.6711285,,,,,,,,,,,,,,,,, -94815,,,0.7034842371940613,1.538518309593201,35.83842989432134,0.6895264983177185,1.606481432914734,30.30097926569033,3000.0,0.7068154215812683,1.5066797733306885,30.62494146013703,3003.0,32788.33744096756,53536.12700033188,32788.33744096756,20743.521106243134,1.2359349727630615,0.0 -94900,0.30064377,3.6948814,,,,,,,,,,,,,,,,, -95000,0.31551254,3.7188003,,,,,,,,,,,,,,,,, -95100,0.31076908,3.6976852,,,,,,,,,,,,,,,,, -95200,0.3267456,3.746215,,,,,,,,,,,,,,,,, -95300,0.31150016,3.7382398,,,,,,,,,,,,,,,,, -95400,0.3089498,3.6348138,,,,,,,,,,,,,,,,, -95500,0.31974724,3.756123,,,,,,,,,,,,,,,,, -95600,0.32315564,3.7163563,,,,,,,,,,,,,,,,, -95700,0.310083,3.6822696,,,,,,,,,,,,,,,,, -95800,0.30565083,3.6724486,,,,,,,,,,,,,,,,, -95900,0.3293212,3.6828554,,,,,,,,,,,,,,,,, -96000,0.32334095,3.725687,,,,,,,,,,,,,,,,, -96100,0.31151146,3.7162893,,,,,,,,,,,,,,,,, -96200,0.3165405,3.6897914,,,,,,,,,,,,,,,,, -96300,0.316498,3.757792,,,,,,,,,,,,,,,,, -96400,0.3515465,3.696914,,,,,,,,,,,,,,,,, -96500,0.32437882,3.728198,,,,,,,,,,,,,,,,, -96600,0.31328294,3.734308,,,,,,,,,,,,,,,,, -96700,0.32203412,3.6709666,,,,,,,,,,,,,,,,, -96800,0.33705768,3.7540898,,,,,,,,,,,,,,,,, -96900,0.33433405,3.6959443,,,,,,,,,,,,,,,,, -97000,0.3383533,3.7366998,,,,,,,,,,,,,,,,, -97100,0.34499758,3.6491787,,,,,,,,,,,,,,,,, -97200,0.30867592,3.7034924,,,,,,,,,,,,,,,,, -97246,,,0.6954914927482605,1.5836246013641355,35.719509026518935,0.6894396543502808,1.608036756515503,30.44411971751765,3000.0,0.706466794013977,1.511273980140686,30.46507838062227,3003.0,33628.37856912613,54841.610292196274,33628.37856912613,21208.848113775253,1.274146318435669,0.0 -97300,0.32413062,3.7459931,,,,,,,,,,,,,,,,, -97400,0.32242966,3.6797557,,,,,,,,,,,,,,,,, -97500,0.3517036,3.7659876,,,,,,,,,,,,,,,,, -97600,0.32394063,3.692962,,,,,,,,,,,,,,,,, -97700,0.30677247,3.694042,,,,,,,,,,,,,,,,, -97800,0.32942617,3.6663013,,,,,,,,,,,,,,,,, -97900,0.32069942,3.7249477,,,,,,,,,,,,,,,,, -98000,0.33600155,3.728711,,,,,,,,,,,,,,,,, -98100,0.32845905,3.6756606,,,,,,,,,,,,,,,,, -98200,0.33359572,3.7251606,,,,,,,,,,,,,,,,, -98300,0.33351636,3.7046733,,,,,,,,,,,,,,,,, -98400,0.3209294,3.6905258,,,,,,,,,,,,,,,,, -98500,0.31743318,3.665119,,,,,,,,,,,,,,,,, -98600,0.33945528,3.6805463,,,,,,,,,,,,,,,,, -98700,0.32295984,3.6413414,,,,,,,,,,,,,,,,, -98800,0.37340966,3.7509067,,,,,,,,,,,,,,,,, -98900,0.31863245,3.6705263,,,,,,,,,,,,,,,,, -99000,0.33408955,3.6855993,,,,,,,,,,,,,,,,, -99100,0.32499614,3.7329342,,,,,,,,,,,,,,,,, -99200,0.31970596,3.652099,,,,,,,,,,,,,,,,, -99300,0.33021608,3.7015593,,,,,,,,,,,,,,,,, -99400,0.32906526,3.6812685,,,,,,,,,,,,,,,,, -99500,0.33930975,3.69895,,,,,,,,,,,,,,,,, -99600,0.3422304,3.6571407,,,,,,,,,,,,,,,,, -99677,,,0.6979652643203735,1.5806125402450562,35.63073779555025,0.6885221600532532,1.6104274988174438,30.35029837043829,3000.0,0.7069316506385803,1.5114892721176147,30.37346591024028,3003.0,34468.591000556946,56148.71933174133,34468.591000556946,21675.62931656837,1.3121390342712402,0.0 -99700,0.32447314,3.6883411,,,,,,,,,,,,,,,,, -99800,0.32518527,3.6818373,,,,,,,,,,,,,,,,, -99900,0.3452013,3.7225528,,,,,,,,,,,,,,,,, -100000,0.35124114,3.7306118,,,,,,,,,,,,,,,,, -100100,0.3401001,3.686676,,,,,,,,,,,,,,,,, -100200,0.34545636,3.6736536,,,,,,,,,,,,,,,,, -100300,0.31961384,3.6854157,,,,,,,,,,,,,,,,, -100400,0.3328508,3.6842449,,,,,,,,,,,,,,,,, -100500,0.33654258,3.725154,,,,,,,,,,,,,,,,, -100600,0.33468407,3.6866784,,,,,,,,,,,,,,,,, -100700,0.3533006,3.7106338,,,,,,,,,,,,,,,,, -100800,0.331942,3.6467936,,,,,,,,,,,,,,,,, -100900,0.3402769,3.6712487,,,,,,,,,,,,,,,,, -101000,0.34187505,3.7022147,,,,,,,,,,,,,,,,, -101100,0.3539148,3.6580138,,,,,,,,,,,,,,,,, -101200,0.3382102,3.712348,,,,,,,,,,,,,,,,, -101300,0.3302108,3.7534695,,,,,,,,,,,,,,,,, -101400,0.34662396,3.6888595,,,,,,,,,,,,,,,,, -101500,0.33866593,3.7241418,,,,,,,,,,,,,,,,, -101600,0.36993843,3.7648013,,,,,,,,,,,,,,,,, -101700,0.32453033,3.6911414,,,,,,,,,,,,,,,,, -101800,0.36207655,3.7046452,,,,,,,,,,,,,,,,, -101900,0.3525241,3.7133236,,,,,,,,,,,,,,,,, -102000,0.3076253,3.6700885,,,,,,,,,,,,,,,,, -102100,0.33125797,3.630585,,,,,,,,,,,,,,,,, -102108,,,0.7036283016204834,1.5418643951416016,36.0726587683943,0.689315676689148,1.6059999465942385,30.482721601595376,3000.0,0.7062808871269226,1.5039721727371216,30.5054449175256,3003.0,35308.638811826706,57451.48646616936,35308.638811826706,22138.22982096672,1.3539230823516846,0.0 -102200,0.335704,3.685853,,,,,,,,,,,,,,,,, -102300,0.3417981,3.6770246,,,,,,,,,,,,,,,,, -102400,0.33046255,3.6402109,,,,,,,,,,,,,,,,, -102500,0.3328169,3.7147024,,,,,,,,,,,,,,,,, -102600,0.33556038,3.6786354,,,,,,,,,,,,,,,,, -102700,0.35358974,3.7640367,,,,,,,,,,,,,,,,, -102800,0.33775172,3.7199585,,,,,,,,,,,,,,,,, -102900,0.36235997,3.6584396,,,,,,,,,,,,,,,,, -103000,0.34755632,3.710492,,,,,,,,,,,,,,,,, -103100,0.37232944,3.6959395,,,,,,,,,,,,,,,,, -103200,0.3352865,3.6653178,,,,,,,,,,,,,,,,, -103300,0.35289595,3.700762,,,,,,,,,,,,,,,,, -103400,0.33830348,3.674842,,,,,,,,,,,,,,,,, -103500,0.33789566,3.6685026,,,,,,,,,,,,,,,,, -103600,0.3649174,3.6843684,,,,,,,,,,,,,,,,, -103700,0.33333504,3.682088,,,,,,,,,,,,,,,,, -103800,0.33508202,3.6264596,,,,,,,,,,,,,,,,, -103900,0.35328358,3.6942818,,,,,,,,,,,,,,,,, -104000,0.3457567,3.659106,,,,,,,,,,,,,,,,, -104100,0.32744133,3.69867,,,,,,,,,,,,,,,,, -104200,0.33532676,3.6783032,,,,,,,,,,,,,,,,, -104300,0.3592745,3.6929975,,,,,,,,,,,,,,,,, -104400,0.334744,3.6875398,,,,,,,,,,,,,,,,, -104500,0.36751056,3.7332063,,,,,,,,,,,,,,,,, -104539,,,0.6998471617698669,1.5669548511505127,35.92477397313134,0.6904687881469727,1.6009769439697266,30.45298341634368,3000.0,0.7068619132041931,1.5041028261184692,30.359611628316312,3003.0,36148.625277519226,58767.56322598457,36148.625277519226,22614.202502965927,1.394146203994751,0.0 -104600,0.3323766,3.6459107,,,,,,,,,,,,,,,,, -104700,0.34199104,3.6815019,,,,,,,,,,,,,,,,, -104800,0.35924658,3.6756976,,,,,,,,,,,,,,,,, -104900,0.37465495,3.6703105,,,,,,,,,,,,,,,,, -105000,0.35783443,3.6713722,,,,,,,,,,,,,,,,, -105100,0.34536058,3.6422825,,,,,,,,,,,,,,,,, -105200,0.36418793,3.7092695,,,,,,,,,,,,,,,,, -105300,0.33085066,3.682849,,,,,,,,,,,,,,,,, -105400,0.36211187,3.6476038,,,,,,,,,,,,,,,,, -105500,0.36421964,3.657091,,,,,,,,,,,,,,,,, -105600,0.35566324,3.6715393,,,,,,,,,,,,,,,,, -105700,0.3363869,3.6595318,,,,,,,,,,,,,,,,, -105800,0.34951937,3.6715517,,,,,,,,,,,,,,,,, -105900,0.35401636,3.643564,,,,,,,,,,,,,,,,, -106000,0.3532736,3.6049886,,,,,,,,,,,,,,,,, -106100,0.35488725,3.6336582,,,,,,,,,,,,,,,,, -106200,0.34595862,3.694765,,,,,,,,,,,,,,,,, -106300,0.34793174,3.6996508,,,,,,,,,,,,,,,,, -106400,0.35019338,3.6722245,,,,,,,,,,,,,,,,, -106500,0.3478715,3.6483226,,,,,,,,,,,,,,,,, -106600,0.35368428,3.6905127,,,,,,,,,,,,,,,,, -106700,0.3795308,3.697914,,,,,,,,,,,,,,,,, -106800,0.3640311,3.6721869,,,,,,,,,,,,,,,,, -106900,0.34670463,3.6390135,,,,,,,,,,,,,,,,, -106971,,,0.7189033031463623,1.4635096788406372,37.323135864973146,0.6899852156639099,1.603049635887146,30.498854117375075,3000.0,0.7066411375999451,1.503878474235535,30.1928181417836,3003.0,36988.803647995,60116.20536804199,36988.803647995,23122.549444437027,1.4340860843658447,0.0 -107000,0.36458257,3.7078416,,,,,,,,,,,,,,,,, -107100,0.36706692,3.68078,,,,,,,,,,,,,,,,, -107200,0.3609865,3.6919322,,,,,,,,,,,,,,,,, -107300,0.35215858,3.659715,,,,,,,,,,,,,,,,, -107400,0.35868692,3.6767294,,,,,,,,,,,,,,,,, -107500,0.3490496,3.617481,,,,,,,,,,,,,,,,, -107600,0.34233242,3.634687,,,,,,,,,,,,,,,,, -107700,0.33812502,3.632317,,,,,,,,,,,,,,,,, -107800,0.34874663,3.6586704,,,,,,,,,,,,,,,,, -107900,0.48495176,3.6134486,,,,,,,,,,,,,,,,, -108000,0.360994,3.6711082,,,,,,,,,,,,,,,,, -108100,0.3511833,3.6628876,,,,,,,,,,,,,,,,, -108200,0.35835096,3.6657836,,,,,,,,,,,,,,,,, -108300,0.39259633,3.6957092,,,,,,,,,,,,,,,,, -108400,0.3558148,3.6425045,,,,,,,,,,,,,,,,, -108500,0.3548135,3.6492352,,,,,,,,,,,,,,,,, -108600,0.3641781,3.6844873,,,,,,,,,,,,,,,,, -108700,0.37235248,3.6771863,,,,,,,,,,,,,,,,, -108800,0.34498617,3.635211,,,,,,,,,,,,,,,,, -108900,0.36776236,3.622153,,,,,,,,,,,,,,,,, -109000,0.3508075,3.6717005,,,,,,,,,,,,,,,,, -109100,0.35691023,3.6162596,,,,,,,,,,,,,,,,, -109200,0.36695,3.6788833,,,,,,,,,,,,,,,,, -109300,0.37198657,3.6324296,,,,,,,,,,,,,,,,, -109400,0.35037208,3.6356592,,,,,,,,,,,,,,,,, -109401,,,0.7105783224105835,1.5062320232391355,36.803288837604285,0.6902704238891602,1.5982686281204224,30.68162179003621,3000.0,0.7091511487960815,1.4960204362869265,30.79524511078754,3003.0,37828.95445275307,61403.85651016235,37828.95445275307,23569.930787324905,1.4757776260375977,0.0 -109500,0.38717008,3.701113,,,,,,,,,,,,,,,,, -109600,0.3819264,3.6036491,,,,,,,,,,,,,,,,, -109700,0.34881935,3.6689045,,,,,,,,,,,,,,,,, -109800,0.3741804,3.5962403,,,,,,,,,,,,,,,,, -109900,0.3904566,3.6755302,,,,,,,,,,,,,,,,, -110000,0.3504654,3.6500762,,,,,,,,,,,,,,,,, -110100,0.3772677,3.6674297,,,,,,,,,,,,,,,,, -110200,0.35590938,3.6793416,,,,,,,,,,,,,,,,, -110300,0.35137904,3.63783,,,,,,,,,,,,,,,,, -110400,0.37096173,3.7209737,,,,,,,,,,,,,,,,, -110500,0.3757301,3.6346416,,,,,,,,,,,,,,,,, -110600,0.37203994,3.6037056,,,,,,,,,,,,,,,,, -110700,0.37709203,3.6925774,,,,,,,,,,,,,,,,, -110800,0.36697894,3.6821501,,,,,,,,,,,,,,,,, -110900,0.34165478,3.7000248,,,,,,,,,,,,,,,,, -111000,0.36451167,3.6516063,,,,,,,,,,,,,,,,, -111100,0.3597545,3.6662865,,,,,,,,,,,,,,,,, -111200,0.3661916,3.611594,,,,,,,,,,,,,,,,, -111300,0.37409362,3.6568675,,,,,,,,,,,,,,,,, -111400,0.36310726,3.67937,,,,,,,,,,,,,,,,, -111500,0.3676191,3.6784036,,,,,,,,,,,,,,,,, -111600,0.37864885,3.6522932,,,,,,,,,,,,,,,,, -111700,0.3633692,3.6388872,,,,,,,,,,,,,,,,, -111800,0.361307,3.6265645,,,,,,,,,,,,,,,,, -111832,,,0.7118147611618042,1.5033549070358276,36.46637331018224,0.6907911896705627,1.6010395288467407,30.68753003150424,3000.0,0.7090930342674255,1.5006964206695557,30.644621518964943,3003.0,38669.03541469574,62706.950540065765,38669.03541469574,24032.82445716858,1.5178844928741455,0.0 -111900,0.37387395,3.6875443,,,,,,,,,,,,,,,,, -112000,0.37438098,3.6760602,,,,,,,,,,,,,,,,, -112100,0.3625453,3.6113002,,,,,,,,,,,,,,,,, -112200,0.40542346,3.6113608,,,,,,,,,,,,,,,,, -112300,0.35471684,3.6221108,,,,,,,,,,,,,,,,, -112400,0.38005593,3.611669,,,,,,,,,,,,,,,,, -112500,0.3878924,3.6455283,,,,,,,,,,,,,,,,, -112600,0.38721824,3.6551023,,,,,,,,,,,,,,,,, -112700,0.3757948,3.5888376,,,,,,,,,,,,,,,,, -112800,0.36608446,3.629974,,,,,,,,,,,,,,,,, -112900,0.3579273,3.6249511,,,,,,,,,,,,,,,,, -113000,0.34512764,3.5732093,,,,,,,,,,,,,,,,, -113100,0.3731253,3.5833614,,,,,,,,,,,,,,,,, -113200,0.34623173,3.5811234,,,,,,,,,,,,,,,,, -113300,0.36788246,3.6567872,,,,,,,,,,,,,,,,, -113400,0.37851155,3.631304,,,,,,,,,,,,,,,,, -113500,0.36882573,3.6448767,,,,,,,,,,,,,,,,, -113600,0.3843146,3.6304812,,,,,,,,,,,,,,,,, -113700,0.38006154,3.6443303,,,,,,,,,,,,,,,,, -113800,0.37287268,3.618339,,,,,,,,,,,,,,,,, -113900,0.37496445,3.610067,,,,,,,,,,,,,,,,, -114000,0.3678785,3.6339676,,,,,,,,,,,,,,,,, -114100,0.35669127,3.6304607,,,,,,,,,,,,,,,,, -114200,0.37912717,3.657519,,,,,,,,,,,,,,,,, -114263,,,0.7177315354347229,1.4715121984481812,37.12487177143348,0.6904315948486328,1.597692847251892,30.83881345974083,3000.0,0.708512008190155,1.4953467845916748,30.55669261861972,3003.0,39509.11229848862,63997.05086636543,39509.11229848862,24482.726551771164,1.5615882873535156,0.0 -114300,0.3801956,3.6495736,,,,,,,,,,,,,,,,, -114400,0.3728872,3.6316872,,,,,,,,,,,,,,,,, -114500,0.39523724,3.699105,,,,,,,,,,,,,,,,, -114600,0.39406466,3.6272264,,,,,,,,,,,,,,,,, -114700,0.3851451,3.5955477,,,,,,,,,,,,,,,,, -114800,0.36908725,3.6172423,,,,,,,,,,,,,,,,, -114900,0.36154252,3.6286507,,,,,,,,,,,,,,,,, -115000,0.3840047,3.6073227,,,,,,,,,,,,,,,,, -115100,0.36410797,3.5963516,,,,,,,,,,,,,,,,, -115200,0.37955385,3.6476674,,,,,,,,,,,,,,,,, -115300,0.37455976,3.6485984,,,,,,,,,,,,,,,,, -115400,0.37588343,3.6528986,,,,,,,,,,,,,,,,, -115500,0.3797652,3.670413,,,,,,,,,,,,,,,,, -115600,0.37119836,3.648059,,,,,,,,,,,,,,,,, -115700,0.37921134,3.6599612,,,,,,,,,,,,,,,,, -115800,0.37151662,3.63088,,,,,,,,,,,,,,,,, -115900,0.3837746,3.6372142,,,,,,,,,,,,,,,,, -116000,0.3806761,3.6342032,,,,,,,,,,,,,,,,, -116100,0.37445596,3.6061032,,,,,,,,,,,,,,,,, -116200,0.36936745,3.610546,,,,,,,,,,,,,,,,, -116300,0.3801614,3.6261423,,,,,,,,,,,,,,,,, -116400,0.36228308,3.5690186,,,,,,,,,,,,,,,,, -116500,0.37995338,3.6279736,,,,,,,,,,,,,,,,, -116600,0.37818685,3.6364107,,,,,,,,,,,,,,,,, -116694,,,0.7157125473022461,1.4814000129699707,36.77640951803491,0.6911507248878479,1.601112961769104,30.63074329345861,3000.0,0.7089884281158447,1.5000479221343994,30.72359865724317,3003.0,40349.132508039474,65300.12974977493,40349.132508039474,24945.666995048523,1.6028475761413574,0.0 -116700,0.36038405,3.581273,,,,,,,,,,,,,,,,, -116800,0.3698786,3.6188645,,,,,,,,,,,,,,,,, -116900,0.38440228,3.6159348,,,,,,,,,,,,,,,,, -117000,0.39226186,3.6455152,,,,,,,,,,,,,,,,, -117100,0.37657458,3.633417,,,,,,,,,,,,,,,,, -117200,0.38751113,3.6189024,,,,,,,,,,,,,,,,, -117300,0.36973992,3.6476047,,,,,,,,,,,,,,,,, -117400,0.38457978,3.6516507,,,,,,,,,,,,,,,,, -117500,0.38256073,3.6644995,,,,,,,,,,,,,,,,, -117600,0.37333864,3.5847747,,,,,,,,,,,,,,,,, -117700,0.37494695,3.5920398,,,,,,,,,,,,,,,,, -117800,0.40299332,3.6122942,,,,,,,,,,,,,,,,, -117900,0.36363086,3.6161127,,,,,,,,,,,,,,,,, -118000,0.37572294,3.5594966,,,,,,,,,,,,,,,,, -118100,0.37027273,3.6274335,,,,,,,,,,,,,,,,, -118200,0.37263894,3.5873106,,,,,,,,,,,,,,,,, -118300,0.37254784,3.6455495,,,,,,,,,,,,,,,,, -118400,0.37589937,3.6199224,,,,,,,,,,,,,,,,, -118500,0.3747837,3.6229446,,,,,,,,,,,,,,,,, -118600,0.37307233,3.5990012,,,,,,,,,,,,,,,,, -118700,0.37366158,3.667134,,,,,,,,,,,,,,,,, -118800,0.38292882,3.5971963,,,,,,,,,,,,,,,,, -118900,0.38112098,3.5866673,,,,,,,,,,,,,,,,, -119000,0.41005754,3.6687193,,,,,,,,,,,,,,,,, -119100,0.37072757,3.608891,,,,,,,,,,,,,,,,, -119125,,,0.7265663146972656,1.425704002380371,37.28005875910896,0.6910267472267151,1.5974948406219482,30.580480746826456,3000.0,0.7088141441345215,1.4978607892990112,30.61001220758412,3003.0,41189.27968287468,66606.66089344025,41189.27968287468,25411.930948019028,1.6457676887512207,0.0 -119200,0.39040095,3.6925652,,,,,,,,,,,,,,,,, -119300,0.38218698,3.6071799,,,,,,,,,,,,,,,,, -119400,0.39214063,3.6330953,,,,,,,,,,,,,,,,, -119500,0.38956597,3.63371,,,,,,,,,,,,,,,,, -119600,0.38718593,3.6303258,,,,,,,,,,,,,,,,, -119700,0.38917813,3.647605,,,,,,,,,,,,,,,,, -119800,0.38041714,3.5763726,,,,,,,,,,,,,,,,, -119900,0.37151104,3.6321058,,,,,,,,,,,,,,,,, -120000,0.37839478,3.6156178,,,,,,,,,,,,,,,,, -120100,0.39038026,3.639746,,,,,,,,,,,,,,,,, -120200,0.3969586,3.6040037,,,,,,,,,,,,,,,,, -120300,0.39574692,3.6588972,,,,,,,,,,,,,,,,, -120400,0.3966397,3.6680448,,,,,,,,,,,,,,,,, -120500,0.3770046,3.6341386,,,,,,,,,,,,,,,,, -120600,0.39158866,3.67403,,,,,,,,,,,,,,,,, -120700,0.38513523,3.6094103,,,,,,,,,,,,,,,,, -120800,0.39716408,3.608875,,,,,,,,,,,,,,,,, -120900,0.39058122,3.6306603,,,,,,,,,,,,,,,,, -121000,0.37797803,3.5823586,,,,,,,,,,,,,,,,, -121100,0.37776786,3.594919,,,,,,,,,,,,,,,,, -121200,0.37424028,3.627607,,,,,,,,,,,,,,,,, -121300,0.37857807,3.5580451,,,,,,,,,,,,,,,,, -121400,0.39586562,3.5981877,,,,,,,,,,,,,,,,, -121500,0.3799434,3.6085024,,,,,,,,,,,,,,,,, -121555,,,0.722062885761261,1.4466272592544556,37.89198746749157,0.6907168030738831,1.5995148420333862,30.64146443986613,3000.0,0.7092208862304688,1.4971085786819458,30.70075004627944,3003.0,42029.3257484436,67912.01016974449,42029.3257484436,25877.113832235336,1.688605785369873,0.0 -121600,0.37169746,3.540628,,,,,,,,,,,,,,,,, -121700,0.3910981,3.6302054,,,,,,,,,,,,,,,,, -121800,0.39325547,3.6638405,,,,,,,,,,,,,,,,, -121900,0.38176885,3.6360116,,,,,,,,,,,,,,,,, -122000,0.39185035,3.5888598,,,,,,,,,,,,,,,,, -122100,0.39035943,3.5559494,,,,,,,,,,,,,,,,, -122200,0.39716637,3.6366625,,,,,,,,,,,,,,,,, -122300,0.39316016,3.5920913,,,,,,,,,,,,,,,,, -122400,0.37585428,3.5784132,,,,,,,,,,,,,,,,, -122500,0.37861517,3.5844204,,,,,,,,,,,,,,,,, -122600,0.3929986,3.6083512,,,,,,,,,,,,,,,,, -122700,0.4061761,3.6340144,,,,,,,,,,,,,,,,, -122800,0.38848323,3.64155,,,,,,,,,,,,,,,,, -122900,0.4012139,3.599442,,,,,,,,,,,,,,,,, -123000,0.38702148,3.653253,,,,,,,,,,,,,,,,, -123100,0.39993954,3.6199381,,,,,,,,,,,,,,,,, -123200,0.36976233,3.5663247,,,,,,,,,,,,,,,,, -123300,0.37334988,3.6173105,,,,,,,,,,,,,,,,, -123400,0.3903119,3.6081226,,,,,,,,,,,,,,,,, -123500,0.40615532,3.580269,,,,,,,,,,,,,,,,, -123600,0.39331478,3.6465127,,,,,,,,,,,,,,,,, -123700,0.4121762,3.6036787,,,,,,,,,,,,,,,,, -123800,0.36573786,3.5883594,,,,,,,,,,,,,,,,, -123900,0.3917992,3.6027482,,,,,,,,,,,,,,,,, -123986,,,0.7214351892471313,1.4518346786499023,37.47774055464697,0.6916838884353638,1.5992367267608645,30.64448680287795,3000.0,0.7090930342674255,1.4954748153686523,30.69339181447463,3003.0,42869.3330552578,69206.40446782112,42869.3330552578,26331.379861593246,1.732133388519287,0.0 -124000,0.37707788,3.560778,,,,,,,,,,,,,,,,, -124100,0.39041287,3.581733,,,,,,,,,,,,,,,,, -124200,0.4037502,3.597618,,,,,,,,,,,,,,,,, -124300,0.39637777,3.5681574,,,,,,,,,,,,,,,,, -124400,0.3951702,3.615988,,,,,,,,,,,,,,,,, -124500,0.38213906,3.6352813,,,,,,,,,,,,,,,,, -124600,0.4039037,3.6208954,,,,,,,,,,,,,,,,, -124700,0.39687932,3.602206,,,,,,,,,,,,,,,,, -124800,0.39613616,3.6320853,,,,,,,,,,,,,,,,, -124900,0.3773153,3.6179924,,,,,,,,,,,,,,,,, -125000,0.38528356,3.6175318,,,,,,,,,,,,,,,,, -125100,0.41250923,3.6008818,,,,,,,,,,,,,,,,, -125200,0.371736,3.5722523,,,,,,,,,,,,,,,,, -125300,0.39288816,3.5954747,,,,,,,,,,,,,,,,, -125400,0.39151996,3.6632102,,,,,,,,,,,,,,,,, -125500,0.3994006,3.6162972,,,,,,,,,,,,,,,,, -125600,0.3994453,3.6342392,,,,,,,,,,,,,,,,, -125700,0.39592648,3.59442,,,,,,,,,,,,,,,,, -125800,0.37069994,3.622578,,,,,,,,,,,,,,,,, -125900,0.39758042,3.615271,,,,,,,,,,,,,,,,, -126000,0.38721398,3.613475,,,,,,,,,,,,,,,,, -126100,0.39442286,3.6316876,,,,,,,,,,,,,,,,, -126200,0.38191906,3.5878484,,,,,,,,,,,,,,,,, -126300,0.39300385,3.6982617,,,,,,,,,,,,,,,,, -126400,0.3958308,3.6355019,,,,,,,,,,,,,,,,, -126417,,,0.7267507910728455,1.4239526987075806,38.03936471702318,0.6922666430473328,1.5971962213516235,30.60053445338147,3000.0,0.7093719244003296,1.4938480854034424,30.82672499037724,3003.0,43709.52652788162,70500.4326248169,43709.52652788162,26785.094694375992,1.7746250629425049,0.0 -126500,0.37519896,3.6058035,,,,,,,,,,,,,,,,, -126600,0.3822817,3.5603447,,,,,,,,,,,,,,,,, -126700,0.37493673,3.553308,,,,,,,,,,,,,,,,, -126800,0.40782204,3.6090887,,,,,,,,,,,,,,,,, -126900,0.37559292,3.5758772,,,,,,,,,,,,,,,,, -127000,0.38680783,3.5642216,,,,,,,,,,,,,,,,, -127100,0.38679698,3.5896807,,,,,,,,,,,,,,,,, -127200,0.38477525,3.5786412,,,,,,,,,,,,,,,,, -127300,0.38411048,3.624871,,,,,,,,,,,,,,,,, -127400,0.3902146,3.5591755,,,,,,,,,,,,,,,,, -127500,0.38017803,3.6455894,,,,,,,,,,,,,,,,, -127600,0.39040273,3.5956814,,,,,,,,,,,,,,,,, -127700,0.393434,3.6151423,,,,,,,,,,,,,,,,, -127800,0.37995663,3.589612,,,,,,,,,,,,,,,,, -127900,0.38586125,3.5859542,,,,,,,,,,,,,,,,, -128000,0.3974409,3.603264,,,,,,,,,,,,,,,,, -128100,0.3932119,3.6348343,,,,,,,,,,,,,,,,, -128200,0.37911344,3.6212783,,,,,,,,,,,,,,,,, -128300,0.40931585,3.6380372,,,,,,,,,,,,,,,,, -128400,0.39018834,3.5922482,,,,,,,,,,,,,,,,, -128500,0.39100996,3.6042478,,,,,,,,,,,,,,,,, -128600,0.393314,3.627145,,,,,,,,,,,,,,,,, -128700,0.40179428,3.6180975,,,,,,,,,,,,,,,,, -128800,0.38689697,3.6257296,,,,,,,,,,,,,,,,, -128848,,,0.7229095101356506,1.4474419355392456,37.66247762215224,0.691745936870575,1.5984151363372805,30.603091854338725,3000.0,0.709511399269104,1.4947388172149658,30.74373078206694,3003.0,44549.70166611672,71797.655200243,44549.70166611672,27242.02351140976,1.816523790359497,0.0 -128900,0.37824595,3.609405,,,,,,,,,,,,,,,,, -129000,0.4113338,3.6131978,,,,,,,,,,,,,,,,, -129100,0.4027914,3.5731833,,,,,,,,,,,,,,,,, -129200,0.3913298,3.6125455,,,,,,,,,,,,,,,,, -129300,0.3788995,3.6041324,,,,,,,,,,,,,,,,, -129400,0.37337586,3.582239,,,,,,,,,,,,,,,,, -129500,0.37952182,3.5746837,,,,,,,,,,,,,,,,, -129600,0.371234,3.5831013,,,,,,,,,,,,,,,,, -129700,0.39928913,3.6158845,,,,,,,,,,,,,,,,, -129800,0.38312706,3.5749824,,,,,,,,,,,,,,,,, -129900,0.37480703,3.6390512,,,,,,,,,,,,,,,,, -130000,0.3771476,3.5609193,,,,,,,,,,,,,,,,, -130100,0.4040164,3.6123645,,,,,,,,,,,,,,,,, -130200,0.3987982,3.6253436,,,,,,,,,,,,,,,,, -130300,0.3988548,3.6549587,,,,,,,,,,,,,,,,, -130400,0.38858262,3.6167448,,,,,,,,,,,,,,,,, -130500,0.39817128,3.6057112,,,,,,,,,,,,,,,,, -130600,0.37625617,3.6083968,,,,,,,,,,,,,,,,, -130700,0.38741785,3.5915453,,,,,,,,,,,,,,,,, -130800,0.3890797,3.6370807,,,,,,,,,,,,,,,,, -130900,0.37617648,3.57924,,,,,,,,,,,,,,,,, -131000,0.3788181,3.572876,,,,,,,,,,,,,,,,, -131100,0.39235863,3.6320007,,,,,,,,,,,,,,,,, -131200,0.4031974,3.6310725,,,,,,,,,,,,,,,,, -131279,,,0.7239293456077576,1.431621551513672,37.89795779107517,0.6918451189994812,1.5985901355743408,30.610609263383143,3000.0,0.7093486785888672,1.4943861961364746,30.861089142384305,3003.0,45389.855187654495,73102.1201004982,45389.855187654495,27706.21432876587,1.859803915023804,0.0 -131300,0.4004829,3.658348,,,,,,,,,,,,,,,,, -131400,0.3835635,3.590221,,,,,,,,,,,,,,,,, -131500,0.3897221,3.6015534,,,,,,,,,,,,,,,,, -131600,0.38250205,3.6039438,,,,,,,,,,,,,,,,, -131700,0.38061723,3.6237903,,,,,,,,,,,,,,,,, -131800,0.40327117,3.6183858,,,,,,,,,,,,,,,,, -131900,0.37584007,3.555269,,,,,,,,,,,,,,,,, -132000,0.3960683,3.631314,,,,,,,,,,,,,,,,, -132100,0.39835817,3.553849,,,,,,,,,,,,,,,,, -132200,0.4138956,3.6193857,,,,,,,,,,,,,,,,, -132300,0.3928324,3.6133978,,,,,,,,,,,,,,,,, -132400,0.39511806,3.6505096,,,,,,,,,,,,,,,,, -132500,0.39116755,3.6309757,,,,,,,,,,,,,,,,, -132600,0.39242828,3.596045,,,,,,,,,,,,,,,,, -132700,0.3897926,3.631445,,,,,,,,,,,,,,,,, -132800,0.3889784,3.652993,,,,,,,,,,,,,,,,, -132900,0.4049117,3.6679752,,,,,,,,,,,,,,,,, -133000,0.3874271,3.6409817,,,,,,,,,,,,,,,,, -133100,0.37114137,3.5883934,,,,,,,,,,,,,,,,, -133200,0.4045326,3.6038606,,,,,,,,,,,,,,,,, -133300,0.40777907,3.6618035,,,,,,,,,,,,,,,,, -133333,,,0.7237716913223267,1.4409384727478027,37.902471942390434,0.691882312297821,1.59796142578125,30.55257145746482,3000.0,0.7095927000045776,1.4941442012786863,30.79392196122609,3003.0,46099.62748599053,74258.77504205704,46099.62748599053,28152.985813856125,1.9054770469665527,0.0 -133333,,,,,,,,,,,,,,46099.627485990524,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/eval_measurements.csv deleted file mode 100644 index b21562f88..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/eval_measurements.csv +++ /dev/null @@ -1,57 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -856.2799527645111,0.0,26.78457856178284,1,0,26.78457856178284,0.0007088489946909,0.0,11.036273956298828,3003,883.0645732879639,0.000615433731582,0.0,11.025583267211914,0.0004835649742744,0.0,11.047277450561523,3000 -1393.9082560539246,0.019158124923706,866.7384278774261,2429,0,866.7384278774261,0.3875080049037933,8.298912293063825,4.218213081359863,3003,2260.7427303791046,0.4166429340839386,14.596829131524013,3.9009313583374015,0.4004042148590088,9.782566022877925,4.036395072937012,3000 -1849.9122297763824,0.0464570522308349,1706.9111173152924,4857,0,1706.9111173152924,0.5460693836212158,19.060328733347863,2.64404845237732,3003,3557.023582458496,0.5418379306793213,24.88816508460163,2.6729602813720703,0.5444941520690918,20.57576643195957,2.6168112754821777,3000 -2303.706892490387,0.0725109577178955,2547.001212835312,7286,0,2547.001212835312,0.5889140963554382,21.840320736237747,2.2054107189178467,3003,4851.010681629181,0.5796564817428589,27.28085169810225,2.280468225479126,0.5861427783966064,23.4061450554374,2.215071201324463,3000 -2734.530050754547,0.0988461971282959,3387.198296308517,9716,0,3387.198296308517,0.6107489466667175,23.53198212123661,2.0003881454467773,3003,6122.1344566345215,0.5948212146759033,28.156699199500373,2.13262677192688,0.606229305267334,24.546534607890653,2.028346538543701,3000 -3398.5560586452484,0.1277399063110351,4227.132491111755,12147,0,4227.132491111755,0.6287490725517273,24.73216670379908,1.8641327619552608,3003,7626.200456619263,0.6030737161636353,29.21319992486941,2.049139976501465,0.6187896132469177,24.31398881379957,1.911129593849182,3000 -3879.895225048065,0.1547043323516845,5067.12614607811,14577,0,5067.12614607811,0.6393702030181885,25.362275219367604,1.772122144699097,3003,8947.63675236702,0.6107993721961975,29.97132197779252,1.964206337928772,0.6309283375740051,26.43848120140468,1.819167971611023,3000 -4321.288250684738,0.1813514232635498,5907.25464463234,17009,0,5907.25464463234,0.6478647589683533,26.263295660080857,1.698652267456055,3003,10229.260805130005,0.6224886775016785,30.31643486549434,1.8965948820114136,0.6383429765701294,26.82364824455908,1.756270408630371,3000 -4787.744311332703,0.2097384929656982,6747.339457988739,19441,0,6747.339457988739,0.6533263921737671,26.58671307572002,1.6645915508270264,3003,11535.906452178957,0.6325414180755615,30.4674639846037,1.800859808921814,0.6434885859489441,27.554292710136,1.7217135429382324,3000 -5273.040806770325,0.2392292022705078,7587.353776454925,21873,0,7587.353776454925,0.6594852209091187,27.107368915175847,1.6314548254013062,3003,12861.3235578537,0.6281017065048218,30.706210179092377,1.841070532798767,0.6485598087310791,27.49852401103104,1.6945736408233645,3000 -5754.607835054398,0.2663815021514892,8427.538682699203,24306,0,8427.538682699203,0.6609842777252197,27.086848862449703,1.61290442943573,3003,14183.1789124012,0.6283778548240662,30.231787337608484,1.834109783172608,0.6494525671005249,27.867396431428062,1.675093650817871,3000 -6209.833786487579,0.2946088314056396,9267.65031027794,26739,0,9267.65031027794,0.661495566368103,27.351616233989287,1.5966767072677612,3003,15478.620971679688,0.6350250840187073,31.10056987152565,1.7813962697982788,0.6521431803703308,27.86187178986665,1.6590827703475952,3000 -6689.792566776276,0.323559045791626,10107.838258981705,29173,0,10107.838258981705,0.6651443839073181,27.611150169070616,1.5824037790298462,3003,16798.87334752083,0.6313155889511108,31.167266657108826,1.8113305568695068,0.6551065444946289,28.01659926041504,1.6476982831954956,3000 -7208.857345581055,0.3532938957214355,10948.034049272535,31607,0,10948.034049272535,0.6655162572860718,27.55012963104404,1.5741702318191528,3003,18158.23977446556,0.6526026129722595,32.13943278256692,1.652001142501831,0.657090425491333,28.228548539202578,1.638078689575195,3000 -7730.468078613281,0.3834218978881836,11787.991804122925,34040,0,11787.991804122925,0.6660507917404175,27.39591630850173,1.565197229385376,3003,19519.91458582878,0.6393951773643494,31.008950206319334,1.768338680267334,0.6554041504859924,28.08600000210978,1.6242969036102295,3000 -8367.384491205215,0.4141757488250732,12628.08743071556,36474,0,12628.08743071556,0.667793869972229,27.544555632291114,1.5579041242599487,3003,20997.03369545937,0.6336399912834167,31.23075734922129,1.79706072807312,0.6571400165557861,28.072348613153583,1.6196331977844238,3000 -8889.13286614418,0.4446568489074707,13468.063895463943,38908,0,13468.063895463943,0.6705943942070007,27.78423681173496,1.5488643646240234,3003,22358.86438035965,0.6437124013900757,31.870111946059616,1.718425989151001,0.6600537896156311,28.30275218783093,1.6126872301101685,3000 -9489.347652196884,0.4772262573242187,14308.293565273283,41342,0,14308.293565273283,0.6712451577186584,27.958816788714334,1.538042068481445,3003,23799.416815519333,0.6403557658195496,31.64339967441196,1.7567963600158691,0.659694254398346,27.765434272920903,1.6073123216629028,3000 -10002.591368198397,0.5125997066497803,15148.395472049711,43776,0,15148.395472049711,0.6745337247848511,27.9198714817124,1.532659888267517,3003,25152.87335944176,0.6402798295021057,31.42158195384738,1.7504782676696775,0.6611325144767761,28.68003646635293,1.5975406169891355,3000 -10524.269515752792,0.5430536270141602,15988.560994148254,46210,0,15988.560994148254,0.6745337247848511,28.18982572482699,1.5249536037445068,3003,26514.822845697403,0.6428989768028259,31.507355253150934,1.7245334386825562,0.6620624661445618,28.495653401907823,1.5930095911026,3000 -11038.159530639648,0.5735483169555664,16828.613729715347,48644,0,16828.613729715347,0.6765324473381042,28.04292709290652,1.513628602027893,3003,27868.87157773972,0.6434541940689087,31.33666911989976,1.7303590774536133,0.6629676222801208,28.31539252944437,1.5866459608078003,3000 -11526.525929927826,0.6058268547058105,17668.511252641678,51077,0,17668.511252641678,0.6761838793754578,28.28753882239032,1.5135340690612793,3003,29197.242948055267,0.6509331464767456,32.02152703215946,1.667337417602539,0.6633147597312927,28.656745396147187,1.5808254480361938,3000 -12091.620062828064,0.637232780456543,18508.527975320816,53511,0,18508.527975320816,0.6800999641418457,28.593444733008027,1.4995527267456057,3003,30602.46071600914,0.6447697877883911,31.75867017336076,1.7171337604522705,0.6656085848808289,28.81990449034723,1.571361422538757,3000 -12760.62174320221,0.6705992221832275,19348.71917390824,55945,0,19348.71917390824,0.678345263004303,28.48742903827465,1.496812343597412,3003,32111.7632894516,0.6402867436408997,31.603690698245167,1.7420943975448608,0.6657822132110596,28.31823448140098,1.5687559843063354,3000 -13260.881618976591,0.7030997276306152,20188.83740305901,58379,0,20188.83740305901,0.6788333058357239,28.461223409180462,1.493443727493286,3003,33452.249553442,0.651101291179657,31.862773994581925,1.6768395900726318,0.6660177707672119,28.684740364153395,1.565895915031433,3000 -13749.946232795715,0.7364275455474854,21028.95820069313,60813,0,21028.95820069313,0.6794724464416504,28.543103089635263,1.485634207725525,3003,34781.54425239563,0.6452680826187134,31.57398436310127,1.7102110385894775,0.6664021611213684,28.68595025894429,1.5667551755905151,3000 -14239.12595629692,0.7697396278381348,21869.133969783783,63247,0,21869.133969783783,0.6820870637893677,28.98144871314173,1.4776993989944458,3003,36111.008915662766,0.6609110236167908,32.737365321346005,1.613291621208191,0.6695019006729126,29.214455853774165,1.551543354988098,3000 -14849.57773900032,0.8056454658508301,22709.169924020767,65681,0,22709.169924020767,0.6832723617553711,28.77781389125732,1.465322732925415,3003,37561.60833859444,0.6530311703681946,32.311872612330355,1.671697735786438,0.6706550121307373,28.82110298512525,1.538474202156067,3000 -15419.075818538666,0.8422815799713135,23549.336899280548,68115,0,23549.336899280548,0.6844460368156433,28.78482850239207,1.4604904651641846,3003,38971.38629126549,0.6517167091369629,32.40043379300699,1.6799516677856443,0.670059859752655,28.84134266910904,1.534948229789734,3000 -15875.903427124023,0.8780360221862793,24389.565573453903,70549,0,24389.565573453903,0.6852594614028931,28.99412723224184,1.4523561000823977,3003,40268.55391001701,0.6591607332229614,32.18171612106351,1.6216074228286743,0.6709278225898743,28.9270734590302,1.5275872945785522,3000 -16340.762417078018,0.9139454364776612,25229.58171772957,72983,0,25229.58171772957,0.6866422891616821,28.961118007307395,1.4460179805755615,3003,41573.54033780098,0.6545958518981934,32.03680376481512,1.6556434631347656,0.6736928224563599,29.4972412252912,1.5203146934509275,3000 -16852.63940834999,0.9512898921966552,26069.81137084961,75417,0,26069.81137084961,0.6866190433502197,29.221898187490247,1.439462661743164,3003,42925.76065564156,0.6752561330795288,33.8632993216732,1.5263938903808594,0.6756518483161926,29.306960178991066,1.5147626399993896,3000 -17378.230769634247,0.9887261390686036,26909.84682536125,77851,0,26909.84682536125,0.6893149614334106,29.199608789390265,1.4282652139663696,3003,44291.5014474392,0.6602402329444885,32.94359332980648,1.6101831197738647,0.6758378744125366,29.88321465163858,1.5022680759429932,3000 -17877.399688482285,1.0267176628112793,27750.07419347763,80285,0,27750.07419347763,0.6895241737365723,29.169380704529548,1.4248106479644775,3003,45631.01221609116,0.6605494022369385,32.74820463146153,1.6196547746658323,0.6773629784584045,29.766035010112983,1.497270107269287,3000 -18438.65266013145,1.062680959701538,28590.17192029953,82719,0,28590.17192029953,0.6920806765556335,29.278406234177744,1.4164384603500366,3003,47032.475782871246,0.6689773201942444,33.41044851396017,1.5642541646957395,0.6781564950942993,29.560270364388217,1.491786003112793,3000 -19057.13529086113,1.1003947257995603,29430.228848457336,85153,0,29430.228848457336,0.6939166784286499,29.56086936865955,1.4057952165603638,3003,48491.12983894348,0.6641597151756287,32.891950379823754,1.5821653604507446,0.679582417011261,29.699337732733976,1.482561111450195,3000 -19621.019548416138,1.1373250484466553,30270.3239107132,87587,0,30270.3239107132,0.6963105201721191,29.990684621811305,1.3956819772720337,3003,49895.22182369232,0.6998754143714905,36.14299179613099,1.4042586088180542,0.6811694502830505,30.155293354746227,1.4776188135147097,3000 -20117.80855345726,1.1739370822906494,31110.369894504547,90021,0,31110.369894504547,0.6959386467933655,29.75166453868224,1.3930952548980713,3003,51232.16949701309,0.6689482927322388,33.551267722025884,1.5577067136764526,0.6799047589302063,29.56897782735325,1.47508442401886,3000 -20629.78096461296,1.2131965160369873,31950.513469696045,92455,0,31950.513469696045,0.6989483833312988,29.880037934469392,1.3823401927947998,3003,52584.40080785751,0.6670649647712708,33.667583219883554,1.56660258769989,0.6828681230545044,30.19248417028501,1.465275764465332,3000 -21136.12323880196,1.2523298263549805,32790.68526005745,94889,0,32790.68526005745,0.6998547315597534,30.19016397136998,1.3732168674468994,3003,53931.0303311348,0.6854655146598816,34.65779037824217,1.4662498235702517,0.6841452717781067,30.1208920740266,1.458268165588379,3000 -21704.13444662094,1.2920780181884766,33630.83154082298,97323,0,33630.83154082298,0.7023648023605347,30.336006678823065,1.3640131950378418,3003,55339.30372548103,0.6764490008354187,34.23678263585544,1.5147420167922974,0.686042308807373,30.389317993276933,1.448639750480652,3000 -22318.83654236793,1.3307271003723145,34470.79104375839,99756,0,34470.79104375839,0.7009587287902832,30.23364767228356,1.360594391822815,3003,56794.07966709137,0.6764749884605408,34.04751952929381,1.5274385213851929,0.6861662864685059,30.30977062684815,1.449126362800598,3000 -22848.193954706192,1.3705039024353027,35310.75461125374,102189,0,35310.75461125374,0.7030736207962036,30.37917219885976,1.3559722900390625,3003,58163.51621007919,0.6875606179237366,34.74172992301118,1.455111384391785,0.6875798106193542,30.38678618421871,1.437204360961914,3000 -23436.27487707138,1.4109792709350586,36150.90387272835,104623,0,36150.90387272835,0.704317033290863,30.614908680396404,1.349063277244568,3003,59591.86266922951,0.6816208958625793,34.384006160637064,1.488276720046997,0.6881005764007568,30.368678371068093,1.4366995096206665,3000 -23986.36045742035,1.4505560398101809,36990.97637438774,107056,0,36990.97637438774,0.7048166990280151,30.54643101468684,1.348414421081543,3003,60982.136293411255,0.6984134316444397,35.94551384243376,1.395855188369751,0.6892660856246948,30.78100347405398,1.4338706731796265,3000 -24518.378514528275,1.4953322410583496,37831.17963075638,109490,0,37831.17963075638,0.7041543126106262,30.61245571439726,1.3388774394989014,3003,62354.47929620743,0.6916165947914124,35.22134886001856,1.4298664331436155,0.6891669034957886,30.790310256380376,1.4241664409637451,3000 -25058.54103994369,1.5352284908294678,38671.1463496685,111923,0,38671.1463496685,0.7065016627311707,30.75464808617365,1.336126446723938,3003,63734.7251303196,0.6884708404541016,35.23403588662617,1.4604493379592896,0.6906672120094299,30.63906031522196,1.4245483875274658,3000 -25613.393503665924,1.575636625289917,39511.106392622,114356,0,39511.106392622,0.7080007195472717,30.880743489149467,1.3302934169769287,3003,65129.65449762344,0.7008587121963501,36.201180058925885,1.3816499710083008,0.6914979219436646,30.521495408385693,1.4168438911437988,3000 -26140.66422200203,1.618021011352539,40351.31870722771,116790,0,40351.31870722771,0.7080704569816589,30.986274615235534,1.329336166381836,3003,66497.25703215599,0.7000380158424377,35.78679107524839,1.39041268825531,0.6912747621536255,30.574397190601577,1.4197064638137815,3000 -26650.726126909256,1.6613097190856934,41191.29754805565,119223,0,41191.29754805565,0.7090930342674255,31.189875956554665,1.3264704942703247,3003,67847.41802787781,0.7113429307937622,36.80936519873941,1.3300620317459106,0.6935933828353882,30.94324378288797,1.412529468536377,3000 -27231.141949653625,1.7042453289031982,42031.31493568421,121656,0,42031.31493568421,0.7095810770988464,30.90680463875628,1.3253158330917358,3003,69267.97033762932,0.707312285900116,36.62262186936308,1.3521544933319092,0.6932833790779114,30.834937831447107,1.41163969039917,3000 -27790.95612001419,1.74674654006958,42871.35903739929,124089,0,42871.35903739929,0.7108128666877747,31.108249855042047,1.3213181495666504,3003,70667.94770431519,0.707473874092102,36.84931955921861,1.3492178916931152,0.6941513419151306,30.96556062109168,1.4105970859527588,3000 -28348.25263595581,1.7892396450042725,43711.43947052956,126522,0,43711.43947052956,0.7102434635162354,31.24458192290439,1.3205872774124146,3003,72065.4435763359,0.7100922465324402,37.03802375742551,1.3384053707122805,0.6936057806015015,30.960405634997983,1.4109114408493042,3000 -28918.01958155632,1.8332068920135496,44551.3946352005,128955,0,44551.3946352005,0.7098715901374817,31.098821084681767,1.3178081512451172,3003,73475.28617930412,0.7126824855804443,36.65160501353665,1.3262940645217896,0.6942629218101501,31.02216745186169,1.4080395698547363,3000 -29487.417108535767,1.877014636993408,45391.56788253784,131389,0,45391.56788253784,0.7102085947990417,31.13512948906848,1.3181660175323486,3003,74884.97759580612,0.712949275970459,36.297809192303134,1.319365382194519,0.6945109367370605,30.947350711661613,1.408849596977234,3000 -30065.29164481163,1.9203734397888184,46062.535388469696,133333,0,46062.535388469696,0.7103132009506226,31.18299976012298,1.3183356523513794,3003,76133.92356038094,0.7099116444587708,37.02394184889104,1.3356733322143555,0.6946597099304199,31.011342101754487,1.4091264009475708,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/measurements.csv deleted file mode 100644 index 9a3d559de..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/measurements.csv +++ /dev/null @@ -1,1392 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.7015142,11.023183,,,,,,,,,,,,,,,,, -1,,,0.000615433731582,11.025583267211914,0.0,0.0004835649742744,11.047277450561523,0.0,3000.0,0.0007088489946909,11.036273956298828,0.0,3003.0,26.78457856178284,883.0645732879639,26.78457856178284,856.2799527645111,0.0,0.0 -100,0.42692253,8.692481,,,,,,,,,,,,,,,,, -200,0.18607579,8.369736,,,,,,,,,,,,,,,,, -300,0.19438364,8.08527,,,,,,,,,,,,,,,,, -400,0.29111674,7.6453896,,,,,,,,,,,,,,,,, -500,0.30855697,7.222339,,,,,,,,,,,,,,,,, -600,0.5080569,6.8993073,,,,,,,,,,,,,,,,, -700,0.6625405,6.7384143,,,,,,,,,,,,,,,,, -800,0.8119606,6.4339743,,,,,,,,,,,,,,,,, -900,0.69327456,6.205782,,,,,,,,,,,,,,,,, -1000,0.5173838,6.000431,,,,,,,,,,,,,,,,, -1100,0.62627447,5.841181,,,,,,,,,,,,,,,,, -1200,0.6351033,5.5299134,,,,,,,,,,,,,,,,, -1300,0.6263112,5.447632,,,,,,,,,,,,,,,,, -1400,0.6886002,5.248096,,,,,,,,,,,,,,,,, -1500,0.9422509,5.1175885,,,,,,,,,,,,,,,,, -1600,1.3387505,4.9875793,,,,,,,,,,,,,,,,, -1700,0.9858228,4.9102535,,,,,,,,,,,,,,,,, -1800,0.94727933,4.6384454,,,,,,,,,,,,,,,,, -1900,0.8536355,4.603952,,,,,,,,,,,,,,,,, -2000,1.6323333,4.5192695,,,,,,,,,,,,,,,,, -2100,1.5345101,4.2285786,,,,,,,,,,,,,,,,, -2200,0.97368217,4.242387,,,,,,,,,,,,,,,,, -2300,0.8115448,4.0468082,,,,,,,,,,,,,,,,, -2400,1.4015605,3.9373755,,,,,,,,,,,,,,,,, -2429,,,0.4166429340839386,3.9009313583374015,14.596829131524013,0.4004042148590088,4.036395072937012,9.782566022877925,3000.0,0.3875080049037933,4.218213081359863,8.298912293063825,3003.0,866.7384278774261,2260.7427303791046,866.7384278774261,1393.9082560539246,0.019158124923706,0.0 -2500,0.7784716,3.8770535,,,,,,,,,,,,,,,,, -2600,0.8921749,3.704095,,,,,,,,,,,,,,,,, -2700,0.7892859,3.6354177,,,,,,,,,,,,,,,,, -2800,0.88596195,3.615497,,,,,,,,,,,,,,,,, -2900,1.0507729,3.4323027,,,,,,,,,,,,,,,,, -3000,0.75206757,3.4537387,,,,,,,,,,,,,,,,, -3100,0.78696305,3.361725,,,,,,,,,,,,,,,,, -3200,0.9403066,3.2946289,,,,,,,,,,,,,,,,, -3300,0.8655783,3.2149477,,,,,,,,,,,,,,,,, -3400,1.1603237,3.2078366,,,,,,,,,,,,,,,,, -3500,0.7514391,3.0338488,,,,,,,,,,,,,,,,, -3600,0.8891195,3.1168432,,,,,,,,,,,,,,,,, -3700,0.745298,2.9967072,,,,,,,,,,,,,,,,, -3800,0.8231009,2.9686878,,,,,,,,,,,,,,,,, -3900,0.83548146,2.9068327,,,,,,,,,,,,,,,,, -4000,0.6356099,2.9499693,,,,,,,,,,,,,,,,, -4100,0.67891884,2.813062,,,,,,,,,,,,,,,,, -4200,0.66937333,2.8950317,,,,,,,,,,,,,,,,, -4300,0.9782236,2.787497,,,,,,,,,,,,,,,,, -4400,0.73520786,2.7141726,,,,,,,,,,,,,,,,, -4500,0.6519702,2.6529982,,,,,,,,,,,,,,,,, -4600,0.62317765,2.7048507,,,,,,,,,,,,,,,,, -4700,0.69741386,2.6351008,,,,,,,,,,,,,,,,, -4800,0.6735407,2.7587805,,,,,,,,,,,,,,,,, -4857,,,0.5418379306793213,2.6729602813720703,24.88816508460163,0.5444941520690918,2.6168112754821777,20.57576643195957,3000.0,0.5460693836212158,2.64404845237732,19.060328733347863,3003.0,1706.9111173152924,3557.023582458496,1706.9111173152924,1849.9122297763824,0.0464570522308349,0.0 -4900,0.635848,2.628738,,,,,,,,,,,,,,,,, -5000,0.64076424,2.587431,,,,,,,,,,,,,,,,, -5100,0.5680269,2.5791938,,,,,,,,,,,,,,,,, -5200,0.6600023,2.6381695,,,,,,,,,,,,,,,,, -5300,0.8055273,2.552442,,,,,,,,,,,,,,,,, -5400,0.7524958,2.5330431,,,,,,,,,,,,,,,,, -5500,0.63772774,2.5549746,,,,,,,,,,,,,,,,, -5600,0.56061655,2.5016768,,,,,,,,,,,,,,,,, -5700,0.6062932,2.4861956,,,,,,,,,,,,,,,,, -5800,0.5828378,2.4555755,,,,,,,,,,,,,,,,, -5900,0.5617139,2.423847,,,,,,,,,,,,,,,,, -6000,0.5623844,2.482595,,,,,,,,,,,,,,,,, -6100,0.7698267,2.526961,,,,,,,,,,,,,,,,, -6200,0.6487868,2.5089865,,,,,,,,,,,,,,,,, -6300,0.50783026,2.4818559,,,,,,,,,,,,,,,,, -6400,0.58369714,2.4089487,,,,,,,,,,,,,,,,, -6500,0.51976657,2.3748896,,,,,,,,,,,,,,,,, -6600,0.4819901,2.3152518,,,,,,,,,,,,,,,,, -6700,0.5018483,2.2752538,,,,,,,,,,,,,,,,, -6800,0.47548997,2.34293,,,,,,,,,,,,,,,,, -6900,0.5828902,2.3361986,,,,,,,,,,,,,,,,, -7000,0.59775996,2.4546862,,,,,,,,,,,,,,,,, -7100,0.48935992,2.277,,,,,,,,,,,,,,,,, -7200,0.45831943,2.2442555,,,,,,,,,,,,,,,,, -7286,,,0.5796564817428589,2.280468225479126,27.28085169810225,0.5861427783966064,2.215071201324463,23.4061450554374,3000.0,0.5889140963554382,2.2054107189178467,21.840320736237747,3003.0,2547.001212835312,4851.010681629181,2547.001212835312,2303.706892490387,0.0725109577178955,0.0 -7300,0.45073968,2.212172,,,,,,,,,,,,,,,,, -7400,0.5061418,2.319963,,,,,,,,,,,,,,,,, -7500,0.47407803,2.2633793,,,,,,,,,,,,,,,,, -7600,0.47782716,2.304555,,,,,,,,,,,,,,,,, -7700,0.43384516,2.2983727,,,,,,,,,,,,,,,,, -7800,0.42240557,2.1388512,,,,,,,,,,,,,,,,, -7900,0.4421249,2.2899134,,,,,,,,,,,,,,,,, -8000,0.43001938,2.287013,,,,,,,,,,,,,,,,, -8100,0.43938494,2.2918212,,,,,,,,,,,,,,,,, -8200,0.43364292,2.2254634,,,,,,,,,,,,,,,,, -8300,0.43777907,2.2078736,,,,,,,,,,,,,,,,, -8400,0.4368944,2.3141956,,,,,,,,,,,,,,,,, -8500,0.4199513,2.2520168,,,,,,,,,,,,,,,,, -8600,0.3778361,2.1630366,,,,,,,,,,,,,,,,, -8700,0.38343552,2.1339664,,,,,,,,,,,,,,,,, -8800,0.4246322,2.2441614,,,,,,,,,,,,,,,,, -8900,0.38519454,2.1720533,,,,,,,,,,,,,,,,, -9000,0.36464024,2.292979,,,,,,,,,,,,,,,,, -9100,0.3522099,2.1744528,,,,,,,,,,,,,,,,, -9200,0.4540086,2.3004858,,,,,,,,,,,,,,,,, -9300,0.35696915,2.166578,,,,,,,,,,,,,,,,, -9400,0.4159555,2.244906,,,,,,,,,,,,,,,,, -9500,0.36463317,2.1438673,,,,,,,,,,,,,,,,, -9600,0.41350493,2.1686747,,,,,,,,,,,,,,,,, -9700,0.33105302,2.1014216,,,,,,,,,,,,,,,,, -9716,,,0.5948212146759033,2.13262677192688,28.156699199500373,0.606229305267334,2.028346538543701,24.546534607890653,3000.0,0.6107489466667175,2.0003881454467773,23.53198212123661,3003.0,3387.198296308517,6122.1344566345215,3387.198296308517,2734.530050754547,0.0988461971282959,0.0 -9800,0.36159542,2.1452897,,,,,,,,,,,,,,,,, -9900,0.42000327,2.0560408,,,,,,,,,,,,,,,,, -10000,0.35068262,2.1580997,,,,,,,,,,,,,,,,, -10100,0.34595552,2.0819254,,,,,,,,,,,,,,,,, -10200,0.41333374,2.159239,,,,,,,,,,,,,,,,, -10300,0.31468493,2.0656695,,,,,,,,,,,,,,,,, -10400,0.32049018,2.0682352,,,,,,,,,,,,,,,,, -10500,0.30347922,2.1531258,,,,,,,,,,,,,,,,, -10600,0.37868357,2.0456038,,,,,,,,,,,,,,,,, -10700,0.37242466,2.1001394,,,,,,,,,,,,,,,,, -10800,0.31550238,2.1387966,,,,,,,,,,,,,,,,, -10900,0.33503842,2.134248,,,,,,,,,,,,,,,,, -11000,0.29794046,2.1865091,,,,,,,,,,,,,,,,, -11100,0.3260898,2.0323567,,,,,,,,,,,,,,,,, -11200,0.3122974,2.103404,,,,,,,,,,,,,,,,, -11300,0.323886,2.0620158,,,,,,,,,,,,,,,,, -11400,0.28325966,2.022004,,,,,,,,,,,,,,,,, -11500,0.2773462,2.1491234,,,,,,,,,,,,,,,,, -11600,0.33591077,2.1487882,,,,,,,,,,,,,,,,, -11700,0.28919327,2.0075479,,,,,,,,,,,,,,,,, -11800,0.27613878,2.1187341,,,,,,,,,,,,,,,,, -11900,0.27543432,2.0480804,,,,,,,,,,,,,,,,, -12000,0.3000447,2.128397,,,,,,,,,,,,,,,,, -12100,0.30036312,2.0158756,,,,,,,,,,,,,,,,, -12147,,,0.6030737161636353,2.049139976501465,29.21319992486941,0.6187896132469177,1.911129593849182,24.31398881379957,3000.0,0.6287490725517273,1.8641327619552608,24.73216670379908,3003.0,4227.132491111755,7626.200456619263,4227.132491111755,3398.5560586452484,0.1277399063110351,0.0 -12200,0.27839378,2.0635676,,,,,,,,,,,,,,,,, -12300,0.28893715,2.1094363,,,,,,,,,,,,,,,,, -12400,0.27438644,2.0286932,,,,,,,,,,,,,,,,, -12500,0.32345247,2.0404932,,,,,,,,,,,,,,,,, -12600,0.27349365,1.9590448,,,,,,,,,,,,,,,,, -12700,0.27600744,1.9464551,,,,,,,,,,,,,,,,, -12800,0.2713624,2.055991,,,,,,,,,,,,,,,,, -12900,0.26206028,2.0311959,,,,,,,,,,,,,,,,, -13000,0.28733373,1.9944144,,,,,,,,,,,,,,,,, -13100,0.30127856,2.037164,,,,,,,,,,,,,,,,, -13200,0.271014,2.0320687,,,,,,,,,,,,,,,,, -13300,0.29710123,2.0352902,,,,,,,,,,,,,,,,, -13400,0.2640284,2.0632234,,,,,,,,,,,,,,,,, -13500,0.26148963,1.9709587,,,,,,,,,,,,,,,,, -13600,0.35972273,2.004726,,,,,,,,,,,,,,,,, -13700,0.25697133,1.9826914,,,,,,,,,,,,,,,,, -13800,0.26141286,2.0145252,,,,,,,,,,,,,,,,, -13900,0.32201135,2.0886328,,,,,,,,,,,,,,,,, -14000,0.28135675,1.90323,,,,,,,,,,,,,,,,, -14100,0.2652372,1.9775013,,,,,,,,,,,,,,,,, -14200,0.277573,1.9395493,,,,,,,,,,,,,,,,, -14300,0.2747504,1.9627155,,,,,,,,,,,,,,,,, -14400,0.30034298,2.010576,,,,,,,,,,,,,,,,, -14500,0.36536816,1.98604,,,,,,,,,,,,,,,,, -14577,,,0.6107993721961975,1.964206337928772,29.97132197779252,0.6309283375740051,1.819167971611023,26.43848120140468,3000.0,0.6393702030181885,1.772122144699097,25.362275219367604,3003.0,5067.12614607811,8947.63675236702,5067.12614607811,3879.895225048065,0.1547043323516845,0.0 -14600,0.31817663,2.0747092,,,,,,,,,,,,,,,,, -14700,0.26078483,1.946856,,,,,,,,,,,,,,,,, -14800,0.2828422,1.9845266,,,,,,,,,,,,,,,,, -14900,0.30040547,1.9225361,,,,,,,,,,,,,,,,, -15000,0.34539574,1.8687011,,,,,,,,,,,,,,,,, -15100,0.37440193,2.0154805,,,,,,,,,,,,,,,,, -15200,0.28160623,1.9879327,,,,,,,,,,,,,,,,, -15300,0.30031165,1.9517035,,,,,,,,,,,,,,,,, -15400,0.2651583,1.8651973,,,,,,,,,,,,,,,,, -15500,0.28859374,1.9867579,,,,,,,,,,,,,,,,, -15600,0.28163552,1.9547491,,,,,,,,,,,,,,,,, -15700,0.295284,1.9218662,,,,,,,,,,,,,,,,, -15800,0.27480882,1.8617295,,,,,,,,,,,,,,,,, -15900,0.31918624,1.9709146,,,,,,,,,,,,,,,,, -16000,0.3244585,1.9317752,,,,,,,,,,,,,,,,, -16100,0.32192567,2.0174859,,,,,,,,,,,,,,,,, -16200,0.3274606,1.8963122,,,,,,,,,,,,,,,,, -16300,0.31215936,1.8881512,,,,,,,,,,,,,,,,, -16400,0.33451623,1.8343753,,,,,,,,,,,,,,,,, -16500,0.33887842,1.9036691,,,,,,,,,,,,,,,,, -16600,0.39059484,1.910025,,,,,,,,,,,,,,,,, -16700,0.29931152,1.8931541,,,,,,,,,,,,,,,,, -16800,0.3086097,1.9341519,,,,,,,,,,,,,,,,, -16900,0.39607197,1.9149657,,,,,,,,,,,,,,,,, -17000,0.2862238,1.9152881,,,,,,,,,,,,,,,,, -17009,,,0.6224886775016785,1.8965948820114136,30.31643486549434,0.6383429765701294,1.756270408630371,26.82364824455908,3000.0,0.6478647589683533,1.698652267456055,26.263295660080857,3003.0,5907.25464463234,10229.260805130005,5907.25464463234,4321.288250684738,0.1813514232635498,0.0 -17100,0.3263044,1.939114,,,,,,,,,,,,,,,,, -17200,0.28770962,1.8734915,,,,,,,,,,,,,,,,, -17300,0.30855793,1.9405348,,,,,,,,,,,,,,,,, -17400,0.33564502,1.8786021,,,,,,,,,,,,,,,,, -17500,0.32777938,1.9104378,,,,,,,,,,,,,,,,, -17600,0.33643392,1.9573613,,,,,,,,,,,,,,,,, -17700,0.35848036,1.8889025,,,,,,,,,,,,,,,,, -17800,0.30393097,1.8752978,,,,,,,,,,,,,,,,, -17900,0.30013555,1.9208,,,,,,,,,,,,,,,,, -18000,0.3528991,1.8954772,,,,,,,,,,,,,,,,, -18100,0.6221987,1.9872416,,,,,,,,,,,,,,,,, -18200,0.46055216,1.9901092,,,,,,,,,,,,,,,,, -18300,0.35141882,1.9753728,,,,,,,,,,,,,,,,, -18400,0.33859226,1.935141,,,,,,,,,,,,,,,,, -18500,0.3189734,1.9274582,,,,,,,,,,,,,,,,, -18600,0.31860486,1.8545583,,,,,,,,,,,,,,,,, -18700,0.33246076,1.8528558,,,,,,,,,,,,,,,,, -18800,0.32170197,1.9532484,,,,,,,,,,,,,,,,, -18900,0.3032117,1.7739955,,,,,,,,,,,,,,,,, -19000,0.32762623,1.8446469,,,,,,,,,,,,,,,,, -19100,0.33503735,1.8888397,,,,,,,,,,,,,,,,, -19200,0.33744845,1.8695297,,,,,,,,,,,,,,,,, -19300,0.3579152,1.8523552,,,,,,,,,,,,,,,,, -19400,0.5067622,1.90507,,,,,,,,,,,,,,,,, -19441,,,0.6325414180755615,1.800859808921814,30.4674639846037,0.6434885859489441,1.7217135429382324,27.554292710136,3000.0,0.6533263921737671,1.6645915508270264,26.58671307572002,3003.0,6747.339457988739,11535.906452178957,6747.339457988739,4787.744311332703,0.2097384929656982,0.0 -19500,0.34777305,1.873581,,,,,,,,,,,,,,,,, -19600,0.37133503,1.845381,,,,,,,,,,,,,,,,, -19700,0.34726116,1.8983967,,,,,,,,,,,,,,,,, -19800,0.33711314,1.9716662,,,,,,,,,,,,,,,,, -19900,0.36074924,1.8850617,,,,,,,,,,,,,,,,, -20000,0.3583095,1.8883915,,,,,,,,,,,,,,,,, -20100,0.36961538,1.8265404,,,,,,,,,,,,,,,,, -20200,0.3372086,1.8382267,,,,,,,,,,,,,,,,, -20300,0.35417905,1.8748183,,,,,,,,,,,,,,,,, -20400,0.33154765,1.9260119,,,,,,,,,,,,,,,,, -20500,0.36361614,1.8847044,,,,,,,,,,,,,,,,, -20600,0.34436187,1.914485,,,,,,,,,,,,,,,,, -20700,0.37265432,1.8629951,,,,,,,,,,,,,,,,, -20800,0.44619069,1.8373575,,,,,,,,,,,,,,,,, -20900,0.36940363,1.789417,,,,,,,,,,,,,,,,, -21000,0.4387067,1.8206336,,,,,,,,,,,,,,,,, -21100,0.41688153,1.8109045,,,,,,,,,,,,,,,,, -21200,0.4242643,1.8507069,,,,,,,,,,,,,,,,, -21300,0.53508925,1.8563818,,,,,,,,,,,,,,,,, -21400,0.35807022,1.8624495,,,,,,,,,,,,,,,,, -21500,0.35041544,1.8780318,,,,,,,,,,,,,,,,, -21600,0.38700268,1.8385799,,,,,,,,,,,,,,,,, -21700,0.42074844,1.8084894,,,,,,,,,,,,,,,,, -21800,0.35404083,1.8437409,,,,,,,,,,,,,,,,, -21873,,,0.6281017065048218,1.841070532798767,30.706210179092377,0.6485598087310791,1.6945736408233645,27.49852401103104,3000.0,0.6594852209091187,1.6314548254013062,27.107368915175847,3003.0,7587.353776454925,12861.3235578537,7587.353776454925,5273.040806770325,0.2392292022705078,0.0 -21900,0.3906864,1.931081,,,,,,,,,,,,,,,,, -22000,0.34784842,1.9019667,,,,,,,,,,,,,,,,, -22100,0.5152899,1.8952776,,,,,,,,,,,,,,,,, -22200,0.42324916,1.9424174,,,,,,,,,,,,,,,,, -22300,0.4005329,1.754122,,,,,,,,,,,,,,,,, -22400,0.4070527,1.9192075,,,,,,,,,,,,,,,,, -22500,0.37050232,1.8896105,,,,,,,,,,,,,,,,, -22600,0.34357563,1.8101082,,,,,,,,,,,,,,,,, -22700,0.37664995,1.879534,,,,,,,,,,,,,,,,, -22800,0.33814228,1.830496,,,,,,,,,,,,,,,,, -22900,0.37825996,1.8481679,,,,,,,,,,,,,,,,, -23000,0.43674046,1.938857,,,,,,,,,,,,,,,,, -23100,0.3590109,1.8746668,,,,,,,,,,,,,,,,, -23200,0.40441048,1.8177568,,,,,,,,,,,,,,,,, -23300,0.3647515,1.8424037,,,,,,,,,,,,,,,,, -23400,0.34213492,1.8397263,,,,,,,,,,,,,,,,, -23500,0.51178986,1.817394,,,,,,,,,,,,,,,,, -23600,0.40719143,1.8269076,,,,,,,,,,,,,,,,, -23700,0.3814802,1.7866325,,,,,,,,,,,,,,,,, -23800,0.53029174,1.8606688,,,,,,,,,,,,,,,,, -23900,0.3882449,1.9409317,,,,,,,,,,,,,,,,, -24000,0.4064712,1.7984389,,,,,,,,,,,,,,,,, -24100,0.39314264,1.8790857,,,,,,,,,,,,,,,,, -24200,0.370475,1.883836,,,,,,,,,,,,,,,,, -24300,0.38070157,1.8652772,,,,,,,,,,,,,,,,, -24306,,,0.6283778548240662,1.834109783172608,30.231787337608484,0.6494525671005249,1.675093650817871,27.867396431428062,3000.0,0.6609842777252197,1.61290442943573,27.086848862449703,3003.0,8427.538682699203,14183.1789124012,8427.538682699203,5754.607835054398,0.2663815021514892,0.0 -24400,0.40038842,1.857336,,,,,,,,,,,,,,,,, -24500,0.38800877,1.8501917,,,,,,,,,,,,,,,,, -24600,0.41761765,1.8660622,,,,,,,,,,,,,,,,, -24700,0.37129542,1.8420522,,,,,,,,,,,,,,,,, -24800,0.46601674,1.8573704,,,,,,,,,,,,,,,,, -24900,0.48516643,1.800571,,,,,,,,,,,,,,,,, -25000,0.46189567,1.797083,,,,,,,,,,,,,,,,, -25100,0.34157524,1.8160626,,,,,,,,,,,,,,,,, -25200,0.43581158,1.8442007,,,,,,,,,,,,,,,,, -25300,0.4475184,1.9397982,,,,,,,,,,,,,,,,, -25400,0.39675567,1.7999974,,,,,,,,,,,,,,,,, -25500,0.40833005,1.7297152,,,,,,,,,,,,,,,,, -25600,0.3812148,1.7588797,,,,,,,,,,,,,,,,, -25700,0.44831005,1.8642422,,,,,,,,,,,,,,,,, -25800,0.38050863,1.9731007,,,,,,,,,,,,,,,,, -25900,0.36111113,1.7536958,,,,,,,,,,,,,,,,, -26000,0.36714846,1.7458818,,,,,,,,,,,,,,,,, -26100,0.34650564,1.7797726,,,,,,,,,,,,,,,,, -26200,0.39252988,1.8697344,,,,,,,,,,,,,,,,, -26300,0.3786053,1.8061944,,,,,,,,,,,,,,,,, -26400,0.36492887,1.8614223,,,,,,,,,,,,,,,,, -26500,0.38078842,1.8183573,,,,,,,,,,,,,,,,, -26600,0.48163003,1.887845,,,,,,,,,,,,,,,,, -26700,0.40678635,1.8237703,,,,,,,,,,,,,,,,, -26739,,,0.6350250840187073,1.7813962697982788,31.10056987152565,0.6521431803703308,1.6590827703475952,27.86187178986665,3000.0,0.661495566368103,1.5966767072677612,27.351616233989287,3003.0,9267.65031027794,15478.620971679688,9267.65031027794,6209.833786487579,0.2946088314056396,0.0 -26800,0.44882956,1.8319151,,,,,,,,,,,,,,,,, -26900,0.41958058,1.8860464,,,,,,,,,,,,,,,,, -27000,0.40340626,1.7604271,,,,,,,,,,,,,,,,, -27100,0.37534803,1.7754877,,,,,,,,,,,,,,,,, -27200,0.3934247,1.7745041,,,,,,,,,,,,,,,,, -27300,0.46494606,1.8475047,,,,,,,,,,,,,,,,, -27400,0.4487008,1.7301922,,,,,,,,,,,,,,,,, -27500,0.5572955,1.8536336,,,,,,,,,,,,,,,,, -27600,0.39666927,1.824733,,,,,,,,,,,,,,,,, -27700,0.3968517,1.8090485,,,,,,,,,,,,,,,,, -27800,0.44501188,1.8427657,,,,,,,,,,,,,,,,, -27900,0.38882652,1.8067404,,,,,,,,,,,,,,,,, -28000,0.41696143,1.8282082,,,,,,,,,,,,,,,,, -28100,0.39046693,1.808292,,,,,,,,,,,,,,,,, -28200,0.47119316,1.8017027,,,,,,,,,,,,,,,,, -28300,0.40187362,1.7807084,,,,,,,,,,,,,,,,, -28400,0.4425934,1.8619077,,,,,,,,,,,,,,,,, -28500,0.41929388,1.8561118,,,,,,,,,,,,,,,,, -28600,0.38011858,1.8417053,,,,,,,,,,,,,,,,, -28700,0.42578307,1.8487359,,,,,,,,,,,,,,,,, -28800,0.42319024,1.8382146,,,,,,,,,,,,,,,,, -28900,0.36893883,1.8253468,,,,,,,,,,,,,,,,, -29000,0.41875094,1.9565369,,,,,,,,,,,,,,,,, -29100,0.40392616,1.7920642,,,,,,,,,,,,,,,,, -29173,,,0.6313155889511108,1.8113305568695068,31.167266657108826,0.6551065444946289,1.6476982831954956,28.01659926041504,3000.0,0.6651443839073181,1.5824037790298462,27.611150169070616,3003.0,10107.838258981705,16798.87334752083,10107.838258981705,6689.792566776276,0.323559045791626,0.0 -29200,0.41089097,1.7845434,,,,,,,,,,,,,,,,, -29300,0.36248022,1.818374,,,,,,,,,,,,,,,,, -29400,0.3976974,1.810963,,,,,,,,,,,,,,,,, -29500,0.37953517,1.7689064,,,,,,,,,,,,,,,,, -29600,0.4635414,1.7768948,,,,,,,,,,,,,,,,, -29700,0.42295852,1.8538686,,,,,,,,,,,,,,,,, -29800,0.4042689,1.8097799,,,,,,,,,,,,,,,,, -29900,0.38771403,1.7630945,,,,,,,,,,,,,,,,, -30000,0.40493017,1.7780038,,,,,,,,,,,,,,,,, -30100,0.4174881,1.7404794,,,,,,,,,,,,,,,,, -30200,0.42789516,1.8666652,,,,,,,,,,,,,,,,, -30300,0.346873,1.7756311,,,,,,,,,,,,,,,,, -30400,0.41642818,1.7699986,,,,,,,,,,,,,,,,, -30500,0.418456,1.8355223,,,,,,,,,,,,,,,,, -30600,0.39770627,1.8030431,,,,,,,,,,,,,,,,, -30700,0.41152787,1.7729667,,,,,,,,,,,,,,,,, -30800,0.4046717,1.8098047,,,,,,,,,,,,,,,,, -30900,0.38138145,1.6763866,,,,,,,,,,,,,,,,, -31000,0.3961048,1.7797738,,,,,,,,,,,,,,,,, -31100,0.36447823,1.753037,,,,,,,,,,,,,,,,, -31200,0.4207157,1.7518721,,,,,,,,,,,,,,,,, -31300,0.45378575,1.7948337,,,,,,,,,,,,,,,,, -31400,0.4374432,1.7091821,,,,,,,,,,,,,,,,, -31500,0.47717866,1.7784858,,,,,,,,,,,,,,,,, -31600,0.41605043,1.8327178,,,,,,,,,,,,,,,,, -31607,,,0.6526026129722595,1.652001142501831,32.13943278256692,0.657090425491333,1.638078689575195,28.228548539202578,3000.0,0.6655162572860718,1.5741702318191528,27.55012963104404,3003.0,10948.034049272535,18158.23977446556,10948.034049272535,7208.857345581055,0.3532938957214355,0.0 -31700,0.4642934,1.8057576,,,,,,,,,,,,,,,,, -31800,0.41513163,1.831262,,,,,,,,,,,,,,,,, -31900,0.39205155,1.8253258,,,,,,,,,,,,,,,,, -32000,0.37256774,1.7919044,,,,,,,,,,,,,,,,, -32100,0.40875632,1.8566098,,,,,,,,,,,,,,,,, -32200,0.3797019,1.7690573,,,,,,,,,,,,,,,,, -32300,0.3762417,1.7282429,,,,,,,,,,,,,,,,, -32400,0.3963827,1.7740501,,,,,,,,,,,,,,,,, -32500,0.38383558,1.7513679,,,,,,,,,,,,,,,,, -32600,0.46083212,1.7123483,,,,,,,,,,,,,,,,, -32700,0.3630146,1.8332051,,,,,,,,,,,,,,,,, -32800,0.41446304,1.8377379,,,,,,,,,,,,,,,,, -32900,0.3830851,1.8324332,,,,,,,,,,,,,,,,, -33000,0.39916056,1.795532,,,,,,,,,,,,,,,,, -33100,0.42738032,1.8104059,,,,,,,,,,,,,,,,, -33200,0.35857883,1.8142397,,,,,,,,,,,,,,,,, -33300,0.37797552,1.7933751,,,,,,,,,,,,,,,,, -33400,0.3548475,1.8655404,,,,,,,,,,,,,,,,, -33500,0.42643812,1.8163393,,,,,,,,,,,,,,,,, -33600,0.45389336,1.8017919,,,,,,,,,,,,,,,,, -33700,0.41250923,1.7926027,,,,,,,,,,,,,,,,, -33800,0.39709947,1.8086122,,,,,,,,,,,,,,,,, -33900,0.34409568,1.7122567,,,,,,,,,,,,,,,,, -34000,0.36768788,1.809744,,,,,,,,,,,,,,,,, -34040,,,0.6393951773643494,1.768338680267334,31.008950206319334,0.6554041504859924,1.6242969036102295,28.08600000210978,3000.0,0.6660507917404175,1.565197229385376,27.39591630850173,3003.0,11787.991804122925,19519.91458582878,11787.991804122925,7730.468078613281,0.3834218978881836,0.0 -34100,0.38896284,1.816611,,,,,,,,,,,,,,,,, -34200,0.4055213,1.828856,,,,,,,,,,,,,,,,, -34300,0.41088057,1.7374773,,,,,,,,,,,,,,,,, -34400,0.37242645,1.716247,,,,,,,,,,,,,,,,, -34500,0.36854663,1.7820551,,,,,,,,,,,,,,,,, -34600,0.41147313,1.7569444,,,,,,,,,,,,,,,,, -34700,0.38627365,1.7964605,,,,,,,,,,,,,,,,, -34800,0.44429964,1.7574081,,,,,,,,,,,,,,,,, -34900,0.38987157,1.8148487,,,,,,,,,,,,,,,,, -35000,0.3659068,1.7305878,,,,,,,,,,,,,,,,, -35100,0.40943205,1.8924682,,,,,,,,,,,,,,,,, -35200,0.36941725,1.7577486,,,,,,,,,,,,,,,,, -35300,0.44456244,1.8671385,,,,,,,,,,,,,,,,, -35400,0.42970544,1.7987419,,,,,,,,,,,,,,,,, -35500,0.40387926,1.7905753,,,,,,,,,,,,,,,,, -35600,0.4186882,1.728857,,,,,,,,,,,,,,,,, -35700,0.6583237,1.794102,,,,,,,,,,,,,,,,, -35800,0.3836568,1.8104113,,,,,,,,,,,,,,,,, -35900,0.41246307,1.7608322,,,,,,,,,,,,,,,,, -36000,0.4007893,1.8309728,,,,,,,,,,,,,,,,, -36100,0.41320863,1.842893,,,,,,,,,,,,,,,,, -36200,0.4703106,1.77573,,,,,,,,,,,,,,,,, -36300,0.40283617,1.7670666,,,,,,,,,,,,,,,,, -36400,0.39761403,1.8561587,,,,,,,,,,,,,,,,, -36474,,,0.6336399912834167,1.79706072807312,31.23075734922129,0.6571400165557861,1.6196331977844238,28.072348613153583,3000.0,0.667793869972229,1.5579041242599487,27.544555632291114,3003.0,12628.08743071556,20997.03369545937,12628.08743071556,8367.384491205215,0.4141757488250732,0.0 -36500,0.42337123,1.720162,,,,,,,,,,,,,,,,, -36600,0.433915,1.8300108,,,,,,,,,,,,,,,,, -36700,0.39197448,1.7386254,,,,,,,,,,,,,,,,, -36800,0.3970784,1.8186992,,,,,,,,,,,,,,,,, -36900,0.39377186,1.8166546,,,,,,,,,,,,,,,,, -37000,0.41918382,1.7517049,,,,,,,,,,,,,,,,, -37100,0.41293088,1.7465794,,,,,,,,,,,,,,,,, -37200,0.4157354,1.8180017,,,,,,,,,,,,,,,,, -37300,0.46863097,1.7579117,,,,,,,,,,,,,,,,, -37400,0.44017094,1.7850721,,,,,,,,,,,,,,,,, -37500,0.42296952,1.7992399,,,,,,,,,,,,,,,,, -37600,0.45229456,1.7708752,,,,,,,,,,,,,,,,, -37700,0.50368845,1.8274697,,,,,,,,,,,,,,,,, -37800,0.36465725,1.7235966,,,,,,,,,,,,,,,,, -37900,0.39552048,1.8367022,,,,,,,,,,,,,,,,, -38000,0.43356323,1.7564979,,,,,,,,,,,,,,,,, -38100,0.43533647,1.7359143,,,,,,,,,,,,,,,,, -38200,0.36274996,1.7789279,,,,,,,,,,,,,,,,, -38300,0.40537274,1.7577764,,,,,,,,,,,,,,,,, -38400,0.42615092,1.7243077,,,,,,,,,,,,,,,,, -38500,0.3895997,1.7653935,,,,,,,,,,,,,,,,, -38600,0.4404665,1.8202482,,,,,,,,,,,,,,,,, -38700,0.35450223,1.7205415,,,,,,,,,,,,,,,,, -38800,0.4541498,1.7479837,,,,,,,,,,,,,,,,, -38900,0.40258548,1.8276064,,,,,,,,,,,,,,,,, -38908,,,0.6437124013900757,1.718425989151001,31.870111946059616,0.6600537896156311,1.6126872301101685,28.30275218783093,3000.0,0.6705943942070007,1.5488643646240234,27.78423681173496,3003.0,13468.063895463943,22358.86438035965,13468.063895463943,8889.13286614418,0.4446568489074707,0.0 -39000,0.38850448,1.6986485,,,,,,,,,,,,,,,,, -39100,0.4089366,1.7679777,,,,,,,,,,,,,,,,, -39200,0.37422147,1.8048074,,,,,,,,,,,,,,,,, -39300,0.3621404,1.8866442,,,,,,,,,,,,,,,,, -39400,0.41318336,1.76189,,,,,,,,,,,,,,,,, -39500,0.3639372,1.7339455,,,,,,,,,,,,,,,,, -39600,0.4122514,1.7533288,,,,,,,,,,,,,,,,, -39700,0.4030923,1.7602725,,,,,,,,,,,,,,,,, -39800,0.48023957,1.7432959,,,,,,,,,,,,,,,,, -39900,0.41291356,1.699996,,,,,,,,,,,,,,,,, -40000,0.42722598,1.7343199,,,,,,,,,,,,,,,,, -40100,0.37702382,1.7081003,,,,,,,,,,,,,,,,, -40200,0.38600639,1.7429731,,,,,,,,,,,,,,,,, -40300,0.46428004,1.8206621,,,,,,,,,,,,,,,,, -40400,0.49618036,1.8347079,,,,,,,,,,,,,,,,, -40500,0.44176397,1.7633632,,,,,,,,,,,,,,,,, -40600,0.35365808,1.7018464,,,,,,,,,,,,,,,,, -40700,0.39527816,1.774567,,,,,,,,,,,,,,,,, -40800,0.40927306,1.8003715,,,,,,,,,,,,,,,,, -40900,0.4135377,1.8411645,,,,,,,,,,,,,,,,, -41000,0.38261145,1.7041112,,,,,,,,,,,,,,,,, -41100,0.38840497,1.8132191,,,,,,,,,,,,,,,,, -41200,0.37484062,1.6885272,,,,,,,,,,,,,,,,, -41300,0.41936943,1.7742447,,,,,,,,,,,,,,,,, -41342,,,0.6403557658195496,1.7567963600158691,31.64339967441196,0.659694254398346,1.6073123216629028,27.765434272920903,3000.0,0.6712451577186584,1.538042068481445,27.958816788714334,3003.0,14308.293565273283,23799.416815519333,14308.293565273283,9489.347652196884,0.4772262573242187,0.0 -41400,0.41498587,1.7718426,,,,,,,,,,,,,,,,, -41500,0.37777972,1.63608,,,,,,,,,,,,,,,,, -41600,0.40310746,1.8174309,,,,,,,,,,,,,,,,, -41700,0.38069463,1.6970909,,,,,,,,,,,,,,,,, -41800,0.3938572,1.8220052,,,,,,,,,,,,,,,,, -41900,0.41292578,1.8552133,,,,,,,,,,,,,,,,, -42000,0.38883117,1.7449832,,,,,,,,,,,,,,,,, -42100,0.43841335,1.8720306,,,,,,,,,,,,,,,,, -42200,0.44004843,1.7854779,,,,,,,,,,,,,,,,, -42300,0.40451923,1.8364407,,,,,,,,,,,,,,,,, -42400,0.39811525,1.7223492,,,,,,,,,,,,,,,,, -42500,0.43193188,1.7559674,,,,,,,,,,,,,,,,, -42600,0.35983112,1.7783506,,,,,,,,,,,,,,,,, -42700,0.3636092,1.7205169,,,,,,,,,,,,,,,,, -42800,0.50190485,1.7822828,,,,,,,,,,,,,,,,, -42900,0.40436628,1.7572557,,,,,,,,,,,,,,,,, -43000,0.38755035,1.7957566,,,,,,,,,,,,,,,,, -43100,0.36130613,1.7631633,,,,,,,,,,,,,,,,, -43200,0.5151977,1.7025715,,,,,,,,,,,,,,,,, -43300,0.3771321,1.7536724,,,,,,,,,,,,,,,,, -43400,0.3867217,1.7330205,,,,,,,,,,,,,,,,, -43500,0.39171538,1.7110906,,,,,,,,,,,,,,,,, -43600,0.48908103,1.7431635,,,,,,,,,,,,,,,,, -43700,0.38519433,1.7390665,,,,,,,,,,,,,,,,, -43776,,,0.6402798295021057,1.7504782676696775,31.42158195384738,0.6611325144767761,1.5975406169891355,28.68003646635293,3000.0,0.6745337247848511,1.532659888267517,27.9198714817124,3003.0,15148.395472049711,25152.87335944176,15148.395472049711,10002.591368198397,0.5125997066497803,0.0 -43800,0.4188458,1.8098038,,,,,,,,,,,,,,,,, -43900,0.4127915,1.7040925,,,,,,,,,,,,,,,,, -44000,0.39462298,1.7410535,,,,,,,,,,,,,,,,, -44100,0.46067858,1.808975,,,,,,,,,,,,,,,,, -44200,0.37905875,1.6628009,,,,,,,,,,,,,,,,, -44300,0.4265575,1.7996949,,,,,,,,,,,,,,,,, -44400,0.4078535,1.7448775,,,,,,,,,,,,,,,,, -44500,0.40093383,1.7724276,,,,,,,,,,,,,,,,, -44600,0.39309412,1.7472492,,,,,,,,,,,,,,,,, -44700,0.36168838,1.7912201,,,,,,,,,,,,,,,,, -44800,0.3815268,1.7348044,,,,,,,,,,,,,,,,, -44900,0.37479708,1.6932899,,,,,,,,,,,,,,,,, -45000,0.39229313,1.7165959,,,,,,,,,,,,,,,,, -45100,0.38734967,1.7779088,,,,,,,,,,,,,,,,, -45200,0.4474685,1.7530159,,,,,,,,,,,,,,,,, -45300,0.42418638,1.7277614,,,,,,,,,,,,,,,,, -45400,0.38726804,1.7835724,,,,,,,,,,,,,,,,, -45500,0.4682547,1.7506642,,,,,,,,,,,,,,,,, -45600,0.43353894,1.7400277,,,,,,,,,,,,,,,,, -45700,0.38910455,1.7991301,,,,,,,,,,,,,,,,, -45800,0.39975345,1.7649114,,,,,,,,,,,,,,,,, -45900,0.40072075,1.7870933,,,,,,,,,,,,,,,,, -46000,0.41568002,1.744997,,,,,,,,,,,,,,,,, -46100,0.38079444,1.8028251,,,,,,,,,,,,,,,,, -46200,0.4472697,1.7629006,,,,,,,,,,,,,,,,, -46210,,,0.6428989768028259,1.7245334386825562,31.507355253150934,0.6620624661445618,1.5930095911026,28.495653401907823,3000.0,0.6745337247848511,1.5249536037445068,28.18982572482699,3003.0,15988.560994148254,26514.822845697403,15988.560994148254,10524.269515752792,0.5430536270141602,0.0 -46300,0.4377879,1.6864696,,,,,,,,,,,,,,,,, -46400,0.40055317,1.6692456,,,,,,,,,,,,,,,,, -46500,0.41941497,1.8861808,,,,,,,,,,,,,,,,, -46600,0.38205683,1.7674695,,,,,,,,,,,,,,,,, -46700,0.554304,1.7682348,,,,,,,,,,,,,,,,, -46800,0.4006396,1.7396909,,,,,,,,,,,,,,,,, -46900,0.369981,1.7027181,,,,,,,,,,,,,,,,, -47000,0.3664,1.6675798,,,,,,,,,,,,,,,,, -47100,11.590187,1.7628803,,,,,,,,,,,,,,,,, -47200,0.4382237,1.8014419,,,,,,,,,,,,,,,,, -47300,0.4115659,1.6908021,,,,,,,,,,,,,,,,, -47400,0.42369825,1.8042103,,,,,,,,,,,,,,,,, -47500,0.37533236,1.7338408,,,,,,,,,,,,,,,,, -47600,0.38022012,1.6713172,,,,,,,,,,,,,,,,, -47700,0.42292577,1.7210826,,,,,,,,,,,,,,,,, -47800,0.40094957,1.7580543,,,,,,,,,,,,,,,,, -47900,0.395077,1.7543738,,,,,,,,,,,,,,,,, -48000,0.4036824,1.7264642,,,,,,,,,,,,,,,,, -48100,0.39300224,1.7632853,,,,,,,,,,,,,,,,, -48200,0.3707302,1.8046016,,,,,,,,,,,,,,,,, -48300,0.39671245,1.6819618,,,,,,,,,,,,,,,,, -48400,0.36615896,1.6948814,,,,,,,,,,,,,,,,, -48500,0.38151982,1.7510448,,,,,,,,,,,,,,,,, -48600,0.39737234,1.788886,,,,,,,,,,,,,,,,, -48644,,,0.6434541940689087,1.7303590774536133,31.33666911989976,0.6629676222801208,1.5866459608078003,28.31539252944437,3000.0,0.6765324473381042,1.513628602027893,28.04292709290652,3003.0,16828.613729715347,27868.87157773972,16828.613729715347,11038.159530639648,0.5735483169555664,0.0 -48700,0.40744025,1.7476566,,,,,,,,,,,,,,,,, -48800,0.3603228,1.6976019,,,,,,,,,,,,,,,,, -48900,0.38035,1.7334437,,,,,,,,,,,,,,,,, -49000,0.4001422,1.6716663,,,,,,,,,,,,,,,,, -49100,0.36063784,1.7821399,,,,,,,,,,,,,,,,, -49200,0.3515836,1.6344945,,,,,,,,,,,,,,,,, -49300,0.37939999,1.7595361,,,,,,,,,,,,,,,,, -49400,0.38005823,1.7193811,,,,,,,,,,,,,,,,, -49500,0.37985286,1.7637861,,,,,,,,,,,,,,,,, -49600,0.41904664,1.7259482,,,,,,,,,,,,,,,,, -49700,0.4384808,1.7505041,,,,,,,,,,,,,,,,, -49800,1.6021376,1.7592988,,,,,,,,,,,,,,,,, -49900,0.44201088,1.8102478,,,,,,,,,,,,,,,,, -50000,0.37555206,1.7213832,,,,,,,,,,,,,,,,, -50100,0.44957942,1.7565823,,,,,,,,,,,,,,,,, -50200,0.4049526,1.6794643,,,,,,,,,,,,,,,,, -50300,0.4288305,1.7211449,,,,,,,,,,,,,,,,, -50400,0.43505567,1.7715597,,,,,,,,,,,,,,,,, -50500,0.40547857,1.711814,,,,,,,,,,,,,,,,, -50600,0.367296,1.6547,,,,,,,,,,,,,,,,, -50700,0.40036777,1.7872602,,,,,,,,,,,,,,,,, -50800,0.46971408,1.7061205,,,,,,,,,,,,,,,,, -50900,0.38008666,1.8031524,,,,,,,,,,,,,,,,, -51000,0.4019461,1.7079563,,,,,,,,,,,,,,,,, -51077,,,0.6509331464767456,1.667337417602539,32.02152703215946,0.6633147597312927,1.5808254480361938,28.656745396147187,3000.0,0.6761838793754578,1.5135340690612793,28.28753882239032,3003.0,17668.511252641678,29197.242948055267,17668.511252641678,11526.525929927826,0.6058268547058105,0.0 -51100,0.3706374,1.7451814,,,,,,,,,,,,,,,,, -51200,0.44600335,1.7969366,,,,,,,,,,,,,,,,, -51300,0.40095535,1.7437823,,,,,,,,,,,,,,,,, -51400,0.3914302,1.7099627,,,,,,,,,,,,,,,,, -51500,0.42875114,1.7001868,,,,,,,,,,,,,,,,, -51600,0.3598704,1.7370684,,,,,,,,,,,,,,,,, -51700,0.40847278,1.7003956,,,,,,,,,,,,,,,,, -51800,0.3904507,1.8171115,,,,,,,,,,,,,,,,, -51900,0.398642,1.6711975,,,,,,,,,,,,,,,,, -52000,0.35666072,1.6408006,,,,,,,,,,,,,,,,, -52100,0.35280663,1.743177,,,,,,,,,,,,,,,,, -52200,0.36842358,1.7388105,,,,,,,,,,,,,,,,, -52300,0.37272218,1.7010722,,,,,,,,,,,,,,,,, -52400,0.3905913,1.7524515,,,,,,,,,,,,,,,,, -52500,0.36854932,1.7069738,,,,,,,,,,,,,,,,, -52600,0.39492556,1.6834604,,,,,,,,,,,,,,,,, -52700,0.39706543,1.714865,,,,,,,,,,,,,,,,, -52800,0.39060763,1.7768787,,,,,,,,,,,,,,,,, -52900,0.40285635,1.776583,,,,,,,,,,,,,,,,, -53000,0.3887927,1.7439739,,,,,,,,,,,,,,,,, -53100,0.35172832,1.7212435,,,,,,,,,,,,,,,,, -53200,0.3708695,1.7006117,,,,,,,,,,,,,,,,, -53300,0.39438593,1.7128206,,,,,,,,,,,,,,,,, -53400,0.4063218,1.7034056,,,,,,,,,,,,,,,,, -53500,0.39071777,1.7554151,,,,,,,,,,,,,,,,, -53511,,,0.6447697877883911,1.7171337604522705,31.75867017336076,0.6656085848808289,1.571361422538757,28.81990449034723,3000.0,0.6800999641418457,1.4995527267456057,28.593444733008027,3003.0,18508.527975320816,30602.46071600914,18508.527975320816,12091.620062828064,0.637232780456543,0.0 -53600,0.3781841,1.782105,,,,,,,,,,,,,,,,, -53700,0.4991616,1.8011218,,,,,,,,,,,,,,,,, -53800,0.49327078,1.7308898,,,,,,,,,,,,,,,,, -53900,0.41025525,1.6756241,,,,,,,,,,,,,,,,, -54000,0.4214799,1.7202741,,,,,,,,,,,,,,,,, -54100,0.39784175,1.6460555,,,,,,,,,,,,,,,,, -54200,0.3940253,1.7182552,,,,,,,,,,,,,,,,, -54300,0.47036123,1.7075162,,,,,,,,,,,,,,,,, -54400,0.39126801,1.6841466,,,,,,,,,,,,,,,,, -54500,0.38430578,1.6724238,,,,,,,,,,,,,,,,, -54600,0.40147328,1.7577285,,,,,,,,,,,,,,,,, -54700,0.36701325,1.7185456,,,,,,,,,,,,,,,,, -54800,0.3987891,1.7155346,,,,,,,,,,,,,,,,, -54900,0.42755747,1.7711478,,,,,,,,,,,,,,,,, -55000,0.3901603,1.6927389,,,,,,,,,,,,,,,,, -55100,0.39997074,1.7898213,,,,,,,,,,,,,,,,, -55200,0.4108059,1.733154,,,,,,,,,,,,,,,,, -55300,0.376105,1.7204747,,,,,,,,,,,,,,,,, -55400,0.3790905,1.6835147,,,,,,,,,,,,,,,,, -55500,0.3991814,1.6956087,,,,,,,,,,,,,,,,, -55600,0.4106324,1.7377715,,,,,,,,,,,,,,,,, -55700,0.39290985,1.6982684,,,,,,,,,,,,,,,,, -55800,0.39121732,1.7694505,,,,,,,,,,,,,,,,, -55900,0.40724576,1.7190341,,,,,,,,,,,,,,,,, -55945,,,0.6402867436408997,1.7420943975448608,31.603690698245167,0.6657822132110596,1.5687559843063354,28.31823448140098,3000.0,0.678345263004303,1.496812343597412,28.48742903827465,3003.0,19348.71917390824,32111.7632894516,19348.71917390824,12760.62174320221,0.6705992221832275,0.0 -56000,0.4123341,1.7509624,,,,,,,,,,,,,,,,, -56100,0.38102722,1.598187,,,,,,,,,,,,,,,,, -56200,0.42390382,1.6891868,,,,,,,,,,,,,,,,, -56300,0.36714116,1.7665925,,,,,,,,,,,,,,,,, -56400,0.39505482,1.7029487,,,,,,,,,,,,,,,,, -56500,0.36778915,1.7276903,,,,,,,,,,,,,,,,, -56600,0.35852638,1.639568,,,,,,,,,,,,,,,,, -56700,0.4049922,1.7544603,,,,,,,,,,,,,,,,, -56800,0.39890382,1.7169745,,,,,,,,,,,,,,,,, -56900,0.4015001,1.7171788,,,,,,,,,,,,,,,,, -57000,0.36894602,1.6829631,,,,,,,,,,,,,,,,, -57100,0.39391184,1.8034421,,,,,,,,,,,,,,,,, -57200,0.40780067,1.643903,,,,,,,,,,,,,,,,, -57300,0.39373386,1.7274666,,,,,,,,,,,,,,,,, -57400,0.37288576,1.7039257,,,,,,,,,,,,,,,,, -57500,0.38590693,1.752413,,,,,,,,,,,,,,,,, -57600,0.4175256,1.7209318,,,,,,,,,,,,,,,,, -57700,0.3616211,1.6619362,,,,,,,,,,,,,,,,, -57800,0.40567616,1.6641985,,,,,,,,,,,,,,,,, -57900,0.38250113,1.6900712,,,,,,,,,,,,,,,,, -58000,0.42542878,1.6469994,,,,,,,,,,,,,,,,, -58100,0.37166786,1.723297,,,,,,,,,,,,,,,,, -58200,0.38832846,1.6730974,,,,,,,,,,,,,,,,, -58300,0.37860405,1.6820014,,,,,,,,,,,,,,,,, -58379,,,0.651101291179657,1.6768395900726318,31.862773994581925,0.6660177707672119,1.565895915031433,28.684740364153395,3000.0,0.6788333058357239,1.493443727493286,28.461223409180462,3003.0,20188.83740305901,33452.249553442,20188.83740305901,13260.881618976591,0.7030997276306152,0.0 -58400,0.38458076,1.6902725,,,,,,,,,,,,,,,,, -58500,0.39078987,1.7459495,,,,,,,,,,,,,,,,, -58600,0.3852883,1.7192842,,,,,,,,,,,,,,,,, -58700,0.39105377,1.7330893,,,,,,,,,,,,,,,,, -58800,0.39734295,1.7452403,,,,,,,,,,,,,,,,, -58900,0.39360043,1.694999,,,,,,,,,,,,,,,,, -59000,0.3986133,1.6313654,,,,,,,,,,,,,,,,, -59100,0.37520927,1.733624,,,,,,,,,,,,,,,,, -59200,0.39916322,1.7168338,,,,,,,,,,,,,,,,, -59300,0.37416765,1.8186018,,,,,,,,,,,,,,,,, -59400,0.3786426,1.7033886,,,,,,,,,,,,,,,,, -59500,0.38466483,1.6980675,,,,,,,,,,,,,,,,, -59600,0.3591663,1.6389905,,,,,,,,,,,,,,,,, -59700,0.43654767,1.7284486,,,,,,,,,,,,,,,,, -59800,0.39585012,1.6756418,,,,,,,,,,,,,,,,, -59900,0.48577586,1.6500217,,,,,,,,,,,,,,,,, -60000,0.42025796,1.6878495,,,,,,,,,,,,,,,,, -60100,0.39803275,1.7289193,,,,,,,,,,,,,,,,, -60200,0.40096268,1.6848918,,,,,,,,,,,,,,,,, -60300,0.37865886,1.741339,,,,,,,,,,,,,,,,, -60400,0.37833756,1.7467674,,,,,,,,,,,,,,,,, -60500,0.39918458,1.721507,,,,,,,,,,,,,,,,, -60600,0.37480578,1.7063017,,,,,,,,,,,,,,,,, -60700,0.40317586,1.6484321,,,,,,,,,,,,,,,,, -60800,0.40683743,1.7975448,,,,,,,,,,,,,,,,, -60813,,,0.6452680826187134,1.7102110385894775,31.57398436310127,0.6664021611213684,1.5667551755905151,28.68595025894429,3000.0,0.6794724464416504,1.485634207725525,28.543103089635263,3003.0,21028.95820069313,34781.54425239563,21028.95820069313,13749.946232795715,0.7364275455474854,0.0 -60900,0.37525728,1.7279897,,,,,,,,,,,,,,,,, -61000,0.3784242,1.7207165,,,,,,,,,,,,,,,,, -61100,0.4112502,1.6837429,,,,,,,,,,,,,,,,, -61200,0.38444415,1.6943904,,,,,,,,,,,,,,,,, -61300,0.38159946,1.7229615,,,,,,,,,,,,,,,,, -61400,0.39434087,1.635994,,,,,,,,,,,,,,,,, -61500,0.44386742,1.6408107,,,,,,,,,,,,,,,,, -61600,0.38597855,1.6988112,,,,,,,,,,,,,,,,, -61700,0.3993548,1.6688381,,,,,,,,,,,,,,,,, -61800,0.35229832,1.7303208,,,,,,,,,,,,,,,,, -61900,0.4132421,1.7356488,,,,,,,,,,,,,,,,, -62000,0.46083283,1.7458769,,,,,,,,,,,,,,,,, -62100,1.0952648,1.6468197,,,,,,,,,,,,,,,,, -62200,0.4160068,1.730098,,,,,,,,,,,,,,,,, -62300,0.37497222,1.714314,,,,,,,,,,,,,,,,, -62400,0.37546358,1.6638066,,,,,,,,,,,,,,,,, -62500,0.37590337,1.677932,,,,,,,,,,,,,,,,, -62600,0.41593477,1.6440237,,,,,,,,,,,,,,,,, -62700,0.37508044,1.6887405,,,,,,,,,,,,,,,,, -62800,0.39426872,1.7755805,,,,,,,,,,,,,,,,, -62900,0.39406532,1.7226484,,,,,,,,,,,,,,,,, -63000,0.39396435,1.7288221,,,,,,,,,,,,,,,,, -63100,0.36920652,1.6763477,,,,,,,,,,,,,,,,, -63200,0.40314808,1.637759,,,,,,,,,,,,,,,,, -63247,,,0.6609110236167908,1.613291621208191,32.737365321346005,0.6695019006729126,1.551543354988098,29.214455853774165,3000.0,0.6820870637893677,1.4776993989944458,28.98144871314173,3003.0,21869.133969783783,36111.008915662766,21869.133969783783,14239.12595629692,0.7697396278381348,0.0 -63300,0.41588435,1.6514826,,,,,,,,,,,,,,,,, -63400,0.3849111,1.7296597,,,,,,,,,,,,,,,,, -63500,0.37488532,1.726137,,,,,,,,,,,,,,,,, -63600,0.40710935,1.711683,,,,,,,,,,,,,,,,, -63700,0.37280628,1.6447076,,,,,,,,,,,,,,,,, -63800,0.38431334,1.6946867,,,,,,,,,,,,,,,,, -63900,0.36625853,1.6940639,,,,,,,,,,,,,,,,, -64000,0.37099794,1.7166842,,,,,,,,,,,,,,,,, -64100,0.42226592,1.6856731,,,,,,,,,,,,,,,,, -64200,0.3865102,1.726569,,,,,,,,,,,,,,,,, -64300,0.3799259,1.7195698,,,,,,,,,,,,,,,,, -64400,0.3997903,1.728999,,,,,,,,,,,,,,,,, -64500,0.38985324,1.7474029,,,,,,,,,,,,,,,,, -64600,0.40587434,1.7101762,,,,,,,,,,,,,,,,, -64700,0.46212196,1.7267073,,,,,,,,,,,,,,,,, -64800,0.38195726,1.6890218,,,,,,,,,,,,,,,,, -64900,0.40765804,1.6694334,,,,,,,,,,,,,,,,, -65000,0.37777415,1.678452,,,,,,,,,,,,,,,,, -65100,0.36852816,1.6567091,,,,,,,,,,,,,,,,, -65200,0.36041093,1.7080542,,,,,,,,,,,,,,,,, -65300,0.3867853,1.701924,,,,,,,,,,,,,,,,, -65400,0.38936886,1.7356538,,,,,,,,,,,,,,,,, -65500,0.41089305,1.7067041,,,,,,,,,,,,,,,,, -65600,0.37738696,1.7038904,,,,,,,,,,,,,,,,, -65681,,,0.6530311703681946,1.671697735786438,32.311872612330355,0.6706550121307373,1.538474202156067,28.82110298512525,3000.0,0.6832723617553711,1.465322732925415,28.77781389125732,3003.0,22709.169924020767,37561.60833859444,22709.169924020767,14849.57773900032,0.8056454658508301,0.0 -65700,0.3852907,1.6331675,,,,,,,,,,,,,,,,, -65800,0.37475598,1.7226958,,,,,,,,,,,,,,,,, -65900,0.38201842,1.6819595,,,,,,,,,,,,,,,,, -66000,0.40139472,1.6944858,,,,,,,,,,,,,,,,, -66100,0.3691486,1.6385839,,,,,,,,,,,,,,,,, -66200,0.40619278,1.6405427,,,,,,,,,,,,,,,,, -66300,0.387386,1.6137556,,,,,,,,,,,,,,,,, -66400,0.3549425,1.5946441,,,,,,,,,,,,,,,,, -66500,0.47303143,1.687872,,,,,,,,,,,,,,,,, -66600,0.41511294,1.7619303,,,,,,,,,,,,,,,,, -66700,0.4297664,1.6532352,,,,,,,,,,,,,,,,, -66800,0.3998038,1.6857512,,,,,,,,,,,,,,,,, -66900,0.37864667,1.6344999,,,,,,,,,,,,,,,,, -67000,0.4162525,1.741899,,,,,,,,,,,,,,,,, -67100,0.36261964,1.7085967,,,,,,,,,,,,,,,,, -67200,0.39863628,1.6477395,,,,,,,,,,,,,,,,, -67300,0.38629082,1.5942212,,,,,,,,,,,,,,,,, -67400,0.39463606,1.7147862,,,,,,,,,,,,,,,,, -67500,0.4056663,1.5919381,,,,,,,,,,,,,,,,, -67600,0.38272095,1.5594544,,,,,,,,,,,,,,,,, -67700,0.4022256,1.6645194,,,,,,,,,,,,,,,,, -67800,0.39108127,1.6894972,,,,,,,,,,,,,,,,, -67900,0.46102569,1.7308928,,,,,,,,,,,,,,,,, -68000,0.39263365,1.6235254,,,,,,,,,,,,,,,,, -68100,0.4245804,1.5936484,,,,,,,,,,,,,,,,, -68115,,,0.6517167091369629,1.6799516677856443,32.40043379300699,0.670059859752655,1.534948229789734,28.84134266910904,3000.0,0.6844460368156433,1.4604904651641846,28.78482850239207,3003.0,23549.336899280548,38971.38629126549,23549.336899280548,15419.075818538666,0.8422815799713135,0.0 -68200,0.40661,1.7417997,,,,,,,,,,,,,,,,, -68300,0.37721473,1.729585,,,,,,,,,,,,,,,,, -68400,0.36152056,1.6995531,,,,,,,,,,,,,,,,, -68500,0.411503,1.6837168,,,,,,,,,,,,,,,,, -68600,0.42306086,1.723961,,,,,,,,,,,,,,,,, -68700,0.40525848,1.6755836,,,,,,,,,,,,,,,,, -68800,0.38237903,1.7424284,,,,,,,,,,,,,,,,, -68900,0.37011626,1.6563143,,,,,,,,,,,,,,,,, -69000,0.41064072,1.676917,,,,,,,,,,,,,,,,, -69100,0.38664907,1.7101444,,,,,,,,,,,,,,,,, -69200,0.4004267,1.6989716,,,,,,,,,,,,,,,,, -69300,0.37928373,1.6032991,,,,,,,,,,,,,,,,, -69400,0.36140528,1.6251339,,,,,,,,,,,,,,,,, -69500,0.45171008,1.6260501,,,,,,,,,,,,,,,,, -69600,0.37939006,1.6623625,,,,,,,,,,,,,,,,, -69700,0.38788694,1.6851327,,,,,,,,,,,,,,,,, -69800,0.41261628,1.5761278,,,,,,,,,,,,,,,,, -69900,0.37534004,1.6380695,,,,,,,,,,,,,,,,, -70000,0.39567694,1.606237,,,,,,,,,,,,,,,,, -70100,0.41800815,1.6423624,,,,,,,,,,,,,,,,, -70200,0.40812075,1.6069015,,,,,,,,,,,,,,,,, -70300,0.40103164,1.5834259,,,,,,,,,,,,,,,,, -70400,0.4263112,1.6676052,,,,,,,,,,,,,,,,, -70500,0.40420187,1.6537981,,,,,,,,,,,,,,,,, -70549,,,0.6591607332229614,1.6216074228286743,32.18171612106351,0.6709278225898743,1.5275872945785522,28.9270734590302,3000.0,0.6852594614028931,1.4523561000823977,28.99412723224184,3003.0,24389.565573453903,40268.55391001701,24389.565573453903,15875.903427124023,0.8780360221862793,0.0 -70600,0.4097609,1.7080721,,,,,,,,,,,,,,,,, -70700,0.38260114,1.6066445,,,,,,,,,,,,,,,,, -70800,0.4245493,1.7259636,,,,,,,,,,,,,,,,, -70900,0.3804749,1.6424098,,,,,,,,,,,,,,,,, -71000,0.4194334,1.7378012,,,,,,,,,,,,,,,,, -71100,0.3932472,1.6810071,,,,,,,,,,,,,,,,, -71200,0.387367,1.6601094,,,,,,,,,,,,,,,,, -71300,0.3941676,1.7004232,,,,,,,,,,,,,,,,, -71400,0.41237912,1.6897786,,,,,,,,,,,,,,,,, -71500,0.41574255,1.7141373,,,,,,,,,,,,,,,,, -71600,0.42077518,1.6314609,,,,,,,,,,,,,,,,, -71700,0.47080553,1.7023008,,,,,,,,,,,,,,,,, -71800,0.40852466,1.6879579,,,,,,,,,,,,,,,,, -71900,0.37034804,1.6289895,,,,,,,,,,,,,,,,, -72000,0.38608614,1.6759187,,,,,,,,,,,,,,,,, -72100,0.42815286,1.6041982,,,,,,,,,,,,,,,,, -72200,0.387178,1.697138,,,,,,,,,,,,,,,,, -72300,0.37254715,1.6299937,,,,,,,,,,,,,,,,, -72400,0.40623495,1.6286408,,,,,,,,,,,,,,,,, -72500,0.38317397,1.6161895,,,,,,,,,,,,,,,,, -72600,0.39788812,1.6022084,,,,,,,,,,,,,,,,, -72700,0.37420553,1.554483,,,,,,,,,,,,,,,,, -72800,0.40196708,1.666391,,,,,,,,,,,,,,,,, -72900,0.3704825,1.5494335,,,,,,,,,,,,,,,,, -72983,,,0.6545958518981934,1.6556434631347656,32.03680376481512,0.6736928224563599,1.5203146934509275,29.4972412252912,3000.0,0.6866422891616821,1.4460179805755615,28.961118007307395,3003.0,25229.58171772957,41573.54033780098,25229.58171772957,16340.762417078018,0.9139454364776612,0.0 -73000,0.36748284,1.6907959,,,,,,,,,,,,,,,,, -73100,0.39178762,1.5854568,,,,,,,,,,,,,,,,, -73200,0.42007324,1.6245823,,,,,,,,,,,,,,,,, -73300,0.45327264,1.6074816,,,,,,,,,,,,,,,,, -73400,0.39908645,1.6501178,,,,,,,,,,,,,,,,, -73500,0.40045828,1.5897745,,,,,,,,,,,,,,,,, -73600,0.41249287,1.6318362,,,,,,,,,,,,,,,,, -73700,0.3547732,1.6598384,,,,,,,,,,,,,,,,, -73800,0.39775142,1.6613616,,,,,,,,,,,,,,,,, -73900,0.43403724,1.60628,,,,,,,,,,,,,,,,, -74000,0.407354,1.6421545,,,,,,,,,,,,,,,,, -74100,0.45658362,1.6064882,,,,,,,,,,,,,,,,, -74200,0.41995955,1.6849355,,,,,,,,,,,,,,,,, -74300,0.38883406,1.5969222,,,,,,,,,,,,,,,,, -74400,0.42108285,1.6591551,,,,,,,,,,,,,,,,, -74500,0.39290693,1.6955302,,,,,,,,,,,,,,,,, -74600,0.38944244,1.6113818,,,,,,,,,,,,,,,,, -74700,0.4064302,1.6072755,,,,,,,,,,,,,,,,, -74800,0.39292902,1.6053928,,,,,,,,,,,,,,,,, -74900,0.38096148,1.6172341,,,,,,,,,,,,,,,,, -75000,0.40993625,1.6426517,,,,,,,,,,,,,,,,, -75100,0.37995014,1.645519,,,,,,,,,,,,,,,,, -75200,0.40055835,1.689716,,,,,,,,,,,,,,,,, -75300,0.41909847,1.5497478,,,,,,,,,,,,,,,,, -75400,0.38023105,1.5845221,,,,,,,,,,,,,,,,, -75417,,,0.6752561330795288,1.5263938903808594,33.8632993216732,0.6756518483161926,1.5147626399993896,29.306960178991066,3000.0,0.6866190433502197,1.439462661743164,29.221898187490247,3003.0,26069.81137084961,42925.76065564156,26069.81137084961,16852.63940834999,0.9512898921966552,0.0 -75500,0.39606735,1.559966,,,,,,,,,,,,,,,,, -75600,0.39333338,1.6021968,,,,,,,,,,,,,,,,, -75700,0.40306804,1.6585035,,,,,,,,,,,,,,,,, -75800,0.43842694,1.7202182,,,,,,,,,,,,,,,,, -75900,0.37118,1.6837462,,,,,,,,,,,,,,,,, -76000,0.36089465,1.6599518,,,,,,,,,,,,,,,,, -76100,0.390866,1.6314967,,,,,,,,,,,,,,,,, -76200,0.41867226,1.6477181,,,,,,,,,,,,,,,,, -76300,0.3983061,1.6321638,,,,,,,,,,,,,,,,, -76400,0.389741,1.649687,,,,,,,,,,,,,,,,, -76500,0.3842313,1.6490896,,,,,,,,,,,,,,,,, -76600,0.38441414,1.5436997,,,,,,,,,,,,,,,,, -76700,0.51143384,1.730865,,,,,,,,,,,,,,,,, -76800,0.3915238,1.7416137,,,,,,,,,,,,,,,,, -76900,0.3905868,1.5848304,,,,,,,,,,,,,,,,, -77000,0.44205222,1.5561769,,,,,,,,,,,,,,,,, -77100,0.43108064,1.6549028,,,,,,,,,,,,,,,,, -77200,0.37928426,1.607235,,,,,,,,,,,,,,,,, -77300,0.43754697,1.6741141,,,,,,,,,,,,,,,,, -77400,0.43796858,1.6153655,,,,,,,,,,,,,,,,, -77500,0.41231695,1.6100067,,,,,,,,,,,,,,,,, -77600,0.40579808,1.6306779,,,,,,,,,,,,,,,,, -77700,0.3781797,1.5453,,,,,,,,,,,,,,,,, -77800,0.39819875,1.6181663,,,,,,,,,,,,,,,,, -77851,,,0.6602402329444885,1.6101831197738647,32.94359332980648,0.6758378744125366,1.5022680759429932,29.88321465163858,3000.0,0.6893149614334106,1.4282652139663696,29.199608789390265,3003.0,26909.84682536125,44291.5014474392,26909.84682536125,17378.230769634247,0.9887261390686036,0.0 -77900,0.42667633,1.6477122,,,,,,,,,,,,,,,,, -78000,0.43736878,1.7076772,,,,,,,,,,,,,,,,, -78100,0.39172417,1.6736622,,,,,,,,,,,,,,,,, -78200,0.4687791,1.7424151,,,,,,,,,,,,,,,,, -78300,0.39323834,1.60768,,,,,,,,,,,,,,,,, -78400,0.40747395,1.5116602,,,,,,,,,,,,,,,,, -78500,0.4661801,1.5756613,,,,,,,,,,,,,,,,, -78600,0.39592776,1.6393347,,,,,,,,,,,,,,,,, -78700,0.42499453,1.6780189,,,,,,,,,,,,,,,,, -78800,0.38540086,1.6017642,,,,,,,,,,,,,,,,, -78900,0.3957849,1.5720947,,,,,,,,,,,,,,,,, -79000,0.41955712,1.6115305,,,,,,,,,,,,,,,,, -79100,0.399333,1.6467717,,,,,,,,,,,,,,,,, -79200,0.40032566,1.5926503,,,,,,,,,,,,,,,,, -79300,0.43515936,1.6135658,,,,,,,,,,,,,,,,, -79400,0.41399342,1.6908416,,,,,,,,,,,,,,,,, -79500,0.42317453,1.6179254,,,,,,,,,,,,,,,,, -79600,0.39852303,1.6283149,,,,,,,,,,,,,,,,, -79700,0.4445373,1.6739186,,,,,,,,,,,,,,,,, -79800,0.41410163,1.6052918,,,,,,,,,,,,,,,,, -79900,0.39804402,1.5217733,,,,,,,,,,,,,,,,, -80000,0.3976569,1.6602029,,,,,,,,,,,,,,,,, -80100,0.41026443,1.6408867,,,,,,,,,,,,,,,,, -80200,0.39040568,1.5451652,,,,,,,,,,,,,,,,, -80285,,,0.6605494022369385,1.6196547746658323,32.74820463146153,0.6773629784584045,1.497270107269287,29.766035010112983,3000.0,0.6895241737365723,1.4248106479644775,29.169380704529548,3003.0,27750.07419347763,45631.01221609116,27750.07419347763,17877.399688482285,1.0267176628112793,0.0 -80300,0.45202512,1.64508,,,,,,,,,,,,,,,,, -80400,0.384247,1.5961887,,,,,,,,,,,,,,,,, -80500,0.38790786,1.6162783,,,,,,,,,,,,,,,,, -80600,0.41093013,1.5935383,,,,,,,,,,,,,,,,, -80700,0.39765888,1.6296034,,,,,,,,,,,,,,,,, -80800,0.4079996,1.7068499,,,,,,,,,,,,,,,,, -80900,0.42321154,1.6107907,,,,,,,,,,,,,,,,, -81000,0.38051975,1.5981302,,,,,,,,,,,,,,,,, -81100,0.41119382,1.6640923,,,,,,,,,,,,,,,,, -81200,0.40557227,1.5907345,,,,,,,,,,,,,,,,, -81300,0.40798938,1.5949379,,,,,,,,,,,,,,,,, -81400,0.42045376,1.6330571,,,,,,,,,,,,,,,,, -81500,0.379573,1.6529328,,,,,,,,,,,,,,,,, -81600,0.40863767,1.6738075,,,,,,,,,,,,,,,,, -81700,0.41288358,1.6106323,,,,,,,,,,,,,,,,, -81800,0.42767859,1.6323721,,,,,,,,,,,,,,,,, -81900,0.4546678,1.6554708,,,,,,,,,,,,,,,,, -82000,0.45912778,1.669347,,,,,,,,,,,,,,,,, -82100,0.4332148,1.6364301,,,,,,,,,,,,,,,,, -82200,0.46347982,1.7153991,,,,,,,,,,,,,,,,, -82300,0.4433994,1.6553361,,,,,,,,,,,,,,,,, -82400,0.4339982,1.6210922,,,,,,,,,,,,,,,,, -82500,0.39362028,1.570417,,,,,,,,,,,,,,,,, -82600,0.42586824,1.6598805,,,,,,,,,,,,,,,,, -82700,0.40756196,1.6630993,,,,,,,,,,,,,,,,, -82719,,,0.6689773201942444,1.5642541646957395,33.41044851396017,0.6781564950942993,1.491786003112793,29.560270364388217,3000.0,0.6920806765556335,1.4164384603500366,29.278406234177744,3003.0,28590.17192029953,47032.475782871246,28590.17192029953,18438.65266013145,1.062680959701538,0.0 -82800,0.42981592,1.6458726,,,,,,,,,,,,,,,,, -82900,0.42544067,1.5983063,,,,,,,,,,,,,,,,, -83000,0.40087035,1.5345967,,,,,,,,,,,,,,,,, -83100,0.40962282,1.6728886,,,,,,,,,,,,,,,,, -83200,0.4054267,1.6612333,,,,,,,,,,,,,,,,, -83300,0.4136477,1.6000237,,,,,,,,,,,,,,,,, -83400,0.44166666,1.6669506,,,,,,,,,,,,,,,,, -83500,0.44472042,1.6502302,,,,,,,,,,,,,,,,, -83600,0.42034066,1.5901265,,,,,,,,,,,,,,,,, -83700,0.39390302,1.5892754,,,,,,,,,,,,,,,,, -83800,0.4180863,1.5492013,,,,,,,,,,,,,,,,, -83900,0.4228217,1.6400006,,,,,,,,,,,,,,,,, -84000,0.43127644,1.5854726,,,,,,,,,,,,,,,,, -84100,0.42000315,1.6207219,,,,,,,,,,,,,,,,, -84200,0.41894564,1.6472793,,,,,,,,,,,,,,,,, -84300,0.4284922,1.6261835,,,,,,,,,,,,,,,,, -84400,0.433019,1.6875962,,,,,,,,,,,,,,,,, -84500,0.4652398,1.6136048,,,,,,,,,,,,,,,,, -84600,0.43695438,1.5895112,,,,,,,,,,,,,,,,, -84700,0.41934875,1.5439233,,,,,,,,,,,,,,,,, -84800,0.41575766,1.5776558,,,,,,,,,,,,,,,,, -84900,0.42895624,1.578058,,,,,,,,,,,,,,,,, -85000,0.46287867,1.5962447,,,,,,,,,,,,,,,,, -85100,0.41377252,1.6721108,,,,,,,,,,,,,,,,, -85153,,,0.6641597151756287,1.5821653604507446,32.891950379823754,0.679582417011261,1.482561111450195,29.699337732733976,3000.0,0.6939166784286499,1.4057952165603638,29.56086936865955,3003.0,29430.228848457336,48491.12983894348,29430.228848457336,19057.13529086113,1.1003947257995603,0.0 -85200,0.43098104,1.5464873,,,,,,,,,,,,,,,,, -85300,0.4221927,1.6047206,,,,,,,,,,,,,,,,, -85400,0.4334141,1.5337325,,,,,,,,,,,,,,,,, -85500,0.4422238,1.5863733,,,,,,,,,,,,,,,,, -85600,0.43058145,1.5529846,,,,,,,,,,,,,,,,, -85700,0.41958913,1.6310915,,,,,,,,,,,,,,,,, -85800,0.4304676,1.6928471,,,,,,,,,,,,,,,,, -85900,0.44786635,1.6081609,,,,,,,,,,,,,,,,, -86000,0.44929454,1.6147978,,,,,,,,,,,,,,,,, -86100,0.46804953,1.5427171,,,,,,,,,,,,,,,,, -86200,0.4214375,1.6230733,,,,,,,,,,,,,,,,, -86300,0.4465775,1.5925629,,,,,,,,,,,,,,,,, -86400,0.43241596,1.6538234,,,,,,,,,,,,,,,,, -86500,0.45700416,1.6244465,,,,,,,,,,,,,,,,, -86600,0.44334283,1.6565257,,,,,,,,,,,,,,,,, -86700,0.4384025,1.5964952,,,,,,,,,,,,,,,,, -86800,0.42981437,1.6190277,,,,,,,,,,,,,,,,, -86900,0.41436717,1.5072933,,,,,,,,,,,,,,,,, -87000,0.4415192,1.6789293,,,,,,,,,,,,,,,,, -87100,0.45462397,1.5962459,,,,,,,,,,,,,,,,, -87200,0.4330094,1.6097686,,,,,,,,,,,,,,,,, -87300,0.4357777,1.5997659,,,,,,,,,,,,,,,,, -87400,0.42333063,1.5489334,,,,,,,,,,,,,,,,, -87500,0.4402989,1.5949563,,,,,,,,,,,,,,,,, -87587,,,0.6998754143714905,1.4042586088180542,36.14299179613099,0.6811694502830505,1.4776188135147097,30.155293354746227,3000.0,0.6963105201721191,1.3956819772720337,29.990684621811305,3003.0,30270.3239107132,49895.22182369232,30270.3239107132,19621.019548416138,1.1373250484466553,0.0 -87600,0.42319202,1.5935858,,,,,,,,,,,,,,,,, -87700,0.44307083,1.5650028,,,,,,,,,,,,,,,,, -87800,0.4249039,1.5585008,,,,,,,,,,,,,,,,, -87900,0.41859728,1.5497372,,,,,,,,,,,,,,,,, -88000,0.45288533,1.5835868,,,,,,,,,,,,,,,,, -88100,0.43626845,1.548742,,,,,,,,,,,,,,,,, -88200,0.43476677,1.6384699,,,,,,,,,,,,,,,,, -88300,0.43825054,1.5957749,,,,,,,,,,,,,,,,, -88400,0.41201907,1.6017799,,,,,,,,,,,,,,,,, -88500,0.4541775,1.5911821,,,,,,,,,,,,,,,,, -88600,0.4086729,1.5645714,,,,,,,,,,,,,,,,, -88700,0.4409346,1.589663,,,,,,,,,,,,,,,,, -88800,0.43740132,1.7022079,,,,,,,,,,,,,,,,, -88900,0.44726032,1.5498046,,,,,,,,,,,,,,,,, -89000,0.44556892,1.5447341,,,,,,,,,,,,,,,,, -89100,0.4693376,1.6332283,,,,,,,,,,,,,,,,, -89200,0.4592011,1.6563458,,,,,,,,,,,,,,,,, -89300,0.45067304,1.6225777,,,,,,,,,,,,,,,,, -89400,0.4531598,1.5523307,,,,,,,,,,,,,,,,, -89500,0.42833668,1.5419947,,,,,,,,,,,,,,,,, -89600,0.48207492,1.5382115,,,,,,,,,,,,,,,,, -89700,0.45836407,1.5513998,,,,,,,,,,,,,,,,, -89800,0.4551309,1.5661874,,,,,,,,,,,,,,,,, -89900,0.438792,1.6126771,,,,,,,,,,,,,,,,, -90000,0.44058585,1.5995878,,,,,,,,,,,,,,,,, -90021,,,0.6689482927322388,1.5577067136764526,33.551267722025884,0.6799047589302063,1.47508442401886,29.56897782735325,3000.0,0.6959386467933655,1.3930952548980713,29.75166453868224,3003.0,31110.369894504547,51232.16949701309,31110.369894504547,20117.80855345726,1.1739370822906494,0.0 -90100,0.44199854,1.5880977,,,,,,,,,,,,,,,,, -90200,0.45013204,1.4786382,,,,,,,,,,,,,,,,, -90300,0.47884437,1.6105155,,,,,,,,,,,,,,,,, -90400,0.4414108,1.5703267,,,,,,,,,,,,,,,,, -90500,0.43957427,1.4538417,,,,,,,,,,,,,,,,, -90600,0.4338258,1.5893617,,,,,,,,,,,,,,,,, -90700,0.450005,1.5242198,,,,,,,,,,,,,,,,, -90800,0.44650367,1.5341197,,,,,,,,,,,,,,,,, -90900,0.4502239,1.6230711,,,,,,,,,,,,,,,,, -91000,0.44148365,1.5668912,,,,,,,,,,,,,,,,, -91100,0.6433533,1.5671422,,,,,,,,,,,,,,,,, -91200,0.46478745,1.6687398,,,,,,,,,,,,,,,,, -91300,0.46295056,1.5327017,,,,,,,,,,,,,,,,, -91400,0.45455027,1.5400923,,,,,,,,,,,,,,,,, -91500,0.4512467,1.6267684,,,,,,,,,,,,,,,,, -91600,0.44410992,1.5941296,,,,,,,,,,,,,,,,, -91700,0.4674659,1.5879799,,,,,,,,,,,,,,,,, -91800,0.48432598,1.6770942,,,,,,,,,,,,,,,,, -91900,0.4414748,1.5300795,,,,,,,,,,,,,,,,, -92000,0.4488643,1.5725708,,,,,,,,,,,,,,,,, -92100,0.47122964,1.4999475,,,,,,,,,,,,,,,,, -92200,0.45625713,1.56954,,,,,,,,,,,,,,,,, -92300,0.50051785,1.5443897,,,,,,,,,,,,,,,,, -92400,0.4756778,1.6292763,,,,,,,,,,,,,,,,, -92455,,,0.6670649647712708,1.56660258769989,33.667583219883554,0.6828681230545044,1.465275764465332,30.19248417028501,3000.0,0.6989483833312988,1.3823401927947998,29.880037934469392,3003.0,31950.513469696045,52584.40080785751,31950.513469696045,20629.78096461296,1.2131965160369873,0.0 -92500,0.4749802,1.6065159,,,,,,,,,,,,,,,,, -92600,0.45909935,1.5951021,,,,,,,,,,,,,,,,, -92700,0.47773358,1.5647159,,,,,,,,,,,,,,,,, -92800,0.44440052,1.4970871,,,,,,,,,,,,,,,,, -92900,0.45831254,1.5138912,,,,,,,,,,,,,,,,, -93000,0.47392544,1.5627538,,,,,,,,,,,,,,,,, -93100,0.42444304,1.5267645,,,,,,,,,,,,,,,,, -93200,0.44081378,1.4843752,,,,,,,,,,,,,,,,, -93300,0.42832044,1.4967426,,,,,,,,,,,,,,,,, -93400,0.4467752,1.556437,,,,,,,,,,,,,,,,, -93500,0.49139613,1.6150588,,,,,,,,,,,,,,,,, -93600,0.4601546,1.5550787,,,,,,,,,,,,,,,,, -93700,0.4759106,1.5516186,,,,,,,,,,,,,,,,, -93800,0.46978343,1.5452904,,,,,,,,,,,,,,,,, -93900,0.4694845,1.6118718,,,,,,,,,,,,,,,,, -94000,0.4793868,1.5184298,,,,,,,,,,,,,,,,, -94100,0.45911407,1.5572218,,,,,,,,,,,,,,,,, -94200,0.4804784,1.5339113,,,,,,,,,,,,,,,,, -94300,0.46575397,1.6002016,,,,,,,,,,,,,,,,, -94400,0.48414233,1.5507431,,,,,,,,,,,,,,,,, -94500,0.47882706,1.5877482,,,,,,,,,,,,,,,,, -94600,0.4640837,1.5470889,,,,,,,,,,,,,,,,, -94700,0.4859204,1.5408944,,,,,,,,,,,,,,,,, -94800,0.49826705,1.4814118,,,,,,,,,,,,,,,,, -94889,,,0.6854655146598816,1.4662498235702517,34.65779037824217,0.6841452717781067,1.458268165588379,30.1208920740266,3000.0,0.6998547315597534,1.3732168674468994,30.19016397136998,3003.0,32790.68526005745,53931.0303311348,32790.68526005745,21136.12323880196,1.2523298263549805,0.0 -94900,0.47405615,1.5159314,,,,,,,,,,,,,,,,, -95000,0.5371097,1.5498446,,,,,,,,,,,,,,,,, -95100,0.50337344,1.5280143,,,,,,,,,,,,,,,,, -95200,0.46603334,1.5862867,,,,,,,,,,,,,,,,, -95300,0.47073573,1.5721973,,,,,,,,,,,,,,,,, -95400,0.48409072,1.4435544,,,,,,,,,,,,,,,,, -95500,0.49771735,1.6040386,,,,,,,,,,,,,,,,, -95600,0.4895134,1.5452623,,,,,,,,,,,,,,,,, -95700,0.47005275,1.4948338,,,,,,,,,,,,,,,,, -95800,0.4724412,1.4989154,,,,,,,,,,,,,,,,, -95900,0.48495856,1.506531,,,,,,,,,,,,,,,,, -96000,0.51703894,1.5585052,,,,,,,,,,,,,,,,, -96100,0.4700578,1.5511522,,,,,,,,,,,,,,,,, -96200,0.471242,1.5119348,,,,,,,,,,,,,,,,, -96300,0.4847472,1.6095626,,,,,,,,,,,,,,,,, -96400,0.5084159,1.5324359,,,,,,,,,,,,,,,,, -96500,0.48765883,1.5708448,,,,,,,,,,,,,,,,, -96600,0.9354192,1.5974679,,,,,,,,,,,,,,,,, -96700,0.4819554,1.4859433,,,,,,,,,,,,,,,,, -96800,0.5060641,1.608279,,,,,,,,,,,,,,,,, -96900,0.47330543,1.5314945,,,,,,,,,,,,,,,,, -97000,0.47799364,1.5830384,,,,,,,,,,,,,,,,, -97100,0.49008945,1.4614109,,,,,,,,,,,,,,,,, -97200,0.4818047,1.5461885,,,,,,,,,,,,,,,,, -97300,0.5051359,1.5978713,,,,,,,,,,,,,,,,, -97323,,,0.6764490008354187,1.5147420167922974,34.23678263585544,0.686042308807373,1.448639750480652,30.389317993276933,3000.0,0.7023648023605347,1.3640131950378418,30.336006678823065,3003.0,33630.83154082298,55339.30372548103,33630.83154082298,21704.13444662094,1.2920780181884766,0.0 -97400,0.49238446,1.5071757,,,,,,,,,,,,,,,,, -97500,0.51060253,1.6111917,,,,,,,,,,,,,,,,, -97600,0.47451538,1.518425,,,,,,,,,,,,,,,,, -97700,0.4989859,1.519322,,,,,,,,,,,,,,,,, -97800,0.49491206,1.4918569,,,,,,,,,,,,,,,,, -97900,0.47107965,1.5657686,,,,,,,,,,,,,,,,, -98000,0.5014084,1.5799655,,,,,,,,,,,,,,,,, -98100,0.49718863,1.5015432,,,,,,,,,,,,,,,,, -98200,0.5021967,1.5706224,,,,,,,,,,,,,,,,, -98300,0.4957419,1.5467516,,,,,,,,,,,,,,,,, -98400,0.51302123,1.5223398,,,,,,,,,,,,,,,,, -98500,0.5008836,1.4960575,,,,,,,,,,,,,,,,, -98600,0.5089519,1.5152192,,,,,,,,,,,,,,,,, -98700,0.5366642,1.4588995,,,,,,,,,,,,,,,,, -98800,0.5273013,1.6113603,,,,,,,,,,,,,,,,, -98900,0.5405223,1.4942335,,,,,,,,,,,,,,,,, -99000,0.5005887,1.5092348,,,,,,,,,,,,,,,,, -99100,0.51576877,1.5815332,,,,,,,,,,,,,,,,, -99200,0.47985545,1.4659476,,,,,,,,,,,,,,,,, -99300,0.5334954,1.5399517,,,,,,,,,,,,,,,,, -99400,0.5144965,1.5106454,,,,,,,,,,,,,,,,, -99500,0.54317933,1.5395764,,,,,,,,,,,,,,,,, -99600,0.52257985,1.4763477,,,,,,,,,,,,,,,,, -99700,0.5265461,1.521582,,,,,,,,,,,,,,,,, -99756,,,0.6764749884605408,1.5274385213851929,34.04751952929381,0.6861662864685059,1.449126362800598,30.30977062684815,3000.0,0.7009587287902832,1.360594391822815,30.23364767228356,3003.0,34470.79104375839,56794.07966709137,34470.79104375839,22318.83654236793,1.3307271003723145,0.0 -99800,0.48616356,1.5107118,,,,,,,,,,,,,,,,, -99900,0.54563165,1.5688437,,,,,,,,,,,,,,,,, -100000,0.5210905,1.5697716,,,,,,,,,,,,,,,,, -100100,0.52429044,1.5133811,,,,,,,,,,,,,,,,, -100200,0.5493466,1.5024227,,,,,,,,,,,,,,,,, -100300,0.5037006,1.5059944,,,,,,,,,,,,,,,,, -100400,0.5176078,1.527795,,,,,,,,,,,,,,,,, -100500,0.5467876,1.5682743,,,,,,,,,,,,,,,,, -100600,0.53724504,1.5053971,,,,,,,,,,,,,,,,, -100700,0.5503686,1.5615232,,,,,,,,,,,,,,,,, -100800,0.53932244,1.468762,,,,,,,,,,,,,,,,, -100900,0.52942103,1.4865624,,,,,,,,,,,,,,,,, -101000,0.5360155,1.5218955,,,,,,,,,,,,,,,,, -101100,0.5427353,1.4673154,,,,,,,,,,,,,,,,, -101200,0.53268933,1.5366163,,,,,,,,,,,,,,,,, -101300,0.5389102,1.5878022,,,,,,,,,,,,,,,,, -101400,0.5182128,1.5121446,,,,,,,,,,,,,,,,, -101500,0.5519739,1.5593113,,,,,,,,,,,,,,,,, -101600,0.55663687,1.611114,,,,,,,,,,,,,,,,, -101700,0.52565265,1.5148888,,,,,,,,,,,,,,,,, -101800,0.52034074,1.5321715,,,,,,,,,,,,,,,,, -101900,0.5641928,1.5520308,,,,,,,,,,,,,,,,, -102000,0.49347678,1.490326,,,,,,,,,,,,,,,,, -102100,0.5025257,1.4251237,,,,,,,,,,,,,,,,, -102189,,,0.6875606179237366,1.455111384391785,34.74172992301118,0.6875798106193542,1.437204360961914,30.38678618421871,3000.0,0.7030736207962036,1.3559722900390625,30.37917219885976,3003.0,35310.75461125374,58163.51621007919,35310.75461125374,22848.193954706192,1.3705039024353027,0.0 -102200,0.55379975,1.514541,,,,,,,,,,,,,,,,, -102300,0.5208849,1.4934185,,,,,,,,,,,,,,,,, -102400,0.511811,1.440255,,,,,,,,,,,,,,,,, -102500,0.5444983,1.5436416,,,,,,,,,,,,,,,,, -102600,0.5348397,1.5109433,,,,,,,,,,,,,,,,, -102700,0.56698006,1.6174423,,,,,,,,,,,,,,,,, -102800,0.5290247,1.5626076,,,,,,,,,,,,,,,,, -102900,0.5563187,1.475166,,,,,,,,,,,,,,,,, -103000,0.5286243,1.5534699,,,,,,,,,,,,,,,,, -103100,0.56667227,1.5371336,,,,,,,,,,,,,,,,, -103200,0.5437465,1.4954339,,,,,,,,,,,,,,,,, -103300,0.5387628,1.5416396,,,,,,,,,,,,,,,,, -103400,0.53348464,1.4988445,,,,,,,,,,,,,,,,, -103500,0.5539379,1.488808,,,,,,,,,,,,,,,,, -103600,0.56122047,1.5177265,,,,,,,,,,,,,,,,, -103700,0.5329111,1.5065787,,,,,,,,,,,,,,,,, -103800,0.54882175,1.4337503,,,,,,,,,,,,,,,,, -103900,0.56688017,1.5280448,,,,,,,,,,,,,,,,, -104000,0.5614288,1.4724768,,,,,,,,,,,,,,,,, -104100,0.56458515,1.5274919,,,,,,,,,,,,,,,,, -104200,0.58169353,1.4992981,,,,,,,,,,,,,,,,, -104300,0.5522368,1.523377,,,,,,,,,,,,,,,,, -104400,0.5654191,1.5165659,,,,,,,,,,,,,,,,, -104500,0.56665796,1.5641943,,,,,,,,,,,,,,,,, -104600,0.5467291,1.460899,,,,,,,,,,,,,,,,, -104623,,,0.6816208958625793,1.488276720046997,34.384006160637064,0.6881005764007568,1.4366995096206665,30.368678371068093,3000.0,0.704317033290863,1.349063277244568,30.614908680396404,3003.0,36150.90387272835,59591.86266922951,36150.90387272835,23436.27487707138,1.4109792709350586,0.0 -104700,0.54546505,1.5148993,,,,,,,,,,,,,,,,, -104800,0.5759701,1.4989882,,,,,,,,,,,,,,,,, -104900,0.5951403,1.4872108,,,,,,,,,,,,,,,,, -105000,0.5759742,1.4932826,,,,,,,,,,,,,,,,, -105100,0.576011,1.4603609,,,,,,,,,,,,,,,,, -105200,0.59016883,1.5363458,,,,,,,,,,,,,,,,, -105300,0.5728857,1.5052928,,,,,,,,,,,,,,,,, -105400,0.5645956,1.4581428,,,,,,,,,,,,,,,,, -105500,0.5514652,1.4694637,,,,,,,,,,,,,,,,, -105600,0.5506875,1.4841878,,,,,,,,,,,,,,,,, -105700,0.5758134,1.4720031,,,,,,,,,,,,,,,,, -105800,0.5763623,1.499143,,,,,,,,,,,,,,,,, -105900,0.60377467,1.4521456,,,,,,,,,,,,,,,,, -106000,0.56214935,1.4105823,,,,,,,,,,,,,,,,, -106100,0.59794694,1.4422203,,,,,,,,,,,,,,,,, -106200,0.55785394,1.528998,,,,,,,,,,,,,,,,, -106300,0.60044736,1.5315123,,,,,,,,,,,,,,,,, -106400,0.56923425,1.4902388,,,,,,,,,,,,,,,,, -106500,0.6146497,1.4676532,,,,,,,,,,,,,,,,, -106600,0.5951118,1.5110667,,,,,,,,,,,,,,,,, -106700,0.5835151,1.517517,,,,,,,,,,,,,,,,, -106800,0.5779365,1.4947844,,,,,,,,,,,,,,,,, -106900,0.5761258,1.4388647,,,,,,,,,,,,,,,,, -107000,0.5923493,1.5383124,,,,,,,,,,,,,,,,, -107056,,,0.6984134316444397,1.395855188369751,35.94551384243376,0.6892660856246948,1.4338706731796265,30.78100347405398,3000.0,0.7048166990280151,1.348414421081543,30.54643101468684,3003.0,36990.97637438774,60982.136293411255,36990.97637438774,23986.36045742035,1.4505560398101809,0.0 -107100,0.6015556,1.5014545,,,,,,,,,,,,,,,,, -107200,0.6214591,1.501044,,,,,,,,,,,,,,,,, -107300,0.5943485,1.4694769,,,,,,,,,,,,,,,,, -107400,0.577398,1.4826486,,,,,,,,,,,,,,,,, -107500,0.6078439,1.4064032,,,,,,,,,,,,,,,,, -107600,0.60270417,1.4408993,,,,,,,,,,,,,,,,, -107700,0.5725832,1.4233441,,,,,,,,,,,,,,,,, -107800,0.5907042,1.4650824,,,,,,,,,,,,,,,,, -107900,0.59388244,1.4072176,,,,,,,,,,,,,,,,, -108000,0.6016448,1.4749241,,,,,,,,,,,,,,,,, -108100,0.5944082,1.47403,,,,,,,,,,,,,,,,, -108200,0.657996,1.4739743,,,,,,,,,,,,,,,,, -108300,0.6145592,1.5170325,,,,,,,,,,,,,,,,, -108400,0.5969302,1.4446149,,,,,,,,,,,,,,,,, -108500,0.6198081,1.4576652,,,,,,,,,,,,,,,,, -108600,0.600595,1.4952189,,,,,,,,,,,,,,,,, -108700,0.62990737,1.5052735,,,,,,,,,,,,,,,,, -108800,0.601079,1.4378302,,,,,,,,,,,,,,,,, -108900,0.6282938,1.4249134,,,,,,,,,,,,,,,,, -109000,0.64183426,1.4897765,,,,,,,,,,,,,,,,, -109100,0.5840228,1.4012172,,,,,,,,,,,,,,,,, -109200,0.63126206,1.4993442,,,,,,,,,,,,,,,,, -109300,0.6382842,1.4314058,,,,,,,,,,,,,,,,, -109400,0.63465583,1.4324001,,,,,,,,,,,,,,,,, -109490,,,0.6916165947914124,1.4298664331436155,35.22134886001856,0.6891669034957886,1.4241664409637451,30.790310256380376,3000.0,0.7041543126106262,1.3388774394989014,30.61245571439726,3003.0,37831.17963075638,62354.47929620743,37831.17963075638,24518.378514528275,1.4953322410583496,0.0 -109500,0.6126713,1.5191386,,,,,,,,,,,,,,,,, -109600,0.61934865,1.3982248,,,,,,,,,,,,,,,,, -109700,0.61927617,1.4860255,,,,,,,,,,,,,,,,, -109800,0.6225015,1.3862898,,,,,,,,,,,,,,,,, -109900,0.6355399,1.489513,,,,,,,,,,,,,,,,, -110000,0.61420244,1.45679,,,,,,,,,,,,,,,,, -110100,0.6336134,1.4785582,,,,,,,,,,,,,,,,, -110200,0.6421359,1.495205,,,,,,,,,,,,,,,,, -110300,0.65209067,1.4404503,,,,,,,,,,,,,,,,, -110400,0.6350359,1.5454422,,,,,,,,,,,,,,,,, -110500,0.65518016,1.4364417,,,,,,,,,,,,,,,,, -110600,0.63539624,1.3932291,,,,,,,,,,,,,,,,, -110700,0.6413488,1.5145448,,,,,,,,,,,,,,,,, -110800,0.62659556,1.49479,,,,,,,,,,,,,,,,, -110900,0.63676316,1.5294547,,,,,,,,,,,,,,,,, -111000,0.63225716,1.4589761,,,,,,,,,,,,,,,,, -111100,0.63546467,1.4742788,,,,,,,,,,,,,,,,, -111200,0.62263453,1.4022554,,,,,,,,,,,,,,,,, -111300,0.6418641,1.4692781,,,,,,,,,,,,,,,,, -111400,0.64776814,1.4931462,,,,,,,,,,,,,,,,, -111500,0.64705414,1.4959605,,,,,,,,,,,,,,,,, -111600,0.66438913,1.4600868,,,,,,,,,,,,,,,,, -111700,0.6539847,1.4393818,,,,,,,,,,,,,,,,, -111800,0.658966,1.4305946,,,,,,,,,,,,,,,,, -111900,0.6505793,1.5058178,,,,,,,,,,,,,,,,, -111923,,,0.6884708404541016,1.4604493379592896,35.23403588662617,0.6906672120094299,1.4245483875274658,30.63906031522196,3000.0,0.7065016627311707,1.336126446723938,30.75464808617365,3003.0,38671.1463496685,63734.7251303196,38671.1463496685,25058.54103994369,1.5352284908294678,0.0 -112000,0.6996191,1.489044,,,,,,,,,,,,,,,,, -112100,0.647185,1.4054441,,,,,,,,,,,,,,,,, -112200,0.658151,1.4078462,,,,,,,,,,,,,,,,, -112300,0.65833604,1.4175758,,,,,,,,,,,,,,,,, -112400,0.65495974,1.4034567,,,,,,,,,,,,,,,,, -112500,0.69160527,1.4443954,,,,,,,,,,,,,,,,, -112600,0.6722154,1.4531084,,,,,,,,,,,,,,,,, -112700,0.67253673,1.3780756,,,,,,,,,,,,,,,,, -112800,0.6545162,1.4244944,,,,,,,,,,,,,,,,, -112900,0.6622443,1.4171903,,,,,,,,,,,,,,,,, -113000,0.65238124,1.3551422,,,,,,,,,,,,,,,,, -113100,0.65718853,1.3584236,,,,,,,,,,,,,,,,, -113200,0.6467583,1.3640065,,,,,,,,,,,,,,,,, -113300,0.6632455,1.454846,,,,,,,,,,,,,,,,, -113400,0.68520373,1.4235795,,,,,,,,,,,,,,,,, -113500,0.66082996,1.438221,,,,,,,,,,,,,,,,, -113600,0.6705801,1.4202304,,,,,,,,,,,,,,,,, -113700,0.70380443,1.4312887,,,,,,,,,,,,,,,,, -113800,0.69749224,1.4061569,,,,,,,,,,,,,,,,, -113900,0.661494,1.3861552,,,,,,,,,,,,,,,,, -114000,0.65678316,1.4147568,,,,,,,,,,,,,,,,, -114100,0.6637522,1.4175019,,,,,,,,,,,,,,,,, -114200,0.6695309,1.4519786,,,,,,,,,,,,,,,,, -114300,0.67262226,1.4518267,,,,,,,,,,,,,,,,, -114356,,,0.7008587121963501,1.3816499710083008,36.201180058925885,0.6914979219436646,1.4168438911437988,30.521495408385693,3000.0,0.7080007195472717,1.3302934169769287,30.880743489149467,3003.0,39511.106392622,65129.65449762344,39511.106392622,25613.393503665924,1.575636625289917,0.0 -114400,0.70391136,1.4245954,,,,,,,,,,,,,,,,, -114500,0.70572037,1.5051919,,,,,,,,,,,,,,,,, -114600,0.7123669,1.4109122,,,,,,,,,,,,,,,,, -114700,0.6979916,1.3753947,,,,,,,,,,,,,,,,, -114800,0.68520397,1.40049,,,,,,,,,,,,,,,,, -114900,0.6935783,1.4215658,,,,,,,,,,,,,,,,, -115000,0.66576356,1.3840593,,,,,,,,,,,,,,,,, -115100,0.676703,1.3748962,,,,,,,,,,,,,,,,, -115200,0.74678886,1.4389969,,,,,,,,,,,,,,,,, -115300,0.7289901,1.4497994,,,,,,,,,,,,,,,,, -115400,0.7180165,1.4453174,,,,,,,,,,,,,,,,, -115500,0.717476,1.4724845,,,,,,,,,,,,,,,,, -115600,0.67955154,1.4353204,,,,,,,,,,,,,,,,, -115700,0.7079203,1.4551709,,,,,,,,,,,,,,,,, -115800,0.7260271,1.4140494,,,,,,,,,,,,,,,,, -115900,0.64854485,1.4220434,,,,,,,,,,,,,,,,, -116000,0.7245371,1.4356735,,,,,,,,,,,,,,,,, -116100,0.7036901,1.3877968,,,,,,,,,,,,,,,,, -116200,0.72723174,1.396678,,,,,,,,,,,,,,,,, -116300,0.6773259,1.4114008,,,,,,,,,,,,,,,,, -116400,0.7145967,1.3293922,,,,,,,,,,,,,,,,, -116500,0.68245304,1.4097275,,,,,,,,,,,,,,,,, -116600,0.6906183,1.4275055,,,,,,,,,,,,,,,,, -116700,0.7141227,1.3521283,,,,,,,,,,,,,,,,, -116790,,,0.7000380158424377,1.39041268825531,35.78679107524839,0.6912747621536255,1.4197064638137815,30.574397190601577,3000.0,0.7080704569816589,1.329336166381836,30.986274615235534,3003.0,40351.31870722771,66497.25703215599,40351.31870722771,26140.66422200203,1.618021011352539,0.0 -116800,0.72708315,1.4017978,,,,,,,,,,,,,,,,, -116900,0.7349251,1.398933,,,,,,,,,,,,,,,,, -117000,0.7268908,1.4445717,,,,,,,,,,,,,,,,, -117100,0.735946,1.4210436,,,,,,,,,,,,,,,,, -117200,0.7549469,1.4064811,,,,,,,,,,,,,,,,, -117300,0.717164,1.4399724,,,,,,,,,,,,,,,,, -117400,0.7109751,1.4490156,,,,,,,,,,,,,,,,, -117500,0.7405001,1.4717687,,,,,,,,,,,,,,,,, -117600,0.71260834,1.3524147,,,,,,,,,,,,,,,,, -117700,0.69847083,1.3634114,,,,,,,,,,,,,,,,, -117800,0.70654124,1.3906523,,,,,,,,,,,,,,,,, -117900,0.7628297,1.3938473,,,,,,,,,,,,,,,,, -118000,0.70106286,1.3166395,,,,,,,,,,,,,,,,, -118100,0.7181934,1.4122169,,,,,,,,,,,,,,,,, -118200,0.73812175,1.3608986,,,,,,,,,,,,,,,,, -118300,0.72600716,1.432836,,,,,,,,,,,,,,,,, -118400,0.7105854,1.4035425,,,,,,,,,,,,,,,,, -118500,0.7367743,1.4076629,,,,,,,,,,,,,,,,, -118600,0.71991855,1.3768018,,,,,,,,,,,,,,,,, -118700,0.72914755,1.4635983,,,,,,,,,,,,,,,,, -118800,0.72013247,1.3792199,,,,,,,,,,,,,,,,, -118900,0.74079704,1.3505346,,,,,,,,,,,,,,,,, -119000,0.75231117,1.460871,,,,,,,,,,,,,,,,, -119100,0.7348429,1.3801172,,,,,,,,,,,,,,,,, -119200,0.7500892,1.4928011,,,,,,,,,,,,,,,,, -119223,,,0.7113429307937622,1.3300620317459106,36.80936519873941,0.6935933828353882,1.412529468536377,30.94324378288797,3000.0,0.7090930342674255,1.3264704942703247,31.189875956554665,3003.0,41191.29754805565,67847.41802787781,41191.29754805565,26650.726126909256,1.6613097190856934,0.0 -119300,0.72220254,1.3766801,,,,,,,,,,,,,,,,, -119400,0.76508015,1.4236009,,,,,,,,,,,,,,,,, -119500,0.7468234,1.4089704,,,,,,,,,,,,,,,,, -119600,0.7272625,1.4062521,,,,,,,,,,,,,,,,, -119700,0.7400393,1.4233745,,,,,,,,,,,,,,,,, -119800,0.7236235,1.3342541,,,,,,,,,,,,,,,,, -119900,0.73035586,1.4048924,,,,,,,,,,,,,,,,, -120000,0.74205655,1.3802918,,,,,,,,,,,,,,,,, -120100,0.75505006,1.411388,,,,,,,,,,,,,,,,, -120200,0.75166446,1.368275,,,,,,,,,,,,,,,,, -120300,0.77437365,1.4442909,,,,,,,,,,,,,,,,, -120400,0.7550732,1.4468375,,,,,,,,,,,,,,,,, -120500,0.73274386,1.406294,,,,,,,,,,,,,,,,, -120600,0.74380213,1.4663141,,,,,,,,,,,,,,,,, -120700,0.75802183,1.3819983,,,,,,,,,,,,,,,,, -120800,0.7575241,1.3794954,,,,,,,,,,,,,,,,, -120900,0.7968368,1.4025593,,,,,,,,,,,,,,,,, -121000,0.7656063,1.3346965,,,,,,,,,,,,,,,,, -121100,0.75268936,1.3601054,,,,,,,,,,,,,,,,, -121200,0.77071834,1.4049665,,,,,,,,,,,,,,,,, -121300,0.7467809,1.3090005,,,,,,,,,,,,,,,,, -121400,0.79585713,1.3694255,,,,,,,,,,,,,,,,, -121500,0.73749083,1.3860606,,,,,,,,,,,,,,,,, -121600,0.71401864,1.2896416,,,,,,,,,,,,,,,,, -121656,,,0.707312285900116,1.3521544933319092,36.62262186936308,0.6932833790779114,1.41163969039917,30.834937831447107,3000.0,0.7095810770988464,1.3253158330917358,30.90680463875628,3003.0,42031.31493568421,69267.97033762932,42031.31493568421,27231.141949653625,1.7042453289031982,0.0 -121700,0.74304664,1.4053357,,,,,,,,,,,,,,,,, -121800,0.7820234,1.4545884,,,,,,,,,,,,,,,,, -121900,0.74956256,1.4072173,,,,,,,,,,,,,,,,, -122000,0.7782376,1.351329,,,,,,,,,,,,,,,,, -122100,0.7713306,1.3018951,,,,,,,,,,,,,,,,, -122200,0.76039135,1.4112731,,,,,,,,,,,,,,,,, -122300,0.76107657,1.3550657,,,,,,,,,,,,,,,,, -122400,0.78988916,1.3279034,,,,,,,,,,,,,,,,, -122500,0.73650223,1.3381896,,,,,,,,,,,,,,,,, -122600,0.7447689,1.3741562,,,,,,,,,,,,,,,,, -122700,0.7716305,1.3921105,,,,,,,,,,,,,,,,, -122800,0.75413287,1.419559,,,,,,,,,,,,,,,,, -122900,0.76955795,1.3517134,,,,,,,,,,,,,,,,, -123000,0.77371156,1.4377536,,,,,,,,,,,,,,,,, -123100,0.7814466,1.3896536,,,,,,,,,,,,,,,,, -123200,0.77261543,1.31878,,,,,,,,,,,,,,,,, -123300,0.7829885,1.3862834,,,,,,,,,,,,,,,,, -123400,0.7755984,1.3817443,,,,,,,,,,,,,,,,, -123500,0.82417387,1.3421046,,,,,,,,,,,,,,,,, -123600,0.79071724,1.4274391,,,,,,,,,,,,,,,,, -123700,0.76395506,1.3694023,,,,,,,,,,,,,,,,, -123800,0.7665459,1.3460609,,,,,,,,,,,,,,,,, -123900,0.7819688,1.3668813,,,,,,,,,,,,,,,,, -124000,0.7618885,1.3061774,,,,,,,,,,,,,,,,, -124089,,,0.707473874092102,1.3492178916931152,36.84931955921861,0.6941513419151306,1.4105970859527588,30.96556062109168,3000.0,0.7108128666877747,1.3213181495666504,31.108249855042047,3003.0,42871.35903739929,70667.94770431519,42871.35903739929,27790.95612001419,1.74674654006958,0.0 -124100,0.77860945,1.3290573,,,,,,,,,,,,,,,,, -124200,0.7528982,1.3597314,,,,,,,,,,,,,,,,, -124300,0.7867734,1.3229373,,,,,,,,,,,,,,,,, -124400,0.78263754,1.3960044,,,,,,,,,,,,,,,,, -124500,0.76600933,1.4022229,,,,,,,,,,,,,,,,, -124600,0.7738967,1.3895655,,,,,,,,,,,,,,,,, -124700,0.80894727,1.369769,,,,,,,,,,,,,,,,, -124800,0.8065855,1.4000854,,,,,,,,,,,,,,,,, -124900,0.7575017,1.3866975,,,,,,,,,,,,,,,,, -125000,0.774752,1.3790083,,,,,,,,,,,,,,,,, -125100,0.79491687,1.3572285,,,,,,,,,,,,,,,,, -125200,0.7703295,1.3110732,,,,,,,,,,,,,,,,, -125300,0.75889736,1.3573301,,,,,,,,,,,,,,,,, -125400,0.7752876,1.4433813,,,,,,,,,,,,,,,,, -125500,0.7993141,1.3861037,,,,,,,,,,,,,,,,, -125600,0.7792761,1.4076571,,,,,,,,,,,,,,,,, -125700,0.7687322,1.3497807,,,,,,,,,,,,,,,,, -125800,0.77783024,1.3903399,,,,,,,,,,,,,,,,, -125900,0.77962655,1.3716109,,,,,,,,,,,,,,,,, -126000,0.7946975,1.3687953,,,,,,,,,,,,,,,,, -126100,0.76133937,1.3880109,,,,,,,,,,,,,,,,, -126200,0.7768338,1.3306822,,,,,,,,,,,,,,,,, -126300,0.78693545,1.4796946,,,,,,,,,,,,,,,,, -126400,0.78017795,1.3815368,,,,,,,,,,,,,,,,, -126500,0.7493745,1.3594286,,,,,,,,,,,,,,,,, -126522,,,0.7100922465324402,1.3384053707122805,37.03802375742551,0.6936057806015015,1.4109114408493042,30.960405634997983,3000.0,0.7102434635162354,1.3205872774124146,31.24458192290439,3003.0,43711.43947052956,72065.4435763359,43711.43947052956,28348.25263595581,1.7892396450042725,0.0 -126600,0.76101416,1.2907358,,,,,,,,,,,,,,,,, -126700,0.7653846,1.2851567,,,,,,,,,,,,,,,,, -126800,0.8270454,1.355536,,,,,,,,,,,,,,,,, -126900,0.76457155,1.3068538,,,,,,,,,,,,,,,,, -127000,0.7763685,1.3053145,,,,,,,,,,,,,,,,, -127100,0.7709277,1.3288943,,,,,,,,,,,,,,,,, -127200,0.7482085,1.3213636,,,,,,,,,,,,,,,,, -127300,0.7637961,1.3805083,,,,,,,,,,,,,,,,, -127400,0.772419,1.3034099,,,,,,,,,,,,,,,,, -127500,0.7918982,1.4089962,,,,,,,,,,,,,,,,, -127600,0.7701341,1.3568368,,,,,,,,,,,,,,,,, -127700,0.781157,1.3695096,,,,,,,,,,,,,,,,, -127800,0.7944655,1.3373679,,,,,,,,,,,,,,,,, -127900,0.77488613,1.3257159,,,,,,,,,,,,,,,,, -128000,0.7852339,1.3676895,,,,,,,,,,,,,,,,, -128100,0.77486,1.4017545,,,,,,,,,,,,,,,,, -128200,0.76405865,1.3875636,,,,,,,,,,,,,,,,, -128300,0.7905666,1.4025707,,,,,,,,,,,,,,,,, -128400,0.79479283,1.3382286,,,,,,,,,,,,,,,,, -128500,0.76637447,1.3583016,,,,,,,,,,,,,,,,, -128600,0.7759486,1.3943862,,,,,,,,,,,,,,,,, -128700,0.81028485,1.3722844,,,,,,,,,,,,,,,,, -128800,0.77751434,1.3872436,,,,,,,,,,,,,,,,, -128900,0.76863164,1.3646301,,,,,,,,,,,,,,,,, -128955,,,0.7126824855804443,1.3262940645217896,36.65160501353665,0.6942629218101501,1.4080395698547363,31.02216745186169,3000.0,0.7098715901374817,1.3178081512451172,31.098821084681767,3003.0,44551.3946352005,73475.28617930412,44551.3946352005,28918.01958155632,1.8332068920135496,0.0 -129000,0.79781324,1.3763785,,,,,,,,,,,,,,,,, -129100,0.8364594,1.3128244,,,,,,,,,,,,,,,,, -129200,0.8077424,1.3683293,,,,,,,,,,,,,,,,, -129300,0.786902,1.3555657,,,,,,,,,,,,,,,,, -129400,0.76001805,1.3199704,,,,,,,,,,,,,,,,, -129500,0.7702387,1.320538,,,,,,,,,,,,,,,,, -129600,0.7543605,1.3244925,,,,,,,,,,,,,,,,, -129700,0.79615986,1.371974,,,,,,,,,,,,,,,,, -129800,0.7795013,1.3136736,,,,,,,,,,,,,,,,, -129900,0.7717003,1.4056039,,,,,,,,,,,,,,,,, -130000,0.8025444,1.298372,,,,,,,,,,,,,,,,, -130100,0.77663374,1.3687011,,,,,,,,,,,,,,,,, -130200,0.77739125,1.3926336,,,,,,,,,,,,,,,,, -130300,0.7788351,1.424436,,,,,,,,,,,,,,,,, -130400,0.7894734,1.3631905,,,,,,,,,,,,,,,,, -130500,0.7989533,1.372098,,,,,,,,,,,,,,,,, -130600,0.7674482,1.3583901,,,,,,,,,,,,,,,,, -130700,0.7904552,1.3483679,,,,,,,,,,,,,,,,, -130800,0.7904304,1.4087442,,,,,,,,,,,,,,,,, -130900,0.7701541,1.3276004,,,,,,,,,,,,,,,,, -131000,0.75930566,1.3054951,,,,,,,,,,,,,,,,, -131100,0.7689815,1.3860377,,,,,,,,,,,,,,,,, -131200,0.78275055,1.3945334,,,,,,,,,,,,,,,,, -131300,0.79805744,1.428284,,,,,,,,,,,,,,,,, -131389,,,0.712949275970459,1.319365382194519,36.297809192303134,0.6945109367370605,1.408849596977234,30.947350711661613,3000.0,0.7102085947990417,1.3181660175323486,31.13512948906848,3003.0,45391.56788253784,74884.97759580612,45391.56788253784,29487.417108535767,1.877014636993408,0.0 -131400,0.78481287,1.3441764,,,,,,,,,,,,,,,,, -131500,0.79932106,1.3539852,,,,,,,,,,,,,,,,, -131600,0.79758865,1.3606266,,,,,,,,,,,,,,,,, -131700,0.80779725,1.3790811,,,,,,,,,,,,,,,,, -131800,0.7888184,1.3794757,,,,,,,,,,,,,,,,, -131900,0.77362114,1.2969183,,,,,,,,,,,,,,,,, -132000,0.7952101,1.3892677,,,,,,,,,,,,,,,,, -132100,0.78486216,1.2871821,,,,,,,,,,,,,,,,, -132200,0.768173,1.3630207,,,,,,,,,,,,,,,,, -132300,0.78003526,1.3578666,,,,,,,,,,,,,,,,, -132400,0.806174,1.4027734,,,,,,,,,,,,,,,,, -132500,0.7715653,1.3752556,,,,,,,,,,,,,,,,, -132600,0.7765017,1.3301101,,,,,,,,,,,,,,,,, -132700,0.78352445,1.3814825,,,,,,,,,,,,,,,,, -132800,0.7816184,1.4152424,,,,,,,,,,,,,,,,, -132900,0.7617485,1.4289442,,,,,,,,,,,,,,,,, -133000,0.7714739,1.3942658,,,,,,,,,,,,,,,,, -133100,0.7772716,1.3298506,,,,,,,,,,,,,,,,, -133200,0.77754885,1.3444813,,,,,,,,,,,,,,,,, -133300,0.76424474,1.4284052,,,,,,,,,,,,,,,,, -133333,,,0.7099116444587708,1.3356733322143557,37.02394184889104,0.6946597099304199,1.4091264009475708,31.011342101754487,3000.0,0.7103132009506226,1.3183356523513794,31.18299976012298,3003.0,46062.535388469696,76133.92356038094,46062.535388469696,30065.29164481163,1.9203734397888184,0.0 -133333,,,,,,,,,,,,,,46062.535388469696,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/eval_measurements.csv deleted file mode 100644 index 77dfd5e4b..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/eval_measurements.csv +++ /dev/null @@ -1,57 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -857.3925604820251,0.0,27.777300596237183,1,0,27.777300596237183,0.0007088489946909,0.0,11.036273956298828,3003,885.1699132919312,0.0006736774812452,0.0,11.026397705078123,0.0004835649742744,0.0,11.047277450561523,3000 -1324.3492550849917,0.0198400020599365,867.7713329792023,2432,0,867.7713329792023,0.5400848388671875,18.86934210352786,2.5434446334838867,3003,2192.2171177864075,0.5334533452987671,24.246267876536,2.6014294624328613,0.5396957397460938,20.273622232300472,2.53312087059021,3000 -1828.375424146652,0.0461373329162597,1707.9139828681946,4865,0,1707.9139828681946,0.6017082333564758,22.691728876503863,2.0296666622161865,3003,3536.4875156879425,0.5815564393997192,27.39705149030112,2.21299695968628,0.5966076254844666,23.915241366267704,2.080467700958252,3000 -2322.1379055976868,0.0710291862487793,2547.882658958435,7297,0,2547.882658958435,0.6117367148399353,23.37913525773548,1.962164282798767,3003,4870.319628000259,0.5896901488304138,27.75413258659114,2.140878677368164,0.6057829260826111,24.422510760879085,2.008117437362671,3000 -2821.1698200702667,0.098750352859497,3388.040452480316,9730,0,3388.040452480316,0.6183835864067078,23.38670207942713,1.9371329545974727,3003,6209.612332820892,0.5907862186431885,27.758313068089063,2.14593243598938,0.6103581786155701,24.93721700930189,1.9752626419067385,3000 -3356.9464728832245,0.1251380443572998,4228.165338039398,12163,0,4228.165338039398,0.6207425594329834,23.72557107950406,1.8987210988998413,3003,7585.615716218948,0.5988882184028625,27.654304678780385,2.0783145427703857,0.6146730780601501,24.84778671859338,1.9561139345169067,3000 -3921.0114846229553,0.152226448059082,5068.237158060074,14596,0,5068.237158060074,0.6213933229446411,23.30884255123911,1.8927571773529053,3003,8989.855357408524,0.5957963466644287,27.566099296535413,2.0913431644439697,0.6168429255485535,24.633238759292187,1.918594479560852,3000 -4409.076854467392,0.1800706386566162,5908.252819776535,17029,0,5908.252819776535,0.628063440322876,24.449780344884086,1.8578368425369265,3003,10318.040254831314,0.5992476940155029,27.728670866665368,2.066373348236084,0.6194343566894531,25.00878222369035,1.9151461124420168,3000 -4927.977090597153,0.2080614566802978,6748.244628667831,19461,0,6748.244628667831,0.6275056600570679,24.24569041878532,1.8592458963394165,3003,11677.037281274796,0.6093730926513672,28.350201916222264,1.9898121356964111,0.6198559403419495,25.218665959343333,1.9141809940338133,3000 -5595.293299913406,0.2365128993988037,7588.188913345337,21893,0,7588.188913345337,0.6247748732566833,23.55798125524869,1.851194977760315,3003,13184.40267777443,0.6032451391220093,27.34081747482811,2.035078763961792,0.6205874681472778,25.087785595498843,1.8935205936431885,3000 -6137.813019990921,0.2643194198608398,8428.134506702423,24326,0,8428.134506702423,0.6287490725517273,24.262781323959945,1.8355058431625368,3003,14566.971681118011,0.6028209328651428,28.100459720929965,2.0420238971710205,0.6226333379745483,24.9253727665481,1.877811908721924,3000 -6695.690530538559,0.2926228046417236,9268.36467552185,26759,0,9268.36467552185,0.6294578909873962,24.53579871136742,1.8357408046722408,3003,15965.183842658997,0.6044977307319641,28.284201263564533,2.021979808807373,0.624183177947998,25.36061576960488,1.8740408420562744,3000 -7188.456845998764,0.3211045265197754,10108.279440879822,29192,0,10108.279440879822,0.6307710409164429,24.321099015768755,1.8219873905181885,3003,17297.96948671341,0.6068159937858582,28.95645104575732,2.0194811820983887,0.623960018157959,25.25651891429717,1.871357560157776,3000 -7705.389865159988,0.3495962619781494,10948.29575920105,31625,0,10948.29575920105,0.6341409683227539,25.021743492191217,1.80830192565918,3003,18655.023404359818,0.6236925721168518,30.04438775326001,1.8840527534484863,0.6268489956855774,25.799430214237688,1.861114740371704,3000 -8303.139657497406,0.3803756237030029,11788.428247213364,34058,0,11788.428247213364,0.6344314813613892,24.71210323790936,1.792935132980347,3003,20093.013048648834,0.6058012843132019,28.87051163861552,2.0277810096740723,0.6283245086669922,25.78896790070437,1.8390883207321167,3000 -8834.580487966537,0.4115097522735595,12628.50937962532,36491,0,12628.50937962532,0.6373133659362793,25.120160265789195,1.7856651544570925,3003,21464.642553329468,0.608106255531311,28.904658555369576,2.017210006713867,0.6289692521095276,26.038305212224035,1.84528398513794,3000 -9329.693039894104,0.4427659511566162,13468.485274791718,38923,0,13468.485274791718,0.6345128417015076,24.61682099042204,1.7959977388381958,3003,22799.8377776146,0.6152269244194031,28.735497170260977,1.9483067989349363,0.6286717057228088,25.679533221535827,1.8417868614196773,3000 -9775.80866074562,0.4757249355316162,14308.5158598423,41356,0,14308.5158598423,0.6367555856704712,24.162457512803574,1.777856707572937,3003,24086.092776298523,0.6089126467704773,28.389003946961708,1.994400978088379,0.6299487948417664,25.0824188802018,1.834946632385254,3000 -10531.183442354202,0.5067262649536133,15148.529332399368,43789,0,15148.529332399368,0.6381616592407227,25.0090002959881,1.7659757137298584,3003,25681.5877430439,0.6921001672744751,35.19845928097932,1.4866576194763184,0.6321682333946228,25.890601540900896,1.8150848150253296,3000 -11052.160829782486,0.5364077091217041,15988.571270942688,46222,0,15988.571270942688,0.6401255130767822,24.83304160183564,1.7569904327392578,3003,27042.712433576584,0.613900899887085,28.75841070693544,1.9534741640090945,0.6313250660896301,25.706606452819305,1.8070567846298216,3000 -11559.00778746605,0.5670428276062012,16828.72014260292,48655,0,16828.72014260292,0.6380454301834106,24.83728357032393,1.7637287378311155,3003,28389.81461524964,0.6083666682243347,29.04350547005292,1.9990887641906736,0.6298372149467468,25.873173393998723,1.8104116916656487,3000 -12068.680082798004,0.6013104915618896,17668.943604707718,51088,0,17668.943604707718,0.6439370512962341,25.48859377960028,1.7313753366470337,3003,29739.8209412098,0.618570864200592,29.38637507894043,1.924091100692749,0.6350200176239014,26.589383299087874,1.7849268913269043,3000 -12595.083347082138,0.6347942352294922,18509.00431752205,53521,0,18509.00431752205,0.6456336379051208,25.94893984098375,1.7239888906478882,3003,31106.393884658813,0.6205177903175354,29.460057986444184,1.918626427650452,0.6394093036651611,26.36329864950495,1.774772047996521,3000 -13107.286288261414,0.6692543029785156,19348.94182872772,55954,0,19348.94182872772,0.650665283203125,26.0390079997861,1.6983579397201538,3003,32458.6449136734,0.616606593132019,29.716919854833712,1.9429373741149905,0.6380329728126526,26.87983601191332,1.7670636177062988,3000 -13734.115124940872,0.7012717723846436,20189.042788743973,58387,0,20189.042788743973,0.6511068940162659,25.839829990207257,1.6911041736602783,3003,33925.682745695114,0.6255083680152893,29.25000924980348,1.883889317512512,0.63966965675354,26.65371403234823,1.7545181512832642,3000 -14196.984060525894,0.7340254783630371,21029.19118499756,60820,0,21029.19118499756,0.6530939936637878,25.896172575343115,1.672269582748413,3003,35228.80893397331,0.6229991912841797,29.73174634459232,1.90169644355774,0.6420503258705139,26.703019452981483,1.7479721307754517,3000 -14683.410673379898,0.7670059204101562,21869.32551908493,63253,0,21869.32551908493,0.6543838381767273,26.206675737115404,1.6732841730117798,3003,36555.478934049606,0.6290479302406311,30.297907849320147,1.827910304069519,0.6419262886047363,26.966711675602447,1.7374597787857056,3000 -15256.493627786636,0.8021972179412842,22709.30819892884,65685,0,22709.30819892884,0.6523967385292053,26.01282070100662,1.6633965969085691,3003,37968.656623363495,0.6229196786880493,29.82231045733329,1.90165114402771,0.6441333293914795,27.3569961752832,1.7247092723846436,3000 -15775.118801116943,0.8397126197814941,23549.419113636017,68118,0,23549.419113636017,0.6561850309371948,26.68360440541797,1.6488022804260254,3003,39327.50707483292,0.6257489323616028,29.903225654359392,1.8845466375350952,0.6469975709915161,27.369239895807823,1.7029153108596802,3000 -16256.143748998642,0.8760302066802979,24389.49199271202,70552,0,24389.49199271202,0.6587531566619873,26.84871294586024,1.639639139175415,3003,40648.717832803726,0.6302703022956848,30.171382342838893,1.8330473899841309,0.649502158164978,27.07700891501412,1.6973282098770142,3000 -16764.72614622116,0.911757230758667,25229.58637213707,72985,0,25229.58637213707,0.6605426669120789,26.85367362042821,1.6174166202545166,3003,41997.50688147545,0.6285916566848755,30.52954279286608,1.8592113256454468,0.6510024666786194,27.65884298530821,1.6859381198883057,3000 -17357.010063409805,0.9460341930389404,26069.70525288582,75418,0,26069.70525288582,0.6633432507514954,27.128032352674563,1.6029932498931885,3003,43430.02218937874,0.6453965306282043,31.558854951636462,1.716774582862854,0.6527011394500732,27.563748681111136,1.681450605392456,3000 -17892.767714738846,0.9847004413604736,26909.717646598816,77852,0,26909.717646598816,0.667956531047821,27.65037334871951,1.5868098735809326,3003,44805.90866136551,0.6361210942268372,30.33095928486013,1.8001832962036133,0.6541270613670349,27.47509347000726,1.6627322435379028,3000 -18396.198426246643,1.0218725204467771,27749.741545915604,80285,0,27749.741545915604,0.6648771166801453,27.20680830071649,1.5862025022506714,3003,46149.47729349136,0.6344415545463562,30.737479092771107,1.814192771911621,0.6556149125099182,27.724983358387743,1.6477035284042358,3000 -18964.878935098648,1.0599958896636963,28589.92642354965,82719,0,28589.92642354965,0.6703736186027527,27.516618205330463,1.5644491910934448,3003,47558.457666397095,0.6453226804733276,31.61777031006529,1.7232457399368286,0.6575492024421692,28.11244713518492,1.6341495513916016,3000 -19489.93882274628,1.096651315689087,29430.129603147507,85153,0,29430.129603147507,0.672267735004425,27.951467340137658,1.5401736497879028,3003,48923.83412575722,0.6420574188232422,31.85978635812725,1.7634731531143188,0.6593594551086426,28.13691576680429,1.616284966468811,3000 -20041.71503210068,1.135914325714111,30270.221952676773,87586,0,30270.221952676773,0.6754401326179504,27.99981196715836,1.5277425050735474,3003,50315.81853723526,0.706721842288971,36.76419684639784,1.3691809177398682,0.664182722568512,28.74037977809781,1.5927585363388062,3000 -20610.94938087464,1.1731517314910889,31110.187237501144,90019,0,31110.187237501144,0.6758701205253601,27.95895404149604,1.5183173418045044,3003,51725.13202857971,0.6472726464271545,31.716836297151044,1.7228424549102783,0.6628807783126831,28.02380074296613,1.5885562896728516,3000 -21215.34196662903,1.2124512195587158,31950.0975458622,92452,0,31950.0975458622,0.6786009073257446,28.248444795106696,1.4941082000732422,3003,53169.55118060112,0.6476435661315918,31.69203925629525,1.7178657054901123,0.6662781834602356,28.56802874801828,1.5724620819091797,3000 -21692.634751558304,1.252263069152832,32790.26389718056,94886,0,32790.26389718056,0.6810412406921387,28.72655015719533,1.4820163249969482,3003,54487.12789297104,0.6604101657867432,32.403113391945894,1.6313503980636597,0.6679768562316895,28.760747178515107,1.559630036354065,3000 -22176.063058376312,1.2936749458312988,33630.172709703445,97319,0,33630.172709703445,0.6848992109298706,28.58278102882752,1.4625848531723022,3003,55810.58344745636,0.6549776792526245,32.00376825833275,1.674398422241211,0.6725149154663086,28.80237879604785,1.5401153564453125,3000 -22696.30715584755,1.3333721160888672,34470.056736946106,99752,0,34470.056736946106,0.6886526346206665,28.744171233574672,1.4499832391738892,3003,57170.82868671417,0.6542562246322632,32.45218523737319,1.6796025037765503,0.6749699115753174,29.279505165766786,1.5269558429718018,3000 -23320.24372291565,1.3716270923614502,35310.28398871422,102186,0,35310.28398871422,0.6881877779960632,28.851760085084145,1.4374244213104248,3003,58635.10763931274,0.6620293259620667,32.64789488614793,1.618472933769226,0.6765817999839783,29.198657105886326,1.5140327215194702,3000 -23875.32908177376,1.4112050533294678,36150.45073056221,104621,0,36150.45073056221,0.693591296672821,29.587910915299563,1.4133522510528564,3003,60030.47633481026,0.6610704660415649,32.60958248876339,1.626672625541687,0.6787516474723816,29.66641032835177,1.5002140998840332,3000 -24597.542265176773,1.4509341716766355,36990.67479014397,107055,0,36990.67479014397,0.6934751272201538,29.51269855331246,1.4053597450256348,3003,61593.02966165543,0.6778357028961182,33.98085596494906,1.5199074745178225,0.6803759336471558,29.69537641347098,1.4900200366973877,3000 -25250.408385038376,1.4924664497375488,37830.68860411644,109489,0,37830.68860411644,0.6989367604255676,29.94549352250294,1.3856124877929688,3003,63086.02798628807,0.6716055870056152,33.4808430765383,1.5597528219223022,0.6831905245780945,30.04091275834891,1.4681103229522705,3000 -25792.438461780548,1.5327885150909424,38670.9139418602,111923,0,38670.9139418602,0.6991691589355469,29.85416362140656,1.375052571296692,3003,64468.39996099472,0.6700042486190796,33.42104699643389,1.5730342864990234,0.6848644018173218,30.067353708829728,1.4581775665283203,3000 -26359.942928552628,1.5786361694335938,39510.85387992859,114357,0,39510.85387992859,0.702213704586029,30.16390519547844,1.3610974550247192,3003,65875.96654629707,0.6836004257202148,34.42451429782527,1.4913737773895264,0.6876541972160339,30.17884093970701,1.4455455541610718,3000 -26933.835524082184,1.6199851036071775,40350.74137663841,116791,0,40350.74137663841,0.7040497660636902,30.197523386417203,1.344609022140503,3003,67289.86536455154,0.6783912181854248,34.31626903846658,1.516280174255371,0.6888321042060852,30.4859611177528,1.4344900846481323,3000 -27627.37395954132,1.6613929271697998,41190.94413161278,119225,0,41190.94413161278,0.704317033290863,30.53495047135109,1.340147614479065,3003,68823.72482323647,0.6942856907844543,35.552780823300324,1.4276137351989746,0.6898612380027771,30.59610711383455,1.4264612197875977,3000 -28212.71403479576,1.703831911087036,42030.92356848717,121659,0,42030.92356848717,0.708349347114563,30.80779430639281,1.3267370462417605,3003,70249.16306734085,0.6907327771186829,34.881001369806654,1.4434564113616943,0.6906051635742188,30.869333107168814,1.420662522315979,3000 -28829.68808722496,1.747204303741455,42871.01779818535,124093,0,42871.01779818535,0.708453893661499,30.912976176032,1.3230172395706177,3003,71706.35120844841,0.6904439330101013,35.23399677830485,1.4435969591140747,0.6916218996047974,30.74350009523277,1.41691792011261,3000 -29428.44505548477,1.790480613708496,43711.224172115326,126527,0,43711.224172115326,0.7094765305519104,30.907800321136087,1.3161181211471558,3003,73145.43577122688,0.6974536180496216,35.86837829412484,1.413323998451233,0.693134605884552,30.930916769026155,1.4092453718185425,3000 -30052.566106319427,1.8353376388549805,44551.23852276802,128961,0,44551.23852276802,0.7107547521591187,31.213135886632305,1.3119847774505615,3003,74609.69246411324,0.6983780860900879,35.716242509446154,1.4064826965332031,0.692638635635376,31.06077647495805,1.40710711479187,3000 -30635.33205962181,1.8811628818511963,45391.26353955269,131395,0,45391.26353955269,0.711498498916626,31.106689500794637,1.3109595775604248,3003,76032.60543179512,0.6976325511932373,35.93356190724933,1.408244490623474,0.6933081746101379,30.89430551114946,1.4070932865142822,3000 -31230.08441901207,1.925365686416626,46060.09185528755,133333,0,46060.09185528755,0.7112892866134644,31.039245815704213,1.3105939626693726,3003,77296.29063916206,0.6962321400642395,35.18021952446278,1.4215978384017944,0.6931717991828918,30.9194758232504,1.4065862894058228,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/measurements.csv deleted file mode 100644 index 3e7943cee..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/measurements.csv +++ /dev/null @@ -1,1392 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.7015142,11.023183,,,,,,,,,,,,,,,,, -1,,,0.0006736774812452,11.026397705078123,0.0,0.0004835649742744,11.047277450561523,0.0,3000.0,0.0007088489946909,11.036273956298828,0.0,3003.0,27.777300596237183,885.1699132919312,27.777300596237183,857.3925604820251,0.0,0.0 -100,0.7236198,7.5908523,,,,,,,,,,,,,,,,, -200,0.49577656,6.6550875,,,,,,,,,,,,,,,,, -300,0.46937773,5.877396,,,,,,,,,,,,,,,,, -400,0.6034638,5.4385247,,,,,,,,,,,,,,,,, -500,0.5935078,5.072424,,,,,,,,,,,,,,,,, -600,0.5109654,4.7100377,,,,,,,,,,,,,,,,, -700,0.59780854,4.519241,,,,,,,,,,,,,,,,, -800,0.5068849,4.207457,,,,,,,,,,,,,,,,, -900,0.48865145,4.0253754,,,,,,,,,,,,,,,,, -1000,0.49260408,3.8581696,,,,,,,,,,,,,,,,, -1100,0.4748168,3.732478,,,,,,,,,,,,,,,,, -1200,0.4123902,3.5146642,,,,,,,,,,,,,,,,, -1300,0.3919647,3.536646,,,,,,,,,,,,,,,,, -1400,0.319177,3.3566096,,,,,,,,,,,,,,,,, -1500,0.28331584,3.1859963,,,,,,,,,,,,,,,,, -1600,0.28794387,3.1678243,,,,,,,,,,,,,,,,, -1700,0.27733174,3.1077132,,,,,,,,,,,,,,,,, -1800,0.34164575,2.8980412,,,,,,,,,,,,,,,,, -1900,0.22663194,2.8826196,,,,,,,,,,,,,,,,, -2000,0.28778556,2.8468094,,,,,,,,,,,,,,,,, -2100,0.19944194,2.6527624,,,,,,,,,,,,,,,,, -2200,0.23008999,2.7333267,,,,,,,,,,,,,,,,, -2300,0.18532096,2.6223009,,,,,,,,,,,,,,,,, -2400,0.1704208,2.5913694,,,,,,,,,,,,,,,,, -2432,,,0.5334533452987671,2.6014294624328613,24.246267876536,0.5396957397460938,2.53312087059021,20.273622232300472,3000.0,0.5400848388671875,2.5434446334838867,18.86934210352786,3003.0,867.7713329792023,2192.2171177864075,867.7713329792023,1324.3492550849917,0.0198400020599365,0.0 -2500,0.18613705,2.6048496,,,,,,,,,,,,,,,,, -2600,0.26949447,2.480439,,,,,,,,,,,,,,,,, -2700,0.17555119,2.4506319,,,,,,,,,,,,,,,,, -2800,0.23116297,2.4826698,,,,,,,,,,,,,,,,, -2900,0.18724872,2.3711886,,,,,,,,,,,,,,,,, -3000,0.19500624,2.4406037,,,,,,,,,,,,,,,,, -3100,0.39525998,2.4328847,,,,,,,,,,,,,,,,, -3200,0.17980145,2.3488643,,,,,,,,,,,,,,,,, -3300,0.29423034,2.3208938,,,,,,,,,,,,,,,,, -3400,0.4565339,2.374265,,,,,,,,,,,,,,,,, -3500,0.26408377,2.2654128,,,,,,,,,,,,,,,,, -3600,0.39006895,2.3537433,,,,,,,,,,,,,,,,, -3700,0.45069787,2.2726011,,,,,,,,,,,,,,,,, -3800,0.27867627,2.2758439,,,,,,,,,,,,,,,,, -3900,0.4069121,2.2538776,,,,,,,,,,,,,,,,, -4000,0.6701394,2.304652,,,,,,,,,,,,,,,,, -4100,0.24824458,2.1842058,,,,,,,,,,,,,,,,, -4200,0.2453306,2.2791495,,,,,,,,,,,,,,,,, -4300,0.30309346,2.20204,,,,,,,,,,,,,,,,, -4400,0.28992215,2.1736815,,,,,,,,,,,,,,,,, -4500,0.37408623,2.138055,,,,,,,,,,,,,,,,, -4600,0.33052063,2.1987484,,,,,,,,,,,,,,,,, -4700,0.3222294,2.1569815,,,,,,,,,,,,,,,,, -4800,0.5465014,2.305927,,,,,,,,,,,,,,,,, -4865,,,0.5815564393997192,2.21299695968628,27.39705149030112,0.5966076254844666,2.080467700958252,23.915241366267704,3000.0,0.6017082333564758,2.0296666622161865,22.691728876503863,3003.0,1707.9139828681946,3536.4875156879425,1707.9139828681946,1828.375424146652,0.0461373329162597,0.0 -4900,0.3116285,2.1742363,,,,,,,,,,,,,,,,, -5000,0.47189292,2.1823397,,,,,,,,,,,,,,,,, -5100,0.28573725,2.1678064,,,,,,,,,,,,,,,,, -5200,0.40765515,2.2244234,,,,,,,,,,,,,,,,, -5300,0.27913216,2.1580954,,,,,,,,,,,,,,,,, -5400,0.27483308,2.1668224,,,,,,,,,,,,,,,,, -5500,0.42815453,2.210702,,,,,,,,,,,,,,,,, -5600,0.31906524,2.1443734,,,,,,,,,,,,,,,,, -5700,0.35687816,2.15439,,,,,,,,,,,,,,,,, -5800,0.4005029,2.138711,,,,,,,,,,,,,,,,, -5900,0.60754323,2.1346056,,,,,,,,,,,,,,,,, -6000,0.27076116,2.1853993,,,,,,,,,,,,,,,,, -6100,0.5276182,2.2374358,,,,,,,,,,,,,,,,, -6200,0.6666398,2.254815,,,,,,,,,,,,,,,,, -6300,0.5987547,2.2306516,,,,,,,,,,,,,,,,, -6400,0.3211906,2.149051,,,,,,,,,,,,,,,,, -6500,0.4482263,2.1476843,,,,,,,,,,,,,,,,, -6600,0.49597406,2.0925207,,,,,,,,,,,,,,,,, -6700,0.59542376,2.072597,,,,,,,,,,,,,,,,, -6800,0.4441239,2.1379635,,,,,,,,,,,,,,,,, -6900,0.33241132,2.1444268,,,,,,,,,,,,,,,,, -7000,0.30868173,2.2461114,,,,,,,,,,,,,,,,, -7100,0.4137293,2.113689,,,,,,,,,,,,,,,,, -7200,0.3391607,2.0802417,,,,,,,,,,,,,,,,, -7297,,,0.5896901488304138,2.140878677368164,27.75413258659114,0.6057829260826111,2.008117437362671,24.422510760879085,3000.0,0.6117367148399353,1.962164282798767,23.37913525773548,3003.0,2547.882658958435,4870.319628000259,2547.882658958435,2322.1379055976868,0.0710291862487793,0.0 -7300,0.45769668,2.0605543,,,,,,,,,,,,,,,,, -7400,0.42904767,2.1472108,,,,,,,,,,,,,,,,, -7500,0.50550735,2.105152,,,,,,,,,,,,,,,,, -7600,0.35790962,2.1530864,,,,,,,,,,,,,,,,, -7700,0.60333157,2.1610103,,,,,,,,,,,,,,,,, -7800,0.31766802,2.0234168,,,,,,,,,,,,,,,,, -7900,0.3203239,2.1719227,,,,,,,,,,,,,,,,, -8000,0.4762782,2.1590333,,,,,,,,,,,,,,,,, -8100,0.2810874,2.1598892,,,,,,,,,,,,,,,,, -8200,0.47374168,2.125043,,,,,,,,,,,,,,,,, -8300,0.6396384,2.114261,,,,,,,,,,,,,,,,, -8400,0.34894684,2.2119045,,,,,,,,,,,,,,,,, -8500,0.28953132,2.1401043,,,,,,,,,,,,,,,,, -8600,0.5196546,2.0753613,,,,,,,,,,,,,,,,, -8700,0.28177923,2.0559733,,,,,,,,,,,,,,,,, -8800,0.27762288,2.1767364,,,,,,,,,,,,,,,,, -8900,0.3321702,2.0968938,,,,,,,,,,,,,,,,, -9000,0.30691546,2.214387,,,,,,,,,,,,,,,,, -9100,0.33201283,2.1224256,,,,,,,,,,,,,,,,, -9200,0.5141501,2.2470531,,,,,,,,,,,,,,,,, -9300,0.40818498,2.1181884,,,,,,,,,,,,,,,,, -9400,0.3887075,2.1882114,,,,,,,,,,,,,,,,, -9500,0.29302925,2.0929575,,,,,,,,,,,,,,,,, -9600,0.37605882,2.1278086,,,,,,,,,,,,,,,,, -9700,0.31887344,2.069969,,,,,,,,,,,,,,,,, -9730,,,0.5907862186431885,2.14593243598938,27.758313068089063,0.6103581786155701,1.9752626419067385,24.93721700930189,3000.0,0.6183835864067078,1.9371329545974727,23.38670207942713,3003.0,3388.040452480316,6209.612332820892,3388.040452480316,2821.1698200702667,0.098750352859497,0.0 -9800,0.29032275,2.1165617,,,,,,,,,,,,,,,,, -9900,0.29218552,2.0324402,,,,,,,,,,,,,,,,, -10000,0.3413577,2.1364183,,,,,,,,,,,,,,,,, -10100,0.25895151,2.069128,,,,,,,,,,,,,,,,, -10200,0.3582552,2.1394157,,,,,,,,,,,,,,,,, -10300,0.47766477,2.0502892,,,,,,,,,,,,,,,,, -10400,0.31512392,2.0661936,,,,,,,,,,,,,,,,, -10500,0.26095876,2.1502793,,,,,,,,,,,,,,,,, -10600,0.3195559,2.0327868,,,,,,,,,,,,,,,,, -10700,0.32306483,2.1061592,,,,,,,,,,,,,,,,, -10800,0.26811665,2.127785,,,,,,,,,,,,,,,,, -10900,0.43871796,2.1387002,,,,,,,,,,,,,,,,, -11000,0.5082281,2.1888154,,,,,,,,,,,,,,,,, -11100,0.2896982,2.0513775,,,,,,,,,,,,,,,,, -11200,0.39481664,2.1184494,,,,,,,,,,,,,,,,, -11300,0.6521601,2.104844,,,,,,,,,,,,,,,,, -11400,0.7419836,2.0352378,,,,,,,,,,,,,,,,, -11500,0.38486844,2.170659,,,,,,,,,,,,,,,,, -11600,0.41702425,2.1984468,,,,,,,,,,,,,,,,, -11700,0.4356481,2.0567708,,,,,,,,,,,,,,,,, -11800,0.29802635,2.1439333,,,,,,,,,,,,,,,,, -11900,0.29271293,2.0871854,,,,,,,,,,,,,,,,, -12000,0.29775152,2.1633182,,,,,,,,,,,,,,,,, -12100,0.38432366,2.052955,,,,,,,,,,,,,,,,, -12163,,,0.5988882184028625,2.0783145427703857,27.654304678780385,0.6146730780601501,1.9561139345169067,24.84778671859338,3000.0,0.6207425594329834,1.8987210988998413,23.72557107950406,3003.0,4228.165338039398,7585.615716218948,4228.165338039398,3356.9464728832245,0.1251380443572998,0.0 -12200,0.29210848,2.1200705,,,,,,,,,,,,,,,,, -12300,0.4158158,2.1552718,,,,,,,,,,,,,,,,, -12400,0.3222976,2.0807943,,,,,,,,,,,,,,,,, -12500,0.36297786,2.1068635,,,,,,,,,,,,,,,,, -12600,0.44199947,2.0229344,,,,,,,,,,,,,,,,, -12700,0.30779082,2.0049314,,,,,,,,,,,,,,,,, -12800,0.63644457,2.1072056,,,,,,,,,,,,,,,,, -12900,0.5924818,2.104544,,,,,,,,,,,,,,,,, -13000,0.5654439,2.0667644,,,,,,,,,,,,,,,,, -13100,0.32983726,2.096481,,,,,,,,,,,,,,,,, -13200,0.26529786,2.1111417,,,,,,,,,,,,,,,,, -13300,0.3650585,2.114522,,,,,,,,,,,,,,,,, -13400,0.6200291,2.140625,,,,,,,,,,,,,,,,, -13500,0.75692075,2.0531042,,,,,,,,,,,,,,,,, -13600,0.5045957,2.0891263,,,,,,,,,,,,,,,,, -13700,0.27763566,2.075704,,,,,,,,,,,,,,,,, -13800,0.5721345,2.1076572,,,,,,,,,,,,,,,,, -13900,0.5803384,2.1742508,,,,,,,,,,,,,,,,, -14000,0.49511445,1.9971682,,,,,,,,,,,,,,,,, -14100,0.45776114,2.0737512,,,,,,,,,,,,,,,,, -14200,0.42284417,2.0425453,,,,,,,,,,,,,,,,, -14300,0.28886467,2.0525856,,,,,,,,,,,,,,,,, -14400,0.28533605,2.108049,,,,,,,,,,,,,,,,, -14500,0.27918288,2.0932858,,,,,,,,,,,,,,,,, -14596,,,0.5957963466644287,2.0913431644439697,27.566099296535413,0.6168429255485535,1.918594479560852,24.633238759292187,3000.0,0.6213933229446411,1.8927571773529053,23.30884255123911,3003.0,5068.237158060074,8989.855357408524,5068.237158060074,3921.0114846229553,0.152226448059082,0.0 -14600,0.28360295,2.187444,,,,,,,,,,,,,,,,, -14700,0.4180869,2.0624967,,,,,,,,,,,,,,,,, -14800,0.48233733,2.1123106,,,,,,,,,,,,,,,,, -14900,0.39468384,2.0343199,,,,,,,,,,,,,,,,, -15000,0.7252099,1.9957473,,,,,,,,,,,,,,,,, -15100,0.27094045,2.1294873,,,,,,,,,,,,,,,,, -15200,0.37289715,2.1242225,,,,,,,,,,,,,,,,, -15300,0.5925225,2.0739748,,,,,,,,,,,,,,,,, -15400,0.28906024,1.9887198,,,,,,,,,,,,,,,,, -15500,0.6240957,2.1248918,,,,,,,,,,,,,,,,, -15600,0.4620897,2.0822718,,,,,,,,,,,,,,,,, -15700,0.38765466,2.0563486,,,,,,,,,,,,,,,,, -15800,0.25567764,1.9914045,,,,,,,,,,,,,,,,, -15900,0.29505253,2.104096,,,,,,,,,,,,,,,,, -16000,0.35007405,2.0646374,,,,,,,,,,,,,,,,, -16100,0.35693622,2.1683981,,,,,,,,,,,,,,,,, -16200,0.6732438,2.050274,,,,,,,,,,,,,,,,, -16300,0.33093068,2.033677,,,,,,,,,,,,,,,,, -16400,0.4718503,1.981276,,,,,,,,,,,,,,,,, -16500,0.30010125,2.0473354,,,,,,,,,,,,,,,,, -16600,0.30572364,2.0589216,,,,,,,,,,,,,,,,, -16700,0.54286647,2.0490987,,,,,,,,,,,,,,,,, -16800,0.35958034,2.0746915,,,,,,,,,,,,,,,,, -16900,0.36889037,2.0614607,,,,,,,,,,,,,,,,, -17000,0.23680213,2.0631094,,,,,,,,,,,,,,,,, -17029,,,0.5992476940155029,2.066373348236084,27.728670866665368,0.6194343566894531,1.9151461124420168,25.00878222369035,3000.0,0.628063440322876,1.8578368425369265,24.449780344884086,3003.0,5908.252819776535,10318.040254831314,5908.252819776535,4409.076854467392,0.1800706386566162,0.0 -17100,0.4216575,2.0987635,,,,,,,,,,,,,,,,, -17200,0.23709741,2.0224388,,,,,,,,,,,,,,,,, -17300,0.3972676,2.1121264,,,,,,,,,,,,,,,,, -17400,0.26594803,2.032599,,,,,,,,,,,,,,,,, -17500,0.50670147,2.0694416,,,,,,,,,,,,,,,,, -17600,0.5753195,2.1311512,,,,,,,,,,,,,,,,, -17700,0.24691916,2.046579,,,,,,,,,,,,,,,,, -17800,0.4622686,2.0466256,,,,,,,,,,,,,,,,, -17900,0.47653243,2.0807433,,,,,,,,,,,,,,,,, -18000,0.31834865,2.066474,,,,,,,,,,,,,,,,, -18100,0.32406753,2.1133301,,,,,,,,,,,,,,,,, -18200,0.34168333,2.0698097,,,,,,,,,,,,,,,,, -18300,0.6773526,2.1237726,,,,,,,,,,,,,,,,, -18400,0.42740446,2.093223,,,,,,,,,,,,,,,,, -18500,0.4797156,2.1038868,,,,,,,,,,,,,,,,, -18600,0.50714535,2.043917,,,,,,,,,,,,,,,,, -18700,0.28254282,2.014256,,,,,,,,,,,,,,,,, -18800,0.26617533,2.1395051,,,,,,,,,,,,,,,,, -18900,0.51068014,1.9557561,,,,,,,,,,,,,,,,, -19000,0.28749523,2.023874,,,,,,,,,,,,,,,,, -19100,0.3469994,2.0747838,,,,,,,,,,,,,,,,, -19200,0.270283,2.041724,,,,,,,,,,,,,,,,, -19300,0.3511262,2.0357966,,,,,,,,,,,,,,,,, -19400,0.2604584,2.076518,,,,,,,,,,,,,,,,, -19461,,,0.6093730926513672,1.9898121356964111,28.350201916222264,0.6198559403419495,1.9141809940338133,25.218665959343333,3000.0,0.6275056600570679,1.8592458963394165,24.24569041878532,3003.0,6748.244628667831,11677.037281274796,6748.244628667831,4927.977090597153,0.2080614566802978,0.0 -19500,0.32573077,2.0546105,,,,,,,,,,,,,,,,, -19600,0.46835575,2.0201955,,,,,,,,,,,,,,,,, -19700,0.86288774,2.0956147,,,,,,,,,,,,,,,,, -19800,0.2833863,2.1618745,,,,,,,,,,,,,,,,, -19900,0.55322105,2.0625212,,,,,,,,,,,,,,,,, -20000,0.3649214,2.0717623,,,,,,,,,,,,,,,,, -20100,0.5463958,2.0160134,,,,,,,,,,,,,,,,, -20200,0.3224262,2.0329013,,,,,,,,,,,,,,,,, -20300,0.59369636,2.0630436,,,,,,,,,,,,,,,,, -20400,0.29054093,2.104654,,,,,,,,,,,,,,,,, -20500,0.68925023,2.069234,,,,,,,,,,,,,,,,, -20600,0.33756045,2.0904799,,,,,,,,,,,,,,,,, -20700,0.5359761,2.0500762,,,,,,,,,,,,,,,,, -20800,0.26964092,2.021354,,,,,,,,,,,,,,,,, -20900,0.3262209,1.9837795,,,,,,,,,,,,,,,,, -21000,0.2701926,2.007582,,,,,,,,,,,,,,,,, -21100,0.26065993,2.0016694,,,,,,,,,,,,,,,,, -21200,0.59233105,2.0579774,,,,,,,,,,,,,,,,, -21300,0.48516124,2.0550933,,,,,,,,,,,,,,,,, -21400,0.30480543,2.0489514,,,,,,,,,,,,,,,,, -21500,0.35240304,2.0673494,,,,,,,,,,,,,,,,, -21600,0.5887174,2.040662,,,,,,,,,,,,,,,,, -21700,0.51803184,2.0152662,,,,,,,,,,,,,,,,, -21800,0.5455785,2.0518987,,,,,,,,,,,,,,,,, -21893,,,0.6032451391220093,2.035078763961792,27.34081747482811,0.6205874681472778,1.8935205936431885,25.087785595498843,3000.0,0.6247748732566833,1.851194977760315,23.55798125524869,3003.0,7588.188913345337,13184.40267777443,7588.188913345337,5595.293299913406,0.2365128993988037,0.0 -21900,0.5864451,2.13199,,,,,,,,,,,,,,,,, -22000,0.27537045,2.0955627,,,,,,,,,,,,,,,,, -22100,0.2649148,2.0901253,,,,,,,,,,,,,,,,, -22200,0.35189068,2.1407008,,,,,,,,,,,,,,,,, -22300,0.40448877,1.955123,,,,,,,,,,,,,,,,, -22400,0.28402516,2.1204288,,,,,,,,,,,,,,,,, -22500,0.5501675,2.0941715,,,,,,,,,,,,,,,,, -22600,0.56568867,2.0371356,,,,,,,,,,,,,,,,, -22700,0.47809058,2.0900445,,,,,,,,,,,,,,,,, -22800,0.5971707,2.0382032,,,,,,,,,,,,,,,,, -22900,0.36215985,2.0567913,,,,,,,,,,,,,,,,, -23000,0.27690902,2.1340208,,,,,,,,,,,,,,,,, -23100,0.40070358,2.0623071,,,,,,,,,,,,,,,,, -23200,0.36106417,2.0113268,,,,,,,,,,,,,,,,, -23300,0.3674095,2.052259,,,,,,,,,,,,,,,,, -23400,0.40174347,2.0461476,,,,,,,,,,,,,,,,, -23500,0.33783093,2.0081089,,,,,,,,,,,,,,,,, -23600,0.29988873,2.0301619,,,,,,,,,,,,,,,,, -23700,0.25516433,1.969561,,,,,,,,,,,,,,,,, -23800,0.5720809,2.0904179,,,,,,,,,,,,,,,,, -23900,0.57483184,2.156872,,,,,,,,,,,,,,,,, -24000,0.58964133,2.0125868,,,,,,,,,,,,,,,,, -24100,0.4198149,2.0771859,,,,,,,,,,,,,,,,, -24200,0.5128368,2.0946307,,,,,,,,,,,,,,,,, -24300,0.3231646,2.0676131,,,,,,,,,,,,,,,,, -24326,,,0.6028209328651428,2.0420238971710205,28.100459720929965,0.6226333379745483,1.877811908721924,24.9253727665481,3000.0,0.6287490725517273,1.8355058431625368,24.262781323959945,3003.0,8428.134506702423,14566.971681118011,8428.134506702423,6137.813019990921,0.2643194198608398,0.0 -24400,0.27840558,2.0588787,,,,,,,,,,,,,,,,, -24500,0.29990974,2.0433052,,,,,,,,,,,,,,,,, -24600,0.3308278,2.0635417,,,,,,,,,,,,,,,,, -24700,0.3621463,2.0472653,,,,,,,,,,,,,,,,, -24800,0.3123711,2.071343,,,,,,,,,,,,,,,,, -24900,0.41180786,2.0193086,,,,,,,,,,,,,,,,, -25000,0.31322303,1.9980389,,,,,,,,,,,,,,,,, -25100,0.278029,2.0265577,,,,,,,,,,,,,,,,, -25200,0.348767,2.04641,,,,,,,,,,,,,,,,, -25300,0.44922945,2.1518924,,,,,,,,,,,,,,,,, -25400,0.25226653,2.0265484,,,,,,,,,,,,,,,,, -25500,0.26660788,1.9293737,,,,,,,,,,,,,,,,, -25600,0.55386704,1.9654216,,,,,,,,,,,,,,,,, -25700,0.34681588,2.0739605,,,,,,,,,,,,,,,,, -25800,0.28547716,2.1853383,,,,,,,,,,,,,,,,, -25900,0.34045973,1.9479346,,,,,,,,,,,,,,,,, -26000,0.56292427,1.9687961,,,,,,,,,,,,,,,,, -26100,0.49461508,1.9938898,,,,,,,,,,,,,,,,, -26200,0.39069733,2.081949,,,,,,,,,,,,,,,,, -26300,0.44185388,2.007391,,,,,,,,,,,,,,,,, -26400,0.30939358,2.079219,,,,,,,,,,,,,,,,, -26500,0.34621918,2.0373948,,,,,,,,,,,,,,,,, -26600,0.43922418,2.092666,,,,,,,,,,,,,,,,, -26700,0.38519,2.037184,,,,,,,,,,,,,,,,, -26759,,,0.6044977307319641,2.021979808807373,28.284201263564533,0.624183177947998,1.8740408420562744,25.36061576960488,3000.0,0.6294578909873962,1.8357408046722408,24.53579871136742,3003.0,9268.36467552185,15965.183842658997,9268.36467552185,6695.690530538559,0.2926228046417236,0.0 -26800,0.36845386,2.0448542,,,,,,,,,,,,,,,,, -26900,0.4263238,2.1061463,,,,,,,,,,,,,,,,, -27000,0.510069,1.9709777,,,,,,,,,,,,,,,,, -27100,0.4491857,2.008099,,,,,,,,,,,,,,,,, -27200,0.3084763,1.9899763,,,,,,,,,,,,,,,,, -27300,0.36925274,2.0610836,,,,,,,,,,,,,,,,, -27400,0.2671453,1.9304088,,,,,,,,,,,,,,,,, -27500,0.46608752,2.0761719,,,,,,,,,,,,,,,,, -27600,0.4126139,2.0309563,,,,,,,,,,,,,,,,, -27700,0.42072147,2.0262508,,,,,,,,,,,,,,,,, -27800,0.45209235,2.07256,,,,,,,,,,,,,,,,, -27900,0.28570384,2.0129058,,,,,,,,,,,,,,,,, -28000,0.4297959,2.080699,,,,,,,,,,,,,,,,, -28100,0.27280402,2.0186837,,,,,,,,,,,,,,,,, -28200,0.32275477,2.0254407,,,,,,,,,,,,,,,,, -28300,0.67776924,2.0086117,,,,,,,,,,,,,,,,, -28400,0.31586665,2.071786,,,,,,,,,,,,,,,,, -28500,0.33803734,2.0836275,,,,,,,,,,,,,,,,, -28600,0.28189483,2.0674825,,,,,,,,,,,,,,,,, -28700,0.25719842,2.0655115,,,,,,,,,,,,,,,,, -28800,0.34035993,2.0643075,,,,,,,,,,,,,,,,, -28900,0.5749557,2.0495427,,,,,,,,,,,,,,,,, -29000,0.40419313,2.1848311,,,,,,,,,,,,,,,,, -29100,0.4237206,2.0119126,,,,,,,,,,,,,,,,, -29192,,,0.6068159937858582,2.0194811820983887,28.95645104575732,0.623960018157959,1.871357560157776,25.25651891429717,3000.0,0.6307710409164429,1.8219873905181885,24.321099015768755,3003.0,10108.279440879822,17297.96948671341,10108.279440879822,7188.456845998764,0.3211045265197754,0.0 -29200,0.5422027,2.0008738,,,,,,,,,,,,,,,,, -29300,0.28768623,2.0430257,,,,,,,,,,,,,,,,, -29400,0.2600243,2.0263522,,,,,,,,,,,,,,,,, -29500,0.42114422,1.988309,,,,,,,,,,,,,,,,, -29600,0.3604462,1.9889287,,,,,,,,,,,,,,,,, -29700,0.38232526,2.0834208,,,,,,,,,,,,,,,,, -29800,0.5503672,2.0338223,,,,,,,,,,,,,,,,, -29900,0.32156363,1.9704198,,,,,,,,,,,,,,,,, -30000,0.30834052,1.9960543,,,,,,,,,,,,,,,,, -30100,0.44696116,1.9612675,,,,,,,,,,,,,,,,, -30200,0.34322962,2.0981424,,,,,,,,,,,,,,,,, -30300,0.35393298,1.987926,,,,,,,,,,,,,,,,, -30400,0.39916044,1.9766527,,,,,,,,,,,,,,,,, -30500,0.4079175,2.0727375,,,,,,,,,,,,,,,,, -30600,0.23671196,2.0081558,,,,,,,,,,,,,,,,, -30700,0.35717392,1.9992114,,,,,,,,,,,,,,,,, -30800,0.29175037,2.036232,,,,,,,,,,,,,,,,, -30900,0.2682685,1.8919683,,,,,,,,,,,,,,,,, -31000,0.26516736,1.9835801,,,,,,,,,,,,,,,,, -31100,0.24393104,1.9616615,,,,,,,,,,,,,,,,, -31200,0.27673867,1.9658198,,,,,,,,,,,,,,,,, -31300,0.46966067,2.0267122,,,,,,,,,,,,,,,,, -31400,0.28748164,1.9265788,,,,,,,,,,,,,,,,, -31500,0.37690678,2.0025265,,,,,,,,,,,,,,,,, -31600,0.30745158,2.0586405,,,,,,,,,,,,,,,,, -31625,,,0.6236925721168518,1.8840527534484863,30.04438775326001,0.6268489956855774,1.861114740371704,25.799430214237688,3000.0,0.6341409683227539,1.80830192565918,25.021743492191217,3003.0,10948.29575920105,18655.023404359818,10948.29575920105,7705.389865159988,0.3495962619781494,0.0 -31700,0.40744233,2.0067563,,,,,,,,,,,,,,,,, -31800,0.21735497,2.0427608,,,,,,,,,,,,,,,,, -31900,0.35405752,2.059201,,,,,,,,,,,,,,,,, -32000,0.38310975,2.0202322,,,,,,,,,,,,,,,,, -32100,0.3752584,2.0858066,,,,,,,,,,,,,,,,, -32200,0.3278618,1.9854273,,,,,,,,,,,,,,,,, -32300,0.2770057,1.9452902,,,,,,,,,,,,,,,,, -32400,0.242888,1.9922563,,,,,,,,,,,,,,,,, -32500,0.30374554,1.9668568,,,,,,,,,,,,,,,,, -32600,0.2806351,1.9246129,,,,,,,,,,,,,,,,, -32700,0.43272918,2.0721862,,,,,,,,,,,,,,,,, -32800,0.41753218,2.074716,,,,,,,,,,,,,,,,, -32900,0.3311292,2.0518272,,,,,,,,,,,,,,,,, -33000,0.30171576,2.0130653,,,,,,,,,,,,,,,,, -33100,0.29536432,2.0205617,,,,,,,,,,,,,,,,, -33200,0.2597846,2.041332,,,,,,,,,,,,,,,,, -33300,0.26788223,2.0224075,,,,,,,,,,,,,,,,, -33400,0.295085,2.0917406,,,,,,,,,,,,,,,,, -33500,0.3009411,2.0447874,,,,,,,,,,,,,,,,, -33600,0.28210422,2.0076587,,,,,,,,,,,,,,,,, -33700,0.3123886,2.009889,,,,,,,,,,,,,,,,, -33800,0.3105622,2.0440955,,,,,,,,,,,,,,,,, -33900,0.2536754,1.9169247,,,,,,,,,,,,,,,,, -34000,0.27466246,2.0279176,,,,,,,,,,,,,,,,, -34058,,,0.6058012843132019,2.0277810096740723,28.87051163861552,0.6283245086669922,1.8390883207321167,25.78896790070437,3000.0,0.6344314813613892,1.792935132980347,24.71210323790936,3003.0,11788.428247213364,20093.013048648834,11788.428247213364,8303.139657497406,0.3803756237030029,0.0 -34100,0.30763307,2.0430818,,,,,,,,,,,,,,,,, -34200,0.26967114,2.0564408,,,,,,,,,,,,,,,,, -34300,0.44098243,1.9569441,,,,,,,,,,,,,,,,, -34400,0.28921348,1.9451771,,,,,,,,,,,,,,,,, -34500,0.26846972,2.0004048,,,,,,,,,,,,,,,,, -34600,0.29646486,1.9847169,,,,,,,,,,,,,,,,, -34700,0.43251923,2.0231607,,,,,,,,,,,,,,,,, -34800,0.3274333,1.9609311,,,,,,,,,,,,,,,,, -34900,0.3735041,2.0438764,,,,,,,,,,,,,,,,, -35000,0.45133325,1.9598432,,,,,,,,,,,,,,,,, -35100,0.3816488,2.110873,,,,,,,,,,,,,,,,, -35200,0.32142058,1.9790062,,,,,,,,,,,,,,,,, -35300,0.23201184,2.0794578,,,,,,,,,,,,,,,,, -35400,0.36528832,2.0145314,,,,,,,,,,,,,,,,, -35500,0.38837793,2.0219352,,,,,,,,,,,,,,,,, -35600,0.39453396,1.9681766,,,,,,,,,,,,,,,,, -35700,0.34117264,2.0229363,,,,,,,,,,,,,,,,, -35800,0.3226111,2.0341742,,,,,,,,,,,,,,,,, -35900,0.37667432,1.9681109,,,,,,,,,,,,,,,,, -36000,0.35351592,2.0499759,,,,,,,,,,,,,,,,, -36100,0.2501763,2.0775867,,,,,,,,,,,,,,,,, -36200,0.3365315,1.9996212,,,,,,,,,,,,,,,,, -36300,0.26675436,1.9856205,,,,,,,,,,,,,,,,, -36400,0.3876257,2.080804,,,,,,,,,,,,,,,,, -36491,,,0.608106255531311,2.017210006713867,28.904658555369576,0.6289692521095276,1.84528398513794,26.038305212224035,3000.0,0.6373133659362793,1.7856651544570925,25.120160265789195,3003.0,12628.50937962532,21464.642553329468,12628.50937962532,8834.580487966537,0.4115097522735595,0.0 -36500,0.49584898,1.9451748,,,,,,,,,,,,,,,,, -36600,0.3139603,2.0446084,,,,,,,,,,,,,,,,, -36700,0.2511361,1.9567181,,,,,,,,,,,,,,,,, -36800,0.29693058,2.0412242,,,,,,,,,,,,,,,,, -36900,0.26342195,2.0402625,,,,,,,,,,,,,,,,, -37000,0.34995663,1.9778578,,,,,,,,,,,,,,,,, -37100,0.42766243,1.9843874,,,,,,,,,,,,,,,,, -37200,0.48586133,2.0592697,,,,,,,,,,,,,,,,, -37300,0.43262932,1.9921403,,,,,,,,,,,,,,,,, -37400,0.31560004,2.007811,,,,,,,,,,,,,,,,, -37500,0.41460666,2.0289986,,,,,,,,,,,,,,,,, -37600,0.67456746,2.0100124,,,,,,,,,,,,,,,,, -37700,0.43060115,2.0276384,,,,,,,,,,,,,,,,, -37800,0.26301068,1.940678,,,,,,,,,,,,,,,,, -37900,0.3443465,2.077701,,,,,,,,,,,,,,,,, -38000,0.3211633,1.97726,,,,,,,,,,,,,,,,, -38100,0.2909729,1.9413832,,,,,,,,,,,,,,,,, -38200,0.3218336,2.007082,,,,,,,,,,,,,,,,, -38300,0.29031458,1.9908488,,,,,,,,,,,,,,,,, -38400,0.4143678,1.9564542,,,,,,,,,,,,,,,,, -38500,0.39459476,1.9923979,,,,,,,,,,,,,,,,, -38600,0.4209067,2.0471976,,,,,,,,,,,,,,,,, -38700,0.46470705,1.9579632,,,,,,,,,,,,,,,,, -38800,0.35030752,1.9741099,,,,,,,,,,,,,,,,, -38900,0.36157167,2.0539849,,,,,,,,,,,,,,,,, -38923,,,0.6152269244194031,1.9483067989349363,28.735497170260977,0.6286717057228088,1.8417868614196773,25.679533221535827,3000.0,0.6345128417015076,1.7959977388381958,24.61682099042204,3003.0,13468.485274791718,22799.8377776146,13468.485274791718,9329.693039894104,0.4427659511566162,0.0 -39000,0.30601138,1.9049592,,,,,,,,,,,,,,,,, -39100,0.27784202,1.9981236,,,,,,,,,,,,,,,,, -39200,0.3421887,2.0494542,,,,,,,,,,,,,,,,, -39300,0.39543813,2.1254246,,,,,,,,,,,,,,,,, -39400,0.40699738,1.9893453,,,,,,,,,,,,,,,,, -39500,0.2796988,1.9512717,,,,,,,,,,,,,,,,, -39600,0.39058587,1.9746636,,,,,,,,,,,,,,,,, -39700,0.32588798,1.9950839,,,,,,,,,,,,,,,,, -39800,0.4081024,1.994474,,,,,,,,,,,,,,,,, -39900,0.32101154,1.9198066,,,,,,,,,,,,,,,,, -40000,0.34728435,1.9552294,,,,,,,,,,,,,,,,, -40100,0.2848559,1.9223597,,,,,,,,,,,,,,,,, -40200,0.32334775,1.9668134,,,,,,,,,,,,,,,,, -40300,0.31405252,2.0436776,,,,,,,,,,,,,,,,, -40400,0.32132283,2.0565135,,,,,,,,,,,,,,,,, -40500,0.44409344,1.9819052,,,,,,,,,,,,,,,,, -40600,0.36886725,1.9268403,,,,,,,,,,,,,,,,, -40700,0.33946246,1.9867367,,,,,,,,,,,,,,,,, -40800,0.32561162,2.0275211,,,,,,,,,,,,,,,,, -40900,0.26831707,2.058202,,,,,,,,,,,,,,,,, -41000,0.38578126,1.9173377,,,,,,,,,,,,,,,,, -41100,0.5456178,2.0534356,,,,,,,,,,,,,,,,, -41200,0.26164815,1.8958356,,,,,,,,,,,,,,,,, -41300,0.28705832,1.9862996,,,,,,,,,,,,,,,,, -41356,,,0.6089126467704773,1.994400978088379,28.389003946961708,0.6299487948417664,1.834946632385254,25.0824188802018,3000.0,0.6367555856704712,1.777856707572937,24.162457512803574,3003.0,14308.5158598423,24086.092776298523,14308.5158598423,9775.80866074562,0.4757249355316162,0.0 -41400,0.29480717,2.0028346,,,,,,,,,,,,,,,,, -41500,0.3030179,1.8349773,,,,,,,,,,,,,,,,, -41600,0.36209744,2.0372455,,,,,,,,,,,,,,,,, -41700,0.30413908,1.9259403,,,,,,,,,,,,,,,,, -41800,0.44570762,2.0598874,,,,,,,,,,,,,,,,, -41900,0.3095338,2.084915,,,,,,,,,,,,,,,,, -42000,0.24042505,1.9614935,,,,,,,,,,,,,,,,, -42100,0.4015272,2.1083212,,,,,,,,,,,,,,,,, -42200,0.42189172,2.0035868,,,,,,,,,,,,,,,,, -42300,0.33425283,2.0665426,,,,,,,,,,,,,,,,, -42400,0.2771999,1.9459003,,,,,,,,,,,,,,,,, -42500,0.4358299,1.9694164,,,,,,,,,,,,,,,,, -42600,0.3160565,2.0066626,,,,,,,,,,,,,,,,, -42700,0.311073,1.9298152,,,,,,,,,,,,,,,,, -42800,0.35597378,2.0188892,,,,,,,,,,,,,,,,, -42900,0.39537084,1.9962565,,,,,,,,,,,,,,,,, -43000,0.29179987,2.0161986,,,,,,,,,,,,,,,,, -43100,0.33870643,1.9839267,,,,,,,,,,,,,,,,, -43200,0.29695106,1.9531987,,,,,,,,,,,,,,,,, -43300,0.3174024,1.9598613,,,,,,,,,,,,,,,,, -43400,0.26443496,1.9456159,,,,,,,,,,,,,,,,, -43500,0.3511946,1.9407527,,,,,,,,,,,,,,,,, -43600,0.29627547,1.9754491,,,,,,,,,,,,,,,,, -43700,0.39535415,1.9547632,,,,,,,,,,,,,,,,, -43789,,,0.6921001672744751,1.4866576194763184,35.19845928097932,0.6321682333946228,1.8150848150253296,25.890601540900896,3000.0,0.6381616592407227,1.7659757137298584,25.0090002959881,3003.0,15148.529332399368,25681.5877430439,15148.529332399368,10531.183442354202,0.5067262649536133,0.0 -43800,0.29593998,2.0479364,,,,,,,,,,,,,,,,, -43900,0.4192265,1.9287,,,,,,,,,,,,,,,,, -44000,0.2873365,1.959114,,,,,,,,,,,,,,,,, -44100,0.3226774,2.040354,,,,,,,,,,,,,,,,, -44200,0.26110747,1.8704722,,,,,,,,,,,,,,,,, -44300,0.2569834,2.0302665,,,,,,,,,,,,,,,,, -44400,0.28854644,1.944231,,,,,,,,,,,,,,,,, -44500,0.37239358,2.0014951,,,,,,,,,,,,,,,,, -44600,0.29741448,1.9776651,,,,,,,,,,,,,,,,, -44700,0.27681652,2.0196304,,,,,,,,,,,,,,,,, -44800,0.3462927,1.9425557,,,,,,,,,,,,,,,,, -44900,0.31668422,1.9159958,,,,,,,,,,,,,,,,, -45000,0.28998905,1.9435399,,,,,,,,,,,,,,,,, -45100,0.2655835,1.9929382,,,,,,,,,,,,,,,,, -45200,0.22835691,1.967032,,,,,,,,,,,,,,,,, -45300,0.37240013,1.9450166,,,,,,,,,,,,,,,,, -45400,0.32956898,2.001425,,,,,,,,,,,,,,,,, -45500,0.2620174,1.9631351,,,,,,,,,,,,,,,,, -45600,0.37274653,1.9493514,,,,,,,,,,,,,,,,, -45700,0.26167887,2.0165794,,,,,,,,,,,,,,,,, -45800,0.2652392,1.9700841,,,,,,,,,,,,,,,,, -45900,0.56340617,2.0126436,,,,,,,,,,,,,,,,, -46000,0.3495801,1.9668255,,,,,,,,,,,,,,,,, -46100,0.4569341,2.0298455,,,,,,,,,,,,,,,,, -46200,0.29712352,1.9840009,,,,,,,,,,,,,,,,, -46222,,,0.613900899887085,1.9534741640090945,28.75841070693544,0.6313250660896301,1.8070567846298216,25.706606452819305,3000.0,0.6401255130767822,1.7569904327392578,24.83304160183564,3003.0,15988.571270942688,27042.712433576584,15988.571270942688,11052.160829782486,0.5364077091217041,0.0 -46300,0.2260682,1.9006104,,,,,,,,,,,,,,,,, -46400,0.32396165,1.8898861,,,,,,,,,,,,,,,,, -46500,0.49668986,2.1106868,,,,,,,,,,,,,,,,, -46600,0.280671,1.9867129,,,,,,,,,,,,,,,,, -46700,0.33032468,1.9989734,,,,,,,,,,,,,,,,, -46800,0.28079978,1.9693568,,,,,,,,,,,,,,,,, -46900,0.30830598,1.9255705,,,,,,,,,,,,,,,,, -47000,0.35835838,1.8908727,,,,,,,,,,,,,,,,, -47100,0.29385862,1.9768896,,,,,,,,,,,,,,,,, -47200,0.3450413,2.0086493,,,,,,,,,,,,,,,,, -47300,0.2961706,1.8970234,,,,,,,,,,,,,,,,, -47400,0.280407,2.023044,,,,,,,,,,,,,,,,, -47500,0.25940487,1.962538,,,,,,,,,,,,,,,,, -47600,0.2993465,1.8932092,,,,,,,,,,,,,,,,, -47700,0.38355917,1.9334307,,,,,,,,,,,,,,,,, -47800,0.29377058,1.9759086,,,,,,,,,,,,,,,,, -47900,0.38435516,1.9758865,,,,,,,,,,,,,,,,, -48000,0.28464413,1.9579256,,,,,,,,,,,,,,,,, -48100,0.35555625,1.9831194,,,,,,,,,,,,,,,,, -48200,0.2637605,2.0316858,,,,,,,,,,,,,,,,, -48300,0.26816154,1.8929383,,,,,,,,,,,,,,,,, -48400,0.43533042,1.9273382,,,,,,,,,,,,,,,,, -48500,0.32032138,1.9809004,,,,,,,,,,,,,,,,, -48600,0.29297435,2.009754,,,,,,,,,,,,,,,,, -48655,,,0.6083666682243347,1.9990887641906736,29.04350547005292,0.6298372149467468,1.8104116916656487,25.873173393998723,3000.0,0.6380454301834106,1.7637287378311155,24.83728357032393,3003.0,16828.72014260292,28389.81461524964,16828.72014260292,11559.00778746605,0.5670428276062012,0.0 -48700,0.27983946,1.9655722,,,,,,,,,,,,,,,,, -48800,0.38276672,1.91678,,,,,,,,,,,,,,,,, -48900,0.26762065,1.936341,,,,,,,,,,,,,,,,, -49000,0.33095506,1.8913791,,,,,,,,,,,,,,,,, -49100,0.2894932,1.9998481,,,,,,,,,,,,,,,,, -49200,0.26358798,1.8486664,,,,,,,,,,,,,,,,, -49300,0.31423974,1.969026,,,,,,,,,,,,,,,,, -49400,0.3172932,1.9415804,,,,,,,,,,,,,,,,, -49500,0.3328148,1.9933568,,,,,,,,,,,,,,,,, -49600,0.35294074,1.9474057,,,,,,,,,,,,,,,,, -49700,0.3684979,1.9833974,,,,,,,,,,,,,,,,, -49800,0.3559149,1.9691167,,,,,,,,,,,,,,,,, -49900,0.30444184,2.0355744,,,,,,,,,,,,,,,,, -50000,0.2906542,1.9311237,,,,,,,,,,,,,,,,, -50100,0.26629186,1.9661156,,,,,,,,,,,,,,,,, -50200,0.3251893,1.898395,,,,,,,,,,,,,,,,, -50300,0.2977787,1.9028577,,,,,,,,,,,,,,,,, -50400,0.28326625,1.9818729,,,,,,,,,,,,,,,,, -50500,0.3027916,1.9153312,,,,,,,,,,,,,,,,, -50600,0.28016222,1.8614914,,,,,,,,,,,,,,,,, -50700,0.3790399,1.9984987,,,,,,,,,,,,,,,,, -50800,0.34677798,1.9266149,,,,,,,,,,,,,,,,, -50900,0.3290485,2.0250707,,,,,,,,,,,,,,,,, -51000,0.3004425,1.9109247,,,,,,,,,,,,,,,,, -51088,,,0.618570864200592,1.924091100692749,29.38637507894043,0.6350200176239014,1.7849268913269043,26.589383299087874,3000.0,0.6439370512962341,1.7313753366470337,25.48859377960028,3003.0,17668.943604707718,29739.8209412098,17668.943604707718,12068.680082798004,0.6013104915618896,0.0 -51100,0.3245447,1.9703345,,,,,,,,,,,,,,,,, -51200,0.3254905,2.0295749,,,,,,,,,,,,,,,,, -51300,0.26231354,1.9447399,,,,,,,,,,,,,,,,, -51400,0.2599462,1.9278014,,,,,,,,,,,,,,,,, -51500,0.36048183,1.9171635,,,,,,,,,,,,,,,,, -51600,0.3054271,1.9420686,,,,,,,,,,,,,,,,, -51700,0.33262727,1.9213192,,,,,,,,,,,,,,,,, -51800,0.34843537,2.0484524,,,,,,,,,,,,,,,,, -51900,0.34399384,1.8901702,,,,,,,,,,,,,,,,, -52000,0.373068,1.850424,,,,,,,,,,,,,,,,, -52100,0.34473404,1.9662235,,,,,,,,,,,,,,,,, -52200,0.3069242,1.9714873,,,,,,,,,,,,,,,,, -52300,0.3494641,1.9126132,,,,,,,,,,,,,,,,, -52400,0.39992663,1.9626049,,,,,,,,,,,,,,,,, -52500,0.26728323,1.9238997,,,,,,,,,,,,,,,,, -52600,0.44498858,1.9030569,,,,,,,,,,,,,,,,, -52700,0.4405771,1.93929,,,,,,,,,,,,,,,,, -52800,0.2687194,1.9926351,,,,,,,,,,,,,,,,, -52900,0.44242716,2.0007782,,,,,,,,,,,,,,,,, -53000,0.38547662,1.9604161,,,,,,,,,,,,,,,,, -53100,0.29286677,1.93818,,,,,,,,,,,,,,,,, -53200,0.25257003,1.9116702,,,,,,,,,,,,,,,,, -53300,0.3477801,1.9189577,,,,,,,,,,,,,,,,, -53400,0.33518803,1.918499,,,,,,,,,,,,,,,,, -53500,0.4521445,1.966686,,,,,,,,,,,,,,,,, -53521,,,0.6205177903175354,1.918626427650452,29.460057986444184,0.6394093036651611,1.774772047996521,26.36329864950495,3000.0,0.6456336379051208,1.7239888906478882,25.94893984098375,3003.0,18509.00431752205,31106.393884658813,18509.00431752205,12595.083347082138,0.6347942352294922,0.0 -53600,0.31081733,1.986052,,,,,,,,,,,,,,,,, -53700,0.36200103,2.026863,,,,,,,,,,,,,,,,, -53800,0.9659014,1.931494,,,,,,,,,,,,,,,,, -53900,0.29341283,1.8839396,,,,,,,,,,,,,,,,, -54000,0.26303488,1.9388151,,,,,,,,,,,,,,,,, -54100,0.2777782,1.8645759,,,,,,,,,,,,,,,,, -54200,0.3397997,1.9362698,,,,,,,,,,,,,,,,, -54300,0.28920746,1.9040138,,,,,,,,,,,,,,,,, -54400,0.28210664,1.9021753,,,,,,,,,,,,,,,,, -54500,0.39757758,1.8960927,,,,,,,,,,,,,,,,, -54600,0.29681093,1.977729,,,,,,,,,,,,,,,,, -54700,0.2718311,1.9231614,,,,,,,,,,,,,,,,, -54800,0.4600409,1.9411665,,,,,,,,,,,,,,,,, -54900,0.35423455,1.9856613,,,,,,,,,,,,,,,,, -55000,0.3200749,1.9185541,,,,,,,,,,,,,,,,, -55100,0.34088245,2.0190039,,,,,,,,,,,,,,,,, -55200,0.2862529,1.9407266,,,,,,,,,,,,,,,,, -55300,0.39179295,1.9380363,,,,,,,,,,,,,,,,, -55400,0.35691985,1.9071064,,,,,,,,,,,,,,,,, -55500,0.28676817,1.9054377,,,,,,,,,,,,,,,,, -55600,0.32279155,1.9374646,,,,,,,,,,,,,,,,, -55700,0.39155188,1.9108422,,,,,,,,,,,,,,,,, -55800,0.26877463,1.9864123,,,,,,,,,,,,,,,,, -55900,0.32618323,1.9330834,,,,,,,,,,,,,,,,, -55954,,,0.616606593132019,1.9429373741149905,29.716919854833712,0.6380329728126526,1.7670636177062988,26.87983601191332,3000.0,0.650665283203125,1.6983579397201538,26.0390079997861,3003.0,19348.94182872772,32458.6449136734,19348.94182872772,13107.286288261414,0.6692543029785156,0.0 -56000,0.3050793,1.9610094,,,,,,,,,,,,,,,,, -56100,0.3458819,1.8012096,,,,,,,,,,,,,,,,, -56200,0.4961379,1.9115074,,,,,,,,,,,,,,,,, -56300,0.34543157,1.9976972,,,,,,,,,,,,,,,,, -56400,0.3407566,1.9135247,,,,,,,,,,,,,,,,, -56500,0.34637186,1.9418817,,,,,,,,,,,,,,,,, -56600,0.2675726,1.8385382,,,,,,,,,,,,,,,,, -56700,0.41812894,1.9783719,,,,,,,,,,,,,,,,, -56800,0.28918642,1.9152828,,,,,,,,,,,,,,,,, -56900,0.2926208,1.9266175,,,,,,,,,,,,,,,,, -57000,0.30007944,1.8957431,,,,,,,,,,,,,,,,, -57100,0.3206146,2.0141041,,,,,,,,,,,,,,,,, -57200,0.27589357,1.8516593,,,,,,,,,,,,,,,,, -57300,0.31340602,1.9369906,,,,,,,,,,,,,,,,, -57400,0.36677083,1.9166049,,,,,,,,,,,,,,,,, -57500,0.27486593,1.9725155,,,,,,,,,,,,,,,,, -57600,0.29495722,1.9196959,,,,,,,,,,,,,,,,, -57700,0.27831584,1.8762696,,,,,,,,,,,,,,,,, -57800,0.29444018,1.8709135,,,,,,,,,,,,,,,,, -57900,0.29665574,1.9106199,,,,,,,,,,,,,,,,, -58000,0.3449548,1.8596268,,,,,,,,,,,,,,,,, -58100,0.28406462,1.925123,,,,,,,,,,,,,,,,, -58200,0.32587203,1.8782245,,,,,,,,,,,,,,,,, -58300,0.3891398,1.8938081,,,,,,,,,,,,,,,,, -58387,,,0.6255083680152893,1.883889317512512,29.25000924980348,0.63966965675354,1.7545181512832642,26.65371403234823,3000.0,0.6511068940162659,1.6911041736602783,25.839829990207257,3003.0,20189.042788743973,33925.682745695114,20189.042788743973,13734.115124940872,0.7012717723846436,0.0 -58400,0.4703333,1.8971019,,,,,,,,,,,,,,,,, -58500,0.3514452,1.9558847,,,,,,,,,,,,,,,,, -58600,0.33639723,1.9284178,,,,,,,,,,,,,,,,, -58700,0.36180237,1.9576356,,,,,,,,,,,,,,,,, -58800,0.3563623,1.9576179,,,,,,,,,,,,,,,,, -58900,0.43542063,1.894289,,,,,,,,,,,,,,,,, -59000,0.33516917,1.8496081,,,,,,,,,,,,,,,,, -59100,0.33087757,1.946625,,,,,,,,,,,,,,,,, -59200,0.31854096,1.9178324,,,,,,,,,,,,,,,,, -59300,0.33765596,2.0301192,,,,,,,,,,,,,,,,, -59400,0.259614,1.91048,,,,,,,,,,,,,,,,, -59500,0.33583608,1.9135494,,,,,,,,,,,,,,,,, -59600,0.27450812,1.8360265,,,,,,,,,,,,,,,,, -59700,0.32959306,1.9421535,,,,,,,,,,,,,,,,, -59800,0.34994832,1.8794081,,,,,,,,,,,,,,,,, -59900,0.3247838,1.8548161,,,,,,,,,,,,,,,,, -60000,0.31900015,1.8878932,,,,,,,,,,,,,,,,, -60100,0.30750725,1.9314445,,,,,,,,,,,,,,,,, -60200,0.26099497,1.885621,,,,,,,,,,,,,,,,, -60300,0.24581863,1.9436853,,,,,,,,,,,,,,,,, -60400,0.33967105,1.9602468,,,,,,,,,,,,,,,,, -60500,0.29696897,1.91758,,,,,,,,,,,,,,,,, -60600,0.26825404,1.9218396,,,,,,,,,,,,,,,,, -60700,0.3704283,1.8513213,,,,,,,,,,,,,,,,, -60800,0.29692975,1.9922178,,,,,,,,,,,,,,,,, -60820,,,0.6229991912841797,1.90169644355774,29.73174634459232,0.6420503258705139,1.7479721307754517,26.703019452981483,3000.0,0.6530939936637878,1.672269582748413,25.896172575343115,3003.0,21029.19118499756,35228.80893397331,21029.19118499756,14196.984060525894,0.7340254783630371,0.0 -60900,0.26653466,1.9376667,,,,,,,,,,,,,,,,, -61000,0.4244732,1.9324732,,,,,,,,,,,,,,,,, -61100,0.27918342,1.9017063,,,,,,,,,,,,,,,,, -61200,0.2788848,1.8926857,,,,,,,,,,,,,,,,, -61300,0.2579258,1.9275838,,,,,,,,,,,,,,,,, -61400,0.26734284,1.8368835,,,,,,,,,,,,,,,,, -61500,0.37292275,1.8421508,,,,,,,,,,,,,,,,, -61600,0.3448948,1.9101771,,,,,,,,,,,,,,,,, -61700,0.37970594,1.8774875,,,,,,,,,,,,,,,,, -61800,0.40070975,1.9509317,,,,,,,,,,,,,,,,, -61900,0.2675145,1.9481349,,,,,,,,,,,,,,,,, -62000,0.2750508,1.9376637,,,,,,,,,,,,,,,,, -62100,0.32696545,1.8533778,,,,,,,,,,,,,,,,, -62200,0.34022024,1.9437741,,,,,,,,,,,,,,,,, -62300,0.2862433,1.9290392,,,,,,,,,,,,,,,,, -62400,0.34152874,1.8740127,,,,,,,,,,,,,,,,, -62500,0.2465561,1.8698292,,,,,,,,,,,,,,,,, -62600,0.2978365,1.8475732,,,,,,,,,,,,,,,,, -62700,0.3827471,1.9059105,,,,,,,,,,,,,,,,, -62800,0.2801118,1.9832578,,,,,,,,,,,,,,,,, -62900,0.32400775,1.9294941,,,,,,,,,,,,,,,,, -63000,0.3362954,1.9331924,,,,,,,,,,,,,,,,, -63100,0.3279858,1.878188,,,,,,,,,,,,,,,,, -63200,0.3306921,1.8477936,,,,,,,,,,,,,,,,, -63253,,,0.6290479302406311,1.827910304069519,30.297907849320147,0.6419262886047363,1.7374597787857056,26.966711675602447,3000.0,0.6543838381767273,1.6732841730117798,26.206675737115404,3003.0,21869.32551908493,36555.478934049606,21869.32551908493,14683.410673379898,0.7670059204101562,0.0 -63300,0.27623072,1.8432947,,,,,,,,,,,,,,,,, -63400,0.35089046,1.9358176,,,,,,,,,,,,,,,,, -63500,0.3091313,1.9353415,,,,,,,,,,,,,,,,, -63600,0.2913046,1.9291673,,,,,,,,,,,,,,,,, -63700,0.29747415,1.841015,,,,,,,,,,,,,,,,, -63800,0.36917984,1.8977901,,,,,,,,,,,,,,,,, -63900,0.30753815,1.8898705,,,,,,,,,,,,,,,,, -64000,0.26375377,1.929587,,,,,,,,,,,,,,,,, -64100,0.29318023,1.876884,,,,,,,,,,,,,,,,, -64200,0.29015046,1.9327847,,,,,,,,,,,,,,,,, -64300,0.32602593,1.9182693,,,,,,,,,,,,,,,,, -64400,0.3101028,1.9330136,,,,,,,,,,,,,,,,, -64500,0.2649889,1.9447796,,,,,,,,,,,,,,,,, -64600,0.26889858,1.9217052,,,,,,,,,,,,,,,,, -64700,0.36641943,1.9467462,,,,,,,,,,,,,,,,, -64800,0.31802914,1.8908579,,,,,,,,,,,,,,,,, -64900,0.31187,1.8676538,,,,,,,,,,,,,,,,, -65000,0.30904153,1.8769708,,,,,,,,,,,,,,,,, -65100,0.26497185,1.8572055,,,,,,,,,,,,,,,,, -65200,0.3007043,1.9076489,,,,,,,,,,,,,,,,, -65300,0.30337697,1.9062606,,,,,,,,,,,,,,,,, -65400,0.34358096,1.9349788,,,,,,,,,,,,,,,,, -65500,0.2913941,1.9057941,,,,,,,,,,,,,,,,, -65600,0.29441872,1.9052907,,,,,,,,,,,,,,,,, -65685,,,0.6229196786880493,1.90165114402771,29.82231045733329,0.6441333293914795,1.7247092723846436,27.3569961752832,3000.0,0.6523967385292053,1.6633965969085691,26.01282070100662,3003.0,22709.30819892884,37968.656623363495,22709.30819892884,15256.493627786636,0.8021972179412842,0.0 -65700,0.26408595,1.8340571,,,,,,,,,,,,,,,,, -65800,0.32037297,1.9263066,,,,,,,,,,,,,,,,, -65900,0.2756362,1.8773865,,,,,,,,,,,,,,,,, -66000,0.3592289,1.9025444,,,,,,,,,,,,,,,,, -66100,0.32980397,1.8322269,,,,,,,,,,,,,,,,, -66200,0.32112578,1.8429103,,,,,,,,,,,,,,,,, -66300,0.30182338,1.8057199,,,,,,,,,,,,,,,,, -66400,0.29820052,1.7817928,,,,,,,,,,,,,,,,, -66500,0.3229589,1.8683953,,,,,,,,,,,,,,,,, -66600,0.37006935,1.9779963,,,,,,,,,,,,,,,,, -66700,0.31246084,1.8536444,,,,,,,,,,,,,,,,, -66800,0.29145017,1.8774745,,,,,,,,,,,,,,,,, -66900,0.2697189,1.8332486,,,,,,,,,,,,,,,,, -67000,0.31069145,1.9327573,,,,,,,,,,,,,,,,, -67100,0.31160387,1.9061284,,,,,,,,,,,,,,,,, -67200,0.28949162,1.8506231,,,,,,,,,,,,,,,,, -67300,0.28562462,1.795955,,,,,,,,,,,,,,,,, -67400,0.29373172,1.9225035,,,,,,,,,,,,,,,,, -67500,0.35431862,1.7912989,,,,,,,,,,,,,,,,, -67600,0.27908307,1.7534313,,,,,,,,,,,,,,,,, -67700,0.25905454,1.8664132,,,,,,,,,,,,,,,,, -67800,0.2501393,1.8781843,,,,,,,,,,,,,,,,, -67900,0.288425,1.9321026,,,,,,,,,,,,,,,,, -68000,0.30677292,1.8242071,,,,,,,,,,,,,,,,, -68100,0.33894134,1.7935244,,,,,,,,,,,,,,,,, -68118,,,0.6257489323616028,1.8845466375350952,29.903225654359392,0.6469975709915161,1.7029153108596802,27.369239895807823,3000.0,0.6561850309371948,1.6488022804260254,26.68360440541797,3003.0,23549.419113636017,39327.50707483292,23549.419113636017,15775.118801116943,0.8397126197814941,0.0 -68200,0.27334303,1.9423896,,,,,,,,,,,,,,,,, -68300,0.32241493,1.9417764,,,,,,,,,,,,,,,,, -68400,0.28975943,1.9032456,,,,,,,,,,,,,,,,, -68500,0.27980068,1.8738698,,,,,,,,,,,,,,,,, -68600,0.25625148,1.9063611,,,,,,,,,,,,,,,,, -68700,0.29177237,1.860818,,,,,,,,,,,,,,,,, -68800,0.33799735,1.9554543,,,,,,,,,,,,,,,,, -68900,0.49597612,1.851419,,,,,,,,,,,,,,,,, -69000,0.28053594,1.8716989,,,,,,,,,,,,,,,,, -69100,0.3296509,1.9123497,,,,,,,,,,,,,,,,, -69200,0.28430697,1.8910685,,,,,,,,,,,,,,,,, -69300,0.27105388,1.7876867,,,,,,,,,,,,,,,,, -69400,0.25942212,1.8106346,,,,,,,,,,,,,,,,, -69500,0.28732383,1.8139937,,,,,,,,,,,,,,,,, -69600,0.26436788,1.8604339,,,,,,,,,,,,,,,,, -69700,0.3022626,1.8907948,,,,,,,,,,,,,,,,, -69800,0.3478822,1.7754784,,,,,,,,,,,,,,,,, -69900,0.3405459,1.8384435,,,,,,,,,,,,,,,,, -70000,0.27122897,1.7952217,,,,,,,,,,,,,,,,, -70100,0.29514933,1.8495828,,,,,,,,,,,,,,,,, -70200,0.31158242,1.7874385,,,,,,,,,,,,,,,,, -70300,0.27193043,1.7845254,,,,,,,,,,,,,,,,, -70400,0.28750372,1.8519857,,,,,,,,,,,,,,,,, -70500,0.2941131,1.8637896,,,,,,,,,,,,,,,,, -70552,,,0.6302703022956848,1.8330473899841309,30.171382342838893,0.649502158164978,1.6973282098770142,27.07700891501412,3000.0,0.6587531566619873,1.639639139175415,26.84871294586024,3003.0,24389.49199271202,40648.717832803726,24389.49199271202,16256.143748998642,0.8760302066802979,0.0 -70600,0.2731443,1.9076773,,,,,,,,,,,,,,,,, -70700,0.32245243,1.7924923,,,,,,,,,,,,,,,,, -70800,0.31713486,1.9158189,,,,,,,,,,,,,,,,, -70900,0.30756956,1.8479862,,,,,,,,,,,,,,,,, -71000,0.32297075,1.9317763,,,,,,,,,,,,,,,,, -71100,0.2788839,1.8772082,,,,,,,,,,,,,,,,, -71200,0.26806492,1.849682,,,,,,,,,,,,,,,,, -71300,0.29536244,1.8888688,,,,,,,,,,,,,,,,, -71400,0.28766212,1.8630011,,,,,,,,,,,,,,,,, -71500,0.26249212,1.9033903,,,,,,,,,,,,,,,,, -71600,0.299084,1.8196279,,,,,,,,,,,,,,,,, -71700,0.27243534,1.8910408,,,,,,,,,,,,,,,,, -71800,0.2710586,1.8815904,,,,,,,,,,,,,,,,, -71900,0.29398113,1.8180317,,,,,,,,,,,,,,,,, -72000,0.27193308,1.8555187,,,,,,,,,,,,,,,,, -72100,0.2977999,1.7972474,,,,,,,,,,,,,,,,, -72200,0.3054279,1.8900834,,,,,,,,,,,,,,,,, -72300,0.3033961,1.8280286,,,,,,,,,,,,,,,,, -72400,0.30037585,1.8111372,,,,,,,,,,,,,,,,, -72500,0.2683455,1.7977263,,,,,,,,,,,,,,,,, -72600,0.30280098,1.7902017,,,,,,,,,,,,,,,,, -72700,0.29825163,1.7400016,,,,,,,,,,,,,,,,, -72800,0.2865537,1.855282,,,,,,,,,,,,,,,,, -72900,0.3702135,1.7451494,,,,,,,,,,,,,,,,, -72985,,,0.6285916566848755,1.8592113256454468,30.52954279286608,0.6510024666786194,1.6859381198883057,27.65884298530821,3000.0,0.6605426669120789,1.6174166202545166,26.85367362042821,3003.0,25229.58637213707,41997.50688147545,25229.58637213707,16764.72614622116,0.911757230758667,0.0 -73000,0.30169606,1.8795817,,,,,,,,,,,,,,,,, -73100,0.28770837,1.7656469,,,,,,,,,,,,,,,,, -73200,0.31606254,1.8164378,,,,,,,,,,,,,,,,, -73300,0.29989144,1.8086402,,,,,,,,,,,,,,,,, -73400,0.30541486,1.8450154,,,,,,,,,,,,,,,,, -73500,0.2758782,1.7793311,,,,,,,,,,,,,,,,, -73600,0.38382024,1.8239216,,,,,,,,,,,,,,,,, -73700,0.2863328,1.843963,,,,,,,,,,,,,,,,, -73800,0.2852299,1.8503764,,,,,,,,,,,,,,,,, -73900,0.292726,1.7918131,,,,,,,,,,,,,,,,, -74000,0.31039593,1.8348088,,,,,,,,,,,,,,,,, -74100,0.3043777,1.7923146,,,,,,,,,,,,,,,,, -74200,0.3525504,1.8776735,,,,,,,,,,,,,,,,, -74300,0.28009167,1.786496,,,,,,,,,,,,,,,,, -74400,0.27169618,1.8358139,,,,,,,,,,,,,,,,, -74500,0.31976986,1.8901219,,,,,,,,,,,,,,,,, -74600,0.30100387,1.8069657,,,,,,,,,,,,,,,,, -74700,0.2765997,1.796088,,,,,,,,,,,,,,,,, -74800,0.24967296,1.7947782,,,,,,,,,,,,,,,,, -74900,0.2803483,1.8133773,,,,,,,,,,,,,,,,, -75000,0.30762511,1.8252838,,,,,,,,,,,,,,,,, -75100,0.28677693,1.8306239,,,,,,,,,,,,,,,,, -75200,0.32867572,1.8777984,,,,,,,,,,,,,,,,, -75300,0.25160748,1.7149646,,,,,,,,,,,,,,,,, -75400,0.26806846,1.766374,,,,,,,,,,,,,,,,, -75418,,,0.6453965306282043,1.716774582862854,31.558854951636462,0.6527011394500732,1.681450605392456,27.563748681111136,3000.0,0.6633432507514954,1.6029932498931885,27.128032352674563,3003.0,26069.70525288582,43430.02218937874,26069.70525288582,17357.010063409805,0.9460341930389404,0.0 -75500,0.26039824,1.7384458,,,,,,,,,,,,,,,,, -75600,0.29114968,1.7951416,,,,,,,,,,,,,,,,, -75700,0.3190677,1.8439325,,,,,,,,,,,,,,,,, -75800,0.2842405,1.901601,,,,,,,,,,,,,,,,, -75900,0.32201818,1.8820561,,,,,,,,,,,,,,,,, -76000,0.25605193,1.8365682,,,,,,,,,,,,,,,,, -76100,0.2987417,1.8090873,,,,,,,,,,,,,,,,, -76200,0.29948473,1.8223827,,,,,,,,,,,,,,,,, -76300,0.2891412,1.8168094,,,,,,,,,,,,,,,,, -76400,0.28750324,1.8365886,,,,,,,,,,,,,,,,, -76500,0.29565364,1.8294665,,,,,,,,,,,,,,,,, -76600,0.31748742,1.7285157,,,,,,,,,,,,,,,,, -76700,0.32304773,1.9176096,,,,,,,,,,,,,,,,, -76800,0.2846633,1.9367799,,,,,,,,,,,,,,,,, -76900,0.31128332,1.7494285,,,,,,,,,,,,,,,,, -77000,0.2620761,1.7345495,,,,,,,,,,,,,,,,, -77100,0.3288934,1.8406886,,,,,,,,,,,,,,,,, -77200,0.27281055,1.7824472,,,,,,,,,,,,,,,,, -77300,0.60218954,1.8929951,,,,,,,,,,,,,,,,, -77400,0.40631866,1.7985373,,,,,,,,,,,,,,,,, -77500,0.28625128,1.8036642,,,,,,,,,,,,,,,,, -77600,0.28659698,1.8062459,,,,,,,,,,,,,,,,, -77700,0.2811152,1.7254633,,,,,,,,,,,,,,,,, -77800,0.335078,1.8039242,,,,,,,,,,,,,,,,, -77852,,,0.6361210942268372,1.8001832962036133,30.33095928486013,0.6541270613670349,1.6627322435379028,27.47509347000726,3000.0,0.667956531047821,1.5868098735809326,27.65037334871951,3003.0,26909.717646598816,44805.90866136551,26909.717646598816,17892.767714738846,0.9847004413604736,0.0 -77900,0.31093222,1.840815,,,,,,,,,,,,,,,,, -78000,0.318296,1.8944281,,,,,,,,,,,,,,,,, -78100,0.30546504,1.861939,,,,,,,,,,,,,,,,, -78200,0.31609514,1.9177032,,,,,,,,,,,,,,,,, -78300,0.34745526,1.7947451,,,,,,,,,,,,,,,,, -78400,0.32970485,1.6847879,,,,,,,,,,,,,,,,, -78500,0.29882985,1.7626455,,,,,,,,,,,,,,,,, -78600,0.28765026,1.8249637,,,,,,,,,,,,,,,,, -78700,0.29338428,1.8691958,,,,,,,,,,,,,,,,, -78800,0.30282038,1.7940751,,,,,,,,,,,,,,,,, -78900,0.30175576,1.7503681,,,,,,,,,,,,,,,,, -79000,0.2937072,1.7822328,,,,,,,,,,,,,,,,, -79100,0.29706648,1.8342814,,,,,,,,,,,,,,,,, -79200,0.30188718,1.7652696,,,,,,,,,,,,,,,,, -79300,0.3432157,1.791047,,,,,,,,,,,,,,,,, -79400,0.3030192,1.8788177,,,,,,,,,,,,,,,,, -79500,0.3114434,1.7986065,,,,,,,,,,,,,,,,, -79600,0.31684804,1.8017601,,,,,,,,,,,,,,,,, -79700,0.2748894,1.8518081,,,,,,,,,,,,,,,,, -79800,0.32304215,1.7858015,,,,,,,,,,,,,,,,, -79900,0.32558447,1.6949724,,,,,,,,,,,,,,,,, -80000,0.36993778,1.8402265,,,,,,,,,,,,,,,,, -80100,0.40135998,1.8409293,,,,,,,,,,,,,,,,, -80200,0.32164657,1.7088244,,,,,,,,,,,,,,,,, -80285,,,0.6344415545463562,1.814192771911621,30.737479092771107,0.6556149125099182,1.6477035284042358,27.724983358387743,3000.0,0.6648771166801453,1.5862025022506714,27.20680830071649,3003.0,27749.741545915604,46149.47729349136,27749.741545915604,18396.198426246643,1.0218725204467771,0.0 -80300,0.28575265,1.8231161,,,,,,,,,,,,,,,,, -80400,0.32285163,1.7775443,,,,,,,,,,,,,,,,, -80500,0.26266006,1.7882378,,,,,,,,,,,,,,,,, -80600,0.28520694,1.7557505,,,,,,,,,,,,,,,,, -80700,0.27775428,1.8084747,,,,,,,,,,,,,,,,, -80800,0.3074603,1.8835082,,,,,,,,,,,,,,,,, -80900,0.2641396,1.7912436,,,,,,,,,,,,,,,,, -81000,0.30255425,1.7810287,,,,,,,,,,,,,,,,, -81100,0.33460593,1.8458942,,,,,,,,,,,,,,,,, -81200,0.28381214,1.7545483,,,,,,,,,,,,,,,,, -81300,0.27787432,1.7768139,,,,,,,,,,,,,,,,, -81400,0.31198636,1.8053479,,,,,,,,,,,,,,,,, -81500,0.3396008,1.8320394,,,,,,,,,,,,,,,,, -81600,0.319418,1.8554635,,,,,,,,,,,,,,,,, -81700,0.30710018,1.779902,,,,,,,,,,,,,,,,, -81800,0.35830504,1.8176892,,,,,,,,,,,,,,,,, -81900,0.31810758,1.8370011,,,,,,,,,,,,,,,,, -82000,0.32238394,1.8438437,,,,,,,,,,,,,,,,, -82100,0.32393557,1.8182878,,,,,,,,,,,,,,,,, -82200,0.3147452,1.8983029,,,,,,,,,,,,,,,,, -82300,0.3504769,1.8353779,,,,,,,,,,,,,,,,, -82400,0.29329747,1.7916787,,,,,,,,,,,,,,,,, -82500,0.3232259,1.7414653,,,,,,,,,,,,,,,,, -82600,0.31746024,1.8441861,,,,,,,,,,,,,,,,, -82700,0.27773106,1.8299928,,,,,,,,,,,,,,,,, -82719,,,0.6453226804733276,1.7232457399368286,31.61777031006529,0.6575492024421692,1.6341495513916016,28.11244713518492,3000.0,0.6703736186027527,1.5644491910934448,27.516618205330463,3003.0,28589.92642354965,47558.457666397095,28589.92642354965,18964.878935098648,1.0599958896636963,0.0 -82800,0.31048286,1.8318071,,,,,,,,,,,,,,,,, -82900,0.28738186,1.7695763,,,,,,,,,,,,,,,,, -83000,0.31671914,1.7138764,,,,,,,,,,,,,,,,, -83100,0.32007623,1.8443696,,,,,,,,,,,,,,,,, -83200,0.30129436,1.831828,,,,,,,,,,,,,,,,, -83300,0.27076373,1.7780845,,,,,,,,,,,,,,,,, -83400,0.30769208,1.8374741,,,,,,,,,,,,,,,,, -83500,0.2865047,1.8168125,,,,,,,,,,,,,,,,, -83600,0.28044248,1.7636981,,,,,,,,,,,,,,,,, -83700,0.27405858,1.751035,,,,,,,,,,,,,,,,, -83800,0.29314375,1.7200549,,,,,,,,,,,,,,,,, -83900,0.34746933,1.813775,,,,,,,,,,,,,,,,, -84000,0.33159763,1.7658225,,,,,,,,,,,,,,,,, -84100,0.26072216,1.7912713,,,,,,,,,,,,,,,,, -84200,0.30101812,1.8176438,,,,,,,,,,,,,,,,, -84300,0.32634845,1.8060964,,,,,,,,,,,,,,,,, -84400,0.3154391,1.8557509,,,,,,,,,,,,,,,,, -84500,0.29593885,1.78869,,,,,,,,,,,,,,,,, -84600,0.30305177,1.7707642,,,,,,,,,,,,,,,,, -84700,0.30741742,1.7182238,,,,,,,,,,,,,,,,, -84800,0.27964568,1.7508466,,,,,,,,,,,,,,,,, -84900,0.27793977,1.7314901,,,,,,,,,,,,,,,,, -85000,0.28619954,1.7667894,,,,,,,,,,,,,,,,, -85100,0.28527364,1.852766,,,,,,,,,,,,,,,,, -85153,,,0.6420574188232422,1.7634731531143188,31.85978635812725,0.6593594551086426,1.616284966468811,28.13691576680429,3000.0,0.672267735004425,1.5401736497879028,27.951467340137658,3003.0,29430.129603147507,48923.83412575722,29430.129603147507,19489.93882274628,1.096651315689087,0.0 -85200,0.2903377,1.7043426,,,,,,,,,,,,,,,,, -85300,0.3256051,1.768683,,,,,,,,,,,,,,,,, -85400,0.36271006,1.7097801,,,,,,,,,,,,,,,,, -85500,0.2907339,1.7450161,,,,,,,,,,,,,,,,, -85600,0.3012487,1.7194811,,,,,,,,,,,,,,,,, -85700,0.2697726,1.7981685,,,,,,,,,,,,,,,,, -85800,0.36046228,1.8711619,,,,,,,,,,,,,,,,, -85900,0.295696,1.7694553,,,,,,,,,,,,,,,,, -86000,0.33361745,1.789342,,,,,,,,,,,,,,,,, -86100,0.32804167,1.7095677,,,,,,,,,,,,,,,,, -86200,0.28638002,1.7857059,,,,,,,,,,,,,,,,, -86300,0.30195647,1.7644354,,,,,,,,,,,,,,,,, -86400,0.299619,1.8137778,,,,,,,,,,,,,,,,, -86500,0.31338522,1.7921921,,,,,,,,,,,,,,,,, -86600,0.35521257,1.830571,,,,,,,,,,,,,,,,, -86700,0.30976486,1.7689253,,,,,,,,,,,,,,,,, -86800,0.3043774,1.7873414,,,,,,,,,,,,,,,,, -86900,0.27716735,1.6771021,,,,,,,,,,,,,,,,, -87000,0.374335,1.8549361,,,,,,,,,,,,,,,,, -87100,0.30924058,1.7533169,,,,,,,,,,,,,,,,, -87200,0.28732872,1.780382,,,,,,,,,,,,,,,,, -87300,0.28071365,1.7697936,,,,,,,,,,,,,,,,, -87400,0.28483146,1.7121302,,,,,,,,,,,,,,,,, -87500,0.31425306,1.7619915,,,,,,,,,,,,,,,,, -87586,,,0.706721842288971,1.3691809177398682,36.76419684639784,0.664182722568512,1.5927585363388062,28.74037977809781,3000.0,0.6754401326179504,1.5277425050735474,27.99981196715836,3003.0,30270.221952676773,50315.81853723526,30270.221952676773,20041.71503210068,1.135914325714111,0.0 -87600,0.30362678,1.7665008,,,,,,,,,,,,,,,,, -87700,0.2935612,1.7213625,,,,,,,,,,,,,,,,, -87800,0.30071008,1.7197539,,,,,,,,,,,,,,,,, -87900,0.29117072,1.7118155,,,,,,,,,,,,,,,,, -88000,0.32522342,1.7546824,,,,,,,,,,,,,,,,, -88100,0.30430114,1.7170713,,,,,,,,,,,,,,,,, -88200,0.30615416,1.8040065,,,,,,,,,,,,,,,,, -88300,0.28470236,1.7561915,,,,,,,,,,,,,,,,, -88400,0.3119858,1.7771089,,,,,,,,,,,,,,,,, -88500,0.30736524,1.7516495,,,,,,,,,,,,,,,,, -88600,0.2697674,1.7134267,,,,,,,,,,,,,,,,, -88700,0.32542646,1.7503335,,,,,,,,,,,,,,,,, -88800,0.3223668,1.8726379,,,,,,,,,,,,,,,,, -88900,0.320511,1.7091486,,,,,,,,,,,,,,,,, -89000,0.30910832,1.7033964,,,,,,,,,,,,,,,,, -89100,0.35617414,1.797943,,,,,,,,,,,,,,,,, -89200,0.35098666,1.8228514,,,,,,,,,,,,,,,,, -89300,0.31104267,1.7825887,,,,,,,,,,,,,,,,, -89400,0.31635472,1.7080907,,,,,,,,,,,,,,,,, -89500,0.2790961,1.696497,,,,,,,,,,,,,,,,, -89600,0.28951713,1.6939436,,,,,,,,,,,,,,,,, -89700,0.29158303,1.7218039,,,,,,,,,,,,,,,,, -89800,0.29765692,1.7359806,,,,,,,,,,,,,,,,, -89900,0.30048808,1.7759825,,,,,,,,,,,,,,,,, -90000,0.3115856,1.7672763,,,,,,,,,,,,,,,,, -90019,,,0.6472726464271545,1.7228424549102783,31.716836297151044,0.6628807783126831,1.5885562896728516,28.02380074296613,3000.0,0.6758701205253601,1.5183173418045044,27.95895404149604,3003.0,31110.187237501144,51725.13202857971,31110.187237501144,20610.94938087464,1.1731517314910889,0.0 -90100,0.2941452,1.7502807,,,,,,,,,,,,,,,,, -90200,0.2852946,1.6361986,,,,,,,,,,,,,,,,, -90300,0.3144399,1.7674344,,,,,,,,,,,,,,,,, -90400,0.30047268,1.7336496,,,,,,,,,,,,,,,,, -90500,0.27599132,1.6050845,,,,,,,,,,,,,,,,, -90600,0.28549403,1.7593492,,,,,,,,,,,,,,,,, -90700,0.29984918,1.685128,,,,,,,,,,,,,,,,, -90800,0.26414874,1.696509,,,,,,,,,,,,,,,,, -90900,0.33227423,1.7878584,,,,,,,,,,,,,,,,, -91000,0.302202,1.7363112,,,,,,,,,,,,,,,,, -91100,0.29163975,1.7196468,,,,,,,,,,,,,,,,, -91200,0.3396889,1.8337526,,,,,,,,,,,,,,,,, -91300,0.28640205,1.6843597,,,,,,,,,,,,,,,,, -91400,0.31940776,1.7027502,,,,,,,,,,,,,,,,, -91500,0.31634843,1.7942731,,,,,,,,,,,,,,,,, -91600,0.27846485,1.7458814,,,,,,,,,,,,,,,,, -91700,0.2945103,1.7493075,,,,,,,,,,,,,,,,, -91800,0.30778134,1.8371457,,,,,,,,,,,,,,,,, -91900,0.295677,1.687361,,,,,,,,,,,,,,,,, -92000,0.27351624,1.7268655,,,,,,,,,,,,,,,,, -92100,0.30655557,1.6528693,,,,,,,,,,,,,,,,, -92200,0.29198655,1.7284398,,,,,,,,,,,,,,,,, -92300,0.33703488,1.6988609,,,,,,,,,,,,,,,,, -92400,0.3069432,1.792056,,,,,,,,,,,,,,,,, -92452,,,0.6476435661315918,1.7178657054901123,31.69203925629525,0.6662781834602356,1.5724620819091797,28.56802874801828,3000.0,0.6786009073257446,1.4941082000732422,28.248444795106696,3003.0,31950.0975458622,53169.55118060112,31950.0975458622,21215.34196662903,1.2124512195587158,0.0 -92500,0.29930636,1.7598214,,,,,,,,,,,,,,,,, -92600,0.3029849,1.7572265,,,,,,,,,,,,,,,,, -92700,0.3192835,1.726217,,,,,,,,,,,,,,,,, -92800,0.267901,1.6487207,,,,,,,,,,,,,,,,, -92900,0.3131309,1.6793667,,,,,,,,,,,,,,,,, -93000,0.30426824,1.7186953,,,,,,,,,,,,,,,,, -93100,0.30475315,1.6810484,,,,,,,,,,,,,,,,, -93200,0.32298616,1.6363857,,,,,,,,,,,,,,,,, -93300,0.27200767,1.6498678,,,,,,,,,,,,,,,,, -93400,0.30052817,1.7006259,,,,,,,,,,,,,,,,, -93500,0.3632731,1.7726392,,,,,,,,,,,,,,,,, -93600,0.28942645,1.7068117,,,,,,,,,,,,,,,,, -93700,0.32833716,1.7148072,,,,,,,,,,,,,,,,, -93800,0.28961012,1.702598,,,,,,,,,,,,,,,,, -93900,0.28738886,1.7668273,,,,,,,,,,,,,,,,, -94000,0.27834553,1.6633008,,,,,,,,,,,,,,,,, -94100,0.3002611,1.7240584,,,,,,,,,,,,,,,,, -94200,0.2868788,1.6911793,,,,,,,,,,,,,,,,, -94300,0.28570303,1.7607421,,,,,,,,,,,,,,,,, -94400,0.43832454,1.7142088,,,,,,,,,,,,,,,,, -94500,0.3144578,1.7467444,,,,,,,,,,,,,,,,, -94600,0.28129858,1.7074767,,,,,,,,,,,,,,,,, -94700,0.3015799,1.6937916,,,,,,,,,,,,,,,,, -94800,0.31266794,1.6436272,,,,,,,,,,,,,,,,, -94886,,,0.6604101657867432,1.6313503980636597,32.403113391945894,0.6679768562316895,1.559630036354065,28.760747178515107,3000.0,0.6810412406921387,1.4820163249969482,28.72655015719533,3003.0,32790.26389718056,54487.12789297104,32790.26389718056,21692.634751558304,1.252263069152832,0.0 -94900,0.32566547,1.6682264,,,,,,,,,,,,,,,,, -95000,0.33007145,1.7079548,,,,,,,,,,,,,,,,, -95100,0.28007895,1.6636423,,,,,,,,,,,,,,,,, -95200,0.3366287,1.7476221,,,,,,,,,,,,,,,,, -95300,0.31329536,1.7146983,,,,,,,,,,,,,,,,, -95400,0.3093551,1.5949074,,,,,,,,,,,,,,,,, -95500,0.31067905,1.7615948,,,,,,,,,,,,,,,,, -95600,0.30222905,1.7070745,,,,,,,,,,,,,,,,, -95700,0.3129607,1.6525191,,,,,,,,,,,,,,,,, -95800,0.30379018,1.6495004,,,,,,,,,,,,,,,,, -95900,0.28726,1.6565309,,,,,,,,,,,,,,,,, -96000,0.30526337,1.7161412,,,,,,,,,,,,,,,,, -96100,0.3569584,1.7054121,,,,,,,,,,,,,,,,, -96200,0.2904295,1.6560646,,,,,,,,,,,,,,,,, -96300,0.3029546,1.765401,,,,,,,,,,,,,,,,, -96400,0.31116375,1.6830262,,,,,,,,,,,,,,,,, -96500,0.3269926,1.7256963,,,,,,,,,,,,,,,,, -96600,0.35704812,1.7463112,,,,,,,,,,,,,,,,, -96700,0.3121893,1.6425351,,,,,,,,,,,,,,,,, -96800,0.32820618,1.7744209,,,,,,,,,,,,,,,,, -96900,0.3112623,1.68321,,,,,,,,,,,,,,,,, -97000,0.3338182,1.7415489,,,,,,,,,,,,,,,,, -97100,0.32344058,1.6122239,,,,,,,,,,,,,,,,, -97200,0.28691313,1.6943799,,,,,,,,,,,,,,,,, -97300,0.316593,1.7530959,,,,,,,,,,,,,,,,, -97319,,,0.6549776792526245,1.674398422241211,32.00376825833275,0.6725149154663086,1.5401153564453125,28.80237879604785,3000.0,0.6848992109298706,1.4625848531723022,28.58278102882752,3003.0,33630.172709703445,55810.58344745636,33630.172709703445,22176.063058376312,1.2936749458312988,0.0 -97400,0.31888884,1.6556119,,,,,,,,,,,,,,,,, -97500,0.31873742,1.7685658,,,,,,,,,,,,,,,,, -97600,0.29456508,1.6640937,,,,,,,,,,,,,,,,, -97700,0.30415913,1.6721915,,,,,,,,,,,,,,,,, -97800,0.30305594,1.632034,,,,,,,,,,,,,,,,, -97900,0.30664545,1.7078034,,,,,,,,,,,,,,,,, -98000,0.29293957,1.7274493,,,,,,,,,,,,,,,,, -98100,0.31133467,1.6461362,,,,,,,,,,,,,,,,, -98200,0.31020916,1.7244219,,,,,,,,,,,,,,,,, -98300,0.3040332,1.685031,,,,,,,,,,,,,,,,, -98400,0.31215456,1.6709578,,,,,,,,,,,,,,,,, -98500,0.3095053,1.6447432,,,,,,,,,,,,,,,,, -98600,0.32420492,1.6555163,,,,,,,,,,,,,,,,, -98700,0.29426208,1.58805,,,,,,,,,,,,,,,,, -98800,0.3159577,1.7644595,,,,,,,,,,,,,,,,, -98900,0.29085466,1.6369185,,,,,,,,,,,,,,,,, -99000,0.3080059,1.6567109,,,,,,,,,,,,,,,,, -99100,0.3033531,1.7230625,,,,,,,,,,,,,,,,, -99200,0.32316715,1.6066873,,,,,,,,,,,,,,,,, -99300,0.30483958,1.6799241,,,,,,,,,,,,,,,,, -99400,0.27862668,1.652653,,,,,,,,,,,,,,,,, -99500,0.3302853,1.681836,,,,,,,,,,,,,,,,, -99600,0.288068,1.6226432,,,,,,,,,,,,,,,,, -99700,0.32839522,1.6562779,,,,,,,,,,,,,,,,, -99752,,,0.6542562246322632,1.6796025037765503,32.45218523737319,0.6749699115753174,1.5269558429718018,29.279505165766786,3000.0,0.6886526346206665,1.4499832391738892,28.744171233574672,3003.0,34470.056736946106,57170.82868671417,34470.056736946106,22696.30715584755,1.3333721160888672,0.0 -99800,0.30664593,1.650603,,,,,,,,,,,,,,,,, -99900,0.3033976,1.719894,,,,,,,,,,,,,,,,, -100000,0.32624578,1.7256236,,,,,,,,,,,,,,,,, -100100,0.29700866,1.6602452,,,,,,,,,,,,,,,,, -100200,0.30396515,1.6525719,,,,,,,,,,,,,,,,, -100300,0.31123894,1.6483123,,,,,,,,,,,,,,,,, -100400,0.28622296,1.6610451,,,,,,,,,,,,,,,,, -100500,0.31240228,1.7120507,,,,,,,,,,,,,,,,, -100600,0.2999183,1.651676,,,,,,,,,,,,,,,,, -100700,0.30456343,1.7062731,,,,,,,,,,,,,,,,, -100800,0.30795228,1.6008544,,,,,,,,,,,,,,,,, -100900,0.33961967,1.6289883,,,,,,,,,,,,,,,,, -101000,0.32097346,1.6697385,,,,,,,,,,,,,,,,, -101100,0.31406787,1.6039045,,,,,,,,,,,,,,,,, -101200,0.31342262,1.6956016,,,,,,,,,,,,,,,,, -101300,0.29459646,1.7376084,,,,,,,,,,,,,,,,, -101400,0.32766974,1.6474181,,,,,,,,,,,,,,,,, -101500,0.3285678,1.7040503,,,,,,,,,,,,,,,,, -101600,0.33727792,1.7591313,,,,,,,,,,,,,,,,, -101700,0.2970527,1.6585202,,,,,,,,,,,,,,,,, -101800,0.30131978,1.6745178,,,,,,,,,,,,,,,,, -101900,0.29837346,1.6812415,,,,,,,,,,,,,,,,, -102000,0.30456147,1.6247604,,,,,,,,,,,,,,,,, -102100,0.29931644,1.5690979,,,,,,,,,,,,,,,,, -102186,,,0.6620293259620667,1.618472933769226,32.64789488614793,0.6765817999839783,1.5140327215194702,29.198657105886326,3000.0,0.6881877779960632,1.4374244213104248,28.851760085084145,3003.0,35310.28398871422,58635.10763931274,35310.28398871422,23320.24372291565,1.3716270923614502,0.0 -102200,0.3195338,1.6424135,,,,,,,,,,,,,,,,, -102300,0.29823512,1.6346343,,,,,,,,,,,,,,,,, -102400,0.2832093,1.5750198,,,,,,,,,,,,,,,,, -102500,0.31062433,1.6897796,,,,,,,,,,,,,,,,, -102600,0.33594468,1.6446007,,,,,,,,,,,,,,,,, -102700,0.328313,1.7547336,,,,,,,,,,,,,,,,, -102800,0.31581047,1.7086883,,,,,,,,,,,,,,,,, -102900,0.31448117,1.6159734,,,,,,,,,,,,,,,,, -103000,0.2877771,1.6769083,,,,,,,,,,,,,,,,, -103100,0.33665502,1.6838647,,,,,,,,,,,,,,,,, -103200,0.31112283,1.6234884,,,,,,,,,,,,,,,,, -103300,0.33106673,1.681838,,,,,,,,,,,,,,,,, -103400,0.31377682,1.6345537,,,,,,,,,,,,,,,,, -103500,0.31330684,1.6249391,,,,,,,,,,,,,,,,, -103600,0.31820145,1.6480652,,,,,,,,,,,,,,,,, -103700,0.36780262,1.6518524,,,,,,,,,,,,,,,,, -103800,0.33652595,1.564192,,,,,,,,,,,,,,,,, -103900,0.34855837,1.6620584,,,,,,,,,,,,,,,,, -104000,0.30190548,1.6191012,,,,,,,,,,,,,,,,, -104100,0.28637668,1.6641061,,,,,,,,,,,,,,,,, -104200,0.32458353,1.6370138,,,,,,,,,,,,,,,,, -104300,0.30183798,1.6540065,,,,,,,,,,,,,,,,, -104400,0.31111473,1.6492177,,,,,,,,,,,,,,,,, -104500,0.31081396,1.7033103,,,,,,,,,,,,,,,,, -104600,0.31279746,1.5868359,,,,,,,,,,,,,,,,, -104621,,,0.6610704660415649,1.626672625541687,32.60958248876339,0.6787516474723816,1.5002140998840332,29.66641032835177,3000.0,0.693591296672821,1.4133522510528564,29.587910915299563,3003.0,36150.45073056221,60030.47633481026,36150.45073056221,23875.32908177376,1.4112050533294678,0.0 -104700,0.2973713,1.6373535,,,,,,,,,,,,,,,,, -104800,0.3176089,1.6288086,,,,,,,,,,,,,,,,, -104900,0.3187573,1.6288211,,,,,,,,,,,,,,,,, -105000,0.30738947,1.6189597,,,,,,,,,,,,,,,,, -105100,0.32125202,1.5892421,,,,,,,,,,,,,,,,, -105200,0.3136619,1.6706809,,,,,,,,,,,,,,,,, -105300,0.3093521,1.6360049,,,,,,,,,,,,,,,,, -105400,0.3258859,1.580506,,,,,,,,,,,,,,,,, -105500,0.32468343,1.6018524,,,,,,,,,,,,,,,,, -105600,0.32209972,1.6196913,,,,,,,,,,,,,,,,, -105700,0.3217741,1.5963104,,,,,,,,,,,,,,,,, -105800,0.31356433,1.619669,,,,,,,,,,,,,,,,, -105900,0.33516234,1.5873567,,,,,,,,,,,,,,,,, -106000,0.29817227,1.5360392,,,,,,,,,,,,,,,,, -106100,0.3084061,1.5685925,,,,,,,,,,,,,,,,, -106200,0.32451528,1.6594529,,,,,,,,,,,,,,,,, -106300,0.3267302,1.6689517,,,,,,,,,,,,,,,,, -106400,0.33387548,1.6266453,,,,,,,,,,,,,,,,, -106500,0.31749365,1.5923812,,,,,,,,,,,,,,,,, -106600,0.33862117,1.6446425,,,,,,,,,,,,,,,,, -106700,0.31713635,1.6501409,,,,,,,,,,,,,,,,, -106800,0.31431547,1.6191169,,,,,,,,,,,,,,,,, -106900,0.33286119,1.5638953,,,,,,,,,,,,,,,,, -107000,0.31272438,1.6657442,,,,,,,,,,,,,,,,, -107055,,,0.6778357028961182,1.5199074745178225,33.98085596494906,0.6803759336471558,1.4900200366973877,29.69537641347098,3000.0,0.6934751272201538,1.4053597450256348,29.51269855331246,3003.0,36990.67479014397,61593.02966165543,36990.67479014397,24597.542265176773,1.4509341716766355,0.0 -107100,0.31546715,1.6294951,,,,,,,,,,,,,,,,, -107200,0.3542338,1.6269996,,,,,,,,,,,,,,,,, -107300,0.33746037,1.5965569,,,,,,,,,,,,,,,,, -107400,0.32761052,1.6153299,,,,,,,,,,,,,,,,, -107500,0.31801805,1.5244354,,,,,,,,,,,,,,,,, -107600,0.30731106,1.5536606,,,,,,,,,,,,,,,,, -107700,0.32809764,1.5557088,,,,,,,,,,,,,,,,, -107800,0.31357408,1.5841312,,,,,,,,,,,,,,,,, -107900,0.31986648,1.5138538,,,,,,,,,,,,,,,,, -108000,0.3000275,1.5955858,,,,,,,,,,,,,,,,, -108100,0.34239832,1.6012776,,,,,,,,,,,,,,,,, -108200,0.3330439,1.6027566,,,,,,,,,,,,,,,,, -108300,0.3275809,1.6487852,,,,,,,,,,,,,,,,, -108400,0.3383573,1.5740415,,,,,,,,,,,,,,,,, -108500,0.30735552,1.5856565,,,,,,,,,,,,,,,,, -108600,0.33331773,1.6130145,,,,,,,,,,,,,,,,, -108700,0.34442672,1.6263323,,,,,,,,,,,,,,,,, -108800,0.3126976,1.5576991,,,,,,,,,,,,,,,,, -108900,0.32171112,1.5378057,,,,,,,,,,,,,,,,, -109000,0.3257885,1.6103605,,,,,,,,,,,,,,,,, -109100,0.31379217,1.5343697,,,,,,,,,,,,,,,,, -109200,0.31399179,1.620482,,,,,,,,,,,,,,,,, -109300,0.31267077,1.5492002,,,,,,,,,,,,,,,,, -109400,0.32896626,1.5575289,,,,,,,,,,,,,,,,, -109489,,,0.6716055870056152,1.5597528219223022,33.4808430765383,0.6831905245780945,1.4681103229522705,30.04091275834891,3000.0,0.6989367604255676,1.3856124877929688,29.94549352250294,3003.0,37830.68860411644,63086.02798628807,37830.68860411644,25250.408385038376,1.4924664497375488,0.0 -109500,0.31905195,1.6408542,,,,,,,,,,,,,,,,, -109600,0.33841366,1.5147066,,,,,,,,,,,,,,,,, -109700,0.3421037,1.6069633,,,,,,,,,,,,,,,,, -109800,0.33312756,1.4957547,,,,,,,,,,,,,,,,, -109900,0.35496548,1.6172919,,,,,,,,,,,,,,,,, -110000,0.32760996,1.5765092,,,,,,,,,,,,,,,,, -110100,0.33079866,1.5887223,,,,,,,,,,,,,,,,, -110200,0.32550904,1.6154938,,,,,,,,,,,,,,,,, -110300,0.3274528,1.5673661,,,,,,,,,,,,,,,,, -110400,0.3363499,1.6647713,,,,,,,,,,,,,,,,, -110500,0.33857706,1.5662012,,,,,,,,,,,,,,,,, -110600,0.31108722,1.5142127,,,,,,,,,,,,,,,,, -110700,0.33128205,1.6327547,,,,,,,,,,,,,,,,, -110800,0.33143523,1.6240438,,,,,,,,,,,,,,,,, -110900,0.31952924,1.6429778,,,,,,,,,,,,,,,,, -111000,0.31765044,1.575833,,,,,,,,,,,,,,,,, -111100,0.35114452,1.5991567,,,,,,,,,,,,,,,,, -111200,0.34927204,1.5115954,,,,,,,,,,,,,,,,, -111300,0.34237707,1.5817031,,,,,,,,,,,,,,,,, -111400,0.34092996,1.6116999,,,,,,,,,,,,,,,,, -111500,0.33502316,1.6112307,,,,,,,,,,,,,,,,, -111600,0.34757283,1.5856096,,,,,,,,,,,,,,,,, -111700,0.3246192,1.5533284,,,,,,,,,,,,,,,,, -111800,0.33485374,1.5450596,,,,,,,,,,,,,,,,, -111900,0.33763257,1.6238368,,,,,,,,,,,,,,,,, -111923,,,0.6700042486190796,1.5730342864990234,33.42104699643389,0.6848644018173218,1.4581775665283203,30.067353708829728,3000.0,0.6991691589355469,1.375052571296692,29.85416362140656,3003.0,38670.9139418602,64468.39996099472,38670.9139418602,25792.438461780548,1.5327885150909424,0.0 -112000,0.3438541,1.6038582,,,,,,,,,,,,,,,,, -112100,0.3391806,1.5063285,,,,,,,,,,,,,,,,, -112200,0.44846115,1.5185797,,,,,,,,,,,,,,,,, -112300,0.3246316,1.5379565,,,,,,,,,,,,,,,,, -112400,0.329525,1.5207603,,,,,,,,,,,,,,,,, -112500,0.33560258,1.5623362,,,,,,,,,,,,,,,,, -112600,0.33277917,1.5801142,,,,,,,,,,,,,,,,, -112700,0.32254705,1.4909385,,,,,,,,,,,,,,,,, -112800,0.32924214,1.5337291,,,,,,,,,,,,,,,,, -112900,0.32792908,1.5338607,,,,,,,,,,,,,,,,, -113000,0.3220864,1.452432,,,,,,,,,,,,,,,,, -113100,0.34114385,1.4668963,,,,,,,,,,,,,,,,, -113200,0.32348117,1.4617397,,,,,,,,,,,,,,,,, -113300,0.33817205,1.5725895,,,,,,,,,,,,,,,,, -113400,0.34815457,1.5306804,,,,,,,,,,,,,,,,, -113500,0.33632174,1.5395544,,,,,,,,,,,,,,,,, -113600,0.33828768,1.527063,,,,,,,,,,,,,,,,, -113700,0.34299824,1.5435722,,,,,,,,,,,,,,,,, -113800,0.33152524,1.5096718,,,,,,,,,,,,,,,,, -113900,0.35602683,1.4953104,,,,,,,,,,,,,,,,, -114000,0.33202162,1.518651,,,,,,,,,,,,,,,,, -114100,0.34296337,1.5205697,,,,,,,,,,,,,,,,, -114200,0.32194543,1.5516043,,,,,,,,,,,,,,,,, -114300,0.35086104,1.5516337,,,,,,,,,,,,,,,,, -114357,,,0.6836004257202148,1.4913737773895264,34.42451429782527,0.6876541972160339,1.4455455541610718,30.17884093970701,3000.0,0.702213704586029,1.3610974550247192,30.16390519547844,3003.0,39510.85387992859,65875.96654629707,39510.85387992859,26359.942928552628,1.5786361694335938,0.0 -114400,0.35065886,1.5363636,,,,,,,,,,,,,,,,, -114500,0.3426566,1.6197225,,,,,,,,,,,,,,,,, -114600,0.35488963,1.5122743,,,,,,,,,,,,,,,,, -114700,0.3392234,1.4783044,,,,,,,,,,,,,,,,, -114800,0.3403047,1.5090562,,,,,,,,,,,,,,,,, -114900,0.3393271,1.5285861,,,,,,,,,,,,,,,,, -115000,0.32411313,1.4832854,,,,,,,,,,,,,,,,, -115100,0.3280618,1.4826422,,,,,,,,,,,,,,,,, -115200,0.36999717,1.5477561,,,,,,,,,,,,,,,,, -115300,0.3567649,1.5512116,,,,,,,,,,,,,,,,, -115400,0.33000994,1.55499,,,,,,,,,,,,,,,,, -115500,0.34967393,1.5802437,,,,,,,,,,,,,,,,, -115600,0.36058173,1.5519528,,,,,,,,,,,,,,,,, -115700,0.33515418,1.5639697,,,,,,,,,,,,,,,,, -115800,0.35403606,1.5202037,,,,,,,,,,,,,,,,, -115900,0.35560906,1.5312136,,,,,,,,,,,,,,,,, -116000,0.3589034,1.5362697,,,,,,,,,,,,,,,,, -116100,0.3347963,1.4876413,,,,,,,,,,,,,,,,, -116200,0.3454185,1.5053413,,,,,,,,,,,,,,,,, -116300,0.3592408,1.5124986,,,,,,,,,,,,,,,,, -116400,0.33533093,1.4243795,,,,,,,,,,,,,,,,, -116500,0.36745855,1.5118307,,,,,,,,,,,,,,,,, -116600,0.35569727,1.5297422,,,,,,,,,,,,,,,,, -116700,0.34454992,1.4560914,,,,,,,,,,,,,,,,, -116791,,,0.6783912181854248,1.516280174255371,34.31626903846658,0.6888321042060852,1.4344900846481323,30.4859611177528,3000.0,0.7040497660636902,1.344609022140503,30.197523386417203,3003.0,40350.74137663841,67289.86536455154,40350.74137663841,26933.835524082184,1.6199851036071775,0.0 -116800,0.32853338,1.5009712,,,,,,,,,,,,,,,,, -116900,0.36338606,1.5035952,,,,,,,,,,,,,,,,, -117000,0.35262856,1.5385573,,,,,,,,,,,,,,,,, -117100,0.38499758,1.5258272,,,,,,,,,,,,,,,,, -117200,0.35666427,1.5137061,,,,,,,,,,,,,,,,, -117300,0.34431827,1.5390545,,,,,,,,,,,,,,,,, -117400,0.33934498,1.5578475,,,,,,,,,,,,,,,,, -117500,0.36189917,1.5659559,,,,,,,,,,,,,,,,, -117600,0.33493876,1.4470264,,,,,,,,,,,,,,,,, -117700,0.3474056,1.4556966,,,,,,,,,,,,,,,,, -117800,0.34261903,1.4880775,,,,,,,,,,,,,,,,, -117900,0.3542556,1.4923211,,,,,,,,,,,,,,,,, -118000,0.33291188,1.4096084,,,,,,,,,,,,,,,,, -118100,0.36417586,1.5141044,,,,,,,,,,,,,,,,, -118200,0.36051613,1.4617972,,,,,,,,,,,,,,,,, -118300,0.35026902,1.5273713,,,,,,,,,,,,,,,,, -118400,0.3601692,1.4953688,,,,,,,,,,,,,,,,, -118500,0.35217822,1.503692,,,,,,,,,,,,,,,,, -118600,0.36172092,1.4613659,,,,,,,,,,,,,,,,, -118700,0.35076246,1.5610069,,,,,,,,,,,,,,,,, -118800,0.37157953,1.4668773,,,,,,,,,,,,,,,,, -118900,0.36583155,1.4428879,,,,,,,,,,,,,,,,, -119000,0.3651397,1.5607919,,,,,,,,,,,,,,,,, -119100,0.35790333,1.4673439,,,,,,,,,,,,,,,,, -119200,0.35779762,1.5874959,,,,,,,,,,,,,,,,, -119225,,,0.6942856907844543,1.4276137351989746,35.552780823300324,0.6898612380027771,1.4264612197875977,30.59610711383455,3000.0,0.704317033290863,1.340147614479065,30.53495047135109,3003.0,41190.94413161278,68823.72482323647,41190.94413161278,27627.37395954132,1.6613929271697998,0.0 -119300,0.35616386,1.4687532,,,,,,,,,,,,,,,,, -119400,0.3602171,1.5197668,,,,,,,,,,,,,,,,, -119500,0.35443595,1.4998742,,,,,,,,,,,,,,,,, -119600,0.3451635,1.4968495,,,,,,,,,,,,,,,,, -119700,0.34500787,1.5096377,,,,,,,,,,,,,,,,, -119800,0.34562075,1.4256057,,,,,,,,,,,,,,,,, -119900,0.35792488,1.4966778,,,,,,,,,,,,,,,,, -120000,0.37918583,1.468054,,,,,,,,,,,,,,,,, -120100,0.34862795,1.5112008,,,,,,,,,,,,,,,,, -120200,0.34741938,1.4496175,,,,,,,,,,,,,,,,, -120300,0.34450558,1.5324429,,,,,,,,,,,,,,,,, -120400,0.35981485,1.5410007,,,,,,,,,,,,,,,,, -120500,0.36327568,1.499701,,,,,,,,,,,,,,,,, -120600,0.36779732,1.5505455,,,,,,,,,,,,,,,,, -120700,0.35505164,1.4510832,,,,,,,,,,,,,,,,, -120800,0.36987722,1.4604537,,,,,,,,,,,,,,,,, -120900,0.38046306,1.4949945,,,,,,,,,,,,,,,,, -121000,0.36035287,1.4120973,,,,,,,,,,,,,,,,, -121100,0.36170197,1.4467225,,,,,,,,,,,,,,,,, -121200,0.35824433,1.4894949,,,,,,,,,,,,,,,,, -121300,0.35182616,1.3938875,,,,,,,,,,,,,,,,, -121400,0.36311355,1.4500688,,,,,,,,,,,,,,,,, -121500,0.35110465,1.4596473,,,,,,,,,,,,,,,,, -121600,0.35457808,1.3706607,,,,,,,,,,,,,,,,, -121659,,,0.6907327771186829,1.4434564113616943,34.881001369806654,0.6906051635742188,1.420662522315979,30.869333107168814,3000.0,0.708349347114563,1.3267370462417605,30.80779430639281,3003.0,42030.92356848717,70249.16306734085,42030.92356848717,28212.71403479576,1.703831911087036,0.0 -121700,0.3752483,1.4928386,,,,,,,,,,,,,,,,, -121800,0.36847624,1.5430409,,,,,,,,,,,,,,,,, -121900,0.3726237,1.5030179,,,,,,,,,,,,,,,,, -122000,0.3582477,1.4308809,,,,,,,,,,,,,,,,, -122100,0.36112833,1.3901476,,,,,,,,,,,,,,,,, -122200,0.38141948,1.49739,,,,,,,,,,,,,,,,, -122300,0.35292724,1.4364144,,,,,,,,,,,,,,,,, -122400,0.37424538,1.4146899,,,,,,,,,,,,,,,,, -122500,0.3570445,1.4168369,,,,,,,,,,,,,,,,, -122600,0.3837724,1.4578493,,,,,,,,,,,,,,,,, -122700,0.3717963,1.4853318,,,,,,,,,,,,,,,,, -122800,0.37369707,1.503533,,,,,,,,,,,,,,,,, -122900,0.3885496,1.4365079,,,,,,,,,,,,,,,,, -123000,0.3703243,1.5221174,,,,,,,,,,,,,,,,, -123100,0.36581257,1.4714005,,,,,,,,,,,,,,,,, -123200,0.37360808,1.399883,,,,,,,,,,,,,,,,, -123300,0.38216814,1.4676516,,,,,,,,,,,,,,,,, -123400,0.3689776,1.4633085,,,,,,,,,,,,,,,,, -123500,0.3808303,1.4160442,,,,,,,,,,,,,,,,, -123600,0.39365685,1.5089623,,,,,,,,,,,,,,,,, -123700,0.3706038,1.4525826,,,,,,,,,,,,,,,,, -123800,0.361944,1.4238482,,,,,,,,,,,,,,,,, -123900,0.3938767,1.4448011,,,,,,,,,,,,,,,,, -124000,0.36830285,1.385569,,,,,,,,,,,,,,,,, -124093,,,0.6904439330101013,1.4435969591140747,35.23399677830485,0.6916218996047974,1.41691792011261,30.74350009523277,3000.0,0.708453893661499,1.3230172395706177,30.912976176032,3003.0,42871.01779818535,71706.35120844841,42871.01779818535,28829.68808722496,1.747204303741455,0.0 -124100,0.35925138,1.4044813,,,,,,,,,,,,,,,,, -124200,0.36607188,1.434491,,,,,,,,,,,,,,,,, -124300,0.39395759,1.394781,,,,,,,,,,,,,,,,, -124400,0.38297573,1.4711707,,,,,,,,,,,,,,,,, -124500,0.38556305,1.4710697,,,,,,,,,,,,,,,,, -124600,0.37192106,1.4643438,,,,,,,,,,,,,,,,, -124700,0.38489586,1.45031,,,,,,,,,,,,,,,,, -124800,0.3783404,1.4784715,,,,,,,,,,,,,,,,, -124900,0.38003024,1.4628713,,,,,,,,,,,,,,,,, -125000,0.36345053,1.4569486,,,,,,,,,,,,,,,,, -125100,0.39147648,1.4320241,,,,,,,,,,,,,,,,, -125200,0.3670932,1.3911515,,,,,,,,,,,,,,,,, -125300,0.37859672,1.4333771,,,,,,,,,,,,,,,,, -125400,0.3882933,1.5227429,,,,,,,,,,,,,,,,, -125500,0.37925667,1.4612161,,,,,,,,,,,,,,,,, -125600,0.37504303,1.4863616,,,,,,,,,,,,,,,,, -125700,0.36040646,1.4174463,,,,,,,,,,,,,,,,, -125800,0.37231833,1.4668423,,,,,,,,,,,,,,,,, -125900,0.3637713,1.4425445,,,,,,,,,,,,,,,,, -126000,0.39573035,1.4510983,,,,,,,,,,,,,,,,, -126100,0.37481168,1.4636991,,,,,,,,,,,,,,,,, -126200,0.36099157,1.4025893,,,,,,,,,,,,,,,,, -126300,0.38260418,1.5431584,,,,,,,,,,,,,,,,, -126400,0.36104095,1.4610466,,,,,,,,,,,,,,,,, -126500,0.36793175,1.432544,,,,,,,,,,,,,,,,, -126527,,,0.6974536180496216,1.413323998451233,35.86837829412484,0.693134605884552,1.4092453718185425,30.930916769026155,3000.0,0.7094765305519104,1.3161181211471558,30.907800321136087,3003.0,43711.224172115326,73145.43577122688,43711.224172115326,29428.44505548477,1.790480613708496,0.0 -126600,0.3642857,1.3617064,,,,,,,,,,,,,,,,, -126700,0.37624508,1.3570074,,,,,,,,,,,,,,,,, -126800,0.37053898,1.4294044,,,,,,,,,,,,,,,,, -126900,0.36578193,1.3813686,,,,,,,,,,,,,,,,, -127000,0.35961086,1.3717256,,,,,,,,,,,,,,,,, -127100,0.37293378,1.3947148,,,,,,,,,,,,,,,,, -127200,0.38158756,1.3911829,,,,,,,,,,,,,,,,, -127300,0.36739397,1.4527404,,,,,,,,,,,,,,,,, -127400,0.3907455,1.3686109,,,,,,,,,,,,,,,,, -127500,0.37025005,1.4869335,,,,,,,,,,,,,,,,, -127600,0.37729567,1.4347026,,,,,,,,,,,,,,,,, -127700,0.3849678,1.444875,,,,,,,,,,,,,,,,, -127800,0.3646299,1.4071518,,,,,,,,,,,,,,,,, -127900,0.36622474,1.3968525,,,,,,,,,,,,,,,,, -128000,0.3752206,1.4424083,,,,,,,,,,,,,,,,, -128100,0.38288915,1.4751108,,,,,,,,,,,,,,,,, -128200,0.3814795,1.4500169,,,,,,,,,,,,,,,,, -128300,0.39891675,1.4756883,,,,,,,,,,,,,,,,, -128400,0.3692788,1.4067652,,,,,,,,,,,,,,,,, -128500,0.37705,1.4273492,,,,,,,,,,,,,,,,, -128600,0.39857823,1.460304,,,,,,,,,,,,,,,,, -128700,0.37786934,1.4454868,,,,,,,,,,,,,,,,, -128800,0.38405117,1.4544557,,,,,,,,,,,,,,,,, -128900,0.36906365,1.4285319,,,,,,,,,,,,,,,,, -128961,,,0.6983780860900879,1.4064826965332031,35.716242509446154,0.692638635635376,1.40710711479187,31.06077647495805,3000.0,0.7107547521591187,1.3119847774505615,31.213135886632305,3003.0,44551.23852276802,74609.69246411324,44551.23852276802,30052.566106319427,1.8353376388549805,0.0 -129000,0.38338786,1.4387467,,,,,,,,,,,,,,,,, -129100,0.3746604,1.3786427,,,,,,,,,,,,,,,,, -129200,0.37879607,1.430099,,,,,,,,,,,,,,,,, -129300,0.3775934,1.4239465,,,,,,,,,,,,,,,,, -129400,0.3707711,1.3945549,,,,,,,,,,,,,,,,, -129500,0.38058284,1.3811531,,,,,,,,,,,,,,,,, -129600,0.3602741,1.3917439,,,,,,,,,,,,,,,,, -129700,0.38405374,1.4423287,,,,,,,,,,,,,,,,, -129800,0.3794149,1.3801275,,,,,,,,,,,,,,,,, -129900,0.3715486,1.4766512,,,,,,,,,,,,,,,,, -130000,0.37252337,1.3606944,,,,,,,,,,,,,,,,, -130100,0.38206562,1.442172,,,,,,,,,,,,,,,,, -130200,0.37486234,1.4582169,,,,,,,,,,,,,,,,, -130300,0.39337763,1.4876862,,,,,,,,,,,,,,,,, -130400,0.384025,1.4340059,,,,,,,,,,,,,,,,, -130500,0.398318,1.4297473,,,,,,,,,,,,,,,,, -130600,0.3756472,1.4170682,,,,,,,,,,,,,,,,, -130700,0.37130728,1.4145578,,,,,,,,,,,,,,,,, -130800,0.37219512,1.4742243,,,,,,,,,,,,,,,,, -130900,0.3606837,1.3885607,,,,,,,,,,,,,,,,, -131000,0.36652303,1.3781677,,,,,,,,,,,,,,,,, -131100,0.36836037,1.4620498,,,,,,,,,,,,,,,,, -131200,0.3820307,1.4686522,,,,,,,,,,,,,,,,, -131300,0.39276108,1.4918243,,,,,,,,,,,,,,,,, -131395,,,0.6976325511932373,1.408244490623474,35.93356190724933,0.6933081746101379,1.4070932865142822,30.89430551114946,3000.0,0.711498498916626,1.3109595775604248,31.10668950079464,3003.0,45391.26353955269,76032.60543179512,45391.26353955269,30635.33205962181,1.8811628818511963,0.0 -131400,0.36366016,1.4092313,,,,,,,,,,,,,,,,, -131500,0.37619805,1.4069451,,,,,,,,,,,,,,,,, -131600,0.3817842,1.4262797,,,,,,,,,,,,,,,,, -131700,0.3758608,1.4409667,,,,,,,,,,,,,,,,, -131800,0.38105223,1.4447213,,,,,,,,,,,,,,,,, -131900,0.3674524,1.3634609,,,,,,,,,,,,,,,,, -132000,0.38423747,1.4542495,,,,,,,,,,,,,,,,, -132100,0.3783411,1.3526627,,,,,,,,,,,,,,,,, -132200,0.36160654,1.4286213,,,,,,,,,,,,,,,,, -132300,0.3809633,1.4284133,,,,,,,,,,,,,,,,, -132400,0.38758188,1.4693725,,,,,,,,,,,,,,,,, -132500,0.38739407,1.4397211,,,,,,,,,,,,,,,,, -132600,0.37146786,1.3926191,,,,,,,,,,,,,,,,, -132700,0.3818008,1.4447565,,,,,,,,,,,,,,,,, -132800,0.3804525,1.4804387,,,,,,,,,,,,,,,,, -132900,0.40897277,1.504366,,,,,,,,,,,,,,,,, -133000,0.36801955,1.4524455,,,,,,,,,,,,,,,,, -133100,0.37276253,1.3911067,,,,,,,,,,,,,,,,, -133200,0.36985132,1.4089353,,,,,,,,,,,,,,,,, -133300,0.36738998,1.4891884,,,,,,,,,,,,,,,,, -133333,,,0.6962321400642395,1.4215978384017944,35.18021952446278,0.6931717991828918,1.4065862894058228,30.9194758232504,3000.0,0.7112892866134644,1.3105939626693726,31.039245815704213,3003.0,46060.09185528755,77296.29063916206,46060.09185528755,31230.08441901207,1.925365686416626,0.0 -133333,,,,,,,,,,,,,,46060.09185528755,,,,,0.0 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/eval_measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/eval_measurements.csv deleted file mode 100644 index 6bdb75373..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/eval_measurements.csv +++ /dev/null @@ -1,59 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -856.2280375957489,0.0,30.910383462905884,1,0,30.910383462905884,0.0007088489946909,0.0,11.036273956298828,3003,887.1384572982788,0.0005945574957877,0.0,11.024234771728516,0.0004835649742744,0.0,11.047277450561523,3000 -1357.2976994514463,0.0189738273620605,870.9926633834839,2373,0,870.9926633834839,0.5161815285682678,17.2187217021128,2.86415958404541,3003,2228.3837530612946,0.5159180164337158,23.051630204159792,2.8420560359954834,0.5170301795005798,18.71438425767787,2.82480525970459,3000 -1960.2570941448207,0.0446903705596923,1710.9601809978485,4747,0,1710.9601809978485,0.5934344530105591,22.19616409204832,2.123737335205078,3003,3671.411406755448,0.5766600966453552,26.82964297454152,2.2446744441986084,0.5912140011787415,23.472348577549567,2.1382405757904053,3000 -2380.467398405075,0.0704712867736816,2551.112315654754,7122,0,2551.112315654754,0.6211028099060059,24.172361442868134,1.8899255990982056,3003,4931.874216079712,0.604808509349823,29.179246592431355,2.0004196166992188,0.6162725687026978,25.08246901209177,1.923214077949524,3000 -2852.2883291244507,0.0963246822357177,3391.2784502506256,9497,0,3391.2784502506256,0.6391029357910156,25.708341384430515,1.7555793523788452,3003,6243.961631298065,0.6149968504905701,29.185273045229128,1.941268801689148,0.6318706274032593,26.343849434747742,1.8118021488189693,3000 -3292.704610824585,0.1238934993743896,4231.514441490173,11872,0,4231.514441490173,0.6488640904426575,25.705053088847816,1.6925368309020996,3003,7524.716588258743,0.6193379759788513,29.845714220736607,1.89149820804596,0.6386777758598328,26.896069609013946,1.7426763772964478,3000 -3784.979299783706,0.1510317325592041,5071.745926856995,14247,0,5071.745926856995,0.6537911891937256,26.303994280825968,1.640441656112671,3003,8857.325303792953,0.6294394135475159,30.450921466562807,1.81944739818573,0.6437985897064209,27.1843550621261,1.7072787284851074,3000 -4590.6215217113495,0.1793782711029052,5911.688598394394,16621,0,5911.688598394394,0.6579513549804688,26.439084023361502,1.6053942441940308,3003,10503.013573646544,0.6266187429428101,30.356420384125588,1.8285259008407595,0.6463403701782227,27.10367908047491,1.6791563034057615,3000 -5385.383625268936,0.2077019214630127,6751.716594457626,18996,0,6751.716594457626,0.6589855551719666,26.80562274957389,1.592655897140503,3003,12137.90736103058,0.6516609787940979,31.447995353884966,1.649769902229309,0.6500477194786072,26.475614046524264,1.6590960025787354,3000 -6065.267931461334,0.2369616031646728,7591.700333595276,21370,0,7591.700333595276,0.6615420579910278,26.81812092467978,1.5799243450164795,3003,13657.880095720291,0.6328192353248596,30.559165503155825,1.773112177848816,0.6511512398719788,27.1359312860208,1.648661971092224,3000 -6615.000554323196,0.2675762176513672,8431.671695709229,23744,0,8431.671695709229,0.6650747060775757,27.07595711213196,1.5590084791183472,3003,15047.690311908722,0.6332682967185974,31.001136692097017,1.783609390258789,0.6562844514846802,27.96646650715337,1.6264350414276123,3000 -7063.780653953552,0.2965688705444336,9271.90131354332,26119,0,9271.90131354332,0.6670966148376465,27.10684203663389,1.5462366342544556,3003,16336.804920434952,0.6439311504364014,31.45628583204521,1.6942683458328247,0.6556025147438049,27.50174142438485,1.624097228050232,3000 -7684.471136808395,0.3272538185119629,10112.068130731584,28493,0,10112.068130731584,0.6716286540031433,27.86535354891519,1.5281535387039185,3003,17797.767949342728,0.6390884518623352,31.07782237514248,1.7399277687072754,0.6567432284355164,27.80369912167485,1.6084965467453003,3000 -8169.727640390396,0.3598606586456299,10952.097707033156,30867,0,10952.097707033156,0.6713497638702393,27.451590240782444,1.5222526788711548,3003,19123.16144561768,0.6361994743347168,30.96582954608702,1.7502000331878662,0.6586031317710876,27.939569403470813,1.5999562740325928,3000 -8682.395342111588,0.3906929492950439,11792.311067819595,33242,0,11792.311067819595,0.671884298324585,27.42189268731702,1.5101147890090942,3003,20476.148945569992,0.6436529755592346,31.702394127417445,1.703749179840088,0.6596198081970215,28.28841359912601,1.5879870653152466,3000 -9139.705032587051,0.4213588237762451,12632.403628587725,35616,0,12632.403628587725,0.6744756698608398,27.731952103228185,1.5015976428985596,3003,21773.65692186356,0.641491711139679,31.370909674220155,1.723802089691162,0.6618516445159912,28.14534491270295,1.5850014686584473,3000 -9627.55661559105,0.4520916938781738,13472.459535121918,37990,0,13472.459535121918,0.6746964454650879,27.57462241443689,1.4958630800247192,3003,23101.66996979713,0.6566494703292847,32.589987116770835,1.6067266464233398,0.6636743545532227,28.37420393676452,1.575439691543579,3000 -10291.879784822464,0.4850766658782959,14312.40743112564,40364,0,14312.40743112564,0.6744988560676575,27.850430361041195,1.485736846923828,3003,24606.04882979393,0.6478399634361267,31.48029070214986,1.676770567893982,0.6622112393379211,28.142196735251872,1.57160747051239,3000 -10824.458010911942,0.5174057483673096,15152.398067235948,42738,0,15152.398067235948,0.6777177453041077,28.13747402749904,1.4755793809890747,3003,25978.72548937797,0.6446605920791626,31.174335951599705,1.6969213485717771,0.6645174622535706,28.429532815753745,1.5625895261764526,3000 -11379.908358573914,0.5488555431365967,15992.489179372787,45112,0,15992.489179372787,0.6789379119873047,28.37878554052869,1.4708216190338137,3003,27374.3735370636,0.6525201201438904,32.448058247862534,1.646553874015808,0.6664517521858215,28.67135209872841,1.550062656402588,3000 -11873.347818136215,0.5799081325531006,16832.455917835236,47487,0,16832.455917835236,0.6797629594802856,28.22441199348882,1.469280242919922,3003,28707.885778665543,0.6500367522239685,32.292675817863426,1.6660417318344116,0.6654598116874695,28.594837733209264,1.5579570531845093,3000 -12428.909444093704,0.6142618656158447,17672.611881494522,49862,0,17672.611881494522,0.6824705004692078,28.232030328801063,1.4507604837417605,3003,30103.71274662018,0.6516985297203064,31.26118472682359,1.6592146158218384,0.6688819527626038,29.092173467752627,1.5371198654174805,3000 -13037.063804626465,0.6480474472045898,18512.59501695633,52237,0,18512.59501695633,0.6822032332420349,28.419362421745703,1.4496662616729736,3003,31551.95904159546,0.651114284992218,32.15875856369312,1.653276443481445,0.670059859752655,29.08356031013785,1.5286074876785278,3000 -13580.379266500471,0.6796233654022217,19352.575419425964,54612,0,19352.575419425964,0.6830283403396606,28.666378067688967,1.4475644826889038,3003,32935.361577510834,0.6498084664344788,31.615262520847377,1.666841983795166,0.6697251200675964,28.79401454291601,1.5325440168380735,3000 -14104.548913478851,0.7127649784088135,20192.615739822388,56987,0,20192.615739822388,0.6824240684509277,28.52835011043656,1.442415714263916,3003,34299.67973899841,0.663896381855011,32.592228093350805,1.5628291368484497,0.6697623133659363,28.77605624550842,1.524965763092041,3000 -14885.285350084305,0.7453715801239014,21032.84838962555,59362,0,21032.84838962555,0.6865028142929077,27.1299610500768,1.425528049468994,3003,35920.757283210754,0.6547658443450928,32.32660768039863,1.639319896697998,0.6719445586204529,28.431021789399217,1.517433762550354,3000 -15399.552985191343,0.7798454761505127,21872.760452747345,61736,0,21872.760452747345,0.6850618720054626,28.778837217794965,1.4241111278533936,3003,37275.04721617699,0.6519191861152649,32.12458710119513,1.6469298601150513,0.673072874546051,29.256527875615028,1.5103808641433716,3000 -16075.27707862854,0.812446117401123,22712.73940300941,64110,0,22712.73940300941,0.6873162388801575,28.827913304890146,1.4133727550506592,3003,38790.858786821365,0.6602867245674133,32.69674771030805,1.5844924449920654,0.6743127703666687,29.04694521750138,1.5003536939620972,3000 -16765.60342478752,0.8476648330688477,23552.78756380081,66485,0,23552.78756380081,0.691046416759491,29.41162720824399,1.3991174697875977,3003,40321.3447868824,0.6578103303909302,32.83270106756466,1.6158467531204224,0.6757386922836304,29.64027073611295,1.4923642873764038,3000 -17284.506219387054,0.8824207782745361,24392.84364748001,68860,0,24392.84364748001,0.6896519660949707,29.14519622574132,1.3980993032455444,3003,41680.414498806,0.6925715208053589,35.164799300747305,1.3943202495574951,0.675986647605896,29.37226423439722,1.4907090663909912,3000 -17989.61869287491,0.9181523323059082,25233.028274297714,71235,0,25233.028274297714,0.6922782063484192,29.602084105542623,1.3883384466171265,3003,43225.82240843773,0.663612425327301,33.04904270814784,1.5702091455459597,0.6764578223228455,29.29003691272471,1.4874941110610962,3000 -18606.38701105117,0.956063747406006,26073.24313759804,73610,0,26073.24313759804,0.6927198171615601,29.329502689351862,1.3803156614303589,3003,44682.91942191124,0.6637750864028931,32.387814931952754,1.5700886249542236,0.6769413948059082,29.41112830125952,1.4775288105010986,3000 -19192.265601158146,0.9935698509216307,26913.46995139122,75985,0,26913.46995139122,0.6943234205245972,29.62003987298333,1.3734365701675415,3003,46109.13782072067,0.6715718507766724,33.4930723722878,1.5168931484222412,0.6797311902046204,29.64138507890422,1.4644559621810913,3000 -19693.95895600319,1.03102707862854,27753.6347835064,78359,0,27753.6347835064,0.6962175369262695,29.716893139832724,1.3679414987564087,3003,47451.10957407951,0.6638250946998596,32.52902754344788,1.5705655813217163,0.6803511381149292,29.507351771603417,1.4611443281173706,3000 -20216.756098032,1.0682952404022217,28593.58155298233,80733,0,28593.58155298233,0.6961246132850647,29.882335007945937,1.3614065647125244,3003,48813.9661796093,0.6645106077194214,33.2519705355767,1.572019338607788,0.6811570525169373,30.089135107250858,1.4543503522872925,3000 -20768.853723526,1.1046974658966064,29433.65891289711,83108,0,29433.65891289711,0.6982976198196411,30.016736440843538,1.349581003189087,3003,50206.2528424263,0.6732177138328552,33.89015855149095,1.5030677318572998,0.6820250153541565,29.56513048449141,1.4465928077697754,3000 -21316.877049922943,1.142566204071045,30273.729954242703,85483,0,30273.729954242703,0.698448657989502,29.82206997277747,1.3438206911087036,3003,51594.460906744,0.6708388328552246,33.3104744951593,1.520539402961731,0.6837112903594971,30.142586435709703,1.4417012929916382,3000 -21922.486659526825,1.18042254447937,31113.74466848373,87858,0,31113.74466848373,0.698506772518158,29.97895945307208,1.3392562866210938,3003,53040.19746589661,0.6888561248779297,34.55736403714124,1.4054269790649414,0.6858067512512207,29.992162601588777,1.4331176280975342,3000 -22616.300694704056,1.2178633213043213,31953.942955493927,90234,0,31953.942955493927,0.7016443014144897,29.90019334184747,1.3300366401672363,3003,54574.32289242744,0.6757356524467468,33.67554884277297,1.4885623455047607,0.68563312292099,29.696829275004344,1.4297630786895752,3000 -23154.02640128136,1.2572076320648191,32794.13788151741,92610,0,32794.13788151741,0.7028412222862244,30.292554115791944,1.321674346923828,3003,55952.35843801498,0.6737697720527649,33.71377058019944,1.50662362575531,0.6864514946937561,30.020258720494507,1.4230599403381348,3000 -23737.11946773529,1.2959558963775637,33634.26413846016,94985,0,33634.26413846016,0.7035384774208069,30.436514797787588,1.316325306892395,3003,57375.692895412445,0.6820225715637207,34.13055095861968,1.4462394714355469,0.6877161860466003,30.236845114728684,1.4159722328186035,3000 -24358.108916282654,1.3352117538452148,34474.2428150177,97360,0,34474.2428150177,0.7068619132041931,30.64499714464644,1.306337833404541,3003,58836.77590274811,0.6778927445411682,33.92635535254963,1.4817765951156616,0.6891545057296753,30.186419436898863,1.410452127456665,3000 -24884.105031967163,1.375361442565918,35314.15435934067,99735,0,35314.15435934067,0.7069200277328491,30.6400977790549,1.2959305047988892,3003,60202.79884767532,0.6772194504737854,33.94047766861321,1.4831637144088743,0.6899604201316833,30.348874012448068,1.4045439958572388,3000 -25444.19984698296,1.414492130279541,36154.0476911068,102110,0,36154.0476911068,0.7071989178657532,30.771823174514257,1.2938815355300903,3003,61602.90172743797,0.6823187470436096,34.765234065491775,1.446167230606079,0.6899108290672302,30.40090160189428,1.3995331525802612,3000 -26074.819528579712,1.4544637203216553,36994.13355565071,104486,0,36994.13355565071,0.70707106590271,30.53443950681672,1.2888002395629885,3003,63073.723007678986,0.6826587319374084,34.46536235091862,1.4512836933135986,0.6910763382911682,30.427187187595266,1.3928955793380735,3000 -26778.89575409889,1.4950628280639648,37834.16900777817,106862,0,37834.16900777817,0.7095810770988464,30.553318604965003,1.2826364040374756,3003,64617.951558828354,0.6929945349693298,35.6936682202158,1.3875455856323242,0.6926262378692627,30.321860371372377,1.3891054391860962,3000 -27396.39172673225,1.5361199378967283,38674.22744607925,109238,0,38674.22744607925,0.7090697884559631,30.7142415724196,1.2797530889511108,3003,66075.62375807762,0.6881809234619141,34.792418856647195,1.4110008478164673,0.6915723085403442,30.53146184710506,1.3885592222213743,3000 -27934.165376663208,1.5760877132415771,39514.40298914909,111614,0,39514.40298914909,0.7106850147247314,30.786169768048005,1.2727055549621582,3003,67453.68859291077,0.6850574612617493,34.64420152808263,1.43190598487854,0.6938413381576538,30.713396510455496,1.3793723583221436,3000 -28590.395575761795,1.6191487312316897,40354.49650359154,113990,0,40354.49650359154,0.711661159992218,31.040662341172386,1.2692705392837524,3003,68950.13110136986,0.6922797560691833,35.342772251239296,1.3891639709472656,0.6943125128746033,30.757281576648342,1.3756449222564695,3000 -29159.84848332405,1.660917043685913,41194.70844531059,116366,0,41194.70844531059,0.7128929495811462,31.160097992334794,1.264074683189392,3003,70359.9132270813,0.6906039714813232,34.79750781778418,1.3968476057052612,0.6947464942932129,30.66973494550261,1.373445987701416,3000 -29691.103921175003,1.7029705047607422,42034.721598386765,118741,0,42034.721598386765,0.7127767205238342,30.7693335450808,1.262351393699646,3003,71731.30040001869,0.6908093690872192,34.89803540192062,1.4008818864822388,0.6953044533729553,30.95589799049646,1.3693207502365112,3000 -30304.17706489563,1.744511365890503,42874.60907793045,121116,0,42874.60907793045,0.7138457894325256,31.12364558456364,1.2584335803985596,3003,73184.37866711617,0.6932882070541382,35.52039251433238,1.388006567955017,0.6962592005729675,30.874046124594287,1.367598533630371,3000 -30837.32000207901,1.7871100902557373,43714.75734376907,123491,0,43714.75734376907,0.7139503955841064,31.129922784276456,1.255860447883606,3003,74557.78880643845,0.6912440061569214,35.35323909995653,1.395629644393921,0.6952672600746155,30.84424959213709,1.367558479309082,3000 -31382.90145254135,1.8318068981170648,44554.68297743797,125866,0,44554.68297743797,0.7137179970741272,30.93043914065589,1.25685453414917,3003,75943.4162247181,0.697185218334198,35.78074525428276,1.3619694709777832,0.695651650428772,30.911010119255785,1.3671724796295166,3000 -31950.483147382736,1.8745992183685305,45394.836584329605,128241,0,45394.836584329605,0.7137528657913208,31.12628625505224,1.2547988891601562,3003,77351.26991295815,0.6922347545623779,35.30423628524667,1.3835301399230957,0.6954284310340881,30.88262735410851,1.3652565479278564,3000 -32513.490881443024,1.928779125213623,46234.92179393768,130616,0,46234.92179393768,0.7141479253768921,31.103261886181247,1.2541664838790894,3003,78754.49300599098,0.6963204741477966,35.57011387339867,1.366992473602295,0.6954284310340881,30.946973319770205,1.3648079633712769,3000 -33078.891793727875,1.9719855785369875,47074.85367107391,132991,0,47074.85367107391,0.7140085101127625,31.14580943345989,1.2541321516036987,3003,80159.94454312325,0.69627445936203,35.58865724493143,1.36464262008667,0.6954532265663147,31.040501645749234,1.3649661540985107,3000 -33632.739612579346,2.015644073486328,47195.332033634186,133333,0,47195.332033634186,0.7140317559242249,31.13090825160175,1.2541394233703613,3003,80834.32544207573,0.6945627331733704,35.631271648808266,1.3744174242019653,0.6954532265663147,31.040501645749234,1.3649758100509644,3000 diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/measurements.csv b/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/measurements.csv deleted file mode 100644 index 31a24ff22..000000000 --- a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/measurements.csv +++ /dev/null @@ -1,1394 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.789248,11.021477,,,,,,,,,,,,,,,,, -1,,,0.0005945574957877,11.024234771728516,0.0,0.0004835649742744,11.047277450561523,0.0,3000.0,0.0007088489946909,11.036273956298828,0.0,3003.0,30.910383462905884,887.1384572982788,30.910383462905884,856.2280375957489,0.0,0.0 -100,0.14961265,8.2489805,,,,,,,,,,,,,,,,, -200,0.49750254,7.4857616,,,,,,,,,,,,,,,,, -300,0.37088782,6.8687816,,,,,,,,,,,,,,,,, -400,0.5543487,6.3305583,,,,,,,,,,,,,,,,, -500,0.44564283,5.8620815,,,,,,,,,,,,,,,,, -600,0.45984927,5.4946704,,,,,,,,,,,,,,,,, -700,0.41917625,5.2905293,,,,,,,,,,,,,,,,, -800,0.37636402,5.005205,,,,,,,,,,,,,,,,, -900,0.44796506,4.816192,,,,,,,,,,,,,,,,, -1000,0.50835365,4.592582,,,,,,,,,,,,,,,,, -1100,0.5439562,4.348014,,,,,,,,,,,,,,,,, -1200,0.49261355,4.0208535,,,,,,,,,,,,,,,,, -1300,0.48109898,4.002508,,,,,,,,,,,,,,,,, -1400,0.5133812,3.7716956,,,,,,,,,,,,,,,,, -1500,0.7034692,3.6612282,,,,,,,,,,,,,,,,, -1600,0.4468564,3.5535014,,,,,,,,,,,,,,,,, -1700,0.43921065,3.5051796,,,,,,,,,,,,,,,,, -1800,0.4126761,3.2649646,,,,,,,,,,,,,,,,, -1900,0.6610724,3.3004315,,,,,,,,,,,,,,,,, -2000,0.44426006,3.2287047,,,,,,,,,,,,,,,,, -2100,0.37322998,3.0162947,,,,,,,,,,,,,,,,, -2200,0.36876976,3.0847933,,,,,,,,,,,,,,,,, -2300,0.41831237,2.966548,,,,,,,,,,,,,,,,, -2373,,,0.5159180164337158,2.8420560359954834,23.051630204159792,0.5170301795005798,2.82480525970459,18.71438425767787,3000.0,0.5161815285682678,2.86415958404541,17.2187217021128,3003.0,870.9926633834839,2228.3837530612946,870.9926633834839,1357.2976994514463,0.0189738273620605,0.0 -2400,0.31477544,2.9174755,,,,,,,,,,,,,,,,, -2500,0.33088294,2.922921,,,,,,,,,,,,,,,,, -2600,0.36288112,2.7858486,,,,,,,,,,,,,,,,, -2700,0.2940709,2.7733533,,,,,,,,,,,,,,,,, -2800,0.3403575,2.7842765,,,,,,,,,,,,,,,,, -2900,0.29913858,2.651689,,,,,,,,,,,,,,,,, -3000,0.28800786,2.705136,,,,,,,,,,,,,,,,, -3100,0.43049085,2.6926663,,,,,,,,,,,,,,,,, -3200,0.23656705,2.6127796,,,,,,,,,,,,,,,,, -3300,0.22881928,2.5495656,,,,,,,,,,,,,,,,, -3400,0.25294003,2.5916333,,,,,,,,,,,,,,,,, -3500,0.2623698,2.4712927,,,,,,,,,,,,,,,,, -3600,0.22242647,2.5604162,,,,,,,,,,,,,,,,, -3700,0.1950279,2.459054,,,,,,,,,,,,,,,,, -3800,0.22752015,2.4616568,,,,,,,,,,,,,,,,, -3900,0.21709628,2.4427526,,,,,,,,,,,,,,,,, -4000,0.19905081,2.4844427,,,,,,,,,,,,,,,,, -4100,0.18598545,2.3602378,,,,,,,,,,,,,,,,, -4200,0.21281183,2.452756,,,,,,,,,,,,,,,,, -4300,0.21385121,2.3561966,,,,,,,,,,,,,,,,, -4400,0.17305836,2.3046374,,,,,,,,,,,,,,,,, -4500,0.18857358,2.2713342,,,,,,,,,,,,,,,,, -4600,0.21143731,2.3261926,,,,,,,,,,,,,,,,, -4700,0.15674394,2.267055,,,,,,,,,,,,,,,,, -4747,,,0.5766600966453552,2.2446744441986084,26.82964297454152,0.5912140011787415,2.1382405757904053,23.472348577549567,3000.0,0.5934344530105591,2.123737335205078,22.19616409204832,3003.0,1710.9601809978485,3671.411406755448,1710.9601809978485,1960.2570941448207,0.0446903705596923,0.0 -4800,0.18502556,2.4093492,,,,,,,,,,,,,,,,, -4900,0.1840897,2.2838626,,,,,,,,,,,,,,,,, -5000,0.16584378,2.259663,,,,,,,,,,,,,,,,, -5100,0.1654749,2.248595,,,,,,,,,,,,,,,,, -5200,0.15216047,2.2985816,,,,,,,,,,,,,,,,, -5300,0.17906602,2.240504,,,,,,,,,,,,,,,,, -5400,0.18422033,2.2472074,,,,,,,,,,,,,,,,, -5500,0.16699788,2.2718418,,,,,,,,,,,,,,,,, -5600,0.17937298,2.2175317,,,,,,,,,,,,,,,,, -5700,0.15680915,2.2091453,,,,,,,,,,,,,,,,, -5800,0.1735439,2.1903887,,,,,,,,,,,,,,,,, -5900,0.1785709,2.1723738,,,,,,,,,,,,,,,,, -6000,0.17097437,2.2260342,,,,,,,,,,,,,,,,, -6100,0.17922443,2.2684467,,,,,,,,,,,,,,,,, -6200,0.16670538,2.2620368,,,,,,,,,,,,,,,,, -6300,0.16550042,2.2504928,,,,,,,,,,,,,,,,, -6400,0.16303636,2.1735241,,,,,,,,,,,,,,,,, -6500,0.1590777,2.1539953,,,,,,,,,,,,,,,,, -6600,0.14954276,2.109645,,,,,,,,,,,,,,,,, -6700,0.28742653,2.0747626,,,,,,,,,,,,,,,,, -6800,0.19393384,2.1401832,,,,,,,,,,,,,,,,, -6900,0.1654235,2.1246245,,,,,,,,,,,,,,,,, -7000,0.18360789,2.2380733,,,,,,,,,,,,,,,,, -7100,0.1645137,2.0854483,,,,,,,,,,,,,,,,, -7122,,,0.604808509349823,2.0004196166992188,29.179246592431355,0.6162725687026978,1.923214077949524,25.08246901209177,3000.0,0.6211028099060059,1.8899255990982056,24.172361442868134,3003.0,2551.112315654754,4931.874216079712,2551.112315654754,2380.467398405075,0.0704712867736816,0.0 -7200,0.16514836,2.0674431,,,,,,,,,,,,,,,,, -7300,0.15799822,2.0267937,,,,,,,,,,,,,,,,, -7400,0.1511762,2.1242342,,,,,,,,,,,,,,,,, -7500,0.15119208,2.0810223,,,,,,,,,,,,,,,,, -7600,0.16214038,2.117786,,,,,,,,,,,,,,,,, -7700,0.187014,2.132271,,,,,,,,,,,,,,,,, -7800,0.15769865,1.9718187,,,,,,,,,,,,,,,,, -7900,0.20558146,2.1299791,,,,,,,,,,,,,,,,, -8000,0.20916577,2.1165824,,,,,,,,,,,,,,,,, -8100,0.16368024,2.1164227,,,,,,,,,,,,,,,,, -8200,0.16679746,2.06037,,,,,,,,,,,,,,,,, -8300,0.20072113,2.0509102,,,,,,,,,,,,,,,,, -8400,0.1720548,2.140534,,,,,,,,,,,,,,,,, -8500,0.20518512,2.096986,,,,,,,,,,,,,,,,, -8600,0.17382963,2.0092225,,,,,,,,,,,,,,,,, -8700,0.16309695,1.9876567,,,,,,,,,,,,,,,,, -8800,0.2614856,2.087563,,,,,,,,,,,,,,,,, -8900,0.22488357,2.0279593,,,,,,,,,,,,,,,,, -9000,0.15836121,2.1428163,,,,,,,,,,,,,,,,, -9100,0.17458647,2.0366497,,,,,,,,,,,,,,,,, -9200,0.18500656,2.1646814,,,,,,,,,,,,,,,,, -9300,0.1861779,2.0380404,,,,,,,,,,,,,,,,, -9400,0.2431181,2.1116016,,,,,,,,,,,,,,,,, -9497,,,0.6149968504905701,1.941268801689148,29.185273045229128,0.6318706274032593,1.8118021488189693,26.343849434747742,3000.0,0.6391029357910156,1.7555793523788452,25.708341384430515,3003.0,3391.2784502506256,6243.961631298065,3391.2784502506256,2852.2883291244507,0.0963246822357177,0.0 -9500,0.17756617,2.0143025,,,,,,,,,,,,,,,,, -9600,0.1900951,2.0276513,,,,,,,,,,,,,,,,, -9700,0.18527867,1.968916,,,,,,,,,,,,,,,,, -9800,0.20585753,2.00304,,,,,,,,,,,,,,,,, -9900,0.2213966,1.94209,,,,,,,,,,,,,,,,, -10000,0.19302124,2.0384567,,,,,,,,,,,,,,,,, -10100,0.1837676,1.9664214,,,,,,,,,,,,,,,,, -10200,0.24505116,2.0380595,,,,,,,,,,,,,,,,, -10300,0.15961564,1.9425073,,,,,,,,,,,,,,,,, -10400,0.17708902,1.9657127,,,,,,,,,,,,,,,,, -10500,0.17204642,2.0408506,,,,,,,,,,,,,,,,, -10600,0.16877966,1.9242203,,,,,,,,,,,,,,,,, -10700,0.19986898,1.9877654,,,,,,,,,,,,,,,,, -10800,0.19205841,2.0285687,,,,,,,,,,,,,,,,, -10900,0.20117208,2.0305467,,,,,,,,,,,,,,,,, -11000,0.17895518,2.0736337,,,,,,,,,,,,,,,,, -11100,0.18811576,1.9286697,,,,,,,,,,,,,,,,, -11200,0.20313278,2.002949,,,,,,,,,,,,,,,,, -11300,0.19593415,1.9704175,,,,,,,,,,,,,,,,, -11400,0.17860356,1.9232596,,,,,,,,,,,,,,,,, -11500,0.18646282,2.044755,,,,,,,,,,,,,,,,, -11600,0.20085919,2.0520194,,,,,,,,,,,,,,,,, -11700,0.29331222,1.913032,,,,,,,,,,,,,,,,, -11800,0.21708469,2.0282235,,,,,,,,,,,,,,,,, -11872,,,0.6193379759788513,1.89149820804596,29.845714220736607,0.6386777758598328,1.7426763772964478,26.896069609013946,3000.0,0.6488640904426575,1.6925368309020996,25.705053088847816,3003.0,4231.514441490173,7524.716588258743,4231.514441490173,3292.704610824585,0.1238934993743896,0.0 -11900,0.23911804,1.9495941,,,,,,,,,,,,,,,,, -12000,0.1949064,2.0380774,,,,,,,,,,,,,,,,, -12100,0.27447575,1.9193214,,,,,,,,,,,,,,,,, -12200,0.18060258,1.9819467,,,,,,,,,,,,,,,,, -12300,0.258412,2.0236373,,,,,,,,,,,,,,,,, -12400,0.18338649,1.9436365,,,,,,,,,,,,,,,,, -12500,0.2516966,1.9630642,,,,,,,,,,,,,,,,, -12600,0.19437692,1.8925228,,,,,,,,,,,,,,,,, -12700,0.19708855,1.87601,,,,,,,,,,,,,,,,, -12800,0.2007995,1.9733857,,,,,,,,,,,,,,,,, -12900,0.18276294,1.9399232,,,,,,,,,,,,,,,,, -13000,0.23044349,1.9152753,,,,,,,,,,,,,,,,, -13100,0.17860268,1.9591428,,,,,,,,,,,,,,,,, -13200,0.2555529,1.9587977,,,,,,,,,,,,,,,,, -13300,0.2089765,1.9604841,,,,,,,,,,,,,,,,, -13400,0.17950818,1.99238,,,,,,,,,,,,,,,,, -13500,0.2578282,1.9049062,,,,,,,,,,,,,,,,, -13600,1.8883461,2.178062,,,,,,,,,,,,,,,,, -13700,0.20074905,1.9301404,,,,,,,,,,,,,,,,, -13800,0.20259339,1.9582466,,,,,,,,,,,,,,,,, -13900,0.24975757,2.0239925,,,,,,,,,,,,,,,,, -14000,0.2274547,1.8591266,,,,,,,,,,,,,,,,, -14100,0.17408043,1.9222183,,,,,,,,,,,,,,,,, -14200,0.20513886,1.8875568,,,,,,,,,,,,,,,,, -14247,,,0.6294394135475159,1.81944739818573,30.450921466562807,0.6437985897064209,1.7072787284851074,27.1843550621261,3000.0,0.6537911891937256,1.640441656112671,26.303994280825968,3003.0,5071.745926856995,8857.325303792953,5071.745926856995,3784.979299783706,0.1510317325592041,0.0 -14300,0.16481468,1.9086802,,,,,,,,,,,,,,,,, -14400,0.17857587,1.9553167,,,,,,,,,,,,,,,,, -14500,0.21263568,1.9320364,,,,,,,,,,,,,,,,, -14600,0.23951365,2.0143087,,,,,,,,,,,,,,,,, -14700,0.19371498,1.9014686,,,,,,,,,,,,,,,,, -14800,0.20846286,1.9372805,,,,,,,,,,,,,,,,, -14900,0.27416623,1.8823726,,,,,,,,,,,,,,,,, -15000,0.32678828,1.8211492,,,,,,,,,,,,,,,,, -15100,0.20821586,1.9749563,,,,,,,,,,,,,,,,, -15200,0.18522897,1.938214,,,,,,,,,,,,,,,,, -15300,0.19015567,1.9248098,,,,,,,,,,,,,,,,, -15400,1.1835932,1.8450587,,,,,,,,,,,,,,,,, -15500,5.871781,1.9443237,,,,,,,,,,,,,,,,, -15600,0.21096347,1.9198539,,,,,,,,,,,,,,,,, -15700,0.20247096,1.8939879,,,,,,,,,,,,,,,,, -15800,0.18161616,1.8257328,,,,,,,,,,,,,,,,, -15900,0.20276147,1.9310663,,,,,,,,,,,,,,,,, -16000,0.35384834,1.8993627,,,,,,,,,,,,,,,,, -16100,0.22017612,1.9908853,,,,,,,,,,,,,,,,, -16200,0.3126064,1.8781426,,,,,,,,,,,,,,,,, -16300,0.19114083,1.8641557,,,,,,,,,,,,,,,,, -16400,0.26243278,1.812652,,,,,,,,,,,,,,,,, -16500,0.20759553,1.8761971,,,,,,,,,,,,,,,,, -16600,0.22499849,1.882858,,,,,,,,,,,,,,,,, -16621,,,0.6266187429428101,1.8285259008407595,30.356420384125588,0.6463403701782227,1.6791563034057615,27.10367908047491,3000.0,0.6579513549804688,1.6053942441940308,26.439084023361502,3003.0,5911.688598394394,10503.013573646544,5911.688598394394,4590.6215217113495,0.1793782711029052,0.0 -16700,0.21456571,1.8730488,,,,,,,,,,,,,,,,, -16800,0.2549689,1.9073212,,,,,,,,,,,,,,,,, -16900,0.23451488,1.8846382,,,,,,,,,,,,,,,,, -17000,0.1947665,1.8885746,,,,,,,,,,,,,,,,, -17100,0.19514124,1.921965,,,,,,,,,,,,,,,,, -17200,0.21556665,1.8477159,,,,,,,,,,,,,,,,, -17300,0.21433078,1.9220008,,,,,,,,,,,,,,,,, -17400,0.25825956,1.8737861,,,,,,,,,,,,,,,,, -17500,0.2311553,1.8942378,,,,,,,,,,,,,,,,, -17600,0.25373682,1.9369383,,,,,,,,,,,,,,,,, -17700,0.21936525,1.8724593,,,,,,,,,,,,,,,,, -17800,0.20866892,1.8546133,,,,,,,,,,,,,,,,, -17900,0.18698533,1.9012085,,,,,,,,,,,,,,,,, -18000,0.20681748,1.881174,,,,,,,,,,,,,,,,, -18100,0.29718143,1.9233395,,,,,,,,,,,,,,,,, -18200,0.20021473,1.8808429,,,,,,,,,,,,,,,,, -18300,0.22068016,1.9373043,,,,,,,,,,,,,,,,, -18400,0.21142563,1.9145814,,,,,,,,,,,,,,,,, -18500,0.2184753,1.9165158,,,,,,,,,,,,,,,,, -18600,0.22382313,1.8525124,,,,,,,,,,,,,,,,, -18700,0.20495939,1.8466718,,,,,,,,,,,,,,,,, -18800,0.25936893,1.9622891,,,,,,,,,,,,,,,,, -18900,0.19860895,1.7725132,,,,,,,,,,,,,,,,, -18996,,,0.6516609787940979,1.649769902229309,31.447995353884966,0.6500477194786072,1.6590960025787354,26.475614046524264,3000.0,0.6589855551719666,1.592655897140503,26.80562274957389,3003.0,6751.716594457626,12137.90736103058,6751.716594457626,5385.383625268936,0.2077019214630127,0.0 -19000,0.2242668,1.8461517,,,,,,,,,,,,,,,,, -19100,0.29742852,1.9011891,,,,,,,,,,,,,,,,, -19200,0.19443874,1.8630104,,,,,,,,,,,,,,,,, -19300,0.21442817,1.845603,,,,,,,,,,,,,,,,, -19400,0.20407163,1.8940126,,,,,,,,,,,,,,,,, -19500,0.21469216,1.8625402,,,,,,,,,,,,,,,,, -19600,0.22793402,1.8352857,,,,,,,,,,,,,,,,, -19700,0.28353885,1.8976716,,,,,,,,,,,,,,,,, -19800,0.1981632,1.9722832,,,,,,,,,,,,,,,,, -19900,0.18773128,1.8844305,,,,,,,,,,,,,,,,, -20000,0.22946805,1.8882579,,,,,,,,,,,,,,,,, -20100,0.26532206,1.8294047,,,,,,,,,,,,,,,,, -20200,1.2498246,2.172716,,,,,,,,,,,,,,,,, -20300,0.17536132,1.8794485,,,,,,,,,,,,,,,,, -20400,0.18859154,1.9222277,,,,,,,,,,,,,,,,, -20500,0.21469688,1.8686923,,,,,,,,,,,,,,,,, -20600,0.23351716,1.9047897,,,,,,,,,,,,,,,,, -20700,0.18038768,1.860484,,,,,,,,,,,,,,,,, -20800,0.27811006,1.8395953,,,,,,,,,,,,,,,,, -20900,0.18878028,1.780382,,,,,,,,,,,,,,,,, -21000,0.21924937,1.8092949,,,,,,,,,,,,,,,,, -21100,0.21100302,1.8180635,,,,,,,,,,,,,,,,, -21200,0.2874798,1.8523215,,,,,,,,,,,,,,,,, -21300,0.2228853,1.858904,,,,,,,,,,,,,,,,, -21370,,,0.6328192353248596,1.773112177848816,30.559165503155825,0.6511512398719788,1.648661971092224,27.1359312860208,3000.0,0.6615420579910278,1.5799243450164795,26.81812092467978,3003.0,7591.700333595276,13657.880095720291,7591.700333595276,6065.267931461334,0.2369616031646728,0.0 -21400,0.22740848,1.8702017,,,,,,,,,,,,,,,,, -21500,0.20291558,1.8722634,,,,,,,,,,,,,,,,, -21600,0.1983783,1.8394719,,,,,,,,,,,,,,,,, -21700,0.20055293,1.7966944,,,,,,,,,,,,,,,,, -21800,0.19621374,1.8517035,,,,,,,,,,,,,,,,, -21900,0.2047592,1.9307797,,,,,,,,,,,,,,,,, -22000,0.20218396,1.9009573,,,,,,,,,,,,,,,,, -22100,0.2558958,1.8973315,,,,,,,,,,,,,,,,, -22200,0.28072643,1.9432793,,,,,,,,,,,,,,,,, -22300,0.23903601,1.7590557,,,,,,,,,,,,,,,,, -22400,0.22390646,1.9041107,,,,,,,,,,,,,,,,, -22500,0.23986104,1.8933607,,,,,,,,,,,,,,,,, -22600,0.2080774,1.8098636,,,,,,,,,,,,,,,,, -22700,0.18856643,1.885862,,,,,,,,,,,,,,,,, -22800,0.1867493,1.82828,,,,,,,,,,,,,,,,, -22900,0.19472267,1.8492676,,,,,,,,,,,,,,,,, -23000,0.17982401,1.9392788,,,,,,,,,,,,,,,,, -23100,0.20266947,1.8720514,,,,,,,,,,,,,,,,, -23200,0.2067645,1.818778,,,,,,,,,,,,,,,,, -23300,0.30888456,1.8455893,,,,,,,,,,,,,,,,, -23400,0.18481341,1.8494729,,,,,,,,,,,,,,,,, -23500,0.27542755,1.8417394,,,,,,,,,,,,,,,,, -23600,0.2048139,1.8353662,,,,,,,,,,,,,,,,, -23700,0.21295273,1.7887092,,,,,,,,,,,,,,,,, -23744,,,0.6332682967185974,1.783609390258789,31.001136692097017,0.6562844514846802,1.6264350414276123,27.96646650715337,3000.0,0.6650747060775757,1.5590084791183472,27.07595711213196,3003.0,8431.671695709229,15047.690311908722,8431.671695709229,6615.000554323196,0.2675762176513672,0.0 -23800,0.21968232,1.8655238,,,,,,,,,,,,,,,,, -23900,0.20509209,1.9396483,,,,,,,,,,,,,,,,, -24000,0.19819432,1.8015609,,,,,,,,,,,,,,,,, -24100,0.20665269,1.8797309,,,,,,,,,,,,,,,,, -24200,0.24047427,1.8777021,,,,,,,,,,,,,,,,, -24300,0.20430727,1.87371,,,,,,,,,,,,,,,,, -24400,0.25596747,1.8597445,,,,,,,,,,,,,,,,, -24500,0.23653315,1.8508157,,,,,,,,,,,,,,,,, -24600,0.19060993,1.8646436,,,,,,,,,,,,,,,,, -24700,0.20976603,1.8448894,,,,,,,,,,,,,,,,, -24800,0.21218032,1.8629777,,,,,,,,,,,,,,,,, -24900,0.2212543,1.8153104,,,,,,,,,,,,,,,,, -25000,0.22208795,1.7943751,,,,,,,,,,,,,,,,, -25100,0.22621647,1.8227948,,,,,,,,,,,,,,,,, -25200,0.24241582,1.8445005,,,,,,,,,,,,,,,,, -25300,0.22095981,1.9309037,,,,,,,,,,,,,,,,, -25400,0.21027093,1.8197248,,,,,,,,,,,,,,,,, -25500,0.20898515,1.7389332,,,,,,,,,,,,,,,,, -25600,0.20497395,1.7741549,,,,,,,,,,,,,,,,, -25700,0.21076994,1.8712354,,,,,,,,,,,,,,,,, -25800,0.2186018,1.9756242,,,,,,,,,,,,,,,,, -25900,0.23321788,1.7525246,,,,,,,,,,,,,,,,, -26000,0.18924156,1.753227,,,,,,,,,,,,,,,,, -26100,0.20785786,1.7809583,,,,,,,,,,,,,,,,, -26119,,,0.6439311504364014,1.6942683458328247,31.45628583204521,0.6556025147438049,1.624097228050232,27.50174142438485,3000.0,0.6670966148376465,1.5462366342544556,27.10684203663389,3003.0,9271.90131354332,16336.804920434952,9271.90131354332,7063.780653953552,0.2965688705444336,0.0 -26200,0.17998385,1.8691852,,,,,,,,,,,,,,,,, -26300,0.20336725,1.8070164,,,,,,,,,,,,,,,,, -26400,0.19203922,1.8689227,,,,,,,,,,,,,,,,, -26500,0.21692325,1.8316356,,,,,,,,,,,,,,,,, -26600,0.20756102,1.8883498,,,,,,,,,,,,,,,,, -26700,0.2083464,1.830601,,,,,,,,,,,,,,,,, -26800,0.21990773,1.8491927,,,,,,,,,,,,,,,,, -26900,0.21918486,1.8976055,,,,,,,,,,,,,,,,, -27000,0.22806029,1.7569746,,,,,,,,,,,,,,,,, -27100,0.83477026,1.7868294,,,,,,,,,,,,,,,,, -27200,0.19053178,1.7834231,,,,,,,,,,,,,,,,, -27300,0.29408118,1.8812984,,,,,,,,,,,,,,,,, -27400,0.2466785,1.7363558,,,,,,,,,,,,,,,,, -27500,0.21191815,1.8658942,,,,,,,,,,,,,,,,, -27600,0.18150683,1.8320627,,,,,,,,,,,,,,,,, -27700,0.21109451,1.8083168,,,,,,,,,,,,,,,,, -27800,0.2185209,1.849748,,,,,,,,,,,,,,,,, -27900,0.19956039,1.8102503,,,,,,,,,,,,,,,,, -28000,0.23654138,1.8415654,,,,,,,,,,,,,,,,, -28100,0.22297008,1.8171554,,,,,,,,,,,,,,,,, -28200,0.25709844,1.8222405,,,,,,,,,,,,,,,,, -28300,0.21632883,1.7972478,,,,,,,,,,,,,,,,, -28400,2.2485163,1.9057356,,,,,,,,,,,,,,,,, -28493,,,0.6390884518623352,1.7399277687072754,31.07782237514248,0.6567432284355164,1.6084965467453003,27.80369912167485,3000.0,0.6716286540031433,1.5281535387039185,27.86535354891519,3003.0,10112.068130731584,17797.767949342728,10112.068130731584,7684.471136808395,0.3272538185119629,0.0 -28500,0.18364374,1.8672329,,,,,,,,,,,,,,,,, -28600,0.19147114,1.8428569,,,,,,,,,,,,,,,,, -28700,0.20789778,1.8566133,,,,,,,,,,,,,,,,, -28800,0.1993187,1.8490287,,,,,,,,,,,,,,,,, -28900,0.19339113,1.8377906,,,,,,,,,,,,,,,,, -29000,0.22699553,1.9748673,,,,,,,,,,,,,,,,, -29100,0.2501487,1.797816,,,,,,,,,,,,,,,,, -29200,0.23953132,1.7983799,,,,,,,,,,,,,,,,, -29300,0.20952094,1.8258,,,,,,,,,,,,,,,,, -29400,0.21437885,1.8252064,,,,,,,,,,,,,,,,, -29500,0.19239244,1.7785465,,,,,,,,,,,,,,,,, -29600,0.19404304,1.7866571,,,,,,,,,,,,,,,,, -29700,0.21001074,1.8628383,,,,,,,,,,,,,,,,, -29800,0.3099407,1.8259739,,,,,,,,,,,,,,,,, -29900,0.21155263,1.7731551,,,,,,,,,,,,,,,,, -30000,0.20704879,1.787321,,,,,,,,,,,,,,,,, -30100,0.19316207,1.7426101,,,,,,,,,,,,,,,,, -30200,0.24835823,1.8840194,,,,,,,,,,,,,,,,, -30300,0.2833877,1.77991,,,,,,,,,,,,,,,,, -30400,0.19797894,1.7656871,,,,,,,,,,,,,,,,, -30500,0.22778891,1.8446951,,,,,,,,,,,,,,,,, -30600,0.23481846,1.8116906,,,,,,,,,,,,,,,,, -30700,0.22164294,1.7776217,,,,,,,,,,,,,,,,, -30800,0.24081989,1.8233149,,,,,,,,,,,,,,,,, -30867,,,0.6361994743347168,1.7502000331878662,30.96582954608702,0.6586031317710876,1.5999562740325928,27.939569403470813,3000.0,0.6713497638702393,1.5222526788711548,27.451590240782444,3003.0,10952.097707033156,19123.16144561768,10952.097707033156,8169.727640390396,0.3598606586456299,0.0 -30900,0.19905247,1.6815281,,,,,,,,,,,,,,,,, -31000,0.5853167,1.7845646,,,,,,,,,,,,,,,,, -31100,0.2237701,1.7655495,,,,,,,,,,,,,,,,, -31200,5.3653393,1.7542428,,,,,,,,,,,,,,,,, -31300,0.31367353,1.8239496,,,,,,,,,,,,,,,,, -31400,0.20051415,1.7205691,,,,,,,,,,,,,,,,, -31500,0.2931256,1.7832441,,,,,,,,,,,,,,,,, -31600,0.21597269,1.8367621,,,,,,,,,,,,,,,,, -31700,0.1991923,1.7970607,,,,,,,,,,,,,,,,, -31800,0.21247372,1.8369102,,,,,,,,,,,,,,,,, -31900,0.2157662,1.8331298,,,,,,,,,,,,,,,,, -32000,0.2027128,1.8100411,,,,,,,,,,,,,,,,, -32100,0.24223553,1.8727157,,,,,,,,,,,,,,,,, -32200,0.18562542,1.7707814,,,,,,,,,,,,,,,,, -32300,0.19631213,1.727989,,,,,,,,,,,,,,,,, -32400,0.21995287,1.779993,,,,,,,,,,,,,,,,, -32500,0.20691805,1.7568138,,,,,,,,,,,,,,,,, -32600,0.26264957,1.720482,,,,,,,,,,,,,,,,, -32700,0.2110662,1.8491892,,,,,,,,,,,,,,,,, -32800,0.21074688,1.8490958,,,,,,,,,,,,,,,,, -32900,0.189984,1.838014,,,,,,,,,,,,,,,,, -33000,0.2018593,1.8027244,,,,,,,,,,,,,,,,, -33100,0.24408865,1.8067203,,,,,,,,,,,,,,,,, -33200,0.19071521,1.8272264,,,,,,,,,,,,,,,,, -33242,,,0.6436529755592346,1.703749179840088,31.702394127417445,0.6596198081970215,1.5879870653152466,28.28841359912601,3000.0,0.671884298324585,1.5101147890090942,27.42189268731702,3003.0,11792.311067819595,20476.148945569992,11792.311067819595,8682.395342111588,0.3906929492950439,0.0 -33300,0.20863104,1.8046998,,,,,,,,,,,,,,,,, -33400,0.20936923,1.8764856,,,,,,,,,,,,,,,,, -33500,0.19177632,1.8264068,,,,,,,,,,,,,,,,, -33600,0.22938395,1.7893913,,,,,,,,,,,,,,,,, -33700,0.20177634,1.7906845,,,,,,,,,,,,,,,,, -33800,0.22148445,1.821316,,,,,,,,,,,,,,,,, -33900,0.19492845,1.7219716,,,,,,,,,,,,,,,,, -34000,0.2284512,1.8235292,,,,,,,,,,,,,,,,, -34100,0.28139052,1.8279954,,,,,,,,,,,,,,,,, -34200,0.22070158,1.8448272,,,,,,,,,,,,,,,,, -34300,0.2171639,1.7412664,,,,,,,,,,,,,,,,, -34400,0.19072756,1.7209166,,,,,,,,,,,,,,,,, -34500,0.20071906,1.789972,,,,,,,,,,,,,,,,, -34600,0.21894437,1.766323,,,,,,,,,,,,,,,,, -34700,0.19256988,1.8068035,,,,,,,,,,,,,,,,, -34800,0.35360146,1.7509196,,,,,,,,,,,,,,,,, -34900,0.42398146,1.8119658,,,,,,,,,,,,,,,,, -35000,0.20386599,1.7431291,,,,,,,,,,,,,,,,, -35100,0.2101701,1.8894594,,,,,,,,,,,,,,,,, -35200,0.29175624,1.7661237,,,,,,,,,,,,,,,,, -35300,0.22598977,1.8774471,,,,,,,,,,,,,,,,, -35400,0.20296215,1.7996091,,,,,,,,,,,,,,,,, -35500,0.2087022,1.7993276,,,,,,,,,,,,,,,,, -35600,0.19670574,1.7394902,,,,,,,,,,,,,,,,, -35616,,,0.641491711139679,1.723802089691162,31.370909674220155,0.6618516445159912,1.5850014686584473,28.14534491270295,3000.0,0.6744756698608398,1.5015976428985596,27.731952103228185,3003.0,12632.403628587725,21773.65692186356,12632.403628587725,9139.705032587051,0.4213588237762451,0.0 -35700,0.23224625,1.7940634,,,,,,,,,,,,,,,,, -35800,0.23106594,1.8155626,,,,,,,,,,,,,,,,, -35900,0.19019908,1.7636471,,,,,,,,,,,,,,,,, -36000,0.20548783,1.843235,,,,,,,,,,,,,,,,, -36100,0.22346637,1.856326,,,,,,,,,,,,,,,,, -36200,0.20616046,1.7906238,,,,,,,,,,,,,,,,, -36300,0.20702702,1.7710457,,,,,,,,,,,,,,,,, -36400,0.20725167,1.8569272,,,,,,,,,,,,,,,,, -36500,0.2038578,1.7216464,,,,,,,,,,,,,,,,, -36600,0.19885668,1.8269886,,,,,,,,,,,,,,,,, -36700,0.22382659,1.7440051,,,,,,,,,,,,,,,,, -36800,0.20004748,1.8342435,,,,,,,,,,,,,,,,, -36900,0.21900441,1.8215241,,,,,,,,,,,,,,,,, -37000,0.24445541,1.760591,,,,,,,,,,,,,,,,, -37100,0.2523526,1.7549757,,,,,,,,,,,,,,,,, -37200,0.23184451,1.8287342,,,,,,,,,,,,,,,,, -37300,0.2508037,1.7577053,,,,,,,,,,,,,,,,, -37400,0.30185342,1.7978723,,,,,,,,,,,,,,,,, -37500,0.19521871,1.816036,,,,,,,,,,,,,,,,, -37600,0.2325899,1.7729012,,,,,,,,,,,,,,,,, -37700,0.20104218,1.8211284,,,,,,,,,,,,,,,,, -37800,0.21358304,1.7370168,,,,,,,,,,,,,,,,, -37900,0.20401014,1.8451684,,,,,,,,,,,,,,,,, -37990,,,0.6566494703292847,1.6067266464233398,32.589987116770835,0.6636743545532227,1.575439691543579,28.37420393676452,3000.0,0.6746964454650879,1.4958630800247192,27.57462241443689,3003.0,13472.459535121918,23101.66996979713,13472.459535121918,9627.55661559105,0.4520916938781738,0.0 -38000,0.19491898,1.7618215,,,,,,,,,,,,,,,,, -38100,0.18813114,1.7446486,,,,,,,,,,,,,,,,, -38200,0.2058162,1.7953874,,,,,,,,,,,,,,,,, -38300,0.20526116,1.771652,,,,,,,,,,,,,,,,, -38400,0.2285583,1.736385,,,,,,,,,,,,,,,,, -38500,0.17439547,1.7626995,,,,,,,,,,,,,,,,, -38600,0.19649531,1.8266708,,,,,,,,,,,,,,,,, -38700,0.19557269,1.7270877,,,,,,,,,,,,,,,,, -38800,0.25532162,1.7615136,,,,,,,,,,,,,,,,, -38900,0.25414053,1.8364705,,,,,,,,,,,,,,,,, -39000,0.19503805,1.7001077,,,,,,,,,,,,,,,,, -39100,0.22731976,1.7860354,,,,,,,,,,,,,,,,, -39200,0.21621257,1.813513,,,,,,,,,,,,,,,,, -39300,0.27122682,1.9074067,,,,,,,,,,,,,,,,, -39400,0.25605854,1.7777425,,,,,,,,,,,,,,,,, -39500,0.21348879,1.7504542,,,,,,,,,,,,,,,,, -39600,0.17833999,1.7609289,,,,,,,,,,,,,,,,, -39700,0.21330649,1.7742624,,,,,,,,,,,,,,,,, -39800,0.2052633,1.7523767,,,,,,,,,,,,,,,,, -39900,0.18628418,1.7054963,,,,,,,,,,,,,,,,, -40000,0.18992843,1.7396554,,,,,,,,,,,,,,,,, -40100,0.21679628,1.7213721,,,,,,,,,,,,,,,,, -40200,0.22569397,1.744194,,,,,,,,,,,,,,,,, -40300,0.20462464,1.8242044,,,,,,,,,,,,,,,,, -40364,,,0.6478399634361267,1.676770567893982,31.48029070214986,0.6622112393379211,1.57160747051239,28.142196735251872,3000.0,0.6744988560676575,1.485736846923828,27.850430361041195,3003.0,14312.40743112564,24606.04882979393,14312.40743112564,10291.879784822464,0.4850766658782959,0.0 -40400,0.27628028,1.8536195,,,,,,,,,,,,,,,,, -40500,0.21666989,1.7701792,,,,,,,,,,,,,,,,, -40600,0.20127623,1.7128494,,,,,,,,,,,,,,,,, -40700,0.20591968,1.767967,,,,,,,,,,,,,,,,, -40800,0.21466655,1.8025262,,,,,,,,,,,,,,,,, -40900,0.20406505,1.832781,,,,,,,,,,,,,,,,, -41000,0.20912965,1.7057282,,,,,,,,,,,,,,,,, -41100,0.20566466,1.8146155,,,,,,,,,,,,,,,,, -41200,0.20381272,1.7002826,,,,,,,,,,,,,,,,, -41300,0.20903148,1.7811608,,,,,,,,,,,,,,,,, -41400,0.21971661,1.7830211,,,,,,,,,,,,,,,,, -41500,0.19622792,1.6320415,,,,,,,,,,,,,,,,, -41600,0.2401068,1.8237858,,,,,,,,,,,,,,,,, -41700,0.19494615,1.7065966,,,,,,,,,,,,,,,,, -41800,0.19951978,1.8354359,,,,,,,,,,,,,,,,, -41900,0.18785322,1.8706446,,,,,,,,,,,,,,,,, -42000,0.21019134,1.7460307,,,,,,,,,,,,,,,,, -42100,0.21773566,1.8712717,,,,,,,,,,,,,,,,, -42200,0.21136783,1.7850128,,,,,,,,,,,,,,,,, -42300,0.21553025,1.8420739,,,,,,,,,,,,,,,,, -42400,0.19079733,1.7215953,,,,,,,,,,,,,,,,, -42500,0.23762564,1.7550453,,,,,,,,,,,,,,,,, -42600,0.20683023,1.7819188,,,,,,,,,,,,,,,,, -42700,0.20916572,1.7297767,,,,,,,,,,,,,,,,, -42738,,,0.6446605920791626,1.6969213485717771,31.174335951599705,0.6645174622535706,1.5625895261764526,28.429532815753745,3000.0,0.6777177453041077,1.4755793809890747,28.13747402749904,3003.0,15152.398067235948,25978.72548937797,15152.398067235948,10824.458010911942,0.5174057483673096,0.0 -42800,0.19309376,1.7804921,,,,,,,,,,,,,,,,, -42900,0.20602901,1.7804791,,,,,,,,,,,,,,,,, -43000,0.219704,1.8007048,,,,,,,,,,,,,,,,, -43100,0.21520826,1.7685237,,,,,,,,,,,,,,,,, -43200,0.20090999,1.7117361,,,,,,,,,,,,,,,,, -43300,0.71567863,1.7485545,,,,,,,,,,,,,,,,, -43400,0.20924853,1.7296247,,,,,,,,,,,,,,,,, -43500,0.18925937,1.7251183,,,,,,,,,,,,,,,,, -43600,0.23809576,1.751908,,,,,,,,,,,,,,,,, -43700,0.21639305,1.7456585,,,,,,,,,,,,,,,,, -43800,0.19966854,1.817845,,,,,,,,,,,,,,,,, -43900,0.19616957,1.7083368,,,,,,,,,,,,,,,,, -44000,0.19513483,1.7424979,,,,,,,,,,,,,,,,, -44100,0.19531712,1.818376,,,,,,,,,,,,,,,,, -44200,0.20373367,1.6603945,,,,,,,,,,,,,,,,, -44300,0.24368852,1.8156964,,,,,,,,,,,,,,,,, -44400,0.20593435,1.7292114,,,,,,,,,,,,,,,,, -44500,0.1800585,1.7773458,,,,,,,,,,,,,,,,, -44600,0.19545944,1.7579621,,,,,,,,,,,,,,,,, -44700,0.2175535,1.8092036,,,,,,,,,,,,,,,,, -44800,0.20805386,1.7322696,,,,,,,,,,,,,,,,, -44900,0.22034912,1.7129796,,,,,,,,,,,,,,,,, -45000,0.22696131,1.7226361,,,,,,,,,,,,,,,,, -45100,0.19253749,1.790042,,,,,,,,,,,,,,,,, -45112,,,0.6525201201438904,1.646553874015808,32.448058247862534,0.6664517521858215,1.550062656402588,28.67135209872841,3000.0,0.6789379119873047,1.4708216190338137,28.37878554052869,3003.0,15992.489179372787,27374.3735370636,15992.489179372787,11379.908358573914,0.5488555431365967,0.0 -45200,0.18244258,1.7500327,,,,,,,,,,,,,,,,, -45300,0.19456159,1.734532,,,,,,,,,,,,,,,,, -45400,0.2169384,1.7867248,,,,,,,,,,,,,,,,, -45500,0.22473443,1.7570686,,,,,,,,,,,,,,,,, -45600,0.19496222,1.7430516,,,,,,,,,,,,,,,,, -45700,0.22748834,1.7971268,,,,,,,,,,,,,,,,, -45800,0.18633828,1.7590289,,,,,,,,,,,,,,,,, -45900,0.2014972,1.7837336,,,,,,,,,,,,,,,,, -46000,0.21870972,1.7490598,,,,,,,,,,,,,,,,, -46100,0.2202361,1.8131111,,,,,,,,,,,,,,,,, -46200,0.6125642,1.7784177,,,,,,,,,,,,,,,,, -46300,0.21338029,1.6955935,,,,,,,,,,,,,,,,, -46400,0.21113306,1.6811595,,,,,,,,,,,,,,,,, -46500,0.19340909,1.892583,,,,,,,,,,,,,,,,, -46600,0.21945153,1.7690467,,,,,,,,,,,,,,,,, -46700,0.2531389,1.7790216,,,,,,,,,,,,,,,,, -46800,0.22048828,1.7494189,,,,,,,,,,,,,,,,, -46900,0.20394175,1.7105906,,,,,,,,,,,,,,,,, -47000,0.22927186,1.6795409,,,,,,,,,,,,,,,,, -47100,0.1863914,1.7667723,,,,,,,,,,,,,,,,, -47200,0.21600227,1.7962224,,,,,,,,,,,,,,,,, -47300,0.28924227,1.7068708,,,,,,,,,,,,,,,,, -47400,0.20869088,1.8067248,,,,,,,,,,,,,,,,, -47487,,,0.6500367522239685,1.6660417318344116,32.292675817863426,0.6654598116874695,1.5579570531845093,28.594837733209264,3000.0,0.6797629594802856,1.469280242919922,28.22441199348882,3003.0,16832.455917835236,28707.885778665543,16832.455917835236,11873.347818136215,0.5799081325531006,0.0 -47500,0.185047,1.7455013,,,,,,,,,,,,,,,,, -47600,0.24149746,1.6922966,,,,,,,,,,,,,,,,, -47700,0.21621607,1.7270926,,,,,,,,,,,,,,,,, -47800,0.2079395,1.7567108,,,,,,,,,,,,,,,,, -47900,0.22656202,1.7594789,,,,,,,,,,,,,,,,, -48000,0.20040102,1.7400694,,,,,,,,,,,,,,,,, -48100,0.21950293,1.7903855,,,,,,,,,,,,,,,,, -48200,0.20489757,1.8164171,,,,,,,,,,,,,,,,, -48300,0.19108567,1.6813033,,,,,,,,,,,,,,,,, -48400,0.19569367,1.7031896,,,,,,,,,,,,,,,,, -48500,0.28557047,1.7515131,,,,,,,,,,,,,,,,, -48600,0.21102495,1.8028045,,,,,,,,,,,,,,,,, -48700,0.2674913,1.7611922,,,,,,,,,,,,,,,,, -48800,0.20457144,1.7081918,,,,,,,,,,,,,,,,, -48900,0.18619676,1.7319046,,,,,,,,,,,,,,,,, -49000,0.21540582,1.6739274,,,,,,,,,,,,,,,,, -49100,0.18857457,1.7820572,,,,,,,,,,,,,,,,, -49200,0.19935393,1.6527328,,,,,,,,,,,,,,,,, -49300,0.36802623,1.7904091,,,,,,,,,,,,,,,,, -49400,0.24001594,1.7342118,,,,,,,,,,,,,,,,, -49500,0.19228792,1.7776453,,,,,,,,,,,,,,,,, -49600,0.22215062,1.7487693,,,,,,,,,,,,,,,,, -49700,0.19978838,1.7505823,,,,,,,,,,,,,,,,, -49800,0.20809202,1.7599419,,,,,,,,,,,,,,,,, -49862,,,0.6516985297203064,1.6592146158218384,31.26118472682359,0.6688819527626038,1.5371198654174805,29.092173467752627,3000.0,0.6824705004692078,1.4507604837417605,28.232030328801063,3003.0,17672.611881494522,30103.71274662018,17672.611881494522,12428.909444093704,0.6142618656158447,0.0 -49900,0.2230205,1.8227284,,,,,,,,,,,,,,,,, -50000,0.2033648,1.7226971,,,,,,,,,,,,,,,,, -50100,0.26451024,1.7667716,,,,,,,,,,,,,,,,, -50200,0.20369658,1.6855851,,,,,,,,,,,,,,,,, -50300,0.23873185,1.6940026,,,,,,,,,,,,,,,,, -50400,0.21157618,1.7872137,,,,,,,,,,,,,,,,, -50500,0.2171009,1.7145873,,,,,,,,,,,,,,,,, -50600,0.1998731,1.657715,,,,,,,,,,,,,,,,, -50700,0.21898527,1.7940378,,,,,,,,,,,,,,,,, -50800,0.21786922,1.7061691,,,,,,,,,,,,,,,,, -50900,0.26006147,1.8085562,,,,,,,,,,,,,,,,, -51000,0.2163729,1.7056632,,,,,,,,,,,,,,,,, -51100,0.23251773,1.7564738,,,,,,,,,,,,,,,,, -51200,0.23296061,1.7950833,,,,,,,,,,,,,,,,, -51300,0.21257249,1.7665495,,,,,,,,,,,,,,,,, -51400,0.19264752,1.7171621,,,,,,,,,,,,,,,,, -51500,0.2002291,1.7111133,,,,,,,,,,,,,,,,, -51600,0.21233852,1.7315999,,,,,,,,,,,,,,,,, -51700,0.20776407,1.7084677,,,,,,,,,,,,,,,,, -51800,0.23207541,1.8332527,,,,,,,,,,,,,,,,, -51900,0.21097624,1.6729891,,,,,,,,,,,,,,,,, -52000,0.23494361,1.6525284,,,,,,,,,,,,,,,,, -52100,0.21778685,1.7550944,,,,,,,,,,,,,,,,, -52200,0.2017439,1.74372,,,,,,,,,,,,,,,,, -52237,,,0.651114284992218,1.653276443481445,32.15875856369312,0.670059859752655,1.5286074876785278,29.08356031013785,3000.0,0.6822032332420349,1.4496662616729736,28.419362421745703,3003.0,18512.59501695633,31551.95904159546,18512.59501695633,13037.063804626465,0.6480474472045898,0.0 -52300,0.21743642,1.7072675,,,,,,,,,,,,,,,,, -52400,0.1967561,1.7514049,,,,,,,,,,,,,,,,, -52500,1.0842593,1.7625527,,,,,,,,,,,,,,,,, -52600,0.668306,2.3662622,,,,,,,,,,,,,,,,, -52700,0.20529008,1.7241987,,,,,,,,,,,,,,,,, -52800,0.1941947,1.7879891,,,,,,,,,,,,,,,,, -52900,0.20338352,1.7872094,,,,,,,,,,,,,,,,, -53000,0.19381121,1.7503448,,,,,,,,,,,,,,,,, -53100,0.22328539,1.7311047,,,,,,,,,,,,,,,,, -53200,0.19590704,1.7004336,,,,,,,,,,,,,,,,, -53300,0.19881278,1.7079747,,,,,,,,,,,,,,,,, -53400,0.19517392,1.7004231,,,,,,,,,,,,,,,,, -53500,0.21778911,1.7565147,,,,,,,,,,,,,,,,, -53600,0.21238153,1.7762169,,,,,,,,,,,,,,,,, -53700,0.2936883,1.8002492,,,,,,,,,,,,,,,,, -53800,0.19137564,1.7326313,,,,,,,,,,,,,,,,, -53900,0.19074853,1.6670536,,,,,,,,,,,,,,,,, -54000,0.20631696,1.7256398,,,,,,,,,,,,,,,,, -54100,0.21197814,1.6618826,,,,,,,,,,,,,,,,, -54200,0.27301311,3.822359,,,,,,,,,,,,,,,,, -54300,3.1851752,3.4998703,,,,,,,,,,,,,,,,, -54400,0.1876404,1.7235665,,,,,,,,,,,,,,,,, -54500,0.19049904,1.6955408,,,,,,,,,,,,,,,,, -54600,0.19800729,1.7631444,,,,,,,,,,,,,,,,, -54612,,,0.6498084664344788,1.666841983795166,31.615262520847377,0.6697251200675964,1.5325440168380735,28.79401454291601,3000.0,0.6830283403396606,1.4475644826889038,28.666378067688967,3003.0,19352.575419425964,32935.361577510834,19352.575419425964,13580.379266500471,0.6796233654022217,0.0 -54700,0.18984157,1.7214631,,,,,,,,,,,,,,,,, -54800,0.21288978,1.7179134,,,,,,,,,,,,,,,,, -54900,0.2123039,1.774335,,,,,,,,,,,,,,,,, -55000,0.19195776,1.70667,,,,,,,,,,,,,,,,, -55100,0.24253888,1.7881551,,,,,,,,,,,,,,,,, -55200,0.24766506,1.7500335,,,,,,,,,,,,,,,,, -55300,0.19509521,1.7249273,,,,,,,,,,,,,,,,, -55400,0.18746918,1.6949719,,,,,,,,,,,,,,,,, -55500,0.20529832,1.6997625,,,,,,,,,,,,,,,,, -55600,0.21321939,1.7344573,,,,,,,,,,,,,,,,, -55700,0.20662451,1.7073941,,,,,,,,,,,,,,,,, -55800,0.196587,1.7855319,,,,,,,,,,,,,,,,, -55900,0.20907712,1.7302628,,,,,,,,,,,,,,,,, -56000,0.20013024,1.7502948,,,,,,,,,,,,,,,,, -56100,0.19597314,1.6078825,,,,,,,,,,,,,,,,, -56200,0.2207223,1.7117436,,,,,,,,,,,,,,,,, -56300,0.26927742,1.768239,,,,,,,,,,,,,,,,, -56400,0.1967985,1.7122656,,,,,,,,,,,,,,,,, -56500,0.21081026,1.7426383,,,,,,,,,,,,,,,,, -56600,0.18093525,1.6397626,,,,,,,,,,,,,,,,, -56700,0.22405984,1.7721843,,,,,,,,,,,,,,,,, -56800,0.2144405,1.7108535,,,,,,,,,,,,,,,,, -56900,0.1927725,1.7237178,,,,,,,,,,,,,,,,, -56987,,,0.663896381855011,1.5628291368484497,32.592228093350805,0.6697623133659363,1.524965763092041,28.77605624550842,3000.0,0.6824240684509277,1.442415714263916,28.52835011043656,3003.0,20192.615739822388,34299.67973899841,20192.615739822388,14104.548913478851,0.7127649784088135,0.0 -57000,0.19901834,1.6908407,,,,,,,,,,,,,,,,, -57100,0.21372616,1.8119211,,,,,,,,,,,,,,,,, -57200,0.21279937,1.6493871,,,,,,,,,,,,,,,,, -57300,0.20126957,1.7346668,,,,,,,,,,,,,,,,, -57400,0.19700202,1.7038386,,,,,,,,,,,,,,,,, -57500,0.19672489,1.7593048,,,,,,,,,,,,,,,,, -57600,0.18353085,1.7163502,,,,,,,,,,,,,,,,, -57700,0.19199814,1.6668026,,,,,,,,,,,,,,,,, -57800,0.19432645,1.6711285,,,,,,,,,,,,,,,,, -57900,0.1938583,1.7043765,,,,,,,,,,,,,,,,, -58000,0.20069526,1.6630201,,,,,,,,,,,,,,,,, -58100,0.18836185,1.7305882,,,,,,,,,,,,,,,,, -58200,0.22181375,1.683516,,,,,,,,,,,,,,,,, -58300,0.20417425,1.6906736,,,,,,,,,,,,,,,,, -58400,0.18529621,1.6907893,,,,,,,,,,,,,,,,, -58500,0.20767868,1.7410616,,,,,,,,,,,,,,,,, -58600,0.21089242,1.7203774,,,,,,,,,,,,,,,,, -58700,0.20698546,1.7532682,,,,,,,,,,,,,,,,, -58800,0.2090078,1.763665,,,,,,,,,,,,,,,,, -58900,0.17485218,1.6883025,,,,,,,,,,,,,,,,, -59000,0.23212455,1.6429033,,,,,,,,,,,,,,,,, -59100,0.21677129,1.742899,,,,,,,,,,,,,,,,, -59200,0.19838688,1.7146211,,,,,,,,,,,,,,,,, -59300,0.19958778,1.818522,,,,,,,,,,,,,,,,, -59362,,,0.6547658443450928,1.639319896697998,32.32660768039863,0.6719445586204529,1.517433762550354,28.431021789399217,3000.0,0.6865028142929077,1.425528049468994,27.1299610500768,3003.0,21032.84838962555,35920.757283210754,21032.84838962555,14885.285350084305,0.7453715801239014,0.0 -59400,0.19506536,1.7002678,,,,,,,,,,,,,,,,, -59500,0.19623862,1.704052,,,,,,,,,,,,,,,,, -59600,0.19786893,1.6407158,,,,,,,,,,,,,,,,, -59700,0.21140145,1.7463558,,,,,,,,,,,,,,,,, -59800,0.19001275,1.686074,,,,,,,,,,,,,,,,, -59900,0.19507961,1.6677147,,,,,,,,,,,,,,,,, -60000,0.21472485,1.6987138,,,,,,,,,,,,,,,,, -60100,0.20585553,1.739916,,,,,,,,,,,,,,,,, -60200,0.18442008,1.7022679,,,,,,,,,,,,,,,,, -60300,0.18925834,1.7467642,,,,,,,,,,,,,,,,, -60400,0.2163933,1.7631992,,,,,,,,,,,,,,,,, -60500,0.20551424,1.7434059,,,,,,,,,,,,,,,,, -60600,0.18080527,1.7194599,,,,,,,,,,,,,,,,, -60700,0.22863498,1.6530704,,,,,,,,,,,,,,,,, -60800,0.20924644,1.7997154,,,,,,,,,,,,,,,,, -60900,0.19599825,1.7414148,,,,,,,,,,,,,,,,, -61000,0.19205195,1.7276775,,,,,,,,,,,,,,,,, -61100,0.23331796,1.694085,,,,,,,,,,,,,,,,, -61200,0.20649049,1.6973325,,,,,,,,,,,,,,,,, -61300,0.19484329,1.7296014,,,,,,,,,,,,,,,,, -61400,0.2069285,1.6519308,,,,,,,,,,,,,,,,, -61500,0.19286111,1.6486324,,,,,,,,,,,,,,,,, -61600,0.21500838,1.7098796,,,,,,,,,,,,,,,,, -61700,0.1940083,1.6893573,,,,,,,,,,,,,,,,, -61736,,,0.6519191861152649,1.6469298601150513,32.12458710119513,0.673072874546051,1.5103808641433716,29.256527875615028,3000.0,0.6850618720054626,1.4241111278533936,28.778837217794965,3003.0,21872.760452747345,37275.04721617699,21872.760452747345,15399.552985191343,0.7798454761505127,0.0 -61800,0.21451785,1.7395018,,,,,,,,,,,,,,,,, -61900,0.19930503,1.7397945,,,,,,,,,,,,,,,,, -62000,0.24088988,1.7533953,,,,,,,,,,,,,,,,, -62100,0.18772374,1.6570952,,,,,,,,,,,,,,,,, -62200,0.20548597,1.7443715,,,,,,,,,,,,,,,,, -62300,0.18826604,1.7209182,,,,,,,,,,,,,,,,, -62400,0.22242261,1.679303,,,,,,,,,,,,,,,,, -62500,0.21875894,1.6902195,,,,,,,,,,,,,,,,, -62600,0.19802634,1.6493527,,,,,,,,,,,,,,,,, -62700,0.2036722,1.6953148,,,,,,,,,,,,,,,,, -62800,0.18991014,1.7681152,,,,,,,,,,,,,,,,, -62900,0.21510994,1.7247158,,,,,,,,,,,,,,,,, -63000,0.19430076,1.7308239,,,,,,,,,,,,,,,,, -63100,0.18665479,1.6749942,,,,,,,,,,,,,,,,, -63200,0.20046115,1.6492034,,,,,,,,,,,,,,,,, -63300,0.218744,1.6551017,,,,,,,,,,,,,,,,, -63400,0.21942051,1.7378763,,,,,,,,,,,,,,,,, -63500,0.1964489,1.7215052,,,,,,,,,,,,,,,,, -63600,0.20957097,1.7213722,,,,,,,,,,,,,,,,, -63700,0.20484059,1.6512992,,,,,,,,,,,,,,,,, -63800,0.20926131,1.7055902,,,,,,,,,,,,,,,,, -63900,0.20343652,1.7018136,,,,,,,,,,,,,,,,, -64000,0.19282396,1.7177353,,,,,,,,,,,,,,,,, -64100,0.21491106,1.6831923,,,,,,,,,,,,,,,,, -64110,,,0.6602867245674133,1.5844924449920654,32.69674771030805,0.6743127703666687,1.5003536939620972,29.04694521750138,3000.0,0.6873162388801575,1.4133727550506592,28.827913304890146,3003.0,22712.73940300941,38790.858786821365,22712.73940300941,16075.27707862854,0.812446117401123,0.0 -64200,0.19478944,1.7302254,,,,,,,,,,,,,,,,, -64300,0.20101106,1.7152544,,,,,,,,,,,,,,,,, -64400,0.21198164,1.7352138,,,,,,,,,,,,,,,,, -64500,0.20804456,1.7535955,,,,,,,,,,,,,,,,, -64600,0.194538,1.7258859,,,,,,,,,,,,,,,,, -64700,0.20829394,1.7233269,,,,,,,,,,,,,,,,, -64800,0.1962836,1.7007593,,,,,,,,,,,,,,,,, -64900,0.20783825,1.6714109,,,,,,,,,,,,,,,,, -65000,0.19410178,1.689121,,,,,,,,,,,,,,,,, -65100,0.19869916,1.6603674,,,,,,,,,,,,,,,,, -65200,0.19766471,1.7097882,,,,,,,,,,,,,,,,, -65300,0.1939895,1.7009319,,,,,,,,,,,,,,,,, -65400,0.20972517,1.7369281,,,,,,,,,,,,,,,,, -65500,0.20676892,1.7240235,,,,,,,,,,,,,,,,, -65600,0.22298104,1.7124763,,,,,,,,,,,,,,,,, -65700,0.19261082,1.6393985,,,,,,,,,,,,,,,,, -65800,0.21365511,1.7283528,,,,,,,,,,,,,,,,, -65900,0.20053734,1.6793634,,,,,,,,,,,,,,,,, -66000,0.2036271,1.7071183,,,,,,,,,,,,,,,,, -66100,0.21492866,1.642627,,,,,,,,,,,,,,,,, -66200,0.19447196,1.6520609,,,,,,,,,,,,,,,,, -66300,0.23513988,1.6213391,,,,,,,,,,,,,,,,, -66400,0.18352185,1.5994697,,,,,,,,,,,,,,,,, -66485,,,0.6578103303909302,1.6158467531204224,32.83270106756466,0.6757386922836304,1.4923642873764038,29.64027073611295,3000.0,0.691046416759491,1.3991174697875977,29.41162720824399,3003.0,23552.78756380081,40321.3447868824,23552.78756380081,16765.60342478752,0.8476648330688477,0.0 -66500,0.21354502,1.6790211,,,,,,,,,,,,,,,,, -66600,0.21784818,1.7802376,,,,,,,,,,,,,,,,, -66700,0.22713447,1.6583879,,,,,,,,,,,,,,,,, -66800,0.22266687,1.6983265,,,,,,,,,,,,,,,,, -66900,0.23156387,1.6540493,,,,,,,,,,,,,,,,, -67000,0.22162795,1.7341062,,,,,,,,,,,,,,,,, -67100,0.23360518,1.7145419,,,,,,,,,,,,,,,,, -67200,0.19685839,1.6581638,,,,,,,,,,,,,,,,, -67300,0.20758319,1.5993056,,,,,,,,,,,,,,,,, -67400,0.1911941,1.7222477,,,,,,,,,,,,,,,,, -67500,0.19324602,1.5959035,,,,,,,,,,,,,,,,, -67600,0.20893736,1.5881556,,,,,,,,,,,,,,,,, -67700,0.18668617,1.6696452,,,,,,,,,,,,,,,,, -67800,0.19983386,1.6971176,,,,,,,,,,,,,,,,, -67900,0.21824922,1.728283,,,,,,,,,,,,,,,,, -68000,0.19585632,1.6280831,,,,,,,,,,,,,,,,, -68100,0.24654,1.5979238,,,,,,,,,,,,,,,,, -68200,0.2114547,1.7431958,,,,,,,,,,,,,,,,, -68300,0.19943039,1.738208,,,,,,,,,,,,,,,,, -68400,0.2003548,1.7106893,,,,,,,,,,,,,,,,, -68500,0.19809286,1.6843702,,,,,,,,,,,,,,,,, -68600,0.20448695,1.7207084,,,,,,,,,,,,,,,,, -68700,0.21105008,1.6811943,,,,,,,,,,,,,,,,, -68800,0.20952775,1.7426269,,,,,,,,,,,,,,,,, -68860,,,0.6925715208053589,1.3943202495574951,35.164799300747305,0.675986647605896,1.4907090663909912,29.37226423439722,3000.0,0.6896519660949707,1.3980993032455444,29.14519622574132,3003.0,24392.84364748001,41680.414498806,24392.84364748001,17284.506219387054,0.8824207782745361,0.0 -68900,0.21261407,1.6626401,,,,,,,,,,,,,,,,, -69000,0.21162106,1.6786809,,,,,,,,,,,,,,,,, -69100,0.20107588,1.7143581,,,,,,,,,,,,,,,,, -69200,0.24122673,1.703675,,,,,,,,,,,,,,,,, -69300,0.2066565,1.5974028,,,,,,,,,,,,,,,,, -69400,0.18327096,1.6374066,,,,,,,,,,,,,,,,, -69500,0.1988725,1.6267476,,,,,,,,,,,,,,,,, -69600,0.19077884,1.6540804,,,,,,,,,,,,,,,,, -69700,0.18358235,1.6964384,,,,,,,,,,,,,,,,, -69800,0.18824153,1.585722,,,,,,,,,,,,,,,,, -69900,0.19741324,1.6509761,,,,,,,,,,,,,,,,, -70000,0.19096369,1.617265,,,,,,,,,,,,,,,,, -70100,0.20470624,1.6486685,,,,,,,,,,,,,,,,, -70200,0.4540943,1.6101793,,,,,,,,,,,,,,,,, -70300,0.19438352,1.5930065,,,,,,,,,,,,,,,,, -70400,0.19607943,1.6720535,,,,,,,,,,,,,,,,, -70500,0.19633473,1.6660237,,,,,,,,,,,,,,,,, -70600,0.23602742,1.7185731,,,,,,,,,,,,,,,,, -70700,0.21835636,1.6136183,,,,,,,,,,,,,,,,, -70800,0.20902635,1.7316157,,,,,,,,,,,,,,,,, -70900,0.20573041,1.6500598,,,,,,,,,,,,,,,,, -71000,0.20256823,1.7458472,,,,,,,,,,,,,,,,, -71100,0.19910537,1.6849543,,,,,,,,,,,,,,,,, -71200,0.20748784,1.6570859,,,,,,,,,,,,,,,,, -71235,,,0.663612425327301,1.5702091455459597,33.04904270814784,0.6764578223228455,1.4874941110610962,29.29003691272471,3000.0,0.6922782063484192,1.3883384466171265,29.602084105542623,3003.0,25233.028274297714,43225.82240843773,25233.028274297714,17989.61869287491,0.9181523323059082,0.0 -71300,0.219889,1.7026279,,,,,,,,,,,,,,,,, -71400,0.24056304,1.6921053,,,,,,,,,,,,,,,,, -71500,0.2241902,1.7165314,,,,,,,,,,,,,,,,, -71600,0.1929142,1.6419152,,,,,,,,,,,,,,,,, -71700,0.20899707,1.7010424,,,,,,,,,,,,,,,,, -71800,0.2067856,1.6892227,,,,,,,,,,,,,,,,, -71900,0.19689593,1.635982,,,,,,,,,,,,,,,,, -72000,0.2013337,1.6806796,,,,,,,,,,,,,,,,, -72100,0.19742323,1.6052582,,,,,,,,,,,,,,,,, -72200,0.20369,1.7025982,,,,,,,,,,,,,,,,, -72300,0.20412907,1.6400892,,,,,,,,,,,,,,,,, -72400,0.19613048,1.6389642,,,,,,,,,,,,,,,,, -72500,0.18026495,1.6220975,,,,,,,,,,,,,,,,, -72600,0.20065732,1.6012204,,,,,,,,,,,,,,,,, -72700,0.18991919,1.5662413,,,,,,,,,,,,,,,,, -72800,0.25257924,1.6808264,,,,,,,,,,,,,,,,, -72900,0.22611986,1.5590717,,,,,,,,,,,,,,,,, -73000,0.2055168,1.7036393,,,,,,,,,,,,,,,,, -73100,0.2012097,1.5952733,,,,,,,,,,,,,,,,, -73200,0.20647517,1.6372359,,,,,,,,,,,,,,,,, -73300,0.20588434,1.6082069,,,,,,,,,,,,,,,,, -73400,0.20862521,1.6487796,,,,,,,,,,,,,,,,, -73500,0.1859884,1.5933602,,,,,,,,,,,,,,,,, -73600,0.20524229,1.6388075,,,,,,,,,,,,,,,,, -73610,,,0.6637750864028931,1.5700886249542236,32.387814931952754,0.6769413948059082,1.4775288105010986,29.41112830125952,3000.0,0.6927198171615601,1.3803156614303589,29.329502689351862,3003.0,26073.24313759804,44682.91942191124,26073.24313759804,18606.38701105117,0.956063747406006,0.0 -73700,0.20041966,1.6541727,,,,,,,,,,,,,,,,, -73800,0.1919924,1.6671764,,,,,,,,,,,,,,,,, -73900,0.22317104,1.6032525,,,,,,,,,,,,,,,,, -74000,0.2603967,1.6552428,,,,,,,,,,,,,,,,, -74100,0.20674491,1.6095818,,,,,,,,,,,,,,,,, -74200,0.21060458,1.6891887,,,,,,,,,,,,,,,,, -74300,0.20266739,1.6039757,,,,,,,,,,,,,,,,, -74400,0.21563639,1.6609181,,,,,,,,,,,,,,,,, -74500,0.20351964,1.7046734,,,,,,,,,,,,,,,,, -74600,0.20313843,1.6175985,,,,,,,,,,,,,,,,, -74700,0.2016366,1.6071494,,,,,,,,,,,,,,,,, -74800,0.18962437,1.6179994,,,,,,,,,,,,,,,,, -74900,0.19046472,1.6308328,,,,,,,,,,,,,,,,, -75000,0.19193955,1.6444376,,,,,,,,,,,,,,,,, -75100,0.2159518,1.6530262,,,,,,,,,,,,,,,,, -75200,0.33060613,1.6872351,,,,,,,,,,,,,,,,, -75300,0.21283245,1.5532677,,,,,,,,,,,,,,,,, -75400,0.20198914,1.5901028,,,,,,,,,,,,,,,,, -75500,0.18377292,1.570766,,,,,,,,,,,,,,,,, -75600,0.20593888,1.6100914,,,,,,,,,,,,,,,,, -75700,0.2101887,1.6731095,,,,,,,,,,,,,,,,, -75800,0.20758738,1.7183113,,,,,,,,,,,,,,,,, -75900,0.20674823,1.6941196,,,,,,,,,,,,,,,,, -75985,,,0.6715718507766724,1.5168931484222412,33.4930723722878,0.6797311902046204,1.4644559621810913,29.64138507890422,3000.0,0.6943234205245972,1.3734365701675415,29.62003987298333,3003.0,26913.46995139122,46109.13782072067,26913.46995139122,19192.265601158146,0.9935698509216307,0.0 -76000,0.19512494,1.6574794,,,,,,,,,,,,,,,,, -76100,0.21152376,1.6382921,,,,,,,,,,,,,,,,, -76200,0.21379699,1.6543987,,,,,,,,,,,,,,,,, -76300,0.20350738,1.6475056,,,,,,,,,,,,,,,,, -76400,0.20061822,1.6591887,,,,,,,,,,,,,,,,, -76500,0.20690504,1.6552471,,,,,,,,,,,,,,,,, -76600,0.2308663,1.5605422,,,,,,,,,,,,,,,,, -76700,0.25880563,1.7315128,,,,,,,,,,,,,,,,, -76800,0.19353244,1.7501434,,,,,,,,,,,,,,,,, -76900,0.18058647,1.578728,,,,,,,,,,,,,,,,, -77000,0.20168737,1.5611708,,,,,,,,,,,,,,,,, -77100,0.21278168,1.65304,,,,,,,,,,,,,,,,, -77200,0.19091396,1.6049415,,,,,,,,,,,,,,,,, -77300,0.2107138,1.6844053,,,,,,,,,,,,,,,,, -77400,0.19298375,1.623598,,,,,,,,,,,,,,,,, -77500,0.20606019,1.6134266,,,,,,,,,,,,,,,,, -77600,0.21626206,1.6256279,,,,,,,,,,,,,,,,, -77700,0.19524279,1.555915,,,,,,,,,,,,,,,,, -77800,0.20884101,1.6169282,,,,,,,,,,,,,,,,, -77900,0.26452047,1.6534517,,,,,,,,,,,,,,,,, -78000,0.23362339,1.7083211,,,,,,,,,,,,,,,,, -78100,0.22003825,1.677902,,,,,,,,,,,,,,,,, -78200,0.20669761,1.73613,,,,,,,,,,,,,,,,, -78300,0.20780586,1.6227322,,,,,,,,,,,,,,,,, -78359,,,0.6638250946998596,1.5705655813217163,32.52902754344788,0.6803511381149292,1.4611443281173706,29.507351771603417,3000.0,0.6962175369262695,1.3679414987564087,29.716893139832724,3003.0,27753.6347835064,47451.10957407951,27753.6347835064,19693.95895600319,1.03102707862854,0.0 -78400,0.22512256,1.5353532,,,,,,,,,,,,,,,,, -78500,0.202666,1.5790476,,,,,,,,,,,,,,,,, -78600,0.2409971,1.6525464,,,,,,,,,,,,,,,,, -78700,0.20676559,1.6888502,,,,,,,,,,,,,,,,, -78800,0.19137362,1.6093833,,,,,,,,,,,,,,,,, -78900,0.19576982,1.5788255,,,,,,,,,,,,,,,,, -79000,0.21246594,1.6161839,,,,,,,,,,,,,,,,, -79100,0.19904995,1.6582315,,,,,,,,,,,,,,,,, -79200,0.18987761,1.5976318,,,,,,,,,,,,,,,,, -79300,0.2021432,1.6201164,,,,,,,,,,,,,,,,, -79400,0.22664246,1.7033943,,,,,,,,,,,,,,,,, -79500,0.20627481,1.6330577,,,,,,,,,,,,,,,,, -79600,0.21133457,1.6432955,,,,,,,,,,,,,,,,, -79700,0.20228875,1.6778514,,,,,,,,,,,,,,,,, -79800,0.32178894,1.6067475,,,,,,,,,,,,,,,,, -79900,0.21521522,1.5257922,,,,,,,,,,,,,,,,, -80000,0.19457845,1.6622528,,,,,,,,,,,,,,,,, -80100,0.20696758,1.6584787,,,,,,,,,,,,,,,,, -80200,0.21959895,1.550523,,,,,,,,,,,,,,,,, -80300,0.20306063,1.6531463,,,,,,,,,,,,,,,,, -80400,0.20324862,1.6076318,,,,,,,,,,,,,,,,, -80500,0.19680622,1.6177206,,,,,,,,,,,,,,,,, -80600,0.21679384,1.5979987,,,,,,,,,,,,,,,,, -80700,0.20419662,1.6409125,,,,,,,,,,,,,,,,, -80733,,,0.6645106077194214,1.572019338607788,33.2519705355767,0.6811570525169373,1.4543503522872925,30.089135107250858,3000.0,0.6961246132850647,1.3614065647125244,29.882335007945937,3003.0,28593.58155298233,48813.9661796093,28593.58155298233,20216.756098032,1.0682952404022217,0.0 -80800,2.0005639,1.7076427,,,,,,,,,,,,,,,,, -80900,0.19732803,1.6264077,,,,,,,,,,,,,,,,, -81000,0.19466029,1.6066813,,,,,,,,,,,,,,,,, -81100,0.21275325,1.657484,,,,,,,,,,,,,,,,, -81200,0.20742907,1.6010284,,,,,,,,,,,,,,,,, -81300,0.19736424,1.6011038,,,,,,,,,,,,,,,,, -81400,0.19915794,1.636992,,,,,,,,,,,,,,,,, -81500,0.19053797,1.6554027,,,,,,,,,,,,,,,,, -81600,0.21070077,1.6846915,,,,,,,,,,,,,,,,, -81700,0.21755815,1.6142243,,,,,,,,,,,,,,,,, -81800,0.21472888,1.6368415,,,,,,,,,,,,,,,,, -81900,0.22763069,1.665119,,,,,,,,,,,,,,,,, -82000,0.21408962,1.6728094,,,,,,,,,,,,,,,,, -82100,0.21057953,1.6402928,,,,,,,,,,,,,,,,, -82200,0.22706486,1.7158997,,,,,,,,,,,,,,,,, -82300,0.21939082,1.6663796,,,,,,,,,,,,,,,,, -82400,0.20098683,1.6264932,,,,,,,,,,,,,,,,, -82500,0.20547204,1.575119,,,,,,,,,,,,,,,,, -82600,0.21947339,1.6688315,,,,,,,,,,,,,,,,, -82700,0.21273571,1.6682006,,,,,,,,,,,,,,,,, -82800,0.2134109,1.6532314,,,,,,,,,,,,,,,,, -82900,0.19454023,1.6048652,,,,,,,,,,,,,,,,, -83000,0.21993238,1.5433065,,,,,,,,,,,,,,,,, -83100,0.29447925,1.6698676,,,,,,,,,,,,,,,,, -83108,,,0.6732177138328552,1.5030677318572998,33.89015855149095,0.6820250153541565,1.4465928077697754,29.56513048449141,3000.0,0.6982976198196411,1.349581003189087,30.016736440843538,3003.0,29433.65891289711,50206.2528424263,29433.65891289711,20768.853723526,1.1046974658966064,0.0 -83200,0.20312272,1.6724051,,,,,,,,,,,,,,,,, -83300,0.20055911,1.6141812,,,,,,,,,,,,,,,,, -83400,0.21633516,1.6694162,,,,,,,,,,,,,,,,, -83500,0.33528253,1.6530628,,,,,,,,,,,,,,,,, -83600,0.19589512,1.5902846,,,,,,,,,,,,,,,,, -83700,0.19273692,1.5973396,,,,,,,,,,,,,,,,, -83800,0.20638336,1.5506305,,,,,,,,,,,,,,,,, -83900,0.20650245,1.647085,,,,,,,,,,,,,,,,, -84000,0.2075817,1.5988786,,,,,,,,,,,,,,,,, -84100,0.22332945,1.6314143,,,,,,,,,,,,,,,,, -84200,0.1993285,1.6458136,,,,,,,,,,,,,,,,, -84300,0.20810843,1.6381536,,,,,,,,,,,,,,,,, -84400,0.21636753,1.6887963,,,,,,,,,,,,,,,,, -84500,0.20925064,1.6206923,,,,,,,,,,,,,,,,, -84600,0.20053081,1.609768,,,,,,,,,,,,,,,,, -84700,0.20359725,1.5530707,,,,,,,,,,,,,,,,, -84800,0.20809732,1.5899504,,,,,,,,,,,,,,,,, -84900,0.23132786,1.5867866,,,,,,,,,,,,,,,,, -85000,0.23268221,1.6121204,,,,,,,,,,,,,,,,, -85100,0.20836279,1.681456,,,,,,,,,,,,,,,,, -85200,0.22385758,1.5567153,,,,,,,,,,,,,,,,, -85300,0.21963333,1.6096861,,,,,,,,,,,,,,,,, -85400,0.20242709,1.5438931,,,,,,,,,,,,,,,,, -85483,,,0.6708388328552246,1.520539402961731,33.3104744951593,0.6837112903594971,1.4417012929916382,30.142586435709703,3000.0,0.698448657989502,1.3438206911087036,29.82206997277747,3003.0,30273.729954242703,51594.460906744,30273.729954242703,21316.877049922943,1.142566204071045,0.0 -85500,0.20277879,1.5984002,,,,,,,,,,,,,,,,, -85600,0.24701537,1.5720059,,,,,,,,,,,,,,,,, -85700,0.20139363,1.6423837,,,,,,,,,,,,,,,,, -85800,0.20873079,1.6966355,,,,,,,,,,,,,,,,, -85900,0.23307198,1.6220855,,,,,,,,,,,,,,,,, -86000,0.2153773,1.6288042,,,,,,,,,,,,,,,,, -86100,0.22865318,1.5517049,,,,,,,,,,,,,,,,, -86200,0.22694546,1.624883,,,,,,,,,,,,,,,,, -86300,0.20861195,1.5980765,,,,,,,,,,,,,,,,, -86400,0.21259387,1.6571012,,,,,,,,,,,,,,,,, -86500,0.21373075,1.637685,,,,,,,,,,,,,,,,, -86600,0.21666281,1.663776,,,,,,,,,,,,,,,,, -86700,0.2112098,1.6017573,,,,,,,,,,,,,,,,, -86800,0.20658064,1.6365042,,,,,,,,,,,,,,,,, -86900,0.20391373,1.5261142,,,,,,,,,,,,,,,,, -87000,0.21473378,1.6875359,,,,,,,,,,,,,,,,, -87100,0.2116365,1.5998948,,,,,,,,,,,,,,,,, -87200,0.21658123,1.6136338,,,,,,,,,,,,,,,,, -87300,0.22649986,1.6051098,,,,,,,,,,,,,,,,, -87400,0.20374042,1.5613629,,,,,,,,,,,,,,,,, -87500,0.20765837,1.5980467,,,,,,,,,,,,,,,,, -87600,0.22073871,1.6136247,,,,,,,,,,,,,,,,, -87700,0.209956,1.5711375,,,,,,,,,,,,,,,,, -87800,0.21431999,1.5709138,,,,,,,,,,,,,,,,, -87858,,,0.6888561248779297,1.4054269790649414,34.55736403714124,0.6858067512512207,1.4331176280975342,29.992162601588777,3000.0,0.698506772518158,1.3392562866210938,29.97895945307208,3003.0,31113.74466848373,53040.19746589661,31113.74466848373,21922.486659526825,1.18042254447937,0.0 -87900,0.20333835,1.5593479,,,,,,,,,,,,,,,,, -88000,0.21903934,1.6055992,,,,,,,,,,,,,,,,, -88100,0.19912688,1.5578563,,,,,,,,,,,,,,,,, -88200,0.19060549,1.6481801,,,,,,,,,,,,,,,,, -88300,0.20656697,1.6046696,,,,,,,,,,,,,,,,, -88400,0.19352676,1.6088524,,,,,,,,,,,,,,,,, -88500,0.19776084,1.6015823,,,,,,,,,,,,,,,,, -88600,0.2089731,1.5687809,,,,,,,,,,,,,,,,, -88700,0.21654943,1.6019126,,,,,,,,,,,,,,,,, -88800,0.21033882,1.7082503,,,,,,,,,,,,,,,,, -88900,0.20055397,1.5577176,,,,,,,,,,,,,,,,, -89000,0.1963946,1.5535245,,,,,,,,,,,,,,,,, -89100,0.22747748,1.6378906,,,,,,,,,,,,,,,,, -89200,0.22019926,1.6725546,,,,,,,,,,,,,,,,, -89300,0.20234673,1.6293302,,,,,,,,,,,,,,,,, -89400,0.22750454,1.5663329,,,,,,,,,,,,,,,,, -89500,0.19872662,1.5507278,,,,,,,,,,,,,,,,, -89600,0.23566686,1.5497155,,,,,,,,,,,,,,,,, -89700,0.21387011,1.5657699,,,,,,,,,,,,,,,,, -89800,0.20448703,1.5844767,,,,,,,,,,,,,,,,, -89900,0.21110843,1.6206622,,,,,,,,,,,,,,,,, -90000,0.20535515,1.6157923,,,,,,,,,,,,,,,,, -90100,0.22556552,1.5996796,,,,,,,,,,,,,,,,, -90200,0.2039571,1.4924722,,,,,,,,,,,,,,,,, -90234,,,0.6757356524467468,1.4885623455047607,33.67554884277297,0.68563312292099,1.4297630786895752,29.696829275004344,3000.0,0.7016443014144897,1.3300366401672363,29.90019334184747,3003.0,31953.942955493927,54574.32289242744,31953.942955493927,22616.300694704056,1.2178633213043213,0.0 -90300,0.2230807,1.6156151,,,,,,,,,,,,,,,,, -90400,0.21491551,1.5857261,,,,,,,,,,,,,,,,, -90500,0.20907664,1.4697794,,,,,,,,,,,,,,,,, -90600,0.1947293,1.6065726,,,,,,,,,,,,,,,,, -90700,0.2018236,1.5408959,,,,,,,,,,,,,,,,, -90800,0.2125872,1.5517673,,,,,,,,,,,,,,,,, -90900,0.21435288,1.6324023,,,,,,,,,,,,,,,,, -91000,0.20799556,1.5852039,,,,,,,,,,,,,,,,, -91100,0.21076709,1.5789249,,,,,,,,,,,,,,,,, -91200,0.24857168,1.6845257,,,,,,,,,,,,,,,,, -91300,0.21540248,1.5411007,,,,,,,,,,,,,,,,, -91400,0.21369165,1.5497853,,,,,,,,,,,,,,,,, -91500,0.2015487,1.640359,,,,,,,,,,,,,,,,, -91600,0.20939673,1.5999002,,,,,,,,,,,,,,,,, -91700,0.2195089,1.6107816,,,,,,,,,,,,,,,,, -91800,0.22519521,1.6870655,,,,,,,,,,,,,,,,, -91900,0.22117792,1.5419159,,,,,,,,,,,,,,,,, -92000,0.20814374,1.5867096,,,,,,,,,,,,,,,,, -92100,0.2355255,1.5194157,,,,,,,,,,,,,,,,, -92200,0.2168652,1.5831074,,,,,,,,,,,,,,,,, -92300,0.21466494,1.5553514,,,,,,,,,,,,,,,,, -92400,0.2139483,1.640132,,,,,,,,,,,,,,,,, -92500,0.21336421,1.6195232,,,,,,,,,,,,,,,,, -92600,0.2007921,1.6123672,,,,,,,,,,,,,,,,, -92610,,,0.6737697720527649,1.50662362575531,33.71377058019944,0.6864514946937561,1.4230599403381348,30.020258720494507,3000.0,0.7028412222862244,1.321674346923828,30.292554115791944,3003.0,32794.13788151741,55952.35843801498,32794.13788151741,23154.02640128136,1.2572076320648191,0.0 -92700,0.21148354,1.5807167,,,,,,,,,,,,,,,,, -92800,0.20992632,1.5155723,,,,,,,,,,,,,,,,, -92900,0.2122527,1.532328,,,,,,,,,,,,,,,,, -93000,0.24666521,1.5735178,,,,,,,,,,,,,,,,, -93100,0.23630822,1.5410973,,,,,,,,,,,,,,,,, -93200,0.20121457,1.504776,,,,,,,,,,,,,,,,, -93300,0.20555446,1.5159962,,,,,,,,,,,,,,,,, -93400,0.22130199,1.5663643,,,,,,,,,,,,,,,,, -93500,0.23812029,1.6345929,,,,,,,,,,,,,,,,, -93600,0.21294364,1.5729804,,,,,,,,,,,,,,,,, -93700,0.23253661,1.571861,,,,,,,,,,,,,,,,, -93800,0.21324474,1.5657414,,,,,,,,,,,,,,,,, -93900,0.20202029,1.6269722,,,,,,,,,,,,,,,,, -94000,0.21377176,1.528669,,,,,,,,,,,,,,,,, -94100,0.21123806,1.581787,,,,,,,,,,,,,,,,, -94200,0.20138244,1.5467004,,,,,,,,,,,,,,,,, -94300,0.2072152,1.6161848,,,,,,,,,,,,,,,,, -94400,0.21791884,1.5721866,,,,,,,,,,,,,,,,, -94500,0.21303748,1.6136532,,,,,,,,,,,,,,,,, -94600,0.21227594,1.5730714,,,,,,,,,,,,,,,,, -94700,0.21549511,1.5608174,,,,,,,,,,,,,,,,, -94800,0.21391557,1.5017903,,,,,,,,,,,,,,,,, -94900,0.2090679,1.5336573,,,,,,,,,,,,,,,,, -94985,,,0.6820225715637207,1.4462394714355469,34.13055095861968,0.6877161860466003,1.4159722328186035,30.236845114728684,3000.0,0.7035384774208069,1.316325306892395,30.436514797787588,3003.0,33634.26413846016,57375.692895412445,33634.26413846016,23737.11946773529,1.2959558963775637,0.0 -95000,0.20912871,1.5611303,,,,,,,,,,,,,,,,, -95100,0.2102826,1.5323886,,,,,,,,,,,,,,,,, -95200,0.2233149,1.6038781,,,,,,,,,,,,,,,,, -95300,0.21143068,1.5856154,,,,,,,,,,,,,,,,, -95400,0.22239079,1.4530631,,,,,,,,,,,,,,,,, -95500,0.21539629,1.6212147,,,,,,,,,,,,,,,,, -95600,0.22580965,1.5679774,,,,,,,,,,,,,,,,, -95700,0.20979561,1.5179604,,,,,,,,,,,,,,,,, -95800,0.21001989,1.5120436,,,,,,,,,,,,,,,,, -95900,0.21598706,1.5284449,,,,,,,,,,,,,,,,, -96000,0.21323292,1.577555,,,,,,,,,,,,,,,,, -96100,0.20911412,1.5648891,,,,,,,,,,,,,,,,, -96200,0.20202048,1.533723,,,,,,,,,,,,,,,,, -96300,0.21889089,1.636498,,,,,,,,,,,,,,,,, -96400,0.23605487,1.5554566,,,,,,,,,,,,,,,,, -96500,0.21343821,1.6002045,,,,,,,,,,,,,,,,, -96600,0.22498712,1.6000956,,,,,,,,,,,,,,,,, -96700,0.21197292,1.5179571,,,,,,,,,,,,,,,,, -96800,0.22522451,1.6379756,,,,,,,,,,,,,,,,, -96900,0.2079726,1.5521221,,,,,,,,,,,,,,,,, -97000,0.22286399,1.6043704,,,,,,,,,,,,,,,,, -97100,0.22093545,1.4887853,,,,,,,,,,,,,,,,, -97200,0.2083535,1.5570866,,,,,,,,,,,,,,,,, -97300,0.22424644,1.6158764,,,,,,,,,,,,,,,,, -97360,,,0.6778927445411682,1.4817765951156616,33.92635535254963,0.6891545057296753,1.410452127456665,30.186419436898863,3000.0,0.7068619132041931,1.306337833404541,30.64499714464644,3003.0,34474.2428150177,58836.77590274811,34474.2428150177,24358.108916282654,1.3352117538452148,0.0 -97400,0.21879539,1.5272276,,,,,,,,,,,,,,,,, -97500,0.21999548,1.6309928,,,,,,,,,,,,,,,,, -97600,0.20565972,1.5421184,,,,,,,,,,,,,,,,, -97700,0.205242,1.5488155,,,,,,,,,,,,,,,,, -97800,0.21491079,1.5076785,,,,,,,,,,,,,,,,, -97900,0.20358132,1.5882028,,,,,,,,,,,,,,,,, -98000,0.2339558,1.6001902,,,,,,,,,,,,,,,,, -98100,0.20665793,1.5190386,,,,,,,,,,,,,,,,, -98200,0.22575502,1.5880035,,,,,,,,,,,,,,,,, -98300,0.21365963,1.5663086,,,,,,,,,,,,,,,,, -98400,0.22110365,1.5466104,,,,,,,,,,,,,,,,, -98500,0.21055865,1.5136302,,,,,,,,,,,,,,,,, -98600,0.21123888,1.5412352,,,,,,,,,,,,,,,,, -98700,0.22117756,1.4749552,,,,,,,,,,,,,,,,, -98800,0.22150911,1.6368711,,,,,,,,,,,,,,,,, -98900,0.20449378,1.5152402,,,,,,,,,,,,,,,,, -99000,0.21927844,1.5368431,,,,,,,,,,,,,,,,, -99100,0.21494214,1.6030723,,,,,,,,,,,,,,,,, -99200,0.20204504,1.4905499,,,,,,,,,,,,,,,,, -99300,0.22106858,1.5606734,,,,,,,,,,,,,,,,, -99400,0.20978513,1.5281541,,,,,,,,,,,,,,,,, -99500,0.23567207,1.5569927,,,,,,,,,,,,,,,,, -99600,0.21067426,1.5014356,,,,,,,,,,,,,,,,, -99700,0.20708413,1.5385245,,,,,,,,,,,,,,,,, -99735,,,0.6772194504737854,1.4831637144088743,33.94047766861321,0.6899604201316833,1.4045439958572388,30.348874012448068,3000.0,0.7069200277328491,1.2959305047988892,30.6400977790549,3003.0,35314.15435934067,60202.79884767532,35314.15435934067,24884.105031967163,1.375361442565918,0.0 -99800,0.21485938,1.5337746,,,,,,,,,,,,,,,,, -99900,0.22014503,1.5974896,,,,,,,,,,,,,,,,, -100000,0.21802403,1.6037368,,,,,,,,,,,,,,,,, -100100,0.20477118,1.5384231,,,,,,,,,,,,,,,,, -100200,0.2290707,1.5292233,,,,,,,,,,,,,,,,, -100300,0.20393753,1.5255439,,,,,,,,,,,,,,,,, -100400,0.20834942,1.54757,,,,,,,,,,,,,,,,, -100500,0.20949955,1.5808895,,,,,,,,,,,,,,,,, -100600,0.22696933,1.5370125,,,,,,,,,,,,,,,,, -100700,0.23424076,1.5784725,,,,,,,,,,,,,,,,, -100800,0.2363013,1.4892198,,,,,,,,,,,,,,,,, -100900,0.21340272,1.503071,,,,,,,,,,,,,,,,, -101000,0.2211179,1.5480351,,,,,,,,,,,,,,,,, -101100,0.24396612,1.4902815,,,,,,,,,,,,,,,,, -101200,0.22241943,1.5711448,,,,,,,,,,,,,,,,, -101300,0.22882256,1.6251245,,,,,,,,,,,,,,,,, -101400,0.21215029,1.5309899,,,,,,,,,,,,,,,,, -101500,0.24884357,1.595518,,,,,,,,,,,,,,,,, -101600,0.23729764,1.6430923,,,,,,,,,,,,,,,,, -101700,0.20997511,1.5447831,,,,,,,,,,,,,,,,, -101800,0.21024267,1.564588,,,,,,,,,,,,,,,,, -101900,0.2242123,1.5646758,,,,,,,,,,,,,,,,, -102000,0.22325127,1.5187793,,,,,,,,,,,,,,,,, -102100,0.21802354,1.4560721,,,,,,,,,,,,,,,,, -102110,,,0.6823187470436096,1.446167230606079,34.765234065491775,0.6899108290672302,1.3995331525802612,30.40090160189428,3000.0,0.7071989178657532,1.2938815355300903,30.771823174514257,3003.0,36154.0476911068,61602.90172743797,36154.0476911068,25444.19984698296,1.414492130279541,0.0 -102200,0.2248848,1.532809,,,,,,,,,,,,,,,,, -102300,0.22415885,1.5229851,,,,,,,,,,,,,,,,, -102400,0.220156,1.465581,,,,,,,,,,,,,,,,, -102500,0.21363232,1.5751736,,,,,,,,,,,,,,,,, -102600,0.21365714,1.5393955,,,,,,,,,,,,,,,,, -102700,0.227584,1.6438726,,,,,,,,,,,,,,,,, -102800,0.23030595,1.5939729,,,,,,,,,,,,,,,,, -102900,0.20909189,1.5037811,,,,,,,,,,,,,,,,, -103000,0.22092704,1.5719839,,,,,,,,,,,,,,,,, -103100,0.2489921,1.5803962,,,,,,,,,,,,,,,,, -103200,0.2187405,1.5186183,,,,,,,,,,,,,,,,, -103300,0.22720253,1.5681765,,,,,,,,,,,,,,,,, -103400,0.22468537,1.5337297,,,,,,,,,,,,,,,,, -103500,0.2343507,1.5146357,,,,,,,,,,,,,,,,, -103600,0.2226302,1.5477296,,,,,,,,,,,,,,,,, -103700,0.23433875,1.5373044,,,,,,,,,,,,,,,,, -103800,0.20720214,1.4702877,,,,,,,,,,,,,,,,, -103900,0.2067389,1.5603954,,,,,,,,,,,,,,,,, -104000,0.21422042,1.5109396,,,,,,,,,,,,,,,,, -104100,0.22630747,1.562703,,,,,,,,,,,,,,,,, -104200,0.22051312,1.5340763,,,,,,,,,,,,,,,,, -104300,0.22041121,1.5563545,,,,,,,,,,,,,,,,, -104400,0.22453453,1.5509416,,,,,,,,,,,,,,,,, -104486,,,0.6826587319374084,1.4512836933135986,34.46536235091862,0.6910763382911682,1.3928955793380735,30.427187187595266,3000.0,0.70707106590271,1.2888002395629885,30.53443950681672,3003.0,36994.13355565071,63073.723007678986,36994.13355565071,26074.819528579712,1.4544637203216553,0.0 -104500,0.22085723,1.6031489,,,,,,,,,,,,,,,,, -104600,0.21527077,1.4886822,,,,,,,,,,,,,,,,, -104700,0.21025635,1.5445764,,,,,,,,,,,,,,,,, -104800,0.22375186,1.5329198,,,,,,,,,,,,,,,,, -104900,0.225677,1.5253141,,,,,,,,,,,,,,,,, -105000,0.22073898,1.5329825,,,,,,,,,,,,,,,,, -105100,0.22747856,1.5011908,,,,,,,,,,,,,,,,, -105200,0.22843426,1.570435,,,,,,,,,,,,,,,,, -105300,0.21243253,1.5429802,,,,,,,,,,,,,,,,, -105400,0.21011302,1.4892704,,,,,,,,,,,,,,,,, -105500,0.22496922,1.5085702,,,,,,,,,,,,,,,,, -105600,0.21975636,1.5238053,,,,,,,,,,,,,,,,, -105700,0.21525688,1.5089208,,,,,,,,,,,,,,,,, -105800,0.21611352,1.5216432,,,,,,,,,,,,,,,,, -105900,0.22296503,1.4875335,,,,,,,,,,,,,,,,, -106000,0.21247748,1.4432484,,,,,,,,,,,,,,,,, -106100,0.22612108,1.478536,,,,,,,,,,,,,,,,, -106200,0.23668818,1.5627568,,,,,,,,,,,,,,,,, -106300,0.21914005,1.5688882,,,,,,,,,,,,,,,,, -106400,0.23297262,1.5364461,,,,,,,,,,,,,,,,, -106500,0.21735452,1.496486,,,,,,,,,,,,,,,,, -106600,0.22879183,1.5483912,,,,,,,,,,,,,,,,, -106700,0.22594856,1.5625421,,,,,,,,,,,,,,,,, -106800,0.22626278,1.5231615,,,,,,,,,,,,,,,,, -106862,,,0.6929945349693298,1.3875455856323242,35.6936682202158,0.6926262378692627,1.3891054391860962,30.321860371372377,3000.0,0.7095810770988464,1.2826364040374756,30.553318604965003,3003.0,37834.16900777817,64617.951558828354,37834.16900777817,26778.89575409889,1.4950628280639648,0.0 -106900,0.21270502,1.4772094,,,,,,,,,,,,,,,,, -107000,0.22280723,1.5799221,,,,,,,,,,,,,,,,, -107100,0.23055279,1.5429451,,,,,,,,,,,,,,,,, -107200,0.23643176,1.5511351,,,,,,,,,,,,,,,,, -107300,0.23372348,1.5081596,,,,,,,,,,,,,,,,, -107400,0.21706215,1.5207584,,,,,,,,,,,,,,,,, -107500,0.24319048,1.4385192,,,,,,,,,,,,,,,,, -107600,0.23742345,1.4684579,,,,,,,,,,,,,,,,, -107700,0.23897049,1.4695628,,,,,,,,,,,,,,,,, -107800,0.2240567,1.5059211,,,,,,,,,,,,,,,,, -107900,0.22800533,1.4371156,,,,,,,,,,,,,,,,, -108000,0.21872014,1.5182657,,,,,,,,,,,,,,,,, -108100,0.22884123,1.5052173,,,,,,,,,,,,,,,,, -108200,0.24533947,1.5112294,,,,,,,,,,,,,,,,, -108300,0.23622715,1.5624533,,,,,,,,,,,,,,,,, -108400,0.21868657,1.4797589,,,,,,,,,,,,,,,,, -108500,0.2352046,1.4978644,,,,,,,,,,,,,,,,, -108600,0.22789949,1.5370225,,,,,,,,,,,,,,,,, -108700,0.23613112,1.5397956,,,,,,,,,,,,,,,,, -108800,0.22411314,1.4760917,,,,,,,,,,,,,,,,, -108900,0.23081148,1.4621717,,,,,,,,,,,,,,,,, -109000,0.22886574,1.5259354,,,,,,,,,,,,,,,,, -109100,0.22602783,1.463088,,,,,,,,,,,,,,,,, -109200,0.22463836,1.5353688,,,,,,,,,,,,,,,,, -109238,,,0.6881809234619141,1.4110008478164673,34.792418856647195,0.6915723085403442,1.3885592222213743,30.53146184710506,3000.0,0.7090697884559631,1.2797530889511108,30.7142415724196,3003.0,38674.22744607925,66075.62375807762,38674.22744607925,27396.39172673225,1.5361199378967283,0.0 -109300,0.2368181,1.4758295,,,,,,,,,,,,,,,,, -109400,0.24121143,1.4782829,,,,,,,,,,,,,,,,, -109500,0.22420138,1.5694478,,,,,,,,,,,,,,,,, -109600,0.24015272,1.4368696,,,,,,,,,,,,,,,,, -109700,0.21583484,1.5209315,,,,,,,,,,,,,,,,, -109800,0.2248352,1.4224749,,,,,,,,,,,,,,,,, -109900,0.22985357,1.5395634,,,,,,,,,,,,,,,,, -110000,0.22871946,1.5019983,,,,,,,,,,,,,,,,, -110100,0.21700893,1.5136226,,,,,,,,,,,,,,,,, -110200,0.22614577,1.537656,,,,,,,,,,,,,,,,, -110300,0.23142692,1.4903634,,,,,,,,,,,,,,,,, -110400,0.23291512,1.6016141,,,,,,,,,,,,,,,,, -110500,0.22315541,1.4906435,,,,,,,,,,,,,,,,, -110600,0.220659,1.4437604,,,,,,,,,,,,,,,,, -110700,0.23724653,1.5572761,,,,,,,,,,,,,,,,, -110800,0.2219271,1.5550491,,,,,,,,,,,,,,,,, -110900,0.22527215,1.5751088,,,,,,,,,,,,,,,,, -111000,0.2292452,1.502732,,,,,,,,,,,,,,,,, -111100,0.22498533,1.5269064,,,,,,,,,,,,,,,,, -111200,0.2190849,1.4518466,,,,,,,,,,,,,,,,, -111300,0.22653608,1.5125943,,,,,,,,,,,,,,,,, -111400,0.24552868,1.5441856,,,,,,,,,,,,,,,,, -111500,0.2270817,1.540389,,,,,,,,,,,,,,,,, -111600,0.2345127,1.5147777,,,,,,,,,,,,,,,,, -111614,,,0.6850574612617493,1.43190598487854,34.64420152808263,0.6938413381576538,1.3793723583221436,30.713396510455496,3000.0,0.7106850147247314,1.2727055549621582,30.786169768048005,3003.0,39514.40298914909,67453.68859291077,39514.40298914909,27934.165376663208,1.5760877132415771,0.0 -111700,0.23454387,1.4879718,,,,,,,,,,,,,,,,, -111800,0.24056111,1.479489,,,,,,,,,,,,,,,,, -111900,0.23984505,1.5563744,,,,,,,,,,,,,,,,, -112000,0.24234095,1.5428455,,,,,,,,,,,,,,,,, -112100,0.23524648,1.4577266,,,,,,,,,,,,,,,,, -112200,0.22645852,1.4622524,,,,,,,,,,,,,,,,, -112300,0.2318876,1.4712329,,,,,,,,,,,,,,,,, -112400,0.2389289,1.4655012,,,,,,,,,,,,,,,,, -112500,0.22907944,1.4982636,,,,,,,,,,,,,,,,, -112600,0.24337588,1.5129879,,,,,,,,,,,,,,,,, -112700,0.2350726,1.433693,,,,,,,,,,,,,,,,, -112800,0.23317161,1.4775375,,,,,,,,,,,,,,,,, -112900,0.22549078,1.4720277,,,,,,,,,,,,,,,,, -113000,0.21173309,1.3968658,,,,,,,,,,,,,,,,, -113100,0.22267935,1.4107472,,,,,,,,,,,,,,,,, -113200,0.21912941,1.4030317,,,,,,,,,,,,,,,,, -113300,0.21907501,1.5097171,,,,,,,,,,,,,,,,, -113400,0.25396067,1.4716207,,,,,,,,,,,,,,,,, -113500,0.23425134,1.4831678,,,,,,,,,,,,,,,,, -113600,0.24180783,1.4678628,,,,,,,,,,,,,,,,, -113700,0.22946908,1.4779541,,,,,,,,,,,,,,,,, -113800,0.22687238,1.4512924,,,,,,,,,,,,,,,,, -113900,0.23009266,1.4357687,,,,,,,,,,,,,,,,, -113990,,,0.6922797560691833,1.3891639709472656,35.342772251239296,0.6943125128746033,1.3756449222564695,30.757281576648342,3000.0,0.711661159992218,1.2692705392837524,31.040662341172386,3003.0,40354.49650359154,68950.13110136986,40354.49650359154,28590.395575761795,1.6191487312316897,0.0 -114000,0.23072548,1.4737753,,,,,,,,,,,,,,,,, -114100,0.23478986,1.4724963,,,,,,,,,,,,,,,,, -114200,0.22874495,1.5043801,,,,,,,,,,,,,,,,, -114300,0.23260197,1.4945892,,,,,,,,,,,,,,,,, -114400,0.24063691,1.4826508,,,,,,,,,,,,,,,,, -114500,0.23817267,1.5641836,,,,,,,,,,,,,,,,, -114600,0.2594626,1.4653126,,,,,,,,,,,,,,,,, -114700,0.23787247,1.428979,,,,,,,,,,,,,,,,, -114800,0.22849356,1.4559004,,,,,,,,,,,,,,,,, -114900,0.22444879,1.4722581,,,,,,,,,,,,,,,,, -115000,0.22746581,1.4437288,,,,,,,,,,,,,,,,, -115100,0.2203342,1.4341446,,,,,,,,,,,,,,,,, -115200,0.23104624,1.5037037,,,,,,,,,,,,,,,,, -115300,0.2354961,1.4998858,,,,,,,,,,,,,,,,, -115400,0.23137887,1.5106622,,,,,,,,,,,,,,,,, -115500,0.23236753,1.5408328,,,,,,,,,,,,,,,,, -115600,0.22182924,1.5033373,,,,,,,,,,,,,,,,, -115700,0.24791135,1.5160499,,,,,,,,,,,,,,,,, -115800,0.24372135,1.4790183,,,,,,,,,,,,,,,,, -115900,0.21623676,1.4801937,,,,,,,,,,,,,,,,, -116000,0.24856679,1.49042,,,,,,,,,,,,,,,,, -116100,0.22648607,1.4529628,,,,,,,,,,,,,,,,, -116200,0.22860017,1.4628382,,,,,,,,,,,,,,,,, -116300,0.22460346,1.4655668,,,,,,,,,,,,,,,,, -116366,,,0.6906039714813232,1.3968476057052612,34.79750781778418,0.6947464942932129,1.373445987701416,30.66973494550261,3000.0,0.7128929495811462,1.264074683189392,31.160097992334794,3003.0,41194.70844531059,70359.9132270813,41194.70844531059,29159.84848332405,1.660917043685913,0.0 -116400,0.22928952,1.3889053,,,,,,,,,,,,,,,,, -116500,0.23701337,1.4782747,,,,,,,,,,,,,,,,, -116600,0.23830086,1.4925733,,,,,,,,,,,,,,,,, -116700,0.21794777,1.4148412,,,,,,,,,,,,,,,,, -116800,0.22594926,1.472526,,,,,,,,,,,,,,,,, -116900,0.24255538,1.4685444,,,,,,,,,,,,,,,,, -117000,0.23530027,1.5156893,,,,,,,,,,,,,,,,, -117100,0.24237671,1.4920303,,,,,,,,,,,,,,,,, -117200,0.23951812,1.4796351,,,,,,,,,,,,,,,,, -117300,0.22998208,1.5086532,,,,,,,,,,,,,,,,, -117400,0.232416,1.5174235,,,,,,,,,,,,,,,,, -117500,0.23206979,1.5386992,,,,,,,,,,,,,,,,, -117600,0.22969232,1.4121423,,,,,,,,,,,,,,,,, -117700,0.22344726,1.4285104,,,,,,,,,,,,,,,,, -117800,0.23771381,1.4603449,,,,,,,,,,,,,,,,, -117900,0.22597924,1.4585851,,,,,,,,,,,,,,,,, -118000,0.23197687,1.3812469,,,,,,,,,,,,,,,,, -118100,0.23653898,1.484843,,,,,,,,,,,,,,,,, -118200,0.22654662,1.4343821,,,,,,,,,,,,,,,,, -118300,0.23169473,1.4960365,,,,,,,,,,,,,,,,, -118400,0.23559678,1.4708759,,,,,,,,,,,,,,,,, -118500,0.23274688,1.4827685,,,,,,,,,,,,,,,,, -118600,0.2355088,1.4376442,,,,,,,,,,,,,,,,, -118700,0.23790753,1.5368991,,,,,,,,,,,,,,,,, -118741,,,0.6908093690872192,1.4008818864822388,34.89803540192062,0.6953044533729553,1.3693207502365112,30.95589799049646,3000.0,0.7127767205238342,1.262351393699646,30.7693335450808,3003.0,42034.721598386765,71731.30040001869,42034.721598386765,29691.103921175003,1.7029705047607422,0.0 -118800,0.23280537,1.4388298,,,,,,,,,,,,,,,,, -118900,0.22779551,1.4209529,,,,,,,,,,,,,,,,, -119000,0.25089246,1.5381119,,,,,,,,,,,,,,,,, -119100,0.23475374,1.4427477,,,,,,,,,,,,,,,,, -119200,0.23534425,1.5710305,,,,,,,,,,,,,,,,, -119300,0.22439468,1.444782,,,,,,,,,,,,,,,,, -119400,0.23369838,1.5005741,,,,,,,,,,,,,,,,, -119500,0.23981956,1.4854192,,,,,,,,,,,,,,,,, -119600,0.22761057,1.4793994,,,,,,,,,,,,,,,,, -119700,0.24420561,1.4889317,,,,,,,,,,,,,,,,, -119800,0.21672864,1.4136206,,,,,,,,,,,,,,,,, -119900,0.24498737,1.482881,,,,,,,,,,,,,,,,, -120000,0.24266961,1.4505883,,,,,,,,,,,,,,,,, -120100,0.23963355,1.4903926,,,,,,,,,,,,,,,,, -120200,0.22958063,1.4390981,,,,,,,,,,,,,,,,, -120300,0.23953176,1.5194613,,,,,,,,,,,,,,,,, -120400,0.24120946,1.515943,,,,,,,,,,,,,,,,, -120500,0.23030603,1.4784597,,,,,,,,,,,,,,,,, -120600,0.22958013,1.5373911,,,,,,,,,,,,,,,,, -120700,0.22692768,1.4381952,,,,,,,,,,,,,,,,, -120800,0.23669794,1.4544367,,,,,,,,,,,,,,,,, -120900,0.23090348,1.4798973,,,,,,,,,,,,,,,,, -121000,0.23006299,1.408007,,,,,,,,,,,,,,,,, -121100,0.23294328,1.437326,,,,,,,,,,,,,,,,, -121116,,,0.6932882070541382,1.388006567955017,35.52039251433238,0.6962592005729675,1.367598533630371,30.874046124594287,3000.0,0.7138457894325256,1.2584335803985596,31.12364558456364,3003.0,42874.60907793045,73184.37866711617,42874.60907793045,30304.17706489563,1.744511365890503,0.0 -121200,0.23403686,1.4772843,,,,,,,,,,,,,,,,, -121300,0.22963221,1.3859414,,,,,,,,,,,,,,,,, -121400,0.23969375,1.4452497,,,,,,,,,,,,,,,,, -121500,0.24381658,1.455787,,,,,,,,,,,,,,,,, -121600,0.24243654,1.3681773,,,,,,,,,,,,,,,,, -121700,0.23279051,1.4822878,,,,,,,,,,,,,,,,, -121800,0.23478614,1.5367042,,,,,,,,,,,,,,,,, -121900,0.23833871,1.5013423,,,,,,,,,,,,,,,,, -122000,0.22894964,1.428442,,,,,,,,,,,,,,,,, -122100,0.25131503,1.3818283,,,,,,,,,,,,,,,,, -122200,0.2283712,1.4903326,,,,,,,,,,,,,,,,, -122300,0.23256995,1.4277169,,,,,,,,,,,,,,,,, -122400,0.23322825,1.4172146,,,,,,,,,,,,,,,,, -122500,0.22315747,1.4103619,,,,,,,,,,,,,,,,, -122600,0.23091947,1.4491501,,,,,,,,,,,,,,,,, -122700,0.24076734,1.4751886,,,,,,,,,,,,,,,,, -122800,0.23421146,1.4990363,,,,,,,,,,,,,,,,, -122900,0.23418397,1.4341915,,,,,,,,,,,,,,,,, -123000,0.25089988,1.5152457,,,,,,,,,,,,,,,,, -123100,0.22602823,1.4643435,,,,,,,,,,,,,,,,, -123200,0.2340857,1.4018788,,,,,,,,,,,,,,,,, -123300,0.2380161,1.4639369,,,,,,,,,,,,,,,,, -123400,0.24120778,1.463317,,,,,,,,,,,,,,,,, -123491,,,0.6912440061569214,1.395629644393921,35.35323909995653,0.6952672600746155,1.367558479309082,30.84424959213709,3000.0,0.7139503955841064,1.255860447883606,31.129922784276456,3003.0,43714.75734376907,74557.78880643845,43714.75734376907,30837.32000207901,1.7871100902557373,0.0 -123500,0.23635733,1.4239914,,,,,,,,,,,,,,,,, -123600,0.23993199,1.5167584,,,,,,,,,,,,,,,,, -123700,0.22909969,1.4529717,,,,,,,,,,,,,,,,, -123800,0.23331732,1.4336741,,,,,,,,,,,,,,,,, -123900,0.23370172,1.4568788,,,,,,,,,,,,,,,,, -124000,0.23361446,1.389798,,,,,,,,,,,,,,,,, -124100,0.24035525,1.4153373,,,,,,,,,,,,,,,,, -124200,0.22871095,1.4481094,,,,,,,,,,,,,,,,, -124300,0.24783316,1.4071523,,,,,,,,,,,,,,,,, -124400,0.23956199,1.472849,,,,,,,,,,,,,,,,, -124500,0.2248767,1.4916364,,,,,,,,,,,,,,,,, -124600,0.233402,1.4708025,,,,,,,,,,,,,,,,, -124700,0.2337604,1.4564989,,,,,,,,,,,,,,,,, -124800,0.2316995,1.4822106,,,,,,,,,,,,,,,,, -124900,0.24186896,1.47486,,,,,,,,,,,,,,,,, -125000,0.22958502,1.4681343,,,,,,,,,,,,,,,,, -125100,0.22944002,1.4459856,,,,,,,,,,,,,,,,, -125200,0.2297099,1.4087336,,,,,,,,,,,,,,,,, -125300,0.2320502,1.4451786,,,,,,,,,,,,,,,,, -125400,0.23931164,1.5317903,,,,,,,,,,,,,,,,, -125500,0.23679954,1.4652517,,,,,,,,,,,,,,,,, -125600,0.24227278,1.4995995,,,,,,,,,,,,,,,,, -125700,0.22849923,1.4313321,,,,,,,,,,,,,,,,, -125800,0.23384948,1.4867324,,,,,,,,,,,,,,,,, -125866,,,0.697185218334198,1.3619694709777832,35.78074525428276,0.695651650428772,1.3671724796295166,30.911010119255785,3000.0,0.7137179970741272,1.25685453414917,30.93043914065589,3003.0,44554.68297743797,75943.4162247181,44554.68297743797,31382.90145254135,1.8318068981170648,0.0 -125900,0.2355878,1.4625193,,,,,,,,,,,,,,,,, -126000,0.22801276,1.4635519,,,,,,,,,,,,,,,,, -126100,0.2310119,1.488176,,,,,,,,,,,,,,,,, -126200,0.22850046,1.4115204,,,,,,,,,,,,,,,,, -126300,0.23836926,1.5530021,,,,,,,,,,,,,,,,, -126400,0.2296587,1.4778157,,,,,,,,,,,,,,,,, -126500,0.22964862,1.445697,,,,,,,,,,,,,,,,, -126600,0.23677453,1.3901068,,,,,,,,,,,,,,,,, -126700,0.23163255,1.3680316,,,,,,,,,,,,,,,,, -126800,0.24034908,1.449802,,,,,,,,,,,,,,,,, -126900,0.22309226,1.3948997,,,,,,,,,,,,,,,,, -127000,0.23015736,1.3894391,,,,,,,,,,,,,,,,, -127100,0.3460975,1.4203942,,,,,,,,,,,,,,,,, -127200,0.22981542,1.4105501,,,,,,,,,,,,,,,,, -127300,0.22919463,1.4670628,,,,,,,,,,,,,,,,, -127400,0.23780234,1.3855218,,,,,,,,,,,,,,,,, -127500,0.22447994,1.5039126,,,,,,,,,,,,,,,,, -127600,0.22452517,1.4550685,,,,,,,,,,,,,,,,, -127700,0.2528277,1.4636194,,,,,,,,,,,,,,,,, -127800,0.23341288,1.4286729,,,,,,,,,,,,,,,,, -127900,0.21914053,1.4166261,,,,,,,,,,,,,,,,, -128000,0.23174778,1.4642892,,,,,,,,,,,,,,,,, -128100,0.23524967,1.4950726,,,,,,,,,,,,,,,,, -128200,0.2339735,1.4763905,,,,,,,,,,,,,,,,, -128241,,,0.6922347545623779,1.3835301399230957,35.30423628524667,0.6954284310340881,1.3652565479278564,30.88262735410851,3000.0,0.7137528657913208,1.2547988891601562,31.12628625505224,3003.0,45394.836584329605,77351.26991295815,45394.836584329605,31950.483147382736,1.8745992183685305,0.0 -128300,0.23335552,1.5177398,,,,,,,,,,,,,,,,, -128400,0.22370462,1.4356182,,,,,,,,,,,,,,,,, -128500,0.23285869,1.4542222,,,,,,,,,,,,,,,,, -128600,0.23168124,1.487445,,,,,,,,,,,,,,,,, -128700,0.24177942,1.4679247,,,,,,,,,,,,,,,,, -128800,0.24795754,1.4867035,,,,,,,,,,,,,,,,, -128900,0.2308358,1.465697,,,,,,,,,,,,,,,,, -129000,0.2510211,1.473422,,,,,,,,,,,,,,,,, -129100,0.22742431,1.4035261,,,,,,,,,,,,,,,,, -129200,0.2421465,1.4664553,,,,,,,,,,,,,,,,, -129300,0.23283486,1.4523795,,,,,,,,,,,,,,,,, -129400,0.2245712,1.4248419,,,,,,,,,,,,,,,,, -129500,0.23460078,1.4125055,,,,,,,,,,,,,,,,, -129600,0.22977403,1.4226009,,,,,,,,,,,,,,,,, -129700,0.22976688,1.4686148,,,,,,,,,,,,,,,,, -129800,0.23528059,1.4086123,,,,,,,,,,,,,,,,, -129900,0.23183554,1.507329,,,,,,,,,,,,,,,,, -130000,0.23362906,1.3950436,,,,,,,,,,,,,,,,, -130100,0.23375292,1.4749475,,,,,,,,,,,,,,,,, -130200,0.23205328,1.481111,,,,,,,,,,,,,,,,, -130300,0.24580078,1.5186862,,,,,,,,,,,,,,,,, -130400,0.23040174,1.4712805,,,,,,,,,,,,,,,,, -130500,0.25009522,1.4644487,,,,,,,,,,,,,,,,, -130600,0.23969628,1.4556571,,,,,,,,,,,,,,,,, -130616,,,0.6963204741477966,1.366992473602295,35.57011387339867,0.6954284310340881,1.3648079633712769,30.946973319770205,3000.0,0.7141479253768921,1.2541664838790894,31.103261886181247,3003.0,46234.92179393768,78754.49300599098,46234.92179393768,32513.490881443024,1.928779125213623,0.0 -130700,0.23246095,1.4404415,,,,,,,,,,,,,,,,, -130800,0.23409665,1.4994613,,,,,,,,,,,,,,,,, -130900,0.22678407,1.4215543,,,,,,,,,,,,,,,,, -131000,0.22742328,1.4047604,,,,,,,,,,,,,,,,, -131100,0.23566228,1.4914384,,,,,,,,,,,,,,,,, -131200,0.24230632,1.5053831,,,,,,,,,,,,,,,,, -131300,0.23255044,1.5218927,,,,,,,,,,,,,,,,, -131400,0.22710426,1.4457147,,,,,,,,,,,,,,,,, -131500,0.22621295,1.4426041,,,,,,,,,,,,,,,,, -131600,0.22484125,1.4517953,,,,,,,,,,,,,,,,, -131700,0.24254905,1.4661686,,,,,,,,,,,,,,,,, -131800,0.23381387,1.4755819,,,,,,,,,,,,,,,,, -131900,0.22802633,1.4022062,,,,,,,,,,,,,,,,, -132000,0.23582302,1.4961649,,,,,,,,,,,,,,,,, -132100,0.23321167,1.3860598,,,,,,,,,,,,,,,,, -132200,0.22735272,1.4623649,,,,,,,,,,,,,,,,, -132300,0.23870906,1.4578326,,,,,,,,,,,,,,,,, -132400,0.24156363,1.5072181,,,,,,,,,,,,,,,,, -132500,0.2198874,1.4673991,,,,,,,,,,,,,,,,, -132600,0.23394045,1.4333926,,,,,,,,,,,,,,,,, -132700,0.22914927,1.4704363,,,,,,,,,,,,,,,,, -132800,0.22864145,1.5156432,,,,,,,,,,,,,,,,, -132900,0.23740979,1.535919,,,,,,,,,,,,,,,,, -132991,,,0.69627445936203,1.36464262008667,35.58865724493143,0.6954532265663147,1.3649661540985107,31.040501645749234,3000.0,0.7140085101127625,1.2541321516036987,31.14580943345989,3003.0,47074.85367107391,80159.94454312325,47074.85367107391,33078.891793727875,1.9719855785369875,0.0 -133000,0.22903764,1.4892503,,,,,,,,,,,,,,,,, -133100,0.230999,1.4192283,,,,,,,,,,,,,,,,, -133200,0.23541845,1.4475508,,,,,,,,,,,,,,,,, -133300,0.23296061,1.5276902,,,,,,,,,,,,,,,,, -133333,,,0.6945627331733704,1.374417424201965,35.631271648808266,0.6954532265663147,1.3649758100509644,31.040501645749234,3000.0,0.7140317559242249,1.254139423370361,31.13090825160175,3003.0,47195.332033634186,80834.32544207573,47195.332033634186,33632.739612579346,2.015644073486328,0.0 -133333,,,,,,,,,,,,,,47195.332033634186,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 9b50eacdb..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,53 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -789.0446937084198,0.0,18.817258834838867,1,0,18.817258834838867,1.2045833081414474,95000000,807.8619921207428,1.2087496365391233,1.2014791517681729,83274637 -1451.193811416626,0.027919054031372,139.51362490653992,183,0,139.51362490653992,0.1314888969366776,95000000,1590.7417414188385,0.1278515188433464,0.1287038725557954,83274637 -2076.362182378769,0.0564866065979003,260.11498379707336,357,0,260.11498379707336,0.1297014794202302,95000000,2336.545876264572,0.1252776052240493,0.1271161402895473,83274637 -2692.7577061653137,0.0769233703613281,380.83020520210266,538,0,380.83020520210266,0.1291569885999177,95000000,3073.683103084564,0.1250187005824263,0.1265432106493676,83274637 -3304.3442330360413,0.0988883972167968,501.5063283443451,715,0,501.5063283443451,0.1287863612253289,95000000,3805.973586559296,0.1250134759096417,0.1261492000826922,83274637 -3894.105229139328,0.1215641498565673,621.5448927879333,888,0,621.5448927879333,0.1285003619757401,95000000,4515.801623344421,0.1256469891599889,0.1259796613389358,83274637 -4475.561370134354,0.1439273357391357,741.6070051193237,1068,0,741.6070051193237,0.128525582863898,95000000,5217.348087310791,0.1234735286320155,0.1260174835319587,83274637 -5027.531148433685,0.1646909713745117,861.6804373264313,1249,0,861.6804373264313,0.1281114284642269,95000000,5889.418105363846,0.1227490286102647,0.1257213830206151,83274637 -5580.225698709488,0.185643196105957,982.4170339107512,1426,0,982.4170339107512,0.1279454642475329,95000000,6562.875961303711,0.1246953570289806,0.1255922103046157,83274637 -6070.445714712143,0.2074365615844726,1102.5696003437042,1604,0,1102.5696003437042,0.1279644877569901,95000000,7173.276115179062,0.1235765063982902,0.1255172006136679,83274637 -6414.903634786606,0.2290723323822021,1222.6450836658478,1776,0,1222.6450836658478,0.1282788197060033,95000000,7637.837061882019,0.1246517481247209,0.1257625124532684,83274637 -6907.844863176346,0.2506551742553711,1343.1871309280396,1954,0,1343.1871309280396,0.127675164113898,95000000,8251.347687005997,0.124695456428629,0.1251498707596344,83274637 -7426.209614753723,0.2737016677856445,1463.1581037044523,2144,0,1463.1581037044523,0.127443936091694,95000000,8889.712707042694,0.1235761223653764,0.1250831759643958,83274637 -7964.530359506607,0.2965288162231445,1583.281406402588,2331,0,1583.281406402588,0.1275051309313322,95000000,9548.18579506874,0.125252609439897,0.1250847223641586,83274637 -8481.885344982147,0.3175611495971679,1703.615050792694,2507,0,1703.615050792694,0.1274699079975328,95000000,10185.901446580889,0.1267426652048931,0.1250378364413401,83274637 -8990.940734148026,0.3392865657806396,1824.0671820640564,2684,0,1824.0671820640564,0.1276610174856085,95000000,10815.436591625214,0.1232916171422357,0.1251517855223715,83274637 -9534.269513607023,0.3602786064147949,1944.574628829956,2863,0,1944.574628829956,0.1272024709703947,95000000,11479.299566030502,0.1233500298932663,0.124826699454252,83274637 -10075.8374915123,0.3829615116119385,2064.736346960068,3045,0,2064.736346960068,0.1274849592619243,95000000,12141.058000802994,0.1245352885094266,0.1251230068300519,83274637 -10621.36034655571,0.4073977470397949,2184.739155292511,3220,0,2184.739155292511,0.1271892849712171,95000000,12806.61373114586,0.1227000882532799,0.124808083846143,83274637 -11172.7266330719,0.4294459819793701,2304.8250658512115,3393,0,2304.8250658512115,0.1271991233758223,95000000,13478.09382390976,0.1252406428937079,0.1247541848312672,83274637 -11700.181691408156,0.4539427757263183,2424.872097969055,3571,0,2424.872097969055,0.1272389699424342,95000000,14125.62613105774,0.124942989855911,0.1247091829053751,83274637 -12238.537390470505,0.4765989780426025,2545.603095293045,3746,0,2545.603095293045,0.1271622003392269,95000000,14784.741124391556,0.1245335034627772,0.1246319910406814,83274637 -12788.208154916763,0.5039494037628174,2666.4074409008026,3923,0,2666.4074409008026,0.1272743532483552,95000000,15455.249332904816,0.1217524022238809,0.1248240568018683,83274637 -13287.225280284882,0.5254969596862793,2787.113216876984,4099,0,2787.113216876984,0.1271502514083059,95000000,16075.00016617775,0.1231201131529403,0.1247593495002828,83274637 -13843.353071928024,0.547893762588501,2907.6059803962708,4274,0,2907.6059803962708,0.1273904122635691,95000000,16751.648733615875,0.1223117412659544,0.1249397528848954,83274637 -14414.588943958282,0.5692830085754395,3027.923202037812,4455,0,3027.923202037812,0.1267874390008223,95000000,17443.229073524475,0.1225913947084977,0.124421076110391,83274637 -14971.964445590973,0.594536304473877,3148.1890013217926,4628,0,3148.1890013217926,0.1268318977693256,95000000,18120.90138840676,0.1233379368629283,0.1244669025799461,83274637 -15508.046690702438,0.6166715621948242,3268.8093614578247,4814,0,3268.8093614578247,0.1267612940275493,95000000,18777.632177591324,0.1234617902188556,0.1243810264070015,83274637 -16057.181163549423,0.6385774612426758,3388.91603922844,4990,0,3388.91603922844,0.1265567877672697,95000000,19446.900912046432,0.1227653210385228,0.1242355577147866,83274637 -16596.44086575508,0.6602156162261963,3509.4399082660675,5168,0,3509.4399082660675,0.1266145437088816,95000000,20106.71199607849,0.1241660562677765,0.1242902807104288,83274637 -17153.042880296707,0.6854746341705322,3630.0084569454193,5344,0,3630.0084569454193,0.1266809336759868,95000000,20783.91361641884,0.1231663083107029,0.1243867278388549,83274637 -17695.571431159973,0.7104918956756592,3750.658250808716,5517,0,3750.658250808716,0.1268313451377467,95000000,21447.12274193764,0.1235510464523386,0.1244935922285696,83274637 -18240.172213554382,0.7323856353759766,3871.1222972869873,5699,0,3871.1222972869873,0.126768357185444,95000000,22112.21555185318,0.1242261264566917,0.1243767421960324,83274637 -18793.809163808823,0.7616417407989502,3991.6780228614807,5877,0,3991.6780228614807,0.1265490914987664,95000000,22786.44338607788,0.1226479431977436,0.1241523478241509,83274637 -19327.80193209648,0.784590482711792,4111.936642169952,6056,0,4111.936642169952,0.1266930142269736,95000000,23440.723692655563,0.1225242608185834,0.1243487263222586,83274637 -19883.262872695923,0.8101849555969238,4232.155610084534,6232,0,4232.155610084534,0.1266402905016447,95000000,24116.43520641327,0.1234508143538761,0.1242802672980467,83274637 -20429.887636899948,0.8363173007965088,4352.459222078323,6409,0,4352.459222078323,0.126642444634046,95000000,24783.39586544037,0.1205027495516733,0.1242960469867839,83274637 -20973.60512423516,0.8614339828491211,4472.683204650879,6590,0,4472.683204650879,0.1264559795435855,95000000,25447.36835289001,0.1231092098491184,0.1241006255065073,83274637 -21508.55733346939,0.8837997913360596,4593.207630872726,6770,0,4593.207630872726,0.126623035207648,95000000,26102.873340845108,0.120198328223712,0.1242077644196624,83274637 -22050.916337251663,0.9108951091766356,4713.524817228317,6946,0,4713.524817228317,0.1264218494140625,95000000,26765.58244419098,0.1179989202873511,0.1241508252712629,83274637 -22591.14922332764,0.933558702468872,4833.66826915741,7123,0,4833.66826915741,0.1263023667865954,95000000,27425.98725819588,0.1220455420607665,0.1240180520055227,83274637 -23138.42849516869,0.9559965133666992,4953.663403511047,7306,0,4953.663403511047,0.1264483639391447,95000000,28093.290339946747,0.1221627628419009,0.1240735888791122,83274637 -23669.392629384995,0.9820210933685304,5073.9851770401,7485,0,5073.9851770401,0.1263452935752467,95000000,28744.60824251175,0.1203739370260411,0.1240158122580213,83274637 -24213.64492177964,1.0047781467437744,5194.398851156235,7663,0,5194.398851156235,0.1261958093647203,95000000,29409.30269098282,0.1207796007793092,0.1238437841866733,83274637 -24753.443234682083,1.0272042751312256,5314.715612649918,7843,0,5314.715612649918,0.1261964286800986,95000000,30069.44606542588,0.1195970896936062,0.1238456539939176,83274637 -25303.39065337181,1.049952745437622,5435.273231506348,8018,0,5435.273231506348,0.1261577527138158,95000000,30739.97963070869,0.1206216625869274,0.1238476197632847,83274637 -25840.4235098362,1.0788283348083496,5555.939074039459,8198,0,5555.939074039459,0.1262986234477796,95000000,31397.71311402321,0.1217111458404446,0.123903496778799,83274637 -26393.805539608,1.1015920639038086,5676.031098842621,8374,0,5676.031098842621,0.1261533757915296,95000000,32071.21576356888,0.1197104043087119,0.1237947691900101,83274637 -26935.695637226105,1.125080108642578,5796.273926258087,8553,0,5796.273926258087,0.1261754389905427,95000000,32733.37809896469,0.1218349276426828,0.1238067552613776,83274637 -27473.209585905075,1.1481850147247314,5916.893688201904,8733,0,5916.893688201904,0.126067133645148,95000000,33391.54086971283,0.1219008069063695,0.1237535159318583,83274637 -28010.46356701851,1.1713495254516602,6036.8774428367615,8912,0,6036.8774428367615,0.1260618208778783,95000000,34048.80773925781,0.121345523768251,0.1237732839121945,83274637 -28552.538821220398,1.1950955390930176,6157.476182222366,9084,0,6157.476182222366,0.12601664622738487,95000000,34711.51103806496,0.12280092236099753,0.12373237628420816,83274637 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 572f773b7..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,145 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,11.462805,1.205889,,,,,,,,,,, -1,,,1.2087496365391233,1.2014791517681729,83274637.0,1.2045833081414474,95000000.0,18.817258834838867,807.8619921207428,18.817258834838867,789.0446937084198,0.0,0.0 -100,0.071425185,0.12019351,,,,,,,,,,, -183,,,0.1278515188433464,0.1287038725557954,83274637.0,0.1314888969366776,95000000.0,139.51362490653992,1590.7417414188385,139.51362490653992,1451.193811416626,0.027919054031372,0.0 -200,0.05924954,0.13903196,,,,,,,,,,, -300,0.06533368,0.12525702,,,,,,,,,,, -357,,,0.1252776052240493,0.1271161402895473,83274637.0,0.1297014794202302,95000000.0,260.11498379707336,2336.545876264572,260.11498379707336,2076.362182378769,0.0564866065979003,0.0 -400,0.033377655,0.12183362,,,,,,,,,,, -500,0.039142035,0.12344751,,,,,,,,,,, -538,,,0.1250187005824263,0.1265432106493676,83274637.0,0.1291569885999177,95000000.0,380.83020520210266,3073.683103084564,380.83020520210266,2692.7577061653137,0.0769233703613281,0.0 -600,0.01344431,0.12288745,,,,,,,,,,, -700,0.020722218,0.123577885,,,,,,,,,,, -715,,,0.1250134759096417,0.1261492000826922,83274637.0,0.1287863612253289,95000000.0,501.5063283443451,3805.973586559296,501.5063283443451,3304.3442330360413,0.0988883972167968,0.0 -800,0.052377407,0.12584583,,,,,,,,,,, -888,,,0.1256469891599889,0.1259796613389358,83274637.0,0.1285003619757401,95000000.0,621.5448927879333,4515.801623344421,621.5448927879333,3894.105229139328,0.1215641498565673,0.0 -900,0.019178953,0.12763691,,,,,,,,,,, -1000,0.04938807,0.13550264,,,,,,,,,,, -1068,,,0.1234735286320155,0.1260174835319587,83274637.0,0.128525582863898,95000000.0,741.6070051193237,5217.348087310791,741.6070051193237,4475.561370134354,0.1439273357391357,0.0 -1100,0.00658263,0.14204456,,,,,,,,,,, -1200,0.04677612,0.12935942,,,,,,,,,,, -1249,,,0.1227490286102647,0.1257213830206151,83274637.0,0.1281114284642269,95000000.0,861.6804373264313,5889.418105363846,861.6804373264313,5027.531148433685,0.1646909713745117,0.0 -1300,0.018106228,0.12241062,,,,,,,,,,, -1400,0.07539979,0.12941392,,,,,,,,,,, -1426,,,0.1246953570289806,0.1255922103046157,83274637.0,0.1279454642475329,95000000.0,982.4170339107512,6562.875961303711,982.4170339107512,5580.225698709488,0.185643196105957,0.0 -1500,0.004922091,0.1337175,,,,,,,,,,, -1600,0.02009989,0.13911735,,,,,,,,,,, -1604,,,0.1235765063982902,0.1255172006136679,83274637.0,0.1279644877569901,95000000.0,1102.5696003437042,7173.276115179062,1102.5696003437042,6070.445714712143,0.2074365615844726,0.0 -1700,0.088498645,0.1306015,,,,,,,,,,, -1776,,,0.1246517481247209,0.1257625124532684,83274637.0,0.1282788197060033,95000000.0,1222.6450836658478,7637.837061882019,1222.6450836658478,6414.903634786606,0.2290723323822021,0.0 -1800,0.01763718,0.117220104,,,,,,,,,,, -1900,0.031226432,0.122089274,,,,,,,,,,, -1954,,,0.124695456428629,0.1251498707596344,83274637.0,0.127675164113898,95000000.0,1343.1871309280396,8251.347687005997,1343.1871309280396,6907.844863176346,0.2506551742553711,0.0 -2000,0.013000768,0.12141036,,,,,,,,,,, -2100,0.013495149,0.12043808,,,,,,,,,,, -2144,,,0.1235761223653764,0.1250831759643958,83274637.0,0.127443936091694,95000000.0,1463.1581037044523,8889.712707042694,1463.1581037044523,7426.209614753723,0.2737016677856445,0.0 -2200,0.005287248,0.12703927,,,,,,,,,,, -2300,0.009499348,0.11720042,,,,,,,,,,, -2331,,,0.125252609439897,0.1250847223641586,83274637.0,0.1275051309313322,95000000.0,1583.281406402588,9548.18579506874,1583.281406402588,7964.530359506607,0.2965288162231445,0.0 -2400,0.0110210925,0.12535265,,,,,,,,,,, -2500,0.01284422,0.12337703,,,,,,,,,,, -2507,,,0.1267426652048931,0.1250378364413401,83274637.0,0.1274699079975328,95000000.0,1703.615050792694,10185.901446580889,1703.615050792694,8481.885344982147,0.3175611495971679,0.0 -2600,0.0075356937,0.12415421,,,,,,,,,,, -2684,,,0.1232916171422357,0.1251517855223715,83274637.0,0.1276610174856085,95000000.0,1824.0671820640564,10815.436591625214,1824.0671820640564,8990.940734148026,0.3392865657806396,0.0 -2700,0.011808872,0.119768225,,,,,,,,,,, -2800,0.008373472,0.12152786,,,,,,,,,,, -2863,,,0.1233500298932663,0.124826699454252,83274637.0,0.1272024709703947,95000000.0,1944.574628829956,11479.299566030502,1944.574628829956,9534.269513607023,0.3602786064147949,0.0 -2900,0.022984825,0.120489456,,,,,,,,,,, -3000,0.0331235,0.12072735,,,,,,,,,,, -3045,,,0.1245352885094266,0.1251230068300519,83274637.0,0.1274849592619243,95000000.0,2064.736346960068,12141.058000802994,2064.736346960068,10075.8374915123,0.3829615116119385,0.0 -3100,0.0058887578,0.1251144,,,,,,,,,,, -3200,0.016445346,0.11473449,,,,,,,,,,, -3220,,,0.1227000882532799,0.124808083846143,83274637.0,0.1271892849712171,95000000.0,2184.739155292511,12806.61373114586,2184.739155292511,10621.36034655571,0.4073977470397949,0.0 -3300,0.006767278,0.12065528,,,,,,,,,,, -3393,,,0.1252406428937079,0.1247541848312672,83274637.0,0.1271991233758223,95000000.0,2304.8250658512115,13478.09382390976,2304.8250658512115,11172.7266330719,0.4294459819793701,0.0 -3400,0.027817262,0.12773971,,,,,,,,,,, -3500,0.009975643,0.12727672,,,,,,,,,,, -3571,,,0.124942989855911,0.1247091829053751,83274637.0,0.1272389699424342,95000000.0,2424.872097969055,14125.62613105774,2424.872097969055,11700.181691408156,0.4539427757263183,0.0 -3600,0.019023668,0.1208598,,,,,,,,,,, -3700,0.008672261,0.12713227,,,,,,,,,,, -3746,,,0.1245335034627772,0.1246319910406814,83274637.0,0.1271622003392269,95000000.0,2545.603095293045,14784.741124391556,2545.603095293045,12238.537390470505,0.4765989780426025,0.0 -3800,0.019418927,0.12083593,,,,,,,,,,, -3900,0.0076151276,0.1349982,,,,,,,,,,, -3923,,,0.1217524022238809,0.1248240568018683,83274637.0,0.1272743532483552,95000000.0,2666.4074409008026,15455.249332904816,2666.4074409008026,12788.208154916763,0.5039494037628174,0.0 -4000,0.007945969,0.13214885,,,,,,,,,,, -4099,,,0.1231201131529403,0.1247593495002828,83274637.0,0.1271502514083059,95000000.0,2787.113216876984,16075.00016617775,2787.113216876984,13287.225280284882,0.5254969596862793,0.0 -4100,0.008236185,0.13203499,,,,,,,,,,, -4200,0.0054630893,0.12670656,,,,,,,,,,, -4274,,,0.1223117412659544,0.1249397528848954,83274637.0,0.1273904122635691,95000000.0,2907.6059803962708,16751.648733615875,2907.6059803962708,13843.353071928024,0.547893762588501,0.0 -4300,0.009874275,0.11865268,,,,,,,,,,, -4400,0.0069060577,0.13006435,,,,,,,,,,, -4455,,,0.1225913947084977,0.124421076110391,83274637.0,0.1267874390008223,95000000.0,3027.923202037812,17443.229073524475,3027.923202037812,14414.588943958282,0.5692830085754395,0.0 -4500,0.0064945845,0.12645456,,,,,,,,,,, -4600,0.020139277,0.1220995,,,,,,,,,,, -4628,,,0.1233379368629283,0.1244669025799461,83274637.0,0.1268318977693256,95000000.0,3148.1890013217926,18120.90138840676,3148.1890013217926,14971.964445590973,0.594536304473877,0.0 -4700,0.00955177,0.12218832,,,,,,,,,,, -4800,0.01400395,0.12096425,,,,,,,,,,, -4814,,,0.1234617902188556,0.1243810264070015,83274637.0,0.1267612940275493,95000000.0,3268.8093614578247,18777.632177591324,3268.8093614578247,15508.046690702438,0.6166715621948242,0.0 -4900,0.0072806827,0.12518404,,,,,,,,,,, -4990,,,0.1227653210385228,0.1242355577147866,83274637.0,0.1265567877672697,95000000.0,3388.91603922844,19446.900912046432,3388.91603922844,16057.181163549423,0.6385774612426758,0.0 -5000,0.0050934553,0.11679775,,,,,,,,,,, -5100,0.040235464,0.11776762,,,,,,,,,,, -5168,,,0.1241660562677765,0.1242902807104288,83274637.0,0.1266145437088816,95000000.0,3509.4399082660675,20106.71199607849,3509.4399082660675,16596.44086575508,0.6602156162261963,0.0 -5200,0.005291338,0.12857696,,,,,,,,,,, -5300,0.010960764,0.13335,,,,,,,,,,, -5344,,,0.1231663083107029,0.1243867278388549,83274637.0,0.1266809336759868,95000000.0,3630.0084569454193,20783.91361641884,3630.0084569454193,17153.042880296707,0.6854746341705322,0.0 -5400,0.006199705,0.12580274,,,,,,,,,,, -5500,0.008503447,0.11941835,,,,,,,,,,, -5517,,,0.1235510464523386,0.1244935922285696,83274637.0,0.1268313451377467,95000000.0,3750.658250808716,21447.12274193764,3750.658250808716,17695.571431159973,0.7104918956756592,0.0 -5600,0.005898555,0.13242878,,,,,,,,,,, -5699,,,0.1242261264566917,0.1243767421960324,83274637.0,0.126768357185444,95000000.0,3871.1222972869873,22112.21555185318,3871.1222972869873,18240.172213554382,0.7323856353759766,0.0 -5700,0.0105206845,0.11286346,,,,,,,,,,, -5800,0.00453692,0.12473185,,,,,,,,,,, -5877,,,0.1226479431977436,0.1241523478241509,83274637.0,0.1265490914987664,95000000.0,3991.6780228614807,22786.44338607788,3991.6780228614807,18793.809163808823,0.7616417407989502,0.0 -5900,0.006412353,0.1211244,,,,,,,,,,, -6000,0.010110289,0.1210707,,,,,,,,,,, -6056,,,0.1225242608185834,0.1243487263222586,83274637.0,0.1266930142269736,95000000.0,4111.936642169952,23440.723692655563,4111.936642169952,19327.80193209648,0.784590482711792,0.0 -6100,0.0072887857,0.12463021,,,,,,,,,,, -6200,0.013546782,0.118647546,,,,,,,,,,, -6232,,,0.1234508143538761,0.1242802672980467,83274637.0,0.1266402905016447,95000000.0,4232.155610084534,24116.43520641327,4232.155610084534,19883.262872695923,0.8101849555969238,0.0 -6300,0.0068567884,0.122523114,,,,,,,,,,, -6400,0.010439426,0.12558915,,,,,,,,,,, -6409,,,0.1205027495516733,0.1242960469867839,83274637.0,0.126642444634046,95000000.0,4352.459222078323,24783.39586544037,4352.459222078323,20429.887636899948,0.8363173007965088,0.0 -6500,0.009170041,0.124075375,,,,,,,,,,, -6590,,,0.1231092098491184,0.1241006255065073,83274637.0,0.1264559795435855,95000000.0,4472.683204650879,25447.36835289001,4472.683204650879,20973.60512423516,0.8614339828491211,0.0 -6600,0.018652527,0.12477882,,,,,,,,,,, -6700,0.0075468095,0.12340215,,,,,,,,,,, -6770,,,0.120198328223712,0.1242077644196624,83274637.0,0.126623035207648,95000000.0,4593.207630872726,26102.873340845108,4593.207630872726,21508.55733346939,0.8837997913360596,0.0 -6800,0.007258374,0.12598108,,,,,,,,,,, -6900,0.005702285,0.13080636,,,,,,,,,,, -6946,,,0.1179989202873511,0.1241508252712629,83274637.0,0.1264218494140625,95000000.0,4713.524817228317,26765.58244419098,4713.524817228317,22050.916337251663,0.9108951091766356,0.0 -7000,0.010179522,0.13009855,,,,,,,,,,, -7100,0.011751409,0.12513566,,,,,,,,,,, -7123,,,0.1220455420607665,0.1240180520055227,83274637.0,0.1263023667865954,95000000.0,4833.66826915741,27425.98725819588,4833.66826915741,22591.14922332764,0.933558702468872,0.0 -7200,0.008777374,0.1183174,,,,,,,,,,, -7300,0.00626218,0.1305419,,,,,,,,,,, -7306,,,0.1221627628419009,0.1240735888791122,83274637.0,0.1264483639391447,95000000.0,4953.663403511047,28093.290339946747,4953.663403511047,23138.42849516869,0.9559965133666992,0.0 -7400,0.011616847,0.12169292,,,,,,,,,,, -7485,,,0.1203739370260411,0.1240158122580213,83274637.0,0.1263452935752467,95000000.0,5073.9851770401,28744.60824251175,5073.9851770401,23669.392629384995,0.9820210933685304,0.0 -7500,0.0048046787,0.11606923,,,,,,,,,,, -7600,0.013483932,0.120348215,,,,,,,,,,, -7663,,,0.1207796007793092,0.1238437841866733,83274637.0,0.1261958093647203,95000000.0,5194.398851156235,29409.30269098282,5194.398851156235,24213.64492177964,1.0047781467437744,0.0 -7700,0.009994014,0.120052084,,,,,,,,,,, -7800,0.009775802,0.13166305,,,,,,,,,,, -7843,,,0.1195970896936062,0.1238456539939176,83274637.0,0.1261964286800986,95000000.0,5314.715612649918,30069.44606542588,5314.715612649918,24753.443234682083,1.0272042751312256,0.0 -7900,0.010409527,0.12421111,,,,,,,,,,, -8000,0.011801486,0.1203686,,,,,,,,,,, -8018,,,0.1206216625869274,0.1238476197632847,83274637.0,0.1261577527138158,95000000.0,5435.273231506348,30739.97963070869,5435.273231506348,25303.39065337181,1.049952745437622,0.0 -8100,0.0076918104,0.12085098,,,,,,,,,,, -8198,,,0.1217111458404446,0.123903496778799,83274637.0,0.1262986234477796,95000000.0,5555.939074039459,31397.71311402321,5555.939074039459,25840.4235098362,1.0788283348083496,0.0 -8200,0.021286935,0.11774483,,,,,,,,,,, -8300,0.0065134177,0.12650844,,,,,,,,,,, -8374,,,0.1197104043087119,0.1237947691900101,83274637.0,0.1261533757915296,95000000.0,5676.031098842621,32071.21576356888,5676.031098842621,26393.805539608,1.1015920639038086,0.0 -8400,0.0057804207,0.11970234,,,,,,,,,,, -8500,0.008060225,0.12869728,,,,,,,,,,, -8553,,,0.1218349276426828,0.1238067552613776,83274637.0,0.1261754389905427,95000000.0,5796.273926258087,32733.37809896469,5796.273926258087,26935.695637226105,1.125080108642578,0.0 -8600,0.009370948,0.12444587,,,,,,,,,,, -8700,0.005969734,0.13464676,,,,,,,,,,, -8733,,,0.1219008069063695,0.1237535159318583,83274637.0,0.126067133645148,95000000.0,5916.893688201904,33391.54086971283,5916.893688201904,27473.209585905075,1.1481850147247314,0.0 -8800,0.0068824594,0.112984896,,,,,,,,,,, -8900,0.0073679914,0.13360634,,,,,,,,,,, -8912,,,0.121345523768251,0.1237732839121945,83274637.0,0.1260618208778783,95000000.0,6036.8774428367615,34048.80773925781,6036.8774428367615,28010.46356701851,1.1713495254516602,0.0 -9000,0.010680152,0.11689377,,,,,,,,,,, -9084,,,0.1228009223609975,0.1237323762842081,83274637.0,0.1260166462273848,95000000.0,6157.476182222366,34711.51103806496,6157.476182222366,28552.5388212204,1.1950955390930176,0.0 -9084,,,,,,,,6157.476182222366,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index f04303ef8..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,26 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -201.36471915245056,0.0,54.686328411102295,1,0,54.686328411102295,0.7713690200886624,3581,0.2424260842270839,256.05149149894714,0.7593945094517299,0.2261475665228707,0.7711199501222215,3554,0.2178304368482256 -205.8366770744324,0.0289773941040039,134.86516547203064,339,0,134.86516547203064,0.3133805697605417,3581,0.7145979731351229,340.744056224823,0.2890044961656843,0.716214793069022,0.3112095718512591,3554,0.6970787756752954 -209.91926383972168,0.0662980079650878,214.8670153617859,582,0,214.8670153617859,0.3054640662633517,3581,0.7219551212126502,424.8745911121368,0.2816087688718523,0.72406005859375,0.3035313698033905,3554,0.7046305109559651 -213.99327325820923,0.1038587093353271,294.9222095012665,827,0,294.9222095012665,0.3004246519717781,3581,0.7280137084176906,509.0501329898834,0.2767578193119594,0.7303955214364188,0.2985322916849852,3554,0.7108409153858328 -218.0668997764588,0.1392419338226318,374.9872243404389,1111,0,374.9872243404389,0.2976607701126605,3581,0.7302057244397515,593.234457731247,0.274313143321446,0.7323812076023647,0.2958636780300366,3554,0.7131611443531936 -222.1412708759308,0.1635708808898925,455.0248363018036,1458,0,455.0248363018036,0.29979234752426,3581,0.7264880510855907,677.3835711479187,0.2768131153924124,0.728395802634103,0.2979012630464793,3554,0.7095637452957935 -226.21669507026672,0.1881003379821777,535.0989918708801,1804,0,535.0989918708801,0.2936607772030333,3581,0.7347816737861281,761.5707142353058,0.2701465232031686,0.7374175616673061,0.2921151164950584,3554,0.7174316131471581 -230.30079555511475,0.2125391960144043,615.1334404945374,2149,0,615.1334404945374,0.2932166403413851,3581,0.7366957335590617,845.7267036437988,0.2696566752025059,0.7391834940229144,0.2915768599201603,3554,0.7197018324423186 -234.3766646385193,0.2379350662231445,695.336752653122,2497,0,695.336752653122,0.3009507030922752,3581,0.7240060797219702,930.0442731380464,0.278454269681658,0.7255415916442871,0.2993109106816263,3554,0.7071849198306486 -238.4527871608734,0.2625031471252441,775.4102036952972,2844,0,775.4102036952972,0.2932444564192963,3581,0.7338836508002304,1014.231377363205,0.270389403615679,0.7352662767682757,0.2916783905458638,3554,0.7170392295652785 -242.53347277641296,0.2881107330322265,855.4152998924255,3188,0,855.4152998924255,0.291109538351281,3581,0.7376729778256772,1098.3554491996765,0.2676169191087995,0.7407202039446149,0.2895971501081528,3554,0.7203461878341305 -246.6138072013855,0.3145184516906738,935.459174633026,3535,0,935.459174633026,0.2913536448879503,3581,0.7371134519643605,1182.519168138504,0.2674918004444667,0.7404157093593052,0.2897985626890476,3554,0.7198601735060847 -250.6893527507782,0.3417632579803467,1015.5872983932496,3883,0,1015.5872983932496,0.2916627237830738,3581,0.7358972484990226,1266.763106584549,0.2685514858790806,0.738018308367048,0.290056373540377,3554,0.7190554162123663 -254.77047514915463,0.3691582679748535,1095.6436443328855,4227,0,1095.6436443328855,0.2900163596717048,3581,0.7396287617154775,1350.940851688385,0.2665797131402151,0.7423955372401646,0.2886444589832759,3554,0.7223245237540448 -258.84808468818665,0.3941807746887207,1175.8080520629885,4571,0,1175.8080520629885,0.2897774686496439,3581,0.7400900450075049,1435.2206432819366,0.2659957238606044,0.7432968275887626,0.2883094697435988,3554,0.7228588990837789 -262.92627787590027,0.4198975563049316,1255.8872196674347,4915,0,1255.8872196674347,0.2895973799981674,3581,0.7406827047263335,1519.4165687561035,0.2659684930528913,0.7437765257699149,0.288180873443655,3554,0.7234973467395892 -267.0048804283142,0.4453067779541015,1336.0409696102142,5262,0,1336.0409696102142,0.2892613031450712,3581,0.7389619257888858,1603.6873452663422,0.2657807043620518,0.7419512612479073,0.2879154203175559,3554,0.7215963609445343 -271.0842912197113,0.470118761062622,1416.0627336502075,5605,0,1416.0627336502075,0.2892330439188599,3581,0.7384023317509075,1687.8261096477509,0.2656069993972778,0.7413966315133231,0.2878381560609524,3554,0.721058070022334 -275.16144132614136,0.4962143898010254,1496.1338346004486,5951,0,1496.1338346004486,0.2888655376247905,3581,0.7396513281904495,1772.013376712799,0.2650624173028128,0.7431563649858747,0.2875061893838808,3554,0.7223671144089406 -279.24346709251404,0.5249452590942383,1576.1826355457306,6294,0,1576.1826355457306,0.2892078526424183,3581,0.7398609032480452,1856.1862041950224,0.2655024017606462,0.743187495640346,0.2878637791484946,3554,0.7226340616426913 -283.3258457183838,0.5521426200866699,1656.2280633449554,6639,0,1656.2280633449554,0.2886234763879677,3581,0.7390587366482826,1940.3542692661283,0.2648814746311733,0.7422693116324288,0.2872580301245076,3554,0.7217326510402012 -287.40594720840454,0.5793905258178711,1736.193373918533,6985,0,1736.193373918533,0.2886633597349727,3581,0.7385842270839151,2024.4396481513977,0.2650238616125924,0.7416989462716239,0.2872385208567811,3554,0.7212246544386607 -291.48550271987915,0.6093320846557617,1816.3327956199648,7331,0,1816.3327956199648,0.2882811954586707,3581,0.7398701070973541,2108.7012643814087,0.2646456786564418,0.7430637904575893,0.2869049398564997,3554,0.7226036299328221 -295.5657305717468,0.6355040073394775,1896.345843076706,7676,0,1896.345843076706,0.2889652119039723,3581,0.7402495102188634,2192.833665370941,0.2649330581937517,0.7440517289297921,0.2876525088918296,3554,0.7231099091692459 -299.64626455307007,0.6633594036102295,1976.387057542801,8022,0,1976.387057542801,0.28833481640297753,3581,0.7410538584717956,2276.996250152588,0.2646443673542568,0.7439615385872977,0.286903943784732,3554,0.723860054252251 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 12d02c6c0..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.1362333,0.6871885,,,,,,,,,,,,,, -1,,,0.2261475665228707,0.7593945094517299,0.2178304368482256,0.7711199501222215,3554.0,0.2424260842270839,0.7713690200886624,3581.0,54.686328411102295,256.05149149894714,54.686328411102295,201.36471915245056,0.0,0.0 -100,0.18867734,0.25560138,,,,,,,,,,,,,, -200,0.13936792,0.36383906,,,,,,,,,,,,,, -300,0.71918374,0.3167506,,,,,,,,,,,,,, -339,,,0.716214793069022,0.2890044961656843,0.6970787756752954,0.3112095718512591,3554.0,0.7145979731351229,0.3133805697605417,3581.0,134.86516547203064,340.744056224823,134.86516547203064,205.8366770744324,0.0289773941040039,0.0 -400,0.34530437,0.29572603,,,,,,,,,,,,,, -500,0.17815068,0.2167565,,,,,,,,,,,,,, -582,,,0.72406005859375,0.2816087688718523,0.7046305109559651,0.3035313698033905,3554.0,0.7219551212126502,0.3054640662633517,3581.0,214.8670153617859,424.8745911121368,214.8670153617859,209.91926383972168,0.0662980079650878,0.0 -600,0.3996742,0.35493925,,,,,,,,,,,,,, -700,0.17787038,0.22358857,,,,,,,,,,,,,, -800,0.26794443,0.27551624,,,,,,,,,,,,,, -827,,,0.7303955214364188,0.2767578193119594,0.7108409153858328,0.2985322916849852,3554.0,0.7280137084176906,0.3004246519717781,3581.0,294.9222095012665,509.0501329898834,294.9222095012665,213.99327325820923,0.1038587093353271,0.0 -900,0.09775496,0.34897754,,,,,,,,,,,,,, -1000,0.12532401,0.27639228,,,,,,,,,,,,,, -1100,0.66096723,0.2349462,,,,,,,,,,,,,, -1111,,,0.7323812076023647,0.274313143321446,0.7131611443531936,0.2958636780300366,3554.0,0.7302057244397515,0.2976607701126605,3581.0,374.9872243404389,593.234457731247,374.9872243404389,218.0668997764588,0.1392419338226318,0.0 -1200,0.1215177,0.26651984,,,,,,,,,,,,,, -1300,0.18117695,0.29726294,,,,,,,,,,,,,, -1400,0.08233435,0.2719007,,,,,,,,,,,,,, -1458,,,0.728395802634103,0.2768131153924124,0.7095637452957935,0.2979012630464793,3554.0,0.7264880510855907,0.29979234752426,3581.0,455.0248363018036,677.3835711479187,455.0248363018036,222.1412708759308,0.1635708808898925,0.0 -1500,0.27193,0.27510378,,,,,,,,,,,,,, -1600,0.25289968,0.20034413,,,,,,,,,,,,,, -1700,0.12687688,0.29427257,,,,,,,,,,,,,, -1800,0.065033644,0.27312732,,,,,,,,,,,,,, -1804,,,0.7374175616673061,0.2701465232031686,0.7174316131471581,0.2921151164950584,3554.0,0.7347816737861281,0.2936607772030333,3581.0,535.0989918708801,761.5707142353058,535.0989918708801,226.21669507026672,0.1881003379821777,0.0 -1900,0.13386966,0.35460165,,,,,,,,,,,,,, -2000,0.10274796,0.36636758,,,,,,,,,,,,,, -2100,0.088999465,0.2739057,,,,,,,,,,,,,, -2149,,,0.7391834940229144,0.2696566752025059,0.7197018324423186,0.2915768599201603,3554.0,0.7366957335590617,0.2932166403413851,3581.0,615.1334404945374,845.7267036437988,615.1334404945374,230.30079555511475,0.2125391960144043,0.0 -2200,0.1586352,0.2106151,,,,,,,,,,,,,, -2300,0.08152793,0.2908054,,,,,,,,,,,,,, -2400,0.1252261,0.2806339,,,,,,,,,,,,,, -2497,,,0.7255415916442871,0.278454269681658,0.7071849198306486,0.2993109106816263,3554.0,0.7240060797219702,0.3009507030922752,3581.0,695.336752653122,930.0442731380464,695.336752653122,234.3766646385193,0.2379350662231445,0.0 -2500,0.11423849,0.25861305,,,,,,,,,,,,,, -2600,0.075898044,0.22241968,,,,,,,,,,,,,, -2700,0.11698237,0.27981696,,,,,,,,,,,,,, -2800,0.21528886,0.31048924,,,,,,,,,,,,,, -2844,,,0.7352662767682757,0.270389403615679,0.7170392295652785,0.2916783905458638,3554.0,0.7338836508002304,0.2932444564192963,3581.0,775.4102036952972,1014.231377363205,775.4102036952972,238.4527871608734,0.2625031471252441,0.0 -2900,0.14399615,0.27264944,,,,,,,,,,,,,, -3000,0.055957202,0.32092097,,,,,,,,,,,,,, -3100,0.17403677,0.2411504,,,,,,,,,,,,,, -3188,,,0.7407202039446149,0.2676169191087995,0.7203461878341305,0.2895971501081528,3554.0,0.7376729778256772,0.291109538351281,3581.0,855.4152998924255,1098.3554491996765,855.4152998924255,242.53347277641296,0.2881107330322265,0.0 -3200,0.23296437,0.25141522,,,,,,,,,,,,,, -3300,0.19181006,0.32140362,,,,,,,,,,,,,, -3400,0.07361997,0.35843396,,,,,,,,,,,,,, -3500,0.10419288,0.23305872,,,,,,,,,,,,,, -3535,,,0.7404157093593052,0.2674918004444667,0.7198601735060847,0.2897985626890476,3554.0,0.7371134519643605,0.2913536448879503,3581.0,935.459174633026,1182.519168138504,935.459174633026,246.6138072013855,0.3145184516906738,0.0 -3600,0.061744448,0.3161912,,,,,,,,,,,,,, -3700,0.15843846,0.31731635,,,,,,,,,,,,,, -3800,0.08700883,0.28914648,,,,,,,,,,,,,, -3883,,,0.738018308367048,0.2685514858790806,0.7190554162123663,0.290056373540377,3554.0,0.7358972484990226,0.2916627237830738,3581.0,1015.5872983932496,1266.763106584549,1015.5872983932496,250.6893527507782,0.3417632579803467,0.0 -3900,0.13392554,0.29748714,,,,,,,,,,,,,, -4000,0.14078137,0.2527919,,,,,,,,,,,,,, -4100,0.09767304,0.25113565,,,,,,,,,,,,,, -4200,0.15204743,0.27659684,,,,,,,,,,,,,, -4227,,,0.7423955372401646,0.2665797131402151,0.7223245237540448,0.2886444589832759,3554.0,0.7396287617154775,0.2900163596717048,3581.0,1095.6436443328855,1350.940851688385,1095.6436443328855,254.77047514915463,0.3691582679748535,0.0 -4300,0.2328859,0.37302572,,,,,,,,,,,,,, -4400,0.05640749,0.28350803,,,,,,,,,,,,,, -4500,0.0991314,0.2926692,,,,,,,,,,,,,, -4571,,,0.7432968275887626,0.2659957238606044,0.7228588990837789,0.2883094697435988,3554.0,0.7400900450075049,0.2897774686496439,3581.0,1175.8080520629885,1435.2206432819366,1175.8080520629885,258.84808468818665,0.3941807746887207,0.0 -4600,0.04445569,0.3098628,,,,,,,,,,,,,, -4700,0.19781102,0.20798874,,,,,,,,,,,,,, -4800,0.33237627,0.2461478,,,,,,,,,,,,,, -4900,0.07761949,0.26157877,,,,,,,,,,,,,, -4915,,,0.7437765257699149,0.2659684930528913,0.7234973467395892,0.288180873443655,3554.0,0.7406827047263335,0.2895973799981674,3581.0,1255.8872196674347,1519.4165687561035,1255.8872196674347,262.92627787590027,0.4198975563049316,0.0 -5000,0.18235882,0.24055909,,,,,,,,,,,,,, -5100,0.062968984,0.2958387,,,,,,,,,,,,,, -5200,0.059635047,0.3080678,,,,,,,,,,,,,, -5262,,,0.7419512612479073,0.2657807043620518,0.7215963609445343,0.2879154203175559,3554.0,0.7389619257888858,0.2892613031450712,3581.0,1336.0409696102142,1603.6873452663422,1336.0409696102142,267.0048804283142,0.4453067779541015,0.0 -5300,0.15675758,0.20586959,,,,,,,,,,,,,, -5400,0.20995831,0.3441805,,,,,,,,,,,,,, -5500,0.076509826,0.2537634,,,,,,,,,,,,,, -5600,0.051993236,0.23338604,,,,,,,,,,,,,, -5605,,,0.7413966315133231,0.2656069993972778,0.721058070022334,0.2878381560609524,3554.0,0.7384023317509075,0.2892330439188599,3581.0,1416.0627336502075,1687.8261096477509,1416.0627336502075,271.0842912197113,0.470118761062622,0.0 -5700,0.1455133,0.23797318,,,,,,,,,,,,,, -5800,0.07751339,0.30312604,,,,,,,,,,,,,, -5900,0.10402378,0.2581895,,,,,,,,,,,,,, -5951,,,0.7431563649858747,0.2650624173028128,0.7223671144089406,0.2875061893838808,3554.0,0.7396513281904495,0.2888655376247905,3581.0,1496.1338346004486,1772.013376712799,1496.1338346004486,275.16144132614136,0.4962143898010254,0.0 -6000,0.15605047,0.2578016,,,,,,,,,,,,,, -6100,0.16521783,0.336954,,,,,,,,,,,,,, -6200,0.057411604,0.25239784,,,,,,,,,,,,,, -6294,,,0.743187495640346,0.2655024017606462,0.7226340616426913,0.2878637791484946,3554.0,0.7398609032480452,0.2892078526424183,3581.0,1576.1826355457306,1856.1862041950224,1576.1826355457306,279.24346709251404,0.5249452590942383,0.0 -6300,0.22834006,0.22516625,,,,,,,,,,,,,, -6400,0.11513728,0.2860917,,,,,,,,,,,,,, -6500,0.115646146,0.23372775,,,,,,,,,,,,,, -6600,0.08623608,0.28001586,,,,,,,,,,,,,, -6639,,,0.7422693116324288,0.2648814746311733,0.7217326510402012,0.2872580301245076,3554.0,0.7390587366482826,0.2886234763879677,3581.0,1656.2280633449554,1940.3542692661283,1656.2280633449554,283.3258457183838,0.5521426200866699,0.0 -6700,0.09045506,0.20667937,,,,,,,,,,,,,, -6800,0.067035235,0.25470054,,,,,,,,,,,,,, -6900,0.06514288,0.32365653,,,,,,,,,,,,,, -6985,,,0.7416989462716239,0.2650238616125924,0.7212246544386607,0.2872385208567811,3554.0,0.7385842270839151,0.2886633597349727,3581.0,1736.193373918533,2024.4396481513977,1736.193373918533,287.40594720840454,0.5793905258178711,0.0 -7000,0.19603159,0.21868013,,,,,,,,,,,,,, -7100,0.08028179,0.26402196,,,,,,,,,,,,,, -7200,0.099218875,0.29664338,,,,,,,,,,,,,, -7300,0.13075283,0.30743134,,,,,,,,,,,,,, -7331,,,0.7430637904575893,0.2646456786564418,0.7226036299328221,0.2869049398564997,3554.0,0.7398701070973541,0.2882811954586707,3581.0,1816.3327956199648,2108.7012643814087,1816.3327956199648,291.48550271987915,0.6093320846557617,0.0 -7400,0.04530741,0.25625318,,,,,,,,,,,,,, -7500,0.11061286,0.17940588,,,,,,,,,,,,,, -7600,0.12354575,0.26766393,,,,,,,,,,,,,, -7676,,,0.7440517289297921,0.2649330581937517,0.7231099091692459,0.2876525088918296,3554.0,0.7402495102188634,0.2889652119039723,3581.0,1896.345843076706,2192.833665370941,1896.345843076706,295.5657305717468,0.6355040073394775,0.0 -7700,0.147539,0.20948106,,,,,,,,,,,,,, -7800,0.043592654,0.28634912,,,,,,,,,,,,,, -7900,0.101965934,0.31336993,,,,,,,,,,,,,, -8000,0.072190434,0.2711475,,,,,,,,,,,,,, -8022,,,0.7439615385872977,0.2646443673542568,0.723860054252251,0.286903943784732,3554.0,0.7410538584717956,0.2883348164029775,3581.0,1976.387057542801,2276.996250152588,1976.387057542801,299.64626455307007,0.6633594036102295,0.0 -8022,,,,,,,,,,,1976.387057542801,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index bee9646b9..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,372 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.31908655166626,0.0,52.27670836448669,1,0,52.27670836448669,0.0008000000379979,6.914037704467773,10000,89.5958993434906,0.0011160713620483,6.914276599884033,0.0011199999134987,6.913983345031738,50000 -55.3390634059906,0.0274107456207275,562.4275517463684,1499,0,562.4275517463684,0.1118000075221061,4.884735107421875,10000,617.8459360599518,0.1658561825752258,4.275528907775879,0.1472399979829788,4.434837818145752,50000 -73.20708441734314,0.0605521202087402,1072.6695590019226,2996,0,1072.6695590019226,0.236500009894371,3.801160335540772,10000,1146.0415749549866,0.3450255095958709,3.01550817489624,0.3194399774074554,3.1867895126342773,50000 -91.00414776802064,0.0871176719665527,1582.7634320259094,4494,0,1582.7634320259094,0.3276000022888183,3.2471020221710205,10000,1674.0095942020416,0.4690887928009033,2.31603479385376,0.4360799789428711,2.511484146118164,50000 -108.64397525787354,0.1147418022155761,2092.8671317100525,5994,0,2092.8671317100525,0.3756000101566314,2.999966144561768,10000,2201.830789089203,0.5284797549247742,2.0127384662628174,0.4960199892520904,2.2188453674316406,50000 -126.56366062164308,0.141902208328247,2602.7957921028137,7494,0,2602.7957921028137,0.4164000153541565,2.72767186164856,10000,2729.759635448456,0.5736607313156128,1.7896974086761477,0.5335999727249146,1.997921347618103,50000 -144.63809752464294,0.1693737506866455,3112.8875806331635,8995,0,3112.8875806331635,0.4229000210762024,2.710902690887451,10000,3258.0060436725616,0.5866350531578064,1.7268341779708862,0.5454999804496765,1.93905246257782,50000 -162.88873028755188,0.1964619159698486,3622.910650968552,10495,0,3622.910650968552,0.45210000872612,2.53738784790039,10000,3786.3579564094543,0.6527024507522583,1.4016082286834717,0.5786399841308594,1.790817141532898,50000 -181.1612477302552,0.2311422824859619,4133.089784383774,11997,0,4133.089784383774,0.4599000215530395,2.4923598766326904,10000,4314.894702672958,0.6460459232330322,1.433350682258606,0.5874199867248535,1.7437560558319092,50000 -199.754013299942,0.2653450965881347,4643.165565252304,13498,0,4643.165565252304,0.4731000363826751,2.4181039333343506,10000,4843.648952245712,0.6513671875,1.4058401584625244,0.5913599729537964,1.706515908241272,50000 -217.90440320968628,0.3007020950317383,5153.095174074173,15000,0,5153.095174074173,0.4726000130176544,2.420872211456299,10000,5371.81556725502,0.6477000713348389,1.419167160987854,0.592960000038147,1.7010241746902466,50000 -236.13387608528137,0.3381190299987793,5663.19630074501,16501,0,5663.19630074501,0.4708000123500824,2.4155619144439697,10000,5900.233821630478,0.6524234414100647,1.3966569900512695,0.5991599559783936,1.6813833713531494,50000 -255.26897883415225,0.3690087795257568,6173.271780967712,18003,0,6173.271780967712,0.4759000241756439,2.3799636363983154,10000,6429.526218175888,0.6489955186843872,1.4116501808166504,0.5984399914741516,1.6733524799346924,50000 -274.1825485229492,0.4079773426055908,6683.211872577667,19505,0,6683.211872577667,0.4868000149726867,2.348141431808472,10000,6958.470961809158,0.7088249325752258,1.1469347476959229,0.6089000105857849,1.630070447921753,50000 -292.59999918937683,0.4410958290100097,7193.288778066635,21007,0,7193.288778066635,0.4828000366687774,2.377382516860962,10000,7487.048682928085,0.6806440949440002,1.2608718872070312,0.6103000044822693,1.6341451406478882,50000 -311.04740691185,0.4867265224456787,7703.217472076416,22508,0,7703.217472076416,0.4828000366687774,2.365996122360229,10000,8015.52258348465,0.6704998016357422,1.3082401752471924,0.6072799563407898,1.6411656141281128,50000 -332.26763224601746,0.5245239734649658,8213.273429870605,24010,0,8213.273429870605,0.4818000197410583,2.3678853511810303,10000,8546.886992692947,0.6690449714660645,1.3177435398101809,0.608460009098053,1.6242786645889282,50000 -355.46244287490845,0.5544643402099609,8723.448077917099,25512,0,8723.448077917099,0.499500036239624,2.262160301208496,10000,9080.337821245192,0.6869817972183228,1.2383699417114258,0.627079963684082,1.5450199842453003,50000 -379.0861880779266,0.5806448459625244,9233.63310432434,27015,0,9233.63310432434,0.4883000254631042,2.3350415229797363,10000,9614.2236931324,0.6685666441917419,1.31665301322937,0.6162599921226501,1.593363881111145,50000 -402.9537811279297,0.8449652194976807,9743.41288614273,28516,0,9743.41288614273,0.4898000359535217,2.3473498821258545,10000,10148.186551094055,0.7271006107330322,1.0419692993164062,0.6188399791717529,1.5885690450668335,50000 -427.8536124229431,0.8739020824432373,10253.381639242172,30018,0,10253.381639242172,0.5054000020027161,2.2726728916168213,10000,10683.135558843613,0.70804762840271,1.1357526779174805,0.6301800012588501,1.532379984855652,50000 -452.8360719680786,0.9039633274078368,10763.521374940872,31520,0,10763.521374940872,0.4951000213623047,2.2808399200439453,10000,11218.338215351105,0.6942163705825806,1.2090507745742798,0.6239399909973145,1.5508848428726196,50000 -477.199057340622,0.9372315406799316,11273.513464927672,33022,0,11273.513464927672,0.4944000244140625,2.32073712348938,10000,11752.77766919136,0.6808235049247742,1.2375199794769287,0.6197400093078613,1.5743621587753296,50000 -503.7779715061188,0.9684596061706544,11783.88451075554,34524,0,11783.88451075554,0.496500015258789,2.297478675842285,10000,12289.809584379196,0.6873405575752258,1.2315131425857544,0.626479983329773,1.5449154376983645,50000 -528.6820592880249,1.0011744499206543,12294.029720783234,36027,0,12294.029720783234,0.5012000203132629,2.299901008605957,10000,12824.942950963974,0.6824377775192261,1.2468916177749634,0.6263799667358398,1.54625141620636,50000 -551.5062041282654,1.030848264694214,12804.038206338882,37529,0,12804.038206338882,0.5037000179290771,2.2564914226531982,10000,13357.855704545977,0.7158800959587097,1.1090245246887207,0.6350600123405457,1.5164011716842651,50000 -574.3761565685272,1.0700502395629885,13314.092184782028,39032,0,13314.092184782028,0.5091000199317932,2.224526882171631,10000,13890.870227098463,0.71683669090271,1.1103540658950806,0.62909996509552,1.5288817882537842,50000 -597.3547255992889,1.10364031791687,13824.17870426178,40534,0,13824.17870426178,0.5076000094413757,2.232386827468872,10000,14424.020713090897,0.7096619606018066,1.1346745491027832,0.6363999843597412,1.4892843961715698,50000 -622.8276555538177,1.1378414630889893,14334.231875896454,42036,0,14334.231875896454,0.5042000412940979,2.279670000076294,10000,14959.633465051653,0.6986008882522583,1.179208517074585,0.6325599551200867,1.5215106010437012,50000 -645.9152612686157,1.1770331859588623,14844.221347093582,43538,0,14844.221347093582,0.4919000267982483,2.3413572311401367,10000,15492.801275491714,0.6833944320678711,1.2552924156188965,0.6233400106430054,1.5577534437179563,50000 -669.1851181983948,1.2121222019195557,15354.339327812197,45041,0,15354.339327812197,0.4878000319004059,2.3068480491638184,10000,16026.276261806488,0.6790497303009033,1.2766070365905762,0.6209200024604797,1.573642611503601,50000 -692.1000168323517,1.25059175491333,15864.367563009262,46544,0,15864.367563009262,0.5092000365257263,2.218715190887451,10000,16559.309529066086,0.7046595811843872,1.145161271095276,0.6420800089836121,1.48219633102417,50000 -714.6889681816101,1.2839748859405518,16374.504691123962,48048,0,16374.504691123962,0.5217000246047974,2.187633752822876,10000,17092.12025952339,0.7367067933082581,0.9979020357131958,0.6473199725151062,1.4513506889343262,50000 -737.3301610946655,1.3194658756256104,16884.7590508461,49551,0,16884.7590508461,0.5031000375747681,2.310176134109497,10000,17625.10194325447,0.7074099183082581,1.1301013231277466,0.6358799934387207,1.5119726657867432,50000 -758.3226172924042,1.3547601699829102,17394.90104007721,51054,0,17394.90104007721,0.527400016784668,2.135467767715454,10000,18156.323054790497,0.7250877022743225,1.0787639617919922,0.6547200083732605,1.407345175743103,50000 -780.4596221446991,1.389142990112305,17905.1085562706,52557,0,17905.1085562706,0.5105000138282776,2.2743141651153564,10000,18688.75293493271,0.7002750039100647,1.1716082096099854,0.6354599595069885,1.5122188329696655,50000 -801.5913171768188,1.4218056201934814,18415.254390001297,54061,0,18415.254390001297,0.5192000269889832,2.1851463317871094,10000,19220.11591911316,0.704500138759613,1.1456981897354126,0.6450799703598022,1.4624508619308472,50000 -821.6257679462433,1.4563281536102295,18925.44615507126,55564,0,18925.44615507126,0.5225000381469727,2.1526975631713867,10000,19750.42909526825,0.7084661722183228,1.1322740316390991,0.649679958820343,1.445713996887207,50000 -838.8901033401489,1.4961392879486084,19435.642174959183,57068,0,19435.642174959183,0.5080000162124634,2.2163994312286377,10000,20277.98100042343,0.7385801672935486,1.0057047605514526,0.6440799832344055,1.4641975164413452,50000 -856.131756067276,1.5372190475463867,19945.802860736847,58571,0,19945.802860736847,0.522100031375885,2.178805589675904,10000,20805.475059747696,0.7267019748687744,1.0357704162597656,0.646619975566864,1.45211923122406,50000 -873.3655483722687,1.5809953212738037,20455.73877811432,60074,0,20455.73877811432,0.5263000130653381,2.112003803253174,10000,21332.740632534027,0.7245694994926453,1.055361032485962,0.6547600030899048,1.4124293327331543,50000 -890.5964822769165,1.6265153884887695,20965.771684646606,61577,0,20965.771684646606,0.526900053024292,2.1299495697021484,10000,21860.102259159088,0.7218191623687744,1.0774519443511963,0.6505599617958069,1.4308249950408936,50000 -907.5062689781188,1.6687884330749512,21475.95622348785,63081,0,21475.95622348785,0.5213000178337097,2.185504913330078,10000,22387.29090833664,0.7154416441917419,1.1013087034225464,0.6513800024986267,1.4529109001159668,50000 -924.2931768894196,1.7055230140686035,21985.94143342972,64584,0,21985.94143342972,0.5286000370979309,2.160560846328736,10000,22914.152262449265,0.7234733700752258,1.0683307647705078,0.6577000021934509,1.4033039808273315,50000 -941.1070840358734,1.74554705619812,22495.97627878189,66087,0,22495.97627878189,0.5059000253677368,2.2493345737457275,10000,23441.093533992767,0.7430046200752258,0.9785423874855042,0.6419000029563904,1.4796708822250366,50000 -957.8929336071014,1.7923862934112549,23006.104821681976,67590,0,23006.104821681976,0.5074000358581543,2.2431344985961914,10000,23968.10825037956,0.7223373651504517,1.0661697387695312,0.6440399885177612,1.467403531074524,50000 -974.7140092849731,1.8431601524353027,23516.008523225784,69093,0,23516.008523225784,0.5078000426292419,2.2381787300109863,10000,24494.93649435044,0.7189094424247742,1.087269306182861,0.6390399932861328,1.490094780921936,50000 -991.4904737472534,1.882537841796875,24026.23086190224,70596,0,24026.23086190224,0.5357000231742859,2.1112637519836426,10000,25022.027183532715,0.7314253449440002,1.0208046436309814,0.6607199907302856,1.391520619392395,50000 -1008.195422887802,1.924994707107544,24536.351742506027,72099,0,24536.351742506027,0.5326000452041626,2.116264581680298,10000,25548.94845175743,0.7379224896430969,0.9979417324066162,0.667639970779419,1.3583799600601196,50000 -1024.909749507904,1.9686222076416016,25046.33527326584,73602,0,25046.33527326584,0.5344000458717346,2.101027011871338,10000,26075.744012355804,0.7327207922935486,1.0299476385116575,0.6582599878311157,1.3956197500228882,50000 -1042.2852289676666,2.011807918548584,25556.429697752,75106,0,25556.429697752,0.5297999978065491,2.1002209186553955,10000,26603.31052732468,0.7759885191917419,0.8355568647384644,0.6620799899101257,1.3705966472625732,50000 -1059.3335857391355,2.053749084472656,26066.617539405823,76609,0,26066.617539405823,0.5212000012397766,2.1413140296936035,10000,27130.641062498093,0.7373644709587097,0.9887303709983826,0.6548199653625488,1.4147918224334717,50000 -1076.1535539627075,2.098949909210205,26576.7337744236,78112,0,26576.7337744236,0.5396000146865845,2.074337959289551,10000,27657.67480564117,0.7537468075752258,0.915976583957672,0.6726399660110474,1.3231079578399658,50000 -1092.703783750534,2.143179416656494,27086.811772346497,79615,0,27086.811772346497,0.5450000166893005,2.059394836425781,10000,28184.400028944016,0.74906325340271,0.9586185216903688,0.6702799797058105,1.339442491531372,50000 -1109.4012954235077,2.186863899230957,27596.813438415527,81118,0,27596.813438415527,0.5406000018119812,2.044244766235352,10000,28711.19625043869,0.7498006820678711,0.9481151700019836,0.6739799976348877,1.3297719955444336,50000 -1125.9471654891968,2.227048873901367,28106.82225990296,82621,0,28106.82225990296,0.5247000455856323,2.1522934436798096,10000,29237.84588885308,0.7277582883834839,1.0413271188735962,0.6600399613380432,1.386773705482483,50000 -1142.674451828003,2.274683952331543,28616.90827870369,84125,0,28616.90827870369,0.5400000214576721,2.088098526000977,10000,29764.76045846939,0.7925701141357422,0.7795647382736206,0.6730599999427795,1.3342093229293823,50000 -1159.42480802536,2.3172881603240967,29127.042145967484,85628,0,29127.042145967484,0.5486000180244446,2.0616767406463623,10000,30291.73993468285,0.7692721486091614,0.8613221049308777,0.6786800026893616,1.3224866390228271,50000 -1176.0026772022247,2.3611364364624023,29636.98632788658,87131,0,29636.98632788658,0.5521000027656555,2.015770673751831,10000,30818.359090328217,0.7654455900192261,0.869311511516571,0.6791200041770935,1.3054094314575195,50000 -1192.485106468201,2.4045522212982178,30147.10192131996,88634,0,30147.10192131996,0.5489000082015991,2.0037930011749268,10000,31345.05516934395,0.7592872977256775,0.9058119654655457,0.6819199919700623,1.291403889656067,50000 -1209.234266757965,2.4512953758239746,30657.10430932045,90137,0,30657.10430932045,0.5453000068664551,2.0540900230407715,10000,31871.90614414215,0.7499800324440002,0.9388789534568788,0.6706199645996094,1.3405901193618774,50000 -1225.9227805137634,2.495468854904175,31167.13133573532,91640,0,31167.13133573532,0.5464000105857849,2.101987361907959,10000,32398.71913027764,0.7515146732330322,0.9365392327308656,0.6793999671936035,1.3209984302520752,50000 -1242.5521783828735,2.5414443016052246,31677.139016866684,93144,0,31677.139016866684,0.5480000376701355,2.06464958190918,10000,32925.45507359505,0.7772839665412903,0.8345317840576172,0.6767199635505676,1.3265360593795776,50000 -1259.1848402023315,2.5860958099365234,32187.304002285004,94647,0,32187.304002285004,0.5538000464439392,2.0050389766693115,10000,33452.35047388077,0.7825653553009033,0.8143873810768127,0.6804599761962891,1.2941557168960571,50000 -1275.8200645446775,2.6573374271392822,32697.30992746353,96150,0,32697.30992746353,0.5613000392913818,2.0027637481689453,10000,33979.11499476433,0.7772839665412903,0.8267792463302612,0.6888599991798401,1.2830692529678345,50000 -1292.5487146377563,2.7052574157714844,33207.35705137253,97653,0,33207.35705137253,0.5325000286102295,2.141735076904297,10000,34505.99077963829,0.7453164458274841,0.9852451086044312,0.6612799763679504,1.3965210914611816,50000 -1309.2410578727722,2.7550249099731445,33717.318086624146,99156,0,33717.318086624146,0.5527999997138977,2.025407791137696,10000,35032.746633291245,0.7657644748687744,0.8797435760498047,0.6818999648094177,1.3018945455551147,50000 -1326.0208716392517,2.801799774169922,34227.27606058121,100659,0,34227.27606058121,0.5570999979972839,2.000819206237793,10000,35559.58508205414,0.7707070708274841,0.8599854111671448,0.6876400113105774,1.271263599395752,50000 -1342.7114791870115,2.847575664520264,34737.225987672806,102162,0,34737.225987672806,0.5587000250816345,2.031827449798584,10000,36086.32460570336,0.7767059803009033,0.8202096223831177,0.6881600022315979,1.2708592414855957,50000 -1359.4346933364868,2.912957191467285,35247.42342591286,103666,0,35247.42342591286,0.5575000047683716,2.007575273513794,10000,36613.36411499977,0.7948023080825806,0.742831289768219,0.6871799826622009,1.2739155292510986,50000 -1376.1155910491943,2.963505744934082,35757.409264564514,105169,0,35757.409264564514,0.5600000023841858,1.9892528057098389,10000,37140.13447451592,0.7901785373687744,0.779336154460907,0.6885799765586853,1.2575336694717407,50000 -1392.704957962036,3.009145498275757,36267.38300538063,106672,0,36267.38300538063,0.5663000345230103,1.970118284225464,10000,37666.79642248154,0.7915935516357422,0.7726999521255493,0.6963399648666382,1.2351707220077517,50000 -1409.3859219551086,3.06013560295105,36777.43883371353,108176,0,36777.43883371353,0.5533000230789185,2.026097297668457,10000,38193.63676953316,0.7795758843421936,0.8351767063140869,0.6855999827384949,1.2827911376953125,50000 -1425.993721485138,3.102454423904419,37287.4537332058,109679,0,37287.4537332058,0.567300021648407,1.9748430252075195,10000,38720.35473489761,0.7897002100944519,0.7726017832756042,0.6967399716377258,1.228294849395752,50000 -1442.764579296112,3.153353452682495,37797.4540207386,111183,0,37797.4540207386,0.5749000310897827,1.9123485088348389,10000,39247.22995352745,0.7920718789100647,0.7628602385520935,0.7012999653816223,1.20783793926239,50000 -1459.3377630710602,3.2031338214874268,38307.6800494194,112687,0,38307.6800494194,0.5703999996185303,1.9807078838348389,10000,39774.13205599785,0.816824734210968,0.6701944470405579,0.6967799663543701,1.2372926473617554,50000 -1475.9748423099518,3.2496707439422607,38817.644235134125,114191,0,38817.644235134125,0.5706000328063965,1.964166522026062,10000,40300.83305287361,0.807039201259613,0.6915223002433777,0.6977799534797668,1.2234848737716677,50000 -1492.686810255051,3.298171281814575,39327.79904174805,115694,0,39327.79904174805,0.5722000002861023,1.946460485458374,10000,40827.80290222168,0.8061423897743225,0.6987681984901428,0.7060799598693848,1.1993192434310913,50000 -1509.241043329239,3.349168300628662,39837.94040846825,117198,0,39837.94040846825,0.5745000243186951,1.972318410873413,10000,41354.602694273,0.7992864847183228,0.7257764339447021,0.701259970664978,1.2185938358306885,50000 -1525.8609673976898,3.40002703666687,40348.18306827545,118701,0,40348.18306827545,0.5827000141143799,1.896999478340149,10000,41881.56796193123,0.8131775856018066,0.685913622379303,0.7095999717712402,1.1763111352920532,50000 -1542.5479755401611,3.452460289001465,40858.13369560242,120204,0,40858.13369560242,0.5776000022888184,1.9270565509796145,10000,42408.31049633026,0.8085139989852905,0.6942028999328613,0.7087599635124207,1.1854325532913208,50000 -1559.2989346981049,3.502571582794189,41368.25840616226,121707,0,41368.25840616226,0.579200029373169,1.906941294670105,10000,42935.28839254379,0.8465999364852905,0.56064772605896,0.7129799723625183,1.1660146713256836,50000 -1575.9256281852722,3.5578701496124268,41878.31901669502,123211,0,41878.31901669502,0.5875000357627869,1.8746788501739504,10000,43462.08464598656,0.8325493931770325,0.5962198376655579,0.7136200070381165,1.1775128841400146,50000 -1592.7352993488312,3.611999750137329,42388.27456307411,124714,0,42388.27456307411,0.5853000283241272,1.8980510234832764,10000,43988.95673203468,0.8289819955825806,0.6051926016807556,0.7158199548721313,1.1632517576217651,50000 -1609.340696811676,3.663426637649536,42898.41350364685,126218,0,42898.41350364685,0.5856000185012817,1.908242106437683,10000,44515.80540394783,0.8241987824440002,0.6316837072372437,0.7112999558448792,1.1798099279403689,50000 -1625.9791305065155,3.715073585510254,43408.54190301895,127721,0,43408.54190301895,0.5924000144004822,1.8633819818496704,10000,45042.67664170265,0.8290417790412903,0.6082040667533875,0.7193999886512756,1.1457984447479248,50000 -1642.595087766647,3.7680022716522217,43918.64049005509,129225,0,43918.64049005509,0.5823000073432922,1.93605637550354,10000,45569.49729180336,0.817781388759613,0.6508827209472656,0.7068399786949158,1.1904618740081787,50000 -1659.1601164340973,3.817939519882202,44428.71792554855,130728,0,44428.71792554855,0.5943000316619873,1.8778047561645508,10000,46096.24298453331,0.8703563213348389,0.4560602903366089,0.7164199948310852,1.169049859046936,50000 -1675.7581803798676,3.868992805480957,44938.95955300331,132232,0,44938.95955300331,0.5918000340461731,1.8805676698684688,10000,46623.186673641205,0.8533163070678711,0.5247998833656311,0.7186799645423889,1.150518774986267,50000 -1692.2820625305176,3.927493095397949,45448.97965598106,133735,0,45448.97965598106,0.5940999984741211,1.8898875713348389,10000,47149.84192085266,0.8556680083274841,0.5154725313186646,0.7227599620819092,1.1428040266036987,50000 -1708.829159975052,3.9842641353607178,45959.17137622833,135239,0,45959.17137622833,0.5942000150680542,1.886696934700012,10000,47676.69100522995,0.8517617583274841,0.5228502750396729,0.7201799750328064,1.1425071954727173,50000 -1725.3450455665588,4.035556793212891,46469.35333895683,136742,0,46469.35333895683,0.5951000452041626,1.8549350500106807,10000,48203.49363040924,0.8520607352256775,0.5236783027648926,0.7238399982452393,1.1254109144210815,50000 -1742.0426452159882,4.085598707199097,46979.564589738846,138246,0,46979.564589738846,0.6016000509262085,1.8454172611236568,10000,48730.506182432175,0.8571826815605164,0.4904449582099914,0.7291399836540222,1.1260720491409302,50000 -1758.654646873474,4.139785051345825,47489.7115893364,139749,0,47489.7115893364,0.6012000441551208,1.864063143730164,10000,49257.37302136421,0.8938735723495483,0.3696881830692291,0.7278199791908264,1.112565040588379,50000 -1775.2069325447085,4.194054841995239,47999.79533600807,141253,0,47999.79533600807,0.6060000061988831,1.822378635406494,10000,49784.1164534092,0.88578200340271,0.3992698490619659,0.7300199866294861,1.1068578958511353,50000 -1791.8828384876251,4.247902631759644,48509.867886304855,142756,0,48509.867886304855,0.6037000417709351,1.838681936264038,10000,50310.97115421295,0.8810786008834839,0.4117041528224945,0.7327799797058105,1.0963129997253418,50000 -1808.416305065155,4.30181097984314,49019.87905693054,144260,0,49019.87905693054,0.5979000329971313,1.8827852010726929,10000,50837.62298154831,0.8727877736091614,0.44132199883461,0.7282800078392029,1.1258344650268557,50000 -1825.05018401146,4.359186172485352,49530.08825039864,145763,0,49530.08825039864,0.6073000431060791,1.843134880065918,10000,51364.57683515549,0.8825733065605164,0.4018238186836242,0.7343399524688721,1.1096502542495728,50000 -1841.761024713516,4.412938833236694,50040.01444983482,147266,0,50040.01444983482,0.6115000247955322,1.845555543899536,10000,51891.32067155838,0.88671875,0.3884380757808685,0.7356199622154236,1.0984673500061035,50000 -1858.532964706421,4.470274448394775,50550.195132255554,148770,0,50550.195132255554,0.6085000038146973,1.8452770709991453,10000,52418.38363933563,0.9123883843421936,0.3148034214973449,0.7336199879646301,1.105415105819702,50000 -1875.268774271012,4.525710821151733,51060.157440423965,150274,0,51060.157440423965,0.6093000173568726,1.8343807458877563,10000,52945.19025826454,0.913305163383484,0.2986312806606293,0.7386999726295471,1.0863254070281982,50000 -1892.6738231182096,4.581248283386231,51570.23234939575,151777,0,51570.23234939575,0.6095000505447388,1.837144374847412,10000,53472.77893829346,0.9090202450752258,0.3202269375324249,0.7390799522399902,1.08871328830719,50000 -1909.446493148804,4.638153791427612,52080.39623951912,153281,0,52080.39623951912,0.6140000224113464,1.825111627578736,10000,53999.82557368279,0.9080038070678712,0.3098303377628326,0.7390999794006348,1.0933698415756226,50000 -1926.1085708141327,4.691672086715698,52590.39731740952,154785,0,52590.39731740952,0.6220000386238098,1.8184280395507808,10000,54526.59492826462,0.9130859375,0.2970570921897888,0.7426199913024902,1.0836670398712158,50000 -1942.9485993385315,4.743105173110962,53100.54725623131,156289,0,53100.54725623131,0.6170000433921814,1.8322269916534424,10000,55053.689846515656,0.9122289419174194,0.3007284998893738,0.7432799935340881,1.0831379890441897,50000 -1959.4862580299373,4.789821863174439,53610.67733025551,157792,0,53610.67733025551,0.6177000403404236,1.8258272409439087,10000,55580.45754265785,0.9258809089660645,0.2602937519550323,0.7439599633216858,1.0839968919754028,50000 -1976.401884317398,4.849618911743164,54120.89049816132,159296,0,54120.89049816132,0.6195000410079956,1.8430743217468264,10000,56107.6992123127,0.9353076815605164,0.2272559702396392,0.7450799942016602,1.0768632888793943,50000 -1993.062605857849,4.90923810005188,54631.05508708954,160799,0,54631.05508708954,0.6162000298500061,1.8339414596557613,10000,56634.638201236725,0.9308434128761292,0.2407589256763458,0.7454400062561035,1.0781235694885254,50000 -2009.718491077423,4.966087818145752,55141.21358561516,162303,0,55141.21358561516,0.6193000078201294,1.8274227380752563,10000,57161.56299781799,0.9325773119926452,0.2340480387210846,0.7475000023841858,1.0697925090789795,50000 -2026.4165422916408,5.023638486862183,55651.39206695557,163806,0,55651.39206695557,0.6229000091552734,1.8333741426467896,10000,57688.54973888397,0.935327649116516,0.2216317504644394,0.7468799948692322,1.0788114070892334,50000 -2042.9757788181305,5.078388690948486,56161.42722654343,165309,0,56161.42722654343,0.6246000528335571,1.8346995115280151,10000,58215.25165247917,0.9376594424247742,0.2162195891141891,0.7461599707603455,1.07829487323761,50000 -2059.741940021515,5.1376566886901855,56671.53956079483,166812,0,56671.53956079483,0.6241000294685364,1.825334906578064,10000,58742.24252414704,0.9435387253761292,0.202550783753395,0.7520999908447266,1.061427116394043,50000 -2076.389808654785,5.199169874191284,57181.61806106568,168315,0,57181.61806106568,0.6273000240325928,1.837319254875183,10000,59269.084530353546,0.95121169090271,0.173487052321434,0.7513599991798401,1.0714797973632812,50000 -2093.059808731079,5.259644985198975,57691.601484537125,169818,0,57691.601484537125,0.6255000233650208,1.8304463624954224,10000,59795.85251188278,0.951969027519226,0.1739388257265091,0.7509399652481079,1.0662760734558103,50000 -2109.7203953266144,5.319650173187256,58201.5602388382,171321,0,58201.5602388382,0.6276000142097473,1.834411859512329,10000,60322.585090875626,0.9516900181770324,0.1759648621082306,0.7511999607086182,1.066061019897461,50000 -2126.3594963550568,5.380169868469238,58711.51387619972,172823,0,58711.51387619972,0.6281000375747681,1.826152801513672,10000,60849.290974617004,0.9527662396430968,0.1673062890768051,0.7521599531173706,1.0650585889816284,50000 -2142.955652713776,5.441788911819458,59221.47832560539,174326,0,59221.47832560539,0.6239000558853149,1.8231186866760247,10000,61375.96544504166,0.9549385905265808,0.1686282455921173,0.7537800073623657,1.060943841934204,50000 -2159.5833282470703,5.498790264129639,59731.51045560837,175828,0,59731.51045560837,0.6284000277519226,1.8286449909210205,10000,61902.73510289192,0.954858899116516,0.1653094291687011,0.7553600072860718,1.0575079917907717,50000 -2176.286799430847,5.566740036010742,60241.45378923416,177330,0,60241.45378923416,0.6285000443458557,1.8258801698684688,10000,62429.50436306,0.9606584906578064,0.1450536251068115,0.7543799877166748,1.057165503501892,50000 -2192.8202583789825,5.632821083068848,60751.63723874092,178833,0,60751.63723874092,0.6291000247001648,1.8220672607421875,10000,62956.34075641632,0.9592633843421936,0.1504260599613189,0.7558799982070923,1.055589199066162,50000 -2209.394678592682,5.691540241241455,61261.6109521389,180336,0,61261.6109521389,0.6288000345230103,1.822423934936524,10000,63483.00061035156,0.9604392051696776,0.1449994593858719,0.7549799680709839,1.0576988458633425,50000 -2226.068264245987,5.75496768951416,61771.61212205887,181839,0,61771.61212205887,0.6292000412940979,1.8193235397338867,10000,64009.79198360443,0.9586654901504515,0.1504426002502441,0.7546799778938293,1.0552231073379517,50000 -2242.700765132904,5.826450347900391,62281.60619163513,183342,0,62281.60619163513,0.6309000253677368,1.8188600540161133,10000,64536.54286932945,0.9608178734779358,0.1480124741792678,0.7549600005149841,1.0551658868789673,50000 -2259.293065071106,5.892799139022827,62791.7273118496,184845,0,62791.7273118496,0.6299000382423401,1.818251609802246,10000,65063.375341653824,0.960379421710968,0.1480392068624496,0.7543999552726746,1.0552953481674194,50000 -2275.9461719989777,5.953683853149414,63301.88521718979,186348,0,63301.88521718979,0.629800021648407,1.8192886114120483,10000,65590.30063033104,0.961316168308258,0.1417189538478851,0.7547999620437622,1.0548179149627686,50000 -2292.624106407165,6.019184112548828,63812.025695085526,187851,0,63812.025695085526,0.6304000020027161,1.818678379058838,10000,66117.23640704155,0.9618940949440002,0.1434378921985626,0.7547199726104736,1.056044578552246,50000 -2309.2071928977966,6.081565856933594,64321.92574834824,189354,0,64321.92574834824,0.6302000284194946,1.817767858505249,10000,66643.83481454849,0.9598413109779358,0.1478601396083831,0.7548399567604065,1.0551180839538574,50000 -2326.568407535553,6.143704175949097,64831.91565823555,190857,0,64831.91565823555,0.6309000253677368,1.8206629753112795,10000,67171.30288815498,0.9608976244926452,0.1452582329511642,0.7548399567604065,1.0554226636886597,50000 -2343.153601646424,6.2114434242248535,65341.9563536644,192360,0,65341.9563536644,0.6299000382423401,1.818163514137268,10000,67698.0499484539,0.9600207209587096,0.1482952535152435,0.7549399733543396,1.055193305015564,50000 -2359.8171005249023,6.272164344787598,65851.87741112709,193862,0,65851.87741112709,0.6300000548362732,1.818334460258484,10000,68224.74874973297,0.9616549611091614,0.1444314867258072,0.7547399997711182,1.0553990602493286,50000 -2376.5136275291443,6.332875728607178,66361.8992049694,195366,0,66361.8992049694,0.6297000050544739,1.819195985794068,10000,68751.58087086678,0.9608577489852904,0.1464710235595703,0.7546199560165405,1.05693256855011,50000 -2393.079930305481,6.3962483406066895,66872.12175607681,196869,0,66872.12175607681,0.6299000382423401,1.8178199529647827,10000,69278.48593592644,0.961136758327484,0.1438667327165603,0.7547799944877625,1.055078148841858,50000 -2409.745950460434,6.455162048339844,67382.08436250687,198372,0,67382.08436250687,0.6314000487327576,1.8186239004135127,10000,69805.22742986679,0.9610371589660645,0.1463333070278167,0.7547599673271179,1.0546435117721558,50000 -2426.7303347587585,6.520244121551514,67891.98360681534,199874,0,67891.98360681534,0.6300000548362732,1.8194706439971924,10000,70332.2281923294,0.96097731590271,0.1450323164463043,0.7546600103378296,1.0563466548919678,50000 -2443.469132423401,6.584799528121948,68402.14937877655,201378,0,68402.14937877655,0.6297000050544739,1.8182106018066408,10000,70859.24960279465,0.9603993892669678,0.14662966132164,0.7549799680709839,1.0545313358306885,50000 -2460.0737595558167,6.650354385375977,68912.09048843384,202880,0,68912.09048843384,0.6297000050544739,1.819766640663147,10000,71385.91350674629,0.960758090019226,0.1468870043754577,0.7549200057983398,1.0562556982040403,50000 -2476.78812289238,6.723001718521118,69422.0167837143,204383,0,69422.0167837143,0.6309000253677368,1.8184700012207031,10000,71912.68066692352,0.9612563848495485,0.1475114524364471,0.7545999884605408,1.056083917617798,50000 -2493.426098585129,6.797011137008667,69932.16599369049,205886,0,69932.16599369049,0.6307000517845154,1.818460464477539,10000,72439.5957725048,0.958984375,0.1498239785432815,0.7548999786376953,1.0548142194747925,50000 -2510.198974609375,6.865321636199951,70442.23829221725,207389,0,70442.23829221725,0.6303000450134277,1.8208805322647093,10000,72966.56232953072,0.9612165093421936,0.1428539901971817,0.7551599740982056,1.0566010475158691,50000 -2526.792496442795,6.931980609893799,70952.31413340569,208892,0,70952.31413340569,0.6303000450134277,1.818329811096192,10000,73493.35128498077,0.960359513759613,0.1474143266677856,0.7548999786376953,1.055539608001709,50000 -2543.6113333702087,6.999336004257202,71462.24423503876,210395,0,71462.24423503876,0.6307000517845154,1.817723512649536,10000,74020.21932840347,0.9596021771430968,0.1500817537307739,0.7551800012588501,1.0549089908599854,50000 -2560.212978601456,7.059044599533081,71972.2426841259,211898,0,71972.2426841259,0.6303000450134277,1.817959427833557,10000,74546.93216991425,0.9599011540412904,0.1484949290752411,0.7547799944877625,1.0549051761627195,50000 -2576.818477153778,7.12625527381897,72482.38831710815,213401,0,72482.38831710815,0.6299000382423401,1.8180344104766848,10000,75073.80241346359,0.9609375,0.1480892598628997,0.7550999522209167,1.0552589893341064,50000 -2593.551813840866,7.196750164031982,72992.4620001316,214904,0,72992.4620001316,0.6300000548362732,1.8187536001205444,10000,75600.73349714279,0.9608976244926452,0.146467387676239,0.7546399831771851,1.054847002029419,50000 -2610.1900346279144,7.270415544509888,73502.42083287239,216408,0,73502.42083287239,0.6304000020027161,1.8173750638961792,10000,76127.45818305016,0.9601004123687744,0.1451017260551452,0.7546799778938293,1.0557056665420532,50000 -2626.9975488185883,7.337841749191284,74012.63173341751,217911,0,74012.63173341751,0.6309000253677368,1.8198227882385247,10000,76654.59640598297,0.9603196382522584,0.1469211876392364,0.7546600103378296,1.0558857917785645,50000 -2643.6508531570435,7.39803147315979,74522.71133613586,219415,0,74522.71133613586,0.6302000284194946,1.819336175918579,10000,77181.44177699089,0.9604790806770324,0.1473052203655243,0.7550599575042725,1.0550016164779663,50000 -2660.755885362625,7.464325189590454,75032.91843128204,220918,0,75032.91843128204,0.6307000517845154,1.8186966180801392,10000,77708.87258601189,0.960718274116516,0.1482438445091247,0.7550399899482727,1.0545982122421265,50000 -2677.481550216675,7.535091638565063,75543.07302880287,222422,0,75543.07302880287,0.6304000020027161,1.819899082183838,10000,78235.87724661827,0.960558831691742,0.1466655135154724,0.754859983921051,1.0557494163513184,50000 -2694.1498594284058,7.601906299591064,76053.1585996151,223925,0,76053.1585996151,0.629300057888031,1.8206843137741089,10000,78762.75082278252,0.9612165093421936,0.1431528031826019,0.7547399997711182,1.0563209056854248,50000 -2710.828951358795,7.669835329055786,76563.16371178627,225428,0,76563.16371178627,0.631100058555603,1.8175678253173828,10000,79289.55604577065,0.9616350531578064,0.1421045213937759,0.7547999620437622,1.0561600923538208,50000 -2727.477970123291,7.7380051612854,77073.34233403206,226931,0,77073.34233403206,0.6301000118255615,1.819921970367432,10000,79816.50515580177,0.9616150856018066,0.1458224952220916,0.7543599605560303,1.0568692684173584,50000 -2744.826174020767,8.672579765319824,77582.4801542759,228431,0,77582.4801542759,0.6300000548362732,1.8176454305648804,10000,80343.97794318199,0.9591238498687744,0.1471176445484161,0.7551800012588501,1.0547051429748535,50000 -2761.564717531204,8.739872694015503,78092.46252202988,229934,0,78092.46252202988,0.6292000412940979,1.818885087966919,10000,80870.81914234161,0.9608976244926452,0.1465269774198532,0.7546399831771851,1.0559604167938232,50000 -2778.42622590065,8.809640645980835,78602.5905714035,231438,0,78602.5905714035,0.6313000321388245,1.8184458017349243,10000,81397.9317638874,0.9616549611091614,0.1448116451501846,0.7546799778938293,1.0556319952011108,50000 -2795.0093553066254,8.879449844360352,79112.5441596508,232941,0,79112.5441596508,0.6312000155448914,1.81929874420166,10000,81924.59214305878,0.9609175324440002,0.1463182866573333,0.7546600103378296,1.056311011314392,50000 -2811.688470363617,8.947274923324585,79622.69375491142,234445,0,79622.69375491142,0.6300000548362732,1.819419264793396,10000,82451.54144382477,0.961316168308258,0.1448947787284851,0.7548799514770508,1.055671215057373,50000 -2828.3450248241425,9.016587018966677,80132.81121706963,235948,0,80132.81121706963,0.6304000020027161,1.8185477256774905,10000,82978.43870472908,0.9607979655265808,0.1469217389822006,0.754859983921051,1.054784059524536,50000 -2845.0322086811066,9.085235595703123,80643.00874638557,237452,0,80643.00874638557,0.6301000118255615,1.819985270500183,10000,83505.44635510445,0.9615154266357422,0.1444291174411773,0.7550999522209167,1.0559808015823364,50000 -2861.686569929123,9.158106327056885,81153.17875504494,238955,0,81153.17875504494,0.631100058555603,1.8194414377212524,10000,84032.39768075943,0.9600805044174194,0.1469574570655822,0.7549200057983398,1.055910587310791,50000 -2878.328460454941,9.227638959884644,81663.18556928635,240459,0,81663.18556928635,0.6307000517845154,1.8203246593475344,10000,84559.16914653778,0.9600406289100648,0.1476673632860183,0.7547399997711182,1.0556063652038574,50000 -2895.289322376251,9.2988703250885,82173.29926466942,241961,0,82173.29926466942,0.6307000517845154,1.8199002742767327,10000,85086.36740994453,0.961316168308258,0.1453227698802948,0.7550599575042725,1.0559238195419312,50000 -2911.817486524582,9.36837124824524,82683.24376130104,243464,0,82683.24376130104,0.629800021648407,1.8192369937896729,10000,85612.96273779869,0.9612762928009032,0.1468772292137146,0.7552199959754944,1.054276943206787,50000 -2928.6007051467896,9.440913200378418,83193.1289358139,244966,0,83193.1289358139,0.6302000284194946,1.8175748586654663,10000,86139.756534338,0.9592832922935486,0.1474329233169555,0.7551599740982056,1.0548502206802368,50000 -2945.365699291229,9.517107009887695,83703.21215891838,246469,0,83703.21215891838,0.6319000124931335,1.820106506347656,10000,86666.73375701904,0.9610969424247742,0.1457830071449279,0.7546199560165405,1.0571213960647583,50000 -2961.9213037490845,9.590791940689089,84213.20818805695,247972,0,84213.20818805695,0.6305000185966492,1.8187205791473389,10000,87193.41122412682,0.9592633843421936,0.1491094678640365,0.7548399567604065,1.0558416843414309,50000 -2978.6110417842865,9.664908409118652,84723.22751140594,249475,0,84723.22751140594,0.6299000382423401,1.820226669311524,10000,87720.24756479263,0.960379421710968,0.1473689973354339,0.7548799514770508,1.05515718460083,50000 -2995.421051502228,9.74666666984558,85233.19104790688,250978,0,85233.19104790688,0.6300000548362732,1.8185988664627075,10000,88247.156447649,0.9597018361091614,0.1512549221515655,0.7553199529647827,1.055238127708435,50000 -3012.068596839905,9.817442417144775,85743.23414087296,252481,0,85743.23414087296,0.6306000351905823,1.818914532661438,10000,88773.97137570381,0.9615154266357422,0.1466091573238372,0.7547799944877625,1.0560367107391355,50000 -3028.657395362854,9.926298141479492,86253.38255238533,253984,0,86253.38255238533,0.6312000155448914,1.8176671266555784,10000,89300.87192893028,0.9611168503761292,0.1454275101423263,0.7546799778938293,1.0555272102355957,50000 -3045.324955224991,9.99882435798645,86763.52758550644,255488,0,86763.52758550644,0.6297000050544739,1.818804144859314,10000,89827.80944180489,0.9581273794174194,0.1509248316287994,0.7550199627876282,1.0555461645126345,50000 -3061.9612522125244,10.076414108276367,87273.6397869587,256991,0,87273.6397869587,0.6306000351905823,1.8188364505767824,10000,90354.68846488,0.9624322056770324,0.1416583657264709,0.7550199627876282,1.0557940006256104,50000 -3078.59850025177,10.152153968811035,87783.56503915787,258494,0,87783.56503915787,0.6300000548362732,1.818835020065308,10000,90881.38066411018,0.9602399468421936,0.1474353671073913,0.7547399997711182,1.0550732612609863,50000 -3095.220307826996,10.222743272781372,88293.56649446487,259997,0,88293.56649446487,0.6309000253677368,1.81663715839386,10000,91408.12760949136,0.961316168308258,0.146479919552803,0.7549399733543396,1.0550395250320437,50000 -3111.8932209014893,10.29693603515625,88803.53046774864,261500,0,88803.53046774864,0.6307000517845154,1.820395946502685,10000,91934.89223361015,0.9601402878761292,0.1482132524251938,0.7545199990272522,1.0562045574188232,50000 -3128.666751384735,10.373113632202148,89313.42373251915,263002,0,89313.42373251915,0.6306000351905823,1.81957483291626,10000,92461.68709921835,0.9615154266357422,0.1408833563327789,0.7550599575042725,1.0564838647842407,50000 -3145.334497213364,10.445182085037231,89823.2990591526,264505,0,89823.2990591526,0.6299000382423401,1.818839192390442,10000,92988.3552968502,0.9611168503761292,0.146037608385086,0.7548799514770508,1.054930329322815,50000 -3162.2875208854675,10.524413585662842,90333.40451908112,266008,0,90333.40451908112,0.6300000548362732,1.820417881011963,10000,93515.5467464924,0.9613759517669678,0.1456584334373474,0.7547999620437622,1.0558500289916992,50000 -3179.641728401184,10.585381507873535,90843.54014015198,267511,0,90843.54014015198,0.6303000450134277,1.820171356201172,10000,94043.1495695114,0.9602399468421936,0.1465692073106765,0.7547199726104736,1.0563976764678955,50000 -3196.283055782318,10.662968635559082,91353.7362806797,269014,0,91353.7362806797,0.6302000284194946,1.819125056266785,10000,94570.11760783195,0.9600805044174194,0.14800925552845,0.7546399831771851,1.0555812120437622,50000 -3212.8778393268585,10.743399143218994,91863.84095406532,270517,0,91863.84095406532,0.6313000321388245,1.819807291030884,10000,95096.95014238358,0.9622727632522584,0.1424338072538375,0.7545199990272522,1.0561974048614502,50000 -3229.5373075008392,10.815797090530396,92374.03491449356,272020,0,92374.03491449356,0.6302000284194946,1.819312572479248,10000,95623.9277703762,0.9606385231018066,0.1462656706571579,0.7546399831771851,1.0556144714355469,50000 -3246.166423082352,10.89329195022583,92884.12191414832,273523,0,92884.12191414832,0.6306000351905823,1.820548176765442,10000,96150.77471852304,0.9612563848495485,0.1452231556177139,0.7550399899482727,1.0557005405426023,50000 -3262.778742313385,10.976376295089722,93394.05655407906,275025,0,93394.05655407906,0.6304000020027161,1.818693280220032,10000,96677.45746946336,0.9610570669174194,0.1458848714828491,0.7547199726104736,1.056044578552246,50000 -3279.488739728928,11.052030563354492,93904.15710663795,276528,0,93904.15710663795,0.6305000185966492,1.8192275762557983,10000,97204.39679145812,0.9612165093421936,0.1463463455438614,0.7547599673271179,1.055619716644287,50000 -3296.2211039066315,11.124534606933594,94414.21587204932,278029,0,94414.21587204932,0.6306000351905823,1.8184598684310915,10000,97731.3130865097,0.9600406289100648,0.1457443684339523,0.7548799514770508,1.055456519126892,50000 -3312.801731109619,11.201616048812866,94924.25241136552,279532,0,94924.25241136552,0.6300000548362732,1.8213742971420288,10000,98258.06078863144,0.9602997303009032,0.1474865823984146,0.7548799514770508,1.0558282136917114,50000 -3329.51867890358,11.280039310455322,95434.31954813004,281034,0,95434.31954813004,0.6305000185966492,1.8184127807617188,10000,98784.97705054285,0.9610969424247742,0.1467566490173339,0.7548199892044067,1.0551109313964844,50000 -3346.3417830467224,11.358751773834229,95944.22466540337,282537,0,95944.22466540337,0.6302000284194946,1.8191053867340088,10000,99311.83710670473,0.959382951259613,0.148421362042427,0.7549799680709839,1.0551875829696655,50000 -3363.020151615143,11.438347816467283,96454.19933009148,284040,0,96454.19933009148,0.6299000382423401,1.8197104930877688,10000,99838.62381887436,0.9605388641357422,0.1470494121313095,0.7549200057983398,1.0552425384521484,50000 -3379.6935136318207,11.514539003372192,96964.37281227112,285543,0,96964.37281227112,0.6309000253677368,1.817240715026856,10000,100365.60154628754,0.9606186151504515,0.1456643491983413,0.754859983921051,1.0539931058883667,50000 -3396.415134191513,11.59593391418457,97474.5063648224,287046,0,97474.5063648224,0.6299000382423401,1.8189139366149905,10000,100892.59136629105,0.9601203799247742,0.1489818543195724,0.7545999884605408,1.0550166368484497,50000 -3413.089588880539,11.67100715637207,97984.6793088913,288550,0,97984.6793088913,0.6303000450134277,1.819604754447937,10000,101419.5659172535,0.9608577489852904,0.146798700094223,0.7548799514770508,1.0558570623397827,50000 -3429.8191237449646,11.751428842544556,98494.60548329352,290052,0,98494.60548329352,0.6304000020027161,1.818774700164795,10000,101946.35471343994,0.9595623016357422,0.1505014598369598,0.7546799778938293,1.0554848909378052,50000 -3446.859624147415,11.829417705535889,99004.58482909204,291555,0,99004.58482909204,0.6304000020027161,1.817091703414917,10000,102473.50423121452,0.9608976244926452,0.1472378075122833,0.7546199560165405,1.054668664932251,50000 -3463.4972772598267,11.915416479110718,99514.50890374184,293057,0,99514.50890374184,0.6300000548362732,1.8181191682815552,10000,103000.20554113388,0.9612762928009032,0.1453863382339477,0.7545599937438965,1.055548071861267,50000 -3480.2958142757416,11.99556565284729,100024.59973978996,294560,0,100024.59973978996,0.6296000480651855,1.8201440572738647,10000,103527.22813105585,0.959203600883484,0.1480921506881714,0.754539966583252,1.0552234649658203,50000 -3496.9346759319305,12.068589448928831,100534.60390925407,296062,0,100534.60390925407,0.629800021648407,1.8213645219802856,10000,104053.99726223946,0.9609175324440002,0.1463489681482315,0.7550599575042725,1.0558689832687378,50000 -3513.5654361248016,12.149300336837769,101044.60762357712,297565,0,101044.60762357712,0.6313000321388245,1.8192485570907595,10000,104580.7647485733,0.9612762928009032,0.1450014561414718,0.7545599937438965,1.055658221244812,50000 -3530.203478574753,12.22605037689209,101554.6838247776,299068,0,101554.6838247776,0.6301000118255615,1.818241834640503,10000,105107.60811972618,0.9602000713348388,0.1488392055034637,0.7549200057983398,1.056084394454956,50000 -3546.855792999268,12.306427717208862,102064.62591266632,300571,0,102064.62591266632,0.6300000548362732,1.818524479866028,10000,105634.33630681038,0.961136758327484,0.14404296875,0.7550999522209167,1.0556014776229858,50000 -3563.481164932251,12.430617094039915,102574.78446245192,302074,0,102574.78446245192,0.6300000548362732,1.8204947710037231,10000,106161.29858207704,0.9614756107330322,0.1412848830223083,0.7550599575042725,1.056552171707153,50000 -3580.18265748024,12.513375520706177,103084.86869740486,303577,0,103084.86869740486,0.6292000412940979,1.8195948600769043,10000,106688.22069692612,0.961316168308258,0.1461625099182129,0.7547799944877625,1.05518901348114,50000 -3597.605567932129,12.594209432601929,103594.85424780846,305080,0,103594.85424780846,0.6306000351905823,1.817049264907837,10000,107215.7625875473,0.9605388641357422,0.1448585391044616,0.7550199627876282,1.054583191871643,50000 -3614.303423643112,12.67265248298645,104104.78870105743,306583,0,104104.78870105743,0.6307000517845154,1.818765878677368,10000,107742.52642440796,0.959980845451355,0.1489111185073852,0.7552199959754944,1.0559635162353516,50000 -3630.9186856746674,12.757277011871338,104614.98310089111,308086,0,104614.98310089111,0.6294000148773193,1.8203606605529783,10000,108269.4746274948,0.9614157676696776,0.1438297778367996,0.7548399567604065,1.0567333698272705,50000 -3647.511518955231,12.83810329437256,105125.10056447984,309589,0,105125.10056447984,0.6310000419616699,1.819544672966004,10000,108796.3183927536,0.960558831691742,0.147440105676651,0.7550399899482727,1.0556894540786743,50000 -3664.230366706848,12.919663667678831,105635.1733725071,311092,0,105635.1733725071,0.6300000548362732,1.8196454048156736,10000,109323.24497246742,0.9615154266357422,0.1449520289897918,0.7554399967193604,1.0552821159362793,50000 -3680.8267703056335,13.000634670257568,106145.30757308006,312595,0,106145.30757308006,0.6308000087738037,1.8177525997161863,10000,109850.11073994637,0.9604990482330322,0.145500361919403,0.7548199892044067,1.055170655250549,50000 -3697.522133350372,13.088707447052002,106655.35900592804,314097,0,106655.35900592804,0.6306000351905823,1.818840742111206,10000,110376.99791812895,0.9604591727256776,0.1466144472360611,0.7550399899482727,1.0555564165115356,50000 -3714.099026441574,13.175846815109251,107165.2802259922,315600,0,107165.2802259922,0.631100058555603,1.820388913154602,10000,110903.63587498663,0.9605388641357422,0.1460705995559692,0.7547000050544739,1.0562903881072998,50000 -3730.758029222488,13.25473165512085,107675.1609697342,317102,0,107675.1609697342,0.6310000419616699,1.8175864219665527,10000,111430.3076775074,0.9609972834587096,0.1439758241176605,0.7548399567604065,1.0555888414382937,50000 -3747.4862003326416,13.34488844871521,108185.10398316383,318605,0,108185.10398316383,0.6300000548362732,1.822026610374451,10000,111957.12183737756,0.9608378410339355,0.1478148102760315,0.7548799514770508,1.056528925895691,50000 -3764.187334537506,13.422353982925417,108695.18180251122,320107,0,108695.18180251122,0.6300000548362732,1.8197250366210933,10000,112484.0316901207,0.9611168503761292,0.1467342525720596,0.7544999718666077,1.055534839630127,50000 -3780.785356760025,13.505805730819702,109205.0882525444,321610,0,109205.0882525444,0.6301000118255615,1.8181936740875244,10000,113010.67347025871,0.9595623016357422,0.14896921813488,0.7551999688148499,1.0555588006973269,50000 -3797.404670953751,13.588683843612673,109715.07874393465,323112,0,109715.07874393465,0.6301000118255615,1.8190803527832031,10000,113537.41925096512,0.9606584906578064,0.1447600871324539,0.7548999786376953,1.0555732250213623,50000 -3814.077670812607,13.667658567428589,110224.97239756584,324614,0,110224.97239756584,0.6305000185966492,1.81911849975586,10000,114064.11824631692,0.9600805044174194,0.1485663652420044,0.7547599673271179,1.055357575416565,50000 -3830.655671596527,13.755608558654783,110734.83861660956,326116,0,110734.83861660956,0.6304000020027161,1.819116711616516,10000,114590.7036960125,0.959582269191742,0.1480408757925033,0.754859983921051,1.055404782295227,50000 -3847.3105340003967,13.836986780166626,111244.93863582613,327618,0,111244.93863582613,0.6300000548362732,1.8186113834381104,10000,115117.5930685997,0.961355984210968,0.147781953215599,0.7549799680709839,1.0543818473815918,50000 -3863.8564035892487,13.919602155685425,111755.11136102676,329121,0,111755.11136102676,0.6299000382423401,1.8170280456542969,10000,115644.44719862938,0.959363043308258,0.1508469581604004,0.7549799680709839,1.055037498474121,50000 -3880.439862012863,14.007933139801024,112265.23484659196,330624,0,112265.23484659196,0.6305000185966492,1.8186101913452148,10000,116171.29470396042,0.961336076259613,0.1432124525308609,0.7545599937438965,1.0553230047225952,50000 -3897.117249965668,14.09683918952942,112775.25321817398,332127,0,112775.25321817398,0.6305000185966492,1.8171299695968628,10000,116698.1319694519,0.959183633327484,0.1507417559623718,0.7548399567604065,1.0550819635391235,50000 -3913.714976072312,14.186609506607056,113285.4351055622,333631,0,113285.4351055622,0.6306000351905823,1.819736361503601,10000,117225.0557732582,0.9612563848495485,0.1435246765613556,0.7547599673271179,1.0559492111206057,50000 -3930.369282960892,14.27349853515625,113795.54326581956,335134,0,113795.54326581956,0.6310000419616699,1.8180768489837649,10000,117751.9583992958,0.9598612785339355,0.1471763551235199,0.7553199529647827,1.055428147315979,50000 -3947.1287870407104,14.36074447631836,114305.43227887154,336637,0,114305.43227887154,0.6302000284194946,1.8188215494155884,10000,118278.74791812895,0.961136758327484,0.147893875837326,0.7549200057983398,1.0557163953781128,50000 -3963.851808786392,14.4454824924469,114815.45590782166,338139,0,114815.45590782166,0.6310000419616699,1.8178565502166748,10000,118805.63232278824,0.9604790806770324,0.1472291797399521,0.7548399567604065,1.055741786956787,50000 -3980.400431394577,14.531226396560667,115325.57354831696,339643,0,115325.57354831696,0.6303000450134277,1.8199061155319207,10000,119332.43783807756,0.9617944359779358,0.1409309804439544,0.7549600005149841,1.0552440881729126,50000 -3997.29501080513,14.61514163017273,115835.75323462486,341147,0,115835.75323462486,0.6305000185966492,1.820884227752685,10000,119859.64888954164,0.9620137214660645,0.1430947631597519,0.7548999786376953,1.0565391778945925,50000 -4013.947923898697,14.70347285270691,116345.7600734234,342650,0,116345.7600734234,0.6305000185966492,1.8172508478164675,10000,120386.4504327774,0.9606783986091614,0.1467174738645553,0.7546600103378296,1.0551903247833252,50000 -4031.3891978263855,14.792497396469116,116855.8805770874,344153,0,116855.8805770874,0.6302000284194946,1.8171762228012085,10000,120914.15520620346,0.9591438174247742,0.1492006033658981,0.7547199726104736,1.0547455549240112,50000 -4048.061734199524,14.883959531784058,117365.90293121338,345656,0,117365.90293121338,0.6303000450134277,1.8182249069213867,10000,121440.9950222969,0.9612762928009032,0.1464512199163437,0.7548999786376953,1.0546934604644775,50000 -4064.625182628632,14.9880850315094,117875.79716300964,347158,0,117875.79716300964,0.6308000087738037,1.817484259605408,10000,121967.6102962494,0.9612165093421936,0.1453184336423874,0.7544599771499634,1.0555799007415771,50000 -4081.1895790100098,15.07457399368286,118385.72451281548,348661,0,118385.72451281548,0.6313000321388245,1.8207018375396729,10000,122494.24131369592,0.9612962007522584,0.1442330181598663,0.7548999786376953,1.055746078491211,50000 -4097.764590501785,15.211965799331663,118895.66305184364,350163,0,118895.66305184364,0.6304000020027161,1.819610834121704,10000,123020.94657683372,0.9616350531578064,0.1450804620981216,0.7546199560165405,1.0556223392486572,50000 -4114.402928113937,15.298362493515016,119405.83976817132,351667,0,119405.83976817132,0.6300000548362732,1.8202815055847168,10000,123547.9008114338,0.9608577489852904,0.1461744904518127,0.754859983921051,1.0564955472946167,50000 -4131.206185817719,15.391158103942873,119916.00923371316,353169,0,119916.00923371316,0.6300000548362732,1.819525122642517,10000,124075.02059555054,0.9614157676696776,0.1444696485996246,0.7545199990272522,1.0556888580322266,50000 -4147.925310850143,15.483352661132812,120425.9566578865,354672,0,120425.9566578865,0.6304000020027161,1.817675828933716,10000,124601.83222198486,0.9606186151504515,0.1470703184604644,0.7550999522209167,1.0547600984573364,50000 -4164.55343580246,15.569127082824709,120936.03289294244,356175,0,120936.03289294244,0.629800021648407,1.819735646247864,10000,125128.67629480362,0.9604591727256776,0.1472638845443725,0.7554199695587158,1.056190848350525,50000 -4181.1410665512085,15.66112756729126,121446.13084578514,357678,0,121446.13084578514,0.6301000118255615,1.8201425075531008,10000,125655.50641322136,0.9614756107330322,0.146551638841629,0.7547399997711182,1.0553967952728271,50000 -4197.899563074112,15.761876106262209,121956.0479941368,359180,0,121956.0479941368,0.6306000351905823,1.819602012634277,10000,126182.3363289833,0.9594826102256776,0.1474414020776748,0.7551800012588501,1.055631160736084,50000 -4214.466181278229,15.852198362350464,122466.16347789764,360684,0,122466.16347789764,0.6294000148773193,1.8188576698303225,10000,126709.16116809844,0.9600605964660645,0.1481878459453582,0.7550599575042725,1.0551674365997314,50000 -4231.663587808609,15.942219734191896,122976.25839185716,362187,0,122976.25839185716,0.6307000517845154,1.818588137626648,10000,127236.59667944908,0.9610371589660645,0.1449005752801895,0.7542600035667419,1.055108666419983,50000 -4248.319851875305,16.017550230026245,123486.3828933239,363690,0,123486.3828933239,0.6306000351905823,1.8202086687088013,10000,127763.50550031662,0.9592434167861938,0.1494711190462112,0.7552199959754944,1.0559954643249512,50000 -4264.995944738388,16.10771369934082,123996.3074054718,365192,0,123996.3074054718,0.6299000382423401,1.8223227262496948,10000,128290.24997830392,0.9596819281578064,0.1469049155712127,0.7549399733543396,1.0571703910827637,50000 -4281.69885134697,16.204543352127075,124506.4336566925,366695,0,124506.4336566925,0.6302000284194946,1.8210707902908323,10000,128817.22952365877,0.961156725883484,0.1489018052816391,0.7549999952316284,1.055810809135437,50000 -4298.291492462158,16.294630765914917,125016.474984169,368197,0,125016.474984169,0.6299000382423401,1.818937420845032,10000,129344.00648355484,0.9597018361091614,0.1484669595956802,0.7549999952316284,1.0560575723648071,50000 -4314.997852563858,16.387300729751587,125526.76858377457,369701,0,125526.76858377457,0.6312000155448914,1.8172831535339355,10000,129871.15259885788,0.9613958597183228,0.1455031484365463,0.7548799514770508,1.054102063179016,50000 -4331.719168186188,16.48409938812256,126036.9502067566,371204,0,126036.9502067566,0.6304000020027161,1.81916081905365,10000,130398.2058544159,0.959382951259613,0.1471273601055145,0.7553399801254272,1.0547292232513428,50000 -4348.296098232269,16.576444387435913,126546.84605765344,372707,0,126546.84605765344,0.6306000351905823,1.8193237781524656,10000,130924.82411670683,0.9607780575752258,0.1458850502967834,0.7550399899482727,1.0556190013885498,50000 -4364.946834564209,16.668809413909912,127056.72100305556,374209,0,127056.72100305556,0.6301000118255615,1.81897234916687,10000,131451.49540233612,0.9608178734779358,0.1466042101383209,0.7549600005149841,1.054998755455017,50000 -4381.365758657455,16.76178002357483,127566.63928079604,375711,0,127566.63928079604,0.6310000419616699,1.8201284408569336,10000,131977.97884607315,0.9608178734779358,0.1465323269367218,0.754859983921051,1.0567834377288818,50000 -4397.861366271973,16.862693548202515,128076.54186677931,377213,0,128076.54186677931,0.6302000284194946,1.820207476615905,10000,132504.5306007862,0.9602598547935486,0.146878108382225,0.7548799514770508,1.0555857419967651,50000 -4414.532790660858,17.786174058914185,128585.89695334436,378714,0,128585.89695334436,0.6297000050544739,1.818804383277893,10000,133031.53467178345,0.961694836616516,0.1412075906991958,0.7550399899482727,1.055570125579834,50000 -4431.161715507507,17.883363485336304,129095.92270565032,380217,0,129095.92270565032,0.6300000548362732,1.8188520669937127,10000,133558.33870458603,0.961316168308258,0.1460568010807037,0.7547399997711182,1.05453622341156,50000 -4448.361759185791,17.9791898727417,129605.80332398416,381720,0,129605.80332398416,0.6305000185966492,1.8185116052627563,10000,134085.56870532036,0.9607381820678712,0.1449822038412094,0.7547000050544739,1.0562328100204468,50000 -4465.047143936157,18.068235158920288,130115.81632757188,383223,0,130115.81632757188,0.6303000450134277,1.8197861909866333,10000,134612.40925621986,0.9597018361091614,0.1486869752407074,0.7551800012588501,1.0554879903793335,50000 -4481.549957990646,18.162591457366943,130625.8651521206,384726,0,130625.8651521206,0.629800021648407,1.818486213684082,10000,135139.1086113453,0.9608976244926452,0.145926147699356,0.7549200057983398,1.055842638015747,50000 -4498.146886110306,18.27062964439392,131135.7101225853,386228,0,131135.7101225853,0.6301000118255615,1.817418098449707,10000,135665.71124458313,0.9614357352256776,0.1451815962791443,0.7549999952316284,1.0552767515182495,50000 -4514.734850406647,18.367199420928955,131645.83796977997,387731,0,131645.83796977997,0.6303000450134277,1.8203297853469849,10000,136192.5762116909,0.9614157676696776,0.1451276540756225,0.7549200057983398,1.0561199188232422,50000 -4531.350663900375,18.472204446792603,132155.74128174782,389233,0,132155.74128174782,0.6301000118255615,1.8190268278121948,10000,136719.2530815601,0.961156725883484,0.1456144601106643,0.7547399997711182,1.055548906326294,50000 -4548.0277144908905,18.570212364196777,132665.60343956947,390736,0,132665.60343956947,0.6299000382423401,1.8202205896377563,10000,137245.94326090813,0.9600805044174194,0.1483380496501922,0.7553600072860718,1.0564292669296265,50000 -4564.457427501679,18.6664834022522,133175.61843442917,392239,0,133175.61843442917,0.6301000118255615,1.8189407587051392,10000,137772.5371003151,0.9613958597183228,0.143719732761383,0.7547599673271179,1.055400013923645,50000 -4580.938060998917,18.762927532196045,133685.6428039074,393743,0,133685.6428039074,0.6302000284194946,1.8172541856765747,10000,138299.19182229042,0.9607780575752258,0.144915223121643,0.7547199726104736,1.0558115243911743,50000 -4597.527137756348,18.85831356048584,134195.6141886711,395246,0,134195.6141886711,0.6305000185966492,1.8179938793182373,10000,138825.90058135986,0.960180163383484,0.1485379189252853,0.7544999718666077,1.0541654825210571,50000 -4614.233896970749,18.956045627594,134705.46764349937,396749,0,134705.46764349937,0.6306000351905823,1.8182374238967896,10000,139352.6130232811,0.9606186151504515,0.1472035646438598,0.7551400065422058,1.054998517036438,50000 -4630.798380851746,19.05357837677002,135215.65347242355,398252,0,135215.65347242355,0.6303000450134277,1.8208487033844,10000,139879.5144290924,0.9606983065605164,0.1478620916604995,0.754859983921051,1.055959701538086,50000 -4647.372666358948,19.147751808166504,135725.60052633286,399755,0,135725.60052633286,0.6297000050544739,1.8177335262298584,10000,140406.1835153103,0.9601004123687744,0.1449165344238281,0.7548799514770508,1.054701566696167,50000 -4663.987517356873,19.24118137359619,136235.73774003985,401258,0,136235.73774003985,0.6304000020027161,1.8172955513000488,10000,140933.08319306374,0.9598413109779358,0.1499073654413223,0.7547599673271179,1.05543315410614,50000 -4680.529687643051,19.362505674362183,136745.5668039322,402760,0,136745.5668039322,0.6304000020027161,1.8198953866958616,10000,141459.62825584412,0.9602598547935486,0.1475264430046081,0.7549600005149841,1.056088924407959,50000 -4697.173872232437,19.46660351753235,137255.6495103836,404263,0,137255.6495103836,0.6304000020027161,1.8172743320465088,10000,141986.512663126,0.9608378410339355,0.1480744481086731,0.7546600103378296,1.0558913946151731,50000 -4713.778460979462,19.561405658721924,137765.65123844147,405766,0,137765.65123844147,0.6300000548362732,1.818644642829895,10000,142513.2674689293,0.959781527519226,0.1490054130554199,0.754539966583252,1.0548299551010132,50000 -4730.435454368591,19.65113377571106,138275.66917204857,407269,0,138275.66917204857,0.6303000450134277,1.81747841835022,10000,143040.0849058628,0.9610570669174194,0.1444486230611801,0.7546199560165405,1.055034637451172,50000 -4747.116379022598,19.74397087097168,138785.59047579765,408772,0,138785.59047579765,0.6304000020027161,1.8169559240341189,10000,143566.8327577114,0.9592633843421936,0.1505450159311294,0.7552399635314941,1.0552719831466677,50000 -4763.699547767639,19.84065055847168,139295.51828718185,410274,0,139295.51828718185,0.6297000050544739,1.8201110363006592,10000,144093.49433231354,0.9607381820678712,0.1451594531536102,0.7550999522209167,1.0560662746429443,50000 -4780.302748441696,19.94161033630371,139805.68612885475,411778,0,139805.68612885475,0.6297000050544739,1.8194156885147093,10000,144620.41824531555,0.9607381820678712,0.1458351612091064,0.7547399997711182,1.0559715032577517,50000 -4796.926322221756,20.04148244857788,140315.84532666206,413281,0,140315.84532666206,0.6307000517845154,1.818820595741272,10000,145147.35459661484,0.9610171914100648,0.1474115848541259,0.7545599937438965,1.05522882938385,50000 -4813.552445888519,20.14136648178101,140825.76637721062,414784,0,140825.76637721062,0.6302000284194946,1.818278431892395,10000,145674.05481624603,0.9598612785339355,0.1473547369241714,0.7548799514770508,1.0550575256347656,50000 -4830.139992237091,20.239493131637573,141335.79673457146,416287,0,141335.79673457146,0.629800021648407,1.8195641040802,10000,146200.8248922825,0.9617745280265808,0.1414497345685959,0.7548199892044067,1.056381106376648,50000 -4846.724861383438,20.34183478355408,141845.92551875114,417791,0,141845.92551875114,0.6301000118255615,1.818116307258606,10000,146727.69525814056,0.9614556431770324,0.1459291577339172,0.7548199892044067,1.0547126531600952,50000 -4863.333392381668,20.443277597427368,142356.03491735458,419294,0,142356.03491735458,0.6306000351905823,1.8181205987930296,10000,147254.5678577423,0.9610570669174194,0.1447632610797882,0.7547599673271179,1.0562463998794556,50000 -4880.598002910614,20.54567670822144,142865.97271060944,420797,0,142865.97271060944,0.6304000020027161,1.8204907178878784,10000,147781.92643356323,0.960180163383484,0.1462839394807815,0.7548999786376953,1.056235671043396,50000 -4897.206827640533,20.6411075592041,143375.83768200874,422299,0,143375.83768200874,0.6300000548362732,1.817649483680725,10000,148308.54959082603,0.9609175324440002,0.146629050374031,0.7549200057983398,1.054916501045227,50000 -4913.787359714508,20.7419707775116,143885.76449489594,423801,0,143885.76449489594,0.6299000382423401,1.819965362548828,10000,148835.2104690075,0.9602997303009032,0.1462405920028686,0.7544599771499634,1.0567693710327148,50000 -4930.458664178848,20.84931969642639,144395.63716816902,425303,0,144395.63716816902,0.631100058555603,1.8174110651016235,10000,149361.91491818428,0.9620934128761292,0.1448357999324798,0.7547000050544739,1.054429292678833,50000 -4947.074217557907,20.96125626564026,144905.5118675232,426806,0,144905.5118675232,0.6296000480651855,1.8200547695159912,10000,149888.57080936432,0.9610570669174194,0.1460616737604141,0.7549600005149841,1.0565731525421145,50000 -4963.67450594902,21.06189441680908,145415.37442803383,428308,0,145415.37442803383,0.629800021648407,1.8188358545303345,10000,150415.18838214874,0.9610570669174194,0.1456134170293808,0.7544999718666077,1.0552159547805786,50000 -4980.192862272263,21.162484407424927,145925.31875014305,429811,0,145925.31875014305,0.6301000118255615,1.8180526494979856,10000,150941.80598330498,0.9610570669174194,0.146034225821495,0.7547599673271179,1.054571509361267,50000 -4996.765533208847,21.2678291797638,146435.1837067604,431313,0,146435.1837067604,0.6306000351905823,1.818410396575928,10000,151468.40133213997,0.9602598547935486,0.1460363268852234,0.7546799778938293,1.0550318956375122,50000 -5013.30003285408,21.38243818283081,146945.2528398037,432816,0,146945.2528398037,0.6296000480651855,1.8184632062911987,10000,151995.17208194733,0.9611766338348388,0.1457973569631576,0.7546600103378296,1.0549418926239014,50000 -5029.848837137222,21.50903558731079,147455.2933781147,434318,0,147455.2933781147,0.6299000382423401,1.8182473182678225,10000,152521.94089126587,0.9610171914100648,0.1467979103326797,0.7549999952316284,1.0560564994812012,50000 -5046.399070739746,21.609692573547363,147965.1800661087,435821,0,147965.1800661087,0.6305000185966492,1.819121599197388,10000,153048.53157043457,0.9590640664100648,0.1513096988201141,0.7549799680709839,1.0555171966552734,50000 -5062.943537712097,21.714293003082275,148475.27742362022,437324,0,148475.27742362022,0.6303000450134277,1.817741870880127,10000,153575.33093428612,0.9607979655265808,0.1454037278890609,0.7547799944877625,1.0546343326568604,50000 -5079.535908937454,21.82430601119995,148985.24048876762,438827,0,148985.24048876762,0.629300057888031,1.8190243244171145,10000,154102.04985761642,0.9602997303009032,0.1466187983751297,0.7544199824333191,1.0557987689971924,50000 -5096.1784517765045,21.92604899406433,149495.2818260193,440330,0,149495.2818260193,0.629800021648407,1.818731546401977,10000,154628.88868403435,0.9606783986091614,0.1473367661237716,0.7546199560165405,1.0556747913360596,50000 -5112.650840759277,22.02223467826844,150005.31732654572,441833,0,150005.31732654572,0.6296000480651855,1.8205649852752688,10000,155155.54529619217,0.9596819281578064,0.1484495252370834,0.7549600005149841,1.0566083192825315,50000 -5129.289839744568,22.123254776000977,150515.3161327839,443336,0,150515.3161327839,0.6296000480651855,1.819937229156494,10000,155682.33803749084,0.9596021771430968,0.1501669436693191,0.754539966583252,1.0561749935150146,50000 -5145.732047557831,22.230733633041385,151025.23944425583,444839,0,151025.23944425583,0.6306000351905823,1.818387985229492,10000,156208.86432909966,0.9607381820678712,0.1462684571743011,0.7548999786376953,1.0548075437545776,50000 -5162.222529888153,22.332300186157227,151535.28600239754,446342,0,151535.28600239754,0.6305000185966492,1.81734037399292,10000,156735.5563902855,0.9601402878761292,0.1477520167827606,0.7553399801254272,1.055067777633667,50000 -5178.921859025955,22.474713563919067,152045.3865249157,447846,0,152045.3865249157,0.6297000050544739,1.8188778162002563,10000,157262.55180740356,0.9604790806770324,0.1445417702198028,0.7549200057983398,1.0557647943496704,50000 -5195.382237672806,22.58767342567444,152555.23027157784,449348,0,152555.23027157784,0.6299000382423401,1.8194371461868288,10000,157789.02242159843,0.9610171914100648,0.1449035555124282,0.7546799778938293,1.055970549583435,50000 -5211.9875292778015,22.69133400917053,153065.41314435005,450852,0,153065.41314435005,0.6301000118255615,1.8175337314605715,10000,158315.9674050808,0.9602000713348388,0.1482229530811309,0.7547199726104736,1.054293155670166,50000 -5228.598965406418,22.787771224975582,153575.34843325615,452354,0,153575.34843325615,0.6310000419616699,1.8206629753112795,10000,158842.66458582878,0.9606385231018066,0.1475521922111511,0.7548799514770508,1.055489182472229,50000 -5245.10581445694,22.89153957366944,154085.36639356613,453857,0,154085.36639356613,0.6300000548362732,1.8203938007354736,10000,159369.3465027809,0.9617147445678712,0.1444277018308639,0.7550199627876282,1.0558409690856934,50000 -5261.763719320297,22.99702095985413,154595.30282139778,455359,0,154595.30282139778,0.6294000148773193,1.819143295288086,10000,159896.09922385216,0.9610371589660645,0.142574280500412,0.7546799778938293,1.0553386211395264,50000 -5278.295104265213,23.10880136489868,155105.47327399254,456863,0,155105.47327399254,0.629800021648407,1.819894790649414,10000,160422.96509361267,0.961156725883484,0.145716980099678,0.754859983921051,1.0554897785186768,50000 -5295.399443626404,23.212135314941406,155615.38880991936,458366,0,155615.38880991936,0.6299000382423401,1.8191497325897217,10000,160950.1413064003,0.9598014950752258,0.1472740620374679,0.7547799944877625,1.0557348728179932,50000 -5312.113205909729,23.317948579788208,156125.31659150124,459869,0,156125.31659150124,0.6297000050544739,1.8184266090393064,10000,161476.9416666031,0.960598647594452,0.1467158049345016,0.7549399733543396,1.0555691719055176,50000 -5328.743317604065,23.42711114883423,156635.17824530602,461371,0,156635.17824530602,0.6301000118255615,1.8204811811447144,10000,162003.59568810463,0.9602798223495485,0.1468383967876434,0.7546399831771851,1.056372046470642,50000 -5345.371515035629,23.53261113166809,157145.0807697773,462874,0,157145.0807697773,0.629800021648407,1.8176995515823364,10000,162530.2850341797,0.9627510905265808,0.1442515403032302,0.7547399997711182,1.0549752712249756,50000 -5362.519360303879,23.63612127304077,157655.2273361683,464377,0,157655.2273361683,0.629800021648407,1.8174927234649656,10000,163057.73574876785,0.9610171914100648,0.1462255269289016,0.7545199990272522,1.0557774305343628,50000 -5379.328293085098,23.72211003303528,158165.1404414177,465880,0,158165.1404414177,0.629300057888031,1.82122802734375,10000,163584.59530687332,0.9604392051696776,0.1437030434608459,0.7551199793815613,1.0570862293243408,50000 -5396.06454372406,23.83945727348328,158675.24805808067,467383,0,158675.24805808067,0.6310000419616699,1.818740963935852,10000,164111.61119008064,0.9608777165412904,0.1466167420148849,0.7547199726104736,1.0551986694335938,50000 -5412.717334508896,23.949141263961792,159185.16336750984,468886,0,159185.16336750984,0.6300000548362732,1.8203456401824951,10000,164638.34228634834,0.9606584906578064,0.1459176987409591,0.7545199990272522,1.0558143854141235,50000 -5429.2265625,24.054170608520508,159695.21085381508,470389,0,159695.21085381508,0.6301000118255615,1.819016456604004,10000,165165.05762171745,0.9602798223495485,0.1463384926319122,0.7549799680709839,1.055783748626709,50000 -5445.694318056107,24.15822815895081,160205.1214659214,471892,0,160205.1214659214,0.6300000548362732,1.8196965456008911,10000,165691.59308218956,0.961694836616516,0.1453837901353836,0.7548799514770508,1.0557386875152588,50000 -5462.238760948181,24.271573543548584,160715.06251072884,473394,0,160715.06251072884,0.6307000517845154,1.8173385858535769,10000,166218.24579143524,0.9602598547935486,0.1482415646314621,0.7550599575042725,1.055686593055725,50000 -5478.855576515198,24.37463998794556,161224.98149490356,474897,0,161224.98149490356,0.631100058555603,1.818025708198548,10000,166744.93678593636,0.9599609375,0.1462009400129318,0.7547399997711182,1.0547475814819336,50000 -5495.393129825592,24.47951650619507,161734.98222899437,476399,0,161734.98222899437,0.6299000382423401,1.8198281526565552,10000,167271.63280415535,0.9605189561843872,0.1475342661142349,0.7547599673271179,1.056140422821045,50000 -5512.12969326973,24.58859372138977,162244.99106502533,477901,0,162244.99106502533,0.6309000253677368,1.815615177154541,10000,167798.54024600983,0.9598014950752258,0.1489960998296737,0.7548799514770508,1.0541421175003052,50000 -5528.671658039093,24.69663667678833,162755.0349202156,479404,0,162755.0349202156,0.629800021648407,1.8202873468399048,10000,168325.28706121445,0.9606983065605164,0.1470500826835632,0.7549999952316284,1.056383728981018,50000 -5545.251647233963,24.805399894714355,163264.94740009308,480907,0,163264.94740009308,0.6295000314712524,1.8191709518432613,10000,168851.94211268425,0.959622085094452,0.151059940457344,0.7543599605560303,1.0561161041259766,50000 -5561.8279457092285,24.915863752365112,163775.10062289238,482410,0,163775.10062289238,0.6310000419616699,1.8195371627807613,10000,169378.83524656296,0.9606186151504515,0.146817535161972,0.7549200057983398,1.0554178953170776,50000 -5578.402358531952,25.023176670074463,164285.20070242882,483913,0,164285.20070242882,0.6296000480651855,1.8194700479507449,10000,169905.66928815842,0.9611965417861938,0.1441107988357544,0.7546600103378296,1.0558921098709106,50000 -5594.983488559723,25.13081169128418,164795.32609844208,485417,0,164795.32609844208,0.6309000253677368,1.8180348873138428,10000,170432.5374853611,0.9585259556770324,0.1497927606105804,0.7547199726104736,1.055144429206848,50000 -5611.545911312103,25.24029874801636,165305.46842598915,486921,0,165305.46842598915,0.631100058555603,1.819509029388428,10000,170959.40493369102,0.9615553021430968,0.1433934420347213,0.7547399997711182,1.0567063093185425,50000 -5628.148483276367,25.35098004341125,165815.3304874897,488423,0,165815.3304874897,0.6297000050544739,1.8203799724578853,10000,171486.03428840637,0.9604790806770324,0.1482679396867752,0.7546199560165405,1.0550260543823242,50000 -5644.838176488876,25.461592197418213,166325.27960824966,489926,0,166325.27960824966,0.6305000185966492,1.8186701536178589,10000,172012.83657836914,0.9613958597183228,0.1450057327747345,0.7552199959754944,1.0549418926239014,50000 -5661.418847084045,25.57334852218628,166835.23479795456,491428,0,166835.23479795456,0.6299000382423401,1.818878173828125,10000,172539.53677415848,0.960339605808258,0.1478165686130523,0.7548399567604065,1.055351972579956,50000 -5677.953634738922,25.689466953277588,167345.2049202919,492931,0,167345.2049202919,0.6303000450134277,1.8192788362503047,10000,173066.20989394188,0.9618343114852904,0.1427801996469497,0.7546600103378296,1.0560132265090942,50000 -5694.541827440262,25.811455249786377,167855.2519850731,494434,0,167855.2519850731,0.6296000480651855,1.820483684539795,10000,173593.01951646805,0.960379421710968,0.1443609893321991,0.7543999552726746,1.0561788082122805,50000 -5711.08086681366,25.91838574409485,168365.25228381157,495937,0,168365.25228381157,0.6294000148773193,1.8213483095169067,10000,174119.71955490112,0.961535394191742,0.1451693475246429,0.7547399997711182,1.056144118309021,50000 -5728.474135398865,26.03202795982361,168875.17960882187,497440,0,168875.17960882187,0.6305000185966492,1.8195267915725708,10000,174647.20702934265,0.96000075340271,0.1484004706144333,0.7550199627876282,1.0548858642578125,50000 -5745.026813030243,26.14190459251404,169385.31666755676,498943,0,169385.31666755676,0.6310000419616699,1.817232728004456,10000,175174.05980086327,0.9606783986091614,0.1457280069589615,0.7548999786376953,1.054988145828247,50000 -5761.515566825867,26.25598978996277,169895.3501522541,500446,0,169895.3501522541,0.6303000450134277,1.8181790113449097,10000,175700.7484512329,0.9609972834587096,0.1472698599100113,0.7549399733543396,1.0556048154830933,50000 -5778.214676856995,26.37480115890503,170405.28198862076,501949,0,170405.28198862076,0.6306000351905823,1.8182381391525269,10000,176227.55142569542,0.962292730808258,0.1432228684425354,0.7550199627876282,1.055835485458374,50000 -5794.882792234421,26.4967200756073,170915.20975255966,503452,0,170915.20975255966,0.6301000118255615,1.8219772577285769,10000,176754.32218694687,0.961575210094452,0.1445035338401794,0.7547799944877625,1.0563068389892578,50000 -5811.456128358841,26.61021900177002,171425.38033485413,504956,0,171425.38033485413,0.6304000020027161,1.818542718887329,10000,177281.23242735863,0.9600605964660645,0.1470568627119064,0.7549799680709839,1.0555027723312378,50000 -5828.022411584854,26.72679662704468,171935.21598100662,506458,0,171935.21598100662,0.6303000450134277,1.81696355342865,10000,177807.80382156372,0.9614556431770324,0.1461593806743621,0.7551400065422058,1.0542012453079224,50000 -5844.531116485596,26.84491229057312,172445.2868115902,507962,0,172445.2868115902,0.6307000517845154,1.8180532455444336,10000,178334.55427193642,0.9599609375,0.1463172882795334,0.7548799514770508,1.0554317235946655,50000 -5861.175798654556,26.954604864120483,172955.43111920357,509465,0,172955.43111920357,0.631600022315979,1.81832218170166,10000,178861.50686740875,0.9617546200752258,0.1442474871873855,0.7548799514770508,1.0556154251098633,50000 -5877.799740791321,27.069482803344727,173465.5064251423,510969,0,173465.5064251423,0.6302000284194946,1.8184680938720703,10000,179388.3738567829,0.961156725883484,0.1486804485321045,0.7549200057983398,1.055416226387024,50000 -5894.746170282364,27.18165707588196,173975.3585202694,512471,0,173975.3585202694,0.6303000450134277,1.817718863487244,10000,179915.33734679222,0.958804965019226,0.1498315632343292,0.7547799944877625,1.055154800415039,50000 -5911.234726428986,27.304914474487305,174485.4700343609,513975,0,174485.4700343609,0.6304000020027161,1.817402482032776,10000,180442.1136064529,0.9610969424247742,0.146408274769783,0.7549600005149841,1.0548641681671145,50000 -5927.747765779495,27.412547826766968,174995.56325149536,515478,0,174995.56325149536,0.6305000185966492,1.817795157432556,10000,180968.88140940663,0.9604192972183228,0.1447800695896148,0.7547199726104736,1.055174469947815,50000 -5944.373743772507,27.52469778060913,175505.54929184914,516981,0,175505.54929184914,0.6308000087738037,1.818442702293396,10000,181495.6579201221,0.9598811864852904,0.1495180130004882,0.7548799514770508,1.055079460144043,50000 -5960.828776597977,27.63849425315857,176015.523532629,518483,0,176015.523532629,0.6297000050544739,1.8194637298583984,10000,182022.25287151337,0.9592633843421936,0.1484085023403167,0.7550199627876282,1.0548003911972046,50000 -5977.352998971939,27.755635499954224,176525.6594619751,519986,0,176525.6594619751,0.631100058555603,1.8197120428085327,10000,182549.0837368965,0.9608777165412904,0.1482365876436233,0.7551400065422058,1.0555709600448608,50000 -5993.85312795639,27.870992183685303,177035.51152539253,521488,0,177035.51152539253,0.6308000087738037,1.818725824356079,10000,183075.60394525528,0.9598413109779358,0.1506748795509338,0.7550599575042725,1.055654764175415,50000 -6010.478732824326,27.994094133377075,177545.4000184536,522990,0,177545.4000184536,0.6309000253677368,1.819977045059204,10000,183602.2939076424,0.9611168503761292,0.1443382948637008,0.7545599937438965,1.0562598705291748,50000 -6026.929802417755,28.10876989364624,178055.3327858448,524492,0,178055.3327858448,0.6304000020027161,1.8197730779647827,10000,184128.84509038925,0.9600406289100648,0.1463096588850021,0.7547599673271179,1.0564887523651123,50000 -6043.616086244583,28.223960638046265,178565.2026667595,525994,0,178565.2026667595,0.6310000419616699,1.8185182809829712,10000,184655.56948399544,0.9606385231018066,0.1450351178646087,0.7550599575042725,1.0553193092346191,50000 -6060.282440185547,28.40710473060608,179075.05401158333,527495,0,179075.05401158333,0.6303000450134277,1.819907665252685,10000,185182.32403969765,0.9610371589660645,0.1455485224723816,0.7551400065422058,1.0551986694335938,50000 -6077.02671456337,28.523298025131226,179584.9684138298,528998,0,179584.9684138298,0.6308000087738037,1.8188236951828003,10000,185709.15247774124,0.961136758327484,0.1468872874975204,0.7548399567604065,1.0558708906173706,50000 -6093.907920360565,28.641568899154663,180095.15888261795,530501,0,180095.15888261795,0.6310000419616699,1.819552183151245,10000,186236.39528346065,0.9602798223495485,0.1466607749462127,0.7544599771499634,1.0554156303405762,50000 -6110.511300086975,28.757806062698364,180605.15130925176,532004,0,180605.15130925176,0.6317000389099121,1.8177498579025269,10000,186763.15968847275,0.9609972834587096,0.1430176347494125,0.7551400065422058,1.0556042194366455,50000 -6127.112663269043,28.87390160560608,181115.2619752884,533507,0,181115.2619752884,0.6303000450134277,1.8190414905548096,10000,187290.04085493088,0.9610969424247742,0.1459327936172485,0.7546799778938293,1.0548255443572998,50000 -6144.01261806488,28.99262237548828,181625.2687883377,535010,0,181625.2687883377,0.6301000118255615,1.820230960845948,10000,187817.1193869114,0.9614556431770324,0.1439944505691528,0.7545799612998962,1.0565119981765747,50000 -6160.654649019241,29.10392904281616,182135.4176516533,536513,0,182135.4176516533,0.6297000050544739,1.8187036514282229,10000,188344.07443284988,0.9588249325752258,0.1505401283502578,0.7545199990272522,1.0554897785186768,50000 -6177.164265871048,29.22213888168335,182645.48416352272,538016,0,182645.48416352272,0.6300000548362732,1.8173961639404297,10000,188870.8207633496,0.9614357352256776,0.1456693708896637,0.754539966583252,1.0547114610671997,50000 -6194.077189683914,29.346582889556885,183155.6240270137,539519,0,183155.6240270137,0.6303000450134277,1.819313764572144,10000,189398.0522556305,0.9617546200752258,0.1445423066616058,0.7547599673271179,1.055425047874451,50000 -6210.7504341602325,29.442301750183105,183665.75982761383,541023,0,183665.75982761383,0.6304000020027161,1.8197301626205444,10000,189925.0090081692,0.961535394191742,0.1437825411558151,0.7549399733543396,1.0555442571640017,50000 -6227.353454113007,29.56045937538147,184175.8125114441,542525,0,184175.8125114441,0.6294000148773193,1.819315910339356,10000,190451.8363242149,0.960957407951355,0.146311342716217,0.7544199824333191,1.055719256401062,50000 -6243.999848365784,29.67587733268737,184685.7376544476,544028,0,184685.7376544476,0.6305000185966492,1.817673921585083,10000,190978.5760681629,0.9606584906578064,0.1456425189971923,0.7549799680709839,1.0556423664093018,50000 -6260.603581190109,29.792681217193604,185195.77567243576,545531,0,185195.77567243576,0.6301000118255615,1.818894028663636,10000,191505.38779973984,0.9602399468421936,0.1466413885354995,0.7547599673271179,1.0558973550796509,50000 -6277.120337963104,29.91552424430847,185705.8674867153,547034,0,185705.8674867153,0.6303000450134277,1.819093585014344,10000,192032.17276000977,0.9611965417861938,0.1448834687471389,0.7548399567604065,1.0554053783416748,50000 -6293.597116947174,30.034253358840942,186216.0143876076,548537,0,186216.0143876076,0.6305000185966492,1.818883299827576,10000,192558.9681181908,0.9606385231018066,0.1475456207990646,0.7551599740982056,1.054714322090149,50000 -6310.110383272171,30.16148805618286,186725.95565152168,550040,0,186725.95565152168,0.6306000351905823,1.8195624351501465,10000,193085.6027336121,0.9608976244926452,0.1470227688550949,0.7550999522209167,1.056078553199768,50000 -6326.669639825821,30.279273509979248,187235.9205963612,551542,0,187235.9205963612,0.6299000382423401,1.8204457759857176,10000,193612.2978289128,0.9594626426696776,0.1485990285873413,0.7549399733543396,1.0570076704025269,50000 -6343.1246864795685,30.39686608314514,187745.98685359955,553045,0,187745.98685359955,0.6305000185966492,1.8178973197937007,10000,194138.9894325733,0.9600207209587096,0.1465989500284195,0.7548999786376953,1.0555554628372192,50000 -6359.778885602951,30.522332429885864,188256.05421233177,554548,0,188256.05421233177,0.6312000155448914,1.8189270496368408,10000,194665.88861370087,0.960339605808258,0.1473934650421142,0.7547000050544739,1.0555559396743774,50000 -6376.293786287308,30.646925687789917,188766.14553546906,556051,0,188766.14553546906,0.6299000382423401,1.8204028606414795,10000,195192.6730442047,0.9599210619926453,0.1484605371952057,0.7547599673271179,1.0562946796417236,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index fd70fb78d..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5942 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6710036,6.930171,,,,,,,,,,,,,, -1,,,0.0011160713620483,6.914276599884033,0.0011199999134987,6.913983345031738,50000.0,0.0008000000379979,6.914037704467773,10000.0,52.27670836448669,89.5958993434906,52.27670836448669,37.31908655166626,0.0,0.0 -100,0.67644167,6.8286443,,,,,,,,,,,,,, -200,0.8360885,6.554777,,,,,,,,,,,,,, -300,1.0310134,6.2718296,,,,,,,,,,,,,, -400,1.6244663,6.020854,,,,,,,,,,,,,, -500,2.2105303,5.792985,,,,,,,,,,,,,, -600,2.3131444,5.549655,,,,,,,,,,,,,, -700,4.975005,5.4613485,,,,,,,,,,,,,, -800,2.8844128,5.3596196,,,,,,,,,,,,,, -900,3.9193993,5.203694,,,,,,,,,,,,,, -1000,3.774704,5.108587,,,,,,,,,,,,,, -1100,3.7163298,4.882288,,,,,,,,,,,,,, -1200,3.2530575,4.8120537,,,,,,,,,,,,,, -1300,6.976915,4.7632756,,,,,,,,,,,,,, -1400,3.4447446,4.607274,,,,,,,,,,,,,, -1499,,,0.1658561825752258,4.275528907775879,0.1472399979829788,4.434837818145752,50000.0,0.1118000075221061,4.884735107421875,10000.0,562.4275517463684,617.8459360599518,562.4275517463684,55.3390634059906,0.0274107456207275,0.0 -1500,7.5603237,4.4201536,,,,,,,,,,,,,, -1600,8.246814,4.4176908,,,,,,,,,,,,,, -1700,7.938628,4.307312,,,,,,,,,,,,,, -1800,6.419836,4.2016897,,,,,,,,,,,,,, -1900,5.55512,4.0936613,,,,,,,,,,,,,, -2000,6.125289,4.0468335,,,,,,,,,,,,,, -2100,5.6903205,4.0179777,,,,,,,,,,,,,, -2200,7.0413656,3.834391,,,,,,,,,,,,,, -2300,5.058214,3.8307033,,,,,,,,,,,,,, -2400,5.781705,3.7999737,,,,,,,,,,,,,, -2500,5.2071433,3.6511607,,,,,,,,,,,,,, -2600,5.560331,3.5462291,,,,,,,,,,,,,, -2700,3.6768687,3.623102,,,,,,,,,,,,,, -2800,6.3272977,3.4743905,,,,,,,,,,,,,, -2900,4.1529093,3.4806428,,,,,,,,,,,,,, -2996,,,0.3450255095958709,3.01550817489624,0.3194399774074554,3.1867895126342773,50000.0,0.236500009894371,3.801160335540772,10000.0,1072.6695590019226,1146.0415749549866,1072.6695590019226,73.20708441734314,0.0605521202087402,0.0 -3000,3.1625264,3.4470212,,,,,,,,,,,,,, -3100,4.1265097,3.3337774,,,,,,,,,,,,,, -3200,4.496218,3.3042336,,,,,,,,,,,,,, -3300,4.1967053,3.2536774,,,,,,,,,,,,,, -3400,4.307152,3.104828,,,,,,,,,,,,,, -3500,3.133103,3.1372786,,,,,,,,,,,,,, -3600,3.4328976,3.145905,,,,,,,,,,,,,, -3700,2.93546,3.2430673,,,,,,,,,,,,,, -3800,3.0134125,3.0870776,,,,,,,,,,,,,, -3900,2.512362,3.1005902,,,,,,,,,,,,,, -4000,3.6439521,3.010333,,,,,,,,,,,,,, -4100,2.4590533,2.8753486,,,,,,,,,,,,,, -4200,2.080834,2.9349103,,,,,,,,,,,,,, -4300,2.5590656,2.7316484,,,,,,,,,,,,,, -4400,1.8354563,2.7363636,,,,,,,,,,,,,, -4494,,,0.4690887928009033,2.31603479385376,0.4360799789428711,2.511484146118164,50000.0,0.3276000022888183,3.2471020221710205,10000.0,1582.7634320259094,1674.0095942020416,1582.7634320259094,91.00414776802064,0.0871176719665527,0.0 -4500,2.763701,2.8375478,,,,,,,,,,,,,, -4600,2.8811536,2.7359352,,,,,,,,,,,,,, -4700,2.0960267,2.7635498,,,,,,,,,,,,,, -4800,2.3201587,2.6836412,,,,,,,,,,,,,, -4900,2.5054355,2.657334,,,,,,,,,,,,,, -5000,2.0093448,2.6669278,,,,,,,,,,,,,, -5100,2.1995492,2.7433975,,,,,,,,,,,,,, -5200,2.3713317,2.6591954,,,,,,,,,,,,,, -5300,2.0012434,2.6149933,,,,,,,,,,,,,, -5400,1.5798068,2.6207094,,,,,,,,,,,,,, -5500,1.8507228,2.5981612,,,,,,,,,,,,,, -5600,1.7759755,2.4073095,,,,,,,,,,,,,, -5700,1.6387597,2.5247245,,,,,,,,,,,,,, -5800,1.9887096,2.5496855,,,,,,,,,,,,,, -5900,2.5664566,2.4832904,,,,,,,,,,,,,, -5994,,,0.5284797549247742,2.0127384662628174,0.4960199892520904,2.2188453674316406,50000.0,0.3756000101566314,2.999966144561768,10000.0,2092.8671317100525,2201.830789089203,2092.8671317100525,108.64397525787354,0.1147418022155761,0.0 -6000,2.0086184,2.5733025,,,,,,,,,,,,,, -6100,1.9996674,2.3206062,,,,,,,,,,,,,, -6200,2.1164708,2.4045897,,,,,,,,,,,,,, -6300,1.6206205,2.5174103,,,,,,,,,,,,,, -6400,2.040332,2.360837,,,,,,,,,,,,,, -6500,2.7380755,2.4248092,,,,,,,,,,,,,, -6600,1.8755744,2.3368912,,,,,,,,,,,,,, -6700,2.5400288,2.3071241,,,,,,,,,,,,,, -6800,2.155767,2.3088422,,,,,,,,,,,,,, -6900,2.255525,2.3443773,,,,,,,,,,,,,, -7000,2.0228982,2.3466692,,,,,,,,,,,,,, -7100,1.7203285,2.2715838,,,,,,,,,,,,,, -7200,2.108683,2.3484952,,,,,,,,,,,,,, -7300,1.977613,2.2812116,,,,,,,,,,,,,, -7400,1.7826667,2.24798,,,,,,,,,,,,,, -7494,,,0.5736607313156128,1.7896974086761477,0.5335999727249146,1.997921347618103,50000.0,0.4164000153541565,2.72767186164856,10000.0,2602.7957921028137,2729.759635448456,2602.7957921028137,126.56366062164308,0.141902208328247,0.0 -7500,1.5158135,2.2072718,,,,,,,,,,,,,, -7600,2.0285816,2.2496321,,,,,,,,,,,,,, -7700,1.8624163,2.214254,,,,,,,,,,,,,, -7800,1.90704,2.2761092,,,,,,,,,,,,,, -7900,1.7192795,2.1478508,,,,,,,,,,,,,, -8000,2.4919038,2.0963635,,,,,,,,,,,,,, -8100,1.6032308,2.3691256,,,,,,,,,,,,,, -8200,2.3171422,2.302179,,,,,,,,,,,,,, -8300,1.445272,2.2211106,,,,,,,,,,,,,, -8400,2.0868368,2.1467025,,,,,,,,,,,,,, -8500,2.2511132,2.1391675,,,,,,,,,,,,,, -8600,1.714288,2.074387,,,,,,,,,,,,,, -8700,2.2456598,2.187064,,,,,,,,,,,,,, -8800,2.2021484,2.1406605,,,,,,,,,,,,,, -8900,1.2345771,2.2432218,,,,,,,,,,,,,, -8995,,,0.5866350531578064,1.7268341779708862,0.5454999804496765,1.93905246257782,50000.0,0.4229000210762024,2.710902690887451,10000.0,3112.8875806331635,3258.0060436725616,3112.8875806331635,144.63809752464294,0.1693737506866455,0.0 -9000,1.4442722,2.0494282,,,,,,,,,,,,,, -9100,1.8072648,2.0492134,,,,,,,,,,,,,, -9200,1.6473936,2.1464884,,,,,,,,,,,,,, -9300,1.4925363,2.2464018,,,,,,,,,,,,,, -9400,1.2671288,2.14905,,,,,,,,,,,,,, -9500,1.4114573,2.0745978,,,,,,,,,,,,,, -9600,2.0534475,2.1842124,,,,,,,,,,,,,, -9700,1.9526078,1.9424423,,,,,,,,,,,,,, -9800,2.0158992,2.2201626,,,,,,,,,,,,,, -9900,1.6045822,2.1530972,,,,,,,,,,,,,, -10000,1.6148958,2.2771544,,,,,,,,,,,,,, -10100,1.628717,2.0918546,,,,,,,,,,,,,, -10200,1.983192,2.0542066,,,,,,,,,,,,,, -10300,1.4386194,2.1148522,,,,,,,,,,,,,, -10400,2.4300478,2.0018656,,,,,,,,,,,,,, -10495,,,0.6527024507522583,1.4016082286834717,0.5786399841308594,1.790817141532898,50000.0,0.45210000872612,2.53738784790039,10000.0,3622.910650968552,3786.3579564094543,3622.910650968552,162.88873028755188,0.1964619159698486,0.0 -10500,1.9180351,2.1176112,,,,,,,,,,,,,, -10600,1.5707693,2.1085439,,,,,,,,,,,,,, -10700,1.4998252,2.0500867,,,,,,,,,,,,,, -10800,1.5561243,2.0681875,,,,,,,,,,,,,, -10900,1.4456828,2.0766017,,,,,,,,,,,,,, -11000,1.875666,2.2492914,,,,,,,,,,,,,, -11100,1.6404157,2.0733402,,,,,,,,,,,,,, -11200,1.736764,2.147711,,,,,,,,,,,,,, -11300,2.068664,1.9557253,,,,,,,,,,,,,, -11400,1.4833705,1.9965848,,,,,,,,,,,,,, -11500,1.61475,2.096788,,,,,,,,,,,,,, -11600,1.6809225,1.9655069,,,,,,,,,,,,,, -11700,1.9720566,2.0062099,,,,,,,,,,,,,, -11800,1.6482211,1.9764873,,,,,,,,,,,,,, -11900,1.4268817,1.9322805,,,,,,,,,,,,,, -11997,,,0.6460459232330322,1.433350682258606,0.5874199867248535,1.7437560558319092,50000.0,0.4599000215530395,2.4923598766326904,10000.0,4133.089784383774,4314.894702672958,4133.089784383774,181.1612477302552,0.2311422824859619,0.0 -12000,1.3593879,1.8669451,,,,,,,,,,,,,, -12100,2.2633023,1.9184941,,,,,,,,,,,,,, -12200,1.6688658,1.973829,,,,,,,,,,,,,, -12300,1.7152315,2.0112774,,,,,,,,,,,,,, -12400,1.4362686,2.0019374,,,,,,,,,,,,,, -12500,1.3948138,1.9980807,,,,,,,,,,,,,, -12600,1.501384,1.950233,,,,,,,,,,,,,, -12700,1.4691765,2.0144558,,,,,,,,,,,,,, -12800,1.682305,2.0535905,,,,,,,,,,,,,, -12900,1.5905306,1.9483063,,,,,,,,,,,,,, -13000,1.3499609,1.9603424,,,,,,,,,,,,,, -13100,1.3961931,1.9254305,,,,,,,,,,,,,, -13200,1.4957926,1.8434092,,,,,,,,,,,,,, -13300,1.6418617,1.990985,,,,,,,,,,,,,, -13400,1.7756082,1.964788,,,,,,,,,,,,,, -13498,,,0.6513671875,1.4058401584625244,0.5913599729537964,1.706515908241272,50000.0,0.4731000363826751,2.4181039333343506,10000.0,4643.165565252304,4843.648952245712,4643.165565252304,199.754013299942,0.2653450965881347,0.0 -13500,1.7034003,2.0226648,,,,,,,,,,,,,, -13600,1.6529033,1.9459887,,,,,,,,,,,,,, -13700,1.5826254,1.946177,,,,,,,,,,,,,, -13800,1.5581375,1.9829017,,,,,,,,,,,,,, -13900,1.9470983,1.9160005,,,,,,,,,,,,,, -14000,1.7112718,1.947986,,,,,,,,,,,,,, -14100,1.9796741,1.843074,,,,,,,,,,,,,, -14200,1.3934848,1.8854185,,,,,,,,,,,,,, -14300,1.5575954,1.9323668,,,,,,,,,,,,,, -14400,1.6513685,1.9560663,,,,,,,,,,,,,, -14500,1.6252024,1.8313217,,,,,,,,,,,,,, -14600,1.5362622,1.9049637,,,,,,,,,,,,,, -14700,1.6094902,1.9351041,,,,,,,,,,,,,, -14800,1.4981056,2.078381,,,,,,,,,,,,,, -14900,1.4550631,1.9141667,,,,,,,,,,,,,, -15000,,,0.6477000713348389,1.419167160987854,0.592960000038147,1.7010241746902466,50000.0,0.4726000130176544,2.420872211456299,10000.0,5153.095174074173,5371.81556725502,5153.095174074173,217.90440320968628,0.3007020950317383,0.0 -15000,1.8988016,1.9437008,,,,,,,,,,,,,, -15100,1.4118152,1.9442497,,,,,,,,,,,,,, -15200,1.6049157,1.9282653,,,,,,,,,,,,,, -15300,1.461891,1.8343904,,,,,,,,,,,,,, -15400,1.5595495,1.786086,,,,,,,,,,,,,, -15500,1.3618371,1.8265706,,,,,,,,,,,,,, -15600,1.7615582,1.964176,,,,,,,,,,,,,, -15700,1.776977,1.8961763,,,,,,,,,,,,,, -15800,1.6886648,1.7869881,,,,,,,,,,,,,, -15900,1.3768839,1.9501854,,,,,,,,,,,,,, -16000,1.424844,1.9016913,,,,,,,,,,,,,, -16100,1.51634,1.9036512,,,,,,,,,,,,,, -16200,1.4988024,1.8679472,,,,,,,,,,,,,, -16300,1.4901474,1.927766,,,,,,,,,,,,,, -16400,1.4042449,1.9326047,,,,,,,,,,,,,, -16500,1.4323671,1.8363134,,,,,,,,,,,,,, -16501,,,0.6524234414100647,1.3966569900512695,0.5991599559783936,1.6813833713531494,50000.0,0.4708000123500824,2.4155619144439697,10000.0,5663.19630074501,5900.233821630478,5663.19630074501,236.13387608528137,0.3381190299987793,0.0 -16600,1.6638821,1.8838662,,,,,,,,,,,,,, -16700,1.7685356,1.9490943,,,,,,,,,,,,,, -16800,1.4264275,1.9571488,,,,,,,,,,,,,, -16900,1.6885948,1.9087601,,,,,,,,,,,,,, -17000,1.5366002,2.0151427,,,,,,,,,,,,,, -17100,1.6583899,1.8750023,,,,,,,,,,,,,, -17200,1.4909728,1.8299847,,,,,,,,,,,,,, -17300,1.5654029,1.6998491,,,,,,,,,,,,,, -17400,1.65467,2.047633,,,,,,,,,,,,,, -17500,1.5324568,1.8254585,,,,,,,,,,,,,, -17600,1.4954832,1.8965704,,,,,,,,,,,,,, -17700,1.8866857,1.8857709,,,,,,,,,,,,,, -17800,1.5487188,1.9036059,,,,,,,,,,,,,, -17900,1.5805182,1.8658457,,,,,,,,,,,,,, -18000,1.4425824,1.8349862,,,,,,,,,,,,,, -18003,,,0.6489955186843872,1.4116501808166504,0.5984399914741516,1.6733524799346924,50000.0,0.4759000241756439,2.3799636363983154,10000.0,6173.271780967712,6429.526218175888,6173.271780967712,255.26897883415225,0.3690087795257568,0.0 -18100,1.4370993,1.7810433,,,,,,,,,,,,,, -18200,1.64665,1.8042172,,,,,,,,,,,,,, -18300,1.5389451,1.8649054,,,,,,,,,,,,,, -18400,1.5719562,1.8792422,,,,,,,,,,,,,, -18500,1.7506267,1.8283086,,,,,,,,,,,,,, -18600,1.6335483,1.8487958,,,,,,,,,,,,,, -18700,1.5680119,1.7843716,,,,,,,,,,,,,, -18800,1.5956186,1.7547983,,,,,,,,,,,,,, -18900,1.6145389,1.7288804,,,,,,,,,,,,,, -19000,1.5174071,1.7726806,,,,,,,,,,,,,, -19100,1.533383,1.7729921,,,,,,,,,,,,,, -19200,2.008262,1.7084968,,,,,,,,,,,,,, -19300,2.0898297,1.9279187,,,,,,,,,,,,,, -19400,1.4845773,1.7969646,,,,,,,,,,,,,, -19500,1.5417094,1.924643,,,,,,,,,,,,,, -19505,,,0.7088249325752258,1.1469347476959229,0.6089000105857849,1.630070447921753,50000.0,0.4868000149726867,2.348141431808472,10000.0,6683.211872577667,6958.470961809158,6683.211872577667,274.1825485229492,0.4079773426055908,0.0 -19600,1.6018361,1.763621,,,,,,,,,,,,,, -19700,1.6849513,1.8635939,,,,,,,,,,,,,, -19800,1.6513934,1.7722774,,,,,,,,,,,,,, -19900,1.5868788,1.698308,,,,,,,,,,,,,, -20000,1.747375,1.8107419,,,,,,,,,,,,,, -20100,1.6442184,1.773984,,,,,,,,,,,,,, -20200,1.8330524,1.8820274,,,,,,,,,,,,,, -20300,1.6248204,1.7682494,,,,,,,,,,,,,, -20400,1.6894261,1.7374297,,,,,,,,,,,,,, -20500,1.6264168,1.7317488,,,,,,,,,,,,,, -20600,1.3832989,1.688285,,,,,,,,,,,,,, -20700,1.9769881,1.7544768,,,,,,,,,,,,,, -20800,1.7152858,1.7140867,,,,,,,,,,,,,, -20900,1.9706963,1.9265145,,,,,,,,,,,,,, -21000,1.6901402,1.8519266,,,,,,,,,,,,,, -21007,,,0.6806440949440002,1.2608718872070312,0.6103000044822693,1.6341451406478882,50000.0,0.4828000366687774,2.377382516860962,10000.0,7193.288778066635,7487.048682928085,7193.288778066635,292.59999918937683,0.4410958290100097,0.0 -21100,1.5844711,1.8852074,,,,,,,,,,,,,, -21200,1.6019809,1.8370211,,,,,,,,,,,,,, -21300,1.8199731,1.7870392,,,,,,,,,,,,,, -21400,1.6803983,1.860614,,,,,,,,,,,,,, -21500,1.7245562,1.8297485,,,,,,,,,,,,,, -21600,1.6838048,1.752159,,,,,,,,,,,,,, -21700,1.9063905,1.9966127,,,,,,,,,,,,,, -21800,2.0459652,1.8382891,,,,,,,,,,,,,, -21900,1.8661377,1.7656466,,,,,,,,,,,,,, -22000,1.7688568,1.930666,,,,,,,,,,,,,, -22100,1.7612721,1.91704,,,,,,,,,,,,,, -22200,1.7182856,1.7865388,,,,,,,,,,,,,, -22300,1.8514308,1.8183535,,,,,,,,,,,,,, -22400,1.4205488,1.7409028,,,,,,,,,,,,,, -22500,1.8912913,1.8251214,,,,,,,,,,,,,, -22508,,,0.6704998016357422,1.3082401752471924,0.6072799563407898,1.6411656141281128,50000.0,0.4828000366687774,2.365996122360229,10000.0,7703.217472076416,8015.52258348465,7703.217472076416,311.04740691185,0.4867265224456787,0.0 -22600,1.5721459,1.8417652,,,,,,,,,,,,,, -22700,1.8778905,1.8581858,,,,,,,,,,,,,, -22800,1.8247905,1.9425205,,,,,,,,,,,,,, -22900,1.5632607,1.8247535,,,,,,,,,,,,,, -23000,1.4442124,1.7042475,,,,,,,,,,,,,, -23100,1.7430978,1.7742184,,,,,,,,,,,,,, -23200,1.8083701,1.7196877,,,,,,,,,,,,,, -23300,1.7318805,1.844585,,,,,,,,,,,,,, -23400,1.4794189,1.752095,,,,,,,,,,,,,, -23500,1.6791508,1.746647,,,,,,,,,,,,,, -23600,1.9425582,1.7705181,,,,,,,,,,,,,, -23700,1.6738235,1.8750038,,,,,,,,,,,,,, -23800,1.8530923,1.7644873,,,,,,,,,,,,,, -23900,1.7678392,1.835706,,,,,,,,,,,,,, -24000,1.6371051,1.7173371,,,,,,,,,,,,,, -24010,,,0.6690449714660645,1.3177435398101809,0.608460009098053,1.6242786645889282,50000.0,0.4818000197410583,2.3678853511810303,10000.0,8213.273429870605,8546.886992692947,8213.273429870605,332.26763224601746,0.5245239734649658,0.0 -24100,1.5900517,1.7572104,,,,,,,,,,,,,, -24200,1.838547,1.8208734,,,,,,,,,,,,,, -24300,1.7630057,1.708241,,,,,,,,,,,,,, -24400,1.6433771,1.8927873,,,,,,,,,,,,,, -24500,1.718227,1.7793039,,,,,,,,,,,,,, -24600,1.897516,1.6928452,,,,,,,,,,,,,, -24700,1.7952828,1.7699428,,,,,,,,,,,,,, -24800,2.237456,1.8358533,,,,,,,,,,,,,, -24900,1.7834824,1.800996,,,,,,,,,,,,,, -25000,1.741786,1.8572423,,,,,,,,,,,,,, -25100,1.8555974,1.8216798,,,,,,,,,,,,,, -25200,1.5951144,1.7617955,,,,,,,,,,,,,, -25300,1.8637965,1.7089515,,,,,,,,,,,,,, -25400,1.863099,1.7020756,,,,,,,,,,,,,, -25500,1.4356619,1.5601898,,,,,,,,,,,,,, -25512,,,0.6869817972183228,1.2383699417114258,0.627079963684082,1.5450199842453003,50000.0,0.499500036239624,2.262160301208496,10000.0,8723.448077917099,9080.337821245192,8723.448077917099,355.46244287490845,0.5544643402099609,0.0 -25600,1.6623275,1.68114,,,,,,,,,,,,,, -25700,1.9031292,1.8187848,,,,,,,,,,,,,, -25800,1.5461522,1.7972643,,,,,,,,,,,,,, -25900,1.7505286,1.721389,,,,,,,,,,,,,, -26000,1.6619495,1.7904598,,,,,,,,,,,,,, -26100,1.6310549,1.8514959,,,,,,,,,,,,,, -26200,1.6686409,1.7423759,,,,,,,,,,,,,, -26300,1.668834,1.6961588,,,,,,,,,,,,,, -26400,1.6630989,1.7440352,,,,,,,,,,,,,, -26500,1.7188742,1.7568872,,,,,,,,,,,,,, -26600,1.8234676,1.7412574,,,,,,,,,,,,,, -26700,1.5329682,1.7081399,,,,,,,,,,,,,, -26800,1.9746989,1.7497504,,,,,,,,,,,,,, -26900,1.6752372,1.7094933,,,,,,,,,,,,,, -27000,1.7879319,1.8131516,,,,,,,,,,,,,, -27015,,,0.6685666441917419,1.31665301322937,0.6162599921226501,1.593363881111145,50000.0,0.4883000254631042,2.3350415229797363,10000.0,9233.63310432434,9614.2236931324,9233.63310432434,379.0861880779266,0.5806448459625244,0.0 -27100,1.8245443,1.8012863,,,,,,,,,,,,,, -27200,1.481426,1.63111,,,,,,,,,,,,,, -27300,1.6824006,1.745304,,,,,,,,,,,,,, -27400,1.665388,1.6529888,,,,,,,,,,,,,, -27500,1.8559852,1.7203203,,,,,,,,,,,,,, -27600,1.6556586,1.7682965,,,,,,,,,,,,,, -27700,1.8798487,1.8018856,,,,,,,,,,,,,, -27800,1.7117336,1.7978153,,,,,,,,,,,,,, -27900,1.9803226,1.7244084,,,,,,,,,,,,,, -28000,1.6921923,1.8042518,,,,,,,,,,,,,, -28100,1.799466,1.8271452,,,,,,,,,,,,,, -28200,1.8278601,1.8606591,,,,,,,,,,,,,, -28300,1.8399904,1.7880862,,,,,,,,,,,,,, -28400,1.669327,1.6888566,,,,,,,,,,,,,, -28500,2.1676702,1.8042731,,,,,,,,,,,,,, -28516,,,0.7271006107330322,1.0419692993164062,0.6188399791717529,1.5885690450668335,50000.0,0.4898000359535217,2.3473498821258545,10000.0,9743.41288614273,10148.186551094055,9743.41288614273,402.9537811279297,0.8449652194976807,0.0 -28600,1.7777555,1.7616119,,,,,,,,,,,,,, -28700,1.686601,1.8221729,,,,,,,,,,,,,, -28800,1.6669357,1.7028263,,,,,,,,,,,,,, -28900,2.1454327,1.6932075,,,,,,,,,,,,,, -29000,1.6966995,1.7132996,,,,,,,,,,,,,, -29100,1.7695968,1.699234,,,,,,,,,,,,,, -29200,1.6894284,1.8255172,,,,,,,,,,,,,, -29300,1.8887485,1.6465697,,,,,,,,,,,,,, -29400,1.7968924,1.7377084,,,,,,,,,,,,,, -29500,1.9369558,1.7329698,,,,,,,,,,,,,, -29600,1.7017279,1.7273117,,,,,,,,,,,,,, -29700,1.7185715,1.6569033,,,,,,,,,,,,,, -29800,1.707469,1.7719327,,,,,,,,,,,,,, -29900,1.7648679,1.7305124,,,,,,,,,,,,,, -30000,1.7014362,1.7042832,,,,,,,,,,,,,, -30018,,,0.70804762840271,1.1357526779174805,0.6301800012588501,1.532379984855652,50000.0,0.5054000020027161,2.2726728916168213,10000.0,10253.381639242172,10683.135558843613,10253.381639242172,427.8536124229431,0.8739020824432373,0.0 -30100,1.7047082,1.6652905,,,,,,,,,,,,,, -30200,1.6359202,1.7066551,,,,,,,,,,,,,, -30300,1.5728794,1.7039411,,,,,,,,,,,,,, -30400,1.7381223,1.8341022,,,,,,,,,,,,,, -30500,1.921454,1.6950251,,,,,,,,,,,,,, -30600,1.7359376,1.5747962,,,,,,,,,,,,,, -30700,1.5123037,1.8234752,,,,,,,,,,,,,, -30800,1.8928535,1.7325889,,,,,,,,,,,,,, -30900,1.919498,1.6472535,,,,,,,,,,,,,, -31000,1.7286597,1.8332751,,,,,,,,,,,,,, -31100,1.5259329,1.5705084,,,,,,,,,,,,,, -31200,1.7874781,1.676569,,,,,,,,,,,,,, -31300,1.8833027,1.7094443,,,,,,,,,,,,,, -31400,1.9494804,1.779984,,,,,,,,,,,,,, -31500,1.7312285,1.6409009,,,,,,,,,,,,,, -31520,,,0.6942163705825806,1.2090507745742798,0.6239399909973145,1.5508848428726196,50000.0,0.4951000213623047,2.2808399200439453,10000.0,10763.521374940872,11218.338215351105,10763.521374940872,452.8360719680786,0.9039633274078368,0.0 -31600,1.7111003,1.7188998,,,,,,,,,,,,,, -31700,1.7021145,1.6778954,,,,,,,,,,,,,, -31800,2.03106,1.7320594,,,,,,,,,,,,,, -31900,1.7717358,1.6304388,,,,,,,,,,,,,, -32000,1.8967695,1.6824696,,,,,,,,,,,,,, -32100,1.7098625,1.6363802,,,,,,,,,,,,,, -32200,1.6801497,1.6961262,,,,,,,,,,,,,, -32300,2.0145016,1.6775639,,,,,,,,,,,,,, -32400,2.0151336,1.7354983,,,,,,,,,,,,,, -32500,1.7164469,1.7274659,,,,,,,,,,,,,, -32600,1.8055053,1.6579063,,,,,,,,,,,,,, -32700,1.6404548,1.5311866,,,,,,,,,,,,,, -32800,1.772483,1.8130977,,,,,,,,,,,,,, -32900,1.9195495,1.7688488,,,,,,,,,,,,,, -33000,1.7204443,1.7090542,,,,,,,,,,,,,, -33022,,,0.6808235049247742,1.2375199794769287,0.6197400093078613,1.5743621587753296,50000.0,0.4944000244140625,2.32073712348938,10000.0,11273.513464927672,11752.77766919136,11273.513464927672,477.199057340622,0.9372315406799316,0.0 -33100,1.6680963,1.6004832,,,,,,,,,,,,,, -33200,1.7889718,1.6188854,,,,,,,,,,,,,, -33300,1.8867979,1.842778,,,,,,,,,,,,,, -33400,1.7708756,1.7623916,,,,,,,,,,,,,, -33500,1.6374629,1.7151455,,,,,,,,,,,,,, -33600,1.5884598,1.6508219,,,,,,,,,,,,,, -33700,1.688761,1.7234693,,,,,,,,,,,,,, -33800,1.7374419,1.6691015,,,,,,,,,,,,,, -33900,1.8989435,1.7359209,,,,,,,,,,,,,, -34000,2.0003507,1.7388574,,,,,,,,,,,,,, -34100,1.6081108,1.5902567,,,,,,,,,,,,,, -34200,1.7180355,1.6770549,,,,,,,,,,,,,, -34300,1.8176675,1.6113863,,,,,,,,,,,,,, -34400,1.7378343,1.6415077,,,,,,,,,,,,,, -34500,1.9960446,1.5870339,,,,,,,,,,,,,, -34524,,,0.6873405575752258,1.2315131425857544,0.626479983329773,1.5449154376983645,50000.0,0.496500015258789,2.297478675842285,10000.0,11783.88451075554,12289.809584379196,11783.88451075554,503.7779715061188,0.9684596061706544,0.0 -34600,1.7420969,1.6437743,,,,,,,,,,,,,, -34700,1.6776419,1.5231805,,,,,,,,,,,,,, -34800,1.9095076,1.5654938,,,,,,,,,,,,,, -34900,1.5854008,1.6813518,,,,,,,,,,,,,, -35000,1.913533,1.6868513,,,,,,,,,,,,,, -35100,2.036107,1.6425434,,,,,,,,,,,,,, -35200,1.8946791,1.6026864,,,,,,,,,,,,,, -35300,2.2063074,1.7476388,,,,,,,,,,,,,, -35400,1.902966,1.7390729,,,,,,,,,,,,,, -35500,1.911872,1.7361351,,,,,,,,,,,,,, -35600,1.9377207,1.718051,,,,,,,,,,,,,, -35700,1.9755217,1.7783281,,,,,,,,,,,,,, -35800,1.8900589,1.703315,,,,,,,,,,,,,, -35900,1.8588712,1.7005658,,,,,,,,,,,,,, -36000,1.6935549,1.6752703,,,,,,,,,,,,,, -36027,,,0.6824377775192261,1.2468916177749634,0.6263799667358398,1.54625141620636,50000.0,0.5012000203132629,2.299901008605957,10000.0,12294.029720783234,12824.942950963974,12294.029720783234,528.6820592880249,1.0011744499206543,0.0 -36100,1.8782593,1.7305936,,,,,,,,,,,,,, -36200,1.7632359,1.6709749,,,,,,,,,,,,,, -36300,1.8195509,1.7154143,,,,,,,,,,,,,, -36400,2.1789432,1.7832803,,,,,,,,,,,,,, -36500,2.1026285,1.736577,,,,,,,,,,,,,, -36600,1.7106619,1.6484866,,,,,,,,,,,,,, -36700,1.9790016,1.71327,,,,,,,,,,,,,, -36800,1.7867162,1.7677991,,,,,,,,,,,,,, -36900,1.7903483,1.7250648,,,,,,,,,,,,,, -37000,1.8157176,1.6596928,,,,,,,,,,,,,, -37100,1.6984185,1.6426349,,,,,,,,,,,,,, -37200,2.0212448,1.6277424,,,,,,,,,,,,,, -37300,1.6971856,1.6350974,,,,,,,,,,,,,, -37400,2.0497317,1.6328442,,,,,,,,,,,,,, -37500,2.2595158,1.8021029,,,,,,,,,,,,,, -37529,,,0.7158800959587097,1.1090245246887207,0.6350600123405457,1.5164011716842651,50000.0,0.5037000179290771,2.2564914226531982,10000.0,12804.038206338882,13357.855704545977,12804.038206338882,551.5062041282654,1.030848264694214,0.0 -37600,1.8170549,1.5528108,,,,,,,,,,,,,, -37700,1.8570287,1.6277784,,,,,,,,,,,,,, -37800,2.1292365,1.640546,,,,,,,,,,,,,, -37900,1.79902,1.757664,,,,,,,,,,,,,, -38000,2.0562334,1.6285535,,,,,,,,,,,,,, -38100,1.9698077,1.7436682,,,,,,,,,,,,,, -38200,1.9387538,1.7846782,,,,,,,,,,,,,, -38300,2.0766602,1.6609279,,,,,,,,,,,,,, -38400,1.8814226,1.6305662,,,,,,,,,,,,,, -38500,1.833814,1.7298362,,,,,,,,,,,,,, -38600,1.9555304,1.6635343,,,,,,,,,,,,,, -38700,2.14,1.7597293,,,,,,,,,,,,,, -38800,1.9173236,1.6530685,,,,,,,,,,,,,, -38900,1.8238912,1.6634685,,,,,,,,,,,,,, -39000,1.7800148,1.6169889,,,,,,,,,,,,,, -39032,,,0.71683669090271,1.1103540658950806,0.62909996509552,1.5288817882537842,50000.0,0.5091000199317932,2.224526882171631,10000.0,13314.092184782028,13890.870227098463,13314.092184782028,574.3761565685272,1.0700502395629885,0.0 -39100,1.8399067,1.62795,,,,,,,,,,,,,, -39200,1.951971,1.7195332,,,,,,,,,,,,,, -39300,1.8671223,1.6652955,,,,,,,,,,,,,, -39400,1.8721483,1.7696813,,,,,,,,,,,,,, -39500,1.8110417,1.5238714,,,,,,,,,,,,,, -39600,1.7776285,1.6185321,,,,,,,,,,,,,, -39700,1.8482248,1.6480803,,,,,,,,,,,,,, -39800,1.7441382,1.6463356,,,,,,,,,,,,,, -39900,1.8665107,1.7116524,,,,,,,,,,,,,, -40000,1.9853979,1.7638911,,,,,,,,,,,,,, -40100,1.8739812,1.5896599,,,,,,,,,,,,,, -40200,1.9773738,1.6796424,,,,,,,,,,,,,, -40300,2.0746987,1.7608285,,,,,,,,,,,,,, -40400,1.7719989,1.6628225,,,,,,,,,,,,,, -40500,1.8583387,1.6334579,,,,,,,,,,,,,, -40534,,,0.7096619606018066,1.1346745491027832,0.6363999843597412,1.4892843961715698,50000.0,0.5076000094413757,2.232386827468872,10000.0,13824.17870426178,14424.020713090897,13824.17870426178,597.3547255992889,1.10364031791687,0.0 -40600,1.7749852,1.729005,,,,,,,,,,,,,, -40700,1.807311,1.6510396,,,,,,,,,,,,,, -40800,1.6947107,1.5407404,,,,,,,,,,,,,, -40900,1.7877318,1.6621845,,,,,,,,,,,,,, -41000,1.9945266,1.6705979,,,,,,,,,,,,,, -41100,1.7807399,1.8171017,,,,,,,,,,,,,, -41200,1.8560553,1.512864,,,,,,,,,,,,,, -41300,1.8857782,1.7111441,,,,,,,,,,,,,, -41400,1.7198541,1.6033009,,,,,,,,,,,,,, -41500,1.6916678,1.6407789,,,,,,,,,,,,,, -41600,1.8894023,1.7437907,,,,,,,,,,,,,, -41700,2.1859028,1.7433715,,,,,,,,,,,,,, -41800,1.8113413,1.5904772,,,,,,,,,,,,,, -41900,2.03706,1.6153315,,,,,,,,,,,,,, -42000,1.997277,1.7760414,,,,,,,,,,,,,, -42036,,,0.6986008882522583,1.179208517074585,0.6325599551200867,1.5215106010437012,50000.0,0.5042000412940979,2.279670000076294,10000.0,14334.231875896454,14959.633465051653,14334.231875896454,622.8276555538177,1.1378414630889893,0.0 -42100,1.7204995,1.4555175,,,,,,,,,,,,,, -42200,1.998504,1.6599294,,,,,,,,,,,,,, -42300,1.7146208,1.6045309,,,,,,,,,,,,,, -42400,1.6397493,1.5633274,,,,,,,,,,,,,, -42500,1.8011154,1.4963301,,,,,,,,,,,,,, -42600,1.7717131,1.6905321,,,,,,,,,,,,,, -42700,1.9273877,1.6592329,,,,,,,,,,,,,, -42800,1.7286714,1.6191598,,,,,,,,,,,,,, -42900,1.9155579,1.7539811,,,,,,,,,,,,,, -43000,1.8044649,1.6656628,,,,,,,,,,,,,, -43100,1.7332692,1.5167959,,,,,,,,,,,,,, -43200,1.7560897,1.6711658,,,,,,,,,,,,,, -43300,2.0301893,1.7431369,,,,,,,,,,,,,, -43400,1.9230722,1.5491335,,,,,,,,,,,,,, -43500,1.9405143,1.4933169,,,,,,,,,,,,,, -43538,,,0.6833944320678711,1.2552924156188965,0.6233400106430054,1.5577534437179563,50000.0,0.4919000267982483,2.3413572311401367,10000.0,14844.221347093582,15492.801275491714,14844.221347093582,645.9152612686157,1.1770331859588623,0.0 -43600,1.9806118,1.6199317,,,,,,,,,,,,,, -43700,1.9240535,1.6199448,,,,,,,,,,,,,, -43800,1.887176,1.5843465,,,,,,,,,,,,,, -43900,1.665673,1.5143229,,,,,,,,,,,,,, -44000,1.8237427,1.5803401,,,,,,,,,,,,,, -44100,2.01984,1.6168191,,,,,,,,,,,,,, -44200,1.571351,1.639307,,,,,,,,,,,,,, -44300,1.9535571,1.5046219,,,,,,,,,,,,,, -44400,1.642902,1.687351,,,,,,,,,,,,,, -44500,1.8416684,1.744468,,,,,,,,,,,,,, -44600,1.8619994,1.636529,,,,,,,,,,,,,, -44700,2.0307937,1.6695614,,,,,,,,,,,,,, -44800,1.9816953,1.8264482,,,,,,,,,,,,,, -44900,1.9304799,1.5810795,,,,,,,,,,,,,, -45000,1.8074684,1.6516182,,,,,,,,,,,,,, -45041,,,0.6790497303009033,1.2766070365905762,0.6209200024604797,1.573642611503601,50000.0,0.4878000319004059,2.3068480491638184,10000.0,15354.339327812197,16026.276261806488,15354.339327812197,669.1851181983948,1.2121222019195557,0.0 -45100,1.9061244,1.7250574,,,,,,,,,,,,,, -45200,2.0427132,1.6401291,,,,,,,,,,,,,, -45300,1.8868686,1.6702957,,,,,,,,,,,,,, -45400,1.7334348,1.5488901,,,,,,,,,,,,,, -45500,1.7822553,1.7501034,,,,,,,,,,,,,, -45600,1.8592879,1.6802833,,,,,,,,,,,,,, -45700,1.8831972,1.540067,,,,,,,,,,,,,, -45800,1.8982031,1.7474384,,,,,,,,,,,,,, -45900,1.7194949,1.6615634,,,,,,,,,,,,,, -46000,1.7533259,1.5534577,,,,,,,,,,,,,, -46100,1.9426749,1.6754698,,,,,,,,,,,,,, -46200,2.1817203,1.5998888,,,,,,,,,,,,,, -46300,2.1166599,1.64965,,,,,,,,,,,,,, -46400,1.8170896,1.5847629,,,,,,,,,,,,,, -46500,2.0542367,1.6426349,,,,,,,,,,,,,, -46544,,,0.7046595811843872,1.145161271095276,0.6420800089836121,1.48219633102417,50000.0,0.5092000365257263,2.218715190887451,10000.0,15864.367563009262,16559.309529066086,15864.367563009262,692.1000168323517,1.25059175491333,0.0 -46600,1.7243621,1.6209333,,,,,,,,,,,,,, -46700,1.7550493,1.6115494,,,,,,,,,,,,,, -46800,2.191279,1.66057,,,,,,,,,,,,,, -46900,1.8389155,1.6764404,,,,,,,,,,,,,, -47000,1.7927011,1.5317763,,,,,,,,,,,,,, -47100,1.6869578,1.6488461,,,,,,,,,,,,,, -47200,1.893705,1.4774325,,,,,,,,,,,,,, -47300,2.0164938,1.5126202,,,,,,,,,,,,,, -47400,2.0025582,1.6679783,,,,,,,,,,,,,, -47500,1.8520734,1.6285284,,,,,,,,,,,,,, -47600,1.8727517,1.7152109,,,,,,,,,,,,,, -47700,1.7333429,1.5617113,,,,,,,,,,,,,, -47800,1.7085773,1.5860335,,,,,,,,,,,,,, -47900,1.7538189,1.5369765,,,,,,,,,,,,,, -48000,1.879964,1.5237136,,,,,,,,,,,,,, -48048,,,0.7367067933082581,0.9979020357131958,0.6473199725151062,1.4513506889343262,50000.0,0.5217000246047974,2.187633752822876,10000.0,16374.504691123962,17092.12025952339,16374.504691123962,714.6889681816101,1.2839748859405518,0.0 -48100,2.0339205,1.5578264,,,,,,,,,,,,,, -48200,1.7543072,1.6710582,,,,,,,,,,,,,, -48300,1.904897,1.5980672,,,,,,,,,,,,,, -48400,1.8360454,1.6922054,,,,,,,,,,,,,, -48500,2.052018,1.7467866,,,,,,,,,,,,,, -48600,1.8480874,1.4778651,,,,,,,,,,,,,, -48700,2.0477903,1.6022754,,,,,,,,,,,,,, -48800,1.7025274,1.5714471,,,,,,,,,,,,,, -48900,1.743695,1.6670942,,,,,,,,,,,,,, -49000,1.811725,1.6798847,,,,,,,,,,,,,, -49100,2.000352,1.6282189,,,,,,,,,,,,,, -49200,1.8829001,1.6377721,,,,,,,,,,,,,, -49300,1.7868541,1.577385,,,,,,,,,,,,,, -49400,1.7714492,1.5608988,,,,,,,,,,,,,, -49500,1.8601283,1.5925766,,,,,,,,,,,,,, -49551,,,0.7074099183082581,1.1301013231277466,0.6358799934387207,1.5119726657867432,50000.0,0.5031000375747681,2.310176134109497,10000.0,16884.7590508461,17625.10194325447,16884.7590508461,737.3301610946655,1.3194658756256104,0.0 -49600,1.9434798,1.510042,,,,,,,,,,,,,, -49700,1.9783256,1.6413441,,,,,,,,,,,,,, -49800,1.8354892,1.6624744,,,,,,,,,,,,,, -49900,1.9134468,1.5716724,,,,,,,,,,,,,, -50000,1.848107,1.621802,,,,,,,,,,,,,, -50100,2.1589131,1.7803347,,,,,,,,,,,,,, -50200,1.8932192,1.6041456,,,,,,,,,,,,,, -50300,1.9951057,1.6409061,,,,,,,,,,,,,, -50400,1.8058146,1.5061415,,,,,,,,,,,,,, -50500,1.9804476,1.7523353,,,,,,,,,,,,,, -50600,1.884448,1.6367177,,,,,,,,,,,,,, -50700,1.704583,1.5543748,,,,,,,,,,,,,, -50800,1.7144753,1.626245,,,,,,,,,,,,,, -50900,2.0199032,1.7834558,,,,,,,,,,,,,, -51000,1.9010425,1.713003,,,,,,,,,,,,,, -51054,,,0.7250877022743225,1.0787639617919922,0.6547200083732605,1.407345175743103,50000.0,0.527400016784668,2.135467767715454,10000.0,17394.90104007721,18156.323054790497,17394.90104007721,758.3226172924042,1.3547601699829102,0.0 -51100,2.000558,1.6292979,,,,,,,,,,,,,, -51200,1.9001344,1.6084675,,,,,,,,,,,,,, -51300,1.7739395,1.6206331,,,,,,,,,,,,,, -51400,1.8420547,1.5099175,,,,,,,,,,,,,, -51500,1.8028623,1.6308331,,,,,,,,,,,,,, -51600,2.0753968,1.6891184,,,,,,,,,,,,,, -51700,1.7640095,1.5725532,,,,,,,,,,,,,, -51800,1.8664515,1.5093063,,,,,,,,,,,,,, -51900,1.8740568,1.6133522,,,,,,,,,,,,,, -52000,1.9155005,1.6660668,,,,,,,,,,,,,, -52100,1.8854375,1.4870468,,,,,,,,,,,,,, -52200,1.901473,1.5789256,,,,,,,,,,,,,, -52300,1.8417829,1.566479,,,,,,,,,,,,,, -52400,1.9368339,1.5039856,,,,,,,,,,,,,, -52500,1.8719858,1.6126188,,,,,,,,,,,,,, -52557,,,0.7002750039100647,1.1716082096099854,0.6354599595069885,1.5122188329696655,50000.0,0.5105000138282776,2.2743141651153564,10000.0,17905.1085562706,18688.75293493271,17905.1085562706,780.4596221446991,1.389142990112305,0.0 -52600,1.9215516,1.5799383,,,,,,,,,,,,,, -52700,1.9962724,1.6198162,,,,,,,,,,,,,, -52800,1.6866243,1.5161924,,,,,,,,,,,,,, -52900,1.873752,1.5956085,,,,,,,,,,,,,, -53000,1.9090341,1.6203444,,,,,,,,,,,,,, -53100,1.7916361,1.486796,,,,,,,,,,,,,, -53200,1.7261777,1.5212462,,,,,,,,,,,,,, -53300,2.1416545,1.743257,,,,,,,,,,,,,, -53400,1.7384491,1.6169255,,,,,,,,,,,,,, -53500,2.0236802,1.6138998,,,,,,,,,,,,,, -53600,1.8896046,1.6110942,,,,,,,,,,,,,, -53700,1.7974843,1.6397132,,,,,,,,,,,,,, -53800,1.6437818,1.4593903,,,,,,,,,,,,,, -53900,2.0169683,1.6547098,,,,,,,,,,,,,, -54000,1.8694457,1.4665236,,,,,,,,,,,,,, -54061,,,0.704500138759613,1.1456981897354126,0.6450799703598022,1.4624508619308472,50000.0,0.5192000269889832,2.1851463317871094,10000.0,18415.254390001297,19220.11591911316,18415.254390001297,801.5913171768188,1.4218056201934814,0.0 -54100,2.173751,1.6591594,,,,,,,,,,,,,, -54200,1.8972067,1.515559,,,,,,,,,,,,,, -54300,2.0466845,1.5830503,,,,,,,,,,,,,, -54400,1.7431281,1.529876,,,,,,,,,,,,,, -54500,1.9524012,1.5377126,,,,,,,,,,,,,, -54600,1.8995918,1.5762532,,,,,,,,,,,,,, -54700,2.1202068,1.6337,,,,,,,,,,,,,, -54800,1.9961224,1.6394442,,,,,,,,,,,,,, -54900,1.9265139,1.6461718,,,,,,,,,,,,,, -55000,1.8770655,1.6818347,,,,,,,,,,,,,, -55100,1.9119438,1.6108543,,,,,,,,,,,,,, -55200,1.7806189,1.4998326,,,,,,,,,,,,,, -55300,2.0535357,1.6272434,,,,,,,,,,,,,, -55400,1.8566489,1.5437177,,,,,,,,,,,,,, -55500,1.939401,1.5955894,,,,,,,,,,,,,, -55564,,,0.7084661722183228,1.1322740316390991,0.649679958820343,1.445713996887207,50000.0,0.5225000381469727,2.1526975631713867,10000.0,18925.44615507126,19750.42909526825,18925.44615507126,821.6257679462433,1.4563281536102295,0.0 -55600,2.1572967,1.6286519,,,,,,,,,,,,,, -55700,1.8714263,1.6100667,,,,,,,,,,,,,, -55800,1.6850876,1.5713279,,,,,,,,,,,,,, -55900,1.8148608,1.5837957,,,,,,,,,,,,,, -56000,2.0211267,1.6130649,,,,,,,,,,,,,, -56100,2.179025,1.6527824,,,,,,,,,,,,,, -56200,1.85425,1.4672025,,,,,,,,,,,,,, -56300,1.9098427,1.5360423,,,,,,,,,,,,,, -56400,1.8105975,1.5349317,,,,,,,,,,,,,, -56500,1.8786223,1.6150963,,,,,,,,,,,,,, -56600,1.9992584,1.4796599,,,,,,,,,,,,,, -56700,1.9696608,1.7151967,,,,,,,,,,,,,, -56800,1.9462391,1.6688094,,,,,,,,,,,,,, -56900,2.3454037,1.4829692,,,,,,,,,,,,,, -57000,2.1343617,1.6419804,,,,,,,,,,,,,, -57068,,,0.7385801672935486,1.0057047605514526,0.6440799832344055,1.4641975164413452,50000.0,0.5080000162124634,2.2163994312286377,10000.0,19435.642174959183,20277.98100042343,19435.642174959183,838.8901033401489,1.4961392879486084,0.0 -57100,1.8451146,1.5851183,,,,,,,,,,,,,, -57200,1.8469412,1.6942071,,,,,,,,,,,,,, -57300,1.9019189,1.6458781,,,,,,,,,,,,,, -57400,2.050605,1.6075906,,,,,,,,,,,,,, -57500,1.9897465,1.6490788,,,,,,,,,,,,,, -57600,1.9069104,1.547394,,,,,,,,,,,,,, -57700,1.9745313,1.6484239,,,,,,,,,,,,,, -57800,1.8890439,1.5919826,,,,,,,,,,,,,, -57900,1.9673239,1.5196643,,,,,,,,,,,,,, -58000,2.0689168,1.5934088,,,,,,,,,,,,,, -58100,1.75698,1.499263,,,,,,,,,,,,,, -58200,2.2283823,1.5777589,,,,,,,,,,,,,, -58300,1.9059283,1.5904787,,,,,,,,,,,,,, -58400,2.0311177,1.6360177,,,,,,,,,,,,,, -58500,1.9294944,1.5191478,,,,,,,,,,,,,, -58571,,,0.7267019748687744,1.0357704162597656,0.646619975566864,1.45211923122406,50000.0,0.522100031375885,2.178805589675904,10000.0,19945.802860736847,20805.475059747696,19945.802860736847,856.131756067276,1.5372190475463867,0.0 -58600,1.9784307,1.509825,,,,,,,,,,,,,, -58700,2.0514162,1.5315024,,,,,,,,,,,,,, -58800,2.012776,1.6238286,,,,,,,,,,,,,, -58900,1.9400403,1.6072569,,,,,,,,,,,,,, -59000,1.8581434,1.517549,,,,,,,,,,,,,, -59100,1.8714364,1.657052,,,,,,,,,,,,,, -59200,2.0856967,1.5051782,,,,,,,,,,,,,, -59300,1.8839798,1.60244,,,,,,,,,,,,,, -59400,1.9811782,1.5757005,,,,,,,,,,,,,, -59500,2.233932,1.5320411,,,,,,,,,,,,,, -59600,2.0068004,1.5126346,,,,,,,,,,,,,, -59700,2.0828648,1.5609566,,,,,,,,,,,,,, -59800,2.0272114,1.4949951,,,,,,,,,,,,,, -59900,1.7844592,1.5717989,,,,,,,,,,,,,, -60000,1.9447585,1.4644784,,,,,,,,,,,,,, -60074,,,0.7245694994926453,1.055361032485962,0.6547600030899048,1.4124293327331543,50000.0,0.5263000130653381,2.112003803253174,10000.0,20455.73877811432,21332.740632534027,20455.73877811432,873.3655483722687,1.5809953212738037,0.0 -60100,1.9056084,1.5301758,,,,,,,,,,,,,, -60200,2.0328228,1.5804561,,,,,,,,,,,,,, -60300,1.8215789,1.568964,,,,,,,,,,,,,, -60400,2.010699,1.474129,,,,,,,,,,,,,, -60500,1.8099993,1.6403774,,,,,,,,,,,,,, -60600,2.0556185,1.4607166,,,,,,,,,,,,,, -60700,2.3890939,1.6154358,,,,,,,,,,,,,, -60800,1.9396774,1.5300479,,,,,,,,,,,,,, -60900,2.0378041,1.6769407,,,,,,,,,,,,,, -61000,1.8215563,1.580186,,,,,,,,,,,,,, -61100,2.144946,1.5388025,,,,,,,,,,,,,, -61200,2.0354385,1.6555706,,,,,,,,,,,,,, -61300,2.1503816,1.5899539,,,,,,,,,,,,,, -61400,1.8533741,1.5037022,,,,,,,,,,,,,, -61500,2.0881555,1.7029349,,,,,,,,,,,,,, -61577,,,0.7218191623687744,1.0774519443511963,0.6505599617958069,1.4308249950408936,50000.0,0.526900053024292,2.1299495697021484,10000.0,20965.771684646606,21860.102259159088,20965.771684646606,890.5964822769165,1.6265153884887695,0.0 -61600,1.9097263,1.5264565,,,,,,,,,,,,,, -61700,1.9114475,1.5287678,,,,,,,,,,,,,, -61800,1.9871773,1.5841786,,,,,,,,,,,,,, -61900,2.0177376,1.582061,,,,,,,,,,,,,, -62000,1.8043277,1.4629072,,,,,,,,,,,,,, -62100,2.147392,1.5594943,,,,,,,,,,,,,, -62200,2.1486695,1.4576716,,,,,,,,,,,,,, -62300,1.9235965,1.4634812,,,,,,,,,,,,,, -62400,2.1510756,1.6190462,,,,,,,,,,,,,, -62500,1.9646863,1.4925917,,,,,,,,,,,,,, -62600,2.0896058,1.6871912,,,,,,,,,,,,,, -62700,1.7207662,1.5619152,,,,,,,,,,,,,, -62800,1.9640386,1.5492585,,,,,,,,,,,,,, -62900,1.9744858,1.5058957,,,,,,,,,,,,,, -63000,1.9587138,1.573728,,,,,,,,,,,,,, -63081,,,0.7154416441917419,1.1013087034225464,0.6513800024986267,1.4529109001159668,50000.0,0.5213000178337097,2.185504913330078,10000.0,21475.95622348785,22387.29090833664,21475.95622348785,907.5062689781188,1.6687884330749512,0.0 -63100,1.8713398,1.5469908,,,,,,,,,,,,,, -63200,2.0720618,1.5814164,,,,,,,,,,,,,, -63300,2.4218912,1.5896782,,,,,,,,,,,,,, -63400,2.0257351,1.4842174,,,,,,,,,,,,,, -63500,2.0013027,1.5216641,,,,,,,,,,,,,, -63600,2.002295,1.5134549,,,,,,,,,,,,,, -63700,1.9130116,1.559278,,,,,,,,,,,,,, -63800,2.1107128,1.5425942,,,,,,,,,,,,,, -63900,1.9774592,1.5989249,,,,,,,,,,,,,, -64000,1.9850763,1.5255685,,,,,,,,,,,,,, -64100,1.9194199,1.4523124,,,,,,,,,,,,,, -64200,1.8660077,1.5243001,,,,,,,,,,,,,, -64300,1.9971396,1.4850352,,,,,,,,,,,,,, -64400,1.8903539,1.5065688,,,,,,,,,,,,,, -64500,2.0655103,1.5989048,,,,,,,,,,,,,, -64584,,,0.7234733700752258,1.0683307647705078,0.6577000021934509,1.4033039808273315,50000.0,0.5286000370979309,2.160560846328736,10000.0,21985.94143342972,22914.152262449265,21985.94143342972,924.2931768894196,1.7055230140686035,0.0 -64600,2.1775131,1.5256964,,,,,,,,,,,,,, -64700,1.9890082,1.5192823,,,,,,,,,,,,,, -64800,1.8892399,1.5591567,,,,,,,,,,,,,, -64900,2.2229557,1.585629,,,,,,,,,,,,,, -65000,1.9607872,1.5528669,,,,,,,,,,,,,, -65100,2.1448321,1.5456324,,,,,,,,,,,,,, -65200,1.8933682,1.6114037,,,,,,,,,,,,,, -65300,1.9039377,1.4430677,,,,,,,,,,,,,, -65400,1.9014218,1.4560759,,,,,,,,,,,,,, -65500,2.1409297,1.5144584,,,,,,,,,,,,,, -65600,2.237124,1.6688719,,,,,,,,,,,,,, -65700,2.1475043,1.447062,,,,,,,,,,,,,, -65800,2.1384625,1.5845324,,,,,,,,,,,,,, -65900,2.0400858,1.4867027,,,,,,,,,,,,,, -66000,2.2146585,1.5581303,,,,,,,,,,,,,, -66087,,,0.7430046200752258,0.9785423874855042,0.6419000029563904,1.4796708822250366,50000.0,0.5059000253677368,2.2493345737457275,10000.0,22495.97627878189,23441.093533992767,22495.97627878189,941.1070840358734,1.74554705619812,0.0 -66100,1.9857359,1.6178482,,,,,,,,,,,,,, -66200,2.010378,1.4453266,,,,,,,,,,,,,, -66300,2.2725186,1.6427305,,,,,,,,,,,,,, -66400,2.13762,1.4079286,,,,,,,,,,,,,, -66500,2.0446117,1.5245805,,,,,,,,,,,,,, -66600,1.9832453,1.6162773,,,,,,,,,,,,,, -66700,1.8065555,1.4961214,,,,,,,,,,,,,, -66800,2.0481122,1.5147288,,,,,,,,,,,,,, -66900,2.2141385,1.4909685,,,,,,,,,,,,,, -67000,1.9373511,1.4313519,,,,,,,,,,,,,, -67100,1.9092189,1.4850041,,,,,,,,,,,,,, -67200,2.0852628,1.6144736,,,,,,,,,,,,,, -67300,2.0614903,1.4655648,,,,,,,,,,,,,, -67400,2.2419083,1.5167953,,,,,,,,,,,,,, -67500,1.9903101,1.5290772,,,,,,,,,,,,,, -67590,,,0.7223373651504517,1.0661697387695312,0.6440399885177612,1.467403531074524,50000.0,0.5074000358581543,2.2431344985961914,10000.0,23006.104821681976,23968.10825037956,23006.104821681976,957.8929336071014,1.7923862934112549,0.0 -67600,1.9235135,1.5557188,,,,,,,,,,,,,, -67700,2.3199956,1.5099881,,,,,,,,,,,,,, -67800,2.1040776,1.5643839,,,,,,,,,,,,,, -67900,1.9080629,1.4475074,,,,,,,,,,,,,, -68000,1.992297,1.5061902,,,,,,,,,,,,,, -68100,2.3629491,1.5623267,,,,,,,,,,,,,, -68200,1.8563988,1.6000212,,,,,,,,,,,,,, -68300,2.0756595,1.5488075,,,,,,,,,,,,,, -68400,2.043175,1.5731531,,,,,,,,,,,,,, -68500,2.11926,1.5199282,,,,,,,,,,,,,, -68600,1.982428,1.5261482,,,,,,,,,,,,,, -68700,2.360673,1.6600581,,,,,,,,,,,,,, -68800,2.1915855,1.475829,,,,,,,,,,,,,, -68900,2.089179,1.524526,,,,,,,,,,,,,, -69000,1.968949,1.4734523,,,,,,,,,,,,,, -69093,,,0.7189094424247742,1.087269306182861,0.6390399932861328,1.490094780921936,50000.0,0.5078000426292419,2.2381787300109863,10000.0,23516.008523225784,24494.93649435044,23516.008523225784,974.7140092849731,1.8431601524353027,0.0 -69100,2.0747285,1.535055,,,,,,,,,,,,,, -69200,2.044183,1.5081139,,,,,,,,,,,,,, -69300,2.0918715,1.531328,,,,,,,,,,,,,, -69400,1.9899683,1.5054812,,,,,,,,,,,,,, -69500,2.0159118,1.4793166,,,,,,,,,,,,,, -69600,2.0721,1.5438972,,,,,,,,,,,,,, -69700,2.1711962,1.5947472,,,,,,,,,,,,,, -69800,2.2402809,1.5124114,,,,,,,,,,,,,, -69900,2.1430743,1.5922686,,,,,,,,,,,,,, -70000,2.0766075,1.5030911,,,,,,,,,,,,,, -70100,2.0801246,1.466069,,,,,,,,,,,,,, -70200,2.1453042,1.4162033,,,,,,,,,,,,,, -70300,2.264327,1.6026453,,,,,,,,,,,,,, -70400,1.8871572,1.5179706,,,,,,,,,,,,,, -70500,1.8596855,1.4355972,,,,,,,,,,,,,, -70596,,,0.7314253449440002,1.0208046436309814,0.6607199907302856,1.391520619392395,50000.0,0.5357000231742859,2.1112637519836426,10000.0,24026.23086190224,25022.027183532715,24026.23086190224,991.4904737472534,1.882537841796875,0.0 -70600,2.0942323,1.5144159,,,,,,,,,,,,,, -70700,1.9221091,1.4564474,,,,,,,,,,,,,, -70800,1.9197317,1.4907012,,,,,,,,,,,,,, -70900,2.2234871,1.6489695,,,,,,,,,,,,,, -71000,2.2080178,1.5073745,,,,,,,,,,,,,, -71100,1.893375,1.5280706,,,,,,,,,,,,,, -71200,2.0054448,1.5417389,,,,,,,,,,,,,, -71300,2.5075681,1.5098212,,,,,,,,,,,,,, -71400,2.0799263,1.4430796,,,,,,,,,,,,,, -71500,1.9545541,1.5084465,,,,,,,,,,,,,, -71600,1.8232145,1.4351577,,,,,,,,,,,,,, -71700,1.8963624,1.5029418,,,,,,,,,,,,,, -71800,1.9925033,1.5023999,,,,,,,,,,,,,, -71900,2.063702,1.5288386,,,,,,,,,,,,,, -72000,2.2042694,1.6262978,,,,,,,,,,,,,, -72099,,,0.7379224896430969,0.9979417324066162,0.667639970779419,1.3583799600601196,50000.0,0.5326000452041626,2.116264581680298,10000.0,24536.351742506027,25548.94845175743,24536.351742506027,1008.195422887802,1.924994707107544,0.0 -72100,2.204516,1.6129811,,,,,,,,,,,,,, -72200,2.241167,1.5921884,,,,,,,,,,,,,, -72300,2.0136328,1.5339115,,,,,,,,,,,,,, -72400,2.161487,1.5217849,,,,,,,,,,,,,, -72500,2.2020566,1.5158533,,,,,,,,,,,,,, -72600,1.9176441,1.5638278,,,,,,,,,,,,,, -72700,2.1293383,1.4853201,,,,,,,,,,,,,, -72800,2.191641,1.5336068,,,,,,,,,,,,,, -72900,2.0072975,1.3577137,,,,,,,,,,,,,, -73000,2.1013548,1.5555911,,,,,,,,,,,,,, -73100,1.9729518,1.4996643,,,,,,,,,,,,,, -73200,2.0128336,1.4917319,,,,,,,,,,,,,, -73300,2.0809305,1.5033094,,,,,,,,,,,,,, -73400,2.0398426,1.5383303,,,,,,,,,,,,,, -73500,2.2553144,1.6088918,,,,,,,,,,,,,, -73600,2.0716553,1.5260024,,,,,,,,,,,,,, -73602,,,0.7327207922935486,1.0299476385116575,0.6582599878311157,1.3956197500228882,50000.0,0.5344000458717346,2.101027011871338,10000.0,25046.33527326584,26075.744012355804,25046.33527326584,1024.909749507904,1.9686222076416016,0.0 -73700,1.8820084,1.4519932,,,,,,,,,,,,,, -73800,1.9209763,1.4382466,,,,,,,,,,,,,, -73900,1.9581686,1.3986562,,,,,,,,,,,,,, -74000,1.9599377,1.4653955,,,,,,,,,,,,,, -74100,2.2471764,1.5711731,,,,,,,,,,,,,, -74200,2.203209,1.471088,,,,,,,,,,,,,, -74300,2.0417297,1.4505137,,,,,,,,,,,,,, -74400,2.167354,1.5486763,,,,,,,,,,,,,, -74500,2.266138,1.4534013,,,,,,,,,,,,,, -74600,1.9595114,1.3723073,,,,,,,,,,,,,, -74700,2.2195456,1.5110836,,,,,,,,,,,,,, -74800,2.0875475,1.453258,,,,,,,,,,,,,, -74900,2.1763976,1.4835087,,,,,,,,,,,,,, -75000,2.1028242,1.517145,,,,,,,,,,,,,, -75100,2.2705426,1.5036917,,,,,,,,,,,,,, -75106,,,0.7759885191917419,0.8355568647384644,0.6620799899101257,1.3705966472625732,50000.0,0.5297999978065491,2.1002209186553955,10000.0,25556.429697752,26603.31052732468,25556.429697752,1042.2852289676666,2.011807918548584,0.0 -75200,2.3358827,1.5445899,,,,,,,,,,,,,, -75300,2.1685376,1.5241818,,,,,,,,,,,,,, -75400,2.0141876,1.4801544,,,,,,,,,,,,,, -75500,2.275154,1.488152,,,,,,,,,,,,,, -75600,2.276398,1.4796785,,,,,,,,,,,,,, -75700,2.5769308,1.6264719,,,,,,,,,,,,,, -75800,1.9906373,1.4936647,,,,,,,,,,,,,, -75900,2.170757,1.4874309,,,,,,,,,,,,,, -76000,2.0800886,1.547692,,,,,,,,,,,,,, -76100,2.3198636,1.4058285,,,,,,,,,,,,,, -76200,2.291215,1.5315189,,,,,,,,,,,,,, -76300,1.9704964,1.4807444,,,,,,,,,,,,,, -76400,1.9975194,1.4683373,,,,,,,,,,,,,, -76500,2.2471452,1.4591665,,,,,,,,,,,,,, -76600,2.1685376,1.5085875,,,,,,,,,,,,,, -76609,,,0.7373644709587097,0.9887303709983826,0.6548199653625488,1.4147918224334717,50000.0,0.5212000012397766,2.1413140296936035,10000.0,26066.617539405823,27130.641062498093,26066.617539405823,1059.3335857391355,2.053749084472656,0.0 -76700,2.6801028,1.5221854,,,,,,,,,,,,,, -76800,1.9531717,1.4950858,,,,,,,,,,,,,, -76900,1.9562722,1.572909,,,,,,,,,,,,,, -77000,2.1895137,1.5404482,,,,,,,,,,,,,, -77100,2.3679135,1.4889566,,,,,,,,,,,,,, -77200,2.218311,1.5237923,,,,,,,,,,,,,, -77300,1.9515505,1.4282932,,,,,,,,,,,,,, -77400,2.1285086,1.4927653,,,,,,,,,,,,,, -77500,2.3495061,1.4802947,,,,,,,,,,,,,, -77600,2.2079625,1.4580925,,,,,,,,,,,,,, -77700,1.9942595,1.4409673,,,,,,,,,,,,,, -77800,2.083237,1.471853,,,,,,,,,,,,,, -77900,2.3119974,1.40125,,,,,,,,,,,,,, -78000,2.2038896,1.524293,,,,,,,,,,,,,, -78100,2.2109299,1.4395244,,,,,,,,,,,,,, -78112,,,0.7537468075752258,0.915976583957672,0.6726399660110474,1.3231079578399658,50000.0,0.5396000146865845,2.074337959289551,10000.0,26576.7337744236,27657.67480564117,26576.7337744236,1076.1535539627075,2.098949909210205,0.0 -78200,2.064962,1.366818,,,,,,,,,,,,,, -78300,2.052625,1.4512942,,,,,,,,,,,,,, -78400,2.2438745,1.3225914,,,,,,,,,,,,,, -78500,2.093929,1.5390011,,,,,,,,,,,,,, -78600,1.9945911,1.3623362,,,,,,,,,,,,,, -78700,2.0032032,1.4194468,,,,,,,,,,,,,, -78800,1.8260096,1.4306641,,,,,,,,,,,,,, -78900,1.9398947,1.3548745,,,,,,,,,,,,,, -79000,2.1625633,1.3753848,,,,,,,,,,,,,, -79100,2.2263713,1.3709203,,,,,,,,,,,,,, -79200,2.1011899,1.2392187,,,,,,,,,,,,,, -79300,2.3371756,1.48574,,,,,,,,,,,,,, -79400,2.34771,1.4825168,,,,,,,,,,,,,, -79500,2.4194918,1.6073701,,,,,,,,,,,,,, -79600,2.2996118,1.5364778,,,,,,,,,,,,,, -79615,,,0.74906325340271,0.9586185216903688,0.6702799797058105,1.339442491531372,50000.0,0.5450000166893005,2.059394836425781,10000.0,27086.811772346497,28184.400028944016,27086.811772346497,1092.703783750534,2.143179416656494,0.0 -79700,2.2782311,1.5128796,,,,,,,,,,,,,, -79800,2.0779812,1.5334071,,,,,,,,,,,,,, -79900,2.1321318,1.5680774,,,,,,,,,,,,,, -80000,2.2918057,1.3957624,,,,,,,,,,,,,, -80100,1.8920169,1.4347916,,,,,,,,,,,,,, -80200,1.9979298,1.3642055,,,,,,,,,,,,,, -80300,2.006098,1.4853334,,,,,,,,,,,,,, -80400,1.9939924,1.3873672,,,,,,,,,,,,,, -80500,2.2136407,1.4253621,,,,,,,,,,,,,, -80600,1.9820211,1.3609006,,,,,,,,,,,,,, -80700,2.1708224,1.3853776,,,,,,,,,,,,,, -80800,2.3269331,1.5851136,,,,,,,,,,,,,, -80900,2.3001304,1.5179286,,,,,,,,,,,,,, -81000,2.2387466,1.3579888,,,,,,,,,,,,,, -81100,2.2272444,1.4413285,,,,,,,,,,,,,, -81118,,,0.7498006820678711,0.9481151700019836,0.6739799976348877,1.3297719955444336,50000.0,0.5406000018119812,2.044244766235352,10000.0,27596.813438415527,28711.19625043869,27596.813438415527,1109.4012954235077,2.186863899230957,0.0 -81200,2.1726782,1.4956892,,,,,,,,,,,,,, -81300,2.384824,1.5649933,,,,,,,,,,,,,, -81400,1.9678551,1.3797362,,,,,,,,,,,,,, -81500,2.296667,1.3801638,,,,,,,,,,,,,, -81600,2.0929353,1.42169,,,,,,,,,,,,,, -81700,2.1708925,1.4815017,,,,,,,,,,,,,, -81800,2.2456203,1.4677762,,,,,,,,,,,,,, -81900,2.2867107,1.5705726,,,,,,,,,,,,,, -82000,2.5017786,1.5182426,,,,,,,,,,,,,, -82100,2.2257447,1.4573152,,,,,,,,,,,,,, -82200,2.2043884,1.4425272,,,,,,,,,,,,,, -82300,2.3031294,1.4644052,,,,,,,,,,,,,, -82400,2.121719,1.490991,,,,,,,,,,,,,, -82500,2.2915545,1.4100296,,,,,,,,,,,,,, -82600,2.2560644,1.5150607,,,,,,,,,,,,,, -82621,,,0.7277582883834839,1.0413271188735962,0.6600399613380432,1.386773705482483,50000.0,0.5247000455856323,2.1522934436798096,10000.0,28106.82225990296,29237.84588885308,28106.82225990296,1125.9471654891968,2.227048873901367,0.0 -82700,2.0835102,1.4964294,,,,,,,,,,,,,, -82800,2.044553,1.3924885,,,,,,,,,,,,,, -82900,2.3934538,1.4825914,,,,,,,,,,,,,, -83000,2.1882968,1.4059662,,,,,,,,,,,,,, -83100,2.2640262,1.4195873,,,,,,,,,,,,,, -83200,2.100747,1.4043491,,,,,,,,,,,,,, -83300,2.2828956,1.3342807,,,,,,,,,,,,,, -83400,2.5838223,1.3741238,,,,,,,,,,,,,, -83500,2.7960727,1.4674184,,,,,,,,,,,,,, -83600,2.3843234,1.5622991,,,,,,,,,,,,,, -83700,2.1532435,1.3967052,,,,,,,,,,,,,, -83800,2.3338292,1.4914631,,,,,,,,,,,,,, -83900,2.2332606,1.4328939,,,,,,,,,,,,,, -84000,2.239306,1.5271436,,,,,,,,,,,,,, -84100,2.3089018,1.5087761,,,,,,,,,,,,,, -84125,,,0.7925701141357422,0.7795647382736206,0.6730599999427795,1.3342093229293823,50000.0,0.5400000214576721,2.088098526000977,10000.0,28616.90827870369,29764.76045846939,28616.90827870369,1142.674451828003,2.274683952331543,0.0 -84200,2.4637382,1.3948952,,,,,,,,,,,,,, -84300,2.2917442,1.4155104,,,,,,,,,,,,,, -84400,2.37461,1.4422077,,,,,,,,,,,,,, -84500,2.1243238,1.4027438,,,,,,,,,,,,,, -84600,2.2894359,1.4141461,,,,,,,,,,,,,, -84700,2.1269898,1.433174,,,,,,,,,,,,,, -84800,2.2628243,1.5442313,,,,,,,,,,,,,, -84900,2.4705718,1.5937296,,,,,,,,,,,,,, -85000,2.3974416,1.4107846,,,,,,,,,,,,,, -85100,2.436235,1.5270991,,,,,,,,,,,,,, -85200,2.6568844,1.5779915,,,,,,,,,,,,,, -85300,2.0837662,1.4383975,,,,,,,,,,,,,, -85400,2.1286778,1.3583872,,,,,,,,,,,,,, -85500,2.1890483,1.5102229,,,,,,,,,,,,,, -85600,2.1747484,1.3596613,,,,,,,,,,,,,, -85628,,,0.7692721486091614,0.8613221049308777,0.6786800026893616,1.3224866390228271,50000.0,0.5486000180244446,2.0616767406463623,10000.0,29127.042145967484,30291.73993468285,29127.042145967484,1159.42480802536,2.3172881603240967,0.0 -85700,2.238961,1.4875181,,,,,,,,,,,,,, -85800,2.0925293,1.3770076,,,,,,,,,,,,,, -85900,2.3820252,1.4600964,,,,,,,,,,,,,, -86000,2.2356405,1.4665866,,,,,,,,,,,,,, -86100,2.2733254,1.355617,,,,,,,,,,,,,, -86200,2.318832,1.3870331,,,,,,,,,,,,,, -86300,2.0671904,1.4322383,,,,,,,,,,,,,, -86400,2.2048156,1.442028,,,,,,,,,,,,,, -86500,2.3336577,1.5489488,,,,,,,,,,,,,, -86600,2.3002362,1.3411386,,,,,,,,,,,,,, -86700,2.2046008,1.4227309,,,,,,,,,,,,,, -86800,2.1644623,1.27614,,,,,,,,,,,,,, -86900,2.2620165,1.4063973,,,,,,,,,,,,,, -87000,2.1765401,1.4236481,,,,,,,,,,,,,, -87100,2.077401,1.365811,,,,,,,,,,,,,, -87131,,,0.7654455900192261,0.869311511516571,0.6791200041770935,1.3054094314575195,50000.0,0.5521000027656555,2.015770673751831,10000.0,29636.98632788658,30818.359090328217,29636.98632788658,1176.0026772022247,2.3611364364624023,0.0 -87200,2.2404706,1.4256225,,,,,,,,,,,,,, -87300,2.2625132,1.3958759,,,,,,,,,,,,,, -87400,2.3896585,1.4635361,,,,,,,,,,,,,, -87500,2.2923777,1.4023668,,,,,,,,,,,,,, -87600,2.2479656,1.4051785,,,,,,,,,,,,,, -87700,2.4763465,1.3459617,,,,,,,,,,,,,, -87800,2.2209213,1.5162739,,,,,,,,,,,,,, -87900,2.1324708,1.3582997,,,,,,,,,,,,,, -88000,2.1255846,1.3522102,,,,,,,,,,,,,, -88100,2.2041929,1.4573954,,,,,,,,,,,,,, -88200,2.1509118,1.4407152,,,,,,,,,,,,,, -88300,2.1595247,1.378598,,,,,,,,,,,,,, -88400,2.2579758,1.3100868,,,,,,,,,,,,,, -88500,2.4077907,1.4206516,,,,,,,,,,,,,, -88600,2.2682116,1.3626302,,,,,,,,,,,,,, -88634,,,0.7592872977256775,0.9058119654655457,0.6819199919700623,1.291403889656067,50000.0,0.5489000082015991,2.0037930011749268,10000.0,30147.10192131996,31345.05516934395,30147.10192131996,1192.485106468201,2.4045522212982178,0.0 -88700,2.5144079,1.3227801,,,,,,,,,,,,,, -88800,2.3786957,1.4160049,,,,,,,,,,,,,, -88900,2.2965002,1.3845694,,,,,,,,,,,,,, -89000,2.4940345,1.4412043,,,,,,,,,,,,,, -89100,2.3197997,1.4189703,,,,,,,,,,,,,, -89200,2.176917,1.4027925,,,,,,,,,,,,,, -89300,2.2905958,1.4451395,,,,,,,,,,,,,, -89400,2.697452,1.4330823,,,,,,,,,,,,,, -89500,2.2728822,1.5168967,,,,,,,,,,,,,, -89600,2.595626,1.3683118,,,,,,,,,,,,,, -89700,2.0458272,1.2980852,,,,,,,,,,,,,, -89800,2.3409524,1.394734,,,,,,,,,,,,,, -89900,2.3101602,1.450877,,,,,,,,,,,,,, -90000,2.5011265,1.498205,,,,,,,,,,,,,, -90100,2.6286337,1.432536,,,,,,,,,,,,,, -90137,,,0.7499800324440002,0.9388789534568788,0.6706199645996094,1.3405901193618774,50000.0,0.5453000068664551,2.0540900230407715,10000.0,30657.10430932045,31871.90614414215,30657.10430932045,1209.234266757965,2.4512953758239746,0.0 -90200,2.3558114,1.4397936,,,,,,,,,,,,,, -90300,2.3035188,1.390522,,,,,,,,,,,,,, -90400,2.3006094,1.4079592,,,,,,,,,,,,,, -90500,2.3752296,1.4323943,,,,,,,,,,,,,, -90600,2.6674707,1.4734746,,,,,,,,,,,,,, -90700,2.2447145,1.3719431,,,,,,,,,,,,,, -90800,2.4565344,1.4504639,,,,,,,,,,,,,, -90900,2.2664354,1.4516866,,,,,,,,,,,,,, -91000,2.4045181,1.3724294,,,,,,,,,,,,,, -91100,2.3403463,1.3690355,,,,,,,,,,,,,, -91200,2.4699984,1.3757062,,,,,,,,,,,,,, -91300,2.2354686,1.3353374,,,,,,,,,,,,,, -91400,2.7163243,1.4359444,,,,,,,,,,,,,, -91500,2.2013621,1.3705157,,,,,,,,,,,,,, -91600,2.2741396,1.4336587,,,,,,,,,,,,,, -91640,,,0.7515146732330322,0.9365392327308656,0.6793999671936035,1.3209984302520752,50000.0,0.5464000105857849,2.101987361907959,10000.0,31167.13133573532,32398.71913027764,31167.13133573532,1225.9227805137634,2.495468854904175,0.0 -91700,2.39065,1.3969114,,,,,,,,,,,,,, -91800,2.3799436,1.4010216,,,,,,,,,,,,,, -91900,2.305775,1.3236448,,,,,,,,,,,,,, -92000,2.249775,1.3586568,,,,,,,,,,,,,, -92100,2.4163723,1.4632603,,,,,,,,,,,,,, -92200,2.7366402,1.5085635,,,,,,,,,,,,,, -92300,2.0758083,1.3150091,,,,,,,,,,,,,, -92400,2.3498406,1.4009199,,,,,,,,,,,,,, -92500,2.4822423,1.3737658,,,,,,,,,,,,,, -92600,2.396711,1.415969,,,,,,,,,,,,,, -92700,2.2736845,1.4724922,,,,,,,,,,,,,, -92800,2.4010503,1.3648498,,,,,,,,,,,,,, -92900,2.2264981,1.3169364,,,,,,,,,,,,,, -93000,2.3787832,1.3111843,,,,,,,,,,,,,, -93100,2.2813091,1.3755479,,,,,,,,,,,,,, -93144,,,0.7772839665412903,0.8345317840576172,0.6767199635505676,1.3265360593795776,50000.0,0.5480000376701355,2.06464958190918,10000.0,31677.139016866684,32925.45507359505,31677.139016866684,1242.5521783828735,2.5414443016052246,0.0 -93200,2.2824595,1.5481713,,,,,,,,,,,,,, -93300,2.5004656,1.4544673,,,,,,,,,,,,,, -93400,2.7942617,1.3604362,,,,,,,,,,,,,, -93500,2.3945348,1.368025,,,,,,,,,,,,,, -93600,2.3037243,1.3990546,,,,,,,,,,,,,, -93700,2.4177816,1.444639,,,,,,,,,,,,,, -93800,2.3575637,1.3481871,,,,,,,,,,,,,, -93900,2.4841318,1.4049424,,,,,,,,,,,,,, -94000,2.4524274,1.2625432,,,,,,,,,,,,,, -94100,2.258864,1.3321248,,,,,,,,,,,,,, -94200,2.5602002,1.4270272,,,,,,,,,,,,,, -94300,2.338908,1.4180816,,,,,,,,,,,,,, -94400,2.360023,1.4315349,,,,,,,,,,,,,, -94500,2.5292785,1.4327757,,,,,,,,,,,,,, -94600,2.5305412,1.3775698,,,,,,,,,,,,,, -94647,,,0.7825653553009033,0.8143873810768127,0.6804599761962891,1.2941557168960571,50000.0,0.5538000464439392,2.0050389766693115,10000.0,32187.304002285004,33452.35047388077,32187.304002285004,1259.1848402023315,2.5860958099365234,0.0 -94700,2.2847345,1.3438439,,,,,,,,,,,,,, -94800,2.4891493,1.3724813,,,,,,,,,,,,,, -94900,2.3703985,1.329946,,,,,,,,,,,,,, -95000,2.539996,1.3517537,,,,,,,,,,,,,, -95100,2.3418136,1.4604459,,,,,,,,,,,,,, -95200,2.461667,1.3104094,,,,,,,,,,,,,, -95300,2.5320413,1.4387705,,,,,,,,,,,,,, -95400,2.3801904,1.3392508,,,,,,,,,,,,,, -95500,2.2190027,1.2903478,,,,,,,,,,,,,, -95600,2.369477,1.329319,,,,,,,,,,,,,, -95700,2.3908815,1.398097,,,,,,,,,,,,,, -95800,2.36984,1.4958853,,,,,,,,,,,,,, -95900,2.4060018,1.4100567,,,,,,,,,,,,,, -96000,2.3830616,1.390744,,,,,,,,,,,,,, -96100,2.3264923,1.4215021,,,,,,,,,,,,,, -96150,,,0.7772839665412903,0.8267792463302612,0.6888599991798401,1.2830692529678345,50000.0,0.5613000392913818,2.0027637481689453,10000.0,32697.30992746353,33979.11499476433,32697.30992746353,1275.8200645446775,2.6573374271392822,0.0 -96200,2.4924073,1.3241248,,,,,,,,,,,,,, -96300,2.277615,1.2811819,,,,,,,,,,,,,, -96400,2.3978574,1.2812128,,,,,,,,,,,,,, -96500,2.505331,1.3168414,,,,,,,,,,,,,, -96600,2.659832,1.3687402,,,,,,,,,,,,,, -96700,2.4745317,1.3113984,,,,,,,,,,,,,, -96800,2.4668162,1.4810879,,,,,,,,,,,,,, -96900,2.7848804,1.373188,,,,,,,,,,,,,, -97000,2.7930675,1.3597265,,,,,,,,,,,,,, -97100,2.1853745,1.1984267,,,,,,,,,,,,,, -97200,2.3013418,1.4180905,,,,,,,,,,,,,, -97300,2.234485,1.316926,,,,,,,,,,,,,, -97400,2.3626308,1.3361591,,,,,,,,,,,,,, -97500,2.7773082,1.3831962,,,,,,,,,,,,,, -97600,2.5054996,1.3782831,,,,,,,,,,,,,, -97653,,,0.7453164458274841,0.9852451086044312,0.6612799763679504,1.3965210914611816,50000.0,0.5325000286102295,2.141735076904297,10000.0,33207.35705137253,34505.99077963829,33207.35705137253,1292.5487146377563,2.7052574157714844,0.0 -97700,2.4564478,1.3478479,,,,,,,,,,,,,, -97800,2.4247656,1.39831,,,,,,,,,,,,,, -97900,2.368001,1.4197085,,,,,,,,,,,,,, -98000,2.4577239,1.3342026,,,,,,,,,,,,,, -98100,2.4268522,1.378777,,,,,,,,,,,,,, -98200,2.6041336,1.3202121,,,,,,,,,,,,,, -98300,2.3640056,1.3442829,,,,,,,,,,,,,, -98400,2.4214694,1.3115952,,,,,,,,,,,,,, -98500,2.4798493,1.3030455,,,,,,,,,,,,,, -98600,2.2879791,1.377863,,,,,,,,,,,,,, -98700,2.3081534,1.2923523,,,,,,,,,,,,,, -98800,2.4689302,1.3984103,,,,,,,,,,,,,, -98900,2.4174557,1.3367457,,,,,,,,,,,,,, -99000,2.3755696,1.401481,,,,,,,,,,,,,, -99100,2.741873,1.4031646,,,,,,,,,,,,,, -99156,,,0.7657644748687744,0.8797435760498047,0.6818999648094177,1.3018945455551147,50000.0,0.5527999997138977,2.025407791137696,10000.0,33717.318086624146,35032.746633291245,33717.318086624146,1309.2410578727722,2.7550249099731445,0.0 -99200,2.6858845,1.3404119,,,,,,,,,,,,,, -99300,2.4672048,1.2446556,,,,,,,,,,,,,, -99400,2.5443528,1.4082009,,,,,,,,,,,,,, -99500,2.3072162,1.3765746,,,,,,,,,,,,,, -99600,2.5792587,1.400195,,,,,,,,,,,,,, -99700,2.3803642,1.4378911,,,,,,,,,,,,,, -99800,2.6123002,1.3988749,,,,,,,,,,,,,, -99900,2.4843194,1.3432385,,,,,,,,,,,,,, -100000,2.5276206,1.3442688,,,,,,,,,,,,,, -100100,2.5854974,1.3402833,,,,,,,,,,,,,, -100200,2.318648,1.2373657,,,,,,,,,,,,,, -100300,2.5444665,1.3830638,,,,,,,,,,,,,, -100400,2.430508,1.391176,,,,,,,,,,,,,, -100500,2.4201488,1.3473643,,,,,,,,,,,,,, -100600,2.4441495,1.2016842,,,,,,,,,,,,,, -100659,,,0.7707070708274841,0.8599854111671448,0.6876400113105774,1.271263599395752,50000.0,0.5570999979972839,2.000819206237793,10000.0,34227.27606058121,35559.58508205414,34227.27606058121,1326.0208716392517,2.801799774169922,0.0 -100700,2.5471275,1.4138077,,,,,,,,,,,,,, -100800,2.58179,1.2959216,,,,,,,,,,,,,, -100900,2.3424401,1.2668389,,,,,,,,,,,,,, -101000,2.5246503,1.3261873,,,,,,,,,,,,,, -101100,2.7043965,1.4122134,,,,,,,,,,,,,, -101200,2.5323048,1.3449003,,,,,,,,,,,,,, -101300,2.5638838,1.3417116,,,,,,,,,,,,,, -101400,2.3714259,1.3263382,,,,,,,,,,,,,, -101500,2.6539578,1.3994718,,,,,,,,,,,,,, -101600,2.6851573,1.3972521,,,,,,,,,,,,,, -101700,2.479631,1.3309306,,,,,,,,,,,,,, -101800,2.748782,1.3505248,,,,,,,,,,,,,, -101900,2.3226123,1.2691363,,,,,,,,,,,,,, -102000,2.4240847,1.3298112,,,,,,,,,,,,,, -102100,2.5637636,1.2400699,,,,,,,,,,,,,, -102162,,,0.7767059803009033,0.8202096223831177,0.6881600022315979,1.2708592414855957,50000.0,0.5587000250816345,2.031827449798584,10000.0,34737.225987672806,36086.32460570336,34737.225987672806,1342.7114791870115,2.847575664520264,0.0 -102200,2.536721,1.3839265,,,,,,,,,,,,,, -102300,2.754876,1.4188408,,,,,,,,,,,,,, -102400,2.593395,1.3208877,,,,,,,,,,,,,, -102500,2.3956013,1.2142304,,,,,,,,,,,,,, -102600,2.3074896,1.2587392,,,,,,,,,,,,,, -102700,2.41583,1.3060167,,,,,,,,,,,,,, -102800,2.654197,1.3233935,,,,,,,,,,,,,, -102900,2.5858123,1.32429,,,,,,,,,,,,,, -103000,2.6436985,1.352209,,,,,,,,,,,,,, -103100,2.65022,1.3386248,,,,,,,,,,,,,, -103200,2.6711996,1.3876139,,,,,,,,,,,,,, -103300,2.8890474,1.3582356,,,,,,,,,,,,,, -103400,2.4131982,1.3540823,,,,,,,,,,,,,, -103500,2.4785624,1.347066,,,,,,,,,,,,,, -103600,2.6008573,1.2506223,,,,,,,,,,,,,, -103666,,,0.7948023080825806,0.742831289768219,0.6871799826622009,1.2739155292510986,50000.0,0.5575000047683716,2.007575273513794,10000.0,35247.42342591286,36613.36411499977,35247.42342591286,1359.4346933364868,2.912957191467285,0.0 -103700,2.4763253,1.3265154,,,,,,,,,,,,,, -103800,3.074837,1.3548154,,,,,,,,,,,,,, -103900,2.7861285,1.3647271,,,,,,,,,,,,,, -104000,2.5411758,1.3704251,,,,,,,,,,,,,, -104100,2.3361955,1.2899998,,,,,,,,,,,,,, -104200,2.5815485,1.2652377,,,,,,,,,,,,,, -104300,2.6475382,1.3014729,,,,,,,,,,,,,, -104400,2.7818615,1.2913226,,,,,,,,,,,,,, -104500,2.49926,1.289768,,,,,,,,,,,,,, -104600,2.5451705,1.3452144,,,,,,,,,,,,,, -104700,2.6786034,1.3016815,,,,,,,,,,,,,, -104800,2.5959044,1.2718184,,,,,,,,,,,,,, -104900,2.6886294,1.316098,,,,,,,,,,,,,, -105000,2.7064161,1.3677553,,,,,,,,,,,,,, -105100,2.821836,1.2775837,,,,,,,,,,,,,, -105169,,,0.7901785373687744,0.779336154460907,0.6885799765586853,1.2575336694717407,50000.0,0.5600000023841858,1.9892528057098389,10000.0,35757.409264564514,37140.13447451592,35757.409264564514,1376.1155910491943,2.963505744934082,0.0 -105200,2.6139972,1.3768353,,,,,,,,,,,,,, -105300,2.392103,1.2001767,,,,,,,,,,,,,, -105400,2.580659,1.3114545,,,,,,,,,,,,,, -105500,2.6451719,1.2713001,,,,,,,,,,,,,, -105600,2.6426232,1.3113179,,,,,,,,,,,,,, -105700,2.8589492,1.3499237,,,,,,,,,,,,,, -105800,3.075364,1.4156206,,,,,,,,,,,,,, -105900,2.675682,1.3329847,,,,,,,,,,,,,, -106000,2.4846923,1.3260801,,,,,,,,,,,,,, -106100,2.9559388,1.3902545,,,,,,,,,,,,,, -106200,2.7878654,1.3355219,,,,,,,,,,,,,, -106300,2.7299874,1.3018792,,,,,,,,,,,,,, -106400,2.5940971,1.3044142,,,,,,,,,,,,,, -106500,2.5724633,1.3712342,,,,,,,,,,,,,, -106600,2.8874662,1.3521075,,,,,,,,,,,,,, -106672,,,0.7915935516357422,0.7726999521255493,0.6963399648666382,1.2351707220077517,50000.0,0.5663000345230103,1.970118284225464,10000.0,36267.38300538063,37666.79642248154,36267.38300538063,1392.704957962036,3.009145498275757,0.0 -106700,2.7513053,1.3677936,,,,,,,,,,,,,, -106800,2.679261,1.4041169,,,,,,,,,,,,,, -106900,2.8636193,1.3304548,,,,,,,,,,,,,, -107000,2.6854677,1.4056743,,,,,,,,,,,,,, -107100,2.8841782,1.3189942,,,,,,,,,,,,,, -107200,2.388,1.3416158,,,,,,,,,,,,,, -107300,2.4768815,1.211077,,,,,,,,,,,,,, -107400,2.9381201,1.3813161,,,,,,,,,,,,,, -107500,2.6128652,1.3716912,,,,,,,,,,,,,, -107600,2.713237,1.2897208,,,,,,,,,,,,,, -107700,2.8850071,1.2429612,,,,,,,,,,,,,, -107800,3.030114,1.3178476,,,,,,,,,,,,,, -107900,2.595405,1.2258909,,,,,,,,,,,,,, -108000,2.953993,1.3260658,,,,,,,,,,,,,, -108100,2.5704901,1.1847765,,,,,,,,,,,,,, -108176,,,0.7795758843421936,0.8351767063140869,0.6855999827384949,1.2827911376953125,50000.0,0.5533000230789185,2.026097297668457,10000.0,36777.43883371353,38193.63676953316,36777.43883371353,1409.3859219551086,3.06013560295105,0.0 -108200,2.6017413,1.2872759,,,,,,,,,,,,,, -108300,2.5039072,1.2742107,,,,,,,,,,,,,, -108400,2.8060756,1.2600484,,,,,,,,,,,,,, -108500,2.8823845,1.2910297,,,,,,,,,,,,,, -108600,3.3184655,1.3567843,,,,,,,,,,,,,, -108700,2.6549225,1.2648329,,,,,,,,,,,,,, -108800,2.5800836,1.3089211,,,,,,,,,,,,,, -108900,2.5970302,1.3041546,,,,,,,,,,,,,, -109000,2.915644,1.2467241,,,,,,,,,,,,,, -109100,2.8759716,1.2889957,,,,,,,,,,,,,, -109200,2.5772762,1.328736,,,,,,,,,,,,,, -109300,2.817018,1.2258368,,,,,,,,,,,,,, -109400,2.5714738,1.2636862,,,,,,,,,,,,,, -109500,2.9203396,1.3032718,,,,,,,,,,,,,, -109600,2.6968029,1.3659879,,,,,,,,,,,,,, -109679,,,0.7897002100944519,0.7726017832756042,0.6967399716377258,1.228294849395752,50000.0,0.567300021648407,1.9748430252075195,10000.0,37287.4537332058,38720.35473489761,37287.4537332058,1425.993721485138,3.102454423904419,0.0 -109700,2.8237429,1.3800322,,,,,,,,,,,,,, -109800,2.7560666,1.175175,,,,,,,,,,,,,, -109900,2.6039467,1.207385,,,,,,,,,,,,,, -110000,2.8478627,1.2668378,,,,,,,,,,,,,, -110100,2.455491,1.1424905,,,,,,,,,,,,,, -110200,2.5427554,1.2591504,,,,,,,,,,,,,, -110300,2.5000906,1.2813586,,,,,,,,,,,,,, -110400,2.8100986,1.2565705,,,,,,,,,,,,,, -110500,2.7153864,1.2682166,,,,,,,,,,,,,, -110600,2.7434528,1.2778958,,,,,,,,,,,,,, -110700,2.8725655,1.3713673,,,,,,,,,,,,,, -110800,2.837682,1.2061541,,,,,,,,,,,,,, -110900,2.7995815,1.2944443,,,,,,,,,,,,,, -111000,2.8671334,1.2994412,,,,,,,,,,,,,, -111100,2.8270407,1.281943,,,,,,,,,,,,,, -111183,,,0.7920718789100647,0.7628602385520935,0.7012999653816223,1.20783793926239,50000.0,0.5749000310897827,1.9123485088348389,10000.0,37797.4540207386,39247.22995352745,37797.4540207386,1442.764579296112,3.153353452682495,0.0 -111200,2.7509234,1.231756,,,,,,,,,,,,,, -111300,3.3264875,1.2986877,,,,,,,,,,,,,, -111400,2.722203,1.2584585,,,,,,,,,,,,,, -111500,2.790836,1.245566,,,,,,,,,,,,,, -111600,2.9493892,1.2680454,,,,,,,,,,,,,, -111700,2.845623,1.3265986,,,,,,,,,,,,,, -111800,2.5947075,1.3130983,,,,,,,,,,,,,, -111900,2.7016387,1.1811978,,,,,,,,,,,,,, -112000,2.9383616,1.2909033,,,,,,,,,,,,,, -112100,2.7069514,1.1297693,,,,,,,,,,,,,, -112200,3.0174885,1.4053319,,,,,,,,,,,,,, -112300,3.0598705,1.2224209,,,,,,,,,,,,,, -112400,2.7737656,1.3107378,,,,,,,,,,,,,, -112500,2.850177,1.2166759,,,,,,,,,,,,,, -112600,2.759157,1.2466608,,,,,,,,,,,,,, -112687,,,0.816824734210968,0.6701944470405579,0.6967799663543701,1.2372926473617554,50000.0,0.5703999996185303,1.9807078838348389,10000.0,38307.6800494194,39774.13205599785,38307.6800494194,1459.3377630710602,3.2031338214874268,0.0 -112700,2.9081979,1.3221177,,,,,,,,,,,,,, -112800,2.717768,1.3457451,,,,,,,,,,,,,, -112900,2.56154,1.1964447,,,,,,,,,,,,,, -113000,2.786245,1.2321421,,,,,,,,,,,,,, -113100,3.0168242,1.2648885,,,,,,,,,,,,,, -113200,2.6602414,1.3806894,,,,,,,,,,,,,, -113300,2.559911,1.1741664,,,,,,,,,,,,,, -113400,2.8747685,1.2318864,,,,,,,,,,,,,, -113500,2.7551258,1.2923169,,,,,,,,,,,,,, -113600,2.9885895,1.2341976,,,,,,,,,,,,,, -113700,2.7821755,1.1870844,,,,,,,,,,,,,, -113800,2.530954,1.124297,,,,,,,,,,,,,, -113900,2.7915983,1.2253668,,,,,,,,,,,,,, -114000,2.7559216,1.2908062,,,,,,,,,,,,,, -114100,2.5673864,1.2093384,,,,,,,,,,,,,, -114191,,,0.807039201259613,0.6915223002433777,0.6977799534797668,1.2234848737716677,50000.0,0.5706000328063965,1.964166522026062,10000.0,38817.644235134125,40300.83305287361,38817.644235134125,1475.9748423099518,3.2496707439422607,0.0 -114200,2.6774807,1.2284008,,,,,,,,,,,,,, -114300,2.86436,1.2548022,,,,,,,,,,,,,, -114400,2.8915982,1.3037708,,,,,,,,,,,,,, -114500,3.0142248,1.3728328,,,,,,,,,,,,,, -114600,2.726836,1.2900069,,,,,,,,,,,,,, -114700,2.650337,1.2617577,,,,,,,,,,,,,, -114800,2.757914,1.2405705,,,,,,,,,,,,,, -114900,3.1938162,1.218719,,,,,,,,,,,,,, -115000,2.9854887,1.145474,,,,,,,,,,,,,, -115100,2.7290096,1.1430277,,,,,,,,,,,,,, -115200,3.2989666,1.3065307,,,,,,,,,,,,,, -115300,2.6809928,1.1835922,,,,,,,,,,,,,, -115400,3.0004354,1.2855097,,,,,,,,,,,,,, -115500,2.8313634,1.3057181,,,,,,,,,,,,,, -115600,2.8670645,1.3136283,,,,,,,,,,,,,, -115694,,,0.8061423897743225,0.6987681984901428,0.7060799598693848,1.1993192434310913,50000.0,0.5722000002861023,1.946460485458374,10000.0,39327.79904174805,40827.80290222168,39327.79904174805,1492.686810255051,3.298171281814575,0.0 -115700,3.0282505,1.2472644,,,,,,,,,,,,,, -115800,3.0431395,1.2617071,,,,,,,,,,,,,, -115900,2.9022486,1.2094338,,,,,,,,,,,,,, -116000,2.797208,1.2263477,,,,,,,,,,,,,, -116100,2.8502152,1.2635612,,,,,,,,,,,,,, -116200,2.7259223,1.2426366,,,,,,,,,,,,,, -116300,2.8873427,1.2408864,,,,,,,,,,,,,, -116400,2.8044758,1.2734721,,,,,,,,,,,,,, -116500,2.8900454,1.2474155,,,,,,,,,,,,,, -116600,2.6596496,1.2211162,,,,,,,,,,,,,, -116700,2.7334914,1.2741741,,,,,,,,,,,,,, -116800,2.9524505,1.2600017,,,,,,,,,,,,,, -116900,3.0000343,1.2497745,,,,,,,,,,,,,, -117000,2.827614,1.2287574,,,,,,,,,,,,,, -117100,2.7394207,1.3296123,,,,,,,,,,,,,, -117198,,,0.7992864847183228,0.7257764339447021,0.701259970664978,1.2185938358306885,50000.0,0.5745000243186951,1.972318410873413,10000.0,39837.94040846825,41354.602694273,39837.94040846825,1509.241043329239,3.349168300628662,0.0 -117200,2.873066,1.3306797,,,,,,,,,,,,,, -117300,2.6840074,1.2148724,,,,,,,,,,,,,, -117400,3.112776,1.1617064,,,,,,,,,,,,,, -117500,2.859057,1.1262043,,,,,,,,,,,,,, -117600,2.8159087,1.1901308,,,,,,,,,,,,,, -117700,2.7730634,1.1976869,,,,,,,,,,,,,, -117800,3.276781,1.2632365,,,,,,,,,,,,,, -117900,3.0392666,1.1870104,,,,,,,,,,,,,, -118000,2.7594547,1.1828512,,,,,,,,,,,,,, -118100,2.783299,1.2527969,,,,,,,,,,,,,, -118200,2.851099,1.1965381,,,,,,,,,,,,,, -118300,2.802965,1.1429031,,,,,,,,,,,,,, -118400,2.8207703,1.2504368,,,,,,,,,,,,,, -118500,3.2979383,1.2011461,,,,,,,,,,,,,, -118600,3.355078,1.336976,,,,,,,,,,,,,, -118700,2.94056,1.1660458,,,,,,,,,,,,,, -118701,,,0.8131775856018066,0.685913622379303,0.7095999717712402,1.1763111352920532,50000.0,0.5827000141143799,1.896999478340149,10000.0,40348.18306827545,41881.56796193123,40348.18306827545,1525.8609673976898,3.40002703666687,0.0 -118800,2.7812927,1.1683347,,,,,,,,,,,,,, -118900,2.8305273,1.2672863,,,,,,,,,,,,,, -119000,2.7350452,1.084816,,,,,,,,,,,,,, -119100,3.3875616,1.3616631,,,,,,,,,,,,,, -119200,2.9502742,1.1769792,,,,,,,,,,,,,, -119300,3.0200486,1.1834853,,,,,,,,,,,,,, -119400,2.8102674,1.1820953,,,,,,,,,,,,,, -119500,3.239461,1.2279407,,,,,,,,,,,,,, -119600,3.0112891,1.2149221,,,,,,,,,,,,,, -119700,2.7696283,1.1946754,,,,,,,,,,,,,, -119800,3.0895538,1.1377178,,,,,,,,,,,,,, -119900,3.175634,1.1942731,,,,,,,,,,,,,, -120000,2.7359266,1.1773254,,,,,,,,,,,,,, -120100,2.9179766,1.2142369,,,,,,,,,,,,,, -120200,3.3454149,1.1727824,,,,,,,,,,,,,, -120204,,,0.8085139989852905,0.6942028999328613,0.7087599635124207,1.1854325532913208,50000.0,0.5776000022888184,1.9270565509796145,10000.0,40858.13369560242,42408.31049633026,40858.13369560242,1542.5479755401611,3.452460289001465,0.0 -120300,3.0522363,1.2492899,,,,,,,,,,,,,, -120400,3.155913,1.2361884,,,,,,,,,,,,,, -120500,2.896512,1.2282237,,,,,,,,,,,,,, -120600,2.8669238,1.2505765,,,,,,,,,,,,,, -120700,2.9975495,1.1657217,,,,,,,,,,,,,, -120800,3.169521,1.3172463,,,,,,,,,,,,,, -120900,3.4426517,1.2329185,,,,,,,,,,,,,, -121000,3.1312892,1.2464272,,,,,,,,,,,,,, -121100,2.9162996,1.1704756,,,,,,,,,,,,,, -121200,3.099614,1.1579334,,,,,,,,,,,,,, -121300,3.1381633,1.1768751,,,,,,,,,,,,,, -121400,3.0949507,1.2721078,,,,,,,,,,,,,, -121500,2.907086,1.1385318,,,,,,,,,,,,,, -121600,3.0165374,1.1641452,,,,,,,,,,,,,, -121700,3.1856244,1.1823101,,,,,,,,,,,,,, -121707,,,0.8465999364852905,0.56064772605896,0.7129799723625183,1.1660146713256836,50000.0,0.579200029373169,1.906941294670105,10000.0,41368.25840616226,42935.28839254379,41368.25840616226,1559.2989346981049,3.502571582794189,0.0 -121800,3.1351504,1.2586265,,,,,,,,,,,,,, -121900,3.0546298,1.2276878,,,,,,,,,,,,,, -122000,3.162784,1.2434788,,,,,,,,,,,,,, -122100,3.1764061,1.1490854,,,,,,,,,,,,,, -122200,3.1383107,1.3037971,,,,,,,,,,,,,, -122300,3.0102708,1.2148986,,,,,,,,,,,,,, -122400,3.087371,1.1387513,,,,,,,,,,,,,, -122500,2.932546,1.1981094,,,,,,,,,,,,,, -122600,3.1507807,1.1745452,,,,,,,,,,,,,, -122700,2.889904,1.1987463,,,,,,,,,,,,,, -122800,3.0113885,1.1442512,,,,,,,,,,,,,, -122900,3.0118644,1.1997992,,,,,,,,,,,,,, -123000,2.920369,1.171345,,,,,,,,,,,,,, -123100,3.1209867,1.1317981,,,,,,,,,,,,,, -123200,2.9905827,1.1715866,,,,,,,,,,,,,, -123211,,,0.8325493931770325,0.5962198376655579,0.7136200070381165,1.1775128841400146,50000.0,0.5875000357627869,1.8746788501739504,10000.0,41878.31901669502,43462.08464598656,41878.31901669502,1575.9256281852722,3.5578701496124268,0.0 -123300,3.2242787,1.167841,,,,,,,,,,,,,, -123400,3.1713848,1.2392957,,,,,,,,,,,,,, -123500,3.1177266,1.1934153,,,,,,,,,,,,,, -123600,3.0270698,1.1131989,,,,,,,,,,,,,, -123700,2.718517,1.0470213,,,,,,,,,,,,,, -123800,2.9870589,1.1397071,,,,,,,,,,,,,, -123900,3.1823866,1.2001748,,,,,,,,,,,,,, -124000,3.149165,1.1280544,,,,,,,,,,,,,, -124100,3.0604796,1.2013619,,,,,,,,,,,,,, -124200,2.9851956,1.0785365,,,,,,,,,,,,,, -124300,3.1042168,1.1186106,,,,,,,,,,,,,, -124400,2.877032,1.1495159,,,,,,,,,,,,,, -124500,3.1877284,1.1968973,,,,,,,,,,,,,, -124600,2.891674,1.1298498,,,,,,,,,,,,,, -124700,3.1788492,1.1400152,,,,,,,,,,,,,, -124714,,,0.8289819955825806,0.6051926016807556,0.7158199548721313,1.1632517576217651,50000.0,0.5853000283241272,1.8980510234832764,10000.0,42388.27456307411,43988.95673203468,42388.27456307411,1592.7352993488312,3.611999750137329,0.0 -124800,3.0213323,1.0762857,,,,,,,,,,,,,, -124900,3.2674894,1.1238112,,,,,,,,,,,,,, -125000,3.1802597,1.0797362,,,,,,,,,,,,,, -125100,3.032765,1.1752565,,,,,,,,,,,,,, -125200,3.208299,1.1343212,,,,,,,,,,,,,, -125300,3.053131,1.1370249,,,,,,,,,,,,,, -125400,3.1790802,1.1392739,,,,,,,,,,,,,, -125500,3.0817938,1.0523555,,,,,,,,,,,,,, -125600,3.2385337,1.1550137,,,,,,,,,,,,,, -125700,2.992542,1.1716393,,,,,,,,,,,,,, -125800,3.282744,1.1682065,,,,,,,,,,,,,, -125900,3.159017,1.1116006,,,,,,,,,,,,,, -126000,3.172036,1.0772084,,,,,,,,,,,,,, -126100,3.447621,1.1509976,,,,,,,,,,,,,, -126200,3.137516,1.1455982,,,,,,,,,,,,,, -126218,,,0.8241987824440002,0.6316837072372437,0.7112999558448792,1.1798099279403689,50000.0,0.5856000185012817,1.908242106437683,10000.0,42898.41350364685,44515.80540394783,42898.41350364685,1609.340696811676,3.663426637649536,0.0 -126300,3.1746824,1.0925822,,,,,,,,,,,,,, -126400,2.8264294,1.1386048,,,,,,,,,,,,,, -126500,3.0755801,1.1756337,,,,,,,,,,,,,, -126600,3.8493268,1.2038281,,,,,,,,,,,,,, -126700,3.2519226,1.1068654,,,,,,,,,,,,,, -126800,3.1034937,1.1046934,,,,,,,,,,,,,, -126900,3.4304924,1.1514252,,,,,,,,,,,,,, -127000,3.1616766,1.1141627,,,,,,,,,,,,,, -127100,3.1619549,1.1414454,,,,,,,,,,,,,, -127200,2.9432135,1.0104482,,,,,,,,,,,,,, -127300,3.6546755,1.1768022,,,,,,,,,,,,,, -127400,3.0914395,1.1194426,,,,,,,,,,,,,, -127500,3.168761,1.0858655,,,,,,,,,,,,,, -127600,3.2008674,1.1307247,,,,,,,,,,,,,, -127700,3.127759,1.126365,,,,,,,,,,,,,, -127721,,,0.8290417790412903,0.6082040667533875,0.7193999886512756,1.1457984447479248,50000.0,0.5924000144004822,1.8633819818496704,10000.0,43408.54190301895,45042.67664170265,43408.54190301895,1625.9791305065155,3.715073585510254,0.0 -127800,3.0254917,1.1155192,,,,,,,,,,,,,, -127900,3.1875362,1.142724,,,,,,,,,,,,,, -128000,3.2324092,0.98970807,,,,,,,,,,,,,, -128100,3.3788657,1.0406461,,,,,,,,,,,,,, -128200,3.2405381,1.0421759,,,,,,,,,,,,,, -128300,3.2876954,1.165146,,,,,,,,,,,,,, -128400,3.463474,1.0981696,,,,,,,,,,,,,, -128500,3.4321141,1.0901421,,,,,,,,,,,,,, -128600,3.0816598,1.1493523,,,,,,,,,,,,,, -128700,3.3530796,1.1525227,,,,,,,,,,,,,, -128800,3.0409386,1.1093636,,,,,,,,,,,,,, -128900,3.2839875,1.1183426,,,,,,,,,,,,,, -129000,3.0993638,1.0673814,,,,,,,,,,,,,, -129100,4.104765,1.1328738,,,,,,,,,,,,,, -129200,3.1283636,1.0726273,,,,,,,,,,,,,, -129225,,,0.817781388759613,0.6508827209472656,0.7068399786949158,1.1904618740081787,50000.0,0.5823000073432922,1.93605637550354,10000.0,43918.64049005509,45569.49729180336,43918.64049005509,1642.595087766647,3.7680022716522217,0.0 -129300,3.1231806,1.0247307,,,,,,,,,,,,,, -129400,3.2515247,1.0285734,,,,,,,,,,,,,, -129500,3.1911614,1.107709,,,,,,,,,,,,,, -129600,3.0650046,1.0404979,,,,,,,,,,,,,, -129700,3.3054276,1.1399195,,,,,,,,,,,,,, -129800,3.186343,1.0950038,,,,,,,,,,,,,, -129900,3.2137818,1.0912075,,,,,,,,,,,,,, -130000,3.3117864,1.1203483,,,,,,,,,,,,,, -130100,3.0126078,1.034441,,,,,,,,,,,,,, -130200,3.5715847,1.1390258,,,,,,,,,,,,,, -130300,3.1264887,1.0789015,,,,,,,,,,,,,, -130400,3.8463821,1.2184229,,,,,,,,,,,,,, -130500,3.3751166,1.1387541,,,,,,,,,,,,,, -130600,3.50603,1.0978634,,,,,,,,,,,,,, -130700,3.1216679,1.0368156,,,,,,,,,,,,,, -130728,,,0.8703563213348389,0.4560602903366089,0.7164199948310852,1.169049859046936,50000.0,0.5943000316619873,1.8778047561645508,10000.0,44428.71792554855,46096.24298453331,44428.71792554855,1659.1601164340973,3.817939519882202,0.0 -130800,3.4483275,1.0985944,,,,,,,,,,,,,, -130900,3.5643647,1.1083214,,,,,,,,,,,,,, -131000,3.4295423,1.0969288,,,,,,,,,,,,,, -131100,3.3545537,0.99894583,,,,,,,,,,,,,, -131200,3.2681231,1.0793613,,,,,,,,,,,,,, -131300,3.3791227,1.128434,,,,,,,,,,,,,, -131400,3.3076694,1.1264881,,,,,,,,,,,,,, -131500,3.4228058,1.0822229,,,,,,,,,,,,,, -131600,3.4615736,1.0792692,,,,,,,,,,,,,, -131700,3.34331,1.1159012,,,,,,,,,,,,,, -131800,3.3372486,1.0515453,,,,,,,,,,,,,, -131900,3.2872357,1.0643401,,,,,,,,,,,,,, -132000,3.496054,1.1206074,,,,,,,,,,,,,, -132100,3.3705575,1.1451074,,,,,,,,,,,,,, -132200,3.2956173,0.99965423,,,,,,,,,,,,,, -132232,,,0.8533163070678711,0.5247998833656311,0.7186799645423889,1.150518774986267,50000.0,0.5918000340461731,1.8805676698684688,10000.0,44938.95955300331,46623.186673641205,44938.95955300331,1675.7581803798676,3.868992805480957,0.0 -132300,3.3423285,1.0598557,,,,,,,,,,,,,, -132400,3.1121628,0.98470074,,,,,,,,,,,,,, -132500,3.385045,1.0844039,,,,,,,,,,,,,, -132600,3.744755,1.0216022,,,,,,,,,,,,,, -132700,3.503061,1.1509802,,,,,,,,,,,,,, -132800,3.3308468,1.0962926,,,,,,,,,,,,,, -132900,3.4749146,1.0946635,,,,,,,,,,,,,, -133000,3.3616478,1.1017617,,,,,,,,,,,,,, -133100,3.0443938,1.0156912,,,,,,,,,,,,,, -133200,3.584263,1.1416821,,,,,,,,,,,,,, -133300,3.3873377,1.0868163,,,,,,,,,,,,,, -133400,3.3500783,1.0708524,,,,,,,,,,,,,, -133500,3.4089396,1.0282551,,,,,,,,,,,,,, -133600,3.4230165,1.0876322,,,,,,,,,,,,,, -133700,3.396916,1.0258721,,,,,,,,,,,,,, -133735,,,0.8556680083274841,0.5154725313186646,0.7227599620819092,1.1428040266036987,50000.0,0.5940999984741211,1.8898875713348389,10000.0,45448.97965598106,47149.84192085266,45448.97965598106,1692.2820625305176,3.927493095397949,0.0 -133800,3.3854535,0.982554,,,,,,,,,,,,,, -133900,3.4597476,1.075077,,,,,,,,,,,,,, -134000,3.2951314,1.043669,,,,,,,,,,,,,, -134100,3.6504114,0.99120027,,,,,,,,,,,,,, -134200,3.6120615,1.0545468,,,,,,,,,,,,,, -134300,3.3184135,0.98190975,,,,,,,,,,,,,, -134400,3.3348663,1.083926,,,,,,,,,,,,,, -134500,3.461633,1.0669297,,,,,,,,,,,,,, -134600,3.373992,1.1268846,,,,,,,,,,,,,, -134700,3.3989413,1.0620489,,,,,,,,,,,,,, -134800,3.4922454,1.0174286,,,,,,,,,,,,,, -134900,3.4639204,1.0445857,,,,,,,,,,,,,, -135000,3.2908483,1.0026851,,,,,,,,,,,,,, -135100,3.5910478,1.0204985,,,,,,,,,,,,,, -135200,3.6049006,0.9789883,,,,,,,,,,,,,, -135239,,,0.8517617583274841,0.5228502750396729,0.7201799750328064,1.1425071954727173,50000.0,0.5942000150680542,1.886696934700012,10000.0,45959.17137622833,47676.69100522995,45959.17137622833,1708.829159975052,3.9842641353607178,0.0 -135300,3.5615137,0.95531857,,,,,,,,,,,,,, -135400,3.4154644,0.9890677,,,,,,,,,,,,,, -135500,3.5833635,1.1017456,,,,,,,,,,,,,, -135600,3.399013,0.99176455,,,,,,,,,,,,,, -135700,3.4643433,1.0089962,,,,,,,,,,,,,, -135800,3.180283,1.0149004,,,,,,,,,,,,,, -135900,3.5908241,0.9797832,,,,,,,,,,,,,, -136000,3.4864607,1.026721,,,,,,,,,,,,,, -136100,3.503311,1.0741166,,,,,,,,,,,,,, -136200,3.4167132,0.9737241,,,,,,,,,,,,,, -136300,3.7570677,1.1446347,,,,,,,,,,,,,, -136400,3.746548,1.0632257,,,,,,,,,,,,,, -136500,3.4405053,1.0142447,,,,,,,,,,,,,, -136600,3.7089188,1.0451406,,,,,,,,,,,,,, -136700,3.3778493,0.97830117,,,,,,,,,,,,,, -136742,,,0.8520607352256775,0.5236783027648926,0.7238399982452393,1.1254109144210815,50000.0,0.5951000452041626,1.8549350500106807,10000.0,46469.35333895683,48203.49363040924,46469.35333895683,1725.3450455665588,4.035556793212891,0.0 -136800,3.2935376,0.98958576,,,,,,,,,,,,,, -136900,3.3301256,0.98576415,,,,,,,,,,,,,, -137000,3.3274307,0.95428663,,,,,,,,,,,,,, -137100,3.5051103,0.9734166,,,,,,,,,,,,,, -137200,3.4338439,0.9733459,,,,,,,,,,,,,, -137300,3.654917,1.0173514,,,,,,,,,,,,,, -137400,3.3053875,0.94463795,,,,,,,,,,,,,, -137500,3.8290055,1.0313728,,,,,,,,,,,,,, -137600,3.6655848,1.0671685,,,,,,,,,,,,,, -137700,3.5690598,0.97675484,,,,,,,,,,,,,, -137800,3.4856424,1.0958567,,,,,,,,,,,,,, -137900,3.5335102,1.0224538,,,,,,,,,,,,,, -138000,3.6960585,0.9237281,,,,,,,,,,,,,, -138100,3.6413646,1.0338618,,,,,,,,,,,,,, -138200,3.7095742,1.0898597,,,,,,,,,,,,,, -138246,,,0.8571826815605164,0.4904449582099914,0.7291399836540222,1.1260720491409302,50000.0,0.6016000509262085,1.8454172611236568,10000.0,46979.564589738846,48730.506182432175,46979.564589738846,1742.0426452159882,4.085598707199097,0.0 -138300,3.6742992,1.0109876,,,,,,,,,,,,,, -138400,3.5342817,0.9529579,,,,,,,,,,,,,, -138500,3.4015446,0.9853613,,,,,,,,,,,,,, -138600,4.0375514,1.1021268,,,,,,,,,,,,,, -138700,3.4406836,1.045138,,,,,,,,,,,,,, -138800,3.5674334,1.0455337,,,,,,,,,,,,,, -138900,3.4805245,0.97209054,,,,,,,,,,,,,, -139000,3.380865,0.93460816,,,,,,,,,,,,,, -139100,3.5897844,1.0466063,,,,,,,,,,,,,, -139200,3.6172073,1.0265181,,,,,,,,,,,,,, -139300,3.4373147,0.9100929,,,,,,,,,,,,,, -139400,3.5076528,0.95109713,,,,,,,,,,,,,, -139500,3.911839,1.0122201,,,,,,,,,,,,,, -139600,3.4325612,0.9396031,,,,,,,,,,,,,, -139700,3.5951216,0.8411063,,,,,,,,,,,,,, -139749,,,0.8938735723495483,0.3696881830692291,0.7278199791908264,1.112565040588379,50000.0,0.6012000441551208,1.864063143730164,10000.0,47489.7115893364,49257.37302136421,47489.7115893364,1758.654646873474,4.139785051345825,0.0 -139800,3.8357255,1.0748047,,,,,,,,,,,,,, -139900,3.6597285,0.9070725,,,,,,,,,,,,,, -140000,3.508704,0.93696,,,,,,,,,,,,,, -140100,3.9300785,0.9443106,,,,,,,,,,,,,, -140200,3.7632365,1.0564425,,,,,,,,,,,,,, -140300,3.7808976,1.0104357,,,,,,,,,,,,,, -140400,3.9006233,1.0272747,,,,,,,,,,,,,, -140500,3.7540781,1.013069,,,,,,,,,,,,,, -140600,3.874579,0.9137955,,,,,,,,,,,,,, -140700,3.7897172,1.0346075,,,,,,,,,,,,,, -140800,3.9611967,1.092711,,,,,,,,,,,,,, -140900,3.5431893,0.8788127,,,,,,,,,,,,,, -141000,3.7429078,1.0545771,,,,,,,,,,,,,, -141100,3.4457467,0.98481566,,,,,,,,,,,,,, -141200,4.086365,1.091547,,,,,,,,,,,,,, -141253,,,0.88578200340271,0.3992698490619659,0.7300199866294861,1.1068578958511353,50000.0,0.6060000061988831,1.822378635406494,10000.0,47999.79533600807,49784.1164534092,47999.79533600807,1775.2069325447085,4.194054841995239,0.0 -141300,3.949481,0.993441,,,,,,,,,,,,,, -141400,3.4691136,0.9337023,,,,,,,,,,,,,, -141500,3.7232535,1.0182443,,,,,,,,,,,,,, -141600,4.0618505,1.0178336,,,,,,,,,,,,,, -141700,3.6698642,0.98316985,,,,,,,,,,,,,, -141800,3.8541684,1.0442545,,,,,,,,,,,,,, -141900,3.638292,0.93608,,,,,,,,,,,,,, -142000,3.6985786,0.99394023,,,,,,,,,,,,,, -142100,4.139991,1.0496272,,,,,,,,,,,,,, -142200,3.4790306,0.905156,,,,,,,,,,,,,, -142300,4.039945,0.9826826,,,,,,,,,,,,,, -142400,3.5112605,0.8990346,,,,,,,,,,,,,, -142500,3.4121666,0.85361683,,,,,,,,,,,,,, -142600,3.676904,0.904608,,,,,,,,,,,,,, -142700,3.6081734,0.9964444,,,,,,,,,,,,,, -142756,,,0.8810786008834839,0.4117041528224945,0.7327799797058105,1.0963129997253418,50000.0,0.6037000417709351,1.838681936264038,10000.0,48509.867886304855,50310.97115421295,48509.867886304855,1791.8828384876251,4.247902631759644,0.0 -142800,3.9894738,1.033361,,,,,,,,,,,,,, -142900,3.6508176,0.93322885,,,,,,,,,,,,,, -143000,3.6568656,0.93062747,,,,,,,,,,,,,, -143100,4.4996867,1.0427506,,,,,,,,,,,,,, -143200,3.8394632,0.97101235,,,,,,,,,,,,,, -143300,4.099977,0.9823928,,,,,,,,,,,,,, -143400,3.8100126,0.99637103,,,,,,,,,,,,,, -143500,3.4653008,0.9139747,,,,,,,,,,,,,, -143600,3.7794201,1.0492245,,,,,,,,,,,,,, -143700,3.9452217,0.92452663,,,,,,,,,,,,,, -143800,3.7466156,0.9612571,,,,,,,,,,,,,, -143900,3.9470832,0.91529363,,,,,,,,,,,,,, -144000,3.6982505,1.0062875,,,,,,,,,,,,,, -144100,3.907727,0.9740248,,,,,,,,,,,,,, -144200,3.644554,0.9745261,,,,,,,,,,,,,, -144260,,,0.8727877736091614,0.44132199883461,0.7282800078392029,1.1258344650268557,50000.0,0.5979000329971313,1.8827852010726929,10000.0,49019.87905693054,50837.62298154831,49019.87905693054,1808.416305065155,4.30181097984314,0.0 -144300,3.788421,0.8638268,,,,,,,,,,,,,, -144400,3.8033662,0.9386945,,,,,,,,,,,,,, -144500,3.7412102,0.83391464,,,,,,,,,,,,,, -144600,4.2223206,0.9375226,,,,,,,,,,,,,, -144700,3.791878,1.0001016,,,,,,,,,,,,,, -144800,4.2818484,0.9470916,,,,,,,,,,,,,, -144900,3.7877626,0.89722836,,,,,,,,,,,,,, -145000,3.6415281,0.86763865,,,,,,,,,,,,,, -145100,4.0465817,0.9056485,,,,,,,,,,,,,, -145200,3.5380967,0.921818,,,,,,,,,,,,,, -145300,3.7721741,0.8594279,,,,,,,,,,,,,, -145400,4.173232,0.9324262,,,,,,,,,,,,,, -145500,3.7788842,0.91196567,,,,,,,,,,,,,, -145600,3.8382404,0.9379471,,,,,,,,,,,,,, -145700,3.709012,0.955539,,,,,,,,,,,,,, -145763,,,0.8825733065605164,0.4018238186836242,0.7343399524688721,1.1096502542495728,50000.0,0.6073000431060791,1.843134880065918,10000.0,49530.08825039864,51364.57683515549,49530.08825039864,1825.05018401146,4.359186172485352,0.0 -145800,4.351055,0.9574891,,,,,,,,,,,,,, -145900,4.001443,0.9467641,,,,,,,,,,,,,, -146000,3.6376858,0.7709277,,,,,,,,,,,,,, -146100,4.045342,0.92439675,,,,,,,,,,,,,, -146200,3.5989428,0.96370894,,,,,,,,,,,,,, -146300,3.6716013,0.8450326,,,,,,,,,,,,,, -146400,4.0081277,0.890758,,,,,,,,,,,,,, -146500,3.560692,0.870525,,,,,,,,,,,,,, -146600,3.7608411,0.8521062,,,,,,,,,,,,,, -146700,3.8301804,0.9060859,,,,,,,,,,,,,, -146800,3.989981,0.9568711,,,,,,,,,,,,,, -146900,3.7551641,0.9324926,,,,,,,,,,,,,, -147000,3.7634819,0.887727,,,,,,,,,,,,,, -147100,3.9781537,0.87589914,,,,,,,,,,,,,, -147200,4.127257,0.917794,,,,,,,,,,,,,, -147266,,,0.88671875,0.3884380757808685,0.7356199622154236,1.0984673500061035,50000.0,0.6115000247955322,1.845555543899536,10000.0,50040.01444983482,51891.32067155838,50040.01444983482,1841.761024713516,4.412938833236694,0.0 -147300,4.225789,0.91871345,,,,,,,,,,,,,, -147400,3.9220264,0.9011258,,,,,,,,,,,,,, -147500,4.1187577,0.85248435,,,,,,,,,,,,,, -147600,3.797937,0.9143201,,,,,,,,,,,,,, -147700,3.9746165,0.9520372,,,,,,,,,,,,,, -147800,3.6925275,0.8911271,,,,,,,,,,,,,, -147900,3.8504481,0.9223284,,,,,,,,,,,,,, -148000,4.2067327,0.9503232,,,,,,,,,,,,,, -148100,3.9705644,0.8542362,,,,,,,,,,,,,, -148200,4.018376,0.9215766,,,,,,,,,,,,,, -148300,4.2519646,1.0027723,,,,,,,,,,,,,, -148400,4.2161937,0.92534643,,,,,,,,,,,,,, -148500,4.4990954,1.0034714,,,,,,,,,,,,,, -148600,4.164833,0.8928433,,,,,,,,,,,,,, -148700,4.0988026,0.86299986,,,,,,,,,,,,,, -148770,,,0.9123883843421936,0.3148034214973449,0.7336199879646301,1.105415105819702,50000.0,0.6085000038146973,1.8452770709991453,10000.0,50550.195132255554,52418.38363933563,50550.195132255554,1858.532964706421,4.470274448394775,0.0 -148800,3.994137,0.82865256,,,,,,,,,,,,,, -148900,3.5856826,0.80446184,,,,,,,,,,,,,, -149000,3.651125,0.8148676,,,,,,,,,,,,,, -149100,3.8866045,0.8426918,,,,,,,,,,,,,, -149200,4.007617,0.88358545,,,,,,,,,,,,,, -149300,3.8414614,0.91877663,,,,,,,,,,,,,, -149400,3.9954627,0.88589865,,,,,,,,,,,,,, -149500,4.503806,0.87966704,,,,,,,,,,,,,, -149600,3.912987,0.8592236,,,,,,,,,,,,,, -149700,3.8758307,0.8919998,,,,,,,,,,,,,, -149800,3.900406,0.81631196,,,,,,,,,,,,,, -149900,4.0615206,0.92040485,,,,,,,,,,,,,, -150000,4.3976026,0.8790374,,,,,,,,,,,,,, -150100,3.8583875,0.8115389,,,,,,,,,,,,,, -150200,3.9696648,0.8985231,,,,,,,,,,,,,, -150274,,,0.913305163383484,0.2986312806606293,0.7386999726295471,1.0863254070281982,50000.0,0.6093000173568726,1.8343807458877563,10000.0,51060.157440423965,52945.19025826454,51060.157440423965,1875.268774271012,4.525710821151733,0.0 -150300,4.264492,0.9032929,,,,,,,,,,,,,, -150400,4.1224494,0.81676733,,,,,,,,,,,,,, -150500,3.9332294,0.8680644,,,,,,,,,,,,,, -150600,3.808293,0.85171187,,,,,,,,,,,,,, -150700,4.1016526,0.9542525,,,,,,,,,,,,,, -150800,4.0514107,0.86991805,,,,,,,,,,,,,, -150900,3.938345,0.89053994,,,,,,,,,,,,,, -151000,3.9640408,0.8619808,,,,,,,,,,,,,, -151100,3.8461058,0.87205184,,,,,,,,,,,,,, -151200,4.1106505,0.91384655,,,,,,,,,,,,,, -151300,3.8438044,0.7867296,,,,,,,,,,,,,, -151400,4.0590467,0.8527657,,,,,,,,,,,,,, -151500,4.056373,0.8182263,,,,,,,,,,,,,, -151600,4.020625,0.8758757,,,,,,,,,,,,,, -151700,4.402885,0.91774905,,,,,,,,,,,,,, -151777,,,0.9090202450752258,0.3202269375324249,0.7390799522399902,1.08871328830719,50000.0,0.6095000505447388,1.837144374847412,10000.0,51570.23234939575,53472.77893829346,51570.23234939575,1892.6738231182096,4.581248283386231,0.0 -151800,3.947674,0.8395514,,,,,,,,,,,,,, -151900,4.361172,0.9094162,,,,,,,,,,,,,, -152000,4.3036017,0.898693,,,,,,,,,,,,,, -152100,4.0127006,0.8460245,,,,,,,,,,,,,, -152200,4.511347,0.85035723,,,,,,,,,,,,,, -152300,4.3163834,0.8802031,,,,,,,,,,,,,, -152400,4.0295916,0.7907652,,,,,,,,,,,,,, -152500,4.1244187,0.84247124,,,,,,,,,,,,,, -152600,4.352835,0.7967892,,,,,,,,,,,,,, -152700,4.223338,0.87756574,,,,,,,,,,,,,, -152800,4.4571424,0.8671171,,,,,,,,,,,,,, -152900,4.689693,0.8957247,,,,,,,,,,,,,, -153000,3.9758854,0.81350136,,,,,,,,,,,,,, -153100,4.152593,0.851726,,,,,,,,,,,,,, -153200,4.4879704,0.95061094,,,,,,,,,,,,,, -153281,,,0.9080038070678712,0.3098303377628326,0.7390999794006348,1.0933698415756226,50000.0,0.6140000224113464,1.825111627578736,10000.0,52080.39623951912,53999.82557368279,52080.39623951912,1909.446493148804,4.638153791427612,0.0 -153300,4.0796328,0.82862085,,,,,,,,,,,,,, -153400,4.27601,0.7887939,,,,,,,,,,,,,, -153500,3.848566,0.7584916,,,,,,,,,,,,,, -153600,3.8983457,0.81058294,,,,,,,,,,,,,, -153700,4.146115,0.9293486,,,,,,,,,,,,,, -153800,4.030537,0.9015322,,,,,,,,,,,,,, -153900,4.439581,0.88291556,,,,,,,,,,,,,, -154000,4.0070777,0.7620097,,,,,,,,,,,,,, -154100,4.1803894,0.84854484,,,,,,,,,,,,,, -154200,4.008245,0.8283742,,,,,,,,,,,,,, -154300,4.062664,0.7384594,,,,,,,,,,,,,, -154400,4.690441,0.9102403,,,,,,,,,,,,,, -154500,4.1965733,0.81647587,,,,,,,,,,,,,, -154600,4.072263,0.8029928,,,,,,,,,,,,,, -154700,4.1274486,0.81280804,,,,,,,,,,,,,, -154785,,,0.9130859375,0.2970570921897888,0.7426199913024902,1.0836670398712158,50000.0,0.6220000386238098,1.8184280395507808,10000.0,52590.39731740952,54526.59492826462,52590.39731740952,1926.1085708141327,4.691672086715698,0.0 -154800,4.9158463,0.8439186,,,,,,,,,,,,,, -154900,4.0437574,0.79600334,,,,,,,,,,,,,, -155000,4.295517,0.7829103,,,,,,,,,,,,,, -155100,4.0743237,0.8054811,,,,,,,,,,,,,, -155200,4.125278,0.7837333,,,,,,,,,,,,,, -155300,4.078988,0.8750863,,,,,,,,,,,,,, -155400,4.7245426,0.8385893,,,,,,,,,,,,,, -155500,4.255236,0.88269037,,,,,,,,,,,,,, -155600,3.8942378,0.72874814,,,,,,,,,,,,,, -155700,4.291485,0.8010607,,,,,,,,,,,,,, -155800,3.9843156,0.7947867,,,,,,,,,,,,,, -155900,4.3197775,0.81487525,,,,,,,,,,,,,, -156000,4.242866,0.80033267,,,,,,,,,,,,,, -156100,4.061205,0.82206887,,,,,,,,,,,,,, -156200,4.7639284,0.80035007,,,,,,,,,,,,,, -156289,,,0.9122289419174194,0.3007284998893738,0.7432799935340881,1.0831379890441897,50000.0,0.6170000433921814,1.8322269916534424,10000.0,53100.54725623131,55053.689846515656,53100.54725623131,1942.9485993385315,4.743105173110962,0.0 -156300,4.4115686,0.8139876,,,,,,,,,,,,,, -156400,4.110632,0.78857917,,,,,,,,,,,,,, -156500,3.832324,0.75927216,,,,,,,,,,,,,, -156600,4.4817863,0.78987765,,,,,,,,,,,,,, -156700,4.146386,0.7596086,,,,,,,,,,,,,, -156800,4.6358795,0.8345526,,,,,,,,,,,,,, -156900,4.268369,0.74395114,,,,,,,,,,,,,, -157000,4.044331,0.80429995,,,,,,,,,,,,,, -157100,3.9157965,0.7532613,,,,,,,,,,,,,, -157200,4.173934,0.81731725,,,,,,,,,,,,,, -157300,4.214865,0.7859831,,,,,,,,,,,,,, -157400,4.5196886,0.8195273,,,,,,,,,,,,,, -157500,3.9760349,0.748797,,,,,,,,,,,,,, -157600,4.852836,0.74693286,,,,,,,,,,,,,, -157700,4.7431874,0.8407354,,,,,,,,,,,,,, -157792,,,0.9258809089660645,0.2602937519550323,0.7439599633216858,1.0839968919754028,50000.0,0.6177000403404236,1.8258272409439087,10000.0,53610.67733025551,55580.45754265785,53610.67733025551,1959.4862580299373,4.789821863174439,0.0 -157800,4.671836,0.7613278,,,,,,,,,,,,,, -157900,4.1240697,0.83507353,,,,,,,,,,,,,, -158000,4.321442,0.7826952,,,,,,,,,,,,,, -158100,4.3361998,0.7992326,,,,,,,,,,,,,, -158200,4.028484,0.7392943,,,,,,,,,,,,,, -158300,4.635071,0.7889716,,,,,,,,,,,,,, -158400,4.181439,0.762129,,,,,,,,,,,,,, -158500,4.268212,0.73580754,,,,,,,,,,,,,, -158600,4.5070767,0.82981616,,,,,,,,,,,,,, -158700,4.436163,0.7601256,,,,,,,,,,,,,, -158800,4.528302,0.70013,,,,,,,,,,,,,, -158900,4.329645,0.75759745,,,,,,,,,,,,,, -159000,4.488339,0.7284255,,,,,,,,,,,,,, -159100,4.4319077,0.8402345,,,,,,,,,,,,,, -159200,4.533658,0.83895767,,,,,,,,,,,,,, -159296,,,0.9353076815605164,0.2272559702396392,0.7450799942016602,1.0768632888793943,50000.0,0.6195000410079956,1.8430743217468264,10000.0,54120.89049816132,56107.6992123127,54120.89049816132,1976.401884317398,4.849618911743164,0.0 -159300,3.9397147,0.8074092,,,,,,,,,,,,,, -159400,4.414805,0.74240994,,,,,,,,,,,,,, -159500,4.327464,0.74495137,,,,,,,,,,,,,, -159600,4.148322,0.68395495,,,,,,,,,,,,,, -159700,4.3914647,0.79603016,,,,,,,,,,,,,, -159800,4.513506,0.7609571,,,,,,,,,,,,,, -159900,4.289285,0.79390943,,,,,,,,,,,,,, -160000,4.565674,0.7469954,,,,,,,,,,,,,, -160100,4.435073,0.8038087,,,,,,,,,,,,,, -160200,4.0628533,0.6786945,,,,,,,,,,,,,, -160300,4.198809,0.7615655,,,,,,,,,,,,,, -160400,4.484155,0.7910188,,,,,,,,,,,,,, -160500,4.4486656,0.7831785,,,,,,,,,,,,,, -160600,4.1750045,0.7553875,,,,,,,,,,,,,, -160700,4.31937,0.7793546,,,,,,,,,,,,,, -160799,,,0.9308434128761292,0.2407589256763458,0.7454400062561035,1.0781235694885254,50000.0,0.6162000298500061,1.8339414596557613,10000.0,54631.05508708954,56634.638201236725,54631.05508708954,1993.062605857849,4.90923810005188,0.0 -160800,4.432509,0.7759725,,,,,,,,,,,,,, -160900,4.2439723,0.7427143,,,,,,,,,,,,,, -161000,4.4719353,0.77451986,,,,,,,,,,,,,, -161100,4.2885957,0.7921667,,,,,,,,,,,,,, -161200,4.6102533,0.74927145,,,,,,,,,,,,,, -161300,4.123923,0.69721836,,,,,,,,,,,,,, -161400,4.405743,0.7743292,,,,,,,,,,,,,, -161500,4.388391,0.77731824,,,,,,,,,,,,,, -161600,3.9988232,0.6942461,,,,,,,,,,,,,, -161700,4.2819467,0.71787524,,,,,,,,,,,,,, -161800,4.073989,0.738488,,,,,,,,,,,,,, -161900,4.140622,0.7392317,,,,,,,,,,,,,, -162000,4.5013247,0.7767137,,,,,,,,,,,,,, -162100,3.8131096,0.7275447,,,,,,,,,,,,,, -162200,4.424634,0.72031915,,,,,,,,,,,,,, -162300,4.1434712,0.70069957,,,,,,,,,,,,,, -162303,,,0.9325773119926452,0.2340480387210846,0.7475000023841858,1.0697925090789795,50000.0,0.6193000078201294,1.8274227380752563,10000.0,55141.21358561516,57161.56299781799,55141.21358561516,2009.718491077423,4.966087818145752,0.0 -162400,4.9262643,0.73152035,,,,,,,,,,,,,, -162500,5.381612,0.8564984,,,,,,,,,,,,,, -162600,4.1708603,0.64988434,,,,,,,,,,,,,, -162700,4.5290823,0.7540439,,,,,,,,,,,,,, -162800,4.1238165,0.76452225,,,,,,,,,,,,,, -162900,4.21996,0.78340644,,,,,,,,,,,,,, -163000,4.255203,0.66402686,,,,,,,,,,,,,, -163100,4.3907237,0.64789414,,,,,,,,,,,,,, -163200,4.8588448,0.74926317,,,,,,,,,,,,,, -163300,4.4287243,0.7083878,,,,,,,,,,,,,, -163400,4.754788,0.7501652,,,,,,,,,,,,,, -163500,4.2940054,0.71753573,,,,,,,,,,,,,, -163600,4.4529324,0.7274811,,,,,,,,,,,,,, -163700,4.3067393,0.71846986,,,,,,,,,,,,,, -163800,4.0437846,0.70830953,,,,,,,,,,,,,, -163806,,,0.935327649116516,0.2216317504644394,0.7468799948692322,1.0788114070892334,50000.0,0.6229000091552734,1.8333741426467896,10000.0,55651.39206695557,57688.54973888397,55651.39206695557,2026.4165422916408,5.023638486862183,0.0 -163900,4.1317596,0.7008355,,,,,,,,,,,,,, -164000,3.7957118,0.6209839,,,,,,,,,,,,,, -164100,4.5993543,0.75294155,,,,,,,,,,,,,, -164200,4.2709107,0.701132,,,,,,,,,,,,,, -164300,4.490182,0.71501815,,,,,,,,,,,,,, -164400,3.811043,0.6504332,,,,,,,,,,,,,, -164500,4.0397825,0.62288064,,,,,,,,,,,,,, -164600,4.472105,0.70710427,,,,,,,,,,,,,, -164700,4.4151444,0.71869874,,,,,,,,,,,,,, -164800,4.506406,0.72006935,,,,,,,,,,,,,, -164900,4.1230087,0.6503599,,,,,,,,,,,,,, -165000,4.3414145,0.6726849,,,,,,,,,,,,,, -165100,3.9862494,0.67726785,,,,,,,,,,,,,, -165200,4.080937,0.7150222,,,,,,,,,,,,,, -165300,5.198391,0.73793817,,,,,,,,,,,,,, -165309,,,0.9376594424247742,0.2162195891141891,0.7461599707603455,1.07829487323761,50000.0,0.6246000528335571,1.8346995115280151,10000.0,56161.42722654343,58215.25165247917,56161.42722654343,2042.9757788181305,5.078388690948486,0.0 -165400,4.5437603,0.75818837,,,,,,,,,,,,,, -165500,5.336226,0.7382669,,,,,,,,,,,,,, -165600,4.9266615,0.7420831,,,,,,,,,,,,,, -165700,4.3637195,0.7025369,,,,,,,,,,,,,, -165800,5.068698,0.7307519,,,,,,,,,,,,,, -165900,4.8520885,0.7533509,,,,,,,,,,,,,, -166000,4.4327145,0.7072247,,,,,,,,,,,,,, -166100,4.516637,0.69523424,,,,,,,,,,,,,, -166200,4.4359593,0.7266788,,,,,,,,,,,,,, -166300,4.305624,0.69044805,,,,,,,,,,,,,, -166400,4.2138305,0.65979916,,,,,,,,,,,,,, -166500,4.485586,0.6559361,,,,,,,,,,,,,, -166600,4.635884,0.68492,,,,,,,,,,,,,, -166700,4.942111,0.72190344,,,,,,,,,,,,,, -166800,4.6000714,0.69221747,,,,,,,,,,,,,, -166812,,,0.9435387253761292,0.202550783753395,0.7520999908447266,1.061427116394043,50000.0,0.6241000294685364,1.825334906578064,10000.0,56671.53956079483,58742.24252414704,56671.53956079483,2059.741940021515,5.1376566886901855,0.0 -166900,4.334384,0.63511246,,,,,,,,,,,,,, -167000,4.4542656,0.7040321,,,,,,,,,,,,,, -167100,4.72811,0.6642285,,,,,,,,,,,,,, -167200,4.3732934,0.67817473,,,,,,,,,,,,,, -167300,4.7550983,0.7657738,,,,,,,,,,,,,, -167400,5.2327976,0.78498137,,,,,,,,,,,,,, -167500,4.7961516,0.6878571,,,,,,,,,,,,,, -167600,4.2113423,0.6692564,,,,,,,,,,,,,, -167700,4.4267335,0.7055707,,,,,,,,,,,,,, -167800,4.7095723,0.7514301,,,,,,,,,,,,,, -167900,4.6689463,0.7227254,,,,,,,,,,,,,, -168000,4.8748364,0.78184044,,,,,,,,,,,,,, -168100,4.6557403,0.66637826,,,,,,,,,,,,,, -168200,4.669334,0.69652563,,,,,,,,,,,,,, -168300,4.1443996,0.66352946,,,,,,,,,,,,,, -168315,,,0.95121169090271,0.173487052321434,0.7513599991798401,1.0714797973632812,50000.0,0.6273000240325928,1.837319254875183,10000.0,57181.61806106568,59269.084530353546,57181.61806106568,2076.389808654785,5.199169874191284,0.0 -168400,4.20574,0.6677282,,,,,,,,,,,,,, -168500,4.1715274,0.6784317,,,,,,,,,,,,,, -168600,4.242543,0.65475726,,,,,,,,,,,,,, -168700,4.2302203,0.6723533,,,,,,,,,,,,,, -168800,4.7182484,0.7233496,,,,,,,,,,,,,, -168900,4.778258,0.66574013,,,,,,,,,,,,,, -169000,4.554126,0.7212132,,,,,,,,,,,,,, -169100,4.814477,0.7172644,,,,,,,,,,,,,, -169200,4.68372,0.73017675,,,,,,,,,,,,,, -169300,4.4000297,0.65450084,,,,,,,,,,,,,, -169400,4.754869,0.718551,,,,,,,,,,,,,, -169500,4.446289,0.5613286,,,,,,,,,,,,,, -169600,4.8425584,0.6857818,,,,,,,,,,,,,, -169700,4.4559484,0.6577601,,,,,,,,,,,,,, -169800,4.6635976,0.6531504,,,,,,,,,,,,,, -169818,,,0.951969027519226,0.1739388257265091,0.7509399652481079,1.0662760734558103,50000.0,0.6255000233650208,1.8304463624954224,10000.0,57691.601484537125,59795.85251188278,57691.601484537125,2093.059808731079,5.259644985198975,0.0 -169900,4.88604,0.7146909,,,,,,,,,,,,,, -170000,4.347683,0.64501965,,,,,,,,,,,,,, -170100,4.4624767,0.6642924,,,,,,,,,,,,,, -170200,4.4216037,0.6673153,,,,,,,,,,,,,, -170300,4.4640756,0.6943649,,,,,,,,,,,,,, -170400,4.17406,0.62381554,,,,,,,,,,,,,, -170500,4.4536343,0.65324897,,,,,,,,,,,,,, -170600,4.270028,0.65853465,,,,,,,,,,,,,, -170700,4.6168113,0.699436,,,,,,,,,,,,,, -170800,4.611413,0.6733986,,,,,,,,,,,,,, -170900,4.720756,0.7278185,,,,,,,,,,,,,, -171000,4.436224,0.64161587,,,,,,,,,,,,,, -171100,4.3622513,0.681857,,,,,,,,,,,,,, -171200,4.8285275,0.6515007,,,,,,,,,,,,,, -171300,5.141845,0.7361392,,,,,,,,,,,,,, -171321,,,0.9516900181770324,0.1759648621082306,0.7511999607086182,1.066061019897461,50000.0,0.6276000142097473,1.834411859512329,10000.0,58201.5602388382,60322.585090875626,58201.5602388382,2109.7203953266144,5.319650173187256,0.0 -171400,4.5573907,0.6873593,,,,,,,,,,,,,, -171500,4.7736845,0.6951722,,,,,,,,,,,,,, -171600,4.7996807,0.72282934,,,,,,,,,,,,,, -171700,4.497292,0.7427629,,,,,,,,,,,,,, -171800,4.45411,0.6920898,,,,,,,,,,,,,, -171900,4.48787,0.67876786,,,,,,,,,,,,,, -172000,4.20414,0.62121344,,,,,,,,,,,,,, -172100,4.6627994,0.74244225,,,,,,,,,,,,,, -172200,4.360721,0.6197945,,,,,,,,,,,,,, -172300,4.95652,0.759892,,,,,,,,,,,,,, -172400,4.393011,0.6417143,,,,,,,,,,,,,, -172500,4.5441933,0.57997864,,,,,,,,,,,,,, -172600,4.754577,0.68792856,,,,,,,,,,,,,, -172700,4.847614,0.77295303,,,,,,,,,,,,,, -172800,4.2084236,0.6205125,,,,,,,,,,,,,, -172823,,,0.9527662396430968,0.1673062890768051,0.7521599531173706,1.0650585889816284,50000.0,0.6281000375747681,1.826152801513672,10000.0,58711.51387619972,60849.290974617004,58711.51387619972,2126.3594963550568,5.380169868469238,0.0 -172900,4.4698067,0.66473067,,,,,,,,,,,,,, -173000,4.6180725,0.71296537,,,,,,,,,,,,,, -173100,4.747243,0.5819074,,,,,,,,,,,,,, -173200,4.3566427,0.6054695,,,,,,,,,,,,,, -173300,5.217586,0.7337255,,,,,,,,,,,,,, -173400,5.012625,0.76412845,,,,,,,,,,,,,, -173500,4.217951,0.63653886,,,,,,,,,,,,,, -173600,4.730823,0.6753083,,,,,,,,,,,,,, -173700,4.9936314,0.6062936,,,,,,,,,,,,,, -173800,5.048521,0.6558293,,,,,,,,,,,,,, -173900,4.7038355,0.7017281,,,,,,,,,,,,,, -174000,4.605978,0.6126373,,,,,,,,,,,,,, -174100,4.4043846,0.6058486,,,,,,,,,,,,,, -174200,5.06601,0.6789844,,,,,,,,,,,,,, -174300,4.3682265,0.6173816,,,,,,,,,,,,,, -174326,,,0.9549385905265808,0.1686282455921173,0.7537800073623657,1.060943841934204,50000.0,0.6239000558853149,1.8231186866760247,10000.0,59221.47832560539,61375.96544504166,59221.47832560539,2142.955652713776,5.441788911819458,0.0 -174400,4.5192823,0.6423968,,,,,,,,,,,,,, -174500,4.334304,0.6320712,,,,,,,,,,,,,, -174600,4.3224053,0.6484999,,,,,,,,,,,,,, -174700,4.15606,0.6018517,,,,,,,,,,,,,, -174800,4.8509803,0.67149466,,,,,,,,,,,,,, -174900,4.1648207,0.5462134,,,,,,,,,,,,,, -175000,4.25001,0.6441619,,,,,,,,,,,,,, -175100,4.3967323,0.58633745,,,,,,,,,,,,,, -175200,4.5523715,0.64079505,,,,,,,,,,,,,, -175300,4.642414,0.60833323,,,,,,,,,,,,,, -175400,4.7500753,0.6324994,,,,,,,,,,,,,, -175500,4.7465796,0.6535889,,,,,,,,,,,,,, -175600,4.6034513,0.66864514,,,,,,,,,,,,,, -175700,5.34955,0.7059027,,,,,,,,,,,,,, -175800,4.3618555,0.6714545,,,,,,,,,,,,,, -175828,,,0.954858899116516,0.1653094291687011,0.7553600072860718,1.0575079917907717,50000.0,0.6284000277519226,1.8286449909210205,10000.0,59731.51045560837,61902.73510289192,59731.51045560837,2159.5833282470703,5.498790264129639,0.0 -175900,4.4649277,0.6225319,,,,,,,,,,,,,, -176000,4.772297,0.6917483,,,,,,,,,,,,,, -176100,4.0924706,0.6113527,,,,,,,,,,,,,, -176200,4.618281,0.6820131,,,,,,,,,,,,,, -176300,4.4178724,0.64431965,,,,,,,,,,,,,, -176400,4.5673885,0.60017514,,,,,,,,,,,,,, -176500,4.6688266,0.71233,,,,,,,,,,,,,, -176600,4.295347,0.5782438,,,,,,,,,,,,,, -176700,4.594872,0.63622206,,,,,,,,,,,,,, -176800,4.985579,0.69265634,,,,,,,,,,,,,, -176900,4.686993,0.63721186,,,,,,,,,,,,,, -177000,5.379355,0.59139985,,,,,,,,,,,,,, -177100,4.3397465,0.6346355,,,,,,,,,,,,,, -177200,4.7468443,0.5837856,,,,,,,,,,,,,, -177300,4.560556,0.7012298,,,,,,,,,,,,,, -177330,,,0.9606584906578064,0.1450536251068115,0.7543799877166748,1.057165503501892,50000.0,0.6285000443458557,1.8258801698684688,10000.0,60241.45378923416,62429.50436306,60241.45378923416,2176.286799430847,5.566740036010742,0.0 -177400,4.6812367,0.65504116,,,,,,,,,,,,,, -177500,4.214511,0.67143494,,,,,,,,,,,,,, -177600,4.577908,0.6841116,,,,,,,,,,,,,, -177700,4.4259553,0.5808498,,,,,,,,,,,,,, -177800,4.420992,0.6553903,,,,,,,,,,,,,, -177900,4.7341437,0.63454497,,,,,,,,,,,,,, -178000,4.3104186,0.6551181,,,,,,,,,,,,,, -178100,4.3270655,0.5034319,,,,,,,,,,,,,, -178200,4.4490867,0.60004133,,,,,,,,,,,,,, -178300,4.881739,0.6234791,,,,,,,,,,,,,, -178400,4.4934726,0.58523965,,,,,,,,,,,,,, -178500,4.4521036,0.65410274,,,,,,,,,,,,,, -178600,4.2362046,0.611465,,,,,,,,,,,,,, -178700,4.275678,0.5375026,,,,,,,,,,,,,, -178800,5.1916547,0.67523086,,,,,,,,,,,,,, -178833,,,0.9592633843421936,0.1504260599613189,0.7558799982070923,1.055589199066162,50000.0,0.6291000247001648,1.8220672607421875,10000.0,60751.63723874092,62956.34075641632,60751.63723874092,2192.8202583789825,5.632821083068848,0.0 -178900,4.3755517,0.62299323,,,,,,,,,,,,,, -179000,4.710382,0.64736474,,,,,,,,,,,,,, -179100,4.4100556,0.6136696,,,,,,,,,,,,,, -179200,4.5511484,0.7397487,,,,,,,,,,,,,, -179300,4.467661,0.6272414,,,,,,,,,,,,,, -179400,4.1717315,0.63887334,,,,,,,,,,,,,, -179500,4.9829674,0.6559559,,,,,,,,,,,,,, -179600,4.388622,0.67545706,,,,,,,,,,,,,, -179700,4.516138,0.5670007,,,,,,,,,,,,,, -179800,4.814214,0.66361433,,,,,,,,,,,,,, -179900,4.7127404,0.7719606,,,,,,,,,,,,,, -180000,4.3383408,0.5864077,,,,,,,,,,,,,, -180100,4.5462523,0.57909095,,,,,,,,,,,,,, -180200,4.6354713,0.6205202,,,,,,,,,,,,,, -180300,5.081221,0.66169286,,,,,,,,,,,,,, -180336,,,0.9604392051696776,0.1449994593858719,0.7549799680709839,1.0576988458633425,50000.0,0.6288000345230103,1.822423934936524,10000.0,61261.6109521389,63483.00061035156,61261.6109521389,2209.394678592682,5.691540241241455,0.0 -180400,4.4875317,0.6110374,,,,,,,,,,,,,, -180500,4.553495,0.6525601,,,,,,,,,,,,,, -180600,4.300698,0.62931097,,,,,,,,,,,,,, -180700,4.08272,0.5426609,,,,,,,,,,,,,, -180800,4.991302,0.60475296,,,,,,,,,,,,,, -180900,4.7338057,0.6603809,,,,,,,,,,,,,, -181000,4.348856,0.53848237,,,,,,,,,,,,,, -181100,4.8738685,0.6646155,,,,,,,,,,,,,, -181200,4.3354735,0.63943696,,,,,,,,,,,,,, -181300,4.4774356,0.63449186,,,,,,,,,,,,,, -181400,4.366469,0.6312546,,,,,,,,,,,,,, -181500,4.258011,0.6629175,,,,,,,,,,,,,, -181600,4.4038196,0.6153275,,,,,,,,,,,,,, -181700,4.7690763,0.6402781,,,,,,,,,,,,,, -181800,4.621947,0.73844725,,,,,,,,,,,,,, -181839,,,0.9586654901504515,0.1504426002502441,0.7546799778938293,1.0552231073379517,50000.0,0.6292000412940979,1.8193235397338867,10000.0,61771.61212205887,64009.79198360443,61771.61212205887,2226.068264245987,5.75496768951416,0.0 -181900,4.7671633,0.6323729,,,,,,,,,,,,,, -182000,4.6401024,0.6566731,,,,,,,,,,,,,, -182100,4.9470205,0.595034,,,,,,,,,,,,,, -182200,4.324022,0.5763313,,,,,,,,,,,,,, -182300,4.6548047,0.6324051,,,,,,,,,,,,,, -182400,4.5579605,0.59709644,,,,,,,,,,,,,, -182500,4.061407,0.5620955,,,,,,,,,,,,,, -182600,4.7473598,0.55037045,,,,,,,,,,,,,, -182700,4.521383,0.60526764,,,,,,,,,,,,,, -182800,4.591207,0.6724327,,,,,,,,,,,,,, -182900,4.512076,0.6488371,,,,,,,,,,,,,, -183000,5.160224,0.63616854,,,,,,,,,,,,,, -183100,4.90284,0.62168485,,,,,,,,,,,,,, -183200,4.686584,0.5562751,,,,,,,,,,,,,, -183300,4.653592,0.64613134,,,,,,,,,,,,,, -183342,,,0.9608178734779358,0.1480124741792678,0.7549600005149841,1.0551658868789673,50000.0,0.6309000253677368,1.8188600540161133,10000.0,62281.60619163513,64536.54286932945,62281.60619163513,2242.700765132904,5.826450347900391,0.0 -183400,5.1109037,0.63800764,,,,,,,,,,,,,, -183500,4.432575,0.58716154,,,,,,,,,,,,,, -183600,4.9157023,0.6808961,,,,,,,,,,,,,, -183700,4.078161,0.5240378,,,,,,,,,,,,,, -183800,4.613052,0.6302706,,,,,,,,,,,,,, -183900,4.7934437,0.6452532,,,,,,,,,,,,,, -184000,4.8641825,0.62369573,,,,,,,,,,,,,, -184100,4.427989,0.62768096,,,,,,,,,,,,,, -184200,4.6373653,0.63100845,,,,,,,,,,,,,, -184300,4.6365733,0.5465292,,,,,,,,,,,,,, -184400,4.553296,0.6897598,,,,,,,,,,,,,, -184500,4.8784676,0.6571046,,,,,,,,,,,,,, -184600,4.567739,0.620932,,,,,,,,,,,,,, -184700,5.383213,0.6338779,,,,,,,,,,,,,, -184800,5.0080757,0.6137479,,,,,,,,,,,,,, -184845,,,0.960379421710968,0.1480392068624496,0.7543999552726746,1.0552953481674194,50000.0,0.6299000382423401,1.818251609802246,10000.0,62791.7273118496,65063.375341653824,62791.7273118496,2259.293065071106,5.892799139022827,0.0 -184900,4.526795,0.60582554,,,,,,,,,,,,,, -185000,5.008316,0.6779238,,,,,,,,,,,,,, -185100,4.6913347,0.65213317,,,,,,,,,,,,,, -185200,4.580558,0.62933797,,,,,,,,,,,,,, -185300,4.539513,0.62808067,,,,,,,,,,,,,, -185400,4.4435205,0.68173724,,,,,,,,,,,,,, -185500,4.42433,0.5769428,,,,,,,,,,,,,, -185600,4.9893,0.6277588,,,,,,,,,,,,,, -185700,4.517164,0.6388081,,,,,,,,,,,,,, -185800,4.2580824,0.6384242,,,,,,,,,,,,,, -185900,4.605824,0.59588945,,,,,,,,,,,,,, -186000,4.6503177,0.6213629,,,,,,,,,,,,,, -186100,4.9469485,0.6826271,,,,,,,,,,,,,, -186200,4.704724,0.6735736,,,,,,,,,,,,,, -186300,4.8533573,0.6662823,,,,,,,,,,,,,, -186348,,,0.961316168308258,0.1417189538478851,0.7547999620437622,1.0548179149627686,50000.0,0.629800021648407,1.8192886114120483,10000.0,63301.88521718979,65590.30063033104,63301.88521718979,2275.9461719989777,5.953683853149414,0.0 -186400,5.1417527,0.60598725,,,,,,,,,,,,,, -186500,4.8862586,0.65465945,,,,,,,,,,,,,, -186600,4.610168,0.6104903,,,,,,,,,,,,,, -186700,4.544054,0.6325598,,,,,,,,,,,,,, -186800,4.7880125,0.5923464,,,,,,,,,,,,,, -186900,4.488902,0.648478,,,,,,,,,,,,,, -187000,4.4825015,0.58674073,,,,,,,,,,,,,, -187100,4.9426823,0.6423218,,,,,,,,,,,,,, -187200,4.626722,0.70470387,,,,,,,,,,,,,, -187300,4.6105638,0.686792,,,,,,,,,,,,,, -187400,4.614561,0.6177014,,,,,,,,,,,,,, -187500,4.468789,0.6912787,,,,,,,,,,,,,, -187600,4.8737326,0.68595684,,,,,,,,,,,,,, -187700,4.7775054,0.6438257,,,,,,,,,,,,,, -187800,4.574542,0.66052973,,,,,,,,,,,,,, -187851,,,0.9618940949440002,0.1434378921985626,0.7547199726104736,1.056044578552246,50000.0,0.6304000020027161,1.818678379058838,10000.0,63812.025695085526,66117.23640704155,63812.025695085526,2292.624106407165,6.019184112548828,0.0 -187900,4.9481764,0.68897694,,,,,,,,,,,,,, -188000,5.3781705,0.6275103,,,,,,,,,,,,,, -188100,3.9356627,0.57118744,,,,,,,,,,,,,, -188200,4.352857,0.5668704,,,,,,,,,,,,,, -188300,4.394617,0.5603566,,,,,,,,,,,,,, -188400,4.4041963,0.58441633,,,,,,,,,,,,,, -188500,4.611003,0.66901475,,,,,,,,,,,,,, -188600,5.031939,0.7105364,,,,,,,,,,,,,, -188700,4.5661545,0.57234037,,,,,,,,,,,,,, -188800,4.2609425,0.5586287,,,,,,,,,,,,,, -188900,4.631012,0.58809924,,,,,,,,,,,,,, -189000,4.143345,0.5356656,,,,,,,,,,,,,, -189100,4.6090403,0.640607,,,,,,,,,,,,,, -189200,4.445058,0.5596858,,,,,,,,,,,,,, -189300,4.598607,0.73511386,,,,,,,,,,,,,, -189354,,,0.9598413109779358,0.1478601396083831,0.7548399567604065,1.0551180839538574,50000.0,0.6302000284194946,1.817767858505249,10000.0,64321.92574834824,66643.83481454849,64321.92574834824,2309.2071928977966,6.081565856933594,0.0 -189400,4.636979,0.63841385,,,,,,,,,,,,,, -189500,4.291908,0.60705185,,,,,,,,,,,,,, -189600,4.839509,0.64460605,,,,,,,,,,,,,, -189700,4.688819,0.5930764,,,,,,,,,,,,,, -189800,4.3457565,0.5816332,,,,,,,,,,,,,, -189900,4.494881,0.65933526,,,,,,,,,,,,,, -190000,4.8774395,0.57071155,,,,,,,,,,,,,, -190100,4.503337,0.58001703,,,,,,,,,,,,,, -190200,4.4789195,0.63998824,,,,,,,,,,,,,, -190300,4.4093895,0.6095002,,,,,,,,,,,,,, -190400,4.167553,0.56146854,,,,,,,,,,,,,, -190500,4.1642313,0.58005124,,,,,,,,,,,,,, -190600,4.546361,0.6032562,,,,,,,,,,,,,, -190700,4.5467577,0.6343884,,,,,,,,,,,,,, -190800,4.8101053,0.6333047,,,,,,,,,,,,,, -190857,,,0.9608976244926452,0.1452582329511642,0.7548399567604065,1.0554226636886597,50000.0,0.6309000253677368,1.8206629753112795,10000.0,64831.91565823555,67171.30288815498,64831.91565823555,2326.568407535553,6.143704175949097,0.0 -190900,5.125055,0.6353626,,,,,,,,,,,,,, -191000,4.487861,0.61177355,,,,,,,,,,,,,, -191100,4.7667365,0.6364689,,,,,,,,,,,,,, -191200,4.237287,0.6116393,,,,,,,,,,,,,, -191300,4.150356,0.5625326,,,,,,,,,,,,,, -191400,4.62712,0.6491022,,,,,,,,,,,,,, -191500,4.309606,0.58743,,,,,,,,,,,,,, -191600,4.4902015,0.6299104,,,,,,,,,,,,,, -191700,4.1680183,0.6178493,,,,,,,,,,,,,, -191800,4.171288,0.5724349,,,,,,,,,,,,,, -191900,5.0878534,0.6770922,,,,,,,,,,,,,, -192000,4.8692,0.6692296,,,,,,,,,,,,,, -192100,4.9348373,0.6095462,,,,,,,,,,,,,, -192200,4.6776466,0.66578025,,,,,,,,,,,,,, -192300,4.203807,0.5360294,,,,,,,,,,,,,, -192360,,,0.9600207209587096,0.1482952535152435,0.7549399733543396,1.055193305015564,50000.0,0.6299000382423401,1.818163514137268,10000.0,65341.9563536644,67698.0499484539,65341.9563536644,2343.153601646424,6.2114434242248535,0.0 -192400,5.129816,0.687251,,,,,,,,,,,,,, -192500,4.162949,0.58312297,,,,,,,,,,,,,, -192600,4.3270645,0.59690857,,,,,,,,,,,,,, -192700,4.904703,0.6274248,,,,,,,,,,,,,, -192800,4.260665,0.5463633,,,,,,,,,,,,,, -192900,4.3538713,0.6744175,,,,,,,,,,,,,, -193000,4.6014132,0.62050295,,,,,,,,,,,,,, -193100,4.762205,0.64565825,,,,,,,,,,,,,, -193200,4.2409678,0.57098067,,,,,,,,,,,,,, -193300,4.194614,0.5857936,,,,,,,,,,,,,, -193400,4.2775893,0.5460131,,,,,,,,,,,,,, -193500,4.5608716,0.59343004,,,,,,,,,,,,,, -193600,4.5418506,0.6798177,,,,,,,,,,,,,, -193700,4.9228506,0.59381855,,,,,,,,,,,,,, -193800,4.725084,0.60116136,,,,,,,,,,,,,, -193862,,,0.9616549611091614,0.1444314867258072,0.7547399997711182,1.0553990602493286,50000.0,0.6300000548362732,1.818334460258484,10000.0,65851.87741112709,68224.74874973297,65851.87741112709,2359.8171005249023,6.272164344787598,0.0 -193900,4.3345923,0.5360868,,,,,,,,,,,,,, -194000,4.368846,0.638605,,,,,,,,,,,,,, -194100,4.419131,0.6177735,,,,,,,,,,,,,, -194200,4.688049,0.5976004,,,,,,,,,,,,,, -194300,4.5610294,0.68328685,,,,,,,,,,,,,, -194400,4.7301836,0.6911053,,,,,,,,,,,,,, -194500,4.87635,0.61842966,,,,,,,,,,,,,, -194600,4.582612,0.6014584,,,,,,,,,,,,,, -194700,4.9062285,0.60618675,,,,,,,,,,,,,, -194800,4.379206,0.66237754,,,,,,,,,,,,,, -194900,4.709978,0.67502403,,,,,,,,,,,,,, -195000,4.3299065,0.6686623,,,,,,,,,,,,,, -195100,4.82737,0.6380525,,,,,,,,,,,,,, -195200,4.4899564,0.59997714,,,,,,,,,,,,,, -195300,4.531425,0.6324216,,,,,,,,,,,,,, -195366,,,0.9608577489852904,0.1464710235595703,0.7546199560165405,1.05693256855011,50000.0,0.6297000050544739,1.819195985794068,10000.0,66361.8992049694,68751.58087086678,66361.8992049694,2376.5136275291443,6.332875728607178,0.0 -195400,4.544072,0.65064716,,,,,,,,,,,,,, -195500,4.3705955,0.63564855,,,,,,,,,,,,,, -195600,4.400526,0.60421944,,,,,,,,,,,,,, -195700,4.155817,0.54957277,,,,,,,,,,,,,, -195800,4.2930055,0.64195955,,,,,,,,,,,,,, -195900,4.2061734,0.5687947,,,,,,,,,,,,,, -196000,5.056619,0.6751983,,,,,,,,,,,,,, -196100,4.882359,0.73793983,,,,,,,,,,,,,, -196200,4.6193705,0.5915282,,,,,,,,,,,,,, -196300,4.4920497,0.65284014,,,,,,,,,,,,,, -196400,4.598594,0.5650957,,,,,,,,,,,,,, -196500,4.1223106,0.5036499,,,,,,,,,,,,,, -196600,4.6564755,0.5657278,,,,,,,,,,,,,, -196700,4.3734713,0.59240025,,,,,,,,,,,,,, -196800,4.1866345,0.63262534,,,,,,,,,,,,,, -196869,,,0.961136758327484,0.1438667327165603,0.7547799944877625,1.055078148841858,50000.0,0.6299000382423401,1.8178199529647827,10000.0,66872.12175607681,69278.48593592644,66872.12175607681,2393.079930305481,6.3962483406066895,0.0 -196900,4.6027184,0.66050565,,,,,,,,,,,,,, -197000,4.310226,0.6212318,,,,,,,,,,,,,, -197100,4.5912156,0.63961,,,,,,,,,,,,,, -197200,4.317368,0.61901975,,,,,,,,,,,,,, -197300,4.603697,0.6401897,,,,,,,,,,,,,, -197400,4.2722907,0.56321776,,,,,,,,,,,,,, -197500,4.6396303,0.5882234,,,,,,,,,,,,,, -197600,4.6431847,0.5936466,,,,,,,,,,,,,, -197700,4.220898,0.5740565,,,,,,,,,,,,,, -197800,4.302345,0.58264387,,,,,,,,,,,,,, -197900,4.649182,0.6114782,,,,,,,,,,,,,, -198000,5.131324,0.6272964,,,,,,,,,,,,,, -198100,4.4500113,0.64388174,,,,,,,,,,,,,, -198200,4.4151893,0.64972156,,,,,,,,,,,,,, -198300,4.664418,0.6344942,,,,,,,,,,,,,, -198372,,,0.9610371589660645,0.1463333070278167,0.7547599673271179,1.0546435117721558,50000.0,0.6314000487327576,1.8186239004135127,10000.0,67382.08436250687,69805.22742986679,67382.08436250687,2409.745950460434,6.455162048339844,0.0 -198400,5.041673,0.5937652,,,,,,,,,,,,,, -198500,4.084846,0.53695595,,,,,,,,,,,,,, -198600,4.6940207,0.59393764,,,,,,,,,,,,,, -198700,4.7541323,0.6089307,,,,,,,,,,,,,, -198800,4.278289,0.5691409,,,,,,,,,,,,,, -198900,4.6634455,0.6748983,,,,,,,,,,,,,, -199000,4.6690636,0.5594989,,,,,,,,,,,,,, -199100,4.3608847,0.57777816,,,,,,,,,,,,,, -199200,4.6765585,0.58087486,,,,,,,,,,,,,, -199300,4.755581,0.5864955,,,,,,,,,,,,,, -199400,4.4908857,0.6354643,,,,,,,,,,,,,, -199500,4.3267593,0.56170857,,,,,,,,,,,,,, -199600,4.508458,0.64658177,,,,,,,,,,,,,, -199700,4.215927,0.6133836,,,,,,,,,,,,,, -199800,4.658959,0.62962663,,,,,,,,,,,,,, -199874,,,0.96097731590271,0.1450323164463043,0.7546600103378296,1.0563466548919678,50000.0,0.6300000548362732,1.8194706439971924,10000.0,67891.98360681534,70332.2281923294,67891.98360681534,2426.7303347587585,6.520244121551514,0.0 -199900,4.3459873,0.57150435,,,,,,,,,,,,,, -200000,4.3861403,0.57675,,,,,,,,,,,,,, -200100,4.611813,0.6767509,,,,,,,,,,,,,, -200200,4.992525,0.6816195,,,,,,,,,,,,,, -200300,4.7548127,0.60944545,,,,,,,,,,,,,, -200400,5.028342,0.6341446,,,,,,,,,,,,,, -200500,4.3348894,0.62571824,,,,,,,,,,,,,, -200600,4.7315784,0.6323053,,,,,,,,,,,,,, -200700,4.1650867,0.5347894,,,,,,,,,,,,,, -200800,4.4995775,0.63305867,,,,,,,,,,,,,, -200900,4.656082,0.629481,,,,,,,,,,,,,, -201000,4.3246408,0.6029366,,,,,,,,,,,,,, -201100,4.4697356,0.56526756,,,,,,,,,,,,,, -201200,5.505691,0.6188172,,,,,,,,,,,,,, -201300,4.156332,0.54209316,,,,,,,,,,,,,, -201378,,,0.9603993892669678,0.14662966132164,0.7549799680709839,1.0545313358306885,50000.0,0.6297000050544739,1.8182106018066408,10000.0,68402.14937877655,70859.24960279465,68402.14937877655,2443.469132423401,6.584799528121948,0.0 -201400,4.533996,0.6396402,,,,,,,,,,,,,, -201500,4.3042965,0.62432694,,,,,,,,,,,,,, -201600,4.187157,0.5653847,,,,,,,,,,,,,, -201700,4.3382964,0.56483305,,,,,,,,,,,,,, -201800,4.4323564,0.64784026,,,,,,,,,,,,,, -201900,4.3448296,0.58421654,,,,,,,,,,,,,, -202000,4.6698823,0.64018255,,,,,,,,,,,,,, -202100,4.431425,0.65572023,,,,,,,,,,,,,, -202200,4.063715,0.5276502,,,,,,,,,,,,,, -202300,4.630384,0.6548616,,,,,,,,,,,,,, -202400,4.390069,0.5914697,,,,,,,,,,,,,, -202500,4.4476185,0.6121151,,,,,,,,,,,,,, -202600,4.0937614,0.5286918,,,,,,,,,,,,,, -202700,4.389631,0.65737146,,,,,,,,,,,,,, -202800,4.836144,0.62519675,,,,,,,,,,,,,, -202880,,,0.960758090019226,0.1468870043754577,0.7549200057983398,1.0562556982040403,50000.0,0.6297000050544739,1.819766640663147,10000.0,68912.09048843384,71385.91350674629,68912.09048843384,2460.0737595558167,6.650354385375977,0.0 -202900,5.010505,0.66432047,,,,,,,,,,,,,, -203000,4.322499,0.63226616,,,,,,,,,,,,,, -203100,4.4218345,0.6049023,,,,,,,,,,,,,, -203200,4.6557126,0.58392334,,,,,,,,,,,,,, -203300,4.728853,0.62042165,,,,,,,,,,,,,, -203400,4.605395,0.5939349,,,,,,,,,,,,,, -203500,4.416017,0.5662616,,,,,,,,,,,,,, -203600,4.6468425,0.68283445,,,,,,,,,,,,,, -203700,4.4898586,0.59892184,,,,,,,,,,,,,, -203800,4.853452,0.67965454,,,,,,,,,,,,,, -203900,4.5454264,0.6393912,,,,,,,,,,,,,, -204000,4.95574,0.6830124,,,,,,,,,,,,,, -204100,4.3144317,0.60363126,,,,,,,,,,,,,, -204200,4.813521,0.6813643,,,,,,,,,,,,,, -204300,4.192524,0.57817745,,,,,,,,,,,,,, -204383,,,0.9612563848495485,0.1475114524364471,0.7545999884605408,1.056083917617798,50000.0,0.6309000253677368,1.8184700012207031,10000.0,69422.0167837143,71912.68066692352,69422.0167837143,2476.78812289238,6.723001718521118,0.0 -204400,4.695456,0.6327949,,,,,,,,,,,,,, -204500,4.6664004,0.6247626,,,,,,,,,,,,,, -204600,5.011876,0.79132926,,,,,,,,,,,,,, -204700,4.22959,0.6102676,,,,,,,,,,,,,, -204800,4.979033,0.62975943,,,,,,,,,,,,,, -204900,4.8860793,0.65023875,,,,,,,,,,,,,, -205000,5.0187,0.6620498,,,,,,,,,,,,,, -205100,4.737475,0.6220966,,,,,,,,,,,,,, -205200,4.434011,0.6036663,,,,,,,,,,,,,, -205300,4.2390394,0.57938457,,,,,,,,,,,,,, -205400,4.2775025,0.61107427,,,,,,,,,,,,,, -205500,4.7687893,0.68818116,,,,,,,,,,,,,, -205600,4.705533,0.6040555,,,,,,,,,,,,,, -205700,4.197648,0.5710392,,,,,,,,,,,,,, -205800,4.383714,0.5969743,,,,,,,,,,,,,, -205886,,,0.958984375,0.1498239785432815,0.7548999786376953,1.0548142194747925,50000.0,0.6307000517845154,1.818460464477539,10000.0,69932.16599369049,72439.5957725048,69932.16599369049,2493.426098585129,6.797011137008667,0.0 -205900,4.391524,0.5363457,,,,,,,,,,,,,, -206000,5.0531406,0.7335253,,,,,,,,,,,,,, -206100,4.7627864,0.65825164,,,,,,,,,,,,,, -206200,4.6119695,0.6584228,,,,,,,,,,,,,, -206300,4.2120996,0.52719045,,,,,,,,,,,,,, -206400,4.690535,0.61203754,,,,,,,,,,,,,, -206500,4.4932213,0.67649764,,,,,,,,,,,,,, -206600,4.634245,0.6200344,,,,,,,,,,,,,, -206700,4.492921,0.6435827,,,,,,,,,,,,,, -206800,4.6972885,0.63636756,,,,,,,,,,,,,, -206900,4.652372,0.65548855,,,,,,,,,,,,,, -207000,4.6821494,0.6216313,,,,,,,,,,,,,, -207100,4.6584787,0.6769974,,,,,,,,,,,,,, -207200,4.541084,0.5864094,,,,,,,,,,,,,, -207300,4.255621,0.61726415,,,,,,,,,,,,,, -207389,,,0.9612165093421936,0.1428539901971817,0.7551599740982056,1.0566010475158691,50000.0,0.6303000450134277,1.8208805322647093,10000.0,70442.23829221725,72966.56232953072,70442.23829221725,2510.198974609375,6.865321636199951,0.0 -207400,4.7789435,0.67313516,,,,,,,,,,,,,, -207500,4.350134,0.6128657,,,,,,,,,,,,,, -207600,4.7622886,0.6376029,,,,,,,,,,,,,, -207700,4.4009566,0.6600475,,,,,,,,,,,,,, -207800,4.5230846,0.66701716,,,,,,,,,,,,,, -207900,4.4852567,0.6485734,,,,,,,,,,,,,, -208000,4.531042,0.6620137,,,,,,,,,,,,,, -208100,4.3223786,0.611975,,,,,,,,,,,,,, -208200,5.1214437,0.738813,,,,,,,,,,,,,, -208300,4.7059674,0.6501342,,,,,,,,,,,,,, -208400,4.242735,0.6213066,,,,,,,,,,,,,, -208500,4.395548,0.6652739,,,,,,,,,,,,,, -208600,4.821499,0.6026339,,,,,,,,,,,,,, -208700,4.8482485,0.62574923,,,,,,,,,,,,,, -208800,4.5594783,0.58326817,,,,,,,,,,,,,, -208892,,,0.960359513759613,0.1474143266677856,0.7548999786376953,1.055539608001709,50000.0,0.6303000450134277,1.818329811096192,10000.0,70952.31413340569,73493.35128498077,70952.31413340569,2526.792496442795,6.931980609893799,0.0 -208900,4.846645,0.6696366,,,,,,,,,,,,,, -209000,4.432084,0.5705925,,,,,,,,,,,,,, -209100,4.3774605,0.5997058,,,,,,,,,,,,,, -209200,4.5396523,0.61724836,,,,,,,,,,,,,, -209300,4.8510404,0.6231663,,,,,,,,,,,,,, -209400,5.256972,0.62835,,,,,,,,,,,,,, -209500,4.884725,0.6650569,,,,,,,,,,,,,, -209600,4.4741216,0.6273569,,,,,,,,,,,,,, -209700,4.983092,0.69238436,,,,,,,,,,,,,, -209800,4.1929793,0.6310233,,,,,,,,,,,,,, -209900,4.590361,0.6078646,,,,,,,,,,,,,, -210000,4.496586,0.6770459,,,,,,,,,,,,,, -210100,4.4619894,0.6374011,,,,,,,,,,,,,, -210200,4.416948,0.5996606,,,,,,,,,,,,,, -210300,4.5616236,0.60667735,,,,,,,,,,,,,, -210395,,,0.9596021771430968,0.1500817537307739,0.7551800012588501,1.0549089908599854,50000.0,0.6307000517845154,1.817723512649536,10000.0,71462.24423503876,74020.21932840347,71462.24423503876,2543.6113333702087,6.999336004257202,0.0 -210400,5.06838,0.6244587,,,,,,,,,,,,,, -210500,4.3920646,0.66543835,,,,,,,,,,,,,, -210600,4.697453,0.6429683,,,,,,,,,,,,,, -210700,4.360239,0.593108,,,,,,,,,,,,,, -210800,4.536836,0.61464304,,,,,,,,,,,,,, -210900,4.8581157,0.62786186,,,,,,,,,,,,,, -211000,4.8418,0.6689351,,,,,,,,,,,,,, -211100,4.490026,0.5870135,,,,,,,,,,,,,, -211200,4.8552003,0.6128012,,,,,,,,,,,,,, -211300,4.438215,0.6175437,,,,,,,,,,,,,, -211400,4.5642524,0.59552306,,,,,,,,,,,,,, -211500,4.3229384,0.6434891,,,,,,,,,,,,,, -211600,4.6134524,0.6543479,,,,,,,,,,,,,, -211700,4.692052,0.622923,,,,,,,,,,,,,, -211800,4.4421453,0.55322695,,,,,,,,,,,,,, -211898,,,0.9599011540412904,0.1484949290752411,0.7547799944877625,1.0549051761627195,50000.0,0.6303000450134277,1.817959427833557,10000.0,71972.2426841259,74546.93216991425,71972.2426841259,2560.212978601456,7.059044599533081,0.0 -211900,4.9495792,0.6240808,,,,,,,,,,,,,, -212000,4.4666123,0.6364896,,,,,,,,,,,,,, -212100,4.495202,0.62768364,,,,,,,,,,,,,, -212200,4.09429,0.5627135,,,,,,,,,,,,,, -212300,4.3639092,0.578588,,,,,,,,,,,,,, -212400,4.6272025,0.64594907,,,,,,,,,,,,,, -212500,4.314957,0.62472177,,,,,,,,,,,,,, -212600,4.6852527,0.66745836,,,,,,,,,,,,,, -212700,4.3212504,0.6411018,,,,,,,,,,,,,, -212800,4.853324,0.5982765,,,,,,,,,,,,,, -212900,4.418807,0.62767917,,,,,,,,,,,,,, -213000,4.434025,0.5808513,,,,,,,,,,,,,, -213100,4.447817,0.6353282,,,,,,,,,,,,,, -213200,4.948569,0.64687777,,,,,,,,,,,,,, -213300,4.0899997,0.537225,,,,,,,,,,,,,, -213400,4.4551425,0.5709901,,,,,,,,,,,,,, -213401,,,0.9609375,0.1480892598628997,0.7550999522209167,1.0552589893341064,50000.0,0.6299000382423401,1.8180344104766848,10000.0,72482.38831710815,75073.80241346359,72482.38831710815,2576.818477153778,7.12625527381897,0.0 -213500,4.622672,0.58354264,,,,,,,,,,,,,, -213600,4.520694,0.65138793,,,,,,,,,,,,,, -213700,4.4450936,0.53014946,,,,,,,,,,,,,, -213800,4.2801676,0.61752737,,,,,,,,,,,,,, -213900,4.636726,0.6271734,,,,,,,,,,,,,, -214000,4.5023155,0.60792565,,,,,,,,,,,,,, -214100,4.5746174,0.6437145,,,,,,,,,,,,,, -214200,4.4886346,0.6328895,,,,,,,,,,,,,, -214300,4.216073,0.5220087,,,,,,,,,,,,,, -214400,5.0164557,0.6175795,,,,,,,,,,,,,, -214500,4.750494,0.66119486,,,,,,,,,,,,,, -214600,4.678512,0.68718547,,,,,,,,,,,,,, -214700,4.634603,0.67696846,,,,,,,,,,,,,, -214800,4.0780344,0.55859464,,,,,,,,,,,,,, -214900,4.3603473,0.6332704,,,,,,,,,,,,,, -214904,,,0.9608976244926452,0.146467387676239,0.7546399831771851,1.054847002029419,50000.0,0.6300000548362732,1.8187536001205444,10000.0,72992.4620001316,75600.73349714279,72992.4620001316,2593.551813840866,7.196750164031982,0.0 -215000,4.9182196,0.6164415,,,,,,,,,,,,,, -215100,4.570638,0.6730424,,,,,,,,,,,,,, -215200,4.2397447,0.58289087,,,,,,,,,,,,,, -215300,4.574961,0.6304481,,,,,,,,,,,,,, -215400,4.842038,0.62871337,,,,,,,,,,,,,, -215500,4.8792577,0.70627695,,,,,,,,,,,,,, -215600,4.550683,0.5978552,,,,,,,,,,,,,, -215700,4.1549563,0.57731795,,,,,,,,,,,,,, -215800,4.51464,0.67220056,,,,,,,,,,,,,, -215900,4.6235304,0.6498543,,,,,,,,,,,,,, -216000,4.6517315,0.6429768,,,,,,,,,,,,,, -216100,4.482144,0.64406484,,,,,,,,,,,,,, -216200,4.803357,0.6697016,,,,,,,,,,,,,, -216300,4.336063,0.5911597,,,,,,,,,,,,,, -216400,4.115187,0.5575509,,,,,,,,,,,,,, -216408,,,0.9601004123687744,0.1451017260551452,0.7546799778938293,1.0557056665420532,50000.0,0.6304000020027161,1.8173750638961792,10000.0,73502.42083287239,76127.45818305016,73502.42083287239,2610.1900346279144,7.270415544509888,0.0 -216500,4.0470786,0.5485359,,,,,,,,,,,,,, -216600,5.12824,0.5915456,,,,,,,,,,,,,, -216700,4.6576858,0.6205931,,,,,,,,,,,,,, -216800,4.4744515,0.5709636,,,,,,,,,,,,,, -216900,4.627324,0.6912632,,,,,,,,,,,,,, -217000,4.358677,0.62238026,,,,,,,,,,,,,, -217100,4.01576,0.57679737,,,,,,,,,,,,,, -217200,5.050167,0.65270567,,,,,,,,,,,,,, -217300,4.6610384,0.7002867,,,,,,,,,,,,,, -217400,4.477974,0.56966287,,,,,,,,,,,,,, -217500,4.090937,0.595402,,,,,,,,,,,,,, -217600,4.2801533,0.59916675,,,,,,,,,,,,,, -217700,4.7975903,0.650661,,,,,,,,,,,,,, -217800,4.2265754,0.5778909,,,,,,,,,,,,,, -217900,4.3578153,0.6111232,,,,,,,,,,,,,, -217911,,,0.9603196382522584,0.1469211876392364,0.7546600103378296,1.0558857917785645,50000.0,0.6309000253677368,1.8198227882385247,10000.0,74012.63173341751,76654.59640598297,74012.63173341751,2626.9975488185883,7.337841749191284,0.0 -218000,4.8292136,0.7332443,,,,,,,,,,,,,, -218100,4.656845,0.65394485,,,,,,,,,,,,,, -218200,4.6798873,0.637953,,,,,,,,,,,,,, -218300,4.696273,0.60840875,,,,,,,,,,,,,, -218400,4.8497868,0.71154296,,,,,,,,,,,,,, -218500,4.8129354,0.6938681,,,,,,,,,,,,,, -218600,4.7914248,0.59262943,,,,,,,,,,,,,, -218700,4.5184007,0.6193211,,,,,,,,,,,,,, -218800,4.6220107,0.6233322,,,,,,,,,,,,,, -218900,4.865517,0.6473813,,,,,,,,,,,,,, -219000,4.8603196,0.689486,,,,,,,,,,,,,, -219100,4.9263377,0.6320424,,,,,,,,,,,,,, -219200,4.1375155,0.6265382,,,,,,,,,,,,,, -219300,4.3251014,0.5509549,,,,,,,,,,,,,, -219400,4.397197,0.57702196,,,,,,,,,,,,,, -219415,,,0.9604790806770324,0.1473052203655243,0.7550599575042725,1.0550016164779663,50000.0,0.6302000284194946,1.819336175918579,10000.0,74522.71133613586,77181.44177699089,74522.71133613586,2643.6508531570435,7.39803147315979,0.0 -219500,4.5920634,0.6488285,,,,,,,,,,,,,, -219600,4.502434,0.5728385,,,,,,,,,,,,,, -219700,4.584465,0.5973208,,,,,,,,,,,,,, -219800,4.845953,0.6691724,,,,,,,,,,,,,, -219900,4.5295324,0.6573667,,,,,,,,,,,,,, -220000,4.546956,0.6292046,,,,,,,,,,,,,, -220100,4.8590217,0.634138,,,,,,,,,,,,,, -220200,4.4261346,0.5899193,,,,,,,,,,,,,, -220300,4.491733,0.57135516,,,,,,,,,,,,,, -220400,4.512108,0.62894815,,,,,,,,,,,,,, -220500,4.3352013,0.5925597,,,,,,,,,,,,,, -220600,5.2590446,0.5530175,,,,,,,,,,,,,, -220700,4.4745436,0.55946654,,,,,,,,,,,,,, -220800,4.507504,0.59704363,,,,,,,,,,,,,, -220900,4.3183937,0.6271514,,,,,,,,,,,,,, -220918,,,0.960718274116516,0.1482438445091247,0.7550399899482727,1.0545982122421265,50000.0,0.6307000517845154,1.8186966180801392,10000.0,75032.91843128204,77708.87258601189,75032.91843128204,2660.755885362625,7.464325189590454,0.0 -221000,4.444401,0.6183752,,,,,,,,,,,,,, -221100,4.3097496,0.56235707,,,,,,,,,,,,,, -221200,4.3018894,0.60979843,,,,,,,,,,,,,, -221300,4.466582,0.63811547,,,,,,,,,,,,,, -221400,4.6519456,0.6102911,,,,,,,,,,,,,, -221500,4.6653743,0.61287653,,,,,,,,,,,,,, -221600,4.4722366,0.5850148,,,,,,,,,,,,,, -221700,4.5210233,0.61356705,,,,,,,,,,,,,, -221800,4.3457704,0.5510717,,,,,,,,,,,,,, -221900,4.3292274,0.6173575,,,,,,,,,,,,,, -222000,4.3505087,0.59768873,,,,,,,,,,,,,, -222100,4.8428373,0.6111455,,,,,,,,,,,,,, -222200,4.9453154,0.7309585,,,,,,,,,,,,,, -222300,4.2714143,0.6567895,,,,,,,,,,,,,, -222400,4.6426454,0.6750305,,,,,,,,,,,,,, -222422,,,0.960558831691742,0.1466655135154724,0.754859983921051,1.0557494163513184,50000.0,0.6304000020027161,1.819899082183838,10000.0,75543.07302880287,78235.87724661827,75543.07302880287,2677.481550216675,7.535091638565063,0.0 -222500,4.270459,0.6631026,,,,,,,,,,,,,, -222600,4.886746,0.62112427,,,,,,,,,,,,,, -222700,4.0809007,0.59640056,,,,,,,,,,,,,, -222800,4.2486134,0.59410155,,,,,,,,,,,,,, -222900,4.5921316,0.61882186,,,,,,,,,,,,,, -223000,4.3358335,0.61743754,,,,,,,,,,,,,, -223100,4.3023596,0.60340965,,,,,,,,,,,,,, -223200,4.712934,0.6732503,,,,,,,,,,,,,, -223300,4.44086,0.6356478,,,,,,,,,,,,,, -223400,4.5539556,0.5897373,,,,,,,,,,,,,, -223500,4.3621826,0.60582995,,,,,,,,,,,,,, -223600,4.485094,0.6520913,,,,,,,,,,,,,, -223700,4.448267,0.65965474,,,,,,,,,,,,,, -223800,4.5809855,0.68266207,,,,,,,,,,,,,, -223900,4.4747176,0.6355927,,,,,,,,,,,,,, -223925,,,0.9612165093421936,0.1431528031826019,0.7547399997711182,1.0563209056854248,50000.0,0.629300057888031,1.8206843137741089,10000.0,76053.1585996151,78762.75082278252,76053.1585996151,2694.1498594284058,7.601906299591064,0.0 -224000,4.5136476,0.54639375,,,,,,,,,,,,,, -224100,4.8004956,0.6503603,,,,,,,,,,,,,, -224200,4.3448396,0.62797403,,,,,,,,,,,,,, -224300,4.6120377,0.6293047,,,,,,,,,,,,,, -224400,4.8634076,0.7264889,,,,,,,,,,,,,, -224500,4.5646687,0.5864496,,,,,,,,,,,,,, -224600,4.6136594,0.5794081,,,,,,,,,,,,,, -224700,4.07061,0.57297164,,,,,,,,,,,,,, -224800,4.3933964,0.57092535,,,,,,,,,,,,,, -224900,4.3527703,0.6113236,,,,,,,,,,,,,, -225000,4.9916987,0.65034175,,,,,,,,,,,,,, -225100,5.426241,0.75779486,,,,,,,,,,,,,, -225200,4.6938887,0.63252586,,,,,,,,,,,,,, -225300,4.8139696,0.604224,,,,,,,,,,,,,, -225400,4.2276487,0.6345874,,,,,,,,,,,,,, -225428,,,0.9616350531578064,0.1421045213937759,0.7547999620437622,1.0561600923538208,50000.0,0.631100058555603,1.8175678253173828,10000.0,76563.16371178627,79289.55604577065,76563.16371178627,2710.828951358795,7.669835329055786,0.0 -225500,4.55834,0.6394553,,,,,,,,,,,,,, -225600,4.546347,0.6004008,,,,,,,,,,,,,, -225700,4.811129,0.65096533,,,,,,,,,,,,,, -225800,4.7095504,0.7205734,,,,,,,,,,,,,, -225900,4.416846,0.69256747,,,,,,,,,,,,,, -226000,4.7064533,0.6222113,,,,,,,,,,,,,, -226100,4.8016796,0.6125415,,,,,,,,,,,,,, -226200,4.6321464,0.62825006,,,,,,,,,,,,,, -226300,4.7107277,0.64794475,,,,,,,,,,,,,, -226400,4.832635,0.6649705,,,,,,,,,,,,,, -226500,4.3703294,0.63097715,,,,,,,,,,,,,, -226600,4.745461,0.6385497,,,,,,,,,,,,,, -226700,4.352709,0.65609187,,,,,,,,,,,,,, -226800,4.52838,0.583742,,,,,,,,,,,,,, -226900,4.2487526,0.6047943,,,,,,,,,,,,,, -226931,,,0.9616150856018066,0.1458224952220916,0.7543599605560303,1.0568692684173584,50000.0,0.6301000118255615,1.819921970367432,10000.0,77073.34233403206,79816.50515580177,77073.34233403206,2727.477970123291,7.7380051612854,0.0 -227000,4.598678,0.6223426,,,,,,,,,,,,,, -227100,5.110322,0.6620531,,,,,,,,,,,,,, -227200,5.0559077,0.5938262,,,,,,,,,,,,,, -227300,4.5218334,0.7071202,,,,,,,,,,,,,, -227400,4.038864,0.556783,,,,,,,,,,,,,, -227500,4.5689044,0.6339556,,,,,,,,,,,,,, -227600,4.5505238,0.69090486,,,,,,,,,,,,,, -227700,5.2548833,0.6373999,,,,,,,,,,,,,, -227800,4.740312,0.65633774,,,,,,,,,,,,,, -227900,4.636908,0.6359452,,,,,,,,,,,,,, -228000,4.561231,0.6036891,,,,,,,,,,,,,, -228100,5.2167635,0.6222138,,,,,,,,,,,,,, -228200,4.5362754,0.5970606,,,,,,,,,,,,,, -228300,4.8793793,0.6849732,,,,,,,,,,,,,, -228400,4.447444,0.54943174,,,,,,,,,,,,,, -228431,,,0.9591238498687744,0.1471176445484161,0.7551800012588501,1.0547051429748535,50000.0,0.6300000548362732,1.8176454305648804,10000.0,77582.4801542759,80343.97794318199,77582.4801542759,2744.826174020767,8.672579765319824,0.0 -228500,4.524983,0.5922112,,,,,,,,,,,,,, -228600,4.4181366,0.6501584,,,,,,,,,,,,,, -228700,4.4946847,0.6039585,,,,,,,,,,,,,, -228800,4.590641,0.6404988,,,,,,,,,,,,,, -228900,4.3633037,0.57566005,,,,,,,,,,,,,, -229000,5.034459,0.6627809,,,,,,,,,,,,,, -229100,4.9385705,0.68852377,,,,,,,,,,,,,, -229200,5.2057657,0.6699907,,,,,,,,,,,,,, -229300,4.690002,0.61874354,,,,,,,,,,,,,, -229400,4.792167,0.69470274,,,,,,,,,,,,,, -229500,4.5130644,0.6943891,,,,,,,,,,,,,, -229600,4.3396573,0.61377233,,,,,,,,,,,,,, -229700,4.7601986,0.64659,,,,,,,,,,,,,, -229800,4.9818397,0.64529085,,,,,,,,,,,,,, -229900,4.513069,0.6513221,,,,,,,,,,,,,, -229934,,,0.9608976244926452,0.1465269774198532,0.7546399831771851,1.0559604167938232,50000.0,0.6292000412940979,1.818885087966919,10000.0,78092.46252202988,80870.81914234161,78092.46252202988,2761.564717531204,8.739872694015503,0.0 -230000,4.556103,0.6380092,,,,,,,,,,,,,, -230100,4.556564,0.65064985,,,,,,,,,,,,,, -230200,4.5576715,0.6123285,,,,,,,,,,,,,, -230300,4.620896,0.5954629,,,,,,,,,,,,,, -230400,4.934392,0.6308777,,,,,,,,,,,,,, -230500,4.334075,0.6262457,,,,,,,,,,,,,, -230600,4.46843,0.61152244,,,,,,,,,,,,,, -230700,4.540753,0.59588945,,,,,,,,,,,,,, -230800,4.418082,0.5562329,,,,,,,,,,,,,, -230900,5.0123143,0.6197923,,,,,,,,,,,,,, -231000,4.9925637,0.67488354,,,,,,,,,,,,,, -231100,4.7905846,0.69763833,,,,,,,,,,,,,, -231200,4.428906,0.6333864,,,,,,,,,,,,,, -231300,4.1561313,0.59734434,,,,,,,,,,,,,, -231400,4.2161884,0.60455734,,,,,,,,,,,,,, -231438,,,0.9616549611091614,0.1448116451501846,0.7546799778938293,1.0556319952011108,50000.0,0.6313000321388245,1.8184458017349243,10000.0,78602.5905714035,81397.9317638874,78602.5905714035,2778.42622590065,8.809640645980835,0.0 -231500,4.4862075,0.55550987,,,,,,,,,,,,,, -231600,4.3023267,0.63176316,,,,,,,,,,,,,, -231700,4.578469,0.58469206,,,,,,,,,,,,,, -231800,5.3986444,0.59646547,,,,,,,,,,,,,, -231900,4.437118,0.61743504,,,,,,,,,,,,,, -232000,4.6417875,0.6036718,,,,,,,,,,,,,, -232100,4.607439,0.6329048,,,,,,,,,,,,,, -232200,4.1698556,0.5636798,,,,,,,,,,,,,, -232300,4.2932477,0.5262508,,,,,,,,,,,,,, -232400,4.2055016,0.5591827,,,,,,,,,,,,,, -232500,4.8596406,0.69016236,,,,,,,,,,,,,, -232600,4.5904584,0.6911671,,,,,,,,,,,,,, -232700,4.117414,0.5641653,,,,,,,,,,,,,, -232800,4.662381,0.667459,,,,,,,,,,,,,, -232900,4.8615427,0.6445029,,,,,,,,,,,,,, -232941,,,0.9609175324440002,0.1463182866573333,0.7546600103378296,1.056311011314392,50000.0,0.6312000155448914,1.81929874420166,10000.0,79112.5441596508,81924.59214305878,79112.5441596508,2795.0093553066254,8.879449844360352,0.0 -233000,4.7776093,0.61904573,,,,,,,,,,,,,, -233100,4.512397,0.6187302,,,,,,,,,,,,,, -233200,4.6557345,0.5967113,,,,,,,,,,,,,, -233300,4.5007744,0.5792643,,,,,,,,,,,,,, -233400,4.4777184,0.62405884,,,,,,,,,,,,,, -233500,4.9674,0.63423103,,,,,,,,,,,,,, -233600,4.6784906,0.64609224,,,,,,,,,,,,,, -233700,4.4355197,0.61235166,,,,,,,,,,,,,, -233800,4.1631556,0.58183795,,,,,,,,,,,,,, -233900,4.368229,0.62240744,,,,,,,,,,,,,, -234000,4.645495,0.6076487,,,,,,,,,,,,,, -234100,4.5660076,0.69227934,,,,,,,,,,,,,, -234200,4.88423,0.70148766,,,,,,,,,,,,,, -234300,5.0129743,0.65479285,,,,,,,,,,,,,, -234400,4.356862,0.6839307,,,,,,,,,,,,,, -234445,,,0.961316168308258,0.1448947787284851,0.7548799514770508,1.055671215057373,50000.0,0.6300000548362732,1.819419264793396,10000.0,79622.69375491142,82451.54144382477,79622.69375491142,2811.688470363617,8.947274923324585,0.0 -234500,4.455293,0.59576917,,,,,,,,,,,,,, -234600,4.64099,0.6130111,,,,,,,,,,,,,, -234700,4.5336714,0.6172856,,,,,,,,,,,,,, -234800,4.678274,0.65718424,,,,,,,,,,,,,, -234900,4.185665,0.5385774,,,,,,,,,,,,,, -235000,4.629076,0.59017396,,,,,,,,,,,,,, -235100,4.4402995,0.5582868,,,,,,,,,,,,,, -235200,4.5968714,0.643117,,,,,,,,,,,,,, -235300,4.723831,0.6755792,,,,,,,,,,,,,, -235400,4.088648,0.6377004,,,,,,,,,,,,,, -235500,4.288643,0.6204878,,,,,,,,,,,,,, -235600,4.7609253,0.67133176,,,,,,,,,,,,,, -235700,4.5362015,0.598839,,,,,,,,,,,,,, -235800,4.574755,0.6277254,,,,,,,,,,,,,, -235900,4.4927335,0.61781263,,,,,,,,,,,,,, -235948,,,0.9607979655265808,0.1469217389822006,0.754859983921051,1.054784059524536,50000.0,0.6304000020027161,1.8185477256774905,10000.0,80132.81121706963,82978.43870472908,80132.81121706963,2828.3450248241425,9.016587018966677,0.0 -236000,4.533451,0.631655,,,,,,,,,,,,,, -236100,4.3425455,0.5668766,,,,,,,,,,,,,, -236200,4.540193,0.65475047,,,,,,,,,,,,,, -236300,4.0584073,0.6410712,,,,,,,,,,,,,, -236400,4.2453876,0.58377856,,,,,,,,,,,,,, -236500,4.631583,0.64821815,,,,,,,,,,,,,, -236600,4.6518993,0.5948051,,,,,,,,,,,,,, -236700,4.7347116,0.6628589,,,,,,,,,,,,,, -236800,4.0802994,0.56414735,,,,,,,,,,,,,, -236900,4.3675156,0.66160786,,,,,,,,,,,,,, -237000,3.9512327,0.5322435,,,,,,,,,,,,,, -237100,4.87599,0.65844387,,,,,,,,,,,,,, -237200,4.9351926,0.6947006,,,,,,,,,,,,,, -237300,6.130764,0.67235917,,,,,,,,,,,,,, -237400,4.7033434,0.7173109,,,,,,,,,,,,,, -237452,,,0.9615154266357422,0.1444291174411773,0.7550999522209167,1.0559808015823364,50000.0,0.6301000118255615,1.819985270500183,10000.0,80643.00874638557,83505.44635510445,80643.00874638557,2845.0322086811066,9.085235595703123,0.0 -237500,4.824873,0.7126275,,,,,,,,,,,,,, -237600,4.635235,0.6854992,,,,,,,,,,,,,, -237700,4.618545,0.5895499,,,,,,,,,,,,,, -237800,4.348556,0.5493577,,,,,,,,,,,,,, -237900,4.53261,0.6120558,,,,,,,,,,,,,, -238000,4.2361774,0.58253103,,,,,,,,,,,,,, -238100,4.50185,0.60084605,,,,,,,,,,,,,, -238200,4.2445445,0.6010359,,,,,,,,,,,,,, -238300,4.568321,0.70402265,,,,,,,,,,,,,, -238400,4.3418436,0.6127731,,,,,,,,,,,,,, -238500,4.754669,0.68207955,,,,,,,,,,,,,, -238600,4.330005,0.5710279,,,,,,,,,,,,,, -238700,4.495814,0.6223423,,,,,,,,,,,,,, -238800,4.2180085,0.5654904,,,,,,,,,,,,,, -238900,4.699311,0.64260495,,,,,,,,,,,,,, -238955,,,0.9600805044174194,0.1469574570655822,0.7549200057983398,1.055910587310791,50000.0,0.631100058555603,1.8194414377212524,10000.0,81153.17875504494,84032.39768075943,81153.17875504494,2861.686569929123,9.158106327056885,0.0 -239000,4.2107186,0.58899856,,,,,,,,,,,,,, -239100,4.377853,0.57161564,,,,,,,,,,,,,, -239200,4.7813635,0.6734291,,,,,,,,,,,,,, -239300,4.744276,0.6616032,,,,,,,,,,,,,, -239400,4.8454585,0.6350786,,,,,,,,,,,,,, -239500,4.9157276,0.58859146,,,,,,,,,,,,,, -239600,5.034118,0.6378436,,,,,,,,,,,,,, -239700,4.4364305,0.6266659,,,,,,,,,,,,,, -239800,4.957208,0.62527984,,,,,,,,,,,,,, -239900,4.3892407,0.6377628,,,,,,,,,,,,,, -240000,4.084123,0.5994213,,,,,,,,,,,,,, -240100,4.8714423,0.609586,,,,,,,,,,,,,, -240200,4.174089,0.57078874,,,,,,,,,,,,,, -240300,4.421714,0.5921739,,,,,,,,,,,,,, -240400,4.2782583,0.5638226,,,,,,,,,,,,,, -240459,,,0.9600406289100648,0.1476673632860183,0.7547399997711182,1.0556063652038574,50000.0,0.6307000517845154,1.8203246593475344,10000.0,81663.18556928635,84559.16914653778,81663.18556928635,2878.328460454941,9.227638959884644,0.0 -240500,4.5549893,0.59792525,,,,,,,,,,,,,, -240600,4.824353,0.6397881,,,,,,,,,,,,,, -240700,4.510099,0.6372234,,,,,,,,,,,,,, -240800,4.8087606,0.5990955,,,,,,,,,,,,,, -240900,4.647966,0.6508632,,,,,,,,,,,,,, -241000,4.222625,0.5280026,,,,,,,,,,,,,, -241100,4.5452085,0.68835795,,,,,,,,,,,,,, -241200,4.3751917,0.5956895,,,,,,,,,,,,,, -241300,4.9604735,0.65908986,,,,,,,,,,,,,, -241400,4.992564,0.696278,,,,,,,,,,,,,, -241500,4.555334,0.5538093,,,,,,,,,,,,,, -241600,4.5297523,0.62857294,,,,,,,,,,,,,, -241700,4.165694,0.5249791,,,,,,,,,,,,,, -241800,4.6461854,0.67282164,,,,,,,,,,,,,, -241900,4.444906,0.60100853,,,,,,,,,,,,,, -241961,,,0.961316168308258,0.1453227698802948,0.7550599575042725,1.0559238195419312,50000.0,0.6307000517845154,1.8199002742767327,10000.0,82173.29926466942,85086.36740994453,82173.29926466942,2895.289322376251,9.2988703250885,0.0 -242000,4.7046347,0.6217312,,,,,,,,,,,,,, -242100,4.2977543,0.5972431,,,,,,,,,,,,,, -242200,4.399974,0.6917204,,,,,,,,,,,,,, -242300,4.6304555,0.66689765,,,,,,,,,,,,,, -242400,4.577929,0.62641454,,,,,,,,,,,,,, -242500,4.712551,0.6465136,,,,,,,,,,,,,, -242600,4.6662197,0.6694325,,,,,,,,,,,,,, -242700,4.313624,0.52168167,,,,,,,,,,,,,, -242800,4.873144,0.65587354,,,,,,,,,,,,,, -242900,4.296283,0.63852733,,,,,,,,,,,,,, -243000,4.7511697,0.64066267,,,,,,,,,,,,,, -243100,4.5770035,0.62523437,,,,,,,,,,,,,, -243200,4.4101944,0.5807067,,,,,,,,,,,,,, -243300,4.431615,0.590026,,,,,,,,,,,,,, -243400,4.9348297,0.6313247,,,,,,,,,,,,,, -243464,,,0.9612762928009032,0.1468772292137146,0.7552199959754944,1.054276943206787,50000.0,0.629800021648407,1.8192369937896729,10000.0,82683.24376130104,85612.96273779869,82683.24376130104,2911.817486524582,9.36837124824524,0.0 -243500,4.962791,0.66950196,,,,,,,,,,,,,, -243600,5.099065,0.6543493,,,,,,,,,,,,,, -243700,4.993127,0.6229811,,,,,,,,,,,,,, -243800,4.485688,0.6282597,,,,,,,,,,,,,, -243900,4.371481,0.5986489,,,,,,,,,,,,,, -244000,4.5464725,0.6023066,,,,,,,,,,,,,, -244100,4.538159,0.62984324,,,,,,,,,,,,,, -244200,4.5607963,0.63338184,,,,,,,,,,,,,, -244300,4.168958,0.6041671,,,,,,,,,,,,,, -244400,4.921389,0.66780543,,,,,,,,,,,,,, -244500,4.770162,0.689029,,,,,,,,,,,,,, -244600,4.1943126,0.5700027,,,,,,,,,,,,,, -244700,5.438496,0.7084332,,,,,,,,,,,,,, -244800,4.0891204,0.5349739,,,,,,,,,,,,,, -244900,4.2548447,0.6678517,,,,,,,,,,,,,, -244966,,,0.9592832922935486,0.1474329233169555,0.7551599740982056,1.0548502206802368,50000.0,0.6302000284194946,1.8175748586654663,10000.0,83193.1289358139,86139.756534338,83193.1289358139,2928.6007051467896,9.440913200378418,0.0 -245000,4.1772513,0.5549262,,,,,,,,,,,,,, -245100,4.5379586,0.56235313,,,,,,,,,,,,,, -245200,4.6633964,0.56128603,,,,,,,,,,,,,, -245300,4.5685844,0.6276852,,,,,,,,,,,,,, -245400,5.0264363,0.6481079,,,,,,,,,,,,,, -245500,4.5313945,0.6678859,,,,,,,,,,,,,, -245600,4.243262,0.58502287,,,,,,,,,,,,,, -245700,4.54677,0.6248001,,,,,,,,,,,,,, -245800,4.810343,0.59721166,,,,,,,,,,,,,, -245900,4.1998134,0.57153505,,,,,,,,,,,,,, -246000,4.567997,0.7025907,,,,,,,,,,,,,, -246100,4.567004,0.641078,,,,,,,,,,,,,, -246200,4.4523673,0.565323,,,,,,,,,,,,,, -246300,4.8786864,0.6370191,,,,,,,,,,,,,, -246400,4.7592864,0.6290132,,,,,,,,,,,,,, -246469,,,0.9610969424247742,0.1457830071449279,0.7546199560165405,1.0571213960647583,50000.0,0.6319000124931335,1.820106506347656,10000.0,83703.21215891838,86666.73375701904,83703.21215891838,2945.365699291229,9.517107009887695,0.0 -246500,4.570635,0.695432,,,,,,,,,,,,,, -246600,5.0060124,0.67091036,,,,,,,,,,,,,, -246700,4.9332356,0.60797715,,,,,,,,,,,,,, -246800,4.8042226,0.62769336,,,,,,,,,,,,,, -246900,4.5828876,0.5671116,,,,,,,,,,,,,, -247000,4.2981563,0.6236333,,,,,,,,,,,,,, -247100,4.811504,0.66214097,,,,,,,,,,,,,, -247200,4.5007286,0.6222176,,,,,,,,,,,,,, -247300,5.0283113,0.6634019,,,,,,,,,,,,,, -247400,4.841413,0.6527837,,,,,,,,,,,,,, -247500,4.624152,0.6940411,,,,,,,,,,,,,, -247600,4.748119,0.65321475,,,,,,,,,,,,,, -247700,4.544793,0.60305876,,,,,,,,,,,,,, -247800,4.2260623,0.5991323,,,,,,,,,,,,,, -247900,4.8385673,0.63736564,,,,,,,,,,,,,, -247972,,,0.9592633843421936,0.1491094678640365,0.7548399567604065,1.0558416843414309,50000.0,0.6305000185966492,1.8187205791473389,10000.0,84213.20818805695,87193.41122412682,84213.20818805695,2961.9213037490845,9.590791940689089,0.0 -248000,4.425165,0.5758693,,,,,,,,,,,,,, -248100,4.683919,0.6246067,,,,,,,,,,,,,, -248200,4.27887,0.5740732,,,,,,,,,,,,,, -248300,4.5911326,0.67249805,,,,,,,,,,,,,, -248400,4.3034954,0.63034266,,,,,,,,,,,,,, -248500,4.2967696,0.5373712,,,,,,,,,,,,,, -248600,4.045967,0.49614102,,,,,,,,,,,,,, -248700,4.686867,0.6207304,,,,,,,,,,,,,, -248800,4.1268663,0.5614235,,,,,,,,,,,,,, -248900,4.557385,0.58252615,,,,,,,,,,,,,, -249000,4.38448,0.63666725,,,,,,,,,,,,,, -249100,4.4748607,0.6019818,,,,,,,,,,,,,, -249200,4.8577814,0.72737014,,,,,,,,,,,,,, -249300,4.55907,0.5838998,,,,,,,,,,,,,, -249400,5.175416,0.64716554,,,,,,,,,,,,,, -249475,,,0.960379421710968,0.1473689973354339,0.7548799514770508,1.05515718460083,50000.0,0.6299000382423401,1.820226669311524,10000.0,84723.22751140594,87720.24756479263,84723.22751140594,2978.6110417842865,9.664908409118652,0.0 -249500,4.309726,0.6253177,,,,,,,,,,,,,, -249600,4.48873,0.6297668,,,,,,,,,,,,,, -249700,4.278827,0.55122006,,,,,,,,,,,,,, -249800,4.229406,0.62207174,,,,,,,,,,,,,, -249900,4.6157045,0.5983608,,,,,,,,,,,,,, -250000,3.9199963,0.50499505,,,,,,,,,,,,,, -250100,4.7232203,0.6220696,,,,,,,,,,,,,, -250200,4.2751813,0.5532416,,,,,,,,,,,,,, -250300,4.8457217,0.66707623,,,,,,,,,,,,,, -250400,3.9855533,0.5530562,,,,,,,,,,,,,, -250500,4.5471854,0.6677196,,,,,,,,,,,,,, -250600,4.4620624,0.6075408,,,,,,,,,,,,,, -250700,4.101652,0.5734722,,,,,,,,,,,,,, -250800,4.4396024,0.6188859,,,,,,,,,,,,,, -250900,4.727575,0.58812046,,,,,,,,,,,,,, -250978,,,0.9597018361091614,0.1512549221515655,0.7553199529647827,1.055238127708435,50000.0,0.6300000548362732,1.8185988664627075,10000.0,85233.19104790688,88247.156447649,85233.19104790688,2995.421051502228,9.74666666984558,0.0 -251000,4.366495,0.6303606,,,,,,,,,,,,,, -251100,4.2787275,0.60909665,,,,,,,,,,,,,, -251200,4.5656524,0.65761095,,,,,,,,,,,,,, -251300,4.6968694,0.55980414,,,,,,,,,,,,,, -251400,4.572525,0.6346041,,,,,,,,,,,,,, -251500,4.3566394,0.6709713,,,,,,,,,,,,,, -251600,4.157314,0.54916525,,,,,,,,,,,,,, -251700,4.112456,0.6422365,,,,,,,,,,,,,, -251800,4.801655,0.653569,,,,,,,,,,,,,, -251900,4.446861,0.61121196,,,,,,,,,,,,,, -252000,4.562573,0.58951527,,,,,,,,,,,,,, -252100,4.646165,0.6043131,,,,,,,,,,,,,, -252200,4.773222,0.68013626,,,,,,,,,,,,,, -252300,4.85258,0.6168351,,,,,,,,,,,,,, -252400,4.8248076,0.6680263,,,,,,,,,,,,,, -252481,,,0.9615154266357422,0.1466091573238372,0.7547799944877625,1.0560367107391355,50000.0,0.6306000351905823,1.818914532661438,10000.0,85743.23414087296,88773.97137570381,85743.23414087296,3012.068596839905,9.817442417144775,0.0 -252500,4.487349,0.6373032,,,,,,,,,,,,,, -252600,4.848591,0.5852717,,,,,,,,,,,,,, -252700,4.3529963,0.60872257,,,,,,,,,,,,,, -252800,4.025001,0.606058,,,,,,,,,,,,,, -252900,5.267694,0.6881919,,,,,,,,,,,,,, -253000,4.5250626,0.64525396,,,,,,,,,,,,,, -253100,4.207703,0.59770274,,,,,,,,,,,,,, -253200,4.430555,0.5982965,,,,,,,,,,,,,, -253300,4.148341,0.5885165,,,,,,,,,,,,,, -253400,4.66776,0.60758305,,,,,,,,,,,,,, -253500,4.85708,0.6087775,,,,,,,,,,,,,, -253600,4.6931176,0.6154963,,,,,,,,,,,,,, -253700,4.727656,0.6298989,,,,,,,,,,,,,, -253800,4.5573306,0.63048613,,,,,,,,,,,,,, -253900,4.5640874,0.6208744,,,,,,,,,,,,,, -253984,,,0.9611168503761292,0.1454275101423263,0.7546799778938293,1.0555272102355957,50000.0,0.6312000155448914,1.8176671266555784,10000.0,86253.38255238533,89300.87192893028,86253.38255238533,3028.657395362854,9.926298141479492,0.0 -254000,4.7134295,0.5982732,,,,,,,,,,,,,, -254100,4.6531973,0.5510918,,,,,,,,,,,,,, -254200,4.137948,0.607226,,,,,,,,,,,,,, -254300,4.916045,0.62505454,,,,,,,,,,,,,, -254400,4.8098607,0.6286521,,,,,,,,,,,,,, -254500,4.4842873,0.6435931,,,,,,,,,,,,,, -254600,4.686002,0.63054764,,,,,,,,,,,,,, -254700,4.137619,0.5463696,,,,,,,,,,,,,, -254800,4.553942,0.6378953,,,,,,,,,,,,,, -254900,4.84432,0.64082885,,,,,,,,,,,,,, -255000,4.740616,0.59168714,,,,,,,,,,,,,, -255100,4.378804,0.5796092,,,,,,,,,,,,,, -255200,4.0887337,0.5476762,,,,,,,,,,,,,, -255300,4.3082833,0.5860859,,,,,,,,,,,,,, -255400,4.444745,0.6305635,,,,,,,,,,,,,, -255488,,,0.9581273794174194,0.1509248316287994,0.7550199627876282,1.0555461645126345,50000.0,0.6297000050544739,1.818804144859314,10000.0,86763.52758550644,89827.80944180489,86763.52758550644,3045.324955224991,9.99882435798645,0.0 -255500,4.2650213,0.5441462,,,,,,,,,,,,,, -255600,5.2165737,0.6274351,,,,,,,,,,,,,, -255700,4.4373617,0.61507624,,,,,,,,,,,,,, -255800,4.6592093,0.63869,,,,,,,,,,,,,, -255900,4.353531,0.6380309,,,,,,,,,,,,,, -256000,4.339452,0.5831909,,,,,,,,,,,,,, -256100,4.508395,0.652649,,,,,,,,,,,,,, -256200,4.5347204,0.7163084,,,,,,,,,,,,,, -256300,4.69999,0.70264757,,,,,,,,,,,,,, -256400,4.4652143,0.6185999,,,,,,,,,,,,,, -256500,4.6680346,0.69336164,,,,,,,,,,,,,, -256600,4.485945,0.61183715,,,,,,,,,,,,,, -256700,4.9987965,0.62976635,,,,,,,,,,,,,, -256800,4.291041,0.6264666,,,,,,,,,,,,,, -256900,4.3639774,0.59098,,,,,,,,,,,,,, -256991,,,0.9624322056770324,0.1416583657264709,0.7550199627876282,1.0557940006256104,50000.0,0.6306000351905823,1.8188364505767824,10000.0,87273.6397869587,90354.68846488,87273.6397869587,3061.9612522125244,10.076414108276367,0.0 -257000,4.2968254,0.58669436,,,,,,,,,,,,,, -257100,4.483092,0.70010304,,,,,,,,,,,,,, -257200,4.396215,0.54204303,,,,,,,,,,,,,, -257300,4.6359224,0.6437843,,,,,,,,,,,,,, -257400,4.3698673,0.6256032,,,,,,,,,,,,,, -257500,4.6304297,0.6212368,,,,,,,,,,,,,, -257600,4.8578396,0.6380627,,,,,,,,,,,,,, -257700,4.945951,0.63382906,,,,,,,,,,,,,, -257800,5.50204,0.64132446,,,,,,,,,,,,,, -257900,4.738562,0.64509594,,,,,,,,,,,,,, -258000,4.381515,0.6144603,,,,,,,,,,,,,, -258100,4.5863175,0.63569623,,,,,,,,,,,,,, -258200,4.65543,0.6887486,,,,,,,,,,,,,, -258300,4.672913,0.55218035,,,,,,,,,,,,,, -258400,4.393299,0.56594145,,,,,,,,,,,,,, -258494,,,0.9602399468421936,0.1474353671073913,0.7547399997711182,1.0550732612609863,50000.0,0.6300000548362732,1.818835020065308,10000.0,87783.56503915787,90881.38066411018,87783.56503915787,3078.59850025177,10.152153968811035,0.0 -258500,4.27596,0.6069577,,,,,,,,,,,,,, -258600,4.4499855,0.61611927,,,,,,,,,,,,,, -258700,3.9297187,0.5152648,,,,,,,,,,,,,, -258800,4.71099,0.6740055,,,,,,,,,,,,,, -258900,4.674829,0.61696917,,,,,,,,,,,,,, -259000,4.5972786,0.6002136,,,,,,,,,,,,,, -259100,4.3482614,0.6343058,,,,,,,,,,,,,, -259200,4.5858493,0.5882729,,,,,,,,,,,,,, -259300,4.326679,0.5873234,,,,,,,,,,,,,, -259400,4.455361,0.5662225,,,,,,,,,,,,,, -259500,4.8243837,0.6345751,,,,,,,,,,,,,, -259600,5.607529,0.6357538,,,,,,,,,,,,,, -259700,4.146115,0.5857607,,,,,,,,,,,,,, -259800,5.1086392,0.63049287,,,,,,,,,,,,,, -259900,4.301036,0.5676484,,,,,,,,,,,,,, -259997,,,0.961316168308258,0.146479919552803,0.7549399733543396,1.0550395250320437,50000.0,0.6309000253677368,1.81663715839386,10000.0,88293.56649446487,91408.12760949136,88293.56649446487,3095.220307826996,10.222743272781372,0.0 -260000,4.7818174,0.67263585,,,,,,,,,,,,,, -260100,4.111416,0.5978156,,,,,,,,,,,,,, -260200,4.1446524,0.58551615,,,,,,,,,,,,,, -260300,4.4253025,0.6187333,,,,,,,,,,,,,, -260400,4.642808,0.6154227,,,,,,,,,,,,,, -260500,4.781334,0.667137,,,,,,,,,,,,,, -260600,4.957957,0.6626243,,,,,,,,,,,,,, -260700,4.460141,0.60046583,,,,,,,,,,,,,, -260800,4.736581,0.57704,,,,,,,,,,,,,, -260900,4.5082,0.58751374,,,,,,,,,,,,,, -261000,4.391472,0.59209144,,,,,,,,,,,,,, -261100,4.1329155,0.5797938,,,,,,,,,,,,,, -261200,4.6430736,0.61456144,,,,,,,,,,,,,, -261300,4.452021,0.6669093,,,,,,,,,,,,,, -261400,4.8207316,0.67387354,,,,,,,,,,,,,, -261500,,,0.9601402878761292,0.1482132524251938,0.7545199990272522,1.0562045574188232,50000.0,0.6307000517845154,1.820395946502685,10000.0,88803.53046774864,91934.89223361015,88803.53046774864,3111.8932209014893,10.29693603515625,0.0 -261500,4.489765,0.5456593,,,,,,,,,,,,,, -261600,4.1207056,0.65320045,,,,,,,,,,,,,, -261700,3.9943905,0.6005456,,,,,,,,,,,,,, -261800,4.216407,0.6325641,,,,,,,,,,,,,, -261900,4.610468,0.5905355,,,,,,,,,,,,,, -262000,4.4578524,0.53226376,,,,,,,,,,,,,, -262100,4.669999,0.6340691,,,,,,,,,,,,,, -262200,4.2528453,0.56295663,,,,,,,,,,,,,, -262300,4.6429358,0.5950012,,,,,,,,,,,,,, -262400,4.7883887,0.63674825,,,,,,,,,,,,,, -262500,4.688604,0.69660896,,,,,,,,,,,,,, -262600,4.354911,0.6222793,,,,,,,,,,,,,, -262700,4.7148356,0.6580938,,,,,,,,,,,,,, -262800,4.4551926,0.5875049,,,,,,,,,,,,,, -262900,4.905035,0.6193621,,,,,,,,,,,,,, -263000,4.389555,0.63233775,,,,,,,,,,,,,, -263002,,,0.9615154266357422,0.1408833563327789,0.7550599575042725,1.0564838647842407,50000.0,0.6306000351905823,1.81957483291626,10000.0,89313.42373251915,92461.68709921835,89313.42373251915,3128.666751384735,10.373113632202148,0.0 -263100,5.2717977,0.63177,,,,,,,,,,,,,, -263200,4.4620824,0.6500348,,,,,,,,,,,,,, -263300,4.519573,0.6514629,,,,,,,,,,,,,, -263400,5.0871863,0.75626373,,,,,,,,,,,,,, -263500,4.7698417,0.71204776,,,,,,,,,,,,,, -263600,4.867262,0.6957654,,,,,,,,,,,,,, -263700,4.2214255,0.56655014,,,,,,,,,,,,,, -263800,4.6633964,0.6600037,,,,,,,,,,,,,, -263900,4.3868675,0.6293961,,,,,,,,,,,,,, -264000,4.4084806,0.63602805,,,,,,,,,,,,,, -264100,5.024837,0.6176484,,,,,,,,,,,,,, -264200,4.4863234,0.6367358,,,,,,,,,,,,,, -264300,4.8887906,0.5811771,,,,,,,,,,,,,, -264400,4.5514364,0.5853383,,,,,,,,,,,,,, -264500,4.575133,0.67105746,,,,,,,,,,,,,, -264505,,,0.9611168503761292,0.146037608385086,0.7548799514770508,1.054930329322815,50000.0,0.6299000382423401,1.818839192390442,10000.0,89823.2990591526,92988.3552968502,89823.2990591526,3145.334497213364,10.445182085037231,0.0 -264600,4.773853,0.6287998,,,,,,,,,,,,,, -264700,4.5920935,0.6044292,,,,,,,,,,,,,, -264800,4.714787,0.65379333,,,,,,,,,,,,,, -264900,4.4664655,0.60887283,,,,,,,,,,,,,, -265000,4.5328712,0.6138262,,,,,,,,,,,,,, -265100,4.6081614,0.5377252,,,,,,,,,,,,,, -265200,4.2204657,0.56652635,,,,,,,,,,,,,, -265300,4.2619705,0.5881666,,,,,,,,,,,,,, -265400,4.7931247,0.6404138,,,,,,,,,,,,,, -265500,4.6022654,0.6394851,,,,,,,,,,,,,, -265600,4.6882405,0.63272345,,,,,,,,,,,,,, -265700,4.2940917,0.5667136,,,,,,,,,,,,,, -265800,4.5547557,0.6040341,,,,,,,,,,,,,, -265900,4.4533443,0.65650296,,,,,,,,,,,,,, -266000,4.6556973,0.6676836,,,,,,,,,,,,,, -266008,,,0.9613759517669678,0.1456584334373474,0.7547999620437622,1.0558500289916992,50000.0,0.6300000548362732,1.820417881011963,10000.0,90333.40451908112,93515.5467464924,90333.40451908112,3162.2875208854675,10.524413585662842,0.0 -266100,4.9227195,0.7523911,,,,,,,,,,,,,, -266200,4.53326,0.5764684,,,,,,,,,,,,,, -266300,4.68842,0.6643225,,,,,,,,,,,,,, -266400,4.418803,0.59397984,,,,,,,,,,,,,, -266500,4.1880503,0.65299004,,,,,,,,,,,,,, -266600,4.825826,0.66862994,,,,,,,,,,,,,, -266700,4.5413904,0.66356117,,,,,,,,,,,,,, -266800,4.412821,0.6543404,,,,,,,,,,,,,, -266900,4.803373,0.69639707,,,,,,,,,,,,,, -267000,4.8881483,0.6674641,,,,,,,,,,,,,, -267100,4.5909305,0.6154903,,,,,,,,,,,,,, -267200,4.5808945,0.6314786,,,,,,,,,,,,,, -267300,4.7202296,0.77821314,,,,,,,,,,,,,, -267400,4.3329062,0.57039654,,,,,,,,,,,,,, -267500,4.298265,0.6422409,,,,,,,,,,,,,, -267511,,,0.9602399468421936,0.1465692073106765,0.7547199726104736,1.0563976764678955,50000.0,0.6303000450134277,1.820171356201172,10000.0,90843.54014015198,94043.1495695114,90843.54014015198,3179.641728401184,10.585381507873535,0.0 -267600,5.0421925,0.7018424,,,,,,,,,,,,,, -267700,4.618347,0.62652934,,,,,,,,,,,,,, -267800,4.658087,0.6485109,,,,,,,,,,,,,, -267900,4.316448,0.57625085,,,,,,,,,,,,,, -268000,4.9037747,0.64922947,,,,,,,,,,,,,, -268100,4.5421233,0.656981,,,,,,,,,,,,,, -268200,4.4779897,0.6814383,,,,,,,,,,,,,, -268300,4.727154,0.68533444,,,,,,,,,,,,,, -268400,4.7020116,0.6168737,,,,,,,,,,,,,, -268500,4.4387126,0.646118,,,,,,,,,,,,,, -268600,4.402172,0.54988277,,,,,,,,,,,,,, -268700,4.983817,0.64980376,,,,,,,,,,,,,, -268800,4.353926,0.5506218,,,,,,,,,,,,,, -268900,4.3484993,0.59002125,,,,,,,,,,,,,, -269000,4.5139236,0.6545752,,,,,,,,,,,,,, -269014,,,0.9600805044174194,0.14800925552845,0.7546399831771851,1.0555812120437622,50000.0,0.6302000284194946,1.819125056266785,10000.0,91353.7362806797,94570.11760783195,91353.7362806797,3196.283055782318,10.662968635559082,0.0 -269100,4.2737412,0.6043134,,,,,,,,,,,,,, -269200,4.338587,0.6708481,,,,,,,,,,,,,, -269300,4.685599,0.63636976,,,,,,,,,,,,,, -269400,4.7048707,0.58639425,,,,,,,,,,,,,, -269500,4.3747063,0.63509625,,,,,,,,,,,,,, -269600,4.4212627,0.56298006,,,,,,,,,,,,,, -269700,4.500646,0.6475111,,,,,,,,,,,,,, -269800,4.8884454,0.5616783,,,,,,,,,,,,,, -269900,4.559683,0.6239815,,,,,,,,,,,,,, -270000,4.549536,0.6409617,,,,,,,,,,,,,, -270100,5.058126,0.6816349,,,,,,,,,,,,,, -270200,4.580528,0.62813693,,,,,,,,,,,,,, -270300,4.685869,0.60033023,,,,,,,,,,,,,, -270400,4.433741,0.60685027,,,,,,,,,,,,,, -270500,4.42368,0.61604047,,,,,,,,,,,,,, -270517,,,0.9622727632522584,0.1424338072538375,0.7545199990272522,1.0561974048614502,50000.0,0.6313000321388245,1.819807291030884,10000.0,91863.84095406532,95096.95014238358,91863.84095406532,3212.8778393268585,10.743399143218994,0.0 -270600,4.746498,0.6415423,,,,,,,,,,,,,, -270700,5.120275,0.64430594,,,,,,,,,,,,,, -270800,4.5732527,0.6202967,,,,,,,,,,,,,, -270900,4.558112,0.6814425,,,,,,,,,,,,,, -271000,4.589642,0.6023822,,,,,,,,,,,,,, -271100,4.7452197,0.6901278,,,,,,,,,,,,,, -271200,4.759987,0.6156459,,,,,,,,,,,,,, -271300,4.817011,0.7097686,,,,,,,,,,,,,, -271400,4.5394497,0.62684476,,,,,,,,,,,,,, -271500,4.9276433,0.6495752,,,,,,,,,,,,,, -271600,4.642505,0.5942661,,,,,,,,,,,,,, -271700,4.4474926,0.6599251,,,,,,,,,,,,,, -271800,4.778496,0.592882,,,,,,,,,,,,,, -271900,4.6179814,0.6606293,,,,,,,,,,,,,, -272000,4.283931,0.5854684,,,,,,,,,,,,,, -272020,,,0.9606385231018066,0.1462656706571579,0.7546399831771851,1.0556144714355469,50000.0,0.6302000284194946,1.819312572479248,10000.0,92374.03491449356,95623.9277703762,92374.03491449356,3229.5373075008392,10.815797090530396,0.0 -272100,4.861931,0.58915246,,,,,,,,,,,,,, -272200,4.794425,0.6478132,,,,,,,,,,,,,, -272300,4.861375,0.7013953,,,,,,,,,,,,,, -272400,4.630774,0.61706376,,,,,,,,,,,,,, -272500,4.6914954,0.6002885,,,,,,,,,,,,,, -272600,4.5818777,0.64854825,,,,,,,,,,,,,, -272700,5.446156,0.5994028,,,,,,,,,,,,,, -272800,4.7716713,0.6649444,,,,,,,,,,,,,, -272900,5.288537,0.6782929,,,,,,,,,,,,,, -273000,4.983479,0.7333224,,,,,,,,,,,,,, -273100,4.473095,0.6301889,,,,,,,,,,,,,, -273200,4.609169,0.6630771,,,,,,,,,,,,,, -273300,4.6818314,0.6330085,,,,,,,,,,,,,, -273400,4.3803554,0.65665156,,,,,,,,,,,,,, -273500,4.8690906,0.64586216,,,,,,,,,,,,,, -273523,,,0.9612563848495485,0.1452231556177139,0.7550399899482727,1.0557005405426023,50000.0,0.6306000351905823,1.820548176765442,10000.0,92884.12191414832,96150.77471852304,92884.12191414832,3246.166423082352,10.89329195022583,0.0 -273600,4.338965,0.6039186,,,,,,,,,,,,,, -273700,4.4700885,0.64934456,,,,,,,,,,,,,, -273800,4.329053,0.5886374,,,,,,,,,,,,,, -273900,4.5748744,0.6255877,,,,,,,,,,,,,, -274000,4.9371934,0.6453034,,,,,,,,,,,,,, -274100,4.462229,0.5919673,,,,,,,,,,,,,, -274200,4.439924,0.6495559,,,,,,,,,,,,,, -274300,4.636018,0.6747284,,,,,,,,,,,,,, -274400,4.644836,0.587685,,,,,,,,,,,,,, -274500,4.751684,0.67459667,,,,,,,,,,,,,, -274600,4.5324383,0.64152145,,,,,,,,,,,,,, -274700,4.753876,0.5998986,,,,,,,,,,,,,, -274800,4.3527474,0.625671,,,,,,,,,,,,,, -274900,4.7130737,0.5748214,,,,,,,,,,,,,, -275000,4.6525044,0.6648988,,,,,,,,,,,,,, -275025,,,0.9610570669174194,0.1458848714828491,0.7547199726104736,1.056044578552246,50000.0,0.6304000020027161,1.818693280220032,10000.0,93394.05655407906,96677.45746946336,93394.05655407906,3262.778742313385,10.976376295089722,0.0 -275100,4.5945787,0.5594378,,,,,,,,,,,,,, -275200,4.4633203,0.6707656,,,,,,,,,,,,,, -275300,4.6670995,0.5949265,,,,,,,,,,,,,, -275400,4.820571,0.72636056,,,,,,,,,,,,,, -275500,4.764896,0.7069633,,,,,,,,,,,,,, -275600,4.602473,0.6855823,,,,,,,,,,,,,, -275700,4.6112757,0.69847167,,,,,,,,,,,,,, -275800,4.0242243,0.5273171,,,,,,,,,,,,,, -275900,5.5292754,0.6427468,,,,,,,,,,,,,, -276000,4.931378,0.5821325,,,,,,,,,,,,,, -276100,4.6988926,0.56845546,,,,,,,,,,,,,, -276200,4.5455594,0.63182735,,,,,,,,,,,,,, -276300,4.659305,0.6157177,,,,,,,,,,,,,, -276400,4.9817214,0.72719204,,,,,,,,,,,,,, -276500,4.510345,0.6445132,,,,,,,,,,,,,, -276528,,,0.9612165093421936,0.1463463455438614,0.7547599673271179,1.055619716644287,50000.0,0.6305000185966492,1.8192275762557983,10000.0,93904.15710663795,97204.39679145812,93904.15710663795,3279.488739728928,11.052030563354492,0.0 -276600,4.421888,0.6143739,,,,,,,,,,,,,, -276700,4.429824,0.6142318,,,,,,,,,,,,,, -276800,4.2506804,0.5882395,,,,,,,,,,,,,, -276900,4.6812806,0.63644886,,,,,,,,,,,,,, -277000,4.1613803,0.5729244,,,,,,,,,,,,,, -277100,4.8195996,0.6890631,,,,,,,,,,,,,, -277200,4.6620345,0.54729176,,,,,,,,,,,,,, -277300,5.361896,0.66268593,,,,,,,,,,,,,, -277400,4.6700473,0.6828041,,,,,,,,,,,,,, -277500,4.2612033,0.5572015,,,,,,,,,,,,,, -277600,4.5192122,0.6197014,,,,,,,,,,,,,, -277700,4.1106224,0.5800426,,,,,,,,,,,,,, -277800,4.3975463,0.5957376,,,,,,,,,,,,,, -277900,4.437072,0.5900725,,,,,,,,,,,,,, -278000,4.634617,0.6033585,,,,,,,,,,,,,, -278029,,,0.9600406289100648,0.1457443684339523,0.7548799514770508,1.055456519126892,50000.0,0.6306000351905823,1.8184598684310915,10000.0,94414.21587204932,97731.3130865097,94414.21587204932,3296.2211039066315,11.124534606933594,0.0 -278100,4.3055186,0.547592,,,,,,,,,,,,,, -278200,4.5376277,0.6063932,,,,,,,,,,,,,, -278300,4.801665,0.6403645,,,,,,,,,,,,,, -278400,4.8273034,0.60992676,,,,,,,,,,,,,, -278500,4.6490035,0.6226598,,,,,,,,,,,,,, -278600,4.055168,0.6106328,,,,,,,,,,,,,, -278700,4.6279244,0.7034103,,,,,,,,,,,,,, -278800,4.7987757,0.58833873,,,,,,,,,,,,,, -278900,4.119259,0.61599183,,,,,,,,,,,,,, -279000,4.4652295,0.6357791,,,,,,,,,,,,,, -279100,4.242984,0.5755027,,,,,,,,,,,,,, -279200,4.58686,0.705457,,,,,,,,,,,,,, -279300,4.4722877,0.5986859,,,,,,,,,,,,,, -279400,4.404055,0.5830453,,,,,,,,,,,,,, -279500,4.980191,0.7599167,,,,,,,,,,,,,, -279532,,,0.9602997303009032,0.1474865823984146,0.7548799514770508,1.0558282136917114,50000.0,0.6300000548362732,1.8213742971420288,10000.0,94924.25241136552,98258.06078863144,94924.25241136552,3312.801731109619,11.201616048812866,0.0 -279600,4.529359,0.55827343,,,,,,,,,,,,,, -279700,4.769356,0.6187707,,,,,,,,,,,,,, -279800,4.0913353,0.50919163,,,,,,,,,,,,,, -279900,4.6090403,0.6015448,,,,,,,,,,,,,, -280000,4.3910313,0.56308335,,,,,,,,,,,,,, -280100,4.7053986,0.6200204,,,,,,,,,,,,,, -280200,4.3346953,0.6703815,,,,,,,,,,,,,, -280300,4.718349,0.6004582,,,,,,,,,,,,,, -280400,4.4242396,0.604835,,,,,,,,,,,,,, -280500,4.8478007,0.7141032,,,,,,,,,,,,,, -280600,4.4412913,0.5863434,,,,,,,,,,,,,, -280700,4.6188335,0.64353245,,,,,,,,,,,,,, -280800,4.789019,0.6896087,,,,,,,,,,,,,, -280900,4.2955184,0.5814702,,,,,,,,,,,,,, -281000,4.7593665,0.6527792,,,,,,,,,,,,,, -281034,,,0.9610969424247742,0.1467566490173339,0.7548199892044067,1.0551109313964844,50000.0,0.6305000185966492,1.8184127807617188,10000.0,95434.31954813004,98784.97705054285,95434.31954813004,3329.51867890358,11.280039310455322,0.0 -281100,4.599737,0.62356716,,,,,,,,,,,,,, -281200,4.8013535,0.6401925,,,,,,,,,,,,,, -281300,4.3947616,0.5729127,,,,,,,,,,,,,, -281400,4.81474,0.6419712,,,,,,,,,,,,,, -281500,4.356835,0.5729215,,,,,,,,,,,,,, -281600,4.152677,0.55968976,,,,,,,,,,,,,, -281700,5.1466584,0.6800351,,,,,,,,,,,,,, -281800,4.574681,0.68360376,,,,,,,,,,,,,, -281900,4.900395,0.64800036,,,,,,,,,,,,,, -282000,4.8016257,0.68361586,,,,,,,,,,,,,, -282100,5.188896,0.56753635,,,,,,,,,,,,,, -282200,4.8584137,0.61804247,,,,,,,,,,,,,, -282300,4.45643,0.5898153,,,,,,,,,,,,,, -282400,4.204184,0.5884502,,,,,,,,,,,,,, -282500,4.184653,0.57147884,,,,,,,,,,,,,, -282537,,,0.959382951259613,0.148421362042427,0.7549799680709839,1.0551875829696655,50000.0,0.6302000284194946,1.8191053867340088,10000.0,95944.22466540337,99311.83710670473,95944.22466540337,3346.3417830467224,11.358751773834229,0.0 -282600,4.4453063,0.6292786,,,,,,,,,,,,,, -282700,4.9132276,0.6308487,,,,,,,,,,,,,, -282800,4.395464,0.61651284,,,,,,,,,,,,,, -282900,4.4370284,0.6320371,,,,,,,,,,,,,, -283000,5.0992126,0.6365968,,,,,,,,,,,,,, -283100,4.5632524,0.59792894,,,,,,,,,,,,,, -283200,4.1544995,0.56828946,,,,,,,,,,,,,, -283300,4.3784785,0.6444925,,,,,,,,,,,,,, -283400,4.6178646,0.5570462,,,,,,,,,,,,,, -283500,4.5834336,0.58712363,,,,,,,,,,,,,, -283600,4.81646,0.66893685,,,,,,,,,,,,,, -283700,4.3878393,0.6462327,,,,,,,,,,,,,, -283800,5.175592,0.6678134,,,,,,,,,,,,,, -283900,4.631604,0.6445614,,,,,,,,,,,,,, -284000,4.326201,0.5791,,,,,,,,,,,,,, -284040,,,0.9605388641357422,0.1470494121313095,0.7549200057983398,1.0552425384521484,50000.0,0.6299000382423401,1.8197104930877688,10000.0,96454.19933009148,99838.62381887436,96454.19933009148,3363.020151615143,11.438347816467283,0.0 -284100,4.7484975,0.6303601,,,,,,,,,,,,,, -284200,4.6339626,0.68015766,,,,,,,,,,,,,, -284300,4.4826994,0.6353291,,,,,,,,,,,,,, -284400,4.690808,0.64645517,,,,,,,,,,,,,, -284500,4.540896,0.6152373,,,,,,,,,,,,,, -284600,4.593664,0.6542851,,,,,,,,,,,,,, -284700,4.3039207,0.5178061,,,,,,,,,,,,,, -284800,4.4138737,0.678918,,,,,,,,,,,,,, -284900,4.6284914,0.59257996,,,,,,,,,,,,,, -285000,4.9324207,0.6122671,,,,,,,,,,,,,, -285100,4.6278253,0.62640446,,,,,,,,,,,,,, -285200,4.717775,0.6455191,,,,,,,,,,,,,, -285300,4.8235555,0.6878135,,,,,,,,,,,,,, -285400,4.1247554,0.5325418,,,,,,,,,,,,,, -285500,5.2073936,0.6137754,,,,,,,,,,,,,, -285543,,,0.9606186151504515,0.1456643491983413,0.754859983921051,1.0539931058883667,50000.0,0.6309000253677368,1.817240715026856,10000.0,96964.37281227112,100365.60154628754,96964.37281227112,3379.6935136318207,11.514539003372192,0.0 -285600,4.5806203,0.6528705,,,,,,,,,,,,,, -285700,4.207347,0.579228,,,,,,,,,,,,,, -285800,4.763081,0.6609418,,,,,,,,,,,,,, -285900,4.6266828,0.62129813,,,,,,,,,,,,,, -286000,4.6518645,0.6118758,,,,,,,,,,,,,, -286100,5.1680975,0.6757223,,,,,,,,,,,,,, -286200,4.307512,0.61574984,,,,,,,,,,,,,, -286300,4.262274,0.59119993,,,,,,,,,,,,,, -286400,5.124196,0.6670414,,,,,,,,,,,,,, -286500,4.8410444,0.63805014,,,,,,,,,,,,,, -286600,4.594257,0.53883886,,,,,,,,,,,,,, -286700,5.8600106,0.6103842,,,,,,,,,,,,,, -286800,4.552155,0.61796206,,,,,,,,,,,,,, -286900,4.5698676,0.61141515,,,,,,,,,,,,,, -287000,5.308346,0.7366519,,,,,,,,,,,,,, -287046,,,0.9601203799247742,0.1489818543195724,0.7545999884605408,1.0550166368484497,50000.0,0.6299000382423401,1.8189139366149905,10000.0,97474.5063648224,100892.59136629105,97474.5063648224,3396.415134191513,11.59593391418457,0.0 -287100,4.310672,0.6454386,,,,,,,,,,,,,, -287200,4.859355,0.60070133,,,,,,,,,,,,,, -287300,4.8028293,0.6384755,,,,,,,,,,,,,, -287400,4.404942,0.65091264,,,,,,,,,,,,,, -287500,4.5373874,0.58446854,,,,,,,,,,,,,, -287600,4.911636,0.71171033,,,,,,,,,,,,,, -287700,4.151877,0.5488565,,,,,,,,,,,,,, -287800,4.199895,0.59052324,,,,,,,,,,,,,, -287900,4.6938963,0.63410413,,,,,,,,,,,,,, -288000,4.331462,0.6100468,,,,,,,,,,,,,, -288100,4.3773017,0.5181545,,,,,,,,,,,,,, -288200,4.301124,0.6113462,,,,,,,,,,,,,, -288300,4.5739,0.66308236,,,,,,,,,,,,,, -288400,4.8275867,0.70946854,,,,,,,,,,,,,, -288500,4.3513937,0.5790198,,,,,,,,,,,,,, -288550,,,0.9608577489852904,0.146798700094223,0.7548799514770508,1.0558570623397827,50000.0,0.6303000450134277,1.819604754447937,10000.0,97984.6793088913,101419.5659172535,97984.6793088913,3413.089588880539,11.67100715637207,0.0 -288600,4.924242,0.6958778,,,,,,,,,,,,,, -288700,4.5081778,0.5901863,,,,,,,,,,,,,, -288800,4.642493,0.62844557,,,,,,,,,,,,,, -288900,4.4607472,0.6196591,,,,,,,,,,,,,, -289000,4.4877977,0.609331,,,,,,,,,,,,,, -289100,4.5883675,0.6138403,,,,,,,,,,,,,, -289200,4.3112803,0.58809453,,,,,,,,,,,,,, -289300,4.8370914,0.70731014,,,,,,,,,,,,,, -289400,4.5558224,0.7005888,,,,,,,,,,,,,, -289500,4.4752364,0.6057751,,,,,,,,,,,,,, -289600,4.9813,0.61706114,,,,,,,,,,,,,, -289700,4.4243913,0.61895514,,,,,,,,,,,,,, -289800,4.8087335,0.59212816,,,,,,,,,,,,,, -289900,4.686927,0.5747972,,,,,,,,,,,,,, -290000,5.0793405,0.62469935,,,,,,,,,,,,,, -290052,,,0.9595623016357422,0.1505014598369598,0.7546799778938293,1.0554848909378052,50000.0,0.6304000020027161,1.818774700164795,10000.0,98494.60548329352,101946.35471343994,98494.60548329352,3429.8191237449646,11.751428842544556,0.0 -290100,4.6235337,0.69240075,,,,,,,,,,,,,, -290200,4.2202687,0.57711047,,,,,,,,,,,,,, -290300,4.5853395,0.63130045,,,,,,,,,,,,,, -290400,5.059264,0.6307212,,,,,,,,,,,,,, -290500,4.6508646,0.6510538,,,,,,,,,,,,,, -290600,4.51781,0.576353,,,,,,,,,,,,,, -290700,4.4478097,0.6147789,,,,,,,,,,,,,, -290800,4.578832,0.64489174,,,,,,,,,,,,,, -290900,5.0156436,0.6272936,,,,,,,,,,,,,, -291000,4.6215205,0.68809557,,,,,,,,,,,,,, -291100,4.540779,0.64287835,,,,,,,,,,,,,, -291200,4.4158587,0.63867456,,,,,,,,,,,,,, -291300,4.261142,0.5857671,,,,,,,,,,,,,, -291400,4.4309297,0.5734376,,,,,,,,,,,,,, -291500,4.990004,0.5966079,,,,,,,,,,,,,, -291555,,,0.9608976244926452,0.1472378075122833,0.7546199560165405,1.054668664932251,50000.0,0.6304000020027161,1.817091703414917,10000.0,99004.58482909204,102473.50423121452,99004.58482909204,3446.859624147415,11.829417705535889,0.0 -291600,4.834881,0.7008502,,,,,,,,,,,,,, -291700,4.305786,0.60542667,,,,,,,,,,,,,, -291800,4.645522,0.64024705,,,,,,,,,,,,,, -291900,4.69351,0.6325965,,,,,,,,,,,,,, -292000,4.578701,0.5802006,,,,,,,,,,,,,, -292100,4.0986505,0.54043454,,,,,,,,,,,,,, -292200,4.380583,0.6535015,,,,,,,,,,,,,, -292300,4.5307875,0.7264797,,,,,,,,,,,,,, -292400,4.2894354,0.6693568,,,,,,,,,,,,,, -292500,4.34425,0.61390424,,,,,,,,,,,,,, -292600,4.5537124,0.54804367,,,,,,,,,,,,,, -292700,4.835086,0.6755517,,,,,,,,,,,,,, -292800,4.081231,0.58174753,,,,,,,,,,,,,, -292900,4.394537,0.6044143,,,,,,,,,,,,,, -293000,4.515456,0.5703125,,,,,,,,,,,,,, -293057,,,0.9612762928009032,0.1453863382339477,0.7545599937438965,1.055548071861267,50000.0,0.6300000548362732,1.8181191682815552,10000.0,99514.50890374184,103000.20554113388,99514.50890374184,3463.4972772598267,11.915416479110718,0.0 -293100,4.7674565,0.66094184,,,,,,,,,,,,,, -293200,4.954801,0.63697326,,,,,,,,,,,,,, -293300,4.439202,0.5730469,,,,,,,,,,,,,, -293400,4.4954624,0.65033257,,,,,,,,,,,,,, -293500,4.473231,0.6932951,,,,,,,,,,,,,, -293600,4.8169594,0.6846943,,,,,,,,,,,,,, -293700,4.5483456,0.6636536,,,,,,,,,,,,,, -293800,4.462738,0.6488267,,,,,,,,,,,,,, -293900,4.415861,0.6389278,,,,,,,,,,,,,, -294000,4.5176196,0.61095417,,,,,,,,,,,,,, -294100,5.2207975,0.7093958,,,,,,,,,,,,,, -294200,4.9800515,0.6331257,,,,,,,,,,,,,, -294300,4.544942,0.6317503,,,,,,,,,,,,,, -294400,4.537516,0.63089764,,,,,,,,,,,,,, -294500,4.4244204,0.59172475,,,,,,,,,,,,,, -294560,,,0.959203600883484,0.1480921506881714,0.754539966583252,1.0552234649658203,50000.0,0.6296000480651855,1.8201440572738647,10000.0,100024.59973978996,103527.22813105585,100024.59973978996,3480.2958142757416,11.99556565284729,0.0 -294600,5.0330005,0.63449687,,,,,,,,,,,,,, -294700,4.752355,0.6602835,,,,,,,,,,,,,, -294800,4.676475,0.6203959,,,,,,,,,,,,,, -294900,4.8616276,0.62864006,,,,,,,,,,,,,, -295000,4.379411,0.544437,,,,,,,,,,,,,, -295100,4.6737833,0.65100014,,,,,,,,,,,,,, -295200,4.7418127,0.5734152,,,,,,,,,,,,,, -295300,4.954729,0.64433396,,,,,,,,,,,,,, -295400,4.5551796,0.6395265,,,,,,,,,,,,,, -295500,4.319152,0.5882555,,,,,,,,,,,,,, -295600,4.211169,0.57298446,,,,,,,,,,,,,, -295700,4.1169567,0.56631076,,,,,,,,,,,,,, -295800,4.614612,0.657099,,,,,,,,,,,,,, -295900,4.4476924,0.63071597,,,,,,,,,,,,,, -296000,4.463244,0.66272813,,,,,,,,,,,,,, -296062,,,0.9609175324440002,0.1463489681482315,0.7550599575042725,1.0558689832687378,50000.0,0.629800021648407,1.8213645219802856,10000.0,100534.60390925407,104053.99726223946,100534.60390925407,3496.9346759319305,12.068589448928831,0.0 -296100,4.4115148,0.59315217,,,,,,,,,,,,,, -296200,4.4659004,0.6707099,,,,,,,,,,,,,, -296300,4.3493643,0.5861215,,,,,,,,,,,,,, -296400,4.514224,0.6444782,,,,,,,,,,,,,, -296500,5.163384,0.64767873,,,,,,,,,,,,,, -296600,4.577059,0.57711905,,,,,,,,,,,,,, -296700,4.361802,0.5908722,,,,,,,,,,,,,, -296800,4.516821,0.6741304,,,,,,,,,,,,,, -296900,4.408258,0.64993626,,,,,,,,,,,,,, -297000,4.8254857,0.646459,,,,,,,,,,,,,, -297100,4.5285883,0.63347876,,,,,,,,,,,,,, -297200,4.885363,0.6433825,,,,,,,,,,,,,, -297300,4.7826447,0.6062623,,,,,,,,,,,,,, -297400,4.6463275,0.69580877,,,,,,,,,,,,,, -297500,4.686313,0.6627773,,,,,,,,,,,,,, -297565,,,0.9612762928009032,0.1450014561414718,0.7545599937438965,1.055658221244812,50000.0,0.6313000321388245,1.8192485570907595,10000.0,101044.60762357712,104580.7647485733,101044.60762357712,3513.5654361248016,12.149300336837769,0.0 -297600,4.5823803,0.569301,,,,,,,,,,,,,, -297700,4.9481945,0.6246394,,,,,,,,,,,,,, -297800,4.8084407,0.6713545,,,,,,,,,,,,,, -297900,4.4989414,0.56703424,,,,,,,,,,,,,, -298000,4.3504443,0.6389785,,,,,,,,,,,,,, -298100,4.165214,0.5881107,,,,,,,,,,,,,, -298200,4.585806,0.67261666,,,,,,,,,,,,,, -298300,4.6345263,0.6404677,,,,,,,,,,,,,, -298400,4.5384555,0.6269854,,,,,,,,,,,,,, -298500,4.746157,0.6184163,,,,,,,,,,,,,, -298600,4.6009574,0.57096887,,,,,,,,,,,,,, -298700,4.5986013,0.5856343,,,,,,,,,,,,,, -298800,4.343307,0.581892,,,,,,,,,,,,,, -298900,4.640532,0.6425647,,,,,,,,,,,,,, -299000,4.6663055,0.6545465,,,,,,,,,,,,,, -299068,,,0.9602000713348388,0.1488392055034637,0.7549200057983398,1.056084394454956,50000.0,0.6301000118255615,1.818241834640503,10000.0,101554.6838247776,105107.60811972618,101554.6838247776,3530.203478574753,12.22605037689209,0.0 -299100,4.7677145,0.6722087,,,,,,,,,,,,,, -299200,4.6793003,0.6917512,,,,,,,,,,,,,, -299300,4.3745494,0.61460567,,,,,,,,,,,,,, -299400,4.8201637,0.622903,,,,,,,,,,,,,, -299500,4.725716,0.6426162,,,,,,,,,,,,,, -299600,4.40159,0.5914351,,,,,,,,,,,,,, -299700,4.6614785,0.61564475,,,,,,,,,,,,,, -299800,4.505108,0.6170019,,,,,,,,,,,,,, -299900,4.3212447,0.61161697,,,,,,,,,,,,,, -300000,4.361839,0.62252134,,,,,,,,,,,,,, -300100,4.7774143,0.6323066,,,,,,,,,,,,,, -300200,4.3744636,0.6156418,,,,,,,,,,,,,, -300300,4.4413114,0.6488408,,,,,,,,,,,,,, -300400,4.3752913,0.62071717,,,,,,,,,,,,,, -300500,4.464934,0.63595283,,,,,,,,,,,,,, -300571,,,0.961136758327484,0.14404296875,0.7550999522209167,1.0556014776229858,50000.0,0.6300000548362732,1.818524479866028,10000.0,102064.62591266632,105634.33630681038,102064.62591266632,3546.855792999268,12.306427717208862,0.0 -300600,4.460084,0.58456105,,,,,,,,,,,,,, -300700,4.784212,0.61846167,,,,,,,,,,,,,, -300800,4.8516254,0.6225022,,,,,,,,,,,,,, -300900,5.025808,0.69078326,,,,,,,,,,,,,, -301000,4.273773,0.6123898,,,,,,,,,,,,,, -301100,4.041908,0.5655383,,,,,,,,,,,,,, -301200,4.369908,0.62656367,,,,,,,,,,,,,, -301300,4.458452,0.5761131,,,,,,,,,,,,,, -301400,4.34142,0.5872271,,,,,,,,,,,,,, -301500,4.7275147,0.599985,,,,,,,,,,,,,, -301600,4.5891323,0.65523624,,,,,,,,,,,,,, -301700,4.5453434,0.5840846,,,,,,,,,,,,,, -301800,4.55876,0.63029826,,,,,,,,,,,,,, -301900,4.6108365,0.59754646,,,,,,,,,,,,,, -302000,4.9436407,0.7098438,,,,,,,,,,,,,, -302074,,,0.9614756107330322,0.1412848830223083,0.7550599575042725,1.056552171707153,50000.0,0.6300000548362732,1.8204947710037231,10000.0,102574.78446245192,106161.29858207704,102574.78446245192,3563.481164932251,12.430617094039915,0.0 -302100,4.081705,0.57043535,,,,,,,,,,,,,, -302200,4.895319,0.6855616,,,,,,,,,,,,,, -302300,4.2732954,0.5989055,,,,,,,,,,,,,, -302400,4.205895,0.53092146,,,,,,,,,,,,,, -302500,4.262375,0.5758198,,,,,,,,,,,,,, -302600,4.4131784,0.6074095,,,,,,,,,,,,,, -302700,4.598707,0.6279576,,,,,,,,,,,,,, -302800,4.5463247,0.6171328,,,,,,,,,,,,,, -302900,4.5484037,0.66558725,,,,,,,,,,,,,, -303000,4.294808,0.6274145,,,,,,,,,,,,,, -303100,4.905478,0.65818626,,,,,,,,,,,,,, -303200,4.745431,0.6041422,,,,,,,,,,,,,, -303300,4.6884866,0.6539693,,,,,,,,,,,,,, -303400,4.3225555,0.60816336,,,,,,,,,,,,,, -303500,4.449762,0.56870383,,,,,,,,,,,,,, -303577,,,0.961316168308258,0.1461625099182129,0.7547799944877625,1.05518901348114,50000.0,0.6292000412940979,1.8195948600769043,10000.0,103084.86869740486,106688.22069692612,103084.86869740486,3580.18265748024,12.513375520706177,0.0 -303600,4.6594596,0.6241113,,,,,,,,,,,,,, -303700,4.613417,0.61651754,,,,,,,,,,,,,, -303800,4.568648,0.6039665,,,,,,,,,,,,,, -303900,4.52085,0.60508716,,,,,,,,,,,,,, -304000,4.7974515,0.68519187,,,,,,,,,,,,,, -304100,4.2382145,0.55703175,,,,,,,,,,,,,, -304200,4.9539304,0.71053445,,,,,,,,,,,,,, -304300,4.493449,0.66237336,,,,,,,,,,,,,, -304400,4.3757143,0.5609744,,,,,,,,,,,,,, -304500,4.551898,0.62699753,,,,,,,,,,,,,, -304600,4.0657578,0.52238524,,,,,,,,,,,,,, -304700,4.6094103,0.67820466,,,,,,,,,,,,,, -304800,4.6434836,0.6579922,,,,,,,,,,,,,, -304900,4.3645782,0.69146395,,,,,,,,,,,,,, -305000,4.8523784,0.67683154,,,,,,,,,,,,,, -305080,,,0.9605388641357422,0.1448585391044616,0.7550199627876282,1.054583191871643,50000.0,0.6306000351905823,1.817049264907837,10000.0,103594.85424780846,107215.7625875473,103594.85424780846,3597.605567932129,12.594209432601929,0.0 -305100,4.4916134,0.6134049,,,,,,,,,,,,,, -305200,4.3525577,0.58285105,,,,,,,,,,,,,, -305300,4.4472556,0.5761627,,,,,,,,,,,,,, -305400,4.9064903,0.6402114,,,,,,,,,,,,,, -305500,4.8220587,0.6371434,,,,,,,,,,,,,, -305600,4.8278604,0.6511696,,,,,,,,,,,,,, -305700,4.4470515,0.62191373,,,,,,,,,,,,,, -305800,4.4090557,0.6039387,,,,,,,,,,,,,, -305900,4.4394317,0.6437198,,,,,,,,,,,,,, -306000,4.750629,0.61447054,,,,,,,,,,,,,, -306100,4.780336,0.6446459,,,,,,,,,,,,,, -306200,4.7244043,0.65755683,,,,,,,,,,,,,, -306300,4.5932393,0.61357224,,,,,,,,,,,,,, -306400,5.2439327,0.6321778,,,,,,,,,,,,,, -306500,4.685783,0.695778,,,,,,,,,,,,,, -306583,,,0.959980845451355,0.1489111185073852,0.7552199959754944,1.0559635162353516,50000.0,0.6307000517845154,1.818765878677368,10000.0,104104.78870105743,107742.52642440796,104104.78870105743,3614.303423643112,12.67265248298645,0.0 -306600,4.9947386,0.6461935,,,,,,,,,,,,,, -306700,4.1619596,0.55336046,,,,,,,,,,,,,, -306800,4.6491776,0.604293,,,,,,,,,,,,,, -306900,4.4347267,0.63305825,,,,,,,,,,,,,, -307000,4.451821,0.57901824,,,,,,,,,,,,,, -307100,4.3970084,0.56007355,,,,,,,,,,,,,, -307200,4.5307307,0.58342874,,,,,,,,,,,,,, -307300,4.2532306,0.6101792,,,,,,,,,,,,,, -307400,4.859138,0.6339982,,,,,,,,,,,,,, -307500,4.455912,0.63608897,,,,,,,,,,,,,, -307600,4.3501782,0.5870643,,,,,,,,,,,,,, -307700,4.6069813,0.66918814,,,,,,,,,,,,,, -307800,4.2275863,0.66066325,,,,,,,,,,,,,, -307900,4.1842966,0.63031965,,,,,,,,,,,,,, -308000,4.146613,0.62424374,,,,,,,,,,,,,, -308086,,,0.9614157676696776,0.1438297778367996,0.7548399567604065,1.0567333698272705,50000.0,0.6294000148773193,1.8203606605529783,10000.0,104614.98310089111,108269.4746274948,104614.98310089111,3630.9186856746674,12.757277011871338,0.0 -308100,4.8625703,0.7053971,,,,,,,,,,,,,, -308200,5.070724,0.69146925,,,,,,,,,,,,,, -308300,4.9529886,0.62594527,,,,,,,,,,,,,, -308400,4.736149,0.68054307,,,,,,,,,,,,,, -308500,4.612628,0.61339027,,,,,,,,,,,,,, -308600,4.750889,0.6760348,,,,,,,,,,,,,, -308700,4.533962,0.59440273,,,,,,,,,,,,,, -308800,3.966674,0.5424541,,,,,,,,,,,,,, -308900,4.268986,0.66853255,,,,,,,,,,,,,, -309000,4.5999365,0.66214025,,,,,,,,,,,,,, -309100,4.5572143,0.5764824,,,,,,,,,,,,,, -309200,4.6194887,0.6712316,,,,,,,,,,,,,, -309300,4.32584,0.57784814,,,,,,,,,,,,,, -309400,4.1911945,0.60863173,,,,,,,,,,,,,, -309500,4.8114614,0.6430781,,,,,,,,,,,,,, -309589,,,0.960558831691742,0.147440105676651,0.7550399899482727,1.0556894540786743,50000.0,0.6310000419616699,1.819544672966004,10000.0,105125.10056447984,108796.3183927536,105125.10056447984,3647.511518955231,12.83810329437256,0.0 -309600,4.127656,0.563533,,,,,,,,,,,,,, -309700,4.8972325,0.567065,,,,,,,,,,,,,, -309800,4.347842,0.63661444,,,,,,,,,,,,,, -309900,4.684251,0.6492468,,,,,,,,,,,,,, -310000,4.9172826,0.60199165,,,,,,,,,,,,,, -310100,4.6907525,0.7173934,,,,,,,,,,,,,, -310200,4.3406587,0.5730916,,,,,,,,,,,,,, -310300,4.584388,0.5856803,,,,,,,,,,,,,, -310400,4.2150507,0.53305185,,,,,,,,,,,,,, -310500,4.202101,0.58331656,,,,,,,,,,,,,, -310600,4.3994255,0.56334925,,,,,,,,,,,,,, -310700,4.489243,0.671687,,,,,,,,,,,,,, -310800,4.284073,0.6098593,,,,,,,,,,,,,, -310900,5.219472,0.70107865,,,,,,,,,,,,,, -311000,4.721805,0.56473255,,,,,,,,,,,,,, -311092,,,0.9615154266357422,0.1449520289897918,0.7554399967193604,1.0552821159362793,50000.0,0.6300000548362732,1.8196454048156736,10000.0,105635.1733725071,109323.24497246742,105635.1733725071,3664.230366706848,12.919663667678831,0.0 -311100,4.3658557,0.63276786,,,,,,,,,,,,,, -311200,4.2087617,0.5797764,,,,,,,,,,,,,, -311300,4.499244,0.5655734,,,,,,,,,,,,,, -311400,5.224898,0.61795294,,,,,,,,,,,,,, -311500,4.518593,0.65032816,,,,,,,,,,,,,, -311600,4.6685534,0.57450485,,,,,,,,,,,,,, -311700,4.48762,0.6433386,,,,,,,,,,,,,, -311800,4.7198534,0.6965489,,,,,,,,,,,,,, -311900,4.659567,0.6893951,,,,,,,,,,,,,, -312000,4.4206724,0.5635514,,,,,,,,,,,,,, -312100,4.598171,0.71011853,,,,,,,,,,,,,, -312200,4.709574,0.63995236,,,,,,,,,,,,,, -312300,4.667969,0.6263931,,,,,,,,,,,,,, -312400,4.4590354,0.6087454,,,,,,,,,,,,,, -312500,5.0418005,0.72616756,,,,,,,,,,,,,, -312595,,,0.9604990482330322,0.145500361919403,0.7548199892044067,1.055170655250549,50000.0,0.6308000087738037,1.8177525997161863,10000.0,106145.30757308006,109850.11073994637,106145.30757308006,3680.8267703056335,13.000634670257568,0.0 -312600,4.8859105,0.6428027,,,,,,,,,,,,,, -312700,4.034131,0.54952776,,,,,,,,,,,,,, -312800,4.688289,0.6612494,,,,,,,,,,,,,, -312900,4.5298724,0.5921084,,,,,,,,,,,,,, -313000,4.430997,0.5910899,,,,,,,,,,,,,, -313100,4.611147,0.6305728,,,,,,,,,,,,,, -313200,4.2770667,0.59390426,,,,,,,,,,,,,, -313300,4.6830115,0.56363714,,,,,,,,,,,,,, -313400,4.461545,0.67302966,,,,,,,,,,,,,, -313500,5.0938506,0.6412481,,,,,,,,,,,,,, -313600,4.788566,0.6033768,,,,,,,,,,,,,, -313700,5.0815883,0.66645527,,,,,,,,,,,,,, -313800,4.455204,0.6392036,,,,,,,,,,,,,, -313900,4.516899,0.6125687,,,,,,,,,,,,,, -314000,4.319029,0.596132,,,,,,,,,,,,,, -314097,,,0.9604591727256776,0.1466144472360611,0.7550399899482727,1.0555564165115356,50000.0,0.6306000351905823,1.818840742111206,10000.0,106655.35900592804,110376.99791812895,106655.35900592804,3697.522133350372,13.088707447052002,0.0 -314100,4.3590517,0.6608096,,,,,,,,,,,,,, -314200,4.655082,0.5749517,,,,,,,,,,,,,, -314300,4.3213634,0.5998119,,,,,,,,,,,,,, -314400,4.2724915,0.5981132,,,,,,,,,,,,,, -314500,4.7915206,0.59443843,,,,,,,,,,,,,, -314600,4.2603183,0.6109198,,,,,,,,,,,,,, -314700,4.14023,0.6109174,,,,,,,,,,,,,, -314800,4.2998867,0.5861746,,,,,,,,,,,,,, -314900,4.4693627,0.62319964,,,,,,,,,,,,,, -315000,4.145863,0.5468047,,,,,,,,,,,,,, -315100,4.2685666,0.6174005,,,,,,,,,,,,,, -315200,4.433666,0.6196886,,,,,,,,,,,,,, -315300,4.549915,0.6560563,,,,,,,,,,,,,, -315400,4.572461,0.59511834,,,,,,,,,,,,,, -315500,4.7166405,0.61285394,,,,,,,,,,,,,, -315600,,,0.9605388641357422,0.1460705995559692,0.7547000050544739,1.0562903881072998,50000.0,0.631100058555603,1.820388913154602,10000.0,107165.2802259922,110903.63587498663,107165.2802259922,3714.099026441574,13.175846815109251,0.0 -315600,4.6421237,0.6485972,,,,,,,,,,,,,, -315700,4.4280143,0.6550857,,,,,,,,,,,,,, -315800,4.0889506,0.56894225,,,,,,,,,,,,,, -315900,4.446913,0.65995044,,,,,,,,,,,,,, -316000,4.734028,0.6260381,,,,,,,,,,,,,, -316100,4.566153,0.61924213,,,,,,,,,,,,,, -316200,4.684786,0.68008673,,,,,,,,,,,,,, -316300,3.9940634,0.52142465,,,,,,,,,,,,,, -316400,4.1767063,0.591767,,,,,,,,,,,,,, -316500,4.1975393,0.5286062,,,,,,,,,,,,,, -316600,5.3309045,0.668523,,,,,,,,,,,,,, -316700,4.7614627,0.6028088,,,,,,,,,,,,,, -316800,4.202161,0.61543965,,,,,,,,,,,,,, -316900,4.511777,0.6411251,,,,,,,,,,,,,, -317000,4.6373734,0.60413504,,,,,,,,,,,,,, -317100,4.076055,0.6199142,,,,,,,,,,,,,, -317102,,,0.9609972834587096,0.1439758241176605,0.7548399567604065,1.0555888414382937,50000.0,0.6310000419616699,1.8175864219665527,10000.0,107675.1609697342,111430.3076775074,107675.1609697342,3730.758029222488,13.25473165512085,0.0 -317200,4.8921633,0.66434586,,,,,,,,,,,,,, -317300,4.145042,0.6509393,,,,,,,,,,,,,, -317400,4.442381,0.5552836,,,,,,,,,,,,,, -317500,4.5483565,0.631101,,,,,,,,,,,,,, -317600,4.8792076,0.6700791,,,,,,,,,,,,,, -317700,4.612505,0.67352766,,,,,,,,,,,,,, -317800,4.497043,0.6179058,,,,,,,,,,,,,, -317900,4.3859067,0.5866869,,,,,,,,,,,,,, -318000,4.9075565,0.7383553,,,,,,,,,,,,,, -318100,5.0682774,0.6772171,,,,,,,,,,,,,, -318200,4.5038853,0.5900583,,,,,,,,,,,,,, -318300,5.115861,0.66239655,,,,,,,,,,,,,, -318400,4.619522,0.6384907,,,,,,,,,,,,,, -318500,4.398582,0.5843786,,,,,,,,,,,,,, -318600,4.5822835,0.7158266,,,,,,,,,,,,,, -318605,,,0.9608378410339355,0.1478148102760315,0.7548799514770508,1.056528925895691,50000.0,0.6300000548362732,1.822026610374451,10000.0,108185.10398316383,111957.12183737756,108185.10398316383,3747.4862003326416,13.34488844871521,0.0 -318700,4.554542,0.6728678,,,,,,,,,,,,,, -318800,4.382539,0.572691,,,,,,,,,,,,,, -318900,4.8050756,0.71692896,,,,,,,,,,,,,, -319000,4.329331,0.6196047,,,,,,,,,,,,,, -319100,4.5803337,0.580771,,,,,,,,,,,,,, -319200,4.4212894,0.669744,,,,,,,,,,,,,, -319300,4.4830675,0.62069124,,,,,,,,,,,,,, -319400,5.126795,0.62783885,,,,,,,,,,,,,, -319500,4.681538,0.6454673,,,,,,,,,,,,,, -319600,4.9736176,0.67611897,,,,,,,,,,,,,, -319700,4.855915,0.6683546,,,,,,,,,,,,,, -319800,4.6225595,0.64376223,,,,,,,,,,,,,, -319900,4.5000176,0.6442986,,,,,,,,,,,,,, -320000,4.6408453,0.62495244,,,,,,,,,,,,,, -320100,4.223119,0.6133301,,,,,,,,,,,,,, -320107,,,0.9611168503761292,0.1467342525720596,0.7544999718666077,1.055534839630127,50000.0,0.6300000548362732,1.8197250366210933,10000.0,108695.18180251122,112484.0316901207,108695.18180251122,3764.187334537506,13.422353982925417,0.0 -320200,4.8584476,0.7136049,,,,,,,,,,,,,, -320300,4.390007,0.5865673,,,,,,,,,,,,,, -320400,4.700721,0.6066725,,,,,,,,,,,,,, -320500,4.4240737,0.6615808,,,,,,,,,,,,,, -320600,4.4189835,0.61693364,,,,,,,,,,,,,, -320700,5.0075197,0.62935215,,,,,,,,,,,,,, -320800,4.494737,0.62828064,,,,,,,,,,,,,, -320900,4.455682,0.61932963,,,,,,,,,,,,,, -321000,4.469503,0.6164136,,,,,,,,,,,,,, -321100,4.6006155,0.7384668,,,,,,,,,,,,,, -321200,4.5455384,0.6496897,,,,,,,,,,,,,, -321300,4.8697395,0.65768725,,,,,,,,,,,,,, -321400,4.5807076,0.61333895,,,,,,,,,,,,,, -321500,4.2705684,0.6253185,,,,,,,,,,,,,, -321600,4.708189,0.66753507,,,,,,,,,,,,,, -321610,,,0.9595623016357422,0.14896921813488,0.7551999688148499,1.0555588006973269,50000.0,0.6301000118255615,1.8181936740875244,10000.0,109205.0882525444,113010.67347025871,109205.0882525444,3780.785356760025,13.505805730819702,0.0 -321700,4.522761,0.5814783,,,,,,,,,,,,,, -321800,4.4388204,0.60038316,,,,,,,,,,,,,, -321900,4.227535,0.5549988,,,,,,,,,,,,,, -322000,4.581529,0.6718259,,,,,,,,,,,,,, -322100,5.007957,0.6025592,,,,,,,,,,,,,, -322200,4.4678683,0.62956464,,,,,,,,,,,,,, -322300,4.4342213,0.59117055,,,,,,,,,,,,,, -322400,4.5214705,0.6987067,,,,,,,,,,,,,, -322500,4.8924713,0.6547758,,,,,,,,,,,,,, -322600,4.691054,0.68210584,,,,,,,,,,,,,, -322700,4.260876,0.6045084,,,,,,,,,,,,,, -322800,4.2760572,0.53630805,,,,,,,,,,,,,, -322900,4.969973,0.63068634,,,,,,,,,,,,,, -323000,4.961698,0.61934936,,,,,,,,,,,,,, -323100,4.8295236,0.634565,,,,,,,,,,,,,, -323112,,,0.9606584906578064,0.1447600871324539,0.7548999786376953,1.0555732250213623,50000.0,0.6301000118255615,1.8190803527832031,10000.0,109715.07874393465,113537.41925096512,109715.07874393465,3797.404670953751,13.588683843612673,0.0 -323200,4.3618283,0.5890151,,,,,,,,,,,,,, -323300,4.515459,0.6243224,,,,,,,,,,,,,, -323400,4.2148757,0.6327394,,,,,,,,,,,,,, -323500,4.0951705,0.58407634,,,,,,,,,,,,,, -323600,4.385571,0.6434443,,,,,,,,,,,,,, -323700,4.2936134,0.6098436,,,,,,,,,,,,,, -323800,4.6657753,0.6685463,,,,,,,,,,,,,, -323900,4.676975,0.52552533,,,,,,,,,,,,,, -324000,4.884424,0.65059954,,,,,,,,,,,,,, -324100,4.4096932,0.65104586,,,,,,,,,,,,,, -324200,4.3672338,0.58474535,,,,,,,,,,,,,, -324300,4.8097672,0.6585036,,,,,,,,,,,,,, -324400,4.902721,0.6801182,,,,,,,,,,,,,, -324500,4.29526,0.6427483,,,,,,,,,,,,,, -324600,4.7785907,0.72368306,,,,,,,,,,,,,, -324614,,,0.9600805044174194,0.1485663652420044,0.7547599673271179,1.055357575416565,50000.0,0.6305000185966492,1.81911849975586,10000.0,110224.97239756584,114064.11824631692,110224.97239756584,3814.077670812607,13.667658567428589,0.0 -324700,4.4893074,0.60814667,,,,,,,,,,,,,, -324800,4.6806555,0.6510447,,,,,,,,,,,,,, -324900,5.072727,0.6357806,,,,,,,,,,,,,, -325000,4.316452,0.60721225,,,,,,,,,,,,,, -325100,4.519814,0.6099593,,,,,,,,,,,,,, -325200,5.0739264,0.6549485,,,,,,,,,,,,,, -325300,4.3222833,0.5306864,,,,,,,,,,,,,, -325400,4.4750633,0.69571733,,,,,,,,,,,,,, -325500,4.617098,0.6494638,,,,,,,,,,,,,, -325600,4.3206034,0.5752745,,,,,,,,,,,,,, -325700,4.419355,0.68302953,,,,,,,,,,,,,, -325800,4.652739,0.61477953,,,,,,,,,,,,,, -325900,4.5905805,0.62818766,,,,,,,,,,,,,, -326000,4.66274,0.65783304,,,,,,,,,,,,,, -326100,4.5057964,0.5989297,,,,,,,,,,,,,, -326116,,,0.959582269191742,0.1480408757925033,0.754859983921051,1.055404782295227,50000.0,0.6304000020027161,1.819116711616516,10000.0,110734.83861660956,114590.7036960125,110734.83861660956,3830.655671596527,13.755608558654783,0.0 -326200,5.242365,0.6872277,,,,,,,,,,,,,, -326300,4.5000596,0.6257136,,,,,,,,,,,,,, -326400,4.3880496,0.55811214,,,,,,,,,,,,,, -326500,4.2851353,0.5621544,,,,,,,,,,,,,, -326600,4.276601,0.4853488,,,,,,,,,,,,,, -326700,4.098322,0.51368093,,,,,,,,,,,,,, -326800,4.427433,0.57796293,,,,,,,,,,,,,, -326900,4.2588887,0.55166185,,,,,,,,,,,,,, -327000,4.488791,0.6437306,,,,,,,,,,,,,, -327100,4.5002537,0.5811488,,,,,,,,,,,,,, -327200,4.471536,0.574382,,,,,,,,,,,,,, -327300,4.3653564,0.60495776,,,,,,,,,,,,,, -327400,4.672766,0.59390026,,,,,,,,,,,,,, -327500,4.373144,0.6139822,,,,,,,,,,,,,, -327600,4.2870398,0.5631225,,,,,,,,,,,,,, -327618,,,0.961355984210968,0.147781953215599,0.7549799680709839,1.0543818473815918,50000.0,0.6300000548362732,1.8186113834381104,10000.0,111244.93863582613,115117.5930685997,111244.93863582613,3847.3105340003967,13.836986780166626,0.0 -327700,4.457476,0.6719172,,,,,,,,,,,,,, -327800,4.281635,0.5575704,,,,,,,,,,,,,, -327900,5.0188527,0.67645615,,,,,,,,,,,,,, -328000,4.529715,0.6598358,,,,,,,,,,,,,, -328100,4.9497147,0.6126417,,,,,,,,,,,,,, -328200,4.5920053,0.61221397,,,,,,,,,,,,,, -328300,4.5506115,0.64781445,,,,,,,,,,,,,, -328400,4.3540134,0.64507526,,,,,,,,,,,,,, -328500,4.3050876,0.59752035,,,,,,,,,,,,,, -328600,4.5233803,0.6491208,,,,,,,,,,,,,, -328700,4.667986,0.6570117,,,,,,,,,,,,,, -328800,4.662255,0.64274347,,,,,,,,,,,,,, -328900,4.644317,0.62634236,,,,,,,,,,,,,, -329000,4.491535,0.5595482,,,,,,,,,,,,,, -329100,5.03407,0.7240001,,,,,,,,,,,,,, -329121,,,0.959363043308258,0.1508469581604004,0.7549799680709839,1.055037498474121,50000.0,0.6299000382423401,1.8170280456542969,10000.0,111755.11136102676,115644.44719862938,111755.11136102676,3863.8564035892487,13.919602155685425,0.0 -329200,4.4217052,0.64698756,,,,,,,,,,,,,, -329300,4.2941093,0.6745328,,,,,,,,,,,,,, -329400,4.457993,0.6290841,,,,,,,,,,,,,, -329500,4.6682267,0.6709279,,,,,,,,,,,,,, -329600,4.797456,0.6028699,,,,,,,,,,,,,, -329700,4.4735684,0.599958,,,,,,,,,,,,,, -329800,4.6026707,0.6916444,,,,,,,,,,,,,, -329900,4.451038,0.627757,,,,,,,,,,,,,, -330000,4.2042994,0.5627342,,,,,,,,,,,,,, -330100,4.320463,0.6087973,,,,,,,,,,,,,, -330200,4.1918073,0.57329535,,,,,,,,,,,,,, -330300,4.1471124,0.5747552,,,,,,,,,,,,,, -330400,4.1291614,0.50114834,,,,,,,,,,,,,, -330500,4.8878117,0.6421849,,,,,,,,,,,,,, -330600,5.8510857,0.68556935,,,,,,,,,,,,,, -330624,,,0.961336076259613,0.1432124525308609,0.7545599937438965,1.0553230047225952,50000.0,0.6305000185966492,1.8186101913452148,10000.0,112265.23484659196,116171.29470396042,112265.23484659196,3880.439862012863,14.007933139801024,0.0 -330700,4.441887,0.566596,,,,,,,,,,,,,, -330800,4.4454517,0.6630131,,,,,,,,,,,,,, -330900,5.129437,0.74193084,,,,,,,,,,,,,, -331000,4.740832,0.6995488,,,,,,,,,,,,,, -331100,4.149785,0.6096255,,,,,,,,,,,,,, -331200,4.6706676,0.5856329,,,,,,,,,,,,,, -331300,4.023397,0.57992756,,,,,,,,,,,,,, -331400,4.3539476,0.6441066,,,,,,,,,,,,,, -331500,4.2541547,0.5393471,,,,,,,,,,,,,, -331600,4.946591,0.58979046,,,,,,,,,,,,,, -331700,4.1020045,0.57913667,,,,,,,,,,,,,, -331800,4.4309483,0.57511574,,,,,,,,,,,,,, -331900,4.3243384,0.59406507,,,,,,,,,,,,,, -332000,4.7581744,0.6421606,,,,,,,,,,,,,, -332100,4.6321135,0.6492534,,,,,,,,,,,,,, -332127,,,0.959183633327484,0.1507417559623718,0.7548399567604065,1.0550819635391235,50000.0,0.6305000185966492,1.8171299695968628,10000.0,112775.25321817398,116698.1319694519,112775.25321817398,3897.117249965668,14.09683918952942,0.0 -332200,4.1872582,0.56937706,,,,,,,,,,,,,, -332300,4.673283,0.6333775,,,,,,,,,,,,,, -332400,4.8539248,0.66786224,,,,,,,,,,,,,, -332500,4.5879617,0.6958403,,,,,,,,,,,,,, -332600,4.3044724,0.543924,,,,,,,,,,,,,, -332700,4.345586,0.69939876,,,,,,,,,,,,,, -332800,4.406568,0.59364116,,,,,,,,,,,,,, -332900,4.0853033,0.5784398,,,,,,,,,,,,,, -333000,4.3517256,0.61680055,,,,,,,,,,,,,, -333100,4.895384,0.6378641,,,,,,,,,,,,,, -333200,4.5949383,0.5774846,,,,,,,,,,,,,, -333300,4.4353333,0.6269225,,,,,,,,,,,,,, -333400,4.800835,0.6648929,,,,,,,,,,,,,, -333500,4.258482,0.6201127,,,,,,,,,,,,,, -333600,4.4868593,0.60199404,,,,,,,,,,,,,, -333631,,,0.9612563848495485,0.1435246765613556,0.7547599673271179,1.0559492111206057,50000.0,0.6306000351905823,1.819736361503601,10000.0,113285.4351055622,117225.0557732582,113285.4351055622,3913.714976072312,14.186609506607056,0.0 -333700,5.059517,0.6428604,,,,,,,,,,,,,, -333800,4.782356,0.58185965,,,,,,,,,,,,,, -333900,4.826761,0.67804426,,,,,,,,,,,,,, -334000,4.0750065,0.5875008,,,,,,,,,,,,,, -334100,4.8202205,0.65068537,,,,,,,,,,,,,, -334200,4.727242,0.6213224,,,,,,,,,,,,,, -334300,4.314309,0.5809319,,,,,,,,,,,,,, -334400,4.377478,0.5503746,,,,,,,,,,,,,, -334500,4.408259,0.59111166,,,,,,,,,,,,,, -334600,4.640266,0.6998857,,,,,,,,,,,,,, -334700,4.47922,0.60462606,,,,,,,,,,,,,, -334800,4.801911,0.58659714,,,,,,,,,,,,,, -334900,4.447976,0.6275383,,,,,,,,,,,,,, -335000,5.116127,0.6087797,,,,,,,,,,,,,, -335100,4.440822,0.6355954,,,,,,,,,,,,,, -335134,,,0.9598612785339355,0.1471763551235199,0.7553199529647827,1.055428147315979,50000.0,0.6310000419616699,1.8180768489837649,10000.0,113795.54326581956,117751.9583992958,113795.54326581956,3930.369282960892,14.27349853515625,0.0 -335200,4.618179,0.61771387,,,,,,,,,,,,,, -335300,4.57057,0.56455016,,,,,,,,,,,,,, -335400,4.651554,0.6655446,,,,,,,,,,,,,, -335500,4.3329945,0.6093943,,,,,,,,,,,,,, -335600,4.813784,0.6874056,,,,,,,,,,,,,, -335700,4.3041472,0.53234684,,,,,,,,,,,,,, -335800,4.472891,0.59825194,,,,,,,,,,,,,, -335900,4.8999186,0.6220366,,,,,,,,,,,,,, -336000,4.170463,0.59221554,,,,,,,,,,,,,, -336100,4.7550282,0.6544292,,,,,,,,,,,,,, -336200,4.546857,0.582117,,,,,,,,,,,,,, -336300,4.97086,0.6090042,,,,,,,,,,,,,, -336400,4.6647635,0.5901786,,,,,,,,,,,,,, -336500,4.5122914,0.5935404,,,,,,,,,,,,,, -336600,4.3569775,0.6116705,,,,,,,,,,,,,, -336637,,,0.961136758327484,0.147893875837326,0.7549200057983398,1.0557163953781128,50000.0,0.6302000284194946,1.8188215494155884,10000.0,114305.43227887154,118278.74791812895,114305.43227887154,3947.1287870407104,14.36074447631836,0.0 -336700,4.5532784,0.6239457,,,,,,,,,,,,,, -336800,4.4102,0.6551038,,,,,,,,,,,,,, -336900,3.9511468,0.6196784,,,,,,,,,,,,,, -337000,4.760381,0.62701285,,,,,,,,,,,,,, -337100,4.740754,0.6430383,,,,,,,,,,,,,, -337200,4.7743893,0.629913,,,,,,,,,,,,,, -337300,4.5980215,0.6489676,,,,,,,,,,,,,, -337400,4.7635145,0.59523714,,,,,,,,,,,,,, -337500,4.195912,0.6427495,,,,,,,,,,,,,, -337600,4.553155,0.65234315,,,,,,,,,,,,,, -337700,4.5120363,0.5909631,,,,,,,,,,,,,, -337800,4.621059,0.67295474,,,,,,,,,,,,,, -337900,4.6280622,0.5748946,,,,,,,,,,,,,, -338000,4.58306,0.69004893,,,,,,,,,,,,,, -338100,4.4681673,0.6455144,,,,,,,,,,,,,, -338139,,,0.9604790806770324,0.1472291797399521,0.7548399567604065,1.055741786956787,50000.0,0.6310000419616699,1.8178565502166748,10000.0,114815.45590782166,118805.63232278824,114815.45590782166,3963.851808786392,14.4454824924469,0.0 -338200,4.5285945,0.6435713,,,,,,,,,,,,,, -338300,4.4431252,0.62666553,,,,,,,,,,,,,, -338400,4.7484913,0.69400907,,,,,,,,,,,,,, -338500,4.4579315,0.63862693,,,,,,,,,,,,,, -338600,4.8253856,0.63181996,,,,,,,,,,,,,, -338700,4.1784506,0.59041995,,,,,,,,,,,,,, -338800,4.248763,0.6177679,,,,,,,,,,,,,, -338900,4.6742177,0.5733839,,,,,,,,,,,,,, -339000,4.5903993,0.64047104,,,,,,,,,,,,,, -339100,4.4477224,0.5486289,,,,,,,,,,,,,, -339200,4.661248,0.6066473,,,,,,,,,,,,,, -339300,4.475865,0.5875928,,,,,,,,,,,,,, -339400,4.5776687,0.6858231,,,,,,,,,,,,,, -339500,4.544762,0.6276111,,,,,,,,,,,,,, -339600,4.8249164,0.7220892,,,,,,,,,,,,,, -339643,,,0.9617944359779358,0.1409309804439544,0.7549600005149841,1.0552440881729126,50000.0,0.6303000450134277,1.8199061155319207,10000.0,115325.57354831696,119332.43783807756,115325.57354831696,3980.400431394577,14.531226396560667,0.0 -339700,4.3655496,0.62008035,,,,,,,,,,,,,, -339800,4.513542,0.65852654,,,,,,,,,,,,,, -339900,4.4106507,0.5987098,,,,,,,,,,,,,, -340000,4.873323,0.6222078,,,,,,,,,,,,,, -340100,4.5223613,0.6005267,,,,,,,,,,,,,, -340200,4.3836884,0.63565016,,,,,,,,,,,,,, -340300,4.2763815,0.5756212,,,,,,,,,,,,,, -340400,4.7342405,0.6281896,,,,,,,,,,,,,, -340500,4.181526,0.6033042,,,,,,,,,,,,,, -340600,4.337679,0.6314785,,,,,,,,,,,,,, -340700,4.4728847,0.6609471,,,,,,,,,,,,,, -340800,4.766816,0.6169734,,,,,,,,,,,,,, -340900,4.1962667,0.5591377,,,,,,,,,,,,,, -341000,4.5042715,0.65643644,,,,,,,,,,,,,, -341100,4.4608545,0.66147333,,,,,,,,,,,,,, -341147,,,0.9620137214660645,0.1430947631597519,0.7548999786376953,1.0565391778945925,50000.0,0.6305000185966492,1.820884227752685,10000.0,115835.75323462486,119859.64888954164,115835.75323462486,3997.29501080513,14.61514163017273,0.0 -341200,4.4912176,0.6330322,,,,,,,,,,,,,, -341300,4.45919,0.6372284,,,,,,,,,,,,,, -341400,4.670734,0.57343173,,,,,,,,,,,,,, -341500,4.0466666,0.5927837,,,,,,,,,,,,,, -341600,4.2201824,0.60441333,,,,,,,,,,,,,, -341700,4.785768,0.6544172,,,,,,,,,,,,,, -341800,4.444482,0.57557094,,,,,,,,,,,,,, -341900,4.2226315,0.628339,,,,,,,,,,,,,, -342000,4.9418263,0.6387061,,,,,,,,,,,,,, -342100,4.2379255,0.5960324,,,,,,,,,,,,,, -342200,4.8559437,0.64978147,,,,,,,,,,,,,, -342300,4.313659,0.60550195,,,,,,,,,,,,,, -342400,4.310543,0.5660569,,,,,,,,,,,,,, -342500,4.838654,0.59943974,,,,,,,,,,,,,, -342600,5.418806,0.60755277,,,,,,,,,,,,,, -342650,,,0.9606783986091614,0.1467174738645553,0.7546600103378296,1.0551903247833252,50000.0,0.6305000185966492,1.8172508478164675,10000.0,116345.7600734234,120386.4504327774,116345.7600734234,4013.947923898697,14.70347285270691,0.0 -342700,4.6614013,0.6156491,,,,,,,,,,,,,, -342800,4.7551584,0.7000366,,,,,,,,,,,,,, -342900,4.6008687,0.60245526,,,,,,,,,,,,,, -343000,4.6634073,0.61384904,,,,,,,,,,,,,, -343100,4.966863,0.67901087,,,,,,,,,,,,,, -343200,4.946245,0.7243622,,,,,,,,,,,,,, -343300,5.020684,0.65489864,,,,,,,,,,,,,, -343400,4.6063395,0.63848513,,,,,,,,,,,,,, -343500,4.4984455,0.6430084,,,,,,,,,,,,,, -343600,4.601219,0.63171256,,,,,,,,,,,,,, -343700,4.7815886,0.6758374,,,,,,,,,,,,,, -343800,4.503327,0.6072694,,,,,,,,,,,,,, -343900,4.6576037,0.64530843,,,,,,,,,,,,,, -344000,4.5401726,0.6918345,,,,,,,,,,,,,, -344100,4.2850075,0.6236906,,,,,,,,,,,,,, -344153,,,0.9591438174247742,0.1492006033658981,0.7547199726104736,1.0547455549240112,50000.0,0.6302000284194946,1.8171762228012085,10000.0,116855.8805770874,120914.15520620346,116855.8805770874,4031.3891978263855,14.792497396469116,0.0 -344200,4.7879543,0.6366283,,,,,,,,,,,,,, -344300,4.505719,0.59525883,,,,,,,,,,,,,, -344400,4.9910517,0.6972769,,,,,,,,,,,,,, -344500,4.2950597,0.6442978,,,,,,,,,,,,,, -344600,5.0349326,0.5786329,,,,,,,,,,,,,, -344700,4.527635,0.6642848,,,,,,,,,,,,,, -344800,4.747159,0.67291874,,,,,,,,,,,,,, -344900,4.3959503,0.60131645,,,,,,,,,,,,,, -345000,4.1830754,0.56042176,,,,,,,,,,,,,, -345100,4.5436144,0.580788,,,,,,,,,,,,,, -345200,4.6874228,0.71582377,,,,,,,,,,,,,, -345300,4.594463,0.60185164,,,,,,,,,,,,,, -345400,4.6464534,0.67888546,,,,,,,,,,,,,, -345500,4.4942827,0.5942767,,,,,,,,,,,,,, -345600,5.784597,0.74058574,,,,,,,,,,,,,, -345656,,,0.9612762928009032,0.1464512199163437,0.7548999786376953,1.0546934604644775,50000.0,0.6303000450134277,1.8182249069213867,10000.0,117365.90293121338,121440.9950222969,117365.90293121338,4048.061734199524,14.883959531784058,0.0 -345700,4.6555514,0.64361656,,,,,,,,,,,,,, -345800,4.4712896,0.65990984,,,,,,,,,,,,,, -345900,4.216454,0.64590555,,,,,,,,,,,,,, -346000,4.4900584,0.6104983,,,,,,,,,,,,,, -346100,4.9855905,0.6662599,,,,,,,,,,,,,, -346200,4.638574,0.63496816,,,,,,,,,,,,,, -346300,4.3807354,0.58719414,,,,,,,,,,,,,, -346400,4.588401,0.635086,,,,,,,,,,,,,, -346500,5.2145534,0.65294826,,,,,,,,,,,,,, -346600,4.1794014,0.5774829,,,,,,,,,,,,,, -346700,4.344134,0.5813108,,,,,,,,,,,,,, -346800,4.42317,0.6752751,,,,,,,,,,,,,, -346900,4.148391,0.5641873,,,,,,,,,,,,,, -347000,4.1556797,0.60399973,,,,,,,,,,,,,, -347100,4.477062,0.6314528,,,,,,,,,,,,,, -347158,,,0.9612165093421936,0.1453184336423874,0.7544599771499634,1.0555799007415771,50000.0,0.6308000087738037,1.817484259605408,10000.0,117875.79716300964,121967.6102962494,117875.79716300964,4064.625182628632,14.9880850315094,0.0 -347200,4.4361835,0.6235028,,,,,,,,,,,,,, -347300,4.685627,0.70961195,,,,,,,,,,,,,, -347400,4.7569065,0.59745014,,,,,,,,,,,,,, -347500,4.9276013,0.6876088,,,,,,,,,,,,,, -347600,4.679993,0.6051993,,,,,,,,,,,,,, -347700,5.0030165,0.63020176,,,,,,,,,,,,,, -347800,4.381632,0.6100111,,,,,,,,,,,,,, -347900,4.7147856,0.63803005,,,,,,,,,,,,,, -348000,4.7948136,0.5492575,,,,,,,,,,,,,, -348100,4.569196,0.5994177,,,,,,,,,,,,,, -348200,4.192208,0.67703635,,,,,,,,,,,,,, -348300,5.0697937,0.65705127,,,,,,,,,,,,,, -348400,4.6087155,0.58620465,,,,,,,,,,,,,, -348500,5.3519983,0.63058335,,,,,,,,,,,,,, -348600,4.5400133,0.566047,,,,,,,,,,,,,, -348661,,,0.9612962007522584,0.1442330181598663,0.7548999786376953,1.055746078491211,50000.0,0.6313000321388245,1.8207018375396729,10000.0,118385.72451281548,122494.24131369592,118385.72451281548,4081.1895790100098,15.07457399368286,0.0 -348700,4.643257,0.5872741,,,,,,,,,,,,,, -348800,4.5178165,0.6007672,,,,,,,,,,,,,, -348900,4.4825206,0.5410457,,,,,,,,,,,,,, -349000,4.654554,0.6295834,,,,,,,,,,,,,, -349100,4.253131,0.58687395,,,,,,,,,,,,,, -349200,4.653882,0.57876956,,,,,,,,,,,,,, -349300,4.8284793,0.6951608,,,,,,,,,,,,,, -349400,4.290907,0.54369986,,,,,,,,,,,,,, -349500,4.4296575,0.5955,,,,,,,,,,,,,, -349600,4.250488,0.6219311,,,,,,,,,,,,,, -349700,4.498273,0.6451689,,,,,,,,,,,,,, -349800,4.5332294,0.6381841,,,,,,,,,,,,,, -349900,4.411858,0.6386125,,,,,,,,,,,,,, -350000,4.8055387,0.6177998,,,,,,,,,,,,,, -350100,4.3362327,0.62651503,,,,,,,,,,,,,, -350163,,,0.9616350531578064,0.1450804620981216,0.7546199560165405,1.0556223392486572,50000.0,0.6304000020027161,1.819610834121704,10000.0,118895.66305184364,123020.94657683372,118895.66305184364,4097.764590501785,15.211965799331663,0.0 -350200,4.2917733,0.6117655,,,,,,,,,,,,,, -350300,5.2191596,0.6930599,,,,,,,,,,,,,, -350400,4.2919993,0.70219505,,,,,,,,,,,,,, -350500,4.8034816,0.6321287,,,,,,,,,,,,,, -350600,4.4060345,0.6522461,,,,,,,,,,,,,, -350700,4.522531,0.5723843,,,,,,,,,,,,,, -350800,4.421379,0.64666533,,,,,,,,,,,,,, -350900,4.560811,0.5895139,,,,,,,,,,,,,, -351000,4.71713,0.5956145,,,,,,,,,,,,,, -351100,5.316677,0.6451446,,,,,,,,,,,,,, -351200,4.6672373,0.6560818,,,,,,,,,,,,,, -351300,4.551947,0.64157504,,,,,,,,,,,,,, -351400,5.1196933,0.58352906,,,,,,,,,,,,,, -351500,4.620759,0.60838765,,,,,,,,,,,,,, -351600,4.143842,0.50214905,,,,,,,,,,,,,, -351667,,,0.9608577489852904,0.1461744904518127,0.754859983921051,1.0564955472946167,50000.0,0.6300000548362732,1.8202815055847168,10000.0,119405.83976817132,123547.9008114338,119405.83976817132,4114.402928113937,15.298362493515016,0.0 -351700,4.2996716,0.5556328,,,,,,,,,,,,,, -351800,4.2989078,0.6086073,,,,,,,,,,,,,, -351900,4.5698624,0.6767576,,,,,,,,,,,,,, -352000,4.219383,0.638599,,,,,,,,,,,,,, -352100,4.6520915,0.6237277,,,,,,,,,,,,,, -352200,4.611706,0.63786054,,,,,,,,,,,,,, -352300,4.3403387,0.6010226,,,,,,,,,,,,,, -352400,4.316784,0.6366068,,,,,,,,,,,,,, -352500,4.8357463,0.5569735,,,,,,,,,,,,,, -352600,4.7339983,0.6839164,,,,,,,,,,,,,, -352700,4.589917,0.6393219,,,,,,,,,,,,,, -352800,4.4224877,0.5735319,,,,,,,,,,,,,, -352900,4.6000543,0.6190342,,,,,,,,,,,,,, -353000,4.7765093,0.6577302,,,,,,,,,,,,,, -353100,4.9526153,0.61729115,,,,,,,,,,,,,, -353169,,,0.9614157676696776,0.1444696485996246,0.7545199990272522,1.0556888580322266,50000.0,0.6300000548362732,1.819525122642517,10000.0,119916.00923371316,124075.02059555054,119916.00923371316,4131.206185817719,15.391158103942873,0.0 -353200,4.5867043,0.6031124,,,,,,,,,,,,,, -353300,4.4261417,0.6141552,,,,,,,,,,,,,, -353400,4.2530713,0.56569695,,,,,,,,,,,,,, -353500,4.552643,0.61924154,,,,,,,,,,,,,, -353600,4.7925854,0.6794592,,,,,,,,,,,,,, -353700,4.4050655,0.5803034,,,,,,,,,,,,,, -353800,4.378095,0.5856308,,,,,,,,,,,,,, -353900,5.0027227,0.7034953,,,,,,,,,,,,,, -354000,4.517964,0.5998295,,,,,,,,,,,,,, -354100,4.704229,0.6210556,,,,,,,,,,,,,, -354200,4.53181,0.68733287,,,,,,,,,,,,,, -354300,4.936064,0.65493673,,,,,,,,,,,,,, -354400,4.818939,0.5987645,,,,,,,,,,,,,, -354500,4.5826616,0.6066969,,,,,,,,,,,,,, -354600,4.5266023,0.62166756,,,,,,,,,,,,,, -354672,,,0.9606186151504515,0.1470703184604644,0.7550999522209167,1.0547600984573364,50000.0,0.6304000020027161,1.817675828933716,10000.0,120425.9566578865,124601.83222198486,120425.9566578865,4147.925310850143,15.483352661132812,0.0 -354700,4.823781,0.6544213,,,,,,,,,,,,,, -354800,4.3985505,0.6495458,,,,,,,,,,,,,, -354900,4.5921173,0.5906449,,,,,,,,,,,,,, -355000,4.773826,0.61185163,,,,,,,,,,,,,, -355100,4.3953805,0.53805923,,,,,,,,,,,,,, -355200,3.988493,0.5461437,,,,,,,,,,,,,, -355300,4.6232085,0.70922077,,,,,,,,,,,,,, -355400,4.24769,0.6101371,,,,,,,,,,,,,, -355500,4.4909244,0.6295193,,,,,,,,,,,,,, -355600,4.4235244,0.6048137,,,,,,,,,,,,,, -355700,4.517351,0.5875057,,,,,,,,,,,,,, -355800,4.8557367,0.6721755,,,,,,,,,,,,,, -355900,4.6157827,0.5884408,,,,,,,,,,,,,, -356000,4.7392445,0.68170893,,,,,,,,,,,,,, -356100,4.7234974,0.627115,,,,,,,,,,,,,, -356175,,,0.9604591727256776,0.1472638845443725,0.7554199695587158,1.056190848350525,50000.0,0.629800021648407,1.819735646247864,10000.0,120936.03289294244,125128.67629480362,120936.03289294244,4164.55343580246,15.569127082824709,0.0 -356200,4.512811,0.6190815,,,,,,,,,,,,,, -356300,4.826377,0.7152945,,,,,,,,,,,,,, -356400,4.481049,0.5676018,,,,,,,,,,,,,, -356500,4.373455,0.5670357,,,,,,,,,,,,,, -356600,4.78316,0.67261755,,,,,,,,,,,,,, -356700,4.6424007,0.5718487,,,,,,,,,,,,,, -356800,4.7099814,0.6135857,,,,,,,,,,,,,, -356900,4.5553637,0.64479905,,,,,,,,,,,,,, -357000,4.878528,0.678249,,,,,,,,,,,,,, -357100,4.2985263,0.6437117,,,,,,,,,,,,,, -357200,4.805004,0.6735035,,,,,,,,,,,,,, -357300,4.818524,0.6715336,,,,,,,,,,,,,, -357400,4.675189,0.7269581,,,,,,,,,,,,,, -357500,4.6459994,0.6228238,,,,,,,,,,,,,, -357600,4.3564596,0.56627417,,,,,,,,,,,,,, -357678,,,0.9614756107330322,0.146551638841629,0.7547399997711182,1.0553967952728271,50000.0,0.6301000118255615,1.8201425075531008,10000.0,121446.13084578514,125655.50641322136,121446.13084578514,4181.1410665512085,15.66112756729126,0.0 -357700,4.6427045,0.58544415,,,,,,,,,,,,,, -357800,4.8470488,0.56754106,,,,,,,,,,,,,, -357900,5.26247,0.7036045,,,,,,,,,,,,,, -358000,4.5552926,0.64694655,,,,,,,,,,,,,, -358100,4.588152,0.5817654,,,,,,,,,,,,,, -358200,4.436253,0.6081218,,,,,,,,,,,,,, -358300,4.5447907,0.5865189,,,,,,,,,,,,,, -358400,4.522937,0.6530047,,,,,,,,,,,,,, -358500,4.545693,0.5410258,,,,,,,,,,,,,, -358600,4.745327,0.6148635,,,,,,,,,,,,,, -358700,4.628909,0.6115061,,,,,,,,,,,,,, -358800,4.4253983,0.5962583,,,,,,,,,,,,,, -358900,4.2384124,0.6319094,,,,,,,,,,,,,, -359000,4.454974,0.60467523,,,,,,,,,,,,,, -359100,4.518299,0.55998725,,,,,,,,,,,,,, -359180,,,0.9594826102256776,0.1474414020776748,0.7551800012588501,1.055631160736084,50000.0,0.6306000351905823,1.819602012634277,10000.0,121956.0479941368,126182.3363289833,121956.0479941368,4197.899563074112,15.761876106262209,0.0 -359200,4.810805,0.5983824,,,,,,,,,,,,,, -359300,4.8054223,0.622993,,,,,,,,,,,,,, -359400,4.057391,0.57764345,,,,,,,,,,,,,, -359500,4.43857,0.5476163,,,,,,,,,,,,,, -359600,4.7689037,0.5785804,,,,,,,,,,,,,, -359700,4.449633,0.6836734,,,,,,,,,,,,,, -359800,5.259256,0.6898557,,,,,,,,,,,,,, -359900,4.58142,0.6538256,,,,,,,,,,,,,, -360000,4.4959874,0.6930423,,,,,,,,,,,,,, -360100,5.097063,0.6810641,,,,,,,,,,,,,, -360200,4.5270724,0.61640763,,,,,,,,,,,,,, -360300,4.577745,0.60295033,,,,,,,,,,,,,, -360400,4.4802322,0.57801485,,,,,,,,,,,,,, -360500,4.5012875,0.6509614,,,,,,,,,,,,,, -360600,4.9651637,0.62156916,,,,,,,,,,,,,, -360684,,,0.9600605964660645,0.1481878459453582,0.7550599575042725,1.0551674365997314,50000.0,0.6294000148773193,1.8188576698303225,10000.0,122466.16347789764,126709.16116809844,122466.16347789764,4214.466181278229,15.852198362350464,0.0 -360700,4.363109,0.61640656,,,,,,,,,,,,,, -360800,4.686256,0.5622087,,,,,,,,,,,,,, -360900,4.733987,0.6833312,,,,,,,,,,,,,, -361000,4.4988375,0.58644,,,,,,,,,,,,,, -361100,4.491716,0.5790851,,,,,,,,,,,,,, -361200,4.346336,0.6593241,,,,,,,,,,,,,, -361300,4.9663463,0.66859555,,,,,,,,,,,,,, -361400,4.484512,0.6429175,,,,,,,,,,,,,, -361500,4.087868,0.5145403,,,,,,,,,,,,,, -361600,5.212382,0.696049,,,,,,,,,,,,,, -361700,4.299857,0.634678,,,,,,,,,,,,,, -361800,4.6786823,0.66063887,,,,,,,,,,,,,, -361900,4.8121505,0.65306604,,,,,,,,,,,,,, -362000,4.8884344,0.56620216,,,,,,,,,,,,,, -362100,4.62525,0.66685665,,,,,,,,,,,,,, -362187,,,0.9610371589660645,0.1449005752801895,0.7542600035667419,1.055108666419983,50000.0,0.6307000517845154,1.818588137626648,10000.0,122976.25839185716,127236.59667944908,122976.25839185716,4231.663587808609,15.942219734191896,0.0 -362200,4.633202,0.67204636,,,,,,,,,,,,,, -362300,4.425606,0.6566342,,,,,,,,,,,,,, -362400,4.329495,0.58476317,,,,,,,,,,,,,, -362500,4.564738,0.5953779,,,,,,,,,,,,,, -362600,4.8922005,0.6588594,,,,,,,,,,,,,, -362700,4.792595,0.5812323,,,,,,,,,,,,,, -362800,4.6846614,0.63428175,,,,,,,,,,,,,, -362900,4.709565,0.67443866,,,,,,,,,,,,,, -363000,4.200916,0.61381644,,,,,,,,,,,,,, -363100,4.5397067,0.5950528,,,,,,,,,,,,,, -363200,4.4220953,0.6445337,,,,,,,,,,,,,, -363300,4.541995,0.5715703,,,,,,,,,,,,,, -363400,4.855818,0.5969581,,,,,,,,,,,,,, -363500,4.562251,0.6630314,,,,,,,,,,,,,, -363600,4.727641,0.65510875,,,,,,,,,,,,,, -363690,,,0.9592434167861938,0.1494711190462112,0.7552199959754944,1.0559954643249512,50000.0,0.6306000351905823,1.8202086687088013,10000.0,123486.3828933239,127763.50550031662,123486.3828933239,4248.319851875305,16.017550230026245,0.0 -363700,4.4620724,0.61007434,,,,,,,,,,,,,, -363800,4.547294,0.6858971,,,,,,,,,,,,,, -363900,4.6854186,0.6067611,,,,,,,,,,,,,, -364000,4.667873,0.6420559,,,,,,,,,,,,,, -364100,4.7336793,0.70745957,,,,,,,,,,,,,, -364200,4.5565686,0.6305166,,,,,,,,,,,,,, -364300,4.127305,0.6104499,,,,,,,,,,,,,, -364400,4.7064595,0.62395775,,,,,,,,,,,,,, -364500,4.628088,0.5928633,,,,,,,,,,,,,, -364600,5.112932,0.64984137,,,,,,,,,,,,,, -364700,4.9604716,0.6962603,,,,,,,,,,,,,, -364800,4.4424524,0.6510547,,,,,,,,,,,,,, -364900,4.363899,0.6393282,,,,,,,,,,,,,, -365000,3.7959497,0.5200188,,,,,,,,,,,,,, -365100,4.4069376,0.5958366,,,,,,,,,,,,,, -365192,,,0.9596819281578064,0.1469049155712127,0.7549399733543396,1.0571703910827637,50000.0,0.6299000382423401,1.8223227262496948,10000.0,123996.3074054718,128290.24997830392,123996.3074054718,4264.995944738388,16.10771369934082,0.0 -365200,4.1638474,0.55872506,,,,,,,,,,,,,, -365300,4.428802,0.6066615,,,,,,,,,,,,,, -365400,4.4059095,0.63156205,,,,,,,,,,,,,, -365500,4.5739927,0.6279378,,,,,,,,,,,,,, -365600,4.6272097,0.6689496,,,,,,,,,,,,,, -365700,4.6162667,0.62460923,,,,,,,,,,,,,, -365800,4.6950045,0.6306506,,,,,,,,,,,,,, -365900,4.6419697,0.61379313,,,,,,,,,,,,,, -366000,4.733806,0.64695084,,,,,,,,,,,,,, -366100,4.404717,0.5732491,,,,,,,,,,,,,, -366200,4.8394303,0.6052296,,,,,,,,,,,,,, -366300,4.2889853,0.5793106,,,,,,,,,,,,,, -366400,4.9529567,0.672264,,,,,,,,,,,,,, -366500,4.4267893,0.5807779,,,,,,,,,,,,,, -366600,5.117855,0.66548735,,,,,,,,,,,,,, -366695,,,0.961156725883484,0.1489018052816391,0.7549999952316284,1.055810809135437,50000.0,0.6302000284194946,1.8210707902908323,10000.0,124506.4336566925,128817.22952365877,124506.4336566925,4281.69885134697,16.204543352127075,0.0 -366700,4.5609574,0.64014804,,,,,,,,,,,,,, -366800,4.4471354,0.62498647,,,,,,,,,,,,,, -366900,4.552097,0.6524821,,,,,,,,,,,,,, -367000,4.8642817,0.6773733,,,,,,,,,,,,,, -367100,4.5038185,0.5789217,,,,,,,,,,,,,, -367200,4.3395495,0.6313646,,,,,,,,,,,,,, -367300,4.5160027,0.62690043,,,,,,,,,,,,,, -367400,4.575046,0.58515257,,,,,,,,,,,,,, -367500,4.4498672,0.5883498,,,,,,,,,,,,,, -367600,4.5506973,0.56667835,,,,,,,,,,,,,, -367700,4.6177745,0.6436709,,,,,,,,,,,,,, -367800,4.2341027,0.59424627,,,,,,,,,,,,,, -367900,4.6136622,0.6231982,,,,,,,,,,,,,, -368000,4.2736692,0.5835589,,,,,,,,,,,,,, -368100,4.5216537,0.63758063,,,,,,,,,,,,,, -368197,,,0.9597018361091614,0.1484669595956802,0.7549999952316284,1.0560575723648071,50000.0,0.6299000382423401,1.818937420845032,10000.0,125016.474984169,129344.00648355484,125016.474984169,4298.291492462158,16.294630765914917,0.0 -368200,4.53312,0.62312645,,,,,,,,,,,,,, -368300,4.529741,0.60096604,,,,,,,,,,,,,, -368400,4.4888673,0.6016563,,,,,,,,,,,,,, -368500,4.1376357,0.58307374,,,,,,,,,,,,,, -368600,5.008194,0.58653456,,,,,,,,,,,,,, -368700,4.5954266,0.7076293,,,,,,,,,,,,,, -368800,4.8152103,0.6133246,,,,,,,,,,,,,, -368900,4.4126587,0.6240714,,,,,,,,,,,,,, -369000,4.3387403,0.6423052,,,,,,,,,,,,,, -369100,4.506056,0.66102254,,,,,,,,,,,,,, -369200,4.7146673,0.5947284,,,,,,,,,,,,,, -369300,4.51674,0.65892196,,,,,,,,,,,,,, -369400,4.8427334,0.71711814,,,,,,,,,,,,,, -369500,4.3636384,0.6010212,,,,,,,,,,,,,, -369600,4.4477777,0.6025498,,,,,,,,,,,,,, -369700,4.684122,0.65963423,,,,,,,,,,,,,, -369701,,,0.9613958597183228,0.1455031484365463,0.7548799514770508,1.054102063179016,50000.0,0.6312000155448914,1.8172831535339355,10000.0,125526.76858377457,129871.15259885788,125526.76858377457,4314.997852563858,16.387300729751587,0.0 -369800,4.460103,0.605142,,,,,,,,,,,,,, -369900,4.2074647,0.6324999,,,,,,,,,,,,,, -370000,4.7292814,0.68845206,,,,,,,,,,,,,, -370100,4.449036,0.59594595,,,,,,,,,,,,,, -370200,4.2850027,0.5647155,,,,,,,,,,,,,, -370300,4.274007,0.54675806,,,,,,,,,,,,,, -370400,4.2207355,0.55048496,,,,,,,,,,,,,, -370500,4.510064,0.6791266,,,,,,,,,,,,,, -370600,4.8590603,0.6553058,,,,,,,,,,,,,, -370700,4.437414,0.64763075,,,,,,,,,,,,,, -370800,4.7300835,0.57118875,,,,,,,,,,,,,, -370900,4.796701,0.6632315,,,,,,,,,,,,,, -371000,4.502115,0.67473614,,,,,,,,,,,,,, -371100,4.2822137,0.5879767,,,,,,,,,,,,,, -371200,4.1526165,0.57665473,,,,,,,,,,,,,, -371204,,,0.959382951259613,0.1471273601055145,0.7553399801254272,1.0547292232513428,50000.0,0.6304000020027161,1.81916081905365,10000.0,126036.9502067566,130398.2058544159,126036.9502067566,4331.719168186188,16.48409938812256,0.0 -371300,4.4212656,0.67587125,,,,,,,,,,,,,, -371400,4.466702,0.58131856,,,,,,,,,,,,,, -371500,4.921531,0.6009435,,,,,,,,,,,,,, -371600,4.5870852,0.64789045,,,,,,,,,,,,,, -371700,4.429031,0.6124222,,,,,,,,,,,,,, -371800,4.447407,0.6244933,,,,,,,,,,,,,, -371900,5.1201262,0.7146453,,,,,,,,,,,,,, -372000,4.497021,0.6469692,,,,,,,,,,,,,, -372100,4.4820213,0.6241421,,,,,,,,,,,,,, -372200,5.066617,0.7212448,,,,,,,,,,,,,, -372300,4.898769,0.61685574,,,,,,,,,,,,,, -372400,4.2013216,0.5897526,,,,,,,,,,,,,, -372500,4.5778866,0.6598913,,,,,,,,,,,,,, -372600,4.409183,0.62289745,,,,,,,,,,,,,, -372700,4.255635,0.6027622,,,,,,,,,,,,,, -372707,,,0.9607780575752258,0.1458850502967834,0.7550399899482727,1.0556190013885498,50000.0,0.6306000351905823,1.8193237781524656,10000.0,126546.84605765344,130924.82411670683,126546.84605765344,4348.296098232269,16.576444387435913,0.0 -372800,4.5419555,0.61718655,,,,,,,,,,,,,, -372900,4.8693666,0.6443012,,,,,,,,,,,,,, -373000,4.7931027,0.69672585,,,,,,,,,,,,,, -373100,4.480054,0.636175,,,,,,,,,,,,,, -373200,4.305132,0.5819454,,,,,,,,,,,,,, -373300,4.428554,0.6452395,,,,,,,,,,,,,, -373400,4.503958,0.5475762,,,,,,,,,,,,,, -373500,4.4696393,0.58231443,,,,,,,,,,,,,, -373600,4.613126,0.6455519,,,,,,,,,,,,,, -373700,4.6438537,0.65990216,,,,,,,,,,,,,, -373800,4.6900783,0.628693,,,,,,,,,,,,,, -373900,4.8629556,0.6799786,,,,,,,,,,,,,, -374000,4.5171266,0.6917694,,,,,,,,,,,,,, -374100,4.680982,0.64889705,,,,,,,,,,,,,, -374200,4.4377728,0.6367931,,,,,,,,,,,,,, -374209,,,0.9608178734779358,0.1466042101383209,0.7549600005149841,1.054998755455017,50000.0,0.6301000118255615,1.81897234916687,10000.0,127056.72100305556,131451.49540233612,127056.72100305556,4364.946834564209,16.668809413909912,0.0 -374300,4.55185,0.61150074,,,,,,,,,,,,,, -374400,4.6464367,0.6464106,,,,,,,,,,,,,, -374500,4.3570046,0.59237474,,,,,,,,,,,,,, -374600,4.7991643,0.6442835,,,,,,,,,,,,,, -374700,4.3114915,0.5795072,,,,,,,,,,,,,, -374800,4.261528,0.59839064,,,,,,,,,,,,,, -374900,4.753312,0.6033986,,,,,,,,,,,,,, -375000,4.1033063,0.59215266,,,,,,,,,,,,,, -375100,4.386291,0.62371004,,,,,,,,,,,,,, -375200,4.792462,0.6161144,,,,,,,,,,,,,, -375300,4.1977324,0.61324126,,,,,,,,,,,,,, -375400,4.4369106,0.5860282,,,,,,,,,,,,,, -375500,4.5074496,0.57232386,,,,,,,,,,,,,, -375600,4.4425206,0.54685354,,,,,,,,,,,,,, -375700,5.131714,0.73949176,,,,,,,,,,,,,, -375711,,,0.9608178734779358,0.1465323269367218,0.754859983921051,1.0567834377288818,50000.0,0.6310000419616699,1.8201284408569336,10000.0,127566.63928079604,131977.97884607315,127566.63928079604,4381.365758657455,16.76178002357483,0.0 -375800,4.9014573,0.6933858,,,,,,,,,,,,,, -375900,4.6295185,0.6812882,,,,,,,,,,,,,, -376000,4.5389924,0.6792282,,,,,,,,,,,,,, -376100,4.6418886,0.6826813,,,,,,,,,,,,,, -376200,4.5421386,0.6372071,,,,,,,,,,,,,, -376300,5.1664667,0.61838245,,,,,,,,,,,,,, -376400,4.84036,0.6414507,,,,,,,,,,,,,, -376500,4.7764506,0.72193205,,,,,,,,,,,,,, -376600,4.9349513,0.68636084,,,,,,,,,,,,,, -376700,4.7539473,0.5880394,,,,,,,,,,,,,, -376800,4.5638514,0.64669156,,,,,,,,,,,,,, -376900,5.0245514,0.669864,,,,,,,,,,,,,, -377000,4.621695,0.5971981,,,,,,,,,,,,,, -377100,4.850654,0.6242392,,,,,,,,,,,,,, -377200,4.4534044,0.62222886,,,,,,,,,,,,,, -377213,,,0.9602598547935486,0.146878108382225,0.7548799514770508,1.0555857419967651,50000.0,0.6302000284194946,1.820207476615905,10000.0,128076.54186677931,132504.5306007862,128076.54186677931,4397.861366271973,16.862693548202515,0.0 -377300,4.831271,0.74202955,,,,,,,,,,,,,, -377400,4.41874,0.64893717,,,,,,,,,,,,,, -377500,3.8946972,0.5588509,,,,,,,,,,,,,, -377600,4.466658,0.66597384,,,,,,,,,,,,,, -377700,4.425112,0.6577681,,,,,,,,,,,,,, -377800,4.8489027,0.6518923,,,,,,,,,,,,,, -377900,4.421111,0.60784507,,,,,,,,,,,,,, -378000,4.8437843,0.5857076,,,,,,,,,,,,,, -378100,4.426536,0.57648736,,,,,,,,,,,,,, -378200,4.9678836,0.63660944,,,,,,,,,,,,,, -378300,4.69603,0.6572854,,,,,,,,,,,,,, -378400,4.5665345,0.6422206,,,,,,,,,,,,,, -378500,4.4387565,0.60671306,,,,,,,,,,,,,, -378600,4.50889,0.62003374,,,,,,,,,,,,,, -378700,4.609317,0.6215254,,,,,,,,,,,,,, -378714,,,0.961694836616516,0.1412075906991958,0.7550399899482727,1.055570125579834,50000.0,0.6297000050544739,1.818804383277893,10000.0,128585.89695334436,133031.53467178345,128585.89695334436,4414.532790660858,17.786174058914185,0.0 -378800,4.365296,0.5914196,,,,,,,,,,,,,, -378900,4.9127483,0.662222,,,,,,,,,,,,,, -379000,4.8576274,0.6693946,,,,,,,,,,,,,, -379100,5.0925703,0.58794725,,,,,,,,,,,,,, -379200,4.7210393,0.63976204,,,,,,,,,,,,,, -379300,4.705067,0.6515155,,,,,,,,,,,,,, -379400,4.20828,0.5976681,,,,,,,,,,,,,, -379500,4.102468,0.6067878,,,,,,,,,,,,,, -379600,4.5121326,0.604021,,,,,,,,,,,,,, -379700,4.836202,0.6675964,,,,,,,,,,,,,, -379800,4.6602025,0.6156497,,,,,,,,,,,,,, -379900,4.505194,0.5975546,,,,,,,,,,,,,, -380000,4.486745,0.60753894,,,,,,,,,,,,,, -380100,4.505708,0.6037258,,,,,,,,,,,,,, -380200,4.959756,0.6437274,,,,,,,,,,,,,, -380217,,,0.961316168308258,0.1460568010807037,0.7547399997711182,1.05453622341156,50000.0,0.6300000548362732,1.8188520669937127,10000.0,129095.92270565032,133558.33870458603,129095.92270565032,4431.161715507507,17.883363485336304,0.0 -380300,5.08298,0.70178556,,,,,,,,,,,,,, -380400,4.4572678,0.60079473,,,,,,,,,,,,,, -380500,4.5889287,0.5950181,,,,,,,,,,,,,, -380600,4.567939,0.6416439,,,,,,,,,,,,,, -380700,4.7661643,0.6304554,,,,,,,,,,,,,, -380800,4.740108,0.6349661,,,,,,,,,,,,,, -380900,4.7781205,0.6627704,,,,,,,,,,,,,, -381000,4.434342,0.6656627,,,,,,,,,,,,,, -381100,4.177665,0.59474534,,,,,,,,,,,,,, -381200,4.055746,0.5937839,,,,,,,,,,,,,, -381300,4.437899,0.59918034,,,,,,,,,,,,,, -381400,4.769244,0.65609026,,,,,,,,,,,,,, -381500,4.5139003,0.66276276,,,,,,,,,,,,,, -381600,4.5085397,0.6219191,,,,,,,,,,,,,, -381700,4.5556293,0.6634572,,,,,,,,,,,,,, -381720,,,0.9607381820678712,0.1449822038412094,0.7547000050544739,1.0562328100204468,50000.0,0.6305000185966492,1.8185116052627563,10000.0,129605.80332398416,134085.56870532036,129605.80332398416,4448.361759185791,17.9791898727417,0.0 -381800,4.737783,0.64078295,,,,,,,,,,,,,, -381900,4.5398684,0.5884094,,,,,,,,,,,,,, -382000,4.6938353,0.702668,,,,,,,,,,,,,, -382100,4.8257833,0.64556575,,,,,,,,,,,,,, -382200,4.1108007,0.57245106,,,,,,,,,,,,,, -382300,4.2837234,0.58131707,,,,,,,,,,,,,, -382400,4.4697914,0.6368153,,,,,,,,,,,,,, -382500,4.5719805,0.60455453,,,,,,,,,,,,,, -382600,4.4149404,0.59067047,,,,,,,,,,,,,, -382700,4.22867,0.5575515,,,,,,,,,,,,,, -382800,4.3843,0.6482727,,,,,,,,,,,,,, -382900,4.8005214,0.6527709,,,,,,,,,,,,,, -383000,5.5686574,0.73169446,,,,,,,,,,,,,, -383100,4.3862944,0.6156779,,,,,,,,,,,,,, -383200,4.420342,0.590725,,,,,,,,,,,,,, -383223,,,0.9597018361091614,0.1486869752407074,0.7551800012588501,1.0554879903793335,50000.0,0.6303000450134277,1.8197861909866333,10000.0,130115.81632757188,134612.40925621986,130115.81632757188,4465.047143936157,18.068235158920288,0.0 -383300,4.523346,0.64499867,,,,,,,,,,,,,, -383400,4.484549,0.61410016,,,,,,,,,,,,,, -383500,4.599277,0.60389173,,,,,,,,,,,,,, -383600,4.2644634,0.6196516,,,,,,,,,,,,,, -383700,4.55786,0.6413416,,,,,,,,,,,,,, -383800,5.1432104,0.58466506,,,,,,,,,,,,,, -383900,4.1752067,0.5587534,,,,,,,,,,,,,, -384000,4.4154835,0.63271314,,,,,,,,,,,,,, -384100,4.586824,0.6192377,,,,,,,,,,,,,, -384200,4.1807227,0.58179617,,,,,,,,,,,,,, -384300,5.3359346,0.68773687,,,,,,,,,,,,,, -384400,4.8422036,0.62619,,,,,,,,,,,,,, -384500,4.287075,0.6251665,,,,,,,,,,,,,, -384600,4.9070415,0.66562927,,,,,,,,,,,,,, -384700,4.180336,0.5913099,,,,,,,,,,,,,, -384726,,,0.9608976244926452,0.145926147699356,0.7549200057983398,1.055842638015747,50000.0,0.629800021648407,1.818486213684082,10000.0,130625.8651521206,135139.1086113453,130625.8651521206,4481.549957990646,18.162591457366943,0.0 -384800,4.6740427,0.7066288,,,,,,,,,,,,,, -384900,4.3099017,0.5863502,,,,,,,,,,,,,, -385000,4.5091553,0.60580564,,,,,,,,,,,,,, -385100,4.112395,0.54770124,,,,,,,,,,,,,, -385200,4.3693023,0.627542,,,,,,,,,,,,,, -385300,4.7350955,0.5497842,,,,,,,,,,,,,, -385400,3.8475146,0.49766067,,,,,,,,,,,,,, -385500,5.3419194,0.6916251,,,,,,,,,,,,,, -385600,4.6232285,0.66161376,,,,,,,,,,,,,, -385700,5.581785,0.66281617,,,,,,,,,,,,,, -385800,4.607107,0.6322326,,,,,,,,,,,,,, -385900,4.375868,0.5916978,,,,,,,,,,,,,, -386000,4.861161,0.6829039,,,,,,,,,,,,,, -386100,4.258177,0.57560855,,,,,,,,,,,,,, -386200,4.4409137,0.5926011,,,,,,,,,,,,,, -386228,,,0.9614357352256776,0.1451815962791443,0.7549999952316284,1.0552767515182495,50000.0,0.6301000118255615,1.817418098449707,10000.0,131135.7101225853,135665.71124458313,131135.7101225853,4498.146886110306,18.27062964439392,0.0 -386300,4.3393035,0.56987494,,,,,,,,,,,,,, -386400,4.917611,0.63265616,,,,,,,,,,,,,, -386500,4.2598,0.5543632,,,,,,,,,,,,,, -386600,4.15117,0.56984985,,,,,,,,,,,,,, -386700,5.158969,0.6248279,,,,,,,,,,,,,, -386800,4.280203,0.5514373,,,,,,,,,,,,,, -386900,5.2535315,0.6589161,,,,,,,,,,,,,, -387000,4.555609,0.65437156,,,,,,,,,,,,,, -387100,4.651664,0.6091403,,,,,,,,,,,,,, -387200,4.3106523,0.5458353,,,,,,,,,,,,,, -387300,4.4577527,0.57461226,,,,,,,,,,,,,, -387400,4.311168,0.66820747,,,,,,,,,,,,,, -387500,4.7010903,0.6276954,,,,,,,,,,,,,, -387600,4.2851467,0.6267825,,,,,,,,,,,,,, -387700,4.283141,0.6430389,,,,,,,,,,,,,, -387731,,,0.9614157676696776,0.1451276540756225,0.7549200057983398,1.0561199188232422,50000.0,0.6303000450134277,1.8203297853469849,10000.0,131645.83796977997,136192.5762116909,131645.83796977997,4514.734850406647,18.367199420928955,0.0 -387800,5.1705556,0.66207325,,,,,,,,,,,,,, -387900,5.0331845,0.6377428,,,,,,,,,,,,,, -388000,4.6414185,0.61849755,,,,,,,,,,,,,, -388100,4.2957754,0.55124694,,,,,,,,,,,,,, -388200,4.775702,0.6860881,,,,,,,,,,,,,, -388300,4.3323417,0.6049087,,,,,,,,,,,,,, -388400,5.0875864,0.635926,,,,,,,,,,,,,, -388500,4.2207546,0.54763824,,,,,,,,,,,,,, -388600,4.3549695,0.5409434,,,,,,,,,,,,,, -388700,4.1842813,0.56300807,,,,,,,,,,,,,, -388800,4.587977,0.6064553,,,,,,,,,,,,,, -388900,4.483296,0.6424121,,,,,,,,,,,,,, -389000,4.4624176,0.6011796,,,,,,,,,,,,,, -389100,4.1723204,0.5432225,,,,,,,,,,,,,, -389200,4.638873,0.7042184,,,,,,,,,,,,,, -389233,,,0.961156725883484,0.1456144601106643,0.7547399997711182,1.055548906326294,50000.0,0.6301000118255615,1.8190268278121948,10000.0,132155.74128174782,136719.2530815601,132155.74128174782,4531.350663900375,18.472204446792603,0.0 -389300,4.4787903,0.6086121,,,,,,,,,,,,,, -389400,4.1890936,0.58723545,,,,,,,,,,,,,, -389500,4.291104,0.5778358,,,,,,,,,,,,,, -389600,4.424328,0.63490725,,,,,,,,,,,,,, -389700,4.966231,0.6506282,,,,,,,,,,,,,, -389800,4.389244,0.6122891,,,,,,,,,,,,,, -389900,4.7104907,0.66422796,,,,,,,,,,,,,, -390000,4.672346,0.68067145,,,,,,,,,,,,,, -390100,4.8623657,0.683524,,,,,,,,,,,,,, -390200,4.495903,0.6916104,,,,,,,,,,,,,, -390300,4.773118,0.65786326,,,,,,,,,,,,,, -390400,4.5951705,0.65079176,,,,,,,,,,,,,, -390500,4.2976203,0.60539556,,,,,,,,,,,,,, -390600,4.24065,0.60920566,,,,,,,,,,,,,, -390700,4.459295,0.6653208,,,,,,,,,,,,,, -390736,,,0.9600805044174194,0.1483380496501922,0.7553600072860718,1.0564292669296265,50000.0,0.6299000382423401,1.8202205896377563,10000.0,132665.60343956947,137245.94326090813,132665.60343956947,4548.0277144908905,18.570212364196777,0.0 -390800,4.4918733,0.64394355,,,,,,,,,,,,,, -390900,4.9840846,0.6600123,,,,,,,,,,,,,, -391000,4.4176893,0.6359532,,,,,,,,,,,,,, -391100,4.551447,0.6209936,,,,,,,,,,,,,, -391200,4.8149395,0.6515449,,,,,,,,,,,,,, -391300,4.466135,0.62919635,,,,,,,,,,,,,, -391400,4.70072,0.6747707,,,,,,,,,,,,,, -391500,5.122285,0.6751299,,,,,,,,,,,,,, -391600,4.3869205,0.60274243,,,,,,,,,,,,,, -391700,4.805906,0.653245,,,,,,,,,,,,,, -391800,4.3301435,0.54817885,,,,,,,,,,,,,, -391900,4.573307,0.6218726,,,,,,,,,,,,,, -392000,4.76236,0.6749143,,,,,,,,,,,,,, -392100,4.704995,0.7175303,,,,,,,,,,,,,, -392200,4.379872,0.6232027,,,,,,,,,,,,,, -392239,,,0.9613958597183228,0.143719732761383,0.7547599673271179,1.055400013923645,50000.0,0.6301000118255615,1.8189407587051392,10000.0,133175.61843442917,137772.5371003151,133175.61843442917,4564.457427501679,18.6664834022522,0.0 -392300,4.439477,0.61953354,,,,,,,,,,,,,, -392400,4.3418384,0.6093719,,,,,,,,,,,,,, -392500,4.2749286,0.58923227,,,,,,,,,,,,,, -392600,4.629276,0.5756738,,,,,,,,,,,,,, -392700,4.798662,0.6062875,,,,,,,,,,,,,, -392800,4.9147325,0.6544983,,,,,,,,,,,,,, -392900,4.7517486,0.6392478,,,,,,,,,,,,,, -393000,4.5487647,0.6295245,,,,,,,,,,,,,, -393100,4.863619,0.633765,,,,,,,,,,,,,, -393200,4.517271,0.64092636,,,,,,,,,,,,,, -393300,4.40744,0.60821044,,,,,,,,,,,,,, -393400,4.6064186,0.6236992,,,,,,,,,,,,,, -393500,4.3331165,0.6557775,,,,,,,,,,,,,, -393600,4.957789,0.62102336,,,,,,,,,,,,,, -393700,4.73982,0.65416175,,,,,,,,,,,,,, -393743,,,0.9607780575752258,0.144915223121643,0.7547199726104736,1.0558115243911743,50000.0,0.6302000284194946,1.8172541856765747,10000.0,133685.6428039074,138299.19182229042,133685.6428039074,4580.938060998917,18.762927532196045,0.0 -393800,4.5070605,0.62909615,,,,,,,,,,,,,, -393900,4.536408,0.5727446,,,,,,,,,,,,,, -394000,4.9439626,0.6639098,,,,,,,,,,,,,, -394100,4.4258046,0.6476611,,,,,,,,,,,,,, -394200,4.554972,0.5441808,,,,,,,,,,,,,, -394300,4.3523884,0.5164172,,,,,,,,,,,,,, -394400,4.562162,0.6546354,,,,,,,,,,,,,, -394500,5.02595,0.72339547,,,,,,,,,,,,,, -394600,4.5049157,0.61135674,,,,,,,,,,,,,, -394700,4.707104,0.6518083,,,,,,,,,,,,,, -394800,4.194503,0.59257,,,,,,,,,,,,,, -394900,4.363624,0.5807573,,,,,,,,,,,,,, -395000,4.577681,0.6071147,,,,,,,,,,,,,, -395100,4.470434,0.6171196,,,,,,,,,,,,,, -395200,4.110758,0.56284857,,,,,,,,,,,,,, -395246,,,0.960180163383484,0.1485379189252853,0.7544999718666077,1.0541654825210571,50000.0,0.6305000185966492,1.8179938793182373,10000.0,134195.6141886711,138825.90058135986,134195.6141886711,4597.527137756348,18.85831356048584,0.0 -395300,4.328654,0.61491984,,,,,,,,,,,,,, -395400,4.504678,0.6297559,,,,,,,,,,,,,, -395500,4.834357,0.6103907,,,,,,,,,,,,,, -395600,4.1992993,0.5503858,,,,,,,,,,,,,, -395700,4.1018248,0.59670115,,,,,,,,,,,,,, -395800,4.905062,0.71137816,,,,,,,,,,,,,, -395900,4.8468504,0.62677014,,,,,,,,,,,,,, -396000,4.1326513,0.6105446,,,,,,,,,,,,,, -396100,5.003365,0.6989718,,,,,,,,,,,,,, -396200,4.547434,0.64155304,,,,,,,,,,,,,, -396300,5.0409865,0.6757462,,,,,,,,,,,,,, -396400,4.5919495,0.6206619,,,,,,,,,,,,,, -396500,4.3579617,0.59147245,,,,,,,,,,,,,, -396600,4.5965776,0.62285376,,,,,,,,,,,,,, -396700,4.650771,0.63114214,,,,,,,,,,,,,, -396749,,,0.9606186151504515,0.1472035646438598,0.7551400065422058,1.054998517036438,50000.0,0.6306000351905823,1.8182374238967896,10000.0,134705.46764349937,139352.6130232811,134705.46764349937,4614.233896970749,18.956045627594,0.0 -396800,4.8769717,0.62968653,,,,,,,,,,,,,, -396900,4.782062,0.6599225,,,,,,,,,,,,,, -397000,4.603403,0.6432664,,,,,,,,,,,,,, -397100,4.608016,0.5835849,,,,,,,,,,,,,, -397200,4.662546,0.6698968,,,,,,,,,,,,,, -397300,4.435478,0.5689547,,,,,,,,,,,,,, -397400,4.185402,0.6413139,,,,,,,,,,,,,, -397500,4.1198215,0.5880487,,,,,,,,,,,,,, -397600,4.5388002,0.61626047,,,,,,,,,,,,,, -397700,4.523953,0.68875355,,,,,,,,,,,,,, -397800,4.323985,0.52087843,,,,,,,,,,,,,, -397900,4.9426675,0.617274,,,,,,,,,,,,,, -398000,4.479665,0.58700466,,,,,,,,,,,,,, -398100,4.916225,0.70440733,,,,,,,,,,,,,, -398200,4.3302355,0.61869985,,,,,,,,,,,,,, -398252,,,0.9606983065605164,0.1478620916604995,0.754859983921051,1.055959701538086,50000.0,0.6303000450134277,1.8208487033844,10000.0,135215.65347242355,139879.5144290924,135215.65347242355,4630.798380851746,19.05357837677002,0.0 -398300,4.791167,0.6495583,,,,,,,,,,,,,, -398400,4.6364403,0.6151981,,,,,,,,,,,,,, -398500,4.7057967,0.64344215,,,,,,,,,,,,,, -398600,4.2778378,0.54976153,,,,,,,,,,,,,, -398700,5.3609324,0.638547,,,,,,,,,,,,,, -398800,4.182344,0.60093886,,,,,,,,,,,,,, -398900,4.8094683,0.6607163,,,,,,,,,,,,,, -399000,4.5192304,0.55379164,,,,,,,,,,,,,, -399100,4.4150376,0.5600635,,,,,,,,,,,,,, -399200,4.330627,0.63585687,,,,,,,,,,,,,, -399300,4.5054226,0.65848696,,,,,,,,,,,,,, -399400,4.59756,0.70562756,,,,,,,,,,,,,, -399500,4.5895567,0.63579106,,,,,,,,,,,,,, -399600,4.2639213,0.59893394,,,,,,,,,,,,,, -399700,4.395482,0.6197188,,,,,,,,,,,,,, -399755,,,0.9601004123687744,0.1449165344238281,0.7548799514770508,1.054701566696167,50000.0,0.6297000050544739,1.8177335262298584,10000.0,135725.60052633286,140406.1835153103,135725.60052633286,4647.372666358948,19.147751808166504,0.0 -399800,4.7647862,0.6486101,,,,,,,,,,,,,, -399900,4.6746774,0.6294525,,,,,,,,,,,,,, -400000,4.668047,0.5946984,,,,,,,,,,,,,, -400100,4.260595,0.62216794,,,,,,,,,,,,,, -400200,4.6450677,0.608816,,,,,,,,,,,,,, -400300,4.6890054,0.6192323,,,,,,,,,,,,,, -400400,4.5705385,0.58818895,,,,,,,,,,,,,, -400500,4.305155,0.5784359,,,,,,,,,,,,,, -400600,4.7920256,0.59203595,,,,,,,,,,,,,, -400700,4.6657624,0.6947073,,,,,,,,,,,,,, -400800,4.8999434,0.6386578,,,,,,,,,,,,,, -400900,4.905946,0.68752545,,,,,,,,,,,,,, -401000,4.42799,0.62107724,,,,,,,,,,,,,, -401100,4.589073,0.6189111,,,,,,,,,,,,,, -401200,5.084807,0.58774215,,,,,,,,,,,,,, -401258,,,0.9598413109779358,0.1499073654413223,0.7547599673271179,1.05543315410614,50000.0,0.6304000020027161,1.8172955513000488,10000.0,136235.73774003985,140933.08319306374,136235.73774003985,4663.987517356873,19.24118137359619,0.0 -401300,4.440237,0.60029656,,,,,,,,,,,,,, -401400,4.8240056,0.6701077,,,,,,,,,,,,,, -401500,4.284062,0.6094727,,,,,,,,,,,,,, -401600,4.409288,0.6026721,,,,,,,,,,,,,, -401700,4.459065,0.65561587,,,,,,,,,,,,,, -401800,4.5802736,0.60584664,,,,,,,,,,,,,, -401900,4.5037413,0.64499986,,,,,,,,,,,,,, -402000,4.721362,0.6260439,,,,,,,,,,,,,, -402100,4.3999963,0.5900296,,,,,,,,,,,,,, -402200,4.426968,0.5910436,,,,,,,,,,,,,, -402300,4.4180584,0.64611393,,,,,,,,,,,,,, -402400,4.7498937,0.62638295,,,,,,,,,,,,,, -402500,4.196654,0.6149591,,,,,,,,,,,,,, -402600,4.6127205,0.6434136,,,,,,,,,,,,,, -402700,4.738899,0.6127243,,,,,,,,,,,,,, -402760,,,0.9602598547935486,0.1475264430046081,0.7549600005149841,1.056088924407959,50000.0,0.6304000020027161,1.8198953866958616,10000.0,136745.5668039322,141459.62825584412,136745.5668039322,4680.529687643051,19.362505674362183,0.0 -402800,5.1291466,0.72826433,,,,,,,,,,,,,, -402900,4.40396,0.59153867,,,,,,,,,,,,,, -403000,4.5114717,0.5788421,,,,,,,,,,,,,, -403100,4.507661,0.6546891,,,,,,,,,,,,,, -403200,5.0831103,0.7027173,,,,,,,,,,,,,, -403300,4.8149543,0.6398427,,,,,,,,,,,,,, -403400,5.134887,0.6315458,,,,,,,,,,,,,, -403500,4.669174,0.64529085,,,,,,,,,,,,,, -403600,4.7746506,0.7158358,,,,,,,,,,,,,, -403700,4.4686294,0.647108,,,,,,,,,,,,,, -403800,4.7559366,0.65343416,,,,,,,,,,,,,, -403900,4.1269574,0.5770006,,,,,,,,,,,,,, -404000,4.591041,0.6132567,,,,,,,,,,,,,, -404100,4.83619,0.63315135,,,,,,,,,,,,,, -404200,4.5318723,0.62510455,,,,,,,,,,,,,, -404263,,,0.9608378410339355,0.1480744481086731,0.7546600103378296,1.0558913946151731,50000.0,0.6304000020027161,1.8172743320465088,10000.0,137255.6495103836,141986.512663126,137255.6495103836,4697.173872232437,19.46660351753235,0.0 -404300,4.743274,0.6559377,,,,,,,,,,,,,, -404400,4.8898745,0.56931764,,,,,,,,,,,,,, -404500,4.508694,0.5976193,,,,,,,,,,,,,, -404600,4.4487605,0.56758267,,,,,,,,,,,,,, -404700,4.6111865,0.6440641,,,,,,,,,,,,,, -404800,4.294623,0.5579978,,,,,,,,,,,,,, -404900,4.771697,0.66864467,,,,,,,,,,,,,, -405000,4.884231,0.6560734,,,,,,,,,,,,,, -405100,4.542093,0.6374031,,,,,,,,,,,,,, -405200,4.2259197,0.5671317,,,,,,,,,,,,,, -405300,4.7322454,0.6357765,,,,,,,,,,,,,, -405400,4.4207478,0.6647358,,,,,,,,,,,,,, -405500,4.557047,0.67814493,,,,,,,,,,,,,, -405600,4.23245,0.6484048,,,,,,,,,,,,,, -405700,4.624493,0.6103757,,,,,,,,,,,,,, -405766,,,0.959781527519226,0.1490054130554199,0.754539966583252,1.0548299551010132,50000.0,0.6300000548362732,1.818644642829895,10000.0,137765.65123844147,142513.2674689293,137765.65123844147,4713.778460979462,19.561405658721924,0.0 -405800,4.542659,0.6398065,,,,,,,,,,,,,, -405900,4.3354535,0.6070208,,,,,,,,,,,,,, -406000,5.292228,0.7055322,,,,,,,,,,,,,, -406100,4.504239,0.6549109,,,,,,,,,,,,,, -406200,4.490591,0.6059793,,,,,,,,,,,,,, -406300,4.362647,0.64148504,,,,,,,,,,,,,, -406400,5.1886926,0.69038033,,,,,,,,,,,,,, -406500,4.3012576,0.55845207,,,,,,,,,,,,,, -406600,4.704711,0.6148082,,,,,,,,,,,,,, -406700,4.4556017,0.6067584,,,,,,,,,,,,,, -406800,4.4860764,0.6066858,,,,,,,,,,,,,, -406900,4.1933427,0.5665724,,,,,,,,,,,,,, -407000,4.4028707,0.59595406,,,,,,,,,,,,,, -407100,4.452267,0.6155657,,,,,,,,,,,,,, -407200,4.475986,0.5859759,,,,,,,,,,,,,, -407269,,,0.9610570669174194,0.1444486230611801,0.7546199560165405,1.055034637451172,50000.0,0.6303000450134277,1.81747841835022,10000.0,138275.66917204857,143040.0849058628,138275.66917204857,4730.435454368591,19.65113377571106,0.0 -407300,4.4668875,0.6234943,,,,,,,,,,,,,, -407400,4.912968,0.6298994,,,,,,,,,,,,,, -407500,4.527435,0.58006597,,,,,,,,,,,,,, -407600,4.4372525,0.66076374,,,,,,,,,,,,,, -407700,4.014693,0.54908144,,,,,,,,,,,,,, -407800,4.6556115,0.63686115,,,,,,,,,,,,,, -407900,4.359073,0.6223359,,,,,,,,,,,,,, -408000,4.607272,0.70591605,,,,,,,,,,,,,, -408100,4.5894623,0.58864176,,,,,,,,,,,,,, -408200,4.165447,0.5897653,,,,,,,,,,,,,, -408300,4.5299926,0.5986781,,,,,,,,,,,,,, -408400,4.678434,0.6786042,,,,,,,,,,,,,, -408500,4.28209,0.67057335,,,,,,,,,,,,,, -408600,4.598856,0.6444057,,,,,,,,,,,,,, -408700,4.568553,0.55886114,,,,,,,,,,,,,, -408772,,,0.9592633843421936,0.1505450159311294,0.7552399635314941,1.0552719831466677,50000.0,0.6304000020027161,1.8169559240341189,10000.0,138785.59047579765,143566.8327577114,138785.59047579765,4747.116379022598,19.74397087097168,0.0 -408800,4.7738533,0.71752036,,,,,,,,,,,,,, -408900,4.823192,0.6552329,,,,,,,,,,,,,, -409000,4.386639,0.5960173,,,,,,,,,,,,,, -409100,4.852063,0.63074577,,,,,,,,,,,,,, -409200,4.420924,0.5986854,,,,,,,,,,,,,, -409300,4.142236,0.6578056,,,,,,,,,,,,,, -409400,4.380234,0.65836495,,,,,,,,,,,,,, -409500,4.6245494,0.6543461,,,,,,,,,,,,,, -409600,4.447149,0.6505248,,,,,,,,,,,,,, -409700,4.1002007,0.5366409,,,,,,,,,,,,,, -409800,4.530748,0.5602963,,,,,,,,,,,,,, -409900,5.02154,0.63334954,,,,,,,,,,,,,, -410000,4.570986,0.57453316,,,,,,,,,,,,,, -410100,4.1423364,0.6220153,,,,,,,,,,,,,, -410200,4.8728757,0.6824395,,,,,,,,,,,,,, -410274,,,0.9607381820678712,0.1451594531536102,0.7550999522209167,1.0560662746429443,50000.0,0.6297000050544739,1.8201110363006592,10000.0,139295.51828718185,144093.49433231354,139295.51828718185,4763.699547767639,19.84065055847168,0.0 -410300,4.214966,0.56626284,,,,,,,,,,,,,, -410400,4.668247,0.5766805,,,,,,,,,,,,,, -410500,4.4014196,0.5381469,,,,,,,,,,,,,, -410600,4.2857895,0.5766774,,,,,,,,,,,,,, -410700,4.269411,0.6159927,,,,,,,,,,,,,, -410800,4.65785,0.5811993,,,,,,,,,,,,,, -410900,4.318173,0.62184507,,,,,,,,,,,,,, -411000,4.493006,0.62259746,,,,,,,,,,,,,, -411100,4.4249187,0.6120088,,,,,,,,,,,,,, -411200,4.8076034,0.6891853,,,,,,,,,,,,,, -411300,4.513813,0.64508843,,,,,,,,,,,,,, -411400,4.6890097,0.6616309,,,,,,,,,,,,,, -411500,4.4215484,0.6023978,,,,,,,,,,,,,, -411600,4.4765344,0.58583885,,,,,,,,,,,,,, -411700,4.765865,0.5998842,,,,,,,,,,,,,, -411778,,,0.9607381820678712,0.1458351612091064,0.7547399997711182,1.0559715032577517,50000.0,0.6297000050544739,1.8194156885147093,10000.0,139805.68612885475,144620.41824531555,139805.68612885475,4780.302748441696,19.94161033630371,0.0 -411800,4.395965,0.5964378,,,,,,,,,,,,,, -411900,5.0669127,0.6380829,,,,,,,,,,,,,, -412000,4.5505486,0.68265,,,,,,,,,,,,,, -412100,4.3700447,0.65900403,,,,,,,,,,,,,, -412200,4.2818904,0.51627976,,,,,,,,,,,,,, -412300,4.315255,0.6765652,,,,,,,,,,,,,, -412400,4.5506186,0.6221835,,,,,,,,,,,,,, -412500,5.195231,0.6977742,,,,,,,,,,,,,, -412600,4.598078,0.644757,,,,,,,,,,,,,, -412700,4.4643,0.55669284,,,,,,,,,,,,,, -412800,4.606499,0.58806515,,,,,,,,,,,,,, -412900,4.2969646,0.6613922,,,,,,,,,,,,,, -413000,4.0550413,0.6002703,,,,,,,,,,,,,, -413100,4.6147556,0.65833324,,,,,,,,,,,,,, -413200,4.624172,0.7116013,,,,,,,,,,,,,, -413281,,,0.9610171914100648,0.1474115848541259,0.7545599937438965,1.05522882938385,50000.0,0.6307000517845154,1.818820595741272,10000.0,140315.84532666206,145147.35459661484,140315.84532666206,4796.926322221756,20.04148244857788,0.0 -413300,4.272009,0.5552224,,,,,,,,,,,,,, -413400,4.405102,0.59311956,,,,,,,,,,,,,, -413500,4.5879607,0.6063328,,,,,,,,,,,,,, -413600,4.695225,0.59195256,,,,,,,,,,,,,, -413700,4.880111,0.67017865,,,,,,,,,,,,,, -413800,4.411768,0.62237537,,,,,,,,,,,,,, -413900,4.8974695,0.6911162,,,,,,,,,,,,,, -414000,4.27149,0.60420847,,,,,,,,,,,,,, -414100,4.51053,0.65515465,,,,,,,,,,,,,, -414200,5.0366406,0.6740938,,,,,,,,,,,,,, -414300,4.783891,0.68367606,,,,,,,,,,,,,, -414400,4.2142987,0.61966354,,,,,,,,,,,,,, -414500,4.8624744,0.649958,,,,,,,,,,,,,, -414600,4.366042,0.6093838,,,,,,,,,,,,,, -414700,4.42771,0.6363431,,,,,,,,,,,,,, -414784,,,0.9598612785339355,0.1473547369241714,0.7548799514770508,1.0550575256347656,50000.0,0.6302000284194946,1.818278431892395,10000.0,140825.76637721062,145674.05481624603,140825.76637721062,4813.552445888519,20.14136648178101,0.0 -414800,4.505484,0.58499795,,,,,,,,,,,,,, -414900,5.3077865,0.6933477,,,,,,,,,,,,,, -415000,4.258074,0.6529496,,,,,,,,,,,,,, -415100,4.5923934,0.59866166,,,,,,,,,,,,,, -415200,4.540939,0.6041219,,,,,,,,,,,,,, -415300,4.3401093,0.65019435,,,,,,,,,,,,,, -415400,4.52368,0.6024397,,,,,,,,,,,,,, -415500,4.3033123,0.58225846,,,,,,,,,,,,,, -415600,4.6901746,0.6482622,,,,,,,,,,,,,, -415700,4.320749,0.5930275,,,,,,,,,,,,,, -415800,4.6753345,0.63961834,,,,,,,,,,,,,, -415900,4.3937035,0.5572767,,,,,,,,,,,,,, -416000,4.4431486,0.6086598,,,,,,,,,,,,,, -416100,5.333604,0.624289,,,,,,,,,,,,,, -416200,4.4062304,0.6463248,,,,,,,,,,,,,, -416287,,,0.9617745280265808,0.1414497345685959,0.7548199892044067,1.056381106376648,50000.0,0.629800021648407,1.8195641040802,10000.0,141335.79673457146,146200.8248922825,141335.79673457146,4830.139992237091,20.239493131637573,0.0 -416300,4.681917,0.58918864,,,,,,,,,,,,,, -416400,4.6635127,0.70438737,,,,,,,,,,,,,, -416500,4.4654374,0.5645413,,,,,,,,,,,,,, -416600,4.682072,0.59410983,,,,,,,,,,,,,, -416700,4.3125286,0.6170028,,,,,,,,,,,,,, -416800,5.1179385,0.64909023,,,,,,,,,,,,,, -416900,4.6926284,0.6225113,,,,,,,,,,,,,, -417000,4.6560845,0.57204443,,,,,,,,,,,,,, -417100,4.72649,0.62875235,,,,,,,,,,,,,, -417200,5.0298767,0.64530766,,,,,,,,,,,,,, -417300,4.9880247,0.5775571,,,,,,,,,,,,,, -417400,4.332436,0.59421426,,,,,,,,,,,,,, -417500,4.798089,0.64838064,,,,,,,,,,,,,, -417600,4.6013575,0.666564,,,,,,,,,,,,,, -417700,5.086561,0.69673043,,,,,,,,,,,,,, -417791,,,0.9614556431770324,0.1459291577339172,0.7548199892044067,1.0547126531600952,50000.0,0.6301000118255615,1.818116307258606,10000.0,141845.92551875114,146727.69525814056,141845.92551875114,4846.724861383438,20.34183478355408,0.0 -417800,4.791294,0.5680728,,,,,,,,,,,,,, -417900,4.302075,0.60176456,,,,,,,,,,,,,, -418000,4.289588,0.6055653,,,,,,,,,,,,,, -418100,4.9400496,0.67420894,,,,,,,,,,,,,, -418200,4.4973087,0.56374377,,,,,,,,,,,,,, -418300,4.797788,0.6850858,,,,,,,,,,,,,, -418400,4.724472,0.5833502,,,,,,,,,,,,,, -418500,4.575893,0.65243703,,,,,,,,,,,,,, -418600,4.6962514,0.5993037,,,,,,,,,,,,,, -418700,4.8163466,0.6162849,,,,,,,,,,,,,, -418800,4.4040246,0.6323582,,,,,,,,,,,,,, -418900,5.091773,0.63686454,,,,,,,,,,,,,, -419000,4.312394,0.57772064,,,,,,,,,,,,,, -419100,4.917534,0.6542861,,,,,,,,,,,,,, -419200,5.3525243,0.7001068,,,,,,,,,,,,,, -419294,,,0.9610570669174194,0.1447632610797882,0.7547599673271179,1.0562463998794556,50000.0,0.6306000351905823,1.8181205987930296,10000.0,142356.03491735458,147254.5678577423,142356.03491735458,4863.333392381668,20.443277597427368,0.0 -419300,4.822106,0.62911886,,,,,,,,,,,,,, -419400,5.1342764,0.6609164,,,,,,,,,,,,,, -419500,4.6645184,0.6228076,,,,,,,,,,,,,, -419600,4.375908,0.64649993,,,,,,,,,,,,,, -419700,4.573884,0.6110643,,,,,,,,,,,,,, -419800,4.737335,0.62782574,,,,,,,,,,,,,, -419900,4.3250813,0.5926982,,,,,,,,,,,,,, -420000,4.7052464,0.6118391,,,,,,,,,,,,,, -420100,4.629211,0.6255113,,,,,,,,,,,,,, -420200,4.3727136,0.691662,,,,,,,,,,,,,, -420300,4.8476987,0.521826,,,,,,,,,,,,,, -420400,4.6177225,0.638364,,,,,,,,,,,,,, -420500,5.033563,0.66095006,,,,,,,,,,,,,, -420600,4.5601087,0.62107086,,,,,,,,,,,,,, -420700,4.633074,0.60418415,,,,,,,,,,,,,, -420797,,,0.960180163383484,0.1462839394807815,0.7548999786376953,1.056235671043396,50000.0,0.6304000020027161,1.8204907178878784,10000.0,142865.97271060944,147781.92643356323,142865.97271060944,4880.598002910614,20.54567670822144,0.0 -420800,4.57499,0.6092192,,,,,,,,,,,,,, -420900,5.291666,0.6140315,,,,,,,,,,,,,, -421000,4.300329,0.6366812,,,,,,,,,,,,,, -421100,4.563143,0.60067356,,,,,,,,,,,,,, -421200,4.372463,0.60821974,,,,,,,,,,,,,, -421300,4.2852235,0.54498255,,,,,,,,,,,,,, -421400,4.5392256,0.596602,,,,,,,,,,,,,, -421500,4.6121645,0.62709737,,,,,,,,,,,,,, -421600,4.4851427,0.6216319,,,,,,,,,,,,,, -421700,4.7140126,0.60622764,,,,,,,,,,,,,, -421800,4.6049848,0.61267424,,,,,,,,,,,,,, -421900,4.6843224,0.6672504,,,,,,,,,,,,,, -422000,4.510712,0.7157436,,,,,,,,,,,,,, -422100,4.8615823,0.6607206,,,,,,,,,,,,,, -422200,5.0232286,0.55774236,,,,,,,,,,,,,, -422299,,,0.9609175324440002,0.146629050374031,0.7549200057983398,1.054916501045227,50000.0,0.6300000548362732,1.817649483680725,10000.0,143375.83768200874,148308.54959082603,143375.83768200874,4897.206827640533,20.6411075592041,0.0 -422300,4.851453,0.6694138,,,,,,,,,,,,,, -422400,5.106676,0.70620394,,,,,,,,,,,,,, -422500,4.3812423,0.658754,,,,,,,,,,,,,, -422600,4.767599,0.69037986,,,,,,,,,,,,,, -422700,4.609969,0.6301775,,,,,,,,,,,,,, -422800,5.0147533,0.6147268,,,,,,,,,,,,,, -422900,4.4004807,0.6025856,,,,,,,,,,,,,, -423000,5.4254346,0.63215,,,,,,,,,,,,,, -423100,5.078042,0.66868347,,,,,,,,,,,,,, -423200,4.2604823,0.5932317,,,,,,,,,,,,,, -423300,4.2586627,0.598313,,,,,,,,,,,,,, -423400,4.8950086,0.6727645,,,,,,,,,,,,,, -423500,4.478713,0.70667356,,,,,,,,,,,,,, -423600,4.796433,0.66583425,,,,,,,,,,,,,, -423700,4.5644245,0.6309091,,,,,,,,,,,,,, -423800,5.047667,0.6749724,,,,,,,,,,,,,, -423801,,,0.9602997303009032,0.1462405920028686,0.7544599771499634,1.0567693710327148,50000.0,0.6299000382423401,1.819965362548828,10000.0,143885.76449489594,148835.2104690075,143885.76449489594,4913.787359714508,20.7419707775116,0.0 -423900,4.160586,0.5682366,,,,,,,,,,,,,, -424000,4.327621,0.5934616,,,,,,,,,,,,,, -424100,4.534139,0.62570953,,,,,,,,,,,,,, -424200,4.5100274,0.63651025,,,,,,,,,,,,,, -424300,5.1288753,0.6113233,,,,,,,,,,,,,, -424400,4.6004076,0.58774257,,,,,,,,,,,,,, -424500,4.4885473,0.599622,,,,,,,,,,,,,, -424600,4.4019837,0.662965,,,,,,,,,,,,,, -424700,4.7391725,0.6687939,,,,,,,,,,,,,, -424800,4.534512,0.65396607,,,,,,,,,,,,,, -424900,4.3849735,0.61588246,,,,,,,,,,,,,, -425000,4.1279373,0.50600904,,,,,,,,,,,,,, -425100,4.191042,0.58534855,,,,,,,,,,,,,, -425200,4.0813613,0.582466,,,,,,,,,,,,,, -425300,4.243623,0.6112813,,,,,,,,,,,,,, -425303,,,0.9620934128761292,0.1448357999324798,0.7547000050544739,1.054429292678833,50000.0,0.631100058555603,1.8174110651016235,10000.0,144395.63716816902,149361.91491818428,144395.63716816902,4930.458664178848,20.84931969642639,0.0 -425400,4.3517766,0.5943647,,,,,,,,,,,,,, -425500,4.2290606,0.5819374,,,,,,,,,,,,,, -425600,4.5846777,0.62546206,,,,,,,,,,,,,, -425700,5.1338096,0.6191919,,,,,,,,,,,,,, -425800,4.1939287,0.5787472,,,,,,,,,,,,,, -425900,4.824487,0.58970237,,,,,,,,,,,,,, -426000,4.4947786,0.6552876,,,,,,,,,,,,,, -426100,4.5987744,0.5716115,,,,,,,,,,,,,, -426200,4.496586,0.64743114,,,,,,,,,,,,,, -426300,4.880254,0.7244374,,,,,,,,,,,,,, -426400,4.1256,0.6162796,,,,,,,,,,,,,, -426500,4.6580076,0.64548415,,,,,,,,,,,,,, -426600,4.9633627,0.65083915,,,,,,,,,,,,,, -426700,4.799276,0.6346073,,,,,,,,,,,,,, -426800,4.7912273,0.60198367,,,,,,,,,,,,,, -426806,,,0.9610570669174194,0.1460616737604141,0.7549600005149841,1.0565731525421145,50000.0,0.6296000480651855,1.8200547695159912,10000.0,144905.5118675232,149888.57080936432,144905.5118675232,4947.074217557907,20.96125626564026,0.0 -426900,4.594958,0.5354053,,,,,,,,,,,,,, -427000,4.393679,0.60511714,,,,,,,,,,,,,, -427100,4.727767,0.65506715,,,,,,,,,,,,,, -427200,4.478633,0.5648587,,,,,,,,,,,,,, -427300,4.2903705,0.5378311,,,,,,,,,,,,,, -427400,4.5212297,0.6083919,,,,,,,,,,,,,, -427500,4.610951,0.60305935,,,,,,,,,,,,,, -427600,4.329818,0.5664872,,,,,,,,,,,,,, -427700,4.4089994,0.5974712,,,,,,,,,,,,,, -427800,4.4637117,0.6105904,,,,,,,,,,,,,, -427900,4.5330877,0.5974959,,,,,,,,,,,,,, -428000,4.787342,0.6391379,,,,,,,,,,,,,, -428100,4.8931317,0.6133363,,,,,,,,,,,,,, -428200,4.7844195,0.621751,,,,,,,,,,,,,, -428300,5.1545415,0.67217195,,,,,,,,,,,,,, -428308,,,0.9610570669174194,0.1456134170293808,0.7544999718666077,1.0552159547805786,50000.0,0.629800021648407,1.8188358545303345,10000.0,145415.37442803383,150415.18838214874,145415.37442803383,4963.67450594902,21.06189441680908,0.0 -428400,4.5071287,0.5957287,,,,,,,,,,,,,, -428500,4.6873116,0.6617113,,,,,,,,,,,,,, -428600,4.707422,0.6452279,,,,,,,,,,,,,, -428700,4.482963,0.58547574,,,,,,,,,,,,,, -428800,4.7954965,0.6836005,,,,,,,,,,,,,, -428900,4.723218,0.64014053,,,,,,,,,,,,,, -429000,4.5226145,0.6155423,,,,,,,,,,,,,, -429100,4.7024374,0.57001495,,,,,,,,,,,,,, -429200,4.46095,0.6202663,,,,,,,,,,,,,, -429300,5.073363,0.6460463,,,,,,,,,,,,,, -429400,4.526003,0.5364038,,,,,,,,,,,,,, -429500,4.6255283,0.6552201,,,,,,,,,,,,,, -429600,3.9289074,0.55895674,,,,,,,,,,,,,, -429700,4.6072483,0.68759274,,,,,,,,,,,,,, -429800,4.3294606,0.6285906,,,,,,,,,,,,,, -429811,,,0.9610570669174194,0.146034225821495,0.7547599673271179,1.054571509361267,50000.0,0.6301000118255615,1.8180526494979856,10000.0,145925.31875014305,150941.80598330498,145925.31875014305,4980.192862272263,21.162484407424927,0.0 -429900,4.726433,0.6705216,,,,,,,,,,,,,, -430000,4.6690636,0.6849786,,,,,,,,,,,,,, -430100,4.5283146,0.62811744,,,,,,,,,,,,,, -430200,4.1857767,0.56413525,,,,,,,,,,,,,, -430300,4.341467,0.5558767,,,,,,,,,,,,,, -430400,4.520668,0.69131845,,,,,,,,,,,,,, -430500,5.246302,0.6863329,,,,,,,,,,,,,, -430600,5.360831,0.6672297,,,,,,,,,,,,,, -430700,4.1412625,0.6277316,,,,,,,,,,,,,, -430800,4.837698,0.6754396,,,,,,,,,,,,,, -430900,4.465784,0.6286477,,,,,,,,,,,,,, -431000,4.13895,0.56350684,,,,,,,,,,,,,, -431100,4.8023486,0.686344,,,,,,,,,,,,,, -431200,4.6847796,0.5700716,,,,,,,,,,,,,, -431300,4.4592443,0.5934726,,,,,,,,,,,,,, -431313,,,0.9602598547935486,0.1460363268852234,0.7546799778938293,1.0550318956375122,50000.0,0.6306000351905823,1.818410396575928,10000.0,146435.1837067604,151468.40133213997,146435.1837067604,4996.765533208847,21.2678291797638,0.0 -431400,4.5053535,0.63419336,,,,,,,,,,,,,, -431500,4.099433,0.5943339,,,,,,,,,,,,,, -431600,4.338674,0.6176876,,,,,,,,,,,,,, -431700,4.4757676,0.60807097,,,,,,,,,,,,,, -431800,4.3012304,0.58458734,,,,,,,,,,,,,, -431900,4.576783,0.65512985,,,,,,,,,,,,,, -432000,4.8755703,0.67684835,,,,,,,,,,,,,, -432100,4.524021,0.60584843,,,,,,,,,,,,,, -432200,4.448992,0.5466422,,,,,,,,,,,,,, -432300,4.470452,0.618827,,,,,,,,,,,,,, -432400,4.709573,0.696954,,,,,,,,,,,,,, -432500,4.0945463,0.55601513,,,,,,,,,,,,,, -432600,5.0342803,0.5969807,,,,,,,,,,,,,, -432700,4.44865,0.62210625,,,,,,,,,,,,,, -432800,4.962776,0.6563207,,,,,,,,,,,,,, -432816,,,0.9611766338348388,0.1457973569631576,0.7546600103378296,1.0549418926239014,50000.0,0.6296000480651855,1.8184632062911987,10000.0,146945.2528398037,151995.17208194733,146945.2528398037,5013.30003285408,21.38243818283081,0.0 -432900,4.9698176,0.6764606,,,,,,,,,,,,,, -433000,4.6500335,0.6036052,,,,,,,,,,,,,, -433100,4.6050053,0.65663517,,,,,,,,,,,,,, -433200,4.352051,0.61518574,,,,,,,,,,,,,, -433300,4.4541764,0.65704185,,,,,,,,,,,,,, -433400,4.7275467,0.6587058,,,,,,,,,,,,,, -433500,4.4323387,0.5953671,,,,,,,,,,,,,, -433600,4.8223085,0.65001094,,,,,,,,,,,,,, -433700,4.9911046,0.5932148,,,,,,,,,,,,,, -433800,4.9816504,0.68788713,,,,,,,,,,,,,, -433900,4.776163,0.7011697,,,,,,,,,,,,,, -434000,4.4862123,0.6557306,,,,,,,,,,,,,, -434100,4.492611,0.5410333,,,,,,,,,,,,,, -434200,4.7662067,0.6428843,,,,,,,,,,,,,, -434300,4.7311697,0.6590391,,,,,,,,,,,,,, -434318,,,0.9610171914100648,0.1467979103326797,0.7549999952316284,1.0560564994812012,50000.0,0.6299000382423401,1.8182473182678225,10000.0,147455.2933781147,152521.94089126587,147455.2933781147,5029.848837137222,21.50903558731079,0.0 -434400,4.232536,0.60872716,,,,,,,,,,,,,, -434500,4.656028,0.58601385,,,,,,,,,,,,,, -434600,4.384393,0.681687,,,,,,,,,,,,,, -434700,4.5350804,0.57048404,,,,,,,,,,,,,, -434800,4.5445824,0.65612185,,,,,,,,,,,,,, -434900,5.6895576,0.69484806,,,,,,,,,,,,,, -435000,4.7027035,0.63688254,,,,,,,,,,,,,, -435100,4.2516055,0.5544906,,,,,,,,,,,,,, -435200,5.328228,0.65635,,,,,,,,,,,,,, -435300,4.5974736,0.6536658,,,,,,,,,,,,,, -435400,4.192618,0.6343149,,,,,,,,,,,,,, -435500,3.8113964,0.5597193,,,,,,,,,,,,,, -435600,4.909619,0.6669521,,,,,,,,,,,,,, -435700,4.9422483,0.55848086,,,,,,,,,,,,,, -435800,5.2844577,0.73045045,,,,,,,,,,,,,, -435821,,,0.9590640664100648,0.1513096988201141,0.7549799680709839,1.0555171966552734,50000.0,0.6305000185966492,1.819121599197388,10000.0,147965.1800661087,153048.53157043457,147965.1800661087,5046.399070739746,21.609692573547363,0.0 -435900,4.376331,0.61842334,,,,,,,,,,,,,, -436000,4.3790708,0.6670122,,,,,,,,,,,,,, -436100,4.856941,0.6534958,,,,,,,,,,,,,, -436200,4.713266,0.66534793,,,,,,,,,,,,,, -436300,4.6172056,0.6465684,,,,,,,,,,,,,, -436400,4.589523,0.6359788,,,,,,,,,,,,,, -436500,4.56366,0.6324156,,,,,,,,,,,,,, -436600,4.9459963,0.6830524,,,,,,,,,,,,,, -436700,4.285344,0.60950184,,,,,,,,,,,,,, -436800,4.4854264,0.5689182,,,,,,,,,,,,,, -436900,4.41923,0.62612057,,,,,,,,,,,,,, -437000,4.833224,0.6547521,,,,,,,,,,,,,, -437100,4.387972,0.60836166,,,,,,,,,,,,,, -437200,4.407674,0.5950969,,,,,,,,,,,,,, -437300,5.0945616,0.72745925,,,,,,,,,,,,,, -437324,,,0.9607979655265808,0.1454037278890609,0.7547799944877625,1.0546343326568604,50000.0,0.6303000450134277,1.817741870880127,10000.0,148475.27742362022,153575.33093428612,148475.27742362022,5062.943537712097,21.714293003082275,0.0 -437400,4.249219,0.6044761,,,,,,,,,,,,,, -437500,4.318391,0.60016155,,,,,,,,,,,,,, -437600,4.5877123,0.6099976,,,,,,,,,,,,,, -437700,4.7493815,0.6157402,,,,,,,,,,,,,, -437800,4.6543565,0.6502621,,,,,,,,,,,,,, -437900,4.4273915,0.6315404,,,,,,,,,,,,,, -438000,4.698531,0.6378716,,,,,,,,,,,,,, -438100,4.4413247,0.652217,,,,,,,,,,,,,, -438200,4.9897738,0.6521729,,,,,,,,,,,,,, -438300,4.6377196,0.66278356,,,,,,,,,,,,,, -438400,4.8561153,0.6061612,,,,,,,,,,,,,, -438500,4.735695,0.59074205,,,,,,,,,,,,,, -438600,4.7876034,0.5904734,,,,,,,,,,,,,, -438700,5.18206,0.66318566,,,,,,,,,,,,,, -438800,5.0197062,0.67862254,,,,,,,,,,,,,, -438827,,,0.9602997303009032,0.1466187983751297,0.7544199824333191,1.0557987689971924,50000.0,0.629300057888031,1.8190243244171145,10000.0,148985.24048876762,154102.04985761642,148985.24048876762,5079.535908937454,21.82430601119995,0.0 -438900,4.827386,0.66490954,,,,,,,,,,,,,, -439000,4.402607,0.60698444,,,,,,,,,,,,,, -439100,4.621312,0.63033634,,,,,,,,,,,,,, -439200,4.492908,0.6100113,,,,,,,,,,,,,, -439300,5.1442957,0.6563408,,,,,,,,,,,,,, -439400,4.78148,0.6695472,,,,,,,,,,,,,, -439500,4.6587496,0.676859,,,,,,,,,,,,,, -439600,4.570791,0.52747613,,,,,,,,,,,,,, -439700,5.0300364,0.7109073,,,,,,,,,,,,,, -439800,4.6793222,0.614007,,,,,,,,,,,,,, -439900,4.7340307,0.6448274,,,,,,,,,,,,,, -440000,4.5156054,0.59685,,,,,,,,,,,,,, -440100,4.429744,0.65722764,,,,,,,,,,,,,, -440200,4.234335,0.56472987,,,,,,,,,,,,,, -440300,4.552007,0.74413127,,,,,,,,,,,,,, -440330,,,0.9606783986091614,0.1473367661237716,0.7546199560165405,1.0556747913360596,50000.0,0.629800021648407,1.818731546401977,10000.0,149495.2818260193,154628.88868403435,149495.2818260193,5096.1784517765045,21.92604899406433,0.0 -440400,5.625495,0.6481719,,,,,,,,,,,,,, -440500,4.942294,0.63224673,,,,,,,,,,,,,, -440600,4.4900527,0.62470174,,,,,,,,,,,,,, -440700,4.96064,0.70020056,,,,,,,,,,,,,, -440800,4.663464,0.55624366,,,,,,,,,,,,,, -440900,4.513773,0.5936308,,,,,,,,,,,,,, -441000,4.5661473,0.62389684,,,,,,,,,,,,,, -441100,4.5894313,0.65372723,,,,,,,,,,,,,, -441200,4.5338397,0.58755964,,,,,,,,,,,,,, -441300,4.6018057,0.6900469,,,,,,,,,,,,,, -441400,4.814211,0.6147148,,,,,,,,,,,,,, -441500,4.0609775,0.5388174,,,,,,,,,,,,,, -441600,4.363614,0.6466912,,,,,,,,,,,,,, -441700,4.203586,0.6095101,,,,,,,,,,,,,, -441800,4.625627,0.68778837,,,,,,,,,,,,,, -441833,,,0.9596819281578064,0.1484495252370834,0.7549600005149841,1.0566083192825315,50000.0,0.6296000480651855,1.8205649852752688,10000.0,150005.31732654572,155155.54529619217,150005.31732654572,5112.650840759277,22.02223467826844,0.0 -441900,5.074806,0.67089593,,,,,,,,,,,,,, -442000,4.6532254,0.60719407,,,,,,,,,,,,,, -442100,4.344447,0.6594855,,,,,,,,,,,,,, -442200,4.8093247,0.61254734,,,,,,,,,,,,,, -442300,4.9937224,0.7127719,,,,,,,,,,,,,, -442400,4.8264213,0.606567,,,,,,,,,,,,,, -442500,4.4254317,0.6249028,,,,,,,,,,,,,, -442600,4.095369,0.6235316,,,,,,,,,,,,,, -442700,5.05229,0.6886678,,,,,,,,,,,,,, -442800,4.569059,0.5645977,,,,,,,,,,,,,, -442900,4.4459677,0.5917954,,,,,,,,,,,,,, -443000,4.4798503,0.6417537,,,,,,,,,,,,,, -443100,4.6521473,0.68383646,,,,,,,,,,,,,, -443200,4.365182,0.58666825,,,,,,,,,,,,,, -443300,4.448697,0.64652896,,,,,,,,,,,,,, -443336,,,0.9596021771430968,0.1501669436693191,0.754539966583252,1.0561749935150146,50000.0,0.6296000480651855,1.819937229156494,10000.0,150515.3161327839,155682.33803749084,150515.3161327839,5129.289839744568,22.123254776000977,0.0 -443400,4.7387657,0.65935564,,,,,,,,,,,,,, -443500,4.6045423,0.6434544,,,,,,,,,,,,,, -443600,4.7176795,0.649642,,,,,,,,,,,,,, -443700,4.145215,0.64898914,,,,,,,,,,,,,, -443800,4.726841,0.63201356,,,,,,,,,,,,,, -443900,4.4658155,0.64383894,,,,,,,,,,,,,, -444000,5.0073166,0.64202785,,,,,,,,,,,,,, -444100,4.5519466,0.60162127,,,,,,,,,,,,,, -444200,4.437527,0.6059981,,,,,,,,,,,,,, -444300,4.6426573,0.6591921,,,,,,,,,,,,,, -444400,4.5423064,0.59773225,,,,,,,,,,,,,, -444500,3.856423,0.54757404,,,,,,,,,,,,,, -444600,4.3715515,0.60405684,,,,,,,,,,,,,, -444700,4.736368,0.65265983,,,,,,,,,,,,,, -444800,4.5315757,0.6342783,,,,,,,,,,,,,, -444839,,,0.9607381820678712,0.1462684571743011,0.7548999786376953,1.0548075437545776,50000.0,0.6306000351905823,1.818387985229492,10000.0,151025.23944425583,156208.86432909966,151025.23944425583,5145.732047557831,22.230733633041385,0.0 -444900,4.273255,0.5725367,,,,,,,,,,,,,, -445000,4.2234178,0.59590393,,,,,,,,,,,,,, -445100,4.8943195,0.6261401,,,,,,,,,,,,,, -445200,4.3646274,0.58488226,,,,,,,,,,,,,, -445300,4.286052,0.58538735,,,,,,,,,,,,,, -445400,4.6748843,0.6463243,,,,,,,,,,,,,, -445500,4.2871847,0.615688,,,,,,,,,,,,,, -445600,4.953124,0.66291225,,,,,,,,,,,,,, -445700,4.587627,0.602972,,,,,,,,,,,,,, -445800,4.616665,0.6267388,,,,,,,,,,,,,, -445900,4.688551,0.6512451,,,,,,,,,,,,,, -446000,5.0202184,0.64912593,,,,,,,,,,,,,, -446100,4.023861,0.5468146,,,,,,,,,,,,,, -446200,4.63572,0.6394559,,,,,,,,,,,,,, -446300,4.7482843,0.65438384,,,,,,,,,,,,,, -446342,,,0.9601402878761292,0.1477520167827606,0.7553399801254272,1.055067777633667,50000.0,0.6305000185966492,1.81734037399292,10000.0,151535.28600239754,156735.5563902855,151535.28600239754,5162.222529888153,22.332300186157227,0.0 -446400,4.7636075,0.6012958,,,,,,,,,,,,,, -446500,4.7116647,0.64292467,,,,,,,,,,,,,, -446600,4.4337006,0.64905846,,,,,,,,,,,,,, -446700,4.433863,0.6274892,,,,,,,,,,,,,, -446800,4.536269,0.6260838,,,,,,,,,,,,,, -446900,4.8679304,0.6459126,,,,,,,,,,,,,, -447000,4.554263,0.620768,,,,,,,,,,,,,, -447100,4.7974486,0.5979012,,,,,,,,,,,,,, -447200,4.717763,0.67945844,,,,,,,,,,,,,, -447300,4.7158604,0.677492,,,,,,,,,,,,,, -447400,4.476365,0.606671,,,,,,,,,,,,,, -447500,4.7465115,0.5730149,,,,,,,,,,,,,, -447600,4.185765,0.63721675,,,,,,,,,,,,,, -447700,4.395339,0.56223065,,,,,,,,,,,,,, -447800,4.882133,0.68410957,,,,,,,,,,,,,, -447846,,,0.9604790806770324,0.1445417702198028,0.7549200057983398,1.0557647943496704,50000.0,0.6297000050544739,1.8188778162002563,10000.0,152045.3865249157,157262.55180740356,152045.3865249157,5178.921859025955,22.474713563919067,0.0 -447900,4.521477,0.6706769,,,,,,,,,,,,,, -448000,4.619439,0.61376256,,,,,,,,,,,,,, -448100,4.4601154,0.6808005,,,,,,,,,,,,,, -448200,4.980027,0.6581284,,,,,,,,,,,,,, -448300,4.7952046,0.5671223,,,,,,,,,,,,,, -448400,4.612355,0.6101694,,,,,,,,,,,,,, -448500,4.6574,0.6302292,,,,,,,,,,,,,, -448600,4.6415935,0.5899873,,,,,,,,,,,,,, -448700,4.7464833,0.67588377,,,,,,,,,,,,,, -448800,4.61528,0.5439888,,,,,,,,,,,,,, -448900,4.613501,0.63393813,,,,,,,,,,,,,, -449000,4.150635,0.5508463,,,,,,,,,,,,,, -449100,4.3346343,0.6068232,,,,,,,,,,,,,, -449200,4.151512,0.5868351,,,,,,,,,,,,,, -449300,4.6896276,0.7315097,,,,,,,,,,,,,, -449348,,,0.9610171914100648,0.1449035555124282,0.7546799778938293,1.055970549583435,50000.0,0.6299000382423401,1.8194371461868288,10000.0,152555.23027157784,157789.02242159843,152555.23027157784,5195.382237672806,22.58767342567444,0.0 -449400,4.630793,0.6515305,,,,,,,,,,,,,, -449500,4.564714,0.6464795,,,,,,,,,,,,,, -449600,4.1799765,0.55363125,,,,,,,,,,,,,, -449700,4.8502326,0.6460955,,,,,,,,,,,,,, -449800,4.5391855,0.6446504,,,,,,,,,,,,,, -449900,4.4009013,0.61083925,,,,,,,,,,,,,, -450000,5.034256,0.6852282,,,,,,,,,,,,,, -450100,4.7785335,0.6757284,,,,,,,,,,,,,, -450200,3.954234,0.52731407,,,,,,,,,,,,,, -450300,4.54833,0.58985925,,,,,,,,,,,,,, -450400,4.8806314,0.6449221,,,,,,,,,,,,,, -450500,4.704551,0.61782634,,,,,,,,,,,,,, -450600,4.5839806,0.672508,,,,,,,,,,,,,, -450700,4.7140436,0.69680643,,,,,,,,,,,,,, -450800,4.1571355,0.6041492,,,,,,,,,,,,,, -450852,,,0.9602000713348388,0.1482229530811309,0.7547199726104736,1.054293155670166,50000.0,0.6301000118255615,1.8175337314605715,10000.0,153065.41314435005,158315.9674050808,153065.41314435005,5211.9875292778015,22.69133400917053,0.0 -450900,4.4411306,0.59058446,,,,,,,,,,,,,, -451000,4.483917,0.6172317,,,,,,,,,,,,,, -451100,4.6949434,0.6466567,,,,,,,,,,,,,, -451200,4.6643248,0.6241957,,,,,,,,,,,,,, -451300,4.687014,0.6507665,,,,,,,,,,,,,, -451400,4.4273973,0.5851184,,,,,,,,,,,,,, -451500,4.5924716,0.62694067,,,,,,,,,,,,,, -451600,4.786028,0.626387,,,,,,,,,,,,,, -451700,3.9966772,0.5444903,,,,,,,,,,,,,, -451800,4.931989,0.70972747,,,,,,,,,,,,,, -451900,4.566289,0.5912903,,,,,,,,,,,,,, -452000,4.4268107,0.6453165,,,,,,,,,,,,,, -452100,5.1192503,0.6340033,,,,,,,,,,,,,, -452200,4.2711763,0.6215758,,,,,,,,,,,,,, -452300,4.588553,0.5720519,,,,,,,,,,,,,, -452354,,,0.9606385231018066,0.1475521922111511,0.7548799514770508,1.055489182472229,50000.0,0.6310000419616699,1.8206629753112795,10000.0,153575.34843325615,158842.66458582878,153575.34843325615,5228.598965406418,22.787771224975582,0.0 -452400,4.3340273,0.612095,,,,,,,,,,,,,, -452500,4.587306,0.6312413,,,,,,,,,,,,,, -452600,4.3453794,0.6148936,,,,,,,,,,,,,, -452700,4.6377907,0.5824716,,,,,,,,,,,,,, -452800,4.2848315,0.5969525,,,,,,,,,,,,,, -452900,4.5856824,0.68649054,,,,,,,,,,,,,, -453000,4.186076,0.6098508,,,,,,,,,,,,,, -453100,4.4831014,0.6029146,,,,,,,,,,,,,, -453200,4.916963,0.64996785,,,,,,,,,,,,,, -453300,4.4012294,0.5605059,,,,,,,,,,,,,, -453400,4.8608565,0.6837126,,,,,,,,,,,,,, -453500,4.4296894,0.60656995,,,,,,,,,,,,,, -453600,4.880332,0.6409963,,,,,,,,,,,,,, -453700,4.5784326,0.5999498,,,,,,,,,,,,,, -453800,4.3494887,0.6621908,,,,,,,,,,,,,, -453857,,,0.9617147445678712,0.1444277018308639,0.7550199627876282,1.0558409690856934,50000.0,0.6300000548362732,1.8203938007354736,10000.0,154085.36639356613,159369.3465027809,154085.36639356613,5245.10581445694,22.89153957366944,0.0 -453900,5.085463,0.6057951,,,,,,,,,,,,,, -454000,4.3658266,0.5719026,,,,,,,,,,,,,, -454100,4.3495436,0.6102554,,,,,,,,,,,,,, -454200,5.2359495,0.6338425,,,,,,,,,,,,,, -454300,4.811553,0.6387763,,,,,,,,,,,,,, -454400,4.359439,0.54679936,,,,,,,,,,,,,, -454500,4.656164,0.62511724,,,,,,,,,,,,,, -454600,4.748199,0.6583822,,,,,,,,,,,,,, -454700,4.2562933,0.59468347,,,,,,,,,,,,,, -454800,4.749784,0.63742214,,,,,,,,,,,,,, -454900,4.3655224,0.6051392,,,,,,,,,,,,,, -455000,4.14578,0.6004366,,,,,,,,,,,,,, -455100,4.6334953,0.64403105,,,,,,,,,,,,,, -455200,4.433235,0.6205628,,,,,,,,,,,,,, -455300,4.3862276,0.6178159,,,,,,,,,,,,,, -455359,,,0.9610371589660645,0.142574280500412,0.7546799778938293,1.0553386211395264,50000.0,0.6294000148773193,1.819143295288086,10000.0,154595.30282139778,159896.09922385216,154595.30282139778,5261.763719320297,22.99702095985413,0.0 -455400,4.38534,0.61858463,,,,,,,,,,,,,, -455500,4.9558372,0.6612972,,,,,,,,,,,,,, -455600,4.6967406,0.62575805,,,,,,,,,,,,,, -455700,4.561757,0.6122385,,,,,,,,,,,,,, -455800,4.4443326,0.6213283,,,,,,,,,,,,,, -455900,4.5553226,0.6138042,,,,,,,,,,,,,, -456000,4.7448783,0.6318096,,,,,,,,,,,,,, -456100,4.9303737,0.7109233,,,,,,,,,,,,,, -456200,4.170379,0.6556218,,,,,,,,,,,,,, -456300,4.770197,0.6261886,,,,,,,,,,,,,, -456400,4.9894314,0.6052434,,,,,,,,,,,,,, -456500,4.37334,0.61572033,,,,,,,,,,,,,, -456600,4.361764,0.62727606,,,,,,,,,,,,,, -456700,5.1074657,0.63974524,,,,,,,,,,,,,, -456800,4.6066074,0.5459293,,,,,,,,,,,,,, -456863,,,0.961156725883484,0.145716980099678,0.754859983921051,1.0554897785186768,50000.0,0.629800021648407,1.819894790649414,10000.0,155105.47327399254,160422.96509361267,155105.47327399254,5278.295104265213,23.10880136489868,0.0 -456900,4.350853,0.6246056,,,,,,,,,,,,,, -457000,4.521661,0.61599135,,,,,,,,,,,,,, -457100,4.6331596,0.6216119,,,,,,,,,,,,,, -457200,4.188625,0.6163749,,,,,,,,,,,,,, -457300,4.0925035,0.58322215,,,,,,,,,,,,,, -457400,4.7916145,0.64046514,,,,,,,,,,,,,, -457500,5.0302978,0.6869245,,,,,,,,,,,,,, -457600,4.6848454,0.59555775,,,,,,,,,,,,,, -457700,4.9449625,0.67552125,,,,,,,,,,,,,, -457800,4.792212,0.6687548,,,,,,,,,,,,,, -457900,4.362329,0.70227605,,,,,,,,,,,,,, -458000,4.4650145,0.544459,,,,,,,,,,,,,, -458100,4.4670267,0.6230268,,,,,,,,,,,,,, -458200,5.1411424,0.5798167,,,,,,,,,,,,,, -458300,4.5174155,0.6353766,,,,,,,,,,,,,, -458366,,,0.9598014950752258,0.1472740620374679,0.7547799944877625,1.0557348728179932,50000.0,0.6299000382423401,1.8191497325897217,10000.0,155615.38880991936,160950.1413064003,155615.38880991936,5295.399443626404,23.212135314941406,0.0 -458400,4.461589,0.6272137,,,,,,,,,,,,,, -458500,4.334534,0.61354274,,,,,,,,,,,,,, -458600,4.1712093,0.576233,,,,,,,,,,,,,, -458700,4.455705,0.5712522,,,,,,,,,,,,,, -458800,4.607597,0.57847893,,,,,,,,,,,,,, -458900,4.2617245,0.67327774,,,,,,,,,,,,,, -459000,4.6237006,0.65715855,,,,,,,,,,,,,, -459100,4.420873,0.630785,,,,,,,,,,,,,, -459200,4.5827036,0.6436958,,,,,,,,,,,,,, -459300,4.1892242,0.6021538,,,,,,,,,,,,,, -459400,4.6792464,0.68496776,,,,,,,,,,,,,, -459500,4.7452626,0.6974389,,,,,,,,,,,,,, -459600,4.612054,0.6841119,,,,,,,,,,,,,, -459700,4.906314,0.662259,,,,,,,,,,,,,, -459800,4.5079384,0.6532268,,,,,,,,,,,,,, -459869,,,0.960598647594452,0.1467158049345016,0.7549399733543396,1.0555691719055176,50000.0,0.6297000050544739,1.8184266090393064,10000.0,156125.31659150124,161476.9416666031,156125.31659150124,5312.113205909729,23.317948579788208,0.0 -459900,4.9241033,0.73388886,,,,,,,,,,,,,, -460000,4.609463,0.6215493,,,,,,,,,,,,,, -460100,4.4541163,0.6266755,,,,,,,,,,,,,, -460200,4.7640553,0.5925873,,,,,,,,,,,,,, -460300,4.2429457,0.6640699,,,,,,,,,,,,,, -460400,4.4674997,0.619161,,,,,,,,,,,,,, -460500,4.424693,0.61217076,,,,,,,,,,,,,, -460600,4.5478907,0.60552675,,,,,,,,,,,,,, -460700,4.410738,0.61244875,,,,,,,,,,,,,, -460800,4.9684725,0.65489876,,,,,,,,,,,,,, -460900,5.031133,0.6139527,,,,,,,,,,,,,, -461000,4.7746177,0.6625508,,,,,,,,,,,,,, -461100,4.7343817,0.62792885,,,,,,,,,,,,,, -461200,4.8213897,0.6152235,,,,,,,,,,,,,, -461300,4.5613017,0.5702061,,,,,,,,,,,,,, -461371,,,0.9602798223495485,0.1468383967876434,0.7546399831771851,1.056372046470642,50000.0,0.6301000118255615,1.8204811811447144,10000.0,156635.17824530602,162003.59568810463,156635.17824530602,5328.743317604065,23.42711114883423,0.0 -461400,4.203384,0.546149,,,,,,,,,,,,,, -461500,4.5965786,0.59405416,,,,,,,,,,,,,, -461600,4.0192633,0.63409364,,,,,,,,,,,,,, -461700,4.382056,0.557238,,,,,,,,,,,,,, -461800,4.9513474,0.6375633,,,,,,,,,,,,,, -461900,4.7329473,0.6398257,,,,,,,,,,,,,, -462000,4.726665,0.6536871,,,,,,,,,,,,,, -462100,4.3195353,0.63647425,,,,,,,,,,,,,, -462200,4.9475884,0.6645918,,,,,,,,,,,,,, -462300,5.0596237,0.6732419,,,,,,,,,,,,,, -462400,4.5833764,0.62488085,,,,,,,,,,,,,, -462500,4.414162,0.613572,,,,,,,,,,,,,, -462600,4.446282,0.6128959,,,,,,,,,,,,,, -462700,4.5433917,0.65053207,,,,,,,,,,,,,, -462800,4.427189,0.6319301,,,,,,,,,,,,,, -462874,,,0.9627510905265808,0.1442515403032302,0.7547399997711182,1.0549752712249756,50000.0,0.629800021648407,1.8176995515823364,10000.0,157145.0807697773,162530.2850341797,157145.0807697773,5345.371515035629,23.53261113166809,0.0 -462900,4.39746,0.6506315,,,,,,,,,,,,,, -463000,4.5226283,0.60777235,,,,,,,,,,,,,, -463100,4.8988395,0.6195806,,,,,,,,,,,,,, -463200,4.9934235,0.63408494,,,,,,,,,,,,,, -463300,4.051303,0.53023314,,,,,,,,,,,,,, -463400,4.2222123,0.5598433,,,,,,,,,,,,,, -463500,4.162883,0.5953872,,,,,,,,,,,,,, -463600,4.581837,0.5952357,,,,,,,,,,,,,, -463700,5.0462904,0.6350155,,,,,,,,,,,,,, -463800,4.5180893,0.6483244,,,,,,,,,,,,,, -463900,4.671899,0.61713386,,,,,,,,,,,,,, -464000,4.6647964,0.5796149,,,,,,,,,,,,,, -464100,5.0965557,0.6330297,,,,,,,,,,,,,, -464200,4.3112803,0.5652512,,,,,,,,,,,,,, -464300,4.6267395,0.6563559,,,,,,,,,,,,,, -464377,,,0.9610171914100648,0.1462255269289016,0.7545199990272522,1.0557774305343628,50000.0,0.629800021648407,1.8174927234649656,10000.0,157655.2273361683,163057.73574876785,157655.2273361683,5362.519360303879,23.63612127304077,0.0 -464400,4.6057,0.63438207,,,,,,,,,,,,,, -464500,4.3058586,0.5935414,,,,,,,,,,,,,, -464600,4.2500854,0.5907198,,,,,,,,,,,,,, -464700,4.574508,0.560627,,,,,,,,,,,,,, -464800,4.8294716,0.7295733,,,,,,,,,,,,,, -464900,4.464467,0.60396355,,,,,,,,,,,,,, -465000,4.56643,0.6396538,,,,,,,,,,,,,, -465100,4.384734,0.68979025,,,,,,,,,,,,,, -465200,4.2751327,0.5889567,,,,,,,,,,,,,, -465300,4.4543457,0.56573266,,,,,,,,,,,,,, -465400,4.4811687,0.5642389,,,,,,,,,,,,,, -465500,4.3086643,0.6041148,,,,,,,,,,,,,, -465600,4.4236946,0.6475977,,,,,,,,,,,,,, -465700,4.8070555,0.67179227,,,,,,,,,,,,,, -465800,4.7982554,0.6080698,,,,,,,,,,,,,, -465880,,,0.9604392051696776,0.1437030434608459,0.7551199793815613,1.0570862293243408,50000.0,0.629300057888031,1.82122802734375,10000.0,158165.1404414177,163584.59530687332,158165.1404414177,5379.328293085098,23.72211003303528,0.0 -465900,5.960527,0.65446544,,,,,,,,,,,,,, -466000,4.4251747,0.62317294,,,,,,,,,,,,,, -466100,5.0602427,0.6312605,,,,,,,,,,,,,, -466200,4.6223435,0.69780904,,,,,,,,,,,,,, -466300,4.595224,0.6237115,,,,,,,,,,,,,, -466400,4.3595047,0.5066985,,,,,,,,,,,,,, -466500,4.4911103,0.6264076,,,,,,,,,,,,,, -466600,4.424755,0.6060827,,,,,,,,,,,,,, -466700,4.6647825,0.64178294,,,,,,,,,,,,,, -466800,4.70114,0.64964265,,,,,,,,,,,,,, -466900,4.433782,0.6144936,,,,,,,,,,,,,, -467000,4.697723,0.61270213,,,,,,,,,,,,,, -467100,4.5014386,0.64249206,,,,,,,,,,,,,, -467200,4.490715,0.576246,,,,,,,,,,,,,, -467300,4.755552,0.64721954,,,,,,,,,,,,,, -467383,,,0.9608777165412904,0.1466167420148849,0.7547199726104736,1.0551986694335938,50000.0,0.6310000419616699,1.818740963935852,10000.0,158675.24805808067,164111.61119008064,158675.24805808067,5396.06454372406,23.83945727348328,0.0 -467400,4.568938,0.65443224,,,,,,,,,,,,,, -467500,4.9065285,0.63000095,,,,,,,,,,,,,, -467600,4.8391147,0.6581404,,,,,,,,,,,,,, -467700,4.592395,0.6556248,,,,,,,,,,,,,, -467800,4.5175886,0.64515895,,,,,,,,,,,,,, -467900,4.455343,0.616724,,,,,,,,,,,,,, -468000,4.6524115,0.6538342,,,,,,,,,,,,,, -468100,4.593379,0.6258122,,,,,,,,,,,,,, -468200,4.2648726,0.5801979,,,,,,,,,,,,,, -468300,4.641261,0.5924958,,,,,,,,,,,,,, -468400,4.387359,0.5957159,,,,,,,,,,,,,, -468500,4.521227,0.62002206,,,,,,,,,,,,,, -468600,4.6128054,0.63948244,,,,,,,,,,,,,, -468700,4.5692735,0.6428325,,,,,,,,,,,,,, -468800,4.1913977,0.621114,,,,,,,,,,,,,, -468886,,,0.9606584906578064,0.1459176987409591,0.7545199990272522,1.0558143854141235,50000.0,0.6300000548362732,1.8203456401824951,10000.0,159185.16336750984,164638.34228634834,159185.16336750984,5412.717334508896,23.949141263961792,0.0 -468900,4.6135125,0.6377984,,,,,,,,,,,,,, -469000,4.808514,0.64391035,,,,,,,,,,,,,, -469100,4.362888,0.6240061,,,,,,,,,,,,,, -469200,4.502004,0.5837524,,,,,,,,,,,,,, -469300,4.1456227,0.5813462,,,,,,,,,,,,,, -469400,4.9105196,0.6525627,,,,,,,,,,,,,, -469500,4.405112,0.6022868,,,,,,,,,,,,,, -469600,4.7195334,0.65981287,,,,,,,,,,,,,, -469700,5.0137267,0.634422,,,,,,,,,,,,,, -469800,4.7202587,0.6551416,,,,,,,,,,,,,, -469900,4.6163664,0.64422184,,,,,,,,,,,,,, -470000,4.4230704,0.61721677,,,,,,,,,,,,,, -470100,4.564639,0.6440878,,,,,,,,,,,,,, -470200,4.733329,0.67052823,,,,,,,,,,,,,, -470300,4.4672136,0.58606523,,,,,,,,,,,,,, -470389,,,0.9602798223495485,0.1463384926319122,0.7549799680709839,1.055783748626709,50000.0,0.6301000118255615,1.819016456604004,10000.0,159695.21085381508,165165.05762171745,159695.21085381508,5429.2265625,24.054170608520508,0.0 -470400,4.933287,0.6402314,,,,,,,,,,,,,, -470500,4.5361867,0.63905424,,,,,,,,,,,,,, -470600,4.7399335,0.5366062,,,,,,,,,,,,,, -470700,4.780167,0.6672347,,,,,,,,,,,,,, -470800,4.372767,0.61770374,,,,,,,,,,,,,, -470900,4.694869,0.5749213,,,,,,,,,,,,,, -471000,4.3362203,0.6155096,,,,,,,,,,,,,, -471100,4.4443235,0.6475501,,,,,,,,,,,,,, -471200,4.6651,0.62945133,,,,,,,,,,,,,, -471300,4.051563,0.555255,,,,,,,,,,,,,, -471400,4.629976,0.6865822,,,,,,,,,,,,,, -471500,4.7088947,0.63915014,,,,,,,,,,,,,, -471600,4.965397,0.71875656,,,,,,,,,,,,,, -471700,4.9124165,0.67461836,,,,,,,,,,,,,, -471800,4.616433,0.6125717,,,,,,,,,,,,,, -471892,,,0.961694836616516,0.1453837901353836,0.7548799514770508,1.0557386875152588,50000.0,0.6300000548362732,1.8196965456008911,10000.0,160205.1214659214,165691.59308218956,160205.1214659214,5445.694318056107,24.15822815895081,0.0 -471900,4.8325095,0.61848253,,,,,,,,,,,,,, -472000,4.8079114,0.6154746,,,,,,,,,,,,,, -472100,4.6159663,0.6779489,,,,,,,,,,,,,, -472200,4.505924,0.60964894,,,,,,,,,,,,,, -472300,5.07125,0.7152226,,,,,,,,,,,,,, -472400,4.75277,0.6300383,,,,,,,,,,,,,, -472500,4.2291822,0.5225917,,,,,,,,,,,,,, -472600,3.9747634,0.54859453,,,,,,,,,,,,,, -472700,4.3296156,0.6319494,,,,,,,,,,,,,, -472800,4.2000723,0.5522373,,,,,,,,,,,,,, -472900,4.987362,0.63013995,,,,,,,,,,,,,, -473000,4.7025046,0.5725621,,,,,,,,,,,,,, -473100,4.5103903,0.6022882,,,,,,,,,,,,,, -473200,4.305401,0.6430753,,,,,,,,,,,,,, -473300,4.4275303,0.58030677,,,,,,,,,,,,,, -473394,,,0.9602598547935486,0.1482415646314621,0.7550599575042725,1.055686593055725,50000.0,0.6307000517845154,1.8173385858535769,10000.0,160715.06251072884,166218.24579143524,160715.06251072884,5462.238760948181,24.271573543548584,0.0 -473400,4.7765546,0.68090093,,,,,,,,,,,,,, -473500,4.776927,0.68115234,,,,,,,,,,,,,, -473600,4.3269134,0.58777606,,,,,,,,,,,,,, -473700,4.4733257,0.61718553,,,,,,,,,,,,,, -473800,4.9964523,0.7041149,,,,,,,,,,,,,, -473900,4.372721,0.58440244,,,,,,,,,,,,,, -474000,4.2834926,0.6410124,,,,,,,,,,,,,, -474100,4.279616,0.62542737,,,,,,,,,,,,,, -474200,4.263339,0.59253734,,,,,,,,,,,,,, -474300,4.741441,0.6029352,,,,,,,,,,,,,, -474400,4.561341,0.6497257,,,,,,,,,,,,,, -474500,4.566822,0.6457804,,,,,,,,,,,,,, -474600,4.681765,0.6697858,,,,,,,,,,,,,, -474700,4.230013,0.61126864,,,,,,,,,,,,,, -474800,4.5983152,0.67875946,,,,,,,,,,,,,, -474897,,,0.9599609375,0.1462009400129318,0.7547399997711182,1.0547475814819336,50000.0,0.631100058555603,1.818025708198548,10000.0,161224.98149490356,166744.93678593636,161224.98149490356,5478.855576515198,24.37463998794556,0.0 -474900,3.9641669,0.58137,,,,,,,,,,,,,, -475000,4.9587283,0.6238588,,,,,,,,,,,,,, -475100,4.41767,0.5841708,,,,,,,,,,,,,, -475200,4.474331,0.631943,,,,,,,,,,,,,, -475300,5.179582,0.69664276,,,,,,,,,,,,,, -475400,4.5798974,0.6318764,,,,,,,,,,,,,, -475500,4.5835214,0.71803343,,,,,,,,,,,,,, -475600,4.689385,0.6093568,,,,,,,,,,,,,, -475700,4.783316,0.56834614,,,,,,,,,,,,,, -475800,4.3692203,0.68532723,,,,,,,,,,,,,, -475900,4.7657914,0.6325078,,,,,,,,,,,,,, -476000,4.3424535,0.62835515,,,,,,,,,,,,,, -476100,3.9877644,0.5216707,,,,,,,,,,,,,, -476200,4.592536,0.62867296,,,,,,,,,,,,,, -476300,4.29409,0.5999381,,,,,,,,,,,,,, -476399,,,0.9605189561843872,0.1475342661142349,0.7547599673271179,1.056140422821045,50000.0,0.6299000382423401,1.8198281526565552,10000.0,161734.98222899437,167271.63280415535,161734.98222899437,5495.393129825592,24.47951650619507,0.0 -476400,4.092147,0.6127933,,,,,,,,,,,,,, -476500,4.5698795,0.6448313,,,,,,,,,,,,,, -476600,4.770415,0.6530282,,,,,,,,,,,,,, -476700,4.536249,0.66561997,,,,,,,,,,,,,, -476800,4.0768437,0.58723456,,,,,,,,,,,,,, -476900,4.9169345,0.6489004,,,,,,,,,,,,,, -477000,4.3636136,0.5951896,,,,,,,,,,,,,, -477100,4.0617228,0.54786056,,,,,,,,,,,,,, -477200,4.453902,0.679752,,,,,,,,,,,,,, -477300,4.580224,0.6164315,,,,,,,,,,,,,, -477400,4.787124,0.5780277,,,,,,,,,,,,,, -477500,4.5609546,0.6415978,,,,,,,,,,,,,, -477600,4.7120485,0.6546725,,,,,,,,,,,,,, -477700,4.328528,0.6612786,,,,,,,,,,,,,, -477800,4.4208045,0.59318525,,,,,,,,,,,,,, -477900,4.78721,0.64771247,,,,,,,,,,,,,, -477901,,,0.9598014950752258,0.1489960998296737,0.7548799514770508,1.0541421175003052,50000.0,0.6309000253677368,1.815615177154541,10000.0,162244.99106502533,167798.54024600983,162244.99106502533,5512.12969326973,24.58859372138977,0.0 -478000,3.988713,0.49851233,,,,,,,,,,,,,, -478100,4.5535097,0.5804613,,,,,,,,,,,,,, -478200,4.395311,0.573648,,,,,,,,,,,,,, -478300,4.508459,0.5452585,,,,,,,,,,,,,, -478400,4.597637,0.6764573,,,,,,,,,,,,,, -478500,4.229554,0.572001,,,,,,,,,,,,,, -478600,4.6393337,0.65142775,,,,,,,,,,,,,, -478700,4.5334935,0.6557683,,,,,,,,,,,,,, -478800,4.2067037,0.613213,,,,,,,,,,,,,, -478900,4.549968,0.6746437,,,,,,,,,,,,,, -479000,4.388152,0.6617928,,,,,,,,,,,,,, -479100,4.795893,0.61635655,,,,,,,,,,,,,, -479200,4.2659206,0.59565336,,,,,,,,,,,,,, -479300,4.5325546,0.6292547,,,,,,,,,,,,,, -479400,4.646999,0.601318,,,,,,,,,,,,,, -479404,,,0.9606983065605164,0.1470500826835632,0.7549999952316284,1.056383728981018,50000.0,0.629800021648407,1.8202873468399048,10000.0,162755.0349202156,168325.28706121445,162755.0349202156,5528.671658039093,24.69663667678833,0.0 -479500,4.44468,0.57791364,,,,,,,,,,,,,, -479600,4.612619,0.69074774,,,,,,,,,,,,,, -479700,4.291275,0.60131705,,,,,,,,,,,,,, -479800,4.39899,0.6563635,,,,,,,,,,,,,, -479900,4.5104513,0.5999349,,,,,,,,,,,,,, -480000,4.3319077,0.5987432,,,,,,,,,,,,,, -480100,5.0170565,0.60060704,,,,,,,,,,,,,, -480200,4.7162256,0.66950613,,,,,,,,,,,,,, -480300,4.6024566,0.66304755,,,,,,,,,,,,,, -480400,4.2431507,0.65227425,,,,,,,,,,,,,, -480500,4.7862015,0.65672624,,,,,,,,,,,,,, -480600,4.677649,0.5910847,,,,,,,,,,,,,, -480700,4.4378996,0.5846745,,,,,,,,,,,,,, -480800,4.8597984,0.6380724,,,,,,,,,,,,,, -480900,4.736668,0.7042428,,,,,,,,,,,,,, -480907,,,0.959622085094452,0.151059940457344,0.7543599605560303,1.0561161041259766,50000.0,0.6295000314712524,1.8191709518432613,10000.0,163264.94740009308,168851.94211268425,163264.94740009308,5545.251647233963,24.805399894714355,0.0 -481000,4.2196813,0.6095712,,,,,,,,,,,,,, -481100,5.0881133,0.7307231,,,,,,,,,,,,,, -481200,4.7661366,0.6240671,,,,,,,,,,,,,, -481300,4.877439,0.6665626,,,,,,,,,,,,,, -481400,4.6631575,0.6699566,,,,,,,,,,,,,, -481500,4.249387,0.5791686,,,,,,,,,,,,,, -481600,4.541,0.665053,,,,,,,,,,,,,, -481700,4.6632414,0.7195669,,,,,,,,,,,,,, -481800,4.754528,0.59584236,,,,,,,,,,,,,, -481900,4.746,0.6110571,,,,,,,,,,,,,, -482000,4.713071,0.6283562,,,,,,,,,,,,,, -482100,4.734914,0.6058925,,,,,,,,,,,,,, -482200,4.050816,0.6065673,,,,,,,,,,,,,, -482300,4.5096645,0.5889516,,,,,,,,,,,,,, -482400,4.225365,0.6144806,,,,,,,,,,,,,, -482410,,,0.9606186151504515,0.146817535161972,0.7549200057983398,1.0554178953170776,50000.0,0.6310000419616699,1.8195371627807613,10000.0,163775.10062289238,169378.83524656296,163775.10062289238,5561.8279457092285,24.915863752365112,0.0 -482500,4.626242,0.667802,,,,,,,,,,,,,, -482600,4.738931,0.61026907,,,,,,,,,,,,,, -482700,4.356786,0.6230439,,,,,,,,,,,,,, -482800,4.597538,0.6088927,,,,,,,,,,,,,, -482900,4.2615466,0.61028767,,,,,,,,,,,,,, -483000,4.4787235,0.60764426,,,,,,,,,,,,,, -483100,4.7403336,0.6176163,,,,,,,,,,,,,, -483200,4.3506103,0.63729775,,,,,,,,,,,,,, -483300,4.3383036,0.58329093,,,,,,,,,,,,,, -483400,4.46811,0.6060156,,,,,,,,,,,,,, -483500,4.5908465,0.6851256,,,,,,,,,,,,,, -483600,4.241175,0.6003454,,,,,,,,,,,,,, -483700,4.6672444,0.6129828,,,,,,,,,,,,,, -483800,4.2927313,0.6404742,,,,,,,,,,,,,, -483900,4.445613,0.60920227,,,,,,,,,,,,,, -483913,,,0.9611965417861938,0.1441107988357544,0.7546600103378296,1.0558921098709106,50000.0,0.6296000480651855,1.8194700479507449,10000.0,164285.20070242882,169905.66928815842,164285.20070242882,5578.402358531952,25.023176670074463,0.0 -484000,4.3763356,0.6226743,,,,,,,,,,,,,, -484100,4.918314,0.6434743,,,,,,,,,,,,,, -484200,4.3015122,0.59001446,,,,,,,,,,,,,, -484300,4.4153023,0.5043274,,,,,,,,,,,,,, -484400,4.439107,0.66179866,,,,,,,,,,,,,, -484500,4.6685863,0.63947177,,,,,,,,,,,,,, -484600,4.134842,0.49684864,,,,,,,,,,,,,, -484700,4.539277,0.6393745,,,,,,,,,,,,,, -484800,4.780458,0.5588226,,,,,,,,,,,,,, -484900,4.3750367,0.5696508,,,,,,,,,,,,,, -485000,4.497659,0.7362635,,,,,,,,,,,,,, -485100,4.4803977,0.6849148,,,,,,,,,,,,,, -485200,4.2400527,0.5896677,,,,,,,,,,,,,, -485300,4.618767,0.6625521,,,,,,,,,,,,,, -485400,4.1155806,0.5552559,,,,,,,,,,,,,, -485417,,,0.9585259556770324,0.1497927606105804,0.7547199726104736,1.055144429206848,50000.0,0.6309000253677368,1.8180348873138428,10000.0,164795.32609844208,170432.5374853611,164795.32609844208,5594.983488559723,25.13081169128418,0.0 -485500,4.4408975,0.5845699,,,,,,,,,,,,,, -485600,4.6237416,0.58258235,,,,,,,,,,,,,, -485700,4.527265,0.6279522,,,,,,,,,,,,,, -485800,4.629925,0.6598221,,,,,,,,,,,,,, -485900,4.3254623,0.60234046,,,,,,,,,,,,,, -486000,5.097944,0.64969903,,,,,,,,,,,,,, -486100,4.7154536,0.65231276,,,,,,,,,,,,,, -486200,4.783865,0.5840791,,,,,,,,,,,,,, -486300,4.5949283,0.65248525,,,,,,,,,,,,,, -486400,5.0476084,0.6428613,,,,,,,,,,,,,, -486500,4.407922,0.60161877,,,,,,,,,,,,,, -486600,4.196036,0.57764375,,,,,,,,,,,,,, -486700,4.3613505,0.54239964,,,,,,,,,,,,,, -486800,4.9391303,0.67643267,,,,,,,,,,,,,, -486900,4.5384145,0.57557636,,,,,,,,,,,,,, -486921,,,0.9615553021430968,0.1433934420347213,0.7547399997711182,1.0567063093185425,50000.0,0.631100058555603,1.819509029388428,10000.0,165305.46842598915,170959.40493369102,165305.46842598915,5611.545911312103,25.24029874801636,0.0 -487000,4.2538304,0.5596092,,,,,,,,,,,,,, -487100,5.0772066,0.6465477,,,,,,,,,,,,,, -487200,4.7498927,0.67234844,,,,,,,,,,,,,, -487300,4.2025127,0.61773777,,,,,,,,,,,,,, -487400,4.427301,0.5900532,,,,,,,,,,,,,, -487500,4.9972744,0.69720423,,,,,,,,,,,,,, -487600,4.466723,0.5974191,,,,,,,,,,,,,, -487700,4.918332,0.6784462,,,,,,,,,,,,,, -487800,4.113088,0.598852,,,,,,,,,,,,,, -487900,4.5251403,0.5583123,,,,,,,,,,,,,, -488000,4.373253,0.6769945,,,,,,,,,,,,,, -488100,4.3489423,0.6222705,,,,,,,,,,,,,, -488200,4.894864,0.6875074,,,,,,,,,,,,,, -488300,4.5327415,0.5601792,,,,,,,,,,,,,, -488400,4.4719334,0.62437105,,,,,,,,,,,,,, -488423,,,0.9604790806770324,0.1482679396867752,0.7546199560165405,1.0550260543823242,50000.0,0.6297000050544739,1.8203799724578853,10000.0,165815.3304874897,171486.03428840637,165815.3304874897,5628.148483276367,25.35098004341125,0.0 -488500,4.788991,0.69051456,,,,,,,,,,,,,, -488600,4.9817133,0.6280658,,,,,,,,,,,,,, -488700,4.3728027,0.59176177,,,,,,,,,,,,,, -488800,4.785153,0.61201775,,,,,,,,,,,,,, -488900,4.5669456,0.5879925,,,,,,,,,,,,,, -489000,4.370154,0.57958627,,,,,,,,,,,,,, -489100,4.318578,0.65189356,,,,,,,,,,,,,, -489200,4.5340533,0.5469488,,,,,,,,,,,,,, -489300,4.120803,0.62066346,,,,,,,,,,,,,, -489400,4.8548894,0.693049,,,,,,,,,,,,,, -489500,4.222372,0.5911488,,,,,,,,,,,,,, -489600,4.5983453,0.6092264,,,,,,,,,,,,,, -489700,4.602274,0.5985175,,,,,,,,,,,,,, -489800,4.616477,0.6580823,,,,,,,,,,,,,, -489900,4.230661,0.6112281,,,,,,,,,,,,,, -489926,,,0.9613958597183228,0.1450057327747345,0.7552199959754944,1.0549418926239014,50000.0,0.6305000185966492,1.8186701536178589,10000.0,166325.27960824966,172012.83657836914,166325.27960824966,5644.838176488876,25.461592197418213,0.0 -490000,4.791657,0.6587552,,,,,,,,,,,,,, -490100,4.3332763,0.59854054,,,,,,,,,,,,,, -490200,4.4737525,0.621829,,,,,,,,,,,,,, -490300,4.236987,0.5589223,,,,,,,,,,,,,, -490400,4.663139,0.6107266,,,,,,,,,,,,,, -490500,4.205098,0.58233637,,,,,,,,,,,,,, -490600,4.426661,0.5520922,,,,,,,,,,,,,, -490700,4.86051,0.6183604,,,,,,,,,,,,,, -490800,4.2120104,0.623996,,,,,,,,,,,,,, -490900,4.5340257,0.6280686,,,,,,,,,,,,,, -491000,4.4871182,0.6442891,,,,,,,,,,,,,, -491100,4.576571,0.6576505,,,,,,,,,,,,,, -491200,4.524945,0.6442122,,,,,,,,,,,,,, -491300,4.629594,0.62498987,,,,,,,,,,,,,, -491400,4.689288,0.6368967,,,,,,,,,,,,,, -491428,,,0.960339605808258,0.1478165686130523,0.7548399567604065,1.055351972579956,50000.0,0.6299000382423401,1.818878173828125,10000.0,166835.23479795456,172539.53677415848,166835.23479795456,5661.418847084045,25.57334852218628,0.0 -491500,4.7804585,0.64581794,,,,,,,,,,,,,, -491600,4.4959774,0.56682485,,,,,,,,,,,,,, -491700,4.9842396,0.63432115,,,,,,,,,,,,,, -491800,4.545476,0.60460407,,,,,,,,,,,,,, -491900,4.4119806,0.5897784,,,,,,,,,,,,,, -492000,4.586621,0.5991491,,,,,,,,,,,,,, -492100,4.712201,0.6103229,,,,,,,,,,,,,, -492200,4.546975,0.5888413,,,,,,,,,,,,,, -492300,4.2792397,0.6536274,,,,,,,,,,,,,, -492400,4.847093,0.67474777,,,,,,,,,,,,,, -492500,4.5604753,0.664811,,,,,,,,,,,,,, -492600,4.623891,0.632188,,,,,,,,,,,,,, -492700,4.727696,0.6370917,,,,,,,,,,,,,, -492800,4.7042274,0.63405347,,,,,,,,,,,,,, -492900,4.3926177,0.66542935,,,,,,,,,,,,,, -492931,,,0.9618343114852904,0.1427801996469497,0.7546600103378296,1.0560132265090942,50000.0,0.6303000450134277,1.8192788362503047,10000.0,167345.2049202919,173066.20989394188,167345.2049202919,5677.953634738922,25.689466953277588,0.0 -493000,4.285786,0.627566,,,,,,,,,,,,,, -493100,4.568182,0.6212998,,,,,,,,,,,,,, -493200,4.426967,0.60896844,,,,,,,,,,,,,, -493300,5.4841037,0.67789257,,,,,,,,,,,,,, -493400,4.386789,0.6725954,,,,,,,,,,,,,, -493500,4.653014,0.6622527,,,,,,,,,,,,,, -493600,4.595886,0.63124824,,,,,,,,,,,,,, -493700,4.538805,0.6827738,,,,,,,,,,,,,, -493800,4.440253,0.5859044,,,,,,,,,,,,,, -493900,4.357003,0.65870094,,,,,,,,,,,,,, -494000,4.9712725,0.6545199,,,,,,,,,,,,,, -494100,4.3993397,0.6305616,,,,,,,,,,,,,, -494200,4.542269,0.64545256,,,,,,,,,,,,,, -494300,4.5714855,0.6488764,,,,,,,,,,,,,, -494400,4.529895,0.64659184,,,,,,,,,,,,,, -494434,,,0.960379421710968,0.1443609893321991,0.7543999552726746,1.0561788082122805,50000.0,0.6296000480651855,1.820483684539795,10000.0,167855.2519850731,173593.01951646805,167855.2519850731,5694.541827440262,25.811455249786377,0.0 -494500,4.918063,0.5867153,,,,,,,,,,,,,, -494600,5.0740795,0.6495277,,,,,,,,,,,,,, -494700,4.646928,0.70121205,,,,,,,,,,,,,, -494800,4.42891,0.5845207,,,,,,,,,,,,,, -494900,4.5369153,0.57591957,,,,,,,,,,,,,, -495000,4.34541,0.61173457,,,,,,,,,,,,,, -495100,5.1049933,0.65221655,,,,,,,,,,,,,, -495200,4.7077947,0.6191566,,,,,,,,,,,,,, -495300,4.561287,0.5965928,,,,,,,,,,,,,, -495400,4.7889986,0.6447002,,,,,,,,,,,,,, -495500,4.4617643,0.5971555,,,,,,,,,,,,,, -495600,4.41868,0.5888816,,,,,,,,,,,,,, -495700,4.463954,0.6119903,,,,,,,,,,,,,, -495800,4.4069386,0.6862403,,,,,,,,,,,,,, -495900,4.407434,0.6713236,,,,,,,,,,,,,, -495937,,,0.961535394191742,0.1451693475246429,0.7547399997711182,1.056144118309021,50000.0,0.6294000148773193,1.8213483095169067,10000.0,168365.25228381157,174119.71955490112,168365.25228381157,5711.08086681366,25.91838574409485,0.0 -496000,4.281672,0.55819154,,,,,,,,,,,,,, -496100,4.4026303,0.6717513,,,,,,,,,,,,,, -496200,4.775894,0.5697281,,,,,,,,,,,,,, -496300,4.948331,0.6847446,,,,,,,,,,,,,, -496400,4.284386,0.5963989,,,,,,,,,,,,,, -496500,4.3463473,0.6418591,,,,,,,,,,,,,, -496600,4.453673,0.6630611,,,,,,,,,,,,,, -496700,4.651514,0.57297254,,,,,,,,,,,,,, -496800,4.6992025,0.6548373,,,,,,,,,,,,,, -496900,4.6150374,0.70647717,,,,,,,,,,,,,, -497000,5.0610175,0.616265,,,,,,,,,,,,,, -497100,4.54604,0.64868504,,,,,,,,,,,,,, -497200,4.450809,0.6697492,,,,,,,,,,,,,, -497300,4.60051,0.6187819,,,,,,,,,,,,,, -497400,4.4468718,0.5935993,,,,,,,,,,,,,, -497440,,,0.96000075340271,0.1484004706144333,0.7550199627876282,1.0548858642578125,50000.0,0.6305000185966492,1.8195267915725708,10000.0,168875.17960882187,174647.20702934265,168875.17960882187,5728.474135398865,26.03202795982361,0.0 -497500,4.5414033,0.62093836,,,,,,,,,,,,,, -497600,4.3726315,0.6069551,,,,,,,,,,,,,, -497700,4.536017,0.60256994,,,,,,,,,,,,,, -497800,4.6399226,0.61994463,,,,,,,,,,,,,, -497900,4.488332,0.68042487,,,,,,,,,,,,,, -498000,4.1595216,0.6065312,,,,,,,,,,,,,, -498100,4.5718055,0.58922625,,,,,,,,,,,,,, -498200,4.4354973,0.58255905,,,,,,,,,,,,,, -498300,4.268978,0.62048984,,,,,,,,,,,,,, -498400,4.765245,0.6433409,,,,,,,,,,,,,, -498500,4.767629,0.6207414,,,,,,,,,,,,,, -498600,4.2567344,0.5544246,,,,,,,,,,,,,, -498700,5.0512075,0.64977044,,,,,,,,,,,,,, -498800,4.8806853,0.6408604,,,,,,,,,,,,,, -498900,5.5122323,0.6889948,,,,,,,,,,,,,, -498943,,,0.9606783986091614,0.1457280069589615,0.7548999786376953,1.054988145828247,50000.0,0.6310000419616699,1.817232728004456,10000.0,169385.31666755676,175174.05980086327,169385.31666755676,5745.026813030243,26.14190459251404,0.0 -499000,4.6790724,0.6926723,,,,,,,,,,,,,, -499100,4.399652,0.57976586,,,,,,,,,,,,,, -499200,4.683993,0.64279944,,,,,,,,,,,,,, -499300,4.6237254,0.6814795,,,,,,,,,,,,,, -499400,4.949005,0.684716,,,,,,,,,,,,,, -499500,4.494584,0.62122357,,,,,,,,,,,,,, -499600,4.5042176,0.6685412,,,,,,,,,,,,,, -499700,4.336155,0.6102066,,,,,,,,,,,,,, -499800,4.261512,0.5826004,,,,,,,,,,,,,, -499900,4.6636605,0.68782085,,,,,,,,,,,,,, -500000,4.596057,0.64535356,,,,,,,,,,,,,, -500100,4.2576203,0.59268713,,,,,,,,,,,,,, -500200,4.090488,0.5696677,,,,,,,,,,,,,, -500300,4.1324134,0.54954135,,,,,,,,,,,,,, -500400,4.2536335,0.5723554,,,,,,,,,,,,,, -500446,,,0.9609972834587096,0.1472698599100113,0.7549399733543396,1.0556048154830933,50000.0,0.6303000450134277,1.8181790113449097,10000.0,169895.3501522541,175700.7484512329,169895.3501522541,5761.515566825867,26.25598978996277,0.0 -500500,4.4964967,0.60942924,,,,,,,,,,,,,, -500600,5.2828083,0.59632504,,,,,,,,,,,,,, -500700,4.9045167,0.6668265,,,,,,,,,,,,,, -500800,4.887487,0.74421585,,,,,,,,,,,,,, -500900,4.484681,0.6231587,,,,,,,,,,,,,, -501000,4.4675736,0.61183566,,,,,,,,,,,,,, -501100,4.680454,0.61143756,,,,,,,,,,,,,, -501200,4.5568385,0.57130545,,,,,,,,,,,,,, -501300,4.6121573,0.6617633,,,,,,,,,,,,,, -501400,4.4659753,0.64530116,,,,,,,,,,,,,, -501500,4.573818,0.6663013,,,,,,,,,,,,,, -501600,4.676294,0.57970226,,,,,,,,,,,,,, -501700,4.380263,0.588398,,,,,,,,,,,,,, -501800,4.7853084,0.67071694,,,,,,,,,,,,,, -501900,4.4336624,0.65194315,,,,,,,,,,,,,, -501949,,,0.962292730808258,0.1432228684425354,0.7550199627876282,1.055835485458374,50000.0,0.6306000351905823,1.8182381391525269,10000.0,170405.28198862076,176227.55142569542,170405.28198862076,5778.214676856995,26.37480115890503,0.0 -502000,4.1441493,0.59370637,,,,,,,,,,,,,, -502100,4.7215652,0.5987952,,,,,,,,,,,,,, -502200,5.056082,0.6950995,,,,,,,,,,,,,, -502300,4.6117473,0.6642615,,,,,,,,,,,,,, -502400,4.7897606,0.6745109,,,,,,,,,,,,,, -502500,4.4533987,0.5778052,,,,,,,,,,,,,, -502600,4.199607,0.63682777,,,,,,,,,,,,,, -502700,4.8605275,0.59874064,,,,,,,,,,,,,, -502800,4.478031,0.6167528,,,,,,,,,,,,,, -502900,4.426761,0.64971215,,,,,,,,,,,,,, -503000,4.4393663,0.60757226,,,,,,,,,,,,,, -503100,4.7021484,0.61049694,,,,,,,,,,,,,, -503200,4.644081,0.65888554,,,,,,,,,,,,,, -503300,4.2260184,0.5979519,,,,,,,,,,,,,, -503400,4.5761456,0.6086774,,,,,,,,,,,,,, -503452,,,0.961575210094452,0.1445035338401794,0.7547799944877625,1.0563068389892578,50000.0,0.6301000118255615,1.8219772577285769,10000.0,170915.20975255966,176754.32218694687,170915.20975255966,5794.882792234421,26.4967200756073,0.0 -503500,4.5777135,0.6140385,,,,,,,,,,,,,, -503600,4.6631575,0.5615638,,,,,,,,,,,,,, -503700,4.556829,0.6377931,,,,,,,,,,,,,, -503800,4.7617273,0.6197357,,,,,,,,,,,,,, -503900,4.518821,0.6392263,,,,,,,,,,,,,, -504000,4.4092484,0.60110193,,,,,,,,,,,,,, -504100,5.1487665,0.67750895,,,,,,,,,,,,,, -504200,4.708799,0.65662366,,,,,,,,,,,,,, -504300,4.5654593,0.6850042,,,,,,,,,,,,,, -504400,4.740235,0.5791964,,,,,,,,,,,,,, -504500,4.499914,0.6055557,,,,,,,,,,,,,, -504600,4.5220933,0.68652344,,,,,,,,,,,,,, -504700,4.292994,0.5757204,,,,,,,,,,,,,, -504800,4.3693914,0.6224869,,,,,,,,,,,,,, -504900,4.5375323,0.65301776,,,,,,,,,,,,,, -504956,,,0.9600605964660645,0.1470568627119064,0.7549799680709839,1.0555027723312378,50000.0,0.6304000020027161,1.818542718887329,10000.0,171425.38033485413,177281.23242735863,171425.38033485413,5811.456128358841,26.61021900177002,0.0 -505000,4.827129,0.70211357,,,,,,,,,,,,,, -505100,4.75814,0.7580578,,,,,,,,,,,,,, -505200,4.2285852,0.6348033,,,,,,,,,,,,,, -505300,4.0361767,0.5599041,,,,,,,,,,,,,, -505400,4.7271113,0.62385017,,,,,,,,,,,,,, -505500,4.4754453,0.596217,,,,,,,,,,,,,, -505600,4.747056,0.6788207,,,,,,,,,,,,,, -505700,4.330714,0.599573,,,,,,,,,,,,,, -505800,4.94492,0.6620356,,,,,,,,,,,,,, -505900,4.5878086,0.63667256,,,,,,,,,,,,,, -506000,4.4002676,0.66306657,,,,,,,,,,,,,, -506100,4.575571,0.5847426,,,,,,,,,,,,,, -506200,4.882519,0.5645524,,,,,,,,,,,,,, -506300,4.407087,0.5658456,,,,,,,,,,,,,, -506400,4.666516,0.614201,,,,,,,,,,,,,, -506458,,,0.9614556431770324,0.1461593806743621,0.7551400065422058,1.0542012453079224,50000.0,0.6303000450134277,1.81696355342865,10000.0,171935.21598100662,177807.80382156372,171935.21598100662,5828.022411584854,26.72679662704468,0.0 -506500,4.54095,0.6426971,,,,,,,,,,,,,, -506600,4.53815,0.5930059,,,,,,,,,,,,,, -506700,5.048117,0.6298811,,,,,,,,,,,,,, -506800,4.1250134,0.5519873,,,,,,,,,,,,,, -506900,4.680603,0.56133246,,,,,,,,,,,,,, -507000,5.0062222,0.6446349,,,,,,,,,,,,,, -507100,4.188407,0.59389335,,,,,,,,,,,,,, -507200,4.5830407,0.66883045,,,,,,,,,,,,,, -507300,4.2980204,0.6098767,,,,,,,,,,,,,, -507400,4.2123065,0.6690136,,,,,,,,,,,,,, -507500,4.1827846,0.6282303,,,,,,,,,,,,,, -507600,4.665342,0.62780285,,,,,,,,,,,,,, -507700,5.2500215,0.66224724,,,,,,,,,,,,,, -507800,4.5268483,0.65295684,,,,,,,,,,,,,, -507900,4.2461014,0.5920252,,,,,,,,,,,,,, -507962,,,0.9599609375,0.1463172882795334,0.7548799514770508,1.0554317235946655,50000.0,0.6307000517845154,1.8180532455444336,10000.0,172445.2868115902,178334.55427193642,172445.2868115902,5844.531116485596,26.84491229057312,0.0 -508000,4.8497972,0.63377,,,,,,,,,,,,,, -508100,4.320603,0.60705465,,,,,,,,,,,,,, -508200,5.1705794,0.6607329,,,,,,,,,,,,,, -508300,4.6917195,0.6409694,,,,,,,,,,,,,, -508400,4.2906137,0.6088624,,,,,,,,,,,,,, -508500,4.5825653,0.65163684,,,,,,,,,,,,,, -508600,4.661929,0.58617985,,,,,,,,,,,,,, -508700,5.052545,0.57559425,,,,,,,,,,,,,, -508800,4.398309,0.65414035,,,,,,,,,,,,,, -508900,4.5690265,0.6449594,,,,,,,,,,,,,, -509000,4.209662,0.5364579,,,,,,,,,,,,,, -509100,4.4461436,0.6509854,,,,,,,,,,,,,, -509200,4.7192087,0.6668014,,,,,,,,,,,,,, -509300,4.42708,0.5740965,,,,,,,,,,,,,, -509400,4.4500666,0.5698984,,,,,,,,,,,,,, -509465,,,0.9617546200752258,0.1442474871873855,0.7548799514770508,1.0556154251098633,50000.0,0.631600022315979,1.81832218170166,10000.0,172955.43111920357,178861.50686740875,172955.43111920357,5861.175798654556,26.954604864120483,0.0 -509500,4.3111095,0.6392863,,,,,,,,,,,,,, -509600,4.0687637,0.56119764,,,,,,,,,,,,,, -509700,4.7237396,0.73294705,,,,,,,,,,,,,, -509800,4.7860193,0.6661227,,,,,,,,,,,,,, -509900,4.548778,0.5609312,,,,,,,,,,,,,, -510000,4.5384912,0.65533376,,,,,,,,,,,,,, -510100,4.755947,0.6377466,,,,,,,,,,,,,, -510200,4.2667317,0.6068632,,,,,,,,,,,,,, -510300,4.1166863,0.54676086,,,,,,,,,,,,,, -510400,4.5005217,0.76465374,,,,,,,,,,,,,, -510500,4.625408,0.64831334,,,,,,,,,,,,,, -510600,4.328781,0.5491654,,,,,,,,,,,,,, -510700,4.4197283,0.53098136,,,,,,,,,,,,,, -510800,4.7572713,0.6190144,,,,,,,,,,,,,, -510900,4.673861,0.6477586,,,,,,,,,,,,,, -510969,,,0.961156725883484,0.1486804485321045,0.7549200057983398,1.055416226387024,50000.0,0.6302000284194946,1.8184680938720703,10000.0,173465.5064251423,179388.3738567829,173465.5064251423,5877.799740791321,27.069482803344727,0.0 -511000,4.315184,0.590395,,,,,,,,,,,,,, -511100,4.4450827,0.61576295,,,,,,,,,,,,,, -511200,4.190933,0.5725084,,,,,,,,,,,,,, -511300,4.5363774,0.6201043,,,,,,,,,,,,,, -511400,4.6583815,0.63462853,,,,,,,,,,,,,, -511500,4.4731493,0.63868254,,,,,,,,,,,,,, -511600,4.3611107,0.6404533,,,,,,,,,,,,,, -511700,4.4519105,0.7008449,,,,,,,,,,,,,, -511800,4.750198,0.6313172,,,,,,,,,,,,,, -511900,4.7218146,0.62229556,,,,,,,,,,,,,, -512000,4.4616346,0.6320555,,,,,,,,,,,,,, -512100,4.2374234,0.5974312,,,,,,,,,,,,,, -512200,4.1423764,0.56368953,,,,,,,,,,,,,, -512300,4.830909,0.61448467,,,,,,,,,,,,,, -512400,4.503808,0.60794944,,,,,,,,,,,,,, -512471,,,0.958804965019226,0.1498315632343292,0.7547799944877625,1.055154800415039,50000.0,0.6303000450134277,1.817718863487244,10000.0,173975.3585202694,179915.33734679222,173975.3585202694,5894.746170282364,27.18165707588196,0.0 -512500,4.5998535,0.66203606,,,,,,,,,,,,,, -512600,4.6174016,0.64339054,,,,,,,,,,,,,, -512700,4.447752,0.5701342,,,,,,,,,,,,,, -512800,4.5535965,0.66073835,,,,,,,,,,,,,, -512900,4.8980546,0.69573414,,,,,,,,,,,,,, -513000,4.9509597,0.69297445,,,,,,,,,,,,,, -513100,4.806123,0.669719,,,,,,,,,,,,,, -513200,4.683614,0.65631807,,,,,,,,,,,,,, -513300,4.176218,0.5251192,,,,,,,,,,,,,, -513400,4.3769374,0.5658623,,,,,,,,,,,,,, -513500,4.727535,0.6483044,,,,,,,,,,,,,, -513600,4.445899,0.65498924,,,,,,,,,,,,,, -513700,5.004423,0.6667225,,,,,,,,,,,,,, -513800,4.633975,0.652795,,,,,,,,,,,,,, -513900,4.4588723,0.6531866,,,,,,,,,,,,,, -513975,,,0.9610969424247742,0.146408274769783,0.7549600005149841,1.0548641681671145,50000.0,0.6304000020027161,1.817402482032776,10000.0,174485.4700343609,180442.1136064529,174485.4700343609,5911.234726428986,27.304914474487305,0.0 -514000,4.8156,0.71501833,,,,,,,,,,,,,, -514100,4.1830187,0.60206234,,,,,,,,,,,,,, -514200,4.5526614,0.59833455,,,,,,,,,,,,,, -514300,4.357539,0.601491,,,,,,,,,,,,,, -514400,4.132051,0.5427356,,,,,,,,,,,,,, -514500,3.9734588,0.57788414,,,,,,,,,,,,,, -514600,4.4121585,0.6334554,,,,,,,,,,,,,, -514700,4.854242,0.68771225,,,,,,,,,,,,,, -514800,4.288714,0.68628764,,,,,,,,,,,,,, -514900,4.470762,0.55265915,,,,,,,,,,,,,, -515000,4.653054,0.7016209,,,,,,,,,,,,,, -515100,4.467099,0.61635315,,,,,,,,,,,,,, -515200,4.6026325,0.6310332,,,,,,,,,,,,,, -515300,4.5337143,0.6610602,,,,,,,,,,,,,, -515400,4.832489,0.64325553,,,,,,,,,,,,,, -515478,,,0.9604192972183228,0.1447800695896148,0.7547199726104736,1.055174469947815,50000.0,0.6305000185966492,1.817795157432556,10000.0,174995.56325149536,180968.88140940663,174995.56325149536,5927.747765779495,27.412547826766968,0.0 -515500,4.543212,0.62976474,,,,,,,,,,,,,, -515600,5.2371926,0.65445375,,,,,,,,,,,,,, -515700,4.975055,0.74869287,,,,,,,,,,,,,, -515800,4.5241594,0.620679,,,,,,,,,,,,,, -515900,4.339131,0.6571033,,,,,,,,,,,,,, -516000,4.488761,0.66568995,,,,,,,,,,,,,, -516100,4.5843945,0.5910343,,,,,,,,,,,,,, -516200,4.9184937,0.6155127,,,,,,,,,,,,,, -516300,4.3711557,0.6228957,,,,,,,,,,,,,, -516400,4.43622,0.5738883,,,,,,,,,,,,,, -516500,4.1320796,0.6121176,,,,,,,,,,,,,, -516600,5.054607,0.732251,,,,,,,,,,,,,, -516700,4.7730727,0.6361935,,,,,,,,,,,,,, -516800,4.4047656,0.6843181,,,,,,,,,,,,,, -516900,4.9265227,0.58260244,,,,,,,,,,,,,, -516981,,,0.9598811864852904,0.1495180130004882,0.7548799514770508,1.055079460144043,50000.0,0.6308000087738037,1.818442702293396,10000.0,175505.54929184914,181495.6579201221,175505.54929184914,5944.373743772507,27.52469778060913,0.0 -517000,4.3715568,0.6000541,,,,,,,,,,,,,, -517100,4.5197334,0.6170653,,,,,,,,,,,,,, -517200,4.5204706,0.6640902,,,,,,,,,,,,,, -517300,4.7003064,0.6627454,,,,,,,,,,,,,, -517400,4.682664,0.60798466,,,,,,,,,,,,,, -517500,4.5662923,0.64409393,,,,,,,,,,,,,, -517600,4.376966,0.6049581,,,,,,,,,,,,,, -517700,4.7210174,0.65085626,,,,,,,,,,,,,, -517800,4.791057,0.5963627,,,,,,,,,,,,,, -517900,5.0891786,0.6442401,,,,,,,,,,,,,, -518000,4.344508,0.57760704,,,,,,,,,,,,,, -518100,5.4015017,0.6650357,,,,,,,,,,,,,, -518200,4.9162107,0.59608984,,,,,,,,,,,,,, -518300,4.4038477,0.6359215,,,,,,,,,,,,,, -518400,4.388828,0.654631,,,,,,,,,,,,,, -518483,,,0.9592633843421936,0.1484085023403167,0.7550199627876282,1.0548003911972046,50000.0,0.6297000050544739,1.8194637298583984,10000.0,176015.523532629,182022.25287151337,176015.523532629,5960.828776597977,27.63849425315857,0.0 -518500,4.1532083,0.56491506,,,,,,,,,,,,,, -518600,4.50421,0.532237,,,,,,,,,,,,,, -518700,4.707314,0.5970143,,,,,,,,,,,,,, -518800,4.4606805,0.6110891,,,,,,,,,,,,,, -518900,4.579941,0.6859555,,,,,,,,,,,,,, -519000,4.287019,0.61088157,,,,,,,,,,,,,, -519100,4.9058228,0.7109442,,,,,,,,,,,,,, -519200,4.0129986,0.56045306,,,,,,,,,,,,,, -519300,4.38779,0.6053627,,,,,,,,,,,,,, -519400,4.3843307,0.52692974,,,,,,,,,,,,,, -519500,4.449896,0.5654069,,,,,,,,,,,,,, -519600,4.3739033,0.6000919,,,,,,,,,,,,,, -519700,4.8851132,0.6797745,,,,,,,,,,,,,, -519800,4.6295643,0.639726,,,,,,,,,,,,,, -519900,4.574018,0.66651994,,,,,,,,,,,,,, -519986,,,0.9608777165412904,0.1482365876436233,0.7551400065422058,1.0555709600448608,50000.0,0.631100058555603,1.8197120428085327,10000.0,176525.6594619751,182549.0837368965,176525.6594619751,5977.352998971939,27.755635499954224,0.0 -520000,4.4503956,0.57898855,,,,,,,,,,,,,, -520100,4.264205,0.6154176,,,,,,,,,,,,,, -520200,4.2872553,0.6281428,,,,,,,,,,,,,, -520300,4.8317027,0.6257049,,,,,,,,,,,,,, -520400,4.762196,0.6631291,,,,,,,,,,,,,, -520500,4.74777,0.6196022,,,,,,,,,,,,,, -520600,4.7789435,0.68924195,,,,,,,,,,,,,, -520700,4.655876,0.63057977,,,,,,,,,,,,,, -520800,4.699562,0.6304444,,,,,,,,,,,,,, -520900,4.9256897,0.6275538,,,,,,,,,,,,,, -521000,4.916604,0.6570017,,,,,,,,,,,,,, -521100,5.092897,0.6900054,,,,,,,,,,,,,, -521200,4.3462563,0.59787446,,,,,,,,,,,,,, -521300,4.3877935,0.61424,,,,,,,,,,,,,, -521400,4.451271,0.6459522,,,,,,,,,,,,,, -521488,,,0.9598413109779358,0.1506748795509338,0.7550599575042725,1.055654764175415,50000.0,0.6308000087738037,1.818725824356079,10000.0,177035.51152539253,183075.60394525528,177035.51152539253,5993.85312795639,27.870992183685303,0.0 -521500,5.010957,0.65498966,,,,,,,,,,,,,, -521600,4.635567,0.6297869,,,,,,,,,,,,,, -521700,4.486823,0.5776698,,,,,,,,,,,,,, -521800,4.2223535,0.56062555,,,,,,,,,,,,,, -521900,4.7336354,0.67612773,,,,,,,,,,,,,, -522000,4.9837465,0.7267977,,,,,,,,,,,,,, -522100,4.703833,0.60756415,,,,,,,,,,,,,, -522200,4.6432714,0.55981576,,,,,,,,,,,,,, -522300,4.452628,0.61495304,,,,,,,,,,,,,, -522400,4.2916336,0.67340285,,,,,,,,,,,,,, -522500,4.6681294,0.7045045,,,,,,,,,,,,,, -522600,4.445432,0.6814765,,,,,,,,,,,,,, -522700,4.4781113,0.6389231,,,,,,,,,,,,,, -522800,4.508353,0.58237684,,,,,,,,,,,,,, -522900,4.664664,0.6532373,,,,,,,,,,,,,, -522990,,,0.9611168503761292,0.1443382948637008,0.7545599937438965,1.0562598705291748,50000.0,0.6309000253677368,1.819977045059204,10000.0,177545.4000184536,183602.2939076424,177545.4000184536,6010.478732824326,27.994094133377075,0.0 -523000,4.8677483,0.6880502,,,,,,,,,,,,,, -523100,4.96808,0.69201857,,,,,,,,,,,,,, -523200,4.7045507,0.6277213,,,,,,,,,,,,,, -523300,4.9007974,0.63771826,,,,,,,,,,,,,, -523400,4.2657332,0.6258332,,,,,,,,,,,,,, -523500,4.4182076,0.60362697,,,,,,,,,,,,,, -523600,4.3626204,0.60798395,,,,,,,,,,,,,, -523700,4.760715,0.66215,,,,,,,,,,,,,, -523800,4.5288134,0.63348544,,,,,,,,,,,,,, -523900,4.437444,0.5776557,,,,,,,,,,,,,, -524000,4.817259,0.64971703,,,,,,,,,,,,,, -524100,4.3346925,0.59025204,,,,,,,,,,,,,, -524200,4.179399,0.63782567,,,,,,,,,,,,,, -524300,4.4143324,0.59544927,,,,,,,,,,,,,, -524400,4.260862,0.5639889,,,,,,,,,,,,,, -524492,,,0.9600406289100648,0.1463096588850021,0.7547599673271179,1.0564887523651123,50000.0,0.6304000020027161,1.8197730779647827,10000.0,178055.3327858448,184128.84509038925,178055.3327858448,6026.929802417755,28.10876989364624,0.0 -524500,4.3383284,0.64077425,,,,,,,,,,,,,, -524600,4.9248953,0.70330226,,,,,,,,,,,,,, -524700,4.415283,0.5889543,,,,,,,,,,,,,, -524800,4.798078,0.66135436,,,,,,,,,,,,,, -524900,4.5313005,0.619183,,,,,,,,,,,,,, -525000,4.867356,0.60264784,,,,,,,,,,,,,, -525100,4.2640567,0.5850197,,,,,,,,,,,,,, -525200,4.6925488,0.61554056,,,,,,,,,,,,,, -525300,4.779125,0.60701895,,,,,,,,,,,,,, -525400,5.507847,0.5967275,,,,,,,,,,,,,, -525500,4.5597916,0.6491977,,,,,,,,,,,,,, -525600,4.306492,0.5522065,,,,,,,,,,,,,, -525700,4.2068095,0.6067659,,,,,,,,,,,,,, -525800,4.5581956,0.582379,,,,,,,,,,,,,, -525900,4.589929,0.6617362,,,,,,,,,,,,,, -525994,,,0.9606385231018066,0.1450351178646087,0.7550599575042725,1.0553193092346191,50000.0,0.6310000419616699,1.8185182809829712,10000.0,178565.2026667595,184655.56948399544,178565.2026667595,6043.616086244583,28.223960638046265,0.0 -526000,4.6082873,0.6613143,,,,,,,,,,,,,, -526100,5.1476307,0.6876237,,,,,,,,,,,,,, -526200,4.306858,0.59989506,,,,,,,,,,,,,, -526300,4.768186,0.63166666,,,,,,,,,,,,,, -526400,4.361756,0.6610451,,,,,,,,,,,,,, -526500,4.9217534,0.67741203,,,,,,,,,,,,,, -526600,4.383229,0.63334274,,,,,,,,,,,,,, -526700,4.3471203,0.5459984,,,,,,,,,,,,,, -526800,4.4698863,0.61180395,,,,,,,,,,,,,, -526900,4.5287976,0.61471164,,,,,,,,,,,,,, -527000,4.4632125,0.68475187,,,,,,,,,,,,,, -527100,4.530502,0.5976051,,,,,,,,,,,,,, -527200,4.5375714,0.5945673,,,,,,,,,,,,,, -527300,4.8175073,0.61147726,,,,,,,,,,,,,, -527400,4.5719576,0.6088058,,,,,,,,,,,,,, -527495,,,0.9610371589660645,0.1455485224723816,0.7551400065422058,1.0551986694335938,50000.0,0.6303000450134277,1.819907665252685,10000.0,179075.05401158333,185182.32403969765,179075.05401158333,6060.282440185547,28.40710473060608,0.0 -527500,4.2836847,0.6328106,,,,,,,,,,,,,, -527600,4.860678,0.67428166,,,,,,,,,,,,,, -527700,4.7419,0.5974235,,,,,,,,,,,,,, -527800,4.881026,0.7180585,,,,,,,,,,,,,, -527900,4.5949426,0.6280964,,,,,,,,,,,,,, -528000,4.803759,0.65081364,,,,,,,,,,,,,, -528100,4.363042,0.61494994,,,,,,,,,,,,,, -528200,5.674734,0.76210874,,,,,,,,,,,,,, -528300,4.275011,0.57727677,,,,,,,,,,,,,, -528400,4.540348,0.6392735,,,,,,,,,,,,,, -528500,4.3366785,0.6998389,,,,,,,,,,,,,, -528600,4.417611,0.6048517,,,,,,,,,,,,,, -528700,4.516427,0.67348456,,,,,,,,,,,,,, -528800,4.500467,0.65487504,,,,,,,,,,,,,, -528900,4.9772816,0.6541681,,,,,,,,,,,,,, -528998,,,0.961136758327484,0.1468872874975204,0.7548399567604065,1.0558708906173706,50000.0,0.6308000087738037,1.8188236951828003,10000.0,179584.9684138298,185709.15247774124,179584.9684138298,6077.02671456337,28.523298025131226,0.0 -529000,4.6664023,0.66061866,,,,,,,,,,,,,, -529100,5.1529007,0.64350927,,,,,,,,,,,,,, -529200,4.8522453,0.6106502,,,,,,,,,,,,,, -529300,4.639767,0.6639272,,,,,,,,,,,,,, -529400,4.922439,0.61730385,,,,,,,,,,,,,, -529500,4.7696834,0.6310331,,,,,,,,,,,,,, -529600,4.8618307,0.7275864,,,,,,,,,,,,,, -529700,4.631624,0.63931525,,,,,,,,,,,,,, -529800,4.765165,0.63453394,,,,,,,,,,,,,, -529900,4.6783085,0.6649608,,,,,,,,,,,,,, -530000,4.2018414,0.57035536,,,,,,,,,,,,,, -530100,4.423701,0.6252802,,,,,,,,,,,,,, -530200,4.933915,0.5958886,,,,,,,,,,,,,, -530300,5.0248833,0.66322744,,,,,,,,,,,,,, -530400,4.491268,0.67935133,,,,,,,,,,,,,, -530500,4.5171914,0.72569054,,,,,,,,,,,,,, -530501,,,0.9602798223495485,0.1466607749462127,0.7544599771499634,1.0554156303405762,50000.0,0.6310000419616699,1.819552183151245,10000.0,180095.15888261795,186236.39528346065,180095.15888261795,6093.907920360565,28.641568899154663,0.0 -530600,4.6803083,0.6797192,,,,,,,,,,,,,, -530700,5.276914,0.6266775,,,,,,,,,,,,,, -530800,4.577754,0.6592208,,,,,,,,,,,,,, -530900,4.629185,0.67330784,,,,,,,,,,,,,, -531000,4.668012,0.63534933,,,,,,,,,,,,,, -531100,4.5628366,0.5990725,,,,,,,,,,,,,, -531200,4.7377915,0.5995815,,,,,,,,,,,,,, -531300,4.446206,0.622021,,,,,,,,,,,,,, -531400,4.724215,0.5911235,,,,,,,,,,,,,, -531500,4.859039,0.66193414,,,,,,,,,,,,,, -531600,4.7770896,0.59679765,,,,,,,,,,,,,, -531700,4.2977796,0.6204104,,,,,,,,,,,,,, -531800,4.814718,0.6251645,,,,,,,,,,,,,, -531900,4.1990027,0.5731678,,,,,,,,,,,,,, -532000,4.460879,0.6139379,,,,,,,,,,,,,, -532004,,,0.9609972834587096,0.1430176347494125,0.7551400065422058,1.0556042194366455,50000.0,0.6317000389099121,1.8177498579025269,10000.0,180605.15130925176,186763.15968847275,180605.15130925176,6110.511300086975,28.757806062698364,0.0 -532100,4.211201,0.6066969,,,,,,,,,,,,,, -532200,4.3604546,0.6697958,,,,,,,,,,,,,, -532300,4.086059,0.5283469,,,,,,,,,,,,,, -532400,4.8079534,0.660307,,,,,,,,,,,,,, -532500,4.8505545,0.59051204,,,,,,,,,,,,,, -532600,4.3522196,0.5847813,,,,,,,,,,,,,, -532700,4.6490936,0.674768,,,,,,,,,,,,,, -532800,4.5931845,0.65621316,,,,,,,,,,,,,, -532900,4.315092,0.6176702,,,,,,,,,,,,,, -533000,4.4277396,0.6308176,,,,,,,,,,,,,, -533100,4.308998,0.6248818,,,,,,,,,,,,,, -533200,4.797354,0.6194431,,,,,,,,,,,,,, -533300,4.7474375,0.61701393,,,,,,,,,,,,,, -533400,4.3910346,0.68449587,,,,,,,,,,,,,, -533500,4.5719953,0.6171647,,,,,,,,,,,,,, -533507,,,0.9610969424247742,0.1459327936172485,0.7546799778938293,1.0548255443572998,50000.0,0.6303000450134277,1.8190414905548096,10000.0,181115.2619752884,187290.04085493088,181115.2619752884,6127.112663269043,28.87390160560608,0.0 -533600,4.2631755,0.5795686,,,,,,,,,,,,,, -533700,4.804306,0.648358,,,,,,,,,,,,,, -533800,4.7845044,0.67741555,,,,,,,,,,,,,, -533900,4.6388674,0.65913737,,,,,,,,,,,,,, -534000,4.4816017,0.62195057,,,,,,,,,,,,,, -534100,4.493187,0.6382552,,,,,,,,,,,,,, -534200,4.561385,0.66424763,,,,,,,,,,,,,, -534300,4.2575054,0.58874357,,,,,,,,,,,,,, -534400,4.2591615,0.5389225,,,,,,,,,,,,,, -534500,4.7252483,0.65764076,,,,,,,,,,,,,, -534600,4.5834227,0.5533763,,,,,,,,,,,,,, -534700,4.4697824,0.58639824,,,,,,,,,,,,,, -534800,4.500234,0.5818183,,,,,,,,,,,,,, -534900,4.544119,0.647708,,,,,,,,,,,,,, -535000,5.284553,0.661695,,,,,,,,,,,,,, -535010,,,0.9614556431770324,0.1439944505691528,0.7545799612998962,1.0565119981765747,50000.0,0.6301000118255615,1.820230960845948,10000.0,181625.2687883377,187817.1193869114,181625.2687883377,6144.01261806488,28.99262237548828,0.0 -535100,4.8511515,0.64951456,,,,,,,,,,,,,, -535200,4.540268,0.61147857,,,,,,,,,,,,,, -535300,4.308428,0.61541545,,,,,,,,,,,,,, -535400,4.614857,0.5774359,,,,,,,,,,,,,, -535500,4.885736,0.66091645,,,,,,,,,,,,,, -535600,4.4000483,0.6123465,,,,,,,,,,,,,, -535700,4.259109,0.5661628,,,,,,,,,,,,,, -535800,4.088526,0.543506,,,,,,,,,,,,,, -535900,4.6000085,0.54422235,,,,,,,,,,,,,, -536000,4.6370125,0.6352881,,,,,,,,,,,,,, -536100,4.8465447,0.6375629,,,,,,,,,,,,,, -536200,5.0268903,0.7071171,,,,,,,,,,,,,, -536300,4.267108,0.57852364,,,,,,,,,,,,,, -536400,4.5175095,0.6519693,,,,,,,,,,,,,, -536500,4.9497967,0.70280296,,,,,,,,,,,,,, -536513,,,0.9588249325752258,0.1505401283502578,0.7545199990272522,1.0554897785186768,50000.0,0.6297000050544739,1.8187036514282229,10000.0,182135.4176516533,188344.07443284988,182135.4176516533,6160.654649019241,29.10392904281616,0.0 -536600,4.878097,0.6124823,,,,,,,,,,,,,, -536700,4.0577927,0.5500787,,,,,,,,,,,,,, -536800,4.6141205,0.64303565,,,,,,,,,,,,,, -536900,4.401906,0.5849081,,,,,,,,,,,,,, -537000,4.6592417,0.6694923,,,,,,,,,,,,,, -537100,4.845806,0.6190513,,,,,,,,,,,,,, -537200,4.3168473,0.59219056,,,,,,,,,,,,,, -537300,4.369489,0.6159914,,,,,,,,,,,,,, -537400,4.8375697,0.70053035,,,,,,,,,,,,,, -537500,5.0018053,0.6608301,,,,,,,,,,,,,, -537600,4.287075,0.65488374,,,,,,,,,,,,,, -537700,4.3884544,0.5862178,,,,,,,,,,,,,, -537800,4.8268733,0.6368869,,,,,,,,,,,,,, -537900,4.731173,0.6056694,,,,,,,,,,,,,, -538000,4.996905,0.6519614,,,,,,,,,,,,,, -538016,,,0.9614357352256776,0.1456693708896637,0.754539966583252,1.0547114610671997,50000.0,0.6300000548362732,1.8173961639404297,10000.0,182645.48416352272,188870.8207633496,182645.48416352272,6177.164265871048,29.22213888168335,0.0 -538100,4.1232567,0.55176663,,,,,,,,,,,,,, -538200,4.5619392,0.6140081,,,,,,,,,,,,,, -538300,4.6433454,0.6319561,,,,,,,,,,,,,, -538400,4.4392457,0.62280464,,,,,,,,,,,,,, -538500,4.644229,0.5948713,,,,,,,,,,,,,, -538600,4.153245,0.5739328,,,,,,,,,,,,,, -538700,4.1970425,0.57329214,,,,,,,,,,,,,, -538800,4.2561398,0.57034194,,,,,,,,,,,,,, -538900,4.1822195,0.58850884,,,,,,,,,,,,,, -539000,4.6554956,0.6511596,,,,,,,,,,,,,, -539100,5.1391206,0.58926874,,,,,,,,,,,,,, -539200,4.236468,0.59957683,,,,,,,,,,,,,, -539300,4.6244636,0.6083717,,,,,,,,,,,,,, -539400,4.258046,0.57392097,,,,,,,,,,,,,, -539500,4.540414,0.5600626,,,,,,,,,,,,,, -539519,,,0.9617546200752258,0.1445423066616058,0.7547599673271179,1.055425047874451,50000.0,0.6303000450134277,1.819313764572144,10000.0,183155.6240270137,189398.0522556305,183155.6240270137,6194.077189683914,29.346582889556885,0.0 -539600,4.4986153,0.6532831,,,,,,,,,,,,,, -539700,4.4650497,0.62631315,,,,,,,,,,,,,, -539800,4.462378,0.616624,,,,,,,,,,,,,, -539900,4.4003334,0.5971899,,,,,,,,,,,,,, -540000,4.3829875,0.63162535,,,,,,,,,,,,,, -540100,4.2017865,0.6407765,,,,,,,,,,,,,, -540200,4.611418,0.6146001,,,,,,,,,,,,,, -540300,4.407629,0.62054646,,,,,,,,,,,,,, -540400,4.2388635,0.58505875,,,,,,,,,,,,,, -540500,5.0582085,0.6666151,,,,,,,,,,,,,, -540600,4.479547,0.6111245,,,,,,,,,,,,,, -540700,4.758965,0.6142264,,,,,,,,,,,,,, -540800,4.5662336,0.6651482,,,,,,,,,,,,,, -540900,4.537379,0.5871961,,,,,,,,,,,,,, -541000,5.188475,0.67705476,,,,,,,,,,,,,, -541023,,,0.961535394191742,0.1437825411558151,0.7549399733543396,1.0555442571640017,50000.0,0.6304000020027161,1.8197301626205444,10000.0,183665.75982761383,189925.0090081692,183665.75982761383,6210.7504341602325,29.442301750183105,0.0 -541100,4.911186,0.6056456,,,,,,,,,,,,,, -541200,4.6665688,0.65918434,,,,,,,,,,,,,, -541300,4.9624395,0.64355856,,,,,,,,,,,,,, -541400,4.294127,0.5314223,,,,,,,,,,,,,, -541500,4.609609,0.6363442,,,,,,,,,,,,,, -541600,5.268932,0.67316246,,,,,,,,,,,,,, -541700,4.652616,0.5700268,,,,,,,,,,,,,, -541800,4.7032413,0.63184494,,,,,,,,,,,,,, -541900,4.3628697,0.6408864,,,,,,,,,,,,,, -542000,5.2024484,0.6769693,,,,,,,,,,,,,, -542100,4.2614427,0.5273611,,,,,,,,,,,,,, -542200,4.648862,0.65199417,,,,,,,,,,,,,, -542300,4.716884,0.71392995,,,,,,,,,,,,,, -542400,4.1948314,0.55353785,,,,,,,,,,,,,, -542500,4.712032,0.5554248,,,,,,,,,,,,,, -542525,,,0.960957407951355,0.146311342716217,0.7544199824333191,1.055719256401062,50000.0,0.6294000148773193,1.819315910339356,10000.0,184175.8125114441,190451.8363242149,184175.8125114441,6227.353454113007,29.56045937538147,0.0 -542600,4.7402105,0.68146783,,,,,,,,,,,,,, -542700,4.480473,0.57285905,,,,,,,,,,,,,, -542800,5.1000347,0.69324684,,,,,,,,,,,,,, -542900,4.824053,0.6030555,,,,,,,,,,,,,, -543000,4.4023557,0.5771675,,,,,,,,,,,,,, -543100,4.564217,0.6583625,,,,,,,,,,,,,, -543200,4.5339546,0.6465708,,,,,,,,,,,,,, -543300,4.520083,0.57138956,,,,,,,,,,,,,, -543400,4.442671,0.6368463,,,,,,,,,,,,,, -543500,4.3187103,0.6276887,,,,,,,,,,,,,, -543600,4.0563445,0.5789907,,,,,,,,,,,,,, -543700,4.700508,0.71620315,,,,,,,,,,,,,, -543800,4.203603,0.5670108,,,,,,,,,,,,,, -543900,4.909785,0.6196097,,,,,,,,,,,,,, -544000,4.7240233,0.5956503,,,,,,,,,,,,,, -544028,,,0.9606584906578064,0.1456425189971923,0.7549799680709839,1.0556423664093018,50000.0,0.6305000185966492,1.817673921585083,10000.0,184685.7376544476,190978.5760681629,184685.7376544476,6243.999848365784,29.67587733268737,0.0 -544100,5.501075,0.6523803,,,,,,,,,,,,,, -544200,4.7961526,0.6253257,,,,,,,,,,,,,, -544300,4.377124,0.5984012,,,,,,,,,,,,,, -544400,5.1889586,0.74724954,,,,,,,,,,,,,, -544500,4.5189514,0.6205362,,,,,,,,,,,,,, -544600,4.8066974,0.6329498,,,,,,,,,,,,,, -544700,4.599152,0.6053519,,,,,,,,,,,,,, -544800,4.6344967,0.60168767,,,,,,,,,,,,,, -544900,4.458311,0.6264128,,,,,,,,,,,,,, -545000,4.5948844,0.6623603,,,,,,,,,,,,,, -545100,4.2545853,0.604174,,,,,,,,,,,,,, -545200,4.6742945,0.6257402,,,,,,,,,,,,,, -545300,4.753695,0.6326378,,,,,,,,,,,,,, -545400,4.424812,0.5398006,,,,,,,,,,,,,, -545500,4.732206,0.6854745,,,,,,,,,,,,,, -545531,,,0.9602399468421936,0.1466413885354995,0.7547599673271179,1.0558973550796509,50000.0,0.6301000118255615,1.818894028663636,10000.0,185195.77567243576,191505.38779973984,185195.77567243576,6260.603581190109,29.792681217193604,0.0 -545600,4.562023,0.6174666,,,,,,,,,,,,,, -545700,4.690965,0.62562984,,,,,,,,,,,,,, -545800,4.4502864,0.5783733,,,,,,,,,,,,,, -545900,4.471319,0.6249308,,,,,,,,,,,,,, -546000,4.5708356,0.6140064,,,,,,,,,,,,,, -546100,4.504136,0.60070825,,,,,,,,,,,,,, -546200,4.688202,0.5193545,,,,,,,,,,,,,, -546300,4.441827,0.5992525,,,,,,,,,,,,,, -546400,4.291339,0.573733,,,,,,,,,,,,,, -546500,4.596846,0.66368604,,,,,,,,,,,,,, -546600,4.164474,0.58782023,,,,,,,,,,,,,, -546700,4.8604097,0.5930958,,,,,,,,,,,,,, -546800,4.9413047,0.64026654,,,,,,,,,,,,,, -546900,4.454998,0.6412055,,,,,,,,,,,,,, -547000,4.2352066,0.6099461,,,,,,,,,,,,,, -547034,,,0.9611965417861938,0.1448834687471389,0.7548399567604065,1.0554053783416748,50000.0,0.6303000450134277,1.819093585014344,10000.0,185705.8674867153,192032.17276000977,185705.8674867153,6277.120337963104,29.91552424430847,0.0 -547100,4.8330812,0.65806425,,,,,,,,,,,,,, -547200,4.49576,0.6100534,,,,,,,,,,,,,, -547300,4.578561,0.6247236,,,,,,,,,,,,,, -547400,4.4147305,0.62387955,,,,,,,,,,,,,, -547500,4.51133,0.6260363,,,,,,,,,,,,,, -547600,4.9084435,0.6496868,,,,,,,,,,,,,, -547700,4.9592047,0.60996765,,,,,,,,,,,,,, -547800,4.567938,0.5926888,,,,,,,,,,,,,, -547900,4.408443,0.5966587,,,,,,,,,,,,,, -548000,4.7859945,0.57454896,,,,,,,,,,,,,, -548100,4.018507,0.46226838,,,,,,,,,,,,,, -548200,4.1317353,0.6351558,,,,,,,,,,,,,, -548300,4.6755986,0.628252,,,,,,,,,,,,,, -548400,4.5221453,0.566878,,,,,,,,,,,,,, -548500,4.3383713,0.64708525,,,,,,,,,,,,,, -548537,,,0.9606385231018066,0.1475456207990646,0.7551599740982056,1.054714322090149,50000.0,0.6305000185966492,1.818883299827576,10000.0,186216.0143876076,192558.9681181908,186216.0143876076,6293.597116947174,30.034253358840942,0.0 -548600,4.570659,0.6078354,,,,,,,,,,,,,, -548700,4.716041,0.6818444,,,,,,,,,,,,,, -548800,4.5509405,0.66261363,,,,,,,,,,,,,, -548900,4.5876493,0.55990875,,,,,,,,,,,,,, -549000,4.31672,0.58847666,,,,,,,,,,,,,, -549100,4.7455683,0.6768292,,,,,,,,,,,,,, -549200,4.2753673,0.5732441,,,,,,,,,,,,,, -549300,4.728214,0.57005364,,,,,,,,,,,,,, -549400,4.4642735,0.6257021,,,,,,,,,,,,,, -549500,4.591401,0.5789798,,,,,,,,,,,,,, -549600,5.240436,0.66382056,,,,,,,,,,,,,, -549700,5.0268855,0.6716682,,,,,,,,,,,,,, -549800,4.864651,0.698528,,,,,,,,,,,,,, -549900,4.3816447,0.59829456,,,,,,,,,,,,,, -550000,5.014962,0.64362156,,,,,,,,,,,,,, -550040,,,0.9608976244926452,0.1470227688550949,0.7550999522209167,1.056078553199768,50000.0,0.6306000351905823,1.8195624351501465,10000.0,186725.95565152168,193085.6027336121,186725.95565152168,6310.110383272171,30.16148805618286,0.0 -550100,4.6739216,0.65765357,,,,,,,,,,,,,, -550200,4.327538,0.56048894,,,,,,,,,,,,,, -550300,4.74298,0.57146585,,,,,,,,,,,,,, -550400,4.63101,0.6244359,,,,,,,,,,,,,, -550500,4.569412,0.60828775,,,,,,,,,,,,,, -550600,4.66448,0.62292236,,,,,,,,,,,,,, -550700,4.486903,0.60375625,,,,,,,,,,,,,, -550800,4.557655,0.5828641,,,,,,,,,,,,,, -550900,4.3147316,0.60021853,,,,,,,,,,,,,, -551000,4.735624,0.6056789,,,,,,,,,,,,,, -551100,4.8464103,0.67279345,,,,,,,,,,,,,, -551200,4.454305,0.60863507,,,,,,,,,,,,,, -551300,4.7156596,0.62967825,,,,,,,,,,,,,, -551400,4.587638,0.6867759,,,,,,,,,,,,,, -551500,4.3592253,0.6092557,,,,,,,,,,,,,, -551542,,,0.9594626426696776,0.1485990285873413,0.7549399733543396,1.0570076704025269,50000.0,0.6299000382423401,1.8204457759857176,10000.0,187235.9205963612,193612.2978289128,187235.9205963612,6326.669639825821,30.279273509979248,0.0 -551600,4.3319426,0.5980253,,,,,,,,,,,,,, -551700,4.166227,0.5500395,,,,,,,,,,,,,, -551800,4.3641043,0.6157571,,,,,,,,,,,,,, -551900,4.3997765,0.6588958,,,,,,,,,,,,,, -552000,5.045413,0.604797,,,,,,,,,,,,,, -552100,4.0594034,0.59024024,,,,,,,,,,,,,, -552200,4.621666,0.6759583,,,,,,,,,,,,,, -552300,4.1559963,0.53538007,,,,,,,,,,,,,, -552400,4.4104815,0.6647201,,,,,,,,,,,,,, -552500,4.5043435,0.60413855,,,,,,,,,,,,,, -552600,4.4537587,0.6522227,,,,,,,,,,,,,, -552700,4.355546,0.6767269,,,,,,,,,,,,,, -552800,4.681774,0.60613483,,,,,,,,,,,,,, -552900,4.518557,0.6010902,,,,,,,,,,,,,, -553000,4.333343,0.59156334,,,,,,,,,,,,,, -553045,,,0.9600207209587096,0.1465989500284195,0.7548999786376953,1.0555554628372192,50000.0,0.6305000185966492,1.8178973197937007,10000.0,187745.98685359955,194138.9894325733,187745.98685359955,6343.1246864795685,30.39686608314514,0.0 -553100,4.7355676,0.6240541,,,,,,,,,,,,,, -553200,4.667708,0.63353765,,,,,,,,,,,,,, -553300,4.485146,0.6066456,,,,,,,,,,,,,, -553400,4.5249376,0.63232696,,,,,,,,,,,,,, -553500,3.9383788,0.4912924,,,,,,,,,,,,,, -553600,4.7000203,0.67159957,,,,,,,,,,,,,, -553700,4.9244733,0.64118487,,,,,,,,,,,,,, -553800,4.5545387,0.6510819,,,,,,,,,,,,,, -553900,4.535014,0.63621753,,,,,,,,,,,,,, -554000,4.441811,0.61576563,,,,,,,,,,,,,, -554100,4.2971973,0.62756824,,,,,,,,,,,,,, -554200,4.5750885,0.62516886,,,,,,,,,,,,,, -554300,4.5497613,0.61610043,,,,,,,,,,,,,, -554400,4.6848345,0.6564337,,,,,,,,,,,,,, -554500,4.5469,0.6077963,,,,,,,,,,,,,, -554548,,,0.960339605808258,0.1473934650421142,0.7547000050544739,1.0555559396743774,50000.0,0.6312000155448914,1.8189270496368408,10000.0,188256.05421233177,194665.88861370087,188256.05421233177,6359.778885602951,30.522332429885864,0.0 -554600,4.391655,0.65161085,,,,,,,,,,,,,, -554700,4.7475734,0.63766336,,,,,,,,,,,,,, -554800,4.4460554,0.6663499,,,,,,,,,,,,,, -554900,4.2951183,0.59593135,,,,,,,,,,,,,, -555000,4.4343767,0.6303102,,,,,,,,,,,,,, -555100,3.9856749,0.5458873,,,,,,,,,,,,,, -555200,4.541197,0.58790517,,,,,,,,,,,,,, -555300,4.484066,0.6582726,,,,,,,,,,,,,, -555400,4.625047,0.6135712,,,,,,,,,,,,,, -555500,4.665305,0.66645366,,,,,,,,,,,,,, -555600,4.552159,0.63130736,,,,,,,,,,,,,, -555700,4.373166,0.64848495,,,,,,,,,,,,,, -555800,4.802426,0.6659204,,,,,,,,,,,,,, -555900,4.6542535,0.67787987,,,,,,,,,,,,,, -556000,4.006634,0.63183016,,,,,,,,,,,,,, -556051,,,0.9599210619926452,0.1484605371952057,0.7547599673271179,1.0562946796417236,50000.0,0.6299000382423401,1.8204028606414795,10000.0,188766.14553546903,195192.6730442047,188766.14553546903,6376.293786287308,30.64692568778992,0.0 -556100,4.5543265,0.60292107,,,,,,,,,,,,,, -556200,3.9862177,0.5787694,,,,,,,,,,,,,, -556300,4.5226603,0.61146927,,,,,,,,,,,,,, -556400,4.6970615,0.68811476,,,,,,,,,,,,,, -556500,4.532406,0.65102756,,,,,,,,,,,,,, -556600,4.3230057,0.5942986,,,,,,,,,,,,,, -556700,4.3557367,0.55823135,,,,,,,,,,,,,, -556800,4.705089,0.66469115,,,,,,,,,,,,,, -556812,,,,,,,,,,,189024.25990390778,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index cd1cae69e..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,555 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.57194924354553,0.0,42.89926743507385,1,0,42.89926743507385,0.0010000000474974,6.907756805419922,10000,83.47133111953735,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -62.162476539611816,0.0280063152313232,462.8738512992859,911,0,462.8738512992859,0.0298000015318393,5.954464435577393,10000,525.1148769855499,0.0390429683029651,5.780252933502197,0.0384799987077713,5.81451940536499,50000 -83.55760931968689,0.0585842132568359,882.9174020290375,1874,0,882.9174020290375,0.0593000017106533,5.495050430297852,10000,966.6380949020386,0.0801562517881393,5.216738700866699,0.0748599991202354,5.251400947570801,50000 -105.08153319358826,0.0892825126647949,1303.1719748973846,2839,0,1303.1719748973846,0.1096000075340271,4.916996479034424,10000,1408.4984166622162,0.1545898467302322,4.480022430419922,0.1401599943637848,4.580586910247803,50000 -126.71310567855836,0.1167037487030029,1723.439739942551,3804,0,1723.439739942551,0.1480000019073486,4.542527198791504,10000,1850.476708889008,0.2083788961172104,4.025781154632568,0.1943800002336502,4.114321231842041,50000 -148.57643723487854,0.1443471908569336,2143.524181365967,4761,0,2143.524181365967,0.1910000145435333,4.191851139068604,10000,2292.5040712356567,0.2685156166553497,3.618990182876587,0.2489599883556366,3.72690749168396,50000 -170.15734696388245,0.1746456623077392,2563.565025806427,5717,0,2563.565025806427,0.2207000106573104,3.9930953979492174,10000,2734.2075912952423,0.3191015422344208,3.2768993377685547,0.2922999858856201,3.4312615394592285,50000 -191.6556527614593,0.2017805576324463,2983.585864305496,6672,0,2983.585864305496,0.2574000060558319,3.691765785217285,10000,3175.8048005104065,0.3853124976158142,2.87861442565918,0.3353399932384491,3.149968147277832,50000 -219.53815126419067,0.2309470176696777,3403.8257796764374,7626,0,3403.8257796764374,0.2787000238895416,3.58750057220459,10000,3624.00797700882,0.3844531178474426,2.8820815086364746,0.3578200042247772,3.024298667907715,50000 -244.53706979751587,0.270564317703247,3823.89849114418,8585,0,3823.89849114418,0.2975000143051147,3.4198501110076904,10000,4069.170470237732,0.4216406047344208,2.65896987915039,0.3878999948501587,2.8341987133026123,50000 -271.1634876728058,0.3059976100921631,4243.820187568665,9541,0,4243.820187568665,0.3157000243663788,3.2867302894592285,10000,4515.805619239807,0.4546484351158142,2.4515960216522217,0.4140599966049194,2.663148641586304,50000 -301.6569554805756,0.3401913642883301,4663.750853538513,10498,0,4663.750853538513,0.3304000198841095,3.208634853363037,10000,4966.315071105957,0.4586913883686065,2.4536895751953125,0.4279399812221527,2.6150896549224854,50000 -335.0212616920471,0.377582311630249,5083.722537994385,11454,0,5083.722537994385,0.3394000232219696,3.15163803100586,10000,5419.739944219589,0.477832019329071,2.340218067169189,0.4437799751758575,2.513360977172852,50000 -363.1043710708618,0.4115440845489502,5504.124020576477,12412,0,5504.124020576477,0.356000006198883,3.023467540740967,10000,5868.309863567352,0.5044335722923279,2.160196781158448,0.4608799815177917,2.383046865463257,50000 -388.9872608184815,0.4456574916839599,5924.25744843483,13365,0,5924.25744843483,0.3673000037670135,2.974306106567383,10000,6314.410578966141,0.52685546875,2.067809820175171,0.4763000011444092,2.321436882019043,50000 -419.7616608142853,0.4812161922454834,6344.441004276276,14317,0,6344.441004276276,0.3791000247001648,2.91217041015625,10000,6765.454939126968,0.5179492235183716,2.1016087532043457,0.4814999997615814,2.2885682582855225,50000 -452.32911682128906,0.5112817287445068,6764.550142526627,15266,0,6764.550142526627,0.3884000182151794,2.861824989318848,10000,7218.211961269379,0.5394921898841858,1.976670503616333,0.4970199763774872,2.2036519050598145,50000 -487.1478357315064,0.5434060096740723,7184.7117347717285,16217,0,7184.7117347717285,0.3890000283718109,2.8712475299835205,10000,7673.275277376175,0.5404882431030273,2.0090866088867188,0.4944399893283844,2.2409517765045166,50000 -522.9998071193695,0.5738611221313477,7604.762713670731,17168,0,7604.762713670731,0.38960000872612,2.8298895359039307,10000,8129.259591817856,0.5452538728713989,1.9710559844970703,0.5032199621200562,2.1825075149536133,50000 -558.4737074375153,0.6011612415313721,8025.002263069153,18118,0,8025.002263069153,0.4029000103473663,2.759299993515014,10000,8585.051396369934,0.5589648485183716,1.9002271890640257,0.5151199698448181,2.097113847732544,50000 -596.0462672710419,0.6375689506530762,8445.054428339005,19064,0,8445.054428339005,0.4075000286102295,2.712984800338745,10000,9042.762500047684,0.5706835985183716,1.8196309804916384,0.5245599746704102,2.0479772090911865,50000 -631.0662536621094,0.6704673767089844,8865.098269224167,20015,0,8865.098269224167,0.4067000150680542,2.7463629245758057,10000,9497.909968614578,0.5750390291213989,1.813962697982788,0.5193799734115601,2.0880651473999023,50000 -668.1161289215088,0.7022011280059814,9285.359723567964,20959,0,9285.359723567964,0.4192000329494476,2.6817190647125244,10000,9955.302630662918,0.5743163824081421,1.8025161027908323,0.5355199575424194,2.0042288303375244,50000 -702.4113335609436,0.7362079620361328,9705.488340377808,21905,0,9705.488340377808,0.4274000227451324,2.6510984897613525,10000,10409.810875177383,0.5809569954872131,1.789517521858215,0.5389999747276306,1.9918287992477417,50000 -738.0171258449554,0.7833480834960938,10125.69636964798,22848,0,10125.69636964798,0.4195000231266022,2.6842269897460938,10000,10865.72094798088,0.5851757526397705,1.7660349607467651,0.5362200140953064,2.0083796977996826,50000 -773.5448224544525,0.8156635761260986,10545.699042081833,23799,0,10545.699042081833,0.4353000223636627,2.58562970161438,10000,11321.33425951004,0.6068359017372131,1.6197729110717771,0.5496999621391296,1.906225562095642,50000 -808.2378516197205,0.850771427154541,10965.678258657455,24751,0,10965.678258657455,0.4314000308513641,2.6117215156555176,10000,11776.092476606367,0.5919336080551147,1.7640339136123655,0.5448200106620789,1.9703313112258911,50000 -843.9393339157104,0.8788700103759766,11385.718565702438,25696,0,11385.718565702438,0.4390000104904175,2.579681873321533,10000,12231.912591457369,0.6055663824081421,1.681242823600769,0.560539960861206,1.8968688249588013,50000 -878.3297283649445,1.248054265975952,11805.5884308815,26633,0,11805.5884308815,0.4310000240802765,2.5757434368133545,10000,12686.59158539772,0.6142968535423279,1.5852680206298828,0.5594599843025208,1.871098756790161,50000 -912.7611672878264,1.279726266860962,12225.917182445526,27578,0,12225.917182445526,0.442300021648407,2.531644105911255,10000,13141.4329662323,0.6026562452316284,1.6675916910171509,0.5617600083351135,1.8550418615341189,50000 -947.8815426826476,1.3121697902679443,12645.979301929474,28520,0,12645.979301929474,0.4536000192165375,2.510180950164795,10000,13596.697838544846,0.6154882907867432,1.6122013330459597,0.5725399851799011,1.828546643257141,50000 -983.126780986786,1.3409326076507568,13066.364196300508,29466,0,13066.364196300508,0.4462000131607055,2.499765634536743,10000,14052.40638923645,0.6235937476158142,1.5766786336898804,0.5673199892044067,1.8261879682540887,50000 -1017.985541820526,1.370903491973877,13486.317937612534,30410,0,13486.317937612534,0.4428000152111053,2.5430333614349365,10000,14507.29845237732,0.6274023056030273,1.577084183692932,0.5665199756622314,1.86709463596344,50000 -1052.953050851822,1.4022631645202637,13906.673711299896,31356,0,13906.673711299896,0.4594000279903412,2.448460102081299,10000,14962.702766418455,0.6219531297683716,1.5759414434432983,0.581279993057251,1.7813587188720703,50000 -1088.0273683071136,1.436601638793945,14327.127075195312,32301,0,14327.127075195312,0.4629000127315521,2.435957193374634,10000,15418.314130306244,0.6274804472923279,1.536577582359314,0.5806599855422974,1.76521897315979,50000 -1121.0258762836456,1.4687433242797852,14747.1489007473,33248,0,14747.1489007473,0.4598000347614288,2.4462242126464844,10000,15871.41714978218,0.6382226347923279,1.4830291271209717,0.5817999839782715,1.751964449882507,50000 -1155.8789296150208,1.502777338027954,15167.075635671616,34193,0,15167.075635671616,0.4635000228881836,2.436673402786255,10000,16326.279378652573,0.6274804472923279,1.5200376510620115,0.5874999761581421,1.736883521080017,50000 -1190.3300416469574,1.5448389053344729,15587.07857298851,35139,0,15587.07857298851,0.4652000367641449,2.4026684761047363,10000,16780.82415175438,0.6370312571525574,1.5092811584472656,0.5901399850845337,1.7288663387298584,50000 -1225.792776346207,1.5797712802886963,16007.0850918293,36086,0,16007.0850918293,0.4682000279426574,2.412063837051392,10000,17236.377789497375,0.6399999856948853,1.4712377786636353,0.5904200077056885,1.7305066585540771,50000 -1260.8337228298187,1.616624116897583,16427.407952070236,37030,0,16427.407952070236,0.4718000292778015,2.398625612258911,10000,17691.82711672783,0.6620116829872131,1.3794057369232178,0.592519998550415,1.7107194662094116,50000 -1295.9946694374084,1.6496033668518066,16847.74187517166,37975,0,16847.74187517166,0.4763000309467315,2.375365257263184,10000,18147.40429544449,0.6415038704872131,1.4643170833587646,0.5976799726486206,1.683479905128479,50000 -1330.1508178710938,1.68623948097229,17267.9571518898,38922,0,17267.9571518898,0.4737000167369842,2.377290725708008,10000,18601.86097741127,0.6445702910423279,1.4891875982284546,0.5958200097084045,1.7240655422210691,50000 -1364.786494731903,1.719895601272583,17688.084655046463,39869,0,17688.084655046463,0.4751000106334686,2.372420072555542,10000,19056.70727801323,0.6550585627555847,1.4371448755264282,0.5964199900627136,1.7216540575027466,50000 -1398.0626657009125,1.75140643119812,18108.195221424103,40814,0,18108.195221424103,0.4708000123500824,2.375142335891724,10000,19510.174768209457,0.6373828053474426,1.492069959640503,0.5970399975776672,1.6841896772384644,50000 -1432.521831035614,1.785491704940796,18528.375126600266,41758,0,18528.375126600266,0.480400025844574,2.345698833465576,10000,19964.89772510529,0.6490429639816284,1.4463257789611816,0.5996599793434143,1.6779924631118774,50000 -1467.7309651374817,1.8197968006134035,18948.30942368508,42704,0,18948.30942368508,0.4749000370502472,2.391242742538452,10000,20420.125157356262,0.6508007645606995,1.487926721572876,0.5974400043487549,1.734683871269226,50000 -1503.6146020889282,1.8551933765411377,19368.389864444733,43650,0,19368.389864444733,0.4786000251770019,2.350553274154663,10000,20876.173667669296,0.684277355670929,1.2981736660003662,0.6021400094032288,1.6690380573272705,50000 -1538.0971751213074,1.8865134716033936,19788.714443922043,44598,0,19788.714443922043,0.4823000133037567,2.3217082023620605,10000,21331.062341213223,0.6547460556030273,1.4286702871322632,0.6058399677276611,1.6572929620742798,50000 -1572.1255660057068,1.9221677780151367,20208.8225941658,45540,0,20208.8225941658,0.4834000170230865,2.2972004413604736,10000,21785.28338742256,0.660351574420929,1.3766939640045166,0.6067799925804138,1.6368731260299685,50000 -1606.8116953372955,1.9543015956878664,20628.96104216576,46483,0,20628.96104216576,0.4816000163555145,2.329615592956543,10000,22240.189120054245,0.6699609160423279,1.3838269710540771,0.6090999841690063,1.6606667041778564,50000 -1642.6166734695437,1.9951858520507808,21049.05564045906,47427,0,21049.05564045906,0.4983000159263611,2.286940097808838,10000,22696.17867732048,0.6600781083106995,1.3950390815734863,0.6157999634742737,1.612318754196167,50000 -1679.050225496292,2.0333759784698486,21469.06161904335,48373,0,21469.06161904335,0.4912000298500061,2.276652336120605,10000,23152.705780267715,0.6630077958106995,1.3899929523468018,0.6116600036621094,1.618957281112671,50000 -1714.1676445007324,2.0756874084472656,21889.25936102867,49317,0,21889.25936102867,0.4905000329017639,2.2634685039520264,10000,23608.112749814987,0.6714062094688416,1.346708059310913,0.617680013179779,1.5923179388046265,50000 -1749.5697557926178,2.108650684356689,22309.52768969536,50263,0,22309.52768969536,0.4958000183105469,2.265782594680786,10000,24063.864980220795,0.696582019329071,1.2489502429962158,0.6180999875068665,1.5910358428955078,50000 -1785.455048084259,2.143319845199585,22729.79816842079,51208,0,22729.79816842079,0.4981000125408172,2.2476613521575928,10000,24520.10476160049,0.6667382717132568,1.3661245107650757,0.6186599731445312,1.5887244939804075,50000 -1821.455227851868,2.177870750427246,23150.10472869873,52154,0,23150.10472869873,0.4946000277996063,2.257022857666016,10000,24976.49471473694,0.6708202958106995,1.3567956686019895,0.6210799813270569,1.588613986968994,50000 -1857.0161018371584,2.210293292999268,23570.43446731568,53100,0,23570.43446731568,0.4946000277996063,2.253964424133301,10000,25432.46706700325,0.6807812452316284,1.2781988382339478,0.6184599995613098,1.5739054679870603,50000 -1892.2611465454104,2.2538437843322754,23990.75499391556,54046,0,23990.75499391556,0.4976000189781189,2.259164571762085,10000,25888.12592768669,0.6664257645606995,1.3784408569335938,0.621679961681366,1.5963414907455444,50000 -1928.172496318817,2.290111303329468,24410.766496419907,54991,0,24410.766496419907,0.5001000165939331,2.2374267578125,10000,26344.1343998909,0.6720898151397705,1.3343063592910769,0.6214399933815002,1.5615304708480835,50000 -1963.9228575229645,2.3235397338867188,24830.783933877945,55936,0,24830.783933877945,0.5035000443458557,2.206250190734864,10000,26799.984984874725,0.6838671565055847,1.2733802795410156,0.6282599568367004,1.539625644683838,50000 -1998.3244211673737,2.3617398738861084,25250.83206653595,56882,0,25250.83206653595,0.5056000351905823,2.223431348800659,10000,27254.522886514664,0.7085351347923279,1.1974025964736938,0.6280999779701233,1.562067627906799,50000 -2033.3489353656769,2.401756525039673,25671.16929149628,57826,0,25671.16929149628,0.5074000358581543,2.1904778480529785,10000,27709.973189353943,0.6792578101158142,1.2947694063186646,0.6309199929237366,1.5316017866134644,50000 -2068.3087170124054,2.434241533279419,26091.416990995407,58771,0,26091.416990995407,0.5051000118255615,2.211732625961304,10000,28165.26349496841,0.6832422018051147,1.3048343658447266,0.6307799816131592,1.5494695901870728,50000 -2103.1707775592804,2.470738172531128,26511.82786488533,59717,0,26511.82786488533,0.501300036907196,2.2092301845550537,10000,28620.62424898148,0.6934570074081421,1.2318331003189087,0.6297599673271179,1.5253384113311768,50000 -2138.6504290103912,2.5045390129089355,26931.74744296074,60661,0,26931.74744296074,0.5042000412940979,2.225181818008423,10000,29076.107455968857,0.6733202934265137,1.314584732055664,0.6304999589920044,1.5325448513031006,50000 -2174.3722081184387,2.5386993885040283,27352.115279197693,61605,0,27352.115279197693,0.5072000026702881,2.201844930648804,10000,29532.281440734863,0.6815820336341858,1.2969701290130615,0.6304599642753601,1.544524908065796,50000 -2209.677117586136,2.578078031539917,27772.08389544487,62550,0,27772.08389544487,0.5128999948501587,2.191729784011841,10000,29987.64568257332,0.694531261920929,1.2573353052139282,0.6335799694061279,1.5339800119400024,50000 -2245.259070396424,2.6166160106658936,28192.20817756653,63495,0,28192.20817756653,0.515500009059906,2.1697821617126465,10000,30443.441133499146,0.719531238079071,1.1414695978164673,0.6376999616622925,1.5069578886032104,50000 -2280.39914727211,2.660373449325561,28612.125376462936,64440,0,28612.125376462936,0.5054000020027161,2.1913552284240723,10000,30898.59340786934,0.6832422018051147,1.289946436882019,0.6335200071334839,1.5233887434005735,50000 -2316.05020904541,2.7098233699798584,29032.208420038223,65386,0,29032.208420038223,0.5078999996185303,2.162203073501587,10000,31354.42751383781,0.694042980670929,1.2231028079986572,0.6383799910545349,1.484677791595459,50000 -2352.170639514923,2.747178792953491,29452.30614376068,66332,0,29452.30614376068,0.5143000483512878,2.157905578613281,10000,31810.7338078022,0.707226574420929,1.169092893600464,0.6391199827194214,1.4746456146240234,50000 -2387.653300523758,2.787407875061035,29872.2384326458,67275,0,29872.2384326458,0.5159000158309937,2.1513500213623047,10000,32266.239609479904,0.6863867044448853,1.2691597938537598,0.6340999603271484,1.5110490322113037,50000 -2421.625636100769,2.82780122756958,30292.351333141327,68217,0,30292.351333141327,0.5285000205039978,2.097337245941162,10000,32720.415112257004,0.7003710865974426,1.183992624282837,0.6473000049591064,1.441631317138672,50000 -2455.2891149520874,2.863965511322021,30712.3498916626,69160,0,30712.3498916626,0.517300009727478,2.136876106262207,10000,33174.16375398636,0.70570307970047,1.190212607383728,0.644819974899292,1.461923122406006,50000 -2489.757776260376,3.5462582111358643,31131.891610860825,70099,0,31131.891610860825,0.5184000134468079,2.132868766784668,10000,33628.90654087067,0.7290429472923279,1.0740240812301636,0.6439999938011169,1.463181734085083,50000 -2524.795145511627,3.592628002166748,31552.1533575058,71045,0,31552.1533575058,0.5236000418663025,2.1252052783966064,10000,34084.3028922081,0.6943163871765137,1.2107747793197632,0.6439999938011169,1.4579524993896484,50000 -2558.4052596092224,3.631521224975586,31972.31904554367,71989,0,31972.31904554367,0.5245000123977661,2.0866498947143555,10000,34538.16773843765,0.7100585699081421,1.1660125255584717,0.6522799730300903,1.4295166730880735,50000 -2593.2753612995148,3.672046422958374,32392.24674224853,72930,0,32392.24674224853,0.5224000215530396,2.14300537109375,10000,34993.05684757233,0.7144140601158142,1.1511657238006592,0.6480000019073486,1.455902934074402,50000 -2628.373400449753,3.709576606750488,32812.47220945358,73876,0,32812.47220945358,0.5225000381469727,2.137120246887207,10000,35448.46819233894,0.6960546970367432,1.236786723136902,0.64656001329422,1.4624232053756714,50000 -2664.336992740631,3.747407913208008,33232.751855134964,74819,0,33232.751855134964,0.5272000432014465,2.114329099655152,10000,35904.79923796654,0.7019921541213989,1.2093700170516968,0.65447998046875,1.444311022758484,50000 -2698.1623861789703,3.784353256225586,33652.85684943199,75760,0,33652.85684943199,0.5243000388145447,2.0840115547180176,10000,36358.81673932076,0.710156261920929,1.1536434888839722,0.6520400047302246,1.417355179786682,50000 -2731.18881893158,3.8260109424591056,34073.014280080795,76700,0,34073.014280080795,0.5314000248908997,2.086050271987915,10000,36812.09155678749,0.7394140362739563,1.0564640760421753,0.6574400067329407,1.422579288482666,50000 -2767.2862479686737,3.867687463760376,34493.03785729408,77644,0,34493.03785729408,0.5337000489234924,2.083815813064575,10000,37268.30460429192,0.7078710794448853,1.1789730787277222,0.6548399925231934,1.4233994483947754,50000 -2804.3315374851227,3.908069133758545,34913.28986310959,78587,0,34913.28986310959,0.5367000102996826,2.07688570022583,10000,37725.69277334213,0.7059569954872131,1.1735212802886963,0.6548999547958374,1.41805100440979,50000 -2839.677448511124,3.9546329975128174,35333.480610609055,79532,0,35333.480610609055,0.5297000408172607,2.0830442905426025,10000,38181.32603669167,0.72802734375,1.105783462524414,0.6602999567985535,1.4155524969100952,50000 -2875.596264362335,3.994290590286255,35753.704422950745,80477,0,35753.704422950745,0.5308000445365906,2.055467367172241,10000,38637.558312892914,0.71240234375,1.1545332670211792,0.6611999869346619,1.3902668952941897,50000 -2909.6886727809906,4.0421226024627686,36174.04013347626,81421,0,36174.04013347626,0.5351000428199768,2.045471668243408,10000,39092.08462238312,0.7125195264816284,1.1445448398590088,0.6576600074768066,1.3969004154205322,50000 -2945.35439991951,4.082128286361694,36594.36502742767,82367,0,36594.36502742767,0.5442000031471252,2.011723279953003,10000,39548.16616392136,0.7277148365974426,1.075430154800415,0.6658799648284912,1.359333872795105,50000 -2980.6113533973694,4.123176574707031,37014.328892469406,83313,0,37014.328892469406,0.5373000502586365,2.03928279876709,10000,40003.478034973145,0.7518945336341858,0.9783536195755004,0.6640399694442749,1.3719205856323242,50000 -3016.158703804016,4.166561603546143,37434.48281121254,84258,0,37434.48281121254,0.5437999963760376,2.0086333751678467,10000,40459.272706747055,0.7216015458106995,1.1113862991333008,0.6678599715232849,1.3596880435943604,50000 -3051.331242084503,4.223360538482666,37854.96192789078,85201,0,37854.96192789078,0.5369000434875488,2.038984775543213,10000,40915.03071784973,0.7228124737739563,1.1117151975631714,0.6657599806785583,1.3800417184829712,50000 -3086.22958111763,4.271142959594727,38275.27783441544,86145,0,38275.27783441544,0.5494000315666199,1.995234131813049,10000,41370.34265089035,0.7410351634025574,1.0317023992538452,0.6699399948120117,1.3505462408065796,50000 -3122.0959231853485,4.3140411376953125,38695.271587610245,87089,0,38695.271587610245,0.5448000431060791,1.9972219467163088,10000,41826.296382427216,0.722949206829071,1.0942314863204956,0.6686999797821045,1.3394672870635986,50000 -3157.0737657547,4.3602800369262695,39115.2553396225,88032,0,39115.2553396225,0.5497000217437744,1.986189603805542,10000,42281.354194402695,0.7299609184265137,1.0719990730285645,0.6747599840164185,1.3282451629638672,50000 -3193.2224068641663,4.400323152542114,39535.3887925148,88977,0,39535.3887925148,0.5471000075340271,2.0001771450042725,10000,42737.726367235184,0.7388085722923279,1.0372790098190308,0.6735000014305115,1.3296064138412476,50000 -3228.858241558075,4.443213939666748,39955.61276316643,89923,0,39955.61276316643,0.5466000437736511,1.9942545890808103,10000,43193.67957997322,0.7599999904632568,0.9553410410881042,0.66975998878479,1.3442224264144895,50000 -3263.637674331665,4.48250937461853,40375.80407691002,90868,0,40375.80407691002,0.5463000535964966,2.0084035396575928,10000,43648.7395863533,0.7295898199081421,1.097119688987732,0.6711399555206299,1.3498778343200684,50000 -3298.9227085113525,4.531588315963745,40795.88285851479,91812,0,40795.88285851479,0.5585000514984131,1.9671977758407595,10000,44104.20302581787,0.7383593320846558,1.04035747051239,0.6772399544715881,1.3202310800552368,50000 -3332.5399737358093,4.574221849441528,41215.80678701401,92755,0,41215.80678701401,0.5506000518798828,2.000419855117798,10000,44557.83666324616,0.7471874952316284,1.026834487915039,0.6764000058174133,1.334289312362671,50000 -3368.114696741104,4.624174118041992,41635.83515667915,93698,0,41635.83515667915,0.5543000102043152,1.944562554359436,10000,45013.53951334953,0.7383398413658142,1.0327916145324707,0.678119957447052,1.306376576423645,50000 -3404.7003223896027,4.667530536651611,42055.82461929321,94642,0,42055.82461929321,0.5541000366210938,1.9380820989608765,10000,45470.20898675919,0.7393554449081421,1.033813714981079,0.6789999604225159,1.2958242893218994,50000 -3440.3258113861084,4.711685180664063,42475.95677399635,95589,0,42475.95677399635,0.5523000359535217,1.9421786069869995,10000,45926.06120347977,0.749804675579071,0.988872528076172,0.6826800107955933,1.2907994985580444,50000 -3476.817836523056,4.763655424118042,42895.94270777702,96532,0,42895.94270777702,0.5608000159263611,1.949744820594788,10000,46382.64171719551,0.768261730670929,0.9298219680786132,0.6815999746322632,1.3072997331619265,50000 -3512.03467297554,4.808404445648193,43316.22964668274,97475,0,43316.22964668274,0.5592000484466553,1.9167892932891848,10000,46838.23907756805,0.7375390529632568,1.017408013343811,0.6825000047683716,1.2762969732284546,50000 -3548.5029113292694,4.856558799743652,43736.36538696289,98416,0,43736.36538696289,0.5608000159263611,1.9235416650772093,10000,47294.940781116486,0.7465234398841858,1.0131640434265137,0.6841599941253662,1.29089093208313,50000 -3584.3355853557587,4.8991899490356445,44156.54634976387,99357,0,44156.54634976387,0.5591000318527222,1.913759469985962,10000,47751.04694914818,0.755664050579071,0.9562978148460388,0.6837199926376343,1.2733529806137085,50000 -3620.267826318741,4.93875527381897,44576.57260489464,100302,0,44576.57260489464,0.566100001335144,1.9041144847869875,10000,48207.096264600754,0.7462109327316284,0.9947492480278016,0.6892799735069275,1.256700873374939,50000 -3655.877145767212,4.987702131271362,44996.47789621353,101247,0,44996.47789621353,0.5685000419616699,1.881093978881836,10000,48662.70986151695,0.7510156035423279,0.9671342372894288,0.6897599697113037,1.2461247444152832,50000 -3692.2329466342926,5.0316221714019775,45416.58115434647,102191,0,45416.58115434647,0.5634000301361084,1.915746808052063,10000,49119.26220941544,0.7568163871765137,0.9741609692573548,0.6883999705314636,1.270899772644043,50000 -3728.261288166046,5.086937427520752,45836.55875372887,103135,0,45836.55875372887,0.5682000517845154,1.8965471982955933,10000,49575.37422776222,0.77490234375,0.9026257395744324,0.6901199817657471,1.262291669845581,50000 -3764.905757427216,5.132697582244873,46256.78787302971,104079,0,46256.78787302971,0.5656000375747681,1.8726284503936768,10000,50032.344173669815,0.7558984160423279,0.9612823128700256,0.6957399845123291,1.2298921346664429,50000 -3799.874258995056,5.182331800460815,46677.106143713,105023,0,46677.106143713,0.5689000487327576,1.8730270862579343,10000,50487.73134660721,0.760546863079071,0.9362866282463074,0.6957399845123291,1.2292063236236572,50000 -3835.612533569336,5.22328519821167,47097.24747681618,105968,0,47097.24747681618,0.5648000240325928,1.8811651468276973,10000,50943.70225930214,0.7692187428474426,0.9136697053909302,0.6956999897956848,1.2356456518173218,50000 -3871.7235753536224,5.266687631607056,47517.2873609066,106914,0,47517.2873609066,0.5676000118255615,1.8738008737564087,10000,51399.94680428505,0.7616991996765137,0.9309342503547668,0.6985599994659424,1.2136071920394895,50000 -3907.050561189652,5.311384916305542,47937.33025097847,107855,0,47937.33025097847,0.5760000348091125,1.8427163362503047,10000,51855.41088676453,0.7624804377555847,0.9120814800262452,0.7012799978256226,1.1950334310531616,50000 -3942.584057331085,5.36137318611145,48357.28145503998,108799,0,48357.28145503998,0.5685000419616699,1.8710874319076536,10000,52310.99687170982,0.7684375047683716,0.915904700756073,0.6987599730491638,1.2196636199951172,50000 -3979.134026050568,5.406089067459106,48777.6155333519,109745,0,48777.6155333519,0.5756000280380249,1.829600691795349,10000,52767.975719213486,0.7912695407867432,0.8051213026046753,0.7003399729728699,1.1956088542938232,50000 -4015.9227290153503,5.449268817901611,49197.70868873596,110688,0,49197.70868873596,0.5751000046730042,1.837497353553772,10000,53224.95006370544,0.7667773365974426,0.90343976020813,0.7044399976730347,1.1917033195495603,50000 -4051.419378995896,5.502639532089233,49617.80343198776,111633,0,49617.80343198776,0.5767000317573547,1.8268989324569704,10000,53680.64485859871,0.7742773294448853,0.8982342481613159,0.7044399976730347,1.1926296949386597,50000 -4086.712213039398,5.552619695663452,50038.0046517849,112579,0,50038.0046517849,0.58160001039505,1.812219262123108,10000,54136.23913860321,0.7817187309265137,0.8364842534065247,0.7057200074195862,1.1751327514648438,50000 -4123.13805603981,5.594868421554565,50458.02049589157,113519,0,50458.02049589157,0.5766000151634216,1.8333693742752075,10000,54592.7734439373,0.7722070217132568,0.889786422252655,0.7026799917221069,1.1861135959625244,50000 -4158.8170692920685,5.638230085372925,50878.1044754982,114465,0,50878.1044754982,0.588100016117096,1.804617047309876,10000,55048.6302447319,0.7795702815055847,0.8615785837173462,0.7099199891090393,1.1732913255691528,50000 -4193.726380109787,5.686317682266235,51298.08326268196,115411,0,51298.08326268196,0.5830000042915344,1.7958065271377563,10000,55503.61720633507,0.7862890362739563,0.82613605260849,0.709879994392395,1.1572134494781494,50000 -4227.141417503357,5.730888605117798,51718.13684058189,116356,0,51718.13684058189,0.5822000503540039,1.81033718585968,10000,55957.18047308922,0.8011523485183716,0.7779342532157898,0.7084999680519104,1.161314606666565,50000 -4264.137006044388,5.787650108337402,52138.45469069481,117299,0,52138.45469069481,0.5878000259399414,1.808908462524414,10000,56414.60079431534,0.7773827910423279,0.8753180503845215,0.7099599838256836,1.174333095550537,50000 -4299.884291887283,5.847143650054932,52558.41911363602,118247,0,52558.41911363602,0.5895000100135803,1.791690707206726,10000,56870.423597335815,0.7845507860183716,0.8391237258911133,0.7131199836730957,1.1501730680465698,50000 -4336.783348560333,5.890512466430664,52978.646104097366,119194,0,52978.646104097366,0.5840000510215759,1.8003008365631104,10000,57327.64228081703,0.7907031178474426,0.8129280805587769,0.712619960308075,1.1554131507873535,50000 -4372.142210245132,5.944162607192993,53398.970262527466,120140,0,53398.970262527466,0.5950000286102295,1.7669881582260132,10000,57783.42899656296,0.78236323595047,0.8431088328361511,0.715939998626709,1.1370404958724976,50000 -4408.703050613403,5.99152398109436,53819.214713573456,121084,0,53819.214713573456,0.59170001745224,1.766446590423584,10000,58240.33206629753,0.7857617139816284,0.8171523809432983,0.716759979724884,1.126121163368225,50000 -4444.461498260498,6.041364669799805,54239.28956913948,122027,0,54239.28956913948,0.6005000472068787,1.7421303987503052,10000,58696.26500344277,0.7969726324081421,0.7908524870872498,0.7226200103759766,1.1162563562393188,50000 -4480.420238494873,6.089730262756348,54659.25565576553,122969,0,54659.25565576553,0.5875000357627869,1.7857192754745483,10000,59152.28856778145,0.8081640601158142,0.7565469741821289,0.7193799614906311,1.1386165618896484,50000 -4515.870098590851,6.138895750045776,55079.27008724213,123914,0,55079.27008724213,0.5992000102996826,1.731168270111084,10000,59607.85239100456,0.7901562452316284,0.8068869113922119,0.7235400080680847,1.10583233833313,50000 -4551.8229367733,6.189615488052368,55499.21904158592,124859,0,55499.21904158592,0.6008000373840332,1.7349574565887451,10000,60063.85525274277,0.7969726324081421,0.7828193306922913,0.721839964389801,1.1123453378677368,50000 -4587.547564506531,6.235995531082153,55919.18010044098,125802,0,55919.18010044098,0.5984000563621521,1.7472631931304932,10000,60519.63783144951,0.8039257526397705,0.764805257320404,0.7251600027084351,1.1057837009429932,50000 -4622.255147457123,6.281415939331055,56339.31581926346,126746,0,56339.31581926346,0.6011000275611877,1.7334434986114502,10000,60974.57660126686,0.7957617044448853,0.7915138006210327,0.7269200086593628,1.0992271900177002,50000 -4657.953867673874,6.33309531211853,56759.32902884483,127689,0,56759.32902884483,0.6038000583648682,1.7361700534820557,10000,61430.3923459053,0.7975195050239563,0.7745173573493958,0.725600004196167,1.094840168952942,50000 -4694.850904941559,6.385717153549194,57179.24024510384,128633,0,57179.24024510384,0.5994000434875488,1.735648274421692,10000,61887.30322051048,0.8062695264816284,0.7585539221763611,0.7260199785232544,1.0970200300216677,50000 -4731.710379600525,6.433982849121094,57599.32408952713,129577,0,57599.32408952713,0.600600004196167,1.7161928415298462,10000,62344.34480929375,0.8240624666213989,0.6876732707023621,0.7297599911689758,1.080593466758728,50000 -4767.513851881027,6.48798131942749,58019.40853381157,130521,0,58019.40853381157,0.6066000461578369,1.7314767837524414,10000,62800.33711600304,0.8061913847923279,0.74146568775177,0.731660008430481,1.078520894050598,50000 -4804.62633895874,6.535409212112427,58439.49217700958,131467,0,58439.49217700958,0.605400025844574,1.702867865562439,10000,63257.6309440136,0.8073632717132568,0.7516621947288513,0.7322799563407898,1.0729314088821411,50000 -4840.436786174774,6.581597805023193,58859.64246606827,132413,0,58859.64246606827,0.6101000308990479,1.7026625871658323,10000,63713.68802905083,0.8190429210662842,0.686676025390625,0.7339000105857849,1.062727451324463,50000 -4875.248173952103,6.630462884902954,59279.68286252022,133358,0,59279.68286252022,0.6114000082015991,1.697309494018555,10000,64168.63878440857,0.809863269329071,0.7376311421394348,0.735539972782135,1.060116171836853,50000 -4911.07240653038,6.679866790771484,59699.82854485512,134300,0,59699.82854485512,0.6104000210762024,1.6855266094207764,10000,64624.70807790756,0.8107226490974426,0.7090687155723572,0.735539972782135,1.045581340789795,50000 -4947.333732366562,6.733776807785034,60120.14081978798,135246,0,60120.14081978798,0.619100034236908,1.657473921775818,10000,65081.38606548309,0.8220117092132568,0.6820235252380371,0.7381199598312378,1.040558695793152,50000 -4983.961835622788,6.781572341918945,60540.23037648201,136190,0,60540.23037648201,0.6166000366210938,1.6594387292861938,10000,65538.20174622536,0.833789050579071,0.6432842016220093,0.7403799891471863,1.0398958921432495,50000 -5019.351192474365,6.82807993888855,60960.5283882618,137134,0,60960.5283882618,0.6152000427246094,1.6574362516403198,10000,65993.98559451103,0.8225390315055847,0.6828349232673645,0.7406799793243408,1.026926040649414,50000 -5055.546671152115,6.876504421234131,61380.73367190361,138078,0,61380.73367190361,0.6131000518798828,1.6730371713638306,10000,66450.48570799828,0.8241210579872131,0.6726440787315369,0.7392399907112122,1.0355180501937866,50000 -5090.26197385788,6.923521041870117,61800.80941319466,139021,0,61800.80941319466,0.6114000082015991,1.6480200290679932,10000,66905.37312984467,0.8309569954872131,0.6413344144821167,0.7416599988937378,1.0248628854751587,50000 -5126.392836332321,6.976944208145142,62220.787368536,139965,0,62220.787368536,0.6248000264167786,1.6311440467834473,10000,67361.58550977707,0.8257421851158142,0.6530791521072388,0.7451399564743042,1.004848599433899,50000 -5162.872810840607,7.027889490127564,62640.990287303925,140909,0,62640.990287303925,0.6196000576019287,1.6410871744155884,10000,67818.36901831627,0.8275976181030273,0.6520302295684814,0.7456799745559692,1.002097725868225,50000 -5198.357325792313,7.077587604522705,63061.08670520783,141855,0,63061.08670520783,0.6238000392913818,1.6353424787521362,10000,68274.05000305176,0.8324023485183716,0.6519719362258911,0.7467199563980103,1.0059432983398438,50000 -5234.109792232513,7.129896640777588,63481.00023508072,142802,0,63481.00023508072,0.6241000294685364,1.6294469833374023,10000,68729.8181681633,0.8463085889816284,0.5910487174987793,0.7475399971008301,1.0001438856124878,50000 -5269.614250659943,7.175962209701538,63901.20878171921,143747,0,63901.20878171921,0.6236000061035156,1.624912142753601,10000,69185.62810969353,0.8325781226158142,0.6336491703987122,0.7473799586296082,0.996917188167572,50000 -5305.843818902969,7.224336624145508,64321.56186580658,144694,0,64321.56186580658,0.6264000535011292,1.6167546510696411,10000,69642.30963206291,0.83607417345047,0.6233413219451904,0.7495200037956238,0.9957394003868104,50000 -5341.5623569488525,7.281407833099365,64741.75490403175,145639,0,64741.75490403175,0.6269000172615051,1.6104060411453247,10000,70098.32852602005,0.8413866758346558,0.6079884767532349,0.7496599555015564,0.9920935034751892,50000 -5377.9620950222015,7.32740592956543,65161.94474768639,146583,0,65161.94474768639,0.6305000185966492,1.6054846048355105,10000,70555.01319146156,0.8368554711341858,0.6317271590232849,0.7525999546051025,0.9833926558494568,50000 -5413.293456554413,7.379019498825073,65582.24419879913,147531,0,65582.24419879913,0.6319000124931335,1.593847155570984,10000,71010.74550557137,0.8406640291213989,0.6047345399856567,0.7533400058746338,0.9802722930908204,50000 -5449.63863158226,7.4307496547698975,66002.26580381393,148474,0,66002.26580381393,0.6288000345230103,1.5970369577407837,10000,71467.2146782875,0.8462694883346558,0.5779913067817688,0.7558599710464478,0.9657291173934937,50000 -5485.423399209976,7.478578805923462,66422.62491846085,149419,0,66422.62491846085,0.6330000162124634,1.5846474170684814,10000,71923.45594525337,0.8524804711341858,0.5546852946281433,0.7554199695587158,0.9629727602005004,50000 -5520.838660478592,7.52655816078186,66842.6845741272,150364,0,66842.6845741272,0.6341000199317932,1.5807474851608276,10000,72379.02922987938,0.845996081829071,0.5830208659172058,0.7575799822807312,0.9540197253227234,50000 -5557.04504776001,7.584197282791138,67262.9191057682,151309,0,67262.9191057682,0.6341000199317932,1.5633233785629272,10000,72835.57788276672,0.8492968678474426,0.5629798173904419,0.7577599883079529,0.9516875743865968,50000 -5592.5248901844025,7.643503665924072,67683.15299010277,152255,0,67683.15299010277,0.636400043964386,1.5599946975708008,10000,73291.40171384811,0.8557031154632568,0.5337830781936646,0.7605400085449219,0.942303478717804,50000 -5628.214492797852,7.696725130081177,68103.0815103054,153200,0,68103.0815103054,0.6391000151634216,1.5622645616531372,10000,73747.12293791771,0.8529492020606995,0.5705944895744324,0.7596399784088135,0.9485294222831726,50000 -5664.248873949051,7.753176212310791,68523.02074098587,154142,0,68523.02074098587,0.6353000402450562,1.5707690715789795,10000,74203.20285248756,0.8538476228713989,0.5499022006988525,0.7615799903869629,0.9385902881622314,50000 -5701.631700754166,7.805613040924072,68943.2842502594,155085,0,68943.2842502594,0.6333000063896179,1.566225528717041,10000,74660.95222091675,0.8555077910423279,0.5564358234405518,0.759880006313324,0.946969985961914,50000 -5737.818987846375,7.859495878219604,69363.431671381,156031,0,69363.431671381,0.6373000144958496,1.5529048442840576,10000,75117.39098858833,0.8609960675239563,0.5209859609603882,0.7628200054168701,0.931879997253418,50000 -5772.52631187439,7.919899225234985,69783.68109440804,156974,0,69783.68109440804,0.6371000409126282,1.5472129583358765,10000,75572.45741295815,0.85804682970047,0.5311560034751892,0.764959990978241,0.92471182346344,50000 -5808.363347053528,7.96866250038147,70203.67983937263,157919,0,70203.67983937263,0.6377000212669373,1.5564005374908447,10000,76028.39196753502,0.8621679544448853,0.5378159880638123,0.7645599842071533,0.9374600052833556,50000 -5845.419222831726,8.01696515083313,70623.86025309563,158865,0,70623.86025309563,0.6474000215530396,1.5338735580444336,10000,76485.72631263733,0.8666406273841858,0.5006186962127686,0.7650799751281738,0.9160201549530028,50000 -5882.281366825104,8.068997144699097,71044.06207823753,159809,0,71044.06207823753,0.6447000503540039,1.5281864404678345,10000,76942.89191150665,0.862597644329071,0.5141065120697021,0.7669199705123901,0.9152435064315796,50000 -5919.461312532425,8.124317407608032,71464.29527115822,160754,0,71464.29527115822,0.6433000564575195,1.548334002494812,10000,77400.4092001915,0.8669140338897705,0.5135356187820435,0.7686799764633179,0.9221277236938475,50000 -5956.193717479706,8.185054779052734,71884.56301403046,161699,0,71884.56301403046,0.648900032043457,1.5213313102722168,10000,77857.51973557472,0.8700976371765137,0.5003441572189331,0.768619954586029,0.916907787322998,50000 -5992.814433336258,8.240005731582642,72304.83881092072,162642,0,72304.83881092072,0.6458000540733337,1.5222994089126587,10000,78314.5209043026,0.8746874928474426,0.4664015173912048,0.7695800065994263,0.9045939445495604,50000 -6029.993057489395,8.293435096740723,72724.85019230843,163585,0,72724.85019230843,0.6484000086784363,1.5181078910827637,10000,78771.81423521042,0.8719531297683716,0.4847078919410705,0.7728599905967712,0.8979256749153137,50000 -6066.192003488541,8.345094919204712,73144.75268936157,164529,0,73144.75268936157,0.6499000191688538,1.5122696161270142,10000,79228.01746106148,0.8699023127555847,0.48884978890419,0.7724399566650391,0.8968047499656677,50000 -6101.721790552139,8.40129041671753,73564.731341362,165471,0,73564.731341362,0.650700032711029,1.5046427249908447,10000,79683.63198709488,0.8733788728713989,0.4675703942775726,0.7732999920845032,0.886921226978302,50000 -6137.511723995209,8.457269191741943,73984.96862435341,166413,0,73984.96862435341,0.6540000438690186,1.4882982969284058,10000,80139.76555800438,0.8737499713897705,0.464886873960495,0.7749800086021423,0.881984293460846,50000 -6174.019504547119,8.509846687316895,74405.22176742554,167356,0,74405.22176742554,0.6548000574111938,1.4894704818725586,10000,80596.62842297554,0.8771874904632568,0.4618673622608185,0.7760599851608276,0.8795783519744873,50000 -6209.835191726685,8.572486162185669,74825.44824123383,168303,0,74825.44824123383,0.651900053024292,1.4991596937179563,10000,81052.783826828,0.8756640553474426,0.4663519561290741,0.7746599912643433,0.8834454417228699,50000 -6245.338510274887,8.62363052368164,75245.51100564003,169247,0,75245.51100564003,0.6554000377655029,1.493224024772644,10000,81508.45189023018,0.8783007860183716,0.4513605535030365,0.7752000093460083,0.8814749121665955,50000 -6280.529743909836,8.678744554519653,75665.74991869926,170191,0,75665.74991869926,0.65420001745224,1.4848601818084717,10000,81963.98689365387,0.8774804472923279,0.4559576213359833,0.7778799533843994,0.8701184391975403,50000 -6317.746684074402,8.737242698669434,76086.01024103165,171135,0,76086.01024103165,0.6574000120162964,1.4882923364639282,10000,82421.57224082947,0.8802539110183716,0.4446616172790527,0.7779600024223328,0.8712742924690247,50000 -6353.816187143326,8.7944974899292,76506.09647703171,172081,0,76506.09647703171,0.6577000021934509,1.4779142141342163,10000,82877.83606362343,0.8818163871765137,0.4375278651714325,0.7779200077056885,0.8668556213378906,50000 -6389.507263660431,8.855574369430542,76926.27028298378,173026,0,76926.27028298378,0.6589000225067139,1.478680968284607,10000,83333.81209087372,0.8795507550239563,0.4471313059329986,0.7780799865722656,0.8690222501754761,50000 -6424.645256996155,8.914592504501343,77346.20090389252,173974,0,77346.20090389252,0.6617000102996826,1.4729156494140625,10000,83788.990837574,0.8820312023162842,0.438140720129013,0.7789199948310852,0.8620977997779846,50000 -6461.650452852249,8.96901798248291,77766.21065545082,174916,0,77766.21065545082,0.6585000157356262,1.4759646654129028,10000,84246.11060118675,0.8839452862739563,0.4305086731910705,0.7795400023460388,0.8621959090232849,50000 -6497.523934841156,9.027230262756348,78186.27035021782,175861,0,78186.27035021782,0.6604000329971313,1.4743776321411133,10000,84702.1522629261,0.8856444954872131,0.4301820993423462,0.7799199819564819,0.8656541705131531,50000 -6532.788912773132,9.082460403442385,78606.58297896385,176806,0,78606.58297896385,0.6619000434875488,1.4712177515029907,10000,85157.835013628,0.88330078125,0.4406342208385467,0.7821799516677856,0.86235511302948,50000 -6568.6220116615295,9.137590885162354,79026.84996557236,177751,0,79026.84996557236,0.6593000292778015,1.473184585571289,10000,85614.04051232338,0.88427734375,0.4303996860980987,0.780959963798523,0.8608471155166626,50000 -6604.234881401062,9.192800283432009,79446.80137014389,178697,0,79446.80137014389,0.6617000102996826,1.4699724912643433,10000,86069.71082782745,0.8859374523162842,0.4259621798992157,0.7819199562072754,0.8566552400588989,50000 -6641.677387714386,9.251446962356567,79866.9059574604,179645,0,79866.9059574604,0.6598000526428223,1.470707654953003,10000,86527.36637544632,0.8874218463897705,0.4268486797809601,0.7813599705696106,0.8589540123939514,50000 -6675.996264457703,9.310755252838137,80287.17362117767,180589,0,80287.17362117767,0.6617000102996826,1.4652323722839355,10000,86982.06272292137,0.8857421875,0.4221574366092682,0.7830399870872498,0.8535976409912109,50000 -6711.070337772369,9.363951683044434,80707.40728902817,181530,0,80707.40728902817,0.6627000570297241,1.4637311697006226,10000,87437.47354197502,0.8855078220367432,0.4229528307914734,0.782539963722229,0.8535990118980408,50000 -6746.992145776749,9.44046401977539,81127.42118930817,182476,0,81127.42118930817,0.6619000434875488,1.4650492668151855,10000,87893.53619122505,0.8879101276397705,0.4179525971412658,0.7822999954223633,0.8548969030380249,50000 -6781.985077857971,9.50280237197876,81547.52734351158,183422,0,81547.52734351158,0.6612000465393066,1.4615788459777832,10000,88348.74827170372,0.8875195384025574,0.4133367538452148,0.7827999591827393,0.8514824509620667,50000 -6819.577335119247,9.556721925735474,81967.57227182388,184369,0,81967.57227182388,0.6614000201225281,1.4628207683563232,10000,88806.49009799957,0.8892773389816284,0.4150594472885132,0.7831400036811829,0.8529243469238281,50000 -6855.812837123871,9.610158681869509,82387.84855127335,185318,0,82387.84855127335,0.6621000170707703,1.4614572525024414,10000,89263.10540890694,0.8887499570846558,0.4133415520191192,0.7833399772644043,0.8519114851951599,50000 -6892.20347237587,9.669806718826294,82808.1642510891,186263,0,82808.1642510891,0.6614000201225281,1.4626858234405518,10000,89719.92125058174,0.88880854845047,0.4150834083557129,0.7832399606704712,0.8526034355163574,50000 -6928.669387102127,9.730186939239502,83228.10613918304,187209,0,83228.10613918304,0.6615000367164612,1.4626646041870115,10000,90176.43953990936,0.8882226347923279,0.4158450961112976,0.7831799983978271,0.8526236414909363,50000 -6964.671956300736,9.78453016281128,83648.2685251236,188153,0,83648.2685251236,0.6615000367164612,1.4626646041870115,10000,90632.70875310898,0.8866796493530273,0.4240669012069702,0.7831799983978271,0.8526236414909363,50000 -7001.998011350632,9.83876132965088,84068.47924923897,189099,0,84068.47924923897,0.6615000367164612,1.4626646041870115,10000,91090.34984946252,0.8877343535423279,0.4180521368980407,0.7831799983978271,0.8526236414909363,50000 -7036.250737190247,9.89790678024292,84488.68370914459,190046,0,84488.68370914459,0.6615000367164612,1.4626646041870115,10000,91544.91623663902,0.8861523270606995,0.4211257696151733,0.7831799983978271,0.8526236414909363,50000 -7071.985706090927,9.954696655273438,84909.03533935547,190988,0,84909.03533935547,0.6615000367164612,1.4626646041870115,10000,92001.10895776749,0.8909765481948853,0.413134753704071,0.7831799983978271,0.8526236414909363,50000 -7107.835638523102,10.013721942901611,85329.32217812538,191933,0,85329.32217812538,0.6615000367164612,1.4626646041870115,10000,92457.35536766052,0.8879101276397705,0.4145359992980957,0.7831799983978271,0.8526236414909363,50000 -7141.772822141647,10.06668186187744,85749.28629517555,192879,0,85749.28629517555,0.6615000367164612,1.4626646041870115,10000,92911.35882663728,0.8848242163658142,0.4299986958503723,0.7831799983978271,0.8526236414909363,50000 -7176.942252874374,10.127224683761597,86169.27835011482,193825,0,86169.27835011482,0.6615000367164612,1.4626646041870115,10000,93366.63113760948,0.8856444954872131,0.4275977611541748,0.7831799983978271,0.8526236414909363,50000 -7213.204705715179,10.184736251831056,86589.32567048073,194769,0,86589.32567048073,0.6615000367164612,1.4626646041870115,10000,93823.04841327669,0.8881444931030273,0.4168613851070404,0.7831799983978271,0.8526236414909363,50000 -7248.850476980209,10.240220785140991,87009.27286624908,195714,0,87009.27286624908,0.6615000367164612,1.4626646041870115,10000,94278.746717453,0.88720703125,0.4198476374149322,0.7831799983978271,0.8526236414909363,50000 -7285.001372337341,10.306553363800049,87429.595079422,196662,0,87429.595079422,0.6615000367164612,1.4626646041870115,10000,94735.3364572525,0.8884570002555847,0.4169119894504547,0.7831799983978271,0.8526236414909363,50000 -7319.191005468368,10.364394903182983,87849.86155200005,197609,0,87849.86155200005,0.6615000367164612,1.4626646041870115,10000,95189.90076732635,0.8848632574081421,0.426829069852829,0.7831799983978271,0.8526236414909363,50000 -7355.1042511463165,10.42365837097168,88270.01970601082,198555,0,88270.01970601082,0.6615000367164612,1.4626646041870115,10000,95646.08177280426,0.8874609470367432,0.4123820066452026,0.7831799983978271,0.8526236414909363,50000 -7390.73459148407,10.484740495681764,88690.64194583893,199501,0,88690.64194583893,0.6615000367164612,1.4626646041870115,10000,96102.44577240944,0.8875781297683716,0.4194918870925903,0.7831799983978271,0.8526236414909363,50000 -7427.6580538749695,10.54875898361206,89110.59979486465,200445,0,89110.59979486465,0.6615000367164612,1.4626646041870115,10000,96559.44136738776,0.8860741853713989,0.4245986938476562,0.7831799983978271,0.8526236414909363,50000 -7462.471394062042,10.607412338256836,89530.50890040398,201391,0,89530.50890040398,0.6615000367164612,1.4626646041870115,10000,97014.27209377287,0.8870702981948853,0.4246867001056671,0.7831799983978271,0.8526236414909363,50000 -7499.333806037903,10.667636156082152,89950.66359424591,202334,0,89950.66359424591,0.6615000367164612,1.4626646041870115,10000,97471.39986658096,0.8855664134025574,0.427189439535141,0.7831799983978271,0.8526236414909363,50000 -7534.678661584854,10.743087768554688,90370.86471438408,203278,0,90370.86471438408,0.6615000367164612,1.4626646041870115,10000,97927.07135868073,0.8870702981948853,0.4194826483726501,0.7831799983978271,0.8526236414909363,50000 -7570.669605970383,10.802512645721436,90790.99326586723,204221,0,90790.99326586723,0.6615000367164612,1.4626646041870115,10000,98383.30153632164,0.8884961009025574,0.4168415665626526,0.7831799983978271,0.8526236414909363,50000 -7607.05656671524,10.85942006111145,91211.18811535837,205166,0,91211.18811535837,0.6615000367164612,1.4626646041870115,10000,98839.98974704742,0.8862499594688416,0.4210224449634552,0.7831799983978271,0.8526236414909363,50000 -7644.656408786774,10.91633415222168,91631.25634765624,206112,0,91631.25634765624,0.6615000367164612,1.4626646041870115,10000,99297.76484441756,0.8875976204872131,0.4213772118091583,0.7831799983978271,0.8526236414909363,50000 -7678.900659561157,10.983742237091064,92051.31888103484,207056,0,92051.31888103484,0.6615000367164612,1.4626646041870115,10000,99752.18914794922,0.8889257907867432,0.4079552888870239,0.7831799983978271,0.8526236414909363,50000 -7714.484362125397,11.040274143218994,92471.98997926712,208001,0,92471.98997926712,0.6615000367164612,1.4626646041870115,10000,100208.54979085922,0.8864452838897705,0.425664484500885,0.7831799983978271,0.8526236414909363,50000 -7750.031948566437,11.0984628200531,92892.02937984468,208947,0,92892.02937984468,0.6615000367164612,1.4626646041870115,10000,100664.24516439438,0.8887109160423279,0.4154536426067352,0.7831799983978271,0.8526236414909363,50000 -7786.632411718368,11.160149812698364,93312.06176161766,209894,0,93312.06176161766,0.6615000367164612,1.4626646041870115,10000,101120.9896683693,0.8918554782867432,0.4047467410564422,0.7831799983978271,0.8526236414909363,50000 -7823.209892272949,11.221330165863035,93732.22330355644,210839,0,93732.22330355644,0.6615000367164612,1.4626646041870115,10000,101577.8395473957,0.8872460722923279,0.4177339971065521,0.7831799983978271,0.8526236414909363,50000 -7858.10614824295,11.293134450912476,94152.4714114666,211784,0,94152.4714114666,0.6615000367164612,1.4626646041870115,10000,102033.10613369942,0.8871093392372131,0.4195844233036041,0.7831799983978271,0.8526236414909363,50000 -7893.675740480423,11.354552745819092,94572.54891109468,212729,0,94572.54891109468,0.6615000367164612,1.4626646041870115,10000,102488.86503720284,0.8864257335662842,0.4255499839782715,0.7831799983978271,0.8526236414909363,50000 -7930.459398984909,11.415436506271362,94992.61476254465,213672,0,94992.61476254465,0.6615000367164612,1.4626646041870115,10000,102945.8255429268,0.8885546922683716,0.4147121608257293,0.7831799983978271,0.8526236414909363,50000 -7966.65424156189,11.472744464874268,95412.57948756218,214616,0,95412.57948756218,0.6615000367164612,1.4626646041870115,10000,103402.09250640868,0.8885155916213989,0.419483482837677,0.7831799983978271,0.8526236414909363,50000 -8002.676563978195,11.544119596481323,95832.54670524596,215562,0,95832.54670524596,0.6615000367164612,1.4626646041870115,10000,103858.20341420174,0.8874413967132568,0.417631447315216,0.7831799983978271,0.8526236414909363,50000 -8030.116549730301,11.604514122009276,96252.98335146904,216482,0,96252.98335146904,0.6615000367164612,1.4626646041870115,10000,104306.19088602066,0.8871679306030273,0.4172095358371734,0.7831799983978271,0.8526236414909363,50000 -8069.562195062637,11.770461082458496,96673.04321050644,217388,0,96673.04321050644,0.6615000367164612,1.4626646041870115,10000,104765.90989851952,0.8858984112739563,0.4288398623466491,0.7831799983978271,0.8526236414909363,50000 -8106.868925571442,11.844651222229004,97093.0073914528,218333,0,97093.0073914528,0.6615000367164612,1.4626646041870115,10000,105223.30492758752,0.8865429759025574,0.4225581288337707,0.7831799983978271,0.8526236414909363,50000 -8145.163213729858,11.905775547027588,97513.31821966173,219280,0,97513.31821966173,0.6615000367164612,1.4626646041870115,10000,105682.02164626122,0.8845507502555847,0.4251365661621094,0.7831799983978271,0.8526236414909363,50000 -8178.272148132324,11.980994701385498,97933.19587111472,220224,0,97933.19587111472,0.6615000367164612,1.4626646041870115,10000,106135.13363027573,0.8892187476158142,0.4172480404376983,0.7831799983978271,0.8526236414909363,50000 -8214.198472261429,12.051104068756104,98353.35867023468,221168,0,98353.35867023468,0.6615000367164612,1.4626646041870115,10000,106591.34285092354,0.8871093392372131,0.4200824499130249,0.7831799983978271,0.8526236414909363,50000 -8250.0584192276,12.109280109405518,98773.69676375388,222114,0,98773.69676375388,0.6615000367164612,1.4626646041870115,10000,107047.64982414246,0.8859765529632568,0.421042799949646,0.7831799983978271,0.8526236414909363,50000 -8285.833389759064,12.175068616867064,99193.9378619194,223059,0,99193.9378619194,0.6615000367164612,1.4626646041870115,10000,107503.78232359886,0.8874609470367432,0.416455864906311,0.7831799983978271,0.8526236414909363,50000 -8321.404640674591,12.237873315811155,99614.28044009209,224006,0,99614.28044009209,0.6615000367164612,1.4626646041870115,10000,107959.80958008766,0.8888476490974426,0.4210689067840576,0.7831799983978271,0.8526236414909363,50000 -8357.846234321594,12.309117078781128,100034.37481594086,224950,0,100034.37481594086,0.6615000367164612,1.4626646041870115,10000,108416.46704983713,0.8860937356948853,0.4242303669452667,0.7831799983978271,0.8526236414909363,50000 -8393.42155122757,12.37180781364441,100454.41598153114,225894,0,100454.41598153114,0.6615000367164612,1.4626646041870115,10000,108872.19716739656,0.8856640458106995,0.4249935150146484,0.7831799983978271,0.8526236414909363,50000 -8430.878259420395,12.438353300094604,100874.35620498656,226838,0,100874.35620498656,0.6615000367164612,1.4626646041870115,10000,109329.71062541008,0.8877148032188416,0.4187945425510406,0.7831799983978271,0.8526236414909363,50000 -8468.879625320435,12.502143383026125,101294.32449269296,227782,0,101294.32449269296,0.6615000367164612,1.4626646041870115,10000,109787.79367494585,0.8872265219688416,0.4204618036746979,0.7831799983978271,0.8526236414909363,50000 -8507.651833295822,12.566620111465454,101714.59797477722,228723,0,101714.59797477722,0.6615000367164612,1.4626646041870115,10000,110246.95320534706,0.8875585794448853,0.4156016409397125,0.7831799983978271,0.8526236414909363,50000 -8545.424136638641,12.626487016677856,102134.6535089016,229661,0,102134.6535089016,0.6615000367164612,1.4626646041870115,10000,110704.89049220084,0.8867773413658142,0.4198257625102997,0.7831799983978271,0.8526236414909363,50000 -8584.81552362442,12.685871601104736,102554.6896674633,230604,0,102554.6896674633,0.6615000367164612,1.4626646041870115,10000,111164.42679929732,0.8856054544448853,0.4232231080532074,0.7831799983978271,0.8526236414909363,50000 -8620.59669971466,12.754069089889526,102974.65440821648,231547,0,102974.65440821648,0.6615000367164612,1.4626646041870115,10000,111620.29118657112,0.8893945217132568,0.4144158065319061,0.7831799983978271,0.8526236414909363,50000 -8658.708618879318,12.814483880996704,103394.6379442215,232490,0,103394.6379442215,0.6615000367164612,1.4626646041870115,10000,112078.49687552452,0.8886327743530273,0.41773721575737,0.7831799983978271,0.8526236414909363,50000 -8692.565080881119,12.884013175964355,103814.5766415596,233432,0,103814.5766415596,0.6615000367164612,1.4626646041870115,10000,112532.41099596024,0.8906640410423279,0.4075176119804382,0.7831799983978271,0.8526236414909363,50000 -8728.871542215347,12.946634531021118,104234.8699195385,234375,0,104234.8699195385,0.6615000367164612,1.4626646041870115,10000,112989.12427139282,0.8890429735183716,0.4111610949039459,0.7831799983978271,0.8526236414909363,50000 -8766.155029296875,13.006984949111938,104655.03619599342,235319,0,104655.03619599342,0.6615000367164612,1.4626646041870115,10000,113446.68429636957,0.8880859017372131,0.4207883477210998,0.7831799983978271,0.8526236414909363,50000 -8802.531280994415,13.073152780532835,105075.3451757431,236262,0,105075.3451757431,0.6615000367164612,1.4626646041870115,10000,113903.48621582983,0.88671875,0.4218221306800842,0.7831799983978271,0.8526236414909363,50000 -8838.775473117828,13.146722555160522,105495.2727303505,237207,0,105495.2727303505,0.6615000367164612,1.4626646041870115,10000,114359.78117632866,0.8869531154632568,0.4194927811622619,0.7831799983978271,0.8526236414909363,50000 -8873.482074022293,13.212722063064575,105915.4233739376,238150,0,105915.4233739376,0.6615000367164612,1.4626646041870115,10000,114814.7551767826,0.8891406059265137,0.4173803925514221,0.7831799983978271,0.8526236414909363,50000 -8909.238534212112,13.272995471954346,106335.55676198006,239094,0,106335.55676198006,0.6615000367164612,1.4626646041870115,10000,115270.7553293705,0.8861132860183716,0.4208251237869262,0.7831799983978271,0.8526236414909363,50000 -8945.806883335114,13.338412046432495,106755.64887499808,240041,0,106755.64887499808,0.6615000367164612,1.4626646041870115,10000,115727.53117990494,0.8859374523162842,0.4255052208900451,0.7831799983978271,0.8526236414909363,50000 -8982.179713249207,13.405362844467165,107175.83440184592,240984,0,107175.83440184592,0.6615000367164612,1.4626646041870115,10000,116184.20617771149,0.88685542345047,0.4231062233448028,0.7831799983978271,0.8526236414909363,50000 -9018.915987968445,13.470390558242798,107595.84573626518,241927,0,107595.84573626518,0.6615000367164612,1.4626646041870115,10000,116641.06905651093,0.8875195384025574,0.4177833497524261,0.7831799983978271,0.8526236414909363,50000 -9054.184888839722,13.544833660125732,108016.01560282709,242873,0,108016.01560282709,0.6615000367164612,1.4626646041870115,10000,117096.63333463667,0.8873632550239563,0.4211326837539673,0.7831799983978271,0.8526236414909363,50000 -9092.124347448347,13.611082077026367,108436.21713781355,243816,0,108436.21713781355,0.6615000367164612,1.4626646041870115,10000,117554.89075517654,0.8863866925239563,0.4213360548019409,0.7831799983978271,0.8526236414909363,50000 -9130.800177574158,13.678319692611694,108856.1207022667,244760,0,108856.1207022667,0.6615000367164612,1.4626646041870115,10000,118013.58737802504,0.8868163824081421,0.4208961427211761,0.7831799983978271,0.8526236414909363,50000 -9166.68828845024,13.749892473220823,109276.05662918092,245704,0,109276.05662918092,0.6615000367164612,1.4626646041870115,10000,118469.53283405304,0.8855273127555847,0.4221659004688263,0.7831799983978271,0.8526236414909363,50000 -9201.703897237778,13.811155080795288,109696.36248373984,246648,0,109696.36248373984,0.6615000367164612,1.4626646041870115,10000,118924.96544003488,0.8881054520606995,0.4146759212017059,0.7831799983978271,0.8526236414909363,50000 -9238.741772651672,13.876649856567385,110116.45049238204,247594,0,110116.45049238204,0.6615000367164612,1.4626646041870115,10000,119382.20756483078,0.8883398175239563,0.4226095378398895,0.7831799983978271,0.8526236414909363,50000 -9276.15583205223,13.948917627334597,110536.77914857864,248539,0,110536.77914857864,0.6615000367164612,1.4626646041870115,10000,119840.07347083092,0.8856640458106995,0.4264017343521118,0.7831799983978271,0.8526236414909363,50000 -9315.853289604189,14.01637315750122,110956.83072423936,249482,0,110956.83072423936,0.6615000367164612,1.4626646041870115,10000,120299.9406888485,0.8876757621765137,0.4183971285820007,0.7831799983978271,0.8526236414909363,50000 -9351.783564567566,14.083440780639648,111376.93041610718,250426,0,111376.93041610718,0.6615000367164612,1.4626646041870115,10000,120756.08883309364,0.8869531154632568,0.4247783422470093,0.7831799983978271,0.8526236414909363,50000 -9389.48131465912,14.150177240371704,111796.849973917,251370,0,111796.849973917,0.6615000367164612,1.4626646041870115,10000,121213.8227274418,0.8869140148162842,0.4152418375015259,0.7831799983978271,0.8526236414909363,50000 -9427.63720202446,14.223804235458374,112217.03732776642,252315,0,112217.03732776642,0.6615000367164612,1.4626646041870115,10000,121672.28964018822,0.8876757621765137,0.4186638295650482,0.7831799983978271,0.8526236414909363,50000 -9462.538872718813,14.292896032333374,112637.292222023,253259,0,112637.292222023,0.6615000367164612,1.4626646041870115,10000,122127.56577944756,0.88623046875,0.4246796667575836,0.7831799983978271,0.8526236414909363,50000 -9500.291576862335,14.35829257965088,113057.61286640169,254205,0,113057.61286640169,0.6615000367164612,1.4626646041870115,10000,122585.75490498544,0.8882226347923279,0.414854496717453,0.7831799983978271,0.8526236414909363,50000 -9538.537852048874,14.424855947494509,113477.76876091956,255152,0,113477.76876091956,0.6615000367164612,1.4626646041870115,10000,123044.27444005013,0.8879101276397705,0.4163959324359894,0.7831799983978271,0.8526236414909363,50000 -9575.771821022034,14.490861177444458,113897.86142706872,256096,0,113897.86142706872,0.6615000367164612,1.4626646041870115,10000,123501.7170226574,0.8895312547683716,0.4140346050262451,0.7831799983978271,0.8526236414909363,50000 -9612.654847860336,14.572320699691772,114317.87077498436,257040,0,114317.87077498436,0.6615000367164612,1.4626646041870115,10000,123958.7415716648,0.8866796493530273,0.4149068593978882,0.7831799983978271,0.8526236414909363,50000 -9648.057507038116,14.651172876358032,114738.04656410216,257985,0,114738.04656410216,0.6615000367164612,1.4626646041870115,10000,124414.4491698742,0.8883788585662842,0.4164501130580902,0.7831799983978271,0.8526236414909363,50000 -9686.863124847412,14.71579933166504,115158.09628129004,258931,0,115158.09628129004,0.6615000367164612,1.4626646041870115,10000,124873.41975784302,0.8895312547683716,0.4190482199192047,0.7831799983978271,0.8526236414909363,50000 -9723.382348060608,14.781482458114624,115578.23216414452,259877,0,115578.23216414452,0.6615000367164612,1.4626646041870115,10000,125330.19090342522,0.8865820169448853,0.4202475249767303,0.7831799983978271,0.8526236414909363,50000 -9759.420825719832,14.854071378707886,115998.11412405968,260822,0,115998.11412405968,0.6615000367164612,1.4626646041870115,10000,125786.23392367364,0.8883007764816284,0.418276309967041,0.7831799983978271,0.8526236414909363,50000 -9797.168113470078,14.92531681060791,116418.11951112749,261767,0,116418.11951112749,0.6615000367164612,1.4626646041870115,10000,126244.10726952551,0.8883788585662842,0.4182990491390228,0.7831799983978271,0.8526236414909363,50000 -9834.422104597092,14.993987560272217,116838.13629627228,262712,0,116838.13629627228,0.6615000367164612,1.4626646041870115,10000,126701.49708580972,0.8886132836341858,0.4173754751682281,0.7831799983978271,0.8526236414909363,50000 -9873.68555045128,15.062115669250488,117258.14576864244,263657,0,117258.14576864244,0.6615000367164612,1.4626646041870115,10000,127160.88830947876,0.8840234279632568,0.4240234792232513,0.7831799983978271,0.8526236414909363,50000 -9911.598610162737,15.131271123886108,117678.73268508913,264601,0,117678.73268508913,0.6615000367164612,1.4626646041870115,10000,127619.50759148598,0.8874804377555847,0.4231193959712982,0.7831799983978271,0.8526236414909363,50000 -9944.990770578384,15.20097303390503,118099.01088500024,265547,0,118099.01088500024,0.6615000367164612,1.4626646041870115,10000,128073.29778432846,0.8855078220367432,0.4264200925827026,0.7831799983978271,0.8526236414909363,50000 -9981.062950611116,15.270276546478271,118518.96572709084,266490,0,118518.96572709084,0.6615000367164612,1.4626646041870115,10000,128529.44418168068,0.8890234231948853,0.4168081283569336,0.7831799983978271,0.8526236414909363,50000 -10018.03050661087,15.34412407875061,118938.91685509682,267435,0,118938.91685509682,0.6615000367164612,1.4626646041870115,10000,128986.48639082909,0.8855664134025574,0.4233571588993072,0.7831799983978271,0.8526236414909363,50000 -10054.78897356987,15.412813425064089,119359.17529034616,268381,0,119359.17529034616,0.6615000367164612,1.4626646041870115,10000,129443.6222999096,0.8865038752555847,0.4210705161094665,0.7831799983978271,0.8526236414909363,50000 -10092.251964330671,15.483604669570925,119779.28468871117,269326,0,119779.28468871117,0.6615000367164612,1.4626646041870115,10000,129901.316116333,0.8862109184265137,0.4198490381240845,0.7831799983978271,0.8526236414909363,50000 -10128.442929267883,15.552424669265749,120199.18902301788,270269,0,120199.18902301788,0.6615000367164612,1.4626646041870115,10000,130357.5304300785,0.8884961009025574,0.4142680764198303,0.7831799983978271,0.8526236414909363,50000 -10167.75212931633,15.631080865859984,120619.08997631072,271214,0,120619.08997631072,0.6615000367164612,1.4626646041870115,10000,130816.87016510963,0.8880664110183716,0.4235153794288635,0.7831799983978271,0.8526236414909363,50000 -10205.575210809708,15.698053359985352,121039.39338994026,272159,0,121039.39338994026,0.6615000367164612,1.4626646041870115,10000,131275.11368703842,0.887011706829071,0.4199972152709961,0.7831799983978271,0.8526236414909363,50000 -10240.997619867325,15.770262479782104,121459.28456163406,273101,0,121459.28456163406,0.6615000367164612,1.4626646041870115,10000,131730.5489742756,0.8873242139816284,0.4234861135482788,0.7831799983978271,0.8526236414909363,50000 -10277.346205711365,15.83849048614502,121879.23887968063,274042,0,121879.23887968063,0.6615000367164612,1.4626646041870115,10000,132186.96950387955,0.8865429759025574,0.4222384989261627,0.7831799983978271,0.8526236414909363,50000 -10315.125161886215,15.915863037109377,122299.38060212135,274987,0,122299.38060212135,0.6615000367164612,1.4626646041870115,10000,132645.01781749725,0.8861327767372131,0.4210696220397949,0.7831799983978271,0.8526236414909363,50000 -10351.85710811615,15.985002279281616,122719.44203877448,275933,0,122719.44203877448,0.6615000367164612,1.4626646041870115,10000,133101.93027758598,0.8864843845367432,0.4235370457172394,0.7831799983978271,0.8526236414909363,50000 -10390.764615297318,16.054569482803345,123139.42353534698,276878,0,123139.42353534698,0.6615000367164612,1.4626646041870115,10000,133560.93928790092,0.8878905773162842,0.4141671657562256,0.7831799983978271,0.8526236414909363,50000 -10426.705643177032,16.121344804763794,123559.36771154404,277822,0,123559.36771154404,0.6615000367164612,1.4626646041870115,10000,134016.9411432743,0.8886327743530273,0.4123508036136627,0.7831799983978271,0.8526236414909363,50000 -10462.081463813782,16.189631462097168,123980.14164686204,278767,0,123980.14164686204,0.6615000367164612,1.4626646041870115,10000,134473.20946788788,0.8878515362739563,0.4221333265304565,0.7831799983978271,0.8526236414909363,50000 -10498.109572172163,16.257033348083496,124400.19126391412,279710,0,124400.19126391412,0.6615000367164612,1.4626646041870115,10000,134929.40495729446,0.8892382383346558,0.4132129549980163,0.7831799983978271,0.8526236414909363,50000 -10536.577868461609,16.33836817741394,124820.24197268486,280653,0,124820.24197268486,0.6615000367164612,1.4626646041870115,10000,135388.0551865101,0.8885741829872131,0.4136310815811157,0.7831799983978271,0.8526236414909363,50000 -10573.71326136589,16.405822038650513,125240.18860721588,281595,0,125240.18860721588,0.6615000367164612,1.4626646041870115,10000,135845.2536330223,0.8893163800239563,0.4133542776107788,0.7831799983978271,0.8526236414909363,50000 -10612.752341747284,16.47682547569275,125660.50177502632,282542,0,125660.50177502632,0.6615000367164612,1.4626646041870115,10000,136304.7267510891,0.8870312571525574,0.4205176830291748,0.7831799983978271,0.8526236414909363,50000 -10650.505030870438,16.544782161712646,126080.76248979568,283487,0,126080.76248979568,0.6615000367164612,1.4626646041870115,10000,136762.85723400116,0.8861523270606995,0.4218646585941314,0.7831799983978271,0.8526236414909363,50000 -10686.965690135956,16.61560034751892,126500.92922019958,284431,0,126500.92922019958,0.6615000367164612,1.4626646041870115,10000,137219.60481786728,0.8878515362739563,0.419499933719635,0.7831799983978271,0.8526236414909363,50000 -10721.06535768509,16.69960618019104,126920.93151402472,285377,0,126920.93151402472,0.6615000367164612,1.4626646041870115,10000,137673.84118700027,0.8895312547683716,0.4159430265426636,0.7831799983978271,0.8526236414909363,50000 -10760.106906414032,16.767632246017456,127340.82446813583,286321,0,127340.82446813583,0.6615000367164612,1.4626646041870115,10000,138132.89316129684,0.88818359375,0.4206604659557342,0.7831799983978271,0.8526236414909363,50000 -10797.100801467896,16.835665464401245,127761.01388812064,287265,0,127761.01388812064,0.6615000367164612,1.4626646041870115,10000,138590.194170475,0.8854101300239563,0.4235857725143432,0.7831799983978271,0.8526236414909363,50000 -10835.234018087389,17.169645071029663,128181.00546360016,288208,0,128181.00546360016,0.6615000367164612,1.4626646041870115,10000,139048.703029871,0.8864648342132568,0.4208704531192779,0.7831799983978271,0.8526236414909363,50000 -10874.292521953585,17.23930835723877,128601.05171775818,289147,0,128601.05171775818,0.6615000367164612,1.4626646041870115,10000,139507.9273519516,0.8867773413658142,0.4219215214252472,0.7831799983978271,0.8526236414909363,50000 -10910.316421985626,17.317003965377808,129021.26099348068,290090,0,129021.26099348068,0.6615000367164612,1.4626646041870115,10000,139964.28707146645,0.88734370470047,0.412392109632492,0.7831799983978271,0.8526236414909363,50000 -10946.443460464478,17.38748025894165,129441.24899697304,291033,0,129441.24899697304,0.6615000367164612,1.4626646041870115,10000,140420.52204823494,0.8869140148162842,0.4278862178325653,0.7831799983978271,0.8526236414909363,50000 -10983.784052610396,17.47106671333313,129861.1802175045,291975,0,129861.1802175045,0.6615000367164612,1.4626646041870115,10000,140877.92724633217,0.885058581829071,0.4274966716766357,0.7831799983978271,0.8526236414909363,50000 -11018.251610517502,17.548232555389404,130281.11374497414,292918,0,130281.11374497414,0.6615000367164612,1.4626646041870115,10000,141332.45461821556,0.8882812261581421,0.414492130279541,0.7831799983978271,0.8526236414909363,50000 -11055.229821681976,17.617226600646973,130701.35450816154,293862,0,130701.35450816154,0.6615000367164612,1.4626646041870115,10000,141789.79231905937,0.8871679306030273,0.4163171350955963,0.7831799983978271,0.8526236414909363,50000 -11091.266440868378,17.695829153060913,131121.62486243248,294807,0,131121.62486243248,0.6615000367164612,1.4626646041870115,10000,142246.22771835327,0.8866991996765137,0.425510048866272,0.7831799983978271,0.8526236414909363,50000 -11126.980338573456,17.765344858169556,131541.72653746605,295748,0,131541.72653746605,0.6615000367164612,1.4626646041870115,10000,142702.16307020187,0.8882812261581421,0.4199662804603576,0.7831799983978271,0.8526236414909363,50000 -11165.005978107452,17.840155839920044,131961.94080877304,296690,0,131961.94080877304,0.6615000367164612,1.4626646041870115,10000,143160.52719926834,0.8873828053474426,0.4187077581882477,0.7831799983978271,0.8526236414909363,50000 -11199.285538434982,17.914023637771606,132382.2102355957,297636,0,132382.2102355957,0.6615000367164612,1.4626646041870115,10000,143615.20072078705,0.8857226371765137,0.4250858128070831,0.7831799983978271,0.8526236414909363,50000 -11238.506650447844,17.98520803451538,132802.52157902718,298580,0,132802.52157902718,0.6615000367164612,1.4626646041870115,10000,144074.85414624214,0.8873046636581421,0.4215691983699798,0.7831799983978271,0.8526236414909363,50000 -11273.791742801666,18.054905891418457,133222.74796295166,299522,0,133222.74796295166,0.6615000367164612,1.4626646041870115,10000,144530.4835705757,0.8869335651397705,0.4194314777851105,0.7831799983978271,0.8526236414909363,50000 -11310.786287784576,18.12456536293029,133642.82978200912,300467,0,133642.82978200912,0.6615000367164612,1.4626646041870115,10000,144987.6792693138,0.8865820169448853,0.4196713864803314,0.7831799983978271,0.8526236414909363,50000 -11346.413735866548,18.19742178916931,134063.15568971634,301412,0,134063.15568971634,0.6615000367164612,1.4626646041870115,10000,145443.7553141117,0.8876953125,0.4157899022102356,0.7831799983978271,0.8526236414909363,50000 -11383.800921678543,18.267784357070923,134483.15005874634,302358,0,134483.15005874634,0.6615000367164612,1.4626646041870115,10000,145901.2564651966,0.8873242139816284,0.4200023412704468,0.7831799983978271,0.8526236414909363,50000 -11419.779727935793,18.338664531707764,134903.26118922234,303303,0,134903.26118922234,0.6615000367164612,1.4626646041870115,10000,146357.46699857712,0.8898828029632568,0.4121365249156952,0.7831799983978271,0.8526236414909363,50000 -11458.216473817823,18.40903115272522,135323.39578986168,304248,0,135323.39578986168,0.6615000367164612,1.4626646041870115,10000,146816.15788459778,0.8899218440055847,0.4108898341655731,0.7831799983978271,0.8526236414909363,50000 -11492.684744119644,18.48048448562622,135743.32669734955,305193,0,135743.32669734955,0.6615000367164612,1.4626646041870115,10000,147270.67830872536,0.8875976204872131,0.416908711194992,0.7831799983978271,0.8526236414909363,50000 -11529.90760731697,18.560980796813965,136163.254529953,306138,0,136163.254529953,0.6615000367164612,1.4626646041870115,10000,147727.958391428,0.8876953125,0.4212630093097687,0.7831799983978271,0.8526236414909363,50000 -11565.154380083084,18.63020730018616,136583.568918705,307081,0,136583.568918705,0.6615000367164612,1.4626646041870115,10000,148183.63747787476,0.8867382407188416,0.4187721908092499,0.7831799983978271,0.8526236414909363,50000 -11602.423706293106,18.709379196166992,137003.82122135162,308026,0,137003.82122135162,0.6615000367164612,1.4626646041870115,10000,148641.28811311722,0.8872265219688416,0.4194600582122803,0.7831799983978271,0.8526236414909363,50000 -11637.83688402176,18.780176162719727,137423.9591574669,308970,0,137423.9591574669,0.6615000367164612,1.4626646041870115,10000,149096.95966887474,0.8910546898841858,0.4130688905715942,0.7831799983978271,0.8526236414909363,50000 -11673.974525928495,18.853289127349854,137844.01475334167,309911,0,137844.01475334167,0.6615000367164612,1.4626646041870115,10000,149553.27523708344,0.8875585794448853,0.4212709665298462,0.7831799983978271,0.8526236414909363,50000 -11709.985116958618,18.925705671310425,138264.2561571598,310856,0,138264.2561571598,0.6615000367164612,1.4626646041870115,10000,150009.64900398254,0.8844921588897705,0.429676204919815,0.7831799983978271,0.8526236414909363,50000 -11748.723598957062,19.001494646072388,138684.43855118752,311800,0,138684.43855118752,0.6615000367164612,1.4626646041870115,10000,150468.6958978176,0.88685542345047,0.4198791682720184,0.7831799983978271,0.8526236414909363,50000 -11786.684258937836,19.078715562820435,139104.31449127197,312743,0,139104.31449127197,0.6615000367164612,1.4626646041870115,10000,150926.6590127945,0.8872265219688416,0.4195569157600403,0.7831799983978271,0.8526236414909363,50000 -11824.821381092072,19.15205764770508,139524.47371315956,313689,0,139524.47371315956,0.6615000367164612,1.4626646041870115,10000,151385.07851743698,0.8866406083106995,0.4205119907855987,0.7831799983978271,0.8526236414909363,50000 -11861.724688053131,19.22867512702942,139944.40990495682,314633,0,139944.40990495682,0.6615000367164612,1.4626646041870115,10000,151842.04384183884,0.8891210556030273,0.4166163206100464,0.7831799983978271,0.8526236414909363,50000 -11897.955362796783,19.30998063087464,140364.45055365562,315578,0,140364.45055365562,0.6615000367164612,1.4626646041870115,10000,152298.446028471,0.8840429782867432,0.4275960028171539,0.7831799983978271,0.8526236414909363,50000 -11936.230577230452,19.38427710533142,140784.50740408897,316522,0,140784.50740408897,0.6615000367164612,1.4626646041870115,10000,152756.90191817284,0.8870507478713989,0.4170957505702972,0.7831799983978271,0.8526236414909363,50000 -11971.0127120018,19.45900011062622,141204.74587631226,317468,0,141204.74587631226,0.6615000367164612,1.4626646041870115,10000,153212.0466427803,0.8878124952316284,0.4161619842052459,0.7831799983978271,0.8526236414909363,50000 -12007.916976690292,19.53209972381592,141624.99427366257,318414,0,141624.99427366257,0.6615000367164612,1.4626646041870115,10000,153669.32224822044,0.8857226371765137,0.4262997210025787,0.7831799983978271,0.8526236414909363,50000 -12046.129980802536,19.61472868919373,142045.31981515884,319359,0,142045.31981515884,0.6615000367164612,1.4626646041870115,10000,154127.99576354027,0.8882616758346558,0.4221481680870056,0.7831799983978271,0.8526236414909363,50000 -12081.67114663124,19.68934178352356,142465.31846928596,320301,0,142465.31846928596,0.6615000367164612,1.4626646041870115,10000,154583.66032719612,0.8875781297683716,0.4201326370239258,0.7831799983978271,0.8526236414909363,50000 -12117.643364191055,19.76137471199036,142885.55553865433,321247,0,142885.55553865433,0.6615000367164612,1.4626646041870115,10000,155039.99154281616,0.8852929472923279,0.4266625940799713,0.7831799983978271,0.8526236414909363,50000 -12153.602460861206,19.83621001243592,143305.7799217701,322195,0,143305.7799217701,0.6615000367164612,1.4626646041870115,10000,155496.300719738,0.8888476490974426,0.4134816527366638,0.7831799983978271,0.8526236414909363,50000 -12192.668316364288,19.912522077560425,143725.90061926842,323138,0,143725.90061926842,0.6615000367164612,1.4626646041870115,10000,155955.61283946037,0.8867773413658142,0.4213497340679168,0.7831799983978271,0.8526236414909363,50000 -12229.671454191208,19.987329483032227,144145.90799617767,324081,0,144145.90799617767,0.6615000367164612,1.4626646041870115,10000,156412.74842858317,0.8880664110183716,0.4195075631141662,0.7831799983978271,0.8526236414909363,50000 -12266.286154031754,20.06376004219055,144566.0894215107,325027,0,144566.0894215107,0.6615000367164612,1.4626646041870115,10000,156869.671346426,0.8875976204872131,0.4145772755146026,0.7831799983978271,0.8526236414909363,50000 -12304.35082411766,20.1387312412262,144986.1606106758,325972,0,144986.1606106758,0.6615000367164612,1.4626646041870115,10000,157327.93233895302,0.8864648342132568,0.4213025867938995,0.7831799983978271,0.8526236414909363,50000 -12343.091826438904,20.214343786239624,145406.35031175613,326914,0,145406.35031175613,0.6615000367164612,1.4626646041870115,10000,157786.98872613907,0.8908789157867432,0.4091778099536896,0.7831799983978271,0.8526236414909363,50000 -12382.091963529589,20.28935670852661,145826.32364201546,327857,0,145826.32364201546,0.6615000367164612,1.4626646041870115,10000,158246.08665442467,0.8892382383346558,0.4134601950645447,0.7831799983978271,0.8526236414909363,50000 -12418.419880151749,20.377084732055664,146246.33985328674,328801,0,146246.33985328674,0.6615000367164612,1.4626646041870115,10000,158702.5681695938,0.887499988079071,0.4138509631156921,0.7831799983978271,0.8526236414909363,50000 -12455.2824716568,20.45138263702393,146666.76261496544,329746,0,146666.76261496544,0.6615000367164612,1.4626646041870115,10000,159159.97850561142,0.8868359327316284,0.4221839010715484,0.7831799983978271,0.8526236414909363,50000 -12491.49240732193,20.52607488632202,147086.9284837246,330690,0,147086.9284837246,0.6615000367164612,1.4626646041870115,10000,159616.47861504555,0.8869531154632568,0.4258280098438263,0.7831799983978271,0.8526236414909363,50000 -12529.59518647194,20.882078886032104,147506.69037270546,331633,0,147506.69037270546,0.6615000367164612,1.4626646041870115,10000,160074.74869942665,0.8866601586341858,0.4206277728080749,0.7831799983978271,0.8526236414909363,50000 -12565.919513225555,20.957940578460693,147926.8549695015,332573,0,147926.8549695015,0.6615000367164612,1.4626646041870115,10000,160531.36344838142,0.8890624642372131,0.4169419109821319,0.7831799983978271,0.8526236414909363,50000 -12603.616605520248,21.03495407104492,148347.11112332344,333516,0,148347.11112332344,0.6615000367164612,1.4626646041870115,10000,160989.44289064407,0.8870702981948853,0.4185446500778198,0.7831799983978271,0.8526236414909363,50000 -12639.42637515068,21.11026430130005,148767.2456228733,334459,0,148767.2456228733,0.6615000367164612,1.4626646041870115,10000,161445.51134824753,0.8873242139816284,0.4183368682861328,0.7831799983978271,0.8526236414909363,50000 -12676.100577354431,21.18437361717224,149187.25118470192,335404,0,149187.25118470192,0.6615000367164612,1.4626646041870115,10000,161902.31423664093,0.8857226371765137,0.4270011782646179,0.7831799983978271,0.8526236414909363,50000 -12714.661440134048,21.27260971069336,149607.38150835037,336343,0,149607.38150835037,0.6615000367164612,1.4626646041870115,10000,162361.1449034214,0.8882421851158142,0.4184745550155639,0.7831799983978271,0.8526236414909363,50000 -12750.583944559095,21.355368614196777,150027.41157388687,337283,0,150027.41157388687,0.6615000367164612,1.4626646041870115,10000,162817.23032855988,0.8860155940055847,0.4216941297054291,0.7831799983978271,0.8526236414909363,50000 -12789.311537981032,21.43650913238525,150447.57310414314,338224,0,150447.57310414314,0.6615000367164612,1.4626646041870115,10000,163276.24970722198,0.88636714220047,0.4208191335201263,0.7831799983978271,0.8526236414909363,50000 -12826.396708726885,21.5212631225586,150867.62800478935,339168,0,150867.62800478935,0.6615000367164612,1.4626646041870115,10000,163733.52498984337,0.8886523246765137,0.4195660650730133,0.7831799983978271,0.8526236414909363,50000 -12862.9636387825,21.607585906982425,151287.56206703186,340113,0,151287.56206703186,0.6615000367164612,1.4626646041870115,10000,164190.16235351562,0.8841210603713989,0.4273837208747864,0.7831799983978271,0.8526236414909363,50000 -12900.13408112526,21.6829161643982,151707.53594970703,341059,0,151707.53594970703,0.6615000367164612,1.4626646041870115,10000,164647.4318766594,0.8888671398162842,0.4130294620990753,0.7831799983978271,0.8526236414909363,50000 -12937.41590332985,21.75795602798462,152127.7273261547,342003,0,152127.7273261547,0.6615000367164612,1.4626646041870115,10000,165105.02981305122,0.8862890601158142,0.4243661165237427,0.7831799983978271,0.8526236414909363,50000 -12974.693664312364,21.836976051330566,152547.99410772324,342947,0,152547.99410772324,0.6615000367164612,1.4626646041870115,10000,165562.70320606232,0.8872265219688416,0.4221601486206054,0.7831799983978271,0.8526236414909363,50000 -13012.170394659042,21.91489100456237,152968.00593590736,343891,0,152968.00593590736,0.6615000367164612,1.4626646041870115,10000,166020.32029771805,0.8874413967132568,0.4182123839855194,0.7831799983978271,0.8526236414909363,50000 -13049.210852384567,22.003353595733643,153388.0217819214,344835,0,153388.0217819214,0.6615000367164612,1.4626646041870115,10000,166477.5153517723,0.8876562118530273,0.4235063791275024,0.7831799983978271,0.8526236414909363,50000 -13087.304275989532,22.083884239196777,153808.1084370613,345777,0,153808.1084370613,0.6615000367164612,1.4626646041870115,10000,166935.82588744164,0.8870507478713989,0.4185346066951751,0.7831799983978271,0.8526236414909363,50000 -13125.770390748978,22.161634922027588,154228.45319128036,346721,0,154228.45319128036,0.6615000367164612,1.4626646041870115,10000,167394.76393675804,0.8866601586341858,0.4186669886112213,0.7831799983978271,0.8526236414909363,50000 -13164.504692316055,22.23713183403015,154648.5946586132,347663,0,154648.5946586132,0.6615000367164612,1.4626646041870115,10000,167853.7642493248,0.887011706829071,0.4225160777568817,0.7831799983978271,0.8526236414909363,50000 -13203.645637273788,22.316974401474,155068.51655745506,348606,0,155068.51655745506,0.6615000367164612,1.4626646041870115,10000,168312.95635533333,0.8873828053474426,0.4159050583839416,0.7831799983978271,0.8526236414909363,50000 -13244.434262275696,22.393721342086792,155488.57597899437,349551,0,155488.57597899437,0.6615000367164612,1.4626646041870115,10000,168773.9312326908,0.8894140720367432,0.4155358672142029,0.7831799983978271,0.8526236414909363,50000 -13283.12947845459,22.47111248970032,155908.78848218918,350497,0,155908.78848218918,0.6615000367164612,1.4626646041870115,10000,169232.9659898281,0.8883593678474426,0.4154931306838989,0.7831799983978271,0.8526236414909363,50000 -13321.576848506927,22.55529522895813,156328.82681155205,351441,0,156328.82681155205,0.6615000367164612,1.4626646041870115,10000,169691.5860159397,0.8911327719688416,0.4061800241470337,0.7831799983978271,0.8526236414909363,50000 -13360.859747171402,22.63696026802063,156748.8176636696,352387,0,156748.8176636696,0.6615000367164612,1.4626646041870115,10000,170150.9909837246,0.8876367211341858,0.4141596555709839,0.7831799983978271,0.8526236414909363,50000 -13397.505673408508,22.719033241271973,157168.728931427,353328,0,157168.728931427,0.6615000367164612,1.4626646041870115,10000,170607.67958974838,0.8882421851158142,0.4224283099174499,0.7831799983978271,0.8526236414909363,50000 -13436.36325597763,22.79829692840576,157588.73220658302,354273,0,157588.73220658302,0.6615000367164612,1.4626646041870115,10000,171066.66933846474,0.8841992020606995,0.4266878366470337,0.7831799983978271,0.8526236414909363,50000 -13473.909487962725,22.889212369918823,158008.9703757763,355217,0,158008.9703757763,0.6615000367164612,1.4626646041870115,10000,171524.59533452988,0.8879101276397705,0.4218570291996002,0.7831799983978271,0.8526236414909363,50000 -13512.55196905136,22.982855796813965,158429.252024889,356161,0,158429.252024889,0.6615000367164612,1.4626646041870115,10000,171983.66263103485,0.8890234231948853,0.4148597717285156,0.7831799983978271,0.8526236414909363,50000 -13553.5134370327,23.065964221954346,158849.1560485363,357101,0,158849.1560485363,0.6615000367164612,1.4626646041870115,10000,172444.66031646729,0.8884570002555847,0.4165292978286743,0.7831799983978271,0.8526236414909363,50000 -13589.560196638107,23.15531325340271,159269.13534212112,358045,0,159269.13534212112,0.6615000367164612,1.4626646041870115,10000,172900.82520484924,0.8863476514816284,0.4233624041080475,0.7831799983978271,0.8526236414909363,50000 -13626.16271162033,23.23595213890076,159689.05473470688,358988,0,159689.05473470688,0.6615000367164612,1.4626646041870115,10000,173357.47729110718,0.8848632574081421,0.427352637052536,0.7831799983978271,0.8526236414909363,50000 -13663.025684833528,23.352552890777588,160109.2667543888,359933,0,160109.2667543888,0.6615000367164612,1.4626646041870115,10000,173814.71801424026,0.8875781297683716,0.4193789064884186,0.7831799983978271,0.8526236414909363,50000 -13699.04360795021,23.43653154373169,160529.21381664276,360877,0,160529.21381664276,0.6615000367164612,1.4626646041870115,10000,174270.8171491623,0.8887499570846558,0.4147139191627502,0.7831799983978271,0.8526236414909363,50000 -13737.636802196505,23.527692556381226,160949.31827545166,361821,0,160949.31827545166,0.6615000367164612,1.4626646041870115,10000,174729.65630316734,0.8864648342132568,0.4235952198505401,0.7831799983978271,0.8526236414909363,50000 -13776.793885946274,23.607590436935425,161369.6048526764,362768,0,161369.6048526764,0.6615000367164612,1.4626646041870115,10000,175189.22968149185,0.8859570026397705,0.4194192886352539,0.7831799983978271,0.8526236414909363,50000 -13813.758406877518,23.688162088394165,161789.86298680303,363712,0,161789.86298680303,0.6615000367164612,1.4626646041870115,10000,175646.58432078362,0.8870702981948853,0.4185977876186371,0.7831799983978271,0.8526236414909363,50000 -13850.107394695282,23.76880979537964,162210.0809226036,364655,0,162210.0809226036,0.6615000367164612,1.4626646041870115,10000,176103.28262138367,0.8883007764816284,0.4166864156723022,0.7831799983978271,0.8526236414909363,50000 -13887.037593126295,23.85041332244873,162630.12792491913,365599,0,162630.12792491913,0.6615000367164612,1.4626646041870115,10000,176560.39166498184,0.8848828077316284,0.4277326166629791,0.7831799983978271,0.8526236414909363,50000 -13925.670383930206,23.932437658309937,163050.34325742722,366543,0,163050.34325742722,0.6615000367164612,1.4626646041870115,10000,177019.37155103683,0.8892773389816284,0.4170738458633423,0.7831799983978271,0.8526236414909363,50000 -13963.74238228798,24.01252174377441,163470.2967107296,367486,0,163470.2967107296,0.6615000367164612,1.4626646041870115,10000,177477.52720284462,0.8848046660423279,0.4290616512298584,0.7831799983978271,0.8526236414909363,50000 -14001.34293437004,24.096700191497803,163890.20581889153,368429,0,163890.20581889153,0.6615000367164612,1.4626646041870115,10000,177935.1705136299,0.8879687190055847,0.420084685087204,0.7831799983978271,0.8526236414909363,50000 -14041.759779453278,24.177711248397827,164310.3669514656,369372,0,164310.3669514656,0.6615000367164612,1.4626646041870115,10000,178395.87911200523,0.8879687190055847,0.420869767665863,0.7831799983978271,0.8526236414909363,50000 -14080.945871591568,24.259997367858887,164730.44117760658,370318,0,164730.44117760658,0.6615000367164612,1.4626646041870115,10000,178855.2717988491,0.88587886095047,0.4209321737289428,0.7831799983978271,0.8526236414909363,50000 -14119.509327411652,24.341216564178467,165150.38989567757,371263,0,165150.38989567757,0.6615000367164612,1.4626646041870115,10000,179313.91506695747,0.8868359327316284,0.419025719165802,0.7831799983978271,0.8526236414909363,50000 -14157.917885541916,24.4231505393982,165570.40331053734,372206,0,165570.40331053734,0.6615000367164612,1.4626646041870115,10000,179772.46859931946,0.8878905773162842,0.4138197898864746,0.7831799983978271,0.8526236414909363,50000 -14196.896178245544,24.50393438339233,165990.54666543007,373151,0,165990.54666543007,0.6615000367164612,1.4626646041870115,10000,180231.72131824493,0.8874022960662842,0.4218630492687225,0.7831799983978271,0.8526236414909363,50000 -14232.955313920977,24.584126710891724,166410.48005771637,374094,0,166410.48005771637,0.6615000367164612,1.4626646041870115,10000,180687.84383511543,0.8924023509025574,0.4078778326511383,0.7831799983978271,0.8526236414909363,50000 -14270.859405994415,24.667348384857178,166830.50103092194,375037,0,166830.50103092194,0.6615000367164612,1.4626646041870115,10000,181145.9012551308,0.8862890601158142,0.4168686270713806,0.7831799983978271,0.8526236414909363,50000 -14310.904839277267,24.76238989830017,167250.72504115105,375981,0,167250.72504115105,0.6615000367164612,1.4626646041870115,10000,181606.3165895939,0.8896484375,0.4130140542984009,0.7831799983978271,0.8526236414909363,50000 -14350.03876209259,24.86554718017578,167670.94829511642,376924,0,167670.94829511642,0.6615000367164612,1.4626646041870115,10000,182065.8271315097,0.8883007764816284,0.4164344668388366,0.7831799983978271,0.8526236414909363,50000 -14387.900235891342,24.95109248161316,168090.943703413,377867,0,168090.943703413,0.6615000367164612,1.4626646041870115,10000,182523.81906175613,0.8856835961341858,0.4257451891899109,0.7831799983978271,0.8526236414909363,50000 -14427.597375631332,25.032990217208862,168510.95400238037,378809,0,168510.95400238037,0.6615000367164612,1.4626646041870115,10000,182983.6586351395,0.8865624666213989,0.422280341386795,0.7831799983978271,0.8526236414909363,50000 -14467.989336252213,25.11535716056824,168930.87727046013,379752,0,168930.87727046013,0.6615000367164612,1.4626646041870115,10000,183444.1056733132,0.89013671875,0.4157166182994842,0.7831799983978271,0.8526236414909363,50000 -14505.456683397291,25.19770884513855,169350.8524608612,380695,0,169350.8524608612,0.6615000367164612,1.4626646041870115,10000,183901.6815738678,0.8870312571525574,0.4199641346931457,0.7831799983978271,0.8526236414909363,50000 -14541.1564347744,25.292958974838257,169771.07170915604,381639,0,169771.07170915604,0.6615000367164612,1.4626646041870115,10000,184357.7455892563,0.8875781297683716,0.4152979850769043,0.7831799983978271,0.8526236414909363,50000 -14582.919032096865,25.37623953819275,170191.14010548592,382582,0,170191.14010548592,0.6615000367164612,1.4626646041870115,10000,184819.70988583565,0.8846093416213989,0.4293917417526245,0.7831799983978271,0.8526236414909363,50000 -14622.43280339241,25.47563648223877,170611.00712823868,383525,0,170611.00712823868,0.6615000367164612,1.4626646041870115,10000,185279.24055552483,0.8859961032867432,0.4245673716068268,0.7831799983978271,0.8526236414909363,50000 -14656.259887695312,25.55725598335266,171031.26492524147,384469,0,171031.26492524147,0.6615000367164612,1.4626646041870115,10000,185733.45716404915,0.88720703125,0.4146858751773834,0.7831799983978271,0.8526236414909363,50000 -14697.488318920135,25.653950452804565,171451.29096341133,385413,0,171451.29096341133,0.6615000367164612,1.4626646041870115,10000,186194.85792946813,0.8879492282867432,0.4205496907234192,0.7831799983978271,0.8526236414909363,50000 -14736.033299446106,25.737998723983765,171871.21469521525,386356,0,171871.21469521525,0.6615000367164612,1.4626646041870115,10000,186653.4598946572,0.887011706829071,0.4235352277755737,0.7831799983978271,0.8526236414909363,50000 -14775.26454615593,25.82158470153809,172291.44046711922,387300,0,172291.44046711922,0.6615000367164612,1.4626646041870115,10000,187113.0497689247,0.8868163824081421,0.417324811220169,0.7831799983978271,0.8526236414909363,50000 -14814.46245265007,25.90627861022949,172711.33534526825,388244,0,172711.33534526825,0.6615000367164612,1.4626646041870115,10000,187572.2773804665,0.8858398199081421,0.4219137728214264,0.7831799983978271,0.8526236414909363,50000 -14853.030284166336,25.99392938613892,173131.36773204803,389184,0,173131.36773204803,0.6615000367164612,1.4626646041870115,10000,188031.0147128105,0.8878515362739563,0.4197618365287781,0.7831799983978271,0.8526236414909363,50000 -14889.88103055954,26.092432737350464,173551.6240181923,390130,0,173551.6240181923,0.6615000367164612,1.4626646041870115,10000,188488.27039361,0.8890624642372131,0.4212661981582641,0.7831799983978271,0.8526236414909363,50000 -14927.226495981216,26.17941808700561,173971.50109505653,391073,0,173971.50109505653,0.6615000367164612,1.4626646041870115,10000,188945.62973308563,0.88525390625,0.4267793893814087,0.7831799983978271,0.8526236414909363,50000 -14966.27654528618,26.275421142578125,174391.5544939041,392017,0,174391.5544939041,0.6615000367164612,1.4626646041870115,10000,189404.8793940544,0.88720703125,0.4180474281311035,0.7831799983978271,0.8526236414909363,50000 -15008.291206359863,26.361035346984863,174811.7327439785,392962,0,174811.7327439785,0.6615000367164612,1.4626646041870115,10000,189867.20791006088,0.8881640434265137,0.4164069592952728,0.7831799983978271,0.8526236414909363,50000 -15049.005800247192,26.46606206893921,175231.97394490242,393906,0,175231.97394490242,0.6615000367164612,1.4626646041870115,10000,190328.31834244728,0.88623046875,0.4230261445045471,0.7831799983978271,0.8526236414909363,50000 -15089.174458265305,26.56468725204468,175652.15966057777,394851,0,175652.15966057777,0.6615000367164612,1.4626646041870115,10000,190788.822149992,0.8873632550239563,0.4190618693828583,0.7831799983978271,0.8526236414909363,50000 -15129.109971046448,26.64938521385193,176072.37421774864,395794,0,176072.37421774864,0.6615000367164612,1.4626646041870115,10000,191249.1066968441,0.8883788585662842,0.4160697758197784,0.7831799983978271,0.8526236414909363,50000 -15167.873075723648,26.746363401412964,176492.64163136482,396738,0,176492.64163136482,0.6615000367164612,1.4626646041870115,10000,191708.2846038341,0.8874413967132568,0.4168276190757751,0.7831799983978271,0.8526236414909363,50000 -15207.94590306282,26.832626581192017,176912.88610959053,397683,0,176912.88610959053,0.6615000367164612,1.4626646041870115,10000,192168.7384550572,0.8902539014816284,0.4140291213989258,0.7831799983978271,0.8526236414909363,50000 -15248.890265226364,26.943554639816284,177332.93433642387,398628,0,177332.93433642387,0.6615000367164612,1.4626646041870115,10000,192629.8915224076,0.88818359375,0.4120750427246094,0.7831799983978271,0.8526236414909363,50000 -15288.936572790146,27.03213572502136,177752.96119642258,399570,0,177752.96119642258,0.6615000367164612,1.4626646041870115,10000,193090.10246396065,0.8872656226158142,0.4196255803108215,0.7831799983978271,0.8526236414909363,50000 -15327.57566690445,27.120450258255005,178173.17875623703,400512,0,178173.17875623703,0.6615000367164612,1.4626646041870115,10000,193549.0971210003,0.8886913657188416,0.4138020575046539,0.7831799983978271,0.8526236414909363,50000 -15365.542269706726,27.2052583694458,178593.23889565468,401457,0,178593.23889565468,0.6615000367164612,1.4626646041870115,10000,194007.2590417862,0.8862695097923279,0.4230132102966308,0.7831799983978271,0.8526236414909363,50000 -15401.641217947006,27.29204750061035,179013.23162341118,402400,0,179013.23162341118,0.6615000367164612,1.4626646041870115,10000,194463.48729109764,0.88685542345047,0.4238695502281189,0.7831799983978271,0.8526236414909363,50000 -15440.896358013151,27.38421607017517,179433.4183971882,403345,0,179433.4183971882,0.6615000367164612,1.4626646041870115,10000,194923.0720796585,0.8892382383346558,0.4164697527885437,0.7831799983978271,0.8526236414909363,50000 -15479.857174158096,27.4800283908844,179853.55819368362,404288,0,179853.55819368362,0.6615000367164612,1.4626646041870115,10000,195382.3181631565,0.88623046875,0.4211839139461517,0.7831799983978271,0.8526236414909363,50000 -15520.02494072914,27.570369720458984,180273.6412432193,405232,0,180273.6412432193,0.6615000367164612,1.4626646041870115,10000,195842.7099413872,0.8868359327316284,0.4213452041149139,0.7831799983978271,0.8526236414909363,50000 -15559.92268037796,27.672037363052368,180693.8637099266,406177,0,180693.8637099266,0.6615000367164612,1.4626646041870115,10000,196302.9815299511,0.8876171708106995,0.4252983629703522,0.7831799983978271,0.8526236414909363,50000 -15599.586994171144,27.76237177848816,181114.0154056549,407121,0,181114.0154056549,0.6615000367164612,1.4626646041870115,10000,196762.9376182556,0.8858984112739563,0.4219390451908111,0.7831799983978271,0.8526236414909363,50000 -15640.886532783508,27.84921193122864,181533.90558075905,408066,0,181533.90558075905,0.6615000367164612,1.4626646041870115,10000,197224.2642486096,0.8882030844688416,0.4163236916065216,0.7831799983978271,0.8526236414909363,50000 -15681.871098041534,27.936203956604004,181953.90784978867,409009,0,181953.90784978867,0.6615000367164612,1.4626646041870115,10000,197685.3879377842,0.88720703125,0.4183976352214813,0.7831799983978271,0.8526236414909363,50000 -15721.50890302658,28.027443408966064,182374.1227302552,409953,0,182374.1227302552,0.6615000367164612,1.4626646041870115,10000,198145.38085389137,0.8857616782188416,0.4226754307746887,0.7831799983978271,0.8526236414909363,50000 -15759.182371139526,28.114766120910645,182794.09159350395,410896,0,182794.09159350395,0.6615000367164612,1.4626646041870115,10000,198603.1596791744,0.8866015672683716,0.4220174551010132,0.7831799983978271,0.8526236414909363,50000 -15798.120640039444,28.22186017036438,183214.29858469963,411840,0,183214.29858469963,0.6615000367164612,1.4626646041870115,10000,199062.46124649048,0.8879101276397705,0.4160162210464477,0.7831799983978271,0.8526236414909363,50000 -15835.116010665894,28.31593894958496,183634.53731536865,412784,0,183634.53731536865,0.6615000367164612,1.4626646041870115,10000,199519.83964657784,0.8880078196525574,0.4205534160137176,0.7831799983978271,0.8526236414909363,50000 -15873.348242282867,28.40547013282776,184054.52427721024,413728,0,184054.52427721024,0.6615000367164612,1.4626646041870115,10000,199978.1976265908,0.8871679306030273,0.4199144542217254,0.7831799983978271,0.8526236414909363,50000 -15911.04614019394,28.493316650390625,184474.57068252563,414672,0,184474.57068252563,0.6615000367164612,1.4626646041870115,10000,200436.07974553108,0.88720703125,0.4232935905456543,0.7831799983978271,0.8526236414909363,50000 -15952.918644666672,28.580711126327515,184894.88948082924,415617,0,184894.88948082924,0.6615000367164612,1.4626646041870115,10000,200898.40726947784,0.8864062428474426,0.4224046170711517,0.7831799983978271,0.8526236414909363,50000 -15991.00540137291,28.669321537017822,185314.9847421646,416562,0,185314.9847421646,0.6615000367164612,1.4626646041870115,10000,201356.72749471664,0.8877148032188416,0.4189739227294922,0.7831799983978271,0.8526236414909363,50000 -16030.271182060242,28.7737877368927,185735.00822615623,417507,0,185735.00822615623,0.6615000367164612,1.4626646041870115,10000,201816.1709141732,0.8869140148162842,0.4190338551998138,0.7831799983978271,0.8526236414909363,50000 -16071.009669065475,28.864509344100952,186155.03788423527,418450,0,186155.03788423527,0.6615000367164612,1.4626646041870115,10000,202277.07932949063,0.8861523270606995,0.425559937953949,0.7831799983978271,0.8526236414909363,50000 -16111.09059047699,28.99151039123535,186575.0562889576,419394,0,186575.0562889576,0.6615000367164612,1.4626646041870115,10000,202737.35576033592,0.8866015672683716,0.4160337746143341,0.7831799983978271,0.8526236414909363,50000 -16149.30091905594,29.085197925567627,186995.1407394409,420339,0,186995.1407394409,0.6615000367164612,1.4626646041870115,10000,203195.7936155796,0.8887109160423279,0.4114374816417694,0.7831799983978271,0.8526236414909363,50000 -16186.994090795515,29.1780104637146,187415.330839634,421284,0,187415.330839634,0.6615000367164612,1.4626646041870115,10000,203653.81957530967,0.88978511095047,0.4178528487682342,0.7831799983978271,0.8526236414909363,50000 -16223.821682929993,29.28029179573059,187835.38063955307,422228,0,187835.38063955307,0.6615000367164612,1.4626646041870115,10000,204110.84920215607,0.8898242115974426,0.412438154220581,0.7831799983978271,0.8526236414909363,50000 -16264.61352443695,29.36870121955872,188255.40531373024,423172,0,188255.40531373024,0.6615000367164612,1.4626646041870115,10000,204571.80442857745,0.8860546946525574,0.4179321825504303,0.7831799983978271,0.8526236414909363,50000 -16306.104352474213,29.47391629219055,188675.36987829208,424115,0,188675.36987829208,0.6615000367164612,1.4626646041870115,10000,205033.4146347046,0.8898437023162842,0.4148611426353454,0.7831799983978271,0.8526236414909363,50000 -16346.529787540436,29.56756401062012,189095.49055552483,425058,0,189095.49055552483,0.6615000367164612,1.4626646041870115,10000,205494.103537798,0.8862499594688416,0.4217503368854522,0.7831799983978271,0.8526236414909363,50000 -16385.24165081978,29.66967749595642,189515.35760688785,426000,0,189515.35760688785,0.6615000367164612,1.4626646041870115,10000,205952.8334183693,0.88818359375,0.4202206432819366,0.7831799983978271,0.8526236414909363,50000 -16422.631457805634,29.771843910217285,189935.4038639069,426944,0,189935.4038639069,0.6615000367164612,1.4626646041870115,10000,206410.42110538483,0.8882812261581421,0.4174784421920776,0.7831799983978271,0.8526236414909363,50000 -16462.080998659134,29.86714720726013,190355.65117049217,427886,0,190355.65117049217,0.6615000367164612,1.4626646041870115,10000,206870.26315903664,0.8866015672683716,0.4225281476974487,0.7831799983978271,0.8526236414909363,50000 -16502.711376667023,29.968742847442627,190775.9188687801,428830,0,190775.9188687801,0.6615000367164612,1.4626646041870115,10000,207331.3117365837,0.8874022960662842,0.4188739955425262,0.7831799983978271,0.8526236414909363,50000 -16542.15592098236,30.05966424942017,191195.83547329903,429774,0,191195.83547329903,0.6615000367164612,1.4626646041870115,10000,207790.8137485981,0.8849608898162842,0.4287216365337372,0.7831799983978271,0.8526236414909363,50000 -16581.287677764893,30.16299033164978,191615.99197554588,430718,0,191615.99197554588,0.6615000367164612,1.4626646041870115,10000,208250.25497484207,0.8863085508346558,0.4251331090927124,0.7831799983978271,0.8526236414909363,50000 -16623.598644971848,30.271098613739014,192036.2289819717,431661,0,192036.2289819717,0.6615000367164612,1.4626646041870115,10000,208712.9606354237,0.8882616758346558,0.4132554531097412,0.7831799983978271,0.8526236414909363,50000 -16664.247750520706,30.36311721801757,192456.2843079567,432603,0,192456.2843079567,0.6615000367164612,1.4626646041870115,10000,209173.80654668808,0.8885155916213989,0.4181333482265472,0.7831799983978271,0.8526236414909363,50000 -16703.37880063057,30.455291032791138,192876.3270497322,433548,0,192876.3270497322,0.6615000367164612,1.4626646041870115,10000,209633.1222922802,0.8854687213897705,0.4237488210201263,0.7831799983978271,0.8526236414909363,50000 -16743.577922344208,30.5573947429657,193296.5520606041,434493,0,193296.5520606041,0.6615000367164612,1.4626646041870115,10000,210093.69893980023,0.8869531154632568,0.4174193441867828,0.7831799983978271,0.8526236414909363,50000 -16784.6428296566,30.652455806732178,193716.42161488533,435435,0,193716.42161488533,0.6615000367164612,1.4626646041870115,10000,210554.7790727616,0.8883007764816284,0.4161953330039978,0.7831799983978271,0.8526236414909363,50000 -16829.308379650116,30.753374576568604,194136.64938545227,436378,0,194136.64938545227,0.6615000367164612,1.4626646041870115,10000,211019.8227727413,0.8877539038658142,0.4231440126895904,0.7831799983978271,0.8526236414909363,50000 -16865.685774326324,30.85575246810913,194556.91301631927,437322,0,194556.91301631927,0.6615000367164612,1.4626646041870115,10000,211476.6163315773,0.8863866925239563,0.4232368469238281,0.7831799983978271,0.8526236414909363,50000 -16904.828684091568,30.9488582611084,194977.04292821884,438268,0,194977.04292821884,0.6615000367164612,1.4626646041870115,10000,211936.03138518333,0.8855273127555847,0.42439004778862,0.7831799983978271,0.8526236414909363,50000 -16944.33552479744,31.040956497192383,195397.16938996315,439212,0,195397.16938996315,0.6615000367164612,1.4626646041870115,10000,212395.8071053028,0.8867773413658142,0.4227740466594696,0.7831799983978271,0.8526236414909363,50000 -16985.492069721222,31.14467477798462,195817.39040994644,440157,0,195817.39040994644,0.6615000367164612,1.4626646041870115,10000,212857.3381161689,0.8872460722923279,0.4181328415870666,0.7831799983978271,0.8526236414909363,50000 -17027.075122833252,31.24666905403137,196237.36933875084,441099,0,196237.36933875084,0.6615000367164612,1.4626646041870115,10000,213319.05184865,0.8878124952316284,0.4193216562271118,0.7831799983978271,0.8526236414909363,50000 -17068.658164978027,31.343939542770386,196657.33681559563,442043,0,196657.33681559563,0.6615000367164612,1.4626646041870115,10000,213780.749254942,0.8869140148162842,0.4169472157955169,0.7831799983978271,0.8526236414909363,50000 -17104.009438753128,31.43527889251709,197077.5290656089,442989,0,197077.5290656089,0.6615000367164612,1.4626646041870115,10000,214236.43349838257,0.8885155916213989,0.4157116115093231,0.7831799983978271,0.8526236414909363,50000 -17147.90717315674,31.5306293964386,197497.7634282112,443934,0,197497.7634282112,0.6615000367164612,1.4626646041870115,10000,214700.71033906937,0.8871288895606995,0.4241305589675903,0.7831799983978271,0.8526236414909363,50000 -17187.872669696808,31.63857674598694,197917.61884331703,444870,0,197917.61884331703,0.6615000367164612,1.4626646041870115,10000,215160.68771243087,0.8898242115974426,0.411442756652832,0.7831799983978271,0.8526236414909363,50000 -17222.858201503754,31.731759071350098,198337.8225800991,445811,0,198337.8225800991,0.6615000367164612,1.4626646041870115,10000,215616.0204000473,0.8890820145606995,0.4095396399497986,0.7831799983978271,0.8526236414909363,50000 -17264.571006774902,31.82603144645691,198758.01270985603,446752,0,198758.01270985603,0.6615000367164612,1.4626646041870115,10000,216078.06741833687,0.8875976204872131,0.4188170135021209,0.7831799983978271,0.8526236414909363,50000 -17303.306176424026,31.93401432037353,199178.1304731369,447695,0,199178.1304731369,0.6615000367164612,1.4626646041870115,10000,216537.0774514675,0.8876367211341858,0.4201096594333648,0.7831799983978271,0.8526236414909363,50000 -17342.620814323425,32.0307891368866,199598.13326835632,448636,0,199598.13326835632,0.6615000367164612,1.4626646041870115,10000,216996.54115891457,0.8866015672683716,0.421945184469223,0.7831799983978271,0.8526236414909363,50000 -17381.971882104874,32.12383246421814,200018.26824116707,449577,0,200018.26824116707,0.6615000367164612,1.4626646041870115,10000,217456.16999864567,0.8864648342132568,0.4194687902927398,0.7831799983978271,0.8526236414909363,50000 -17423.806594848633,32.23769760131836,200438.24313998225,450518,0,200438.24313998225,0.6615000367164612,1.4626646041870115,10000,217918.1436984539,0.89013671875,0.4164175689220428,0.7831799983978271,0.8526236414909363,50000 -17465.161774635315,32.33136487007141,200858.3983180523,451461,0,200858.3983180523,0.6615000367164612,1.4626646041870115,10000,218379.80024075508,0.8877539038658142,0.4214079082012176,0.7831799983978271,0.8526236414909363,50000 -17506.869884729385,32.42599129676819,201278.27744483948,452405,0,201278.27744483948,0.6615000367164612,1.4626646041870115,10000,218841.53184318545,0.8856835961341858,0.4210818409919739,0.7831799983978271,0.8526236414909363,50000 -17547.012185573578,32.526304721832275,201698.52309155464,453348,0,201698.52309155464,0.6615000367164612,1.4626646041870115,10000,219302.0693335533,0.88720703125,0.4244422912597656,0.7831799983978271,0.8526236414909363,50000 -17583.095573425293,32.6318883895874,202118.48056221008,454290,0,202118.48056221008,0.6615000367164612,1.4626646041870115,10000,219758.2660264969,0.8869140148162842,0.421396404504776,0.7831799983978271,0.8526236414909363,50000 -17624.13773369789,32.72527503967285,202538.690376997,455235,0,202538.690376997,0.6615000367164612,1.4626646041870115,10000,220219.6612966061,0.8871093392372131,0.4167193472385406,0.7831799983978271,0.8526236414909363,50000 -17665.55302143097,32.81936073303223,202958.80642414093,456178,0,202958.80642414093,0.6615000367164612,1.4626646041870115,10000,220681.3356461525,0.8875585794448853,0.4189115762710571,0.7831799983978271,0.8526236414909363,50000 -17700.11089038849,32.914966344833374,203379.1025440693,457122,0,203379.1025440693,0.6615000367164612,1.4626646041870115,10000,221136.3347158432,0.8851562142372131,0.426841527223587,0.7831799983978271,0.8526236414909363,50000 -17743.58947658539,33.01036882400513,203799.38678312305,458065,0,203799.38678312305,0.6615000367164612,1.4626646041870115,10000,221600.2426137924,0.8870898485183716,0.4176160097122192,0.7831799983978271,0.8526236414909363,50000 -17779.73451113701,33.1145339012146,204219.39708042145,459006,0,204219.39708042145,0.6615000367164612,1.4626646041870115,10000,222056.5509133339,0.8886132836341858,0.4135399162769317,0.7831799983978271,0.8526236414909363,50000 -17820.108170747757,33.21669864654541,204639.460521698,459950,0,204639.460521698,0.6615000367164612,1.4626646041870115,10000,222517.1409971714,0.8851171731948853,0.4251580834388733,0.7831799983978271,0.8526236414909363,50000 -17855.99463057518,33.313735246658325,205060.09689879417,460892,0,205060.09689879417,0.6615000367164612,1.4626646041870115,10000,222973.8101565838,0.8870507478713989,0.4238626360893249,0.7831799983978271,0.8526236414909363,50000 -17896.213749408722,33.410961866378784,205480.0987341404,461831,0,205480.0987341404,0.6615000367164612,1.4626646041870115,10000,223434.17816066745,0.8887695074081421,0.4198735058307647,0.7831799983978271,0.8526236414909363,50000 -17935.342923879623,33.50675344467163,205900.09218215945,462776,0,205900.09218215945,0.6615000367164612,1.4626646041870115,10000,223893.4460091591,0.8857616782188416,0.4240389466285705,0.7831799983978271,0.8526236414909363,50000 -17970.10102057457,33.603896617889404,206320.3325078488,463722,0,206320.3325078488,0.6615000367164612,1.4626646041870115,10000,224348.5914018154,0.8868945240974426,0.4198920726776123,0.7831799983978271,0.8526236414909363,50000 -18006.612882614136,33.70057702064514,206740.5910184384,464664,0,206740.5910184384,0.6615000367164612,1.4626646041870115,10000,224805.5076739788,0.8875585794448853,0.4207744896411896,0.7831799983978271,0.8526236414909363,50000 -18043.933857917786,33.79648423194885,207160.64344906807,465607,0,207160.64344906807,0.6615000367164612,1.4626646041870115,10000,225263.02623271945,0.8863866925239563,0.4194602966308594,0.7831799983978271,0.8526236414909363,50000 -18081.914219379425,33.89208626747131,207580.8476507664,466553,0,207580.8476507664,0.6615000367164612,1.4626646041870115,10000,225721.3563463688,0.88783198595047,0.4138371646404266,0.7831799983978271,0.8526236414909363,50000 -18118.31734418869,33.98908615112305,208000.78075790405,467498,0,208000.78075790405,0.6615000367164612,1.4626646041870115,10000,226177.8390073776,0.8881640434265137,0.4183202385902405,0.7831799983978271,0.8526236414909363,50000 -18159.481626987457,34.088141679763794,208420.9566919804,468444,0,208420.9566919804,0.6615000367164612,1.4626646041870115,10000,226639.3285050392,0.8898242115974426,0.41509810090065,0.7831799983978271,0.8526236414909363,50000 -18194.91436839104,34.18496608734131,208841.1808238029,469389,0,208841.1808238029,0.6615000367164612,1.4626646041870115,10000,227095.1313097477,0.8898828029632568,0.4089874625205993,0.7831799983978271,0.8526236414909363,50000 -18231.963979959488,34.282028675079346,209261.11147785187,470336,0,209261.11147785187,0.6615000367164612,1.4626646041870115,10000,227552.2586643696,0.8884375095367432,0.4139959514141083,0.7831799983978271,0.8526236414909363,50000 -18269.8248910904,34.37890291213989,209681.18304800987,471281,0,209681.18304800987,0.6615000367164612,1.4626646041870115,10000,228010.3366763592,0.8876171708106995,0.4210407435894012,0.7831799983978271,0.8526236414909363,50000 -18308.068465471268,34.47755169868469,210101.27172207832,472225,0,210101.27172207832,0.6615000367164612,1.4626646041870115,10000,228468.8173315525,0.8863866925239563,0.4219321608543396,0.7831799983978271,0.8526236414909363,50000 -18346.783236026764,34.57780885696411,210521.1846988201,473169,0,210521.1846988201,0.6615000367164612,1.4626646041870115,10000,228927.59491848943,0.8877539038658142,0.4184952080249786,0.7831799983978271,0.8526236414909363,50000 -18389.272876262665,34.690468072891235,210941.2804641724,474114,0,210941.2804641724,0.6615000367164612,1.4626646041870115,10000,229390.34396600723,0.8903319835662842,0.4145033657550812,0.7831799983978271,0.8526236414909363,50000 -18426.54616785049,34.78611469268799,211361.19549322128,475059,0,211361.19549322128,0.6615000367164612,1.4626646041870115,10000,229847.67683005333,0.8855859041213989,0.4250911474227905,0.7831799983978271,0.8526236414909363,50000 -18466.10234069824,34.90059804916382,211781.18303704265,476001,0,211781.18303704265,0.6615000367164612,1.4626646041870115,10000,230307.3840615749,0.8862109184265137,0.4204814434051513,0.7831799983978271,0.8526236414909363,50000 -18504.46773743629,34.99848484992981,212201.07793593407,476944,0,212201.07793593407,0.6615000367164612,1.4626646041870115,10000,230765.7916574478,0.8863866925239563,0.4251146614551544,0.7831799983978271,0.8526236414909363,50000 -18544.29231786728,35.0954225063324,212621.0614106655,477892,0,212621.0614106655,0.6615000367164612,1.4626646041870115,10000,231225.7468166352,0.88671875,0.4219008982181549,0.7831799983978271,0.8526236414909363,50000 -18583.47327947617,35.19288158416748,213041.1797640324,478836,0,213041.1797640324,0.6615000367164612,1.4626646041870115,10000,231685.1926748753,0.8879687190055847,0.4177588820457458,0.7831799983978271,0.8526236414909363,50000 -18623.00409412384,35.30354833602905,213461.09774923325,479780,0,213461.09774923325,0.6615000367164612,1.4626646041870115,10000,232144.8058159352,0.8868749737739563,0.4182649850845337,0.7831799983978271,0.8526236414909363,50000 -18657.94558501244,35.4053909778595,213881.37053871155,480722,0,213881.37053871155,0.6615000367164612,1.4626646041870115,10000,232600.1719095707,0.8855273127555847,0.4250691831111908,0.7831799983978271,0.8526236414909363,50000 -18696.6129488945,35.506397008895874,214301.31644773483,481665,0,214301.31644773483,0.6615000367164612,1.4626646041870115,10000,233058.9361166954,0.8876171708106995,0.4181188941001892,0.7831799983978271,0.8526236414909363,50000 -18736.40962123871,35.604432344436646,214721.59282803533,482609,0,214721.59282803533,0.6615000367164612,1.4626646041870115,10000,233519.15626096723,0.8867382407188416,0.4196585118770599,0.7831799983978271,0.8526236414909363,50000 -18773.288115262985,35.703930616378784,215141.62749814987,483552,0,215141.62749814987,0.6615000367164612,1.4626646041870115,10000,233976.2240343094,0.8869335651397705,0.4238941967487335,0.7831799983978271,0.8526236414909363,50000 -18810.65714740753,35.80401062965393,215561.4826450348,484496,0,215561.4826450348,0.6615000367164612,1.4626646041870115,10000,234433.5977020264,0.8873242139816284,0.4227310717105865,0.7831799983978271,0.8526236414909363,50000 -18849.29625797272,35.90559959411621,215981.5551698208,485439,0,215981.5551698208,0.6615000367164612,1.4626646041870115,10000,234892.46096420288,0.8871093392372131,0.421786367893219,0.7831799983978271,0.8526236414909363,50000 -18893.18506455421,36.01810646057129,216401.4870171547,486383,0,216401.4870171547,0.6615000367164612,1.4626646041870115,10000,235356.4436588288,0.8866601586341858,0.4253378212451935,0.7831799983978271,0.8526236414909363,50000 -18932.574870586395,36.117307901382446,216821.3461008072,487328,0,216821.3461008072,0.6615000367164612,1.4626646041870115,10000,235815.8413271904,0.8872265219688416,0.4159683585166931,0.7831799983978271,0.8526236414909363,50000 -18970.62588739395,36.22060370445252,217241.2795341015,488273,0,217241.2795341015,0.6615000367164612,1.4626646041870115,10000,236273.9787230492,0.8877733945846558,0.4194367825984955,0.7831799983978271,0.8526236414909363,50000 -19010.070831775665,36.32327747344971,217661.31353735924,489217,0,217661.31353735924,0.6615000367164612,1.4626646041870115,10000,236733.6101295948,0.8881054520606995,0.4167794585227966,0.7831799983978271,0.8526236414909363,50000 -19052.525792360306,36.43513631820679,218081.38036298752,490163,0,218081.38036298752,0.6615000367164612,1.4626646041870115,10000,237196.2936933041,0.8856640458106995,0.4225226938724518,0.7831799983978271,0.8526236414909363,50000 -19093.32017993927,36.5455060005188,218501.2313103676,491107,0,218501.2313103676,0.6615000367164612,1.4626646041870115,10000,237657.09906315804,0.8886913657188416,0.413221538066864,0.7831799983978271,0.8526236414909363,50000 -19129.19732618332,36.66964244842529,218921.18591451645,492050,0,218921.18591451645,0.6615000367164612,1.4626646041870115,10000,238113.10402703285,0.8886913657188416,0.4157118201255798,0.7831799983978271,0.8526236414909363,50000 -19166.96463418007,36.76979398727417,219341.1279244423,492992,0,219341.1279244423,0.6615000367164612,1.4626646041870115,10000,238570.96243953705,0.8909375071525574,0.4079668819904327,0.7831799983978271,0.8526236414909363,50000 -19207.34047794342,36.87393951416016,219761.21338629723,493934,0,219761.21338629723,0.6615000367164612,1.4626646041870115,10000,239031.57810759544,0.8881054520606995,0.4184964597225189,0.7831799983978271,0.8526236414909363,50000 -19245.554092168808,36.97444653511048,220181.35465550423,494878,0,220181.35465550423,0.6615000367164612,1.4626646041870115,10000,239490.0839271545,0.8865429759025574,0.4192816019058227,0.7831799983978271,0.8526236414909363,50000 -19286.792535066605,37.07832598686218,220601.31715726847,495819,0,220601.31715726847,0.6615000367164612,1.4626646041870115,10000,239951.438188076,0.8858202695846558,0.4241812825202942,0.7831799983978271,0.8526236414909363,50000 -19324.87941670417,37.19647455215454,221021.30834054947,496762,0,221021.30834054947,0.6615000367164612,1.4626646041870115,10000,240409.68428444865,0.8873046636581421,0.4224570691585541,0.7831799983978271,0.8526236414909363,50000 -19360.00890374184,37.29786229133606,221441.4514014721,497706,0,221441.4514014721,0.6615000367164612,1.4626646041870115,10000,240865.10836744308,0.8900781273841858,0.4112922847270965,0.7831799983978271,0.8526236414909363,50000 -19402.17110681533,37.41727375984192,221861.6649634838,498649,0,221861.6649634838,0.6615000367164612,1.4626646041870115,10000,241327.65272402763,0.8861132860183716,0.424767792224884,0.7831799983978271,0.8526236414909363,50000 -19445.47658014297,37.51878428459168,222281.58246564865,499590,0,222281.58246564865,0.6615000367164612,1.4626646041870115,10000,241791.0271379948,0.888476550579071,0.4164844751358032,0.7831799983978271,0.8526236414909363,50000 -19480.782245635983,37.63924813270569,222701.7090072632,500531,0,222701.7090072632,0.6615000367164612,1.4626646041870115,10000,242246.6289396286,0.8866991996765137,0.4256736040115356,0.7831799983978271,0.8526236414909363,50000 -19523.047776937485,37.74262857437134,223121.9326927662,501475,0,223121.9326927662,0.6615000367164612,1.4626646041870115,10000,242709.27169013023,0.8859765529632568,0.4227696061134338,0.7831799983978271,0.8526236414909363,50000 -19563.318788290024,37.84835863113403,223542.1022367477,502421,0,223542.1022367477,0.6615000367164612,1.4626646041870115,10000,243169.86743497849,0.8878905773162842,0.4137028455734253,0.7831799983978271,0.8526236414909363,50000 -19602.17483854294,37.95388746261597,223962.2054350376,503361,0,223962.2054350376,0.6615000367164612,1.4626646041870115,10000,243628.98177289963,0.8865038752555847,0.4238264262676239,0.7831799983978271,0.8526236414909363,50000 -19645.54245853424,38.05898094177246,224382.14086842537,504305,0,224382.14086842537,0.6615000367164612,1.4626646041870115,10000,244092.43917560577,0.8867577910423279,0.4207553267478943,0.7831799983978271,0.8526236414909363,50000 -19682.870091676712,38.161217212677,224802.12478852272,505247,0,224802.12478852272,0.6615000367164612,1.4626646041870115,10000,244549.902169466,0.8862499594688416,0.4233685731887817,0.7831799983978271,0.8526236414909363,50000 -19723.331513643265,38.26738238334656,225222.24004030228,506190,0,225222.24004030228,0.6615000367164612,1.4626646041870115,10000,245010.63395023343,0.8874218463897705,0.4169977903366089,0.7831799983978271,0.8526236414909363,50000 -19761.377603769302,38.3882851600647,225642.23197722435,507136,0,225642.23197722435,0.6615000367164612,1.4626646041870115,10000,245468.84431529045,0.8875390291213989,0.418641984462738,0.7831799983978271,0.8526236414909363,50000 -19801.3218562603,38.48969912528992,226062.2147741317,508079,0,226062.2147741317,0.6615000367164612,1.4626646041870115,10000,245928.9227323532,0.8869531154632568,0.4258208274841308,0.7831799983978271,0.8526236414909363,50000 -19845.604430675507,38.59902667999268,226482.067596674,509024,0,226482.067596674,0.6615000367164612,1.4626646041870115,10000,246393.21773839,0.8847460746765137,0.4278537333011627,0.7831799983978271,0.8526236414909363,50000 -19885.14374423027,38.70250916481018,226902.1107466221,509965,0,226902.1107466221,0.6615000367164612,1.4626646041870115,10000,246852.9529211521,0.8874022960662842,0.4216749966144562,0.7831799983978271,0.8526236414909363,50000 -19925.925420999527,38.81987309455872,227322.3340280056,510909,0,227322.3340280056,0.6615000367164612,1.4626646041870115,10000,247314.1253759861,0.8890234231948853,0.4123329520225525,0.7831799983978271,0.8526236414909363,50000 -19969.54983282089,38.92194437980652,227742.40439462665,511853,0,227742.40439462665,0.6615000367164612,1.4626646041870115,10000,247777.97219228745,0.8871679306030273,0.4232731163501739,0.7831799983978271,0.8526236414909363,50000 -20008.613343715668,39.02891778945923,228162.349899292,512796,0,228162.349899292,0.6615000367164612,1.4626646041870115,10000,248237.13788104057,0.8870507478713989,0.4156704545021057,0.7831799983978271,0.8526236414909363,50000 -20048.68375325203,39.13458871841431,228582.2359097004,513739,0,228582.2359097004,0.6615000367164612,1.4626646041870115,10000,248697.24965715408,0.8882030844688416,0.4158047139644623,0.7831799983978271,0.8526236414909363,50000 -20089.395189762115,39.24177360534668,229002.4842107296,514682,0,229002.4842107296,0.6615000367164612,1.4626646041870115,10000,249158.3658988476,0.8873828053474426,0.4168609082698822,0.7831799983978271,0.8526236414909363,50000 -20128.963701486588,39.34479546546936,229422.5807933808,515623,0,229422.5807933808,0.6615000367164612,1.4626646041870115,10000,249618.18248152733,0.8875976204872131,0.4215762913227081,0.7831799983978271,0.8526236414909363,50000 -20166.67224431038,39.44740009307861,229842.44540429115,516566,0,229842.44540429115,0.6615000367164612,1.4626646041870115,10000,250075.90839600563,0.8914648294448853,0.4060322046279907,0.7831799983978271,0.8526236414909363,50000 -20202.819187641144,39.57495212554932,230262.38619685173,517510,0,230262.38619685173,0.6615000367164612,1.4626646041870115,10000,250532.1736881733,0.8875195384025574,0.4187245368957519,0.7831799983978271,0.8526236414909363,50000 -20241.813415050507,39.69140291213989,230682.52629995343,518455,0,230682.52629995343,0.6615000367164612,1.4626646041870115,10000,250991.47455906868,0.8878515362739563,0.417012482881546,0.7831799983978271,0.8526236414909363,50000 -20277.67854499817,39.79632830619812,231102.6332271099,519401,0,231102.6332271099,0.6615000367164612,1.4626646041870115,10000,251447.6019744873,0.8869335651397705,0.4184837341308594,0.7831799983978271,0.8526236414909363,50000 -20316.76153421402,39.9029598236084,231522.629506588,520346,0,231522.629506588,0.6615000367164612,1.4626646041870115,10000,251906.8376162052,0.88685542345047,0.4225926995277405,0.7831799983978271,0.8526236414909363,50000 -20358.879606485367,40.01094198226929,231942.6417376995,521288,0,231942.6417376995,0.6615000367164612,1.4626646041870115,10000,252369.12612867355,0.8891991972923279,0.4170268177986145,0.7831799983978271,0.8526236414909363,50000 -20396.625748872757,40.11666750907898,232362.6910688877,522232,0,232362.6910688877,0.6615000367164612,1.4626646041870117,10000,252827.0764052868,0.8890429735183716,0.4175160527229309,0.7831799983978271,0.8526236414909363,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index a8d622f19..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5783 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.36564496,6.907756,,,,,,,,,,,,,, -1,,,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.89926743507385,83.47133111953735,42.89926743507385,40.57194924354553,0.0,0.0 -100,0.44641295,6.8850937,,,,,,,,,,,,,, -200,0.59429634,6.7706566,,,,,,,,,,,,,, -300,0.75012225,6.6653485,,,,,,,,,,,,,, -400,1.4276396,6.5549603,,,,,,,,,,,,,, -500,0.9807071,6.4558372,,,,,,,,,,,,,, -600,0.8911658,6.5403676,,,,,,,,,,,,,, -700,1.0438048,6.286006,,,,,,,,,,,,,, -800,1.1104338,6.31271,,,,,,,,,,,,,, -900,1.3335994,6.166594,,,,,,,,,,,,,, -911,,,0.0390429683029651,5.780252933502197,0.0384799987077713,5.81451940536499,50000.0,0.0298000015318393,5.954464435577393,10000.0,462.8738512992859,525.1148769855499,462.8738512992859,62.162476539611816,0.0280063152313232,0.0 -1000,1.1211135,6.072832,,,,,,,,,,,,,, -1100,0.9412233,6.629085,,,,,,,,,,,,,, -1200,1.1078756,6.5785003,,,,,,,,,,,,,, -1300,0.91457,6.5507784,,,,,,,,,,,,,, -1400,2.7578623,5.8807583,,,,,,,,,,,,,, -1500,1.1480947,5.8107066,,,,,,,,,,,,,, -1600,1.082917,5.7179203,,,,,,,,,,,,,, -1700,1.0758239,5.707069,,,,,,,,,,,,,, -1800,1.0715761,5.7306595,,,,,,,,,,,,,, -1874,,,0.0801562517881393,5.216738700866699,0.0748599991202354,5.251400947570801,50000.0,0.0593000017106533,5.495050430297852,10000.0,882.9174020290375,966.6380949020386,882.9174020290375,83.55760931968689,0.0585842132568359,0.0 -1900,0.9890038,5.69713,,,,,,,,,,,,,, -2000,1.0384347,5.5976048,,,,,,,,,,,,,, -2100,1.056057,6.59394,,,,,,,,,,,,,, -2200,0.9771564,5.432266,,,,,,,,,,,,,, -2300,1.1425567,5.4911246,,,,,,,,,,,,,, -2400,1.0846064,5.4807777,,,,,,,,,,,,,, -2500,1.2235746,6.2118783,,,,,,,,,,,,,, -2600,1.5482128,6.1991286,,,,,,,,,,,,,, -2700,0.96458423,6.3630476,,,,,,,,,,,,,, -2800,0.88063145,5.4842362,,,,,,,,,,,,,, -2839,,,0.1545898467302322,4.480022430419922,0.1401599943637848,4.580586910247803,50000.0,0.1096000075340271,4.916996479034424,10000.0,1303.1719748973846,1408.4984166622162,1303.1719748973846,105.08153319358826,0.0892825126647949,0.0 -2900,1.1373489,5.1845274,,,,,,,,,,,,,, -3000,1.0531332,6.4263635,,,,,,,,,,,,,, -3100,0.8735631,5.4591713,,,,,,,,,,,,,, -3200,1.0570148,4.970723,,,,,,,,,,,,,, -3300,1.1487818,5.049449,,,,,,,,,,,,,, -3400,0.8650463,4.8407307,,,,,,,,,,,,,, -3500,1.057589,5.8989143,,,,,,,,,,,,,, -3600,1.1604441,4.856368,,,,,,,,,,,,,, -3700,0.795493,5.416158,,,,,,,,,,,,,, -3800,1.1482832,4.9043226,,,,,,,,,,,,,, -3804,,,0.2083788961172104,4.025781154632568,0.1943800002336502,4.114321231842041,50000.0,0.1480000019073486,4.542527198791504,10000.0,1723.439739942551,1850.476708889008,1723.439739942551,126.71310567855836,0.1167037487030029,0.0 -3900,1.2882366,6.39699,,,,,,,,,,,,,, -4000,0.8550987,4.906581,,,,,,,,,,,,,, -4100,1.1589506,5.3128457,,,,,,,,,,,,,, -4200,0.7907415,4.8465104,,,,,,,,,,,,,, -4300,0.81374586,5.1624703,,,,,,,,,,,,,, -4400,0.9549644,4.634536,,,,,,,,,,,,,, -4500,0.94776285,4.7797503,,,,,,,,,,,,,, -4600,0.8755977,4.385535,,,,,,,,,,,,,, -4700,1.1810313,4.5235667,,,,,,,,,,,,,, -4761,,,0.2685156166553497,3.618990182876587,0.2489599883556366,3.72690749168396,50000.0,0.1910000145435333,4.191851139068604,10000.0,2143.524181365967,2292.5040712356567,2143.524181365967,148.57643723487854,0.1443471908569336,0.0 -4800,0.7555922,6.210031,,,,,,,,,,,,,, -4900,0.51259637,6.1485133,,,,,,,,,,,,,, -5000,1.3275591,4.5425906,,,,,,,,,,,,,, -5100,0.82116985,4.8307185,,,,,,,,,,,,,, -5200,0.79534,4.604519,,,,,,,,,,,,,, -5300,0.925296,4.3744354,,,,,,,,,,,,,, -5400,1.0141841,4.3238583,,,,,,,,,,,,,, -5500,0.8674109,4.7655926,,,,,,,,,,,,,, -5600,0.75172895,5.357958,,,,,,,,,,,,,, -5700,0.93313617,4.1601343,,,,,,,,,,,,,, -5717,,,0.3191015422344208,3.2768993377685547,0.2922999858856201,3.4312615394592285,50000.0,0.2207000106573104,3.9930953979492174,10000.0,2563.565025806427,2734.2075912952423,2563.565025806427,170.15734696388245,0.1746456623077392,0.0 -5800,0.9091072,4.1399255,,,,,,,,,,,,,, -5900,1.1782764,4.7506127,,,,,,,,,,,,,, -6000,0.8452624,3.9723396,,,,,,,,,,,,,, -6100,0.7538952,4.638942,,,,,,,,,,,,,, -6200,0.9794911,4.595807,,,,,,,,,,,,,, -6300,0.7052724,5.211636,,,,,,,,,,,,,, -6400,0.8142692,4.2366776,,,,,,,,,,,,,, -6500,1.0536757,4.012181,,,,,,,,,,,,,, -6600,0.7363752,6.0313373,,,,,,,,,,,,,, -6672,,,0.3853124976158142,2.87861442565918,0.3353399932384491,3.149968147277832,50000.0,0.2574000060558319,3.691765785217285,10000.0,2983.585864305496,3175.8048005104065,2983.585864305496,191.6556527614593,0.2017805576324463,0.0 -6700,0.8748788,3.7791395,,,,,,,,,,,,,, -6800,0.9384577,4.309936,,,,,,,,,,,,,, -6900,0.95882267,3.7745802,,,,,,,,,,,,,, -7000,0.79011583,4.676941,,,,,,,,,,,,,, -7100,0.8630303,4.655184,,,,,,,,,,,,,, -7200,0.85169387,3.8111646,,,,,,,,,,,,,, -7300,0.9004422,3.8119957,,,,,,,,,,,,,, -7400,0.7874619,5.7993345,,,,,,,,,,,,,, -7500,0.9681681,3.819984,,,,,,,,,,,,,, -7600,0.57614213,5.9456563,,,,,,,,,,,,,, -7626,,,0.3844531178474426,2.8820815086364746,0.3578200042247772,3.024298667907715,50000.0,0.2787000238895416,3.58750057220459,10000.0,3403.8257796764374,3624.00797700882,3403.8257796764374,219.53815126419067,0.2309470176696777,0.0 -7700,0.57592475,5.8446383,,,,,,,,,,,,,, -7800,0.7973208,5.2695007,,,,,,,,,,,,,, -7900,0.9121862,3.8101196,,,,,,,,,,,,,, -8000,0.8136797,4.0398536,,,,,,,,,,,,,, -8100,0.96404517,3.7255335,,,,,,,,,,,,,, -8200,0.76312834,5.716934,,,,,,,,,,,,,, -8300,0.93470705,4.022759,,,,,,,,,,,,,, -8400,1.0274012,3.489875,,,,,,,,,,,,,, -8500,0.8569691,4.3267136,,,,,,,,,,,,,, -8585,,,0.4216406047344208,2.65896987915039,0.3878999948501587,2.8341987133026123,50000.0,0.2975000143051147,3.4198501110076904,10000.0,3823.89849114418,4069.170470237732,3823.89849114418,244.53706979751587,0.270564317703247,0.0 -8600,0.8064131,3.7271125,,,,,,,,,,,,,, -8700,0.995458,3.5761123,,,,,,,,,,,,,, -8800,0.98274595,3.780645,,,,,,,,,,,,,, -8900,0.9998303,3.752965,,,,,,,,,,,,,, -9000,0.78518355,4.854452,,,,,,,,,,,,,, -9100,0.8616872,4.42936,,,,,,,,,,,,,, -9200,0.6807359,4.8665676,,,,,,,,,,,,,, -9300,0.7052721,4.815254,,,,,,,,,,,,,, -9400,0.7305721,5.7212973,,,,,,,,,,,,,, -9500,0.8332075,4.3794684,,,,,,,,,,,,,, -9541,,,0.4546484351158142,2.4515960216522217,0.4140599966049194,2.663148641586304,50000.0,0.3157000243663788,3.2867302894592285,10000.0,4243.820187568665,4515.805619239807,4243.820187568665,271.1634876728058,0.3059976100921631,0.0 -9600,0.8432981,5.8013525,,,,,,,,,,,,,, -9700,0.9591341,3.8859591,,,,,,,,,,,,,, -9800,1.1717112,3.512985,,,,,,,,,,,,,, -9900,0.9900577,3.4053884,,,,,,,,,,,,,, -10000,0.7924251,5.7037253,,,,,,,,,,,,,, -10100,0.95003027,3.3624442,,,,,,,,,,,,,, -10200,0.89158344,4.34758,,,,,,,,,,,,,, -10300,0.7709022,5.3341813,,,,,,,,,,,,,, -10400,1.1453725,5.765778,,,,,,,,,,,,,, -10498,,,0.4586913883686065,2.4536895751953125,0.4279399812221527,2.6150896549224854,50000.0,0.3304000198841095,3.208634853363037,10000.0,4663.750853538513,4966.315071105957,4663.750853538513,301.6569554805756,0.3401913642883301,0.0 -10500,0.91386116,3.470016,,,,,,,,,,,,,, -10600,0.741304,4.872937,,,,,,,,,,,,,, -10700,0.91458875,3.778633,,,,,,,,,,,,,, -10800,1.0281149,3.3368149,,,,,,,,,,,,,, -10900,0.7299126,5.3443003,,,,,,,,,,,,,, -11000,0.97422355,5.8159385,,,,,,,,,,,,,, -11100,0.97584885,3.7188087,,,,,,,,,,,,,, -11200,0.94616395,3.683005,,,,,,,,,,,,,, -11300,1.0713484,3.3565197,,,,,,,,,,,,,, -11400,0.8324618,5.74574,,,,,,,,,,,,,, -11454,,,0.477832019329071,2.340218067169189,0.4437799751758575,2.513360977172852,50000.0,0.3394000232219696,3.15163803100586,10000.0,5083.722537994385,5419.739944219589,5083.722537994385,335.0212616920471,0.377582311630249,0.0 -11500,0.8099215,4.356706,,,,,,,,,,,,,, -11600,1.0084094,3.3295064,,,,,,,,,,,,,, -11700,1.0578827,3.3654165,,,,,,,,,,,,,, -11800,1.1621356,3.2436347,,,,,,,,,,,,,, -11900,1.1851233,3.6395044,,,,,,,,,,,,,, -12000,0.98234254,3.2828352,,,,,,,,,,,,,, -12100,0.79830503,4.5908213,,,,,,,,,,,,,, -12200,1.1119407,3.5558877,,,,,,,,,,,,,, -12300,0.90970707,3.4298813,,,,,,,,,,,,,, -12400,1.0787983,3.225039,,,,,,,,,,,,,, -12412,,,0.5044335722923279,2.160196781158448,0.4608799815177917,2.383046865463257,50000.0,0.356000006198883,3.023467540740967,10000.0,5504.124020576477,5868.309863567352,5504.124020576477,363.1043710708618,0.4115440845489502,0.0 -12500,0.84396964,4.1626267,,,,,,,,,,,,,, -12600,0.728126,5.261274,,,,,,,,,,,,,, -12700,1.1536405,3.23044,,,,,,,,,,,,,, -12800,1.0027905,5.209627,,,,,,,,,,,,,, -12900,1.1483519,3.3347645,,,,,,,,,,,,,, -13000,1.1557142,3.3146806,,,,,,,,,,,,,, -13100,0.7436491,5.7309136,,,,,,,,,,,,,, -13200,0.967636,4.616511,,,,,,,,,,,,,, -13300,0.830729,5.0040307,,,,,,,,,,,,,, -13365,,,0.52685546875,2.067809820175171,0.4763000011444092,2.321436882019043,50000.0,0.3673000037670135,2.974306106567383,10000.0,5924.25744843483,6314.410578966141,5924.25744843483,388.9872608184815,0.4456574916839599,0.0 -13400,0.9627873,3.5144484,,,,,,,,,,,,,, -13500,1.026738,3.149277,,,,,,,,,,,,,, -13600,1.3194011,3.1976192,,,,,,,,,,,,,, -13700,1.0734736,3.2211978,,,,,,,,,,,,,, -13800,0.9847028,5.500372,,,,,,,,,,,,,, -13900,0.812286,5.7002716,,,,,,,,,,,,,, -14000,0.87040514,5.4903526,,,,,,,,,,,,,, -14100,0.76097333,4.9962997,,,,,,,,,,,,,, -14200,0.7598653,4.605926,,,,,,,,,,,,,, -14300,1.0378227,2.9976306,,,,,,,,,,,,,, -14317,,,0.5179492235183716,2.1016087532043457,0.4814999997615814,2.2885682582855225,50000.0,0.3791000247001648,2.91217041015625,10000.0,6344.441004276276,6765.454939126968,6344.441004276276,419.7616608142853,0.4812161922454834,0.0 -14400,1.1686523,3.1205177,,,,,,,,,,,,,, -14500,0.952601,3.4360356,,,,,,,,,,,,,, -14600,1.0830716,2.9478881,,,,,,,,,,,,,, -14700,1.072916,3.0268483,,,,,,,,,,,,,, -14800,0.99389064,3.170178,,,,,,,,,,,,,, -14900,1.0098019,3.0468714,,,,,,,,,,,,,, -15000,0.826662,5.257898,,,,,,,,,,,,,, -15100,0.9954067,3.076493,,,,,,,,,,,,,, -15200,1.1041021,2.9771798,,,,,,,,,,,,,, -15266,,,0.5394921898841858,1.976670503616333,0.4970199763774872,2.2036519050598145,50000.0,0.3884000182151794,2.861824989318848,10000.0,6764.550142526627,7218.211961269379,6764.550142526627,452.32911682128906,0.5112817287445068,0.0 -15300,1.0615152,3.0776896,,,,,,,,,,,,,, -15400,1.2154467,3.1987746,,,,,,,,,,,,,, -15500,1.070837,3.0814698,,,,,,,,,,,,,, -15600,1.040626,3.2408638,,,,,,,,,,,,,, -15700,0.909335,3.8365593,,,,,,,,,,,,,, -15800,0.92616236,4.308678,,,,,,,,,,,,,, -15900,1.306347,2.9304383,,,,,,,,,,,,,, -16000,0.96792144,3.3316224,,,,,,,,,,,,,, -16100,1.1955878,3.059236,,,,,,,,,,,,,, -16200,1.0449415,2.9930668,,,,,,,,,,,,,, -16217,,,0.5404882431030273,2.0090866088867188,0.4944399893283844,2.2409517765045166,50000.0,0.3890000283718109,2.8712475299835205,10000.0,7184.7117347717285,7673.275277376175,7184.7117347717285,487.1478357315064,0.5434060096740723,0.0 -16300,1.0773095,2.9110794,,,,,,,,,,,,,, -16400,1.0656755,4.5183907,,,,,,,,,,,,,, -16500,0.81484914,5.323902,,,,,,,,,,,,,, -16600,0.83435476,3.8471727,,,,,,,,,,,,,, -16700,1.0457768,3.053572,,,,,,,,,,,,,, -16800,0.9283581,3.5386124,,,,,,,,,,,,,, -16900,1.0564126,3.004608,,,,,,,,,,,,,, -17000,1.1082401,3.1359987,,,,,,,,,,,,,, -17100,1.0829132,3.0144763,,,,,,,,,,,,,, -17168,,,0.5452538728713989,1.9710559844970703,0.5032199621200562,2.1825075149536133,50000.0,0.38960000872612,2.8298895359039307,10000.0,7604.762713670731,8129.259591817856,7604.762713670731,522.9998071193695,0.5738611221313477,0.0 -17200,1.1098695,2.9627597,,,,,,,,,,,,,, -17300,1.0740194,2.9138856,,,,,,,,,,,,,, -17400,0.919535,3.8201241,,,,,,,,,,,,,, -17500,1.0093304,3.462487,,,,,,,,,,,,,, -17600,0.8558601,5.580918,,,,,,,,,,,,,, -17700,0.98842025,3.542357,,,,,,,,,,,,,, -17800,1.3025749,3.0217357,,,,,,,,,,,,,, -17900,1.0491264,2.819655,,,,,,,,,,,,,, -18000,1.1256219,2.8582811,,,,,,,,,,,,,, -18100,1.0154645,5.0001154,,,,,,,,,,,,,, -18118,,,0.5589648485183716,1.9002271890640257,0.5151199698448181,2.097113847732544,50000.0,0.4029000103473663,2.759299993515014,10000.0,8025.002263069153,8585.051396369934,8025.002263069153,558.4737074375153,0.6011612415313721,0.0 -18200,1.0957444,2.9311342,,,,,,,,,,,,,, -18300,1.1295631,2.9205616,,,,,,,,,,,,,, -18400,1.0037019,4.7933784,,,,,,,,,,,,,, -18500,1.1373631,3.2081215,,,,,,,,,,,,,, -18600,0.87380946,5.225572,,,,,,,,,,,,,, -18700,0.84925765,4.908679,,,,,,,,,,,,,, -18800,0.8737152,4.394935,,,,,,,,,,,,,, -18900,1.2545123,2.8492136,,,,,,,,,,,,,, -19000,1.1802607,2.9952111,,,,,,,,,,,,,, -19064,,,0.5706835985183716,1.8196309804916384,0.5245599746704102,2.0479772090911865,50000.0,0.4075000286102295,2.712984800338745,10000.0,8445.054428339005,9042.762500047684,8445.054428339005,596.0462672710419,0.6375689506530762,0.0 -19100,1.1190083,2.9950116,,,,,,,,,,,,,, -19200,1.1094483,3.1992354,,,,,,,,,,,,,, -19300,1.2771611,3.4456599,,,,,,,,,,,,,, -19400,1.0075309,5.381778,,,,,,,,,,,,,, -19500,0.9033009,5.423044,,,,,,,,,,,,,, -19600,0.86361283,4.4362373,,,,,,,,,,,,,, -19700,1.349393,2.9496324,,,,,,,,,,,,,, -19800,1.1100781,2.9041998,,,,,,,,,,,,,, -19900,0.82164437,5.185606,,,,,,,,,,,,,, -20000,0.93239915,3.9971712,,,,,,,,,,,,,, -20015,,,0.5750390291213989,1.813962697982788,0.5193799734115601,2.0880651473999023,50000.0,0.4067000150680542,2.7463629245758057,10000.0,8865.098269224167,9497.909968614578,8865.098269224167,631.0662536621094,0.6704673767089844,0.0 -20100,0.92651683,4.6189737,,,,,,,,,,,,,, -20200,0.9988021,5.504145,,,,,,,,,,,,,, -20300,0.9391288,3.7023485,,,,,,,,,,,,,, -20400,1.0751597,2.9851604,,,,,,,,,,,,,, -20500,0.90226495,5.3384175,,,,,,,,,,,,,, -20600,1.1451899,2.8263326,,,,,,,,,,,,,, -20700,0.9839917,5.053564,,,,,,,,,,,,,, -20800,0.9469698,4.4400086,,,,,,,,,,,,,, -20900,1.1383582,2.8538973,,,,,,,,,,,,,, -20959,,,0.5743163824081421,1.8025161027908323,0.5355199575424194,2.0042288303375244,50000.0,0.4192000329494476,2.6817190647125244,10000.0,9285.359723567964,9955.302630662918,9285.359723567964,668.1161289215088,0.7022011280059814,0.0 -21000,1.1485456,2.8813477,,,,,,,,,,,,,, -21100,1.176533,2.9412093,,,,,,,,,,,,,, -21200,1.1144516,2.8269315,,,,,,,,,,,,,, -21300,1.2183397,2.9798641,,,,,,,,,,,,,, -21400,0.9927148,3.6772053,,,,,,,,,,,,,, -21500,1.1346314,2.9321272,,,,,,,,,,,,,, -21600,0.9130317,4.9380655,,,,,,,,,,,,,, -21700,1.2042564,2.728802,,,,,,,,,,,,,, -21800,0.853802,5.2082667,,,,,,,,,,,,,, -21900,1.2969528,2.7546763,,,,,,,,,,,,,, -21905,,,0.5809569954872131,1.789517521858215,0.5389999747276306,1.9918287992477417,50000.0,0.4274000227451324,2.6510984897613525,10000.0,9705.488340377808,10409.810875177383,9705.488340377808,702.4113335609436,0.7362079620361328,0.0 -22000,0.9747676,4.094474,,,,,,,,,,,,,, -22100,1.0523632,3.402071,,,,,,,,,,,,,, -22200,1.1646888,2.9801784,,,,,,,,,,,,,, -22300,1.195805,2.9186358,,,,,,,,,,,,,, -22400,1.0439563,2.6909754,,,,,,,,,,,,,, -22500,1.2085917,2.7769177,,,,,,,,,,,,,, -22600,1.169906,2.6614428,,,,,,,,,,,,,, -22700,0.8304616,5.205865,,,,,,,,,,,,,, -22800,1.0287968,4.3011055,,,,,,,,,,,,,, -22848,,,0.5851757526397705,1.7660349607467651,0.5362200140953064,2.0083796977996826,50000.0,0.4195000231266022,2.6842269897460938,10000.0,10125.69636964798,10865.72094798088,10125.69636964798,738.0171258449554,0.7833480834960938,0.0 -22900,0.8284901,5.346827,,,,,,,,,,,,,, -23000,0.9003316,4.0301585,,,,,,,,,,,,,, -23100,1.1126939,2.855842,,,,,,,,,,,,,, -23200,1.1580652,2.86785,,,,,,,,,,,,,, -23300,1.1492636,2.9499536,,,,,,,,,,,,,, -23400,1.0267804,2.9825895,,,,,,,,,,,,,, -23500,1.0949496,2.779293,,,,,,,,,,,,,, -23600,1.1368784,2.8774602,,,,,,,,,,,,,, -23700,1.1651165,2.8757553,,,,,,,,,,,,,, -23799,,,0.6068359017372131,1.6197729110717771,0.5496999621391296,1.906225562095642,50000.0,0.4353000223636627,2.58562970161438,10000.0,10545.699042081833,11321.33425951004,10545.699042081833,773.5448224544525,0.8156635761260986,0.0 -23800,1.1692426,5.064231,,,,,,,,,,,,,, -23900,1.0264993,5.308258,,,,,,,,,,,,,, -24000,0.98034585,4.6629224,,,,,,,,,,,,,, -24100,1.1027516,5.2092366,,,,,,,,,,,,,, -24200,0.9890532,3.581955,,,,,,,,,,,,,, -24300,0.98723227,5.425141,,,,,,,,,,,,,, -24400,0.8707542,4.709911,,,,,,,,,,,,,, -24500,1.1115896,3.28398,,,,,,,,,,,,,, -24600,1.0176456,4.4672117,,,,,,,,,,,,,, -24700,1.2213782,2.831305,,,,,,,,,,,,,, -24751,,,0.5919336080551147,1.7640339136123655,0.5448200106620789,1.9703313112258911,50000.0,0.4314000308513641,2.6117215156555176,10000.0,10965.678258657455,11776.092476606367,10965.678258657455,808.2378516197205,0.850771427154541,0.0 -24800,1.1539719,2.7291925,,,,,,,,,,,,,, -24900,1.2879319,2.5812964,,,,,,,,,,,,,, -25000,1.0081451,3.1070518,,,,,,,,,,,,,, -25100,1.0138053,3.7035024,,,,,,,,,,,,,, -25200,1.2074387,2.7475796,,,,,,,,,,,,,, -25300,1.3384494,2.7647781,,,,,,,,,,,,,, -25400,1.1882435,2.6959257,,,,,,,,,,,,,, -25500,1.0554063,4.2969046,,,,,,,,,,,,,, -25600,1.0830889,2.713995,,,,,,,,,,,,,, -25696,,,0.6055663824081421,1.681242823600769,0.560539960861206,1.8968688249588013,50000.0,0.4390000104904175,2.579681873321533,10000.0,11385.718565702438,12231.912591457369,11385.718565702438,843.9393339157104,0.8788700103759766,0.0 -25700,1.1717474,2.7460246,,,,,,,,,,,,,, -25800,1.0246893,3.7627053,,,,,,,,,,,,,, -25900,1.0378082,3.8598228,,,,,,,,,,,,,, -26000,1.2990141,2.766554,,,,,,,,,,,,,, -26100,0.95166427,3.289452,,,,,,,,,,,,,, -26200,1.0831043,4.4445777,,,,,,,,,,,,,, -26300,1.2646904,2.6900864,,,,,,,,,,,,,, -26400,0.93513024,4.1286373,,,,,,,,,,,,,, -26500,1.2170974,2.6996362,,,,,,,,,,,,,, -26600,0.9672953,4.6912084,,,,,,,,,,,,,, -26633,,,0.6142968535423279,1.5852680206298828,0.5594599843025208,1.871098756790161,50000.0,0.4310000240802765,2.5757434368133545,10000.0,11805.5884308815,12686.59158539772,11805.5884308815,878.3297283649445,1.248054265975952,0.0 -26700,0.93272847,4.063575,,,,,,,,,,,,,, -26800,1.1430478,5.1981883,,,,,,,,,,,,,, -26900,0.8879929,4.3335795,,,,,,,,,,,,,, -27000,1.2153503,3.140564,,,,,,,,,,,,,, -27100,0.976222,3.4674888,,,,,,,,,,,,,, -27200,1.0258497,4.206315,,,,,,,,,,,,,, -27300,0.97785056,3.6398492,,,,,,,,,,,,,, -27400,1.0980967,4.4707727,,,,,,,,,,,,,, -27500,1.0500875,4.383661,,,,,,,,,,,,,, -27578,,,0.6026562452316284,1.6675916910171509,0.5617600083351135,1.8550418615341189,50000.0,0.442300021648407,2.531644105911255,10000.0,12225.917182445526,13141.4329662323,12225.917182445526,912.7611672878264,1.279726266860962,0.0 -27600,1.1996074,5.387237,,,,,,,,,,,,,, -27700,1.2048625,2.9690154,,,,,,,,,,,,,, -27800,1.1004432,2.724043,,,,,,,,,,,,,, -27900,1.1663715,3.3114057,,,,,,,,,,,,,, -28000,1.1655135,2.8441858,,,,,,,,,,,,,, -28100,1.3038025,2.6330085,,,,,,,,,,,,,, -28200,1.2535076,2.722537,,,,,,,,,,,,,, -28300,1.1157653,2.5800219,,,,,,,,,,,,,, -28400,1.215606,2.6019812,,,,,,,,,,,,,, -28500,1.2119447,2.6951358,,,,,,,,,,,,,, -28520,,,0.6154882907867432,1.6122013330459597,0.5725399851799011,1.828546643257141,50000.0,0.4536000192165375,2.510180950164795,10000.0,12645.979301929474,13596.697838544846,12645.979301929474,947.8815426826476,1.3121697902679443,0.0 -28600,1.2087928,2.7488143,,,,,,,,,,,,,, -28700,1.1809031,2.7344244,,,,,,,,,,,,,, -28800,1.2333127,2.6318226,,,,,,,,,,,,,, -28900,1.1140182,2.6875813,,,,,,,,,,,,,, -29000,1.205487,2.8019795,,,,,,,,,,,,,, -29100,1.110916,2.8678792,,,,,,,,,,,,,, -29200,0.95721185,3.8906493,,,,,,,,,,,,,, -29300,0.9970534,3.4100938,,,,,,,,,,,,,, -29400,0.9651463,4.6181836,,,,,,,,,,,,,, -29466,,,0.6235937476158142,1.5766786336898804,0.5673199892044067,1.8261879682540887,50000.0,0.4462000131607055,2.499765634536743,10000.0,13066.364196300508,14052.40638923645,13066.364196300508,983.126780986786,1.3409326076507568,0.0 -29500,1.173242,2.7551184,,,,,,,,,,,,,, -29600,1.0471139,3.7761786,,,,,,,,,,,,,, -29700,1.2699002,2.6417818,,,,,,,,,,,,,, -29800,1.143106,2.634532,,,,,,,,,,,,,, -29900,1.1287999,2.5970988,,,,,,,,,,,,,, -30000,1.0606915,3.7682161,,,,,,,,,,,,,, -30100,0.92067933,4.1795125,,,,,,,,,,,,,, -30200,1.043401,4.777623,,,,,,,,,,,,,, -30300,1.1442255,2.6804552,,,,,,,,,,,,,, -30400,1.2134315,2.492616,,,,,,,,,,,,,, -30410,,,0.6274023056030273,1.577084183692932,0.5665199756622314,1.86709463596344,50000.0,0.4428000152111053,2.5430333614349365,10000.0,13486.317937612534,14507.29845237732,13486.317937612534,1017.985541820526,1.370903491973877,0.0 -30500,1.2227534,2.578116,,,,,,,,,,,,,, -30600,0.9843423,4.710114,,,,,,,,,,,,,, -30700,1.0783491,2.6018045,,,,,,,,,,,,,, -30800,1.1064752,2.6067264,,,,,,,,,,,,,, -30900,0.9362551,4.925211,,,,,,,,,,,,,, -31000,1.1584302,2.5746145,,,,,,,,,,,,,, -31100,1.0465581,3.0778258,,,,,,,,,,,,,, -31200,0.99387664,5.11162,,,,,,,,,,,,,, -31300,0.9469478,3.7561808,,,,,,,,,,,,,, -31356,,,0.6219531297683716,1.5759414434432983,0.581279993057251,1.7813587188720703,50000.0,0.4594000279903412,2.448460102081299,10000.0,13906.673711299896,14962.702766418455,13906.673711299896,1052.953050851822,1.4022631645202637,0.0 -31400,1.107769,3.223869,,,,,,,,,,,,,, -31500,0.97451836,5.002713,,,,,,,,,,,,,, -31600,1.1284819,5.2701316,,,,,,,,,,,,,, -31700,1.1607068,2.490797,,,,,,,,,,,,,, -31800,0.98663765,4.1487875,,,,,,,,,,,,,, -31900,1.2243937,2.5679235,,,,,,,,,,,,,, -32000,1.1120294,2.885366,,,,,,,,,,,,,, -32100,1.0291951,5.0092783,,,,,,,,,,,,,, -32200,1.028215,4.3228283,,,,,,,,,,,,,, -32300,1.1643714,2.9376373,,,,,,,,,,,,,, -32301,,,0.6274804472923279,1.536577582359314,0.5806599855422974,1.76521897315979,50000.0,0.4629000127315521,2.435957193374634,10000.0,14327.127075195312,15418.314130306244,14327.127075195312,1088.0273683071136,1.436601638793945,0.0 -32400,1.0127714,4.067987,,,,,,,,,,,,,, -32500,1.1966802,2.7518125,,,,,,,,,,,,,, -32600,1.3497163,2.5337782,,,,,,,,,,,,,, -32700,1.1375376,2.5498543,,,,,,,,,,,,,, -32800,1.1550466,2.6897798,,,,,,,,,,,,,, -32900,1.1640841,2.6090188,,,,,,,,,,,,,, -33000,1.124273,3.5402975,,,,,,,,,,,,,, -33100,1.0105333,4.609523,,,,,,,,,,,,,, -33200,1.313343,2.6492603,,,,,,,,,,,,,, -33248,,,0.6382226347923279,1.4830291271209717,0.5817999839782715,1.751964449882507,50000.0,0.4598000347614288,2.4462242126464844,10000.0,14747.1489007473,15871.41714978218,14747.1489007473,1121.0258762836456,1.4687433242797852,0.0 -33300,1.1314348,2.5604925,,,,,,,,,,,,,, -33400,1.058029,3.221877,,,,,,,,,,,,,, -33500,1.1758696,2.6279626,,,,,,,,,,,,,, -33600,1.0763655,4.9296,,,,,,,,,,,,,, -33700,1.2253872,2.6703744,,,,,,,,,,,,,, -33800,1.1306496,2.545767,,,,,,,,,,,,,, -33900,1.2468762,2.5705285,,,,,,,,,,,,,, -34000,1.1202002,2.9879353,,,,,,,,,,,,,, -34100,1.2254157,2.5606937,,,,,,,,,,,,,, -34193,,,0.6274804472923279,1.5200376510620115,0.5874999761581421,1.736883521080017,50000.0,0.4635000228881836,2.436673402786255,10000.0,15167.075635671616,16326.279378652573,15167.075635671616,1155.8789296150208,1.502777338027954,0.0 -34200,1.1666813,2.4958854,,,,,,,,,,,,,, -34300,1.1916382,2.676729,,,,,,,,,,,,,, -34400,1.2513338,2.7171226,,,,,,,,,,,,,, -34500,1.1907945,2.563634,,,,,,,,,,,,,, -34600,1.3269029,4.84118,,,,,,,,,,,,,, -34700,1.2091471,2.6767082,,,,,,,,,,,,,, -34800,1.2481854,2.3846793,,,,,,,,,,,,,, -34900,1.278452,2.6589675,,,,,,,,,,,,,, -35000,1.127184,2.8154922,,,,,,,,,,,,,, -35100,1.100719,2.5409226,,,,,,,,,,,,,, -35139,,,0.6370312571525574,1.5092811584472656,0.5901399850845337,1.7288663387298584,50000.0,0.4652000367641449,2.4026684761047363,10000.0,15587.07857298851,16780.82415175438,15587.07857298851,1190.3300416469574,1.5448389053344729,0.0 -35200,1.1092387,3.023422,,,,,,,,,,,,,, -35300,1.0983182,3.3191319,,,,,,,,,,,,,, -35400,1.1826049,2.5367699,,,,,,,,,,,,,, -35500,1.0120949,3.3337379,,,,,,,,,,,,,, -35600,1.1279292,2.543155,,,,,,,,,,,,,, -35700,1.1318369,4.96092,,,,,,,,,,,,,, -35800,1.1697959,3.1015775,,,,,,,,,,,,,, -35900,1.273455,2.5008793,,,,,,,,,,,,,, -36000,1.3824543,2.6117456,,,,,,,,,,,,,, -36086,,,0.6399999856948853,1.4712377786636353,0.5904200077056885,1.7305066585540771,50000.0,0.4682000279426574,2.412063837051392,10000.0,16007.0850918293,17236.377789497375,16007.0850918293,1225.792776346207,1.5797712802886963,0.0 -36100,0.99453276,4.0781746,,,,,,,,,,,,,, -36200,1.1864842,2.5826573,,,,,,,,,,,,,, -36300,1.1444932,2.4265475,,,,,,,,,,,,,, -36400,1.0106572,5.147317,,,,,,,,,,,,,, -36500,1.17804,2.4708135,,,,,,,,,,,,,, -36600,1.22017,2.6863596,,,,,,,,,,,,,, -36700,1.0900972,2.9219368,,,,,,,,,,,,,, -36800,1.160164,4.012141,,,,,,,,,,,,,, -36900,1.297041,2.591496,,,,,,,,,,,,,, -37000,1.0876386,4.0060544,,,,,,,,,,,,,, -37030,,,0.6620116829872131,1.3794057369232178,0.592519998550415,1.7107194662094116,50000.0,0.4718000292778015,2.398625612258911,10000.0,16427.407952070236,17691.82711672783,16427.407952070236,1260.8337228298187,1.616624116897583,0.0 -37100,1.1151642,2.599812,,,,,,,,,,,,,, -37200,1.20812,4.7822104,,,,,,,,,,,,,, -37300,1.0370002,2.8122258,,,,,,,,,,,,,, -37400,0.95902514,4.1017447,,,,,,,,,,,,,, -37500,1.1966752,2.8011303,,,,,,,,,,,,,, -37600,1.2338338,2.665105,,,,,,,,,,,,,, -37700,1.0814588,5.113159,,,,,,,,,,,,,, -37800,1.2805444,2.5906398,,,,,,,,,,,,,, -37900,1.1804501,3.1523554,,,,,,,,,,,,,, -37975,,,0.6415038704872131,1.4643170833587646,0.5976799726486206,1.683479905128479,50000.0,0.4763000309467315,2.375365257263184,10000.0,16847.74187517166,18147.40429544449,16847.74187517166,1295.9946694374084,1.6496033668518066,0.0 -38000,1.0326716,3.0454335,,,,,,,,,,,,,, -38100,1.2194754,2.5777426,,,,,,,,,,,,,, -38200,1.230752,2.6348126,,,,,,,,,,,,,, -38300,1.1450179,2.5230813,,,,,,,,,,,,,, -38400,1.1725556,2.5320816,,,,,,,,,,,,,, -38500,1.0296061,3.9415145,,,,,,,,,,,,,, -38600,1.3146527,2.5381482,,,,,,,,,,,,,, -38700,1.266895,2.6400309,,,,,,,,,,,,,, -38800,1.0919614,5.299731,,,,,,,,,,,,,, -38900,1.2112705,2.5850096,,,,,,,,,,,,,, -38922,,,0.6445702910423279,1.4891875982284546,0.5958200097084045,1.7240655422210691,50000.0,0.4737000167369842,2.377290725708008,10000.0,17267.9571518898,18601.86097741127,17267.9571518898,1330.1508178710938,1.68623948097229,0.0 -39000,1.3427856,2.5177329,,,,,,,,,,,,,, -39100,1.0457023,3.5520036,,,,,,,,,,,,,, -39200,1.0044054,3.6856186,,,,,,,,,,,,,, -39300,1.3048046,2.5182197,,,,,,,,,,,,,, -39400,1.2089934,2.4571257,,,,,,,,,,,,,, -39500,0.9835787,3.8829074,,,,,,,,,,,,,, -39600,1.0898389,3.479741,,,,,,,,,,,,,, -39700,1.3156639,3.0004125,,,,,,,,,,,,,, -39800,1.2007627,2.3792508,,,,,,,,,,,,,, -39869,,,0.6550585627555847,1.4371448755264282,0.5964199900627136,1.7216540575027466,50000.0,0.4751000106334686,2.372420072555542,10000.0,17688.084655046463,19056.70727801323,17688.084655046463,1364.786494731903,1.719895601272583,0.0 -39900,1.2824271,2.547844,,,,,,,,,,,,,, -40000,1.2197319,2.4750001,,,,,,,,,,,,,, -40100,1.2674892,2.6249092,,,,,,,,,,,,,, -40200,1.1555836,2.8023636,,,,,,,,,,,,,, -40300,1.265906,2.5395036,,,,,,,,,,,,,, -40400,1.0337006,4.2775354,,,,,,,,,,,,,, -40500,1.2394334,2.455035,,,,,,,,,,,,,, -40600,1.198485,2.5001616,,,,,,,,,,,,,, -40700,1.1383971,2.4021983,,,,,,,,,,,,,, -40800,1.0095878,3.6620429,,,,,,,,,,,,,, -40814,,,0.6373828053474426,1.492069959640503,0.5970399975776672,1.6841896772384644,50000.0,0.4708000123500824,2.375142335891724,10000.0,18108.195221424103,19510.174768209457,18108.195221424103,1398.0626657009125,1.75140643119812,0.0 -40900,1.1044613,3.3095894,,,,,,,,,,,,,, -41000,1.1546869,2.9371536,,,,,,,,,,,,,, -41100,1.1509374,4.65441,,,,,,,,,,,,,, -41200,1.1980945,2.3960514,,,,,,,,,,,,,, -41300,0.99592084,3.949919,,,,,,,,,,,,,, -41400,1.1935765,2.5192246,,,,,,,,,,,,,, -41500,1.023718,4.2863646,,,,,,,,,,,,,, -41600,1.187492,2.5485075,,,,,,,,,,,,,, -41700,1.1858627,3.0692637,,,,,,,,,,,,,, -41758,,,0.6490429639816284,1.4463257789611816,0.5996599793434143,1.6779924631118774,50000.0,0.480400025844574,2.345698833465576,10000.0,18528.375126600266,19964.89772510529,18528.375126600266,1432.521831035614,1.785491704940796,0.0 -41800,1.0048257,5.073668,,,,,,,,,,,,,, -41900,1.0179521,4.9469705,,,,,,,,,,,,,, -42000,1.2505324,2.506269,,,,,,,,,,,,,, -42100,1.0241601,3.6777697,,,,,,,,,,,,,, -42200,1.2116493,2.4534435,,,,,,,,,,,,,, -42300,1.2588037,3.8861165,,,,,,,,,,,,,, -42400,1.0383238,5.0734434,,,,,,,,,,,,,, -42500,1.2499639,2.532209,,,,,,,,,,,,,, -42600,1.2224222,2.8047075,,,,,,,,,,,,,, -42700,1.2551479,2.4338415,,,,,,,,,,,,,, -42704,,,0.6508007645606995,1.487926721572876,0.5974400043487549,1.734683871269226,50000.0,0.4749000370502472,2.391242742538452,10000.0,18948.30942368508,20420.125157356262,18948.30942368508,1467.7309651374817,1.8197968006134035,0.0 -42800,1.57637,5.279996,,,,,,,,,,,,,, -42900,1.2082658,2.3607273,,,,,,,,,,,,,, -43000,0.9591841,4.713305,,,,,,,,,,,,,, -43100,1.1942364,2.352707,,,,,,,,,,,,,, -43200,1.288153,2.4317522,,,,,,,,,,,,,, -43300,1.4293162,2.5008864,,,,,,,,,,,,,, -43400,1.3250049,2.5397487,,,,,,,,,,,,,, -43500,1.3180739,2.6978111,,,,,,,,,,,,,, -43600,1.2245243,2.4993386,,,,,,,,,,,,,, -43650,,,0.684277355670929,1.2981736660003662,0.6021400094032288,1.6690380573272705,50000.0,0.4786000251770019,2.350553274154663,10000.0,19368.389864444733,20876.173667669296,19368.389864444733,1503.6146020889282,1.8551933765411377,0.0 -43700,1.1106586,2.3971338,,,,,,,,,,,,,, -43800,1.1954167,2.430955,,,,,,,,,,,,,, -43900,1.3460304,2.805135,,,,,,,,,,,,,, -44000,1.2412328,2.4466543,,,,,,,,,,,,,, -44100,1.1569681,3.9905653,,,,,,,,,,,,,, -44200,1.3819703,2.4895973,,,,,,,,,,,,,, -44300,1.0685633,4.9115195,,,,,,,,,,,,,, -44400,1.1751838,3.0016,,,,,,,,,,,,,, -44500,1.7045406,5.2621946,,,,,,,,,,,,,, -44598,,,0.6547460556030273,1.4286702871322632,0.6058399677276611,1.6572929620742798,50000.0,0.4823000133037567,2.3217082023620605,10000.0,19788.714443922043,21331.062341213223,19788.714443922043,1538.0971751213074,1.8865134716033936,0.0 -44600,1.285448,2.4128206,,,,,,,,,,,,,, -44700,1.2644719,4.8309526,,,,,,,,,,,,,, -44800,1.3102405,2.7144248,,,,,,,,,,,,,, -44900,1.2217313,2.5583827,,,,,,,,,,,,,, -45000,1.2561014,2.4853516,,,,,,,,,,,,,, -45100,1.4044193,2.4187465,,,,,,,,,,,,,, -45200,1.2157171,2.4627478,,,,,,,,,,,,,, -45300,1.2119786,2.6113176,,,,,,,,,,,,,, -45400,1.263931,2.4410462,,,,,,,,,,,,,, -45500,1.1135315,3.3401375,,,,,,,,,,,,,, -45540,,,0.660351574420929,1.3766939640045166,0.6067799925804138,1.6368731260299685,50000.0,0.4834000170230865,2.2972004413604736,10000.0,20208.8225941658,21785.28338742256,20208.8225941658,1572.1255660057068,1.9221677780151367,0.0 -45600,1.1385725,2.2905898,,,,,,,,,,,,,, -45700,1.0016575,4.943992,,,,,,,,,,,,,, -45800,1.2956101,2.4106953,,,,,,,,,,,,,, -45900,1.0603933,3.607759,,,,,,,,,,,,,, -46000,1.1277891,2.7348843,,,,,,,,,,,,,, -46100,1.229444,2.3299074,,,,,,,,,,,,,, -46200,1.2574618,2.4354324,,,,,,,,,,,,,, -46300,1.2121252,2.483388,,,,,,,,,,,,,, -46400,1.046183,4.6927223,,,,,,,,,,,,,, -46483,,,0.6699609160423279,1.3838269710540771,0.6090999841690063,1.6606667041778564,50000.0,0.4816000163555145,2.329615592956543,10000.0,20628.96104216576,22240.189120054245,20628.96104216576,1606.8116953372955,1.9543015956878664,0.0 -46500,1.2690717,2.381744,,,,,,,,,,,,,, -46600,1.1556662,2.6309035,,,,,,,,,,,,,, -46700,1.3073223,2.4341135,,,,,,,,,,,,,, -46800,1.1523336,2.8250957,,,,,,,,,,,,,, -46900,1.341119,2.3936872,,,,,,,,,,,,,, -47000,1.4729269,5.043162,,,,,,,,,,,,,, -47100,1.3246627,2.4440722,,,,,,,,,,,,,, -47200,1.253425,2.504995,,,,,,,,,,,,,, -47300,1.5105829,5.0960364,,,,,,,,,,,,,, -47400,1.2432115,2.3523889,,,,,,,,,,,,,, -47427,,,0.6600781083106995,1.3950390815734863,0.6157999634742737,1.612318754196167,50000.0,0.4983000159263611,2.286940097808838,10000.0,21049.05564045906,22696.17867732048,21049.05564045906,1642.6166734695437,1.9951858520507808,0.0 -47500,1.0752927,3.2190316,,,,,,,,,,,,,, -47600,1.0694995,3.703523,,,,,,,,,,,,,, -47700,1.2110935,2.8250282,,,,,,,,,,,,,, -47800,1.1444892,2.5160155,,,,,,,,,,,,,, -47900,1.335871,2.4461248,,,,,,,,,,,,,, -48000,1.0037881,4.7100587,,,,,,,,,,,,,, -48100,1.0735785,3.4067307,,,,,,,,,,,,,, -48200,1.0489411,3.7467475,,,,,,,,,,,,,, -48300,1.0181801,3.6498015,,,,,,,,,,,,,, -48373,,,0.6630077958106995,1.3899929523468018,0.6116600036621094,1.618957281112671,50000.0,0.4912000298500061,2.276652336120605,10000.0,21469.06161904335,23152.705780267715,21469.06161904335,1679.050225496292,2.0333759784698486,0.0 -48400,1.1810868,2.5131092,,,,,,,,,,,,,, -48500,1.3251204,2.220574,,,,,,,,,,,,,, -48600,1.2093983,2.4902499,,,,,,,,,,,,,, -48700,1.286828,2.3324096,,,,,,,,,,,,,, -48800,0.99539214,4.176358,,,,,,,,,,,,,, -48900,1.0445399,4.218764,,,,,,,,,,,,,, -49000,1.2162267,2.689359,,,,,,,,,,,,,, -49100,1.2550758,2.3826747,,,,,,,,,,,,,, -49200,1.1989534,2.8957682,,,,,,,,,,,,,, -49300,1.4034771,2.4259496,,,,,,,,,,,,,, -49317,,,0.6714062094688416,1.346708059310913,0.617680013179779,1.5923179388046265,50000.0,0.4905000329017639,2.2634685039520264,10000.0,21889.25936102867,23608.112749814987,21889.25936102867,1714.1676445007324,2.0756874084472656,0.0 -49400,0.9794516,3.7782576,,,,,,,,,,,,,, -49500,1.188405,2.4948905,,,,,,,,,,,,,, -49600,1.0286293,3.8763142,,,,,,,,,,,,,, -49700,1.0723702,3.678457,,,,,,,,,,,,,, -49800,1.193076,3.5825014,,,,,,,,,,,,,, -49900,1.2086849,2.6545029,,,,,,,,,,,,,, -50000,1.1760799,4.402931,,,,,,,,,,,,,, -50100,1.0431122,3.4686475,,,,,,,,,,,,,, -50200,1.154687,3.5757902,,,,,,,,,,,,,, -50263,,,0.696582019329071,1.2489502429962158,0.6180999875068665,1.5910358428955078,50000.0,0.4958000183105469,2.265782594680786,10000.0,22309.52768969536,24063.864980220795,22309.52768969536,1749.5697557926178,2.108650684356689,0.0 -50300,1.0495954,3.0815687,,,,,,,,,,,,,, -50400,1.1660169,2.4772055,,,,,,,,,,,,,, -50500,1.2326932,2.4475403,,,,,,,,,,,,,, -50600,1.2791818,2.4830642,,,,,,,,,,,,,, -50700,1.2646825,2.2768273,,,,,,,,,,,,,, -50800,1.3687489,2.2959552,,,,,,,,,,,,,, -50900,1.3146697,2.5754535,,,,,,,,,,,,,, -51000,1.103184,3.2818906,,,,,,,,,,,,,, -51100,1.3156292,2.3214896,,,,,,,,,,,,,, -51200,0.972365,4.325509,,,,,,,,,,,,,, -51208,,,0.6667382717132568,1.3661245107650757,0.6186599731445312,1.5887244939804075,50000.0,0.4981000125408172,2.2476613521575928,10000.0,22729.79816842079,24520.10476160049,22729.79816842079,1785.455048084259,2.143319845199585,0.0 -51300,1.2458324,2.7613657,,,,,,,,,,,,,, -51400,1.1698039,3.4156113,,,,,,,,,,,,,, -51500,1.2595271,2.4836822,,,,,,,,,,,,,, -51600,1.3154393,2.8043199,,,,,,,,,,,,,, -51700,1.2021135,2.4450712,,,,,,,,,,,,,, -51800,1.1923869,3.1755774,,,,,,,,,,,,,, -51900,1.0262476,4.6870503,,,,,,,,,,,,,, -52000,1.1217189,4.5130577,,,,,,,,,,,,,, -52100,1.2539567,2.7111351,,,,,,,,,,,,,, -52154,,,0.6708202958106995,1.3567956686019895,0.6210799813270569,1.588613986968994,50000.0,0.4946000277996063,2.257022857666016,10000.0,23150.10472869873,24976.49471473694,23150.10472869873,1821.455227851868,2.177870750427246,0.0 -52200,1.225659,2.3436208,,,,,,,,,,,,,, -52300,1.3463444,2.3430421,,,,,,,,,,,,,, -52400,1.0457697,3.2857332,,,,,,,,,,,,,, -52500,1.1200763,2.859001,,,,,,,,,,,,,, -52600,1.249573,2.3070428,,,,,,,,,,,,,, -52700,1.0739338,4.8199286,,,,,,,,,,,,,, -52800,1.3622166,4.7537017,,,,,,,,,,,,,, -52900,1.1451818,3.1937964,,,,,,,,,,,,,, -53000,1.2755463,2.3841076,,,,,,,,,,,,,, -53100,,,0.6807812452316284,1.2781988382339478,0.6184599995613098,1.5739054679870603,50000.0,0.4946000277996063,2.253964424133301,10000.0,23570.43446731568,25432.46706700325,23570.43446731568,1857.0161018371584,2.210293292999268,0.0 -53100,1.0756607,3.5330684,,,,,,,,,,,,,, -53200,1.2334946,3.9289386,,,,,,,,,,,,,, -53300,1.1797622,2.6645262,,,,,,,,,,,,,, -53400,1.2022269,2.4801843,,,,,,,,,,,,,, -53500,1.3374263,2.345292,,,,,,,,,,,,,, -53600,1.0829118,4.284131,,,,,,,,,,,,,, -53700,1.2724444,2.3006124,,,,,,,,,,,,,, -53800,1.1962907,4.3607435,,,,,,,,,,,,,, -53900,1.3155662,2.2922058,,,,,,,,,,,,,, -54000,1.1977859,2.8546014,,,,,,,,,,,,,, -54046,,,0.6664257645606995,1.3784408569335938,0.621679961681366,1.5963414907455444,50000.0,0.4976000189781189,2.259164571762085,10000.0,23990.75499391556,25888.12592768669,23990.75499391556,1892.2611465454104,2.2538437843322754,0.0 -54100,1.1663216,4.628264,,,,,,,,,,,,,, -54200,1.3057573,3.6611516,,,,,,,,,,,,,, -54300,1.3083719,2.36644,,,,,,,,,,,,,, -54400,1.114098,4.7440658,,,,,,,,,,,,,, -54500,1.2776103,4.6815214,,,,,,,,,,,,,, -54600,1.3071446,2.2238855,,,,,,,,,,,,,, -54700,1.2424852,2.3614879,,,,,,,,,,,,,, -54800,1.1310732,2.4259837,,,,,,,,,,,,,, -54900,1.2210773,2.682622,,,,,,,,,,,,,, -54991,,,0.6720898151397705,1.3343063592910769,0.6214399933815002,1.5615304708480835,50000.0,0.5001000165939331,2.2374267578125,10000.0,24410.766496419907,26344.1343998909,24410.766496419907,1928.172496318817,2.290111303329468,0.0 -55000,1.1514586,2.992862,,,,,,,,,,,,,, -55100,1.2871163,4.906395,,,,,,,,,,,,,, -55200,1.3135362,2.2907884,,,,,,,,,,,,,, -55300,1.375356,2.6499338,,,,,,,,,,,,,, -55400,1.202193,2.5697672,,,,,,,,,,,,,, -55500,1.2399689,2.2850227,,,,,,,,,,,,,, -55600,1.2151346,3.252194,,,,,,,,,,,,,, -55700,1.1375732,3.3117545,,,,,,,,,,,,,, -55800,1.1810095,4.8364854,,,,,,,,,,,,,, -55900,1.2424011,2.3337283,,,,,,,,,,,,,, -55936,,,0.6838671565055847,1.2733802795410156,0.6282599568367004,1.539625644683838,50000.0,0.5035000443458557,2.206250190734864,10000.0,24830.783933877945,26799.984984874725,24830.783933877945,1963.9228575229645,2.3235397338867188,0.0 -56000,1.1989585,2.3476696,,,,,,,,,,,,,, -56100,1.3889896,2.368789,,,,,,,,,,,,,, -56200,1.0833176,3.4181404,,,,,,,,,,,,,, -56300,1.4506294,2.2811909,,,,,,,,,,,,,, -56400,1.2557791,4.444822,,,,,,,,,,,,,, -56500,1.1352428,2.8918715,,,,,,,,,,,,,, -56600,1.3776867,2.4496326,,,,,,,,,,,,,, -56700,1.2527059,2.2997687,,,,,,,,,,,,,, -56800,1.1814852,3.1703959,,,,,,,,,,,,,, -56882,,,0.7085351347923279,1.1974025964736938,0.6280999779701233,1.562067627906799,50000.0,0.5056000351905823,2.223431348800659,10000.0,25250.83206653595,27254.522886514664,25250.83206653595,1998.3244211673737,2.3617398738861084,0.0 -56900,1.2645644,2.4385,,,,,,,,,,,,,, -57000,1.1028141,2.4485917,,,,,,,,,,,,,, -57100,1.3170148,2.3215733,,,,,,,,,,,,,, -57200,1.188847,3.0159974,,,,,,,,,,,,,, -57300,1.0370576,3.987545,,,,,,,,,,,,,, -57400,1.1476125,4.8666363,,,,,,,,,,,,,, -57500,1.1936704,2.378848,,,,,,,,,,,,,, -57600,1.1259891,4.157237,,,,,,,,,,,,,, -57700,1.1449192,3.385326,,,,,,,,,,,,,, -57800,1.2700418,3.5562313,,,,,,,,,,,,,, -57826,,,0.6792578101158142,1.2947694063186646,0.6309199929237366,1.5316017866134644,50000.0,0.5074000358581543,2.1904778480529785,10000.0,25671.16929149628,27709.973189353943,25671.16929149628,2033.3489353656769,2.401756525039673,0.0 -57900,1.1671015,4.857052,,,,,,,,,,,,,, -58000,1.2932266,2.244881,,,,,,,,,,,,,, -58100,1.3694398,2.3838634,,,,,,,,,,,,,, -58200,1.2589321,2.4506333,,,,,,,,,,,,,, -58300,1.3572506,4.9595222,,,,,,,,,,,,,, -58400,1.2339422,2.2091131,,,,,,,,,,,,,, -58500,1.277249,4.578297,,,,,,,,,,,,,, -58600,1.3190503,2.2829318,,,,,,,,,,,,,, -58700,1.2167369,3.2747169,,,,,,,,,,,,,, -58771,,,0.6832422018051147,1.3048343658447266,0.6307799816131592,1.5494695901870728,50000.0,0.5051000118255615,2.211732625961304,10000.0,26091.416990995407,28165.26349496841,26091.416990995407,2068.3087170124054,2.434241533279419,0.0 -58800,1.054162,4.8777013,,,,,,,,,,,,,, -58900,1.1392179,3.6609712,,,,,,,,,,,,,, -59000,1.3470308,2.6784465,,,,,,,,,,,,,, -59100,1.2229611,4.5777416,,,,,,,,,,,,,, -59200,1.1043873,4.171634,,,,,,,,,,,,,, -59300,1.2248547,2.3402479,,,,,,,,,,,,,, -59400,1.1578285,2.5837204,,,,,,,,,,,,,, -59500,1.1717042,2.6005828,,,,,,,,,,,,,, -59600,1.2986156,2.3472261,,,,,,,,,,,,,, -59700,1.2758651,2.376272,,,,,,,,,,,,,, -59717,,,0.6934570074081421,1.2318331003189087,0.6297599673271179,1.5253384113311768,50000.0,0.501300036907196,2.2092301845550537,10000.0,26511.82786488533,28620.62424898148,26511.82786488533,2103.1707775592804,2.470738172531128,0.0 -59800,1.2161132,2.2513237,,,,,,,,,,,,,, -59900,1.2431778,4.7212744,,,,,,,,,,,,,, -60000,1.1737669,2.9020345,,,,,,,,,,,,,, -60100,1.350238,2.2971334,,,,,,,,,,,,,, -60200,1.222517,4.703114,,,,,,,,,,,,,, -60300,1.2748181,2.4412003,,,,,,,,,,,,,, -60400,1.245625,2.3157935,,,,,,,,,,,,,, -60500,1.1539084,3.9858885,,,,,,,,,,,,,, -60600,1.1591467,4.893887,,,,,,,,,,,,,, -60661,,,0.6733202934265137,1.314584732055664,0.6304999589920044,1.5325448513031006,50000.0,0.5042000412940979,2.225181818008423,10000.0,26931.74744296074,29076.107455968857,26931.74744296074,2138.6504290103912,2.5045390129089355,0.0 -60700,1.1933299,2.235868,,,,,,,,,,,,,, -60800,1.2244864,3.3263116,,,,,,,,,,,,,, -60900,1.2511735,4.6590905,,,,,,,,,,,,,, -61000,1.441431,2.3186846,,,,,,,,,,,,,, -61100,1.2113897,3.8354206,,,,,,,,,,,,,, -61200,1.3394649,2.2577894,,,,,,,,,,,,,, -61300,1.3853465,2.4746485,,,,,,,,,,,,,, -61400,1.3867168,2.243767,,,,,,,,,,,,,, -61500,1.3136343,2.4121134,,,,,,,,,,,,,, -61600,1.0809553,4.905401,,,,,,,,,,,,,, -61605,,,0.6815820336341858,1.2969701290130615,0.6304599642753601,1.544524908065796,50000.0,0.5072000026702881,2.201844930648804,10000.0,27352.115279197693,29532.281440734863,27352.115279197693,2174.3722081184387,2.5386993885040283,0.0 -61700,1.2167054,2.4538531,,,,,,,,,,,,,, -61800,1.420204,2.4459276,,,,,,,,,,,,,, -61900,1.3458655,4.813558,,,,,,,,,,,,,, -62000,1.2868156,2.4285655,,,,,,,,,,,,,, -62100,1.2490717,3.393619,,,,,,,,,,,,,, -62200,1.2631886,2.2550864,,,,,,,,,,,,,, -62300,1.3270079,2.2368731,,,,,,,,,,,,,, -62400,1.3101344,2.2839234,,,,,,,,,,,,,, -62500,1.1566906,3.2080424,,,,,,,,,,,,,, -62550,,,0.694531261920929,1.2573353052139282,0.6335799694061279,1.5339800119400024,50000.0,0.5128999948501587,2.191729784011841,10000.0,27772.08389544487,29987.64568257332,27772.08389544487,2209.677117586136,2.578078031539917,0.0 -62600,1.1467814,2.7970402,,,,,,,,,,,,,, -62700,1.3508344,2.3563433,,,,,,,,,,,,,, -62800,1.3172591,2.2950835,,,,,,,,,,,,,, -62900,1.3284595,2.2818143,,,,,,,,,,,,,, -63000,1.3404685,2.1881251,,,,,,,,,,,,,, -63100,1.3191417,4.4694834,,,,,,,,,,,,,, -63200,1.3137655,2.528334,,,,,,,,,,,,,, -63300,1.2739912,2.154939,,,,,,,,,,,,,, -63400,1.3540733,2.377632,,,,,,,,,,,,,, -63495,,,0.719531238079071,1.1414695978164673,0.6376999616622925,1.5069578886032104,50000.0,0.515500009059906,2.1697821617126465,10000.0,28192.20817756653,30443.441133499146,28192.20817756653,2245.259070396424,2.6166160106658936,0.0 -63500,0.9803232,4.532763,,,,,,,,,,,,,, -63600,1.3841363,2.2862086,,,,,,,,,,,,,, -63700,1.2099315,2.3808596,,,,,,,,,,,,,, -63800,1.211951,3.9288025,,,,,,,,,,,,,, -63900,1.3160568,2.395529,,,,,,,,,,,,,, -64000,1.2321666,2.8450074,,,,,,,,,,,,,, -64100,1.2225958,3.06132,,,,,,,,,,,,,, -64200,1.016952,3.6245248,,,,,,,,,,,,,, -64300,1.1713951,3.9821193,,,,,,,,,,,,,, -64400,1.296382,2.2918577,,,,,,,,,,,,,, -64440,,,0.6832422018051147,1.289946436882019,0.6335200071334839,1.5233887434005735,50000.0,0.5054000020027161,2.1913552284240723,10000.0,28612.125376462936,30898.59340786934,28612.125376462936,2280.39914727211,2.660373449325561,0.0 -64500,1.1984292,3.1144733,,,,,,,,,,,,,, -64600,1.1485237,4.39045,,,,,,,,,,,,,, -64700,1.1312387,4.2997484,,,,,,,,,,,,,, -64800,1.3122715,2.1937268,,,,,,,,,,,,,, -64900,1.2868041,2.2620876,,,,,,,,,,,,,, -65000,1.2020475,3.9918118,,,,,,,,,,,,,, -65100,1.3364685,2.3334932,,,,,,,,,,,,,, -65200,1.1905528,3.6135592,,,,,,,,,,,,,, -65300,1.1215545,4.5082765,,,,,,,,,,,,,, -65386,,,0.694042980670929,1.2231028079986572,0.6383799910545349,1.484677791595459,50000.0,0.5078999996185303,2.162203073501587,10000.0,29032.208420038223,31354.42751383781,29032.208420038223,2316.05020904541,2.7098233699798584,0.0 -65400,1.3685045,3.6500592,,,,,,,,,,,,,, -65500,1.3615432,2.2471116,,,,,,,,,,,,,, -65600,1.1063262,4.755645,,,,,,,,,,,,,, -65700,1.3184894,2.243375,,,,,,,,,,,,,, -65800,1.2738465,2.2438362,,,,,,,,,,,,,, -65900,1.2207383,4.7565217,,,,,,,,,,,,,, -66000,1.2341542,2.23405,,,,,,,,,,,,,, -66100,1.1678938,3.2999377,,,,,,,,,,,,,, -66200,1.4187552,2.2747073,,,,,,,,,,,,,, -66300,1.3120039,4.9165754,,,,,,,,,,,,,, -66332,,,0.707226574420929,1.169092893600464,0.6391199827194214,1.4746456146240234,50000.0,0.5143000483512878,2.157905578613281,10000.0,29452.30614376068,31810.7338078022,29452.30614376068,2352.170639514923,2.747178792953491,0.0 -66400,1.3697618,2.1582115,,,,,,,,,,,,,, -66500,1.3578341,4.540818,,,,,,,,,,,,,, -66600,1.3393345,2.1887124,,,,,,,,,,,,,, -66700,1.3864164,2.2599082,,,,,,,,,,,,,, -66800,1.2499442,2.251029,,,,,,,,,,,,,, -66900,1.0973552,4.6849484,,,,,,,,,,,,,, -67000,1.2110738,2.7995129,,,,,,,,,,,,,, -67100,1.4284557,2.0371947,,,,,,,,,,,,,, -67200,1.4836972,2.3936944,,,,,,,,,,,,,, -67275,,,0.6863867044448853,1.2691597938537598,0.6340999603271484,1.5110490322113037,50000.0,0.5159000158309937,2.1513500213623047,10000.0,29872.2384326458,32266.239609479904,29872.2384326458,2387.653300523758,2.787407875061035,0.0 -67300,1.1805828,3.2199736,,,,,,,,,,,,,, -67400,1.1774364,3.2508621,,,,,,,,,,,,,, -67500,1.2778652,2.6146114,,,,,,,,,,,,,, -67600,1.3638561,2.2684584,,,,,,,,,,,,,, -67700,1.6812394,2.7048235,,,,,,,,,,,,,, -67800,1.1758149,3.4891474,,,,,,,,,,,,,, -67900,1.1692743,4.688966,,,,,,,,,,,,,, -68000,1.1777848,3.348898,,,,,,,,,,,,,, -68100,1.4017214,2.4501457,,,,,,,,,,,,,, -68200,1.2221726,2.3272946,,,,,,,,,,,,,, -68217,,,0.7003710865974426,1.183992624282837,0.6473000049591064,1.441631317138672,50000.0,0.5285000205039978,2.097337245941162,10000.0,30292.351333141327,32720.415112257004,30292.351333141327,2421.625636100769,2.82780122756958,0.0 -68300,1.3399388,2.1003299,,,,,,,,,,,,,, -68400,1.2839032,2.2258878,,,,,,,,,,,,,, -68500,1.3947445,2.412805,,,,,,,,,,,,,, -68600,1.1687654,2.6409588,,,,,,,,,,,,,, -68700,1.1087016,2.967247,,,,,,,,,,,,,, -68800,1.331187,2.1905785,,,,,,,,,,,,,, -68900,1.425862,2.265018,,,,,,,,,,,,,, -69000,1.3131473,2.43494,,,,,,,,,,,,,, -69100,1.3529093,2.1946032,,,,,,,,,,,,,, -69160,,,0.70570307970047,1.190212607383728,0.644819974899292,1.461923122406006,50000.0,0.517300009727478,2.136876106262207,10000.0,30712.3498916626,33174.16375398636,30712.3498916626,2455.2891149520874,2.863965511322021,0.0 -69200,1.2903434,2.3037946,,,,,,,,,,,,,, -69300,1.2098064,2.2673838,,,,,,,,,,,,,, -69400,1.3759888,2.3117604,,,,,,,,,,,,,, -69500,1.366628,2.2735934,,,,,,,,,,,,,, -69600,1.351006,4.7812414,,,,,,,,,,,,,, -69700,1.5095978,2.236403,,,,,,,,,,,,,, -69800,1.4992324,2.1916673,,,,,,,,,,,,,, -69900,1.1577312,3.3957357,,,,,,,,,,,,,, -70000,1.2149631,3.237231,,,,,,,,,,,,,, -70099,,,0.7290429472923279,1.0740240812301636,0.6439999938011169,1.463181734085083,50000.0,0.5184000134468079,2.132868766784668,10000.0,31131.891610860825,33628.90654087067,31131.891610860825,2489.757776260376,3.5462582111358643,0.0 -70100,1.3002081,2.2918198,,,,,,,,,,,,,, -70200,1.1555293,4.1418447,,,,,,,,,,,,,, -70300,1.3002505,2.192843,,,,,,,,,,,,,, -70400,1.3243111,2.2231636,,,,,,,,,,,,,, -70500,1.3061944,2.293176,,,,,,,,,,,,,, -70600,1.363395,2.395399,,,,,,,,,,,,,, -70700,1.1671867,3.9291363,,,,,,,,,,,,,, -70800,1.139416,4.1929526,,,,,,,,,,,,,, -70900,1.3007087,2.2500815,,,,,,,,,,,,,, -71000,1.3208604,2.3006892,,,,,,,,,,,,,, -71045,,,0.6943163871765137,1.2107747793197632,0.6439999938011169,1.4579524993896484,50000.0,0.5236000418663025,2.1252052783966064,10000.0,31552.1533575058,34084.3028922081,31552.1533575058,2524.795145511627,3.592628002166748,0.0 -71100,1.2558928,4.643901,,,,,,,,,,,,,, -71200,1.1007733,3.0202665,,,,,,,,,,,,,, -71300,1.284798,4.6936617,,,,,,,,,,,,,, -71400,1.3357363,2.2302327,,,,,,,,,,,,,, -71500,1.2598302,2.6488597,,,,,,,,,,,,,, -71600,1.115102,4.0312333,,,,,,,,,,,,,, -71700,1.1256866,3.2806168,,,,,,,,,,,,,, -71800,1.2265866,2.034457,,,,,,,,,,,,,, -71900,1.291183,2.3222175,,,,,,,,,,,,,, -71989,,,0.7100585699081421,1.1660125255584717,0.6522799730300903,1.4295166730880735,50000.0,0.5245000123977661,2.0866498947143555,10000.0,31972.31904554367,34538.16773843765,31972.31904554367,2558.4052596092224,3.631521224975586,0.0 -72000,1.4052569,2.3712327,,,,,,,,,,,,,, -72100,1.310664,2.1427293,,,,,,,,,,,,,, -72200,1.2932469,2.6695042,,,,,,,,,,,,,, -72300,1.342757,4.5042562,,,,,,,,,,,,,, -72400,1.4212332,2.3079863,,,,,,,,,,,,,, -72500,1.2845435,3.6086812,,,,,,,,,,,,,, -72600,1.1695148,3.2147706,,,,,,,,,,,,,, -72700,1.4321111,2.514233,,,,,,,,,,,,,, -72800,1.182979,2.7693317,,,,,,,,,,,,,, -72900,1.1373807,3.7073088,,,,,,,,,,,,,, -72930,,,0.7144140601158142,1.1511657238006592,0.6480000019073486,1.455902934074402,50000.0,0.5224000215530396,2.14300537109375,10000.0,32392.24674224853,34993.05684757233,32392.24674224853,2593.2753612995148,3.672046422958374,0.0 -73000,1.6301697,2.6022024,,,,,,,,,,,,,, -73100,1.4367408,2.076101,,,,,,,,,,,,,, -73200,1.2918211,3.9352107,,,,,,,,,,,,,, -73300,1.3433137,2.227786,,,,,,,,,,,,,, -73400,1.4064105,2.1675627,,,,,,,,,,,,,, -73500,1.1650598,4.508522,,,,,,,,,,,,,, -73600,1.3654445,2.1919277,,,,,,,,,,,,,, -73700,1.3565035,2.2975912,,,,,,,,,,,,,, -73800,1.4525515,2.2008219,,,,,,,,,,,,,, -73876,,,0.6960546970367432,1.236786723136902,0.64656001329422,1.4624232053756714,50000.0,0.5225000381469727,2.137120246887207,10000.0,32812.47220945358,35448.46819233894,32812.47220945358,2628.373400449753,3.709576606750488,0.0 -73900,1.222573,3.7827005,,,,,,,,,,,,,, -74000,1.3889471,2.0456123,,,,,,,,,,,,,, -74100,1.3873682,2.327807,,,,,,,,,,,,,, -74200,1.1997417,3.7194948,,,,,,,,,,,,,, -74300,1.2447021,3.7045918,,,,,,,,,,,,,, -74400,1.2793586,3.9641616,,,,,,,,,,,,,, -74500,1.3993738,2.2127936,,,,,,,,,,,,,, -74600,1.1680332,3.3960443,,,,,,,,,,,,,, -74700,1.1917135,2.9589808,,,,,,,,,,,,,, -74800,1.2169136,2.7890563,,,,,,,,,,,,,, -74819,,,0.7019921541213989,1.2093700170516968,0.65447998046875,1.444311022758484,50000.0,0.5272000432014465,2.114329099655152,10000.0,33232.751855134964,35904.79923796654,33232.751855134964,2664.336992740631,3.747407913208008,0.0 -74900,1.3086662,2.4902337,,,,,,,,,,,,,, -75000,1.2891397,2.1924458,,,,,,,,,,,,,, -75100,1.4289932,2.1275291,,,,,,,,,,,,,, -75200,1.491817,2.2660623,,,,,,,,,,,,,, -75300,1.1673026,4.4377074,,,,,,,,,,,,,, -75400,1.4381366,2.259884,,,,,,,,,,,,,, -75500,1.242074,3.0023708,,,,,,,,,,,,,, -75600,1.2785398,2.5359848,,,,,,,,,,,,,, -75700,1.3176181,2.405903,,,,,,,,,,,,,, -75760,,,0.710156261920929,1.1536434888839722,0.6520400047302246,1.417355179786682,50000.0,0.5243000388145447,2.0840115547180176,10000.0,33652.85684943199,36358.81673932076,33652.85684943199,2698.1623861789703,3.784353256225586,0.0 -75800,1.3348994,2.2319942,,,,,,,,,,,,,, -75900,1.3936882,2.0751863,,,,,,,,,,,,,, -76000,1.3532573,3.1542065,,,,,,,,,,,,,, -76100,1.370379,2.1820273,,,,,,,,,,,,,, -76200,1.3645779,2.4252732,,,,,,,,,,,,,, -76300,1.2475652,4.5076957,,,,,,,,,,,,,, -76400,1.1738005,4.088547,,,,,,,,,,,,,, -76500,1.3110285,2.146877,,,,,,,,,,,,,, -76600,1.4379679,2.1230555,,,,,,,,,,,,,, -76700,,,0.7394140362739563,1.0564640760421753,0.6574400067329407,1.422579288482666,50000.0,0.5314000248908997,2.086050271987915,10000.0,34073.014280080795,36812.09155678749,34073.014280080795,2731.18881893158,3.8260109424591056,0.0 -76700,1.2250727,2.5123055,,,,,,,,,,,,,, -76800,1.5595751,4.6553507,,,,,,,,,,,,,, -76900,1.1933813,3.5628896,,,,,,,,,,,,,, -77000,1.446982,2.1114843,,,,,,,,,,,,,, -77100,1.2848351,2.1719906,,,,,,,,,,,,,, -77200,1.214862,4.678985,,,,,,,,,,,,,, -77300,1.392488,2.3241463,,,,,,,,,,,,,, -77400,1.4607296,2.2827299,,,,,,,,,,,,,, -77500,1.3619771,2.2084734,,,,,,,,,,,,,, -77600,1.4859512,2.3131588,,,,,,,,,,,,,, -77644,,,0.7078710794448853,1.1789730787277222,0.6548399925231934,1.4233994483947754,50000.0,0.5337000489234924,2.083815813064575,10000.0,34493.03785729408,37268.30460429192,34493.03785729408,2767.2862479686737,3.867687463760376,0.0 -77700,1.3025068,2.251677,,,,,,,,,,,,,, -77800,1.536701,2.1555262,,,,,,,,,,,,,, -77900,1.3851521,4.671851,,,,,,,,,,,,,, -78000,1.4592186,2.2292156,,,,,,,,,,,,,, -78100,1.3132031,2.1368043,,,,,,,,,,,,,, -78200,1.3953578,2.2031589,,,,,,,,,,,,,, -78300,1.1300889,3.7306535,,,,,,,,,,,,,, -78400,1.496352,2.5013905,,,,,,,,,,,,,, -78500,1.3442137,2.6426134,,,,,,,,,,,,,, -78587,,,0.7059569954872131,1.1735212802886963,0.6548999547958374,1.41805100440979,50000.0,0.5367000102996826,2.07688570022583,10000.0,34913.28986310959,37725.69277334213,34913.28986310959,2804.3315374851227,3.908069133758545,0.0 -78600,1.3366169,2.311198,,,,,,,,,,,,,, -78700,1.3396332,2.1062021,,,,,,,,,,,,,, -78800,1.4365346,2.321193,,,,,,,,,,,,,, -78900,1.2333261,2.5798461,,,,,,,,,,,,,, -79000,1.345166,2.2598126,,,,,,,,,,,,,, -79100,1.3016587,3.1877756,,,,,,,,,,,,,, -79200,1.2948346,2.5988905,,,,,,,,,,,,,, -79300,1.610979,2.1867447,,,,,,,,,,,,,, -79400,1.2040689,3.5968814,,,,,,,,,,,,,, -79500,1.3724098,2.162994,,,,,,,,,,,,,, -79532,,,0.72802734375,1.105783462524414,0.6602999567985535,1.4155524969100952,50000.0,0.5297000408172607,2.0830442905426025,10000.0,35333.480610609055,38181.32603669167,35333.480610609055,2839.677448511124,3.9546329975128174,0.0 -79600,1.4924011,2.1935534,,,,,,,,,,,,,, -79700,1.3632729,2.160912,,,,,,,,,,,,,, -79800,1.2243325,2.9617693,,,,,,,,,,,,,, -79900,1.3577235,2.0936418,,,,,,,,,,,,,, -80000,1.3080577,2.1800303,,,,,,,,,,,,,, -80100,1.3544921,2.1147776,,,,,,,,,,,,,, -80200,1.3525949,2.5528505,,,,,,,,,,,,,, -80300,1.2332709,2.0957987,,,,,,,,,,,,,, -80400,1.4617045,2.302427,,,,,,,,,,,,,, -80477,,,0.71240234375,1.1545332670211792,0.6611999869346619,1.3902668952941897,50000.0,0.5308000445365906,2.055467367172241,10000.0,35753.704422950745,38637.558312892914,35753.704422950745,2875.596264362335,3.994290590286255,0.0 -80500,1.3659875,2.3360162,,,,,,,,,,,,,, -80600,1.4485886,2.2119951,,,,,,,,,,,,,, -80700,1.3111405,4.5735693,,,,,,,,,,,,,, -80800,1.3372192,2.6025941,,,,,,,,,,,,,, -80900,1.3177501,2.4319787,,,,,,,,,,,,,, -81000,1.2287903,2.9764924,,,,,,,,,,,,,, -81100,1.3763413,4.0270267,,,,,,,,,,,,,, -81200,1.3394603,2.1586907,,,,,,,,,,,,,, -81300,1.2204031,3.1058776,,,,,,,,,,,,,, -81400,1.5640152,2.2056985,,,,,,,,,,,,,, -81421,,,0.7125195264816284,1.1445448398590088,0.6576600074768066,1.3969004154205322,50000.0,0.5351000428199768,2.045471668243408,10000.0,36174.04013347626,39092.08462238312,36174.04013347626,2909.6886727809906,4.0421226024627686,0.0 -81500,1.2755387,3.2083292,,,,,,,,,,,,,, -81600,1.2460929,3.3234367,,,,,,,,,,,,,, -81700,1.293569,4.1800714,,,,,,,,,,,,,, -81800,1.1644739,2.9297547,,,,,,,,,,,,,, -81900,1.333271,2.0899742,,,,,,,,,,,,,, -82000,1.3102438,2.4600368,,,,,,,,,,,,,, -82100,1.4312506,2.0407305,,,,,,,,,,,,,, -82200,1.3688027,2.0918012,,,,,,,,,,,,,, -82300,1.388427,2.1032882,,,,,,,,,,,,,, -82367,,,0.7277148365974426,1.075430154800415,0.6658799648284912,1.359333872795105,50000.0,0.5442000031471252,2.011723279953003,10000.0,36594.36502742767,39548.16616392136,36594.36502742767,2945.35439991951,4.082128286361694,0.0 -82400,1.175987,2.8110743,,,,,,,,,,,,,, -82500,1.3795378,2.037957,,,,,,,,,,,,,, -82600,1.3049189,2.1912339,,,,,,,,,,,,,, -82700,1.4212209,2.3051198,,,,,,,,,,,,,, -82800,1.4390172,2.0734594,,,,,,,,,,,,,, -82900,1.2052926,4.0582776,,,,,,,,,,,,,, -83000,1.4674034,2.0158002,,,,,,,,,,,,,, -83100,1.2670062,4.5900407,,,,,,,,,,,,,, -83200,1.2568907,3.6124775,,,,,,,,,,,,,, -83300,1.5229743,2.187728,,,,,,,,,,,,,, -83313,,,0.7518945336341858,0.9783536195755004,0.6640399694442749,1.3719205856323242,50000.0,0.5373000502586365,2.03928279876709,10000.0,37014.328892469406,40003.478034973145,37014.328892469406,2980.6113533973694,4.123176574707031,0.0 -83400,1.4508029,2.2038958,,,,,,,,,,,,,, -83500,1.2398489,3.859984,,,,,,,,,,,,,, -83600,1.397077,1.9466245,,,,,,,,,,,,,, -83700,1.3081754,4.3048096,,,,,,,,,,,,,, -83800,1.349661,2.06146,,,,,,,,,,,,,, -83900,1.432986,2.1226294,,,,,,,,,,,,,, -84000,1.3050469,2.728543,,,,,,,,,,,,,, -84100,1.3462523,2.0049615,,,,,,,,,,,,,, -84200,1.4892702,2.1221294,,,,,,,,,,,,,, -84258,,,0.7216015458106995,1.1113862991333008,0.6678599715232849,1.3596880435943604,50000.0,0.5437999963760376,2.0086333751678467,10000.0,37434.48281121254,40459.272706747055,37434.48281121254,3016.158703804016,4.166561603546143,0.0 -84300,1.188499,3.4071934,,,,,,,,,,,,,, -84400,1.3847636,3.0111024,,,,,,,,,,,,,, -84500,1.5167454,2.1580536,,,,,,,,,,,,,, -84600,1.410181,4.476997,,,,,,,,,,,,,, -84700,1.4678414,2.6104317,,,,,,,,,,,,,, -84800,1.3963405,1.9880939,,,,,,,,,,,,,, -84900,1.4289681,2.1074991,,,,,,,,,,,,,, -85000,1.2565607,4.19913,,,,,,,,,,,,,, -85100,1.2790031,2.2555017,,,,,,,,,,,,,, -85200,1.3537035,2.259621,,,,,,,,,,,,,, -85201,,,0.7228124737739563,1.1117151975631714,0.6657599806785583,1.3800417184829712,50000.0,0.5369000434875488,2.038984775543213,10000.0,37854.96192789078,40915.03071784973,37854.96192789078,3051.331242084503,4.223360538482666,0.0 -85300,1.260049,4.404504,,,,,,,,,,,,,, -85400,1.5046643,1.9658488,,,,,,,,,,,,,, -85500,1.3848392,4.3764844,,,,,,,,,,,,,, -85600,1.4840604,2.1014931,,,,,,,,,,,,,, -85700,1.4509251,2.0943887,,,,,,,,,,,,,, -85800,1.2884523,2.1897151,,,,,,,,,,,,,, -85900,1.2860382,4.4783945,,,,,,,,,,,,,, -86000,1.3560203,2.7465935,,,,,,,,,,,,,, -86100,1.3363059,2.3857372,,,,,,,,,,,,,, -86145,,,0.7410351634025574,1.0317023992538452,0.6699399948120117,1.3505462408065796,50000.0,0.5494000315666199,1.995234131813049,10000.0,38275.27783441544,41370.34265089035,38275.27783441544,3086.22958111763,4.271142959594727,0.0 -86200,1.3367124,1.937652,,,,,,,,,,,,,, -86300,1.54832,2.158775,,,,,,,,,,,,,, -86400,1.4529619,2.0471027,,,,,,,,,,,,,, -86500,1.2238585,3.0135105,,,,,,,,,,,,,, -86600,1.2555948,3.2108305,,,,,,,,,,,,,, -86700,1.3665729,4.1055107,,,,,,,,,,,,,, -86800,1.2863079,3.0488498,,,,,,,,,,,,,, -86900,1.3614504,2.5760605,,,,,,,,,,,,,, -87000,1.2791374,2.5488362,,,,,,,,,,,,,, -87089,,,0.722949206829071,1.0942314863204956,0.6686999797821045,1.3394672870635986,50000.0,0.5448000431060791,1.9972219467163088,10000.0,38695.271587610245,41826.296382427216,38695.271587610245,3122.0959231853485,4.3140411376953125,0.0 -87100,1.457033,2.1704824,,,,,,,,,,,,,, -87200,1.3785975,4.628935,,,,,,,,,,,,,, -87300,1.4103024,2.178977,,,,,,,,,,,,,, -87400,1.3580788,2.1277325,,,,,,,,,,,,,, -87500,1.2664329,3.9316604,,,,,,,,,,,,,, -87600,1.1937197,3.696999,,,,,,,,,,,,,, -87700,1.2906321,2.0897586,,,,,,,,,,,,,, -87800,1.3451934,3.1459389,,,,,,,,,,,,,, -87900,1.408839,2.0329301,,,,,,,,,,,,,, -88000,1.2378869,3.3488357,,,,,,,,,,,,,, -88032,,,0.7299609184265137,1.0719990730285645,0.6747599840164185,1.3282451629638672,50000.0,0.5497000217437744,1.986189603805542,10000.0,39115.2553396225,42281.354194402695,39115.2553396225,3157.0737657547,4.3602800369262695,0.0 -88100,1.2860867,3.2610116,,,,,,,,,,,,,, -88200,1.4215633,1.9845335,,,,,,,,,,,,,, -88300,1.5338054,2.0416865,,,,,,,,,,,,,, -88400,1.6259651,1.8927888,,,,,,,,,,,,,, -88500,1.4932175,2.0217361,,,,,,,,,,,,,, -88600,1.4281405,2.0291023,,,,,,,,,,,,,, -88700,1.3409173,2.4051056,,,,,,,,,,,,,, -88800,1.341768,1.988639,,,,,,,,,,,,,, -88900,1.3023938,2.6200476,,,,,,,,,,,,,, -88977,,,0.7388085722923279,1.0372790098190308,0.6735000014305115,1.3296064138412476,50000.0,0.5471000075340271,2.0001771450042725,10000.0,39535.3887925148,42737.726367235184,39535.3887925148,3193.2224068641663,4.400323152542114,0.0 -89000,1.3803726,2.3859296,,,,,,,,,,,,,, -89100,1.3512784,4.268475,,,,,,,,,,,,,, -89200,1.3464001,3.919764,,,,,,,,,,,,,, -89300,1.3472549,2.847239,,,,,,,,,,,,,, -89400,1.173117,3.636723,,,,,,,,,,,,,, -89500,1.5204254,2.0468554,,,,,,,,,,,,,, -89600,1.3377067,4.1507063,,,,,,,,,,,,,, -89700,1.4793392,2.029357,,,,,,,,,,,,,, -89800,1.3736632,2.2625494,,,,,,,,,,,,,, -89900,1.4143791,2.056652,,,,,,,,,,,,,, -89923,,,0.7599999904632568,0.9553410410881042,0.66975998878479,1.3442224264144895,50000.0,0.5466000437736511,1.9942545890808103,10000.0,39955.61276316643,43193.67957997322,39955.61276316643,3228.858241558075,4.443213939666748,0.0 -90000,1.3274417,2.0419896,,,,,,,,,,,,,, -90100,1.4030231,2.053009,,,,,,,,,,,,,, -90200,1.2418042,2.985647,,,,,,,,,,,,,, -90300,1.4669566,2.2819276,,,,,,,,,,,,,, -90400,1.5357895,2.129461,,,,,,,,,,,,,, -90500,1.2936428,4.610142,,,,,,,,,,,,,, -90600,1.3758026,2.0369506,,,,,,,,,,,,,, -90700,1.383461,1.9362884,,,,,,,,,,,,,, -90800,1.3887129,1.9847499,,,,,,,,,,,,,, -90868,,,0.7295898199081421,1.097119688987732,0.6711399555206299,1.3498778343200684,50000.0,0.5463000535964966,2.0084035396575928,10000.0,40375.80407691002,43648.7395863533,40375.80407691002,3263.637674331665,4.48250937461853,0.0 -90900,1.3876648,2.8157232,,,,,,,,,,,,,, -91000,1.569596,2.1429365,,,,,,,,,,,,,, -91100,1.5637357,1.9806378,,,,,,,,,,,,,, -91200,1.325281,4.501181,,,,,,,,,,,,,, -91300,1.383407,3.7152867,,,,,,,,,,,,,, -91400,1.5486803,1.9926503,,,,,,,,,,,,,, -91500,1.4025269,4.381402,,,,,,,,,,,,,, -91600,1.4590837,2.0901787,,,,,,,,,,,,,, -91700,1.4275216,4.5430727,,,,,,,,,,,,,, -91800,1.3743896,4.5460896,,,,,,,,,,,,,, -91812,,,0.7383593320846558,1.04035747051239,0.6772399544715881,1.3202310800552368,50000.0,0.5585000514984131,1.9671977758407595,10000.0,40795.88285851479,44104.20302581787,40795.88285851479,3298.9227085113525,4.531588315963745,0.0 -91900,1.3278981,2.5667286,,,,,,,,,,,,,, -92000,1.4561726,3.9226403,,,,,,,,,,,,,, -92100,1.5380076,2.1172185,,,,,,,,,,,,,, -92200,1.362506,4.219236,,,,,,,,,,,,,, -92300,1.4859146,2.0166588,,,,,,,,,,,,,, -92400,1.3629556,2.2479632,,,,,,,,,,,,,, -92500,1.2913669,2.8616526,,,,,,,,,,,,,, -92600,1.4619046,2.0770953,,,,,,,,,,,,,, -92700,1.562412,2.188215,,,,,,,,,,,,,, -92755,,,0.7471874952316284,1.026834487915039,0.6764000058174133,1.334289312362671,50000.0,0.5506000518798828,2.000419855117798,10000.0,41215.80678701401,44557.83666324616,41215.80678701401,3332.5399737358093,4.574221849441528,0.0 -92800,1.3941616,2.8981502,,,,,,,,,,,,,, -92900,1.5408193,2.101827,,,,,,,,,,,,,, -93000,1.4890952,4.324504,,,,,,,,,,,,,, -93100,1.5843654,2.0488474,,,,,,,,,,,,,, -93200,1.4131662,2.0639167,,,,,,,,,,,,,, -93300,1.4168197,2.1530805,,,,,,,,,,,,,, -93400,1.539589,2.123906,,,,,,,,,,,,,, -93500,1.3853656,2.484423,,,,,,,,,,,,,, -93600,1.2443693,3.1207023,,,,,,,,,,,,,, -93698,,,0.7383398413658142,1.0327916145324707,0.678119957447052,1.306376576423645,50000.0,0.5543000102043152,1.944562554359436,10000.0,41635.83515667915,45013.53951334953,41635.83515667915,3368.114696741104,4.624174118041992,0.0 -93700,1.3450211,3.5445185,,,,,,,,,,,,,, -93800,1.5655664,2.021553,,,,,,,,,,,,,, -93900,1.3471435,2.1224122,,,,,,,,,,,,,, -94000,1.3627933,2.6650898,,,,,,,,,,,,,, -94100,1.3900957,4.025813,,,,,,,,,,,,,, -94200,1.4816501,2.4057057,,,,,,,,,,,,,, -94300,1.3934371,1.975718,,,,,,,,,,,,,, -94400,1.4373418,2.0306036,,,,,,,,,,,,,, -94500,1.6213431,1.9845905,,,,,,,,,,,,,, -94600,1.4663434,3.9414916,,,,,,,,,,,,,, -94642,,,0.7393554449081421,1.033813714981079,0.6789999604225159,1.2958242893218994,50000.0,0.5541000366210938,1.9380820989608765,10000.0,42055.82461929321,45470.20898675919,42055.82461929321,3404.7003223896027,4.667530536651611,0.0 -94700,1.5249349,4.5182734,,,,,,,,,,,,,, -94800,1.4767011,2.0268426,,,,,,,,,,,,,, -94900,1.3939532,3.5857387,,,,,,,,,,,,,, -95000,1.3811716,2.9804935,,,,,,,,,,,,,, -95100,1.4596745,3.8483794,,,,,,,,,,,,,, -95200,1.3621194,1.9696372,,,,,,,,,,,,,, -95300,1.4325889,2.0656216,,,,,,,,,,,,,, -95400,1.5142497,2.0400693,,,,,,,,,,,,,, -95500,1.6642023,1.9317628,,,,,,,,,,,,,, -95589,,,0.749804675579071,0.988872528076172,0.6826800107955933,1.2907994985580444,50000.0,0.5523000359535217,1.9421786069869995,10000.0,42475.95677399635,45926.06120347977,42475.95677399635,3440.3258113861084,4.711685180664063,0.0 -95600,1.4957308,2.495409,,,,,,,,,,,,,, -95700,1.394734,2.6010814,,,,,,,,,,,,,, -95800,1.4575169,2.1914566,,,,,,,,,,,,,, -95900,1.5550435,2.0111012,,,,,,,,,,,,,, -96000,1.3168375,2.2880113,,,,,,,,,,,,,, -96100,1.4055309,2.0355387,,,,,,,,,,,,,, -96200,1.3889794,3.1214192,,,,,,,,,,,,,, -96300,1.4217614,2.3429523,,,,,,,,,,,,,, -96400,1.450126,4.4144073,,,,,,,,,,,,,, -96500,1.6135088,4.4781747,,,,,,,,,,,,,, -96532,,,0.768261730670929,0.9298219680786132,0.6815999746322632,1.3072997331619265,50000.0,0.5608000159263611,1.949744820594788,10000.0,42895.94270777702,46382.64171719551,42895.94270777702,3476.817836523056,4.763655424118042,0.0 -96600,1.4056095,3.7581873,,,,,,,,,,,,,, -96700,1.3442122,2.8510423,,,,,,,,,,,,,, -96800,1.4825114,2.0021071,,,,,,,,,,,,,, -96900,1.3392028,4.064243,,,,,,,,,,,,,, -97000,1.4256892,1.9698765,,,,,,,,,,,,,, -97100,1.4006517,2.1085594,,,,,,,,,,,,,, -97200,1.5306484,2.218359,,,,,,,,,,,,,, -97300,1.4854394,2.6428711,,,,,,,,,,,,,, -97400,1.3699919,3.092705,,,,,,,,,,,,,, -97475,,,0.7375390529632568,1.017408013343811,0.6825000047683716,1.2762969732284546,50000.0,0.5592000484466553,1.9167892932891848,10000.0,43316.22964668274,46838.23907756805,43316.22964668274,3512.03467297554,4.808404445648193,0.0 -97500,1.5331959,4.4359226,,,,,,,,,,,,,, -97600,1.5121981,1.9501292,,,,,,,,,,,,,, -97700,1.4221271,1.9837967,,,,,,,,,,,,,, -97800,1.3695761,4.133259,,,,,,,,,,,,,, -97900,1.4088768,2.0769644,,,,,,,,,,,,,, -98000,1.5186518,2.118608,,,,,,,,,,,,,, -98100,1.4826791,2.181351,,,,,,,,,,,,,, -98200,1.4125723,3.8723865,,,,,,,,,,,,,, -98300,1.4368931,2.7104092,,,,,,,,,,,,,, -98400,1.3256502,3.4154825,,,,,,,,,,,,,, -98416,,,0.7465234398841858,1.0131640434265137,0.6841599941253662,1.29089093208313,50000.0,0.5608000159263611,1.9235416650772093,10000.0,43736.36538696289,47294.940781116486,43736.36538696289,3548.5029113292694,4.856558799743652,0.0 -98500,1.4875505,2.0775144,,,,,,,,,,,,,, -98600,1.5553098,2.074495,,,,,,,,,,,,,, -98700,1.3996539,2.7555761,,,,,,,,,,,,,, -98800,1.4047508,2.2282271,,,,,,,,,,,,,, -98900,1.6434556,1.8510333,,,,,,,,,,,,,, -99000,1.4779423,1.9058803,,,,,,,,,,,,,, -99100,1.303081,3.6181664,,,,,,,,,,,,,, -99200,1.4200267,2.5278776,,,,,,,,,,,,,, -99300,1.6900213,3.4877744,,,,,,,,,,,,,, -99357,,,0.755664050579071,0.9562978148460388,0.6837199926376343,1.2733529806137085,50000.0,0.5591000318527222,1.913759469985962,10000.0,44156.54634976387,47751.04694914818,44156.54634976387,3584.3355853557587,4.8991899490356445,0.0 -99400,1.529045,2.0120885,,,,,,,,,,,,,, -99500,1.3073801,2.2210326,,,,,,,,,,,,,, -99600,1.4456322,2.3390648,,,,,,,,,,,,,, -99700,1.4844922,3.2416155,,,,,,,,,,,,,, -99800,1.4237394,1.8800555,,,,,,,,,,,,,, -99900,1.3352196,3.7902381,,,,,,,,,,,,,, -100000,1.435097,1.9202597,,,,,,,,,,,,,, -100100,1.4352064,2.3092759,,,,,,,,,,,,,, -100200,1.4103657,4.283719,,,,,,,,,,,,,, -100300,1.4419857,4.0428643,,,,,,,,,,,,,, -100302,,,0.7462109327316284,0.9947492480278016,0.6892799735069275,1.256700873374939,50000.0,0.566100001335144,1.9041144847869875,10000.0,44576.57260489464,48207.096264600754,44576.57260489464,3620.267826318741,4.93875527381897,0.0 -100400,1.4547473,2.7436652,,,,,,,,,,,,,, -100500,1.276648,3.3809662,,,,,,,,,,,,,, -100600,1.4891822,2.8108406,,,,,,,,,,,,,, -100700,1.5148782,2.0024807,,,,,,,,,,,,,, -100800,1.4812391,4.3302865,,,,,,,,,,,,,, -100900,1.3621676,2.9910147,,,,,,,,,,,,,, -101000,1.763959,2.0242453,,,,,,,,,,,,,, -101100,1.4189535,3.0984478,,,,,,,,,,,,,, -101200,1.4500074,2.2100909,,,,,,,,,,,,,, -101247,,,0.7510156035423279,0.9671342372894288,0.6897599697113037,1.2461247444152832,50000.0,0.5685000419616699,1.881093978881836,10000.0,44996.47789621353,48662.70986151695,44996.47789621353,3655.877145767212,4.987702131271362,0.0 -101300,1.5094067,2.2103097,,,,,,,,,,,,,, -101400,1.4506016,3.2964425,,,,,,,,,,,,,, -101500,1.7210952,1.8431981,,,,,,,,,,,,,, -101600,1.7122638,2.9570873,,,,,,,,,,,,,, -101700,1.5262799,1.9567662,,,,,,,,,,,,,, -101800,1.468572,2.495299,,,,,,,,,,,,,, -101900,1.5686985,1.9749784,,,,,,,,,,,,,, -102000,1.5237578,1.9974462,,,,,,,,,,,,,, -102100,1.6919539,1.848265,,,,,,,,,,,,,, -102191,,,0.7568163871765137,0.9741609692573548,0.6883999705314636,1.270899772644043,50000.0,0.5634000301361084,1.915746808052063,10000.0,45416.58115434647,49119.26220941544,45416.58115434647,3692.2329466342926,5.0316221714019775,0.0 -102200,1.4911466,4.3290567,,,,,,,,,,,,,, -102300,1.6325839,1.9868537,,,,,,,,,,,,,, -102400,1.3594431,3.6668754,,,,,,,,,,,,,, -102500,1.566538,1.8641715,,,,,,,,,,,,,, -102600,1.6191403,4.317015,,,,,,,,,,,,,, -102700,1.7884585,3.9777,,,,,,,,,,,,,, -102800,1.591838,2.0175462,,,,,,,,,,,,,, -102900,1.5440418,4.1862426,,,,,,,,,,,,,, -103000,1.5524727,2.4003434,,,,,,,,,,,,,, -103100,1.4954566,2.2037094,,,,,,,,,,,,,, -103135,,,0.77490234375,0.9026257395744324,0.6901199817657471,1.262291669845581,50000.0,0.5682000517845154,1.8965471982955933,10000.0,45836.55875372887,49575.37422776222,45836.55875372887,3728.261288166046,5.086937427520752,0.0 -103200,1.4860687,1.8073933,,,,,,,,,,,,,, -103300,1.600799,1.9030391,,,,,,,,,,,,,, -103400,1.5941019,1.8635725,,,,,,,,,,,,,, -103500,1.4667976,2.730161,,,,,,,,,,,,,, -103600,1.455423,1.8075261,,,,,,,,,,,,,, -103700,1.652496,2.0066242,,,,,,,,,,,,,, -103800,1.594468,3.8793695,,,,,,,,,,,,,, -103900,1.4092077,3.0517054,,,,,,,,,,,,,, -104000,1.6797177,2.0884113,,,,,,,,,,,,,, -104079,,,0.7558984160423279,0.9612823128700256,0.6957399845123291,1.2298921346664429,50000.0,0.5656000375747681,1.8726284503936768,10000.0,46256.78787302971,50032.344173669815,46256.78787302971,3764.905757427216,5.132697582244873,0.0 -104100,1.35157,3.6695762,,,,,,,,,,,,,, -104200,1.4234954,2.6367333,,,,,,,,,,,,,, -104300,1.4302762,2.9456139,,,,,,,,,,,,,, -104400,1.5844657,1.9516218,,,,,,,,,,,,,, -104500,1.6144862,1.9766076,,,,,,,,,,,,,, -104600,1.6850123,2.3153882,,,,,,,,,,,,,, -104700,1.5172253,2.8814018,,,,,,,,,,,,,, -104800,1.9541103,1.8218795,,,,,,,,,,,,,, -104900,1.51251,2.113003,,,,,,,,,,,,,, -105000,1.5529157,4.3820324,,,,,,,,,,,,,, -105023,,,0.760546863079071,0.9362866282463074,0.6957399845123291,1.2292063236236572,50000.0,0.5689000487327576,1.8730270862579343,10000.0,46677.106143713,50487.73134660721,46677.106143713,3799.874258995056,5.182331800460815,0.0 -105100,1.4472755,4.0883207,,,,,,,,,,,,,, -105200,1.5671626,1.8164752,,,,,,,,,,,,,, -105300,1.5880418,1.9915047,,,,,,,,,,,,,, -105400,1.4518968,3.8665962,,,,,,,,,,,,,, -105500,1.5894371,1.8723595,,,,,,,,,,,,,, -105600,1.433085,3.1667337,,,,,,,,,,,,,, -105700,1.5880758,2.4685748,,,,,,,,,,,,,, -105800,1.5499661,2.0657437,,,,,,,,,,,,,, -105900,1.6074519,4.350768,,,,,,,,,,,,,, -105968,,,0.7692187428474426,0.9136697053909302,0.6956999897956848,1.2356456518173218,50000.0,0.5648000240325928,1.8811651468276973,10000.0,47097.24747681618,50943.70225930214,47097.24747681618,3835.612533569336,5.22328519821167,0.0 -106000,1.8071518,1.9385047,,,,,,,,,,,,,, -106100,1.5588968,1.8548125,,,,,,,,,,,,,, -106200,1.452506,2.9189942,,,,,,,,,,,,,, -106300,1.6795384,1.8379385,,,,,,,,,,,,,, -106400,1.6036128,1.9614196,,,,,,,,,,,,,, -106500,1.6354433,2.686692,,,,,,,,,,,,,, -106600,1.5134244,3.870699,,,,,,,,,,,,,, -106700,1.3857781,3.3763647,,,,,,,,,,,,,, -106800,1.4718288,4.317409,,,,,,,,,,,,,, -106900,1.6878786,4.010612,,,,,,,,,,,,,, -106914,,,0.7616991996765137,0.9309342503547668,0.6985599994659424,1.2136071920394895,50000.0,0.5676000118255615,1.8738008737564087,10000.0,47517.2873609066,51399.94680428505,47517.2873609066,3871.7235753536224,5.266687631607056,0.0 -107000,1.4613286,3.094849,,,,,,,,,,,,,, -107100,1.4530512,2.890358,,,,,,,,,,,,,, -107200,1.6592454,2.2038603,,,,,,,,,,,,,, -107300,1.8146852,1.9048399,,,,,,,,,,,,,, -107400,1.3970013,3.046385,,,,,,,,,,,,,, -107500,1.5690598,1.8794987,,,,,,,,,,,,,, -107600,1.571855,1.9998977,,,,,,,,,,,,,, -107700,1.3860177,2.5818615,,,,,,,,,,,,,, -107800,1.6928416,1.9081811,,,,,,,,,,,,,, -107855,,,0.7624804377555847,0.9120814800262452,0.7012799978256226,1.1950334310531616,50000.0,0.5760000348091125,1.8427163362503047,10000.0,47937.33025097847,51855.41088676453,47937.33025097847,3907.050561189652,5.311384916305542,0.0 -107900,1.4184914,2.7965622,,,,,,,,,,,,,, -108000,1.439377,2.6756651,,,,,,,,,,,,,, -108100,1.5418957,1.9821657,,,,,,,,,,,,,, -108200,1.44836,3.0671554,,,,,,,,,,,,,, -108300,1.5840783,2.1917572,,,,,,,,,,,,,, -108400,1.6918834,4.2478642,,,,,,,,,,,,,, -108500,1.7155359,3.7818928,,,,,,,,,,,,,, -108600,1.5186983,2.4073424,,,,,,,,,,,,,, -108700,1.4139323,2.6007006,,,,,,,,,,,,,, -108799,,,0.7684375047683716,0.915904700756073,0.6987599730491638,1.2196636199951172,50000.0,0.5685000419616699,1.8710874319076536,10000.0,48357.28145503998,52310.99687170982,48357.28145503998,3942.584057331085,5.36137318611145,0.0 -108800,1.6402241,4.0644135,,,,,,,,,,,,,, -108900,1.6261228,2.290145,,,,,,,,,,,,,, -109000,1.7715763,4.2643404,,,,,,,,,,,,,, -109100,1.5270246,2.1596737,,,,,,,,,,,,,, -109200,1.3744224,2.601781,,,,,,,,,,,,,, -109300,1.6964529,1.9903425,,,,,,,,,,,,,, -109400,1.5092307,2.6248205,,,,,,,,,,,,,, -109500,1.4431716,3.3695638,,,,,,,,,,,,,, -109600,1.6784942,1.8781004,,,,,,,,,,,,,, -109700,1.5989985,2.0315561,,,,,,,,,,,,,, -109745,,,0.7912695407867432,0.8051213026046753,0.7003399729728699,1.1956088542938232,50000.0,0.5756000280380249,1.829600691795349,10000.0,48777.6155333519,52767.975719213486,48777.6155333519,3979.134026050568,5.406089067459106,0.0 -109800,1.6104614,2.8364658,,,,,,,,,,,,,, -109900,1.7252271,2.12722,,,,,,,,,,,,,, -110000,1.6309841,1.823148,,,,,,,,,,,,,, -110100,1.4413521,2.4040601,,,,,,,,,,,,,, -110200,1.6036285,2.2162206,,,,,,,,,,,,,, -110300,1.5456078,3.5859435,,,,,,,,,,,,,, -110400,1.6019663,1.8762293,,,,,,,,,,,,,, -110500,1.4480428,2.0307858,,,,,,,,,,,,,, -110600,1.5133489,1.8087683,,,,,,,,,,,,,, -110688,,,0.7667773365974426,0.90343976020813,0.7044399976730347,1.1917033195495603,50000.0,0.5751000046730042,1.837497353553772,10000.0,49197.70868873596,53224.95006370544,49197.70868873596,4015.9227290153503,5.449268817901611,0.0 -110700,1.4726543,2.1909292,,,,,,,,,,,,,, -110800,1.6694988,1.7670581,,,,,,,,,,,,,, -110900,1.745907,1.9197674,,,,,,,,,,,,,, -111000,1.6897739,1.6677328,,,,,,,,,,,,,, -111100,1.4818487,3.0266953,,,,,,,,,,,,,, -111200,1.572985,2.075632,,,,,,,,,,,,,, -111300,1.4126126,2.7883072,,,,,,,,,,,,,, -111400,1.6806922,1.8425248,,,,,,,,,,,,,, -111500,1.6332448,1.7857784,,,,,,,,,,,,,, -111600,1.6799549,1.974345,,,,,,,,,,,,,, -111633,,,0.7742773294448853,0.8982342481613159,0.7044399976730347,1.1926296949386597,50000.0,0.5767000317573547,1.8268989324569704,10000.0,49617.80343198776,53680.64485859871,49617.80343198776,4051.419378995896,5.502639532089233,0.0 -111700,1.5697153,1.8347726,,,,,,,,,,,,,, -111800,1.8655396,1.9002655,,,,,,,,,,,,,, -111900,1.6065358,3.8556886,,,,,,,,,,,,,, -112000,1.6809605,2.0346007,,,,,,,,,,,,,, -112100,1.516299,2.6697006,,,,,,,,,,,,,, -112200,1.6363366,2.021992,,,,,,,,,,,,,, -112300,1.4811617,2.8891795,,,,,,,,,,,,,, -112400,1.7901601,3.8280373,,,,,,,,,,,,,, -112500,1.5744973,1.7637707,,,,,,,,,,,,,, -112579,,,0.7817187309265137,0.8364842534065247,0.7057200074195862,1.1751327514648438,50000.0,0.58160001039505,1.812219262123108,10000.0,50038.0046517849,54136.23913860321,50038.0046517849,4086.712213039398,5.552619695663452,0.0 -112600,1.5970217,3.8164518,,,,,,,,,,,,,, -112700,1.7029077,1.9748262,,,,,,,,,,,,,, -112800,1.5176833,2.1986833,,,,,,,,,,,,,, -112900,1.3964293,2.7270412,,,,,,,,,,,,,, -113000,1.6051242,4.1032357,,,,,,,,,,,,,, -113100,1.675274,4.1710496,,,,,,,,,,,,,, -113200,1.6314118,2.1525664,,,,,,,,,,,,,, -113300,1.6673621,4.3961086,,,,,,,,,,,,,, -113400,1.4354421,3.397428,,,,,,,,,,,,,, -113500,1.6808548,1.7470797,,,,,,,,,,,,,, -113519,,,0.7722070217132568,0.889786422252655,0.7026799917221069,1.1861135959625244,50000.0,0.5766000151634216,1.8333693742752075,10000.0,50458.02049589157,54592.7734439373,50458.02049589157,4123.13805603981,5.594868421554565,0.0 -113600,1.5680134,1.742554,,,,,,,,,,,,,, -113700,1.6946944,1.6913922,,,,,,,,,,,,,, -113800,1.7376078,1.8110292,,,,,,,,,,,,,, -113900,1.5268719,2.6214204,,,,,,,,,,,,,, -114000,1.8363543,1.8384626,,,,,,,,,,,,,, -114100,1.5597615,1.7028557,,,,,,,,,,,,,, -114200,1.5392083,1.9363751,,,,,,,,,,,,,, -114300,1.6271863,1.8064777,,,,,,,,,,,,,, -114400,1.4335612,2.578392,,,,,,,,,,,,,, -114465,,,0.7795702815055847,0.8615785837173462,0.7099199891090393,1.1732913255691528,50000.0,0.588100016117096,1.804617047309876,10000.0,50878.1044754982,55048.6302447319,50878.1044754982,4158.8170692920685,5.638230085372925,0.0 -114500,1.460174,2.2862372,,,,,,,,,,,,,, -114600,1.6021613,2.5770626,,,,,,,,,,,,,, -114700,1.6738951,2.0013149,,,,,,,,,,,,,, -114800,1.8214828,3.9930174,,,,,,,,,,,,,, -114900,1.5757558,2.2287235,,,,,,,,,,,,,, -115000,1.6826556,1.9166813,,,,,,,,,,,,,, -115100,1.7604038,1.7667563,,,,,,,,,,,,,, -115200,1.7208434,3.9846315,,,,,,,,,,,,,, -115300,1.7023728,1.7178594,,,,,,,,,,,,,, -115400,1.642894,3.218582,,,,,,,,,,,,,, -115411,,,0.7862890362739563,0.82613605260849,0.709879994392395,1.1572134494781494,50000.0,0.5830000042915344,1.7958065271377563,10000.0,51298.08326268196,55503.61720633507,51298.08326268196,4193.726380109787,5.686317682266235,0.0 -115500,1.6384481,2.1118422,,,,,,,,,,,,,, -115600,1.7522532,1.8462322,,,,,,,,,,,,,, -115700,1.7255968,2.193803,,,,,,,,,,,,,, -115800,1.7275344,1.8065546,,,,,,,,,,,,,, -115900,1.7651852,1.8386831,,,,,,,,,,,,,, -116000,1.6156791,2.3486605,,,,,,,,,,,,,, -116100,1.7819604,2.6281276,,,,,,,,,,,,,, -116200,1.6038444,1.6940747,,,,,,,,,,,,,, -116300,1.6013116,2.066105,,,,,,,,,,,,,, -116356,,,0.8011523485183716,0.7779342532157898,0.7084999680519104,1.161314606666565,50000.0,0.5822000503540039,1.81033718585968,10000.0,51718.13684058189,55957.18047308922,51718.13684058189,4227.141417503357,5.730888605117798,0.0 -116400,1.837903,4.130406,,,,,,,,,,,,,, -116500,1.6727196,1.8305976,,,,,,,,,,,,,, -116600,1.6826705,1.6915338,,,,,,,,,,,,,, -116700,1.6418741,2.3037937,,,,,,,,,,,,,, -116800,1.7593795,1.8555224,,,,,,,,,,,,,, -116900,1.6220933,2.518055,,,,,,,,,,,,,, -117000,1.7673888,1.7868506,,,,,,,,,,,,,, -117100,1.4466686,3.0535111,,,,,,,,,,,,,, -117200,1.6463331,2.1224966,,,,,,,,,,,,,, -117299,,,0.7773827910423279,0.8753180503845215,0.7099599838256836,1.174333095550537,50000.0,0.5878000259399414,1.808908462524414,10000.0,52138.45469069481,56414.60079431534,52138.45469069481,4264.137006044388,5.787650108337402,0.0 -117300,1.8010812,4.232707,,,,,,,,,,,,,, -117400,1.5393524,3.1693344,,,,,,,,,,,,,, -117500,1.992388,4.060126,,,,,,,,,,,,,, -117600,1.6514287,1.7957125,,,,,,,,,,,,,, -117700,1.7196151,4.214136,,,,,,,,,,,,,, -117800,1.7653985,1.8507311,,,,,,,,,,,,,, -117900,1.797578,1.8288441,,,,,,,,,,,,,, -118000,1.7106736,1.907625,,,,,,,,,,,,,, -118100,1.7088023,1.7702928,,,,,,,,,,,,,, -118200,1.6614175,2.3086166,,,,,,,,,,,,,, -118247,,,0.7845507860183716,0.8391237258911133,0.7131199836730957,1.1501730680465698,50000.0,0.5895000100135803,1.791690707206726,10000.0,52558.41911363602,56870.423597335815,52558.41911363602,4299.884291887283,5.847143650054932,0.0 -118300,1.7453562,1.841537,,,,,,,,,,,,,, -118400,1.7340381,1.795834,,,,,,,,,,,,,, -118500,1.6295575,3.39264,,,,,,,,,,,,,, -118600,1.487666,2.9395003,,,,,,,,,,,,,, -118700,1.5567878,2.4149072,,,,,,,,,,,,,, -118800,1.9352231,1.8047282,,,,,,,,,,,,,, -118900,1.5588257,2.2667167,,,,,,,,,,,,,, -119000,1.7152268,1.9219186,,,,,,,,,,,,,, -119100,1.7591642,1.802895,,,,,,,,,,,,,, -119194,,,0.7907031178474426,0.8129280805587769,0.712619960308075,1.1554131507873535,50000.0,0.5840000510215759,1.8003008365631104,10000.0,52978.646104097366,57327.64228081703,52978.646104097366,4336.783348560333,5.890512466430664,0.0 -119200,1.7116327,1.7497321,,,,,,,,,,,,,, -119300,2.1014824,3.860038,,,,,,,,,,,,,, -119400,1.7956245,3.3865957,,,,,,,,,,,,,, -119500,1.6996582,1.9686999,,,,,,,,,,,,,, -119600,1.8896779,1.887819,,,,,,,,,,,,,, -119700,2.005405,4.127957,,,,,,,,,,,,,, -119800,1.8159586,1.897667,,,,,,,,,,,,,, -119900,1.7555553,2.085034,,,,,,,,,,,,,, -120000,1.6755941,3.7656133,,,,,,,,,,,,,, -120100,1.8658627,4.2028584,,,,,,,,,,,,,, -120140,,,0.78236323595047,0.8431088328361511,0.715939998626709,1.1370404958724976,50000.0,0.5950000286102295,1.7669881582260132,10000.0,53398.970262527466,57783.42899656296,53398.970262527466,4372.142210245132,5.944162607192993,0.0 -120200,1.7905918,3.4223394,,,,,,,,,,,,,, -120300,1.7164879,2.8766778,,,,,,,,,,,,,, -120400,1.7071273,4.195669,,,,,,,,,,,,,, -120500,1.608293,3.7714047,,,,,,,,,,,,,, -120600,1.7611134,1.7626499,,,,,,,,,,,,,, -120700,1.847958,3.778767,,,,,,,,,,,,,, -120800,1.929605,1.709149,,,,,,,,,,,,,, -120900,1.6872591,1.7207752,,,,,,,,,,,,,, -121000,1.7391546,4.213193,,,,,,,,,,,,,, -121084,,,0.7857617139816284,0.8171523809432983,0.716759979724884,1.126121163368225,50000.0,0.59170001745224,1.766446590423584,10000.0,53819.214713573456,58240.33206629753,53819.214713573456,4408.703050613403,5.99152398109436,0.0 -121100,1.7476307,3.5803688,,,,,,,,,,,,,, -121200,1.7246162,1.7383022,,,,,,,,,,,,,, -121300,2.1919248,4.1674137,,,,,,,,,,,,,, -121400,1.8183227,1.7940086,,,,,,,,,,,,,, -121500,1.7745318,1.8039429,,,,,,,,,,,,,, -121600,1.8343495,1.7776148,,,,,,,,,,,,,, -121700,1.7042232,2.994444,,,,,,,,,,,,,, -121800,1.7505469,2.4806747,,,,,,,,,,,,,, -121900,1.8997158,1.6901281,,,,,,,,,,,,,, -122000,1.7776208,3.7843685,,,,,,,,,,,,,, -122027,,,0.7969726324081421,0.7908524870872498,0.7226200103759766,1.1162563562393188,50000.0,0.6005000472068787,1.7421303987503052,10000.0,54239.28956913948,58696.26500344277,54239.28956913948,4444.461498260498,6.041364669799805,0.0 -122100,1.7839209,1.760427,,,,,,,,,,,,,, -122200,1.9428403,4.2298245,,,,,,,,,,,,,, -122300,1.679244,2.862367,,,,,,,,,,,,,, -122400,1.6784326,3.5213335,,,,,,,,,,,,,, -122500,1.7931017,1.7652766,,,,,,,,,,,,,, -122600,1.7303518,2.0771852,,,,,,,,,,,,,, -122700,1.9087623,1.675478,,,,,,,,,,,,,, -122800,1.8649129,1.766659,,,,,,,,,,,,,, -122900,1.717146,1.6799827,,,,,,,,,,,,,, -122969,,,0.8081640601158142,0.7565469741821289,0.7193799614906311,1.1386165618896484,50000.0,0.5875000357627869,1.7857192754745483,10000.0,54659.25565576553,59152.28856778145,54659.25565576553,4480.420238494873,6.089730262756348,0.0 -123000,1.5625943,2.831322,,,,,,,,,,,,,, -123100,1.8204485,1.7346053,,,,,,,,,,,,,, -123200,1.826266,3.9269607,,,,,,,,,,,,,, -123300,1.8354568,1.7479901,,,,,,,,,,,,,, -123400,1.6651173,2.359847,,,,,,,,,,,,,, -123500,1.7442948,3.3388107,,,,,,,,,,,,,, -123600,1.8659937,1.7462612,,,,,,,,,,,,,, -123700,1.8431102,3.8696265,,,,,,,,,,,,,, -123800,2.0390246,1.6914016,,,,,,,,,,,,,, -123900,1.8248329,3.3100882,,,,,,,,,,,,,, -123914,,,0.7901562452316284,0.8068869113922119,0.7235400080680847,1.10583233833313,50000.0,0.5992000102996826,1.731168270111084,10000.0,55079.27008724213,59607.85239100456,55079.27008724213,4515.870098590851,6.138895750045776,0.0 -124000,1.9264209,1.719639,,,,,,,,,,,,,, -124100,1.8248992,1.7610114,,,,,,,,,,,,,, -124200,1.7790241,1.7084353,,,,,,,,,,,,,, -124300,1.6838953,3.245305,,,,,,,,,,,,,, -124400,1.7729518,1.8020103,,,,,,,,,,,,,, -124500,1.6703568,2.9599936,,,,,,,,,,,,,, -124600,1.6959512,1.9976572,,,,,,,,,,,,,, -124700,1.9178184,1.7317908,,,,,,,,,,,,,, -124800,1.9843934,1.7068744,,,,,,,,,,,,,, -124859,,,0.7969726324081421,0.7828193306922913,0.721839964389801,1.1123453378677368,50000.0,0.6008000373840332,1.7349574565887451,10000.0,55499.21904158592,60063.85525274277,55499.21904158592,4551.8229367733,6.189615488052368,0.0 -124900,1.7600547,1.7725902,,,,,,,,,,,,,, -125000,1.8571174,3.3019981,,,,,,,,,,,,,, -125100,1.798598,2.8206182,,,,,,,,,,,,,, -125200,1.7412444,3.3329458,,,,,,,,,,,,,, -125300,1.747091,2.230425,,,,,,,,,,,,,, -125400,1.8990735,1.6183187,,,,,,,,,,,,,, -125500,2.1726148,4.0510154,,,,,,,,,,,,,, -125600,1.7839259,1.7184155,,,,,,,,,,,,,, -125700,2.019058,4.0795445,,,,,,,,,,,,,, -125800,1.8751373,3.950275,,,,,,,,,,,,,, -125802,,,0.8039257526397705,0.764805257320404,0.7251600027084351,1.1057837009429932,50000.0,0.5984000563621521,1.7472631931304932,10000.0,55919.18010044098,60519.63783144951,55919.18010044098,4587.547564506531,6.235995531082153,0.0 -125900,1.8424392,1.7445589,,,,,,,,,,,,,, -126000,1.7424746,1.4920567,,,,,,,,,,,,,, -126100,1.9893063,1.7780738,,,,,,,,,,,,,, -126200,1.885493,1.6163995,,,,,,,,,,,,,, -126300,2.0453866,1.6689018,,,,,,,,,,,,,, -126400,1.920609,1.610176,,,,,,,,,,,,,, -126500,1.6608635,3.3119514,,,,,,,,,,,,,, -126600,1.6794657,2.9848704,,,,,,,,,,,,,, -126700,1.8581451,1.5850141,,,,,,,,,,,,,, -126746,,,0.7957617044448853,0.7915138006210327,0.7269200086593628,1.0992271900177002,50000.0,0.6011000275611877,1.7334434986114502,10000.0,56339.31581926346,60974.57660126686,56339.31581926346,4622.255147457123,6.281415939331055,0.0 -126800,1.816254,1.771347,,,,,,,,,,,,,, -126900,1.8080982,1.7271922,,,,,,,,,,,,,, -127000,1.7659762,3.3487508,,,,,,,,,,,,,, -127100,1.7349145,2.037185,,,,,,,,,,,,,, -127200,2.0715585,1.6568608,,,,,,,,,,,,,, -127300,2.0496924,1.6797905,,,,,,,,,,,,,, -127400,1.895598,2.033867,,,,,,,,,,,,,, -127500,1.9400816,1.5560266,,,,,,,,,,,,,, -127600,2.118879,1.6537294,,,,,,,,,,,,,, -127689,,,0.7975195050239563,0.7745173573493958,0.725600004196167,1.094840168952942,50000.0,0.6038000583648682,1.7361700534820557,10000.0,56759.32902884483,61430.3923459053,56759.32902884483,4657.953867673874,6.33309531211853,0.0 -127700,1.8988467,1.6509743,,,,,,,,,,,,,, -127800,1.7930245,1.5967578,,,,,,,,,,,,,, -127900,1.7644914,2.192861,,,,,,,,,,,,,, -128000,2.1085632,3.9781725,,,,,,,,,,,,,, -128100,1.8161103,1.5464637,,,,,,,,,,,,,, -128200,1.9226557,2.6169665,,,,,,,,,,,,,, -128300,1.9512811,3.4364383,,,,,,,,,,,,,, -128400,1.9716872,1.7335502,,,,,,,,,,,,,, -128500,1.914747,1.7322123,,,,,,,,,,,,,, -128600,1.7692425,2.24952,,,,,,,,,,,,,, -128633,,,0.8062695264816284,0.7585539221763611,0.7260199785232544,1.0970200300216677,50000.0,0.5994000434875488,1.735648274421692,10000.0,57179.24024510384,61887.30322051048,57179.24024510384,4694.850904941559,6.385717153549194,0.0 -128700,2.178332,1.7592832,,,,,,,,,,,,,, -128800,1.8492801,1.5992136,,,,,,,,,,,,,, -128900,2.12312,3.8006043,,,,,,,,,,,,,, -129000,1.697892,2.4414222,,,,,,,,,,,,,, -129100,1.8733872,1.8368225,,,,,,,,,,,,,, -129200,1.8430915,2.6432102,,,,,,,,,,,,,, -129300,1.9335964,2.311329,,,,,,,,,,,,,, -129400,1.920482,3.3630993,,,,,,,,,,,,,, -129500,1.9377005,1.798218,,,,,,,,,,,,,, -129577,,,0.8240624666213989,0.6876732707023621,0.7297599911689758,1.080593466758728,50000.0,0.600600004196167,1.7161928415298462,10000.0,57599.32408952713,62344.34480929375,57599.32408952713,4731.710379600525,6.433982849121094,0.0 -129600,1.7341686,2.4558294,,,,,,,,,,,,,, -129700,1.9228334,2.2789724,,,,,,,,,,,,,, -129800,1.8971096,3.0568664,,,,,,,,,,,,,, -129900,1.9377942,3.8117793,,,,,,,,,,,,,, -130000,1.7203946,2.4351158,,,,,,,,,,,,,, -130100,1.9355645,3.0766551,,,,,,,,,,,,,, -130200,1.907685,3.4731402,,,,,,,,,,,,,, -130300,1.8285179,2.867272,,,,,,,,,,,,,, -130400,1.9414728,2.1087565,,,,,,,,,,,,,, -130500,1.875043,1.6247349,,,,,,,,,,,,,, -130521,,,0.8061913847923279,0.74146568775177,0.731660008430481,1.078520894050598,50000.0,0.6066000461578369,1.7314767837524414,10000.0,58019.40853381157,62800.33711600304,58019.40853381157,4767.513851881027,6.48798131942749,0.0 -130600,1.9881493,1.7120376,,,,,,,,,,,,,, -130700,1.8220927,1.9826891,,,,,,,,,,,,,, -130800,1.88738,1.7015709,,,,,,,,,,,,,, -130900,2.0687442,3.5472367,,,,,,,,,,,,,, -131000,1.9263706,1.5660102,,,,,,,,,,,,,, -131100,1.9773369,1.5990853,,,,,,,,,,,,,, -131200,1.7899367,3.0082247,,,,,,,,,,,,,, -131300,1.8490113,1.6179991,,,,,,,,,,,,,, -131400,2.0633786,3.7406354,,,,,,,,,,,,,, -131467,,,0.8073632717132568,0.7516621947288513,0.7322799563407898,1.0729314088821411,50000.0,0.605400025844574,1.702867865562439,10000.0,58439.49217700958,63257.6309440136,58439.49217700958,4804.62633895874,6.535409212112427,0.0 -131500,1.7897645,2.5490937,,,,,,,,,,,,,, -131600,2.7171633,1.6109297,,,,,,,,,,,,,, -131700,1.8914156,1.615975,,,,,,,,,,,,,, -131800,1.995795,1.5244046,,,,,,,,,,,,,, -131900,1.9509972,2.152063,,,,,,,,,,,,,, -132000,1.8977218,2.6727483,,,,,,,,,,,,,, -132100,1.7595949,1.6565752,,,,,,,,,,,,,, -132200,2.0189264,1.7169322,,,,,,,,,,,,,, -132300,2.5941467,3.544193,,,,,,,,,,,,,, -132400,1.9309924,1.4560382,,,,,,,,,,,,,, -132413,,,0.8190429210662842,0.686676025390625,0.7339000105857849,1.062727451324463,50000.0,0.6101000308990479,1.7026625871658323,10000.0,58859.64246606827,63713.68802905083,58859.64246606827,4840.436786174774,6.581597805023193,0.0 -132500,2.2643461,3.9045868,,,,,,,,,,,,,, -132600,2.0327723,1.7554384,,,,,,,,,,,,,, -132700,1.9671811,2.9359193,,,,,,,,,,,,,, -132800,1.8900781,1.6294564,,,,,,,,,,,,,, -132900,2.012271,1.9431416,,,,,,,,,,,,,, -133000,1.9454695,1.5668571,,,,,,,,,,,,,, -133100,2.024852,1.6640671,,,,,,,,,,,,,, -133200,1.8261017,1.671699,,,,,,,,,,,,,, -133300,1.9738854,1.7259293,,,,,,,,,,,,,, -133358,,,0.809863269329071,0.7376311421394348,0.735539972782135,1.060116171836853,50000.0,0.6114000082015991,1.697309494018555,10000.0,59279.68286252022,64168.63878440857,59279.68286252022,4875.248173952103,6.630462884902954,0.0 -133400,2.0148044,1.6769493,,,,,,,,,,,,,, -133500,2.0173247,1.640878,,,,,,,,,,,,,, -133600,1.7741307,2.1834257,,,,,,,,,,,,,, -133700,1.9988289,1.5423107,,,,,,,,,,,,,, -133800,1.9027754,2.4655929,,,,,,,,,,,,,, -133900,1.9598082,1.6236976,,,,,,,,,,,,,, -134000,2.1100192,1.6276689,,,,,,,,,,,,,, -134100,1.8807054,1.8755441,,,,,,,,,,,,,, -134200,2.0104008,1.6750824,,,,,,,,,,,,,, -134300,,,0.8107226490974426,0.7090687155723572,0.735539972782135,1.045581340789795,50000.0,0.6104000210762024,1.6855266094207764,10000.0,59699.82854485512,64624.70807790756,59699.82854485512,4911.07240653038,6.679866790771484,0.0 -134300,1.9555093,2.2047727,,,,,,,,,,,,,, -134400,1.842761,2.978613,,,,,,,,,,,,,, -134500,1.9926153,3.6033666,,,,,,,,,,,,,, -134600,1.9537581,1.566077,,,,,,,,,,,,,, -134700,1.9846385,2.0417557,,,,,,,,,,,,,, -134800,2.0029752,1.4458092,,,,,,,,,,,,,, -134900,1.9100022,1.8676672,,,,,,,,,,,,,, -135000,2.2025146,3.737884,,,,,,,,,,,,,, -135100,2.0144506,1.5592606,,,,,,,,,,,,,, -135200,2.0882895,1.5735899,,,,,,,,,,,,,, -135246,,,0.8220117092132568,0.6820235252380371,0.7381199598312378,1.040558695793152,50000.0,0.619100034236908,1.657473921775818,10000.0,60120.14081978798,65081.38606548309,60120.14081978798,4947.333732366562,6.733776807785034,0.0 -135300,1.9653094,1.6487961,,,,,,,,,,,,,, -135400,2.0437982,1.4980801,,,,,,,,,,,,,, -135500,2.0717592,1.5333164,,,,,,,,,,,,,, -135600,2.1317616,2.8953726,,,,,,,,,,,,,, -135700,1.9845899,1.6830785,,,,,,,,,,,,,, -135800,2.1966207,1.6262912,,,,,,,,,,,,,, -135900,1.9861693,1.5941396,,,,,,,,,,,,,, -136000,1.968189,1.5079675,,,,,,,,,,,,,, -136100,2.081825,1.6230384,,,,,,,,,,,,,, -136190,,,0.833789050579071,0.6432842016220093,0.7403799891471863,1.0398958921432495,50000.0,0.6166000366210938,1.6594387292861938,10000.0,60540.23037648201,65538.20174622536,60540.23037648201,4983.961835622788,6.781572341918945,0.0 -136200,2.2964466,3.972701,,,,,,,,,,,,,, -136300,2.094008,1.6751552,,,,,,,,,,,,,, -136400,2.0942328,1.7086018,,,,,,,,,,,,,, -136500,2.268269,3.7750864,,,,,,,,,,,,,, -136600,1.971666,3.000993,,,,,,,,,,,,,, -136700,1.8905387,1.5858167,,,,,,,,,,,,,, -136800,1.9389584,3.0672007,,,,,,,,,,,,,, -136900,1.8484036,2.6337948,,,,,,,,,,,,,, -137000,2.0009518,1.6232699,,,,,,,,,,,,,, -137100,2.0900126,3.7126327,,,,,,,,,,,,,, -137134,,,0.8225390315055847,0.6828349232673645,0.7406799793243408,1.026926040649414,50000.0,0.6152000427246094,1.6574362516403198,10000.0,60960.5283882618,65993.98559451103,60960.5283882618,5019.351192474365,6.82807993888855,0.0 -137200,1.97908,3.51017,,,,,,,,,,,,,, -137300,1.9452505,2.3874362,,,,,,,,,,,,,, -137400,1.9252007,1.5094392,,,,,,,,,,,,,, -137500,2.222733,3.1790874,,,,,,,,,,,,,, -137600,2.1787589,1.5832719,,,,,,,,,,,,,, -137700,1.9569415,1.6237341,,,,,,,,,,,,,, -137800,1.9147606,1.9477173,,,,,,,,,,,,,, -137900,2.2943783,3.9242723,,,,,,,,,,,,,, -138000,2.1198099,1.7223294,,,,,,,,,,,,,, -138078,,,0.8241210579872131,0.6726440787315369,0.7392399907112122,1.0355180501937866,50000.0,0.6131000518798828,1.6730371713638306,10000.0,61380.73367190361,66450.48570799828,61380.73367190361,5055.546671152115,6.876504421234131,0.0 -138100,2.0238044,1.4530877,,,,,,,,,,,,,, -138200,2.0133224,2.0165107,,,,,,,,,,,,,, -138300,2.117763,1.526584,,,,,,,,,,,,,, -138400,2.0168526,1.4914912,,,,,,,,,,,,,, -138500,2.182915,1.5651265,,,,,,,,,,,,,, -138600,1.9355394,2.8610172,,,,,,,,,,,,,, -138700,2.1416602,1.5552349,,,,,,,,,,,,,, -138800,1.9626673,1.8052304,,,,,,,,,,,,,, -138900,1.9988884,3.24957,,,,,,,,,,,,,, -139000,2.1108825,2.8694382,,,,,,,,,,,,,, -139021,,,0.8309569954872131,0.6413344144821167,0.7416599988937378,1.0248628854751587,50000.0,0.6114000082015991,1.6480200290679932,10000.0,61800.80941319466,66905.37312984467,61800.80941319466,5090.26197385788,6.923521041870117,0.0 -139100,2.0580363,1.5192535,,,,,,,,,,,,,, -139200,2.0802145,1.5219573,,,,,,,,,,,,,, -139300,2.0818567,1.5384827,,,,,,,,,,,,,, -139400,2.1836464,1.4910569,,,,,,,,,,,,,, -139500,2.2163987,3.2696226,,,,,,,,,,,,,, -139600,1.9375954,2.9142551,,,,,,,,,,,,,, -139700,2.7492826,1.4726554,,,,,,,,,,,,,, -139800,2.362682,3.8335073,,,,,,,,,,,,,, -139900,2.221098,1.4860568,,,,,,,,,,,,,, -139965,,,0.8257421851158142,0.6530791521072388,0.7451399564743042,1.004848599433899,50000.0,0.6248000264167786,1.6311440467834473,10000.0,62220.787368536,67361.58550977707,62220.787368536,5126.392836332321,6.976944208145142,0.0 -140000,2.2962387,3.6264803,,,,,,,,,,,,,, -140100,2.3750308,3.7640982,,,,,,,,,,,,,, -140200,2.0635934,2.6269662,,,,,,,,,,,,,, -140300,2.1024632,1.6439509,,,,,,,,,,,,,, -140400,2.6119905,3.4583755,,,,,,,,,,,,,, -140500,2.0358267,2.868783,,,,,,,,,,,,,, -140600,1.9475657,1.9257073,,,,,,,,,,,,,, -140700,2.116745,2.8066435,,,,,,,,,,,,,, -140800,2.137781,1.5282582,,,,,,,,,,,,,, -140900,2.2125854,2.1641574,,,,,,,,,,,,,, -140909,,,0.8275976181030273,0.6520302295684814,0.7456799745559692,1.002097725868225,50000.0,0.6196000576019287,1.6410871744155884,10000.0,62640.990287303925,67818.36901831627,62640.990287303925,5162.872810840607,7.027889490127564,0.0 -141000,2.0801027,1.520193,,,,,,,,,,,,,, -141100,2.1004121,1.5436884,,,,,,,,,,,,,, -141200,2.0047119,1.4620919,,,,,,,,,,,,,, -141300,2.013474,2.6378021,,,,,,,,,,,,,, -141400,1.8460443,2.6365743,,,,,,,,,,,,,, -141500,2.17937,2.3004124,,,,,,,,,,,,,, -141600,1.997135,2.3085613,,,,,,,,,,,,,, -141700,2.1070328,1.4531848,,,,,,,,,,,,,, -141800,2.3360817,1.5618811,,,,,,,,,,,,,, -141855,,,0.8324023485183716,0.6519719362258911,0.7467199563980103,1.0059432983398438,50000.0,0.6238000392913818,1.6353424787521362,10000.0,63061.08670520783,68274.05000305176,63061.08670520783,5198.357325792313,7.077587604522705,0.0 -141900,2.3640814,1.521873,,,,,,,,,,,,,, -142000,2.1469154,3.2904854,,,,,,,,,,,,,, -142100,2.1474528,1.452671,,,,,,,,,,,,,, -142200,2.4497392,3.69662,,,,,,,,,,,,,, -142300,2.2001452,1.5516912,,,,,,,,,,,,,, -142400,2.0921133,1.9292021,,,,,,,,,,,,,, -142500,2.699691,1.3749456,,,,,,,,,,,,,, -142600,2.1349316,1.377732,,,,,,,,,,,,,, -142700,2.2670624,1.4963108,,,,,,,,,,,,,, -142800,1.9218292,2.1619797,,,,,,,,,,,,,, -142802,,,0.8463085889816284,0.5910487174987793,0.7475399971008301,1.0001438856124878,50000.0,0.6241000294685364,1.6294469833374023,10000.0,63481.00023508072,68729.8181681633,63481.00023508072,5234.109792232513,7.129896640777588,0.0 -142900,2.2249448,2.5390954,,,,,,,,,,,,,, -143000,3.1577175,3.8546472,,,,,,,,,,,,,, -143100,2.3230083,3.049143,,,,,,,,,,,,,, -143200,2.3010192,1.4657364,,,,,,,,,,,,,, -143300,2.0064256,1.6421387,,,,,,,,,,,,,, -143400,2.2067904,1.9692383,,,,,,,,,,,,,, -143500,2.0312803,1.5999753,,,,,,,,,,,,,, -143600,2.1955936,1.8420002,,,,,,,,,,,,,, -143700,2.1253078,1.734436,,,,,,,,,,,,,, -143747,,,0.8325781226158142,0.6336491703987122,0.7473799586296082,0.996917188167572,50000.0,0.6236000061035156,1.624912142753601,10000.0,63901.20878171921,69185.62810969353,63901.20878171921,5269.614250659943,7.175962209701538,0.0 -143800,2.3416624,1.6527426,,,,,,,,,,,,,, -143900,2.2926934,3.3254728,,,,,,,,,,,,,, -144000,2.374981,1.8453367,,,,,,,,,,,,,, -144100,2.466719,3.0535817,,,,,,,,,,,,,, -144200,2.0231495,1.5389823,,,,,,,,,,,,,, -144300,2.2617903,1.4767593,,,,,,,,,,,,,, -144400,2.1143486,1.5788798,,,,,,,,,,,,,, -144500,2.2789633,1.4601226,,,,,,,,,,,,,, -144600,2.3091612,1.5680751,,,,,,,,,,,,,, -144694,,,0.83607417345047,0.6233413219451904,0.7495200037956238,0.9957394003868104,50000.0,0.6264000535011292,1.6167546510696411,10000.0,64321.56186580658,69642.30963206291,64321.56186580658,5305.843818902969,7.224336624145508,0.0 -144700,1.9545113,1.6632028,,,,,,,,,,,,,, -144800,2.140779,1.3016789,,,,,,,,,,,,,, -144900,2.374321,3.4459898,,,,,,,,,,,,,, -145000,2.4044347,1.5086093,,,,,,,,,,,,,, -145100,2.2542098,2.6052442,,,,,,,,,,,,,, -145200,2.1547222,1.443052,,,,,,,,,,,,,, -145300,2.227652,1.4480941,,,,,,,,,,,,,, -145400,2.3726373,1.4388821,,,,,,,,,,,,,, -145500,2.351253,2.745993,,,,,,,,,,,,,, -145600,2.2086213,2.2234735,,,,,,,,,,,,,, -145639,,,0.8413866758346558,0.6079884767532349,0.7496599555015564,0.9920935034751892,50000.0,0.6269000172615051,1.6104060411453247,10000.0,64741.75490403175,70098.32852602005,64741.75490403175,5341.5623569488525,7.281407833099365,0.0 -145700,2.132644,3.055277,,,,,,,,,,,,,, -145800,2.5614548,1.5200301,,,,,,,,,,,,,, -145900,2.5276518,1.8134953,,,,,,,,,,,,,, -146000,2.1915104,1.7538862,,,,,,,,,,,,,, -146100,2.2017045,1.4653403,,,,,,,,,,,,,, -146200,2.4274957,1.5203121,,,,,,,,,,,,,, -146300,2.1724584,1.4982097,,,,,,,,,,,,,, -146400,2.314819,3.0979931,,,,,,,,,,,,,, -146500,2.2621727,1.3389062,,,,,,,,,,,,,, -146583,,,0.8368554711341858,0.6317271590232849,0.7525999546051025,0.9833926558494568,50000.0,0.6305000185966492,1.6054846048355105,10000.0,65161.94474768639,70555.01319146156,65161.94474768639,5377.9620950222015,7.32740592956543,0.0 -146600,2.1151745,1.8138003,,,,,,,,,,,,,, -146700,2.18559,1.484993,,,,,,,,,,,,,, -146800,2.2618287,3.3296735,,,,,,,,,,,,,, -146900,2.6331713,3.6049922,,,,,,,,,,,,,, -147000,2.322277,1.3379788,,,,,,,,,,,,,, -147100,2.2827806,3.1745937,,,,,,,,,,,,,, -147200,2.2399325,1.7255421,,,,,,,,,,,,,, -147300,2.3030477,2.9025335,,,,,,,,,,,,,, -147400,2.1453357,1.5630554,,,,,,,,,,,,,, -147500,2.559254,1.4190171,,,,,,,,,,,,,, -147531,,,0.8406640291213989,0.6047345399856567,0.7533400058746338,0.9802722930908204,50000.0,0.6319000124931335,1.593847155570984,10000.0,65582.24419879913,71010.74550557137,65582.24419879913,5413.293456554413,7.379019498825073,0.0 -147600,2.334828,1.3008122,,,,,,,,,,,,,, -147700,2.5074356,1.4295748,,,,,,,,,,,,,, -147800,2.1640453,2.198681,,,,,,,,,,,,,, -147900,2.2338386,1.3435053,,,,,,,,,,,,,, -148000,2.3439593,1.5599506,,,,,,,,,,,,,, -148100,2.2769084,1.9337593,,,,,,,,,,,,,, -148200,2.4860878,2.4780526,,,,,,,,,,,,,, -148300,2.2431185,1.6809211,,,,,,,,,,,,,, -148400,2.5731733,3.5959988,,,,,,,,,,,,,, -148474,,,0.8462694883346558,0.5779913067817688,0.7558599710464478,0.9657291173934937,50000.0,0.6288000345230103,1.5970369577407837,10000.0,66002.26580381393,71467.2146782875,66002.26580381393,5449.63863158226,7.4307496547698975,0.0 -148500,2.358483,1.4049774,,,,,,,,,,,,,, -148600,2.6416001,3.8163571,,,,,,,,,,,,,, -148700,2.4441612,1.4125265,,,,,,,,,,,,,, -148800,2.9208786,1.5942957,,,,,,,,,,,,,, -148900,2.257345,1.7285086,,,,,,,,,,,,,, -149000,2.675724,3.5880806,,,,,,,,,,,,,, -149100,2.5107784,1.4660614,,,,,,,,,,,,,, -149200,2.4292183,1.5038371,,,,,,,,,,,,,, -149300,2.7531571,3.753333,,,,,,,,,,,,,, -149400,2.3067555,1.4466813,,,,,,,,,,,,,, -149419,,,0.8524804711341858,0.5546852946281433,0.7554199695587158,0.9629727602005004,50000.0,0.6330000162124634,1.5846474170684814,10000.0,66422.62491846085,71923.45594525337,66422.62491846085,5485.423399209976,7.478578805923462,0.0 -149500,2.1929297,1.7241805,,,,,,,,,,,,,, -149600,2.5115657,1.6010702,,,,,,,,,,,,,, -149700,2.4375808,1.4454654,,,,,,,,,,,,,, -149800,2.3007736,1.3906853,,,,,,,,,,,,,, -149900,2.6724272,1.4817879,,,,,,,,,,,,,, -150000,2.397657,2.3881567,,,,,,,,,,,,,, -150100,2.5336761,1.3869728,,,,,,,,,,,,,, -150200,2.3167582,2.9554594,,,,,,,,,,,,,, -150300,2.3954692,1.3427362,,,,,,,,,,,,,, -150364,,,0.845996081829071,0.5830208659172058,0.7575799822807312,0.9540197253227234,50000.0,0.6341000199317932,1.5807474851608276,10000.0,66842.6845741272,72379.02922987938,66842.6845741272,5520.838660478592,7.52655816078186,0.0 -150400,2.1479993,1.9209008,,,,,,,,,,,,,, -150500,2.308676,2.4710712,,,,,,,,,,,,,, -150600,2.4594295,2.1928263,,,,,,,,,,,,,, -150700,2.241488,1.5919715,,,,,,,,,,,,,, -150800,2.5435176,1.5103598,,,,,,,,,,,,,, -150900,2.3930478,2.4285448,,,,,,,,,,,,,, -151000,2.544567,1.3805687,,,,,,,,,,,,,, -151100,2.4069245,1.2708043,,,,,,,,,,,,,, -151200,2.362386,1.8820398,,,,,,,,,,,,,, -151300,2.3167212,1.4301254,,,,,,,,,,,,,, -151309,,,0.8492968678474426,0.5629798173904419,0.7577599883079529,0.9516875743865968,50000.0,0.6341000199317932,1.5633233785629272,10000.0,67262.9191057682,72835.57788276672,67262.9191057682,5557.04504776001,7.584197282791138,0.0 -151400,2.3417428,2.8885615,,,,,,,,,,,,,, -151500,2.368609,1.3122579,,,,,,,,,,,,,, -151600,2.8344553,3.4699998,,,,,,,,,,,,,, -151700,2.5963209,2.4259279,,,,,,,,,,,,,, -151800,2.4239132,2.3490014,,,,,,,,,,,,,, -151900,2.3634248,2.6044,,,,,,,,,,,,,, -152000,2.4330595,1.3378973,,,,,,,,,,,,,, -152100,2.381919,1.312565,,,,,,,,,,,,,, -152200,2.2490518,1.7508569,,,,,,,,,,,,,, -152255,,,0.8557031154632568,0.5337830781936646,0.7605400085449219,0.942303478717804,50000.0,0.636400043964386,1.5599946975708008,10000.0,67683.15299010277,73291.40171384811,67683.15299010277,5592.5248901844025,7.643503665924072,0.0 -152300,2.5774975,1.2792184,,,,,,,,,,,,,, -152400,2.4248981,1.322004,,,,,,,,,,,,,, -152500,2.591634,1.344152,,,,,,,,,,,,,, -152600,2.1602418,1.8494649,,,,,,,,,,,,,, -152700,2.8479393,3.4370472,,,,,,,,,,,,,, -152800,2.503475,1.6634907,,,,,,,,,,,,,, -152900,2.694766,1.3971497,,,,,,,,,,,,,, -153000,2.4702485,2.9634402,,,,,,,,,,,,,, -153100,2.5187747,1.4864435,,,,,,,,,,,,,, -153200,,,0.8529492020606995,0.5705944895744324,0.7596399784088135,0.9485294222831726,50000.0,0.6391000151634216,1.5622645616531372,10000.0,68103.0815103054,73747.12293791771,68103.0815103054,5628.214492797852,7.696725130081177,0.0 -153200,2.3621266,1.2274727,,,,,,,,,,,,,, -153300,2.7316952,2.8313751,,,,,,,,,,,,,, -153400,2.5909133,1.2788903,,,,,,,,,,,,,, -153500,2.5248334,1.4899033,,,,,,,,,,,,,, -153600,2.5620437,1.3952072,,,,,,,,,,,,,, -153700,2.2946649,2.1930714,,,,,,,,,,,,,, -153800,2.3678007,1.4554832,,,,,,,,,,,,,, -153900,2.895956,3.6064298,,,,,,,,,,,,,, -154000,2.559083,1.6358178,,,,,,,,,,,,,, -154100,3.0851433,3.4700305,,,,,,,,,,,,,, -154142,,,0.8538476228713989,0.5499022006988525,0.7615799903869629,0.9385902881622314,50000.0,0.6353000402450562,1.5707690715789795,10000.0,68523.02074098587,74203.20285248756,68523.02074098587,5664.248873949051,7.753176212310791,0.0 -154200,2.5880167,1.3700824,,,,,,,,,,,,,, -154300,2.5172343,1.4441315,,,,,,,,,,,,,, -154400,2.5989664,1.4203404,,,,,,,,,,,,,, -154500,2.3254004,1.3226671,,,,,,,,,,,,,, -154600,3.2733486,3.568195,,,,,,,,,,,,,, -154700,2.393219,1.7565817,,,,,,,,,,,,,, -154800,2.5691488,1.4314125,,,,,,,,,,,,,, -154900,2.3335366,1.4538109,,,,,,,,,,,,,, -155000,2.467741,1.2451288,,,,,,,,,,,,,, -155085,,,0.8555077910423279,0.5564358234405518,0.759880006313324,0.946969985961914,50000.0,0.6333000063896179,1.566225528717041,10000.0,68943.2842502594,74660.95222091675,68943.2842502594,5701.631700754166,7.805613040924072,0.0 -155100,2.3902988,1.3307264,,,,,,,,,,,,,, -155200,2.5339873,1.5757754,,,,,,,,,,,,,, -155300,2.625083,1.406998,,,,,,,,,,,,,, -155400,2.7625434,3.0035052,,,,,,,,,,,,,, -155500,2.6855114,1.4214928,,,,,,,,,,,,,, -155600,2.2707338,2.6487436,,,,,,,,,,,,,, -155700,2.4027596,2.5238748,,,,,,,,,,,,,, -155800,2.3672786,1.3619307,,,,,,,,,,,,,, -155900,2.655106,1.3757938,,,,,,,,,,,,,, -156000,2.9060035,1.4865736,,,,,,,,,,,,,, -156031,,,0.8609960675239563,0.5209859609603882,0.7628200054168701,0.931879997253418,50000.0,0.6373000144958496,1.5529048442840576,10000.0,69363.431671381,75117.39098858833,69363.431671381,5737.818987846375,7.859495878219604,0.0 -156100,2.6986203,3.0389884,,,,,,,,,,,,,, -156200,2.5955586,1.5344334,,,,,,,,,,,,,, -156300,2.6394978,1.4040918,,,,,,,,,,,,,, -156400,3.0410073,3.5901222,,,,,,,,,,,,,, -156500,2.9657953,1.2052402,,,,,,,,,,,,,, -156600,2.7786386,3.3779335,,,,,,,,,,,,,, -156700,2.3696423,2.1685784,,,,,,,,,,,,,, -156800,2.4899693,1.3719349,,,,,,,,,,,,,, -156900,3.400572,2.938109,,,,,,,,,,,,,, -156974,,,0.85804682970047,0.5311560034751892,0.764959990978241,0.92471182346344,50000.0,0.6371000409126282,1.5472129583358765,10000.0,69783.68109440804,75572.45741295815,69783.68109440804,5772.52631187439,7.919899225234985,0.0 -157000,2.5526197,1.4892658,,,,,,,,,,,,,, -157100,2.677821,1.3281845,,,,,,,,,,,,,, -157200,2.5266159,1.7687337,,,,,,,,,,,,,, -157300,2.613945,1.4911146,,,,,,,,,,,,,, -157400,2.8842037,3.1765666,,,,,,,,,,,,,, -157500,2.501798,2.1107662,,,,,,,,,,,,,, -157600,2.865002,3.3693085,,,,,,,,,,,,,, -157700,2.5860581,1.350479,,,,,,,,,,,,,, -157800,2.6793115,1.327963,,,,,,,,,,,,,, -157900,2.6438048,1.3171102,,,,,,,,,,,,,, -157919,,,0.8621679544448853,0.5378159880638123,0.7645599842071533,0.9374600052833556,50000.0,0.6377000212669373,1.5564005374908447,10000.0,70203.67983937263,76028.39196753502,70203.67983937263,5808.363347053528,7.96866250038147,0.0 -158000,2.4590065,1.6572125,,,,,,,,,,,,,, -158100,2.3797834,1.2607448,,,,,,,,,,,,,, -158200,2.8368332,3.1372306,,,,,,,,,,,,,, -158300,2.5206003,1.4909003,,,,,,,,,,,,,, -158400,2.699775,2.552109,,,,,,,,,,,,,, -158500,2.6054318,1.1939389,,,,,,,,,,,,,, -158600,2.565254,1.3929243,,,,,,,,,,,,,, -158700,2.6463325,1.2399708,,,,,,,,,,,,,, -158800,2.5521452,1.2836447,,,,,,,,,,,,,, -158865,,,0.8666406273841858,0.5006186962127686,0.7650799751281738,0.9160201549530028,50000.0,0.6474000215530396,1.5338735580444336,10000.0,70623.86025309563,76485.72631263733,70623.86025309563,5845.419222831726,8.01696515083313,0.0 -158900,2.7333148,1.4510852,,,,,,,,,,,,,, -159000,2.77675,1.3054719,,,,,,,,,,,,,, -159100,2.5345268,1.4316337,,,,,,,,,,,,,, -159200,2.762196,1.4067944,,,,,,,,,,,,,, -159300,2.5413837,2.018901,,,,,,,,,,,,,, -159400,2.9866111,3.5480516,,,,,,,,,,,,,, -159500,2.6395998,1.3029703,,,,,,,,,,,,,, -159600,2.797623,1.3041701,,,,,,,,,,,,,, -159700,2.426411,1.2409196,,,,,,,,,,,,,, -159800,2.6917384,1.3250755,,,,,,,,,,,,,, -159809,,,0.862597644329071,0.5141065120697021,0.7669199705123901,0.9152435064315796,50000.0,0.6447000503540039,1.5281864404678345,10000.0,71044.06207823753,76942.89191150665,71044.06207823753,5882.281366825104,8.068997144699097,0.0 -159900,3.1465821,3.351359,,,,,,,,,,,,,, -160000,2.631366,1.2840176,,,,,,,,,,,,,, -160100,2.896896,1.324826,,,,,,,,,,,,,, -160200,2.769714,1.1748655,,,,,,,,,,,,,, -160300,3.517735,3.390052,,,,,,,,,,,,,, -160400,2.8104935,2.8695138,,,,,,,,,,,,,, -160500,2.580194,1.3751206,,,,,,,,,,,,,, -160600,2.8424,2.765398,,,,,,,,,,,,,, -160700,2.7505798,1.9063247,,,,,,,,,,,,,, -160754,,,0.8669140338897705,0.5135356187820435,0.7686799764633179,0.9221277236938475,50000.0,0.6433000564575195,1.548334002494812,10000.0,71464.29527115822,77400.4092001915,71464.29527115822,5919.461312532425,8.124317407608032,0.0 -160800,2.5511065,1.3087641,,,,,,,,,,,,,, -160900,3.5240443,3.3413491,,,,,,,,,,,,,, -161000,2.559978,2.0901885,,,,,,,,,,,,,, -161100,2.7969604,1.2687895,,,,,,,,,,,,,, -161200,2.5404925,1.403789,,,,,,,,,,,,,, -161300,2.5422194,1.1660464,,,,,,,,,,,,,, -161400,2.744381,2.0584366,,,,,,,,,,,,,, -161500,2.6809795,1.8371321,,,,,,,,,,,,,, -161600,2.6451683,2.322835,,,,,,,,,,,,,, -161699,,,0.8700976371765137,0.5003441572189331,0.768619954586029,0.916907787322998,50000.0,0.648900032043457,1.5213313102722168,10000.0,71884.56301403046,77857.51973557472,71884.56301403046,5956.193717479706,8.185054779052734,0.0 -161700,2.8180854,1.3824859,,,,,,,,,,,,,, -161800,2.9731798,1.3470671,,,,,,,,,,,,,, -161900,2.647408,1.375661,,,,,,,,,,,,,, -162000,3.0487857,2.1365814,,,,,,,,,,,,,, -162100,3.5735533,2.326109,,,,,,,,,,,,,, -162200,2.7836156,2.8059905,,,,,,,,,,,,,, -162300,3.7900982,3.0705194,,,,,,,,,,,,,, -162400,2.544361,1.3577377,,,,,,,,,,,,,, -162500,3.2560577,1.2655488,,,,,,,,,,,,,, -162600,2.6282723,1.2702818,,,,,,,,,,,,,, -162642,,,0.8746874928474426,0.4664015173912048,0.7695800065994263,0.9045939445495604,50000.0,0.6458000540733337,1.5222994089126587,10000.0,72304.83881092072,78314.5209043026,72304.83881092072,5992.814433336258,8.240005731582642,0.0 -162700,2.9768798,3.2088988,,,,,,,,,,,,,, -162800,2.7798707,1.2211303,,,,,,,,,,,,,, -162900,2.6288135,1.3661424,,,,,,,,,,,,,, -163000,3.4385707,1.8418608,,,,,,,,,,,,,, -163100,2.8527222,1.2612013,,,,,,,,,,,,,, -163200,3.7136881,3.5199726,,,,,,,,,,,,,, -163300,2.5096524,2.1405973,,,,,,,,,,,,,, -163400,2.7281647,2.0519307,,,,,,,,,,,,,, -163500,3.20223,2.044943,,,,,,,,,,,,,, -163585,,,0.8719531297683716,0.4847078919410705,0.7728599905967712,0.8979256749153137,50000.0,0.6484000086784363,1.5181078910827637,10000.0,72724.85019230843,78771.81423521042,72724.85019230843,6029.993057489395,8.293435096740723,0.0 -163600,2.6596532,1.2032079,,,,,,,,,,,,,, -163700,2.6910625,2.4390998,,,,,,,,,,,,,, -163800,2.6337168,1.36172,,,,,,,,,,,,,, -163900,2.4922347,1.1270635,,,,,,,,,,,,,, -164000,2.6825454,1.4714732,,,,,,,,,,,,,, -164100,2.6418798,2.3988512,,,,,,,,,,,,,, -164200,2.9250453,2.6225066,,,,,,,,,,,,,, -164300,2.903306,2.215333,,,,,,,,,,,,,, -164400,2.9044125,1.3706914,,,,,,,,,,,,,, -164500,3.4050422,1.3155508,,,,,,,,,,,,,, -164529,,,0.8699023127555847,0.48884978890419,0.7724399566650391,0.8968047499656677,50000.0,0.6499000191688538,1.5122696161270142,10000.0,73144.75268936157,79228.01746106148,73144.75268936157,6066.192003488541,8.345094919204712,0.0 -164600,2.8072948,1.2049625,,,,,,,,,,,,,, -164700,2.7972302,2.5039763,,,,,,,,,,,,,, -164800,3.0499747,1.9334347,,,,,,,,,,,,,, -164900,3.3839874,3.406487,,,,,,,,,,,,,, -165000,3.1524823,1.320235,,,,,,,,,,,,,, -165100,2.7714655,1.2827463,,,,,,,,,,,,,, -165200,2.6444705,1.5387487,,,,,,,,,,,,,, -165300,2.8469505,1.2015779,,,,,,,,,,,,,, -165400,2.8835924,1.1628695,,,,,,,,,,,,,, -165471,,,0.8733788728713989,0.4675703942775726,0.7732999920845032,0.886921226978302,50000.0,0.650700032711029,1.5046427249908447,10000.0,73564.731341362,79683.63198709488,73564.731341362,6101.721790552139,8.40129041671753,0.0 -165500,2.8368,1.2294266,,,,,,,,,,,,,, -165600,3.0398002,3.177856,,,,,,,,,,,,,, -165700,3.0393574,1.1520702,,,,,,,,,,,,,, -165800,2.756674,1.2860551,,,,,,,,,,,,,, -165900,2.779164,1.2341914,,,,,,,,,,,,,, -166000,2.6990213,2.6434112,,,,,,,,,,,,,, -166100,2.6713169,2.594029,,,,,,,,,,,,,, -166200,2.8193517,2.2374594,,,,,,,,,,,,,, -166300,3.2298946,3.0437212,,,,,,,,,,,,,, -166400,2.6808336,1.1465907,,,,,,,,,,,,,, -166413,,,0.8737499713897705,0.464886873960495,0.7749800086021423,0.881984293460846,50000.0,0.6540000438690186,1.4882982969284058,10000.0,73984.96862435341,80139.76555800438,73984.96862435341,6137.511723995209,8.457269191741943,0.0 -166500,3.3360069,1.3008742,,,,,,,,,,,,,, -166600,3.1389377,3.1542087,,,,,,,,,,,,,, -166700,3.3862588,3.3319716,,,,,,,,,,,,,, -166800,2.662151,1.5180497,,,,,,,,,,,,,, -166900,2.6555696,1.9150171,,,,,,,,,,,,,, -167000,2.6160173,1.2292182,,,,,,,,,,,,,, -167100,2.9050891,1.1886468,,,,,,,,,,,,,, -167200,2.907363,1.35483,,,,,,,,,,,,,, -167300,2.8214176,1.1535507,,,,,,,,,,,,,, -167356,,,0.8771874904632568,0.4618673622608185,0.7760599851608276,0.8795783519744873,50000.0,0.6548000574111938,1.4894704818725586,10000.0,74405.22176742554,80596.62842297554,74405.22176742554,6174.019504547119,8.509846687316895,0.0 -167400,3.132004,1.2451977,,,,,,,,,,,,,, -167500,2.8676527,1.1920289,,,,,,,,,,,,,, -167600,2.8690207,1.2512336,,,,,,,,,,,,,, -167700,2.8683689,1.1023397,,,,,,,,,,,,,, -167800,3.161012,3.0057836,,,,,,,,,,,,,, -167900,2.867747,2.088646,,,,,,,,,,,,,, -168000,2.9640005,1.1653615,,,,,,,,,,,,,, -168100,2.8013027,1.1388763,,,,,,,,,,,,,, -168200,3.0593433,1.1675576,,,,,,,,,,,,,, -168300,4.1702843,3.025137,,,,,,,,,,,,,, -168303,,,0.8756640553474426,0.4663519561290741,0.7746599912643433,0.8834454417228699,50000.0,0.651900053024292,1.4991596937179563,10000.0,74825.44824123383,81052.783826828,74825.44824123383,6209.835191726685,8.572486162185669,0.0 -168400,3.204409,1.2098845,,,,,,,,,,,,,, -168500,2.9328635,2.5364864,,,,,,,,,,,,,, -168600,2.6389544,1.1133747,,,,,,,,,,,,,, -168700,2.8694289,1.144403,,,,,,,,,,,,,, -168800,2.9861155,1.2329918,,,,,,,,,,,,,, -168900,3.4520342,3.0293872,,,,,,,,,,,,,, -169000,3.0753198,1.2053031,,,,,,,,,,,,,, -169100,3.2452137,1.2230638,,,,,,,,,,,,,, -169200,3.189899,1.2457943,,,,,,,,,,,,,, -169247,,,0.8783007860183716,0.4513605535030365,0.7752000093460083,0.8814749121665955,50000.0,0.6554000377655029,1.493224024772644,10000.0,75245.51100564003,81508.45189023018,75245.51100564003,6245.338510274887,8.62363052368164,0.0 -169300,2.9564486,1.1380574,,,,,,,,,,,,,, -169400,3.6605062,3.2655604,,,,,,,,,,,,,, -169500,3.6695457,3.1290693,,,,,,,,,,,,,, -169600,3.0549748,1.2366312,,,,,,,,,,,,,, -169700,3.1164,1.1918963,,,,,,,,,,,,,, -169800,2.8300695,1.1776693,,,,,,,,,,,,,, -169900,2.861055,1.1781362,,,,,,,,,,,,,, -170000,3.0353184,1.140572,,,,,,,,,,,,,, -170100,2.7758102,1.7790796,,,,,,,,,,,,,, -170191,,,0.8774804472923279,0.4559576213359833,0.7778799533843994,0.8701184391975403,50000.0,0.65420001745224,1.4848601818084717,10000.0,75665.74991869926,81963.98689365387,75665.74991869926,6280.529743909836,8.678744554519653,0.0 -170200,3.3684256,2.8396845,,,,,,,,,,,,,, -170300,2.7312207,1.6641902,,,,,,,,,,,,,, -170400,3.333553,1.5660989,,,,,,,,,,,,,, -170500,2.7087424,1.5756652,,,,,,,,,,,,,, -170600,3.1238966,2.8642485,,,,,,,,,,,,,, -170700,3.0008965,1.1860652,,,,,,,,,,,,,, -170800,2.9890063,1.1541697,,,,,,,,,,,,,, -170900,3.261537,1.7618155,,,,,,,,,,,,,, -171000,3.2742841,1.1970779,,,,,,,,,,,,,, -171100,3.0245366,1.2764333,,,,,,,,,,,,,, -171135,,,0.8802539110183716,0.4446616172790527,0.7779600024223328,0.8712742924690247,50000.0,0.6574000120162964,1.4882923364639282,10000.0,76086.01024103165,82421.57224082947,76086.01024103165,6317.746684074402,8.737242698669434,0.0 -171200,3.4201808,3.31242,,,,,,,,,,,,,, -171300,3.119603,1.3231391,,,,,,,,,,,,,, -171400,3.0727098,1.2022381,,,,,,,,,,,,,, -171500,4.7298727,1.9848382,,,,,,,,,,,,,, -171600,3.0112424,1.1277053,,,,,,,,,,,,,, -171700,3.6146464,3.346425,,,,,,,,,,,,,, -171800,2.9791439,1.8991754,,,,,,,,,,,,,, -171900,2.786956,2.6850464,,,,,,,,,,,,,, -172000,3.1757984,2.0001493,,,,,,,,,,,,,, -172081,,,0.8818163871765137,0.4375278651714325,0.7779200077056885,0.8668556213378906,50000.0,0.6577000021934509,1.4779142141342163,10000.0,76506.09647703171,82877.83606362343,76506.09647703171,6353.816187143326,8.7944974899292,0.0 -172100,2.9858687,1.2841046,,,,,,,,,,,,,, -172200,4.742935,2.4788852,,,,,,,,,,,,,, -172300,3.737486,3.1795754,,,,,,,,,,,,,, -172400,3.8493028,3.2263827,,,,,,,,,,,,,, -172500,2.7313144,2.3304436,,,,,,,,,,,,,, -172600,3.056984,1.1830839,,,,,,,,,,,,,, -172700,2.9685502,1.2490524,,,,,,,,,,,,,, -172800,2.7304106,1.1663613,,,,,,,,,,,,,, -172900,3.3409946,2.612884,,,,,,,,,,,,,, -173000,3.095898,1.1547691,,,,,,,,,,,,,, -173026,,,0.8795507550239563,0.4471313059329986,0.7780799865722656,0.8690222501754761,50000.0,0.6589000225067139,1.478680968284607,10000.0,76926.27028298378,83333.81209087372,76926.27028298378,6389.507263660431,8.855574369430542,0.0 -173100,2.5934846,1.8086506,,,,,,,,,,,,,, -173200,4.101094,3.0771348,,,,,,,,,,,,,, -173300,3.4965055,2.8888168,,,,,,,,,,,,,, -173400,3.5180657,3.1403463,,,,,,,,,,,,,, -173500,8.877783,3.0221243,,,,,,,,,,,,,, -173600,3.0032444,1.2902919,,,,,,,,,,,,,, -173700,3.7735245,2.6044962,,,,,,,,,,,,,, -173800,3.973214,3.3062904,,,,,,,,,,,,,, -173900,2.9226058,1.2860569,,,,,,,,,,,,,, -173974,,,0.8820312023162842,0.438140720129013,0.7789199948310852,0.8620977997779846,50000.0,0.6617000102996826,1.4729156494140625,10000.0,77346.20090389252,83788.990837574,77346.20090389252,6424.645256996155,8.914592504501343,0.0 -174000,4.018449,3.295887,,,,,,,,,,,,,, -174100,2.972684,1.4515061,,,,,,,,,,,,,, -174200,3.8179963,3.0487804,,,,,,,,,,,,,, -174300,2.9628153,1.8551546,,,,,,,,,,,,,, -174400,3.4326117,3.0805662,,,,,,,,,,,,,, -174500,3.2990253,2.558228,,,,,,,,,,,,,, -174600,3.0542364,1.1452993,,,,,,,,,,,,,, -174700,2.8884194,1.6609025,,,,,,,,,,,,,, -174800,4.805639,3.2094395,,,,,,,,,,,,,, -174900,3.2088602,1.2373769,,,,,,,,,,,,,, -174916,,,0.8839452862739563,0.4305086731910705,0.7795400023460388,0.8621959090232849,50000.0,0.6585000157356262,1.4759646654129028,10000.0,77766.21065545082,84246.11060118675,77766.21065545082,6461.650452852249,8.96901798248291,0.0 -175000,3.9059255,3.3041248,,,,,,,,,,,,,, -175100,3.2366204,1.1475352,,,,,,,,,,,,,, -175200,2.6286612,1.6647767,,,,,,,,,,,,,, -175300,3.6863399,3.195947,,,,,,,,,,,,,, -175400,3.3473423,2.6525662,,,,,,,,,,,,,, -175500,2.9346821,1.1511773,,,,,,,,,,,,,, -175600,3.399934,1.0770379,,,,,,,,,,,,,, -175700,3.1675863,3.0855927,,,,,,,,,,,,,, -175800,3.093664,2.5310795,,,,,,,,,,,,,, -175861,,,0.8856444954872131,0.4301820993423462,0.7799199819564819,0.8656541705131531,50000.0,0.6604000329971313,1.4743776321411133,10000.0,78186.27035021782,84702.1522629261,78186.27035021782,6497.523934841156,9.027230262756348,0.0 -175900,3.2291632,1.4763191,,,,,,,,,,,,,, -176000,3.29302,1.1922971,,,,,,,,,,,,,, -176100,3.8402464,2.927539,,,,,,,,,,,,,, -176200,3.4446225,2.1003296,,,,,,,,,,,,,, -176300,3.4968126,3.183668,,,,,,,,,,,,,, -176400,3.098999,1.1973462,,,,,,,,,,,,,, -176500,2.913023,1.345978,,,,,,,,,,,,,, -176600,3.2981336,2.7900968,,,,,,,,,,,,,, -176700,3.1658227,1.582994,,,,,,,,,,,,,, -176800,3.084729,1.1849492,,,,,,,,,,,,,, -176806,,,0.88330078125,0.4406342208385467,0.7821799516677856,0.86235511302948,50000.0,0.6619000434875488,1.4712177515029907,10000.0,78606.58297896385,85157.835013628,78606.58297896385,6532.788912773132,9.082460403442385,0.0 -176900,3.5070865,3.236935,,,,,,,,,,,,,, -177000,3.0788248,1.1602801,,,,,,,,,,,,,, -177100,2.8453808,1.5010248,,,,,,,,,,,,,, -177200,3.6698833,3.2534604,,,,,,,,,,,,,, -177300,3.3196008,2.8965132,,,,,,,,,,,,,, -177400,2.9391649,1.4265712,,,,,,,,,,,,,, -177500,3.8655074,3.0766695,,,,,,,,,,,,,, -177600,3.5544624,1.3375916,,,,,,,,,,,,,, -177700,2.916879,1.6513445,,,,,,,,,,,,,, -177751,,,0.88427734375,0.4303996860980987,0.780959963798523,0.8608471155166626,50000.0,0.6593000292778015,1.473184585571289,10000.0,79026.84996557236,85614.04051232338,79026.84996557236,6568.6220116615295,9.137590885162354,0.0 -177800,3.223006,1.111901,,,,,,,,,,,,,, -177900,3.557153,3.0112576,,,,,,,,,,,,,, -178000,3.044874,2.4128575,,,,,,,,,,,,,, -178100,3.1751301,1.2034665,,,,,,,,,,,,,, -178200,3.3652909,1.1934476,,,,,,,,,,,,,, -178300,3.0657816,1.0194515,,,,,,,,,,,,,, -178400,5.5303125,2.852246,,,,,,,,,,,,,, -178500,3.2317994,1.2938519,,,,,,,,,,,,,, -178600,3.4787831,2.9373665,,,,,,,,,,,,,, -178697,,,0.8859374523162842,0.4259621798992157,0.7819199562072754,0.8566552400588989,50000.0,0.6617000102996826,1.4699724912643433,10000.0,79446.80137014389,86069.71082782745,79446.80137014389,6604.234881401062,9.192800283432009,0.0 -178700,3.2282228,1.109058,,,,,,,,,,,,,, -178800,3.2294297,2.4089499,,,,,,,,,,,,,, -178900,3.693927,3.0867276,,,,,,,,,,,,,, -179000,3.1932638,1.4454467,,,,,,,,,,,,,, -179100,3.7406476,2.9263937,,,,,,,,,,,,,, -179200,2.8708894,1.6339481,,,,,,,,,,,,,, -179300,3.1016278,2.6517267,,,,,,,,,,,,,, -179400,5.3928037,3.1644185,,,,,,,,,,,,,, -179500,4.3967247,3.2905564,,,,,,,,,,,,,, -179600,3.247681,1.1977992,,,,,,,,,,,,,, -179645,,,0.8874218463897705,0.4268486797809601,0.7813599705696106,0.8589540123939514,50000.0,0.6598000526428223,1.470707654953003,10000.0,79866.9059574604,86527.36637544632,79866.9059574604,6641.677387714386,9.251446962356567,0.0 -179700,2.9506254,1.1747817,,,,,,,,,,,,,, -179800,3.0068047,1.1393344,,,,,,,,,,,,,, -179900,3.148516,1.53762,,,,,,,,,,,,,, -180000,2.9424777,1.1887141,,,,,,,,,,,,,, -180100,3.3078122,2.6989646,,,,,,,,,,,,,, -180200,5.1740146,2.632908,,,,,,,,,,,,,, -180300,2.9957848,1.293106,,,,,,,,,,,,,, -180400,4.7129197,2.7573857,,,,,,,,,,,,,, -180500,3.663248,2.758292,,,,,,,,,,,,,, -180589,,,0.8857421875,0.4221574366092682,0.7830399870872498,0.8535976409912109,50000.0,0.6617000102996826,1.4652323722839355,10000.0,80287.17362117767,86982.06272292137,80287.17362117767,6675.996264457703,9.310755252838137,0.0 -180600,3.7322533,3.2953975,,,,,,,,,,,,,, -180700,3.6548076,3.130915,,,,,,,,,,,,,, -180800,3.836746,1.1456254,,,,,,,,,,,,,, -180900,2.9095874,1.1003623,,,,,,,,,,,,,, -181000,3.152706,1.6332641,,,,,,,,,,,,,, -181100,3.1424403,1.2997662,,,,,,,,,,,,,, -181200,2.9643211,1.2228587,,,,,,,,,,,,,, -181300,3.2105517,1.3724642,,,,,,,,,,,,,, -181400,3.210093,2.7583117,,,,,,,,,,,,,, -181500,2.9816358,1.8140138,,,,,,,,,,,,,, -181530,,,0.8855078220367432,0.4229528307914734,0.782539963722229,0.8535990118980408,50000.0,0.6627000570297241,1.4637311697006226,10000.0,80707.40728902817,87437.47354197502,80707.40728902817,6711.070337772369,9.363951683044434,0.0 -181600,3.154494,1.2696749,,,,,,,,,,,,,, -181700,3.289249,1.8248048,,,,,,,,,,,,,, -181800,3.6938868,3.1752791,,,,,,,,,,,,,, -181900,3.1663404,1.1348532,,,,,,,,,,,,,, -182000,3.317447,1.2846314,,,,,,,,,,,,,, -182100,3.023099,1.1654055,,,,,,,,,,,,,, -182200,4.117461,2.2272558,,,,,,,,,,,,,, -182300,3.231836,1.0984514,,,,,,,,,,,,,, -182400,3.2233272,1.2027959,,,,,,,,,,,,,, -182476,,,0.8879101276397705,0.4179525971412658,0.7822999954223633,0.8548969030380249,50000.0,0.6619000434875488,1.4650492668151855,10000.0,81127.42118930817,87893.53619122505,81127.42118930817,6746.992145776749,9.44046401977539,0.0 -182500,3.0531073,1.0544353,,,,,,,,,,,,,, -182600,3.0692737,1.7209508,,,,,,,,,,,,,, -182700,3.0009074,1.5144122,,,,,,,,,,,,,, -182800,2.9774544,1.1091328,,,,,,,,,,,,,, -182900,3.5200367,2.797183,,,,,,,,,,,,,, -183000,3.1394327,2.2562275,,,,,,,,,,,,,, -183100,3.4263833,2.9552097,,,,,,,,,,,,,, -183200,3.335808,2.9266,,,,,,,,,,,,,, -183300,3.0246325,1.2093074,,,,,,,,,,,,,, -183400,3.99841,2.6913137,,,,,,,,,,,,,, -183422,,,0.8875195384025574,0.4133367538452148,0.7827999591827393,0.8514824509620667,50000.0,0.6612000465393066,1.4615788459777832,10000.0,81547.52734351158,88348.74827170372,81547.52734351158,6781.985077857971,9.50280237197876,0.0 -183500,3.5751643,1.702612,,,,,,,,,,,,,, -183600,3.3259308,1.294525,,,,,,,,,,,,,, -183700,2.9242206,1.121933,,,,,,,,,,,,,, -183800,3.3623972,1.7513163,,,,,,,,,,,,,, -183900,3.1596882,1.1580365,,,,,,,,,,,,,, -184000,4.8933177,2.8707676,,,,,,,,,,,,,, -184100,3.3841028,2.137616,,,,,,,,,,,,,, -184200,3.1055892,2.4028769,,,,,,,,,,,,,, -184300,3.260339,1.9594587,,,,,,,,,,,,,, -184369,,,0.8892773389816284,0.4150594472885132,0.7831400036811829,0.8529243469238281,50000.0,0.6614000201225281,1.4628207683563232,10000.0,81967.57227182388,88806.49009799957,81967.57227182388,6819.577335119247,9.556721925735474,0.0 -184400,2.9773085,1.9923605,,,,,,,,,,,,,, -184500,3.028996,1.078429,,,,,,,,,,,,,, -184600,3.6772897,3.031733,,,,,,,,,,,,,, -184700,3.070188,1.1810707,,,,,,,,,,,,,, -184800,3.854145,3.1897442,,,,,,,,,,,,,, -184900,3.50079,1.3014573,,,,,,,,,,,,,, -185000,3.1425962,1.1396648,,,,,,,,,,,,,, -185100,3.3321676,2.579003,,,,,,,,,,,,,, -185200,3.4481142,2.804445,,,,,,,,,,,,,, -185300,3.307064,1.0183156,,,,,,,,,,,,,, -185318,,,0.8887499570846558,0.4133415520191192,0.7833399772644043,0.8519114851951599,50000.0,0.6621000170707703,1.4614572525024414,10000.0,82387.84855127335,89263.10540890694,82387.84855127335,6855.812837123871,9.610158681869509,0.0 -185400,3.2526662,1.3871127,,,,,,,,,,,,,, -185500,3.5616612,2.4164205,,,,,,,,,,,,,, -185600,3.1078732,1.1653008,,,,,,,,,,,,,, -185700,3.0409303,1.1084292,,,,,,,,,,,,,, -185800,2.8968167,2.5282333,,,,,,,,,,,,,, -185900,3.1594467,1.730839,,,,,,,,,,,,,, -186000,2.8252795,1.0769224,,,,,,,,,,,,,, -186100,3.6522167,1.1464934,,,,,,,,,,,,,, -186200,3.1061282,1.06866,,,,,,,,,,,,,, -186263,,,0.88880854845047,0.4150834083557129,0.7832399606704712,0.8526034355163574,50000.0,0.6614000201225281,1.4626858234405518,10000.0,82808.1642510891,89719.92125058174,82808.1642510891,6892.20347237587,9.669806718826294,0.0 -186300,3.8343177,3.106732,,,,,,,,,,,,,, -186400,3.1517246,1.09948,,,,,,,,,,,,,, -186500,3.381491,3.06934,,,,,,,,,,,,,, -186600,3.4413657,3.0320597,,,,,,,,,,,,,, -186700,4.3008924,3.1645436,,,,,,,,,,,,,, -186800,3.2682166,2.8051627,,,,,,,,,,,,,, -186900,3.1141667,1.1229899,,,,,,,,,,,,,, -187000,3.697331,3.1739495,,,,,,,,,,,,,, -187100,3.6748626,3.1612754,,,,,,,,,,,,,, -187200,3.3476613,2.4788318,,,,,,,,,,,,,, -187209,,,0.8882226347923279,0.4158450961112976,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,83228.10613918304,90176.43953990936,83228.10613918304,6928.669387102127,9.730186939239502,0.0 -187300,2.767775,1.3695625,,,,,,,,,,,,,, -187400,3.2421145,1.0585251,,,,,,,,,,,,,, -187500,3.258183,2.3143063,,,,,,,,,,,,,, -187600,3.289059,1.3703367,,,,,,,,,,,,,, -187700,4.301563,3.2638664,,,,,,,,,,,,,, -187800,3.0786016,2.3648622,,,,,,,,,,,,,, -187900,3.953461,3.0718806,,,,,,,,,,,,,, -188000,2.955712,1.3146822,,,,,,,,,,,,,, -188100,3.2186754,2.92591,,,,,,,,,,,,,, -188153,,,0.8866796493530273,0.4240669012069702,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,83648.2685251236,90632.70875310898,83648.2685251236,6964.671956300736,9.78453016281128,0.0 -188200,3.420374,2.431161,,,,,,,,,,,,,, -188300,3.9746144,3.2129054,,,,,,,,,,,,,, -188400,3.1028385,1.0292586,,,,,,,,,,,,,, -188500,4.230981,3.2654493,,,,,,,,,,,,,, -188600,3.179162,1.3795946,,,,,,,,,,,,,, -188700,2.9871159,1.6863629,,,,,,,,,,,,,, -188800,3.0812516,1.2057968,,,,,,,,,,,,,, -188900,3.3141053,2.5148935,,,,,,,,,,,,,, -189000,3.1230335,1.1554561,,,,,,,,,,,,,, -189099,,,0.8877343535423279,0.4180521368980407,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,84068.47924923897,91090.34984946252,84068.47924923897,7001.998011350632,9.83876132965088,0.0 -189100,3.315009,2.1154456,,,,,,,,,,,,,, -189200,4.2608533,3.258126,,,,,,,,,,,,,, -189300,3.0547688,1.0819355,,,,,,,,,,,,,, -189400,2.9556987,1.1847968,,,,,,,,,,,,,, -189500,4.8663373,3.2243984,,,,,,,,,,,,,, -189600,3.1457713,1.1109096,,,,,,,,,,,,,, -189700,2.87101,1.1531942,,,,,,,,,,,,,, -189800,2.8216588,1.0770875,,,,,,,,,,,,,, -189900,3.1506891,1.1471655,,,,,,,,,,,,,, -190000,3.0865862,1.0809189,,,,,,,,,,,,,, -190046,,,0.8861523270606995,0.4211257696151733,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,84488.68370914459,91544.91623663902,84488.68370914459,7036.250737190247,9.89790678024292,0.0 -190100,3.097555,1.737731,,,,,,,,,,,,,, -190200,3.8106449,3.2991223,,,,,,,,,,,,,, -190300,3.7964582,1.496434,,,,,,,,,,,,,, -190400,2.931941,2.0649545,,,,,,,,,,,,,, -190500,2.8355536,1.0685225,,,,,,,,,,,,,, -190600,3.175853,1.1108913,,,,,,,,,,,,,, -190700,3.0432367,1.1774735,,,,,,,,,,,,,, -190800,3.2256293,1.335712,,,,,,,,,,,,,, -190900,2.967745,1.9050614,,,,,,,,,,,,,, -190988,,,0.8909765481948853,0.413134753704071,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,84909.03533935547,92001.10895776749,84909.03533935547,7071.985706090927,9.954696655273438,0.0 -191000,3.0726058,1.1433748,,,,,,,,,,,,,, -191100,2.9164186,1.5110105,,,,,,,,,,,,,, -191200,3.2015123,1.9398273,,,,,,,,,,,,,, -191300,3.1409683,1.818298,,,,,,,,,,,,,, -191400,3.1598992,2.2013586,,,,,,,,,,,,,, -191500,3.3057165,1.6600881,,,,,,,,,,,,,, -191600,3.0644627,1.1298912,,,,,,,,,,,,,, -191700,3.4460616,2.7263846,,,,,,,,,,,,,, -191800,2.9433682,1.1165106,,,,,,,,,,,,,, -191900,3.1761537,1.2863333,,,,,,,,,,,,,, -191933,,,0.8879101276397705,0.4145359992980957,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,85329.32217812538,92457.35536766052,85329.32217812538,7107.835638523102,10.013721942901611,0.0 -192000,3.1778836,1.2259861,,,,,,,,,,,,,, -192100,3.050196,2.1453435,,,,,,,,,,,,,, -192200,3.1485837,1.127696,,,,,,,,,,,,,, -192300,3.4540899,1.1906158,,,,,,,,,,,,,, -192400,3.373222,1.1794045,,,,,,,,,,,,,, -192500,4.4370937,3.2172422,,,,,,,,,,,,,, -192600,3.4772763,1.1880026,,,,,,,,,,,,,, -192700,3.2139852,2.6357367,,,,,,,,,,,,,, -192800,3.2749555,2.030466,,,,,,,,,,,,,, -192879,,,0.8848242163658142,0.4299986958503723,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,85749.28629517555,92911.35882663728,85749.28629517555,7141.772822141647,10.06668186187744,0.0 -192900,3.4476583,1.2248089,,,,,,,,,,,,,, -193000,3.6129904,1.023806,,,,,,,,,,,,,, -193100,3.461767,2.892646,,,,,,,,,,,,,, -193200,3.1412501,1.0624034,,,,,,,,,,,,,, -193300,3.6068337,1.4109765,,,,,,,,,,,,,, -193400,2.9228146,1.6929775,,,,,,,,,,,,,, -193500,3.1070762,1.560667,,,,,,,,,,,,,, -193600,3.4814172,3.1572382,,,,,,,,,,,,,, -193700,2.8962753,1.1207652,,,,,,,,,,,,,, -193800,2.9346821,1.142456,,,,,,,,,,,,,, -193825,,,0.8856444954872131,0.4275977611541748,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,86169.27835011482,93366.63113760948,86169.27835011482,7176.942252874374,10.127224683761597,0.0 -193900,2.97257,1.4169935,,,,,,,,,,,,,, -194000,2.9593446,1.4393706,,,,,,,,,,,,,, -194100,3.2485673,2.4935188,,,,,,,,,,,,,, -194200,3.0811527,1.1062875,,,,,,,,,,,,,, -194300,3.5000327,1.1336596,,,,,,,,,,,,,, -194400,2.9254918,2.060622,,,,,,,,,,,,,, -194500,3.6637826,2.8710103,,,,,,,,,,,,,, -194600,3.94688,3.13779,,,,,,,,,,,,,, -194700,2.9150503,1.1416259,,,,,,,,,,,,,, -194769,,,0.8881444931030273,0.4168613851070404,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,86589.32567048073,93823.04841327669,86589.32567048073,7213.204705715179,10.184736251831056,0.0 -194800,3.2138867,1.127633,,,,,,,,,,,,,, -194900,3.3187804,1.9621788,,,,,,,,,,,,,, -195000,3.5557375,3.0872552,,,,,,,,,,,,,, -195100,3.084283,1.0673407,,,,,,,,,,,,,, -195200,3.1601145,1.2016134,,,,,,,,,,,,,, -195300,3.464091,2.6367853,,,,,,,,,,,,,, -195400,3.2621515,2.6947773,,,,,,,,,,,,,, -195500,3.768977,3.2003698,,,,,,,,,,,,,, -195600,3.9837744,3.0357645,,,,,,,,,,,,,, -195700,3.2273762,1.2836273,,,,,,,,,,,,,, -195714,,,0.88720703125,0.4198476374149322,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,87009.27286624908,94278.746717453,87009.27286624908,7248.850476980209,10.240220785140991,0.0 -195800,3.1271672,1.1062243,,,,,,,,,,,,,, -195900,3.384088,2.868746,,,,,,,,,,,,,, -196000,3.0778704,1.151685,,,,,,,,,,,,,, -196100,3.5880027,3.056409,,,,,,,,,,,,,, -196200,3.5436323,1.1413943,,,,,,,,,,,,,, -196300,5.821463,3.1247876,,,,,,,,,,,,,, -196400,3.0057437,1.2323964,,,,,,,,,,,,,, -196500,2.8875957,0.98569506,,,,,,,,,,,,,, -196600,3.1575623,2.5979478,,,,,,,,,,,,,, -196662,,,0.8884570002555847,0.4169119894504547,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,87429.595079422,94735.3364572525,87429.595079422,7285.001372337341,10.306553363800049,0.0 -196700,3.0759335,1.4748942,,,,,,,,,,,,,, -196800,3.1550047,1.2670115,,,,,,,,,,,,,, -196900,3.1014676,1.3606404,,,,,,,,,,,,,, -197000,3.289711,1.080053,,,,,,,,,,,,,, -197100,3.1887782,1.0761737,,,,,,,,,,,,,, -197200,3.0532057,1.392247,,,,,,,,,,,,,, -197300,4.5718007,2.3828986,,,,,,,,,,,,,, -197400,3.3205974,1.4972272,,,,,,,,,,,,,, -197500,5.1237206,2.5606966,,,,,,,,,,,,,, -197600,3.1869326,2.3171697,,,,,,,,,,,,,, -197609,,,0.8848632574081421,0.426829069852829,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,87849.86155200005,95189.90076732635,87849.86155200005,7319.191005468368,10.364394903182983,0.0 -197700,3.163869,1.2407943,,,,,,,,,,,,,, -197800,3.1592133,1.712285,,,,,,,,,,,,,, -197900,3.2363958,1.1777712,,,,,,,,,,,,,, -198000,4.046633,3.1572769,,,,,,,,,,,,,, -198100,3.0366983,1.3127333,,,,,,,,,,,,,, -198200,4.1897936,3.2650592,,,,,,,,,,,,,, -198300,3.223914,1.1421571,,,,,,,,,,,,,, -198400,3.6904843,3.142028,,,,,,,,,,,,,, -198500,4.1641498,3.2245681,,,,,,,,,,,,,, -198555,,,0.8874609470367432,0.4123820066452026,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,88270.01970601082,95646.08177280426,88270.01970601082,7355.1042511463165,10.42365837097168,0.0 -198600,3.140869,1.088561,,,,,,,,,,,,,, -198700,3.1856487,1.1649561,,,,,,,,,,,,,, -198800,3.1711087,1.190021,,,,,,,,,,,,,, -198900,3.002004,2.2259893,,,,,,,,,,,,,, -199000,3.236164,1.0196961,,,,,,,,,,,,,, -199100,3.0638661,1.1161444,,,,,,,,,,,,,, -199200,3.2392724,2.6719031,,,,,,,,,,,,,, -199300,3.304954,1.9281242,,,,,,,,,,,,,, -199400,3.2522895,1.45718,,,,,,,,,,,,,, -199500,4.107801,3.2708595,,,,,,,,,,,,,, -199501,,,0.8875781297683716,0.4194918870925903,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,88690.64194583893,96102.44577240944,88690.64194583893,7390.73459148407,10.484740495681764,0.0 -199600,3.861541,3.0186775,,,,,,,,,,,,,, -199700,3.2946072,1.1393659,,,,,,,,,,,,,, -199800,2.9457393,1.173884,,,,,,,,,,,,,, -199900,3.7293239,2.4006627,,,,,,,,,,,,,, -200000,2.9260678,1.9591168,,,,,,,,,,,,,, -200100,3.2459867,1.1418378,,,,,,,,,,,,,, -200200,3.1579876,1.6086872,,,,,,,,,,,,,, -200300,2.9127128,1.3330601,,,,,,,,,,,,,, -200400,4.084962,1.3018388,,,,,,,,,,,,,, -200445,,,0.8860741853713989,0.4245986938476562,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,89110.59979486465,96559.44136738776,89110.59979486465,7427.6580538749695,10.54875898361206,0.0 -200500,2.9067724,1.0805038,,,,,,,,,,,,,, -200600,2.9863102,1.0846492,,,,,,,,,,,,,, -200700,2.9792109,1.1663513,,,,,,,,,,,,,, -200800,3.2106009,1.1385891,,,,,,,,,,,,,, -200900,2.9783049,1.8710403,,,,,,,,,,,,,, -201000,3.1028752,1.1204453,,,,,,,,,,,,,, -201100,4.2992,3.192224,,,,,,,,,,,,,, -201200,3.0898187,1.1218511,,,,,,,,,,,,,, -201300,2.8094885,1.7591512,,,,,,,,,,,,,, -201391,,,0.8870702981948853,0.4246867001056671,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,89530.50890040398,97014.27209377287,89530.50890040398,7462.471394062042,10.607412338256836,0.0 -201400,3.1089664,2.5928092,,,,,,,,,,,,,, -201500,3.2468562,1.1056043,,,,,,,,,,,,,, -201600,3.2010803,1.399528,,,,,,,,,,,,,, -201700,2.8107302,1.3041116,,,,,,,,,,,,,, -201800,3.007803,1.2560856,,,,,,,,,,,,,, -201900,3.178119,1.4731902,,,,,,,,,,,,,, -202000,4.977898,3.0250006,,,,,,,,,,,,,, -202100,2.9163244,1.8278183,,,,,,,,,,,,,, -202200,3.7742183,1.0655935,,,,,,,,,,,,,, -202300,3.088548,1.166843,,,,,,,,,,,,,, -202334,,,0.8855664134025574,0.427189439535141,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,89950.66359424591,97471.39986658096,89950.66359424591,7499.333806037903,10.667636156082152,0.0 -202400,3.006178,1.611476,,,,,,,,,,,,,, -202500,3.1087017,1.373085,,,,,,,,,,,,,, -202600,2.9870615,1.0405247,,,,,,,,,,,,,, -202700,3.3331969,1.0449665,,,,,,,,,,,,,, -202800,3.0243413,1.1684723,,,,,,,,,,,,,, -202900,3.15687,1.583043,,,,,,,,,,,,,, -203000,3.6556911,2.948007,,,,,,,,,,,,,, -203100,3.1648421,1.1954483,,,,,,,,,,,,,, -203200,5.14321,3.1448004,,,,,,,,,,,,,, -203278,,,0.8870702981948853,0.4194826483726501,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,90370.86471438408,97927.07135868073,90370.86471438408,7534.678661584854,10.743087768554688,0.0 -203300,3.0898428,2.454603,,,,,,,,,,,,,, -203400,4.6080265,3.255874,,,,,,,,,,,,,, -203500,3.625738,2.1680887,,,,,,,,,,,,,, -203600,3.1636043,1.1013834,,,,,,,,,,,,,, -203700,3.3256114,1.6322484,,,,,,,,,,,,,, -203800,3.1428988,1.2205743,,,,,,,,,,,,,, -203900,3.3012102,1.0711659,,,,,,,,,,,,,, -204000,3.3384466,1.1624793,,,,,,,,,,,,,, -204100,2.9622304,1.2531668,,,,,,,,,,,,,, -204200,3.1799428,1.6402684,,,,,,,,,,,,,, -204221,,,0.8884961009025574,0.4168415665626526,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,90790.99326586723,98383.30153632164,90790.99326586723,7570.669605970383,10.802512645721436,0.0 -204300,3.5770762,2.0923538,,,,,,,,,,,,,, -204400,3.0898252,1.1146249,,,,,,,,,,,,,, -204500,2.757451,1.182162,,,,,,,,,,,,,, -204600,3.203219,1.6104654,,,,,,,,,,,,,, -204700,3.5960116,2.7401564,,,,,,,,,,,,,, -204800,3.5241492,1.1993222,,,,,,,,,,,,,, -204900,3.5834985,1.1924698,,,,,,,,,,,,,, -205000,3.6981463,3.2713313,,,,,,,,,,,,,, -205100,3.2405834,1.0871235,,,,,,,,,,,,,, -205166,,,0.8862499594688416,0.4210224449634552,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,91211.18811535837,98839.98974704742,91211.18811535837,7607.05656671524,10.85942006111145,0.0 -205200,3.3014195,1.995902,,,,,,,,,,,,,, -205300,3.0563953,1.1723324,,,,,,,,,,,,,, -205400,3.308705,1.2611817,,,,,,,,,,,,,, -205500,3.0255153,1.2390498,,,,,,,,,,,,,, -205600,3.0015333,1.2995498,,,,,,,,,,,,,, -205700,4.1647077,3.1410604,,,,,,,,,,,,,, -205800,2.8771842,1.872602,,,,,,,,,,,,,, -205900,3.2993395,1.2011924,,,,,,,,,,,,,, -206000,3.2053032,2.071231,,,,,,,,,,,,,, -206100,3.141181,2.447608,,,,,,,,,,,,,, -206112,,,0.8875976204872131,0.4213772118091583,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,91631.25634765624,99297.76484441756,91631.25634765624,7644.656408786774,10.91633415222168,0.0 -206200,4.183941,3.2636313,,,,,,,,,,,,,, -206300,2.9705253,1.0463104,,,,,,,,,,,,,, -206400,3.7414105,3.1010532,,,,,,,,,,,,,, -206500,3.1253877,2.0333962,,,,,,,,,,,,,, -206600,5.0705914,3.293325,,,,,,,,,,,,,, -206700,3.1265736,1.4346025,,,,,,,,,,,,,, -206800,3.1418328,1.1184974,,,,,,,,,,,,,, -206900,3.151843,1.1591945,,,,,,,,,,,,,, -207000,3.1077247,1.5858687,,,,,,,,,,,,,, -207056,,,0.8889257907867432,0.4079552888870239,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,92051.31888103484,99752.18914794922,92051.31888103484,7678.900659561157,10.983742237091064,0.0 -207100,3.0978706,1.300824,,,,,,,,,,,,,, -207200,3.8286445,3.127244,,,,,,,,,,,,,, -207300,3.3084967,1.6010785,,,,,,,,,,,,,, -207400,3.0173848,1.0483252,,,,,,,,,,,,,, -207500,3.0456405,1.2473923,,,,,,,,,,,,,, -207600,3.061361,1.3004284,,,,,,,,,,,,,, -207700,3.599472,2.615418,,,,,,,,,,,,,, -207800,5.0968213,2.2000136,,,,,,,,,,,,,, -207900,3.5492072,2.0043433,,,,,,,,,,,,,, -208000,3.2136462,1.1449726,,,,,,,,,,,,,, -208001,,,0.8864452838897705,0.425664484500885,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,92471.98997926712,100208.54979085922,92471.98997926712,7714.484362125397,11.040274143218994,0.0 -208100,2.9213212,1.708254,,,,,,,,,,,,,, -208200,3.384987,1.7265514,,,,,,,,,,,,,, -208300,4.184891,2.7722898,,,,,,,,,,,,,, -208400,3.5890467,1.1710935,,,,,,,,,,,,,, -208500,3.2037008,2.205194,,,,,,,,,,,,,, -208600,2.977153,1.1234945,,,,,,,,,,,,,, -208700,3.0304518,1.2029378,,,,,,,,,,,,,, -208800,3.2809029,2.4545188,,,,,,,,,,,,,, -208900,3.2445621,2.371015,,,,,,,,,,,,,, -208947,,,0.8887109160423279,0.4154536426067352,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,92892.02937984468,100664.24516439438,92892.02937984468,7750.031948566437,11.0984628200531,0.0 -209000,3.037879,1.1655269,,,,,,,,,,,,,, -209100,3.3298607,1.0908903,,,,,,,,,,,,,, -209200,2.9819062,1.3737,,,,,,,,,,,,,, -209300,7.439752,2.1585386,,,,,,,,,,,,,, -209400,3.2398152,1.1737657,,,,,,,,,,,,,, -209500,3.4081445,3.0648994,,,,,,,,,,,,,, -209600,10.075975,1.2653344,,,,,,,,,,,,,, -209700,3.1034224,1.0834541,,,,,,,,,,,,,, -209800,3.2470083,2.699938,,,,,,,,,,,,,, -209894,,,0.8918554782867432,0.4047467410564422,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,93312.06176161766,101120.9896683693,93312.06176161766,7786.632411718368,11.160149812698364,0.0 -209900,3.1190317,1.2727991,,,,,,,,,,,,,, -210000,2.8333511,1.0729383,,,,,,,,,,,,,, -210100,3.2820678,1.0991541,,,,,,,,,,,,,, -210200,3.7179794,1.5053926,,,,,,,,,,,,,, -210300,3.1388521,1.0514739,,,,,,,,,,,,,, -210400,4.3929467,2.8508615,,,,,,,,,,,,,, -210500,3.4484906,2.7371194,,,,,,,,,,,,,, -210600,3.2631845,2.55463,,,,,,,,,,,,,, -210700,2.9670527,1.123212,,,,,,,,,,,,,, -210800,4.0374975,3.2026825,,,,,,,,,,,,,, -210839,,,0.8872460722923279,0.4177339971065521,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,93732.22330355644,101577.8395473957,93732.22330355644,7823.209892272949,11.221330165863035,0.0 -210900,3.114298,1.935615,,,,,,,,,,,,,, -211000,3.099549,1.0962794,,,,,,,,,,,,,, -211100,3.2834454,2.4840446,,,,,,,,,,,,,, -211200,3.2895372,2.5164685,,,,,,,,,,,,,, -211300,3.5792022,1.8805066,,,,,,,,,,,,,, -211400,3.0387132,1.7765665,,,,,,,,,,,,,, -211500,3.1753402,2.2557745,,,,,,,,,,,,,, -211600,3.129869,1.1726614,,,,,,,,,,,,,, -211700,3.3048058,1.3888404,,,,,,,,,,,,,, -211784,,,0.8871093392372131,0.4195844233036041,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,94152.4714114666,102033.10613369942,94152.4714114666,7858.10614824295,11.293134450912476,0.0 -211800,2.998829,1.5385991,,,,,,,,,,,,,, -211900,4.6945744,3.218944,,,,,,,,,,,,,, -212000,3.3605843,1.1457675,,,,,,,,,,,,,, -212100,3.0342867,1.1096628,,,,,,,,,,,,,, -212200,3.825742,3.0435958,,,,,,,,,,,,,, -212300,3.5310767,1.2104361,,,,,,,,,,,,,, -212400,3.0346906,1.2494677,,,,,,,,,,,,,, -212500,3.0907366,1.1427717,,,,,,,,,,,,,, -212600,3.02338,1.2514874,,,,,,,,,,,,,, -212700,3.7839355,2.8331878,,,,,,,,,,,,,, -212729,,,0.8864257335662842,0.4255499839782715,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,94572.54891109468,102488.86503720284,94572.54891109468,7893.675740480423,11.354552745819092,0.0 -212800,3.3182478,2.804149,,,,,,,,,,,,,, -212900,3.1638002,1.1938586,,,,,,,,,,,,,, -213000,2.9796972,1.7682886,,,,,,,,,,,,,, -213100,3.1921625,1.213883,,,,,,,,,,,,,, -213200,3.1506655,2.6285765,,,,,,,,,,,,,, -213300,3.3708143,2.9377813,,,,,,,,,,,,,, -213400,3.2199078,1.4985479,,,,,,,,,,,,,, -213500,4.5889735,3.1419735,,,,,,,,,,,,,, -213600,3.8900366,3.3156247,,,,,,,,,,,,,, -213672,,,0.8885546922683716,0.4147121608257293,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,94992.61476254465,102945.8255429268,94992.61476254465,7930.459398984909,11.415436506271362,0.0 -213700,2.8245265,1.4409225,,,,,,,,,,,,,, -213800,2.9909515,1.187535,,,,,,,,,,,,,, -213900,4.8242292,1.1597192,,,,,,,,,,,,,, -214000,3.0960245,1.4955043,,,,,,,,,,,,,, -214100,3.081223,2.4230058,,,,,,,,,,,,,, -214200,3.0198703,1.5546756,,,,,,,,,,,,,, -214300,3.7180805,2.1147225,,,,,,,,,,,,,, -214400,3.3950322,2.0590758,,,,,,,,,,,,,, -214500,3.184654,1.1015762,,,,,,,,,,,,,, -214600,3.6607451,3.1129277,,,,,,,,,,,,,, -214616,,,0.8885155916213989,0.419483482837677,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,95412.57948756218,103402.09250640868,95412.57948756218,7966.65424156189,11.472744464874268,0.0 -214700,3.2383482,1.0646907,,,,,,,,,,,,,, -214800,2.8405356,1.0676056,,,,,,,,,,,,,, -214900,3.0719697,1.198907,,,,,,,,,,,,,, -215000,2.8904846,1.0604805,,,,,,,,,,,,,, -215100,3.5160482,2.0113115,,,,,,,,,,,,,, -215200,4.162991,1.113056,,,,,,,,,,,,,, -215300,3.099248,1.1529354,,,,,,,,,,,,,, -215400,3.1129253,1.0766644,,,,,,,,,,,,,, -215500,3.367438,1.1277503,,,,,,,,,,,,,, -215562,,,0.8874413967132568,0.417631447315216,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,95832.54670524596,103858.20341420174,95832.54670524596,8002.676563978195,11.544119596481323,0.0 -215600,3.0286467,1.227873,,,,,,,,,,,,,, -215700,3.1895354,1.0923479,,,,,,,,,,,,,, -215800,3.119265,1.3869729,,,,,,,,,,,,,, -215900,2.881513,2.201712,,,,,,,,,,,,,, -216000,3.0077968,1.053343,,,,,,,,,,,,,, -216100,2.9556088,1.4399443,,,,,,,,,,,,,, -216200,3.16696,1.8899562,,,,,,,,,,,,,, -216300,3.0576644,1.1239073,,,,,,,,,,,,,, -216400,3.264332,1.0802381,,,,,,,,,,,,,, -216482,,,0.8871679306030273,0.4172095358371734,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,96252.98335146904,104306.19088602066,96252.98335146904,8030.116549730301,11.604514122009276,0.0 -216500,3.4946125,1.8256255,,,,,,,,,,,,,, -216600,4.016077,2.668093,,,,,,,,,,,,,, -216700,3.5816274,2.875002,,,,,,,,,,,,,, -216800,3.5656586,3.0635264,,,,,,,,,,,,,, -216900,3.0932412,1.1965036,,,,,,,,,,,,,, -217000,2.9969475,2.3271072,,,,,,,,,,,,,, -217100,3.3002117,2.5912845,,,,,,,,,,,,,, -217200,3.4353037,2.6284096,,,,,,,,,,,,,, -217300,3.2144322,2.3237548,,,,,,,,,,,,,, -217388,,,0.8858984112739563,0.4288398623466491,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,96673.04321050644,104765.90989851952,96673.04321050644,8069.562195062637,11.770461082458496,0.0 -217400,3.2692602,1.7784089,,,,,,,,,,,,,, -217500,2.947728,1.5110512,,,,,,,,,,,,,, -217600,3.369973,2.811332,,,,,,,,,,,,,, -217700,3.964466,1.2548674,,,,,,,,,,,,,, -217800,3.2583692,1.1696162,,,,,,,,,,,,,, -217900,3.2126422,1.240837,,,,,,,,,,,,,, -218000,2.992008,1.1811653,,,,,,,,,,,,,, -218100,3.268243,1.4935359,,,,,,,,,,,,,, -218200,3.1316454,1.251322,,,,,,,,,,,,,, -218300,3.256014,2.5430999,,,,,,,,,,,,,, -218333,,,0.8865429759025574,0.4225581288337707,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,97093.0073914528,105223.30492758752,97093.0073914528,8106.868925571442,11.844651222229004,0.0 -218400,2.9650278,1.7288406,,,,,,,,,,,,,, -218500,2.998544,1.0890415,,,,,,,,,,,,,, -218600,3.5255985,1.1886839,,,,,,,,,,,,,, -218700,3.7857213,3.098993,,,,,,,,,,,,,, -218800,2.8928957,1.0086246,,,,,,,,,,,,,, -218900,3.9340544,3.0624144,,,,,,,,,,,,,, -219000,2.8299139,1.801553,,,,,,,,,,,,,, -219100,3.3722415,1.1009322,,,,,,,,,,,,,, -219200,3.72511,1.2270145,,,,,,,,,,,,,, -219280,,,0.8845507502555847,0.4251365661621094,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,97513.31821966173,105682.02164626122,97513.31821966173,8145.163213729858,11.905775547027588,0.0 -219300,3.3713868,1.4680536,,,,,,,,,,,,,, -219400,4.013843,3.1649032,,,,,,,,,,,,,, -219500,3.100386,1.4555902,,,,,,,,,,,,,, -219600,3.6184788,3.0918212,,,,,,,,,,,,,, -219700,2.8602767,0.94353855,,,,,,,,,,,,,, -219800,2.9494858,2.289478,,,,,,,,,,,,,, -219900,3.478712,2.2054205,,,,,,,,,,,,,, -220000,3.820939,2.9618762,,,,,,,,,,,,,, -220100,3.686116,2.7987719,,,,,,,,,,,,,, -220200,3.0953686,1.1337662,,,,,,,,,,,,,, -220224,,,0.8892187476158142,0.4172480404376983,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,97933.19587111472,106135.13363027573,97933.19587111472,8178.272148132324,11.980994701385498,0.0 -220300,3.8453155,1.177504,,,,,,,,,,,,,, -220400,2.9963708,1.1887839,,,,,,,,,,,,,, -220500,3.0298488,1.9851234,,,,,,,,,,,,,, -220600,3.6983123,3.0802634,,,,,,,,,,,,,, -220700,3.7732441,2.588336,,,,,,,,,,,,,, -220800,3.1917374,1.1891773,,,,,,,,,,,,,, -220900,3.567122,3.0415506,,,,,,,,,,,,,, -221000,3.4651895,2.9896245,,,,,,,,,,,,,, -221100,2.8968244,1.1379526,,,,,,,,,,,,,, -221168,,,0.8871093392372131,0.4200824499130249,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,98353.35867023468,106591.34285092354,98353.35867023468,8214.198472261429,12.051104068756104,0.0 -221200,3.264208,1.2066419,,,,,,,,,,,,,, -221300,3.0555124,1.1079408,,,,,,,,,,,,,, -221400,2.9670992,2.2148883,,,,,,,,,,,,,, -221500,3.3176246,1.8653637,,,,,,,,,,,,,, -221600,3.4150894,1.1292485,,,,,,,,,,,,,, -221700,2.9484,1.3997325,,,,,,,,,,,,,, -221800,2.847567,1.1637297,,,,,,,,,,,,,, -221900,3.1804972,1.1628603,,,,,,,,,,,,,, -222000,2.979065,1.1370127,,,,,,,,,,,,,, -222100,4.4766645,3.2780724,,,,,,,,,,,,,, -222114,,,0.8859765529632568,0.421042799949646,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,98773.69676375388,107047.64982414246,98773.69676375388,8250.0584192276,12.109280109405518,0.0 -222200,3.40135,1.2965274,,,,,,,,,,,,,, -222300,2.9959028,1.6475991,,,,,,,,,,,,,, -222400,4.6248612,2.8977041,,,,,,,,,,,,,, -222500,3.1358728,2.6617775,,,,,,,,,,,,,, -222600,3.2694998,1.0844959,,,,,,,,,,,,,, -222700,3.2221684,1.1354399,,,,,,,,,,,,,, -222800,4.422539,3.1413429,,,,,,,,,,,,,, -222900,3.1704879,2.538056,,,,,,,,,,,,,, -223000,3.0128376,1.1972446,,,,,,,,,,,,,, -223059,,,0.8874609470367432,0.416455864906311,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,99193.9378619194,107503.78232359886,99193.9378619194,8285.833389759064,12.175068616867064,0.0 -223100,3.3307908,1.1474516,,,,,,,,,,,,,, -223200,2.9511173,1.6784935,,,,,,,,,,,,,, -223300,3.0428393,1.2526791,,,,,,,,,,,,,, -223400,3.5652878,1.1050309,,,,,,,,,,,,,, -223500,3.6778026,3.1698275,,,,,,,,,,,,,, -223600,3.06224,1.0177622,,,,,,,,,,,,,, -223700,2.9644232,1.5404472,,,,,,,,,,,,,, -223800,3.1413589,1.1881769,,,,,,,,,,,,,, -223900,3.1938994,1.0949378,,,,,,,,,,,,,, -224000,3.410598,2.6864493,,,,,,,,,,,,,, -224006,,,0.8888476490974426,0.4210689067840576,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,99614.28044009209,107959.80958008766,99614.28044009209,8321.404640674591,12.237873315811155,0.0 -224100,3.2158883,1.9081603,,,,,,,,,,,,,, -224200,3.9775927,3.2249994,,,,,,,,,,,,,, -224300,3.0363624,2.3962376,,,,,,,,,,,,,, -224400,5.6960983,1.29502,,,,,,,,,,,,,, -224500,2.9907575,1.0594646,,,,,,,,,,,,,, -224600,3.5819054,1.1420765,,,,,,,,,,,,,, -224700,3.527298,1.1172266,,,,,,,,,,,,,, -224800,3.2538369,2.6233082,,,,,,,,,,,,,, -224900,3.3861406,1.6003766,,,,,,,,,,,,,, -224950,,,0.8860937356948853,0.4242303669452667,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,100034.37481594086,108416.46704983713,100034.37481594086,8357.846234321594,12.309117078781128,0.0 -225000,3.2686992,1.1016681,,,,,,,,,,,,,, -225100,3.3993301,1.0940194,,,,,,,,,,,,,, -225200,3.1823435,1.1942891,,,,,,,,,,,,,, -225300,3.6112008,3.1318707,,,,,,,,,,,,,, -225400,2.7759743,1.8507468,,,,,,,,,,,,,, -225500,3.6461558,1.0668356,,,,,,,,,,,,,, -225600,2.9947884,1.1244156,,,,,,,,,,,,,, -225700,3.9312928,1.0847015,,,,,,,,,,,,,, -225800,3.4767249,2.330267,,,,,,,,,,,,,, -225894,,,0.8856640458106995,0.4249935150146484,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,100454.41598153114,108872.19716739656,100454.41598153114,8393.42155122757,12.37180781364441,0.0 -225900,3.4094996,1.0771912,,,,,,,,,,,,,, -226000,3.0784595,1.0674245,,,,,,,,,,,,,, -226100,2.8493936,1.0511973,,,,,,,,,,,,,, -226200,2.949196,1.654916,,,,,,,,,,,,,, -226300,3.1379163,1.1403352,,,,,,,,,,,,,, -226400,3.5219738,1.1589086,,,,,,,,,,,,,, -226500,2.930398,2.100003,,,,,,,,,,,,,, -226600,3.1970186,1.5040236,,,,,,,,,,,,,, -226700,3.039046,1.85431,,,,,,,,,,,,,, -226800,2.8485343,1.148371,,,,,,,,,,,,,, -226838,,,0.8877148032188416,0.4187945425510406,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,100874.35620498656,109329.71062541008,100874.35620498656,8430.878259420395,12.438353300094604,0.0 -226900,3.461547,2.2111974,,,,,,,,,,,,,, -227000,3.2808108,1.2182276,,,,,,,,,,,,,, -227100,3.0859451,2.4005477,,,,,,,,,,,,,, -227200,3.224567,1.1130247,,,,,,,,,,,,,, -227300,3.2477772,1.1411859,,,,,,,,,,,,,, -227400,3.1877937,2.5928226,,,,,,,,,,,,,, -227500,3.429491,1.152379,,,,,,,,,,,,,, -227600,2.9174478,1.0322995,,,,,,,,,,,,,, -227700,3.0556173,1.4447569,,,,,,,,,,,,,, -227782,,,0.8872265219688416,0.4204618036746979,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,101294.32449269296,109787.79367494585,101294.32449269296,8468.879625320435,12.502143383026125,0.0 -227800,4.474106,3.2356355,,,,,,,,,,,,,, -227900,3.8177168,2.5738492,,,,,,,,,,,,,, -228000,3.34715,2.1549242,,,,,,,,,,,,,, -228100,3.445431,2.8990982,,,,,,,,,,,,,, -228200,3.128008,1.2912198,,,,,,,,,,,,,, -228300,3.2349598,2.6901312,,,,,,,,,,,,,, -228400,2.9759688,1.1192391,,,,,,,,,,,,,, -228500,3.2195272,1.9411607,,,,,,,,,,,,,, -228600,3.2037377,1.0868508,,,,,,,,,,,,,, -228700,3.1747646,2.5358996,,,,,,,,,,,,,, -228723,,,0.8875585794448853,0.4156016409397125,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,101714.59797477722,110246.95320534706,101714.59797477722,8507.651833295822,12.566620111465454,0.0 -228800,2.9237902,1.0979538,,,,,,,,,,,,,, -228900,3.415372,2.0021732,,,,,,,,,,,,,, -229000,3.991684,3.1926575,,,,,,,,,,,,,, -229100,3.8031337,3.1122503,,,,,,,,,,,,,, -229200,3.0112538,1.0788938,,,,,,,,,,,,,, -229300,3.1030762,2.1909337,,,,,,,,,,,,,, -229400,3.1928527,1.2960646,,,,,,,,,,,,,, -229500,2.9654396,1.1136966,,,,,,,,,,,,,, -229600,2.9281688,1.6832938,,,,,,,,,,,,,, -229661,,,0.8867773413658142,0.4198257625102997,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,102134.6535089016,110704.89049220084,102134.6535089016,8545.424136638641,12.626487016677856,0.0 -229700,3.0228474,1.1366524,,,,,,,,,,,,,, -229800,2.9758613,1.0848151,,,,,,,,,,,,,, -229900,3.1645703,1.3576839,,,,,,,,,,,,,, -230000,3.5512795,1.169383,,,,,,,,,,,,,, -230100,3.0751894,1.0704517,,,,,,,,,,,,,, -230200,3.5345037,1.1361986,,,,,,,,,,,,,, -230300,3.3191595,3.0346355,,,,,,,,,,,,,, -230400,3.544018,1.79162,,,,,,,,,,,,,, -230500,6.2643437,1.0534035,,,,,,,,,,,,,, -230600,3.3784878,1.1399362,,,,,,,,,,,,,, -230604,,,0.8856054544448853,0.4232231080532074,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,102554.6896674633,111164.42679929732,102554.6896674633,8584.81552362442,12.685871601104736,0.0 -230700,3.2587183,1.4536024,,,,,,,,,,,,,, -230800,3.479265,1.2020919,,,,,,,,,,,,,, -230900,3.3563836,1.1256664,,,,,,,,,,,,,, -231000,3.463922,1.5112389,,,,,,,,,,,,,, -231100,3.619763,3.0058782,,,,,,,,,,,,,, -231200,4.228956,3.2253594,,,,,,,,,,,,,, -231300,3.2476912,1.2063012,,,,,,,,,,,,,, -231400,2.9368336,1.0634983,,,,,,,,,,,,,, -231500,3.4416616,1.0571885,,,,,,,,,,,,,, -231547,,,0.8893945217132568,0.4144158065319061,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,102974.65440821648,111620.29118657112,102974.65440821648,8620.59669971466,12.754069089889526,0.0 -231600,4.0009184,1.7824751,,,,,,,,,,,,,, -231700,3.4040053,2.0324366,,,,,,,,,,,,,, -231800,3.1856298,1.6138113,,,,,,,,,,,,,, -231900,3.027092,1.0496619,,,,,,,,,,,,,, -232000,3.0568912,1.1323591,,,,,,,,,,,,,, -232100,3.294798,1.1683023,,,,,,,,,,,,,, -232200,3.40266,1.0974364,,,,,,,,,,,,,, -232300,3.4055817,1.2155753,,,,,,,,,,,,,, -232400,3.5775263,1.2732289,,,,,,,,,,,,,, -232490,,,0.8886327743530273,0.41773721575737,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,103394.6379442215,112078.49687552452,103394.6379442215,8658.708618879318,12.814483880996704,0.0 -232500,3.300677,2.491578,,,,,,,,,,,,,, -232600,2.9338183,1.0848038,,,,,,,,,,,,,, -232700,3.63947,2.051918,,,,,,,,,,,,,, -232800,4.867812,3.2041192,,,,,,,,,,,,,, -232900,3.0166357,2.2292135,,,,,,,,,,,,,, -233000,3.105981,1.1076087,,,,,,,,,,,,,, -233100,3.4629133,1.2331123,,,,,,,,,,,,,, -233200,3.2819207,1.1805937,,,,,,,,,,,,,, -233300,3.383056,1.4119967,,,,,,,,,,,,,, -233400,3.4165492,2.7982197,,,,,,,,,,,,,, -233432,,,0.8906640410423279,0.4075176119804382,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,103814.5766415596,112532.41099596024,103814.5766415596,8692.565080881119,12.884013175964355,0.0 -233500,3.025724,1.1047653,,,,,,,,,,,,,, -233600,3.3322847,1.177273,,,,,,,,,,,,,, -233700,3.4290867,1.1676946,,,,,,,,,,,,,, -233800,3.739308,3.157258,,,,,,,,,,,,,, -233900,3.0309184,2.0378954,,,,,,,,,,,,,, -234000,2.9339004,1.1967098,,,,,,,,,,,,,, -234100,3.0802774,1.0672158,,,,,,,,,,,,,, -234200,3.1690805,1.2861493,,,,,,,,,,,,,, -234300,3.1956546,2.2196012,,,,,,,,,,,,,, -234375,,,0.8890429735183716,0.4111610949039459,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,104234.8699195385,112989.12427139282,104234.8699195385,8728.871542215347,12.946634531021118,0.0 -234400,3.1948311,2.3970523,,,,,,,,,,,,,, -234500,4.0243397,3.2167985,,,,,,,,,,,,,, -234600,3.5856495,1.1542908,,,,,,,,,,,,,, -234700,3.1464162,1.0787591,,,,,,,,,,,,,, -234800,2.927743,1.2249719,,,,,,,,,,,,,, -234900,3.3983781,1.0797032,,,,,,,,,,,,,, -235000,3.254846,1.2021749,,,,,,,,,,,,,, -235100,3.8305764,3.131151,,,,,,,,,,,,,, -235200,3.4842353,1.9259315,,,,,,,,,,,,,, -235300,3.025107,1.5499816,,,,,,,,,,,,,, -235319,,,0.8880859017372131,0.4207883477210998,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,104655.03619599342,113446.68429636957,104655.03619599342,8766.155029296875,13.006984949111938,0.0 -235400,3.091006,1.887309,,,,,,,,,,,,,, -235500,2.887633,1.4555479,,,,,,,,,,,,,, -235600,3.2497616,1.8201263,,,,,,,,,,,,,, -235700,3.133691,1.2559271,,,,,,,,,,,,,, -235800,3.373019,1.4017595,,,,,,,,,,,,,, -235900,3.2954524,1.173044,,,,,,,,,,,,,, -236000,3.1866927,1.9019291,,,,,,,,,,,,,, -236100,3.0582387,1.0921828,,,,,,,,,,,,,, -236200,3.1612558,1.3495369,,,,,,,,,,,,,, -236262,,,0.88671875,0.4218221306800842,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,105075.3451757431,113903.48621582983,105075.3451757431,8802.531280994415,13.073152780532835,0.0 -236300,4.257244,2.855882,,,,,,,,,,,,,, -236400,2.94181,1.1311934,,,,,,,,,,,,,, -236500,4.2725763,3.1073542,,,,,,,,,,,,,, -236600,3.2969065,2.7190678,,,,,,,,,,,,,, -236700,4.1590085,3.1271973,,,,,,,,,,,,,, -236800,3.0602598,1.1505482,,,,,,,,,,,,,, -236900,3.216458,1.1584255,,,,,,,,,,,,,, -237000,3.1156962,1.4822767,,,,,,,,,,,,,, -237100,2.9214892,1.0364952,,,,,,,,,,,,,, -237200,3.0839472,1.071727,,,,,,,,,,,,,, -237207,,,0.8869531154632568,0.4194927811622619,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,105495.2727303505,114359.78117632866,105495.2727303505,8838.775473117828,13.146722555160522,0.0 -237300,3.6155958,2.7277846,,,,,,,,,,,,,, -237400,3.0355656,2.005348,,,,,,,,,,,,,, -237500,3.337969,2.7443523,,,,,,,,,,,,,, -237600,3.3435671,1.4019312,,,,,,,,,,,,,, -237700,3.275816,1.1060549,,,,,,,,,,,,,, -237800,3.0432703,1.1399121,,,,,,,,,,,,,, -237900,3.0157554,1.3726759,,,,,,,,,,,,,, -238000,3.0365834,1.0349725,,,,,,,,,,,,,, -238100,2.8961577,1.5925057,,,,,,,,,,,,,, -238150,,,0.8891406059265137,0.4173803925514221,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,105915.4233739376,114814.7551767826,105915.4233739376,8873.482074022293,13.212722063064575,0.0 -238200,2.9019718,1.2722931,,,,,,,,,,,,,, -238300,5.76223,2.9828212,,,,,,,,,,,,,, -238400,3.303227,2.158837,,,,,,,,,,,,,, -238500,3.110987,1.6748914,,,,,,,,,,,,,, -238600,3.2559862,1.7141246,,,,,,,,,,,,,, -238700,2.9266043,2.4755933,,,,,,,,,,,,,, -238800,3.246,1.1067487,,,,,,,,,,,,,, -238900,5.0360446,3.1942527,,,,,,,,,,,,,, -239000,3.2781656,2.7025163,,,,,,,,,,,,,, -239094,,,0.8861132860183716,0.4208251237869262,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,106335.55676198006,115270.7553293705,106335.55676198006,8909.238534212112,13.272995471954346,0.0 -239100,3.4683144,2.745588,,,,,,,,,,,,,, -239200,3.392688,2.95548,,,,,,,,,,,,,, -239300,3.1417778,2.481199,,,,,,,,,,,,,, -239400,3.4471757,1.1972699,,,,,,,,,,,,,, -239500,3.154833,1.3992893,,,,,,,,,,,,,, -239600,2.929245,1.5294673,,,,,,,,,,,,,, -239700,3.0367744,1.176915,,,,,,,,,,,,,, -239800,3.2556956,1.1962234,,,,,,,,,,,,,, -239900,3.3443851,1.7305206,,,,,,,,,,,,,, -240000,3.1957366,1.1308967,,,,,,,,,,,,,, -240041,,,0.8859374523162842,0.4255052208900451,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,106755.64887499808,115727.53117990494,106755.64887499808,8945.806883335114,13.338412046432495,0.0 -240100,3.3798141,1.1502129,,,,,,,,,,,,,, -240200,3.3325233,2.4576025,,,,,,,,,,,,,, -240300,3.2654495,1.0828779,,,,,,,,,,,,,, -240400,3.3835497,1.1704302,,,,,,,,,,,,,, -240500,4.360848,2.9868243,,,,,,,,,,,,,, -240600,3.0942192,1.106072,,,,,,,,,,,,,, -240700,3.8224492,3.194496,,,,,,,,,,,,,, -240800,3.2862763,1.1389474,,,,,,,,,,,,,, -240900,3.0537047,1.6799517,,,,,,,,,,,,,, -240984,,,0.88685542345047,0.4231062233448028,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,107175.83440184592,116184.20617771149,107175.83440184592,8982.179713249207,13.405362844467165,0.0 -241000,3.0309136,1.4083848,,,,,,,,,,,,,, -241100,3.9778032,3.026561,,,,,,,,,,,,,, -241200,3.311798,1.4649638,,,,,,,,,,,,,, -241300,3.0161858,1.1082189,,,,,,,,,,,,,, -241400,2.8822453,1.2468371,,,,,,,,,,,,,, -241500,3.0743098,1.1260612,,,,,,,,,,,,,, -241600,3.2342145,2.7394795,,,,,,,,,,,,,, -241700,3.303512,1.3106028,,,,,,,,,,,,,, -241800,3.0992415,1.1758938,,,,,,,,,,,,,, -241900,3.364284,1.4624312,,,,,,,,,,,,,, -241927,,,0.8875195384025574,0.4177833497524261,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,107595.84573626518,116641.06905651093,107595.84573626518,9018.915987968445,13.470390558242798,0.0 -242000,2.8669991,1.2223496,,,,,,,,,,,,,, -242100,3.465573,1.243634,,,,,,,,,,,,,, -242200,3.2077737,2.6204646,,,,,,,,,,,,,, -242300,2.984034,2.1731234,,,,,,,,,,,,,, -242400,3.0263476,1.8084073,,,,,,,,,,,,,, -242500,3.02728,1.1792701,,,,,,,,,,,,,, -242600,3.0530558,2.0788565,,,,,,,,,,,,,, -242700,3.1186514,1.8362094,,,,,,,,,,,,,, -242800,3.2741063,1.1660435,,,,,,,,,,,,,, -242873,,,0.8873632550239563,0.4211326837539673,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,108016.01560282709,117096.63333463667,108016.01560282709,9054.184888839722,13.544833660125732,0.0 -242900,3.0583851,2.300169,,,,,,,,,,,,,, -243000,3.2644482,2.0453045,,,,,,,,,,,,,, -243100,3.2616537,1.7585623,,,,,,,,,,,,,, -243200,2.9430695,1.1931853,,,,,,,,,,,,,, -243300,3.2577531,1.0672169,,,,,,,,,,,,,, -243400,3.073047,1.0223318,,,,,,,,,,,,,, -243500,3.5907152,1.3958488,,,,,,,,,,,,,, -243600,3.3299613,1.7615423,,,,,,,,,,,,,, -243700,3.1649537,1.1838076,,,,,,,,,,,,,, -243800,2.961985,1.5621004,,,,,,,,,,,,,, -243816,,,0.8863866925239563,0.4213360548019409,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,108436.21713781355,117554.89075517654,108436.21713781355,9092.124347448347,13.611082077026367,0.0 -243900,3.4668417,2.1603575,,,,,,,,,,,,,, -244000,3.1128123,1.6284088,,,,,,,,,,,,,, -244100,3.031724,1.2005582,,,,,,,,,,,,,, -244200,3.1564858,1.1790267,,,,,,,,,,,,,, -244300,3.0699139,1.3016943,,,,,,,,,,,,,, -244400,3.4281013,1.2982081,,,,,,,,,,,,,, -244500,3.3214798,2.4661393,,,,,,,,,,,,,, -244600,2.9669502,1.4918296,,,,,,,,,,,,,, -244700,3.2598345,2.6489594,,,,,,,,,,,,,, -244760,,,0.8868163824081421,0.4208961427211761,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,108856.1207022667,118013.58737802504,108856.1207022667,9130.800177574158,13.678319692611694,0.0 -244800,2.9591038,1.1375704,,,,,,,,,,,,,, -244900,3.2387567,1.2373909,,,,,,,,,,,,,, -245000,3.2025476,1.1689254,,,,,,,,,,,,,, -245100,3.8327208,3.1493688,,,,,,,,,,,,,, -245200,3.5215595,2.9419749,,,,,,,,,,,,,, -245300,3.5418978,2.9001718,,,,,,,,,,,,,, -245400,3.0761483,1.0561311,,,,,,,,,,,,,, -245500,2.965471,1.5725135,,,,,,,,,,,,,, -245600,3.060816,1.0962546,,,,,,,,,,,,,, -245700,4.1557837,3.0737357,,,,,,,,,,,,,, -245704,,,0.8855273127555847,0.4221659004688263,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,109276.05662918092,118469.53283405304,109276.05662918092,9166.68828845024,13.749892473220823,0.0 -245800,2.8957453,1.7173166,,,,,,,,,,,,,, -245900,3.125499,2.0988443,,,,,,,,,,,,,, -246000,3.5621748,2.9376245,,,,,,,,,,,,,, -246100,3.423329,2.9776938,,,,,,,,,,,,,, -246200,2.9953072,2.2718017,,,,,,,,,,,,,, -246300,3.3325512,1.751694,,,,,,,,,,,,,, -246400,3.085574,1.095578,,,,,,,,,,,,,, -246500,4.8394403,3.199609,,,,,,,,,,,,,, -246600,3.4025576,1.357565,,,,,,,,,,,,,, -246648,,,0.8881054520606995,0.4146759212017059,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,109696.36248373984,118924.96544003488,109696.36248373984,9201.703897237778,13.811155080795288,0.0 -246700,3.4321225,2.7423544,,,,,,,,,,,,,, -246800,2.9203634,1.3266628,,,,,,,,,,,,,, -246900,3.0812342,1.6138606,,,,,,,,,,,,,, -247000,3.7398694,2.1768446,,,,,,,,,,,,,, -247100,3.1711187,2.400631,,,,,,,,,,,,,, -247200,3.3830268,1.2482449,,,,,,,,,,,,,, -247300,3.3945823,1.549413,,,,,,,,,,,,,, -247400,3.246862,2.2334466,,,,,,,,,,,,,, -247500,3.1964426,1.5692848,,,,,,,,,,,,,, -247594,,,0.8883398175239563,0.4226095378398895,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,110116.45049238204,119382.20756483078,110116.45049238204,9238.741772651672,13.876649856567385,0.0 -247600,2.9312654,1.1835685,,,,,,,,,,,,,, -247700,2.977583,1.4176692,,,,,,,,,,,,,, -247800,3.4109871,2.85235,,,,,,,,,,,,,, -247900,3.0056465,1.0902561,,,,,,,,,,,,,, -248000,3.4459505,1.1528544,,,,,,,,,,,,,, -248100,2.9622517,1.2795148,,,,,,,,,,,,,, -248200,3.0324092,1.0517707,,,,,,,,,,,,,, -248300,3.248692,1.4467864,,,,,,,,,,,,,, -248400,3.1077192,1.0949395,,,,,,,,,,,,,, -248500,4.1352687,1.1489552,,,,,,,,,,,,,, -248539,,,0.8856640458106995,0.4264017343521118,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,110536.77914857864,119840.07347083092,110536.77914857864,9276.15583205223,13.948917627334597,0.0 -248600,3.4291544,1.3068914,,,,,,,,,,,,,, -248700,3.0161922,1.3297033,,,,,,,,,,,,,, -248800,3.7147171,2.9538438,,,,,,,,,,,,,, -248900,3.1879013,1.1615455,,,,,,,,,,,,,, -249000,3.6756153,1.1715958,,,,,,,,,,,,,, -249100,3.8941905,3.0948179,,,,,,,,,,,,,, -249200,3.2392254,1.2151821,,,,,,,,,,,,,, -249300,3.822515,1.4436473,,,,,,,,,,,,,, -249400,3.5056736,1.1804652,,,,,,,,,,,,,, -249482,,,0.8876757621765137,0.4183971285820007,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,110956.83072423936,120299.9406888485,110956.83072423936,9315.853289604189,14.01637315750122,0.0 -249500,2.9166074,1.8940703,,,,,,,,,,,,,, -249600,3.6852098,2.8996758,,,,,,,,,,,,,, -249700,4.383242,3.244354,,,,,,,,,,,,,, -249800,3.2671907,1.1775236,,,,,,,,,,,,,, -249900,3.2582543,2.5490515,,,,,,,,,,,,,, -250000,3.2972445,2.788394,,,,,,,,,,,,,, -250100,2.9433777,1.1922736,,,,,,,,,,,,,, -250200,3.0767963,1.239346,,,,,,,,,,,,,, -250300,3.0295498,1.467454,,,,,,,,,,,,,, -250400,3.2333627,2.1520607,,,,,,,,,,,,,, -250426,,,0.8869531154632568,0.4247783422470093,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,111376.93041610718,120756.08883309364,111376.93041610718,9351.783564567566,14.083440780639648,0.0 -250500,2.9518206,1.2215317,,,,,,,,,,,,,, -250600,3.1105485,2.6743824,,,,,,,,,,,,,, -250700,3.1332355,1.156807,,,,,,,,,,,,,, -250800,2.8615823,1.6150788,,,,,,,,,,,,,, -250900,3.2745025,1.2495794,,,,,,,,,,,,,, -251000,2.862695,1.6953461,,,,,,,,,,,,,, -251100,3.0529058,1.7655759,,,,,,,,,,,,,, -251200,3.4634254,2.9200954,,,,,,,,,,,,,, -251300,3.2682986,1.1166594,,,,,,,,,,,,,, -251370,,,0.8869140148162842,0.4152418375015259,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,111796.849973917,121213.8227274418,111796.849973917,9389.48131465912,14.150177240371704,0.0 -251400,3.0751636,1.1015533,,,,,,,,,,,,,, -251500,4.1746273,2.87852,,,,,,,,,,,,,, -251600,3.2789586,1.5024563,,,,,,,,,,,,,, -251700,2.7873716,1.8908677,,,,,,,,,,,,,, -251800,4.0941925,3.1851594,,,,,,,,,,,,,, -251900,2.9689693,1.836293,,,,,,,,,,,,,, -252000,3.9425175,1.3662786,,,,,,,,,,,,,, -252100,2.930015,1.1793969,,,,,,,,,,,,,, -252200,3.0892994,1.1115187,,,,,,,,,,,,,, -252300,3.6527572,1.9456633,,,,,,,,,,,,,, -252315,,,0.8876757621765137,0.4186638295650482,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,112217.03732776642,121672.28964018822,112217.03732776642,9427.63720202446,14.223804235458374,0.0 -252400,3.6866877,2.704683,,,,,,,,,,,,,, -252500,3.5821218,2.6569963,,,,,,,,,,,,,, -252600,3.3803623,2.283504,,,,,,,,,,,,,, -252700,3.176906,1.866617,,,,,,,,,,,,,, -252800,3.4214513,2.8653398,,,,,,,,,,,,,, -252900,3.2230284,1.7798893,,,,,,,,,,,,,, -253000,3.5647943,2.9226599,,,,,,,,,,,,,, -253100,4.685166,3.2369819,,,,,,,,,,,,,, -253200,3.23825,1.1452844,,,,,,,,,,,,,, -253259,,,0.88623046875,0.4246796667575836,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,112637.292222023,122127.56577944756,112637.292222023,9462.538872718813,14.292896032333374,0.0 -253300,3.289566,1.1440182,,,,,,,,,,,,,, -253400,3.3021116,1.1215509,,,,,,,,,,,,,, -253500,3.229081,2.5429664,,,,,,,,,,,,,, -253600,2.9531155,1.8682915,,,,,,,,,,,,,, -253700,3.2609828,2.1329436,,,,,,,,,,,,,, -253800,3.5410264,1.0724021,,,,,,,,,,,,,, -253900,3.7191036,1.0518837,,,,,,,,,,,,,, -254000,3.1391356,1.2102668,,,,,,,,,,,,,, -254100,4.0823326,1.9762883,,,,,,,,,,,,,, -254200,3.3473897,1.101156,,,,,,,,,,,,,, -254205,,,0.8882226347923279,0.414854496717453,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,113057.61286640169,122585.75490498544,113057.61286640169,9500.291576862335,14.35829257965088,0.0 -254300,3.0672617,1.2334414,,,,,,,,,,,,,, -254400,3.2497375,1.8695174,,,,,,,,,,,,,, -254500,2.820581,1.0485528,,,,,,,,,,,,,, -254600,3.2768793,1.1409509,,,,,,,,,,,,,, -254700,3.408566,2.5854473,,,,,,,,,,,,,, -254800,3.2738194,1.1451719,,,,,,,,,,,,,, -254900,3.5248866,3.0527403,,,,,,,,,,,,,, -255000,3.054847,1.0291624,,,,,,,,,,,,,, -255100,3.1395297,2.3452182,,,,,,,,,,,,,, -255152,,,0.8879101276397705,0.4163959324359894,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,113477.76876091956,123044.27444005013,113477.76876091956,9538.537852048874,14.424855947494509,0.0 -255200,4.0134025,3.3079383,,,,,,,,,,,,,, -255300,3.043717,1.0784642,,,,,,,,,,,,,, -255400,3.098097,1.379495,,,,,,,,,,,,,, -255500,3.6151998,2.1014807,,,,,,,,,,,,,, -255600,4.3278847,3.1318016,,,,,,,,,,,,,, -255700,3.647208,2.8617988,,,,,,,,,,,,,, -255800,3.1471326,1.0885253,,,,,,,,,,,,,, -255900,3.6320746,2.8571234,,,,,,,,,,,,,, -256000,2.870894,1.906208,,,,,,,,,,,,,, -256096,,,0.8895312547683716,0.4140346050262451,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,113897.86142706872,123501.7170226574,113897.86142706872,9575.771821022034,14.490861177444458,0.0 -256100,3.622816,2.7045684,,,,,,,,,,,,,, -256200,3.3675778,1.169585,,,,,,,,,,,,,, -256300,3.3620152,1.1031473,,,,,,,,,,,,,, -256400,3.1997056,2.2743878,,,,,,,,,,,,,, -256500,3.458016,1.1128755,,,,,,,,,,,,,, -256600,2.9776313,1.0923295,,,,,,,,,,,,,, -256700,3.8826442,3.21202,,,,,,,,,,,,,, -256800,4.6098046,2.9433038,,,,,,,,,,,,,, -256900,3.2088468,1.4042956,,,,,,,,,,,,,, -257000,3.0625916,1.2437259,,,,,,,,,,,,,, -257040,,,0.8866796493530273,0.4149068593978882,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,114317.87077498436,123958.7415716648,114317.87077498436,9612.654847860336,14.572320699691772,0.0 -257100,3.0156858,1.2922592,,,,,,,,,,,,,, -257200,2.9209292,1.1516348,,,,,,,,,,,,,, -257300,3.1303692,1.0737861,,,,,,,,,,,,,, -257400,3.1180046,2.8587124,,,,,,,,,,,,,, -257500,3.0602963,1.1483438,,,,,,,,,,,,,, -257600,2.9307256,1.138017,,,,,,,,,,,,,, -257700,3.267875,2.7635596,,,,,,,,,,,,,, -257800,3.346864,1.6453155,,,,,,,,,,,,,, -257900,2.9481893,1.8831238,,,,,,,,,,,,,, -257985,,,0.8883788585662842,0.4164501130580902,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,114738.04656410216,124414.4491698742,114738.04656410216,9648.057507038116,14.651172876358032,0.0 -258000,3.8751354,3.0932257,,,,,,,,,,,,,, -258100,3.0271132,1.1067142,,,,,,,,,,,,,, -258200,3.1890488,1.133419,,,,,,,,,,,,,, -258300,3.0225728,1.3944374,,,,,,,,,,,,,, -258400,3.7146096,3.053034,,,,,,,,,,,,,, -258500,3.107276,1.9913914,,,,,,,,,,,,,, -258600,3.1309679,1.2341866,,,,,,,,,,,,,, -258700,3.589778,2.297455,,,,,,,,,,,,,, -258800,3.3352563,2.6203234,,,,,,,,,,,,,, -258900,3.9736607,3.252502,,,,,,,,,,,,,, -258931,,,0.8895312547683716,0.4190482199192047,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,115158.09628129004,124873.41975784302,115158.09628129004,9686.863124847412,14.71579933166504,0.0 -259000,2.957822,2.1190882,,,,,,,,,,,,,, -259100,3.1549907,1.1844918,,,,,,,,,,,,,, -259200,3.9016578,3.2420099,,,,,,,,,,,,,, -259300,4.1047826,3.11793,,,,,,,,,,,,,, -259400,3.0820377,1.1634529,,,,,,,,,,,,,, -259500,2.989553,1.3955973,,,,,,,,,,,,,, -259600,3.4050152,1.4685758,,,,,,,,,,,,,, -259700,2.9500039,1.1107235,,,,,,,,,,,,,, -259800,2.8911743,1.0225372,,,,,,,,,,,,,, -259877,,,0.8865820169448853,0.4202475249767303,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,115578.23216414452,125330.19090342522,115578.23216414452,9723.382348060608,14.781482458114624,0.0 -259900,2.8614848,1.9955679,,,,,,,,,,,,,, -260000,2.8272288,1.0205307,,,,,,,,,,,,,, -260100,3.203447,1.0167843,,,,,,,,,,,,,, -260200,3.0346773,2.2899954,,,,,,,,,,,,,, -260300,3.1795623,2.1025124,,,,,,,,,,,,,, -260400,3.0249476,1.649008,,,,,,,,,,,,,, -260500,3.1976175,2.6124666,,,,,,,,,,,,,, -260600,3.100164,1.2283573,,,,,,,,,,,,,, -260700,3.4589553,1.3019919,,,,,,,,,,,,,, -260800,3.113668,1.2078695,,,,,,,,,,,,,, -260822,,,0.8883007764816284,0.418276309967041,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,115998.11412405968,125786.23392367364,115998.11412405968,9759.420825719832,14.854071378707886,0.0 -260900,3.0908606,2.5847633,,,,,,,,,,,,,, -261000,3.0269365,1.1596648,,,,,,,,,,,,,, -261100,3.2865188,1.251384,,,,,,,,,,,,,, -261200,3.3061314,1.1799283,,,,,,,,,,,,,, -261300,3.1919322,1.3744069,,,,,,,,,,,,,, -261400,2.9919293,1.1512215,,,,,,,,,,,,,, -261500,3.0267277,1.5763049,,,,,,,,,,,,,, -261600,3.115564,1.6964684,,,,,,,,,,,,,, -261700,3.442755,2.7054968,,,,,,,,,,,,,, -261767,,,0.8883788585662842,0.4182990491390228,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,116418.11951112749,126244.10726952551,116418.11951112749,9797.168113470078,14.92531681060791,0.0 -261800,2.9679024,1.1244586,,,,,,,,,,,,,, -261900,3.151598,1.1997876,,,,,,,,,,,,,, -262000,3.375447,1.321788,,,,,,,,,,,,,, -262100,3.0990627,1.2493463,,,,,,,,,,,,,, -262200,3.2960913,2.8097405,,,,,,,,,,,,,, -262300,3.6645226,1.8900056,,,,,,,,,,,,,, -262400,3.8269439,3.1591654,,,,,,,,,,,,,, -262500,2.99252,1.0658951,,,,,,,,,,,,,, -262600,3.2748888,1.1283934,,,,,,,,,,,,,, -262700,3.051146,1.1016616,,,,,,,,,,,,,, -262712,,,0.8886132836341858,0.4173754751682281,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,116838.13629627228,126701.49708580972,116838.13629627228,9834.422104597092,14.993987560272217,0.0 -262800,3.0001035,1.099312,,,,,,,,,,,,,, -262900,2.9242754,1.7955114,,,,,,,,,,,,,, -263000,3.0905757,1.7748646,,,,,,,,,,,,,, -263100,2.9791665,1.209825,,,,,,,,,,,,,, -263200,3.173294,1.0852224,,,,,,,,,,,,,, -263300,3.0759623,1.3837109,,,,,,,,,,,,,, -263400,2.8240914,1.9315919,,,,,,,,,,,,,, -263500,3.2042735,1.4099545,,,,,,,,,,,,,, -263600,4.073111,3.251448,,,,,,,,,,,,,, -263657,,,0.8840234279632568,0.4240234792232513,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,117258.14576864244,127160.88830947876,117258.14576864244,9873.68555045128,15.062115669250488,0.0 -263700,3.845172,3.0848775,,,,,,,,,,,,,, -263800,3.168835,1.0813997,,,,,,,,,,,,,, -263900,3.3932157,2.613001,,,,,,,,,,,,,, -264000,3.1674514,1.1358162,,,,,,,,,,,,,, -264100,3.122478,1.2519878,,,,,,,,,,,,,, -264200,3.3795323,2.713503,,,,,,,,,,,,,, -264300,3.3606944,2.9601545,,,,,,,,,,,,,, -264400,4.003349,2.9236922,,,,,,,,,,,,,, -264500,3.1365917,1.2971069,,,,,,,,,,,,,, -264600,3.2280884,1.1571187,,,,,,,,,,,,,, -264601,,,0.8874804377555847,0.4231193959712982,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,117678.73268508913,127619.50759148598,117678.73268508913,9911.598610162737,15.131271123886108,0.0 -264700,2.9172606,1.156688,,,,,,,,,,,,,, -264800,3.3074336,1.1250343,,,,,,,,,,,,,, -264900,3.2258775,1.1854603,,,,,,,,,,,,,, -265000,3.1818428,2.0027323,,,,,,,,,,,,,, -265100,3.2449143,1.804352,,,,,,,,,,,,,, -265200,2.785296,1.2801151,,,,,,,,,,,,,, -265300,3.3132746,1.1811786,,,,,,,,,,,,,, -265400,3.050475,1.1813726,,,,,,,,,,,,,, -265500,3.1496422,1.1805058,,,,,,,,,,,,,, -265547,,,0.8855078220367432,0.4264200925827026,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,118099.01088500024,128073.29778432846,118099.01088500024,9944.990770578384,15.20097303390503,0.0 -265600,3.2006693,1.2106673,,,,,,,,,,,,,, -265700,3.047142,1.1876851,,,,,,,,,,,,,, -265800,3.0474107,1.1495253,,,,,,,,,,,,,, -265900,4.284726,3.275183,,,,,,,,,,,,,, -266000,3.805112,3.0118797,,,,,,,,,,,,,, -266100,3.2948987,1.5937655,,,,,,,,,,,,,, -266200,3.8909578,3.0766227,,,,,,,,,,,,,, -266300,3.02146,1.2346025,,,,,,,,,,,,,, -266400,3.2951515,1.1393971,,,,,,,,,,,,,, -266490,,,0.8890234231948853,0.4168081283569336,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,118518.96572709084,128529.44418168068,118518.96572709084,9981.062950611116,15.270276546478271,0.0 -266500,2.9716995,1.042335,,,,,,,,,,,,,, -266600,3.522557,3.0315714,,,,,,,,,,,,,, -266700,3.1850476,1.4305658,,,,,,,,,,,,,, -266800,3.2758255,2.5735755,,,,,,,,,,,,,, -266900,3.2552435,1.1329207,,,,,,,,,,,,,, -267000,2.7668617,1.9620847,,,,,,,,,,,,,, -267100,3.7616324,3.1485684,,,,,,,,,,,,,, -267200,3.5769205,1.6763924,,,,,,,,,,,,,, -267300,3.006909,1.2750927,,,,,,,,,,,,,, -267400,2.8737488,1.2052369,,,,,,,,,,,,,, -267435,,,0.8855664134025574,0.4233571588993072,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,118938.91685509682,128986.48639082909,118938.91685509682,10018.03050661087,15.34412407875061,0.0 -267500,3.6641111,3.0307481,,,,,,,,,,,,,, -267600,3.0589662,1.0573792,,,,,,,,,,,,,, -267700,3.0229352,1.0947592,,,,,,,,,,,,,, -267800,3.2163558,1.5480033,,,,,,,,,,,,,, -267900,3.2998083,1.1817217,,,,,,,,,,,,,, -268000,15.057154,1.2342439,,,,,,,,,,,,,, -268100,3.3014162,1.1753141,,,,,,,,,,,,,, -268200,3.2033129,1.7976886,,,,,,,,,,,,,, -268300,3.3443205,1.2176162,,,,,,,,,,,,,, -268381,,,0.8865038752555847,0.4210705161094665,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,119359.17529034616,129443.6222999096,119359.17529034616,10054.78897356987,15.412813425064089,0.0 -268400,2.9962568,1.0863372,,,,,,,,,,,,,, -268500,3.0510552,1.1344521,,,,,,,,,,,,,, -268600,3.157621,1.6009998,,,,,,,,,,,,,, -268700,3.4664505,2.6849387,,,,,,,,,,,,,, -268800,3.2320197,1.2268702,,,,,,,,,,,,,, -268900,3.2403607,1.1449304,,,,,,,,,,,,,, -269000,3.0738826,1.0962088,,,,,,,,,,,,,, -269100,3.0358033,1.1727204,,,,,,,,,,,,,, -269200,3.2374268,1.11908,,,,,,,,,,,,,, -269300,3.276847,2.7908316,,,,,,,,,,,,,, -269326,,,0.8862109184265137,0.4198490381240845,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,119779.28468871117,129901.316116333,119779.28468871117,10092.251964330671,15.483604669570925,0.0 -269400,4.407957,2.9747484,,,,,,,,,,,,,, -269500,3.2197776,1.2092538,,,,,,,,,,,,,, -269600,4.363441,3.2201948,,,,,,,,,,,,,, -269700,2.9895291,1.610721,,,,,,,,,,,,,, -269800,2.8770807,1.1933821,,,,,,,,,,,,,, -269900,3.1785886,1.4673519,,,,,,,,,,,,,, -270000,3.970744,3.0223413,,,,,,,,,,,,,, -270100,3.1077335,1.1301892,,,,,,,,,,,,,, -270200,3.4168332,2.261722,,,,,,,,,,,,,, -270269,,,0.8884961009025574,0.4142680764198303,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,120199.18902301788,130357.5304300785,120199.18902301788,10128.442929267883,15.552424669265749,0.0 -270300,3.126802,1.1704654,,,,,,,,,,,,,, -270400,3.0458288,1.1383712,,,,,,,,,,,,,, -270500,4.0079694,2.9799967,,,,,,,,,,,,,, -270600,3.155953,1.0965168,,,,,,,,,,,,,, -270700,3.0238056,2.3108618,,,,,,,,,,,,,, -270800,3.256761,1.0543278,,,,,,,,,,,,,, -270900,2.9106736,1.059908,,,,,,,,,,,,,, -271000,2.819062,1.0977241,,,,,,,,,,,,,, -271100,3.6147742,1.4177957,,,,,,,,,,,,,, -271200,3.4441054,1.1424639,,,,,,,,,,,,,, -271214,,,0.8880664110183716,0.4235153794288635,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,120619.08997631072,130816.87016510963,120619.08997631072,10167.75212931633,15.631080865859984,0.0 -271300,3.0794904,1.2356129,,,,,,,,,,,,,, -271400,3.0795333,1.9087539,,,,,,,,,,,,,, -271500,9.28506,3.0839264,,,,,,,,,,,,,, -271600,2.9174695,1.4413922,,,,,,,,,,,,,, -271700,2.9696555,1.0518095,,,,,,,,,,,,,, -271800,3.6483119,1.1444693,,,,,,,,,,,,,, -271900,2.9349868,1.6511813,,,,,,,,,,,,,, -272000,3.3163145,2.4911575,,,,,,,,,,,,,, -272100,4.0920877,3.181259,,,,,,,,,,,,,, -272159,,,0.887011706829071,0.4199972152709961,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,121039.39338994026,131275.11368703842,121039.39338994026,10205.575210809708,15.698053359985352,0.0 -272200,3.02902,1.0425372,,,,,,,,,,,,,, -272300,3.423554,1.0880095,,,,,,,,,,,,,, -272400,3.2009537,1.2641444,,,,,,,,,,,,,, -272500,3.0641253,1.6344843,,,,,,,,,,,,,, -272600,3.2441332,2.963729,,,,,,,,,,,,,, -272700,4.768427,3.2404344,,,,,,,,,,,,,, -272800,3.2640555,1.1625544,,,,,,,,,,,,,, -272900,2.9023812,2.1367846,,,,,,,,,,,,,, -273000,2.950628,1.1613283,,,,,,,,,,,,,, -273100,4.647459,3.2936213,,,,,,,,,,,,,, -273101,,,0.8873242139816284,0.4234861135482788,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,121459.28456163406,131730.5489742756,121459.28456163406,10240.997619867325,15.770262479782104,0.0 -273200,4.497883,3.1981626,,,,,,,,,,,,,, -273300,4.038237,3.2380176,,,,,,,,,,,,,, -273400,3.1609898,1.2295156,,,,,,,,,,,,,, -273500,2.9000735,1.0338489,,,,,,,,,,,,,, -273600,3.2404456,2.6487138,,,,,,,,,,,,,, -273700,3.3340435,1.2244219,,,,,,,,,,,,,, -273800,3.3180544,1.1539224,,,,,,,,,,,,,, -273900,3.6372986,3.1355584,,,,,,,,,,,,,, -274000,3.982458,3.286373,,,,,,,,,,,,,, -274042,,,0.8865429759025574,0.4222384989261627,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,121879.23887968063,132186.96950387955,121879.23887968063,10277.346205711365,15.83849048614502,0.0 -274100,3.8260977,3.2143,,,,,,,,,,,,,, -274200,3.387302,1.2915694,,,,,,,,,,,,,, -274300,3.4084575,2.746495,,,,,,,,,,,,,, -274400,3.1097348,1.058351,,,,,,,,,,,,,, -274500,3.0050726,1.1504095,,,,,,,,,,,,,, -274600,3.1273334,1.2844205,,,,,,,,,,,,,, -274700,3.277161,2.7074451,,,,,,,,,,,,,, -274800,3.414935,1.4416093,,,,,,,,,,,,,, -274900,6.3092585,2.2157843,,,,,,,,,,,,,, -274987,,,0.8861327767372131,0.4210696220397949,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,122299.38060212135,132645.01781749725,122299.38060212135,10315.125161886215,15.915863037109377,0.0 -275000,3.1212664,1.8255146,,,,,,,,,,,,,, -275100,3.0361655,1.1373241,,,,,,,,,,,,,, -275200,3.2115,1.7895266,,,,,,,,,,,,,, -275300,3.0421903,1.8996994,,,,,,,,,,,,,, -275400,3.3473709,1.6021254,,,,,,,,,,,,,, -275500,7.556025,1.5215805,,,,,,,,,,,,,, -275600,3.4600658,1.2005861,,,,,,,,,,,,,, -275700,3.1652536,1.2184063,,,,,,,,,,,,,, -275800,3.0832593,1.1366556,,,,,,,,,,,,,, -275900,3.6014814,1.8898954,,,,,,,,,,,,,, -275933,,,0.8864843845367432,0.4235370457172394,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,122719.44203877448,133101.93027758598,122719.44203877448,10351.85710811615,15.985002279281616,0.0 -276000,4.347562,3.1364708,,,,,,,,,,,,,, -276100,3.3358855,1.1691952,,,,,,,,,,,,,, -276200,3.2991233,1.1254919,,,,,,,,,,,,,, -276300,3.1736627,1.0731279,,,,,,,,,,,,,, -276400,3.6633606,3.263882,,,,,,,,,,,,,, -276500,3.3522615,1.1980194,,,,,,,,,,,,,, -276600,3.1983056,1.4389294,,,,,,,,,,,,,, -276700,3.206962,1.7152233,,,,,,,,,,,,,, -276800,3.3371673,1.1663053,,,,,,,,,,,,,, -276878,,,0.8878905773162842,0.4141671657562256,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,123139.42353534698,133560.93928790092,123139.42353534698,10390.764615297318,16.054569482803345,0.0 -276900,3.1359737,1.9523592,,,,,,,,,,,,,, -277000,2.889838,1.1788752,,,,,,,,,,,,,, -277100,4.5125027,3.1545198,,,,,,,,,,,,,, -277200,3.288967,1.0976243,,,,,,,,,,,,,, -277300,3.1481984,1.1172589,,,,,,,,,,,,,, -277400,3.3640015,3.0805135,,,,,,,,,,,,,, -277500,2.9289808,1.025034,,,,,,,,,,,,,, -277600,4.350589,3.394132,,,,,,,,,,,,,, -277700,3.1391842,1.0780002,,,,,,,,,,,,,, -277800,4.1536875,3.041471,,,,,,,,,,,,,, -277822,,,0.8886327743530273,0.4123508036136627,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,123559.36771154404,134016.9411432743,123559.36771154404,10426.705643177032,16.121344804763794,0.0 -277900,2.911854,1.3751985,,,,,,,,,,,,,, -278000,2.989889,1.251075,,,,,,,,,,,,,, -278100,3.142235,1.0948424,,,,,,,,,,,,,, -278200,3.1262023,1.2087142,,,,,,,,,,,,,, -278300,3.4538019,1.124767,,,,,,,,,,,,,, -278400,3.0665543,2.6447418,,,,,,,,,,,,,, -278500,3.159196,1.0833949,,,,,,,,,,,,,, -278600,3.2816262,1.1199871,,,,,,,,,,,,,, -278700,3.1219876,1.0845001,,,,,,,,,,,,,, -278767,,,0.8878515362739563,0.4221333265304565,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,123980.14164686204,134473.20946788788,123980.14164686204,10462.081463813782,16.189631462097168,0.0 -278800,3.1698236,1.1730396,,,,,,,,,,,,,, -278900,3.3907597,1.1636392,,,,,,,,,,,,,, -279000,3.411445,2.282794,,,,,,,,,,,,,, -279100,3.1800134,1.3271382,,,,,,,,,,,,,, -279200,3.1349692,1.5168705,,,,,,,,,,,,,, -279300,3.335175,2.448238,,,,,,,,,,,,,, -279400,3.1652474,1.156376,,,,,,,,,,,,,, -279500,3.23986,1.1789992,,,,,,,,,,,,,, -279600,3.741166,1.1357613,,,,,,,,,,,,,, -279700,3.3128955,1.1833961,,,,,,,,,,,,,, -279710,,,0.8892382383346558,0.4132129549980163,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,124400.19126391412,134929.40495729446,124400.19126391412,10498.109572172163,16.257033348083496,0.0 -279800,3.074061,2.714659,,,,,,,,,,,,,, -279900,4.02207,1.5299999,,,,,,,,,,,,,, -280000,3.5781302,3.1814594,,,,,,,,,,,,,, -280100,4.593884,3.0571494,,,,,,,,,,,,,, -280200,3.283141,1.1180911,,,,,,,,,,,,,, -280300,3.1856418,1.1245595,,,,,,,,,,,,,, -280400,4.435735,3.1597903,,,,,,,,,,,,,, -280500,3.16958,1.0252005,,,,,,,,,,,,,, -280600,3.092987,1.3607486,,,,,,,,,,,,,, -280653,,,0.8885741829872131,0.4136310815811157,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,124820.24197268486,135388.0551865101,124820.24197268486,10536.577868461609,16.33836817741394,0.0 -280700,2.9036484,1.0260563,,,,,,,,,,,,,, -280800,2.887172,1.1030036,,,,,,,,,,,,,, -280900,3.3656013,1.6078389,,,,,,,,,,,,,, -281000,3.1829154,1.332576,,,,,,,,,,,,,, -281100,3.2608461,1.1519405,,,,,,,,,,,,,, -281200,3.2319994,1.1977522,,,,,,,,,,,,,, -281300,2.9022112,1.2713931,,,,,,,,,,,,,, -281400,3.1445613,1.1544182,,,,,,,,,,,,,, -281500,3.2119567,1.0661852,,,,,,,,,,,,,, -281595,,,0.8893163800239563,0.4133542776107788,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,125240.18860721588,135845.2536330223,125240.18860721588,10573.71326136589,16.405822038650513,0.0 -281600,4.197893,3.1042075,,,,,,,,,,,,,, -281700,3.6883023,2.999973,,,,,,,,,,,,,, -281800,3.3765576,1.0936344,,,,,,,,,,,,,, -281900,3.1664116,1.4225771,,,,,,,,,,,,,, -282000,5.2780633,3.3221488,,,,,,,,,,,,,, -282100,6.1645923,2.9487836,,,,,,,,,,,,,, -282200,3.0341315,1.7629449,,,,,,,,,,,,,, -282300,3.004741,1.4881732,,,,,,,,,,,,,, -282400,3.0191953,1.1744852,,,,,,,,,,,,,, -282500,3.1861496,1.3151729,,,,,,,,,,,,,, -282542,,,0.8870312571525574,0.4205176830291748,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,125660.50177502632,136304.7267510891,125660.50177502632,10612.752341747284,16.47682547569275,0.0 -282600,2.9978156,1.0739505,,,,,,,,,,,,,, -282700,2.955722,1.8379178,,,,,,,,,,,,,, -282800,3.369957,1.2504407,,,,,,,,,,,,,, -282900,3.3135698,2.0433407,,,,,,,,,,,,,, -283000,3.1490276,1.0778426,,,,,,,,,,,,,, -283100,3.1238601,2.380292,,,,,,,,,,,,,, -283200,3.5074205,1.1127636,,,,,,,,,,,,,, -283300,3.1235573,1.8364139,,,,,,,,,,,,,, -283400,3.3416154,1.3829952,,,,,,,,,,,,,, -283487,,,0.8861523270606995,0.4218646585941314,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,126080.76248979568,136762.85723400116,126080.76248979568,10650.505030870438,16.544782161712646,0.0 -283500,2.8673637,1.4978805,,,,,,,,,,,,,, -283600,4.5880365,2.9666138,,,,,,,,,,,,,, -283700,3.4128428,3.1062243,,,,,,,,,,,,,, -283800,2.7576606,1.5814683,,,,,,,,,,,,,, -283900,3.0541816,1.4956154,,,,,,,,,,,,,, -284000,3.6278868,1.1894628,,,,,,,,,,,,,, -284100,3.1418395,1.4241788,,,,,,,,,,,,,, -284200,4.1163354,3.1712558,,,,,,,,,,,,,, -284300,3.1644585,1.3950728,,,,,,,,,,,,,, -284400,3.1613588,1.206281,,,,,,,,,,,,,, -284431,,,0.8878515362739563,0.419499933719635,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,126500.92922019958,137219.60481786728,126500.92922019958,10686.965690135956,16.61560034751892,0.0 -284500,3.2533193,1.1707089,,,,,,,,,,,,,, -284600,4.256232,3.2842088,,,,,,,,,,,,,, -284700,3.1123066,1.3266022,,,,,,,,,,,,,, -284800,3.2191765,2.6920576,,,,,,,,,,,,,, -284900,3.0956633,1.1078615,,,,,,,,,,,,,, -285000,3.2465346,2.4745274,,,,,,,,,,,,,, -285100,3.4153326,2.363243,,,,,,,,,,,,,, -285200,3.1269014,1.3075066,,,,,,,,,,,,,, -285300,3.1895442,2.4463758,,,,,,,,,,,,,, -285377,,,0.8895312547683716,0.4159430265426636,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,126920.93151402472,137673.84118700027,126920.93151402472,10721.06535768509,16.69960618019104,0.0 -285400,3.3338876,1.7751542,,,,,,,,,,,,,, -285500,3.425998,1.1456208,,,,,,,,,,,,,, -285600,3.0097816,1.3303318,,,,,,,,,,,,,, -285700,2.9202979,1.1581635,,,,,,,,,,,,,, -285800,2.9221628,1.6391516,,,,,,,,,,,,,, -285900,3.0854783,2.0752525,,,,,,,,,,,,,, -286000,3.5564504,1.3730106,,,,,,,,,,,,,, -286100,3.1616025,1.1058587,,,,,,,,,,,,,, -286200,3.18867,1.1776983,,,,,,,,,,,,,, -286300,3.536762,1.089099,,,,,,,,,,,,,, -286321,,,0.88818359375,0.4206604659557342,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,127340.82446813583,138132.89316129684,127340.82446813583,10760.106906414032,16.767632246017456,0.0 -286400,3.3913717,2.7294774,,,,,,,,,,,,,, -286500,3.0889287,1.3278418,,,,,,,,,,,,,, -286600,2.9376023,1.3508235,,,,,,,,,,,,,, -286700,2.8739178,1.1945704,,,,,,,,,,,,,, -286800,3.1287894,2.2285597,,,,,,,,,,,,,, -286900,3.1653967,1.0314064,,,,,,,,,,,,,, -287000,3.1998298,1.4495311,,,,,,,,,,,,,, -287100,3.6036139,2.9505436,,,,,,,,,,,,,, -287200,6.1066394,3.056198,,,,,,,,,,,,,, -287265,,,0.8854101300239563,0.4235857725143432,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,127761.01388812064,138590.194170475,127761.01388812064,10797.100801467896,16.835665464401245,0.0 -287300,3.1923008,1.1573733,,,,,,,,,,,,,, -287400,3.0842097,1.1026233,,,,,,,,,,,,,, -287500,2.9968991,1.0883988,,,,,,,,,,,,,, -287600,2.9118853,0.9997897,,,,,,,,,,,,,, -287700,3.1231313,1.1022961,,,,,,,,,,,,,, -287800,3.0430064,1.1376036,,,,,,,,,,,,,, -287900,3.4989126,1.0004054,,,,,,,,,,,,,, -288000,2.9430177,1.3007071,,,,,,,,,,,,,, -288100,3.3652887,1.2071421,,,,,,,,,,,,,, -288200,3.5959313,2.9365625,,,,,,,,,,,,,, -288208,,,0.8864648342132568,0.4208704531192779,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,128181.00546360016,139048.703029871,128181.00546360016,10835.234018087389,17.169645071029663,0.0 -288300,3.3649182,1.5469408,,,,,,,,,,,,,, -288400,2.9631839,1.039422,,,,,,,,,,,,,, -288500,3.153049,1.8258854,,,,,,,,,,,,,, -288600,2.9342027,1.6658957,,,,,,,,,,,,,, -288700,3.360071,1.3579254,,,,,,,,,,,,,, -288800,3.3246663,1.2458307,,,,,,,,,,,,,, -288900,2.8870213,1.1681609,,,,,,,,,,,,,, -289000,3.3556201,1.3576288,,,,,,,,,,,,,, -289100,3.0311906,2.366041,,,,,,,,,,,,,, -289147,,,0.8867773413658142,0.4219215214252472,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,128601.05171775818,139507.9273519516,128601.05171775818,10874.292521953585,17.23930835723877,0.0 -289200,3.2801886,1.1398313,,,,,,,,,,,,,, -289300,3.3393197,1.1905766,,,,,,,,,,,,,, -289400,3.6478167,3.0480628,,,,,,,,,,,,,, -289500,3.2272048,2.6953716,,,,,,,,,,,,,, -289600,3.2785392,2.5922618,,,,,,,,,,,,,, -289700,3.2088838,2.9213765,,,,,,,,,,,,,, -289800,3.9072306,2.7798033,,,,,,,,,,,,,, -289900,2.9990833,1.8910956,,,,,,,,,,,,,, -290000,3.5606527,2.9082162,,,,,,,,,,,,,, -290090,,,0.88734370470047,0.412392109632492,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,129021.26099348068,139964.28707146645,129021.26099348068,10910.316421985626,17.317003965377808,0.0 -290100,3.8602254,3.236624,,,,,,,,,,,,,, -290200,4.0853724,3.066223,,,,,,,,,,,,,, -290300,3.178275,1.3687042,,,,,,,,,,,,,, -290400,3.1486788,1.7630684,,,,,,,,,,,,,, -290500,2.9287493,1.4998131,,,,,,,,,,,,,, -290600,3.2464974,1.064208,,,,,,,,,,,,,, -290700,3.0929027,1.0648282,,,,,,,,,,,,,, -290800,3.4744542,2.7885761,,,,,,,,,,,,,, -290900,3.1409352,1.227916,,,,,,,,,,,,,, -291000,3.2056074,2.8663316,,,,,,,,,,,,,, -291033,,,0.8869140148162842,0.4278862178325653,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,129441.24899697304,140420.52204823494,129441.24899697304,10946.443460464478,17.38748025894165,0.0 -291100,3.0162778,1.1717689,,,,,,,,,,,,,, -291200,3.658322,1.2307339,,,,,,,,,,,,,, -291300,3.167921,1.1868976,,,,,,,,,,,,,, -291400,4.0354633,3.1799064,,,,,,,,,,,,,, -291500,3.8012884,3.103116,,,,,,,,,,,,,, -291600,3.0464032,1.188974,,,,,,,,,,,,,, -291700,3.1146786,1.8884951,,,,,,,,,,,,,, -291800,3.106854,1.72568,,,,,,,,,,,,,, -291900,2.972577,2.2366056,,,,,,,,,,,,,, -291975,,,0.885058581829071,0.4274966716766357,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,129861.1802175045,140877.92724633217,129861.1802175045,10983.784052610396,17.47106671333313,0.0 -292000,3.1594038,1.779259,,,,,,,,,,,,,, -292100,3.069552,1.8051875,,,,,,,,,,,,,, -292200,3.1340435,1.2710395,,,,,,,,,,,,,, -292300,2.9626842,1.1093447,,,,,,,,,,,,,, -292400,2.9200366,1.5393076,,,,,,,,,,,,,, -292500,3.0881772,1.8136351,,,,,,,,,,,,,, -292600,3.054713,2.4234738,,,,,,,,,,,,,, -292700,3.1370025,1.1093125,,,,,,,,,,,,,, -292800,4.340344,3.186328,,,,,,,,,,,,,, -292900,3.1977208,1.0869598,,,,,,,,,,,,,, -292918,,,0.8882812261581421,0.414492130279541,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,130281.11374497414,141332.45461821556,130281.11374497414,11018.251610517502,17.548232555389404,0.0 -293000,3.185811,2.4327583,,,,,,,,,,,,,, -293100,3.2166636,1.0086484,,,,,,,,,,,,,, -293200,3.017212,1.1400177,,,,,,,,,,,,,, -293300,3.8606825,3.3870685,,,,,,,,,,,,,, -293400,2.9774342,1.2120799,,,,,,,,,,,,,, -293500,3.260354,2.8007264,,,,,,,,,,,,,, -293600,3.8880172,3.2882025,,,,,,,,,,,,,, -293700,2.995281,1.2868481,,,,,,,,,,,,,, -293800,3.109929,1.0084261,,,,,,,,,,,,,, -293862,,,0.8871679306030273,0.4163171350955963,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,130701.35450816154,141789.79231905937,130701.35450816154,11055.229821681976,17.617226600646973,0.0 -293900,3.2721987,2.6949613,,,,,,,,,,,,,, -294000,3.4932475,3.128251,,,,,,,,,,,,,, -294100,3.035153,1.1130803,,,,,,,,,,,,,, -294200,3.113662,1.148754,,,,,,,,,,,,,, -294300,3.8627155,3.2372127,,,,,,,,,,,,,, -294400,3.4444158,1.1562713,,,,,,,,,,,,,, -294500,3.1216066,1.1879113,,,,,,,,,,,,,, -294600,2.8378336,1.1545209,,,,,,,,,,,,,, -294700,3.2060428,1.1729476,,,,,,,,,,,,,, -294800,2.9807792,1.9341276,,,,,,,,,,,,,, -294807,,,0.8866991996765137,0.425510048866272,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,131121.62486243248,142246.22771835327,131121.62486243248,11091.266440868378,17.695829153060913,0.0 -294900,3.3787026,2.647873,,,,,,,,,,,,,, -295000,3.0648155,1.6319821,,,,,,,,,,,,,, -295100,3.0778382,1.610341,,,,,,,,,,,,,, -295200,3.05753,1.2427917,,,,,,,,,,,,,, -295300,3.3661325,1.3612503,,,,,,,,,,,,,, -295400,3.053803,1.2368206,,,,,,,,,,,,,, -295500,3.1047888,1.1246938,,,,,,,,,,,,,, -295600,3.5834093,1.3442646,,,,,,,,,,,,,, -295700,3.047438,1.6307734,,,,,,,,,,,,,, -295748,,,0.8882812261581421,0.4199662804603576,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,131541.72653746605,142702.16307020187,131541.72653746605,11126.980338573456,17.765344858169556,0.0 -295800,3.3725762,1.587095,,,,,,,,,,,,,, -295900,3.5223544,2.7611985,,,,,,,,,,,,,, -296000,3.1790323,1.312984,,,,,,,,,,,,,, -296100,3.0041127,1.5380278,,,,,,,,,,,,,, -296200,3.1059709,1.1084471,,,,,,,,,,,,,, -296300,2.860492,1.0945637,,,,,,,,,,,,,, -296400,3.573648,3.0350747,,,,,,,,,,,,,, -296500,3.3045602,2.0257382,,,,,,,,,,,,,, -296600,3.0649958,1.5392717,,,,,,,,,,,,,, -296690,,,0.8873828053474426,0.4187077581882477,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,131961.94080877304,143160.52719926834,131961.94080877304,11165.005978107452,17.840155839920044,0.0 -296700,4.031336,3.1567352,,,,,,,,,,,,,, -296800,3.9027576,3.236793,,,,,,,,,,,,,, -296900,3.9677534,3.1327195,,,,,,,,,,,,,, -297000,2.9397454,1.1097732,,,,,,,,,,,,,, -297100,2.9745884,1.446747,,,,,,,,,,,,,, -297200,3.2109993,1.8947555,,,,,,,,,,,,,, -297300,3.0014377,1.1813501,,,,,,,,,,,,,, -297400,3.4459538,2.6478581,,,,,,,,,,,,,, -297500,3.143827,2.335251,,,,,,,,,,,,,, -297600,3.1694674,1.1648573,,,,,,,,,,,,,, -297636,,,0.8857226371765137,0.4250858128070831,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,132382.2102355957,143615.20072078705,132382.2102355957,11199.285538434982,17.914023637771606,0.0 -297700,3.0673459,1.1253408,,,,,,,,,,,,,, -297800,4.334131,1.1372998,,,,,,,,,,,,,, -297900,3.5566425,2.5925512,,,,,,,,,,,,,, -298000,3.070563,1.2095277,,,,,,,,,,,,,, -298100,3.1599772,1.6706822,,,,,,,,,,,,,, -298200,3.3149986,1.0802572,,,,,,,,,,,,,, -298300,3.1048908,1.1777012,,,,,,,,,,,,,, -298400,3.3350341,1.1655204,,,,,,,,,,,,,, -298500,3.2536955,1.3089217,,,,,,,,,,,,,, -298580,,,0.8873046636581421,0.4215691983699798,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,132802.52157902718,144074.85414624214,132802.52157902718,11238.506650447844,17.98520803451538,0.0 -298600,3.1767616,1.5199888,,,,,,,,,,,,,, -298700,3.1363428,1.090608,,,,,,,,,,,,,, -298800,4.4602985,3.2159183,,,,,,,,,,,,,, -298900,3.0172691,1.0851955,,,,,,,,,,,,,, -299000,2.9940376,1.0394137,,,,,,,,,,,,,, -299100,3.608508,3.1373646,,,,,,,,,,,,,, -299200,3.610658,2.7624753,,,,,,,,,,,,,, -299300,3.4795842,3.0429418,,,,,,,,,,,,,, -299400,2.789141,2.2017038,,,,,,,,,,,,,, -299500,4.463271,3.1458502,,,,,,,,,,,,,, -299522,,,0.8869335651397705,0.4194314777851105,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,133222.74796295166,144530.4835705757,133222.74796295166,11273.791742801666,18.054905891418457,0.0 -299600,3.2647853,1.8969132,,,,,,,,,,,,,, -299700,2.8171432,1.3783416,,,,,,,,,,,,,, -299800,3.0919254,1.9092999,,,,,,,,,,,,,, -299900,3.5143511,1.2304347,,,,,,,,,,,,,, -300000,2.876549,1.1046833,,,,,,,,,,,,,, -300100,2.9553828,1.1083814,,,,,,,,,,,,,, -300200,3.8501778,3.110172,,,,,,,,,,,,,, -300300,3.1412845,2.2647324,,,,,,,,,,,,,, -300400,2.860561,1.0669007,,,,,,,,,,,,,, -300467,,,0.8865820169448853,0.4196713864803314,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,133642.82978200912,144987.6792693138,133642.82978200912,11310.786287784576,18.12456536293029,0.0 -300500,3.094143,1.433828,,,,,,,,,,,,,, -300600,3.3575716,2.868274,,,,,,,,,,,,,, -300700,3.6468418,2.0833542,,,,,,,,,,,,,, -300800,5.0639286,3.207646,,,,,,,,,,,,,, -300900,3.1018353,2.086527,,,,,,,,,,,,,, -301000,3.0463514,1.3891436,,,,,,,,,,,,,, -301100,3.0463493,2.6494112,,,,,,,,,,,,,, -301200,3.8581598,3.1089277,,,,,,,,,,,,,, -301300,3.8255196,2.9475918,,,,,,,,,,,,,, -301400,3.303901,1.1354741,,,,,,,,,,,,,, -301412,,,0.8876953125,0.4157899022102356,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,134063.15568971634,145443.7553141117,134063.15568971634,11346.413735866548,18.19742178916931,0.0 -301500,2.9740694,0.9990355,,,,,,,,,,,,,, -301600,3.229054,1.2471663,,,,,,,,,,,,,, -301700,3.329131,1.1561066,,,,,,,,,,,,,, -301800,3.4276407,2.7426956,,,,,,,,,,,,,, -301900,3.4478388,2.9075162,,,,,,,,,,,,,, -302000,4.0251575,3.177255,,,,,,,,,,,,,, -302100,3.300521,1.3437059,,,,,,,,,,,,,, -302200,3.0472317,1.1193007,,,,,,,,,,,,,, -302300,3.2596457,1.5133805,,,,,,,,,,,,,, -302358,,,0.8873242139816284,0.4200023412704468,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,134483.15005874634,145901.2564651966,134483.15005874634,11383.800921678543,18.267784357070923,0.0 -302400,3.0432923,1.0144207,,,,,,,,,,,,,, -302500,2.9561136,1.4029453,,,,,,,,,,,,,, -302600,2.785232,1.2280111,,,,,,,,,,,,,, -302700,3.1351964,1.0566887,,,,,,,,,,,,,, -302800,3.1619947,1.8860649,,,,,,,,,,,,,, -302900,2.974323,1.0921729,,,,,,,,,,,,,, -303000,3.0082922,1.952157,,,,,,,,,,,,,, -303100,2.9663575,1.1273266,,,,,,,,,,,,,, -303200,3.2302318,1.5158125,,,,,,,,,,,,,, -303300,3.4549892,2.8917205,,,,,,,,,,,,,, -303303,,,0.8898828029632568,0.4121365249156952,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,134903.26118922234,146357.46699857712,134903.26118922234,11419.779727935793,18.338664531707764,0.0 -303400,2.8365893,1.210068,,,,,,,,,,,,,, -303500,3.16079,2.7519221,,,,,,,,,,,,,, -303600,2.7705116,1.244107,,,,,,,,,,,,,, -303700,3.8167152,1.1417619,,,,,,,,,,,,,, -303800,2.8606846,1.2184167,,,,,,,,,,,,,, -303900,3.3731177,1.0811074,,,,,,,,,,,,,, -304000,3.077493,1.1266441,,,,,,,,,,,,,, -304100,3.1462872,2.773396,,,,,,,,,,,,,, -304200,4.9266515,3.3673966,,,,,,,,,,,,,, -304248,,,0.8899218440055847,0.4108898341655731,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,135323.39578986168,146816.15788459778,135323.39578986168,11458.216473817823,18.40903115272522,0.0 -304300,3.2930021,1.175738,,,,,,,,,,,,,, -304400,3.1873388,1.3033669,,,,,,,,,,,,,, -304500,3.2724614,1.3058851,,,,,,,,,,,,,, -304600,4.705662,2.63384,,,,,,,,,,,,,, -304700,3.2641754,1.1194319,,,,,,,,,,,,,, -304800,3.212752,1.2598191,,,,,,,,,,,,,, -304900,2.9898145,2.3284218,,,,,,,,,,,,,, -305000,3.2645795,1.1959063,,,,,,,,,,,,,, -305100,3.2913275,1.0501453,,,,,,,,,,,,,, -305193,,,0.8875976204872131,0.416908711194992,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,135743.32669734955,147270.67830872536,135743.32669734955,11492.684744119644,18.48048448562622,0.0 -305200,3.1421878,1.4470893,,,,,,,,,,,,,, -305300,2.995133,1.9932137,,,,,,,,,,,,,, -305400,3.50895,2.9908926,,,,,,,,,,,,,, -305500,3.0923333,2.4426641,,,,,,,,,,,,,, -305600,3.1995418,1.0652487,,,,,,,,,,,,,, -305700,6.818309,3.2394087,,,,,,,,,,,,,, -305800,3.1208265,1.5714834,,,,,,,,,,,,,, -305900,3.6182492,1.0105791,,,,,,,,,,,,,, -306000,2.9534757,1.0522038,,,,,,,,,,,,,, -306100,3.0740285,1.1521671,,,,,,,,,,,,,, -306138,,,0.8876953125,0.4212630093097687,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,136163.254529953,147727.958391428,136163.254529953,11529.90760731697,18.560980796813965,0.0 -306200,3.2634315,1.1738906,,,,,,,,,,,,,, -306300,3.1761823,1.5860406,,,,,,,,,,,,,, -306400,2.9149692,2.3588119,,,,,,,,,,,,,, -306500,2.9559267,0.99554414,,,,,,,,,,,,,, -306600,3.1102085,1.8887558,,,,,,,,,,,,,, -306700,3.43346,1.8106807,,,,,,,,,,,,,, -306800,3.2174954,1.1228336,,,,,,,,,,,,,, -306900,2.9769623,1.0150064,,,,,,,,,,,,,, -307000,3.88474,3.1899884,,,,,,,,,,,,,, -307081,,,0.8867382407188416,0.4187721908092499,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,136583.568918705,148183.63747787476,136583.568918705,11565.154380083084,18.63020730018616,0.0 -307100,3.6302178,3.2782779,,,,,,,,,,,,,, -307200,3.5509305,1.0532683,,,,,,,,,,,,,, -307300,3.1619766,1.1427791,,,,,,,,,,,,,, -307400,2.9932816,1.1208547,,,,,,,,,,,,,, -307500,3.149191,1.9112507,,,,,,,,,,,,,, -307600,3.2606983,1.2146662,,,,,,,,,,,,,, -307700,3.3734403,2.8383727,,,,,,,,,,,,,, -307800,2.9183612,1.0728512,,,,,,,,,,,,,, -307900,3.1353586,1.1578715,,,,,,,,,,,,,, -308000,3.686894,2.194691,,,,,,,,,,,,,, -308026,,,0.8872265219688416,0.4194600582122803,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,137003.82122135162,148641.28811311722,137003.82122135162,11602.423706293106,18.709379196166992,0.0 -308100,3.0127897,1.4254534,,,,,,,,,,,,,, -308200,3.6128473,1.0835786,,,,,,,,,,,,,, -308300,2.8731341,1.0863943,,,,,,,,,,,,,, -308400,3.1007378,1.1728827,,,,,,,,,,,,,, -308500,2.9285998,1.0342557,,,,,,,,,,,,,, -308600,3.0925593,1.1920794,,,,,,,,,,,,,, -308700,3.299487,1.0847915,,,,,,,,,,,,,, -308800,3.5485744,1.1393943,,,,,,,,,,,,,, -308900,4.0395594,2.7366621,,,,,,,,,,,,,, -308970,,,0.8910546898841858,0.4130688905715942,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,137423.9591574669,149096.95966887474,137423.9591574669,11637.83688402176,18.780176162719727,0.0 -309000,3.2809713,1.2274511,,,,,,,,,,,,,, -309100,3.3778713,1.1726601,,,,,,,,,,,,,, -309200,3.5904312,1.0931283,,,,,,,,,,,,,, -309300,3.1538925,1.0855644,,,,,,,,,,,,,, -309400,2.860646,1.4394977,,,,,,,,,,,,,, -309500,3.357459,0.98562163,,,,,,,,,,,,,, -309600,2.823986,1.3197217,,,,,,,,,,,,,, -309700,3.1020036,1.1257657,,,,,,,,,,,,,, -309800,3.255395,1.1501018,,,,,,,,,,,,,, -309900,2.856331,2.0211306,,,,,,,,,,,,,, -309911,,,0.8875585794448853,0.4212709665298462,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,137844.01475334167,149553.27523708344,137844.01475334167,11673.974525928495,18.853289127349854,0.0 -310000,3.0992818,1.1244823,,,,,,,,,,,,,, -310100,3.4184003,1.6770183,,,,,,,,,,,,,, -310200,3.4056845,1.5645987,,,,,,,,,,,,,, -310300,2.989513,1.1486481,,,,,,,,,,,,,, -310400,3.345209,1.1280277,,,,,,,,,,,,,, -310500,2.9419725,1.0821297,,,,,,,,,,,,,, -310600,3.0888417,1.3782477,,,,,,,,,,,,,, -310700,3.135766,1.0839642,,,,,,,,,,,,,, -310800,2.9469082,1.048965,,,,,,,,,,,,,, -310856,,,0.8844921588897705,0.429676204919815,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,138264.2561571598,150009.64900398254,138264.2561571598,11709.985116958618,18.925705671310425,0.0 -310900,3.0205395,1.2430006,,,,,,,,,,,,,, -311000,3.2396443,1.2420537,,,,,,,,,,,,,, -311100,2.8978891,1.0713775,,,,,,,,,,,,,, -311200,3.2812583,1.2292634,,,,,,,,,,,,,, -311300,3.16364,1.3307924,,,,,,,,,,,,,, -311400,3.1295488,1.140305,,,,,,,,,,,,,, -311500,3.2717633,1.5572821,,,,,,,,,,,,,, -311600,3.0566142,1.6602961,,,,,,,,,,,,,, -311700,3.372802,1.2058654,,,,,,,,,,,,,, -311800,,,0.88685542345047,0.4198791682720184,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,138684.43855118752,150468.6958978176,138684.43855118752,11748.723598957062,19.001494646072388,0.0 -311800,4.2607675,3.2579916,,,,,,,,,,,,,, -311900,3.249944,1.0891457,,,,,,,,,,,,,, -312000,4.3734255,3.2064753,,,,,,,,,,,,,, -312100,3.1990745,1.0977232,,,,,,,,,,,,,, -312200,3.4801924,2.2514343,,,,,,,,,,,,,, -312300,3.212464,1.23779,,,,,,,,,,,,,, -312400,4.4272656,1.1569217,,,,,,,,,,,,,, -312500,3.1552455,1.1177324,,,,,,,,,,,,,, -312600,2.9954627,1.2758878,,,,,,,,,,,,,, -312700,3.1722803,1.1547734,,,,,,,,,,,,,, -312743,,,0.8872265219688416,0.4195569157600403,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,139104.31449127197,150926.6590127945,139104.31449127197,11786.684258937836,19.078715562820435,0.0 -312800,2.9915833,1.104856,,,,,,,,,,,,,, -312900,3.1163535,0.9887322,,,,,,,,,,,,,, -313000,3.1068459,1.7368348,,,,,,,,,,,,,, -313100,3.2367754,1.7193086,,,,,,,,,,,,,, -313200,3.3907766,1.3996066,,,,,,,,,,,,,, -313300,3.1645954,1.1966492,,,,,,,,,,,,,, -313400,3.146489,1.7897557,,,,,,,,,,,,,, -313500,4.330998,3.3280795,,,,,,,,,,,,,, -313600,3.144003,2.476811,,,,,,,,,,,,,, -313689,,,0.8866406083106995,0.4205119907855987,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,139524.47371315956,151385.07851743698,139524.47371315956,11824.821381092072,19.15205764770508,0.0 -313700,2.9094138,1.0732093,,,,,,,,,,,,,, -313800,3.0989082,2.0060031,,,,,,,,,,,,,, -313900,3.0627186,1.132733,,,,,,,,,,,,,, -314000,3.307283,2.6840985,,,,,,,,,,,,,, -314100,3.3884187,1.0673748,,,,,,,,,,,,,, -314200,3.0757256,2.351813,,,,,,,,,,,,,, -314300,3.4342835,1.0921093,,,,,,,,,,,,,, -314400,2.9155743,1.5861483,,,,,,,,,,,,,, -314500,3.0614855,1.1942863,,,,,,,,,,,,,, -314600,3.3641653,2.8131287,,,,,,,,,,,,,, -314633,,,0.8891210556030273,0.4166163206100464,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,139944.40990495682,151842.04384183884,139944.40990495682,11861.724688053131,19.22867512702942,0.0 -314700,3.1460693,1.1128436,,,,,,,,,,,,,, -314800,2.891816,1.2187705,,,,,,,,,,,,,, -314900,3.046408,1.5650499,,,,,,,,,,,,,, -315000,3.0758088,1.5733413,,,,,,,,,,,,,, -315100,2.9713438,1.7872599,,,,,,,,,,,,,, -315200,4.0578475,1.9156436,,,,,,,,,,,,,, -315300,3.1965315,2.3186378,,,,,,,,,,,,,, -315400,3.4753284,2.7204192,,,,,,,,,,,,,, -315500,3.689447,3.2373524,,,,,,,,,,,,,, -315578,,,0.8840429782867432,0.4275960028171539,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,140364.45055365562,152298.446028471,140364.45055365562,11897.955362796783,19.30998063087464,0.0 -315600,3.126359,0.9871572,,,,,,,,,,,,,, -315700,2.9941902,1.593129,,,,,,,,,,,,,, -315800,3.0423765,1.5813704,,,,,,,,,,,,,, -315900,3.0529544,1.162721,,,,,,,,,,,,,, -316000,3.0155299,1.4833871,,,,,,,,,,,,,, -316100,3.0335355,1.1493466,,,,,,,,,,,,,, -316200,3.41086,1.0507325,,,,,,,,,,,,,, -316300,3.3398767,1.1555002,,,,,,,,,,,,,, -316400,3.3062565,1.147174,,,,,,,,,,,,,, -316500,3.019374,1.1893929,,,,,,,,,,,,,, -316522,,,0.8870507478713989,0.4170957505702972,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,140784.50740408897,152756.90191817284,140784.50740408897,11936.230577230452,19.38427710533142,0.0 -316600,3.3302863,1.36446,,,,,,,,,,,,,, -316700,3.5708752,1.1545943,,,,,,,,,,,,,, -316800,3.008889,0.98207784,,,,,,,,,,,,,, -316900,3.0376604,1.0395489,,,,,,,,,,,,,, -317000,3.1665187,1.068839,,,,,,,,,,,,,, -317100,3.1221728,1.129601,,,,,,,,,,,,,, -317200,3.0868716,1.0235611,,,,,,,,,,,,,, -317300,3.45385,2.4550571,,,,,,,,,,,,,, -317400,3.174571,1.6041794,,,,,,,,,,,,,, -317468,,,0.8878124952316284,0.4161619842052459,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,141204.74587631226,153212.0466427803,141204.74587631226,11971.0127120018,19.45900011062622,0.0 -317500,3.6428676,2.2892318,,,,,,,,,,,,,, -317600,3.874263,1.171318,,,,,,,,,,,,,, -317700,2.9712462,1.081099,,,,,,,,,,,,,, -317800,3.2976604,2.374495,,,,,,,,,,,,,, -317900,2.937365,1.1360848,,,,,,,,,,,,,, -318000,3.071491,1.1268653,,,,,,,,,,,,,, -318100,3.149355,1.3099048,,,,,,,,,,,,,, -318200,4.2299333,3.321069,,,,,,,,,,,,,, -318300,2.9495025,1.4906938,,,,,,,,,,,,,, -318400,2.9576166,1.7232068,,,,,,,,,,,,,, -318414,,,0.8857226371765137,0.4262997210025787,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,141624.99427366257,153669.32224822044,141624.99427366257,12007.916976690292,19.53209972381592,0.0 -318500,3.2046463,1.1345261,,,,,,,,,,,,,, -318600,3.9018772,2.72113,,,,,,,,,,,,,, -318700,3.1680033,2.2490718,,,,,,,,,,,,,, -318800,3.3798711,1.3437657,,,,,,,,,,,,,, -318900,3.4813588,1.2906227,,,,,,,,,,,,,, -319000,3.036802,1.6750731,,,,,,,,,,,,,, -319100,3.228801,1.39682,,,,,,,,,,,,,, -319200,3.1107051,1.215749,,,,,,,,,,,,,, -319300,5.2730327,3.1874518,,,,,,,,,,,,,, -319359,,,0.8882616758346558,0.4221481680870056,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,142045.31981515884,154127.99576354027,142045.31981515884,12046.129980802536,19.61472868919373,0.0 -319400,3.0431023,1.7992924,,,,,,,,,,,,,, -319500,3.4632413,2.079428,,,,,,,,,,,,,, -319600,3.0942898,1.0426282,,,,,,,,,,,,,, -319700,3.1595695,2.122028,,,,,,,,,,,,,, -319800,2.9160383,1.5136092,,,,,,,,,,,,,, -319900,3.0225327,1.3025041,,,,,,,,,,,,,, -320000,2.8898222,1.1238003,,,,,,,,,,,,,, -320100,2.905463,1.1175417,,,,,,,,,,,,,, -320200,3.1922996,2.4265692,,,,,,,,,,,,,, -320300,2.922557,2.4714265,,,,,,,,,,,,,, -320301,,,0.8875781297683716,0.4201326370239258,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,142465.31846928596,154583.66032719612,142465.31846928596,12081.67114663124,19.68934178352356,0.0 -320400,3.2125416,1.1758025,,,,,,,,,,,,,, -320500,3.0320854,1.4069769,,,,,,,,,,,,,, -320600,3.3472862,1.3327936,,,,,,,,,,,,,, -320700,3.952887,1.1414392,,,,,,,,,,,,,, -320800,3.3765635,1.1912928,,,,,,,,,,,,,, -320900,3.066971,1.6290963,,,,,,,,,,,,,, -321000,3.3559945,1.1210382,,,,,,,,,,,,,, -321100,3.06214,1.7863392,,,,,,,,,,,,,, -321200,3.9272072,1.7634591,,,,,,,,,,,,,, -321247,,,0.8852929472923279,0.4266625940799713,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,142885.55553865433,155039.99154281616,142885.55553865433,12117.643364191055,19.76137471199036,0.0 -321300,3.3421152,2.1867006,,,,,,,,,,,,,, -321400,3.053651,1.0922383,,,,,,,,,,,,,, -321500,3.175999,1.489067,,,,,,,,,,,,,, -321600,3.0971043,1.0435548,,,,,,,,,,,,,, -321700,3.038055,1.0560013,,,,,,,,,,,,,, -321800,3.6673014,3.205272,,,,,,,,,,,,,, -321900,2.9469235,1.0955081,,,,,,,,,,,,,, -322000,3.7155864,3.1435592,,,,,,,,,,,,,, -322100,3.2307775,1.6891427,,,,,,,,,,,,,, -322195,,,0.8888476490974426,0.4134816527366638,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,143305.7799217701,155496.300719738,143305.7799217701,12153.602460861206,19.83621001243592,0.0 -322200,3.3680506,1.1354463,,,,,,,,,,,,,, -322300,4.114134,2.6918848,,,,,,,,,,,,,, -322400,4.44933,1.8862426,,,,,,,,,,,,,, -322500,4.222453,3.1667247,,,,,,,,,,,,,, -322600,3.0004535,1.1711049,,,,,,,,,,,,,, -322700,3.1431754,1.2004958,,,,,,,,,,,,,, -322800,3.1402237,1.1475961,,,,,,,,,,,,,, -322900,2.9841998,1.1490566,,,,,,,,,,,,,, -323000,5.5690303,1.2832183,,,,,,,,,,,,,, -323100,3.0155056,1.2653279,,,,,,,,,,,,,, -323138,,,0.8867773413658142,0.4213497340679168,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,143725.90061926842,155955.61283946037,143725.90061926842,12192.668316364288,19.912522077560425,0.0 -323200,10.012027,3.1255736,,,,,,,,,,,,,, -323300,2.9976535,1.1571289,,,,,,,,,,,,,, -323400,3.1854587,1.6030542,,,,,,,,,,,,,, -323500,3.1544569,1.1758717,,,,,,,,,,,,,, -323600,2.8867874,1.4154854,,,,,,,,,,,,,, -323700,3.2200997,2.419021,,,,,,,,,,,,,, -323800,3.2874658,2.432376,,,,,,,,,,,,,, -323900,2.7709036,1.5050659,,,,,,,,,,,,,, -324000,3.373591,1.1013666,,,,,,,,,,,,,, -324081,,,0.8880664110183716,0.4195075631141662,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,144145.90799617767,156412.74842858317,144145.90799617767,12229.671454191208,19.987329483032227,0.0 -324100,3.062807,1.1154003,,,,,,,,,,,,,, -324200,4.2385063,2.6301117,,,,,,,,,,,,,, -324300,3.4547718,2.6747856,,,,,,,,,,,,,, -324400,2.924909,1.0499427,,,,,,,,,,,,,, -324500,3.4800751,2.8601873,,,,,,,,,,,,,, -324600,3.5045428,2.1532707,,,,,,,,,,,,,, -324700,2.9601648,2.1728928,,,,,,,,,,,,,, -324800,2.862882,1.7247578,,,,,,,,,,,,,, -324900,2.9562762,1.2454822,,,,,,,,,,,,,, -325000,3.0172064,1.533119,,,,,,,,,,,,,, -325027,,,0.8875976204872131,0.4145772755146026,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,144566.0894215107,156869.671346426,144566.0894215107,12266.286154031754,20.06376004219055,0.0 -325100,3.1581333,1.4116329,,,,,,,,,,,,,, -325200,4.3030396,3.3566341,,,,,,,,,,,,,, -325300,3.2820415,1.3166472,,,,,,,,,,,,,, -325400,3.0489352,1.1200128,,,,,,,,,,,,,, -325500,3.2489035,1.2652974,,,,,,,,,,,,,, -325600,3.42479,1.1137419,,,,,,,,,,,,,, -325700,3.1147492,1.1408515,,,,,,,,,,,,,, -325800,3.1057842,1.705596,,,,,,,,,,,,,, -325900,3.001419,1.0901401,,,,,,,,,,,,,, -325972,,,0.8864648342132568,0.4213025867938995,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,144986.1606106758,157327.93233895302,144986.1606106758,12304.35082411766,20.1387312412262,0.0 -326000,3.0257208,1.2130883,,,,,,,,,,,,,, -326100,3.178048,1.5504326,,,,,,,,,,,,,, -326200,3.4813216,1.2333215,,,,,,,,,,,,,, -326300,3.5274503,2.6121469,,,,,,,,,,,,,, -326400,3.5424614,2.382727,,,,,,,,,,,,,, -326500,3.4841492,2.9279525,,,,,,,,,,,,,, -326600,2.8552084,1.1138648,,,,,,,,,,,,,, -326700,3.3815112,2.7597182,,,,,,,,,,,,,, -326800,3.192667,1.0593716,,,,,,,,,,,,,, -326900,3.2819078,1.2240578,,,,,,,,,,,,,, -326914,,,0.8908789157867432,0.4091778099536896,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,145406.35031175613,157786.98872613907,145406.35031175613,12343.091826438904,20.214343786239624,0.0 -327000,3.088183,1.2979472,,,,,,,,,,,,,, -327100,3.3086317,1.1512824,,,,,,,,,,,,,, -327200,7.6700096,3.215658,,,,,,,,,,,,,, -327300,3.2186599,1.7854975,,,,,,,,,,,,,, -327400,3.2394192,1.1309351,,,,,,,,,,,,,, -327500,2.8598433,1.2447363,,,,,,,,,,,,,, -327600,3.1586723,1.1833808,,,,,,,,,,,,,, -327700,2.9638696,1.7521923,,,,,,,,,,,,,, -327800,3.1666744,1.1246471,,,,,,,,,,,,,, -327857,,,0.8892382383346558,0.4134601950645447,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,145826.32364201546,158246.08665442467,145826.32364201546,12382.091963529589,20.28935670852661,0.0 -327900,5.0291104,1.6335843,,,,,,,,,,,,,, -328000,3.069555,1.0765308,,,,,,,,,,,,,, -328100,3.4382684,2.4462833,,,,,,,,,,,,,, -328200,2.9961674,1.3986913,,,,,,,,,,,,,, -328300,3.2204828,1.1695077,,,,,,,,,,,,,, -328400,2.9993489,1.9665211,,,,,,,,,,,,,, -328500,6.3312893,3.3330348,,,,,,,,,,,,,, -328600,3.2510629,1.2981179,,,,,,,,,,,,,, -328700,3.2761407,1.4891633,,,,,,,,,,,,,, -328800,3.2965286,1.0662279,,,,,,,,,,,,,, -328801,,,0.887499988079071,0.4138509631156921,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,146246.33985328674,158702.5681695938,146246.33985328674,12418.419880151749,20.377084732055664,0.0 -328900,3.2695298,1.0485675,,,,,,,,,,,,,, -329000,3.2260842,1.8273888,,,,,,,,,,,,,, -329100,3.3156245,1.1313313,,,,,,,,,,,,,, -329200,3.0403674,1.0391169,,,,,,,,,,,,,, -329300,3.0517688,2.5567675,,,,,,,,,,,,,, -329400,3.2649486,1.1120015,,,,,,,,,,,,,, -329500,3.3533585,1.190009,,,,,,,,,,,,,, -329600,4.194572,1.4740502,,,,,,,,,,,,,, -329700,3.4218607,1.2563313,,,,,,,,,,,,,, -329746,,,0.8868359327316284,0.4221839010715484,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,146666.76261496544,159159.97850561142,146666.76261496544,12455.2824716568,20.45138263702393,0.0 -329800,3.498819,2.6670825,,,,,,,,,,,,,, -329900,3.3041472,2.0674036,,,,,,,,,,,,,, -330000,3.0542452,1.0530764,,,,,,,,,,,,,, -330100,3.3901184,1.2252625,,,,,,,,,,,,,, -330200,2.9539647,1.6607744,,,,,,,,,,,,,, -330300,3.6604843,1.3801428,,,,,,,,,,,,,, -330400,3.0032673,1.6251097,,,,,,,,,,,,,, -330500,3.2807562,2.660686,,,,,,,,,,,,,, -330600,3.1358414,1.5183642,,,,,,,,,,,,,, -330690,,,0.8869531154632568,0.4258280098438263,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,147086.9284837246,159616.47861504555,147086.9284837246,12491.49240732193,20.52607488632202,0.0 -330700,3.2444212,1.2051678,,,,,,,,,,,,,, -330800,3.3665354,1.1012007,,,,,,,,,,,,,, -330900,3.4683697,1.9520317,,,,,,,,,,,,,, -331000,3.7092547,2.8159745,,,,,,,,,,,,,, -331100,5.0922194,1.2032888,,,,,,,,,,,,,, -331200,2.887637,1.130169,,,,,,,,,,,,,, -331300,3.4343796,1.1308566,,,,,,,,,,,,,, -331400,3.3483539,1.1151187,,,,,,,,,,,,,, -331500,3.298055,2.4871194,,,,,,,,,,,,,, -331600,3.927473,3.0995069,,,,,,,,,,,,,, -331633,,,0.8866601586341858,0.4206277728080749,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,147506.69037270546,160074.74869942665,147506.69037270546,12529.59518647194,20.882078886032104,0.0 -331700,3.0784628,2.0308397,,,,,,,,,,,,,, -331800,4.039121,1.1606368,,,,,,,,,,,,,, -331900,3.238691,1.5120152,,,,,,,,,,,,,, -332000,3.1179516,1.2303448,,,,,,,,,,,,,, -332100,3.1640348,1.793493,,,,,,,,,,,,,, -332200,3.0291042,2.0980418,,,,,,,,,,,,,, -332300,3.216171,1.8874115,,,,,,,,,,,,,, -332400,3.0798,1.3159083,,,,,,,,,,,,,, -332500,3.413399,1.118681,,,,,,,,,,,,,, -332573,,,0.8890624642372131,0.4169419109821319,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,147926.8549695015,160531.36344838142,147926.8549695015,12565.919513225555,20.957940578460693,0.0 -332600,3.044092,1.0662731,,,,,,,,,,,,,, -332700,3.799104,3.281613,,,,,,,,,,,,,, -332800,3.8630607,3.2224157,,,,,,,,,,,,,, -332900,3.2668097,1.3018659,,,,,,,,,,,,,, -333000,3.2642202,1.4808257,,,,,,,,,,,,,, -333100,3.0152695,1.0243266,,,,,,,,,,,,,, -333200,3.08354,1.1431075,,,,,,,,,,,,,, -333300,3.7813942,1.3214257,,,,,,,,,,,,,, -333400,3.1008496,1.1751732,,,,,,,,,,,,,, -333500,3.123319,1.0750506,,,,,,,,,,,,,, -333516,,,0.8870702981948853,0.4185446500778198,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,148347.11112332344,160989.44289064407,148347.11112332344,12603.616605520248,21.03495407104492,0.0 -333600,3.7444823,3.1824703,,,,,,,,,,,,,, -333700,4.1351275,2.9111438,,,,,,,,,,,,,, -333800,3.0058026,2.5027378,,,,,,,,,,,,,, -333900,3.0985067,1.024661,,,,,,,,,,,,,, -334000,3.1530848,1.0703626,,,,,,,,,,,,,, -334100,3.0412085,1.1931571,,,,,,,,,,,,,, -334200,4.0746436,3.15064,,,,,,,,,,,,,, -334300,3.3697078,1.14765,,,,,,,,,,,,,, -334400,3.4960759,2.8031049,,,,,,,,,,,,,, -334459,,,0.8873242139816284,0.4183368682861328,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,148767.2456228733,161445.51134824753,148767.2456228733,12639.42637515068,21.11026430130005,0.0 -334500,3.4723208,2.8972423,,,,,,,,,,,,,, -334600,3.1374533,1.1326174,,,,,,,,,,,,,, -334700,3.1487458,1.1997596,,,,,,,,,,,,,, -334800,3.0762086,1.0995687,,,,,,,,,,,,,, -334900,3.0485563,1.656955,,,,,,,,,,,,,, -335000,3.1417704,2.1325393,,,,,,,,,,,,,, -335100,3.003072,1.7611825,,,,,,,,,,,,,, -335200,3.122459,1.7027985,,,,,,,,,,,,,, -335300,3.0035233,1.956966,,,,,,,,,,,,,, -335400,3.1434066,1.1506336,,,,,,,,,,,,,, -335404,,,0.8857226371765137,0.4270011782646179,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,149187.25118470192,161902.31423664093,149187.25118470192,12676.100577354431,21.18437361717224,0.0 -335500,3.7290184,3.2824676,,,,,,,,,,,,,, -335600,3.4765058,1.1646936,,,,,,,,,,,,,, -335700,3.2353852,1.4138626,,,,,,,,,,,,,, -335800,2.927163,1.5746316,,,,,,,,,,,,,, -335900,3.346112,1.088462,,,,,,,,,,,,,, -336000,3.245712,1.339531,,,,,,,,,,,,,, -336100,2.9334006,0.9849615,,,,,,,,,,,,,, -336200,3.5809505,3.1159792,,,,,,,,,,,,,, -336300,3.1455636,1.201192,,,,,,,,,,,,,, -336343,,,0.8882421851158142,0.4184745550155639,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,149607.38150835037,162361.1449034214,149607.38150835037,12714.661440134048,21.27260971069336,0.0 -336400,3.485902,1.2562041,,,,,,,,,,,,,, -336500,2.9720793,2.0496535,,,,,,,,,,,,,, -336600,3.0801096,1.2189968,,,,,,,,,,,,,, -336700,2.914579,1.125662,,,,,,,,,,,,,, -336800,3.3189507,1.250162,,,,,,,,,,,,,, -336900,3.1337612,1.1196135,,,,,,,,,,,,,, -337000,3.5460618,1.1620778,,,,,,,,,,,,,, -337100,3.306965,1.0471921,,,,,,,,,,,,,, -337200,4.067587,1.6322424,,,,,,,,,,,,,, -337283,,,0.8860155940055847,0.4216941297054291,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,150027.41157388687,162817.23032855988,150027.41157388687,12750.583944559095,21.355368614196777,0.0 -337300,3.0462823,2.3962698,,,,,,,,,,,,,, -337400,3.333441,2.6636066,,,,,,,,,,,,,, -337500,2.9799163,1.1618253,,,,,,,,,,,,,, -337600,2.7988904,1.1011416,,,,,,,,,,,,,, -337700,3.6873949,2.7570584,,,,,,,,,,,,,, -337800,3.2174945,1.5154486,,,,,,,,,,,,,, -337900,3.542709,3.1087022,,,,,,,,,,,,,, -338000,3.3616369,1.2324647,,,,,,,,,,,,,, -338100,3.5010467,2.3804712,,,,,,,,,,,,,, -338200,4.682827,3.3369193,,,,,,,,,,,,,, -338224,,,0.88636714220047,0.4208191335201263,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,150447.57310414314,163276.24970722198,150447.57310414314,12789.311537981032,21.43650913238525,0.0 -338300,3.337296,1.2012855,,,,,,,,,,,,,, -338400,3.0303302,1.1247406,,,,,,,,,,,,,, -338500,3.1348984,1.1596293,,,,,,,,,,,,,, -338600,3.1603222,2.5622709,,,,,,,,,,,,,, -338700,6.091657,3.2622745,,,,,,,,,,,,,, -338800,3.2356644,1.9762548,,,,,,,,,,,,,, -338900,3.817104,2.032248,,,,,,,,,,,,,, -339000,3.2213492,1.4150052,,,,,,,,,,,,,, -339100,3.7021298,3.1328058,,,,,,,,,,,,,, -339168,,,0.8886523246765137,0.4195660650730133,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,150867.62800478935,163733.52498984337,150867.62800478935,12826.396708726885,21.5212631225586,0.0 -339200,3.3939133,2.9291246,,,,,,,,,,,,,, -339300,2.7822988,1.0082,,,,,,,,,,,,,, -339400,3.8756633,1.0855353,,,,,,,,,,,,,, -339500,3.057497,1.2284437,,,,,,,,,,,,,, -339600,2.8905215,1.0274018,,,,,,,,,,,,,, -339700,3.1712983,1.2953717,,,,,,,,,,,,,, -339800,3.2194574,1.1738524,,,,,,,,,,,,,, -339900,3.3821776,1.2433325,,,,,,,,,,,,,, -340000,3.8927743,2.8888369,,,,,,,,,,,,,, -340100,3.3789315,1.1900463,,,,,,,,,,,,,, -340113,,,0.8841210603713989,0.4273837208747864,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,151287.56206703186,164190.16235351562,151287.56206703186,12862.9636387825,21.607585906982425,0.0 -340200,4.5886655,1.1745602,,,,,,,,,,,,,, -340300,3.277553,1.1681399,,,,,,,,,,,,,, -340400,3.2673173,1.2042907,,,,,,,,,,,,,, -340500,3.1072905,1.3710853,,,,,,,,,,,,,, -340600,3.4063106,2.9882767,,,,,,,,,,,,,, -340700,3.41766,2.7612133,,,,,,,,,,,,,, -340800,3.572794,3.0117571,,,,,,,,,,,,,, -340900,3.3070304,1.1003618,,,,,,,,,,,,,, -341000,3.574679,1.4943048,,,,,,,,,,,,,, -341059,,,0.8888671398162842,0.4130294620990753,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,151707.53594970703,164647.4318766594,151707.53594970703,12900.13408112526,21.6829161643982,0.0 -341100,4.0921907,3.174159,,,,,,,,,,,,,, -341200,2.8782194,1.0682188,,,,,,,,,,,,,, -341300,3.3728235,2.4062223,,,,,,,,,,,,,, -341400,3.9783528,3.2081246,,,,,,,,,,,,,, -341500,3.0229542,1.0851991,,,,,,,,,,,,,, -341600,3.0981336,1.1052547,,,,,,,,,,,,,, -341700,2.8089085,2.2496796,,,,,,,,,,,,,, -341800,2.9745388,1.1566386,,,,,,,,,,,,,, -341900,2.9239297,1.9940681,,,,,,,,,,,,,, -342000,4.089554,3.3417356,,,,,,,,,,,,,, -342003,,,0.8862890601158142,0.4243661165237427,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,152127.7273261547,165105.02981305122,152127.7273261547,12937.41590332985,21.75795602798462,0.0 -342100,3.1139302,1.1895165,,,,,,,,,,,,,, -342200,3.8154595,2.8607395,,,,,,,,,,,,,, -342300,2.933763,1.0231814,,,,,,,,,,,,,, -342400,3.537922,2.9357612,,,,,,,,,,,,,, -342500,3.0119789,2.1214683,,,,,,,,,,,,,, -342600,3.387466,1.1384703,,,,,,,,,,,,,, -342700,3.7050667,3.2556734,,,,,,,,,,,,,, -342800,3.137278,1.0997782,,,,,,,,,,,,,, -342900,2.9862673,1.2457047,,,,,,,,,,,,,, -342947,,,0.8872265219688416,0.4221601486206054,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,152547.99410772324,165562.70320606232,152547.99410772324,12974.693664312364,21.836976051330566,0.0 -343000,2.9665904,2.1576424,,,,,,,,,,,,,, -343100,3.0055857,1.1734971,,,,,,,,,,,,,, -343200,3.1359813,1.367852,,,,,,,,,,,,,, -343300,3.0258706,1.2114555,,,,,,,,,,,,,, -343400,3.265625,2.8026059,,,,,,,,,,,,,, -343500,3.7327597,2.1510296,,,,,,,,,,,,,, -343600,3.0238755,1.0855995,,,,,,,,,,,,,, -343700,4.2472763,3.2936208,,,,,,,,,,,,,, -343800,2.944901,1.1183292,,,,,,,,,,,,,, -343891,,,0.8874413967132568,0.4182123839855194,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,152968.00593590736,166020.32029771805,152968.00593590736,13012.170394659042,21.91489100456237,0.0 -343900,2.9624724,1.1028445,,,,,,,,,,,,,, -344000,4.0966125,3.067452,,,,,,,,,,,,,, -344100,3.94006,2.9569685,,,,,,,,,,,,,, -344200,3.1273334,1.1183817,,,,,,,,,,,,,, -344300,3.7234392,2.8848777,,,,,,,,,,,,,, -344400,3.1541696,1.4340266,,,,,,,,,,,,,, -344500,3.279332,2.757358,,,,,,,,,,,,,, -344600,2.908198,1.3346328,,,,,,,,,,,,,, -344700,3.0670323,1.3773092,,,,,,,,,,,,,, -344800,3.2132556,1.1483659,,,,,,,,,,,,,, -344835,,,0.8876562118530273,0.4235063791275024,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,153388.0217819214,166477.5153517723,153388.0217819214,13049.210852384567,22.003353595733643,0.0 -344900,3.7082663,2.8764627,,,,,,,,,,,,,, -345000,2.8006577,1.8856386,,,,,,,,,,,,,, -345100,2.9109669,1.1926221,,,,,,,,,,,,,, -345200,2.8287685,1.1271583,,,,,,,,,,,,,, -345300,2.9221861,1.0328927,,,,,,,,,,,,,, -345400,3.4431756,2.872999,,,,,,,,,,,,,, -345500,3.5136256,1.080028,,,,,,,,,,,,,, -345600,3.1778257,2.3715644,,,,,,,,,,,,,, -345700,3.0593653,2.306053,,,,,,,,,,,,,, -345777,,,0.8870507478713989,0.4185346066951751,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,153808.1084370613,166935.82588744164,153808.1084370613,13087.304275989532,22.083884239196777,0.0 -345800,3.1986065,1.309109,,,,,,,,,,,,,, -345900,3.2692444,1.5761945,,,,,,,,,,,,,, -346000,3.5880017,3.1010585,,,,,,,,,,,,,, -346100,2.9951234,1.2277844,,,,,,,,,,,,,, -346200,2.9890556,1.3141701,,,,,,,,,,,,,, -346300,3.1931326,1.1664283,,,,,,,,,,,,,, -346400,3.4028444,1.1315111,,,,,,,,,,,,,, -346500,2.8916264,1.635475,,,,,,,,,,,,,, -346600,3.370895,2.8220408,,,,,,,,,,,,,, -346700,3.0116913,1.5790474,,,,,,,,,,,,,, -346721,,,0.8866601586341858,0.4186669886112213,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,154228.45319128036,167394.76393675804,154228.45319128036,13125.770390748978,22.161634922027588,0.0 -346800,3.4676635,2.0938,,,,,,,,,,,,,, -346900,3.383816,1.1455089,,,,,,,,,,,,,, -347000,3.153597,2.4939463,,,,,,,,,,,,,, -347100,3.4242675,2.918507,,,,,,,,,,,,,, -347200,3.0174234,1.2246552,,,,,,,,,,,,,, -347300,3.354658,1.3691142,,,,,,,,,,,,,, -347400,3.302118,1.6526686,,,,,,,,,,,,,, -347500,2.8797698,1.1369977,,,,,,,,,,,,,, -347600,3.8905132,3.1063364,,,,,,,,,,,,,, -347663,,,0.887011706829071,0.4225160777568817,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,154648.5946586132,167853.7642493248,154648.5946586132,13164.504692316055,22.23713183403015,0.0 -347700,3.3359108,2.303576,,,,,,,,,,,,,, -347800,3.1569524,2.6718047,,,,,,,,,,,,,, -347900,3.393322,2.7689269,,,,,,,,,,,,,, -348000,3.676087,3.211575,,,,,,,,,,,,,, -348100,3.3872528,2.736675,,,,,,,,,,,,,, -348200,3.763989,3.0275097,,,,,,,,,,,,,, -348300,2.8317156,1.2585787,,,,,,,,,,,,,, -348400,3.5041776,1.209486,,,,,,,,,,,,,, -348500,3.1783502,2.5921292,,,,,,,,,,,,,, -348600,3.0235796,1.1235801,,,,,,,,,,,,,, -348606,,,0.8873828053474426,0.4159050583839416,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,155068.51655745506,168312.95635533333,155068.51655745506,13203.645637273788,22.316974401474,0.0 -348700,3.4149675,1.3394651,,,,,,,,,,,,,, -348800,2.8785584,1.2711872,,,,,,,,,,,,,, -348900,3.4966762,1.4299321,,,,,,,,,,,,,, -349000,3.0892823,1.2051772,,,,,,,,,,,,,, -349100,3.4427106,1.1652969,,,,,,,,,,,,,, -349200,3.3126285,1.3066123,,,,,,,,,,,,,, -349300,3.7638803,3.1125257,,,,,,,,,,,,,, -349400,2.87654,1.0754789,,,,,,,,,,,,,, -349500,3.0646887,1.1265916,,,,,,,,,,,,,, -349551,,,0.8894140720367432,0.4155358672142029,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,155488.57597899437,168773.9312326908,155488.57597899437,13244.434262275696,22.393721342086792,0.0 -349600,3.2498024,1.2198493,,,,,,,,,,,,,, -349700,2.954117,2.061009,,,,,,,,,,,,,, -349800,3.9495835,3.1487186,,,,,,,,,,,,,, -349900,3.191983,2.0794418,,,,,,,,,,,,,, -350000,3.3279865,1.1673493,,,,,,,,,,,,,, -350100,2.956461,1.4609418,,,,,,,,,,,,,, -350200,3.4347198,1.2366387,,,,,,,,,,,,,, -350300,2.985326,1.113091,,,,,,,,,,,,,, -350400,3.538434,1.3675799,,,,,,,,,,,,,, -350497,,,0.8883593678474426,0.4154931306838989,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,155908.78848218918,169232.9659898281,155908.78848218918,13283.12947845459,22.47111248970032,0.0 -350500,4.1147404,3.3107204,,,,,,,,,,,,,, -350600,3.5407615,1.7733713,,,,,,,,,,,,,, -350700,2.9510007,1.1897162,,,,,,,,,,,,,, -350800,3.0318162,2.6148741,,,,,,,,,,,,,, -350900,3.1597307,1.1816566,,,,,,,,,,,,,, -351000,2.8244812,1.2756,,,,,,,,,,,,,, -351100,2.7968042,1.7965726,,,,,,,,,,,,,, -351200,3.0989373,2.550271,,,,,,,,,,,,,, -351300,2.8846893,1.5487131,,,,,,,,,,,,,, -351400,3.1824422,1.1671467,,,,,,,,,,,,,, -351441,,,0.8911327719688416,0.4061800241470337,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,156328.82681155205,169691.5860159397,156328.82681155205,13321.576848506927,22.55529522895813,0.0 -351500,3.542325,2.9865818,,,,,,,,,,,,,, -351600,3.1462226,2.624098,,,,,,,,,,,,,, -351700,3.155032,1.0502334,,,,,,,,,,,,,, -351800,3.682239,2.4166214,,,,,,,,,,,,,, -351900,3.0699768,2.3018918,,,,,,,,,,,,,, -352000,3.3967435,1.5721427,,,,,,,,,,,,,, -352100,3.2300324,1.1511886,,,,,,,,,,,,,, -352200,3.1552005,1.0967096,,,,,,,,,,,,,, -352300,3.03172,1.0875823,,,,,,,,,,,,,, -352387,,,0.8876367211341858,0.4141596555709839,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,156748.8176636696,170150.9909837246,156748.8176636696,13360.859747171402,22.63696026802063,0.0 -352400,3.1897821,2.389757,,,,,,,,,,,,,, -352500,3.0459285,1.0827556,,,,,,,,,,,,,, -352600,2.9048917,1.1437103,,,,,,,,,,,,,, -352700,3.3566332,1.1916213,,,,,,,,,,,,,, -352800,3.8141131,2.1954684,,,,,,,,,,,,,, -352900,3.0044186,1.1433427,,,,,,,,,,,,,, -353000,3.3159263,2.8249948,,,,,,,,,,,,,, -353100,3.2800095,2.8145893,,,,,,,,,,,,,, -353200,3.0957384,1.0858445,,,,,,,,,,,,,, -353300,3.0903478,1.1223823,,,,,,,,,,,,,, -353328,,,0.8882421851158142,0.4224283099174499,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,157168.728931427,170607.67958974838,157168.728931427,13397.505673408508,22.719033241271973,0.0 -353400,3.1488643,1.1076403,,,,,,,,,,,,,, -353500,4.6722693,2.9869666,,,,,,,,,,,,,, -353600,2.9939606,1.2241185,,,,,,,,,,,,,, -353700,3.1516278,1.1258394,,,,,,,,,,,,,, -353800,3.172194,1.2969106,,,,,,,,,,,,,, -353900,3.2602806,2.0218332,,,,,,,,,,,,,, -354000,3.254807,1.1482201,,,,,,,,,,,,,, -354100,3.1826897,1.1436433,,,,,,,,,,,,,, -354200,4.7977486,3.2587867,,,,,,,,,,,,,, -354273,,,0.8841992020606995,0.4266878366470337,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,157588.73220658302,171066.66933846474,157588.73220658302,13436.36325597763,22.79829692840576,0.0 -354300,2.8629708,1.3066258,,,,,,,,,,,,,, -354400,3.051632,1.1572875,,,,,,,,,,,,,, -354500,3.1348193,2.4846244,,,,,,,,,,,,,, -354600,3.0227418,1.0838543,,,,,,,,,,,,,, -354700,3.030355,1.2095252,,,,,,,,,,,,,, -354800,3.0503392,1.0655369,,,,,,,,,,,,,, -354900,3.6348448,2.536083,,,,,,,,,,,,,, -355000,3.062086,2.617341,,,,,,,,,,,,,, -355100,3.2962685,1.4920262,,,,,,,,,,,,,, -355200,6.4185834,1.6996089,,,,,,,,,,,,,, -355217,,,0.8879101276397705,0.4218570291996002,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,158008.9703757763,171524.59533452988,158008.9703757763,13473.909487962725,22.889212369918823,0.0 -355300,3.160347,1.1762819,,,,,,,,,,,,,, -355400,3.551432,2.3355525,,,,,,,,,,,,,, -355500,2.9348257,1.6791031,,,,,,,,,,,,,, -355600,2.9947493,1.1824591,,,,,,,,,,,,,, -355700,3.0694854,1.219272,,,,,,,,,,,,,, -355800,3.1074882,1.1559598,,,,,,,,,,,,,, -355900,3.5277426,2.5934122,,,,,,,,,,,,,, -356000,3.086785,2.3160934,,,,,,,,,,,,,, -356100,3.0306575,1.0240986,,,,,,,,,,,,,, -356161,,,0.8890234231948853,0.4148597717285156,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,158429.252024889,171983.66263103485,158429.252024889,13512.55196905136,22.982855796813965,0.0 -356200,3.9260616,3.2049239,,,,,,,,,,,,,, -356300,3.3141606,1.6744741,,,,,,,,,,,,,, -356400,3.874286,3.1870546,,,,,,,,,,,,,, -356500,3.746803,2.751649,,,,,,,,,,,,,, -356600,3.0286877,1.037718,,,,,,,,,,,,,, -356700,3.0906186,1.6488833,,,,,,,,,,,,,, -356800,3.1645908,1.5933552,,,,,,,,,,,,,, -356900,3.0635583,1.8135464,,,,,,,,,,,,,, -357000,3.4975743,1.579447,,,,,,,,,,,,,, -357100,3.7288888,3.089424,,,,,,,,,,,,,, -357101,,,0.8884570002555847,0.4165292978286743,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,158849.1560485363,172444.66031646729,158849.1560485363,13553.5134370327,23.065964221954346,0.0 -357200,3.900853,3.1037927,,,,,,,,,,,,,, -357300,5.126272,3.2302518,,,,,,,,,,,,,, -357400,3.609861,1.1215725,,,,,,,,,,,,,, -357500,3.0729146,1.5709552,,,,,,,,,,,,,, -357600,3.2690125,1.5801105,,,,,,,,,,,,,, -357700,3.0687745,1.9362617,,,,,,,,,,,,,, -357800,3.5421906,1.4312778,,,,,,,,,,,,,, -357900,6.690582,1.1761556,,,,,,,,,,,,,, -358000,3.208886,2.4812427,,,,,,,,,,,,,, -358045,,,0.8863476514816284,0.4233624041080475,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,159269.13534212112,172900.82520484924,159269.13534212112,13589.560196638107,23.15531325340271,0.0 -358100,3.517511,1.4238013,,,,,,,,,,,,,, -358200,3.140418,1.7292736,,,,,,,,,,,,,, -358300,3.3133929,1.1996719,,,,,,,,,,,,,, -358400,3.3887076,1.1141142,,,,,,,,,,,,,, -358500,3.0276988,1.2328424,,,,,,,,,,,,,, -358600,2.940049,1.900999,,,,,,,,,,,,,, -358700,3.3930063,1.5429791,,,,,,,,,,,,,, -358800,3.2249544,1.6812897,,,,,,,,,,,,,, -358900,3.3881493,1.1749456,,,,,,,,,,,,,, -358988,,,0.8848632574081421,0.427352637052536,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,159689.05473470688,173357.47729110718,159689.05473470688,13626.16271162033,23.23595213890076,0.0 -359000,3.2474968,1.059117,,,,,,,,,,,,,, -359100,3.3520126,2.7103708,,,,,,,,,,,,,, -359200,3.1076674,1.8359927,,,,,,,,,,,,,, -359300,2.6978092,1.338165,,,,,,,,,,,,,, -359400,2.91104,1.1439954,,,,,,,,,,,,,, -359500,3.1434603,1.090575,,,,,,,,,,,,,, -359600,3.140025,1.1459498,,,,,,,,,,,,,, -359700,3.1046307,1.2425323,,,,,,,,,,,,,, -359800,3.4124105,2.8642414,,,,,,,,,,,,,, -359900,3.0645223,1.1806784,,,,,,,,,,,,,, -359933,,,0.8875781297683716,0.4193789064884186,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,160109.2667543888,173814.71801424026,160109.2667543888,13663.025684833528,23.352552890777588,0.0 -360000,3.2823045,1.4017773,,,,,,,,,,,,,, -360100,3.1229732,1.5599811,,,,,,,,,,,,,, -360200,3.2629063,2.365099,,,,,,,,,,,,,, -360300,2.9762266,1.1156204,,,,,,,,,,,,,, -360400,3.4339144,1.1605861,,,,,,,,,,,,,, -360500,3.5617151,1.2817292,,,,,,,,,,,,,, -360600,3.0484562,1.274872,,,,,,,,,,,,,, -360700,2.9734,1.819058,,,,,,,,,,,,,, -360800,3.2837129,2.834791,,,,,,,,,,,,,, -360877,,,0.8887499570846558,0.4147139191627502,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,160529.21381664276,174270.8171491623,160529.21381664276,13699.04360795021,23.43653154373169,0.0 -360900,4.0718956,1.09882,,,,,,,,,,,,,, -361000,3.700325,2.7017736,,,,,,,,,,,,,, -361100,3.080753,1.2054458,,,,,,,,,,,,,, -361200,3.4266503,1.1716647,,,,,,,,,,,,,, -361300,3.3602505,1.1150478,,,,,,,,,,,,,, -361400,3.3338182,1.152937,,,,,,,,,,,,,, -361500,3.3676531,2.0040262,,,,,,,,,,,,,, -361600,3.3547196,1.8209696,,,,,,,,,,,,,, -361700,3.3271,1.1552192,,,,,,,,,,,,,, -361800,3.6815984,3.0173903,,,,,,,,,,,,,, -361821,,,0.8864648342132568,0.4235952198505401,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,160949.31827545166,174729.65630316734,160949.31827545166,13737.636802196505,23.527692556381226,0.0 -361900,4.4229755,2.6282563,,,,,,,,,,,,,, -362000,3.1462302,1.2051957,,,,,,,,,,,,,, -362100,3.2808743,2.6152816,,,,,,,,,,,,,, -362200,2.9545195,1.3690648,,,,,,,,,,,,,, -362300,3.8499293,2.9868784,,,,,,,,,,,,,, -362400,3.3304105,1.0818939,,,,,,,,,,,,,, -362500,3.4990013,1.2420217,,,,,,,,,,,,,, -362600,3.189007,2.4045277,,,,,,,,,,,,,, -362700,3.2736187,1.2394134,,,,,,,,,,,,,, -362768,,,0.8859570026397705,0.4194192886352539,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,161369.6048526764,175189.22968149185,161369.6048526764,13776.793885946274,23.607590436935425,0.0 -362800,4.4183617,1.1005727,,,,,,,,,,,,,, -362900,2.9666703,1.1779112,,,,,,,,,,,,,, -363000,3.1518304,1.1452601,,,,,,,,,,,,,, -363100,3.2031698,1.2210038,,,,,,,,,,,,,, -363200,3.5453775,1.8387086,,,,,,,,,,,,,, -363300,3.4667692,1.0341136,,,,,,,,,,,,,, -363400,3.1465409,1.1138678,,,,,,,,,,,,,, -363500,3.019399,1.2945704,,,,,,,,,,,,,, -363600,3.0499952,1.242629,,,,,,,,,,,,,, -363700,2.8969672,1.0413026,,,,,,,,,,,,,, -363712,,,0.8870702981948853,0.4185977876186371,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,161789.86298680303,175646.58432078362,161789.86298680303,13813.758406877518,23.688162088394165,0.0 -363800,2.9575284,1.1101329,,,,,,,,,,,,,, -363900,3.3994792,2.0853796,,,,,,,,,,,,,, -364000,3.2290423,2.198036,,,,,,,,,,,,,, -364100,2.9917548,1.113702,,,,,,,,,,,,,, -364200,3.5207078,2.8621478,,,,,,,,,,,,,, -364300,3.1577399,1.2731428,,,,,,,,,,,,,, -364400,3.0809355,1.0326735,,,,,,,,,,,,,, -364500,3.3697054,2.9641235,,,,,,,,,,,,,, -364600,3.2845545,1.0649335,,,,,,,,,,,,,, -364655,,,0.8883007764816284,0.4166864156723022,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,162210.0809226036,176103.28262138367,162210.0809226036,13850.107394695282,23.76880979537964,0.0 -364700,3.1363766,1.1553864,,,,,,,,,,,,,, -364800,3.1820278,1.2202435,,,,,,,,,,,,,, -364900,2.9562702,1.266274,,,,,,,,,,,,,, -365000,2.9266145,1.1813552,,,,,,,,,,,,,, -365100,3.1614993,1.8274891,,,,,,,,,,,,,, -365200,3.0622208,1.144874,,,,,,,,,,,,,, -365300,3.2697387,1.1440861,,,,,,,,,,,,,, -365400,3.103096,2.8604062,,,,,,,,,,,,,, -365500,3.1336775,1.1535261,,,,,,,,,,,,,, -365599,,,0.8848828077316284,0.4277326166629791,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,162630.12792491913,176560.39166498184,162630.12792491913,13887.037593126295,23.85041332244873,0.0 -365600,3.990125,2.998782,,,,,,,,,,,,,, -365700,3.2509267,1.4458191,,,,,,,,,,,,,, -365800,4.2403984,3.1484778,,,,,,,,,,,,,, -365900,3.3082457,1.0830153,,,,,,,,,,,,,, -366000,3.0579185,1.0490863,,,,,,,,,,,,,, -366100,3.0513954,1.1784973,,,,,,,,,,,,,, -366200,3.2753344,1.2685266,,,,,,,,,,,,,, -366300,3.1282554,1.4902266,,,,,,,,,,,,,, -366400,3.9163444,3.2650275,,,,,,,,,,,,,, -366500,3.111572,1.0628619,,,,,,,,,,,,,, -366543,,,0.8892773389816284,0.4170738458633423,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,163050.34325742722,177019.37155103683,163050.34325742722,13925.670383930206,23.932437658309937,0.0 -366600,4.4839582,3.3362474,,,,,,,,,,,,,, -366700,3.9624176,3.198739,,,,,,,,,,,,,, -366800,3.4871073,1.0936676,,,,,,,,,,,,,, -366900,2.9415593,1.4651134,,,,,,,,,,,,,, -367000,3.5934174,2.812297,,,,,,,,,,,,,, -367100,3.5920932,1.3597462,,,,,,,,,,,,,, -367200,2.9411569,1.2434824,,,,,,,,,,,,,, -367300,3.2736824,1.9830567,,,,,,,,,,,,,, -367400,3.0642161,1.1191317,,,,,,,,,,,,,, -367486,,,0.8848046660423279,0.4290616512298584,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,163470.2967107296,177477.52720284462,163470.2967107296,13963.74238228798,24.01252174377441,0.0 -367500,3.3587003,2.677792,,,,,,,,,,,,,, -367600,3.1049597,1.0834265,,,,,,,,,,,,,, -367700,2.8570848,1.9484993,,,,,,,,,,,,,, -367800,3.0720093,1.1388972,,,,,,,,,,,,,, -367900,4.074785,3.1856477,,,,,,,,,,,,,, -368000,4.2183533,3.3045275,,,,,,,,,,,,,, -368100,2.9987292,1.1175339,,,,,,,,,,,,,, -368200,3.1395824,1.7161994,,,,,,,,,,,,,, -368300,4.265966,3.313117,,,,,,,,,,,,,, -368400,3.3513227,1.1374182,,,,,,,,,,,,,, -368429,,,0.8879687190055847,0.420084685087204,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,163890.20581889153,177935.1705136299,163890.20581889153,14001.34293437004,24.096700191497803,0.0 -368500,3.0140297,1.1540269,,,,,,,,,,,,,, -368600,2.8639033,1.8658314,,,,,,,,,,,,,, -368700,3.0481644,1.1136838,,,,,,,,,,,,,, -368800,3.7898047,2.9778419,,,,,,,,,,,,,, -368900,4.3916206,3.2090702,,,,,,,,,,,,,, -369000,3.9433982,3.1756928,,,,,,,,,,,,,, -369100,3.2016394,2.5766268,,,,,,,,,,,,,, -369200,3.1667833,1.1313944,,,,,,,,,,,,,, -369300,3.2404072,1.1511506,,,,,,,,,,,,,, -369372,,,0.8879687190055847,0.420869767665863,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,164310.3669514656,178395.87911200523,164310.3669514656,14041.759779453278,24.177711248397827,0.0 -369400,2.9036565,2.081452,,,,,,,,,,,,,, -369500,3.4758632,1.1200821,,,,,,,,,,,,,, -369600,3.2427292,2.236051,,,,,,,,,,,,,, -369700,2.9951944,1.1364074,,,,,,,,,,,,,, -369800,3.1657481,2.136509,,,,,,,,,,,,,, -369900,3.0008457,1.1803672,,,,,,,,,,,,,, -370000,3.2706301,1.2810814,,,,,,,,,,,,,, -370100,3.2614024,1.5629514,,,,,,,,,,,,,, -370200,3.0613148,1.3956271,,,,,,,,,,,,,, -370300,3.0777338,1.139011,,,,,,,,,,,,,, -370318,,,0.88587886095047,0.4209321737289428,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,164730.44117760658,178855.2717988491,164730.44117760658,14080.945871591568,24.259997367858887,0.0 -370400,4.05285,3.0307689,,,,,,,,,,,,,, -370500,2.8068721,1.8174231,,,,,,,,,,,,,, -370600,3.421421,1.1508788,,,,,,,,,,,,,, -370700,3.1745558,1.1157097,,,,,,,,,,,,,, -370800,2.9434962,1.6624063,,,,,,,,,,,,,, -370900,3.1437624,1.1725976,,,,,,,,,,,,,, -371000,3.7949693,1.1315362,,,,,,,,,,,,,, -371100,3.5880773,3.1743422,,,,,,,,,,,,,, -371200,3.028789,1.3385193,,,,,,,,,,,,,, -371263,,,0.8868359327316284,0.419025719165802,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,165150.38989567757,179313.91506695747,165150.38989567757,14119.509327411652,24.341216564178467,0.0 -371300,3.4261658,1.0578358,,,,,,,,,,,,,, -371400,2.9627175,1.4906838,,,,,,,,,,,,,, -371500,2.7888155,1.3197192,,,,,,,,,,,,,, -371600,3.1228414,1.1077118,,,,,,,,,,,,,, -371700,3.4669676,2.6293023,,,,,,,,,,,,,, -371800,2.8909101,1.1641423,,,,,,,,,,,,,, -371900,3.1714346,1.9961163,,,,,,,,,,,,,, -372000,3.1914935,1.2329193,,,,,,,,,,,,,, -372100,3.0689728,1.1376199,,,,,,,,,,,,,, -372200,3.3301191,1.1794322,,,,,,,,,,,,,, -372206,,,0.8878905773162842,0.4138197898864746,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,165570.40331053734,179772.46859931946,165570.40331053734,14157.917885541916,24.4231505393982,0.0 -372300,3.106111,1.0751784,,,,,,,,,,,,,, -372400,3.6187072,2.5681295,,,,,,,,,,,,,, -372500,3.2920828,2.4264247,,,,,,,,,,,,,, -372600,3.2027032,1.2895216,,,,,,,,,,,,,, -372700,3.696774,2.7840197,,,,,,,,,,,,,, -372800,3.3752387,1.7284299,,,,,,,,,,,,,, -372900,4.2017007,3.2000303,,,,,,,,,,,,,, -373000,3.1686954,1.1902584,,,,,,,,,,,,,, -373100,2.9558651,1.987299,,,,,,,,,,,,,, -373151,,,0.8874022960662842,0.4218630492687225,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,165990.54666543007,180231.72131824493,165990.54666543007,14196.896178245544,24.50393438339233,0.0 -373200,3.498225,1.23435,,,,,,,,,,,,,, -373300,3.0596812,1.3019105,,,,,,,,,,,,,, -373400,3.348424,1.1150346,,,,,,,,,,,,,, -373500,2.9827008,0.99593914,,,,,,,,,,,,,, -373600,3.9399817,1.3263142,,,,,,,,,,,,,, -373700,3.5417361,1.3642318,,,,,,,,,,,,,, -373800,3.120786,1.9098337,,,,,,,,,,,,,, -373900,2.8428416,1.1950555,,,,,,,,,,,,,, -374000,3.1655278,2.1762977,,,,,,,,,,,,,, -374094,,,0.8924023509025574,0.4078778326511383,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,166410.48005771637,180687.84383511543,166410.48005771637,14232.955313920977,24.584126710891724,0.0 -374100,3.1260266,1.1440065,,,,,,,,,,,,,, -374200,2.8474252,1.0620759,,,,,,,,,,,,,, -374300,3.0526335,1.5981089,,,,,,,,,,,,,, -374400,3.4356518,1.1687744,,,,,,,,,,,,,, -374500,3.3257785,2.8923528,,,,,,,,,,,,,, -374600,3.4033406,1.0525621,,,,,,,,,,,,,, -374700,3.1487403,1.1512796,,,,,,,,,,,,,, -374800,3.228016,2.6374164,,,,,,,,,,,,,, -374900,4.0512996,3.1128948,,,,,,,,,,,,,, -375000,3.3179805,1.9586675,,,,,,,,,,,,,, -375037,,,0.8862890601158142,0.4168686270713806,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,166830.50103092194,181145.9012551308,166830.50103092194,14270.859405994415,24.667348384857178,0.0 -375100,3.689529,2.8675013,,,,,,,,,,,,,, -375200,3.6918058,3.212721,,,,,,,,,,,,,, -375300,3.8594906,3.304641,,,,,,,,,,,,,, -375400,4.0853186,3.2666793,,,,,,,,,,,,,, -375500,3.083251,2.5541673,,,,,,,,,,,,,, -375600,3.5172505,2.9198287,,,,,,,,,,,,,, -375700,2.9526763,1.1913415,,,,,,,,,,,,,, -375800,3.9804428,3.107942,,,,,,,,,,,,,, -375900,3.4486916,2.6665616,,,,,,,,,,,,,, -375981,,,0.8896484375,0.4130140542984009,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,167250.72504115105,181606.3165895939,167250.72504115105,14310.904839277267,24.76238989830017,0.0 -376000,3.0836108,2.5332913,,,,,,,,,,,,,, -376100,4.488944,3.258433,,,,,,,,,,,,,, -376200,3.077656,1.2068435,,,,,,,,,,,,,, -376300,3.2919,1.0635555,,,,,,,,,,,,,, -376400,3.2185519,2.0577865,,,,,,,,,,,,,, -376500,3.0181992,1.214221,,,,,,,,,,,,,, -376600,3.1004982,1.1701595,,,,,,,,,,,,,, -376700,3.2204185,2.696188,,,,,,,,,,,,,, -376800,2.7411199,1.5735468,,,,,,,,,,,,,, -376900,3.2323718,1.3862599,,,,,,,,,,,,,, -376924,,,0.8883007764816284,0.4164344668388366,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,167670.94829511642,182065.8271315097,167670.94829511642,14350.03876209259,24.86554718017578,0.0 -377000,3.3866668,1.1756202,,,,,,,,,,,,,, -377100,3.047422,1.312154,,,,,,,,,,,,,, -377200,3.0299077,1.294163,,,,,,,,,,,,,, -377300,3.1268342,1.034492,,,,,,,,,,,,,, -377400,2.9123964,1.0849006,,,,,,,,,,,,,, -377500,2.958959,1.0354785,,,,,,,,,,,,,, -377600,3.0605063,1.1280243,,,,,,,,,,,,,, -377700,3.8499541,3.119635,,,,,,,,,,,,,, -377800,3.328403,1.091795,,,,,,,,,,,,,, -377867,,,0.8856835961341858,0.4257451891899109,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,168090.943703413,182523.81906175613,168090.943703413,14387.900235891342,24.95109248161316,0.0 -377900,3.3547785,1.1015129,,,,,,,,,,,,,, -378000,4.061463,3.2202992,,,,,,,,,,,,,, -378100,3.5927181,3.1118355,,,,,,,,,,,,,, -378200,3.329588,1.0979149,,,,,,,,,,,,,, -378300,2.887407,1.0550289,,,,,,,,,,,,,, -378400,3.3329327,2.3868706,,,,,,,,,,,,,, -378500,4.0747623,1.1341131,,,,,,,,,,,,,, -378600,2.8204935,1.4318389,,,,,,,,,,,,,, -378700,3.0771203,1.1278822,,,,,,,,,,,,,, -378800,3.1962118,1.0028254,,,,,,,,,,,,,, -378809,,,0.8865624666213989,0.422280341386795,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,168510.95400238037,182983.6586351395,168510.95400238037,14427.597375631332,25.032990217208862,0.0 -378900,3.0471036,1.9039749,,,,,,,,,,,,,, -379000,3.3569076,2.795394,,,,,,,,,,,,,, -379100,3.421336,1.2071071,,,,,,,,,,,,,, -379200,3.2202997,1.172622,,,,,,,,,,,,,, -379300,3.526154,3.0594895,,,,,,,,,,,,,, -379400,3.4158254,1.0437982,,,,,,,,,,,,,, -379500,3.3246713,1.120835,,,,,,,,,,,,,, -379600,3.2498937,1.6927319,,,,,,,,,,,,,, -379700,3.1853087,1.0840573,,,,,,,,,,,,,, -379752,,,0.89013671875,0.4157166182994842,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,168930.87727046013,183444.1056733132,168930.87727046013,14467.989336252213,25.11535716056824,0.0 -379800,2.998321,1.9559174,,,,,,,,,,,,,, -379900,2.9756846,1.0286429,,,,,,,,,,,,,, -380000,3.2803304,1.9299309,,,,,,,,,,,,,, -380100,2.867572,1.1003988,,,,,,,,,,,,,, -380200,3.6267362,2.5577328,,,,,,,,,,,,,, -380300,2.8589635,1.1448935,,,,,,,,,,,,,, -380400,3.319119,1.1321806,,,,,,,,,,,,,, -380500,3.728716,3.1742349,,,,,,,,,,,,,, -380600,3.4220877,2.4936695,,,,,,,,,,,,,, -380695,,,0.8870312571525574,0.4199641346931457,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,169350.8524608612,183901.6815738678,169350.8524608612,14505.456683397291,25.19770884513855,0.0 -380700,3.2528872,1.1084044,,,,,,,,,,,,,, -380800,3.3188174,2.5537531,,,,,,,,,,,,,, -380900,3.0696971,1.0470097,,,,,,,,,,,,,, -381000,3.3010473,1.9688987,,,,,,,,,,,,,, -381100,3.308804,2.6672554,,,,,,,,,,,,,, -381200,3.2255564,1.0506853,,,,,,,,,,,,,, -381300,2.9822903,1.0934113,,,,,,,,,,,,,, -381400,3.2211702,1.340358,,,,,,,,,,,,,, -381500,3.239011,1.1517323,,,,,,,,,,,,,, -381600,3.2433457,2.1218772,,,,,,,,,,,,,, -381639,,,0.8875781297683716,0.4152979850769043,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,169771.07170915604,184357.7455892563,169771.07170915604,14541.1564347744,25.292958974838257,0.0 -381700,3.2085197,1.1335341,,,,,,,,,,,,,, -381800,3.8830528,3.141843,,,,,,,,,,,,,, -381900,2.9764261,1.1339991,,,,,,,,,,,,,, -382000,3.5078545,2.0158238,,,,,,,,,,,,,, -382100,3.0797775,1.0585423,,,,,,,,,,,,,, -382200,3.6738877,2.2000587,,,,,,,,,,,,,, -382300,3.059298,2.6143997,,,,,,,,,,,,,, -382400,3.7661624,1.0153403,,,,,,,,,,,,,, -382500,2.861927,2.1383471,,,,,,,,,,,,,, -382582,,,0.8846093416213989,0.4293917417526245,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,170191.14010548592,184819.70988583565,170191.14010548592,14582.919032096865,25.37623953819275,0.0 -382600,3.2185137,1.1962688,,,,,,,,,,,,,, -382700,2.8743951,2.0538564,,,,,,,,,,,,,, -382800,3.227875,1.2383219,,,,,,,,,,,,,, -382900,3.231126,1.0437638,,,,,,,,,,,,,, -383000,3.6925454,1.1045101,,,,,,,,,,,,,, -383100,3.4312906,1.4578103,,,,,,,,,,,,,, -383200,2.8888357,1.863696,,,,,,,,,,,,,, -383300,2.961643,1.2984732,,,,,,,,,,,,,, -383400,3.0241299,1.4318002,,,,,,,,,,,,,, -383500,3.2044318,2.3229222,,,,,,,,,,,,,, -383525,,,0.8859961032867432,0.4245673716068268,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,170611.00712823868,185279.24055552483,170611.00712823868,14622.43280339241,25.47563648223877,0.0 -383600,3.1933794,2.1899338,,,,,,,,,,,,,, -383700,2.8357282,1.4692295,,,,,,,,,,,,,, -383800,3.0461805,1.9760801,,,,,,,,,,,,,, -383900,3.2412393,1.0856271,,,,,,,,,,,,,, -384000,2.8996146,1.1113842,,,,,,,,,,,,,, -384100,3.0094168,1.5459082,,,,,,,,,,,,,, -384200,3.767289,2.88247,,,,,,,,,,,,,, -384300,3.108879,1.1434395,,,,,,,,,,,,,, -384400,3.3685358,1.1001564,,,,,,,,,,,,,, -384469,,,0.88720703125,0.4146858751773834,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,171031.26492524147,185733.45716404915,171031.26492524147,14656.259887695312,25.55725598335266,0.0 -384500,3.143294,2.458455,,,,,,,,,,,,,, -384600,3.2366383,2.6217756,,,,,,,,,,,,,, -384700,3.3302338,1.1807274,,,,,,,,,,,,,, -384800,3.1130831,2.6475317,,,,,,,,,,,,,, -384900,3.1157753,1.0547191,,,,,,,,,,,,,, -385000,3.1003819,2.0767963,,,,,,,,,,,,,, -385100,3.0906937,1.0961567,,,,,,,,,,,,,, -385200,3.4971511,2.0964248,,,,,,,,,,,,,, -385300,3.2658145,1.155509,,,,,,,,,,,,,, -385400,3.3110101,1.3678006,,,,,,,,,,,,,, -385413,,,0.8879492282867432,0.4205496907234192,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,171451.29096341133,186194.85792946813,171451.29096341133,14697.488318920135,25.653950452804565,0.0 -385500,3.0540361,1.6405602,,,,,,,,,,,,,, -385600,4.610433,2.9178226,,,,,,,,,,,,,, -385700,3.1627095,1.2080704,,,,,,,,,,,,,, -385800,3.2682521,1.1229159,,,,,,,,,,,,,, -385900,2.9359746,1.1609839,,,,,,,,,,,,,, -386000,3.1847281,2.1253877,,,,,,,,,,,,,, -386100,2.9244752,1.1761537,,,,,,,,,,,,,, -386200,3.148756,1.060675,,,,,,,,,,,,,, -386300,2.9507246,1.7702048,,,,,,,,,,,,,, -386356,,,0.887011706829071,0.4235352277755737,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,171871.21469521525,186653.4598946572,171871.21469521525,14736.033299446106,25.737998723983765,0.0 -386400,3.2545044,1.1800051,,,,,,,,,,,,,, -386500,3.2979295,1.5224653,,,,,,,,,,,,,, -386600,3.2232854,1.3032416,,,,,,,,,,,,,, -386700,3.3870463,1.4935898,,,,,,,,,,,,,, -386800,2.9245203,1.313704,,,,,,,,,,,,,, -386900,2.9825726,1.0576688,,,,,,,,,,,,,, -387000,3.2166116,1.1434942,,,,,,,,,,,,,, -387100,3.0662227,1.0937647,,,,,,,,,,,,,, -387200,3.1417687,1.0895532,,,,,,,,,,,,,, -387300,,,0.8868163824081421,0.417324811220169,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,172291.44046711922,187113.0497689247,172291.44046711922,14775.26454615593,25.82158470153809,0.0 -387300,3.7045038,2.9058626,,,,,,,,,,,,,, -387400,4.140686,3.0665278,,,,,,,,,,,,,, -387500,3.2396219,2.3369536,,,,,,,,,,,,,, -387600,3.035801,1.4795117,,,,,,,,,,,,,, -387700,3.7814229,3.2692647,,,,,,,,,,,,,, -387800,3.1013043,1.9147222,,,,,,,,,,,,,, -387900,3.3402014,2.866356,,,,,,,,,,,,,, -388000,4.663411,1.1254976,,,,,,,,,,,,,, -388100,3.597287,1.1694553,,,,,,,,,,,,,, -388200,3.0076272,1.627827,,,,,,,,,,,,,, -388244,,,0.8858398199081421,0.4219137728214264,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,172711.33534526825,187572.2773804665,172711.33534526825,14814.46245265007,25.90627861022949,0.0 -388300,3.2507513,1.4460629,,,,,,,,,,,,,, -388400,3.3097024,2.3360434,,,,,,,,,,,,,, -388500,5.1815596,3.1503434,,,,,,,,,,,,,, -388600,3.0442152,1.2256601,,,,,,,,,,,,,, -388700,3.1959023,1.1798848,,,,,,,,,,,,,, -388800,3.057824,1.3161482,,,,,,,,,,,,,, -388900,3.1423917,1.4303218,,,,,,,,,,,,,, -389000,4.1704545,2.9921086,,,,,,,,,,,,,, -389100,4.4869847,3.3212707,,,,,,,,,,,,,, -389184,,,0.8878515362739563,0.4197618365287781,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,173131.36773204803,188031.0147128105,173131.36773204803,14853.030284166336,25.99392938613892,0.0 -389200,3.2323818,0.9630849,,,,,,,,,,,,,, -389300,3.4021423,2.159028,,,,,,,,,,,,,, -389400,3.190893,1.3281378,,,,,,,,,,,,,, -389500,2.9863698,2.0532994,,,,,,,,,,,,,, -389600,3.055863,1.1449335,,,,,,,,,,,,,, -389700,3.0901732,1.1318758,,,,,,,,,,,,,, -389800,2.8861291,1.423955,,,,,,,,,,,,,, -389900,3.470605,1.1479954,,,,,,,,,,,,,, -390000,3.8878129,3.2010071,,,,,,,,,,,,,, -390100,3.0324554,1.4024097,,,,,,,,,,,,,, -390130,,,0.8890624642372131,0.4212661981582641,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,173551.6240181923,188488.27039361,173551.6240181923,14889.88103055954,26.092432737350464,0.0 -390200,3.9046903,2.1619956,,,,,,,,,,,,,, -390300,2.9800987,2.0715716,,,,,,,,,,,,,, -390400,3.2127616,1.5992532,,,,,,,,,,,,,, -390500,2.9429135,1.6344372,,,,,,,,,,,,,, -390600,3.638741,2.414527,,,,,,,,,,,,,, -390700,4.0921187,3.1927466,,,,,,,,,,,,,, -390800,3.1767101,1.5674307,,,,,,,,,,,,,, -390900,3.0682309,1.2510864,,,,,,,,,,,,,, -391000,3.382625,2.7063859,,,,,,,,,,,,,, -391073,,,0.88525390625,0.4267793893814087,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,173971.50109505653,188945.62973308563,173971.50109505653,14927.226495981216,26.17941808700561,0.0 -391100,3.0402799,2.1360438,,,,,,,,,,,,,, -391200,3.0702116,1.4918101,,,,,,,,,,,,,, -391300,2.9472728,1.3234166,,,,,,,,,,,,,, -391400,2.7748728,2.0603285,,,,,,,,,,,,,, -391500,2.9401896,1.6195236,,,,,,,,,,,,,, -391600,3.0774891,1.7796502,,,,,,,,,,,,,, -391700,3.7471955,1.1417409,,,,,,,,,,,,,, -391800,3.8427663,1.220938,,,,,,,,,,,,,, -391900,3.177904,1.3343801,,,,,,,,,,,,,, -392000,2.976259,1.2769818,,,,,,,,,,,,,, -392017,,,0.88720703125,0.4180474281311035,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,174391.5544939041,189404.8793940544,174391.5544939041,14966.27654528618,26.275421142578125,0.0 -392100,3.1849837,1.1872342,,,,,,,,,,,,,, -392200,3.2115576,1.1497247,,,,,,,,,,,,,, -392300,3.185112,1.6947359,,,,,,,,,,,,,, -392400,3.6303525,3.0983794,,,,,,,,,,,,,, -392500,3.3136036,1.2681217,,,,,,,,,,,,,, -392600,3.2152917,1.1263454,,,,,,,,,,,,,, -392700,3.0357704,1.7809999,,,,,,,,,,,,,, -392800,3.8346813,3.0583696,,,,,,,,,,,,,, -392900,2.9708002,1.2790987,,,,,,,,,,,,,, -392962,,,0.8881640434265137,0.4164069592952728,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,174811.7327439785,189867.20791006088,174811.7327439785,15008.291206359863,26.361035346984863,0.0 -393000,4.072859,3.1568336,,,,,,,,,,,,,, -393100,2.7895927,1.1689268,,,,,,,,,,,,,, -393200,3.0610497,2.0868828,,,,,,,,,,,,,, -393300,3.1499393,1.2609028,,,,,,,,,,,,,, -393400,3.9527624,3.2439377,,,,,,,,,,,,,, -393500,3.0790186,1.655227,,,,,,,,,,,,,, -393600,3.2523623,2.588359,,,,,,,,,,,,,, -393700,3.1120462,1.2472726,,,,,,,,,,,,,, -393800,3.2947319,2.797931,,,,,,,,,,,,,, -393900,3.0729873,1.205819,,,,,,,,,,,,,, -393906,,,0.88623046875,0.4230261445045471,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,175231.97394490242,190328.31834244728,175231.97394490242,15049.005800247192,26.46606206893921,0.0 -394000,4.276261,3.1734362,,,,,,,,,,,,,, -394100,3.056044,1.5168117,,,,,,,,,,,,,, -394200,3.155014,1.0942421,,,,,,,,,,,,,, -394300,3.1025565,1.1269095,,,,,,,,,,,,,, -394400,2.965989,1.1988049,,,,,,,,,,,,,, -394500,2.9425244,2.2632933,,,,,,,,,,,,,, -394600,3.7313738,3.00798,,,,,,,,,,,,,, -394700,3.1891348,2.3928852,,,,,,,,,,,,,, -394800,3.1034796,1.1881499,,,,,,,,,,,,,, -394851,,,0.8873632550239563,0.4190618693828583,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,175652.15966057777,190788.822149992,175652.15966057777,15089.174458265305,26.56468725204468,0.0 -394900,3.497009,2.1898377,,,,,,,,,,,,,, -395000,3.0669098,1.7530878,,,,,,,,,,,,,, -395100,3.2281635,1.2083594,,,,,,,,,,,,,, -395200,2.9923882,1.0626041,,,,,,,,,,,,,, -395300,4.3680058,1.4135854,,,,,,,,,,,,,, -395400,3.0585034,1.1542939,,,,,,,,,,,,,, -395500,3.1749482,1.076069,,,,,,,,,,,,,, -395600,3.2949963,2.951561,,,,,,,,,,,,,, -395700,3.7049506,1.2500436,,,,,,,,,,,,,, -395794,,,0.8883788585662842,0.4160697758197784,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,176072.37421774864,191249.1066968441,176072.37421774864,15129.109971046448,26.64938521385193,0.0 -395800,3.1013827,1.0476345,,,,,,,,,,,,,, -395900,3.0820649,1.8275214,,,,,,,,,,,,,, -396000,2.9032528,1.367978,,,,,,,,,,,,,, -396100,3.3425806,2.53445,,,,,,,,,,,,,, -396200,3.0523014,2.1812644,,,,,,,,,,,,,, -396300,2.9235322,1.989483,,,,,,,,,,,,,, -396400,2.799356,1.3961504,,,,,,,,,,,,,, -396500,3.208469,1.1540813,,,,,,,,,,,,,, -396600,2.982143,1.566066,,,,,,,,,,,,,, -396700,3.9921424,3.3777914,,,,,,,,,,,,,, -396738,,,0.8874413967132568,0.4168276190757751,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,176492.64163136482,191708.2846038341,176492.64163136482,15167.873075723648,26.746363401412964,0.0 -396800,3.8870258,3.0265307,,,,,,,,,,,,,, -396900,3.1230168,2.002095,,,,,,,,,,,,,, -397000,2.9614635,1.1120644,,,,,,,,,,,,,, -397100,3.0069392,1.0431408,,,,,,,,,,,,,, -397200,3.056352,2.1543834,,,,,,,,,,,,,, -397300,3.0563512,1.0960643,,,,,,,,,,,,,, -397400,3.7470949,3.0082517,,,,,,,,,,,,,, -397500,3.022346,1.2047753,,,,,,,,,,,,,, -397600,2.9248967,1.137298,,,,,,,,,,,,,, -397683,,,0.8902539014816284,0.4140291213989258,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,176912.88610959053,192168.7384550572,176912.88610959053,15207.94590306282,26.832626581192017,0.0 -397700,3.1792612,2.0839796,,,,,,,,,,,,,, -397800,3.078852,1.1979203,,,,,,,,,,,,,, -397900,2.9475708,1.0499296,,,,,,,,,,,,,, -398000,3.2517748,1.0807817,,,,,,,,,,,,,, -398100,3.4629183,1.2379677,,,,,,,,,,,,,, -398200,3.1630797,1.1486303,,,,,,,,,,,,,, -398300,3.0853016,1.160426,,,,,,,,,,,,,, -398400,2.9856203,1.1994847,,,,,,,,,,,,,, -398500,3.044743,1.0910501,,,,,,,,,,,,,, -398600,3.3745496,1.7373247,,,,,,,,,,,,,, -398628,,,0.88818359375,0.4120750427246094,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,177332.93433642387,192629.8915224076,177332.93433642387,15248.890265226364,26.943554639816284,0.0 -398700,3.2637398,1.24468,,,,,,,,,,,,,, -398800,3.9744737,3.373744,,,,,,,,,,,,,, -398900,3.2494168,2.6645377,,,,,,,,,,,,,, -399000,2.953039,1.137728,,,,,,,,,,,,,, -399100,3.114548,2.4909365,,,,,,,,,,,,,, -399200,3.3569927,2.2480626,,,,,,,,,,,,,, -399300,3.6083012,3.2713938,,,,,,,,,,,,,, -399400,3.7854702,2.5823953,,,,,,,,,,,,,, -399500,3.713381,2.98368,,,,,,,,,,,,,, -399570,,,0.8872656226158142,0.4196255803108215,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,177752.96119642258,193090.10246396065,177752.96119642258,15288.936572790146,27.03213572502136,0.0 -399600,3.1820114,2.5580842,,,,,,,,,,,,,, -399700,3.5996475,2.8872333,,,,,,,,,,,,,, -399800,3.0590794,1.9192348,,,,,,,,,,,,,, -399900,4.887521,3.200275,,,,,,,,,,,,,, -400000,3.030797,1.1243855,,,,,,,,,,,,,, -400100,3.3484824,1.1319461,,,,,,,,,,,,,, -400200,2.9622982,1.1899626,,,,,,,,,,,,,, -400300,3.856515,3.2014692,,,,,,,,,,,,,, -400400,4.043444,1.1329832,,,,,,,,,,,,,, -400500,3.2516518,2.680413,,,,,,,,,,,,,, -400512,,,0.8886913657188416,0.4138020575046539,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,178173.17875623703,193549.0971210003,178173.17875623703,15327.57566690445,27.120450258255005,0.0 -400600,2.8200762,1.5799468,,,,,,,,,,,,,, -400700,3.7686298,3.1614861,,,,,,,,,,,,,, -400800,2.9375088,1.2270931,,,,,,,,,,,,,, -400900,3.0019462,1.9471047,,,,,,,,,,,,,, -401000,2.9868467,1.788565,,,,,,,,,,,,,, -401100,3.6604145,1.2493038,,,,,,,,,,,,,, -401200,4.3298435,3.3529449,,,,,,,,,,,,,, -401300,3.228906,1.1243705,,,,,,,,,,,,,, -401400,3.075291,1.2481406,,,,,,,,,,,,,, -401457,,,0.8862695097923279,0.4230132102966308,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,178593.23889565468,194007.2590417862,178593.23889565468,15365.542269706726,27.2052583694458,0.0 -401500,3.0759747,2.4899592,,,,,,,,,,,,,, -401600,3.045914,1.1261959,,,,,,,,,,,,,, -401700,3.2842572,1.1540394,,,,,,,,,,,,,, -401800,3.3814998,2.7744415,,,,,,,,,,,,,, -401900,3.1527183,1.6949146,,,,,,,,,,,,,, -402000,3.4261189,1.1239287,,,,,,,,,,,,,, -402100,3.9016757,2.8484511,,,,,,,,,,,,,, -402200,3.0984342,2.0838706,,,,,,,,,,,,,, -402300,3.074087,1.1157979,,,,,,,,,,,,,, -402400,,,0.88685542345047,0.4238695502281189,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,179013.23162341118,194463.48729109764,179013.23162341118,15401.641217947006,27.29204750061035,0.0 -402400,3.1637297,1.0939057,,,,,,,,,,,,,, -402500,3.7070842,1.7843764,,,,,,,,,,,,,, -402600,2.8859088,1.0465698,,,,,,,,,,,,,, -402700,2.9971921,1.159954,,,,,,,,,,,,,, -402800,3.4466214,2.694443,,,,,,,,,,,,,, -402900,2.94874,1.9501601,,,,,,,,,,,,,, -403000,3.3006568,1.0667763,,,,,,,,,,,,,, -403100,2.968887,1.1395099,,,,,,,,,,,,,, -403200,4.687535,2.724706,,,,,,,,,,,,,, -403300,4.027048,3.1235907,,,,,,,,,,,,,, -403345,,,0.8892382383346558,0.4164697527885437,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,179433.4183971882,194923.0720796585,179433.4183971882,15440.896358013151,27.38421607017517,0.0 -403400,3.220862,1.2477946,,,,,,,,,,,,,, -403500,3.7883976,3.2183037,,,,,,,,,,,,,, -403600,2.9322994,1.6154428,,,,,,,,,,,,,, -403700,2.9611542,1.3058064,,,,,,,,,,,,,, -403800,3.775426,2.690879,,,,,,,,,,,,,, -403900,3.0957985,1.5038157,,,,,,,,,,,,,, -404000,3.2243617,1.788374,,,,,,,,,,,,,, -404100,3.028605,1.2604734,,,,,,,,,,,,,, -404200,3.3233016,2.5908945,,,,,,,,,,,,,, -404288,,,0.88623046875,0.4211839139461517,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,179853.55819368362,195382.3181631565,179853.55819368362,15479.857174158096,27.4800283908844,0.0 -404300,4.0217443,3.1454134,,,,,,,,,,,,,, -404400,3.2013905,2.2990856,,,,,,,,,,,,,, -404500,2.891816,1.0415245,,,,,,,,,,,,,, -404600,3.0781376,2.3959193,,,,,,,,,,,,,, -404700,3.1211555,1.5899523,,,,,,,,,,,,,, -404800,3.4686737,2.8097103,,,,,,,,,,,,,, -404900,3.1881304,1.2168717,,,,,,,,,,,,,, -405000,3.0288095,1.2130697,,,,,,,,,,,,,, -405100,3.1982608,1.081352,,,,,,,,,,,,,, -405200,2.9575677,1.0183934,,,,,,,,,,,,,, -405232,,,0.8868359327316284,0.4213452041149139,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,180273.6412432193,195842.7099413872,180273.6412432193,15520.02494072914,27.570369720458984,0.0 -405300,3.2468643,1.1993784,,,,,,,,,,,,,, -405400,3.3623428,1.1316079,,,,,,,,,,,,,, -405500,3.223558,1.0384434,,,,,,,,,,,,,, -405600,3.1218994,1.1837826,,,,,,,,,,,,,, -405700,3.0439997,2.1280088,,,,,,,,,,,,,, -405800,3.7874563,3.181071,,,,,,,,,,,,,, -405900,2.9936118,1.4090432,,,,,,,,,,,,,, -406000,2.9243546,1.7311696,,,,,,,,,,,,,, -406100,3.0199559,1.1194797,,,,,,,,,,,,,, -406177,,,0.8876171708106995,0.4252983629703522,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,180693.8637099266,196302.9815299511,180693.8637099266,15559.92268037796,27.672037363052368,0.0 -406200,3.1233785,1.1604019,,,,,,,,,,,,,, -406300,3.082321,1.0741231,,,,,,,,,,,,,, -406400,3.2939584,1.2482535,,,,,,,,,,,,,, -406500,3.1059928,1.1003922,,,,,,,,,,,,,, -406600,3.4881332,1.1264132,,,,,,,,,,,,,, -406700,2.9901934,1.8133885,,,,,,,,,,,,,, -406800,3.3104765,1.143943,,,,,,,,,,,,,, -406900,4.7670903,3.170428,,,,,,,,,,,,,, -407000,3.8887737,3.1333232,,,,,,,,,,,,,, -407100,2.9901845,2.0639129,,,,,,,,,,,,,, -407121,,,0.8858984112739563,0.4219390451908111,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,181114.0154056549,196762.9376182556,181114.0154056549,15599.586994171144,27.76237177848816,0.0 -407200,3.2443366,1.1791791,,,,,,,,,,,,,, -407300,3.6286802,2.9946742,,,,,,,,,,,,,, -407400,3.4151917,1.5778812,,,,,,,,,,,,,, -407500,3.2067552,2.7949767,,,,,,,,,,,,,, -407600,3.1321235,2.628806,,,,,,,,,,,,,, -407700,3.2053044,1.6139672,,,,,,,,,,,,,, -407800,3.0138824,1.8361688,,,,,,,,,,,,,, -407900,3.005809,1.054301,,,,,,,,,,,,,, -408000,3.070621,1.1307243,,,,,,,,,,,,,, -408066,,,0.8882030844688416,0.4163236916065216,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,181533.90558075905,197224.2642486096,181533.90558075905,15640.886532783508,27.84921193122864,0.0 -408100,2.9787621,1.075632,,,,,,,,,,,,,, -408200,3.1971734,2.084194,,,,,,,,,,,,,, -408300,3.235384,1.0637786,,,,,,,,,,,,,, -408400,4.1970387,3.2327754,,,,,,,,,,,,,, -408500,3.620179,2.4911263,,,,,,,,,,,,,, -408600,3.2643876,1.0944055,,,,,,,,,,,,,, -408700,4.594274,3.2352839,,,,,,,,,,,,,, -408800,2.9635115,1.5930153,,,,,,,,,,,,,, -408900,3.049184,1.1502295,,,,,,,,,,,,,, -409000,3.2216332,1.0824149,,,,,,,,,,,,,, -409009,,,0.88720703125,0.4183976352214813,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,181953.90784978867,197685.3879377842,181953.90784978867,15681.871098041534,27.936203956604004,0.0 -409100,3.6331637,3.0987594,,,,,,,,,,,,,, -409200,3.4414363,2.1928723,,,,,,,,,,,,,, -409300,3.1624916,2.7473752,,,,,,,,,,,,,, -409400,3.0151834,2.5013494,,,,,,,,,,,,,, -409500,3.1480033,1.213678,,,,,,,,,,,,,, -409600,12.14149,1.7249341,,,,,,,,,,,,,, -409700,4.7690926,3.1211078,,,,,,,,,,,,,, -409800,3.1739404,1.4009882,,,,,,,,,,,,,, -409900,3.192741,1.189643,,,,,,,,,,,,,, -409953,,,0.8857616782188416,0.4226754307746887,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,182374.1227302552,198145.38085389137,182374.1227302552,15721.50890302658,28.027443408966064,0.0 -410000,3.2754447,1.3564916,,,,,,,,,,,,,, -410100,3.7238038,3.1716945,,,,,,,,,,,,,, -410200,2.9816742,1.1953185,,,,,,,,,,,,,, -410300,3.9080813,3.116863,,,,,,,,,,,,,, -410400,3.2011728,2.441459,,,,,,,,,,,,,, -410500,3.0323002,2.364225,,,,,,,,,,,,,, -410600,3.2413423,1.1508919,,,,,,,,,,,,,, -410700,3.0343933,1.8492782,,,,,,,,,,,,,, -410800,2.927722,1.5404276,,,,,,,,,,,,,, -410896,,,0.8866015672683716,0.4220174551010132,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,182794.09159350395,198603.1596791744,182794.09159350395,15759.182371139526,28.114766120910645,0.0 -410900,3.5853899,3.0829341,,,,,,,,,,,,,, -411000,3.6339848,1.2236397,,,,,,,,,,,,,, -411100,3.209217,1.7348064,,,,,,,,,,,,,, -411200,3.2172332,1.2818134,,,,,,,,,,,,,, -411300,3.0323412,1.340328,,,,,,,,,,,,,, -411400,3.1186168,1.1412513,,,,,,,,,,,,,, -411500,3.2214968,2.3233452,,,,,,,,,,,,,, -411600,3.1044216,1.0924196,,,,,,,,,,,,,, -411700,3.019458,1.0588043,,,,,,,,,,,,,, -411800,3.086794,1.2303059,,,,,,,,,,,,,, -411840,,,0.8879101276397705,0.4160162210464477,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,183214.29858469963,199062.46124649048,183214.29858469963,15798.120640039444,28.22186017036438,0.0 -411900,3.0860772,1.1381264,,,,,,,,,,,,,, -412000,3.3478465,1.2729843,,,,,,,,,,,,,, -412100,2.987961,1.0495745,,,,,,,,,,,,,, -412200,3.1702983,1.1762135,,,,,,,,,,,,,, -412300,3.3313446,1.1976886,,,,,,,,,,,,,, -412400,2.9877818,1.9594197,,,,,,,,,,,,,, -412500,2.8861494,1.2394017,,,,,,,,,,,,,, -412600,3.5391603,1.3020668,,,,,,,,,,,,,, -412700,3.6369374,2.6070626,,,,,,,,,,,,,, -412784,,,0.8880078196525574,0.4205534160137176,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,183634.53731536865,199519.83964657784,183634.53731536865,15835.116010665894,28.31593894958496,0.0 -412800,2.9960842,1.1264657,,,,,,,,,,,,,, -412900,3.466998,2.6056128,,,,,,,,,,,,,, -413000,4.0868535,2.9572577,,,,,,,,,,,,,, -413100,3.121167,1.0898676,,,,,,,,,,,,,, -413200,3.6807685,1.9866194,,,,,,,,,,,,,, -413300,2.9536288,1.6319549,,,,,,,,,,,,,, -413400,3.2104144,2.3616338,,,,,,,,,,,,,, -413500,3.7637942,3.2784877,,,,,,,,,,,,,, -413600,3.1139283,1.2049714,,,,,,,,,,,,,, -413700,3.046925,1.1632841,,,,,,,,,,,,,, -413728,,,0.8871679306030273,0.4199144542217254,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,184054.52427721024,199978.1976265908,184054.52427721024,15873.348242282867,28.40547013282776,0.0 -413800,3.1327982,1.8541319,,,,,,,,,,,,,, -413900,3.228433,2.5595284,,,,,,,,,,,,,, -414000,3.161647,1.1269337,,,,,,,,,,,,,, -414100,3.368512,1.1871771,,,,,,,,,,,,,, -414200,3.2768314,1.2261724,,,,,,,,,,,,,, -414300,5.003776,3.2307725,,,,,,,,,,,,,, -414400,2.964302,1.0831051,,,,,,,,,,,,,, -414500,3.1568747,2.627192,,,,,,,,,,,,,, -414600,4.739358,3.2497096,,,,,,,,,,,,,, -414672,,,0.88720703125,0.4232935905456543,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,184474.57068252563,200436.07974553108,184474.57068252563,15911.04614019394,28.493316650390625,0.0 -414700,3.2100823,1.2013398,,,,,,,,,,,,,, -414800,3.6670125,2.9348876,,,,,,,,,,,,,, -414900,3.1402316,1.1398627,,,,,,,,,,,,,, -415000,3.0960004,1.0971463,,,,,,,,,,,,,, -415100,2.9767451,1.2252859,,,,,,,,,,,,,, -415200,3.1323726,2.403644,,,,,,,,,,,,,, -415300,3.7016091,3.0604565,,,,,,,,,,,,,, -415400,3.0723708,1.9508603,,,,,,,,,,,,,, -415500,2.9223433,1.1996067,,,,,,,,,,,,,, -415600,3.935661,2.8518348,,,,,,,,,,,,,, -415617,,,0.8864062428474426,0.4224046170711517,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,184894.88948082924,200898.40726947784,184894.88948082924,15952.918644666672,28.580711126327515,0.0 -415700,2.8525991,1.9173396,,,,,,,,,,,,,, -415800,3.0065966,1.0094254,,,,,,,,,,,,,, -415900,3.5003195,2.6141894,,,,,,,,,,,,,, -416000,3.525895,1.2021184,,,,,,,,,,,,,, -416100,3.6820538,3.0488777,,,,,,,,,,,,,, -416200,3.0901966,1.2323246,,,,,,,,,,,,,, -416300,2.917432,1.2334254,,,,,,,,,,,,,, -416400,3.1333547,1.2010787,,,,,,,,,,,,,, -416500,3.4681334,1.3549252,,,,,,,,,,,,,, -416562,,,0.8877148032188416,0.4189739227294922,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,185314.9847421646,201356.72749471664,185314.9847421646,15991.00540137291,28.669321537017822,0.0 -416600,3.4879365,2.9276047,,,,,,,,,,,,,, -416700,3.9785714,3.1491435,,,,,,,,,,,,,, -416800,3.025387,2.6259592,,,,,,,,,,,,,, -416900,3.308645,1.096395,,,,,,,,,,,,,, -417000,3.0489554,1.0901978,,,,,,,,,,,,,, -417100,3.4395268,2.9549077,,,,,,,,,,,,,, -417200,2.9565191,1.8706168,,,,,,,,,,,,,, -417300,2.9121642,1.9993519,,,,,,,,,,,,,, -417400,4.3945336,3.065287,,,,,,,,,,,,,, -417500,3.1682594,1.1847874,,,,,,,,,,,,,, -417507,,,0.8869140148162842,0.4190338551998138,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,185735.00822615623,201816.1709141732,185735.00822615623,16030.271182060242,28.7737877368927,0.0 -417600,3.3384893,2.3331304,,,,,,,,,,,,,, -417700,3.0246909,1.4191476,,,,,,,,,,,,,, -417800,3.0988884,1.1904877,,,,,,,,,,,,,, -417900,3.1644804,1.1172518,,,,,,,,,,,,,, -418000,3.291429,1.1440783,,,,,,,,,,,,,, -418100,3.0977519,1.1261725,,,,,,,,,,,,,, -418200,3.5828447,1.1113546,,,,,,,,,,,,,, -418300,3.2521052,1.2797707,,,,,,,,,,,,,, -418400,2.9227908,1.1503556,,,,,,,,,,,,,, -418450,,,0.8861523270606995,0.425559937953949,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,186155.03788423527,202277.07932949063,186155.03788423527,16071.009669065475,28.864509344100952,0.0 -418500,3.2456124,1.0229447,,,,,,,,,,,,,, -418600,2.9661264,1.4738022,,,,,,,,,,,,,, -418700,3.7257843,3.1648974,,,,,,,,,,,,,, -418800,3.4946625,1.1271857,,,,,,,,,,,,,, -418900,4.4052367,3.2846074,,,,,,,,,,,,,, -419000,3.3065543,2.2297316,,,,,,,,,,,,,, -419100,2.94266,1.2226026,,,,,,,,,,,,,, -419200,3.16223,1.1679175,,,,,,,,,,,,,, -419300,3.0738435,1.1271526,,,,,,,,,,,,,, -419394,,,0.8866015672683716,0.4160337746143341,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,186575.0562889576,202737.35576033592,186575.0562889576,16111.09059047699,28.99151039123535,0.0 -419400,3.487556,2.7443564,,,,,,,,,,,,,, -419500,2.8038933,1.3615505,,,,,,,,,,,,,, -419600,3.33207,2.3331466,,,,,,,,,,,,,, -419700,3.0675647,1.1686172,,,,,,,,,,,,,, -419800,3.5385659,1.0602133,,,,,,,,,,,,,, -419900,3.2803078,1.1844466,,,,,,,,,,,,,, -420000,3.1790414,1.124449,,,,,,,,,,,,,, -420100,3.0671747,1.3269988,,,,,,,,,,,,,, -420200,3.1545768,1.0079119,,,,,,,,,,,,,, -420300,3.5376167,3.0311313,,,,,,,,,,,,,, -420339,,,0.8887109160423279,0.4114374816417694,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,186995.1407394409,203195.7936155796,186995.1407394409,16149.30091905594,29.085197925567627,0.0 -420400,2.8565872,1.0788078,,,,,,,,,,,,,, -420500,3.4464471,1.1334928,,,,,,,,,,,,,, -420600,3.17013,2.7317305,,,,,,,,,,,,,, -420700,3.5018804,2.956967,,,,,,,,,,,,,, -420800,3.291991,1.3135248,,,,,,,,,,,,,, -420900,3.109605,2.4253051,,,,,,,,,,,,,, -421000,3.7940543,3.042857,,,,,,,,,,,,,, -421100,3.3193583,2.3260956,,,,,,,,,,,,,, -421200,2.955635,2.4456427,,,,,,,,,,,,,, -421284,,,0.88978511095047,0.4178528487682342,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,187415.330839634,203653.81957530967,187415.330839634,16186.994090795515,29.1780104637146,0.0 -421300,3.2532427,1.0894797,,,,,,,,,,,,,, -421400,3.081373,1.5227797,,,,,,,,,,,,,, -421500,2.812285,1.0827975,,,,,,,,,,,,,, -421600,3.172572,1.86584,,,,,,,,,,,,,, -421700,3.0706792,1.1323994,,,,,,,,,,,,,, -421800,3.0742004,1.509646,,,,,,,,,,,,,, -421900,3.2421553,1.2220124,,,,,,,,,,,,,, -422000,3.44871,1.425501,,,,,,,,,,,,,, -422100,3.4434445,1.17051,,,,,,,,,,,,,, -422200,3.0572865,1.1825136,,,,,,,,,,,,,, -422228,,,0.8898242115974426,0.412438154220581,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,187835.38063955307,204110.84920215607,187835.38063955307,16223.821682929993,29.28029179573059,0.0 -422300,3.427662,3.018331,,,,,,,,,,,,,, -422400,3.5967915,2.5907247,,,,,,,,,,,,,, -422500,2.7650433,0.9744288,,,,,,,,,,,,,, -422600,3.2818258,1.2193865,,,,,,,,,,,,,, -422700,4.797598,3.1183124,,,,,,,,,,,,,, -422800,3.3115227,1.0875112,,,,,,,,,,,,,, -422900,3.3492103,2.2057614,,,,,,,,,,,,,, -423000,3.4948242,1.2223175,,,,,,,,,,,,,, -423100,4.1326995,3.2684956,,,,,,,,,,,,,, -423172,,,0.8860546946525574,0.4179321825504303,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,188255.40531373024,204571.80442857745,188255.40531373024,16264.61352443695,29.36870121955872,0.0 -423200,4.2563686,3.3478265,,,,,,,,,,,,,, -423300,3.1895137,2.1245842,,,,,,,,,,,,,, -423400,2.8552735,1.3619225,,,,,,,,,,,,,, -423500,3.2337205,1.1792735,,,,,,,,,,,,,, -423600,4.3059683,1.175262,,,,,,,,,,,,,, -423700,3.7370286,1.2236692,,,,,,,,,,,,,, -423800,3.1748965,1.6919937,,,,,,,,,,,,,, -423900,3.5006664,1.9917295,,,,,,,,,,,,,, -424000,3.2065592,2.902673,,,,,,,,,,,,,, -424100,3.8335402,2.70484,,,,,,,,,,,,,, -424115,,,0.8898437023162842,0.4148611426353454,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,188675.36987829208,205033.4146347046,188675.36987829208,16306.104352474213,29.47391629219055,0.0 -424200,3.7646625,3.230331,,,,,,,,,,,,,, -424300,3.1764286,1.1667818,,,,,,,,,,,,,, -424400,2.864112,1.0990863,,,,,,,,,,,,,, -424500,3.5507708,3.0598302,,,,,,,,,,,,,, -424600,3.1773288,1.5359039,,,,,,,,,,,,,, -424700,3.045935,1.1336744,,,,,,,,,,,,,, -424800,4.235486,3.0178094,,,,,,,,,,,,,, -424900,3.270548,1.0567715,,,,,,,,,,,,,, -425000,3.9080977,3.030805,,,,,,,,,,,,,, -425058,,,0.8862499594688416,0.4217503368854522,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,189095.49055552483,205494.103537798,189095.49055552483,16346.529787540436,29.56756401062012,0.0 -425100,4.0065136,3.360274,,,,,,,,,,,,,, -425200,3.176439,1.1730604,,,,,,,,,,,,,, -425300,2.9904256,1.0105592,,,,,,,,,,,,,, -425400,2.917611,1.1148232,,,,,,,,,,,,,, -425500,3.268929,2.763158,,,,,,,,,,,,,, -425600,3.237555,1.248824,,,,,,,,,,,,,, -425700,3.0967157,1.1676178,,,,,,,,,,,,,, -425800,3.0804446,2.300616,,,,,,,,,,,,,, -425900,3.3996675,1.1574128,,,,,,,,,,,,,, -426000,,,0.88818359375,0.4202206432819366,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,189515.35760688785,205952.8334183693,189515.35760688785,16385.24165081978,29.66967749595642,0.0 -426000,3.444737,1.2499781,,,,,,,,,,,,,, -426100,3.5270765,2.2663999,,,,,,,,,,,,,, -426200,3.1616616,1.599279,,,,,,,,,,,,,, -426300,2.9613745,1.52934,,,,,,,,,,,,,, -426400,2.9660826,1.1448632,,,,,,,,,,,,,, -426500,2.9050229,1.6495612,,,,,,,,,,,,,, -426600,3.3000247,1.2050779,,,,,,,,,,,,,, -426700,3.1969426,1.0790918,,,,,,,,,,,,,, -426800,2.8367102,1.1206715,,,,,,,,,,,,,, -426900,3.3524518,1.2043337,,,,,,,,,,,,,, -426944,,,0.8882812261581421,0.4174784421920776,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,189935.4038639069,206410.42110538483,189935.4038639069,16422.631457805634,29.771843910217285,0.0 -427000,3.110662,1.1788663,,,,,,,,,,,,,, -427100,3.4226937,2.7073393,,,,,,,,,,,,,, -427200,3.1383135,2.5874615,,,,,,,,,,,,,, -427300,2.9715245,1.0813096,,,,,,,,,,,,,, -427400,3.0854633,1.0657524,,,,,,,,,,,,,, -427500,3.587711,1.656046,,,,,,,,,,,,,, -427600,3.2580938,1.0983136,,,,,,,,,,,,,, -427700,3.2574816,1.1399244,,,,,,,,,,,,,, -427800,3.2101195,1.1426947,,,,,,,,,,,,,, -427886,,,0.8866015672683716,0.4225281476974487,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,190355.65117049217,206870.26315903664,190355.65117049217,16462.080998659134,29.86714720726013,0.0 -427900,3.3107803,1.085448,,,,,,,,,,,,,, -428000,2.8181443,1.5742652,,,,,,,,,,,,,, -428100,3.2581975,1.2076254,,,,,,,,,,,,,, -428200,3.0041716,1.1734065,,,,,,,,,,,,,, -428300,2.9933915,1.0815572,,,,,,,,,,,,,, -428400,2.9383621,1.1075453,,,,,,,,,,,,,, -428500,3.2922895,2.7467475,,,,,,,,,,,,,, -428600,3.0977235,1.1797308,,,,,,,,,,,,,, -428700,3.0216055,1.9881203,,,,,,,,,,,,,, -428800,4.265049,3.1353614,,,,,,,,,,,,,, -428830,,,0.8874022960662842,0.4188739955425262,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,190775.9188687801,207331.3117365837,190775.9188687801,16502.711376667023,29.968742847442627,0.0 -428900,3.0733204,1.0460682,,,,,,,,,,,,,, -429000,3.2428164,1.9813954,,,,,,,,,,,,,, -429100,3.4105098,1.1843389,,,,,,,,,,,,,, -429200,3.9818447,2.7374394,,,,,,,,,,,,,, -429300,3.3704517,2.6512518,,,,,,,,,,,,,, -429400,3.3172135,1.1255041,,,,,,,,,,,,,, -429500,3.0690064,1.6473991,,,,,,,,,,,,,, -429600,2.9498541,1.2174051,,,,,,,,,,,,,, -429700,2.8659852,1.6795851,,,,,,,,,,,,,, -429774,,,0.8849608898162842,0.4287216365337372,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,191195.83547329903,207790.8137485981,191195.83547329903,16542.15592098236,30.05966424942017,0.0 -429800,3.1446266,1.0699393,,,,,,,,,,,,,, -429900,2.9522183,1.9420528,,,,,,,,,,,,,, -430000,3.2040327,1.1815567,,,,,,,,,,,,,, -430100,2.861394,0.99122643,,,,,,,,,,,,,, -430200,3.4725404,3.0331,,,,,,,,,,,,,, -430300,2.9985652,2.0528665,,,,,,,,,,,,,, -430400,3.668629,2.9572747,,,,,,,,,,,,,, -430500,3.1903076,1.0786357,,,,,,,,,,,,,, -430600,3.0209105,1.2690866,,,,,,,,,,,,,, -430700,3.1756406,1.1306338,,,,,,,,,,,,,, -430718,,,0.8863085508346558,0.4251331090927124,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,191615.99197554588,208250.25497484207,191615.99197554588,16581.287677764893,30.16299033164978,0.0 -430800,3.2683437,1.1478336,,,,,,,,,,,,,, -430900,4.978825,1.6665764,,,,,,,,,,,,,, -431000,3.0167081,1.4452082,,,,,,,,,,,,,, -431100,2.9483576,1.045182,,,,,,,,,,,,,, -431200,3.0726798,1.1528236,,,,,,,,,,,,,, -431300,3.2030208,2.5288868,,,,,,,,,,,,,, -431400,2.865277,2.4514644,,,,,,,,,,,,,, -431500,4.044808,3.316748,,,,,,,,,,,,,, -431600,3.0193605,1.294493,,,,,,,,,,,,,, -431661,,,0.8882616758346558,0.4132554531097412,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,192036.2289819717,208712.9606354237,192036.2289819717,16623.598644971848,30.271098613739014,0.0 -431700,3.5529635,2.9897606,,,,,,,,,,,,,, -431800,3.072617,2.7387023,,,,,,,,,,,,,, -431900,3.4957297,1.7116979,,,,,,,,,,,,,, -432000,3.5277288,2.7680259,,,,,,,,,,,,,, -432100,3.1335049,1.136485,,,,,,,,,,,,,, -432200,3.913435,3.1497943,,,,,,,,,,,,,, -432300,3.885833,1.7210478,,,,,,,,,,,,,, -432400,3.3399932,2.084056,,,,,,,,,,,,,, -432500,2.9967608,1.6382737,,,,,,,,,,,,,, -432600,2.9092197,1.3334668,,,,,,,,,,,,,, -432603,,,0.8885155916213989,0.4181333482265472,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,192456.2843079567,209173.80654668808,192456.2843079567,16664.247750520706,30.36311721801757,0.0 -432700,3.0411768,1.2033271,,,,,,,,,,,,,, -432800,2.9087005,1.2355624,,,,,,,,,,,,,, -432900,3.1199992,1.8821727,,,,,,,,,,,,,, -433000,3.912892,3.2254736,,,,,,,,,,,,,, -433100,3.3343382,1.4501584,,,,,,,,,,,,,, -433200,2.8521845,1.5120459,,,,,,,,,,,,,, -433300,3.2418187,1.0778114,,,,,,,,,,,,,, -433400,3.3578913,1.1623082,,,,,,,,,,,,,, -433500,4.051286,1.1048192,,,,,,,,,,,,,, -433548,,,0.8854687213897705,0.4237488210201263,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,192876.3270497322,209633.1222922802,192876.3270497322,16703.37880063057,30.455291032791138,0.0 -433600,4.034084,3.263236,,,,,,,,,,,,,, -433700,2.851738,2.225443,,,,,,,,,,,,,, -433800,3.5015736,1.6978271,,,,,,,,,,,,,, -433900,3.3009367,1.3901911,,,,,,,,,,,,,, -434000,3.0689874,1.4839231,,,,,,,,,,,,,, -434100,2.837987,1.0284561,,,,,,,,,,,,,, -434200,3.1828046,1.0350432,,,,,,,,,,,,,, -434300,3.068608,1.1843393,,,,,,,,,,,,,, -434400,3.309285,3.0365696,,,,,,,,,,,,,, -434493,,,0.8869531154632568,0.4174193441867828,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,193296.5520606041,210093.69893980023,193296.5520606041,16743.577922344208,30.5573947429657,0.0 -434500,2.9488444,1.0255442,,,,,,,,,,,,,, -434600,3.3347497,1.7394749,,,,,,,,,,,,,, -434700,2.938275,1.9684044,,,,,,,,,,,,,, -434800,3.1116385,1.0585301,,,,,,,,,,,,,, -434900,3.209213,1.1827697,,,,,,,,,,,,,, -435000,3.5545053,2.954904,,,,,,,,,,,,,, -435100,3.503204,2.6410666,,,,,,,,,,,,,, -435200,3.6556058,2.9556365,,,,,,,,,,,,,, -435300,3.1083481,2.4348514,,,,,,,,,,,,,, -435400,3.4179354,1.0534949,,,,,,,,,,,,,, -435435,,,0.8883007764816284,0.4161953330039978,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,193716.42161488533,210554.7790727616,193716.42161488533,16784.6428296566,30.652455806732178,0.0 -435500,4.02334,2.5938852,,,,,,,,,,,,,, -435600,4.138628,3.0072768,,,,,,,,,,,,,, -435700,3.3073335,1.226692,,,,,,,,,,,,,, -435800,4.732928,3.341864,,,,,,,,,,,,,, -435900,3.2129898,1.1445392,,,,,,,,,,,,,, -436000,3.2670534,1.2483412,,,,,,,,,,,,,, -436100,3.1940236,2.7009332,,,,,,,,,,,,,, -436200,2.897296,1.1336062,,,,,,,,,,,,,, -436300,3.128818,1.0807792,,,,,,,,,,,,,, -436378,,,0.8877539038658142,0.4231440126895904,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,194136.64938545227,211019.8227727413,194136.64938545227,16829.308379650116,30.753374576568604,0.0 -436400,3.181722,2.6490963,,,,,,,,,,,,,, -436500,3.710665,3.1280298,,,,,,,,,,,,,, -436600,3.2619681,1.1417842,,,,,,,,,,,,,, -436700,3.8364053,3.244672,,,,,,,,,,,,,, -436800,3.1018846,1.0927924,,,,,,,,,,,,,, -436900,2.9853857,1.1444057,,,,,,,,,,,,,, -437000,3.592429,2.0376763,,,,,,,,,,,,,, -437100,3.0047054,2.081293,,,,,,,,,,,,,, -437200,4.237233,3.174072,,,,,,,,,,,,,, -437300,2.8426554,1.3176733,,,,,,,,,,,,,, -437322,,,0.8863866925239563,0.4232368469238281,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,194556.91301631927,211476.6163315773,194556.91301631927,16865.685774326324,30.85575246810913,0.0 -437400,3.1771214,1.807518,,,,,,,,,,,,,, -437500,3.2264328,1.3505557,,,,,,,,,,,,,, -437600,3.5694547,2.8154473,,,,,,,,,,,,,, -437700,2.9601688,1.0731341,,,,,,,,,,,,,, -437800,3.6452591,3.3183954,,,,,,,,,,,,,, -437900,3.7343824,2.7643266,,,,,,,,,,,,,, -438000,3.417772,2.7801416,,,,,,,,,,,,,, -438100,3.3390841,1.1682167,,,,,,,,,,,,,, -438200,3.110332,1.1128399,,,,,,,,,,,,,, -438268,,,0.8855273127555847,0.42439004778862,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,194977.04292821884,211936.03138518333,194977.04292821884,16904.828684091568,30.9488582611084,0.0 -438300,3.111045,1.8106258,,,,,,,,,,,,,, -438400,3.3648264,2.2165458,,,,,,,,,,,,,, -438500,3.5654037,2.079371,,,,,,,,,,,,,, -438600,2.9184864,1.073409,,,,,,,,,,,,,, -438700,3.4035566,1.1558409,,,,,,,,,,,,,, -438800,3.8838289,3.1035974,,,,,,,,,,,,,, -438900,3.4825766,2.6857467,,,,,,,,,,,,,, -439000,4.0360928,1.4958422,,,,,,,,,,,,,, -439100,3.3725615,2.0150893,,,,,,,,,,,,,, -439200,3.0908716,2.5713973,,,,,,,,,,,,,, -439212,,,0.8867773413658142,0.4227740466594696,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,195397.16938996315,212395.8071053028,195397.16938996315,16944.33552479744,31.040956497192383,0.0 -439300,3.082624,1.4599321,,,,,,,,,,,,,, -439400,8.598316,2.0244167,,,,,,,,,,,,,, -439500,4.412457,3.0913255,,,,,,,,,,,,,, -439600,3.0115275,1.0703171,,,,,,,,,,,,,, -439700,3.44501,2.6066763,,,,,,,,,,,,,, -439800,2.9732866,1.6442866,,,,,,,,,,,,,, -439900,3.2209077,1.1934142,,,,,,,,,,,,,, -440000,3.3434618,2.63418,,,,,,,,,,,,,, -440100,3.3139627,1.2859617,,,,,,,,,,,,,, -440157,,,0.8872460722923279,0.4181328415870666,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,195817.39040994644,212857.3381161689,195817.39040994644,16985.492069721222,31.14467477798462,0.0 -440200,2.7392495,1.3153512,,,,,,,,,,,,,, -440300,3.392276,1.119281,,,,,,,,,,,,,, -440400,3.3953772,2.9816873,,,,,,,,,,,,,, -440500,3.3432348,1.3572296,,,,,,,,,,,,,, -440600,4.7286468,3.0789664,,,,,,,,,,,,,, -440700,2.9519906,1.3615365,,,,,,,,,,,,,, -440800,3.001326,2.0395513,,,,,,,,,,,,,, -440900,3.1906152,1.1382954,,,,,,,,,,,,,, -441000,3.2305527,2.9101098,,,,,,,,,,,,,, -441099,,,0.8878124952316284,0.4193216562271118,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,196237.36933875084,213319.05184865,196237.36933875084,17027.075122833252,31.24666905403137,0.0 -441100,3.6392064,2.8614109,,,,,,,,,,,,,, -441200,3.2727888,2.5749478,,,,,,,,,,,,,, -441300,3.0950942,1.9053462,,,,,,,,,,,,,, -441400,2.8836265,1.1622645,,,,,,,,,,,,,, -441500,2.9722106,1.0330408,,,,,,,,,,,,,, -441600,3.068875,1.4737589,,,,,,,,,,,,,, -441700,4.100543,3.2700896,,,,,,,,,,,,,, -441800,3.1709363,1.4173343,,,,,,,,,,,,,, -441900,4.901283,3.312314,,,,,,,,,,,,,, -442000,4.070342,3.2653232,,,,,,,,,,,,,, -442043,,,0.8869140148162842,0.4169472157955169,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,196657.33681559563,213780.749254942,196657.33681559563,17068.658164978027,31.343939542770386,0.0 -442100,2.7488632,1.4462755,,,,,,,,,,,,,, -442200,3.6718762,3.012942,,,,,,,,,,,,,, -442300,3.1927178,1.5892987,,,,,,,,,,,,,, -442400,3.280996,1.1692731,,,,,,,,,,,,,, -442500,4.3481035,2.675478,,,,,,,,,,,,,, -442600,4.7524548,2.9645023,,,,,,,,,,,,,, -442700,4.2536707,3.0927262,,,,,,,,,,,,,, -442800,3.245615,1.3146526,,,,,,,,,,,,,, -442900,3.0230036,1.7972158,,,,,,,,,,,,,, -442989,,,0.8885155916213989,0.4157116115093231,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,197077.5290656089,214236.43349838257,197077.5290656089,17104.009438753128,31.43527889251709,0.0 -443000,3.020325,1.2001467,,,,,,,,,,,,,, -443100,3.0995243,1.1483197,,,,,,,,,,,,,, -443200,3.287846,1.3235861,,,,,,,,,,,,,, -443300,3.3700616,1.0855948,,,,,,,,,,,,,, -443400,2.9026291,1.0505247,,,,,,,,,,,,,, -443500,3.0621734,1.1158127,,,,,,,,,,,,,, -443600,3.3209188,1.2044988,,,,,,,,,,,,,, -443700,3.3579164,1.1680303,,,,,,,,,,,,,, -443800,3.1655178,1.1671157,,,,,,,,,,,,,, -443900,3.4326804,2.2947152,,,,,,,,,,,,,, -443934,,,0.8871288895606995,0.4241305589675903,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,197497.7634282112,214700.71033906937,197497.7634282112,17147.90717315674,31.5306293964386,0.0 -444000,3.1768303,1.8236641,,,,,,,,,,,,,, -444100,3.480071,1.1203222,,,,,,,,,,,,,, -444200,3.6932366,1.1774895,,,,,,,,,,,,,, -444300,3.4336936,2.3743103,,,,,,,,,,,,,, -444400,3.4256637,1.0314466,,,,,,,,,,,,,, -444500,3.1090698,1.0316978,,,,,,,,,,,,,, -444600,2.938248,1.3209727,,,,,,,,,,,,,, -444700,3.151771,1.3222667,,,,,,,,,,,,,, -444800,3.0181468,1.0870116,,,,,,,,,,,,,, -444870,,,0.8898242115974426,0.411442756652832,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,197917.61884331703,215160.68771243087,197917.61884331703,17187.872669696808,31.63857674598694,0.0 -444900,2.953042,1.1926788,,,,,,,,,,,,,, -445000,3.091371,1.1213014,,,,,,,,,,,,,, -445100,3.3326156,2.9822433,,,,,,,,,,,,,, -445200,4.1718354,3.1927252,,,,,,,,,,,,,, -445300,3.7416098,1.1894976,,,,,,,,,,,,,, -445400,3.3367915,2.0705073,,,,,,,,,,,,,, -445500,3.1425571,1.1773024,,,,,,,,,,,,,, -445600,3.0475705,1.0951905,,,,,,,,,,,,,, -445700,3.1897519,1.0473838,,,,,,,,,,,,,, -445800,3.1055512,1.1626093,,,,,,,,,,,,,, -445811,,,0.8890820145606995,0.4095396399497986,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,198337.8225800991,215616.0204000473,198337.8225800991,17222.858201503754,31.731759071350098,0.0 -445900,3.8689997,3.3315654,,,,,,,,,,,,,, -446000,3.3019671,1.0287924,,,,,,,,,,,,,, -446100,3.1749601,1.2007785,,,,,,,,,,,,,, -446200,3.7193651,3.0350444,,,,,,,,,,,,,, -446300,3.2639234,1.3262523,,,,,,,,,,,,,, -446400,3.0296135,2.8217783,,,,,,,,,,,,,, -446500,3.2212336,1.6744698,,,,,,,,,,,,,, -446600,2.9537554,1.2183119,,,,,,,,,,,,,, -446700,3.2305892,1.0961225,,,,,,,,,,,,,, -446752,,,0.8875976204872131,0.4188170135021209,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,198758.01270985603,216078.06741833687,198758.01270985603,17264.571006774902,31.82603144645691,0.0 -446800,3.2725368,1.2102739,,,,,,,,,,,,,, -446900,3.4420729,2.7638435,,,,,,,,,,,,,, -447000,2.937728,1.1480863,,,,,,,,,,,,,, -447100,4.019305,3.3270216,,,,,,,,,,,,,, -447200,3.470448,1.1582291,,,,,,,,,,,,,, -447300,3.4478319,3.0585163,,,,,,,,,,,,,, -447400,3.4125154,1.2374767,,,,,,,,,,,,,, -447500,3.1530187,1.1818136,,,,,,,,,,,,,, -447600,3.1312983,1.3599145,,,,,,,,,,,,,, -447695,,,0.8876367211341858,0.4201096594333648,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,199178.1304731369,216537.0774514675,199178.1304731369,17303.306176424026,31.93401432037353,0.0 -447700,3.8501213,2.8661296,,,,,,,,,,,,,, -447800,3.145072,1.1773612,,,,,,,,,,,,,, -447900,3.7469685,2.9498863,,,,,,,,,,,,,, -448000,2.942177,1.2994964,,,,,,,,,,,,,, -448100,3.0008974,1.0815381,,,,,,,,,,,,,, -448200,3.3200572,2.780374,,,,,,,,,,,,,, -448300,3.4859183,3.1593075,,,,,,,,,,,,,, -448400,2.999038,1.063364,,,,,,,,,,,,,, -448500,3.4180822,1.2701832,,,,,,,,,,,,,, -448600,2.8702726,1.8071808,,,,,,,,,,,,,, -448636,,,0.8866015672683716,0.421945184469223,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,199598.13326835632,216996.54115891457,199598.13326835632,17342.620814323425,32.0307891368866,0.0 -448700,3.6027796,1.0044847,,,,,,,,,,,,,, -448800,3.0315287,2.1199667,,,,,,,,,,,,,, -448900,3.3056214,1.9856372,,,,,,,,,,,,,, -449000,3.2307081,1.1860505,,,,,,,,,,,,,, -449100,3.214601,1.224386,,,,,,,,,,,,,, -449200,3.187268,1.0953441,,,,,,,,,,,,,, -449300,3.165491,1.0738003,,,,,,,,,,,,,, -449400,3.0891924,1.1244588,,,,,,,,,,,,,, -449500,3.2754693,1.2775105,,,,,,,,,,,,,, -449577,,,0.8864648342132568,0.4194687902927398,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,200018.26824116707,217456.16999864567,200018.26824116707,17381.971882104874,32.12383246421814,0.0 -449600,3.3280706,2.2766051,,,,,,,,,,,,,, -449700,3.0763438,2.2930605,,,,,,,,,,,,,, -449800,5.5678525,3.2762055,,,,,,,,,,,,,, -449900,3.0644214,1.4015932,,,,,,,,,,,,,, -450000,4.2646737,2.712857,,,,,,,,,,,,,, -450100,3.4283223,1.3757247,,,,,,,,,,,,,, -450200,3.5874598,1.0624964,,,,,,,,,,,,,, -450300,2.8607097,1.5123194,,,,,,,,,,,,,, -450400,3.3262727,2.5394807,,,,,,,,,,,,,, -450500,4.2789407,3.1918187,,,,,,,,,,,,,, -450518,,,0.89013671875,0.4164175689220428,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,200438.24313998225,217918.1436984539,200438.24313998225,17423.806594848633,32.23769760131836,0.0 -450600,3.2133582,1.2443433,,,,,,,,,,,,,, -450700,3.0773737,1.5359764,,,,,,,,,,,,,, -450800,2.9770398,1.1103916,,,,,,,,,,,,,, -450900,3.177651,1.1687795,,,,,,,,,,,,,, -451000,2.9738252,2.2318883,,,,,,,,,,,,,, -451100,2.9194224,1.1775548,,,,,,,,,,,,,, -451200,3.252178,1.149864,,,,,,,,,,,,,, -451300,3.4433374,2.6899529,,,,,,,,,,,,,, -451400,3.094037,1.1397215,,,,,,,,,,,,,, -451461,,,0.8877539038658142,0.4214079082012176,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,200858.3983180523,218379.80024075508,200858.3983180523,17465.161774635315,32.33136487007141,0.0 -451500,3.2792118,2.3644047,,,,,,,,,,,,,, -451600,2.888766,1.0094283,,,,,,,,,,,,,, -451700,2.9824686,1.0980852,,,,,,,,,,,,,, -451800,3.279838,1.1190017,,,,,,,,,,,,,, -451900,3.1598318,1.6315587,,,,,,,,,,,,,, -452000,3.2136896,1.0202252,,,,,,,,,,,,,, -452100,3.1290638,1.154885,,,,,,,,,,,,,, -452200,3.0355463,1.6741636,,,,,,,,,,,,,, -452300,3.4354136,2.5465138,,,,,,,,,,,,,, -452400,3.071377,1.153166,,,,,,,,,,,,,, -452405,,,0.8856835961341858,0.4210818409919739,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,201278.27744483948,218841.53184318545,201278.27744483948,17506.869884729385,32.42599129676819,0.0 -452500,4.0647736,1.3028742,,,,,,,,,,,,,, -452600,3.8732777,3.0590985,,,,,,,,,,,,,, -452700,3.4488618,2.7842853,,,,,,,,,,,,,, -452800,3.12849,1.4933249,,,,,,,,,,,,,, -452900,3.123797,1.1401842,,,,,,,,,,,,,, -453000,3.2451987,1.2022651,,,,,,,,,,,,,, -453100,3.313949,1.1089183,,,,,,,,,,,,,, -453200,3.7287672,3.0463574,,,,,,,,,,,,,, -453300,3.1554172,2.314773,,,,,,,,,,,,,, -453348,,,0.88720703125,0.4244422912597656,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,201698.52309155464,219302.0693335533,201698.52309155464,17547.012185573578,32.526304721832275,0.0 -453400,3.2169461,2.7538087,,,,,,,,,,,,,, -453500,3.742528,2.772138,,,,,,,,,,,,,, -453600,3.4029603,2.290143,,,,,,,,,,,,,, -453700,3.135972,1.1181309,,,,,,,,,,,,,, -453800,3.3714032,1.247557,,,,,,,,,,,,,, -453900,2.92055,1.236494,,,,,,,,,,,,,, -454000,3.1309772,1.2383667,,,,,,,,,,,,,, -454100,3.1771894,0.98987746,,,,,,,,,,,,,, -454200,3.361863,1.5022311,,,,,,,,,,,,,, -454290,,,0.8869140148162842,0.421396404504776,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,202118.48056221008,219758.2660264969,202118.48056221008,17583.095573425293,32.6318883895874,0.0 -454300,3.113417,2.1205459,,,,,,,,,,,,,, -454400,3.3338103,2.3334439,,,,,,,,,,,,,, -454500,3.2369735,1.8476774,,,,,,,,,,,,,, -454600,3.2988756,1.1454037,,,,,,,,,,,,,, -454700,2.9961033,1.2050363,,,,,,,,,,,,,, -454800,3.7577262,3.1797578,,,,,,,,,,,,,, -454900,3.3355532,2.7053165,,,,,,,,,,,,,, -455000,2.9773457,1.6473355,,,,,,,,,,,,,, -455100,3.9000235,3.1130486,,,,,,,,,,,,,, -455200,2.9045095,1.2269245,,,,,,,,,,,,,, -455235,,,0.8871093392372131,0.4167193472385406,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,202538.690376997,220219.6612966061,202538.690376997,17624.13773369789,32.72527503967285,0.0 -455300,3.1346107,2.4433143,,,,,,,,,,,,,, -455400,3.1684692,1.3024901,,,,,,,,,,,,,, -455500,3.1688616,2.0657656,,,,,,,,,,,,,, -455600,3.0489144,1.1850294,,,,,,,,,,,,,, -455700,3.1141715,1.1160681,,,,,,,,,,,,,, -455800,3.2490642,1.2684007,,,,,,,,,,,,,, -455900,2.935302,1.0765647,,,,,,,,,,,,,, -456000,3.0663788,1.1314328,,,,,,,,,,,,,, -456100,3.1906154,1.4233245,,,,,,,,,,,,,, -456178,,,0.8875585794448853,0.4189115762710571,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,202958.80642414093,220681.3356461525,202958.80642414093,17665.55302143097,32.81936073303223,0.0 -456200,3.6829197,3.0795786,,,,,,,,,,,,,, -456300,3.4627287,2.3410807,,,,,,,,,,,,,, -456400,3.4522243,1.1362658,,,,,,,,,,,,,, -456500,3.3877766,1.5040654,,,,,,,,,,,,,, -456600,3.0199914,1.1188374,,,,,,,,,,,,,, -456700,2.747053,1.6163235,,,,,,,,,,,,,, -456800,3.4033284,1.1168504,,,,,,,,,,,,,, -456900,3.3079236,1.1613955,,,,,,,,,,,,,, -457000,3.2937431,2.8316295,,,,,,,,,,,,,, -457100,3.3484805,2.6324604,,,,,,,,,,,,,, -457122,,,0.8851562142372131,0.426841527223587,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,203379.1025440693,221136.3347158432,203379.1025440693,17700.11089038849,32.914966344833374,0.0 -457200,3.313021,1.1448349,,,,,,,,,,,,,, -457300,2.8014917,1.1152554,,,,,,,,,,,,,, -457400,3.0268025,1.1571249,,,,,,,,,,,,,, -457500,3.5284243,2.9276648,,,,,,,,,,,,,, -457600,3.1733294,1.0824724,,,,,,,,,,,,,, -457700,2.9622343,1.7176424,,,,,,,,,,,,,, -457800,3.090155,1.0775278,,,,,,,,,,,,,, -457900,4.17594,1.1982632,,,,,,,,,,,,,, -458000,3.5440323,2.5701346,,,,,,,,,,,,,, -458065,,,0.8870898485183716,0.4176160097122192,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,203799.38678312305,221600.2426137924,203799.38678312305,17743.58947658539,33.01036882400513,0.0 -458100,3.4952075,1.6574304,,,,,,,,,,,,,, -458200,3.1584451,1.1457505,,,,,,,,,,,,,, -458300,3.1661255,1.9726071,,,,,,,,,,,,,, -458400,3.1531985,1.6368864,,,,,,,,,,,,,, -458500,3.0109603,1.2110033,,,,,,,,,,,,,, -458600,3.1628807,1.2274777,,,,,,,,,,,,,, -458700,3.5041084,1.2567805,,,,,,,,,,,,,, -458800,3.224198,2.6060672,,,,,,,,,,,,,, -458900,3.2910757,1.5636383,,,,,,,,,,,,,, -459000,3.4899127,1.6476942,,,,,,,,,,,,,, -459006,,,0.8886132836341858,0.4135399162769317,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,204219.39708042145,222056.5509133339,204219.39708042145,17779.73451113701,33.1145339012146,0.0 -459100,3.4494386,2.4697676,,,,,,,,,,,,,, -459200,3.1867669,2.564142,,,,,,,,,,,,,, -459300,3.2297626,1.6253076,,,,,,,,,,,,,, -459400,3.242054,1.1229773,,,,,,,,,,,,,, -459500,3.2563834,1.7636837,,,,,,,,,,,,,, -459600,3.3591537,1.1852789,,,,,,,,,,,,,, -459700,4.0088773,3.3077095,,,,,,,,,,,,,, -459800,3.0191002,2.0546365,,,,,,,,,,,,,, -459900,3.2944374,2.3096404,,,,,,,,,,,,,, -459950,,,0.8851171731948853,0.4251580834388733,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,204639.460521698,222517.1409971714,204639.460521698,17820.108170747757,33.21669864654541,0.0 -460000,3.1245055,1.0659348,,,,,,,,,,,,,, -460100,3.8233907,3.207649,,,,,,,,,,,,,, -460200,3.1911426,1.1715986,,,,,,,,,,,,,, -460300,3.0110588,1.8130642,,,,,,,,,,,,,, -460400,3.0162692,1.1880263,,,,,,,,,,,,,, -460500,2.9597905,1.1477413,,,,,,,,,,,,,, -460600,2.9975123,1.0778304,,,,,,,,,,,,,, -460700,3.353394,2.0760221,,,,,,,,,,,,,, -460800,6.341416,3.0388472,,,,,,,,,,,,,, -460892,,,0.8870507478713989,0.4238626360893249,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,205060.09689879417,222973.8101565838,205060.09689879417,17855.99463057518,33.313735246658325,0.0 -460900,3.111992,2.06976,,,,,,,,,,,,,, -461000,3.1942966,1.2324498,,,,,,,,,,,,,, -461100,3.1055474,1.0542057,,,,,,,,,,,,,, -461200,2.949736,1.1230111,,,,,,,,,,,,,, -461300,3.9910681,1.6561382,,,,,,,,,,,,,, -461400,4.4192786,3.147977,,,,,,,,,,,,,, -461500,3.5775306,1.1239691,,,,,,,,,,,,,, -461600,3.5020003,2.2835333,,,,,,,,,,,,,, -461700,4.1148863,2.3225598,,,,,,,,,,,,,, -461800,3.0092683,1.1538061,,,,,,,,,,,,,, -461831,,,0.8887695074081421,0.4198735058307647,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,205480.0987341404,223434.17816066745,205480.0987341404,17896.213749408722,33.410961866378784,0.0 -461900,4.079675,2.850577,,,,,,,,,,,,,, -462000,3.1132832,1.1279457,,,,,,,,,,,,,, -462100,3.1057289,1.3238902,,,,,,,,,,,,,, -462200,3.664199,3.1594076,,,,,,,,,,,,,, -462300,3.2324626,1.8060274,,,,,,,,,,,,,, -462400,3.2292528,1.208457,,,,,,,,,,,,,, -462500,3.2329865,0.99664146,,,,,,,,,,,,,, -462600,3.1540282,1.2513919,,,,,,,,,,,,,, -462700,3.2705555,1.0960021,,,,,,,,,,,,,, -462776,,,0.8857616782188416,0.4240389466285705,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,205900.09218215945,223893.4460091591,205900.09218215945,17935.342923879623,33.50675344467163,0.0 -462800,3.0780776,1.1059809,,,,,,,,,,,,,, -462900,3.1721144,2.2566295,,,,,,,,,,,,,, -463000,3.9213257,3.1465383,,,,,,,,,,,,,, -463100,3.4743185,1.2440984,,,,,,,,,,,,,, -463200,3.0477803,1.1746821,,,,,,,,,,,,,, -463300,4.312489,3.0265787,,,,,,,,,,,,,, -463400,2.875763,1.0420432,,,,,,,,,,,,,, -463500,3.5513043,1.182971,,,,,,,,,,,,,, -463600,2.9317803,1.1895692,,,,,,,,,,,,,, -463700,3.2871368,1.1769826,,,,,,,,,,,,,, -463722,,,0.8868945240974426,0.4198920726776123,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,206320.3325078488,224348.5914018154,206320.3325078488,17970.10102057457,33.603896617889404,0.0 -463800,2.9994276,1.2447621,,,,,,,,,,,,,, -463900,3.1101284,1.0470723,,,,,,,,,,,,,, -464000,3.3290396,1.4381852,,,,,,,,,,,,,, -464100,3.089249,1.1731026,,,,,,,,,,,,,, -464200,2.9003425,1.1081444,,,,,,,,,,,,,, -464300,3.9018052,3.0608585,,,,,,,,,,,,,, -464400,3.063878,1.0832208,,,,,,,,,,,,,, -464500,3.2230244,1.0222974,,,,,,,,,,,,,, -464600,3.5524557,2.7939801,,,,,,,,,,,,,, -464664,,,0.8875585794448853,0.4207744896411896,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,206740.5910184384,224805.5076739788,206740.5910184384,18006.612882614136,33.70057702064514,0.0 -464700,3.079654,2.213248,,,,,,,,,,,,,, -464800,3.2538853,2.6648047,,,,,,,,,,,,,, -464900,3.9320323,3.177896,,,,,,,,,,,,,, -465000,3.2257993,1.069666,,,,,,,,,,,,,, -465100,3.3393855,2.6826386,,,,,,,,,,,,,, -465200,3.0124624,1.1627197,,,,,,,,,,,,,, -465300,6.0044575,2.5918775,,,,,,,,,,,,,, -465400,4.2246327,3.0901694,,,,,,,,,,,,,, -465500,3.1277626,1.3770708,,,,,,,,,,,,,, -465600,3.1818218,1.870362,,,,,,,,,,,,,, -465607,,,0.8863866925239563,0.4194602966308594,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,207160.64344906807,225263.02623271945,207160.64344906807,18043.933857917786,33.79648423194885,0.0 -465700,3.052227,2.5495977,,,,,,,,,,,,,, -465800,3.196932,1.5033314,,,,,,,,,,,,,, -465900,3.1906862,1.173809,,,,,,,,,,,,,, -466000,3.1886978,2.341005,,,,,,,,,,,,,, -466100,3.1733,1.0886289,,,,,,,,,,,,,, -466200,3.3768628,2.7395985,,,,,,,,,,,,,, -466300,3.2685723,2.8474913,,,,,,,,,,,,,, -466400,3.320424,1.3621953,,,,,,,,,,,,,, -466500,2.9858618,1.1319584,,,,,,,,,,,,,, -466553,,,0.88783198595047,0.4138371646404266,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,207580.8476507664,225721.3563463688,207580.8476507664,18081.914219379425,33.89208626747131,0.0 -466600,3.1381474,1.3241018,,,,,,,,,,,,,, -466700,3.1447515,1.5735276,,,,,,,,,,,,,, -466800,3.4951901,1.9723289,,,,,,,,,,,,,, -466900,3.3009925,1.2639437,,,,,,,,,,,,,, -467000,3.1615357,1.1028868,,,,,,,,,,,,,, -467100,3.1444829,1.1498836,,,,,,,,,,,,,, -467200,3.3683908,1.3135121,,,,,,,,,,,,,, -467300,2.9338045,1.1664125,,,,,,,,,,,,,, -467400,3.061531,2.3753133,,,,,,,,,,,,,, -467498,,,0.8881640434265137,0.4183202385902405,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,208000.78075790405,226177.8390073776,208000.78075790405,18118.31734418869,33.98908615112305,0.0 -467500,3.1125743,2.4859858,,,,,,,,,,,,,, -467600,3.2460215,1.1776164,,,,,,,,,,,,,, -467700,2.9714706,1.176748,,,,,,,,,,,,,, -467800,3.0003126,1.1387012,,,,,,,,,,,,,, -467900,3.11605,1.1602336,,,,,,,,,,,,,, -468000,3.3917656,1.1893455,,,,,,,,,,,,,, -468100,3.431954,1.3323696,,,,,,,,,,,,,, -468200,2.836771,1.8078629,,,,,,,,,,,,,, -468300,3.260566,1.1125239,,,,,,,,,,,,,, -468400,3.2419791,1.8121618,,,,,,,,,,,,,, -468444,,,0.8898242115974426,0.41509810090065,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,208420.9566919804,226639.3285050392,208420.9566919804,18159.481626987457,34.088141679763794,0.0 -468500,3.2548273,2.5609045,,,,,,,,,,,,,, -468600,3.4828393,1.1296561,,,,,,,,,,,,,, -468700,3.2227206,2.7965763,,,,,,,,,,,,,, -468800,3.481596,2.8245254,,,,,,,,,,,,,, -468900,3.331968,3.061104,,,,,,,,,,,,,, -469000,3.052286,1.0906522,,,,,,,,,,,,,, -469100,2.9330955,0.9791313,,,,,,,,,,,,,, -469200,3.6238618,3.0500884,,,,,,,,,,,,,, -469300,3.076444,1.1053773,,,,,,,,,,,,,, -469389,,,0.8898828029632568,0.4089874625205993,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,208841.1808238029,227095.1313097477,208841.1808238029,18194.91436839104,34.18496608734131,0.0 -469400,3.1383095,1.1389552,,,,,,,,,,,,,, -469500,2.888887,1.1378007,,,,,,,,,,,,,, -469600,3.4280853,2.5155892,,,,,,,,,,,,,, -469700,3.0254095,1.9540198,,,,,,,,,,,,,, -469800,3.0069032,1.7807543,,,,,,,,,,,,,, -469900,3.1415281,1.1080688,,,,,,,,,,,,,, -470000,3.9907691,2.5862758,,,,,,,,,,,,,, -470100,2.9285529,1.8704145,,,,,,,,,,,,,, -470200,3.1632802,1.17221,,,,,,,,,,,,,, -470300,3.6879787,1.6849818,,,,,,,,,,,,,, -470336,,,0.8884375095367432,0.4139959514141083,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,209261.11147785187,227552.2586643696,209261.11147785187,18231.963979959488,34.282028675079346,0.0 -470400,4.287681,2.6631405,,,,,,,,,,,,,, -470500,3.0437877,1.566124,,,,,,,,,,,,,, -470600,3.1750793,1.1015116,,,,,,,,,,,,,, -470700,3.5532818,1.1858052,,,,,,,,,,,,,, -470800,2.7727017,1.4304551,,,,,,,,,,,,,, -470900,3.634712,3.1403246,,,,,,,,,,,,,, -471000,3.0800867,1.0713544,,,,,,,,,,,,,, -471100,3.528446,1.2224952,,,,,,,,,,,,,, -471200,3.1618235,1.221723,,,,,,,,,,,,,, -471281,,,0.8876171708106995,0.4210407435894012,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,209681.18304800987,228010.3366763592,209681.18304800987,18269.8248910904,34.37890291213989,0.0 -471300,3.0576088,1.728784,,,,,,,,,,,,,, -471400,3.2769566,1.1262158,,,,,,,,,,,,,, -471500,3.066836,1.170908,,,,,,,,,,,,,, -471600,3.0361087,1.445099,,,,,,,,,,,,,, -471700,3.2373219,1.0936798,,,,,,,,,,,,,, -471800,2.9999943,1.7073196,,,,,,,,,,,,,, -471900,3.2835526,1.2877554,,,,,,,,,,,,,, -472000,3.073352,1.7985765,,,,,,,,,,,,,, -472100,3.013225,1.3344125,,,,,,,,,,,,,, -472200,3.2296436,1.1535841,,,,,,,,,,,,,, -472225,,,0.8863866925239563,0.4219321608543396,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,210101.27172207832,228468.8173315525,210101.27172207832,18308.068465471268,34.47755169868469,0.0 -472300,3.1064537,1.9298966,,,,,,,,,,,,,, -472400,3.2802699,2.103428,,,,,,,,,,,,,, -472500,3.3450398,1.180397,,,,,,,,,,,,,, -472600,3.4156964,1.1503646,,,,,,,,,,,,,, -472700,2.9483054,1.1152551,,,,,,,,,,,,,, -472800,3.0542738,2.2981086,,,,,,,,,,,,,, -472900,3.1837711,1.2230889,,,,,,,,,,,,,, -473000,4.1210704,3.215671,,,,,,,,,,,,,, -473100,2.9180727,1.0892535,,,,,,,,,,,,,, -473169,,,0.8877539038658142,0.4184952080249786,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,210521.1846988201,228927.59491848943,210521.1846988201,18346.783236026764,34.57780885696411,0.0 -473200,3.301652,1.2383435,,,,,,,,,,,,,, -473300,3.0266736,1.3596834,,,,,,,,,,,,,, -473400,4.020299,2.2470114,,,,,,,,,,,,,, -473500,3.287571,1.556633,,,,,,,,,,,,,, -473600,3.2402723,1.1213924,,,,,,,,,,,,,, -473700,3.3428109,1.2345519,,,,,,,,,,,,,, -473800,2.950674,1.4692606,,,,,,,,,,,,,, -473900,3.0771961,1.1001527,,,,,,,,,,,,,, -474000,3.2893486,1.0503863,,,,,,,,,,,,,, -474100,3.5229547,2.8135993,,,,,,,,,,,,,, -474114,,,0.8903319835662842,0.4145033657550812,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,210941.2804641724,229390.34396600723,210941.2804641724,18389.272876262665,34.690468072891235,0.0 -474200,3.4154756,1.0786585,,,,,,,,,,,,,, -474300,3.0108979,1.3016782,,,,,,,,,,,,,, -474400,3.2395694,1.3565618,,,,,,,,,,,,,, -474500,2.9393013,1.181033,,,,,,,,,,,,,, -474600,3.2269714,2.2830307,,,,,,,,,,,,,, -474700,3.2448084,1.3606821,,,,,,,,,,,,,, -474800,3.0073686,1.1729424,,,,,,,,,,,,,, -474900,3.3313184,1.0761552,,,,,,,,,,,,,, -475000,3.0705247,1.2107781,,,,,,,,,,,,,, -475059,,,0.8855859041213989,0.4250911474227905,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,211361.19549322128,229847.67683005333,211361.19549322128,18426.54616785049,34.78611469268799,0.0 -475100,3.1901872,1.1273261,,,,,,,,,,,,,, -475200,3.380735,1.6660818,,,,,,,,,,,,,, -475300,3.2836797,1.1687647,,,,,,,,,,,,,, -475400,3.483332,2.7606766,,,,,,,,,,,,,, -475500,3.115679,1.0995954,,,,,,,,,,,,,, -475600,3.2108147,1.4551849,,,,,,,,,,,,,, -475700,2.8979166,1.6002519,,,,,,,,,,,,,, -475800,3.2129815,1.1107537,,,,,,,,,,,,,, -475900,2.9026282,1.8614249,,,,,,,,,,,,,, -476000,3.1918943,1.1590496,,,,,,,,,,,,,, -476001,,,0.8862109184265137,0.4204814434051513,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,211781.18303704265,230307.3840615749,211781.18303704265,18466.10234069824,34.90059804916382,0.0 -476100,3.4793038,2.5116422,,,,,,,,,,,,,, -476200,3.150868,1.5807303,,,,,,,,,,,,,, -476300,3.1040387,1.5057588,,,,,,,,,,,,,, -476400,3.118223,1.966782,,,,,,,,,,,,,, -476500,3.281928,1.1379868,,,,,,,,,,,,,, -476600,3.0849886,1.3106391,,,,,,,,,,,,,, -476700,3.3003113,1.0867643,,,,,,,,,,,,,, -476800,2.9526317,1.9249129,,,,,,,,,,,,,, -476900,3.4248106,2.679234,,,,,,,,,,,,,, -476944,,,0.8863866925239563,0.4251146614551544,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,212201.07793593407,230765.7916574478,212201.07793593407,18504.46773743629,34.99848484992981,0.0 -477000,3.0987735,1.6260622,,,,,,,,,,,,,, -477100,3.543564,1.017941,,,,,,,,,,,,,, -477200,3.1200457,1.2796559,,,,,,,,,,,,,, -477300,3.161694,2.2568731,,,,,,,,,,,,,, -477400,3.1883302,1.1749138,,,,,,,,,,,,,, -477500,2.9324598,1.083184,,,,,,,,,,,,,, -477600,2.9326527,1.2133056,,,,,,,,,,,,,, -477700,3.5539975,2.8686266,,,,,,,,,,,,,, -477800,4.5613914,3.2413387,,,,,,,,,,,,,, -477892,,,0.88671875,0.4219008982181549,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,212621.0614106655,231225.7468166352,212621.0614106655,18544.29231786728,35.0954225063324,0.0 -477900,3.0245543,1.166045,,,,,,,,,,,,,, -478000,3.398312,1.2335757,,,,,,,,,,,,,, -478100,3.1439075,1.0896887,,,,,,,,,,,,,, -478200,3.3307116,2.7390895,,,,,,,,,,,,,, -478300,2.7673814,1.6548365,,,,,,,,,,,,,, -478400,3.7278438,3.121405,,,,,,,,,,,,,, -478500,4.048079,3.089724,,,,,,,,,,,,,, -478600,4.458643,3.241965,,,,,,,,,,,,,, -478700,3.169248,1.135592,,,,,,,,,,,,,, -478800,3.4786303,2.499978,,,,,,,,,,,,,, -478836,,,0.8879687190055847,0.4177588820457458,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,213041.1797640324,231685.1926748753,213041.1797640324,18583.47327947617,35.19288158416748,0.0 -478900,3.1718695,1.1274742,,,,,,,,,,,,,, -479000,3.07606,2.0322075,,,,,,,,,,,,,, -479100,2.9696915,1.132361,,,,,,,,,,,,,, -479200,2.9677193,2.3118386,,,,,,,,,,,,,, -479300,2.9033494,1.2514192,,,,,,,,,,,,,, -479400,3.1332386,1.1360291,,,,,,,,,,,,,, -479500,2.9838333,1.0639216,,,,,,,,,,,,,, -479600,3.038921,1.7226502,,,,,,,,,,,,,, -479700,4.2478976,3.2425194,,,,,,,,,,,,,, -479780,,,0.8868749737739563,0.4182649850845337,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,213461.09774923325,232144.8058159352,213461.09774923325,18623.00409412384,35.30354833602905,0.0 -479800,2.8376098,1.030898,,,,,,,,,,,,,, -479900,3.0283103,1.6723485,,,,,,,,,,,,,, -480000,3.3535051,2.6973808,,,,,,,,,,,,,, -480100,3.329302,1.109342,,,,,,,,,,,,,, -480200,2.9612396,1.1457175,,,,,,,,,,,,,, -480300,3.2930295,1.7942648,,,,,,,,,,,,,, -480400,2.9744823,2.3380182,,,,,,,,,,,,,, -480500,2.94586,1.0408541,,,,,,,,,,,,,, -480600,3.0436072,1.1523142,,,,,,,,,,,,,, -480700,3.6421783,1.7525735,,,,,,,,,,,,,, -480722,,,0.8855273127555847,0.4250691831111908,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,213881.37053871155,232600.1719095707,213881.37053871155,18657.94558501244,35.4053909778595,0.0 -480800,3.2911294,2.3638566,,,,,,,,,,,,,, -480900,2.8297064,1.348653,,,,,,,,,,,,,, -481000,3.1337643,1.185635,,,,,,,,,,,,,, -481100,3.1006196,1.3645396,,,,,,,,,,,,,, -481200,3.0015032,1.0881407,,,,,,,,,,,,,, -481300,3.952814,3.2763171,,,,,,,,,,,,,, -481400,3.1677935,2.021523,,,,,,,,,,,,,, -481500,3.5798142,1.1217222,,,,,,,,,,,,,, -481600,3.905931,3.3453593,,,,,,,,,,,,,, -481665,,,0.8876171708106995,0.4181188941001892,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,214301.31644773483,233058.9361166954,214301.31644773483,18696.6129488945,35.506397008895874,0.0 -481700,3.022173,1.7576665,,,,,,,,,,,,,, -481800,4.8992047,3.2387922,,,,,,,,,,,,,, -481900,3.2271805,1.0537119,,,,,,,,,,,,,, -482000,3.0682065,1.4045671,,,,,,,,,,,,,, -482100,3.051167,1.0235288,,,,,,,,,,,,,, -482200,2.9824772,1.0955784,,,,,,,,,,,,,, -482300,3.8307464,1.3882351,,,,,,,,,,,,,, -482400,3.4394891,2.883721,,,,,,,,,,,,,, -482500,3.116614,1.1364008,,,,,,,,,,,,,, -482600,3.9994361,2.7758012,,,,,,,,,,,,,, -482609,,,0.8867382407188416,0.4196585118770599,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,214721.59282803533,233519.15626096723,214721.59282803533,18736.40962123871,35.604432344436646,0.0 -482700,2.766227,0.96967965,,,,,,,,,,,,,, -482800,2.9831185,1.0031556,,,,,,,,,,,,,, -482900,2.9290285,1.1698195,,,,,,,,,,,,,, -483000,3.335248,2.110335,,,,,,,,,,,,,, -483100,3.2005649,1.7650285,,,,,,,,,,,,,, -483200,3.523933,1.8981614,,,,,,,,,,,,,, -483300,3.349883,2.6277878,,,,,,,,,,,,,, -483400,3.3427637,2.668256,,,,,,,,,,,,,, -483500,3.1412518,1.3121834,,,,,,,,,,,,,, -483552,,,0.8869335651397705,0.4238941967487335,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,215141.62749814987,233976.2240343094,215141.62749814987,18773.288115262985,35.703930616378784,0.0 -483600,7.2298627,3.250823,,,,,,,,,,,,,, -483700,3.7380974,1.8644303,,,,,,,,,,,,,, -483800,4.3505583,3.2112756,,,,,,,,,,,,,, -483900,3.071628,1.1366143,,,,,,,,,,,,,, -484000,3.0110564,1.8976645,,,,,,,,,,,,,, -484100,3.2004209,1.6330237,,,,,,,,,,,,,, -484200,3.4165425,1.2074997,,,,,,,,,,,,,, -484300,3.4070423,2.8719885,,,,,,,,,,,,,, -484400,4.174971,3.3070931,,,,,,,,,,,,,, -484496,,,0.8873242139816284,0.4227310717105865,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,215561.4826450348,234433.5977020264,215561.4826450348,18810.65714740753,35.80401062965393,0.0 -484500,2.8831863,1.1886623,,,,,,,,,,,,,, -484600,2.9477363,1.3382022,,,,,,,,,,,,,, -484700,3.3866518,1.1841606,,,,,,,,,,,,,, -484800,3.1166422,1.3068756,,,,,,,,,,,,,, -484900,3.0428402,1.5131391,,,,,,,,,,,,,, -485000,3.3929338,2.7591896,,,,,,,,,,,,,, -485100,3.1200037,1.84483,,,,,,,,,,,,,, -485200,3.2911549,1.3473339,,,,,,,,,,,,,, -485300,2.9535756,1.0461997,,,,,,,,,,,,,, -485400,3.1774762,1.0293834,,,,,,,,,,,,,, -485439,,,0.8871093392372131,0.421786367893219,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,215981.5551698208,234892.46096420288,215981.5551698208,18849.29625797272,35.90559959411621,0.0 -485500,2.922887,1.0779736,,,,,,,,,,,,,, -485600,3.3262737,1.6438922,,,,,,,,,,,,,, -485700,3.616829,3.054729,,,,,,,,,,,,,, -485800,3.3205223,1.5581326,,,,,,,,,,,,,, -485900,3.222095,1.1842843,,,,,,,,,,,,,, -486000,3.2728302,1.0579343,,,,,,,,,,,,,, -486100,3.319428,1.6973104,,,,,,,,,,,,,, -486200,4.0022616,1.1703255,,,,,,,,,,,,,, -486300,3.502825,1.1844786,,,,,,,,,,,,,, -486383,,,0.8866601586341858,0.4253378212451935,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,216401.4870171547,235356.4436588288,216401.4870171547,18893.18506455421,36.01810646057129,0.0 -486400,3.5141118,3.1329663,,,,,,,,,,,,,, -486500,3.287387,1.0538453,,,,,,,,,,,,,, -486600,3.4717696,2.8300943,,,,,,,,,,,,,, -486700,3.0011432,1.0936433,,,,,,,,,,,,,, -486800,3.362519,2.3827124,,,,,,,,,,,,,, -486900,3.1409905,1.3264514,,,,,,,,,,,,,, -487000,3.5274916,2.3199363,,,,,,,,,,,,,, -487100,3.887664,3.3137097,,,,,,,,,,,,,, -487200,3.1137059,1.7607927,,,,,,,,,,,,,, -487300,3.1140292,1.342699,,,,,,,,,,,,,, -487328,,,0.8872265219688416,0.4159683585166931,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,216821.3461008072,235815.8413271904,216821.3461008072,18932.574870586395,36.117307901382446,0.0 -487400,3.128386,1.1113977,,,,,,,,,,,,,, -487500,3.1720383,1.0989858,,,,,,,,,,,,,, -487600,3.1532125,1.1805869,,,,,,,,,,,,,, -487700,3.1707919,1.1827983,,,,,,,,,,,,,, -487800,3.0161252,2.303619,,,,,,,,,,,,,, -487900,4.053056,1.1832161,,,,,,,,,,,,,, -488000,2.9923086,2.1580653,,,,,,,,,,,,,, -488100,3.468632,1.8214047,,,,,,,,,,,,,, -488200,3.792469,3.2549617,,,,,,,,,,,,,, -488273,,,0.8877733945846558,0.4194367825984955,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,217241.2795341015,236273.9787230492,217241.2795341015,18970.62588739395,36.22060370445252,0.0 -488300,3.207199,2.1416042,,,,,,,,,,,,,, -488400,3.0038798,2.3301544,,,,,,,,,,,,,, -488500,3.0061328,2.5778978,,,,,,,,,,,,,, -488600,3.2265675,1.1302762,,,,,,,,,,,,,, -488700,3.5403702,2.6379933,,,,,,,,,,,,,, -488800,2.9250324,1.1292751,,,,,,,,,,,,,, -488900,3.3259254,2.8880017,,,,,,,,,,,,,, -489000,3.0826707,1.1365452,,,,,,,,,,,,,, -489100,3.049048,1.6429306,,,,,,,,,,,,,, -489200,3.8458664,2.99149,,,,,,,,,,,,,, -489217,,,0.8881054520606995,0.4167794585227966,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,217661.31353735924,236733.6101295948,217661.31353735924,19010.070831775665,36.32327747344971,0.0 -489300,3.325058,2.6528623,,,,,,,,,,,,,, -489400,3.5620692,2.930805,,,,,,,,,,,,,, -489500,3.2108605,2.3756073,,,,,,,,,,,,,, -489600,3.8075318,3.1739402,,,,,,,,,,,,,, -489700,3.3673856,1.2217817,,,,,,,,,,,,,, -489800,3.4086323,1.1067829,,,,,,,,,,,,,, -489900,4.1808615,3.1995077,,,,,,,,,,,,,, -490000,2.965121,1.5844015,,,,,,,,,,,,,, -490100,4.189996,1.2040809,,,,,,,,,,,,,, -490163,,,0.8856640458106995,0.4225226938724518,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,218081.38036298752,237196.2936933041,218081.38036298752,19052.525792360306,36.43513631820679,0.0 -490200,3.323005,1.0675979,,,,,,,,,,,,,, -490300,3.116733,1.8139774,,,,,,,,,,,,,, -490400,3.2567432,1.4459205,,,,,,,,,,,,,, -490500,3.2051928,1.1356262,,,,,,,,,,,,,, -490600,3.7193892,2.3013515,,,,,,,,,,,,,, -490700,3.4505644,2.6552773,,,,,,,,,,,,,, -490800,2.9870498,1.1171141,,,,,,,,,,,,,, -490900,3.2650712,1.6152756,,,,,,,,,,,,,, -491000,3.02384,1.0712434,,,,,,,,,,,,,, -491100,3.002999,1.5939707,,,,,,,,,,,,,, -491107,,,0.8886913657188416,0.413221538066864,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,218501.2313103676,237657.09906315804,218501.2313103676,19093.32017993927,36.5455060005188,0.0 -491200,3.848647,2.9502149,,,,,,,,,,,,,, -491300,3.293746,1.1316899,,,,,,,,,,,,,, -491400,3.0006945,2.1278765,,,,,,,,,,,,,, -491500,3.2131197,1.5599693,,,,,,,,,,,,,, -491600,3.0052269,1.0812275,,,,,,,,,,,,,, -491700,3.1429434,1.085988,,,,,,,,,,,,,, -491800,3.8764439,3.1491432,,,,,,,,,,,,,, -491900,3.0949285,2.385631,,,,,,,,,,,,,, -492000,2.7998517,1.0924698,,,,,,,,,,,,,, -492050,,,0.8886913657188416,0.4157118201255798,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,218921.18591451645,238113.10402703285,218921.18591451645,19129.19732618332,36.66964244842529,0.0 -492100,3.1086984,1.2436168,,,,,,,,,,,,,, -492200,3.193413,2.6818452,,,,,,,,,,,,,, -492300,2.9302015,1.1456022,,,,,,,,,,,,,, -492400,3.4146044,2.766948,,,,,,,,,,,,,, -492500,2.9674296,1.6000813,,,,,,,,,,,,,, -492600,3.2452545,1.1006262,,,,,,,,,,,,,, -492700,3.5003567,2.9952242,,,,,,,,,,,,,, -492800,3.806054,1.1233748,,,,,,,,,,,,,, -492900,3.3528526,2.5078156,,,,,,,,,,,,,, -492992,,,0.8909375071525574,0.4079668819904327,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,219341.1279244423,238570.96243953705,219341.1279244423,19166.96463418007,36.76979398727417,0.0 -493000,3.7443047,1.1279895,,,,,,,,,,,,,, -493100,3.1641433,1.219766,,,,,,,,,,,,,, -493200,3.142413,1.6868681,,,,,,,,,,,,,, -493300,3.3016887,1.0942872,,,,,,,,,,,,,, -493400,3.6575713,2.196759,,,,,,,,,,,,,, -493500,3.1351545,1.5033892,,,,,,,,,,,,,, -493600,3.0795145,1.1650705,,,,,,,,,,,,,, -493700,3.1205578,1.1630062,,,,,,,,,,,,,, -493800,3.160325,1.1475703,,,,,,,,,,,,,, -493900,3.2871354,2.4863434,,,,,,,,,,,,,, -493934,,,0.8881054520606995,0.4184964597225189,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,219761.21338629723,239031.57810759544,219761.21338629723,19207.34047794342,36.87393951416016,0.0 -494000,3.4035347,1.1427125,,,,,,,,,,,,,, -494100,3.1235602,1.5003302,,,,,,,,,,,,,, -494200,3.5865068,2.105729,,,,,,,,,,,,,, -494300,4.0156054,3.3165145,,,,,,,,,,,,,, -494400,3.2766404,1.2179182,,,,,,,,,,,,,, -494500,3.8346114,1.1294246,,,,,,,,,,,,,, -494600,3.2558944,2.7046316,,,,,,,,,,,,,, -494700,4.9551983,2.9201822,,,,,,,,,,,,,, -494800,3.7569308,3.1365914,,,,,,,,,,,,,, -494878,,,0.8865429759025574,0.4192816019058227,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,220181.35465550423,239490.0839271545,220181.35465550423,19245.554092168808,36.97444653511048,0.0 -494900,3.52633,2.9793859,,,,,,,,,,,,,, -495000,3.596425,2.869593,,,,,,,,,,,,,, -495100,6.608511,3.266092,,,,,,,,,,,,,, -495200,3.1032987,1.2233733,,,,,,,,,,,,,, -495300,2.971079,1.8987625,,,,,,,,,,,,,, -495400,3.581979,1.1066175,,,,,,,,,,,,,, -495500,3.0588233,1.1858082,,,,,,,,,,,,,, -495600,3.1580715,1.5970355,,,,,,,,,,,,,, -495700,4.337163,3.1074293,,,,,,,,,,,,,, -495800,3.6744103,2.0211077,,,,,,,,,,,,,, -495819,,,0.8858202695846558,0.4241812825202942,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,220601.31715726847,239951.438188076,220601.31715726847,19286.792535066605,37.07832598686218,0.0 -495900,3.3257656,1.350343,,,,,,,,,,,,,, -496000,3.4038584,2.8255372,,,,,,,,,,,,,, -496100,2.9912648,1.1845179,,,,,,,,,,,,,, -496200,4.030908,2.7525918,,,,,,,,,,,,,, -496300,3.497469,1.8950219,,,,,,,,,,,,,, -496400,2.9738305,1.099805,,,,,,,,,,,,,, -496500,3.1157007,1.3507826,,,,,,,,,,,,,, -496600,3.2209196,1.8754333,,,,,,,,,,,,,, -496700,3.5716286,3.0437782,,,,,,,,,,,,,, -496762,,,0.8873046636581421,0.4224570691585541,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,221021.30834054947,240409.68428444865,221021.30834054947,19324.87941670417,37.19647455215454,0.0 -496800,3.2571206,1.0386066,,,,,,,,,,,,,, -496900,2.962145,1.2748692,,,,,,,,,,,,,, -497000,6.352119,1.1394062,,,,,,,,,,,,,, -497100,2.8718202,1.1768802,,,,,,,,,,,,,, -497200,2.9588008,1.2854573,,,,,,,,,,,,,, -497300,3.0820878,1.5193802,,,,,,,,,,,,,, -497400,3.6554296,3.0552323,,,,,,,,,,,,,, -497500,3.2608337,1.183973,,,,,,,,,,,,,, -497600,3.0591083,1.049198,,,,,,,,,,,,,, -497700,3.0127454,1.1169192,,,,,,,,,,,,,, -497706,,,0.8900781273841858,0.4112922847270965,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,221441.4514014721,240865.10836744308,221441.4514014721,19360.00890374184,37.29786229133606,0.0 -497800,3.0492046,2.0839405,,,,,,,,,,,,,, -497900,3.352279,2.693238,,,,,,,,,,,,,, -498000,3.301437,2.432475,,,,,,,,,,,,,, -498100,3.0138574,1.1530082,,,,,,,,,,,,,, -498200,3.1316037,1.0766522,,,,,,,,,,,,,, -498300,2.8669937,1.0868292,,,,,,,,,,,,,, -498400,3.2466612,1.3640379,,,,,,,,,,,,,, -498500,3.663216,1.1785285,,,,,,,,,,,,,, -498600,2.9198089,1.3266914,,,,,,,,,,,,,, -498649,,,0.8861132860183716,0.424767792224884,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,221861.6649634838,241327.65272402763,221861.6649634838,19402.17110681533,37.41727375984192,0.0 -498700,4.0797386,1.9717578,,,,,,,,,,,,,, -498800,3.0329251,1.1811364,,,,,,,,,,,,,, -498900,3.5989127,1.5016968,,,,,,,,,,,,,, -499000,3.29916,1.1959971,,,,,,,,,,,,,, -499100,2.9898546,2.3105798,,,,,,,,,,,,,, -499200,3.3720953,1.4470109,,,,,,,,,,,,,, -499300,3.0437882,1.1253238,,,,,,,,,,,,,, -499400,3.3092444,1.1462684,,,,,,,,,,,,,, -499500,3.6739411,1.1382154,,,,,,,,,,,,,, -499590,,,0.888476550579071,0.4164844751358032,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,222281.58246564865,241791.0271379948,222281.58246564865,19445.47658014297,37.51878428459168,0.0 -499600,3.1961849,1.1290218,,,,,,,,,,,,,, -499700,3.3721886,1.3775228,,,,,,,,,,,,,, -499800,3.1626241,1.1779623,,,,,,,,,,,,,, -499900,3.2373805,1.202251,,,,,,,,,,,,,, -500000,3.0437822,1.1525643,,,,,,,,,,,,,, -500100,2.914378,1.3471005,,,,,,,,,,,,,, -500200,2.6757505,1.8658178,,,,,,,,,,,,,, -500300,3.109289,2.2589839,,,,,,,,,,,,,, -500400,3.9101717,1.6657404,,,,,,,,,,,,,, -500500,3.040535,1.1331525,,,,,,,,,,,,,, -500531,,,0.8866991996765137,0.4256736040115356,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,222701.7090072632,242246.6289396286,222701.7090072632,19480.782245635983,37.63924813270569,0.0 -500600,3.1140282,2.0248928,,,,,,,,,,,,,, -500700,3.113596,2.0582175,,,,,,,,,,,,,, -500800,3.789667,3.0098913,,,,,,,,,,,,,, -500900,4.7585506,1.4333897,,,,,,,,,,,,,, -501000,3.838081,3.0509865,,,,,,,,,,,,,, -501100,3.0805178,1.2304808,,,,,,,,,,,,,, -501200,3.1618028,1.2607546,,,,,,,,,,,,,, -501300,3.2047036,1.1464834,,,,,,,,,,,,,, -501400,2.9771926,1.8342707,,,,,,,,,,,,,, -501475,,,0.8859765529632568,0.4227696061134338,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,223121.9326927662,242709.27169013023,223121.9326927662,19523.047776937485,37.74262857437134,0.0 -501500,3.5590763,1.163946,,,,,,,,,,,,,, -501600,3.3504345,1.3071008,,,,,,,,,,,,,, -501700,3.177817,1.094167,,,,,,,,,,,,,, -501800,3.3488326,1.139283,,,,,,,,,,,,,, -501900,2.8321357,1.0523053,,,,,,,,,,,,,, -502000,2.9520495,2.0551877,,,,,,,,,,,,,, -502100,2.805944,1.2551521,,,,,,,,,,,,,, -502200,3.1172807,1.1263266,,,,,,,,,,,,,, -502300,3.0950615,1.5948571,,,,,,,,,,,,,, -502400,3.1650739,1.5541782,,,,,,,,,,,,,, -502421,,,0.8878905773162842,0.4137028455734253,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,223542.1022367477,243169.86743497849,223542.1022367477,19563.318788290024,37.84835863113403,0.0 -502500,3.1996748,2.664003,,,,,,,,,,,,,, -502600,3.336589,1.4310588,,,,,,,,,,,,,, -502700,3.5190063,2.800264,,,,,,,,,,,,,, -502800,3.247787,2.241977,,,,,,,,,,,,,, -502900,2.937489,1.0520643,,,,,,,,,,,,,, -503000,3.028032,2.1564476,,,,,,,,,,,,,, -503100,3.1580222,2.30902,,,,,,,,,,,,,, -503200,3.7863712,1.1892778,,,,,,,,,,,,,, -503300,3.3577282,1.1309623,,,,,,,,,,,,,, -503361,,,0.8865038752555847,0.4238264262676239,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,223962.2054350376,243628.98177289963,223962.2054350376,19602.17483854294,37.95388746261597,0.0 -503400,3.281489,2.2815006,,,,,,,,,,,,,, -503500,3.055733,1.6403702,,,,,,,,,,,,,, -503600,3.490825,1.4062632,,,,,,,,,,,,,, -503700,3.1351461,1.0810294,,,,,,,,,,,,,, -503800,3.2262676,1.6592976,,,,,,,,,,,,,, -503900,3.190556,1.2574731,,,,,,,,,,,,,, -504000,2.8598897,1.1402667,,,,,,,,,,,,,, -504100,2.9573808,1.1182266,,,,,,,,,,,,,, -504200,3.1120286,1.0268338,,,,,,,,,,,,,, -504300,3.1289966,1.8039787,,,,,,,,,,,,,, -504305,,,0.8867577910423279,0.4207553267478943,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,224382.14086842537,244092.43917560577,224382.14086842537,19645.54245853424,38.05898094177246,0.0 -504400,3.5151534,1.1403178,,,,,,,,,,,,,, -504500,3.363162,2.4086382,,,,,,,,,,,,,, -504600,3.0528762,1.9698077,,,,,,,,,,,,,, -504700,3.2811654,2.500636,,,,,,,,,,,,,, -504800,3.0399218,1.0517962,,,,,,,,,,,,,, -504900,3.120104,1.0658399,,,,,,,,,,,,,, -505000,3.8017602,2.9575405,,,,,,,,,,,,,, -505100,3.1950612,2.2608995,,,,,,,,,,,,,, -505200,4.5564885,3.2808938,,,,,,,,,,,,,, -505247,,,0.8862499594688416,0.4233685731887817,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,224802.12478852272,244549.902169466,224802.12478852272,19682.870091676712,38.161217212677,0.0 -505300,3.437044,1.100668,,,,,,,,,,,,,, -505400,3.0132372,1.1960064,,,,,,,,,,,,,, -505500,3.6808717,1.0685337,,,,,,,,,,,,,, -505600,4.2481995,1.5608265,,,,,,,,,,,,,, -505700,3.270428,1.4806978,,,,,,,,,,,,,, -505800,3.0354168,1.5005165,,,,,,,,,,,,,, -505900,5.714747,1.0580281,,,,,,,,,,,,,, -506000,3.168449,1.1404831,,,,,,,,,,,,,, -506100,3.0583668,1.171859,,,,,,,,,,,,,, -506190,,,0.8874218463897705,0.4169977903366089,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,225222.24004030228,245010.63395023343,225222.24004030228,19723.331513643265,38.26738238334656,0.0 -506200,2.8769875,1.5553311,,,,,,,,,,,,,, -506300,3.1806715,1.6237042,,,,,,,,,,,,,, -506400,3.079331,1.124044,,,,,,,,,,,,,, -506500,3.2385626,1.1028725,,,,,,,,,,,,,, -506600,3.257831,1.261121,,,,,,,,,,,,,, -506700,3.2210126,1.2518132,,,,,,,,,,,,,, -506800,3.2322865,1.403899,,,,,,,,,,,,,, -506900,3.1258004,1.4409308,,,,,,,,,,,,,, -507000,5.208733,3.1947274,,,,,,,,,,,,,, -507100,3.1023114,1.1690452,,,,,,,,,,,,,, -507136,,,0.8875390291213989,0.418641984462738,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,225642.23197722435,245468.84431529045,225642.23197722435,19761.377603769302,38.3882851600647,0.0 -507200,3.453369,2.9867713,,,,,,,,,,,,,, -507300,3.2288332,1.0458301,,,,,,,,,,,,,, -507400,3.3048496,2.7092369,,,,,,,,,,,,,, -507500,2.9860163,1.1855643,,,,,,,,,,,,,, -507600,3.5366998,1.4779264,,,,,,,,,,,,,, -507700,3.3436685,2.7059643,,,,,,,,,,,,,, -507800,3.023924,1.2082584,,,,,,,,,,,,,, -507900,2.9963987,1.1786121,,,,,,,,,,,,,, -508000,3.2086143,2.206012,,,,,,,,,,,,,, -508079,,,0.8869531154632568,0.4258208274841308,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,226062.2147741317,245928.9227323532,226062.2147741317,19801.3218562603,38.48969912528992,0.0 -508100,2.9532478,1.0324097,,,,,,,,,,,,,, -508200,3.1987078,1.4412591,,,,,,,,,,,,,, -508300,2.9880917,1.6148951,,,,,,,,,,,,,, -508400,2.9874792,1.0893564,,,,,,,,,,,,,, -508500,4.138395,2.6437862,,,,,,,,,,,,,, -508600,2.9506502,1.0572804,,,,,,,,,,,,,, -508700,2.8510053,1.2119548,,,,,,,,,,,,,, -508800,3.1338546,1.1594541,,,,,,,,,,,,,, -508900,3.532981,1.1463348,,,,,,,,,,,,,, -509000,3.441385,1.1483601,,,,,,,,,,,,,, -509024,,,0.8847460746765137,0.4278537333011627,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,226482.067596674,246393.21773839,226482.067596674,19845.604430675507,38.59902667999268,0.0 -509100,3.0674093,1.1407039,,,,,,,,,,,,,, -509200,2.9890318,2.2862961,,,,,,,,,,,,,, -509300,3.0512946,1.2523758,,,,,,,,,,,,,, -509400,3.1462502,1.4183089,,,,,,,,,,,,,, -509500,2.8486495,2.1333585,,,,,,,,,,,,,, -509600,3.1007278,2.8806167,,,,,,,,,,,,,, -509700,4.9785748,3.3016763,,,,,,,,,,,,,, -509800,3.5890021,3.090338,,,,,,,,,,,,,, -509900,3.4809625,2.899212,,,,,,,,,,,,,, -509965,,,0.8874022960662842,0.4216749966144562,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,226902.1107466221,246852.9529211521,226902.1107466221,19885.14374423027,38.70250916481018,0.0 -510000,3.2872498,2.5641103,,,,,,,,,,,,,, -510100,3.3786964,2.3413134,,,,,,,,,,,,,, -510200,3.475655,1.1363243,,,,,,,,,,,,,, -510300,4.327781,3.1569335,,,,,,,,,,,,,, -510400,3.1654394,1.1147163,,,,,,,,,,,,,, -510500,3.367341,1.2078751,,,,,,,,,,,,,, -510600,4.170389,3.236892,,,,,,,,,,,,,, -510700,3.0438228,1.1514913,,,,,,,,,,,,,, -510800,3.1007802,1.2919321,,,,,,,,,,,,,, -510900,3.6265857,2.7463784,,,,,,,,,,,,,, -510909,,,0.8890234231948853,0.4123329520225525,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,227322.3340280056,247314.1253759861,227322.3340280056,19925.925420999527,38.81987309455872,0.0 -511000,3.1910298,1.1206578,,,,,,,,,,,,,, -511100,3.012886,1.043187,,,,,,,,,,,,,, -511200,4.4238296,3.295969,,,,,,,,,,,,,, -511300,2.9776893,1.143783,,,,,,,,,,,,,, -511400,3.1422675,1.4391072,,,,,,,,,,,,,, -511500,2.9682164,1.3924797,,,,,,,,,,,,,, -511600,3.0912347,1.1163428,,,,,,,,,,,,,, -511700,3.0576737,1.5767745,,,,,,,,,,,,,, -511800,3.185507,1.2764485,,,,,,,,,,,,,, -511853,,,0.8871679306030273,0.4232731163501739,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,227742.40439462665,247777.97219228745,227742.40439462665,19969.54983282089,38.92194437980652,0.0 -511900,3.1515162,1.7279147,,,,,,,,,,,,,, -512000,3.4692097,3.0248008,,,,,,,,,,,,,, -512100,3.0955842,2.1338978,,,,,,,,,,,,,, -512200,3.7776592,3.3373866,,,,,,,,,,,,,, -512300,3.0787482,1.2067859,,,,,,,,,,,,,, -512400,3.1167662,1.2099179,,,,,,,,,,,,,, -512500,5.255036,1.16944,,,,,,,,,,,,,, -512600,3.643725,1.7317088,,,,,,,,,,,,,, -512700,4.015668,1.1744319,,,,,,,,,,,,,, -512796,,,0.8870507478713989,0.4156704545021057,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,228162.349899292,248237.13788104057,228162.349899292,20008.613343715668,39.02891778945923,0.0 -512800,3.0446782,1.1351719,,,,,,,,,,,,,, -512900,3.028266,1.1731675,,,,,,,,,,,,,, -513000,2.9691231,1.5664924,,,,,,,,,,,,,, -513100,3.389622,2.2668185,,,,,,,,,,,,,, -513200,3.1532247,1.9041247,,,,,,,,,,,,,, -513300,3.609895,1.2908442,,,,,,,,,,,,,, -513400,2.9966822,1.2487651,,,,,,,,,,,,,, -513500,3.0701792,2.7268376,,,,,,,,,,,,,, -513600,3.1446843,1.1446271,,,,,,,,,,,,,, -513700,3.243755,1.6181858,,,,,,,,,,,,,, -513739,,,0.8882030844688416,0.4158047139644623,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,228582.2359097004,248697.24965715408,228582.2359097004,20048.68375325203,39.13458871841431,0.0 -513800,3.0610874,1.1923876,,,,,,,,,,,,,, -513900,3.0293076,1.7314352,,,,,,,,,,,,,, -514000,2.8882477,1.0694137,,,,,,,,,,,,,, -514100,8.075239,1.0263168,,,,,,,,,,,,,, -514200,4.0292187,3.1959782,,,,,,,,,,,,,, -514300,3.7231493,3.1814618,,,,,,,,,,,,,, -514400,5.051924,2.8487928,,,,,,,,,,,,,, -514500,2.970162,1.4813263,,,,,,,,,,,,,, -514600,4.532558,2.9889286,,,,,,,,,,,,,, -514682,,,0.8873828053474426,0.4168609082698822,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,229002.4842107296,249158.3658988476,229002.4842107296,20089.395189762115,39.24177360534668,0.0 -514700,3.194577,1.1021916,,,,,,,,,,,,,, -514800,3.5126562,2.7829766,,,,,,,,,,,,,, -514900,3.910885,3.0841522,,,,,,,,,,,,,, -515000,3.0801566,1.0911577,,,,,,,,,,,,,, -515100,5.1286893,2.360455,,,,,,,,,,,,,, -515200,3.0377777,1.1105316,,,,,,,,,,,,,, -515300,3.2796073,1.8904922,,,,,,,,,,,,,, -515400,3.151,1.1508118,,,,,,,,,,,,,, -515500,3.1200342,1.4299927,,,,,,,,,,,,,, -515600,4.9653263,3.1111307,,,,,,,,,,,,,, -515623,,,0.8875976204872131,0.4215762913227081,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,229422.5807933808,249618.18248152733,229422.5807933808,20128.963701486588,39.34479546546936,0.0 -515700,3.0213597,1.0525336,,,,,,,,,,,,,, -515800,3.027966,1.1984731,,,,,,,,,,,,,, -515900,3.6929474,3.16036,,,,,,,,,,,,,, -516000,3.2507555,1.8615289,,,,,,,,,,,,,, -516100,3.2696066,1.2318158,,,,,,,,,,,,,, -516200,3.6792479,3.0846062,,,,,,,,,,,,,, -516300,4.4649715,3.1289427,,,,,,,,,,,,,, -516400,3.0098639,1.1541724,,,,,,,,,,,,,, -516500,2.8599284,1.0324385,,,,,,,,,,,,,, -516566,,,0.8914648294448853,0.4060322046279907,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,229842.44540429115,250075.90839600563,229842.44540429115,20166.67224431038,39.44740009307861,0.0 -516600,3.010987,1.0268652,,,,,,,,,,,,,, -516700,3.0223477,1.2028527,,,,,,,,,,,,,, -516800,3.5829682,2.9212177,,,,,,,,,,,,,, -516900,3.6270254,2.9302437,,,,,,,,,,,,,, -517000,3.3300934,1.5680299,,,,,,,,,,,,,, -517100,3.20856,1.0636009,,,,,,,,,,,,,, -517200,4.122162,3.2192512,,,,,,,,,,,,,, -517300,2.875158,2.0403695,,,,,,,,,,,,,, -517400,3.0838811,1.1782416,,,,,,,,,,,,,, -517500,3.0920725,1.2376308,,,,,,,,,,,,,, -517510,,,0.8875195384025574,0.4187245368957519,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,230262.38619685173,250532.1736881733,230262.38619685173,20202.819187641144,39.57495212554932,0.0 -517600,3.1641822,2.671298,,,,,,,,,,,,,, -517700,3.1531398,2.3421743,,,,,,,,,,,,,, -517800,3.3354478,2.6685207,,,,,,,,,,,,,, -517900,3.498947,1.0901517,,,,,,,,,,,,,, -518000,3.6331563,3.2663255,,,,,,,,,,,,,, -518100,3.500433,1.4868478,,,,,,,,,,,,,, -518200,3.440899,2.1089609,,,,,,,,,,,,,, -518300,3.3521094,2.7055547,,,,,,,,,,,,,, -518400,3.219777,1.1361586,,,,,,,,,,,,,, -518455,,,0.8878515362739563,0.417012482881546,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,230682.52629995343,250991.47455906868,230682.52629995343,20241.813415050507,39.69140291213989,0.0 -518500,2.9996471,1.0739413,,,,,,,,,,,,,, -518600,3.0960855,1.0341125,,,,,,,,,,,,,, -518700,2.9870894,1.2580565,,,,,,,,,,,,,, -518800,2.9695113,1.134727,,,,,,,,,,,,,, -518900,2.945423,1.3518159,,,,,,,,,,,,,, -519000,3.897103,2.9806144,,,,,,,,,,,,,, -519100,3.027496,1.5380529,,,,,,,,,,,,,, -519200,3.209558,1.171426,,,,,,,,,,,,,, -519300,3.417871,3.0042331,,,,,,,,,,,,,, -519400,3.0032105,2.5880413,,,,,,,,,,,,,, -519401,,,0.8869335651397705,0.4184837341308594,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,231102.6332271099,251447.6019744873,231102.6332271099,20277.67854499817,39.79632830619812,0.0 -519500,3.4111392,2.9973533,,,,,,,,,,,,,, -519600,4.0738997,1.7689173,,,,,,,,,,,,,, -519700,3.0775683,1.0924643,,,,,,,,,,,,,, -519800,2.7701893,1.0035629,,,,,,,,,,,,,, -519900,3.2841065,2.7548504,,,,,,,,,,,,,, -520000,4.5441895,3.2477903,,,,,,,,,,,,,, -520100,3.8470461,3.1854534,,,,,,,,,,,,,, -520200,3.1848247,1.471086,,,,,,,,,,,,,, -520300,3.6824102,1.1513681,,,,,,,,,,,,,, -520346,,,0.88685542345047,0.4225926995277405,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,231522.629506588,251906.8376162052,231522.629506588,20316.76153421402,39.9029598236084,0.0 -520400,3.2360566,1.1311865,,,,,,,,,,,,,, -520500,3.3530006,1.3757647,,,,,,,,,,,,,, -520600,3.2115076,1.1476922,,,,,,,,,,,,,, -520700,3.1789243,1.7807279,,,,,,,,,,,,,, -520800,3.1111336,1.3929749,,,,,,,,,,,,,, -520900,4.6275115,1.0740364,,,,,,,,,,,,,, -521000,3.2721472,2.691935,,,,,,,,,,,,,, -521100,3.149918,2.0149753,,,,,,,,,,,,,, -521200,2.9826088,1.1113204,,,,,,,,,,,,,, -521288,,,0.8891991972923279,0.4170268177986145,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,231942.6417376995,252369.12612867355,231942.6417376995,20358.879606485367,40.01094198226929,0.0 -521300,4.418088,3.1827073,,,,,,,,,,,,,, -521400,3.2952738,1.0869632,,,,,,,,,,,,,, -521500,2.9070306,1.0106514,,,,,,,,,,,,,, -521600,4.0368676,3.1629074,,,,,,,,,,,,,, -521700,3.4271438,1.1012093,,,,,,,,,,,,,, -521800,3.03289,1.102735,,,,,,,,,,,,,, -521900,3.4056525,1.1304739,,,,,,,,,,,,,, -522000,4.025221,2.8685322,,,,,,,,,,,,,, -522100,3.2511473,1.1875769,,,,,,,,,,,,,, -522200,3.2230482,1.0481077,,,,,,,,,,,,,, -522232,,,0.8890429735183716,0.4175160527229309,0.7831799983978271,0.8526236414909363,50000.0,0.6615000367164612,1.4626646041870115,10000.0,232362.6910688877,252827.0764052868,232362.6910688877,20396.62574887276,40.11666750907898,0.0 -522300,3.0484076,1.0835557,,,,,,,,,,,,,, -522400,3.2113504,1.051054,,,,,,,,,,,,,, -522500,4.3836083,2.1253676,,,,,,,,,,,,,, -522600,3.3288255,2.8051586,,,,,,,,,,,,,, -522683,,,,,,,,,,,232560.30878305435,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index fc16f6bc1..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,43 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -177.47227215766907,0.0,61.46844673156738,1,0,61.46844673156738,30.464462,2472,1.2198322263522434,238.94078707695007,31.430166,1.2202850717364595,30.370512,5348,1.215762186585825 -308.289675951004,0.0418143272399902,1501.4747319221497,1723,0,1501.4747319221497,3.4976146,2472,0.6335587918672435,1809.871680021286,3.439656,0.6469434203208997,3.7791116,5348,0.6606292902864536 -443.4888708591461,0.0941410064697265,2941.9052169322968,3481,0,2941.9052169322968,0.6647084,2472,0.217780756809457,3385.6249141693115,0.62516457,0.2151026822576778,0.961843,5348,0.2835957789856821 -579.3772525787354,0.1505255699157714,4382.4255402088165,5191,0,4382.4255402088165,0.48897383,2472,0.1633051002376455,4962.160263299942,0.41936246,0.152373873945106,0.7425709,5348,0.2224238971972542 -715.6960117816925,0.2050743103027343,5822.786217927933,6923,0,5822.786217927933,0.43173426,2472,0.147238640749091,6538.965226650238,0.37841558,0.1352564169694938,0.6750195,5348,0.205248269403439 -852.4245226383209,0.2643783092498779,7263.397251367569,8668,0,7263.397251367569,0.39091432,2472,0.130928442304958,8116.435704231262,0.33041155,0.1210724303637826,0.6191281,5348,0.1872423414464601 -987.8912315368652,0.3169138431549072,8703.343909740448,10374,0,8703.343909740448,0.37396675,2472,0.1265614526841752,9691.972328186035,0.34174573,0.121349045883991,0.5995971,5348,0.1801654807534491 -1122.5591580867767,0.3622667789459228,10143.505054235458,12071,0,10143.505054235458,0.34399128,2472,0.1172790607925578,11266.91709280014,0.2921563,0.105678216663476,0.5633191,5348,0.1713604371626905 -1259.3874979019165,0.4119458198547363,11583.59387087822,13799,0,11583.59387087822,0.32829466,2472,0.1116324416549875,12843.95545077324,0.2486947,0.0929145323591357,0.5417403,5348,0.162806414551493 -1395.4358882904053,0.4626312255859375,13023.565511226654,15494,0,13023.565511226654,0.32373837,2472,0.109540348952938,14420.097017765043,0.2355237,0.0882645542616069,0.52823937,5348,0.1592921208376377 -1529.8941369056702,0.5137264728546143,14463.611997127531,17202,0,14463.611997127531,0.30617744,2472,0.1036499908597891,15994.723123788834,0.22558315,0.0871789646425372,0.50909185,5348,0.1533834731648918 -1665.189467906952,0.5614349842071533,15903.78661584854,18927,0,15903.78661584854,0.29994157,2472,0.1016391444762659,17570.31193447113,0.24126676,0.0893836431423358,0.5059711,5348,0.1527269567568089 -1800.3386902809143,0.6096725463867188,17343.6914768219,20605,0,17343.6914768219,0.29017657,2472,0.0997501675705319,19145.483505010605,0.24273874,0.0900097607302081,0.50120276,5348,0.1509601552468212 -1935.5979552268984,0.6652815341949463,18783.908142089844,22325,0,18783.908142089844,0.28587866,2472,0.0981252412000081,20721.08587527275,0.2258042,0.0840981582467332,0.48603284,5348,0.1474651708390859 -2070.4026024341583,0.7172105312347412,20223.88219809532,24032,0,20223.88219809532,0.28025538,2472,0.094489468445961,22295.987997055054,0.20639215,0.078781512605042,0.48175213,5348,0.144308099288452 -2206.801432132721,0.7604649066925049,21664.468277931213,25720,0,21664.468277931213,0.27270424,2472,0.0918489630938598,23873.08303618431,0.19494656,0.0738412582487806,0.46175957,5348,0.1385153074524266 -2343.496416091919,0.811537504196167,23104.85160088539,27407,0,23104.85160088539,0.2689558,2472,0.0907318262141246,25450.283534526825,0.18990134,0.0708441153307217,0.45618972,5348,0.137685007289263 -2476.7793169021606,0.8699901103973389,24544.772507190704,29120,0,24544.772507190704,0.2627904,2472,0.0880710092823919,27023.616567134857,0.1999816,0.0741462991853007,0.45578963,5348,0.1337362541877057 -2610.891562461853,0.9234447479248048,25985.5729637146,30815,0,25985.5729637146,0.2543592,2472,0.0854101923506591,28598.65405917168,0.18715546,0.068063310164461,0.43771347,5348,0.1296619905963679 -2747.5177092552185,0.9884016513824464,27425.56795692444,32516,0,27425.56795692444,0.25108758,2472,0.0833993459671358,30175.411805152893,0.19618995,0.0722970786309068,0.43466973,5348,0.1297199185147281 -2883.643948793412,1.0415289402008057,28865.83015680313,34220,0,28865.83015680313,0.24840562,2472,0.0835415270245567,31751.92447423935,0.19564296,0.0697113079062629,0.4341091,5348,0.1286772159842436 -3020.135945081711,1.094953536987305,30307.033792495728,35924,0,30307.033792495728,0.23915388,2472,0.0792354721426685,33329.742975473404,0.14429401,0.0552002295952359,0.41627276,5348,0.1244581325970051 -3158.3678278923035,1.1490800380706787,31746.93484067917,37625,0,31746.93484067917,0.23604308,2472,0.077773038409197,34908.00041890144,0.16264132,0.0605425439800637,0.41757557,5348,0.1228747694951581 -3293.553910017013,1.211296558380127,33187.41482400894,39331,0,33187.41482400894,0.23083746,2472,0.0759449962423577,36483.80149149895,0.20321023,0.074391282415669,0.4084176,5348,0.1200459561485658 -3426.381377220154,1.2644286155700684,34627.85082554817,41030,0,34627.85082554817,0.2242229,2472,0.0741575772347815,38057.18810248375,0.2045451,0.0760643288164008,0.40269017,5348,0.1192639292507023 -3559.942678451538,1.317253351211548,36068.26052713394,42744,0,36068.26052713394,0.22219394,2472,0.0729998171957833,39631.28194499016,0.23569493,0.0872179747413001,0.39436883,5348,0.1162323681898491 -3693.0100190639496,1.3747758865356443,37508.558656692505,44447,0,37508.558656692505,0.2164622,2472,0.0709483476529969,41204.77492642403,0.20926268,0.0754552573167931,0.3915589,5348,0.1143304015370207 -3827.943774223328,1.4282572269439695,38948.61928009987,46143,0,38948.61928009987,0.2097056,2472,0.0708671013344707,42779.89405918121,0.18335475,0.0694053199425473,0.37230888,5348,0.1101982100273226 -3962.282675266266,1.4953930377960205,40388.79512381554,47857,0,40388.79512381554,0.20599574,2472,0.068429711778685,44354.54707813263,0.15611507,0.0589402821129104,0.37141645,5348,0.1087403574152562 -4097.3878445625305,1.5552036762237549,41829.10494494438,49556,0,41829.10494494438,0.20355535,2472,0.0662969959173725,45930.09180688858,0.1717872,0.0652484426509963,0.36111674,5348,0.105728105660523 -4231.793177843094,1.6100687980651855,43269.49426746368,51252,0,43269.49426746368,0.19794679,2472,0.0652001706172689,47505.0118470192,0.15362392,0.0589192149115272,0.35629016,5348,0.1045116193749577 -4367.57383275032,1.7452752590179443,44709.73230814934,52977,0,44709.73230814934,0.19058819,2472,0.063758048463429,49081.2375099659,0.14720261,0.0560105880949023,0.35049742,5348,0.1014704036610444 -4503.586663246155,1.8032231330871584,46150.22907114029,54683,0,46150.22907114029,0.18795525,2472,0.061869071557695,50657.874841213226,0.13729951,0.0526597653554175,0.34367317,5348,0.0999449684775577 -4637.395668745041,1.857304096221924,47590.79759192467,56392,0,47590.79759192467,0.183492,2472,0.0607316230983283,52232.37819981575,0.1411343,0.0551892677678193,0.33604154,5348,0.0983905693348909 -4771.6484117507935,1.911901473999024,49030.792356967926,58122,0,49030.792356967926,0.18152194,2472,0.0592082546259622,53806.75104427338,0.14084783,0.0525272547076313,0.33468187,5348,0.096131380518841 -4905.197862625122,1.971836805343628,50470.96611952782,59821,0,50470.96611952782,0.1730187,2472,0.0560599597830723,55380.60426187515,0.11546845,0.0449368933127087,0.32863495,5348,0.0947411104781949 -5041.062778234482,2.031010150909424,51910.99819993973,61529,0,51910.99819993973,0.17211151,2472,0.0557959092478622,56956.63300848007,0.11642376,0.0449945232987996,0.32271284,5348,0.0927232879886461 -5175.682463884354,2.0892953872680664,53351.327491045,63247,0,53351.327491045,0.17125273,2472,0.0544756565718115,58531.71190810204,0.10640869,0.0406809355889836,0.31391063,5348,0.0896820722747328 -5308.787171840668,2.1523008346557617,54791.25001358986,64925,0,54791.25001358986,0.1674624,2472,0.0541303597180752,60104.87237620354,0.10517093,0.0403502678820521,0.31392825,5348,0.0894503606012917 -5443.867123842239,2.211320400238037,56231.39729690552,66640,0,56231.39729690552,0.16359144,2472,0.0522413828123413,61680.23007559776,0.09685134,0.0373594060955703,0.30721417,5348,0.086544310030219 -5580.501723051071,2.2747132778167725,57671.69725656509,68358,0,57671.69725656509,0.1620917,2472,0.0517742164808157,63257.30109000206,0.09348633,0.0353541502938799,0.30337724,5348,0.0861291599486372 -5714.869594812393,2.3328914642333984,59111.74648475647,70048,0,59111.74648475647,0.15958287,2472,0.05079926065850141,64831.84681892395,0.07539534,0.029094321914594087,0.29719102,5348,0.08392789905094761 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 3c36e9b37..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,745 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,55.24324,32.435127,,,,,,,,,,,,,, -1,,,31.430166,1.2202850717364595,30.370512,1.215762186585825,5348.0,30.464462,1.2198322263522434,2472.0,61.46844673156738,238.94078707695007,61.46844673156738,177.47227215766907,0.0,0.0 -100,0.5205828,5.9781127,,,,,,,,,,,,,, -200,0.29299626,5.8282275,,,,,,,,,,,,,, -300,4.533542,5.8366523,,,,,,,,,,,,,, -400,2.4960258,5.8124657,,,,,,,,,,,,,, -500,0.9890033,5.805365,,,,,,,,,,,,,, -600,0.30069774,5.771849,,,,,,,,,,,,,, -700,2.1153061,5.634913,,,,,,,,,,,,,, -800,1.1599513,5.46081,,,,,,,,,,,,,, -900,3.6537538,4.910568,,,,,,,,,,,,,, -1000,0.9079363,3.8561537,,,,,,,,,,,,,, -1100,1.2057168,3.3915038,,,,,,,,,,,,,, -1200,0.9927767,3.0827353,,,,,,,,,,,,,, -1300,0.6917195,2.91227,,,,,,,,,,,,,, -1400,0.8823533,2.7778893,,,,,,,,,,,,,, -1500,1.2413036,2.6774526,,,,,,,,,,,,,, -1600,1.0412946,2.5240717,,,,,,,,,,,,,, -1700,1.1456126,2.4612346,,,,,,,,,,,,,, -1723,,,3.439656,0.6469434203208997,3.7791116,0.6606292902864536,5348.0,3.4976146,0.6335587918672435,2472.0,1501.4747319221497,1809.871680021286,1501.4747319221497,308.289675951004,0.0418143272399902,0.0 -1800,0.73195565,2.4317563,,,,,,,,,,,,,, -1900,0.65181595,2.2931085,,,,,,,,,,,,,, -2000,1.3942864,2.2737675,,,,,,,,,,,,,, -2100,0.60553515,2.2447734,,,,,,,,,,,,,, -2200,0.5950309,2.2026424,,,,,,,,,,,,,, -2300,0.7912952,2.1156886,,,,,,,,,,,,,, -2400,0.6834817,2.0532653,,,,,,,,,,,,,, -2500,0.68360317,2.1520476,,,,,,,,,,,,,, -2600,0.6772255,2.0170133,,,,,,,,,,,,,, -2700,0.770127,1.9686544,,,,,,,,,,,,,, -2800,0.6232496,2.0095048,,,,,,,,,,,,,, -2900,0.5080708,1.9565905,,,,,,,,,,,,,, -3000,0.7316554,1.9193004,,,,,,,,,,,,,, -3100,0.7681935,1.893942,,,,,,,,,,,,,, -3200,0.5390393,1.9556593,,,,,,,,,,,,,, -3300,0.56439567,1.9229703,,,,,,,,,,,,,, -3400,0.6846572,1.8610982,,,,,,,,,,,,,, -3481,,,0.62516457,0.2151026822576778,0.961843,0.2835957789856821,5348.0,0.6647084,0.217780756809457,2472.0,2941.9052169322968,3385.6249141693115,2941.9052169322968,443.4888708591461,0.0941410064697265,0.0 -3500,0.64287096,1.8527465,,,,,,,,,,,,,, -3600,0.53825855,1.8327525,,,,,,,,,,,,,, -3700,0.5000394,1.852986,,,,,,,,,,,,,, -3800,0.67602926,1.8024915,,,,,,,,,,,,,, -3900,0.49725404,1.803203,,,,,,,,,,,,,, -4000,0.6720408,1.809422,,,,,,,,,,,,,, -4100,0.5200553,1.7647809,,,,,,,,,,,,,, -4200,0.6245259,1.7757008,,,,,,,,,,,,,, -4300,0.7015008,1.7869422,,,,,,,,,,,,,, -4400,0.572705,1.7392808,,,,,,,,,,,,,, -4500,0.64552,1.7526082,,,,,,,,,,,,,, -4600,0.47160015,1.7286979,,,,,,,,,,,,,, -4700,0.47650638,1.7037779,,,,,,,,,,,,,, -4800,0.4982121,1.749202,,,,,,,,,,,,,, -4900,0.560835,1.7959819,,,,,,,,,,,,,, -5000,0.50627935,1.7019924,,,,,,,,,,,,,, -5100,0.6533555,1.7057587,,,,,,,,,,,,,, -5191,,,0.41936246,0.152373873945106,0.7425709,0.2224238971972542,5348.0,0.48897383,0.1633051002376455,2472.0,4382.4255402088165,4962.160263299942,4382.4255402088165,579.3772525787354,0.1505255699157714,0.0 -5200,0.73298126,1.7252644,,,,,,,,,,,,,, -5300,0.575727,1.6622019,,,,,,,,,,,,,, -5400,0.6644683,1.7220964,,,,,,,,,,,,,, -5500,0.62283206,1.6931471,,,,,,,,,,,,,, -5600,0.49914494,1.6266779,,,,,,,,,,,,,, -5700,0.6054794,1.6823659,,,,,,,,,,,,,, -5800,0.5634541,1.6465752,,,,,,,,,,,,,, -5900,0.55698466,1.623068,,,,,,,,,,,,,, -6000,0.5848995,1.7002889,,,,,,,,,,,,,, -6100,0.46341705,1.627045,,,,,,,,,,,,,, -6200,0.4469633,1.6498424,,,,,,,,,,,,,, -6300,0.61016726,1.6465418,,,,,,,,,,,,,, -6400,0.48316863,1.649938,,,,,,,,,,,,,, -6500,0.68752784,1.6662093,,,,,,,,,,,,,, -6600,0.5773564,1.5770297,,,,,,,,,,,,,, -6700,0.61641645,1.5829742,,,,,,,,,,,,,, -6800,0.46545178,1.654662,,,,,,,,,,,,,, -6900,0.42521223,1.6163938,,,,,,,,,,,,,, -6923,,,0.37841558,0.1352564169694938,0.6750195,0.205248269403439,5348.0,0.43173426,0.147238640749091,2472.0,5822.786217927933,6538.965226650238,5822.786217927933,715.6960117816925,0.2050743103027343,0.0 -7000,0.6333179,1.648083,,,,,,,,,,,,,, -7100,0.45173335,1.5815196,,,,,,,,,,,,,, -7200,0.47587898,1.6167974,,,,,,,,,,,,,, -7300,0.4821544,1.6414328,,,,,,,,,,,,,, -7400,0.50026554,1.614399,,,,,,,,,,,,,, -7500,0.7212194,1.6093003,,,,,,,,,,,,,, -7600,0.585587,1.5814744,,,,,,,,,,,,,, -7700,0.6080905,1.5444463,,,,,,,,,,,,,, -7800,0.5673077,1.646847,,,,,,,,,,,,,, -7900,0.6240256,1.5969825,,,,,,,,,,,,,, -8000,0.41152346,1.5998834,,,,,,,,,,,,,, -8100,0.5187255,1.5592502,,,,,,,,,,,,,, -8200,0.43199354,1.4822937,,,,,,,,,,,,,, -8300,0.5750828,1.5389963,,,,,,,,,,,,,, -8400,0.4454088,1.5169764,,,,,,,,,,,,,, -8500,0.43535897,1.5491661,,,,,,,,,,,,,, -8600,0.51266026,1.52606,,,,,,,,,,,,,, -8668,,,0.33041155,0.1210724303637826,0.6191281,0.1872423414464601,5348.0,0.39091432,0.130928442304958,2472.0,7263.397251367569,8116.435704231262,7263.397251367569,852.4245226383209,0.2643783092498779,0.0 -8700,0.38703194,1.5182552,,,,,,,,,,,,,, -8800,0.4405096,1.532412,,,,,,,,,,,,,, -8900,0.49718127,1.533942,,,,,,,,,,,,,, -9000,0.49425137,1.5525855,,,,,,,,,,,,,, -9100,0.5628626,1.5398575,,,,,,,,,,,,,, -9200,0.62618124,1.5028665,,,,,,,,,,,,,, -9300,0.61713684,1.5211588,,,,,,,,,,,,,, -9400,0.51490957,1.5252907,,,,,,,,,,,,,, -9500,0.3683127,1.4531682,,,,,,,,,,,,,, -9600,0.55216956,1.5204194,,,,,,,,,,,,,, -9700,0.54018164,1.4650971,,,,,,,,,,,,,, -9800,0.70979655,1.5140502,,,,,,,,,,,,,, -9900,0.6823505,1.5760804,,,,,,,,,,,,,, -10000,0.5092966,1.5079799,,,,,,,,,,,,,, -10100,0.50207114,1.5392848,,,,,,,,,,,,,, -10200,0.48001397,1.4688684,,,,,,,,,,,,,, -10300,0.464262,1.4577456,,,,,,,,,,,,,, -10374,,,0.34174573,0.121349045883991,0.5995971,0.1801654807534491,5348.0,0.37396675,0.1265614526841752,2472.0,8703.343909740448,9691.972328186035,8703.343909740448,987.8912315368652,0.3169138431549072,0.0 -10400,0.48161134,1.5024222,,,,,,,,,,,,,, -10500,0.5174107,1.4890467,,,,,,,,,,,,,, -10600,0.44159552,1.455687,,,,,,,,,,,,,, -10700,0.4955456,1.4588802,,,,,,,,,,,,,, -10800,0.52801615,1.4776682,,,,,,,,,,,,,, -10900,0.49552917,1.5322471,,,,,,,,,,,,,, -11000,0.5244426,1.4409232,,,,,,,,,,,,,, -11100,0.44531867,1.5150772,,,,,,,,,,,,,, -11200,0.66866094,1.4859136,,,,,,,,,,,,,, -11300,0.49154988,1.4542072,,,,,,,,,,,,,, -11400,0.53586376,1.4880604,,,,,,,,,,,,,, -11500,0.55875474,1.4666785,,,,,,,,,,,,,, -11600,0.48405036,1.4696113,,,,,,,,,,,,,, -11700,0.3935278,1.4369192,,,,,,,,,,,,,, -11800,0.4703764,1.4391135,,,,,,,,,,,,,, -11900,0.53302187,1.4486823,,,,,,,,,,,,,, -12000,0.56265956,1.4990735,,,,,,,,,,,,,, -12071,,,0.2921563,0.105678216663476,0.5633191,0.1713604371626905,5348.0,0.34399128,0.1172790607925578,2472.0,10143.505054235458,11266.91709280014,10143.505054235458,1122.5591580867767,0.3622667789459228,0.0 -12100,0.46401417,1.452801,,,,,,,,,,,,,, -12200,0.55385995,1.443494,,,,,,,,,,,,,, -12300,0.4893869,1.4277781,,,,,,,,,,,,,, -12400,0.5119362,1.4507476,,,,,,,,,,,,,, -12500,0.47213453,1.4155736,,,,,,,,,,,,,, -12600,0.5526002,1.4598353,,,,,,,,,,,,,, -12700,0.5998077,1.4454532,,,,,,,,,,,,,, -12800,0.4547834,1.425394,,,,,,,,,,,,,, -12900,0.48087347,1.3977798,,,,,,,,,,,,,, -13000,0.45255443,1.4376115,,,,,,,,,,,,,, -13100,0.5531607,1.4131361,,,,,,,,,,,,,, -13200,0.5133132,1.4791743,,,,,,,,,,,,,, -13300,0.45480362,1.4255652,,,,,,,,,,,,,, -13400,0.54968923,1.3968214,,,,,,,,,,,,,, -13500,0.70492256,1.4521812,,,,,,,,,,,,,, -13600,0.5254945,1.4338285,,,,,,,,,,,,,, -13700,0.43680644,1.4268396,,,,,,,,,,,,,, -13799,,,0.2486947,0.0929145323591357,0.5417403,0.162806414551493,5348.0,0.32829466,0.1116324416549875,2472.0,11583.59387087822,12843.95545077324,11583.59387087822,1259.3874979019165,0.4119458198547363,0.0 -13800,0.6256083,1.4199244,,,,,,,,,,,,,, -13900,0.48208204,1.438453,,,,,,,,,,,,,, -14000,0.46578753,1.3773599,,,,,,,,,,,,,, -14100,0.53316283,1.4123863,,,,,,,,,,,,,, -14200,0.48992696,1.4607139,,,,,,,,,,,,,, -14300,0.46953213,1.4742616,,,,,,,,,,,,,, -14400,0.4912882,1.4522262,,,,,,,,,,,,,, -14500,0.4033603,1.421071,,,,,,,,,,,,,, -14600,0.44934192,1.4122084,,,,,,,,,,,,,, -14700,0.5189495,1.3855337,,,,,,,,,,,,,, -14800,0.58732426,1.4479458,,,,,,,,,,,,,, -14900,0.5017404,1.3720925,,,,,,,,,,,,,, -15000,0.6220562,1.390394,,,,,,,,,,,,,, -15100,0.5674843,1.3897176,,,,,,,,,,,,,, -15200,0.49314603,1.3959126,,,,,,,,,,,,,, -15300,0.5306485,1.3981283,,,,,,,,,,,,,, -15400,0.49613068,1.3940803,,,,,,,,,,,,,, -15494,,,0.2355237,0.0882645542616069,0.52823937,0.1592921208376377,5348.0,0.32373837,0.109540348952938,2472.0,13023.565511226654,14420.097017765043,13023.565511226654,1395.4358882904053,0.4626312255859375,0.0 -15500,0.40137097,1.397349,,,,,,,,,,,,,, -15600,0.60635805,1.3760265,,,,,,,,,,,,,, -15700,0.4916099,1.4176598,,,,,,,,,,,,,, -15800,0.54666984,1.3896081,,,,,,,,,,,,,, -15900,0.5398081,1.4188032,,,,,,,,,,,,,, -16000,0.59742284,1.357783,,,,,,,,,,,,,, -16100,0.46371433,1.3840997,,,,,,,,,,,,,, -16200,0.4763077,1.4060534,,,,,,,,,,,,,, -16300,0.6112152,1.4149531,,,,,,,,,,,,,, -16400,0.4633853,1.4053615,,,,,,,,,,,,,, -16500,0.5309142,1.3155156,,,,,,,,,,,,,, -16600,0.49798912,1.371269,,,,,,,,,,,,,, -16700,0.46792614,1.3082587,,,,,,,,,,,,,, -16800,0.59757924,1.3940665,,,,,,,,,,,,,, -16900,0.5248569,1.4537736,,,,,,,,,,,,,, -17000,0.52020764,1.3582901,,,,,,,,,,,,,, -17100,0.44064218,1.4121698,,,,,,,,,,,,,, -17200,0.47721648,1.4102254,,,,,,,,,,,,,, -17202,,,0.22558315,0.0871789646425372,0.50909185,0.1533834731648918,5348.0,0.30617744,0.1036499908597891,2472.0,14463.611997127531,15994.723123788834,14463.611997127531,1529.8941369056702,0.5137264728546143,0.0 -17300,0.5763906,1.3827018,,,,,,,,,,,,,, -17400,0.49194717,1.3586531,,,,,,,,,,,,,, -17500,0.5179711,1.4481252,,,,,,,,,,,,,, -17600,0.4784034,1.345219,,,,,,,,,,,,,, -17700,0.49476373,1.3561027,,,,,,,,,,,,,, -17800,0.49928588,1.429757,,,,,,,,,,,,,, -17900,0.53267324,1.403509,,,,,,,,,,,,,, -18000,0.631214,1.3818954,,,,,,,,,,,,,, -18100,0.51355445,1.4207126,,,,,,,,,,,,,, -18200,0.561684,1.3101354,,,,,,,,,,,,,, -18300,0.4964022,1.3838413,,,,,,,,,,,,,, -18400,0.45390823,1.3390664,,,,,,,,,,,,,, -18500,0.5414732,1.388434,,,,,,,,,,,,,, -18600,0.44265908,1.2348033,,,,,,,,,,,,,, -18700,0.6910544,1.3977938,,,,,,,,,,,,,, -18800,0.5116342,1.3363568,,,,,,,,,,,,,, -18900,0.52590775,1.3941041,,,,,,,,,,,,,, -18927,,,0.24126676,0.0893836431423358,0.5059711,0.1527269567568089,5348.0,0.29994157,0.1016391444762659,2472.0,15903.78661584854,17570.31193447113,15903.78661584854,1665.189467906952,0.5614349842071533,0.0 -19000,0.46065825,1.4004846,,,,,,,,,,,,,, -19100,0.5144198,1.3785508,,,,,,,,,,,,,, -19200,0.4957272,1.4111948,,,,,,,,,,,,,, -19300,0.6088981,1.4259683,,,,,,,,,,,,,, -19400,0.5671759,1.3204403,,,,,,,,,,,,,, -19500,0.44771037,1.3877854,,,,,,,,,,,,,, -19600,0.43395936,1.3589414,,,,,,,,,,,,,, -19700,0.47723058,1.3553607,,,,,,,,,,,,,, -19800,0.49107528,1.3719229,,,,,,,,,,,,,, -19900,0.4831145,1.3502713,,,,,,,,,,,,,, -20000,0.59779274,1.3484771,,,,,,,,,,,,,, -20100,0.44615984,1.3142571,,,,,,,,,,,,,, -20200,0.41465145,1.3092558,,,,,,,,,,,,,, -20300,0.5371235,1.314283,,,,,,,,,,,,,, -20400,0.5277575,1.3594791,,,,,,,,,,,,,, -20500,0.48991662,1.383411,,,,,,,,,,,,,, -20600,0.6092056,1.3596762,,,,,,,,,,,,,, -20605,,,0.24273874,0.0900097607302081,0.50120276,0.1509601552468212,5348.0,0.29017657,0.0997501675705319,2472.0,17343.6914768219,19145.483505010605,17343.6914768219,1800.3386902809143,0.6096725463867188,0.0 -20700,0.45102698,1.3901877,,,,,,,,,,,,,, -20800,0.46928433,1.3601661,,,,,,,,,,,,,, -20900,0.411791,1.332802,,,,,,,,,,,,,, -21000,0.5311229,1.3943598,,,,,,,,,,,,,, -21100,0.46444187,1.3653964,,,,,,,,,,,,,, -21200,0.51203173,1.3282274,,,,,,,,,,,,,, -21300,0.46213973,1.3461244,,,,,,,,,,,,,, -21400,0.5138596,1.3622295,,,,,,,,,,,,,, -21500,0.52741796,1.3629451,,,,,,,,,,,,,, -21600,0.50250447,1.4161772,,,,,,,,,,,,,, -21700,0.48111823,1.3270007,,,,,,,,,,,,,, -21800,0.48694333,1.3099462,,,,,,,,,,,,,, -21900,0.51028484,1.346506,,,,,,,,,,,,,, -22000,0.43613273,1.2999426,,,,,,,,,,,,,, -22100,0.46452755,1.3610643,,,,,,,,,,,,,, -22200,0.51367307,1.3162212,,,,,,,,,,,,,, -22300,0.5826027,1.3223734,,,,,,,,,,,,,, -22325,,,0.2258042,0.0840981582467332,0.48603284,0.1474651708390859,5348.0,0.28587866,0.0981252412000081,2472.0,18783.908142089844,20721.08587527275,18783.908142089844,1935.5979552268984,0.6652815341949463,0.0 -22400,0.6154244,1.3571624,,,,,,,,,,,,,, -22500,0.5419562,1.3457122,,,,,,,,,,,,,, -22600,0.47266972,1.3304704,,,,,,,,,,,,,, -22700,0.47445947,1.3164222,,,,,,,,,,,,,, -22800,0.4586193,1.3069608,,,,,,,,,,,,,, -22900,0.4596002,1.279214,,,,,,,,,,,,,, -23000,0.5126881,1.2897716,,,,,,,,,,,,,, -23100,0.4482884,1.2841014,,,,,,,,,,,,,, -23200,0.5812362,1.348482,,,,,,,,,,,,,, -23300,0.5021294,1.3019296,,,,,,,,,,,,,, -23400,0.56488484,1.3694371,,,,,,,,,,,,,, -23500,0.522482,1.2860907,,,,,,,,,,,,,, -23600,0.52842563,1.358485,,,,,,,,,,,,,, -23700,0.586441,1.2798408,,,,,,,,,,,,,, -23800,0.52551866,1.3088347,,,,,,,,,,,,,, -23900,0.49511948,1.2815148,,,,,,,,,,,,,, -24000,0.42605996,1.346642,,,,,,,,,,,,,, -24032,,,0.20639215,0.078781512605042,0.48175213,0.144308099288452,5348.0,0.28025538,0.094489468445961,2472.0,20223.88219809532,22295.987997055054,20223.88219809532,2070.4026024341583,0.7172105312347412,0.0 -24100,0.57339174,1.316748,,,,,,,,,,,,,, -24200,0.5443087,1.2954679,,,,,,,,,,,,,, -24300,0.5102812,1.2661258,,,,,,,,,,,,,, -24400,0.5167706,1.315532,,,,,,,,,,,,,, -24500,0.5049063,1.3142464,,,,,,,,,,,,,, -24600,0.4954869,1.3084493,,,,,,,,,,,,,, -24700,0.48147845,1.3190751,,,,,,,,,,,,,, -24800,0.6787148,1.3783742,,,,,,,,,,,,,, -24900,0.60145015,1.3308518,,,,,,,,,,,,,, -25000,0.5250175,1.3396565,,,,,,,,,,,,,, -25100,0.51404727,1.3240496,,,,,,,,,,,,,, -25200,0.72294205,1.3331795,,,,,,,,,,,,,, -25300,0.53478295,1.3226615,,,,,,,,,,,,,, -25400,0.56892896,1.2872978,,,,,,,,,,,,,, -25500,0.5363397,1.271742,,,,,,,,,,,,,, -25600,0.5078547,1.3067415,,,,,,,,,,,,,, -25700,0.4701313,1.2498289,,,,,,,,,,,,,, -25720,,,0.19494656,0.0738412582487806,0.46175957,0.1385153074524266,5348.0,0.27270424,0.0918489630938598,2472.0,21664.468277931213,23873.08303618431,21664.468277931213,2206.801432132721,0.7604649066925049,0.0 -25800,0.4782766,1.2884835,,,,,,,,,,,,,, -25900,0.64967,1.3146886,,,,,,,,,,,,,, -26000,0.496737,1.3223401,,,,,,,,,,,,,, -26100,0.5495429,1.3233815,,,,,,,,,,,,,, -26200,0.56923306,1.2830055,,,,,,,,,,,,,, -26300,0.562425,1.2621747,,,,,,,,,,,,,, -26400,0.5589081,1.377578,,,,,,,,,,,,,, -26500,0.46744743,1.2981257,,,,,,,,,,,,,, -26600,0.5363991,1.3788123,,,,,,,,,,,,,, -26700,0.5020377,1.3191351,,,,,,,,,,,,,, -26800,0.5447202,1.3490522,,,,,,,,,,,,,, -26900,0.54248327,1.3052409,,,,,,,,,,,,,, -27000,0.54364,1.2761426,,,,,,,,,,,,,, -27100,0.6032768,1.3271489,,,,,,,,,,,,,, -27200,0.4841137,1.2537593,,,,,,,,,,,,,, -27300,0.5093756,1.2685217,,,,,,,,,,,,,, -27400,0.45417553,1.2472142,,,,,,,,,,,,,, -27407,,,0.18990134,0.0708441153307217,0.45618972,0.137685007289263,5348.0,0.2689558,0.0907318262141246,2472.0,23104.85160088539,25450.283534526825,23104.85160088539,2343.496416091919,0.811537504196167,0.0 -27500,0.43667698,1.2798742,,,,,,,,,,,,,, -27600,0.52333343,1.2717265,,,,,,,,,,,,,, -27700,0.5460489,1.2328851,,,,,,,,,,,,,, -27800,0.593983,1.2509698,,,,,,,,,,,,,, -27900,0.46955913,1.2843686,,,,,,,,,,,,,, -28000,0.49415728,1.2818633,,,,,,,,,,,,,, -28100,0.46226993,1.2888365,,,,,,,,,,,,,, -28200,0.48302665,1.2073009,,,,,,,,,,,,,, -28300,0.5449948,1.3085436,,,,,,,,,,,,,, -28400,0.6189625,1.3121684,,,,,,,,,,,,,, -28500,0.63186914,1.2633206,,,,,,,,,,,,,, -28600,0.46228155,1.2371961,,,,,,,,,,,,,, -28700,0.46718374,1.2237505,,,,,,,,,,,,,, -28800,0.4978567,1.3180861,,,,,,,,,,,,,, -28900,0.5525051,1.3103644,,,,,,,,,,,,,, -29000,0.51527673,1.2944881,,,,,,,,,,,,,, -29100,0.5118164,1.2489619,,,,,,,,,,,,,, -29120,,,0.1999816,0.0741462991853007,0.45578963,0.1337362541877057,5348.0,0.2627904,0.0880710092823919,2472.0,24544.772507190704,27023.616567134857,24544.772507190704,2476.7793169021606,0.8699901103973389,0.0 -29200,0.4652416,1.3081727,,,,,,,,,,,,,, -29300,0.46887493,1.2636273,,,,,,,,,,,,,, -29400,0.4901474,1.261074,,,,,,,,,,,,,, -29500,0.50749713,1.2681557,,,,,,,,,,,,,, -29600,0.54392093,1.238251,,,,,,,,,,,,,, -29700,0.40558,1.2124836,,,,,,,,,,,,,, -29800,0.61841214,1.2366811,,,,,,,,,,,,,, -29900,0.49954808,1.2623788,,,,,,,,,,,,,, -30000,0.53263307,1.2672062,,,,,,,,,,,,,, -30100,0.46250784,1.2534267,,,,,,,,,,,,,, -30200,0.68934923,1.1967264,,,,,,,,,,,,,, -30300,0.5287578,1.2865217,,,,,,,,,,,,,, -30400,0.44191462,1.2344167,,,,,,,,,,,,,, -30500,0.5311167,1.2466341,,,,,,,,,,,,,, -30600,0.56960285,1.2554965,,,,,,,,,,,,,, -30700,0.51334083,1.2993639,,,,,,,,,,,,,, -30800,0.6038016,1.2871039,,,,,,,,,,,,,, -30815,,,0.18715546,0.068063310164461,0.43771347,0.1296619905963679,5348.0,0.2543592,0.0854101923506591,2472.0,25985.5729637146,28598.65405917168,25985.5729637146,2610.891562461853,0.9234447479248048,0.0 -30900,0.57502943,1.2424359,,,,,,,,,,,,,, -31000,0.5271104,1.2646819,,,,,,,,,,,,,, -31100,0.5075055,1.1896114,,,,,,,,,,,,,, -31200,0.5232008,1.2408665,,,,,,,,,,,,,, -31300,0.4815898,1.3145168,,,,,,,,,,,,,, -31400,0.49293935,1.2318866,,,,,,,,,,,,,, -31500,0.45884317,1.2097178,,,,,,,,,,,,,, -31600,0.5067639,1.26633,,,,,,,,,,,,,, -31700,0.50567234,1.2846389,,,,,,,,,,,,,, -31800,0.58174574,1.1919601,,,,,,,,,,,,,, -31900,0.6300122,1.2290281,,,,,,,,,,,,,, -32000,0.510256,1.2324951,,,,,,,,,,,,,, -32100,0.607268,1.2094082,,,,,,,,,,,,,, -32200,0.46644408,1.2417578,,,,,,,,,,,,,, -32300,0.6218079,1.2179499,,,,,,,,,,,,,, -32400,0.5293378,1.2808594,,,,,,,,,,,,,, -32500,0.6401974,1.2297634,,,,,,,,,,,,,, -32516,,,0.19618995,0.0722970786309068,0.43466973,0.1297199185147281,5348.0,0.25108758,0.0833993459671358,2472.0,27425.56795692444,30175.411805152893,27425.56795692444,2747.5177092552185,0.9884016513824464,0.0 -32600,0.51460826,1.2457707,,,,,,,,,,,,,, -32700,0.5364307,1.2295732,,,,,,,,,,,,,, -32800,0.49098554,1.2201279,,,,,,,,,,,,,, -32900,0.58996814,1.2011094,,,,,,,,,,,,,, -33000,0.5233165,1.2356272,,,,,,,,,,,,,, -33100,0.505827,1.1745828,,,,,,,,,,,,,, -33200,0.54519653,1.2536088,,,,,,,,,,,,,, -33300,0.5013179,1.1828432,,,,,,,,,,,,,, -33400,0.5259928,1.2375996,,,,,,,,,,,,,, -33500,0.48590165,1.2446257,,,,,,,,,,,,,, -33600,0.52161837,1.2535626,,,,,,,,,,,,,, -33700,0.559391,1.2261981,,,,,,,,,,,,,, -33800,0.5264334,1.2472161,,,,,,,,,,,,,, -33900,0.53242713,1.1994097,,,,,,,,,,,,,, -34000,0.51935345,1.206088,,,,,,,,,,,,,, -34100,0.64458936,1.2085344,,,,,,,,,,,,,, -34200,0.53801143,1.2732078,,,,,,,,,,,,,, -34220,,,0.19564296,0.0697113079062629,0.4341091,0.1286772159842436,5348.0,0.24840562,0.0835415270245567,2472.0,28865.83015680313,31751.92447423935,28865.83015680313,2883.643948793412,1.0415289402008057,0.0 -34300,0.5674896,1.2031741,,,,,,,,,,,,,, -34400,0.76914525,1.2194145,,,,,,,,,,,,,, -34500,0.518389,1.1931338,,,,,,,,,,,,,, -34600,0.4963167,1.2190715,,,,,,,,,,,,,, -34700,0.50373465,1.2276242,,,,,,,,,,,,,, -34800,0.5902496,1.2004633,,,,,,,,,,,,,, -34900,0.436058,1.2264359,,,,,,,,,,,,,, -35000,0.5345366,1.2412363,,,,,,,,,,,,,, -35100,0.55265474,1.2046047,,,,,,,,,,,,,, -35200,0.42072263,1.2123485,,,,,,,,,,,,,, -35300,0.52122223,1.2460366,,,,,,,,,,,,,, -35400,0.6172326,1.281753,,,,,,,,,,,,,, -35500,0.5591066,1.2183204,,,,,,,,,,,,,, -35600,0.5551681,1.2320733,,,,,,,,,,,,,, -35700,0.51297224,1.1946586,,,,,,,,,,,,,, -35800,0.59186715,1.2539347,,,,,,,,,,,,,, -35900,0.6179983,1.2017376,,,,,,,,,,,,,, -35924,,,0.14429401,0.0552002295952359,0.41627276,0.1244581325970051,5348.0,0.23915388,0.0792354721426685,2472.0,30307.033792495728,33329.742975473404,30307.033792495728,3020.135945081711,1.094953536987305,0.0 -36000,0.46553776,1.1895177,,,,,,,,,,,,,, -36100,0.48198968,1.2712427,,,,,,,,,,,,,, -36200,0.53133756,1.2303442,,,,,,,,,,,,,, -36300,0.47802946,1.1738467,,,,,,,,,,,,,, -36400,0.50439525,1.1824896,,,,,,,,,,,,,, -36500,0.62164414,1.2275023,,,,,,,,,,,,,, -36600,0.5750661,1.2338797,,,,,,,,,,,,,, -36700,0.4933276,1.2225987,,,,,,,,,,,,,, -36800,0.5785034,1.2792366,,,,,,,,,,,,,, -36900,0.51336455,1.2205342,,,,,,,,,,,,,, -37000,0.57223284,1.2546334,,,,,,,,,,,,,, -37100,0.55923676,1.1408414,,,,,,,,,,,,,, -37200,0.48287323,1.2520602,,,,,,,,,,,,,, -37300,0.56718594,1.1932808,,,,,,,,,,,,,, -37400,0.6154391,1.1452222,,,,,,,,,,,,,, -37500,0.6228452,1.2045269,,,,,,,,,,,,,, -37600,0.5163459,1.2029557,,,,,,,,,,,,,, -37625,,,0.16264132,0.0605425439800637,0.41757557,0.1228747694951581,5348.0,0.23604308,0.077773038409197,2472.0,31746.93484067917,34908.00041890144,31746.93484067917,3158.3678278923035,1.1490800380706787,0.0 -37700,0.7238531,1.1832095,,,,,,,,,,,,,, -37800,0.5025674,1.1877352,,,,,,,,,,,,,, -37900,0.5167019,1.1745563,,,,,,,,,,,,,, -38000,0.51100135,1.1149626,,,,,,,,,,,,,, -38100,0.6187992,1.1868975,,,,,,,,,,,,,, -38200,0.58648384,1.1971912,,,,,,,,,,,,,, -38300,0.82464814,1.1942554,,,,,,,,,,,,,, -38400,0.44917443,1.1767455,,,,,,,,,,,,,, -38500,0.67604333,1.186472,,,,,,,,,,,,,, -38600,0.50308526,1.174135,,,,,,,,,,,,,, -38700,0.51811725,1.2162638,,,,,,,,,,,,,, -38800,0.5331167,1.2033045,,,,,,,,,,,,,, -38900,0.57873094,1.1609297,,,,,,,,,,,,,, -39000,0.5687324,1.2271204,,,,,,,,,,,,,, -39100,0.59228224,1.1641694,,,,,,,,,,,,,, -39200,0.7543305,1.1934781,,,,,,,,,,,,,, -39300,0.47240776,1.2091581,,,,,,,,,,,,,, -39331,,,0.20321023,0.074391282415669,0.4084176,0.1200459561485658,5348.0,0.23083746,0.0759449962423577,2472.0,33187.41482400894,36483.80149149895,33187.41482400894,3293.553910017013,1.211296558380127,0.0 -39400,0.4829186,1.198878,,,,,,,,,,,,,, -39500,0.4962561,1.1543344,,,,,,,,,,,,,, -39600,0.6645627,1.2449253,,,,,,,,,,,,,, -39700,0.5068571,1.1912524,,,,,,,,,,,,,, -39800,0.52025217,1.1448377,,,,,,,,,,,,,, -39900,0.46993047,1.1997368,,,,,,,,,,,,,, -40000,0.63099974,1.1807983,,,,,,,,,,,,,, -40100,0.50122917,1.138653,,,,,,,,,,,,,, -40200,0.6286188,1.1366119,,,,,,,,,,,,,, -40300,0.52267903,1.1276782,,,,,,,,,,,,,, -40400,0.55227023,1.1392533,,,,,,,,,,,,,, -40500,0.4403243,1.1488905,,,,,,,,,,,,,, -40600,0.5506066,1.1869482,,,,,,,,,,,,,, -40700,0.56981623,1.1888579,,,,,,,,,,,,,, -40800,0.54528886,1.1701441,,,,,,,,,,,,,, -40900,0.48433763,1.1584896,,,,,,,,,,,,,, -41000,0.5803991,1.1762491,,,,,,,,,,,,,, -41030,,,0.2045451,0.0760643288164008,0.40269017,0.1192639292507023,5348.0,0.2242229,0.0741575772347815,2472.0,34627.85082554817,38057.18810248375,34627.85082554817,3426.381377220154,1.2644286155700684,0.0 -41100,0.5527987,1.1503174,,,,,,,,,,,,,, -41200,0.590877,1.1853215,,,,,,,,,,,,,, -41300,0.5409393,1.1289164,,,,,,,,,,,,,, -41400,0.428841,1.1259828,,,,,,,,,,,,,, -41500,0.4954423,1.0855745,,,,,,,,,,,,,, -41600,0.5501447,1.1378224,,,,,,,,,,,,,, -41700,0.50307286,1.1595232,,,,,,,,,,,,,, -41800,0.66320354,1.1801659,,,,,,,,,,,,,, -41900,0.6695818,1.1762488,,,,,,,,,,,,,, -42000,0.5613547,1.156035,,,,,,,,,,,,,, -42100,0.56734043,1.2045991,,,,,,,,,,,,,, -42200,0.5564548,1.1412145,,,,,,,,,,,,,, -42300,0.5902517,1.1543244,,,,,,,,,,,,,, -42400,0.5951358,1.1420717,,,,,,,,,,,,,, -42500,0.5208353,1.1813116,,,,,,,,,,,,,, -42600,0.54159737,1.1240848,,,,,,,,,,,,,, -42700,0.5062276,1.111936,,,,,,,,,,,,,, -42744,,,0.23569493,0.0872179747413001,0.39436883,0.1162323681898491,5348.0,0.22219394,0.0729998171957833,2472.0,36068.26052713394,39631.28194499016,36068.26052713394,3559.942678451538,1.317253351211548,0.0 -42800,0.5448907,1.1331464,,,,,,,,,,,,,, -42900,0.59461933,1.1789252,,,,,,,,,,,,,, -43000,0.5644271,1.1564441,,,,,,,,,,,,,, -43100,0.5333609,1.1462455,,,,,,,,,,,,,, -43200,0.48968968,1.1317984,,,,,,,,,,,,,, -43300,0.6042083,1.122153,,,,,,,,,,,,,, -43400,0.57470685,1.1760776,,,,,,,,,,,,,, -43500,0.49976692,1.0957264,,,,,,,,,,,,,, -43600,0.6086421,1.1401612,,,,,,,,,,,,,, -43700,0.4849125,1.1788392,,,,,,,,,,,,,, -43800,0.5442171,1.1309984,,,,,,,,,,,,,, -43900,0.5509754,1.101449,,,,,,,,,,,,,, -44000,0.56566054,1.151917,,,,,,,,,,,,,, -44100,0.5174999,1.1497067,,,,,,,,,,,,,, -44200,0.58694726,1.134881,,,,,,,,,,,,,, -44300,0.69814605,1.1651208,,,,,,,,,,,,,, -44400,0.4824122,1.1249392,,,,,,,,,,,,,, -44447,,,0.20926268,0.0754552573167931,0.3915589,0.1143304015370207,5348.0,0.2164622,0.0709483476529969,2472.0,37508.558656692505,41204.77492642403,37508.558656692505,3693.0100190639496,1.3747758865356443,0.0 -44500,0.57467544,1.116429,,,,,,,,,,,,,, -44600,0.46867412,1.1096474,,,,,,,,,,,,,, -44700,0.6238302,1.1455901,,,,,,,,,,,,,, -44800,0.5945919,1.148514,,,,,,,,,,,,,, -44900,0.8315188,1.1395079,,,,,,,,,,,,,, -45000,0.62107897,1.1614705,,,,,,,,,,,,,, -45100,0.5630704,1.1436669,,,,,,,,,,,,,, -45200,0.54794383,1.1417011,,,,,,,,,,,,,, -45300,0.56534207,1.1148067,,,,,,,,,,,,,, -45400,0.50789016,1.1623167,,,,,,,,,,,,,, -45500,0.48096466,1.0587554,,,,,,,,,,,,,, -45600,0.5325938,1.0807276,,,,,,,,,,,,,, -45700,0.59929633,1.1433058,,,,,,,,,,,,,, -45800,0.5788215,1.1169089,,,,,,,,,,,,,, -45900,0.5011375,1.1728357,,,,,,,,,,,,,, -46000,0.54117054,1.061901,,,,,,,,,,,,,, -46100,0.5866891,1.1075587,,,,,,,,,,,,,, -46143,,,0.18335475,0.0694053199425473,0.37230888,0.1101982100273226,5348.0,0.2097056,0.0708671013344707,2472.0,38948.61928009987,42779.89405918121,38948.61928009987,3827.943774223328,1.4282572269439695,0.0 -46200,0.5054727,1.1021166,,,,,,,,,,,,,, -46300,0.62686205,1.1067313,,,,,,,,,,,,,, -46400,0.55591077,1.118545,,,,,,,,,,,,,, -46500,0.5359194,1.0704361,,,,,,,,,,,,,, -46600,0.5711483,1.1268798,,,,,,,,,,,,,, -46700,0.5256616,1.1135281,,,,,,,,,,,,,, -46800,0.59579784,1.1042295,,,,,,,,,,,,,, -46900,0.54366755,1.0949842,,,,,,,,,,,,,, -47000,0.7245978,1.1571243,,,,,,,,,,,,,, -47100,0.5390581,1.079757,,,,,,,,,,,,,, -47200,0.7842848,1.0951035,,,,,,,,,,,,,, -47300,0.64281505,1.1625206,,,,,,,,,,,,,, -47400,0.57366747,1.0962268,,,,,,,,,,,,,, -47500,0.506592,1.1052461,,,,,,,,,,,,,, -47600,0.7626832,1.1511829,,,,,,,,,,,,,, -47700,0.59526306,1.1015784,,,,,,,,,,,,,, -47800,0.49222374,1.0927991,,,,,,,,,,,,,, -47857,,,0.15611507,0.0589402821129104,0.37141645,0.1087403574152562,5348.0,0.20599574,0.068429711778685,2472.0,40388.79512381554,44354.54707813263,40388.79512381554,3962.282675266266,1.4953930377960205,0.0 -47900,0.55322814,1.0961174,,,,,,,,,,,,,, -48000,0.8607046,1.0998702,,,,,,,,,,,,,, -48100,0.6215248,1.1451746,,,,,,,,,,,,,, -48200,1.1941438,1.0972748,,,,,,,,,,,,,, -48300,0.60540164,1.1367903,,,,,,,,,,,,,, -48400,0.62899417,1.0973315,,,,,,,,,,,,,, -48500,0.53965235,1.09766,,,,,,,,,,,,,, -48600,0.65611064,1.0985538,,,,,,,,,,,,,, -48700,0.8008194,1.0987122,,,,,,,,,,,,,, -48800,0.5401238,1.0818576,,,,,,,,,,,,,, -48900,0.51671284,1.0706433,,,,,,,,,,,,,, -49000,0.7195842,1.0982159,,,,,,,,,,,,,, -49100,0.6688149,1.0756428,,,,,,,,,,,,,, -49200,0.5110664,1.0466799,,,,,,,,,,,,,, -49300,0.5058543,1.0912976,,,,,,,,,,,,,, -49400,0.56902415,1.0920901,,,,,,,,,,,,,, -49500,0.6008482,1.0939395,,,,,,,,,,,,,, -49556,,,0.1717872,0.0652484426509963,0.36111674,0.105728105660523,5348.0,0.20355535,0.0662969959173725,2472.0,41829.10494494438,45930.09180688858,41829.10494494438,4097.3878445625305,1.5552036762237549,0.0 -49600,0.59745777,1.1223617,,,,,,,,,,,,,, -49700,0.6062723,1.0729201,,,,,,,,,,,,,, -49800,0.8263997,1.0352017,,,,,,,,,,,,,, -49900,0.57278633,1.1009635,,,,,,,,,,,,,, -50000,0.52887416,1.1087222,,,,,,,,,,,,,, -50100,0.53519034,1.1042714,,,,,,,,,,,,,, -50200,0.5457821,1.0695872,,,,,,,,,,,,,, -50300,0.576365,1.081145,,,,,,,,,,,,,, -50400,0.59995776,1.0850177,,,,,,,,,,,,,, -50500,0.6311331,1.049565,,,,,,,,,,,,,, -50600,0.5994973,1.083729,,,,,,,,,,,,,, -50700,0.62031734,1.0704249,,,,,,,,,,,,,, -50800,0.5373812,1.0690988,,,,,,,,,,,,,, -50900,0.73245424,1.0612155,,,,,,,,,,,,,, -51000,0.65438604,1.0338676,,,,,,,,,,,,,, -51100,0.53860277,1.0889431,,,,,,,,,,,,,, -51200,0.5720718,1.0899925,,,,,,,,,,,,,, -51252,,,0.15362392,0.0589192149115272,0.35629016,0.1045116193749577,5348.0,0.19794679,0.0652001706172689,2472.0,43269.49426746368,47505.0118470192,43269.49426746368,4231.793177843094,1.6100687980651855,0.0 -51300,0.55059206,1.0094041,,,,,,,,,,,,,, -51400,0.67572993,1.1275946,,,,,,,,,,,,,, -51500,0.5978368,1.093505,,,,,,,,,,,,,, -51600,0.6760782,1.0813732,,,,,,,,,,,,,, -51700,0.63046646,1.0782621,,,,,,,,,,,,,, -51800,0.55466,1.0683639,,,,,,,,,,,,,, -51900,0.6505552,1.1494064,,,,,,,,,,,,,, -52000,0.5747212,1.081234,,,,,,,,,,,,,, -52100,0.7235015,1.0622317,,,,,,,,,,,,,, -52200,0.56461376,1.0697395,,,,,,,,,,,,,, -52300,0.5791858,1.0518334,,,,,,,,,,,,,, -52400,0.60297614,1.0884725,,,,,,,,,,,,,, -52500,0.74863297,1.0677401,,,,,,,,,,,,,, -52600,0.55248797,1.0293304,,,,,,,,,,,,,, -52700,0.59833544,1.0991108,,,,,,,,,,,,,, -52800,0.7603177,0.9893423,,,,,,,,,,,,,, -52900,0.5563417,1.0567579,,,,,,,,,,,,,, -52977,,,0.14720261,0.0560105880949023,0.35049742,0.1014704036610444,5348.0,0.19058819,0.063758048463429,2472.0,44709.73230814934,49081.2375099659,44709.73230814934,4367.57383275032,1.7452752590179443,0.0 -53000,0.6944481,1.1233406,,,,,,,,,,,,,, -53100,0.79628366,1.0771253,,,,,,,,,,,,,, -53200,0.63908726,1.0454665,,,,,,,,,,,,,, -53300,0.72827244,1.0347335,,,,,,,,,,,,,, -53400,0.6133597,1.0958395,,,,,,,,,,,,,, -53500,0.56776154,1.1163017,,,,,,,,,,,,,, -53600,0.6169492,1.0594856,,,,,,,,,,,,,, -53700,0.57019085,0.99276656,,,,,,,,,,,,,, -53800,0.64046276,1.0646275,,,,,,,,,,,,,, -53900,0.5832153,1.0335432,,,,,,,,,,,,,, -54000,0.62189394,1.094321,,,,,,,,,,,,,, -54100,0.53565603,0.99413383,,,,,,,,,,,,,, -54200,0.68365973,1.0400732,,,,,,,,,,,,,, -54300,0.779343,1.0421959,,,,,,,,,,,,,, -54400,0.64165705,1.0619906,,,,,,,,,,,,,, -54500,0.560686,1.0422367,,,,,,,,,,,,,, -54600,0.65059304,1.0163556,,,,,,,,,,,,,, -54683,,,0.13729951,0.0526597653554175,0.34367317,0.0999449684775577,5348.0,0.18795525,0.061869071557695,2472.0,46150.22907114029,50657.874841213226,46150.22907114029,4503.586663246155,1.8032231330871584,0.0 -54700,0.6681912,1.0367413,,,,,,,,,,,,,, -54800,0.6559353,1.0386122,,,,,,,,,,,,,, -54900,0.5803303,1.0385145,,,,,,,,,,,,,, -55000,0.612561,1.0199565,,,,,,,,,,,,,, -55100,0.70010996,1.0119369,,,,,,,,,,,,,, -55200,0.54757077,1.0141035,,,,,,,,,,,,,, -55300,0.7448363,1.0759625,,,,,,,,,,,,,, -55400,0.61010355,0.9985501,,,,,,,,,,,,,, -55500,0.65282243,1.0368562,,,,,,,,,,,,,, -55600,0.66976357,1.054236,,,,,,,,,,,,,, -55700,0.7688889,1.0086778,,,,,,,,,,,,,, -55800,0.8134853,1.0142688,,,,,,,,,,,,,, -55900,0.8733585,1.0744485,,,,,,,,,,,,,, -56000,0.6196073,1.0372198,,,,,,,,,,,,,, -56100,0.90139365,1.0399172,,,,,,,,,,,,,, -56200,1.3435783,1.0215296,,,,,,,,,,,,,, -56300,0.56128323,1.0199134,,,,,,,,,,,,,, -56392,,,0.1411343,0.0551892677678193,0.33604154,0.0983905693348909,5348.0,0.183492,0.0607316230983283,2472.0,47590.79759192467,52232.37819981575,47590.79759192467,4637.395668745041,1.857304096221924,0.0 -56400,0.6314394,0.9920193,,,,,,,,,,,,,, -56500,0.63126403,1.0181003,,,,,,,,,,,,,, -56600,0.6223298,1.0376225,,,,,,,,,,,,,, -56700,0.7107539,1.008718,,,,,,,,,,,,,, -56800,0.5967325,0.97964644,,,,,,,,,,,,,, -56900,0.5831157,0.9920256,,,,,,,,,,,,,, -57000,0.75225776,1.0289686,,,,,,,,,,,,,, -57100,0.6411861,1.0232695,,,,,,,,,,,,,, -57200,0.7000084,1.01247,,,,,,,,,,,,,, -57300,0.635508,1.026757,,,,,,,,,,,,,, -57400,0.60136086,1.065973,,,,,,,,,,,,,, -57500,0.6303777,1.0022519,,,,,,,,,,,,,, -57600,0.7335915,1.053035,,,,,,,,,,,,,, -57700,0.6067291,1.0104104,,,,,,,,,,,,,, -57800,0.6602871,0.9613992,,,,,,,,,,,,,, -57900,0.6379449,1.0045319,,,,,,,,,,,,,, -58000,0.63794905,0.98163855,,,,,,,,,,,,,, -58100,0.6789861,0.9516508,,,,,,,,,,,,,, -58122,,,0.14084783,0.0525272547076313,0.33468187,0.096131380518841,5348.0,0.18152194,0.0592082546259622,2472.0,49030.792356967926,53806.75104427338,49030.792356967926,4771.6484117507935,1.911901473999024,0.0 -58200,0.7397711,0.9617887,,,,,,,,,,,,,, -58300,0.7008934,1.0443771,,,,,,,,,,,,,, -58400,0.6932297,1.0365194,,,,,,,,,,,,,, -58500,0.5526505,0.98076093,,,,,,,,,,,,,, -58600,0.7274084,0.96336174,,,,,,,,,,,,,, -58700,0.65086234,1.0623112,,,,,,,,,,,,,, -58800,0.70929956,0.98341966,,,,,,,,,,,,,, -58900,0.68114156,0.95911,,,,,,,,,,,,,, -59000,0.58765674,1.0150174,,,,,,,,,,,,,, -59100,0.7289384,1.031154,,,,,,,,,,,,,, -59200,0.5122751,0.999536,,,,,,,,,,,,,, -59300,0.58718264,0.98900086,,,,,,,,,,,,,, -59400,0.735547,0.996539,,,,,,,,,,,,,, -59500,0.67335737,1.0725584,,,,,,,,,,,,,, -59600,0.6439682,0.963225,,,,,,,,,,,,,, -59700,0.6925219,0.9821723,,,,,,,,,,,,,, -59800,0.7146865,1.0241194,,,,,,,,,,,,,, -59821,,,0.11546845,0.0449368933127087,0.32863495,0.0947411104781949,5348.0,0.1730187,0.0560599597830723,2472.0,50470.96611952782,55380.60426187515,50470.96611952782,4905.197862625122,1.971836805343628,0.0 -59900,0.6139311,1.0163689,,,,,,,,,,,,,, -60000,0.6178588,0.97576094,,,,,,,,,,,,,, -60100,0.61212313,1.0042851,,,,,,,,,,,,,, -60200,0.6373853,0.9778135,,,,,,,,,,,,,, -60300,0.68702644,1.0161594,,,,,,,,,,,,,, -60400,0.6625593,0.98734534,,,,,,,,,,,,,, -60500,0.70206714,0.9831215,,,,,,,,,,,,,, -60600,0.5864677,0.9343977,,,,,,,,,,,,,, -60700,0.8760033,0.9768546,,,,,,,,,,,,,, -60800,0.7338986,1.0426474,,,,,,,,,,,,,, -60900,0.6009136,0.9245907,,,,,,,,,,,,,, -61000,0.645703,0.98905456,,,,,,,,,,,,,, -61100,0.708048,0.9714027,,,,,,,,,,,,,, -61200,0.60669017,0.99748945,,,,,,,,,,,,,, -61300,0.80595344,0.97405654,,,,,,,,,,,,,, -61400,0.6354075,0.9779363,,,,,,,,,,,,,, -61500,0.62933433,0.94639766,,,,,,,,,,,,,, -61529,,,0.11642376,0.0449945232987996,0.32271284,0.0927232879886461,5348.0,0.17211151,0.0557959092478622,2472.0,51910.99819993973,56956.63300848007,51910.99819993973,5041.062778234482,2.031010150909424,0.0 -61600,0.5769339,0.9500325,,,,,,,,,,,,,, -61700,0.7071683,1.0072175,,,,,,,,,,,,,, -61800,0.584846,0.964104,,,,,,,,,,,,,, -61900,0.7375923,0.9596453,,,,,,,,,,,,,, -62000,0.67029977,0.91655266,,,,,,,,,,,,,, -62100,0.5864159,0.9534074,,,,,,,,,,,,,, -62200,0.5647091,0.9581043,,,,,,,,,,,,,, -62300,0.67588097,0.9550067,,,,,,,,,,,,,, -62400,0.6847522,0.9310921,,,,,,,,,,,,,, -62500,0.61473113,0.9358105,,,,,,,,,,,,,, -62600,0.72425175,0.9308104,,,,,,,,,,,,,, -62700,0.59736675,0.9772905,,,,,,,,,,,,,, -62800,0.61590266,0.94877803,,,,,,,,,,,,,, -62900,0.75094485,0.9460888,,,,,,,,,,,,,, -63000,0.91010934,0.9700502,,,,,,,,,,,,,, -63100,0.7403481,0.94956017,,,,,,,,,,,,,, -63200,0.6517305,0.9514689,,,,,,,,,,,,,, -63247,,,0.10640869,0.0406809355889836,0.31391063,0.0896820722747328,5348.0,0.17125273,0.0544756565718115,2472.0,53351.327491045,58531.71190810204,53351.327491045,5175.682463884354,2.0892953872680664,0.0 -63300,0.75534326,0.95742595,,,,,,,,,,,,,, -63400,0.6111964,0.9748041,,,,,,,,,,,,,, -63500,1.4577426,0.9808365,,,,,,,,,,,,,, -63600,1.1022938,0.9304377,,,,,,,,,,,,,, -63700,0.6374295,0.94543296,,,,,,,,,,,,,, -63800,0.76282126,0.90588623,,,,,,,,,,,,,, -63900,0.68279135,0.8843329,,,,,,,,,,,,,, -64000,0.6003123,0.94943607,,,,,,,,,,,,,, -64100,0.59916645,0.947142,,,,,,,,,,,,,, -64200,0.58347344,0.91920495,,,,,,,,,,,,,, -64300,0.73040134,0.9443205,,,,,,,,,,,,,, -64400,0.67881,0.984356,,,,,,,,,,,,,, -64500,0.6068892,0.9497826,,,,,,,,,,,,,, -64600,0.6284969,0.974481,,,,,,,,,,,,,, -64700,1.106601,0.9435974,,,,,,,,,,,,,, -64800,0.7092499,0.95817715,,,,,,,,,,,,,, -64900,0.8988729,0.9160624,,,,,,,,,,,,,, -64925,,,0.10517093,0.0403502678820521,0.31392825,0.0894503606012917,5348.0,0.1674624,0.0541303597180752,2472.0,54791.25001358986,60104.87237620354,54791.25001358986,5308.787171840668,2.1523008346557617,0.0 -65000,0.593349,0.91138107,,,,,,,,,,,,,, -65100,0.7637152,0.9388177,,,,,,,,,,,,,, -65200,0.88080466,0.9666469,,,,,,,,,,,,,, -65300,0.837578,0.91178924,,,,,,,,,,,,,, -65400,0.6486508,1.0109334,,,,,,,,,,,,,, -65500,0.69848466,0.9558227,,,,,,,,,,,,,, -65600,0.8785984,0.979307,,,,,,,,,,,,,, -65700,0.7313632,1.0057601,,,,,,,,,,,,,, -65800,0.7499255,0.95168954,,,,,,,,,,,,,, -65900,0.75566006,0.96413845,,,,,,,,,,,,,, -66000,0.82424617,0.9377213,,,,,,,,,,,,,, -66100,0.7242827,0.93509895,,,,,,,,,,,,,, -66200,0.95227647,0.94529074,,,,,,,,,,,,,, -66300,0.79753906,0.94008875,,,,,,,,,,,,,, -66400,0.7374776,0.8933683,,,,,,,,,,,,,, -66500,0.6576136,0.96200526,,,,,,,,,,,,,, -66600,0.6549731,0.9091043,,,,,,,,,,,,,, -66640,,,0.09685134,0.0373594060955703,0.30721417,0.086544310030219,5348.0,0.16359144,0.0522413828123413,2472.0,56231.39729690552,61680.23007559776,56231.39729690552,5443.867123842239,2.211320400238037,0.0 -66700,0.78389066,0.94371104,,,,,,,,,,,,,, -66800,0.7010138,0.96330816,,,,,,,,,,,,,, -66900,0.66956997,0.949526,,,,,,,,,,,,,, -67000,0.6915197,0.9612502,,,,,,,,,,,,,, -67100,0.6973814,0.9167675,,,,,,,,,,,,,, -67200,0.60402906,0.9174715,,,,,,,,,,,,,, -67300,0.6298336,0.9513805,,,,,,,,,,,,,, -67400,0.70940125,0.911856,,,,,,,,,,,,,, -67500,0.9986484,0.9397338,,,,,,,,,,,,,, -67600,0.66711384,0.90326095,,,,,,,,,,,,,, -67700,0.6418539,0.9528086,,,,,,,,,,,,,, -67800,0.6892298,0.92587984,,,,,,,,,,,,,, -67900,0.6285238,0.9158673,,,,,,,,,,,,,, -68000,0.8062157,0.9325538,,,,,,,,,,,,,, -68100,0.68752176,0.89254546,,,,,,,,,,,,,, -68200,0.7083718,0.9227814,,,,,,,,,,,,,, -68300,0.789455,0.9137091,,,,,,,,,,,,,, -68358,,,0.09348633,0.0353541502938799,0.30337724,0.0861291599486372,5348.0,0.1620917,0.0517742164808157,2472.0,57671.69725656509,63257.30109000206,57671.69725656509,5580.501723051071,2.2747132778167725,0.0 -68400,0.7718257,0.9197634,,,,,,,,,,,,,, -68500,0.60012394,0.8828389,,,,,,,,,,,,,, -68600,0.67273784,0.9088723,,,,,,,,,,,,,, -68700,0.71656096,0.94151634,,,,,,,,,,,,,, -68800,0.60232973,0.8912815,,,,,,,,,,,,,, -68900,0.630893,0.8821797,,,,,,,,,,,,,, -69000,0.8816849,0.9178651,,,,,,,,,,,,,, -69100,1.1049896,0.9394253,,,,,,,,,,,,,, -69200,0.63981533,0.94651675,,,,,,,,,,,,,, -69300,0.68659,0.93856144,,,,,,,,,,,,,, -69400,0.8331527,0.9602542,,,,,,,,,,,,,, -69500,0.75990987,0.88114226,,,,,,,,,,,,,, -69600,0.6453503,0.9285572,,,,,,,,,,,,,, -69700,0.8577634,0.9745834,,,,,,,,,,,,,, -69800,0.6994017,0.88534236,,,,,,,,,,,,,, -69900,0.8677264,0.9037722,,,,,,,,,,,,,, -70000,0.68896496,0.86849827,,,,,,,,,,,,,, -70048,,,0.07539534,0.029094321914594,0.29719102,0.0839278990509476,5348.0,0.15958287,0.0507992606585014,2472.0,59111.74648475647,64831.84681892395,59111.74648475647,5714.869594812393,2.3328914642333984,0.0 -70048,,,,,,,,,,,59111.74648475647,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 4f7357a45..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,27 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -181.3702912330628,0.0,41.58487391471863,1,0,41.58487391471863,30.129211,2472,1.6001665549529789,222.955227136612,30.794819,1.6620837931824337,29.931355,5348,1.4281742085598157 -312.3856499195099,0.0425958633422851,1481.7585852146149,1753,0,1481.7585852146149,1.6972979,2472,0.418763837263624,1794.2577757835388,1.7393318,0.4204444810611522,2.1726947,5348,0.4895681473686243 -448.3193964958191,0.0848848819732666,2921.7802641391754,3512,0,2921.7802641391754,0.60292125,2472,0.1893039221660268,3370.328623056412,0.57093966,0.1871507338252137,0.9262279,5348,0.2626355271923303 -584.1421117782593,0.1240792274475097,4362.251379489899,5222,0,4362.251379489899,0.49402905,2472,0.1591209148335466,4946.733110427856,0.42450523,0.1461907849307631,0.7988043,5348,0.2279270494414783 -721.6570925712585,0.1635785102844238,5802.464230775833,6949,0,5802.464230775833,0.46259597,2472,0.1485995165844047,6524.573171615601,0.4194605,0.1420006492894469,0.76416,5348,0.2194599187078212 -855.529622554779,0.2090013027191162,7242.510432720184,8699,0,7242.510432720184,0.41620168,2472,0.1332439623829545,8098.611032485962,0.36787793,0.1260410098066929,0.6959745,5348,0.1992237658939726 -988.4021236896516,0.2537970542907715,8682.545643806458,10428,0,8682.545643806458,0.39532664,2472,0.1270286190157008,9671.63682961464,0.37247258,0.1249606563848685,0.6662017,5348,0.1914131515683984 -1124.0344214439392,0.2944934368133545,10123.035349369047,12141,0,10123.035349369047,0.38549662,2472,0.1246521641988097,11247.871819019318,0.33491787,0.1148139876204454,0.6525444,5348,0.188130569527984 -1259.844975233078,0.3318409919738769,11563.117947101591,13894,0,11563.117947101591,0.37729406,2472,0.1218288546300245,12823.876756429672,0.2947174,0.1042824826438545,0.65284616,5348,0.1880436776504436 -1400.8908877372742,0.3732874393463135,13003.560472249985,15623,0,13003.560472249985,0.36129552,2472,0.1158166270590863,14405.480276107788,0.27732593,0.0969840771616596,0.61508054,5348,0.176709114957954 -1535.9185712337494,0.4116241931915283,14444.06426525116,17340,0,14444.06426525116,0.34505987,2472,0.1124042816809863,15981.121856212616,0.28093824,0.0972719033394649,0.61341965,5348,0.1754636647132085 -1672.7844877243042,0.4560194015502929,15884.91688466072,19066,0,15884.91688466072,0.34121943,2472,0.1097434647492535,17558.958316087723,0.29000315,0.0995108278907427,0.59839016,5348,0.1725286501829556 -1817.4322745800016,0.4951100349426269,17325.00393819809,20790,0,17325.00393819809,0.32885116,2472,0.1038734182357361,19143.80579972267,0.28017858,0.09633576911916,0.5747481,5348,0.1661276152041476 -1953.54097032547,0.5389456748962402,18765.608085393906,22510,0,18765.608085393906,0.3162405,2472,0.1023703613430016,20720.63533425331,0.26128995,0.0898394896594299,0.5584282,5348,0.1600934570416212 -2088.428907632828,0.5796234607696533,20206.615512132645,24240,0,20206.615512132645,0.30818713,2472,0.0985924075315337,22296.645008563995,0.2412733,0.084118069338288,0.5459013,5348,0.1572067157766685 -2224.599086999893,0.6414904594421387,21647.060393571854,25970,0,21647.060393571854,0.29683977,2472,0.095748786383117,23873.397852659225,0.22022744,0.0780615039515897,0.52577573,5348,0.1526979927976288 -2361.101853847504,0.6938507556915283,23087.36500597,27677,0,23087.36500597,0.284106,2472,0.0905083988381776,25450.333413124084,0.20966426,0.0734011107588139,0.510851,5348,0.146586597410622 -2495.944045782089,0.7462542057037354,24527.617062807083,29394,0,24527.617062807083,0.27293038,2472,0.0875632197916032,27025.553820371628,0.21916138,0.0770064135898769,0.49210644,5348,0.1443563725537522 -2631.939340829849,0.8093528747558594,25968.251851797104,31101,0,25968.251851797104,0.27050504,2472,0.0872585460971299,28602.32335114479,0.20257692,0.0690345579737274,0.48863128,5348,0.1420971837377024 -2768.881613969803,0.8643975257873535,27408.37361598015,32790,0,27408.37361598015,0.25773567,2472,0.0819369122336644,30179.51735305786,0.20776458,0.0702371718147214,0.47163725,5348,0.1355223650038136 -2903.467637300492,0.9291818141937256,28848.777802228928,34502,0,28848.777802228928,0.25175428,2472,0.0811650722076655,31754.64643263817,0.1990523,0.0663810650383111,0.46344855,5348,0.1336010890448651 -3042.599319934845,0.9865810871124268,30289.52933859825,36191,0,30289.52933859825,0.24260797,2472,0.0783214510592488,33334.662573099136,0.14591317,0.0517768006313869,0.44477853,5348,0.1280689728414609 -3178.628196001053,1.0424203872680664,31730.143264770508,37893,0,31730.143264770508,0.23578015,2472,0.0751934677959905,34911.43565702438,0.16582121,0.0575542227196891,0.4391291,5348,0.1269779970456761 -3311.199209690094,1.094820261001587,33170.094004154205,39608,0,33170.094004154205,0.23024255,2472,0.0734669835273089,36484.08493804932,0.20992558,0.0719612937520001,0.42906445,5348,0.1237726522297421 -3446.5949206352234,1.1482131481170654,34610.70112776756,41293,0,34610.70112776756,0.22649801,2472,0.0714358255641541,38060.21517109871,0.21701375,0.0753537585282312,0.42279768,5348,0.1214169168830918 -3582.435993909836,1.1999573707580566,36051.13100504875,42989,0,36051.13100504875,0.22220674,2472,0.07056242763999757,39636.61383128166,0.24711965,0.08698398956426111,0.41582108,5348,0.11951495023026347 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 405adab2a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,458 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.763968,33.378967,,,,,,,,,,,,,, -1,,,30.794819,1.6620837931824337,29.931355,1.4281742085598157,5348.0,30.129211,1.6001665549529789,2472.0,41.58487391471863,222.955227136612,41.58487391471863,181.3702912330628,0.0,0.0 -100,0.67033994,6.121869,,,,,,,,,,,,,, -200,0.3740788,5.839669,,,,,,,,,,,,,, -300,0.8324561,5.6766768,,,,,,,,,,,,,, -400,0.58280444,5.162966,,,,,,,,,,,,,, -500,1.5972025,4.3179746,,,,,,,,,,,,,, -600,1.7649777,3.6992612,,,,,,,,,,,,,, -700,2.2306197,3.421361,,,,,,,,,,,,,, -800,2.7735322,3.131822,,,,,,,,,,,,,, -900,2.0546234,2.9361598,,,,,,,,,,,,,, -1000,2.2665453,2.8014238,,,,,,,,,,,,,, -1100,1.9623945,2.6943603,,,,,,,,,,,,,, -1200,1.8854458,2.5311759,,,,,,,,,,,,,, -1300,2.0509534,2.5597763,,,,,,,,,,,,,, -1400,2.823341,2.5098798,,,,,,,,,,,,,, -1500,1.649399,2.3210092,,,,,,,,,,,,,, -1600,2.170584,2.3141942,,,,,,,,,,,,,, -1700,2.226357,2.2729833,,,,,,,,,,,,,, -1753,,,1.7393318,0.4204444810611522,2.1726947,0.4895681473686243,5348.0,1.6972979,0.418763837263624,2472.0,1481.7585852146149,1794.2577757835388,1481.7585852146149,312.3856499195099,0.0425958633422851,0.0 -1800,2.161251,2.1271567,,,,,,,,,,,,,, -1900,2.352257,2.1761355,,,,,,,,,,,,,, -2000,2.028306,2.1935034,,,,,,,,,,,,,, -2100,3.3576756,2.0766423,,,,,,,,,,,,,, -2200,3.9327888,2.081422,,,,,,,,,,,,,, -2300,2.402216,1.9836248,,,,,,,,,,,,,, -2400,2.2566812,1.9992378,,,,,,,,,,,,,, -2500,2.212346,2.0314343,,,,,,,,,,,,,, -2600,2.7127788,2.048924,,,,,,,,,,,,,, -2700,2.6491237,2.0563967,,,,,,,,,,,,,, -2800,2.4701977,1.9745324,,,,,,,,,,,,,, -2900,2.735281,1.9543754,,,,,,,,,,,,,, -3000,2.69668,1.9302442,,,,,,,,,,,,,, -3100,2.3972354,1.9228039,,,,,,,,,,,,,, -3200,3.4134936,1.8641174,,,,,,,,,,,,,, -3300,2.5135942,1.9062359,,,,,,,,,,,,,, -3400,2.5291266,1.8423,,,,,,,,,,,,,, -3500,4.3212,1.8767371,,,,,,,,,,,,,, -3512,,,0.57093966,0.1871507338252137,0.9262279,0.2626355271923303,5348.0,0.60292125,0.1893039221660268,2472.0,2921.7802641391754,3370.328623056412,2921.7802641391754,448.3193964958191,0.0848848819732666,0.0 -3600,2.9836526,1.7981024,,,,,,,,,,,,,, -3700,3.0627468,1.893896,,,,,,,,,,,,,, -3800,3.1085596,1.8669096,,,,,,,,,,,,,, -3900,2.8641407,1.824852,,,,,,,,,,,,,, -4000,2.709447,1.8856064,,,,,,,,,,,,,, -4100,2.4768183,1.8240156,,,,,,,,,,,,,, -4200,2.1714547,1.7664051,,,,,,,,,,,,,, -4300,2.519824,1.8771327,,,,,,,,,,,,,, -4400,1.9360901,1.8322237,,,,,,,,,,,,,, -4500,2.4882934,1.7502164,,,,,,,,,,,,,, -4600,2.6977851,1.7606276,,,,,,,,,,,,,, -4700,2.7610445,1.7837493,,,,,,,,,,,,,, -4800,1.8954933,1.8103749,,,,,,,,,,,,,, -4900,3.5412457,1.7481048,,,,,,,,,,,,,, -5000,3.3210137,1.7797782,,,,,,,,,,,,,, -5100,2.4595337,1.7570484,,,,,,,,,,,,,, -5200,2.2442427,1.7159246,,,,,,,,,,,,,, -5222,,,0.42450523,0.1461907849307631,0.7988043,0.2279270494414783,5348.0,0.49402905,0.1591209148335466,2472.0,4362.251379489899,4946.733110427856,4362.251379489899,584.1421117782593,0.1240792274475097,0.0 -5300,1.7397594,1.7128743,,,,,,,,,,,,,, -5400,4.2929044,1.6788131,,,,,,,,,,,,,, -5500,3.7701373,1.7536496,,,,,,,,,,,,,, -5600,2.043203,1.7519673,,,,,,,,,,,,,, -5700,2.3378246,1.6454443,,,,,,,,,,,,,, -5800,3.5443285,1.690722,,,,,,,,,,,,,, -5900,2.8005466,1.7580483,,,,,,,,,,,,,, -6000,2.1133249,1.6699637,,,,,,,,,,,,,, -6100,2.633987,1.6873329,,,,,,,,,,,,,, -6200,2.1281655,1.6776954,,,,,,,,,,,,,, -6300,2.2804363,1.6962079,,,,,,,,,,,,,, -6400,2.8454633,1.6550497,,,,,,,,,,,,,, -6500,3.0664284,1.6727202,,,,,,,,,,,,,, -6600,2.4478083,1.7200928,,,,,,,,,,,,,, -6700,2.2409654,1.6759835,,,,,,,,,,,,,, -6800,2.1300135,1.6025853,,,,,,,,,,,,,, -6900,3.5161424,1.6654456,,,,,,,,,,,,,, -6949,,,0.4194605,0.1420006492894469,0.76416,0.2194599187078212,5348.0,0.46259597,0.1485995165844047,2472.0,5802.464230775833,6524.573171615601,5802.464230775833,721.6570925712585,0.1635785102844238,0.0 -7000,2.9302738,1.7244482,,,,,,,,,,,,,, -7100,4.5024304,1.7049217,,,,,,,,,,,,,, -7200,3.1775727,1.6760471,,,,,,,,,,,,,, -7300,2.62326,1.6381075,,,,,,,,,,,,,, -7400,2.8911684,1.6720837,,,,,,,,,,,,,, -7500,3.1219583,1.6966406,,,,,,,,,,,,,, -7600,1.8845689,1.623237,,,,,,,,,,,,,, -7700,2.431739,1.6093949,,,,,,,,,,,,,, -7800,2.4741585,1.6549925,,,,,,,,,,,,,, -7900,3.7116613,1.6247113,,,,,,,,,,,,,, -8000,2.8342998,1.6497244,,,,,,,,,,,,,, -8100,3.4157639,1.6778455,,,,,,,,,,,,,, -8200,1.9290459,1.6708783,,,,,,,,,,,,,, -8300,2.1571758,1.6038014,,,,,,,,,,,,,, -8400,3.035846,1.6249644,,,,,,,,,,,,,, -8500,3.2032506,1.6502712,,,,,,,,,,,,,, -8600,3.6481936,1.6422994,,,,,,,,,,,,,, -8699,,,0.36787793,0.1260410098066929,0.6959745,0.1992237658939726,5348.0,0.41620168,0.1332439623829545,2472.0,7242.510432720184,8098.611032485962,7242.510432720184,855.529622554779,0.2090013027191162,0.0 -8700,2.4993742,1.591404,,,,,,,,,,,,,, -8800,2.71596,1.6154804,,,,,,,,,,,,,, -8900,2.8591485,1.590311,,,,,,,,,,,,,, -9000,2.634088,1.5728562,,,,,,,,,,,,,, -9100,6.234758,1.6599302,,,,,,,,,,,,,, -9200,1.6973343,1.6971631,,,,,,,,,,,,,, -9300,2.5407417,1.5887297,,,,,,,,,,,,,, -9400,2.8837578,1.6707525,,,,,,,,,,,,,, -9500,4.471839,1.6176486,,,,,,,,,,,,,, -9600,2.0276291,1.6078192,,,,,,,,,,,,,, -9700,2.0790951,1.6167896,,,,,,,,,,,,,, -9800,3.30242,1.6196376,,,,,,,,,,,,,, -9900,2.7081106,1.6418127,,,,,,,,,,,,,, -10000,2.2973702,1.5559278,,,,,,,,,,,,,, -10100,2.1046846,1.6325527,,,,,,,,,,,,,, -10200,1.9933059,1.6320188,,,,,,,,,,,,,, -10300,3.7776682,1.5970148,,,,,,,,,,,,,, -10400,2.2410152,1.5724539,,,,,,,,,,,,,, -10428,,,0.37247258,0.1249606563848685,0.6662017,0.1914131515683984,5348.0,0.39532664,0.1270286190157008,2472.0,8682.545643806458,9671.63682961464,8682.545643806458,988.4021236896516,0.2537970542907715,0.0 -10500,5.215192,1.6912819,,,,,,,,,,,,,, -10600,3.2839992,1.568965,,,,,,,,,,,,,, -10700,2.6225755,1.6617241,,,,,,,,,,,,,, -10800,3.0325217,1.5727069,,,,,,,,,,,,,, -10900,3.795647,1.6221522,,,,,,,,,,,,,, -11000,2.2539318,1.6141479,,,,,,,,,,,,,, -11100,4.3254185,1.5597512,,,,,,,,,,,,,, -11200,2.0167584,1.5404514,,,,,,,,,,,,,, -11300,2.6233962,1.5540416,,,,,,,,,,,,,, -11400,6.0213866,1.5867933,,,,,,,,,,,,,, -11500,3.321226,1.5943056,,,,,,,,,,,,,, -11600,2.9636197,1.583612,,,,,,,,,,,,,, -11700,3.487636,1.5801148,,,,,,,,,,,,,, -11800,2.3553178,1.4971383,,,,,,,,,,,,,, -11900,4.401296,1.6185564,,,,,,,,,,,,,, -12000,3.3224738,1.5784546,,,,,,,,,,,,,, -12100,2.0955434,1.5523964,,,,,,,,,,,,,, -12141,,,0.33491787,0.1148139876204454,0.6525444,0.188130569527984,5348.0,0.38549662,0.1246521641988097,2472.0,10123.035349369047,11247.871819019318,10123.035349369047,1124.0344214439392,0.2944934368133545,0.0 -12200,2.6500027,1.5731331,,,,,,,,,,,,,, -12300,2.5298991,1.6117122,,,,,,,,,,,,,, -12400,2.8526254,1.5465125,,,,,,,,,,,,,, -12500,2.8112955,1.556061,,,,,,,,,,,,,, -12600,2.5077813,1.5301216,,,,,,,,,,,,,, -12700,4.19101,1.5428274,,,,,,,,,,,,,, -12800,1.7928507,1.5496843,,,,,,,,,,,,,, -12900,3.3897452,1.5232757,,,,,,,,,,,,,, -13000,3.0975056,1.5452901,,,,,,,,,,,,,, -13100,2.308428,1.5666189,,,,,,,,,,,,,, -13200,2.1545565,1.5886122,,,,,,,,,,,,,, -13300,2.5268435,1.592216,,,,,,,,,,,,,, -13400,2.9743955,1.5536233,,,,,,,,,,,,,, -13500,2.5113387,1.4857166,,,,,,,,,,,,,, -13600,1.8271732,1.5789514,,,,,,,,,,,,,, -13700,2.795372,1.4769304,,,,,,,,,,,,,, -13800,2.466001,1.5489045,,,,,,,,,,,,,, -13894,,,0.2947174,0.1042824826438545,0.65284616,0.1880436776504436,5348.0,0.37729406,0.1218288546300245,2472.0,11563.117947101591,12823.876756429672,11563.117947101591,1259.844975233078,0.3318409919738769,0.0 -13900,2.77199,1.5339551,,,,,,,,,,,,,, -14000,12.782514,1.5690103,,,,,,,,,,,,,, -14100,1.9844995,1.5592557,,,,,,,,,,,,,, -14200,4.4794927,1.5626231,,,,,,,,,,,,,, -14300,2.0434222,1.5249972,,,,,,,,,,,,,, -14400,2.0533733,1.4998876,,,,,,,,,,,,,, -14500,3.2825725,1.5350611,,,,,,,,,,,,,, -14600,1.9995878,1.487609,,,,,,,,,,,,,, -14700,2.264789,1.515497,,,,,,,,,,,,,, -14800,3.4692197,1.5213645,,,,,,,,,,,,,, -14900,3.65316,1.5577867,,,,,,,,,,,,,, -15000,2.4245462,1.5286008,,,,,,,,,,,,,, -15100,2.7174268,1.5678635,,,,,,,,,,,,,, -15200,2.344774,1.5145022,,,,,,,,,,,,,, -15300,2.7544158,1.5611687,,,,,,,,,,,,,, -15400,3.6795766,1.4971648,,,,,,,,,,,,,, -15500,2.0320525,1.4978479,,,,,,,,,,,,,, -15600,3.0975926,1.5015854,,,,,,,,,,,,,, -15623,,,0.27732593,0.0969840771616596,0.61508054,0.176709114957954,5348.0,0.36129552,0.1158166270590863,2472.0,13003.560472249985,14405.480276107788,13003.560472249985,1400.8908877372742,0.3732874393463135,0.0 -15700,2.463206,1.4329998,,,,,,,,,,,,,, -15800,2.967464,1.4690585,,,,,,,,,,,,,, -15900,3.069165,1.5021814,,,,,,,,,,,,,, -16000,3.1960788,1.5781716,,,,,,,,,,,,,, -16100,2.4908025,1.535876,,,,,,,,,,,,,, -16200,3.9665143,1.4821571,,,,,,,,,,,,,, -16300,2.5606682,1.4941967,,,,,,,,,,,,,, -16400,2.8002539,1.5359524,,,,,,,,,,,,,, -16500,1.9927392,1.4710704,,,,,,,,,,,,,, -16600,2.4134734,1.4982779,,,,,,,,,,,,,, -16700,2.0608523,1.5246882,,,,,,,,,,,,,, -16800,2.9755569,1.5082093,,,,,,,,,,,,,, -16900,2.1738596,1.4956554,,,,,,,,,,,,,, -17000,2.8054776,1.5200055,,,,,,,,,,,,,, -17100,1.5767763,1.4999303,,,,,,,,,,,,,, -17200,3.184522,1.5354768,,,,,,,,,,,,,, -17300,1.6865159,1.5031477,,,,,,,,,,,,,, -17340,,,0.28093824,0.0972719033394649,0.61341965,0.1754636647132085,5348.0,0.34505987,0.1124042816809863,2472.0,14444.06426525116,15981.121856212616,14444.06426525116,1535.9185712337494,0.4116241931915283,0.0 -17400,2.92621,1.508536,,,,,,,,,,,,,, -17500,2.316096,1.4861233,,,,,,,,,,,,,, -17600,3.8588283,1.4836183,,,,,,,,,,,,,, -17700,2.7259011,1.4641246,,,,,,,,,,,,,, -17800,3.080611,1.5000813,,,,,,,,,,,,,, -17900,3.7576947,1.4406958,,,,,,,,,,,,,, -18000,1.8336996,1.478172,,,,,,,,,,,,,, -18100,2.9121947,1.4438757,,,,,,,,,,,,,, -18200,2.5318956,1.4984921,,,,,,,,,,,,,, -18300,2.9185634,1.5471371,,,,,,,,,,,,,, -18400,1.6412928,1.4449673,,,,,,,,,,,,,, -18500,1.9682597,1.4983984,,,,,,,,,,,,,, -18600,2.9666822,1.5254924,,,,,,,,,,,,,, -18700,5.1230974,1.5033889,,,,,,,,,,,,,, -18800,3.4170406,1.425043,,,,,,,,,,,,,, -18900,2.9268959,1.4387512,,,,,,,,,,,,,, -19000,5.2932744,1.5181859,,,,,,,,,,,,,, -19066,,,0.29000315,0.0995108278907427,0.59839016,0.1725286501829556,5348.0,0.34121943,0.1097434647492535,2472.0,15884.91688466072,17558.958316087723,15884.91688466072,1672.7844877243042,0.4560194015502929,0.0 -19100,2.439368,1.5012264,,,,,,,,,,,,,, -19200,2.8582768,1.5114281,,,,,,,,,,,,,, -19300,3.035566,1.4628464,,,,,,,,,,,,,, -19400,5.3537097,1.4644976,,,,,,,,,,,,,, -19500,2.2568922,1.4632797,,,,,,,,,,,,,, -19600,2.0644886,1.4776877,,,,,,,,,,,,,, -19700,2.2136023,1.5056992,,,,,,,,,,,,,, -19800,2.7509856,1.4217664,,,,,,,,,,,,,, -19900,3.533451,1.4005989,,,,,,,,,,,,,, -20000,2.7666574,1.4661434,,,,,,,,,,,,,, -20100,2.7718484,1.4774482,,,,,,,,,,,,,, -20200,3.3108606,1.4847473,,,,,,,,,,,,,, -20300,2.8920248,1.5365931,,,,,,,,,,,,,, -20400,2.7943976,1.4445522,,,,,,,,,,,,,, -20500,2.821797,1.4307418,,,,,,,,,,,,,, -20600,1.7041728,1.4671546,,,,,,,,,,,,,, -20700,2.424081,1.4572681,,,,,,,,,,,,,, -20790,,,0.28017858,0.09633576911916,0.5747481,0.1661276152041476,5348.0,0.32885116,0.1038734182357361,2472.0,17325.00393819809,19143.80579972267,17325.00393819809,1817.4322745800016,0.4951100349426269,0.0 -20800,2.4222589,1.411122,,,,,,,,,,,,,, -20900,2.8339188,1.4337566,,,,,,,,,,,,,, -21000,2.656157,1.451371,,,,,,,,,,,,,, -21100,2.558771,1.4547899,,,,,,,,,,,,,, -21200,2.5813015,1.4520549,,,,,,,,,,,,,, -21300,2.7630014,1.4332212,,,,,,,,,,,,,, -21400,2.765566,1.4475552,,,,,,,,,,,,,, -21500,2.5068958,1.4202027,,,,,,,,,,,,,, -21600,3.1656108,1.4689212,,,,,,,,,,,,,, -21700,2.4283426,1.3770053,,,,,,,,,,,,,, -21800,2.5530343,1.45629,,,,,,,,,,,,,, -21900,1.6263437,1.4252008,,,,,,,,,,,,,, -22000,2.6481245,1.3658904,,,,,,,,,,,,,, -22100,1.8065991,1.4208479,,,,,,,,,,,,,, -22200,2.4919293,1.3963946,,,,,,,,,,,,,, -22300,1.9716743,1.3710321,,,,,,,,,,,,,, -22400,3.085999,1.461423,,,,,,,,,,,,,, -22500,2.5266407,1.4287766,,,,,,,,,,,,,, -22510,,,0.26128995,0.0898394896594299,0.5584282,0.1600934570416212,5348.0,0.3162405,0.1023703613430016,2472.0,18765.608085393906,20720.63533425331,18765.608085393906,1953.54097032547,0.5389456748962402,0.0 -22600,2.438034,1.3956425,,,,,,,,,,,,,, -22700,4.1953707,1.4291986,,,,,,,,,,,,,, -22800,2.5050194,1.3536788,,,,,,,,,,,,,, -22900,2.485957,1.408588,,,,,,,,,,,,,, -23000,3.2041326,1.4412633,,,,,,,,,,,,,, -23100,1.7548925,1.4118799,,,,,,,,,,,,,, -23200,2.4915376,1.4114931,,,,,,,,,,,,,, -23300,2.8779218,1.4643279,,,,,,,,,,,,,, -23400,1.6155381,1.3836799,,,,,,,,,,,,,, -23500,3.1054206,1.3953238,,,,,,,,,,,,,, -23600,2.7527816,1.3948908,,,,,,,,,,,,,, -23700,3.0356803,1.3850315,,,,,,,,,,,,,, -23800,3.884565,1.4091073,,,,,,,,,,,,,, -23900,3.0086884,1.3896039,,,,,,,,,,,,,, -24000,2.8277617,1.4275157,,,,,,,,,,,,,, -24100,2.5997539,1.410377,,,,,,,,,,,,,, -24200,3.2192566,1.3128766,,,,,,,,,,,,,, -24240,,,0.2412733,0.084118069338288,0.5459013,0.1572067157766685,5348.0,0.30818713,0.0985924075315337,2472.0,20206.615512132645,22296.645008563995,20206.615512132645,2088.428907632828,0.5796234607696533,0.0 -24300,2.4348145,1.4470639,,,,,,,,,,,,,, -24400,1.9817383,1.4376446,,,,,,,,,,,,,, -24500,2.2333984,1.4082378,,,,,,,,,,,,,, -24600,3.4761257,1.373927,,,,,,,,,,,,,, -24700,2.0953004,1.3882668,,,,,,,,,,,,,, -24800,3.3858235,1.3302649,,,,,,,,,,,,,, -24900,2.5582478,1.418056,,,,,,,,,,,,,, -25000,3.4650557,1.4003991,,,,,,,,,,,,,, -25100,2.8513913,1.3826623,,,,,,,,,,,,,, -25200,4.737279,1.4586997,,,,,,,,,,,,,, -25300,3.0163252,1.3637756,,,,,,,,,,,,,, -25400,1.9468119,1.364693,,,,,,,,,,,,,, -25500,3.7754838,1.3902203,,,,,,,,,,,,,, -25600,3.1783786,1.3915906,,,,,,,,,,,,,, -25700,2.1086073,1.3714895,,,,,,,,,,,,,, -25800,1.8153558,1.3509618,,,,,,,,,,,,,, -25900,3.0097086,1.3736724,,,,,,,,,,,,,, -25970,,,0.22022744,0.0780615039515897,0.52577573,0.1526979927976288,5348.0,0.29683977,0.095748786383117,2472.0,21647.060393571854,23873.397852659225,21647.060393571854,2224.599086999893,0.6414904594421387,0.0 -26000,4.344428,1.3605913,,,,,,,,,,,,,, -26100,2.687345,1.35802,,,,,,,,,,,,,, -26200,2.223216,1.344057,,,,,,,,,,,,,, -26300,1.5582879,1.3157396,,,,,,,,,,,,,, -26400,3.8760896,1.3438892,,,,,,,,,,,,,, -26500,1.8348778,1.3470604,,,,,,,,,,,,,, -26600,2.5099723,1.3320198,,,,,,,,,,,,,, -26700,2.2693532,1.3631333,,,,,,,,,,,,,, -26800,2.0827281,1.3779336,,,,,,,,,,,,,, -26900,2.6497629,1.3804771,,,,,,,,,,,,,, -27000,1.8608581,1.3599818,,,,,,,,,,,,,, -27100,3.0905805,1.3326308,,,,,,,,,,,,,, -27200,1.7297742,1.3779881,,,,,,,,,,,,,, -27300,3.99206,1.3393159,,,,,,,,,,,,,, -27400,2.122092,1.3017739,,,,,,,,,,,,,, -27500,1.9336087,1.3101919,,,,,,,,,,,,,, -27600,3.0472612,1.3774918,,,,,,,,,,,,,, -27677,,,0.20966426,0.0734011107588139,0.510851,0.146586597410622,5348.0,0.284106,0.0905083988381776,2472.0,23087.36500597,25450.333413124084,23087.36500597,2361.101853847504,0.6938507556915283,0.0 -27700,3.7102516,1.3520256,,,,,,,,,,,,,, -27800,2.869406,1.3583454,,,,,,,,,,,,,, -27900,2.1138947,1.2876912,,,,,,,,,,,,,, -28000,2.4105504,1.3771393,,,,,,,,,,,,,, -28100,2.7274709,1.3605,,,,,,,,,,,,,, -28200,2.3277915,1.3558067,,,,,,,,,,,,,, -28300,2.7332118,1.2549521,,,,,,,,,,,,,, -28400,2.2799287,1.3327088,,,,,,,,,,,,,, -28500,2.2209246,1.3166004,,,,,,,,,,,,,, -28600,3.2153997,1.2938906,,,,,,,,,,,,,, -28700,5.565454,1.3387914,,,,,,,,,,,,,, -28800,1.865554,1.3507088,,,,,,,,,,,,,, -28900,2.4457488,1.3633816,,,,,,,,,,,,,, -29000,1.786983,1.3211725,,,,,,,,,,,,,, -29100,2.8019612,1.3152415,,,,,,,,,,,,,, -29200,2.2345808,1.3620758,,,,,,,,,,,,,, -29300,2.948935,1.2944229,,,,,,,,,,,,,, -29394,,,0.21916138,0.0770064135898769,0.49210644,0.1443563725537522,5348.0,0.27293038,0.0875632197916032,2472.0,24527.617062807083,27025.553820371628,24527.617062807083,2495.944045782089,0.7462542057037354,0.0 -29400,3.371917,1.3155199,,,,,,,,,,,,,, -29500,3.1425407,1.3424928,,,,,,,,,,,,,, -29600,1.8457361,1.3271147,,,,,,,,,,,,,, -29700,2.6443298,1.3233443,,,,,,,,,,,,,, -29800,3.1873736,1.3412205,,,,,,,,,,,,,, -29900,1.8874602,1.30612,,,,,,,,,,,,,, -30000,2.3879683,1.3342214,,,,,,,,,,,,,, -30100,6.646641,1.3445027,,,,,,,,,,,,,, -30200,2.193468,1.2849559,,,,,,,,,,,,,, -30300,2.838001,1.337311,,,,,,,,,,,,,, -30400,3.299253,1.2850647,,,,,,,,,,,,,, -30500,1.8571151,1.3100427,,,,,,,,,,,,,, -30600,2.5064304,1.3157684,,,,,,,,,,,,,, -30700,1.6702485,1.2726643,,,,,,,,,,,,,, -30800,2.716786,1.2935896,,,,,,,,,,,,,, -30900,1.838769,1.2430534,,,,,,,,,,,,,, -31000,2.3048654,1.3070774,,,,,,,,,,,,,, -31100,2.4394655,1.2946657,,,,,,,,,,,,,, -31101,,,0.20257692,0.0690345579737274,0.48863128,0.1420971837377024,5348.0,0.27050504,0.0872585460971299,2472.0,25968.251851797104,28602.32335114479,25968.251851797104,2631.939340829849,0.8093528747558594,0.0 -31200,1.5518372,1.3242754,,,,,,,,,,,,,, -31300,2.5761688,1.3781248,,,,,,,,,,,,,, -31400,2.1313908,1.3038706,,,,,,,,,,,,,, -31500,2.7645829,1.2339644,,,,,,,,,,,,,, -31600,4.26535,1.3600013,,,,,,,,,,,,,, -31700,2.2258508,1.226343,,,,,,,,,,,,,, -31800,2.3182805,1.3104575,,,,,,,,,,,,,, -31900,4.520929,1.3272436,,,,,,,,,,,,,, -32000,3.0850577,1.3011514,,,,,,,,,,,,,, -32100,2.6824517,1.2925965,,,,,,,,,,,,,, -32200,6.415418,1.3423902,,,,,,,,,,,,,, -32300,3.1358511,1.2666951,,,,,,,,,,,,,, -32400,2.1702976,1.2333807,,,,,,,,,,,,,, -32500,2.1816335,1.2709353,,,,,,,,,,,,,, -32600,3.0782375,1.2732135,,,,,,,,,,,,,, -32700,3.5155795,1.2810649,,,,,,,,,,,,,, -32790,,,0.20776458,0.0702371718147214,0.47163725,0.1355223650038136,5348.0,0.25773567,0.0819369122336644,2472.0,27408.37361598015,30179.51735305786,27408.37361598015,2768.881613969803,0.8643975257873535,0.0 -32800,2.383106,1.315894,,,,,,,,,,,,,, -32900,2.1699085,1.2536585,,,,,,,,,,,,,, -33000,2.2494678,1.2390363,,,,,,,,,,,,,, -33100,2.2902958,1.3141216,,,,,,,,,,,,,, -33200,2.010631,1.2659727,,,,,,,,,,,,,, -33300,2.8016207,1.2449739,,,,,,,,,,,,,, -33400,2.7515178,1.2770327,,,,,,,,,,,,,, -33500,3.860913,1.2645628,,,,,,,,,,,,,, -33600,2.5050893,1.207625,,,,,,,,,,,,,, -33700,1.994128,1.2806991,,,,,,,,,,,,,, -33800,1.6635575,1.2205039,,,,,,,,,,,,,, -33900,1.6673063,1.3008775,,,,,,,,,,,,,, -34000,2.0649,1.2854439,,,,,,,,,,,,,, -34100,5.8963923,1.2647462,,,,,,,,,,,,,, -34200,7.324533,1.2295626,,,,,,,,,,,,,, -34300,3.2964623,1.2330301,,,,,,,,,,,,,, -34400,2.638812,1.2435827,,,,,,,,,,,,,, -34500,2.0625553,1.2606658,,,,,,,,,,,,,, -34502,,,0.1990523,0.0663810650383111,0.46344855,0.1336010890448651,5348.0,0.25175428,0.0811650722076655,2472.0,28848.777802228928,31754.64643263817,28848.777802228928,2903.467637300492,0.9291818141937256,0.0 -34600,1.836787,1.2424272,,,,,,,,,,,,,, -34700,1.8271935,1.2596562,,,,,,,,,,,,,, -34800,1.9118774,1.2100905,,,,,,,,,,,,,, -34900,2.1597419,1.2428844,,,,,,,,,,,,,, -35000,2.0987935,1.2331809,,,,,,,,,,,,,, -35100,3.0484335,1.2698796,,,,,,,,,,,,,, -35200,2.9438894,1.2507982,,,,,,,,,,,,,, -35300,2.5431354,1.2489105,,,,,,,,,,,,,, -35400,3.9336505,1.2764205,,,,,,,,,,,,,, -35500,3.4079216,1.2676075,,,,,,,,,,,,,, -35600,1.9181974,1.2383796,,,,,,,,,,,,,, -35700,3.373761,1.229113,,,,,,,,,,,,,, -35800,2.112502,1.2801386,,,,,,,,,,,,,, -35900,4.1251917,1.2697066,,,,,,,,,,,,,, -36000,2.9778297,1.2081999,,,,,,,,,,,,,, -36100,1.6808186,1.2138492,,,,,,,,,,,,,, -36191,,,0.14591317,0.0517768006313869,0.44477853,0.1280689728414609,5348.0,0.24260797,0.0783214510592488,2472.0,30289.52933859825,33334.662573099136,30289.52933859825,3042.599319934845,0.9865810871124268,0.0 -36200,3.1230714,1.2320924,,,,,,,,,,,,,, -36300,2.0514424,1.2194043,,,,,,,,,,,,,, -36400,1.8555374,1.2016671,,,,,,,,,,,,,, -36500,3.865332,1.2574245,,,,,,,,,,,,,, -36600,3.5186458,1.1538767,,,,,,,,,,,,,, -36700,2.5610545,1.2594953,,,,,,,,,,,,,, -36800,2.3660822,1.2790394,,,,,,,,,,,,,, -36900,3.9787872,1.194393,,,,,,,,,,,,,, -37000,2.8123863,1.2628648,,,,,,,,,,,,,, -37100,3.475768,1.1685194,,,,,,,,,,,,,, -37200,2.5881526,1.1763604,,,,,,,,,,,,,, -37300,4.818134,1.1905118,,,,,,,,,,,,,, -37400,4.946385,1.2322837,,,,,,,,,,,,,, -37500,2.0069802,1.2042407,,,,,,,,,,,,,, -37600,6.9353194,1.1973288,,,,,,,,,,,,,, -37700,2.374177,1.2811686,,,,,,,,,,,,,, -37800,2.6357508,1.2135484,,,,,,,,,,,,,, -37893,,,0.16582121,0.0575542227196891,0.4391291,0.1269779970456761,5348.0,0.23578015,0.0751934677959905,2472.0,31730.143264770508,34911.43565702438,31730.143264770508,3178.628196001053,1.0424203872680664,0.0 -37900,3.0240498,1.1825109,,,,,,,,,,,,,, -38000,4.6636353,1.1949272,,,,,,,,,,,,,, -38100,2.3945422,1.227862,,,,,,,,,,,,,, -38200,1.835779,1.2142025,,,,,,,,,,,,,, -38300,1.8006698,1.1653322,,,,,,,,,,,,,, -38400,2.228654,1.1763431,,,,,,,,,,,,,, -38500,3.7309027,1.1813307,,,,,,,,,,,,,, -38600,2.7036982,1.1721185,,,,,,,,,,,,,, -38700,3.0842853,1.2072815,,,,,,,,,,,,,, -38800,3.347978,1.1511247,,,,,,,,,,,,,, -38900,1.8700736,1.1933894,,,,,,,,,,,,,, -39000,2.2845275,1.2099848,,,,,,,,,,,,,, -39100,2.5444555,1.198208,,,,,,,,,,,,,, -39200,2.3453262,1.1960113,,,,,,,,,,,,,, -39300,3.536411,1.2027729,,,,,,,,,,,,,, -39400,4.5525723,1.2164432,,,,,,,,,,,,,, -39500,2.7225387,1.2171215,,,,,,,,,,,,,, -39600,2.3943057,1.2276462,,,,,,,,,,,,,, -39608,,,0.20992558,0.0719612937520001,0.42906445,0.1237726522297421,5348.0,0.23024255,0.0734669835273089,2472.0,33170.094004154205,36484.08493804932,33170.094004154205,3311.199209690094,1.094820261001587,0.0 -39700,2.3737082,1.2029129,,,,,,,,,,,,,, -39800,2.6354003,1.2357445,,,,,,,,,,,,,, -39900,4.7435503,1.1534826,,,,,,,,,,,,,, -40000,2.593291,1.1867819,,,,,,,,,,,,,, -40100,3.303032,1.2090033,,,,,,,,,,,,,, -40200,4.850717,1.2133664,,,,,,,,,,,,,, -40300,1.9539847,1.125328,,,,,,,,,,,,,, -40400,3.4297392,1.2374254,,,,,,,,,,,,,, -40500,4.2769065,1.194939,,,,,,,,,,,,,, -40600,2.6033049,1.1676761,,,,,,,,,,,,,, -40700,2.3445587,1.1541964,,,,,,,,,,,,,, -40800,2.8691268,1.20933,,,,,,,,,,,,,, -40900,1.9769083,1.2126793,,,,,,,,,,,,,, -41000,2.9102018,1.1765181,,,,,,,,,,,,,, -41100,1.7300342,1.2043849,,,,,,,,,,,,,, -41200,1.9266276,1.1404436,,,,,,,,,,,,,, -41293,,,0.21701375,0.0753537585282312,0.42279768,0.1214169168830918,5348.0,0.22649801,0.0714358255641541,2472.0,34610.70112776756,38060.21517109871,34610.70112776756,3446.5949206352234,1.1482131481170654,0.0 -41300,2.6412616,1.1219234,,,,,,,,,,,,,, -41400,2.227057,1.1932838,,,,,,,,,,,,,, -41500,2.297256,1.1598235,,,,,,,,,,,,,, -41600,3.4830294,1.1903241,,,,,,,,,,,,,, -41700,2.7021923,1.1876086,,,,,,,,,,,,,, -41800,2.0676382,1.1342157,,,,,,,,,,,,,, -41900,2.4282331,1.1209412,,,,,,,,,,,,,, -42000,5.0091887,1.1379882,,,,,,,,,,,,,, -42100,2.404438,1.1501591,,,,,,,,,,,,,, -42200,3.6330776,1.0913115,,,,,,,,,,,,,, -42300,2.9766426,1.122458,,,,,,,,,,,,,, -42400,3.6079028,1.1300182,,,,,,,,,,,,,, -42500,2.4726923,1.1088313,,,,,,,,,,,,,, -42600,3.1112037,1.1278607,,,,,,,,,,,,,, -42700,3.709283,1.1681118,,,,,,,,,,,,,, -42800,3.2102761,1.1688324,,,,,,,,,,,,,, -42900,3.427415,1.2050838,,,,,,,,,,,,,, -42989,,,0.24711965,0.0869839895642611,0.41582108,0.1195149502302634,5348.0,0.22220674,0.0705624276399975,2472.0,36051.13100504875,39636.61383128166,36051.13100504875,3582.435993909836,1.1999573707580566,0.0 -42989,,,,,,,,,,,36051.13100504875,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index cfbb73218..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,232 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -536.988322019577,0.0,20.478039264678955,1,0,20.478039264678955,0.4661603569984436,0.7539003491401672,0.0295337511283684,43793,557.4664070606232,0.4663219749927521,0.7589666843414307,0.0246376417424939,0.4651613235473633,0.755351185798645,0.0288803148676715,43793 -652.5934982299805,0.0295979976654052,260.4673655033112,762,0,260.4673655033112,0.9832456707954408,0.0639590546488761,0.0563886604872141,43793,913.1106786727904,0.9867117404937744,0.0512692332267761,0.0549762184250856,0.9842259287834167,0.0606738105416297,0.0549251718406611,43793 -774.1975507736206,0.0569667816162109,500.6432864665985,1524,0,500.6432864665985,0.983719527721405,0.0586970448493957,0.1037015853295259,43793,1274.9377574920654,0.9874638319015504,0.0462587960064411,0.1052157271216831,0.9847134351730348,0.055482342839241,0.1035546290814991,43793 -894.7939918041229,0.0827620029449462,740.6774151325226,2276,0,740.6774151325226,0.9840750098228456,0.0557515323162078,0.1282670616537205,43793,1635.6136515140531,0.9878709316253662,0.0440672747790813,0.135036552981234,0.985028862953186,0.0528953969478607,0.1320009243871645,43793 -1019.1199588775636,0.110504150390625,980.8679234981536,3038,0,980.8679234981536,0.9844068884849548,0.0539854243397712,0.154043207003535,43793,2000.177325725556,0.9881895780563354,0.0415606386959552,0.1676026993212543,0.9853706955909728,0.050984688103199,0.1571880636251882,43793 -1144.1841297149658,0.1403696537017822,1220.901849985123,3793,0,1220.901849985123,0.9846153855323792,0.0524870157241821,0.1727824752855969,43793,2365.3287811279297,0.988350510597229,0.0400646440684795,0.2019606540584417,0.985541582107544,0.0497087389230728,0.1713321946448269,43793 -1271.3422000408173,0.1712827682495117,1460.9999330043793,4545,0,1460.9999330043793,0.9848293662071228,0.0512782260775566,0.1918057026970715,43793,2732.639394760132,0.988611102104187,0.038961786776781,0.22024903128454,0.9857437014579772,0.0485648401081562,0.1875371005404125,43793 -1397.3150610923767,0.1990737915039062,1701.144421339035,5307,0,1701.144421339035,0.9850513339042664,0.0501545369625091,0.2047523244605464,43793,3098.8045933246613,0.9888646006584167,0.0378557853400707,0.2397422560122646,0.9859409928321838,0.0475250221788883,0.2057997940980671,43793 -1528.1678965091703,0.2259581089019775,1941.1350138187408,6033,0,1941.1350138187408,0.985364317893982,0.0490593016147613,0.2205065731715344,43793,3469.7046773433685,0.9893951416015624,0.0358287021517753,0.2831310489165644,0.9862219095230104,0.0465059168636798,0.2159552329651954,43793 -1655.6694447994232,0.2586400508880615,2181.31711101532,6774,0,2181.31711101532,0.9852455258369446,0.0492561794817447,0.22154515529967,43793,3837.4433250427246,0.9893807172775269,0.0356910154223442,0.2807387695694591,0.9862409830093384,0.0463772155344486,0.224570762905789,43793 -1779.4602222442627,0.2872166633605957,2421.491445541382,7536,0,2421.491445541382,0.985526442527771,0.0482594072818756,0.2345355224697904,43793,4201.456716775894,0.989676833152771,0.0346273817121982,0.2996095669035396,0.986353039741516,0.0456151477992534,0.2320336506118515,43793 -1906.420913219452,0.3165235519409179,2661.753399848938,8293,0,2661.753399848938,0.9855167865753174,0.0479991845786571,0.2388739386660848,43793,4568.728547811508,0.9896586537361144,0.0345381908118724,0.3084910291337902,0.9863985180854796,0.0453417785465717,0.2410090941250586,43793 -2033.2460660934448,0.3444001674652099,2901.8338191509247,9042,0,2901.8338191509247,0.9856544733047484,0.0478878058493137,0.2396084057737746,43793,4935.682396650314,0.9897889494895936,0.0342106819152832,0.3190344427106689,0.9864873886108398,0.045198518782854,0.2418314847993425,43793 -2162.5575363636017,0.3734691143035888,3141.838902950287,9798,0,3141.838902950287,0.9857429265975952,0.0474335961043834,0.2517587761230952,43793,5305.048637628555,0.990106165409088,0.0331576168537139,0.3252865702974921,0.98661607503891,0.0447712168097496,0.2498949257991057,43793 -2292.943807125092,0.401355504989624,3381.862954854965,10550,0,3381.862954854965,0.9857412576675416,0.0477214269340038,0.2506242295578791,43793,5675.507498025894,0.9900806546211244,0.0328888185322284,0.3546899649784792,0.9866286516189576,0.0449186265468597,0.2529471298880844,43793 -2423.986411333084,0.4300618171691894,3621.858566761017,11298,0,3621.858566761017,0.9856927990913392,0.0474970564246177,0.2519475865112507,43793,6046.595143318176,0.9903729557991028,0.0320044048130512,0.3639937596876242,0.986629068851471,0.0446187332272529,0.2582009507227083,43793 -2553.7944264411926,0.4591627120971679,3861.815475463867,12043,0,3861.815475463867,0.9858107566833496,0.0472899749875068,0.2561519833967316,43793,6416.409091234207,0.9904464483261108,0.0314045511186122,0.3860296660383807,0.986702561378479,0.0443083681166172,0.2602666292408093,43793 -2679.0158491134644,0.487302303314209,4102.04964017868,12786,0,4102.04964017868,0.9859880805015564,0.046922318637371,0.2588306255104661,43793,6781.913225889206,0.990697145462036,0.0306289941072464,0.3985212442916236,0.9868206977844238,0.0440055690705776,0.2653274057273426,43793 -2808.231062889099,0.5157938003540039,4342.18771481514,13540,0,4342.18771481514,0.9859177470207214,0.0466880537569522,0.2566011645968828,43793,7151.315158605576,0.990955412387848,0.0300523471087217,0.4159336755250958,0.9867321848869324,0.0441063083708286,0.264509819457471,43793 -2935.961008310318,0.5450277328491211,4582.151921987534,14297,0,4582.151921987534,0.9858878254890442,0.0471305660903453,0.2551575278951115,43793,7519.058586359024,0.9908493757247924,0.0302403513342142,0.4088204630160475,0.9867675304412842,0.0443227104842662,0.2648037087964535,43793 -3066.32634806633,0.5752942562103271,4822.344773769379,15047,0,4822.344773769379,0.9858798384666444,0.0468828491866588,0.2637279712481232,43793,7889.667108535767,0.9907960891723632,0.0304703917354345,0.4065808959727163,0.9867001175880432,0.0441397167742252,0.2686152810756539,43793 -3202.7751603126526,0.6057493686676025,5062.34095621109,15804,0,5062.34095621109,0.9858170747756958,0.0469556115567684,0.2542581856078846,43793,8266.162788152695,0.990820586681366,0.0303225982934236,0.4107530372449284,0.9867711663246156,0.0440304353833198,0.2680708694233623,43793 -3336.2889742851257,0.6370806694030762,5302.402543067932,16540,0,5302.402543067932,0.98592871427536,0.0474228151142597,0.2621201168497622,43793,8639.794367313385,0.990711271762848,0.0303991846740245,0.4025933595309913,0.9867861866950988,0.0445683114230632,0.2658517816125347,43793 -3465.017394065857,0.669407844543457,5542.3714718818665,17286,0,5542.3714718818665,0.9860125184059144,0.046807624399662,0.2634333099377213,43793,9008.544433832169,0.9908052682876588,0.0300758443772792,0.4174778803693812,0.9868454337120056,0.0440562255680561,0.2679873720240026,43793 -3594.087012052536,0.7005889415740967,5782.42107629776,18042,0,5782.42107629776,0.9859421849250792,0.0468078851699829,0.2667088061323157,43793,9377.715222358704,0.9909988045692444,0.0294074118137359,0.4241687010668576,0.9868454337120056,0.0440659411251544,0.270766252405497,43793 -3724.530516862869,0.7313892841339111,6022.573717832565,18793,0,6022.573717832565,0.985970377922058,0.0467896312475204,0.2599671178064437,43793,9748.362263441086,0.9910372495651244,0.0290959365665912,0.4259529531939385,0.9867752194404602,0.0439095199108123,0.2738895893633772,43793 -3852.681501150131,0.7612929344177246,6262.538156986237,19548,0,6262.538156986237,0.9860718846321106,0.0469783172011375,0.2697352700780682,43793,10116.527661561966,0.9911220073699952,0.0286788288503885,0.4524073924889998,0.9870005249977112,0.0439878851175308,0.274081041031077,43793 -3983.186435461044,0.7912535667419434,6502.570947647095,20299,0,6502.570947647095,0.9862125515937804,0.0465430319309234,0.2696994986515265,43793,10487.115420341492,0.9913285970687866,0.0280595403164625,0.4674548046421443,0.9869655966758728,0.0438056886196136,0.2766592911034783,43793 -4120.325926780701,0.8225059509277344,6742.623478651047,21043,0,6742.623478651047,0.9861034750938416,0.0469097383320331,0.2658631214409926,43793,10864.359109163284,0.9913150072097778,0.0281208455562591,0.4598708789135767,0.9869221448898317,0.0441030338406562,0.2744794112477171,43793 -4253.412977218628,0.853546142578125,6982.736036777496,21784,0,6982.736036777496,0.9860184192657472,0.046649981290102,0.2655067753076269,43793,11237.61018204689,0.9913945198059082,0.0282089468091726,0.4455943310874722,0.9868649244308472,0.0439488366246223,0.274835884000275,43793 -4384.202717542648,0.8843135833740234,7222.803408145904,22532,0,7222.803408145904,0.9860659837722778,0.0466744974255561,0.2620544147619095,43793,11608.518146514893,0.9913437366485596,0.0282543916255235,0.4483314941364055,0.9869745373725892,0.0437615998089313,0.2805751517421683,43793 -4513.108451843262,0.9187815189361572,7462.937109947205,23278,0,7462.937109947205,0.9861236810684204,0.0467276126146316,0.2661157909791088,43793,11977.61382842064,0.991249144077301,0.0284157712012529,0.4628828997940778,0.9869765639305116,0.0439945273101329,0.2760817632379464,43793 -4643.4614017009735,0.9489178657531738,7702.915818929672,24029,0,7702.915818929672,0.9860647320747375,0.0475699119269847,0.2683217825679791,43793,12347.995973825457,0.991316258907318,0.0281748175621032,0.4489988947669722,0.9868953824043274,0.0444948449730873,0.27705578804006,43793 -4770.698044300079,0.9801278114318848,7942.99335026741,24779,0,7942.99335026741,0.9860441088676452,0.0468056164681911,0.2693561350510513,43793,12715.360845327376,0.9912404417991638,0.0283038429915905,0.4516727626946165,0.9870346188545228,0.0436816960573196,0.2832783057158123,43793 -4900.004084348679,1.0125136375427246,8183.164006948471,25526,0,8183.164006948471,0.9861544370651244,0.0467205196619033,0.2696622948080067,43793,13084.890377283096,0.9914011359214784,0.0276933945715427,0.4690910450316292,0.986968457698822,0.0441174320876598,0.2768586391381986,43793 -5032.323452711105,1.043934345245361,8423.121535778046,26267,0,8423.121535778046,0.9860975742340088,0.0474251545965671,0.2647629339525147,43793,13457.219170570374,0.9915769696235656,0.0271135028451681,0.4799763807067559,0.9870049953460692,0.0444050058722496,0.2801886624062228,43793 -5168.311667442322,1.0791988372802734,8663.304875612259,26998,0,8663.304875612259,0.986042022705078,0.0467803031206131,0.2685553927331099,43793,13833.450981378555,0.9920740127563475,0.0258896984159946,0.5144971201680135,0.9868710041046144,0.0441129580140113,0.2735882717769362,43793 -5298.877843379974,1.1134259700775146,8903.47397851944,27744,0,8903.47397851944,0.9861022233963012,0.0469837635755538,0.2671069133928673,43793,14204.240990161896,0.9919695854187012,0.0261274073272943,0.5067298145105661,0.9869599342346193,0.0442815013229846,0.277661715150162,43793 -5432.393877744675,1.1442391872406006,9143.46143746376,28492,0,9143.46143746376,0.98611319065094,0.0477529726922512,0.263590646622293,43793,14577.79633116722,0.9916380047798156,0.026933841407299,0.4918883510856133,0.9869672060012816,0.0448606945574283,0.2792009679877025,43793 -5559.553866147995,1.1761319637298584,9383.617679357529,29242,0,9383.617679357529,0.9862269163131714,0.0470354966819286,0.2678343298693436,43793,14945.165321350098,0.991564154624939,0.0272049885243177,0.4691972139556656,0.9871259331703186,0.0442688465118408,0.2877655372285986,43793 -5691.682248830795,1.2084007263183594,9623.72170972824,29991,0,9623.72170972824,0.986084520816803,0.0467661544680595,0.2690176516848281,43793,15317.450065851212,0.991787314414978,0.0266905855387449,0.4814734467220524,0.9869627952575684,0.0438532941043376,0.2948247113332413,43793 -5821.52396774292,1.2435061931610107,9863.68747830391,30733,0,9863.68747830391,0.9861578345298768,0.046843446791172,0.2706482304538314,43793,15687.312734127045,0.9916547536849976,0.0268725398927927,0.505886375231136,0.987059772014618,0.0440127216279506,0.2781033442700349,43793 -5954.453230142593,1.281077861785889,10103.922180891035,31466,0,10103.922180891035,0.9861629009246826,0.0466210283339023,0.2740760282792117,43793,16060.538444519045,0.991858720779419,0.026429558172822,0.4930497997143744,0.9869838953018188,0.0440751984715461,0.2827252918313451,43793 -6086.7869300842285,1.3168962001800537,10344.138065576552,32202,0,10344.138065576552,0.9862138628959656,0.0468107126653194,0.2729542577372867,43793,16433.148668050766,0.9919987320899964,0.0257250852882862,0.5162503143904493,0.9871016144752502,0.0439109280705452,0.2890559874025201,43793 -6217.578600645065,1.3507368564605713,10584.087520360948,32951,0,10584.087520360948,0.9862251877784728,0.0470446273684501,0.2750339194968053,43793,16803.94484782219,0.9919003248214722,0.0258762538433074,0.5181524506162096,0.9869920015335084,0.0443031750619411,0.2793055822996804,43793 -6345.683080196381,1.3866889476776123,10824.240885734558,33700,0,10824.240885734558,0.9861350655555724,0.0467751398682594,0.2726212946469625,43793,17172.259604930878,0.9923912286758424,0.0245789401233196,0.5401013702391663,0.9869599342346193,0.0440076068043708,0.2849392163740508,43793 -6470.765429973602,1.423579216003418,11064.355593442917,34444,0,11064.355593442917,0.9861670732498168,0.0469842590391635,0.2729544840814009,43793,17537.51575899124,0.9924638271331788,0.0244253538548946,0.5408102506296366,0.9870119094848632,0.0440884940326213,0.2841427396357607,43793 -6598.782794237137,1.4568994045257568,11304.5543050766,35193,0,11304.5543050766,0.986255943775177,0.0475512593984603,0.2712236659701404,43793,17905.78564977646,0.992255449295044,0.0248958189040422,0.5338702720010491,0.9870837330818176,0.0445379801094532,0.2771609054072548,43793 -6728.039513587952,1.4914841651916504,11544.72012758255,35938,0,11544.72012758255,0.9862921833992004,0.0474688448011875,0.2737562533565371,43793,18275.26300549507,0.9919145107269288,0.0258436184376478,0.5134692618480501,0.9870756268501282,0.0446954704821109,0.2806026018791224,43793 -6855.695904970169,1.526174783706665,11784.769032478333,36684,0,11784.769032478333,0.9862239360809326,0.0469509214162826,0.27379143182807,43793,18643.02334499359,0.9919720888137816,0.0258467327803373,0.4990831629052673,0.9870249032974244,0.044183675199747,0.2844107591122186,43793 -6984.932626485825,1.5598630905151367,12024.889050722122,37430,0,12024.889050722122,0.9861733913421632,0.0472716465592384,0.2732995899724466,43793,19012.433899879456,0.992161750793457,0.0251628756523132,0.5274201870761099,0.9870399236679076,0.0443357676267623,0.2835930685939118,43793 -7114.086160421372,1.593503713607788,12264.882412672045,38176,0,12264.882412672045,0.986193597316742,0.047460202127695,0.2783172807574485,43793,19381.634644031525,0.9922487139701844,0.0247343424707651,0.5345395525532157,0.9869737029075624,0.0445678867399692,0.2862172278706824,43793 -7240.5595326423645,1.6269724369049072,12505.049992084503,38923,0,12505.049992084503,0.9862353205680848,0.0476949959993362,0.2720401231594703,43793,19748.3295750618,0.9923862218856812,0.0243118181824684,0.5500534552972087,0.9870415329933168,0.0446854121983051,0.2939358720634637,43793 -7369.194571733475,1.6616127490997314,12745.011418104172,39670,0,12745.011418104172,0.98629891872406,0.0473158583045005,0.2731693622950285,43793,20116.981176376343,0.9924398064613342,0.0240245051681995,0.5614326523796441,0.9870817065238952,0.0445556156337261,0.2897896197496002,43793 -7497.037292003632,1.6968517303466797,12984.965381383896,40409,0,12984.965381383896,0.9859733581542968,0.0475118793547153,0.2730102876132786,43793,20484.833154678345,0.9927424192428588,0.0233082212507724,0.574511917572232,0.9868040680885316,0.0447564125061035,0.2882639522424063,43793 -7626.816812753677,2.0183424949646,13224.677162885666,41157,0,13224.677162885666,0.986185610294342,0.0472666770219802,0.2766887350322903,43793,20854.66638636589,0.9929094910621644,0.0228238552808761,0.5796415843969315,0.9870049953460692,0.0445084460079669,0.2902851871377324,43793 -7753.115566253662,2.056342363357544,13464.927323102953,41906,0,13464.927323102953,0.9861674904823304,0.0477052852511405,0.26936929027808,43793,21221.274071455,0.9926960468292236,0.0233404859900474,0.5810469593506353,0.9869651794433594,0.0448800325393676,0.2916548765324748,43793 -7882.5806567668915,2.092425346374512,13705.163032531738,42649,0,13705.163032531738,0.9862475395202636,0.0478574857115745,0.2744560090343512,43793,21591.03133511544,0.9926057457923888,0.0236248318105936,0.5628984806159489,0.9871000051498412,0.0449630059301853,0.2910589062124619,43793 -8008.517409801483,2.128856897354126,13945.18832564354,43386,0,13945.18832564354,0.9862290024757384,0.0474529117345809,0.2745836850366369,43793,21957.04997587204,0.9926634430885316,0.0233371630311012,0.5653618151959215,0.987144649028778,0.044412650167942,0.2941698091762321,43793 -8139.154308795929,2.1666409969329834,14185.310455322266,44124,0,14185.310455322266,0.9862479567527772,0.0486087724566459,0.2739571897071809,43793,22327.866988182068,0.9925113320350648,0.0236401371657848,0.5700971758111578,0.987098753452301,0.0454526208341121,0.2921075738306386,43793 -8270.51467037201,2.20216703414917,14425.444154977798,44870,0,14425.444154977798,0.9861940741539,0.0481225363910198,0.272632916276983,43793,22699.416808366776,0.9926388263702391,0.023288769647479,0.5691611133646597,0.9871628880500792,0.0448407232761383,0.295265385933204,43793 -8396.839224815369,2.237090587615967,14665.592813968658,45615,0,14665.592813968658,0.9862799644470216,0.0486726686358451,0.2769721109562906,43793,23065.945892572403,0.9928351044654846,0.02258169837296,0.5853443199422982,0.9871637225151062,0.0455134995281696,0.2881334524893654,43793 -8527.988328695297,2.2753024101257324,14905.821466445925,46353,0,14905.821466445925,0.9861839413642884,0.0484987720847129,0.2734909871848885,43793,23437.383046627045,0.993068516254425,0.0219520926475524,0.6135675762189512,0.9870699644088744,0.0454390197992324,0.2857170722059383,43793 -8655.228150367737,2.3150885105133057,15145.786823511124,47091,0,15145.786823511124,0.98616623878479,0.0482204817235469,0.2768215463962636,43793,23804.64868044853,0.9933658838272096,0.021166056394577,0.6093570684686115,0.9869733452796936,0.0454302765429019,0.2875142031287636,43793 -8786.653314113617,2.3512327671051025,15385.791709899902,47838,0,15385.791709899902,0.9862921833992004,0.0486597009003162,0.2697960378710056,43793,24176.136302232742,0.9935070872306824,0.0205803215503692,0.6371483651392227,0.987105667591095,0.0456289239227771,0.2910954688285329,43793 -8911.482343435287,2.387014865875244,15625.98142337799,48582,0,15625.98142337799,0.986214280128479,0.0486755184829235,0.2777921732487484,43793,24541.21181440353,0.993273913860321,0.0214224010705947,0.6035744046576249,0.9871032238006592,0.0457254275679588,0.2912422849534083,43793 -9033.203171491625,2.4227333068847656,15866.042525291445,49329,0,15866.042525291445,0.986250936985016,0.0489511005580425,0.2742838969821881,43793,24903.049714803696,0.9930155873298644,0.02194794267416,0.6046991869074054,0.9870277047157288,0.0460869930684566,0.2934004400441242,43793 -9158.79671406746,2.4613735675811768,16106.172878026962,50061,0,16106.172878026962,0.9863073229789734,0.0490118525922298,0.2727283348585548,43793,25268.836081027985,0.9932344555854796,0.0213817190378904,0.6176617595712219,0.9871576428413392,0.0460498519241809,0.293558521584603,43793 -9287.683853626251,2.49810791015625,16346.374148607254,50801,0,16346.374148607254,0.9861944913864136,0.0488013066351413,0.2790345826699184,43793,25637.98314404488,0.9931604862213136,0.0216441955417394,0.6141723415553106,0.987022876739502,0.0459070391952991,0.2900934698570349,43793 -9409.34234046936,2.534095764160156,16586.50678873062,51543,0,16586.50678873062,0.9862997531890868,0.0489075742661953,0.2789417867146801,43793,25999.830745220184,0.9934256076812744,0.0206026528030633,0.6302478098406785,0.9871271848678588,0.0461113080382347,0.2967554986096878,43793 -9532.496723175049,2.5732104778289795,16826.707840442657,52287,0,16826.707840442657,0.9863343238830566,0.0495630428194999,0.2742401976552608,43793,26363.245794296265,0.9934813976287842,0.0203452669084072,0.6378640103243176,0.9871352910995485,0.0466223023831844,0.2948339264460545,43793 -9656.69024682045,2.609474897384644,17066.75954055786,53032,0,17066.75954055786,0.9861392974853516,0.0489125773310661,0.2728305192982156,43793,26727.54792380333,0.9936859011650084,0.0199180506169796,0.6607255578238254,0.9869903922080994,0.0459285601973533,0.2948226204101228,43793 -9783.560585260391,2.648601531982422,17306.817973852158,53774,0,17306.817973852158,0.986177623271942,0.0494026094675064,0.2757804694010088,43793,27094.53839492798,0.9940152168273926,0.0189543068408966,0.6742533910555928,0.9870938658714294,0.0462624207139015,0.3016247706823539,43793 -9910.918627500534,2.690823554992676,17546.784031391144,54517,0,17546.784031391144,0.98629891872406,0.0495272688567638,0.2791162088408168,43793,27461.929767370224,0.9941158294677734,0.0187518820166587,0.6797247754253466,0.98710036277771,0.0464842617511749,0.2955374417131669,43793 -10033.722482919691,2.7282874584198,17786.882937908173,55260,0,17786.882937908173,0.9862226843833924,0.0497312098741531,0.275614268267156,43793,27824.8903477192,0.9938734769821168,0.0192361325025558,0.6635183115064012,0.9870488047599792,0.0465531982481479,0.2942804772799886,43793 -10159.011183738708,2.768742561340332,18027.09535241127,56008,0,18027.09535241127,0.9861831068992616,0.0500927679240703,0.2752872680004323,43793,28190.45258665085,0.9936686158180236,0.0198822543025016,0.6484135803660225,0.9869859218597412,0.0469512417912483,0.2881964914675938,43793 -10285.594108343124,2.806877851486206,18267.287185430527,56756,0,18267.287185430527,0.9860533475875854,0.0501333326101303,0.274866405576504,43793,28557.286256313324,0.9936248660087584,0.0199340619146823,0.6411082926218209,0.9869481325149536,0.0469044186174869,0.2937136247289022,43793 -10413.653168201448,2.843985319137573,18507.45695638657,57497,0,18507.45695638657,0.9861485362052916,0.0507963001728057,0.2768490884306481,43793,28925.572428703308,0.9938535094261168,0.0192855987697839,0.6639852026299888,0.9869769811630248,0.0477448664605617,0.2913151944361254,43793 -10536.1366751194,2.8811538219451904,18747.58773970604,58239,0,18747.58773970604,0.9861140251159668,0.0504525229334831,0.2746775260475992,43793,29288.24400830269,0.9939899444580078,0.0187293533235788,0.6713579688097611,0.9869345426559448,0.0474524535238742,0.293276872576415,43793 -10661.195743560793,2.9196202754974365,18987.742092847824,58988,0,18987.742092847824,0.9863199591636658,0.0507941544055938,0.2780661761896076,43793,29653.516090154648,0.9940062165260316,0.0185742657631635,0.6838410034958001,0.9871000051498412,0.0477108918130397,0.2963570922631796,43793 -10786.069747924805,2.9572134017944336,19227.76380634308,59736,0,19227.76380634308,0.9861990809440612,0.0508352927863597,0.2754505039140115,43793,30018.46971058845,0.9945381283760072,0.017136039212346,0.7119890023519173,0.9870768189430236,0.0477268025279045,0.2921333089367812,43793 -10913.33787703514,2.99572229385376,19467.88423514366,60478,0,19467.88423514366,0.9862260818481444,0.0509336069226264,0.2769695394517674,43793,30385.91834759712,0.9946551322937012,0.0168867819011211,0.7121866182580209,0.9869863390922546,0.0479784160852432,0.2957888682327104,43793 -11033.724110364914,3.034928560256958,19707.84122633934,61226,0,19707.84122633934,0.9861772060394288,0.051029447466135,0.2759675212331038,43793,30746.321582317352,0.994626522064209,0.0167669877409935,0.7171335605289063,0.9870147109031676,0.0480419769883155,0.2928131691734593,43793 -11158.269618988035,3.074516534805298,19948.08852863312,61962,0,19948.08852863312,0.98625510931015,0.0519960150122642,0.274682914531157,43793,31111.174551725388,0.9944984316825868,0.0170986913144588,0.7176627280464143,0.9870285391807556,0.0489490553736686,0.294612297398712,43793 -11283.433761835098,3.114187479019165,20188.32045006752,62706,0,20188.32045006752,0.9861464500427246,0.0516576319932937,0.2757454189249902,43793,31476.630955457687,0.9945216178894044,0.0171154234558343,0.7133070344489623,0.9869871139526368,0.0483480617403984,0.2968437849571173,43793 -11406.333416223526,3.152745485305786,20428.48164987564,63446,0,20428.48164987564,0.9861767888069152,0.0518626347184181,0.2775576912809816,43793,31839.75074863434,0.9945242404937744,0.0169851873070001,0.7093867418780908,0.9869940280914308,0.0486472472548484,0.2948966868870458,43793 -11531.373598575592,3.19139051437378,20668.63866043091,64189,0,20668.63866043091,0.9862075448036194,0.0521550811827182,0.2738652667227593,43793,32205.007052898407,0.9944551587104796,0.0170583687722682,0.7121639624300242,0.9870054125785828,0.048753298819065,0.297605083648721,43793 -11654.752810955048,3.231707811355591,20908.890762090683,64933,0,20908.890762090683,0.9862125515937804,0.0522994175553321,0.2767880296537022,43793,32568.69940972328,0.9945725202560424,0.0167442299425601,0.7202841801572984,0.9870171546936036,0.0490759573876857,0.2939358425075371,43793 -11775.9447824955,3.270869255065918,21148.8704726696,65686,0,21148.8704726696,0.9862125515937804,0.0525530837476253,0.276474986311574,43793,32929.93105196953,0.9948503971099854,0.0159808285534381,0.7366013362986004,0.9870244860649108,0.0493682250380516,0.2944331164330976,43793 -11897.835361719131,3.3115804195404053,21389.03907442093,66427,0,21389.03907442093,0.986159086227417,0.0525974743068218,0.2758683298245158,43793,33292.05467629433,0.9951132535934448,0.0153884962201118,0.7601626246504376,0.9869201183319092,0.0494905523955822,0.2930955641549069,43793 -12018.075865268707,3.3581244945526123,21629.091300964355,67167,0,21629.091300964355,0.9861615896224976,0.052687082439661,0.2765470271540292,43793,33652.41530036926,0.9952598214149476,0.0148692950606346,0.7578312354119906,0.9869258403778076,0.0494728460907936,0.2936498573486638,43793 -12139.252855539322,3.3980062007904053,21869.05692052841,67912,0,21869.05692052841,0.9861228466033936,0.0529684200882911,0.2736527050416096,43793,34013.61874079704,0.9953340888023376,0.014705266803503,0.7636951762040912,0.9869895577430724,0.0496973432600498,0.294791959712245,43793 -12259.688798666,3.437978029251098,22109.022665262222,68664,0,22109.022665262222,0.9861595034599304,0.0530820786952972,0.2710669890759926,43793,34374.08118200302,0.995245397090912,0.014989978633821,0.754445749752948,0.9869197607040404,0.0498844981193542,0.2920263811403023,43793 -12382.230098962784,3.4789209365844727,22349.240227222443,69412,0,22349.240227222443,0.9861894249916076,0.0533779002726078,0.2722586746507549,43793,34736.90176987648,0.995061218738556,0.0152342738583683,0.7505739178886142,0.9869449138641356,0.0500610433518886,0.2925591697654719,43793 -12499.442770004272,3.5188333988189697,22589.316182613373,70165,0,22589.316182613373,0.9861894249916076,0.0535657070577144,0.2755686078935777,43793,35094.25100970268,0.9951221942901612,0.0151732191443443,0.7626114380669714,0.9870163798332214,0.0502706728875637,0.2902962107963454,43793 -12616.302125930786,3.5597028732299805,22829.283252239227,70912,0,22829.283252239227,0.98615825176239,0.053830362856388,0.2747486902004241,43793,35451.13889479637,0.994987726211548,0.0154399583116173,0.7428052908604961,0.9870102405548096,0.0504793338477611,0.2939184261839505,43793 -12730.839373111725,3.6005778312683105,23069.419085025787,71663,0,23069.419085025787,0.9861262440681458,0.0536865107715129,0.2744871260934867,43793,35805.873645067215,0.9952868819236756,0.0145252058282494,0.7632115012875428,0.9869834780693054,0.0503106638789176,0.2912756458337494,43793 -12852.929856061935,3.643061876296997,23309.60293293,72409,0,23309.60293293,0.9861717224121094,0.0539077967405319,0.2749625676471574,43793,36168.21115899086,0.995356559753418,0.0144841000437736,0.7616841179855216,0.9870256781578064,0.0505304187536239,0.2914941165453763,43793 -12966.86022424698,3.684103012084961,23549.57167220116,73150,0,23549.57167220116,0.9861780405044556,0.0538734272122383,0.2740218947411233,43793,36522.17186427117,0.9953998923301696,0.0142793525010347,0.778936084156196,0.9870293140411376,0.0505202263593673,0.2940463974337953,43793 -13085.3903298378,3.724265813827514,23789.65530896187,73901,0,23789.65530896187,0.9861649870872498,0.0538687482476234,0.2766866151141615,43793,36880.84664463997,0.9956340789794922,0.013805765658617,0.7875806631898318,0.9869623780250548,0.0504661202430725,0.2937153218722242,43793 -13204.942134141922,3.765378952026367,24029.66795539856,74646,0,24029.66795539856,0.986185610294342,0.0539115630090236,0.274894312465339,43793,37240.47275662422,0.995647132396698,0.0137284025549888,0.7831579439134501,0.9869558811187744,0.0505936332046985,0.2927131763169233,43793 -13324.7779610157,3.806945562362671,24269.63495206833,75394,0,24269.63495206833,0.9862277507781982,0.0540798865258693,0.2759888610790896,43793,37600.3379855156,0.9955528378486632,0.0138159431517124,0.782723698481396,0.9869920015335084,0.0507420748472213,0.2923844320138635,43793 -13446.011998414991,3.848209857940674,24509.7945497036,76147,0,24509.7945497036,0.986213445663452,0.0540745370090007,0.2772311098034813,43793,37961.79363822937,0.995425283908844,0.0141666233539581,0.7730561717837915,0.9869599342346193,0.0507632009685039,0.2921613507528562,43793 -13562.63976407051,3.890085697174072,24749.799326896667,76897,0,24749.799326896667,0.986204981803894,0.0540654771029949,0.2752977861531578,43793,38318.488555669785,0.9954675436019896,0.0141391856595873,0.7729857674740117,0.9869920015335084,0.0507909543812274,0.2932262344296919,43793 -13676.798060655594,3.931835889816284,24989.83911180496,77645,0,24989.83911180496,0.9862138628959656,0.0539667382836341,0.2751514090916538,43793,38672.74894404411,0.9955635666847228,0.0138484183698892,0.789354634106922,0.9869863390922546,0.0506942756474018,0.2951626304122393,43793 -13791.799069404602,3.97400426864624,25229.79161000252,78390,0,25229.79161000252,0.9862096309661864,0.0539970993995666,0.2748707703275417,43793,39027.76487231255,0.9955982565879822,0.0138391042128205,0.7907666497457357,0.9869810342788696,0.0507110469043254,0.2936789197248196,43793 -13917.542226552963,4.018395185470581,25470.0196492672,79135,0,25470.0196492672,0.9862079620361328,0.0540053136646747,0.2754306980280074,43793,39393.80100250244,0.9955662488937378,0.0138804921880364,0.7777995135357072,0.9869794249534608,0.0507164113223552,0.2937873785465716,43793 -14037.619457006454,4.064361095428467,25710.17612195015,79868,0,25710.17612195015,0.9862112998962402,0.054036695510149,0.2753682278616062,43793,39754.10463857651,0.9955735206604004,0.0137835601344704,0.7806222712920365,0.9869920015335084,0.0507424473762512,0.2936313270430152,43793 -14156.27874803543,4.105771064758301,25950.23030114174,80611,0,25950.23030114174,0.9862112998962402,0.0540365986526012,0.2754226497068658,43793,40112.879777908325,0.99559223651886,0.0137861659750342,0.7772153678768996,0.9869915843009948,0.0507423616945743,0.2936277846170593,43793 -14277.388365507126,4.147861957550049,26190.267526865005,81350,0,26190.267526865005,0.9862112998962402,0.0540365986526012,0.275526511079584,43793,40474.08954358101,0.9955688118934632,0.0138651421293616,0.7831689070136559,0.9869915843009948,0.0507423616945743,0.2936635114911271,43793 -14392.712438821793,4.19120192527771,26430.20501279831,82101,0,26430.20501279831,0.9862112998962402,0.0540365986526012,0.2754228664505747,43793,40829.4153251648,0.995573878288269,0.0138689950108528,0.786614385369315,0.9869915843009948,0.0507423616945743,0.2936718267714238,43793 -14513.572434902191,4.234838008880615,26670.177414417267,82848,0,26670.177414417267,0.9862112998962402,0.0540365986526012,0.2755169929742068,43793,41190.311574697495,0.995610535144806,0.0137790357694029,0.7796165680656176,0.9869915843009948,0.0507423616945743,0.2937034401072733,43793 -14630.670375823976,4.278192520141602,26910.234219789505,83601,0,26910.234219789505,0.9862112998962402,0.0540366023778915,0.2753272473537633,43793,41547.52993321419,0.9955638647079468,0.0138187641277909,0.7790167630847286,0.9869915843009948,0.0507423616945743,0.2936536998079178,43793 -14749.948195457458,4.320634126663208,27150.300426721573,84348,0,27150.300426721573,0.9862112998962402,0.0540365986526012,0.2754699155733398,43793,41906.93664479256,0.9955844879150392,0.0138030629605054,0.774204131037632,0.9869915843009948,0.0507423616945743,0.2938256946907282,43793 -14871.777806282043,4.376424074172974,27390.44856762886,85093,0,27390.44856762886,0.9862112998962402,0.0540365986526012,0.2754004305972816,43793,42268.99037504196,0.9954951405525208,0.0139528112486004,0.7879491584598639,0.9869915843009948,0.0507423616945743,0.2936760948804247,43793 -14992.430229187012,4.419990539550781,27630.48721194268,85844,0,27630.48721194268,0.9862112998962402,0.0540365986526012,0.2753637605802473,43793,42629.74559020996,0.9956097602844238,0.0138254640623927,0.7908804322850795,0.9869915843009948,0.0507423616945743,0.2935794272536952,43793 -15116.532540082932,4.467123508453369,27870.723264932632,86583,0,27870.723264932632,0.9862112998962402,0.0540365986526012,0.2754111736117262,43793,42994.15589976311,0.9956318140029908,0.0137186404317617,0.78391687670689,0.9869915843009948,0.0507423616945743,0.2936513145469017,43793 -15232.091541528702,4.512559175491333,28110.6790099144,87324,0,28110.6790099144,0.9862112998962402,0.0540365986526012,0.2754694072339723,43793,43349.73703336716,0.99556964635849,0.0138719119131565,0.7821974859985805,0.9869915843009948,0.0507423616945743,0.2935722567401824,43793 -15355.256055355072,4.554625511169434,28350.94520950317,88072,0,28350.94520950317,0.9862112998962402,0.0540365986526012,0.2753737133918487,43793,43713.2302980423,0.9955666065216064,0.0137403607368469,0.7812936362956535,0.9869915843009948,0.0507423616945743,0.2936613903844776,43793 -15474.09222960472,4.601079940795898,28590.884345531464,88805,0,28590.884345531464,0.9862112998962402,0.0540365986526012,0.2754791782357467,43793,44072.07675909996,0.99554181098938,0.0139920338988304,0.7816320112847811,0.9869915843009948,0.0507423616945743,0.2935806132338083,43793 -15588.70144701004,4.645176410675049,28831.03605914116,89541,0,28831.03605914116,0.9862112998962402,0.0540365986526012,0.2755014562042579,43793,44426.90284395218,0.9955811500549316,0.0138107240200042,0.7874249815139477,0.9869915843009948,0.0507423616945743,0.2937410135718478,43793 -15707.728714704514,4.690348386764526,29071.14478611946,90280,0,29071.14478611946,0.9862112998962402,0.0540366023778915,0.275349481187374,43793,44786.104506492615,0.995635151863098,0.0136811081320047,0.7850135717763382,0.9869915843009948,0.0507423616945743,0.2936347523194236,43793 -15826.435528993608,4.734649658203125,29311.08949136734,91025,0,29311.08949136734,0.9862112998962402,0.0540365986526012,0.2753792621221552,43793,45144.8211247921,0.9955487251281738,0.0139289842918515,0.7848721563012246,0.9869915843009948,0.0507423616945743,0.2937334688183136,43793 -15946.73088979721,4.778702735900879,29551.084138154984,91770,0,29551.084138154984,0.9862112998962402,0.0540365986526012,0.2754203519807138,43793,45505.17563891411,0.9956004619598388,0.0137164033949375,0.7883310032943757,0.9869915843009948,0.0507423616945743,0.2937376259139235,43793 -16064.169033050535,4.824016094207764,29791.020255804066,92515,0,29791.020255804066,0.9862112998962402,0.0540365986526012,0.2754073981223309,43793,45862.61566567421,0.995555818080902,0.0139048527926206,0.7673218996392518,0.9869915843009948,0.0507423616945743,0.2936585100039495,43793 -16181.204716920853,4.867949724197388,30031.20615911484,93256,0,30031.20615911484,0.9862112998962402,0.0540365986526012,0.2754559623634967,43793,46219.90180850029,0.9955539107322692,0.0138565571978688,0.7895852409581146,0.9869915843009948,0.0507423616945743,0.2937831941628872,43793 -16298.102165699003,4.91324520111084,30271.378150701523,93999,0,30271.378150701523,0.9862112998962402,0.0540366023778915,0.2754090103100722,43793,46577.03712940216,0.9955896139144896,0.0138924093917012,0.7888154545262414,0.9869915843009948,0.0507423616945743,0.2936715268549584,43793 -16410.420204401016,4.958989381790161,30511.55004000664,94744,0,30511.55004000664,0.9862112998962402,0.0540365986526012,0.2753408036677083,43793,46929.59366226196,0.9956102967262268,0.0137448869645595,0.7833450442984059,0.9869915843009948,0.0507423616945743,0.2937749887670415,43793 -16528.11999154091,5.0040812492370605,30751.71346235276,95490,0,30751.71346235276,0.9862112998962402,0.0540365986526012,0.2754687286827509,43793,47287.522605896,0.99561607837677,0.0136659052222967,0.7821546171064298,0.9869915843009948,0.0507423616945743,0.2937425339183844,43793 -16648.886462688446,5.048603057861328,30991.67889022827,96244,0,30991.67889022827,0.9862112998962402,0.0540365986526012,0.2754213828360592,43793,47648.32022738457,0.9955379366874696,0.0139642404392361,0.777552866971797,0.9869915843009948,0.0507423616945743,0.2936285897507651,43793 -16765.890092611313,5.093991041183472,31231.81813645363,96990,0,31231.81813645363,0.9862112998962402,0.0540365986526012,0.2754537971353992,43793,48005.52919220925,0.995514750480652,0.0139308404177427,0.7831007753855972,0.9869915843009948,0.0507423616945743,0.2936723064057466,43793 -16882.567397117615,5.139058351516724,31471.98353600502,97738,0,31471.98353600502,0.9862112998962402,0.0540365986526012,0.2754499157951282,43793,48362.4382288456,0.995638906955719,0.0136951776221394,0.7867676729768284,0.9869915843009948,0.0507423616945743,0.2937441843168288,43793 -16996.652015209198,5.186892509460449,31712.06727313996,98487,0,31712.06727313996,0.9862112998962402,0.0540365986526012,0.2754063960084381,43793,48716.67514181137,0.995592474937439,0.0138282580301165,0.7842024551917022,0.9869915843009948,0.0507423616945743,0.2937038315424408,43793 -17112.835739850998,5.234352111816406,31952.20751523972,99229,0,31952.20751523972,0.9862112998962402,0.0540365986526012,0.2754442037892525,43793,49073.06725502014,0.9955712556838988,0.0138505389913916,0.7811136130391136,0.9869915843009948,0.0507423616945743,0.2937046239389677,43793 -17223.295473337173,5.279926300048828,32192.220431804657,99975,0,32192.220431804657,0.9862112998962402,0.0540365986526012,0.2753772766361173,43793,49423.6063015461,0.9955838918685912,0.0137693621218204,0.7773563293839678,0.9869915843009948,0.0507423616945743,0.2937261884863826,43793 -17341.50161242485,5.326424598693848,32432.44850897789,100725,0,32432.44850897789,0.9862112998962402,0.0540365986526012,0.2754337862658121,43793,49782.10767865181,0.9955572485923768,0.0138950860127806,0.7741282285367919,0.9869915843009948,0.0507423616945743,0.2936773969363314,43793 -17455.88349723816,5.3753931522369385,32672.53692626953,101463,0,32672.53692626953,0.9862112998962402,0.0540365986526012,0.2754924725969873,43793,50136.65001726151,0.9955657720565796,0.0138729671016335,0.7906566108249474,0.9869915843009948,0.0507423616945743,0.2936484156831344,43793 -17571.774189710617,5.421801805496216,32912.768273591995,102209,0,32912.768273591995,0.9862112998962402,0.0540365986526012,0.2754421511920781,43793,50492.8391327858,0.9956038594245912,0.0138377929106354,0.7844097048287434,0.9869915843009948,0.0507423616945743,0.2935776915827144,43793 -17684.104935884476,5.467526435852051,33152.73183774948,102950,0,33152.73183774948,0.9862112998962402,0.0540365986526012,0.2753811304726336,43793,50845.19981837273,0.9955734014511108,0.0137864360585808,0.7839354967984052,0.9869915843009948,0.0507423616945743,0.2937865842712753,43793 -17798.916505098343,5.51433539390564,33392.72674036026,103691,0,33392.72674036026,0.9862112998962402,0.0540365986526012,0.2754707876834113,43793,51200.07393550873,0.9955546855926514,0.0138387288898229,0.7864173109327909,0.9869915843009948,0.0507423616945743,0.2938074174387261,43793 -17914.172067642212,5.561546802520752,33632.83125758171,104443,0,33632.83125758171,0.9862112998962402,0.0540365986526012,0.275457568870989,43793,51555.50200009346,0.995599091053009,0.0137803135439753,0.7693775071494818,0.9869915843009948,0.0507423616945743,0.2937300487911993,43793 -18026.642154455185,5.607269763946533,33872.89906716347,105195,0,33872.89906716347,0.9862112998962402,0.0540365986526012,0.2754226879137796,43793,51908.106872558594,0.9955223798751832,0.0139496717602014,0.7825516269931497,0.9869915843009948,0.0507423616945743,0.2937538129987925,43793 -18142.262134552,5.656062602996826,34112.948595285416,105942,0,34112.948595285416,0.9862112998962402,0.0540365986526012,0.2753237852488251,43793,52263.84589409828,0.9956201314926147,0.0137748792767524,0.7936656836107864,0.9869915843009948,0.0507423616945743,0.2937028913704122,43793 -18252.405061244965,5.704350709915161,34353.06476831436,106687,0,34353.06476831436,0.9862112998962402,0.0540365986526012,0.2753481876534774,43793,52614.17381596565,0.9955931305885316,0.0138001209124922,0.7827427341975708,0.9869915843009948,0.0507423616945743,0.2938741571133589,43793 -18370.54836511612,5.751443147659302,34593.14384675026,107435,0,34593.14384675026,0.9862112998962402,0.0540365986526012,0.2754138806616004,43793,52972.46685504913,0.9955994486808776,0.01373241096735,0.7809417040681077,0.9869915843009948,0.0507423616945743,0.2938712298130228,43793 -18480.653708934784,5.805322170257568,34833.32964348793,108172,0,34833.32964348793,0.9862112998962402,0.0540365986526012,0.2753967137243359,43793,53322.83411717415,0.9955727458000184,0.0138137601315975,0.7745374913235604,0.9869915843009948,0.0507423616945743,0.2936441576869056,43793 -18590.412984609604,5.85346269607544,35073.34981870651,108916,0,35073.34981870651,0.9862112998962402,0.0540365986526012,0.2754546663138822,43793,53672.68225312233,0.9955095052719116,0.0140068819746375,0.7848383461143658,0.9869915843009948,0.0507423616945743,0.2936524279192357,43793 -18700.22829508781,5.906094074249268,35313.440516233444,109670,0,35313.440516233444,0.9862112998962402,0.0540365986526012,0.275395700887565,43793,54022.661536455154,0.9956090450286864,0.0138024566695094,0.7855790926404516,0.9869915843009948,0.0507423616945743,0.2937035017352725,43793 -18813.123986005783,5.954303979873657,35553.47294831276,110416,0,35553.47294831276,0.9862112998962402,0.0540365986526012,0.2755301847475866,43793,54375.65848469734,0.9955815076828004,0.0137841440737247,0.7876969888879968,0.9869915843009948,0.0507423616945743,0.293807377697929,43793 -18932.277148246765,6.003270387649536,35793.5690510273,111167,0,35793.5690510273,0.9862112998962402,0.0540365986526012,0.2755673129599433,43793,54734.97816634178,0.9955974817276,0.0138169005513191,0.7778713694892079,0.9869915843009948,0.0507423616945743,0.2936051051603034,43793 -19042.534830093384,6.051767826080322,36033.5463924408,111912,0,36033.5463924408,0.9862112998962402,0.0540365986526012,0.2753573878831423,43793,55085.2830324173,0.9955654144287108,0.0137704741209745,0.7812751085022197,0.9869915843009948,0.0507423616945743,0.2936191037192044,43793 -19151.26888847351,6.1010048389434814,36273.49717617035,112651,0,36273.49717617035,0.9862112998962402,0.0540365986526012,0.275431472781563,43793,55434.03934621811,0.9955233335494996,0.0140600008890032,0.7775877571357277,0.9869915843009948,0.0507423616945743,0.2937597007290897,43793 -19264.81359577179,6.149595975875855,36513.54458999634,113393,0,36513.54458999634,0.9862112998962402,0.0540365986526012,0.2753357528733862,43793,55787.701063632965,0.9956125020980836,0.0137028321623802,0.7856105334997641,0.9869915843009948,0.0507423616945743,0.2936817305110332,43793 -19377.594454288483,6.197968244552612,36753.55989098549,114130,0,36753.55989098549,0.9862112998962402,0.0540365986526012,0.2754157196970703,43793,56140.56715130806,0.9956178069114684,0.0137725817039608,0.7901904740236783,0.9869915843009948,0.0507423616945743,0.2935913388519824,43793 -19493.173021554947,6.24558687210083,36993.52682876587,114866,0,36993.52682876587,0.9862112998962402,0.0540365986526012,0.2754269244754509,43793,56496.18143749237,0.9955707788467408,0.0138441119343042,0.7832189831767615,0.9869915843009948,0.0507423616945743,0.2935787289523329,43793 -19606.26589083672,6.298041105270386,37233.61578011513,115612,0,37233.61578011513,0.9862112998962402,0.0540365986526012,0.2754521002138382,43793,56849.43639016152,0.9955816268920898,0.0137649774551391,0.7833042739703718,0.9869915843009948,0.0507423616945743,0.293672361685288,43793 -19716.618169546127,6.347314119338989,37473.81154060364,116362,0,37473.81154060364,0.9862112998962402,0.0540365986526012,0.2754720384718899,43793,57200.054525375366,0.9955711960792542,0.0138259436935186,0.7763133112655201,0.9869915843009948,0.0507423616945743,0.2937281957738649,43793 -19829.223463773727,6.399012088775635,37713.840057611465,117102,0,37713.840057611465,0.9862112998962402,0.0540365986526012,0.275412886815387,43793,57552.76069331169,0.9955509305000304,0.0138606643304228,0.7871036951453844,0.9869915843009948,0.0507423616945743,0.2936363325723397,43793 -19940.994478702545,6.45639443397522,37954.070449113846,117842,0,37954.070449113846,0.9862112998962402,0.0540365986526012,0.2753685032524396,43793,57904.84043097496,0.9955703616142272,0.0139037342742085,0.7890036277593242,0.9869915843009948,0.0507423616945743,0.2936448231073847,43793 -20056.31689786911,6.509163856506348,38194.10259127617,118584,0,38194.10259127617,0.9862112998962402,0.0540365986526012,0.2754307721802707,43793,58260.26805996895,0.995596408843994,0.0138527592644095,0.7802696371874123,0.9869915843009948,0.0507423616945743,0.2936510843504485,43793 -20168.87475633621,6.5590479373931885,38434.34361219406,119331,0,38434.34361219406,0.9862112998962402,0.0540366023778915,0.2755224102040603,43793,58613.13779401779,0.995570421218872,0.0137610621750354,0.7811442251892875,0.9869915843009948,0.0507423616945743,0.2935784888838916,43793 -20279.42533969879,6.609145879745483,38674.53906083107,120074,0,38674.53906083107,0.9862112998962402,0.0540365986526012,0.2754766615316152,43793,58963.95445632935,0.9955947399139404,0.0137996897101402,0.7765940262969431,0.9869915843009948,0.0507423616945743,0.293746364209449,43793 -20396.31151819229,6.660247087478638,38914.5114774704,120812,0,38914.5114774704,0.9862112998962402,0.0540365986526012,0.2755423162549807,43793,59320.885125637054,0.9955591559410096,0.0138795301318168,0.7845728335870632,0.9869915843009948,0.0507423616945743,0.2937915479617342,43793 -20506.640821695328,6.712068557739258,39154.60091519356,121554,0,39154.60091519356,0.9862112998962402,0.0540365986526012,0.2755032971678063,43793,59671.376187086105,0.995585799217224,0.0137846209108829,0.7895918756519975,0.9869915843009948,0.0507423616945743,0.2937423974114145,43793 -20620.923574447632,6.762216567993164,39394.5952064991,122302,0,39394.5952064991,0.9862112998962402,0.0540365986526012,0.2754014452639356,43793,60025.72388958931,0.995614528656006,0.0137911001220345,0.7868911213978563,0.9869915843009948,0.0507423616945743,0.2936706497959538,43793 -20732.93293976784,6.812769651412964,39634.70795702934,123046,0,39634.70795702934,0.9862112998962402,0.0540365986526012,0.2755150682884237,43793,60377.917174339294,0.995576560497284,0.013819707557559,0.7784860680421057,0.9869915843009948,0.0507423616945743,0.2935860849708683,43793 -20846.8649699688,6.862906455993652,39874.68226027489,123787,0,39874.68226027489,0.9862112998962402,0.0540365986526012,0.2754583601199453,43793,60731.89415550232,0.9955735206604004,0.0138522526249289,0.7843652857748288,0.9869915843009948,0.0507423616945743,0.2936752539884849,43793 -20959.648869991302,6.913788318634033,40114.67047047615,124526,0,40114.67047047615,0.9862112998962402,0.0540365986526012,0.2754465652294123,43793,61084.737924575806,0.995505154132843,0.0140109313651919,0.7773680917507209,0.9869915843009948,0.0507423616945743,0.2937885779266836,43793 -21073.94648528099,6.963781595230103,40354.857283592224,125262,0,40354.857283592224,0.9862112998962402,0.0540365986526012,0.2753858426724248,43793,61439.29281306267,0.9956053495407104,0.0137420119717717,0.7890075856031153,0.9869915843009948,0.0507423616945743,0.2936749697469681,43793 -21187.55341076851,7.017248630523682,40594.93407726288,126001,0,40594.93407726288,0.9862112998962402,0.0540365986526012,0.2756016566885009,43793,61793.05135250092,0.9956063628196716,0.0138191375881433,0.7756586723562215,0.9869915843009948,0.0507423616945743,0.2937366296262176,43793 -21302.12158560753,7.071530342102051,40834.8824300766,126746,0,40834.8824300766,0.9862112998962402,0.0540365986526012,0.2755557447850956,43793,62147.64252591133,0.9955751299858092,0.0137756494805216,0.7885662798229311,0.9869915843009948,0.0507423616945743,0.2936974288321023,43793 -21411.546427965164,7.12211012840271,41075.1099421978,127499,0,41075.1099421978,0.9862112998962402,0.0540365986526012,0.2754435153360622,43793,62497.366092681885,0.9956125020980836,0.0137626240029931,0.7903543820948858,0.9869915843009948,0.0507423616945743,0.2935955037536901,43793 -21525.99359869957,7.175124883651733,41315.15427827835,128248,0,41315.15427827835,0.9862112998962402,0.0540365986526012,0.2754312230509766,43793,62851.930923223495,0.9954938292503356,0.0139695256948471,0.771295520144599,0.9869915843009948,0.0507423616945743,0.2937692486438858,43793 -21633.69455242157,7.226402759552002,41555.35489320755,128991,0,41555.35489320755,0.9862112998962402,0.0540365986526012,0.2753627787110927,43793,63199.904326200485,0.995583951473236,0.0138546442613005,0.7902642742906548,0.9869915843009948,0.0507423616945743,0.2937646572325736,43793 -21741.826742887497,7.277673959732056,41795.50968170166,129733,0,41795.50968170166,0.9862112998962402,0.0540366023778915,0.2755378317641422,43793,63548.26286840439,0.9956200122833252,0.0136795286089181,0.7835308102014059,0.9869915843009948,0.0507423616945743,0.2935854843682218,43793 -21851.141013383865,7.32834792137146,42035.57419133186,130478,0,42035.57419133186,0.9862112998962402,0.0540365986526012,0.2754409507264886,43793,63897.71296691895,0.995602548122406,0.0138369193300604,0.7875081898725586,0.9869915843009948,0.0507423616945743,0.2936325389844466,43793 -21960.82028913498,7.379400491714477,42275.503985881805,131227,0,42275.503985881805,0.9862112998962402,0.0540365986526012,0.2752876374633426,43793,64247.39361643791,0.9955888390541076,0.0137723758816719,0.7732472724957443,0.9869915843009948,0.0507423616945743,0.2935959974947021,43793 -22073.290736675262,7.432871103286743,42515.66183280945,131969,0,42515.66183280945,0.9862112998962402,0.0540365986526012,0.2754581241452384,43793,64600.09630036354,0.9955713152885436,0.0138152642175555,0.7818256717297449,0.9869915843009948,0.0507423616945743,0.2936461241032045,43793 -22183.3386054039,7.485567808151245,42755.7799987793,132722,0,42755.7799987793,0.9862112998962402,0.0540365986526012,0.2754239669424201,43793,64950.33537364006,0.9955384135246276,0.0139381103217601,0.7809306588978733,0.9869915843009948,0.0507423616945743,0.2938063416664886,43793 -22291.031487464905,7.536278247833252,42995.96339964867,133475,0,42995.96339964867,0.9862112998962402,0.0540365986526012,0.2754217274149027,43793,65298.283219099045,0.9956251978874208,0.0136889275163412,0.7866224635436307,0.9869915843009948,0.0507423616945743,0.2939696343293295,43793 -22398.84451198578,7.58860969543457,43236.10581612587,134223,0,43236.10581612587,0.9862112998962402,0.0540365986526012,0.275459098161773,43793,65646.3113667965,0.9955657124519348,0.0139132514595985,0.785629117844441,0.9869915843009948,0.0507423616945743,0.2937110096408846,43793 -22513.485692977905,7.652945756912231,43476.32939147949,134960,0,43476.32939147949,0.9862112998962402,0.0540365986526012,0.2754662314445237,43793,66001.26565551758,0.995591163635254,0.0138048967346549,0.7821396458418919,0.9869915843009948,0.0507423616945743,0.2935877750821879,43793 -22623.08250141144,7.711076021194458,43716.24765229225,135704,0,43716.24765229225,0.9862112998962402,0.0540365986526012,0.2755756755415497,43793,66350.86281728745,0.9955680966377258,0.0138288754969835,0.7725323100158961,0.9869915843009948,0.0507423616945743,0.2938210133629313,43793 -22734.455174207687,7.762632131576538,43956.23304724693,136453,0,43956.23304724693,0.9862112998962402,0.0540365986526012,0.2753951118399199,43793,66702.2934141159,0.9955561757087708,0.0138661824166774,0.7835748448083539,0.9869915843009948,0.0507423616945743,0.2936474347516125,43793 -22844.751775741577,7.815975904464722,44196.325745821,137202,0,44196.325745821,0.9862112998962402,0.0540365986526012,0.2754957400432071,43793,67052.75707435608,0.9955800175666808,0.0137928072363138,0.7885792245747008,0.9869915843009948,0.0507423616945743,0.2936389094944352,43793 -22954.84996652603,7.869782209396362,44436.34388566017,137945,0,44436.34388566017,0.9862112998962402,0.0540365986526012,0.2754197051375271,43793,67402.94797158241,0.9956175088882446,0.0137908682227134,0.7859297351781429,0.9869915843009948,0.0507423616945743,0.2936445535571279,43793 -23065.253385066982,7.923959970474243,44676.512374162674,138693,0,44676.512374162674,0.9862112998962402,0.0540365986526012,0.2755277708514029,43793,67753.59470915794,0.9955361485481262,0.0139456894248723,0.7821278781897788,0.9869915843009948,0.0507423616945743,0.2938877070849096,43793 -23174.9851911068,7.976691007614136,44916.49689507485,139441,0,44916.49689507485,0.9862112998962402,0.0540365986526012,0.2754072478578473,43793,68103.3846859932,0.9956095814704896,0.0136872818693518,0.7837841741459687,0.9869915843009948,0.0507423616945743,0.2937139301409759,43793 -23288.182639360428,8.0308096408844,45156.51519060135,140189,0,45156.51519060135,0.9862112998962402,0.0540365986526012,0.2753595513198698,43793,68456.67514777184,0.9955382347106934,0.0139403715729713,0.7720245902031437,0.9869915843009948,0.0507423616945743,0.2935594829764194,43793 -23399.51488018036,8.09301209449768,45396.50525188446,140920,0,45396.50525188446,0.9862112998962402,0.0540365986526012,0.275453484211465,43793,68808.08483695984,0.9955587983131408,0.0138416392728686,0.7846010270148847,0.9869915843009948,0.0507423616945743,0.2938899450060268,43793 -23509.45137310028,8.151569604873657,45636.494839668274,141660,0,45636.494839668274,0.9862112998962402,0.0540365986526012,0.2754512949348238,43793,69158.09258151054,0.9956005215644836,0.0138408299535512,0.7835544456428272,0.9869915843009948,0.0507423616945743,0.2935728189413232,43793 -23616.3204202652,8.205451011657715,45876.57214689255,142412,0,45876.57214689255,0.9862112998962402,0.0540365986526012,0.2754350355381713,43793,69505.11353802681,0.9955870509147644,0.0138025749474763,0.7837377125709595,0.9869915843009948,0.0507423616945743,0.2938052738879229,43793 -23722.715693950653,8.261404752731323,46116.80425333977,143162,0,46116.80425333977,0.9862112998962402,0.0540366023778915,0.2753971686265937,43793,69851.81775379181,0.9956026077270508,0.0137334009632468,0.7859749685301147,0.9869915843009948,0.0507423616945743,0.2939219873827208,43793 -23831.21180343628,8.31418228149414,46356.86944055557,143911,0,46356.86944055557,0.9862112998962402,0.0540365986526012,0.2754010988589716,43793,70200.4524781704,0.9955673217773438,0.0138666266575455,0.7774714317441656,0.9869915843009948,0.0507423616945743,0.2937292203462356,43793 -23943.13414955139,8.367965698242188,46596.89909791946,144659,0,46596.89909791946,0.9862112998962402,0.0540365986526012,0.2754712819334631,43793,70552.47891068459,0.995541512966156,0.0138751147314906,0.7791450319607318,0.9869915843009948,0.0507423616945743,0.2937038021491862,43793 -24049.355276346207,8.423321008682251,46837.16341853142,145404,0,46837.16341853142,0.9862112998962402,0.0540365986526012,0.2754216593004718,43793,70899.0400402546,0.9955787062644958,0.0138519490137696,0.7859469751789931,0.9869915843009948,0.0507423616945743,0.2935924268937204,43793 -24157.030386209488,8.47667145729065,47077.318110466,146157,0,47077.318110466,0.9862112998962402,0.0540365986526012,0.2754313374050625,43793,71246.9433298111,0.9956125020980836,0.0137772569432854,0.7824314669110136,0.9869915843009948,0.0507423616945743,0.2937127948568798,43793 -24264.45087337494,8.5319983959198,47317.36692094803,146908,0,47317.36692094803,0.9862112998962402,0.0540365986526012,0.2754988653620888,43793,71594.48837161064,0.9955742955207824,0.013849351555109,0.7841609706569808,0.9869915843009948,0.0507423616945743,0.2936861766004716,43793 -24376.361690044403,8.58792233467102,47557.4769847393,147650,0,47557.4769847393,0.9862112998962402,0.0540365986526012,0.2754986839728315,43793,71946.58534765244,0.9956088066101074,0.0137588586658239,0.7803652558375634,0.9869915843009948,0.0507423616945743,0.2936537150876402,43793 -24482.029361486435,8.643598794937134,47797.62417554855,148396,0,47797.62417554855,0.9862112998962402,0.0540365986526012,0.2754836519611033,43793,72292.47641682625,0.9955193400382996,0.0139368381351232,0.7758290701438132,0.9869915843009948,0.0507423616945743,0.2937131747472333,43793 -24590.29389500618,8.698626518249512,48037.61570096016,149150,0,48037.61570096016,0.9862112998962402,0.0540365986526012,0.2754176262087233,43793,72640.80835223198,0.9955896735191344,0.0137344496324658,0.7878894749972761,0.9869915843009948,0.0507423616945743,0.2935785880532318,43793 -24696.59809613228,8.752968549728394,48277.7650783062,149904,0,48277.7650783062,0.9862112998962402,0.0540366023778915,0.2754609112657465,43793,72987.3367960453,0.9956307411193848,0.0137735474854707,0.7877537249143864,0.9869915843009948,0.0507423616945743,0.2937076776102705,43793 -24808.22857093811,8.807228088378906,48517.886543273926,150660,0,48517.886543273926,0.9862112998962402,0.0540365986526012,0.275389384030019,43793,73339.16317415237,0.9955337047576904,0.0139522580429911,0.7795977649702918,0.9869915843009948,0.0507423616945743,0.2937253985372344,43793 -24915.70268535614,8.862153768539429,48757.82743740082,151405,0,48757.82743740082,0.9862112998962402,0.0540365986526012,0.2754052718551408,43793,73686.6535577774,0.995615005493164,0.0136622358113527,0.7890784196348725,0.9869915843009948,0.0507423616945743,0.2936964955645312,43793 -25026.75702905655,9.230566263198853,48997.56856369972,152164,0,48997.56856369972,0.9862112998962402,0.0540366023778915,0.2754099633235572,43793,74037.83790445328,0.995550274848938,0.0139026893302798,0.7697975941964679,0.9869915843009948,0.0507423616945743,0.2937655585787352,43793 -25137.68102812767,9.286834478378296,49237.56392478943,152916,0,49237.56392478943,0.9862112998962402,0.0540365986526012,0.2754459852629338,43793,74388.833922863,0.995542049407959,0.0139244012534618,0.7826677908654057,0.9869915843009948,0.0507423616945743,0.2936824436732723,43793 -25245.425184965134,9.341866731643677,49477.695827007294,153672,0,49477.695827007294,0.9862112998962402,0.0540365986526012,0.2753901263563581,43793,74736.78531956673,0.9956010580062866,0.0138049824163317,0.7902975906992156,0.9869915843009948,0.0507423616945743,0.2937089946570169,43793 -25348.473020792007,9.397128582000732,49717.68036675453,154432,0,49717.68036675453,0.9862112998962402,0.0540366023778915,0.2753254195522341,43793,75079.89323925972,0.9955745935440063,0.0138638485223054,0.7789048942518366,0.9869915843009948,0.0507423616945743,0.2936350856274303,43793 -25459.034336566925,9.452574968338013,49957.67692351341,155180,0,49957.67692351341,0.9862112998962402,0.0540365986526012,0.2754519692100451,43793,75430.52669262886,0.995599091053009,0.013690123334527,0.7854078741192276,0.9869915843009948,0.0507423616945743,0.2937296730739804,43793 -25571.970460414886,9.509182453155518,50197.66117596626,155931,0,50197.66117596626,0.9862112998962402,0.0540365986526012,0.2754514664735991,43793,75783.52428460121,0.9955906271934508,0.0138539336621761,0.773038709091723,0.9869915843009948,0.0507423616945743,0.2937262505472416,43793 -25678.198181152344,9.566497564315796,50437.83835029602,156677,0,50437.83835029602,0.9862112998962402,0.0540365986526012,0.2753889199464259,43793,76130.00862145424,0.9955081939697266,0.0139719014987349,0.777331023418733,0.9869915843009948,0.0507423616945743,0.2936683234079764,43793 -25783.583258152008,9.623881101608276,50677.798704624176,157433,0,50677.798704624176,0.9862112998962402,0.0540365986526012,0.2753973492261922,43793,76475.43181467056,0.9956135749816896,0.0137933911755681,0.7892423439032692,0.9869915843009948,0.0507423616945743,0.293717954748,43793 -25893.5378510952,9.682755708694458,50917.96872997284,158182,0,50917.96872997284,0.9862112998962402,0.0540365986526012,0.275407186043599,43793,76825.63581681252,0.995617926120758,0.0137408478185534,0.7845530195903894,0.9869915843009948,0.0507423616945743,0.2936882202342486,43793 -26002.302206754684,9.739495277404783,51158.21628952026,158929,0,51158.21628952026,0.9862112998962402,0.0540365986526012,0.2754564632069592,43793,77174.72514629364,0.9955463409423828,0.0138870431110262,0.7822689544969321,0.9869915843009948,0.0507423616945743,0.2936670853993187,43793 -26113.859795093536,9.795924425125122,51398.28981542587,159684,0,51398.28981542587,0.9862112998962402,0.0540365986526012,0.2754188429540947,43793,77526.43344187737,0.9955745339393616,0.0138083333149552,0.7875565170609959,0.9869915843009948,0.0507423616945743,0.2936776762052396,43793 -26218.408900260925,9.85430383682251,51638.4400715828,160424,0,51638.4400715828,0.9862112998962402,0.0540365986526012,0.2754478020806966,43793,77871.21443009377,0.9955351948738098,0.0138880098238587,0.774616761796803,0.9869915843009948,0.0507423616945743,0.2936985274524972,43793 -26330.62533640861,9.911830186843872,51878.58922076225,161174,0,51878.58922076225,0.9862112998962402,0.0540365986526012,0.2754757861197963,43793,78223.65838193893,0.9956015944480896,0.0137695474550127,0.7862721595759705,0.9869915843009948,0.0507423616945743,0.2936892479236363,43793 -26438.3941590786,9.969873428344728,52118.57286572456,161916,0,52118.57286572456,0.9862112998962402,0.0540365986526012,0.2754001764751488,43793,78571.4911108017,0.9956266283988952,0.0137533387169241,0.7849122736140868,0.9869915843009948,0.0507423616945743,0.2937739629217785,43793 -26549.352571725845,10.02830934524536,52358.634991168976,162659,0,52358.634991168976,0.9862112998962402,0.0540365986526012,0.2755013318092121,43793,78922.59042525291,0.9955690503120422,0.0138960359618067,0.7824925082203897,0.9869915843009948,0.0507423616945743,0.2936022995657795,43793 -26657.94908690453,10.08947467803955,52598.57873725891,163411,0,52598.57873725891,0.9862112998962402,0.0540365986526012,0.2753319317863094,43793,79271.2127199173,0.9955653548240662,0.0137989101931452,0.7842163827697879,0.9869915843009948,0.0507423616945743,0.2935796978506215,43793 -26762.67085957527,10.14896845817566,52838.597648859024,164163,0,52838.597648859024,0.9862112998962402,0.0540365986526012,0.2754401584745454,43793,79616.03345918655,0.9955796599388124,0.0137862637639045,0.7746348160789243,0.9869915843009948,0.0507423616945743,0.2935902334093814,43793 -26866.38114452362,10.207342863082886,53078.761633872986,164911,0,53078.761633872986,0.9862112998962402,0.0540365986526012,0.2754810781515628,43793,79959.98651099205,0.9955300688743592,0.0139853237196803,0.7824321755727822,0.9869915843009948,0.0507423616945743,0.2937889285679689,43793 -26969.93276333809,10.265093803405762,53318.756410360336,165668,0,53318.756410360336,0.9862112998962402,0.0540366023778915,0.275376862591533,43793,80303.61129760742,0.9955757260322572,0.0138637935742735,0.7839171197252315,0.9869915843009948,0.0507423616945743,0.2936526321692164,43793 -27078.699613571167,10.322208404541016,53558.69554495812,166423,0,53558.69554495812,0.9862112998962402,0.0540365986526012,0.2753630251818127,43793,80652.39430904388,0.995619535446167,0.0137594006955623,0.7878243548892319,0.9869915843009948,0.0507423616945743,0.2937584227051379,43793 -27190.793970823288,10.380990743637083,53798.87780714035,167157,0,53798.87780714035,0.9862112998962402,0.0540365986526012,0.2754087925659607,43793,81004.75309371948,0.9955827593803406,0.0137697160243988,0.7863487925670267,0.9869915843009948,0.0507423616945743,0.2939032390369472,43793 -27299.30680155754,10.439002513885498,54039.08760070801,167909,0,54039.08760070801,0.9862112998962402,0.0540365986526012,0.275392875397748,43793,81353.55399942398,0.9955384135246276,0.0138554880395531,0.7785570710333689,0.9869915843009948,0.0507423616945743,0.2936441465812117,43793 -27407.28624010086,10.49805474281311,54279.04728722572,168650,0,54279.04728722572,0.9862112998962402,0.0540365986526012,0.275552159012681,43793,81701.57476568222,0.99558025598526,0.0138541841879487,0.7801272234892287,0.9869915843009948,0.0507423616945743,0.2937228481334556,43793 -27512.24540758133,10.55790948867798,54518.96765470505,169403,0,54518.96765470505,0.9862112998962402,0.0540365986526012,0.2755014735067793,43793,82046.5347559452,0.9955908060073853,0.0137906232848763,0.7917480193232054,0.9869915843009948,0.0507423616945743,0.293798366684791,43793 -27618.15584230423,10.618480205535889,54759.18366193771,170162,0,54759.18366193771,0.9862112998962402,0.0540365986526012,0.275371892760308,43793,82392.74225926399,0.9956451058387756,0.0137569643557071,0.7842756858193123,0.9869915843009948,0.0507423616945743,0.2935977855388695,43793 -27730.01386666298,10.676766395568848,54999.39244008064,170921,0,54999.39244008064,0.9862112998962402,0.0540366023778915,0.2755011356766365,43793,82744.88785123825,0.9955541491508484,0.0138856377452611,0.7790105273746023,0.9869915843009948,0.0507423616945743,0.2936214992287673,43793 -27830.2014875412,10.736129522323608,55239.529722452164,171674,0,55239.529722452164,0.9862112998962402,0.05403659865260124,0.27539845493352894,43793,83085.2924580574,0.9955615401268005,0.013751037418842316,0.7800829913217441,0.9869915843009949,0.050742361694574356,0.2936228202985142,43793 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index 20dfdb3dc..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1956 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.9044524,0.76413035,,,,,,,,,,,,,,,,, -1,,,0.4663219749927521,0.7589666843414307,0.0246376417424939,0.4651613235473633,0.755351185798645,0.0288803148676715,43793.0,0.4661603569984436,0.7539003491401672,0.0295337511283684,43793.0,20.478039264678955,557.4664070606232,20.478039264678955,536.988322019577,0.0,0.0 -100,0.31746224,0.28501502,,,,,,,,,,,,,,,,, -200,0.10171896,0.115627535,,,,,,,,,,,,,,,,, -300,0.033503525,0.07099892,,,,,,,,,,,,,,,,, -400,0.017030556,0.062060278,,,,,,,,,,,,,,,,, -500,0.01655227,0.065490015,,,,,,,,,,,,,,,,, -600,0.011499421,0.046450816,,,,,,,,,,,,,,,,, -700,0.048181113,0.05781099,,,,,,,,,,,,,,,,, -762,,,0.9867117404937744,0.0512692332267761,0.0549762184250856,0.9842259287834167,0.0606738105416297,0.0549251718406611,43793.0,0.9832456707954408,0.0639590546488761,0.0563886604872141,43793.0,260.4673655033112,913.1106786727904,260.4673655033112,652.5934982299805,0.0295979976654052,0.0 -800,0.032954443,0.05619415,,,,,,,,,,,,,,,,, -900,0.012765743,0.05334847,,,,,,,,,,,,,,,,, -1000,0.047234662,0.052543644,,,,,,,,,,,,,,,,, -1100,0.032284837,0.038926225,,,,,,,,,,,,,,,,, -1200,0.022336695,0.045801546,,,,,,,,,,,,,,,,, -1300,0.020581571,0.04635021,,,,,,,,,,,,,,,,, -1400,0.01561556,0.05166341,,,,,,,,,,,,,,,,, -1500,0.021452336,0.049216244,,,,,,,,,,,,,,,,, -1524,,,0.9874638319015504,0.0462587960064411,0.1052157271216831,0.9847134351730348,0.055482342839241,0.1035546290814991,43793.0,0.983719527721405,0.0586970448493957,0.1037015853295259,43793.0,500.6432864665985,1274.9377574920654,500.6432864665985,774.1975507736206,0.0569667816162109,0.0 -1600,0.023426518,0.05324604,,,,,,,,,,,,,,,,, -1700,0.018855251,0.044267934,,,,,,,,,,,,,,,,, -1800,0.024631873,0.048922937,,,,,,,,,,,,,,,,, -1900,0.019481676,0.042914957,,,,,,,,,,,,,,,,, -2000,0.01546517,0.047382034,,,,,,,,,,,,,,,,, -2100,0.015742412,0.04148964,,,,,,,,,,,,,,,,, -2200,0.019170707,0.045059502,,,,,,,,,,,,,,,,, -2276,,,0.9878709316253662,0.0440672747790813,0.135036552981234,0.985028862953186,0.0528953969478607,0.1320009243871645,43793.0,0.9840750098228456,0.0557515323162078,0.1282670616537205,43793.0,740.6774151325226,1635.6136515140531,740.6774151325226,894.7939918041229,0.0827620029449462,0.0 -2300,0.01196625,0.0416666,,,,,,,,,,,,,,,,, -2400,0.010211818,0.039977983,,,,,,,,,,,,,,,,, -2500,0.011375851,0.042715937,,,,,,,,,,,,,,,,, -2600,0.014976674,0.0468786,,,,,,,,,,,,,,,,, -2700,0.015607724,0.039282285,,,,,,,,,,,,,,,,, -2800,0.010811957,0.041928116,,,,,,,,,,,,,,,,, -2900,0.011977563,0.041615672,,,,,,,,,,,,,,,,, -3000,0.017376874,0.042702578,,,,,,,,,,,,,,,,, -3038,,,0.9881895780563354,0.0415606386959552,0.1676026993212543,0.9853706955909728,0.050984688103199,0.1571880636251882,43793.0,0.9844068884849548,0.0539854243397712,0.154043207003535,43793.0,980.8679234981536,2000.177325725556,980.8679234981536,1019.1199588775636,0.110504150390625,0.0 -3100,0.018551512,0.04488015,,,,,,,,,,,,,,,,, -3200,0.014274404,0.045910276,,,,,,,,,,,,,,,,, -3300,0.017604701,0.04185045,,,,,,,,,,,,,,,,, -3400,0.0127207,0.043645963,,,,,,,,,,,,,,,,, -3500,0.016742803,0.0468149,,,,,,,,,,,,,,,,, -3600,0.011230965,0.04100452,,,,,,,,,,,,,,,,, -3700,0.011545218,0.04163957,,,,,,,,,,,,,,,,, -3793,,,0.988350510597229,0.0400646440684795,0.2019606540584417,0.985541582107544,0.0497087389230728,0.1713321946448269,43793.0,0.9846153855323792,0.0524870157241821,0.1727824752855969,43793.0,1220.901849985123,2365.3287811279297,1220.901849985123,1144.1841297149658,0.1403696537017822,0.0 -3800,0.013538933,0.042911444,,,,,,,,,,,,,,,,, -3900,0.011670928,0.04013721,,,,,,,,,,,,,,,,, -4000,0.013349576,0.041788995,,,,,,,,,,,,,,,,, -4100,0.014050352,0.040463332,,,,,,,,,,,,,,,,, -4200,0.011640541,0.04495341,,,,,,,,,,,,,,,,, -4300,0.019826768,0.042162653,,,,,,,,,,,,,,,,, -4400,0.010327539,0.039183855,,,,,,,,,,,,,,,,, -4500,0.012165812,0.039871823,,,,,,,,,,,,,,,,, -4545,,,0.988611102104187,0.038961786776781,0.22024903128454,0.9857437014579772,0.0485648401081562,0.1875371005404125,43793.0,0.9848293662071228,0.0512782260775566,0.1918057026970715,43793.0,1460.9999330043793,2732.639394760132,1460.9999330043793,1271.3422000408173,0.1712827682495117,0.0 -4600,0.009985788,0.03422067,,,,,,,,,,,,,,,,, -4700,0.014310252,0.04134432,,,,,,,,,,,,,,,,, -4800,0.016047243,0.038155563,,,,,,,,,,,,,,,,, -4900,0.013985063,0.03486722,,,,,,,,,,,,,,,,, -5000,0.010955699,0.04044941,,,,,,,,,,,,,,,,, -5100,0.009007851,0.03597597,,,,,,,,,,,,,,,,, -5200,0.014058479,0.040329132,,,,,,,,,,,,,,,,, -5300,0.013816525,0.03926368,,,,,,,,,,,,,,,,, -5307,,,0.9888646006584167,0.0378557853400707,0.2397422560122646,0.9859409928321838,0.0475250221788883,0.2057997940980671,43793.0,0.9850513339042664,0.0501545369625091,0.2047523244605464,43793.0,1701.144421339035,3098.8045933246613,1701.144421339035,1397.3150610923767,0.1990737915039062,0.0 -5400,0.013475674,0.042591996,,,,,,,,,,,,,,,,, -5500,0.0133119095,0.03900416,,,,,,,,,,,,,,,,, -5600,0.010410702,0.040053144,,,,,,,,,,,,,,,,, -5700,0.012829091,0.03810659,,,,,,,,,,,,,,,,, -5800,0.012302059,0.03460429,,,,,,,,,,,,,,,,, -5900,0.010137812,0.035260875,,,,,,,,,,,,,,,,, -6000,0.02204005,0.040904175,,,,,,,,,,,,,,,,, -6033,,,0.9893951416015624,0.0358287021517753,0.2831310489165644,0.9862219095230104,0.0465059168636798,0.2159552329651954,43793.0,0.985364317893982,0.0490593016147613,0.2205065731715344,43793.0,1941.1350138187408,3469.7046773433685,1941.1350138187408,1528.1678965091703,0.2259581089019775,0.0 -6100,0.015883967,0.033691537,,,,,,,,,,,,,,,,, -6200,0.010745753,0.03163866,,,,,,,,,,,,,,,,, -6300,0.013655818,0.040306687,,,,,,,,,,,,,,,,, -6400,0.01411253,0.03660601,,,,,,,,,,,,,,,,, -6500,0.017717369,0.03659157,,,,,,,,,,,,,,,,, -6600,0.014242331,0.03784472,,,,,,,,,,,,,,,,, -6700,0.010979046,0.037905548,,,,,,,,,,,,,,,,, -6774,,,0.9893807172775269,0.0356910154223442,0.2807387695694591,0.9862409830093384,0.0463772155344486,0.224570762905789,43793.0,0.9852455258369446,0.0492561794817447,0.22154515529967,43793.0,2181.31711101532,3837.4433250427246,2181.31711101532,1655.6694447994232,0.2586400508880615,0.0 -6800,0.014384743,0.034863196,,,,,,,,,,,,,,,,, -6900,0.014083532,0.03334373,,,,,,,,,,,,,,,,, -7000,0.018064044,0.035982244,,,,,,,,,,,,,,,,, -7100,0.013307203,0.033905245,,,,,,,,,,,,,,,,, -7200,0.01602444,0.038674977,,,,,,,,,,,,,,,,, -7300,0.01640283,0.035915017,,,,,,,,,,,,,,,,, -7400,0.013426519,0.037780903,,,,,,,,,,,,,,,,, -7500,0.014844515,0.03704058,,,,,,,,,,,,,,,,, -7536,,,0.989676833152771,0.0346273817121982,0.2996095669035396,0.986353039741516,0.0456151477992534,0.2320336506118515,43793.0,0.985526442527771,0.0482594072818756,0.2345355224697904,43793.0,2421.491445541382,4201.456716775894,2421.491445541382,1779.4602222442627,0.2872166633605957,0.0 -7600,0.014727742,0.0323298,,,,,,,,,,,,,,,,, -7700,0.01852298,0.034517728,,,,,,,,,,,,,,,,, -7800,0.017651625,0.041841675,,,,,,,,,,,,,,,,, -7900,0.012280693,0.03357591,,,,,,,,,,,,,,,,, -8000,0.012734628,0.030043611,,,,,,,,,,,,,,,,, -8100,0.013425863,0.03639733,,,,,,,,,,,,,,,,, -8200,0.029360883,0.03810041,,,,,,,,,,,,,,,,, -8293,,,0.9896586537361144,0.0345381908118724,0.3084910291337902,0.9863985180854796,0.0453417785465717,0.2410090941250586,43793.0,0.9855167865753174,0.0479991845786571,0.2388739386660848,43793.0,2661.753399848938,4568.728547811508,2661.753399848938,1906.420913219452,0.3165235519409179,0.0 -8300,0.013988819,0.03776211,,,,,,,,,,,,,,,,, -8400,0.024958752,0.03144587,,,,,,,,,,,,,,,,, -8500,0.016580222,0.03486143,,,,,,,,,,,,,,,,, -8600,0.02195626,0.03473298,,,,,,,,,,,,,,,,, -8700,0.016752766,0.03430226,,,,,,,,,,,,,,,,, -8800,0.021209512,0.03458462,,,,,,,,,,,,,,,,, -8900,0.018895488,0.037100658,,,,,,,,,,,,,,,,, -9000,0.022423916,0.034959517,,,,,,,,,,,,,,,,, -9042,,,0.9897889494895936,0.0342106819152832,0.3190344427106689,0.9864873886108398,0.045198518782854,0.2418314847993425,43793.0,0.9856544733047484,0.0478878058493137,0.2396084057737746,43793.0,2901.8338191509247,4935.682396650314,2901.8338191509247,2033.2460660934448,0.3444001674652099,0.0 -9100,0.024532204,0.029497182,,,,,,,,,,,,,,,,, -9200,0.016424907,0.031736743,,,,,,,,,,,,,,,,, -9300,0.024734547,0.03191815,,,,,,,,,,,,,,,,, -9400,0.01843953,0.035252072,,,,,,,,,,,,,,,,, -9500,0.018676782,0.037617013,,,,,,,,,,,,,,,,, -9600,0.024735225,0.035529103,,,,,,,,,,,,,,,,, -9700,0.02330586,0.033665106,,,,,,,,,,,,,,,,, -9798,,,0.990106165409088,0.0331576168537139,0.3252865702974921,0.98661607503891,0.0447712168097496,0.2498949257991057,43793.0,0.9857429265975952,0.0474335961043834,0.2517587761230952,43793.0,3141.838902950287,5305.048637628555,3141.838902950287,2162.5575363636017,0.3734691143035888,0.0 -9800,0.018685732,0.0335675,,,,,,,,,,,,,,,,, -9900,0.031140676,0.038865358,,,,,,,,,,,,,,,,, -10000,0.024374556,0.03560186,,,,,,,,,,,,,,,,, -10100,0.021254202,0.030186526,,,,,,,,,,,,,,,,, -10200,0.024913145,0.033907942,,,,,,,,,,,,,,,,, -10300,0.034049667,0.034223236,,,,,,,,,,,,,,,,, -10400,0.02092098,0.032952055,,,,,,,,,,,,,,,,, -10500,0.028918466,0.036917787,,,,,,,,,,,,,,,,, -10550,,,0.9900806546211244,0.0328888185322284,0.3546899649784792,0.9866286516189576,0.0449186265468597,0.2529471298880844,43793.0,0.9857412576675416,0.0477214269340038,0.2506242295578791,43793.0,3381.862954854965,5675.507498025894,3381.862954854965,2292.943807125092,0.401355504989624,0.0 -10600,0.024122244,0.03922316,,,,,,,,,,,,,,,,, -10700,0.020787947,0.03306173,,,,,,,,,,,,,,,,, -10800,0.019083,0.03311915,,,,,,,,,,,,,,,,, -10900,0.036931716,0.031112034,,,,,,,,,,,,,,,,, -11000,0.017637083,0.030944739,,,,,,,,,,,,,,,,, -11100,0.035382938,0.034081206,,,,,,,,,,,,,,,,, -11200,0.023732752,0.031495277,,,,,,,,,,,,,,,,, -11298,,,0.9903729557991028,0.0320044048130512,0.3639937596876242,0.986629068851471,0.0446187332272529,0.2582009507227083,43793.0,0.9856927990913392,0.0474970564246177,0.2519475865112507,43793.0,3621.858566761017,6046.595143318176,3621.858566761017,2423.986411333084,0.4300618171691894,0.0 -11300,0.03208524,0.038371023,,,,,,,,,,,,,,,,, -11400,0.03651985,0.03333555,,,,,,,,,,,,,,,,, -11500,0.024944464,0.035258777,,,,,,,,,,,,,,,,, -11600,0.023861578,0.033113316,,,,,,,,,,,,,,,,, -11700,0.023554413,0.03271177,,,,,,,,,,,,,,,,, -11800,0.043091543,0.03817315,,,,,,,,,,,,,,,,, -11900,0.024653638,0.031736683,,,,,,,,,,,,,,,,, -12000,0.034003068,0.035572145,,,,,,,,,,,,,,,,, -12043,,,0.9904464483261108,0.0314045511186122,0.3860296660383807,0.986702561378479,0.0443083681166172,0.2602666292408093,43793.0,0.9858107566833496,0.0472899749875068,0.2561519833967316,43793.0,3861.815475463867,6416.409091234207,3861.815475463867,2553.7944264411926,0.4591627120971679,0.0 -12100,0.024194289,0.03070167,,,,,,,,,,,,,,,,, -12200,0.030397873,0.032041248,,,,,,,,,,,,,,,,, -12300,0.045661107,0.036009163,,,,,,,,,,,,,,,,, -12400,0.027472787,0.033574697,,,,,,,,,,,,,,,,, -12500,0.0346692,0.034230124,,,,,,,,,,,,,,,,, -12600,0.023062387,0.033512417,,,,,,,,,,,,,,,,, -12700,0.030129952,0.03604098,,,,,,,,,,,,,,,,, -12786,,,0.990697145462036,0.0306289941072464,0.3985212442916236,0.9868206977844238,0.0440055690705776,0.2653274057273426,43793.0,0.9859880805015564,0.046922318637371,0.2588306255104661,43793.0,4102.04964017868,6781.913225889206,4102.04964017868,2679.0158491134644,0.487302303314209,0.0 -12800,0.034650337,0.034470733,,,,,,,,,,,,,,,,, -12900,0.034131687,0.03152593,,,,,,,,,,,,,,,,, -13000,0.027801227,0.03792667,,,,,,,,,,,,,,,,, -13100,0.03159762,0.030972378,,,,,,,,,,,,,,,,, -13200,0.029833004,0.03291019,,,,,,,,,,,,,,,,, -13300,0.026460057,0.03348722,,,,,,,,,,,,,,,,, -13400,0.026652645,0.029321974,,,,,,,,,,,,,,,,, -13500,0.029775707,0.034972347,,,,,,,,,,,,,,,,, -13540,,,0.990955412387848,0.0300523471087217,0.4159336755250958,0.9867321848869324,0.0441063083708286,0.264509819457471,43793.0,0.9859177470207214,0.0466880537569522,0.2566011645968828,43793.0,4342.18771481514,7151.315158605576,4342.18771481514,2808.231062889099,0.5157938003540039,0.0 -13600,0.02969272,0.031034583,,,,,,,,,,,,,,,,, -13700,0.032298323,0.030479359,,,,,,,,,,,,,,,,, -13800,0.044838388,0.032528855,,,,,,,,,,,,,,,,, -13900,0.032480754,0.034275264,,,,,,,,,,,,,,,,, -14000,0.036687054,0.034714326,,,,,,,,,,,,,,,,, -14100,0.030104311,0.03341676,,,,,,,,,,,,,,,,, -14200,0.037038878,0.032450546,,,,,,,,,,,,,,,,, -14297,,,0.9908493757247924,0.0302403513342142,0.4088204630160475,0.9867675304412842,0.0443227104842662,0.2648037087964535,43793.0,0.9858878254890442,0.0471305660903453,0.2551575278951115,43793.0,4582.151921987534,7519.058586359024,4582.151921987534,2935.961008310318,0.5450277328491211,0.0 -14300,0.03890124,0.03543277,,,,,,,,,,,,,,,,, -14400,0.046649262,0.032151937,,,,,,,,,,,,,,,,, -14500,0.028674278,0.028629348,,,,,,,,,,,,,,,,, -14600,0.02942553,0.029410072,,,,,,,,,,,,,,,,, -14700,0.033390995,0.031877056,,,,,,,,,,,,,,,,, -14800,0.035720147,0.030027471,,,,,,,,,,,,,,,,, -14900,0.037549343,0.029196419,,,,,,,,,,,,,,,,, -15000,0.03249876,0.03086191,,,,,,,,,,,,,,,,, -15047,,,0.9907960891723632,0.0304703917354345,0.4065808959727163,0.9867001175880432,0.0441397167742252,0.2686152810756539,43793.0,0.9858798384666444,0.0468828491866588,0.2637279712481232,43793.0,4822.344773769379,7889.667108535767,4822.344773769379,3066.32634806633,0.5752942562103271,0.0 -15100,0.03301181,0.02869632,,,,,,,,,,,,,,,,, -15200,0.046210837,0.03457856,,,,,,,,,,,,,,,,, -15300,0.043553017,0.0332296,,,,,,,,,,,,,,,,, -15400,0.027639076,0.027740614,,,,,,,,,,,,,,,,, -15500,0.042119455,0.034833156,,,,,,,,,,,,,,,,, -15600,0.03743113,0.032554813,,,,,,,,,,,,,,,,, -15700,0.040187903,0.036946252,,,,,,,,,,,,,,,,, -15800,0.039802097,0.03039243,,,,,,,,,,,,,,,,, -15804,,,0.990820586681366,0.0303225982934236,0.4107530372449284,0.9867711663246156,0.0440304353833198,0.2680708694233623,43793.0,0.9858170747756958,0.0469556115567684,0.2542581856078846,43793.0,5062.34095621109,8266.162788152695,5062.34095621109,3202.7751603126526,0.6057493686676025,0.0 -15900,0.050324243,0.033280674,,,,,,,,,,,,,,,,, -16000,0.037445076,0.031414818,,,,,,,,,,,,,,,,, -16100,0.046891563,0.03640969,,,,,,,,,,,,,,,,, -16200,0.030626915,0.030304192,,,,,,,,,,,,,,,,, -16300,0.04391609,0.036155958,,,,,,,,,,,,,,,,, -16400,0.040210225,0.031083865,,,,,,,,,,,,,,,,, -16500,0.032270007,0.03273653,,,,,,,,,,,,,,,,, -16540,,,0.990711271762848,0.0303991846740245,0.4025933595309913,0.9867861866950988,0.0445683114230632,0.2658517816125347,43793.0,0.98592871427536,0.0474228151142597,0.2621201168497622,43793.0,5302.402543067932,8639.794367313385,5302.402543067932,3336.2889742851257,0.6370806694030762,0.0 -16600,0.037549727,0.03009688,,,,,,,,,,,,,,,,, -16700,0.050114974,0.03248284,,,,,,,,,,,,,,,,, -16800,0.040381067,0.03217423,,,,,,,,,,,,,,,,, -16900,0.037701778,0.030486189,,,,,,,,,,,,,,,,, -17000,0.0390082,0.029544003,,,,,,,,,,,,,,,,, -17100,0.043254,0.030955922,,,,,,,,,,,,,,,,, -17200,0.03163408,0.02835405,,,,,,,,,,,,,,,,, -17286,,,0.9908052682876588,0.0300758443772792,0.4174778803693812,0.9868454337120056,0.0440562255680561,0.2679873720240026,43793.0,0.9860125184059144,0.046807624399662,0.2634333099377213,43793.0,5542.3714718818665,9008.544433832169,5542.3714718818665,3465.017394065857,0.669407844543457,0.0 -17300,0.03326453,0.030843772,,,,,,,,,,,,,,,,, -17400,0.03734737,0.02891989,,,,,,,,,,,,,,,,, -17500,0.04281395,0.03281369,,,,,,,,,,,,,,,,, -17600,0.034817167,0.03227907,,,,,,,,,,,,,,,,, -17700,0.052637797,0.036769297,,,,,,,,,,,,,,,,, -17800,0.055949196,0.035131305,,,,,,,,,,,,,,,,, -17900,0.042400476,0.02890891,,,,,,,,,,,,,,,,, -18000,0.04688213,0.030353548,,,,,,,,,,,,,,,,, -18042,,,0.9909988045692444,0.0294074118137359,0.4241687010668576,0.9868454337120056,0.0440659411251544,0.270766252405497,43793.0,0.9859421849250792,0.0468078851699829,0.2667088061323157,43793.0,5782.42107629776,9377.715222358704,5782.42107629776,3594.087012052536,0.7005889415740967,0.0 -18100,0.048558377,0.03294904,,,,,,,,,,,,,,,,, -18200,0.04364844,0.031375486,,,,,,,,,,,,,,,,, -18300,0.04498741,0.029380815,,,,,,,,,,,,,,,,, -18400,0.036054384,0.029887635,,,,,,,,,,,,,,,,, -18500,0.0432743,0.033048134,,,,,,,,,,,,,,,,, -18600,0.044168454,0.030069515,,,,,,,,,,,,,,,,, -18700,0.046259064,0.031830188,,,,,,,,,,,,,,,,, -18793,,,0.9910372495651244,0.0290959365665912,0.4259529531939385,0.9867752194404602,0.0439095199108123,0.2738895893633772,43793.0,0.985970377922058,0.0467896312475204,0.2599671178064437,43793.0,6022.573717832565,9748.362263441086,6022.573717832565,3724.530516862869,0.7313892841339111,0.0 -18800,0.0490253,0.03373998,,,,,,,,,,,,,,,,, -18900,0.035435144,0.03042415,,,,,,,,,,,,,,,,, -19000,0.058754858,0.031648554,,,,,,,,,,,,,,,,, -19100,0.039462935,0.02993087,,,,,,,,,,,,,,,,, -19200,0.03263444,0.026192755,,,,,,,,,,,,,,,,, -19300,0.053452853,0.033534218,,,,,,,,,,,,,,,,, -19400,0.058244243,0.029682009,,,,,,,,,,,,,,,,, -19500,0.042643875,0.033295944,,,,,,,,,,,,,,,,, -19548,,,0.9911220073699952,0.0286788288503885,0.4524073924889998,0.9870005249977112,0.0439878851175308,0.274081041031077,43793.0,0.9860718846321106,0.0469783172011375,0.2697352700780682,43793.0,6262.538156986237,10116.527661561966,6262.538156986237,3852.681501150131,0.7612929344177246,0.0 -19600,0.034114726,0.030818779,,,,,,,,,,,,,,,,, -19700,0.041869376,0.028848302,,,,,,,,,,,,,,,,, -19800,0.043510634,0.030453151,,,,,,,,,,,,,,,,, -19900,0.059077196,0.03233572,,,,,,,,,,,,,,,,, -20000,0.041290924,0.030988142,,,,,,,,,,,,,,,,, -20100,0.056972407,0.032505795,,,,,,,,,,,,,,,,, -20200,0.06821762,0.03394891,,,,,,,,,,,,,,,,, -20299,,,0.9913285970687866,0.0280595403164625,0.4674548046421443,0.9869655966758728,0.0438056886196136,0.2766592911034783,43793.0,0.9862125515937804,0.0465430319309234,0.2696994986515265,43793.0,6502.570947647095,10487.115420341492,6502.570947647095,3983.186435461044,0.7912535667419434,0.0 -20300,0.04590884,0.029368347,,,,,,,,,,,,,,,,, -20400,0.054943573,0.03083377,,,,,,,,,,,,,,,,, -20500,0.040484667,0.03169587,,,,,,,,,,,,,,,,, -20600,0.04067049,0.028518073,,,,,,,,,,,,,,,,, -20700,0.04172114,0.033663798,,,,,,,,,,,,,,,,, -20800,0.04004609,0.027265547,,,,,,,,,,,,,,,,, -20900,0.06334387,0.032415036,,,,,,,,,,,,,,,,, -21000,0.040246375,0.03239758,,,,,,,,,,,,,,,,, -21043,,,0.9913150072097778,0.0281208455562591,0.4598708789135767,0.9869221448898317,0.0441030338406562,0.2744794112477171,43793.0,0.9861034750938416,0.0469097383320331,0.2658631214409926,43793.0,6742.623478651047,10864.359109163284,6742.623478651047,4120.325926780701,0.8225059509277344,0.0 -21100,0.041537154,0.03148925,,,,,,,,,,,,,,,,, -21200,0.044245742,0.030281847,,,,,,,,,,,,,,,,, -21300,0.052072156,0.030767625,,,,,,,,,,,,,,,,, -21400,0.044443343,0.031823404,,,,,,,,,,,,,,,,, -21500,0.059674595,0.034353316,,,,,,,,,,,,,,,,, -21600,0.06627721,0.02839022,,,,,,,,,,,,,,,,, -21700,0.04244835,0.028516572,,,,,,,,,,,,,,,,, -21784,,,0.9913945198059082,0.0282089468091726,0.4455943310874722,0.9868649244308472,0.0439488366246223,0.274835884000275,43793.0,0.9860184192657472,0.046649981290102,0.2655067753076269,43793.0,6982.736036777496,11237.61018204689,6982.736036777496,4253.412977218628,0.853546142578125,0.0 -21800,0.04664957,0.03243163,,,,,,,,,,,,,,,,, -21900,0.05718804,0.027777417,,,,,,,,,,,,,,,,, -22000,0.06704939,0.031003712,,,,,,,,,,,,,,,,, -22100,0.040214833,0.030793972,,,,,,,,,,,,,,,,, -22200,0.04832534,0.028540093,,,,,,,,,,,,,,,,, -22300,0.053325526,0.033403184,,,,,,,,,,,,,,,,, -22400,0.055751305,0.026991468,,,,,,,,,,,,,,,,, -22500,0.047386397,0.030761296,,,,,,,,,,,,,,,,, -22532,,,0.9913437366485596,0.0282543916255235,0.4483314941364055,0.9869745373725892,0.0437615998089313,0.2805751517421683,43793.0,0.9860659837722778,0.0466744974255561,0.2620544147619095,43793.0,7222.803408145904,11608.518146514893,7222.803408145904,4384.202717542648,0.8843135833740234,0.0 -22600,0.048993275,0.02869694,,,,,,,,,,,,,,,,, -22700,0.04757943,0.028371997,,,,,,,,,,,,,,,,, -22800,0.046310738,0.030906267,,,,,,,,,,,,,,,,, -22900,0.05264046,0.03175777,,,,,,,,,,,,,,,,, -23000,0.059942476,0.030496942,,,,,,,,,,,,,,,,, -23100,0.058192197,0.031888995,,,,,,,,,,,,,,,,, -23200,0.04463402,0.028152436,,,,,,,,,,,,,,,,, -23278,,,0.991249144077301,0.0284157712012529,0.4628828997940778,0.9869765639305116,0.0439945273101329,0.2760817632379464,43793.0,0.9861236810684204,0.0467276126146316,0.2661157909791088,43793.0,7462.937109947205,11977.61382842064,7462.937109947205,4513.108451843262,0.9187815189361572,0.0 -23300,0.04568861,0.025470244,,,,,,,,,,,,,,,,, -23400,0.046969708,0.028221287,,,,,,,,,,,,,,,,, -23500,0.046737734,0.031230053,,,,,,,,,,,,,,,,, -23600,0.053871088,0.028028492,,,,,,,,,,,,,,,,, -23700,0.04831912,0.028522504,,,,,,,,,,,,,,,,, -23800,0.04883868,0.02933024,,,,,,,,,,,,,,,,, -23900,0.049114317,0.031116385,,,,,,,,,,,,,,,,, -24000,0.054330345,0.030461049,,,,,,,,,,,,,,,,, -24029,,,0.991316258907318,0.0281748175621032,0.4489988947669722,0.9868953824043274,0.0444948449730873,0.27705578804006,43793.0,0.9860647320747375,0.0475699119269847,0.2683217825679791,43793.0,7702.915818929672,12347.995973825457,7702.915818929672,4643.4614017009735,0.9489178657531738,0.0 -24100,0.04803122,0.02853003,,,,,,,,,,,,,,,,, -24200,0.043876067,0.031503957,,,,,,,,,,,,,,,,, -24300,0.047661323,0.030391544,,,,,,,,,,,,,,,,, -24400,0.060146827,0.03287881,,,,,,,,,,,,,,,,, -24500,0.05481686,0.033086963,,,,,,,,,,,,,,,,, -24600,0.04773027,0.028326632,,,,,,,,,,,,,,,,, -24700,0.048663314,0.028415905,,,,,,,,,,,,,,,,, -24779,,,0.9912404417991638,0.0283038429915905,0.4516727626946165,0.9870346188545228,0.0436816960573196,0.2832783057158123,43793.0,0.9860441088676452,0.0468056164681911,0.2693561350510513,43793.0,7942.99335026741,12715.360845327376,7942.99335026741,4770.698044300079,0.9801278114318848,0.0 -24800,0.060662504,0.030956268,,,,,,,,,,,,,,,,, -24900,0.05846228,0.030945897,,,,,,,,,,,,,,,,, -25000,0.053194873,0.029555064,,,,,,,,,,,,,,,,, -25100,0.03989167,0.027561454,,,,,,,,,,,,,,,,, -25200,0.04363033,0.029737182,,,,,,,,,,,,,,,,, -25300,0.043322425,0.029061146,,,,,,,,,,,,,,,,, -25400,0.06954922,0.030080242,,,,,,,,,,,,,,,,, -25500,0.039872635,0.026465759,,,,,,,,,,,,,,,,, -25526,,,0.9914011359214784,0.0276933945715427,0.4690910450316292,0.986968457698822,0.0441174320876598,0.2768586391381986,43793.0,0.9861544370651244,0.0467205196619033,0.2696622948080067,43793.0,8183.164006948471,13084.890377283096,8183.164006948471,4900.004084348679,1.0125136375427246,0.0 -25600,0.06438778,0.03193974,,,,,,,,,,,,,,,,, -25700,0.047033872,0.028261637,,,,,,,,,,,,,,,,, -25800,0.043719336,0.02698064,,,,,,,,,,,,,,,,, -25900,0.07536833,0.03088387,,,,,,,,,,,,,,,,, -26000,0.052027293,0.027947837,,,,,,,,,,,,,,,,, -26100,0.050837222,0.02718879,,,,,,,,,,,,,,,,, -26200,0.05381282,0.028211853,,,,,,,,,,,,,,,,, -26267,,,0.9915769696235656,0.0271135028451681,0.4799763807067559,0.9870049953460692,0.0444050058722496,0.2801886624062228,43793.0,0.9860975742340088,0.0474251545965671,0.2647629339525147,43793.0,8423.121535778046,13457.219170570374,8423.121535778046,5032.323452711105,1.043934345245361,0.0 -26300,0.06220058,0.02938352,,,,,,,,,,,,,,,,, -26400,0.052578803,0.027418617,,,,,,,,,,,,,,,,, -26500,0.06152,0.028405147,,,,,,,,,,,,,,,,, -26600,0.05442808,0.030780194,,,,,,,,,,,,,,,,, -26700,0.056995217,0.03187626,,,,,,,,,,,,,,,,, -26800,0.04442061,0.025133388,,,,,,,,,,,,,,,,, -26900,0.055536855,0.029517064,,,,,,,,,,,,,,,,, -26998,,,0.9920740127563475,0.0258896984159946,0.5144971201680135,0.9868710041046144,0.0441129580140113,0.2735882717769362,43793.0,0.986042022705078,0.0467803031206131,0.2685553927331099,43793.0,8663.304875612259,13833.450981378555,8663.304875612259,5168.311667442322,1.0791988372802734,0.0 -27000,0.056664266,0.028965164,,,,,,,,,,,,,,,,, -27100,0.054368846,0.029153742,,,,,,,,,,,,,,,,, -27200,0.044018216,0.027694343,,,,,,,,,,,,,,,,, -27300,0.06007609,0.02928293,,,,,,,,,,,,,,,,, -27400,0.053248808,0.027505545,,,,,,,,,,,,,,,,, -27500,0.050768044,0.027718203,,,,,,,,,,,,,,,,, -27600,0.06066614,0.032340094,,,,,,,,,,,,,,,,, -27700,0.058287695,0.03182678,,,,,,,,,,,,,,,,, -27744,,,0.9919695854187012,0.0261274073272943,0.5067298145105661,0.9869599342346193,0.0442815013229846,0.277661715150162,43793.0,0.9861022233963012,0.0469837635755538,0.2671069133928673,43793.0,8903.47397851944,14204.240990161896,8903.47397851944,5298.877843379974,1.1134259700775146,0.0 -27800,0.050524756,0.028768271,,,,,,,,,,,,,,,,, -27900,0.05576615,0.025026174,,,,,,,,,,,,,,,,, -28000,0.052168284,0.02900303,,,,,,,,,,,,,,,,, -28100,0.06122925,0.028985728,,,,,,,,,,,,,,,,, -28200,0.061883986,0.029909952,,,,,,,,,,,,,,,,, -28300,0.06253299,0.03046094,,,,,,,,,,,,,,,,, -28400,0.06627423,0.027522922,,,,,,,,,,,,,,,,, -28492,,,0.9916380047798156,0.026933841407299,0.4918883510856133,0.9869672060012816,0.0448606945574283,0.2792009679877025,43793.0,0.98611319065094,0.0477529726922512,0.263590646622293,43793.0,9143.46143746376,14577.79633116722,9143.46143746376,5432.393877744675,1.1442391872406006,0.0 -28500,0.062493164,0.028214017,,,,,,,,,,,,,,,,, -28600,0.05520798,0.028206777,,,,,,,,,,,,,,,,, -28700,0.052524578,0.030869296,,,,,,,,,,,,,,,,, -28800,0.07308527,0.032543756,,,,,,,,,,,,,,,,, -28900,0.06820878,0.03299033,,,,,,,,,,,,,,,,, -29000,0.063856095,0.03157261,,,,,,,,,,,,,,,,, -29100,0.052521203,0.027212897,,,,,,,,,,,,,,,,, -29200,0.062951945,0.025533916,,,,,,,,,,,,,,,,, -29242,,,0.991564154624939,0.0272049885243177,0.4691972139556656,0.9871259331703186,0.0442688465118408,0.2877655372285986,43793.0,0.9862269163131714,0.0470354966819286,0.2678343298693436,43793.0,9383.617679357529,14945.165321350098,9383.617679357529,5559.553866147995,1.1761319637298584,0.0 -29300,0.050982688,0.026501328,,,,,,,,,,,,,,,,, -29400,0.059305407,0.028529007,,,,,,,,,,,,,,,,, -29500,0.05541605,0.028369,,,,,,,,,,,,,,,,, -29600,0.06522864,0.028944448,,,,,,,,,,,,,,,,, -29700,0.06390478,0.027539393,,,,,,,,,,,,,,,,, -29800,0.05982839,0.030611364,,,,,,,,,,,,,,,,, -29900,0.057534125,0.026915072,,,,,,,,,,,,,,,,, -29991,,,0.991787314414978,0.0266905855387449,0.4814734467220524,0.9869627952575684,0.0438532941043376,0.2948247113332413,43793.0,0.986084520816803,0.0467661544680595,0.2690176516848281,43793.0,9623.72170972824,15317.450065851212,9623.72170972824,5691.682248830795,1.2084007263183594,0.0 -30000,0.06717693,0.028208096,,,,,,,,,,,,,,,,, -30100,0.056348305,0.028218696,,,,,,,,,,,,,,,,, -30200,0.055163525,0.025975093,,,,,,,,,,,,,,,,, -30300,0.065432824,0.026630072,,,,,,,,,,,,,,,,, -30400,0.055400714,0.025989408,,,,,,,,,,,,,,,,, -30500,0.05381394,0.02660901,,,,,,,,,,,,,,,,, -30600,0.05321191,0.025765143,,,,,,,,,,,,,,,,, -30700,0.060928803,0.027063461,,,,,,,,,,,,,,,,, -30733,,,0.9916547536849976,0.0268725398927927,0.505886375231136,0.987059772014618,0.0440127216279506,0.2781033442700349,43793.0,0.9861578345298768,0.046843446791172,0.2706482304538314,43793.0,9863.68747830391,15687.312734127045,9863.68747830391,5821.52396774292,1.2435061931610107,0.0 -30800,0.084751196,0.030954147,,,,,,,,,,,,,,,,, -30900,0.055886623,0.028047116,,,,,,,,,,,,,,,,, -31000,0.05546828,0.022457309,,,,,,,,,,,,,,,,, -31100,0.047595847,0.02770882,,,,,,,,,,,,,,,,, -31200,0.084774174,0.031534847,,,,,,,,,,,,,,,,, -31300,0.055306073,0.030644992,,,,,,,,,,,,,,,,, -31400,0.06558301,0.028113116,,,,,,,,,,,,,,,,, -31466,,,0.991858720779419,0.026429558172822,0.4930497997143744,0.9869838953018188,0.0440751984715461,0.2827252918313451,43793.0,0.9861629009246826,0.0466210283339023,0.2740760282792117,43793.0,10103.922180891035,16060.538444519045,10103.922180891035,5954.453230142593,1.281077861785889,0.0 -31500,0.04703946,0.025951244,,,,,,,,,,,,,,,,, -31600,0.06898276,0.028230872,,,,,,,,,,,,,,,,, -31700,0.05534609,0.027385002,,,,,,,,,,,,,,,,, -31800,0.070296854,0.026724963,,,,,,,,,,,,,,,,, -31900,0.06420366,0.025944127,,,,,,,,,,,,,,,,, -32000,0.05953076,0.028249498,,,,,,,,,,,,,,,,, -32100,0.056175638,0.030577306,,,,,,,,,,,,,,,,, -32200,0.05815674,0.02723357,,,,,,,,,,,,,,,,, -32202,,,0.9919987320899964,0.0257250852882862,0.5162503143904493,0.9871016144752502,0.0439109280705452,0.2890559874025201,43793.0,0.9862138628959656,0.0468107126653194,0.2729542577372867,43793.0,10344.138065576552,16433.148668050766,10344.138065576552,6086.7869300842285,1.3168962001800537,0.0 -32300,0.067083634,0.027141657,,,,,,,,,,,,,,,,, -32400,0.06403009,0.03135271,,,,,,,,,,,,,,,,, -32500,0.073183194,0.0288256,,,,,,,,,,,,,,,,, -32600,0.06612806,0.028148262,,,,,,,,,,,,,,,,, -32700,0.054481883,0.024294183,,,,,,,,,,,,,,,,, -32800,0.07134086,0.029412623,,,,,,,,,,,,,,,,, -32900,0.054290887,0.024875997,,,,,,,,,,,,,,,,, -32951,,,0.9919003248214722,0.0258762538433074,0.5181524506162096,0.9869920015335084,0.0443031750619411,0.2793055822996804,43793.0,0.9862251877784728,0.0470446273684501,0.2750339194968053,43793.0,10584.087520360948,16803.94484782219,10584.087520360948,6217.578600645065,1.3507368564605713,0.0 -33000,0.06079799,0.029961329,,,,,,,,,,,,,,,,, -33100,0.062215567,0.027013266,,,,,,,,,,,,,,,,, -33200,0.06800674,0.03231538,,,,,,,,,,,,,,,,, -33300,0.05825732,0.029008633,,,,,,,,,,,,,,,,, -33400,0.054764297,0.027453057,,,,,,,,,,,,,,,,, -33500,0.05999382,0.029925011,,,,,,,,,,,,,,,,, -33600,0.062150337,0.031126011,,,,,,,,,,,,,,,,, -33700,,,0.9923912286758424,0.0245789401233196,0.5401013702391663,0.9869599342346193,0.0440076068043708,0.2849392163740508,43793.0,0.9861350655555724,0.0467751398682594,0.2726212946469625,43793.0,10824.240885734558,17172.259604930878,10824.240885734558,6345.683080196381,1.3866889476776123,0.0 -33700,0.0636838,0.031194987,,,,,,,,,,,,,,,,, -33800,0.06191184,0.03060598,,,,,,,,,,,,,,,,, -33900,0.057877384,0.028330756,,,,,,,,,,,,,,,,, -34000,0.069547854,0.026379367,,,,,,,,,,,,,,,,, -34100,0.07487092,0.027526217,,,,,,,,,,,,,,,,, -34200,0.061755344,0.025470924,,,,,,,,,,,,,,,,, -34300,0.060483564,0.026885053,,,,,,,,,,,,,,,,, -34400,0.07330672,0.030127112,,,,,,,,,,,,,,,,, -34444,,,0.9924638271331788,0.0244253538548946,0.5408102506296366,0.9870119094848632,0.0440884940326213,0.2841427396357607,43793.0,0.9861670732498168,0.0469842590391635,0.2729544840814009,43793.0,11064.355593442917,17537.51575899124,11064.355593442917,6470.765429973602,1.423579216003418,0.0 -34500,0.06451952,0.026161954,,,,,,,,,,,,,,,,, -34600,0.059104867,0.025264634,,,,,,,,,,,,,,,,, -34700,0.061673917,0.025426758,,,,,,,,,,,,,,,,, -34800,0.057574,0.027259242,,,,,,,,,,,,,,,,, -34900,0.06556742,0.028714506,,,,,,,,,,,,,,,,, -35000,0.07405247,0.032062702,,,,,,,,,,,,,,,,, -35100,0.06674906,0.027266076,,,,,,,,,,,,,,,,, -35193,,,0.992255449295044,0.0248958189040422,0.5338702720010491,0.9870837330818176,0.0445379801094532,0.2771609054072548,43793.0,0.986255943775177,0.0475512593984603,0.2712236659701404,43793.0,11304.5543050766,17905.78564977646,11304.5543050766,6598.782794237137,1.4568994045257568,0.0 -35200,0.072366126,0.028083103,,,,,,,,,,,,,,,,, -35300,0.06391472,0.025822574,,,,,,,,,,,,,,,,, -35400,0.06343031,0.028322507,,,,,,,,,,,,,,,,, -35500,0.064307064,0.02894733,,,,,,,,,,,,,,,,, -35600,0.052795075,0.024288578,,,,,,,,,,,,,,,,, -35700,0.062144242,0.027901255,,,,,,,,,,,,,,,,, -35800,0.06589861,0.02433978,,,,,,,,,,,,,,,,, -35900,0.05915517,0.024557771,,,,,,,,,,,,,,,,, -35938,,,0.9919145107269288,0.0258436184376478,0.5134692618480501,0.9870756268501282,0.0446954704821109,0.2806026018791224,43793.0,0.9862921833992004,0.0474688448011875,0.2737562533565371,43793.0,11544.72012758255,18275.26300549507,11544.72012758255,6728.039513587952,1.4914841651916504,0.0 -36000,0.057175696,0.027513284,,,,,,,,,,,,,,,,, -36100,0.076263145,0.028985798,,,,,,,,,,,,,,,,, -36200,0.07062715,0.029595124,,,,,,,,,,,,,,,,, -36300,0.060540836,0.027270878,,,,,,,,,,,,,,,,, -36400,0.08177952,0.0246843,,,,,,,,,,,,,,,,, -36500,0.06758779,0.029750634,,,,,,,,,,,,,,,,, -36600,0.068925805,0.03216881,,,,,,,,,,,,,,,,, -36684,,,0.9919720888137816,0.0258467327803373,0.4990831629052673,0.9870249032974244,0.044183675199747,0.2844107591122186,43793.0,0.9862239360809326,0.0469509214162826,0.27379143182807,43793.0,11784.769032478333,18643.02334499359,11784.769032478333,6855.695904970169,1.526174783706665,0.0 -36700,0.08912022,0.028896676,,,,,,,,,,,,,,,,, -36800,0.087516494,0.029168643,,,,,,,,,,,,,,,,, -36900,0.065365076,0.02704714,,,,,,,,,,,,,,,,, -37000,0.066659786,0.026851442,,,,,,,,,,,,,,,,, -37100,0.0861224,0.030189939,,,,,,,,,,,,,,,,, -37200,0.060724273,0.024674792,,,,,,,,,,,,,,,,, -37300,0.06207997,0.027818315,,,,,,,,,,,,,,,,, -37400,0.06621416,0.026285917,,,,,,,,,,,,,,,,, -37430,,,0.992161750793457,0.0251628756523132,0.5274201870761099,0.9870399236679076,0.0443357676267623,0.2835930685939118,43793.0,0.9861733913421632,0.0472716465592384,0.2732995899724466,43793.0,12024.889050722122,19012.433899879456,12024.889050722122,6984.932626485825,1.5598630905151367,0.0 -37500,0.06855128,0.030428845,,,,,,,,,,,,,,,,, -37600,0.06606229,0.025739975,,,,,,,,,,,,,,,,, -37700,0.064128764,0.02790542,,,,,,,,,,,,,,,,, -37800,0.06361224,0.029058713,,,,,,,,,,,,,,,,, -37900,0.090686254,0.030198954,,,,,,,,,,,,,,,,, -38000,0.07108126,0.028984485,,,,,,,,,,,,,,,,, -38100,0.071391396,0.027193733,,,,,,,,,,,,,,,,, -38176,,,0.9922487139701844,0.0247343424707651,0.5345395525532157,0.9869737029075624,0.0445678867399692,0.2862172278706824,43793.0,0.986193597316742,0.047460202127695,0.2783172807574485,43793.0,12264.882412672045,19381.634644031525,12264.882412672045,7114.086160421372,1.593503713607788,0.0 -38200,0.10686492,0.027754506,,,,,,,,,,,,,,,,, -38300,0.06083233,0.025909737,,,,,,,,,,,,,,,,, -38400,0.06269553,0.023303026,,,,,,,,,,,,,,,,, -38500,0.059334744,0.025534477,,,,,,,,,,,,,,,,, -38600,0.07605912,0.027616303,,,,,,,,,,,,,,,,, -38700,0.07337237,0.024950359,,,,,,,,,,,,,,,,, -38800,0.0742473,0.027041707,,,,,,,,,,,,,,,,, -38900,0.07028293,0.02888222,,,,,,,,,,,,,,,,, -38923,,,0.9923862218856812,0.0243118181824684,0.5500534552972087,0.9870415329933168,0.0446854121983051,0.2939358720634637,43793.0,0.9862353205680848,0.0476949959993362,0.2720401231594703,43793.0,12505.049992084503,19748.3295750618,12505.049992084503,7240.5595326423645,1.6269724369049072,0.0 -39000,0.0722333,0.027404487,,,,,,,,,,,,,,,,, -39100,0.07101329,0.026676873,,,,,,,,,,,,,,,,, -39200,0.071830735,0.029468508,,,,,,,,,,,,,,,,, -39300,0.0679386,0.027210882,,,,,,,,,,,,,,,,, -39400,0.06257391,0.026399003,,,,,,,,,,,,,,,,, -39500,0.06594662,0.023736374,,,,,,,,,,,,,,,,, -39600,0.075835675,0.028781295,,,,,,,,,,,,,,,,, -39670,,,0.9924398064613342,0.0240245051681995,0.5614326523796441,0.9870817065238952,0.0445556156337261,0.2897896197496002,43793.0,0.98629891872406,0.0473158583045005,0.2731693622950285,43793.0,12745.011418104172,20116.981176376343,12745.011418104172,7369.194571733475,1.6616127490997314,0.0 -39700,0.06842414,0.026836835,,,,,,,,,,,,,,,,, -39800,0.06950439,0.024420392,,,,,,,,,,,,,,,,, -39900,0.0780538,0.027216244,,,,,,,,,,,,,,,,, -40000,0.06453645,0.027495248,,,,,,,,,,,,,,,,, -40100,0.077801555,0.027149286,,,,,,,,,,,,,,,,, -40200,0.07186388,0.025707584,,,,,,,,,,,,,,,,, -40300,0.07990399,0.029083092,,,,,,,,,,,,,,,,, -40400,0.06824754,0.025224876,,,,,,,,,,,,,,,,, -40409,,,0.9927424192428588,0.0233082212507724,0.574511917572232,0.9868040680885316,0.0447564125061035,0.2882639522424063,43793.0,0.9859733581542968,0.0475118793547153,0.2730102876132786,43793.0,12984.965381383896,20484.833154678345,12984.965381383896,7497.037292003632,1.6968517303466797,0.0 -40500,0.0826369,0.027728619,,,,,,,,,,,,,,,,, -40600,0.07795653,0.031595822,,,,,,,,,,,,,,,,, -40700,0.07058838,0.027799888,,,,,,,,,,,,,,,,, -40800,0.06955248,0.025547972,,,,,,,,,,,,,,,,, -40900,0.06557196,0.027889045,,,,,,,,,,,,,,,,, -41000,0.07412177,0.024962537,,,,,,,,,,,,,,,,, -41100,0.08914772,0.029842123,,,,,,,,,,,,,,,,, -41157,,,0.9929094910621644,0.0228238552808761,0.5796415843969315,0.9870049953460692,0.0445084460079669,0.2902851871377324,43793.0,0.986185610294342,0.0472666770219802,0.2766887350322903,43793.0,13224.677162885666,20854.66638636589,13224.677162885666,7626.816812753677,2.0183424949646,0.0 -41200,0.070004664,0.024969703,,,,,,,,,,,,,,,,, -41300,0.0658861,0.023606487,,,,,,,,,,,,,,,,, -41400,0.07877072,0.02380451,,,,,,,,,,,,,,,,, -41500,0.09013501,0.027361337,,,,,,,,,,,,,,,,, -41600,0.08111334,0.02616763,,,,,,,,,,,,,,,,, -41700,0.07723143,0.024763413,,,,,,,,,,,,,,,,, -41800,0.08408598,0.029280776,,,,,,,,,,,,,,,,, -41900,0.069230914,0.025622645,,,,,,,,,,,,,,,,, -41906,,,0.9926960468292236,0.0233404859900474,0.5810469593506353,0.9869651794433594,0.0448800325393676,0.2916548765324748,43793.0,0.9861674904823304,0.0477052852511405,0.26936929027808,43793.0,13464.927323102953,21221.274071455,13464.927323102953,7753.115566253662,2.056342363357544,0.0 -42000,0.07273372,0.0255546,,,,,,,,,,,,,,,,, -42100,0.06955353,0.027680853,,,,,,,,,,,,,,,,, -42200,0.07693371,0.02565115,,,,,,,,,,,,,,,,, -42300,0.06699577,0.029283082,,,,,,,,,,,,,,,,, -42400,0.08234736,0.02353016,,,,,,,,,,,,,,,,, -42500,0.08663683,0.028822744,,,,,,,,,,,,,,,,, -42600,0.084544815,0.027438266,,,,,,,,,,,,,,,,, -42649,,,0.9926057457923888,0.0236248318105936,0.5628984806159489,0.9871000051498412,0.0449630059301853,0.2910589062124619,43793.0,0.9862475395202636,0.0478574857115745,0.2744560090343512,43793.0,13705.163032531738,21591.03133511544,13705.163032531738,7882.5806567668915,2.092425346374512,0.0 -42700,0.08529944,0.028398499,,,,,,,,,,,,,,,,, -42800,0.0765239,0.025315432,,,,,,,,,,,,,,,,, -42900,0.08251922,0.025036052,,,,,,,,,,,,,,,,, -43000,0.082129054,0.02332033,,,,,,,,,,,,,,,,, -43100,0.075133964,0.026099186,,,,,,,,,,,,,,,,, -43200,0.07662891,0.026395915,,,,,,,,,,,,,,,,, -43300,0.074611545,0.026393013,,,,,,,,,,,,,,,,, -43386,,,0.9926634430885316,0.0233371630311012,0.5653618151959215,0.987144649028778,0.044412650167942,0.2941698091762321,43793.0,0.9862290024757384,0.0474529117345809,0.2745836850366369,43793.0,13945.18832564354,21957.04997587204,13945.18832564354,8008.517409801483,2.128856897354126,0.0 -43400,0.09623073,0.028057348,,,,,,,,,,,,,,,,, -43500,0.07393767,0.028238924,,,,,,,,,,,,,,,,, -43600,0.0796492,0.02780752,,,,,,,,,,,,,,,,, -43700,0.07158717,0.024690134,,,,,,,,,,,,,,,,, -43800,0.07717299,0.025728196,,,,,,,,,,,,,,,,, -43900,0.095373936,0.026859349,,,,,,,,,,,,,,,,, -44000,0.08688899,0.025841435,,,,,,,,,,,,,,,,, -44100,0.07901894,0.027087253,,,,,,,,,,,,,,,,, -44124,,,0.9925113320350648,0.0236401371657848,0.5700971758111578,0.987098753452301,0.0454526208341121,0.2921075738306386,43793.0,0.9862479567527772,0.0486087724566459,0.2739571897071809,43793.0,14185.310455322266,22327.866988182068,14185.310455322266,8139.154308795929,2.1666409969329834,0.0 -44200,0.08870737,0.026621051,,,,,,,,,,,,,,,,, -44300,0.093198195,0.02595224,,,,,,,,,,,,,,,,, -44400,0.083589464,0.028245848,,,,,,,,,,,,,,,,, -44500,0.07266891,0.025455484,,,,,,,,,,,,,,,,, -44600,0.080555186,0.026139313,,,,,,,,,,,,,,,,, -44700,0.08159547,0.026038235,,,,,,,,,,,,,,,,, -44800,0.07667137,0.026848732,,,,,,,,,,,,,,,,, -44870,,,0.9926388263702391,0.023288769647479,0.5691611133646597,0.9871628880500792,0.0448407232761383,0.295265385933204,43793.0,0.9861940741539,0.0481225363910198,0.272632916276983,43793.0,14425.444154977798,22699.416808366776,14425.444154977798,8270.51467037201,2.20216703414917,0.0 -44900,0.074596904,0.026569296,,,,,,,,,,,,,,,,, -45000,0.07465802,0.024963195,,,,,,,,,,,,,,,,, -45100,0.0922831,0.026831824,,,,,,,,,,,,,,,,, -45200,0.08282595,0.02426951,,,,,,,,,,,,,,,,, -45300,0.08439303,0.025533251,,,,,,,,,,,,,,,,, -45400,0.08115992,0.026685763,,,,,,,,,,,,,,,,, -45500,0.08096279,0.022946496,,,,,,,,,,,,,,,,, -45600,0.06792185,0.024859553,,,,,,,,,,,,,,,,, -45615,,,0.9928351044654846,0.02258169837296,0.5853443199422982,0.9871637225151062,0.0455134995281696,0.2881334524893654,43793.0,0.9862799644470216,0.0486726686358451,0.2769721109562906,43793.0,14665.592813968658,23065.945892572403,14665.592813968658,8396.839224815369,2.237090587615967,0.0 -45700,0.086583994,0.024787096,,,,,,,,,,,,,,,,, -45800,0.08169097,0.027176041,,,,,,,,,,,,,,,,, -45900,0.08314798,0.025596872,,,,,,,,,,,,,,,,, -46000,0.06955796,0.023767138,,,,,,,,,,,,,,,,, -46100,0.071770854,0.025170315,,,,,,,,,,,,,,,,, -46200,0.08893598,0.028488046,,,,,,,,,,,,,,,,, -46300,0.08689109,0.026694955,,,,,,,,,,,,,,,,, -46353,,,0.993068516254425,0.0219520926475524,0.6135675762189512,0.9870699644088744,0.0454390197992324,0.2857170722059383,43793.0,0.9861839413642884,0.0484987720847129,0.2734909871848885,43793.0,14905.821466445925,23437.383046627045,14905.821466445925,8527.988328695297,2.2753024101257324,0.0 -46400,0.0801029,0.020646248,,,,,,,,,,,,,,,,, -46500,0.088606104,0.02624924,,,,,,,,,,,,,,,,, -46600,0.07611486,0.025149893,,,,,,,,,,,,,,,,, -46700,0.08210799,0.025114143,,,,,,,,,,,,,,,,, -46800,0.073600225,0.022218386,,,,,,,,,,,,,,,,, -46900,0.093318865,0.025279937,,,,,,,,,,,,,,,,, -47000,0.087129675,0.024920817,,,,,,,,,,,,,,,,, -47091,,,0.9933658838272096,0.021166056394577,0.6093570684686115,0.9869733452796936,0.0454302765429019,0.2875142031287636,43793.0,0.98616623878479,0.0482204817235469,0.2768215463962636,43793.0,15145.786823511124,23804.64868044853,15145.786823511124,8655.228150367737,2.3150885105133057,0.0 -47100,0.09435835,0.025538886,,,,,,,,,,,,,,,,, -47200,0.073666364,0.024169158,,,,,,,,,,,,,,,,, -47300,0.09012209,0.02639059,,,,,,,,,,,,,,,,, -47400,0.0780112,0.022401204,,,,,,,,,,,,,,,,, -47500,0.083213836,0.023738818,,,,,,,,,,,,,,,,, -47600,0.09266555,0.021733318,,,,,,,,,,,,,,,,, -47700,0.098511554,0.024539644,,,,,,,,,,,,,,,,, -47800,0.076430485,0.022159731,,,,,,,,,,,,,,,,, -47838,,,0.9935070872306824,0.0205803215503692,0.6371483651392227,0.987105667591095,0.0456289239227771,0.2910954688285329,43793.0,0.9862921833992004,0.0486597009003162,0.2697960378710056,43793.0,15385.791709899902,24176.136302232742,15385.791709899902,8786.653314113617,2.3512327671051025,0.0 -47900,0.09872333,0.026685892,,,,,,,,,,,,,,,,, -48000,0.08091417,0.024721634,,,,,,,,,,,,,,,,, -48100,0.095614456,0.026290568,,,,,,,,,,,,,,,,, -48200,0.081773244,0.023458458,,,,,,,,,,,,,,,,, -48300,0.07748015,0.021460673,,,,,,,,,,,,,,,,, -48400,0.07900109,0.024776595,,,,,,,,,,,,,,,,, -48500,0.09893757,0.028130006,,,,,,,,,,,,,,,,, -48582,,,0.993273913860321,0.0214224010705947,0.6035744046576249,0.9871032238006592,0.0457254275679588,0.2912422849534083,43793.0,0.986214280128479,0.0486755184829235,0.2777921732487484,43793.0,15625.98142337799,24541.21181440353,15625.98142337799,8911.482343435287,2.387014865875244,0.0 -48600,0.09294232,0.025350822,,,,,,,,,,,,,,,,, -48700,0.087652266,0.024622645,,,,,,,,,,,,,,,,, -48800,0.088909864,0.02700293,,,,,,,,,,,,,,,,, -48900,0.09740871,0.02743679,,,,,,,,,,,,,,,,, -49000,0.097422086,0.02643195,,,,,,,,,,,,,,,,, -49100,0.08947923,0.024755456,,,,,,,,,,,,,,,,, -49200,0.112278715,0.024275245,,,,,,,,,,,,,,,,, -49300,0.08439115,0.02384987,,,,,,,,,,,,,,,,, -49329,,,0.9930155873298644,0.02194794267416,0.6046991869074054,0.9870277047157288,0.0460869930684566,0.2934004400441242,43793.0,0.986250936985016,0.0489511005580425,0.2742838969821881,43793.0,15866.042525291445,24903.049714803696,15866.042525291445,9033.203171491625,2.4227333068847656,0.0 -49400,0.0841018,0.023965511,,,,,,,,,,,,,,,,, -49500,0.09802627,0.025407879,,,,,,,,,,,,,,,,, -49600,0.09250523,0.022563182,,,,,,,,,,,,,,,,, -49700,0.07693263,0.022623623,,,,,,,,,,,,,,,,, -49800,0.100766316,0.026276639,,,,,,,,,,,,,,,,, -49900,0.09344973,0.02389084,,,,,,,,,,,,,,,,, -50000,0.098119386,0.025753638,,,,,,,,,,,,,,,,, -50061,,,0.9932344555854796,0.0213817190378904,0.6176617595712219,0.9871576428413392,0.0460498519241809,0.293558521584603,43793.0,0.9863073229789734,0.0490118525922298,0.2727283348585548,43793.0,16106.172878026962,25268.836081027985,16106.172878026962,9158.79671406746,2.4613735675811768,0.0 -50100,0.08674874,0.022328047,,,,,,,,,,,,,,,,, -50200,0.09024373,0.025069695,,,,,,,,,,,,,,,,, -50300,0.10045027,0.023098296,,,,,,,,,,,,,,,,, -50400,0.08561498,0.021926712,,,,,,,,,,,,,,,,, -50500,0.08412396,0.021846235,,,,,,,,,,,,,,,,, -50600,0.097555794,0.024049314,,,,,,,,,,,,,,,,, -50700,0.10251906,0.02306701,,,,,,,,,,,,,,,,, -50800,0.08880834,0.02497109,,,,,,,,,,,,,,,,, -50801,,,0.9931604862213136,0.0216441955417394,0.6141723415553106,0.987022876739502,0.0459070391952991,0.2900934698570349,43793.0,0.9861944913864136,0.0488013066351413,0.2790345826699184,43793.0,16346.374148607254,25637.98314404488,16346.374148607254,9287.683853626251,2.49810791015625,0.0 -50900,0.09036133,0.024553772,,,,,,,,,,,,,,,,, -51000,0.08944089,0.023180528,,,,,,,,,,,,,,,,, -51100,0.094685465,0.024245463,,,,,,,,,,,,,,,,, -51200,0.09485136,0.023915755,,,,,,,,,,,,,,,,, -51300,0.08576771,0.023777831,,,,,,,,,,,,,,,,, -51400,0.091231525,0.022139657,,,,,,,,,,,,,,,,, -51500,0.09163518,0.023758566,,,,,,,,,,,,,,,,, -51543,,,0.9934256076812744,0.0206026528030633,0.6302478098406785,0.9871271848678588,0.0461113080382347,0.2967554986096878,43793.0,0.9862997531890868,0.0489075742661953,0.2789417867146801,43793.0,16586.50678873062,25999.830745220184,16586.50678873062,9409.34234046936,2.534095764160156,0.0 -51600,0.092965916,0.021412924,,,,,,,,,,,,,,,,, -51700,0.098732725,0.024063766,,,,,,,,,,,,,,,,, -51800,0.09074806,0.022602014,,,,,,,,,,,,,,,,, -51900,0.10064943,0.024614861,,,,,,,,,,,,,,,,, -52000,0.106142625,0.025217585,,,,,,,,,,,,,,,,, -52100,0.086285494,0.02151814,,,,,,,,,,,,,,,,, -52200,0.10162625,0.025083568,,,,,,,,,,,,,,,,, -52287,,,0.9934813976287842,0.0203452669084072,0.6378640103243176,0.9871352910995485,0.0466223023831844,0.2948339264460545,43793.0,0.9863343238830566,0.0495630428194999,0.2742401976552608,43793.0,16826.707840442657,26363.245794296265,16826.707840442657,9532.496723175049,2.5732104778289795,0.0 -52300,0.10205488,0.026203582,,,,,,,,,,,,,,,,, -52400,0.09330864,0.0222066,,,,,,,,,,,,,,,,, -52500,0.10758398,0.023471715,,,,,,,,,,,,,,,,, -52600,0.09546933,0.024095837,,,,,,,,,,,,,,,,, -52700,0.09703113,0.023019832,,,,,,,,,,,,,,,,, -52800,0.10064024,0.023564382,,,,,,,,,,,,,,,,, -52900,0.09517564,0.021273997,,,,,,,,,,,,,,,,, -53000,0.09066067,0.022091562,,,,,,,,,,,,,,,,, -53032,,,0.9936859011650084,0.0199180506169796,0.6607255578238254,0.9869903922080994,0.0459285601973533,0.2948226204101228,43793.0,0.9861392974853516,0.0489125773310661,0.2728305192982156,43793.0,17066.75954055786,26727.54792380333,17066.75954055786,9656.69024682045,2.609474897384644,0.0 -53100,0.10093648,0.019375246,,,,,,,,,,,,,,,,, -53200,0.100351326,0.02257114,,,,,,,,,,,,,,,,, -53300,0.09416707,0.023450302,,,,,,,,,,,,,,,,, -53400,0.114279434,0.024399271,,,,,,,,,,,,,,,,, -53500,0.1011823,0.022825183,,,,,,,,,,,,,,,,, -53600,0.09532712,0.023645326,,,,,,,,,,,,,,,,, -53700,0.10159834,0.02344681,,,,,,,,,,,,,,,,, -53774,,,0.9940152168273926,0.0189543068408966,0.6742533910555928,0.9870938658714294,0.0462624207139015,0.3016247706823539,43793.0,0.986177623271942,0.0494026094675064,0.2757804694010088,43793.0,17306.817973852158,27094.53839492798,17306.817973852158,9783.560585260391,2.648601531982422,0.0 -53800,0.1130606,0.025189715,,,,,,,,,,,,,,,,, -53900,0.10940995,0.02345692,,,,,,,,,,,,,,,,, -54000,0.10430935,0.021124039,,,,,,,,,,,,,,,,, -54100,0.09584195,0.020822216,,,,,,,,,,,,,,,,, -54200,0.10024516,0.021356633,,,,,,,,,,,,,,,,, -54300,0.097076826,0.0214946,,,,,,,,,,,,,,,,, -54400,0.10554901,0.02552627,,,,,,,,,,,,,,,,, -54500,0.10447649,0.024471788,,,,,,,,,,,,,,,,, -54517,,,0.9941158294677734,0.0187518820166587,0.6797247754253466,0.98710036277771,0.0464842617511749,0.2955374417131669,43793.0,0.98629891872406,0.0495272688567638,0.2791162088408168,43793.0,17546.784031391144,27461.929767370224,17546.784031391144,9910.918627500534,2.690823554992676,0.0 -54600,0.09224114,0.02136774,,,,,,,,,,,,,,,,, -54700,0.0998375,0.019420454,,,,,,,,,,,,,,,,, -54800,0.117096685,0.024582239,,,,,,,,,,,,,,,,, -54900,0.11251938,0.021873,,,,,,,,,,,,,,,,, -55000,0.11470412,0.022298427,,,,,,,,,,,,,,,,, -55100,0.08777337,0.021238944,,,,,,,,,,,,,,,,, -55200,0.10239405,0.021521876,,,,,,,,,,,,,,,,, -55260,,,0.9938734769821168,0.0192361325025558,0.6635183115064012,0.9870488047599792,0.0465531982481479,0.2942804772799886,43793.0,0.9862226843833924,0.0497312098741531,0.275614268267156,43793.0,17786.882937908173,27824.8903477192,17786.882937908173,10033.722482919691,2.7282874584198,0.0 -55300,0.10880109,0.02349908,,,,,,,,,,,,,,,,, -55400,0.1006396,0.023960797,,,,,,,,,,,,,,,,, -55500,0.09998177,0.021908259,,,,,,,,,,,,,,,,, -55600,0.11459193,0.02336416,,,,,,,,,,,,,,,,, -55700,0.101033546,0.02097466,,,,,,,,,,,,,,,,, -55800,0.1306841,0.023832964,,,,,,,,,,,,,,,,, -55900,0.11825276,0.023853384,,,,,,,,,,,,,,,,, -56000,0.096548565,0.021700503,,,,,,,,,,,,,,,,, -56008,,,0.9936686158180236,0.0198822543025016,0.6484135803660225,0.9869859218597412,0.0469512417912483,0.2881964914675938,43793.0,0.9861831068992616,0.0500927679240703,0.2752872680004323,43793.0,18027.09535241127,28190.45258665085,18027.09535241127,10159.011183738708,2.768742561340332,0.0 -56100,0.115245596,0.02332403,,,,,,,,,,,,,,,,, -56200,0.1026492,0.021388618,,,,,,,,,,,,,,,,, -56300,0.13143174,0.025504334,,,,,,,,,,,,,,,,, -56400,0.11248813,0.02243893,,,,,,,,,,,,,,,,, -56500,0.11084453,0.024866382,,,,,,,,,,,,,,,,, -56600,0.11093718,0.022300852,,,,,,,,,,,,,,,,, -56700,0.10757629,0.023588786,,,,,,,,,,,,,,,,, -56756,,,0.9936248660087584,0.0199340619146823,0.6411082926218209,0.9869481325149536,0.0469044186174869,0.2937136247289022,43793.0,0.9860533475875854,0.0501333326101303,0.274866405576504,43793.0,18267.287185430527,28557.286256313324,18267.287185430527,10285.594108343124,2.806877851486206,0.0 -56800,0.102888696,0.020930072,,,,,,,,,,,,,,,,, -56900,0.09576112,0.019604398,,,,,,,,,,,,,,,,, -57000,0.09526491,0.019607114,,,,,,,,,,,,,,,,, -57100,0.12062452,0.023006339,,,,,,,,,,,,,,,,, -57200,0.11123075,0.021840058,,,,,,,,,,,,,,,,, -57300,0.10070441,0.021653827,,,,,,,,,,,,,,,,, -57400,0.0958113,0.023329545,,,,,,,,,,,,,,,,, -57497,,,0.9938535094261168,0.0192855987697839,0.6639852026299888,0.9869769811630248,0.0477448664605617,0.2913151944361254,43793.0,0.9861485362052916,0.0507963001728057,0.2768490884306481,43793.0,18507.45695638657,28925.572428703308,18507.45695638657,10413.653168201448,2.843985319137573,0.0 -57500,0.11663096,0.022510076,,,,,,,,,,,,,,,,, -57600,0.096449606,0.020010736,,,,,,,,,,,,,,,,, -57700,0.124631986,0.023081923,,,,,,,,,,,,,,,,, -57800,0.14523287,0.025066858,,,,,,,,,,,,,,,,, -57900,0.10411646,0.019940708,,,,,,,,,,,,,,,,, -58000,0.13353723,0.022744397,,,,,,,,,,,,,,,,, -58100,0.122323945,0.023633445,,,,,,,,,,,,,,,,, -58200,0.10269016,0.020348402,,,,,,,,,,,,,,,,, -58239,,,0.9939899444580078,0.0187293533235788,0.6713579688097611,0.9869345426559448,0.0474524535238742,0.293276872576415,43793.0,0.9861140251159668,0.0504525229334831,0.2746775260475992,43793.0,18747.58773970604,29288.24400830269,18747.58773970604,10536.1366751194,2.8811538219451904,0.0 -58300,0.09992269,0.019827316,,,,,,,,,,,,,,,,, -58400,0.11968534,0.0227458,,,,,,,,,,,,,,,,, -58500,0.09326882,0.02028872,,,,,,,,,,,,,,,,, -58600,0.119446196,0.022240495,,,,,,,,,,,,,,,,, -58700,0.11072143,0.022492992,,,,,,,,,,,,,,,,, -58800,0.106713645,0.020757696,,,,,,,,,,,,,,,,, -58900,0.10261995,0.02103893,,,,,,,,,,,,,,,,, -58988,,,0.9940062165260316,0.0185742657631635,0.6838410034958001,0.9871000051498412,0.0477108918130397,0.2963570922631796,43793.0,0.9863199591636658,0.0507941544055938,0.2780661761896076,43793.0,18987.742092847824,29653.516090154648,18987.742092847824,10661.195743560793,2.9196202754974365,0.0 -59000,0.14876981,0.023116723,,,,,,,,,,,,,,,,, -59100,0.108682334,0.021305658,,,,,,,,,,,,,,,,, -59200,0.13280463,0.022299837,,,,,,,,,,,,,,,,, -59300,0.119247265,0.021386268,,,,,,,,,,,,,,,,, -59400,0.12017111,0.022141198,,,,,,,,,,,,,,,,, -59500,0.13071059,0.02342357,,,,,,,,,,,,,,,,, -59600,0.11479117,0.021131102,,,,,,,,,,,,,,,,, -59700,0.10795259,0.02014628,,,,,,,,,,,,,,,,, -59736,,,0.9945381283760072,0.017136039212346,0.7119890023519173,0.9870768189430236,0.0477268025279045,0.2921333089367812,43793.0,0.9861990809440612,0.0508352927863597,0.2754505039140115,43793.0,19227.76380634308,30018.46971058845,19227.76380634308,10786.069747924805,2.9572134017944336,0.0 -59800,0.11743754,0.022418533,,,,,,,,,,,,,,,,, -59900,0.11602489,0.02234749,,,,,,,,,,,,,,,,, -60000,0.13228413,0.023147756,,,,,,,,,,,,,,,,, -60100,0.11296784,0.02029401,,,,,,,,,,,,,,,,, -60200,0.12664723,0.021871027,,,,,,,,,,,,,,,,, -60300,0.11863322,0.023709144,,,,,,,,,,,,,,,,, -60400,0.11819214,0.02043536,,,,,,,,,,,,,,,,, -60478,,,0.9946551322937012,0.0168867819011211,0.7121866182580209,0.9869863390922546,0.0479784160852432,0.2957888682327104,43793.0,0.9862260818481444,0.0509336069226264,0.2769695394517674,43793.0,19467.88423514366,30385.91834759712,19467.88423514366,10913.33787703514,2.99572229385376,0.0 -60500,0.106869966,0.0194703,,,,,,,,,,,,,,,,, -60600,0.121072106,0.019845711,,,,,,,,,,,,,,,,, -60700,0.10761094,0.020295385,,,,,,,,,,,,,,,,, -60800,0.12603122,0.019456446,,,,,,,,,,,,,,,,, -60900,0.11049428,0.021869974,,,,,,,,,,,,,,,,, -61000,0.12154819,0.018718045,,,,,,,,,,,,,,,,, -61100,0.110859394,0.019223714,,,,,,,,,,,,,,,,, -61200,0.12626882,0.021126604,,,,,,,,,,,,,,,,, -61226,,,0.994626522064209,0.0167669877409935,0.7171335605289063,0.9870147109031676,0.0480419769883155,0.2928131691734593,43793.0,0.9861772060394288,0.051029447466135,0.2759675212331038,43793.0,19707.84122633934,30746.321582317352,19707.84122633934,11033.724110364914,3.034928560256958,0.0 -61300,0.12639749,0.020893343,,,,,,,,,,,,,,,,, -61400,0.12741557,0.01997554,,,,,,,,,,,,,,,,, -61500,0.11791552,0.018816894,,,,,,,,,,,,,,,,, -61600,0.12386078,0.019901732,,,,,,,,,,,,,,,,, -61700,0.14312354,0.021702629,,,,,,,,,,,,,,,,, -61800,0.12039027,0.019835526,,,,,,,,,,,,,,,,, -61900,0.10956621,0.019672332,,,,,,,,,,,,,,,,, -61962,,,0.9944984316825868,0.0170986913144588,0.7176627280464143,0.9870285391807556,0.0489490553736686,0.294612297398712,43793.0,0.98625510931015,0.0519960150122642,0.274682914531157,43793.0,19948.08852863312,31111.174551725388,19948.08852863312,11158.269618988035,3.074516534805298,0.0 -62000,0.12028195,0.021487849,,,,,,,,,,,,,,,,, -62100,0.14689276,0.022414388,,,,,,,,,,,,,,,,, -62200,0.13390027,0.01942994,,,,,,,,,,,,,,,,, -62300,0.1102955,0.016417667,,,,,,,,,,,,,,,,, -62400,0.11535753,0.01893431,,,,,,,,,,,,,,,,, -62500,0.11200384,0.020696552,,,,,,,,,,,,,,,,, -62600,0.13513845,0.02232474,,,,,,,,,,,,,,,,, -62700,0.12124003,0.020708362,,,,,,,,,,,,,,,,, -62706,,,0.9945216178894044,0.0171154234558343,0.7133070344489623,0.9869871139526368,0.0483480617403984,0.2968437849571173,43793.0,0.9861464500427246,0.0516576319932937,0.2757454189249902,43793.0,20188.32045006752,31476.630955457687,20188.32045006752,11283.433761835098,3.114187479019165,0.0 -62800,0.11877267,0.019921292,,,,,,,,,,,,,,,,, -62900,0.1167576,0.018708028,,,,,,,,,,,,,,,,, -63000,0.11674932,0.018047985,,,,,,,,,,,,,,,,, -63100,0.15906179,0.02082435,,,,,,,,,,,,,,,,, -63200,0.12594849,0.019534778,,,,,,,,,,,,,,,,, -63300,0.116520666,0.01964951,,,,,,,,,,,,,,,,, -63400,0.124616615,0.020892981,,,,,,,,,,,,,,,,, -63446,,,0.9945242404937744,0.0169851873070001,0.7093867418780908,0.9869940280914308,0.0486472472548484,0.2948966868870458,43793.0,0.9861767888069152,0.0518626347184181,0.2775576912809816,43793.0,20428.48164987564,31839.75074863434,20428.48164987564,11406.333416223526,3.152745485305786,0.0 -63500,0.12810132,0.019645562,,,,,,,,,,,,,,,,, -63600,0.11836298,0.021069435,,,,,,,,,,,,,,,,, -63700,0.1317015,0.020582253,,,,,,,,,,,,,,,,, -63800,0.12982187,0.019119548,,,,,,,,,,,,,,,,, -63900,0.11802194,0.01987759,,,,,,,,,,,,,,,,, -64000,0.13945134,0.022025311,,,,,,,,,,,,,,,,, -64100,0.12474011,0.020549556,,,,,,,,,,,,,,,,, -64189,,,0.9944551587104796,0.0170583687722682,0.7121639624300242,0.9870054125785828,0.048753298819065,0.297605083648721,43793.0,0.9862075448036194,0.0521550811827182,0.2738652667227593,43793.0,20668.63866043091,32205.007052898407,20668.63866043091,11531.373598575592,3.19139051437378,0.0 -64200,0.14715375,0.021418985,,,,,,,,,,,,,,,,, -64300,0.13362852,0.02044715,,,,,,,,,,,,,,,,, -64400,0.11335778,0.017463956,,,,,,,,,,,,,,,,, -64500,0.13464455,0.019247616,,,,,,,,,,,,,,,,, -64600,0.1315618,0.021450825,,,,,,,,,,,,,,,,, -64700,0.1354865,0.018188624,,,,,,,,,,,,,,,,, -64800,0.12912536,0.018823657,,,,,,,,,,,,,,,,, -64900,0.15663417,0.02228233,,,,,,,,,,,,,,,,, -64933,,,0.9945725202560424,0.0167442299425601,0.7202841801572984,0.9870171546936036,0.0490759573876857,0.2939358425075371,43793.0,0.9862125515937804,0.0522994175553321,0.2767880296537022,43793.0,20908.890762090683,32568.69940972328,20908.890762090683,11654.752810955048,3.231707811355591,0.0 -65000,0.11619514,0.018661138,,,,,,,,,,,,,,,,, -65100,0.13474555,0.018209225,,,,,,,,,,,,,,,,, -65200,0.12163009,0.01721901,,,,,,,,,,,,,,,,, -65300,0.12880714,0.020169532,,,,,,,,,,,,,,,,, -65400,0.13478537,0.020191688,,,,,,,,,,,,,,,,, -65500,0.12542927,0.019370114,,,,,,,,,,,,,,,,, -65600,0.1263778,0.018704522,,,,,,,,,,,,,,,,, -65686,,,0.9948503971099854,0.0159808285534381,0.7366013362986004,0.9870244860649108,0.0493682250380516,0.2944331164330976,43793.0,0.9862125515937804,0.0525530837476253,0.276474986311574,43793.0,21148.8704726696,32929.93105196953,21148.8704726696,11775.9447824955,3.270869255065918,0.0 -65700,0.1452788,0.01891923,,,,,,,,,,,,,,,,, -65800,0.12394847,0.018124271,,,,,,,,,,,,,,,,, -65900,0.13341402,0.019518865,,,,,,,,,,,,,,,,, -66000,0.13489898,0.02089396,,,,,,,,,,,,,,,,, -66100,0.12010876,0.016242532,,,,,,,,,,,,,,,,, -66200,0.14508201,0.019572355,,,,,,,,,,,,,,,,, -66300,0.1275787,0.019230565,,,,,,,,,,,,,,,,, -66400,0.14724685,0.020198649,,,,,,,,,,,,,,,,, -66427,,,0.9951132535934448,0.0153884962201118,0.7601626246504376,0.9869201183319092,0.0494905523955822,0.2930955641549069,43793.0,0.986159086227417,0.0525974743068218,0.2758683298245158,43793.0,21389.03907442093,33292.05467629433,21389.03907442093,11897.835361719131,3.3115804195404053,0.0 -66500,0.12243791,0.018358584,,,,,,,,,,,,,,,,, -66600,0.1456346,0.019520013,,,,,,,,,,,,,,,,, -66700,0.14816214,0.019748492,,,,,,,,,,,,,,,,, -66800,0.13585512,0.019053899,,,,,,,,,,,,,,,,, -66900,0.14222,0.019785393,,,,,,,,,,,,,,,,, -67000,0.13063502,0.01965772,,,,,,,,,,,,,,,,, -67100,0.14693342,0.019094542,,,,,,,,,,,,,,,,, -67167,,,0.9952598214149476,0.0148692950606346,0.7578312354119906,0.9869258403778076,0.0494728460907936,0.2936498573486638,43793.0,0.9861615896224976,0.052687082439661,0.2765470271540292,43793.0,21629.091300964355,33652.41530036926,21629.091300964355,12018.075865268707,3.3581244945526123,0.0 -67200,0.11824218,0.017119044,,,,,,,,,,,,,,,,, -67300,0.15858367,0.02058924,,,,,,,,,,,,,,,,, -67400,0.12195339,0.019094063,,,,,,,,,,,,,,,,, -67500,0.1371667,0.021305762,,,,,,,,,,,,,,,,, -67600,0.13660473,0.01816765,,,,,,,,,,,,,,,,, -67700,0.14767183,0.019107437,,,,,,,,,,,,,,,,, -67800,0.13559891,0.019789563,,,,,,,,,,,,,,,,, -67900,0.13186875,0.017075056,,,,,,,,,,,,,,,,, -67912,,,0.9953340888023376,0.014705266803503,0.7636951762040912,0.9869895577430724,0.0496973432600498,0.294791959712245,43793.0,0.9861228466033936,0.0529684200882911,0.2736527050416096,43793.0,21869.05692052841,34013.61874079704,21869.05692052841,12139.252855539322,3.3980062007904053,0.0 -68000,0.13433991,0.018096145,,,,,,,,,,,,,,,,, -68100,0.12732048,0.018818047,,,,,,,,,,,,,,,,, -68200,0.14275615,0.019917399,,,,,,,,,,,,,,,,, -68300,0.13529493,0.019513141,,,,,,,,,,,,,,,,, -68400,0.151628,0.0213931,,,,,,,,,,,,,,,,, -68500,0.14931284,0.020319885,,,,,,,,,,,,,,,,, -68600,0.1457108,0.01866759,,,,,,,,,,,,,,,,, -68664,,,0.995245397090912,0.014989978633821,0.754445749752948,0.9869197607040404,0.0498844981193542,0.2920263811403023,43793.0,0.9861595034599304,0.0530820786952972,0.2710669890759926,43793.0,22109.022665262222,34374.08118200302,22109.022665262222,12259.688798666,3.437978029251098,0.0 -68700,0.1464466,0.017460253,,,,,,,,,,,,,,,,, -68800,0.15412375,0.020481905,,,,,,,,,,,,,,,,, -68900,0.14838359,0.018894503,,,,,,,,,,,,,,,,, -69000,0.12315923,0.016686732,,,,,,,,,,,,,,,,, -69100,0.1483387,0.019722607,,,,,,,,,,,,,,,,, -69200,0.122972734,0.017488627,,,,,,,,,,,,,,,,, -69300,0.15947407,0.01737808,,,,,,,,,,,,,,,,, -69400,0.13805062,0.018784229,,,,,,,,,,,,,,,,, -69412,,,0.995061218738556,0.0152342738583683,0.7505739178886142,0.9869449138641356,0.0500610433518886,0.2925591697654719,43793.0,0.9861894249916076,0.0533779002726078,0.2722586746507549,43793.0,22349.240227222443,34736.90176987648,22349.240227222443,12382.230098962784,3.4789209365844727,0.0 -69500,0.1340237,0.015950039,,,,,,,,,,,,,,,,, -69600,0.14054473,0.016815294,,,,,,,,,,,,,,,,, -69700,0.14459631,0.021370873,,,,,,,,,,,,,,,,, -69800,0.14317223,0.018816125,,,,,,,,,,,,,,,,, -69900,0.13126007,0.01867873,,,,,,,,,,,,,,,,, -70000,0.14357837,0.019724185,,,,,,,,,,,,,,,,, -70100,0.13392006,0.017066667,,,,,,,,,,,,,,,,, -70165,,,0.9951221942901612,0.0151732191443443,0.7626114380669714,0.9870163798332214,0.0502706728875637,0.2902962107963454,43793.0,0.9861894249916076,0.0535657070577144,0.2755686078935777,43793.0,22589.316182613373,35094.25100970268,22589.316182613373,12499.442770004272,3.5188333988189697,0.0 -70200,0.14542177,0.020737572,,,,,,,,,,,,,,,,, -70300,0.14076306,0.018674785,,,,,,,,,,,,,,,,, -70400,0.14063689,0.017218197,,,,,,,,,,,,,,,,, -70500,0.12111331,0.01645672,,,,,,,,,,,,,,,,, -70600,0.1693447,0.019324582,,,,,,,,,,,,,,,,, -70700,0.14507753,0.018191459,,,,,,,,,,,,,,,,, -70800,0.14011692,0.01791651,,,,,,,,,,,,,,,,, -70900,0.14947581,0.018326646,,,,,,,,,,,,,,,,, -70912,,,0.994987726211548,0.0154399583116173,0.7428052908604961,0.9870102405548096,0.0504793338477611,0.2939184261839505,43793.0,0.98615825176239,0.053830362856388,0.2747486902004241,43793.0,22829.283252239227,35451.13889479637,22829.283252239227,12616.302125930786,3.5597028732299805,0.0 -71000,0.13725351,0.01908394,,,,,,,,,,,,,,,,, -71100,0.1411908,0.017061062,,,,,,,,,,,,,,,,, -71200,0.1423791,0.01838028,,,,,,,,,,,,,,,,, -71300,0.15686248,0.01922439,,,,,,,,,,,,,,,,, -71400,0.15060507,0.016838439,,,,,,,,,,,,,,,,, -71500,0.14746253,0.018456265,,,,,,,,,,,,,,,,, -71600,0.13780206,0.01745843,,,,,,,,,,,,,,,,, -71663,,,0.9952868819236756,0.0145252058282494,0.7632115012875428,0.9869834780693054,0.0503106638789176,0.2912756458337494,43793.0,0.9861262440681458,0.0536865107715129,0.2744871260934867,43793.0,23069.419085025787,35805.873645067215,23069.419085025787,12730.839373111725,3.6005778312683105,0.0 -71700,0.1403848,0.017615655,,,,,,,,,,,,,,,,, -71800,0.14543186,0.018788109,,,,,,,,,,,,,,,,, -71900,0.15434821,0.01913658,,,,,,,,,,,,,,,,, -72000,0.12648702,0.014634303,,,,,,,,,,,,,,,,, -72100,0.13729993,0.018065877,,,,,,,,,,,,,,,,, -72200,0.1375416,0.018166563,,,,,,,,,,,,,,,,, -72300,0.13287343,0.018831216,,,,,,,,,,,,,,,,, -72400,0.17250338,0.01935542,,,,,,,,,,,,,,,,, -72409,,,0.995356559753418,0.0144841000437736,0.7616841179855216,0.9870256781578064,0.0505304187536239,0.2914941165453763,43793.0,0.9861717224121094,0.0539077967405319,0.2749625676471574,43793.0,23309.60293293,36168.21115899086,23309.60293293,12852.929856061935,3.643061876296997,0.0 -72500,0.14079013,0.020621013,,,,,,,,,,,,,,,,, -72600,0.15647158,0.018280681,,,,,,,,,,,,,,,,, -72700,0.17056994,0.01935652,,,,,,,,,,,,,,,,, -72800,0.14912213,0.01766504,,,,,,,,,,,,,,,,, -72900,0.15490414,0.019240884,,,,,,,,,,,,,,,,, -73000,0.13986179,0.016822645,,,,,,,,,,,,,,,,, -73100,0.15841125,0.019226447,,,,,,,,,,,,,,,,, -73150,,,0.9953998923301696,0.0142793525010347,0.778936084156196,0.9870293140411376,0.0505202263593673,0.2940463974337953,43793.0,0.9861780405044556,0.0538734272122383,0.2740218947411233,43793.0,23549.57167220116,36522.17186427117,23549.57167220116,12966.86022424698,3.684103012084961,0.0 -73200,0.12748003,0.018731702,,,,,,,,,,,,,,,,, -73300,0.15375555,0.018866284,,,,,,,,,,,,,,,,, -73400,0.13958219,0.017056873,,,,,,,,,,,,,,,,, -73500,0.15955038,0.020952793,,,,,,,,,,,,,,,,, -73600,0.14258046,0.018509727,,,,,,,,,,,,,,,,, -73700,0.1349315,0.017185392,,,,,,,,,,,,,,,,, -73800,0.14869818,0.018934347,,,,,,,,,,,,,,,,, -73900,0.15059389,0.017527197,,,,,,,,,,,,,,,,, -73901,,,0.9956340789794922,0.013805765658617,0.7875806631898318,0.9869623780250548,0.0504661202430725,0.2937153218722242,43793.0,0.9861649870872498,0.0538687482476234,0.2766866151141615,43793.0,23789.65530896187,36880.84664463997,23789.65530896187,13085.3903298378,3.724265813827514,0.0 -74000,0.1438123,0.017291822,,,,,,,,,,,,,,,,, -74100,0.13824901,0.018104913,,,,,,,,,,,,,,,,, -74200,0.16188417,0.01952373,,,,,,,,,,,,,,,,, -74300,0.13104826,0.01761491,,,,,,,,,,,,,,,,, -74400,0.14215565,0.018069208,,,,,,,,,,,,,,,,, -74500,0.13218968,0.017250795,,,,,,,,,,,,,,,,, -74600,0.14445205,0.016820822,,,,,,,,,,,,,,,,, -74646,,,0.995647132396698,0.0137284025549888,0.7831579439134501,0.9869558811187744,0.0505936332046985,0.2927131763169233,43793.0,0.986185610294342,0.0539115630090236,0.274894312465339,43793.0,24029.66795539856,37240.47275662422,24029.66795539856,13204.942134141922,3.765378952026367,0.0 -74700,0.14350085,0.01778135,,,,,,,,,,,,,,,,, -74800,0.14698592,0.017665016,,,,,,,,,,,,,,,,, -74900,0.15157866,0.019760339,,,,,,,,,,,,,,,,, -75000,0.14595981,0.018457677,,,,,,,,,,,,,,,,, -75100,0.1601732,0.018846776,,,,,,,,,,,,,,,,, -75200,0.14631167,0.017379725,,,,,,,,,,,,,,,,, -75300,0.13415878,0.0149072455,,,,,,,,,,,,,,,,, -75394,,,0.9955528378486632,0.0138159431517124,0.782723698481396,0.9869920015335084,0.0507420748472213,0.2923844320138635,43793.0,0.9862277507781982,0.0540798865258693,0.2759888610790896,43793.0,24269.63495206833,37600.3379855156,24269.63495206833,13324.7779610157,3.806945562362671,0.0 -75400,0.14438239,0.018019002,,,,,,,,,,,,,,,,, -75500,0.15121795,0.017230188,,,,,,,,,,,,,,,,, -75600,0.14225589,0.016305735,,,,,,,,,,,,,,,,, -75700,0.15030183,0.020034224,,,,,,,,,,,,,,,,, -75800,0.14075424,0.016296143,,,,,,,,,,,,,,,,, -75900,0.14993316,0.018313842,,,,,,,,,,,,,,,,, -76000,0.21327993,0.01703065,,,,,,,,,,,,,,,,, -76100,0.13152488,0.016037734,,,,,,,,,,,,,,,,, -76147,,,0.995425283908844,0.0141666233539581,0.7730561717837915,0.9869599342346193,0.0507632009685039,0.2921613507528562,43793.0,0.986213445663452,0.0540745370090007,0.2772311098034813,43793.0,24509.7945497036,37961.79363822937,24509.7945497036,13446.011998414991,3.848209857940674,0.0 -76200,0.14215858,0.017909104,,,,,,,,,,,,,,,,, -76300,0.14531842,0.018190326,,,,,,,,,,,,,,,,, -76400,0.14488636,0.017580865,,,,,,,,,,,,,,,,, -76500,0.1439435,0.019870352,,,,,,,,,,,,,,,,, -76600,0.13490915,0.016960885,,,,,,,,,,,,,,,,, -76700,0.14812812,0.018527914,,,,,,,,,,,,,,,,, -76800,0.14478098,0.016684666,,,,,,,,,,,,,,,,, -76897,,,0.9954675436019896,0.0141391856595873,0.7729857674740117,0.9869920015335084,0.0507909543812274,0.2932262344296919,43793.0,0.986204981803894,0.0540654771029949,0.2752977861531578,43793.0,24749.799326896667,38318.488555669785,24749.799326896667,13562.63976407051,3.890085697174072,0.0 -76900,0.15737334,0.017536804,,,,,,,,,,,,,,,,, -77000,0.12565845,0.015454288,,,,,,,,,,,,,,,,, -77100,0.12712856,0.016644126,,,,,,,,,,,,,,,,, -77200,0.1353767,0.014775389,,,,,,,,,,,,,,,,, -77300,0.16505966,0.019801516,,,,,,,,,,,,,,,,, -77400,0.14108348,0.015856676,,,,,,,,,,,,,,,,, -77500,0.14792792,0.017425679,,,,,,,,,,,,,,,,, -77600,0.13717277,0.017126054,,,,,,,,,,,,,,,,, -77645,,,0.9955635666847228,0.0138484183698892,0.789354634106922,0.9869863390922546,0.0506942756474018,0.2951626304122393,43793.0,0.9862138628959656,0.0539667382836341,0.2751514090916538,43793.0,24989.83911180496,38672.74894404411,24989.83911180496,13676.798060655594,3.931835889816284,0.0 -77700,0.14567228,0.01844057,,,,,,,,,,,,,,,,, -77800,0.14984341,0.01871018,,,,,,,,,,,,,,,,, -77900,0.14789855,0.019716872,,,,,,,,,,,,,,,,, -78000,0.16990547,0.019003341,,,,,,,,,,,,,,,,, -78100,0.14414318,0.016084708,,,,,,,,,,,,,,,,, -78200,0.14455721,0.019030482,,,,,,,,,,,,,,,,, -78300,0.15030469,0.018978767,,,,,,,,,,,,,,,,, -78390,,,0.9955982565879822,0.0138391042128205,0.7907666497457357,0.9869810342788696,0.0507110469043254,0.2936789197248196,43793.0,0.9862096309661864,0.0539970993995666,0.2748707703275417,43793.0,25229.79161000252,39027.76487231255,25229.79161000252,13791.799069404602,3.97400426864624,0.0 -78400,0.15395218,0.018684098,,,,,,,,,,,,,,,,, -78500,0.15299627,0.018056588,,,,,,,,,,,,,,,,, -78600,0.13964139,0.016460104,,,,,,,,,,,,,,,,, -78700,0.13729087,0.017660731,,,,,,,,,,,,,,,,, -78800,0.14469056,0.018398058,,,,,,,,,,,,,,,,, -78900,0.14500982,0.017653378,,,,,,,,,,,,,,,,, -79000,0.14051999,0.016123747,,,,,,,,,,,,,,,,, -79100,0.14969605,0.015507028,,,,,,,,,,,,,,,,, -79135,,,0.9955662488937378,0.0138804921880364,0.7777995135357072,0.9869794249534608,0.0507164113223552,0.2937873785465716,43793.0,0.9862079620361328,0.0540053136646747,0.2754306980280074,43793.0,25470.0196492672,39393.80100250244,25470.0196492672,13917.542226552963,4.018395185470581,0.0 -79200,0.14460815,0.019207755,,,,,,,,,,,,,,,,, -79300,0.13312194,0.017859293,,,,,,,,,,,,,,,,, -79400,0.15374802,0.019042,,,,,,,,,,,,,,,,, -79500,0.14198509,0.016955094,,,,,,,,,,,,,,,,, -79600,0.13120522,0.01727662,,,,,,,,,,,,,,,,, -79700,0.12836151,0.015795944,,,,,,,,,,,,,,,,, -79800,0.14625786,0.018692749,,,,,,,,,,,,,,,,, -79868,,,0.9955735206604004,0.0137835601344704,0.7806222712920365,0.9869920015335084,0.0507424473762512,0.2936313270430152,43793.0,0.9862112998962402,0.054036695510149,0.2753682278616062,43793.0,25710.17612195015,39754.10463857651,25710.17612195015,14037.619457006454,4.064361095428467,0.0 -79900,0.13697588,0.017269816,,,,,,,,,,,,,,,,, -80000,0.14614452,0.018035192,,,,,,,,,,,,,,,,, -80100,0.123823,0.016180877,,,,,,,,,,,,,,,,, -80200,0.13712388,0.018028539,,,,,,,,,,,,,,,,, -80300,0.142694,0.016903447,,,,,,,,,,,,,,,,, -80400,0.14943889,0.0179185,,,,,,,,,,,,,,,,, -80500,0.14577807,0.017096775,,,,,,,,,,,,,,,,, -80600,0.15821406,0.017605936,,,,,,,,,,,,,,,,, -80611,,,0.99559223651886,0.0137861659750342,0.7772153678768996,0.9869915843009948,0.0507423616945743,0.2936277846170593,43793.0,0.9862112998962402,0.0540365986526012,0.2754226497068658,43793.0,25950.23030114174,40112.879777908325,25950.23030114174,14156.27874803543,4.105771064758301,0.0 -80700,0.13844903,0.017775571,,,,,,,,,,,,,,,,, -80800,0.14736825,0.017260952,,,,,,,,,,,,,,,,, -80900,0.14022778,0.015855845,,,,,,,,,,,,,,,,, -81000,0.14483385,0.016553894,,,,,,,,,,,,,,,,, -81100,0.13907526,0.016238337,,,,,,,,,,,,,,,,, -81200,0.14899643,0.02065826,,,,,,,,,,,,,,,,, -81300,0.17055023,0.02027093,,,,,,,,,,,,,,,,, -81350,,,0.9955688118934632,0.0138651421293616,0.7831689070136559,0.9869915843009948,0.0507423616945743,0.2936635114911271,43793.0,0.9862112998962402,0.0540365986526012,0.275526511079584,43793.0,26190.267526865005,40474.08954358101,26190.267526865005,14277.388365507126,4.147861957550049,0.0 -81400,0.13376002,0.01688226,,,,,,,,,,,,,,,,, -81500,0.15536165,0.017473707,,,,,,,,,,,,,,,,, -81600,0.15410084,0.017687546,,,,,,,,,,,,,,,,, -81700,0.13907221,0.01853515,,,,,,,,,,,,,,,,, -81800,0.14130065,0.017089857,,,,,,,,,,,,,,,,, -81900,0.15507992,0.01958637,,,,,,,,,,,,,,,,, -82000,0.15513623,0.019655071,,,,,,,,,,,,,,,,, -82100,0.14859122,0.017055215,,,,,,,,,,,,,,,,, -82101,,,0.995573878288269,0.0138689950108528,0.786614385369315,0.9869915843009948,0.0507423616945743,0.2936718267714238,43793.0,0.9862112998962402,0.0540365986526012,0.2754228664505747,43793.0,26430.20501279831,40829.4153251648,26430.20501279831,14392.712438821793,4.19120192527771,0.0 -82200,0.14207068,0.017816288,,,,,,,,,,,,,,,,, -82300,0.1483777,0.017481336,,,,,,,,,,,,,,,,, -82400,0.14785051,0.018934224,,,,,,,,,,,,,,,,, -82500,0.13958083,0.017334675,,,,,,,,,,,,,,,,, -82600,0.13883354,0.017657612,,,,,,,,,,,,,,,,, -82700,0.13523091,0.016439294,,,,,,,,,,,,,,,,, -82800,0.13655053,0.015203259,,,,,,,,,,,,,,,,, -82848,,,0.995610535144806,0.0137790357694029,0.7796165680656176,0.9869915843009948,0.0507423616945743,0.2937034401072733,43793.0,0.9862112998962402,0.0540365986526012,0.2755169929742068,43793.0,26670.177414417267,41190.311574697495,26670.177414417267,14513.572434902191,4.234838008880615,0.0 -82900,0.1518281,0.017983343,,,,,,,,,,,,,,,,, -83000,0.13906221,0.018107815,,,,,,,,,,,,,,,,, -83100,0.14618932,0.016452162,,,,,,,,,,,,,,,,, -83200,0.13290776,0.016020194,,,,,,,,,,,,,,,,, -83300,0.14879394,0.016402114,,,,,,,,,,,,,,,,, -83400,0.13640422,0.017436791,,,,,,,,,,,,,,,,, -83500,0.15064877,0.015936656,,,,,,,,,,,,,,,,, -83600,0.13670705,0.016038932,,,,,,,,,,,,,,,,, -83601,,,0.9955638647079468,0.0138187641277909,0.7790167630847286,0.9869915843009948,0.0507423616945743,0.2936536998079178,43793.0,0.9862112998962402,0.0540366023778915,0.2753272473537633,43793.0,26910.234219789505,41547.52993321419,26910.234219789505,14630.670375823976,4.278192520141602,0.0 -83700,0.14053768,0.018442204,,,,,,,,,,,,,,,,, -83800,0.15586405,0.01633472,,,,,,,,,,,,,,,,, -83900,0.13710657,0.018103676,,,,,,,,,,,,,,,,, -84000,0.14438713,0.017042331,,,,,,,,,,,,,,,,, -84100,0.15855643,0.017095467,,,,,,,,,,,,,,,,, -84200,0.14430639,0.018223038,,,,,,,,,,,,,,,,, -84300,0.15716493,0.017822832,,,,,,,,,,,,,,,,, -84348,,,0.9955844879150392,0.0138030629605054,0.774204131037632,0.9869915843009948,0.0507423616945743,0.2938256946907282,43793.0,0.9862112998962402,0.0540365986526012,0.2754699155733398,43793.0,27150.300426721573,41906.93664479256,27150.300426721573,14749.948195457458,4.320634126663208,0.0 -84400,0.15051797,0.018563898,,,,,,,,,,,,,,,,, -84500,0.15039524,0.020633785,,,,,,,,,,,,,,,,, -84600,0.13890867,0.01702257,,,,,,,,,,,,,,,,, -84700,0.14806847,0.017675936,,,,,,,,,,,,,,,,, -84800,0.14201689,0.015689973,,,,,,,,,,,,,,,,, -84900,0.16769855,0.016792523,,,,,,,,,,,,,,,,, -85000,0.14060122,0.018148065,,,,,,,,,,,,,,,,, -85093,,,0.9954951405525208,0.0139528112486004,0.7879491584598639,0.9869915843009948,0.0507423616945743,0.2936760948804247,43793.0,0.9862112998962402,0.0540365986526012,0.2754004305972816,43793.0,27390.44856762886,42268.99037504196,27390.44856762886,14871.777806282043,4.376424074172974,0.0 -85100,0.14657046,0.01783173,,,,,,,,,,,,,,,,, -85200,0.1364875,0.0155436555,,,,,,,,,,,,,,,,, -85300,0.14599372,0.017982831,,,,,,,,,,,,,,,,, -85400,0.14224803,0.017597927,,,,,,,,,,,,,,,,, -85500,0.120139286,0.0155486,,,,,,,,,,,,,,,,, -85600,0.13448535,0.017361147,,,,,,,,,,,,,,,,, -85700,0.14375074,0.017608784,,,,,,,,,,,,,,,,, -85800,0.13442345,0.01745119,,,,,,,,,,,,,,,,, -85844,,,0.9956097602844238,0.0138254640623927,0.7908804322850795,0.9869915843009948,0.0507423616945743,0.2935794272536952,43793.0,0.9862112998962402,0.0540365986526012,0.2753637605802473,43793.0,27630.48721194268,42629.74559020996,27630.48721194268,14992.430229187012,4.419990539550781,0.0 -85900,0.14847948,0.019384928,,,,,,,,,,,,,,,,, -86000,0.13660525,0.016995251,,,,,,,,,,,,,,,,, -86100,0.1412938,0.018735228,,,,,,,,,,,,,,,,, -86200,0.14646865,0.015446441,,,,,,,,,,,,,,,,, -86300,0.1322206,0.017349709,,,,,,,,,,,,,,,,, -86400,0.13064335,0.016996644,,,,,,,,,,,,,,,,, -86500,0.1477239,0.016526362,,,,,,,,,,,,,,,,, -86583,,,0.9956318140029908,0.0137186404317617,0.78391687670689,0.9869915843009948,0.0507423616945743,0.2936513145469017,43793.0,0.9862112998962402,0.0540365986526012,0.2754111736117262,43793.0,27870.723264932632,42994.15589976311,27870.723264932632,15116.532540082932,4.467123508453369,0.0 -86600,0.1309509,0.017545462,,,,,,,,,,,,,,,,, -86700,0.15594564,0.018162346,,,,,,,,,,,,,,,,, -86800,0.1424194,0.016270583,,,,,,,,,,,,,,,,, -86900,0.13863313,0.018122429,,,,,,,,,,,,,,,,, -87000,0.15536486,0.017894866,,,,,,,,,,,,,,,,, -87100,0.14204088,0.01889783,,,,,,,,,,,,,,,,, -87200,0.15795772,0.019146543,,,,,,,,,,,,,,,,, -87300,0.15201254,0.01852552,,,,,,,,,,,,,,,,, -87324,,,0.99556964635849,0.0138719119131565,0.7821974859985805,0.9869915843009948,0.0507423616945743,0.2935722567401824,43793.0,0.9862112998962402,0.0540365986526012,0.2754694072339723,43793.0,28110.6790099144,43349.73703336716,28110.6790099144,15232.091541528702,4.512559175491333,0.0 -87400,0.1552273,0.018340506,,,,,,,,,,,,,,,,, -87500,0.13009124,0.017849665,,,,,,,,,,,,,,,,, -87600,0.15394108,0.019492786,,,,,,,,,,,,,,,,, -87700,0.12719415,0.014895006,,,,,,,,,,,,,,,,, -87800,0.16219074,0.019467672,,,,,,,,,,,,,,,,, -87900,0.12339732,0.017505486,,,,,,,,,,,,,,,,, -88000,0.1434182,0.018855395,,,,,,,,,,,,,,,,, -88072,,,0.9955666065216064,0.0137403607368469,0.7812936362956535,0.9869915843009948,0.0507423616945743,0.2936613903844776,43793.0,0.9862112998962402,0.0540365986526012,0.2753737133918487,43793.0,28350.94520950317,43713.2302980423,28350.94520950317,15355.256055355072,4.554625511169434,0.0 -88100,0.15734094,0.018736746,,,,,,,,,,,,,,,,, -88200,0.15682039,0.017791944,,,,,,,,,,,,,,,,, -88300,0.1467846,0.017691946,,,,,,,,,,,,,,,,, -88400,0.13973801,0.018450115,,,,,,,,,,,,,,,,, -88500,0.13863309,0.014982062,,,,,,,,,,,,,,,,, -88600,0.13469769,0.014565796,,,,,,,,,,,,,,,,, -88700,0.12961435,0.016472055,,,,,,,,,,,,,,,,, -88800,0.16709271,0.018476084,,,,,,,,,,,,,,,,, -88805,,,0.99554181098938,0.0139920338988304,0.7816320112847811,0.9869915843009948,0.0507423616945743,0.2935806132338083,43793.0,0.9862112998962402,0.0540365986526012,0.2754791782357467,43793.0,28590.884345531464,44072.07675909996,28590.884345531464,15474.09222960472,4.601079940795898,0.0 -88900,0.16108745,0.019654429,,,,,,,,,,,,,,,,, -89000,0.15981077,0.020355048,,,,,,,,,,,,,,,,, -89100,0.14652078,0.017586485,,,,,,,,,,,,,,,,, -89200,0.15808801,0.019265404,,,,,,,,,,,,,,,,, -89300,0.139997,0.01760464,,,,,,,,,,,,,,,,, -89400,0.16042975,0.019057712,,,,,,,,,,,,,,,,, -89500,0.13962029,0.01667735,,,,,,,,,,,,,,,,, -89541,,,0.9955811500549316,0.0138107240200042,0.7874249815139477,0.9869915843009948,0.0507423616945743,0.2937410135718478,43793.0,0.9862112998962402,0.0540365986526012,0.2755014562042579,43793.0,28831.03605914116,44426.90284395218,28831.03605914116,15588.70144701004,4.645176410675049,0.0 -89600,0.14612879,0.018013163,,,,,,,,,,,,,,,,, -89700,0.13898061,0.017601134,,,,,,,,,,,,,,,,, -89800,0.134884,0.015860295,,,,,,,,,,,,,,,,, -89900,0.1315179,0.016965969,,,,,,,,,,,,,,,,, -90000,0.15894693,0.017582057,,,,,,,,,,,,,,,,, -90100,0.13832258,0.017282251,,,,,,,,,,,,,,,,, -90200,0.14333695,0.0183269,,,,,,,,,,,,,,,,, -90280,,,0.995635151863098,0.0136811081320047,0.7850135717763382,0.9869915843009948,0.0507423616945743,0.2936347523194236,43793.0,0.9862112998962402,0.0540366023778915,0.275349481187374,43793.0,29071.14478611946,44786.104506492615,29071.14478611946,15707.728714704514,4.690348386764526,0.0 -90300,0.14424129,0.018496174,,,,,,,,,,,,,,,,, -90400,0.17057112,0.018949853,,,,,,,,,,,,,,,,, -90500,0.14611104,0.01852386,,,,,,,,,,,,,,,,, -90600,0.13716319,0.01928304,,,,,,,,,,,,,,,,, -90700,0.16568859,0.018868256,,,,,,,,,,,,,,,,, -90800,0.12806149,0.016364686,,,,,,,,,,,,,,,,, -90900,0.13476528,0.016624408,,,,,,,,,,,,,,,,, -91000,0.14185432,0.019934885,,,,,,,,,,,,,,,,, -91025,,,0.9955487251281738,0.0139289842918515,0.7848721563012246,0.9869915843009948,0.0507423616945743,0.2937334688183136,43793.0,0.9862112998962402,0.0540365986526012,0.2753792621221552,43793.0,29311.08949136734,45144.8211247921,29311.08949136734,15826.435528993608,4.734649658203125,0.0 -91100,0.13836981,0.017174121,,,,,,,,,,,,,,,,, -91200,0.14028233,0.017522685,,,,,,,,,,,,,,,,, -91300,0.16385578,0.020213837,,,,,,,,,,,,,,,,, -91400,0.13752045,0.017953822,,,,,,,,,,,,,,,,, -91500,0.13821432,0.017986752,,,,,,,,,,,,,,,,, -91600,0.15155837,0.016686305,,,,,,,,,,,,,,,,, -91700,0.15101811,0.018673481,,,,,,,,,,,,,,,,, -91770,,,0.9956004619598388,0.0137164033949375,0.7883310032943757,0.9869915843009948,0.0507423616945743,0.2937376259139235,43793.0,0.9862112998962402,0.0540365986526012,0.2754203519807138,43793.0,29551.084138154984,45505.17563891411,29551.084138154984,15946.73088979721,4.778702735900879,0.0 -91800,0.15159066,0.018003693,,,,,,,,,,,,,,,,, -91900,0.15182409,0.01619686,,,,,,,,,,,,,,,,, -92000,0.14874813,0.019325765,,,,,,,,,,,,,,,,, -92100,0.145957,0.017014066,,,,,,,,,,,,,,,,, -92200,0.13229826,0.01596244,,,,,,,,,,,,,,,,, -92300,0.12506147,0.013819692,,,,,,,,,,,,,,,,, -92400,0.14518684,0.018453648,,,,,,,,,,,,,,,,, -92500,0.14472647,0.017492214,,,,,,,,,,,,,,,,, -92515,,,0.995555818080902,0.0139048527926206,0.7673218996392518,0.9869915843009948,0.0507423616945743,0.2936585100039495,43793.0,0.9862112998962402,0.0540365986526012,0.2754073981223309,43793.0,29791.020255804066,45862.61566567421,29791.020255804066,16064.169033050535,4.824016094207764,0.0 -92600,0.1499028,0.020091228,,,,,,,,,,,,,,,,, -92700,0.14089748,0.017100364,,,,,,,,,,,,,,,,, -92800,0.14839889,0.01619674,,,,,,,,,,,,,,,,, -92900,0.1551657,0.018175196,,,,,,,,,,,,,,,,, -93000,0.20244439,0.019357098,,,,,,,,,,,,,,,,, -93100,0.12581548,0.016107496,,,,,,,,,,,,,,,,, -93200,0.13983062,0.017327927,,,,,,,,,,,,,,,,, -93256,,,0.9955539107322692,0.0138565571978688,0.7895852409581146,0.9869915843009948,0.0507423616945743,0.2937831941628872,43793.0,0.9862112998962402,0.0540365986526012,0.2754559623634967,43793.0,30031.20615911484,46219.90180850029,30031.20615911484,16181.204716920853,4.867949724197388,0.0 -93300,0.14494982,0.018352164,,,,,,,,,,,,,,,,, -93400,0.16484497,0.018815123,,,,,,,,,,,,,,,,, -93500,0.16100867,0.018955622,,,,,,,,,,,,,,,,, -93600,0.15061997,0.017664168,,,,,,,,,,,,,,,,, -93700,0.1271817,0.01719233,,,,,,,,,,,,,,,,, -93800,0.17063649,0.019110199,,,,,,,,,,,,,,,,, -93900,0.15050054,0.018439934,,,,,,,,,,,,,,,,, -93999,,,0.9955896139144896,0.0138924093917012,0.7888154545262414,0.9869915843009948,0.0507423616945743,0.2936715268549584,43793.0,0.9862112998962402,0.0540366023778915,0.2754090103100722,43793.0,30271.378150701523,46577.03712940216,30271.378150701523,16298.102165699003,4.91324520111084,0.0 -94000,0.14631122,0.01749095,,,,,,,,,,,,,,,,, -94100,0.1422085,0.01757559,,,,,,,,,,,,,,,,, -94200,0.12406246,0.014966079,,,,,,,,,,,,,,,,, -94300,0.15009575,0.018169757,,,,,,,,,,,,,,,,, -94400,0.12461865,0.014659042,,,,,,,,,,,,,,,,, -94500,0.14013791,0.017601157,,,,,,,,,,,,,,,,, -94600,0.14207932,0.017823312,,,,,,,,,,,,,,,,, -94700,0.15120648,0.01995831,,,,,,,,,,,,,,,,, -94744,,,0.9956102967262268,0.0137448869645595,0.7833450442984059,0.9869915843009948,0.0507423616945743,0.2937749887670415,43793.0,0.9862112998962402,0.0540365986526012,0.2753408036677083,43793.0,30511.55004000664,46929.59366226196,30511.55004000664,16410.420204401016,4.958989381790161,0.0 -94800,0.12581053,0.014909512,,,,,,,,,,,,,,,,, -94900,0.13113116,0.017260455,,,,,,,,,,,,,,,,, -95000,0.14045739,0.018263409,,,,,,,,,,,,,,,,, -95100,0.1427518,0.016251331,,,,,,,,,,,,,,,,, -95200,0.13479967,0.01810404,,,,,,,,,,,,,,,,, -95300,0.1537793,0.018118924,,,,,,,,,,,,,,,,, -95400,0.15092562,0.018719684,,,,,,,,,,,,,,,,, -95490,,,0.99561607837677,0.0136659052222967,0.7821546171064298,0.9869915843009948,0.0507423616945743,0.2937425339183844,43793.0,0.9862112998962402,0.0540365986526012,0.2754687286827509,43793.0,30751.71346235276,47287.522605896,30751.71346235276,16528.11999154091,5.0040812492370605,0.0 -95500,0.12409013,0.01663435,,,,,,,,,,,,,,,,, -95600,0.14134283,0.016881948,,,,,,,,,,,,,,,,, -95700,0.12709984,0.01566317,,,,,,,,,,,,,,,,, -95800,0.13888569,0.017627768,,,,,,,,,,,,,,,,, -95900,0.1533624,0.01846234,,,,,,,,,,,,,,,,, -96000,0.14868324,0.017747696,,,,,,,,,,,,,,,,, -96100,0.15435922,0.020104812,,,,,,,,,,,,,,,,, -96200,0.14370614,0.016753599,,,,,,,,,,,,,,,,, -96244,,,0.9955379366874696,0.0139642404392361,0.777552866971797,0.9869915843009948,0.0507423616945743,0.2936285897507651,43793.0,0.9862112998962402,0.0540365986526012,0.2754213828360592,43793.0,30991.67889022827,47648.32022738457,30991.67889022827,16648.886462688446,5.048603057861328,0.0 -96300,0.13524666,0.016401049,,,,,,,,,,,,,,,,, -96400,0.13879494,0.016983673,,,,,,,,,,,,,,,,, -96500,0.13505453,0.016739655,,,,,,,,,,,,,,,,, -96600,0.13841553,0.016433816,,,,,,,,,,,,,,,,, -96700,0.129673,0.015134746,,,,,,,,,,,,,,,,, -96800,0.18389887,0.016108667,,,,,,,,,,,,,,,,, -96900,0.15047677,0.0174126,,,,,,,,,,,,,,,,, -96990,,,0.995514750480652,0.0139308404177427,0.7831007753855972,0.9869915843009948,0.0507423616945743,0.2936723064057466,43793.0,0.9862112998962402,0.0540365986526012,0.2754537971353992,43793.0,31231.81813645363,48005.52919220925,31231.81813645363,16765.890092611313,5.093991041183472,0.0 -97000,0.14303532,0.01769074,,,,,,,,,,,,,,,,, -97100,0.14862812,0.017451191,,,,,,,,,,,,,,,,, -97200,0.1362894,0.015959345,,,,,,,,,,,,,,,,, -97300,0.15338151,0.017532185,,,,,,,,,,,,,,,,, -97400,0.15960649,0.017597307,,,,,,,,,,,,,,,,, -97500,0.14911881,0.018712766,,,,,,,,,,,,,,,,, -97600,0.13666704,0.01521099,,,,,,,,,,,,,,,,, -97700,0.13175473,0.017355537,,,,,,,,,,,,,,,,, -97738,,,0.995638906955719,0.0136951776221394,0.7867676729768284,0.9869915843009948,0.0507423616945743,0.2937441843168288,43793.0,0.9862112998962402,0.0540365986526012,0.2754499157951282,43793.0,31471.98353600502,48362.4382288456,31471.98353600502,16882.567397117615,5.139058351516724,0.0 -97800,0.14011239,0.016887087,,,,,,,,,,,,,,,,, -97900,0.13205913,0.016575566,,,,,,,,,,,,,,,,, -98000,0.13948953,0.014908012,,,,,,,,,,,,,,,,, -98100,0.14155956,0.015767816,,,,,,,,,,,,,,,,, -98200,0.12075279,0.016312521,,,,,,,,,,,,,,,,, -98300,0.1508244,0.01918992,,,,,,,,,,,,,,,,, -98400,0.15259334,0.01822873,,,,,,,,,,,,,,,,, -98487,,,0.995592474937439,0.0138282580301165,0.7842024551917022,0.9869915843009948,0.0507423616945743,0.2937038315424408,43793.0,0.9862112998962402,0.0540365986526012,0.2754063960084381,43793.0,31712.06727313996,48716.67514181137,31712.06727313996,16996.652015209198,5.186892509460449,0.0 -98500,0.143934,0.017798213,,,,,,,,,,,,,,,,, -98600,0.14535885,0.01774011,,,,,,,,,,,,,,,,, -98700,0.1611737,0.017417556,,,,,,,,,,,,,,,,, -98800,0.13855487,0.018465513,,,,,,,,,,,,,,,,, -98900,0.14413871,0.018778263,,,,,,,,,,,,,,,,, -99000,0.12955908,0.016336434,,,,,,,,,,,,,,,,, -99100,0.14367943,0.017690925,,,,,,,,,,,,,,,,, -99200,0.15429309,0.017237758,,,,,,,,,,,,,,,,, -99229,,,0.9955712556838988,0.0138505389913916,0.7811136130391136,0.9869915843009948,0.0507423616945743,0.2937046239389677,43793.0,0.9862112998962402,0.0540365986526012,0.2754442037892525,43793.0,31952.20751523972,49073.06725502014,31952.20751523972,17112.835739850998,5.234352111816406,0.0 -99300,0.13510633,0.017474992,,,,,,,,,,,,,,,,, -99400,0.15385768,0.018797006,,,,,,,,,,,,,,,,, -99500,0.14790925,0.019601397,,,,,,,,,,,,,,,,, -99600,0.1411419,0.018647319,,,,,,,,,,,,,,,,, -99700,0.14000322,0.017675,,,,,,,,,,,,,,,,, -99800,0.13387853,0.016296934,,,,,,,,,,,,,,,,, -99900,0.16140255,0.018472211,,,,,,,,,,,,,,,,, -99975,,,0.9955838918685912,0.0137693621218204,0.7773563293839678,0.9869915843009948,0.0507423616945743,0.2937261884863826,43793.0,0.9862112998962402,0.0540365986526012,0.2753772766361173,43793.0,32192.220431804657,49423.6063015461,32192.220431804657,17223.295473337173,5.279926300048828,0.0 -100000,0.1500172,0.018639138,,,,,,,,,,,,,,,,, -100100,0.14192624,0.01922813,,,,,,,,,,,,,,,,, -100200,0.1332666,0.017709125,,,,,,,,,,,,,,,,, -100300,0.1331302,0.018118452,,,,,,,,,,,,,,,,, -100400,0.14628318,0.016266847,,,,,,,,,,,,,,,,, -100500,0.14532675,0.018309461,,,,,,,,,,,,,,,,, -100600,0.13836609,0.017068014,,,,,,,,,,,,,,,,, -100700,0.1364337,0.016416244,,,,,,,,,,,,,,,,, -100725,,,0.9955572485923768,0.0138950860127806,0.7741282285367919,0.9869915843009948,0.0507423616945743,0.2936773969363314,43793.0,0.9862112998962402,0.0540365986526012,0.2754337862658121,43793.0,32432.44850897789,49782.10767865181,32432.44850897789,17341.50161242485,5.326424598693848,0.0 -100800,0.14451051,0.017792076,,,,,,,,,,,,,,,,, -100900,0.13303645,0.016815728,,,,,,,,,,,,,,,,, -101000,0.15763739,0.016326327,,,,,,,,,,,,,,,,, -101100,0.13163239,0.01617045,,,,,,,,,,,,,,,,, -101200,0.13618734,0.016992826,,,,,,,,,,,,,,,,, -101300,0.14019148,0.01804443,,,,,,,,,,,,,,,,, -101400,0.14211784,0.01627572,,,,,,,,,,,,,,,,, -101463,,,0.9955657720565796,0.0138729671016335,0.7906566108249474,0.9869915843009948,0.0507423616945743,0.2936484156831344,43793.0,0.9862112998962402,0.0540365986526012,0.2754924725969873,43793.0,32672.53692626953,50136.65001726151,32672.53692626953,17455.88349723816,5.3753931522369385,0.0 -101500,0.15847608,0.01788442,,,,,,,,,,,,,,,,, -101600,0.13128503,0.016382633,,,,,,,,,,,,,,,,, -101700,0.1453092,0.017710494,,,,,,,,,,,,,,,,, -101800,0.13980287,0.014871568,,,,,,,,,,,,,,,,, -101900,0.1387522,0.016811289,,,,,,,,,,,,,,,,, -102000,0.15060803,0.017390916,,,,,,,,,,,,,,,,, -102100,0.13380048,0.016740285,,,,,,,,,,,,,,,,, -102200,0.16628562,0.018625597,,,,,,,,,,,,,,,,, -102209,,,0.9956038594245912,0.0138377929106354,0.7844097048287434,0.9869915843009948,0.0507423616945743,0.2935776915827144,43793.0,0.9862112998962402,0.0540365986526012,0.2754421511920781,43793.0,32912.768273591995,50492.8391327858,32912.768273591995,17571.774189710617,5.421801805496216,0.0 -102300,0.15368375,0.021166924,,,,,,,,,,,,,,,,, -102400,0.13748464,0.016504861,,,,,,,,,,,,,,,,, -102500,0.14414097,0.016468579,,,,,,,,,,,,,,,,, -102600,0.15834841,0.018220903,,,,,,,,,,,,,,,,, -102700,0.13488638,0.018000402,,,,,,,,,,,,,,,,, -102800,0.1440467,0.016912937,,,,,,,,,,,,,,,,, -102900,0.1476318,0.01783721,,,,,,,,,,,,,,,,, -102950,,,0.9955734014511108,0.0137864360585808,0.7839354967984052,0.9869915843009948,0.0507423616945743,0.2937865842712753,43793.0,0.9862112998962402,0.0540365986526012,0.2753811304726336,43793.0,33152.73183774948,50845.19981837273,33152.73183774948,17684.104935884476,5.467526435852051,0.0 -103000,0.13849267,0.01769491,,,,,,,,,,,,,,,,, -103100,0.14379856,0.017752476,,,,,,,,,,,,,,,,, -103200,0.13777274,0.01773536,,,,,,,,,,,,,,,,, -103300,0.13078985,0.015515336,,,,,,,,,,,,,,,,, -103400,0.15798835,0.019663138,,,,,,,,,,,,,,,,, -103500,0.15092337,0.017820448,,,,,,,,,,,,,,,,, -103600,0.14626877,0.019465124,,,,,,,,,,,,,,,,, -103691,,,0.9955546855926514,0.0138387288898229,0.7864173109327909,0.9869915843009948,0.0507423616945743,0.2938074174387261,43793.0,0.9862112998962402,0.0540365986526012,0.2754707876834113,43793.0,33392.72674036026,51200.07393550873,33392.72674036026,17798.916505098343,5.51433539390564,0.0 -103700,0.16457698,0.017659651,,,,,,,,,,,,,,,,, -103800,0.15122786,0.01802127,,,,,,,,,,,,,,,,, -103900,0.12987733,0.01641341,,,,,,,,,,,,,,,,, -104000,0.15250754,0.018539231,,,,,,,,,,,,,,,,, -104100,0.143953,0.017395712,,,,,,,,,,,,,,,,, -104200,0.15415569,0.019451417,,,,,,,,,,,,,,,,, -104300,0.13338225,0.017039403,,,,,,,,,,,,,,,,, -104400,0.14766452,0.017204745,,,,,,,,,,,,,,,,, -104443,,,0.995599091053009,0.0137803135439753,0.7693775071494818,0.9869915843009948,0.0507423616945743,0.2937300487911993,43793.0,0.9862112998962402,0.0540365986526012,0.275457568870989,43793.0,33632.83125758171,51555.50200009346,33632.83125758171,17914.172067642212,5.561546802520752,0.0 -104500,0.15079144,0.019883357,,,,,,,,,,,,,,,,, -104600,0.14864299,0.018031158,,,,,,,,,,,,,,,,, -104700,0.15791848,0.016996313,,,,,,,,,,,,,,,,, -104800,0.13532704,0.01800288,,,,,,,,,,,,,,,,, -104900,0.17671812,0.020808749,,,,,,,,,,,,,,,,, -105000,0.13390364,0.015244552,,,,,,,,,,,,,,,,, -105100,0.14433533,0.01685857,,,,,,,,,,,,,,,,, -105195,,,0.9955223798751832,0.0139496717602014,0.7825516269931497,0.9869915843009948,0.0507423616945743,0.2937538129987925,43793.0,0.9862112998962402,0.0540365986526012,0.2754226879137796,43793.0,33872.89906716347,51908.106872558594,33872.89906716347,18026.642154455185,5.607269763946533,0.0 -105200,0.16317366,0.017721523,,,,,,,,,,,,,,,,, -105300,0.13707358,0.017519174,,,,,,,,,,,,,,,,, -105400,0.14881955,0.0178104,,,,,,,,,,,,,,,,, -105500,0.15328307,0.01765733,,,,,,,,,,,,,,,,, -105600,0.15398873,0.02161797,,,,,,,,,,,,,,,,, -105700,0.13174196,0.016580924,,,,,,,,,,,,,,,,, -105800,0.13897036,0.018020686,,,,,,,,,,,,,,,,, -105900,0.14210388,0.017345991,,,,,,,,,,,,,,,,, -105942,,,0.9956201314926147,0.0137748792767524,0.7936656836107864,0.9869915843009948,0.0507423616945743,0.2937028913704122,43793.0,0.9862112998962402,0.0540365986526012,0.2753237852488251,43793.0,34112.948595285416,52263.84589409828,34112.948595285416,18142.262134552,5.656062602996826,0.0 -106000,0.15363826,0.017829787,,,,,,,,,,,,,,,,, -106100,0.15058193,0.01737905,,,,,,,,,,,,,,,,, -106200,0.15275462,0.018834269,,,,,,,,,,,,,,,,, -106300,0.14092223,0.016676733,,,,,,,,,,,,,,,,, -106400,0.14353923,0.017608453,,,,,,,,,,,,,,,,, -106500,0.13539547,0.016427606,,,,,,,,,,,,,,,,, -106600,0.15013903,0.017380817,,,,,,,,,,,,,,,,, -106687,,,0.9955931305885316,0.0138001209124922,0.7827427341975708,0.9869915843009948,0.0507423616945743,0.2938741571133589,43793.0,0.9862112998962402,0.0540365986526012,0.2753481876534774,43793.0,34353.06476831436,52614.17381596565,34353.06476831436,18252.405061244965,5.704350709915161,0.0 -106700,0.13885403,0.015741844,,,,,,,,,,,,,,,,, -106800,0.14525773,0.018630901,,,,,,,,,,,,,,,,, -106900,0.12988038,0.016714992,,,,,,,,,,,,,,,,, -107000,0.14058986,0.017641151,,,,,,,,,,,,,,,,, -107100,0.14504293,0.01704883,,,,,,,,,,,,,,,,, -107200,0.13152091,0.016107634,,,,,,,,,,,,,,,,, -107300,0.15747064,0.019474022,,,,,,,,,,,,,,,,, -107400,0.13734053,0.017571218,,,,,,,,,,,,,,,,, -107435,,,0.9955994486808776,0.01373241096735,0.7809417040681077,0.9869915843009948,0.0507423616945743,0.2938712298130228,43793.0,0.9862112998962402,0.0540365986526012,0.2754138806616004,43793.0,34593.14384675026,52972.46685504913,34593.14384675026,18370.54836511612,5.751443147659302,0.0 -107500,0.15103848,0.017984005,,,,,,,,,,,,,,,,, -107600,0.12801331,0.017411945,,,,,,,,,,,,,,,,, -107700,0.16039087,0.017506022,,,,,,,,,,,,,,,,, -107800,0.16000508,0.017345715,,,,,,,,,,,,,,,,, -107900,0.16145569,0.017419469,,,,,,,,,,,,,,,,, -108000,0.15297677,0.018800693,,,,,,,,,,,,,,,,, -108100,0.15603937,0.017831203,,,,,,,,,,,,,,,,, -108172,,,0.9955727458000184,0.0138137601315975,0.7745374913235604,0.9869915843009948,0.0507423616945743,0.2936441576869056,43793.0,0.9862112998962402,0.0540365986526012,0.2753967137243359,43793.0,34833.32964348793,53322.83411717415,34833.32964348793,18480.653708934784,5.805322170257568,0.0 -108200,0.14144361,0.016275039,,,,,,,,,,,,,,,,, -108300,0.14557925,0.01760552,,,,,,,,,,,,,,,,, -108400,0.14984195,0.017918931,,,,,,,,,,,,,,,,, -108500,0.13957912,0.015952053,,,,,,,,,,,,,,,,, -108600,0.16745867,0.020628626,,,,,,,,,,,,,,,,, -108700,0.149345,0.018577443,,,,,,,,,,,,,,,,, -108800,0.16091803,0.01818908,,,,,,,,,,,,,,,,, -108900,0.13952333,0.017471813,,,,,,,,,,,,,,,,, -108916,,,0.9955095052719116,0.0140068819746375,0.7848383461143658,0.9869915843009948,0.0507423616945743,0.2936524279192357,43793.0,0.9862112998962402,0.0540365986526012,0.2754546663138822,43793.0,35073.34981870651,53672.68225312233,35073.34981870651,18590.412984609604,5.85346269607544,0.0 -109000,0.12320665,0.015585503,,,,,,,,,,,,,,,,, -109100,0.16873044,0.019554323,,,,,,,,,,,,,,,,, -109200,0.14452355,0.017378785,,,,,,,,,,,,,,,,, -109300,0.13970518,0.01624004,,,,,,,,,,,,,,,,, -109400,0.12761538,0.015102765,,,,,,,,,,,,,,,,, -109500,0.14427695,0.018101698,,,,,,,,,,,,,,,,, -109600,0.13993569,0.01728613,,,,,,,,,,,,,,,,, -109670,,,0.9956090450286864,0.0138024566695094,0.7855790926404516,0.9869915843009948,0.0507423616945743,0.2937035017352725,43793.0,0.9862112998962402,0.0540365986526012,0.275395700887565,43793.0,35313.440516233444,54022.661536455154,35313.440516233444,18700.22829508781,5.906094074249268,0.0 -109700,0.14507675,0.01604794,,,,,,,,,,,,,,,,, -109800,0.1635957,0.018699374,,,,,,,,,,,,,,,,, -109900,0.13713132,0.016158031,,,,,,,,,,,,,,,,, -110000,0.14999305,0.019917214,,,,,,,,,,,,,,,,, -110100,0.15741645,0.016798554,,,,,,,,,,,,,,,,, -110200,0.13891262,0.018523168,,,,,,,,,,,,,,,,, -110300,0.14839941,0.020124467,,,,,,,,,,,,,,,,, -110400,0.16502592,0.01993613,,,,,,,,,,,,,,,,, -110416,,,0.9955815076828004,0.0137841440737247,0.7876969888879968,0.9869915843009948,0.0507423616945743,0.293807377697929,43793.0,0.9862112998962402,0.0540365986526012,0.2755301847475866,43793.0,35553.47294831276,54375.65848469734,35553.47294831276,18813.123986005783,5.954303979873657,0.0 -110500,0.1263899,0.016385332,,,,,,,,,,,,,,,,, -110600,0.13532178,0.017013904,,,,,,,,,,,,,,,,, -110700,0.15355462,0.01779642,,,,,,,,,,,,,,,,, -110800,0.13729727,0.01746951,,,,,,,,,,,,,,,,, -110900,0.15221424,0.017905971,,,,,,,,,,,,,,,,, -111000,0.13468198,0.017577583,,,,,,,,,,,,,,,,, -111100,0.15174073,0.017960064,,,,,,,,,,,,,,,,, -111167,,,0.9955974817276,0.0138169005513191,0.7778713694892079,0.9869915843009948,0.0507423616945743,0.2936051051603034,43793.0,0.9862112998962402,0.0540365986526012,0.2755673129599433,43793.0,35793.5690510273,54734.97816634178,35793.5690510273,18932.277148246765,6.003270387649536,0.0 -111200,0.1441879,0.018552553,,,,,,,,,,,,,,,,, -111300,0.14274716,0.019156989,,,,,,,,,,,,,,,,, -111400,0.13760501,0.016760642,,,,,,,,,,,,,,,,, -111500,0.13422896,0.016064368,,,,,,,,,,,,,,,,, -111600,0.13515364,0.01652031,,,,,,,,,,,,,,,,, -111700,0.13909791,0.01783705,,,,,,,,,,,,,,,,, -111800,0.15676002,0.021150067,,,,,,,,,,,,,,,,, -111900,0.14515354,0.019599147,,,,,,,,,,,,,,,,, -111912,,,0.9955654144287108,0.0137704741209745,0.7812751085022197,0.9869915843009948,0.0507423616945743,0.2936191037192044,43793.0,0.9862112998962402,0.0540365986526012,0.2753573878831423,43793.0,36033.5463924408,55085.2830324173,36033.5463924408,19042.534830093384,6.051767826080322,0.0 -112000,0.14267561,0.016155671,,,,,,,,,,,,,,,,, -112100,0.14067806,0.017318271,,,,,,,,,,,,,,,,, -112200,0.1421445,0.015294947,,,,,,,,,,,,,,,,, -112300,0.14460012,0.017387683,,,,,,,,,,,,,,,,, -112400,0.14704567,0.017273238,,,,,,,,,,,,,,,,, -112500,0.14166182,0.017627966,,,,,,,,,,,,,,,,, -112600,0.13037387,0.016725788,,,,,,,,,,,,,,,,, -112651,,,0.9955233335494996,0.0140600008890032,0.7775877571357277,0.9869915843009948,0.0507423616945743,0.2937597007290897,43793.0,0.9862112998962402,0.0540365986526012,0.275431472781563,43793.0,36273.49717617035,55434.03934621811,36273.49717617035,19151.26888847351,6.1010048389434814,0.0 -112700,0.13718708,0.016005494,,,,,,,,,,,,,,,,, -112800,0.13200255,0.016687106,,,,,,,,,,,,,,,,, -112900,0.12599362,0.016328769,,,,,,,,,,,,,,,,, -113000,0.15954177,0.01880235,,,,,,,,,,,,,,,,, -113100,0.123307094,0.01725869,,,,,,,,,,,,,,,,, -113200,0.16154012,0.021001376,,,,,,,,,,,,,,,,, -113300,0.14254609,0.017613702,,,,,,,,,,,,,,,,, -113393,,,0.9956125020980836,0.0137028321623802,0.7856105334997641,0.9869915843009948,0.0507423616945743,0.2936817305110332,43793.0,0.9862112998962402,0.0540365986526012,0.2753357528733862,43793.0,36513.54458999634,55787.701063632965,36513.54458999634,19264.81359577179,6.149595975875855,0.0 -113400,0.12691632,0.016187951,,,,,,,,,,,,,,,,, -113500,0.14423028,0.016658414,,,,,,,,,,,,,,,,, -113600,0.14829354,0.018739957,,,,,,,,,,,,,,,,, -113700,0.13783324,0.016703904,,,,,,,,,,,,,,,,, -113800,0.16433899,0.01974643,,,,,,,,,,,,,,,,, -113900,0.13481459,0.017784683,,,,,,,,,,,,,,,,, -114000,0.14991547,0.016760135,,,,,,,,,,,,,,,,, -114100,0.15142287,0.018060066,,,,,,,,,,,,,,,,, -114130,,,0.9956178069114684,0.0137725817039608,0.7901904740236783,0.9869915843009948,0.0507423616945743,0.2935913388519824,43793.0,0.9862112998962402,0.0540365986526012,0.2754157196970703,43793.0,36753.55989098549,56140.56715130806,36753.55989098549,19377.594454288483,6.197968244552612,0.0 -114200,0.15139875,0.01898904,,,,,,,,,,,,,,,,, -114300,0.1453819,0.01836974,,,,,,,,,,,,,,,,, -114400,0.13023195,0.015583098,,,,,,,,,,,,,,,,, -114500,0.1393798,0.01791364,,,,,,,,,,,,,,,,, -114600,0.14154984,0.015769433,,,,,,,,,,,,,,,,, -114700,0.15208279,0.016231788,,,,,,,,,,,,,,,,, -114800,0.14047275,0.01678926,,,,,,,,,,,,,,,,, -114866,,,0.9955707788467408,0.0138441119343042,0.7832189831767615,0.9869915843009948,0.0507423616945743,0.2935787289523329,43793.0,0.9862112998962402,0.0540365986526012,0.2754269244754509,43793.0,36993.52682876587,56496.18143749237,36993.52682876587,19493.173021554947,6.24558687210083,0.0 -114900,0.14058062,0.017224558,,,,,,,,,,,,,,,,, -115000,0.14807655,0.019383514,,,,,,,,,,,,,,,,, -115100,0.1343129,0.017129404,,,,,,,,,,,,,,,,, -115200,0.14610729,0.016950304,,,,,,,,,,,,,,,,, -115300,0.16572085,0.017726362,,,,,,,,,,,,,,,,, -115400,0.14112186,0.01888909,,,,,,,,,,,,,,,,, -115500,0.15775052,0.019076955,,,,,,,,,,,,,,,,, -115600,0.13312145,0.017179985,,,,,,,,,,,,,,,,, -115612,,,0.9955816268920898,0.0137649774551391,0.7833042739703718,0.9869915843009948,0.0507423616945743,0.293672361685288,43793.0,0.9862112998962402,0.0540365986526012,0.2754521002138382,43793.0,37233.61578011513,56849.43639016152,37233.61578011513,19606.26589083672,6.298041105270386,0.0 -115700,0.14190586,0.0170342,,,,,,,,,,,,,,,,, -115800,0.15281206,0.018174948,,,,,,,,,,,,,,,,, -115900,0.14486326,0.018538825,,,,,,,,,,,,,,,,, -116000,0.14570723,0.017985439,,,,,,,,,,,,,,,,, -116100,0.1669675,0.020990366,,,,,,,,,,,,,,,,, -116200,0.14251627,0.017492676,,,,,,,,,,,,,,,,, -116300,0.14801872,0.01716597,,,,,,,,,,,,,,,,, -116362,,,0.9955711960792542,0.0138259436935186,0.7763133112655201,0.9869915843009948,0.0507423616945743,0.2937281957738649,43793.0,0.9862112998962402,0.0540365986526012,0.2754720384718899,43793.0,37473.81154060364,57200.054525375366,37473.81154060364,19716.618169546127,6.347314119338989,0.0 -116400,0.12246849,0.016826985,,,,,,,,,,,,,,,,, -116500,0.15797922,0.016202603,,,,,,,,,,,,,,,,, -116600,0.17180362,0.018880954,,,,,,,,,,,,,,,,, -116700,0.1503781,0.016311236,,,,,,,,,,,,,,,,, -116800,0.12816295,0.016476152,,,,,,,,,,,,,,,,, -116900,0.13890399,0.017974708,,,,,,,,,,,,,,,,, -117000,0.11867259,0.015782239,,,,,,,,,,,,,,,,, -117100,0.14079289,0.015107748,,,,,,,,,,,,,,,,, -117102,,,0.9955509305000304,0.0138606643304228,0.7871036951453844,0.9869915843009948,0.0507423616945743,0.2936363325723397,43793.0,0.9862112998962402,0.0540365986526012,0.275412886815387,43793.0,37713.840057611465,57552.76069331169,37713.840057611465,19829.223463773727,6.399012088775635,0.0 -117200,0.13484955,0.018274937,,,,,,,,,,,,,,,,, -117300,0.13874306,0.017784249,,,,,,,,,,,,,,,,, -117400,0.14722666,0.017995449,,,,,,,,,,,,,,,,, -117500,0.14766304,0.017991286,,,,,,,,,,,,,,,,, -117600,0.14053127,0.015145283,,,,,,,,,,,,,,,,, -117700,0.144885,0.018848099,,,,,,,,,,,,,,,,, -117800,0.14891239,0.018367503,,,,,,,,,,,,,,,,, -117842,,,0.9955703616142272,0.0139037342742085,0.7890036277593242,0.9869915843009948,0.0507423616945743,0.2936448231073847,43793.0,0.9862112998962402,0.0540365986526012,0.2753685032524396,43793.0,37954.070449113846,57904.84043097496,37954.070449113846,19940.994478702545,6.45639443397522,0.0 -117900,0.15099758,0.019286573,,,,,,,,,,,,,,,,, -118000,0.1443417,0.016688049,,,,,,,,,,,,,,,,, -118100,0.1271168,0.01503927,,,,,,,,,,,,,,,,, -118200,0.13509825,0.017347166,,,,,,,,,,,,,,,,, -118300,0.14305526,0.018137963,,,,,,,,,,,,,,,,, -118400,0.14073406,0.017509103,,,,,,,,,,,,,,,,, -118500,0.16048524,0.017192522,,,,,,,,,,,,,,,,, -118584,,,0.995596408843994,0.0138527592644095,0.7802696371874123,0.9869915843009948,0.0507423616945743,0.2936510843504485,43793.0,0.9862112998962402,0.0540365986526012,0.2754307721802707,43793.0,38194.10259127617,58260.26805996895,38194.10259127617,20056.31689786911,6.509163856506348,0.0 -118600,0.16140969,0.018228285,,,,,,,,,,,,,,,,, -118700,0.13226259,0.016999757,,,,,,,,,,,,,,,,, -118800,0.14582627,0.017522134,,,,,,,,,,,,,,,,, -118900,0.14290047,0.016075935,,,,,,,,,,,,,,,,, -119000,0.15235625,0.017318415,,,,,,,,,,,,,,,,, -119100,0.14245309,0.015762027,,,,,,,,,,,,,,,,, -119200,0.13841131,0.016916802,,,,,,,,,,,,,,,,, -119300,0.13956644,0.016588256,,,,,,,,,,,,,,,,, -119331,,,0.995570421218872,0.0137610621750354,0.7811442251892875,0.9869915843009948,0.0507423616945743,0.2935784888838916,43793.0,0.9862112998962402,0.0540366023778915,0.2755224102040603,43793.0,38434.34361219406,58613.13779401779,38434.34361219406,20168.87475633621,6.5590479373931885,0.0 -119400,0.14281482,0.016587095,,,,,,,,,,,,,,,,, -119500,0.13938613,0.016097605,,,,,,,,,,,,,,,,, -119600,0.13433915,0.017093886,,,,,,,,,,,,,,,,, -119700,0.14787656,0.017237984,,,,,,,,,,,,,,,,, -119800,0.15468325,0.018750776,,,,,,,,,,,,,,,,, -119900,0.15520148,0.019287363,,,,,,,,,,,,,,,,, -120000,0.15122496,0.018132977,,,,,,,,,,,,,,,,, -120074,,,0.9955947399139404,0.0137996897101402,0.7765940262969431,0.9869915843009948,0.0507423616945743,0.293746364209449,43793.0,0.9862112998962402,0.0540365986526012,0.2754766615316152,43793.0,38674.53906083107,58963.95445632935,38674.53906083107,20279.42533969879,6.609145879745483,0.0 -120100,0.15199362,0.018293891,,,,,,,,,,,,,,,,, -120200,0.15241551,0.01682339,,,,,,,,,,,,,,,,, -120300,0.16805856,0.01914801,,,,,,,,,,,,,,,,, -120400,0.1400482,0.01733506,,,,,,,,,,,,,,,,, -120500,0.14260758,0.017514499,,,,,,,,,,,,,,,,, -120600,0.13212399,0.015876533,,,,,,,,,,,,,,,,, -120700,0.13299154,0.016323859,,,,,,,,,,,,,,,,, -120800,0.1577938,0.018733164,,,,,,,,,,,,,,,,, -120812,,,0.9955591559410096,0.0138795301318168,0.7845728335870632,0.9869915843009948,0.0507423616945743,0.2937915479617342,43793.0,0.9862112998962402,0.0540365986526012,0.2755423162549807,43793.0,38914.5114774704,59320.885125637054,38914.5114774704,20396.31151819229,6.660247087478638,0.0 -120900,0.14984764,0.0180276,,,,,,,,,,,,,,,,, -121000,0.16561778,0.018973397,,,,,,,,,,,,,,,,, -121100,0.14099765,0.01815897,,,,,,,,,,,,,,,,, -121200,0.16055454,0.02001534,,,,,,,,,,,,,,,,, -121300,0.1469213,0.01641586,,,,,,,,,,,,,,,,, -121400,0.14844853,0.016044445,,,,,,,,,,,,,,,,, -121500,0.13645762,0.01632057,,,,,,,,,,,,,,,,, -121554,,,0.995585799217224,0.0137846209108829,0.7895918756519975,0.9869915843009948,0.0507423616945743,0.2937423974114145,43793.0,0.9862112998962402,0.0540365986526012,0.2755032971678063,43793.0,39154.60091519356,59671.376187086105,39154.60091519356,20506.640821695328,6.712068557739258,0.0 -121600,0.13596648,0.018073792,,,,,,,,,,,,,,,,, -121700,0.13921796,0.017443916,,,,,,,,,,,,,,,,, -121800,0.135884,0.01761752,,,,,,,,,,,,,,,,, -121900,0.14560364,0.019191468,,,,,,,,,,,,,,,,, -122000,0.15682232,0.017858377,,,,,,,,,,,,,,,,, -122100,0.15046746,0.01858572,,,,,,,,,,,,,,,,, -122200,0.16513325,0.0226491,,,,,,,,,,,,,,,,, -122300,0.13642296,0.016981147,,,,,,,,,,,,,,,,, -122302,,,0.995614528656006,0.0137911001220345,0.7868911213978563,0.9869915843009948,0.0507423616945743,0.2936706497959538,43793.0,0.9862112998962402,0.0540365986526012,0.2754014452639356,43793.0,39394.5952064991,60025.72388958931,39394.5952064991,20620.923574447632,6.762216567993164,0.0 -122400,0.1520758,0.018904416,,,,,,,,,,,,,,,,, -122500,0.13510334,0.017201468,,,,,,,,,,,,,,,,, -122600,0.14763667,0.018727876,,,,,,,,,,,,,,,,, -122700,0.12335697,0.015094277,,,,,,,,,,,,,,,,, -122800,0.14316644,0.01987115,,,,,,,,,,,,,,,,, -122900,0.15916687,0.018540151,,,,,,,,,,,,,,,,, -123000,0.14917023,0.017983692,,,,,,,,,,,,,,,,, -123046,,,0.995576560497284,0.013819707557559,0.7784860680421057,0.9869915843009948,0.0507423616945743,0.2935860849708683,43793.0,0.9862112998962402,0.0540365986526012,0.2755150682884237,43793.0,39634.70795702934,60377.917174339294,39634.70795702934,20732.93293976784,6.812769651412964,0.0 -123100,0.14722978,0.018595284,,,,,,,,,,,,,,,,, -123200,0.15279837,0.017962743,,,,,,,,,,,,,,,,, -123300,0.13226822,0.017308021,,,,,,,,,,,,,,,,, -123400,0.15731867,0.019190116,,,,,,,,,,,,,,,,, -123500,0.1553327,0.01958489,,,,,,,,,,,,,,,,, -123600,0.13702755,0.018475525,,,,,,,,,,,,,,,,, -123700,0.14927825,0.018335292,,,,,,,,,,,,,,,,, -123787,,,0.9955735206604004,0.0138522526249289,0.7843652857748288,0.9869915843009948,0.0507423616945743,0.2936752539884849,43793.0,0.9862112998962402,0.0540365986526012,0.2754583601199453,43793.0,39874.68226027489,60731.89415550232,39874.68226027489,20846.8649699688,6.862906455993652,0.0 -123800,0.14827311,0.02004638,,,,,,,,,,,,,,,,, -123900,0.14174993,0.015962793,,,,,,,,,,,,,,,,, -124000,0.18137921,0.017890625,,,,,,,,,,,,,,,,, -124100,0.13306816,0.016353145,,,,,,,,,,,,,,,,, -124200,0.14566533,0.01714904,,,,,,,,,,,,,,,,, -124300,0.13217884,0.01702088,,,,,,,,,,,,,,,,, -124400,0.14447716,0.017015211,,,,,,,,,,,,,,,,, -124500,0.15095869,0.019069038,,,,,,,,,,,,,,,,, -124526,,,0.995505154132843,0.0140109313651919,0.7773680917507209,0.9869915843009948,0.0507423616945743,0.2937885779266836,43793.0,0.9862112998962402,0.0540365986526012,0.2754465652294123,43793.0,40114.67047047615,61084.737924575806,40114.67047047615,20959.648869991302,6.913788318634033,0.0 -124600,0.11806776,0.016548447,,,,,,,,,,,,,,,,, -124700,0.16249897,0.018524947,,,,,,,,,,,,,,,,, -124800,0.15454683,0.020286297,,,,,,,,,,,,,,,,, -124900,0.14451009,0.016192194,,,,,,,,,,,,,,,,, -125000,0.14278096,0.019013641,,,,,,,,,,,,,,,,, -125100,0.14662953,0.016515296,,,,,,,,,,,,,,,,, -125200,0.15263395,0.01687623,,,,,,,,,,,,,,,,, -125262,,,0.9956053495407104,0.0137420119717717,0.7890075856031153,0.9869915843009948,0.0507423616945743,0.2936749697469681,43793.0,0.9862112998962402,0.0540365986526012,0.2753858426724248,43793.0,40354.857283592224,61439.29281306267,40354.857283592224,21073.94648528099,6.963781595230103,0.0 -125300,0.13828398,0.016622918,,,,,,,,,,,,,,,,, -125400,0.15135434,0.019166479,,,,,,,,,,,,,,,,, -125500,0.15964638,0.01847129,,,,,,,,,,,,,,,,, -125600,0.15017025,0.018660992,,,,,,,,,,,,,,,,, -125700,0.13939798,0.01859066,,,,,,,,,,,,,,,,, -125800,0.1342095,0.017857,,,,,,,,,,,,,,,,, -125900,0.15252487,0.017503155,,,,,,,,,,,,,,,,, -126000,0.15754299,0.014996763,,,,,,,,,,,,,,,,, -126001,,,0.9956063628196716,0.0138191375881433,0.7756586723562215,0.9869915843009948,0.0507423616945743,0.2937366296262176,43793.0,0.9862112998962402,0.0540365986526012,0.2756016566885009,43793.0,40594.93407726288,61793.05135250092,40594.93407726288,21187.55341076851,7.017248630523682,0.0 -126100,0.15818149,0.017607868,,,,,,,,,,,,,,,,, -126200,0.15862188,0.018601093,,,,,,,,,,,,,,,,, -126300,0.14285086,0.01955306,,,,,,,,,,,,,,,,, -126400,0.14820617,0.016695365,,,,,,,,,,,,,,,,, -126500,0.13273029,0.01797876,,,,,,,,,,,,,,,,, -126600,0.157471,0.019856172,,,,,,,,,,,,,,,,, -126700,0.15029654,0.017606972,,,,,,,,,,,,,,,,, -126746,,,0.9955751299858092,0.0137756494805216,0.7885662798229311,0.9869915843009948,0.0507423616945743,0.2936974288321023,43793.0,0.9862112998962402,0.0540365986526012,0.2755557447850956,43793.0,40834.8824300766,62147.64252591133,40834.8824300766,21302.12158560753,7.071530342102051,0.0 -126800,0.16570754,0.018946407,,,,,,,,,,,,,,,,, -126900,0.16203268,0.018805722,,,,,,,,,,,,,,,,, -127000,0.1482818,0.017931812,,,,,,,,,,,,,,,,, -127100,0.1568285,0.020026907,,,,,,,,,,,,,,,,, -127200,0.15589446,0.017622147,,,,,,,,,,,,,,,,, -127300,0.158843,0.017794974,,,,,,,,,,,,,,,,, -127400,0.16423945,0.018847689,,,,,,,,,,,,,,,,, -127499,,,0.9956125020980836,0.0137626240029931,0.7903543820948858,0.9869915843009948,0.0507423616945743,0.2935955037536901,43793.0,0.9862112998962402,0.0540365986526012,0.2754435153360622,43793.0,41075.1099421978,62497.366092681885,41075.1099421978,21411.546427965164,7.12211012840271,0.0 -127500,0.12831327,0.016775867,,,,,,,,,,,,,,,,, -127600,0.12980399,0.015238967,,,,,,,,,,,,,,,,, -127700,0.1490977,0.017374747,,,,,,,,,,,,,,,,, -127800,0.13920826,0.015558467,,,,,,,,,,,,,,,,, -127900,0.14797547,0.017714484,,,,,,,,,,,,,,,,, -128000,0.14302884,0.016523689,,,,,,,,,,,,,,,,, -128100,0.15015894,0.017306857,,,,,,,,,,,,,,,,, -128200,0.14806888,0.017185565,,,,,,,,,,,,,,,,, -128248,,,0.9954938292503356,0.0139695256948471,0.771295520144599,0.9869915843009948,0.0507423616945743,0.2937692486438858,43793.0,0.9862112998962402,0.0540365986526012,0.2754312230509766,43793.0,41315.15427827835,62851.930923223495,41315.15427827835,21525.99359869957,7.175124883651733,0.0 -128300,0.14221002,0.018275354,,,,,,,,,,,,,,,,, -128400,0.15351497,0.019111194,,,,,,,,,,,,,,,,, -128500,0.15106253,0.016934486,,,,,,,,,,,,,,,,, -128600,0.16326319,0.01886797,,,,,,,,,,,,,,,,, -128700,0.13769408,0.017080205,,,,,,,,,,,,,,,,, -128800,0.14120913,0.017418459,,,,,,,,,,,,,,,,, -128900,0.15552205,0.018216353,,,,,,,,,,,,,,,,, -128991,,,0.995583951473236,0.0138546442613005,0.7902642742906548,0.9869915843009948,0.0507423616945743,0.2937646572325736,43793.0,0.9862112998962402,0.0540365986526012,0.2753627787110927,43793.0,41555.35489320755,63199.904326200485,41555.35489320755,21633.69455242157,7.226402759552002,0.0 -129000,0.15688795,0.019725626,,,,,,,,,,,,,,,,, -129100,0.15159306,0.019137518,,,,,,,,,,,,,,,,, -129200,0.14576793,0.018327344,,,,,,,,,,,,,,,,, -129300,0.15485461,0.018585892,,,,,,,,,,,,,,,,, -129400,0.12667608,0.017443003,,,,,,,,,,,,,,,,, -129500,0.14145894,0.01745692,,,,,,,,,,,,,,,,, -129600,0.14740291,0.0153911,,,,,,,,,,,,,,,,, -129700,0.15159953,0.016233444,,,,,,,,,,,,,,,,, -129733,,,0.9956200122833252,0.0136795286089181,0.7835308102014059,0.9869915843009948,0.0507423616945743,0.2935854843682218,43793.0,0.9862112998962402,0.0540366023778915,0.2755378317641422,43793.0,41795.50968170166,63548.26286840439,41795.50968170166,21741.826742887497,7.277673959732056,0.0 -129800,0.12505744,0.015092184,,,,,,,,,,,,,,,,, -129900,0.13601945,0.01651288,,,,,,,,,,,,,,,,, -130000,0.14091335,0.01670262,,,,,,,,,,,,,,,,, -130100,0.1434313,0.018897565,,,,,,,,,,,,,,,,, -130200,0.16867417,0.0185477,,,,,,,,,,,,,,,,, -130300,0.13637781,0.017484527,,,,,,,,,,,,,,,,, -130400,0.12958673,0.018151384,,,,,,,,,,,,,,,,, -130478,,,0.995602548122406,0.0138369193300604,0.7875081898725586,0.9869915843009948,0.0507423616945743,0.2936325389844466,43793.0,0.9862112998962402,0.0540365986526012,0.2754409507264886,43793.0,42035.57419133186,63897.71296691895,42035.57419133186,21851.141013383865,7.32834792137146,0.0 -130500,0.12951022,0.015151386,,,,,,,,,,,,,,,,, -130600,0.13602704,0.016752085,,,,,,,,,,,,,,,,, -130700,0.1568478,0.019455757,,,,,,,,,,,,,,,,, -130800,0.14854467,0.018062854,,,,,,,,,,,,,,,,, -130900,0.1484492,0.01635925,,,,,,,,,,,,,,,,, -131000,0.15636027,0.018213702,,,,,,,,,,,,,,,,, -131100,0.13160044,0.018541938,,,,,,,,,,,,,,,,, -131200,0.15143692,0.01851951,,,,,,,,,,,,,,,,, -131227,,,0.9955888390541076,0.0137723758816719,0.7732472724957443,0.9869915843009948,0.0507423616945743,0.2935959974947021,43793.0,0.9862112998962402,0.0540365986526012,0.2752876374633426,43793.0,42275.503985881805,64247.39361643791,42275.503985881805,21960.82028913498,7.379400491714477,0.0 -131300,0.13658875,0.016183209,,,,,,,,,,,,,,,,, -131400,0.15087578,0.01731454,,,,,,,,,,,,,,,,, -131500,0.13933842,0.01872542,,,,,,,,,,,,,,,,, -131600,0.13697568,0.017662972,,,,,,,,,,,,,,,,, -131700,0.14227575,0.017079668,,,,,,,,,,,,,,,,, -131800,0.13779685,0.017035952,,,,,,,,,,,,,,,,, -131900,0.12725104,0.016005518,,,,,,,,,,,,,,,,, -131969,,,0.9955713152885436,0.0138152642175555,0.7818256717297449,0.9869915843009948,0.0507423616945743,0.2936461241032045,43793.0,0.9862112998962402,0.0540365986526012,0.2754581241452384,43793.0,42515.66183280945,64600.09630036354,42515.66183280945,22073.290736675262,7.432871103286743,0.0 -132000,0.13783,0.016856635,,,,,,,,,,,,,,,,, -132100,0.1360454,0.016402887,,,,,,,,,,,,,,,,, -132200,0.14693892,0.017048627,,,,,,,,,,,,,,,,, -132300,0.14518201,0.01731007,,,,,,,,,,,,,,,,, -132400,0.14180268,0.01841746,,,,,,,,,,,,,,,,, -132500,0.14243828,0.016329903,,,,,,,,,,,,,,,,, -132600,0.16074403,0.017998178,,,,,,,,,,,,,,,,, -132700,0.1542607,0.017353794,,,,,,,,,,,,,,,,, -132722,,,0.9955384135246276,0.0139381103217601,0.7809306588978733,0.9869915843009948,0.0507423616945743,0.2938063416664886,43793.0,0.9862112998962402,0.0540365986526012,0.2754239669424201,43793.0,42755.7799987793,64950.33537364006,42755.7799987793,22183.3386054039,7.485567808151245,0.0 -132800,0.16859889,0.017660087,,,,,,,,,,,,,,,,, -132900,0.13992207,0.018227624,,,,,,,,,,,,,,,,, -133000,0.1350212,0.01667435,,,,,,,,,,,,,,,,, -133100,0.17403656,0.019111944,,,,,,,,,,,,,,,,, -133200,0.16037323,0.01906302,,,,,,,,,,,,,,,,, -133300,0.14919198,0.018190997,,,,,,,,,,,,,,,,, -133400,0.14153737,0.01654782,,,,,,,,,,,,,,,,, -133475,,,0.9956251978874208,0.0136889275163412,0.7866224635436307,0.9869915843009948,0.0507423616945743,0.2939696343293295,43793.0,0.9862112998962402,0.0540365986526012,0.2754217274149027,43793.0,42995.96339964867,65298.283219099045,42995.96339964867,22291.031487464905,7.536278247833252,0.0 -133500,0.17709841,0.018433586,,,,,,,,,,,,,,,,, -133600,0.14056386,0.016929438,,,,,,,,,,,,,,,,, -133700,0.14274877,0.01738773,,,,,,,,,,,,,,,,, -133800,0.14433832,0.017487925,,,,,,,,,,,,,,,,, -133900,0.12901944,0.01691027,,,,,,,,,,,,,,,,, -134000,0.14362894,0.01617048,,,,,,,,,,,,,,,,, -134100,0.1557912,0.018486593,,,,,,,,,,,,,,,,, -134200,0.15128912,0.018795153,,,,,,,,,,,,,,,,, -134223,,,0.9955657124519348,0.0139132514595985,0.785629117844441,0.9869915843009948,0.0507423616945743,0.2937110096408846,43793.0,0.9862112998962402,0.0540365986526012,0.275459098161773,43793.0,43236.10581612587,65646.3113667965,43236.10581612587,22398.84451198578,7.58860969543457,0.0 -134300,0.12582274,0.01671585,,,,,,,,,,,,,,,,, -134400,0.14407668,0.017797617,,,,,,,,,,,,,,,,, -134500,0.14029817,0.017594893,,,,,,,,,,,,,,,,, -134600,0.15293135,0.019311016,,,,,,,,,,,,,,,,, -134700,0.15537985,0.020135634,,,,,,,,,,,,,,,,, -134800,0.12489704,0.016685314,,,,,,,,,,,,,,,,, -134900,0.13030972,0.01625912,,,,,,,,,,,,,,,,, -134960,,,0.995591163635254,0.0138048967346549,0.7821396458418919,0.9869915843009948,0.0507423616945743,0.2935877750821879,43793.0,0.9862112998962402,0.0540365986526012,0.2754662314445237,43793.0,43476.32939147949,66001.26565551758,43476.32939147949,22513.485692977905,7.652945756912231,0.0 -135000,0.14499438,0.019420745,,,,,,,,,,,,,,,,, -135100,0.14769943,0.018480923,,,,,,,,,,,,,,,,, -135200,0.1495177,0.018065404,,,,,,,,,,,,,,,,, -135300,0.13567437,0.017599221,,,,,,,,,,,,,,,,, -135400,0.14765042,0.01839453,,,,,,,,,,,,,,,,, -135500,0.1386683,0.016274957,,,,,,,,,,,,,,,,, -135600,0.15752968,0.01963349,,,,,,,,,,,,,,,,, -135700,0.13159513,0.018154869,,,,,,,,,,,,,,,,, -135704,,,0.9955680966377258,0.0138288754969835,0.7725323100158961,0.9869915843009948,0.0507423616945743,0.2938210133629313,43793.0,0.9862112998962402,0.0540365986526012,0.2755756755415497,43793.0,43716.24765229225,66350.86281728745,43716.24765229225,22623.08250141144,7.711076021194458,0.0 -135800,0.15404803,0.0171021,,,,,,,,,,,,,,,,, -135900,0.13426937,0.015471307,,,,,,,,,,,,,,,,, -136000,0.13445707,0.01781574,,,,,,,,,,,,,,,,, -136100,0.13475968,0.018773042,,,,,,,,,,,,,,,,, -136200,0.14823079,0.017677383,,,,,,,,,,,,,,,,, -136300,0.13553548,0.016083447,,,,,,,,,,,,,,,,, -136400,0.14900039,0.017881336,,,,,,,,,,,,,,,,, -136453,,,0.9955561757087708,0.0138661824166774,0.7835748448083539,0.9869915843009948,0.0507423616945743,0.2936474347516125,43793.0,0.9862112998962402,0.0540365986526012,0.2753951118399199,43793.0,43956.23304724693,66702.2934141159,43956.23304724693,22734.455174207687,7.762632131576538,0.0 -136500,0.14913489,0.018028049,,,,,,,,,,,,,,,,, -136600,0.14565656,0.017839655,,,,,,,,,,,,,,,,, -136700,0.1327808,0.017183937,,,,,,,,,,,,,,,,, -136800,0.14772289,0.016540943,,,,,,,,,,,,,,,,, -136900,0.1317501,0.015128461,,,,,,,,,,,,,,,,, -137000,0.15081899,0.019060485,,,,,,,,,,,,,,,,, -137100,0.1598909,0.017788835,,,,,,,,,,,,,,,,, -137200,0.17104878,0.018297119,,,,,,,,,,,,,,,,, -137202,,,0.9955800175666808,0.0137928072363138,0.7885792245747008,0.9869915843009948,0.0507423616945743,0.2936389094944352,43793.0,0.9862112998962402,0.0540365986526012,0.2754957400432071,43793.0,44196.325745821,67052.75707435608,44196.325745821,22844.751775741577,7.815975904464722,0.0 -137300,0.16377302,0.017202705,,,,,,,,,,,,,,,,, -137400,0.14090207,0.01596123,,,,,,,,,,,,,,,,, -137500,0.14766559,0.019248461,,,,,,,,,,,,,,,,, -137600,0.15950455,0.015439349,,,,,,,,,,,,,,,,, -137700,0.13437629,0.016370354,,,,,,,,,,,,,,,,, -137800,0.15159321,0.017606584,,,,,,,,,,,,,,,,, -137900,0.1446206,0.016807595,,,,,,,,,,,,,,,,, -137945,,,0.9956175088882446,0.0137908682227134,0.7859297351781429,0.9869915843009948,0.0507423616945743,0.2936445535571279,43793.0,0.9862112998962402,0.0540365986526012,0.2754197051375271,43793.0,44436.34388566017,67402.94797158241,44436.34388566017,22954.84996652603,7.869782209396362,0.0 -138000,0.16228092,0.019061638,,,,,,,,,,,,,,,,, -138100,0.15679285,0.019643119,,,,,,,,,,,,,,,,, -138200,0.1405126,0.016621552,,,,,,,,,,,,,,,,, -138300,0.16141954,0.01982395,,,,,,,,,,,,,,,,, -138400,0.15490705,0.017412152,,,,,,,,,,,,,,,,, -138500,0.14973071,0.018192101,,,,,,,,,,,,,,,,, -138600,0.15012072,0.016523879,,,,,,,,,,,,,,,,, -138693,,,0.9955361485481262,0.0139456894248723,0.7821278781897788,0.9869915843009948,0.0507423616945743,0.2938877070849096,43793.0,0.9862112998962402,0.0540365986526012,0.2755277708514029,43793.0,44676.512374162674,67753.59470915794,44676.512374162674,23065.253385066982,7.923959970474243,0.0 -138700,0.13367204,0.018233132,,,,,,,,,,,,,,,,, -138800,0.15456787,0.017022723,,,,,,,,,,,,,,,,, -138900,0.14221384,0.01765146,,,,,,,,,,,,,,,,, -139000,0.14165027,0.016624844,,,,,,,,,,,,,,,,, -139100,0.12488317,0.015931237,,,,,,,,,,,,,,,,, -139200,0.14387296,0.01893507,,,,,,,,,,,,,,,,, -139300,0.1675271,0.019752888,,,,,,,,,,,,,,,,, -139400,0.14651716,0.018315341,,,,,,,,,,,,,,,,, -139441,,,0.9956095814704896,0.0136872818693518,0.7837841741459687,0.9869915843009948,0.0507423616945743,0.2937139301409759,43793.0,0.9862112998962402,0.0540365986526012,0.2754072478578473,43793.0,44916.49689507485,68103.3846859932,44916.49689507485,23174.9851911068,7.976691007614136,0.0 -139500,0.13570783,0.016098205,,,,,,,,,,,,,,,,, -139600,0.13858287,0.016757715,,,,,,,,,,,,,,,,, -139700,0.15559867,0.017335655,,,,,,,,,,,,,,,,, -139800,0.17564616,0.018692743,,,,,,,,,,,,,,,,, -139900,0.16361573,0.019106649,,,,,,,,,,,,,,,,, -140000,0.15639752,0.018909464,,,,,,,,,,,,,,,,, -140100,0.15426634,0.02005693,,,,,,,,,,,,,,,,, -140189,,,0.9955382347106934,0.0139403715729713,0.7720245902031437,0.9869915843009948,0.0507423616945743,0.2935594829764194,43793.0,0.9862112998962402,0.0540365986526012,0.2753595513198698,43793.0,45156.51519060135,68456.67514777184,45156.51519060135,23288.182639360428,8.0308096408844,0.0 -140200,0.14149037,0.018028108,,,,,,,,,,,,,,,,, -140300,0.14662337,0.017827548,,,,,,,,,,,,,,,,, -140400,0.15292215,0.021209916,,,,,,,,,,,,,,,,, -140500,0.14317128,0.017449599,,,,,,,,,,,,,,,,, -140600,0.14508173,0.01924443,,,,,,,,,,,,,,,,, -140700,0.15079556,0.017789098,,,,,,,,,,,,,,,,, -140800,0.14583096,0.017806439,,,,,,,,,,,,,,,,, -140900,0.13763456,0.016599532,,,,,,,,,,,,,,,,, -140920,,,0.9955587983131408,0.0138416392728686,0.7846010270148847,0.9869915843009948,0.0507423616945743,0.2938899450060268,43793.0,0.9862112998962402,0.0540365986526012,0.275453484211465,43793.0,45396.50525188446,68808.08483695984,45396.50525188446,23399.51488018036,8.09301209449768,0.0 -141000,0.14504424,0.017424976,,,,,,,,,,,,,,,,, -141100,0.13822931,0.017183267,,,,,,,,,,,,,,,,, -141200,0.15013976,0.016877538,,,,,,,,,,,,,,,,, -141300,0.16268227,0.017683804,,,,,,,,,,,,,,,,, -141400,0.18048067,0.022051169,,,,,,,,,,,,,,,,, -141500,0.15972681,0.017813604,,,,,,,,,,,,,,,,, -141600,0.14549877,0.018725459,,,,,,,,,,,,,,,,, -141660,,,0.9956005215644836,0.0138408299535512,0.7835544456428272,0.9869915843009948,0.0507423616945743,0.2935728189413232,43793.0,0.9862112998962402,0.0540365986526012,0.2754512949348238,43793.0,45636.494839668274,69158.09258151054,45636.494839668274,23509.45137310028,8.151569604873657,0.0 -141700,0.14105146,0.017449351,,,,,,,,,,,,,,,,, -141800,0.13712668,0.016867038,,,,,,,,,,,,,,,,, -141900,0.14014933,0.016832283,,,,,,,,,,,,,,,,, -142000,0.15193394,0.01697139,,,,,,,,,,,,,,,,, -142100,0.15088026,0.018723419,,,,,,,,,,,,,,,,, -142200,0.15082344,0.019449998,,,,,,,,,,,,,,,,, -142300,0.13407636,0.016638065,,,,,,,,,,,,,,,,, -142400,0.17914416,0.02008057,,,,,,,,,,,,,,,,, -142412,,,0.9955870509147644,0.0138025749474763,0.7837377125709595,0.9869915843009948,0.0507423616945743,0.2938052738879229,43793.0,0.9862112998962402,0.0540365986526012,0.2754350355381713,43793.0,45876.57214689255,69505.11353802681,45876.57214689255,23616.3204202652,8.205451011657715,0.0 -142500,0.14247572,0.017373024,,,,,,,,,,,,,,,,, -142600,0.14368193,0.017524669,,,,,,,,,,,,,,,,, -142700,0.15906946,0.017874919,,,,,,,,,,,,,,,,, -142800,0.1505693,0.01893993,,,,,,,,,,,,,,,,, -142900,0.13890652,0.016623331,,,,,,,,,,,,,,,,, -143000,0.15316126,0.018504092,,,,,,,,,,,,,,,,, -143100,0.15501703,0.01907357,,,,,,,,,,,,,,,,, -143162,,,0.9956026077270508,0.0137334009632468,0.7859749685301147,0.9869915843009948,0.0507423616945743,0.2939219873827208,43793.0,0.9862112998962402,0.0540366023778915,0.2753971686265937,43793.0,46116.80425333977,69851.81775379181,46116.80425333977,23722.715693950653,8.261404752731323,0.0 -143200,0.14912854,0.018456616,,,,,,,,,,,,,,,,, -143300,0.1515271,0.017295748,,,,,,,,,,,,,,,,, -143400,0.13966608,0.016204072,,,,,,,,,,,,,,,,, -143500,0.1489136,0.018354168,,,,,,,,,,,,,,,,, -143600,0.15321271,0.018193137,,,,,,,,,,,,,,,,, -143700,0.15632564,0.020259729,,,,,,,,,,,,,,,,, -143800,0.13355672,0.017219204,,,,,,,,,,,,,,,,, -143900,0.13538463,0.016438441,,,,,,,,,,,,,,,,, -143911,,,0.9955673217773438,0.0138666266575455,0.7774714317441656,0.9869915843009948,0.0507423616945743,0.2937292203462356,43793.0,0.9862112998962402,0.0540365986526012,0.2754010988589716,43793.0,46356.86944055557,70200.4524781704,46356.86944055557,23831.21180343628,8.31418228149414,0.0 -144000,0.12830798,0.01560029,,,,,,,,,,,,,,,,, -144100,0.16063827,0.018601842,,,,,,,,,,,,,,,,, -144200,0.14364834,0.018385062,,,,,,,,,,,,,,,,, -144300,0.13854936,0.01740837,,,,,,,,,,,,,,,,, -144400,0.14567582,0.018932411,,,,,,,,,,,,,,,,, -144500,0.1531459,0.018014183,,,,,,,,,,,,,,,,, -144600,0.12561618,0.014139411,,,,,,,,,,,,,,,,, -144659,,,0.995541512966156,0.0138751147314906,0.7791450319607318,0.9869915843009948,0.0507423616945743,0.2937038021491862,43793.0,0.9862112998962402,0.0540365986526012,0.2754712819334631,43793.0,46596.89909791946,70552.47891068459,46596.89909791946,23943.13414955139,8.367965698242188,0.0 -144700,0.14951181,0.017391592,,,,,,,,,,,,,,,,, -144800,0.13097271,0.01703533,,,,,,,,,,,,,,,,, -144900,0.13642527,0.01770466,,,,,,,,,,,,,,,,, -145000,0.13622518,0.016658641,,,,,,,,,,,,,,,,, -145100,0.15736422,0.017827397,,,,,,,,,,,,,,,,, -145200,0.15351503,0.016905488,,,,,,,,,,,,,,,,, -145300,0.14941184,0.016410263,,,,,,,,,,,,,,,,, -145400,0.14803904,0.018576961,,,,,,,,,,,,,,,,, -145404,,,0.9955787062644958,0.0138519490137696,0.7859469751789931,0.9869915843009948,0.0507423616945743,0.2935924268937204,43793.0,0.9862112998962402,0.0540365986526012,0.2754216593004718,43793.0,46837.16341853142,70899.0400402546,46837.16341853142,24049.355276346207,8.423321008682251,0.0 -145500,0.17600547,0.021012938,,,,,,,,,,,,,,,,, -145600,0.16191968,0.01749993,,,,,,,,,,,,,,,,, -145700,0.13926697,0.015259419,,,,,,,,,,,,,,,,, -145800,0.13902095,0.01754735,,,,,,,,,,,,,,,,, -145900,0.14178465,0.015638106,,,,,,,,,,,,,,,,, -146000,0.13208973,0.0154956775,,,,,,,,,,,,,,,,, -146100,0.14525239,0.016774738,,,,,,,,,,,,,,,,, -146157,,,0.9956125020980836,0.0137772569432854,0.7824314669110136,0.9869915843009948,0.0507423616945743,0.2937127948568798,43793.0,0.9862112998962402,0.0540365986526012,0.2754313374050625,43793.0,47077.318110466,71246.9433298111,47077.318110466,24157.030386209488,8.47667145729065,0.0 -146200,0.12754412,0.016486311,,,,,,,,,,,,,,,,, -146300,0.16603665,0.019511724,,,,,,,,,,,,,,,,, -146400,0.15292916,0.018061312,,,,,,,,,,,,,,,,, -146500,0.1329662,0.016970202,,,,,,,,,,,,,,,,, -146600,0.13916032,0.016280277,,,,,,,,,,,,,,,,, -146700,0.1421198,0.016934697,,,,,,,,,,,,,,,,, -146800,0.14730638,0.018628504,,,,,,,,,,,,,,,,, -146900,0.15246731,0.018704787,,,,,,,,,,,,,,,,, -146908,,,0.9955742955207824,0.013849351555109,0.7841609706569808,0.9869915843009948,0.0507423616945743,0.2936861766004716,43793.0,0.9862112998962402,0.0540365986526012,0.2754988653620888,43793.0,47317.36692094803,71594.48837161064,47317.36692094803,24264.45087337494,8.5319983959198,0.0 -147000,0.15785141,0.016735198,,,,,,,,,,,,,,,,, -147100,0.1364238,0.016811853,,,,,,,,,,,,,,,,, -147200,0.15127045,0.0177988,,,,,,,,,,,,,,,,, -147300,0.14717852,0.018318877,,,,,,,,,,,,,,,,, -147400,0.13159505,0.016341051,,,,,,,,,,,,,,,,, -147500,0.14334793,0.017388007,,,,,,,,,,,,,,,,, -147600,0.14116433,0.018720448,,,,,,,,,,,,,,,,, -147650,,,0.9956088066101074,0.0137588586658239,0.7803652558375634,0.9869915843009948,0.0507423616945743,0.2936537150876402,43793.0,0.9862112998962402,0.0540365986526012,0.2754986839728315,43793.0,47557.4769847393,71946.58534765244,47557.4769847393,24376.361690044403,8.58792233467102,0.0 -147700,0.15073465,0.017237674,,,,,,,,,,,,,,,,, -147800,0.13100246,0.017802287,,,,,,,,,,,,,,,,, -147900,0.12189998,0.016054392,,,,,,,,,,,,,,,,, -148000,0.140568,0.019639116,,,,,,,,,,,,,,,,, -148100,0.1418875,0.016368356,,,,,,,,,,,,,,,,, -148200,0.16693863,0.01875167,,,,,,,,,,,,,,,,, -148300,0.14330713,0.017014949,,,,,,,,,,,,,,,,, -148396,,,0.9955193400382996,0.0139368381351232,0.7758290701438132,0.9869915843009948,0.0507423616945743,0.2937131747472333,43793.0,0.9862112998962402,0.0540365986526012,0.2754836519611033,43793.0,47797.62417554855,72292.47641682625,47797.62417554855,24482.029361486435,8.643598794937134,0.0 -148400,0.1358851,0.01705705,,,,,,,,,,,,,,,,, -148500,0.14638442,0.016961556,,,,,,,,,,,,,,,,, -148600,0.12873273,0.015659995,,,,,,,,,,,,,,,,, -148700,0.1527458,0.017536812,,,,,,,,,,,,,,,,, -148800,0.15302296,0.017716099,,,,,,,,,,,,,,,,, -148900,0.14822173,0.017798685,,,,,,,,,,,,,,,,, -149000,0.14772199,0.01697248,,,,,,,,,,,,,,,,, -149100,0.14627501,0.017452022,,,,,,,,,,,,,,,,, -149150,,,0.9955896735191344,0.0137344496324658,0.7878894749972761,0.9869915843009948,0.0507423616945743,0.2935785880532318,43793.0,0.9862112998962402,0.0540365986526012,0.2754176262087233,43793.0,48037.61570096016,72640.80835223198,48037.61570096016,24590.29389500618,8.698626518249512,0.0 -149200,0.13974993,0.016679158,,,,,,,,,,,,,,,,, -149300,0.14193302,0.01691948,,,,,,,,,,,,,,,,, -149400,0.14708444,0.017839152,,,,,,,,,,,,,,,,, -149500,0.14880472,0.018578906,,,,,,,,,,,,,,,,, -149600,0.1284476,0.015750628,,,,,,,,,,,,,,,,, -149700,0.17180501,0.019102443,,,,,,,,,,,,,,,,, -149800,0.15426832,0.019783294,,,,,,,,,,,,,,,,, -149900,0.1489691,0.017377181,,,,,,,,,,,,,,,,, -149904,,,0.9956307411193848,0.0137735474854707,0.7877537249143864,0.9869915843009948,0.0507423616945743,0.2937076776102705,43793.0,0.9862112998962402,0.0540366023778915,0.2754609112657465,43793.0,48277.7650783062,72987.3367960453,48277.7650783062,24696.59809613228,8.752968549728394,0.0 -150000,0.14621882,0.018297134,,,,,,,,,,,,,,,,, -150100,0.16746844,0.018356498,,,,,,,,,,,,,,,,, -150200,0.14874013,0.017882356,,,,,,,,,,,,,,,,, -150300,0.13569249,0.016637426,,,,,,,,,,,,,,,,, -150400,0.12688434,0.01677423,,,,,,,,,,,,,,,,, -150500,0.14238583,0.016972918,,,,,,,,,,,,,,,,, -150600,0.14475586,0.017280314,,,,,,,,,,,,,,,,, -150660,,,0.9955337047576904,0.0139522580429911,0.7795977649702918,0.9869915843009948,0.0507423616945743,0.2937253985372344,43793.0,0.9862112998962402,0.0540365986526012,0.275389384030019,43793.0,48517.886543273926,73339.16317415237,48517.886543273926,24808.22857093811,8.807228088378906,0.0 -150700,0.14568129,0.017097482,,,,,,,,,,,,,,,,, -150800,0.1485446,0.017201422,,,,,,,,,,,,,,,,, -150900,0.13788816,0.015863188,,,,,,,,,,,,,,,,, -151000,0.13696319,0.016587876,,,,,,,,,,,,,,,,, -151100,0.1409617,0.017687602,,,,,,,,,,,,,,,,, -151200,0.14758815,0.015015144,,,,,,,,,,,,,,,,, -151300,0.12447025,0.015351679,,,,,,,,,,,,,,,,, -151400,0.16513476,0.018822907,,,,,,,,,,,,,,,,, -151405,,,0.995615005493164,0.0136622358113527,0.7890784196348725,0.9869915843009948,0.0507423616945743,0.2936964955645312,43793.0,0.9862112998962402,0.0540365986526012,0.2754052718551408,43793.0,48757.82743740082,73686.6535577774,48757.82743740082,24915.70268535614,8.862153768539429,0.0 -151500,0.1340918,0.01679213,,,,,,,,,,,,,,,,, -151600,0.1452802,0.016382298,,,,,,,,,,,,,,,,, -151700,0.14074874,0.017181085,,,,,,,,,,,,,,,,, -151800,0.15215397,0.017767461,,,,,,,,,,,,,,,,, -151900,0.1457435,0.017966555,,,,,,,,,,,,,,,,, -152000,0.1379364,0.017183278,,,,,,,,,,,,,,,,, -152100,0.14439884,0.01706735,,,,,,,,,,,,,,,,, -152164,,,0.995550274848938,0.0139026893302798,0.7697975941964679,0.9869915843009948,0.0507423616945743,0.2937655585787352,43793.0,0.9862112998962402,0.0540366023778915,0.2754099633235572,43793.0,48997.56856369972,74037.83790445328,48997.56856369972,25026.75702905655,9.230566263198853,0.0 -152200,0.16483137,0.019553933,,,,,,,,,,,,,,,,, -152300,0.14131965,0.016775014,,,,,,,,,,,,,,,,, -152400,0.146534,0.01812243,,,,,,,,,,,,,,,,, -152500,0.14114858,0.019268857,,,,,,,,,,,,,,,,, -152600,0.1557741,0.019940803,,,,,,,,,,,,,,,,, -152700,0.13759872,0.017129619,,,,,,,,,,,,,,,,, -152800,0.13357085,0.017048297,,,,,,,,,,,,,,,,, -152900,0.12678374,0.016023083,,,,,,,,,,,,,,,,, -152916,,,0.995542049407959,0.0139244012534618,0.7826677908654057,0.9869915843009948,0.0507423616945743,0.2936824436732723,43793.0,0.9862112998962402,0.0540365986526012,0.2754459852629338,43793.0,49237.56392478943,74388.833922863,49237.56392478943,25137.68102812767,9.286834478378296,0.0 -153000,0.15616737,0.018189752,,,,,,,,,,,,,,,,, -153100,0.1648441,0.017888991,,,,,,,,,,,,,,,,, -153200,0.1433569,0.019114004,,,,,,,,,,,,,,,,, -153300,0.14715974,0.016566202,,,,,,,,,,,,,,,,, -153400,0.17338037,0.01965839,,,,,,,,,,,,,,,,, -153500,0.14416319,0.018651552,,,,,,,,,,,,,,,,, -153600,0.16090164,0.016772497,,,,,,,,,,,,,,,,, -153672,,,0.9956010580062866,0.0138049824163317,0.7902975906992156,0.9869915843009948,0.0507423616945743,0.2937089946570169,43793.0,0.9862112998962402,0.0540365986526012,0.2753901263563581,43793.0,49477.695827007294,74736.78531956673,49477.695827007294,25245.425184965134,9.341866731643677,0.0 -153700,0.1252511,0.015135165,,,,,,,,,,,,,,,,, -153800,0.13305157,0.017937811,,,,,,,,,,,,,,,,, -153900,0.14059499,0.0181537,,,,,,,,,,,,,,,,, -154000,0.15140636,0.018566592,,,,,,,,,,,,,,,,, -154100,0.14721559,0.01740176,,,,,,,,,,,,,,,,, -154200,0.15939376,0.020182028,,,,,,,,,,,,,,,,, -154300,0.1357182,0.015552586,,,,,,,,,,,,,,,,, -154400,0.14868604,0.018138537,,,,,,,,,,,,,,,,, -154432,,,0.9955745935440063,0.0138638485223054,0.7789048942518366,0.9869915843009948,0.0507423616945743,0.2936350856274303,43793.0,0.9862112998962402,0.0540366023778915,0.2753254195522341,43793.0,49717.68036675453,75079.89323925972,49717.68036675453,25348.473020792007,9.397128582000732,0.0 -154500,0.139193,0.015853854,,,,,,,,,,,,,,,,, -154600,0.16016379,0.017393222,,,,,,,,,,,,,,,,, -154700,0.14702217,0.018198151,,,,,,,,,,,,,,,,, -154800,0.12999809,0.015733954,,,,,,,,,,,,,,,,, -154900,0.14231576,0.018168231,,,,,,,,,,,,,,,,, -155000,0.14061874,0.016845558,,,,,,,,,,,,,,,,, -155100,0.13609038,0.017280048,,,,,,,,,,,,,,,,, -155180,,,0.995599091053009,0.013690123334527,0.7854078741192276,0.9869915843009948,0.0507423616945743,0.2937296730739804,43793.0,0.9862112998962402,0.0540365986526012,0.2754519692100451,43793.0,49957.67692351341,75430.52669262886,49957.67692351341,25459.034336566925,9.452574968338013,0.0 -155200,0.15284462,0.018989814,,,,,,,,,,,,,,,,, -155300,0.13198341,0.016296232,,,,,,,,,,,,,,,,, -155400,0.15297571,0.017716214,,,,,,,,,,,,,,,,, -155500,0.13596547,0.0151714925,,,,,,,,,,,,,,,,, -155600,0.14705722,0.017144457,,,,,,,,,,,,,,,,, -155700,0.14789863,0.018220844,,,,,,,,,,,,,,,,, -155800,0.15092649,0.01695135,,,,,,,,,,,,,,,,, -155900,0.14175463,0.016632788,,,,,,,,,,,,,,,,, -155931,,,0.9955906271934508,0.0138539336621761,0.773038709091723,0.9869915843009948,0.0507423616945743,0.2937262505472416,43793.0,0.9862112998962402,0.0540365986526012,0.2754514664735991,43793.0,50197.66117596626,75783.52428460121,50197.66117596626,25571.970460414886,9.509182453155518,0.0 -156000,0.15067367,0.019854365,,,,,,,,,,,,,,,,, -156100,0.13017145,0.01641094,,,,,,,,,,,,,,,,, -156200,0.14723465,0.017117921,,,,,,,,,,,,,,,,, -156300,0.13329913,0.015131775,,,,,,,,,,,,,,,,, -156400,0.15445055,0.01861254,,,,,,,,,,,,,,,,, -156500,0.16139299,0.018720062,,,,,,,,,,,,,,,,, -156600,0.15118401,0.018211566,,,,,,,,,,,,,,,,, -156677,,,0.9955081939697266,0.0139719014987349,0.777331023418733,0.9869915843009948,0.0507423616945743,0.2936683234079764,43793.0,0.9862112998962402,0.0540365986526012,0.2753889199464259,43793.0,50437.83835029602,76130.00862145424,50437.83835029602,25678.198181152344,9.566497564315796,0.0 -156700,0.15127823,0.019013535,,,,,,,,,,,,,,,,, -156800,0.14436588,0.018160874,,,,,,,,,,,,,,,,, -156900,0.14253458,0.018386504,,,,,,,,,,,,,,,,, -157000,0.15647256,0.018954348,,,,,,,,,,,,,,,,, -157100,0.1464619,0.017956816,,,,,,,,,,,,,,,,, -157200,0.13937822,0.018451443,,,,,,,,,,,,,,,,, -157300,0.14564085,0.018107276,,,,,,,,,,,,,,,,, -157400,0.13179351,0.016655639,,,,,,,,,,,,,,,,, -157433,,,0.9956135749816896,0.0137933911755681,0.7892423439032692,0.9869915843009948,0.0507423616945743,0.293717954748,43793.0,0.9862112998962402,0.0540365986526012,0.2753973492261922,43793.0,50677.798704624176,76475.43181467056,50677.798704624176,25783.583258152008,9.623881101608276,0.0 -157500,0.14866872,0.019681118,,,,,,,,,,,,,,,,, -157600,0.14618883,0.019647095,,,,,,,,,,,,,,,,, -157700,0.1385911,0.017218735,,,,,,,,,,,,,,,,, -157800,0.14537972,0.018845528,,,,,,,,,,,,,,,,, -157900,0.16321082,0.018963017,,,,,,,,,,,,,,,,, -158000,0.12676226,0.01764143,,,,,,,,,,,,,,,,, -158100,0.15220064,0.018656181,,,,,,,,,,,,,,,,, -158182,,,0.995617926120758,0.0137408478185534,0.7845530195903894,0.9869915843009948,0.0507423616945743,0.2936882202342486,43793.0,0.9862112998962402,0.0540365986526012,0.275407186043599,43793.0,50917.96872997284,76825.63581681252,50917.96872997284,25893.5378510952,9.682755708694458,0.0 -158200,0.14736632,0.018445566,,,,,,,,,,,,,,,,, -158300,0.14127868,0.016413521,,,,,,,,,,,,,,,,, -158400,0.13915588,0.015954465,,,,,,,,,,,,,,,,, -158500,0.13356422,0.0153936045,,,,,,,,,,,,,,,,, -158600,0.1461466,0.017685475,,,,,,,,,,,,,,,,, -158700,0.14469956,0.015443931,,,,,,,,,,,,,,,,, -158800,0.14748338,0.018414695,,,,,,,,,,,,,,,,, -158900,0.12755235,0.017896675,,,,,,,,,,,,,,,,, -158929,,,0.9955463409423828,0.0138870431110262,0.7822689544969321,0.9869915843009948,0.0507423616945743,0.2936670853993187,43793.0,0.9862112998962402,0.0540365986526012,0.2754564632069592,43793.0,51158.21628952026,77174.72514629364,51158.21628952026,26002.302206754684,9.739495277404783,0.0 -159000,0.13571122,0.01577144,,,,,,,,,,,,,,,,, -159100,0.15661849,0.018584538,,,,,,,,,,,,,,,,, -159200,0.14104205,0.0153840175,,,,,,,,,,,,,,,,, -159300,0.15038602,0.01872499,,,,,,,,,,,,,,,,, -159400,0.16289347,0.018646,,,,,,,,,,,,,,,,, -159500,0.13729186,0.017612038,,,,,,,,,,,,,,,,, -159600,0.1505216,0.018789709,,,,,,,,,,,,,,,,, -159684,,,0.9955745339393616,0.0138083333149552,0.7875565170609959,0.9869915843009948,0.0507423616945743,0.2936776762052396,43793.0,0.9862112998962402,0.0540365986526012,0.2754188429540947,43793.0,51398.28981542587,77526.43344187737,51398.28981542587,26113.859795093536,9.795924425125122,0.0 -159700,0.14703806,0.016919391,,,,,,,,,,,,,,,,, -159800,0.15937041,0.019252354,,,,,,,,,,,,,,,,, -159900,0.13666439,0.014587996,,,,,,,,,,,,,,,,, -160000,0.13847373,0.017851083,,,,,,,,,,,,,,,,, -160100,0.15752605,0.017833281,,,,,,,,,,,,,,,,, -160200,0.15120424,0.01737624,,,,,,,,,,,,,,,,, -160300,0.1257811,0.018855654,,,,,,,,,,,,,,,,, -160400,0.14366508,0.01628923,,,,,,,,,,,,,,,,, -160424,,,0.9955351948738098,0.0138880098238587,0.774616761796803,0.9869915843009948,0.0507423616945743,0.2936985274524972,43793.0,0.9862112998962402,0.0540365986526012,0.2754478020806966,43793.0,51638.4400715828,77871.21443009377,51638.4400715828,26218.408900260925,9.85430383682251,0.0 -160500,0.1369601,0.016142935,,,,,,,,,,,,,,,,, -160600,0.13510922,0.015967742,,,,,,,,,,,,,,,,, -160700,0.1366737,0.013510339,,,,,,,,,,,,,,,,, -160800,0.14997725,0.017680436,,,,,,,,,,,,,,,,, -160900,0.14165276,0.018332364,,,,,,,,,,,,,,,,, -161000,0.1390205,0.019276299,,,,,,,,,,,,,,,,, -161100,0.15468208,0.019579059,,,,,,,,,,,,,,,,, -161174,,,0.9956015944480896,0.0137695474550127,0.7862721595759705,0.9869915843009948,0.0507423616945743,0.2936892479236363,43793.0,0.9862112998962402,0.0540365986526012,0.2754757861197963,43793.0,51878.58922076225,78223.65838193893,51878.58922076225,26330.62533640861,9.911830186843872,0.0 -161200,0.15143187,0.017785756,,,,,,,,,,,,,,,,, -161300,0.15069601,0.016998531,,,,,,,,,,,,,,,,, -161400,0.153254,0.019850228,,,,,,,,,,,,,,,,, -161500,0.15249306,0.01716153,,,,,,,,,,,,,,,,, -161600,0.13386275,0.016912788,,,,,,,,,,,,,,,,, -161700,0.15323548,0.019745206,,,,,,,,,,,,,,,,, -161800,0.13622245,0.017129997,,,,,,,,,,,,,,,,, -161900,0.14303434,0.018333305,,,,,,,,,,,,,,,,, -161916,,,0.9956266283988952,0.0137533387169241,0.7849122736140868,0.9869915843009948,0.0507423616945743,0.2937739629217785,43793.0,0.9862112998962402,0.0540365986526012,0.2754001764751488,43793.0,52118.57286572456,78571.4911108017,52118.57286572456,26438.3941590786,9.969873428344728,0.0 -162000,0.13040146,0.016683944,,,,,,,,,,,,,,,,, -162100,0.15276623,0.018666534,,,,,,,,,,,,,,,,, -162200,0.14844264,0.019747976,,,,,,,,,,,,,,,,, -162300,0.1632085,0.019789385,,,,,,,,,,,,,,,,, -162400,0.14758497,0.01888692,,,,,,,,,,,,,,,,, -162500,0.14433333,0.017166022,,,,,,,,,,,,,,,,, -162600,0.14890541,0.017269278,,,,,,,,,,,,,,,,, -162659,,,0.9955690503120422,0.0138960359618067,0.7824925082203897,0.9869915843009948,0.0507423616945743,0.2936022995657795,43793.0,0.9862112998962402,0.0540365986526012,0.2755013318092121,43793.0,52358.634991168976,78922.59042525291,52358.634991168976,26549.352571725845,10.02830934524536,0.0 -162700,0.1439346,0.017186286,,,,,,,,,,,,,,,,, -162800,0.12123146,0.014849285,,,,,,,,,,,,,,,,, -162900,0.14913267,0.017675519,,,,,,,,,,,,,,,,, -163000,0.14345746,0.01766222,,,,,,,,,,,,,,,,, -163100,0.15467694,0.020251894,,,,,,,,,,,,,,,,, -163200,0.15024759,0.01832878,,,,,,,,,,,,,,,,, -163300,0.13875607,0.01819755,,,,,,,,,,,,,,,,, -163400,0.1412025,0.017434847,,,,,,,,,,,,,,,,, -163411,,,0.9955653548240662,0.0137989101931452,0.7842163827697879,0.9869915843009948,0.0507423616945743,0.2935796978506215,43793.0,0.9862112998962402,0.0540365986526012,0.2753319317863094,43793.0,52598.57873725891,79271.2127199173,52598.57873725891,26657.94908690453,10.08947467803955,0.0 -163500,0.14673918,0.016771894,,,,,,,,,,,,,,,,, -163600,0.15558164,0.018804891,,,,,,,,,,,,,,,,, -163700,0.13901176,0.01734487,,,,,,,,,,,,,,,,, -163800,0.14428642,0.017589916,,,,,,,,,,,,,,,,, -163900,0.15102868,0.01616215,,,,,,,,,,,,,,,,, -164000,0.13540572,0.017013047,,,,,,,,,,,,,,,,, -164100,0.13361232,0.0169281,,,,,,,,,,,,,,,,, -164163,,,0.9955796599388124,0.0137862637639045,0.7746348160789243,0.9869915843009948,0.0507423616945743,0.2935902334093814,43793.0,0.9862112998962402,0.0540365986526012,0.2754401584745454,43793.0,52838.597648859024,79616.03345918655,52838.597648859024,26762.67085957527,10.14896845817566,0.0 -164200,0.12540849,0.01734112,,,,,,,,,,,,,,,,, -164300,0.15671915,0.020124815,,,,,,,,,,,,,,,,, -164400,0.14106715,0.018876446,,,,,,,,,,,,,,,,, -164500,0.15213996,0.017588794,,,,,,,,,,,,,,,,, -164600,0.133922,0.016560158,,,,,,,,,,,,,,,,, -164700,0.14479207,0.01767401,,,,,,,,,,,,,,,,, -164800,0.15107948,0.02003404,,,,,,,,,,,,,,,,, -164900,0.16694719,0.018270107,,,,,,,,,,,,,,,,, -164911,,,0.9955300688743592,0.0139853237196803,0.7824321755727822,0.9869915843009948,0.0507423616945743,0.2937889285679689,43793.0,0.9862112998962402,0.0540365986526012,0.2754810781515628,43793.0,53078.761633872986,79959.98651099205,53078.761633872986,26866.38114452362,10.207342863082886,0.0 -165000,0.13873178,0.017439192,,,,,,,,,,,,,,,,, -165100,0.13865592,0.015835132,,,,,,,,,,,,,,,,, -165200,0.14539994,0.017317954,,,,,,,,,,,,,,,,, -165300,0.13537082,0.016727775,,,,,,,,,,,,,,,,, -165400,0.16118631,0.016892351,,,,,,,,,,,,,,,,, -165500,0.14170374,0.018044993,,,,,,,,,,,,,,,,, -165600,0.15022427,0.01947431,,,,,,,,,,,,,,,,, -165668,,,0.9955757260322572,0.0138637935742735,0.7839171197252315,0.9869915843009948,0.0507423616945743,0.2936526321692164,43793.0,0.9862112998962402,0.0540366023778915,0.275376862591533,43793.0,53318.756410360336,80303.61129760742,53318.756410360336,26969.93276333809,10.265093803405762,0.0 -165700,0.16468498,0.020013567,,,,,,,,,,,,,,,,, -165800,0.14389575,0.018552277,,,,,,,,,,,,,,,,, -165900,0.14280282,0.018062554,,,,,,,,,,,,,,,,, -166000,0.14859901,0.01832233,,,,,,,,,,,,,,,,, -166100,0.13343365,0.01751903,,,,,,,,,,,,,,,,, -166200,0.15440193,0.017948987,,,,,,,,,,,,,,,,, -166300,0.13185622,0.015753781,,,,,,,,,,,,,,,,, -166400,0.12567006,0.014833623,,,,,,,,,,,,,,,,, -166423,,,0.995619535446167,0.0137594006955623,0.7878243548892319,0.9869915843009948,0.0507423616945743,0.2937584227051379,43793.0,0.9862112998962402,0.0540365986526012,0.2753630251818127,43793.0,53558.69554495812,80652.39430904388,53558.69554495812,27078.699613571167,10.322208404541016,0.0 -166500,0.12993279,0.017792553,,,,,,,,,,,,,,,,, -166600,0.1475711,0.018228814,,,,,,,,,,,,,,,,, -166700,0.14593172,0.019928826,,,,,,,,,,,,,,,,, -166800,0.1311193,0.016794058,,,,,,,,,,,,,,,,, -166900,0.15383966,0.017833458,,,,,,,,,,,,,,,,, -167000,0.14835185,0.017481863,,,,,,,,,,,,,,,,, -167100,0.12878881,0.016949499,,,,,,,,,,,,,,,,, -167157,,,0.9955827593803406,0.0137697160243988,0.7863487925670267,0.9869915843009948,0.0507423616945743,0.2939032390369472,43793.0,0.9862112998962402,0.0540365986526012,0.2754087925659607,43793.0,53798.87780714035,81004.75309371948,53798.87780714035,27190.793970823288,10.380990743637083,0.0 -167200,0.13350625,0.01669349,,,,,,,,,,,,,,,,, -167300,0.1455653,0.019213025,,,,,,,,,,,,,,,,, -167400,0.14323272,0.016913764,,,,,,,,,,,,,,,,, -167500,0.13570298,0.01725392,,,,,,,,,,,,,,,,, -167600,0.15552881,0.017402548,,,,,,,,,,,,,,,,, -167700,0.15072595,0.02003849,,,,,,,,,,,,,,,,, -167800,0.15168685,0.018067919,,,,,,,,,,,,,,,,, -167900,0.14611071,0.017991263,,,,,,,,,,,,,,,,, -167909,,,0.9955384135246276,0.0138554880395531,0.7785570710333689,0.9869915843009948,0.0507423616945743,0.2936441465812117,43793.0,0.9862112998962402,0.0540365986526012,0.275392875397748,43793.0,54039.08760070801,81353.55399942398,54039.08760070801,27299.30680155754,10.439002513885498,0.0 -168000,0.14359407,0.01800229,,,,,,,,,,,,,,,,, -168100,0.14825971,0.017484583,,,,,,,,,,,,,,,,, -168200,0.13438474,0.016840102,,,,,,,,,,,,,,,,, -168300,0.15670197,0.01750975,,,,,,,,,,,,,,,,, -168400,0.1493573,0.018728217,,,,,,,,,,,,,,,,, -168500,0.12900935,0.016628291,,,,,,,,,,,,,,,,, -168600,0.14608057,0.01835404,,,,,,,,,,,,,,,,, -168650,,,0.99558025598526,0.0138541841879487,0.7801272234892287,0.9869915843009948,0.0507423616945743,0.2937228481334556,43793.0,0.9862112998962402,0.0540365986526012,0.275552159012681,43793.0,54279.04728722572,81701.57476568222,54279.04728722572,27407.28624010086,10.49805474281311,0.0 -168700,0.15168987,0.016732702,,,,,,,,,,,,,,,,, -168800,0.13757806,0.016606929,,,,,,,,,,,,,,,,, -168900,0.14245152,0.018536456,,,,,,,,,,,,,,,,, -169000,0.14304675,0.016275285,,,,,,,,,,,,,,,,, -169100,0.14584841,0.018981408,,,,,,,,,,,,,,,,, -169200,0.15812652,0.018039605,,,,,,,,,,,,,,,,, -169300,0.16450612,0.018040486,,,,,,,,,,,,,,,,, -169400,0.13381809,0.0175418,,,,,,,,,,,,,,,,, -169403,,,0.9955908060073853,0.0137906232848763,0.7917480193232054,0.9869915843009948,0.0507423616945743,0.293798366684791,43793.0,0.9862112998962402,0.0540365986526012,0.2755014735067793,43793.0,54518.96765470505,82046.5347559452,54518.96765470505,27512.24540758133,10.55790948867798,0.0 -169500,0.13819693,0.016149996,,,,,,,,,,,,,,,,, -169600,0.13336636,0.01587301,,,,,,,,,,,,,,,,, -169700,0.14312305,0.018198403,,,,,,,,,,,,,,,,, -169800,0.12991674,0.016247382,,,,,,,,,,,,,,,,, -169900,0.14947471,0.019533703,,,,,,,,,,,,,,,,, -170000,0.13813414,0.017835377,,,,,,,,,,,,,,,,, -170100,0.1550704,0.017549925,,,,,,,,,,,,,,,,, -170162,,,0.9956451058387756,0.0137569643557071,0.7842756858193123,0.9869915843009948,0.0507423616945743,0.2935977855388695,43793.0,0.9862112998962402,0.0540365986526012,0.275371892760308,43793.0,54759.18366193771,82392.74225926399,54759.18366193771,27618.15584230423,10.618480205535889,0.0 -170200,0.1812283,0.019826192,,,,,,,,,,,,,,,,, -170300,0.13427353,0.016918136,,,,,,,,,,,,,,,,, -170400,0.146679,0.018288044,,,,,,,,,,,,,,,,, -170500,0.1331835,0.017105313,,,,,,,,,,,,,,,,, -170600,0.1628993,0.018656617,,,,,,,,,,,,,,,,, -170700,0.14143567,0.016175376,,,,,,,,,,,,,,,,, -170800,0.13525172,0.016886333,,,,,,,,,,,,,,,,, -170900,0.14779782,0.017102538,,,,,,,,,,,,,,,,, -170921,,,0.9955541491508484,0.0138856377452611,0.7790105273746023,0.9869915843009948,0.0507423616945743,0.2936214992287673,43793.0,0.9862112998962402,0.0540366023778915,0.2755011356766365,43793.0,54999.39244008064,82744.88785123825,54999.39244008064,27730.01386666298,10.676766395568848,0.0 -171000,0.14974797,0.018558057,,,,,,,,,,,,,,,,, -171100,0.14969842,0.017813595,,,,,,,,,,,,,,,,, -171200,0.14759049,0.016434226,,,,,,,,,,,,,,,,, -171300,0.13219726,0.01750067,,,,,,,,,,,,,,,,, -171400,0.14054301,0.01751871,,,,,,,,,,,,,,,,, -171500,0.12604232,0.016412633,,,,,,,,,,,,,,,,, -171600,0.14618316,0.017329244,,,,,,,,,,,,,,,,, -171674,,,0.9955615401268004,0.0137510374188423,0.7800829913217441,0.9869915843009948,0.0507423616945743,0.2936228202985142,43793.0,0.9862112998962402,0.0540365986526012,0.2753984549335289,43793.0,55239.52972245216,83085.2924580574,55239.52972245216,27830.2014875412,10.736129522323608,0.0 -171700,0.14660002,0.018672166,,,,,,,,,,,,,,,,, -171800,0.14966199,0.016383357,,,,,,,,,,,,,,,,, -171900,0.1425066,0.017979082,,,,,,,,,,,,,,,,, -172000,0.14651313,0.01905302,,,,,,,,,,,,,,,,, -172100,0.18179709,0.01836619,,,,,,,,,,,,,,,,, -172200,0.14483504,0.017444326,,,,,,,,,,,,,,,,, -172274,,,,,,,,,,,,,,55431.058990478516,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 187a2aa27..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,49 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -887.8651258945465,0.0,40.496176958084106,1,0,40.496176958084106,0.0007088489946909,0.0,11.07231903076172,3003,928.3613443374634,0.0006029167561791,0.0,11.115453720092772,0.0004835649742744,0.0,11.067522048950195,3000 -1376.6985096931458,0.0274982452392578,880.5481684207916,2349,0,880.5481684207916,0.5071175694465637,16.284499320835064,2.903294801712036,3003,2257.3512320518494,0.5111194849014282,21.577162944601955,2.8539931774139404,0.5105206370353699,17.973378801733833,2.85092568397522,3000 -1813.621686458588,0.0516557693481445,1720.5627517700195,4699,0,1720.5627517700195,0.5924118161201477,22.337677996314792,2.12496280670166,3003,3534.390073299408,0.5772063732147217,27.272062123291107,2.269517183303833,0.5889449715614319,23.188244077476853,2.158039093017578,3000 -2242.718279838562,0.0812623500823974,2560.4681203365326,7048,0,2560.4681203365326,0.6206844449043274,23.973413980430863,1.8931884765625,3003,4803.5010669231415,0.6080343127250671,29.0606718918695,1.998271942138672,0.6168801188468933,25.284735116509545,1.93113386631012,3000 -2696.6207807064056,0.1075160503387451,3400.660268306732,9397,0,3400.660268306732,0.6367904543876648,24.995813239162366,1.7636173963546753,3003,6097.70339179039,0.6122574210166931,29.22745206440884,1.960111141204834,0.6298496127128601,26.10114560190848,1.819329023361206,3000 -3170.2137649059296,0.1367018222808838,4240.656886100769,11747,0,4240.656886100769,0.6459822654724121,25.88734468011442,1.6933776140213013,3003,7411.399604558945,0.6196650862693787,30.07848536027601,1.8994940519332888,0.6404136419296265,26.887197005780408,1.751357078552246,3000 -3667.5925085544586,0.1636300086975097,5080.653621196747,14097,0,5080.653621196747,0.6518040895462036,25.83268204250817,1.6483497619628906,3003,8748.880180835724,0.6269037127494812,30.409302554729337,1.8341394662857056,0.643438994884491,26.79804227860407,1.716156244277954,3000 -4155.951548576355,0.1942224502563476,5920.588991165161,16446,0,5920.588991165161,0.6574283838272095,26.619005671904056,1.6131547689437866,3003,10077.283208847046,0.6281946897506714,30.51662314858753,1.814017653465271,0.646265983581543,27.324046560476773,1.6841596364974976,3000 -4618.6915826797485,0.2228555679321289,6760.74355173111,18795,0,6760.74355173111,0.6583115458488464,26.442234509686312,1.6000570058822632,3003,11380.289041280746,0.6745564937591553,33.8394973668327,1.5200400352478027,0.6503701210021973,27.392234079788903,1.6605991125106812,3000 -5142.428127288818,0.2506749629974365,7600.783447980881,21144,0,7600.783447980881,0.6626808643341064,26.793200258777787,1.5769139528274536,3003,12744.170624256134,0.6381985545158386,31.213722577422697,1.7518174648284912,0.6526267528533936,27.370917701953303,1.642736315727234,3000 -5642.817688941956,0.2786710262298584,8440.980343818665,23493,0,8440.980343818665,0.6649003624916077,27.0147249904204,1.561911702156067,3003,14084.867583990095,0.6324992775917053,30.77239768822412,1.795293211936951,0.6544618010520935,27.83978166659492,1.6285449266433716,3000 -6141.316317081451,0.3110058307647705,9281.125701904297,25843,0,9281.125701904297,0.6663180589675903,27.17611606060272,1.546418070793152,3003,15423.621300458908,0.646230161190033,31.48282830995302,1.6848127841949463,0.6560364961624146,27.66872427113727,1.6232653856277466,3000 -6646.066769123077,0.3381881713867187,10121.028459310532,28192,0,10121.028459310532,0.6693858504295349,27.31636621568932,1.5327434539794922,3003,16768.38085126877,0.64000004529953,31.34536810411951,1.739559531211853,0.6591982841491699,28.11320010201369,1.6082539558410645,3000 -7127.2394053936005,0.3667356967926025,10961.262724161148,30543,0,10961.262724161148,0.6699784994125366,27.527019800715944,1.522635579109192,3003,18089.89426422119,0.6375179886817932,30.77511208561493,1.7588555812835691,0.658925473690033,28.296488039179028,1.6004260778427124,3000 -7614.882809877396,0.3996870517730713,11801.242480754852,32893,0,11801.242480754852,0.6708035469055176,27.49746230167418,1.5205928087234497,3003,19417.62895989418,0.6432289481163025,31.19315213816847,1.7125086784362793,0.6582435369491577,28.12177583729921,1.6062827110290527,3000 -8156.013454914093,0.4326503276824951,12641.374759197235,35244,0,12641.374759197235,0.6722328662872314,27.84060668588333,1.5019036531448364,3003,20799.00270557404,0.6422494649887085,31.382447419480027,1.720298171043396,0.6626824140548706,28.378079676535123,1.5820107460021973,3000 -8681.38389492035,0.4682419300079345,13481.296279668808,37593,0,13481.296279668808,0.6723839640617371,27.599065186221257,1.4996920824050903,3003,22164.41160917282,0.6800824999809265,34.355612800032446,1.45755136013031,0.6637487411499023,28.429244621089666,1.5757863521575928,3000 -9297.66802597046,0.5044949054718018,14321.491159915924,39943,0,14321.491159915924,0.6742083430290222,27.961575161697663,1.4953006505966189,3003,23621.00763463974,0.6489623785018921,32.00845543488021,1.68193519115448,0.6647778749465942,28.787674762573683,1.5651044845581057,3000 -9867.36664390564,0.5345263481140137,15161.52368092537,42293,0,15161.52368092537,0.6760095357894897,27.91267454128032,1.48641037940979,3003,25030.84917783737,0.6470581293106079,31.773503395043683,1.6912137269973757,0.6645422577857971,28.337935195464578,1.5649478435516355,3000 -10335.682363271711,0.5661778450012207,16001.64799952507,44644,0,16001.64799952507,0.6754517555236816,27.63704969028884,1.4745254516601562,3003,26339.4004714489,0.6550261974334717,32.423378180989296,1.621602177619934,0.6661665439605713,28.65202752908857,1.5491173267364502,3000 -10946.873561382294,0.5983190536499023,16841.717359304428,46993,0,16841.717359304428,0.6792981624603271,28.03113906907784,1.4608837366104126,3003,27790.774476528168,0.6464254856109619,31.81020358869237,1.6788039207458496,0.666340172290802,28.39240158912281,1.5431653261184692,3000 -11498.062032461166,0.6335971355438232,17681.87114572525,49344,0,17681.87114572525,0.6801580786705017,28.28024055566884,1.4569159746170044,3003,29182.23004412651,0.6485030055046082,31.90822207781761,1.683557629585266,0.6678528189659119,28.77088171329392,1.540746808052063,3000 -12084.83533358574,0.664344072341919,18521.941950798035,51695,0,18521.941950798035,0.6809133887290955,28.51707180407588,1.451818823814392,3003,30609.1839325428,0.6541595458984375,32.13664638188772,1.6322287321090698,0.6684479713439941,28.76373048098764,1.5355753898620603,3000 -12709.605189323423,0.6956663131713867,19361.91049265861,54045,0,19361.91049265861,0.6835047602653503,28.57077916909548,1.4431202411651611,3003,32074.034927845,0.6509293913841248,32.05001778447784,1.6562706232070925,0.6697251200675964,28.898027363811607,1.5251288414001465,3000 -13241.70564031601,0.7274036407470703,20202.092440128326,56396,0,20202.092440128326,0.6842600703239441,28.68166147587714,1.435634970664978,3003,33446.429708480835,0.6806057691574097,34.32327576098257,1.4595482349395752,0.6707666516304016,29.178939066208603,1.515071988105774,3000 -13867.775197029114,0.7592089176177979,21042.018231630325,58746,0,21042.018231630325,0.6848527193069458,28.62383352483516,1.4315576553344729,3003,34912.5352704525,0.653881311416626,32.14687787371736,1.633278727531433,0.6720437407493591,28.967526096442047,1.5138766765594482,3000 -14479.754987239838,0.7943167686462402,21882.15102601052,61097,0,21882.15102601052,0.6872000694274902,28.95892319870933,1.4211808443069458,3003,36364.76279473305,0.6571126580238342,32.149596664961024,1.6194921731948853,0.6735812425613403,29.221768396880584,1.5096439123153689,3000 -15109.183556556702,0.8285095691680908,22722.04600667953,63447,0,22722.04600667953,0.6873278617858887,29.13141158817472,1.4133059978485107,3003,37834.20046401024,0.6630586385726929,32.501644119858895,1.5684752464294434,0.6748087406158447,29.088602365219,1.494695782661438,3000 -15779.91496014595,0.8635752201080322,23562.01352858544,65797,0,23562.01352858544,0.687897264957428,28.904316917236,1.4136604070663452,3003,39345.014755010605,0.6596559286117554,32.72167565803844,1.6069821119308472,0.673159658908844,29.10851777105609,1.5000309944152832,3000 -16354.474965810776,0.905400276184082,24402.17420578003,68147,0,24402.17420578003,0.6913717985153198,29.46735070037606,1.398087501525879,3003,40759.86060547829,0.6588523983955383,32.49809723904921,1.6124247312545776,0.6761106252670288,29.24748990886783,1.4879592657089231,3000 -16943.696088552475,0.9412546157836914,25242.11082100868,70497,0,25242.11082100868,0.6908837556838989,29.18475442807425,1.3939534425735474,3003,42189.13254570961,0.6634678840637207,33.044339503927496,1.5741132497787476,0.6757386922836304,29.433881912093547,1.4837726354599,3000 -17492.493832349777,0.9753479957580566,26082.05291867256,72847,0,26082.05291867256,0.6912323832511902,28.924374363232204,1.3835713863372805,3003,43577.98728346825,0.6604095101356506,32.75089059937514,1.5908526182174685,0.6772513389587402,29.40842404506689,1.4792447090148926,3000 -18035.313774347305,1.0108327865600586,26922.27769780159,75197,0,26922.27769780159,0.6939283013343811,29.33074891716061,1.3758292198181152,3003,44961.1496899128,0.6884849667549133,34.55588996153888,1.4089932441711426,0.6790616512298584,29.55969900752161,1.4695286750793457,3000 -18620.25772929192,1.0464684963226318,27762.29205560684,77546,0,27762.29205560684,0.6951600909233093,29.366417158019804,1.3723912239074707,3003,46386.2266972065,0.6647745966911316,33.22219566753689,1.5595890283584597,0.6787516474723816,29.46783892268668,1.46135675907135,3000 -19239.971554994583,1.083417892456055,28602.22843670845,79896,0,28602.22843670845,0.6968799233436584,29.39433137317041,1.3566237688064575,3003,47845.993671655655,0.6653583645820618,32.828529578925284,1.56242573261261,0.6814670562744141,29.776679592171543,1.453194260597229,3000 -19924.62496495247,1.1249215602874756,29442.20606327057,82246,0,29442.20606327057,0.6970542073249817,29.67077386186465,1.3562180995941162,3003,49370.74531555176,0.6765756011009216,33.29825089454542,1.4844000339508057,0.6817150115966797,29.627096514149848,1.444311022758484,3000 -20485.63865876197,1.1610476970672607,30282.19902396202,84596,0,30282.19902396202,0.6992272734642029,29.647750123973445,1.3421218395233154,3003,50771.86794400215,0.670854926109314,33.14134605170951,1.519789695739746,0.6831161379814148,29.692176572418138,1.4388524293899536,3000 -21037.6956114769,1.203956127166748,31122.1114256382,86946,0,31122.1114256382,0.7010865211486816,30.05570302170628,1.3369659185409546,3003,52163.96068120003,0.6692231893539429,33.56647363891882,1.5366313457489014,0.6852735877037048,30.41250771457849,1.4326162338256836,3000 -21572.29588246345,1.2391314506530762,31962.09663248062,89296,0,31962.09663248062,0.7005636096000671,29.961738819264145,1.3306442499160769,3003,53538.661982774734,0.675879180431366,34.04868441106427,1.4836534261703491,0.6848767995834351,30.374787594004687,1.427587628364563,3000 -22116.24362039566,1.2771565914154053,32802.2980568409,91647,0,32802.2980568409,0.7031317353248596,30.19090552965387,1.3237190246582031,3003,54922.926729917526,0.6721989512443542,33.45237889132655,1.515886306762695,0.6875798106193542,30.52288304117627,1.4213722944259644,3000 -22676.927169799805,1.3175652027130127,33642.41419363022,93998,0,33642.41419363022,0.7028993368148804,29.985501215724845,1.315637707710266,3003,56323.84543085098,0.6907885074615479,35.22884127960909,1.397413969039917,0.6862407326698303,30.26388804807659,1.4174182415008545,3000 -23351.149356365204,1.3545808792114258,34482.45923447609,96348,0,34482.45923447609,0.7055023312568665,30.463050020198107,1.3074795007705688,3003,57838.22999000549,0.680406928062439,33.95429486501357,1.4625781774520874,0.6884601712226868,30.50475710065807,1.409622311592102,3000 -23953.17521595955,1.3932616710662842,35322.54058718681,98698,0,35322.54058718681,0.7058973908424377,30.41957079143748,1.300991177558899,3003,59280.45553016663,0.6774348616600037,34.15186940985161,1.4779555797576904,0.6887453198432922,30.495876301755462,1.4022163152694702,3000 -24525.823969364166,1.4304945468902588,36162.63892054558,101047,0,36162.63892054558,0.7059903740882874,30.442794792692226,1.296967625617981,3003,60693.32236337662,0.6855158805847168,34.851935506476764,1.433314561843872,0.6893280744552612,30.42614686854671,1.3989043235778809,3000 -25190.066334486008,1.471330642700195,37002.79528737068,103397,0,37002.79528737068,0.7083725929260254,30.640346635700297,1.2883847951889038,3003,62197.84350180626,0.6817169785499573,34.91150543285352,1.4496946334838867,0.6903820037841797,30.451516312385404,1.3946393728256226,3000 -25772.742190361023,1.5127499103546145,37842.91463184357,105747,0,37842.91463184357,0.7087560296058655,30.557939741522667,1.285719394683838,3003,63620.76156711578,0.683246910572052,34.73362591682547,1.4504034519195557,0.6902952194213867,30.559631401457903,1.389479160308838,3000 -26343.09537410736,1.5515015125274658,38682.86744952202,108097,0,38682.86744952202,0.7101156115531921,30.65263189003389,1.277177333831787,3003,65031.18496370316,0.6888919472694397,35.10855108546573,1.4076614379882812,0.6916466951370239,30.89753459757305,1.383272647857666,3000 -26985.349088191986,1.5927612781524658,39522.984773635864,110446,0,39522.984773635864,0.7107431292533875,30.877531140595956,1.2718663215637207,3003,66513.68182969093,0.6874507069587708,34.818386155596556,1.4143240451812744,0.6933826208114624,30.886477140805987,1.3791488409042358,3000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index e30270950..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1155 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.871642,11.084388,,,,,,,,,,,,,,,,, -1,,,0.0006029167561791,11.115453720092772,0.0,0.0004835649742744,11.067522048950195,0.0,3000.0,0.0007088489946909,11.07231903076172,0.0,3003.0,40.496176958084106,928.3613443374634,40.496176958084106,887.8651258945465,0.0,0.0 -100,0.17043175,8.234814,,,,,,,,,,,,,,,,, -200,0.36992356,7.476822,,,,,,,,,,,,,,,,, -300,0.59090304,6.920758,,,,,,,,,,,,,,,,, -400,0.42627937,6.3132772,,,,,,,,,,,,,,,,, -500,0.66011643,5.9086723,,,,,,,,,,,,,,,,, -600,0.69594985,5.5755,,,,,,,,,,,,,,,,, -700,0.47828713,5.3073072,,,,,,,,,,,,,,,,, -800,0.4987925,5.05413,,,,,,,,,,,,,,,,, -900,0.5246274,4.776385,,,,,,,,,,,,,,,,, -1000,0.6963559,4.5935545,,,,,,,,,,,,,,,,, -1100,0.5349801,4.30896,,,,,,,,,,,,,,,,, -1200,0.66853803,3.936335,,,,,,,,,,,,,,,,, -1300,0.54093105,3.9689264,,,,,,,,,,,,,,,,, -1400,0.76001686,3.8971117,,,,,,,,,,,,,,,,, -1500,0.6630372,3.6431491,,,,,,,,,,,,,,,,, -1600,0.43420216,3.405537,,,,,,,,,,,,,,,,, -1700,0.61271197,3.44744,,,,,,,,,,,,,,,,, -1800,0.45341545,3.259425,,,,,,,,,,,,,,,,, -1900,0.47587278,3.3005826,,,,,,,,,,,,,,,,, -2000,0.41190648,3.1808712,,,,,,,,,,,,,,,,, -2100,0.36349633,3.0636344,,,,,,,,,,,,,,,,, -2200,0.38977906,3.145315,,,,,,,,,,,,,,,,, -2300,0.4255654,2.9672146,,,,,,,,,,,,,,,,, -2349,,,0.5111194849014282,2.8539931774139404,21.577162944601955,0.5105206370353699,2.85092568397522,17.973378801733833,3000.0,0.5071175694465637,2.903294801712036,16.284499320835064,3003.0,880.5481684207916,2257.3512320518494,880.5481684207916,1376.6985096931458,0.0274982452392578,0.0 -2400,0.34114292,2.923097,,,,,,,,,,,,,,,,, -2500,0.29442298,2.879563,,,,,,,,,,,,,,,,, -2600,0.42155367,2.9021356,,,,,,,,,,,,,,,,, -2700,0.2885047,2.7706783,,,,,,,,,,,,,,,,, -2800,0.33875945,2.758845,,,,,,,,,,,,,,,,, -2900,0.24510935,2.7693524,,,,,,,,,,,,,,,,, -3000,0.26784733,2.7150292,,,,,,,,,,,,,,,,, -3100,0.4033053,2.7644715,,,,,,,,,,,,,,,,, -3200,0.2446864,2.6246727,,,,,,,,,,,,,,,,, -3300,0.25147778,2.5559967,,,,,,,,,,,,,,,,, -3400,0.27644658,2.5973716,,,,,,,,,,,,,,,,, -3500,0.21489501,2.429964,,,,,,,,,,,,,,,,, -3600,0.20848455,2.5409613,,,,,,,,,,,,,,,,, -3700,0.21754219,2.5053046,,,,,,,,,,,,,,,,, -3800,0.19472983,2.464049,,,,,,,,,,,,,,,,, -3900,0.24391489,2.5004356,,,,,,,,,,,,,,,,, -4000,0.22071047,2.5593865,,,,,,,,,,,,,,,,, -4100,0.19936624,2.395362,,,,,,,,,,,,,,,,, -4200,0.19044982,2.374429,,,,,,,,,,,,,,,,, -4300,0.17594533,2.3399336,,,,,,,,,,,,,,,,, -4400,0.19504215,2.3342206,,,,,,,,,,,,,,,,, -4500,0.16040403,2.3053215,,,,,,,,,,,,,,,,, -4600,0.16710809,2.3488908,,,,,,,,,,,,,,,,, -4699,,,0.5772063732147217,2.269517183303833,27.272062123291107,0.5889449715614319,2.158039093017578,23.188244077476853,3000.0,0.5924118161201477,2.12496280670166,22.337677996314792,3003.0,1720.5627517700195,3534.390073299408,1720.5627517700195,1813.621686458588,0.0516557693481445,0.0 -4700,0.1574008,2.2653553,,,,,,,,,,,,,,,,, -4800,0.1544824,2.4066544,,,,,,,,,,,,,,,,, -4900,0.1490069,2.2723217,,,,,,,,,,,,,,,,, -5000,0.15680172,2.2573993,,,,,,,,,,,,,,,,, -5100,0.1855811,2.3523483,,,,,,,,,,,,,,,,, -5200,0.16942193,2.3561618,,,,,,,,,,,,,,,,, -5300,0.17493236,2.2296653,,,,,,,,,,,,,,,,, -5400,0.19017005,2.3207042,,,,,,,,,,,,,,,,, -5500,0.154243,2.3607311,,,,,,,,,,,,,,,,, -5600,0.1501402,2.2730994,,,,,,,,,,,,,,,,, -5700,0.15037481,2.2998192,,,,,,,,,,,,,,,,, -5800,0.1766985,2.1871986,,,,,,,,,,,,,,,,, -5900,0.1450493,2.2096229,,,,,,,,,,,,,,,,, -6000,0.16599983,2.18329,,,,,,,,,,,,,,,,, -6100,0.14180775,2.178986,,,,,,,,,,,,,,,,, -6200,0.16075797,2.216552,,,,,,,,,,,,,,,,, -6300,0.14836878,2.1739364,,,,,,,,,,,,,,,,, -6400,0.18419732,2.2545974,,,,,,,,,,,,,,,,, -6500,0.14592515,2.14982,,,,,,,,,,,,,,,,, -6600,0.18575716,2.2022288,,,,,,,,,,,,,,,,, -6700,0.17038678,2.1613941,,,,,,,,,,,,,,,,, -6800,0.14371522,2.1425018,,,,,,,,,,,,,,,,, -6900,0.14947766,2.1992486,,,,,,,,,,,,,,,,, -7000,0.17418431,2.107533,,,,,,,,,,,,,,,,, -7048,,,0.6080343127250671,1.998271942138672,29.0606718918695,0.6168801188468933,1.93113386631012,25.284735116509545,3000.0,0.6206844449043274,1.8931884765625,23.973413980430863,3003.0,2560.4681203365326,4803.5010669231415,2560.4681203365326,2242.718279838562,0.0812623500823974,0.0 -7100,0.16471362,2.1832242,,,,,,,,,,,,,,,,, -7200,0.18529834,2.1400874,,,,,,,,,,,,,,,,, -7300,0.14453591,2.0178666,,,,,,,,,,,,,,,,, -7400,0.18043728,2.1248314,,,,,,,,,,,,,,,,, -7500,0.2224491,2.1043406,,,,,,,,,,,,,,,,, -7600,0.19439992,2.1579125,,,,,,,,,,,,,,,,, -7700,0.17834683,2.1019952,,,,,,,,,,,,,,,,, -7800,0.15339328,2.1057103,,,,,,,,,,,,,,,,, -7900,0.15744789,2.0686316,,,,,,,,,,,,,,,,, -8000,0.15323801,2.0843518,,,,,,,,,,,,,,,,, -8100,0.18623567,2.1376758,,,,,,,,,,,,,,,,, -8200,0.16376075,2.0317159,,,,,,,,,,,,,,,,, -8300,0.16002505,2.063562,,,,,,,,,,,,,,,,, -8400,0.16527633,2.0829377,,,,,,,,,,,,,,,,, -8500,0.14946903,2.0805402,,,,,,,,,,,,,,,,, -8600,0.14317967,2.025453,,,,,,,,,,,,,,,,, -8700,0.2070115,2.0968826,,,,,,,,,,,,,,,,, -8800,0.17728007,2.0550098,,,,,,,,,,,,,,,,, -8900,0.19367902,2.1178055,,,,,,,,,,,,,,,,, -9000,0.17633471,2.0903406,,,,,,,,,,,,,,,,, -9100,0.16015801,2.0609934,,,,,,,,,,,,,,,,, -9200,0.16255987,2.0383043,,,,,,,,,,,,,,,,, -9300,0.17007746,1.9290421,,,,,,,,,,,,,,,,, -9397,,,0.6122574210166931,1.960111141204834,29.22745206440884,0.6298496127128601,1.819329023361206,26.10114560190848,3000.0,0.6367904543876648,1.7636173963546753,24.995813239162366,3003.0,3400.660268306732,6097.70339179039,3400.660268306732,2696.6207807064056,0.1075160503387451,0.0 -9400,0.19805504,2.0314403,,,,,,,,,,,,,,,,, -9500,0.1796198,2.004889,,,,,,,,,,,,,,,,, -9600,0.16123846,2.0377388,,,,,,,,,,,,,,,,, -9700,0.19447052,2.0284286,,,,,,,,,,,,,,,,, -9800,0.16866726,1.9566278,,,,,,,,,,,,,,,,, -9900,0.20978358,2.0077937,,,,,,,,,,,,,,,,, -10000,0.17582199,1.9939574,,,,,,,,,,,,,,,,, -10100,0.21804006,2.082073,,,,,,,,,,,,,,,,, -10200,0.18269643,1.9818193,,,,,,,,,,,,,,,,, -10300,0.16426118,1.9609303,,,,,,,,,,,,,,,,, -10400,0.24894013,1.9358225,,,,,,,,,,,,,,,,, -10500,0.20749962,2.0571032,,,,,,,,,,,,,,,,, -10600,0.18555103,1.99227,,,,,,,,,,,,,,,,, -10700,0.15982747,2.1170635,,,,,,,,,,,,,,,,, -10800,0.20668785,2.0010705,,,,,,,,,,,,,,,,, -10900,0.18552864,2.0240726,,,,,,,,,,,,,,,,, -11000,0.19181034,2.0429552,,,,,,,,,,,,,,,,, -11100,0.17766528,1.91989,,,,,,,,,,,,,,,,, -11200,0.16145082,1.9395725,,,,,,,,,,,,,,,,, -11300,0.15589102,1.9374075,,,,,,,,,,,,,,,,, -11400,0.17148344,2.0195997,,,,,,,,,,,,,,,,, -11500,0.17373477,1.9553808,,,,,,,,,,,,,,,,, -11600,0.18166414,1.8544632,,,,,,,,,,,,,,,,, -11700,0.16273168,1.9939363,,,,,,,,,,,,,,,,, -11747,,,0.6196650862693787,1.8994940519332888,30.07848536027601,0.6404136419296265,1.751357078552246,26.887197005780408,3000.0,0.6459822654724121,1.6933776140213013,25.88734468011442,3003.0,4240.656886100769,7411.399604558945,4240.656886100769,3170.2137649059296,0.1367018222808838,0.0 -11800,0.14804392,1.8864235,,,,,,,,,,,,,,,,, -11900,0.16224575,1.9991521,,,,,,,,,,,,,,,,, -12000,0.17561647,1.9593341,,,,,,,,,,,,,,,,, -12100,0.21891668,1.9469435,,,,,,,,,,,,,,,,, -12200,0.16385484,1.9101804,,,,,,,,,,,,,,,,, -12300,0.16918206,2.0044174,,,,,,,,,,,,,,,,, -12400,0.28997704,1.8688949,,,,,,,,,,,,,,,,, -12500,0.22756046,2.0312667,,,,,,,,,,,,,,,,, -12600,0.18872827,2.0477858,,,,,,,,,,,,,,,,, -12700,0.19833611,1.9786828,,,,,,,,,,,,,,,,, -12800,0.18065631,1.9824932,,,,,,,,,,,,,,,,, -12900,0.17373067,1.8571308,,,,,,,,,,,,,,,,, -13000,0.2393536,2.053254,,,,,,,,,,,,,,,,, -13100,0.17989156,1.8967739,,,,,,,,,,,,,,,,, -13200,0.23015945,1.9956299,,,,,,,,,,,,,,,,, -13300,0.18974286,1.8898191,,,,,,,,,,,,,,,,, -13400,0.24650605,1.9106059,,,,,,,,,,,,,,,,, -13500,0.22891276,1.9952191,,,,,,,,,,,,,,,,, -13600,0.17367023,1.9497668,,,,,,,,,,,,,,,,, -13700,0.16375002,1.9618224,,,,,,,,,,,,,,,,, -13800,0.21437289,1.9830818,,,,,,,,,,,,,,,,, -13900,0.24465623,1.9736588,,,,,,,,,,,,,,,,, -14000,0.2941661,1.8696127,,,,,,,,,,,,,,,,, -14097,,,0.6269037127494812,1.8341394662857056,30.409302554729337,0.643438994884491,1.716156244277954,26.79804227860407,3000.0,0.6518040895462036,1.6483497619628906,25.83268204250817,3003.0,5080.653621196747,8748.880180835724,5080.653621196747,3667.5925085544586,0.1636300086975097,0.0 -14100,0.16061221,1.9467701,,,,,,,,,,,,,,,,, -14200,0.16164114,1.9380453,,,,,,,,,,,,,,,,, -14300,0.19776852,1.9748286,,,,,,,,,,,,,,,,, -14400,0.1731339,1.9278843,,,,,,,,,,,,,,,,, -14500,0.2145882,1.8904157,,,,,,,,,,,,,,,,, -14600,0.20369393,1.9231743,,,,,,,,,,,,,,,,, -14700,0.2933394,1.9284639,,,,,,,,,,,,,,,,, -14800,0.23200698,1.9363943,,,,,,,,,,,,,,,,, -14900,0.17827956,1.9248368,,,,,,,,,,,,,,,,, -15000,0.19668183,1.9475136,,,,,,,,,,,,,,,,, -15100,0.24704842,1.8959073,,,,,,,,,,,,,,,,, -15200,0.21007143,1.9370645,,,,,,,,,,,,,,,,, -15300,0.19607978,1.9127767,,,,,,,,,,,,,,,,, -15400,0.17946239,1.8989267,,,,,,,,,,,,,,,,, -15500,0.19695051,1.9603447,,,,,,,,,,,,,,,,, -15600,0.18641585,1.905961,,,,,,,,,,,,,,,,, -15700,0.21307038,1.8508345,,,,,,,,,,,,,,,,, -15800,0.23641469,1.8186777,,,,,,,,,,,,,,,,, -15900,0.17955367,1.964851,,,,,,,,,,,,,,,,, -16000,0.20558913,1.8637787,,,,,,,,,,,,,,,,, -16100,0.21326557,1.9061116,,,,,,,,,,,,,,,,, -16200,0.19898412,1.857801,,,,,,,,,,,,,,,,, -16300,0.18427797,1.8984959,,,,,,,,,,,,,,,,, -16400,0.2365083,1.8446268,,,,,,,,,,,,,,,,, -16446,,,0.6281946897506714,1.814017653465271,30.51662314858753,0.646265983581543,1.6841596364974976,27.324046560476773,3000.0,0.6574283838272095,1.6131547689437866,26.619005671904056,3003.0,5920.588991165161,10077.283208847046,5920.588991165161,4155.951548576355,0.1942224502563476,0.0 -16500,0.22793862,1.8651431,,,,,,,,,,,,,,,,, -16600,0.18043678,1.9096661,,,,,,,,,,,,,,,,, -16700,0.24433349,1.8775408,,,,,,,,,,,,,,,,, -16800,0.17758824,1.8646901,,,,,,,,,,,,,,,,, -16900,0.19969912,1.9304812,,,,,,,,,,,,,,,,, -17000,0.18905665,2.0133214,,,,,,,,,,,,,,,,, -17100,0.21855895,1.9611113,,,,,,,,,,,,,,,,, -17200,0.17772347,1.7851272,,,,,,,,,,,,,,,,, -17300,0.16327989,1.896132,,,,,,,,,,,,,,,,, -17400,0.27873287,1.8965989,,,,,,,,,,,,,,,,, -17500,0.18821302,1.8607323,,,,,,,,,,,,,,,,, -17600,0.17571814,1.8945465,,,,,,,,,,,,,,,,, -17700,0.17437238,1.880693,,,,,,,,,,,,,,,,, -17800,0.19506308,1.8487891,,,,,,,,,,,,,,,,, -17900,0.17837943,1.8251153,,,,,,,,,,,,,,,,, -18000,0.17717178,1.8641386,,,,,,,,,,,,,,,,, -18100,0.22376432,1.8965006,,,,,,,,,,,,,,,,, -18200,0.18410987,1.9293258,,,,,,,,,,,,,,,,, -18300,0.16663948,1.8761519,,,,,,,,,,,,,,,,, -18400,0.20740952,1.8601738,,,,,,,,,,,,,,,,, -18500,0.22039571,1.8905178,,,,,,,,,,,,,,,,, -18600,0.8634313,1.8431035,,,,,,,,,,,,,,,,, -18700,0.19499797,1.7778788,,,,,,,,,,,,,,,,, -18795,,,0.6745564937591553,1.5200400352478027,33.8394973668327,0.6503701210021973,1.6605991125106812,27.392234079788903,3000.0,0.6583115458488464,1.6000570058822632,26.442234509686312,3003.0,6760.74355173111,11380.289041280746,6760.74355173111,4618.6915826797485,0.2228555679321289,0.0 -18800,0.18092908,1.8733636,,,,,,,,,,,,,,,,, -18900,0.19170481,1.8333548,,,,,,,,,,,,,,,,, -19000,0.18074648,1.8092085,,,,,,,,,,,,,,,,, -19100,0.22174525,1.9010105,,,,,,,,,,,,,,,,, -19200,0.19994348,1.8035437,,,,,,,,,,,,,,,,, -19300,0.19372351,1.8859258,,,,,,,,,,,,,,,,, -19400,0.21309645,1.877094,,,,,,,,,,,,,,,,, -19500,0.18402116,1.8163991,,,,,,,,,,,,,,,,, -19600,0.17038201,1.909972,,,,,,,,,,,,,,,,, -19700,0.18946587,1.8987273,,,,,,,,,,,,,,,,, -19800,0.17971128,1.8839554,,,,,,,,,,,,,,,,, -19900,0.21105888,1.91939,,,,,,,,,,,,,,,,, -20000,0.23456205,1.855477,,,,,,,,,,,,,,,,, -20100,0.17622703,1.8500968,,,,,,,,,,,,,,,,, -20200,0.23322222,1.8654249,,,,,,,,,,,,,,,,, -20300,0.16595611,1.8830208,,,,,,,,,,,,,,,,, -20400,0.17252108,1.8775045,,,,,,,,,,,,,,,,, -20500,0.18745112,1.8672751,,,,,,,,,,,,,,,,, -20600,0.19532885,1.8665701,,,,,,,,,,,,,,,,, -20700,0.17496623,1.8928261,,,,,,,,,,,,,,,,, -20800,0.18868579,1.8760241,,,,,,,,,,,,,,,,, -20900,0.17598428,1.7890092,,,,,,,,,,,,,,,,, -21000,0.25297493,1.8724867,,,,,,,,,,,,,,,,, -21100,0.26180995,1.8684858,,,,,,,,,,,,,,,,, -21144,,,0.6381985545158386,1.7518174648284912,31.213722577422697,0.6526267528533936,1.642736315727234,27.370917701953303,3000.0,0.6626808643341064,1.5769139528274536,26.793200258777787,3003.0,7600.783447980881,12744.170624256134,7600.783447980881,5142.428127288818,0.2506749629974365,0.0 -21200,0.19518556,1.8752849,,,,,,,,,,,,,,,,, -21300,0.35341278,1.8890034,,,,,,,,,,,,,,,,, -21400,0.19325672,1.8605709,,,,,,,,,,,,,,,,, -21500,0.22952062,1.8472952,,,,,,,,,,,,,,,,, -21600,0.20569742,1.898761,,,,,,,,,,,,,,,,, -21700,0.24703427,1.8441595,,,,,,,,,,,,,,,,, -21800,0.18869491,1.913207,,,,,,,,,,,,,,,,, -21900,0.19530618,1.9078507,,,,,,,,,,,,,,,,, -22000,0.21687616,1.8936087,,,,,,,,,,,,,,,,, -22100,0.19803159,1.9019606,,,,,,,,,,,,,,,,, -22200,0.19381677,1.7918698,,,,,,,,,,,,,,,,, -22300,0.18415864,1.793036,,,,,,,,,,,,,,,,, -22400,0.20650344,1.9374195,,,,,,,,,,,,,,,,, -22500,0.18615896,1.875289,,,,,,,,,,,,,,,,, -22600,0.21567391,1.7764194,,,,,,,,,,,,,,,,, -22700,0.21448417,1.7885065,,,,,,,,,,,,,,,,, -22800,0.20429806,1.8401408,,,,,,,,,,,,,,,,, -22900,0.17486618,1.7868456,,,,,,,,,,,,,,,,, -23000,0.23925786,1.8838153,,,,,,,,,,,,,,,,, -23100,0.19379918,1.7641748,,,,,,,,,,,,,,,,, -23200,0.1864222,1.8831959,,,,,,,,,,,,,,,,, -23300,0.1712317,1.8389285,,,,,,,,,,,,,,,,, -23400,0.2833641,1.8143436,,,,,,,,,,,,,,,,, -23493,,,0.6324992775917053,1.795293211936951,30.77239768822412,0.6544618010520935,1.6285449266433716,27.83978166659492,3000.0,0.6649003624916077,1.561911702156067,27.0147249904204,3003.0,8440.980343818665,14084.867583990095,8440.980343818665,5642.817688941956,0.2786710262298584,0.0 -23500,0.32065037,1.8147852,,,,,,,,,,,,,,,,, -23600,0.1996843,1.8880112,,,,,,,,,,,,,,,,, -23700,0.18486342,1.8068621,,,,,,,,,,,,,,,,, -23800,0.196263,1.8497877,,,,,,,,,,,,,,,,, -23900,0.18664075,1.7845315,,,,,,,,,,,,,,,,, -24000,0.21334301,1.869414,,,,,,,,,,,,,,,,, -24100,0.2253014,1.9206322,,,,,,,,,,,,,,,,, -24200,0.1787298,1.7844453,,,,,,,,,,,,,,,,, -24300,0.25526553,1.8751997,,,,,,,,,,,,,,,,, -24400,0.2861422,1.8053077,,,,,,,,,,,,,,,,, -24500,0.21903868,1.7719872,,,,,,,,,,,,,,,,, -24600,0.18808028,1.9121141,,,,,,,,,,,,,,,,, -24700,0.16993468,1.8029485,,,,,,,,,,,,,,,,, -24800,0.18460917,1.8123521,,,,,,,,,,,,,,,,, -24900,0.37847558,1.866416,,,,,,,,,,,,,,,,, -25000,0.24714367,1.8562104,,,,,,,,,,,,,,,,, -25100,0.35090867,1.8630137,,,,,,,,,,,,,,,,, -25200,0.19155256,1.8667309,,,,,,,,,,,,,,,,, -25300,0.21586011,1.8530147,,,,,,,,,,,,,,,,, -25400,0.19552946,1.8193665,,,,,,,,,,,,,,,,, -25500,0.19385265,1.7996604,,,,,,,,,,,,,,,,, -25600,0.2165846,1.8664184,,,,,,,,,,,,,,,,, -25700,0.18862072,1.7908603,,,,,,,,,,,,,,,,, -25800,0.21868452,1.8554648,,,,,,,,,,,,,,,,, -25843,,,0.646230161190033,1.6848127841949463,31.48282830995302,0.6560364961624146,1.6232653856277466,27.66872427113727,3000.0,0.6663180589675903,1.546418070793152,27.17611606060272,3003.0,9281.125701904297,15423.621300458908,9281.125701904297,6141.316317081451,0.3110058307647705,0.0 -25900,0.21927187,1.8087225,,,,,,,,,,,,,,,,, -26000,0.20411243,1.7902948,,,,,,,,,,,,,,,,, -26100,0.19112012,1.8643266,,,,,,,,,,,,,,,,, -26200,0.209917,1.7963704,,,,,,,,,,,,,,,,, -26300,0.22073938,1.7964611,,,,,,,,,,,,,,,,, -26400,0.20511174,1.835734,,,,,,,,,,,,,,,,, -26500,0.1831175,1.8569444,,,,,,,,,,,,,,,,, -26600,0.20220871,1.843692,,,,,,,,,,,,,,,,, -26700,0.20417885,1.7995291,,,,,,,,,,,,,,,,, -26800,0.20056754,1.9437418,,,,,,,,,,,,,,,,, -26900,0.19639558,1.8215982,,,,,,,,,,,,,,,,, -27000,0.24158604,1.8541263,,,,,,,,,,,,,,,,, -27100,0.2310209,1.8020047,,,,,,,,,,,,,,,,, -27200,0.2109831,1.8454835,,,,,,,,,,,,,,,,, -27300,0.2098387,1.880642,,,,,,,,,,,,,,,,, -27400,0.21622114,1.7691938,,,,,,,,,,,,,,,,, -27500,0.1912309,1.845862,,,,,,,,,,,,,,,,, -27600,0.24413249,1.7868632,,,,,,,,,,,,,,,,, -27700,0.21281895,1.7855433,,,,,,,,,,,,,,,,, -27800,0.24883334,1.8994234,,,,,,,,,,,,,,,,, -27900,0.19947937,1.8359023,,,,,,,,,,,,,,,,, -28000,0.21672021,1.8705921,,,,,,,,,,,,,,,,, -28100,0.19016846,1.8305862,,,,,,,,,,,,,,,,, -28192,,,0.64000004529953,1.739559531211853,31.34536810411951,0.6591982841491699,1.6082539558410645,28.11320010201369,3000.0,0.6693858504295349,1.5327434539794922,27.31636621568932,3003.0,10121.028459310532,16768.38085126877,10121.028459310532,6646.066769123077,0.3381881713867187,0.0 -28200,0.30872306,1.8304553,,,,,,,,,,,,,,,,, -28300,0.20726205,1.7599629,,,,,,,,,,,,,,,,, -28400,0.18200308,1.7501822,,,,,,,,,,,,,,,,, -28500,0.1959329,1.8209511,,,,,,,,,,,,,,,,, -28600,0.1880383,1.851028,,,,,,,,,,,,,,,,, -28700,0.2319467,1.8627472,,,,,,,,,,,,,,,,, -28800,0.1987042,1.8651496,,,,,,,,,,,,,,,,, -28900,0.18153876,1.844507,,,,,,,,,,,,,,,,, -29000,0.18073025,1.8239346,,,,,,,,,,,,,,,,, -29100,0.20055504,1.8231047,,,,,,,,,,,,,,,,, -29200,0.18464105,1.8652692,,,,,,,,,,,,,,,,, -29300,0.25004926,1.884757,,,,,,,,,,,,,,,,, -29400,0.1972502,1.8200214,,,,,,,,,,,,,,,,, -29500,0.22600374,1.8508208,,,,,,,,,,,,,,,,, -29600,0.19102333,1.7481273,,,,,,,,,,,,,,,,, -29700,0.17898919,1.7824662,,,,,,,,,,,,,,,,, -29800,0.22489618,1.8612732,,,,,,,,,,,,,,,,, -29900,0.24982025,1.7258847,,,,,,,,,,,,,,,,, -30000,0.1990965,1.7637115,,,,,,,,,,,,,,,,, -30100,0.25472018,1.7697307,,,,,,,,,,,,,,,,, -30200,0.24246606,1.8841416,,,,,,,,,,,,,,,,, -30300,0.18051088,1.747889,,,,,,,,,,,,,,,,, -30400,0.23617314,1.8113089,,,,,,,,,,,,,,,,, -30500,0.18666802,1.7307687,,,,,,,,,,,,,,,,, -30543,,,0.6375179886817932,1.7588555812835691,30.77511208561493,0.658925473690033,1.6004260778427124,28.296488039179028,3000.0,0.6699784994125366,1.522635579109192,27.527019800715944,3003.0,10961.262724161148,18089.89426422119,10961.262724161148,7127.2394053936005,0.3667356967926025,0.0 -30600,0.17645188,1.8050807,,,,,,,,,,,,,,,,, -30700,0.18076386,1.7383312,,,,,,,,,,,,,,,,, -30800,0.19392994,1.7362677,,,,,,,,,,,,,,,,, -30900,0.19378491,1.7593107,,,,,,,,,,,,,,,,, -31000,0.3896599,1.7946866,,,,,,,,,,,,,,,,, -31100,0.23149078,1.8065766,,,,,,,,,,,,,,,,, -31200,0.19218604,1.7982775,,,,,,,,,,,,,,,,, -31300,0.18369444,1.8922938,,,,,,,,,,,,,,,,, -31400,0.20953335,1.8639253,,,,,,,,,,,,,,,,, -31500,0.19194052,1.8387742,,,,,,,,,,,,,,,,, -31600,0.18776739,1.7722118,,,,,,,,,,,,,,,,, -31700,0.18688278,1.8059918,,,,,,,,,,,,,,,,, -31800,0.17560725,1.7904075,,,,,,,,,,,,,,,,, -31900,0.19669382,1.8684136,,,,,,,,,,,,,,,,, -32000,0.20572238,1.7917974,,,,,,,,,,,,,,,,, -32100,0.2699829,1.7560654,,,,,,,,,,,,,,,,, -32200,0.22350028,1.8727118,,,,,,,,,,,,,,,,, -32300,0.19782375,1.7692304,,,,,,,,,,,,,,,,, -32400,0.19713071,1.8766319,,,,,,,,,,,,,,,,, -32500,0.25476342,1.71896,,,,,,,,,,,,,,,,, -32600,0.20460159,1.7384453,,,,,,,,,,,,,,,,, -32700,0.25267518,1.7547675,,,,,,,,,,,,,,,,, -32800,0.25694573,1.8915452,,,,,,,,,,,,,,,,, -32893,,,0.6432289481163025,1.7125086784362793,31.19315213816847,0.6582435369491577,1.6062827110290527,28.12177583729921,3000.0,0.6708035469055176,1.5205928087234497,27.49746230167418,3003.0,11801.242480754852,19417.62895989418,11801.242480754852,7614.882809877396,0.3996870517730713,0.0 -32900,1.0959768,1.8259829,,,,,,,,,,,,,,,,, -33000,0.21821828,1.8719854,,,,,,,,,,,,,,,,, -33100,0.19309856,1.7196147,,,,,,,,,,,,,,,,, -33200,0.23363854,1.7555313,,,,,,,,,,,,,,,,, -33300,0.22384949,1.7190781,,,,,,,,,,,,,,,,, -33400,0.20637149,1.8668593,,,,,,,,,,,,,,,,, -33500,0.1934224,1.8362079,,,,,,,,,,,,,,,,, -33600,0.20765445,1.7195132,,,,,,,,,,,,,,,,, -33700,0.20447437,1.7953801,,,,,,,,,,,,,,,,, -33800,0.23213445,1.8261939,,,,,,,,,,,,,,,,, -33900,0.20638077,1.7673211,,,,,,,,,,,,,,,,, -34000,0.19579205,1.8207611,,,,,,,,,,,,,,,,, -34100,0.21564227,1.801506,,,,,,,,,,,,,,,,, -34200,0.24815767,1.8226223,,,,,,,,,,,,,,,,, -34300,0.20946787,1.7118742,,,,,,,,,,,,,,,,, -34400,0.19300427,1.7963492,,,,,,,,,,,,,,,,, -34500,0.23396833,1.7827938,,,,,,,,,,,,,,,,, -34600,0.23798849,1.7426938,,,,,,,,,,,,,,,,, -34700,0.18525304,1.712643,,,,,,,,,,,,,,,,, -34800,0.27210602,1.7356421,,,,,,,,,,,,,,,,, -34900,0.21218027,1.7203355,,,,,,,,,,,,,,,,, -35000,0.20827928,1.82753,,,,,,,,,,,,,,,,, -35100,0.19495334,1.8066863,,,,,,,,,,,,,,,,, -35200,0.20083472,1.700144,,,,,,,,,,,,,,,,, -35244,,,0.6422494649887085,1.720298171043396,31.382447419480027,0.6626824140548706,1.5820107460021973,28.378079676535123,3000.0,0.6722328662872314,1.5019036531448364,27.84060668588333,3003.0,12641.374759197235,20799.00270557404,12641.374759197235,8156.013454914093,0.4326503276824951,0.0 -35300,0.20736316,1.7923628,,,,,,,,,,,,,,,,, -35400,0.18286592,1.8062124,,,,,,,,,,,,,,,,, -35500,0.20202364,1.8112713,,,,,,,,,,,,,,,,, -35600,0.23043399,1.674877,,,,,,,,,,,,,,,,, -35700,0.26873994,1.7713199,,,,,,,,,,,,,,,,, -35800,0.19249974,1.7398798,,,,,,,,,,,,,,,,, -35900,0.21185407,1.7756001,,,,,,,,,,,,,,,,, -36000,0.29716405,1.8511918,,,,,,,,,,,,,,,,, -36100,0.20154697,1.8228054,,,,,,,,,,,,,,,,, -36200,0.18474454,1.761267,,,,,,,,,,,,,,,,, -36300,0.19127934,1.7401928,,,,,,,,,,,,,,,,, -36400,0.19595516,1.7452857,,,,,,,,,,,,,,,,, -36500,0.24531716,1.7946209,,,,,,,,,,,,,,,,, -36600,0.21678364,1.7599568,,,,,,,,,,,,,,,,, -36700,0.20625848,1.7790124,,,,,,,,,,,,,,,,, -36800,0.20415455,1.7904346,,,,,,,,,,,,,,,,, -36900,0.20233937,1.7973894,,,,,,,,,,,,,,,,, -37000,0.2112176,1.8114429,,,,,,,,,,,,,,,,, -37100,0.23131776,1.7629801,,,,,,,,,,,,,,,,, -37200,0.19252598,1.8134755,,,,,,,,,,,,,,,,, -37300,0.2589074,1.8987377,,,,,,,,,,,,,,,,, -37400,0.22814125,1.8456553,,,,,,,,,,,,,,,,, -37500,0.2370876,1.8045233,,,,,,,,,,,,,,,,, -37593,,,0.6800824999809265,1.45755136013031,34.355612800032446,0.6637487411499023,1.5757863521575928,28.429244621089666,3000.0,0.6723839640617371,1.4996920824050903,27.599065186221257,3003.0,13481.296279668808,22164.41160917282,13481.296279668808,8681.38389492035,0.4682419300079345,0.0 -37600,0.18342763,1.7775822,,,,,,,,,,,,,,,,, -37700,0.19928065,1.7710922,,,,,,,,,,,,,,,,, -37800,0.1952005,1.77228,,,,,,,,,,,,,,,,, -37900,0.20257582,1.8122057,,,,,,,,,,,,,,,,, -38000,0.19527216,1.7224927,,,,,,,,,,,,,,,,, -38100,0.19645159,1.7314491,,,,,,,,,,,,,,,,, -38200,0.18897632,1.7897182,,,,,,,,,,,,,,,,, -38300,0.18036765,1.7815576,,,,,,,,,,,,,,,,, -38400,0.20491871,1.8232697,,,,,,,,,,,,,,,,, -38500,0.22273672,1.7395813,,,,,,,,,,,,,,,,, -38600,0.20985743,1.7847328,,,,,,,,,,,,,,,,, -38700,0.22615835,1.7219158,,,,,,,,,,,,,,,,, -38800,0.19372138,1.7195956,,,,,,,,,,,,,,,,, -38900,0.21353501,1.7651787,,,,,,,,,,,,,,,,, -39000,0.18484649,1.7849765,,,,,,,,,,,,,,,,, -39100,0.19801906,1.7702899,,,,,,,,,,,,,,,,, -39200,0.22125195,1.7878985,,,,,,,,,,,,,,,,, -39300,0.22155891,1.8224478,,,,,,,,,,,,,,,,, -39400,0.38475806,1.8305535,,,,,,,,,,,,,,,,, -39500,0.2887433,1.7863791,,,,,,,,,,,,,,,,, -39600,0.19410186,1.754752,,,,,,,,,,,,,,,,, -39700,0.20769468,1.7544132,,,,,,,,,,,,,,,,, -39800,0.19332568,1.8112266,,,,,,,,,,,,,,,,, -39900,0.19383602,1.6701169,,,,,,,,,,,,,,,,, -39943,,,0.6489623785018921,1.68193519115448,32.00845543488021,0.6647778749465942,1.5651044845581057,28.787674762573683,3000.0,0.6742083430290222,1.4953006505966189,27.961575161697663,3003.0,14321.491159915924,23621.00763463974,14321.491159915924,9297.66802597046,0.5044949054718018,0.0 -40000,0.18964688,1.6851293,,,,,,,,,,,,,,,,, -40100,0.19530348,1.7372354,,,,,,,,,,,,,,,,, -40200,0.21488799,1.8857651,,,,,,,,,,,,,,,,, -40300,0.23084661,1.8405616,,,,,,,,,,,,,,,,, -40400,0.18849784,1.7481437,,,,,,,,,,,,,,,,, -40500,0.19099745,1.8109615,,,,,,,,,,,,,,,,, -40600,0.258493,1.6291901,,,,,,,,,,,,,,,,, -40700,0.19818525,1.7818362,,,,,,,,,,,,,,,,, -40800,0.17887479,1.757433,,,,,,,,,,,,,,,,, -40900,0.19856721,1.783554,,,,,,,,,,,,,,,,, -41000,0.19438756,1.7622627,,,,,,,,,,,,,,,,, -41100,0.22889894,1.8268415,,,,,,,,,,,,,,,,, -41200,0.22723414,1.7365353,,,,,,,,,,,,,,,,, -41300,0.18921989,1.7673492,,,,,,,,,,,,,,,,, -41400,0.1849589,1.8038626,,,,,,,,,,,,,,,,, -41500,0.18603322,1.6908392,,,,,,,,,,,,,,,,, -41600,0.18979274,1.8120648,,,,,,,,,,,,,,,,, -41700,0.23123527,1.769262,,,,,,,,,,,,,,,,, -41800,0.19760734,1.7506146,,,,,,,,,,,,,,,,, -41900,0.19501947,1.7578782,,,,,,,,,,,,,,,,, -42000,0.25702575,1.7543584,,,,,,,,,,,,,,,,, -42100,0.2285223,1.8035647,,,,,,,,,,,,,,,,, -42200,0.25237226,1.8072114,,,,,,,,,,,,,,,,, -42293,,,0.6470581293106079,1.6912137269973757,31.773503395043683,0.6645422577857971,1.5649478435516355,28.337935195464578,3000.0,0.6760095357894897,1.48641037940979,27.91267454128032,3003.0,15161.52368092537,25030.84917783737,15161.52368092537,9867.36664390564,0.5345263481140137,0.0 -42300,0.20416719,1.7159024,,,,,,,,,,,,,,,,, -42400,0.17784747,1.8147246,,,,,,,,,,,,,,,,, -42500,0.20302187,1.7479256,,,,,,,,,,,,,,,,, -42600,0.2002014,1.7101939,,,,,,,,,,,,,,,,, -42700,0.18950826,1.7072448,,,,,,,,,,,,,,,,, -42800,0.20612372,1.7414016,,,,,,,,,,,,,,,,, -42900,0.19406481,1.843703,,,,,,,,,,,,,,,,, -43000,0.21461773,1.8098298,,,,,,,,,,,,,,,,, -43100,0.21069859,1.7556267,,,,,,,,,,,,,,,,, -43200,0.19332784,1.7057155,,,,,,,,,,,,,,,,, -43300,0.19430807,1.7399164,,,,,,,,,,,,,,,,, -43400,0.20086995,1.789177,,,,,,,,,,,,,,,,, -43500,0.18898183,1.7199785,,,,,,,,,,,,,,,,, -43600,0.20153259,1.7237139,,,,,,,,,,,,,,,,, -43700,0.20635888,1.7607363,,,,,,,,,,,,,,,,, -43800,0.19168228,1.6992642,,,,,,,,,,,,,,,,, -43900,0.27869618,1.7772948,,,,,,,,,,,,,,,,, -44000,0.20306538,1.770407,,,,,,,,,,,,,,,,, -44100,0.2224049,1.7640551,,,,,,,,,,,,,,,,, -44200,0.17752072,1.7274796,,,,,,,,,,,,,,,,, -44300,0.18593036,1.8084795,,,,,,,,,,,,,,,,, -44400,0.21118566,1.7179044,,,,,,,,,,,,,,,,, -44500,0.19865024,1.8607312,,,,,,,,,,,,,,,,, -44600,0.19625229,1.7096386,,,,,,,,,,,,,,,,, -44644,,,0.6550261974334717,1.621602177619934,32.423378180989296,0.6661665439605713,1.5491173267364502,28.65202752908857,3000.0,0.6754517555236816,1.4745254516601562,27.63704969028884,3003.0,16001.64799952507,26339.4004714489,16001.64799952507,10335.682363271711,0.5661778450012207,0.0 -44700,0.22143874,1.7343836,,,,,,,,,,,,,,,,, -44800,0.199167,1.7020406,,,,,,,,,,,,,,,,, -44900,0.28115997,1.8279985,,,,,,,,,,,,,,,,, -45000,0.18711054,1.78601,,,,,,,,,,,,,,,,, -45100,0.21876068,1.7645838,,,,,,,,,,,,,,,,, -45200,0.20474413,1.732986,,,,,,,,,,,,,,,,, -45300,0.23380123,1.8313365,,,,,,,,,,,,,,,,, -45400,0.21866438,1.7864201,,,,,,,,,,,,,,,,, -45500,0.20230272,1.8183466,,,,,,,,,,,,,,,,, -45600,0.4083509,1.7562597,,,,,,,,,,,,,,,,, -45700,0.20511137,1.7189068,,,,,,,,,,,,,,,,, -45800,0.20024666,1.6868515,,,,,,,,,,,,,,,,, -45900,0.22148539,1.841121,,,,,,,,,,,,,,,,, -46000,0.21211354,1.8260099,,,,,,,,,,,,,,,,, -46100,0.21681103,1.7221928,,,,,,,,,,,,,,,,, -46200,0.20693217,1.7836457,,,,,,,,,,,,,,,,, -46300,0.21106614,1.7622046,,,,,,,,,,,,,,,,, -46400,0.24344285,1.7714944,,,,,,,,,,,,,,,,, -46500,0.19466908,1.7943848,,,,,,,,,,,,,,,,, -46600,0.24222524,1.7690223,,,,,,,,,,,,,,,,, -46700,0.2020565,1.8057235,,,,,,,,,,,,,,,,, -46800,0.19841182,1.728515,,,,,,,,,,,,,,,,, -46900,0.19318794,1.7317784,,,,,,,,,,,,,,,,, -46993,,,0.6464254856109619,1.6788039207458496,31.81020358869237,0.666340172290802,1.5431653261184692,28.39240158912281,3000.0,0.6792981624603271,1.4608837366104126,28.03113906907784,3003.0,16841.717359304428,27790.774476528168,16841.717359304428,10946.873561382294,0.5983190536499023,0.0 -47000,0.20174798,1.6922014,,,,,,,,,,,,,,,,, -47100,0.18627135,1.7106249,,,,,,,,,,,,,,,,, -47200,0.21352932,1.7151979,,,,,,,,,,,,,,,,, -47300,0.17962019,1.6598686,,,,,,,,,,,,,,,,, -47400,0.20206329,1.7897426,,,,,,,,,,,,,,,,, -47500,0.18565442,1.661266,,,,,,,,,,,,,,,,, -47600,0.47085625,1.748113,,,,,,,,,,,,,,,,, -47700,0.19392377,1.7432021,,,,,,,,,,,,,,,,, -47800,0.18614912,1.7340897,,,,,,,,,,,,,,,,, -47900,0.19161816,1.7509049,,,,,,,,,,,,,,,,, -48000,0.19850902,1.7404364,,,,,,,,,,,,,,,,, -48100,0.20973718,1.7820721,,,,,,,,,,,,,,,,, -48200,0.20647103,1.7372543,,,,,,,,,,,,,,,,, -48300,0.19873452,1.7399175,,,,,,,,,,,,,,,,, -48400,0.19092873,1.7323883,,,,,,,,,,,,,,,,, -48500,0.20056424,1.7292664,,,,,,,,,,,,,,,,, -48600,0.26124492,1.7691365,,,,,,,,,,,,,,,,, -48700,0.20607612,1.8181621,,,,,,,,,,,,,,,,, -48800,0.20588557,1.6683099,,,,,,,,,,,,,,,,, -48900,0.27383024,1.735496,,,,,,,,,,,,,,,,, -49000,0.244035,1.7667123,,,,,,,,,,,,,,,,, -49100,0.19108799,1.7439406,,,,,,,,,,,,,,,,, -49200,0.19119261,1.7169185,,,,,,,,,,,,,,,,, -49300,0.22277084,1.8093451,,,,,,,,,,,,,,,,, -49344,,,0.6485030055046082,1.683557629585266,31.90822207781761,0.6678528189659119,1.540746808052063,28.77088171329392,3000.0,0.6801580786705017,1.4569159746170044,28.28024055566884,3003.0,17681.87114572525,29182.23004412651,17681.87114572525,11498.062032461166,0.6335971355438232,0.0 -49400,0.1975662,1.6994935,,,,,,,,,,,,,,,,, -49500,0.24594367,1.712791,,,,,,,,,,,,,,,,, -49600,0.22099273,1.6989055,,,,,,,,,,,,,,,,, -49700,0.19162077,1.6309414,,,,,,,,,,,,,,,,, -49800,0.2329571,1.78843,,,,,,,,,,,,,,,,, -49900,0.23360446,1.7796689,,,,,,,,,,,,,,,,, -50000,0.18651102,1.7317015,,,,,,,,,,,,,,,,, -50100,0.19917133,1.680781,,,,,,,,,,,,,,,,, -50200,0.1988084,1.80987,,,,,,,,,,,,,,,,, -50300,0.18912849,1.6925459,,,,,,,,,,,,,,,,, -50400,0.20074125,1.7877121,,,,,,,,,,,,,,,,, -50500,0.21182513,1.743618,,,,,,,,,,,,,,,,, -50600,0.20094043,1.6996648,,,,,,,,,,,,,,,,, -50700,0.1930631,1.6527164,,,,,,,,,,,,,,,,, -50800,0.25068796,1.702313,,,,,,,,,,,,,,,,, -50900,0.19725217,1.728484,,,,,,,,,,,,,,,,, -51000,0.2232338,1.7101151,,,,,,,,,,,,,,,,, -51100,0.18452641,1.7055646,,,,,,,,,,,,,,,,, -51200,0.19781335,1.7249165,,,,,,,,,,,,,,,,, -51300,0.1952119,1.7005677,,,,,,,,,,,,,,,,, -51400,0.20355389,1.7622982,,,,,,,,,,,,,,,,, -51500,0.20771733,1.6461897,,,,,,,,,,,,,,,,, -51600,0.22607297,1.7165326,,,,,,,,,,,,,,,,, -51695,,,0.6541595458984375,1.6322287321090698,32.13664638188772,0.6684479713439941,1.5355753898620603,28.76373048098764,3000.0,0.6809133887290955,1.451818823814392,28.51707180407588,3003.0,18521.941950798035,30609.1839325428,18521.941950798035,12084.83533358574,0.664344072341919,0.0 -51700,0.19998138,1.7180402,,,,,,,,,,,,,,,,, -51800,0.19427721,1.7553217,,,,,,,,,,,,,,,,, -51900,0.19073334,1.7110654,,,,,,,,,,,,,,,,, -52000,0.19817357,1.7382016,,,,,,,,,,,,,,,,, -52100,0.22745395,1.7167535,,,,,,,,,,,,,,,,, -52200,0.19196838,1.7553262,,,,,,,,,,,,,,,,, -52300,0.21035475,1.7466568,,,,,,,,,,,,,,,,, -52400,0.1948368,1.7534393,,,,,,,,,,,,,,,,, -52500,0.1905953,1.7732917,,,,,,,,,,,,,,,,, -52600,0.26915932,1.7888886,,,,,,,,,,,,,,,,, -52700,0.19950584,1.8092496,,,,,,,,,,,,,,,,, -52800,0.19270724,1.7194073,,,,,,,,,,,,,,,,, -52900,0.21385963,1.7054527,,,,,,,,,,,,,,,,, -53000,0.20062222,1.7022752,,,,,,,,,,,,,,,,, -53100,0.19243461,1.7971084,,,,,,,,,,,,,,,,, -53200,0.2059637,1.817331,,,,,,,,,,,,,,,,, -53300,0.25128675,1.7056051,,,,,,,,,,,,,,,,, -53400,0.21138836,1.6932398,,,,,,,,,,,,,,,,, -53500,0.19656532,1.7891903,,,,,,,,,,,,,,,,, -53600,0.22142524,1.6652571,,,,,,,,,,,,,,,,, -53700,0.18183365,1.6846039,,,,,,,,,,,,,,,,, -53800,0.2088416,1.7662656,,,,,,,,,,,,,,,,, -53900,0.18871212,1.727465,,,,,,,,,,,,,,,,, -54000,0.21849939,1.7727025,,,,,,,,,,,,,,,,, -54045,,,0.6509293913841248,1.6562706232070925,32.05001778447784,0.6697251200675964,1.5251288414001465,28.898027363811607,3000.0,0.6835047602653503,1.4431202411651611,28.57077916909548,3003.0,19361.91049265861,32074.034927845,19361.91049265861,12709.605189323423,0.6956663131713867,0.0 -54100,0.19691575,1.6051904,,,,,,,,,,,,,,,,, -54200,0.50131565,1.8254484,,,,,,,,,,,,,,,,, -54300,0.20102485,1.7687685,,,,,,,,,,,,,,,,, -54400,0.19881436,1.7076744,,,,,,,,,,,,,,,,, -54500,0.21846546,1.6764988,,,,,,,,,,,,,,,,, -54600,0.18755768,1.6688613,,,,,,,,,,,,,,,,, -54700,0.20361105,1.7161181,,,,,,,,,,,,,,,,, -54800,0.21013589,1.8040559,,,,,,,,,,,,,,,,, -54900,0.21076052,1.7329148,,,,,,,,,,,,,,,,, -55000,0.22126253,1.6747987,,,,,,,,,,,,,,,,, -55100,0.1967647,1.6825318,,,,,,,,,,,,,,,,, -55200,0.18953684,1.659773,,,,,,,,,,,,,,,,, -55300,0.18120362,1.7039678,,,,,,,,,,,,,,,,, -55400,0.21523471,1.7631238,,,,,,,,,,,,,,,,, -55500,0.20023878,1.7747957,,,,,,,,,,,,,,,,, -55600,0.20348045,1.6688993,,,,,,,,,,,,,,,,, -55700,0.17602299,1.6729821,,,,,,,,,,,,,,,,, -55800,0.19611213,1.6851374,,,,,,,,,,,,,,,,, -55900,0.18457828,1.7243427,,,,,,,,,,,,,,,,, -56000,0.20026773,1.7079198,,,,,,,,,,,,,,,,, -56100,0.19297627,1.6828936,,,,,,,,,,,,,,,,, -56200,0.21559478,1.7271293,,,,,,,,,,,,,,,,, -56300,0.18783398,1.696739,,,,,,,,,,,,,,,,, -56396,,,0.6806057691574097,1.4595482349395752,34.32327576098257,0.6707666516304016,1.515071988105774,29.178939066208603,3000.0,0.6842600703239441,1.435634970664978,28.68166147587714,3003.0,20202.092440128326,33446.429708480835,20202.092440128326,13241.70564031601,0.7274036407470703,0.0 -56400,0.2085873,1.698293,,,,,,,,,,,,,,,,, -56500,0.27477667,1.7313749,,,,,,,,,,,,,,,,, -56600,0.18658826,1.8087267,,,,,,,,,,,,,,,,, -56700,0.19766194,1.6962562,,,,,,,,,,,,,,,,, -56800,0.21278004,1.7403038,,,,,,,,,,,,,,,,, -56900,0.18674667,1.6285992,,,,,,,,,,,,,,,,, -57000,0.21338867,1.7049091,,,,,,,,,,,,,,,,, -57100,0.19062378,1.7265133,,,,,,,,,,,,,,,,, -57200,0.20752902,1.7468458,,,,,,,,,,,,,,,,, -57300,0.19858344,1.7455764,,,,,,,,,,,,,,,,, -57400,0.189396,1.7293587,,,,,,,,,,,,,,,,, -57500,0.2100167,1.651694,,,,,,,,,,,,,,,,, -57600,0.1963121,1.7535009,,,,,,,,,,,,,,,,, -57700,0.19869837,1.6758732,,,,,,,,,,,,,,,,, -57800,0.22933841,1.6897355,,,,,,,,,,,,,,,,, -57900,0.18449274,1.6371738,,,,,,,,,,,,,,,,, -58000,0.2162105,1.5926857,,,,,,,,,,,,,,,,, -58100,0.19368492,1.6747676,,,,,,,,,,,,,,,,, -58200,0.19544038,1.6578859,,,,,,,,,,,,,,,,, -58300,0.21898055,1.662831,,,,,,,,,,,,,,,,, -58400,0.19181766,1.6939559,,,,,,,,,,,,,,,,, -58500,0.18178761,1.6031907,,,,,,,,,,,,,,,,, -58600,0.1893169,1.671242,,,,,,,,,,,,,,,,, -58700,0.20642555,1.7436087,,,,,,,,,,,,,,,,, -58746,,,0.653881311416626,1.633278727531433,32.14687787371736,0.6720437407493591,1.5138766765594482,28.967526096442047,3000.0,0.6848527193069458,1.4315576553344729,28.62383352483516,3003.0,21042.018231630325,34912.5352704525,21042.018231630325,13867.775197029114,0.7592089176177979,0.0 -58800,0.22445148,1.7733067,,,,,,,,,,,,,,,,, -58900,0.18810815,1.6976461,,,,,,,,,,,,,,,,, -59000,0.19949198,1.7124424,,,,,,,,,,,,,,,,, -59100,0.206248,1.6830134,,,,,,,,,,,,,,,,, -59200,0.20535925,1.6553593,,,,,,,,,,,,,,,,, -59300,0.1841933,1.7897385,,,,,,,,,,,,,,,,, -59400,0.19539669,1.7242198,,,,,,,,,,,,,,,,, -59500,0.19982931,1.6609228,,,,,,,,,,,,,,,,, -59600,0.20384364,1.7029182,,,,,,,,,,,,,,,,, -59700,0.18912655,1.6922566,,,,,,,,,,,,,,,,, -59800,0.2123264,1.6993022,,,,,,,,,,,,,,,,, -59900,0.1883613,1.6853701,,,,,,,,,,,,,,,,, -60000,0.20268457,1.7945043,,,,,,,,,,,,,,,,, -60100,0.20498167,1.7227254,,,,,,,,,,,,,,,,, -60200,0.19353649,1.7371927,,,,,,,,,,,,,,,,, -60300,0.187491,1.7289759,,,,,,,,,,,,,,,,, -60400,0.20799847,1.7286093,,,,,,,,,,,,,,,,, -60500,0.19860798,1.6954267,,,,,,,,,,,,,,,,, -60600,0.19781904,1.6814996,,,,,,,,,,,,,,,,, -60700,0.19553557,1.7363839,,,,,,,,,,,,,,,,, -60800,0.19089855,1.7293143,,,,,,,,,,,,,,,,, -60900,0.20901406,1.7175078,,,,,,,,,,,,,,,,, -61000,0.22283116,1.7363447,,,,,,,,,,,,,,,,, -61097,,,0.6571126580238342,1.6194921731948853,32.149596664961024,0.6735812425613403,1.5096439123153689,29.221768396880584,3000.0,0.6872000694274902,1.4211808443069458,28.95892319870933,3003.0,21882.15102601052,36364.76279473305,21882.15102601052,14479.754987239838,0.7943167686462402,0.0 -61100,0.20780256,1.7644968,,,,,,,,,,,,,,,,, -61200,0.1873308,1.7108313,,,,,,,,,,,,,,,,, -61300,0.20454428,1.7202375,,,,,,,,,,,,,,,,, -61400,0.20646971,1.7090944,,,,,,,,,,,,,,,,, -61500,0.19363153,1.7138637,,,,,,,,,,,,,,,,, -61600,0.2005769,1.6672939,,,,,,,,,,,,,,,,, -61700,0.2085523,1.6867257,,,,,,,,,,,,,,,,, -61800,0.19817449,1.6727483,,,,,,,,,,,,,,,,, -61900,0.18142267,1.6309104,,,,,,,,,,,,,,,,, -62000,0.18621357,1.7804805,,,,,,,,,,,,,,,,, -62100,0.19768094,1.6188644,,,,,,,,,,,,,,,,, -62200,0.19817474,1.7326415,,,,,,,,,,,,,,,,, -62300,0.22731838,1.6984596,,,,,,,,,,,,,,,,, -62400,0.18808204,1.6691982,,,,,,,,,,,,,,,,, -62500,0.19998567,1.7443758,,,,,,,,,,,,,,,,, -62600,0.2335989,1.662472,,,,,,,,,,,,,,,,, -62700,0.20945467,1.7551973,,,,,,,,,,,,,,,,, -62800,0.21502307,1.7867773,,,,,,,,,,,,,,,,, -62900,0.23783292,1.6895242,,,,,,,,,,,,,,,,, -63000,0.20553806,1.6972095,,,,,,,,,,,,,,,,, -63100,0.20540872,1.6333224,,,,,,,,,,,,,,,,, -63200,0.20163123,1.6822457,,,,,,,,,,,,,,,,, -63300,0.18931584,1.6743501,,,,,,,,,,,,,,,,, -63400,0.20728217,1.7772076,,,,,,,,,,,,,,,,, -63447,,,0.6630586385726929,1.5684752464294434,32.501644119858895,0.6748087406158447,1.494695782661438,29.088602365219,3000.0,0.6873278617858887,1.4133059978485107,29.13141158817472,3003.0,22722.04600667953,37834.20046401024,22722.04600667953,15109.183556556702,0.8285095691680908,0.0 -63500,0.19355714,1.667659,,,,,,,,,,,,,,,,, -63600,0.20939195,1.7235072,,,,,,,,,,,,,,,,, -63700,0.18872613,1.6563209,,,,,,,,,,,,,,,,, -63800,0.19300051,1.6684891,,,,,,,,,,,,,,,,, -63900,0.20264858,1.6295168,,,,,,,,,,,,,,,,, -64000,0.20631975,1.6838589,,,,,,,,,,,,,,,,, -64100,0.19421904,1.652074,,,,,,,,,,,,,,,,, -64200,0.20689121,1.7553393,,,,,,,,,,,,,,,,, -64300,0.1910257,1.6877556,,,,,,,,,,,,,,,,, -64400,0.20443921,1.7622826,,,,,,,,,,,,,,,,, -64500,0.19524316,1.7849785,,,,,,,,,,,,,,,,, -64600,0.22038163,1.6773071,,,,,,,,,,,,,,,,, -64700,0.19688168,1.7275275,,,,,,,,,,,,,,,,, -64800,0.19576414,1.7282753,,,,,,,,,,,,,,,,, -64900,0.2003633,1.7247428,,,,,,,,,,,,,,,,, -65000,0.2615859,1.711278,,,,,,,,,,,,,,,,, -65100,0.18887594,1.6543802,,,,,,,,,,,,,,,,, -65200,0.20140801,1.725779,,,,,,,,,,,,,,,,, -65300,0.19350365,1.597883,,,,,,,,,,,,,,,,, -65400,0.19711134,1.6513271,,,,,,,,,,,,,,,,, -65500,0.20690885,1.6594212,,,,,,,,,,,,,,,,, -65600,0.21102294,1.749659,,,,,,,,,,,,,,,,, -65700,0.2197754,1.6078587,,,,,,,,,,,,,,,,, -65797,,,0.6596559286117554,1.6069821119308472,32.72167565803844,0.673159658908844,1.5000309944152832,29.10851777105609,3000.0,0.687897264957428,1.4136604070663452,28.904316917236,3003.0,23562.01352858544,39345.014755010605,23562.01352858544,15779.91496014595,0.8635752201080322,0.0 -65800,0.22189827,1.6698282,,,,,,,,,,,,,,,,, -65900,0.19563927,1.6278086,,,,,,,,,,,,,,,,, -66000,0.21526055,1.6438028,,,,,,,,,,,,,,,,, -66100,0.19906287,1.656173,,,,,,,,,,,,,,,,, -66200,0.2116653,1.6059475,,,,,,,,,,,,,,,,, -66300,0.24890925,1.6984075,,,,,,,,,,,,,,,,, -66400,0.18810168,1.6480963,,,,,,,,,,,,,,,,, -66500,0.19211671,1.6519699,,,,,,,,,,,,,,,,, -66600,0.43812513,1.6717551,,,,,,,,,,,,,,,,, -66700,0.19324912,1.6601905,,,,,,,,,,,,,,,,, -66800,0.23383833,1.6687635,,,,,,,,,,,,,,,,, -66900,0.19415933,1.6564443,,,,,,,,,,,,,,,,, -67000,0.19496688,1.6937797,,,,,,,,,,,,,,,,, -67100,0.21092796,1.7255799,,,,,,,,,,,,,,,,, -67200,0.17633362,1.5862367,,,,,,,,,,,,,,,,, -67300,0.21322608,1.6807299,,,,,,,,,,,,,,,,, -67400,0.22745688,1.7101059,,,,,,,,,,,,,,,,, -67500,0.20698828,1.6725411,,,,,,,,,,,,,,,,, -67600,0.2088557,1.5702612,,,,,,,,,,,,,,,,, -67700,0.19282618,1.6713208,,,,,,,,,,,,,,,,, -67800,0.21177253,1.6733661,,,,,,,,,,,,,,,,, -67900,0.25132447,1.719689,,,,,,,,,,,,,,,,, -68000,0.20051935,1.5952945,,,,,,,,,,,,,,,,, -68100,0.20789458,1.656794,,,,,,,,,,,,,,,,, -68147,,,0.6588523983955383,1.6124247312545776,32.49809723904921,0.6761106252670288,1.4879592657089231,29.24748990886783,3000.0,0.6913717985153198,1.398087501525879,29.46735070037606,3003.0,24402.17420578003,40759.86060547829,24402.17420578003,16354.474965810776,0.905400276184082,0.0 -68200,0.19409797,1.7006131,,,,,,,,,,,,,,,,, -68300,0.20026267,1.7399253,,,,,,,,,,,,,,,,, -68400,0.22731753,1.5872148,,,,,,,,,,,,,,,,, -68500,0.19310156,1.6334593,,,,,,,,,,,,,,,,, -68600,0.22076319,1.7631476,,,,,,,,,,,,,,,,, -68700,0.20255163,1.6743207,,,,,,,,,,,,,,,,, -68800,0.19562067,1.6876751,,,,,,,,,,,,,,,,, -68900,0.18590786,1.6176022,,,,,,,,,,,,,,,,, -69000,0.31447062,1.6224028,,,,,,,,,,,,,,,,, -69100,0.1887111,1.7214706,,,,,,,,,,,,,,,,, -69200,0.20576738,1.5717632,,,,,,,,,,,,,,,,, -69300,0.20519626,1.7006037,,,,,,,,,,,,,,,,, -69400,0.19377081,1.7027737,,,,,,,,,,,,,,,,, -69500,0.20284799,1.6783931,,,,,,,,,,,,,,,,, -69600,0.19862418,1.7188247,,,,,,,,,,,,,,,,, -69700,0.23500331,1.6807913,,,,,,,,,,,,,,,,, -69800,0.1957621,1.6238512,,,,,,,,,,,,,,,,, -69900,0.20511803,1.6555153,,,,,,,,,,,,,,,,, -70000,0.18308161,1.5715353,,,,,,,,,,,,,,,,, -70100,0.18814184,1.6348315,,,,,,,,,,,,,,,,, -70200,0.18189056,1.7418025,,,,,,,,,,,,,,,,, -70300,0.21090093,1.6767389,,,,,,,,,,,,,,,,, -70400,0.28771105,1.7818404,,,,,,,,,,,,,,,,, -70497,,,0.6634678840637207,1.5741132497787476,33.044339503927496,0.6757386922836304,1.4837726354599,29.433881912093547,3000.0,0.6908837556838989,1.3939534425735474,29.18475442807425,3003.0,25242.11082100868,42189.13254570961,25242.11082100868,16943.696088552475,0.9412546157836914,0.0 -70500,0.21176444,1.7030236,,,,,,,,,,,,,,,,, -70600,0.22101107,1.7338536,,,,,,,,,,,,,,,,, -70700,0.22490388,1.7068149,,,,,,,,,,,,,,,,, -70800,0.19548066,1.6887184,,,,,,,,,,,,,,,,, -70900,0.4518179,1.5789988,,,,,,,,,,,,,,,,, -71000,0.20751576,1.6212904,,,,,,,,,,,,,,,,, -71100,0.19058746,1.7098134,,,,,,,,,,,,,,,,, -71200,0.19497085,1.6784908,,,,,,,,,,,,,,,,, -71300,0.20356233,1.6784728,,,,,,,,,,,,,,,,, -71400,0.21373259,1.7021338,,,,,,,,,,,,,,,,, -71500,0.22466533,1.6503816,,,,,,,,,,,,,,,,, -71600,0.19236313,1.6232541,,,,,,,,,,,,,,,,, -71700,0.2098728,1.7835188,,,,,,,,,,,,,,,,, -71800,0.21491341,1.6808547,,,,,,,,,,,,,,,,, -71900,0.1958923,1.6727103,,,,,,,,,,,,,,,,, -72000,0.19380884,1.645856,,,,,,,,,,,,,,,,, -72100,0.21303628,1.6887838,,,,,,,,,,,,,,,,, -72200,0.2100243,1.6568681,,,,,,,,,,,,,,,,, -72300,0.20376003,1.5867025,,,,,,,,,,,,,,,,, -72400,0.20006374,1.6933633,,,,,,,,,,,,,,,,, -72500,0.1884733,1.6195197,,,,,,,,,,,,,,,,, -72600,0.20205256,1.6165067,,,,,,,,,,,,,,,,, -72700,0.19590795,1.7145479,,,,,,,,,,,,,,,,, -72800,0.23271887,1.6465231,,,,,,,,,,,,,,,,, -72847,,,0.6604095101356506,1.5908526182174685,32.75089059937514,0.6772513389587402,1.4792447090148926,29.40842404506689,3000.0,0.6912323832511902,1.3835713863372805,28.924374363232204,3003.0,26082.05291867256,43577.98728346825,26082.05291867256,17492.493832349777,0.9753479957580566,0.0 -72900,0.20301844,1.6394314,,,,,,,,,,,,,,,,, -73000,0.22237627,1.7143422,,,,,,,,,,,,,,,,, -73100,0.19251141,1.6172231,,,,,,,,,,,,,,,,, -73200,0.21456851,1.6628245,,,,,,,,,,,,,,,,, -73300,0.20032065,1.6286136,,,,,,,,,,,,,,,,, -73400,0.19328758,1.6882255,,,,,,,,,,,,,,,,, -73500,0.2019497,1.6301662,,,,,,,,,,,,,,,,, -73600,0.21330321,1.6358702,,,,,,,,,,,,,,,,, -73700,0.19595242,1.6883819,,,,,,,,,,,,,,,,, -73800,0.2018957,1.5921605,,,,,,,,,,,,,,,,, -73900,0.20613958,1.6638564,,,,,,,,,,,,,,,,, -74000,0.214151,1.6793964,,,,,,,,,,,,,,,,, -74100,0.19172694,1.6207685,,,,,,,,,,,,,,,,, -74200,0.21447816,1.6413832,,,,,,,,,,,,,,,,, -74300,0.19123688,1.6759762,,,,,,,,,,,,,,,,, -74400,0.22278948,1.7137632,,,,,,,,,,,,,,,,, -74500,0.21441658,1.6898581,,,,,,,,,,,,,,,,, -74600,0.18920301,1.6036944,,,,,,,,,,,,,,,,, -74700,0.19224307,1.6624997,,,,,,,,,,,,,,,,, -74800,0.20424072,1.6898577,,,,,,,,,,,,,,,,, -74900,0.20849563,1.6269242,,,,,,,,,,,,,,,,, -75000,0.2330903,1.7092838,,,,,,,,,,,,,,,,, -75100,0.20709029,1.6244895,,,,,,,,,,,,,,,,, -75197,,,0.6884849667549133,1.4089932441711426,34.55588996153888,0.6790616512298584,1.4695286750793457,29.55969900752161,3000.0,0.6939283013343811,1.3758292198181152,29.33074891716061,3003.0,26922.27769780159,44961.1496899128,26922.27769780159,18035.313774347305,1.0108327865600586,0.0 -75200,0.20312576,1.7675049,,,,,,,,,,,,,,,,, -75300,0.20376907,1.6563865,,,,,,,,,,,,,,,,, -75400,0.20770185,1.7072555,,,,,,,,,,,,,,,,, -75500,0.18931389,1.5646267,,,,,,,,,,,,,,,,, -75600,0.19292144,1.6054919,,,,,,,,,,,,,,,,, -75700,0.2182149,1.6972171,,,,,,,,,,,,,,,,, -75800,0.2007966,1.6559409,,,,,,,,,,,,,,,,, -75900,0.21423122,1.6566106,,,,,,,,,,,,,,,,, -76000,0.214423,1.6166023,,,,,,,,,,,,,,,,, -76100,0.20821647,1.6127347,,,,,,,,,,,,,,,,, -76200,0.20582262,1.6236515,,,,,,,,,,,,,,,,, -76300,0.19769298,1.6444885,,,,,,,,,,,,,,,,, -76400,0.1991342,1.6429652,,,,,,,,,,,,,,,,, -76500,0.20990846,1.672462,,,,,,,,,,,,,,,,, -76600,0.20251007,1.6499119,,,,,,,,,,,,,,,,, -76700,0.19721639,1.6726795,,,,,,,,,,,,,,,,, -76800,0.19623534,1.6835481,,,,,,,,,,,,,,,,, -76900,0.19399653,1.6116978,,,,,,,,,,,,,,,,, -77000,0.20396103,1.6547015,,,,,,,,,,,,,,,,, -77100,0.21357982,1.6023568,,,,,,,,,,,,,,,,, -77200,0.20931119,1.5999305,,,,,,,,,,,,,,,,, -77300,0.19957006,1.721056,,,,,,,,,,,,,,,,, -77400,0.20997913,1.6424907,,,,,,,,,,,,,,,,, -77500,0.20055456,1.6534625,,,,,,,,,,,,,,,,, -77546,,,0.6647745966911316,1.5595890283584597,33.22219566753689,0.6787516474723816,1.46135675907135,29.46783892268668,3000.0,0.6951600909233093,1.3723912239074707,29.366417158019804,3003.0,27762.29205560684,46386.2266972065,27762.29205560684,18620.25772929192,1.0464684963226318,0.0 -77600,0.1997289,1.6335794,,,,,,,,,,,,,,,,, -77700,0.18461825,1.5824864,,,,,,,,,,,,,,,,, -77800,0.21695364,1.659665,,,,,,,,,,,,,,,,, -77900,0.20234574,1.6713626,,,,,,,,,,,,,,,,, -78000,0.22845027,1.7026196,,,,,,,,,,,,,,,,, -78100,0.21326622,1.6571211,,,,,,,,,,,,,,,,, -78200,0.20382647,1.6193944,,,,,,,,,,,,,,,,, -78300,0.22148784,1.6578798,,,,,,,,,,,,,,,,, -78400,0.20805283,1.5787076,,,,,,,,,,,,,,,,, -78500,0.20127995,1.5969634,,,,,,,,,,,,,,,,, -78600,0.20205523,1.6389235,,,,,,,,,,,,,,,,, -78700,0.21632273,1.6562688,,,,,,,,,,,,,,,,, -78800,0.19337726,1.6202902,,,,,,,,,,,,,,,,, -78900,0.20017475,1.644011,,,,,,,,,,,,,,,,, -79000,0.21358702,1.5959929,,,,,,,,,,,,,,,,, -79100,0.2037278,1.644189,,,,,,,,,,,,,,,,, -79200,0.2017425,1.5852901,,,,,,,,,,,,,,,,, -79300,0.20412913,1.6723189,,,,,,,,,,,,,,,,, -79400,0.19348413,1.6341627,,,,,,,,,,,,,,,,, -79500,0.23801397,1.6889956,,,,,,,,,,,,,,,,, -79600,0.22223084,1.5969551,,,,,,,,,,,,,,,,, -79700,0.20107618,1.6424065,,,,,,,,,,,,,,,,, -79800,0.2016676,1.66075,,,,,,,,,,,,,,,,, -79896,,,0.6653583645820618,1.56242573261261,32.828529578925284,0.6814670562744141,1.453194260597229,29.776679592171543,3000.0,0.6968799233436584,1.3566237688064575,29.39433137317041,3003.0,28602.22843670845,47845.993671655655,28602.22843670845,19239.971554994583,1.083417892456055,0.0 -79900,0.20627019,1.5872209,,,,,,,,,,,,,,,,, -80000,0.19503772,1.6609827,,,,,,,,,,,,,,,,, -80100,0.205136,1.6511496,,,,,,,,,,,,,,,,, -80200,0.19054323,1.571163,,,,,,,,,,,,,,,,, -80300,0.19866738,1.622282,,,,,,,,,,,,,,,,, -80400,0.20438653,1.624765,,,,,,,,,,,,,,,,, -80500,0.20488319,1.6112865,,,,,,,,,,,,,,,,, -80600,0.2010725,1.6368444,,,,,,,,,,,,,,,,, -80700,0.20089228,1.5763555,,,,,,,,,,,,,,,,, -80800,0.20661354,1.6772354,,,,,,,,,,,,,,,,, -80900,0.19982205,1.5710473,,,,,,,,,,,,,,,,, -81000,0.21625538,1.6045233,,,,,,,,,,,,,,,,, -81100,0.21304992,1.6985251,,,,,,,,,,,,,,,,, -81200,0.20265879,1.5809524,,,,,,,,,,,,,,,,, -81300,0.18756537,1.637246,,,,,,,,,,,,,,,,, -81400,0.21963945,1.6374981,,,,,,,,,,,,,,,,, -81500,0.20030703,1.6010948,,,,,,,,,,,,,,,,, -81600,0.19702014,1.6804072,,,,,,,,,,,,,,,,, -81700,0.20905605,1.5723543,,,,,,,,,,,,,,,,, -81800,0.2339193,1.5863062,,,,,,,,,,,,,,,,, -81900,0.19728072,1.6551431,,,,,,,,,,,,,,,,, -82000,0.20051552,1.6589087,,,,,,,,,,,,,,,,, -82100,0.22690128,1.6148146,,,,,,,,,,,,,,,,, -82200,0.23051158,1.6236581,,,,,,,,,,,,,,,,, -82246,,,0.6765756011009216,1.4844000339508057,33.29825089454542,0.6817150115966797,1.444311022758484,29.627096514149848,3000.0,0.6970542073249817,1.3562180995941162,29.67077386186465,3003.0,29442.20606327057,49370.74531555176,29442.20606327057,19924.62496495247,1.1249215602874756,0.0 -82300,0.2072578,1.6674705,,,,,,,,,,,,,,,,, -82400,0.21519184,1.6095619,,,,,,,,,,,,,,,,, -82500,0.18684167,1.6242032,,,,,,,,,,,,,,,,, -82600,0.21142864,1.6486529,,,,,,,,,,,,,,,,, -82700,0.19819829,1.6817309,,,,,,,,,,,,,,,,, -82800,0.2028417,1.5942549,,,,,,,,,,,,,,,,, -82900,0.19838953,1.5833709,,,,,,,,,,,,,,,,, -83000,0.1953843,1.5562884,,,,,,,,,,,,,,,,, -83100,0.19561996,1.5627316,,,,,,,,,,,,,,,,, -83200,1.2967836,1.5680176,,,,,,,,,,,,,,,,, -83300,0.20356742,1.700491,,,,,,,,,,,,,,,,, -83400,0.22323701,1.5740598,,,,,,,,,,,,,,,,, -83500,0.20578009,1.674501,,,,,,,,,,,,,,,,, -83600,0.21561679,1.6050363,,,,,,,,,,,,,,,,, -83700,0.20719457,1.6588825,,,,,,,,,,,,,,,,, -83800,0.2125126,1.6971115,,,,,,,,,,,,,,,,, -83900,0.20336144,1.6050781,,,,,,,,,,,,,,,,, -84000,1.1719495,1.6131461,,,,,,,,,,,,,,,,, -84100,0.21639547,1.6001799,,,,,,,,,,,,,,,,, -84200,0.19863158,1.6023798,,,,,,,,,,,,,,,,, -84300,0.21237892,1.6756226,,,,,,,,,,,,,,,,, -84400,0.21983235,1.6166291,,,,,,,,,,,,,,,,, -84500,0.20503855,1.6393268,,,,,,,,,,,,,,,,, -84596,,,0.670854926109314,1.519789695739746,33.14134605170951,0.6831161379814148,1.4388524293899536,29.692176572418138,3000.0,0.6992272734642029,1.3421218395233154,29.647750123973445,3003.0,30282.19902396202,50771.86794400215,30282.19902396202,20485.63865876197,1.1610476970672607,0.0 -84600,0.19912119,1.6058824,,,,,,,,,,,,,,,,, -84700,0.20239317,1.5539474,,,,,,,,,,,,,,,,, -84800,0.19902079,1.6109556,,,,,,,,,,,,,,,,, -84900,0.38153836,1.6105278,,,,,,,,,,,,,,,,, -85000,0.19259028,1.5666716,,,,,,,,,,,,,,,,, -85100,0.20382407,1.6132858,,,,,,,,,,,,,,,,, -85200,0.19803971,1.5652215,,,,,,,,,,,,,,,,, -85300,0.21696104,1.7266241,,,,,,,,,,,,,,,,, -85400,0.19718704,1.5945485,,,,,,,,,,,,,,,,, -85500,0.20126578,1.6184518,,,,,,,,,,,,,,,,, -85600,0.2112407,1.5829574,,,,,,,,,,,,,,,,, -85700,0.21414201,1.5804777,,,,,,,,,,,,,,,,, -85800,0.21151054,1.6310177,,,,,,,,,,,,,,,,, -85900,0.2052557,1.5927795,,,,,,,,,,,,,,,,, -86000,0.21968544,1.5917674,,,,,,,,,,,,,,,,, -86100,0.19622897,1.6213214,,,,,,,,,,,,,,,,, -86200,0.21047612,1.5250849,,,,,,,,,,,,,,,,, -86300,0.22703373,1.6492908,,,,,,,,,,,,,,,,, -86400,0.21246321,1.5687172,,,,,,,,,,,,,,,,, -86500,0.21350858,1.639308,,,,,,,,,,,,,,,,, -86600,0.21606538,1.611107,,,,,,,,,,,,,,,,, -86700,0.20892873,1.6638798,,,,,,,,,,,,,,,,, -86800,0.21211001,1.5313077,,,,,,,,,,,,,,,,, -86900,0.23274235,1.590816,,,,,,,,,,,,,,,,, -86946,,,0.6692231893539429,1.5366313457489014,33.56647363891882,0.6852735877037048,1.4326162338256836,30.41250771457849,3000.0,0.7010865211486816,1.3369659185409546,30.05570302170628,3003.0,31122.1114256382,52163.96068120003,31122.1114256382,21037.6956114769,1.203956127166748,0.0 -87000,0.2031697,1.5484525,,,,,,,,,,,,,,,,, -87100,0.21063098,1.6812931,,,,,,,,,,,,,,,,, -87200,0.19054566,1.6032966,,,,,,,,,,,,,,,,, -87300,0.587601,1.5940518,,,,,,,,,,,,,,,,, -87400,0.5468998,1.6424513,,,,,,,,,,,,,,,,, -87500,0.18968019,1.4805444,,,,,,,,,,,,,,,,, -87600,0.21293217,1.5729351,,,,,,,,,,,,,,,,, -87700,0.20476598,1.6359429,,,,,,,,,,,,,,,,, -87800,0.22814783,1.6268464,,,,,,,,,,,,,,,,, -87900,0.20274667,1.5643035,,,,,,,,,,,,,,,,, -88000,0.20476751,1.6193094,,,,,,,,,,,,,,,,, -88100,0.20350055,1.6070187,,,,,,,,,,,,,,,,, -88200,0.20799601,1.6023316,,,,,,,,,,,,,,,,, -88300,0.21725443,1.5416598,,,,,,,,,,,,,,,,, -88400,0.22352125,1.7118301,,,,,,,,,,,,,,,,, -88500,0.19441749,1.5617594,,,,,,,,,,,,,,,,, -88600,0.19769008,1.5180697,,,,,,,,,,,,,,,,, -88700,0.21351387,1.6568123,,,,,,,,,,,,,,,,, -88800,0.20136699,1.6011997,,,,,,,,,,,,,,,,, -88900,0.21067424,1.5466298,,,,,,,,,,,,,,,,, -89000,0.21891327,1.5291002,,,,,,,,,,,,,,,,, -89100,0.21416222,1.605904,,,,,,,,,,,,,,,,, -89200,0.2249588,1.6129915,,,,,,,,,,,,,,,,, -89296,,,0.675879180431366,1.4836534261703491,34.04868441106427,0.6848767995834351,1.427587628364563,30.374787594004687,3000.0,0.7005636096000671,1.3306442499160769,29.961738819264145,3003.0,31962.09663248062,53538.661982774734,31962.09663248062,21572.29588246345,1.2391314506530762,0.0 -89300,0.20602037,1.6332974,,,,,,,,,,,,,,,,, -89400,0.20196338,1.5152494,,,,,,,,,,,,,,,,, -89500,0.20931347,1.6522963,,,,,,,,,,,,,,,,, -89600,0.21387401,1.6406995,,,,,,,,,,,,,,,,, -89700,0.21373932,1.5435177,,,,,,,,,,,,,,,,, -89800,0.22503556,1.6173564,,,,,,,,,,,,,,,,, -89900,0.20754161,1.5375206,,,,,,,,,,,,,,,,, -90000,0.20379715,1.6812292,,,,,,,,,,,,,,,,, -90100,0.20831457,1.6689905,,,,,,,,,,,,,,,,, -90200,0.20713498,1.6316468,,,,,,,,,,,,,,,,, -90300,0.1998844,1.626295,,,,,,,,,,,,,,,,, -90400,0.20941105,1.6731827,,,,,,,,,,,,,,,,, -90500,0.21535736,1.58483,,,,,,,,,,,,,,,,, -90600,0.2118041,1.6779078,,,,,,,,,,,,,,,,, -90700,0.22545749,1.6277891,,,,,,,,,,,,,,,,, -90800,0.25337386,1.6495719,,,,,,,,,,,,,,,,, -90900,0.1951375,1.4863356,,,,,,,,,,,,,,,,, -91000,0.2167421,1.5893222,,,,,,,,,,,,,,,,, -91100,0.20489413,1.5753342,,,,,,,,,,,,,,,,, -91200,0.20688388,1.5898459,,,,,,,,,,,,,,,,, -91300,0.19588351,1.4949199,,,,,,,,,,,,,,,,, -91400,0.19885243,1.5162143,,,,,,,,,,,,,,,,, -91500,0.21467158,1.5707196,,,,,,,,,,,,,,,,, -91600,0.20294449,1.6648898,,,,,,,,,,,,,,,,, -91647,,,0.6721989512443542,1.515886306762695,33.45237889132655,0.6875798106193542,1.4213722944259644,30.52288304117627,3000.0,0.7031317353248596,1.3237190246582031,30.19090552965387,3003.0,32802.2980568409,54922.926729917526,32802.2980568409,22116.24362039566,1.2771565914154053,0.0 -91700,0.21072079,1.6028558,,,,,,,,,,,,,,,,, -91800,0.32342508,1.6086763,,,,,,,,,,,,,,,,, -91900,0.19894603,1.5446061,,,,,,,,,,,,,,,,, -92000,0.22223306,1.6167436,,,,,,,,,,,,,,,,, -92100,0.20630547,1.5010169,,,,,,,,,,,,,,,,, -92200,0.20925324,1.526727,,,,,,,,,,,,,,,,, -92300,0.21316543,1.5339477,,,,,,,,,,,,,,,,, -92400,0.2062984,1.5875776,,,,,,,,,,,,,,,,, -92500,0.22304554,1.5365826,,,,,,,,,,,,,,,,, -92600,0.20950353,1.5381525,,,,,,,,,,,,,,,,, -92700,0.23712967,1.6439745,,,,,,,,,,,,,,,,, -92800,0.57716274,1.5542337,,,,,,,,,,,,,,,,, -92900,0.20772725,1.619464,,,,,,,,,,,,,,,,, -93000,0.2258851,1.645807,,,,,,,,,,,,,,,,, -93100,1.4538417,1.5068023,,,,,,,,,,,,,,,,, -93200,0.21125147,1.5624634,,,,,,,,,,,,,,,,, -93300,0.22530358,1.5428987,,,,,,,,,,,,,,,,, -93400,0.21400724,1.6256754,,,,,,,,,,,,,,,,, -93500,0.20976503,1.5699733,,,,,,,,,,,,,,,,, -93600,0.21171695,1.6102484,,,,,,,,,,,,,,,,, -93700,0.22927807,1.6743052,,,,,,,,,,,,,,,,, -93800,0.22978449,1.5605456,,,,,,,,,,,,,,,,, -93900,0.2181466,1.632631,,,,,,,,,,,,,,,,, -93998,,,0.6907885074615479,1.397413969039917,35.22884127960909,0.6862407326698303,1.4174182415008545,30.26388804807659,3000.0,0.7028993368148804,1.315637707710266,29.985501215724845,3003.0,33642.41419363022,56323.84543085098,33642.41419363022,22676.927169799805,1.3175652027130127,0.0 -94000,0.1992713,1.4769875,,,,,,,,,,,,,,,,, -94100,0.21295221,1.5699059,,,,,,,,,,,,,,,,, -94200,0.21451432,1.5802938,,,,,,,,,,,,,,,,, -94300,0.23658615,1.5687441,,,,,,,,,,,,,,,,, -94400,0.2249635,1.5595737,,,,,,,,,,,,,,,,, -94500,0.21857168,1.6388588,,,,,,,,,,,,,,,,, -94600,0.20279215,1.5195788,,,,,,,,,,,,,,,,, -94700,0.21931964,1.562892,,,,,,,,,,,,,,,,, -94800,0.22047544,1.5553155,,,,,,,,,,,,,,,,, -94900,0.21143657,1.5457511,,,,,,,,,,,,,,,,, -95000,0.2205131,1.6534525,,,,,,,,,,,,,,,,, -95100,0.22066003,1.6157756,,,,,,,,,,,,,,,,, -95200,0.21585211,1.6189001,,,,,,,,,,,,,,,,, -95300,0.21214698,1.5935104,,,,,,,,,,,,,,,,, -95400,0.35429233,1.5212201,,,,,,,,,,,,,,,,, -95500,0.30022964,1.6222814,,,,,,,,,,,,,,,,, -95600,0.22363572,1.5947485,,,,,,,,,,,,,,,,, -95700,0.2170928,1.6479652,,,,,,,,,,,,,,,,, -95800,0.31941566,1.5707703,,,,,,,,,,,,,,,,, -95900,0.23239146,1.6309485,,,,,,,,,,,,,,,,, -96000,0.22011648,1.5147282,,,,,,,,,,,,,,,,, -96100,0.22807121,1.5982331,,,,,,,,,,,,,,,,, -96200,0.44801182,1.558755,,,,,,,,,,,,,,,,, -96300,0.21083942,1.5826561,,,,,,,,,,,,,,,,, -96348,,,0.680406928062439,1.4625781774520874,33.95429486501357,0.6884601712226868,1.409622311592102,30.50475710065807,3000.0,0.7055023312568665,1.3074795007705688,30.463050020198107,3003.0,34482.45923447609,57838.22999000549,34482.45923447609,23351.149356365204,1.3545808792114258,0.0 -96400,0.25637674,1.5555501,,,,,,,,,,,,,,,,, -96500,0.2247897,1.5722964,,,,,,,,,,,,,,,,, -96600,0.24305548,1.6047482,,,,,,,,,,,,,,,,, -96700,0.20526128,1.6213735,,,,,,,,,,,,,,,,, -96800,0.21153396,1.6134158,,,,,,,,,,,,,,,,, -96900,0.22508852,1.5564717,,,,,,,,,,,,,,,,, -97000,0.20524204,1.582295,,,,,,,,,,,,,,,,, -97100,0.267061,1.6629292,,,,,,,,,,,,,,,,, -97200,0.22444233,1.5492086,,,,,,,,,,,,,,,,, -97300,0.23207399,1.5666248,,,,,,,,,,,,,,,,, -97400,0.2146509,1.5867031,,,,,,,,,,,,,,,,, -97500,0.22982912,1.5514246,,,,,,,,,,,,,,,,, -97600,0.21934637,1.5560539,,,,,,,,,,,,,,,,, -97700,0.21629961,1.5138642,,,,,,,,,,,,,,,,, -97800,0.21712388,1.6402249,,,,,,,,,,,,,,,,, -97900,0.21375014,1.5580889,,,,,,,,,,,,,,,,, -98000,0.20835795,1.5476979,,,,,,,,,,,,,,,,, -98100,0.23043905,1.5980165,,,,,,,,,,,,,,,,, -98200,0.21934453,1.6284424,,,,,,,,,,,,,,,,, -98300,0.22807264,1.5545704,,,,,,,,,,,,,,,,, -98400,0.22777575,1.5047553,,,,,,,,,,,,,,,,, -98500,0.21855302,1.5319092,,,,,,,,,,,,,,,,, -98600,0.22917798,1.5400778,,,,,,,,,,,,,,,,, -98698,,,0.6774348616600037,1.4779555797576904,34.15186940985161,0.6887453198432922,1.4022163152694702,30.495876301755462,3000.0,0.7058973908424377,1.300991177558899,30.41957079143748,3003.0,35322.54058718681,59280.45553016663,35322.54058718681,23953.17521595955,1.3932616710662842,0.0 -98700,0.21811034,1.570013,,,,,,,,,,,,,,,,, -98800,0.2196758,1.4712083,,,,,,,,,,,,,,,,, -98900,0.21411079,1.5448691,,,,,,,,,,,,,,,,, -99000,0.21306063,1.5018796,,,,,,,,,,,,,,,,, -99100,0.20854558,1.4912683,,,,,,,,,,,,,,,,, -99200,0.22622429,1.5245368,,,,,,,,,,,,,,,,, -99300,0.21564418,1.5519994,,,,,,,,,,,,,,,,, -99400,0.22361144,1.4724942,,,,,,,,,,,,,,,,, -99500,0.21123745,1.5339092,,,,,,,,,,,,,,,,, -99600,0.24895678,1.4993079,,,,,,,,,,,,,,,,, -99700,0.22097893,1.5557426,,,,,,,,,,,,,,,,, -99800,0.22488306,1.5704336,,,,,,,,,,,,,,,,, -99900,0.2151204,1.45155,,,,,,,,,,,,,,,,, -100000,0.23012961,1.540352,,,,,,,,,,,,,,,,, -100100,0.22767955,1.5727872,,,,,,,,,,,,,,,,, -100200,0.2182219,1.4979526,,,,,,,,,,,,,,,,, -100300,0.20718507,1.4992812,,,,,,,,,,,,,,,,, -100400,0.21354362,1.5548071,,,,,,,,,,,,,,,,, -100500,0.2203186,1.6408538,,,,,,,,,,,,,,,,, -100600,0.21766202,1.4886153,,,,,,,,,,,,,,,,, -100700,0.22368722,1.5682957,,,,,,,,,,,,,,,,, -100800,0.22490084,1.512476,,,,,,,,,,,,,,,,, -100900,0.2214874,1.5493507,,,,,,,,,,,,,,,,, -101000,0.24614142,1.5608239,,,,,,,,,,,,,,,,, -101047,,,0.6855158805847168,1.433314561843872,34.851935506476764,0.6893280744552612,1.3989043235778809,30.42614686854671,3000.0,0.7059903740882874,1.296967625617981,30.442794792692226,3003.0,36162.63892054558,60693.32236337662,36162.63892054558,24525.823969364166,1.4304945468902588,0.0 -101100,0.21338336,1.5919181,,,,,,,,,,,,,,,,, -101200,0.21521275,1.4988664,,,,,,,,,,,,,,,,, -101300,0.23579681,1.5349739,,,,,,,,,,,,,,,,, -101400,0.21515658,1.5672644,,,,,,,,,,,,,,,,, -101500,0.22247426,1.5498114,,,,,,,,,,,,,,,,, -101600,0.2225002,1.5909534,,,,,,,,,,,,,,,,, -101700,0.21950643,1.5060613,,,,,,,,,,,,,,,,, -101800,0.20631424,1.4984882,,,,,,,,,,,,,,,,, -101900,0.22140832,1.5887637,,,,,,,,,,,,,,,,, -102000,0.24069893,1.5366205,,,,,,,,,,,,,,,,, -102100,0.22016735,1.4798719,,,,,,,,,,,,,,,,, -102200,0.22383668,1.5233314,,,,,,,,,,,,,,,,, -102300,0.22301932,1.5273116,,,,,,,,,,,,,,,,, -102400,0.2176094,1.5250297,,,,,,,,,,,,,,,,, -102500,0.21988288,1.5326158,,,,,,,,,,,,,,,,, -102600,0.22434048,1.5205787,,,,,,,,,,,,,,,,, -102700,0.2419004,1.6542281,,,,,,,,,,,,,,,,, -102800,0.2105991,1.4374278,,,,,,,,,,,,,,,,, -102900,0.23141846,1.5730877,,,,,,,,,,,,,,,,, -103000,0.2144187,1.4897234,,,,,,,,,,,,,,,,, -103100,0.22153683,1.524421,,,,,,,,,,,,,,,,, -103200,0.22279564,1.5689442,,,,,,,,,,,,,,,,, -103300,0.22370653,1.5328095,,,,,,,,,,,,,,,,, -103397,,,0.6817169785499573,1.4496946334838867,34.91150543285352,0.6903820037841797,1.3946393728256226,30.451516312385404,3000.0,0.7083725929260254,1.2883847951889038,30.640346635700297,3003.0,37002.79528737068,62197.84350180626,37002.79528737068,25190.066334486008,1.471330642700195,0.0 -103400,0.21637157,1.5390797,,,,,,,,,,,,,,,,, -103500,0.210885,1.4686205,,,,,,,,,,,,,,,,, -103600,0.20893885,1.5057532,,,,,,,,,,,,,,,,, -103700,0.21605663,1.4602219,,,,,,,,,,,,,,,,, -103800,0.2275874,1.6128404,,,,,,,,,,,,,,,,, -103900,0.21177061,1.5306281,,,,,,,,,,,,,,,,, -104000,0.2301999,1.5507475,,,,,,,,,,,,,,,,, -104100,0.23091726,1.508759,,,,,,,,,,,,,,,,, -104200,0.2354601,1.5791818,,,,,,,,,,,,,,,,, -104300,0.21426943,1.5177033,,,,,,,,,,,,,,,,, -104400,0.22488019,1.5095354,,,,,,,,,,,,,,,,, -104500,0.21550709,1.5325387,,,,,,,,,,,,,,,,, -104600,0.22140613,1.5243295,,,,,,,,,,,,,,,,, -104700,0.23541392,1.5736141,,,,,,,,,,,,,,,,, -104800,0.2194704,1.4615235,,,,,,,,,,,,,,,,, -104900,0.21488194,1.5430223,,,,,,,,,,,,,,,,, -105000,0.23457871,1.5451847,,,,,,,,,,,,,,,,, -105100,0.22230709,1.5193645,,,,,,,,,,,,,,,,, -105200,0.23176569,1.5540466,,,,,,,,,,,,,,,,, -105300,0.21526317,1.5260928,,,,,,,,,,,,,,,,, -105400,0.21857862,1.4223912,,,,,,,,,,,,,,,,, -105500,0.21499556,1.4998252,,,,,,,,,,,,,,,,, -105600,0.23094101,1.4787309,,,,,,,,,,,,,,,,, -105700,0.22297631,1.5481776,,,,,,,,,,,,,,,,, -105747,,,0.683246910572052,1.4504034519195557,34.73362591682547,0.6902952194213867,1.389479160308838,30.559631401457903,3000.0,0.7087560296058655,1.285719394683838,30.557939741522667,3003.0,37842.91463184357,63620.76156711578,37842.91463184357,25772.742190361023,1.5127499103546145,0.0 -105800,0.25350574,1.513912,,,,,,,,,,,,,,,,, -105900,0.2412702,1.5689006,,,,,,,,,,,,,,,,, -106000,0.23676778,1.5592207,,,,,,,,,,,,,,,,, -106100,0.2193917,1.4624894,,,,,,,,,,,,,,,,, -106200,0.22893761,1.5762312,,,,,,,,,,,,,,,,, -106300,0.22119401,1.4983637,,,,,,,,,,,,,,,,, -106400,0.21756648,1.4869889,,,,,,,,,,,,,,,,, -106500,0.22778118,1.4944546,,,,,,,,,,,,,,,,, -106600,0.2245678,1.5074184,,,,,,,,,,,,,,,,, -106700,0.22573975,1.5156655,,,,,,,,,,,,,,,,, -106800,0.22147171,1.4560665,,,,,,,,,,,,,,,,, -106900,0.2375986,1.5039723,,,,,,,,,,,,,,,,, -107000,0.22887802,1.5411736,,,,,,,,,,,,,,,,, -107100,0.214711,1.5767989,,,,,,,,,,,,,,,,, -107200,0.26473466,1.5579267,,,,,,,,,,,,,,,,, -107300,0.24349748,1.4370539,,,,,,,,,,,,,,,,, -107400,0.22217983,1.502231,,,,,,,,,,,,,,,,, -107500,0.22450198,1.5356908,,,,,,,,,,,,,,,,, -107600,0.23352563,1.5865107,,,,,,,,,,,,,,,,, -107700,0.2263415,1.4634501,,,,,,,,,,,,,,,,, -107800,0.2232281,1.5360632,,,,,,,,,,,,,,,,, -107900,0.23074527,1.5135591,,,,,,,,,,,,,,,,, -108000,0.21878342,1.5625433,,,,,,,,,,,,,,,,, -108097,,,0.6888919472694397,1.4076614379882812,35.10855108546573,0.6916466951370239,1.383272647857666,30.89753459757305,3000.0,0.7101156115531921,1.277177333831787,30.65263189003389,3003.0,38682.86744952202,65031.18496370316,38682.86744952202,26343.09537410736,1.5515015125274658,0.0 -108100,0.22623064,1.4923317,,,,,,,,,,,,,,,,, -108200,0.22564171,1.5172627,,,,,,,,,,,,,,,,, -108300,0.2369023,1.5348812,,,,,,,,,,,,,,,,, -108400,0.22312687,1.4607656,,,,,,,,,,,,,,,,, -108500,0.21278103,1.464826,,,,,,,,,,,,,,,,, -108600,0.2225091,1.5128942,,,,,,,,,,,,,,,,, -108700,0.22106351,1.5299853,,,,,,,,,,,,,,,,, -108800,0.22193223,1.5438275,,,,,,,,,,,,,,,,, -108900,0.24629135,1.5597863,,,,,,,,,,,,,,,,, -109000,0.22492091,1.5197382,,,,,,,,,,,,,,,,, -109100,0.23255323,1.4509273,,,,,,,,,,,,,,,,, -109200,0.21053497,1.4784497,,,,,,,,,,,,,,,,, -109300,0.21419403,1.4153693,,,,,,,,,,,,,,,,, -109400,0.22389911,1.4564501,,,,,,,,,,,,,,,,, -109500,0.23110949,1.5056446,,,,,,,,,,,,,,,,, -109600,0.22674556,1.4287895,,,,,,,,,,,,,,,,, -109700,0.236242,1.5100996,,,,,,,,,,,,,,,,, -109800,0.2290146,1.4347934,,,,,,,,,,,,,,,,, -109900,0.23231742,1.5754011,,,,,,,,,,,,,,,,, -110000,0.22103037,1.4698211,,,,,,,,,,,,,,,,, -110100,0.22957335,1.4898137,,,,,,,,,,,,,,,,, -110200,0.21732527,1.4572405,,,,,,,,,,,,,,,,, -110300,0.24300258,1.4521677,,,,,,,,,,,,,,,,, -110400,0.23155066,1.4743994,,,,,,,,,,,,,,,,, -110446,,,0.6874507069587708,1.4143240451812744,34.81838615559656,0.6933826208114624,1.3791488409042358,30.886477140805987,3000.0,0.7107431292533875,1.2718663215637207,30.87753114059596,3003.0,39522.98477363586,66513.68182969093,39522.98477363586,26985.349088191982,1.5927612781524658,0.0 -110446,,,,,,,,,,,,,,39522.984773635864,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 89210b67a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,50 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -769.7907042503357,0.0,20.054532766342163,1,0,20.054532766342163,1.2727720309210526,95000000,789.845296382904,1.2750268968396217,1.27203968947502,83274637 -1389.8240342140198,0.0265195369720459,140.60763335227966,183,0,140.60763335227966,0.1314100198190789,95000000,1530.4645409584043,0.1306915200614141,0.1286313995871501,83274637 -2009.2675638198853,0.0491020679473876,260.8166162967682,366,0,260.8166162967682,0.1302061311369243,95000000,2270.1458444595337,0.1270057580919948,0.1276038683119875,83274637 -2629.859326124192,0.069005012512207,381.08206129074097,544,0,381.08206129074097,0.1293979709909539,95000000,3011.0285456180573,0.1258156687935006,0.1267161050770326,83274637 -3231.0204243659973,0.0892062187194824,501.1722431182861,722,0,501.1722431182861,0.1297202143708881,95000000,3732.305637359619,0.1254300063464251,0.1272029359363006,83274637 -3817.221437215805,0.110586404800415,621.7608590126038,905,0,621.7608590126038,0.1292632202508223,95000000,4439.122531175613,0.1261651381123928,0.1265232753421894,83274637 -4380.449725151062,0.1312270164489746,742.3804759979248,1083,0,742.3804759979248,0.1282775658614309,95000000,5122.996841669083,0.1266514414909688,0.1258599507896233,83274637 -4920.549924612045,0.1552424430847168,862.654314994812,1259,0,862.654314994812,0.1279592428659539,95000000,5783.400713682175,0.1241250919916157,0.1255488845671694,83274637 -5455.457439184189,0.1811528205871582,982.6249959468842,1444,0,982.6249959468842,0.1280510447985197,95000000,6438.310907840729,0.1254901969981081,0.1256811114304696,83274637 -5988.156452178955,0.2012228965759277,1103.1413357257843,1623,0,1103.1413357257843,0.1279304759046052,95000000,7091.552002668381,0.1250005254885123,0.1256984208138439,83274637 -6501.290910720825,0.2238349914550781,1223.7536084651947,1809,0,1223.7536084651947,0.1277026213199013,95000000,7725.327343940735,0.1246529179595926,0.1253543950441038,83274637 -6971.628367900848,0.2449679374694824,1344.3381459712982,1984,0,1344.3381459712982,0.1277069748046875,95000000,8316.276195287704,0.1249708107455908,0.1252997718687706,83274637 -7494.321210861206,0.2679464817047119,1464.9220495224,2170,0,1464.9220495224,0.1278315850637335,95000000,8959.58207988739,0.125709370668954,0.1256153176448625,83274637 -8021.292371273041,0.2880659103393554,1585.228083372116,2349,0,1585.228083372116,0.1276712290810032,95000000,9606.885175466536,0.1225543501171863,0.1253291580202077,83274637 -8552.149214029312,0.3084101676940918,1705.2277812957764,2525,0,1705.2277812957764,0.1273398111636513,95000000,10257.767736196518,0.123191749021036,0.1250012444585731,83274637 -9091.789171218872,0.3289697170257568,1825.53936123848,2702,0,1825.53936123848,0.1275802810958059,95000000,10917.745633125303,0.1248918822351491,0.1253410620845944,83274637 -9606.8914167881,0.3496608734130859,1945.523303985596,2878,0,1945.523303985596,0.1273002328330592,95000000,11552.858093500136,0.1217001618922881,0.1248983065180702,83274637 -10128.94816493988,0.3706555366516113,2065.6756069660187,3053,0,2065.6756069660187,0.1271365337890625,95000000,12195.09390592575,0.1233305929685538,0.1248150544981511,83274637 -10637.009717226028,0.3918793201446533,2185.868538618088,3232,0,2185.868538618088,0.1272028106085526,95000000,12823.375492572784,0.1239032166535561,0.1248819490923223,83274637 -11177.440727949142,0.4134020805358886,2306.386968612671,3408,0,2306.386968612671,0.1271035701788651,95000000,13484.352406978607,0.1238242089279792,0.1249167751339725,83274637 -11713.505402088163,0.4371469020843506,2426.4390614032745,3584,0,2426.4390614032745,0.1273806009662829,95000000,14140.498716831207,0.1221014177485269,0.1249408105908345,83274637 -12237.177653312683,0.4585399627685547,2546.913526058197,3770,0,2546.913526058197,0.1268082985711348,95000000,14784.672607421877,0.1238716137156171,0.1245102247641083,83274637 -12788.11208820343,0.4802253246307373,2666.8965351581573,3943,0,2666.8965351581573,0.1268686580489309,95000000,15455.617283582687,0.1250578743997235,0.124649814354405,83274637 -13339.73760986328,0.5014171600341797,2786.889313220978,4122,0,2786.889313220978,0.1278794684004934,95000000,16127.262543678284,0.1229165908558376,0.1249230898180862,83274637 -13884.941725730896,0.5227134227752686,2907.2045364379883,4300,0,2907.2045364379883,0.1268836971833881,95000000,16792.80887746811,0.1225354496342768,0.1246359960725872,83274637 -14422.877438545229,0.5451195240020752,3027.667422056198,4478,0,3027.667422056198,0.1272580874897204,95000000,17451.23566007614,0.1250219466691871,0.1247609821932487,83274637 -14966.952553272247,0.5720090866088867,3147.6467897892,4660,0,3147.6467897892,0.1270226580797697,95000000,18115.323104143143,0.125849908112355,0.1246348594086899,83274637 -15510.190085411072,0.5935556888580322,3268.2762217521667,4843,0,3268.2762217521667,0.1269167181332237,95000000,18779.21762752533,0.1248512399271599,0.1245842440914719,83274637 -16051.788796186447,0.6148788928985596,3388.835000038147,5025,0,3388.835000038147,0.1269516108655427,95000000,19441.40255951881,0.1222759853673618,0.1246284942222076,83274637 -16596.137999773026,0.6388421058654785,3509.062576532364,5205,0,3509.062576532364,0.1265438617084704,95000000,20106.009090423584,0.1194243619113036,0.1241413238073654,83274637 -17134.551802396774,0.6600420475006104,3629.413244247437,5385,0,3629.413244247437,0.126651141118421,95000000,20764.800588607788,0.1242270162504799,0.1243186725521428,83274637 -17675.89226746559,0.6861846446990967,3749.918417453766,5567,0,3749.918417453766,0.1265954343030427,95000000,21426.67833852768,0.1234444462746944,0.124154712235632,83274637 -18215.002883434296,0.7078261375427246,3870.579716682434,5744,0,3870.579716682434,0.1266719728824013,95000000,22086.477658748627,0.1225258158344142,0.1241873344085678,83274637 -18767.438177347183,0.7333683967590332,3990.8730747699738,5920,0,3990.8730747699738,0.1264998841077302,95000000,22759.23750019073,0.1233771085950001,0.1241138011540412,83274637 -19299.25735974312,0.7545573711395264,4110.999460935593,6100,0,4110.999460935593,0.1265815414165296,95000000,23411.2101802826,0.1220550997616172,0.1241251522136813,83274637 -19849.08947825432,0.7763080596923828,4231.590196371079,6280,0,4231.590196371079,0.1264107358244243,95000000,24081.660502672195,0.1226911115219945,0.1240183384787863,83274637 -20402.014936208725,0.8041188716888428,4351.695854663849,6458,0,4351.695854663849,0.1263291104749177,95000000,24754.72527098656,0.122776152952662,0.1239687272867987,83274637 -20951.32184123993,0.8290243148803711,4472.008096218109,6636,0,4472.008096218109,0.1262632901624177,95000000,25424.37505888939,0.1219295825715522,0.1239027230424068,83274637 -21501.974306821823,0.8511290550231934,4592.035045385361,6813,0,4592.035045385361,0.1265540494551809,95000000,26095.08243894577,0.122105953348039,0.1241584364994646,83274637 -22042.21883559227,0.8733363151550293,4712.584577083588,6994,0,4712.584577083588,0.1262784921875,95000000,26755.904494524,0.1226234874518225,0.1239216095151291,83274637 -22590.956671714783,0.8951373100280762,4832.942766189575,7169,0,4832.942766189575,0.1263248577302631,95000000,27425.027945756912,0.1214132299939604,0.1239500357873445,83274637 -23129.687811613083,0.917670726776123,4953.628441810608,7343,0,4953.628441810608,0.1261030070518092,95000000,28084.47293281555,0.1239069283196011,0.1237468096242459,83274637 -23674.528027772903,0.9421184062957764,5073.737655878067,7520,0,5073.737655878067,0.1262120737561677,95000000,28749.45256137848,0.1231251848545276,0.1238554156743033,83274637 -24221.43767905236,0.9639968872070312,5194.339241266251,7697,0,5194.339241266251,0.1262973823807565,95000000,29416.991481542587,0.1225887277872307,0.1239002044617888,83274637 -24751.25191283226,0.9881429672241212,5314.863927364349,7875,0,5314.863927364349,0.1260425471217105,95000000,30067.36024856568,0.1228803222963832,0.123725748377521,83274637 -25291.850727558136,1.0113427639007568,5434.98015499115,8052,0,5434.98015499115,0.1260829233038651,95000000,30728.10421323776,0.1231917627858665,0.1237757983355874,83274637 -25834.18929052353,1.036632061004639,5555.511506319046,8234,0,5555.511506319046,0.1261384195106908,95000000,31391.005412340164,0.1203018530336378,0.123735820307683,83274637 -26379.57387113571,1.0609583854675293,5676.200478792191,8416,0,5676.200478792191,0.1260427416426809,95000000,32057.109263181686,0.1235053100910201,0.1237182156968322,83274637 -26921.09217262268,1.082719326019287,5796.7442128658295,8592,0,5796.7442128658295,0.12596203797286185,95000000,32719.198880910873,0.11987802426114022,0.12362843281709719,83274637 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index a1d051f7a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,137 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,9.193443,1.275513,,,,,,,,,,, -1,,,1.2750268968396217,1.27203968947502,83274637.0,1.2727720309210526,95000000.0,20.054532766342163,789.845296382904,20.054532766342163,769.7907042503357,0.0,0.0 -100,0.17406148,0.12947348,,,,,,,,,,, -183,,,0.1306915200614141,0.1286313995871501,83274637.0,0.1314100198190789,95000000.0,140.60763335227966,1530.4645409584043,140.60763335227966,1389.8240342140198,0.0265195369720459,0.0 -200,0.29021636,0.13765678,,,,,,,,,,, -300,0.0057042865,0.12443295,,,,,,,,,,, -366,,,0.1270057580919948,0.1276038683119875,83274637.0,0.1302061311369243,95000000.0,260.8166162967682,2270.1458444595337,260.8166162967682,2009.2675638198853,0.0491020679473876,0.0 -400,0.0069647403,0.12963495,,,,,,,,,,, -500,0.05887551,0.12835485,,,,,,,,,,, -544,,,0.1258156687935006,0.1267161050770326,83274637.0,0.1293979709909539,95000000.0,381.08206129074097,3011.0285456180573,381.08206129074097,2629.859326124192,0.069005012512207,0.0 -600,0.0086504,0.11867523,,,,,,,,,,, -700,0.037642974,0.12500732,,,,,,,,,,, -722,,,0.1254300063464251,0.1272029359363006,83274637.0,0.1297202143708881,95000000.0,501.1722431182861,3732.305637359619,501.1722431182861,3231.0204243659973,0.0892062187194824,0.0 -800,0.011826224,0.12378581,,,,,,,,,,, -900,0.15060851,0.13373075,,,,,,,,,,, -905,,,0.1261651381123928,0.1265232753421894,83274637.0,0.1292632202508223,95000000.0,621.7608590126038,4439.122531175613,621.7608590126038,3817.221437215805,0.110586404800415,0.0 -1000,0.0070499736,0.12068532,,,,,,,,,,, -1083,,,0.1266514414909688,0.1258599507896233,83274637.0,0.1282775658614309,95000000.0,742.3804759979248,5122.996841669083,742.3804759979248,4380.449725151062,0.1312270164489746,0.0 -1100,0.07079528,0.12993388,,,,,,,,,,, -1200,0.046857532,0.124414966,,,,,,,,,,, -1259,,,0.1241250919916157,0.1255488845671694,83274637.0,0.1279592428659539,95000000.0,862.654314994812,5783.400713682175,862.654314994812,4920.549924612045,0.1552424430847168,0.0 -1300,0.033175938,0.1290081,,,,,,,,,,, -1400,0.01609739,0.12736219,,,,,,,,,,, -1444,,,0.1254901969981081,0.1256811114304696,83274637.0,0.1280510447985197,95000000.0,982.6249959468842,6438.310907840729,982.6249959468842,5455.457439184189,0.1811528205871582,0.0 -1500,0.107057385,0.13133015,,,,,,,,,,, -1600,0.04195135,0.11826688,,,,,,,,,,, -1623,,,0.1250005254885123,0.1256984208138439,83274637.0,0.1279304759046052,95000000.0,1103.1413357257843,7091.552002668381,1103.1413357257843,5988.156452178955,0.2012228965759277,0.0 -1700,0.039266452,0.1299193,,,,,,,,,,, -1800,0.025198404,0.13812539,,,,,,,,,,, -1809,,,0.1246529179595926,0.1253543950441038,83274637.0,0.1277026213199013,95000000.0,1223.7536084651947,7725.327343940735,1223.7536084651947,6501.290910720825,0.2238349914550781,0.0 -1900,0.014120885,0.12303955,,,,,,,,,,, -1984,,,0.1249708107455908,0.1252997718687706,83274637.0,0.1277069748046875,95000000.0,1344.3381459712982,8316.276195287704,1344.3381459712982,6971.628367900848,0.2449679374694824,0.0 -2000,0.04692828,0.13050185,,,,,,,,,,, -2100,0.0043330505,0.12000799,,,,,,,,,,, -2170,,,0.125709370668954,0.1256153176448625,83274637.0,0.1278315850637335,95000000.0,1464.9220495224,8959.58207988739,1464.9220495224,7494.321210861206,0.2679464817047119,0.0 -2200,0.014911163,0.13657174,,,,,,,,,,, -2300,0.042424463,0.12691496,,,,,,,,,,, -2349,,,0.1225543501171863,0.1253291580202077,83274637.0,0.1276712290810032,95000000.0,1585.228083372116,9606.885175466536,1585.228083372116,8021.292371273041,0.2880659103393554,0.0 -2400,0.02029007,0.13111672,,,,,,,,,,, -2500,0.03459655,0.13408193,,,,,,,,,,, -2525,,,0.123191749021036,0.1250012444585731,83274637.0,0.1273398111636513,95000000.0,1705.2277812957764,10257.767736196518,1705.2277812957764,8552.149214029312,0.3084101676940918,0.0 -2600,0.03207355,0.12317154,,,,,,,,,,, -2700,0.047921088,0.122743286,,,,,,,,,,, -2702,,,0.1248918822351491,0.1253410620845944,83274637.0,0.1275802810958059,95000000.0,1825.53936123848,10917.745633125303,1825.53936123848,9091.789171218872,0.3289697170257568,0.0 -2800,0.046090562,0.12849593,,,,,,,,,,, -2878,,,0.1217001618922881,0.1248983065180702,83274637.0,0.1273002328330592,95000000.0,1945.523303985596,11552.858093500136,1945.523303985596,9606.8914167881,0.3496608734130859,0.0 -2900,0.009070436,0.12944296,,,,,,,,,,, -3000,0.060671777,0.1311128,,,,,,,,,,, -3053,,,0.1233305929685538,0.1248150544981511,83274637.0,0.1271365337890625,95000000.0,2065.6756069660187,12195.09390592575,2065.6756069660187,10128.94816493988,0.3706555366516113,0.0 -3100,0.009896474,0.122252755,,,,,,,,,,, -3200,0.043682203,0.12375927,,,,,,,,,,, -3232,,,0.1239032166535561,0.1248819490923223,83274637.0,0.1272028106085526,95000000.0,2185.868538618088,12823.375492572784,2185.868538618088,10637.009717226028,0.3918793201446533,0.0 -3300,0.016766196,0.12798166,,,,,,,,,,, -3400,0.012344062,0.12014663,,,,,,,,,,, -3408,,,0.1238242089279792,0.1249167751339725,83274637.0,0.1271035701788651,95000000.0,2306.386968612671,13484.352406978607,2306.386968612671,11177.440727949142,0.4134020805358886,0.0 -3500,0.01310231,0.12605622,,,,,,,,,,, -3584,,,0.1221014177485269,0.1249408105908345,83274637.0,0.1273806009662829,95000000.0,2426.4390614032745,14140.498716831207,2426.4390614032745,11713.505402088163,0.4371469020843506,0.0 -3600,0.009425149,0.12211418,,,,,,,,,,, -3700,0.0047277957,0.12497344,,,,,,,,,,, -3770,,,0.1238716137156171,0.1245102247641083,83274637.0,0.1268082985711348,95000000.0,2546.913526058197,14784.672607421877,2546.913526058197,12237.177653312683,0.4585399627685547,0.0 -3800,0.012949807,0.12357981,,,,,,,,,,, -3900,0.030650258,0.11829205,,,,,,,,,,, -3943,,,0.1250578743997235,0.124649814354405,83274637.0,0.1268686580489309,95000000.0,2666.8965351581573,15455.617283582687,2666.8965351581573,12788.11208820343,0.4802253246307373,0.0 -4000,0.009529898,0.123761095,,,,,,,,,,, -4100,0.022971846,0.12185048,,,,,,,,,,, -4122,,,0.1229165908558376,0.1249230898180862,83274637.0,0.1278794684004934,95000000.0,2786.889313220978,16127.262543678284,2786.889313220978,13339.73760986328,0.5014171600341797,0.0 -4200,0.02356564,0.12624912,,,,,,,,,,, -4300,,,0.1225354496342768,0.1246359960725872,83274637.0,0.1268836971833881,95000000.0,2907.2045364379883,16792.80887746811,2907.2045364379883,13884.941725730896,0.5227134227752686,0.0 -4300,0.00535433,0.12115466,,,,,,,,,,, -4400,0.0071059414,0.12047344,,,,,,,,,,, -4478,,,0.1250219466691871,0.1247609821932487,83274637.0,0.1272580874897204,95000000.0,3027.667422056198,17451.23566007614,3027.667422056198,14422.877438545229,0.5451195240020752,0.0 -4500,0.013148362,0.124086276,,,,,,,,,,, -4600,0.005881113,0.124325275,,,,,,,,,,, -4660,,,0.125849908112355,0.1246348594086899,83274637.0,0.1270226580797697,95000000.0,3147.6467897892,18115.323104143143,3147.6467897892,14966.952553272247,0.5720090866088867,0.0 -4700,0.004637919,0.1323927,,,,,,,,,,, -4800,0.019627135,0.12196121,,,,,,,,,,, -4843,,,0.1248512399271599,0.1245842440914719,83274637.0,0.1269167181332237,95000000.0,3268.2762217521667,18779.21762752533,3268.2762217521667,15510.190085411072,0.5935556888580322,0.0 -4900,0.0058309054,0.1281078,,,,,,,,,,, -5000,0.007906484,0.12993887,,,,,,,,,,, -5025,,,0.1222759853673618,0.1246284942222076,83274637.0,0.1269516108655427,95000000.0,3388.835000038147,19441.40255951881,3388.835000038147,16051.788796186447,0.6148788928985596,0.0 -5100,0.0065259724,0.12941913,,,,,,,,,,, -5200,0.0062909937,0.119605675,,,,,,,,,,, -5205,,,0.1194243619113036,0.1241413238073654,83274637.0,0.1265438617084704,95000000.0,3509.062576532364,20106.009090423584,3509.062576532364,16596.137999773026,0.6388421058654785,0.0 -5300,0.0637239,0.12043012,,,,,,,,,,, -5385,,,0.1242270162504799,0.1243186725521428,83274637.0,0.126651141118421,95000000.0,3629.413244247437,20764.800588607788,3629.413244247437,17134.551802396774,0.6600420475006104,0.0 -5400,0.022803253,0.11806701,,,,,,,,,,, -5500,0.006032241,0.121685565,,,,,,,,,,, -5567,,,0.1234444462746944,0.124154712235632,83274637.0,0.1265954343030427,95000000.0,3749.918417453766,21426.67833852768,3749.918417453766,17675.89226746559,0.6861846446990967,0.0 -5600,0.01333635,0.12859362,,,,,,,,,,, -5700,0.018904848,0.1187994,,,,,,,,,,, -5744,,,0.1225258158344142,0.1241873344085678,83274637.0,0.1266719728824013,95000000.0,3870.579716682434,22086.477658748627,3870.579716682434,18215.002883434296,0.7078261375427246,0.0 -5800,0.020199738,0.12306995,,,,,,,,,,, -5900,0.010616281,0.13025217,,,,,,,,,,, -5920,,,0.1233771085950001,0.1241138011540412,83274637.0,0.1264998841077302,95000000.0,3990.8730747699738,22759.23750019073,3990.8730747699738,18767.438177347183,0.7333683967590332,0.0 -6000,0.00553147,0.12645099,,,,,,,,,,, -6100,,,0.1220550997616172,0.1241251522136813,83274637.0,0.1265815414165296,95000000.0,4110.999460935593,23411.2101802826,4110.999460935593,19299.25735974312,0.7545573711395264,0.0 -6100,0.0073349634,0.113641694,,,,,,,,,,, -6200,0.013990974,0.13040543,,,,,,,,,,, -6280,,,0.1226911115219945,0.1240183384787863,83274637.0,0.1264107358244243,95000000.0,4231.590196371079,24081.660502672195,4231.590196371079,19849.08947825432,0.7763080596923828,0.0 -6300,0.027143957,0.11945921,,,,,,,,,,, -6400,0.013295012,0.11907414,,,,,,,,,,, -6458,,,0.122776152952662,0.1239687272867987,83274637.0,0.1263291104749177,95000000.0,4351.695854663849,24754.72527098656,4351.695854663849,20402.014936208725,0.8041188716888428,0.0 -6500,0.009181279,0.13156824,,,,,,,,,,, -6600,0.016354606,0.13184029,,,,,,,,,,, -6636,,,0.1219295825715522,0.1239027230424068,83274637.0,0.1262632901624177,95000000.0,4472.008096218109,25424.37505888939,4472.008096218109,20951.32184123993,0.8290243148803711,0.0 -6700,0.008869216,0.12238988,,,,,,,,,,, -6800,0.025942046,0.12326771,,,,,,,,,,, -6813,,,0.122105953348039,0.1241584364994646,83274637.0,0.1265540494551809,95000000.0,4592.035045385361,26095.08243894577,4592.035045385361,21501.974306821823,0.8511290550231934,0.0 -6900,0.0057829977,0.12079891,,,,,,,,,,, -6994,,,0.1226234874518225,0.1239216095151291,83274637.0,0.1262784921875,95000000.0,4712.584577083588,26755.904494524,4712.584577083588,22042.21883559227,0.8733363151550293,0.0 -7000,0.0065383995,0.11927109,,,,,,,,,,, -7100,0.010537244,0.12599514,,,,,,,,,,, -7169,,,0.1214132299939604,0.1239500357873445,83274637.0,0.1263248577302631,95000000.0,4832.942766189575,27425.027945756912,4832.942766189575,22590.956671714783,0.8951373100280762,0.0 -7200,0.011179821,0.12362923,,,,,,,,,,, -7300,0.0055970117,0.13018736,,,,,,,,,,, -7343,,,0.1239069283196011,0.1237468096242459,83274637.0,0.1261030070518092,95000000.0,4953.628441810608,28084.47293281555,4953.628441810608,23129.687811613083,0.917670726776123,0.0 -7400,0.0074973335,0.1157335,,,,,,,,,,, -7500,0.0068598557,0.12932041,,,,,,,,,,, -7520,,,0.1231251848545276,0.1238554156743033,83274637.0,0.1262120737561677,95000000.0,5073.737655878067,28749.45256137848,5073.737655878067,23674.528027772903,0.9421184062957764,0.0 -7600,0.008074706,0.11606064,,,,,,,,,,, -7697,,,0.1225887277872307,0.1239002044617888,83274637.0,0.1262973823807565,95000000.0,5194.339241266251,29416.991481542587,5194.339241266251,24221.43767905236,0.9639968872070312,0.0 -7700,0.006675052,0.12449772,,,,,,,,,,, -7800,0.008349893,0.115552515,,,,,,,,,,, -7875,,,0.1228803222963832,0.123725748377521,83274637.0,0.1260425471217105,95000000.0,5314.863927364349,30067.36024856568,5314.863927364349,24751.25191283226,0.9881429672241212,0.0 -7900,0.0058765416,0.11952849,,,,,,,,,,, -8000,0.006465148,0.13160342,,,,,,,,,,, -8052,,,0.1231917627858665,0.1237757983355874,83274637.0,0.1260829233038651,95000000.0,5434.98015499115,30728.10421323776,5434.98015499115,25291.850727558136,1.0113427639007568,0.0 -8100,0.00886124,0.11351249,,,,,,,,,,, -8200,0.012977887,0.13272312,,,,,,,,,,, -8234,,,0.1203018530336378,0.123735820307683,83274637.0,0.1261384195106908,95000000.0,5555.511506319046,31391.005412340164,5555.511506319046,25834.18929052353,1.036632061004639,0.0 -8300,0.010996833,0.119543254,,,,,,,,,,, -8400,0.0061579575,0.11885491,,,,,,,,,,, -8416,,,0.1235053100910201,0.1237182156968322,83274637.0,0.1260427416426809,95000000.0,5676.200478792191,32057.109263181686,5676.200478792191,26379.57387113571,1.0609583854675293,0.0 -8500,0.00540291,0.12067863,,,,,,,,,,, -8592,,,0.1198780242611402,0.1236284328170971,83274637.0,0.1259620379728618,95000000.0,5796.7442128658295,32719.198880910877,5796.7442128658295,26921.09217262268,1.082719326019287,0.0 -8592,,,,,,,,5796.7442128658295,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d42a893fa..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,37 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -211.87901759147644,0.0,57.167845487594604,1,0,57.167845487594604,0.9542218943076306,3581,0.1806810384887252,269.0472927093506,0.949392182486398,0.1666739497865949,0.9540861749085536,3554,0.1565500236858996 -216.2666301727295,0.0293574333190917,137.38201189041138,321,0,137.38201189041138,0.3166258811151738,3581,0.7127972229736805,353.6898150444031,0.296964304787772,0.7165457861764091,0.3145373789326287,3554,0.6952132362962508 -220.31308269500727,0.0702927112579345,217.40318703651428,541,0,217.40318703651428,0.3186125831209858,3581,0.7074705804070092,437.8061335086823,0.2996918814522879,0.7096298081534249,0.3163410588201674,3554,0.6907452018192529 -224.36898016929624,0.099755048751831,297.6003651618957,769,0,297.6003651618957,0.3051015027772444,3581,0.7222035569673275,522.0966551303864,0.2861111504690988,0.7260971750531878,0.3030438785413442,3554,0.7049438957424733 -228.4245643615723,0.1374483108520507,377.60139751434326,1006,0,377.60139751434326,0.2999985819254398,3581,0.7272784912995671,606.199057340622,0.2805786643709455,0.7319478988647461,0.2980949474843486,3554,0.7102617511738534 -232.47822403907776,0.1603643894195556,457.715913772583,1356,0,457.715913772583,0.2988607134442195,3581,0.7298116633360094,690.4028720855713,0.2789282117571149,0.734635625566755,0.2971919912532533,3554,0.7127662877655458 -236.5397520065308,0.1849915981292724,537.897935628891,1704,0,537.897935628891,0.295369659313041,3581,0.7326778783641091,774.6835222244263,0.2755681616919381,0.7371971266610282,0.2936600238068022,3554,0.7156436986713914 -240.59777998924253,0.2132840156555175,617.8677086830139,2053,0,617.8677086830139,0.295348592724623,3581,0.7331143453513335,858.7522346973419,0.2759322098323277,0.7376223291669574,0.2937862844901871,3554,0.7158771229380627 -244.5958752632141,0.236177921295166,697.8672807216644,2401,0,697.8672807216644,0.2937181137753944,3581,0.7331176178310876,942.7851028442384,0.274444648197719,0.7370893614632743,0.2920735219119302,3554,0.7161523822189786 -248.6550018787384,0.25958251953125,777.9301776885986,2754,0,777.9301776885986,0.2975957295775446,3581,0.7289458879110234,1026.943326473236,0.2777220862252371,0.7339754785810199,0.2959396199155001,3554,0.7120074184677476 -252.7120006084442,0.283416748046875,858.1966280937195,3104,0,858.1966280937195,0.2924171667197535,3581,0.7360860978515079,1111.3028681278229,0.2731503418513706,0.7403870310102191,0.2908985693936409,3554,0.7188271440410453 -256.76760053634644,0.3104734420776367,938.2925367355348,3455,0,938.2925367355348,0.2927997400560248,3581,0.735108171818277,1195.4938111305237,0.2730321543557303,0.7401855332510812,0.2911679896331598,3554,0.717900728602455 -260.8212649822235,0.3337705135345459,1018.3152334690094,3806,0,1018.3152334690094,0.291892683662472,3581,0.7367013240453085,1279.606038093567,0.2719219582421439,0.7418031692504883,0.2903851802656338,3554,0.7194951990714687 -264.8784046173096,0.3595108985900879,1098.3594090938568,4156,0,1098.3594090938568,0.29146777261938,3581,0.7368092477005376,1363.7455296516418,0.2717895167214529,0.7417457444327218,0.2899759665056098,3554,0.7195797621298186 -268.9368917942047,0.3831746578216553,1178.503762960434,4506,0,1178.503762960434,0.3034388785320965,3581,0.7287131327885018,1447.9846580028534,0.2847920826503208,0.7318805285862514,0.3016822827821644,3554,0.7119016287765546 -272.99322748184204,0.4068598747253418,1258.6917498111725,4861,0,1258.6917498111725,0.2912022586109851,3581,0.7375149443242112,1532.2653322219849,0.2719615868159702,0.7416975157601493,0.2896975816201815,3554,0.7205655984014491 -277.0551047325134,0.4309911727905273,1338.8532030582428,5212,0,1338.8532030582428,0.2904809154434865,3581,0.7392886283510193,1616.525263786316,0.2708852631705148,0.7440947805132184,0.2889249047068444,3554,0.7221770364378165 -281.11381006240845,0.4541096687316894,1418.8516051769257,5560,0,1418.8516051769257,0.2906327789570825,3581,0.7365404953007191,1700.6179819107056,0.2704907315117972,0.7420173372541156,0.2891139522588984,3554,0.7193733348427828 -285.165326833725,0.4809415340423584,1498.854224205017,5910,0,1498.854224205017,0.2903547886196244,3581,0.7378017635393396,1784.7112278938291,0.2704597541264125,0.7429211480276925,0.2889580498535981,3554,0.7205333119372538 -289.2194714546204,0.5062820911407471,1578.8222794532776,6260,0,1578.8222794532776,0.2918514708705669,3581,0.7374155427516755,1868.7712697982788,0.2717337608337402,0.7425744192940849,0.2902517753433631,3554,0.7202327730418191 -293.27813386917114,0.5319225788116455,1658.8429119586945,6606,0,1658.8429119586945,0.2898056256108628,3581,0.738683287773143,1952.8885478973389,0.269721473966326,0.7441070420401437,0.2883000929300612,3554,0.7215114544131612 -297.333603143692,0.5562727451324463,1738.8819456100464,6957,0,1738.8819456100464,0.2925878469919366,3581,0.7333682352389347,2037.0198502540588,0.2731671673910958,0.7371713774544852,0.2910327986511677,3554,0.7164769642260481 -301.3985674381256,0.5810074806213379,1818.9697902202608,7309,0,1818.9697902202608,0.2902679656411442,3581,0.7377455177935632,2121.2102065086365,0.270327159336635,0.7429877008710589,0.2888100473278348,3554,0.7205748034784749 -305.453408241272,0.6062729358673096,1899.003197669983,7657,0,1899.003197669983,0.2922854834979754,3581,0.7372101264704343,2205.3362097740173,0.271768365587507,0.7427032334463937,0.290655871355339,3554,0.7201644906047763 -309.5147559642792,0.6307752132415771,1979.119682073593,8010,0,1979.119682073593,0.2895987094430676,3581,0.7385266859815693,2289.551235675812,0.2699317421231951,0.7434444427490234,0.2882308831158554,3554,0.7212056947277715 -313.5734236240387,0.6558754444122314,2059.097370862961,8361,0,2059.097370862961,0.2890893275163188,3581,0.7405470331698548,2373.625335931778,0.269161650112697,0.7458525385175433,0.2876737698719752,3554,0.7233962282815138 -317.63561153411865,0.6815152168273926,2139.133618593216,8711,0,2139.133618593216,0.2887830438643186,3581,0.7406608200179768,2457.7619621753693,0.2688164710998535,0.7459323746817452,0.2873607800794439,3554,0.723511154355128 -321.69448947906494,0.7097055912017822,2219.263828992844,9065,0,2219.263828992844,0.2897788662712057,3581,0.7399505555579796,2541.992125272751,0.2700753211975097,0.7447177342006138,0.2883014668221546,3554,0.7228308716850732 -325.7490336894989,0.7341964244842529,2299.4439175128937,9417,0,2299.4439175128937,0.2896762944839081,3581,0.7402204669610444,2626.263649225235,0.2696090936660766,0.7454489299229213,0.2882844992548009,3554,0.7230337955472707 -329.8038098812103,0.7595300674438477,2379.613367795944,9767,0,2379.613367795944,0.2891634696357512,3581,0.7405805079106744,2710.5259046554565,0.2690895625523158,0.7458797182355609,0.2877059189469612,3554,0.7234043342448649 -333.8633136749268,0.7841873168945312,2459.616132259369,10117,0,2459.616132259369,0.2885292903300405,3581,0.7396588957998813,2794.6254320144653,0.2684458153588431,0.7448987279619489,0.2871320785668437,3554,0.7223642292355444 -337.9186999797821,0.8094024658203125,2539.793482542038,10466,0,2539.793482542038,0.2884751580607721,3581,0.7394358899399609,2878.8959987163544,0.2686894961765834,0.7444917815072196,0.2871836853886026,3554,0.7220280378402856 -341.9752879142761,0.8340375423431396,2619.763811588288,10815,0,2619.763811588288,0.2908301163039304,3581,0.7398273603305641,2962.9599463939667,0.2701017686298915,0.7452566283089774,0.2892038734964124,3554,0.7227582614879361 -346.03573846817017,0.8586001396179199,2699.914743423462,11166,0,2699.914743423462,0.2890474329577981,3581,0.7390223303110165,3047.208276987076,0.2688817977905273,0.744647707257952,0.2875809119401115,3554,0.7218623464538196 -350.09374594688416,0.8860418796539307,2779.982065677643,11518,0,2779.982065677643,0.2880069548375977,3581,0.7399262164898073,3131.3735089302063,0.2679336241313389,0.745260374886649,0.2866494817953802,3554,0.7225378205015476 -354.1460373401642,0.9109201431274414,2860.175224542618,11869,0,2860.175224542618,0.28770135295221305,3581,0.7416352008648073,3215.6563110351562,0.2677281413759504,0.7469045775277274,0.28639936473977384,3554,0.724327658426245 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 10101c6dd..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,157 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,6.0047956,0.8782616,,,,,,,,,,,,,, -1,,,0.1666739497865949,0.949392182486398,0.1565500236858996,0.9540861749085536,3554.0,0.1806810384887252,0.9542218943076306,3581.0,57.167845487594604,269.0472927093506,57.167845487594604,211.87901759147644,0.0,0.0 -100,0.19947258,0.32839608,,,,,,,,,,,,,, -200,0.2079307,0.2990676,,,,,,,,,,,,,, -300,0.5084673,0.31425044,,,,,,,,,,,,,, -321,,,0.7165457861764091,0.296964304787772,0.6952132362962508,0.3145373789326287,3554.0,0.7127972229736805,0.3166258811151738,3581.0,137.38201189041138,353.6898150444031,137.38201189041138,216.2666301727295,0.0293574333190917,0.0 -400,0.2963291,0.2915892,,,,,,,,,,,,,, -500,0.2312562,0.27345398,,,,,,,,,,,,,, -541,,,0.7096298081534249,0.2996918814522879,0.6907452018192529,0.3163410588201674,3554.0,0.7074705804070092,0.3186125831209858,3581.0,217.40318703651428,437.8061335086823,217.40318703651428,220.31308269500727,0.0702927112579345,0.0 -600,0.22478566,0.28373575,,,,,,,,,,,,,, -700,0.46591702,0.29267538,,,,,,,,,,,,,, -769,,,0.7260971750531878,0.2861111504690988,0.7049438957424733,0.3030438785413442,3554.0,0.7222035569673275,0.3051015027772444,3581.0,297.6003651618957,522.0966551303864,297.6003651618957,224.36898016929624,0.099755048751831,0.0 -800,0.15694126,0.25058588,,,,,,,,,,,,,, -900,0.09317388,0.34275672,,,,,,,,,,,,,, -1000,0.36214522,0.23676927,,,,,,,,,,,,,, -1006,,,0.7319478988647461,0.2805786643709455,0.7102617511738534,0.2980949474843486,3554.0,0.7272784912995671,0.2999985819254398,3581.0,377.60139751434326,606.199057340622,377.60139751434326,228.4245643615723,0.1374483108520507,0.0 -1100,0.36572388,0.31446502,,,,,,,,,,,,,, -1200,0.16738689,0.249224,,,,,,,,,,,,,, -1300,0.11910743,0.27250308,,,,,,,,,,,,,, -1356,,,0.734635625566755,0.2789282117571149,0.7127662877655458,0.2971919912532533,3554.0,0.7298116633360094,0.2988607134442195,3581.0,457.715913772583,690.4028720855713,457.715913772583,232.47822403907776,0.1603643894195556,0.0 -1400,0.19823116,0.27645075,,,,,,,,,,,,,, -1500,0.13268507,0.29177886,,,,,,,,,,,,,, -1600,0.28098547,0.2715404,,,,,,,,,,,,,, -1700,0.08273353,0.44929305,,,,,,,,,,,,,, -1704,,,0.7371971266610282,0.2755681616919381,0.7156436986713914,0.2936600238068022,3554.0,0.7326778783641091,0.295369659313041,3581.0,537.897935628891,774.6835222244263,537.897935628891,236.5397520065308,0.1849915981292724,0.0 -1800,0.14656837,0.24938671,,,,,,,,,,,,,, -1900,0.13903536,0.2409766,,,,,,,,,,,,,, -2000,0.07517548,0.26444608,,,,,,,,,,,,,, -2053,,,0.7376223291669574,0.2759322098323277,0.7158771229380627,0.2937862844901871,3554.0,0.7331143453513335,0.295348592724623,3581.0,617.8677086830139,858.7522346973419,617.8677086830139,240.59777998924253,0.2132840156555175,0.0 -2100,0.0950666,0.23008016,,,,,,,,,,,,,, -2200,0.33657303,0.20569095,,,,,,,,,,,,,, -2300,0.20898737,0.2431573,,,,,,,,,,,,,, -2400,0.2818645,0.33721977,,,,,,,,,,,,,, -2401,,,0.7370893614632743,0.274444648197719,0.7161523822189786,0.2920735219119302,3554.0,0.7331176178310876,0.2937181137753944,3581.0,697.8672807216644,942.7851028442384,697.8672807216644,244.5958752632141,0.236177921295166,0.0 -2500,0.11352337,0.22089839,,,,,,,,,,,,,, -2600,0.090647876,0.27092314,,,,,,,,,,,,,, -2700,0.084371865,0.27321714,,,,,,,,,,,,,, -2754,,,0.7339754785810199,0.2777220862252371,0.7120074184677476,0.2959396199155001,3554.0,0.7289458879110234,0.2975957295775446,3581.0,777.9301776885986,1026.943326473236,777.9301776885986,248.6550018787384,0.25958251953125,0.0 -2800,0.18306535,0.26681712,,,,,,,,,,,,,, -2900,0.12302786,0.21863891,,,,,,,,,,,,,, -3000,0.1457906,0.30724972,,,,,,,,,,,,,, -3100,0.12758388,0.22282146,,,,,,,,,,,,,, -3104,,,0.7403870310102191,0.2731503418513706,0.7188271440410453,0.2908985693936409,3554.0,0.7360860978515079,0.2924171667197535,3581.0,858.1966280937195,1111.3028681278229,858.1966280937195,252.7120006084442,0.283416748046875,0.0 -3200,0.07152758,0.29319328,,,,,,,,,,,,,, -3300,0.06558319,0.30747113,,,,,,,,,,,,,, -3400,0.07124475,0.4278422,,,,,,,,,,,,,, -3455,,,0.7401855332510812,0.2730321543557303,0.717900728602455,0.2911679896331598,3554.0,0.735108171818277,0.2927997400560248,3581.0,938.2925367355348,1195.4938111305237,938.2925367355348,256.76760053634644,0.3104734420776367,0.0 -3500,0.027598446,0.29487595,,,,,,,,,,,,,, -3600,0.047831975,0.28148848,,,,,,,,,,,,,, -3700,0.31697708,0.27066872,,,,,,,,,,,,,, -3800,0.11629858,0.2808988,,,,,,,,,,,,,, -3806,,,0.7418031692504883,0.2719219582421439,0.7194951990714687,0.2903851802656338,3554.0,0.7367013240453085,0.291892683662472,3581.0,1018.3152334690094,1279.606038093567,1018.3152334690094,260.8212649822235,0.3337705135345459,0.0 -3900,0.061210386,0.3477282,,,,,,,,,,,,,, -4000,0.19546281,0.19057763,,,,,,,,,,,,,, -4100,0.12984487,0.28707737,,,,,,,,,,,,,, -4156,,,0.7417457444327218,0.2717895167214529,0.7195797621298186,0.2899759665056098,3554.0,0.7368092477005376,0.29146777261938,3581.0,1098.3594090938568,1363.7455296516418,1098.3594090938568,264.8784046173096,0.3595108985900879,0.0 -4200,0.29106784,0.28050238,,,,,,,,,,,,,, -4300,0.13648793,0.3260742,,,,,,,,,,,,,, -4400,0.09143051,0.37124798,,,,,,,,,,,,,, -4500,0.12643336,0.27394107,,,,,,,,,,,,,, -4506,,,0.7318805285862514,0.2847920826503208,0.7119016287765546,0.3016822827821644,3554.0,0.7287131327885018,0.3034388785320965,3581.0,1178.503762960434,1447.9846580028534,1178.503762960434,268.9368917942047,0.3831746578216553,0.0 -4600,0.107572064,0.29190192,,,,,,,,,,,,,, -4700,0.17475241,0.26861697,,,,,,,,,,,,,, -4800,0.07495514,0.29144135,,,,,,,,,,,,,, -4861,,,0.7416975157601493,0.2719615868159702,0.7205655984014491,0.2896975816201815,3554.0,0.7375149443242112,0.2912022586109851,3581.0,1258.6917498111725,1532.2653322219849,1258.6917498111725,272.99322748184204,0.4068598747253418,0.0 -4900,0.054732393,0.39245686,,,,,,,,,,,,,, -5000,0.21915086,0.39035058,,,,,,,,,,,,,, -5100,0.08853672,0.30567136,,,,,,,,,,,,,, -5200,0.08967044,0.23717757,,,,,,,,,,,,,, -5212,,,0.7440947805132184,0.2708852631705148,0.7221770364378165,0.2889249047068444,3554.0,0.7392886283510193,0.2904809154434865,3581.0,1338.8532030582428,1616.525263786316,1338.8532030582428,277.0551047325134,0.4309911727905273,0.0 -5300,0.06308084,0.29200724,,,,,,,,,,,,,, -5400,0.097126566,0.30976415,,,,,,,,,,,,,, -5500,0.15357766,0.26407725,,,,,,,,,,,,,, -5560,,,0.7420173372541156,0.2704907315117972,0.7193733348427828,0.2891139522588984,3554.0,0.7365404953007191,0.2906327789570825,3581.0,1418.8516051769257,1700.6179819107056,1418.8516051769257,281.11381006240845,0.4541096687316894,0.0 -5600,0.07552649,0.2506204,,,,,,,,,,,,,, -5700,0.10056262,0.22994839,,,,,,,,,,,,,, -5800,0.096101,0.3581094,,,,,,,,,,,,,, -5900,0.05669721,0.31557757,,,,,,,,,,,,,, -5910,,,0.7429211480276925,0.2704597541264125,0.7205333119372538,0.2889580498535981,3554.0,0.7378017635393396,0.2903547886196244,3581.0,1498.854224205017,1784.7112278938291,1498.854224205017,285.165326833725,0.4809415340423584,0.0 -6000,0.17290746,0.2915334,,,,,,,,,,,,,, -6100,0.13837765,0.2784336,,,,,,,,,,,,,, -6200,0.053263538,0.2750971,,,,,,,,,,,,,, -6260,,,0.7425744192940849,0.2717337608337402,0.7202327730418191,0.2902517753433631,3554.0,0.7374155427516755,0.2918514708705669,3581.0,1578.8222794532776,1868.7712697982788,1578.8222794532776,289.2194714546204,0.5062820911407471,0.0 -6300,0.105820954,0.20120332,,,,,,,,,,,,,, -6400,0.10165949,0.24975541,,,,,,,,,,,,,, -6500,0.084800266,0.28728598,,,,,,,,,,,,,, -6600,0.08490534,0.2653175,,,,,,,,,,,,,, -6606,,,0.7441070420401437,0.269721473966326,0.7215114544131612,0.2883000929300612,3554.0,0.738683287773143,0.2898056256108628,3581.0,1658.8429119586945,1952.8885478973389,1658.8429119586945,293.27813386917114,0.5319225788116455,0.0 -6700,0.109833606,0.30272698,,,,,,,,,,,,,, -6800,0.13485071,0.27535167,,,,,,,,,,,,,, -6900,0.06421297,0.3410409,,,,,,,,,,,,,, -6957,,,0.7371713774544852,0.2731671673910958,0.7164769642260481,0.2910327986511677,3554.0,0.7333682352389347,0.2925878469919366,3581.0,1738.8819456100464,2037.0198502540588,1738.8819456100464,297.333603143692,0.5562727451324463,0.0 -7000,0.0679769,0.23027629,,,,,,,,,,,,,, -7100,0.091417,0.23971996,,,,,,,,,,,,,, -7200,0.13440073,0.26782358,,,,,,,,,,,,,, -7300,0.05783939,0.26172554,,,,,,,,,,,,,, -7309,,,0.7429877008710589,0.270327159336635,0.7205748034784749,0.2888100473278348,3554.0,0.7377455177935632,0.2902679656411442,3581.0,1818.9697902202608,2121.2102065086365,1818.9697902202608,301.3985674381256,0.5810074806213379,0.0 -7400,0.090588056,0.2936203,,,,,,,,,,,,,, -7500,0.061901204,0.21413556,,,,,,,,,,,,,, -7600,0.1471974,0.22135317,,,,,,,,,,,,,, -7657,,,0.7427032334463937,0.271768365587507,0.7201644906047763,0.290655871355339,3554.0,0.7372101264704343,0.2922854834979754,3581.0,1899.003197669983,2205.3362097740173,1899.003197669983,305.453408241272,0.6062729358673096,0.0 -7700,0.1776275,0.17771378,,,,,,,,,,,,,, -7800,0.049456205,0.25345305,,,,,,,,,,,,,, -7900,0.054863732,0.24318984,,,,,,,,,,,,,, -8000,0.08320521,0.24050775,,,,,,,,,,,,,, -8010,,,0.7434444427490234,0.2699317421231951,0.7212056947277715,0.2882308831158554,3554.0,0.7385266859815693,0.2895987094430676,3581.0,1979.119682073593,2289.551235675812,1979.119682073593,309.5147559642792,0.6307752132415771,0.0 -8100,0.14088908,0.19595897,,,,,,,,,,,,,, -8200,0.04923698,0.3142887,,,,,,,,,,,,,, -8300,0.30023918,0.22071676,,,,,,,,,,,,,, -8361,,,0.7458525385175433,0.269161650112697,0.7233962282815138,0.2876737698719752,3554.0,0.7405470331698548,0.2890893275163188,3581.0,2059.097370862961,2373.625335931778,2059.097370862961,313.5734236240387,0.6558754444122314,0.0 -8400,0.1150041,0.26212922,,,,,,,,,,,,,, -8500,0.15695462,0.28213134,,,,,,,,,,,,,, -8600,0.09194674,0.23015851,,,,,,,,,,,,,, -8700,0.16563678,0.25171876,,,,,,,,,,,,,, -8711,,,0.7459323746817452,0.2688164710998535,0.723511154355128,0.2873607800794439,3554.0,0.7406608200179768,0.2887830438643186,3581.0,2139.133618593216,2457.7619621753693,2139.133618593216,317.63561153411865,0.6815152168273926,0.0 -8800,0.17430481,0.24715403,,,,,,,,,,,,,, -8900,0.12610042,0.2909592,,,,,,,,,,,,,, -9000,0.2175621,0.29648092,,,,,,,,,,,,,, -9065,,,0.7447177342006138,0.2700753211975097,0.7228308716850732,0.2883014668221546,3554.0,0.7399505555579796,0.2897788662712057,3581.0,2219.263828992844,2541.992125272751,2219.263828992844,321.69448947906494,0.7097055912017822,0.0 -9100,0.040089134,0.3405608,,,,,,,,,,,,,, -9200,0.14756909,0.22383142,,,,,,,,,,,,,, -9300,0.16259328,0.3227496,,,,,,,,,,,,,, -9400,0.07292045,0.26158106,,,,,,,,,,,,,, -9417,,,0.7454489299229213,0.2696090936660766,0.7230337955472707,0.2882844992548009,3554.0,0.7402204669610444,0.2896762944839081,3581.0,2299.4439175128937,2626.263649225235,2299.4439175128937,325.7490336894989,0.7341964244842529,0.0 -9500,0.057001036,0.27998748,,,,,,,,,,,,,, -9600,0.05102667,0.2777011,,,,,,,,,,,,,, -9700,0.02529343,0.25962967,,,,,,,,,,,,,, -9767,,,0.7458797182355609,0.2690895625523158,0.7234043342448649,0.2877059189469612,3554.0,0.7405805079106744,0.2891634696357512,3581.0,2379.613367795944,2710.5259046554565,2379.613367795944,329.8038098812103,0.7595300674438477,0.0 -9800,0.12466395,0.21524419,,,,,,,,,,,,,, -9900,0.13484329,0.26114172,,,,,,,,,,,,,, -10000,0.054205984,0.21022293,,,,,,,,,,,,,, -10100,0.13685833,0.27502003,,,,,,,,,,,,,, -10117,,,0.7448987279619489,0.2684458153588431,0.7223642292355444,0.2871320785668437,3554.0,0.7396588957998813,0.2885292903300405,3581.0,2459.616132259369,2794.6254320144653,2459.616132259369,333.8633136749268,0.7841873168945312,0.0 -10200,0.057856604,0.28030607,,,,,,,,,,,,,, -10300,0.08051611,0.25937727,,,,,,,,,,,,,, -10400,0.1089267,0.26597688,,,,,,,,,,,,,, -10466,,,0.7444917815072196,0.2686894961765834,0.7220280378402856,0.2871836853886026,3554.0,0.7394358899399609,0.2884751580607721,3581.0,2539.793482542038,2878.8959987163544,2539.793482542038,337.9186999797821,0.8094024658203125,0.0 -10500,0.06374385,0.37624457,,,,,,,,,,,,,, -10600,0.0631693,0.22876218,,,,,,,,,,,,,, -10700,0.14779998,0.23684269,,,,,,,,,,,,,, -10800,0.08487035,0.33668166,,,,,,,,,,,,,, -10815,,,0.7452566283089774,0.2701017686298915,0.7227582614879361,0.2892038734964124,3554.0,0.7398273603305641,0.2908301163039304,3581.0,2619.763811588288,2962.9599463939667,2619.763811588288,341.9752879142761,0.8340375423431396,0.0 -10900,0.06182772,0.34851742,,,,,,,,,,,,,, -11000,0.08406753,0.3525005,,,,,,,,,,,,,, -11100,0.06992767,0.32123905,,,,,,,,,,,,,, -11166,,,0.744647707257952,0.2688817977905273,0.7218623464538196,0.2875809119401115,3554.0,0.7390223303110165,0.2890474329577981,3581.0,2699.914743423462,3047.208276987076,2699.914743423462,346.03573846817017,0.8586001396179199,0.0 -11200,0.047675464,0.28827724,,,,,,,,,,,,,, -11300,0.09423501,0.31927174,,,,,,,,,,,,,, -11400,0.07137801,0.34779793,,,,,,,,,,,,,, -11500,0.06156904,0.33353502,,,,,,,,,,,,,, -11518,,,0.745260374886649,0.2679336241313389,0.7225378205015476,0.2866494817953802,3554.0,0.7399262164898073,0.2880069548375977,3581.0,2779.982065677643,3131.3735089302063,2779.982065677643,350.09374594688416,0.8860418796539307,0.0 -11600,0.19963348,0.22291714,,,,,,,,,,,,,, -11700,0.071067885,0.22832581,,,,,,,,,,,,,, -11800,0.08529575,0.21123546,,,,,,,,,,,,,, -11869,,,0.7469045775277274,0.2677281413759504,0.724327658426245,0.2863993647397738,3554.0,0.7416352008648073,0.287701352952213,3581.0,2860.175224542618,3215.656311035156,2860.175224542618,354.1460373401642,0.9109201431274414,0.0 -11869,,,,,,,,,,,2860.175224542618,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index a651f361a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,372 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.02550530433655,0.0,50.636836767196655,1,0,50.636836767196655,0.0010999999940395,6.912676334381104,10000,87.66242718696594,0.0011559311533346,6.912258625030518,0.0011599999852478,6.912920475006104,50000 -54.821940660476685,0.0259764194488525,560.612095117569,1506,0,560.612095117569,0.1157000064849853,4.921426296234131,10000,615.5135684013367,0.1685267835855484,4.295879364013672,0.1553799957036972,4.445466995239258,50000 -72.51834774017334,0.0519242286682128,1070.6800384521484,3010,0,1070.6800384521484,0.2299000173807144,3.900252103805542,10000,1143.3546075820925,0.3398238122463226,3.0561516284942627,0.3103599846363067,3.23838472366333,50000 -90.5101797580719,0.0783896446228027,1580.6047410964966,4516,0,1580.6047410964966,0.3290000259876251,3.260066509246826,10000,1671.3481891155243,0.4670958220958709,2.33237099647522,0.4367599785327911,2.513633728027344,50000 -108.24045300483704,0.1055688858032226,2090.61950135231,6024,0,2090.61950135231,0.385200023651123,2.8951804637908936,10000,2199.170222520828,0.5397201776504517,1.9583563804626465,0.5008000135421753,2.161590337753296,50000 -126.03376865386964,0.1374855041503906,2600.636000394821,7532,0,2600.636000394821,0.4175000190734863,2.7358531951904297,10000,2727.06293296814,0.5772082209587097,1.7723355293273926,0.5385599732398987,1.980293869972229,50000 -143.8541338443756,0.1674284934997558,3110.8438572883606,9041,0,3110.8438572883606,0.4307000339031219,2.655898094177246,10000,3255.172870874405,0.6067641973495483,1.6266725063323977,0.5556600093841553,1.8939425945281985,50000 -162.1518476009369,0.1943204402923584,3621.028754234314,10549,0,3621.028754234314,0.4445000290870666,2.585268497467041,10000,3783.733434200287,0.6285674571990967,1.5007295608520508,0.5639199614524841,1.8578181266784668,50000 -180.35190606117249,0.2221417427062988,4131.20413517952,12058,0,4131.20413517952,0.4433000087738037,2.580861568450928,10000,4312.186669111252,0.6300023794174194,1.498435378074646,0.5737599730491638,1.8043254613876345,50000 -198.97168898582456,0.2587699890136719,4641.235483169556,13566,0,4641.235483169556,0.4456000328063965,2.5929534435272217,10000,4840.924443721771,0.6170080900192261,1.5616915225982666,0.5643399953842163,1.8516933917999268,50000 -217.94574451446533,0.2880523204803467,5151.217526912689,15075,0,5151.217526912689,0.4571000337600708,2.5280418395996094,10000,5369.960024833679,0.6428770422935486,1.4448336362838743,0.5908399820327759,1.7287929058074951,50000 -236.3728106021881,0.3328497409820556,5661.294228553772,16584,0,5661.294228553772,0.4767000079154968,2.377015352249145,10000,5898.560553789139,0.6567881107330322,1.3828415870666504,0.6065399646759033,1.6479262113571167,50000 -260.39603447914124,0.3701400756835937,6171.443657398224,18092,0,6171.443657398224,0.4919000267982483,2.33859920501709,10000,6432.820313692093,0.6660555005073547,1.3375362157821655,0.6100599765777588,1.6255896091461182,50000 -283.1347830295563,0.4076879024505615,6681.714090824127,19601,0,6681.714090824127,0.482200026512146,2.384953022003174,10000,6965.917953968048,0.6856465339660645,1.2491897344589231,0.6026399731636047,1.6453914642333984,50000 -306.5707335472107,0.4417273998260498,7191.711829662323,21110,0,7191.711829662323,0.4874000251293182,2.3641345500946045,10000,7499.435747146606,0.6740672588348389,1.284901738166809,0.612060010433197,1.6220479011535645,50000 -330.72790360450745,0.4803063869476318,7701.745020151138,22620,0,7701.745020151138,0.4837000370025635,2.382908344268799,10000,8033.715864419937,0.6695631146430969,1.30652117729187,0.6083199977874756,1.6256723403930664,50000 -355.2130854129791,0.5175192356109619,8211.9476480484,24130,0,8211.9476480484,0.4928000271320343,2.347294330596924,10000,8568.492515087128,0.6800462007522583,1.271593689918518,0.6187199950218201,1.5880154371261597,50000 -378.2750778198242,0.5450489521026611,8722.075249195099,25640,0,8722.075249195099,0.4976000189781189,2.3183577060699463,10000,9101.75936460495,0.6746651530265808,1.286739468574524,0.6209200024604797,1.5903011560440063,50000 -401.23756408691406,0.5751643180847168,9232.15855383873,27150,0,9232.15855383873,0.4777000248432159,2.413469314575196,10000,9634.885572433472,0.6598373651504517,1.3535218238830566,0.6033399701118469,1.653844118118286,50000 -423.27252864837646,0.8448653221130371,9742.014395713806,28659,0,9742.014395713806,0.5024999976158142,2.293640851974488,10000,10167.096479654312,0.7111367583274841,1.1349786520004272,0.6277799606323242,1.5449466705322266,50000 -448.03714537620544,0.8730454444885254,10252.25020313263,30169,0,10252.25020313263,0.4919000267982483,2.3065173625946045,10000,10702.17523431778,0.6936184763908386,1.211167335510254,0.6198599934577942,1.569377303123474,50000 -470.6238646507263,0.9068002700805664,10762.18581557274,31678,0,10762.18581557274,0.4880000352859497,2.335855722427368,10000,11234.781381845474,0.6753029227256775,1.2694958448410034,0.6157999634742737,1.6059370040893557,50000 -492.6578459739685,0.9439327716827391,11272.298690795898,33189,0,11272.298690795898,0.5026000142097473,2.295906782150269,10000,11767.017308950424,0.6901506781578064,1.2144831418991089,0.6292600035667419,1.5291776657104492,50000 -514.2654712200165,0.9805140495300292,11782.477632045746,34699,0,11782.477632045746,0.5005000233650208,2.252868890762329,10000,12298.890912532806,0.6905093789100647,1.2288453578948977,0.630079984664917,1.5220123529434204,50000 -536.3608770370483,1.0178706645965576,12292.503982305529,36208,0,12292.503982305529,0.4957000315189361,2.327148199081421,10000,12831.100484848022,0.6882174611091614,1.2332067489624023,0.6226599812507629,1.570199966430664,50000 -555.5504469871521,1.0512502193450928,12802.65355682373,37718,0,12802.65355682373,0.4864000082015991,2.3108022212982178,10000,13360.524811983109,0.6916055083274841,1.2146902084350586,0.6141799688339233,1.6082170009613037,50000 -574.8256969451904,1.0908870697021484,13312.825942516329,39228,0,13312.825942516329,0.5144000053405762,2.233076572418213,10000,13890.063405036926,0.7033043503761292,1.156270980834961,0.6322799921035767,1.516471028327942,50000 -596.7445023059845,1.1273493766784668,13822.809534072876,40737,0,13822.809534072876,0.5001000165939331,2.2545292377471924,10000,14422.052576303482,0.698640763759613,1.1931952238082886,0.6297799944877625,1.5363556146621704,50000 -621.6768667697906,1.1569111347198486,14333.007185459135,42247,0,14333.007185459135,0.515500009059906,2.202585220336914,10000,14957.263036727903,0.6972257494926453,1.1852744817733765,0.6369199752807617,1.5088731050491333,50000 -640.7257559299469,1.1918201446533203,14843.097437620165,43757,0,14843.097437620165,0.5047000050544739,2.261016845703125,10000,15486.488491773604,0.6907684803009033,1.2137516736984253,0.6337599754333496,1.513280153274536,50000 -658.4644737243652,1.2304582595825195,15353.215401887894,45267,0,15353.215401887894,0.5081000328063965,2.2523016929626465,10000,16014.435420036316,0.7050183415412903,1.1576184034347534,0.6333799958229065,1.5275466442108154,50000 -676.3059678077698,1.2712736129760742,15863.195637464523,46777,0,15863.195637464523,0.508400022983551,2.2467596530914307,10000,16542.35011291504,0.7170957922935486,1.097066044807434,0.6341800093650818,1.5051928758621216,50000 -693.9939639568329,1.3073816299438477,16373.403130292892,48287,0,16373.403130292892,0.5080000162124634,2.26758337020874,10000,17070.333611249924,0.6975845098495483,1.179138422012329,0.620199978351593,1.5764381885528564,50000 -713.4746758937836,1.3526594638824463,16883.60203051567,49798,0,16883.60203051567,0.5197000503540039,2.1758978366851807,10000,17600.109800100327,0.7109972834587097,1.1291890144348145,0.6435399651527405,1.4746588468551636,50000 -732.192950963974,1.389472246170044,17393.63419032097,51308,0,17393.63419032097,0.5217000246047974,2.146376132965088,10000,18128.94772911072,0.7145049571990967,1.11211895942688,0.6514599919319153,1.4398891925811768,50000 -750.9723279476166,1.429046869277954,17903.583858013153,52818,0,17903.583858013153,0.5200000405311584,2.168796300888061,10000,18657.767151117325,0.7062539458274841,1.140507698059082,0.6458199620246887,1.4497957229614258,50000 -769.1590249538422,1.4705872535705566,18413.752345323563,54329,0,18413.752345323563,0.5128999948501587,2.193343162536621,10000,19186.21563768387,0.7234932780265808,1.0733474493026731,0.6421399712562561,1.484086513519287,50000 -787.6337478160858,1.508671760559082,18923.7833173275,55839,0,18923.7833173275,0.517300009727478,2.189545154571533,10000,19714.81046271324,0.7308075428009033,1.0367093086242676,0.6415799856185913,1.4657151699066162,50000 -805.2087225914001,1.5456082820892334,19433.82699513436,57349,0,19433.82699513436,0.5111000537872314,2.2348709106445312,10000,20242.517218351364,0.7161989808082581,1.0922472476959229,0.6412999629974365,1.4972940683364868,50000 -823.1132650375366,1.5886414051055908,19943.966212511063,58859,0,19943.966212511063,0.5232000350952148,2.208247423171997,10000,20770.65550327301,0.7106584906578064,1.1283915042877195,0.6424799561500549,1.481645584106445,50000 -840.6038625240326,1.6262004375457764,20453.887604236603,60369,0,20453.887604236603,0.5243000388145447,2.127664804458618,10000,21298.15672469139,0.7179926633834839,1.0955880880355835,0.6552000045776367,1.42160165309906,50000 -858.3883543014526,1.667083978652954,20964.092700004578,61880,0,20964.092700004578,0.5288000106811523,2.1190402507781982,10000,21826.239169597626,0.7214205861091614,1.0676918029785156,0.6552199721336365,1.4101110696792605,50000 -876.7133748531342,1.7043514251708984,21474.092804193497,63391,0,21474.092804193497,0.5279000401496887,2.121713399887085,10000,22354.6533973217,0.7526506781578064,0.9483557343482972,0.6526399850845337,1.41191303730011,50000 -895.2762496471405,1.7477846145629885,21984.03058242798,64901,0,21984.03058242798,0.5368000268936157,2.108168363571167,10000,22883.24925875664,0.7511758208274841,0.947900414466858,0.6609999537467957,1.3837952613830566,50000 -912.3423182964324,1.7881202697753906,22494.092856168747,66412,0,22494.092856168747,0.5273000001907349,2.121701955795288,10000,23410.470444202423,0.7248086333274841,1.053501844406128,0.6454199552536011,1.4561376571655271,50000 -929.560394525528,1.828833818435669,23004.167016267776,67923,0,23004.167016267776,0.5348000526428223,2.116585493087769,10000,23937.85494709015,0.7313655614852905,1.0242691040039062,0.6623199582099915,1.3837313652038574,50000 -946.576143026352,1.868699073791504,23514.20720410347,69433,0,23514.20720410347,0.5420000553131104,2.1048424243927,10000,24465.004062891006,0.7354910373687744,1.011197805404663,0.6612199544906616,1.3969271183013916,50000 -963.7557821273804,1.906718254089356,24024.3210310936,70944,0,24024.3210310936,0.5263000130653381,2.148491382598877,10000,24992.387575387955,0.7195272445678711,1.0805131196975708,0.6513800024986267,1.4209266901016235,50000 -981.2310724258424,1.946107625961304,24534.373774051663,72455,0,24534.373774051663,0.5333000421524048,2.126142978668213,10000,25520.007241487503,0.7664819955825806,0.8874098658561707,0.6557999849319458,1.4078059196472168,50000 -998.6352922916412,1.987243413925171,25044.457715272903,73966,0,25044.457715272903,0.5290000438690186,2.135497093200684,10000,26047.58855366707,0.7506775856018066,0.9532562494277954,0.6530999541282654,1.414706826210022,50000 -1016.3779957294464,2.028724908828736,25554.508782863617,75476,0,25554.508782863617,0.5326000452041626,2.133927345275879,10000,26575.4754896164,0.7411909699440002,0.9859520196914672,0.6597200036048889,1.3939448595046997,50000 -1033.5246062278748,2.0691707134246826,26064.464957475662,76986,0,26064.464957475662,0.5375000238418579,2.1226179599761963,10000,27102.67097759247,0.7378627061843872,0.9981990456581116,0.6623799800872803,1.3914191722869873,50000 -1050.6655519008636,3.017031192779541,26573.513592481613,78494,0,26573.513592481613,0.5348000526428223,2.101693630218506,10000,27629.85961341858,0.7308474183082581,1.024781346321106,0.6592999696731567,1.396071195602417,50000 -1067.5592651367188,3.058248281478882,27083.48101949692,80005,0,27083.48101949692,0.5323000550270081,2.090766191482544,10000,28156.81354093552,0.73148512840271,1.0369644165039062,0.6642000079154968,1.3688702583312988,50000 -1084.7018973827362,3.098496198654175,27593.4535381794,81515,0,27593.4535381794,0.5491999983787537,2.030189275741577,10000,28684.019721269608,0.7939453125,0.7794730067253113,0.6765599846839905,1.3186352252960205,50000 -1102.1596643924713,3.142584085464477,28103.538105487823,83025,0,28103.538105487823,0.5421000123023987,2.0674638748168945,10000,29211.65815377236,0.7580117583274841,0.917476773262024,0.6669999957084656,1.3600958585739136,50000 -1119.2302613258362,3.1828737258911133,28613.781172037125,84537,0,28613.781172037125,0.5378000140190125,2.0711793899536133,10000,29739.06396436692,0.75390625,0.9180837273597716,0.6711199879646301,1.3440886735916138,50000 -1136.0348734855652,3.2278828620910645,29123.81152534485,86048,0,29123.81152534485,0.5491000413894653,2.0667781829833984,10000,30265.995129585262,0.7520527839660645,0.9375916719436646,0.670799970626831,1.3429946899414062,50000 -1153.1339037418363,3.2748494148254395,29633.89668869972,87559,0,29633.89668869972,0.5520000457763672,2.0441277027130127,10000,30793.2783703804,0.7551020383834839,0.9199984073638916,0.6786999702453613,1.3077771663665771,50000 -1170.2514476776123,3.320532321929932,30144.08286547661,89070,0,30144.08286547661,0.5527000427246094,2.053436517715454,10000,31320.67922115326,0.7512954473495483,0.9448497891426086,0.6762999892234802,1.3257722854614258,50000 -1187.541890144348,3.3698208332061768,30654.26819229126,90582,0,30654.26819229126,0.5542000532150269,1.987893223762512,10000,31848.25612831116,0.8028140664100647,0.7457603812217712,0.6789199709892273,1.2876888513565063,50000 -1204.5566387176514,3.4134774208068848,31164.4939289093,92095,0,31164.4939289093,0.5468000173568726,2.067296504974365,10000,32375.592450141907,0.7693120241165161,0.856544017791748,0.6754800081253052,1.3307112455368042,50000 -1222.084743976593,3.460716009140014,31674.687499523163,93607,0,31674.687499523163,0.5529000163078308,2.0210089683532715,10000,32903.412647247314,0.76566481590271,0.876564621925354,0.6783599853515625,1.3167260885238647,50000 -1239.3243083953855,3.507728338241577,32184.654997348785,95118,0,32184.654997348785,0.5593000054359436,1.996338963508606,10000,33430.718074798584,0.7731783986091614,0.8373833894729614,0.6893999576568604,1.2670496702194214,50000 -1256.4522774219513,3.5546374320983887,32694.792423963547,96629,0,32694.792423963547,0.5505000352859497,2.031044006347656,10000,33958.081731557846,0.756257951259613,0.9207841157913208,0.6714000105857849,1.3359733819961548,50000 -1273.7611339092257,3.6036319732666016,33204.81808638573,98140,0,33204.81808638573,0.5555000305175781,2.0382800102233887,10000,34485.51674103737,0.7623963356018066,0.890619695186615,0.6840199828147888,1.299177885055542,50000 -1290.823139667511,3.64790391921997,33714.866651535034,99651,0,33714.866651535034,0.5587000250816345,1.994417428970337,10000,35012.723252773285,0.807059109210968,0.7231523990631104,0.6853199601173401,1.2836941480636597,50000 -1307.958990097046,3.690621852874756,34224.85311436653,101162,0,34224.85311436653,0.5457000136375427,2.0852365493774414,10000,35539.94133090973,0.7729192972183228,0.8433666229248047,0.6765599846839905,1.3290696144104004,50000 -1325.1586797237396,3.736107349395752,34734.81702041626,102673,0,34734.81702041626,0.5649000406265259,1.9514639377594,10000,36067.20191526413,0.7886638641357422,0.7821224331855774,0.6950399875640869,1.240763783454895,50000 -1342.0841455459597,3.785308361053467,35245.00524544716,104184,0,35245.00524544716,0.5725000500679016,1.948253273963928,10000,36594.41708111763,0.7867307066917419,0.7857255339622498,0.6966599822044373,1.2365820407867432,50000 -1359.2032098770142,3.831068515777588,35755.16663765907,105695,0,35755.16663765907,0.5662000179290771,1.96476411819458,10000,37121.79448270798,0.7801339030265808,0.8069247603416443,0.6951000094413757,1.2463161945343018,50000 -1376.4114780426023,3.880370616912842,36265.1814968586,107206,0,36265.1814968586,0.554900050163269,2.0364840030670166,10000,37649.1190032959,0.7763273119926453,0.8324143290519714,0.6882799863815308,1.2644940614700315,50000 -1393.2988758087158,3.924722671508789,36775.30914545059,108717,0,36775.30914545059,0.5611000061035156,1.955988407135009,10000,38176.23004317284,0.8256935477256775,0.6464024782180786,0.6990599632263184,1.226111888885498,50000 -1410.4655022621157,3.9767305850982666,37285.4779791832,110229,0,37285.4779791832,0.5652000308036804,1.968800067901612,10000,38703.67038869858,0.8019770383834839,0.727361261844635,0.6938199996948242,1.250580072402954,50000 -1427.5435523986816,4.02356743812561,37795.405812978745,111739,0,37795.405812978745,0.5719000101089478,1.935737371444702,10000,39230.77417373657,0.80765700340271,0.6981836557388306,0.700939953327179,1.21205735206604,50000 -1444.5057072639463,4.071113586425781,38305.49462866783,113250,0,38305.49462866783,0.5667999982833862,1.932706356048584,10000,39757.9245569706,0.8006218075752258,0.7307600378990173,0.7016599774360657,1.2059155702590942,50000 -1461.475175857544,4.1177613735198975,38815.42671918869,114760,0,38815.42671918869,0.5746999979019165,1.917339324951172,10000,40284.924193143845,0.7948819994926453,0.7492088079452515,0.6979399919509888,1.2191317081451416,50000 -1478.638622522354,4.166247367858887,39325.39078187943,116270,0,39325.39078187943,0.5814000368118286,1.9192442893981927,10000,40812.15231490135,0.7964963316917419,0.7375555038452148,0.7020999789237976,1.2137975692749023,50000 -1496.0060350894928,4.21298360824585,39835.441056251526,117781,0,39835.441056251526,0.5736000537872314,1.9403326511383057,10000,41339.66882777214,0.8391461968421936,0.5893431305885315,0.702299952507019,1.215503215789795,50000 -1513.2711565494535,4.26213812828064,40345.60301208496,119292,0,40345.60301208496,0.5779000520706177,1.933236002922058,10000,41867.19692969322,0.8243981003761292,0.6250141263008118,0.7067999839782715,1.1909725666046145,50000 -1530.355375289917,4.317422151565552,40855.583112478256,120803,0,40855.583112478256,0.5700000524520874,1.950163722038269,10000,42394.367864370346,0.8103874325752258,0.6829017996788025,0.7016199827194214,1.216865062713623,50000 -1547.233092069626,4.367457628250122,41365.70845270157,122314,0,41365.70845270157,0.5781000256538391,1.9371180534362795,10000,42921.472259521484,0.8113440275192261,0.6753456592559814,0.7064399719238281,1.2009263038635254,50000 -1564.4061267375946,4.415964365005493,41875.71085715294,123825,0,41875.71085715294,0.5855000019073486,1.910176515579224,10000,43448.74796462059,0.8163862824440002,0.6607792973518372,0.7098199725151062,1.1880916357040403,50000 -1581.5561335086825,4.467309474945068,42385.60908794403,125336,0,42385.60908794403,0.5805000066757202,1.9390764236450195,10000,43975.89997935295,0.8089126348495483,0.6776189208030701,0.7065799832344055,1.1986366510391235,50000 -1598.7869279384613,4.516381502151489,42895.51242017746,126847,0,42895.51242017746,0.5924000144004822,1.8532426357269287,10000,44503.13484168053,0.8555285334587097,0.5188804268836975,0.7141599655151367,1.1573100090026855,50000 -1615.8899295330048,4.567685127258301,43405.469561100006,128357,0,43405.469561100006,0.5907000303268433,1.86967408657074,10000,45030.29795050621,0.8460618257522583,0.5457704067230225,0.7174400091171265,1.141788363456726,50000 -1633.0848352909088,4.626312971115112,43915.47292876244,129868,0,43915.47292876244,0.5879999995231628,1.8958600759506223,10000,45557.60752797127,0.833426296710968,0.5857810974121094,0.7130599617958069,1.1719478368759155,50000 -1650.1739237308502,4.677285432815552,44425.701722860336,131380,0,44425.701722860336,0.5936000347137451,1.8518024682998653,10000,46085.02875685692,0.8418765664100647,0.564940869808197,0.720579981803894,1.134805679321289,50000 -1667.4603853225708,4.729793787002564,44935.64200139046,132891,0,44935.64200139046,0.5984000563621521,1.8685274124145508,10000,46612.3599793911,0.8443080186843872,0.5527390241622925,0.7223599553108215,1.1370505094528198,50000 -1684.610284090042,4.779802560806274,45445.54392409325,134402,0,45445.54392409325,0.5957000255584717,1.8864117860794067,10000,47139.51445841789,0.8469586968421936,0.5429478883743286,0.722599983215332,1.1435267925262451,50000 -1701.8160552978516,4.829279661178589,45955.56926059723,135913,0,45955.56926059723,0.5994000434875488,1.854590654373169,10000,47666.84624814987,0.879902720451355,0.428302139043808,0.7254799604415894,1.13906729221344,50000 -1718.7475936412811,4.878940105438232,46465.56084442139,137424,0,46465.56084442139,0.5951000452041626,1.876573920249939,10000,48193.87060427666,0.8622847199440002,0.485957384109497,0.7216599583625793,1.1468485593795776,50000 -1735.6987011432648,4.933518886566162,46975.57577776909,138935,0,46975.57577776909,0.6015000343322754,1.8418588638305664,10000,48720.94380235672,0.8673469424247742,0.4645864367485046,0.7269600033760071,1.113786220550537,50000 -1753.1074166297913,4.989470481872559,47485.51925230026,140446,0,47485.51925230026,0.6013000011444092,1.8446753025054927,10000,49248.403594732285,0.8664301633834839,0.4641014337539673,0.7262399792671204,1.1191344261169434,50000 -1770.591741323471,5.043493986129761,47995.51835870743,141957,0,47995.51835870743,0.6045000553131104,1.8624248504638672,10000,49775.99292445183,0.8687818646430969,0.4507312178611755,0.7324999570846558,1.1073507070541382,50000 -1787.475771188736,5.096850633621216,48505.51794219017,143468,0,48505.51794219017,0.6025000214576721,1.8811101913452148,10000,50302.9829390049,0.8729272484779358,0.4402921795845032,0.7320399880409241,1.117584228515625,50000 -1804.427111625672,5.150873422622681,49015.45076060295,144979,0,49015.45076060295,0.6011000275611877,1.8705135583877563,10000,50829.97280144692,0.8941127061843872,0.3673299551010132,0.7300800085067749,1.1244101524353027,50000 -1821.4655013084407,5.202937126159668,49525.504454135895,146490,0,49525.504454135895,0.6080000400543213,1.8569191694259644,10000,51357.16882824898,0.8969627022743225,0.3568805456161499,0.7330399751663208,1.1061513423919678,50000 -1838.759839296341,5.256864070892334,50035.673583984375,148001,0,50035.673583984375,0.6052000522613525,1.8347212076187127,10000,51884.738174676895,0.8917012214660645,0.3716536462306976,0.7354599833488464,1.0957401990890503,50000 -1855.8230559825893,5.3106231689453125,50545.63154053688,149512,0,50545.63154053688,0.6164000034332275,1.8350088596344,10000,52411.86543726921,0.8974609375,0.349711924791336,0.7398799657821655,1.0879695415496826,50000 -1872.6232559680936,5.355967283248901,51055.59986066818,151023,0,51055.59986066818,0.6171000003814697,1.84127140045166,10000,52938.73140335083,0.8940728306770325,0.3650061786174774,0.7384399771690369,1.08740496635437,50000 -1890.3863129615784,5.407841682434082,51565.666075229645,152534,0,51565.666075229645,0.612000048160553,1.854403018951416,10000,53466.66446304321,0.902164340019226,0.3393253982067108,0.7376399636268616,1.093889236450195,50000 -1907.4831624031067,5.463268280029297,52075.750030994415,154045,0,52075.750030994415,0.6180000305175781,1.834379196166992,10000,53993.95304322243,0.924465835094452,0.2674159705638885,0.74263995885849,1.075542688369751,50000 -1924.760721445084,5.524292469024658,52585.85358142853,155555,0,52585.85358142853,0.6170000433921814,1.832737922668457,10000,54521.44651532173,0.9222337007522584,0.2677811682224273,0.7441799640655518,1.0739376544952393,50000 -1941.9596049785607,5.5780956745147705,53095.834916353226,157065,0,53095.834916353226,0.6144000291824341,1.843645453453064,10000,55048.73201870918,0.9177096486091614,0.2826022207736969,0.7434599995613098,1.0783482789993286,50000 -1959.165843486786,5.632391452789307,53605.92672085762,158576,0,53605.92672085762,0.6214000582695007,1.8301516771316528,10000,55576.13701224327,0.9215162396430968,0.2728367447853088,0.7452399730682373,1.074062705039978,50000 -1976.270096063614,5.6887383460998535,54115.92363142967,160087,0,54115.92363142967,0.6155000329017639,1.869329571723938,10000,56103.3468940258,0.9222137928009032,0.2624360918998718,0.7467199563980103,1.0808501243591309,50000 -1993.2992358207705,5.746572971343994,54625.90968847275,161598,0,54625.90968847275,0.6211000084877014,1.862369418144226,10000,56630.47075533867,0.9258609414100648,0.2548287212848663,0.7464199662208557,1.0761977434158323,50000 -2010.447353601456,5.804094076156616,55136.06539058685,163109,0,55136.06539058685,0.6243000030517578,1.836309671401977,10000,57157.884127378464,0.9452128410339355,0.1964490413665771,0.7478199601173401,1.0691890716552734,50000 -2027.654661655426,5.87807035446167,55645.96127533913,164619,0,55645.96127533913,0.6244000196456909,1.8305258750915527,10000,57685.11239886284,0.9431201815605164,0.2001777589321136,0.7487999796867371,1.062160611152649,50000 -2044.7875108718872,5.943117380142212,56156.17899298668,166130,0,56156.17899298668,0.6175000071525574,1.8367518186569207,10000,58212.58046770096,0.9429408311843872,0.2010091245174408,0.7507999539375305,1.063918113708496,50000 -2062.002586364746,5.999414920806885,56666.13904762268,167639,0,56666.13904762268,0.6234000325202942,1.833619952201844,10000,58739.86371612549,0.9417450428009032,0.2010733932256698,0.7512800097465515,1.0629782676696775,50000 -2079.134115457535,6.059413433074951,57176.32164978981,169150,0,57176.32164978981,0.625,1.8395988941192627,10000,59267.289298295975,0.9476044178009032,0.1906523704528808,0.7506200075149536,1.0631998777389526,50000 -2095.982746839524,6.118293046951294,57686.33897137642,170660,0,57686.33897137642,0.625700056552887,1.836114883422852,10000,59794.26624298096,0.9497169852256776,0.18296679854393,0.7519999742507935,1.061332106590271,50000 -2113.2451345920563,6.178622245788574,58196.301375865936,172170,0,58196.301375865936,0.6272000074386597,1.831020832061768,10000,60321.604273319244,0.9553770422935486,0.1642269641160965,0.7532199621200562,1.0589675903320312,50000 -2130.392186880112,6.2380759716033936,58706.51188850403,173681,0,58706.51188850403,0.6290000081062317,1.8309428691864007,10000,60849.07347011566,0.9561144709587096,0.1630240380764007,0.7536199688911438,1.0499507188796997,50000 -2147.473479747772,6.29523515701294,59216.62665820122,175192,0,59216.62665820122,0.6276000142097473,1.8195759057998653,10000,61376.379529476166,0.9563137292861938,0.1613334119319915,0.7529799938201904,1.0494953393936155,50000 -2164.808212280273,6.3645124435424805,59726.75988483429,176703,0,59726.75988483429,0.6301000118255615,1.824698686599732,10000,61903.96989917755,0.956273913383484,0.1597492545843124,0.7539399862289429,1.049542784690857,50000 -2182.069860935211,6.435606241226196,60236.88195681572,178213,0,60236.88195681572,0.6283000111579895,1.824266076087952,10000,62431.477404117584,0.9578084945678712,0.1552063673734665,0.7549200057983398,1.0473157167434692,50000 -2199.227989912033,6.507908821105957,60746.79544234276,179723,0,60746.79544234276,0.6289000511169434,1.826785683631897,10000,62958.6741335392,0.9579480290412904,0.1545311510562896,0.7556999921798706,1.0462963581085205,50000 -2216.197611093521,6.566876411437988,61256.721252441406,181233,0,61256.721252441406,0.6278000473976135,1.820835828781128,10000,63485.68079519272,0.9613759517669678,0.1460684090852737,0.7547599673271179,1.0432047843933103,50000 -2233.5767714977264,6.625716686248779,61766.63048315048,182743,0,61766.63048315048,0.6301000118255615,1.822758674621582,10000,64013.07961964607,0.9588847160339355,0.1517803221940994,0.7552799582481384,1.0448073148727417,50000 -2250.8308987617493,6.685147285461426,62276.70138978958,184253,0,62276.70138978958,0.6307000517845154,1.82177996635437,10000,64540.51701974869,0.9606584906578064,0.1485065221786499,0.7553399801254272,1.0438897609710691,50000 -2267.693324804306,6.7476887702941895,62786.637889146805,185763,0,62786.637889146805,0.6310000419616699,1.8208897113800049,10000,65067.43080019951,0.9605787396430968,0.1489373445510864,0.7553600072860718,1.0433369874954224,50000 -2284.7242891788483,6.810681581497192,63296.7076792717,187274,0,63296.7076792717,0.6310000419616699,1.822888970375061,10000,65594.64647817612,0.9614955186843872,0.1437671929597854,0.755299985408783,1.0444318056106567,50000 -2301.769454240799,6.878018856048584,63806.64422917366,188784,0,63806.64422917366,0.6305000185966492,1.8203967809677124,10000,66121.7480969429,0.9608577489852904,0.1480024456977844,0.7550999522209167,1.0436111688613892,50000 -2318.9982414245605,6.939778089523315,64316.67609715462,190294,0,64316.67609715462,0.6308000087738037,1.8202807903289795,10000,66649.12196302414,0.9602000713348388,0.1484966427087783,0.7550599575042725,1.0441384315490725,50000 -2336.7331914901733,6.99896240234375,64826.72256541252,191804,0,64826.72256541252,0.6303000450134277,1.8232477903366089,10000,67177.01352787018,0.9599011540412904,0.1491530537605285,0.7549200057983398,1.045419692993164,50000 -2353.772256135941,7.07848596572876,65336.91848921776,193315,0,65336.91848921776,0.6300000548362732,1.8218053579330444,10000,67704.38000226021,0.9594427347183228,0.1506325155496597,0.7553199529647827,1.0440016984939575,50000 -2370.7418541908264,7.142187833786011,65846.86537241936,194825,0,65846.86537241936,0.6312000155448914,1.820330023765564,10000,68231.4124391079,0.9610371589660645,0.1465877741575241,0.7549200057983398,1.0440150499343872,50000 -2387.5755808353424,7.208361387252808,66356.94939851761,196335,0,66356.94939851761,0.6327000260353088,1.8242987394332888,10000,68758.44803285599,0.9608577489852904,0.145789235830307,0.7549799680709839,1.0447232723236084,50000 -2404.9300212860107,7.270601749420166,66867.08155965805,197846,0,66867.08155965805,0.6309000253677368,1.822264552116394,10000,69286.04866456985,0.9601203799247742,0.1475706398487091,0.7552399635314941,1.0447214841842651,50000 -2422.416407346725,7.336972713470459,67376.99536323547,199356,0,67376.99536323547,0.6305000185966492,1.8182530403137207,10000,69813.56760764122,0.9608777165412904,0.1482481956481933,0.7546600103378296,1.042603850364685,50000 -2439.470073223114,7.393136739730835,67887.0665371418,200867,0,67887.0665371418,0.6312000155448914,1.821490168571472,10000,70340.8018951416,0.9594228267669678,0.150496631860733,0.7551800012588501,1.0434798002243042,50000 -2456.667491674423,7.457369804382324,68397.0556704998,202377,0,68397.0556704998,0.631600022315979,1.822011947631836,10000,70868.1044383049,0.9612165093421936,0.1434010863304138,0.7554000020027161,1.043828368186951,50000 -2473.7711822986603,7.51897931098938,68906.9993698597,203887,0,68906.9993698597,0.6301000118255615,1.822356104850769,10000,71395.26534843445,0.9611168503761292,0.1459710448980331,0.7547199726104736,1.0434318780899048,50000 -2490.9919068813324,7.580933094024658,69416.98277258873,205397,0,69416.98277258873,0.6310000419616699,1.8229455947875977,10000,71922.58283758163,0.9606584906578064,0.1462922245264053,0.7554199695587158,1.0448895692825315,50000 -2508.140120267868,7.6438798904418945,69926.99735879898,206907,0,69926.99735879898,0.6320000290870667,1.8228908777236936,10000,72449.86073350906,0.9609175324440002,0.1488041579723358,0.7552799582481384,1.044255256652832,50000 -2525.5167202949524,7.710052490234375,70436.92092514038,208417,0,70436.92092514038,0.6306000351905823,1.8218469619750977,10000,72977.27802371979,0.9602000713348388,0.1493316441774368,0.7552399635314941,1.0437886714935305,50000 -2542.445108652115,7.775446891784668,70946.95390248299,209928,0,70946.95390248299,0.631100058555603,1.8214828968048096,10000,73504.35719633102,0.9607381820678712,0.1480738967657089,0.7552399635314941,1.044136643409729,50000 -2559.6543271541595,7.840409278869629,71457.17092013359,211439,0,71457.17092013359,0.6317000389099121,1.823574542999268,10000,74031.89971113205,0.9592434167861938,0.1501782238483429,0.7551800012588501,1.0448315143585205,50000 -2576.7010428905487,7.906655311584473,71967.09794926643,212949,0,71967.09794926643,0.6307000517845154,1.822056531906128,10000,74558.99166703224,0.9595423936843872,0.1497287303209304,0.7553399801254272,1.0440635681152344,50000 -2593.694860935211,7.975393772125244,72477.27730298042,214460,0,72477.27730298042,0.6321000456809998,1.8225747346878047,10000,75086.28588485718,0.9602000713348388,0.1484657675027847,0.7551400065422058,1.0437440872192385,50000 -2610.5095081329346,8.043686389923096,72987.33560729027,215970,0,72987.33560729027,0.6306000351905823,1.823861837387085,10000,75613.27943301201,0.9599210619926452,0.148515373468399,0.7547399997711182,1.0446696281433103,50000 -2627.5460596084595,8.105120658874512,73497.32830357552,217481,0,73497.32830357552,0.6312000155448914,1.8224695920944207,10000,76140.42266964912,0.9605388641357422,0.1486754417419433,0.7548999786376953,1.0447502136230469,50000 -2644.8311898708344,8.175191640853882,74007.20657753944,218991,0,74007.20657753944,0.6301000118255615,1.8205660581588743,10000,76667.70803070068,0.9596021771430968,0.1492205560207367,0.7551199793815613,1.043259620666504,50000 -2661.6404991149902,8.243839025497437,74517.34069633484,220501,0,74517.34069633484,0.6314000487327576,1.822311758995056,10000,77194.7726213932,0.9606186151504515,0.1475935280323028,0.7549399733543396,1.044519543647766,50000 -2678.5212664604187,8.310691595077515,75027.50423502922,222012,0,75027.50423502922,0.6314000487327576,1.8228307962417605,10000,77721.93538951874,0.960339605808258,0.1480839848518371,0.7552199959754944,1.0450679063796997,50000 -2695.558828353882,8.374577283859253,75537.42972803116,223522,0,75537.42972803116,0.6297000050544739,1.822348356246948,10000,78249.01416754723,0.9602798223495485,0.148158848285675,0.7552399635314941,1.04426109790802,50000 -2712.5811915397644,8.44601035118103,76047.38732123375,225032,0,76047.38732123375,0.6327000260353088,1.8209985494613647,10000,78776.11684513092,0.960339605808258,0.1490835100412368,0.7552799582481384,1.0442862510681152,50000 -2729.6561329364777,8.513875246047974,76557.46827101707,226543,0,76557.46827101707,0.6315000057220459,1.82106614112854,10000,79303.39326834679,0.9614157676696776,0.1448323428630828,0.7552799582481384,1.043460488319397,50000 -2746.561534643173,8.578904390335083,77067.64736413956,228054,0,77067.64736413956,0.6314000487327576,1.82277250289917,10000,79830.59428882599,0.9608178734779358,0.1471483409404754,0.7548399567604065,1.044750094413757,50000 -2764.244913339615,8.647281408309937,77577.66748857498,229565,0,77577.66748857498,0.631600022315979,1.81840181350708,10000,80358.41845607758,0.9593231678009032,0.1507213413715362,0.7548999786376953,1.0427016019821167,50000 -2781.0031356811523,8.715314626693726,78087.64923286438,231075,0,78087.64923286438,0.6315000057220459,1.820623159408569,10000,80885.27916908264,0.9600605964660645,0.1495428681373596,0.7551999688148499,1.043567180633545,50000 -2798.164034128189,8.785706281661987,78597.64774560928,232585,0,78597.64774560928,0.6317000389099121,1.8235143423080444,10000,81412.56066179276,0.9606783986091614,0.1485619843006134,0.7554000020027161,1.0447583198547363,50000 -2815.323947429657,8.851818084716797,79107.72985172272,234095,0,79107.72985172272,0.6314000487327576,1.821435809135437,10000,81939.9204120636,0.9609375,0.1463867574930191,0.7549799680709839,1.0434765815734863,50000 -2832.093398332596,8.920358657836914,79617.69214963913,235605,0,79617.69214963913,0.6312000155448914,1.820827841758728,10000,82466.77335429192,0.9598214030265808,0.1489105075597763,0.7550999522209167,1.0432418584823608,50000 -2849.3516008853912,8.984033823013306,80127.71584177017,237115,0,80127.71584177017,0.6315000057220459,1.822920203208924,10000,82994.17069745064,0.9611168503761292,0.1452362686395645,0.7553799748420715,1.044255256652832,50000 -2866.3377170562744,9.051939964294434,80637.91840529442,238564,0,80637.91840529442,0.6309000253677368,1.8205711841583248,10000,83521.47772955894,0.9593430757522584,0.1516030877828598,0.7551599740982056,1.0438953638076782,50000 -2883.4794404506683,9.120002508163452,81147.93720412254,240075,0,81147.93720412254,0.6315000057220459,1.82076632976532,10000,84048.75836634636,0.961535394191742,0.1448425352573394,0.7551599740982056,1.0428738594055176,50000 -2900.567331552505,9.188629388809204,81658.0580971241,241586,0,81658.0580971241,0.6305000185966492,1.82180655002594,10000,84576.08765101433,0.9617147445678712,0.1425634920597076,0.7549200057983398,1.0439531803131104,50000 -2917.415878772736,9.261387825012209,82167.9572262764,243096,0,82167.9572262764,0.629800021648407,1.821830153465271,10000,85102.96196842194,0.96000075340271,0.148179680109024,0.7549200057983398,1.0450206995010376,50000 -2934.426279783249,9.335773944854736,82677.94487500191,244606,0,82677.94487500191,0.6308000087738037,1.8226317167282104,10000,85630.08571457863,0.9604990482330322,0.1490000039339065,0.7550399899482727,1.0443910360336304,50000 -2951.382652282715,9.406443357467651,83188.13959169388,246117,0,83188.13959169388,0.631100058555603,1.820813775062561,10000,86157.36020231247,0.9610570669174194,0.1478585451841354,0.7549200057983398,1.044091820716858,50000 -2968.502408027649,9.47542691230774,83698.27619314194,247628,0,83698.27619314194,0.6317000389099121,1.822364330291748,10000,86684.73764586449,0.959622085094452,0.148349180817604,0.7554799914360046,1.0434871912002563,50000 -2985.615665435791,9.542508840560911,84208.43679046631,249138,0,84208.43679046631,0.6308000087738037,1.82217025756836,10000,87212.13094115257,0.959382951259613,0.1505030393600464,0.7555599808692932,1.044371485710144,50000 -3002.359689235688,9.651327848434448,84718.3010494709,250648,0,84718.3010494709,0.6315000057220459,1.8228862285614007,10000,87738.89959907532,0.9588448405265808,0.151751235127449,0.7554000020027161,1.0438597202301023,50000 -3019.3289988040924,9.725855588912964,85228.23076438904,252158,0,85228.23076438904,0.6318000555038452,1.8210536241531368,10000,88265.92456793785,0.960957407951355,0.1462645232677459,0.7552199959754944,1.043168544769287,50000 -3036.42351937294,9.79535961151123,85738.42370462418,253669,0,85738.42370462418,0.6310000419616699,1.8215969800949097,10000,88793.33317494392,0.959980845451355,0.1491307467222213,0.7551800012588501,1.044544696807861,50000 -3053.542221546173,9.870186805725098,86248.32073330879,255179,0,86248.32073330879,0.6305000185966492,1.822287559509277,10000,89320.47562503815,0.9600805044174194,0.1473139524459839,0.7553399801254272,1.0440517663955688,50000 -3070.470389842987,9.949352264404297,86758.5216627121,256690,0,86758.5216627121,0.6304000020027161,1.8216809034347528,10000,89847.73591923714,0.9594427347183228,0.1510674059391021,0.7545799612998962,1.0446027517318726,50000 -3087.653629779816,10.03275179862976,87268.53696084023,258200,0,87268.53696084023,0.6319000124931335,1.8211063146591189,10000,90375.06952118874,0.960758090019226,0.1466726660728454,0.754859983921051,1.0436691045761108,50000 -3104.648650407791,10.103766679763794,87778.56299948692,259711,0,87778.56299948692,0.6309000253677368,1.822376608848572,10000,90902.21386170389,0.9605388641357422,0.1487253457307815,0.7547599673271179,1.0438485145568848,50000 -3121.9073457717896,10.17822551727295,88288.66622662544,261222,0,88288.66622662544,0.6309000253677368,1.820735216140747,10000,91429.70258498192,0.959741711616516,0.149654671549797,0.7551199793815613,1.0438201427459717,50000 -3139.123267889023,10.250165700912476,88798.86704826355,262733,0,88798.86704826355,0.6309000253677368,1.8212600946426392,10000,91957.24380397797,0.9611766338348388,0.148326426744461,0.754859983921051,1.0432186126708984,50000 -3156.001371860504,10.322932720184326,89308.83620667458,264243,0,89308.83620667458,0.6310000419616699,1.821357488632202,10000,92484.2161114216,0.9611965417861938,0.1453111469745636,0.7555800080299377,1.0443156957626345,50000 -3173.056991100312,10.393508911132812,89818.98686289787,265754,0,89818.98686289787,0.6318000555038452,1.8218551874160769,10000,93011.54561305046,0.9607979655265808,0.1484108865261078,0.7550999522209167,1.0444883108139038,50000 -3190.143748044968,10.469247817993164,90328.98372650146,267265,0,90328.98372650146,0.6308000087738037,1.82218599319458,10000,93538.75735902786,0.9601004123687744,0.14792300760746,0.7553600072860718,1.0435808897018433,50000 -3207.9533960819244,10.54270601272583,90839.05476880074,268775,0,90839.05476880074,0.6313000321388245,1.82094955444336,10000,94066.76383280754,0.9599011540412904,0.1491859406232834,0.754859983921051,1.0436774492263794,50000 -3225.401467323303,10.631753206253052,91349.01457190514,270285,0,91349.01457190514,0.6309000253677368,1.821266889572144,10000,94594.31363272668,0.9602997303009032,0.1504989713430404,0.7555599808692932,1.0429737567901611,50000 -3242.464148521424,10.695482969284058,91859.1024298668,271796,0,91859.1024298668,0.6319000124931335,1.819881677627564,10000,95121.58063220978,0.961535394191742,0.1452290415763855,0.7554000020027161,1.0430622100830078,50000 -3259.586245775223,10.765870571136476,92369.1267938614,273306,0,92369.1267938614,0.6305000185966492,1.8220733404159544,10000,95648.849401474,0.9596819281578064,0.1488217264413833,0.7553399801254272,1.0445902347564695,50000 -3276.891673803329,10.8425931930542,92879.02329158784,274816,0,92879.02329158784,0.6303000450134277,1.8224775791168213,10000,96176.17948126791,0.9607979655265808,0.1465967744588852,0.7552799582481384,1.0439108610153198,50000 -3294.133857011795,10.913537740707396,93389.19371938704,276327,0,93389.19371938704,0.6310000419616699,1.821550726890564,10000,96703.71585345268,0.9604591727256776,0.1478716731071472,0.7551800012588501,1.044476866722107,50000 -3311.18604016304,10.987645626068115,93899.33606386185,277837,0,93899.33606386185,0.6314000487327576,1.8213553428649905,10000,97231.03688788414,0.9602399468421936,0.1496829986572265,0.7553600072860718,1.0444915294647217,50000 -3328.050077676773,11.06198787689209,94409.34986424446,279347,0,94409.34986424446,0.6308000087738037,1.822886824607849,10000,97758.04048991203,0.9615154266357422,0.143006756901741,0.7550999522209167,1.043919563293457,50000 -3345.172520637512,11.137099504470823,94919.4937672615,280858,0,94919.4937672615,0.631100058555603,1.820765733718872,10000,98285.43342876434,0.960758090019226,0.1456627547740936,0.7552599906921387,1.043474197387695,50000 -3362.503289937973,11.210871934890749,95429.6256942749,282369,0,95429.6256942749,0.6309000253677368,1.821438193321228,10000,98813.02177882196,0.9601203799247742,0.1479266285896301,0.7554599642753601,1.0445046424865725,50000 -3379.660735845566,11.28923773765564,95939.5230576992,283879,0,95939.5230576992,0.6307000517845154,1.821829915046692,10000,99340.20707345007,0.9611766338348388,0.1473551839590072,0.7550599575042725,1.043495535850525,50000 -3396.690092563629,11.36655569076538,96449.7061998844,285390,0,96449.7061998844,0.6313000321388245,1.821364164352417,10000,99867.54928565024,0.9603196382522584,0.1493917554616928,0.7549799680709839,1.043468713760376,50000 -3413.703411340713,11.450378656387327,96959.79819345474,286900,0,96959.79819345474,0.6312000155448914,1.820935606956482,10000,100394.79163074492,0.960598647594452,0.1475573778152465,0.7552199959754944,1.0439929962158203,50000 -3430.952383518219,11.525851726531982,97469.97769403458,288411,0,97469.97769403458,0.6308000087738037,1.8218662738800049,10000,100922.34735965727,0.9584462642669678,0.1529013216495514,0.7551999688148499,1.043368577957153,50000 -3448.037952423096,11.608575582504272,97980.17452836037,289922,0,97980.17452836037,0.6312000155448914,1.8234632015228271,10000,101449.76487326622,0.9600605964660645,0.1496689617633819,0.7550999522209167,1.0449682474136353,50000 -3465.177909612656,11.683691501617432,98490.11977744102,291433,0,98490.11977744102,0.6320000290870667,1.8225699663162231,10000,101976.9769806862,0.9602997303009032,0.1464380621910095,0.7553199529647827,1.044140338897705,50000 -3482.211163520813,11.76318645477295,99000.29478907584,292944,0,99000.29478907584,0.6308000087738037,1.821649074554444,10000,102504.31647467612,0.9602598547935486,0.1485782861709594,0.7555999755859375,1.0444979667663574,50000 -3499.3769342899323,11.840510606765749,99510.4796898365,294455,0,99510.4796898365,0.6317000389099121,1.821648359298706,10000,103031.7966837883,0.9598014950752258,0.1485868841409683,0.7549799680709839,1.0446248054504397,50000 -3516.5390000343323,11.916550636291504,100020.47172164916,295965,0,100020.47172164916,0.631100058555603,1.821966528892517,10000,103559.07963442802,0.9590441584587096,0.1495958268642425,0.7550199627876282,1.043899655342102,50000 -3533.82601761818,11.992268562316896,100530.40967082976,297475,0,100530.40967082976,0.6300000548362732,1.8233592510223389,10000,104086.43240571022,0.961734652519226,0.1469607055187225,0.7551800012588501,1.0441290140151978,50000 -3551.3767199516296,12.078696966171265,101040.5618698597,298986,0,101040.5618698597,0.631100058555603,1.822619795799256,10000,104614.27379345894,0.9597018361091614,0.1497321426868438,0.7550599575042725,1.0440407991409302,50000 -3568.2104346752167,12.144399642944336,101550.59058713912,300496,0,101550.59058713912,0.6314000487327576,1.821979522705078,10000,105141.25382041933,0.9599011540412904,0.1493127793073654,0.7553399801254272,1.0435774326324463,50000 -3585.3242876529694,12.231438398361206,102060.7697827816,302007,0,102060.7697827816,0.631100058555603,1.821194648742676,10000,105668.68538475037,0.960758090019226,0.146996721625328,0.7557199597358704,1.043437123298645,50000 -3602.418655157089,12.31143856048584,102570.90319132803,303517,0,102570.90319132803,0.6321000456809998,1.822009444236756,10000,106196.04564929008,0.961316168308258,0.1453872919082641,0.7549600005149841,1.044714331626892,50000 -3619.393835544586,12.389586448669434,103081.1084947586,305028,0,103081.1084947586,0.631600022315979,1.821263194084168,10000,106723.35728740692,0.9611766338348388,0.1474113166332245,0.7551199793815613,1.0436537265777588,50000 -3637.1305088996887,12.47010898590088,103590.98895573616,306538,0,103590.98895573616,0.6306000351905823,1.8231303691864007,10000,107251.10673069954,0.9596021771430968,0.1488519310951233,0.7552799582481384,1.0441625118255615,50000 -3654.240547180176,12.552121877670288,104100.93312764168,308048,0,104100.93312764168,0.6309000253677368,1.822330355644226,10000,107778.29575705528,0.9596819281578064,0.1510353088378906,0.7555599808692932,1.044703722000122,50000 -3671.359024047852,12.631409883499146,104610.91036987305,309558,0,104610.91036987305,0.6309000253677368,1.820195198059082,10000,108305.5235171318,0.9604591727256776,0.1492012590169906,0.7558599710464478,1.0423375368118286,50000 -3688.6541543006897,12.713399410247805,105120.89776802064,311069,0,105120.89776802064,0.6317000389099121,1.8213917016983032,10000,108832.93958759308,0.960180163383484,0.1465490460395813,0.7554799914360046,1.0438048839569092,50000 -3705.587835550308,12.816117763519289,105630.90949559212,312579,0,105630.90949559212,0.6312000155448914,1.821340084075928,10000,109360.03984093666,0.960379421710968,0.147471010684967,0.754859983921051,1.043671727180481,50000 -3722.494250059128,12.90107798576355,106140.7857196331,314089,0,106140.7857196331,0.6302000284194946,1.8212146759033203,10000,109886.95943021774,0.9612165093421936,0.1459231972694397,0.7547199726104736,1.0452815294265747,50000 -3739.307128429413,12.98228931427002,106650.83040618896,315599,0,106650.83040618896,0.6309000253677368,1.820225715637207,10000,110413.95022773744,0.9605787396430968,0.1492339074611663,0.7549999952316284,1.0437467098236084,50000 -3756.320683717728,13.06489896774292,107160.75046777724,317109,0,107160.75046777724,0.6308000087738037,1.821238398551941,10000,110941.01823163033,0.9605189561843872,0.1477705836296081,0.7552399635314941,1.0443214178085327,50000 -3773.578710079193,13.148986101150513,107670.861992836,318619,0,107670.861992836,0.6307000517845154,1.82354736328125,10000,111468.52405381204,0.9615154266357422,0.1435057967901229,0.7549600005149841,1.0443395376205444,50000 -3790.567802667618,13.232507467269896,108180.89000105858,320129,0,108180.89000105858,0.6304000020027161,1.821660280227661,10000,111995.6765794754,0.960359513759613,0.1465110927820205,0.7546199560165405,1.044081687927246,50000 -3807.768210887909,13.314942359924316,108690.9446735382,321640,0,108690.9446735382,0.6313000321388245,1.8210450410842896,10000,112523.06674385072,0.9608178734779358,0.1478320807218551,0.7554199695587158,1.0437196493148804,50000 -3824.800094604492,13.394735336303713,109200.85324692726,323150,0,109200.85324692726,0.6308000087738037,1.8226555585861208,10000,113050.1400270462,0.960359513759613,0.1503463238477707,0.7552399635314941,1.04445481300354,50000 -3842.041315317154,13.47830367088318,109710.8930542469,324660,0,109710.8930542469,0.631100058555603,1.8214871883392327,10000,113577.55630922318,0.9606186151504515,0.1466918587684631,0.7547999620437622,1.0439778566360474,50000 -3859.1264731884,13.55803608894348,110220.93134093285,326171,0,110220.93134093285,0.6309000253677368,1.8210031986236568,10000,114104.812646389,0.9599409699440002,0.1505787074565887,0.755840003490448,1.0444331169128418,50000 -3876.194142580032,13.642706871032717,110731.02626633644,327682,0,110731.02626633644,0.631600022315979,1.820167899131775,10000,114632.11230945589,0.9587252736091614,0.1513209640979766,0.7554799914360046,1.0429681539535522,50000 -3893.1527485847473,14.410885095596312,111240.5034327507,329191,0,111240.5034327507,0.6309000253677368,1.822582721710205,10000,115159.3679318428,0.9612364172935486,0.1466499269008636,0.7554399967193604,1.044033765792847,50000 -3909.984509468079,14.49191951751709,111750.51376104356,330702,0,111750.51376104356,0.6299000382423401,1.8222780227661133,10000,115686.34260249138,0.9597018361091614,0.1489305049180984,0.7549200057983398,1.044543743133545,50000 -3927.011519670488,14.57741928100586,112260.58879756927,332213,0,112260.58879756927,0.6320000290870667,1.8226637840271,10000,116213.58219194412,0.9599609375,0.1484025716781616,0.7553600072860718,1.0446621179580688,50000 -3944.153341770172,14.66352605819702,112770.559705019,333723,0,112770.559705019,0.6308000087738037,1.822697997093201,10000,116740.83248353004,0.959741711616516,0.1488601267337799,0.7554999589920044,1.0437555313110352,50000 -3961.2871708869934,14.749180316925049,113280.64534378052,335234,0,113280.64534378052,0.6306000351905823,1.8214054107666016,10000,117268.18897390366,0.9608976244926452,0.1474035531282425,0.7553199529647827,1.0441795587539673,50000 -3978.467176914215,14.835916996002195,113790.71906757356,336744,0,113790.71906757356,0.6309000253677368,1.82236397266388,10000,117795.58216953278,0.9601004123687744,0.1491908580064773,0.7552799582481384,1.0451172590255735,50000 -3995.216495990753,14.918900728225708,114300.66180181503,338255,0,114300.66180181503,0.6303000450134277,1.821576833724976,10000,118322.4094736576,0.9598014950752258,0.1497568637132644,0.7551400065422058,1.0436608791351318,50000 -4012.285629749298,15.001896142959597,114810.5943918228,339765,0,114810.5943918228,0.6299000382423401,1.82118558883667,10000,118849.54559993744,0.9605189561843872,0.1501489579677581,0.7548399567604065,1.0430506467819214,50000 -4029.317915916443,15.082010746002195,115320.70908522606,341276,0,115320.70908522606,0.6304000020027161,1.8239904642105105,10000,119376.82467389108,0.9614955186843872,0.1441345363855362,0.7552599906921387,1.0443507432937622,50000 -4046.286930322647,15.220804691314695,115830.62089276314,342786,0,115830.62089276314,0.6305000185966492,1.821317672729492,10000,119903.89597392082,0.9616549611091614,0.145839437842369,0.754859983921051,1.0440630912780762,50000 -4063.241968870163,15.308415651321411,116340.69873285294,344296,0,116340.69873285294,0.6310000419616699,1.8192788362503047,10000,120431.06862139702,0.9596021771430968,0.1497314274311065,0.7553199529647827,1.043811321258545,50000 -4081.250019550324,15.40729808807373,116850.72356057169,345806,0,116850.72356057169,0.631100058555603,1.8209062814712524,10000,120959.25270080566,0.9596420526504515,0.1502621620893478,0.7553399801254272,1.0439318418502808,50000 -4098.288117408752,15.495001316070557,117360.82094693184,347316,0,117360.82094693184,0.6319000124931335,1.822487950325012,10000,121486.52737092972,0.9595025181770324,0.1501846611499786,0.7553199529647827,1.044198513031006,50000 -4115.436737775803,15.581097602844238,117870.91880130768,348827,0,117870.91880130768,0.6308000087738037,1.8201829195022583,10000,122013.91163277626,0.9618343114852904,0.1459579169750213,0.7551599740982056,1.0425646305084229,50000 -4132.369649171829,15.654744863510132,118380.99407410622,350338,0,118380.99407410622,0.6315000057220459,1.821236252784729,10000,122541.0457561016,0.9604790806770324,0.1469176858663559,0.7550199627876282,1.0435841083526611,50000 -4149.481686353684,15.740437507629396,118891.1866297722,351849,0,118891.1866297722,0.6301000118255615,1.8209481239318848,10000,123068.48770236968,0.9606584906578064,0.1471874266862869,0.7547799944877625,1.0444236993789673,50000 -4166.511382102966,15.826533794403076,119401.21796488762,353359,0,119401.21796488762,0.6307000517845154,1.8223930597305296,10000,123595.68791532516,0.9605787396430968,0.1489564776420593,0.7554199695587158,1.0429089069366455,50000 -4183.558404684067,15.91768193244934,119911.32834887505,354870,0,119911.32834887505,0.6306000351905823,1.8234483003616333,10000,124122.98861145972,0.9594626426696776,0.149523377418518,0.7547399997711182,1.0451470613479614,50000 -4200.7578365802765,16.00704550743103,120421.18785190582,356380,0,120421.18785190582,0.6318000555038452,1.8212215900421145,10000,124650.18924379347,0.9622528553009032,0.142277330160141,0.755620002746582,1.0428621768951416,50000 -4217.550870895386,16.096714735031128,120931.28691577911,357891,0,120931.28691577911,0.6312000155448914,1.8233301639556885,10000,125177.22383451462,0.9612563848495485,0.1443909555673599,0.7553199529647827,1.044643521308899,50000 -4234.424687862396,16.18539547920227,121441.40004301073,359401,0,121441.40004301073,0.6303000450134277,1.8215309381484983,10000,125704.3516037464,0.9583665132522584,0.1501764953136444,0.7549999952316284,1.0434049367904663,50000 -4251.351005554199,16.275821685791016,121951.36361050606,360911,0,121951.36361050606,0.6305000185966492,1.8208733797073364,10000,126231.38405561449,0.9622129797935486,0.1466683894395828,0.7549999952316284,1.0447399616241455,50000 -4268.519370794296,16.37822437286377,122461.54856848715,362422,0,122461.54856848715,0.6313000321388245,1.8205649852752688,10000,126758.89103007317,0.9601004123687744,0.1505505144596099,0.7554199695587158,1.0432157516479492,50000 -4285.543945074081,16.466397523880005,122971.61852169035,363933,0,122971.61852169035,0.6302000284194946,1.821946144104004,10000,127286.1263434887,0.9612165093421936,0.1459295153617859,0.7549999952316284,1.0441726446151731,50000 -4302.584145069122,16.560025453567505,123481.5219783783,365443,0,123481.5219783783,0.6312000155448914,1.8219724893569944,10000,127813.21567821504,0.957409918308258,0.1546299159526825,0.7553799748420715,1.043689846992493,50000 -4319.705858469009,16.655137300491333,123991.6851592064,366954,0,123991.6851592064,0.6315000057220459,1.8199297189712524,10000,128340.6479599476,0.960379421710968,0.1479264050722122,0.7549799680709839,1.043339490890503,50000 -4336.812728643417,16.74353575706482,124501.6871304512,368464,0,124501.6871304512,0.631100058555603,1.8230565786361688,10000,128867.89737272264,0.960758090019226,0.1465861648321151,0.7549999952316284,1.0438544750213623,50000 -4353.856767416,16.82906484603882,125011.75234508514,369974,0,125011.75234508514,0.6308000087738037,1.8213281631469729,10000,129395.1438281536,0.9587850570678712,0.1517333090305328,0.7553399801254272,1.0435724258422852,50000 -4370.871035575867,16.91614294052124,125521.72247958183,371484,0,125521.72247958183,0.631600022315979,1.821453213691712,10000,129922.26717352869,0.9606186151504515,0.1473256200551986,0.7552199959754944,1.0428035259246826,50000 -4387.892452478409,17.007720232009888,126031.5854511261,372994,0,126031.5854511261,0.6306000351905823,1.819501638412476,10000,130449.29526090622,0.9604392051696776,0.1476263850927353,0.7548799514770508,1.0441089868545532,50000 -4404.918319940567,17.09819459915161,126541.45552277564,374504,0,126541.45552277564,0.6314000487327576,1.8218950033187864,10000,130976.33390402794,0.9590441584587096,0.1508360058069229,0.7549799680709839,1.0448273420333862,50000 -4422.162630319595,17.189462661743164,127051.36449313164,376014,0,127051.36449313164,0.6307000517845154,1.8222171068191528,10000,131503.63059139252,0.9611168503761292,0.1466236859560012,0.7553799748420715,1.0439865589141846,50000 -4439.046007156372,17.278419017791748,127561.3567495346,377524,0,127561.3567495346,0.6303000450134277,1.822312355041504,10000,132030.64701890945,0.9603196382522584,0.1494261622428894,0.7545799612998962,1.0441468954086304,50000 -4455.830331802368,18.18609476089477,128070.5689780712,379032,0,128070.5689780712,0.6310000419616699,1.8221724033355715,10000,132557.6023557186,0.9614157676696776,0.1464610993862152,0.7551199793815613,1.0441802740097046,50000 -4472.877988100052,18.27739262580872,128580.7130572796,380543,0,128580.7130572796,0.6302000284194946,1.8225467205047607,10000,133084.9375550747,0.961355984210968,0.1447767466306686,0.7554199695587158,1.0452173948287964,50000 -4489.933167695999,18.371665239334103,129090.8795325756,382054,0,129090.8795325756,0.6312000155448914,1.822293996810913,10000,133612.30461883545,0.9597616195678712,0.1493446826934814,0.7551999688148499,1.0439260005950928,50000 -4507.924045085907,18.46406841278076,129600.94166183472,383565,0,129600.94166183472,0.631100058555603,1.823047757148743,10000,134140.50261998177,0.959980845451355,0.1489178836345672,0.7549799680709839,1.0436407327651978,50000 -4524.816039323807,18.555992364883423,130110.80349063872,385074,0,130110.80349063872,0.6301000118255615,1.822293400764465,10000,134667.40004825592,0.9599609375,0.1488925218582153,0.7550999522209167,1.0445308685302734,50000 -4541.817780017853,18.64685320854187,130620.97541832924,386585,0,130620.97541832924,0.6305000185966492,1.8213427066802976,10000,135194.71618199348,0.9601402878761292,0.1486131697893142,0.7552199959754944,1.044558882713318,50000 -4559.130956888199,18.741585731506348,131131.05405449867,388095,0,131131.05405449867,0.6306000351905823,1.821485996246338,10000,135722.25498723984,0.9613759517669678,0.1466667354106903,0.755620002746582,1.0440008640289309,50000 -4576.21217250824,18.844163179397583,131640.9479010105,389605,0,131640.9479010105,0.6303000450134277,1.8217018842697144,10000,136249.3849053383,0.9606186151504515,0.1472872942686081,0.755079984664917,1.043853998184204,50000 -4593.165332555771,18.94241428375244,132151.01629567146,391116,0,132151.01629567146,0.6303000450134277,1.8227473497390747,10000,136776.55661320686,0.9608378410339355,0.1467545479536056,0.7552199959754944,1.0440353155136108,50000 -4609.966013908386,19.038569688797,132660.9876010418,392626,0,132660.9876010418,0.6300000548362732,1.821921944618225,10000,137303.47638607025,0.9600406289100648,0.1507564932107925,0.7551599740982056,1.0441441535949707,50000 -4626.819089174271,19.13323402404785,133171.0541651249,394136,0,133171.0541651249,0.6296000480651855,1.8231137990951536,10000,137830.54284071922,0.9599609375,0.1487556546926498,0.7548999786376953,1.0435575246810913,50000 -4643.953272104263,19.22391462326049,133681.10687541962,395646,0,133681.10687541962,0.631100058555603,1.82287073135376,10000,138357.87222909927,0.961933970451355,0.1406219154596328,0.7546799778938293,1.044531226158142,50000 -4660.833715677261,19.33142256736756,134191.01133322716,397156,0,134191.01133322716,0.6309000253677368,1.8224536180496216,10000,138884.81725287437,0.9604591727256776,0.1473497450351715,0.7556999921798706,1.0434935092926023,50000 -4677.92427277565,19.43360996246338,134701.0166823864,398665,0,134701.0166823864,0.6309000253677368,1.8216071128845213,10000,139412.06704068184,0.9610171914100648,0.1485028862953186,0.7549200057983398,1.044009804725647,50000 -4694.866877555847,19.5312135219574,135211.10509705544,400176,0,135211.10509705544,0.6319000124931335,1.8218629360198968,10000,139939.248128891,0.9601203799247742,0.1494250744581222,0.755079984664917,1.04414701461792,50000 -4711.882912874222,19.63275671005249,135721.27109241486,401687,0,135721.27109241486,0.6308000087738037,1.8203022480010984,10000,140466.58408665657,0.961316168308258,0.1458020508289337,0.7552199959754944,1.0438878536224363,50000 -4728.720160961151,19.728036642074585,136231.34775304794,403197,0,136231.34775304794,0.6313000321388245,1.821870803833008,10000,140993.64425182345,0.9598413109779358,0.1497427225112915,0.755079984664917,1.0438313484191897,50000 -4745.791574001312,19.823182582855225,136741.38403439522,404708,0,136741.38403439522,0.6308000087738037,1.822520732879639,10000,141520.8993494511,0.959004282951355,0.1519351452589035,0.7553199529647827,1.0440607070922852,50000 -4762.691604375839,19.920071363449097,137251.34723711014,406218,0,137251.34723711014,0.6312000155448914,1.82229483127594,10000,142047.91184687614,0.9599609375,0.1494533419609069,0.7551999688148499,1.0442662239074707,50000 -4779.641663074493,20.02484154701233,137761.51915931702,407729,0,137761.51915931702,0.6303000450134277,1.8232182264328003,10000,142575.1911354065,0.960339605808258,0.1472578048706054,0.7548799514770508,1.044488787651062,50000 -4796.389009952545,20.122021436691284,138271.42797732353,409239,0,138271.42797732353,0.6307000517845154,1.8203986883163448,10000,143101.9966094494,0.9592832922935486,0.1489840149879455,0.7554599642753601,1.0428025722503662,50000 -4813.636505126953,20.22056555747986,138781.3643064499,410749,0,138781.3643064499,0.631100058555603,1.8239178657531736,10000,143629.33147835732,0.959980845451355,0.1499495357275009,0.7550399899482727,1.0459305047988892,50000 -4830.401615142822,20.31696081161499,139291.53511548042,412260,0,139291.53511548042,0.6308000087738037,1.822560429573059,10000,144156.41612815857,0.9607780575752258,0.146344318985939,0.7551800012588501,1.043354630470276,50000 -4847.201719760895,20.461642026901245,139801.42785167694,413770,0,139801.42785167694,0.6309000253677368,1.821332693099976,10000,144683.30664777756,0.9599409699440002,0.1482083797454834,0.7553399801254272,1.0445153713226318,50000 -4864.29502248764,20.562307596206665,140311.3820786476,415280,0,140311.3820786476,0.6317000389099121,1.820890545845032,10000,145210.5059850216,0.9595025181770324,0.1520635187625885,0.754859983921051,1.0431699752807615,50000 -4881.290358066559,20.66101765632629,140821.24357557297,416790,0,140821.24357557297,0.6305000185966492,1.8210554122924805,10000,145737.5135421753,0.961355984210968,0.1469427496194839,0.7549200057983398,1.0446892976760864,50000 -4898.1636373996735,20.759366750717163,141331.1424012184,418300,0,141331.1424012184,0.6302000284194946,1.82075834274292,10000,146264.43515825272,0.9614756107330322,0.1455113440752029,0.7551599740982056,1.0434072017669678,50000 -4915.135595321655,20.857829093933105,141841.25842308998,419810,0,141841.25842308998,0.631100058555603,1.822310447692871,10000,146791.6738820076,0.9610769748687744,0.1462026089429855,0.7550999522209167,1.0435361862182615,50000 -4932.147796154022,20.957899570465088,142351.206646204,421320,0,142351.206646204,0.631600022315979,1.8217707872390747,10000,147318.78655314443,0.9599409699440002,0.1498060375452041,0.7549999952316284,1.0440946817398071,50000 -4949.913951873779,21.057546615600582,142861.21098184586,422830,0,142861.21098184586,0.6304000020027161,1.821520566940308,10000,147846.70860219002,0.9587252736091614,0.1514067202806472,0.7552199959754944,1.0427519083023071,50000 -4966.923537492752,21.153748273849487,143371.2946677208,424340,0,143371.2946677208,0.6302000284194946,1.822304129600525,10000,148373.94969177246,0.960160195827484,0.1487235575914383,0.7555199861526489,1.043978929519653,50000 -4983.973979711533,21.25181031227112,143881.41513490677,425851,0,143881.41513490677,0.6309000253677368,1.820621132850647,10000,148901.27196621895,0.9612165093421936,0.1480762660503387,0.7553399801254272,1.0441192388534546,50000 -5000.952465295792,21.35924863815308,144391.50057291985,427362,0,144391.50057291985,0.631100058555603,1.8233418464660645,10000,149428.4954817295,0.9609375,0.1451647132635116,0.7549399733543396,1.0440856218338013,50000 -5017.714422941208,21.466845512390137,144901.63012313843,428872,0,144901.63012313843,0.6304000020027161,1.8220691680908203,10000,149955.54619264603,0.9600605964660645,0.1478695422410965,0.7552399635314941,1.0440679788589478,50000 -5034.476024627686,21.568559885025024,145411.6282696724,430383,0,145411.6282696724,0.6314000487327576,1.8212717771530151,10000,150482.45879983902,0.9606983065605164,0.1471541225910186,0.7550599575042725,1.0442098379135132,50000 -5051.591791629791,21.6670138835907,145921.7688229084,431894,0,145921.7688229084,0.6308000087738037,1.8222107887268064,10000,151009.8665678501,0.9595025181770324,0.1516856253147125,0.7554199695587158,1.043545842170715,50000 -5068.386824846268,21.767746925354004,146431.67700624466,433404,0,146431.67700624466,0.6315000057220459,1.8215786218643188,10000,151536.72223758698,0.962312638759613,0.1417900472879409,0.755299985408783,1.044046401977539,50000 -5085.313814640045,21.91230607032776,146941.73819971085,434914,0,146941.73819971085,0.631100058555603,1.823761820793152,10000,152063.90723013878,0.9612762928009032,0.1439260393381118,0.7549200057983398,1.0445345640182495,50000 -5102.110006809235,22.009751081466675,147451.59631967545,436424,0,147451.59631967545,0.6307000517845154,1.8222869634628296,10000,152590.7104575634,0.959203600883484,0.1486395001411438,0.7547999620437622,1.045411467552185,50000 -5118.958813905716,22.115402460098267,147961.5187318325,437934,0,147961.5187318325,0.6306000351905823,1.8189860582351685,10000,153117.63931131363,0.9607381820678712,0.1503721624612808,0.7549799680709839,1.042616367340088,50000 -5136.166883468628,22.215853929519653,148471.39300227165,439444,0,148471.39300227165,0.6320000290870667,1.8204666376113887,10000,153644.87471556664,0.9608577489852904,0.1458749920129776,0.7550199627876282,1.043423771858215,50000 -5153.135057687759,22.316500425338745,148981.3084001541,440954,0,148981.3084001541,0.6302000284194946,1.821756720542908,10000,154171.9116613865,0.961575210094452,0.1486123651266098,0.755299985408783,1.0435887575149536,50000 -5169.926432609558,22.424886465072632,149491.16506505013,442464,0,149491.16506505013,0.631100058555603,1.8232500553131104,10000,154698.7199485302,0.9580875039100648,0.1526551395654678,0.7554399967193604,1.0435420274734497,50000 -5186.662141799927,22.53334045410156,150001.00571346283,443974,0,150001.00571346283,0.6308000087738037,1.8222272396087649,10000,155225.45688009262,0.959582269191742,0.1498272866010666,0.7555999755859375,1.043811559677124,50000 -5203.670311450958,22.639395475387573,150511.18306398392,445485,0,150511.18306398392,0.6308000087738037,1.8205363750457764,10000,155752.8000433445,0.9597018361091614,0.1487534046173095,0.7551199793815613,1.0437120199203491,50000 -5220.766705274582,22.740891456604004,151021.32336831093,446996,0,151021.32336831093,0.6314000487327576,1.819707155227661,10000,156280.19078469276,0.9610769748687744,0.1466841101646423,0.7549399733543396,1.0435166358947754,50000 -5237.740381240845,22.843069553375244,151531.32101392746,448506,0,151531.32101392746,0.6306000351905823,1.8209294080734253,10000,156807.316532135,0.9595623016357422,0.1505237370729446,0.7554399967193604,1.044321060180664,50000 -5254.520807504654,22.94432306289673,152041.26451206207,450016,0,152041.26451206207,0.6318000555038452,1.8213356733322144,10000,157334.1936788559,0.959741711616516,0.1485694497823715,0.7551199793815613,1.0434327125549316,50000 -5271.759593009949,23.047905683517456,152551.20199108124,451526,0,152551.20199108124,0.6309000253677368,1.8225295543670648,10000,157861.52579021454,0.9610969424247742,0.147056832909584,0.7548999786376953,1.0432263612747192,50000 -5288.670228481293,23.14963865280152,153061.21781492233,453036,0,153061.21781492233,0.6308000087738037,1.8207491636276243,10000,158388.60536050797,0.9601004123687744,0.1490989327430725,0.7547999620437622,1.044902205467224,50000 -5305.7620232105255,23.25257182121277,153571.1104207039,454546,0,153571.1104207039,0.6305000185966492,1.8212093114852903,10000,158915.745316267,0.9602199792861938,0.1482055038213729,0.7547999620437622,1.0442836284637451,50000 -5322.696059703827,23.35454678535461,154081.2016234398,456056,0,154081.2016234398,0.6319000124931335,1.8223989009857176,10000,159442.92390727997,0.960379421710968,0.1483684480190277,0.7552599906921387,1.0436128377914429,50000 -5339.510857105255,23.473296642303467,154591.28657960892,457567,0,154591.28657960892,0.631100058555603,1.82321572303772,10000,159969.99428796768,0.9609375,0.1461519598960876,0.7552799582481384,1.0445982217788696,50000 -5356.392268419266,23.5770070552826,155101.33123731613,459078,0,155101.33123731613,0.6315000057220459,1.8196227550506592,10000,160497.07613682747,0.9615553021430968,0.1467585861682891,0.7550399899482727,1.0421971082687378,50000 -5373.915447711945,23.67885732650757,155611.23100996015,460588,0,155611.23100996015,0.6309000253677368,1.822048306465149,10000,161024.65453124046,0.9595623016357422,0.148551806807518,0.755079984664917,1.0454641580581665,50000 -5390.780106306076,23.7828586101532,156121.33530831337,462099,0,156121.33530831337,0.6312000155448914,1.821954727172852,10000,161551.77962732315,0.959363043308258,0.1515534073114395,0.7554000020027161,1.0442078113555908,50000 -5407.651072978973,23.887704133987427,156631.41730761528,463609,0,156631.41730761528,0.6315000057220459,1.8227782249450684,10000,162078.88946652412,0.9598413109779358,0.1484057307243347,0.7550399899482727,1.043148398399353,50000 -5424.55797290802,23.98750042915344,157141.29112553596,465119,0,157141.29112553596,0.6305000185966492,1.8231332302093504,10000,162605.82205343246,0.9618741869926452,0.1466239541769027,0.755079984664917,1.044501543045044,50000 -5441.46524143219,24.0924711227417,157651.32108592987,466629,0,157651.32108592987,0.6304000020027161,1.822500228881836,10000,163132.91656827927,0.9601004123687744,0.1466823518276214,0.7554199695587158,1.043945550918579,50000 -5458.459059715271,24.19330763816833,158161.37888002396,468140,0,158161.37888002396,0.6304000020027161,1.8211979866027832,10000,163660.12048506737,0.9606584906578064,0.1466185003519058,0.7549399733543396,1.0439494848251345,50000 -5475.40735244751,24.295584201812744,158671.50946760178,469651,0,158671.50946760178,0.6307000517845154,1.822662591934204,10000,164187.35395240784,0.960718274116516,0.1477012634277343,0.7555399537086487,1.0445095300674438,50000 -5492.339055299759,24.399834871292114,159181.595911026,471162,0,159181.595911026,0.6308000087738037,1.821772575378418,10000,164714.5281674862,0.9594826102256776,0.1514231264591217,0.7551800012588501,1.043542504310608,50000 -5509.320637464523,24.53001499176025,159691.7286350727,472673,0,159691.7286350727,0.6302000284194946,1.821508288383484,10000,165241.82578277588,0.9615553021430968,0.1431863456964492,0.755299985408783,1.0438499450683594,50000 -5526.427300453186,24.63716220855713,160201.8107573986,474183,0,160201.8107573986,0.6310000419616699,1.823304533958435,10000,165769.17288303375,0.9614756107330322,0.1437897384166717,0.754859983921051,1.045137643814087,50000 -5543.228472948074,24.7429301738739,160711.80350279808,475693,0,160711.80350279808,0.6305000185966492,1.8224012851715088,10000,166296.12507390976,0.9602798223495485,0.1477307230234146,0.7551999688148499,1.0440150499343872,50000 -5560.047877073288,24.84749460220337,161221.76031327248,477203,0,161221.76031327248,0.6299000382423401,1.8201723098754885,10000,166823.05702018738,0.9600406289100648,0.1513166725635528,0.7547199726104736,1.0432075262069702,50000 -5576.959014892578,24.95271611213684,161731.6183450222,478712,0,161731.6183450222,0.6303000450134277,1.824659824371338,10000,167349.98368883133,0.9610171914100648,0.1452740877866745,0.7547000050544739,1.045116901397705,50000 -5593.939316511154,25.06274676322937,162241.5458357334,480222,0,162241.5458357334,0.6310000419616699,1.821131706237793,10000,167877.05255436897,0.9602598547935486,0.1506530493497848,0.7551400065422058,1.043482542037964,50000 -5610.818668365479,25.167808771133423,162751.5722372532,481732,0,162751.5722372532,0.6317000389099121,1.822176098823548,10000,168404.1153049469,0.9599011540412904,0.1497306227684021,0.754859983921051,1.044361591339111,50000 -5627.59955739975,25.27507972717285,163261.70915150642,483243,0,163261.70915150642,0.6308000087738037,1.8228379487991333,10000,168931.19228696823,0.9590840339660645,0.1498535722494125,0.7548799514770508,1.0456264019012451,50000 -5644.491109132767,25.381606101989743,163771.55219697952,484753,0,163771.55219697952,0.6305000185966492,1.8222293853759768,10000,169458.08463525772,0.9606983065605164,0.1470476686954498,0.7554000020027161,1.0443992614746094,50000 -5661.443585395813,25.490796089172363,164281.594363451,486264,0,164281.594363451,0.6307000517845154,1.8217581510543823,10000,169985.24022579193,0.9604591727256776,0.147618219256401,0.7550399899482727,1.044253945350647,50000 -5678.37740445137,25.59970498085022,164791.71216368675,487775,0,164791.71216368675,0.6321000456809998,1.819995641708374,10000,170512.45286893845,0.9592434167861938,0.1517109423875808,0.7553399801254272,1.0429028272628784,50000 -5695.487970590591,25.70747256278992,165301.6225683689,489286,0,165301.6225683689,0.6303000450134277,1.820755958557129,10000,171039.63341140747,0.9607381820678712,0.1463460773229599,0.7548399567604065,1.0446679592132568,50000 -5712.22460770607,25.81318640708924,165811.5897936821,490796,0,165811.5897936821,0.6317000389099121,1.8227946758270264,10000,171566.49512457848,0.9604392051696776,0.1487823575735092,0.755299985408783,1.0443212985992432,50000 -5729.101854324341,25.92009258270264,166321.67253041267,492307,0,166321.67253041267,0.6322000026702881,1.823703646659851,10000,172093.6145040989,0.959203600883484,0.1508540958166122,0.7555199861526489,1.044857144355774,50000 -5746.009928226471,26.02876138687133,166831.68437623978,493817,0,166831.68437623978,0.631100058555603,1.8204790353775024,10000,172620.69486761093,0.9607381820678712,0.148208349943161,0.7550999522209167,1.0426260232925415,50000 -5763.065656900406,26.13835978507996,167341.80224132538,495327,0,167341.80224132538,0.6322000026702881,1.82174289226532,10000,173148.03052449226,0.961535394191742,0.1453877687454223,0.7549999952316284,1.0442039966583252,50000 -5780.004978179932,26.24807047843933,167851.83716320992,496837,0,167851.83716320992,0.6304000020027161,1.8195481300354004,10000,173675.16613030434,0.960339605808258,0.147151231765747,0.7551800012588501,1.0421451330184937,50000 -5796.880818128586,26.336756229400635,168361.71323132515,498347,0,168361.71323132515,0.6319000124931335,1.819997906684876,10000,174202.0587952137,0.9596021771430968,0.1485576331615448,0.7550199627876282,1.044399380683899,50000 -5814.488134860992,26.44757008552552,168871.77570962906,499857,0,168871.77570962906,0.631100058555603,1.821232557296753,10000,174729.8920071125,0.9595423936843872,0.151010975241661,0.7550399899482727,1.043644905090332,50000 -5831.401889562607,26.55976819992065,169381.6427283287,501366,0,169381.6427283287,0.6312000155448914,1.822257161140442,10000,175256.8369383812,0.9618343114852904,0.1473004668951034,0.7550599575042725,1.0455933809280396,50000 -5848.470964431763,26.6716742515564,169891.77300667763,502877,0,169891.77300667763,0.6317000389099121,1.8203349113464355,10000,175784.2019557953,0.9600406289100648,0.1488345712423324,0.7554799914360046,1.043159008026123,50000 -5865.335414886475,26.78855919837952,170401.88774085045,504388,0,170401.88774085045,0.6304000020027161,1.8220136165618896,10000,176311.35015702248,0.961136758327484,0.1460768282413482,0.7550199627876282,1.0443062782287598,50000 -5882.064656734467,26.900224924087524,170911.90631198883,505898,0,170911.90631198883,0.6313000321388245,1.822537422180176,10000,176838.2618751526,0.960359513759613,0.146673321723938,0.7550999522209167,1.0442750453948977,50000 -5899.219933986664,27.01107287406921,171422.06424570084,507409,0,171422.06424570084,0.6313000321388245,1.8244786262512207,10000,177365.73793911934,0.9600406289100648,0.1485704332590103,0.7552199959754944,1.044175386428833,50000 -5916.238545417786,27.122536659240723,171932.13884925842,508919,0,171932.13884925842,0.6301000118255615,1.8214324712753296,10000,177892.9942240715,0.9602997303009032,0.151060089468956,0.755899965763092,1.043476700782776,50000 -5933.223411083221,27.235366582870483,172442.27340078354,510430,0,172442.27340078354,0.6312000155448914,1.8233567476272583,10000,178420.2787978649,0.9616748690605164,0.1418393105268478,0.7552199959754944,1.0449419021606443,50000 -5950.154138565064,27.349958181381226,172952.19137954712,511940,0,172952.19137954712,0.6306000351905823,1.8227651119232176,10000,178947.29454159737,0.9614357352256776,0.1435249298810959,0.7557199597358704,1.0442218780517578,50000 -5967.08442735672,27.45971775054932,173462.10426592827,513450,0,173462.10426592827,0.6304000020027161,1.8211781978607176,10000,179474.29887628555,0.9593430757522584,0.148752212524414,0.7551400065422058,1.043470025062561,50000 -5983.905996799469,27.568768978118896,173971.96740150452,514960,0,173971.96740150452,0.6315000057220459,1.82271409034729,10000,180001.14465355873,0.9610769748687744,0.1501688510179519,0.7549999952316284,1.0442010164260864,50000 -6000.9197318553925,27.683451890945435,174482.09878492355,516471,0,174482.09878492355,0.6300000548362732,1.8224447965621948,10000,180528.45638537407,0.9609972834587096,0.1467299163341522,0.7552199959754944,1.0436804294586182,50000 -6018.418604850769,27.80076766014099,174992.04834723473,517981,0,174992.04834723473,0.6314000487327576,1.8219637870788568,10000,181056.0744562149,0.960558831691742,0.1473705470561981,0.7551400065422058,1.043861746788025,50000 -6035.283583641052,27.892229795455933,175502.09851241112,519491,0,175502.09851241112,0.6310000419616699,1.821869969367981,10000,181583.13382434845,0.958765149116516,0.1532300561666488,0.7552399635314941,1.043881893157959,50000 -6052.146709918976,28.011739253997803,176011.94000339508,521001,0,176011.94000339508,0.6307000517845154,1.8207950592041016,10000,182110.01022839543,0.9597018361091614,0.1484248042106628,0.7550599575042725,1.04348886013031,50000 -6069.219238519669,28.131194591522217,176522.079351902,522512,0,176522.079351902,0.6303000450134277,1.8228248357772827,10000,182637.39370679847,0.9596619606018066,0.150680735707283,0.7549200057983398,1.0442814826965332,50000 -6086.252090454102,28.246938705444336,177032.21652126312,524023,0,177032.21652126312,0.6313000321388245,1.819411039352417,10000,183164.7318851948,0.960558831691742,0.1465355008840561,0.755299985408783,1.0426671504974363,50000 -6103.131259441376,28.36412715911865,177542.2542464733,525533,0,177542.2542464733,0.6310000419616699,1.823042869567871,10000,183691.81848621368,0.9598612785339355,0.1497280746698379,0.7552199959754944,1.044359564781189,50000 -6119.841101884842,28.47494602203369,178052.15637159348,527043,0,178052.15637159348,0.6318000555038452,1.8209943771362305,10000,184218.59387516967,0.9602000713348388,0.1480627954006195,0.7551400065422058,1.043128252029419,50000 -6136.617638349533,28.58440470695496,178562.26889562607,528553,0,178562.26889562607,0.6310000419616699,1.8213881254196167,10000,184745.6449303627,0.9608777165412904,0.1473826766014099,0.7552799582481384,1.0443631410598757,50000 -6153.652619361877,28.698846340179443,179072.21165585518,530063,0,179072.21165585518,0.6301000118255615,1.8239092826843264,10000,185272.78962039948,0.9594626426696776,0.1503868401050567,0.755079984664917,1.0445592403411863,50000 -6170.457691907883,28.81775641441345,179582.16641449928,531573,0,179582.16641449928,0.6312000155448914,1.8203206062316888,10000,185799.72025728223,0.9595224857330322,0.150401160120964,0.7553600072860718,1.0430076122283936,50000 -6187.182283878326,28.934704780578613,180092.11389112473,533083,0,180092.11389112473,0.6320000290870667,1.8219304084777832,10000,186326.5613975525,0.9605388641357422,0.1487494260072708,0.7551999688148499,1.043909788131714,50000 -6203.998332023621,29.05104231834412,180602.2413403988,534594,0,180602.2413403988,0.6320000290870667,1.821923971176148,10000,186853.6732811928,0.9620535373687744,0.1427308619022369,0.7545999884605408,1.04520845413208,50000 -6221.071969509125,29.16448450088501,181112.3642761708,536105,0,181112.3642761708,0.6321000456809998,1.8213005065917969,10000,187381.0358531475,0.9607780575752258,0.1484987586736679,0.7554399967193604,1.0436646938323977,50000 -6238.322468757629,29.27834534645081,181622.24652838707,537615,0,181622.24652838707,0.6301000118255615,1.8213962316513064,10000,187908.334836483,0.9604392051696776,0.1466994285583496,0.7552599906921387,1.0449239015579224,50000 -6255.310319900513,29.389899730682373,182132.1446402073,539125,0,182132.1446402073,0.6310000419616699,1.822413206100464,10000,188435.3849275112,0.9596021771430968,0.151600956916809,0.7555599808692932,1.043686032295227,50000 -6272.058003902435,29.50988364219665,182642.13536715508,540635,0,182642.13536715508,0.6308000087738037,1.8208487033844,10000,188962.29661369324,0.9598612785339355,0.1493533551692962,0.7551400065422058,1.0445722341537476,50000 -6288.69265794754,29.62590265274048,183152.2509188652,542146,0,183152.2509188652,0.6310000419616699,1.822114944458008,10000,189489.21495723724,0.9615154266357422,0.1449406147003173,0.7549200057983398,1.044356346130371,50000 -6305.479026794434,29.74439120292664,183662.38544940948,543657,0,183662.38544940948,0.6319000124931335,1.8207831382751465,10000,190016.3076210022,0.9601402878761292,0.1462753862142563,0.7551400065422058,1.0432575941085815,50000 -6322.706331968308,29.857580184936523,184172.5107626915,545168,0,184172.5107626915,0.631100058555603,1.821388959884644,10000,190543.8264966011,0.96000075340271,0.1489509791135788,0.7549999952316284,1.044825792312622,50000 -6339.538756847382,29.97205090522766,184682.5400466919,546679,0,184682.5400466919,0.6315000057220459,1.821320414543152,10000,191070.85496354103,0.9606783986091614,0.1472610086202621,0.7549999952316284,1.043584942817688,50000 -6356.251218318939,30.08995032310486,185192.69767308235,548189,0,185192.69767308235,0.6309000253677368,1.82038676738739,10000,191597.8956928253,0.959741711616516,0.1502102166414261,0.7551800012588501,1.0438426733016968,50000 -6373.249268293381,30.20629596710205,185702.6072921753,549699,0,185702.6072921753,0.6307000517845154,1.8231457471847528,10000,192124.9710338116,0.9620535373687744,0.1419393718242645,0.7547799944877625,1.043861746788025,50000 -6390.202693462372,30.328657150268555,186212.7490684986,551209,0,186212.7490684986,0.6319000124931335,1.823291301727295,10000,192652.24084949493,0.9608178734779358,0.144424170255661,0.7553199529647827,1.0437524318695068,50000 -6407.26146531105,30.44870686531067,186722.60253715515,552719,0,186722.60253715515,0.6306000351905823,1.8228309154510496,10000,193179.3252959252,0.96097731590271,0.1480630040168762,0.7549999952316284,1.0442960262298584,50000 -6424.07945227623,30.56987929344177,187232.4409868717,554229,0,187232.4409868717,0.6302000284194946,1.8223854303359983,10000,193706.154392004,0.9595423936843872,0.1502232849597931,0.7549799680709839,1.0441011190414429,50000 -6440.839239120483,30.693721771240234,187742.46044564247,555740,0,187742.46044564247,0.631100058555603,1.8212271928787231,10000,194233.10874533653,0.9612165093421936,0.1473574638366699,0.7553600072860718,1.0433720350265503,50000 -6457.721947193146,30.81244969367981,188252.3400375843,557250,0,188252.3400375843,0.6312000155448914,1.8224512338638303,10000,194760.04110717773,0.9599011540412904,0.1505948454141616,0.755299985408783,1.0440723896026611,50000 -6474.599181890488,30.935455560684204,188762.30083870888,558760,0,188762.30083870888,0.6308000087738037,1.822641134262085,10000,195287.05368733406,0.95804762840271,0.1525946706533432,0.7549200057983398,1.043939232826233,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 72b75fbfa..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5969 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.657526,6.9289646,,,,,,,,,,,,,, -1,,,0.0011559311533346,6.912258625030518,0.0011599999852478,6.912920475006104,50000.0,0.0010999999940395,6.912676334381104,10000.0,50.636836767196655,87.66242718696594,50.636836767196655,37.02550530433655,0.0,0.0 -100,0.67666143,6.8072534,,,,,,,,,,,,,, -200,0.78953254,6.5714636,,,,,,,,,,,,,, -300,1.0923727,6.297056,,,,,,,,,,,,,, -400,1.7477684,6.0672746,,,,,,,,,,,,,, -500,2.9927757,5.8596,,,,,,,,,,,,,, -600,2.1546376,5.5608215,,,,,,,,,,,,,, -700,4.517483,5.4487896,,,,,,,,,,,,,, -800,3.3539162,5.3285174,,,,,,,,,,,,,, -900,3.6423876,5.1471033,,,,,,,,,,,,,, -1000,7.5269604,5.0822763,,,,,,,,,,,,,, -1100,3.4165585,4.907184,,,,,,,,,,,,,, -1200,5.3943887,4.8350253,,,,,,,,,,,,,, -1300,5.1038704,4.753716,,,,,,,,,,,,,, -1400,4.227055,4.7139797,,,,,,,,,,,,,, -1500,6.338866,4.6383233,,,,,,,,,,,,,, -1506,,,0.1685267835855484,4.295879364013672,0.1553799957036972,4.445466995239258,50000.0,0.1157000064849853,4.921426296234131,10000.0,560.612095117569,615.5135684013367,560.612095117569,54.821940660476685,0.0259764194488525,0.0 -1600,3.602686,4.326953,,,,,,,,,,,,,, -1700,6.169819,4.2370305,,,,,,,,,,,,,, -1800,3.718819,4.359346,,,,,,,,,,,,,, -1900,5.597655,4.113846,,,,,,,,,,,,,, -2000,5.1305475,4.093711,,,,,,,,,,,,,, -2100,3.6263578,3.978268,,,,,,,,,,,,,, -2200,4.6840706,3.9218912,,,,,,,,,,,,,, -2300,5.1681156,3.760606,,,,,,,,,,,,,, -2400,5.457753,3.7075934,,,,,,,,,,,,,, -2500,3.4265144,3.8183892,,,,,,,,,,,,,, -2600,3.7355082,3.6895592,,,,,,,,,,,,,, -2700,4.214271,3.5903523,,,,,,,,,,,,,, -2800,3.3201556,3.5113857,,,,,,,,,,,,,, -2900,5.2850175,3.5779188,,,,,,,,,,,,,, -3000,4.145559,3.3186173,,,,,,,,,,,,,, -3010,,,0.3398238122463226,3.0561516284942627,0.3103599846363067,3.23838472366333,50000.0,0.2299000173807144,3.900252103805542,10000.0,1070.6800384521484,1143.3546075820925,1070.6800384521484,72.51834774017334,0.0519242286682128,0.0 -3100,3.246496,3.4448068,,,,,,,,,,,,,, -3200,3.4527159,3.264643,,,,,,,,,,,,,, -3300,4.1519017,3.2103176,,,,,,,,,,,,,, -3400,4.078003,3.1943722,,,,,,,,,,,,,, -3500,3.2231007,3.185532,,,,,,,,,,,,,, -3600,2.6749697,3.2584863,,,,,,,,,,,,,, -3700,3.2727952,3.1020627,,,,,,,,,,,,,, -3800,2.666453,3.0043623,,,,,,,,,,,,,, -3900,3.5453992,3.0854638,,,,,,,,,,,,,, -4000,4.5209274,2.976467,,,,,,,,,,,,,, -4100,2.6866682,2.9609876,,,,,,,,,,,,,, -4200,2.3391855,2.8132596,,,,,,,,,,,,,, -4300,1.8893396,2.9623075,,,,,,,,,,,,,, -4400,2.522821,2.9317822,,,,,,,,,,,,,, -4500,2.5833626,2.76918,,,,,,,,,,,,,, -4516,,,0.4670958220958709,2.33237099647522,0.4367599785327911,2.513633728027344,50000.0,0.3290000259876251,3.260066509246826,10000.0,1580.6047410964966,1671.3481891155243,1580.6047410964966,90.5101797580719,0.0783896446228027,0.0 -4600,2.4197211,2.6936593,,,,,,,,,,,,,, -4700,2.5546973,2.771473,,,,,,,,,,,,,, -4800,2.79085,2.7218592,,,,,,,,,,,,,, -4900,2.8869984,2.7673228,,,,,,,,,,,,,, -5000,2.4276319,2.4799004,,,,,,,,,,,,,, -5100,3.2785597,2.641571,,,,,,,,,,,,,, -5200,2.476498,2.6835773,,,,,,,,,,,,,, -5300,1.9361541,2.5609272,,,,,,,,,,,,,, -5400,2.1537385,2.5857716,,,,,,,,,,,,,, -5500,2.4376178,2.4532218,,,,,,,,,,,,,, -5600,2.1043835,2.4809277,,,,,,,,,,,,,, -5700,1.9887441,2.5969777,,,,,,,,,,,,,, -5800,2.5086627,2.3662534,,,,,,,,,,,,,, -5900,1.9563696,2.49879,,,,,,,,,,,,,, -6000,1.7836419,2.4159715,,,,,,,,,,,,,, -6024,,,0.5397201776504517,1.9583563804626465,0.5008000135421753,2.161590337753296,50000.0,0.385200023651123,2.8951804637908936,10000.0,2090.61950135231,2199.170222520828,2090.61950135231,108.24045300483704,0.1055688858032226,0.0 -6100,2.9876266,2.440539,,,,,,,,,,,,,, -6200,2.7071178,2.4052255,,,,,,,,,,,,,, -6300,2.5472949,2.3594134,,,,,,,,,,,,,, -6400,2.1158903,2.2873662,,,,,,,,,,,,,, -6500,1.9126642,2.414003,,,,,,,,,,,,,, -6600,2.2499678,2.3849027,,,,,,,,,,,,,, -6700,2.0178277,2.2122934,,,,,,,,,,,,,, -6800,1.8515116,2.4102569,,,,,,,,,,,,,, -6900,1.9297482,2.324358,,,,,,,,,,,,,, -7000,2.0023816,2.303808,,,,,,,,,,,,,, -7100,1.889976,2.1549196,,,,,,,,,,,,,, -7200,1.8127017,2.3308518,,,,,,,,,,,,,, -7300,1.9996278,2.1906328,,,,,,,,,,,,,, -7400,2.2833817,2.3201852,,,,,,,,,,,,,, -7500,2.2674625,2.320732,,,,,,,,,,,,,, -7532,,,0.5772082209587097,1.7723355293273926,0.5385599732398987,1.980293869972229,50000.0,0.4175000190734863,2.7358531951904297,10000.0,2600.636000394821,2727.06293296814,2600.636000394821,126.03376865386964,0.1374855041503906,0.0 -7600,1.8938649,2.2189138,,,,,,,,,,,,,, -7700,2.0402815,2.2864509,,,,,,,,,,,,,, -7800,1.47741,2.118788,,,,,,,,,,,,,, -7900,2.0307086,2.2176416,,,,,,,,,,,,,, -8000,1.9181911,2.2592525,,,,,,,,,,,,,, -8100,1.5913484,2.2665048,,,,,,,,,,,,,, -8200,2.527324,2.1704516,,,,,,,,,,,,,, -8300,1.4717153,2.1894093,,,,,,,,,,,,,, -8400,1.4844755,2.2654712,,,,,,,,,,,,,, -8500,2.0958068,2.1933382,,,,,,,,,,,,,, -8600,2.0629108,2.1816525,,,,,,,,,,,,,, -8700,1.5253865,2.0567057,,,,,,,,,,,,,, -8800,2.3026583,2.1241155,,,,,,,,,,,,,, -8900,1.7496341,2.2159452,,,,,,,,,,,,,, -9000,1.6352738,2.1241906,,,,,,,,,,,,,, -9041,,,0.6067641973495483,1.6266725063323977,0.5556600093841553,1.8939425945281985,50000.0,0.4307000339031219,2.655898094177246,10000.0,3110.8438572883606,3255.172870874405,3110.8438572883606,143.8541338443756,0.1674284934997558,0.0 -9100,1.5973908,2.1883776,,,,,,,,,,,,,, -9200,2.2739532,2.0848994,,,,,,,,,,,,,, -9300,2.0963585,2.1172912,,,,,,,,,,,,,, -9400,2.1055305,2.1734037,,,,,,,,,,,,,, -9500,1.8525425,2.0832136,,,,,,,,,,,,,, -9600,1.9400412,2.0993717,,,,,,,,,,,,,, -9700,1.5856556,2.0492663,,,,,,,,,,,,,, -9800,2.4669287,2.0696476,,,,,,,,,,,,,, -9900,1.5126727,2.0737658,,,,,,,,,,,,,, -10000,1.4863596,2.0302773,,,,,,,,,,,,,, -10100,1.6509439,2.0372763,,,,,,,,,,,,,, -10200,1.8251327,2.099382,,,,,,,,,,,,,, -10300,1.5538676,2.066015,,,,,,,,,,,,,, -10400,1.487028,2.1327446,,,,,,,,,,,,,, -10500,2.0135052,2.1769905,,,,,,,,,,,,,, -10549,,,0.6285674571990967,1.5007295608520508,0.5639199614524841,1.8578181266784668,50000.0,0.4445000290870666,2.585268497467041,10000.0,3621.028754234314,3783.733434200287,3621.028754234314,162.1518476009369,0.1943204402923584,0.0 -10600,1.5255847,2.087686,,,,,,,,,,,,,, -10700,1.719296,2.1051543,,,,,,,,,,,,,, -10800,2.2125907,1.9930427,,,,,,,,,,,,,, -10900,1.7535007,2.077032,,,,,,,,,,,,,, -11000,2.0309408,2.1616962,,,,,,,,,,,,,, -11100,2.488586,1.9766965,,,,,,,,,,,,,, -11200,1.7135177,2.04627,,,,,,,,,,,,,, -11300,1.5654675,1.9447681,,,,,,,,,,,,,, -11400,1.5609896,2.0489388,,,,,,,,,,,,,, -11500,1.5206928,1.9624803,,,,,,,,,,,,,, -11600,1.5373832,2.0752785,,,,,,,,,,,,,, -11700,1.5727689,2.158955,,,,,,,,,,,,,, -11800,1.8757215,2.1064224,,,,,,,,,,,,,, -11900,1.5821385,2.0201807,,,,,,,,,,,,,, -12000,1.6746471,2.02886,,,,,,,,,,,,,, -12058,,,0.6300023794174194,1.498435378074646,0.5737599730491638,1.8043254613876345,50000.0,0.4433000087738037,2.580861568450928,10000.0,4131.20413517952,4312.186669111252,4131.20413517952,180.35190606117249,0.2221417427062988,0.0 -12100,1.658227,2.0873363,,,,,,,,,,,,,, -12200,1.855952,1.9576082,,,,,,,,,,,,,, -12300,1.339763,2.082521,,,,,,,,,,,,,, -12400,1.5690141,2.1018014,,,,,,,,,,,,,, -12500,1.5111929,2.027324,,,,,,,,,,,,,, -12600,1.3649541,2.0338302,,,,,,,,,,,,,, -12700,1.4105831,1.9157702,,,,,,,,,,,,,, -12800,1.5480529,1.9164199,,,,,,,,,,,,,, -12900,1.9519033,1.9633393,,,,,,,,,,,,,, -13000,1.3843468,2.087409,,,,,,,,,,,,,, -13100,1.5676782,2.0605729,,,,,,,,,,,,,, -13200,1.5260268,2.0233545,,,,,,,,,,,,,, -13300,1.2949972,1.9754536,,,,,,,,,,,,,, -13400,1.590934,1.8545048,,,,,,,,,,,,,, -13500,1.3240463,1.9174535,,,,,,,,,,,,,, -13566,,,0.6170080900192261,1.5616915225982666,0.5643399953842163,1.8516933917999268,50000.0,0.4456000328063965,2.5929534435272217,10000.0,4641.235483169556,4840.924443721771,4641.235483169556,198.97168898582456,0.2587699890136719,0.0 -13600,1.4719676,1.932812,,,,,,,,,,,,,, -13700,1.3536607,1.9640911,,,,,,,,,,,,,, -13800,1.4348481,1.8662884,,,,,,,,,,,,,, -13900,1.9822999,1.9972303,,,,,,,,,,,,,, -14000,1.7510078,2.0242047,,,,,,,,,,,,,, -14100,1.6680275,1.9707996,,,,,,,,,,,,,, -14200,1.4641597,1.8997927,,,,,,,,,,,,,, -14300,1.7236509,1.78748,,,,,,,,,,,,,, -14400,1.5471612,1.9199665,,,,,,,,,,,,,, -14500,1.7531719,1.9320804,,,,,,,,,,,,,, -14600,1.5468107,1.9859284,,,,,,,,,,,,,, -14700,2.404155,1.9572066,,,,,,,,,,,,,, -14800,1.6559026,1.848001,,,,,,,,,,,,,, -14900,1.51352,1.8543683,,,,,,,,,,,,,, -15000,1.7479632,1.8666091,,,,,,,,,,,,,, -15075,,,0.6428770422935486,1.4448336362838743,0.5908399820327759,1.7287929058074951,50000.0,0.4571000337600708,2.5280418395996094,10000.0,5151.217526912689,5369.960024833679,5151.217526912689,217.94574451446533,0.2880523204803467,0.0 -15100,1.5974723,2.034294,,,,,,,,,,,,,, -15200,1.5440084,1.8865236,,,,,,,,,,,,,, -15300,1.4756474,1.8365606,,,,,,,,,,,,,, -15400,1.6768401,1.9486451,,,,,,,,,,,,,, -15500,1.5019075,1.9192582,,,,,,,,,,,,,, -15600,1.5877581,1.9243966,,,,,,,,,,,,,, -15700,1.7366568,2.03085,,,,,,,,,,,,,, -15800,1.590861,1.8640313,,,,,,,,,,,,,, -15900,1.8334125,1.8706483,,,,,,,,,,,,,, -16000,1.7321613,1.9076359,,,,,,,,,,,,,, -16100,1.664603,1.8389547,,,,,,,,,,,,,, -16200,1.7288182,1.8064649,,,,,,,,,,,,,, -16300,1.6548482,1.825959,,,,,,,,,,,,,, -16400,1.60188,1.7249397,,,,,,,,,,,,,, -16500,1.7274089,1.9874419,,,,,,,,,,,,,, -16584,,,0.6567881107330322,1.3828415870666504,0.6065399646759033,1.6479262113571167,50000.0,0.4767000079154968,2.377015352249145,10000.0,5661.294228553772,5898.560553789139,5661.294228553772,236.3728106021881,0.3328497409820556,0.0 -16600,1.5413164,1.7763066,,,,,,,,,,,,,, -16700,1.4880065,1.8561887,,,,,,,,,,,,,, -16800,1.5194566,1.9170506,,,,,,,,,,,,,, -16900,1.6794903,1.9641947,,,,,,,,,,,,,, -17000,1.5801694,1.7641113,,,,,,,,,,,,,, -17100,1.4838904,1.924068,,,,,,,,,,,,,, -17200,1.5187846,1.8163792,,,,,,,,,,,,,, -17300,2.0598967,1.920668,,,,,,,,,,,,,, -17400,1.3775738,1.851845,,,,,,,,,,,,,, -17500,1.6963931,1.8376346,,,,,,,,,,,,,, -17600,1.5671843,1.8985705,,,,,,,,,,,,,, -17700,1.368258,1.8607526,,,,,,,,,,,,,, -17800,1.7879398,2.0157409,,,,,,,,,,,,,, -17900,1.5939599,1.8498031,,,,,,,,,,,,,, -18000,1.6510878,1.8814671,,,,,,,,,,,,,, -18092,,,0.6660555005073547,1.3375362157821655,0.6100599765777588,1.6255896091461182,50000.0,0.4919000267982483,2.33859920501709,10000.0,6171.443657398224,6432.820313692093,6171.443657398224,260.39603447914124,0.3701400756835937,0.0 -18100,1.6221372,1.7181317,,,,,,,,,,,,,, -18200,1.8732018,1.849712,,,,,,,,,,,,,, -18300,1.676212,1.8262444,,,,,,,,,,,,,, -18400,1.4299773,1.9079188,,,,,,,,,,,,,, -18500,1.7019304,1.9087874,,,,,,,,,,,,,, -18600,1.5499674,1.8464532,,,,,,,,,,,,,, -18700,1.8406663,1.7957807,,,,,,,,,,,,,, -18800,1.5432419,1.7482039,,,,,,,,,,,,,, -18900,1.6088963,2.0695713,,,,,,,,,,,,,, -19000,1.702467,1.9227769,,,,,,,,,,,,,, -19100,1.545944,1.745446,,,,,,,,,,,,,, -19200,1.7957338,1.8969021,,,,,,,,,,,,,, -19300,1.6955572,1.9002655,,,,,,,,,,,,,, -19400,1.8521214,1.7810562,,,,,,,,,,,,,, -19500,1.6305152,1.8210282,,,,,,,,,,,,,, -19600,1.6911072,1.8123906,,,,,,,,,,,,,, -19601,,,0.6856465339660645,1.2491897344589231,0.6026399731636047,1.6453914642333984,50000.0,0.482200026512146,2.384953022003174,10000.0,6681.714090824127,6965.917953968048,6681.714090824127,283.1347830295563,0.4076879024505615,0.0 -19700,1.7906281,1.9234387,,,,,,,,,,,,,, -19800,1.4863101,1.7529292,,,,,,,,,,,,,, -19900,1.8019035,1.8824025,,,,,,,,,,,,,, -20000,1.8778663,1.7784798,,,,,,,,,,,,,, -20100,1.7151273,1.8702217,,,,,,,,,,,,,, -20200,1.710299,1.7885327,,,,,,,,,,,,,, -20300,1.7924308,1.9252694,,,,,,,,,,,,,, -20400,1.5036778,1.8057349,,,,,,,,,,,,,, -20500,1.497829,1.622548,,,,,,,,,,,,,, -20600,1.7851851,1.8929886,,,,,,,,,,,,,, -20700,1.7286341,1.8459973,,,,,,,,,,,,,, -20800,1.6973128,1.8814445,,,,,,,,,,,,,, -20900,1.6682091,1.8333416,,,,,,,,,,,,,, -21000,1.6882427,1.7420936,,,,,,,,,,,,,, -21100,1.7265638,1.7042547,,,,,,,,,,,,,, -21110,,,0.6740672588348389,1.284901738166809,0.612060010433197,1.6220479011535645,50000.0,0.4874000251293182,2.3641345500946045,10000.0,7191.711829662323,7499.435747146606,7191.711829662323,306.5707335472107,0.4417273998260498,0.0 -21200,1.5596406,1.924968,,,,,,,,,,,,,, -21300,1.879986,1.8029113,,,,,,,,,,,,,, -21400,2.066151,1.8225212,,,,,,,,,,,,,, -21500,1.6983615,1.7541295,,,,,,,,,,,,,, -21600,1.8042849,1.6764452,,,,,,,,,,,,,, -21700,1.6257566,1.7676536,,,,,,,,,,,,,, -21800,1.7317071,1.7945979,,,,,,,,,,,,,, -21900,1.5651419,1.7994449,,,,,,,,,,,,,, -22000,1.7842709,1.9096178,,,,,,,,,,,,,, -22100,1.778121,1.7733661,,,,,,,,,,,,,, -22200,1.7853632,1.7591995,,,,,,,,,,,,,, -22300,1.558144,1.8582287,,,,,,,,,,,,,, -22400,1.6249148,1.8633604,,,,,,,,,,,,,, -22500,1.7460213,1.7482637,,,,,,,,,,,,,, -22600,1.6214507,1.7900386,,,,,,,,,,,,,, -22620,,,0.6695631146430969,1.30652117729187,0.6083199977874756,1.6256723403930664,50000.0,0.4837000370025635,2.382908344268799,10000.0,7701.745020151138,8033.715864419937,7701.745020151138,330.72790360450745,0.4803063869476318,0.0 -22700,1.7587094,1.8644131,,,,,,,,,,,,,, -22800,1.9506658,1.8683403,,,,,,,,,,,,,, -22900,1.7523509,1.76462,,,,,,,,,,,,,, -23000,1.5696797,1.8278842,,,,,,,,,,,,,, -23100,1.8723147,1.8213828,,,,,,,,,,,,,, -23200,1.8135399,1.8899341,,,,,,,,,,,,,, -23300,1.8094598,1.8445787,,,,,,,,,,,,,, -23400,2.2246888,1.696157,,,,,,,,,,,,,, -23500,1.7368679,1.7157166,,,,,,,,,,,,,, -23600,1.7188703,1.6617723,,,,,,,,,,,,,, -23700,1.733491,1.8271645,,,,,,,,,,,,,, -23800,1.8799437,1.8792889,,,,,,,,,,,,,, -23900,1.5471514,1.6479362,,,,,,,,,,,,,, -24000,1.7001885,1.8414971,,,,,,,,,,,,,, -24100,1.84931,1.7290431,,,,,,,,,,,,,, -24130,,,0.6800462007522583,1.271593689918518,0.6187199950218201,1.5880154371261597,50000.0,0.4928000271320343,2.347294330596924,10000.0,8211.9476480484,8568.492515087128,8211.9476480484,355.2130854129791,0.5175192356109619,0.0 -24200,1.7490121,1.784798,,,,,,,,,,,,,, -24300,1.7963297,1.7782768,,,,,,,,,,,,,, -24400,1.6206878,1.7108192,,,,,,,,,,,,,, -24500,1.6071652,1.7952046,,,,,,,,,,,,,, -24600,1.8036009,1.7197125,,,,,,,,,,,,,, -24700,1.8873276,1.6964625,,,,,,,,,,,,,, -24800,1.791382,1.7499186,,,,,,,,,,,,,, -24900,1.8665003,1.8143167,,,,,,,,,,,,,, -25000,1.6369562,1.6531968,,,,,,,,,,,,,, -25100,1.6516848,1.8595772,,,,,,,,,,,,,, -25200,1.8326703,1.7079537,,,,,,,,,,,,,, -25300,1.5445597,1.7599589,,,,,,,,,,,,,, -25400,1.8236959,1.755665,,,,,,,,,,,,,, -25500,1.6666541,1.7111462,,,,,,,,,,,,,, -25600,2.4245243,1.7914404,,,,,,,,,,,,,, -25640,,,0.6746651530265808,1.286739468574524,0.6209200024604797,1.5903011560440063,50000.0,0.4976000189781189,2.3183577060699463,10000.0,8722.075249195099,9101.75936460495,8722.075249195099,378.2750778198242,0.5450489521026611,0.0 -25700,1.8131173,1.6794164,,,,,,,,,,,,,, -25800,1.6815867,1.6995132,,,,,,,,,,,,,, -25900,1.9654914,1.695153,,,,,,,,,,,,,, -26000,1.8062588,1.7587214,,,,,,,,,,,,,, -26100,1.6324015,1.7186155,,,,,,,,,,,,,, -26200,1.6710218,1.6549022,,,,,,,,,,,,,, -26300,1.6789964,1.6805855,,,,,,,,,,,,,, -26400,1.7350041,1.6417022,,,,,,,,,,,,,, -26500,1.5877893,1.6900386,,,,,,,,,,,,,, -26600,1.6204872,1.745413,,,,,,,,,,,,,, -26700,1.9019037,1.8320472,,,,,,,,,,,,,, -26800,1.6366546,1.8729955,,,,,,,,,,,,,, -26900,1.8641027,1.7771677,,,,,,,,,,,,,, -27000,1.8723584,1.8214545,,,,,,,,,,,,,, -27100,1.6713868,1.5775299,,,,,,,,,,,,,, -27150,,,0.6598373651504517,1.3535218238830566,0.6033399701118469,1.653844118118286,50000.0,0.4777000248432159,2.413469314575196,10000.0,9232.15855383873,9634.885572433472,9232.15855383873,401.23756408691406,0.5751643180847168,0.0 -27200,1.7192506,1.6112223,,,,,,,,,,,,,, -27300,1.5209048,1.7784256,,,,,,,,,,,,,, -27400,1.7164165,1.7207959,,,,,,,,,,,,,, -27500,1.9387162,1.7060632,,,,,,,,,,,,,, -27600,1.7510203,1.7345078,,,,,,,,,,,,,, -27700,2.083943,1.865593,,,,,,,,,,,,,, -27800,1.9245946,1.6957548,,,,,,,,,,,,,, -27900,1.661968,1.8318219,,,,,,,,,,,,,, -28000,1.8370718,1.6439306,,,,,,,,,,,,,, -28100,1.7022276,1.7651827,,,,,,,,,,,,,, -28200,1.770106,1.6417764,,,,,,,,,,,,,, -28300,1.7764943,1.8119698,,,,,,,,,,,,,, -28400,1.8422874,1.7494104,,,,,,,,,,,,,, -28500,1.5966942,1.7259176,,,,,,,,,,,,,, -28600,1.884869,1.7318807,,,,,,,,,,,,,, -28659,,,0.7111367583274841,1.1349786520004272,0.6277799606323242,1.5449466705322266,50000.0,0.5024999976158142,2.293640851974488,10000.0,9742.014395713806,10167.096479654312,9742.014395713806,423.27252864837646,0.8448653221130371,0.0 -28700,1.8206662,1.7525767,,,,,,,,,,,,,, -28800,1.6902691,1.6599512,,,,,,,,,,,,,, -28900,1.6437461,1.7588578,,,,,,,,,,,,,, -29000,2.0344846,1.8753979,,,,,,,,,,,,,, -29100,2.0213854,1.6954932,,,,,,,,,,,,,, -29200,1.8666939,1.7081801,,,,,,,,,,,,,, -29300,1.7911661,1.8441881,,,,,,,,,,,,,, -29400,1.819144,1.6306725,,,,,,,,,,,,,, -29500,1.6795365,1.7132949,,,,,,,,,,,,,, -29600,1.6342058,1.7576638,,,,,,,,,,,,,, -29700,1.5504731,1.7059171,,,,,,,,,,,,,, -29800,1.7574058,1.7144914,,,,,,,,,,,,,, -29900,1.9127833,1.6129664,,,,,,,,,,,,,, -30000,1.5415022,1.6054387,,,,,,,,,,,,,, -30100,1.7698935,1.7274365,,,,,,,,,,,,,, -30169,,,0.6936184763908386,1.211167335510254,0.6198599934577942,1.569377303123474,50000.0,0.4919000267982483,2.3065173625946045,10000.0,10252.25020313263,10702.17523431778,10252.25020313263,448.03714537620544,0.8730454444885254,0.0 -30200,1.742834,1.6591206,,,,,,,,,,,,,, -30300,1.7993448,1.7587137,,,,,,,,,,,,,, -30400,1.7322531,1.6983907,,,,,,,,,,,,,, -30500,2.0430787,1.7105032,,,,,,,,,,,,,, -30600,1.7278029,1.7519555,,,,,,,,,,,,,, -30700,1.7524561,1.6446375,,,,,,,,,,,,,, -30800,1.7038136,1.7646407,,,,,,,,,,,,,, -30900,1.669819,1.6709934,,,,,,,,,,,,,, -31000,1.9325709,1.7312987,,,,,,,,,,,,,, -31100,1.89061,1.8417746,,,,,,,,,,,,,, -31200,1.7925946,1.7297195,,,,,,,,,,,,,, -31300,1.9587976,1.9319575,,,,,,,,,,,,,, -31400,1.8336422,1.6342448,,,,,,,,,,,,,, -31500,1.7471477,1.717071,,,,,,,,,,,,,, -31600,1.791411,1.6968544,,,,,,,,,,,,,, -31678,,,0.6753029227256775,1.2694958448410034,0.6157999634742737,1.6059370040893557,50000.0,0.4880000352859497,2.335855722427368,10000.0,10762.18581557274,11234.781381845474,10762.18581557274,470.6238646507263,0.9068002700805664,0.0 -31700,1.796777,1.7334162,,,,,,,,,,,,,, -31800,1.7360209,1.6616857,,,,,,,,,,,,,, -31900,1.7873093,1.7737538,,,,,,,,,,,,,, -32000,1.7189443,1.6695966,,,,,,,,,,,,,, -32100,1.714766,1.7138152,,,,,,,,,,,,,, -32200,1.6508763,1.6182352,,,,,,,,,,,,,, -32300,1.8352093,1.7224029,,,,,,,,,,,,,, -32400,1.882979,1.7498517,,,,,,,,,,,,,, -32500,1.7688864,1.6218584,,,,,,,,,,,,,, -32600,1.7731193,1.7772197,,,,,,,,,,,,,, -32700,2.0977795,1.7414775,,,,,,,,,,,,,, -32800,1.9130617,1.7602935,,,,,,,,,,,,,, -32900,1.7555467,1.6035042,,,,,,,,,,,,,, -33000,1.7157973,1.7812157,,,,,,,,,,,,,, -33100,1.7735149,1.6682969,,,,,,,,,,,,,, -33189,,,0.6901506781578064,1.2144831418991089,0.6292600035667419,1.5291776657104492,50000.0,0.5026000142097473,2.295906782150269,10000.0,11272.298690795898,11767.017308950424,11272.298690795898,492.6578459739685,0.9439327716827391,0.0 -33200,1.7787464,1.6391326,,,,,,,,,,,,,, -33300,1.8474478,1.7692131,,,,,,,,,,,,,, -33400,1.7628161,1.6692401,,,,,,,,,,,,,, -33500,1.8212019,1.7304611,,,,,,,,,,,,,, -33600,1.6681957,1.5751684,,,,,,,,,,,,,, -33700,1.8911173,1.6441756,,,,,,,,,,,,,, -33800,1.8005543,1.7150049,,,,,,,,,,,,,, -33900,1.8343378,1.7507591,,,,,,,,,,,,,, -34000,1.6216199,1.7467289,,,,,,,,,,,,,, -34100,1.7399071,1.5997982,,,,,,,,,,,,,, -34200,1.9913797,1.7112939,,,,,,,,,,,,,, -34300,1.7009302,1.5852189,,,,,,,,,,,,,, -34400,1.632399,1.6937141,,,,,,,,,,,,,, -34500,1.7747447,1.6096387,,,,,,,,,,,,,, -34600,2.0537827,1.7222558,,,,,,,,,,,,,, -34699,,,0.6905093789100647,1.2288453578948977,0.630079984664917,1.5220123529434204,50000.0,0.5005000233650208,2.252868890762329,10000.0,11782.477632045746,12298.890912532806,11782.477632045746,514.2654712200165,0.9805140495300292,0.0 -34700,1.7223123,1.7602482,,,,,,,,,,,,,, -34800,2.0365844,1.5990106,,,,,,,,,,,,,, -34900,1.6045561,1.6542468,,,,,,,,,,,,,, -35000,1.5984986,1.6259842,,,,,,,,,,,,,, -35100,1.9453986,1.6582899,,,,,,,,,,,,,, -35200,1.6946201,1.7077261,,,,,,,,,,,,,, -35300,1.6624235,1.5907054,,,,,,,,,,,,,, -35400,1.8659524,1.7358803,,,,,,,,,,,,,, -35500,2.0451236,1.6738257,,,,,,,,,,,,,, -35600,2.0927505,1.7598732,,,,,,,,,,,,,, -35700,1.9419283,1.793203,,,,,,,,,,,,,, -35800,1.9174304,1.7141445,,,,,,,,,,,,,, -35900,1.9896129,1.706336,,,,,,,,,,,,,, -36000,1.7522495,1.7077136,,,,,,,,,,,,,, -36100,1.7559835,1.6542372,,,,,,,,,,,,,, -36200,1.675802,1.7580504,,,,,,,,,,,,,, -36208,,,0.6882174611091614,1.2332067489624023,0.6226599812507629,1.570199966430664,50000.0,0.4957000315189361,2.327148199081421,10000.0,12292.503982305529,12831.100484848022,12292.503982305529,536.3608770370483,1.0178706645965576,0.0 -36300,1.6290108,1.693816,,,,,,,,,,,,,, -36400,2.039879,1.7448531,,,,,,,,,,,,,, -36500,1.8366493,1.727591,,,,,,,,,,,,,, -36600,1.792847,1.7114365,,,,,,,,,,,,,, -36700,1.9588729,1.7308372,,,,,,,,,,,,,, -36800,1.9468944,1.72759,,,,,,,,,,,,,, -36900,1.7439281,1.715999,,,,,,,,,,,,,, -37000,1.8041468,1.7430505,,,,,,,,,,,,,, -37100,1.7561114,1.6628308,,,,,,,,,,,,,, -37200,1.7586311,1.635728,,,,,,,,,,,,,, -37300,1.9321364,1.6332009,,,,,,,,,,,,,, -37400,1.5692269,1.4506658,,,,,,,,,,,,,, -37500,1.7769091,1.6040113,,,,,,,,,,,,,, -37600,1.6737238,1.7047229,,,,,,,,,,,,,, -37700,2.0009356,1.8005102,,,,,,,,,,,,,, -37718,,,0.6916055083274841,1.2146902084350586,0.6141799688339233,1.6082170009613037,50000.0,0.4864000082015991,2.3108022212982178,10000.0,12802.65355682373,13360.524811983109,12802.65355682373,555.5504469871521,1.0512502193450928,0.0 -37800,2.252771,1.8966384,,,,,,,,,,,,,, -37900,2.009401,1.5889227,,,,,,,,,,,,,, -38000,2.0999575,1.6684089,,,,,,,,,,,,,, -38100,1.8346429,1.739705,,,,,,,,,,,,,, -38200,1.8277417,1.7337502,,,,,,,,,,,,,, -38300,1.8140199,1.6269679,,,,,,,,,,,,,, -38400,1.7821066,1.6717813,,,,,,,,,,,,,, -38500,1.7795471,1.5004183,,,,,,,,,,,,,, -38600,1.6669012,1.6192378,,,,,,,,,,,,,, -38700,1.9105822,1.689295,,,,,,,,,,,,,, -38800,1.6673281,1.7037165,,,,,,,,,,,,,, -38900,1.7279953,1.5842209,,,,,,,,,,,,,, -39000,1.6719478,1.5721961,,,,,,,,,,,,,, -39100,1.7019445,1.5309302,,,,,,,,,,,,,, -39200,1.7302061,1.7752314,,,,,,,,,,,,,, -39228,,,0.7033043503761292,1.156270980834961,0.6322799921035767,1.516471028327942,50000.0,0.5144000053405762,2.233076572418213,10000.0,13312.825942516329,13890.063405036926,13312.825942516329,574.8256969451904,1.0908870697021484,0.0 -39300,1.7277172,1.7098107,,,,,,,,,,,,,, -39400,1.7848756,1.7210481,,,,,,,,,,,,,, -39500,1.6815803,1.6840076,,,,,,,,,,,,,, -39600,1.7673572,1.7205039,,,,,,,,,,,,,, -39700,1.9115584,1.7305099,,,,,,,,,,,,,, -39800,1.8267156,1.6575351,,,,,,,,,,,,,, -39900,1.6974186,1.7080866,,,,,,,,,,,,,, -40000,1.9952015,1.6848284,,,,,,,,,,,,,, -40100,1.820532,1.5422632,,,,,,,,,,,,,, -40200,1.8187448,1.6725345,,,,,,,,,,,,,, -40300,1.7989808,1.737839,,,,,,,,,,,,,, -40400,1.7292097,1.6395802,,,,,,,,,,,,,, -40500,1.8402417,1.5639179,,,,,,,,,,,,,, -40600,1.6302727,1.6489731,,,,,,,,,,,,,, -40700,1.9396665,1.6423653,,,,,,,,,,,,,, -40737,,,0.698640763759613,1.1931952238082886,0.6297799944877625,1.5363556146621704,50000.0,0.5001000165939331,2.2545292377471924,10000.0,13822.809534072876,14422.052576303482,13822.809534072876,596.7445023059845,1.1273493766784668,0.0 -40800,1.9594685,1.7076793,,,,,,,,,,,,,, -40900,1.8604863,1.6576918,,,,,,,,,,,,,, -41000,1.7532798,1.6565111,,,,,,,,,,,,,, -41100,1.9345646,1.7245914,,,,,,,,,,,,,, -41200,1.7777971,1.7007394,,,,,,,,,,,,,, -41300,1.713852,1.6999615,,,,,,,,,,,,,, -41400,1.7224828,1.6745294,,,,,,,,,,,,,, -41500,1.7705204,1.7394631,,,,,,,,,,,,,, -41600,2.2268944,1.6595798,,,,,,,,,,,,,, -41700,2.0387468,1.7908326,,,,,,,,,,,,,, -41800,1.7799103,1.6043892,,,,,,,,,,,,,, -41900,1.7474456,1.6077076,,,,,,,,,,,,,, -42000,1.7341043,1.6271781,,,,,,,,,,,,,, -42100,2.02095,1.7636603,,,,,,,,,,,,,, -42200,1.7096403,1.5739901,,,,,,,,,,,,,, -42247,,,0.6972257494926453,1.1852744817733765,0.6369199752807617,1.5088731050491333,50000.0,0.515500009059906,2.202585220336914,10000.0,14333.007185459135,14957.263036727903,14333.007185459135,621.6768667697906,1.1569111347198486,0.0 -42300,1.8359053,1.6200116,,,,,,,,,,,,,, -42400,1.7703745,1.5690122,,,,,,,,,,,,,, -42500,1.8951184,1.6245493,,,,,,,,,,,,,, -42600,1.7634509,1.7562759,,,,,,,,,,,,,, -42700,1.873283,1.5441827,,,,,,,,,,,,,, -42800,1.8044264,1.5842702,,,,,,,,,,,,,, -42900,1.7063328,1.6424922,,,,,,,,,,,,,, -43000,2.2994773,1.6299282,,,,,,,,,,,,,, -43100,1.8221022,1.5985544,,,,,,,,,,,,,, -43200,1.9720418,1.7106035,,,,,,,,,,,,,, -43300,1.8378043,1.6642457,,,,,,,,,,,,,, -43400,1.8325174,1.6914405,,,,,,,,,,,,,, -43500,1.7878155,1.6555916,,,,,,,,,,,,,, -43600,1.7771066,1.6503184,,,,,,,,,,,,,, -43700,1.9715005,1.7482425,,,,,,,,,,,,,, -43757,,,0.6907684803009033,1.2137516736984253,0.6337599754333496,1.513280153274536,50000.0,0.5047000050544739,2.261016845703125,10000.0,14843.097437620165,15486.488491773604,14843.097437620165,640.7257559299469,1.1918201446533203,0.0 -43800,1.767679,1.6503019,,,,,,,,,,,,,, -43900,1.7426916,1.6642919,,,,,,,,,,,,,, -44000,1.8399348,1.6727483,,,,,,,,,,,,,, -44100,1.6037471,1.6726954,,,,,,,,,,,,,, -44200,1.7753708,1.6513315,,,,,,,,,,,,,, -44300,1.8071654,1.5543989,,,,,,,,,,,,,, -44400,1.7644874,1.6480267,,,,,,,,,,,,,, -44500,2.0887258,1.7651377,,,,,,,,,,,,,, -44600,1.6473517,1.651366,,,,,,,,,,,,,, -44700,1.6952776,1.5500978,,,,,,,,,,,,,, -44800,1.9939967,1.5654619,,,,,,,,,,,,,, -44900,1.7799861,1.7386746,,,,,,,,,,,,,, -45000,1.778256,1.6284565,,,,,,,,,,,,,, -45100,2.2300994,1.6740195,,,,,,,,,,,,,, -45200,1.9787798,1.6839418,,,,,,,,,,,,,, -45267,,,0.7050183415412903,1.1576184034347534,0.6333799958229065,1.5275466442108154,50000.0,0.5081000328063965,2.2523016929626465,10000.0,15353.215401887894,16014.435420036316,15353.215401887894,658.4644737243652,1.2304582595825195,0.0 -45300,1.6280663,1.6528281,,,,,,,,,,,,,, -45400,2.1077616,1.7196354,,,,,,,,,,,,,, -45500,1.9339894,1.690469,,,,,,,,,,,,,, -45600,1.7488849,1.6377223,,,,,,,,,,,,,, -45700,1.7505854,1.6294719,,,,,,,,,,,,,, -45800,1.7660711,1.5507677,,,,,,,,,,,,,, -45900,1.8751209,1.6642336,,,,,,,,,,,,,, -46000,1.7822208,1.6792858,,,,,,,,,,,,,, -46100,1.8726416,1.7238679,,,,,,,,,,,,,, -46200,1.8644325,1.7253413,,,,,,,,,,,,,, -46300,1.9225054,1.6909373,,,,,,,,,,,,,, -46400,1.8587077,1.7631397,,,,,,,,,,,,,, -46500,1.6629918,1.6650584,,,,,,,,,,,,,, -46600,2.0064597,1.6104946,,,,,,,,,,,,,, -46700,1.9163132,1.679878,,,,,,,,,,,,,, -46777,,,0.7170957922935486,1.097066044807434,0.6341800093650818,1.5051928758621216,50000.0,0.508400022983551,2.2467596530914307,10000.0,15863.195637464523,16542.35011291504,15863.195637464523,676.3059678077698,1.2712736129760742,0.0 -46800,1.9177414,1.5928043,,,,,,,,,,,,,, -46900,1.8302354,1.5836,,,,,,,,,,,,,, -47000,1.6472313,1.6538918,,,,,,,,,,,,,, -47100,1.8240206,1.7086703,,,,,,,,,,,,,, -47200,1.9269786,1.5679717,,,,,,,,,,,,,, -47300,1.8781618,1.759918,,,,,,,,,,,,,, -47400,1.7866176,1.5296456,,,,,,,,,,,,,, -47500,1.8551012,1.5492207,,,,,,,,,,,,,, -47600,1.880227,1.6733806,,,,,,,,,,,,,, -47700,2.2155285,1.7071849,,,,,,,,,,,,,, -47800,1.820588,1.484831,,,,,,,,,,,,,, -47900,1.934962,1.5896404,,,,,,,,,,,,,, -48000,1.9575833,1.6933097,,,,,,,,,,,,,, -48100,1.6750224,1.5874612,,,,,,,,,,,,,, -48200,1.8584492,1.6755152,,,,,,,,,,,,,, -48287,,,0.6975845098495483,1.179138422012329,0.620199978351593,1.5764381885528564,50000.0,0.5080000162124634,2.26758337020874,10000.0,16373.403130292892,17070.333611249924,16373.403130292892,693.9939639568329,1.3073816299438477,0.0 -48300,1.7522537,1.5510687,,,,,,,,,,,,,, -48400,1.7344123,1.5777847,,,,,,,,,,,,,, -48500,1.8531383,1.6209688,,,,,,,,,,,,,, -48600,1.864474,1.5822272,,,,,,,,,,,,,, -48700,1.732074,1.538962,,,,,,,,,,,,,, -48800,1.8535483,1.6375556,,,,,,,,,,,,,, -48900,2.0752413,1.6222417,,,,,,,,,,,,,, -49000,1.8821125,1.6270386,,,,,,,,,,,,,, -49100,2.0405009,1.6071254,,,,,,,,,,,,,, -49200,1.755229,1.6294715,,,,,,,,,,,,,, -49300,1.5917723,1.5919051,,,,,,,,,,,,,, -49400,2.0201645,1.5978065,,,,,,,,,,,,,, -49500,1.7403687,1.6250427,,,,,,,,,,,,,, -49600,1.9136873,1.5786268,,,,,,,,,,,,,, -49700,2.0402455,1.6429093,,,,,,,,,,,,,, -49798,,,0.7109972834587097,1.1291890144348145,0.6435399651527405,1.4746588468551636,50000.0,0.5197000503540039,2.1758978366851807,10000.0,16883.60203051567,17600.109800100327,16883.60203051567,713.4746758937836,1.3526594638824463,0.0 -49800,1.867466,1.5698688,,,,,,,,,,,,,, -49900,1.9514861,1.5703468,,,,,,,,,,,,,, -50000,2.0208602,1.7149978,,,,,,,,,,,,,, -50100,1.9140751,1.5378008,,,,,,,,,,,,,, -50200,2.1395383,1.7713537,,,,,,,,,,,,,, -50300,1.9726706,1.6620042,,,,,,,,,,,,,, -50400,2.1063135,1.478156,,,,,,,,,,,,,, -50500,1.8386077,1.6012312,,,,,,,,,,,,,, -50600,1.9725418,1.6678387,,,,,,,,,,,,,, -50700,2.0254788,1.5735459,,,,,,,,,,,,,, -50800,1.820372,1.5673256,,,,,,,,,,,,,, -50900,1.8827873,1.6001976,,,,,,,,,,,,,, -51000,1.9591266,1.7320172,,,,,,,,,,,,,, -51100,1.8483602,1.5887072,,,,,,,,,,,,,, -51200,1.7408051,1.4824269,,,,,,,,,,,,,, -51300,1.6900064,1.7468812,,,,,,,,,,,,,, -51308,,,0.7145049571990967,1.11211895942688,0.6514599919319153,1.4398891925811768,50000.0,0.5217000246047974,2.146376132965088,10000.0,17393.63419032097,18128.94772911072,17393.63419032097,732.192950963974,1.389472246170044,0.0 -51400,2.029097,1.6965544,,,,,,,,,,,,,, -51500,2.0012374,1.7095652,,,,,,,,,,,,,, -51600,1.7913293,1.7146963,,,,,,,,,,,,,, -51700,1.8268044,1.5708904,,,,,,,,,,,,,, -51800,1.8112545,1.6492848,,,,,,,,,,,,,, -51900,1.892094,1.6933881,,,,,,,,,,,,,, -52000,1.7155963,1.4733677,,,,,,,,,,,,,, -52100,2.0468485,1.6131672,,,,,,,,,,,,,, -52200,1.8273379,1.6661748,,,,,,,,,,,,,, -52300,2.1879156,1.6336656,,,,,,,,,,,,,, -52400,1.9266233,1.7687066,,,,,,,,,,,,,, -52500,1.8418143,1.5571297,,,,,,,,,,,,,, -52600,1.8578998,1.6408802,,,,,,,,,,,,,, -52700,1.874517,1.5770057,,,,,,,,,,,,,, -52800,1.7886808,1.4632524,,,,,,,,,,,,,, -52818,,,0.7062539458274841,1.140507698059082,0.6458199620246887,1.4497957229614258,50000.0,0.5200000405311584,2.168796300888061,10000.0,17903.583858013153,18657.767151117325,17903.583858013153,750.9723279476166,1.429046869277954,0.0 -52900,2.0166612,1.7296927,,,,,,,,,,,,,, -53000,2.0916076,1.5671326,,,,,,,,,,,,,, -53100,2.0085313,1.5809444,,,,,,,,,,,,,, -53200,2.125435,1.7250148,,,,,,,,,,,,,, -53300,1.7399822,1.5835334,,,,,,,,,,,,,, -53400,1.959204,1.495518,,,,,,,,,,,,,, -53500,1.8987181,1.5709693,,,,,,,,,,,,,, -53600,1.8736994,1.5818572,,,,,,,,,,,,,, -53700,1.825451,1.6188524,,,,,,,,,,,,,, -53800,1.7662919,1.6169645,,,,,,,,,,,,,, -53900,1.751166,1.5058708,,,,,,,,,,,,,, -54000,1.9001542,1.5991051,,,,,,,,,,,,,, -54100,1.9558458,1.6977202,,,,,,,,,,,,,, -54200,2.1205122,1.6661148,,,,,,,,,,,,,, -54300,2.306547,1.5701468,,,,,,,,,,,,,, -54329,,,0.7234932780265808,1.0733474493026731,0.6421399712562561,1.484086513519287,50000.0,0.5128999948501587,2.193343162536621,10000.0,18413.752345323563,19186.21563768387,18413.752345323563,769.1590249538422,1.4705872535705566,0.0 -54400,1.8150115,1.6615063,,,,,,,,,,,,,, -54500,1.8144217,1.5639228,,,,,,,,,,,,,, -54600,1.8661109,1.6244903,,,,,,,,,,,,,, -54700,1.8047577,1.6060137,,,,,,,,,,,,,, -54800,1.9036024,1.6202446,,,,,,,,,,,,,, -54900,1.8778878,1.4801384,,,,,,,,,,,,,, -55000,1.7857885,1.5189162,,,,,,,,,,,,,, -55100,1.8292515,1.560924,,,,,,,,,,,,,, -55200,1.8711827,1.5597069,,,,,,,,,,,,,, -55300,1.8244241,1.6358519,,,,,,,,,,,,,, -55400,1.9414614,1.7073423,,,,,,,,,,,,,, -55500,1.7432325,1.5671942,,,,,,,,,,,,,, -55600,1.822162,1.6688398,,,,,,,,,,,,,, -55700,1.6639408,1.4979032,,,,,,,,,,,,,, -55800,1.9886055,1.4395251,,,,,,,,,,,,,, -55839,,,0.7308075428009033,1.0367093086242676,0.6415799856185913,1.4657151699066162,50000.0,0.517300009727478,2.189545154571533,10000.0,18923.7833173275,19714.81046271324,18923.7833173275,787.6337478160858,1.508671760559082,0.0 -55900,1.7260658,1.6431487,,,,,,,,,,,,,, -56000,2.1284077,1.6811962,,,,,,,,,,,,,, -56100,1.9604607,1.6173102,,,,,,,,,,,,,, -56200,1.6923041,1.4304483,,,,,,,,,,,,,, -56300,1.8059443,1.583709,,,,,,,,,,,,,, -56400,2.0856912,1.6368334,,,,,,,,,,,,,, -56500,1.9888755,1.6023592,,,,,,,,,,,,,, -56600,1.9742833,1.5418167,,,,,,,,,,,,,, -56700,1.8844094,1.6618516,,,,,,,,,,,,,, -56800,1.8856773,1.555599,,,,,,,,,,,,,, -56900,1.8904722,1.5872505,,,,,,,,,,,,,, -57000,2.0962365,1.7013227,,,,,,,,,,,,,, -57100,2.0828912,1.6010004,,,,,,,,,,,,,, -57200,1.9658774,1.4825034,,,,,,,,,,,,,, -57300,2.0973554,1.5891986,,,,,,,,,,,,,, -57349,,,0.7161989808082581,1.0922472476959229,0.6412999629974365,1.4972940683364868,50000.0,0.5111000537872314,2.2348709106445312,10000.0,19433.82699513436,20242.517218351364,19433.82699513436,805.2087225914001,1.5456082820892334,0.0 -57400,1.9912565,1.584899,,,,,,,,,,,,,, -57500,1.8754185,1.6358649,,,,,,,,,,,,,, -57600,1.9143438,1.6096805,,,,,,,,,,,,,, -57700,2.0091147,1.6693383,,,,,,,,,,,,,, -57800,1.7577832,1.485954,,,,,,,,,,,,,, -57900,1.7743964,1.4912076,,,,,,,,,,,,,, -58000,1.9745908,1.4811561,,,,,,,,,,,,,, -58100,1.9954556,1.6375252,,,,,,,,,,,,,, -58200,2.080036,1.6169682,,,,,,,,,,,,,, -58300,1.8784658,1.5804639,,,,,,,,,,,,,, -58400,1.9815055,1.5466292,,,,,,,,,,,,,, -58500,1.8750451,1.5690317,,,,,,,,,,,,,, -58600,1.9047738,1.5153689,,,,,,,,,,,,,, -58700,1.7928398,1.4865729,,,,,,,,,,,,,, -58800,1.988202,1.6797798,,,,,,,,,,,,,, -58859,,,0.7106584906578064,1.1283915042877195,0.6424799561500549,1.481645584106445,50000.0,0.5232000350952148,2.208247423171997,10000.0,19943.966212511063,20770.65550327301,19943.966212511063,823.1132650375366,1.5886414051055908,0.0 -58900,1.8723166,1.5870318,,,,,,,,,,,,,, -59000,1.9848751,1.548733,,,,,,,,,,,,,, -59100,1.9270785,1.5968338,,,,,,,,,,,,,, -59200,1.7840928,1.5511764,,,,,,,,,,,,,, -59300,1.8724326,1.5791343,,,,,,,,,,,,,, -59400,1.812445,1.5239056,,,,,,,,,,,,,, -59500,1.778683,1.5044792,,,,,,,,,,,,,, -59600,1.9293321,1.5186158,,,,,,,,,,,,,, -59700,1.909164,1.5815868,,,,,,,,,,,,,, -59800,2.1695065,1.5691051,,,,,,,,,,,,,, -59900,1.8099929,1.5302037,,,,,,,,,,,,,, -60000,1.85226,1.519041,,,,,,,,,,,,,, -60100,1.825505,1.4836421,,,,,,,,,,,,,, -60200,1.9296054,1.6439244,,,,,,,,,,,,,, -60300,2.113171,1.5587168,,,,,,,,,,,,,, -60369,,,0.7179926633834839,1.0955880880355835,0.6552000045776367,1.42160165309906,50000.0,0.5243000388145447,2.127664804458618,10000.0,20453.887604236603,21298.15672469139,20453.887604236603,840.6038625240326,1.6262004375457764,0.0 -60400,1.8786023,1.4677254,,,,,,,,,,,,,, -60500,1.8975217,1.5061007,,,,,,,,,,,,,, -60600,1.8215008,1.486106,,,,,,,,,,,,,, -60700,2.0469167,1.6182783,,,,,,,,,,,,,, -60800,1.7658825,1.6193205,,,,,,,,,,,,,, -60900,1.9636256,1.5486275,,,,,,,,,,,,,, -61000,1.9737177,1.6156604,,,,,,,,,,,,,, -61100,1.8781352,1.6274055,,,,,,,,,,,,,, -61200,2.0003192,1.5213672,,,,,,,,,,,,,, -61300,1.920333,1.6650553,,,,,,,,,,,,,, -61400,1.8019346,1.5464591,,,,,,,,,,,,,, -61500,1.9485425,1.5615956,,,,,,,,,,,,,, -61600,2.0453687,1.5034602,,,,,,,,,,,,,, -61700,2.003582,1.5459566,,,,,,,,,,,,,, -61800,2.0508196,1.5281152,,,,,,,,,,,,,, -61880,,,0.7214205861091614,1.0676918029785156,0.6552199721336365,1.4101110696792605,50000.0,0.5288000106811523,2.1190402507781982,10000.0,20964.092700004578,21826.239169597626,20964.092700004578,858.3883543014526,1.667083978652954,0.0 -61900,1.879352,1.6522157,,,,,,,,,,,,,, -62000,1.8614016,1.6173899,,,,,,,,,,,,,, -62100,1.8361769,1.5133781,,,,,,,,,,,,,, -62200,2.165106,1.6112487,,,,,,,,,,,,,, -62300,1.967874,1.6690818,,,,,,,,,,,,,, -62400,2.222875,1.4829452,,,,,,,,,,,,,, -62500,1.760079,1.5645456,,,,,,,,,,,,,, -62600,1.9809847,1.6152252,,,,,,,,,,,,,, -62700,1.8940345,1.513588,,,,,,,,,,,,,, -62800,1.8781126,1.5601249,,,,,,,,,,,,,, -62900,1.8668795,1.5085909,,,,,,,,,,,,,, -63000,1.7837801,1.5819653,,,,,,,,,,,,,, -63100,1.6962923,1.4092643,,,,,,,,,,,,,, -63200,2.1912754,1.4943309,,,,,,,,,,,,,, -63300,1.9541527,1.5275654,,,,,,,,,,,,,, -63391,,,0.7526506781578064,0.9483557343482972,0.6526399850845337,1.41191303730011,50000.0,0.5279000401496887,2.121713399887085,10000.0,21474.092804193497,22354.6533973217,21474.092804193497,876.7133748531342,1.7043514251708984,0.0 -63400,1.7381068,1.4727739,,,,,,,,,,,,,, -63500,1.8859617,1.6188595,,,,,,,,,,,,,, -63600,1.9481431,1.5482514,,,,,,,,,,,,,, -63700,1.8897738,1.6586081,,,,,,,,,,,,,, -63800,1.8027002,1.5079939,,,,,,,,,,,,,, -63900,2.0206826,1.4973141,,,,,,,,,,,,,, -64000,2.167825,1.6882935,,,,,,,,,,,,,, -64100,1.8483188,1.5057067,,,,,,,,,,,,,, -64200,1.9177544,1.4923723,,,,,,,,,,,,,, -64300,1.8964206,1.5456594,,,,,,,,,,,,,, -64400,1.8784407,1.5393366,,,,,,,,,,,,,, -64500,1.9844457,1.6628537,,,,,,,,,,,,,, -64600,1.870933,1.4795444,,,,,,,,,,,,,, -64700,1.8949356,1.5025021,,,,,,,,,,,,,, -64800,1.9672779,1.5932165,,,,,,,,,,,,,, -64900,1.9656408,1.4406049,,,,,,,,,,,,,, -64901,,,0.7511758208274841,0.947900414466858,0.6609999537467957,1.3837952613830566,50000.0,0.5368000268936157,2.108168363571167,10000.0,21984.03058242798,22883.24925875664,21984.03058242798,895.2762496471405,1.7477846145629885,0.0 -65000,1.8558017,1.5098867,,,,,,,,,,,,,, -65100,1.8402073,1.5743726,,,,,,,,,,,,,, -65200,1.9755559,1.423976,,,,,,,,,,,,,, -65300,2.0158055,1.4011667,,,,,,,,,,,,,, -65400,2.0506532,1.4854115,,,,,,,,,,,,,, -65500,2.2029183,1.5358639,,,,,,,,,,,,,, -65600,1.8371226,1.4790261,,,,,,,,,,,,,, -65700,2.0995288,1.572542,,,,,,,,,,,,,, -65800,2.1346703,1.6185013,,,,,,,,,,,,,, -65900,1.8843719,1.5535139,,,,,,,,,,,,,, -66000,1.9641894,1.6345887,,,,,,,,,,,,,, -66100,1.8886521,1.5681189,,,,,,,,,,,,,, -66200,1.8563259,1.5623515,,,,,,,,,,,,,, -66300,2.1533237,1.6701498,,,,,,,,,,,,,, -66400,1.9957116,1.392433,,,,,,,,,,,,,, -66412,,,0.7248086333274841,1.053501844406128,0.6454199552536011,1.4561376571655271,50000.0,0.5273000001907349,2.121701955795288,10000.0,22494.092856168747,23410.470444202423,22494.092856168747,912.3423182964324,1.7881202697753906,0.0 -66500,1.9879428,1.5747061,,,,,,,,,,,,,, -66600,1.923818,1.4548879,,,,,,,,,,,,,, -66700,1.9723163,1.5678884,,,,,,,,,,,,,, -66800,1.8544655,1.5085855,,,,,,,,,,,,,, -66900,1.9031123,1.435703,,,,,,,,,,,,,, -67000,1.9117249,1.4955246,,,,,,,,,,,,,, -67100,1.9030411,1.4952064,,,,,,,,,,,,,, -67200,2.0331087,1.5977942,,,,,,,,,,,,,, -67300,1.9253894,1.5227349,,,,,,,,,,,,,, -67400,1.9936779,1.5281628,,,,,,,,,,,,,, -67500,2.1679244,1.52144,,,,,,,,,,,,,, -67600,2.1863923,1.7122982,,,,,,,,,,,,,, -67700,1.9667609,1.589575,,,,,,,,,,,,,, -67800,2.1193035,1.5694174,,,,,,,,,,,,,, -67900,1.9011626,1.4807935,,,,,,,,,,,,,, -67923,,,0.7313655614852905,1.0242691040039062,0.6623199582099915,1.3837313652038574,50000.0,0.5348000526428223,2.116585493087769,10000.0,23004.167016267776,23937.85494709015,23004.167016267776,929.560394525528,1.828833818435669,0.0 -68000,1.7887273,1.494933,,,,,,,,,,,,,, -68100,1.9969896,1.5569495,,,,,,,,,,,,,, -68200,1.9539721,1.46125,,,,,,,,,,,,,, -68300,2.088003,1.4811711,,,,,,,,,,,,,, -68400,1.9785699,1.5859784,,,,,,,,,,,,,, -68500,2.0949504,1.5288981,,,,,,,,,,,,,, -68600,2.0138583,1.5914865,,,,,,,,,,,,,, -68700,1.8916862,1.6002202,,,,,,,,,,,,,, -68800,2.0963612,1.3670743,,,,,,,,,,,,,, -68900,1.933576,1.4480709,,,,,,,,,,,,,, -69000,2.0820699,1.467475,,,,,,,,,,,,,, -69100,2.035829,1.5394945,,,,,,,,,,,,,, -69200,1.9140636,1.44888,,,,,,,,,,,,,, -69300,2.1414585,1.5019076,,,,,,,,,,,,,, -69400,2.2281163,1.568754,,,,,,,,,,,,,, -69433,,,0.7354910373687744,1.011197805404663,0.6612199544906616,1.3969271183013916,50000.0,0.5420000553131104,2.1048424243927,10000.0,23514.20720410347,24465.004062891006,23514.20720410347,946.576143026352,1.868699073791504,0.0 -69500,2.0962796,1.5367663,,,,,,,,,,,,,, -69600,2.0990577,1.6724812,,,,,,,,,,,,,, -69700,2.1247833,1.5577427,,,,,,,,,,,,,, -69800,1.9801983,1.4701504,,,,,,,,,,,,,, -69900,2.0976987,1.5519861,,,,,,,,,,,,,, -70000,2.1078453,1.4850155,,,,,,,,,,,,,, -70100,1.8527454,1.5585938,,,,,,,,,,,,,, -70200,2.0706267,1.4579098,,,,,,,,,,,,,, -70300,2.1381643,1.4526619,,,,,,,,,,,,,, -70400,2.0212154,1.4351034,,,,,,,,,,,,,, -70500,2.004909,1.5705998,,,,,,,,,,,,,, -70600,2.1447697,1.533513,,,,,,,,,,,,,, -70700,1.9748812,1.5543288,,,,,,,,,,,,,, -70800,2.0471606,1.5026846,,,,,,,,,,,,,, -70900,1.968992,1.5222749,,,,,,,,,,,,,, -70944,,,0.7195272445678711,1.0805131196975708,0.6513800024986267,1.4209266901016235,50000.0,0.5263000130653381,2.148491382598877,10000.0,24024.3210310936,24992.387575387955,24024.3210310936,963.7557821273804,1.906718254089356,0.0 -71000,2.1389167,1.6032646,,,,,,,,,,,,,, -71100,1.8832757,1.3925002,,,,,,,,,,,,,, -71200,1.8872066,1.4277233,,,,,,,,,,,,,, -71300,2.3053536,1.5171353,,,,,,,,,,,,,, -71400,2.0921555,1.6431241,,,,,,,,,,,,,, -71500,2.079323,1.5384934,,,,,,,,,,,,,, -71600,2.0155492,1.5834924,,,,,,,,,,,,,, -71700,1.9129664,1.4166985,,,,,,,,,,,,,, -71800,2.0173507,1.4953775,,,,,,,,,,,,,, -71900,1.9785918,1.5301878,,,,,,,,,,,,,, -72000,2.0902038,1.523776,,,,,,,,,,,,,, -72100,2.286257,1.6059245,,,,,,,,,,,,,, -72200,1.8574339,1.4276819,,,,,,,,,,,,,, -72300,2.2776537,1.5411897,,,,,,,,,,,,,, -72400,2.1492834,1.4851011,,,,,,,,,,,,,, -72455,,,0.7664819955825806,0.8874098658561707,0.6557999849319458,1.4078059196472168,50000.0,0.5333000421524048,2.126142978668213,10000.0,24534.373774051663,25520.007241487503,24534.373774051663,981.2310724258424,1.946107625961304,0.0 -72500,2.1013756,1.5183624,,,,,,,,,,,,,, -72600,2.050086,1.5525064,,,,,,,,,,,,,, -72700,1.9397349,1.4950273,,,,,,,,,,,,,, -72800,2.0352192,1.527005,,,,,,,,,,,,,, -72900,2.010668,1.5822414,,,,,,,,,,,,,, -73000,2.2004817,1.4680567,,,,,,,,,,,,,, -73100,2.0919595,1.5483279,,,,,,,,,,,,,, -73200,2.1516104,1.4729424,,,,,,,,,,,,,, -73300,2.2184346,1.6792501,,,,,,,,,,,,,, -73400,2.1653686,1.568413,,,,,,,,,,,,,, -73500,2.0766516,1.4373225,,,,,,,,,,,,,, -73600,2.2453735,1.4764495,,,,,,,,,,,,,, -73700,2.0777383,1.5075015,,,,,,,,,,,,,, -73800,2.1722794,1.4850143,,,,,,,,,,,,,, -73900,2.0816467,1.5578994,,,,,,,,,,,,,, -73966,,,0.7506775856018066,0.9532562494277954,0.6530999541282654,1.414706826210022,50000.0,0.5290000438690186,2.135497093200684,10000.0,25044.457715272903,26047.58855366707,25044.457715272903,998.6352922916412,1.987243413925171,0.0 -74000,2.0850043,1.6355247,,,,,,,,,,,,,, -74100,1.9774694,1.4682661,,,,,,,,,,,,,, -74200,2.0606222,1.3902626,,,,,,,,,,,,,, -74300,2.4140375,1.5468649,,,,,,,,,,,,,, -74400,2.109592,1.4764398,,,,,,,,,,,,,, -74500,1.9193884,1.4300089,,,,,,,,,,,,,, -74600,2.0302122,1.5667154,,,,,,,,,,,,,, -74700,2.039922,1.3439112,,,,,,,,,,,,,, -74800,1.9611223,1.5050972,,,,,,,,,,,,,, -74900,1.9620483,1.3710115,,,,,,,,,,,,,, -75000,2.0527933,1.4912691,,,,,,,,,,,,,, -75100,2.063584,1.6141299,,,,,,,,,,,,,, -75200,2.1502895,1.4673622,,,,,,,,,,,,,, -75300,2.2619567,1.3790464,,,,,,,,,,,,,, -75400,1.8730915,1.4403318,,,,,,,,,,,,,, -75476,,,0.7411909699440002,0.9859520196914672,0.6597200036048889,1.3939448595046997,50000.0,0.5326000452041626,2.133927345275879,10000.0,25554.508782863617,26575.4754896164,25554.508782863617,1016.3779957294464,2.028724908828736,0.0 -75500,2.2099469,1.4963233,,,,,,,,,,,,,, -75600,2.1883247,1.5029372,,,,,,,,,,,,,, -75700,2.005534,1.4197823,,,,,,,,,,,,,, -75800,1.9662894,1.4316947,,,,,,,,,,,,,, -75900,2.182726,1.4755219,,,,,,,,,,,,,, -76000,2.1300185,1.53908,,,,,,,,,,,,,, -76100,2.075971,1.4559994,,,,,,,,,,,,,, -76200,2.0797186,1.4004343,,,,,,,,,,,,,, -76300,2.232121,1.4881821,,,,,,,,,,,,,, -76400,1.925059,1.4618256,,,,,,,,,,,,,, -76500,2.134112,1.4468356,,,,,,,,,,,,,, -76600,2.1258252,1.5031509,,,,,,,,,,,,,, -76700,2.1368635,1.4437342,,,,,,,,,,,,,, -76800,2.350569,1.5026497,,,,,,,,,,,,,, -76900,2.2901359,1.5019816,,,,,,,,,,,,,, -76986,,,0.7378627061843872,0.9981990456581116,0.6623799800872803,1.3914191722869873,50000.0,0.5375000238418579,2.1226179599761963,10000.0,26064.464957475662,27102.67097759247,26064.464957475662,1033.5246062278748,2.0691707134246826,0.0 -77000,2.1832757,1.4958575,,,,,,,,,,,,,, -77100,2.020928,1.4929805,,,,,,,,,,,,,, -77200,2.139544,1.4582125,,,,,,,,,,,,,, -77300,2.0500486,1.3660557,,,,,,,,,,,,,, -77400,2.0623727,1.5509385,,,,,,,,,,,,,, -77500,2.0337267,1.4949117,,,,,,,,,,,,,, -77600,2.194452,1.5712273,,,,,,,,,,,,,, -77700,2.254498,1.4123138,,,,,,,,,,,,,, -77800,2.152646,1.4684423,,,,,,,,,,,,,, -77900,2.0000517,1.4328501,,,,,,,,,,,,,, -78000,2.1419556,1.6007212,,,,,,,,,,,,,, -78100,2.24114,1.5519631,,,,,,,,,,,,,, -78200,2.0258088,1.4993573,,,,,,,,,,,,,, -78300,2.15466,1.4849253,,,,,,,,,,,,,, -78400,2.223189,1.4325314,,,,,,,,,,,,,, -78494,,,0.7308474183082581,1.024781346321106,0.6592999696731567,1.396071195602417,50000.0,0.5348000526428223,2.101693630218506,10000.0,26573.513592481613,27629.85961341858,26573.513592481613,1050.6655519008636,3.017031192779541,0.0 -78500,2.1168165,1.440265,,,,,,,,,,,,,, -78600,2.1318693,1.4497124,,,,,,,,,,,,,, -78700,2.1576083,1.493122,,,,,,,,,,,,,, -78800,1.9920533,1.4992316,,,,,,,,,,,,,, -78900,2.00211,1.474389,,,,,,,,,,,,,, -79000,2.0524566,1.4991974,,,,,,,,,,,,,, -79100,2.132174,1.5117958,,,,,,,,,,,,,, -79200,2.1339748,1.4761373,,,,,,,,,,,,,, -79300,2.078384,1.4658446,,,,,,,,,,,,,, -79400,2.2036765,1.509901,,,,,,,,,,,,,, -79500,2.1221113,1.4994721,,,,,,,,,,,,,, -79600,2.0478477,1.4490294,,,,,,,,,,,,,, -79700,2.1857374,1.4905075,,,,,,,,,,,,,, -79800,2.1015317,1.547905,,,,,,,,,,,,,, -79900,2.168685,1.4474026,,,,,,,,,,,,,, -80000,2.0903792,1.4503822,,,,,,,,,,,,,, -80005,,,0.73148512840271,1.0369644165039062,0.6642000079154968,1.3688702583312988,50000.0,0.5323000550270081,2.090766191482544,10000.0,27083.48101949692,28156.81354093552,27083.48101949692,1067.5592651367188,3.058248281478882,0.0 -80100,2.2040458,1.3312135,,,,,,,,,,,,,, -80200,2.2598045,1.4203023,,,,,,,,,,,,,, -80300,2.2945652,1.4577564,,,,,,,,,,,,,, -80400,2.0838313,1.4013844,,,,,,,,,,,,,, -80500,2.0601897,1.4272122,,,,,,,,,,,,,, -80600,2.2499366,1.4365815,,,,,,,,,,,,,, -80700,2.3746216,1.4243164,,,,,,,,,,,,,, -80800,2.0862505,1.3828665,,,,,,,,,,,,,, -80900,2.1671135,1.4343216,,,,,,,,,,,,,, -81000,2.075273,1.4931538,,,,,,,,,,,,,, -81100,2.267901,1.4608558,,,,,,,,,,,,,, -81200,1.9015073,1.3615022,,,,,,,,,,,,,, -81300,2.3515406,1.5688384,,,,,,,,,,,,,, -81400,2.273854,1.5711339,,,,,,,,,,,,,, -81500,2.34432,1.5359145,,,,,,,,,,,,,, -81515,,,0.7939453125,0.7794730067253113,0.6765599846839905,1.3186352252960205,50000.0,0.5491999983787537,2.030189275741577,10000.0,27593.4535381794,28684.019721269608,27593.4535381794,1084.7018973827362,3.098496198654175,0.0 -81600,2.1456676,1.5033007,,,,,,,,,,,,,, -81700,2.1165934,1.4026966,,,,,,,,,,,,,, -81800,2.2011936,1.5129697,,,,,,,,,,,,,, -81900,2.04665,1.424099,,,,,,,,,,,,,, -82000,2.3597918,1.3258767,,,,,,,,,,,,,, -82100,2.0240662,1.527077,,,,,,,,,,,,,, -82200,1.9307811,1.496914,,,,,,,,,,,,,, -82300,2.1023135,1.363213,,,,,,,,,,,,,, -82400,2.146645,1.4074153,,,,,,,,,,,,,, -82500,2.0716262,1.426873,,,,,,,,,,,,,, -82600,2.0930433,1.4365295,,,,,,,,,,,,,, -82700,2.0867636,1.4486948,,,,,,,,,,,,,, -82800,2.3745804,1.5076323,,,,,,,,,,,,,, -82900,2.060197,1.4751992,,,,,,,,,,,,,, -83000,2.1064157,1.3837407,,,,,,,,,,,,,, -83025,,,0.7580117583274841,0.917476773262024,0.6669999957084656,1.3600958585739136,50000.0,0.5421000123023987,2.0674638748168945,10000.0,28103.538105487823,29211.65815377236,28103.538105487823,1102.1596643924713,3.142584085464477,0.0 -83100,2.1397526,1.4293818,,,,,,,,,,,,,, -83200,2.2329416,1.4667724,,,,,,,,,,,,,, -83300,2.042064,1.5287913,,,,,,,,,,,,,, -83400,2.1860058,1.4591436,,,,,,,,,,,,,, -83500,2.0952742,1.419093,,,,,,,,,,,,,, -83600,2.4214354,1.481568,,,,,,,,,,,,,, -83700,2.1922143,1.442039,,,,,,,,,,,,,, -83800,2.2849455,1.473649,,,,,,,,,,,,,, -83900,2.2315903,1.4585341,,,,,,,,,,,,,, -84000,2.3822253,1.4730129,,,,,,,,,,,,,, -84100,2.2581885,1.4906039,,,,,,,,,,,,,, -84200,2.1711133,1.4797491,,,,,,,,,,,,,, -84300,2.1123075,1.339057,,,,,,,,,,,,,, -84400,2.1557362,1.4318233,,,,,,,,,,,,,, -84500,2.1177888,1.3975642,,,,,,,,,,,,,, -84537,,,0.75390625,0.9180837273597716,0.6711199879646301,1.3440886735916138,50000.0,0.5378000140190125,2.0711793899536133,10000.0,28613.781172037125,29739.06396436692,28613.781172037125,1119.2302613258362,3.1828737258911133,0.0 -84600,2.183335,1.491099,,,,,,,,,,,,,, -84700,2.133655,1.5033579,,,,,,,,,,,,,, -84800,2.2028892,1.5041434,,,,,,,,,,,,,, -84900,2.115295,1.415982,,,,,,,,,,,,,, -85000,2.232502,1.3942648,,,,,,,,,,,,,, -85100,2.1462224,1.3928404,,,,,,,,,,,,,, -85200,2.1514795,1.5172623,,,,,,,,,,,,,, -85300,2.3516965,1.5444825,,,,,,,,,,,,,, -85400,2.4945416,1.467897,,,,,,,,,,,,,, -85500,2.2279193,1.4350673,,,,,,,,,,,,,, -85600,2.1810958,1.4658889,,,,,,,,,,,,,, -85700,2.3674183,1.4137537,,,,,,,,,,,,,, -85800,2.259125,1.4758558,,,,,,,,,,,,,, -85900,2.3700004,1.4508466,,,,,,,,,,,,,, -86000,2.6762097,1.4890404,,,,,,,,,,,,,, -86048,,,0.7520527839660645,0.9375916719436646,0.670799970626831,1.3429946899414062,50000.0,0.5491000413894653,2.0667781829833984,10000.0,29123.81152534485,30265.995129585262,29123.81152534485,1136.0348734855652,3.2278828620910645,0.0 -86100,2.185536,1.400596,,,,,,,,,,,,,, -86200,2.3639486,1.3720653,,,,,,,,,,,,,, -86300,2.1804247,1.3247899,,,,,,,,,,,,,, -86400,2.1799238,1.40488,,,,,,,,,,,,,, -86500,2.323375,1.4655795,,,,,,,,,,,,,, -86600,2.0562444,1.4338896,,,,,,,,,,,,,, -86700,2.1690164,1.4354935,,,,,,,,,,,,,, -86800,2.2707548,1.4767003,,,,,,,,,,,,,, -86900,2.2469487,1.3973756,,,,,,,,,,,,,, -87000,2.3299527,1.514744,,,,,,,,,,,,,, -87100,2.1979847,1.475323,,,,,,,,,,,,,, -87200,2.2500598,1.4541852,,,,,,,,,,,,,, -87300,2.467149,1.4806765,,,,,,,,,,,,,, -87400,2.4135864,1.5156885,,,,,,,,,,,,,, -87500,2.1689022,1.4465115,,,,,,,,,,,,,, -87559,,,0.7551020383834839,0.9199984073638916,0.6786999702453613,1.3077771663665771,50000.0,0.5520000457763672,2.0441277027130127,10000.0,29633.89668869972,30793.2783703804,29633.89668869972,1153.1339037418363,3.2748494148254395,0.0 -87600,2.301245,1.4195151,,,,,,,,,,,,,, -87700,2.2029808,1.4364383,,,,,,,,,,,,,, -87800,2.3041646,1.4539692,,,,,,,,,,,,,, -87900,2.0586123,1.3647084,,,,,,,,,,,,,, -88000,2.6282394,1.4465752,,,,,,,,,,,,,, -88100,2.3527126,1.5472807,,,,,,,,,,,,,, -88200,2.5196433,1.3772948,,,,,,,,,,,,,, -88300,2.1733508,1.4455669,,,,,,,,,,,,,, -88400,2.3613636,1.5085996,,,,,,,,,,,,,, -88500,2.1563034,1.407181,,,,,,,,,,,,,, -88600,2.3725634,1.4015816,,,,,,,,,,,,,, -88700,2.4389563,1.4826186,,,,,,,,,,,,,, -88800,2.175974,1.3442161,,,,,,,,,,,,,, -88900,2.1794295,1.3964885,,,,,,,,,,,,,, -89000,2.584467,1.4454149,,,,,,,,,,,,,, -89070,,,0.7512954473495483,0.9448497891426086,0.6762999892234802,1.3257722854614258,50000.0,0.5527000427246094,2.053436517715454,10000.0,30144.08286547661,31320.67922115326,30144.08286547661,1170.2514476776123,3.320532321929932,0.0 -89100,2.1202946,1.3401814,,,,,,,,,,,,,, -89200,2.4402604,1.5474061,,,,,,,,,,,,,, -89300,2.190522,1.4455053,,,,,,,,,,,,,, -89400,2.257546,1.4345565,,,,,,,,,,,,,, -89500,2.1229913,1.372346,,,,,,,,,,,,,, -89600,2.1838374,1.4243206,,,,,,,,,,,,,, -89700,2.532521,1.4646474,,,,,,,,,,,,,, -89800,2.0941594,1.5502353,,,,,,,,,,,,,, -89900,2.223487,1.3096802,,,,,,,,,,,,,, -90000,2.2767277,1.4559339,,,,,,,,,,,,,, -90100,2.2160294,1.319622,,,,,,,,,,,,,, -90200,2.2563717,1.3829193,,,,,,,,,,,,,, -90300,2.1903076,1.3444351,,,,,,,,,,,,,, -90400,2.3105097,1.3520169,,,,,,,,,,,,,, -90500,2.231322,1.4119841,,,,,,,,,,,,,, -90582,,,0.8028140664100647,0.7457603812217712,0.6789199709892273,1.2876888513565063,50000.0,0.5542000532150269,1.987893223762512,10000.0,30654.26819229126,31848.25612831116,30654.26819229126,1187.541890144348,3.3698208332061768,0.0 -90600,2.2960558,1.334184,,,,,,,,,,,,,, -90700,2.1690977,1.4644268,,,,,,,,,,,,,, -90800,2.3262258,1.4159456,,,,,,,,,,,,,, -90900,2.1743288,1.296401,,,,,,,,,,,,,, -91000,2.1981106,1.3610513,,,,,,,,,,,,,, -91100,2.5957427,1.3402301,,,,,,,,,,,,,, -91200,2.315253,1.402038,,,,,,,,,,,,,, -91300,2.128104,1.4016179,,,,,,,,,,,,,, -91400,2.5567033,1.4006025,,,,,,,,,,,,,, -91500,2.172975,1.4772426,,,,,,,,,,,,,, -91600,2.3199496,1.2710015,,,,,,,,,,,,,, -91700,2.6726413,1.414685,,,,,,,,,,,,,, -91800,2.651228,1.5589968,,,,,,,,,,,,,, -91900,2.2083614,1.3718116,,,,,,,,,,,,,, -92000,2.2006645,1.3278295,,,,,,,,,,,,,, -92095,,,0.7693120241165161,0.856544017791748,0.6754800081253052,1.3307112455368042,50000.0,0.5468000173568726,2.067296504974365,10000.0,31164.4939289093,32375.592450141907,31164.4939289093,1204.5566387176514,3.4134774208068848,0.0 -92100,2.241331,1.4152635,,,,,,,,,,,,,, -92200,2.2660756,1.3108644,,,,,,,,,,,,,, -92300,2.1902099,1.3154367,,,,,,,,,,,,,, -92400,2.261938,1.3366345,,,,,,,,,,,,,, -92500,2.499656,1.3999377,,,,,,,,,,,,,, -92600,2.2586026,1.3109243,,,,,,,,,,,,,, -92700,2.3378456,1.3306869,,,,,,,,,,,,,, -92800,2.2706947,1.4040979,,,,,,,,,,,,,, -92900,2.4697537,1.4546125,,,,,,,,,,,,,, -93000,2.5406923,1.3720868,,,,,,,,,,,,,, -93100,2.2426817,1.4177512,,,,,,,,,,,,,, -93200,2.2973866,1.301717,,,,,,,,,,,,,, -93300,2.3921144,1.4270039,,,,,,,,,,,,,, -93400,2.2569911,1.4415319,,,,,,,,,,,,,, -93500,2.5091703,1.4576409,,,,,,,,,,,,,, -93600,2.318091,1.3676994,,,,,,,,,,,,,, -93607,,,0.76566481590271,0.876564621925354,0.6783599853515625,1.3167260885238647,50000.0,0.5529000163078308,2.0210089683532715,10000.0,31674.687499523163,32903.412647247314,31674.687499523163,1222.084743976593,3.460716009140014,0.0 -93700,2.3607056,1.4283407,,,,,,,,,,,,,, -93800,2.7781427,1.466775,,,,,,,,,,,,,, -93900,2.244577,1.2999812,,,,,,,,,,,,,, -94000,2.412564,1.4451026,,,,,,,,,,,,,, -94100,2.3457794,1.3771738,,,,,,,,,,,,,, -94200,2.3119395,1.4631827,,,,,,,,,,,,,, -94300,2.3237967,1.2513888,,,,,,,,,,,,,, -94400,2.1503232,1.3026882,,,,,,,,,,,,,, -94500,2.2757907,1.388154,,,,,,,,,,,,,, -94600,2.192455,1.3326366,,,,,,,,,,,,,, -94700,2.178586,1.3049865,,,,,,,,,,,,,, -94800,2.310628,1.3716017,,,,,,,,,,,,,, -94900,2.2325373,1.4743701,,,,,,,,,,,,,, -95000,2.877894,1.4566737,,,,,,,,,,,,,, -95100,2.5325954,1.4363687,,,,,,,,,,,,,, -95118,,,0.7731783986091614,0.8373833894729614,0.6893999576568604,1.2670496702194214,50000.0,0.5593000054359436,1.996338963508606,10000.0,32184.654997348785,33430.718074798584,32184.654997348785,1239.3243083953855,3.507728338241577,0.0 -95200,2.2639332,1.354156,,,,,,,,,,,,,, -95300,2.606212,1.4298916,,,,,,,,,,,,,, -95400,2.281938,1.2773701,,,,,,,,,,,,,, -95500,2.3246677,1.3176104,,,,,,,,,,,,,, -95600,2.2624235,1.2141804,,,,,,,,,,,,,, -95700,2.222397,1.4071217,,,,,,,,,,,,,, -95800,2.6407704,1.5048225,,,,,,,,,,,,,, -95900,2.3667355,1.3960557,,,,,,,,,,,,,, -96000,2.2779598,1.394459,,,,,,,,,,,,,, -96100,2.3378847,1.3967175,,,,,,,,,,,,,, -96200,2.426786,1.3275207,,,,,,,,,,,,,, -96300,2.3071344,1.4233028,,,,,,,,,,,,,, -96400,2.3157153,1.41364,,,,,,,,,,,,,, -96500,2.4206073,1.331523,,,,,,,,,,,,,, -96600,2.2504354,1.2811818,,,,,,,,,,,,,, -96629,,,0.756257951259613,0.9207841157913208,0.6714000105857849,1.3359733819961548,50000.0,0.5505000352859497,2.031044006347656,10000.0,32694.792423963547,33958.081731557846,32694.792423963547,1256.4522774219513,3.5546374320983887,0.0 -96700,2.401539,1.3624976,,,,,,,,,,,,,, -96800,2.573468,1.4975231,,,,,,,,,,,,,, -96900,3.111574,1.3660504,,,,,,,,,,,,,, -97000,2.3870728,1.309902,,,,,,,,,,,,,, -97100,2.713386,1.5086111,,,,,,,,,,,,,, -97200,2.6296887,1.4193482,,,,,,,,,,,,,, -97300,2.329653,1.3882495,,,,,,,,,,,,,, -97400,2.4644203,1.4107322,,,,,,,,,,,,,, -97500,2.3812468,1.3002853,,,,,,,,,,,,,, -97600,2.361434,1.3706409,,,,,,,,,,,,,, -97700,2.4568949,1.3747565,,,,,,,,,,,,,, -97800,2.4645438,1.4736615,,,,,,,,,,,,,, -97900,2.3890803,1.3834795,,,,,,,,,,,,,, -98000,2.5124333,1.3497837,,,,,,,,,,,,,, -98100,2.4835725,1.3735331,,,,,,,,,,,,,, -98140,,,0.7623963356018066,0.890619695186615,0.6840199828147888,1.299177885055542,50000.0,0.5555000305175781,2.0382800102233887,10000.0,33204.81808638573,34485.51674103737,33204.81808638573,1273.7611339092257,3.6036319732666016,0.0 -98200,2.2233038,1.3318758,,,,,,,,,,,,,, -98300,2.4227753,1.3417045,,,,,,,,,,,,,, -98400,2.2687812,1.2958138,,,,,,,,,,,,,, -98500,2.4574065,1.396687,,,,,,,,,,,,,, -98600,2.4874966,1.3779795,,,,,,,,,,,,,, -98700,2.6491623,1.3728753,,,,,,,,,,,,,, -98800,2.6239674,1.4325876,,,,,,,,,,,,,, -98900,2.654301,1.2943878,,,,,,,,,,,,,, -99000,2.4799666,1.3177683,,,,,,,,,,,,,, -99100,2.398362,1.3501018,,,,,,,,,,,,,, -99200,2.338117,1.3473969,,,,,,,,,,,,,, -99300,2.2750986,1.3089619,,,,,,,,,,,,,, -99400,2.3064604,1.3086317,,,,,,,,,,,,,, -99500,2.4255383,1.3604774,,,,,,,,,,,,,, -99600,2.4386752,1.3118622,,,,,,,,,,,,,, -99651,,,0.807059109210968,0.7231523990631104,0.6853199601173401,1.2836941480636597,50000.0,0.5587000250816345,1.994417428970337,10000.0,33714.866651535034,35012.723252773285,33714.866651535034,1290.823139667511,3.64790391921997,0.0 -99700,2.6106508,1.429633,,,,,,,,,,,,,, -99800,2.8501124,1.474016,,,,,,,,,,,,,, -99900,2.5755615,1.4169058,,,,,,,,,,,,,, -100000,2.617731,1.3323007,,,,,,,,,,,,,, -100100,2.4730825,1.3963468,,,,,,,,,,,,,, -100200,2.599841,1.3288479,,,,,,,,,,,,,, -100300,2.4726062,1.3678578,,,,,,,,,,,,,, -100400,2.567966,1.3640643,,,,,,,,,,,,,, -100500,2.802126,1.4433948,,,,,,,,,,,,,, -100600,2.4178026,1.3705704,,,,,,,,,,,,,, -100700,2.4331026,1.2897744,,,,,,,,,,,,,, -100800,2.4764369,1.392723,,,,,,,,,,,,,, -100900,2.687376,1.2937384,,,,,,,,,,,,,, -101000,2.4400983,1.4419702,,,,,,,,,,,,,, -101100,2.26242,1.309097,,,,,,,,,,,,,, -101162,,,0.7729192972183228,0.8433666229248047,0.6765599846839905,1.3290696144104004,50000.0,0.5457000136375427,2.0852365493774414,10000.0,34224.85311436653,35539.94133090973,34224.85311436653,1307.958990097046,3.690621852874756,0.0 -101200,2.3931692,1.4172447,,,,,,,,,,,,,, -101300,2.3691704,1.2893759,,,,,,,,,,,,,, -101400,2.5725245,1.323585,,,,,,,,,,,,,, -101500,2.9795814,1.4119167,,,,,,,,,,,,,, -101600,2.466725,1.2553543,,,,,,,,,,,,,, -101700,2.5566235,1.308425,,,,,,,,,,,,,, -101800,2.45757,1.335237,,,,,,,,,,,,,, -101900,2.3459542,1.3647789,,,,,,,,,,,,,, -102000,2.5446475,1.4091493,,,,,,,,,,,,,, -102100,2.2003329,1.1772311,,,,,,,,,,,,,, -102200,2.45983,1.3729043,,,,,,,,,,,,,, -102300,2.5327332,1.3682853,,,,,,,,,,,,,, -102400,2.4623132,1.3792801,,,,,,,,,,,,,, -102500,2.5322058,1.3643388,,,,,,,,,,,,,, -102600,2.4904666,1.2862277,,,,,,,,,,,,,, -102673,,,0.7886638641357422,0.7821224331855774,0.6950399875640869,1.240763783454895,50000.0,0.5649000406265259,1.9514639377594,10000.0,34734.81702041626,36067.20191526413,34734.81702041626,1325.1586797237396,3.736107349395752,0.0 -102700,2.5425513,1.2443782,,,,,,,,,,,,,, -102800,2.657394,1.3943077,,,,,,,,,,,,,, -102900,2.4315364,1.3295497,,,,,,,,,,,,,, -103000,2.6237044,1.300917,,,,,,,,,,,,,, -103100,2.4949267,1.3559158,,,,,,,,,,,,,, -103200,2.3551319,1.3087497,,,,,,,,,,,,,, -103300,2.5694525,1.3890427,,,,,,,,,,,,,, -103400,2.5521433,1.318803,,,,,,,,,,,,,, -103500,2.7953825,1.2910305,,,,,,,,,,,,,, -103600,2.370627,1.2578362,,,,,,,,,,,,,, -103700,2.454809,1.25629,,,,,,,,,,,,,, -103800,2.6361477,1.4331025,,,,,,,,,,,,,, -103900,2.3562639,1.2537639,,,,,,,,,,,,,, -104000,2.4110394,1.2848366,,,,,,,,,,,,,, -104100,2.6069028,1.4433094,,,,,,,,,,,,,, -104184,,,0.7867307066917419,0.7857255339622498,0.6966599822044373,1.2365820407867432,50000.0,0.5725000500679016,1.948253273963928,10000.0,35245.00524544716,36594.41708111763,35245.00524544716,1342.0841455459597,3.785308361053467,0.0 -104200,2.687708,1.3440616,,,,,,,,,,,,,, -104300,2.6535745,1.2760987,,,,,,,,,,,,,, -104400,2.3729877,1.2183645,,,,,,,,,,,,,, -104500,2.5658488,1.3310473,,,,,,,,,,,,,, -104600,2.5128126,1.3161654,,,,,,,,,,,,,, -104700,2.5311172,1.3222182,,,,,,,,,,,,,, -104800,2.3908775,1.202034,,,,,,,,,,,,,, -104900,2.4481514,1.3607979,,,,,,,,,,,,,, -105000,2.5872452,1.2708011,,,,,,,,,,,,,, -105100,2.6493287,1.2418023,,,,,,,,,,,,,, -105200,2.5725524,1.3692632,,,,,,,,,,,,,, -105300,2.406677,1.2720028,,,,,,,,,,,,,, -105400,2.9011476,1.2459128,,,,,,,,,,,,,, -105500,2.5492122,1.2551216,,,,,,,,,,,,,, -105600,2.5409813,1.3267353,,,,,,,,,,,,,, -105695,,,0.7801339030265808,0.8069247603416443,0.6951000094413757,1.2463161945343018,50000.0,0.5662000179290771,1.96476411819458,10000.0,35755.16663765907,37121.79448270798,35755.16663765907,1359.2032098770142,3.831068515777588,0.0 -105700,2.5052168,1.352359,,,,,,,,,,,,,, -105800,2.7236667,1.3524865,,,,,,,,,,,,,, -105900,2.7412126,1.4383831,,,,,,,,,,,,,, -106000,2.6757338,1.3017573,,,,,,,,,,,,,, -106100,2.6925192,1.3888075,,,,,,,,,,,,,, -106200,2.8674064,1.2275059,,,,,,,,,,,,,, -106300,2.5664482,1.2998083,,,,,,,,,,,,,, -106400,2.62704,1.2884405,,,,,,,,,,,,,, -106500,2.2949107,1.273142,,,,,,,,,,,,,, -106600,2.5694802,1.3114159,,,,,,,,,,,,,, -106700,2.454967,1.2977554,,,,,,,,,,,,,, -106800,2.6375108,1.394093,,,,,,,,,,,,,, -106900,2.5742404,1.2838012,,,,,,,,,,,,,, -107000,2.881935,1.3491876,,,,,,,,,,,,,, -107100,2.5849752,1.3399497,,,,,,,,,,,,,, -107200,2.6601255,1.3743795,,,,,,,,,,,,,, -107206,,,0.7763273119926453,0.8324143290519714,0.6882799863815308,1.2644940614700315,50000.0,0.554900050163269,2.0364840030670166,10000.0,36265.1814968586,37649.1190032959,36265.1814968586,1376.4114780426023,3.880370616912842,0.0 -107300,2.7113416,1.4441032,,,,,,,,,,,,,, -107400,2.7949743,1.4344068,,,,,,,,,,,,,, -107500,2.8915446,1.3426023,,,,,,,,,,,,,, -107600,2.5842516,1.2539644,,,,,,,,,,,,,, -107700,2.8304992,1.3687956,,,,,,,,,,,,,, -107800,2.6789792,1.3060017,,,,,,,,,,,,,, -107900,2.4281523,1.2078514,,,,,,,,,,,,,, -108000,2.4158905,1.2831246,,,,,,,,,,,,,, -108100,2.8124928,1.3217659,,,,,,,,,,,,,, -108200,2.5427094,1.290436,,,,,,,,,,,,,, -108300,2.3901298,1.2840999,,,,,,,,,,,,,, -108400,2.9343522,1.2807673,,,,,,,,,,,,,, -108500,2.4052727,1.1436914,,,,,,,,,,,,,, -108600,2.8164992,1.3146188,,,,,,,,,,,,,, -108700,2.6973023,1.1797514,,,,,,,,,,,,,, -108717,,,0.8256935477256775,0.6464024782180786,0.6990599632263184,1.226111888885498,50000.0,0.5611000061035156,1.955988407135009,10000.0,36775.30914545059,38176.23004317284,36775.30914545059,1393.2988758087158,3.924722671508789,0.0 -108800,2.853484,1.2539537,,,,,,,,,,,,,, -108900,2.6310236,1.3766897,,,,,,,,,,,,,, -109000,2.6378517,1.2622396,,,,,,,,,,,,,, -109100,2.4467587,1.2863445,,,,,,,,,,,,,, -109200,2.9311643,1.2935787,,,,,,,,,,,,,, -109300,2.4744878,1.1236842,,,,,,,,,,,,,, -109400,2.8081894,1.3136716,,,,,,,,,,,,,, -109500,2.4772902,1.3133475,,,,,,,,,,,,,, -109600,2.7709665,1.2654599,,,,,,,,,,,,,, -109700,2.5557184,1.3254057,,,,,,,,,,,,,, -109800,2.5728092,1.1675537,,,,,,,,,,,,,, -109900,2.5511727,1.3903241,,,,,,,,,,,,,, -110000,2.7905197,1.2783104,,,,,,,,,,,,,, -110100,2.621077,1.2521518,,,,,,,,,,,,,, -110200,2.6239161,1.3658171,,,,,,,,,,,,,, -110229,,,0.8019770383834839,0.727361261844635,0.6938199996948242,1.250580072402954,50000.0,0.5652000308036804,1.968800067901612,10000.0,37285.4779791832,38703.67038869858,37285.4779791832,1410.4655022621157,3.9767305850982666,0.0 -110300,2.9485626,1.353384,,,,,,,,,,,,,, -110400,2.3178258,1.2115333,,,,,,,,,,,,,, -110500,2.751929,1.3533889,,,,,,,,,,,,,, -110600,2.6952844,1.3549641,,,,,,,,,,,,,, -110700,2.5388255,1.223265,,,,,,,,,,,,,, -110800,2.5353453,1.2564867,,,,,,,,,,,,,, -110900,2.5824604,1.2022455,,,,,,,,,,,,,, -111000,2.737979,1.2611554,,,,,,,,,,,,,, -111100,2.672346,1.249638,,,,,,,,,,,,,, -111200,2.784814,1.3371964,,,,,,,,,,,,,, -111300,2.541805,1.2755603,,,,,,,,,,,,,, -111400,2.6058965,1.3038536,,,,,,,,,,,,,, -111500,2.556255,1.2022411,,,,,,,,,,,,,, -111600,2.716242,1.2269561,,,,,,,,,,,,,, -111700,2.662406,1.2086972,,,,,,,,,,,,,, -111739,,,0.80765700340271,0.6981836557388306,0.700939953327179,1.21205735206604,50000.0,0.5719000101089478,1.935737371444702,10000.0,37795.405812978745,39230.77417373657,37795.405812978745,1427.5435523986816,4.02356743812561,0.0 -111800,2.8641598,1.3555541,,,,,,,,,,,,,, -111900,2.70007,1.2984257,,,,,,,,,,,,,, -112000,2.7356882,1.2042042,,,,,,,,,,,,,, -112100,2.6563253,1.2062047,,,,,,,,,,,,,, -112200,2.5343099,1.2229784,,,,,,,,,,,,,, -112300,2.9501445,1.285615,,,,,,,,,,,,,, -112400,2.645991,1.2338655,,,,,,,,,,,,,, -112500,2.6733248,1.2682389,,,,,,,,,,,,,, -112600,2.6095893,1.2632412,,,,,,,,,,,,,, -112700,2.5217912,1.1891278,,,,,,,,,,,,,, -112800,2.6800334,1.2137115,,,,,,,,,,,,,, -112900,2.8964424,1.3036505,,,,,,,,,,,,,, -113000,2.79517,1.2949541,,,,,,,,,,,,,, -113100,2.7747297,1.2322662,,,,,,,,,,,,,, -113200,2.8056881,1.2685934,,,,,,,,,,,,,, -113250,,,0.8006218075752258,0.7307600378990173,0.7016599774360657,1.2059155702590942,50000.0,0.5667999982833862,1.932706356048584,10000.0,38305.49462866783,39757.9245569706,38305.49462866783,1444.5057072639463,4.071113586425781,0.0 -113300,2.6782792,1.2517774,,,,,,,,,,,,,, -113400,2.7273355,1.3220448,,,,,,,,,,,,,, -113500,2.6011345,1.1936867,,,,,,,,,,,,,, -113600,2.8576388,1.3156871,,,,,,,,,,,,,, -113700,2.9311874,1.273541,,,,,,,,,,,,,, -113800,3.0279248,1.2387635,,,,,,,,,,,,,, -113900,2.8607657,1.2394491,,,,,,,,,,,,,, -114000,2.8476315,1.2325824,,,,,,,,,,,,,, -114100,2.8957334,1.2500262,,,,,,,,,,,,,, -114200,2.911218,1.228189,,,,,,,,,,,,,, -114300,2.933576,1.2530116,,,,,,,,,,,,,, -114400,2.8278725,1.3306437,,,,,,,,,,,,,, -114500,2.886223,1.2660768,,,,,,,,,,,,,, -114600,2.9453957,1.1784755,,,,,,,,,,,,,, -114700,2.6850533,1.3161803,,,,,,,,,,,,,, -114760,,,0.7948819994926453,0.7492088079452515,0.6979399919509888,1.2191317081451416,50000.0,0.5746999979019165,1.917339324951172,10000.0,38815.42671918869,40284.924193143845,38815.42671918869,1461.475175857544,4.1177613735198975,0.0 -114800,2.6993184,1.1888374,,,,,,,,,,,,,, -114900,2.6576304,1.2359276,,,,,,,,,,,,,, -115000,2.7121034,1.2353942,,,,,,,,,,,,,, -115100,2.6939394,1.2729229,,,,,,,,,,,,,, -115200,2.855966,1.2206144,,,,,,,,,,,,,, -115300,2.8511603,1.1960826,,,,,,,,,,,,,, -115400,3.0600924,1.3044298,,,,,,,,,,,,,, -115500,3.00847,1.3586965,,,,,,,,,,,,,, -115600,2.9104023,1.227293,,,,,,,,,,,,,, -115700,2.8632567,1.3167859,,,,,,,,,,,,,, -115800,2.5570939,1.1488749,,,,,,,,,,,,,, -115900,2.9210453,1.2551304,,,,,,,,,,,,,, -116000,2.8042607,1.1745458,,,,,,,,,,,,,, -116100,2.9208887,1.3062265,,,,,,,,,,,,,, -116200,2.9052336,1.2516402,,,,,,,,,,,,,, -116270,,,0.7964963316917419,0.7375555038452148,0.7020999789237976,1.2137975692749023,50000.0,0.5814000368118286,1.9192442893981927,10000.0,39325.39078187943,40812.15231490135,39325.39078187943,1478.638622522354,4.166247367858887,0.0 -116300,3.0181904,1.2498102,,,,,,,,,,,,,, -116400,2.9822097,1.116538,,,,,,,,,,,,,, -116500,3.0621336,1.2033126,,,,,,,,,,,,,, -116600,2.7416892,1.1443311,,,,,,,,,,,,,, -116700,2.8978472,1.1969724,,,,,,,,,,,,,, -116800,2.7747524,1.2431905,,,,,,,,,,,,,, -116900,3.014801,1.2320929,,,,,,,,,,,,,, -117000,2.6060631,1.2203188,,,,,,,,,,,,,, -117100,2.8048975,1.2863337,,,,,,,,,,,,,, -117200,2.721048,1.1264808,,,,,,,,,,,,,, -117300,3.237779,1.3089218,,,,,,,,,,,,,, -117400,2.7293687,1.1849847,,,,,,,,,,,,,, -117500,2.94827,1.1983641,,,,,,,,,,,,,, -117600,2.7996962,1.1524256,,,,,,,,,,,,,, -117700,3.0571322,1.2882779,,,,,,,,,,,,,, -117781,,,0.8391461968421936,0.5893431305885315,0.702299952507019,1.215503215789795,50000.0,0.5736000537872314,1.9403326511383057,10000.0,39835.441056251526,41339.66882777214,39835.441056251526,1496.0060350894928,4.21298360824585,0.0 -117800,2.9014509,1.1581129,,,,,,,,,,,,,, -117900,2.8368247,1.2097276,,,,,,,,,,,,,, -118000,2.880146,1.2111604,,,,,,,,,,,,,, -118100,3.0232472,1.2204682,,,,,,,,,,,,,, -118200,3.0604103,1.2253937,,,,,,,,,,,,,, -118300,3.1231582,1.1891371,,,,,,,,,,,,,, -118400,2.9190657,1.1701708,,,,,,,,,,,,,, -118500,2.8660977,1.1790472,,,,,,,,,,,,,, -118600,2.760709,1.1606793,,,,,,,,,,,,,, -118700,2.9622025,1.2651286,,,,,,,,,,,,,, -118800,2.6787226,1.1637888,,,,,,,,,,,,,, -118900,2.9361358,1.2081157,,,,,,,,,,,,,, -119000,2.8503625,1.1676838,,,,,,,,,,,,,, -119100,2.9893236,1.1777755,,,,,,,,,,,,,, -119200,2.91778,1.2380433,,,,,,,,,,,,,, -119292,,,0.8243981003761292,0.6250141263008118,0.7067999839782715,1.1909725666046145,50000.0,0.5779000520706177,1.933236002922058,10000.0,40345.60301208496,41867.19692969322,40345.60301208496,1513.2711565494535,4.26213812828064,0.0 -119300,2.8442004,1.0595227,,,,,,,,,,,,,, -119400,2.9270606,1.274801,,,,,,,,,,,,,, -119500,2.8046327,1.1849357,,,,,,,,,,,,,, -119600,2.8585575,1.2031134,,,,,,,,,,,,,, -119700,2.664564,1.186494,,,,,,,,,,,,,, -119800,3.0157437,1.2200434,,,,,,,,,,,,,, -119900,3.1144855,1.1870956,,,,,,,,,,,,,, -120000,2.6968856,1.1451422,,,,,,,,,,,,,, -120100,2.9297132,1.1979085,,,,,,,,,,,,,, -120200,2.8712218,1.2353746,,,,,,,,,,,,,, -120300,2.9235349,1.2391286,,,,,,,,,,,,,, -120400,2.916894,1.1435959,,,,,,,,,,,,,, -120500,3.1355038,1.0941277,,,,,,,,,,,,,, -120600,2.8793848,1.2301692,,,,,,,,,,,,,, -120700,2.9326441,1.1582792,,,,,,,,,,,,,, -120800,2.8091822,1.1758542,,,,,,,,,,,,,, -120803,,,0.8103874325752258,0.6829017996788025,0.7016199827194214,1.216865062713623,50000.0,0.5700000524520874,1.950163722038269,10000.0,40855.583112478256,42394.367864370346,40855.583112478256,1530.355375289917,4.317422151565552,0.0 -120900,2.930164,1.1391735,,,,,,,,,,,,,, -121000,2.9404924,1.2289506,,,,,,,,,,,,,, -121100,3.199249,1.2638463,,,,,,,,,,,,,, -121200,2.9966493,1.1999397,,,,,,,,,,,,,, -121300,2.8444653,1.1394371,,,,,,,,,,,,,, -121400,2.9053874,1.1823637,,,,,,,,,,,,,, -121500,2.838484,1.115864,,,,,,,,,,,,,, -121600,2.9354,1.2042772,,,,,,,,,,,,,, -121700,2.7255023,1.0857182,,,,,,,,,,,,,, -121800,2.8080842,1.1535096,,,,,,,,,,,,,, -121900,3.026197,1.1539571,,,,,,,,,,,,,, -122000,3.1437693,1.1649671,,,,,,,,,,,,,, -122100,2.8391953,1.1264709,,,,,,,,,,,,,, -122200,2.8870945,1.0970614,,,,,,,,,,,,,, -122300,2.943004,1.2331797,,,,,,,,,,,,,, -122314,,,0.8113440275192261,0.6753456592559814,0.7064399719238281,1.2009263038635254,50000.0,0.5781000256538391,1.9371180534362795,10000.0,41365.70845270157,42921.472259521484,41365.70845270157,1547.233092069626,4.367457628250122,0.0 -122400,2.985395,1.1294053,,,,,,,,,,,,,, -122500,3.0327735,1.1718106,,,,,,,,,,,,,, -122600,3.130691,1.3196726,,,,,,,,,,,,,, -122700,3.1506941,1.1492196,,,,,,,,,,,,,, -122800,2.839038,1.0478263,,,,,,,,,,,,,, -122900,3.125719,1.1899832,,,,,,,,,,,,,, -123000,2.9597294,1.1227722,,,,,,,,,,,,,, -123100,3.197289,1.264209,,,,,,,,,,,,,, -123200,2.9146678,1.0756929,,,,,,,,,,,,,, -123300,3.2572675,1.2468257,,,,,,,,,,,,,, -123400,2.9434705,1.1949064,,,,,,,,,,,,,, -123500,2.9553723,1.1102954,,,,,,,,,,,,,, -123600,3.2247634,1.2766745,,,,,,,,,,,,,, -123700,3.3450377,1.1857145,,,,,,,,,,,,,, -123800,3.3124647,1.165164,,,,,,,,,,,,,, -123825,,,0.8163862824440002,0.6607792973518372,0.7098199725151062,1.1880916357040403,50000.0,0.5855000019073486,1.910176515579224,10000.0,41875.71085715294,43448.74796462059,41875.71085715294,1564.4061267375946,4.415964365005493,0.0 -123900,3.4855049,1.1849635,,,,,,,,,,,,,, -124000,3.1875668,1.1518387,,,,,,,,,,,,,, -124100,2.9715273,1.1302679,,,,,,,,,,,,,, -124200,2.9137669,1.0716584,,,,,,,,,,,,,, -124300,2.8767333,1.0137681,,,,,,,,,,,,,, -124400,2.975008,1.211474,,,,,,,,,,,,,, -124500,3.154743,1.1387515,,,,,,,,,,,,,, -124600,3.3138192,1.2147208,,,,,,,,,,,,,, -124700,2.9322202,1.087541,,,,,,,,,,,,,, -124800,2.9648244,1.1772114,,,,,,,,,,,,,, -124900,3.131748,1.1508203,,,,,,,,,,,,,, -125000,2.9650736,1.038246,,,,,,,,,,,,,, -125100,3.011819,1.0734849,,,,,,,,,,,,,, -125200,2.9200034,1.0332199,,,,,,,,,,,,,, -125300,3.1959045,1.1297268,,,,,,,,,,,,,, -125336,,,0.8089126348495483,0.6776189208030701,0.7065799832344055,1.1986366510391235,50000.0,0.5805000066757202,1.9390764236450195,10000.0,42385.60908794403,43975.89997935295,42385.60908794403,1581.5561335086825,4.467309474945068,0.0 -125400,3.0248456,1.1625333,,,,,,,,,,,,,, -125500,3.4985123,1.1277657,,,,,,,,,,,,,, -125600,2.872399,1.0860769,,,,,,,,,,,,,, -125700,2.9874225,1.2003024,,,,,,,,,,,,,, -125800,3.0203004,1.0392274,,,,,,,,,,,,,, -125900,3.449068,1.1489419,,,,,,,,,,,,,, -126000,3.0721369,1.1077231,,,,,,,,,,,,,, -126100,3.0404828,1.0970436,,,,,,,,,,,,,, -126200,3.0837562,1.2013568,,,,,,,,,,,,,, -126300,2.9260247,1.0866983,,,,,,,,,,,,,, -126400,2.9954474,1.1532447,,,,,,,,,,,,,, -126500,2.8886852,1.0513024,,,,,,,,,,,,,, -126600,3.0701509,1.1283417,,,,,,,,,,,,,, -126700,3.2666924,1.0912039,,,,,,,,,,,,,, -126800,3.3921063,1.2913239,,,,,,,,,,,,,, -126847,,,0.8555285334587097,0.5188804268836975,0.7141599655151367,1.1573100090026855,50000.0,0.5924000144004822,1.8532426357269287,10000.0,42895.51242017746,44503.13484168053,42895.51242017746,1598.7869279384613,4.516381502151489,0.0 -126900,3.0218596,1.01686,,,,,,,,,,,,,, -127000,3.000757,1.0561123,,,,,,,,,,,,,, -127100,3.094686,1.2213793,,,,,,,,,,,,,, -127200,3.1430888,1.0512133,,,,,,,,,,,,,, -127300,3.210975,1.145015,,,,,,,,,,,,,, -127400,3.283797,1.13099,,,,,,,,,,,,,, -127500,3.3278906,1.1332586,,,,,,,,,,,,,, -127600,3.4280298,1.2196183,,,,,,,,,,,,,, -127700,3.284178,1.1490996,,,,,,,,,,,,,, -127800,3.446706,1.1981976,,,,,,,,,,,,,, -127900,3.1851892,1.224893,,,,,,,,,,,,,, -128000,2.8872695,1.0307093,,,,,,,,,,,,,, -128100,3.01458,1.161352,,,,,,,,,,,,,, -128200,3.5093818,1.1160859,,,,,,,,,,,,,, -128300,3.3898714,1.1622797,,,,,,,,,,,,,, -128357,,,0.8460618257522583,0.5457704067230225,0.7174400091171265,1.141788363456726,50000.0,0.5907000303268433,1.86967408657074,10000.0,43405.469561100006,45030.29795050621,43405.469561100006,1615.8899295330048,4.567685127258301,0.0 -128400,3.5412526,1.0802922,,,,,,,,,,,,,, -128500,2.9421656,1.0343577,,,,,,,,,,,,,, -128600,3.3151312,1.1307853,,,,,,,,,,,,,, -128700,3.3279033,1.124778,,,,,,,,,,,,,, -128800,3.6078894,1.2078252,,,,,,,,,,,,,, -128900,3.0479321,1.1348914,,,,,,,,,,,,,, -129000,3.268108,1.1182852,,,,,,,,,,,,,, -129100,3.2952807,1.1270616,,,,,,,,,,,,,, -129200,3.1833246,1.0145252,,,,,,,,,,,,,, -129300,3.1252322,1.0955601,,,,,,,,,,,,,, -129400,3.5120573,1.0826223,,,,,,,,,,,,,, -129500,3.1488209,1.1196386,,,,,,,,,,,,,, -129600,3.612187,1.1618997,,,,,,,,,,,,,, -129700,2.8761816,1.0409777,,,,,,,,,,,,,, -129800,3.3138998,1.0699056,,,,,,,,,,,,,, -129868,,,0.833426296710968,0.5857810974121094,0.7130599617958069,1.1719478368759155,50000.0,0.5879999995231628,1.8958600759506223,10000.0,43915.47292876244,45557.60752797127,43915.47292876244,1633.0848352909088,4.626312971115112,0.0 -129900,3.4672806,1.1097924,,,,,,,,,,,,,, -130000,3.2713716,1.0790579,,,,,,,,,,,,,, -130100,3.1905925,1.0461689,,,,,,,,,,,,,, -130200,2.8849368,1.0930052,,,,,,,,,,,,,, -130300,3.0980868,1.006651,,,,,,,,,,,,,, -130400,3.2245839,1.1231364,,,,,,,,,,,,,, -130500,3.0875165,1.0826699,,,,,,,,,,,,,, -130600,3.4608364,1.1429396,,,,,,,,,,,,,, -130700,3.0653844,0.97756803,,,,,,,,,,,,,, -130800,3.2263124,1.0819314,,,,,,,,,,,,,, -130900,2.9747882,1.0218908,,,,,,,,,,,,,, -131000,3.2392101,1.109566,,,,,,,,,,,,,, -131100,3.25315,1.0538721,,,,,,,,,,,,,, -131200,3.1794553,1.0542061,,,,,,,,,,,,,, -131300,3.4659326,0.985975,,,,,,,,,,,,,, -131380,,,0.8418765664100647,0.564940869808197,0.720579981803894,1.134805679321289,50000.0,0.5936000347137451,1.8518024682998653,10000.0,44425.701722860336,46085.02875685692,44425.701722860336,1650.1739237308502,4.677285432815552,0.0 -131400,3.2238803,1.0707463,,,,,,,,,,,,,, -131500,3.4722538,1.1421958,,,,,,,,,,,,,, -131600,3.1606443,1.1018363,,,,,,,,,,,,,, -131700,3.2265503,0.9803885,,,,,,,,,,,,,, -131800,3.8664143,1.0677016,,,,,,,,,,,,,, -131900,3.291297,1.0653698,,,,,,,,,,,,,, -132000,3.788709,1.1685166,,,,,,,,,,,,,, -132100,3.3344274,1.1520243,,,,,,,,,,,,,, -132200,3.4400723,1.0782584,,,,,,,,,,,,,, -132300,3.17518,1.0298989,,,,,,,,,,,,,, -132400,3.4188917,1.0434206,,,,,,,,,,,,,, -132500,3.5075495,1.1584582,,,,,,,,,,,,,, -132600,3.1177256,1.0285523,,,,,,,,,,,,,, -132700,3.3876765,1.0228757,,,,,,,,,,,,,, -132800,3.1428556,1.0190835,,,,,,,,,,,,,, -132891,,,0.8443080186843872,0.5527390241622925,0.7223599553108215,1.1370505094528198,50000.0,0.5984000563621521,1.8685274124145508,10000.0,44935.64200139046,46612.3599793911,44935.64200139046,1667.4603853225708,4.729793787002564,0.0 -132900,3.5733452,1.0718175,,,,,,,,,,,,,, -133000,3.3004124,1.0250854,,,,,,,,,,,,,, -133100,3.4092507,1.0203602,,,,,,,,,,,,,, -133200,3.2822876,1.0895611,,,,,,,,,,,,,, -133300,3.3692725,1.0689764,,,,,,,,,,,,,, -133400,3.2810779,1.0194672,,,,,,,,,,,,,, -133500,3.3600185,1.0807905,,,,,,,,,,,,,, -133600,3.5021334,1.0880888,,,,,,,,,,,,,, -133700,3.3135269,1.0784513,,,,,,,,,,,,,, -133800,3.563382,1.0072637,,,,,,,,,,,,,, -133900,3.5379198,1.096021,,,,,,,,,,,,,, -134000,3.20316,0.9696149,,,,,,,,,,,,,, -134100,3.3216398,1.0487093,,,,,,,,,,,,,, -134200,3.4750562,1.0500953,,,,,,,,,,,,,, -134300,3.8527207,1.095111,,,,,,,,,,,,,, -134400,3.3079848,1.0049644,,,,,,,,,,,,,, -134402,,,0.8469586968421936,0.5429478883743286,0.722599983215332,1.1435267925262451,50000.0,0.5957000255584717,1.8864117860794067,10000.0,45445.54392409325,47139.51445841789,45445.54392409325,1684.610284090042,4.779802560806274,0.0 -134500,3.3307104,0.9914248,,,,,,,,,,,,,, -134600,3.2553296,1.0286001,,,,,,,,,,,,,, -134700,3.783343,1.1277966,,,,,,,,,,,,,, -134800,3.3749013,1.1051164,,,,,,,,,,,,,, -134900,3.6677296,1.067064,,,,,,,,,,,,,, -135000,3.890593,1.2043607,,,,,,,,,,,,,, -135100,3.4764473,1.0981001,,,,,,,,,,,,,, -135200,3.6901355,1.0211401,,,,,,,,,,,,,, -135300,3.1354213,1.0322381,,,,,,,,,,,,,, -135400,3.580347,0.9942867,,,,,,,,,,,,,, -135500,3.403652,1.025566,,,,,,,,,,,,,, -135600,3.689888,1.1453888,,,,,,,,,,,,,, -135700,3.242946,1.0507056,,,,,,,,,,,,,, -135800,3.4480445,0.97316337,,,,,,,,,,,,,, -135900,3.5688362,1.024262,,,,,,,,,,,,,, -135913,,,0.879902720451355,0.428302139043808,0.7254799604415894,1.13906729221344,50000.0,0.5994000434875488,1.854590654373169,10000.0,45955.56926059723,47666.84624814987,45955.56926059723,1701.8160552978516,4.829279661178589,0.0 -136000,3.323379,0.9393049,,,,,,,,,,,,,, -136100,3.4852402,1.0284485,,,,,,,,,,,,,, -136200,3.201883,0.97311276,,,,,,,,,,,,,, -136300,3.2809854,1.0297531,,,,,,,,,,,,,, -136400,3.3193724,0.9453462,,,,,,,,,,,,,, -136500,3.7809174,1.141076,,,,,,,,,,,,,, -136600,3.406758,1.0779557,,,,,,,,,,,,,, -136700,3.7708154,1.061503,,,,,,,,,,,,,, -136800,3.4643786,0.9788989,,,,,,,,,,,,,, -136900,3.6136544,1.0432231,,,,,,,,,,,,,, -137000,3.7245471,0.9997653,,,,,,,,,,,,,, -137100,3.6105292,1.0409855,,,,,,,,,,,,,, -137200,3.3452482,1.0085135,,,,,,,,,,,,,, -137300,3.2370741,1.0052981,,,,,,,,,,,,,, -137400,3.5126822,1.0390158,,,,,,,,,,,,,, -137424,,,0.8622847199440002,0.485957384109497,0.7216599583625793,1.1468485593795776,50000.0,0.5951000452041626,1.876573920249939,10000.0,46465.56084442139,48193.87060427666,46465.56084442139,1718.7475936412811,4.878940105438232,0.0 -137500,3.3621979,1.007857,,,,,,,,,,,,,, -137600,4.1705303,1.0640429,,,,,,,,,,,,,, -137700,3.4650433,1.0343493,,,,,,,,,,,,,, -137800,3.8070116,1.0492388,,,,,,,,,,,,,, -137900,3.7192993,1.0744847,,,,,,,,,,,,,, -138000,3.400322,1.0512434,,,,,,,,,,,,,, -138100,3.5482337,1.0221038,,,,,,,,,,,,,, -138200,3.3488476,1.0708425,,,,,,,,,,,,,, -138300,3.6362321,1.0434514,,,,,,,,,,,,,, -138400,3.6632597,1.0691497,,,,,,,,,,,,,, -138500,3.1490448,0.8713118,,,,,,,,,,,,,, -138600,3.8204024,1.0047349,,,,,,,,,,,,,, -138700,3.7989738,1.0264844,,,,,,,,,,,,,, -138800,3.475012,1.0540289,,,,,,,,,,,,,, -138900,3.4943018,1.0249584,,,,,,,,,,,,,, -138935,,,0.8673469424247742,0.4645864367485046,0.7269600033760071,1.113786220550537,50000.0,0.6015000343322754,1.8418588638305664,10000.0,46975.57577776909,48720.94380235672,46975.57577776909,1735.6987011432648,4.933518886566162,0.0 -139000,3.5661824,1.0109708,,,,,,,,,,,,,, -139100,3.5895114,0.9821348,,,,,,,,,,,,,, -139200,3.9499393,1.067155,,,,,,,,,,,,,, -139300,3.8771887,1.077783,,,,,,,,,,,,,, -139400,3.7129583,1.0011225,,,,,,,,,,,,,, -139500,3.750558,1.0098286,,,,,,,,,,,,,, -139600,3.4042876,0.9793106,,,,,,,,,,,,,, -139700,3.4906516,0.9447668,,,,,,,,,,,,,, -139800,3.2398825,0.8621104,,,,,,,,,,,,,, -139900,3.9319167,0.99682367,,,,,,,,,,,,,, -140000,3.5043666,1.0090793,,,,,,,,,,,,,, -140100,3.497274,1.0202627,,,,,,,,,,,,,, -140200,3.513932,0.9767902,,,,,,,,,,,,,, -140300,3.7030108,0.99825245,,,,,,,,,,,,,, -140400,3.6803045,0.98556125,,,,,,,,,,,,,, -140446,,,0.8664301633834839,0.4641014337539673,0.7262399792671204,1.1191344261169434,50000.0,0.6013000011444092,1.8446753025054927,10000.0,47485.51925230026,49248.403594732285,47485.51925230026,1753.1074166297913,4.989470481872559,0.0 -140500,3.7606952,0.92702097,,,,,,,,,,,,,, -140600,3.5568352,1.0571318,,,,,,,,,,,,,, -140700,3.9971464,1.0377673,,,,,,,,,,,,,, -140800,3.337834,0.9246453,,,,,,,,,,,,,, -140900,3.7783904,1.0374875,,,,,,,,,,,,,, -141000,3.713449,0.9902687,,,,,,,,,,,,,, -141100,3.1930063,0.86989105,,,,,,,,,,,,,, -141200,3.8495283,1.0537069,,,,,,,,,,,,,, -141300,3.7414231,0.96067333,,,,,,,,,,,,,, -141400,3.7861555,0.9652432,,,,,,,,,,,,,, -141500,3.7356336,1.0130261,,,,,,,,,,,,,, -141600,3.6124794,1.0176424,,,,,,,,,,,,,, -141700,3.6364906,1.0557716,,,,,,,,,,,,,, -141800,3.649092,1.0556028,,,,,,,,,,,,,, -141900,3.615767,0.946807,,,,,,,,,,,,,, -141957,,,0.8687818646430969,0.4507312178611755,0.7324999570846558,1.1073507070541382,50000.0,0.6045000553131104,1.8624248504638672,10000.0,47995.51835870743,49775.99292445183,47995.51835870743,1770.591741323471,5.043493986129761,0.0 -142000,3.9402733,0.9875816,,,,,,,,,,,,,, -142100,3.706147,1.053385,,,,,,,,,,,,,, -142200,3.7846894,0.9794345,,,,,,,,,,,,,, -142300,3.770731,0.9721098,,,,,,,,,,,,,, -142400,3.7680643,0.9215427,,,,,,,,,,,,,, -142500,3.9723463,0.99450374,,,,,,,,,,,,,, -142600,3.816914,0.9172952,,,,,,,,,,,,,, -142700,4.267503,0.9869732,,,,,,,,,,,,,, -142800,3.9395354,0.98841345,,,,,,,,,,,,,, -142900,3.7474134,0.99503255,,,,,,,,,,,,,, -143000,3.6762738,1.0535593,,,,,,,,,,,,,, -143100,4.1222515,0.99720836,,,,,,,,,,,,,, -143200,3.716428,0.97275525,,,,,,,,,,,,,, -143300,3.4164526,0.8868442,,,,,,,,,,,,,, -143400,3.5902224,0.9006967,,,,,,,,,,,,,, -143468,,,0.8729272484779358,0.4402921795845032,0.7320399880409241,1.117584228515625,50000.0,0.6025000214576721,1.8811101913452148,10000.0,48505.51794219017,50302.9829390049,48505.51794219017,1787.475771188736,5.096850633621216,0.0 -143500,3.6584833,0.9827987,,,,,,,,,,,,,, -143600,3.7140658,0.96113706,,,,,,,,,,,,,, -143700,3.8458247,1.0045507,,,,,,,,,,,,,, -143800,3.590476,0.8789205,,,,,,,,,,,,,, -143900,3.5858028,0.90068984,,,,,,,,,,,,,, -144000,3.6313214,0.946666,,,,,,,,,,,,,, -144100,3.77105,1.0091857,,,,,,,,,,,,,, -144200,3.7125993,0.9320701,,,,,,,,,,,,,, -144300,3.659953,0.9100106,,,,,,,,,,,,,, -144400,3.738148,0.96612394,,,,,,,,,,,,,, -144500,4.0904512,0.9562795,,,,,,,,,,,,,, -144600,3.786868,0.9136754,,,,,,,,,,,,,, -144700,3.4658442,0.8982314,,,,,,,,,,,,,, -144800,3.8883686,0.868552,,,,,,,,,,,,,, -144900,3.6526155,0.9736115,,,,,,,,,,,,,, -144979,,,0.8941127061843872,0.3673299551010132,0.7300800085067749,1.1244101524353027,50000.0,0.6011000275611877,1.8705135583877563,10000.0,49015.45076060295,50829.97280144692,49015.45076060295,1804.427111625672,5.150873422622681,0.0 -145000,3.8392901,0.8672205,,,,,,,,,,,,,, -145100,3.6557574,0.8515601,,,,,,,,,,,,,, -145200,3.8549595,0.88608724,,,,,,,,,,,,,, -145300,3.6366518,0.8672603,,,,,,,,,,,,,, -145400,3.801355,0.921938,,,,,,,,,,,,,, -145500,3.8573503,0.895956,,,,,,,,,,,,,, -145600,3.9743576,0.94981414,,,,,,,,,,,,,, -145700,4.021166,0.8849928,,,,,,,,,,,,,, -145800,3.6520715,0.90286016,,,,,,,,,,,,,, -145900,3.8662336,0.9326151,,,,,,,,,,,,,, -146000,3.5183322,0.8672707,,,,,,,,,,,,,, -146100,3.5875123,0.7868027,,,,,,,,,,,,,, -146200,3.9143207,0.92110246,,,,,,,,,,,,,, -146300,4.129079,0.9314221,,,,,,,,,,,,,, -146400,4.077681,0.964615,,,,,,,,,,,,,, -146490,,,0.8969627022743225,0.3568805456161499,0.7330399751663208,1.1061513423919678,50000.0,0.6080000400543213,1.8569191694259644,10000.0,49525.504454135895,51357.16882824898,49525.504454135895,1821.4655013084407,5.202937126159668,0.0 -146500,4.054636,0.9888083,,,,,,,,,,,,,, -146600,3.833396,0.878733,,,,,,,,,,,,,, -146700,3.57279,0.8839741,,,,,,,,,,,,,, -146800,4.169057,1.0006742,,,,,,,,,,,,,, -146900,3.6275692,0.87963045,,,,,,,,,,,,,, -147000,3.7809014,0.9561358,,,,,,,,,,,,,, -147100,3.9434361,0.9864502,,,,,,,,,,,,,, -147200,3.6560404,0.8679781,,,,,,,,,,,,,, -147300,3.5745783,0.75829685,,,,,,,,,,,,,, -147400,3.8483562,0.9065311,,,,,,,,,,,,,, -147500,4.2635307,0.97083044,,,,,,,,,,,,,, -147600,3.8186479,0.89034384,,,,,,,,,,,,,, -147700,3.8713272,0.9522393,,,,,,,,,,,,,, -147800,4.0151267,0.9872953,,,,,,,,,,,,,, -147900,4.010917,0.89555347,,,,,,,,,,,,,, -148000,3.896772,0.91532236,,,,,,,,,,,,,, -148001,,,0.8917012214660645,0.3716536462306976,0.7354599833488464,1.0957401990890503,50000.0,0.6052000522613525,1.8347212076187127,10000.0,50035.673583984375,51884.738174676895,50035.673583984375,1838.759839296341,5.256864070892334,0.0 -148100,3.6898923,0.90122765,,,,,,,,,,,,,, -148200,3.6936615,0.95254624,,,,,,,,,,,,,, -148300,3.9628234,0.89936894,,,,,,,,,,,,,, -148400,4.1269426,0.8971952,,,,,,,,,,,,,, -148500,3.9347792,0.9133061,,,,,,,,,,,,,, -148600,3.8629987,0.94037807,,,,,,,,,,,,,, -148700,3.7180703,0.84767747,,,,,,,,,,,,,, -148800,3.513308,0.817526,,,,,,,,,,,,,, -148900,3.9105635,0.9449032,,,,,,,,,,,,,, -149000,3.872276,0.8491195,,,,,,,,,,,,,, -149100,4.2926493,0.93127173,,,,,,,,,,,,,, -149200,4.002565,0.88538116,,,,,,,,,,,,,, -149300,4.09222,0.84319687,,,,,,,,,,,,,, -149400,4.348185,0.93878335,,,,,,,,,,,,,, -149500,4.2002044,0.9221566,,,,,,,,,,,,,, -149512,,,0.8974609375,0.349711924791336,0.7398799657821655,1.0879695415496826,50000.0,0.6164000034332275,1.8350088596344,10000.0,50545.63154053688,52411.86543726921,50545.63154053688,1855.8230559825893,5.3106231689453125,0.0 -149600,3.8019052,0.8663503,,,,,,,,,,,,,, -149700,4.1097407,0.90568054,,,,,,,,,,,,,, -149800,3.6980286,0.8291345,,,,,,,,,,,,,, -149900,4.1283073,0.8716401,,,,,,,,,,,,,, -150000,3.911212,0.8890684,,,,,,,,,,,,,, -150100,4.422661,0.82023984,,,,,,,,,,,,,, -150200,3.7460172,0.8953324,,,,,,,,,,,,,, -150300,3.7393947,0.8655221,,,,,,,,,,,,,, -150400,4.1257834,0.8162238,,,,,,,,,,,,,, -150500,3.7763042,0.8275331,,,,,,,,,,,,,, -150600,4.2664485,0.9271859,,,,,,,,,,,,,, -150700,4.421042,0.9806626,,,,,,,,,,,,,, -150800,3.934685,0.8772572,,,,,,,,,,,,,, -150900,3.7879379,0.81134784,,,,,,,,,,,,,, -151000,3.7988172,0.81633496,,,,,,,,,,,,,, -151023,,,0.8940728306770325,0.3650061786174774,0.7384399771690369,1.08740496635437,50000.0,0.6171000003814697,1.84127140045166,10000.0,51055.59986066818,52938.73140335083,51055.59986066818,1872.6232559680936,5.355967283248901,0.0 -151100,3.7626233,0.8669652,,,,,,,,,,,,,, -151200,4.2877913,0.82007295,,,,,,,,,,,,,, -151300,4.072456,0.78770214,,,,,,,,,,,,,, -151400,3.7942033,0.84829915,,,,,,,,,,,,,, -151500,3.597151,0.796736,,,,,,,,,,,,,, -151600,4.0335813,0.9391403,,,,,,,,,,,,,, -151700,3.9959462,0.7784108,,,,,,,,,,,,,, -151800,4.147845,0.84449035,,,,,,,,,,,,,, -151900,4.02071,0.88662875,,,,,,,,,,,,,, -152000,4.3766036,0.9131307,,,,,,,,,,,,,, -152100,3.8113174,0.81038946,,,,,,,,,,,,,, -152200,3.8340952,0.8674668,,,,,,,,,,,,,, -152300,4.078218,0.7691019,,,,,,,,,,,,,, -152400,4.423545,0.894297,,,,,,,,,,,,,, -152500,4.3490033,0.8484596,,,,,,,,,,,,,, -152534,,,0.902164340019226,0.3393253982067108,0.7376399636268616,1.093889236450195,50000.0,0.612000048160553,1.854403018951416,10000.0,51565.666075229645,53466.66446304321,51565.666075229645,1890.3863129615784,5.407841682434082,0.0 -152600,3.8491404,0.84582347,,,,,,,,,,,,,, -152700,3.7723215,0.7221496,,,,,,,,,,,,,, -152800,3.9855258,0.88139725,,,,,,,,,,,,,, -152900,3.88402,0.88874936,,,,,,,,,,,,,, -153000,4.258109,0.87704265,,,,,,,,,,,,,, -153100,4.058061,0.86598325,,,,,,,,,,,,,, -153200,3.9955142,0.82119656,,,,,,,,,,,,,, -153300,3.9478865,0.84675974,,,,,,,,,,,,,, -153400,3.721542,0.77039325,,,,,,,,,,,,,, -153500,4.3090444,0.9124272,,,,,,,,,,,,,, -153600,3.9138513,0.7691932,,,,,,,,,,,,,, -153700,4.341572,0.9973228,,,,,,,,,,,,,, -153800,3.852941,0.82755834,,,,,,,,,,,,,, -153900,4.034486,0.9059608,,,,,,,,,,,,,, -154000,4.0069437,0.8451011,,,,,,,,,,,,,, -154045,,,0.924465835094452,0.2674159705638885,0.74263995885849,1.075542688369751,50000.0,0.6180000305175781,1.834379196166992,10000.0,52075.750030994415,53993.95304322243,52075.750030994415,1907.4831624031067,5.463268280029297,0.0 -154100,4.1712537,0.78097224,,,,,,,,,,,,,, -154200,3.9638004,0.79163176,,,,,,,,,,,,,, -154300,4.0133142,0.7330729,,,,,,,,,,,,,, -154400,3.9753954,0.6978769,,,,,,,,,,,,,, -154500,3.919756,0.7643331,,,,,,,,,,,,,, -154600,4.2718906,0.8468349,,,,,,,,,,,,,, -154700,3.9477184,0.815366,,,,,,,,,,,,,, -154800,4.204366,0.8707363,,,,,,,,,,,,,, -154900,3.8496366,0.76926434,,,,,,,,,,,,,, -155000,4.0314016,0.7363782,,,,,,,,,,,,,, -155100,3.9972312,0.7988514,,,,,,,,,,,,,, -155200,4.174678,0.8603363,,,,,,,,,,,,,, -155300,4.32781,0.8910652,,,,,,,,,,,,,, -155400,4.044367,0.7768462,,,,,,,,,,,,,, -155500,3.8868082,0.77003074,,,,,,,,,,,,,, -155555,,,0.9222337007522584,0.2677811682224273,0.7441799640655518,1.0739376544952393,50000.0,0.6170000433921814,1.832737922668457,10000.0,52585.85358142853,54521.44651532173,52585.85358142853,1924.760721445084,5.524292469024658,0.0 -155600,4.3546605,0.87074435,,,,,,,,,,,,,, -155700,4.1004405,0.83805597,,,,,,,,,,,,,, -155800,3.9110653,0.76254,,,,,,,,,,,,,, -155900,3.9542003,0.8175806,,,,,,,,,,,,,, -156000,4.002416,0.70375204,,,,,,,,,,,,,, -156100,4.2023277,0.8416116,,,,,,,,,,,,,, -156200,3.9204783,0.7355331,,,,,,,,,,,,,, -156300,4.434445,0.79284406,,,,,,,,,,,,,, -156400,4.218001,0.7973301,,,,,,,,,,,,,, -156500,4.3840017,0.83104473,,,,,,,,,,,,,, -156600,4.4748964,0.8617708,,,,,,,,,,,,,, -156700,4.6284328,0.85887265,,,,,,,,,,,,,, -156800,4.176584,0.83377147,,,,,,,,,,,,,, -156900,4.6714616,0.8518018,,,,,,,,,,,,,, -157000,3.8692892,0.76521397,,,,,,,,,,,,,, -157065,,,0.9177096486091614,0.2826022207736969,0.7434599995613098,1.0783482789993286,50000.0,0.6144000291824341,1.843645453453064,10000.0,53095.834916353226,55048.73201870918,53095.834916353226,1941.9596049785607,5.5780956745147705,0.0 -157100,4.1634903,0.77252847,,,,,,,,,,,,,, -157200,4.2181144,0.7985169,,,,,,,,,,,,,, -157300,4.419339,0.821111,,,,,,,,,,,,,, -157400,4.06774,0.77869844,,,,,,,,,,,,,, -157500,4.3701987,0.8052071,,,,,,,,,,,,,, -157600,4.151106,0.7472219,,,,,,,,,,,,,, -157700,4.5926104,0.7957686,,,,,,,,,,,,,, -157800,4.266747,0.8540102,,,,,,,,,,,,,, -157900,3.9846444,0.80806565,,,,,,,,,,,,,, -158000,4.58631,0.7792445,,,,,,,,,,,,,, -158100,4.644996,0.8243164,,,,,,,,,,,,,, -158200,4.3221145,0.78534657,,,,,,,,,,,,,, -158300,4.09141,0.7761271,,,,,,,,,,,,,, -158400,4.275282,0.7529697,,,,,,,,,,,,,, -158500,4.0117974,0.75481445,,,,,,,,,,,,,, -158576,,,0.9215162396430968,0.2728367447853088,0.7452399730682373,1.074062705039978,50000.0,0.6214000582695007,1.8301516771316528,10000.0,53605.92672085762,55576.13701224327,53605.92672085762,1959.165843486786,5.632391452789307,0.0 -158600,4.1389027,0.7602445,,,,,,,,,,,,,, -158700,4.260852,0.73169243,,,,,,,,,,,,,, -158800,4.3698845,0.7536354,,,,,,,,,,,,,, -158900,4.3551393,0.8250275,,,,,,,,,,,,,, -159000,4.538481,0.87237185,,,,,,,,,,,,,, -159100,4.2151346,0.8135047,,,,,,,,,,,,,, -159200,4.527341,0.8421396,,,,,,,,,,,,,, -159300,4.3008246,0.79141927,,,,,,,,,,,,,, -159400,4.775433,0.8028333,,,,,,,,,,,,,, -159500,4.3861585,0.7653523,,,,,,,,,,,,,, -159600,4.131353,0.71231717,,,,,,,,,,,,,, -159700,4.392562,0.8026663,,,,,,,,,,,,,, -159800,4.0957603,0.69321454,,,,,,,,,,,,,, -159900,4.545289,0.83312416,,,,,,,,,,,,,, -160000,4.394529,0.7614566,,,,,,,,,,,,,, -160087,,,0.9222137928009032,0.2624360918998718,0.7467199563980103,1.0808501243591309,50000.0,0.6155000329017639,1.869329571723938,10000.0,54115.92363142967,56103.3468940258,54115.92363142967,1976.270096063614,5.6887383460998535,0.0 -160100,4.364764,0.7849646,,,,,,,,,,,,,, -160200,4.10622,0.7027,,,,,,,,,,,,,, -160300,4.2166085,0.7291685,,,,,,,,,,,,,, -160400,4.8604693,0.8536905,,,,,,,,,,,,,, -160500,3.9765372,0.7251703,,,,,,,,,,,,,, -160600,4.480629,0.7763046,,,,,,,,,,,,,, -160700,4.4773726,0.82979625,,,,,,,,,,,,,, -160800,4.101084,0.67721343,,,,,,,,,,,,,, -160900,3.9826348,0.7326992,,,,,,,,,,,,,, -161000,4.121134,0.7424372,,,,,,,,,,,,,, -161100,4.296589,0.6822801,,,,,,,,,,,,,, -161200,4.1104198,0.76435083,,,,,,,,,,,,,, -161300,4.667599,0.8273724,,,,,,,,,,,,,, -161400,4.1657944,0.72213095,,,,,,,,,,,,,, -161500,4.426904,0.7232555,,,,,,,,,,,,,, -161598,,,0.9258609414100648,0.2548287212848663,0.7464199662208557,1.0761977434158323,50000.0,0.6211000084877014,1.862369418144226,10000.0,54625.90968847275,56630.47075533867,54625.90968847275,1993.2992358207705,5.746572971343994,0.0 -161600,4.3281345,0.70954007,,,,,,,,,,,,,, -161700,4.534993,0.82518315,,,,,,,,,,,,,, -161800,4.172908,0.7049387,,,,,,,,,,,,,, -161900,4.1325073,0.726851,,,,,,,,,,,,,, -162000,4.212146,0.73014235,,,,,,,,,,,,,, -162100,4.1319213,0.69683826,,,,,,,,,,,,,, -162200,4.225059,0.75682855,,,,,,,,,,,,,, -162300,4.60301,0.8840595,,,,,,,,,,,,,, -162400,4.826217,0.75866735,,,,,,,,,,,,,, -162500,4.28973,0.8314589,,,,,,,,,,,,,, -162600,4.277478,0.7209109,,,,,,,,,,,,,, -162700,4.1599746,0.7069489,,,,,,,,,,,,,, -162800,3.9494116,0.6438639,,,,,,,,,,,,,, -162900,4.235016,0.75823635,,,,,,,,,,,,,, -163000,4.4126415,0.73906934,,,,,,,,,,,,,, -163100,4.173922,0.7305336,,,,,,,,,,,,,, -163109,,,0.9452128410339355,0.1964490413665771,0.7478199601173401,1.0691890716552734,50000.0,0.6243000030517578,1.836309671401977,10000.0,55136.06539058685,57157.884127378464,55136.06539058685,2010.447353601456,5.804094076156616,0.0 -163200,4.6074004,0.8149825,,,,,,,,,,,,,, -163300,4.49098,0.6468247,,,,,,,,,,,,,, -163400,4.06319,0.7226368,,,,,,,,,,,,,, -163500,4.1868405,0.74621266,,,,,,,,,,,,,, -163600,4.5709743,0.7286827,,,,,,,,,,,,,, -163700,4.5613604,0.74711025,,,,,,,,,,,,,, -163800,4.0144567,0.7337779,,,,,,,,,,,,,, -163900,4.5669494,0.7614902,,,,,,,,,,,,,, -164000,3.8204796,0.6268219,,,,,,,,,,,,,, -164100,4.827168,0.7697157,,,,,,,,,,,,,, -164200,4.378694,0.7993373,,,,,,,,,,,,,, -164300,4.2864456,0.7046367,,,,,,,,,,,,,, -164400,4.166798,0.693105,,,,,,,,,,,,,, -164500,4.6041865,0.73767155,,,,,,,,,,,,,, -164600,4.43741,0.6966599,,,,,,,,,,,,,, -164619,,,0.9431201815605164,0.2001777589321136,0.7487999796867371,1.062160611152649,50000.0,0.6244000196456909,1.8305258750915527,10000.0,55645.96127533913,57685.11239886284,55645.96127533913,2027.654661655426,5.87807035446167,0.0 -164700,4.4235024,0.778041,,,,,,,,,,,,,, -164800,4.593303,0.7807888,,,,,,,,,,,,,, -164900,4.8468585,0.75943834,,,,,,,,,,,,,, -165000,4.2672124,0.74019897,,,,,,,,,,,,,, -165100,4.2913523,0.6985317,,,,,,,,,,,,,, -165200,4.4990444,0.68324643,,,,,,,,,,,,,, -165300,4.166616,0.68570936,,,,,,,,,,,,,, -165400,4.310874,0.71524495,,,,,,,,,,,,,, -165500,4.1672645,0.7101453,,,,,,,,,,,,,, -165600,4.540638,0.76645577,,,,,,,,,,,,,, -165700,5.24207,0.79415256,,,,,,,,,,,,,, -165800,4.855795,0.78747636,,,,,,,,,,,,,, -165900,4.1478963,0.65665686,,,,,,,,,,,,,, -166000,4.378522,0.6536953,,,,,,,,,,,,,, -166100,4.041286,0.6821348,,,,,,,,,,,,,, -166130,,,0.9429408311843872,0.2010091245174408,0.7507999539375305,1.063918113708496,50000.0,0.6175000071525574,1.8367518186569207,10000.0,56156.17899298668,58212.58046770096,56156.17899298668,2044.7875108718872,5.943117380142212,0.0 -166200,4.186523,0.68952495,,,,,,,,,,,,,, -166300,4.318854,0.6148511,,,,,,,,,,,,,, -166400,4.555764,0.735587,,,,,,,,,,,,,, -166500,4.6428275,0.6848044,,,,,,,,,,,,,, -166600,4.548101,0.719408,,,,,,,,,,,,,, -166700,4.753637,0.75194216,,,,,,,,,,,,,, -166800,4.3231335,0.72050655,,,,,,,,,,,,,, -166900,4.347667,0.6905598,,,,,,,,,,,,,, -167000,4.201731,0.6150662,,,,,,,,,,,,,, -167100,4.2522244,0.6614825,,,,,,,,,,,,,, -167200,4.7618933,0.71699506,,,,,,,,,,,,,, -167300,4.244775,0.7411322,,,,,,,,,,,,,, -167400,4.1666756,0.64387184,,,,,,,,,,,,,, -167500,4.6189895,0.72560465,,,,,,,,,,,,,, -167600,4.4508595,0.6930133,,,,,,,,,,,,,, -167639,,,0.9417450428009032,0.2010733932256698,0.7512800097465515,1.0629782676696775,50000.0,0.6234000325202942,1.833619952201844,10000.0,56666.13904762268,58739.86371612549,56666.13904762268,2062.002586364746,5.999414920806885,0.0 -167700,4.796656,0.75667346,,,,,,,,,,,,,, -167800,4.693856,0.6965413,,,,,,,,,,,,,, -167900,4.397344,0.68173754,,,,,,,,,,,,,, -168000,4.033231,0.71563697,,,,,,,,,,,,,, -168100,4.597017,0.68431437,,,,,,,,,,,,,, -168200,5.1131463,0.6747605,,,,,,,,,,,,,, -168300,4.4218745,0.6931959,,,,,,,,,,,,,, -168400,4.7274113,0.6989908,,,,,,,,,,,,,, -168500,4.533281,0.7061794,,,,,,,,,,,,,, -168600,4.376897,0.63706195,,,,,,,,,,,,,, -168700,4.4741774,0.6549561,,,,,,,,,,,,,, -168800,3.98181,0.6170865,,,,,,,,,,,,,, -168900,4.351503,0.66042477,,,,,,,,,,,,,, -169000,3.9878883,0.67367876,,,,,,,,,,,,,, -169100,4.3396063,0.72194195,,,,,,,,,,,,,, -169150,,,0.9476044178009032,0.1906523704528808,0.7506200075149536,1.0631998777389526,50000.0,0.625,1.8395988941192627,10000.0,57176.32164978981,59267.289298295975,57176.32164978981,2079.134115457535,6.059413433074951,0.0 -169200,4.127875,0.62409043,,,,,,,,,,,,,, -169300,4.653618,0.67746705,,,,,,,,,,,,,, -169400,4.6580467,0.6952398,,,,,,,,,,,,,, -169500,5.026081,0.7128527,,,,,,,,,,,,,, -169600,4.354574,0.66027737,,,,,,,,,,,,,, -169700,4.4813027,0.7350948,,,,,,,,,,,,,, -169800,4.43459,0.66349065,,,,,,,,,,,,,, -169900,4.462091,0.7680902,,,,,,,,,,,,,, -170000,4.2180686,0.6177521,,,,,,,,,,,,,, -170100,4.442996,0.6478487,,,,,,,,,,,,,, -170200,4.6434436,0.6795674,,,,,,,,,,,,,, -170300,4.347647,0.65523446,,,,,,,,,,,,,, -170400,4.5305524,0.71000314,,,,,,,,,,,,,, -170500,3.7599995,0.5436923,,,,,,,,,,,,,, -170600,4.2484436,0.6770095,,,,,,,,,,,,,, -170660,,,0.9497169852256776,0.18296679854393,0.7519999742507935,1.061332106590271,50000.0,0.625700056552887,1.836114883422852,10000.0,57686.33897137642,59794.26624298096,57686.33897137642,2095.982746839524,6.118293046951294,0.0 -170700,4.326831,0.7082716,,,,,,,,,,,,,, -170800,4.638218,0.6892321,,,,,,,,,,,,,, -170900,4.6628523,0.69656897,,,,,,,,,,,,,, -171000,4.4965734,0.7063831,,,,,,,,,,,,,, -171100,4.3455596,0.6092601,,,,,,,,,,,,,, -171200,4.3697047,0.6547731,,,,,,,,,,,,,, -171300,4.4959974,0.6150041,,,,,,,,,,,,,, -171400,4.6838164,0.7615878,,,,,,,,,,,,,, -171500,4.942896,0.7733563,,,,,,,,,,,,,, -171600,4.018375,0.68139756,,,,,,,,,,,,,, -171700,4.652283,0.74063134,,,,,,,,,,,,,, -171800,4.989338,0.6773344,,,,,,,,,,,,,, -171900,4.8542056,0.64444387,,,,,,,,,,,,,, -172000,4.3088593,0.61869276,,,,,,,,,,,,,, -172100,4.4969354,0.6890472,,,,,,,,,,,,,, -172170,,,0.9553770422935486,0.1642269641160965,0.7532199621200562,1.0589675903320312,50000.0,0.6272000074386597,1.831020832061768,10000.0,58196.301375865936,60321.604273319244,58196.301375865936,2113.2451345920563,6.178622245788574,0.0 -172200,4.4402156,0.6430271,,,,,,,,,,,,,, -172300,4.632494,0.7182412,,,,,,,,,,,,,, -172400,4.5766325,0.7548605,,,,,,,,,,,,,, -172500,4.444958,0.63999736,,,,,,,,,,,,,, -172600,4.489518,0.5727751,,,,,,,,,,,,,, -172700,5.0007906,0.7237117,,,,,,,,,,,,,, -172800,4.6002216,0.70294744,,,,,,,,,,,,,, -172900,4.9750795,0.71337116,,,,,,,,,,,,,, -173000,5.395134,0.7490244,,,,,,,,,,,,,, -173100,4.5359817,0.66575295,,,,,,,,,,,,,, -173200,4.2639985,0.6316662,,,,,,,,,,,,,, -173300,4.5152507,0.6673127,,,,,,,,,,,,,, -173400,4.2528195,0.6209861,,,,,,,,,,,,,, -173500,4.4184103,0.64324355,,,,,,,,,,,,,, -173600,5.1537156,0.696244,,,,,,,,,,,,,, -173681,,,0.9561144709587096,0.1630240380764007,0.7536199688911438,1.0499507188796997,50000.0,0.6290000081062317,1.8309428691864007,10000.0,58706.51188850403,60849.07347011566,58706.51188850403,2130.392186880112,6.2380759716033936,0.0 -173700,4.5907655,0.67547643,,,,,,,,,,,,,, -173800,4.702564,0.68807894,,,,,,,,,,,,,, -173900,4.546448,0.5634936,,,,,,,,,,,,,, -174000,4.6042166,0.610881,,,,,,,,,,,,,, -174100,4.9027114,0.67433447,,,,,,,,,,,,,, -174200,4.6405187,0.6790314,,,,,,,,,,,,,, -174300,4.272896,0.52560866,,,,,,,,,,,,,, -174400,4.136188,0.6090837,,,,,,,,,,,,,, -174500,4.8240075,0.72306556,,,,,,,,,,,,,, -174600,4.471736,0.69140744,,,,,,,,,,,,,, -174700,4.306844,0.7032969,,,,,,,,,,,,,, -174800,4.8719034,0.69732094,,,,,,,,,,,,,, -174900,4.7566195,0.67067015,,,,,,,,,,,,,, -175000,4.9351254,0.5950717,,,,,,,,,,,,,, -175100,4.469751,0.66466945,,,,,,,,,,,,,, -175192,,,0.9563137292861938,0.1613334119319915,0.7529799938201904,1.0494953393936155,50000.0,0.6276000142097473,1.8195759057998653,10000.0,59216.62665820122,61376.379529476166,59216.62665820122,2147.473479747772,6.29523515701294,0.0 -175200,5.0290413,0.6419823,,,,,,,,,,,,,, -175300,4.1243935,0.55989254,,,,,,,,,,,,,, -175400,4.3187037,0.68318456,,,,,,,,,,,,,, -175500,4.3530827,0.64685625,,,,,,,,,,,,,, -175600,4.712867,0.67128116,,,,,,,,,,,,,, -175700,4.282067,0.66121954,,,,,,,,,,,,,, -175800,4.0838842,0.570539,,,,,,,,,,,,,, -175900,4.799163,0.67515975,,,,,,,,,,,,,, -176000,4.768516,0.7827648,,,,,,,,,,,,,, -176100,4.464944,0.67685974,,,,,,,,,,,,,, -176200,4.2180023,0.6274394,,,,,,,,,,,,,, -176300,4.5402346,0.62861496,,,,,,,,,,,,,, -176400,4.3405776,0.6350927,,,,,,,,,,,,,, -176500,4.430924,0.6434984,,,,,,,,,,,,,, -176600,4.442279,0.6745442,,,,,,,,,,,,,, -176700,4.474059,0.6205723,,,,,,,,,,,,,, -176703,,,0.956273913383484,0.1597492545843124,0.7539399862289429,1.049542784690857,50000.0,0.6301000118255615,1.824698686599732,10000.0,59726.75988483429,61903.96989917755,59726.75988483429,2164.808212280273,6.3645124435424805,0.0 -176800,4.470491,0.6579272,,,,,,,,,,,,,, -176900,5.2839913,0.65263426,,,,,,,,,,,,,, -177000,4.6095743,0.6362249,,,,,,,,,,,,,, -177100,4.7402916,0.6137208,,,,,,,,,,,,,, -177200,4.4415627,0.6664093,,,,,,,,,,,,,, -177300,4.5108943,0.5909972,,,,,,,,,,,,,, -177400,4.331773,0.63402534,,,,,,,,,,,,,, -177500,4.4282346,0.6132486,,,,,,,,,,,,,, -177600,4.3121095,0.63223815,,,,,,,,,,,,,, -177700,4.5480037,0.6817312,,,,,,,,,,,,,, -177800,4.757517,0.5973452,,,,,,,,,,,,,, -177900,5.2152014,0.676913,,,,,,,,,,,,,, -178000,4.673347,0.6741855,,,,,,,,,,,,,, -178100,4.6873493,0.63124824,,,,,,,,,,,,,, -178200,4.7248855,0.60183436,,,,,,,,,,,,,, -178213,,,0.9578084945678712,0.1552063673734665,0.7549200057983398,1.0473157167434692,50000.0,0.6283000111579895,1.824266076087952,10000.0,60236.88195681572,62431.477404117584,60236.88195681572,2182.069860935211,6.435606241226196,0.0 -178300,4.5874057,0.60402274,,,,,,,,,,,,,, -178400,4.9991403,0.5955675,,,,,,,,,,,,,, -178500,4.3878956,0.5945381,,,,,,,,,,,,,, -178600,4.5737967,0.6270709,,,,,,,,,,,,,, -178700,4.546781,0.6527276,,,,,,,,,,,,,, -178800,4.311297,0.55196154,,,,,,,,,,,,,, -178900,4.4460745,0.6312182,,,,,,,,,,,,,, -179000,4.6231933,0.6718521,,,,,,,,,,,,,, -179100,4.6388335,0.6336218,,,,,,,,,,,,,, -179200,4.790469,0.65748495,,,,,,,,,,,,,, -179300,4.8698153,0.7033915,,,,,,,,,,,,,, -179400,4.804437,0.6733665,,,,,,,,,,,,,, -179500,4.2400503,0.6182412,,,,,,,,,,,,,, -179600,4.5814943,0.6086825,,,,,,,,,,,,,, -179700,4.762608,0.6205282,,,,,,,,,,,,,, -179723,,,0.9579480290412904,0.1545311510562896,0.7556999921798706,1.0462963581085205,50000.0,0.6289000511169434,1.826785683631897,10000.0,60746.79544234276,62958.6741335392,60746.79544234276,2199.227989912033,6.507908821105957,0.0 -179800,4.4012694,0.5494526,,,,,,,,,,,,,, -179900,4.516794,0.60574913,,,,,,,,,,,,,, -180000,4.4852495,0.67096674,,,,,,,,,,,,,, -180100,4.906906,0.66784596,,,,,,,,,,,,,, -180200,4.670009,0.6552207,,,,,,,,,,,,,, -180300,4.467976,0.6579267,,,,,,,,,,,,,, -180400,4.9067044,0.58200574,,,,,,,,,,,,,, -180500,4.8091836,0.6962398,,,,,,,,,,,,,, -180600,4.6056194,0.6768958,,,,,,,,,,,,,, -180700,4.7872286,0.6199655,,,,,,,,,,,,,, -180800,4.9709578,0.6835865,,,,,,,,,,,,,, -180900,4.2140436,0.5715584,,,,,,,,,,,,,, -181000,4.471087,0.63409245,,,,,,,,,,,,,, -181100,4.503526,0.599729,,,,,,,,,,,,,, -181200,4.766182,0.6545164,,,,,,,,,,,,,, -181233,,,0.9613759517669678,0.1460684090852737,0.7547599673271179,1.0432047843933103,50000.0,0.6278000473976135,1.820835828781128,10000.0,61256.721252441406,63485.68079519272,61256.721252441406,2216.197611093521,6.566876411437988,0.0 -181300,4.483192,0.6355294,,,,,,,,,,,,,, -181400,4.5380936,0.68647647,,,,,,,,,,,,,, -181500,4.691883,0.5974069,,,,,,,,,,,,,, -181600,4.6855507,0.65297264,,,,,,,,,,,,,, -181700,4.2381988,0.67630726,,,,,,,,,,,,,, -181800,4.416999,0.595099,,,,,,,,,,,,,, -181900,4.1661305,0.6047622,,,,,,,,,,,,,, -182000,4.26384,0.6049074,,,,,,,,,,,,,, -182100,4.46599,0.6282934,,,,,,,,,,,,,, -182200,4.5371156,0.6224248,,,,,,,,,,,,,, -182300,4.466651,0.6358,,,,,,,,,,,,,, -182400,4.4517174,0.64343286,,,,,,,,,,,,,, -182500,4.565958,0.7133688,,,,,,,,,,,,,, -182600,4.422542,0.5839517,,,,,,,,,,,,,, -182700,4.408926,0.5684337,,,,,,,,,,,,,, -182743,,,0.9588847160339355,0.1517803221940994,0.7552799582481384,1.0448073148727417,50000.0,0.6301000118255615,1.822758674621582,10000.0,61766.63048315048,64013.07961964607,61766.63048315048,2233.5767714977264,6.625716686248779,0.0 -182800,5.2525454,0.64924777,,,,,,,,,,,,,, -182900,4.3406844,0.60935867,,,,,,,,,,,,,, -183000,4.332507,0.6056252,,,,,,,,,,,,,, -183100,4.7454934,0.5956629,,,,,,,,,,,,,, -183200,4.200106,0.58274484,,,,,,,,,,,,,, -183300,4.318632,0.6190009,,,,,,,,,,,,,, -183400,3.9612522,0.54029423,,,,,,,,,,,,,, -183500,4.395366,0.5770209,,,,,,,,,,,,,, -183600,4.4891863,0.59733415,,,,,,,,,,,,,, -183700,4.136291,0.5787111,,,,,,,,,,,,,, -183800,4.6567874,0.6244177,,,,,,,,,,,,,, -183900,4.4378257,0.5710069,,,,,,,,,,,,,, -184000,4.435572,0.6236486,,,,,,,,,,,,,, -184100,4.4050555,0.68714917,,,,,,,,,,,,,, -184200,4.537154,0.60958296,,,,,,,,,,,,,, -184253,,,0.9606584906578064,0.1485065221786499,0.7553399801254272,1.0438897609710691,50000.0,0.6307000517845154,1.82177996635437,10000.0,62276.70138978958,64540.51701974869,62276.70138978958,2250.8308987617493,6.685147285461426,0.0 -184300,4.7115445,0.6015719,,,,,,,,,,,,,, -184400,4.7229757,0.74144715,,,,,,,,,,,,,, -184500,4.5942254,0.62606007,,,,,,,,,,,,,, -184600,4.2534275,0.5973101,,,,,,,,,,,,,, -184700,4.26321,0.58411074,,,,,,,,,,,,,, -184800,4.36747,0.6565177,,,,,,,,,,,,,, -184900,4.7688637,0.68323636,,,,,,,,,,,,,, -185000,4.278887,0.5530706,,,,,,,,,,,,,, -185100,4.531183,0.5845712,,,,,,,,,,,,,, -185200,4.4536123,0.6001047,,,,,,,,,,,,,, -185300,4.3128195,0.5756137,,,,,,,,,,,,,, -185400,4.0446806,0.5677591,,,,,,,,,,,,,, -185500,4.482218,0.5900242,,,,,,,,,,,,,, -185600,4.230077,0.5709932,,,,,,,,,,,,,, -185700,4.486318,0.58048654,,,,,,,,,,,,,, -185763,,,0.9605787396430968,0.1489373445510864,0.7553600072860718,1.0433369874954224,50000.0,0.6310000419616699,1.8208897113800049,10000.0,62786.637889146805,65067.43080019951,62786.637889146805,2267.693324804306,6.7476887702941895,0.0 -185800,4.494176,0.6019174,,,,,,,,,,,,,, -185900,4.793963,0.6225223,,,,,,,,,,,,,, -186000,4.423506,0.6376473,,,,,,,,,,,,,, -186100,4.443599,0.59799314,,,,,,,,,,,,,, -186200,4.375853,0.6463029,,,,,,,,,,,,,, -186300,4.260664,0.5686482,,,,,,,,,,,,,, -186400,4.52401,0.64525765,,,,,,,,,,,,,, -186500,4.328956,0.56900334,,,,,,,,,,,,,, -186600,4.4745374,0.5795376,,,,,,,,,,,,,, -186700,4.7815666,0.6949526,,,,,,,,,,,,,, -186800,4.4279604,0.60865366,,,,,,,,,,,,,, -186900,4.568332,0.6116177,,,,,,,,,,,,,, -187000,4.577717,0.6440309,,,,,,,,,,,,,, -187100,4.29199,0.64110893,,,,,,,,,,,,,, -187200,4.651545,0.6385149,,,,,,,,,,,,,, -187274,,,0.9614955186843872,0.1437671929597854,0.755299985408783,1.0444318056106567,50000.0,0.6310000419616699,1.822888970375061,10000.0,63296.7076792717,65594.64647817612,63296.7076792717,2284.7242891788483,6.810681581497192,0.0 -187300,4.131113,0.6124048,,,,,,,,,,,,,, -187400,4.277928,0.6337713,,,,,,,,,,,,,, -187500,4.2675996,0.5247005,,,,,,,,,,,,,, -187600,4.683096,0.63967097,,,,,,,,,,,,,, -187700,4.8030224,0.7008632,,,,,,,,,,,,,, -187800,4.3841143,0.64540654,,,,,,,,,,,,,, -187900,4.291183,0.62698275,,,,,,,,,,,,,, -188000,4.411113,0.62959814,,,,,,,,,,,,,, -188100,4.434585,0.6972861,,,,,,,,,,,,,, -188200,4.451084,0.61021507,,,,,,,,,,,,,, -188300,4.0910974,0.6443926,,,,,,,,,,,,,, -188400,5.0680656,0.64172727,,,,,,,,,,,,,, -188500,4.755955,0.66735643,,,,,,,,,,,,,, -188600,4.85107,0.66175365,,,,,,,,,,,,,, -188700,4.3149643,0.61757827,,,,,,,,,,,,,, -188784,,,0.9608577489852904,0.1480024456977844,0.7550999522209167,1.0436111688613892,50000.0,0.6305000185966492,1.8203967809677124,10000.0,63806.64422917366,66121.7480969429,63806.64422917366,2301.769454240799,6.878018856048584,0.0 -188800,4.3626676,0.60793805,,,,,,,,,,,,,, -188900,4.7782183,0.6632447,,,,,,,,,,,,,, -189000,4.615352,0.6589386,,,,,,,,,,,,,, -189100,4.6515975,0.6812794,,,,,,,,,,,,,, -189200,4.3907275,0.6222117,,,,,,,,,,,,,, -189300,4.208101,0.61544764,,,,,,,,,,,,,, -189400,4.548174,0.764408,,,,,,,,,,,,,, -189500,5.3524795,0.73447436,,,,,,,,,,,,,, -189600,4.4127584,0.5619513,,,,,,,,,,,,,, -189700,4.4659925,0.5742669,,,,,,,,,,,,,, -189800,4.306147,0.6285336,,,,,,,,,,,,,, -189900,4.7774415,0.72831744,,,,,,,,,,,,,, -190000,4.7351747,0.63172895,,,,,,,,,,,,,, -190100,4.602499,0.5927014,,,,,,,,,,,,,, -190200,4.1863165,0.53906065,,,,,,,,,,,,,, -190294,,,0.9602000713348388,0.1484966427087783,0.7550599575042725,1.0441384315490725,50000.0,0.6308000087738037,1.8202807903289795,10000.0,64316.67609715462,66649.12196302414,64316.67609715462,2318.9982414245605,6.939778089523315,0.0 -190300,4.4318657,0.6519297,,,,,,,,,,,,,, -190400,4.4002767,0.63752496,,,,,,,,,,,,,, -190500,4.563024,0.5963099,,,,,,,,,,,,,, -190600,4.5416403,0.6162042,,,,,,,,,,,,,, -190700,4.384508,0.65731573,,,,,,,,,,,,,, -190800,4.7970495,0.582167,,,,,,,,,,,,,, -190900,4.2548084,0.58189964,,,,,,,,,,,,,, -191000,4.5065393,0.64976585,,,,,,,,,,,,,, -191100,4.7088723,0.6884182,,,,,,,,,,,,,, -191200,4.4895926,0.518235,,,,,,,,,,,,,, -191300,4.6222243,0.6142855,,,,,,,,,,,,,, -191400,4.534037,0.59637976,,,,,,,,,,,,,, -191500,4.3569474,0.5977724,,,,,,,,,,,,,, -191600,4.4857016,0.67703533,,,,,,,,,,,,,, -191700,4.5582247,0.70087373,,,,,,,,,,,,,, -191800,4.536302,0.5651079,,,,,,,,,,,,,, -191804,,,0.9599011540412904,0.1491530537605285,0.7549200057983398,1.045419692993164,50000.0,0.6303000450134277,1.8232477903366089,10000.0,64826.72256541252,67177.01352787018,64826.72256541252,2336.7331914901733,6.99896240234375,0.0 -191900,4.375217,0.6189176,,,,,,,,,,,,,, -192000,4.4913874,0.59405625,,,,,,,,,,,,,, -192100,4.2329783,0.5748811,,,,,,,,,,,,,, -192200,4.6261296,0.65500236,,,,,,,,,,,,,, -192300,4.5056357,0.6635599,,,,,,,,,,,,,, -192400,4.3446317,0.6035414,,,,,,,,,,,,,, -192500,4.69082,0.6413487,,,,,,,,,,,,,, -192600,4.877741,0.710736,,,,,,,,,,,,,, -192700,4.4612727,0.63588727,,,,,,,,,,,,,, -192800,4.1243086,0.57486403,,,,,,,,,,,,,, -192900,4.8507223,0.74779785,,,,,,,,,,,,,, -193000,5.1908727,0.65977407,,,,,,,,,,,,,, -193100,4.627836,0.5996609,,,,,,,,,,,,,, -193200,4.550106,0.66544974,,,,,,,,,,,,,, -193300,4.134206,0.6144522,,,,,,,,,,,,,, -193315,,,0.9594427347183228,0.1506325155496597,0.7553199529647827,1.0440016984939575,50000.0,0.6300000548362732,1.8218053579330444,10000.0,65336.91848921776,67704.38000226021,65336.91848921776,2353.772256135941,7.07848596572876,0.0 -193400,4.736013,0.6553821,,,,,,,,,,,,,, -193500,4.1787786,0.6318613,,,,,,,,,,,,,, -193600,4.910081,0.6577955,,,,,,,,,,,,,, -193700,4.6135955,0.7490453,,,,,,,,,,,,,, -193800,4.622559,0.6121764,,,,,,,,,,,,,, -193900,4.436717,0.6727193,,,,,,,,,,,,,, -194000,4.5084734,0.6425961,,,,,,,,,,,,,, -194100,4.626844,0.6332809,,,,,,,,,,,,,, -194200,4.4651875,0.6127474,,,,,,,,,,,,,, -194300,4.5127735,0.66879165,,,,,,,,,,,,,, -194400,4.2793713,0.68702537,,,,,,,,,,,,,, -194500,4.6685247,0.6504923,,,,,,,,,,,,,, -194600,4.514757,0.61264217,,,,,,,,,,,,,, -194700,4.590528,0.595881,,,,,,,,,,,,,, -194800,4.18302,0.59787905,,,,,,,,,,,,,, -194825,,,0.9610371589660645,0.1465877741575241,0.7549200057983398,1.0440150499343872,50000.0,0.6312000155448914,1.820330023765564,10000.0,65846.86537241936,68231.4124391079,65846.86537241936,2370.7418541908264,7.142187833786011,0.0 -194900,4.426458,0.6174666,,,,,,,,,,,,,, -195000,4.534543,0.66587913,,,,,,,,,,,,,, -195100,4.820901,0.65714395,,,,,,,,,,,,,, -195200,4.7248173,0.6002861,,,,,,,,,,,,,, -195300,4.3958445,0.5753694,,,,,,,,,,,,,, -195400,4.6498375,0.62650514,,,,,,,,,,,,,, -195500,4.4431705,0.62054366,,,,,,,,,,,,,, -195600,4.083697,0.64231044,,,,,,,,,,,,,, -195700,4.771376,0.64521533,,,,,,,,,,,,,, -195800,4.3175254,0.6052932,,,,,,,,,,,,,, -195900,4.142555,0.65214455,,,,,,,,,,,,,, -196000,4.587454,0.60313267,,,,,,,,,,,,,, -196100,4.50022,0.69124687,,,,,,,,,,,,,, -196200,4.668233,0.55142933,,,,,,,,,,,,,, -196300,4.6521378,0.600594,,,,,,,,,,,,,, -196335,,,0.9608577489852904,0.145789235830307,0.7549799680709839,1.0447232723236084,50000.0,0.6327000260353088,1.8242987394332888,10000.0,66356.94939851761,68758.44803285599,66356.94939851761,2387.5755808353424,7.208361387252808,0.0 -196400,4.165467,0.6368572,,,,,,,,,,,,,, -196500,4.52472,0.6305403,,,,,,,,,,,,,, -196600,4.862229,0.62610036,,,,,,,,,,,,,, -196700,4.8450327,0.62765026,,,,,,,,,,,,,, -196800,4.881224,0.6477216,,,,,,,,,,,,,, -196900,4.568461,0.5282033,,,,,,,,,,,,,, -197000,4.609049,0.59391874,,,,,,,,,,,,,, -197100,4.299826,0.6330377,,,,,,,,,,,,,, -197200,4.6474214,0.6589926,,,,,,,,,,,,,, -197300,4.6357875,0.61775625,,,,,,,,,,,,,, -197400,4.1386104,0.54213554,,,,,,,,,,,,,, -197500,5.5202475,0.69905525,,,,,,,,,,,,,, -197600,5.1531425,0.68538463,,,,,,,,,,,,,, -197700,4.4980073,0.62612355,,,,,,,,,,,,,, -197800,4.580279,0.5744689,,,,,,,,,,,,,, -197846,,,0.9601203799247742,0.1475706398487091,0.7552399635314941,1.0447214841842651,50000.0,0.6309000253677368,1.822264552116394,10000.0,66867.08155965805,69286.04866456985,66867.08155965805,2404.9300212860107,7.270601749420166,0.0 -197900,4.483575,0.6698173,,,,,,,,,,,,,, -198000,4.546894,0.6714518,,,,,,,,,,,,,, -198100,4.5467477,0.69762963,,,,,,,,,,,,,, -198200,4.490564,0.60100967,,,,,,,,,,,,,, -198300,4.1256967,0.608966,,,,,,,,,,,,,, -198400,4.235964,0.6574001,,,,,,,,,,,,,, -198500,4.6499085,0.6505375,,,,,,,,,,,,,, -198600,5.2639427,0.68802834,,,,,,,,,,,,,, -198700,4.798356,0.63577163,,,,,,,,,,,,,, -198800,4.6095743,0.6133165,,,,,,,,,,,,,, -198900,3.8690412,0.54215276,,,,,,,,,,,,,, -199000,4.7177653,0.66633755,,,,,,,,,,,,,, -199100,4.7391434,0.6977635,,,,,,,,,,,,,, -199200,4.5412197,0.6473016,,,,,,,,,,,,,, -199300,4.90854,0.6819691,,,,,,,,,,,,,, -199356,,,0.9608777165412904,0.1482481956481933,0.7546600103378296,1.042603850364685,50000.0,0.6305000185966492,1.8182530403137207,10000.0,67376.99536323547,69813.56760764122,67376.99536323547,2422.416407346725,7.336972713470459,0.0 -199400,4.2691,0.58678824,,,,,,,,,,,,,, -199500,4.366387,0.63217556,,,,,,,,,,,,,, -199600,4.3487825,0.6280319,,,,,,,,,,,,,, -199700,4.3495107,0.5886576,,,,,,,,,,,,,, -199800,4.5065823,0.6326055,,,,,,,,,,,,,, -199900,4.5171685,0.5842303,,,,,,,,,,,,,, -200000,4.53646,0.63585865,,,,,,,,,,,,,, -200100,4.5600967,0.6187173,,,,,,,,,,,,,, -200200,4.5582104,0.64324623,,,,,,,,,,,,,, -200300,4.740706,0.6274046,,,,,,,,,,,,,, -200400,4.2857165,0.6012216,,,,,,,,,,,,,, -200500,4.528363,0.6093053,,,,,,,,,,,,,, -200600,4.1602902,0.6200187,,,,,,,,,,,,,, -200700,4.195561,0.62411135,,,,,,,,,,,,,, -200800,4.3970566,0.59235406,,,,,,,,,,,,,, -200867,,,0.9594228267669678,0.150496631860733,0.7551800012588501,1.0434798002243042,50000.0,0.6312000155448914,1.821490168571472,10000.0,67887.0665371418,70340.8018951416,67887.0665371418,2439.470073223114,7.393136739730835,0.0 -200900,4.9070716,0.6113216,,,,,,,,,,,,,, -201000,4.6749706,0.6686514,,,,,,,,,,,,,, -201100,4.4920173,0.6241832,,,,,,,,,,,,,, -201200,4.293805,0.7130569,,,,,,,,,,,,,, -201300,4.330487,0.6272891,,,,,,,,,,,,,, -201400,4.567149,0.58969486,,,,,,,,,,,,,, -201500,5.160995,0.6650999,,,,,,,,,,,,,, -201600,4.7507405,0.6695148,,,,,,,,,,,,,, -201700,4.1745725,0.64125156,,,,,,,,,,,,,, -201800,4.816754,0.6305424,,,,,,,,,,,,,, -201900,4.7492685,0.59609914,,,,,,,,,,,,,, -202000,4.4232206,0.56054926,,,,,,,,,,,,,, -202100,4.5330553,0.588573,,,,,,,,,,,,,, -202200,4.798939,0.6948288,,,,,,,,,,,,,, -202300,4.4374957,0.6232401,,,,,,,,,,,,,, -202377,,,0.9612165093421936,0.1434010863304138,0.7554000020027161,1.043828368186951,50000.0,0.631600022315979,1.822011947631836,10000.0,68397.0556704998,70868.1044383049,68397.0556704998,2456.667491674423,7.457369804382324,0.0 -202400,4.6993313,0.6427214,,,,,,,,,,,,,, -202500,4.606806,0.5939986,,,,,,,,,,,,,, -202600,4.5451236,0.570674,,,,,,,,,,,,,, -202700,4.483683,0.6316048,,,,,,,,,,,,,, -202800,4.3691554,0.6467543,,,,,,,,,,,,,, -202900,4.953503,0.7581762,,,,,,,,,,,,,, -203000,4.856586,0.7077525,,,,,,,,,,,,,, -203100,4.447156,0.5506544,,,,,,,,,,,,,, -203200,4.373798,0.5600729,,,,,,,,,,,,,, -203300,4.636086,0.6089443,,,,,,,,,,,,,, -203400,4.6496367,0.6744872,,,,,,,,,,,,,, -203500,4.644059,0.6053024,,,,,,,,,,,,,, -203600,4.9411116,0.5912466,,,,,,,,,,,,,, -203700,4.385749,0.65533626,,,,,,,,,,,,,, -203800,4.684251,0.6651447,,,,,,,,,,,,,, -203887,,,0.9611168503761292,0.1459710448980331,0.7547199726104736,1.0434318780899048,50000.0,0.6301000118255615,1.822356104850769,10000.0,68906.9993698597,71395.26534843445,68906.9993698597,2473.7711822986603,7.51897931098938,0.0 -203900,4.5465975,0.5811866,,,,,,,,,,,,,, -204000,4.3768573,0.68009454,,,,,,,,,,,,,, -204100,4.621255,0.6685858,,,,,,,,,,,,,, -204200,4.4276857,0.5923909,,,,,,,,,,,,,, -204300,4.526396,0.54715896,,,,,,,,,,,,,, -204400,4.6666365,0.64474124,,,,,,,,,,,,,, -204500,4.823121,0.597031,,,,,,,,,,,,,, -204600,4.791178,0.68596345,,,,,,,,,,,,,, -204700,4.4693856,0.6123722,,,,,,,,,,,,,, -204800,4.313058,0.5864278,,,,,,,,,,,,,, -204900,4.3727126,0.69130254,,,,,,,,,,,,,, -205000,5.057336,0.6263166,,,,,,,,,,,,,, -205100,4.8232617,0.66054875,,,,,,,,,,,,,, -205200,4.549309,0.65931916,,,,,,,,,,,,,, -205300,5.266182,0.6870703,,,,,,,,,,,,,, -205397,,,0.9606584906578064,0.1462922245264053,0.7554199695587158,1.0448895692825315,50000.0,0.6310000419616699,1.8229455947875977,10000.0,69416.98277258873,71922.58283758163,69416.98277258873,2490.9919068813324,7.580933094024658,0.0 -205400,4.6038175,0.5697212,,,,,,,,,,,,,, -205500,4.319364,0.5649612,,,,,,,,,,,,,, -205600,4.3081007,0.6891803,,,,,,,,,,,,,, -205700,4.5399647,0.5846633,,,,,,,,,,,,,, -205800,4.5920296,0.7168296,,,,,,,,,,,,,, -205900,4.990364,0.64498556,,,,,,,,,,,,,, -206000,4.091959,0.56309205,,,,,,,,,,,,,, -206100,4.731502,0.5854727,,,,,,,,,,,,,, -206200,4.6480527,0.54232633,,,,,,,,,,,,,, -206300,5.024163,0.692437,,,,,,,,,,,,,, -206400,4.8048797,0.7426253,,,,,,,,,,,,,, -206500,4.8687067,0.6713848,,,,,,,,,,,,,, -206600,4.6024604,0.61376023,,,,,,,,,,,,,, -206700,5.260116,0.6403827,,,,,,,,,,,,,, -206800,4.4840937,0.70828456,,,,,,,,,,,,,, -206900,4.5578156,0.58354324,,,,,,,,,,,,,, -206907,,,0.9609175324440002,0.1488041579723358,0.7552799582481384,1.044255256652832,50000.0,0.6320000290870667,1.8228908777236936,10000.0,69926.99735879898,72449.86073350906,69926.99735879898,2508.140120267868,7.6438798904418945,0.0 -207000,4.3928456,0.60036886,,,,,,,,,,,,,, -207100,4.7555037,0.7129425,,,,,,,,,,,,,, -207200,4.4384294,0.6851153,,,,,,,,,,,,,, -207300,4.448172,0.5712832,,,,,,,,,,,,,, -207400,4.586462,0.6797233,,,,,,,,,,,,,, -207500,4.6957645,0.5895976,,,,,,,,,,,,,, -207600,4.507029,0.6270031,,,,,,,,,,,,,, -207700,4.948182,0.72049105,,,,,,,,,,,,,, -207800,4.729465,0.6282205,,,,,,,,,,,,,, -207900,4.5053635,0.6488037,,,,,,,,,,,,,, -208000,4.7353234,0.60183775,,,,,,,,,,,,,, -208100,4.54572,0.5829582,,,,,,,,,,,,,, -208200,4.720237,0.6461068,,,,,,,,,,,,,, -208300,5.6561747,0.6690878,,,,,,,,,,,,,, -208400,4.175303,0.5830914,,,,,,,,,,,,,, -208417,,,0.9602000713348388,0.1493316441774368,0.7552399635314941,1.0437886714935305,50000.0,0.6306000351905823,1.8218469619750977,10000.0,70436.92092514038,72977.27802371979,70436.92092514038,2525.5167202949524,7.710052490234375,0.0 -208500,4.9068646,0.6252869,,,,,,,,,,,,,, -208600,4.2763014,0.62272537,,,,,,,,,,,,,, -208700,4.3139234,0.624517,,,,,,,,,,,,,, -208800,4.8011923,0.6277851,,,,,,,,,,,,,, -208900,4.7261853,0.6224869,,,,,,,,,,,,,, -209000,4.3157387,0.5823394,,,,,,,,,,,,,, -209100,4.558944,0.55440915,,,,,,,,,,,,,, -209200,4.5894423,0.6654039,,,,,,,,,,,,,, -209300,4.6212378,0.59475285,,,,,,,,,,,,,, -209400,4.7205014,0.6194306,,,,,,,,,,,,,, -209500,4.9040184,0.6523222,,,,,,,,,,,,,, -209600,4.509101,0.5296642,,,,,,,,,,,,,, -209700,4.569057,0.6088838,,,,,,,,,,,,,, -209800,4.3628798,0.6053045,,,,,,,,,,,,,, -209900,4.2194524,0.6626163,,,,,,,,,,,,,, -209928,,,0.9607381820678712,0.1480738967657089,0.7552399635314941,1.044136643409729,50000.0,0.631100058555603,1.8214828968048096,10000.0,70946.95390248299,73504.35719633102,70946.95390248299,2542.445108652115,7.775446891784668,0.0 -210000,4.7542863,0.65154624,,,,,,,,,,,,,, -210100,4.576086,0.61204624,,,,,,,,,,,,,, -210200,4.9756417,0.6224898,,,,,,,,,,,,,, -210300,4.4938903,0.62201524,,,,,,,,,,,,,, -210400,4.8268495,0.67334825,,,,,,,,,,,,,, -210500,4.581256,0.65497154,,,,,,,,,,,,,, -210600,5.026149,0.5924006,,,,,,,,,,,,,, -210700,4.6266794,0.58077615,,,,,,,,,,,,,, -210800,4.029239,0.55699056,,,,,,,,,,,,,, -210900,4.1977468,0.5713648,,,,,,,,,,,,,, -211000,4.26835,0.58002794,,,,,,,,,,,,,, -211100,4.346189,0.5746922,,,,,,,,,,,,,, -211200,4.633463,0.6791212,,,,,,,,,,,,,, -211300,4.158032,0.6087212,,,,,,,,,,,,,, -211400,4.498115,0.6610263,,,,,,,,,,,,,, -211439,,,0.9592434167861938,0.1501782238483429,0.7551800012588501,1.0448315143585205,50000.0,0.6317000389099121,1.823574542999268,10000.0,71457.17092013359,74031.89971113205,71457.17092013359,2559.6543271541595,7.840409278869629,0.0 -211500,5.435941,0.6239417,,,,,,,,,,,,,, -211600,4.471187,0.6114259,,,,,,,,,,,,,, -211700,4.377923,0.6054946,,,,,,,,,,,,,, -211800,4.676623,0.6244736,,,,,,,,,,,,,, -211900,4.630893,0.6916087,,,,,,,,,,,,,, -212000,4.24468,0.596989,,,,,,,,,,,,,, -212100,4.177576,0.5560407,,,,,,,,,,,,,, -212200,4.425124,0.6358645,,,,,,,,,,,,,, -212300,4.592689,0.62929636,,,,,,,,,,,,,, -212400,4.4661884,0.598366,,,,,,,,,,,,,, -212500,4.197128,0.53688085,,,,,,,,,,,,,, -212600,4.6897864,0.69649243,,,,,,,,,,,,,, -212700,4.7382298,0.6276504,,,,,,,,,,,,,, -212800,4.3089476,0.59056365,,,,,,,,,,,,,, -212900,4.908111,0.6912704,,,,,,,,,,,,,, -212949,,,0.9595423936843872,0.1497287303209304,0.7553399801254272,1.0440635681152344,50000.0,0.6307000517845154,1.822056531906128,10000.0,71967.09794926643,74558.99166703224,71967.09794926643,2576.7010428905487,7.906655311584473,0.0 -213000,4.593319,0.68721175,,,,,,,,,,,,,, -213100,4.3401866,0.59571934,,,,,,,,,,,,,, -213200,4.3869023,0.5919128,,,,,,,,,,,,,, -213300,4.733218,0.61332345,,,,,,,,,,,,,, -213400,4.974812,0.6471133,,,,,,,,,,,,,, -213500,4.1864243,0.6425973,,,,,,,,,,,,,, -213600,4.956278,0.6851472,,,,,,,,,,,,,, -213700,4.384967,0.57165784,,,,,,,,,,,,,, -213800,4.6674247,0.66007364,,,,,,,,,,,,,, -213900,4.6287956,0.7025642,,,,,,,,,,,,,, -214000,4.430198,0.6164609,,,,,,,,,,,,,, -214100,4.2664967,0.64618295,,,,,,,,,,,,,, -214200,5.0204644,0.6572935,,,,,,,,,,,,,, -214300,4.4394217,0.6435486,,,,,,,,,,,,,, -214400,4.860818,0.64824307,,,,,,,,,,,,,, -214460,,,0.9602000713348388,0.1484657675027847,0.7551400065422058,1.0437440872192385,50000.0,0.6321000456809998,1.8225747346878047,10000.0,72477.27730298042,75086.28588485718,72477.27730298042,2593.694860935211,7.975393772125244,0.0 -214500,4.2788634,0.6580584,,,,,,,,,,,,,, -214600,4.4331417,0.62079597,,,,,,,,,,,,,, -214700,4.784064,0.6530872,,,,,,,,,,,,,, -214800,4.492331,0.6150024,,,,,,,,,,,,,, -214900,4.5745792,0.6078414,,,,,,,,,,,,,, -215000,4.19227,0.566704,,,,,,,,,,,,,, -215100,4.5394406,0.6072802,,,,,,,,,,,,,, -215200,4.3247766,0.5403453,,,,,,,,,,,,,, -215300,4.28421,0.5914936,,,,,,,,,,,,,, -215400,4.9865966,0.589559,,,,,,,,,,,,,, -215500,4.6609783,0.6527662,,,,,,,,,,,,,, -215600,4.8984923,0.63422346,,,,,,,,,,,,,, -215700,4.2109523,0.5882445,,,,,,,,,,,,,, -215800,4.3727202,0.70158833,,,,,,,,,,,,,, -215900,4.7554955,0.6876545,,,,,,,,,,,,,, -215970,,,0.9599210619926452,0.148515373468399,0.7547399997711182,1.0446696281433103,50000.0,0.6306000351905823,1.823861837387085,10000.0,72987.33560729027,75613.27943301201,72987.33560729027,2610.5095081329346,8.043686389923096,0.0 -216000,4.1258183,0.5556974,,,,,,,,,,,,,, -216100,4.356915,0.6693631,,,,,,,,,,,,,, -216200,4.4576864,0.5953627,,,,,,,,,,,,,, -216300,4.1337357,0.58926386,,,,,,,,,,,,,, -216400,4.5212193,0.6577631,,,,,,,,,,,,,, -216500,4.761735,0.6848096,,,,,,,,,,,,,, -216600,4.4790463,0.6647279,,,,,,,,,,,,,, -216700,4.572527,0.71299845,,,,,,,,,,,,,, -216800,4.404029,0.6303674,,,,,,,,,,,,,, -216900,4.513017,0.6240803,,,,,,,,,,,,,, -217000,4.381716,0.61572826,,,,,,,,,,,,,, -217100,4.6427674,0.6955947,,,,,,,,,,,,,, -217200,4.1259785,0.6214771,,,,,,,,,,,,,, -217300,4.468728,0.634983,,,,,,,,,,,,,, -217400,4.6053505,0.7154045,,,,,,,,,,,,,, -217481,,,0.9605388641357422,0.1486754417419433,0.7548999786376953,1.0447502136230469,50000.0,0.6312000155448914,1.8224695920944207,10000.0,73497.32830357552,76140.42266964912,73497.32830357552,2627.5460596084595,8.105120658874512,0.0 -217500,4.4491277,0.6101242,,,,,,,,,,,,,, -217600,4.6624494,0.6406097,,,,,,,,,,,,,, -217700,4.2392015,0.58867633,,,,,,,,,,,,,, -217800,4.9586873,0.61555934,,,,,,,,,,,,,, -217900,4.397824,0.68214494,,,,,,,,,,,,,, -218000,4.54252,0.6001193,,,,,,,,,,,,,, -218100,4.6671147,0.63593686,,,,,,,,,,,,,, -218200,4.6219025,0.61364096,,,,,,,,,,,,,, -218300,4.464853,0.69341415,,,,,,,,,,,,,, -218400,4.9076433,0.64210826,,,,,,,,,,,,,, -218500,4.5905476,0.600529,,,,,,,,,,,,,, -218600,4.6106224,0.6421922,,,,,,,,,,,,,, -218700,4.6582117,0.6160916,,,,,,,,,,,,,, -218800,4.5923367,0.61440665,,,,,,,,,,,,,, -218900,4.627552,0.61146045,,,,,,,,,,,,,, -218991,,,0.9596021771430968,0.1492205560207367,0.7551199793815613,1.043259620666504,50000.0,0.6301000118255615,1.8205660581588743,10000.0,74007.20657753944,76667.70803070068,74007.20657753944,2644.8311898708344,8.175191640853882,0.0 -219000,4.4567676,0.6696239,,,,,,,,,,,,,, -219100,4.453956,0.60341704,,,,,,,,,,,,,, -219200,4.321008,0.70184636,,,,,,,,,,,,,, -219300,4.6081076,0.67302644,,,,,,,,,,,,,, -219400,4.4794655,0.6265862,,,,,,,,,,,,,, -219500,4.4504046,0.60620505,,,,,,,,,,,,,, -219600,4.3075137,0.6100622,,,,,,,,,,,,,, -219700,4.6486425,0.713946,,,,,,,,,,,,,, -219800,4.5624385,0.649353,,,,,,,,,,,,,, -219900,4.5776114,0.71791697,,,,,,,,,,,,,, -220000,4.6685195,0.629786,,,,,,,,,,,,,, -220100,4.4886312,0.61049235,,,,,,,,,,,,,, -220200,4.2144327,0.599754,,,,,,,,,,,,,, -220300,4.360977,0.5650327,,,,,,,,,,,,,, -220400,4.930068,0.70231044,,,,,,,,,,,,,, -220500,4.4948688,0.65876013,,,,,,,,,,,,,, -220501,,,0.9606186151504515,0.1475935280323028,0.7549399733543396,1.044519543647766,50000.0,0.6314000487327576,1.822311758995056,10000.0,74517.34069633484,77194.7726213932,74517.34069633484,2661.6404991149902,8.243839025497437,0.0 -220600,4.178011,0.66269344,,,,,,,,,,,,,, -220700,4.7260537,0.6610024,,,,,,,,,,,,,, -220800,4.7470164,0.6899575,,,,,,,,,,,,,, -220900,4.2736616,0.57384,,,,,,,,,,,,,, -221000,4.4442725,0.579035,,,,,,,,,,,,,, -221100,4.9539104,0.6914513,,,,,,,,,,,,,, -221200,5.1024756,0.7480687,,,,,,,,,,,,,, -221300,4.586971,0.66194564,,,,,,,,,,,,,, -221400,4.4547915,0.5494964,,,,,,,,,,,,,, -221500,4.453851,0.596562,,,,,,,,,,,,,, -221600,4.728398,0.6370202,,,,,,,,,,,,,, -221700,4.6055903,0.64427525,,,,,,,,,,,,,, -221800,4.559409,0.6388829,,,,,,,,,,,,,, -221900,4.330603,0.56262314,,,,,,,,,,,,,, -222000,4.1312785,0.5919408,,,,,,,,,,,,,, -222012,,,0.960339605808258,0.1480839848518371,0.7552199959754944,1.0450679063796997,50000.0,0.6314000487327576,1.8228307962417605,10000.0,75027.50423502922,77721.93538951874,75027.50423502922,2678.5212664604187,8.310691595077515,0.0 -222100,4.7464623,0.65776914,,,,,,,,,,,,,, -222200,4.3318,0.6007027,,,,,,,,,,,,,, -222300,3.9977837,0.5554819,,,,,,,,,,,,,, -222400,4.1482687,0.6017767,,,,,,,,,,,,,, -222500,4.1252317,0.57427007,,,,,,,,,,,,,, -222600,4.249686,0.6469556,,,,,,,,,,,,,, -222700,4.569852,0.632615,,,,,,,,,,,,,, -222800,4.603521,0.7189339,,,,,,,,,,,,,, -222900,4.510227,0.6553406,,,,,,,,,,,,,, -223000,4.6836534,0.66122425,,,,,,,,,,,,,, -223100,4.30849,0.64538765,,,,,,,,,,,,,, -223200,4.7352543,0.61574996,,,,,,,,,,,,,, -223300,4.8625736,0.6506151,,,,,,,,,,,,,, -223400,4.2672048,0.6266523,,,,,,,,,,,,,, -223500,4.439275,0.6031804,,,,,,,,,,,,,, -223522,,,0.9602798223495485,0.148158848285675,0.7552399635314941,1.04426109790802,50000.0,0.6297000050544739,1.822348356246948,10000.0,75537.42972803116,78249.01416754723,75537.42972803116,2695.558828353882,8.374577283859253,0.0 -223600,4.805964,0.6802703,,,,,,,,,,,,,, -223700,4.505285,0.625333,,,,,,,,,,,,,, -223800,4.436546,0.5964159,,,,,,,,,,,,,, -223900,4.6032963,0.66339636,,,,,,,,,,,,,, -224000,4.347881,0.58599573,,,,,,,,,,,,,, -224100,4.623231,0.686357,,,,,,,,,,,,,, -224200,4.6842294,0.62802255,,,,,,,,,,,,,, -224300,4.328017,0.55209523,,,,,,,,,,,,,, -224400,4.3916397,0.5918156,,,,,,,,,,,,,, -224500,4.1351924,0.53182834,,,,,,,,,,,,,, -224600,4.7312756,0.5771185,,,,,,,,,,,,,, -224700,4.6399045,0.64606833,,,,,,,,,,,,,, -224800,4.3173137,0.58498216,,,,,,,,,,,,,, -224900,4.2409816,0.6511154,,,,,,,,,,,,,, -225000,4.8774548,0.6753672,,,,,,,,,,,,,, -225032,,,0.960339605808258,0.1490835100412368,0.7552799582481384,1.0442862510681152,50000.0,0.6327000260353088,1.8209985494613647,10000.0,76047.38732123375,78776.11684513092,76047.38732123375,2712.5811915397644,8.44601035118103,0.0 -225100,4.7439127,0.6038723,,,,,,,,,,,,,, -225200,4.932991,0.6429987,,,,,,,,,,,,,, -225300,4.3095365,0.56530774,,,,,,,,,,,,,, -225400,4.401693,0.6179778,,,,,,,,,,,,,, -225500,4.771032,0.6316274,,,,,,,,,,,,,, -225600,4.519408,0.5478908,,,,,,,,,,,,,, -225700,4.874821,0.7322155,,,,,,,,,,,,,, -225800,5.954073,0.7159186,,,,,,,,,,,,,, -225900,4.488745,0.58252835,,,,,,,,,,,,,, -226000,4.354525,0.61848265,,,,,,,,,,,,,, -226100,4.510693,0.6584169,,,,,,,,,,,,,, -226200,4.770939,0.6414237,,,,,,,,,,,,,, -226300,4.3723845,0.5989124,,,,,,,,,,,,,, -226400,4.6908507,0.57888806,,,,,,,,,,,,,, -226500,4.4955063,0.5905864,,,,,,,,,,,,,, -226543,,,0.9614157676696776,0.1448323428630828,0.7552799582481384,1.043460488319397,50000.0,0.6315000057220459,1.82106614112854,10000.0,76557.46827101707,79303.39326834679,76557.46827101707,2729.6561329364777,8.513875246047974,0.0 -226600,4.390506,0.62839764,,,,,,,,,,,,,, -226700,4.604,0.6804487,,,,,,,,,,,,,, -226800,4.2562623,0.6086451,,,,,,,,,,,,,, -226900,4.0883617,0.581427,,,,,,,,,,,,,, -227000,4.3484645,0.5887928,,,,,,,,,,,,,, -227100,4.745682,0.6709094,,,,,,,,,,,,,, -227200,4.934975,0.66975904,,,,,,,,,,,,,, -227300,4.3717217,0.5956539,,,,,,,,,,,,,, -227400,4.67544,0.61896527,,,,,,,,,,,,,, -227500,5.021784,0.6878319,,,,,,,,,,,,,, -227600,5.2664275,0.6535903,,,,,,,,,,,,,, -227700,4.9816737,0.6458747,,,,,,,,,,,,,, -227800,4.4201784,0.59913605,,,,,,,,,,,,,, -227900,4.5652275,0.653257,,,,,,,,,,,,,, -228000,4.313815,0.6025061,,,,,,,,,,,,,, -228054,,,0.9608178734779358,0.1471483409404754,0.7548399567604065,1.044750094413757,50000.0,0.6314000487327576,1.82277250289917,10000.0,77067.64736413956,79830.59428882599,77067.64736413956,2746.561534643173,8.578904390335083,0.0 -228100,4.4520426,0.63728315,,,,,,,,,,,,,, -228200,4.2885814,0.65806955,,,,,,,,,,,,,, -228300,4.4167714,0.67082715,,,,,,,,,,,,,, -228400,4.3863306,0.59555304,,,,,,,,,,,,,, -228500,4.722205,0.7154317,,,,,,,,,,,,,, -228600,4.4885287,0.6266887,,,,,,,,,,,,,, -228700,4.1721454,0.5961648,,,,,,,,,,,,,, -228800,4.217614,0.57483935,,,,,,,,,,,,,, -228900,4.534977,0.63271105,,,,,,,,,,,,,, -229000,4.7726336,0.56241167,,,,,,,,,,,,,, -229100,4.726261,0.6770395,,,,,,,,,,,,,, -229200,4.7136216,0.6148176,,,,,,,,,,,,,, -229300,4.6763625,0.6577493,,,,,,,,,,,,,, -229400,4.820639,0.62261146,,,,,,,,,,,,,, -229500,5.070105,0.6313708,,,,,,,,,,,,,, -229565,,,0.9593231678009032,0.1507213413715362,0.7548999786376953,1.0427016019821167,50000.0,0.631600022315979,1.81840181350708,10000.0,77577.66748857498,80358.41845607758,77577.66748857498,2764.244913339615,8.647281408309937,0.0 -229600,5.1211667,0.71325135,,,,,,,,,,,,,, -229700,4.4011855,0.5965742,,,,,,,,,,,,,, -229800,4.7787914,0.6757574,,,,,,,,,,,,,, -229900,4.134889,0.5418009,,,,,,,,,,,,,, -230000,5.590955,0.5618908,,,,,,,,,,,,,, -230100,4.605025,0.7075874,,,,,,,,,,,,,, -230200,4.3422465,0.546373,,,,,,,,,,,,,, -230300,5.2350454,0.73266506,,,,,,,,,,,,,, -230400,4.243898,0.6315829,,,,,,,,,,,,,, -230500,4.5534625,0.65086174,,,,,,,,,,,,,, -230600,4.752192,0.6636396,,,,,,,,,,,,,, -230700,5.095875,0.64277846,,,,,,,,,,,,,, -230800,4.677676,0.6064962,,,,,,,,,,,,,, -230900,4.200425,0.6049631,,,,,,,,,,,,,, -231000,4.401554,0.58374673,,,,,,,,,,,,,, -231075,,,0.9600605964660645,0.1495428681373596,0.7551999688148499,1.043567180633545,50000.0,0.6315000057220459,1.820623159408569,10000.0,78087.64923286438,80885.27916908264,78087.64923286438,2781.0031356811523,8.715314626693726,0.0 -231100,4.7042756,0.674844,,,,,,,,,,,,,, -231200,4.350669,0.6183909,,,,,,,,,,,,,, -231300,4.562873,0.5904005,,,,,,,,,,,,,, -231400,4.37377,0.63769543,,,,,,,,,,,,,, -231500,4.4765353,0.5923459,,,,,,,,,,,,,, -231600,4.633013,0.6106999,,,,,,,,,,,,,, -231700,4.791099,0.68863386,,,,,,,,,,,,,, -231800,4.0729074,0.63231057,,,,,,,,,,,,,, -231900,4.6275544,0.5923872,,,,,,,,,,,,,, -232000,4.626075,0.62654406,,,,,,,,,,,,,, -232100,4.2500114,0.569809,,,,,,,,,,,,,, -232200,4.3280263,0.59729147,,,,,,,,,,,,,, -232300,4.439908,0.65294814,,,,,,,,,,,,,, -232400,5.1562953,0.6708871,,,,,,,,,,,,,, -232500,4.4389834,0.5992921,,,,,,,,,,,,,, -232585,,,0.9606783986091614,0.1485619843006134,0.7554000020027161,1.0447583198547363,50000.0,0.6317000389099121,1.8235143423080444,10000.0,78597.64774560928,81412.56066179276,78597.64774560928,2798.164034128189,8.785706281661987,0.0 -232600,4.7788353,0.6558602,,,,,,,,,,,,,, -232700,4.719293,0.64697564,,,,,,,,,,,,,, -232800,4.0900025,0.56646705,,,,,,,,,,,,,, -232900,4.91889,0.6095057,,,,,,,,,,,,,, -233000,4.6990232,0.6352985,,,,,,,,,,,,,, -233100,4.1216636,0.62448007,,,,,,,,,,,,,, -233200,4.562377,0.6132242,,,,,,,,,,,,,, -233300,4.537544,0.63195825,,,,,,,,,,,,,, -233400,4.9013953,0.6501586,,,,,,,,,,,,,, -233500,4.3405986,0.6280087,,,,,,,,,,,,,, -233600,4.8295617,0.60871434,,,,,,,,,,,,,, -233700,4.667642,0.6561295,,,,,,,,,,,,,, -233800,4.435598,0.65586525,,,,,,,,,,,,,, -233900,4.441266,0.6076317,,,,,,,,,,,,,, -234000,4.487857,0.61821705,,,,,,,,,,,,,, -234095,,,0.9609375,0.1463867574930191,0.7549799680709839,1.0434765815734863,50000.0,0.6314000487327576,1.821435809135437,10000.0,79107.72985172272,81939.9204120636,79107.72985172272,2815.323947429657,8.851818084716797,0.0 -234100,4.576126,0.612932,,,,,,,,,,,,,, -234200,4.3968787,0.656703,,,,,,,,,,,,,, -234300,4.177015,0.5526611,,,,,,,,,,,,,, -234400,4.300816,0.605359,,,,,,,,,,,,,, -234500,4.2837734,0.59445155,,,,,,,,,,,,,, -234600,4.348336,0.57764655,,,,,,,,,,,,,, -234700,4.5578,0.7009852,,,,,,,,,,,,,, -234800,4.489924,0.62267977,,,,,,,,,,,,,, -234900,4.692359,0.66697115,,,,,,,,,,,,,, -235000,4.2235446,0.5654581,,,,,,,,,,,,,, -235100,4.317966,0.5977329,,,,,,,,,,,,,, -235200,4.5742393,0.61027336,,,,,,,,,,,,,, -235300,4.1996737,0.5704939,,,,,,,,,,,,,, -235400,4.665511,0.6089844,,,,,,,,,,,,,, -235500,4.276324,0.5639356,,,,,,,,,,,,,, -235600,4.310693,0.63096553,,,,,,,,,,,,,, -235605,,,0.9598214030265808,0.1489105075597763,0.7550999522209167,1.0432418584823608,50000.0,0.6312000155448914,1.820827841758728,10000.0,79617.69214963913,82466.77335429192,79617.69214963913,2832.093398332596,8.920358657836914,0.0 -235700,4.461544,0.6523143,,,,,,,,,,,,,, -235800,4.7428775,0.67357016,,,,,,,,,,,,,, -235900,4.63321,0.6858756,,,,,,,,,,,,,, -236000,4.4104853,0.6447526,,,,,,,,,,,,,, -236100,4.651535,0.66349316,,,,,,,,,,,,,, -236200,4.371999,0.6332014,,,,,,,,,,,,,, -236300,4.0464787,0.5599696,,,,,,,,,,,,,, -236400,4.4089074,0.5975284,,,,,,,,,,,,,, -236500,4.4263544,0.60162866,,,,,,,,,,,,,, -236600,4.4248934,0.6072298,,,,,,,,,,,,,, -236700,4.133275,0.57717395,,,,,,,,,,,,,, -236800,4.3908186,0.58200455,,,,,,,,,,,,,, -236900,4.439224,0.59557396,,,,,,,,,,,,,, -237000,4.5672503,0.588237,,,,,,,,,,,,,, -237100,4.289353,0.58743054,,,,,,,,,,,,,, -237115,,,0.9611168503761292,0.1452362686395645,0.7553799748420715,1.044255256652832,50000.0,0.6315000057220459,1.822920203208924,10000.0,80127.71584177017,82994.17069745064,80127.71584177017,2849.3516008853912,8.984033823013306,0.0 -237200,4.3241954,0.6683086,,,,,,,,,,,,,, -237300,4.9096217,0.63938963,,,,,,,,,,,,,, -237400,4.5456314,0.6606314,,,,,,,,,,,,,, -237500,4.878351,0.6611473,,,,,,,,,,,,,, -237600,4.17398,0.6270728,,,,,,,,,,,,,, -237700,4.5169263,0.6540743,,,,,,,,,,,,,, -237800,4.5455375,0.682705,,,,,,,,,,,,,, -237900,4.9128904,0.6138406,,,,,,,,,,,,,, -238000,4.5117073,0.6940821,,,,,,,,,,,,,, -238100,4.426141,0.59419644,,,,,,,,,,,,,, -238200,4.3809814,0.64987475,,,,,,,,,,,,,, -238300,4.7534213,0.6506714,,,,,,,,,,,,,, -238400,4.5603423,0.6680853,,,,,,,,,,,,,, -238500,4.6965823,0.63635087,,,,,,,,,,,,,, -238564,,,0.9593430757522584,0.1516030877828598,0.7551599740982056,1.0438953638076782,50000.0,0.6309000253677368,1.8205711841583248,10000.0,80637.91840529442,83521.47772955894,80637.91840529442,2866.3377170562744,9.051939964294434,0.0 -238600,4.2047167,0.5814841,,,,,,,,,,,,,, -238700,4.47938,0.6001013,,,,,,,,,,,,,, -238800,4.8971286,0.62272584,,,,,,,,,,,,,, -238900,4.5941763,0.6556225,,,,,,,,,,,,,, -239000,4.4053006,0.6024153,,,,,,,,,,,,,, -239100,4.6278024,0.607608,,,,,,,,,,,,,, -239200,4.64566,0.6553988,,,,,,,,,,,,,, -239300,4.315225,0.6977055,,,,,,,,,,,,,, -239400,4.43072,0.631137,,,,,,,,,,,,,, -239500,4.8406234,0.6244812,,,,,,,,,,,,,, -239600,5.023853,0.7040487,,,,,,,,,,,,,, -239700,5.2545586,0.6562153,,,,,,,,,,,,,, -239800,4.9387774,0.6578702,,,,,,,,,,,,,, -239900,4.281119,0.57374984,,,,,,,,,,,,,, -240000,4.1983433,0.60122263,,,,,,,,,,,,,, -240075,,,0.961535394191742,0.1448425352573394,0.7551599740982056,1.0428738594055176,50000.0,0.6315000057220459,1.82076632976532,10000.0,81147.93720412254,84048.75836634636,81147.93720412254,2883.4794404506683,9.120002508163452,0.0 -240100,4.209437,0.5729371,,,,,,,,,,,,,, -240200,4.3889585,0.63076997,,,,,,,,,,,,,, -240300,4.7331,0.6349726,,,,,,,,,,,,,, -240400,4.176606,0.54470974,,,,,,,,,,,,,, -240500,4.4615836,0.7166108,,,,,,,,,,,,,, -240600,4.6834373,0.71791244,,,,,,,,,,,,,, -240700,4.990655,0.60549396,,,,,,,,,,,,,, -240800,4.6866097,0.684465,,,,,,,,,,,,,, -240900,4.295915,0.5497507,,,,,,,,,,,,,, -241000,4.312118,0.61275023,,,,,,,,,,,,,, -241100,4.164474,0.5481966,,,,,,,,,,,,,, -241200,4.987616,0.6726587,,,,,,,,,,,,,, -241300,4.585641,0.6110179,,,,,,,,,,,,,, -241400,4.8607225,0.60199225,,,,,,,,,,,,,, -241500,4.335467,0.66251963,,,,,,,,,,,,,, -241586,,,0.9617147445678712,0.1425634920597076,0.7549200057983398,1.0439531803131104,50000.0,0.6305000185966492,1.82180655002594,10000.0,81658.0580971241,84576.08765101433,81658.0580971241,2900.567331552505,9.188629388809204,0.0 -241600,4.492281,0.62707025,,,,,,,,,,,,,, -241700,3.9917922,0.57647216,,,,,,,,,,,,,, -241800,4.658312,0.6030885,,,,,,,,,,,,,, -241900,5.3070693,0.7278205,,,,,,,,,,,,,, -242000,4.8985567,0.6305877,,,,,,,,,,,,,, -242100,4.646127,0.67723197,,,,,,,,,,,,,, -242200,4.7730603,0.59916794,,,,,,,,,,,,,, -242300,4.5602574,0.649003,,,,,,,,,,,,,, -242400,4.314344,0.64756,,,,,,,,,,,,,, -242500,4.579018,0.6420215,,,,,,,,,,,,,, -242600,4.445739,0.69809884,,,,,,,,,,,,,, -242700,4.8017745,0.6302212,,,,,,,,,,,,,, -242800,4.3665624,0.6474363,,,,,,,,,,,,,, -242900,4.77606,0.6302555,,,,,,,,,,,,,, -243000,4.301306,0.5724236,,,,,,,,,,,,,, -243096,,,0.96000075340271,0.148179680109024,0.7549200057983398,1.0450206995010376,50000.0,0.629800021648407,1.821830153465271,10000.0,82167.9572262764,85102.96196842194,82167.9572262764,2917.415878772736,9.261387825012209,0.0 -243100,4.07852,0.5625188,,,,,,,,,,,,,, -243200,4.4559774,0.6059623,,,,,,,,,,,,,, -243300,4.285787,0.5590487,,,,,,,,,,,,,, -243400,4.597713,0.6119468,,,,,,,,,,,,,, -243500,4.5649457,0.64309937,,,,,,,,,,,,,, -243600,3.971567,0.5910971,,,,,,,,,,,,,, -243700,3.8994665,0.57758117,,,,,,,,,,,,,, -243800,4.3109403,0.655706,,,,,,,,,,,,,, -243900,4.61981,0.68626064,,,,,,,,,,,,,, -244000,4.5123177,0.6327007,,,,,,,,,,,,,, -244100,4.583943,0.66948354,,,,,,,,,,,,,, -244200,4.3362145,0.6058992,,,,,,,,,,,,,, -244300,4.343411,0.54319525,,,,,,,,,,,,,, -244400,4.2761827,0.5930059,,,,,,,,,,,,,, -244500,4.666833,0.6554733,,,,,,,,,,,,,, -244600,4.3390217,0.5454259,,,,,,,,,,,,,, -244606,,,0.9604990482330322,0.1490000039339065,0.7550399899482727,1.0443910360336304,50000.0,0.6308000087738037,1.8226317167282104,10000.0,82677.94487500191,85630.08571457863,82677.94487500191,2934.426279783249,9.335773944854736,0.0 -244700,5.148706,0.68094975,,,,,,,,,,,,,, -244800,4.682292,0.6249314,,,,,,,,,,,,,, -244900,4.449961,0.614531,,,,,,,,,,,,,, -245000,4.3768735,0.6665324,,,,,,,,,,,,,, -245100,4.0578156,0.61738473,,,,,,,,,,,,,, -245200,4.6323175,0.66366637,,,,,,,,,,,,,, -245300,3.9899085,0.5930494,,,,,,,,,,,,,, -245400,4.545382,0.6632805,,,,,,,,,,,,,, -245500,4.2372184,0.6604753,,,,,,,,,,,,,, -245600,4.498381,0.67507607,,,,,,,,,,,,,, -245700,4.737137,0.63721734,,,,,,,,,,,,,, -245800,4.022543,0.53652096,,,,,,,,,,,,,, -245900,4.4705777,0.65923226,,,,,,,,,,,,,, -246000,4.4763417,0.59457433,,,,,,,,,,,,,, -246100,4.3577566,0.6228123,,,,,,,,,,,,,, -246117,,,0.9610570669174194,0.1478585451841354,0.7549200057983398,1.044091820716858,50000.0,0.631100058555603,1.820813775062561,10000.0,83188.13959169388,86157.36020231247,83188.13959169388,2951.382652282715,9.406443357467651,0.0 -246200,4.7832885,0.7289177,,,,,,,,,,,,,, -246300,4.7539477,0.6144646,,,,,,,,,,,,,, -246400,4.384066,0.6439919,,,,,,,,,,,,,, -246500,4.6798553,0.7027557,,,,,,,,,,,,,, -246600,4.1657043,0.5664387,,,,,,,,,,,,,, -246700,4.8151183,0.6016749,,,,,,,,,,,,,, -246800,4.3365026,0.6332079,,,,,,,,,,,,,, -246900,3.9434543,0.54987466,,,,,,,,,,,,,, -247000,4.659525,0.66481334,,,,,,,,,,,,,, -247100,4.6461535,0.6577378,,,,,,,,,,,,,, -247200,4.896055,0.654925,,,,,,,,,,,,,, -247300,4.744791,0.69963515,,,,,,,,,,,,,, -247400,4.629317,0.6475131,,,,,,,,,,,,,, -247500,4.596103,0.5772624,,,,,,,,,,,,,, -247600,4.148288,0.6223045,,,,,,,,,,,,,, -247628,,,0.959622085094452,0.148349180817604,0.7554799914360046,1.0434871912002563,50000.0,0.6317000389099121,1.822364330291748,10000.0,83698.27619314194,86684.73764586449,83698.27619314194,2968.502408027649,9.47542691230774,0.0 -247700,4.4104476,0.6564216,,,,,,,,,,,,,, -247800,4.8040066,0.66632324,,,,,,,,,,,,,, -247900,4.4846883,0.66038203,,,,,,,,,,,,,, -248000,4.8370395,0.6104312,,,,,,,,,,,,,, -248100,4.605533,0.627002,,,,,,,,,,,,,, -248200,4.6724763,0.68443054,,,,,,,,,,,,,, -248300,4.264789,0.6071272,,,,,,,,,,,,,, -248400,4.301975,0.61788535,,,,,,,,,,,,,, -248500,4.4476776,0.6204152,,,,,,,,,,,,,, -248600,4.814029,0.6585674,,,,,,,,,,,,,, -248700,4.5022826,0.6160273,,,,,,,,,,,,,, -248800,4.4156218,0.62054306,,,,,,,,,,,,,, -248900,4.3355656,0.6373346,,,,,,,,,,,,,, -249000,4.008137,0.60040534,,,,,,,,,,,,,, -249100,4.5433984,0.5755785,,,,,,,,,,,,,, -249138,,,0.959382951259613,0.1505030393600464,0.7555599808692932,1.044371485710144,50000.0,0.6308000087738037,1.82217025756836,10000.0,84208.43679046631,87212.13094115257,84208.43679046631,2985.615665435791,9.542508840560911,0.0 -249200,4.3190722,0.5835178,,,,,,,,,,,,,, -249300,4.499891,0.64616704,,,,,,,,,,,,,, -249400,4.4903617,0.61553526,,,,,,,,,,,,,, -249500,4.4433904,0.65271413,,,,,,,,,,,,,, -249600,4.2788653,0.5827369,,,,,,,,,,,,,, -249700,4.0742025,0.56852466,,,,,,,,,,,,,, -249800,4.7687435,0.58244884,,,,,,,,,,,,,, -249900,4.8011327,0.65397066,,,,,,,,,,,,,, -250000,4.2892427,0.6005969,,,,,,,,,,,,,, -250100,4.2357836,0.59116495,,,,,,,,,,,,,, -250200,4.6454487,0.60328674,,,,,,,,,,,,,, -250300,4.5675735,0.7094201,,,,,,,,,,,,,, -250400,4.733259,0.6625702,,,,,,,,,,,,,, -250500,4.747846,0.6082971,,,,,,,,,,,,,, -250600,4.5302434,0.60082865,,,,,,,,,,,,,, -250648,,,0.9588448405265808,0.151751235127449,0.7554000020027161,1.0438597202301023,50000.0,0.6315000057220459,1.8228862285614007,10000.0,84718.3010494709,87738.89959907532,84718.3010494709,3002.359689235688,9.651327848434448,0.0 -250700,4.2964277,0.5472205,,,,,,,,,,,,,, -250800,4.3702717,0.58103496,,,,,,,,,,,,,, -250900,4.1226807,0.5761564,,,,,,,,,,,,,, -251000,4.775018,0.6391282,,,,,,,,,,,,,, -251100,4.42555,0.5893514,,,,,,,,,,,,,, -251200,4.4812226,0.68454003,,,,,,,,,,,,,, -251300,4.599976,0.65174145,,,,,,,,,,,,,, -251400,4.573093,0.6793418,,,,,,,,,,,,,, -251500,4.2411714,0.57460684,,,,,,,,,,,,,, -251600,4.4635725,0.6386864,,,,,,,,,,,,,, -251700,4.5894823,0.66992915,,,,,,,,,,,,,, -251800,4.3631067,0.6070194,,,,,,,,,,,,,, -251900,4.951169,0.70656544,,,,,,,,,,,,,, -252000,5.070484,0.555972,,,,,,,,,,,,,, -252100,4.965123,0.6581754,,,,,,,,,,,,,, -252158,,,0.960957407951355,0.1462645232677459,0.7552199959754944,1.043168544769287,50000.0,0.6318000555038452,1.8210536241531368,10000.0,85228.23076438904,88265.92456793785,85228.23076438904,3019.3289988040924,9.725855588912964,0.0 -252200,4.711724,0.61500704,,,,,,,,,,,,,, -252300,4.540889,0.6386044,,,,,,,,,,,,,, -252400,4.534395,0.5842401,,,,,,,,,,,,,, -252500,4.536348,0.5588568,,,,,,,,,,,,,, -252600,4.229348,0.5217196,,,,,,,,,,,,,, -252700,5.2467456,0.6786554,,,,,,,,,,,,,, -252800,3.9729846,0.551341,,,,,,,,,,,,,, -252900,4.3489513,0.63953865,,,,,,,,,,,,,, -253000,4.5719686,0.59233105,,,,,,,,,,,,,, -253100,4.292838,0.5940637,,,,,,,,,,,,,, -253200,4.9548254,0.67971003,,,,,,,,,,,,,, -253300,4.398703,0.6524621,,,,,,,,,,,,,, -253400,4.3872414,0.6028265,,,,,,,,,,,,,, -253500,4.4552927,0.6017089,,,,,,,,,,,,,, -253600,4.747718,0.5942421,,,,,,,,,,,,,, -253669,,,0.959980845451355,0.1491307467222213,0.7551800012588501,1.044544696807861,50000.0,0.6310000419616699,1.8215969800949097,10000.0,85738.42370462418,88793.33317494392,85738.42370462418,3036.42351937294,9.79535961151123,0.0 -253700,4.68071,0.6464093,,,,,,,,,,,,,, -253800,5.1012416,0.6119719,,,,,,,,,,,,,, -253900,5.0743313,0.62607086,,,,,,,,,,,,,, -254000,4.198392,0.5999853,,,,,,,,,,,,,, -254100,4.1205444,0.5967498,,,,,,,,,,,,,, -254200,5.089859,0.6002376,,,,,,,,,,,,,, -254300,4.3453703,0.6600702,,,,,,,,,,,,,, -254400,4.414814,0.6191368,,,,,,,,,,,,,, -254500,4.5954385,0.6885971,,,,,,,,,,,,,, -254600,4.52661,0.66943103,,,,,,,,,,,,,, -254700,4.4215364,0.6948713,,,,,,,,,,,,,, -254800,4.4432535,0.613156,,,,,,,,,,,,,, -254900,4.8878865,0.75167394,,,,,,,,,,,,,, -255000,5.006735,0.7068576,,,,,,,,,,,,,, -255100,4.5133224,0.6900573,,,,,,,,,,,,,, -255179,,,0.9600805044174194,0.1473139524459839,0.7553399801254272,1.0440517663955688,50000.0,0.6305000185966492,1.822287559509277,10000.0,86248.32073330879,89320.47562503815,86248.32073330879,3053.542221546173,9.870186805725098,0.0 -255200,4.834209,0.6590814,,,,,,,,,,,,,, -255300,4.543606,0.65166116,,,,,,,,,,,,,, -255400,4.596297,0.5904815,,,,,,,,,,,,,, -255500,4.7770743,0.6868745,,,,,,,,,,,,,, -255600,4.407538,0.60208094,,,,,,,,,,,,,, -255700,4.44083,0.5756657,,,,,,,,,,,,,, -255800,4.2687163,0.5836,,,,,,,,,,,,,, -255900,4.6883807,0.70231974,,,,,,,,,,,,,, -256000,4.489092,0.6142851,,,,,,,,,,,,,, -256100,4.3642445,0.6438633,,,,,,,,,,,,,, -256200,4.264015,0.58963054,,,,,,,,,,,,,, -256300,4.564598,0.657843,,,,,,,,,,,,,, -256400,4.5349097,0.7136867,,,,,,,,,,,,,, -256500,4.172427,0.6262115,,,,,,,,,,,,,, -256600,4.442971,0.6592424,,,,,,,,,,,,,, -256690,,,0.9594427347183228,0.1510674059391021,0.7545799612998962,1.0446027517318726,50000.0,0.6304000020027161,1.8216809034347528,10000.0,86758.5216627121,89847.73591923714,86758.5216627121,3070.470389842987,9.949352264404297,0.0 -256700,4.3439484,0.5983147,,,,,,,,,,,,,, -256800,4.5542884,0.6268838,,,,,,,,,,,,,, -256900,4.5969934,0.6645356,,,,,,,,,,,,,, -257000,4.622499,0.6208744,,,,,,,,,,,,,, -257100,4.9437027,0.6275606,,,,,,,,,,,,,, -257200,5.1187162,0.67370176,,,,,,,,,,,,,, -257300,4.7874026,0.6894151,,,,,,,,,,,,,, -257400,4.597794,0.63840145,,,,,,,,,,,,,, -257500,4.488408,0.59691644,,,,,,,,,,,,,, -257600,4.2435985,0.5949024,,,,,,,,,,,,,, -257700,4.613808,0.6847792,,,,,,,,,,,,,, -257800,4.113797,0.59633005,,,,,,,,,,,,,, -257900,4.768066,0.6571132,,,,,,,,,,,,,, -258000,4.5080047,0.71672827,,,,,,,,,,,,,, -258100,4.690427,0.6416718,,,,,,,,,,,,,, -258200,,,0.960758090019226,0.1466726660728454,0.754859983921051,1.0436691045761108,50000.0,0.6319000124931335,1.8211063146591189,10000.0,87268.53696084023,90375.06952118874,87268.53696084023,3087.653629779816,10.03275179862976,0.0 -258200,4.604303,0.6461203,,,,,,,,,,,,,, -258300,4.1820674,0.6300021,,,,,,,,,,,,,, -258400,4.2694435,0.5839321,,,,,,,,,,,,,, -258500,5.1335754,0.5954325,,,,,,,,,,,,,, -258600,4.724871,0.63992757,,,,,,,,,,,,,, -258700,4.4126997,0.62545913,,,,,,,,,,,,,, -258800,4.9600673,0.6152868,,,,,,,,,,,,,, -258900,4.410621,0.6486968,,,,,,,,,,,,,, -259000,4.525339,0.61064976,,,,,,,,,,,,,, -259100,4.8767195,0.6076918,,,,,,,,,,,,,, -259200,4.074328,0.52843285,,,,,,,,,,,,,, -259300,4.455363,0.64096344,,,,,,,,,,,,,, -259400,4.2511797,0.5766024,,,,,,,,,,,,,, -259500,4.2159667,0.5779276,,,,,,,,,,,,,, -259600,4.8729277,0.690271,,,,,,,,,,,,,, -259700,4.607116,0.5992384,,,,,,,,,,,,,, -259711,,,0.9605388641357422,0.1487253457307815,0.7547599673271179,1.0438485145568848,50000.0,0.6309000253677368,1.822376608848572,10000.0,87778.56299948692,90902.21386170389,87778.56299948692,3104.648650407791,10.103766679763794,0.0 -259800,4.071634,0.60944694,,,,,,,,,,,,,, -259900,4.5563993,0.66322845,,,,,,,,,,,,,, -260000,4.64828,0.65997416,,,,,,,,,,,,,, -260100,4.2853246,0.62415487,,,,,,,,,,,,,, -260200,4.204811,0.55164456,,,,,,,,,,,,,, -260300,4.188662,0.60962427,,,,,,,,,,,,,, -260400,4.8377743,0.5876225,,,,,,,,,,,,,, -260500,4.2416096,0.5695317,,,,,,,,,,,,,, -260600,5.05253,0.65819705,,,,,,,,,,,,,, -260700,4.693852,0.66887885,,,,,,,,,,,,,, -260800,4.901684,0.5970998,,,,,,,,,,,,,, -260900,4.4452634,0.6432701,,,,,,,,,,,,,, -261000,4.698372,0.61666423,,,,,,,,,,,,,, -261100,4.4031324,0.60301656,,,,,,,,,,,,,, -261200,4.430988,0.64135283,,,,,,,,,,,,,, -261222,,,0.959741711616516,0.149654671549797,0.7551199793815613,1.0438201427459717,50000.0,0.6309000253677368,1.820735216140747,10000.0,88288.66622662544,91429.70258498192,88288.66622662544,3121.9073457717896,10.17822551727295,0.0 -261300,4.6952405,0.6093369,,,,,,,,,,,,,, -261400,4.257907,0.5976832,,,,,,,,,,,,,, -261500,4.3186226,0.56061286,,,,,,,,,,,,,, -261600,4.611245,0.60514104,,,,,,,,,,,,,, -261700,3.9990456,0.56759,,,,,,,,,,,,,, -261800,4.357669,0.5847902,,,,,,,,,,,,,, -261900,4.0005393,0.5184863,,,,,,,,,,,,,, -262000,4.276822,0.5969049,,,,,,,,,,,,,, -262100,4.9852705,0.6398562,,,,,,,,,,,,,, -262200,4.3097763,0.60688835,,,,,,,,,,,,,, -262300,4.5557637,0.5920416,,,,,,,,,,,,,, -262400,4.222383,0.6149694,,,,,,,,,,,,,, -262500,4.5180492,0.6124533,,,,,,,,,,,,,, -262600,4.3208547,0.60770154,,,,,,,,,,,,,, -262700,4.226806,0.5932485,,,,,,,,,,,,,, -262733,,,0.9611766338348388,0.148326426744461,0.754859983921051,1.0432186126708984,50000.0,0.6309000253677368,1.8212600946426392,10000.0,88798.86704826355,91957.24380397797,88798.86704826355,3139.123267889023,10.250165700912476,0.0 -262800,4.7866583,0.6719828,,,,,,,,,,,,,, -262900,4.5991335,0.54778254,,,,,,,,,,,,,, -263000,4.2638855,0.6350527,,,,,,,,,,,,,, -263100,4.58644,0.610149,,,,,,,,,,,,,, -263200,4.3205256,0.63826865,,,,,,,,,,,,,, -263300,4.6905212,0.69065285,,,,,,,,,,,,,, -263400,4.8419724,0.59345144,,,,,,,,,,,,,, -263500,4.632596,0.6368588,,,,,,,,,,,,,, -263600,4.688415,0.6016014,,,,,,,,,,,,,, -263700,4.69195,0.5977906,,,,,,,,,,,,,, -263800,4.724606,0.6566411,,,,,,,,,,,,,, -263900,4.736965,0.6169211,,,,,,,,,,,,,, -264000,4.268974,0.6308239,,,,,,,,,,,,,, -264100,4.3300385,0.55425537,,,,,,,,,,,,,, -264200,4.326065,0.62138253,,,,,,,,,,,,,, -264243,,,0.9611965417861938,0.1453111469745636,0.7555800080299377,1.0443156957626345,50000.0,0.6310000419616699,1.821357488632202,10000.0,89308.83620667458,92484.2161114216,89308.83620667458,3156.001371860504,10.322932720184326,0.0 -264300,4.24602,0.5924264,,,,,,,,,,,,,, -264400,4.4940977,0.5315993,,,,,,,,,,,,,, -264500,4.1463118,0.57566935,,,,,,,,,,,,,, -264600,4.6032166,0.6080479,,,,,,,,,,,,,, -264700,4.6225533,0.6005381,,,,,,,,,,,,,, -264800,4.353854,0.62282705,,,,,,,,,,,,,, -264900,4.51063,0.6758257,,,,,,,,,,,,,, -265000,4.684776,0.70997953,,,,,,,,,,,,,, -265100,4.4613056,0.637007,,,,,,,,,,,,,, -265200,4.784561,0.57101434,,,,,,,,,,,,,, -265300,5.0891895,0.5980146,,,,,,,,,,,,,, -265400,6.2894635,0.6138452,,,,,,,,,,,,,, -265500,4.5360613,0.59006107,,,,,,,,,,,,,, -265600,4.8781567,0.65811574,,,,,,,,,,,,,, -265700,4.4237714,0.62530977,,,,,,,,,,,,,, -265754,,,0.9607979655265808,0.1484108865261078,0.7550999522209167,1.0444883108139038,50000.0,0.6318000555038452,1.8218551874160769,10000.0,89818.98686289787,93011.54561305046,89818.98686289787,3173.056991100312,10.393508911132812,0.0 -265800,4.4195175,0.62006104,,,,,,,,,,,,,, -265900,4.5926785,0.5473751,,,,,,,,,,,,,, -266000,4.338378,0.6048902,,,,,,,,,,,,,, -266100,4.9619074,0.62667716,,,,,,,,,,,,,, -266200,4.594506,0.62024534,,,,,,,,,,,,,, -266300,4.3975034,0.639856,,,,,,,,,,,,,, -266400,4.8380246,0.64082336,,,,,,,,,,,,,, -266500,4.451171,0.60494477,,,,,,,,,,,,,, -266600,5.02929,0.6302134,,,,,,,,,,,,,, -266700,4.3375144,0.6328545,,,,,,,,,,,,,, -266800,5.162942,0.6976415,,,,,,,,,,,,,, -266900,4.8406463,0.70392656,,,,,,,,,,,,,, -267000,4.4870872,0.6666038,,,,,,,,,,,,,, -267100,4.3634753,0.61375266,,,,,,,,,,,,,, -267200,5.011578,0.7323617,,,,,,,,,,,,,, -267265,,,0.9601004123687744,0.14792300760746,0.7553600072860718,1.0435808897018433,50000.0,0.6308000087738037,1.82218599319458,10000.0,90328.98372650146,93538.75735902786,90328.98372650146,3190.143748044968,10.469247817993164,0.0 -267300,4.4276576,0.65291107,,,,,,,,,,,,,, -267400,4.3674917,0.653478,,,,,,,,,,,,,, -267500,4.509684,0.6148432,,,,,,,,,,,,,, -267600,4.3920856,0.5503044,,,,,,,,,,,,,, -267700,4.2694874,0.5595005,,,,,,,,,,,,,, -267800,4.876256,0.64562804,,,,,,,,,,,,,, -267900,4.5872564,0.6536978,,,,,,,,,,,,,, -268000,4.443204,0.6100919,,,,,,,,,,,,,, -268100,4.344698,0.6326905,,,,,,,,,,,,,, -268200,3.9438062,0.6334216,,,,,,,,,,,,,, -268300,4.4540725,0.6453387,,,,,,,,,,,,,, -268400,4.403232,0.5931512,,,,,,,,,,,,,, -268500,4.596098,0.6413357,,,,,,,,,,,,,, -268600,4.310552,0.65763485,,,,,,,,,,,,,, -268700,4.6275635,0.6748338,,,,,,,,,,,,,, -268775,,,0.9599011540412904,0.1491859406232834,0.754859983921051,1.0436774492263794,50000.0,0.6313000321388245,1.82094955444336,10000.0,90839.05476880074,94066.76383280754,90839.05476880074,3207.9533960819244,10.54270601272583,0.0 -268800,5.0277123,0.6365391,,,,,,,,,,,,,, -268900,4.447708,0.61489606,,,,,,,,,,,,,, -269000,4.456033,0.5964438,,,,,,,,,,,,,, -269100,4.8712144,0.62809724,,,,,,,,,,,,,, -269200,4.201025,0.5417392,,,,,,,,,,,,,, -269300,4.5417852,0.64626175,,,,,,,,,,,,,, -269400,4.2092257,0.6039915,,,,,,,,,,,,,, -269500,4.4636736,0.61757845,,,,,,,,,,,,,, -269600,5.022908,0.61880857,,,,,,,,,,,,,, -269700,4.458371,0.66222227,,,,,,,,,,,,,, -269800,4.5060763,0.5794448,,,,,,,,,,,,,, -269900,5.323029,0.6076105,,,,,,,,,,,,,, -270000,4.7973514,0.64659584,,,,,,,,,,,,,, -270100,4.491851,0.624774,,,,,,,,,,,,,, -270200,4.5226817,0.61933684,,,,,,,,,,,,,, -270285,,,0.9602997303009032,0.1504989713430404,0.7555599808692932,1.0429737567901611,50000.0,0.6309000253677368,1.821266889572144,10000.0,91349.01457190514,94594.31363272668,91349.01457190514,3225.401467323303,10.631753206253052,0.0 -270300,4.5966415,0.6035567,,,,,,,,,,,,,, -270400,4.0203114,0.5509906,,,,,,,,,,,,,, -270500,4.711062,0.67232233,,,,,,,,,,,,,, -270600,4.754847,0.65078485,,,,,,,,,,,,,, -270700,4.3998284,0.61663693,,,,,,,,,,,,,, -270800,4.3831472,0.6370606,,,,,,,,,,,,,, -270900,4.3208404,0.64322174,,,,,,,,,,,,,, -271000,4.6979575,0.6632936,,,,,,,,,,,,,, -271100,4.7304482,0.6638094,,,,,,,,,,,,,, -271200,4.391735,0.6069833,,,,,,,,,,,,,, -271300,4.5267944,0.61941105,,,,,,,,,,,,,, -271400,4.3612285,0.60239625,,,,,,,,,,,,,, -271500,4.963251,0.60449153,,,,,,,,,,,,,, -271600,4.5665274,0.574045,,,,,,,,,,,,,, -271700,4.6685762,0.582539,,,,,,,,,,,,,, -271796,,,0.961535394191742,0.1452290415763855,0.7554000020027161,1.0430622100830078,50000.0,0.6319000124931335,1.819881677627564,10000.0,91859.1024298668,95121.58063220978,91859.1024298668,3242.464148521424,10.695482969284058,0.0 -271800,4.5242014,0.6428344,,,,,,,,,,,,,, -271900,4.79843,0.63181573,,,,,,,,,,,,,, -272000,4.3258414,0.62915325,,,,,,,,,,,,,, -272100,4.2367435,0.5938864,,,,,,,,,,,,,, -272200,4.502447,0.6293966,,,,,,,,,,,,,, -272300,4.472181,0.5908481,,,,,,,,,,,,,, -272400,4.2851624,0.5987314,,,,,,,,,,,,,, -272500,4.5047297,0.61763036,,,,,,,,,,,,,, -272600,4.249999,0.5521837,,,,,,,,,,,,,, -272700,4.5801935,0.606021,,,,,,,,,,,,,, -272800,4.4291997,0.55086005,,,,,,,,,,,,,, -272900,4.607282,0.6361484,,,,,,,,,,,,,, -273000,4.1801066,0.62351346,,,,,,,,,,,,,, -273100,4.337417,0.6508325,,,,,,,,,,,,,, -273200,4.3754787,0.6134895,,,,,,,,,,,,,, -273300,4.437032,0.60527366,,,,,,,,,,,,,, -273306,,,0.9596819281578064,0.1488217264413833,0.7553399801254272,1.0445902347564695,50000.0,0.6305000185966492,1.8220733404159544,10000.0,92369.1267938614,95648.849401474,92369.1267938614,3259.586245775223,10.765870571136476,0.0 -273400,4.356481,0.628586,,,,,,,,,,,,,, -273500,4.6174445,0.64934295,,,,,,,,,,,,,, -273600,4.572804,0.6845708,,,,,,,,,,,,,, -273700,4.754448,0.60359913,,,,,,,,,,,,,, -273800,4.701701,0.65230095,,,,,,,,,,,,,, -273900,4.3617406,0.66273564,,,,,,,,,,,,,, -274000,4.2598114,0.55782694,,,,,,,,,,,,,, -274100,4.382783,0.6324563,,,,,,,,,,,,,, -274200,4.8340926,0.68368036,,,,,,,,,,,,,, -274300,4.276106,0.6225362,,,,,,,,,,,,,, -274400,5.1125894,0.66278094,,,,,,,,,,,,,, -274500,4.3380804,0.6336533,,,,,,,,,,,,,, -274600,4.50407,0.6695016,,,,,,,,,,,,,, -274700,4.607499,0.59657437,,,,,,,,,,,,,, -274800,4.294669,0.58173174,,,,,,,,,,,,,, -274816,,,0.9607979655265808,0.1465967744588852,0.7552799582481384,1.0439108610153198,50000.0,0.6303000450134277,1.8224775791168213,10000.0,92879.02329158784,96176.17948126791,92879.02329158784,3276.891673803329,10.8425931930542,0.0 -274900,4.368857,0.6347215,,,,,,,,,,,,,, -275000,4.6209326,0.6179552,,,,,,,,,,,,,, -275100,4.431926,0.62505007,,,,,,,,,,,,,, -275200,4.743752,0.6566331,,,,,,,,,,,,,, -275300,4.7000732,0.6110027,,,,,,,,,,,,,, -275400,4.2921934,0.58359885,,,,,,,,,,,,,, -275500,4.26723,0.69067574,,,,,,,,,,,,,, -275600,4.3455615,0.6368203,,,,,,,,,,,,,, -275700,4.3585787,0.6344709,,,,,,,,,,,,,, -275800,4.4794044,0.667494,,,,,,,,,,,,,, -275900,4.772488,0.71043766,,,,,,,,,,,,,, -276000,4.6099787,0.63945377,,,,,,,,,,,,,, -276100,4.701054,0.70369,,,,,,,,,,,,,, -276200,3.9326396,0.58947265,,,,,,,,,,,,,, -276300,4.253179,0.6225804,,,,,,,,,,,,,, -276327,,,0.9604591727256776,0.1478716731071472,0.7551800012588501,1.044476866722107,50000.0,0.6310000419616699,1.821550726890564,10000.0,93389.19371938704,96703.71585345268,93389.19371938704,3294.133857011795,10.913537740707396,0.0 -276400,4.678078,0.6569172,,,,,,,,,,,,,, -276500,4.30889,0.5800465,,,,,,,,,,,,,, -276600,5.347528,0.6389999,,,,,,,,,,,,,, -276700,4.2275996,0.5759938,,,,,,,,,,,,,, -276800,4.339029,0.6484172,,,,,,,,,,,,,, -276900,4.256869,0.5403878,,,,,,,,,,,,,, -277000,4.541924,0.6489331,,,,,,,,,,,,,, -277100,4.369754,0.6461834,,,,,,,,,,,,,, -277200,4.8897877,0.6230149,,,,,,,,,,,,,, -277300,4.6148777,0.63196063,,,,,,,,,,,,,, -277400,4.4890947,0.63538754,,,,,,,,,,,,,, -277500,4.1953845,0.658809,,,,,,,,,,,,,, -277600,4.6930647,0.6711704,,,,,,,,,,,,,, -277700,4.2913475,0.6208688,,,,,,,,,,,,,, -277800,4.6492076,0.63597476,,,,,,,,,,,,,, -277837,,,0.9602399468421936,0.1496829986572265,0.7553600072860718,1.0444915294647217,50000.0,0.6314000487327576,1.8213553428649905,10000.0,93899.33606386185,97231.03688788414,93899.33606386185,3311.18604016304,10.987645626068115,0.0 -277900,4.645804,0.5651406,,,,,,,,,,,,,, -278000,4.692261,0.6094467,,,,,,,,,,,,,, -278100,4.6732473,0.63829666,,,,,,,,,,,,,, -278200,4.5245743,0.5919908,,,,,,,,,,,,,, -278300,4.4238443,0.5791074,,,,,,,,,,,,,, -278400,4.8085136,0.5958426,,,,,,,,,,,,,, -278500,4.294492,0.62829274,,,,,,,,,,,,,, -278600,4.556628,0.6378604,,,,,,,,,,,,,, -278700,4.514318,0.56285226,,,,,,,,,,,,,, -278800,4.704845,0.63613135,,,,,,,,,,,,,, -278900,4.574945,0.5807196,,,,,,,,,,,,,, -279000,4.36925,0.6754273,,,,,,,,,,,,,, -279100,4.6440463,0.6162103,,,,,,,,,,,,,, -279200,4.259758,0.6303066,,,,,,,,,,,,,, -279300,4.9255037,0.64949286,,,,,,,,,,,,,, -279347,,,0.9615154266357422,0.143006756901741,0.7550999522209167,1.043919563293457,50000.0,0.6308000087738037,1.822886824607849,10000.0,94409.34986424446,97758.04048991203,94409.34986424446,3328.050077676773,11.06198787689209,0.0 -279400,4.443064,0.6259595,,,,,,,,,,,,,, -279500,4.8897877,0.6099243,,,,,,,,,,,,,, -279600,4.5919924,0.59848005,,,,,,,,,,,,,, -279700,4.905027,0.5919877,,,,,,,,,,,,,, -279800,4.927436,0.6523677,,,,,,,,,,,,,, -279900,4.790815,0.61078537,,,,,,,,,,,,,, -280000,4.287532,0.58980155,,,,,,,,,,,,,, -280100,4.7513266,0.6069404,,,,,,,,,,,,,, -280200,4.452676,0.67500216,,,,,,,,,,,,,, -280300,5.110386,0.68695,,,,,,,,,,,,,, -280400,4.458989,0.60074514,,,,,,,,,,,,,, -280500,4.5603075,0.67376786,,,,,,,,,,,,,, -280600,4.770156,0.72904986,,,,,,,,,,,,,, -280700,4.586218,0.5918011,,,,,,,,,,,,,, -280800,4.5366054,0.6329157,,,,,,,,,,,,,, -280858,,,0.960758090019226,0.1456627547740936,0.7552599906921387,1.043474197387695,50000.0,0.631100058555603,1.820765733718872,10000.0,94919.4937672615,98285.43342876434,94919.4937672615,3345.172520637512,11.137099504470823,0.0 -280900,4.654646,0.63654304,,,,,,,,,,,,,, -281000,4.6410303,0.6679387,,,,,,,,,,,,,, -281100,4.511983,0.63711923,,,,,,,,,,,,,, -281200,4.073651,0.55596715,,,,,,,,,,,,,, -281300,4.7682767,0.6549618,,,,,,,,,,,,,, -281400,4.4707866,0.64116305,,,,,,,,,,,,,, -281500,4.513979,0.64805907,,,,,,,,,,,,,, -281600,4.3919477,0.69832754,,,,,,,,,,,,,, -281700,4.514187,0.610751,,,,,,,,,,,,,, -281800,4.5100055,0.6369501,,,,,,,,,,,,,, -281900,4.2799587,0.5935175,,,,,,,,,,,,,, -282000,5.0148864,0.6347765,,,,,,,,,,,,,, -282100,5.056244,0.6367855,,,,,,,,,,,,,, -282200,5.1902404,0.6014654,,,,,,,,,,,,,, -282300,4.6705565,0.6688996,,,,,,,,,,,,,, -282369,,,0.9601203799247742,0.1479266285896301,0.7554599642753601,1.0445046424865725,50000.0,0.6309000253677368,1.821438193321228,10000.0,95429.6256942749,98813.02177882196,95429.6256942749,3362.503289937973,11.210871934890749,0.0 -282400,4.2602806,0.5800599,,,,,,,,,,,,,, -282500,4.7339287,0.5916433,,,,,,,,,,,,,, -282600,3.955429,0.5590567,,,,,,,,,,,,,, -282700,4.6364837,0.59859085,,,,,,,,,,,,,, -282800,4.307811,0.6036378,,,,,,,,,,,,,, -282900,4.659474,0.5948951,,,,,,,,,,,,,, -283000,4.541368,0.63842833,,,,,,,,,,,,,, -283100,4.064434,0.62569404,,,,,,,,,,,,,, -283200,4.3220496,0.62852204,,,,,,,,,,,,,, -283300,4.652812,0.66340655,,,,,,,,,,,,,, -283400,4.4365406,0.619874,,,,,,,,,,,,,, -283500,4.086839,0.6007095,,,,,,,,,,,,,, -283600,4.3091326,0.6038901,,,,,,,,,,,,,, -283700,4.0076137,0.52555,,,,,,,,,,,,,, -283800,4.8501477,0.7142228,,,,,,,,,,,,,, -283879,,,0.9611766338348388,0.1473551839590072,0.7550599575042725,1.043495535850525,50000.0,0.6307000517845154,1.821829915046692,10000.0,95939.5230576992,99340.20707345007,95939.5230576992,3379.660735845566,11.28923773765564,0.0 -283900,4.368033,0.6424317,,,,,,,,,,,,,, -284000,4.460932,0.63058233,,,,,,,,,,,,,, -284100,4.3359737,0.54822534,,,,,,,,,,,,,, -284200,4.7023525,0.6464376,,,,,,,,,,,,,, -284300,4.727248,0.63091767,,,,,,,,,,,,,, -284400,4.810106,0.6260092,,,,,,,,,,,,,, -284500,4.4795837,0.63195586,,,,,,,,,,,,,, -284600,4.6072154,0.6821798,,,,,,,,,,,,,, -284700,4.7519317,0.6491096,,,,,,,,,,,,,, -284800,4.2270546,0.5888572,,,,,,,,,,,,,, -284900,4.403739,0.61630726,,,,,,,,,,,,,, -285000,4.8657093,0.68284035,,,,,,,,,,,,,, -285100,4.771719,0.6517507,,,,,,,,,,,,,, -285200,4.1746655,0.6412769,,,,,,,,,,,,,, -285300,4.71671,0.6442226,,,,,,,,,,,,,, -285390,,,0.9603196382522584,0.1493917554616928,0.7549799680709839,1.043468713760376,50000.0,0.6313000321388245,1.821364164352417,10000.0,96449.7061998844,99867.54928565024,96449.7061998844,3396.690092563629,11.36655569076538,0.0 -285400,4.701206,0.6057947,,,,,,,,,,,,,, -285500,4.6957774,0.6050122,,,,,,,,,,,,,, -285600,4.7722397,0.6589663,,,,,,,,,,,,,, -285700,4.7165318,0.61842084,,,,,,,,,,,,,, -285800,4.4910164,0.6779776,,,,,,,,,,,,,, -285900,4.2298875,0.6601521,,,,,,,,,,,,,, -286000,4.31646,0.5933293,,,,,,,,,,,,,, -286100,4.5420175,0.6287379,,,,,,,,,,,,,, -286200,4.979269,0.7214104,,,,,,,,,,,,,, -286300,4.3007917,0.644842,,,,,,,,,,,,,, -286400,4.9464855,0.65647066,,,,,,,,,,,,,, -286500,4.2236776,0.57549244,,,,,,,,,,,,,, -286600,4.2731395,0.6157265,,,,,,,,,,,,,, -286700,4.230634,0.55515826,,,,,,,,,,,,,, -286800,4.750002,0.60278547,,,,,,,,,,,,,, -286900,,,0.960598647594452,0.1475573778152465,0.7552199959754944,1.0439929962158203,50000.0,0.6312000155448914,1.820935606956482,10000.0,96959.79819345474,100394.79163074492,96959.79819345474,3413.703411340713,11.450378656387327,0.0 -286900,4.7196865,0.62904376,,,,,,,,,,,,,, -287000,4.6271815,0.6308833,,,,,,,,,,,,,, -287100,4.5168915,0.5939974,,,,,,,,,,,,,, -287200,4.274172,0.63507754,,,,,,,,,,,,,, -287300,4.3239374,0.5462396,,,,,,,,,,,,,, -287400,4.9360523,0.6814774,,,,,,,,,,,,,, -287500,3.9951305,0.5750646,,,,,,,,,,,,,, -287600,4.4853816,0.5630109,,,,,,,,,,,,,, -287700,4.4870787,0.61878526,,,,,,,,,,,,,, -287800,4.3830037,0.66761094,,,,,,,,,,,,,, -287900,4.8720703,0.62388647,,,,,,,,,,,,,, -288000,4.702601,0.6448717,,,,,,,,,,,,,, -288100,4.8500066,0.5987773,,,,,,,,,,,,,, -288200,4.7264614,0.6738346,,,,,,,,,,,,,, -288300,5.1050777,0.6394872,,,,,,,,,,,,,, -288400,4.770203,0.6530274,,,,,,,,,,,,,, -288411,,,0.9584462642669678,0.1529013216495514,0.7551999688148499,1.043368577957153,50000.0,0.6308000087738037,1.8218662738800049,10000.0,97469.97769403458,100922.34735965727,97469.97769403458,3430.952383518219,11.525851726531982,0.0 -288500,5.2142706,0.7156652,,,,,,,,,,,,,, -288600,4.4749813,0.58604586,,,,,,,,,,,,,, -288700,4.6389103,0.58265585,,,,,,,,,,,,,, -288800,4.7074904,0.7435928,,,,,,,,,,,,,, -288900,4.3882513,0.5923761,,,,,,,,,,,,,, -289000,4.4270735,0.61679834,,,,,,,,,,,,,, -289100,4.876461,0.65944654,,,,,,,,,,,,,, -289200,4.2952323,0.6119801,,,,,,,,,,,,,, -289300,4.7298727,0.6616418,,,,,,,,,,,,,, -289400,4.4104667,0.63969415,,,,,,,,,,,,,, -289500,4.399671,0.65213853,,,,,,,,,,,,,, -289600,4.475826,0.69408846,,,,,,,,,,,,,, -289700,4.3651876,0.63380325,,,,,,,,,,,,,, -289800,4.6368027,0.6624202,,,,,,,,,,,,,, -289900,4.3349576,0.68499494,,,,,,,,,,,,,, -289922,,,0.9600605964660645,0.1496689617633819,0.7550999522209167,1.0449682474136353,50000.0,0.6312000155448914,1.8234632015228271,10000.0,97980.17452836037,101449.76487326622,97980.17452836037,3448.037952423096,11.608575582504272,0.0 -290000,4.7634125,0.62374145,,,,,,,,,,,,,, -290100,4.6928525,0.6922414,,,,,,,,,,,,,, -290200,4.501153,0.61759615,,,,,,,,,,,,,, -290300,4.5861897,0.6696895,,,,,,,,,,,,,, -290400,4.481532,0.6013799,,,,,,,,,,,,,, -290500,4.431892,0.56613874,,,,,,,,,,,,,, -290600,4.2841554,0.6103016,,,,,,,,,,,,,, -290700,4.7791934,0.68487674,,,,,,,,,,,,,, -290800,4.59019,0.60096717,,,,,,,,,,,,,, -290900,4.5814705,0.65771914,,,,,,,,,,,,,, -291000,4.8998322,0.7244282,,,,,,,,,,,,,, -291100,4.433133,0.6587301,,,,,,,,,,,,,, -291200,4.511776,0.618289,,,,,,,,,,,,,, -291300,4.158167,0.54757255,,,,,,,,,,,,,, -291400,4.358779,0.6183502,,,,,,,,,,,,,, -291433,,,0.9602997303009032,0.1464380621910095,0.7553199529647827,1.044140338897705,50000.0,0.6320000290870667,1.8225699663162231,10000.0,98490.11977744102,101976.9769806862,98490.11977744102,3465.177909612656,11.683691501617432,0.0 -291500,4.625395,0.66338485,,,,,,,,,,,,,, -291600,4.798363,0.64820576,,,,,,,,,,,,,, -291700,4.4527864,0.62344426,,,,,,,,,,,,,, -291800,4.186189,0.5704547,,,,,,,,,,,,,, -291900,4.686159,0.6031387,,,,,,,,,,,,,, -292000,4.2068157,0.57394934,,,,,,,,,,,,,, -292100,4.610393,0.62398136,,,,,,,,,,,,,, -292200,4.0541644,0.5625203,,,,,,,,,,,,,, -292300,4.189758,0.61155987,,,,,,,,,,,,,, -292400,4.7128534,0.6229851,,,,,,,,,,,,,, -292500,4.841168,0.6772698,,,,,,,,,,,,,, -292600,4.439038,0.68258387,,,,,,,,,,,,,, -292700,4.714814,0.6454455,,,,,,,,,,,,,, -292800,4.780301,0.6816914,,,,,,,,,,,,,, -292900,5.0616503,0.6494627,,,,,,,,,,,,,, -292944,,,0.9602598547935486,0.1485782861709594,0.7555999755859375,1.0444979667663574,50000.0,0.6308000087738037,1.821649074554444,10000.0,99000.29478907584,102504.31647467612,99000.29478907584,3482.211163520813,11.76318645477295,0.0 -293000,4.3000317,0.60403603,,,,,,,,,,,,,, -293100,4.5602846,0.62130415,,,,,,,,,,,,,, -293200,4.5035753,0.67643476,,,,,,,,,,,,,, -293300,4.328189,0.58185756,,,,,,,,,,,,,, -293400,4.5142884,0.61594,,,,,,,,,,,,,, -293500,4.423261,0.58215004,,,,,,,,,,,,,, -293600,4.930098,0.6740501,,,,,,,,,,,,,, -293700,4.7668605,0.561153,,,,,,,,,,,,,, -293800,4.3446226,0.6139516,,,,,,,,,,,,,, -293900,4.15386,0.58025426,,,,,,,,,,,,,, -294000,4.677874,0.62234986,,,,,,,,,,,,,, -294100,4.425712,0.59768975,,,,,,,,,,,,,, -294200,4.74155,0.571054,,,,,,,,,,,,,, -294300,4.7105517,0.6743986,,,,,,,,,,,,,, -294400,4.900336,0.8143766,,,,,,,,,,,,,, -294455,,,0.9598014950752258,0.1485868841409683,0.7549799680709839,1.0446248054504397,50000.0,0.6317000389099121,1.821648359298706,10000.0,99510.4796898365,103031.7966837883,99510.4796898365,3499.3769342899323,11.840510606765749,0.0 -294500,4.3185363,0.6403203,,,,,,,,,,,,,, -294600,4.35669,0.6087688,,,,,,,,,,,,,, -294700,4.4538608,0.5941871,,,,,,,,,,,,,, -294800,4.397072,0.6319203,,,,,,,,,,,,,, -294900,4.315347,0.6046406,,,,,,,,,,,,,, -295000,4.425781,0.5752305,,,,,,,,,,,,,, -295100,4.8215246,0.6956695,,,,,,,,,,,,,, -295200,4.675867,0.70588446,,,,,,,,,,,,,, -295300,4.7926855,0.6112734,,,,,,,,,,,,,, -295400,4.9667125,0.6293291,,,,,,,,,,,,,, -295500,4.3813744,0.6787328,,,,,,,,,,,,,, -295600,4.953261,0.62785065,,,,,,,,,,,,,, -295700,4.5565352,0.64099956,,,,,,,,,,,,,, -295800,4.3289714,0.60514843,,,,,,,,,,,,,, -295900,4.344562,0.6230479,,,,,,,,,,,,,, -295965,,,0.9590441584587096,0.1495958268642425,0.7550199627876282,1.043899655342102,50000.0,0.631100058555603,1.821966528892517,10000.0,100020.47172164916,103559.07963442802,100020.47172164916,3516.5390000343323,11.916550636291504,0.0 -296000,4.1624794,0.6009482,,,,,,,,,,,,,, -296100,4.9739037,0.7313245,,,,,,,,,,,,,, -296200,4.23663,0.6821954,,,,,,,,,,,,,, -296300,4.483362,0.598722,,,,,,,,,,,,,, -296400,4.209904,0.62623566,,,,,,,,,,,,,, -296500,4.7818556,0.59770805,,,,,,,,,,,,,, -296600,4.876374,0.71354765,,,,,,,,,,,,,, -296700,4.468125,0.61802936,,,,,,,,,,,,,, -296800,4.0788326,0.5887711,,,,,,,,,,,,,, -296900,4.6803575,0.7547888,,,,,,,,,,,,,, -297000,4.560614,0.6443989,,,,,,,,,,,,,, -297100,5.050934,0.6775876,,,,,,,,,,,,,, -297200,4.2616014,0.61481214,,,,,,,,,,,,,, -297300,4.272287,0.5863554,,,,,,,,,,,,,, -297400,4.195932,0.61937,,,,,,,,,,,,,, -297475,,,0.961734652519226,0.1469607055187225,0.7551800012588501,1.0441290140151978,50000.0,0.6300000548362732,1.8233592510223389,10000.0,100530.40967082976,104086.43240571022,100530.40967082976,3533.82601761818,11.992268562316896,0.0 -297500,4.4176893,0.5760897,,,,,,,,,,,,,, -297600,4.711869,0.648707,,,,,,,,,,,,,, -297700,3.977405,0.55775523,,,,,,,,,,,,,, -297800,4.453221,0.6228595,,,,,,,,,,,,,, -297900,4.591533,0.6998664,,,,,,,,,,,,,, -298000,4.2294936,0.60144603,,,,,,,,,,,,,, -298100,4.650911,0.68861014,,,,,,,,,,,,,, -298200,4.9930205,0.7306224,,,,,,,,,,,,,, -298300,4.4872127,0.6155093,,,,,,,,,,,,,, -298400,5.4825644,0.647515,,,,,,,,,,,,,, -298500,4.482762,0.6394434,,,,,,,,,,,,,, -298600,4.796337,0.6252778,,,,,,,,,,,,,, -298700,4.5705237,0.69418406,,,,,,,,,,,,,, -298800,4.3791857,0.5460757,,,,,,,,,,,,,, -298900,5.4357305,0.61031103,,,,,,,,,,,,,, -298986,,,0.9597018361091614,0.1497321426868438,0.7550599575042725,1.0440407991409302,50000.0,0.631100058555603,1.822619795799256,10000.0,101040.5618698597,104614.27379345894,101040.5618698597,3551.3767199516296,12.078696966171265,0.0 -299000,4.238697,0.6276698,,,,,,,,,,,,,, -299100,4.453745,0.65591985,,,,,,,,,,,,,, -299200,4.4568305,0.6517109,,,,,,,,,,,,,, -299300,4.244901,0.58929014,,,,,,,,,,,,,, -299400,4.7292466,0.5759352,,,,,,,,,,,,,, -299500,4.2583265,0.58438796,,,,,,,,,,,,,, -299600,4.6566367,0.6439345,,,,,,,,,,,,,, -299700,4.812745,0.67175466,,,,,,,,,,,,,, -299800,4.8043566,0.638874,,,,,,,,,,,,,, -299900,4.296486,0.5762098,,,,,,,,,,,,,, -300000,4.913301,0.5979923,,,,,,,,,,,,,, -300100,4.4821115,0.63693404,,,,,,,,,,,,,, -300200,4.093786,0.6097462,,,,,,,,,,,,,, -300300,4.5980625,0.6526705,,,,,,,,,,,,,, -300400,4.637873,0.60066384,,,,,,,,,,,,,, -300496,,,0.9599011540412904,0.1493127793073654,0.7553399801254272,1.0435774326324463,50000.0,0.6314000487327576,1.821979522705078,10000.0,101550.59058713912,105141.25382041933,101550.59058713912,3568.2104346752167,12.144399642944336,0.0 -300500,4.905632,0.6840572,,,,,,,,,,,,,, -300600,4.619848,0.60219455,,,,,,,,,,,,,, -300700,4.4925838,0.6398193,,,,,,,,,,,,,, -300800,5.042421,0.6542162,,,,,,,,,,,,,, -300900,4.3816752,0.63175094,,,,,,,,,,,,,, -301000,4.3668222,0.5968246,,,,,,,,,,,,,, -301100,4.4052067,0.6566689,,,,,,,,,,,,,, -301200,4.732354,0.69596016,,,,,,,,,,,,,, -301300,4.21673,0.597668,,,,,,,,,,,,,, -301400,4.487298,0.66430753,,,,,,,,,,,,,, -301500,4.4506435,0.6572578,,,,,,,,,,,,,, -301600,4.259365,0.5760091,,,,,,,,,,,,,, -301700,4.678336,0.5964546,,,,,,,,,,,,,, -301800,4.6037817,0.6475205,,,,,,,,,,,,,, -301900,4.2157545,0.59229493,,,,,,,,,,,,,, -302000,4.5352407,0.55542976,,,,,,,,,,,,,, -302007,,,0.960758090019226,0.146996721625328,0.7557199597358704,1.043437123298645,50000.0,0.631100058555603,1.821194648742676,10000.0,102060.7697827816,105668.68538475037,102060.7697827816,3585.3242876529694,12.231438398361206,0.0 -302100,4.572152,0.64357096,,,,,,,,,,,,,, -302200,4.8324275,0.6617567,,,,,,,,,,,,,, -302300,4.533871,0.663981,,,,,,,,,,,,,, -302400,4.436307,0.58786726,,,,,,,,,,,,,, -302500,5.031862,0.6672423,,,,,,,,,,,,,, -302600,4.486847,0.6002369,,,,,,,,,,,,,, -302700,4.577111,0.6821781,,,,,,,,,,,,,, -302800,4.698533,0.6989126,,,,,,,,,,,,,, -302900,4.270834,0.58791494,,,,,,,,,,,,,, -303000,4.71042,0.6041156,,,,,,,,,,,,,, -303100,4.407433,0.6307641,,,,,,,,,,,,,, -303200,4.243944,0.5686755,,,,,,,,,,,,,, -303300,4.5714817,0.597384,,,,,,,,,,,,,, -303400,4.369348,0.6007416,,,,,,,,,,,,,, -303500,4.5967646,0.57788235,,,,,,,,,,,,,, -303517,,,0.961316168308258,0.1453872919082641,0.7549600005149841,1.044714331626892,50000.0,0.6321000456809998,1.822009444236756,10000.0,102570.90319132803,106196.04564929008,102570.90319132803,3602.418655157089,12.31143856048584,0.0 -303600,4.393567,0.6582232,,,,,,,,,,,,,, -303700,4.3803225,0.6198113,,,,,,,,,,,,,, -303800,4.504687,0.6108597,,,,,,,,,,,,,, -303900,4.2523527,0.5794461,,,,,,,,,,,,,, -304000,4.436913,0.553863,,,,,,,,,,,,,, -304100,5.297815,0.7489712,,,,,,,,,,,,,, -304200,4.8206134,0.7189488,,,,,,,,,,,,,, -304300,4.513944,0.6203786,,,,,,,,,,,,,, -304400,4.671884,0.64864814,,,,,,,,,,,,,, -304500,4.507957,0.6686337,,,,,,,,,,,,,, -304600,4.6266384,0.7068755,,,,,,,,,,,,,, -304700,4.10061,0.60686255,,,,,,,,,,,,,, -304800,4.532914,0.6561544,,,,,,,,,,,,,, -304900,4.3509436,0.6478746,,,,,,,,,,,,,, -305000,4.7154436,0.5911071,,,,,,,,,,,,,, -305028,,,0.9611766338348388,0.1474113166332245,0.7551199793815613,1.0436537265777588,50000.0,0.631600022315979,1.821263194084168,10000.0,103081.1084947586,106723.35728740692,103081.1084947586,3619.393835544586,12.389586448669434,0.0 -305100,4.5159545,0.6311032,,,,,,,,,,,,,, -305200,4.421678,0.61747664,,,,,,,,,,,,,, -305300,4.569581,0.5710524,,,,,,,,,,,,,, -305400,4.5297647,0.60937726,,,,,,,,,,,,,, -305500,4.2427197,0.59517694,,,,,,,,,,,,,, -305600,4.172353,0.58417124,,,,,,,,,,,,,, -305700,4.580051,0.6564833,,,,,,,,,,,,,, -305800,4.4697356,0.5693701,,,,,,,,,,,,,, -305900,4.726779,0.60293543,,,,,,,,,,,,,, -306000,4.3799314,0.61068183,,,,,,,,,,,,,, -306100,4.298533,0.603894,,,,,,,,,,,,,, -306200,4.6038766,0.6272681,,,,,,,,,,,,,, -306300,4.5020666,0.60638237,,,,,,,,,,,,,, -306400,4.6301284,0.64641833,,,,,,,,,,,,,, -306500,4.966701,0.5620903,,,,,,,,,,,,,, -306538,,,0.9596021771430968,0.1488519310951233,0.7552799582481384,1.0441625118255615,50000.0,0.6306000351905823,1.8231303691864007,10000.0,103590.98895573616,107251.10673069954,103590.98895573616,3637.1305088996887,12.47010898590088,0.0 -306600,4.5030136,0.57060295,,,,,,,,,,,,,, -306700,4.5725203,0.6167139,,,,,,,,,,,,,, -306800,4.464834,0.61487937,,,,,,,,,,,,,, -306900,4.2309146,0.59101814,,,,,,,,,,,,,, -307000,4.2502074,0.58276844,,,,,,,,,,,,,, -307100,5.0167265,0.6686276,,,,,,,,,,,,,, -307200,4.6744967,0.6520748,,,,,,,,,,,,,, -307300,4.5695305,0.6697713,,,,,,,,,,,,,, -307400,4.5461087,0.6467546,,,,,,,,,,,,,, -307500,4.3830304,0.6018708,,,,,,,,,,,,,, -307600,4.4341674,0.68335843,,,,,,,,,,,,,, -307700,4.0841284,0.572147,,,,,,,,,,,,,, -307800,4.86676,0.7072278,,,,,,,,,,,,,, -307900,4.272607,0.58496153,,,,,,,,,,,,,, -308000,4.895974,0.6634885,,,,,,,,,,,,,, -308048,,,0.9596819281578064,0.1510353088378906,0.7555599808692932,1.044703722000122,50000.0,0.6309000253677368,1.822330355644226,10000.0,104100.93312764168,107778.29575705528,104100.93312764168,3654.240547180176,12.552121877670288,0.0 -308100,4.509024,0.6175163,,,,,,,,,,,,,, -308200,4.340645,0.64101964,,,,,,,,,,,,,, -308300,4.797175,0.6476878,,,,,,,,,,,,,, -308400,4.923481,0.61591953,,,,,,,,,,,,,, -308500,4.6373277,0.73617643,,,,,,,,,,,,,, -308600,4.6224585,0.6171307,,,,,,,,,,,,,, -308700,4.449475,0.73873013,,,,,,,,,,,,,, -308800,4.5798593,0.67862,,,,,,,,,,,,,, -308900,4.35872,0.55060405,,,,,,,,,,,,,, -309000,4.4144135,0.66590184,,,,,,,,,,,,,, -309100,4.359337,0.60198617,,,,,,,,,,,,,, -309200,4.9949813,0.6499024,,,,,,,,,,,,,, -309300,4.777615,0.62477475,,,,,,,,,,,,,, -309400,4.339765,0.59821534,,,,,,,,,,,,,, -309500,4.630165,0.5983224,,,,,,,,,,,,,, -309558,,,0.9604591727256776,0.1492012590169906,0.7558599710464478,1.0423375368118286,50000.0,0.6309000253677368,1.820195198059082,10000.0,104610.91036987305,108305.5235171318,104610.91036987305,3671.359024047852,12.631409883499146,0.0 -309600,4.6567574,0.61515844,,,,,,,,,,,,,, -309700,4.5874476,0.67027026,,,,,,,,,,,,,, -309800,4.655118,0.6402689,,,,,,,,,,,,,, -309900,4.3167095,0.5680984,,,,,,,,,,,,,, -310000,4.5778937,0.68211746,,,,,,,,,,,,,, -310100,4.396378,0.64584005,,,,,,,,,,,,,, -310200,4.6079674,0.68809533,,,,,,,,,,,,,, -310300,4.291113,0.5635373,,,,,,,,,,,,,, -310400,4.663335,0.6572325,,,,,,,,,,,,,, -310500,4.587268,0.6028135,,,,,,,,,,,,,, -310600,4.770692,0.61815274,,,,,,,,,,,,,, -310700,4.8496885,0.60842,,,,,,,,,,,,,, -310800,4.4311366,0.62551045,,,,,,,,,,,,,, -310900,4.511292,0.570835,,,,,,,,,,,,,, -311000,4.4741797,0.6345409,,,,,,,,,,,,,, -311069,,,0.960180163383484,0.1465490460395813,0.7554799914360046,1.0438048839569092,50000.0,0.6317000389099121,1.8213917016983032,10000.0,105120.89776802064,108832.93958759308,105120.89776802064,3688.6541543006897,12.713399410247805,0.0 -311100,4.3142004,0.551234,,,,,,,,,,,,,, -311200,4.808642,0.6596344,,,,,,,,,,,,,, -311300,4.2601194,0.61161375,,,,,,,,,,,,,, -311400,4.5367956,0.6402646,,,,,,,,,,,,,, -311500,4.8444386,0.5994769,,,,,,,,,,,,,, -311600,4.9660606,0.650637,,,,,,,,,,,,,, -311700,4.30894,0.6077099,,,,,,,,,,,,,, -311800,4.7184567,0.6339408,,,,,,,,,,,,,, -311900,4.1796484,0.5969815,,,,,,,,,,,,,, -312000,4.6449485,0.6846937,,,,,,,,,,,,,, -312100,4.9326816,0.73309875,,,,,,,,,,,,,, -312200,4.3355722,0.59758556,,,,,,,,,,,,,, -312300,4.569557,0.59773964,,,,,,,,,,,,,, -312400,4.650357,0.69891083,,,,,,,,,,,,,, -312500,4.206397,0.58437103,,,,,,,,,,,,,, -312579,,,0.960379421710968,0.147471010684967,0.754859983921051,1.043671727180481,50000.0,0.6312000155448914,1.821340084075928,10000.0,105630.90949559212,109360.03984093666,105630.90949559212,3705.587835550308,12.816117763519289,0.0 -312600,6.1501136,0.6104528,,,,,,,,,,,,,, -312700,4.4172606,0.6390087,,,,,,,,,,,,,, -312800,4.625256,0.6686047,,,,,,,,,,,,,, -312900,4.785734,0.63655436,,,,,,,,,,,,,, -313000,4.314822,0.6448614,,,,,,,,,,,,,, -313100,4.7413216,0.6795776,,,,,,,,,,,,,, -313200,4.7225404,0.6120417,,,,,,,,,,,,,, -313300,4.7487016,0.63517845,,,,,,,,,,,,,, -313400,4.3733764,0.63838226,,,,,,,,,,,,,, -313500,4.569233,0.6662125,,,,,,,,,,,,,, -313600,4.214604,0.6129537,,,,,,,,,,,,,, -313700,4.6679373,0.6430191,,,,,,,,,,,,,, -313800,3.8471565,0.51946336,,,,,,,,,,,,,, -313900,5.013877,0.6136082,,,,,,,,,,,,,, -314000,4.1015277,0.6080612,,,,,,,,,,,,,, -314089,,,0.9612165093421936,0.1459231972694397,0.7547199726104736,1.0452815294265747,50000.0,0.6302000284194946,1.8212146759033203,10000.0,106140.7857196331,109886.95943021774,106140.7857196331,3722.494250059128,12.90107798576355,0.0 -314100,4.2712126,0.6089885,,,,,,,,,,,,,, -314200,4.538593,0.63403827,,,,,,,,,,,,,, -314300,4.468463,0.6404916,,,,,,,,,,,,,, -314400,4.4377236,0.63133,,,,,,,,,,,,,, -314500,4.4645677,0.6313496,,,,,,,,,,,,,, -314600,5.372984,0.71321046,,,,,,,,,,,,,, -314700,4.8824825,0.6424591,,,,,,,,,,,,,, -314800,4.3129926,0.6469408,,,,,,,,,,,,,, -314900,4.7546916,0.6774876,,,,,,,,,,,,,, -315000,4.557734,0.6648177,,,,,,,,,,,,,, -315100,4.826072,0.6227453,,,,,,,,,,,,,, -315200,4.5832195,0.6393745,,,,,,,,,,,,,, -315300,4.362477,0.61409605,,,,,,,,,,,,,, -315400,4.1815147,0.57705396,,,,,,,,,,,,,, -315500,4.624536,0.6244816,,,,,,,,,,,,,, -315599,,,0.9605787396430968,0.1492339074611663,0.7549999952316284,1.0437467098236084,50000.0,0.6309000253677368,1.820225715637207,10000.0,106650.83040618896,110413.95022773744,106650.83040618896,3739.307128429413,12.98228931427002,0.0 -315600,4.3053546,0.6285468,,,,,,,,,,,,,, -315700,4.450945,0.631353,,,,,,,,,,,,,, -315800,4.922214,0.58706343,,,,,,,,,,,,,, -315900,4.5151124,0.6806048,,,,,,,,,,,,,, -316000,4.7086253,0.7271057,,,,,,,,,,,,,, -316100,4.381059,0.65666914,,,,,,,,,,,,,, -316200,4.6349416,0.6613351,,,,,,,,,,,,,, -316300,4.4602437,0.60786164,,,,,,,,,,,,,, -316400,4.408993,0.63004345,,,,,,,,,,,,,, -316500,4.258681,0.7018821,,,,,,,,,,,,,, -316600,4.2411466,0.6105279,,,,,,,,,,,,,, -316700,4.823024,0.6534086,,,,,,,,,,,,,, -316800,4.6037207,0.65134835,,,,,,,,,,,,,, -316900,4.697258,0.5605706,,,,,,,,,,,,,, -317000,4.2245846,0.5979113,,,,,,,,,,,,,, -317100,5.2896247,0.71467394,,,,,,,,,,,,,, -317109,,,0.9605189561843872,0.1477705836296081,0.7552399635314941,1.0443214178085327,50000.0,0.6308000087738037,1.821238398551941,10000.0,107160.75046777724,110941.01823163033,107160.75046777724,3756.320683717728,13.06489896774292,0.0 -317200,4.170776,0.6549167,,,,,,,,,,,,,, -317300,4.7507844,0.6277966,,,,,,,,,,,,,, -317400,3.938964,0.51371455,,,,,,,,,,,,,, -317500,4.729806,0.6615238,,,,,,,,,,,,,, -317600,4.476373,0.6592658,,,,,,,,,,,,,, -317700,4.2576284,0.6382866,,,,,,,,,,,,,, -317800,4.3870535,0.62669265,,,,,,,,,,,,,, -317900,4.658796,0.5847657,,,,,,,,,,,,,, -318000,4.437911,0.6011289,,,,,,,,,,,,,, -318100,4.2920218,0.6228321,,,,,,,,,,,,,, -318200,4.651701,0.66404843,,,,,,,,,,,,,, -318300,4.1228104,0.573128,,,,,,,,,,,,,, -318400,4.286095,0.7003746,,,,,,,,,,,,,, -318500,4.7255025,0.6707751,,,,,,,,,,,,,, -318600,4.4003935,0.6163283,,,,,,,,,,,,,, -318619,,,0.9615154266357422,0.1435057967901229,0.7549600005149841,1.0443395376205444,50000.0,0.6307000517845154,1.82354736328125,10000.0,107670.861992836,111468.52405381204,107670.861992836,3773.578710079193,13.148986101150513,0.0 -318700,4.272143,0.555706,,,,,,,,,,,,,, -318800,4.100787,0.590168,,,,,,,,,,,,,, -318900,4.882275,0.6410018,,,,,,,,,,,,,, -319000,4.045515,0.565058,,,,,,,,,,,,,, -319100,4.5475583,0.6523077,,,,,,,,,,,,,, -319200,4.4035735,0.6228948,,,,,,,,,,,,,, -319300,5.024739,0.5925552,,,,,,,,,,,,,, -319400,5.056541,0.63340974,,,,,,,,,,,,,, -319500,4.520084,0.615094,,,,,,,,,,,,,, -319600,4.764039,0.63399804,,,,,,,,,,,,,, -319700,4.1818604,0.5838345,,,,,,,,,,,,,, -319800,4.410555,0.64333475,,,,,,,,,,,,,, -319900,4.234242,0.6177379,,,,,,,,,,,,,, -320000,4.4405355,0.6199501,,,,,,,,,,,,,, -320100,4.348188,0.6193944,,,,,,,,,,,,,, -320129,,,0.960359513759613,0.1465110927820205,0.7546199560165405,1.044081687927246,50000.0,0.6304000020027161,1.821660280227661,10000.0,108180.89000105858,111995.6765794754,108180.89000105858,3790.567802667618,13.232507467269896,0.0 -320200,4.4588447,0.6659399,,,,,,,,,,,,,, -320300,4.2678175,0.5904589,,,,,,,,,,,,,, -320400,4.513592,0.6152521,,,,,,,,,,,,,, -320500,4.3190355,0.6213745,,,,,,,,,,,,,, -320600,4.6072593,0.5815777,,,,,,,,,,,,,, -320700,4.7928853,0.6169122,,,,,,,,,,,,,, -320800,4.1755743,0.5843522,,,,,,,,,,,,,, -320900,4.6442666,0.6951459,,,,,,,,,,,,,, -321000,4.629324,0.6060413,,,,,,,,,,,,,, -321100,4.252908,0.59790015,,,,,,,,,,,,,, -321200,5.3441114,0.70143217,,,,,,,,,,,,,, -321300,4.393693,0.58448356,,,,,,,,,,,,,, -321400,4.7873836,0.6625726,,,,,,,,,,,,,, -321500,4.733441,0.6237834,,,,,,,,,,,,,, -321600,4.3320103,0.58547646,,,,,,,,,,,,,, -321640,,,0.9608178734779358,0.1478320807218551,0.7554199695587158,1.0437196493148804,50000.0,0.6313000321388245,1.8210450410842896,10000.0,108690.9446735382,112523.06674385072,108690.9446735382,3807.768210887909,13.314942359924316,0.0 -321700,4.728158,0.6860733,,,,,,,,,,,,,, -321800,4.5717983,0.6310719,,,,,,,,,,,,,, -321900,4.3199034,0.58411473,,,,,,,,,,,,,, -322000,4.2921877,0.60486114,,,,,,,,,,,,,, -322100,4.211375,0.61194193,,,,,,,,,,,,,, -322200,4.3869166,0.61951214,,,,,,,,,,,,,, -322300,4.629049,0.6550534,,,,,,,,,,,,,, -322400,4.3766246,0.6857778,,,,,,,,,,,,,, -322500,4.8220615,0.6810693,,,,,,,,,,,,,, -322600,4.5824313,0.6408514,,,,,,,,,,,,,, -322700,4.7426963,0.68916845,,,,,,,,,,,,,, -322800,4.2549944,0.5622139,,,,,,,,,,,,,, -322900,4.652383,0.6625373,,,,,,,,,,,,,, -323000,4.690531,0.66664886,,,,,,,,,,,,,, -323100,4.9041276,0.69985425,,,,,,,,,,,,,, -323150,,,0.960359513759613,0.1503463238477707,0.7552399635314941,1.04445481300354,50000.0,0.6308000087738037,1.8226555585861208,10000.0,109200.85324692726,113050.1400270462,109200.85324692726,3824.800094604492,13.394735336303713,0.0 -323200,4.836089,0.6248186,,,,,,,,,,,,,, -323300,4.6364717,0.62768143,,,,,,,,,,,,,, -323400,4.060523,0.61081517,,,,,,,,,,,,,, -323500,4.844454,0.6863073,,,,,,,,,,,,,, -323600,4.520371,0.6199918,,,,,,,,,,,,,, -323700,4.45539,0.65937877,,,,,,,,,,,,,, -323800,4.516538,0.63493896,,,,,,,,,,,,,, -323900,4.0808725,0.5405084,,,,,,,,,,,,,, -324000,4.7583156,0.6450133,,,,,,,,,,,,,, -324100,5.0950446,0.65335304,,,,,,,,,,,,,, -324200,4.160568,0.54858994,,,,,,,,,,,,,, -324300,4.251706,0.6428449,,,,,,,,,,,,,, -324400,4.5296564,0.58912355,,,,,,,,,,,,,, -324500,5.039533,0.6452639,,,,,,,,,,,,,, -324600,4.270864,0.66179967,,,,,,,,,,,,,, -324660,,,0.9606186151504515,0.1466918587684631,0.7547999620437622,1.0439778566360474,50000.0,0.631100058555603,1.8214871883392327,10000.0,109710.8930542469,113577.55630922318,109710.8930542469,3842.041315317154,13.47830367088318,0.0 -324700,5.18191,0.5824938,,,,,,,,,,,,,, -324800,4.261688,0.59696686,,,,,,,,,,,,,, -324900,4.925478,0.6770673,,,,,,,,,,,,,, -325000,4.7309427,0.68038607,,,,,,,,,,,,,, -325100,4.105464,0.61918163,,,,,,,,,,,,,, -325200,4.466709,0.574939,,,,,,,,,,,,,, -325300,4.5306735,0.6051161,,,,,,,,,,,,,, -325400,4.369443,0.64926267,,,,,,,,,,,,,, -325500,4.256743,0.6262712,,,,,,,,,,,,,, -325600,4.564369,0.6198607,,,,,,,,,,,,,, -325700,5.0670104,0.6472359,,,,,,,,,,,,,, -325800,4.684453,0.6534741,,,,,,,,,,,,,, -325900,4.633229,0.67459846,,,,,,,,,,,,,, -326000,4.5550084,0.5640547,,,,,,,,,,,,,, -326100,4.355063,0.6238819,,,,,,,,,,,,,, -326171,,,0.9599409699440002,0.1505787074565887,0.755840003490448,1.0444331169128418,50000.0,0.6309000253677368,1.8210031986236568,10000.0,110220.93134093285,114104.812646389,110220.93134093285,3859.1264731884,13.55803608894348,0.0 -326200,4.5509396,0.6137833,,,,,,,,,,,,,, -326300,4.3886924,0.62478334,,,,,,,,,,,,,, -326400,4.350261,0.59338796,,,,,,,,,,,,,, -326500,4.4622183,0.64790326,,,,,,,,,,,,,, -326600,4.2052717,0.5463145,,,,,,,,,,,,,, -326700,4.5482016,0.6306405,,,,,,,,,,,,,, -326800,4.3551283,0.648775,,,,,,,,,,,,,, -326900,4.456793,0.6244011,,,,,,,,,,,,,, -327000,4.0234942,0.51747084,,,,,,,,,,,,,, -327100,4.1586294,0.50346905,,,,,,,,,,,,,, -327200,4.114695,0.5622043,,,,,,,,,,,,,, -327300,4.145337,0.58590704,,,,,,,,,,,,,, -327400,4.2824845,0.5738388,,,,,,,,,,,,,, -327500,4.252728,0.57117563,,,,,,,,,,,,,, -327600,4.5163527,0.62803406,,,,,,,,,,,,,, -327682,,,0.9587252736091614,0.1513209640979766,0.7554799914360046,1.0429681539535522,50000.0,0.631600022315979,1.820167899131775,10000.0,110731.02626633644,114632.11230945589,110731.02626633644,3876.194142580032,13.642706871032717,0.0 -327700,4.65553,0.60876954,,,,,,,,,,,,,, -327800,5.1492233,0.7004529,,,,,,,,,,,,,, -327900,4.3140345,0.51841074,,,,,,,,,,,,,, -328000,4.2366123,0.6362867,,,,,,,,,,,,,, -328100,4.341503,0.64325494,,,,,,,,,,,,,, -328200,4.6208115,0.60713494,,,,,,,,,,,,,, -328300,4.6330037,0.6431429,,,,,,,,,,,,,, -328400,4.4870157,0.6347976,,,,,,,,,,,,,, -328500,4.344413,0.63406134,,,,,,,,,,,,,, -328600,4.4021344,0.6172145,,,,,,,,,,,,,, -328700,5.0015593,0.7359689,,,,,,,,,,,,,, -328800,4.263508,0.60012895,,,,,,,,,,,,,, -328900,4.499646,0.62001944,,,,,,,,,,,,,, -329000,4.6281414,0.6615053,,,,,,,,,,,,,, -329100,4.828665,0.60569316,,,,,,,,,,,,,, -329191,,,0.9612364172935486,0.1466499269008636,0.7554399967193604,1.044033765792847,50000.0,0.6309000253677368,1.822582721710205,10000.0,111240.5034327507,115159.3679318428,111240.5034327507,3893.1527485847473,14.410885095596312,0.0 -329200,4.1955724,0.5664949,,,,,,,,,,,,,, -329300,4.610326,0.63641167,,,,,,,,,,,,,, -329400,4.460232,0.7063536,,,,,,,,,,,,,, -329500,4.4763,0.630289,,,,,,,,,,,,,, -329600,4.7768893,0.6064066,,,,,,,,,,,,,, -329700,4.4608736,0.68684554,,,,,,,,,,,,,, -329800,4.7561746,0.5883056,,,,,,,,,,,,,, -329900,4.5956326,0.61226916,,,,,,,,,,,,,, -330000,4.7792645,0.6287632,,,,,,,,,,,,,, -330100,4.400588,0.6282403,,,,,,,,,,,,,, -330200,4.2335024,0.57173467,,,,,,,,,,,,,, -330300,4.93918,0.6226042,,,,,,,,,,,,,, -330400,4.1278048,0.5585098,,,,,,,,,,,,,, -330500,4.3014603,0.62454045,,,,,,,,,,,,,, -330600,4.734514,0.64194286,,,,,,,,,,,,,, -330700,4.2630777,0.65097606,,,,,,,,,,,,,, -330702,,,0.9597018361091614,0.1489305049180984,0.7549200057983398,1.044543743133545,50000.0,0.6299000382423401,1.8222780227661133,10000.0,111750.51376104356,115686.34260249138,111750.51376104356,3909.984509468079,14.49191951751709,0.0 -330800,4.877652,0.6921866,,,,,,,,,,,,,, -330900,4.8068514,0.63885665,,,,,,,,,,,,,, -331000,4.945281,0.6965366,,,,,,,,,,,,,, -331100,4.1116266,0.5450519,,,,,,,,,,,,,, -331200,4.8288245,0.6455841,,,,,,,,,,,,,, -331300,4.271689,0.6040521,,,,,,,,,,,,,, -331400,4.3940783,0.57327664,,,,,,,,,,,,,, -331500,4.326362,0.5870534,,,,,,,,,,,,,, -331600,4.9645658,0.66542614,,,,,,,,,,,,,, -331700,4.2174497,0.66320795,,,,,,,,,,,,,, -331800,4.3197255,0.6022344,,,,,,,,,,,,,, -331900,4.0442815,0.63706,,,,,,,,,,,,,, -332000,4.578792,0.6467252,,,,,,,,,,,,,, -332100,4.265368,0.67112905,,,,,,,,,,,,,, -332200,4.5045676,0.6552951,,,,,,,,,,,,,, -332213,,,0.9599609375,0.1484025716781616,0.7553600072860718,1.0446621179580688,50000.0,0.6320000290870667,1.8226637840271,10000.0,112260.58879756927,116213.58219194412,112260.58879756927,3927.011519670488,14.57741928100586,0.0 -332300,4.2263,0.6305565,,,,,,,,,,,,,, -332400,4.5479703,0.63505566,,,,,,,,,,,,,, -332500,4.535327,0.62112314,,,,,,,,,,,,,, -332600,4.4125876,0.57486707,,,,,,,,,,,,,, -332700,4.6691704,0.67329156,,,,,,,,,,,,,, -332800,4.628381,0.6441044,,,,,,,,,,,,,, -332900,4.4263177,0.5950648,,,,,,,,,,,,,, -333000,4.495971,0.665187,,,,,,,,,,,,,, -333100,4.2938337,0.632838,,,,,,,,,,,,,, -333200,4.799259,0.5959751,,,,,,,,,,,,,, -333300,4.6612453,0.69609106,,,,,,,,,,,,,, -333400,4.593593,0.65196943,,,,,,,,,,,,,, -333500,4.475958,0.6267179,,,,,,,,,,,,,, -333600,4.944255,0.65927607,,,,,,,,,,,,,, -333700,4.9761567,0.63156635,,,,,,,,,,,,,, -333723,,,0.959741711616516,0.1488601267337799,0.7554999589920044,1.0437555313110352,50000.0,0.6308000087738037,1.822697997093201,10000.0,112770.559705019,116740.83248353004,112770.559705019,3944.153341770172,14.66352605819702,0.0 -333800,4.5928183,0.62189275,,,,,,,,,,,,,, -333900,4.2914124,0.52298206,,,,,,,,,,,,,, -334000,4.827385,0.6626053,,,,,,,,,,,,,, -334100,5.0301375,0.6629652,,,,,,,,,,,,,, -334200,4.8425565,0.6238275,,,,,,,,,,,,,, -334300,4.3217273,0.6309262,,,,,,,,,,,,,, -334400,4.231749,0.62411106,,,,,,,,,,,,,, -334500,4.55317,0.6039857,,,,,,,,,,,,,, -334600,4.709236,0.6121796,,,,,,,,,,,,,, -334700,4.1906657,0.5847335,,,,,,,,,,,,,, -334800,4.608965,0.70103216,,,,,,,,,,,,,, -334900,4.8681355,0.61300695,,,,,,,,,,,,,, -335000,4.386235,0.60474503,,,,,,,,,,,,,, -335100,4.589452,0.6542378,,,,,,,,,,,,,, -335200,4.701464,0.71090215,,,,,,,,,,,,,, -335234,,,0.9608976244926452,0.1474035531282425,0.7553199529647827,1.0441795587539673,50000.0,0.6306000351905823,1.8214054107666016,10000.0,113280.64534378052,117268.18897390366,113280.64534378052,3961.2871708869934,14.749180316925049,0.0 -335300,4.5055666,0.64645547,,,,,,,,,,,,,, -335400,5.082211,0.70516264,,,,,,,,,,,,,, -335500,4.257423,0.6300465,,,,,,,,,,,,,, -335600,5.037985,0.65251017,,,,,,,,,,,,,, -335700,4.6513276,0.59515214,,,,,,,,,,,,,, -335800,5.075249,0.74635226,,,,,,,,,,,,,, -335900,4.406404,0.59897494,,,,,,,,,,,,,, -336000,4.6898,0.632119,,,,,,,,,,,,,, -336100,4.5171857,0.6237735,,,,,,,,,,,,,, -336200,4.9540014,0.6849362,,,,,,,,,,,,,, -336300,4.621771,0.6017549,,,,,,,,,,,,,, -336400,5.013174,0.7136148,,,,,,,,,,,,,, -336500,4.728623,0.6295944,,,,,,,,,,,,,, -336600,4.468341,0.64269346,,,,,,,,,,,,,, -336700,4.7450933,0.58966064,,,,,,,,,,,,,, -336744,,,0.9601004123687744,0.1491908580064773,0.7552799582481384,1.0451172590255735,50000.0,0.6309000253677368,1.82236397266388,10000.0,113790.71906757356,117795.58216953278,113790.71906757356,3978.467176914215,14.835916996002195,0.0 -336800,4.459439,0.60720325,,,,,,,,,,,,,, -336900,4.1271205,0.56447166,,,,,,,,,,,,,, -337000,4.1079674,0.58552516,,,,,,,,,,,,,, -337100,4.487074,0.610185,,,,,,,,,,,,,, -337200,4.35584,0.5600277,,,,,,,,,,,,,, -337300,4.723687,0.6360665,,,,,,,,,,,,,, -337400,4.4181323,0.6474556,,,,,,,,,,,,,, -337500,4.3662167,0.6440418,,,,,,,,,,,,,, -337600,4.629616,0.608811,,,,,,,,,,,,,, -337700,4.71268,0.63024074,,,,,,,,,,,,,, -337800,4.375984,0.6212273,,,,,,,,,,,,,, -337900,4.876711,0.7067719,,,,,,,,,,,,,, -338000,4.1271496,0.6257124,,,,,,,,,,,,,, -338100,3.873073,0.5704201,,,,,,,,,,,,,, -338200,4.316396,0.6172783,,,,,,,,,,,,,, -338255,,,0.9598014950752258,0.1497568637132644,0.7551400065422058,1.0436608791351318,50000.0,0.6303000450134277,1.821576833724976,10000.0,114300.66180181503,118322.4094736576,114300.66180181503,3995.216495990753,14.918900728225708,0.0 -338300,4.254495,0.64189756,,,,,,,,,,,,,, -338400,4.9650297,0.6178255,,,,,,,,,,,,,, -338500,4.1566644,0.5540129,,,,,,,,,,,,,, -338600,4.7209206,0.71130484,,,,,,,,,,,,,, -338700,4.1845813,0.56409717,,,,,,,,,,,,,, -338800,4.5874505,0.5732573,,,,,,,,,,,,,, -338900,4.824207,0.6055732,,,,,,,,,,,,,, -339000,4.796587,0.56474495,,,,,,,,,,,,,, -339100,4.971512,0.7310211,,,,,,,,,,,,,, -339200,4.2943087,0.5714293,,,,,,,,,,,,,, -339300,4.6432657,0.68106717,,,,,,,,,,,,,, -339400,4.6029477,0.57286626,,,,,,,,,,,,,, -339500,4.611236,0.61560774,,,,,,,,,,,,,, -339600,4.9444375,0.63068247,,,,,,,,,,,,,, -339700,4.558613,0.6359476,,,,,,,,,,,,,, -339765,,,0.9605189561843872,0.1501489579677581,0.7548399567604065,1.0430506467819214,50000.0,0.6299000382423401,1.82118558883667,10000.0,114810.5943918228,118849.54559993744,114810.5943918228,4012.285629749298,15.001896142959597,0.0 -339800,4.423157,0.6022332,,,,,,,,,,,,,, -339900,4.2769217,0.66990167,,,,,,,,,,,,,, -340000,4.6622934,0.6780727,,,,,,,,,,,,,, -340100,4.1462145,0.6094755,,,,,,,,,,,,,, -340200,4.93329,0.62004256,,,,,,,,,,,,,, -340300,4.8177304,0.7553351,,,,,,,,,,,,,, -340400,4.943122,0.6170428,,,,,,,,,,,,,, -340500,4.623042,0.5874324,,,,,,,,,,,,,, -340600,4.226245,0.6060945,,,,,,,,,,,,,, -340700,5.08736,0.6980356,,,,,,,,,,,,,, -340800,4.4184346,0.6193938,,,,,,,,,,,,,, -340900,4.5824046,0.6226447,,,,,,,,,,,,,, -341000,4.459386,0.5779933,,,,,,,,,,,,,, -341100,4.434814,0.61175513,,,,,,,,,,,,,, -341200,4.4957504,0.63518673,,,,,,,,,,,,,, -341276,,,0.9614955186843872,0.1441345363855362,0.7552599906921387,1.0443507432937622,50000.0,0.6304000020027161,1.8239904642105105,10000.0,115320.70908522606,119376.82467389108,115320.70908522606,4029.317915916443,15.082010746002195,0.0 -341300,4.272789,0.5884125,,,,,,,,,,,,,, -341400,4.7963953,0.6237286,,,,,,,,,,,,,, -341500,4.5447025,0.6303843,,,,,,,,,,,,,, -341600,4.725703,0.6166919,,,,,,,,,,,,,, -341700,4.984755,0.5761827,,,,,,,,,,,,,, -341800,4.647646,0.6751574,,,,,,,,,,,,,, -341900,4.4531837,0.5877343,,,,,,,,,,,,,, -342000,4.6947246,0.6888159,,,,,,,,,,,,,, -342100,3.9659967,0.49472925,,,,,,,,,,,,,, -342200,4.380357,0.5603084,,,,,,,,,,,,,, -342300,4.47288,0.6279193,,,,,,,,,,,,,, -342400,4.420159,0.65655017,,,,,,,,,,,,,, -342500,4.810935,0.6799359,,,,,,,,,,,,,, -342600,4.8821955,0.58613014,,,,,,,,,,,,,, -342700,4.037409,0.60332215,,,,,,,,,,,,,, -342786,,,0.9616549611091614,0.145839437842369,0.754859983921051,1.0440630912780762,50000.0,0.6305000185966492,1.821317672729492,10000.0,115830.62089276314,119903.89597392082,115830.62089276314,4046.286930322647,15.220804691314695,0.0 -342800,4.583791,0.59567416,,,,,,,,,,,,,, -342900,4.5376887,0.57133657,,,,,,,,,,,,,, -343000,4.4068003,0.58070225,,,,,,,,,,,,,, -343100,4.2815776,0.6399083,,,,,,,,,,,,,, -343200,4.8243303,0.6825659,,,,,,,,,,,,,, -343300,4.1772985,0.5995747,,,,,,,,,,,,,, -343400,4.4391427,0.60002923,,,,,,,,,,,,,, -343500,4.839648,0.64250207,,,,,,,,,,,,,, -343600,5.195608,0.65148485,,,,,,,,,,,,,, -343700,4.2413545,0.5912899,,,,,,,,,,,,,, -343800,4.782174,0.6311315,,,,,,,,,,,,,, -343900,4.722396,0.61918575,,,,,,,,,,,,,, -344000,4.4541306,0.61606294,,,,,,,,,,,,,, -344100,4.8802414,0.5747958,,,,,,,,,,,,,, -344200,4.843737,0.6480543,,,,,,,,,,,,,, -344296,,,0.9596021771430968,0.1497314274311065,0.7553199529647827,1.043811321258545,50000.0,0.6310000419616699,1.8192788362503047,10000.0,116340.69873285294,120431.06862139702,116340.69873285294,4063.241968870163,15.308415651321411,0.0 -344300,4.242319,0.6246077,,,,,,,,,,,,,, -344400,4.300746,0.63289225,,,,,,,,,,,,,, -344500,4.2662377,0.5833068,,,,,,,,,,,,,, -344600,4.4942136,0.67131233,,,,,,,,,,,,,, -344700,4.427407,0.6888087,,,,,,,,,,,,,, -344800,4.337468,0.6100854,,,,,,,,,,,,,, -344900,4.3510633,0.650058,,,,,,,,,,,,,, -345000,4.9171414,0.77851886,,,,,,,,,,,,,, -345100,4.144677,0.58690274,,,,,,,,,,,,,, -345200,4.262505,0.59792507,,,,,,,,,,,,,, -345300,4.7371235,0.6365441,,,,,,,,,,,,,, -345400,4.6337876,0.6148484,,,,,,,,,,,,,, -345500,5.3440633,0.5883683,,,,,,,,,,,,,, -345600,4.701932,0.60910356,,,,,,,,,,,,,, -345700,4.268211,0.60581857,,,,,,,,,,,,,, -345800,4.777165,0.5808233,,,,,,,,,,,,,, -345806,,,0.9596420526504515,0.1502621620893478,0.7553399801254272,1.0439318418502808,50000.0,0.631100058555603,1.8209062814712524,10000.0,116850.72356057169,120959.25270080566,116850.72356057169,4081.250019550324,15.40729808807373,0.0 -345900,4.6148534,0.6852627,,,,,,,,,,,,,, -346000,4.9040685,0.64757735,,,,,,,,,,,,,, -346100,4.5874047,0.6014814,,,,,,,,,,,,,, -346200,4.581597,0.66386265,,,,,,,,,,,,,, -346300,5.248886,0.70879817,,,,,,,,,,,,,, -346400,4.4717817,0.65869033,,,,,,,,,,,,,, -346500,4.4033227,0.6086572,,,,,,,,,,,,,, -346600,4.174605,0.6033028,,,,,,,,,,,,,, -346700,4.373903,0.6289367,,,,,,,,,,,,,, -346800,4.3670144,0.60728294,,,,,,,,,,,,,, -346900,4.2979126,0.6043406,,,,,,,,,,,,,, -347000,4.593967,0.63936245,,,,,,,,,,,,,, -347100,4.527046,0.64370924,,,,,,,,,,,,,, -347200,4.673721,0.6140696,,,,,,,,,,,,,, -347300,4.551396,0.6079234,,,,,,,,,,,,,, -347316,,,0.9595025181770324,0.1501846611499786,0.7553199529647827,1.044198513031006,50000.0,0.6319000124931335,1.822487950325012,10000.0,117360.82094693184,121486.52737092972,117360.82094693184,4098.288117408752,15.495001316070557,0.0 -347400,4.4038386,0.6138866,,,,,,,,,,,,,, -347500,4.5281324,0.60503536,,,,,,,,,,,,,, -347600,4.9630804,0.65116954,,,,,,,,,,,,,, -347700,4.436307,0.6941816,,,,,,,,,,,,,, -347800,4.520971,0.6319225,,,,,,,,,,,,,, -347900,4.905495,0.62577623,,,,,,,,,,,,,, -348000,4.063635,0.58582395,,,,,,,,,,,,,, -348100,4.505288,0.6001453,,,,,,,,,,,,,, -348200,4.5993724,0.58635956,,,,,,,,,,,,,, -348300,4.4298615,0.57216763,,,,,,,,,,,,,, -348400,4.838763,0.6592635,,,,,,,,,,,,,, -348500,4.5355954,0.6166656,,,,,,,,,,,,,, -348600,4.334435,0.58765495,,,,,,,,,,,,,, -348700,4.415137,0.63035583,,,,,,,,,,,,,, -348800,4.36604,0.6525249,,,,,,,,,,,,,, -348827,,,0.9618343114852904,0.1459579169750213,0.7551599740982056,1.0425646305084229,50000.0,0.6308000087738037,1.8201829195022583,10000.0,117870.91880130768,122013.91163277626,117870.91880130768,4115.436737775803,15.581097602844238,0.0 -348900,4.52231,0.6680585,,,,,,,,,,,,,, -349000,4.665917,0.6030464,,,,,,,,,,,,,, -349100,4.74714,0.61679834,,,,,,,,,,,,,, -349200,4.2783995,0.51688373,,,,,,,,,,,,,, -349300,4.383641,0.58200586,,,,,,,,,,,,,, -349400,4.3951054,0.58609104,,,,,,,,,,,,,, -349500,4.9116235,0.6404685,,,,,,,,,,,,,, -349600,4.529157,0.59852195,,,,,,,,,,,,,, -349700,4.1594706,0.61860836,,,,,,,,,,,,,, -349800,4.3241987,0.71111953,,,,,,,,,,,,,, -349900,4.6445374,0.6900372,,,,,,,,,,,,,, -350000,4.250851,0.5779382,,,,,,,,,,,,,, -350100,4.4709306,0.61105204,,,,,,,,,,,,,, -350200,4.542246,0.59819806,,,,,,,,,,,,,, -350300,4.3663154,0.62007624,,,,,,,,,,,,,, -350338,,,0.9604790806770324,0.1469176858663559,0.7550199627876282,1.0435841083526611,50000.0,0.6315000057220459,1.821236252784729,10000.0,118380.99407410622,122541.0457561016,118380.99407410622,4132.369649171829,15.654744863510132,0.0 -350400,4.3552375,0.63222396,,,,,,,,,,,,,, -350500,4.916649,0.6669859,,,,,,,,,,,,,, -350600,4.7991486,0.661814,,,,,,,,,,,,,, -350700,4.683614,0.64717543,,,,,,,,,,,,,, -350800,4.5927286,0.638458,,,,,,,,,,,,,, -350900,4.579399,0.66555935,,,,,,,,,,,,,, -351000,4.603587,0.61910594,,,,,,,,,,,,,, -351100,4.2912574,0.6286015,,,,,,,,,,,,,, -351200,4.436674,0.6248709,,,,,,,,,,,,,, -351300,4.737446,0.60246897,,,,,,,,,,,,,, -351400,4.8560996,0.65759325,,,,,,,,,,,,,, -351500,4.307962,0.6410311,,,,,,,,,,,,,, -351600,5.0735435,0.70310533,,,,,,,,,,,,,, -351700,4.359187,0.6086449,,,,,,,,,,,,,, -351800,4.267125,0.6167454,,,,,,,,,,,,,, -351849,,,0.9606584906578064,0.1471874266862869,0.7547799944877625,1.0444236993789673,50000.0,0.6301000118255615,1.8209481239318848,10000.0,118891.1866297722,123068.48770236968,118891.1866297722,4149.481686353684,15.740437507629396,0.0 -351900,4.4612484,0.6727486,,,,,,,,,,,,,, -352000,4.3669395,0.6244434,,,,,,,,,,,,,, -352100,3.8358657,0.5223379,,,,,,,,,,,,,, -352200,4.674436,0.67356014,,,,,,,,,,,,,, -352300,4.785098,0.6551326,,,,,,,,,,,,,, -352400,4.4894514,0.54079384,,,,,,,,,,,,,, -352500,4.6137767,0.6167978,,,,,,,,,,,,,, -352600,4.7921243,0.646137,,,,,,,,,,,,,, -352700,4.1894784,0.57804346,,,,,,,,,,,,,, -352800,4.2956266,0.6322135,,,,,,,,,,,,,, -352900,4.8715134,0.65863407,,,,,,,,,,,,,, -353000,4.0744524,0.57943857,,,,,,,,,,,,,, -353100,4.8528852,0.6925801,,,,,,,,,,,,,, -353200,4.308561,0.63503194,,,,,,,,,,,,,, -353300,4.828898,0.6204878,,,,,,,,,,,,,, -353359,,,0.9605787396430968,0.1489564776420593,0.7554199695587158,1.0429089069366455,50000.0,0.6307000517845154,1.8223930597305296,10000.0,119401.21796488762,123595.68791532516,119401.21796488762,4166.511382102966,15.826533794403076,0.0 -353400,4.3518033,0.6516839,,,,,,,,,,,,,, -353500,4.5762053,0.6554612,,,,,,,,,,,,,, -353600,4.6177044,0.5746867,,,,,,,,,,,,,, -353700,4.4713073,0.58071935,,,,,,,,,,,,,, -353800,4.5887885,0.6712977,,,,,,,,,,,,,, -353900,4.249141,0.6041452,,,,,,,,,,,,,, -354000,4.18273,0.60654074,,,,,,,,,,,,,, -354100,4.799495,0.6710135,,,,,,,,,,,,,, -354200,4.690852,0.60071445,,,,,,,,,,,,,, -354300,4.594671,0.59457684,,,,,,,,,,,,,, -354400,4.4927425,0.62429607,,,,,,,,,,,,,, -354500,4.3716598,0.6176541,,,,,,,,,,,,,, -354600,4.504863,0.5553749,,,,,,,,,,,,,, -354700,4.479429,0.64082634,,,,,,,,,,,,,, -354800,4.4275045,0.6265043,,,,,,,,,,,,,, -354870,,,0.9594626426696776,0.149523377418518,0.7547399997711182,1.0451470613479614,50000.0,0.6306000351905823,1.8234483003616333,10000.0,119911.32834887505,124122.98861145972,119911.32834887505,4183.558404684067,15.91768193244934,0.0 -354900,3.963908,0.5647302,,,,,,,,,,,,,, -355000,5.1319327,0.67481273,,,,,,,,,,,,,, -355100,4.280533,0.5775637,,,,,,,,,,,,,, -355200,4.618402,0.6528037,,,,,,,,,,,,,, -355300,4.617383,0.601248,,,,,,,,,,,,,, -355400,4.671448,0.70432705,,,,,,,,,,,,,, -355500,4.5465946,0.7029797,,,,,,,,,,,,,, -355600,4.2615542,0.5842426,,,,,,,,,,,,,, -355700,4.233365,0.5359529,,,,,,,,,,,,,, -355800,4.481881,0.63262147,,,,,,,,,,,,,, -355900,4.4034677,0.5408355,,,,,,,,,,,,,, -356000,4.304332,0.6467966,,,,,,,,,,,,,, -356100,4.359117,0.5865049,,,,,,,,,,,,,, -356200,4.6241403,0.6333636,,,,,,,,,,,,,, -356300,4.4133477,0.60680336,,,,,,,,,,,,,, -356380,,,0.9622528553009032,0.142277330160141,0.755620002746582,1.0428621768951416,50000.0,0.6318000555038452,1.8212215900421145,10000.0,120421.18785190582,124650.18924379347,120421.18785190582,4200.7578365802765,16.00704550743103,0.0 -356400,4.560022,0.60109735,,,,,,,,,,,,,, -356500,4.8450007,0.67057794,,,,,,,,,,,,,, -356600,5.072633,0.6811051,,,,,,,,,,,,,, -356700,4.646401,0.5813528,,,,,,,,,,,,,, -356800,4.5667434,0.64226913,,,,,,,,,,,,,, -356900,4.587924,0.60893744,,,,,,,,,,,,,, -357000,4.5561996,0.5710114,,,,,,,,,,,,,, -357100,4.3865275,0.60570854,,,,,,,,,,,,,, -357200,4.750735,0.64390785,,,,,,,,,,,,,, -357300,4.5486794,0.6810556,,,,,,,,,,,,,, -357400,4.5227776,0.6678262,,,,,,,,,,,,,, -357500,4.836165,0.66761315,,,,,,,,,,,,,, -357600,4.2590785,0.6517814,,,,,,,,,,,,,, -357700,4.6132536,0.6649428,,,,,,,,,,,,,, -357800,4.091428,0.56340736,,,,,,,,,,,,,, -357891,,,0.9612563848495485,0.1443909555673599,0.7553199529647827,1.044643521308899,50000.0,0.6312000155448914,1.8233301639556885,10000.0,120931.28691577911,125177.22383451462,120931.28691577911,4217.550870895386,16.096714735031128,0.0 -357900,4.3104477,0.5909636,,,,,,,,,,,,,, -358000,4.3636994,0.60759854,,,,,,,,,,,,,, -358100,4.463299,0.6545788,,,,,,,,,,,,,, -358200,4.9379916,0.5967696,,,,,,,,,,,,,, -358300,4.760008,0.6938922,,,,,,,,,,,,,, -358400,4.6995316,0.63945997,,,,,,,,,,,,,, -358500,4.3232403,0.60232127,,,,,,,,,,,,,, -358600,4.843598,0.68675107,,,,,,,,,,,,,, -358700,4.5790634,0.6258404,,,,,,,,,,,,,, -358800,4.3866496,0.58759135,,,,,,,,,,,,,, -358900,4.2730904,0.6180508,,,,,,,,,,,,,, -359000,5.025772,0.5923897,,,,,,,,,,,,,, -359100,4.352654,0.6495108,,,,,,,,,,,,,, -359200,4.538395,0.60005164,,,,,,,,,,,,,, -359300,4.7085896,0.6648318,,,,,,,,,,,,,, -359400,5.0101895,0.6625699,,,,,,,,,,,,,, -359401,,,0.9583665132522584,0.1501764953136444,0.7549999952316284,1.0434049367904663,50000.0,0.6303000450134277,1.8215309381484983,10000.0,121441.40004301073,125704.3516037464,121441.40004301073,4234.424687862396,16.18539547920227,0.0 -359500,4.7897735,0.6836976,,,,,,,,,,,,,, -359600,4.5550275,0.69216436,,,,,,,,,,,,,, -359700,4.329111,0.6057434,,,,,,,,,,,,,, -359800,4.262137,0.6580106,,,,,,,,,,,,,, -359900,4.3743806,0.62670547,,,,,,,,,,,,,, -360000,4.4402742,0.5822471,,,,,,,,,,,,,, -360100,4.4830904,0.6410933,,,,,,,,,,,,,, -360200,4.4526587,0.5883573,,,,,,,,,,,,,, -360300,4.4597836,0.5660503,,,,,,,,,,,,,, -360400,4.692614,0.6259405,,,,,,,,,,,,,, -360500,5.1165,0.72530204,,,,,,,,,,,,,, -360600,4.3441296,0.53006434,,,,,,,,,,,,,, -360700,4.43092,0.60397524,,,,,,,,,,,,,, -360800,4.433883,0.63399196,,,,,,,,,,,,,, -360900,3.9989083,0.5751542,,,,,,,,,,,,,, -360911,,,0.9622129797935486,0.1466683894395828,0.7549999952316284,1.0447399616241455,50000.0,0.6305000185966492,1.8208733797073364,10000.0,121951.36361050606,126231.38405561449,121951.36361050606,4251.351005554199,16.275821685791016,0.0 -361000,4.513967,0.6536544,,,,,,,,,,,,,, -361100,4.4693823,0.6010356,,,,,,,,,,,,,, -361200,4.995144,0.7204845,,,,,,,,,,,,,, -361300,4.755484,0.64588624,,,,,,,,,,,,,, -361400,4.253036,0.6266363,,,,,,,,,,,,,, -361500,4.5550933,0.65578395,,,,,,,,,,,,,, -361600,4.1434417,0.52547956,,,,,,,,,,,,,, -361700,4.669551,0.65778685,,,,,,,,,,,,,, -361800,4.2065115,0.5934024,,,,,,,,,,,,,, -361900,4.5464773,0.6121646,,,,,,,,,,,,,, -362000,4.5445747,0.6128238,,,,,,,,,,,,,, -362100,4.4068995,0.6269909,,,,,,,,,,,,,, -362200,4.7935295,0.628448,,,,,,,,,,,,,, -362300,4.859017,0.5776508,,,,,,,,,,,,,, -362400,4.5651464,0.6312303,,,,,,,,,,,,,, -362422,,,0.9601004123687744,0.1505505144596099,0.7554199695587158,1.0432157516479492,50000.0,0.6313000321388245,1.8205649852752688,10000.0,122461.54856848715,126758.89103007317,122461.54856848715,4268.519370794296,16.37822437286377,0.0 -362500,4.4529824,0.65336853,,,,,,,,,,,,,, -362600,4.193777,0.6339827,,,,,,,,,,,,,, -362700,4.4721174,0.6469021,,,,,,,,,,,,,, -362800,4.2510767,0.64848447,,,,,,,,,,,,,, -362900,4.5377593,0.63463473,,,,,,,,,,,,,, -363000,4.44877,0.59597456,,,,,,,,,,,,,, -363100,4.426653,0.6020628,,,,,,,,,,,,,, -363200,4.219294,0.57357335,,,,,,,,,,,,,, -363300,4.717938,0.6478136,,,,,,,,,,,,,, -363400,4.540447,0.5738151,,,,,,,,,,,,,, -363500,4.5217433,0.64178413,,,,,,,,,,,,,, -363600,4.3038244,0.5433017,,,,,,,,,,,,,, -363700,4.4907637,0.57658494,,,,,,,,,,,,,, -363800,4.400172,0.60673696,,,,,,,,,,,,,, -363900,4.7318287,0.6238963,,,,,,,,,,,,,, -363933,,,0.9612165093421936,0.1459295153617859,0.7549999952316284,1.0441726446151731,50000.0,0.6302000284194946,1.821946144104004,10000.0,122971.61852169035,127286.1263434887,122971.61852169035,4285.543945074081,16.466397523880005,0.0 -364000,4.2914743,0.5889954,,,,,,,,,,,,,, -364100,5.015732,0.64982104,,,,,,,,,,,,,, -364200,4.802732,0.5985404,,,,,,,,,,,,,, -364300,4.1568155,0.5888437,,,,,,,,,,,,,, -364400,4.1230693,0.63191396,,,,,,,,,,,,,, -364500,4.6392097,0.6002798,,,,,,,,,,,,,, -364600,4.7183504,0.6724505,,,,,,,,,,,,,, -364700,4.3215227,0.6001207,,,,,,,,,,,,,, -364800,4.7556195,0.6412364,,,,,,,,,,,,,, -364900,4.5296454,0.62201816,,,,,,,,,,,,,, -365000,4.4633117,0.57344675,,,,,,,,,,,,,, -365100,4.7073402,0.6848998,,,,,,,,,,,,,, -365200,4.567589,0.66296136,,,,,,,,,,,,,, -365300,4.3409185,0.63166535,,,,,,,,,,,,,, -365400,4.9191637,0.7289634,,,,,,,,,,,,,, -365443,,,0.957409918308258,0.1546299159526825,0.7553799748420715,1.043689846992493,50000.0,0.6312000155448914,1.8219724893569944,10000.0,123481.5219783783,127813.21567821504,123481.5219783783,4302.584145069122,16.560025453567505,0.0 -365500,4.3712997,0.5683696,,,,,,,,,,,,,, -365600,4.8764873,0.6792548,,,,,,,,,,,,,, -365700,4.740137,0.7261885,,,,,,,,,,,,,, -365800,4.6600595,0.62799686,,,,,,,,,,,,,, -365900,4.2957115,0.5815321,,,,,,,,,,,,,, -366000,4.62306,0.59224534,,,,,,,,,,,,,, -366100,4.4621587,0.59029484,,,,,,,,,,,,,, -366200,5.0703955,0.6636251,,,,,,,,,,,,,, -366300,4.9736223,0.65408695,,,,,,,,,,,,,, -366400,4.3628707,0.59298795,,,,,,,,,,,,,, -366500,5.124827,0.637763,,,,,,,,,,,,,, -366600,4.7188745,0.63072485,,,,,,,,,,,,,, -366700,4.5951138,0.6431401,,,,,,,,,,,,,, -366800,4.246493,0.62094796,,,,,,,,,,,,,, -366900,4.487307,0.6252215,,,,,,,,,,,,,, -366954,,,0.960379421710968,0.1479264050722122,0.7549799680709839,1.043339490890503,50000.0,0.6315000057220459,1.8199297189712524,10000.0,123991.6851592064,128340.6479599476,123991.6851592064,4319.705858469009,16.655137300491333,0.0 -367000,4.138177,0.66801345,,,,,,,,,,,,,, -367100,4.830752,0.6297987,,,,,,,,,,,,,, -367200,4.756441,0.6473847,,,,,,,,,,,,,, -367300,4.8284264,0.6876781,,,,,,,,,,,,,, -367400,4.1236,0.5005162,,,,,,,,,,,,,, -367500,4.577673,0.5943117,,,,,,,,,,,,,, -367600,4.0401006,0.5629032,,,,,,,,,,,,,, -367700,4.317487,0.60217047,,,,,,,,,,,,,, -367800,4.348481,0.64122355,,,,,,,,,,,,,, -367900,4.405948,0.5696329,,,,,,,,,,,,,, -368000,4.5262403,0.671677,,,,,,,,,,,,,, -368100,4.2280846,0.6844297,,,,,,,,,,,,,, -368200,4.232728,0.55021703,,,,,,,,,,,,,, -368300,4.8117394,0.66816455,,,,,,,,,,,,,, -368400,4.6381645,0.68338376,,,,,,,,,,,,,, -368464,,,0.960758090019226,0.1465861648321151,0.7549999952316284,1.0438544750213623,50000.0,0.631100058555603,1.8230565786361688,10000.0,124501.6871304512,128867.89737272264,124501.6871304512,4336.812728643417,16.74353575706482,0.0 -368500,4.9562817,0.609864,,,,,,,,,,,,,, -368600,4.5569725,0.6248079,,,,,,,,,,,,,, -368700,4.8324986,0.6070427,,,,,,,,,,,,,, -368800,4.9670215,0.6289196,,,,,,,,,,,,,, -368900,4.241761,0.5247057,,,,,,,,,,,,,, -369000,4.5735106,0.5923518,,,,,,,,,,,,,, -369100,4.569281,0.68267137,,,,,,,,,,,,,, -369200,4.9395013,0.63230616,,,,,,,,,,,,,, -369300,4.477403,0.6465312,,,,,,,,,,,,,, -369400,4.6017904,0.71604383,,,,,,,,,,,,,, -369500,4.438993,0.68748176,,,,,,,,,,,,,, -369600,4.8490696,0.66352594,,,,,,,,,,,,,, -369700,4.612499,0.71449685,,,,,,,,,,,,,, -369800,4.792235,0.6559474,,,,,,,,,,,,,, -369900,4.499542,0.64548326,,,,,,,,,,,,,, -369974,,,0.9587850570678712,0.1517333090305328,0.7553399801254272,1.0435724258422852,50000.0,0.6308000087738037,1.8213281631469729,10000.0,125011.75234508514,129395.1438281536,125011.75234508514,4353.856767416,16.82906484603882,0.0 -370000,4.7902703,0.66737825,,,,,,,,,,,,,, -370100,4.7013097,0.57920796,,,,,,,,,,,,,, -370200,4.4130692,0.66704345,,,,,,,,,,,,,, -370300,4.027591,0.58194536,,,,,,,,,,,,,, -370400,4.1727686,0.588007,,,,,,,,,,,,,, -370500,4.693547,0.67008156,,,,,,,,,,,,,, -370600,4.858539,0.600057,,,,,,,,,,,,,, -370700,4.4746003,0.58520496,,,,,,,,,,,,,, -370800,4.582694,0.66885096,,,,,,,,,,,,,, -370900,4.17799,0.5827849,,,,,,,,,,,,,, -371000,4.455674,0.6779199,,,,,,,,,,,,,, -371100,4.5998793,0.6873478,,,,,,,,,,,,,, -371200,4.1638923,0.560439,,,,,,,,,,,,,, -371300,4.451368,0.6103659,,,,,,,,,,,,,, -371400,4.0962553,0.5900583,,,,,,,,,,,,,, -371484,,,0.9606186151504515,0.1473256200551986,0.7552199959754944,1.0428035259246826,50000.0,0.631600022315979,1.821453213691712,10000.0,125521.72247958183,129922.26717352869,125521.72247958183,4370.871035575867,16.91614294052124,0.0 -371500,4.538962,0.6733291,,,,,,,,,,,,,, -371600,4.911783,0.64404684,,,,,,,,,,,,,, -371700,4.317806,0.6701954,,,,,,,,,,,,,, -371800,4.488081,0.56060964,,,,,,,,,,,,,, -371900,4.5767407,0.66830456,,,,,,,,,,,,,, -372000,4.8449626,0.66501105,,,,,,,,,,,,,, -372100,4.3708935,0.5871787,,,,,,,,,,,,,, -372200,4.6855555,0.63631344,,,,,,,,,,,,,, -372300,4.5380483,0.54665965,,,,,,,,,,,,,, -372400,4.2055335,0.6029309,,,,,,,,,,,,,, -372500,4.7632127,0.694069,,,,,,,,,,,,,, -372600,4.5205173,0.67689204,,,,,,,,,,,,,, -372700,4.7433853,0.65508825,,,,,,,,,,,,,, -372800,4.299359,0.5572078,,,,,,,,,,,,,, -372900,4.429009,0.6228867,,,,,,,,,,,,,, -372994,,,0.9604392051696776,0.1476263850927353,0.7548799514770508,1.0441089868545532,50000.0,0.6306000351905823,1.819501638412476,10000.0,126031.5854511261,130449.29526090622,126031.5854511261,4387.892452478409,17.007720232009888,0.0 -373000,4.7613363,0.65740466,,,,,,,,,,,,,, -373100,4.4663334,0.6476818,,,,,,,,,,,,,, -373200,4.8593044,0.6561755,,,,,,,,,,,,,, -373300,4.218313,0.5515928,,,,,,,,,,,,,, -373400,4.7897425,0.58495015,,,,,,,,,,,,,, -373500,4.7998056,0.6545365,,,,,,,,,,,,,, -373600,4.2580833,0.6083682,,,,,,,,,,,,,, -373700,3.990518,0.5936716,,,,,,,,,,,,,, -373800,4.6262207,0.6456147,,,,,,,,,,,,,, -373900,4.5577188,0.6108813,,,,,,,,,,,,,, -374000,4.4325733,0.5216349,,,,,,,,,,,,,, -374100,4.425458,0.561444,,,,,,,,,,,,,, -374200,4.5405717,0.55679226,,,,,,,,,,,,,, -374300,4.5346045,0.61355126,,,,,,,,,,,,,, -374400,4.9151664,0.7338946,,,,,,,,,,,,,, -374500,4.4427447,0.641056,,,,,,,,,,,,,, -374504,,,0.9590441584587096,0.1508360058069229,0.7549799680709839,1.0448273420333862,50000.0,0.6314000487327576,1.8218950033187864,10000.0,126541.45552277564,130976.33390402794,126541.45552277564,4404.918319940567,17.09819459915161,0.0 -374600,4.5782437,0.5680924,,,,,,,,,,,,,, -374700,4.6443944,0.62643147,,,,,,,,,,,,,, -374800,4.8936257,0.64875627,,,,,,,,,,,,,, -374900,4.4406667,0.65800804,,,,,,,,,,,,,, -375000,4.5575395,0.5897244,,,,,,,,,,,,,, -375100,4.0368457,0.55119157,,,,,,,,,,,,,, -375200,4.58173,0.60787314,,,,,,,,,,,,,, -375300,4.891362,0.7199491,,,,,,,,,,,,,, -375400,4.579065,0.6825095,,,,,,,,,,,,,, -375500,4.5295753,0.62308216,,,,,,,,,,,,,, -375600,4.606251,0.62571466,,,,,,,,,,,,,, -375700,4.412037,0.65923905,,,,,,,,,,,,,, -375800,4.7292023,0.66017437,,,,,,,,,,,,,, -375900,4.4588885,0.6121381,,,,,,,,,,,,,, -376000,4.586937,0.63560647,,,,,,,,,,,,,, -376014,,,0.9611168503761292,0.1466236859560012,0.7553799748420715,1.0439865589141846,50000.0,0.6307000517845154,1.8222171068191528,10000.0,127051.36449313164,131503.63059139252,127051.36449313164,4422.162630319595,17.189462661743164,0.0 -376100,4.8442783,0.6551793,,,,,,,,,,,,,, -376200,4.532183,0.706795,,,,,,,,,,,,,, -376300,4.3497596,0.67346144,,,,,,,,,,,,,, -376400,4.957488,0.6453266,,,,,,,,,,,,,, -376500,4.72631,0.615408,,,,,,,,,,,,,, -376600,4.246066,0.58073586,,,,,,,,,,,,,, -376700,5.1016865,0.6748232,,,,,,,,,,,,,, -376800,4.287272,0.5757633,,,,,,,,,,,,,, -376900,4.472467,0.583786,,,,,,,,,,,,,, -377000,4.496314,0.61616784,,,,,,,,,,,,,, -377100,4.3796554,0.61633843,,,,,,,,,,,,,, -377200,4.4912834,0.6077954,,,,,,,,,,,,,, -377300,4.6568837,0.6433818,,,,,,,,,,,,,, -377400,4.551801,0.5765505,,,,,,,,,,,,,, -377500,4.464263,0.5996605,,,,,,,,,,,,,, -377524,,,0.9603196382522584,0.1494261622428894,0.7545799612998962,1.0441468954086304,50000.0,0.6303000450134277,1.822312355041504,10000.0,127561.3567495346,132030.64701890945,127561.3567495346,4439.046007156372,17.278419017791748,0.0 -377600,4.306623,0.58591866,,,,,,,,,,,,,, -377700,4.189392,0.64355004,,,,,,,,,,,,,, -377800,4.692259,0.6764356,,,,,,,,,,,,,, -377900,4.403679,0.5678175,,,,,,,,,,,,,, -378000,5.096882,0.6345243,,,,,,,,,,,,,, -378100,4.3117323,0.65161383,,,,,,,,,,,,,, -378200,4.512261,0.67422694,,,,,,,,,,,,,, -378300,4.6009974,0.6422737,,,,,,,,,,,,,, -378400,4.428155,0.6529965,,,,,,,,,,,,,, -378500,4.587964,0.69556624,,,,,,,,,,,,,, -378600,4.5146565,0.624188,,,,,,,,,,,,,, -378700,4.610186,0.61971706,,,,,,,,,,,,,, -378800,4.438553,0.6518113,,,,,,,,,,,,,, -378900,4.954151,0.63643575,,,,,,,,,,,,,, -379000,4.2058735,0.5730977,,,,,,,,,,,,,, -379032,,,0.9614157676696776,0.1464610993862152,0.7551199793815613,1.0441802740097046,50000.0,0.6310000419616699,1.8221724033355715,10000.0,128070.5689780712,132557.6023557186,128070.5689780712,4455.830331802368,18.18609476089477,0.0 -379100,4.7708144,0.60194755,,,,,,,,,,,,,, -379200,4.4100237,0.5875262,,,,,,,,,,,,,, -379300,4.10556,0.59479004,,,,,,,,,,,,,, -379400,4.1273494,0.6624707,,,,,,,,,,,,,, -379500,4.1897483,0.5755204,,,,,,,,,,,,,, -379600,4.550544,0.64609516,,,,,,,,,,,,,, -379700,4.655651,0.55538106,,,,,,,,,,,,,, -379800,4.658205,0.6147764,,,,,,,,,,,,,, -379900,4.935426,0.62471557,,,,,,,,,,,,,, -380000,4.464746,0.63470656,,,,,,,,,,,,,, -380100,4.956603,0.636183,,,,,,,,,,,,,, -380200,4.491465,0.69601995,,,,,,,,,,,,,, -380300,4.267801,0.5992496,,,,,,,,,,,,,, -380400,4.4791193,0.6944105,,,,,,,,,,,,,, -380500,4.4142184,0.5827482,,,,,,,,,,,,,, -380543,,,0.961355984210968,0.1447767466306686,0.7554199695587158,1.0452173948287964,50000.0,0.6302000284194946,1.8225467205047607,10000.0,128580.7130572796,133084.9375550747,128580.7130572796,4472.877988100052,18.27739262580872,0.0 -380600,4.6544228,0.67698467,,,,,,,,,,,,,, -380700,4.8969007,0.5713654,,,,,,,,,,,,,, -380800,4.5782423,0.6956789,,,,,,,,,,,,,, -380900,4.4304214,0.61041725,,,,,,,,,,,,,, -381000,4.444904,0.60196,,,,,,,,,,,,,, -381100,4.801867,0.6567333,,,,,,,,,,,,,, -381200,4.5637836,0.60809624,,,,,,,,,,,,,, -381300,4.542249,0.5971778,,,,,,,,,,,,,, -381400,4.852166,0.67534286,,,,,,,,,,,,,, -381500,4.1817822,0.5729305,,,,,,,,,,,,,, -381600,4.6215587,0.58983517,,,,,,,,,,,,,, -381700,4.9119606,0.66332185,,,,,,,,,,,,,, -381800,4.801919,0.6712109,,,,,,,,,,,,,, -381900,4.358126,0.5899344,,,,,,,,,,,,,, -382000,4.7589574,0.67219347,,,,,,,,,,,,,, -382054,,,0.9597616195678712,0.1493446826934814,0.7551999688148499,1.0439260005950928,50000.0,0.6312000155448914,1.822293996810913,10000.0,129090.8795325756,133612.30461883545,129090.8795325756,4489.933167695999,18.371665239334103,0.0 -382100,4.4161305,0.6432015,,,,,,,,,,,,,, -382200,4.516719,0.7109054,,,,,,,,,,,,,, -382300,4.5728607,0.67669415,,,,,,,,,,,,,, -382400,4.438532,0.5716076,,,,,,,,,,,,,, -382500,4.248384,0.63909316,,,,,,,,,,,,,, -382600,4.7776017,0.6299703,,,,,,,,,,,,,, -382700,4.599664,0.63386685,,,,,,,,,,,,,, -382800,4.6175866,0.65833193,,,,,,,,,,,,,, -382900,5.032541,0.62377846,,,,,,,,,,,,,, -383000,4.305663,0.6537818,,,,,,,,,,,,,, -383100,4.7521486,0.6123431,,,,,,,,,,,,,, -383200,4.523291,0.59270215,,,,,,,,,,,,,, -383300,4.8101993,0.64163625,,,,,,,,,,,,,, -383400,4.1469455,0.6085943,,,,,,,,,,,,,, -383500,4.220108,0.5622877,,,,,,,,,,,,,, -383565,,,0.959980845451355,0.1489178836345672,0.7549799680709839,1.0436407327651978,50000.0,0.631100058555603,1.823047757148743,10000.0,129600.94166183472,134140.50261998177,129600.94166183472,4507.924045085907,18.46406841278076,0.0 -383600,4.382384,0.6152931,,,,,,,,,,,,,, -383700,4.483423,0.62886393,,,,,,,,,,,,,, -383800,4.5115194,0.624086,,,,,,,,,,,,,, -383900,4.451632,0.5776085,,,,,,,,,,,,,, -384000,4.668494,0.63088626,,,,,,,,,,,,,, -384100,4.693482,0.5477603,,,,,,,,,,,,,, -384200,4.3085785,0.5688641,,,,,,,,,,,,,, -384300,4.8359923,0.62702,,,,,,,,,,,,,, -384400,4.4824934,0.6642472,,,,,,,,,,,,,, -384500,4.578663,0.6271315,,,,,,,,,,,,,, -384600,4.9108586,0.6745724,,,,,,,,,,,,,, -384700,4.3578153,0.65958977,,,,,,,,,,,,,, -384800,4.7521915,0.6006347,,,,,,,,,,,,,, -384900,4.062014,0.5197516,,,,,,,,,,,,,, -385000,5.377037,0.66908896,,,,,,,,,,,,,, -385074,,,0.9599609375,0.1488925218582153,0.7550999522209167,1.0445308685302734,50000.0,0.6301000118255615,1.822293400764465,10000.0,130110.80349063872,134667.40004825592,130110.80349063872,4524.816039323807,18.555992364883423,0.0 -385100,4.346493,0.63721436,,,,,,,,,,,,,, -385200,4.526611,0.6875307,,,,,,,,,,,,,, -385300,4.710688,0.67615974,,,,,,,,,,,,,, -385400,4.3240128,0.5655112,,,,,,,,,,,,,, -385500,4.264409,0.6353658,,,,,,,,,,,,,, -385600,4.7924585,0.68428344,,,,,,,,,,,,,, -385700,4.0186386,0.5975846,,,,,,,,,,,,,, -385800,4.3360705,0.55910385,,,,,,,,,,,,,, -385900,4.3728867,0.61461383,,,,,,,,,,,,,, -386000,4.2343545,0.5935267,,,,,,,,,,,,,, -386100,4.328025,0.59671926,,,,,,,,,,,,,, -386200,4.6039934,0.62060845,,,,,,,,,,,,,, -386300,4.432838,0.61283094,,,,,,,,,,,,,, -386400,4.2793818,0.62604666,,,,,,,,,,,,,, -386500,4.2366323,0.56438744,,,,,,,,,,,,,, -386585,,,0.9601402878761292,0.1486131697893142,0.7552199959754944,1.044558882713318,50000.0,0.6305000185966492,1.8213427066802976,10000.0,130620.97541832924,135194.71618199348,130620.97541832924,4541.817780017853,18.64685320854187,0.0 -386600,4.386408,0.63224566,,,,,,,,,,,,,, -386700,4.3817644,0.6633603,,,,,,,,,,,,,, -386800,4.245788,0.57783294,,,,,,,,,,,,,, -386900,4.5624022,0.686208,,,,,,,,,,,,,, -387000,4.287187,0.5651381,,,,,,,,,,,,,, -387100,4.5553017,0.64441925,,,,,,,,,,,,,, -387200,4.739372,0.6276832,,,,,,,,,,,,,, -387300,4.4346504,0.6140689,,,,,,,,,,,,,, -387400,4.0206985,0.5986768,,,,,,,,,,,,,, -387500,4.228806,0.6276298,,,,,,,,,,,,,, -387600,4.391687,0.65131307,,,,,,,,,,,,,, -387700,4.6906037,0.68979096,,,,,,,,,,,,,, -387800,4.8485417,0.6141945,,,,,,,,,,,,,, -387900,4.4493756,0.60593224,,,,,,,,,,,,,, -388000,4.247711,0.55511045,,,,,,,,,,,,,, -388095,,,0.9613759517669678,0.1466667354106903,0.755620002746582,1.0440008640289309,50000.0,0.6306000351905823,1.821485996246338,10000.0,131131.05405449867,135722.25498723984,131131.05405449867,4559.130956888199,18.741585731506348,0.0 -388100,4.756614,0.6177046,,,,,,,,,,,,,, -388200,4.7732773,0.63980687,,,,,,,,,,,,,, -388300,4.58173,0.649686,,,,,,,,,,,,,, -388400,4.404065,0.6156608,,,,,,,,,,,,,, -388500,4.572812,0.70695794,,,,,,,,,,,,,, -388600,4.8789916,0.6582482,,,,,,,,,,,,,, -388700,4.4904222,0.6203259,,,,,,,,,,,,,, -388800,4.507381,0.6113539,,,,,,,,,,,,,, -388900,4.440836,0.62816346,,,,,,,,,,,,,, -389000,4.411829,0.6495764,,,,,,,,,,,,,, -389100,4.516658,0.608898,,,,,,,,,,,,,, -389200,4.5150676,0.58505833,,,,,,,,,,,,,, -389300,4.464946,0.62081635,,,,,,,,,,,,,, -389400,4.3774705,0.62986773,,,,,,,,,,,,,, -389500,5.1164474,0.74027234,,,,,,,,,,,,,, -389600,4.567562,0.63900876,,,,,,,,,,,,,, -389605,,,0.9606186151504515,0.1472872942686081,0.755079984664917,1.043853998184204,50000.0,0.6303000450134277,1.8217018842697144,10000.0,131640.9479010105,136249.3849053383,131640.9479010105,4576.21217250824,18.844163179397583,0.0 -389700,4.6667795,0.6612083,,,,,,,,,,,,,, -389800,4.9143233,0.6903637,,,,,,,,,,,,,, -389900,4.5723968,0.67478454,,,,,,,,,,,,,, -390000,4.2838655,0.625058,,,,,,,,,,,,,, -390100,4.5292516,0.6149483,,,,,,,,,,,,,, -390200,4.26341,0.6209732,,,,,,,,,,,,,, -390300,4.114225,0.54635704,,,,,,,,,,,,,, -390400,4.407043,0.6216994,,,,,,,,,,,,,, -390500,4.3766766,0.59490347,,,,,,,,,,,,,, -390600,4.366032,0.61309224,,,,,,,,,,,,,, -390700,4.523263,0.5912736,,,,,,,,,,,,,, -390800,5.242044,0.61873114,,,,,,,,,,,,,, -390900,4.457895,0.63885105,,,,,,,,,,,,,, -391000,4.4561534,0.6017627,,,,,,,,,,,,,, -391100,4.758124,0.70561713,,,,,,,,,,,,,, -391116,,,0.9608378410339355,0.1467545479536056,0.7552199959754944,1.0440353155136108,50000.0,0.6303000450134277,1.8227473497390747,10000.0,132151.01629567146,136776.55661320686,132151.01629567146,4593.165332555771,18.94241428375244,0.0 -391200,4.3986807,0.6325839,,,,,,,,,,,,,, -391300,4.1860356,0.6001991,,,,,,,,,,,,,, -391400,4.749872,0.62178624,,,,,,,,,,,,,, -391500,4.037261,0.5945566,,,,,,,,,,,,,, -391600,4.3881726,0.5996837,,,,,,,,,,,,,, -391700,4.1970572,0.57919896,,,,,,,,,,,,,, -391800,4.8737965,0.67937845,,,,,,,,,,,,,, -391900,4.492883,0.585682,,,,,,,,,,,,,, -392000,4.414247,0.6286851,,,,,,,,,,,,,, -392100,4.3801794,0.6360916,,,,,,,,,,,,,, -392200,4.856922,0.56858563,,,,,,,,,,,,,, -392300,4.655034,0.6309457,,,,,,,,,,,,,, -392400,4.7652974,0.70623857,,,,,,,,,,,,,, -392500,4.476329,0.64508796,,,,,,,,,,,,,, -392600,4.5969615,0.69688493,,,,,,,,,,,,,, -392626,,,0.9600406289100648,0.1507564932107925,0.7551599740982056,1.0441441535949707,50000.0,0.6300000548362732,1.821921944618225,10000.0,132660.9876010418,137303.47638607025,132660.9876010418,4609.966013908386,19.038569688797,0.0 -392700,4.497904,0.5112795,,,,,,,,,,,,,, -392800,4.9562993,0.5463806,,,,,,,,,,,,,, -392900,4.491833,0.6411917,,,,,,,,,,,,,, -393000,4.537868,0.6423711,,,,,,,,,,,,,, -393100,4.418803,0.62433547,,,,,,,,,,,,,, -393200,4.737185,0.6880913,,,,,,,,,,,,,, -393300,4.4308047,0.6790533,,,,,,,,,,,,,, -393400,4.5112886,0.5753353,,,,,,,,,,,,,, -393500,4.1371675,0.59868765,,,,,,,,,,,,,, -393600,4.5272546,0.62869835,,,,,,,,,,,,,, -393700,4.4668126,0.6011864,,,,,,,,,,,,,, -393800,4.447963,0.6612899,,,,,,,,,,,,,, -393900,4.3073497,0.56455976,,,,,,,,,,,,,, -394000,4.3958035,0.66313875,,,,,,,,,,,,,, -394100,4.5684495,0.5991342,,,,,,,,,,,,,, -394136,,,0.9599609375,0.1487556546926498,0.7548999786376953,1.0435575246810913,50000.0,0.6296000480651855,1.8231137990951536,10000.0,133171.0541651249,137830.54284071922,133171.0541651249,4626.819089174271,19.13323402404785,0.0 -394200,4.668317,0.59617007,,,,,,,,,,,,,, -394300,4.452398,0.62378573,,,,,,,,,,,,,, -394400,4.359587,0.59587413,,,,,,,,,,,,,, -394500,4.5447726,0.5969041,,,,,,,,,,,,,, -394600,4.7447033,0.70669687,,,,,,,,,,,,,, -394700,4.2720475,0.5706636,,,,,,,,,,,,,, -394800,4.452984,0.6140717,,,,,,,,,,,,,, -394900,4.430415,0.57263106,,,,,,,,,,,,,, -395000,4.385315,0.58224666,,,,,,,,,,,,,, -395100,4.0136704,0.5941535,,,,,,,,,,,,,, -395200,4.7363834,0.6804398,,,,,,,,,,,,,, -395300,4.5687914,0.6715973,,,,,,,,,,,,,, -395400,4.4175816,0.6299974,,,,,,,,,,,,,, -395500,4.6685934,0.6086776,,,,,,,,,,,,,, -395600,4.2389846,0.5964007,,,,,,,,,,,,,, -395646,,,0.961933970451355,0.1406219154596328,0.7546799778938293,1.044531226158142,50000.0,0.631100058555603,1.82287073135376,10000.0,133681.10687541962,138357.87222909927,133681.10687541962,4643.953272104263,19.22391462326049,0.0 -395700,4.276974,0.6085401,,,,,,,,,,,,,, -395800,4.7708406,0.63193756,,,,,,,,,,,,,, -395900,4.6693063,0.60627043,,,,,,,,,,,,,, -396000,5.215139,0.6356871,,,,,,,,,,,,,, -396100,4.700947,0.64722514,,,,,,,,,,,,,, -396200,4.7227154,0.65383387,,,,,,,,,,,,,, -396300,4.4639096,0.64028835,,,,,,,,,,,,,, -396400,4.2963786,0.60347855,,,,,,,,,,,,,, -396500,4.1724024,0.59988165,,,,,,,,,,,,,, -396600,4.2396045,0.5470424,,,,,,,,,,,,,, -396700,4.775567,0.6802696,,,,,,,,,,,,,, -396800,4.603508,0.57392883,,,,,,,,,,,,,, -396900,4.2274704,0.59953976,,,,,,,,,,,,,, -397000,4.4447784,0.5784265,,,,,,,,,,,,,, -397100,4.5596914,0.62975997,,,,,,,,,,,,,, -397156,,,0.9604591727256776,0.1473497450351715,0.7556999921798706,1.0434935092926023,50000.0,0.6309000253677368,1.8224536180496216,10000.0,134191.01133322716,138884.81725287437,134191.01133322716,4660.833715677261,19.33142256736756,0.0 -397200,4.6516604,0.6535317,,,,,,,,,,,,,, -397300,4.410518,0.65616506,,,,,,,,,,,,,, -397400,4.621838,0.6323866,,,,,,,,,,,,,, -397500,4.5782843,0.58849037,,,,,,,,,,,,,, -397600,4.6924706,0.671594,,,,,,,,,,,,,, -397700,4.2741933,0.60533124,,,,,,,,,,,,,, -397800,4.145277,0.5582078,,,,,,,,,,,,,, -397900,4.4874244,0.6050489,,,,,,,,,,,,,, -398000,4.560596,0.62491477,,,,,,,,,,,,,, -398100,4.5153112,0.63020015,,,,,,,,,,,,,, -398200,4.523512,0.61932063,,,,,,,,,,,,,, -398300,4.31014,0.5986344,,,,,,,,,,,,,, -398400,4.6085024,0.62358165,,,,,,,,,,,,,, -398500,3.9573636,0.5278222,,,,,,,,,,,,,, -398600,4.3653054,0.6594704,,,,,,,,,,,,,, -398665,,,0.9610171914100648,0.1485028862953186,0.7549200057983398,1.044009804725647,50000.0,0.6309000253677368,1.8216071128845213,10000.0,134701.0166823864,139412.06704068184,134701.0166823864,4677.92427277565,19.43360996246338,0.0 -398700,4.5758586,0.5826915,,,,,,,,,,,,,, -398800,4.4964223,0.7018443,,,,,,,,,,,,,, -398900,3.9882638,0.57869506,,,,,,,,,,,,,, -399000,4.7650585,0.65039223,,,,,,,,,,,,,, -399100,4.4057035,0.63066316,,,,,,,,,,,,,, -399200,4.3225217,0.6177331,,,,,,,,,,,,,, -399300,4.2344136,0.5381106,,,,,,,,,,,,,, -399400,4.6949987,0.65546197,,,,,,,,,,,,,, -399500,4.437369,0.5729178,,,,,,,,,,,,,, -399600,4.515066,0.63738424,,,,,,,,,,,,,, -399700,4.353492,0.56525457,,,,,,,,,,,,,, -399800,4.869704,0.6959687,,,,,,,,,,,,,, -399900,4.861106,0.67767555,,,,,,,,,,,,,, -400000,4.943708,0.61316615,,,,,,,,,,,,,, -400100,4.622196,0.58614516,,,,,,,,,,,,,, -400176,,,0.9601203799247742,0.1494250744581222,0.755079984664917,1.04414701461792,50000.0,0.6319000124931335,1.8218629360198968,10000.0,135211.10509705544,139939.248128891,135211.10509705544,4694.866877555847,19.5312135219574,0.0 -400200,4.7112865,0.6134238,,,,,,,,,,,,,, -400300,4.8428473,0.6672461,,,,,,,,,,,,,, -400400,4.329187,0.6441339,,,,,,,,,,,,,, -400500,4.5041137,0.61356825,,,,,,,,,,,,,, -400600,4.194601,0.5746524,,,,,,,,,,,,,, -400700,4.4045925,0.6140113,,,,,,,,,,,,,, -400800,4.721429,0.62690836,,,,,,,,,,,,,, -400900,5.1335344,0.7087053,,,,,,,,,,,,,, -401000,4.373278,0.65202206,,,,,,,,,,,,,, -401100,4.6467876,0.5379304,,,,,,,,,,,,,, -401200,4.250211,0.624115,,,,,,,,,,,,,, -401300,4.3546014,0.6358684,,,,,,,,,,,,,, -401400,4.5180807,0.6615999,,,,,,,,,,,,,, -401500,4.7715726,0.6114092,,,,,,,,,,,,,, -401600,4.594705,0.6671372,,,,,,,,,,,,,, -401687,,,0.961316168308258,0.1458020508289337,0.7552199959754944,1.0438878536224363,50000.0,0.6308000087738037,1.8203022480010984,10000.0,135721.27109241486,140466.58408665657,135721.27109241486,4711.882912874222,19.63275671005249,0.0 -401700,4.771971,0.60795975,,,,,,,,,,,,,, -401800,3.9738758,0.51660204,,,,,,,,,,,,,, -401900,4.594575,0.6281097,,,,,,,,,,,,,, -402000,4.222952,0.635993,,,,,,,,,,,,,, -402100,4.576642,0.5686087,,,,,,,,,,,,,, -402200,4.3461623,0.6305388,,,,,,,,,,,,,, -402300,4.778272,0.677734,,,,,,,,,,,,,, -402400,4.272832,0.6732649,,,,,,,,,,,,,, -402500,4.73189,0.6689174,,,,,,,,,,,,,, -402600,4.516448,0.56775737,,,,,,,,,,,,,, -402700,4.647284,0.58272076,,,,,,,,,,,,,, -402800,4.3545856,0.6482699,,,,,,,,,,,,,, -402900,4.7229295,0.6341808,,,,,,,,,,,,,, -403000,4.751778,0.6270279,,,,,,,,,,,,,, -403100,4.2326183,0.646436,,,,,,,,,,,,,, -403197,,,0.9598413109779358,0.1497427225112915,0.755079984664917,1.0438313484191897,50000.0,0.6313000321388245,1.821870803833008,10000.0,136231.34775304794,140993.64425182345,136231.34775304794,4728.720160961151,19.728036642074585,0.0 -403200,4.304349,0.5965432,,,,,,,,,,,,,, -403300,4.462982,0.5963653,,,,,,,,,,,,,, -403400,4.6208215,0.6213834,,,,,,,,,,,,,, -403500,4.870523,0.6331761,,,,,,,,,,,,,, -403600,4.1167436,0.62599665,,,,,,,,,,,,,, -403700,4.226539,0.66411024,,,,,,,,,,,,,, -403800,4.8525987,0.6183914,,,,,,,,,,,,,, -403900,4.3339634,0.6089247,,,,,,,,,,,,,, -404000,4.3493037,0.60806865,,,,,,,,,,,,,, -404100,4.2415276,0.60517627,,,,,,,,,,,,,, -404200,4.6266775,0.65086555,,,,,,,,,,,,,, -404300,5.1620283,0.70318544,,,,,,,,,,,,,, -404400,4.434551,0.6269851,,,,,,,,,,,,,, -404500,4.91917,0.65331703,,,,,,,,,,,,,, -404600,4.7037992,0.5822617,,,,,,,,,,,,,, -404700,4.449174,0.62309176,,,,,,,,,,,,,, -404708,,,0.959004282951355,0.1519351452589035,0.7553199529647827,1.0440607070922852,50000.0,0.6308000087738037,1.822520732879639,10000.0,136741.38403439522,141520.8993494511,136741.38403439522,4745.791574001312,19.823182582855225,0.0 -404800,4.187672,0.6269273,,,,,,,,,,,,,, -404900,4.759907,0.6646778,,,,,,,,,,,,,, -405000,4.8776608,0.64452815,,,,,,,,,,,,,, -405100,4.5388036,0.6117851,,,,,,,,,,,,,, -405200,4.673595,0.67196023,,,,,,,,,,,,,, -405300,4.0441346,0.62180537,,,,,,,,,,,,,, -405400,4.563534,0.6285642,,,,,,,,,,,,,, -405500,4.3932514,0.650733,,,,,,,,,,,,,, -405600,4.36591,0.6807729,,,,,,,,,,,,,, -405700,4.608009,0.63009965,,,,,,,,,,,,,, -405800,4.4460034,0.5868863,,,,,,,,,,,,,, -405900,4.7422066,0.64617383,,,,,,,,,,,,,, -406000,4.0914626,0.55704266,,,,,,,,,,,,,, -406100,4.3467646,0.6068214,,,,,,,,,,,,,, -406200,4.6083193,0.6414248,,,,,,,,,,,,,, -406218,,,0.9599609375,0.1494533419609069,0.7551999688148499,1.0442662239074707,50000.0,0.6312000155448914,1.82229483127594,10000.0,137251.34723711014,142047.91184687614,137251.34723711014,4762.691604375839,19.920071363449097,0.0 -406300,4.768581,0.6309322,,,,,,,,,,,,,, -406400,4.5784216,0.66296136,,,,,,,,,,,,,, -406500,4.4526234,0.4985649,,,,,,,,,,,,,, -406600,4.658469,0.60904527,,,,,,,,,,,,,, -406700,4.348334,0.62721074,,,,,,,,,,,,,, -406800,4.1947384,0.62503254,,,,,,,,,,,,,, -406900,4.179964,0.6156179,,,,,,,,,,,,,, -407000,4.5132446,0.6373807,,,,,,,,,,,,,, -407100,4.9662566,0.6929753,,,,,,,,,,,,,, -407200,4.4493084,0.59887356,,,,,,,,,,,,,, -407300,4.567476,0.6558388,,,,,,,,,,,,,, -407400,4.239835,0.61976737,,,,,,,,,,,,,, -407500,4.443603,0.59879243,,,,,,,,,,,,,, -407600,4.177994,0.5936895,,,,,,,,,,,,,, -407700,4.341662,0.63725233,,,,,,,,,,,,,, -407729,,,0.960339605808258,0.1472578048706054,0.7548799514770508,1.044488787651062,50000.0,0.6303000450134277,1.8232182264328003,10000.0,137761.51915931702,142575.1911354065,137761.51915931702,4779.641663074493,20.02484154701233,0.0 -407800,4.5816407,0.56771255,,,,,,,,,,,,,, -407900,4.4167156,0.5362891,,,,,,,,,,,,,, -408000,4.426267,0.6106217,,,,,,,,,,,,,, -408100,4.2055216,0.66096425,,,,,,,,,,,,,, -408200,4.696351,0.6417159,,,,,,,,,,,,,, -408300,4.322048,0.577569,,,,,,,,,,,,,, -408400,4.5943174,0.66564703,,,,,,,,,,,,,, -408500,4.0362597,0.62811255,,,,,,,,,,,,,, -408600,4.34292,0.5975747,,,,,,,,,,,,,, -408700,4.5204983,0.6295747,,,,,,,,,,,,,, -408800,4.967544,0.62465644,,,,,,,,,,,,,, -408900,4.5896387,0.62017137,,,,,,,,,,,,,, -409000,4.2391224,0.631326,,,,,,,,,,,,,, -409100,4.635485,0.6201494,,,,,,,,,,,,,, -409200,4.5980515,0.6934976,,,,,,,,,,,,,, -409239,,,0.9592832922935486,0.1489840149879455,0.7554599642753601,1.0428025722503662,50000.0,0.6307000517845154,1.8203986883163448,10000.0,138271.42797732353,143101.9966094494,138271.42797732353,4796.389009952545,20.122021436691284,0.0 -409300,4.540007,0.5729358,,,,,,,,,,,,,, -409400,4.1062803,0.60595155,,,,,,,,,,,,,, -409500,4.3470297,0.57784826,,,,,,,,,,,,,, -409600,4.6827555,0.6472316,,,,,,,,,,,,,, -409700,5.2338543,0.71391654,,,,,,,,,,,,,, -409800,4.1856856,0.5570706,,,,,,,,,,,,,, -409900,4.61627,0.6117697,,,,,,,,,,,,,, -410000,4.2485847,0.6125877,,,,,,,,,,,,,, -410100,4.519655,0.6200421,,,,,,,,,,,,,, -410200,4.304153,0.5911192,,,,,,,,,,,,,, -410300,4.5886483,0.64039207,,,,,,,,,,,,,, -410400,4.4177876,0.542757,,,,,,,,,,,,,, -410500,4.7288055,0.62854815,,,,,,,,,,,,,, -410600,4.483693,0.6625158,,,,,,,,,,,,,, -410700,4.8365564,0.6265348,,,,,,,,,,,,,, -410749,,,0.959980845451355,0.1499495357275009,0.7550399899482727,1.0459305047988892,50000.0,0.631100058555603,1.8239178657531736,10000.0,138781.3643064499,143629.33147835732,138781.3643064499,4813.636505126953,20.22056555747986,0.0 -410800,4.476941,0.6978173,,,,,,,,,,,,,, -410900,4.4926505,0.66968435,,,,,,,,,,,,,, -411000,4.3666267,0.6725453,,,,,,,,,,,,,, -411100,5.1557865,0.58606076,,,,,,,,,,,,,, -411200,4.770478,0.6144629,,,,,,,,,,,,,, -411300,4.4973574,0.65237653,,,,,,,,,,,,,, -411400,4.802119,0.6392845,,,,,,,,,,,,,, -411500,4.671646,0.6897661,,,,,,,,,,,,,, -411600,4.3335676,0.6284853,,,,,,,,,,,,,, -411700,4.29106,0.58902997,,,,,,,,,,,,,, -411800,4.685607,0.6571746,,,,,,,,,,,,,, -411900,4.76019,0.7442592,,,,,,,,,,,,,, -412000,4.1174645,0.6004632,,,,,,,,,,,,,, -412100,4.3484774,0.614742,,,,,,,,,,,,,, -412200,4.5015793,0.60280365,,,,,,,,,,,,,, -412260,,,0.9607780575752258,0.146344318985939,0.7551800012588501,1.043354630470276,50000.0,0.6308000087738037,1.822560429573059,10000.0,139291.53511548042,144156.41612815857,139291.53511548042,4830.401615142822,20.31696081161499,0.0 -412300,4.166977,0.55160934,,,,,,,,,,,,,, -412400,4.3262076,0.5924499,,,,,,,,,,,,,, -412500,4.1036572,0.5475038,,,,,,,,,,,,,, -412600,4.3283496,0.6740481,,,,,,,,,,,,,, -412700,4.3505487,0.5973386,,,,,,,,,,,,,, -412800,4.5860434,0.63360924,,,,,,,,,,,,,, -412900,4.282053,0.53498614,,,,,,,,,,,,,, -413000,4.719891,0.71983874,,,,,,,,,,,,,, -413100,4.8554316,0.6555357,,,,,,,,,,,,,, -413200,4.5005827,0.68109167,,,,,,,,,,,,,, -413300,4.047581,0.54966027,,,,,,,,,,,,,, -413400,4.762505,0.5992032,,,,,,,,,,,,,, -413500,4.176239,0.56949955,,,,,,,,,,,,,, -413600,4.2745366,0.57943135,,,,,,,,,,,,,, -413700,4.9226317,0.69925785,,,,,,,,,,,,,, -413770,,,0.9599409699440002,0.1482083797454834,0.7553399801254272,1.0445153713226318,50000.0,0.6309000253677368,1.821332693099976,10000.0,139801.42785167694,144683.30664777756,139801.42785167694,4847.201719760895,20.461642026901245,0.0 -413800,4.721216,0.6778228,,,,,,,,,,,,,, -413900,4.5035543,0.5967477,,,,,,,,,,,,,, -414000,4.3309493,0.6095449,,,,,,,,,,,,,, -414100,4.262153,0.6483609,,,,,,,,,,,,,, -414200,4.6267743,0.6601236,,,,,,,,,,,,,, -414300,4.9314785,0.6585349,,,,,,,,,,,,,, -414400,4.257014,0.5766781,,,,,,,,,,,,,, -414500,4.350443,0.63216335,,,,,,,,,,,,,, -414600,4.5958724,0.6134759,,,,,,,,,,,,,, -414700,4.1237893,0.5713672,,,,,,,,,,,,,, -414800,4.477055,0.6078663,,,,,,,,,,,,,, -414900,4.449522,0.579721,,,,,,,,,,,,,, -415000,4.3590226,0.58655196,,,,,,,,,,,,,, -415100,4.692694,0.6110629,,,,,,,,,,,,,, -415200,4.2358727,0.5866782,,,,,,,,,,,,,, -415280,,,0.9595025181770324,0.1520635187625885,0.754859983921051,1.0431699752807615,50000.0,0.6317000389099121,1.820890545845032,10000.0,140311.3820786476,145210.5059850216,140311.3820786476,4864.29502248764,20.562307596206665,0.0 -415300,4.3617535,0.59168595,,,,,,,,,,,,,, -415400,4.513478,0.6496321,,,,,,,,,,,,,, -415500,4.85424,0.6475534,,,,,,,,,,,,,, -415600,4.489136,0.62932676,,,,,,,,,,,,,, -415700,4.1263533,0.5672636,,,,,,,,,,,,,, -415800,4.565803,0.62930477,,,,,,,,,,,,,, -415900,4.1383457,0.6322061,,,,,,,,,,,,,, -416000,4.1972957,0.6000134,,,,,,,,,,,,,, -416100,4.4437575,0.60962355,,,,,,,,,,,,,, -416200,4.9869294,0.61962676,,,,,,,,,,,,,, -416300,4.4024353,0.5955664,,,,,,,,,,,,,, -416400,4.44039,0.6007237,,,,,,,,,,,,,, -416500,4.539361,0.6410562,,,,,,,,,,,,,, -416600,4.3521266,0.65700364,,,,,,,,,,,,,, -416700,4.6498175,0.62585706,,,,,,,,,,,,,, -416790,,,0.961355984210968,0.1469427496194839,0.7549200057983398,1.0446892976760864,50000.0,0.6305000185966492,1.8210554122924805,10000.0,140821.24357557297,145737.5135421753,140821.24357557297,4881.290358066559,20.66101765632629,0.0 -416800,4.7159176,0.6327665,,,,,,,,,,,,,, -416900,4.7070293,0.6223046,,,,,,,,,,,,,, -417000,4.3556113,0.57171094,,,,,,,,,,,,,, -417100,4.6590705,0.6690676,,,,,,,,,,,,,, -417200,5.072076,0.6835351,,,,,,,,,,,,,, -417300,4.7313366,0.6507815,,,,,,,,,,,,,, -417400,3.9773474,0.53903306,,,,,,,,,,,,,, -417500,4.4577293,0.7448588,,,,,,,,,,,,,, -417600,4.2661924,0.6118319,,,,,,,,,,,,,, -417700,4.6086693,0.63624114,,,,,,,,,,,,,, -417800,4.741885,0.72466415,,,,,,,,,,,,,, -417900,4.787756,0.6258498,,,,,,,,,,,,,, -418000,4.5832024,0.5295453,,,,,,,,,,,,,, -418100,4.2984934,0.5749834,,,,,,,,,,,,,, -418200,4.4602504,0.5661655,,,,,,,,,,,,,, -418300,,,0.9614756107330322,0.1455113440752029,0.7551599740982056,1.0434072017669678,50000.0,0.6302000284194946,1.82075834274292,10000.0,141331.1424012184,146264.43515825272,141331.1424012184,4898.1636373996735,20.759366750717163,0.0 -418300,4.5528855,0.6342006,,,,,,,,,,,,,, -418400,4.445088,0.6345971,,,,,,,,,,,,,, -418500,4.646307,0.67615473,,,,,,,,,,,,,, -418600,4.4268684,0.59067035,,,,,,,,,,,,,, -418700,4.258171,0.5541446,,,,,,,,,,,,,, -418800,4.354519,0.5924152,,,,,,,,,,,,,, -418900,4.7882695,0.7640941,,,,,,,,,,,,,, -419000,4.3675165,0.58096963,,,,,,,,,,,,,, -419100,4.10992,0.6061482,,,,,,,,,,,,,, -419200,4.1501584,0.64363235,,,,,,,,,,,,,, -419300,4.507395,0.6294999,,,,,,,,,,,,,, -419400,4.569844,0.5877376,,,,,,,,,,,,,, -419500,4.5498075,0.6149391,,,,,,,,,,,,,, -419600,4.290806,0.6046168,,,,,,,,,,,,,, -419700,4.5883617,0.6178812,,,,,,,,,,,,,, -419800,4.378222,0.59514356,,,,,,,,,,,,,, -419810,,,0.9610769748687744,0.1462026089429855,0.7550999522209167,1.0435361862182615,50000.0,0.631100058555603,1.822310447692871,10000.0,141841.25842308998,146791.6738820076,141841.25842308998,4915.135595321655,20.857829093933105,0.0 -419900,4.7732005,0.6699063,,,,,,,,,,,,,, -420000,4.4534764,0.66178477,,,,,,,,,,,,,, -420100,4.755249,0.65926266,,,,,,,,,,,,,, -420200,4.584262,0.69246966,,,,,,,,,,,,,, -420300,4.5611024,0.5650042,,,,,,,,,,,,,, -420400,4.658801,0.6674179,,,,,,,,,,,,,, -420500,4.42845,0.6178949,,,,,,,,,,,,,, -420600,5.0077376,0.73343056,,,,,,,,,,,,,, -420700,5.1062884,0.60841405,,,,,,,,,,,,,, -420800,4.676058,0.57544374,,,,,,,,,,,,,, -420900,4.530524,0.6108587,,,,,,,,,,,,,, -421000,4.6345415,0.5902994,,,,,,,,,,,,,, -421100,4.898165,0.6038933,,,,,,,,,,,,,, -421200,4.7666287,0.577253,,,,,,,,,,,,,, -421300,4.4038076,0.594032,,,,,,,,,,,,,, -421320,,,0.9599409699440002,0.1498060375452041,0.7549999952316284,1.0440946817398071,50000.0,0.631600022315979,1.8217707872390747,10000.0,142351.206646204,147318.78655314443,142351.206646204,4932.147796154022,20.957899570465088,0.0 -421400,4.4389324,0.61635846,,,,,,,,,,,,,, -421500,4.4759207,0.59559596,,,,,,,,,,,,,, -421600,4.7410336,0.6490212,,,,,,,,,,,,,, -421700,4.773439,0.71574783,,,,,,,,,,,,,, -421800,4.284188,0.59523094,,,,,,,,,,,,,, -421900,4.3304143,0.654602,,,,,,,,,,,,,, -422000,4.522346,0.5711817,,,,,,,,,,,,,, -422100,4.7949123,0.7232605,,,,,,,,,,,,,, -422200,4.782957,0.63104963,,,,,,,,,,,,,, -422300,3.8566725,0.52299726,,,,,,,,,,,,,, -422400,4.339633,0.6205969,,,,,,,,,,,,,, -422500,4.475902,0.6790151,,,,,,,,,,,,,, -422600,4.3501945,0.64592373,,,,,,,,,,,,,, -422700,4.288804,0.6596783,,,,,,,,,,,,,, -422800,4.251955,0.56023407,,,,,,,,,,,,,, -422830,,,0.9587252736091614,0.1514067202806472,0.7552199959754944,1.0427519083023071,50000.0,0.6304000020027161,1.821520566940308,10000.0,142861.21098184586,147846.70860219002,142861.21098184586,4949.913951873779,21.057546615600582,0.0 -422900,4.7186027,0.6606483,,,,,,,,,,,,,, -423000,4.517679,0.6289373,,,,,,,,,,,,,, -423100,4.5723724,0.6378558,,,,,,,,,,,,,, -423200,4.5294027,0.6046874,,,,,,,,,,,,,, -423300,4.9192724,0.6251645,,,,,,,,,,,,,, -423400,4.112966,0.6412577,,,,,,,,,,,,,, -423500,4.438287,0.58994275,,,,,,,,,,,,,, -423600,4.5674067,0.6490365,,,,,,,,,,,,,, -423700,4.398025,0.6151571,,,,,,,,,,,,,, -423800,5.2541394,0.7035185,,,,,,,,,,,,,, -423900,4.2047424,0.55687284,,,,,,,,,,,,,, -424000,4.700328,0.6707076,,,,,,,,,,,,,, -424100,4.556136,0.64351106,,,,,,,,,,,,,, -424200,4.222843,0.6027173,,,,,,,,,,,,,, -424300,4.7728515,0.6253927,,,,,,,,,,,,,, -424340,,,0.960160195827484,0.1487235575914383,0.7555199861526489,1.043978929519653,50000.0,0.6302000284194946,1.822304129600525,10000.0,143371.2946677208,148373.94969177246,143371.2946677208,4966.923537492752,21.153748273849487,0.0 -424400,4.680473,0.607074,,,,,,,,,,,,,, -424500,4.5778446,0.5593998,,,,,,,,,,,,,, -424600,4.711707,0.69059783,,,,,,,,,,,,,, -424700,4.7492304,0.67546415,,,,,,,,,,,,,, -424800,4.640775,0.59336317,,,,,,,,,,,,,, -424900,4.483248,0.6396179,,,,,,,,,,,,,, -425000,4.593361,0.6255828,,,,,,,,,,,,,, -425100,4.708993,0.67193174,,,,,,,,,,,,,, -425200,4.96879,0.6218059,,,,,,,,,,,,,, -425300,4.3314276,0.51724434,,,,,,,,,,,,,, -425400,4.443628,0.63213366,,,,,,,,,,,,,, -425500,4.9933677,0.5833725,,,,,,,,,,,,,, -425600,4.99861,0.6537007,,,,,,,,,,,,,, -425700,4.613189,0.64138937,,,,,,,,,,,,,, -425800,4.9291954,0.6729522,,,,,,,,,,,,,, -425851,,,0.9612165093421936,0.1480762660503387,0.7553399801254272,1.0441192388534546,50000.0,0.6309000253677368,1.820621132850647,10000.0,143881.41513490677,148901.27196621895,143881.41513490677,4983.973979711533,21.25181031227112,0.0 -425900,4.631137,0.6509346,,,,,,,,,,,,,, -426000,5.1984973,0.67205006,,,,,,,,,,,,,, -426100,4.514616,0.6601451,,,,,,,,,,,,,, -426200,4.5237336,0.6605964,,,,,,,,,,,,,, -426300,4.8991613,0.69362277,,,,,,,,,,,,,, -426400,4.3748245,0.5982895,,,,,,,,,,,,,, -426500,4.3482585,0.587175,,,,,,,,,,,,,, -426600,4.621985,0.66958135,,,,,,,,,,,,,, -426700,4.394517,0.62083656,,,,,,,,,,,,,, -426800,4.785831,0.566207,,,,,,,,,,,,,, -426900,4.359454,0.6318834,,,,,,,,,,,,,, -427000,4.9230514,0.7493447,,,,,,,,,,,,,, -427100,4.6827626,0.6704895,,,,,,,,,,,,,, -427200,4.2208695,0.63238007,,,,,,,,,,,,,, -427300,4.3369703,0.5898758,,,,,,,,,,,,,, -427362,,,0.9609375,0.1451647132635116,0.7549399733543396,1.0440856218338013,50000.0,0.631100058555603,1.8233418464660645,10000.0,144391.50057291985,149428.4954817295,144391.50057291985,5000.952465295792,21.35924863815308,0.0 -427400,4.435294,0.6296518,,,,,,,,,,,,,, -427500,4.7737846,0.61024463,,,,,,,,,,,,,, -427600,4.464138,0.67389035,,,,,,,,,,,,,, -427700,4.4859,0.61372656,,,,,,,,,,,,,, -427800,4.2518334,0.6006979,,,,,,,,,,,,,, -427900,4.498962,0.5436453,,,,,,,,,,,,,, -428000,4.366846,0.65495956,,,,,,,,,,,,,, -428100,4.1146,0.62851495,,,,,,,,,,,,,, -428200,4.6477594,0.69205076,,,,,,,,,,,,,, -428300,4.4339533,0.6189941,,,,,,,,,,,,,, -428400,4.372926,0.6155182,,,,,,,,,,,,,, -428500,4.5153556,0.6523007,,,,,,,,,,,,,, -428600,4.227442,0.5958563,,,,,,,,,,,,,, -428700,4.5463023,0.64926875,,,,,,,,,,,,,, -428800,4.4484534,0.5832281,,,,,,,,,,,,,, -428872,,,0.9600605964660645,0.1478695422410965,0.7552399635314941,1.0440679788589478,50000.0,0.6304000020027161,1.8220691680908203,10000.0,144901.63012313843,149955.54619264603,144901.63012313843,5017.714422941208,21.466845512390137,0.0 -428900,4.651461,0.6918337,,,,,,,,,,,,,, -429000,4.4409237,0.5961128,,,,,,,,,,,,,, -429100,4.440268,0.59558976,,,,,,,,,,,,,, -429200,4.1426296,0.5447028,,,,,,,,,,,,,, -429300,4.3862743,0.6800328,,,,,,,,,,,,,, -429400,4.7925777,0.659501,,,,,,,,,,,,,, -429500,4.222863,0.6480671,,,,,,,,,,,,,, -429600,4.81757,0.5841405,,,,,,,,,,,,,, -429700,4.95387,0.60996324,,,,,,,,,,,,,, -429800,4.520865,0.5901,,,,,,,,,,,,,, -429900,4.31784,0.6500257,,,,,,,,,,,,,, -430000,4.816654,0.69279444,,,,,,,,,,,,,, -430100,5.064945,0.6671941,,,,,,,,,,,,,, -430200,4.2014117,0.5796977,,,,,,,,,,,,,, -430300,4.768198,0.6949895,,,,,,,,,,,,,, -430383,,,0.9606983065605164,0.1471541225910186,0.7550599575042725,1.0442098379135132,50000.0,0.6314000487327576,1.8212717771530151,10000.0,145411.6282696724,150482.45879983902,145411.6282696724,5034.476024627686,21.568559885025024,0.0 -430400,4.970178,0.63751566,,,,,,,,,,,,,, -430500,5.3354006,0.70873487,,,,,,,,,,,,,, -430600,4.2421174,0.6005327,,,,,,,,,,,,,, -430700,4.3253765,0.6017178,,,,,,,,,,,,,, -430800,4.7600183,0.68304956,,,,,,,,,,,,,, -430900,4.677319,0.6471684,,,,,,,,,,,,,, -431000,4.471074,0.5967018,,,,,,,,,,,,,, -431100,4.3920474,0.580353,,,,,,,,,,,,,, -431200,4.364339,0.56166005,,,,,,,,,,,,,, -431300,4.178321,0.5675106,,,,,,,,,,,,,, -431400,4.5026045,0.6180424,,,,,,,,,,,,,, -431500,4.1755614,0.6062121,,,,,,,,,,,,,, -431600,4.2145495,0.541074,,,,,,,,,,,,,, -431700,4.5125656,0.6366811,,,,,,,,,,,,,, -431800,4.384005,0.5946047,,,,,,,,,,,,,, -431894,,,0.9595025181770324,0.1516856253147125,0.7554199695587158,1.043545842170715,50000.0,0.6308000087738037,1.8222107887268064,10000.0,145921.7688229084,151009.8665678501,145921.7688229084,5051.591791629791,21.6670138835907,0.0 -431900,4.811084,0.64988816,,,,,,,,,,,,,, -432000,4.7932053,0.6870843,,,,,,,,,,,,,, -432100,4.412971,0.6170212,,,,,,,,,,,,,, -432200,4.313453,0.60522795,,,,,,,,,,,,,, -432300,4.503416,0.64536643,,,,,,,,,,,,,, -432400,4.2234106,0.66509044,,,,,,,,,,,,,, -432500,4.148091,0.58629036,,,,,,,,,,,,,, -432600,4.6328535,0.68312836,,,,,,,,,,,,,, -432700,4.6389966,0.67397726,,,,,,,,,,,,,, -432800,4.5637,0.5488653,,,,,,,,,,,,,, -432900,4.5629478,0.6336198,,,,,,,,,,,,,, -433000,4.4369807,0.63005364,,,,,,,,,,,,,, -433100,4.74303,0.677418,,,,,,,,,,,,,, -433200,4.007694,0.6041939,,,,,,,,,,,,,, -433300,4.717128,0.66985124,,,,,,,,,,,,,, -433400,4.6747055,0.64910126,,,,,,,,,,,,,, -433404,,,0.962312638759613,0.1417900472879409,0.755299985408783,1.044046401977539,50000.0,0.6315000057220459,1.8215786218643188,10000.0,146431.67700624466,151536.72223758698,146431.67700624466,5068.386824846268,21.767746925354004,0.0 -433500,4.4218307,0.6469027,,,,,,,,,,,,,, -433600,4.627833,0.6705167,,,,,,,,,,,,,, -433700,4.527007,0.6169697,,,,,,,,,,,,,, -433800,4.4367633,0.57879066,,,,,,,,,,,,,, -433900,4.7915735,0.6688585,,,,,,,,,,,,,, -434000,4.7565327,0.5967206,,,,,,,,,,,,,, -434100,4.6616964,0.6623956,,,,,,,,,,,,,, -434200,4.7459245,0.664428,,,,,,,,,,,,,, -434300,4.3066688,0.5985158,,,,,,,,,,,,,, -434400,4.2491493,0.59518194,,,,,,,,,,,,,, -434500,4.3606358,0.6257951,,,,,,,,,,,,,, -434600,4.521359,0.6108606,,,,,,,,,,,,,, -434700,4.268964,0.6071358,,,,,,,,,,,,,, -434800,4.8663125,0.6645062,,,,,,,,,,,,,, -434900,4.4374714,0.5673142,,,,,,,,,,,,,, -434914,,,0.9612762928009032,0.1439260393381118,0.7549200057983398,1.0445345640182495,50000.0,0.631100058555603,1.823761820793152,10000.0,146941.73819971085,152063.90723013878,146941.73819971085,5085.313814640045,21.91230607032776,0.0 -435000,4.4816036,0.5633677,,,,,,,,,,,,,, -435100,5.105392,0.64359576,,,,,,,,,,,,,, -435200,4.3037033,0.6026437,,,,,,,,,,,,,, -435300,4.5173864,0.6710467,,,,,,,,,,,,,, -435400,4.8771315,0.64002025,,,,,,,,,,,,,, -435500,4.190773,0.5682702,,,,,,,,,,,,,, -435600,4.113007,0.5622791,,,,,,,,,,,,,, -435700,4.5132084,0.6842739,,,,,,,,,,,,,, -435800,4.560415,0.75151783,,,,,,,,,,,,,, -435900,4.46318,0.6355714,,,,,,,,,,,,,, -436000,4.105275,0.56312364,,,,,,,,,,,,,, -436100,4.6721287,0.69798374,,,,,,,,,,,,,, -436200,4.246842,0.64391166,,,,,,,,,,,,,, -436300,4.4162216,0.61227685,,,,,,,,,,,,,, -436400,4.167028,0.6469806,,,,,,,,,,,,,, -436424,,,0.959203600883484,0.1486395001411438,0.7547999620437622,1.045411467552185,50000.0,0.6307000517845154,1.8222869634628296,10000.0,147451.59631967545,152590.7104575634,147451.59631967545,5102.110006809235,22.009751081466675,0.0 -436500,4.7242074,0.58877647,,,,,,,,,,,,,, -436600,4.696057,0.6324504,,,,,,,,,,,,,, -436700,4.589883,0.67625153,,,,,,,,,,,,,, -436800,5.000338,0.6503765,,,,,,,,,,,,,, -436900,4.529641,0.58110046,,,,,,,,,,,,,, -437000,4.931578,0.6551088,,,,,,,,,,,,,, -437100,4.5730357,0.63287425,,,,,,,,,,,,,, -437200,4.7357917,0.6322286,,,,,,,,,,,,,, -437300,4.8061004,0.70534456,,,,,,,,,,,,,, -437400,4.443447,0.6546807,,,,,,,,,,,,,, -437500,4.525677,0.6408353,,,,,,,,,,,,,, -437600,4.537896,0.67988634,,,,,,,,,,,,,, -437700,4.2860937,0.6527345,,,,,,,,,,,,,, -437800,4.232167,0.6348868,,,,,,,,,,,,,, -437900,4.6398644,0.6387717,,,,,,,,,,,,,, -437934,,,0.9607381820678712,0.1503721624612808,0.7549799680709839,1.042616367340088,50000.0,0.6306000351905823,1.8189860582351685,10000.0,147961.5187318325,153117.63931131363,147961.5187318325,5118.958813905716,22.115402460098267,0.0 -438000,4.68805,0.58625436,,,,,,,,,,,,,, -438100,4.7454453,0.6365121,,,,,,,,,,,,,, -438200,4.4010615,0.6857103,,,,,,,,,,,,,, -438300,4.4502645,0.65545297,,,,,,,,,,,,,, -438400,4.377873,0.66090125,,,,,,,,,,,,,, -438500,4.3610682,0.58486766,,,,,,,,,,,,,, -438600,4.6263065,0.6910142,,,,,,,,,,,,,, -438700,4.557037,0.6130806,,,,,,,,,,,,,, -438800,4.8745418,0.5756223,,,,,,,,,,,,,, -438900,4.1788454,0.60515326,,,,,,,,,,,,,, -439000,5.0421023,0.6174725,,,,,,,,,,,,,, -439100,5.211868,0.69420505,,,,,,,,,,,,,, -439200,4.2291946,0.60694826,,,,,,,,,,,,,, -439300,4.3138933,0.583806,,,,,,,,,,,,,, -439400,4.575766,0.70579785,,,,,,,,,,,,,, -439444,,,0.9608577489852904,0.1458749920129776,0.7550199627876282,1.043423771858215,50000.0,0.6320000290870667,1.8204666376113887,10000.0,148471.39300227165,153644.87471556664,148471.39300227165,5136.166883468628,22.215853929519653,0.0 -439500,4.5485425,0.6816336,,,,,,,,,,,,,, -439600,4.3199973,0.6572479,,,,,,,,,,,,,, -439700,4.716182,0.60110646,,,,,,,,,,,,,, -439800,4.237199,0.60961056,,,,,,,,,,,,,, -439900,4.5398574,0.6260078,,,,,,,,,,,,,, -440000,4.2614594,0.66010994,,,,,,,,,,,,,, -440100,4.4363494,0.6251469,,,,,,,,,,,,,, -440200,4.4685335,0.63238263,,,,,,,,,,,,,, -440300,3.65708,0.48898518,,,,,,,,,,,,,, -440400,4.259062,0.6670416,,,,,,,,,,,,,, -440500,4.0778418,0.61237574,,,,,,,,,,,,,, -440600,4.1318493,0.62205493,,,,,,,,,,,,,, -440700,4.3911357,0.62363267,,,,,,,,,,,,,, -440800,4.3587227,0.6522771,,,,,,,,,,,,,, -440900,4.2025833,0.57116944,,,,,,,,,,,,,, -440954,,,0.961575210094452,0.1486123651266098,0.755299985408783,1.0435887575149536,50000.0,0.6302000284194946,1.821756720542908,10000.0,148981.3084001541,154171.9116613865,148981.3084001541,5153.135057687759,22.316500425338745,0.0 -441000,4.7623577,0.66463625,,,,,,,,,,,,,, -441100,4.7079377,0.62526596,,,,,,,,,,,,,, -441200,4.7144547,0.6720555,,,,,,,,,,,,,, -441300,4.623484,0.6614638,,,,,,,,,,,,,, -441400,4.5521374,0.6369472,,,,,,,,,,,,,, -441500,4.1056013,0.5797673,,,,,,,,,,,,,, -441600,4.3836055,0.59906256,,,,,,,,,,,,,, -441700,4.4890394,0.62316215,,,,,,,,,,,,,, -441800,4.194842,0.61976504,,,,,,,,,,,,,, -441900,4.4888754,0.62304944,,,,,,,,,,,,,, -442000,4.72286,0.64333147,,,,,,,,,,,,,, -442100,4.1147156,0.62572825,,,,,,,,,,,,,, -442200,4.371934,0.6522909,,,,,,,,,,,,,, -442300,4.2763257,0.58651954,,,,,,,,,,,,,, -442400,4.437245,0.59703845,,,,,,,,,,,,,, -442464,,,0.9580875039100648,0.1526551395654678,0.7554399967193604,1.0435420274734497,50000.0,0.631100058555603,1.8232500553131104,10000.0,149491.16506505013,154698.7199485302,149491.16506505013,5169.926432609558,22.424886465072632,0.0 -442500,4.6912007,0.6872738,,,,,,,,,,,,,, -442600,4.5341706,0.63129747,,,,,,,,,,,,,, -442700,4.4146223,0.6199265,,,,,,,,,,,,,, -442800,4.74303,0.5566675,,,,,,,,,,,,,, -442900,4.8754888,0.608374,,,,,,,,,,,,,, -443000,4.720038,0.7088319,,,,,,,,,,,,,, -443100,4.7790747,0.71686035,,,,,,,,,,,,,, -443200,4.1764255,0.59896237,,,,,,,,,,,,,, -443300,4.744662,0.6927772,,,,,,,,,,,,,, -443400,4.334228,0.5864405,,,,,,,,,,,,,, -443500,3.997075,0.59925425,,,,,,,,,,,,,, -443600,5.60605,0.6315466,,,,,,,,,,,,,, -443700,4.582335,0.6357766,,,,,,,,,,,,,, -443800,4.500127,0.61277723,,,,,,,,,,,,,, -443900,4.674364,0.57338095,,,,,,,,,,,,,, -443974,,,0.959582269191742,0.1498272866010666,0.7555999755859375,1.043811559677124,50000.0,0.6308000087738037,1.8222272396087649,10000.0,150001.00571346283,155225.45688009262,150001.00571346283,5186.662141799927,22.53334045410156,0.0 -444000,4.59665,0.6158542,,,,,,,,,,,,,, -444100,4.3087335,0.6100324,,,,,,,,,,,,,, -444200,4.7858367,0.6762744,,,,,,,,,,,,,, -444300,4.385029,0.6141155,,,,,,,,,,,,,, -444400,4.4415584,0.66228753,,,,,,,,,,,,,, -444500,4.979688,0.6541133,,,,,,,,,,,,,, -444600,4.631937,0.6280835,,,,,,,,,,,,,, -444700,4.6798725,0.6696298,,,,,,,,,,,,,, -444800,4.2878118,0.58354896,,,,,,,,,,,,,, -444900,4.8216977,0.6470627,,,,,,,,,,,,,, -445000,4.6483235,0.6730046,,,,,,,,,,,,,, -445100,4.5371675,0.6198783,,,,,,,,,,,,,, -445200,4.5160837,0.6077157,,,,,,,,,,,,,, -445300,4.393507,0.6280004,,,,,,,,,,,,,, -445400,4.553101,0.6112727,,,,,,,,,,,,,, -445485,,,0.9597018361091614,0.1487534046173095,0.7551199793815613,1.0437120199203491,50000.0,0.6308000087738037,1.8205363750457764,10000.0,150511.18306398392,155752.8000433445,150511.18306398392,5203.670311450958,22.639395475387573,0.0 -445500,4.2830553,0.66307086,,,,,,,,,,,,,, -445600,4.786008,0.59832567,,,,,,,,,,,,,, -445700,4.3799987,0.673757,,,,,,,,,,,,,, -445800,4.6827335,0.671816,,,,,,,,,,,,,, -445900,4.499504,0.6516613,,,,,,,,,,,,,, -446000,4.343484,0.57658505,,,,,,,,,,,,,, -446100,4.813288,0.67150426,,,,,,,,,,,,,, -446200,5.1594586,0.7239826,,,,,,,,,,,,,, -446300,4.968429,0.65672064,,,,,,,,,,,,,, -446400,4.3754606,0.67108834,,,,,,,,,,,,,, -446500,4.4667435,0.5890222,,,,,,,,,,,,,, -446600,4.331785,0.59993845,,,,,,,,,,,,,, -446700,4.276053,0.5667862,,,,,,,,,,,,,, -446800,4.041396,0.5406856,,,,,,,,,,,,,, -446900,4.5474596,0.6475615,,,,,,,,,,,,,, -446996,,,0.9610769748687744,0.1466841101646423,0.7549399733543396,1.0435166358947754,50000.0,0.6314000487327576,1.819707155227661,10000.0,151021.32336831093,156280.19078469276,151021.32336831093,5220.766705274582,22.740891456604004,0.0 -447000,5.209372,0.68497896,,,,,,,,,,,,,, -447100,4.5645156,0.66184425,,,,,,,,,,,,,, -447200,4.6530347,0.6557983,,,,,,,,,,,,,, -447300,5.0646477,0.6177816,,,,,,,,,,,,,, -447400,4.872377,0.6917193,,,,,,,,,,,,,, -447500,4.853247,0.6109743,,,,,,,,,,,,,, -447600,4.5357256,0.55966735,,,,,,,,,,,,,, -447700,5.0256696,0.64091104,,,,,,,,,,,,,, -447800,4.2662516,0.57841766,,,,,,,,,,,,,, -447900,4.5634303,0.56123877,,,,,,,,,,,,,, -448000,4.3082366,0.6470578,,,,,,,,,,,,,, -448100,4.117838,0.547773,,,,,,,,,,,,,, -448200,4.8189673,0.7522768,,,,,,,,,,,,,, -448300,4.906438,0.6146184,,,,,,,,,,,,,, -448400,4.8625493,0.6192756,,,,,,,,,,,,,, -448500,4.5611877,0.6210331,,,,,,,,,,,,,, -448506,,,0.9595623016357422,0.1505237370729446,0.7554399967193604,1.044321060180664,50000.0,0.6306000351905823,1.8209294080734253,10000.0,151531.32101392746,156807.316532135,151531.32101392746,5237.740381240845,22.843069553375244,0.0 -448600,4.4705153,0.6153752,,,,,,,,,,,,,, -448700,4.548991,0.6614553,,,,,,,,,,,,,, -448800,4.8420315,0.674721,,,,,,,,,,,,,, -448900,4.578206,0.7087894,,,,,,,,,,,,,, -449000,4.537894,0.611408,,,,,,,,,,,,,, -449100,4.51588,0.67454624,,,,,,,,,,,,,, -449200,4.5024695,0.6229363,,,,,,,,,,,,,, -449300,4.608154,0.6205802,,,,,,,,,,,,,, -449400,4.356547,0.6336899,,,,,,,,,,,,,, -449500,4.9947934,0.638339,,,,,,,,,,,,,, -449600,4.813056,0.6625432,,,,,,,,,,,,,, -449700,4.64688,0.695821,,,,,,,,,,,,,, -449800,4.221043,0.6067451,,,,,,,,,,,,,, -449900,5.0538616,0.68215847,,,,,,,,,,,,,, -450000,4.4925404,0.6188442,,,,,,,,,,,,,, -450016,,,0.959741711616516,0.1485694497823715,0.7551199793815613,1.0434327125549316,50000.0,0.6318000555038452,1.8213356733322144,10000.0,152041.26451206207,157334.1936788559,152041.26451206207,5254.520807504654,22.94432306289673,0.0 -450100,4.8559747,0.68208265,,,,,,,,,,,,,, -450200,4.4331493,0.59896064,,,,,,,,,,,,,, -450300,4.5907283,0.68095964,,,,,,,,,,,,,, -450400,4.340295,0.5846625,,,,,,,,,,,,,, -450500,4.7629633,0.6913421,,,,,,,,,,,,,, -450600,4.806151,0.69169146,,,,,,,,,,,,,, -450700,4.7332487,0.61212844,,,,,,,,,,,,,, -450800,4.303297,0.64912397,,,,,,,,,,,,,, -450900,4.5260744,0.5894461,,,,,,,,,,,,,, -451000,4.4377394,0.6647639,,,,,,,,,,,,,, -451100,4.4192977,0.5829366,,,,,,,,,,,,,, -451200,4.367954,0.63968176,,,,,,,,,,,,,, -451300,4.3754215,0.565546,,,,,,,,,,,,,, -451400,4.167916,0.57492155,,,,,,,,,,,,,, -451500,4.338763,0.66784996,,,,,,,,,,,,,, -451526,,,0.9610969424247742,0.147056832909584,0.7548999786376953,1.0432263612747192,50000.0,0.6309000253677368,1.8225295543670648,10000.0,152551.20199108124,157861.52579021454,152551.20199108124,5271.759593009949,23.047905683517456,0.0 -451600,4.671699,0.586231,,,,,,,,,,,,,, -451700,4.134428,0.6182591,,,,,,,,,,,,,, -451800,4.8299255,0.5959828,,,,,,,,,,,,,, -451900,4.5068398,0.62605953,,,,,,,,,,,,,, -452000,4.3805356,0.6263905,,,,,,,,,,,,,, -452100,4.2275314,0.56187475,,,,,,,,,,,,,, -452200,4.0828266,0.53559005,,,,,,,,,,,,,, -452300,4.839177,0.61624396,,,,,,,,,,,,,, -452400,4.308647,0.6455163,,,,,,,,,,,,,, -452500,4.532425,0.64172214,,,,,,,,,,,,,, -452600,4.971586,0.6085543,,,,,,,,,,,,,, -452700,3.9909315,0.54785466,,,,,,,,,,,,,, -452800,4.5364637,0.66852766,,,,,,,,,,,,,, -452900,4.088557,0.5805135,,,,,,,,,,,,,, -453000,4.7298164,0.71875614,,,,,,,,,,,,,, -453036,,,0.9601004123687744,0.1490989327430725,0.7547999620437622,1.044902205467224,50000.0,0.6308000087738037,1.8207491636276243,10000.0,153061.21781492233,158388.60536050797,153061.21781492233,5288.670228481293,23.14963865280152,0.0 -453100,4.5397186,0.59900296,,,,,,,,,,,,,, -453200,4.1902575,0.56772304,,,,,,,,,,,,,, -453300,4.1212926,0.6137784,,,,,,,,,,,,,, -453400,4.598694,0.68392533,,,,,,,,,,,,,, -453500,4.7419066,0.6633448,,,,,,,,,,,,,, -453600,4.524547,0.6045822,,,,,,,,,,,,,, -453700,4.80326,0.6021329,,,,,,,,,,,,,, -453800,4.6943226,0.64310753,,,,,,,,,,,,,, -453900,4.678102,0.6449115,,,,,,,,,,,,,, -454000,4.9882274,0.6799308,,,,,,,,,,,,,, -454100,4.8876033,0.5836764,,,,,,,,,,,,,, -454200,4.790571,0.69150716,,,,,,,,,,,,,, -454300,4.1462865,0.5819619,,,,,,,,,,,,,, -454400,4.9244843,0.67345965,,,,,,,,,,,,,, -454500,4.33959,0.60273445,,,,,,,,,,,,,, -454546,,,0.9602199792861938,0.1482055038213729,0.7547999620437622,1.0442836284637451,50000.0,0.6305000185966492,1.8212093114852903,10000.0,153571.1104207039,158915.745316267,153571.1104207039,5305.7620232105255,23.25257182121277,0.0 -454600,4.6270576,0.6052154,,,,,,,,,,,,,, -454700,4.313234,0.665584,,,,,,,,,,,,,, -454800,4.9167156,0.6989806,,,,,,,,,,,,,, -454900,4.2506504,0.5929562,,,,,,,,,,,,,, -455000,4.2952952,0.5476222,,,,,,,,,,,,,, -455100,4.688452,0.7151667,,,,,,,,,,,,,, -455200,5.009413,0.64405555,,,,,,,,,,,,,, -455300,4.3851595,0.6164532,,,,,,,,,,,,,, -455400,4.3682795,0.5722836,,,,,,,,,,,,,, -455500,4.1518674,0.5634035,,,,,,,,,,,,,, -455600,4.9530573,0.73340845,,,,,,,,,,,,,, -455700,4.590018,0.58572894,,,,,,,,,,,,,, -455800,4.678803,0.64002776,,,,,,,,,,,,,, -455900,4.761735,0.59150445,,,,,,,,,,,,,, -456000,4.49397,0.62931275,,,,,,,,,,,,,, -456056,,,0.960379421710968,0.1483684480190277,0.7552599906921387,1.0436128377914429,50000.0,0.6319000124931335,1.8223989009857176,10000.0,154081.2016234398,159442.92390727997,154081.2016234398,5322.696059703827,23.35454678535461,0.0 -456100,4.4676394,0.6757488,,,,,,,,,,,,,, -456200,4.381585,0.575196,,,,,,,,,,,,,, -456300,5.505338,0.72513163,,,,,,,,,,,,,, -456400,4.4626565,0.6019847,,,,,,,,,,,,,, -456500,4.4867682,0.61453164,,,,,,,,,,,,,, -456600,4.3122234,0.5667029,,,,,,,,,,,,,, -456700,3.9818802,0.52165335,,,,,,,,,,,,,, -456800,4.4341655,0.6083939,,,,,,,,,,,,,, -456900,5.0144234,0.630636,,,,,,,,,,,,,, -457000,4.477147,0.5948339,,,,,,,,,,,,,, -457100,4.4798636,0.5739906,,,,,,,,,,,,,, -457200,4.2591105,0.6014184,,,,,,,,,,,,,, -457300,4.4499683,0.616297,,,,,,,,,,,,,, -457400,4.4338765,0.63905674,,,,,,,,,,,,,, -457500,4.9723654,0.64914405,,,,,,,,,,,,,, -457567,,,0.9609375,0.1461519598960876,0.7552799582481384,1.0445982217788696,50000.0,0.631100058555603,1.82321572303772,10000.0,154591.28657960892,159969.99428796768,154591.28657960892,5339.510857105255,23.473296642303467,0.0 -457600,4.51712,0.59715086,,,,,,,,,,,,,, -457700,4.2098794,0.5786241,,,,,,,,,,,,,, -457800,4.2666464,0.62430954,,,,,,,,,,,,,, -457900,4.946792,0.602278,,,,,,,,,,,,,, -458000,4.3777046,0.59690976,,,,,,,,,,,,,, -458100,4.5398793,0.6202281,,,,,,,,,,,,,, -458200,4.298093,0.5655092,,,,,,,,,,,,,, -458300,4.6015177,0.67248285,,,,,,,,,,,,,, -458400,4.7177033,0.665475,,,,,,,,,,,,,, -458500,4.481178,0.6392075,,,,,,,,,,,,,, -458600,4.1300073,0.5795215,,,,,,,,,,,,,, -458700,4.5155106,0.5804714,,,,,,,,,,,,,, -458800,4.121764,0.59494084,,,,,,,,,,,,,, -458900,4.558913,0.64972043,,,,,,,,,,,,,, -459000,4.9114122,0.6071475,,,,,,,,,,,,,, -459078,,,0.9615553021430968,0.1467585861682891,0.7550399899482727,1.0421971082687378,50000.0,0.6315000057220459,1.8196227550506592,10000.0,155101.33123731613,160497.07613682747,155101.33123731613,5356.392268419266,23.5770070552826,0.0 -459100,4.8597674,0.71395206,,,,,,,,,,,,,, -459200,4.7624125,0.67953926,,,,,,,,,,,,,, -459300,4.299633,0.66582274,,,,,,,,,,,,,, -459400,4.3378367,0.59539276,,,,,,,,,,,,,, -459500,4.61659,0.6494104,,,,,,,,,,,,,, -459600,4.5222063,0.6365353,,,,,,,,,,,,,, -459700,4.293212,0.6209613,,,,,,,,,,,,,, -459800,4.806662,0.6028368,,,,,,,,,,,,,, -459900,4.9729624,0.6299975,,,,,,,,,,,,,, -460000,4.526882,0.6382036,,,,,,,,,,,,,, -460100,4.6374125,0.6751098,,,,,,,,,,,,,, -460200,4.5158477,0.5539241,,,,,,,,,,,,,, -460300,4.7934833,0.65834475,,,,,,,,,,,,,, -460400,4.4155464,0.597259,,,,,,,,,,,,,, -460500,4.754782,0.58269024,,,,,,,,,,,,,, -460588,,,0.9595623016357422,0.148551806807518,0.755079984664917,1.0454641580581665,50000.0,0.6309000253677368,1.822048306465149,10000.0,155611.23100996015,161024.65453124046,155611.23100996015,5373.915447711945,23.67885732650757,0.0 -460600,5.154563,0.66481173,,,,,,,,,,,,,, -460700,3.9596636,0.5082583,,,,,,,,,,,,,, -460800,4.523662,0.59812194,,,,,,,,,,,,,, -460900,4.8735046,0.64155734,,,,,,,,,,,,,, -461000,4.6164913,0.66348124,,,,,,,,,,,,,, -461100,4.2102423,0.58135855,,,,,,,,,,,,,, -461200,4.498002,0.65387434,,,,,,,,,,,,,, -461300,4.23725,0.5787692,,,,,,,,,,,,,, -461400,4.4739614,0.58142847,,,,,,,,,,,,,, -461500,4.1028094,0.579692,,,,,,,,,,,,,, -461600,4.7225895,0.6762119,,,,,,,,,,,,,, -461700,4.2496443,0.562617,,,,,,,,,,,,,, -461800,4.8363233,0.5523692,,,,,,,,,,,,,, -461900,4.5093765,0.6559304,,,,,,,,,,,,,, -462000,4.6069016,0.66250986,,,,,,,,,,,,,, -462099,,,0.959363043308258,0.1515534073114395,0.7554000020027161,1.0442078113555908,50000.0,0.6312000155448914,1.821954727172852,10000.0,156121.33530831337,161551.77962732315,156121.33530831337,5390.780106306076,23.7828586101532,0.0 -462100,4.3017344,0.54987985,,,,,,,,,,,,,, -462200,4.766195,0.63529015,,,,,,,,,,,,,, -462300,4.676579,0.6677235,,,,,,,,,,,,,, -462400,4.537385,0.6216227,,,,,,,,,,,,,, -462500,4.282032,0.5512028,,,,,,,,,,,,,, -462600,4.8994274,0.6169885,,,,,,,,,,,,,, -462700,4.5158205,0.63465595,,,,,,,,,,,,,, -462800,4.084899,0.6503573,,,,,,,,,,,,,, -462900,4.427699,0.5874765,,,,,,,,,,,,,, -463000,4.005652,0.5514161,,,,,,,,,,,,,, -463100,4.2448864,0.58615327,,,,,,,,,,,,,, -463200,4.4582634,0.65784955,,,,,,,,,,,,,, -463300,4.367999,0.60979027,,,,,,,,,,,,,, -463400,4.7048273,0.66530335,,,,,,,,,,,,,, -463500,4.8909583,0.6695914,,,,,,,,,,,,,, -463600,4.6981115,0.6572729,,,,,,,,,,,,,, -463609,,,0.9598413109779358,0.1484057307243347,0.7550399899482727,1.043148398399353,50000.0,0.6315000057220459,1.8227782249450684,10000.0,156631.41730761528,162078.88946652412,156631.41730761528,5407.651072978973,23.887704133987427,0.0 -463700,4.641538,0.669632,,,,,,,,,,,,,, -463800,4.3425593,0.64808196,,,,,,,,,,,,,, -463900,4.1566625,0.6348136,,,,,,,,,,,,,, -464000,4.5645356,0.62707853,,,,,,,,,,,,,, -464100,4.545432,0.57678735,,,,,,,,,,,,,, -464200,5.02023,0.60667884,,,,,,,,,,,,,, -464300,4.4387436,0.63267326,,,,,,,,,,,,,, -464400,4.8555956,0.61384046,,,,,,,,,,,,,, -464500,4.6814213,0.6354305,,,,,,,,,,,,,, -464600,4.8509917,0.60743725,,,,,,,,,,,,,, -464700,4.2782626,0.6467118,,,,,,,,,,,,,, -464800,4.392557,0.6363758,,,,,,,,,,,,,, -464900,4.5054474,0.576819,,,,,,,,,,,,,, -465000,4.997661,0.64745796,,,,,,,,,,,,,, -465100,4.6979036,0.61177397,,,,,,,,,,,,,, -465119,,,0.9618741869926452,0.1466239541769027,0.755079984664917,1.044501543045044,50000.0,0.6305000185966492,1.8231332302093504,10000.0,157141.29112553596,162605.82205343246,157141.29112553596,5424.55797290802,23.98750042915344,0.0 -465200,5.7040076,0.6937678,,,,,,,,,,,,,, -465300,4.729267,0.6270002,,,,,,,,,,,,,, -465400,4.2890544,0.62064296,,,,,,,,,,,,,, -465500,4.666466,0.6570261,,,,,,,,,,,,,, -465600,4.4713044,0.6498954,,,,,,,,,,,,,, -465700,4.792424,0.6917846,,,,,,,,,,,,,, -465800,4.8634214,0.62689155,,,,,,,,,,,,,, -465900,4.1743245,0.64732844,,,,,,,,,,,,,, -466000,4.2496185,0.6765598,,,,,,,,,,,,,, -466100,4.277641,0.5943781,,,,,,,,,,,,,, -466200,4.4774413,0.6655013,,,,,,,,,,,,,, -466300,4.43381,0.56834596,,,,,,,,,,,,,, -466400,4.440812,0.6245196,,,,,,,,,,,,,, -466500,4.424454,0.6285813,,,,,,,,,,,,,, -466600,4.535805,0.68682307,,,,,,,,,,,,,, -466629,,,0.9601004123687744,0.1466823518276214,0.7554199695587158,1.043945550918579,50000.0,0.6304000020027161,1.822500228881836,10000.0,157651.32108592987,163132.91656827927,157651.32108592987,5441.46524143219,24.0924711227417,0.0 -466700,4.250821,0.6595629,,,,,,,,,,,,,, -466800,5.069524,0.71722513,,,,,,,,,,,,,, -466900,4.4066405,0.5819845,,,,,,,,,,,,,, -467000,4.1623464,0.60603786,,,,,,,,,,,,,, -467100,4.7010446,0.6658689,,,,,,,,,,,,,, -467200,4.825895,0.6374433,,,,,,,,,,,,,, -467300,4.7625146,0.6649077,,,,,,,,,,,,,, -467400,4.8815427,0.6985451,,,,,,,,,,,,,, -467500,4.340188,0.5786233,,,,,,,,,,,,,, -467600,4.562907,0.6907954,,,,,,,,,,,,,, -467700,4.2534966,0.5270157,,,,,,,,,,,,,, -467800,4.6569505,0.7177255,,,,,,,,,,,,,, -467900,4.461199,0.63852954,,,,,,,,,,,,,, -468000,4.3763027,0.5637669,,,,,,,,,,,,,, -468100,4.3623786,0.6189381,,,,,,,,,,,,,, -468140,,,0.9606584906578064,0.1466185003519058,0.7549399733543396,1.0439494848251345,50000.0,0.6304000020027161,1.8211979866027832,10000.0,158161.37888002396,163660.12048506737,158161.37888002396,5458.459059715271,24.19330763816833,0.0 -468200,4.653138,0.63866353,,,,,,,,,,,,,, -468300,4.393749,0.60499847,,,,,,,,,,,,,, -468400,4.655407,0.6531623,,,,,,,,,,,,,, -468500,4.232372,0.5677241,,,,,,,,,,,,,, -468600,5.0559,0.66120666,,,,,,,,,,,,,, -468700,4.272148,0.65953386,,,,,,,,,,,,,, -468800,4.467399,0.6437193,,,,,,,,,,,,,, -468900,4.5818744,0.6725452,,,,,,,,,,,,,, -469000,4.380131,0.6274697,,,,,,,,,,,,,, -469100,4.439999,0.6355298,,,,,,,,,,,,,, -469200,4.267213,0.57778025,,,,,,,,,,,,,, -469300,4.3866143,0.6047557,,,,,,,,,,,,,, -469400,4.4471974,0.60327023,,,,,,,,,,,,,, -469500,4.5570946,0.66648483,,,,,,,,,,,,,, -469600,4.560404,0.5471555,,,,,,,,,,,,,, -469651,,,0.960718274116516,0.1477012634277343,0.7555399537086487,1.0445095300674438,50000.0,0.6307000517845154,1.822662591934204,10000.0,158671.50946760178,164187.35395240784,158671.50946760178,5475.40735244751,24.295584201812744,0.0 -469700,4.765326,0.63619053,,,,,,,,,,,,,, -469800,5.3078165,0.6443361,,,,,,,,,,,,,, -469900,4.6508603,0.7080753,,,,,,,,,,,,,, -470000,4.981439,0.65908647,,,,,,,,,,,,,, -470100,4.327877,0.5995122,,,,,,,,,,,,,, -470200,4.6204505,0.5672987,,,,,,,,,,,,,, -470300,4.3564987,0.642354,,,,,,,,,,,,,, -470400,4.658907,0.65763676,,,,,,,,,,,,,, -470500,4.764579,0.66784394,,,,,,,,,,,,,, -470600,4.8106093,0.6545141,,,,,,,,,,,,,, -470700,4.3203993,0.61316246,,,,,,,,,,,,,, -470800,4.920377,0.67395663,,,,,,,,,,,,,, -470900,4.3079977,0.5581045,,,,,,,,,,,,,, -471000,4.5626035,0.6073207,,,,,,,,,,,,,, -471100,4.1425486,0.55755687,,,,,,,,,,,,,, -471162,,,0.9594826102256776,0.1514231264591217,0.7551800012588501,1.043542504310608,50000.0,0.6308000087738037,1.821772575378418,10000.0,159181.595911026,164714.5281674862,159181.595911026,5492.339055299759,24.399834871292114,0.0 -471200,5.178603,0.6479538,,,,,,,,,,,,,, -471300,4.2122517,0.60934114,,,,,,,,,,,,,, -471400,4.297484,0.67599446,,,,,,,,,,,,,, -471500,4.7707725,0.64115506,,,,,,,,,,,,,, -471600,4.704744,0.58578324,,,,,,,,,,,,,, -471700,4.2620573,0.57913125,,,,,,,,,,,,,, -471800,4.2387166,0.6564841,,,,,,,,,,,,,, -471900,4.87468,0.67318046,,,,,,,,,,,,,, -472000,5.096248,0.6356496,,,,,,,,,,,,,, -472100,4.591918,0.63194233,,,,,,,,,,,,,, -472200,4.795487,0.6191283,,,,,,,,,,,,,, -472300,4.017132,0.5591152,,,,,,,,,,,,,, -472400,4.749545,0.6034799,,,,,,,,,,,,,, -472500,5.0669956,0.75395554,,,,,,,,,,,,,, -472600,4.4963994,0.6242256,,,,,,,,,,,,,, -472673,,,0.9615553021430968,0.1431863456964492,0.755299985408783,1.0438499450683594,50000.0,0.6302000284194946,1.821508288383484,10000.0,159691.7286350727,165241.82578277588,159691.7286350727,5509.320637464523,24.53001499176025,0.0 -472700,4.8743463,0.682404,,,,,,,,,,,,,, -472800,4.7422605,0.5644678,,,,,,,,,,,,,, -472900,4.5789213,0.62305427,,,,,,,,,,,,,, -473000,4.568968,0.651407,,,,,,,,,,,,,, -473100,4.488603,0.59979874,,,,,,,,,,,,,, -473200,4.3587146,0.6508455,,,,,,,,,,,,,, -473300,4.7735023,0.6389354,,,,,,,,,,,,,, -473400,4.2071023,0.56843305,,,,,,,,,,,,,, -473500,4.981959,0.74786687,,,,,,,,,,,,,, -473600,4.502957,0.57691807,,,,,,,,,,,,,, -473700,4.8568387,0.64583325,,,,,,,,,,,,,, -473800,5.1519012,0.61669624,,,,,,,,,,,,,, -473900,4.7430363,0.6787065,,,,,,,,,,,,,, -474000,4.5227757,0.6033018,,,,,,,,,,,,,, -474100,4.7267756,0.6371884,,,,,,,,,,,,,, -474183,,,0.9614756107330322,0.1437897384166717,0.754859983921051,1.045137643814087,50000.0,0.6310000419616699,1.823304533958435,10000.0,160201.8107573986,165769.17288303375,160201.8107573986,5526.427300453186,24.63716220855713,0.0 -474200,4.2533393,0.5469888,,,,,,,,,,,,,, -474300,4.502624,0.6477222,,,,,,,,,,,,,, -474400,4.441599,0.6248944,,,,,,,,,,,,,, -474500,4.529772,0.5613813,,,,,,,,,,,,,, -474600,4.424978,0.60487425,,,,,,,,,,,,,, -474700,4.3701158,0.58041084,,,,,,,,,,,,,, -474800,4.231902,0.62053704,,,,,,,,,,,,,, -474900,4.5292406,0.62074983,,,,,,,,,,,,,, -475000,4.58173,0.69191957,,,,,,,,,,,,,, -475100,4.5481844,0.60654867,,,,,,,,,,,,,, -475200,4.403591,0.63769233,,,,,,,,,,,,,, -475300,4.5369115,0.60417676,,,,,,,,,,,,,, -475400,4.616195,0.6689683,,,,,,,,,,,,,, -475500,4.4576244,0.5626366,,,,,,,,,,,,,, -475600,4.454062,0.55496895,,,,,,,,,,,,,, -475693,,,0.9602798223495485,0.1477307230234146,0.7551999688148499,1.0440150499343872,50000.0,0.6305000185966492,1.8224012851715088,10000.0,160711.80350279808,166296.12507390976,160711.80350279808,5543.228472948074,24.7429301738739,0.0 -475700,4.52957,0.6182459,,,,,,,,,,,,,, -475800,4.108006,0.55922014,,,,,,,,,,,,,, -475900,5.2534385,0.6660894,,,,,,,,,,,,,, -476000,4.6931396,0.63566005,,,,,,,,,,,,,, -476100,4.7211127,0.54944485,,,,,,,,,,,,,, -476200,4.2462583,0.54293364,,,,,,,,,,,,,, -476300,4.470227,0.5984984,,,,,,,,,,,,,, -476400,4.2630897,0.58966297,,,,,,,,,,,,,, -476500,4.6668715,0.65167964,,,,,,,,,,,,,, -476600,3.952532,0.48801613,,,,,,,,,,,,,, -476700,4.089924,0.54886687,,,,,,,,,,,,,, -476800,4.538237,0.6357369,,,,,,,,,,,,,, -476900,4.5701923,0.6219016,,,,,,,,,,,,,, -477000,4.1123395,0.5596193,,,,,,,,,,,,,, -477100,4.3427405,0.58782923,,,,,,,,,,,,,, -477200,5.1537676,0.6659484,,,,,,,,,,,,,, -477203,,,0.9600406289100648,0.1513166725635528,0.7547199726104736,1.0432075262069702,50000.0,0.6299000382423401,1.8201723098754885,10000.0,161221.76031327248,166823.05702018738,161221.76031327248,5560.047877073288,24.84749460220337,0.0 -477300,4.141937,0.61593,,,,,,,,,,,,,, -477400,4.5082717,0.6159118,,,,,,,,,,,,,, -477500,4.659184,0.64024436,,,,,,,,,,,,,, -477600,4.791139,0.6798424,,,,,,,,,,,,,, -477700,4.4590306,0.52875113,,,,,,,,,,,,,, -477800,4.494982,0.5953751,,,,,,,,,,,,,, -477900,4.2047477,0.5896369,,,,,,,,,,,,,, -478000,4.2226763,0.58507097,,,,,,,,,,,,,, -478100,4.5097275,0.63626665,,,,,,,,,,,,,, -478200,4.585941,0.64738643,,,,,,,,,,,,,, -478300,4.539441,0.61373746,,,,,,,,,,,,,, -478400,4.1870956,0.58623844,,,,,,,,,,,,,, -478500,4.673642,0.68236876,,,,,,,,,,,,,, -478600,4.2059555,0.5478135,,,,,,,,,,,,,, -478700,4.0511594,0.54008245,,,,,,,,,,,,,, -478712,,,0.9610171914100648,0.1452740877866745,0.7547000050544739,1.045116901397705,50000.0,0.6303000450134277,1.824659824371338,10000.0,161731.6183450222,167349.98368883133,161731.6183450222,5576.959014892578,24.95271611213684,0.0 -478800,4.4023786,0.64365464,,,,,,,,,,,,,, -478900,4.603094,0.6797065,,,,,,,,,,,,,, -479000,4.402555,0.60409504,,,,,,,,,,,,,, -479100,5.2118444,0.66704625,,,,,,,,,,,,,, -479200,4.6619473,0.68264693,,,,,,,,,,,,,, -479300,4.282835,0.6121144,,,,,,,,,,,,,, -479400,4.039835,0.54460406,,,,,,,,,,,,,, -479500,4.674588,0.6143835,,,,,,,,,,,,,, -479600,4.418338,0.6307583,,,,,,,,,,,,,, -479700,4.642447,0.65413153,,,,,,,,,,,,,, -479800,4.6845136,0.65952796,,,,,,,,,,,,,, -479900,4.381664,0.61639696,,,,,,,,,,,,,, -480000,4.8793383,0.5942534,,,,,,,,,,,,,, -480100,4.330062,0.6435545,,,,,,,,,,,,,, -480200,4.8934183,0.6305839,,,,,,,,,,,,,, -480222,,,0.9602598547935486,0.1506530493497848,0.7551400065422058,1.043482542037964,50000.0,0.6310000419616699,1.821131706237793,10000.0,162241.5458357334,167877.05255436897,162241.5458357334,5593.939316511154,25.06274676322937,0.0 -480300,4.3878036,0.6704753,,,,,,,,,,,,,, -480400,4.452258,0.5690486,,,,,,,,,,,,,, -480500,4.90996,0.64461637,,,,,,,,,,,,,, -480600,4.238967,0.5534554,,,,,,,,,,,,,, -480700,4.7136316,0.660825,,,,,,,,,,,,,, -480800,4.4866557,0.5743203,,,,,,,,,,,,,, -480900,4.5297813,0.61508334,,,,,,,,,,,,,, -481000,4.438168,0.59544057,,,,,,,,,,,,,, -481100,4.903957,0.67373884,,,,,,,,,,,,,, -481200,5.258862,0.66777325,,,,,,,,,,,,,, -481300,4.3511677,0.61514294,,,,,,,,,,,,,, -481400,4.830693,0.62812674,,,,,,,,,,,,,, -481500,4.1125474,0.56289786,,,,,,,,,,,,,, -481600,4.7717485,0.5698929,,,,,,,,,,,,,, -481700,4.658957,0.65399325,,,,,,,,,,,,,, -481732,,,0.9599011540412904,0.1497306227684021,0.754859983921051,1.044361591339111,50000.0,0.6317000389099121,1.822176098823548,10000.0,162751.5722372532,168404.1153049469,162751.5722372532,5610.818668365479,25.167808771133423,0.0 -481800,4.944804,0.6216212,,,,,,,,,,,,,, -481900,4.2947073,0.58227074,,,,,,,,,,,,,, -482000,4.0229893,0.6208751,,,,,,,,,,,,,, -482100,4.462094,0.6432761,,,,,,,,,,,,,, -482200,4.432854,0.61121476,,,,,,,,,,,,,, -482300,4.3265595,0.57454526,,,,,,,,,,,,,, -482400,4.4737616,0.6130642,,,,,,,,,,,,,, -482500,4.1690636,0.58259153,,,,,,,,,,,,,, -482600,5.200685,0.6286021,,,,,,,,,,,,,, -482700,4.3914366,0.5742602,,,,,,,,,,,,,, -482800,4.6356454,0.64167696,,,,,,,,,,,,,, -482900,4.4493895,0.6504802,,,,,,,,,,,,,, -483000,4.5868826,0.61287993,,,,,,,,,,,,,, -483100,4.37281,0.62819237,,,,,,,,,,,,,, -483200,4.436183,0.5980252,,,,,,,,,,,,,, -483243,,,0.9590840339660645,0.1498535722494125,0.7548799514770508,1.0456264019012451,50000.0,0.6308000087738037,1.8228379487991333,10000.0,163261.70915150642,168931.19228696823,163261.70915150642,5627.59955739975,25.27507972717285,0.0 -483300,4.4051094,0.6043451,,,,,,,,,,,,,, -483400,4.689438,0.62502784,,,,,,,,,,,,,, -483500,4.271551,0.5691852,,,,,,,,,,,,,, -483600,4.0169277,0.6064205,,,,,,,,,,,,,, -483700,4.3286,0.644902,,,,,,,,,,,,,, -483800,4.48044,0.6368699,,,,,,,,,,,,,, -483900,4.27379,0.59040475,,,,,,,,,,,,,, -484000,4.4762073,0.6314572,,,,,,,,,,,,,, -484100,4.7933607,0.702379,,,,,,,,,,,,,, -484200,4.2537527,0.63086313,,,,,,,,,,,,,, -484300,4.352073,0.5891937,,,,,,,,,,,,,, -484400,4.90182,0.63507324,,,,,,,,,,,,,, -484500,3.9652293,0.5475024,,,,,,,,,,,,,, -484600,4.209313,0.6004009,,,,,,,,,,,,,, -484700,4.794836,0.7082161,,,,,,,,,,,,,, -484753,,,0.9606983065605164,0.1470476686954498,0.7554000020027161,1.0443992614746094,50000.0,0.6305000185966492,1.8222293853759768,10000.0,163771.55219697952,169458.08463525772,163771.55219697952,5644.491109132767,25.381606101989743,0.0 -484800,4.5935006,0.66324466,,,,,,,,,,,,,, -484900,4.5691996,0.66678256,,,,,,,,,,,,,, -485000,4.491856,0.64741987,,,,,,,,,,,,,, -485100,4.5969877,0.65053356,,,,,,,,,,,,,, -485200,4.379791,0.60967016,,,,,,,,,,,,,, -485300,5.382715,0.6558995,,,,,,,,,,,,,, -485400,4.4981785,0.55872196,,,,,,,,,,,,,, -485500,4.2116632,0.6647005,,,,,,,,,,,,,, -485600,4.5990267,0.5702668,,,,,,,,,,,,,, -485700,4.1256537,0.5876048,,,,,,,,,,,,,, -485800,4.6245327,0.5656588,,,,,,,,,,,,,, -485900,4.397291,0.6166918,,,,,,,,,,,,,, -486000,4.2733064,0.4735648,,,,,,,,,,,,,, -486100,4.695363,0.68151385,,,,,,,,,,,,,, -486200,4.332528,0.6178988,,,,,,,,,,,,,, -486264,,,0.9604591727256776,0.147618219256401,0.7550399899482727,1.044253945350647,50000.0,0.6307000517845154,1.8217581510543823,10000.0,164281.594363451,169985.24022579193,164281.594363451,5661.443585395813,25.490796089172363,0.0 -486300,4.095379,0.6165172,,,,,,,,,,,,,, -486400,4.4679403,0.60818946,,,,,,,,,,,,,, -486500,4.9701858,0.57222843,,,,,,,,,,,,,, -486600,4.6442533,0.68521076,,,,,,,,,,,,,, -486700,4.651705,0.66985726,,,,,,,,,,,,,, -486800,4.403666,0.5616309,,,,,,,,,,,,,, -486900,4.538544,0.5747794,,,,,,,,,,,,,, -487000,4.5806837,0.6749329,,,,,,,,,,,,,, -487100,5.1634884,0.599087,,,,,,,,,,,,,, -487200,4.4732246,0.64014447,,,,,,,,,,,,,, -487300,4.1403685,0.60880286,,,,,,,,,,,,,, -487400,4.5992775,0.70084316,,,,,,,,,,,,,, -487500,4.663832,0.53912365,,,,,,,,,,,,,, -487600,4.0626125,0.56185824,,,,,,,,,,,,,, -487700,4.983717,0.6489701,,,,,,,,,,,,,, -487775,,,0.9592434167861938,0.1517109423875808,0.7553399801254272,1.0429028272628784,50000.0,0.6321000456809998,1.819995641708374,10000.0,164791.71216368675,170512.45286893845,164791.71216368675,5678.37740445137,25.59970498085022,0.0 -487800,4.7716045,0.6915621,,,,,,,,,,,,,, -487900,4.947003,0.5991562,,,,,,,,,,,,,, -488000,4.529561,0.60534096,,,,,,,,,,,,,, -488100,4.8288183,0.6521816,,,,,,,,,,,,,, -488200,4.8085337,0.6126164,,,,,,,,,,,,,, -488300,4.337339,0.5826458,,,,,,,,,,,,,, -488400,4.658769,0.66257393,,,,,,,,,,,,,, -488500,4.4839373,0.6452542,,,,,,,,,,,,,, -488600,5.155663,0.7025732,,,,,,,,,,,,,, -488700,4.7252784,0.6061154,,,,,,,,,,,,,, -488800,4.2554955,0.6668802,,,,,,,,,,,,,, -488900,4.459565,0.5980286,,,,,,,,,,,,,, -489000,4.5208144,0.5864484,,,,,,,,,,,,,, -489100,5.157795,0.6836169,,,,,,,,,,,,,, -489200,4.2437716,0.6083129,,,,,,,,,,,,,, -489286,,,0.9607381820678712,0.1463460773229599,0.7548399567604065,1.0446679592132568,50000.0,0.6303000450134277,1.820755958557129,10000.0,165301.6225683689,171039.63341140747,165301.6225683689,5695.487970590591,25.70747256278992,0.0 -489300,4.492688,0.66581804,,,,,,,,,,,,,, -489400,4.3663206,0.6530639,,,,,,,,,,,,,, -489500,4.6121316,0.60476077,,,,,,,,,,,,,, -489600,4.298069,0.58983284,,,,,,,,,,,,,, -489700,4.366928,0.62506634,,,,,,,,,,,,,, -489800,4.4390454,0.61070377,,,,,,,,,,,,,, -489900,4.9593964,0.58356655,,,,,,,,,,,,,, -490000,4.7351513,0.6904858,,,,,,,,,,,,,, -490100,4.243829,0.63576686,,,,,,,,,,,,,, -490200,4.6929584,0.6089687,,,,,,,,,,,,,, -490300,5.7749043,0.6836197,,,,,,,,,,,,,, -490400,4.711777,0.59537095,,,,,,,,,,,,,, -490500,4.6144447,0.63563013,,,,,,,,,,,,,, -490600,4.1195817,0.5099767,,,,,,,,,,,,,, -490700,4.3648148,0.65279245,,,,,,,,,,,,,, -490796,,,0.9604392051696776,0.1487823575735092,0.755299985408783,1.0443212985992432,50000.0,0.6317000389099121,1.8227946758270264,10000.0,165811.5897936821,171566.49512457848,165811.5897936821,5712.22460770607,25.81318640708924,0.0 -490800,4.505785,0.6038704,,,,,,,,,,,,,, -490900,5.4607177,0.6284674,,,,,,,,,,,,,, -491000,4.8629065,0.69987595,,,,,,,,,,,,,, -491100,4.8102217,0.6050541,,,,,,,,,,,,,, -491200,4.5009437,0.63243556,,,,,,,,,,,,,, -491300,4.6679606,0.58380586,,,,,,,,,,,,,, -491400,4.637425,0.62319565,,,,,,,,,,,,,, -491500,4.8491416,0.6872964,,,,,,,,,,,,,, -491600,4.247001,0.644439,,,,,,,,,,,,,, -491700,4.2228584,0.5625134,,,,,,,,,,,,,, -491800,4.3312054,0.641666,,,,,,,,,,,,,, -491900,4.275795,0.5669699,,,,,,,,,,,,,, -492000,4.3539257,0.6652846,,,,,,,,,,,,,, -492100,4.182046,0.61448646,,,,,,,,,,,,,, -492200,4.737873,0.68846315,,,,,,,,,,,,,, -492300,4.860278,0.6801854,,,,,,,,,,,,,, -492307,,,0.959203600883484,0.1508540958166122,0.7555199861526489,1.044857144355774,50000.0,0.6322000026702881,1.823703646659851,10000.0,166321.67253041267,172093.6145040989,166321.67253041267,5729.101854324341,25.92009258270264,0.0 -492400,4.641473,0.6077219,,,,,,,,,,,,,, -492500,4.283593,0.5403099,,,,,,,,,,,,,, -492600,4.698281,0.5553118,,,,,,,,,,,,,, -492700,5.238029,0.71002233,,,,,,,,,,,,,, -492800,4.891209,0.63282996,,,,,,,,,,,,,, -492900,4.9332232,0.659897,,,,,,,,,,,,,, -493000,4.2674084,0.5940976,,,,,,,,,,,,,, -493100,4.4933567,0.61289644,,,,,,,,,,,,,, -493200,4.6515856,0.6504149,,,,,,,,,,,,,, -493300,5.0069656,0.64978456,,,,,,,,,,,,,, -493400,4.4750533,0.6808454,,,,,,,,,,,,,, -493500,4.669651,0.60255396,,,,,,,,,,,,,, -493600,4.39292,0.61287063,,,,,,,,,,,,,, -493700,4.81054,0.57957757,,,,,,,,,,,,,, -493800,4.774825,0.6542509,,,,,,,,,,,,,, -493817,,,0.9607381820678712,0.148208349943161,0.7550999522209167,1.0426260232925415,50000.0,0.631100058555603,1.8204790353775024,10000.0,166831.68437623978,172620.69486761093,166831.68437623978,5746.009928226471,26.02876138687133,0.0 -493900,4.1844816,0.604165,,,,,,,,,,,,,, -494000,4.385233,0.62557346,,,,,,,,,,,,,, -494100,4.62453,0.63331133,,,,,,,,,,,,,, -494200,4.3059487,0.5701056,,,,,,,,,,,,,, -494300,4.828526,0.6546306,,,,,,,,,,,,,, -494400,4.82869,0.6729538,,,,,,,,,,,,,, -494500,4.327113,0.5745381,,,,,,,,,,,,,, -494600,4.5384007,0.64004385,,,,,,,,,,,,,, -494700,4.3151193,0.6027708,,,,,,,,,,,,,, -494800,5.07049,0.642572,,,,,,,,,,,,,, -494900,4.3957295,0.6305374,,,,,,,,,,,,,, -495000,4.4187756,0.60734344,,,,,,,,,,,,,, -495100,4.2452035,0.6455202,,,,,,,,,,,,,, -495200,4.468256,0.58791196,,,,,,,,,,,,,, -495300,4.3315954,0.5591023,,,,,,,,,,,,,, -495327,,,0.961535394191742,0.1453877687454223,0.7549999952316284,1.0442039966583252,50000.0,0.6322000026702881,1.82174289226532,10000.0,167341.80224132538,173148.03052449226,167341.80224132538,5763.065656900406,26.13835978507996,0.0 -495400,4.648281,0.679723,,,,,,,,,,,,,, -495500,4.62162,0.61030024,,,,,,,,,,,,,, -495600,4.4947267,0.5939907,,,,,,,,,,,,,, -495700,4.3506384,0.596382,,,,,,,,,,,,,, -495800,4.878456,0.65309423,,,,,,,,,,,,,, -495900,4.1322556,0.6067269,,,,,,,,,,,,,, -496000,4.719529,0.62028635,,,,,,,,,,,,,, -496100,4.645761,0.6650137,,,,,,,,,,,,,, -496200,4.5055656,0.6668533,,,,,,,,,,,,,, -496300,4.9573007,0.6557458,,,,,,,,,,,,,, -496400,4.6826572,0.72620153,,,,,,,,,,,,,, -496500,4.3835835,0.6223603,,,,,,,,,,,,,, -496600,4.4235497,0.673611,,,,,,,,,,,,,, -496700,4.5698905,0.67249227,,,,,,,,,,,,,, -496800,4.300912,0.6408708,,,,,,,,,,,,,, -496837,,,0.960339605808258,0.147151231765747,0.7551800012588501,1.0421451330184937,50000.0,0.6304000020027161,1.8195481300354004,10000.0,167851.83716320992,173675.16613030434,167851.83716320992,5780.004978179932,26.24807047843933,0.0 -496900,4.618508,0.6592016,,,,,,,,,,,,,, -497000,4.3529205,0.61572266,,,,,,,,,,,,,, -497100,4.595067,0.6130894,,,,,,,,,,,,,, -497200,4.828477,0.6572939,,,,,,,,,,,,,, -497300,4.6265483,0.6498049,,,,,,,,,,,,,, -497400,4.4555783,0.60267824,,,,,,,,,,,,,, -497500,4.6619287,0.6017902,,,,,,,,,,,,,, -497600,4.4021735,0.642609,,,,,,,,,,,,,, -497700,4.8902044,0.6892705,,,,,,,,,,,,,, -497800,4.841358,0.70140195,,,,,,,,,,,,,, -497900,5.326322,0.70830476,,,,,,,,,,,,,, -498000,4.4262567,0.6821053,,,,,,,,,,,,,, -498100,4.5692487,0.6161418,,,,,,,,,,,,,, -498200,4.4761095,0.6340934,,,,,,,,,,,,,, -498300,4.7288017,0.6994645,,,,,,,,,,,,,, -498347,,,0.9596021771430968,0.1485576331615448,0.7550199627876282,1.044399380683899,50000.0,0.6319000124931335,1.819997906684876,10000.0,168361.71323132515,174202.0587952137,168361.71323132515,5796.880818128586,26.336756229400635,0.0 -498400,4.5670943,0.6418874,,,,,,,,,,,,,, -498500,4.616075,0.59370387,,,,,,,,,,,,,, -498600,4.597287,0.6712498,,,,,,,,,,,,,, -498700,4.8361063,0.65698725,,,,,,,,,,,,,, -498800,4.552781,0.62026393,,,,,,,,,,,,,, -498900,4.606253,0.70509326,,,,,,,,,,,,,, -499000,4.000929,0.55585086,,,,,,,,,,,,,, -499100,4.613991,0.6107564,,,,,,,,,,,,,, -499200,4.478569,0.63609886,,,,,,,,,,,,,, -499300,4.6694107,0.61706567,,,,,,,,,,,,,, -499400,4.4656916,0.62306446,,,,,,,,,,,,,, -499500,4.7972507,0.69576925,,,,,,,,,,,,,, -499600,4.902307,0.7635137,,,,,,,,,,,,,, -499700,4.940825,0.6656463,,,,,,,,,,,,,, -499800,5.012559,0.6297413,,,,,,,,,,,,,, -499857,,,0.9595423936843872,0.151010975241661,0.7550399899482727,1.043644905090332,50000.0,0.631100058555603,1.821232557296753,10000.0,168871.77570962906,174729.8920071125,168871.77570962906,5814.488134860992,26.44757008552552,0.0 -499900,4.6719904,0.64670485,,,,,,,,,,,,,, -500000,4.374223,0.5968347,,,,,,,,,,,,,, -500100,5.0326166,0.69872737,,,,,,,,,,,,,, -500200,4.7669253,0.6967192,,,,,,,,,,,,,, -500300,4.582047,0.68491274,,,,,,,,,,,,,, -500400,4.2359595,0.5771956,,,,,,,,,,,,,, -500500,4.566252,0.64190817,,,,,,,,,,,,,, -500600,4.626654,0.6404837,,,,,,,,,,,,,, -500700,4.233566,0.5545659,,,,,,,,,,,,,, -500800,4.7087893,0.66649956,,,,,,,,,,,,,, -500900,4.7070923,0.60202086,,,,,,,,,,,,,, -501000,4.1256747,0.5802123,,,,,,,,,,,,,, -501100,4.4998574,0.63667625,,,,,,,,,,,,,, -501200,4.2942176,0.58377105,,,,,,,,,,,,,, -501300,4.592292,0.6448155,,,,,,,,,,,,,, -501366,,,0.9618343114852904,0.1473004668951034,0.7550599575042725,1.0455933809280396,50000.0,0.6312000155448914,1.822257161140442,10000.0,169381.6427283287,175256.8369383812,169381.6427283287,5831.401889562607,26.55976819992065,0.0 -501400,4.4676633,0.56418365,,,,,,,,,,,,,, -501500,4.718744,0.62610936,,,,,,,,,,,,,, -501600,4.6984386,0.6978194,,,,,,,,,,,,,, -501700,4.437569,0.61284995,,,,,,,,,,,,,, -501800,4.413366,0.64544696,,,,,,,,,,,,,, -501900,4.6518025,0.6420569,,,,,,,,,,,,,, -502000,3.9945774,0.5565711,,,,,,,,,,,,,, -502100,4.2586117,0.5889654,,,,,,,,,,,,,, -502200,4.597248,0.61273587,,,,,,,,,,,,,, -502300,4.6166263,0.65578896,,,,,,,,,,,,,, -502400,4.4338884,0.6117434,,,,,,,,,,,,,, -502500,4.5078073,0.6297419,,,,,,,,,,,,,, -502600,4.324806,0.60105723,,,,,,,,,,,,,, -502700,4.8005767,0.677437,,,,,,,,,,,,,, -502800,4.5116353,0.60842514,,,,,,,,,,,,,, -502877,,,0.9600406289100648,0.1488345712423324,0.7554799914360046,1.043159008026123,50000.0,0.6317000389099121,1.8203349113464355,10000.0,169891.77300667763,175784.2019557953,169891.77300667763,5848.470964431763,26.6716742515564,0.0 -502900,4.7343874,0.67115766,,,,,,,,,,,,,, -503000,4.956571,0.69736844,,,,,,,,,,,,,, -503100,4.6472917,0.6437882,,,,,,,,,,,,,, -503200,4.2941084,0.59708375,,,,,,,,,,,,,, -503300,4.643164,0.70466006,,,,,,,,,,,,,, -503400,4.194463,0.6043857,,,,,,,,,,,,,, -503500,4.5787954,0.6506976,,,,,,,,,,,,,, -503600,4.2892942,0.63550496,,,,,,,,,,,,,, -503700,4.8073187,0.5816438,,,,,,,,,,,,,, -503800,4.303134,0.5857101,,,,,,,,,,,,,, -503900,4.400408,0.57786745,,,,,,,,,,,,,, -504000,4.2024913,0.5734659,,,,,,,,,,,,,, -504100,5.01689,0.6690322,,,,,,,,,,,,,, -504200,4.2648764,0.5867434,,,,,,,,,,,,,, -504300,4.8602467,0.6666149,,,,,,,,,,,,,, -504388,,,0.961136758327484,0.1460768282413482,0.7550199627876282,1.0443062782287598,50000.0,0.6304000020027161,1.8220136165618896,10000.0,170401.88774085045,176311.35015702248,170401.88774085045,5865.335414886475,26.78855919837952,0.0 -504400,4.2947016,0.5673384,,,,,,,,,,,,,, -504500,4.5015087,0.614146,,,,,,,,,,,,,, -504600,4.3736515,0.66767263,,,,,,,,,,,,,, -504700,4.4848866,0.568415,,,,,,,,,,,,,, -504800,4.7616405,0.63411033,,,,,,,,,,,,,, -504900,4.5010986,0.6256623,,,,,,,,,,,,,, -505000,5.2323337,0.6378706,,,,,,,,,,,,,, -505100,4.228663,0.5352942,,,,,,,,,,,,,, -505200,4.4655547,0.6742335,,,,,,,,,,,,,, -505300,4.504193,0.6718765,,,,,,,,,,,,,, -505400,4.5914435,0.630027,,,,,,,,,,,,,, -505500,4.3433213,0.5938417,,,,,,,,,,,,,, -505600,4.4714417,0.6349082,,,,,,,,,,,,,, -505700,5.106566,0.642692,,,,,,,,,,,,,, -505800,4.3811493,0.64373446,,,,,,,,,,,,,, -505898,,,0.960359513759613,0.146673321723938,0.7550999522209167,1.0442750453948977,50000.0,0.6313000321388245,1.822537422180176,10000.0,170911.90631198883,176838.2618751526,170911.90631198883,5882.064656734467,26.900224924087524,0.0 -505900,4.69865,0.6817497,,,,,,,,,,,,,, -506000,4.4024353,0.58360887,,,,,,,,,,,,,, -506100,4.2760644,0.56123215,,,,,,,,,,,,,, -506200,4.056017,0.58521765,,,,,,,,,,,,,, -506300,4.71487,0.6445634,,,,,,,,,,,,,, -506400,4.4565,0.63970923,,,,,,,,,,,,,, -506500,4.6514583,0.606407,,,,,,,,,,,,,, -506600,4.661269,0.5941924,,,,,,,,,,,,,, -506700,5.4125404,0.6659538,,,,,,,,,,,,,, -506800,4.950224,0.6196071,,,,,,,,,,,,,, -506900,4.341829,0.6233758,,,,,,,,,,,,,, -507000,4.9590387,0.690498,,,,,,,,,,,,,, -507100,4.869772,0.65522003,,,,,,,,,,,,,, -507200,5.136381,0.66716516,,,,,,,,,,,,,, -507300,4.744243,0.6635984,,,,,,,,,,,,,, -507400,4.470688,0.64626294,,,,,,,,,,,,,, -507409,,,0.9600406289100648,0.1485704332590103,0.7552199959754944,1.044175386428833,50000.0,0.6313000321388245,1.8244786262512207,10000.0,171422.06424570084,177365.73793911934,171422.06424570084,5899.219933986664,27.01107287406921,0.0 -507500,4.40228,0.5552432,,,,,,,,,,,,,, -507600,4.7296414,0.7051114,,,,,,,,,,,,,, -507700,4.858852,0.72610885,,,,,,,,,,,,,, -507800,4.274753,0.5896846,,,,,,,,,,,,,, -507900,4.478568,0.65144503,,,,,,,,,,,,,, -508000,4.626708,0.6220362,,,,,,,,,,,,,, -508100,4.237359,0.62075865,,,,,,,,,,,,,, -508200,5.114043,0.71760565,,,,,,,,,,,,,, -508300,4.229861,0.5817183,,,,,,,,,,,,,, -508400,4.460863,0.60716355,,,,,,,,,,,,,, -508500,4.6316066,0.61191165,,,,,,,,,,,,,, -508600,4.860736,0.6381725,,,,,,,,,,,,,, -508700,4.2574854,0.6405652,,,,,,,,,,,,,, -508800,4.518437,0.5637044,,,,,,,,,,,,,, -508900,4.733188,0.66053396,,,,,,,,,,,,,, -508919,,,0.9602997303009032,0.151060089468956,0.755899965763092,1.043476700782776,50000.0,0.6301000118255615,1.8214324712753296,10000.0,171932.13884925842,177892.9942240715,171932.13884925842,5916.238545417786,27.122536659240723,0.0 -509000,4.598049,0.60017854,,,,,,,,,,,,,, -509100,4.220453,0.6660492,,,,,,,,,,,,,, -509200,4.7106395,0.64734584,,,,,,,,,,,,,, -509300,4.5090876,0.6308594,,,,,,,,,,,,,, -509400,5.58598,0.6548872,,,,,,,,,,,,,, -509500,4.6417594,0.6123666,,,,,,,,,,,,,, -509600,4.596589,0.64318717,,,,,,,,,,,,,, -509700,3.9506958,0.50582725,,,,,,,,,,,,,, -509800,4.5137343,0.5979175,,,,,,,,,,,,,, -509900,4.400786,0.6421049,,,,,,,,,,,,,, -510000,4.3599315,0.55254346,,,,,,,,,,,,,, -510100,4.20487,0.5407003,,,,,,,,,,,,,, -510200,4.261712,0.61062473,,,,,,,,,,,,,, -510300,4.9361005,0.6462875,,,,,,,,,,,,,, -510400,4.0961366,0.59267867,,,,,,,,,,,,,, -510430,,,0.9616748690605164,0.1418393105268478,0.7552199959754944,1.0449419021606443,50000.0,0.6312000155448914,1.8233567476272583,10000.0,172442.27340078354,178420.2787978649,172442.27340078354,5933.223411083221,27.235366582870483,0.0 -510500,4.22699,0.5367583,,,,,,,,,,,,,, -510600,4.7212877,0.6743002,,,,,,,,,,,,,, -510700,4.334752,0.6085911,,,,,,,,,,,,,, -510800,4.399517,0.59253955,,,,,,,,,,,,,, -510900,4.3206797,0.63763714,,,,,,,,,,,,,, -511000,4.5792837,0.5934069,,,,,,,,,,,,,, -511100,4.8029037,0.62797964,,,,,,,,,,,,,, -511200,4.9458933,0.6876408,,,,,,,,,,,,,, -511300,5.248185,0.6188583,,,,,,,,,,,,,, -511400,4.7249346,0.637609,,,,,,,,,,,,,, -511500,4.5506024,0.6446638,,,,,,,,,,,,,, -511600,4.4643064,0.5536027,,,,,,,,,,,,,, -511700,4.2933345,0.60311663,,,,,,,,,,,,,, -511800,4.467094,0.66585714,,,,,,,,,,,,,, -511900,4.0860467,0.590471,,,,,,,,,,,,,, -511940,,,0.9614357352256776,0.1435249298810959,0.7557199597358704,1.0442218780517578,50000.0,0.6306000351905823,1.8227651119232176,10000.0,172952.19137954712,178947.29454159737,172952.19137954712,5950.154138565064,27.349958181381226,0.0 -512000,4.8850355,0.62586474,,,,,,,,,,,,,, -512100,4.8315773,0.6427218,,,,,,,,,,,,,, -512200,4.2872086,0.5617199,,,,,,,,,,,,,, -512300,4.821671,0.6583345,,,,,,,,,,,,,, -512400,4.635418,0.59385663,,,,,,,,,,,,,, -512500,4.292997,0.67509305,,,,,,,,,,,,,, -512600,4.629879,0.6558324,,,,,,,,,,,,,, -512700,4.9612236,0.5969423,,,,,,,,,,,,,, -512800,4.27715,0.68613607,,,,,,,,,,,,,, -512900,4.5856147,0.63828325,,,,,,,,,,,,,, -513000,4.9345927,0.64268637,,,,,,,,,,,,,, -513100,4.883779,0.628994,,,,,,,,,,,,,, -513200,4.617939,0.5936136,,,,,,,,,,,,,, -513300,4.859621,0.70791215,,,,,,,,,,,,,, -513400,4.6954193,0.6442521,,,,,,,,,,,,,, -513450,,,0.9593430757522584,0.148752212524414,0.7551400065422058,1.043470025062561,50000.0,0.6304000020027161,1.8211781978607176,10000.0,173462.10426592827,179474.29887628555,173462.10426592827,5967.08442735672,27.45971775054932,0.0 -513500,4.3903456,0.5776291,,,,,,,,,,,,,, -513600,4.7376366,0.6984912,,,,,,,,,,,,,, -513700,4.3165693,0.5810689,,,,,,,,,,,,,, -513800,4.539535,0.59336877,,,,,,,,,,,,,, -513900,4.579109,0.6255789,,,,,,,,,,,,,, -514000,4.360973,0.558579,,,,,,,,,,,,,, -514100,4.962694,0.6538907,,,,,,,,,,,,,, -514200,4.769036,0.6366117,,,,,,,,,,,,,, -514300,4.2762427,0.58662,,,,,,,,,,,,,, -514400,4.9799953,0.65519714,,,,,,,,,,,,,, -514500,4.487752,0.61704016,,,,,,,,,,,,,, -514600,4.697396,0.6207566,,,,,,,,,,,,,, -514700,4.743564,0.6526133,,,,,,,,,,,,,, -514800,4.9030633,0.7329315,,,,,,,,,,,,,, -514900,5.024502,0.6658554,,,,,,,,,,,,,, -514960,,,0.9610769748687744,0.1501688510179519,0.7549999952316284,1.0442010164260864,50000.0,0.6315000057220459,1.82271409034729,10000.0,173971.96740150452,180001.14465355873,173971.96740150452,5983.905996799469,27.568768978118896,0.0 -515000,4.815013,0.5911658,,,,,,,,,,,,,, -515100,4.588932,0.64203954,,,,,,,,,,,,,, -515200,4.201463,0.5669682,,,,,,,,,,,,,, -515300,4.158421,0.5646502,,,,,,,,,,,,,, -515400,4.556538,0.6268115,,,,,,,,,,,,,, -515500,5.2447248,0.7803525,,,,,,,,,,,,,, -515600,4.521499,0.6514911,,,,,,,,,,,,,, -515700,4.7984037,0.60538816,,,,,,,,,,,,,, -515800,4.9225025,0.6600331,,,,,,,,,,,,,, -515900,4.445736,0.65445846,,,,,,,,,,,,,, -516000,4.2978,0.65954673,,,,,,,,,,,,,, -516100,4.489446,0.6399785,,,,,,,,,,,,,, -516200,4.850985,0.6445756,,,,,,,,,,,,,, -516300,4.1172915,0.6176123,,,,,,,,,,,,,, -516400,4.6661444,0.6235066,,,,,,,,,,,,,, -516471,,,0.9609972834587096,0.1467299163341522,0.7552199959754944,1.0436804294586182,50000.0,0.6300000548362732,1.8224447965621948,10000.0,174482.09878492355,180528.45638537407,174482.09878492355,6000.9197318553925,27.683451890945435,0.0 -516500,4.500883,0.5712052,,,,,,,,,,,,,, -516600,4.7839527,0.6817068,,,,,,,,,,,,,, -516700,4.573395,0.60675704,,,,,,,,,,,,,, -516800,4.490229,0.6364658,,,,,,,,,,,,,, -516900,5.0689993,0.65401167,,,,,,,,,,,,,, -517000,4.6388183,0.56541944,,,,,,,,,,,,,, -517100,5.3348136,0.6412531,,,,,,,,,,,,,, -517200,4.7520194,0.66000724,,,,,,,,,,,,,, -517300,4.545036,0.61053663,,,,,,,,,,,,,, -517400,4.417723,0.61575145,,,,,,,,,,,,,, -517500,4.290398,0.60834956,,,,,,,,,,,,,, -517600,4.340598,0.62690073,,,,,,,,,,,,,, -517700,4.5710425,0.68126065,,,,,,,,,,,,,, -517800,4.377546,0.6035689,,,,,,,,,,,,,, -517900,4.3741236,0.5814539,,,,,,,,,,,,,, -517981,,,0.960558831691742,0.1473705470561981,0.7551400065422058,1.043861746788025,50000.0,0.6314000487327576,1.8219637870788568,10000.0,174992.04834723473,181056.0744562149,174992.04834723473,6018.418604850769,27.80076766014099,0.0 -518000,5.4189806,0.6966559,,,,,,,,,,,,,, -518100,4.2488766,0.6378217,,,,,,,,,,,,,, -518200,4.647695,0.65431404,,,,,,,,,,,,,, -518300,4.714911,0.62500596,,,,,,,,,,,,,, -518400,4.624564,0.6184442,,,,,,,,,,,,,, -518500,4.4795094,0.644519,,,,,,,,,,,,,, -518600,4.802472,0.5798434,,,,,,,,,,,,,, -518700,4.517201,0.60625434,,,,,,,,,,,,,, -518800,4.3953013,0.57798946,,,,,,,,,,,,,, -518900,4.538496,0.6062783,,,,,,,,,,,,,, -519000,4.2656107,0.5702921,,,,,,,,,,,,,, -519100,4.5320992,0.5865716,,,,,,,,,,,,,, -519200,4.2454715,0.5753938,,,,,,,,,,,,,, -519300,4.3879137,0.591717,,,,,,,,,,,,,, -519400,4.145683,0.560478,,,,,,,,,,,,,, -519491,,,0.958765149116516,0.1532300561666488,0.7552399635314941,1.043881893157959,50000.0,0.6310000419616699,1.821869969367981,10000.0,175502.09851241112,181583.13382434845,175502.09851241112,6035.283583641052,27.892229795455933,0.0 -519500,4.5190024,0.7294507,,,,,,,,,,,,,, -519600,4.8580327,0.61632437,,,,,,,,,,,,,, -519700,4.5506334,0.6785544,,,,,,,,,,,,,, -519800,4.3191514,0.64280814,,,,,,,,,,,,,, -519900,4.5327864,0.6305128,,,,,,,,,,,,,, -520000,4.9054165,0.6137134,,,,,,,,,,,,,, -520100,4.1009564,0.56255966,,,,,,,,,,,,,, -520200,4.76479,0.66608864,,,,,,,,,,,,,, -520300,4.5656233,0.63830024,,,,,,,,,,,,,, -520400,4.4557676,0.66765094,,,,,,,,,,,,,, -520500,5.0787907,0.74657154,,,,,,,,,,,,,, -520600,4.4970727,0.61123997,,,,,,,,,,,,,, -520700,4.4702897,0.586887,,,,,,,,,,,,,, -520800,4.805333,0.66958684,,,,,,,,,,,,,, -520900,4.4452524,0.62903386,,,,,,,,,,,,,, -521000,4.175356,0.63141286,,,,,,,,,,,,,, -521001,,,0.9597018361091614,0.1484248042106628,0.7550599575042725,1.04348886013031,50000.0,0.6307000517845154,1.8207950592041016,10000.0,176011.94000339508,182110.01022839543,176011.94000339508,6052.146709918976,28.011739253997803,0.0 -521100,4.301445,0.5988398,,,,,,,,,,,,,, -521200,4.2129035,0.56955,,,,,,,,,,,,,, -521300,4.2946663,0.60565215,,,,,,,,,,,,,, -521400,4.357153,0.54980785,,,,,,,,,,,,,, -521500,4.5499606,0.6701027,,,,,,,,,,,,,, -521600,4.5271535,0.6013759,,,,,,,,,,,,,, -521700,4.0139756,0.5746583,,,,,,,,,,,,,, -521800,4.410585,0.6568506,,,,,,,,,,,,,, -521900,4.7708354,0.61397666,,,,,,,,,,,,,, -522000,4.609342,0.66024745,,,,,,,,,,,,,, -522100,4.440948,0.5910651,,,,,,,,,,,,,, -522200,4.9466777,0.6214154,,,,,,,,,,,,,, -522300,4.120552,0.58944273,,,,,,,,,,,,,, -522400,4.4332147,0.64418566,,,,,,,,,,,,,, -522500,4.869202,0.76000136,,,,,,,,,,,,,, -522512,,,0.9596619606018066,0.150680735707283,0.7549200057983398,1.0442814826965332,50000.0,0.6303000450134277,1.8228248357772827,10000.0,176522.079351902,182637.39370679847,176522.079351902,6069.219238519669,28.131194591522217,0.0 -522600,4.3477993,0.5532307,,,,,,,,,,,,,, -522700,4.5211334,0.59776866,,,,,,,,,,,,,, -522800,4.4146175,0.6171608,,,,,,,,,,,,,, -522900,4.7078357,0.5667834,,,,,,,,,,,,,, -523000,4.423865,0.60221815,,,,,,,,,,,,,, -523100,4.682122,0.6438405,,,,,,,,,,,,,, -523200,4.460134,0.6447329,,,,,,,,,,,,,, -523300,4.4086556,0.6287761,,,,,,,,,,,,,, -523400,4.4000206,0.5714128,,,,,,,,,,,,,, -523500,4.5121922,0.6691716,,,,,,,,,,,,,, -523600,4.1176343,0.54260635,,,,,,,,,,,,,, -523700,4.768423,0.6990168,,,,,,,,,,,,,, -523800,4.406682,0.64835256,,,,,,,,,,,,,, -523900,4.8024397,0.62453854,,,,,,,,,,,,,, -524000,5.0979867,0.6523471,,,,,,,,,,,,,, -524023,,,0.960558831691742,0.1465355008840561,0.755299985408783,1.0426671504974363,50000.0,0.6313000321388245,1.819411039352417,10000.0,177032.21652126312,183164.7318851948,177032.21652126312,6086.252090454102,28.246938705444336,0.0 -524100,4.846426,0.689172,,,,,,,,,,,,,, -524200,4.578257,0.63505155,,,,,,,,,,,,,, -524300,4.657677,0.6404226,,,,,,,,,,,,,, -524400,4.7412596,0.60881096,,,,,,,,,,,,,, -524500,4.8743777,0.6186805,,,,,,,,,,,,,, -524600,4.959366,0.66016567,,,,,,,,,,,,,, -524700,4.33423,0.5645815,,,,,,,,,,,,,, -524800,4.6423707,0.6239413,,,,,,,,,,,,,, -524900,4.5169897,0.6239135,,,,,,,,,,,,,, -525000,4.720837,0.6197128,,,,,,,,,,,,,, -525100,4.558533,0.6783408,,,,,,,,,,,,,, -525200,4.5526476,0.6079658,,,,,,,,,,,,,, -525300,5.212771,0.6755147,,,,,,,,,,,,,, -525400,4.3947325,0.6714854,,,,,,,,,,,,,, -525500,5.154385,0.65603256,,,,,,,,,,,,,, -525533,,,0.9598612785339355,0.1497280746698379,0.7552199959754944,1.044359564781189,50000.0,0.6310000419616699,1.823042869567871,10000.0,177542.2542464733,183691.81848621368,177542.2542464733,6103.131259441376,28.36412715911865,0.0 -525600,4.691498,0.62115383,,,,,,,,,,,,,, -525700,4.4344544,0.65222865,,,,,,,,,,,,,, -525800,4.967546,0.6332831,,,,,,,,,,,,,, -525900,4.075042,0.58559597,,,,,,,,,,,,,, -526000,4.3295574,0.62631595,,,,,,,,,,,,,, -526100,3.9651203,0.5415954,,,,,,,,,,,,,, -526200,4.6766605,0.6819427,,,,,,,,,,,,,, -526300,4.032012,0.5423802,,,,,,,,,,,,,, -526400,4.6042128,0.646532,,,,,,,,,,,,,, -526500,4.6923213,0.6433243,,,,,,,,,,,,,, -526600,4.5426273,0.65239656,,,,,,,,,,,,,, -526700,4.895683,0.6789637,,,,,,,,,,,,,, -526800,4.2783694,0.6346949,,,,,,,,,,,,,, -526900,4.289287,0.613994,,,,,,,,,,,,,, -527000,4.5379977,0.63675785,,,,,,,,,,,,,, -527043,,,0.9602000713348388,0.1480627954006195,0.7551400065422058,1.043128252029419,50000.0,0.6318000555038452,1.8209943771362305,10000.0,178052.15637159348,184218.59387516967,178052.15637159348,6119.841101884842,28.47494602203369,0.0 -527100,4.6051517,0.64924115,,,,,,,,,,,,,, -527200,4.3500485,0.5783401,,,,,,,,,,,,,, -527300,5.5469003,0.6895857,,,,,,,,,,,,,, -527400,4.272386,0.5674201,,,,,,,,,,,,,, -527500,4.7394686,0.6970634,,,,,,,,,,,,,, -527600,4.646131,0.6026715,,,,,,,,,,,,,, -527700,4.2193465,0.59074825,,,,,,,,,,,,,, -527800,4.808264,0.7191129,,,,,,,,,,,,,, -527900,4.37965,0.5815869,,,,,,,,,,,,,, -528000,4.68081,0.59969753,,,,,,,,,,,,,, -528100,4.475035,0.6973356,,,,,,,,,,,,,, -528200,4.313473,0.60663724,,,,,,,,,,,,,, -528300,4.6827936,0.70224893,,,,,,,,,,,,,, -528400,4.496797,0.72146624,,,,,,,,,,,,,, -528500,4.18569,0.56280226,,,,,,,,,,,,,, -528553,,,0.9608777165412904,0.1473826766014099,0.7552799582481384,1.0443631410598757,50000.0,0.6310000419616699,1.8213881254196167,10000.0,178562.26889562607,184745.6449303627,178562.26889562607,6136.617638349533,28.58440470695496,0.0 -528600,4.6955953,0.6210582,,,,,,,,,,,,,, -528700,5.269428,0.6908631,,,,,,,,,,,,,, -528800,4.323285,0.6150052,,,,,,,,,,,,,, -528900,4.044313,0.52908206,,,,,,,,,,,,,, -529000,4.9529433,0.6057992,,,,,,,,,,,,,, -529100,4.4691186,0.5782228,,,,,,,,,,,,,, -529200,4.632244,0.6873469,,,,,,,,,,,,,, -529300,4.4025893,0.6582315,,,,,,,,,,,,,, -529400,5.1214733,0.7230855,,,,,,,,,,,,,, -529500,4.961075,0.6242175,,,,,,,,,,,,,, -529600,4.589529,0.6351631,,,,,,,,,,,,,, -529700,4.2954,0.6048167,,,,,,,,,,,,,, -529800,4.925253,0.60608023,,,,,,,,,,,,,, -529900,4.4677477,0.5606127,,,,,,,,,,,,,, -530000,4.5601664,0.6372273,,,,,,,,,,,,,, -530063,,,0.9594626426696776,0.1503868401050567,0.755079984664917,1.0445592403411863,50000.0,0.6301000118255615,1.8239092826843264,10000.0,179072.21165585518,185272.78962039948,179072.21165585518,6153.652619361877,28.698846340179443,0.0 -530100,4.3862667,0.5359031,,,,,,,,,,,,,, -530200,4.425745,0.60700804,,,,,,,,,,,,,, -530300,4.7583795,0.6831579,,,,,,,,,,,,,, -530400,4.703573,0.64925873,,,,,,,,,,,,,, -530500,4.358976,0.62640595,,,,,,,,,,,,,, -530600,4.3936505,0.62135166,,,,,,,,,,,,,, -530700,4.403534,0.60689026,,,,,,,,,,,,,, -530800,4.25025,0.5806549,,,,,,,,,,,,,, -530900,4.6889596,0.63172495,,,,,,,,,,,,,, -531000,5.3097067,0.70324415,,,,,,,,,,,,,, -531100,4.347844,0.66635525,,,,,,,,,,,,,, -531200,4.4187336,0.6079284,,,,,,,,,,,,,, -531300,4.3059487,0.6305095,,,,,,,,,,,,,, -531400,4.7709007,0.5988468,,,,,,,,,,,,,, -531500,4.278719,0.5713496,,,,,,,,,,,,,, -531573,,,0.9595224857330322,0.150401160120964,0.7553600072860718,1.0430076122283936,50000.0,0.6312000155448914,1.8203206062316888,10000.0,179582.16641449928,185799.72025728223,179582.16641449928,6170.457691907883,28.81775641441345,0.0 -531600,4.4164886,0.57304794,,,,,,,,,,,,,, -531700,4.7075367,0.58893466,,,,,,,,,,,,,, -531800,4.1499825,0.5809401,,,,,,,,,,,,,, -531900,4.91519,0.60205346,,,,,,,,,,,,,, -532000,4.9667892,0.68175805,,,,,,,,,,,,,, -532100,4.862139,0.72136796,,,,,,,,,,,,,, -532200,4.3043284,0.64228505,,,,,,,,,,,,,, -532300,4.355563,0.5599428,,,,,,,,,,,,,, -532400,4.6069646,0.68181986,,,,,,,,,,,,,, -532500,5.105235,0.5922875,,,,,,,,,,,,,, -532600,4.8771377,0.68669504,,,,,,,,,,,,,, -532700,4.5404367,0.6098602,,,,,,,,,,,,,, -532800,4.5680256,0.67247945,,,,,,,,,,,,,, -532900,4.7944174,0.604645,,,,,,,,,,,,,, -533000,4.7839894,0.6750756,,,,,,,,,,,,,, -533083,,,0.9605388641357422,0.1487494260072708,0.7551999688148499,1.043909788131714,50000.0,0.6320000290870667,1.8219304084777832,10000.0,180092.11389112473,186326.5613975525,180092.11389112473,6187.182283878326,28.934704780578613,0.0 -533100,4.719447,0.64757174,,,,,,,,,,,,,, -533200,4.2363253,0.57062554,,,,,,,,,,,,,, -533300,4.425393,0.5662776,,,,,,,,,,,,,, -533400,4.105414,0.55451906,,,,,,,,,,,,,, -533500,4.539801,0.6297448,,,,,,,,,,,,,, -533600,4.2412367,0.58514667,,,,,,,,,,,,,, -533700,4.750932,0.65114665,,,,,,,,,,,,,, -533800,4.317304,0.649467,,,,,,,,,,,,,, -533900,4.243439,0.61872286,,,,,,,,,,,,,, -534000,4.4677954,0.62105644,,,,,,,,,,,,,, -534100,4.580251,0.6037586,,,,,,,,,,,,,, -534200,4.7380657,0.68944424,,,,,,,,,,,,,, -534300,4.7413306,0.62627625,,,,,,,,,,,,,, -534400,4.987763,0.7289524,,,,,,,,,,,,,, -534500,4.772633,0.69865584,,,,,,,,,,,,,, -534594,,,0.9620535373687744,0.1427308619022369,0.7545999884605408,1.04520845413208,50000.0,0.6320000290870667,1.821923971176148,10000.0,180602.2413403988,186853.6732811928,180602.2413403988,6203.998332023621,29.05104231834412,0.0 -534600,5.2375565,0.6277745,,,,,,,,,,,,,, -534700,4.678627,0.60205317,,,,,,,,,,,,,, -534800,4.9088273,0.62923276,,,,,,,,,,,,,, -534900,4.303334,0.55897653,,,,,,,,,,,,,, -535000,4.789563,0.6377071,,,,,,,,,,,,,, -535100,4.3400526,0.59983367,,,,,,,,,,,,,, -535200,4.806511,0.6147411,,,,,,,,,,,,,, -535300,4.7425327,0.64072734,,,,,,,,,,,,,, -535400,4.591207,0.66321707,,,,,,,,,,,,,, -535500,4.6937227,0.6219874,,,,,,,,,,,,,, -535600,4.4817433,0.5678809,,,,,,,,,,,,,, -535700,4.03967,0.599152,,,,,,,,,,,,,, -535800,4.4035444,0.67054975,,,,,,,,,,,,,, -535900,4.5989323,0.6701839,,,,,,,,,,,,,, -536000,3.964889,0.59870785,,,,,,,,,,,,,, -536100,4.438947,0.62185436,,,,,,,,,,,,,, -536105,,,0.9607780575752258,0.1484987586736679,0.7554399967193604,1.0436646938323977,50000.0,0.6321000456809998,1.8213005065917969,10000.0,181112.3642761708,187381.0358531475,181112.3642761708,6221.071969509125,29.16448450088501,0.0 -536200,4.8539104,0.72104883,,,,,,,,,,,,,, -536300,4.508302,0.6264204,,,,,,,,,,,,,, -536400,4.3892107,0.59123456,,,,,,,,,,,,,, -536500,4.4596906,0.62866694,,,,,,,,,,,,,, -536600,4.984455,0.61933243,,,,,,,,,,,,,, -536700,4.3759537,0.6248852,,,,,,,,,,,,,, -536800,4.444221,0.6748247,,,,,,,,,,,,,, -536900,4.5113263,0.59327054,,,,,,,,,,,,,, -537000,4.7054105,0.62956583,,,,,,,,,,,,,, -537100,4.273857,0.5661133,,,,,,,,,,,,,, -537200,4.40348,0.6297157,,,,,,,,,,,,,, -537300,4.629532,0.63098603,,,,,,,,,,,,,, -537400,4.622714,0.6477582,,,,,,,,,,,,,, -537500,4.4907427,0.63523483,,,,,,,,,,,,,, -537600,4.232776,0.6138636,,,,,,,,,,,,,, -537615,,,0.9604392051696776,0.1466994285583496,0.7552599906921387,1.0449239015579224,50000.0,0.6301000118255615,1.8213962316513064,10000.0,181622.24652838707,187908.334836483,181622.24652838707,6238.322468757629,29.27834534645081,0.0 -537700,4.999731,0.68391293,,,,,,,,,,,,,, -537800,4.612669,0.6430665,,,,,,,,,,,,,, -537900,4.3990993,0.5673131,,,,,,,,,,,,,, -538000,4.1459794,0.65232277,,,,,,,,,,,,,, -538100,4.19598,0.68446136,,,,,,,,,,,,,, -538200,4.3763547,0.67639816,,,,,,,,,,,,,, -538300,3.996391,0.58612114,,,,,,,,,,,,,, -538400,4.4950795,0.6528176,,,,,,,,,,,,,, -538500,4.622918,0.7059078,,,,,,,,,,,,,, -538600,4.2836113,0.55785716,,,,,,,,,,,,,, -538700,4.7119017,0.6229692,,,,,,,,,,,,,, -538800,4.785951,0.706666,,,,,,,,,,,,,, -538900,4.782261,0.6563876,,,,,,,,,,,,,, -539000,5.0108614,0.61463296,,,,,,,,,,,,,, -539100,4.9144206,0.6612322,,,,,,,,,,,,,, -539125,,,0.9596021771430968,0.151600956916809,0.7555599808692932,1.043686032295227,50000.0,0.6310000419616699,1.822413206100464,10000.0,182132.1446402073,188435.3849275112,182132.1446402073,6255.310319900513,29.389899730682373,0.0 -539200,4.3369026,0.63547206,,,,,,,,,,,,,, -539300,4.3563776,0.5496691,,,,,,,,,,,,,, -539400,4.9828444,0.63136524,,,,,,,,,,,,,, -539500,4.378321,0.6277782,,,,,,,,,,,,,, -539600,4.5447016,0.5700741,,,,,,,,,,,,,, -539700,5.119678,0.65516627,,,,,,,,,,,,,, -539800,4.6418986,0.63756114,,,,,,,,,,,,,, -539900,4.4919066,0.64040816,,,,,,,,,,,,,, -540000,4.5917373,0.69893473,,,,,,,,,,,,,, -540100,4.89608,0.6016893,,,,,,,,,,,,,, -540200,4.4652753,0.65791947,,,,,,,,,,,,,, -540300,4.6258187,0.64653844,,,,,,,,,,,,,, -540400,4.6652474,0.59970987,,,,,,,,,,,,,, -540500,4.520625,0.6064034,,,,,,,,,,,,,, -540600,4.596298,0.700099,,,,,,,,,,,,,, -540635,,,0.9598612785339355,0.1493533551692962,0.7551400065422058,1.0445722341537476,50000.0,0.6308000087738037,1.8208487033844,10000.0,182642.13536715508,188962.29661369324,182642.13536715508,6272.058003902435,29.50988364219665,0.0 -540700,4.3655586,0.58310276,,,,,,,,,,,,,, -540800,4.2841673,0.630983,,,,,,,,,,,,,, -540900,4.6499634,0.62834793,,,,,,,,,,,,,, -541000,4.6162553,0.67262155,,,,,,,,,,,,,, -541100,4.079356,0.5825615,,,,,,,,,,,,,, -541200,4.1783237,0.6213434,,,,,,,,,,,,,, -541300,4.17933,0.5750877,,,,,,,,,,,,,, -541400,4.6565056,0.6190361,,,,,,,,,,,,,, -541500,4.536057,0.6144148,,,,,,,,,,,,,, -541600,4.3560824,0.6174489,,,,,,,,,,,,,, -541700,4.3428845,0.6905096,,,,,,,,,,,,,, -541800,5.0817847,0.61397505,,,,,,,,,,,,,, -541900,5.3310356,0.6651479,,,,,,,,,,,,,, -542000,4.6638103,0.60215664,,,,,,,,,,,,,, -542100,4.556051,0.6772839,,,,,,,,,,,,,, -542146,,,0.9615154266357422,0.1449406147003173,0.7549200057983398,1.044356346130371,50000.0,0.6310000419616699,1.822114944458008,10000.0,183152.2509188652,189489.21495723724,183152.2509188652,6288.69265794754,29.62590265274048,0.0 -542200,4.2450237,0.553712,,,,,,,,,,,,,, -542300,4.016692,0.54414177,,,,,,,,,,,,,, -542400,4.441879,0.5882672,,,,,,,,,,,,,, -542500,4.785393,0.65169704,,,,,,,,,,,,,, -542600,4.6059833,0.6256763,,,,,,,,,,,,,, -542700,5.418835,0.7352946,,,,,,,,,,,,,, -542800,4.423035,0.58051884,,,,,,,,,,,,,, -542900,4.527285,0.5685313,,,,,,,,,,,,,, -543000,4.4120717,0.6327895,,,,,,,,,,,,,, -543100,4.6802583,0.6640987,,,,,,,,,,,,,, -543200,4.905261,0.6047251,,,,,,,,,,,,,, -543300,4.1725187,0.62107784,,,,,,,,,,,,,, -543400,4.4893723,0.61136687,,,,,,,,,,,,,, -543500,4.59102,0.67511374,,,,,,,,,,,,,, -543600,4.326304,0.54617137,,,,,,,,,,,,,, -543657,,,0.9601402878761292,0.1462753862142563,0.7551400065422058,1.0432575941085815,50000.0,0.6319000124931335,1.8207831382751465,10000.0,183662.38544940948,190016.3076210022,183662.38544940948,6305.479026794434,29.74439120292664,0.0 -543700,4.5549054,0.5998466,,,,,,,,,,,,,, -543800,4.325037,0.6241923,,,,,,,,,,,,,, -543900,5.0680566,0.6121184,,,,,,,,,,,,,, -544000,4.4168563,0.67751217,,,,,,,,,,,,,, -544100,4.483331,0.582512,,,,,,,,,,,,,, -544200,4.798765,0.64206797,,,,,,,,,,,,,, -544300,4.726019,0.63119435,,,,,,,,,,,,,, -544400,4.5450106,0.65439683,,,,,,,,,,,,,, -544500,4.410913,0.6441493,,,,,,,,,,,,,, -544600,4.67161,0.6196524,,,,,,,,,,,,,, -544700,4.8554473,0.6205667,,,,,,,,,,,,,, -544800,4.5434494,0.57133377,,,,,,,,,,,,,, -544900,4.70013,0.5996638,,,,,,,,,,,,,, -545000,5.4918694,0.7156036,,,,,,,,,,,,,, -545100,4.5972023,0.6218029,,,,,,,,,,,,,, -545168,,,0.96000075340271,0.1489509791135788,0.7549999952316284,1.044825792312622,50000.0,0.631100058555603,1.821388959884644,10000.0,184172.5107626915,190543.8264966011,184172.5107626915,6322.706331968308,29.857580184936523,0.0 -545200,4.265565,0.60147613,,,,,,,,,,,,,, -545300,4.0393405,0.61276233,,,,,,,,,,,,,, -545400,4.5675178,0.5956713,,,,,,,,,,,,,, -545500,5.506323,0.74564147,,,,,,,,,,,,,, -545600,4.214225,0.6448597,,,,,,,,,,,,,, -545700,4.738776,0.65674543,,,,,,,,,,,,,, -545800,4.433078,0.6018467,,,,,,,,,,,,,, -545900,4.4265037,0.5979388,,,,,,,,,,,,,, -546000,4.484163,0.6105263,,,,,,,,,,,,,, -546100,4.2499795,0.6110536,,,,,,,,,,,,,, -546200,5.081788,0.69193244,,,,,,,,,,,,,, -546300,4.7579966,0.6259045,,,,,,,,,,,,,, -546400,4.735447,0.6458742,,,,,,,,,,,,,, -546500,4.7613997,0.5970844,,,,,,,,,,,,,, -546600,4.2622523,0.58534026,,,,,,,,,,,,,, -546679,,,0.9606783986091614,0.1472610086202621,0.7549999952316284,1.043584942817688,50000.0,0.6315000057220459,1.821320414543152,10000.0,184682.5400466919,191070.85496354103,184682.5400466919,6339.538756847382,29.97205090522766,0.0 -546700,4.278649,0.61397827,,,,,,,,,,,,,, -546800,4.778564,0.6299236,,,,,,,,,,,,,, -546900,5.528581,0.6276895,,,,,,,,,,,,,, -547000,4.149754,0.59474224,,,,,,,,,,,,,, -547100,4.4298162,0.60451204,,,,,,,,,,,,,, -547200,4.786935,0.6107811,,,,,,,,,,,,,, -547300,4.1710253,0.52552086,,,,,,,,,,,,,, -547400,4.729897,0.6476226,,,,,,,,,,,,,, -547500,4.6627836,0.6745863,,,,,,,,,,,,,, -547600,4.457886,0.6521011,,,,,,,,,,,,,, -547700,5.0434747,0.6603428,,,,,,,,,,,,,, -547800,4.4895964,0.63540405,,,,,,,,,,,,,, -547900,5.0508075,0.6337827,,,,,,,,,,,,,, -548000,4.4456034,0.61634445,,,,,,,,,,,,,, -548100,4.626519,0.6351892,,,,,,,,,,,,,, -548189,,,0.959741711616516,0.1502102166414261,0.7551800012588501,1.0438426733016968,50000.0,0.6309000253677368,1.82038676738739,10000.0,185192.69767308235,191597.8956928253,185192.69767308235,6356.251218318939,30.08995032310486,0.0 -548200,4.6549845,0.613887,,,,,,,,,,,,,, -548300,4.403327,0.6726202,,,,,,,,,,,,,, -548400,4.5657907,0.6053965,,,,,,,,,,,,,, -548500,4.317863,0.6151088,,,,,,,,,,,,,, -548600,4.274548,0.662699,,,,,,,,,,,,,, -548700,4.503199,0.602823,,,,,,,,,,,,,, -548800,4.380048,0.587214,,,,,,,,,,,,,, -548900,4.631894,0.5975702,,,,,,,,,,,,,, -549000,5.076034,0.67690843,,,,,,,,,,,,,, -549100,4.239797,0.56657374,,,,,,,,,,,,,, -549200,4.562628,0.67769235,,,,,,,,,,,,,, -549300,4.7685676,0.7315896,,,,,,,,,,,,,, -549400,4.4903708,0.6208683,,,,,,,,,,,,,, -549500,4.3178596,0.54377127,,,,,,,,,,,,,, -549600,4.21827,0.61634934,,,,,,,,,,,,,, -549699,,,0.9620535373687744,0.1419393718242645,0.7547799944877625,1.043861746788025,50000.0,0.6307000517845154,1.8231457471847528,10000.0,185702.6072921753,192124.9710338116,185702.6072921753,6373.249268293381,30.20629596710205,0.0 -549700,4.4909873,0.59108514,,,,,,,,,,,,,, -549800,4.546327,0.6029011,,,,,,,,,,,,,, -549900,4.70209,0.70652586,,,,,,,,,,,,,, -550000,4.527167,0.68515456,,,,,,,,,,,,,, -550100,4.2765036,0.6183989,,,,,,,,,,,,,, -550200,4.547043,0.6429858,,,,,,,,,,,,,, -550300,4.151099,0.64153,,,,,,,,,,,,,, -550400,4.3729687,0.5371605,,,,,,,,,,,,,, -550500,4.609381,0.62097585,,,,,,,,,,,,,, -550600,4.567567,0.63076043,,,,,,,,,,,,,, -550700,4.532122,0.60906047,,,,,,,,,,,,,, -550800,4.2140846,0.5903391,,,,,,,,,,,,,, -550900,4.6443286,0.6506635,,,,,,,,,,,,,, -551000,4.4244366,0.5783208,,,,,,,,,,,,,, -551100,4.511647,0.5999435,,,,,,,,,,,,,, -551200,4.5639105,0.62066895,,,,,,,,,,,,,, -551209,,,0.9608178734779358,0.144424170255661,0.7553199529647827,1.0437524318695068,50000.0,0.6319000124931335,1.823291301727295,10000.0,186212.7490684986,192652.24084949493,186212.7490684986,6390.202693462372,30.328657150268555,0.0 -551300,4.3108544,0.5991353,,,,,,,,,,,,,, -551400,4.843742,0.66307205,,,,,,,,,,,,,, -551500,4.2855487,0.6285384,,,,,,,,,,,,,, -551600,4.398076,0.563147,,,,,,,,,,,,,, -551700,4.4708366,0.68288815,,,,,,,,,,,,,, -551800,4.7238665,0.59903395,,,,,,,,,,,,,, -551900,4.1517553,0.5884427,,,,,,,,,,,,,, -552000,4.862149,0.60923666,,,,,,,,,,,,,, -552100,4.9162774,0.6359803,,,,,,,,,,,,,, -552200,4.4983697,0.59224284,,,,,,,,,,,,,, -552300,4.333548,0.6225788,,,,,,,,,,,,,, -552400,4.8719835,0.5970786,,,,,,,,,,,,,, -552500,4.091266,0.60805905,,,,,,,,,,,,,, -552600,4.4134903,0.69150764,,,,,,,,,,,,,, -552700,4.430405,0.59975535,,,,,,,,,,,,,, -552719,,,0.96097731590271,0.1480630040168762,0.7549999952316284,1.0442960262298584,50000.0,0.6306000351905823,1.8228309154510496,10000.0,186722.60253715515,193179.3252959252,186722.60253715515,6407.26146531105,30.44870686531067,0.0 -552800,4.4171696,0.6064124,,,,,,,,,,,,,, -552900,4.3213706,0.6098513,,,,,,,,,,,,,, -553000,4.656624,0.6330942,,,,,,,,,,,,,, -553100,4.287061,0.61890906,,,,,,,,,,,,,, -553200,4.303845,0.5502897,,,,,,,,,,,,,, -553300,4.745962,0.6698539,,,,,,,,,,,,,, -553400,4.507776,0.66933036,,,,,,,,,,,,,, -553500,4.2542334,0.54714483,,,,,,,,,,,,,, -553600,4.4725804,0.6316346,,,,,,,,,,,,,, -553700,4.7319155,0.7104292,,,,,,,,,,,,,, -553800,4.933739,0.5928114,,,,,,,,,,,,,, -553900,4.1893106,0.60400593,,,,,,,,,,,,,, -554000,4.64249,0.5706653,,,,,,,,,,,,,, -554100,4.4775515,0.6130785,,,,,,,,,,,,,, -554200,4.4102507,0.6430321,,,,,,,,,,,,,, -554229,,,0.9595423936843872,0.1502232849597931,0.7549799680709839,1.0441011190414429,50000.0,0.6302000284194946,1.8223854303359983,10000.0,187232.4409868717,193706.154392004,187232.4409868717,6424.07945227623,30.56987929344177,0.0 -554300,4.308437,0.5686606,,,,,,,,,,,,,, -554400,4.468685,0.65445715,,,,,,,,,,,,,, -554500,4.3557673,0.61108243,,,,,,,,,,,,,, -554600,4.500843,0.5930164,,,,,,,,,,,,,, -554700,4.3742657,0.5772315,,,,,,,,,,,,,, -554800,4.6089253,0.62062395,,,,,,,,,,,,,, -554900,4.733728,0.60515964,,,,,,,,,,,,,, -555000,4.6008053,0.6392722,,,,,,,,,,,,,, -555100,4.4186206,0.576794,,,,,,,,,,,,,, -555200,4.970551,0.64553833,,,,,,,,,,,,,, -555300,5.135438,0.67193305,,,,,,,,,,,,,, -555400,4.272008,0.5894612,,,,,,,,,,,,,, -555500,4.575135,0.66939664,,,,,,,,,,,,,, -555600,4.2131076,0.6185876,,,,,,,,,,,,,, -555700,4.183729,0.63640386,,,,,,,,,,,,,, -555740,,,0.9612165093421936,0.1473574638366699,0.7553600072860718,1.0433720350265503,50000.0,0.631100058555603,1.8212271928787231,10000.0,187742.46044564247,194233.10874533653,187742.46044564247,6440.839239120483,30.693721771240234,0.0 -555800,4.5231276,0.62743896,,,,,,,,,,,,,, -555900,4.932854,0.6588337,,,,,,,,,,,,,, -556000,4.44893,0.59261465,,,,,,,,,,,,,, -556100,4.8432975,0.63979584,,,,,,,,,,,,,, -556200,4.428829,0.59840804,,,,,,,,,,,,,, -556300,4.4301963,0.65954435,,,,,,,,,,,,,, -556400,4.5416827,0.6217204,,,,,,,,,,,,,, -556500,4.553216,0.64496815,,,,,,,,,,,,,, -556600,4.4617667,0.6812351,,,,,,,,,,,,,, -556700,4.6915827,0.6741439,,,,,,,,,,,,,, -556800,4.2644534,0.578848,,,,,,,,,,,,,, -556900,4.744184,0.6145704,,,,,,,,,,,,,, -557000,4.9302173,0.6075108,,,,,,,,,,,,,, -557100,4.487297,0.59115887,,,,,,,,,,,,,, -557200,4.4950643,0.70086867,,,,,,,,,,,,,, -557250,,,0.9599011540412904,0.1505948454141616,0.755299985408783,1.0440723896026611,50000.0,0.6312000155448914,1.8224512338638303,10000.0,188252.3400375843,194760.04110717773,188252.3400375843,6457.721947193146,30.81244969367981,0.0 -557300,4.3311462,0.6288482,,,,,,,,,,,,,, -557400,4.358329,0.5911141,,,,,,,,,,,,,, -557500,4.5014606,0.6022625,,,,,,,,,,,,,, -557600,4.4997287,0.5755296,,,,,,,,,,,,,, -557700,4.3985868,0.637233,,,,,,,,,,,,,, -557800,4.2427487,0.6499346,,,,,,,,,,,,,, -557900,4.455419,0.6852443,,,,,,,,,,,,,, -558000,4.6205544,0.6474652,,,,,,,,,,,,,, -558100,4.1052165,0.5759109,,,,,,,,,,,,,, -558200,4.242577,0.60083723,,,,,,,,,,,,,, -558300,4.2998405,0.6179587,,,,,,,,,,,,,, -558400,4.042292,0.65979505,,,,,,,,,,,,,, -558500,4.4263787,0.6274431,,,,,,,,,,,,,, -558600,4.373238,0.54577047,,,,,,,,,,,,,, -558700,4.826638,0.6621222,,,,,,,,,,,,,, -558760,,,0.95804762840271,0.1525946706533432,0.7549200057983398,1.043939232826233,50000.0,0.6308000087738037,1.822641134262085,10000.0,188762.30083870888,195287.0536873341,188762.30083870888,6474.599181890488,30.935455560684204,0.0 -558800,4.4396105,0.61523163,,,,,,,,,,,,,, -558900,4.5793066,0.63558376,,,,,,,,,,,,,, -559000,4.8191295,0.62717134,,,,,,,,,,,,,, -559100,4.1544566,0.56601137,,,,,,,,,,,,,, -559200,4.4340425,0.5920622,,,,,,,,,,,,,, -559300,4.5407043,0.60989577,,,,,,,,,,,,,, -559400,4.4140086,0.6313203,,,,,,,,,,,,,, -559500,4.888227,0.7174549,,,,,,,,,,,,,, -559536,,,,,,,,,,,189024.18581700325,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 6e4867da7..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,555 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -39.74869656562805,0.0,42.4552264213562,1,0,42.4552264213562,0.0010000000474974,6.907756805419922,10000,82.20404887199402,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -61.21135711669922,0.0267481803894042,462.6057026386261,916,0,462.6057026386261,0.0265000015497207,6.053519248962402,10000,523.893707036972,0.033281248062849,5.910701751708984,0.0322199985384941,5.930098056793213,50000 -82.97681331634521,0.0551979541778564,882.6500160694122,1880,0,882.6500160694122,0.0578000023961067,5.46556282043457,10000,965.783664226532,0.0802929699420929,5.186488628387451,0.0763599947094917,5.225624084472656,50000 -104.62666702270508,0.0829133987426757,1302.63339304924,2845,0,1302.63339304924,0.1098000034689903,4.911513328552246,10000,1407.4959456920624,0.1541601568460464,4.475606441497803,0.1417199969291687,4.5772247314453125,50000 -126.36748099327087,0.1124176979064941,1722.6479833126068,3809,0,1722.6479833126068,0.1480000019073486,4.524412155151367,10000,1849.3324444293976,0.2120117098093032,3.999783754348755,0.1994999945163726,4.077298641204834,50000 -148.49134397506714,0.1440119743347168,2142.643502473831,4771,0,2142.643502473831,0.1887000054121017,4.161242961883545,10000,2291.534899711609,0.2714648246765136,3.5543103218078613,0.2515200078487396,3.682871580123901,50000 -170.15383672714233,0.1743319034576416,2562.6839604377747,5732,0,2562.6839604377747,0.2252000123262405,3.929141998291016,10000,2733.3190398216248,0.3236913979053497,3.2366814613342285,0.2985199987888336,3.391303062438965,50000 -192.07357716560364,0.2041935920715332,2982.868365049362,6691,0,2982.868365049362,0.2591000199317932,3.685933828353882,10000,3175.5051844120026,0.3828320205211639,2.853161573410034,0.3342599868774414,3.135330438613892,50000 -217.67090773582456,0.2334585189819336,3403.056977033615,7649,0,3403.056977033615,0.2818000018596649,3.512402057647705,10000,3621.3701384067535,0.3971289098262787,2.773566722869873,0.3694399893283844,2.9177207946777344,50000 -249.9872395992279,0.2634403705596924,3823.41290307045,8609,0,3823.41290307045,0.300100028514862,3.4087908267974854,10000,4074.122909069061,0.4228320121765136,2.649387121200561,0.3889999985694885,2.813301086425781,50000 -275.773681640625,0.296008825302124,4243.586100816727,9570,0,4243.586100816727,0.3128000199794769,3.3069517612457275,10000,4520.165455341339,0.4484570324420929,2.475047826766968,0.4088599979877472,2.6899614334106445,50000 -301.3532176017761,0.3298141956329345,4663.614185810089,10527,0,4663.614185810089,0.323600023984909,3.2185275554656982,10000,4965.85725569725,0.4597460925579071,2.403148651123047,0.4257799983024597,2.5876495838165283,50000 -329.1883616447449,0.3632173538208008,5083.6017208099365,11485,0,5083.6017208099365,0.3461000025272369,3.1271629333496094,10000,5413.76374578476,0.4774218499660492,2.352294921875,0.4411799907684326,2.527583599090576,50000 -357.07783699035645,0.3932712078094482,5503.675626039505,12442,0,5503.675626039505,0.3476000130176544,3.0680062770843506,10000,5861.807518482208,0.5004491806030273,2.173110246658325,0.4589599967002868,2.39906644821167,50000 -385.8735525608063,0.4236998558044433,5923.727476358414,13395,0,5923.727476358414,0.3655000030994415,3.0263166427612305,10000,6310.735496520996,0.5154492259025574,2.150564432144165,0.4667999744415283,2.3906404972076416,50000 -416.6871347427368,0.4599125385284424,6343.8371758461,14346,0,6343.8371758461,0.3747000098228454,2.943275213241577,10000,6761.744855165482,0.519238293170929,2.1266376972198486,0.480239987373352,2.312174320220948,50000 -446.9011867046356,0.4952383041381836,6764.088172674179,15295,0,6764.088172674179,0.3838000297546386,2.8951008319854736,10000,7212.294851779938,0.5322265625,2.005976915359497,0.4880199730396271,2.22807240486145,50000 -476.6402099132538,0.5272412300109863,7184.348518371582,16244,0,7184.348518371582,0.3926000297069549,2.8501899242401123,10000,7662.375813245773,0.5482421517372131,1.9572895765304563,0.501259982585907,2.1907479763031006,50000 -507.73856592178345,0.5607478618621826,7604.298126220703,17194,0,7604.298126220703,0.38960000872612,2.829267263412476,10000,8113.509242534637,0.5625194907188416,1.8765634298324585,0.5061399936676025,2.151052713394165,50000 -539.5802881717682,0.6070611476898193,8024.374304771423,18141,0,8024.374304771423,0.3961000144481659,2.798175811767578,10000,8565.522846698761,0.5486718416213989,1.961691856384277,0.5062800049781799,2.1555776596069336,50000 -571.2340226173401,0.6397192478179932,8444.733457803726,19089,0,8444.733457803726,0.4065000116825104,2.7337896823883057,10000,9017.617252111437,0.5634179711341858,1.8526079654693604,0.5212599635124207,2.0621883869171143,50000 -602.41819190979,0.6724779605865479,8864.943633794785,20038,0,8864.943633794785,0.414900004863739,2.704115629196167,10000,9469.092945098875,0.5821679830551147,1.7724109888076782,0.5266799926757812,2.037682056427002,50000 -633.9534032344818,0.7069294452667236,9285.068341493608,20987,0,9285.068341493608,0.4214000105857849,2.6908113956451416,10000,9920.836732149124,0.5732030868530273,1.835625052452088,0.5278199911117554,2.045748949050904,50000 -665.4959726333618,0.7394287586212158,9705.1229493618,21934,0,9705.1229493618,0.4229000210762024,2.6369969844818115,10000,10372.515660524368,0.5780468583106995,1.77647066116333,0.5388399958610535,1.979406476020813,50000 -698.3626811504364,0.7748537063598633,10125.126637935638,22875,0,10125.126637935638,0.4220000207424164,2.671059846878052,10000,10825.470078229904,0.5887890458106995,1.7585443258285522,0.5386399626731873,2.00494122505188,50000 -731.4255640506744,0.802415132522583,10545.515405893326,23826,0,10545.515405893326,0.4359000325202942,2.585143566131592,10000,11278.99824810028,0.623339831829071,1.555069088935852,0.5521999597549438,1.9062819480896,50000 -763.2415297031403,0.8369770050048828,10965.437469005585,24775,0,10965.437469005585,0.4387000203132629,2.591357707977295,10000,11730.820744037628,0.5954492092132568,1.7052239179611206,0.5511800050735474,1.9209455251693728,50000 -794.7530353069305,0.8722774982452393,11385.874091625214,25720,0,11385.874091625214,0.4342000186443329,2.559943437576294,10000,12182.853080272676,0.6026171445846558,1.657276272773743,0.5577200055122375,1.8794413805007928,50000 -826.1598796844482,2.062011241912842,11804.853033542631,26657,0,11804.853033542631,0.4451000094413757,2.543614387512207,10000,12634.47670030594,0.6150780916213989,1.6037347316741943,0.5588799715042114,1.8744468688964844,50000 -858.7074551582336,2.100292921066284,12224.784992218018,27602,0,12224.784992218018,0.4532000124454498,2.5086917877197266,10000,13087.04343366623,0.6069140434265137,1.6219347715377808,0.5689600110054016,1.815601825714112,50000 -891.6379976272583,2.135059118270874,12645.118678808212,28549,0,12645.118678808212,0.449500024318695,2.522683620452881,10000,13540.391645908356,0.6090429425239563,1.6417275667190552,0.5631799697875977,1.8570791482925413,50000 -921.1676602363586,2.163169860839844,13065.383793115616,29497,0,13065.383793115616,0.4493000209331512,2.5253915786743164,10000,13990.263841152191,0.6200780868530273,1.5937964916229248,0.5712400078773499,1.8428938388824463,50000 -954.7928731441498,2.198180913925171,13485.3846244812,30441,0,13485.3846244812,0.4524000287055969,2.4883368015289307,10000,14443.974184513092,0.646484375,1.4701223373413086,0.5729599595069885,1.815598130226136,50000 -986.5768990516664,2.23132872581482,13905.559622764587,31388,0,13905.559622764587,0.4555000364780426,2.473357677459717,10000,14896.015183925629,0.6192382574081421,1.590422749519348,0.5780199766159058,1.7954529523849487,50000 -1021.1722357273102,2.2684500217437744,14325.485937595367,32334,0,14325.485937595367,0.4574000239372253,2.485887289047241,10000,15350.62278676033,0.6196874976158142,1.594934344291687,0.5723400115966797,1.8144795894622805,50000 -1052.1730644702911,2.2977893352508545,14745.835205078123,33281,0,14745.835205078123,0.4593000113964081,2.459869623184204,10000,15802.050520658491,0.6363476514816284,1.4893057346343994,0.5823799967765808,1.7727079391479492,50000 -1083.7390191555023,2.3401732444763184,15165.793439865112,34224,0,15165.793439865112,0.4670000076293945,2.411027908325196,10000,16253.66602897644,0.6278319954872131,1.517386555671692,0.5886200070381165,1.7128303050994873,50000 -1115.3563861846924,2.374427318572998,15585.799923658373,35169,0,15585.799923658373,0.4645000100135803,2.427992582321167,10000,16705.374019622803,0.6328905820846558,1.534470796585083,0.5838599801063538,1.7675602436065674,50000 -1147.792130947113,2.408688545227051,16005.73632979393,36114,0,16005.73632979393,0.4657000303268432,2.405973196029663,10000,17157.829701662064,0.6386132836341858,1.485224366188049,0.585099995136261,1.736467719078064,50000 -1180.2964413166046,2.449498414993286,16425.699372053146,37060,0,16425.699372053146,0.4737000167369842,2.4095163345336914,10000,17610.38723230362,0.6688476204872131,1.3819684982299805,0.5868600010871887,1.7398897409439087,50000 -1210.706288814545,2.4802818298339844,16845.78829050064,38005,0,16845.78829050064,0.4795000255107879,2.360228538513184,10000,18060.96590399742,0.639355480670929,1.477370262145996,0.592960000038147,1.6992011070251465,50000 -1241.442130804062,2.518287420272827,17265.709993124008,38949,0,17265.709993124008,0.4771000146865845,2.336911916732788,10000,18511.710064649586,0.6482812166213989,1.4299132823944092,0.601099967956543,1.6725398302078247,50000 -1272.2637612819672,2.5543577671051025,17685.679537296295,39893,0,17685.679537296295,0.4825000166893005,2.327033042907715,10000,18962.58662962913,0.6584374904632568,1.4017064571380615,0.6007599830627441,1.6774709224700928,50000 -1303.868986368179,2.595737934112549,18105.65966153145,40837,0,18105.65966153145,0.4824000298976898,2.323774814605713,10000,19414.26227951049,0.6469140648841858,1.4363796710968018,0.6001799702644348,1.6573528051376345,50000 -1334.4877030849457,2.63551926612854,18525.777707338333,41783,0,18525.777707338333,0.4859000146389007,2.331721544265747,10000,19865.08677005768,0.6481835842132568,1.4439572095870972,0.5995799899101257,1.6738475561141968,50000 -1365.0891120433807,2.6732630729675293,18945.784809350967,42727,0,18945.784809350967,0.4874000251293182,2.338141202926636,10000,20315.787051200867,0.6622461080551147,1.4292950630187988,0.6039800047874451,1.6770350933074951,50000 -1395.6237680912018,2.708239316940308,19365.801652669907,43675,0,19365.801652669907,0.4832000136375427,2.350550413131714,10000,20766.421213150024,0.6868749856948853,1.3137249946594238,0.6068400144577026,1.681557297706604,50000 -1426.566675901413,2.75048303604126,19785.997804403305,44621,0,19785.997804403305,0.4914000332355499,2.3031435012817383,10000,21217.65069699288,0.6566405892372131,1.4093447923660278,0.6092199683189392,1.6448744535446167,50000 -1457.8632283210754,2.787197589874268,20206.22102165222,45569,0,20206.22102165222,0.4869000315666199,2.2931628227233887,10000,21669.25616335869,0.6633593440055847,1.3881406784057615,0.6101799607276917,1.6343706846237185,50000 -1492.0283637046814,2.8271169662475586,20626.519107580185,46517,0,20626.519107580185,0.486700028181076,2.2925844192504883,10000,22123.807602643967,0.6712695360183716,1.3291295766830444,0.6121799945831299,1.6179096698760986,50000 -1523.521923303604,2.8597183227539062,21046.87879395485,47466,0,21046.87879395485,0.491100013256073,2.26879358291626,10000,22575.74212741852,0.6614062190055847,1.4059300422668457,0.6163600087165833,1.6147215366363523,50000 -1552.896357297897,2.897174119949341,21466.973664999008,48415,0,21466.973664999008,0.4886000156402588,2.3144147396087646,10000,23025.29788851738,0.6602538824081421,1.423980951309204,0.6101399660110474,1.6533193588256836,50000 -1584.982391357422,2.935287952423096,21887.021904945374,49359,0,21887.021904945374,0.4917000234127044,2.266906261444092,10000,23477.518564224243,0.6670702695846558,1.3476743698120115,0.6132599711418152,1.6010985374450684,50000 -1614.7766954898834,2.9790749549865723,22307.08506894112,50305,0,22307.08506894112,0.4980000257492065,2.233330488204956,10000,23927.468125104904,0.6949023008346558,1.217258095741272,0.6212999820709229,1.570192813873291,50000 -1648.4719729423523,3.016878128051758,22727.23337483406,51252,0,22727.23337483406,0.4948000311851501,2.2380619049072266,10000,24381.398369073868,0.6680468320846558,1.355788230895996,0.6198599934577942,1.5720373392105105,50000 -1681.53049826622,3.056202173233032,23147.45352268219,52203,0,23147.45352268219,0.495600014925003,2.272915840148926,10000,24834.76548576355,0.6677343845367432,1.3512169122695925,0.6184799671173096,1.5838528871536257,50000 -1712.6602900028229,3.0928637981414795,23567.636283397675,53154,0,23567.636283397675,0.5009000301361084,2.195390224456787,10000,25286.163171052933,0.6862499713897705,1.2557148933410645,0.6257599592208862,1.5392677783966064,50000 -1742.2668850421906,3.1326940059661865,23987.96748137474,54101,0,23987.96748137474,0.5051000118255615,2.196848630905152,10000,25736.189141988754,0.6750195026397705,1.320992112159729,0.62527996301651,1.5474814176559448,50000 -1775.3276269435885,3.173816442489624,24408.114142656326,55049,0,24408.114142656326,0.4979000091552734,2.2316365242004395,10000,26189.48584985733,0.6699999570846558,1.3624234199523926,0.6233199834823608,1.5818876028060913,50000 -1806.845562696457,3.211068630218506,24828.346135139465,55999,0,24828.346135139465,0.4998000264167785,2.230363368988037,10000,26641.32139706612,0.6770898103713989,1.3100413084030151,0.6246399879455566,1.567784070968628,50000 -1837.941088438034,3.255908966064453,25248.46753191948,56943,0,25248.46753191948,0.5057000517845154,2.21742582321167,10000,27092.631110429764,0.6991015672683716,1.2092502117156982,0.6245599985122681,1.5463355779647827,50000 -1868.226809501648,3.297309160232544,25668.446413993835,57889,0,25668.446413993835,0.506600022315979,2.182492971420288,10000,27542.98506808281,0.6811327934265137,1.2940772771835327,0.6329799890518188,1.5235350131988523,50000 -1900.8828961849213,3.3394837379455566,26088.65665245056,58835,0,26088.65665245056,0.5116000175476074,2.191758632659912,10000,27995.94226884842,0.68115234375,1.3139863014221191,0.6299600005149841,1.553807258605957,50000 -1930.6553165912628,3.3801913261413574,26508.756959676743,59782,0,26508.756959676743,0.5057000517845154,2.211388111114502,10000,28445.90362930298,0.6947851181030273,1.258604884147644,0.630899965763092,1.5560885667800903,50000 -1963.071844100952,3.4235599040985107,26929.041621923447,60689,0,26929.041621923447,0.5031999945640564,2.187201023101806,10000,28898.69361257553,0.6812695264816284,1.2990782260894775,0.6337000131607056,1.513222098350525,50000 -1995.115867614746,3.4631595611572266,27349.087817907333,61639,0,27349.087817907333,0.5178000330924988,2.1470038890838623,10000,29350.87168526649,0.6890038847923279,1.2490397691726685,0.6389399766921997,1.4967432022094729,50000 -2026.404606342316,3.505953550338745,27769.2222173214,62585,0,27769.2222173214,0.5063000321388245,2.162984848022461,10000,29802.38598752021,0.6906836032867432,1.238747477531433,0.6342399716377258,1.5011688470840454,50000 -2057.330437898636,3.545721530914306,28189.522920131683,63533,0,28189.522920131683,0.508400022983551,2.1880664825439453,10000,30253.7003839016,0.71107417345047,1.1600855588912964,0.6335799694061279,1.5163129568099976,50000 -2087.645537853241,3.588677167892456,28609.791985034943,64480,0,28609.791985034943,0.5195000171661377,2.1230661869049072,10000,30704.375244617466,0.690625011920929,1.2341567277908323,0.6427800059318542,1.4698628187179563,50000 -2121.744611263275,3.628873109817505,29030.007290124893,65429,0,29030.007290124893,0.5190000534057617,2.1363329887390137,10000,31158.778439998627,0.6970507502555847,1.2187137603759766,0.642520010471344,1.467975616455078,50000 -2155.1333100795746,3.671163320541382,29450.01219272613,66381,0,29450.01219272613,0.5151000022888184,2.12232518196106,10000,31612.26335000992,0.70703125,1.1558291912078855,0.6475600004196167,1.4485762119293213,50000 -2185.6091549396515,3.7146735191345215,29870.019545555115,67327,0,29870.019545555115,0.5243000388145447,2.134155511856079,10000,32062.83818912506,0.6959765553474426,1.256928563117981,0.6457599997520447,1.4908559322357178,50000 -2215.2276520729065,3.780029296875,30289.98365139961,68273,0,30289.98365139961,0.5218000411987305,2.0921084880828857,10000,32512.53608345985,0.6995898485183716,1.2015330791473389,0.6462399959564209,1.4489117860794067,50000 -2248.1401748657227,3.823965072631836,30710.1666765213,69218,0,30710.1666765213,0.5241000056266785,2.1208341121673584,10000,32965.72461724281,0.7053906321525574,1.2163230180740356,0.6477400064468384,1.4740873575210571,50000 -2279.395946502685,4.07750391960144,31130.051805496216,70164,0,31130.051805496216,0.5228000283241272,2.128782272338867,10000,33417.16694736481,0.7193554639816284,1.1367861032485962,0.6459199786186218,1.4766112565994265,50000 -2313.651697158813,4.12572455406189,31550.3475124836,71108,0,31550.3475124836,0.5250000357627869,2.09932017326355,10000,33871.81460762024,0.6978319883346558,1.2037420272827148,0.6468799710273743,1.4458647966384888,50000 -2344.7223691940308,4.1603617668151855,31970.288603544235,72058,0,31970.288603544235,0.5271000266075134,2.0786259174346924,10000,34322.90913772583,0.7070507407188416,1.153599977493286,0.652899980545044,1.4156707525253296,50000 -2376.42147231102,4.202556133270264,32390.45892190933,73003,0,32390.45892190933,0.5306000113487244,2.0731072425842285,10000,34774.869156360626,0.710253894329071,1.1438536643981934,0.6521599888801575,1.4241604804992676,50000 -2411.644124746322,4.24935245513916,32810.473249197006,73943,0,32810.473249197006,0.5343000292778015,2.051135301589966,10000,35230.20307135582,0.7079882621765137,1.155798077583313,0.6553399562835693,1.4021482467651367,50000 -2445.44240641594,4.287355422973633,33230.81537055969,74891,0,33230.81537055969,0.5301000475883484,2.0819716453552246,10000,35684.42928028107,0.706835925579071,1.19562828540802,0.6556000113487244,1.433531403541565,50000 -2476.5385341644287,4.334265470504761,33651.12571454048,75839,0,33651.12571454048,0.5286000370979309,2.1042492389678955,10000,36135.93005084992,0.71205073595047,1.21051287651062,0.6584199666976929,1.4606388807296753,50000 -2511.526807546616,4.380546808242798,34071.293254852295,76784,0,34071.293254852295,0.534000039100647,2.07847261428833,10000,36591.18041014671,0.72572261095047,1.1005898714065552,0.651419997215271,1.4404127597808838,50000 -2543.342089414597,4.420172214508057,34491.62502241135,77733,0,34491.62502241135,0.5369000434875488,2.0748746395111084,10000,37043.41521525383,0.7066601514816284,1.1821738481521606,0.6551199555397034,1.4311776161193848,50000 -2573.301589488983,4.47462272644043,34911.57203626633,78677,0,34911.57203626633,0.532200038433075,2.0688157081604004,10000,37493.42399716377,0.7102343440055847,1.176645040512085,0.6593799591064453,1.4185059070587158,50000 -2604.975376367569,4.523173093795776,35331.53860926628,79619,0,35331.53860926628,0.5317000150680542,2.070014476776123,10000,37945.161172389984,0.7199609279632568,1.122406244277954,0.657759964466095,1.4062657356262207,50000 -2636.451033353805,4.587049722671509,35751.56966614723,80567,0,35751.56966614723,0.5348000526428223,2.05126953125,10000,38396.780281066895,0.7152929306030273,1.1523548364639282,0.6591199636459351,1.416941523551941,50000 -2669.9925594329834,4.633692026138306,36171.70365715027,81513,0,36171.70365715027,0.536300003528595,2.0532751083374023,10000,38850.550045251846,0.7157031297683716,1.1238621473312378,0.6619799733161926,1.38247811794281,50000 -2702.765382766724,4.6731603145599365,36591.86593723297,82463,0,36591.86593723297,0.5412999987602234,2.030993938446045,10000,39303.57247233391,0.7237108945846558,1.1237812042236328,0.6624999642372131,1.4014748334884644,50000 -2732.4041872024536,4.720138072967529,37012.14245843887,83412,0,37012.14245843887,0.5382000207901001,2.0109994411468506,10000,39753.58320188522,0.7419531345367432,1.0138866901397705,0.6646199822425842,1.3580873012542725,50000 -2764.3009791374207,4.763558626174927,37432.25945806503,84356,0,37432.25945806503,0.5446000099182129,2.0185341835021973,10000,40205.688380241394,0.721484363079071,1.12127685546875,0.6665599942207336,1.3685060739517212,50000 -2794.892049312592,4.821456909179688,37852.270012140274,85303,0,37852.270012140274,0.5424000024795532,2.0130226612091064,10000,40656.39597964287,0.7251952886581421,1.107918620109558,0.6675199866294861,1.3741096258163452,50000 -2824.427891969681,4.871694087982178,38272.54418039322,86251,0,38272.54418039322,0.5398000478744507,2.01186466217041,10000,41106.30424046517,0.7322655916213989,1.052381992340088,0.6670399904251099,1.3552080392837524,50000 -2860.237210035324,4.915456771850586,38692.73499083519,87195,0,38692.73499083519,0.5415000319480896,2.000982999801636,10000,41562.39662957192,0.7362109422683716,1.0454906225204468,0.6697999835014343,1.338555097579956,50000 -2896.091214418412,4.953283309936523,39112.82124638557,88144,0,39112.82124638557,0.5432000160217285,1.992767691612244,10000,42018.42330169678,0.7266015410423279,1.0895438194274902,0.6708799600601196,1.3421494960784912,50000 -2926.929531097412,4.990481376647949,39532.90726733208,89093,0,39532.90726733208,0.5485000014305115,1.9791725873947144,10000,42469.43255209923,0.7369921803474426,1.0429660081863403,0.6749799847602844,1.325405240058899,50000 -2958.897340297699,5.042929649353027,39953.20360445976,90040,0,39953.20360445976,0.5498000383377075,1.9649418592453003,10000,42921.797057151794,0.748828113079071,1.0013405084609983,0.674340009689331,1.3295317888259888,50000 -2991.8312034606934,5.086370944976807,40373.19947838783,90985,0,40373.19947838783,0.5527999997138977,1.9561856985092163,10000,43374.818145513535,0.7356249690055847,1.0418883562088013,0.6781600117683411,1.3062312602996826,50000 -3022.284590482712,5.132043361663818,40793.36383152008,91932,0,40793.36383152008,0.5508000254631042,1.9430943727493288,10000,43825.52984857559,0.7396484017372131,1.0154486894607544,0.6789000034332275,1.2950810194015503,50000 -3055.57865357399,5.207129716873169,41213.46935200691,92877,0,41213.46935200691,0.5501000285148621,1.9638128280639648,10000,44279.05270528793,0.7428515553474426,1.0198140144348145,0.6730799674987793,1.332512617111206,50000 -3086.0911548137665,5.252807855606079,41633.70737886429,93825,0,41633.70737886429,0.5507000088691711,1.9684770107269287,10000,44729.89771127701,0.7576562166213989,0.9598555564880372,0.6772800087928772,1.3212695121765137,50000 -3118.114638566971,5.299704313278198,42053.8782582283,94770,0,42053.8782582283,0.554900050163269,1.955616116523743,10000,45182.18700695038,0.7390429377555847,1.0380719900131226,0.6790399551391602,1.304970145225525,50000 -3149.01988363266,5.345886945724487,42474.13750100136,95715,0,42474.13750100136,0.5614000558853149,1.9297590255737305,10000,45633.44561266899,0.7507030963897705,1.004400610923767,0.6848599910736084,1.291728377342224,50000 -3182.1036455631256,5.395781517028809,42894.14020085335,96661,0,42894.14020085335,0.5568000078201294,1.937228202819824,10000,46086.63012123108,0.7564648389816284,0.9674227833747864,0.6837799549102783,1.2896058559417725,50000 -3212.790130376816,5.443286895751953,43314.50114059448,97611,0,43314.50114059448,0.5580000281333923,1.915615677833557,10000,46537.77447581291,0.7490038871765137,0.9923749566078186,0.6890199780464172,1.2585322856903076,50000 -3244.002207517624,5.495993137359619,43734.45563292503,98558,0,43734.45563292503,0.5594000220298767,1.939386248588562,10000,46989.04209041596,0.7421093583106995,1.0115565061569214,0.6842199563980103,1.2888301610946655,50000 -3276.752377271652,5.544393062591553,44154.56234765053,99504,0,44154.56234765053,0.5587000250816345,1.922590732574463,10000,47441.99556803703,0.7502343654632568,0.9732296466827391,0.6872599720954895,1.2587881088256836,50000 -3307.483034849167,5.593884229660034,44574.62707424164,100452,0,44574.62707424164,0.5685000419616699,1.8823338747024536,10000,47892.88826847077,0.7751367092132568,0.8857801556587219,0.6929799914360046,1.2570501565933228,50000 -3339.8508801460266,5.645975828170776,44994.80569982529,101397,0,44994.80569982529,0.5728000402450562,1.8757268190383911,10000,48345.53457951546,0.7506640553474426,0.966569483280182,0.6922799944877625,1.245205640792847,50000 -3375.8747539520264,5.707857370376587,45415.07636475563,102342,0,45415.07636475563,0.5701000094413757,1.8743760585784912,10000,48801.93915319443,0.7575390338897705,0.9569990038871764,0.6941800117492676,1.2485471963882446,50000 -3406.97397518158,5.746920824050903,45835.02481293678,103292,0,45835.02481293678,0.5659000277519226,1.897047519683838,10000,49253.074348926544,0.7648046612739563,0.9258500933647156,0.6905399560928345,1.2477099895477295,50000 -3439.858054637909,5.793359994888306,46255.20332431793,104239,0,46255.20332431793,0.5667000412940979,1.8755282163619995,10000,49706.23113465309,0.7541210651397705,0.9695034027099608,0.6918599605560303,1.2417765855789185,50000 -3473.9791502952576,5.849823951721191,46675.59752130509,105187,0,46675.59752130509,0.5755000114440918,1.8609957695007324,10000,50160.85048317909,0.7598632574081421,0.940223515033722,0.6953999996185303,1.232001543045044,50000 -3506.291650056839,5.897094011306763,47095.5532104969,106136,0,47095.5532104969,0.5692000389099121,1.868541240692139,10000,50613.21381640434,0.7672656178474426,0.9064086079597472,0.6967399716377258,1.219843864440918,50000 -3539.1190111637115,5.948194980621338,47515.50264263153,107078,0,47515.50264263153,0.5711000561714172,1.8621702194213867,10000,51066.089199543,0.7860937118530273,0.841831624507904,0.6953200101852417,1.229534387588501,50000 -3570.824383974076,5.9956605434417725,47935.81768369675,108021,0,47935.81768369675,0.5750000476837158,1.8669058084487915,10000,51518.20453286171,0.7554101347923279,0.9594653844833374,0.6963199973106384,1.2319732904434204,50000 -3604.986275196075,6.047371864318848,48355.86082482338,108967,0,48355.86082482338,0.5755000114440918,1.845165491104126,10000,51972.50933337212,0.7698437571525574,0.8918425440788269,0.7005800008773804,1.197978138923645,50000 -3637.372734069824,6.091373920440674,48775.94476270676,109914,0,48775.94476270676,0.579300045967102,1.835110664367676,10000,52425.07089591026,0.7803710699081421,0.8482807874679565,0.7021200060844421,1.1901153326034546,50000 -3668.2350981235504,6.1399030685424805,49195.93100190163,110859,0,49195.93100190163,0.5750000476837158,1.831464886665344,10000,52876.01609253883,0.7680468559265137,0.9069384336471558,0.7002800107002258,1.197730302810669,50000 -3699.5143172740936,6.189666271209717,49616.14709401131,111804,0,49616.14709401131,0.5737000107765198,1.8351027965545648,10000,53327.60908675194,0.7716405987739563,0.9013556241989136,0.7028999924659729,1.2098357677459717,50000 -3735.432944059372,6.240252494812012,50036.39476442337,112748,0,50036.39476442337,0.5814000368118286,1.8124264478683472,10000,53783.87377977371,0.7766015529632568,0.8739011883735657,0.7063800096511841,1.1852818727493286,50000 -3767.642267227173,6.535725593566895,50456.44828510285,113693,0,50456.44828510285,0.589900016784668,1.7972122430801392,10000,54236.47974801064,0.7954491972923279,0.7992606163024902,0.7044999599456787,1.186686635017395,50000 -3799.7460713386536,6.585723876953125,50876.50444483757,114632,0,50876.50444483757,0.5830000042915344,1.8008441925048828,10000,54688.73802089691,0.7786718606948853,0.8812108635902405,0.7083799839019775,1.18280029296875,50000 -3830.8059952259055,6.634954214096069,51296.428337574005,115575,0,51296.428337574005,0.5892000198364258,1.7851496934890747,10000,55139.81853270531,0.7844336032867432,0.8428012132644653,0.7099199891090393,1.1657264232635498,50000 -3863.548504590988,6.685975074768066,51716.66536331177,116520,0,51716.66536331177,0.5809000134468079,1.824753761291504,10000,55592.89674210549,0.7871679663658142,0.8348548412322998,0.7083399891853333,1.1812052726745603,50000 -3894.236728906632,6.735541582107544,52136.707661151886,117464,0,52136.707661151886,0.5915000438690186,1.7694329023361206,10000,56043.724599123,0.7781640291213989,0.8482047319412231,0.7140199542045593,1.1441729068756104,50000 -3925.459734916687,6.7867701053619385,52556.71144080162,118406,0,52556.71144080162,0.586400032043457,1.818394660949707,10000,56495.050055503845,0.7785351276397705,0.8815587162971497,0.709119975566864,1.1885517835617063,50000 -3957.0508601665497,6.83685827255249,52976.91713619232,119351,0,52976.91713619232,0.5924000144004822,1.7503550052642822,10000,56946.94454836845,0.7892773151397705,0.8121300935745239,0.7165200114250183,1.1363714933395386,50000 -3989.827569723129,6.887446403503418,53397.10926628113,120296,0,53397.10926628113,0.5946000218391418,1.750155329704285,10000,57400.011506319046,0.8076757788658142,0.7323868274688721,0.7183399796485901,1.1285380125045776,50000 -4025.02843785286,6.941509246826172,53817.029266119,121243,0,53817.029266119,0.596500039100647,1.758715271949768,10000,57855.23564505577,0.7827538847923279,0.8527898192405701,0.7153399586677551,1.1425896883010864,50000 -4056.473051548004,6.989522695541382,54237.009996175766,122191,0,54237.009996175766,0.5915000438690186,1.7542343139648438,10000,58306.75683450699,0.7879882454872131,0.8077414631843567,0.7186999917030334,1.1167722940444946,50000 -4087.749156236649,7.041682243347168,54657.21346616745,123137,0,54657.21346616745,0.6009000539779663,1.726817607879639,10000,58758.33668446541,0.80322265625,0.7724707722663879,0.7219600081443787,1.1197479963302612,50000 -4120.542637825012,7.0932228565216064,55077.45772242546,124081,0,55077.45772242546,0.6030000448226929,1.7267364263534546,10000,59211.47411370277,0.7920898199081421,0.8035741448402405,0.7215999960899353,1.1221164464950562,50000 -4150.582463741303,7.149578332901001,55497.51085400581,125028,0,55497.51085400581,0.5985000133514404,1.7480947971343994,10000,59661.67133617401,0.7930859327316284,0.8153924345970154,0.7215200066566467,1.1258158683776855,50000 -4183.735275506973,7.201582193374634,55917.58813285828,125972,0,55917.58813285828,0.5933000445365906,1.7462410926818848,10000,60115.00229549408,0.8003320097923279,0.7735731601715088,0.7224999666213989,1.1156514883041382,50000 -4214.42046713829,7.252074480056763,56337.57939887047,126916,0,56337.57939887047,0.6083000302314758,1.7146964073181152,10000,60565.77696371079,0.8159765601158142,0.7205840945243835,0.7251399755477905,1.105201005935669,50000 -4245.306490898132,7.332320213317871,56757.66656923294,127863,0,56757.66656923294,0.6078000068664551,1.6983314752578735,10000,61016.87927532196,0.8004687428474426,0.7621183395385742,0.7263799905776978,1.087070107460022,50000 -4278.540325880051,7.388689041137695,57177.68268537521,128810,0,57177.68268537521,0.6009000539779663,1.718663215637207,10000,61470.233803510666,0.804492175579071,0.7556107044219971,0.7278800010681152,1.0926361083984375,50000 -4310.619389057159,7.430602788925171,57597.74878168106,129758,0,57597.74878168106,0.6118000149726868,1.675654411315918,10000,61922.46889066696,0.8183202743530273,0.6911411881446838,0.730139970779419,1.0635643005371094,50000 -4347.752557516098,7.481885433197021,58017.70781111717,130702,0,58017.70781111717,0.6089000105857849,1.6937024593353271,10000,62379.66068387032,0.8029296398162842,0.7509654760360718,0.731220006942749,1.0709768533706665,50000 -4380.077574729919,7.52693510055542,58437.64938545227,131648,0,58437.64938545227,0.6130000352859497,1.677901268005371,10000,62832.02004933357,0.80824214220047,0.7324361801147461,0.7309799790382385,1.0710631608963013,50000 -4411.736454963684,7.580765724182129,58857.876715660095,132594,0,58857.876715660095,0.6110000014305115,1.682750225067139,10000,63284.00793981552,0.8122069835662842,0.7282311320304871,0.7334399819374084,1.069677233695984,50000 -4445.912066459656,7.64042592048645,59278.13584399223,133537,0,59278.13584399223,0.6073000431060791,1.697361946105957,10000,63738.54997134209,0.8247460722923279,0.6852326989173889,0.7335799932479858,1.0745161771774292,50000 -4481.801281452179,7.701645374298096,59698.29487943649,134484,0,59698.29487943649,0.613800048828125,1.670855164527893,10000,64194.70729804039,0.81361323595047,0.7152171730995178,0.7375999689102173,1.051063895225525,50000 -4515.079110383987,7.746327400207519,60118.45872378349,135432,0,60118.45872378349,0.6119000315666199,1.6614116430282593,10000,64648.24386954308,0.8185351490974426,0.6898524165153503,0.7380599975585938,1.040588140487671,50000 -4547.873115539551,7.793517351150513,60538.52721905708,136378,0,60538.52721905708,0.6154000163078308,1.6687698364257812,10000,65101.20194649696,0.8243945240974426,0.6757507920265198,0.7364199757575989,1.0444821119308472,50000 -4582.273346424103,7.8503875732421875,60958.88538622856,137323,0,60958.88538622856,0.6136000156402588,1.662897706031799,10000,65556.06583595276,0.8189452886581421,0.6932224631309509,0.740339994430542,1.0378048419952393,50000 -4614.699839115143,7.911661148071289,61378.89130926132,138270,0,61378.89130926132,0.6142000555992126,1.658218264579773,10000,66008.60767126083,0.8212890625,0.7007355093955994,0.7413199543952942,1.0484044551849363,50000 -4646.778388261795,7.966315031051636,61798.8139359951,139212,0,61798.8139359951,0.6215000152587891,1.646551251411438,10000,66460.71085047722,0.8290624618530273,0.6643991470336914,0.7405399680137634,1.040321946144104,50000 -4678.231098890305,8.020896673202515,62218.95655655861,140156,0,62218.95655655861,0.6212000250816345,1.6370558738708496,10000,66912.40876984596,0.8376757502555847,0.613097608089447,0.7422199845314026,1.0231008529663086,50000 -4712.48251914978,8.083777904510498,62639.31224012375,141101,0,62639.31224012375,0.6190000176429749,1.6144688129425049,10000,67367.12671756744,0.8280078172683716,0.6409241557121277,0.7453199625015259,1.0019612312316897,50000 -4745.326453208923,8.137955904006958,63059.65165567398,142049,0,63059.65165567398,0.6236000061035156,1.622061252593994,10000,67820.41301584244,0.8278319835662842,0.646851658821106,0.7454800009727478,1.0134694576263428,50000 -4776.352700948715,8.195709705352783,63479.56883740425,142990,0,63479.56883740425,0.6205000281333923,1.6167891025543213,10000,68271.46160244942,0.8350585699081421,0.6243909001350403,0.7469799518585205,1.008358120918274,50000 -4813.3277044296265,8.251341104507446,63899.6438407898,143932,0,63899.6438407898,0.6215000152587891,1.6196413040161133,10000,68728.61555743217,0.8307421803474426,0.6419038772583008,0.7469799518585205,0.9995123744010924,50000 -4848.98010802269,8.308446168899536,64319.94776797295,144882,0,64319.94776797295,0.6274000406265259,1.5910381078720093,10000,69184.67727637291,0.8347265720367432,0.622641384601593,0.7500399947166443,0.9826123118400574,50000 -4879.781792402268,8.357537031173706,64740.01791214943,145829,0,64740.01791214943,0.629300057888031,1.5892504453659058,10000,69635.6463303566,0.8373242020606995,0.6188812255859375,0.7525999546051025,0.9853445887565612,50000 -4917.589969396591,8.413017749786377,65160.04265499115,146774,0,65160.04265499115,0.6303000450134277,1.581558108329773,10000,70093.58276224136,0.8498046398162842,0.5664092302322388,0.7521399855613708,0.978649377822876,50000 -4950.214616537094,8.461366415023804,65579.95571732521,147722,0,65579.95571732521,0.6295000314712524,1.5846375226974487,10000,70546.2168867588,0.8400195240974426,0.6110930442810059,0.7532599568367004,0.9788589477539062,50000 -4983.62385559082,8.5170156955719,65999.89698147774,148669,0,65999.89698147774,0.6278000473976135,1.5932260751724243,10000,70999.6721727848,0.840624988079071,0.6056962609291077,0.7522000074386597,0.9804185628890992,50000 -5020.925404548645,8.574753284454346,66420.16347050667,149616,0,66420.16347050667,0.6355000138282776,1.575750470161438,10000,71457.34581279755,0.8484765291213989,0.5737552046775818,0.7560999989509583,0.9599688649177552,50000 -5054.153319835663,8.622200965881348,66840.26841020584,150563,0,66840.26841020584,0.6318000555038452,1.5865153074264526,10000,71910.77382802963,0.8442773222923279,0.588003396987915,0.7562599778175354,0.963024377822876,50000 -5088.375540494919,8.670436143875122,67260.58755874634,151506,0,67260.58755874634,0.6328000426292419,1.570575714111328,10000,72365.41188597679,0.8475585579872131,0.5803536772727966,0.7574999928474426,0.9603232741355896,50000 -5119.106483459473,8.727611780166626,67680.61174440384,152452,0,67680.61174440384,0.636400043964386,1.5653671026229858,10000,72816.272285223,0.8537304401397705,0.5536989569664001,0.7576799988746643,0.9566033482551576,50000 -5152.867078065872,8.786496639251709,68100.60788464546,153394,0,68100.60788464546,0.6341000199317932,1.5579533576965332,10000,73270.136343956,0.8624218702316284,0.5283911824226379,0.7599799633026123,0.9521071910858154,50000 -5183.143657207489,8.846336364746094,68520.679936409,154339,0,68520.679936409,0.6384000182151794,1.555184006690979,10000,73720.59342503548,0.8519921898841858,0.5664136409759521,0.7613399624824524,0.9477301836013794,50000 -5220.722696304321,8.91114854812622,68940.78204774857,155283,0,68940.78204774857,0.6430000066757202,1.5386472940444946,10000,74178.38667702675,0.85498046875,0.5443105697631836,0.7622199654579163,0.9341087341308594,50000 -5257.861455440521,8.966075658798218,69361.06091952324,156229,0,69361.06091952324,0.6414000391960144,1.5460975170135498,10000,74635.90718269348,0.8586718440055847,0.5289605855941772,0.7625399827957153,0.937669038772583,50000 -5289.2215440273285,9.013855934143066,69781.32340240479,157174,0,69781.32340240479,0.643500030040741,1.5365839004516602,10000,75087.62546420097,0.8605468273162842,0.5198253989219666,0.7628399729728699,0.9245589375495912,50000 -5322.374501943588,9.09701156616211,70201.25111293793,158116,0,70201.25111293793,0.6447000503540039,1.522838115692139,10000,75540.83647584915,0.8600195050239563,0.5220351219177246,0.7657999992370605,0.922173261642456,50000 -5353.905155658722,9.155547142028809,70621.22784686089,159059,0,70621.22784686089,0.6422000527381897,1.5358188152313232,10000,75992.44965624809,0.8639453053474426,0.5182149410247803,0.7652599811553955,0.925984799861908,50000 -5385.8952486515045,9.221797943115234,71041.17511677742,160000,0,71041.17511677742,0.65010005235672,1.5326465368270874,10000,76444.50120282173,0.8701952695846558,0.4905823767185211,0.7658799886703491,0.9287203550338744,50000 -5419.464338064194,9.283828735351562,71461.22817921638,160944,0,71461.22817921638,0.6476000547409058,1.5069485902786257,10000,76898.23347449303,0.8650195002555847,0.5044265985488892,0.7679399847984314,0.9082586765289308,50000 -5450.181225776672,9.357109308242798,71881.19460654259,161886,0,71881.19460654259,0.648300051689148,1.5145506858825684,10000,77349.03749322891,0.8689648509025574,0.4930750429630279,0.7684199810028076,0.9124515652656556,50000 -5487.747537851334,9.417224645614624,72301.17582511902,162829,0,72301.17582511902,0.6515000462532043,1.5086804628372192,10000,77806.6926317215,0.87416011095047,0.4663747251033783,0.7705599665641785,0.8969687223434448,50000 -5517.879540681839,9.467840194702148,72721.10506033897,163779,0,72721.10506033897,0.6487000584602356,1.4998230934143066,10000,78256.85267496109,0.869921863079071,0.4806233942508697,0.7714200019836426,0.8954671025276184,50000 -5549.715945720673,9.528099298477173,73141.14754962921,164724,0,73141.14754962921,0.6450000405311584,1.5077391862869265,10000,78708.84034132957,0.8695507645606995,0.4878565371036529,0.771399974822998,0.9021183252334596,50000 -5583.6576244831085,9.585728645324709,73561.19763278961,165668,0,73561.19763278961,0.6477000117301941,1.5097239017486572,10000,79162.93744325638,0.8700780868530273,0.4788884222507477,0.7709999680519104,0.8976555466651917,50000 -5615.822008609772,9.64665174484253,73981.26328992844,166615,0,73981.26328992844,0.656000018119812,1.489260196685791,10000,79615.27650952339,0.8802929520606995,0.4507987201213836,0.7734799981117249,0.8893938660621643,50000 -5648.666755914688,9.712154865264893,74401.27312350273,167561,0,74401.27312350273,0.6522000432014465,1.4930800199508667,10000,80068.24415230751,0.8752539157867432,0.4655075669288635,0.7736200094223022,0.8859438300132751,50000 -5683.770456075668,9.778326272964478,74821.58446097374,168508,0,74821.58446097374,0.6541000604629517,1.4873871803283691,10000,80523.77368187904,0.87708979845047,0.4573198854923248,0.7735799551010132,0.8782155513763428,50000 -5717.919988632202,9.840453386306764,75241.93149113655,169454,0,75241.93149113655,0.6522000432014465,1.4902135133743286,10000,80978.3799700737,0.8761327862739563,0.4634994268417358,0.773859977722168,0.8852800726890564,50000 -5748.718489408493,9.904155015945436,75662.24620199203,170400,0,75662.24620199203,0.6583000421524048,1.4806007146835327,10000,81429.60545611382,0.8781054615974426,0.45171058177948,0.7761200070381165,0.874293863773346,50000 -5782.800493478775,9.963026523590088,76082.20815610886,171344,0,76082.20815610886,0.6561000347137451,1.4711897373199463,10000,81883.75682711601,0.8800976276397705,0.444496214389801,0.7761799693107605,0.873876690864563,50000 -5823.888249158859,10.027615547180176,76502.247112751,172261,0,76502.247112751,0.6562000513076782,1.4677006006240845,10000,82344.99448871613,0.8818749785423279,0.4351919591426849,0.7774199843406677,0.8687987923622131,50000 -5860.713069200516,10.078977823257446,76922.55945754051,173210,0,76922.55945754051,0.6541000604629517,1.4714024066925049,10000,82802.23096346855,0.8833593726158142,0.4367986917495727,0.7771399617195129,0.8698133230209351,50000 -5895.611342906952,10.13133668899536,77342.50984573364,174156,0,77342.50984573364,0.6601000428199768,1.4674209356307983,10000,83257.17953515053,0.8812890648841858,0.4501301944255829,0.7763199806213379,0.872589111328125,50000 -5934.702550172806,10.195268630981444,77762.4930267334,175101,0,77762.4930267334,0.6586000323295593,1.465294361114502,10000,83716.3657438755,0.8842968344688416,0.4302087426185608,0.7767399549484253,0.869044840335846,50000 -5966.238239049912,10.259270668029783,78182.75847244263,176047,0,78182.75847244263,0.6606000065803528,1.4649004936218262,10000,84168.2788734436,0.8846288919448853,0.4247990846633911,0.7781800031661987,0.8619899749755859,50000 -6000.178608179092,10.321425676345823,78602.89275169373,176987,0,78602.89275169373,0.6599000096321106,1.464371919631958,10000,84622.46363949776,0.8851367235183716,0.4297102391719818,0.7775399684906006,0.867326021194458,50000 -6037.081268548965,10.381896018981934,79022.98996186256,177929,0,79022.98996186256,0.6630000472068787,1.4578428268432615,10000,85079.57243037224,0.8857812285423279,0.4314631521701813,0.7797799706459045,0.8623688220977783,50000 -6069.630121469498,10.452614307403564,79443.26837658882,178879,0,79443.26837658882,0.6585000157356262,1.4597877264022827,10000,85532.51931023598,0.8853515386581421,0.4269904494285583,0.7788599729537964,0.8618048429489136,50000 -6102.052920103073,10.519617795944214,79863.53984189034,179824,0,79863.53984189034,0.661300003528595,1.4576455354690552,10000,85985.32821440697,0.887499988079071,0.4231514930725097,0.7784799933433533,0.8612497448921204,50000 -6138.60348033905,10.583715915679932,80283.59482526779,180768,0,80283.59482526779,0.6619000434875488,1.4548231363296509,10000,86442.04560875893,0.8854296803474426,0.4256410896778106,0.7800599932670593,0.8574149012565613,50000 -6169.210487604141,10.634981632232666,80703.73807430267,181716,0,80703.73807430267,0.6617000102996826,1.4499728679656982,10000,86892.89508104324,0.8868945240974426,0.4210363030433655,0.7802599668502808,0.8553614616394043,50000 -6203.603073835373,10.696644067764282,81123.65591287613,182659,0,81123.65591287613,0.6627000570297241,1.4520686864852903,10000,87347.31507396698,0.8882421851158142,0.4179194271564483,0.780299961566925,0.8557093739509583,50000 -6240.15052652359,10.76271915435791,81543.79722166061,183603,0,81543.79722166061,0.6643000245094299,1.4516526460647583,10000,87804.11721277237,0.8871288895606995,0.4207266271114349,0.7800599932670593,0.855748176574707,50000 -6272.948943138123,10.815591096878052,81963.94086170197,184549,0,81963.94086170197,0.663800060749054,1.4520243406295776,10000,88257.15979456902,0.8898437023162842,0.4142662286758423,0.7801600098609924,0.8554735779762268,50000 -6306.051533937454,10.878965616226196,82384.03376245499,185492,0,82384.03376245499,0.664400041103363,1.4500694274902344,10000,88710.46589899063,0.8886523246765137,0.4187787473201751,0.7800799608230591,0.8548168540000916,50000 -6342.082160711288,10.93463397026062,82804.15239548683,186437,0,82804.15239548683,0.6648000478744507,1.450608491897583,10000,89166.71850991249,0.8884961009025574,0.418078750371933,0.7803399562835693,0.8548141121864319,50000 -6379.588619709015,10.998920679092407,83224.0760512352,187384,0,83224.0760512352,0.6648000478744507,1.4506032466888428,10000,89624.2609269619,0.8875585794448853,0.4143358469009399,0.7803199887275696,0.8548135757446289,50000 -6411.748934745789,11.053326845169067,83644.06291556358,188327,0,83644.06291556358,0.6648000478744507,1.4506032466888428,10000,90076.51022052763,0.8861523270606995,0.4194840490818023,0.7803199887275696,0.8548135757446289,50000 -6445.480100631714,11.119504451751707,84064.30104327202,189270,0,84064.30104327202,0.6648000478744507,1.4506032466888428,10000,90530.59429717064,0.885058581829071,0.4213558435440063,0.7803199887275696,0.8548135757446289,50000 -6477.661994457245,11.18522047996521,84484.19007110596,190213,0,84484.19007110596,0.6648000478744507,1.4506032466888428,10000,90982.77915644646,0.8882812261581421,0.4192114472389221,0.7803199887275696,0.8548135757446289,50000 -6509.533877134323,11.248058319091797,84904.22356843948,191158,0,84904.22356843948,0.6648000478744507,1.4506032466888428,10000,91434.7952284813,0.8881054520606995,0.4155466854572296,0.7803199887275696,0.8548135757446289,50000 -6540.580693721771,11.31272554397583,85324.30487179756,192103,0,85324.30487179756,0.6648000478744507,1.4506032466888428,10000,91886.03667712212,0.8872460722923279,0.4198216497898102,0.7803199887275696,0.8548135757446289,50000 -6575.839470863342,11.38812780380249,85744.3128118515,193050,0,85744.3128118515,0.6648000478744507,1.4506032466888428,10000,92341.42679834366,0.8844921588897705,0.4295682907104492,0.7803199887275696,0.8548135757446289,50000 -6606.46160697937,11.450967073440552,86164.21161937714,193996,0,86164.21161937714,0.6648000478744507,1.4506032466888428,10000,92792.05870342256,0.8874413967132568,0.4196164309978485,0.7803199887275696,0.8548135757446289,50000 -6642.044916152954,11.519468545913696,86584.47300457954,194940,0,86584.47300457954,0.6648000478744507,1.4506032466888428,10000,93248.01982617378,0.8882421851158142,0.4126248955726623,0.7803199887275696,0.8548135757446289,50000 -6677.47377038002,11.58312463760376,87004.58102655411,195885,0,87004.58102655411,0.6648000478744507,1.4506032466888428,10000,93703.66821908952,0.8870898485183716,0.4183282852172851,0.7803199887275696,0.8548135757446289,50000 -6710.491514205933,11.646705865859984,87424.61718916893,196833,0,87424.61718916893,0.6648000478744507,1.4506032466888428,10000,94156.83369517326,0.8863085508346558,0.4206987619400024,0.7803199887275696,0.8548135757446289,50000 -6745.4324831962585,11.7142493724823,87844.54752922058,197775,0,87844.54752922058,0.6648000478744507,1.4506032466888428,10000,94611.82145762444,0.8870312571525574,0.4229859709739685,0.7803199887275696,0.8548135757446289,50000 -6779.643330097199,11.784050464630129,88264.76767849922,198719,0,88264.76767849922,0.6648000478744507,1.4506032466888428,10000,95066.36974191666,0.8882030844688416,0.4190461635589599,0.7803199887275696,0.8548135757446289,50000 -6810.564759016037,11.886321544647217,88684.78815793991,199663,0,88684.78815793991,0.6648000478744507,1.4506032466888428,10000,95517.46187448502,0.8855078220367432,0.4201179146766662,0.7803199887275696,0.8548135757446289,50000 -6849.069217681885,11.955017805099487,89105.10370564461,200604,0,89105.10370564461,0.6648000478744507,1.4506032466888428,10000,95976.39848995207,0.8885741829872131,0.4168844521045685,0.7803199887275696,0.8548135757446289,50000 -6881.090431928635,12.015576601028442,89525.00794053078,201549,0,89525.00794053078,0.6648000478744507,1.4506032466888428,10000,96428.4326581955,0.8887695074081421,0.4169636070728302,0.7803199887275696,0.8548135757446289,50000 -6918.256211996079,12.085902690887451,89945.28154015541,202495,0,89945.28154015541,0.6648000478744507,1.4506032466888428,10000,96885.9902985096,0.8854491710662842,0.4246847927570343,0.7803199887275696,0.8548135757446289,50000 -6954.5250408649445,12.15333890914917,90365.33030056952,203440,0,90365.33030056952,0.6648000478744507,1.4506032466888428,10000,97342.42331409454,0.8881250023841858,0.4221977889537811,0.7803199887275696,0.8548135757446289,50000 -6987.989857196808,12.211401224136353,90785.62109804152,204383,0,90785.62109804152,0.6648000478744507,1.4506032466888428,10000,97796.2846519947,0.8873046636581421,0.4180344343185425,0.7803199887275696,0.8548135757446289,50000 -7029.166851758957,12.281495094299316,91205.78978681564,205328,0,91205.78978681564,0.6648000478744507,1.4506032466888428,10000,98257.74850678444,0.8882616758346558,0.4138765633106231,0.7803199887275696,0.8548135757446289,50000 -7062.951898813248,12.340657472610474,91625.9358868599,206275,0,91625.9358868599,0.6648000478744507,1.4506032466888428,10000,98711.7869529724,0.8873828053474426,0.4200892448425293,0.7803199887275696,0.8548135757446289,50000 -7095.32580947876,12.412407398223875,92046.22657752036,207220,0,92046.22657752036,0.6648000478744507,1.4506032466888428,10000,99164.57158637048,0.8865429759025574,0.4198354482650757,0.7803199887275696,0.8548135757446289,50000 -7135.026664972305,12.482912302017212,92466.52595090866,208166,0,92466.52595090866,0.6648000478744507,1.4506032466888428,10000,99624.68998265266,0.8901171684265137,0.4154210090637207,0.7803199887275696,0.8548135757446289,50000 -7169.079889535904,12.53828740119934,92886.82992863657,209115,0,92886.82992863657,0.6648000478744507,1.4506032466888428,10000,100079.15053081512,0.8893554210662842,0.4160081148147583,0.7803199887275696,0.8548135757446289,50000 -7205.7183492183685,12.614384651184082,93306.80130004884,210059,0,93306.80130004884,0.6648000478744507,1.4506032466888428,10000,100535.8836786747,0.8882812261581421,0.4137320518493652,0.7803199887275696,0.8548135757446289,50000 -7245.537149429321,12.672444820404053,93726.69717526436,211003,0,93726.69717526436,0.6648000478744507,1.4506032466888428,10000,100995.70543003082,0.8879687190055847,0.4160420000553131,0.7803199887275696,0.8548135757446289,50000 -7278.940035820007,12.730435132980348,94146.89120841026,211947,0,94146.89120841026,0.6648000478744507,1.4506032466888428,10000,101449.40781760216,0.8864062428474426,0.4183524847030639,0.7803199887275696,0.8548135757446289,50000 -7320.083385229111,12.798975467681885,94566.94662880898,212888,0,94566.94662880898,0.6648000478744507,1.4506032466888428,10000,101910.72256612778,0.8856250047683716,0.4198617041110992,0.7803199887275696,0.8548135757446289,50000 -7352.1744792461395,12.85687255859375,94986.92520284653,213835,0,94986.92520284653,0.6648000478744507,1.4506032466888428,10000,102362.89806699751,0.8876562118530273,0.4190992712974548,0.7803199887275696,0.8548135757446289,50000 -7390.342691898346,12.929483413696287,95407.19551420212,214780,0,95407.19551420212,0.6648000478744507,1.4506032466888428,10000,102821.45707058908,0.8873242139816284,0.4169353246688843,0.7803199887275696,0.8548135757446289,50000 -7429.297716617584,13.001351118087769,95827.43306875227,215728,0,95827.43306875227,0.6648000478744507,1.4506032466888428,10000,103280.76931118964,0.8858398199081421,0.4253257513046264,0.7803199887275696,0.8548135757446289,50000 -7460.550596475601,13.059123516082764,96247.62993168832,216675,0,96247.62993168832,0.6648000478744507,1.4506032466888428,10000,103732.32486963272,0.8865820169448853,0.4243214726448059,0.7803199887275696,0.8548135757446289,50000 -7496.124892711639,13.133168935775757,96667.72405338287,217614,0,96667.72405338287,0.6648000478744507,1.4506032466888428,10000,104188.11450624466,0.8871874809265137,0.4233055114746094,0.7803199887275696,0.8548135757446289,50000 -7529.747981071472,13.193522214889526,97087.75162792206,218559,0,97087.75162792206,0.6648000478744507,1.4506032466888428,10000,104641.87342977524,0.8890820145606995,0.4117428064346313,0.7803199887275696,0.8548135757446289,50000 -7561.695729017258,13.267573595046995,97507.7348575592,219501,0,97507.7348575592,0.6648000478744507,1.4506032466888428,10000,105093.92755794524,0.8858202695846558,0.421209454536438,0.7803199887275696,0.8548135757446289,50000 -7601.125306606293,13.342846632003784,97927.69353795052,220445,0,97927.69353795052,0.6648000478744507,1.4506032466888428,10000,105553.43908929823,0.8859570026397705,0.4233082830905914,0.7803199887275696,0.8548135757446289,50000 -7649.0584235191345,13.41545057296753,98347.8355793953,221391,0,98347.8355793953,0.6648000478744507,1.4506032466888428,10000,106021.63544940948,0.8848632574081421,0.4247506260871887,0.7803199887275696,0.8548135757446289,50000 -7681.673318862915,13.478192329406738,98768.19242358208,222339,0,98768.19242358208,0.6648000478744507,1.4506032466888428,10000,106474.71693491936,0.8887499570846558,0.4126403629779815,0.7803199887275696,0.8548135757446289,50000 -7727.574855804443,13.553714990615845,99188.47766494752,223282,0,99188.47766494752,0.6648000478744507,1.4506032466888428,10000,106941.02724575996,0.8872851133346558,0.4197381734848022,0.7803199887275696,0.8548135757446289,50000 -7761.677759170532,13.621427297592165,99608.62461566924,224226,0,99608.62461566924,0.6648000478744507,1.4506032466888428,10000,107395.39229464532,0.8883593678474426,0.4179112017154693,0.7803199887275696,0.8548135757446289,50000 -7796.263030529022,13.69261384010315,100028.8138911724,225168,0,100028.8138911724,0.6648000478744507,1.4506032466888428,10000,107850.28579640388,0.8888476490974426,0.4202959835529327,0.7803199887275696,0.8548135757446289,50000 -7831.920604705811,13.755226135253906,100449.1869559288,226114,0,100449.1869559288,0.6648000478744507,1.4506032466888428,10000,108306.42710208891,0.8854687213897705,0.4251309037208557,0.7803199887275696,0.8548135757446289,50000 -7871.779316186905,13.82857346534729,100869.26940894128,227056,0,100869.26940894128,0.6648000478744507,1.4506032466888428,10000,108766.48868989944,0.8884961009025574,0.4193135797977447,0.7803199887275696,0.8548135757446289,50000 -7903.954140663147,13.88698935508728,101289.16657400133,228002,0,101289.16657400133,0.6648000478744507,1.4506032466888428,10000,109218.66709923744,0.8871288895606995,0.4178809225559234,0.7803199887275696,0.8548135757446289,50000 -7939.73398065567,13.966781616210938,101709.11290287971,228945,0,101709.11290287971,0.6648000478744507,1.4506032466888428,10000,109674.52127766608,0.8885351419448853,0.4126009345054626,0.7803199887275696,0.8548135757446289,50000 -7982.809792280197,14.043018579483032,102129.02837610243,229889,0,102129.02837610243,0.6648000478744507,1.4506032466888428,10000,110137.63670706748,0.8883984088897705,0.4152209758758545,0.7803199887275696,0.8548135757446289,50000 -8014.580620288849,14.102319240570068,102549.27608656885,230836,0,102549.27608656885,0.6648000478744507,1.4506032466888428,10000,110589.76342749596,0.8867382407188416,0.4209681153297424,0.7803199887275696,0.8548135757446289,50000 -8053.198905706406,14.175431966781616,102969.37177467346,231779,0,102969.37177467346,0.6648000478744507,1.4506032466888428,10000,111048.59785699844,0.8880859017372131,0.424091637134552,0.7803199887275696,0.8548135757446289,50000 -8090.729656934738,14.253586530685425,103389.27419066428,232726,0,103389.27419066428,0.6648000478744507,1.4506032466888428,10000,111506.15778017044,0.8880859017372131,0.4161204695701599,0.7803199887275696,0.8548135757446289,50000 -8130.175797224045,14.363630533218384,103809.17875909804,233671,0,103809.17875909804,0.6648000478744507,1.4506032466888428,10000,111965.66636037828,0.889941394329071,0.4111701250076294,0.7803199887275696,0.8548135757446289,50000 -8164.389766454697,14.44425344467163,104229.44356608392,234615,0,104229.44356608392,0.6648000478744507,1.4506032466888428,10000,112420.273696661,0.8862109184265137,0.4182306230068207,0.7803199887275696,0.8548135757446289,50000 -8207.39955830574,14.519267797470093,104649.33818554878,235560,0,104649.33818554878,0.6648000478744507,1.4506032466888428,10000,112883.30195975304,0.8863866925239563,0.4218202233314514,0.7803199887275696,0.8548135757446289,50000 -8246.647035121918,14.593485116958618,105069.57685279846,236507,0,105069.57685279846,0.6648000478744507,1.4506032466888428,10000,113342.91085219385,0.8854101300239563,0.4208202362060547,0.7803199887275696,0.8548135757446289,50000 -8279.95104432106,14.655708074569702,105489.77764558792,237453,0,105489.77764558792,0.6648000478744507,1.4506032466888428,10000,113796.5252687931,0.8888476490974426,0.4130127131938934,0.7803199887275696,0.8548135757446289,50000 -8315.858958244324,14.733111381530762,105909.77715063097,238397,0,105909.77715063097,0.6648000478744507,1.4506032466888428,10000,114252.5577852726,0.8869140148162842,0.4204769432544708,0.7803199887275696,0.8548135757446289,50000 -8356.481803894043,14.807372331619264,106329.69542336464,239341,0,106329.69542336464,0.6648000478744507,1.4506032466888428,10000,114713.22124695778,0.8874022960662842,0.4200328290462494,0.7803199887275696,0.8548135757446289,50000 -8391.442966938019,14.885765552520752,106749.73273301125,240287,0,106749.73273301125,0.6648000478744507,1.4506032466888428,10000,115168.345505476,0.8852929472923279,0.4266456067562103,0.7803199887275696,0.8548135757446289,50000 -8425.569839000702,14.957653284072876,107169.90454864502,241233,0,107169.90454864502,0.6648000478744507,1.4506032466888428,10000,115622.76381874084,0.8872656226158142,0.4201632738113403,0.7803199887275696,0.8548135757446289,50000 -8464.812617301941,15.034671306610107,107590.16424393654,242178,0,107590.16424393654,0.6648000478744507,1.4506032466888428,10000,116082.39153313635,0.887988269329071,0.4175548553466797,0.7803199887275696,0.8548135757446289,50000 -8507.924887418747,15.112300157546995,108010.23139071465,243120,0,108010.23139071465,0.6648000478744507,1.4506032466888428,10000,116545.69679760931,0.8871093392372131,0.4164575934410095,0.7803199887275696,0.8548135757446289,50000 -8541.744990348816,15.172938346862791,108430.22557520866,244061,0,108430.22557520866,0.6648000478744507,1.4506032466888428,10000,116999.6195745468,0.8856835961341858,0.4228624999523163,0.7803199887275696,0.8548135757446289,50000 -8582.07614517212,15.252984046936035,108850.48312044144,245008,0,108850.48312044144,0.6648000478744507,1.4506032466888428,10000,117460.33714318275,0.8870702981948853,0.418248176574707,0.7803199887275696,0.8548135757446289,50000 -8623.50287437439,15.32018232345581,109270.4581296444,245951,0,109270.4581296444,0.6648000478744507,1.4506032466888428,10000,117921.85386919975,0.8870898485183716,0.4204050004482269,0.7803199887275696,0.8548135757446289,50000 -8657.179294109344,15.38704538345337,109690.64672994614,246896,0,109690.64672994614,0.6648000478744507,1.4506032466888428,10000,118375.8336148262,0.8881640434265137,0.4150146245956421,0.7803199887275696,0.8548135757446289,50000 -8704.600757598877,15.464092493057253,110110.56650781631,247838,0,110110.56650781631,0.6648000478744507,1.4506032466888428,10000,118843.29978728294,0.8864648342132568,0.4250794947147369,0.7803199887275696,0.8548135757446289,50000 -8737.558348178864,15.524272203445436,110530.78303217888,248784,0,110530.78303217888,0.6648000478744507,1.4506032466888428,10000,119296.5821928978,0.8899218440055847,0.4135270118713379,0.7803199887275696,0.8548135757446289,50000 -8776.442138671875,15.602407932281494,110950.67167448996,249728,0,110950.67167448996,0.6648000478744507,1.4506032466888428,10000,119755.48139166832,0.887011706829071,0.4246560633182525,0.7803199887275696,0.8548135757446289,50000 -8808.311047792435,15.66840434074402,111370.60998010635,250671,0,111370.60998010635,0.6648000478744507,1.4506032466888428,10000,120207.40252876282,0.8860546946525574,0.4246920347213745,0.7803199887275696,0.8548135757446289,50000 -8849.612287044525,15.789583206176758,111790.58244228364,251616,0,111790.58244228364,0.6648000478744507,1.4506032466888428,10000,120668.84587788582,0.887011706829071,0.4161558449268341,0.7803199887275696,0.8548135757446289,50000 -8887.146436452866,15.8693106174469,112210.55100989342,252562,0,112210.55100989342,0.6648000478744507,1.4506032466888428,10000,121126.47691631316,0.8890429735183716,0.4178885519504547,0.7803199887275696,0.8548135757446289,50000 -8923.819966077805,15.935075044631958,112630.65174293518,253506,0,112630.65174293518,0.6648000478744507,1.4506032466888428,10000,121583.36565971376,0.8874413967132568,0.4144672751426697,0.7803199887275696,0.8548135757446289,50000 -8959.172646284103,16.009252786636353,113050.73615145683,254449,0,113050.73615145683,0.6648000478744507,1.4506032466888428,10000,122038.9248905182,0.88623046875,0.418743759393692,0.7803199887275696,0.8548135757446289,50000 -9004.886109113693,16.087283849716187,113470.6403222084,255392,0,113470.6403222084,0.6648000478744507,1.4506032466888428,10000,122504.66809415816,0.8873242139816284,0.4254908561706543,0.7803199887275696,0.8548135757446289,50000 -9049.49839234352,16.151769876480103,113890.8896214962,256336,0,113890.8896214962,0.6648000478744507,1.4506032466888428,10000,122969.64131331444,0.8916210532188416,0.4107459783554077,0.7803199887275696,0.8548135757446289,50000 -9090.309470415115,16.214890003204346,114310.84763216972,257282,0,114310.84763216972,0.6648000478744507,1.4506032466888428,10000,123430.52208042143,0.8884570002555847,0.4124673306941986,0.7803199887275696,0.8548135757446289,50000 -9128.919853687286,16.28192138671875,114730.95115160942,258220,0,114730.95115160942,0.6648000478744507,1.4506032466888428,10000,123889.350730896,0.8871874809265137,0.4193939566612243,0.7803199887275696,0.8548135757446289,50000 -9163.585408687592,16.34766459465027,115150.87838816644,259161,0,115150.87838816644,0.6648000478744507,1.4506032466888428,10000,124344.0568537712,0.8880859017372131,0.4152601063251495,0.7803199887275696,0.8548135757446289,50000 -9204.456241607666,16.427672386169434,115571.12068414688,260103,0,115571.12068414688,0.6648000478744507,1.4506032466888428,10000,124805.29729104042,0.8839648365974426,0.4242160022258758,0.7803199887275696,0.8548135757446289,50000 -9238.854766130447,16.508938312530518,115991.3933467865,261048,0,115991.3933467865,0.6648000478744507,1.4506032466888428,10000,125260.09818053246,0.8904882669448853,0.4089166224002838,0.7803199887275696,0.8548135757446289,50000 -9280.902841329576,16.59387707710266,116411.57664585114,261990,0,116411.57664585114,0.6648000478744507,1.4506032466888428,10000,125722.46207761765,0.8859374523162842,0.4217362105846405,0.7803199887275696,0.8548135757446289,50000 -9322.155301809313,16.67237138748169,116831.9596195221,262932,0,116831.9596195221,0.6648000478744507,1.4506032466888428,10000,126184.22378492355,0.8861132860183716,0.42596235871315,0.7803199887275696,0.8548135757446289,50000 -9355.41879272461,16.736831426620483,117252.12936592102,263881,0,117252.12936592102,0.6648000478744507,1.4506032466888428,10000,126637.7700855732,0.8869140148162842,0.4212360680103302,0.7803199887275696,0.8548135757446289,50000 -9399.8897895813,16.814719676971436,117672.37464976312,264823,0,117672.37464976312,0.6648000478744507,1.4506032466888428,10000,127102.61178898811,0.8873632550239563,0.418342113494873,0.7803199887275696,0.8548135757446289,50000 -9432.615530729294,16.88166904449463,118092.33594608308,265767,0,118092.33594608308,0.6648000478744507,1.4506032466888428,10000,127555.41320848464,0.88671875,0.420605331659317,0.7803199887275696,0.8548135757446289,50000 -9474.807350158691,16.959134101867676,118512.41550278664,266709,0,118512.41550278664,0.6648000478744507,1.4506032466888428,10000,128017.80984807014,0.8873828053474426,0.4178589880466461,0.7803199887275696,0.8548135757446289,50000 -9508.095239162443,17.03475069999695,118932.5847196579,267654,0,118932.5847196579,0.6648000478744507,1.4506032466888428,10000,128471.3901963234,0.8849608898162842,0.4240126311779022,0.7803199887275696,0.8548135757446289,50000 -9550.807997703552,17.11642861366272,119352.5260078907,268597,0,119352.5260078907,0.6648000478744507,1.4506032466888428,10000,128934.17479658128,0.8870702981948853,0.4228694140911102,0.7803199887275696,0.8548135757446289,50000 -9584.70620727539,17.181878328323364,119772.48274731636,269543,0,119772.48274731636,0.6648000478744507,1.4506032466888428,10000,129388.14326405524,0.8866991996765137,0.4204760789871216,0.7803199887275696,0.8548135757446289,50000 -9622.012517929075,17.249467372894287,120192.77802968024,270487,0,120192.77802968024,0.6648000478744507,1.4506032466888428,10000,129845.8600654602,0.88734370470047,0.4167658388614654,0.7803199887275696,0.8548135757446289,50000 -9659.829708576202,17.32775592803955,120612.67361688614,271429,0,120612.67361688614,0.6648000478744507,1.4506032466888428,10000,130303.69858121872,0.8888280987739563,0.4134787023067474,0.7803199887275696,0.8548135757446289,50000 -9692.951695919037,17.413233041763306,121032.8384103775,272378,0,121032.8384103775,0.6648000478744507,1.4506032466888428,10000,130757.11981844902,0.8880078196525574,0.4239377379417419,0.7803199887275696,0.8548135757446289,50000 -9736.4949696064,17.49224090576172,121452.86763739586,273322,0,121452.86763739586,0.6648000478744507,1.4506032466888428,10000,131220.82015681267,0.886035144329071,0.4260351955890655,0.7803199887275696,0.8548135757446289,50000 -9771.279657840729,17.556384801864624,121872.97571897508,274270,0,121872.97571897508,0.6648000478744507,1.4506032466888428,10000,131675.82518553734,0.8879492282867432,0.4197743535041809,0.7803199887275696,0.8548135757446289,50000 -9807.406393289566,17.63720178604126,122293.25101804732,275215,0,122293.25101804732,0.6648000478744507,1.4506032466888428,10000,132132.355342865,0.8881444931030273,0.4141191244125366,0.7803199887275696,0.8548135757446289,50000 -9839.993040561676,17.716108798980713,122713.5250134468,276158,0,122713.5250134468,0.6648000478744507,1.4506032466888428,10000,132585.34187698364,0.8878515362739563,0.4155722260475158,0.7803199887275696,0.8548135757446289,50000 -9876.90296959877,17.7959988117218,123133.66612696648,277100,0,123133.66612696648,0.6648000478744507,1.4506032466888428,10000,133042.52073717117,0.8867968320846558,0.4192649722099304,0.7803199887275696,0.8548135757446289,50000 -9915.872620105743,17.875136137008667,123553.5646944046,278042,0,123553.5646944046,0.6648000478744507,1.4506032466888428,10000,133501.51552557945,0.8861327767372131,0.4243049621582031,0.7803199887275696,0.8548135757446289,50000 -9952.77225279808,17.952096462249756,123973.49132466316,278985,0,123973.49132466316,0.6648000478744507,1.4506032466888428,10000,133958.46650099754,0.8908007740974426,0.4121743738651275,0.7803199887275696,0.8548135757446289,50000 -9995.665901184082,18.03502511978149,124393.40948843956,279931,0,124393.40948843956,0.6648000478744507,1.4506032466888428,10000,134421.4090359211,0.8890624642372131,0.420417308807373,0.7803199887275696,0.8548135757446289,50000 -10029.1331782341,18.101533889770508,124813.36013317108,280878,0,124813.36013317108,0.6648000478744507,1.4506032466888428,10000,134874.94128251076,0.8887109160423279,0.4124118089675903,0.7803199887275696,0.8548135757446289,50000 -10064.021770715714,18.181446075439453,125233.32454299928,281823,0,125233.32454299928,0.6648000478744507,1.4506032466888428,10000,135329.92215943336,0.8863085508346558,0.4174687564373016,0.7803199887275696,0.8548135757446289,50000 -10095.547145366669,18.245970249176025,125653.50318288805,282769,0,125653.50318288805,0.6648000478744507,1.4506032466888428,10000,135781.73765468597,0.8883593678474426,0.4142705202102661,0.7803199887275696,0.8548135757446289,50000 -10141.53751373291,18.329392433166504,126073.44761157036,283711,0,126073.44761157036,0.6648000478744507,1.4506032466888428,10000,136247.80353331566,0.8863866925239563,0.4193476140499115,0.7803199887275696,0.8548135757446289,50000 -10174.72700858116,18.393612146377563,126493.50467848778,284655,0,126493.50467848778,0.6648000478744507,1.4506032466888428,10000,136701.16173362732,0.8861523270606995,0.4190988838672638,0.7803199887275696,0.8548135757446289,50000 -10217.211532592772,18.482458353042603,126913.51318049432,285599,0,126913.51318049432,0.6648000478744507,1.4506032466888428,10000,137163.79151773453,0.8877733945846558,0.4184274971485138,0.7803199887275696,0.8548135757446289,50000 -10256.46224308014,18.5535101890564,127333.65675115584,286544,0,127333.65675115584,0.6648000478744507,1.4506032466888428,10000,137623.30382800102,0.8854687213897705,0.426918625831604,0.7803199887275696,0.8548135757446289,50000 -10292.60477733612,18.62070345878601,127753.75061297417,287489,0,127753.75061297417,0.6648000478744507,1.4506032466888428,10000,138079.65566945076,0.8874609470367432,0.4195753335952759,0.7803199887275696,0.8548135757446289,50000 -10332.257049798964,18.705175638198853,128174.01127171516,288425,0,128174.01127171516,0.6648000478744507,1.4506032466888428,10000,138539.70076584816,0.88525390625,0.4243783354759216,0.7803199887275696,0.8548135757446289,50000 -10367.236153364182,18.78646731376648,128594.14718818665,289368,0,128594.14718818665,0.6648000478744507,1.4506032466888428,10000,138994.94444060326,0.8885741829872131,0.4138832092285156,0.7803199887275696,0.8548135757446289,50000 -10403.037517786026,18.866474628448486,129014.123803854,290312,0,129014.123803854,0.6648000478744507,1.4506032466888428,10000,139450.84994983673,0.8881444931030273,0.4155890643596649,0.7803199887275696,0.8548135757446289,50000 -10447.48250246048,18.94978713989257,129434.1880440712,291257,0,129434.1880440712,0.6648000478744507,1.4506032466888428,10000,139915.48982930183,0.8851562142372131,0.4261019825935364,0.7803199887275696,0.8548135757446289,50000 -10483.340990066528,19.0203800201416,129854.40283584596,292204,0,129854.40283584596,0.6648000478744507,1.4506032466888428,10000,140371.68188858032,0.8864452838897705,0.4245300889015198,0.7803199887275696,0.8548135757446289,50000 -10519.48413681984,19.087130546569824,130274.47477340698,293150,0,130274.47477340698,0.6648000478744507,1.4506032466888428,10000,140828.01147723198,0.8864452838897705,0.4200622141361236,0.7803199887275696,0.8548135757446289,50000 -10556.83842110634,19.169212818145752,130694.4135849476,294091,0,130694.4135849476,0.6648000478744507,1.4506032466888428,10000,141285.43379950523,0.8884570002555847,0.4107588827610016,0.7803199887275696,0.8548135757446289,50000 -10604.105157375336,19.25607824325561,131114.35461211205,295034,0,131114.35461211205,0.6648000478744507,1.4506032466888428,10000,141752.77542972565,0.8865429759025574,0.4236370325088501,0.7803199887275696,0.8548135757446289,50000 -10644.770728588104,19.327077388763428,131534.4235270023,295980,0,131534.4235270023,0.6648000478744507,1.4506032466888428,10000,142213.62927746773,0.8881444931030273,0.4216232895851135,0.7803199887275696,0.8548135757446289,50000 -10686.206349372864,19.395113229751587,131954.65556120872,296924,0,131954.65556120872,0.6648000478744507,1.4506032466888428,10000,142675.4125571251,0.8879101276397705,0.4217333495616913,0.7803199887275696,0.8548135757446289,50000 -10731.38487648964,19.475355625152588,132374.87281131744,297867,0,132374.87281131744,0.6648000478744507,1.4506032466888428,10000,143140.9353840351,0.8878710865974426,0.4187273979187011,0.7803199887275696,0.8548135757446289,50000 -10764.83255124092,19.54459524154663,132794.94854402542,298811,0,132794.94854402542,0.6648000478744507,1.4506032466888428,10000,143594.57590150833,0.8876953125,0.4147098362445831,0.7803199887275696,0.8548135757446289,50000 -10801.17714715004,19.628756284713745,133215.20007777214,299750,0,133215.20007777214,0.6648000478744507,1.4506032466888428,10000,144051.30251002312,0.8880273103713989,0.4155641794204712,0.7803199887275696,0.8548135757446289,50000 -10846.996324777603,19.71213221549988,133635.22816705704,300691,0,133635.22816705704,0.6648000478744507,1.4506032466888428,10000,144517.28083229065,0.8881640434265137,0.4177298545837402,0.7803199887275696,0.8548135757446289,50000 -10887.950706481934,19.78096437454224,134055.29154729843,301638,0,134055.29154729843,0.6648000478744507,1.4506032466888428,10000,144978.41498851776,0.8869531154632568,0.4213081002235412,0.7803199887275696,0.8548135757446289,50000 -10925.710918664932,19.851025104522705,134475.1889846325,302583,0,134475.1889846325,0.6648000478744507,1.4506032466888428,10000,145436.19050335884,0.8878515362739563,0.4230779111385345,0.7803199887275696,0.8548135757446289,50000 -10966.428569555284,19.93898057937622,134895.2553062439,303524,0,134895.2553062439,0.6648000478744507,1.4506032466888428,10000,145897.11030721664,0.8898828029632568,0.4145982563495636,0.7803199887275696,0.8548135757446289,50000 -11005.957649946213,20.0074679851532,135315.58999705315,304470,0,135315.58999705315,0.6648000478744507,1.4506032466888428,10000,146357.09014439583,0.8882812261581421,0.4124145805835724,0.7803199887275696,0.8548135757446289,50000 -11039.826193094254,20.091625690460205,135735.683208704,305411,0,135735.683208704,0.6648000478744507,1.4506032466888428,10000,146811.18326807022,0.8877733945846558,0.4161712825298309,0.7803199887275696,0.8548135757446289,50000 -11084.333412647247,20.181591033935547,136155.78617358208,306355,0,136155.78617358208,0.6648000478744507,1.4506032466888428,10000,147275.9312517643,0.8863866925239563,0.4192544519901275,0.7803199887275696,0.8548135757446289,50000 -11121.695330381392,20.25333547592163,136576.10743117332,307301,0,136576.10743117332,0.6648000478744507,1.4506032466888428,10000,147733.73422026634,0.8872460722923279,0.4167300760746002,0.7803199887275696,0.8548135757446289,50000 -11155.53110909462,20.34227418899536,136996.03312301636,308242,0,136996.03312301636,0.6648000478744507,1.4506032466888428,10000,148187.63279628754,0.8857812285423279,0.4191692173480987,0.7803199887275696,0.8548135757446289,50000 -11196.13981294632,20.42733263969421,137416.065836668,309180,0,137416.065836668,0.6648000478744507,1.4506032466888428,10000,148648.40660881996,0.887988269329071,0.4194705188274383,0.7803199887275696,0.8548135757446289,50000 -11227.777802228928,20.50452852249145,137836.3591003418,310128,0,137836.3591003418,0.6648000478744507,1.4506032466888428,10000,149100.46276378632,0.8869140148162842,0.4242978692054748,0.7803199887275696,0.8548135757446289,50000 -11267.157366037369,20.590729236602783,138256.27770638466,311070,0,138256.27770638466,0.6648000478744507,1.4506032466888428,10000,149559.89429020882,0.8864843845367432,0.4207232892513275,0.7803199887275696,0.8548135757446289,50000 -11306.6024684906,20.678019762039185,138676.20446014404,312014,0,138676.20446014404,0.6648000478744507,1.4506032466888428,10000,150019.4015262127,0.8853515386581421,0.425651341676712,0.7803199887275696,0.8548135757446289,50000 -11341.178297281263,20.76548171043396,139096.4767267704,312961,0,139096.4767267704,0.6648000478744507,1.4506032466888428,10000,150474.38501358032,0.8886913657188416,0.4127460122108459,0.7803199887275696,0.8548135757446289,50000 -11379.45639872551,20.85531640052796,139516.60149145126,313905,0,139516.60149145126,0.6648000478744507,1.4506032466888428,10000,150932.9247033596,0.8878710865974426,0.4168930649757385,0.7803199887275696,0.8548135757446289,50000 -11414.343953609468,20.924494981765747,139936.63247156143,314851,0,139936.63247156143,0.6648000478744507,1.4506032466888428,10000,151387.96013522148,0.8854687213897705,0.4233700335025787,0.7803199887275696,0.8548135757446289,50000 -11457.576095819471,21.0146746635437,140356.7695221901,315793,0,140356.7695221901,0.6648000478744507,1.4506032466888428,10000,151851.4670085907,0.8864843845367432,0.4232950806617737,0.7803199887275696,0.8548135757446289,50000 -11488.859185695648,21.10109519958496,140776.95079994202,316739,0,140776.95079994202,0.6648000478744507,1.4506032466888428,10000,152303.06532382965,0.8882226347923279,0.4166227579116821,0.7803199887275696,0.8548135757446289,50000 -11527.099912643433,21.19072890281677,141197.1936097145,317682,0,141197.1936097145,0.6648000478744507,1.4506032466888428,10000,152761.68574547768,0.8858593702316284,0.4191854298114776,0.7803199887275696,0.8548135757446289,50000 -11571.23422384262,21.31320881843567,141617.29385328293,318622,0,141617.29385328293,0.6648000478744507,1.4506032466888428,10000,153226.09011101723,0.8885741829872131,0.4191855788230896,0.7803199887275696,0.8548135757446289,50000 -11606.005845546722,21.38654923439026,142037.1958515644,319569,0,142037.1958515644,0.6648000478744507,1.4506032466888428,10000,153680.8844909668,0.8874022960662842,0.4200645089149475,0.7803199887275696,0.8548135757446289,50000 -11641.539048671722,21.47724652290344,142457.30712151527,320513,0,142457.30712151527,0.6648000478744507,1.4506032466888428,10000,154136.66727113724,0.8875195384025574,0.4228684604167938,0.7803199887275696,0.8548135757446289,50000 -11685.531976222992,21.56491112709045,142877.6751279831,321457,0,142877.6751279831,0.6648000478744507,1.4506032466888428,10000,154601.16312241554,0.8886523246765137,0.4192448258399963,0.7803199887275696,0.8548135757446289,50000 -11727.38549900055,21.639564275741577,143297.70597314835,322403,0,143297.70597314835,0.6648000478744507,1.4506032466888428,10000,155063.17006206512,0.8864648342132568,0.4219618141651153,0.7803199887275696,0.8548135757446289,50000 -11760.326476812364,21.71282839775085,143717.72579455376,323349,0,143717.72579455376,0.6648000478744507,1.4506032466888428,10000,155516.2507390976,0.8867577910423279,0.4161965548992157,0.7803199887275696,0.8548135757446289,50000 -11804.24805378914,21.801358461380005,144137.9409184456,324290,0,144137.9409184456,0.6648000478744507,1.4506032466888428,10000,155980.52302742004,0.8874022960662842,0.4182784259319305,0.7803199887275696,0.8548135757446289,50000 -11837.51058626175,21.88632369041443,144558.39869880676,325237,0,144558.39869880676,0.6648000478744507,1.4506032466888428,10000,156434.3753077984,0.8866991996765137,0.4203983545303345,0.7803199887275696,0.8548135757446289,50000 -11873.732411623,21.974764585494995,144978.63285660744,326177,0,144978.63285660744,0.6648000478744507,1.4506032466888428,10000,156890.96754074097,0.88916015625,0.4168655574321747,0.7803199887275696,0.8548135757446289,50000 -11910.18722462654,22.070751667022705,145398.58534121513,327120,0,145398.58534121513,0.6648000478744507,1.4506032466888428,10000,157347.51926136017,0.8898046612739563,0.4155591130256653,0.7803199887275696,0.8548135757446289,50000 -11945.410465955734,22.175304174423218,145818.65999770164,328063,0,145818.65999770164,0.6648000478744507,1.4506032466888428,10000,157802.96967935562,0.8884375095367432,0.4140132665634155,0.7803199887275696,0.8548135757446289,50000 -11980.082593917848,22.26838541030884,146238.7226538658,329007,0,146238.7226538658,0.6648000478744507,1.4506032466888428,10000,158257.84491205215,0.8872265219688416,0.4143014550209045,0.7803199887275696,0.8548135757446289,50000 -12024.194583177568,22.361082553863525,146659.0313911438,329951,0,146659.0313911438,0.6648000478744507,1.4506032466888428,10000,158722.40548229218,0.8880859017372131,0.4180927276611328,0.7803199887275696,0.8548135757446289,50000 -12061.329034805298,22.44697642326355,147079.2585787773,330899,0,147079.2585787773,0.6648000478744507,1.4506032466888428,10000,159179.90055942535,0.8866601586341858,0.4193836450576782,0.7803199887275696,0.8548135757446289,50000 -12107.44764471054,23.31242060661316,147498.69048261642,331840,0,147498.69048261642,0.6648000478744507,1.4506032466888428,10000,159646.3641116619,0.8866406083106995,0.4111728072166443,0.7803199887275696,0.8548135757446289,50000 -12141.548724412918,23.3977952003479,147918.98912215233,332787,0,147918.98912215233,0.6648000478744507,1.4506032466888428,10000,160100.89725995064,0.8870898485183716,0.427180141210556,0.7803199887275696,0.8548135757446289,50000 -12185.530223608015,23.488813161849976,148338.922219038,333723,0,148338.922219038,0.6648000478744507,1.4506032466888428,10000,160564.95006656647,0.8864062428474426,0.4243050217628479,0.7803199887275696,0.8548135757446289,50000 -12226.760622501371,23.583487033844,148758.86404037476,334663,0,148758.86404037476,0.6648000478744507,1.4506032466888428,10000,161026.26504921913,0.8871679306030273,0.4189110994338989,0.7803199887275696,0.8548135757446289,50000 -12264.491759061812,23.673017024993896,149179.06500458717,335606,0,149179.06500458717,0.6648000478744507,1.4506032466888428,10000,161484.33428955078,0.88623046875,0.4251594543457031,0.7803199887275696,0.8548135757446289,50000 -12306.18427157402,23.765666007995605,149599.2316122055,336548,0,149599.2316122055,0.6648000478744507,1.4506032466888428,10000,161946.33445000648,0.8883788585662842,0.4107557237148285,0.7803199887275696,0.8548135757446289,50000 -12351.189871788025,23.86135625839233,150019.19176864624,337490,0,150019.19176864624,0.6648000478744507,1.4506032466888428,10000,162411.44314026833,0.8858007788658142,0.4200242757797241,0.7803199887275696,0.8548135757446289,50000 -12385.872237205504,23.952771425247192,150439.1379084587,338430,0,150439.1379084587,0.6648000478744507,1.4506032466888428,10000,162866.21015238762,0.8881054520606995,0.4224624037742615,0.7803199887275696,0.8548135757446289,50000 -12421.92520046234,24.09778380393982,150859.23227977753,339370,0,150859.23227977753,0.6648000478744507,1.4506032466888428,10000,163322.5501112938,0.8855859041213989,0.4248220324516296,0.7803199887275696,0.8548135757446289,50000 -12459.357788085938,24.19091844558716,151279.18514490128,340311,0,151279.18514490128,0.6648000478744507,1.4506032466888428,10000,163780.07631373403,0.8870312571525574,0.4190521538257599,0.7803199887275696,0.8548135757446289,50000 -12504.30022096634,24.288305044174194,151699.11443638802,341256,0,151699.11443638802,0.6648000478744507,1.4506032466888428,10000,164245.09310364723,0.8864062428474426,0.4155356585979461,0.7803199887275696,0.8548135757446289,50000 -12544.256209850311,24.38241219520569,152119.28486824036,342201,0,152119.28486824036,0.6648000478744507,1.4506032466888428,10000,164705.3630282879,0.8872656226158142,0.4203912913799286,0.7803199887275696,0.8548135757446289,50000 -12587.055959939957,24.475526571273804,152539.3480424881,343145,0,152539.3480424881,0.6648000478744507,1.4506032466888428,10000,165168.36728596687,0.8889843821525574,0.4194121360778808,0.7803199887275696,0.8548135757446289,50000 -12622.95605802536,24.554410457611084,152959.46968698502,344090,0,152959.46968698502,0.6648000478744507,1.4506032466888428,10000,165624.51542139053,0.8886132836341858,0.4171028733253479,0.7803199887275696,0.8548135757446289,50000 -12665.410373926165,24.657466411590576,153379.60203409195,345033,0,153379.60203409195,0.6648000478744507,1.4506032466888428,10000,166087.2524061203,0.8860741853713989,0.4267341494560241,0.7803199887275696,0.8548135757446289,50000 -12699.949487686155,24.752000331878666,153799.6966097355,345977,0,153799.6966097355,0.6648000478744507,1.4506032466888428,10000,166542.02790880203,0.8871288895606995,0.4171859622001648,0.7803199887275696,0.8548135757446289,50000 -12746.5108397007,24.84833550453186,154219.81880021095,346918,0,154219.81880021095,0.6648000478744507,1.4506032466888428,10000,167008.85578632355,0.887988269329071,0.4203246533870697,0.7803199887275696,0.8548135757446289,50000 -12784.672406435013,24.939489126205444,154639.78215909004,347864,0,154639.78215909004,0.6648000478744507,1.4506032466888428,10000,167467.11938381195,0.8874804377555847,0.4153309464454651,0.7803199887275696,0.8548135757446289,50000 -12819.244787454603,25.030781984329224,155059.82725691795,348810,0,155059.82725691795,0.6648000478744507,1.4506032466888428,10000,167921.87595415115,0.8859374523162842,0.4238024055957794,0.7803199887275696,0.8548135757446289,50000 -12855.2223072052,25.12555241584778,155480.12934875488,349753,0,155480.12934875488,0.6648000478744507,1.4506032466888428,10000,168378.29820513725,0.8904492259025574,0.4134562611579895,0.7803199887275696,0.8548135757446289,50000 -12898.995729207993,25.21956491470337,155900.0498933792,350694,0,155900.0498933792,0.6648000478744507,1.4506032466888428,10000,168842.13325691223,0.8905858993530273,0.4152538478374481,0.7803199887275696,0.8548135757446289,50000 -12936.320994138718,25.31429362297058,156320.398665905,351640,0,156320.398665905,0.6648000478744507,1.4506032466888428,10000,169299.9502491951,0.887499988079071,0.4161133468151092,0.7803199887275696,0.8548135757446289,50000 -12981.430022001266,25.41551327705384,156740.2870953083,352583,0,156740.2870953083,0.6648000478744507,1.4506032466888428,10000,169765.09603238106,0.88587886095047,0.4204886257648468,0.7803199887275696,0.8548135757446289,50000 -13021.78906273842,25.50675082206726,157160.15752458572,353525,0,157160.15752458572,0.6648000478744507,1.4506032466888428,10000,170225.4692606926,0.8861718773841858,0.4195727407932281,0.7803199887275696,0.8548135757446289,50000 -13060.845579624176,25.601975679397583,157580.12749171257,354468,0,157580.12749171257,0.6648000478744507,1.4506032466888428,10000,170684.63828587532,0.8877343535423279,0.4124457836151123,0.7803199887275696,0.8548135757446289,50000 -13101.89858865738,25.69657111167908,158000.41928219795,355410,0,158000.41928219795,0.6648000478744507,1.4506032466888428,10000,171146.12599873543,0.8864648342132568,0.4209831058979034,0.7803199887275696,0.8548135757446289,50000 -13144.91348528862,25.793184995651245,158420.6190032959,356353,0,158420.6190032959,0.6648000478744507,1.4506032466888428,10000,171609.484395504,0.8862890601158142,0.4223797619342804,0.7803199887275696,0.8548135757446289,50000 -13179.165003061296,25.876455068588257,158840.51721048355,357299,0,158840.51721048355,0.6648000478744507,1.4506032466888428,10000,172063.7651720047,0.8887499570846558,0.4185085296630859,0.7803199887275696,0.8548135757446289,50000 -13224.510422706604,25.97172403335572,159260.69488954544,358239,0,159260.69488954544,0.6648000478744507,1.4506032466888428,10000,172529.43064379692,0.8867577910423279,0.4219204485416412,0.7803199887275696,0.8548135757446289,50000 -13260.374766349792,26.063395261764526,159680.70743894577,359182,0,159680.70743894577,0.6648000478744507,1.4506032466888428,10000,172985.44624853134,0.8855664134025574,0.4237650334835052,0.7803199887275696,0.8548135757446289,50000 -13297.446984291077,26.16198706626892,160100.66970348358,360127,0,160100.66970348358,0.6648000478744507,1.4506032466888428,10000,173442.63001036644,0.8892382383346558,0.414288729429245,0.7803199887275696,0.8548135757446289,50000 -13330.930730581284,26.26072931289673,160520.55683374405,361070,0,160520.55683374405,0.6648000478744507,1.4506032466888428,10000,173896.1469092369,0.88832026720047,0.4119943976402282,0.7803199887275696,0.8548135757446289,50000 -13376.318539857864,26.35774397850037,160940.7024629116,362015,0,160940.7024629116,0.6648000478744507,1.4506032466888428,10000,174361.82907938957,0.8843945264816284,0.4265358746051788,0.7803199887275696,0.8548135757446289,50000 -13409.688135623932,26.49013066291809,161361.03342723846,362960,0,161361.03342723846,0.6648000478744507,1.4506032466888428,10000,174815.7095863819,0.8865624666213989,0.4253043234348297,0.7803199887275696,0.8548135757446289,50000 -13449.352699279783,26.583407163619995,161781.51662492752,363901,0,161781.51662492752,0.6648000478744507,1.4506032466888428,10000,175275.9978826046,0.8876562118530273,0.4183862209320068,0.7803199887275696,0.8548135757446289,50000 -13485.275573015211,26.678946018219,162201.56584835052,364845,0,162201.56584835052,0.6648000478744507,1.4506032466888428,10000,175732.1130809784,0.8862695097923279,0.4161961674690246,0.7803199887275696,0.8548135757446289,50000 -13518.878060102465,26.776683807373047,162621.76594257355,365789,0,162621.76594257355,0.6648000478744507,1.4506032466888428,10000,176186.0605404377,0.8881444931030273,0.419982761144638,0.7803199887275696,0.8548135757446289,50000 -13567.830877304075,26.87555503845215,163042.15055036545,366733,0,163042.15055036545,0.6648000478744507,1.4506032466888428,10000,176655.54402661324,0.8887109160423279,0.416282057762146,0.7803199887275696,0.8548135757446289,50000 -13608.676812171936,26.95441627502441,163462.26739168167,367683,0,163462.26739168167,0.6648000478744507,1.4506032466888428,10000,177116.63298726082,0.8882421851158142,0.4190069139003753,0.7803199887275696,0.8548135757446289,50000 -13655.980248212814,27.03390669822693,163882.3846206665,368627,0,163882.3846206665,0.6648000478744507,1.4506032466888428,10000,177584.1802227497,0.8858202695846558,0.4272821247577667,0.7803199887275696,0.8548135757446289,50000 -13689.44184088707,27.11146354675293,164302.32819080353,369571,0,164302.32819080353,0.6648000478744507,1.4506032466888428,10000,178037.7103281021,0.8865429759025574,0.4198426008224487,0.7803199887275696,0.8548135757446289,50000 -13726.37277674675,27.207013607025143,164722.48730134964,370512,0,164722.48730134964,0.6648000478744507,1.4506032466888428,10000,178494.94316363335,0.8881640434265137,0.4155753254890442,0.7803199887275696,0.8548135757446289,50000 -13766.942764282228,27.30364155769348,165142.49481630325,371455,0,165142.49481630325,0.6648000478744507,1.4506032466888428,10000,178955.66393136978,0.88916015625,0.4131352603435516,0.7803199887275696,0.8548135757446289,50000 -13800.73223042488,27.383100271224976,165562.63774585724,372401,0,165562.63774585724,0.6648000478744507,1.4506032466888428,10000,179409.72325754166,0.88525390625,0.4247466921806335,0.7803199887275696,0.8548135757446289,50000 -13843.269656419754,27.48445630073548,165982.9108800888,373342,0,165982.9108800888,0.6648000478744507,1.4506032466888428,10000,179872.6829378605,0.8884375095367432,0.4191824495792389,0.7803199887275696,0.8548135757446289,50000 -13886.068606615068,27.56479525566101,166403.01735568047,374287,0,166403.01735568047,0.6648000478744507,1.4506032466888428,10000,180335.7163796425,0.8906054496765137,0.4155364036560058,0.7803199887275696,0.8548135757446289,50000 -13917.866647958755,27.90357255935669,166822.80471730232,375230,0,166822.80471730232,0.6648000478744507,1.4506032466888428,10000,180787.6889693737,0.8887304663658142,0.4109485149383545,0.7803199887275696,0.8548135757446289,50000 -13962.844014406204,28.00294804573059,167243.02319812775,376172,0,167243.02319812775,0.6648000478744507,1.4506032466888428,10000,181253.0316066742,0.88587886095047,0.4229264855384826,0.7803199887275696,0.8548135757446289,50000 -14007.10686326027,28.10410571098328,167663.00806236267,377119,0,167663.00806236267,0.6648000478744507,1.4506032466888428,10000,181717.4283750057,0.8873242139816284,0.4139610826969147,0.7803199887275696,0.8548135757446289,50000 -14040.968413114548,28.18370485305786,168082.953353405,378063,0,168082.953353405,0.6648000478744507,1.4506032466888428,10000,182171.3619320393,0.8871679306030273,0.4175764620304107,0.7803199887275696,0.8548135757446289,50000 -14077.369477033615,28.28520369529724,168502.90359401703,379000,0,168502.90359401703,0.6648000478744507,1.4506032466888428,10000,182627.8619401455,0.8867382407188416,0.4216572046279907,0.7803199887275696,0.8548135757446289,50000 -14124.485233306885,28.38457489013672,168923.17162704468,379940,0,168923.17162704468,0.6648000478744507,1.4506032466888428,10000,183095.3925318718,0.8876171708106995,0.4176125228404999,0.7803199887275696,0.8548135757446289,50000 -14163.1655189991,28.470067501068115,169343.40920114517,380885,0,169343.40920114517,0.6648000478744507,1.4506032466888428,10000,183554.4429168701,0.8860155940055847,0.426149308681488,0.7803199887275696,0.8548135757446289,50000 -14201.979838609695,28.55268931388855,169763.40131759644,381830,0,169763.40131759644,0.6648000478744507,1.4506032466888428,10000,184013.37898135185,0.8877539038658142,0.4175314307212829,0.7803199887275696,0.8548135757446289,50000 -14242.088454008102,28.640817880630493,170183.68056845665,382773,0,170183.68056845665,0.6648000478744507,1.4506032466888428,10000,184473.9022257328,0.88636714220047,0.4217416942119598,0.7803199887275696,0.8548135757446289,50000 -14280.231041908264,28.722613096237183,170603.55612325668,383717,0,170603.55612325668,0.6648000478744507,1.4506032466888428,10000,184932.04956531525,0.8875781297683716,0.4189814925193786,0.7803199887275696,0.8548135757446289,50000 -14317.67663049698,28.820887088775635,171023.6480166912,384657,0,171023.6480166912,0.6648000478744507,1.4506032466888428,10000,185389.73262786865,0.887499988079071,0.4142645299434662,0.7803199887275696,0.8548135757446289,50000 -14362.350728034971,28.92026162147522,171443.81467032433,385599,0,171443.81467032433,0.6648000478744507,1.4506032466888428,10000,185854.7215340137,0.8870898485183716,0.4226619601249695,0.7803199887275696,0.8548135757446289,50000 -14394.879940986631,29.00299835205078,171863.97700834274,386545,0,171863.97700834274,0.6648000478744507,1.4506032466888428,10000,186307.54375863075,0.886035144329071,0.4234894216060638,0.7803199887275696,0.8548135757446289,50000 -14432.549918174744,29.103593349456787,172284.1622555256,387487,0,172284.1622555256,0.6648000478744507,1.4506032466888428,10000,186765.5475420952,0.8861913681030273,0.4212893843650818,0.7803199887275696,0.8548135757446289,50000 -14470.37551188469,29.203126668930054,172704.41179394722,388432,0,172704.41179394722,0.6648000478744507,1.4506032466888428,10000,187223.76943206787,0.887499988079071,0.4149161875247955,0.7803199887275696,0.8548135757446289,50000 -14505.541404247284,29.30449271202088,173124.3463385105,389377,0,173124.3463385105,0.6648000478744507,1.4506032466888428,10000,187679.0199213028,0.8874413967132568,0.4166697859764099,0.7803199887275696,0.8548135757446289,50000 -14545.296684980392,29.40795946121216,173544.2095386982,390321,0,173544.2095386982,0.6648000478744507,1.4506032466888428,10000,188138.7891843319,0.8872265219688416,0.4215186536312103,0.7803199887275696,0.8548135757446289,50000 -14579.097965717316,29.495544910430908,173964.30621051788,391267,0,173964.30621051788,0.6648000478744507,1.4506032466888428,10000,188592.8221981525,0.8884375095367432,0.4242092669010162,0.7803199887275696,0.8548135757446289,50000 -14618.065967798231,29.601163387298584,174384.32402396202,392207,0,174384.32402396202,0.6648000478744507,1.4506032466888428,10000,189051.9608139992,0.8878905773162842,0.4217130243778229,0.7803199887275696,0.8548135757446289,50000 -14655.7409658432,29.7026801109314,174804.2192106247,393149,0,174804.2192106247,0.6648000478744507,1.4506032466888428,10000,189509.67946982384,0.8867577910423279,0.4197116792201996,0.7803199887275696,0.8548135757446289,50000 -14697.594356775284,29.80150079727173,175224.10652852058,394093,0,175224.10652852058,0.6648000478744507,1.4506032466888428,10000,189971.56618881223,0.8876757621765137,0.416228324174881,0.7803199887275696,0.8548135757446289,50000 -14736.21154475212,29.8844575881958,175644.32959270477,395038,0,175644.32959270477,0.6648000478744507,1.4506032466888428,10000,190430.5368359089,0.8874413967132568,0.4183365404605865,0.7803199887275696,0.8548135757446289,50000 -14779.204402923584,29.98852634429932,176064.4482076168,395981,0,176064.4482076168,0.6648000478744507,1.4506032466888428,10000,190893.79946780205,0.8874804377555847,0.4170241057872772,0.7803199887275696,0.8548135757446289,50000 -14817.607572078705,30.07315492630005,176484.37260246277,396924,0,176484.37260246277,0.6648000478744507,1.4506032466888428,10000,191352.2600080967,0.8891991972923279,0.4172369539737701,0.7803199887275696,0.8548135757446289,50000 -14853.429702758787,30.15562510490417,176904.55462956429,397870,0,176904.55462956429,0.6648000478744507,1.4506032466888428,10000,191808.3955821991,0.8872265219688416,0.4217978119850158,0.7803199887275696,0.8548135757446289,50000 -14888.592317581177,30.25603485107422,177324.88240146637,398814,0,177324.88240146637,0.6648000478744507,1.4506032466888428,10000,192264.0342531204,0.8900976181030273,0.4069123864173889,0.7803199887275696,0.8548135757446289,50000 -14926.936472415924,30.356215715408325,177745.00523638725,399751,0,177745.00523638725,0.6648000478744507,1.4506032466888428,10000,192722.6485464573,0.8873632550239563,0.4189270436763763,0.7803199887275696,0.8548135757446289,50000 -14965.11044239998,30.468629837036133,178164.88333821297,400693,0,178164.88333821297,0.6648000478744507,1.4506032466888428,10000,193180.8615772724,0.8883788585662842,0.4117428958415985,0.7803199887275696,0.8548135757446289,50000 -15007.873401165009,30.57186913490296,178585.00466036797,401638,0,178585.00466036797,0.6648000478744507,1.4506032466888428,10000,193643.89689183235,0.8850390315055847,0.4258567094802856,0.7803199887275696,0.8548135757446289,50000 -15042.488958358765,30.655258893966675,179004.96719145775,402585,0,179004.96719145775,0.6648000478744507,1.4506032466888428,10000,194098.60548830032,0.8865624666213989,0.4191708266735077,0.7803199887275696,0.8548135757446289,50000 -15088.383921384811,30.7644944190979,179425.1990661621,403526,0,179425.1990661621,0.6648000478744507,1.4506032466888428,10000,194564.8892962933,0.8875585794448853,0.4197700917720794,0.7803199887275696,0.8548135757446289,50000 -15131.103526830671,30.86377716064453,179845.30004048347,404467,0,179845.30004048347,0.6648000478744507,1.4506032466888428,10000,195027.85599136355,0.8872265219688416,0.4191464781761169,0.7803199887275696,0.8548135757446289,50000 -15166.39280295372,30.94795870780945,180265.17570757863,405411,0,180265.17570757863,0.6648000478744507,1.4506032466888428,10000,195483.1527121067,0.8865624666213989,0.4221104681491852,0.7803199887275696,0.8548135757446289,50000 -15206.514477968216,31.055206060409542,180685.35479688644,406350,0,180685.35479688644,0.6648000478744507,1.4506032466888428,10000,195943.6081705093,0.8861327767372131,0.4223325252532959,0.7803199887275696,0.8548135757446289,50000 -15244.572483301165,31.16160798072815,181105.31928634644,407288,0,181105.31928634644,0.6648000478744507,1.4506032466888428,10000,196401.7843978405,0.888476550579071,0.4176208078861236,0.7803199887275696,0.8548135757446289,50000 -15283.238829135897,31.26688051223755,181525.1695902348,408228,0,181525.1695902348,0.6648000478744507,1.4506032466888428,10000,196860.4535934925,0.8869140148162842,0.4151829779148102,0.7803199887275696,0.8548135757446289,50000 -15324.498523712158,31.37296724319458,181945.4384617805,409172,0,181945.4384617805,0.6648000478744507,1.4506032466888428,10000,197322.14072799683,0.8871874809265137,0.4227842986583709,0.7803199887275696,0.8548135757446289,50000 -15366.648778438568,31.45956587791443,182365.6897525788,410120,0,182365.6897525788,0.6648000478744507,1.4506032466888428,10000,197784.67691898343,0.8855859041213989,0.4234414994716644,0.7803199887275696,0.8548135757446289,50000 -15401.639070272446,31.566688537597656,182785.7644975185,411060,0,182785.7644975185,0.6648000478744507,1.4506032466888428,10000,198239.89599585533,0.8866015672683716,0.4224465787410736,0.7803199887275696,0.8548135757446289,50000 -15444.472546577454,31.67509889602661,183206.07600712776,411999,0,183206.07600712776,0.6648000478744507,1.4506032466888428,10000,198703.19691753387,0.8878124952316284,0.4114729464054107,0.7803199887275696,0.8548135757446289,50000 -15478.594518899918,31.759069442749023,183626.06824731827,412943,0,183626.06824731827,0.6648000478744507,1.4506032466888428,10000,199157.4434304237,0.8874609470367432,0.419604629278183,0.7803199887275696,0.8548135757446289,50000 -15517.017944574356,31.868417978286743,184045.93671751025,413886,0,184045.93671751025,0.6648000478744507,1.4506032466888428,10000,199615.89252972603,0.8876367211341858,0.4223132729530334,0.7803199887275696,0.8548135757446289,50000 -15558.254828453064,31.97731304168701,184465.8177063465,414826,0,184465.8177063465,0.6648000478744507,1.4506032466888428,10000,200077.1666204929,0.8863281011581421,0.427275151014328,0.7803199887275696,0.8548135757446289,50000 -15593.294227838516,32.06606459617615,184885.7917456627,415771,0,184885.7917456627,0.6648000478744507,1.4506032466888428,10000,200532.31643295288,0.8896093368530273,0.4156470894813537,0.7803199887275696,0.8548135757446289,50000 -15628.56196141243,32.17793083190918,185305.9206287861,416709,0,185305.9206287861,0.6648000478744507,1.4506032466888428,10000,200987.87162804604,0.8847265243530273,0.4236800968647003,0.7803199887275696,0.8548135757446289,50000 -15665.49923491478,32.28747606277466,185726.07363700867,417649,0,185726.07363700867,0.6648000478744507,1.4506032466888428,10000,201445.11868214607,0.8902343511581421,0.4112899601459503,0.7803199887275696,0.8548135757446289,50000 -15709.535342693329,32.39239454269409,186146.1403939724,418587,0,186146.1403939724,0.6648000478744507,1.4506032466888428,10000,201909.37362527847,0.887499988079071,0.4170810580253601,0.7803199887275696,0.8548135757446289,50000 -15751.373156309128,32.48406267166138,186566.229950428,419523,0,186566.229950428,0.6648000478744507,1.4506032466888428,10000,202371.4406733513,0.8847265243530273,0.4248861670494079,0.7803199887275696,0.8548135757446289,50000 -15784.72314286232,32.571659564971924,186986.4781191349,420468,0,186986.4781191349,0.6648000478744507,1.4506032466888428,10000,202825.1754875183,0.8904687166213989,0.4141132235527038,0.7803199887275696,0.8548135757446289,50000 -15825.103631973268,32.67925763130188,187406.4428319931,421408,0,187406.4428319931,0.6648000478744507,1.4506032466888428,10000,203285.67652893063,0.8883984088897705,0.4225490093231201,0.7803199887275696,0.8548135757446289,50000 -15865.375450849531,32.78850722312927,187826.50917100903,422352,0,187826.50917100903,0.6648000478744507,1.4506032466888428,10000,203746.17261266708,0.8893554210662842,0.4090215563774109,0.7803199887275696,0.8548135757446289,50000 -15906.216959238052,32.89739751815796,188246.44891786567,423296,0,188246.44891786567,0.6648000478744507,1.4506032466888428,10000,204207.1106221676,0.8877343535423279,0.4170146882534027,0.7803199887275696,0.8548135757446289,50000 -15941.91059589386,33.02621150016785,188666.66090917587,424240,0,188666.66090917587,0.6648000478744507,1.4506032466888428,10000,204663.19302153587,0.8865820169448853,0.4170560240745544,0.7803199887275696,0.8548135757446289,50000 -15981.843644618988,33.13454604148865,189086.88722491264,425183,0,189086.88722491264,0.6648000478744507,1.4506032466888428,10000,205123.508122921,0.8876562118530273,0.4162637591361999,0.7803199887275696,0.8548135757446289,50000 -16026.771565198898,33.24440002441406,189507.0796597004,426119,0,189507.0796597004,0.6648000478744507,1.4506032466888428,10000,205588.7849709988,0.8857812285423279,0.4212409257888794,0.7803199887275696,0.8548135757446289,50000 -16063.07480955124,33.3340208530426,189927.3677642345,427064,0,189927.3677642345,0.6648000478744507,1.4506032466888428,10000,206045.5145745277,0.8869726657867432,0.4184642732143402,0.7803199887275696,0.8548135757446289,50000 -16098.452381849287,33.44491386413574,190347.5967078209,428007,0,190347.5967078209,0.6648000478744507,1.4506032466888428,10000,206501.27928853035,0.8869726657867432,0.4254200160503387,0.7803199887275696,0.8548135757446289,50000 -16144.987778663635,33.55448365211487,190767.7202951908,428950,0,190767.7202951908,0.6648000478744507,1.4506032466888428,10000,206968.0955746174,0.8871093392372131,0.4201515614986419,0.7803199887275696,0.8548135757446289,50000 -16180.92680835724,33.644153118133545,191187.6201922893,429896,0,191187.6201922893,0.6648000478744507,1.4506032466888428,10000,207424.07229161265,0.8857030868530273,0.4262813031673431,0.7803199887275696,0.8548135757446289,50000 -16217.364193677902,33.7524573802948,191607.90561056137,430841,0,191607.90561056137,0.6648000478744507,1.4506032466888428,10000,207880.95266985893,0.8873046636581421,0.4140742123126983,0.7803199887275696,0.8548135757446289,50000 -16259.892963647842,33.86008358001709,192027.8820848465,431784,0,192027.8820848465,0.6648000478744507,1.4506032466888428,10000,208343.6129004956,0.8875781297683716,0.4148840606212616,0.7803199887275696,0.8548135757446289,50000 -16294.80163049698,33.95030069351196,192447.784427166,432731,0,192447.784427166,0.6648000478744507,1.4506032466888428,10000,208798.56092476845,0.8882616758346558,0.4209084510803222,0.7803199887275696,0.8548135757446289,50000 -16332.21807217598,34.1000452041626,192867.92846989632,433674,0,192867.92846989632,0.6648000478744507,1.4506032466888428,10000,209256.31863760948,0.8856444954872131,0.4259785115718841,0.7803199887275696,0.8548135757446289,50000 -16366.020912408829,34.18677401542664,193288.04555511475,434615,0,193288.04555511475,0.6648000478744507,1.4506032466888428,10000,209710.3728477955,0.8866601586341858,0.4217454791069031,0.7803199887275696,0.8548135757446289,50000 -16409.411342144012,34.299379110336304,193708.16326212883,435556,0,193708.16326212883,0.6648000478744507,1.4506032466888428,10000,210174.040797472,0.8860546946525574,0.4153732657432556,0.7803199887275696,0.8548135757446289,50000 -16446.87612438202,34.396831035614014,194128.2306342125,436500,0,194128.2306342125,0.6648000478744507,1.4506032466888428,10000,210631.71815299988,0.8884179592132568,0.4191081523895263,0.7803199887275696,0.8548135757446289,50000 -16483.761477947235,34.50777578353882,194548.28590345383,437447,0,194548.28590345383,0.6648000478744507,1.4506032466888428,10000,211088.8171780109,0.8879492282867432,0.421378344297409,0.7803199887275696,0.8548135757446289,50000 -16528.610441684723,34.61800146102905,194968.3395268917,438388,0,194968.3395268917,0.6648000478744507,1.4506032466888428,10000,211553.87763905525,0.8866991996765137,0.4202461242675781,0.7803199887275696,0.8548135757446289,50000 -16571.6326918602,34.726662397384644,195388.52269601825,439332,0,195388.52269601825,0.6648000478744507,1.4506032466888428,10000,212017.23931241035,0.8890234231948853,0.4170263707637787,0.7803199887275696,0.8548135757446289,50000 -16609.84241771698,34.8164746761322,195808.53741788864,440278,0,195808.53741788864,0.6648000478744507,1.4506032466888428,10000,212475.6008841992,0.8864843845367432,0.4210885167121887,0.7803199887275696,0.8548135757446289,50000 -16656.709317207336,34.92500853538513,196228.62652277944,441221,0,196228.62652277944,0.6648000478744507,1.4506032466888428,10000,212942.71294903755,0.8875195384025574,0.4173341393470764,0.7803199887275696,0.8548135757446289,50000 -16700.883150815964,35.01406455039978,196648.6313741207,442159,0,196648.6313741207,0.6648000478744507,1.4506032466888428,10000,213407.027728796,0.8889257907867432,0.4166601002216339,0.7803199887275696,0.8548135757446289,50000 -16744.626331090927,35.10526466369629,197069.1317648888,443101,0,197069.1317648888,0.6648000478744507,1.4506032466888428,10000,213871.40952277184,0.8854296803474426,0.422122985124588,0.7803199887275696,0.8548135757446289,50000 -16785.014815568924,35.19585084915161,197489.02456712723,444043,0,197489.02456712723,0.6648000478744507,1.4506032466888428,10000,214331.8288094997,0.8884179592132568,0.4219891726970672,0.7803199887275696,0.8548135757446289,50000 -16827.93798494339,35.28695893287659,197909.07106089592,444988,0,197909.07106089592,0.6648000478744507,1.4506032466888428,10000,214794.9362232685,0.8905078172683716,0.4116933345794678,0.7803199887275696,0.8548135757446289,50000 -16870.90295481682,35.37960910797119,198329.2227098941,445931,0,198329.2227098941,0.6648000478744507,1.4506032466888428,10000,215258.19314956665,0.8903515338897705,0.4074158072471618,0.7803199887275696,0.8548135757446289,50000 -16915.30412220955,35.47115755081177,198749.5303769112,446876,0,198749.5303769112,0.6648000478744507,1.4506032466888428,10000,215723.04054903984,0.885546863079071,0.4211876988410949,0.7803199887275696,0.8548135757446289,50000 -16952.20995235443,35.563923597335815,199169.5191576481,447823,0,199169.5191576481,0.6648000478744507,1.4506032466888428,10000,216180.07539200783,0.887011706829071,0.4194007813930511,0.7803199887275696,0.8548135757446289,50000 -16990.376428365707,35.668083906173706,199589.37185049057,448761,0,199589.37185049057,0.6648000478744507,1.4506032466888428,10000,216638.24611639977,0.8859765529632568,0.421368658542633,0.7803199887275696,0.8548135757446289,50000 -17035.225475549698,35.76078271865845,200009.4314689636,449701,0,200009.4314689636,0.6648000478744507,1.4506032466888428,10000,217103.29435062408,0.8870702981948853,0.4166412651538849,0.7803199887275696,0.8548135757446289,50000 -17071.777951478958,35.85300397872925,200429.42027401924,450645,0,200429.42027401924,0.6648000478744507,1.4506032466888428,10000,217559.97447752955,0.8881444931030273,0.4162316024303436,0.7803199887275696,0.8548135757446289,50000 -17122.015122890472,35.96689581871033,200849.3727684021,451587,0,200849.3727684021,0.6648000478744507,1.4506032466888428,10000,218030.3255007267,0.8881054520606995,0.421012133359909,0.7803199887275696,0.8548135757446289,50000 -17157.460511684418,36.06063795089722,201269.47407794,452532,0,201269.47407794,0.6648000478744507,1.4506032466888428,10000,218486.012780428,0.8839452862739563,0.4290750026702881,0.7803199887275696,0.8548135757446289,50000 -17196.084241628647,36.18183422088623,201689.4728207588,453473,0,201689.4728207588,0.6648000478744507,1.4506032466888428,10000,218944.8034465313,0.8886913657188416,0.4158783555030823,0.7803199887275696,0.8548135757446289,50000 -17241.766832351685,36.29235625267029,202109.38947105408,454414,0,202109.38947105408,0.6648000478744507,1.4506032466888428,10000,219410.560652256,0.8854491710662842,0.4210602045059204,0.7803199887275696,0.8548135757446289,50000 -17278.042666196823,36.3836932182312,202529.48514270785,455360,0,202529.48514270785,0.6648000478744507,1.4506032466888428,10000,219867.0714662075,0.8898632526397705,0.4121298491954803,0.7803199887275696,0.8548135757446289,50000 -17317.149385929108,36.49660277366638,202949.4070637226,456304,0,202949.4070637226,0.6648000478744507,1.4506032466888428,10000,220326.2607011795,0.8844726085662842,0.427636057138443,0.7803199887275696,0.8548135757446289,50000 -17355.284680366516,36.609949350357056,203369.30669617653,457242,0,203369.30669617653,0.6648000478744507,1.4506032466888428,10000,220784.4556255341,0.8884375095367432,0.4146224856376648,0.7803199887275696,0.8548135757446289,50000 -17397.613396406174,36.72454309463501,203789.2432193756,458183,0,203789.2432193756,0.6648000478744507,1.4506032466888428,10000,221246.88278746605,0.88685542345047,0.4204631745815277,0.7803199887275696,0.8548135757446289,50000 -17436.533225536346,36.84003710746765,204209.13581633568,459129,0,204209.13581633568,0.6648000478744507,1.4506032466888428,10000,221705.85871863365,0.8864452838897705,0.4202134609222412,0.7803199887275696,0.8548135757446289,50000 -17481.431627988815,36.955846309661865,204629.39767479897,460072,0,204629.39767479897,0.6648000478744507,1.4506032466888428,10000,222171.1817200184,0.8882030844688416,0.416919469833374,0.7803199887275696,0.8548135757446289,50000 -17524.9306910038,37.0502290725708,205049.5205335617,461015,0,205049.5205335617,0.6648000478744507,1.4506032466888428,10000,222634.9448349476,0.8881054520606995,0.4188825786113739,0.7803199887275696,0.8548135757446289,50000 -17562.089664697647,37.14282035827637,205469.76189637184,461958,0,205469.76189637184,0.6648000478744507,1.4506032466888428,10000,223092.48494267464,0.8861327767372131,0.4285134375095367,0.7803199887275696,0.8548135757446289,50000 -17605.4763879776,37.25931525230408,205890.15830087665,462901,0,205890.15830087665,0.6648000478744507,1.4506032466888428,10000,223556.4319159985,0.8883007764816284,0.4149970114231109,0.7803199887275696,0.8548135757446289,50000 -17643.671263217926,37.37636876106262,206310.3851909637,463843,0,206310.3851909637,0.6648000478744507,1.4506032466888428,10000,224015.01740431783,0.887011706829071,0.4187023341655731,0.7803199887275696,0.8548135757446289,50000 -17679.339367866516,37.474828243255615,206730.5388054848,464785,0,206730.5388054848,0.6648000478744507,1.4506032466888428,10000,224470.98530101776,0.8866406083106995,0.4217980206012726,0.7803199887275696,0.8548135757446289,50000 -17718.601448774338,37.595885038375854,207150.3881704808,465727,0,207150.3881704808,0.6648000478744507,1.4506032466888428,10000,224930.26573109627,0.8874609470367432,0.4204729497432709,0.7803199887275696,0.8548135757446289,50000 -17757.488273620605,37.715657234191895,207570.5442082882,466668,0,207570.5442082882,0.6648000478744507,1.4506032466888428,10000,225389.47584223747,0.8872265219688416,0.4182624816894531,0.7803199887275696,0.8548135757446289,50000 -17800.95606303215,37.89576005935669,207990.6780860424,467613,0,207990.6780860424,0.6648000478744507,1.4506032466888428,10000,225853.3052201271,0.88978511095047,0.4156844913959503,0.7803199887275696,0.8548135757446289,50000 -17841.113161325455,37.99327063560486,208410.72726297376,468557,0,208410.72726297376,0.6648000478744507,1.4506032466888428,10000,226313.6565389633,0.8885741829872131,0.4178123772144317,0.7803199887275696,0.8548135757446289,50000 -17888.128987550735,38.08624720573425,208830.80399012569,469501,0,208830.80399012569,0.6648000478744507,1.4506032466888428,10000,226780.88972759247,0.8895312547683716,0.4101268053054809,0.7803199887275696,0.8548135757446289,50000 -17926.32308101654,38.17964243888855,209250.91027021408,470446,0,209250.91027021408,0.6648000478744507,1.4506032466888428,10000,227239.3307085037,0.8869921565055847,0.4194200932979584,0.7803199887275696,0.8548135757446289,50000 -17967.24115753174,38.296884059906006,209671.11908245087,471383,0,209671.11908245087,0.6648000478744507,1.4506032466888428,10000,227700.62208509445,0.8870898485183716,0.4167537093162536,0.7803199887275696,0.8548135757446289,50000 -18007.481026172638,38.41466212272644,210091.2394402027,472323,0,210091.2394402027,0.6648000478744507,1.4506032466888428,10000,228161.14769411087,0.8866991996765137,0.418349415063858,0.7803199887275696,0.8548135757446289,50000 -18044.091156482697,38.51059150695801,210511.18789815903,473266,0,210511.18789815903,0.6648000478744507,1.4506032466888428,10000,228617.8493804932,0.8862109184265137,0.419338971376419,0.7803199887275696,0.8548135757446289,50000 -18081.904654741287,38.62590575218201,210931.07780385017,474207,0,210931.07780385017,0.6648000478744507,1.4506032466888428,10000,229075.71579408649,0.8871679306030273,0.4201378226280212,0.7803199887275696,0.8548135757446289,50000 -18122.236476182938,38.743818521499634,211351.1733222008,475148,0,211351.1733222008,0.6648000478744507,1.4506032466888428,10000,229536.30795049667,0.8871874809265137,0.4222822189331054,0.7803199887275696,0.8548135757446289,50000 -18156.287446975708,38.8391637802124,211771.2523927689,476091,0,211771.2523927689,0.6648000478744507,1.4506032466888428,10000,229990.5814120769,0.887499988079071,0.4217891097068786,0.7803199887275696,0.8548135757446289,50000 -18195.48608493805,38.95311379432678,212191.43901252747,477034,0,212191.43901252747,0.6648000478744507,1.4506032466888428,10000,230450.12804293635,0.8870702981948853,0.420778214931488,0.7803199887275696,0.8548135757446289,50000 -18228.449116706848,39.07531595230103,212611.59212899208,477976,0,212611.59212899208,0.6648000478744507,1.4506032466888428,10000,230903.41416811943,0.8859765529632568,0.4200300574302673,0.7803199887275696,0.8548135757446289,50000 -18265.33827686309,39.19279146194458,213031.64986252785,478919,0,213031.64986252785,0.6648000478744507,1.4506032466888428,10000,231360.5252292156,0.8873242139816284,0.4137724637985229,0.7803199887275696,0.8548135757446289,50000 -18308.297819375992,39.30975008010864,213451.8512706757,479863,0,213451.8512706757,0.6648000478744507,1.4506032466888428,10000,231823.8510775566,0.8858593702316284,0.4239358007907867,0.7803199887275696,0.8548135757446289,50000 -18350.265253067017,39.41498994827271,213871.92631745336,480810,0,213871.92631745336,0.6648000478744507,1.4506032466888428,10000,232286.0463590622,0.8883398175239563,0.4194519817829132,0.7803199887275696,0.8548135757446289,50000 -18390.25630378723,39.5118944644928,214291.8306343556,481752,0,214291.8306343556,0.6648000478744507,1.4506032466888428,10000,232746.0855693817,0.8857226371765137,0.4243901968002319,0.7803199887275696,0.8548135757446289,50000 -18421.80694413185,39.63293313980103,214711.83514642715,482692,0,214711.83514642715,0.6648000478744507,1.4506032466888428,10000,233197.8092508316,0.8866796493530273,0.420348048210144,0.7803199887275696,0.8548135757446289,50000 -18468.27669906616,39.75440168380737,215131.989664793,483636,0,215131.989664793,0.6648000478744507,1.4506032466888428,10000,233664.60327482224,0.8885937333106995,0.4157139956951141,0.7803199887275696,0.8548135757446289,50000 -18508.301794052124,39.85543179512024,215551.972063303,484577,0,215551.972063303,0.6648000478744507,1.4506032466888428,10000,234124.7588033676,0.8880664110183716,0.4205309450626373,0.7803199887275696,0.8548135757446289,50000 -18541.02921271324,39.95377564430237,215972.1343464852,485518,0,215972.1343464852,0.6648000478744507,1.4506032466888428,10000,234577.79371976847,0.8874609470367432,0.4170184135437011,0.7803199887275696,0.8548135757446289,50000 -18584.534476995468,40.07387638092041,216392.3759646416,486459,0,216392.3759646416,0.6648000478744507,1.4506032466888428,10000,235041.70927786827,0.8861523270606995,0.4273178875446319,0.7803199887275696,0.8548135757446289,50000 -18623.737052679066,40.17147946357727,216812.4767594337,487404,0,216812.4767594337,0.6648000478744507,1.4506032466888428,10000,235501.1575639248,0.8877733945846558,0.4160880446434021,0.7803199887275696,0.8548135757446289,50000 -18668.17640280724,40.29322910308838,217232.3497395516,488346,0,217232.3497395516,0.6648000478744507,1.4506032466888428,10000,235965.63935542107,0.8884570002555847,0.4186307489871979,0.7803199887275696,0.8548135757446289,50000 -18705.208114624023,40.391053676605225,217652.6532354355,489290,0,217652.6532354355,0.6648000478744507,1.4506032466888428,10000,236423.1210463047,0.8884375095367432,0.4156650900840759,0.7803199887275696,0.8548135757446289,50000 -18751.618181467056,40.52058506011963,218072.6185135841,490233,0,218072.6185135841,0.6648000478744507,1.4506032466888428,10000,236889.67304754257,0.8855859041213989,0.418851226568222,0.7803199887275696,0.8548135757446289,50000 -18793.20076608658,40.64298367500305,218493.1903386116,491177,0,218493.1903386116,0.6648000478744507,1.4506032466888428,10000,237351.99741578105,0.8892577886581421,0.4170732498168945,0.7803199887275696,0.8548135757446289,50000 -18829.3963367939,40.74140095710754,218913.1488711834,492121,0,218913.1488711834,0.6648000478744507,1.4506032466888428,10000,237808.29712486267,0.8887695074081421,0.4192838668823242,0.7803199887275696,0.8548135757446289,50000 -18868.54022550583,40.892817974090576,219333.3186454773,493062,0,219333.3186454773,0.6648000478744507,1.4506032466888428,10000,238267.8093366623,0.8889257907867432,0.4122632443904876,0.7803199887275696,0.8548135757446289,50000 -18906.0834171772,41.02065348625183,219753.3205792904,494003,0,219753.3205792904,0.6648000478744507,1.4506032466888428,10000,238725.52990865707,0.8864257335662842,0.4204651415348053,0.7803199887275696,0.8548135757446289,50000 -18944.862055540085,41.14168167114258,220173.2574682236,494941,0,220173.2574682236,0.6648000478744507,1.4506032466888428,10000,239184.4129190445,0.8876562118530273,0.4156997203826904,0.7803199887275696,0.8548135757446289,50000 -18984.65509462357,41.261394739151,220593.34018707275,495883,0,220593.34018707275,0.6648000478744507,1.4506032466888428,10000,239644.45602679253,0.8883593678474426,0.414913535118103,0.7803199887275696,0.8548135757446289,50000 -19022.749748706818,41.38596391677856,221013.46103596687,496827,0,221013.46103596687,0.6648000478744507,1.4506032466888428,10000,240102.84364795685,0.8857421875,0.4213462769985199,0.7803199887275696,0.8548135757446289,50000 -19065.9390039444,41.51298332214356,221433.70231842995,497773,0,221433.70231842995,0.6648000478744507,1.4506032466888428,10000,240566.4491128921,0.8886523246765137,0.4136727452278137,0.7803199887275696,0.8548135757446289,50000 -19109.039420604706,41.63520693778992,221853.9527220726,498712,0,221853.9527220726,0.6648000478744507,1.4506032466888428,10000,241029.9697060585,0.8863085508346558,0.4256725311279297,0.7803199887275696,0.8548135757446289,50000 -19149.91508102417,41.73280024528504,222274.0769441128,499657,0,222274.0769441128,0.6648000478744507,1.4506032466888428,10000,241491.1147809029,0.8865820169448853,0.4248594641685486,0.7803199887275696,0.8548135757446289,50000 -19185.05141544342,41.83038401603699,222694.147285223,500603,0,222694.147285223,0.6648000478744507,1.4506032466888428,10000,241946.46710276604,0.8859961032867432,0.4199767112731933,0.7803199887275696,0.8548135757446289,50000 -19224.61433959008,41.95796823501587,223114.2418487072,501543,0,223114.2418487072,0.6648000478744507,1.4506032466888428,10000,242406.2993843556,0.8876171708106995,0.4205630719661712,0.7803199887275696,0.8548135757446289,50000 -19265.47282481193,42.08222556114197,223534.12419629097,502483,0,223534.12419629097,0.6648000478744507,1.4506032466888428,10000,242867.21188783649,0.8876757621765137,0.4135380089282989,0.7803199887275696,0.8548135757446289,50000 -19310.969926834103,42.20741128921509,223954.3198173046,503381,0,223954.3198173046,0.6648000478744507,1.4506032466888428,10000,243333.0746455193,0.8858984112739563,0.4253036081790924,0.7803199887275696,0.8548135757446289,50000 -19356.176063776016,42.30738377571106,224374.5869781971,504325,0,224374.5869781971,0.6648000478744507,1.4506032466888428,10000,243798.6955487728,0.8860546946525574,0.4211524724960327,0.7803199887275696,0.8548135757446289,50000 -19394.80228829384,42.43042516708374,224794.48518443108,505264,0,224794.48518443108,0.6648000478744507,1.4506032466888428,10000,244257.38936328888,0.8863866925239563,0.4218363761901855,0.7803199887275696,0.8548135757446289,50000 -19434.938640117645,42.55869674682617,225214.68184113505,506199,0,225214.68184113505,0.6648000478744507,1.4506032466888428,10000,244717.8974416256,0.8872656226158142,0.4188918173313141,0.7803199887275696,0.8548135757446289,50000 -19471.48560786248,42.659632205963135,225634.7327599525,507138,0,225634.7327599525,0.6648000478744507,1.4506032466888428,10000,245174.6440844536,0.8887499570846558,0.4152712821960449,0.7803199887275696,0.8548135757446289,50000 -19509.77243232727,42.78360247612,226054.76191449165,508079,0,226054.76191449165,0.6648000478744507,1.4506032466888428,10000,245633.1302809716,0.8889843821525574,0.4183682203292846,0.7803199887275696,0.8548135757446289,50000 -19558.337856054302,42.91069960594177,226475.2010500431,509023,0,226475.2010500431,0.6648000478744507,1.4506032466888428,10000,246102.30951094627,0.88525390625,0.4223649799823761,0.7803199887275696,0.8548135757446289,50000 -19593.461698532104,43.01064705848694,226895.11728739736,509966,0,226895.11728739736,0.6648000478744507,1.4506032466888428,10000,246557.4968225956,0.8899023532867432,0.4177020490169525,0.7803199887275696,0.8548135757446289,50000 -19630.303337335587,43.14096117019653,227315.35199666023,510908,0,227315.35199666023,0.6648000478744507,1.4506032466888428,10000,247014.7506687641,0.8851757645606995,0.4234828948974609,0.7803199887275696,0.8548135757446289,50000 -19674.89552092552,43.26824116706848,227735.47498559952,511850,0,227735.47498559952,0.6648000478744507,1.4506032466888428,10000,247479.6405696869,0.8869335651397705,0.4181903004646301,0.7803199887275696,0.8548135757446289,50000 -19717.473808527,43.37141489982605,228155.5176768303,512799,0,228155.5176768303,0.6648000478744507,1.4506032466888428,10000,247942.41296219823,0.8880273103713989,0.4198199510574341,0.7803199887275696,0.8548135757446289,50000 -19756.49240231514,43.47539925575256,228575.4135878086,513744,0,228575.4135878086,0.6648000478744507,1.4506032466888428,10000,248401.4799234867,0.8875195384025574,0.4158684015274048,0.7803199887275696,0.8548135757446289,50000 -19800.474741220474,43.62744307518005,228995.6517190933,514686,0,228995.6517190933,0.6648000478744507,1.4506032466888428,10000,248865.8999125957,0.8874218463897705,0.4217660427093506,0.7803199887275696,0.8548135757446289,50000 -19834.475734710693,43.72905158996582,229415.50728487968,515625,0,229415.50728487968,0.6648000478744507,1.4506032466888428,10000,249319.9051039219,0.8893554210662842,0.4156787693500519,0.7803199887275696,0.8548135757446289,50000 -19872.9740600586,43.85731816291809,229835.58886671063,516567,0,229835.58886671063,0.6648000478744507,1.4506032466888428,10000,249778.6609852314,0.8910155892372131,0.4081867933273315,0.7803199887275696,0.8548135757446289,50000 -19916.36367583275,43.98423504829407,230255.75998997688,517509,0,230255.75998997688,0.6648000478744507,1.4506032466888428,10000,250242.39598941803,0.8847851157188416,0.4224032461643219,0.7803199887275696,0.8548135757446289,50000 -19951.98074555397,44.08957409858704,230676.0453734398,518454,0,230676.0453734398,0.6648000478744507,1.4506032466888428,10000,250698.45132374763,0.8889062404632568,0.4124764800071716,0.7803199887275696,0.8548135757446289,50000 -19993.95086407661,44.21716904640198,231096.0907754898,519396,0,231096.0907754898,0.6648000478744507,1.4506032466888428,10000,251160.64197206497,0.8869726657867432,0.4183215200901031,0.7803199887275696,0.8548135757446289,50000 -20035.627217292786,44.343384742736816,231516.16703391075,520338,0,231516.16703391075,0.6648000478744507,1.4506032466888428,10000,251622.5686440468,0.8862890601158142,0.4208606481552124,0.7803199887275696,0.8548135757446289,50000 -20078.789390325543,44.4472975730896,231936.3868584633,521278,0,231936.3868584633,0.6648000478744507,1.4506032466888428,10000,252086.10162854195,0.8882226347923279,0.4174492061138153,0.7803199887275696,0.8548135757446289,50000 -20119.089794397354,44.54976439476013,232356.45809936523,522219,0,232356.45809936523,0.6648000478744507,1.4506032466888428,10000,252546.62259054184,0.8867968320846558,0.42232993245124817,0.7803199887275696,0.8548135757446289,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 3ec128f25..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5783 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.34623566,6.907756,,,,,,,,,,,,,, -1,,,0.0009374999790452,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.4552264213562,82.20404887199402,42.4552264213562,39.74869656562805,0.0,0.0 -100,0.52649134,6.8752217,,,,,,,,,,,,,, -200,0.6455489,6.762695,,,,,,,,,,,,,, -300,0.9554784,6.664156,,,,,,,,,,,,,, -400,1.2258662,6.5584536,,,,,,,,,,,,,, -500,1.0467201,6.606578,,,,,,,,,,,,,, -600,1.3212372,6.3255124,,,,,,,,,,,,,, -700,1.0250409,6.6177244,,,,,,,,,,,,,, -800,1.4597263,6.1847243,,,,,,,,,,,,,, -900,1.3834128,6.2025757,,,,,,,,,,,,,, -916,,,0.033281248062849,5.910701751708984,0.0322199985384941,5.930098056793213,50000.0,0.0265000015497207,6.053519248962402,10000.0,462.6057026386261,523.893707036972,462.6057026386261,61.21135711669922,0.0267481803894042,0.0 -1000,1.2575092,6.2155066,,,,,,,,,,,,,, -1100,1.7716651,6.315992,,,,,,,,,,,,,, -1200,1.2396065,6.387168,,,,,,,,,,,,,, -1300,1.1396884,6.582673,,,,,,,,,,,,,, -1400,2.2026162,6.0538697,,,,,,,,,,,,,, -1500,1.2853497,5.866876,,,,,,,,,,,,,, -1600,1.1889507,6.466545,,,,,,,,,,,,,, -1700,1.2360677,6.5629263,,,,,,,,,,,,,, -1800,0.9384642,5.618886,,,,,,,,,,,,,, -1880,,,0.0802929699420929,5.186488628387451,0.0763599947094917,5.225624084472656,50000.0,0.0578000023961067,5.46556282043457,10000.0,882.6500160694122,965.783664226532,882.6500160694122,82.97681331634521,0.0551979541778564,0.0 -1900,1.1694982,5.656164,,,,,,,,,,,,,, -2000,1.4329913,5.602625,,,,,,,,,,,,,, -2100,0.9682188,5.480915,,,,,,,,,,,,,, -2200,0.92416865,5.6872034,,,,,,,,,,,,,, -2300,1.2457442,5.5147076,,,,,,,,,,,,,, -2400,1.0248544,5.362468,,,,,,,,,,,,,, -2500,1.0472493,5.6927185,,,,,,,,,,,,,, -2600,1.1046997,5.264293,,,,,,,,,,,,,, -2700,1.1807617,5.2227693,,,,,,,,,,,,,, -2800,1.132992,5.45251,,,,,,,,,,,,,, -2845,,,0.1541601568460464,4.475606441497803,0.1417199969291687,4.5772247314453125,50000.0,0.1098000034689903,4.911513328552246,10000.0,1302.63339304924,1407.4959456920624,1302.63339304924,104.62666702270508,0.0829133987426757,0.0 -2900,1.0142568,5.3532143,,,,,,,,,,,,,, -3000,0.8810042,6.449263,,,,,,,,,,,,,, -3100,1.3235817,5.6135354,,,,,,,,,,,,,, -3200,1.0097533,5.069244,,,,,,,,,,,,,, -3300,1.0212032,5.456384,,,,,,,,,,,,,, -3400,1.1802398,4.9723635,,,,,,,,,,,,,, -3500,0.7265904,6.28454,,,,,,,,,,,,,, -3600,1.2875198,4.973568,,,,,,,,,,,,,, -3700,1.0270469,5.3896885,,,,,,,,,,,,,, -3800,1.0257565,4.8252897,,,,,,,,,,,,,, -3809,,,0.2120117098093032,3.999783754348755,0.1994999945163726,4.077298641204834,50000.0,0.1480000019073486,4.524412155151367,10000.0,1722.6479833126068,1849.3324444293976,1722.6479833126068,126.36748099327087,0.1124176979064941,0.0 -3900,0.8990842,4.6825237,,,,,,,,,,,,,, -4000,1.2313763,5.114217,,,,,,,,,,,,,, -4100,1.2196702,4.798887,,,,,,,,,,,,,, -4200,0.64662653,6.0468993,,,,,,,,,,,,,, -4300,0.84644186,5.2727156,,,,,,,,,,,,,, -4400,0.8930079,5.155258,,,,,,,,,,,,,, -4500,0.71823806,5.613675,,,,,,,,,,,,,, -4600,0.90707,6.084774,,,,,,,,,,,,,, -4700,0.7185641,6.1142864,,,,,,,,,,,,,, -4771,,,0.2714648246765136,3.5543103218078613,0.2515200078487396,3.682871580123901,50000.0,0.1887000054121017,4.161242961883545,10000.0,2142.643502473831,2291.534899711609,2142.643502473831,148.49134397506714,0.1440119743347168,0.0 -4800,1.0338947,4.58179,,,,,,,,,,,,,, -4900,0.8282867,4.6580076,,,,,,,,,,,,,, -5000,0.887419,4.3894954,,,,,,,,,,,,,, -5100,0.76173633,5.780642,,,,,,,,,,,,,, -5200,1.2754091,4.3424654,,,,,,,,,,,,,, -5300,0.75185955,4.287382,,,,,,,,,,,,,, -5400,1.0805875,4.1979213,,,,,,,,,,,,,, -5500,0.8496419,6.0587974,,,,,,,,,,,,,, -5600,0.9513956,4.181175,,,,,,,,,,,,,, -5700,0.85032487,4.1040215,,,,,,,,,,,,,, -5732,,,0.3236913979053497,3.2366814613342285,0.2985199987888336,3.391303062438965,50000.0,0.2252000123262405,3.929141998291016,10000.0,2562.6839604377747,2733.3190398216248,2562.6839604377747,170.15383672714233,0.1743319034576416,0.0 -5800,0.8184857,4.0789633,,,,,,,,,,,,,, -5900,0.78293556,4.3176575,,,,,,,,,,,,,, -6000,1.1006501,4.2455773,,,,,,,,,,,,,, -6100,0.68502426,5.425948,,,,,,,,,,,,,, -6200,0.9976946,4.054099,,,,,,,,,,,,,, -6300,0.69181263,5.999733,,,,,,,,,,,,,, -6400,1.0557741,3.9356482,,,,,,,,,,,,,, -6500,0.7275208,5.2727327,,,,,,,,,,,,,, -6600,0.8946457,4.2104974,,,,,,,,,,,,,, -6691,,,0.3828320205211639,2.853161573410034,0.3342599868774414,3.135330438613892,50000.0,0.2591000199317932,3.685933828353882,10000.0,2982.868365049362,3175.5051844120026,2982.868365049362,192.07357716560364,0.2041935920715332,0.0 -6700,0.9061523,3.8436608,,,,,,,,,,,,,, -6800,0.96557254,3.8462949,,,,,,,,,,,,,, -6900,0.8589657,5.759194,,,,,,,,,,,,,, -7000,1.0222623,3.8404164,,,,,,,,,,,,,, -7100,0.855709,5.822472,,,,,,,,,,,,,, -7200,0.7125733,5.575396,,,,,,,,,,,,,, -7300,0.9012677,3.8094442,,,,,,,,,,,,,, -7400,0.90289634,3.8452878,,,,,,,,,,,,,, -7500,1.0908474,3.7969627,,,,,,,,,,,,,, -7600,0.70966357,4.5165086,,,,,,,,,,,,,, -7649,,,0.3971289098262787,2.773566722869873,0.3694399893283844,2.9177207946777344,50000.0,0.2818000018596649,3.512402057647705,10000.0,3403.056977033615,3621.3701384067535,3403.056977033615,217.67090773582456,0.2334585189819336,0.0 -7700,0.91483486,3.7532995,,,,,,,,,,,,,, -7800,0.9582976,3.6600845,,,,,,,,,,,,,, -7900,0.751152,5.2503624,,,,,,,,,,,,,, -8000,0.89682436,3.7495856,,,,,,,,,,,,,, -8100,1.0207702,3.6787791,,,,,,,,,,,,,, -8200,0.86604685,4.511086,,,,,,,,,,,,,, -8300,1.0012045,3.6051366,,,,,,,,,,,,,, -8400,0.82432383,3.718683,,,,,,,,,,,,,, -8500,0.7590428,5.860086,,,,,,,,,,,,,, -8600,0.88406223,4.791514,,,,,,,,,,,,,, -8609,,,0.4228320121765136,2.649387121200561,0.3889999985694885,2.813301086425781,50000.0,0.300100028514862,3.4087908267974854,10000.0,3823.41290307045,4074.122909069061,3823.41290307045,249.9872395992279,0.2634403705596924,0.0 -8700,0.83436275,4.400118,,,,,,,,,,,,,, -8800,1.1320853,3.5076993,,,,,,,,,,,,,, -8900,0.8966866,3.8161244,,,,,,,,,,,,,, -9000,0.96780795,4.1887317,,,,,,,,,,,,,, -9100,0.9801889,3.4069881,,,,,,,,,,,,,, -9200,0.81984,5.9209566,,,,,,,,,,,,,, -9300,0.8799039,4.0573483,,,,,,,,,,,,,, -9400,0.846146,5.9266095,,,,,,,,,,,,,, -9500,0.9828367,3.3622415,,,,,,,,,,,,,, -9570,,,0.4484570324420929,2.475047826766968,0.4088599979877472,2.6899614334106445,50000.0,0.3128000199794769,3.3069517612457275,10000.0,4243.586100816727,4520.165455341339,4243.586100816727,275.773681640625,0.296008825302124,0.0 -9600,0.85397273,4.790209,,,,,,,,,,,,,, -9700,0.69653755,5.461704,,,,,,,,,,,,,, -9800,0.87689245,3.812744,,,,,,,,,,,,,, -9900,0.929278,5.3752174,,,,,,,,,,,,,, -10000,0.85364336,4.6580014,,,,,,,,,,,,,, -10100,0.83578897,5.5830255,,,,,,,,,,,,,, -10200,1.0251002,3.6030552,,,,,,,,,,,,,, -10300,0.93798953,3.8187735,,,,,,,,,,,,,, -10400,0.6768274,5.2514734,,,,,,,,,,,,,, -10500,0.95580673,3.4823587,,,,,,,,,,,,,, -10527,,,0.4597460925579071,2.403148651123047,0.4257799983024597,2.5876495838165283,50000.0,0.323600023984909,3.2185275554656982,10000.0,4663.614185810089,4965.85725569725,4663.614185810089,301.3532176017761,0.3298141956329345,0.0 -10600,1.067367,3.383986,,,,,,,,,,,,,, -10700,0.9524732,4.019099,,,,,,,,,,,,,, -10800,1.0135487,3.3227546,,,,,,,,,,,,,, -10900,1.0289837,3.2549539,,,,,,,,,,,,,, -11000,0.74915296,4.4924846,,,,,,,,,,,,,, -11100,0.96737885,3.386742,,,,,,,,,,,,,, -11200,0.8521608,4.0597806,,,,,,,,,,,,,, -11300,0.8900598,3.2623973,,,,,,,,,,,,,, -11400,1.1388012,3.329135,,,,,,,,,,,,,, -11485,,,0.4774218499660492,2.352294921875,0.4411799907684326,2.527583599090576,50000.0,0.3461000025272369,3.1271629333496094,10000.0,5083.6017208099365,5413.76374578476,5083.6017208099365,329.1883616447449,0.3632173538208008,0.0 -11500,0.9928252,3.3143148,,,,,,,,,,,,,, -11600,1.1077954,3.427373,,,,,,,,,,,,,, -11700,0.8266717,4.217048,,,,,,,,,,,,,, -11800,0.9846404,4.0721016,,,,,,,,,,,,,, -11900,1.2468655,3.4263475,,,,,,,,,,,,,, -12000,1.0527915,3.2734969,,,,,,,,,,,,,, -12100,0.95702416,3.2861378,,,,,,,,,,,,,, -12200,1.0500425,3.2791812,,,,,,,,,,,,,, -12300,1.07598,3.4754672,,,,,,,,,,,,,, -12400,1.0469462,3.1798368,,,,,,,,,,,,,, -12442,,,0.5004491806030273,2.173110246658325,0.4589599967002868,2.39906644821167,50000.0,0.3476000130176544,3.0680062770843506,10000.0,5503.675626039505,5861.807518482208,5503.675626039505,357.07783699035645,0.3932712078094482,0.0 -12500,0.80676657,3.8408556,,,,,,,,,,,,,, -12600,1.1828213,3.3934882,,,,,,,,,,,,,, -12700,1.1160352,3.2581525,,,,,,,,,,,,,, -12800,1.1236506,3.2053561,,,,,,,,,,,,,, -12900,0.88092303,3.5676506,,,,,,,,,,,,,, -13000,1.1519202,3.1648445,,,,,,,,,,,,,, -13100,1.3215302,3.2210479,,,,,,,,,,,,,, -13200,0.9869574,3.6578736,,,,,,,,,,,,,, -13300,0.9972899,3.5977545,,,,,,,,,,,,,, -13395,,,0.5154492259025574,2.150564432144165,0.4667999744415283,2.3906404972076416,50000.0,0.3655000030994415,3.0263166427612305,10000.0,5923.727476358414,6310.735496520996,5923.727476358414,385.8735525608063,0.4236998558044433,0.0 -13400,1.0526948,3.0301828,,,,,,,,,,,,,, -13500,1.0339258,3.2207205,,,,,,,,,,,,,, -13600,1.181764,4.670037,,,,,,,,,,,,,, -13700,0.95560217,4.077657,,,,,,,,,,,,,, -13800,0.7313352,5.6174355,,,,,,,,,,,,,, -13900,0.964398,3.1104362,,,,,,,,,,,,,, -14000,1.0061243,3.1285381,,,,,,,,,,,,,, -14100,0.8816753,3.81509,,,,,,,,,,,,,, -14200,1.0774865,3.1728697,,,,,,,,,,,,,, -14300,0.9610594,4.751065,,,,,,,,,,,,,, -14346,,,0.519238293170929,2.1266376972198486,0.480239987373352,2.312174320220948,50000.0,0.3747000098228454,2.943275213241577,10000.0,6343.8371758461,6761.744855165482,6343.8371758461,416.6871347427368,0.4599125385284424,0.0 -14400,0.98136526,3.1144056,,,,,,,,,,,,,, -14500,1.1019996,3.1711159,,,,,,,,,,,,,, -14600,1.0690984,3.1579309,,,,,,,,,,,,,, -14700,0.99312186,3.129111,,,,,,,,,,,,,, -14800,1.0031726,3.7286482,,,,,,,,,,,,,, -14900,0.834773,4.4543295,,,,,,,,,,,,,, -15000,1.1527951,3.0031052,,,,,,,,,,,,,, -15100,0.9822812,3.7972865,,,,,,,,,,,,,, -15200,1.1619378,2.9810834,,,,,,,,,,,,,, -15295,,,0.5322265625,2.005976915359497,0.4880199730396271,2.22807240486145,50000.0,0.3838000297546386,2.8951008319854736,10000.0,6764.088172674179,7212.294851779938,6764.088172674179,446.9011867046356,0.4952383041381836,0.0 -15300,1.0429521,3.7270575,,,,,,,,,,,,,, -15400,1.1588862,3.5533154,,,,,,,,,,,,,, -15500,1.4425645,3.002294,,,,,,,,,,,,,, -15600,1.0071324,4.5895934,,,,,,,,,,,,,, -15700,1.1896076,3.139247,,,,,,,,,,,,,, -15800,1.0984057,3.0179532,,,,,,,,,,,,,, -15900,0.81318134,4.2200375,,,,,,,,,,,,,, -16000,0.8331852,5.250404,,,,,,,,,,,,,, -16100,0.86696804,5.579104,,,,,,,,,,,,,, -16200,1.0089873,5.633645,,,,,,,,,,,,,, -16244,,,0.5482421517372131,1.9572895765304563,0.501259982585907,2.1907479763031006,50000.0,0.3926000297069549,2.8501899242401123,10000.0,7184.348518371582,7662.375813245773,7184.348518371582,476.6402099132538,0.5272412300109863,0.0 -16300,1.1941214,2.9658084,,,,,,,,,,,,,, -16400,0.7785389,5.463795,,,,,,,,,,,,,, -16500,0.9182188,4.2922053,,,,,,,,,,,,,, -16600,1.2322896,3.2313786,,,,,,,,,,,,,, -16700,1.0111521,4.60879,,,,,,,,,,,,,, -16800,0.9897801,3.9400878,,,,,,,,,,,,,, -16900,1.0534728,3.2909975,,,,,,,,,,,,,, -17000,0.8201897,4.7808714,,,,,,,,,,,,,, -17100,0.984798,3.1929276,,,,,,,,,,,,,, -17194,,,0.5625194907188416,1.8765634298324585,0.5061399936676025,2.151052713394165,50000.0,0.38960000872612,2.829267263412476,10000.0,7604.298126220703,8113.509242534637,7604.298126220703,507.73856592178345,0.5607478618621826,0.0 -17200,0.9448476,3.445211,,,,,,,,,,,,,, -17300,0.89525723,5.576943,,,,,,,,,,,,,, -17400,1.2081376,3.5261333,,,,,,,,,,,,,, -17500,1.0425302,2.920371,,,,,,,,,,,,,, -17600,1.1109194,3.011371,,,,,,,,,,,,,, -17700,1.2220799,2.9110625,,,,,,,,,,,,,, -17800,1.1700886,2.959792,,,,,,,,,,,,,, -17900,1.6422386,3.06146,,,,,,,,,,,,,, -18000,0.95613205,3.9766724,,,,,,,,,,,,,, -18100,1.1394893,2.8742044,,,,,,,,,,,,,, -18141,,,0.5486718416213989,1.961691856384277,0.5062800049781799,2.1555776596069336,50000.0,0.3961000144481659,2.798175811767578,10000.0,8024.374304771423,8565.522846698761,8024.374304771423,539.5802881717682,0.6070611476898193,0.0 -18200,1.0991712,2.8963916,,,,,,,,,,,,,, -18300,1.0893056,3.5102158,,,,,,,,,,,,,, -18400,1.0407798,3.1780834,,,,,,,,,,,,,, -18500,1.2468452,3.1588042,,,,,,,,,,,,,, -18600,1.3420256,3.461381,,,,,,,,,,,,,, -18700,0.8917214,3.7511263,,,,,,,,,,,,,, -18800,1.011483,3.0951293,,,,,,,,,,,,,, -18900,0.9572939,5.4161844,,,,,,,,,,,,,, -19000,0.8973353,4.335707,,,,,,,,,,,,,, -19089,,,0.5634179711341858,1.8526079654693604,0.5212599635124207,2.0621883869171143,50000.0,0.4065000116825104,2.7337896823883057,10000.0,8444.733457803726,9017.617252111437,8444.733457803726,571.2340226173401,0.6397192478179932,0.0 -19100,0.9011408,4.015272,,,,,,,,,,,,,, -19200,0.90471,5.095805,,,,,,,,,,,,,, -19300,0.9697214,4.072724,,,,,,,,,,,,,, -19400,0.9772347,4.6376014,,,,,,,,,,,,,, -19500,1.0718061,3.35622,,,,,,,,,,,,,, -19600,0.9483915,4.463668,,,,,,,,,,,,,, -19700,1.0238023,3.457419,,,,,,,,,,,,,, -19800,0.9444734,4.468792,,,,,,,,,,,,,, -19900,1.062793,2.8020954,,,,,,,,,,,,,, -20000,1.168367,3.1594274,,,,,,,,,,,,,, -20038,,,0.5821679830551147,1.7724109888076782,0.5266799926757812,2.037682056427002,50000.0,0.414900004863739,2.704115629196167,10000.0,8864.943633794785,9469.092945098875,8864.943633794785,602.41819190979,0.6724779605865479,0.0 -20100,1.065932,2.947472,,,,,,,,,,,,,, -20200,0.95143056,3.3645816,,,,,,,,,,,,,, -20300,0.89714766,4.5452013,,,,,,,,,,,,,, -20400,1.0009384,2.8524346,,,,,,,,,,,,,, -20500,1.2700154,2.8654492,,,,,,,,,,,,,, -20600,0.9161431,4.430026,,,,,,,,,,,,,, -20700,1.106778,2.8849266,,,,,,,,,,,,,, -20800,1.107489,3.880317,,,,,,,,,,,,,, -20900,0.9210259,4.6228557,,,,,,,,,,,,,, -20987,,,0.5732030868530273,1.835625052452088,0.5278199911117554,2.045748949050904,50000.0,0.4214000105857849,2.6908113956451416,10000.0,9285.068341493608,9920.836732149124,9285.068341493608,633.9534032344818,0.7069294452667236,0.0 -21000,1.1053523,3.4464464,,,,,,,,,,,,,, -21100,1.0249959,3.3955226,,,,,,,,,,,,,, -21200,0.9934037,4.71013,,,,,,,,,,,,,, -21300,1.0443302,2.740234,,,,,,,,,,,,,, -21400,1.5400034,2.8739169,,,,,,,,,,,,,, -21500,0.95974725,3.853918,,,,,,,,,,,,,, -21600,0.96519285,2.9370959,,,,,,,,,,,,,, -21700,0.8413724,5.1080036,,,,,,,,,,,,,, -21800,1.0799111,2.865087,,,,,,,,,,,,,, -21900,1.1156863,2.7930965,,,,,,,,,,,,,, -21934,,,0.5780468583106995,1.77647066116333,0.5388399958610535,1.979406476020813,50000.0,0.4229000210762024,2.6369969844818115,10000.0,9705.1229493618,10372.515660524368,9705.1229493618,665.4959726333618,0.7394287586212158,0.0 -22000,1.1809773,2.795168,,,,,,,,,,,,,, -22100,1.2153258,2.8018205,,,,,,,,,,,,,, -22200,0.9985454,2.7908702,,,,,,,,,,,,,, -22300,1.0151008,3.1642404,,,,,,,,,,,,,, -22400,1.1203496,2.8378808,,,,,,,,,,,,,, -22500,1.1906703,3.5130043,,,,,,,,,,,,,, -22600,0.93597955,4.894245,,,,,,,,,,,,,, -22700,0.9131479,3.5564947,,,,,,,,,,,,,, -22800,1.1330574,2.8244243,,,,,,,,,,,,,, -22875,,,0.5887890458106995,1.7585443258285522,0.5386399626731873,2.00494122505188,50000.0,0.4220000207424164,2.671059846878052,10000.0,10125.126637935638,10825.470078229904,10125.126637935638,698.3626811504364,0.7748537063598633,0.0 -22900,1.0222069,3.7660673,,,,,,,,,,,,,, -23000,0.9953414,3.8818955,,,,,,,,,,,,,, -23100,0.92763555,3.5427089,,,,,,,,,,,,,, -23200,1.2812532,2.93415,,,,,,,,,,,,,, -23300,1.229721,2.8522952,,,,,,,,,,,,,, -23400,1.1445717,3.0791383,,,,,,,,,,,,,, -23500,1.1528342,2.886561,,,,,,,,,,,,,, -23600,1.2552317,2.8416905,,,,,,,,,,,,,, -23700,0.96410775,4.6562724,,,,,,,,,,,,,, -23800,1.1174577,2.7919567,,,,,,,,,,,,,, -23826,,,0.623339831829071,1.555069088935852,0.5521999597549438,1.9062819480896,50000.0,0.4359000325202942,2.585143566131592,10000.0,10545.515405893326,11278.99824810028,10545.515405893326,731.4255640506744,0.802415132522583,0.0 -23900,1.0852693,3.1706402,,,,,,,,,,,,,, -24000,1.2828143,3.013649,,,,,,,,,,,,,, -24100,1.2483836,2.6339269,,,,,,,,,,,,,, -24200,1.1524262,2.745508,,,,,,,,,,,,,, -24300,0.983893,3.011723,,,,,,,,,,,,,, -24400,1.2415146,2.7274425,,,,,,,,,,,,,, -24500,1.2141371,2.899479,,,,,,,,,,,,,, -24600,1.1916348,2.7993107,,,,,,,,,,,,,, -24700,1.1876992,2.968753,,,,,,,,,,,,,, -24775,,,0.5954492092132568,1.7052239179611206,0.5511800050735474,1.9209455251693728,50000.0,0.4387000203132629,2.591357707977295,10000.0,10965.437469005585,11730.820744037628,10965.437469005585,763.2415297031403,0.8369770050048828,0.0 -24800,1.1066827,2.7408051,,,,,,,,,,,,,, -24900,1.0966454,2.5511029,,,,,,,,,,,,,, -25000,1.0903144,2.8507314,,,,,,,,,,,,,, -25100,1.2093023,2.7454767,,,,,,,,,,,,,, -25200,1.1992506,2.6775336,,,,,,,,,,,,,, -25300,1.1238497,2.9021995,,,,,,,,,,,,,, -25400,1.1926098,2.7466235,,,,,,,,,,,,,, -25500,1.0944408,2.976767,,,,,,,,,,,,,, -25600,1.1189625,2.7423697,,,,,,,,,,,,,, -25700,1.24398,2.7199702,,,,,,,,,,,,,, -25720,,,0.6026171445846558,1.657276272773743,0.5577200055122375,1.8794413805007928,50000.0,0.4342000186443329,2.559943437576294,10000.0,11385.874091625214,12182.853080272676,11385.874091625214,794.7530353069305,0.8722774982452393,0.0 -25800,1.1146587,3.29522,,,,,,,,,,,,,, -25900,1.0210552,2.9078586,,,,,,,,,,,,,, -26000,1.1295004,2.7184362,,,,,,,,,,,,,, -26100,1.144012,2.7628734,,,,,,,,,,,,,, -26200,1.1219857,2.9953272,,,,,,,,,,,,,, -26300,1.0884839,3.0062628,,,,,,,,,,,,,, -26400,1.0050367,3.2762656,,,,,,,,,,,,,, -26500,1.1783826,2.6505427,,,,,,,,,,,,,, -26600,1.0960122,4.3591924,,,,,,,,,,,,,, -26657,,,0.6150780916213989,1.6037347316741943,0.5588799715042114,1.8744468688964844,50000.0,0.4451000094413757,2.543614387512207,10000.0,11804.853033542631,12634.47670030594,11804.853033542631,826.1598796844482,2.062011241912842,0.0 -26700,0.9985722,2.7401488,,,,,,,,,,,,,, -26800,1.1036354,2.8215423,,,,,,,,,,,,,, -26900,1.1607598,2.6746473,,,,,,,,,,,,,, -27000,1.011452,5.0711927,,,,,,,,,,,,,, -27100,1.0655224,5.410843,,,,,,,,,,,,,, -27200,1.1920081,2.7064831,,,,,,,,,,,,,, -27300,0.9750459,3.6510966,,,,,,,,,,,,,, -27400,0.9638662,4.257717,,,,,,,,,,,,,, -27500,1.1260642,3.292092,,,,,,,,,,,,,, -27600,1.2294974,2.6608667,,,,,,,,,,,,,, -27602,,,0.6069140434265137,1.6219347715377808,0.5689600110054016,1.815601825714112,50000.0,0.4532000124454498,2.5086917877197266,10000.0,12224.784992218018,13087.04343366623,12224.784992218018,858.7074551582336,2.100292921066284,0.0 -27700,1.1877655,2.7390127,,,,,,,,,,,,,, -27800,1.3602095,5.4030685,,,,,,,,,,,,,, -27900,1.1916181,4.4694858,,,,,,,,,,,,,, -28000,1.1608822,2.7871478,,,,,,,,,,,,,, -28100,1.2105045,2.567799,,,,,,,,,,,,,, -28200,1.1777787,2.627718,,,,,,,,,,,,,, -28300,1.1408519,5.330885,,,,,,,,,,,,,, -28400,1.0768445,2.6367006,,,,,,,,,,,,,, -28500,1.1287854,2.8140926,,,,,,,,,,,,,, -28549,,,0.6090429425239563,1.6417275667190552,0.5631799697875977,1.8570791482925413,50000.0,0.449500024318695,2.522683620452881,10000.0,12645.118678808212,13540.391645908356,12645.118678808212,891.6379976272583,2.135059118270874,0.0 -28600,1.0104961,4.4001822,,,,,,,,,,,,,, -28700,1.0257972,3.5156236,,,,,,,,,,,,,, -28800,1.1243724,2.671914,,,,,,,,,,,,,, -28900,1.2520598,2.7029428,,,,,,,,,,,,,, -29000,1.2723665,2.6364172,,,,,,,,,,,,,, -29100,1.1366664,2.6198845,,,,,,,,,,,,,, -29200,1.1302819,2.748452,,,,,,,,,,,,,, -29300,1.0022851,3.8460875,,,,,,,,,,,,,, -29400,1.2311988,2.6572351,,,,,,,,,,,,,, -29497,,,0.6200780868530273,1.5937964916229248,0.5712400078773499,1.8428938388824463,50000.0,0.4493000209331512,2.5253915786743164,10000.0,13065.383793115616,13990.263841152191,13065.383793115616,921.1676602363586,2.163169860839844,0.0 -29500,1.071381,2.616974,,,,,,,,,,,,,, -29600,0.9117318,4.17467,,,,,,,,,,,,,, -29700,1.0916055,3.0233493,,,,,,,,,,,,,, -29800,1.0808337,2.5893426,,,,,,,,,,,,,, -29900,0.9927817,4.7605996,,,,,,,,,,,,,, -30000,1.2126865,2.6494048,,,,,,,,,,,,,, -30100,1.1654962,2.6745014,,,,,,,,,,,,,, -30200,0.9761638,4.1864686,,,,,,,,,,,,,, -30300,0.9424291,4.280094,,,,,,,,,,,,,, -30400,1.0128518,3.0931866,,,,,,,,,,,,,, -30441,,,0.646484375,1.4701223373413086,0.5729599595069885,1.815598130226136,50000.0,0.4524000287055969,2.4883368015289307,10000.0,13485.3846244812,14443.974184513092,13485.3846244812,954.7928731441498,2.198180913925171,0.0 -30500,1.2367383,2.589811,,,,,,,,,,,,,, -30600,1.0493463,4.6282754,,,,,,,,,,,,,, -30700,1.1934692,2.717774,,,,,,,,,,,,,, -30800,1.2010885,2.7728596,,,,,,,,,,,,,, -30900,1.17893,2.5968583,,,,,,,,,,,,,, -31000,1.0760506,2.556191,,,,,,,,,,,,,, -31100,0.9987111,3.7868814,,,,,,,,,,,,,, -31200,1.2042675,2.5897036,,,,,,,,,,,,,, -31300,0.90285105,5.2316947,,,,,,,,,,,,,, -31388,,,0.6192382574081421,1.590422749519348,0.5780199766159058,1.7954529523849487,50000.0,0.4555000364780426,2.473357677459717,10000.0,13905.559622764587,14896.015183925629,13905.559622764587,986.5768990516664,2.23132872581482,0.0 -31400,1.0084697,4.3948307,,,,,,,,,,,,,, -31500,1.1219262,2.5926952,,,,,,,,,,,,,, -31600,0.98899055,4.715502,,,,,,,,,,,,,, -31700,1.068072,2.8544652,,,,,,,,,,,,,, -31800,1.0299333,4.171941,,,,,,,,,,,,,, -31900,1.2068783,2.6270683,,,,,,,,,,,,,, -32000,1.0301292,3.2345319,,,,,,,,,,,,,, -32100,0.9611082,5.3134775,,,,,,,,,,,,,, -32200,1.240132,2.7470179,,,,,,,,,,,,,, -32300,1.1327513,3.778747,,,,,,,,,,,,,, -32334,,,0.6196874976158142,1.594934344291687,0.5723400115966797,1.8144795894622805,50000.0,0.4574000239372253,2.485887289047241,10000.0,14325.485937595367,15350.62278676033,14325.485937595367,1021.1722357273102,2.2684500217437744,0.0 -32400,1.0212518,3.9465656,,,,,,,,,,,,,, -32500,1.0089685,5.214801,,,,,,,,,,,,,, -32600,0.9836788,3.8713498,,,,,,,,,,,,,, -32700,1.1367413,2.6867259,,,,,,,,,,,,,, -32800,1.1922773,2.7287395,,,,,,,,,,,,,, -32900,1.247301,2.6714585,,,,,,,,,,,,,, -33000,1.2353436,2.6991167,,,,,,,,,,,,,, -33100,1.3660163,2.5946493,,,,,,,,,,,,,, -33200,0.92646545,4.7633386,,,,,,,,,,,,,, -33281,,,0.6363476514816284,1.4893057346343994,0.5823799967765808,1.7727079391479492,50000.0,0.4593000113964081,2.459869623184204,10000.0,14745.835205078123,15802.050520658491,14745.835205078123,1052.1730644702911,2.2977893352508545,0.0 -33300,0.92965007,3.4588258,,,,,,,,,,,,,, -33400,1.0553187,4.3830657,,,,,,,,,,,,,, -33500,1.2365819,2.622846,,,,,,,,,,,,,, -33600,0.97344357,5.1795807,,,,,,,,,,,,,, -33700,1.2331885,2.5398545,,,,,,,,,,,,,, -33800,0.97825235,4.1226435,,,,,,,,,,,,,, -33900,1.2694283,2.623284,,,,,,,,,,,,,, -34000,1.2005153,2.6623864,,,,,,,,,,,,,, -34100,1.3481485,2.5374398,,,,,,,,,,,,,, -34200,1.0191903,5.239567,,,,,,,,,,,,,, -34224,,,0.6278319954872131,1.517386555671692,0.5886200070381165,1.7128303050994873,50000.0,0.4670000076293945,2.411027908325196,10000.0,15165.793439865112,16253.66602897644,15165.793439865112,1083.7390191555023,2.3401732444763184,0.0 -34300,1.2316817,2.4745708,,,,,,,,,,,,,, -34400,1.1140995,2.865687,,,,,,,,,,,,,, -34500,1.0774777,3.1101975,,,,,,,,,,,,,, -34600,1.1438644,2.4353685,,,,,,,,,,,,,, -34700,1.2985766,2.58683,,,,,,,,,,,,,, -34800,1.0926192,4.851065,,,,,,,,,,,,,, -34900,1.1358078,2.7102022,,,,,,,,,,,,,, -35000,0.95844215,3.7190833,,,,,,,,,,,,,, -35100,1.0396769,5.113409,,,,,,,,,,,,,, -35169,,,0.6328905820846558,1.534470796585083,0.5838599801063538,1.7675602436065674,50000.0,0.4645000100135803,2.427992582321167,10000.0,15585.799923658373,16705.374019622803,15585.799923658373,1115.3563861846924,2.374427318572998,0.0 -35200,1.1067208,2.550784,,,,,,,,,,,,,, -35300,1.0979109,5.0909057,,,,,,,,,,,,,, -35400,1.0786015,5.1939955,,,,,,,,,,,,,, -35500,1.0794458,3.2092943,,,,,,,,,,,,,, -35600,1.1256208,3.3633895,,,,,,,,,,,,,, -35700,1.1902744,2.5745554,,,,,,,,,,,,,, -35800,1.2215744,2.530317,,,,,,,,,,,,,, -35900,1.138504,2.863388,,,,,,,,,,,,,, -36000,1.0266418,3.1233718,,,,,,,,,,,,,, -36100,1.2284586,2.5168464,,,,,,,,,,,,,, -36114,,,0.6386132836341858,1.485224366188049,0.585099995136261,1.736467719078064,50000.0,0.4657000303268432,2.405973196029663,10000.0,16005.73632979393,17157.829701662064,16005.73632979393,1147.792130947113,2.408688545227051,0.0 -36200,1.1269599,3.2170243,,,,,,,,,,,,,, -36300,1.1554309,2.7085323,,,,,,,,,,,,,, -36400,1.1176227,2.497357,,,,,,,,,,,,,, -36500,1.1218302,2.7279005,,,,,,,,,,,,,, -36600,1.1239182,3.1581864,,,,,,,,,,,,,, -36700,1.0375929,3.4987485,,,,,,,,,,,,,, -36800,1.1827582,2.4352157,,,,,,,,,,,,,, -36900,1.0121273,4.420497,,,,,,,,,,,,,, -37000,0.99440205,4.0426755,,,,,,,,,,,,,, -37060,,,0.6688476204872131,1.3819684982299805,0.5868600010871887,1.7398897409439087,50000.0,0.4737000167369842,2.4095163345336914,10000.0,16425.699372053146,17610.38723230362,16425.699372053146,1180.2964413166046,2.449498414993286,0.0 -37100,0.98594826,4.8766856,,,,,,,,,,,,,, -37200,1.1347187,2.3909009,,,,,,,,,,,,,, -37300,1.236188,2.5494227,,,,,,,,,,,,,, -37400,0.9812389,4.4514656,,,,,,,,,,,,,, -37500,0.9418548,3.8061144,,,,,,,,,,,,,, -37600,1.5702544,2.6385849,,,,,,,,,,,,,, -37700,1.2575518,2.5544913,,,,,,,,,,,,,, -37800,1.3664176,2.492715,,,,,,,,,,,,,, -37900,1.0555259,3.3658278,,,,,,,,,,,,,, -38000,1.0563486,3.8624296,,,,,,,,,,,,,, -38005,,,0.639355480670929,1.477370262145996,0.592960000038147,1.6992011070251465,50000.0,0.4795000255107879,2.360228538513184,10000.0,16845.78829050064,18060.96590399742,16845.78829050064,1210.706288814545,2.4802818298339844,0.0 -38100,1.3032886,2.5326016,,,,,,,,,,,,,, -38200,1.3561401,4.9053063,,,,,,,,,,,,,, -38300,1.0607159,3.1732337,,,,,,,,,,,,,, -38400,1.0028391,4.596732,,,,,,,,,,,,,, -38500,1.184615,2.4883013,,,,,,,,,,,,,, -38600,1.4508624,2.5258913,,,,,,,,,,,,,, -38700,1.1562653,3.120657,,,,,,,,,,,,,, -38800,1.0681123,3.1429212,,,,,,,,,,,,,, -38900,1.2373486,2.5084822,,,,,,,,,,,,,, -38949,,,0.6482812166213989,1.4299132823944092,0.601099967956543,1.6725398302078247,50000.0,0.4771000146865845,2.336911916732788,10000.0,17265.709993124008,18511.710064649586,17265.709993124008,1241.442130804062,2.518287420272827,0.0 -39000,1.1609954,2.576531,,,,,,,,,,,,,, -39100,1.0056709,2.9561043,,,,,,,,,,,,,, -39200,1.2936106,2.5740964,,,,,,,,,,,,,, -39300,1.2369452,2.4335706,,,,,,,,,,,,,, -39400,0.9975706,3.386596,,,,,,,,,,,,,, -39500,0.95206493,4.5432124,,,,,,,,,,,,,, -39600,1.1011769,3.7517006,,,,,,,,,,,,,, -39700,0.9470005,4.544914,,,,,,,,,,,,,, -39800,1.2787338,2.6436734,,,,,,,,,,,,,, -39893,,,0.6584374904632568,1.4017064571380615,0.6007599830627441,1.6774709224700928,50000.0,0.4825000166893005,2.327033042907715,10000.0,17685.679537296295,18962.58662962913,17685.679537296295,1272.2637612819672,2.5543577671051025,0.0 -39900,1.2393787,2.59765,,,,,,,,,,,,,, -40000,0.95687914,4.2842956,,,,,,,,,,,,,, -40100,1.1900703,2.4670708,,,,,,,,,,,,,, -40200,1.2662493,2.6761336,,,,,,,,,,,,,, -40300,1.10239,2.607838,,,,,,,,,,,,,, -40400,1.2844744,2.5415633,,,,,,,,,,,,,, -40500,1.0609529,3.5759246,,,,,,,,,,,,,, -40600,1.2401102,4.5457034,,,,,,,,,,,,,, -40700,1.0590366,3.6305113,,,,,,,,,,,,,, -40800,1.1634911,2.8024673,,,,,,,,,,,,,, -40837,,,0.6469140648841858,1.4363796710968018,0.6001799702644348,1.6573528051376345,50000.0,0.4824000298976898,2.323774814605713,10000.0,18105.65966153145,19414.26227951049,18105.65966153145,1303.868986368179,2.595737934112549,0.0 -40900,1.0629907,4.065464,,,,,,,,,,,,,, -41000,1.0171828,3.727144,,,,,,,,,,,,,, -41100,1.2105209,2.5940015,,,,,,,,,,,,,, -41200,1.256391,2.6300752,,,,,,,,,,,,,, -41300,1.1762364,2.5829546,,,,,,,,,,,,,, -41400,1.169902,2.666779,,,,,,,,,,,,,, -41500,1.2773194,2.5646918,,,,,,,,,,,,,, -41600,1.0356096,4.0538716,,,,,,,,,,,,,, -41700,1.1501144,2.4396105,,,,,,,,,,,,,, -41783,,,0.6481835842132568,1.4439572095870972,0.5995799899101257,1.6738475561141968,50000.0,0.4859000146389007,2.331721544265747,10000.0,18525.777707338333,19865.08677005768,18525.777707338333,1334.4877030849457,2.63551926612854,0.0 -41800,1.2181842,2.4962258,,,,,,,,,,,,,, -41900,1.158538,2.5115342,,,,,,,,,,,,,, -42000,1.3111582,2.3613946,,,,,,,,,,,,,, -42100,1.2099841,2.5119767,,,,,,,,,,,,,, -42200,1.1311092,2.735665,,,,,,,,,,,,,, -42300,1.13014,2.4535477,,,,,,,,,,,,,, -42400,1.0359981,4.3381968,,,,,,,,,,,,,, -42500,1.049916,3.8001184,,,,,,,,,,,,,, -42600,1.229364,2.5809183,,,,,,,,,,,,,, -42700,1.2244532,2.5828333,,,,,,,,,,,,,, -42727,,,0.6622461080551147,1.4292950630187988,0.6039800047874451,1.6770350933074951,50000.0,0.4874000251293182,2.338141202926636,10000.0,18945.784809350967,20315.787051200867,18945.784809350967,1365.0891120433807,2.6732630729675293,0.0 -42800,1.15798,2.9270763,,,,,,,,,,,,,, -42900,1.0503571,3.685683,,,,,,,,,,,,,, -43000,1.0243062,4.9420695,,,,,,,,,,,,,, -43100,1.0334996,5.0285654,,,,,,,,,,,,,, -43200,1.2264663,3.1481922,,,,,,,,,,,,,, -43300,1.1979302,2.4747083,,,,,,,,,,,,,, -43400,1.3822685,2.5349183,,,,,,,,,,,,,, -43500,1.3323306,2.3599336,,,,,,,,,,,,,, -43600,1.0340141,4.9488153,,,,,,,,,,,,,, -43675,,,0.6868749856948853,1.3137249946594238,0.6068400144577026,1.681557297706604,50000.0,0.4832000136375427,2.350550413131714,10000.0,19365.801652669907,20766.421213150024,19365.801652669907,1395.6237680912018,2.708239316940308,0.0 -43700,1.1344353,2.8012068,,,,,,,,,,,,,, -43800,1.270907,2.511134,,,,,,,,,,,,,, -43900,1.2832873,2.4644666,,,,,,,,,,,,,, -44000,1.1977835,3.17855,,,,,,,,,,,,,, -44100,1.0837656,4.8416085,,,,,,,,,,,,,, -44200,1.0648148,3.009045,,,,,,,,,,,,,, -44300,1.1943519,2.5684893,,,,,,,,,,,,,, -44400,1.2082485,2.478497,,,,,,,,,,,,,, -44500,1.0808963,2.8389723,,,,,,,,,,,,,, -44600,0.9976091,4.6549945,,,,,,,,,,,,,, -44621,,,0.6566405892372131,1.4093447923660278,0.6092199683189392,1.6448744535446167,50000.0,0.4914000332355499,2.3031435012817383,10000.0,19785.997804403305,21217.65069699288,19785.997804403305,1426.566675901413,2.75048303604126,0.0 -44700,1.3399117,2.6581702,,,,,,,,,,,,,, -44800,1.0342678,4.7248473,,,,,,,,,,,,,, -44900,1.1818504,2.6569436,,,,,,,,,,,,,, -45000,1.1709833,2.5024803,,,,,,,,,,,,,, -45100,1.2651688,2.506727,,,,,,,,,,,,,, -45200,1.2661936,2.4511466,,,,,,,,,,,,,, -45300,1.1793032,3.6718779,,,,,,,,,,,,,, -45400,1.2118453,2.392136,,,,,,,,,,,,,, -45500,1.1884153,2.5313368,,,,,,,,,,,,,, -45569,,,0.6633593440055847,1.3881406784057615,0.6101799607276917,1.6343706846237185,50000.0,0.4869000315666199,2.2931628227233887,10000.0,20206.22102165222,21669.25616335869,20206.22102165222,1457.8632283210754,2.787197589874268,0.0 -45600,1.3138824,2.4829452,,,,,,,,,,,,,, -45700,1.164748,2.4297075,,,,,,,,,,,,,, -45800,1.1365194,2.7701893,,,,,,,,,,,,,, -45900,1.0398507,4.3078365,,,,,,,,,,,,,, -46000,1.1282982,4.0221415,,,,,,,,,,,,,, -46100,1.0631036,3.1643746,,,,,,,,,,,,,, -46200,1.0990896,2.6803668,,,,,,,,,,,,,, -46300,1.1299433,2.7035742,,,,,,,,,,,,,, -46400,1.1336195,2.827585,,,,,,,,,,,,,, -46500,1.297276,2.5475233,,,,,,,,,,,,,, -46517,,,0.6712695360183716,1.3291295766830444,0.6121799945831299,1.6179096698760986,50000.0,0.486700028181076,2.2925844192504883,10000.0,20626.519107580185,22123.807602643967,20626.519107580185,1492.0283637046814,2.8271169662475586,0.0 -46600,1.2413131,2.4541285,,,,,,,,,,,,,, -46700,1.2610687,2.5011635,,,,,,,,,,,,,, -46800,1.0707386,3.5955756,,,,,,,,,,,,,, -46900,1.2403189,2.4403884,,,,,,,,,,,,,, -47000,1.3353574,2.4369493,,,,,,,,,,,,,, -47100,1.3357055,2.6171544,,,,,,,,,,,,,, -47200,1.0967009,3.7843103,,,,,,,,,,,,,, -47300,1.2752742,2.2832975,,,,,,,,,,,,,, -47400,1.2534814,2.3607652,,,,,,,,,,,,,, -47466,,,0.6614062190055847,1.4059300422668457,0.6163600087165833,1.6147215366363523,50000.0,0.491100013256073,2.26879358291626,10000.0,21046.87879395485,22575.74212741852,21046.87879395485,1523.521923303604,2.8597183227539062,0.0 -47500,1.3532995,2.3950596,,,,,,,,,,,,,, -47600,1.0574672,4.51038,,,,,,,,,,,,,, -47700,1.2342807,2.4975722,,,,,,,,,,,,,, -47800,1.1288869,2.7054205,,,,,,,,,,,,,, -47900,1.0776494,2.9073086,,,,,,,,,,,,,, -48000,1.0794805,3.9402049,,,,,,,,,,,,,, -48100,1.26215,2.4895294,,,,,,,,,,,,,, -48200,1.2735072,2.4027138,,,,,,,,,,,,,, -48300,1.0758188,3.3634253,,,,,,,,,,,,,, -48400,1.1554503,2.3417125,,,,,,,,,,,,,, -48415,,,0.6602538824081421,1.423980951309204,0.6101399660110474,1.6533193588256836,50000.0,0.4886000156402588,2.3144147396087646,10000.0,21466.973664999008,23025.29788851738,21466.973664999008,1552.896357297897,2.897174119949341,0.0 -48500,1.2039351,2.1843734,,,,,,,,,,,,,, -48600,1.0598747,4.6241326,,,,,,,,,,,,,, -48700,1.1407874,2.580406,,,,,,,,,,,,,, -48800,1.5151865,2.408414,,,,,,,,,,,,,, -48900,1.1287509,2.7352207,,,,,,,,,,,,,, -49000,1.650242,2.474592,,,,,,,,,,,,,, -49100,1.1445138,2.8474512,,,,,,,,,,,,,, -49200,1.256417,2.2806895,,,,,,,,,,,,,, -49300,1.019666,3.9578707,,,,,,,,,,,,,, -49359,,,0.6670702695846558,1.3476743698120115,0.6132599711418152,1.6010985374450684,50000.0,0.4917000234127044,2.266906261444092,10000.0,21887.021904945374,23477.518564224243,21887.021904945374,1584.982391357422,2.935287952423096,0.0 -49400,1.058957,4.8650575,,,,,,,,,,,,,, -49500,1.1972573,2.3974657,,,,,,,,,,,,,, -49600,1.2023107,3.035552,,,,,,,,,,,,,, -49700,1.1050917,4.594533,,,,,,,,,,,,,, -49800,1.1903759,3.1587083,,,,,,,,,,,,,, -49900,1.2639672,2.4080162,,,,,,,,,,,,,, -50000,1.2289721,2.2868834,,,,,,,,,,,,,, -50100,1.1317136,2.5011353,,,,,,,,,,,,,, -50200,1.1119143,4.9416933,,,,,,,,,,,,,, -50300,1.132828,4.994274,,,,,,,,,,,,,, -50305,,,0.6949023008346558,1.217258095741272,0.6212999820709229,1.570192813873291,50000.0,0.4980000257492065,2.233330488204956,10000.0,22307.08506894112,23927.468125104904,22307.08506894112,1614.7766954898834,2.9790749549865723,0.0 -50400,1.2348868,3.6934626,,,,,,,,,,,,,, -50500,1.3072366,2.435897,,,,,,,,,,,,,, -50600,1.0474244,3.327103,,,,,,,,,,,,,, -50700,1.2796144,2.3771808,,,,,,,,,,,,,, -50800,1.203852,2.676581,,,,,,,,,,,,,, -50900,1.2669584,2.4475543,,,,,,,,,,,,,, -51000,1.1781759,2.4747062,,,,,,,,,,,,,, -51100,1.3367192,2.4892383,,,,,,,,,,,,,, -51200,1.3487161,2.517078,,,,,,,,,,,,,, -51252,,,0.6680468320846558,1.355788230895996,0.6198599934577942,1.5720373392105105,50000.0,0.4948000311851501,2.2380619049072266,10000.0,22727.23337483406,24381.398369073868,22727.23337483406,1648.4719729423523,3.016878128051758,0.0 -51300,1.312851,2.3365674,,,,,,,,,,,,,, -51400,1.3444201,2.4354072,,,,,,,,,,,,,, -51500,1.124101,3.235051,,,,,,,,,,,,,, -51600,1.3027536,4.9027214,,,,,,,,,,,,,, -51700,1.2626724,2.5959184,,,,,,,,,,,,,, -51800,1.3860035,2.2947917,,,,,,,,,,,,,, -51900,1.3510741,5.088455,,,,,,,,,,,,,, -52000,1.3119901,2.4471228,,,,,,,,,,,,,, -52100,1.1475879,3.3124928,,,,,,,,,,,,,, -52200,1.4445635,2.3823974,,,,,,,,,,,,,, -52203,,,0.6677343845367432,1.3512169122695925,0.6184799671173096,1.5838528871536257,50000.0,0.495600014925003,2.272915840148926,10000.0,23147.45352268219,24834.76548576355,23147.45352268219,1681.53049826622,3.056202173233032,0.0 -52300,1.1832474,2.7249215,,,,,,,,,,,,,, -52400,1.1647217,3.1163087,,,,,,,,,,,,,, -52500,1.0911739,4.70481,,,,,,,,,,,,,, -52600,1.4044477,2.516977,,,,,,,,,,,,,, -52700,1.1394863,4.129528,,,,,,,,,,,,,, -52800,1.1370964,3.3103285,,,,,,,,,,,,,, -52900,1.20198,2.389107,,,,,,,,,,,,,, -53000,1.1888636,2.8976398,,,,,,,,,,,,,, -53100,1.168182,3.3571932,,,,,,,,,,,,,, -53154,,,0.6862499713897705,1.2557148933410645,0.6257599592208862,1.5392677783966064,50000.0,0.5009000301361084,2.195390224456787,10000.0,23567.636283397675,25286.163171052933,23567.636283397675,1712.6602900028229,3.0928637981414795,0.0 -53200,1.2881695,2.331972,,,,,,,,,,,,,, -53300,1.198324,3.2442508,,,,,,,,,,,,,, -53400,1.2451149,3.9262748,,,,,,,,,,,,,, -53500,1.1401726,2.804511,,,,,,,,,,,,,, -53600,1.2123814,4.7936563,,,,,,,,,,,,,, -53700,1.1659418,3.5727718,,,,,,,,,,,,,, -53800,1.2091902,2.3248234,,,,,,,,,,,,,, -53900,1.1432278,4.9866104,,,,,,,,,,,,,, -54000,1.11331,4.4375663,,,,,,,,,,,,,, -54100,1.1186423,3.7245135,,,,,,,,,,,,,, -54101,,,0.6750195026397705,1.320992112159729,0.62527996301651,1.5474814176559448,50000.0,0.5051000118255615,2.196848630905152,10000.0,23987.96748137474,25736.189141988754,23987.96748137474,1742.2668850421906,3.1326940059661865,0.0 -54200,1.1983585,2.7653894,,,,,,,,,,,,,, -54300,1.3698941,2.4521942,,,,,,,,,,,,,, -54400,1.3317069,2.6771522,,,,,,,,,,,,,, -54500,1.2340643,2.5141234,,,,,,,,,,,,,, -54600,1.070845,2.939975,,,,,,,,,,,,,, -54700,1.1841936,2.8923852,,,,,,,,,,,,,, -54800,1.297484,2.416328,,,,,,,,,,,,,, -54900,1.1874807,5.013619,,,,,,,,,,,,,, -55000,1.3080451,2.361621,,,,,,,,,,,,,, -55049,,,0.6699999570846558,1.3624234199523926,0.6233199834823608,1.5818876028060913,50000.0,0.4979000091552734,2.2316365242004395,10000.0,24408.114142656326,26189.48584985733,24408.114142656326,1775.3276269435885,3.173816442489624,0.0 -55100,1.3104445,2.6035285,,,,,,,,,,,,,, -55200,1.1391269,3.1447644,,,,,,,,,,,,,, -55300,1.2619804,2.374114,,,,,,,,,,,,,, -55400,1.2814109,2.2401893,,,,,,,,,,,,,, -55500,1.3008319,2.3591025,,,,,,,,,,,,,, -55600,1.2910295,2.2694082,,,,,,,,,,,,,, -55700,1.121298,3.2038422,,,,,,,,,,,,,, -55800,1.2509632,2.3808136,,,,,,,,,,,,,, -55900,1.3746895,2.4023592,,,,,,,,,,,,,, -55999,,,0.6770898103713989,1.3100413084030151,0.6246399879455566,1.567784070968628,50000.0,0.4998000264167785,2.230363368988037,10000.0,24828.346135139465,26641.32139706612,24828.346135139465,1806.845562696457,3.211068630218506,0.0 -56000,1.3001211,2.287633,,,,,,,,,,,,,, -56100,1.3727087,2.2977982,,,,,,,,,,,,,, -56200,1.2299402,2.4277644,,,,,,,,,,,,,, -56300,1.3655354,2.420429,,,,,,,,,,,,,, -56400,1.0901136,3.4711556,,,,,,,,,,,,,, -56500,1.3216778,2.547784,,,,,,,,,,,,,, -56600,1.2147121,2.6113374,,,,,,,,,,,,,, -56700,1.143375,4.833387,,,,,,,,,,,,,, -56800,1.0952798,4.4832053,,,,,,,,,,,,,, -56900,1.2238183,2.4321606,,,,,,,,,,,,,, -56943,,,0.6991015672683716,1.2092502117156982,0.6245599985122681,1.5463355779647827,50000.0,0.5057000517845154,2.21742582321167,10000.0,25248.46753191948,27092.631110429764,25248.46753191948,1837.941088438034,3.255908966064453,0.0 -57000,1.2857695,2.3633397,,,,,,,,,,,,,, -57100,1.0655069,3.4935527,,,,,,,,,,,,,, -57200,1.2170737,2.4626353,,,,,,,,,,,,,, -57300,1.1980926,4.834564,,,,,,,,,,,,,, -57400,1.2120339,2.2662442,,,,,,,,,,,,,, -57500,1.0692534,3.8697398,,,,,,,,,,,,,, -57600,1.1228248,2.9102583,,,,,,,,,,,,,, -57700,1.3253162,2.2651403,,,,,,,,,,,,,, -57800,1.324584,2.367289,,,,,,,,,,,,,, -57889,,,0.6811327934265137,1.2940772771835327,0.6329799890518188,1.5235350131988523,50000.0,0.506600022315979,2.182492971420288,10000.0,25668.446413993835,27542.98506808281,25668.446413993835,1868.226809501648,3.297309160232544,0.0 -57900,1.2827457,2.2655642,,,,,,,,,,,,,, -58000,1.2248367,2.8809023,,,,,,,,,,,,,, -58100,1.2118435,2.3485494,,,,,,,,,,,,,, -58200,1.1471348,2.8561082,,,,,,,,,,,,,, -58300,1.3161262,2.4292455,,,,,,,,,,,,,, -58400,1.410471,2.290347,,,,,,,,,,,,,, -58500,1.1578583,3.2388406,,,,,,,,,,,,,, -58600,1.1890606,2.7538407,,,,,,,,,,,,,, -58700,1.325144,2.370635,,,,,,,,,,,,,, -58800,1.1500605,2.4969244,,,,,,,,,,,,,, -58835,,,0.68115234375,1.3139863014221191,0.6299600005149841,1.553807258605957,50000.0,0.5116000175476074,2.191758632659912,10000.0,26088.65665245056,27995.94226884842,26088.65665245056,1900.8828961849213,3.3394837379455566,0.0 -58900,1.3054451,2.396189,,,,,,,,,,,,,, -59000,1.1787338,2.2545247,,,,,,,,,,,,,, -59100,1.0445517,4.049365,,,,,,,,,,,,,, -59200,1.2160639,2.2586188,,,,,,,,,,,,,, -59300,1.15622,3.556953,,,,,,,,,,,,,, -59400,1.4113929,2.2903476,,,,,,,,,,,,,, -59500,1.3005557,2.325342,,,,,,,,,,,,,, -59600,1.282219,2.2570937,,,,,,,,,,,,,, -59700,1.1401118,3.5940142,,,,,,,,,,,,,, -59782,,,0.6947851181030273,1.258604884147644,0.630899965763092,1.5560885667800903,50000.0,0.5057000517845154,2.211388111114502,10000.0,26508.756959676743,28445.90362930298,26508.756959676743,1930.6553165912628,3.3801913261413574,0.0 -59800,1.3334618,2.3521144,,,,,,,,,,,,,, -59900,1.3122214,4.916309,,,,,,,,,,,,,, -60000,1.4080839,2.3497014,,,,,,,,,,,,,, -60100,1.2644744,2.4306104,,,,,,,,,,,,,, -60200,1.3511481,2.5179665,,,,,,,,,,,,,, -60300,1.221245,3.0469916,,,,,,,,,,,,,, -60400,1.2266392,2.1679688,,,,,,,,,,,,,, -60500,1.3452427,2.2823193,,,,,,,,,,,,,, -60600,1.1617911,4.381491,,,,,,,,,,,,,, -60689,,,0.6812695264816284,1.2990782260894775,0.6337000131607056,1.513222098350525,50000.0,0.5031999945640564,2.187201023101806,10000.0,26929.041621923447,28898.69361257553,26929.041621923447,1963.071844100952,3.4235599040985107,0.0 -60700,1.2086335,4.1718006,,,,,,,,,,,,,, -60800,1.1798153,2.5728998,,,,,,,,,,,,,, -60900,1.249646,2.2408767,,,,,,,,,,,,,, -61000,1.2889763,2.234345,,,,,,,,,,,,,, -61100,1.3213661,2.2687304,,,,,,,,,,,,,, -61200,1.2850263,4.515423,,,,,,,,,,,,,, -61300,1.2455515,4.5752873,,,,,,,,,,,,,, -61400,1.3010381,2.2396843,,,,,,,,,,,,,, -61500,1.1958631,2.7353656,,,,,,,,,,,,,, -61600,1.0992969,2.9866762,,,,,,,,,,,,,, -61639,,,0.6890038847923279,1.2490397691726685,0.6389399766921997,1.4967432022094729,50000.0,0.5178000330924988,2.1470038890838623,10000.0,27349.087817907333,29350.87168526649,27349.087817907333,1995.115867614746,3.4631595611572266,0.0 -61700,1.3325883,2.2871027,,,,,,,,,,,,,, -61800,1.194752,4.5518875,,,,,,,,,,,,,, -61900,1.1693518,4.6665635,,,,,,,,,,,,,, -62000,1.3113847,2.5025516,,,,,,,,,,,,,, -62100,1.2270635,4.8301787,,,,,,,,,,,,,, -62200,1.3094153,2.5192199,,,,,,,,,,,,,, -62300,1.1809953,3.585725,,,,,,,,,,,,,, -62400,1.4547486,2.4707468,,,,,,,,,,,,,, -62500,1.3697357,2.122062,,,,,,,,,,,,,, -62585,,,0.6906836032867432,1.238747477531433,0.6342399716377258,1.5011688470840454,50000.0,0.5063000321388245,2.162984848022461,10000.0,27769.2222173214,29802.38598752021,27769.2222173214,2026.404606342316,3.505953550338745,0.0 -62600,1.216647,3.8407,,,,,,,,,,,,,, -62700,1.304944,2.2643347,,,,,,,,,,,,,, -62800,1.0838982,4.8148737,,,,,,,,,,,,,, -62900,1.0529169,4.6319933,,,,,,,,,,,,,, -63000,1.2780441,4.2370963,,,,,,,,,,,,,, -63100,1.1858159,3.2206364,,,,,,,,,,,,,, -63200,1.1867379,2.3854427,,,,,,,,,,,,,, -63300,1.2009196,2.426891,,,,,,,,,,,,,, -63400,1.1840594,2.8962893,,,,,,,,,,,,,, -63500,1.3220398,2.271965,,,,,,,,,,,,,, -63533,,,0.71107417345047,1.1600855588912964,0.6335799694061279,1.5163129568099976,50000.0,0.508400022983551,2.1880664825439453,10000.0,28189.522920131683,30253.7003839016,28189.522920131683,2057.330437898636,3.545721530914306,0.0 -63600,1.320479,2.2417552,,,,,,,,,,,,,, -63700,1.2937003,2.2905507,,,,,,,,,,,,,, -63800,1.4234399,2.2827446,,,,,,,,,,,,,, -63900,1.1678258,2.865718,,,,,,,,,,,,,, -64000,1.2684536,2.200644,,,,,,,,,,,,,, -64100,1.4743667,2.2342358,,,,,,,,,,,,,, -64200,1.1985091,2.7958064,,,,,,,,,,,,,, -64300,1.4169573,2.3818603,,,,,,,,,,,,,, -64400,1.0724827,3.0070684,,,,,,,,,,,,,, -64480,,,0.690625011920929,1.2341567277908323,0.6427800059318542,1.4698628187179563,50000.0,0.5195000171661377,2.1230661869049072,10000.0,28609.791985034943,30704.375244617466,28609.791985034943,2087.645537853241,3.588677167892456,0.0 -64500,1.1397667,3.5764403,,,,,,,,,,,,,, -64600,1.3232412,2.214672,,,,,,,,,,,,,, -64700,1.2488961,2.4378402,,,,,,,,,,,,,, -64800,1.2665223,2.3019845,,,,,,,,,,,,,, -64900,1.0963918,3.9444745,,,,,,,,,,,,,, -65000,1.0666007,4.5406065,,,,,,,,,,,,,, -65100,1.1623822,3.8636456,,,,,,,,,,,,,, -65200,1.3531443,4.2903433,,,,,,,,,,,,,, -65300,1.4315771,2.2998939,,,,,,,,,,,,,, -65400,1.3591962,2.378114,,,,,,,,,,,,,, -65429,,,0.6970507502555847,1.2187137603759766,0.642520010471344,1.467975616455078,50000.0,0.5190000534057617,2.1363329887390137,10000.0,29030.007290124893,31158.778439998627,29030.007290124893,2121.744611263275,3.628873109817505,0.0 -65500,1.2615608,2.2100306,,,,,,,,,,,,,, -65600,1.1910183,4.8020988,,,,,,,,,,,,,, -65700,1.2344332,4.043025,,,,,,,,,,,,,, -65800,1.3142582,2.1501093,,,,,,,,,,,,,, -65900,1.1097032,2.810781,,,,,,,,,,,,,, -66000,1.3376424,2.323352,,,,,,,,,,,,,, -66100,1.359137,2.2124739,,,,,,,,,,,,,, -66200,1.314565,2.2838838,,,,,,,,,,,,,, -66300,1.3425939,2.3676767,,,,,,,,,,,,,, -66381,,,0.70703125,1.1558291912078855,0.6475600004196167,1.4485762119293213,50000.0,0.5151000022888184,2.12232518196106,10000.0,29450.01219272613,31612.26335000992,29450.01219272613,2155.1333100795746,3.671163320541382,0.0 -66400,1.1198152,4.858129,,,,,,,,,,,,,, -66500,1.1914989,2.3377013,,,,,,,,,,,,,, -66600,1.239158,3.0828497,,,,,,,,,,,,,, -66700,1.14758,4.858638,,,,,,,,,,,,,, -66800,1.2759157,2.5409274,,,,,,,,,,,,,, -66900,1.1950575,3.050334,,,,,,,,,,,,,, -67000,1.3468547,2.282415,,,,,,,,,,,,,, -67100,1.6330612,2.2981088,,,,,,,,,,,,,, -67200,1.4818401,2.3307288,,,,,,,,,,,,,, -67300,1.5991974,2.5869093,,,,,,,,,,,,,, -67327,,,0.6959765553474426,1.256928563117981,0.6457599997520447,1.4908559322357178,50000.0,0.5243000388145447,2.134155511856079,10000.0,29870.019545555115,32062.83818912506,29870.019545555115,2185.6091549396515,3.7146735191345215,0.0 -67400,1.3001577,2.4268794,,,,,,,,,,,,,, -67500,1.2081112,2.3384643,,,,,,,,,,,,,, -67600,1.3637692,2.2644286,,,,,,,,,,,,,, -67700,1.1753299,3.1937916,,,,,,,,,,,,,, -67800,1.324448,2.2777238,,,,,,,,,,,,,, -67900,1.4577688,2.2621212,,,,,,,,,,,,,, -68000,1.3687845,2.1610153,,,,,,,,,,,,,, -68100,1.1778028,4.6979365,,,,,,,,,,,,,, -68200,1.2174157,2.4270425,,,,,,,,,,,,,, -68273,,,0.6995898485183716,1.2015330791473389,0.6462399959564209,1.4489117860794067,50000.0,0.5218000411987305,2.0921084880828857,10000.0,30289.98365139961,32512.53608345985,30289.98365139961,2215.2276520729065,3.780029296875,0.0 -68300,1.3879746,2.2877498,,,,,,,,,,,,,, -68400,1.1851839,4.0124893,,,,,,,,,,,,,, -68500,1.2486476,2.921348,,,,,,,,,,,,,, -68600,1.2419198,2.306662,,,,,,,,,,,,,, -68700,1.1801099,2.9673,,,,,,,,,,,,,, -68800,1.2596152,2.4446247,,,,,,,,,,,,,, -68900,1.5899304,2.2438679,,,,,,,,,,,,,, -69000,1.3301721,2.3093958,,,,,,,,,,,,,, -69100,1.3854811,2.1516504,,,,,,,,,,,,,, -69200,1.2842252,2.2356305,,,,,,,,,,,,,, -69218,,,0.7053906321525574,1.2163230180740356,0.6477400064468384,1.4740873575210571,50000.0,0.5241000056266785,2.1208341121673584,10000.0,30710.1666765213,32965.72461724281,30710.1666765213,2248.1401748657227,3.823965072631836,0.0 -69300,1.3144403,2.336306,,,,,,,,,,,,,, -69400,1.4722791,4.7587357,,,,,,,,,,,,,, -69500,1.2628679,4.7421145,,,,,,,,,,,,,, -69600,1.2853924,2.2784054,,,,,,,,,,,,,, -69700,1.1276553,4.591569,,,,,,,,,,,,,, -69800,1.2373905,2.6267357,,,,,,,,,,,,,, -69900,1.146632,3.691253,,,,,,,,,,,,,, -70000,1.4123034,2.2343063,,,,,,,,,,,,,, -70100,1.3001926,2.0693948,,,,,,,,,,,,,, -70164,,,0.7193554639816284,1.1367861032485962,0.6459199786186218,1.4766112565994265,50000.0,0.5228000283241272,2.128782272338867,10000.0,31130.051805496216,33417.16694736481,31130.051805496216,2279.395946502685,4.07750391960144,0.0 -70200,1.3411484,2.1540875,,,,,,,,,,,,,, -70300,1.4224309,2.283532,,,,,,,,,,,,,, -70400,1.3377659,2.0530443,,,,,,,,,,,,,, -70500,1.1802082,3.686941,,,,,,,,,,,,,, -70600,1.0936967,3.8555949,,,,,,,,,,,,,, -70700,1.2698898,2.4973524,,,,,,,,,,,,,, -70800,1.109096,3.4747028,,,,,,,,,,,,,, -70900,1.2197746,2.3240495,,,,,,,,,,,,,, -71000,1.1816709,4.360487,,,,,,,,,,,,,, -71100,1.2399579,2.4028053,,,,,,,,,,,,,, -71108,,,0.6978319883346558,1.2037420272827148,0.6468799710273743,1.4458647966384888,50000.0,0.5250000357627869,2.09932017326355,10000.0,31550.3475124836,33871.81460762024,31550.3475124836,2313.651697158813,4.12572455406189,0.0 -71200,1.19985,2.6212773,,,,,,,,,,,,,, -71300,1.3439234,2.6365569,,,,,,,,,,,,,, -71400,1.155005,2.9931042,,,,,,,,,,,,,, -71500,1.3380384,2.1754599,,,,,,,,,,,,,, -71600,1.4556755,2.1495695,,,,,,,,,,,,,, -71700,1.207554,3.3185568,,,,,,,,,,,,,, -71800,1.3010927,3.2447252,,,,,,,,,,,,,, -71900,1.300357,2.0747013,,,,,,,,,,,,,, -72000,1.3222823,2.8751974,,,,,,,,,,,,,, -72058,,,0.7070507407188416,1.153599977493286,0.652899980545044,1.4156707525253296,50000.0,0.5271000266075134,2.0786259174346924,10000.0,31970.288603544235,34322.90913772583,31970.288603544235,2344.7223691940308,4.1603617668151855,0.0 -72100,1.2521437,2.6348956,,,,,,,,,,,,,, -72200,1.1654273,3.2590926,,,,,,,,,,,,,, -72300,1.3855911,2.5972757,,,,,,,,,,,,,, -72400,1.6808773,2.2901351,,,,,,,,,,,,,, -72500,1.2158066,3.281144,,,,,,,,,,,,,, -72600,1.2011666,3.1281426,,,,,,,,,,,,,, -72700,1.4468609,2.3353946,,,,,,,,,,,,,, -72800,1.2139753,2.5891435,,,,,,,,,,,,,, -72900,1.180626,3.4537325,,,,,,,,,,,,,, -73000,1.355932,2.0428,,,,,,,,,,,,,, -73003,,,0.710253894329071,1.1438536643981934,0.6521599888801575,1.4241604804992676,50000.0,0.5306000113487244,2.0731072425842285,10000.0,32390.45892190933,34774.869156360626,32390.45892190933,2376.42147231102,4.202556133270264,0.0 -73100,1.3403221,4.7624116,,,,,,,,,,,,,, -73200,1.4150285,2.2043078,,,,,,,,,,,,,, -73300,1.4114904,2.2066336,,,,,,,,,,,,,, -73400,1.2242731,4.1659517,,,,,,,,,,,,,, -73500,1.1781468,3.2832675,,,,,,,,,,,,,, -73600,1.173655,2.5819032,,,,,,,,,,,,,, -73700,1.2132168,4.788659,,,,,,,,,,,,,, -73800,1.2700658,2.2934277,,,,,,,,,,,,,, -73900,1.4402229,2.3963728,,,,,,,,,,,,,, -73943,,,0.7079882621765137,1.155798077583313,0.6553399562835693,1.4021482467651367,50000.0,0.5343000292778015,2.051135301589966,10000.0,32810.473249197006,35230.20307135582,32810.473249197006,2411.644124746322,4.24935245513916,0.0 -74000,1.3584114,2.172721,,,,,,,,,,,,,, -74100,1.3691158,2.2249992,,,,,,,,,,,,,, -74200,1.3449293,3.2763543,,,,,,,,,,,,,, -74300,1.4852166,2.2525442,,,,,,,,,,,,,, -74400,1.2953619,3.598932,,,,,,,,,,,,,, -74500,1.3815383,2.368235,,,,,,,,,,,,,, -74600,1.3617086,2.4723363,,,,,,,,,,,,,, -74700,1.2544354,2.6507258,,,,,,,,,,,,,, -74800,1.3778336,4.225799,,,,,,,,,,,,,, -74891,,,0.706835925579071,1.19562828540802,0.6556000113487244,1.433531403541565,50000.0,0.5301000475883484,2.0819716453552246,10000.0,33230.81537055969,35684.42928028107,33230.81537055969,2445.44240641594,4.287355422973633,0.0 -74900,1.4261839,2.2196145,,,,,,,,,,,,,, -75000,1.2040572,2.0240278,,,,,,,,,,,,,, -75100,1.4526212,2.1183236,,,,,,,,,,,,,, -75200,1.2165927,4.4476604,,,,,,,,,,,,,, -75300,1.1675339,2.5128262,,,,,,,,,,,,,, -75400,1.3516871,2.0792716,,,,,,,,,,,,,, -75500,1.1587414,4.558337,,,,,,,,,,,,,, -75600,1.3395568,2.248332,,,,,,,,,,,,,, -75700,1.3562249,2.2710266,,,,,,,,,,,,,, -75800,1.494829,2.222955,,,,,,,,,,,,,, -75839,,,0.71205073595047,1.21051287651062,0.6584199666976929,1.4606388807296753,50000.0,0.5286000370979309,2.1042492389678955,10000.0,33651.12571454048,36135.93005084992,33651.12571454048,2476.5385341644287,4.334265470504761,0.0 -75900,1.4450766,2.1731598,,,,,,,,,,,,,, -76000,1.3636205,2.2421799,,,,,,,,,,,,,, -76100,1.2912812,2.533813,,,,,,,,,,,,,, -76200,1.4327532,2.2476993,,,,,,,,,,,,,, -76300,1.2817278,2.269162,,,,,,,,,,,,,, -76400,1.2955806,2.2517827,,,,,,,,,,,,,, -76500,1.3015858,2.2268968,,,,,,,,,,,,,, -76600,1.3750576,2.0952752,,,,,,,,,,,,,, -76700,1.3708298,2.2134936,,,,,,,,,,,,,, -76784,,,0.72572261095047,1.1005898714065552,0.651419997215271,1.4404127597808838,50000.0,0.534000039100647,2.07847261428833,10000.0,34071.293254852295,36591.18041014671,34071.293254852295,2511.526807546616,4.380546808242798,0.0 -76800,1.2338474,4.3119693,,,,,,,,,,,,,, -76900,1.1832852,3.145786,,,,,,,,,,,,,, -77000,1.1850815,4.299781,,,,,,,,,,,,,, -77100,1.3081183,3.0968292,,,,,,,,,,,,,, -77200,1.3315406,2.31608,,,,,,,,,,,,,, -77300,1.3035463,2.1231012,,,,,,,,,,,,,, -77400,1.427366,2.0459402,,,,,,,,,,,,,, -77500,1.4182843,2.1527724,,,,,,,,,,,,,, -77600,1.3637598,2.2674708,,,,,,,,,,,,,, -77700,1.4121224,2.1964815,,,,,,,,,,,,,, -77733,,,0.7066601514816284,1.1821738481521606,0.6551199555397034,1.4311776161193848,50000.0,0.5369000434875488,2.0748746395111084,10000.0,34491.62502241135,37043.41521525383,34491.62502241135,2543.342089414597,4.420172214508057,0.0 -77800,1.3091838,2.5057304,,,,,,,,,,,,,, -77900,1.4067787,2.20725,,,,,,,,,,,,,, -78000,1.4436655,2.0937576,,,,,,,,,,,,,, -78100,1.3198435,2.624492,,,,,,,,,,,,,, -78200,1.1900413,4.5000167,,,,,,,,,,,,,, -78300,1.3019332,3.48514,,,,,,,,,,,,,, -78400,1.3817515,4.371709,,,,,,,,,,,,,, -78500,1.372551,2.150344,,,,,,,,,,,,,, -78600,1.4733427,2.1180143,,,,,,,,,,,,,, -78677,,,0.7102343440055847,1.176645040512085,0.6593799591064453,1.4185059070587158,50000.0,0.532200038433075,2.0688157081604004,10000.0,34911.57203626633,37493.42399716377,34911.57203626633,2573.301589488983,4.47462272644043,0.0 -78700,1.3319943,2.0732422,,,,,,,,,,,,,, -78800,1.232129,3.278234,,,,,,,,,,,,,, -78900,1.4058584,2.103675,,,,,,,,,,,,,, -79000,1.2688981,3.124959,,,,,,,,,,,,,, -79100,1.3858999,2.2252975,,,,,,,,,,,,,, -79200,1.1625772,3.8719945,,,,,,,,,,,,,, -79300,1.1493925,4.5534463,,,,,,,,,,,,,, -79400,1.3122172,2.1786854,,,,,,,,,,,,,, -79500,1.490789,2.0394075,,,,,,,,,,,,,, -79600,1.3139738,3.188773,,,,,,,,,,,,,, -79619,,,0.7199609279632568,1.122406244277954,0.657759964466095,1.4062657356262207,50000.0,0.5317000150680542,2.070014476776123,10000.0,35331.53860926628,37945.161172389984,35331.53860926628,2604.975376367569,4.523173093795776,0.0 -79700,1.2762189,4.3795586,,,,,,,,,,,,,, -79800,1.3086898,3.943211,,,,,,,,,,,,,, -79900,1.3463451,3.213738,,,,,,,,,,,,,, -80000,1.2376176,2.4838414,,,,,,,,,,,,,, -80100,1.4885987,2.1879137,,,,,,,,,,,,,, -80200,1.1995223,2.3451772,,,,,,,,,,,,,, -80300,1.4934747,2.1523724,,,,,,,,,,,,,, -80400,1.3611215,2.7097516,,,,,,,,,,,,,, -80500,1.2256634,4.551468,,,,,,,,,,,,,, -80567,,,0.7152929306030273,1.1523548364639282,0.6591199636459351,1.416941523551941,50000.0,0.5348000526428223,2.05126953125,10000.0,35751.56966614723,38396.780281066895,35751.56966614723,2636.451033353805,4.587049722671509,0.0 -80600,1.1765336,3.0730798,,,,,,,,,,,,,, -80700,1.3679461,4.865703,,,,,,,,,,,,,, -80800,1.269732,2.4749928,,,,,,,,,,,,,, -80900,1.3394806,2.1923506,,,,,,,,,,,,,, -81000,1.353806,2.0875123,,,,,,,,,,,,,, -81100,1.3642775,2.110196,,,,,,,,,,,,,, -81200,1.3475926,2.8932133,,,,,,,,,,,,,, -81300,1.492426,2.2101636,,,,,,,,,,,,,, -81400,1.2356592,4.491831,,,,,,,,,,,,,, -81500,1.5849148,2.0895653,,,,,,,,,,,,,, -81513,,,0.7157031297683716,1.1238621473312378,0.6619799733161926,1.38247811794281,50000.0,0.536300003528595,2.0532751083374023,10000.0,36171.70365715027,38850.550045251846,36171.70365715027,2669.9925594329834,4.633692026138306,0.0 -81600,1.3464744,2.0095482,,,,,,,,,,,,,, -81700,1.3868988,4.8228617,,,,,,,,,,,,,, -81800,1.4237185,2.1599271,,,,,,,,,,,,,, -81900,1.2880614,2.371037,,,,,,,,,,,,,, -82000,1.6122279,2.02867,,,,,,,,,,,,,, -82100,1.4072983,2.0379345,,,,,,,,,,,,,, -82200,1.2489551,2.3563225,,,,,,,,,,,,,, -82300,1.3122587,3.1231213,,,,,,,,,,,,,, -82400,1.520556,2.1381986,,,,,,,,,,,,,, -82463,,,0.7237108945846558,1.1237812042236328,0.6624999642372131,1.4014748334884644,50000.0,0.5412999987602234,2.030993938446045,10000.0,36591.86593723297,39303.57247233391,36591.86593723297,2702.765382766724,4.6731603145599365,0.0 -82500,1.556724,3.9545555,,,,,,,,,,,,,, -82600,1.277069,2.0486643,,,,,,,,,,,,,, -82700,1.2397298,4.640013,,,,,,,,,,,,,, -82800,1.3618132,4.0745993,,,,,,,,,,,,,, -82900,1.4669602,2.123361,,,,,,,,,,,,,, -83000,1.2754754,3.915193,,,,,,,,,,,,,, -83100,1.1862924,3.0754764,,,,,,,,,,,,,, -83200,1.265475,2.7335973,,,,,,,,,,,,,, -83300,1.5499833,2.3749363,,,,,,,,,,,,,, -83400,1.4442561,3.764964,,,,,,,,,,,,,, -83412,,,0.7419531345367432,1.0138866901397705,0.6646199822425842,1.3580873012542725,50000.0,0.5382000207901001,2.0109994411468506,10000.0,37012.14245843887,39753.58320188522,37012.14245843887,2732.4041872024536,4.720138072967529,0.0 -83500,1.4768099,2.133234,,,,,,,,,,,,,, -83600,1.5288482,2.1312578,,,,,,,,,,,,,, -83700,1.4400706,2.1939917,,,,,,,,,,,,,, -83800,1.3244228,2.3785534,,,,,,,,,,,,,, -83900,1.3053159,2.8669121,,,,,,,,,,,,,, -84000,1.625798,2.0389597,,,,,,,,,,,,,, -84100,1.4602202,2.236606,,,,,,,,,,,,,, -84200,1.235983,3.723238,,,,,,,,,,,,,, -84300,1.2089097,3.1327214,,,,,,,,,,,,,, -84356,,,0.721484363079071,1.12127685546875,0.6665599942207336,1.3685060739517212,50000.0,0.5446000099182129,2.0185341835021973,10000.0,37432.25945806503,40205.688380241394,37432.25945806503,2764.3009791374207,4.763558626174927,0.0 -84400,1.5162771,2.1542885,,,,,,,,,,,,,, -84500,1.5271966,4.7284784,,,,,,,,,,,,,, -84600,1.5003221,2.0637846,,,,,,,,,,,,,, -84700,1.3763044,1.9247334,,,,,,,,,,,,,, -84800,1.1859571,3.169907,,,,,,,,,,,,,, -84900,1.4005473,2.563088,,,,,,,,,,,,,, -85000,1.2516418,2.896484,,,,,,,,,,,,,, -85100,1.319826,2.0836184,,,,,,,,,,,,,, -85200,1.3511953,2.2756393,,,,,,,,,,,,,, -85300,1.5855125,2.3343134,,,,,,,,,,,,,, -85303,,,0.7251952886581421,1.107918620109558,0.6675199866294861,1.3741096258163452,50000.0,0.5424000024795532,2.0130226612091064,10000.0,37852.270012140274,40656.39597964287,37852.270012140274,2794.892049312592,4.821456909179688,0.0 -85400,1.257396,2.4804077,,,,,,,,,,,,,, -85500,1.5144008,2.0596538,,,,,,,,,,,,,, -85600,1.5478963,2.1003613,,,,,,,,,,,,,, -85700,1.6928307,2.128552,,,,,,,,,,,,,, -85800,1.1704578,3.3523872,,,,,,,,,,,,,, -85900,1.4113667,2.3647995,,,,,,,,,,,,,, -86000,1.5086324,2.1134,,,,,,,,,,,,,, -86100,1.4430585,2.2053788,,,,,,,,,,,,,, -86200,1.6599091,4.6151333,,,,,,,,,,,,,, -86251,,,0.7322655916213989,1.052381992340088,0.6670399904251099,1.3552080392837524,50000.0,0.5398000478744507,2.01186466217041,10000.0,38272.54418039322,41106.30424046517,38272.54418039322,2824.427891969681,4.871694087982178,0.0 -86300,1.4846041,2.0537527,,,,,,,,,,,,,, -86400,1.3979751,2.041125,,,,,,,,,,,,,, -86500,1.4721932,2.3037736,,,,,,,,,,,,,, -86600,1.3519669,2.3594172,,,,,,,,,,,,,, -86700,1.5353122,2.0858727,,,,,,,,,,,,,, -86800,1.3193944,2.3876243,,,,,,,,,,,,,, -86900,1.3855017,1.9450425,,,,,,,,,,,,,, -87000,1.269413,3.6829975,,,,,,,,,,,,,, -87100,1.2545252,3.597194,,,,,,,,,,,,,, -87195,,,0.7362109422683716,1.0454906225204468,0.6697999835014343,1.338555097579956,50000.0,0.5415000319480896,2.000982999801636,10000.0,38692.73499083519,41562.39662957192,38692.73499083519,2860.237210035324,4.915456771850586,0.0 -87200,1.5171939,2.0272484,,,,,,,,,,,,,, -87300,1.4049797,4.67992,,,,,,,,,,,,,, -87400,1.4751332,2.4946876,,,,,,,,,,,,,, -87500,1.5622791,2.1786835,,,,,,,,,,,,,, -87600,1.4310292,1.9751672,,,,,,,,,,,,,, -87700,1.335889,2.2741964,,,,,,,,,,,,,, -87800,1.2962627,2.4251766,,,,,,,,,,,,,, -87900,1.5073612,2.101936,,,,,,,,,,,,,, -88000,1.3310728,3.4808166,,,,,,,,,,,,,, -88100,1.3096493,2.5916219,,,,,,,,,,,,,, -88144,,,0.7266015410423279,1.0895438194274902,0.6708799600601196,1.3421494960784912,50000.0,0.5432000160217285,1.992767691612244,10000.0,39112.82124638557,42018.42330169678,39112.82124638557,2896.091214418412,4.953283309936523,0.0 -88200,1.4584138,2.0318353,,,,,,,,,,,,,, -88300,1.2889768,4.5822115,,,,,,,,,,,,,, -88400,1.4859732,2.0352917,,,,,,,,,,,,,, -88500,1.4039315,4.57947,,,,,,,,,,,,,, -88600,1.4677918,3.0009382,,,,,,,,,,,,,, -88700,1.3981637,3.630652,,,,,,,,,,,,,, -88800,1.3475606,4.2537775,,,,,,,,,,,,,, -88900,1.3312888,4.439172,,,,,,,,,,,,,, -89000,1.6365463,2.0705965,,,,,,,,,,,,,, -89093,,,0.7369921803474426,1.0429660081863403,0.6749799847602844,1.325405240058899,50000.0,0.5485000014305115,1.9791725873947144,10000.0,39532.90726733208,42469.43255209923,39532.90726733208,2926.929531097412,4.990481376647949,0.0 -89100,1.403534,4.604636,,,,,,,,,,,,,, -89200,1.3315471,4.170901,,,,,,,,,,,,,, -89300,1.2643605,3.0705297,,,,,,,,,,,,,, -89400,1.5315604,2.0114872,,,,,,,,,,,,,, -89500,1.4303269,1.9394635,,,,,,,,,,,,,, -89600,1.4123098,1.9814725,,,,,,,,,,,,,, -89700,1.2847854,2.8876507,,,,,,,,,,,,,, -89800,1.2625797,3.2812326,,,,,,,,,,,,,, -89900,1.3135531,4.704539,,,,,,,,,,,,,, -90000,1.4567964,1.9857635,,,,,,,,,,,,,, -90040,,,0.748828113079071,1.0013405084609983,0.674340009689331,1.3295317888259888,50000.0,0.5498000383377075,1.9649418592453003,10000.0,39953.20360445976,42921.797057151794,39953.20360445976,2958.897340297699,5.042929649353027,0.0 -90100,1.550402,4.589543,,,,,,,,,,,,,, -90200,1.2804059,3.0566576,,,,,,,,,,,,,, -90300,1.3279544,3.0995977,,,,,,,,,,,,,, -90400,1.3090824,2.1011076,,,,,,,,,,,,,, -90500,1.6081598,2.041541,,,,,,,,,,,,,, -90600,1.2291695,3.047017,,,,,,,,,,,,,, -90700,1.4783444,2.1742716,,,,,,,,,,,,,, -90800,1.3506325,1.9756885,,,,,,,,,,,,,, -90900,1.491603,1.8993038,,,,,,,,,,,,,, -90985,,,0.7356249690055847,1.0418883562088013,0.6781600117683411,1.3062312602996826,50000.0,0.5527999997138977,1.9561856985092163,10000.0,40373.19947838783,43374.818145513535,40373.19947838783,2991.8312034606934,5.086370944976807,0.0 -91000,1.3143837,2.0624056,,,,,,,,,,,,,, -91100,1.3356457,4.615157,,,,,,,,,,,,,, -91200,1.5134538,2.048848,,,,,,,,,,,,,, -91300,1.474497,2.2091746,,,,,,,,,,,,,, -91400,1.5342118,2.125774,,,,,,,,,,,,,, -91500,1.391832,2.0224905,,,,,,,,,,,,,, -91600,1.36874,2.856839,,,,,,,,,,,,,, -91700,1.4329567,1.9198098,,,,,,,,,,,,,, -91800,1.4483004,2.23429,,,,,,,,,,,,,, -91900,1.7028403,1.9760029,,,,,,,,,,,,,, -91932,,,0.7396484017372131,1.0154486894607544,0.6789000034332275,1.2950810194015503,50000.0,0.5508000254631042,1.9430943727493288,10000.0,40793.36383152008,43825.52984857559,40793.36383152008,3022.284590482712,5.132043361663818,0.0 -92000,1.681025,2.0762942,,,,,,,,,,,,,, -92100,1.4429471,2.105249,,,,,,,,,,,,,, -92200,1.5964439,1.9469903,,,,,,,,,,,,,, -92300,1.3414518,3.4439096,,,,,,,,,,,,,, -92400,1.4039645,3.7278333,,,,,,,,,,,,,, -92500,1.467053,2.2281208,,,,,,,,,,,,,, -92600,1.4862957,2.0029655,,,,,,,,,,,,,, -92700,1.4742846,1.9206789,,,,,,,,,,,,,, -92800,1.3283795,2.4224894,,,,,,,,,,,,,, -92877,,,0.7428515553474426,1.0198140144348145,0.6730799674987793,1.332512617111206,50000.0,0.5501000285148621,1.9638128280639648,10000.0,41213.46935200691,44279.05270528793,41213.46935200691,3055.57865357399,5.207129716873169,0.0 -92900,1.3717241,4.4734077,,,,,,,,,,,,,, -93000,1.5255132,2.050179,,,,,,,,,,,,,, -93100,1.3933674,2.148676,,,,,,,,,,,,,, -93200,1.6069175,2.1107674,,,,,,,,,,,,,, -93300,1.3793682,4.2477226,,,,,,,,,,,,,, -93400,1.5580047,1.8725824,,,,,,,,,,,,,, -93500,1.5539533,1.968339,,,,,,,,,,,,,, -93600,1.264495,2.9867842,,,,,,,,,,,,,, -93700,1.469814,2.564003,,,,,,,,,,,,,, -93800,1.406971,4.165303,,,,,,,,,,,,,, -93825,,,0.7576562166213989,0.9598555564880372,0.6772800087928772,1.3212695121765137,50000.0,0.5507000088691711,1.9684770107269287,10000.0,41633.70737886429,44729.89771127701,41633.70737886429,3086.0911548137665,5.252807855606079,0.0 -93900,1.5120971,1.995661,,,,,,,,,,,,,, -94000,1.545041,2.1248972,,,,,,,,,,,,,, -94100,1.3886887,2.366074,,,,,,,,,,,,,, -94200,1.476544,2.02209,,,,,,,,,,,,,, -94300,1.439817,2.1697512,,,,,,,,,,,,,, -94400,1.6000088,2.002839,,,,,,,,,,,,,, -94500,1.4470897,2.3106182,,,,,,,,,,,,,, -94600,1.5327383,1.9276891,,,,,,,,,,,,,, -94700,1.4525789,2.1182516,,,,,,,,,,,,,, -94770,,,0.7390429377555847,1.0380719900131226,0.6790399551391602,1.304970145225525,50000.0,0.554900050163269,1.955616116523743,10000.0,42053.8782582283,45182.18700695038,42053.8782582283,3118.114638566971,5.299704313278198,0.0 -94800,1.5815988,2.0894094,,,,,,,,,,,,,, -94900,1.6624892,2.024401,,,,,,,,,,,,,, -95000,1.4600924,1.8471808,,,,,,,,,,,,,, -95100,1.4668998,4.169058,,,,,,,,,,,,,, -95200,1.5478841,2.3056774,,,,,,,,,,,,,, -95300,1.3002739,2.8576505,,,,,,,,,,,,,, -95400,1.555437,2.1144137,,,,,,,,,,,,,, -95500,1.45591,2.0938258,,,,,,,,,,,,,, -95600,1.6063575,4.432007,,,,,,,,,,,,,, -95700,1.3395628,2.6895258,,,,,,,,,,,,,, -95715,,,0.7507030963897705,1.004400610923767,0.6848599910736084,1.291728377342224,50000.0,0.5614000558853149,1.9297590255737305,10000.0,42474.13750100136,45633.44561266899,42474.13750100136,3149.01988363266,5.345886945724487,0.0 -95800,1.474938,1.9237733,,,,,,,,,,,,,, -95900,1.4354088,2.0685377,,,,,,,,,,,,,, -96000,1.4514138,3.4917612,,,,,,,,,,,,,, -96100,1.4612743,1.9381384,,,,,,,,,,,,,, -96200,1.4640399,4.4947886,,,,,,,,,,,,,, -96300,1.2446631,3.4383106,,,,,,,,,,,,,, -96400,1.5365154,4.4652486,,,,,,,,,,,,,, -96500,1.540979,1.8745164,,,,,,,,,,,,,, -96600,1.4998132,2.306901,,,,,,,,,,,,,, -96661,,,0.7564648389816284,0.9674227833747864,0.6837799549102783,1.2896058559417725,50000.0,0.5568000078201294,1.937228202819824,10000.0,42894.14020085335,46086.63012123108,42894.14020085335,3182.1036455631256,5.395781517028809,0.0 -96700,1.4589868,4.3794937,,,,,,,,,,,,,, -96800,1.4700644,1.979635,,,,,,,,,,,,,, -96900,1.5260173,2.5383704,,,,,,,,,,,,,, -97000,1.5287415,1.9107797,,,,,,,,,,,,,, -97100,1.4794253,2.2785718,,,,,,,,,,,,,, -97200,1.3370537,3.737619,,,,,,,,,,,,,, -97300,1.4436917,4.4159074,,,,,,,,,,,,,, -97400,1.5986278,1.9295615,,,,,,,,,,,,,, -97500,1.2476156,3.8604863,,,,,,,,,,,,,, -97600,1.3699524,3.15073,,,,,,,,,,,,,, -97611,,,0.7490038871765137,0.9923749566078186,0.6890199780464172,1.2585322856903076,50000.0,0.5580000281333923,1.915615677833557,10000.0,43314.50114059448,46537.77447581291,43314.50114059448,3212.790130376816,5.443286895751953,0.0 -97700,1.2957172,3.7546115,,,,,,,,,,,,,, -97800,1.6159225,3.8262262,,,,,,,,,,,,,, -97900,1.6627295,2.0956864,,,,,,,,,,,,,, -98000,1.4242332,3.9277692,,,,,,,,,,,,,, -98100,1.3624297,2.6123767,,,,,,,,,,,,,, -98200,1.3426802,2.2402303,,,,,,,,,,,,,, -98300,1.6376823,1.9642694,,,,,,,,,,,,,, -98400,1.5167122,2.2614176,,,,,,,,,,,,,, -98500,1.5420053,2.04612,,,,,,,,,,,,,, -98558,,,0.7421093583106995,1.0115565061569214,0.6842199563980103,1.2888301610946655,50000.0,0.5594000220298767,1.939386248588562,10000.0,43734.45563292503,46989.04209041596,43734.45563292503,3244.002207517624,5.495993137359619,0.0 -98600,1.3175027,3.2349086,,,,,,,,,,,,,, -98700,1.3178302,2.6522772,,,,,,,,,,,,,, -98800,1.4506118,2.0259538,,,,,,,,,,,,,, -98900,1.3705617,2.4184573,,,,,,,,,,,,,, -99000,1.4710068,4.1325016,,,,,,,,,,,,,, -99100,1.3746116,3.8519778,,,,,,,,,,,,,, -99200,1.534021,4.4767866,,,,,,,,,,,,,, -99300,1.647866,2.0010197,,,,,,,,,,,,,, -99400,1.457682,1.8203663,,,,,,,,,,,,,, -99500,1.399045,4.3262157,,,,,,,,,,,,,, -99504,,,0.7502343654632568,0.9732296466827391,0.6872599720954895,1.2587881088256836,50000.0,0.5587000250816345,1.922590732574463,10000.0,44154.56234765053,47441.99556803703,44154.56234765053,3276.752377271652,5.544393062591553,0.0 -99600,1.5512494,1.9729683,,,,,,,,,,,,,, -99700,1.2354367,3.2976494,,,,,,,,,,,,,, -99800,1.3603238,2.6900148,,,,,,,,,,,,,, -99900,1.3881004,2.2547286,,,,,,,,,,,,,, -100000,1.3457705,2.5132883,,,,,,,,,,,,,, -100100,1.6296402,1.940679,,,,,,,,,,,,,, -100200,1.5486937,4.476404,,,,,,,,,,,,,, -100300,1.6820202,1.8583412,,,,,,,,,,,,,, -100400,1.3207616,3.0316672,,,,,,,,,,,,,, -100452,,,0.7751367092132568,0.8857801556587219,0.6929799914360046,1.2570501565933228,50000.0,0.5685000419616699,1.8823338747024536,10000.0,44574.62707424164,47892.88826847077,44574.62707424164,3307.483034849167,5.593884229660034,0.0 -100500,1.5094532,4.111455,,,,,,,,,,,,,, -100600,1.4544817,3.4411108,,,,,,,,,,,,,, -100700,1.4875336,2.6457114,,,,,,,,,,,,,, -100800,1.3146598,3.0875826,,,,,,,,,,,,,, -100900,1.4408194,1.943955,,,,,,,,,,,,,, -101000,1.6449349,2.0064812,,,,,,,,,,,,,, -101100,1.3635883,4.391857,,,,,,,,,,,,,, -101200,1.5039015,2.2220001,,,,,,,,,,,,,, -101300,1.4552782,3.5065165,,,,,,,,,,,,,, -101397,,,0.7506640553474426,0.966569483280182,0.6922799944877625,1.245205640792847,50000.0,0.5728000402450562,1.8757268190383911,10000.0,44994.80569982529,48345.53457951546,44994.80569982529,3339.8508801460266,5.645975828170776,0.0 -101400,1.5894852,2.0588622,,,,,,,,,,,,,, -101500,1.5216613,1.9744699,,,,,,,,,,,,,, -101600,1.4824021,2.0702634,,,,,,,,,,,,,, -101700,1.48648,2.3956041,,,,,,,,,,,,,, -101800,1.3063155,2.491818,,,,,,,,,,,,,, -101900,1.4584194,1.7572935,,,,,,,,,,,,,, -102000,1.7505637,4.534113,,,,,,,,,,,,,, -102100,1.787077,2.0206082,,,,,,,,,,,,,, -102200,1.603656,2.0349455,,,,,,,,,,,,,, -102300,1.4522696,1.991729,,,,,,,,,,,,,, -102342,,,0.7575390338897705,0.9569990038871764,0.6941800117492676,1.2485471963882446,50000.0,0.5701000094413757,1.8743760585784912,10000.0,45415.07636475563,48801.93915319443,45415.07636475563,3375.8747539520264,5.707857370376587,0.0 -102400,1.5449935,1.8464158,,,,,,,,,,,,,, -102500,1.4231316,2.295156,,,,,,,,,,,,,, -102600,1.5243214,2.0182843,,,,,,,,,,,,,, -102700,1.4775162,4.181538,,,,,,,,,,,,,, -102800,1.4237622,2.3675478,,,,,,,,,,,,,, -102900,1.4563974,2.161204,,,,,,,,,,,,,, -103000,1.4890574,2.1621206,,,,,,,,,,,,,, -103100,1.4792012,1.893227,,,,,,,,,,,,,, -103200,1.5382141,1.8642635,,,,,,,,,,,,,, -103292,,,0.7648046612739563,0.9258500933647156,0.6905399560928345,1.2477099895477295,50000.0,0.5659000277519226,1.897047519683838,10000.0,45835.02481293678,49253.074348926544,45835.02481293678,3406.97397518158,5.746920824050903,0.0 -103300,1.7189525,2.0126016,,,,,,,,,,,,,, -103400,1.493799,1.8075377,,,,,,,,,,,,,, -103500,1.4984813,4.4537845,,,,,,,,,,,,,, -103600,1.4801273,2.813842,,,,,,,,,,,,,, -103700,1.6773971,4.439361,,,,,,,,,,,,,, -103800,1.3892584,3.0231886,,,,,,,,,,,,,, -103900,1.6313007,1.9661368,,,,,,,,,,,,,, -104000,1.4699298,1.8819253,,,,,,,,,,,,,, -104100,1.6298127,1.8078673,,,,,,,,,,,,,, -104200,1.5325856,2.005899,,,,,,,,,,,,,, -104239,,,0.7541210651397705,0.9695034027099608,0.6918599605560303,1.2417765855789185,50000.0,0.5667000412940979,1.8755282163619995,10000.0,46255.20332431793,49706.23113465309,46255.20332431793,3439.858054637909,5.793359994888306,0.0 -104300,1.4431275,2.8017468,,,,,,,,,,,,,, -104400,1.5056864,1.8822145,,,,,,,,,,,,,, -104500,1.6610098,1.8045645,,,,,,,,,,,,,, -104600,1.5867984,2.0138214,,,,,,,,,,,,,, -104700,1.4171779,3.146069,,,,,,,,,,,,,, -104800,1.4090679,2.3949947,,,,,,,,,,,,,, -104900,1.5225911,1.9746269,,,,,,,,,,,,,, -105000,1.5251364,2.8000226,,,,,,,,,,,,,, -105100,1.5335236,1.8112004,,,,,,,,,,,,,, -105187,,,0.7598632574081421,0.940223515033722,0.6953999996185303,1.232001543045044,50000.0,0.5755000114440918,1.8609957695007324,10000.0,46675.59752130509,50160.85048317909,46675.59752130509,3473.9791502952576,5.849823951721191,0.0 -105200,1.5950801,2.0598435,,,,,,,,,,,,,, -105300,1.6637992,1.8882015,,,,,,,,,,,,,, -105400,1.6771754,4.5042214,,,,,,,,,,,,,, -105500,1.4125056,3.6980252,,,,,,,,,,,,,, -105600,1.6448224,3.6640208,,,,,,,,,,,,,, -105700,1.7414163,1.8060105,,,,,,,,,,,,,, -105800,1.450404,3.368815,,,,,,,,,,,,,, -105900,1.5823137,4.1463923,,,,,,,,,,,,,, -106000,1.5223194,1.7942014,,,,,,,,,,,,,, -106100,1.7315034,1.869436,,,,,,,,,,,,,, -106136,,,0.7672656178474426,0.9064086079597472,0.6967399716377258,1.219843864440918,50000.0,0.5692000389099121,1.868541240692139,10000.0,47095.5532104969,50613.21381640434,47095.5532104969,3506.291650056839,5.897094011306763,0.0 -106200,1.6310126,1.8478796,,,,,,,,,,,,,, -106300,1.602339,1.8361102,,,,,,,,,,,,,, -106400,1.5499123,4.296976,,,,,,,,,,,,,, -106500,1.828645,1.8530773,,,,,,,,,,,,,, -106600,1.5939769,3.2945554,,,,,,,,,,,,,, -106700,1.4774256,3.9643428,,,,,,,,,,,,,, -106800,1.5852145,1.8840181,,,,,,,,,,,,,, -106900,1.9899539,1.9229883,,,,,,,,,,,,,, -107000,1.4032983,2.127353,,,,,,,,,,,,,, -107078,,,0.7860937118530273,0.841831624507904,0.6953200101852417,1.229534387588501,50000.0,0.5711000561714172,1.8621702194213867,10000.0,47515.50264263153,51066.089199543,47515.50264263153,3539.1190111637115,5.948194980621338,0.0 -107100,1.5057693,1.8257692,,,,,,,,,,,,,, -107200,1.4680809,4.2716837,,,,,,,,,,,,,, -107300,1.8376385,1.9798732,,,,,,,,,,,,,, -107400,1.5528797,1.9665854,,,,,,,,,,,,,, -107500,1.514388,3.508938,,,,,,,,,,,,,, -107600,1.631568,4.255771,,,,,,,,,,,,,, -107700,1.4993532,1.8262203,,,,,,,,,,,,,, -107800,1.573593,3.742515,,,,,,,,,,,,,, -107900,1.4002968,2.7788887,,,,,,,,,,,,,, -108000,1.6299897,3.1405053,,,,,,,,,,,,,, -108021,,,0.7554101347923279,0.9594653844833374,0.6963199973106384,1.2319732904434204,50000.0,0.5750000476837158,1.8669058084487915,10000.0,47935.81768369675,51518.20453286171,47935.81768369675,3570.824383974076,5.9956605434417725,0.0 -108100,1.7012072,1.8702002,,,,,,,,,,,,,, -108200,1.6204482,1.7339302,,,,,,,,,,,,,, -108300,1.6224017,2.2843528,,,,,,,,,,,,,, -108400,1.5449271,1.8941834,,,,,,,,,,,,,, -108500,1.4885691,2.0792954,,,,,,,,,,,,,, -108600,1.574553,1.9148378,,,,,,,,,,,,,, -108700,1.4752716,4.204311,,,,,,,,,,,,,, -108800,1.5375226,1.9109743,,,,,,,,,,,,,, -108900,1.6784362,1.8416039,,,,,,,,,,,,,, -108967,,,0.7698437571525574,0.8918425440788269,0.7005800008773804,1.197978138923645,50000.0,0.5755000114440918,1.845165491104126,10000.0,48355.86082482338,51972.50933337212,48355.86082482338,3604.986275196075,6.047371864318848,0.0 -109000,1.4544681,2.7702584,,,,,,,,,,,,,, -109100,1.523913,1.7730905,,,,,,,,,,,,,, -109200,1.6813774,3.7770274,,,,,,,,,,,,,, -109300,1.5884045,2.2964926,,,,,,,,,,,,,, -109400,1.5599793,2.7571015,,,,,,,,,,,,,, -109500,1.7532358,2.0812829,,,,,,,,,,,,,, -109600,1.6963875,1.8619617,,,,,,,,,,,,,, -109700,1.6450652,2.265211,,,,,,,,,,,,,, -109800,1.8425702,1.851151,,,,,,,,,,,,,, -109900,1.5771087,3.414869,,,,,,,,,,,,,, -109914,,,0.7803710699081421,0.8482807874679565,0.7021200060844421,1.1901153326034546,50000.0,0.579300045967102,1.835110664367676,10000.0,48775.94476270676,52425.07089591026,48775.94476270676,3637.372734069824,6.091373920440674,0.0 -110000,1.7141913,1.7941484,,,,,,,,,,,,,, -110100,1.7395635,2.0051675,,,,,,,,,,,,,, -110200,1.5773989,3.0482273,,,,,,,,,,,,,, -110300,1.5610039,4.2901125,,,,,,,,,,,,,, -110400,1.5622475,4.232313,,,,,,,,,,,,,, -110500,1.7327769,3.3209922,,,,,,,,,,,,,, -110600,1.4496218,3.6764388,,,,,,,,,,,,,, -110700,1.4446172,2.4014454,,,,,,,,,,,,,, -110800,1.7034369,3.042827,,,,,,,,,,,,,, -110859,,,0.7680468559265137,0.9069384336471558,0.7002800107002258,1.197730302810669,50000.0,0.5750000476837158,1.831464886665344,10000.0,49195.93100190163,52876.01609253883,49195.93100190163,3668.2350981235504,6.1399030685424805,0.0 -110900,1.5934889,3.7469623,,,,,,,,,,,,,, -111000,1.5111423,2.7562194,,,,,,,,,,,,,, -111100,1.470064,3.6172,,,,,,,,,,,,,, -111200,1.4952625,3.2154489,,,,,,,,,,,,,, -111300,1.5881386,2.5304167,,,,,,,,,,,,,, -111400,1.498418,3.508635,,,,,,,,,,,,,, -111500,1.8059748,1.8451364,,,,,,,,,,,,,, -111600,1.5478654,2.8876069,,,,,,,,,,,,,, -111700,1.5999172,2.6316447,,,,,,,,,,,,,, -111800,1.678009,1.7464995,,,,,,,,,,,,,, -111804,,,0.7716405987739563,0.9013556241989136,0.7028999924659729,1.2098357677459717,50000.0,0.5737000107765198,1.8351027965545648,10000.0,49616.14709401131,53327.60908675194,49616.14709401131,3699.5143172740936,6.189666271209717,0.0 -111900,1.6431392,1.8445135,,,,,,,,,,,,,, -112000,1.5793171,1.838198,,,,,,,,,,,,,, -112100,1.5836673,4.236927,,,,,,,,,,,,,, -112200,1.9121541,2.4651423,,,,,,,,,,,,,, -112300,1.6762385,2.1422768,,,,,,,,,,,,,, -112400,1.516289,3.316108,,,,,,,,,,,,,, -112500,1.9815555,4.266105,,,,,,,,,,,,,, -112600,1.6156278,1.7405889,,,,,,,,,,,,,, -112700,1.9559608,4.035251,,,,,,,,,,,,,, -112748,,,0.7766015529632568,0.8739011883735657,0.7063800096511841,1.1852818727493286,50000.0,0.5814000368118286,1.8124264478683472,10000.0,50036.39476442337,53783.87377977371,50036.39476442337,3735.432944059372,6.240252494812012,0.0 -112800,1.6658556,4.2743716,,,,,,,,,,,,,, -112900,1.5089626,3.6181722,,,,,,,,,,,,,, -113000,1.6459438,1.7042075,,,,,,,,,,,,,, -113100,1.6140534,1.681839,,,,,,,,,,,,,, -113200,1.6784537,1.787811,,,,,,,,,,,,,, -113300,1.5365064,3.2775037,,,,,,,,,,,,,, -113400,1.7298138,1.9813664,,,,,,,,,,,,,, -113500,1.5159104,3.3352866,,,,,,,,,,,,,, -113600,1.5393828,3.4332278,,,,,,,,,,,,,, -113693,,,0.7954491972923279,0.7992606163024902,0.7044999599456787,1.186686635017395,50000.0,0.589900016784668,1.7972122430801392,10000.0,50456.44828510285,54236.47974801064,50456.44828510285,3767.642267227173,6.535725593566895,0.0 -113700,1.7019758,3.6914864,,,,,,,,,,,,,, -113800,1.7324071,1.8032458,,,,,,,,,,,,,, -113900,1.5853499,2.623421,,,,,,,,,,,,,, -114000,1.8364378,1.8542614,,,,,,,,,,,,,, -114100,1.5918139,1.841449,,,,,,,,,,,,,, -114200,1.6553669,1.8389447,,,,,,,,,,,,,, -114300,1.8660874,2.6777086,,,,,,,,,,,,,, -114400,1.6266139,4.1095624,,,,,,,,,,,,,, -114500,1.5595345,2.2305036,,,,,,,,,,,,,, -114600,1.8737994,4.264323,,,,,,,,,,,,,, -114632,,,0.7786718606948853,0.8812108635902405,0.7083799839019775,1.18280029296875,50000.0,0.5830000042915344,1.8008441925048828,10000.0,50876.50444483757,54688.73802089691,50876.50444483757,3799.7460713386536,6.585723876953125,0.0 -114700,1.5925257,1.8916509,,,,,,,,,,,,,, -114800,1.510462,3.2555432,,,,,,,,,,,,,, -114900,1.9674275,4.33809,,,,,,,,,,,,,, -115000,1.6847962,1.8811524,,,,,,,,,,,,,, -115100,1.5642829,2.0539174,,,,,,,,,,,,,, -115200,1.7276193,3.9543767,,,,,,,,,,,,,, -115300,1.8395599,1.8479936,,,,,,,,,,,,,, -115400,1.4226986,2.7068238,,,,,,,,,,,,,, -115500,1.6729751,1.8646225,,,,,,,,,,,,,, -115575,,,0.7844336032867432,0.8428012132644653,0.7099199891090393,1.1657264232635498,50000.0,0.5892000198364258,1.7851496934890747,10000.0,51296.428337574005,55139.81853270531,51296.428337574005,3830.8059952259055,6.634954214096069,0.0 -115600,1.699764,3.0922143,,,,,,,,,,,,,, -115700,1.6225861,1.9135859,,,,,,,,,,,,,, -115800,1.7368408,1.7938573,,,,,,,,,,,,,, -115900,1.6345211,1.8041686,,,,,,,,,,,,,, -116000,1.490122,3.450506,,,,,,,,,,,,,, -116100,1.8132043,4.0064917,,,,,,,,,,,,,, -116200,1.658125,3.6860185,,,,,,,,,,,,,, -116300,1.8629642,1.8474009,,,,,,,,,,,,,, -116400,1.683202,3.463297,,,,,,,,,,,,,, -116500,1.9265575,1.7494936,,,,,,,,,,,,,, -116520,,,0.7871679663658142,0.8348548412322998,0.7083399891853333,1.1812052726745603,50000.0,0.5809000134468079,1.824753761291504,10000.0,51716.66536331177,55592.89674210549,51716.66536331177,3863.548504590988,6.685975074768066,0.0 -116600,1.7510523,1.8154515,,,,,,,,,,,,,, -116700,1.6567088,1.7692862,,,,,,,,,,,,,, -116800,1.6586316,1.7763457,,,,,,,,,,,,,, -116900,1.8412609,1.9955148,,,,,,,,,,,,,, -117000,1.6212813,3.8653235,,,,,,,,,,,,,, -117100,1.5428454,1.8758551,,,,,,,,,,,,,, -117200,1.6070393,2.526783,,,,,,,,,,,,,, -117300,1.6535372,1.6816256,,,,,,,,,,,,,, -117400,1.8048536,3.8249116,,,,,,,,,,,,,, -117464,,,0.7781640291213989,0.8482047319412231,0.7140199542045593,1.1441729068756104,50000.0,0.5915000438690186,1.7694329023361206,10000.0,52136.707661151886,56043.724599123,52136.707661151886,3894.236728906632,6.735541582107544,0.0 -117500,1.9284474,1.775527,,,,,,,,,,,,,, -117600,1.7127746,1.9321922,,,,,,,,,,,,,, -117700,1.6143366,2.5041876,,,,,,,,,,,,,, -117800,1.6817411,2.8375473,,,,,,,,,,,,,, -117900,1.5978187,2.101124,,,,,,,,,,,,,, -118000,1.6506989,1.7526416,,,,,,,,,,,,,, -118100,1.6866633,1.7086064,,,,,,,,,,,,,, -118200,1.7471019,4.1537976,,,,,,,,,,,,,, -118300,2.1616983,1.8094428,,,,,,,,,,,,,, -118400,1.726933,1.6561779,,,,,,,,,,,,,, -118406,,,0.7785351276397705,0.8815587162971497,0.709119975566864,1.1885517835617063,50000.0,0.586400032043457,1.818394660949707,10000.0,52556.71144080162,56495.050055503845,52556.71144080162,3925.459734916687,6.7867701053619385,0.0 -118500,1.6524317,1.7819749,,,,,,,,,,,,,, -118600,1.7269471,1.7713395,,,,,,,,,,,,,, -118700,1.7232157,3.6720364,,,,,,,,,,,,,, -118800,1.8768867,4.0811834,,,,,,,,,,,,,, -118900,1.7238387,1.8030374,,,,,,,,,,,,,, -119000,1.7331364,1.7723933,,,,,,,,,,,,,, -119100,1.721258,3.1794088,,,,,,,,,,,,,, -119200,1.6129305,3.0497687,,,,,,,,,,,,,, -119300,1.7443713,1.7761337,,,,,,,,,,,,,, -119351,,,0.7892773151397705,0.8121300935745239,0.7165200114250183,1.1363714933395386,50000.0,0.5924000144004822,1.7503550052642822,10000.0,52976.91713619232,56946.94454836845,52976.91713619232,3957.0508601665497,6.83685827255249,0.0 -119400,2.0035713,4.298068,,,,,,,,,,,,,, -119500,1.7795626,1.783211,,,,,,,,,,,,,, -119600,1.6457039,3.6992507,,,,,,,,,,,,,, -119700,1.7949739,2.0489855,,,,,,,,,,,,,, -119800,1.7203139,1.932061,,,,,,,,,,,,,, -119900,1.9194899,1.6853813,,,,,,,,,,,,,, -120000,1.7251828,2.0225544,,,,,,,,,,,,,, -120100,1.9780809,3.887326,,,,,,,,,,,,,, -120200,1.8432242,1.7861124,,,,,,,,,,,,,, -120296,,,0.8076757788658142,0.7323868274688721,0.7183399796485901,1.1285380125045776,50000.0,0.5946000218391418,1.750155329704285,10000.0,53397.10926628113,57400.011506319046,53397.10926628113,3989.827569723129,6.887446403503418,0.0 -120300,1.6943889,1.7547843,,,,,,,,,,,,,, -120400,1.6884389,1.7867218,,,,,,,,,,,,,, -120500,1.7550378,1.7003317,,,,,,,,,,,,,, -120600,1.9157263,2.5933611,,,,,,,,,,,,,, -120700,1.6681799,2.065671,,,,,,,,,,,,,, -120800,1.6923367,3.3004284,,,,,,,,,,,,,, -120900,1.667635,2.9356332,,,,,,,,,,,,,, -121000,1.8778638,3.4853878,,,,,,,,,,,,,, -121100,1.6710892,2.215318,,,,,,,,,,,,,, -121200,1.8906631,1.7632409,,,,,,,,,,,,,, -121243,,,0.7827538847923279,0.8527898192405701,0.7153399586677551,1.1425896883010864,50000.0,0.596500039100647,1.758715271949768,10000.0,53817.029266119,57855.23564505577,53817.029266119,4025.02843785286,6.941509246826172,0.0 -121300,1.7643749,1.7328607,,,,,,,,,,,,,, -121400,1.836906,1.711704,,,,,,,,,,,,,, -121500,1.6491289,1.6668478,,,,,,,,,,,,,, -121600,1.6647278,2.2018015,,,,,,,,,,,,,, -121700,1.7066376,1.8660219,,,,,,,,,,,,,, -121800,1.6640713,1.7267035,,,,,,,,,,,,,, -121900,1.759652,3.4034705,,,,,,,,,,,,,, -122000,1.6030183,2.7947016,,,,,,,,,,,,,, -122100,2.16867,4.223646,,,,,,,,,,,,,, -122191,,,0.7879882454872131,0.8077414631843567,0.7186999917030334,1.1167722940444946,50000.0,0.5915000438690186,1.7542343139648438,10000.0,54237.009996175766,58306.75683450699,54237.009996175766,4056.473051548004,6.989522695541382,0.0 -122200,1.7332395,3.2350793,,,,,,,,,,,,,, -122300,1.6366762,2.2534947,,,,,,,,,,,,,, -122400,1.6385939,2.5343413,,,,,,,,,,,,,, -122500,1.7869968,2.655301,,,,,,,,,,,,,, -122600,2.1140287,1.7188295,,,,,,,,,,,,,, -122700,1.882122,1.7383014,,,,,,,,,,,,,, -122800,1.8683779,1.6336696,,,,,,,,,,,,,, -122900,1.6162697,1.7973598,,,,,,,,,,,,,, -123000,1.6647754,3.2905066,,,,,,,,,,,,,, -123100,1.7050027,1.6648955,,,,,,,,,,,,,, -123137,,,0.80322265625,0.7724707722663879,0.7219600081443787,1.1197479963302612,50000.0,0.6009000539779663,1.726817607879639,10000.0,54657.21346616745,58758.33668446541,54657.21346616745,4087.749156236649,7.041682243347168,0.0 -123200,1.9999139,3.6118858,,,,,,,,,,,,,, -123300,1.6445075,3.0607872,,,,,,,,,,,,,, -123400,1.6838974,3.1430287,,,,,,,,,,,,,, -123500,1.705569,1.9197736,,,,,,,,,,,,,, -123600,1.7218044,1.8161464,,,,,,,,,,,,,, -123700,1.8753824,1.4980737,,,,,,,,,,,,,, -123800,1.7344966,1.5438911,,,,,,,,,,,,,, -123900,1.9713489,2.876726,,,,,,,,,,,,,, -124000,1.8738846,1.681619,,,,,,,,,,,,,, -124081,,,0.7920898199081421,0.8035741448402405,0.7215999960899353,1.1221164464950562,50000.0,0.6030000448226929,1.7267364263534546,10000.0,55077.45772242546,59211.47411370277,55077.45772242546,4120.542637825012,7.0932228565216064,0.0 -124100,1.8998282,3.6841967,,,,,,,,,,,,,, -124200,1.856413,1.6582252,,,,,,,,,,,,,, -124300,1.8529606,3.963126,,,,,,,,,,,,,, -124400,1.8740156,1.7737999,,,,,,,,,,,,,, -124500,1.7672559,1.7767122,,,,,,,,,,,,,, -124600,1.751184,1.7227962,,,,,,,,,,,,,, -124700,1.7978343,2.0675325,,,,,,,,,,,,,, -124800,1.8283284,1.6303447,,,,,,,,,,,,,, -124900,1.7641121,1.7387841,,,,,,,,,,,,,, -125000,1.8241153,3.4774742,,,,,,,,,,,,,, -125028,,,0.7930859327316284,0.8153924345970154,0.7215200066566467,1.1258158683776855,50000.0,0.5985000133514404,1.7480947971343994,10000.0,55497.51085400581,59661.67133617401,55497.51085400581,4150.582463741303,7.149578332901001,0.0 -125100,1.8948044,1.719596,,,,,,,,,,,,,, -125200,1.8634734,3.095821,,,,,,,,,,,,,, -125300,1.8357458,1.9810314,,,,,,,,,,,,,, -125400,2.0109937,1.622161,,,,,,,,,,,,,, -125500,1.9554732,1.7226381,,,,,,,,,,,,,, -125600,1.716936,2.6545174,,,,,,,,,,,,,, -125700,1.9894233,3.2460663,,,,,,,,,,,,,, -125800,1.7004786,2.8509977,,,,,,,,,,,,,, -125900,1.7608081,1.8655766,,,,,,,,,,,,,, -125972,,,0.8003320097923279,0.7735731601715088,0.7224999666213989,1.1156514883041382,50000.0,0.5933000445365906,1.7462410926818848,10000.0,55917.58813285828,60115.00229549408,55917.58813285828,4183.735275506973,7.201582193374634,0.0 -126000,1.6632062,2.0799432,,,,,,,,,,,,,, -126100,1.6651306,2.9621656,,,,,,,,,,,,,, -126200,1.7766124,2.1005857,,,,,,,,,,,,,, -126300,1.6321571,2.8056066,,,,,,,,,,,,,, -126400,1.9365687,1.8217752,,,,,,,,,,,,,, -126500,1.8749963,1.7149588,,,,,,,,,,,,,, -126600,2.1138532,4.0973277,,,,,,,,,,,,,, -126700,1.8186604,1.6130817,,,,,,,,,,,,,, -126800,1.7375925,2.3190057,,,,,,,,,,,,,, -126900,2.0095463,1.7497225,,,,,,,,,,,,,, -126916,,,0.8159765601158142,0.7205840945243835,0.7251399755477905,1.105201005935669,50000.0,0.6083000302314758,1.7146964073181152,10000.0,56337.57939887047,60565.77696371079,56337.57939887047,4214.42046713829,7.252074480056763,0.0 -127000,2.0022657,1.765072,,,,,,,,,,,,,, -127100,2.127909,1.7453989,,,,,,,,,,,,,, -127200,2.1609502,4.007538,,,,,,,,,,,,,, -127300,1.8012589,1.660583,,,,,,,,,,,,,, -127400,1.8921992,1.6509562,,,,,,,,,,,,,, -127500,1.8292079,1.9079618,,,,,,,,,,,,,, -127600,1.8431687,1.7257004,,,,,,,,,,,,,, -127700,1.9400188,1.6623288,,,,,,,,,,,,,, -127800,1.7388765,2.1262877,,,,,,,,,,,,,, -127863,,,0.8004687428474426,0.7621183395385742,0.7263799905776978,1.087070107460022,50000.0,0.6078000068664551,1.6983314752578735,10000.0,56757.66656923294,61016.87927532196,56757.66656923294,4245.306490898132,7.332320213317871,0.0 -127900,1.8018793,2.3965092,,,,,,,,,,,,,, -128000,2.02688,1.7912732,,,,,,,,,,,,,, -128100,1.7927425,1.5994201,,,,,,,,,,,,,, -128200,2.0331786,1.598011,,,,,,,,,,,,,, -128300,2.133408,1.7303672,,,,,,,,,,,,,, -128400,1.9041654,1.7227864,,,,,,,,,,,,,, -128500,1.982539,2.9858534,,,,,,,,,,,,,, -128600,1.9487793,1.6251783,,,,,,,,,,,,,, -128700,1.9790212,3.9682124,,,,,,,,,,,,,, -128800,1.7419356,2.8200939,,,,,,,,,,,,,, -128810,,,0.804492175579071,0.7556107044219971,0.7278800010681152,1.0926361083984375,50000.0,0.6009000539779663,1.718663215637207,10000.0,57177.68268537521,61470.233803510666,57177.68268537521,4278.540325880051,7.388689041137695,0.0 -128900,1.8488272,2.2102146,,,,,,,,,,,,,, -129000,1.8777751,2.0771184,,,,,,,,,,,,,, -129100,1.9934413,1.7327178,,,,,,,,,,,,,, -129200,1.8949549,1.7021155,,,,,,,,,,,,,, -129300,2.1841714,1.7155801,,,,,,,,,,,,,, -129400,1.9593492,1.7338452,,,,,,,,,,,,,, -129500,1.7872568,2.2366288,,,,,,,,,,,,,, -129600,1.7768046,3.357828,,,,,,,,,,,,,, -129700,2.289953,1.6171287,,,,,,,,,,,,,, -129758,,,0.8183202743530273,0.6911411881446838,0.730139970779419,1.0635643005371094,50000.0,0.6118000149726868,1.675654411315918,10000.0,57597.74878168106,61922.46889066696,57597.74878168106,4310.619389057159,7.430602788925171,0.0 -129800,2.145216,1.610166,,,,,,,,,,,,,, -129900,1.6871371,1.8746717,,,,,,,,,,,,,, -130000,1.8548408,2.1809993,,,,,,,,,,,,,, -130100,2.181918,4.052334,,,,,,,,,,,,,, -130200,1.8668727,1.7569096,,,,,,,,,,,,,, -130300,2.045528,1.6318574,,,,,,,,,,,,,, -130400,1.88894,1.6488769,,,,,,,,,,,,,, -130500,1.7621455,1.6530476,,,,,,,,,,,,,, -130600,2.031075,1.6386322,,,,,,,,,,,,,, -130700,1.9743905,1.6335822,,,,,,,,,,,,,, -130702,,,0.8029296398162842,0.7509654760360718,0.731220006942749,1.0709768533706665,50000.0,0.6089000105857849,1.6937024593353271,10000.0,58017.70781111717,62379.66068387032,58017.70781111717,4347.752557516098,7.481885433197021,0.0 -130800,2.1055753,1.5415225,,,,,,,,,,,,,, -130900,1.9771538,2.4540358,,,,,,,,,,,,,, -131000,2.0148509,1.6892364,,,,,,,,,,,,,, -131100,1.7411613,2.5716133,,,,,,,,,,,,,, -131200,1.8564672,1.718306,,,,,,,,,,,,,, -131300,1.9093872,1.5999085,,,,,,,,,,,,,, -131400,1.9968102,1.6025491,,,,,,,,,,,,,, -131500,1.951978,1.6456368,,,,,,,,,,,,,, -131600,1.9649111,3.5998101,,,,,,,,,,,,,, -131648,,,0.80824214220047,0.7324361801147461,0.7309799790382385,1.0710631608963013,50000.0,0.6130000352859497,1.677901268005371,10000.0,58437.64938545227,62832.02004933357,58437.64938545227,4380.077574729919,7.52693510055542,0.0 -131700,2.1606867,1.7124636,,,,,,,,,,,,,, -131800,1.7804397,2.7248769,,,,,,,,,,,,,, -131900,2.140687,1.5953641,,,,,,,,,,,,,, -132000,2.0828516,1.5055544,,,,,,,,,,,,,, -132100,2.0118494,1.5698537,,,,,,,,,,,,,, -132200,1.9639666,2.0131383,,,,,,,,,,,,,, -132300,1.8778092,1.7262992,,,,,,,,,,,,,, -132400,1.8353636,1.5518417,,,,,,,,,,,,,, -132500,1.8756255,1.5628898,,,,,,,,,,,,,, -132594,,,0.8122069835662842,0.7282311320304871,0.7334399819374084,1.069677233695984,50000.0,0.6110000014305115,1.682750225067139,10000.0,58857.876715660095,63284.00793981552,58857.876715660095,4411.736454963684,7.580765724182129,0.0 -132600,1.9514351,3.272162,,,,,,,,,,,,,, -132700,1.9449502,1.5253665,,,,,,,,,,,,,, -132800,2.0643337,3.9216592,,,,,,,,,,,,,, -132900,1.898278,2.2088678,,,,,,,,,,,,,, -133000,1.8583201,1.503083,,,,,,,,,,,,,, -133100,1.934651,1.4941274,,,,,,,,,,,,,, -133200,1.8501079,2.0267537,,,,,,,,,,,,,, -133300,1.9415318,3.2059696,,,,,,,,,,,,,, -133400,1.9194164,2.1684897,,,,,,,,,,,,,, -133500,2.0126479,3.1227255,,,,,,,,,,,,,, -133537,,,0.8247460722923279,0.6852326989173889,0.7335799932479858,1.0745161771774292,50000.0,0.6073000431060791,1.697361946105957,10000.0,59278.13584399223,63738.54997134209,59278.13584399223,4445.912066459656,7.64042592048645,0.0 -133600,2.1407533,1.5587186,,,,,,,,,,,,,, -133700,1.9755381,1.4930098,,,,,,,,,,,,,, -133800,1.9392761,1.7134349,,,,,,,,,,,,,, -133900,1.9536531,1.5213997,,,,,,,,,,,,,, -134000,1.8536831,2.5421443,,,,,,,,,,,,,, -134100,1.9911907,1.6397161,,,,,,,,,,,,,, -134200,1.9261289,3.171059,,,,,,,,,,,,,, -134300,1.8353251,2.3698003,,,,,,,,,,,,,, -134400,2.0099535,1.5225153,,,,,,,,,,,,,, -134484,,,0.81361323595047,0.7152171730995178,0.7375999689102173,1.051063895225525,50000.0,0.613800048828125,1.670855164527893,10000.0,59698.29487943649,64194.70729804039,59698.29487943649,4481.801281452179,7.701645374298096,0.0 -134500,1.9908549,1.6993531,,,,,,,,,,,,,, -134600,1.9259279,1.5780058,,,,,,,,,,,,,, -134700,1.959281,1.634247,,,,,,,,,,,,,, -134800,2.262637,1.5656025,,,,,,,,,,,,,, -134900,1.8440002,2.5421345,,,,,,,,,,,,,, -135000,2.146273,1.6479475,,,,,,,,,,,,,, -135100,1.8321295,2.467954,,,,,,,,,,,,,, -135200,1.7360862,2.3313727,,,,,,,,,,,,,, -135300,2.1038835,1.6692439,,,,,,,,,,,,,, -135400,2.257856,3.640189,,,,,,,,,,,,,, -135432,,,0.8185351490974426,0.6898524165153503,0.7380599975585938,1.040588140487671,50000.0,0.6119000315666199,1.6614116430282593,10000.0,60118.45872378349,64648.24386954308,60118.45872378349,4515.079110383987,7.746327400207519,0.0 -135500,1.9614817,2.5614488,,,,,,,,,,,,,, -135600,2.1178944,1.5666482,,,,,,,,,,,,,, -135700,2.0868018,3.826544,,,,,,,,,,,,,, -135800,1.9155326,2.466266,,,,,,,,,,,,,, -135900,2.025386,2.4403353,,,,,,,,,,,,,, -136000,1.8051183,2.7021396,,,,,,,,,,,,,, -136100,1.9186084,1.7820264,,,,,,,,,,,,,, -136200,1.9227618,3.1400163,,,,,,,,,,,,,, -136300,1.979969,2.5636487,,,,,,,,,,,,,, -136378,,,0.8243945240974426,0.6757507920265198,0.7364199757575989,1.0444821119308472,50000.0,0.6154000163078308,1.6687698364257812,10000.0,60538.52721905708,65101.20194649696,60538.52721905708,4547.873115539551,7.793517351150513,0.0 -136400,1.966766,1.5157475,,,,,,,,,,,,,, -136500,1.7842734,2.5647929,,,,,,,,,,,,,, -136600,2.011026,1.5979651,,,,,,,,,,,,,, -136700,2.2579255,1.6199017,,,,,,,,,,,,,, -136800,2.0048368,2.0444345,,,,,,,,,,,,,, -136900,1.8473092,2.76509,,,,,,,,,,,,,, -137000,2.0235593,1.5320643,,,,,,,,,,,,,, -137100,2.099291,1.5688753,,,,,,,,,,,,,, -137200,2.1183352,1.5997797,,,,,,,,,,,,,, -137300,2.2383325,1.6540364,,,,,,,,,,,,,, -137323,,,0.8189452886581421,0.6932224631309509,0.740339994430542,1.0378048419952393,50000.0,0.6136000156402588,1.662897706031799,10000.0,60958.88538622856,65556.06583595276,60958.88538622856,4582.273346424103,7.8503875732421875,0.0 -137400,2.2124164,1.5762267,,,,,,,,,,,,,, -137500,1.9915183,3.156508,,,,,,,,,,,,,, -137600,2.1644478,1.5179303,,,,,,,,,,,,,, -137700,2.2212825,3.1667073,,,,,,,,,,,,,, -137800,2.1089385,1.4456499,,,,,,,,,,,,,, -137900,2.021268,1.9788023,,,,,,,,,,,,,, -138000,1.9612851,2.6278083,,,,,,,,,,,,,, -138100,2.1712592,1.463591,,,,,,,,,,,,,, -138200,2.0490577,1.5245974,,,,,,,,,,,,,, -138270,,,0.8212890625,0.7007355093955994,0.7413199543952942,1.0484044551849363,50000.0,0.6142000555992126,1.658218264579773,10000.0,61378.89130926132,66008.60767126083,61378.89130926132,4614.699839115143,7.911661148071289,0.0 -138300,2.1364503,1.6354556,,,,,,,,,,,,,, -138400,2.472939,3.552651,,,,,,,,,,,,,, -138500,1.9686589,1.919209,,,,,,,,,,,,,, -138600,2.02907,1.4900957,,,,,,,,,,,,,, -138700,2.2918916,3.7938924,,,,,,,,,,,,,, -138800,2.0074742,1.3761148,,,,,,,,,,,,,, -138900,2.1408079,1.4517158,,,,,,,,,,,,,, -139000,2.7633197,4.0042133,,,,,,,,,,,,,, -139100,2.0885396,1.6137298,,,,,,,,,,,,,, -139200,2.1772017,1.6798334,,,,,,,,,,,,,, -139212,,,0.8290624618530273,0.6643991470336914,0.7405399680137634,1.040321946144104,50000.0,0.6215000152587891,1.646551251411438,10000.0,61798.8139359951,66460.71085047722,61798.8139359951,4646.778388261795,7.966315031051636,0.0 -139300,1.9353462,2.0123153,,,,,,,,,,,,,, -139400,2.1869438,1.619079,,,,,,,,,,,,,, -139500,2.215049,1.6853912,,,,,,,,,,,,,, -139600,1.9517417,2.263763,,,,,,,,,,,,,, -139700,1.9423774,1.866188,,,,,,,,,,,,,, -139800,2.016203,1.6302402,,,,,,,,,,,,,, -139900,2.3614063,3.8449533,,,,,,,,,,,,,, -140000,2.1701572,1.5373245,,,,,,,,,,,,,, -140100,2.051345,1.7553895,,,,,,,,,,,,,, -140156,,,0.8376757502555847,0.613097608089447,0.7422199845314026,1.0231008529663086,50000.0,0.6212000250816345,1.6370558738708496,10000.0,62218.95655655861,66912.40876984596,62218.95655655861,4678.231098890305,8.020896673202515,0.0 -140200,2.3402736,3.713598,,,,,,,,,,,,,, -140300,2.1624906,1.8396398,,,,,,,,,,,,,, -140400,2.8016646,3.7692695,,,,,,,,,,,,,, -140500,2.0963693,1.4716944,,,,,,,,,,,,,, -140600,2.155943,1.487782,,,,,,,,,,,,,, -140700,2.3277159,3.671545,,,,,,,,,,,,,, -140800,2.2138402,1.5942866,,,,,,,,,,,,,, -140900,2.0167909,1.9124129,,,,,,,,,,,,,, -141000,2.2711077,3.4594052,,,,,,,,,,,,,, -141100,2.2022965,1.5778216,,,,,,,,,,,,,, -141101,,,0.8280078172683716,0.6409241557121277,0.7453199625015259,1.0019612312316897,50000.0,0.6190000176429749,1.6144688129425049,10000.0,62639.31224012375,67367.12671756744,62639.31224012375,4712.48251914978,8.083777904510498,0.0 -141200,2.1969798,1.5297769,,,,,,,,,,,,,, -141300,2.1393473,1.6471286,,,,,,,,,,,,,, -141400,2.1952143,3.162817,,,,,,,,,,,,,, -141500,2.0703237,1.5010811,,,,,,,,,,,,,, -141600,2.3196564,3.2562857,,,,,,,,,,,,,, -141700,2.5749662,3.8245354,,,,,,,,,,,,,, -141800,2.1090488,1.5322351,,,,,,,,,,,,,, -141900,2.2173405,2.3311048,,,,,,,,,,,,,, -142000,2.1795092,3.036779,,,,,,,,,,,,,, -142049,,,0.8278319835662842,0.646851658821106,0.7454800009727478,1.0134694576263428,50000.0,0.6236000061035156,1.622061252593994,10000.0,63059.65165567398,67820.41301584244,63059.65165567398,4745.326453208923,8.137955904006958,0.0 -142100,2.3054538,3.8478773,,,,,,,,,,,,,, -142200,2.1394887,1.92609,,,,,,,,,,,,,, -142300,2.0228086,2.1597404,,,,,,,,,,,,,, -142400,2.367671,1.436569,,,,,,,,,,,,,, -142500,2.186705,2.3123782,,,,,,,,,,,,,, -142600,2.185886,1.4702532,,,,,,,,,,,,,, -142700,2.5484676,3.815847,,,,,,,,,,,,,, -142800,2.016334,2.1728933,,,,,,,,,,,,,, -142900,2.0753343,2.9239464,,,,,,,,,,,,,, -142990,,,0.8350585699081421,0.6243909001350403,0.7469799518585205,1.008358120918274,50000.0,0.6205000281333923,1.6167891025543213,10000.0,63479.56883740425,68271.46160244942,63479.56883740425,4776.352700948715,8.195709705352783,0.0 -143000,2.1615486,1.8270342,,,,,,,,,,,,,, -143100,2.1289902,1.4836961,,,,,,,,,,,,,, -143200,2.3260903,1.6839206,,,,,,,,,,,,,, -143300,1.9833773,2.9125295,,,,,,,,,,,,,, -143400,2.1051483,1.5587662,,,,,,,,,,,,,, -143500,2.4483852,2.310587,,,,,,,,,,,,,, -143600,2.2640655,1.4918277,,,,,,,,,,,,,, -143700,2.2935314,1.4127662,,,,,,,,,,,,,, -143800,2.1042225,1.7916192,,,,,,,,,,,,,, -143900,2.1277058,2.3059235,,,,,,,,,,,,,, -143932,,,0.8307421803474426,0.6419038772583008,0.7469799518585205,0.9995123744010924,50000.0,0.6215000152587891,1.6196413040161133,10000.0,63899.6438407898,68728.61555743217,63899.6438407898,4813.3277044296265,8.251341104507446,0.0 -144000,2.2845137,1.3680367,,,,,,,,,,,,,, -144100,1.9988983,1.8711377,,,,,,,,,,,,,, -144200,2.186977,1.6049335,,,,,,,,,,,,,, -144300,2.2170944,1.8294796,,,,,,,,,,,,,, -144400,2.5918276,3.6952255,,,,,,,,,,,,,, -144500,2.240554,1.6445851,,,,,,,,,,,,,, -144600,2.1568575,2.0738547,,,,,,,,,,,,,, -144700,2.2625062,3.3146925,,,,,,,,,,,,,, -144800,2.14792,1.3795645,,,,,,,,,,,,,, -144882,,,0.8347265720367432,0.622641384601593,0.7500399947166443,0.9826123118400574,50000.0,0.6274000406265259,1.5910381078720093,10000.0,64319.94776797295,69184.67727637291,64319.94776797295,4848.98010802269,8.308446168899536,0.0 -144900,1.9744036,2.0177462,,,,,,,,,,,,,, -145000,2.2206914,1.6636124,,,,,,,,,,,,,, -145100,2.287763,1.4857619,,,,,,,,,,,,,, -145200,2.0611308,2.1109025,,,,,,,,,,,,,, -145300,2.363769,3.1843398,,,,,,,,,,,,,, -145400,2.3275206,2.416407,,,,,,,,,,,,,, -145500,2.1843438,3.0487573,,,,,,,,,,,,,, -145600,2.6703022,3.7385879,,,,,,,,,,,,,, -145700,2.8686721,3.5374978,,,,,,,,,,,,,, -145800,2.27465,1.5215783,,,,,,,,,,,,,, -145829,,,0.8373242020606995,0.6188812255859375,0.7525999546051025,0.9853445887565612,50000.0,0.629300057888031,1.5892504453659058,10000.0,64740.01791214943,69635.6463303566,64740.01791214943,4879.781792402268,8.357537031173706,0.0 -145900,2.3537076,1.684353,,,,,,,,,,,,,, -146000,2.2023687,1.4072609,,,,,,,,,,,,,, -146100,2.1996741,1.4901681,,,,,,,,,,,,,, -146200,2.4027007,1.4344156,,,,,,,,,,,,,, -146300,2.2393975,1.3651665,,,,,,,,,,,,,, -146400,2.0558217,2.6485717,,,,,,,,,,,,,, -146500,2.150051,1.8863388,,,,,,,,,,,,,, -146600,2.4932494,3.425236,,,,,,,,,,,,,, -146700,2.2821023,3.371592,,,,,,,,,,,,,, -146774,,,0.8498046398162842,0.5664092302322388,0.7521399855613708,0.978649377822876,50000.0,0.6303000450134277,1.581558108329773,10000.0,65160.04265499115,70093.58276224136,65160.04265499115,4917.589969396591,8.413017749786377,0.0 -146800,2.2256093,1.5277619,,,,,,,,,,,,,, -146900,2.7021997,1.3720248,,,,,,,,,,,,,, -147000,2.4195993,2.7188237,,,,,,,,,,,,,, -147100,2.412402,1.4307796,,,,,,,,,,,,,, -147200,2.573652,1.8829271,,,,,,,,,,,,,, -147300,2.8807194,3.8048427,,,,,,,,,,,,,, -147400,2.0922682,1.4037802,,,,,,,,,,,,,, -147500,2.3933372,1.3652607,,,,,,,,,,,,,, -147600,2.2991827,1.4512806,,,,,,,,,,,,,, -147700,2.3468573,1.6269792,,,,,,,,,,,,,, -147722,,,0.8400195240974426,0.6110930442810059,0.7532599568367004,0.9788589477539062,50000.0,0.6295000314712524,1.5846375226974487,10000.0,65579.95571732521,70546.2168867588,65579.95571732521,4950.214616537094,8.461366415023804,0.0 -147800,2.3805542,1.5946257,,,,,,,,,,,,,, -147900,2.2816184,2.3184013,,,,,,,,,,,,,, -148000,2.270159,1.3319644,,,,,,,,,,,,,, -148100,2.3953283,1.5171599,,,,,,,,,,,,,, -148200,2.1919806,1.3609574,,,,,,,,,,,,,, -148300,2.3571532,1.4468677,,,,,,,,,,,,,, -148400,2.188408,3.246351,,,,,,,,,,,,,, -148500,2.4539783,2.4024694,,,,,,,,,,,,,, -148600,2.2729442,1.4239426,,,,,,,,,,,,,, -148669,,,0.840624988079071,0.6056962609291077,0.7522000074386597,0.9804185628890992,50000.0,0.6278000473976135,1.5932260751724243,10000.0,65999.89698147774,70999.6721727848,65999.89698147774,4983.62385559082,8.5170156955719,0.0 -148700,2.8491266,3.5833957,,,,,,,,,,,,,, -148800,2.4998248,1.4351809,,,,,,,,,,,,,, -148900,2.0487113,1.7632791,,,,,,,,,,,,,, -149000,2.58466,3.654098,,,,,,,,,,,,,, -149100,2.4600825,2.9741812,,,,,,,,,,,,,, -149200,2.2564628,1.6583735,,,,,,,,,,,,,, -149300,2.2422705,1.4449868,,,,,,,,,,,,,, -149400,2.3641522,1.4098859,,,,,,,,,,,,,, -149500,4.9376845,2.8429809,,,,,,,,,,,,,, -149600,2.3837066,1.6358197,,,,,,,,,,,,,, -149616,,,0.8484765291213989,0.5737552046775818,0.7560999989509583,0.9599688649177552,50000.0,0.6355000138282776,1.575750470161438,10000.0,66420.16347050667,71457.34581279755,66420.16347050667,5020.925404548645,8.574753284454346,0.0 -149700,2.3136604,1.5405755,,,,,,,,,,,,,, -149800,2.4059594,1.4313225,,,,,,,,,,,,,, -149900,2.2053356,2.2250228,,,,,,,,,,,,,, -150000,2.4570785,1.5610027,,,,,,,,,,,,,, -150100,2.6193748,1.7369405,,,,,,,,,,,,,, -150200,2.429391,1.4635563,,,,,,,,,,,,,, -150300,2.2928016,1.3723981,,,,,,,,,,,,,, -150400,2.237329,1.9066075,,,,,,,,,,,,,, -150500,2.282509,2.328608,,,,,,,,,,,,,, -150563,,,0.8442773222923279,0.588003396987915,0.7562599778175354,0.963024377822876,50000.0,0.6318000555038452,1.5865153074264526,10000.0,66840.26841020584,71910.77382802963,66840.26841020584,5054.153319835663,8.622200965881348,0.0 -150600,2.174834,2.0881407,,,,,,,,,,,,,, -150700,2.6492863,1.8811504,,,,,,,,,,,,,, -150800,2.4217658,1.6881741,,,,,,,,,,,,,, -150900,2.1781836,2.1396942,,,,,,,,,,,,,, -151000,2.2132287,1.4068366,,,,,,,,,,,,,, -151100,2.4575574,1.7118835,,,,,,,,,,,,,, -151200,2.871446,3.088022,,,,,,,,,,,,,, -151300,2.4467235,1.4190347,,,,,,,,,,,,,, -151400,2.6009312,3.5646577,,,,,,,,,,,,,, -151500,2.3441586,2.3513134,,,,,,,,,,,,,, -151506,,,0.8475585579872131,0.5803536772727966,0.7574999928474426,0.9603232741355896,50000.0,0.6328000426292419,1.570575714111328,10000.0,67260.58755874634,72365.41188597679,67260.58755874634,5088.375540494919,8.670436143875122,0.0 -151600,2.5295405,1.4143991,,,,,,,,,,,,,, -151700,2.3118744,1.4228872,,,,,,,,,,,,,, -151800,2.1886876,1.6790185,,,,,,,,,,,,,, -151900,2.6066082,1.3920026,,,,,,,,,,,,,, -152000,2.2750666,1.7786384,,,,,,,,,,,,,, -152100,2.3578038,1.6494825,,,,,,,,,,,,,, -152200,2.2117977,1.6450775,,,,,,,,,,,,,, -152300,2.6330862,1.3680406,,,,,,,,,,,,,, -152400,2.3444953,1.6883172,,,,,,,,,,,,,, -152452,,,0.8537304401397705,0.5536989569664001,0.7576799988746643,0.9566033482551576,50000.0,0.636400043964386,1.5653671026229858,10000.0,67680.61174440384,72816.272285223,67680.61174440384,5119.106483459473,8.727611780166626,0.0 -152500,2.6079433,1.3968776,,,,,,,,,,,,,, -152600,2.5484407,2.923205,,,,,,,,,,,,,, -152700,2.495211,3.1519933,,,,,,,,,,,,,, -152800,2.5470297,1.375028,,,,,,,,,,,,,, -152900,2.3243184,1.3299017,,,,,,,,,,,,,, -153000,2.3835385,1.3222065,,,,,,,,,,,,,, -153100,2.4576542,1.7470956,,,,,,,,,,,,,, -153200,2.4610362,1.3804473,,,,,,,,,,,,,, -153300,2.2930121,2.451984,,,,,,,,,,,,,, -153394,,,0.8624218702316284,0.5283911824226379,0.7599799633026123,0.9521071910858154,50000.0,0.6341000199317932,1.5579533576965332,10000.0,68100.60788464546,73270.136343956,68100.60788464546,5152.867078065872,8.786496639251709,0.0 -153400,2.3588293,1.3787994,,,,,,,,,,,,,, -153500,2.330448,2.1736028,,,,,,,,,,,,,, -153600,2.3044744,1.6216092,,,,,,,,,,,,,, -153700,2.9221156,3.4543157,,,,,,,,,,,,,, -153800,2.5277457,1.2949526,,,,,,,,,,,,,, -153900,2.2388868,1.9649787,,,,,,,,,,,,,, -154000,2.6535935,1.3274014,,,,,,,,,,,,,, -154100,2.6087363,1.3296922,,,,,,,,,,,,,, -154200,2.5812678,1.4338729,,,,,,,,,,,,,, -154300,2.6049304,1.3606819,,,,,,,,,,,,,, -154339,,,0.8519921898841858,0.5664136409759521,0.7613399624824524,0.9477301836013794,50000.0,0.6384000182151794,1.555184006690979,10000.0,68520.679936409,73720.59342503548,68520.679936409,5183.143657207489,8.846336364746094,0.0 -154400,2.4185464,1.3083773,,,,,,,,,,,,,, -154500,2.4455533,1.7935414,,,,,,,,,,,,,, -154600,2.419892,1.4211501,,,,,,,,,,,,,, -154700,2.5941024,1.3216498,,,,,,,,,,,,,, -154800,2.5374134,1.4121815,,,,,,,,,,,,,, -154900,2.3862038,1.6888742,,,,,,,,,,,,,, -155000,2.6116562,1.2693205,,,,,,,,,,,,,, -155100,2.6462016,2.330503,,,,,,,,,,,,,, -155200,2.8881671,3.1188745,,,,,,,,,,,,,, -155283,,,0.85498046875,0.5443105697631836,0.7622199654579163,0.9341087341308594,50000.0,0.6430000066757202,1.5386472940444946,10000.0,68940.78204774857,74178.38667702675,68940.78204774857,5220.722696304321,8.91114854812622,0.0 -155300,2.510568,1.345033,,,,,,,,,,,,,, -155400,2.5775962,1.5129712,,,,,,,,,,,,,, -155500,2.625775,1.348839,,,,,,,,,,,,,, -155600,2.477923,1.4729525,,,,,,,,,,,,,, -155700,2.3640203,1.5309448,,,,,,,,,,,,,, -155800,2.4418304,2.8221397,,,,,,,,,,,,,, -155900,2.6722589,1.3862307,,,,,,,,,,,,,, -156000,2.4212744,1.3839241,,,,,,,,,,,,,, -156100,2.318305,2.089006,,,,,,,,,,,,,, -156200,2.706392,1.7047293,,,,,,,,,,,,,, -156229,,,0.8586718440055847,0.5289605855941772,0.7625399827957153,0.937669038772583,50000.0,0.6414000391960144,1.5460975170135498,10000.0,69361.06091952324,74635.90718269348,69361.06091952324,5257.861455440521,8.966075658798218,0.0 -156300,3.6546633,3.3021092,,,,,,,,,,,,,, -156400,2.455456,2.089013,,,,,,,,,,,,,, -156500,2.7061017,3.3438356,,,,,,,,,,,,,, -156600,2.8600488,3.5235443,,,,,,,,,,,,,, -156700,2.9655483,3.5637226,,,,,,,,,,,,,, -156800,2.4276683,1.3718166,,,,,,,,,,,,,, -156900,2.7066693,1.302259,,,,,,,,,,,,,, -157000,2.9296577,3.402431,,,,,,,,,,,,,, -157100,2.5118525,1.779716,,,,,,,,,,,,,, -157174,,,0.8605468273162842,0.5198253989219666,0.7628399729728699,0.9245589375495912,50000.0,0.643500030040741,1.5365839004516602,10000.0,69781.32340240479,75087.62546420097,69781.32340240479,5289.2215440273285,9.013855934143066,0.0 -157200,2.4890249,1.2538104,,,,,,,,,,,,,, -157300,2.6140695,2.3940942,,,,,,,,,,,,,, -157400,2.6713858,1.3843584,,,,,,,,,,,,,, -157500,2.690935,1.5339761,,,,,,,,,,,,,, -157600,2.5289686,2.2581515,,,,,,,,,,,,,, -157700,2.704402,3.0210564,,,,,,,,,,,,,, -157800,2.412937,1.9853754,,,,,,,,,,,,,, -157900,2.3126686,1.7880893,,,,,,,,,,,,,, -158000,2.4935646,2.91593,,,,,,,,,,,,,, -158100,3.0224063,3.359198,,,,,,,,,,,,,, -158116,,,0.8600195050239563,0.5220351219177246,0.7657999992370605,0.922173261642456,50000.0,0.6447000503540039,1.522838115692139,10000.0,70201.25111293793,75540.83647584915,70201.25111293793,5322.374501943588,9.09701156616211,0.0 -158200,2.320342,1.7396642,,,,,,,,,,,,,, -158300,2.358191,2.081357,,,,,,,,,,,,,, -158400,2.8159487,1.3360709,,,,,,,,,,,,,, -158500,2.6035547,1.224038,,,,,,,,,,,,,, -158600,2.8942385,1.8503687,,,,,,,,,,,,,, -158700,3.1386392,3.5965154,,,,,,,,,,,,,, -158800,2.7597055,2.1552467,,,,,,,,,,,,,, -158900,2.8578568,2.9792476,,,,,,,,,,,,,, -159000,2.5606966,1.937374,,,,,,,,,,,,,, -159059,,,0.8639453053474426,0.5182149410247803,0.7652599811553955,0.925984799861908,50000.0,0.6422000527381897,1.5358188152313232,10000.0,70621.22784686089,75992.44965624809,70621.22784686089,5353.905155658722,9.155547142028809,0.0 -159100,2.3738358,1.265121,,,,,,,,,,,,,, -159200,2.6967106,1.2940718,,,,,,,,,,,,,, -159300,2.6305296,2.3987813,,,,,,,,,,,,,, -159400,2.587511,1.4525967,,,,,,,,,,,,,, -159500,2.7758327,2.9207616,,,,,,,,,,,,,, -159600,3.2541587,3.4885876,,,,,,,,,,,,,, -159700,2.955924,3.3915536,,,,,,,,,,,,,, -159800,3.1158469,3.4335144,,,,,,,,,,,,,, -159900,2.7910702,1.4614038,,,,,,,,,,,,,, -160000,,,0.8701952695846558,0.4905823767185211,0.7658799886703491,0.9287203550338744,50000.0,0.65010005235672,1.5326465368270874,10000.0,71041.17511677742,76444.50120282173,71041.17511677742,5385.8952486515045,9.221797943115234,0.0 -160000,2.821673,1.2449733,,,,,,,,,,,,,, -160100,2.870577,1.4236288,,,,,,,,,,,,,, -160200,3.0167434,3.2661653,,,,,,,,,,,,,, -160300,2.598069,1.218986,,,,,,,,,,,,,, -160400,3.0506208,1.2775441,,,,,,,,,,,,,, -160500,2.5731552,1.7066592,,,,,,,,,,,,,, -160600,2.4555638,2.2103536,,,,,,,,,,,,,, -160700,3.0591652,3.4831834,,,,,,,,,,,,,, -160800,2.9574385,1.6494389,,,,,,,,,,,,,, -160900,2.4587495,1.4816055,,,,,,,,,,,,,, -160944,,,0.8650195002555847,0.5044265985488892,0.7679399847984314,0.9082586765289308,50000.0,0.6476000547409058,1.5069485902786257,10000.0,71461.22817921638,76898.23347449303,71461.22817921638,5419.464338064194,9.283828735351562,0.0 -161000,2.5824902,1.6744256,,,,,,,,,,,,,, -161100,2.5722272,1.1890604,,,,,,,,,,,,,, -161200,2.9664123,1.1524159,,,,,,,,,,,,,, -161300,2.8181975,2.8522515,,,,,,,,,,,,,, -161400,2.9031522,1.2942241,,,,,,,,,,,,,, -161500,2.4619858,1.7700717,,,,,,,,,,,,,, -161600,2.5772645,2.3723464,,,,,,,,,,,,,, -161700,2.834281,1.3069001,,,,,,,,,,,,,, -161800,2.7612042,1.2425755,,,,,,,,,,,,,, -161886,,,0.8689648509025574,0.4930750429630279,0.7684199810028076,0.9124515652656556,50000.0,0.648300051689148,1.5145506858825684,10000.0,71881.19460654259,77349.03749322891,71881.19460654259,5450.181225776672,9.357109308242798,0.0 -161900,3.0350842,3.610971,,,,,,,,,,,,,, -162000,2.8351078,1.1711386,,,,,,,,,,,,,, -162100,3.1171072,3.3828397,,,,,,,,,,,,,, -162200,3.4008493,3.3121438,,,,,,,,,,,,,, -162300,2.671734,1.2800363,,,,,,,,,,,,,, -162400,2.493767,1.9538451,,,,,,,,,,,,,, -162500,2.8175159,1.4285305,,,,,,,,,,,,,, -162600,2.6627316,1.125288,,,,,,,,,,,,,, -162700,3.028661,3.1301942,,,,,,,,,,,,,, -162800,2.754732,2.6694388,,,,,,,,,,,,,, -162829,,,0.87416011095047,0.4663747251033783,0.7705599665641785,0.8969687223434448,50000.0,0.6515000462532043,1.5086804628372192,10000.0,72301.17582511902,77806.6926317215,72301.17582511902,5487.747537851334,9.417224645614624,0.0 -162900,2.588778,1.2461903,,,,,,,,,,,,,, -163000,2.7839081,1.3203809,,,,,,,,,,,,,, -163100,2.8665013,1.3233047,,,,,,,,,,,,,, -163200,2.9308012,1.4895277,,,,,,,,,,,,,, -163300,2.771022,1.2608209,,,,,,,,,,,,,, -163400,2.7507799,1.172332,,,,,,,,,,,,,, -163500,2.719122,1.2130408,,,,,,,,,,,,,, -163600,2.67063,1.8766743,,,,,,,,,,,,,, -163700,2.7465267,1.2484912,,,,,,,,,,,,,, -163779,,,0.869921863079071,0.4806233942508697,0.7714200019836426,0.8954671025276184,50000.0,0.6487000584602356,1.4998230934143066,10000.0,72721.10506033897,78256.85267496109,72721.10506033897,5517.879540681839,9.467840194702148,0.0 -163800,2.8830676,2.8815646,,,,,,,,,,,,,, -163900,3.0789075,1.3029872,,,,,,,,,,,,,, -164000,2.7479744,2.4154496,,,,,,,,,,,,,, -164100,3.0614467,2.9952965,,,,,,,,,,,,,, -164200,3.2714744,2.7622964,,,,,,,,,,,,,, -164300,2.753918,1.5617218,,,,,,,,,,,,,, -164400,2.800049,1.3228741,,,,,,,,,,,,,, -164500,3.0179033,3.2251313,,,,,,,,,,,,,, -164600,3.1396143,3.5092673,,,,,,,,,,,,,, -164700,2.9434292,1.3234392,,,,,,,,,,,,,, -164724,,,0.8695507645606995,0.4878565371036529,0.771399974822998,0.9021183252334596,50000.0,0.6450000405311584,1.5077391862869265,10000.0,73141.14754962921,78708.84034132957,73141.14754962921,5549.715945720673,9.528099298477173,0.0 -164800,3.0997055,3.034583,,,,,,,,,,,,,, -164900,2.7368252,1.1742308,,,,,,,,,,,,,, -165000,2.5665946,2.220548,,,,,,,,,,,,,, -165100,2.8644655,1.1617718,,,,,,,,,,,,,, -165200,2.6862946,1.3089821,,,,,,,,,,,,,, -165300,2.698233,1.9301394,,,,,,,,,,,,,, -165400,3.0607424,1.2604347,,,,,,,,,,,,,, -165500,2.7232444,2.276177,,,,,,,,,,,,,, -165600,3.736958,3.171809,,,,,,,,,,,,,, -165668,,,0.8700780868530273,0.4788884222507477,0.7709999680519104,0.8976555466651917,50000.0,0.6477000117301941,1.5097239017486572,10000.0,73561.19763278961,79162.93744325638,73561.19763278961,5583.6576244831085,9.585728645324709,0.0 -165700,2.8004246,1.277208,,,,,,,,,,,,,, -165800,3.2752423,3.3071964,,,,,,,,,,,,,, -165900,2.569473,1.5937872,,,,,,,,,,,,,, -166000,2.8592303,1.9712296,,,,,,,,,,,,,, -166100,2.957015,2.9965692,,,,,,,,,,,,,, -166200,3.034051,2.7016113,,,,,,,,,,,,,, -166300,2.8315651,1.1757078,,,,,,,,,,,,,, -166400,3.198051,1.2860808,,,,,,,,,,,,,, -166500,2.9562347,1.2078531,,,,,,,,,,,,,, -166600,2.983691,1.8795527,,,,,,,,,,,,,, -166615,,,0.8802929520606995,0.4507987201213836,0.7734799981117249,0.8893938660621643,50000.0,0.656000018119812,1.489260196685791,10000.0,73981.26328992844,79615.27650952339,73981.26328992844,5615.822008609772,9.64665174484253,0.0 -166700,3.3765967,1.278389,,,,,,,,,,,,,, -166800,3.165827,1.1810504,,,,,,,,,,,,,, -166900,2.7627928,1.3969681,,,,,,,,,,,,,, -167000,3.4266005,3.3763185,,,,,,,,,,,,,, -167100,2.662674,1.8585058,,,,,,,,,,,,,, -167200,2.6649835,1.2810824,,,,,,,,,,,,,, -167300,2.9356546,1.9125441,,,,,,,,,,,,,, -167400,3.2173252,1.2724293,,,,,,,,,,,,,, -167500,3.3000174,1.2439755,,,,,,,,,,,,,, -167561,,,0.8752539157867432,0.4655075669288635,0.7736200094223022,0.8859438300132751,50000.0,0.6522000432014465,1.4930800199508667,10000.0,74401.27312350273,80068.24415230751,74401.27312350273,5648.666755914688,9.712154865264893,0.0 -167600,2.951438,1.7528516,,,,,,,,,,,,,, -167700,2.9277658,1.3910311,,,,,,,,,,,,,, -167800,2.8609626,1.3501229,,,,,,,,,,,,,, -167900,2.9730968,1.4257381,,,,,,,,,,,,,, -168000,3.2853208,3.235751,,,,,,,,,,,,,, -168100,2.6982458,2.1656697,,,,,,,,,,,,,, -168200,2.7043023,1.3788232,,,,,,,,,,,,,, -168300,2.903778,1.1130028,,,,,,,,,,,,,, -168400,3.2198095,3.007401,,,,,,,,,,,,,, -168500,2.9164202,1.1002148,,,,,,,,,,,,,, -168508,,,0.87708979845047,0.4573198854923248,0.7735799551010132,0.8782155513763428,50000.0,0.6541000604629517,1.4873871803283691,10000.0,74821.58446097374,80523.77368187904,74821.58446097374,5683.770456075668,9.778326272964478,0.0 -168600,2.9831395,1.2503017,,,,,,,,,,,,,, -168700,2.7340586,1.3448479,,,,,,,,,,,,,, -168800,3.0878825,3.0306761,,,,,,,,,,,,,, -168900,3.4542644,3.307264,,,,,,,,,,,,,, -169000,2.8929608,1.7055128,,,,,,,,,,,,,, -169100,3.644455,3.203508,,,,,,,,,,,,,, -169200,3.0162795,1.1992332,,,,,,,,,,,,,, -169300,2.776784,1.1549704,,,,,,,,,,,,,, -169400,5.39023,1.161537,,,,,,,,,,,,,, -169454,,,0.8761327862739563,0.4634994268417358,0.773859977722168,0.8852800726890564,50000.0,0.6522000432014465,1.4902135133743286,10000.0,75241.93149113655,80978.3799700737,75241.93149113655,5717.919988632202,9.840453386306764,0.0 -169500,2.8797789,1.1932818,,,,,,,,,,,,,, -169600,2.7871552,1.4656115,,,,,,,,,,,,,, -169700,3.232564,1.2918267,,,,,,,,,,,,,, -169800,2.8978446,1.2247769,,,,,,,,,,,,,, -169900,3.062086,2.4373488,,,,,,,,,,,,,, -170000,3.447849,3.385211,,,,,,,,,,,,,, -170100,2.7739034,1.1643673,,,,,,,,,,,,,, -170200,2.9140663,1.2146367,,,,,,,,,,,,,, -170300,3.288126,1.2828789,,,,,,,,,,,,,, -170400,,,0.8781054615974426,0.45171058177948,0.7761200070381165,0.874293863773346,50000.0,0.6583000421524048,1.4806007146835327,10000.0,75662.24620199203,81429.60545611382,75662.24620199203,5748.718489408493,9.904155015945436,0.0 -170400,2.9033933,1.157591,,,,,,,,,,,,,, -170500,2.7769258,1.1890287,,,,,,,,,,,,,, -170600,3.07358,1.1801745,,,,,,,,,,,,,, -170700,3.102448,1.2517757,,,,,,,,,,,,,, -170800,3.034296,1.2128427,,,,,,,,,,,,,, -170900,2.9115934,2.160575,,,,,,,,,,,,,, -171000,2.6425939,1.6407608,,,,,,,,,,,,,, -171100,2.7849479,1.1011525,,,,,,,,,,,,,, -171200,3.0144122,1.2138805,,,,,,,,,,,,,, -171300,2.971736,1.2664192,,,,,,,,,,,,,, -171344,,,0.8800976276397705,0.444496214389801,0.7761799693107605,0.873876690864563,50000.0,0.6561000347137451,1.4711897373199463,10000.0,76082.20815610886,81883.75682711601,76082.20815610886,5782.800493478775,9.963026523590088,0.0 -171400,3.3129594,2.907146,,,,,,,,,,,,,, -171500,2.9728463,2.6445432,,,,,,,,,,,,,, -171600,2.7315564,1.8020811,,,,,,,,,,,,,, -171700,3.04576,1.2305994,,,,,,,,,,,,,, -171800,2.7701206,1.1205567,,,,,,,,,,,,,, -171900,2.9071693,1.1900346,,,,,,,,,,,,,, -172000,3.1298108,1.3304719,,,,,,,,,,,,,, -172100,3.0788631,1.3857046,,,,,,,,,,,,,, -172200,3.1831667,2.8949962,,,,,,,,,,,,,, -172261,,,0.8818749785423279,0.4351919591426849,0.7774199843406677,0.8687987923622131,50000.0,0.6562000513076782,1.4677006006240845,10000.0,76502.247112751,82344.99448871613,76502.247112751,5823.888249158859,10.027615547180176,0.0 -172300,2.9225786,1.251714,,,,,,,,,,,,,, -172400,2.9273033,1.4331398,,,,,,,,,,,,,, -172500,3.0143278,1.1142411,,,,,,,,,,,,,, -172600,3.6004643,3.3869588,,,,,,,,,,,,,, -172700,3.0272903,1.2690289,,,,,,,,,,,,,, -172800,3.2317922,2.4966893,,,,,,,,,,,,,, -172900,2.9540813,2.383348,,,,,,,,,,,,,, -173000,3.4798753,2.9809928,,,,,,,,,,,,,, -173100,2.7850618,1.6475787,,,,,,,,,,,,,, -173200,3.2897253,1.1486466,,,,,,,,,,,,,, -173210,,,0.8833593726158142,0.4367986917495727,0.7771399617195129,0.8698133230209351,50000.0,0.6541000604629517,1.4714024066925049,10000.0,76922.55945754051,82802.23096346855,76922.55945754051,5860.713069200516,10.078977823257446,0.0 -173300,3.13846,2.908142,,,,,,,,,,,,,, -173400,3.0185466,2.122577,,,,,,,,,,,,,, -173500,3.064919,1.0871508,,,,,,,,,,,,,, -173600,2.8547065,1.8440733,,,,,,,,,,,,,, -173700,2.9438694,1.2033737,,,,,,,,,,,,,, -173800,3.2494147,2.6823716,,,,,,,,,,,,,, -173900,3.347476,3.0709963,,,,,,,,,,,,,, -174000,2.9953744,1.2376096,,,,,,,,,,,,,, -174100,3.086837,1.1149714,,,,,,,,,,,,,, -174156,,,0.8812890648841858,0.4501301944255829,0.7763199806213379,0.872589111328125,50000.0,0.6601000428199768,1.4674209356307983,10000.0,77342.50984573364,83257.17953515053,77342.50984573364,5895.611342906952,10.13133668899536,0.0 -174200,3.2542753,1.7959113,,,,,,,,,,,,,, -174300,2.9114068,1.8370553,,,,,,,,,,,,,, -174400,2.8643916,1.6941987,,,,,,,,,,,,,, -174500,3.4055145,1.1923478,,,,,,,,,,,,,, -174600,3.5156865,3.133864,,,,,,,,,,,,,, -174700,3.1373572,1.2116984,,,,,,,,,,,,,, -174800,2.9006386,1.1935341,,,,,,,,,,,,,, -174900,3.226456,1.0861077,,,,,,,,,,,,,, -175000,2.8375409,1.1191024,,,,,,,,,,,,,, -175100,3.297145,2.9523766,,,,,,,,,,,,,, -175101,,,0.8842968344688416,0.4302087426185608,0.7767399549484253,0.869044840335846,50000.0,0.6586000323295593,1.465294361114502,10000.0,77762.4930267334,83716.3657438755,77762.4930267334,5934.702550172806,10.195268630981444,0.0 -175200,3.0743752,2.4512768,,,,,,,,,,,,,, -175300,3.4870968,3.2219806,,,,,,,,,,,,,, -175400,3.0225224,1.1389786,,,,,,,,,,,,,, -175500,2.8495798,1.146529,,,,,,,,,,,,,, -175600,3.0634713,2.1627202,,,,,,,,,,,,,, -175700,2.8913028,1.1225364,,,,,,,,,,,,,, -175800,3.3978252,2.8233585,,,,,,,,,,,,,, -175900,2.9803486,1.3083537,,,,,,,,,,,,,, -176000,3.0853236,1.1320491,,,,,,,,,,,,,, -176047,,,0.8846288919448853,0.4247990846633911,0.7781800031661987,0.8619899749755859,50000.0,0.6606000065803528,1.4649004936218262,10000.0,78182.75847244263,84168.2788734436,78182.75847244263,5966.238239049912,10.259270668029783,0.0 -176100,3.2828617,1.2379212,,,,,,,,,,,,,, -176200,3.250761,1.1205794,,,,,,,,,,,,,, -176300,3.2912006,1.8672498,,,,,,,,,,,,,, -176400,3.055233,1.1909878,,,,,,,,,,,,,, -176500,2.818725,1.0463725,,,,,,,,,,,,,, -176600,2.918685,1.9497212,,,,,,,,,,,,,, -176700,3.0477726,1.2581275,,,,,,,,,,,,,, -176800,3.3795114,2.771332,,,,,,,,,,,,,, -176900,3.3316216,2.8342395,,,,,,,,,,,,,, -176987,,,0.8851367235183716,0.4297102391719818,0.7775399684906006,0.867326021194458,50000.0,0.6599000096321106,1.464371919631958,10000.0,78602.89275169373,84622.46363949776,78602.89275169373,6000.178608179092,10.321425676345823,0.0 -177000,2.967188,1.0686189,,,,,,,,,,,,,, -177100,3.2686508,1.1812998,,,,,,,,,,,,,, -177200,3.1488798,1.1270647,,,,,,,,,,,,,, -177300,3.2989335,3.2063053,,,,,,,,,,,,,, -177400,3.32585,2.3073897,,,,,,,,,,,,,, -177500,3.2655888,1.2337067,,,,,,,,,,,,,, -177600,3.1498399,1.1729177,,,,,,,,,,,,,, -177700,3.114634,1.1271844,,,,,,,,,,,,,, -177800,3.2790139,2.7399693,,,,,,,,,,,,,, -177900,2.8285568,1.6906523,,,,,,,,,,,,,, -177929,,,0.8857812285423279,0.4314631521701813,0.7797799706459045,0.8623688220977783,50000.0,0.6630000472068787,1.4578428268432615,10000.0,79022.98996186256,85079.57243037224,79022.98996186256,6037.081268548965,10.381896018981934,0.0 -178000,3.1716495,1.1208683,,,,,,,,,,,,,, -178100,3.0972865,1.1235023,,,,,,,,,,,,,, -178200,2.8316145,1.1190797,,,,,,,,,,,,,, -178300,2.9867435,2.2904186,,,,,,,,,,,,,, -178400,2.9342487,1.4191791,,,,,,,,,,,,,, -178500,2.9214506,1.1402017,,,,,,,,,,,,,, -178600,2.8950503,1.6211963,,,,,,,,,,,,,, -178700,3.0861602,1.1148493,,,,,,,,,,,,,, -178800,2.854588,2.094261,,,,,,,,,,,,,, -178879,,,0.8853515386581421,0.4269904494285583,0.7788599729537964,0.8618048429489136,50000.0,0.6585000157356262,1.4597877264022827,10000.0,79443.26837658882,85532.51931023598,79443.26837658882,6069.630121469498,10.452614307403564,0.0 -178900,3.7140594,3.1462047,,,,,,,,,,,,,, -179000,3.104193,1.7217166,,,,,,,,,,,,,, -179100,2.9200351,1.0095611,,,,,,,,,,,,,, -179200,2.945836,1.0954461,,,,,,,,,,,,,, -179300,3.1724641,2.159879,,,,,,,,,,,,,, -179400,3.0630364,1.0468898,,,,,,,,,,,,,, -179500,3.8631382,3.2198365,,,,,,,,,,,,,, -179600,3.1595805,1.0492759,,,,,,,,,,,,,, -179700,3.3079636,2.9577916,,,,,,,,,,,,,, -179800,3.3528588,1.1610947,,,,,,,,,,,,,, -179824,,,0.887499988079071,0.4231514930725097,0.7784799933433533,0.8612497448921204,50000.0,0.661300003528595,1.4576455354690552,10000.0,79863.53984189034,85985.32821440697,79863.53984189034,6102.052920103073,10.519617795944214,0.0 -179900,2.9112482,1.8324978,,,,,,,,,,,,,, -180000,3.0331976,1.1663306,,,,,,,,,,,,,, -180100,3.11335,1.2457389,,,,,,,,,,,,,, -180200,3.0898085,1.094263,,,,,,,,,,,,,, -180300,3.0635242,1.1393749,,,,,,,,,,,,,, -180400,3.0922277,1.157527,,,,,,,,,,,,,, -180500,3.1908581,1.176069,,,,,,,,,,,,,, -180600,3.1013715,1.2062895,,,,,,,,,,,,,, -180700,3.1133313,1.1485333,,,,,,,,,,,,,, -180768,,,0.8854296803474426,0.4256410896778106,0.7800599932670593,0.8574149012565613,50000.0,0.6619000434875488,1.4548231363296509,10000.0,80283.59482526779,86442.04560875893,80283.59482526779,6138.60348033905,10.583715915679932,0.0 -180800,3.4463067,2.9616709,,,,,,,,,,,,,, -180900,2.9368556,1.9101021,,,,,,,,,,,,,, -181000,3.251412,1.2010176,,,,,,,,,,,,,, -181100,3.1984763,1.0608822,,,,,,,,,,,,,, -181200,2.9501815,1.0462322,,,,,,,,,,,,,, -181300,3.1560292,1.3633718,,,,,,,,,,,,,, -181400,3.3431566,2.7487807,,,,,,,,,,,,,, -181500,2.9960024,1.070681,,,,,,,,,,,,,, -181600,2.994087,1.3972958,,,,,,,,,,,,,, -181700,2.8203998,1.0275408,,,,,,,,,,,,,, -181716,,,0.8868945240974426,0.4210363030433655,0.7802599668502808,0.8553614616394043,50000.0,0.6617000102996826,1.4499728679656982,10000.0,80703.73807430267,86892.89508104324,80703.73807430267,6169.210487604141,10.634981632232666,0.0 -181800,4.6028495,3.2194028,,,,,,,,,,,,,, -181900,2.997532,1.3044868,,,,,,,,,,,,,, -182000,3.0254111,2.043988,,,,,,,,,,,,,, -182100,3.4139514,1.7877094,,,,,,,,,,,,,, -182200,2.957158,1.4177377,,,,,,,,,,,,,, -182300,3.5936904,2.834333,,,,,,,,,,,,,, -182400,3.0183635,1.8934519,,,,,,,,,,,,,, -182500,3.0403018,1.1402762,,,,,,,,,,,,,, -182600,3.0552394,1.1763253,,,,,,,,,,,,,, -182659,,,0.8882421851158142,0.4179194271564483,0.780299961566925,0.8557093739509583,50000.0,0.6627000570297241,1.4520686864852903,10000.0,81123.65591287613,87347.31507396698,81123.65591287613,6203.603073835373,10.696644067764282,0.0 -182700,2.9938536,1.4237795,,,,,,,,,,,,,, -182800,2.9918149,1.2566714,,,,,,,,,,,,,, -182900,2.8530993,1.2262986,,,,,,,,,,,,,, -183000,3.9222467,3.351523,,,,,,,,,,,,,, -183100,3.6120923,2.904223,,,,,,,,,,,,,, -183200,3.2519941,1.2067927,,,,,,,,,,,,,, -183300,3.1104016,2.2857857,,,,,,,,,,,,,, -183400,2.9697213,1.7579794,,,,,,,,,,,,,, -183500,3.1542535,1.6461238,,,,,,,,,,,,,, -183600,3.0866811,1.2472904,,,,,,,,,,,,,, -183603,,,0.8871288895606995,0.4207266271114349,0.7800599932670593,0.855748176574707,50000.0,0.6643000245094299,1.4516526460647583,10000.0,81543.79722166061,87804.11721277237,81543.79722166061,6240.15052652359,10.76271915435791,0.0 -183700,2.8607762,1.7253549,,,,,,,,,,,,,, -183800,3.333799,2.602024,,,,,,,,,,,,,, -183900,2.9800634,1.2144754,,,,,,,,,,,,,, -184000,4.0933437,3.2349937,,,,,,,,,,,,,, -184100,3.5562866,2.881312,,,,,,,,,,,,,, -184200,3.1462595,1.2036343,,,,,,,,,,,,,, -184300,3.0602782,1.0927396,,,,,,,,,,,,,, -184400,2.9702547,2.4826963,,,,,,,,,,,,,, -184500,3.298174,2.1491373,,,,,,,,,,,,,, -184549,,,0.8898437023162842,0.4142662286758423,0.7801600098609924,0.8554735779762268,50000.0,0.663800060749054,1.4520243406295776,10000.0,81963.94086170197,88257.15979456902,81963.94086170197,6272.948943138123,10.815591096878052,0.0 -184600,3.0629072,1.1872387,,,,,,,,,,,,,, -184700,3.0082424,1.0960729,,,,,,,,,,,,,, -184800,4.002882,1.1027161,,,,,,,,,,,,,, -184900,3.7966616,3.3027534,,,,,,,,,,,,,, -185000,3.3307016,3.020462,,,,,,,,,,,,,, -185100,3.4432068,2.897522,,,,,,,,,,,,,, -185200,3.279467,1.1351215,,,,,,,,,,,,,, -185300,3.0188317,1.1605716,,,,,,,,,,,,,, -185400,3.020704,2.4645605,,,,,,,,,,,,,, -185492,,,0.8886523246765137,0.4187787473201751,0.7800799608230591,0.8548168540000916,50000.0,0.664400041103363,1.4500694274902344,10000.0,82384.03376245499,88710.46589899063,82384.03376245499,6306.051533937454,10.878965616226196,0.0 -185500,5.0701756,1.4863701,,,,,,,,,,,,,, -185600,2.998792,1.7481164,,,,,,,,,,,,,, -185700,3.069875,1.7759413,,,,,,,,,,,,,, -185800,3.1191459,2.5165718,,,,,,,,,,,,,, -185900,3.0808876,1.1050065,,,,,,,,,,,,,, -186000,3.7297382,3.0793552,,,,,,,,,,,,,, -186100,3.0756357,2.7611198,,,,,,,,,,,,,, -186200,3.8187191,2.7642372,,,,,,,,,,,,,, -186300,3.0087373,1.3162686,,,,,,,,,,,,,, -186400,2.992211,2.0314221,,,,,,,,,,,,,, -186437,,,0.8884961009025574,0.418078750371933,0.7803399562835693,0.8548141121864319,50000.0,0.6648000478744507,1.450608491897583,10000.0,82804.15239548683,89166.71850991249,82804.15239548683,6342.082160711288,10.93463397026062,0.0 -186500,3.1188552,1.1564648,,,,,,,,,,,,,, -186600,3.9519472,2.3247175,,,,,,,,,,,,,, -186700,3.1332698,1.125345,,,,,,,,,,,,,, -186800,3.1780677,1.1665893,,,,,,,,,,,,,, -186900,3.5897696,3.1938453,,,,,,,,,,,,,, -187000,2.8477213,1.9711189,,,,,,,,,,,,,, -187100,3.1376154,1.1429276,,,,,,,,,,,,,, -187200,3.0532267,2.4725595,,,,,,,,,,,,,, -187300,3.1305068,1.2062056,,,,,,,,,,,,,, -187384,,,0.8875585794448853,0.4143358469009399,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,83224.0760512352,89624.2609269619,83224.0760512352,6379.588619709015,10.998920679092407,0.0 -187400,3.805293,3.3130505,,,,,,,,,,,,,, -187500,4.198159,3.2576973,,,,,,,,,,,,,, -187600,3.2211785,1.1187518,,,,,,,,,,,,,, -187700,3.1997259,1.8968103,,,,,,,,,,,,,, -187800,3.1053963,2.8122666,,,,,,,,,,,,,, -187900,3.1269715,1.031711,,,,,,,,,,,,,, -188000,4.1573596,2.9843097,,,,,,,,,,,,,, -188100,3.3200684,2.7294512,,,,,,,,,,,,,, -188200,3.8158164,3.3938794,,,,,,,,,,,,,, -188300,3.055643,1.1456386,,,,,,,,,,,,,, -188327,,,0.8861523270606995,0.4194840490818023,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,83644.06291556358,90076.51022052763,83644.06291556358,6411.748934745789,11.053326845169067,0.0 -188400,3.1873975,2.0974712,,,,,,,,,,,,,, -188500,3.262798,2.220341,,,,,,,,,,,,,, -188600,2.9332974,1.0713323,,,,,,,,,,,,,, -188700,3.3714242,2.973319,,,,,,,,,,,,,, -188800,2.8862226,1.9709717,,,,,,,,,,,,,, -188900,3.5353208,1.0427948,,,,,,,,,,,,,, -189000,3.062964,1.9001633,,,,,,,,,,,,,, -189100,3.8516114,2.832551,,,,,,,,,,,,,, -189200,3.063284,2.824517,,,,,,,,,,,,,, -189270,,,0.885058581829071,0.4213558435440063,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,84064.30104327202,90530.59429717064,84064.30104327202,6445.480100631714,11.119504451751707,0.0 -189300,3.5977445,3.0776381,,,,,,,,,,,,,, -189400,3.3246932,1.4980068,,,,,,,,,,,,,, -189500,2.98721,1.026208,,,,,,,,,,,,,, -189600,2.931271,2.4505868,,,,,,,,,,,,,, -189700,3.2878973,1.1372774,,,,,,,,,,,,,, -189800,3.1096458,2.248611,,,,,,,,,,,,,, -189900,2.935957,1.3786352,,,,,,,,,,,,,, -190000,3.1292884,2.3216898,,,,,,,,,,,,,, -190100,3.0440638,2.519177,,,,,,,,,,,,,, -190200,3.1730092,1.2483237,,,,,,,,,,,,,, -190213,,,0.8882812261581421,0.4192114472389221,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,84484.19007110596,90982.77915644646,84484.19007110596,6477.661994457245,11.18522047996521,0.0 -190300,3.137863,1.0782868,,,,,,,,,,,,,, -190400,3.3841977,1.4213622,,,,,,,,,,,,,, -190500,3.3674116,1.0914919,,,,,,,,,,,,,, -190600,3.0699952,1.306435,,,,,,,,,,,,,, -190700,3.02789,1.0702859,,,,,,,,,,,,,, -190800,3.106724,1.285416,,,,,,,,,,,,,, -190900,3.1451156,2.5139003,,,,,,,,,,,,,, -191000,3.335918,1.438863,,,,,,,,,,,,,, -191100,3.0304909,1.2227744,,,,,,,,,,,,,, -191158,,,0.8881054520606995,0.4155466854572296,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,84904.22356843948,91434.7952284813,84904.22356843948,6509.533877134323,11.248058319091797,0.0 -191200,2.992095,1.0612054,,,,,,,,,,,,,, -191300,3.3047488,1.0880616,,,,,,,,,,,,,, -191400,3.2255535,1.5888885,,,,,,,,,,,,,, -191500,2.87463,1.0390023,,,,,,,,,,,,,, -191600,3.0646942,1.2669406,,,,,,,,,,,,,, -191700,3.2134168,2.5263207,,,,,,,,,,,,,, -191800,3.029248,2.29318,,,,,,,,,,,,,, -191900,2.91557,1.2568319,,,,,,,,,,,,,, -192000,3.552851,2.421825,,,,,,,,,,,,,, -192100,3.1569395,1.1922402,,,,,,,,,,,,,, -192103,,,0.8872460722923279,0.4198216497898102,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,85324.30487179756,91886.03667712212,85324.30487179756,6540.580693721771,11.31272554397583,0.0 -192200,3.7054036,3.2120078,,,,,,,,,,,,,, -192300,3.1691256,1.9191399,,,,,,,,,,,,,, -192400,3.19878,1.0961158,,,,,,,,,,,,,, -192500,3.1889246,1.1700996,,,,,,,,,,,,,, -192600,3.2457688,2.8679821,,,,,,,,,,,,,, -192700,2.99739,1.7372242,,,,,,,,,,,,,, -192800,3.0887341,2.5353074,,,,,,,,,,,,,, -192900,3.655174,3.247873,,,,,,,,,,,,,, -193000,4.127968,2.9163306,,,,,,,,,,,,,, -193050,,,0.8844921588897705,0.4295682907104492,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,85744.3128118515,92341.42679834366,85744.3128118515,6575.839470863342,11.38812780380249,0.0 -193100,3.805877,1.0949111,,,,,,,,,,,,,, -193200,2.9372914,1.0685084,,,,,,,,,,,,,, -193300,3.8884099,1.1028204,,,,,,,,,,,,,, -193400,3.239223,2.8143656,,,,,,,,,,,,,, -193500,2.8361099,1.3888446,,,,,,,,,,,,,, -193600,3.0465584,1.2009279,,,,,,,,,,,,,, -193700,3.27296,2.1308742,,,,,,,,,,,,,, -193800,3.463171,1.1121855,,,,,,,,,,,,,, -193900,3.3916008,1.1278311,,,,,,,,,,,,,, -193996,,,0.8874413967132568,0.4196164309978485,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,86164.21161937714,92792.05870342256,86164.21161937714,6606.46160697937,11.450967073440552,0.0 -194000,3.1250072,1.411762,,,,,,,,,,,,,, -194100,2.7642694,1.9856322,,,,,,,,,,,,,, -194200,3.0885768,1.2219558,,,,,,,,,,,,,, -194300,3.579951,3.0002255,,,,,,,,,,,,,, -194400,3.5939116,2.7614043,,,,,,,,,,,,,, -194500,3.0289993,1.7947662,,,,,,,,,,,,,, -194600,3.4539845,1.2367942,,,,,,,,,,,,,, -194700,3.1177845,1.0745927,,,,,,,,,,,,,, -194800,3.7795875,2.052566,,,,,,,,,,,,,, -194900,3.037672,1.2318258,,,,,,,,,,,,,, -194940,,,0.8882421851158142,0.4126248955726623,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,86584.47300457954,93248.01982617378,86584.47300457954,6642.044916152954,11.519468545913696,0.0 -195000,3.0146494,1.1350693,,,,,,,,,,,,,, -195100,3.1774788,1.217775,,,,,,,,,,,,,, -195200,3.0002942,1.1267481,,,,,,,,,,,,,, -195300,3.072622,2.217681,,,,,,,,,,,,,, -195400,2.8972135,1.6000164,,,,,,,,,,,,,, -195500,3.3165786,1.1003914,,,,,,,,,,,,,, -195600,3.6777842,3.120297,,,,,,,,,,,,,, -195700,3.2391999,2.1923113,,,,,,,,,,,,,, -195800,3.1222177,1.1133446,,,,,,,,,,,,,, -195885,,,0.8870898485183716,0.4183282852172851,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,87004.58102655411,93703.66821908952,87004.58102655411,6677.47377038002,11.58312463760376,0.0 -195900,3.1008239,1.8955399,,,,,,,,,,,,,, -196000,3.0603685,1.1089597,,,,,,,,,,,,,, -196100,3.3376267,2.733663,,,,,,,,,,,,,, -196200,3.1861815,2.6507897,,,,,,,,,,,,,, -196300,3.689105,3.202373,,,,,,,,,,,,,, -196400,3.3500285,2.3676076,,,,,,,,,,,,,, -196500,2.9547381,2.2060783,,,,,,,,,,,,,, -196600,3.9378495,3.0583608,,,,,,,,,,,,,, -196700,2.885372,1.0674846,,,,,,,,,,,,,, -196800,2.98044,1.3161767,,,,,,,,,,,,,, -196833,,,0.8863085508346558,0.4206987619400024,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,87424.61718916893,94156.83369517326,87424.61718916893,6710.491514205933,11.646705865859984,0.0 -196900,3.1226938,1.1593567,,,,,,,,,,,,,, -197000,2.944673,1.2885301,,,,,,,,,,,,,, -197100,2.872563,2.1458573,,,,,,,,,,,,,, -197200,2.877975,2.4179137,,,,,,,,,,,,,, -197300,3.4544559,2.900679,,,,,,,,,,,,,, -197400,3.9488442,3.2431142,,,,,,,,,,,,,, -197500,3.4164317,1.1446545,,,,,,,,,,,,,, -197600,2.9198759,1.9154692,,,,,,,,,,,,,, -197700,3.1675985,1.7085819,,,,,,,,,,,,,, -197775,,,0.8870312571525574,0.4229859709739685,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,87844.54752922058,94611.82145762444,87844.54752922058,6745.4324831962585,11.7142493724823,0.0 -197800,3.2289138,1.5447717,,,,,,,,,,,,,, -197900,3.0996401,1.1360697,,,,,,,,,,,,,, -198000,3.2646608,1.1176913,,,,,,,,,,,,,, -198100,5.233738,1.1935512,,,,,,,,,,,,,, -198200,3.6232553,2.861463,,,,,,,,,,,,,, -198300,3.049663,1.1462152,,,,,,,,,,,,,, -198400,3.407325,1.1279092,,,,,,,,,,,,,, -198500,3.4642794,1.5417454,,,,,,,,,,,,,, -198600,3.4001863,1.3028331,,,,,,,,,,,,,, -198700,3.0311036,1.3790958,,,,,,,,,,,,,, -198719,,,0.8882030844688416,0.4190461635589599,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,88264.76767849922,95066.36974191666,88264.76767849922,6779.643330097199,11.784050464630129,0.0 -198800,3.2470448,1.0882224,,,,,,,,,,,,,, -198900,2.8635116,1.0375707,,,,,,,,,,,,,, -199000,3.465111,2.2617385,,,,,,,,,,,,,, -199100,3.1856482,1.1811348,,,,,,,,,,,,,, -199200,3.2111506,1.8742576,,,,,,,,,,,,,, -199300,3.3543372,1.2597154,,,,,,,,,,,,,, -199400,3.138085,1.1078088,,,,,,,,,,,,,, -199500,3.1131537,1.3107357,,,,,,,,,,,,,, -199600,3.0652115,1.159286,,,,,,,,,,,,,, -199663,,,0.8855078220367432,0.4201179146766662,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,88684.78815793991,95517.46187448502,88684.78815793991,6810.564759016037,11.886321544647217,0.0 -199700,3.0334828,1.2706003,,,,,,,,,,,,,, -199800,3.0259225,1.0759656,,,,,,,,,,,,,, -199900,2.8076181,1.5025632,,,,,,,,,,,,,, -200000,3.1365716,1.214685,,,,,,,,,,,,,, -200100,3.0322342,1.082679,,,,,,,,,,,,,, -200200,4.0143824,3.3051794,,,,,,,,,,,,,, -200300,3.1777928,1.3844428,,,,,,,,,,,,,, -200400,3.1553288,1.4536413,,,,,,,,,,,,,, -200500,3.1124856,1.1938471,,,,,,,,,,,,,, -200600,3.1762064,2.7592854,,,,,,,,,,,,,, -200604,,,0.8885741829872131,0.4168844521045685,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,89105.10370564461,95976.39848995207,89105.10370564461,6849.069217681885,11.955017805099487,0.0 -200700,3.0208595,1.3117095,,,,,,,,,,,,,, -200800,2.907745,2.017188,,,,,,,,,,,,,, -200900,3.0547981,1.4890659,,,,,,,,,,,,,, -201000,2.840638,1.067419,,,,,,,,,,,,,, -201100,3.0454443,1.7686582,,,,,,,,,,,,,, -201200,3.519093,1.957755,,,,,,,,,,,,,, -201300,3.5057423,3.1842287,,,,,,,,,,,,,, -201400,2.9273295,1.7005509,,,,,,,,,,,,,, -201500,2.9584188,1.1356539,,,,,,,,,,,,,, -201549,,,0.8887695074081421,0.4169636070728302,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,89525.00794053078,96428.4326581955,89525.00794053078,6881.090431928635,12.015576601028442,0.0 -201600,2.904252,1.336281,,,,,,,,,,,,,, -201700,3.0294862,1.15503,,,,,,,,,,,,,, -201800,3.0621963,1.1287552,,,,,,,,,,,,,, -201900,3.0800197,1.1085427,,,,,,,,,,,,,, -202000,3.357253,1.1826593,,,,,,,,,,,,,, -202100,3.7862964,3.1990187,,,,,,,,,,,,,, -202200,2.9246674,1.2151077,,,,,,,,,,,,,, -202300,2.9891572,1.9990684,,,,,,,,,,,,,, -202400,3.0658038,1.1940999,,,,,,,,,,,,,, -202495,,,0.8854491710662842,0.4246847927570343,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,89945.28154015541,96885.9902985096,89945.28154015541,6918.256211996079,12.085902690887451,0.0 -202500,3.174206,1.1340015,,,,,,,,,,,,,, -202600,3.07334,1.1837587,,,,,,,,,,,,,, -202700,2.8731093,1.2445159,,,,,,,,,,,,,, -202800,3.2617621,1.5432786,,,,,,,,,,,,,, -202900,3.2880456,1.4315808,,,,,,,,,,,,,, -203000,5.2578526,2.078594,,,,,,,,,,,,,, -203100,3.391674,2.912081,,,,,,,,,,,,,, -203200,2.8338945,1.5410713,,,,,,,,,,,,,, -203300,3.0898042,1.1209557,,,,,,,,,,,,,, -203400,3.1363475,1.3253162,,,,,,,,,,,,,, -203440,,,0.8881250023841858,0.4221977889537811,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,90365.33030056952,97342.42331409454,90365.33030056952,6954.5250408649445,12.15333890914917,0.0 -203500,3.1780732,1.8391929,,,,,,,,,,,,,, -203600,3.4622602,2.1940396,,,,,,,,,,,,,, -203700,3.8353972,2.8797085,,,,,,,,,,,,,, -203800,2.9210813,1.1749681,,,,,,,,,,,,,, -203900,3.721363,1.246505,,,,,,,,,,,,,, -204000,3.2353988,1.5491437,,,,,,,,,,,,,, -204100,3.0837343,1.6721333,,,,,,,,,,,,,, -204200,3.1574855,1.2334424,,,,,,,,,,,,,, -204300,3.1984508,1.1155894,,,,,,,,,,,,,, -204383,,,0.8873046636581421,0.4180344343185425,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,90785.62109804152,97796.2846519947,90785.62109804152,6987.989857196808,12.211401224136353,0.0 -204400,2.862946,1.084322,,,,,,,,,,,,,, -204500,2.8900568,2.0840142,,,,,,,,,,,,,, -204600,3.1294336,1.2357086,,,,,,,,,,,,,, -204700,2.955204,1.9459713,,,,,,,,,,,,,, -204800,3.0758867,2.1826856,,,,,,,,,,,,,, -204900,3.236575,2.1296394,,,,,,,,,,,,,, -205000,3.9381902,1.1226158,,,,,,,,,,,,,, -205100,2.9290967,1.3094873,,,,,,,,,,,,,, -205200,3.7361412,2.8524978,,,,,,,,,,,,,, -205300,3.0820274,1.0292327,,,,,,,,,,,,,, -205328,,,0.8882616758346558,0.4138765633106231,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,91205.78978681564,98257.74850678444,91205.78978681564,7029.166851758957,12.281495094299316,0.0 -205400,2.914954,1.0968083,,,,,,,,,,,,,, -205500,3.2422717,2.013093,,,,,,,,,,,,,, -205600,3.0712886,1.093203,,,,,,,,,,,,,, -205700,3.5829623,1.2256455,,,,,,,,,,,,,, -205800,3.1467702,1.2342522,,,,,,,,,,,,,, -205900,3.1091619,1.6771059,,,,,,,,,,,,,, -206000,3.1736116,1.1617471,,,,,,,,,,,,,, -206100,3.3872442,3.1131754,,,,,,,,,,,,,, -206200,2.9258778,1.2603682,,,,,,,,,,,,,, -206275,,,0.8873828053474426,0.4200892448425293,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,91625.9358868599,98711.7869529724,91625.9358868599,7062.951898813248,12.340657472610474,0.0 -206300,3.3403282,2.8580158,,,,,,,,,,,,,, -206400,3.2965288,2.8341668,,,,,,,,,,,,,, -206500,3.1757414,1.0449433,,,,,,,,,,,,,, -206600,3.037206,1.061612,,,,,,,,,,,,,, -206700,3.1339228,2.6473851,,,,,,,,,,,,,, -206800,3.0347369,1.1047609,,,,,,,,,,,,,, -206900,3.1139858,1.1382577,,,,,,,,,,,,,, -207000,3.2831182,1.5686177,,,,,,,,,,,,,, -207100,3.62052,3.033097,,,,,,,,,,,,,, -207200,3.180032,1.2709682,,,,,,,,,,,,,, -207220,,,0.8865429759025574,0.4198354482650757,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,92046.22657752036,99164.57158637048,92046.22657752036,7095.32580947876,12.412407398223875,0.0 -207300,3.000302,1.6602876,,,,,,,,,,,,,, -207400,3.022585,2.4944596,,,,,,,,,,,,,, -207500,3.3651855,0.91589016,,,,,,,,,,,,,, -207600,3.2687457,1.1084428,,,,,,,,,,,,,, -207700,2.9951622,2.4341488,,,,,,,,,,,,,, -207800,3.177199,1.4255035,,,,,,,,,,,,,, -207900,3.931516,3.2886214,,,,,,,,,,,,,, -208000,3.1958184,1.5860747,,,,,,,,,,,,,, -208100,2.9706562,1.1383322,,,,,,,,,,,,,, -208166,,,0.8901171684265137,0.4154210090637207,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,92466.52595090866,99624.68998265266,92466.52595090866,7135.026664972305,12.482912302017212,0.0 -208200,3.1420643,2.1698604,,,,,,,,,,,,,, -208300,4.5852585,2.9842315,,,,,,,,,,,,,, -208400,3.397518,1.1280884,,,,,,,,,,,,,, -208500,3.1860893,2.2117999,,,,,,,,,,,,,, -208600,3.5356324,3.162828,,,,,,,,,,,,,, -208700,2.9843519,1.095981,,,,,,,,,,,,,, -208800,3.3819191,1.1183244,,,,,,,,,,,,,, -208900,3.5471814,3.1760275,,,,,,,,,,,,,, -209000,3.5470512,3.0405707,,,,,,,,,,,,,, -209100,2.9341617,1.5513616,,,,,,,,,,,,,, -209115,,,0.8893554210662842,0.4160081148147583,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,92886.82992863657,100079.15053081512,92886.82992863657,7169.079889535904,12.53828740119934,0.0 -209200,3.1447768,1.789139,,,,,,,,,,,,,, -209300,3.3571887,1.1228827,,,,,,,,,,,,,, -209400,3.7920272,3.2429817,,,,,,,,,,,,,, -209500,3.1174562,1.1778646,,,,,,,,,,,,,, -209600,3.9835665,3.3191864,,,,,,,,,,,,,, -209700,3.0339766,1.1854936,,,,,,,,,,,,,, -209800,3.3432095,1.0371604,,,,,,,,,,,,,, -209900,3.0877776,1.1099422,,,,,,,,,,,,,, -210000,3.0877767,2.309256,,,,,,,,,,,,,, -210059,,,0.8882812261581421,0.4137320518493652,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,93306.80130004884,100535.8836786747,93306.80130004884,7205.7183492183685,12.614384651184082,0.0 -210100,3.0382814,1.0868456,,,,,,,,,,,,,, -210200,2.8443878,1.0826181,,,,,,,,,,,,,, -210300,3.1769657,1.0375063,,,,,,,,,,,,,, -210400,3.105965,2.316758,,,,,,,,,,,,,, -210500,3.8062549,3.1879532,,,,,,,,,,,,,, -210600,3.0047023,1.0842494,,,,,,,,,,,,,, -210700,3.1458576,1.655165,,,,,,,,,,,,,, -210800,3.134564,1.2786084,,,,,,,,,,,,,, -210900,3.2156549,1.1304657,,,,,,,,,,,,,, -211000,3.4581459,1.076341,,,,,,,,,,,,,, -211003,,,0.8879687190055847,0.4160420000553131,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,93726.69717526436,100995.70543003082,93726.69717526436,7245.537149429321,12.672444820404053,0.0 -211100,2.9330468,1.8090056,,,,,,,,,,,,,, -211200,3.2996953,1.1106244,,,,,,,,,,,,,, -211300,3.1504498,1.1959987,,,,,,,,,,,,,, -211400,2.8570256,1.1297705,,,,,,,,,,,,,, -211500,3.1646965,1.1621248,,,,,,,,,,,,,, -211600,4.1481094,3.2145357,,,,,,,,,,,,,, -211700,3.192699,1.3510197,,,,,,,,,,,,,, -211800,2.9817145,1.7851181,,,,,,,,,,,,,, -211900,3.8239458,3.1512492,,,,,,,,,,,,,, -211947,,,0.8864062428474426,0.4183524847030639,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,94146.89120841026,101449.40781760216,94146.89120841026,7278.940035820007,12.730435132980348,0.0 -212000,3.1577334,1.1606959,,,,,,,,,,,,,, -212100,3.7573347,3.220232,,,,,,,,,,,,,, -212200,3.1672537,1.0887909,,,,,,,,,,,,,, -212300,3.027212,1.0790081,,,,,,,,,,,,,, -212400,3.0334253,1.35771,,,,,,,,,,,,,, -212500,3.3500197,1.3396882,,,,,,,,,,,,,, -212600,3.2566502,1.8280145,,,,,,,,,,,,,, -212700,3.3092322,1.1850743,,,,,,,,,,,,,, -212800,3.1354551,2.7314122,,,,,,,,,,,,,, -212888,,,0.8856250047683716,0.4198617041110992,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,94566.94662880898,101910.72256612778,94566.94662880898,7320.083385229111,12.798975467681885,0.0 -212900,3.226547,2.830437,,,,,,,,,,,,,, -213000,3.0376036,1.0910484,,,,,,,,,,,,,, -213100,3.5108945,1.0551875,,,,,,,,,,,,,, -213200,2.924917,1.1113373,,,,,,,,,,,,,, -213300,3.000714,1.7553142,,,,,,,,,,,,,, -213400,3.6686604,3.2219248,,,,,,,,,,,,,, -213500,2.978556,1.9677929,,,,,,,,,,,,,, -213600,3.4979966,1.1526641,,,,,,,,,,,,,, -213700,3.0100632,1.242773,,,,,,,,,,,,,, -213800,3.1830165,1.8714588,,,,,,,,,,,,,, -213835,,,0.8876562118530273,0.4190992712974548,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,94986.92520284653,102362.89806699751,94986.92520284653,7352.1744792461395,12.85687255859375,0.0 -213900,2.936473,1.098801,,,,,,,,,,,,,, -214000,3.138969,1.992922,,,,,,,,,,,,,, -214100,3.2088535,1.7752213,,,,,,,,,,,,,, -214200,2.9973166,1.1536506,,,,,,,,,,,,,, -214300,3.5665793,1.1916301,,,,,,,,,,,,,, -214400,3.7338083,2.7538178,,,,,,,,,,,,,, -214500,3.1729875,2.3656006,,,,,,,,,,,,,, -214600,3.2691836,1.1709385,,,,,,,,,,,,,, -214700,3.260815,1.1969391,,,,,,,,,,,,,, -214780,,,0.8873242139816284,0.4169353246688843,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,95407.19551420212,102821.45707058908,95407.19551420212,7390.342691898346,12.929483413696287,0.0 -214800,3.7358472,3.1506798,,,,,,,,,,,,,, -214900,3.0035365,1.1560868,,,,,,,,,,,,,, -215000,3.176918,1.0998924,,,,,,,,,,,,,, -215100,3.8368092,3.3910913,,,,,,,,,,,,,, -215200,2.9670892,1.2203977,,,,,,,,,,,,,, -215300,2.9507394,1.0545696,,,,,,,,,,,,,, -215400,3.00379,1.072839,,,,,,,,,,,,,, -215500,3.3931184,2.7393653,,,,,,,,,,,,,, -215600,3.2235765,1.1487861,,,,,,,,,,,,,, -215700,3.1783078,1.1371431,,,,,,,,,,,,,, -215728,,,0.8858398199081421,0.4253257513046264,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,95827.43306875227,103280.76931118964,95827.43306875227,7429.297716617584,13.001351118087769,0.0 -215800,3.602558,1.7399894,,,,,,,,,,,,,, -215900,3.0161648,1.181986,,,,,,,,,,,,,, -216000,3.2293828,1.3923886,,,,,,,,,,,,,, -216100,3.119017,1.1647319,,,,,,,,,,,,,, -216200,3.1691291,1.1690038,,,,,,,,,,,,,, -216300,2.934224,1.4634236,,,,,,,,,,,,,, -216400,3.0866637,1.1587973,,,,,,,,,,,,,, -216500,3.201337,2.5436447,,,,,,,,,,,,,, -216600,3.3162298,1.165772,,,,,,,,,,,,,, -216675,,,0.8865820169448853,0.4243214726448059,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,96247.62993168832,103732.32486963272,96247.62993168832,7460.550596475601,13.059123516082764,0.0 -216700,3.5869932,2.8346705,,,,,,,,,,,,,, -216800,3.5387933,3.1025705,,,,,,,,,,,,,, -216900,2.9712098,2.0226474,,,,,,,,,,,,,, -217000,3.388447,1.7062991,,,,,,,,,,,,,, -217100,3.4001284,1.9234388,,,,,,,,,,,,,, -217200,3.7258186,3.0159922,,,,,,,,,,,,,, -217300,2.9125528,1.1579492,,,,,,,,,,,,,, -217400,3.1888206,1.554979,,,,,,,,,,,,,, -217500,2.9706287,1.0723231,,,,,,,,,,,,,, -217600,3.1972837,1.455039,,,,,,,,,,,,,, -217614,,,0.8871874809265137,0.4233055114746094,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,96667.72405338287,104188.11450624466,96667.72405338287,7496.124892711639,13.133168935775757,0.0 -217700,3.3332582,1.1952348,,,,,,,,,,,,,, -217800,3.6242416,3.200534,,,,,,,,,,,,,, -217900,2.9420388,2.414922,,,,,,,,,,,,,, -218000,3.1562486,2.5204303,,,,,,,,,,,,,, -218100,3.2507408,1.2573452,,,,,,,,,,,,,, -218200,3.405894,2.6868908,,,,,,,,,,,,,, -218300,2.98899,1.1379977,,,,,,,,,,,,,, -218400,2.9265208,1.061486,,,,,,,,,,,,,, -218500,3.040797,2.6652718,,,,,,,,,,,,,, -218559,,,0.8890820145606995,0.4117428064346313,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,97087.75162792206,104641.87342977524,97087.75162792206,7529.747981071472,13.193522214889526,0.0 -218600,2.9122384,1.8608036,,,,,,,,,,,,,, -218700,2.953427,1.0566272,,,,,,,,,,,,,, -218800,3.2784066,1.1973488,,,,,,,,,,,,,, -218900,3.858575,2.4916975,,,,,,,,,,,,,, -219000,3.0756106,1.1235646,,,,,,,,,,,,,, -219100,3.289843,1.1581028,,,,,,,,,,,,,, -219200,3.0992012,2.014864,,,,,,,,,,,,,, -219300,3.120254,2.351827,,,,,,,,,,,,,, -219400,2.899183,1.9420214,,,,,,,,,,,,,, -219500,3.3493311,1.1780633,,,,,,,,,,,,,, -219501,,,0.8858202695846558,0.421209454536438,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,97507.7348575592,105093.92755794524,97507.7348575592,7561.695729017258,13.267573595046995,0.0 -219600,3.0324042,1.9501648,,,,,,,,,,,,,, -219700,3.2660668,1.2854277,,,,,,,,,,,,,, -219800,3.6959686,2.9712152,,,,,,,,,,,,,, -219900,3.0132246,1.034358,,,,,,,,,,,,,, -220000,3.1687841,2.5444427,,,,,,,,,,,,,, -220100,2.9781199,1.128649,,,,,,,,,,,,,, -220200,3.090377,1.167109,,,,,,,,,,,,,, -220300,3.1760247,2.1253057,,,,,,,,,,,,,, -220400,2.9792264,1.4639313,,,,,,,,,,,,,, -220445,,,0.8859570026397705,0.4233082830905914,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,97927.69353795052,105553.43908929823,97927.69353795052,7601.125306606293,13.342846632003784,0.0 -220500,3.1269188,1.747123,,,,,,,,,,,,,, -220600,3.228506,2.9080205,,,,,,,,,,,,,, -220700,3.323142,1.0978518,,,,,,,,,,,,,, -220800,3.1101255,1.1365422,,,,,,,,,,,,,, -220900,3.1248739,1.7285373,,,,,,,,,,,,,, -221000,3.9901311,3.0042477,,,,,,,,,,,,,, -221100,2.9686534,1.1433961,,,,,,,,,,,,,, -221200,2.7127204,1.0519375,,,,,,,,,,,,,, -221300,3.231558,1.1156497,,,,,,,,,,,,,, -221391,,,0.8848632574081421,0.4247506260871887,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,98347.8355793953,106021.63544940948,98347.8355793953,7649.0584235191345,13.41545057296753,0.0 -221400,4.3674955,3.229063,,,,,,,,,,,,,, -221500,3.2128947,1.1548272,,,,,,,,,,,,,, -221600,3.223628,2.2725575,,,,,,,,,,,,,, -221700,3.153857,1.1809003,,,,,,,,,,,,,, -221800,3.1277933,1.3077396,,,,,,,,,,,,,, -221900,2.8585682,1.9266281,,,,,,,,,,,,,, -222000,3.519976,2.9534798,,,,,,,,,,,,,, -222100,3.4182374,2.8083382,,,,,,,,,,,,,, -222200,3.613972,1.1237376,,,,,,,,,,,,,, -222300,3.2740881,2.709376,,,,,,,,,,,,,, -222339,,,0.8887499570846558,0.4126403629779815,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,98768.19242358208,106474.71693491936,98768.19242358208,7681.673318862915,13.478192329406738,0.0 -222400,3.40859,1.2624227,,,,,,,,,,,,,, -222500,3.0337715,1.2109256,,,,,,,,,,,,,, -222600,3.5123994,1.1553397,,,,,,,,,,,,,, -222700,3.2770753,1.0763584,,,,,,,,,,,,,, -222800,3.0660732,1.1994148,,,,,,,,,,,,,, -222900,2.8662896,1.0315174,,,,,,,,,,,,,, -223000,3.046098,1.0589998,,,,,,,,,,,,,, -223100,3.131722,1.1026219,,,,,,,,,,,,,, -223200,3.4577024,3.063398,,,,,,,,,,,,,, -223282,,,0.8872851133346558,0.4197381734848022,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,99188.47766494752,106941.02724575996,99188.47766494752,7727.574855804443,13.553714990615845,0.0 -223300,3.0335162,1.5139924,,,,,,,,,,,,,, -223400,3.239661,2.577908,,,,,,,,,,,,,, -223500,3.9059072,3.1982305,,,,,,,,,,,,,, -223600,3.1898477,1.1290165,,,,,,,,,,,,,, -223700,3.136837,1.0933046,,,,,,,,,,,,,, -223800,2.880893,1.0081363,,,,,,,,,,,,,, -223900,3.3737519,1.1049304,,,,,,,,,,,,,, -224000,3.1751287,2.7267094,,,,,,,,,,,,,, -224100,3.0823236,1.1574653,,,,,,,,,,,,,, -224200,3.3597803,1.722202,,,,,,,,,,,,,, -224226,,,0.8883593678474426,0.4179112017154693,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,99608.62461566924,107395.39229464532,99608.62461566924,7761.677759170532,13.621427297592165,0.0 -224300,3.157406,1.1477685,,,,,,,,,,,,,, -224400,3.2008448,1.1841056,,,,,,,,,,,,,, -224500,2.91946,1.7634048,,,,,,,,,,,,,, -224600,2.9235225,1.1055727,,,,,,,,,,,,,, -224700,3.5809593,1.1857486,,,,,,,,,,,,,, -224800,3.2486837,2.5833356,,,,,,,,,,,,,, -224900,2.9462292,1.1820128,,,,,,,,,,,,,, -225000,3.1502504,1.9837265,,,,,,,,,,,,,, -225100,3.052244,1.0254096,,,,,,,,,,,,,, -225168,,,0.8888476490974426,0.4202959835529327,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,100028.8138911724,107850.28579640388,100028.8138911724,7796.263030529022,13.69261384010315,0.0 -225200,2.949042,1.195762,,,,,,,,,,,,,, -225300,2.9021933,1.4690844,,,,,,,,,,,,,, -225400,3.0334723,1.2836804,,,,,,,,,,,,,, -225500,3.218435,1.0309123,,,,,,,,,,,,,, -225600,2.8680186,1.627086,,,,,,,,,,,,,, -225700,2.9499862,1.2816774,,,,,,,,,,,,,, -225800,3.079021,1.0922966,,,,,,,,,,,,,, -225900,3.630815,2.7670121,,,,,,,,,,,,,, -226000,3.3782532,2.824876,,,,,,,,,,,,,, -226100,3.6274567,2.8156176,,,,,,,,,,,,,, -226114,,,0.8854687213897705,0.4251309037208557,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,100449.1869559288,108306.42710208891,100449.1869559288,7831.920604705811,13.755226135253906,0.0 -226200,2.8553293,1.6397192,,,,,,,,,,,,,, -226300,3.2788858,1.159844,,,,,,,,,,,,,, -226400,3.0806684,1.082591,,,,,,,,,,,,,, -226500,2.9811683,1.1261783,,,,,,,,,,,,,, -226600,3.8064706,3.205667,,,,,,,,,,,,,, -226700,3.9354594,3.1620388,,,,,,,,,,,,,, -226800,3.229923,2.870908,,,,,,,,,,,,,, -226900,3.0171902,1.2197429,,,,,,,,,,,,,, -227000,3.7651777,3.235305,,,,,,,,,,,,,, -227056,,,0.8884961009025574,0.4193135797977447,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,100869.26940894128,108766.48868989944,100869.26940894128,7871.779316186905,13.82857346534729,0.0 -227100,2.9028504,1.7358785,,,,,,,,,,,,,, -227200,3.0700195,1.2079971,,,,,,,,,,,,,, -227300,3.7320259,1.9683253,,,,,,,,,,,,,, -227400,3.1221774,1.1807369,,,,,,,,,,,,,, -227500,3.438349,2.639367,,,,,,,,,,,,,, -227600,2.7870202,1.0864599,,,,,,,,,,,,,, -227700,3.3804102,2.9443023,,,,,,,,,,,,,, -227800,2.9163742,1.189134,,,,,,,,,,,,,, -227900,2.9470177,1.1708378,,,,,,,,,,,,,, -228000,3.3452601,1.143228,,,,,,,,,,,,,, -228002,,,0.8871288895606995,0.4178809225559234,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,101289.16657400133,109218.66709923744,101289.16657400133,7903.954140663147,13.88698935508728,0.0 -228100,3.3768642,1.2677138,,,,,,,,,,,,,, -228200,3.3272874,1.1934346,,,,,,,,,,,,,, -228300,3.1379385,1.4032735,,,,,,,,,,,,,, -228400,3.0137186,1.1329045,,,,,,,,,,,,,, -228500,3.0071914,1.1008518,,,,,,,,,,,,,, -228600,3.146386,1.4965379,,,,,,,,,,,,,, -228700,3.0477874,1.1027613,,,,,,,,,,,,,, -228800,3.0069687,1.0219765,,,,,,,,,,,,,, -228900,3.911931,3.1667438,,,,,,,,,,,,,, -228945,,,0.8885351419448853,0.4126009345054626,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,101709.11290287971,109674.52127766608,101709.11290287971,7939.73398065567,13.966781616210938,0.0 -229000,3.082459,1.1249841,,,,,,,,,,,,,, -229100,2.848978,1.1714745,,,,,,,,,,,,,, -229200,3.2730527,2.7711508,,,,,,,,,,,,,, -229300,3.063968,1.3972608,,,,,,,,,,,,,, -229400,4.438338,2.6325264,,,,,,,,,,,,,, -229500,3.0714269,1.3416011,,,,,,,,,,,,,, -229600,3.0447106,1.0531694,,,,,,,,,,,,,, -229700,3.0191126,1.0886751,,,,,,,,,,,,,, -229800,3.1514966,2.102672,,,,,,,,,,,,,, -229889,,,0.8883984088897705,0.4152209758758545,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,102129.02837610243,110137.63670706748,102129.02837610243,7982.809792280197,14.043018579483032,0.0 -229900,2.9755995,1.1694987,,,,,,,,,,,,,, -230000,3.3143137,2.2559733,,,,,,,,,,,,,, -230100,3.132626,1.1608281,,,,,,,,,,,,,, -230200,3.0937054,1.6540514,,,,,,,,,,,,,, -230300,3.1781623,2.6200266,,,,,,,,,,,,,, -230400,3.3801942,1.3314754,,,,,,,,,,,,,, -230500,3.4177778,2.948639,,,,,,,,,,,,,, -230600,3.2215064,1.2419207,,,,,,,,,,,,,, -230700,3.1385603,2.2763977,,,,,,,,,,,,,, -230800,3.1810653,1.1989217,,,,,,,,,,,,,, -230836,,,0.8867382407188416,0.4209681153297424,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,102549.27608656885,110589.76342749596,102549.27608656885,8014.580620288849,14.102319240570068,0.0 -230900,3.0280561,1.3260688,,,,,,,,,,,,,, -231000,3.0792484,1.2938733,,,,,,,,,,,,,, -231100,3.21031,1.1017494,,,,,,,,,,,,,, -231200,3.7952716,3.2112246,,,,,,,,,,,,,, -231300,3.0324628,1.9241484,,,,,,,,,,,,,, -231400,3.2927065,2.5378048,,,,,,,,,,,,,, -231500,2.860002,1.2494662,,,,,,,,,,,,,, -231600,3.300763,1.1160301,,,,,,,,,,,,,, -231700,2.9443598,1.0534213,,,,,,,,,,,,,, -231779,,,0.8880859017372131,0.424091637134552,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,102969.37177467346,111048.59785699844,102969.37177467346,8053.198905706406,14.175431966781616,0.0 -231800,3.2859774,1.2077516,,,,,,,,,,,,,, -231900,3.241973,2.2807415,,,,,,,,,,,,,, -232000,2.9054148,1.1663167,,,,,,,,,,,,,, -232100,3.1198733,2.2581682,,,,,,,,,,,,,, -232200,3.3694956,1.0881162,,,,,,,,,,,,,, -232300,2.9731128,1.4644487,,,,,,,,,,,,,, -232400,3.1690536,1.1343305,,,,,,,,,,,,,, -232500,3.490222,2.944452,,,,,,,,,,,,,, -232600,3.0680778,2.1750813,,,,,,,,,,,,,, -232700,3.0008023,1.0630648,,,,,,,,,,,,,, -232726,,,0.8880859017372131,0.4161204695701599,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,103389.27419066428,111506.15778017044,103389.27419066428,8090.729656934738,14.253586530685425,0.0 -232800,3.1932008,1.5376292,,,,,,,,,,,,,, -232900,3.2165809,1.2499944,,,,,,,,,,,,,, -233000,3.3098533,1.7348388,,,,,,,,,,,,,, -233100,3.8413007,2.9328122,,,,,,,,,,,,,, -233200,3.315668,1.0378999,,,,,,,,,,,,,, -233300,2.9599502,1.1014736,,,,,,,,,,,,,, -233400,2.898374,1.1394563,,,,,,,,,,,,,, -233500,3.1433413,1.12848,,,,,,,,,,,,,, -233600,3.2214463,1.1744254,,,,,,,,,,,,,, -233671,,,0.889941394329071,0.4111701250076294,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,103809.17875909804,111965.66636037828,103809.17875909804,8130.175797224045,14.363630533218384,0.0 -233700,3.3294587,1.1852708,,,,,,,,,,,,,, -233800,3.8494956,3.202676,,,,,,,,,,,,,, -233900,3.0317724,1.1075621,,,,,,,,,,,,,, -234000,2.965251,1.1195127,,,,,,,,,,,,,, -234100,3.1755426,1.0934972,,,,,,,,,,,,,, -234200,2.8446698,1.0138388,,,,,,,,,,,,,, -234300,3.0761602,1.1971726,,,,,,,,,,,,,, -234400,3.4967031,1.3061409,,,,,,,,,,,,,, -234500,3.1028762,1.1381228,,,,,,,,,,,,,, -234600,2.981055,1.3216009,,,,,,,,,,,,,, -234615,,,0.8862109184265137,0.4182306230068207,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,104229.44356608392,112420.273696661,104229.44356608392,8164.389766454697,14.44425344467163,0.0 -234700,3.9471374,2.9431942,,,,,,,,,,,,,, -234800,3.1072743,1.2687438,,,,,,,,,,,,,, -234900,3.0284636,1.1812389,,,,,,,,,,,,,, -235000,3.1759024,1.1879109,,,,,,,,,,,,,, -235100,3.2151704,1.1088668,,,,,,,,,,,,,, -235200,3.9002552,3.0902023,,,,,,,,,,,,,, -235300,2.79689,1.410946,,,,,,,,,,,,,, -235400,3.1892552,2.1798096,,,,,,,,,,,,,, -235500,3.1329994,1.2354393,,,,,,,,,,,,,, -235560,,,0.8863866925239563,0.4218202233314514,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,104649.33818554878,112883.30195975304,104649.33818554878,8207.39955830574,14.519267797470093,0.0 -235600,3.495166,1.9133917,,,,,,,,,,,,,, -235700,2.9279017,0.92807037,,,,,,,,,,,,,, -235800,3.0839186,1.1375626,,,,,,,,,,,,,, -235900,3.1465366,1.1956216,,,,,,,,,,,,,, -236000,2.98065,1.3507502,,,,,,,,,,,,,, -236100,3.654589,2.4473782,,,,,,,,,,,,,, -236200,3.4717002,2.8440838,,,,,,,,,,,,,, -236300,2.9608154,1.5837517,,,,,,,,,,,,,, -236400,3.2746513,1.1867099,,,,,,,,,,,,,, -236500,3.145596,1.0981246,,,,,,,,,,,,,, -236507,,,0.8854101300239563,0.4208202362060547,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,105069.57685279846,113342.91085219385,105069.57685279846,8246.647035121918,14.593485116958618,0.0 -236600,2.9158473,2.458553,,,,,,,,,,,,,, -236700,2.916929,1.4312243,,,,,,,,,,,,,, -236800,3.659854,3.0127728,,,,,,,,,,,,,, -236900,3.2490108,2.3081167,,,,,,,,,,,,,, -237000,3.280779,2.5726604,,,,,,,,,,,,,, -237100,3.3856587,2.0357194,,,,,,,,,,,,,, -237200,3.8523602,3.349821,,,,,,,,,,,,,, -237300,3.3998785,1.1703295,,,,,,,,,,,,,, -237400,3.8998523,3.298735,,,,,,,,,,,,,, -237453,,,0.8888476490974426,0.4130127131938934,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,105489.77764558792,113796.5252687931,105489.77764558792,8279.95104432106,14.655708074569702,0.0 -237500,2.8852885,1.376384,,,,,,,,,,,,,, -237600,2.8172896,1.0035582,,,,,,,,,,,,,, -237700,3.1185582,1.1028714,,,,,,,,,,,,,, -237800,3.356497,1.496626,,,,,,,,,,,,,, -237900,3.1591277,1.8984392,,,,,,,,,,,,,, -238000,3.109632,1.242741,,,,,,,,,,,,,, -238100,3.002904,1.0897179,,,,,,,,,,,,,, -238200,2.8648667,1.1070061,,,,,,,,,,,,,, -238300,2.9880006,1.3342452,,,,,,,,,,,,,, -238397,,,0.8869140148162842,0.4204769432544708,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,105909.77715063097,114252.5577852726,105909.77715063097,8315.858958244324,14.733111381530762,0.0 -238400,3.4111602,2.8151004,,,,,,,,,,,,,, -238500,3.8540514,3.2255745,,,,,,,,,,,,,, -238600,3.2538152,2.360108,,,,,,,,,,,,,, -238700,3.1533523,1.3040762,,,,,,,,,,,,,, -238800,3.2485585,2.378615,,,,,,,,,,,,,, -238900,3.3653915,1.6099929,,,,,,,,,,,,,, -239000,3.5015159,2.8787913,,,,,,,,,,,,,, -239100,2.9119136,1.5964863,,,,,,,,,,,,,, -239200,3.0615964,1.1144744,,,,,,,,,,,,,, -239300,3.1227283,2.1426663,,,,,,,,,,,,,, -239341,,,0.8874022960662842,0.4200328290462494,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,106329.69542336464,114713.22124695778,106329.69542336464,8356.481803894043,14.807372331619264,0.0 -239400,3.262671,2.0346742,,,,,,,,,,,,,, -239500,3.01211,1.1439922,,,,,,,,,,,,,, -239600,3.1291332,2.1362598,,,,,,,,,,,,,, -239700,4.143275,3.288003,,,,,,,,,,,,,, -239800,3.2118998,1.1775401,,,,,,,,,,,,,, -239900,3.9560091,3.2330978,,,,,,,,,,,,,, -240000,3.0042615,1.5934638,,,,,,,,,,,,,, -240100,2.9215658,1.0776117,,,,,,,,,,,,,, -240200,3.7953203,3.2540019,,,,,,,,,,,,,, -240287,,,0.8852929472923279,0.4266456067562103,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,106749.73273301125,115168.345505476,106749.73273301125,8391.442966938019,14.885765552520752,0.0 -240300,3.6216633,2.91415,,,,,,,,,,,,,, -240400,3.6577618,1.25221,,,,,,,,,,,,,, -240500,3.1360445,2.357363,,,,,,,,,,,,,, -240600,3.2399793,1.1101688,,,,,,,,,,,,,, -240700,3.6006153,3.1961114,,,,,,,,,,,,,, -240800,3.0450253,2.229477,,,,,,,,,,,,,, -240900,3.4335454,1.1996174,,,,,,,,,,,,,, -241000,3.056978,1.7539315,,,,,,,,,,,,,, -241100,3.8138623,3.1755364,,,,,,,,,,,,,, -241200,3.0690665,1.1117954,,,,,,,,,,,,,, -241233,,,0.8872656226158142,0.4201632738113403,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,107169.90454864502,115622.76381874084,107169.90454864502,8425.569839000702,14.957653284072876,0.0 -241300,3.3117225,1.2162051,,,,,,,,,,,,,, -241400,3.1332316,1.9904277,,,,,,,,,,,,,, -241500,3.1378536,1.162556,,,,,,,,,,,,,, -241600,3.2208626,1.146163,,,,,,,,,,,,,, -241700,3.0948503,1.2612686,,,,,,,,,,,,,, -241800,3.1306448,1.3660307,,,,,,,,,,,,,, -241900,5.725975,1.1378009,,,,,,,,,,,,,, -242000,2.9956603,1.1786547,,,,,,,,,,,,,, -242100,3.4712727,1.1160063,,,,,,,,,,,,,, -242178,,,0.887988269329071,0.4175548553466797,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,107590.16424393654,116082.39153313635,107590.16424393654,8464.812617301941,15.034671306610107,0.0 -242200,3.2404656,1.0692204,,,,,,,,,,,,,, -242300,3.1548104,1.1699334,,,,,,,,,,,,,, -242400,3.15701,1.108331,,,,,,,,,,,,,, -242500,3.8195026,1.2822084,,,,,,,,,,,,,, -242600,3.2679555,2.4853034,,,,,,,,,,,,,, -242700,3.261243,1.0444639,,,,,,,,,,,,,, -242800,3.0084274,1.8629656,,,,,,,,,,,,,, -242900,4.2320256,3.291584,,,,,,,,,,,,,, -243000,3.5298204,2.8759642,,,,,,,,,,,,,, -243100,2.8217962,1.4642102,,,,,,,,,,,,,, -243120,,,0.8871093392372131,0.4164575934410095,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,108010.23139071465,116545.69679760931,108010.23139071465,8507.924887418747,15.112300157546995,0.0 -243200,2.9996233,1.2076151,,,,,,,,,,,,,, -243300,3.7790387,3.28926,,,,,,,,,,,,,, -243400,2.8681207,1.1454897,,,,,,,,,,,,,, -243500,3.037593,2.6027422,,,,,,,,,,,,,, -243600,2.8502707,2.2267065,,,,,,,,,,,,,, -243700,3.1194184,2.2211506,,,,,,,,,,,,,, -243800,3.1887543,1.1826141,,,,,,,,,,,,,, -243900,3.2338126,1.2723733,,,,,,,,,,,,,, -244000,3.1180928,1.9131441,,,,,,,,,,,,,, -244061,,,0.8856835961341858,0.4228624999523163,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,108430.22557520866,116999.6195745468,108430.22557520866,8541.744990348816,15.172938346862791,0.0 -244100,3.0727603,1.5129298,,,,,,,,,,,,,, -244200,3.610639,2.60293,,,,,,,,,,,,,, -244300,3.0277736,2.1448202,,,,,,,,,,,,,, -244400,3.42099,3.0204885,,,,,,,,,,,,,, -244500,2.8836274,2.2348206,,,,,,,,,,,,,, -244600,3.1865516,1.1665355,,,,,,,,,,,,,, -244700,3.0093472,1.7195703,,,,,,,,,,,,,, -244800,2.8972304,2.186134,,,,,,,,,,,,,, -244900,2.936436,1.114077,,,,,,,,,,,,,, -245000,3.0238829,1.2067468,,,,,,,,,,,,,, -245008,,,0.8870702981948853,0.418248176574707,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,108850.48312044144,117460.33714318275,108850.48312044144,8582.07614517212,15.252984046936035,0.0 -245100,3.7327368,2.9691083,,,,,,,,,,,,,, -245200,2.8490047,1.7645261,,,,,,,,,,,,,, -245300,3.004579,1.4057345,,,,,,,,,,,,,, -245400,3.2426648,1.1253082,,,,,,,,,,,,,, -245500,3.079374,1.0842236,,,,,,,,,,,,,, -245600,3.1012287,1.1489779,,,,,,,,,,,,,, -245700,3.189589,1.1373966,,,,,,,,,,,,,, -245800,3.7292044,3.15455,,,,,,,,,,,,,, -245900,2.9424152,1.0674362,,,,,,,,,,,,,, -245951,,,0.8870898485183716,0.4204050004482269,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,109270.4581296444,117921.85386919975,109270.4581296444,8623.50287437439,15.32018232345581,0.0 -246000,3.2255805,1.5545793,,,,,,,,,,,,,, -246100,3.2822537,1.097684,,,,,,,,,,,,,, -246200,3.4485214,1.4472429,,,,,,,,,,,,,, -246300,3.9065208,3.2776046,,,,,,,,,,,,,, -246400,3.131952,1.1878989,,,,,,,,,,,,,, -246500,3.2362092,2.561144,,,,,,,,,,,,,, -246600,3.118997,1.227963,,,,,,,,,,,,,, -246700,3.0709481,1.3825246,,,,,,,,,,,,,, -246800,2.9272676,1.053794,,,,,,,,,,,,,, -246896,,,0.8881640434265137,0.4150146245956421,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,109690.64672994614,118375.8336148262,109690.64672994614,8657.179294109344,15.38704538345337,0.0 -246900,3.2711372,1.1227729,,,,,,,,,,,,,, -247000,3.0038335,1.5505726,,,,,,,,,,,,,, -247100,2.8620968,1.1410033,,,,,,,,,,,,,, -247200,3.075893,1.184932,,,,,,,,,,,,,, -247300,3.0519965,2.3035183,,,,,,,,,,,,,, -247400,3.1241465,2.3318446,,,,,,,,,,,,,, -247500,3.2000828,1.2072189,,,,,,,,,,,,,, -247600,2.9666376,1.1039326,,,,,,,,,,,,,, -247700,3.2767851,1.0792139,,,,,,,,,,,,,, -247800,3.0023985,1.4167519,,,,,,,,,,,,,, -247838,,,0.8864648342132568,0.4250794947147369,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,110110.56650781631,118843.29978728294,110110.56650781631,8704.600757598877,15.464092493057253,0.0 -247900,3.0911453,2.2607284,,,,,,,,,,,,,, -248000,3.400016,1.1234771,,,,,,,,,,,,,, -248100,3.0957482,2.2575142,,,,,,,,,,,,,, -248200,3.2596538,2.4110062,,,,,,,,,,,,,, -248300,2.9758902,1.1116614,,,,,,,,,,,,,, -248400,8.306618,1.5820869,,,,,,,,,,,,,, -248500,3.1690478,2.4484086,,,,,,,,,,,,,, -248600,3.158894,1.1448438,,,,,,,,,,,,,, -248700,3.7536962,2.8646178,,,,,,,,,,,,,, -248784,,,0.8899218440055847,0.4135270118713379,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,110530.78303217888,119296.5821928978,110530.78303217888,8737.558348178864,15.524272203445436,0.0 -248800,3.1849318,1.4283743,,,,,,,,,,,,,, -248900,3.0248053,1.1386995,,,,,,,,,,,,,, -249000,3.153561,1.2548795,,,,,,,,,,,,,, -249100,3.324199,1.1291437,,,,,,,,,,,,,, -249200,3.3985603,2.648705,,,,,,,,,,,,,, -249300,2.9932945,1.1447283,,,,,,,,,,,,,, -249400,3.0802083,1.1067286,,,,,,,,,,,,,, -249500,3.0586498,1.1639789,,,,,,,,,,,,,, -249600,2.852628,1.0902305,,,,,,,,,,,,,, -249700,3.394187,1.9946432,,,,,,,,,,,,,, -249728,,,0.887011706829071,0.4246560633182525,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,110950.67167448996,119755.48139166832,110950.67167448996,8776.442138671875,15.602407932281494,0.0 -249800,2.9847803,1.1116246,,,,,,,,,,,,,, -249900,3.1066267,1.4206822,,,,,,,,,,,,,, -250000,3.1723287,1.3578705,,,,,,,,,,,,,, -250100,2.9095948,1.1533544,,,,,,,,,,,,,, -250200,3.0993485,1.2253329,,,,,,,,,,,,,, -250300,2.9885106,2.5251722,,,,,,,,,,,,,, -250400,3.7894425,3.1217895,,,,,,,,,,,,,, -250500,3.094342,1.1323401,,,,,,,,,,,,,, -250600,3.249158,2.7935345,,,,,,,,,,,,,, -250671,,,0.8860546946525574,0.4246920347213745,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,111370.60998010635,120207.40252876282,111370.60998010635,8808.311047792435,15.66840434074402,0.0 -250700,2.7540374,0.9954474,,,,,,,,,,,,,, -250800,3.3060462,2.54968,,,,,,,,,,,,,, -250900,3.2702308,1.1046796,,,,,,,,,,,,,, -251000,2.9639678,1.2516977,,,,,,,,,,,,,, -251100,3.1732657,1.0791491,,,,,,,,,,,,,, -251200,2.9421496,1.3909079,,,,,,,,,,,,,, -251300,3.1189024,1.0380284,,,,,,,,,,,,,, -251400,3.3081229,2.607159,,,,,,,,,,,,,, -251500,3.029265,1.4305937,,,,,,,,,,,,,, -251600,3.1860363,2.0628142,,,,,,,,,,,,,, -251616,,,0.887011706829071,0.4161558449268341,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,111790.58244228364,120668.84587788582,111790.58244228364,8849.612287044525,15.789583206176758,0.0 -251700,3.109943,1.0787405,,,,,,,,,,,,,, -251800,3.3654768,1.1113976,,,,,,,,,,,,,, -251900,3.5400858,2.430593,,,,,,,,,,,,,, -252000,3.0716143,1.5474918,,,,,,,,,,,,,, -252100,3.1422353,1.1511171,,,,,,,,,,,,,, -252200,3.3139622,1.0702184,,,,,,,,,,,,,, -252300,3.1097689,1.7452555,,,,,,,,,,,,,, -252400,3.8102546,3.2217515,,,,,,,,,,,,,, -252500,2.920601,1.3325912,,,,,,,,,,,,,, -252562,,,0.8890429735183716,0.4178885519504547,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,112210.55100989342,121126.47691631316,112210.55100989342,8887.146436452866,15.8693106174469,0.0 -252600,3.090744,1.2476717,,,,,,,,,,,,,, -252700,2.9556975,1.2754035,,,,,,,,,,,,,, -252800,3.032925,1.2002771,,,,,,,,,,,,,, -252900,3.2824223,2.3289213,,,,,,,,,,,,,, -253000,2.8897896,1.1261497,,,,,,,,,,,,,, -253100,3.161637,1.1619961,,,,,,,,,,,,,, -253200,4.0718617,2.8877606,,,,,,,,,,,,,, -253300,3.1903257,1.110125,,,,,,,,,,,,,, -253400,3.1519916,1.0510654,,,,,,,,,,,,,, -253500,3.0257993,1.0696461,,,,,,,,,,,,,, -253506,,,0.8874413967132568,0.4144672751426697,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,112630.65174293518,121583.36565971376,112630.65174293518,8923.819966077805,15.935075044631958,0.0 -253600,4.1555557,1.2208978,,,,,,,,,,,,,, -253700,3.6798284,1.5872104,,,,,,,,,,,,,, -253800,2.954793,1.0254228,,,,,,,,,,,,,, -253900,3.176927,1.5291687,,,,,,,,,,,,,, -254000,3.0589297,1.4261217,,,,,,,,,,,,,, -254100,3.2816558,1.1438198,,,,,,,,,,,,,, -254200,3.6687956,3.1376886,,,,,,,,,,,,,, -254300,2.9458973,1.4924612,,,,,,,,,,,,,, -254400,3.2756984,3.0546398,,,,,,,,,,,,,, -254449,,,0.88623046875,0.418743759393692,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,113050.73615145683,122038.9248905182,113050.73615145683,8959.172646284103,16.009252786636353,0.0 -254500,3.0602674,2.295178,,,,,,,,,,,,,, -254600,3.685805,2.8365307,,,,,,,,,,,,,, -254700,2.9239144,1.9559829,,,,,,,,,,,,,, -254800,3.4103239,2.8354979,,,,,,,,,,,,,, -254900,3.0714576,1.0639855,,,,,,,,,,,,,, -255000,3.632842,3.1635294,,,,,,,,,,,,,, -255100,2.930333,1.8747084,,,,,,,,,,,,,, -255200,3.0640328,1.283693,,,,,,,,,,,,,, -255300,3.5745254,3.2416844,,,,,,,,,,,,,, -255392,,,0.8873242139816284,0.4254908561706543,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,113470.6403222084,122504.66809415816,113470.6403222084,9004.886109113693,16.087283849716187,0.0 -255400,3.912892,3.3249967,,,,,,,,,,,,,, -255500,3.139866,2.2924256,,,,,,,,,,,,,, -255600,2.8087468,1.2066381,,,,,,,,,,,,,, -255700,3.1383522,1.4497296,,,,,,,,,,,,,, -255800,2.9594285,1.2182523,,,,,,,,,,,,,, -255900,3.2667463,1.2521946,,,,,,,,,,,,,, -256000,3.3168821,1.0960363,,,,,,,,,,,,,, -256100,2.8160877,1.9349456,,,,,,,,,,,,,, -256200,3.523367,1.4056844,,,,,,,,,,,,,, -256300,2.715699,1.3822563,,,,,,,,,,,,,, -256336,,,0.8916210532188416,0.4107459783554077,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,113890.8896214962,122969.64131331444,113890.8896214962,9049.49839234352,16.151769876480103,0.0 -256400,3.3273702,2.7695749,,,,,,,,,,,,,, -256500,3.4890854,2.7999597,,,,,,,,,,,,,, -256600,3.8126855,3.0142999,,,,,,,,,,,,,, -256700,3.023307,1.0993203,,,,,,,,,,,,,, -256800,2.8557892,1.6166766,,,,,,,,,,,,,, -256900,3.108609,1.0618621,,,,,,,,,,,,,, -257000,3.079471,1.4130481,,,,,,,,,,,,,, -257100,2.7148619,1.0629117,,,,,,,,,,,,,, -257200,3.0796647,1.0976154,,,,,,,,,,,,,, -257282,,,0.8884570002555847,0.4124673306941986,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,114310.84763216972,123430.52208042143,114310.84763216972,9090.309470415115,16.214890003204346,0.0 -257300,3.0549512,2.2064483,,,,,,,,,,,,,, -257400,3.24611,1.348287,,,,,,,,,,,,,, -257500,3.0301185,1.881335,,,,,,,,,,,,,, -257600,2.9912925,1.5079898,,,,,,,,,,,,,, -257700,2.9248378,1.3955642,,,,,,,,,,,,,, -257800,3.1950905,1.1082939,,,,,,,,,,,,,, -257900,2.977174,1.0524831,,,,,,,,,,,,,, -258000,3.3038042,2.0876846,,,,,,,,,,,,,, -258100,3.3576186,2.3254135,,,,,,,,,,,,,, -258200,2.6919448,1.0638804,,,,,,,,,,,,,, -258220,,,0.8871874809265137,0.4193939566612243,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,114730.95115160942,123889.350730896,114730.95115160942,9128.919853687286,16.28192138671875,0.0 -258300,3.2731657,2.5516663,,,,,,,,,,,,,, -258400,3.1698353,2.8197875,,,,,,,,,,,,,, -258500,3.2827187,1.218318,,,,,,,,,,,,,, -258600,2.9981074,2.1670463,,,,,,,,,,,,,, -258700,3.0063589,1.1740294,,,,,,,,,,,,,, -258800,2.9470963,1.4228753,,,,,,,,,,,,,, -258900,3.1503992,1.7284455,,,,,,,,,,,,,, -259000,2.973777,1.1074659,,,,,,,,,,,,,, -259100,3.0669117,1.0982063,,,,,,,,,,,,,, -259161,,,0.8880859017372131,0.4152601063251495,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,115150.87838816644,124344.0568537712,115150.87838816644,9163.585408687592,16.34766459465027,0.0 -259200,3.0725129,2.2371325,,,,,,,,,,,,,, -259300,2.9331527,1.1909099,,,,,,,,,,,,,, -259400,3.170147,1.1289722,,,,,,,,,,,,,, -259500,3.4067965,1.149573,,,,,,,,,,,,,, -259600,3.1311748,1.11029,,,,,,,,,,,,,, -259700,3.58462,2.5662043,,,,,,,,,,,,,, -259800,2.90208,1.6722022,,,,,,,,,,,,,, -259900,3.0607026,1.1975219,,,,,,,,,,,,,, -260000,3.0783355,2.4941442,,,,,,,,,,,,,, -260100,3.074105,1.0769784,,,,,,,,,,,,,, -260103,,,0.8839648365974426,0.4242160022258758,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,115571.12068414688,124805.29729104042,115571.12068414688,9204.456241607666,16.427672386169434,0.0 -260200,4.6069684,3.1229327,,,,,,,,,,,,,, -260300,2.8791883,1.2793804,,,,,,,,,,,,,, -260400,3.3436239,3.021242,,,,,,,,,,,,,, -260500,3.204006,1.1062748,,,,,,,,,,,,,, -260600,4.4717655,1.0882051,,,,,,,,,,,,,, -260700,3.0142634,1.2165983,,,,,,,,,,,,,, -260800,3.0203638,1.1670479,,,,,,,,,,,,,, -260900,3.1147788,1.7754278,,,,,,,,,,,,,, -261000,3.081039,1.3511889,,,,,,,,,,,,,, -261048,,,0.8904882669448853,0.4089166224002838,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,115991.3933467865,125260.09818053246,115991.3933467865,9238.854766130447,16.508938312530518,0.0 -261100,3.0725203,1.311295,,,,,,,,,,,,,, -261200,3.1012163,2.269635,,,,,,,,,,,,,, -261300,3.1122036,1.7494011,,,,,,,,,,,,,, -261400,2.914479,1.1380404,,,,,,,,,,,,,, -261500,3.1607325,1.1460826,,,,,,,,,,,,,, -261600,3.0573196,1.1059207,,,,,,,,,,,,,, -261700,3.0110402,1.6249998,,,,,,,,,,,,,, -261800,3.2375576,1.1982431,,,,,,,,,,,,,, -261900,3.6899297,1.1160327,,,,,,,,,,,,,, -261990,,,0.8859374523162842,0.4217362105846405,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,116411.57664585114,125722.46207761765,116411.57664585114,9280.902841329576,16.59387707710266,0.0 -262000,3.4124074,2.9394622,,,,,,,,,,,,,, -262100,3.081491,1.4408171,,,,,,,,,,,,,, -262200,3.5095007,2.6934862,,,,,,,,,,,,,, -262300,3.0154948,1.2589623,,,,,,,,,,,,,, -262400,3.120419,1.7994952,,,,,,,,,,,,,, -262500,3.1716905,2.6992953,,,,,,,,,,,,,, -262600,3.147509,2.0996287,,,,,,,,,,,,,, -262700,3.5406322,2.971857,,,,,,,,,,,,,, -262800,2.8467042,1.6304971,,,,,,,,,,,,,, -262900,2.983135,1.2750108,,,,,,,,,,,,,, -262932,,,0.8861132860183716,0.42596235871315,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,116831.9596195221,126184.22378492355,116831.9596195221,9322.155301809313,16.67237138748169,0.0 -263000,3.721679,3.1939914,,,,,,,,,,,,,, -263100,3.267282,2.6655753,,,,,,,,,,,,,, -263200,3.2172842,2.4757037,,,,,,,,,,,,,, -263300,3.3079133,2.6096776,,,,,,,,,,,,,, -263400,3.36842,1.771774,,,,,,,,,,,,,, -263500,3.103556,1.7928956,,,,,,,,,,,,,, -263600,3.0457716,1.1335574,,,,,,,,,,,,,, -263700,3.1226113,1.5721021,,,,,,,,,,,,,, -263800,3.090858,2.499455,,,,,,,,,,,,,, -263881,,,0.8869140148162842,0.4212360680103302,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,117252.12936592102,126637.7700855732,117252.12936592102,9355.41879272461,16.736831426620483,0.0 -263900,3.1399896,1.0966575,,,,,,,,,,,,,, -264000,2.9253466,1.2680497,,,,,,,,,,,,,, -264100,3.1135156,2.1545887,,,,,,,,,,,,,, -264200,3.2803972,1.1665906,,,,,,,,,,,,,, -264300,2.965638,1.8987919,,,,,,,,,,,,,, -264400,3.391103,2.8792772,,,,,,,,,,,,,, -264500,2.7927537,1.5181689,,,,,,,,,,,,,, -264600,4.2316074,3.3412902,,,,,,,,,,,,,, -264700,5.1171927,3.2091126,,,,,,,,,,,,,, -264800,3.01983,1.3135053,,,,,,,,,,,,,, -264823,,,0.8873632550239563,0.418342113494873,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,117672.37464976312,127102.61178898811,117672.37464976312,9399.8897895813,16.814719676971436,0.0 -264900,3.308759,1.0721834,,,,,,,,,,,,,, -265000,2.9914768,1.3264881,,,,,,,,,,,,,, -265100,3.4349997,1.2361785,,,,,,,,,,,,,, -265200,3.0823805,1.0624928,,,,,,,,,,,,,, -265300,3.7451391,1.1797843,,,,,,,,,,,,,, -265400,2.9715972,1.1283355,,,,,,,,,,,,,, -265500,2.9989166,2.4032307,,,,,,,,,,,,,, -265600,3.0649037,1.3083295,,,,,,,,,,,,,, -265700,3.4393582,3.0966976,,,,,,,,,,,,,, -265767,,,0.88671875,0.420605331659317,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,118092.33594608308,127555.41320848464,118092.33594608308,9432.615530729294,16.88166904449463,0.0 -265800,3.645456,3.0645337,,,,,,,,,,,,,, -265900,3.0658834,1.4784927,,,,,,,,,,,,,, -266000,3.031449,1.1861333,,,,,,,,,,,,,, -266100,3.109055,1.1897968,,,,,,,,,,,,,, -266200,2.9345808,1.1075372,,,,,,,,,,,,,, -266300,3.1497364,1.2765784,,,,,,,,,,,,,, -266400,3.0852644,1.210094,,,,,,,,,,,,,, -266500,3.5474381,2.869827,,,,,,,,,,,,,, -266600,3.221817,1.0956296,,,,,,,,,,,,,, -266700,2.8445544,1.078165,,,,,,,,,,,,,, -266709,,,0.8873828053474426,0.4178589880466461,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,118512.41550278664,128017.80984807014,118512.41550278664,9474.807350158691,16.959134101867676,0.0 -266800,3.2003477,1.1370748,,,,,,,,,,,,,, -266900,3.1770134,1.2486199,,,,,,,,,,,,,, -267000,3.0778475,1.3023918,,,,,,,,,,,,,, -267100,3.0785034,1.1417805,,,,,,,,,,,,,, -267200,3.412408,1.4912679,,,,,,,,,,,,,, -267300,3.1780736,2.6591005,,,,,,,,,,,,,, -267400,3.4967337,2.939874,,,,,,,,,,,,,, -267500,3.2193856,1.07694,,,,,,,,,,,,,, -267600,2.993161,1.3698426,,,,,,,,,,,,,, -267654,,,0.8849608898162842,0.4240126311779022,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,118932.5847196579,128471.3901963234,118932.5847196579,9508.095239162443,17.03475069999695,0.0 -267700,2.912727,1.1830071,,,,,,,,,,,,,, -267800,2.9756446,1.021252,,,,,,,,,,,,,, -267900,3.1243181,1.0628759,,,,,,,,,,,,,, -268000,2.9371467,1.2016083,,,,,,,,,,,,,, -268100,3.0867293,1.0514507,,,,,,,,,,,,,, -268200,3.00945,1.3530633,,,,,,,,,,,,,, -268300,2.9880154,1.3128749,,,,,,,,,,,,,, -268400,3.462341,3.0968838,,,,,,,,,,,,,, -268500,2.9660563,2.4085882,,,,,,,,,,,,,, -268597,,,0.8870702981948853,0.4228694140911102,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,119352.5260078907,128934.17479658128,119352.5260078907,9550.807997703552,17.11642861366272,0.0 -268600,3.8265822,3.2359028,,,,,,,,,,,,,, -268700,3.5926237,1.224248,,,,,,,,,,,,,, -268800,3.4057543,3.053018,,,,,,,,,,,,,, -268900,4.5898466,2.7567766,,,,,,,,,,,,,, -269000,3.4290075,2.8958402,,,,,,,,,,,,,, -269100,2.987233,1.0788965,,,,,,,,,,,,,, -269200,3.2845407,1.0815387,,,,,,,,,,,,,, -269300,3.1059759,2.744881,,,,,,,,,,,,,, -269400,3.6304705,3.017423,,,,,,,,,,,,,, -269500,3.0588765,1.1286092,,,,,,,,,,,,,, -269543,,,0.8866991996765137,0.4204760789871216,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,119772.48274731636,129388.14326405524,119772.48274731636,9584.70620727539,17.181878328323364,0.0 -269600,2.906283,2.2177415,,,,,,,,,,,,,, -269700,3.1951036,1.100455,,,,,,,,,,,,,, -269800,3.588986,3.1306012,,,,,,,,,,,,,, -269900,2.927217,1.2254109,,,,,,,,,,,,,, -270000,3.1622756,1.0683469,,,,,,,,,,,,,, -270100,2.8425117,1.0958105,,,,,,,,,,,,,, -270200,3.0907288,1.069163,,,,,,,,,,,,,, -270300,3.1333692,1.1785384,,,,,,,,,,,,,, -270400,2.9547563,1.4194108,,,,,,,,,,,,,, -270487,,,0.88734370470047,0.4167658388614654,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,120192.77802968024,129845.8600654602,120192.77802968024,9622.012517929075,17.249467372894287,0.0 -270500,3.1461267,1.261812,,,,,,,,,,,,,, -270600,3.0436068,1.1043646,,,,,,,,,,,,,, -270700,3.032419,1.7464954,,,,,,,,,,,,,, -270800,3.1277623,1.4215004,,,,,,,,,,,,,, -270900,3.2700522,1.1394387,,,,,,,,,,,,,, -271000,3.2483988,1.8737841,,,,,,,,,,,,,, -271100,2.8959627,0.980764,,,,,,,,,,,,,, -271200,2.74344,1.2101657,,,,,,,,,,,,,, -271300,2.9862452,1.296854,,,,,,,,,,,,,, -271400,3.2968938,2.5169346,,,,,,,,,,,,,, -271429,,,0.8888280987739563,0.4134787023067474,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,120612.67361688614,130303.69858121872,120612.67361688614,9659.829708576202,17.32775592803955,0.0 -271500,3.2058096,1.1660979,,,,,,,,,,,,,, -271600,3.534309,2.8420143,,,,,,,,,,,,,, -271700,3.2940843,1.3580735,,,,,,,,,,,,,, -271800,3.1332963,1.7946011,,,,,,,,,,,,,, -271900,3.2118893,1.1519525,,,,,,,,,,,,,, -272000,3.0390687,1.4587095,,,,,,,,,,,,,, -272100,3.1580904,1.1424966,,,,,,,,,,,,,, -272200,3.2746918,1.3887489,,,,,,,,,,,,,, -272300,3.1572702,2.3223007,,,,,,,,,,,,,, -272378,,,0.8880078196525574,0.4239377379417419,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,121032.8384103775,130757.11981844902,121032.8384103775,9692.951695919037,17.413233041763306,0.0 -272400,3.2509763,1.526645,,,,,,,,,,,,,, -272500,3.1849704,1.049424,,,,,,,,,,,,,, -272600,3.181073,1.2273643,,,,,,,,,,,,,, -272700,2.9438875,1.579989,,,,,,,,,,,,,, -272800,3.1414094,1.0989166,,,,,,,,,,,,,, -272900,3.2321744,1.0980109,,,,,,,,,,,,,, -273000,3.0679696,2.2933881,,,,,,,,,,,,,, -273100,3.1139112,1.3140715,,,,,,,,,,,,,, -273200,3.0860505,1.38761,,,,,,,,,,,,,, -273300,3.3601918,1.1668204,,,,,,,,,,,,,, -273322,,,0.886035144329071,0.4260351955890655,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,121452.86763739586,131220.82015681267,121452.86763739586,9736.4949696064,17.49224090576172,0.0 -273400,3.676966,3.2226093,,,,,,,,,,,,,, -273500,2.8548899,1.363704,,,,,,,,,,,,,, -273600,3.1701057,1.1228769,,,,,,,,,,,,,, -273700,3.651878,3.2789564,,,,,,,,,,,,,, -273800,2.9813714,1.0569763,,,,,,,,,,,,,, -273900,3.4493113,1.2511585,,,,,,,,,,,,,, -274000,2.8032548,1.5012896,,,,,,,,,,,,,, -274100,3.211677,1.5758158,,,,,,,,,,,,,, -274200,2.8055074,1.0582485,,,,,,,,,,,,,, -274270,,,0.8879492282867432,0.4197743535041809,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,121872.97571897508,131675.82518553734,121872.97571897508,9771.279657840729,17.556384801864624,0.0 -274300,3.225145,1.1801636,,,,,,,,,,,,,, -274400,3.0536215,1.1865534,,,,,,,,,,,,,, -274500,3.0147185,1.1098523,,,,,,,,,,,,,, -274600,3.9919953,3.1610932,,,,,,,,,,,,,, -274700,2.853129,2.046936,,,,,,,,,,,,,, -274800,3.0007637,1.308368,,,,,,,,,,,,,, -274900,3.1660943,1.3891239,,,,,,,,,,,,,, -275000,3.0600405,2.0416853,,,,,,,,,,,,,, -275100,3.0630217,1.0610969,,,,,,,,,,,,,, -275200,3.05766,2.1392338,,,,,,,,,,,,,, -275215,,,0.8881444931030273,0.4141191244125366,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,122293.25101804732,132132.355342865,122293.25101804732,9807.406393289566,17.63720178604126,0.0 -275300,3.1167996,1.2519174,,,,,,,,,,,,,, -275400,3.2545912,1.0466936,,,,,,,,,,,,,, -275500,3.2458723,2.128591,,,,,,,,,,,,,, -275600,2.9760199,1.1150502,,,,,,,,,,,,,, -275700,3.237723,1.1324581,,,,,,,,,,,,,, -275800,3.1440837,1.3672665,,,,,,,,,,,,,, -275900,3.1972568,1.4884061,,,,,,,,,,,,,, -276000,3.2843268,1.1833048,,,,,,,,,,,,,, -276100,3.0712256,1.4642414,,,,,,,,,,,,,, -276158,,,0.8878515362739563,0.4155722260475158,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,122713.5250134468,132585.34187698364,122713.5250134468,9839.993040561676,17.716108798980713,0.0 -276200,3.413864,2.9579594,,,,,,,,,,,,,, -276300,3.1285882,1.1718475,,,,,,,,,,,,,, -276400,3.170925,1.1043886,,,,,,,,,,,,,, -276500,3.3498251,2.3430996,,,,,,,,,,,,,, -276600,2.8997319,1.1348057,,,,,,,,,,,,,, -276700,3.0789225,2.6057193,,,,,,,,,,,,,, -276800,3.3118753,2.7445486,,,,,,,,,,,,,, -276900,3.3766177,2.7542179,,,,,,,,,,,,,, -277000,3.4988291,2.699725,,,,,,,,,,,,,, -277100,,,0.8867968320846558,0.4192649722099304,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,123133.66612696648,133042.52073717117,123133.66612696648,9876.90296959877,17.7959988117218,0.0 -277100,3.0218387,1.1547186,,,,,,,,,,,,,, -277200,3.0140135,1.4016861,,,,,,,,,,,,,, -277300,3.6723428,3.1326327,,,,,,,,,,,,,, -277400,2.9664254,2.1339355,,,,,,,,,,,,,, -277500,3.0743568,1.6896224,,,,,,,,,,,,,, -277600,3.089756,1.315676,,,,,,,,,,,,,, -277700,3.136761,1.0833014,,,,,,,,,,,,,, -277800,3.2417011,1.0648232,,,,,,,,,,,,,, -277900,3.0319686,1.9201345,,,,,,,,,,,,,, -278000,3.3719318,1.695844,,,,,,,,,,,,,, -278042,,,0.8861327767372131,0.4243049621582031,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,123553.5646944046,133501.51552557945,123553.5646944046,9915.872620105743,17.875136137008667,0.0 -278100,3.25971,1.2942866,,,,,,,,,,,,,, -278200,2.991989,1.0798464,,,,,,,,,,,,,, -278300,3.2290454,1.1523137,,,,,,,,,,,,,, -278400,3.85834,3.055194,,,,,,,,,,,,,, -278500,3.168723,1.150944,,,,,,,,,,,,,, -278600,3.3105667,2.9074485,,,,,,,,,,,,,, -278700,4.834256,3.1982565,,,,,,,,,,,,,, -278800,2.9886043,1.4083767,,,,,,,,,,,,,, -278900,3.0968366,2.096711,,,,,,,,,,,,,, -278985,,,0.8908007740974426,0.4121743738651275,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,123973.49132466316,133958.46650099754,123973.49132466316,9952.77225279808,17.952096462249756,0.0 -279000,3.0629761,1.6113763,,,,,,,,,,,,,, -279100,2.9166696,1.1749535,,,,,,,,,,,,,, -279200,3.0830164,1.4673398,,,,,,,,,,,,,, -279300,3.0426507,1.1546825,,,,,,,,,,,,,, -279400,3.3058445,1.0941907,,,,,,,,,,,,,, -279500,3.5235488,1.0690713,,,,,,,,,,,,,, -279600,3.0354738,1.1293395,,,,,,,,,,,,,, -279700,3.1272142,1.1346594,,,,,,,,,,,,,, -279800,3.0511491,1.704951,,,,,,,,,,,,,, -279900,2.940746,1.5202382,,,,,,,,,,,,,, -279931,,,0.8890624642372131,0.420417308807373,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,124393.40948843956,134421.4090359211,124393.40948843956,9995.665901184082,18.03502511978149,0.0 -280000,2.8405247,1.0418329,,,,,,,,,,,,,, -280100,2.814565,1.0002546,,,,,,,,,,,,,, -280200,2.9202898,1.8796688,,,,,,,,,,,,,, -280300,3.044874,2.665649,,,,,,,,,,,,,, -280400,3.3082368,2.5664954,,,,,,,,,,,,,, -280500,3.6371262,3.0206208,,,,,,,,,,,,,, -280600,3.1000047,1.1057962,,,,,,,,,,,,,, -280700,3.5304537,2.2275677,,,,,,,,,,,,,, -280800,3.4455688,2.902895,,,,,,,,,,,,,, -280878,,,0.8887109160423279,0.4124118089675903,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,124813.36013317108,134874.94128251076,124813.36013317108,10029.1331782341,18.101533889770508,0.0 -280900,3.075659,1.1372666,,,,,,,,,,,,,, -281000,3.0407724,1.3174052,,,,,,,,,,,,,, -281100,3.1734812,1.1326113,,,,,,,,,,,,,, -281200,3.1525495,1.1410562,,,,,,,,,,,,,, -281300,2.9561522,1.3304859,,,,,,,,,,,,,, -281400,3.092958,1.3205495,,,,,,,,,,,,,, -281500,3.2042477,2.4763565,,,,,,,,,,,,,, -281600,3.088611,1.3551601,,,,,,,,,,,,,, -281700,3.0907702,1.1170707,,,,,,,,,,,,,, -281800,3.1046836,1.1855102,,,,,,,,,,,,,, -281823,,,0.8863085508346558,0.4174687564373016,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,125233.32454299928,135329.92215943336,125233.32454299928,10064.021770715714,18.181446075439453,0.0 -281900,2.8839548,2.3521311,,,,,,,,,,,,,, -282000,3.190219,2.0183418,,,,,,,,,,,,,, -282100,2.8185816,1.3900925,,,,,,,,,,,,,, -282200,3.060462,2.0899816,,,,,,,,,,,,,, -282300,3.0400321,1.56597,,,,,,,,,,,,,, -282400,3.1606128,1.2282536,,,,,,,,,,,,,, -282500,3.1216486,2.4144392,,,,,,,,,,,,,, -282600,3.376038,2.93793,,,,,,,,,,,,,, -282700,3.0120375,0.9868666,,,,,,,,,,,,,, -282769,,,0.8883593678474426,0.4142705202102661,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,125653.50318288805,135781.73765468597,125653.50318288805,10095.547145366669,18.245970249176025,0.0 -282800,3.135378,2.574305,,,,,,,,,,,,,, -282900,3.0614486,1.6347578,,,,,,,,,,,,,, -283000,3.269262,1.0655037,,,,,,,,,,,,,, -283100,2.9181507,1.060104,,,,,,,,,,,,,, -283200,3.213749,1.133401,,,,,,,,,,,,,, -283300,3.1553419,1.1390011,,,,,,,,,,,,,, -283400,3.3337903,1.1305265,,,,,,,,,,,,,, -283500,3.5118613,3.1335258,,,,,,,,,,,,,, -283600,3.1736174,1.9838822,,,,,,,,,,,,,, -283700,3.0389209,1.1833856,,,,,,,,,,,,,, -283711,,,0.8863866925239563,0.4193476140499115,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,126073.44761157036,136247.80353331566,126073.44761157036,10141.53751373291,18.329392433166504,0.0 -283800,3.5765357,2.0274425,,,,,,,,,,,,,, -283900,3.7607236,3.1388814,,,,,,,,,,,,,, -284000,3.4661505,2.128829,,,,,,,,,,,,,, -284100,3.1909418,1.2914038,,,,,,,,,,,,,, -284200,3.0839057,1.2286431,,,,,,,,,,,,,, -284300,3.2453105,2.5417507,,,,,,,,,,,,,, -284400,4.1793613,1.1971424,,,,,,,,,,,,,, -284500,2.9091187,1.2503645,,,,,,,,,,,,,, -284600,2.9220493,1.3186119,,,,,,,,,,,,,, -284655,,,0.8861523270606995,0.4190988838672638,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,126493.50467848778,136701.16173362732,126493.50467848778,10174.72700858116,18.393612146377563,0.0 -284700,3.1466296,2.2674928,,,,,,,,,,,,,, -284800,3.5218215,2.6436448,,,,,,,,,,,,,, -284900,3.409166,2.7632473,,,,,,,,,,,,,, -285000,2.970742,1.0306298,,,,,,,,,,,,,, -285100,3.6203017,2.9379938,,,,,,,,,,,,,, -285200,3.1304996,1.6592674,,,,,,,,,,,,,, -285300,2.844767,1.0962671,,,,,,,,,,,,,, -285400,3.9789636,3.2056751,,,,,,,,,,,,,, -285500,3.0950217,2.086913,,,,,,,,,,,,,, -285599,,,0.8877733945846558,0.4184274971485138,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,126913.51318049432,137163.79151773453,126913.51318049432,10217.211532592772,18.482458353042603,0.0 -285600,3.1672595,2.0038269,,,,,,,,,,,,,, -285700,3.3154476,1.1481081,,,,,,,,,,,,,, -285800,3.039437,2.098208,,,,,,,,,,,,,, -285900,3.0358615,2.257881,,,,,,,,,,,,,, -286000,3.2731984,1.1543577,,,,,,,,,,,,,, -286100,3.343642,1.0852563,,,,,,,,,,,,,, -286200,2.9268126,1.119819,,,,,,,,,,,,,, -286300,3.1642888,2.7956004,,,,,,,,,,,,,, -286400,3.517556,3.1353822,,,,,,,,,,,,,, -286500,2.995862,1.4990265,,,,,,,,,,,,,, -286544,,,0.8854687213897705,0.426918625831604,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,127333.65675115584,137623.30382800102,127333.65675115584,10256.46224308014,18.5535101890564,0.0 -286600,3.176692,1.1692023,,,,,,,,,,,,,, -286700,3.0049725,1.1244786,,,,,,,,,,,,,, -286800,3.5261736,1.1639577,,,,,,,,,,,,,, -286900,3.1864166,1.1267985,,,,,,,,,,,,,, -287000,2.993395,1.3859034,,,,,,,,,,,,,, -287100,3.9937842,3.2809641,,,,,,,,,,,,,, -287200,3.0544567,0.9744762,,,,,,,,,,,,,, -287300,3.345343,2.8589911,,,,,,,,,,,,,, -287400,3.5632048,2.9905758,,,,,,,,,,,,,, -287489,,,0.8874609470367432,0.4195753335952759,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,127753.75061297417,138079.65566945076,127753.75061297417,10292.60477733612,18.62070345878601,0.0 -287500,3.4410534,1.6464062,,,,,,,,,,,,,, -287600,4.027263,3.2796597,,,,,,,,,,,,,, -287700,2.943232,1.0086664,,,,,,,,,,,,,, -287800,3.557404,2.018941,,,,,,,,,,,,,, -287900,3.0586634,1.405205,,,,,,,,,,,,,, -288000,3.1583169,1.2508638,,,,,,,,,,,,,, -288100,2.91183,2.1541772,,,,,,,,,,,,,, -288200,3.58434,2.835644,,,,,,,,,,,,,, -288300,3.0147173,1.3165191,,,,,,,,,,,,,, -288400,3.431722,2.3860717,,,,,,,,,,,,,, -288425,,,0.88525390625,0.4243783354759216,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,128174.01127171516,138539.70076584816,128174.01127171516,10332.257049798964,18.705175638198853,0.0 -288500,2.9613407,1.6954849,,,,,,,,,,,,,, -288600,3.2443326,2.181746,,,,,,,,,,,,,, -288700,3.7090008,3.4101107,,,,,,,,,,,,,, -288800,3.3585906,1.1364683,,,,,,,,,,,,,, -288900,3.13519,1.0675015,,,,,,,,,,,,,, -289000,2.8777506,2.1132417,,,,,,,,,,,,,, -289100,3.0674992,1.9607443,,,,,,,,,,,,,, -289200,3.5472214,1.1536785,,,,,,,,,,,,,, -289300,3.2876925,1.635216,,,,,,,,,,,,,, -289368,,,0.8885741829872131,0.4138832092285156,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,128594.14718818665,138994.94444060326,128594.14718818665,10367.236153364182,18.78646731376648,0.0 -289400,3.1043797,1.3391607,,,,,,,,,,,,,, -289500,3.0452244,1.1938399,,,,,,,,,,,,,, -289600,3.1519277,1.3584102,,,,,,,,,,,,,, -289700,3.13232,1.4872596,,,,,,,,,,,,,, -289800,3.5198197,3.0630724,,,,,,,,,,,,,, -289900,3.0251775,1.1310229,,,,,,,,,,,,,, -290000,2.8686006,1.7656696,,,,,,,,,,,,,, -290100,3.1722896,1.1659048,,,,,,,,,,,,,, -290200,2.7914474,1.7776307,,,,,,,,,,,,,, -290300,3.3076801,1.192396,,,,,,,,,,,,,, -290312,,,0.8881444931030273,0.4155890643596649,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,129014.123803854,139450.84994983673,129014.123803854,10403.037517786026,18.866474628448486,0.0 -290400,3.1992128,1.2229049,,,,,,,,,,,,,, -290500,3.0819442,1.0725574,,,,,,,,,,,,,, -290600,3.2581673,1.0665851,,,,,,,,,,,,,, -290700,3.216458,1.699784,,,,,,,,,,,,,, -290800,3.226475,1.2149771,,,,,,,,,,,,,, -290900,3.702964,3.338235,,,,,,,,,,,,,, -291000,3.281143,2.5594137,,,,,,,,,,,,,, -291100,3.247349,2.562993,,,,,,,,,,,,,, -291200,3.0099008,1.5325271,,,,,,,,,,,,,, -291257,,,0.8851562142372131,0.4261019825935364,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,129434.1880440712,139915.48982930183,129434.1880440712,10447.48250246048,18.94978713989257,0.0 -291300,3.3006551,1.2566116,,,,,,,,,,,,,, -291400,3.241291,2.1980357,,,,,,,,,,,,,, -291500,3.2785885,2.4368641,,,,,,,,,,,,,, -291600,3.1254072,1.0321858,,,,,,,,,,,,,, -291700,3.2114916,1.1554673,,,,,,,,,,,,,, -291800,3.1767998,2.514285,,,,,,,,,,,,,, -291900,3.6606016,1.1728114,,,,,,,,,,,,,, -292000,2.856216,1.3026954,,,,,,,,,,,,,, -292100,3.224386,1.1792402,,,,,,,,,,,,,, -292200,4.600947,3.2194169,,,,,,,,,,,,,, -292204,,,0.8864452838897705,0.4245300889015198,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,129854.40283584596,140371.68188858032,129854.40283584596,10483.340990066528,19.0203800201416,0.0 -292300,3.2189136,1.1851684,,,,,,,,,,,,,, -292400,3.001949,1.0824757,,,,,,,,,,,,,, -292500,3.1978743,2.76914,,,,,,,,,,,,,, -292600,3.2610183,2.7940352,,,,,,,,,,,,,, -292700,3.4973626,2.5923686,,,,,,,,,,,,,, -292800,3.4212642,1.0808827,,,,,,,,,,,,,, -292900,3.3391287,2.8908446,,,,,,,,,,,,,, -293000,3.486594,1.0663028,,,,,,,,,,,,,, -293100,3.1588879,1.9357508,,,,,,,,,,,,,, -293150,,,0.8864452838897705,0.4200622141361236,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,130274.47477340698,140828.01147723198,130274.47477340698,10519.48413681984,19.087130546569824,0.0 -293200,3.261757,1.1704968,,,,,,,,,,,,,, -293300,3.1195588,1.1565071,,,,,,,,,,,,,, -293400,3.4388816,1.198633,,,,,,,,,,,,,, -293500,3.0317452,1.7246457,,,,,,,,,,,,,, -293600,3.2128417,1.2001019,,,,,,,,,,,,,, -293700,3.1989765,1.1461488,,,,,,,,,,,,,, -293800,3.074893,2.23395,,,,,,,,,,,,,, -293900,3.3195732,1.2716033,,,,,,,,,,,,,, -294000,3.4858575,1.0967104,,,,,,,,,,,,,, -294091,,,0.8884570002555847,0.4107588827610016,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,130694.4135849476,141285.43379950523,130694.4135849476,10556.83842110634,19.169212818145752,0.0 -294100,3.1469462,1.1194732,,,,,,,,,,,,,, -294200,3.2734478,1.1239241,,,,,,,,,,,,,, -294300,3.1869733,1.1079706,,,,,,,,,,,,,, -294400,3.2697027,1.2430108,,,,,,,,,,,,,, -294500,3.050861,1.2512382,,,,,,,,,,,,,, -294600,3.083561,1.7777406,,,,,,,,,,,,,, -294700,3.16495,1.1151611,,,,,,,,,,,,,, -294800,3.0032368,1.3591957,,,,,,,,,,,,,, -294900,3.448372,1.8177003,,,,,,,,,,,,,, -295000,3.2818437,1.1471431,,,,,,,,,,,,,, -295034,,,0.8865429759025574,0.4236370325088501,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,131114.35461211205,141752.77542972565,131114.35461211205,10604.105157375336,19.25607824325561,0.0 -295100,3.6893563,3.0641341,,,,,,,,,,,,,, -295200,4.1333356,1.1415734,,,,,,,,,,,,,, -295300,2.9352856,1.9141128,,,,,,,,,,,,,, -295400,3.2431324,1.1158772,,,,,,,,,,,,,, -295500,3.9267163,2.934184,,,,,,,,,,,,,, -295600,3.7297745,3.000656,,,,,,,,,,,,,, -295700,4.3302646,3.2011592,,,,,,,,,,,,,, -295800,3.0512352,1.5736125,,,,,,,,,,,,,, -295900,3.0328352,2.0465624,,,,,,,,,,,,,, -295980,,,0.8881444931030273,0.4216232895851135,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,131534.4235270023,142213.62927746773,131534.4235270023,10644.770728588104,19.327077388763428,0.0 -296000,3.2197886,2.6189024,,,,,,,,,,,,,, -296100,3.2030144,2.907357,,,,,,,,,,,,,, -296200,3.359564,1.3349544,,,,,,,,,,,,,, -296300,3.9814913,1.3614672,,,,,,,,,,,,,, -296400,3.094396,1.084251,,,,,,,,,,,,,, -296500,2.8702621,1.7235057,,,,,,,,,,,,,, -296600,3.8413835,3.1429949,,,,,,,,,,,,,, -296700,3.2201777,1.315864,,,,,,,,,,,,,, -296800,4.051361,3.1919165,,,,,,,,,,,,,, -296900,3.0946906,1.165954,,,,,,,,,,,,,, -296924,,,0.8879101276397705,0.4217333495616913,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,131954.65556120872,142675.4125571251,131954.65556120872,10686.206349372864,19.395113229751587,0.0 -297000,2.758857,1.8730774,,,,,,,,,,,,,, -297100,3.280037,2.6510108,,,,,,,,,,,,,, -297200,3.3389044,2.7715466,,,,,,,,,,,,,, -297300,3.0314856,2.4815145,,,,,,,,,,,,,, -297400,3.177109,1.1061611,,,,,,,,,,,,,, -297500,3.494827,3.0735483,,,,,,,,,,,,,, -297600,3.3710556,1.2641947,,,,,,,,,,,,,, -297700,3.2554047,1.1032746,,,,,,,,,,,,,, -297800,2.8945868,1.3142955,,,,,,,,,,,,,, -297867,,,0.8878710865974426,0.4187273979187011,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,132374.87281131744,143140.9353840351,132374.87281131744,10731.38487648964,19.475355625152588,0.0 -297900,3.2387974,2.5262446,,,,,,,,,,,,,, -298000,3.2976694,2.7345636,,,,,,,,,,,,,, -298100,3.040875,2.5464723,,,,,,,,,,,,,, -298200,3.160728,1.8072404,,,,,,,,,,,,,, -298300,3.1526337,1.1784707,,,,,,,,,,,,,, -298400,3.7191799,3.3065596,,,,,,,,,,,,,, -298500,3.3691776,2.8075213,,,,,,,,,,,,,, -298600,3.164326,1.746457,,,,,,,,,,,,,, -298700,3.2266386,1.2688669,,,,,,,,,,,,,, -298800,3.0386338,1.9535158,,,,,,,,,,,,,, -298811,,,0.8876953125,0.4147098362445831,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,132794.94854402542,143594.57590150833,132794.94854402542,10764.83255124092,19.54459524154663,0.0 -298900,3.0899768,1.1966455,,,,,,,,,,,,,, -299000,3.1031337,1.8637387,,,,,,,,,,,,,, -299100,4.9591026,3.2822423,,,,,,,,,,,,,, -299200,3.0549464,1.1767968,,,,,,,,,,,,,, -299300,3.4548986,3.179985,,,,,,,,,,,,,, -299400,3.1323574,1.2451473,,,,,,,,,,,,,, -299500,3.019882,1.1154988,,,,,,,,,,,,,, -299600,2.8347204,1.1645014,,,,,,,,,,,,,, -299700,2.9801612,1.1241771,,,,,,,,,,,,,, -299750,,,0.8880273103713989,0.4155641794204712,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,133215.20007777214,144051.30251002312,133215.20007777214,10801.17714715004,19.628756284713745,0.0 -299800,3.3035843,1.1242085,,,,,,,,,,,,,, -299900,2.979063,0.9946408,,,,,,,,,,,,,, -300000,2.9084325,1.1682909,,,,,,,,,,,,,, -300100,3.103277,1.2622564,,,,,,,,,,,,,, -300200,3.113425,1.992256,,,,,,,,,,,,,, -300300,3.1952207,1.4076927,,,,,,,,,,,,,, -300400,3.1829948,1.2323382,,,,,,,,,,,,,, -300500,2.9050708,1.1556851,,,,,,,,,,,,,, -300600,3.0812333,1.3931116,,,,,,,,,,,,,, -300691,,,0.8881640434265137,0.4177298545837402,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,133635.22816705704,144517.28083229065,133635.22816705704,10846.996324777603,19.71213221549988,0.0 -300700,3.2295551,1.1383538,,,,,,,,,,,,,, -300800,3.3450713,1.230649,,,,,,,,,,,,,, -300900,3.2313938,2.3012707,,,,,,,,,,,,,, -301000,3.4371715,1.4213778,,,,,,,,,,,,,, -301100,3.6042266,1.1815867,,,,,,,,,,,,,, -301200,3.4673965,2.8065674,,,,,,,,,,,,,, -301300,3.5705006,3.0718365,,,,,,,,,,,,,, -301400,3.2356281,2.53892,,,,,,,,,,,,,, -301500,3.3741636,2.5906594,,,,,,,,,,,,,, -301600,3.1787453,1.20927,,,,,,,,,,,,,, -301638,,,0.8869531154632568,0.4213081002235412,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,134055.29154729843,144978.41498851776,134055.29154729843,10887.950706481934,19.78096437454224,0.0 -301700,3.9752626,2.7883525,,,,,,,,,,,,,, -301800,3.2598956,2.570443,,,,,,,,,,,,,, -301900,3.2514536,2.6655357,,,,,,,,,,,,,, -302000,3.1933043,1.2734503,,,,,,,,,,,,,, -302100,3.20587,1.0106452,,,,,,,,,,,,,, -302200,2.9346473,1.2816377,,,,,,,,,,,,,, -302300,3.3081303,1.1647512,,,,,,,,,,,,,, -302400,3.3247297,2.0323637,,,,,,,,,,,,,, -302500,3.1017647,1.1645535,,,,,,,,,,,,,, -302583,,,0.8878515362739563,0.4230779111385345,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,134475.1889846325,145436.19050335884,134475.1889846325,10925.710918664932,19.851025104522705,0.0 -302600,4.0110183,3.3103802,,,,,,,,,,,,,, -302700,3.001866,2.274451,,,,,,,,,,,,,, -302800,3.0403278,1.1330657,,,,,,,,,,,,,, -302900,3.743719,2.5701122,,,,,,,,,,,,,, -303000,4.016358,2.985993,,,,,,,,,,,,,, -303100,3.4733417,2.9978452,,,,,,,,,,,,,, -303200,3.5182724,1.6713959,,,,,,,,,,,,,, -303300,3.3302324,1.121385,,,,,,,,,,,,,, -303400,2.9512703,1.4212449,,,,,,,,,,,,,, -303500,3.0135376,1.1929833,,,,,,,,,,,,,, -303524,,,0.8898828029632568,0.4145982563495636,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,134895.2553062439,145897.11030721664,134895.2553062439,10966.428569555284,19.93898057937622,0.0 -303600,2.865679,1.1165116,,,,,,,,,,,,,, -303700,2.9662642,2.0437481,,,,,,,,,,,,,, -303800,3.0101652,1.1708509,,,,,,,,,,,,,, -303900,3.0572658,1.1669077,,,,,,,,,,,,,, -304000,2.8548841,1.14224,,,,,,,,,,,,,, -304100,2.999833,2.0176375,,,,,,,,,,,,,, -304200,3.8881075,3.1487317,,,,,,,,,,,,,, -304300,3.7304926,3.0033922,,,,,,,,,,,,,, -304400,3.1032915,1.1054933,,,,,,,,,,,,,, -304470,,,0.8882812261581421,0.4124145805835724,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,135315.58999705315,146357.09014439583,135315.58999705315,11005.957649946213,20.0074679851532,0.0 -304500,3.0642486,2.0067708,,,,,,,,,,,,,, -304600,2.858282,1.5594494,,,,,,,,,,,,,, -304700,3.0449886,1.3268187,,,,,,,,,,,,,, -304800,3.1697195,1.1559904,,,,,,,,,,,,,, -304900,3.0823252,2.5139613,,,,,,,,,,,,,, -305000,3.7182655,2.8921413,,,,,,,,,,,,,, -305100,3.3288934,1.0999964,,,,,,,,,,,,,, -305200,3.3155522,1.6652093,,,,,,,,,,,,,, -305300,3.2805932,1.1999707,,,,,,,,,,,,,, -305400,3.1807065,1.1377237,,,,,,,,,,,,,, -305411,,,0.8877733945846558,0.4161712825298309,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,135735.683208704,146811.18326807022,135735.683208704,11039.826193094254,20.091625690460205,0.0 -305500,3.3024697,2.1975038,,,,,,,,,,,,,, -305600,3.521471,1.2322608,,,,,,,,,,,,,, -305700,3.2395763,1.173751,,,,,,,,,,,,,, -305800,3.2109594,1.2351129,,,,,,,,,,,,,, -305900,2.8331857,1.5475177,,,,,,,,,,,,,, -306000,3.2892213,1.2910029,,,,,,,,,,,,,, -306100,3.1040883,1.2601688,,,,,,,,,,,,,, -306200,3.1319933,1.3638644,,,,,,,,,,,,,, -306300,2.8560364,2.2115674,,,,,,,,,,,,,, -306355,,,0.8863866925239563,0.4192544519901275,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,136155.78617358208,147275.9312517643,136155.78617358208,11084.333412647247,20.181591033935547,0.0 -306400,2.994093,1.2465451,,,,,,,,,,,,,, -306500,3.1548274,1.1444045,,,,,,,,,,,,,, -306600,3.3101218,1.2170496,,,,,,,,,,,,,, -306700,3.052102,1.5343289,,,,,,,,,,,,,, -306800,3.5614438,2.8268895,,,,,,,,,,,,,, -306900,3.0661469,1.9155138,,,,,,,,,,,,,, -307000,2.9842873,1.2809743,,,,,,,,,,,,,, -307100,3.8291817,3.3001385,,,,,,,,,,,,,, -307200,3.4469028,2.8924298,,,,,,,,,,,,,, -307300,3.056481,1.121369,,,,,,,,,,,,,, -307301,,,0.8872460722923279,0.4167300760746002,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,136576.10743117332,147733.73422026634,136576.10743117332,11121.695330381392,20.25333547592163,0.0 -307400,3.0598104,1.0579871,,,,,,,,,,,,,, -307500,2.977727,1.6529844,,,,,,,,,,,,,, -307600,2.900227,1.8475308,,,,,,,,,,,,,, -307700,3.1003206,1.1305121,,,,,,,,,,,,,, -307800,3.0814245,1.9304475,,,,,,,,,,,,,, -307900,3.1398053,1.476846,,,,,,,,,,,,,, -308000,2.829172,1.5696135,,,,,,,,,,,,,, -308100,3.0783641,1.2380211,,,,,,,,,,,,,, -308200,3.024735,1.4525249,,,,,,,,,,,,,, -308242,,,0.8857812285423279,0.4191692173480987,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,136996.03312301636,148187.63279628754,136996.03312301636,11155.53110909462,20.34227418899536,0.0 -308300,3.4357643,1.055589,,,,,,,,,,,,,, -308400,3.4839842,3.0495427,,,,,,,,,,,,,, -308500,3.14944,1.2000921,,,,,,,,,,,,,, -308600,3.3733711,1.3350996,,,,,,,,,,,,,, -308700,3.4606717,1.1098518,,,,,,,,,,,,,, -308800,3.1507206,2.4264207,,,,,,,,,,,,,, -308900,3.6012084,2.274318,,,,,,,,,,,,,, -309000,2.9455361,1.1812252,,,,,,,,,,,,,, -309100,3.3913932,2.9367807,,,,,,,,,,,,,, -309180,,,0.887988269329071,0.4194705188274383,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,137416.065836668,148648.40660881996,137416.065836668,11196.13981294632,20.42733263969421,0.0 -309200,3.214831,1.103986,,,,,,,,,,,,,, -309300,3.0363812,2.4941661,,,,,,,,,,,,,, -309400,3.2025173,1.4564018,,,,,,,,,,,,,, -309500,3.1598349,1.4006962,,,,,,,,,,,,,, -309600,3.5493333,3.0536644,,,,,,,,,,,,,, -309700,2.9280262,1.087553,,,,,,,,,,,,,, -309800,3.1762805,1.1470375,,,,,,,,,,,,,, -309900,3.3494902,2.427471,,,,,,,,,,,,,, -310000,3.304487,1.1689199,,,,,,,,,,,,,, -310100,3.1535926,1.1478829,,,,,,,,,,,,,, -310128,,,0.8869140148162842,0.4242978692054748,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,137836.3591003418,149100.46276378632,137836.3591003418,11227.777802228928,20.50452852249145,0.0 -310200,3.2129943,1.2990261,,,,,,,,,,,,,, -310300,2.9907103,1.1012237,,,,,,,,,,,,,, -310400,3.1839826,1.1706519,,,,,,,,,,,,,, -310500,2.8822348,1.9267952,,,,,,,,,,,,,, -310600,3.2108014,1.9383096,,,,,,,,,,,,,, -310700,2.9693868,2.2540104,,,,,,,,,,,,,, -310800,3.0501752,1.974895,,,,,,,,,,,,,, -310900,3.061221,1.3810881,,,,,,,,,,,,,, -311000,3.0765045,2.422118,,,,,,,,,,,,,, -311070,,,0.8864843845367432,0.4207232892513275,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,138256.27770638466,149559.89429020882,138256.27770638466,11267.157366037369,20.590729236602783,0.0 -311100,3.1203434,1.5347307,,,,,,,,,,,,,, -311200,3.0232773,1.0577267,,,,,,,,,,,,,, -311300,3.4571204,1.1243161,,,,,,,,,,,,,, -311400,3.2098548,1.0303499,,,,,,,,,,,,,, -311500,3.0025885,2.2128592,,,,,,,,,,,,,, -311600,3.4092672,3.0336716,,,,,,,,,,,,,, -311700,3.1318874,2.82775,,,,,,,,,,,,,, -311800,2.9579654,1.0838952,,,,,,,,,,,,,, -311900,3.013841,1.1787754,,,,,,,,,,,,,, -312000,3.3437636,1.2452291,,,,,,,,,,,,,, -312014,,,0.8853515386581421,0.425651341676712,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,138676.20446014404,150019.4015262127,138676.20446014404,11306.6024684906,20.678019762039185,0.0 -312100,3.2146237,1.4126024,,,,,,,,,,,,,, -312200,3.0790858,1.0745372,,,,,,,,,,,,,, -312300,3.2254086,2.3851018,,,,,,,,,,,,,, -312400,3.0106514,1.523158,,,,,,,,,,,,,, -312500,3.6748092,3.1283886,,,,,,,,,,,,,, -312600,3.0683262,1.3894204,,,,,,,,,,,,,, -312700,3.2540185,1.2343091,,,,,,,,,,,,,, -312800,2.8699715,1.3631527,,,,,,,,,,,,,, -312900,3.1063333,2.5421999,,,,,,,,,,,,,, -312961,,,0.8886913657188416,0.4127460122108459,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,139096.4767267704,150474.38501358032,139096.4767267704,11341.178297281263,20.76548171043396,0.0 -313000,3.0476558,1.5931079,,,,,,,,,,,,,, -313100,3.6705105,3.2679706,,,,,,,,,,,,,, -313200,3.0288064,1.4432058,,,,,,,,,,,,,, -313300,3.064697,1.1764947,,,,,,,,,,,,,, -313400,3.126502,1.9104171,,,,,,,,,,,,,, -313500,3.41427,1.20001,,,,,,,,,,,,,, -313600,3.1154602,1.0988371,,,,,,,,,,,,,, -313700,3.056511,1.5673714,,,,,,,,,,,,,, -313800,3.0967329,1.1416044,,,,,,,,,,,,,, -313900,4.245199,3.2431347,,,,,,,,,,,,,, -313905,,,0.8878710865974426,0.4168930649757385,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,139516.60149145126,150932.9247033596,139516.60149145126,11379.45639872551,20.85531640052796,0.0 -314000,3.377135,2.62758,,,,,,,,,,,,,, -314100,3.1764886,1.4623305,,,,,,,,,,,,,, -314200,3.1945546,2.8499773,,,,,,,,,,,,,, -314300,3.1355898,1.6487281,,,,,,,,,,,,,, -314400,3.6551414,2.9910707,,,,,,,,,,,,,, -314500,3.4787605,3.0512066,,,,,,,,,,,,,, -314600,3.6558652,2.8976662,,,,,,,,,,,,,, -314700,3.0729074,1.6390191,,,,,,,,,,,,,, -314800,3.08252,1.0604905,,,,,,,,,,,,,, -314851,,,0.8854687213897705,0.4233700335025787,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,139936.63247156143,151387.96013522148,139936.63247156143,11414.343953609468,20.924494981765747,0.0 -314900,3.1686964,1.1096946,,,,,,,,,,,,,, -315000,3.2258728,2.93386,,,,,,,,,,,,,, -315100,3.1199687,1.0937696,,,,,,,,,,,,,, -315200,3.2074165,2.6541066,,,,,,,,,,,,,, -315300,3.2993996,2.2471366,,,,,,,,,,,,,, -315400,3.0862641,1.1126037,,,,,,,,,,,,,, -315500,3.4891672,3.0582185,,,,,,,,,,,,,, -315600,3.1497252,1.0944197,,,,,,,,,,,,,, -315700,3.0880818,1.0331018,,,,,,,,,,,,,, -315793,,,0.8864843845367432,0.4232950806617737,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,140356.7695221901,151851.4670085907,140356.7695221901,11457.576095819471,21.0146746635437,0.0 -315800,3.024107,1.0564615,,,,,,,,,,,,,, -315900,3.1641974,1.0901278,,,,,,,,,,,,,, -316000,3.1754363,1.4117107,,,,,,,,,,,,,, -316100,3.066734,1.4924345,,,,,,,,,,,,,, -316200,4.0491977,3.248469,,,,,,,,,,,,,, -316300,3.2745278,1.1301235,,,,,,,,,,,,,, -316400,3.3851128,1.9986973,,,,,,,,,,,,,, -316500,3.2343917,2.5971541,,,,,,,,,,,,,, -316600,3.038522,1.2804587,,,,,,,,,,,,,, -316700,3.2386358,1.2020952,,,,,,,,,,,,,, -316739,,,0.8882226347923279,0.4166227579116821,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,140776.95079994202,152303.06532382965,140776.95079994202,11488.859185695648,21.10109519958496,0.0 -316800,2.955102,1.0493937,,,,,,,,,,,,,, -316900,3.4015641,2.8104424,,,,,,,,,,,,,, -317000,3.2776184,1.8100773,,,,,,,,,,,,,, -317100,3.0987754,2.073472,,,,,,,,,,,,,, -317200,3.2359564,2.7503884,,,,,,,,,,,,,, -317300,3.6764376,1.3550301,,,,,,,,,,,,,, -317400,2.921618,1.135984,,,,,,,,,,,,,, -317500,3.292212,1.1510707,,,,,,,,,,,,,, -317600,2.97406,1.088168,,,,,,,,,,,,,, -317682,,,0.8858593702316284,0.4191854298114776,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,141197.1936097145,152761.68574547768,141197.1936097145,11527.099912643433,21.19072890281677,0.0 -317700,4.0557733,3.3346171,,,,,,,,,,,,,, -317800,3.432332,1.6707568,,,,,,,,,,,,,, -317900,3.1578448,2.092759,,,,,,,,,,,,,, -318000,3.6873398,3.1059537,,,,,,,,,,,,,, -318100,3.1033683,1.1918778,,,,,,,,,,,,,, -318200,3.0180738,1.1703286,,,,,,,,,,,,,, -318300,3.2554054,2.229115,,,,,,,,,,,,,, -318400,3.0460525,2.4931867,,,,,,,,,,,,,, -318500,2.8558695,1.8678368,,,,,,,,,,,,,, -318600,3.2434006,1.1395863,,,,,,,,,,,,,, -318622,,,0.8885741829872131,0.4191855788230896,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,141617.29385328293,153226.09011101723,141617.29385328293,11571.23422384262,21.31320881843567,0.0 -318700,3.0828128,1.1649296,,,,,,,,,,,,,, -318800,3.4155855,3.1155472,,,,,,,,,,,,,, -318900,3.9059193,3.1894484,,,,,,,,,,,,,, -319000,2.9313438,1.053481,,,,,,,,,,,,,, -319100,3.8304198,3.1870198,,,,,,,,,,,,,, -319200,3.3684402,1.2273729,,,,,,,,,,,,,, -319300,3.310133,1.1686509,,,,,,,,,,,,,, -319400,3.1010485,1.3735027,,,,,,,,,,,,,, -319500,2.825543,1.0538114,,,,,,,,,,,,,, -319569,,,0.8874022960662842,0.4200645089149475,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,142037.1958515644,153680.8844909668,142037.1958515644,11606.005845546722,21.38654923439026,0.0 -319600,3.6486025,3.0623155,,,,,,,,,,,,,, -319700,3.011448,1.1495624,,,,,,,,,,,,,, -319800,3.0171568,1.3523861,,,,,,,,,,,,,, -319900,4.8950424,2.246667,,,,,,,,,,,,,, -320000,3.017172,1.5098795,,,,,,,,,,,,,, -320100,3.1853426,1.5454253,,,,,,,,,,,,,, -320200,2.916534,1.2337521,,,,,,,,,,,,,, -320300,3.0156329,1.1026325,,,,,,,,,,,,,, -320400,3.0116217,1.3486042,,,,,,,,,,,,,, -320500,3.2763684,2.3720737,,,,,,,,,,,,,, -320513,,,0.8875195384025574,0.4228684604167938,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,142457.30712151527,154136.66727113724,142457.30712151527,11641.539048671722,21.47724652290344,0.0 -320600,3.0729532,1.542419,,,,,,,,,,,,,, -320700,4.106023,3.1957455,,,,,,,,,,,,,, -320800,3.1128683,1.3968885,,,,,,,,,,,,,, -320900,3.4679222,3.1152437,,,,,,,,,,,,,, -321000,2.7152715,1.5338508,,,,,,,,,,,,,, -321100,3.105986,1.1463661,,,,,,,,,,,,,, -321200,3.6070216,2.9891856,,,,,,,,,,,,,, -321300,3.4506655,2.9272506,,,,,,,,,,,,,, -321400,2.940304,1.0744971,,,,,,,,,,,,,, -321457,,,0.8886523246765137,0.4192448258399963,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,142877.6751279831,154601.16312241554,142877.6751279831,11685.531976222992,21.56491112709045,0.0 -321500,3.1411047,2.6755486,,,,,,,,,,,,,, -321600,3.5062099,1.0307186,,,,,,,,,,,,,, -321700,3.058406,1.4249632,,,,,,,,,,,,,, -321800,2.8790293,1.2731668,,,,,,,,,,,,,, -321900,3.5984266,2.9811592,,,,,,,,,,,,,, -322000,3.052309,1.1059542,,,,,,,,,,,,,, -322100,3.1827798,1.1998215,,,,,,,,,,,,,, -322200,3.0240324,1.8639067,,,,,,,,,,,,,, -322300,3.277453,2.2850287,,,,,,,,,,,,,, -322400,2.9732661,2.1619914,,,,,,,,,,,,,, -322403,,,0.8864648342132568,0.4219618141651153,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,143297.70597314835,155063.17006206512,143297.70597314835,11727.38549900055,21.639564275741577,0.0 -322500,3.18266,1.0726262,,,,,,,,,,,,,, -322600,3.086258,1.0620333,,,,,,,,,,,,,, -322700,3.603308,3.1505618,,,,,,,,,,,,,, -322800,3.0093143,1.1557817,,,,,,,,,,,,,, -322900,3.0984182,1.1823999,,,,,,,,,,,,,, -323000,3.2272823,2.2417364,,,,,,,,,,,,,, -323100,3.4108849,3.001545,,,,,,,,,,,,,, -323200,4.167128,3.153524,,,,,,,,,,,,,, -323300,3.7536638,3.2559335,,,,,,,,,,,,,, -323349,,,0.8867577910423279,0.4161965548992157,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,143717.72579455376,155516.2507390976,143717.72579455376,11760.326476812364,21.71282839775085,0.0 -323400,3.235801,1.2449268,,,,,,,,,,,,,, -323500,3.5139656,2.3109224,,,,,,,,,,,,,, -323600,3.0438533,1.1209512,,,,,,,,,,,,,, -323700,3.817548,1.0794618,,,,,,,,,,,,,, -323800,3.0987403,1.1667671,,,,,,,,,,,,,, -323900,3.127875,2.3359132,,,,,,,,,,,,,, -324000,2.9453785,1.074918,,,,,,,,,,,,,, -324100,3.0939045,1.3251183,,,,,,,,,,,,,, -324200,3.4846728,1.1277704,,,,,,,,,,,,,, -324290,,,0.8874022960662842,0.4182784259319305,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,144137.9409184456,155980.52302742004,144137.9409184456,11804.24805378914,21.801358461380005,0.0 -324300,3.2592587,1.1936985,,,,,,,,,,,,,, -324400,3.4469705,2.9771085,,,,,,,,,,,,,, -324500,3.3263283,2.4288738,,,,,,,,,,,,,, -324600,3.18136,1.0242901,,,,,,,,,,,,,, -324700,3.2272298,2.6270804,,,,,,,,,,,,,, -324800,3.0096686,1.1119331,,,,,,,,,,,,,, -324900,3.0843484,2.6854167,,,,,,,,,,,,,, -325000,3.4039838,2.287519,,,,,,,,,,,,,, -325100,3.0922823,1.2664441,,,,,,,,,,,,,, -325200,3.1349556,1.3570626,,,,,,,,,,,,,, -325237,,,0.8866991996765137,0.4203983545303345,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,144558.39869880676,156434.3753077984,144558.39869880676,11837.51058626175,21.88632369041443,0.0 -325300,3.1999042,1.1550567,,,,,,,,,,,,,, -325400,2.8103418,1.737326,,,,,,,,,,,,,, -325500,3.0438707,1.1555395,,,,,,,,,,,,,, -325600,3.195178,1.1992791,,,,,,,,,,,,,, -325700,3.038398,1.2498357,,,,,,,,,,,,,, -325800,3.1647067,1.1184893,,,,,,,,,,,,,, -325900,2.916797,0.96048105,,,,,,,,,,,,,, -326000,3.137079,1.1969771,,,,,,,,,,,,,, -326100,2.931664,2.0279002,,,,,,,,,,,,,, -326177,,,0.88916015625,0.4168655574321747,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,144978.63285660744,156890.96754074097,144978.63285660744,11873.732411623,21.974764585494995,0.0 -326200,3.1687038,1.2246335,,,,,,,,,,,,,, -326300,3.393959,2.9675941,,,,,,,,,,,,,, -326400,3.8802624,3.0858302,,,,,,,,,,,,,, -326500,3.2865858,2.4826255,,,,,,,,,,,,,, -326600,3.3308182,2.8228161,,,,,,,,,,,,,, -326700,2.8499987,1.062303,,,,,,,,,,,,,, -326800,3.277943,1.890536,,,,,,,,,,,,,, -326900,3.390476,2.4228091,,,,,,,,,,,,,, -327000,3.4426122,2.7919867,,,,,,,,,,,,,, -327100,3.054983,1.2202082,,,,,,,,,,,,,, -327120,,,0.8898046612739563,0.4155591130256653,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,145398.58534121513,157347.51926136017,145398.58534121513,11910.18722462654,22.070751667022705,0.0 -327200,3.018175,1.3889539,,,,,,,,,,,,,, -327300,3.1684198,1.1891595,,,,,,,,,,,,,, -327400,3.488349,3.1470912,,,,,,,,,,,,,, -327500,2.8984537,1.2209268,,,,,,,,,,,,,, -327600,3.5124545,3.1658201,,,,,,,,,,,,,, -327700,3.0290694,1.5449317,,,,,,,,,,,,,, -327800,3.0256162,1.340337,,,,,,,,,,,,,, -327900,3.1528177,1.1074692,,,,,,,,,,,,,, -328000,2.7789123,1.0852809,,,,,,,,,,,,,, -328063,,,0.8884375095367432,0.4140132665634155,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,145818.65999770164,157802.96967935562,145818.65999770164,11945.410465955734,22.175304174423218,0.0 -328100,3.769704,3.2529008,,,,,,,,,,,,,, -328200,2.9123893,1.5644454,,,,,,,,,,,,,, -328300,2.9282932,1.4246902,,,,,,,,,,,,,, -328400,3.4249082,1.9286782,,,,,,,,,,,,,, -328500,3.2945094,1.1517351,,,,,,,,,,,,,, -328600,3.2296114,1.1807991,,,,,,,,,,,,,, -328700,3.2363026,1.1209605,,,,,,,,,,,,,, -328800,3.0944588,1.0474477,,,,,,,,,,,,,, -328900,3.0914345,1.0851495,,,,,,,,,,,,,, -329000,3.5487993,1.0854313,,,,,,,,,,,,,, -329007,,,0.8872265219688416,0.4143014550209045,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,146238.7226538658,158257.84491205215,146238.7226538658,11980.082593917848,22.26838541030884,0.0 -329100,3.016024,1.1591982,,,,,,,,,,,,,, -329200,2.9401321,1.7946281,,,,,,,,,,,,,, -329300,3.1361716,1.2860796,,,,,,,,,,,,,, -329400,4.568606,2.9750075,,,,,,,,,,,,,, -329500,3.6141207,2.712143,,,,,,,,,,,,,, -329600,3.1008325,2.2266045,,,,,,,,,,,,,, -329700,2.9327796,1.6414857,,,,,,,,,,,,,, -329800,3.0106633,1.7731143,,,,,,,,,,,,,, -329900,3.5991907,1.1351082,,,,,,,,,,,,,, -329951,,,0.8880859017372131,0.4180927276611328,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,146659.0313911438,158722.40548229218,146659.0313911438,12024.194583177568,22.361082553863525,0.0 -330000,3.1312132,1.0572647,,,,,,,,,,,,,, -330100,3.0411222,1.0488803,,,,,,,,,,,,,, -330200,2.850113,1.1884323,,,,,,,,,,,,,, -330300,3.181094,1.2084384,,,,,,,,,,,,,, -330400,3.0729902,1.1071144,,,,,,,,,,,,,, -330500,3.3225622,1.2320343,,,,,,,,,,,,,, -330600,2.8645759,1.7228705,,,,,,,,,,,,,, -330700,2.991039,1.3489027,,,,,,,,,,,,,, -330800,3.0749788,1.0936363,,,,,,,,,,,,,, -330899,,,0.8866601586341858,0.4193836450576782,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,147079.2585787773,159179.90055942535,147079.2585787773,12061.329034805298,22.44697642326355,0.0 -330900,3.0464656,1.37031,,,,,,,,,,,,,, -331000,5.2856517,2.7851229,,,,,,,,,,,,,, -331100,3.6868937,3.0839796,,,,,,,,,,,,,, -331200,3.2349856,1.2060192,,,,,,,,,,,,,, -331300,3.38522,2.9661617,,,,,,,,,,,,,, -331400,3.1587384,1.0957508,,,,,,,,,,,,,, -331500,2.7903523,1.1146756,,,,,,,,,,,,,, -331600,2.9332497,1.2182091,,,,,,,,,,,,,, -331700,3.1638918,1.214403,,,,,,,,,,,,,, -331800,3.293644,1.176086,,,,,,,,,,,,,, -331840,,,0.8866406083106995,0.4111728072166443,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,147498.69048261642,159646.3641116619,147498.69048261642,12107.44764471054,23.31242060661316,0.0 -331900,3.0646029,1.8775403,,,,,,,,,,,,,, -332000,3.2818062,1.1574208,,,,,,,,,,,,,, -332100,2.9933314,1.0080336,,,,,,,,,,,,,, -332200,3.1527064,1.2002943,,,,,,,,,,,,,, -332300,3.2093213,1.7158535,,,,,,,,,,,,,, -332400,2.9031203,1.0977015,,,,,,,,,,,,,, -332500,3.103576,1.1390498,,,,,,,,,,,,,, -332600,3.2111325,2.4869137,,,,,,,,,,,,,, -332700,3.0535102,2.3844337,,,,,,,,,,,,,, -332787,,,0.8870898485183716,0.427180141210556,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,147918.98912215233,160100.89725995064,147918.98912215233,12141.548724412918,23.3977952003479,0.0 -332800,2.9658527,1.554858,,,,,,,,,,,,,, -332900,3.166294,1.125427,,,,,,,,,,,,,, -333000,3.91668,3.309608,,,,,,,,,,,,,, -333100,2.864263,1.3976523,,,,,,,,,,,,,, -333200,2.8301253,1.8070393,,,,,,,,,,,,,, -333300,2.93669,1.3314911,,,,,,,,,,,,,, -333400,3.383487,2.937097,,,,,,,,,,,,,, -333500,3.1821153,1.5113369,,,,,,,,,,,,,, -333600,4.2126265,2.9209623,,,,,,,,,,,,,, -333700,2.8983467,1.3770576,,,,,,,,,,,,,, -333723,,,0.8864062428474426,0.4243050217628479,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,148338.922219038,160564.95006656647,148338.922219038,12185.530223608015,23.488813161849976,0.0 -333800,2.9704814,1.3173639,,,,,,,,,,,,,, -333900,3.1885781,0.95361364,,,,,,,,,,,,,, -334000,2.820031,1.9054053,,,,,,,,,,,,,, -334100,2.817735,2.0761893,,,,,,,,,,,,,, -334200,3.1861377,1.2115905,,,,,,,,,,,,,, -334300,3.0374703,1.1167601,,,,,,,,,,,,,, -334400,3.343526,1.9613363,,,,,,,,,,,,,, -334500,3.7818263,3.2778842,,,,,,,,,,,,,, -334600,3.060737,2.0660987,,,,,,,,,,,,,, -334663,,,0.8871679306030273,0.4189110994338989,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,148758.86404037476,161026.26504921913,148758.86404037476,12226.760622501371,23.583487033844,0.0 -334700,3.102182,1.3329546,,,,,,,,,,,,,, -334800,2.909277,1.1502091,,,,,,,,,,,,,, -334900,3.0319433,2.232159,,,,,,,,,,,,,, -335000,4.210429,2.0342884,,,,,,,,,,,,,, -335100,3.0614767,1.0855964,,,,,,,,,,,,,, -335200,2.9000773,1.3768731,,,,,,,,,,,,,, -335300,2.765698,1.1103339,,,,,,,,,,,,,, -335400,3.685188,3.1711345,,,,,,,,,,,,,, -335500,3.756397,3.1528556,,,,,,,,,,,,,, -335600,3.0806353,1.4303952,,,,,,,,,,,,,, -335606,,,0.88623046875,0.4251594543457031,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,149179.06500458717,161484.33428955078,149179.06500458717,12264.491759061812,23.673017024993896,0.0 -335700,3.0071495,1.2580202,,,,,,,,,,,,,, -335800,3.3573895,2.2580786,,,,,,,,,,,,,, -335900,3.2801406,1.1784533,,,,,,,,,,,,,, -336000,3.2784176,3.1378322,,,,,,,,,,,,,, -336100,3.0080168,1.1249402,,,,,,,,,,,,,, -336200,2.9275415,1.1165369,,,,,,,,,,,,,, -336300,3.6212308,1.0479363,,,,,,,,,,,,,, -336400,3.2293088,1.1089234,,,,,,,,,,,,,, -336500,3.4860787,1.0916126,,,,,,,,,,,,,, -336548,,,0.8883788585662842,0.4107557237148285,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,149599.2316122055,161946.33445000648,149599.2316122055,12306.18427157402,23.765666007995605,0.0 -336600,3.228198,1.0896842,,,,,,,,,,,,,, -336700,3.213133,1.2272465,,,,,,,,,,,,,, -336800,3.2913504,1.3696879,,,,,,,,,,,,,, -336900,3.22865,1.072578,,,,,,,,,,,,,, -337000,3.4791012,3.155623,,,,,,,,,,,,,, -337100,3.3409762,2.8815265,,,,,,,,,,,,,, -337200,3.8456535,3.222848,,,,,,,,,,,,,, -337300,3.0018404,1.8922567,,,,,,,,,,,,,, -337400,3.046357,2.6584184,,,,,,,,,,,,,, -337490,,,0.8858007788658142,0.4200242757797241,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,150019.19176864624,162411.44314026833,150019.19176864624,12351.189871788025,23.86135625839233,0.0 -337500,3.2418904,1.254417,,,,,,,,,,,,,, -337600,3.348961,1.1749777,,,,,,,,,,,,,, -337700,3.1754522,1.1449599,,,,,,,,,,,,,, -337800,2.9541743,1.0743589,,,,,,,,,,,,,, -337900,3.1515205,1.0828493,,,,,,,,,,,,,, -338000,3.0786812,1.222773,,,,,,,,,,,,,, -338100,3.3076484,2.7940607,,,,,,,,,,,,,, -338200,3.1168559,1.0820866,,,,,,,,,,,,,, -338300,3.1237345,2.4388032,,,,,,,,,,,,,, -338400,3.1769664,2.7929478,,,,,,,,,,,,,, -338430,,,0.8881054520606995,0.4224624037742615,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,150439.1379084587,162866.21015238762,150439.1379084587,12385.872237205504,23.952771425247192,0.0 -338500,3.504966,1.1570117,,,,,,,,,,,,,, -338600,3.0868564,1.6614121,,,,,,,,,,,,,, -338700,3.19947,2.6707308,,,,,,,,,,,,,, -338800,3.2286541,1.6492492,,,,,,,,,,,,,, -338900,2.9765782,1.1760843,,,,,,,,,,,,,, -339000,3.0010526,1.0798832,,,,,,,,,,,,,, -339100,3.1023164,1.2511581,,,,,,,,,,,,,, -339200,3.3426697,1.0907128,,,,,,,,,,,,,, -339300,3.5727396,2.980569,,,,,,,,,,,,,, -339370,,,0.8855859041213989,0.4248220324516296,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,150859.23227977753,163322.5501112938,150859.23227977753,12421.92520046234,24.09778380393982,0.0 -339400,2.683879,1.638478,,,,,,,,,,,,,, -339500,3.5965047,3.1288617,,,,,,,,,,,,,, -339600,4.8255777,3.104632,,,,,,,,,,,,,, -339700,3.0649526,1.078695,,,,,,,,,,,,,, -339800,3.1655536,1.5750515,,,,,,,,,,,,,, -339900,3.0224175,1.7544358,,,,,,,,,,,,,, -340000,2.9587786,1.3954566,,,,,,,,,,,,,, -340100,2.9404883,1.4726771,,,,,,,,,,,,,, -340200,2.9886775,1.168749,,,,,,,,,,,,,, -340300,3.1355917,1.1118165,,,,,,,,,,,,,, -340311,,,0.8870312571525574,0.4190521538257599,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,151279.18514490128,163780.07631373403,151279.18514490128,12459.357788085938,24.19091844558716,0.0 -340400,3.3094506,2.6964,,,,,,,,,,,,,, -340500,3.1788025,1.6815113,,,,,,,,,,,,,, -340600,2.96131,1.132831,,,,,,,,,,,,,, -340700,3.308659,1.1073966,,,,,,,,,,,,,, -340800,3.629457,1.0455258,,,,,,,,,,,,,, -340900,3.1245553,1.7735134,,,,,,,,,,,,,, -341000,3.1175625,1.1705756,,,,,,,,,,,,,, -341100,2.93755,1.2848372,,,,,,,,,,,,,, -341200,3.6067183,3.2161322,,,,,,,,,,,,,, -341256,,,0.8864062428474426,0.4155356585979461,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,151699.11443638802,164245.09310364723,151699.11443638802,12504.30022096634,24.288305044174194,0.0 -341300,3.2618976,0.9966079,,,,,,,,,,,,,, -341400,3.6915483,1.247982,,,,,,,,,,,,,, -341500,2.9824145,1.07252,,,,,,,,,,,,,, -341600,3.6051617,3.12609,,,,,,,,,,,,,, -341700,3.0508816,1.071607,,,,,,,,,,,,,, -341800,2.908144,2.0100603,,,,,,,,,,,,,, -341900,3.5637002,2.979936,,,,,,,,,,,,,, -342000,4.1076493,1.3746886,,,,,,,,,,,,,, -342100,3.174221,1.3392609,,,,,,,,,,,,,, -342200,3.61627,3.1064868,,,,,,,,,,,,,, -342201,,,0.8872656226158142,0.4203912913799286,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,152119.28486824036,164705.3630282879,152119.28486824036,12544.256209850311,24.38241219520569,0.0 -342300,2.9727962,1.059701,,,,,,,,,,,,,, -342400,3.0717762,1.1533549,,,,,,,,,,,,,, -342500,2.984416,1.2463593,,,,,,,,,,,,,, -342600,3.0862517,1.0930791,,,,,,,,,,,,,, -342700,3.1754107,1.15801,,,,,,,,,,,,,, -342800,3.4272976,1.0588038,,,,,,,,,,,,,, -342900,2.9237251,1.4423718,,,,,,,,,,,,,, -343000,2.9737453,1.1110919,,,,,,,,,,,,,, -343100,3.1593308,2.2915606,,,,,,,,,,,,,, -343145,,,0.8889843821525574,0.4194121360778808,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,152539.3480424881,165168.36728596687,152539.3480424881,12587.055959939957,24.475526571273804,0.0 -343200,2.9585414,1.448905,,,,,,,,,,,,,, -343300,3.002105,1.2265375,,,,,,,,,,,,,, -343400,3.3680134,1.1102141,,,,,,,,,,,,,, -343500,2.9313865,1.0188475,,,,,,,,,,,,,, -343600,2.9916651,2.1097648,,,,,,,,,,,,,, -343700,3.0787435,1.0634278,,,,,,,,,,,,,, -343800,3.5958962,1.1072197,,,,,,,,,,,,,, -343900,3.0243466,1.1085111,,,,,,,,,,,,,, -344000,3.1843078,2.4749718,,,,,,,,,,,,,, -344090,,,0.8886132836341858,0.4171028733253479,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,152959.46968698502,165624.51542139053,152959.46968698502,12622.95605802536,24.554410457611084,0.0 -344100,3.0523252,1.049653,,,,,,,,,,,,,, -344200,3.058867,1.3295135,,,,,,,,,,,,,, -344300,3.4605694,1.1223028,,,,,,,,,,,,,, -344400,3.52823,1.2056122,,,,,,,,,,,,,, -344500,2.9368556,2.4693737,,,,,,,,,,,,,, -344600,2.902642,1.1736407,,,,,,,,,,,,,, -344700,3.23101,2.9075792,,,,,,,,,,,,,, -344800,3.1201754,1.1673046,,,,,,,,,,,,,, -344900,3.0900054,1.2075071,,,,,,,,,,,,,, -345000,3.2457073,2.287604,,,,,,,,,,,,,, -345033,,,0.8860741853713989,0.4267341494560241,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,153379.60203409195,166087.2524061203,153379.60203409195,12665.410373926165,24.657466411590576,0.0 -345100,3.0977163,1.2039572,,,,,,,,,,,,,, -345200,3.24736,2.6673625,,,,,,,,,,,,,, -345300,3.8615766,1.1224036,,,,,,,,,,,,,, -345400,2.987219,1.1696218,,,,,,,,,,,,,, -345500,3.6265852,3.2337646,,,,,,,,,,,,,, -345600,3.9987156,3.2046373,,,,,,,,,,,,,, -345700,2.9211237,1.4766684,,,,,,,,,,,,,, -345800,3.7049124,3.076357,,,,,,,,,,,,,, -345900,3.1513832,1.052766,,,,,,,,,,,,,, -345977,,,0.8871288895606995,0.4171859622001648,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,153799.6966097355,166542.02790880203,153799.6966097355,12699.949487686155,24.752000331878666,0.0 -346000,3.334996,1.1255696,,,,,,,,,,,,,, -346100,2.9161594,1.4251556,,,,,,,,,,,,,, -346200,3.3537996,1.2138739,,,,,,,,,,,,,, -346300,3.029407,2.0017385,,,,,,,,,,,,,, -346400,2.8284004,2.2089229,,,,,,,,,,,,,, -346500,3.6083357,2.2285223,,,,,,,,,,,,,, -346600,3.061523,1.8980876,,,,,,,,,,,,,, -346700,2.9392302,1.7381088,,,,,,,,,,,,,, -346800,3.1920729,1.0861789,,,,,,,,,,,,,, -346900,3.2411158,2.701992,,,,,,,,,,,,,, -346918,,,0.887988269329071,0.4203246533870697,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,154219.81880021095,167008.85578632355,154219.81880021095,12746.5108397007,24.84833550453186,0.0 -347000,3.640176,3.192625,,,,,,,,,,,,,, -347100,3.7771893,3.2117686,,,,,,,,,,,,,, -347200,3.5017312,3.0875287,,,,,,,,,,,,,, -347300,3.1455584,2.6828632,,,,,,,,,,,,,, -347400,3.4667435,3.102004,,,,,,,,,,,,,, -347500,3.0630991,1.1194489,,,,,,,,,,,,,, -347600,3.11289,1.5708756,,,,,,,,,,,,,, -347700,3.1367216,1.1526537,,,,,,,,,,,,,, -347800,2.8286507,1.3723665,,,,,,,,,,,,,, -347864,,,0.8874804377555847,0.4153309464454651,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,154639.78215909004,167467.11938381195,154639.78215909004,12784.672406435013,24.939489126205444,0.0 -347900,3.2779138,1.9346396,,,,,,,,,,,,,, -348000,3.3950577,1.5965272,,,,,,,,,,,,,, -348100,3.1268506,1.1934826,,,,,,,,,,,,,, -348200,3.1280036,1.0805557,,,,,,,,,,,,,, -348300,3.0857902,1.3961301,,,,,,,,,,,,,, -348400,3.2758844,1.4989374,,,,,,,,,,,,,, -348500,3.2637386,1.147546,,,,,,,,,,,,,, -348600,3.3440828,1.2647331,,,,,,,,,,,,,, -348700,2.829523,2.188809,,,,,,,,,,,,,, -348800,3.2518096,1.128611,,,,,,,,,,,,,, -348810,,,0.8859374523162842,0.4238024055957794,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,155059.82725691795,167921.87595415115,155059.82725691795,12819.244787454603,25.030781984329224,0.0 -348900,3.1645591,1.12637,,,,,,,,,,,,,, -349000,2.8774257,1.4779584,,,,,,,,,,,,,, -349100,3.8232427,3.3022714,,,,,,,,,,,,,, -349200,3.054112,1.5674156,,,,,,,,,,,,,, -349300,3.2033298,0.98420143,,,,,,,,,,,,,, -349400,3.5673892,0.98287314,,,,,,,,,,,,,, -349500,3.1730223,1.1846776,,,,,,,,,,,,,, -349600,3.0096788,2.7175317,,,,,,,,,,,,,, -349700,3.7288022,3.1778364,,,,,,,,,,,,,, -349753,,,0.8904492259025574,0.4134562611579895,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,155480.12934875488,168378.29820513725,155480.12934875488,12855.2223072052,25.12555241584778,0.0 -349800,3.087073,1.1587473,,,,,,,,,,,,,, -349900,3.1202564,1.2596762,,,,,,,,,,,,,, -350000,2.9456248,1.1887797,,,,,,,,,,,,,, -350100,2.8536463,2.0436306,,,,,,,,,,,,,, -350200,3.2062347,1.2167271,,,,,,,,,,,,,, -350300,3.601278,1.4916134,,,,,,,,,,,,,, -350400,3.3169615,1.1410384,,,,,,,,,,,,,, -350500,3.0209858,1.260989,,,,,,,,,,,,,, -350600,2.927985,2.1123536,,,,,,,,,,,,,, -350694,,,0.8905858993530273,0.4152538478374481,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,155900.0498933792,168842.13325691223,155900.0498933792,12898.995729207993,25.21956491470337,0.0 -350700,3.6157157,1.0862797,,,,,,,,,,,,,, -350800,3.1877804,1.2527548,,,,,,,,,,,,,, -350900,3.2071197,1.2630557,,,,,,,,,,,,,, -351000,3.116374,1.3013226,,,,,,,,,,,,,, -351100,3.8372383,1.1480386,,,,,,,,,,,,,, -351200,3.194794,1.1569006,,,,,,,,,,,,,, -351300,3.1338,1.4825445,,,,,,,,,,,,,, -351400,3.0757546,1.0913744,,,,,,,,,,,,,, -351500,2.9330566,1.2808502,,,,,,,,,,,,,, -351600,6.6231923,3.274484,,,,,,,,,,,,,, -351640,,,0.887499988079071,0.4161133468151092,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,156320.398665905,169299.9502491951,156320.398665905,12936.320994138718,25.31429362297058,0.0 -351700,3.4492698,2.7244334,,,,,,,,,,,,,, -351800,3.4395816,3.0038605,,,,,,,,,,,,,, -351900,3.0300608,1.9073427,,,,,,,,,,,,,, -352000,3.3061707,1.4851745,,,,,,,,,,,,,, -352100,2.9063084,2.2574263,,,,,,,,,,,,,, -352200,3.3932793,2.641758,,,,,,,,,,,,,, -352300,2.9868653,2.2591105,,,,,,,,,,,,,, -352400,3.0479476,1.3893019,,,,,,,,,,,,,, -352500,3.361525,1.7147403,,,,,,,,,,,,,, -352583,,,0.88587886095047,0.4204886257648468,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,156740.2870953083,169765.09603238106,156740.2870953083,12981.430022001266,25.41551327705384,0.0 -352600,2.8734646,1.3537176,,,,,,,,,,,,,, -352700,3.3845572,1.1380894,,,,,,,,,,,,,, -352800,3.550266,3.059954,,,,,,,,,,,,,, -352900,3.5005665,1.146255,,,,,,,,,,,,,, -353000,3.3663545,1.1187966,,,,,,,,,,,,,, -353100,3.8080845,3.3155704,,,,,,,,,,,,,, -353200,3.2335865,1.1025493,,,,,,,,,,,,,, -353300,3.0048163,2.2743957,,,,,,,,,,,,,, -353400,2.9832475,1.3597689,,,,,,,,,,,,,, -353500,3.6227262,2.3426862,,,,,,,,,,,,,, -353525,,,0.8861718773841858,0.4195727407932281,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,157160.15752458572,170225.4692606926,157160.15752458572,13021.78906273842,25.50675082206726,0.0 -353600,3.4132762,1.3409886,,,,,,,,,,,,,, -353700,2.9618437,1.1352359,,,,,,,,,,,,,, -353800,2.905384,1.05758,,,,,,,,,,,,,, -353900,3.327068,2.4903636,,,,,,,,,,,,,, -354000,3.3170927,1.0126663,,,,,,,,,,,,,, -354100,3.011503,2.4620924,,,,,,,,,,,,,, -354200,3.5755424,2.9587529,,,,,,,,,,,,,, -354300,3.1654687,1.1677247,,,,,,,,,,,,,, -354400,3.3149695,1.129283,,,,,,,,,,,,,, -354468,,,0.8877343535423279,0.4124457836151123,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,157580.12749171257,170684.63828587532,157580.12749171257,13060.845579624176,25.601975679397583,0.0 -354500,3.247015,1.16205,,,,,,,,,,,,,, -354600,3.395825,1.2469286,,,,,,,,,,,,,, -354700,3.1349719,2.4200022,,,,,,,,,,,,,, -354800,2.8578787,1.9642231,,,,,,,,,,,,,, -354900,2.807169,1.6838264,,,,,,,,,,,,,, -355000,3.2007377,1.1315343,,,,,,,,,,,,,, -355100,3.3717198,1.1671582,,,,,,,,,,,,,, -355200,3.122435,1.1270837,,,,,,,,,,,,,, -355300,3.120117,2.7115183,,,,,,,,,,,,,, -355400,2.9620245,1.0988837,,,,,,,,,,,,,, -355410,,,0.8864648342132568,0.4209831058979034,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,158000.41928219795,171146.12599873543,158000.41928219795,13101.89858865738,25.69657111167908,0.0 -355500,3.4856818,2.9073563,,,,,,,,,,,,,, -355600,2.9644845,1.5041895,,,,,,,,,,,,,, -355700,3.9476712,1.1710087,,,,,,,,,,,,,, -355800,3.1269827,1.0679603,,,,,,,,,,,,,, -355900,3.1274953,1.0629337,,,,,,,,,,,,,, -356000,3.3192663,1.2647063,,,,,,,,,,,,,, -356100,3.4046192,3.084692,,,,,,,,,,,,,, -356200,3.7462227,2.8990626,,,,,,,,,,,,,, -356300,3.255225,1.203591,,,,,,,,,,,,,, -356353,,,0.8862890601158142,0.4223797619342804,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,158420.6190032959,171609.484395504,158420.6190032959,13144.91348528862,25.793184995651245,0.0 -356400,4.024259,3.2414794,,,,,,,,,,,,,, -356500,3.1581059,2.7896988,,,,,,,,,,,,,, -356600,2.9328973,1.1342942,,,,,,,,,,,,,, -356700,3.1655264,1.3832256,,,,,,,,,,,,,, -356800,3.2191045,2.7497911,,,,,,,,,,,,,, -356900,2.982057,1.1773604,,,,,,,,,,,,,, -357000,3.3227317,1.1774487,,,,,,,,,,,,,, -357100,2.9615912,2.177215,,,,,,,,,,,,,, -357200,3.0609343,1.3843822,,,,,,,,,,,,,, -357299,,,0.8887499570846558,0.4185085296630859,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,158840.51721048355,172063.7651720047,158840.51721048355,13179.165003061296,25.876455068588257,0.0 -357300,3.243897,1.0595579,,,,,,,,,,,,,, -357400,3.7627852,2.7048566,,,,,,,,,,,,,, -357500,2.9840791,1.110856,,,,,,,,,,,,,, -357600,3.0783978,1.0344709,,,,,,,,,,,,,, -357700,3.1285803,1.0769852,,,,,,,,,,,,,, -357800,3.164321,1.0943421,,,,,,,,,,,,,, -357900,2.963172,1.6672293,,,,,,,,,,,,,, -358000,2.8933067,1.888685,,,,,,,,,,,,,, -358100,3.2082634,1.1567447,,,,,,,,,,,,,, -358200,3.3565817,2.5813735,,,,,,,,,,,,,, -358239,,,0.8867577910423279,0.4219204485416412,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,159260.69488954544,172529.43064379692,159260.69488954544,13224.510422706604,25.97172403335572,0.0 -358300,2.9164467,1.7752552,,,,,,,,,,,,,, -358400,4.048757,3.0696015,,,,,,,,,,,,,, -358500,4.7561975,3.2597153,,,,,,,,,,,,,, -358600,4.3150034,3.2370617,,,,,,,,,,,,,, -358700,2.9821577,1.0622678,,,,,,,,,,,,,, -358800,2.9803874,1.145201,,,,,,,,,,,,,, -358900,3.137213,1.2532355,,,,,,,,,,,,,, -359000,3.6273608,1.1758618,,,,,,,,,,,,,, -359100,2.9497235,1.0075351,,,,,,,,,,,,,, -359182,,,0.8855664134025574,0.4237650334835052,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,159680.70743894577,172985.44624853134,159680.70743894577,13260.374766349792,26.063395261764526,0.0 -359200,3.0479336,2.352693,,,,,,,,,,,,,, -359300,2.8656583,1.1539948,,,,,,,,,,,,,, -359400,3.2164419,1.2572496,,,,,,,,,,,,,, -359500,2.9038162,1.1575261,,,,,,,,,,,,,, -359600,3.04176,2.0035896,,,,,,,,,,,,,, -359700,3.0210402,1.0974686,,,,,,,,,,,,,, -359800,3.3357716,1.110929,,,,,,,,,,,,,, -359900,3.4891171,1.1636279,,,,,,,,,,,,,, -360000,3.2336893,1.133085,,,,,,,,,,,,,, -360100,3.0832064,2.2303782,,,,,,,,,,,,,, -360127,,,0.8892382383346558,0.414288729429245,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,160100.66970348358,173442.63001036644,160100.66970348358,13297.446984291077,26.16198706626892,0.0 -360200,3.3041663,1.3789207,,,,,,,,,,,,,, -360300,3.3271809,2.1619163,,,,,,,,,,,,,, -360400,2.91347,2.3868215,,,,,,,,,,,,,, -360500,3.0293207,1.7411671,,,,,,,,,,,,,, -360600,3.0861468,0.97514063,,,,,,,,,,,,,, -360700,3.7599094,3.222546,,,,,,,,,,,,,, -360800,3.4394226,2.7269158,,,,,,,,,,,,,, -360900,3.3523345,1.128281,,,,,,,,,,,,,, -361000,3.2428582,1.0545105,,,,,,,,,,,,,, -361070,,,0.88832026720047,0.4119943976402282,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,160520.55683374405,173896.1469092369,160520.55683374405,13330.930730581284,26.26072931289673,0.0 -361100,3.1379333,1.2125005,,,,,,,,,,,,,, -361200,3.780294,3.1193368,,,,,,,,,,,,,, -361300,3.027356,1.0026362,,,,,,,,,,,,,, -361400,3.227976,1.8436298,,,,,,,,,,,,,, -361500,2.8408103,2.1765053,,,,,,,,,,,,,, -361600,3.025935,1.0572624,,,,,,,,,,,,,, -361700,2.9807494,2.1968942,,,,,,,,,,,,,, -361800,3.5209305,1.0754995,,,,,,,,,,,,,, -361900,3.1567638,1.8027797,,,,,,,,,,,,,, -362000,3.163722,2.3950088,,,,,,,,,,,,,, -362015,,,0.8843945264816284,0.4265358746051788,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,160940.7024629116,174361.82907938957,160940.7024629116,13376.318539857864,26.35774397850037,0.0 -362100,3.3340635,1.229678,,,,,,,,,,,,,, -362200,3.032211,2.2265239,,,,,,,,,,,,,, -362300,2.9593549,1.2819694,,,,,,,,,,,,,, -362400,3.600643,2.0120993,,,,,,,,,,,,,, -362500,4.9433947,2.687392,,,,,,,,,,,,,, -362600,3.286765,2.7805467,,,,,,,,,,,,,, -362700,2.9332876,1.7173362,,,,,,,,,,,,,, -362800,2.9561946,1.5904727,,,,,,,,,,,,,, -362900,2.9958303,1.2307146,,,,,,,,,,,,,, -362960,,,0.8865624666213989,0.4253043234348297,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,161361.03342723846,174815.7095863819,161361.03342723846,13409.688135623932,26.49013066291809,0.0 -363000,3.1171465,1.8615744,,,,,,,,,,,,,, -363100,3.89191,2.839844,,,,,,,,,,,,,, -363200,4.099939,3.2246792,,,,,,,,,,,,,, -363300,3.4764948,1.1878216,,,,,,,,,,,,,, -363400,3.1713731,1.0769285,,,,,,,,,,,,,, -363500,3.2873065,1.1350693,,,,,,,,,,,,,, -363600,3.084005,1.1344314,,,,,,,,,,,,,, -363700,3.081883,2.691821,,,,,,,,,,,,,, -363800,3.0794556,2.7193067,,,,,,,,,,,,,, -363900,3.0324614,1.2795775,,,,,,,,,,,,,, -363901,,,0.8876562118530273,0.4183862209320068,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,161781.51662492752,175275.9978826046,161781.51662492752,13449.352699279783,26.583407163619995,0.0 -364000,3.6092544,2.969331,,,,,,,,,,,,,, -364100,3.0585458,1.1322601,,,,,,,,,,,,,, -364200,3.8015158,3.277768,,,,,,,,,,,,,, -364300,3.3027093,1.2885896,,,,,,,,,,,,,, -364400,3.3104131,3.0913935,,,,,,,,,,,,,, -364500,3.2012217,2.234947,,,,,,,,,,,,,, -364600,3.5513935,1.0554266,,,,,,,,,,,,,, -364700,3.0703747,2.6853783,,,,,,,,,,,,,, -364800,3.1680963,1.2154603,,,,,,,,,,,,,, -364845,,,0.8862695097923279,0.4161961674690246,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,162201.56584835052,175732.1130809784,162201.56584835052,13485.275573015211,26.678946018219,0.0 -364900,3.013286,1.0783939,,,,,,,,,,,,,, -365000,2.879123,1.3782371,,,,,,,,,,,,,, -365100,3.12776,1.1439501,,,,,,,,,,,,,, -365200,2.872369,1.1085774,,,,,,,,,,,,,, -365300,3.132417,2.5316458,,,,,,,,,,,,,, -365400,3.2546954,2.6919603,,,,,,,,,,,,,, -365500,3.337861,1.4456012,,,,,,,,,,,,,, -365600,3.0836365,1.4853176,,,,,,,,,,,,,, -365700,3.6895955,1.4867176,,,,,,,,,,,,,, -365789,,,0.8881444931030273,0.419982761144638,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,162621.76594257355,176186.0605404377,162621.76594257355,13518.878060102465,26.776683807373047,0.0 -365800,3.024431,1.1233327,,,,,,,,,,,,,, -365900,3.1640916,2.0026362,,,,,,,,,,,,,, -366000,3.2924027,1.1868519,,,,,,,,,,,,,, -366100,3.1287708,1.0685276,,,,,,,,,,,,,, -366200,3.2154353,1.0844252,,,,,,,,,,,,,, -366300,3.0852747,1.2116433,,,,,,,,,,,,,, -366400,3.79523,3.2553093,,,,,,,,,,,,,, -366500,3.1309948,1.1754608,,,,,,,,,,,,,, -366600,2.8380156,1.1125947,,,,,,,,,,,,,, -366700,3.1596608,1.3581378,,,,,,,,,,,,,, -366733,,,0.8887109160423279,0.416282057762146,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,163042.15055036545,176655.54402661324,163042.15055036545,13567.830877304075,26.87555503845215,0.0 -366800,2.8187864,1.0877895,,,,,,,,,,,,,, -366900,3.223838,1.1117008,,,,,,,,,,,,,, -367000,2.9300942,1.6585922,,,,,,,,,,,,,, -367100,3.136809,1.1771487,,,,,,,,,,,,,, -367200,3.3852272,2.863877,,,,,,,,,,,,,, -367300,3.2094984,2.0208406,,,,,,,,,,,,,, -367400,2.9361916,1.156455,,,,,,,,,,,,,, -367500,3.178804,1.1658223,,,,,,,,,,,,,, -367600,3.2478685,1.3866388,,,,,,,,,,,,,, -367683,,,0.8882421851158142,0.4190069139003753,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,163462.26739168167,177116.63298726082,163462.26739168167,13608.676812171936,26.95441627502441,0.0 -367700,3.18539,1.673415,,,,,,,,,,,,,, -367800,3.508186,2.8021114,,,,,,,,,,,,,, -367900,3.4239793,2.344928,,,,,,,,,,,,,, -368000,2.8785763,1.0902431,,,,,,,,,,,,,, -368100,3.8443677,2.7027247,,,,,,,,,,,,,, -368200,3.2497902,2.8121018,,,,,,,,,,,,,, -368300,2.9765658,1.9109769,,,,,,,,,,,,,, -368400,3.2097824,1.101865,,,,,,,,,,,,,, -368500,3.1763244,2.2779226,,,,,,,,,,,,,, -368600,3.2362995,1.7798947,,,,,,,,,,,,,, -368627,,,0.8858202695846558,0.4272821247577667,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,163882.3846206665,177584.1802227497,163882.3846206665,13655.980248212814,27.03390669822693,0.0 -368700,3.124837,1.0748824,,,,,,,,,,,,,, -368800,3.2493894,1.2547405,,,,,,,,,,,,,, -368900,3.3670223,1.473967,,,,,,,,,,,,,, -369000,3.7014585,3.3604705,,,,,,,,,,,,,, -369100,3.1357868,1.1332622,,,,,,,,,,,,,, -369200,2.9714184,1.4369204,,,,,,,,,,,,,, -369300,2.9910126,1.5739257,,,,,,,,,,,,,, -369400,2.9766936,1.0212687,,,,,,,,,,,,,, -369500,3.1660037,1.7425387,,,,,,,,,,,,,, -369571,,,0.8865429759025574,0.4198426008224487,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,164302.32819080353,178037.7103281021,164302.32819080353,13689.44184088707,27.11146354675293,0.0 -369600,3.8725195,3.3698177,,,,,,,,,,,,,, -369700,3.0593114,1.2384212,,,,,,,,,,,,,, -369800,2.8770792,1.2023368,,,,,,,,,,,,,, -369900,2.8367875,1.0628481,,,,,,,,,,,,,, -370000,3.7333937,3.0932713,,,,,,,,,,,,,, -370100,2.9953823,1.4016094,,,,,,,,,,,,,, -370200,3.0452,1.2273757,,,,,,,,,,,,,, -370300,2.9939973,1.200676,,,,,,,,,,,,,, -370400,2.9588072,1.2561516,,,,,,,,,,,,,, -370500,2.991419,1.1501682,,,,,,,,,,,,,, -370512,,,0.8881640434265137,0.4155753254890442,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,164722.48730134964,178494.94316363335,164722.48730134964,13726.37277674675,27.207013607025143,0.0 -370600,3.1138797,1.0104356,,,,,,,,,,,,,, -370700,2.8905747,1.1329764,,,,,,,,,,,,,, -370800,2.984297,1.8482956,,,,,,,,,,,,,, -370900,2.9494755,1.283565,,,,,,,,,,,,,, -371000,3.1489944,1.11219,,,,,,,,,,,,,, -371100,2.9559298,1.1349415,,,,,,,,,,,,,, -371200,2.9777005,1.1604826,,,,,,,,,,,,,, -371300,3.2028341,1.1101718,,,,,,,,,,,,,, -371400,2.969285,1.1968676,,,,,,,,,,,,,, -371455,,,0.88916015625,0.4131352603435516,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,165142.49481630325,178955.66393136978,165142.49481630325,13766.942764282228,27.30364155769348,0.0 -371500,3.1292489,1.4944026,,,,,,,,,,,,,, -371600,2.9093273,1.2883191,,,,,,,,,,,,,, -371700,2.9905422,1.0802226,,,,,,,,,,,,,, -371800,3.0463223,1.0962147,,,,,,,,,,,,,, -371900,2.988844,1.087799,,,,,,,,,,,,,, -372000,3.6868424,2.185937,,,,,,,,,,,,,, -372100,3.2251344,1.2408142,,,,,,,,,,,,,, -372200,2.9848711,1.0183021,,,,,,,,,,,,,, -372300,3.6592302,3.304814,,,,,,,,,,,,,, -372400,2.8542526,1.0783856,,,,,,,,,,,,,, -372401,,,0.88525390625,0.4247466921806335,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,165562.63774585724,179409.72325754166,165562.63774585724,13800.73223042488,27.383100271224976,0.0 -372500,3.2896204,2.8199344,,,,,,,,,,,,,, -372600,3.0766673,1.0620514,,,,,,,,,,,,,, -372700,3.2396867,2.306148,,,,,,,,,,,,,, -372800,3.2991333,2.3246484,,,,,,,,,,,,,, -372900,3.2209773,2.773637,,,,,,,,,,,,,, -373000,3.1030726,1.1835095,,,,,,,,,,,,,, -373100,3.4562454,2.7823527,,,,,,,,,,,,,, -373200,2.8544843,1.4576957,,,,,,,,,,,,,, -373300,3.414758,2.946449,,,,,,,,,,,,,, -373342,,,0.8884375095367432,0.4191824495792389,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,165982.9108800888,179872.6829378605,165982.9108800888,13843.269656419754,27.48445630073548,0.0 -373400,4.1006193,2.303382,,,,,,,,,,,,,, -373500,3.0944989,1.160101,,,,,,,,,,,,,, -373600,3.5367897,2.9806738,,,,,,,,,,,,,, -373700,2.96072,1.1508844,,,,,,,,,,,,,, -373800,3.422102,1.0898558,,,,,,,,,,,,,, -373900,2.8375378,1.9702809,,,,,,,,,,,,,, -374000,2.9580631,1.0486841,,,,,,,,,,,,,, -374100,3.109762,1.0946865,,,,,,,,,,,,,, -374200,3.3412797,1.9620792,,,,,,,,,,,,,, -374287,,,0.8906054496765137,0.4155364036560058,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,166403.01735568047,180335.7163796425,166403.01735568047,13886.068606615068,27.56479525566101,0.0 -374300,3.0635161,1.4279037,,,,,,,,,,,,,, -374400,3.1595469,2.7377768,,,,,,,,,,,,,, -374500,3.0023286,1.12082,,,,,,,,,,,,,, -374600,3.583522,3.0819097,,,,,,,,,,,,,, -374700,4.8623295,3.312085,,,,,,,,,,,,,, -374800,3.237565,1.0884929,,,,,,,,,,,,,, -374900,2.8958902,1.3053485,,,,,,,,,,,,,, -375000,3.1067913,2.1209002,,,,,,,,,,,,,, -375100,3.0897293,2.368268,,,,,,,,,,,,,, -375200,3.070176,1.9088087,,,,,,,,,,,,,, -375230,,,0.8887304663658142,0.4109485149383545,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,166822.80471730232,180787.6889693737,166822.80471730232,13917.866647958755,27.90357255935669,0.0 -375300,2.896469,1.0871584,,,,,,,,,,,,,, -375400,3.1422646,1.1587453,,,,,,,,,,,,,, -375500,3.218546,1.2469485,,,,,,,,,,,,,, -375600,3.6462207,3.0214953,,,,,,,,,,,,,, -375700,3.3159773,2.8112066,,,,,,,,,,,,,, -375800,3.1756175,1.8393393,,,,,,,,,,,,,, -375900,3.020261,1.8945837,,,,,,,,,,,,,, -376000,3.0881095,1.0924311,,,,,,,,,,,,,, -376100,3.0915687,1.1717814,,,,,,,,,,,,,, -376172,,,0.88587886095047,0.4229264855384826,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,167243.02319812775,181253.0316066742,167243.02319812775,13962.844014406204,28.00294804573059,0.0 -376200,3.1157067,1.0902869,,,,,,,,,,,,,, -376300,3.3427374,2.6653175,,,,,,,,,,,,,, -376400,3.1583858,1.1481199,,,,,,,,,,,,,, -376500,3.0196073,2.0347352,,,,,,,,,,,,,, -376600,3.3747559,1.848757,,,,,,,,,,,,,, -376700,3.1841078,1.1562594,,,,,,,,,,,,,, -376800,2.8806186,1.289335,,,,,,,,,,,,,, -376900,3.387115,1.2053229,,,,,,,,,,,,,, -377000,3.0138173,2.6532981,,,,,,,,,,,,,, -377100,3.1403575,1.0944362,,,,,,,,,,,,,, -377119,,,0.8873242139816284,0.4139610826969147,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,167663.00806236267,181717.4283750057,167663.00806236267,14007.10686326027,28.10410571098328,0.0 -377200,3.6515303,2.8776917,,,,,,,,,,,,,, -377300,3.454532,1.1823815,,,,,,,,,,,,,, -377400,3.0259542,1.2582414,,,,,,,,,,,,,, -377500,3.174276,2.62789,,,,,,,,,,,,,, -377600,2.9083447,2.1903288,,,,,,,,,,,,,, -377700,2.8884425,1.9546361,,,,,,,,,,,,,, -377800,3.023432,1.3965561,,,,,,,,,,,,,, -377900,3.109933,1.1428044,,,,,,,,,,,,,, -378000,4.0137506,3.1000948,,,,,,,,,,,,,, -378063,,,0.8871679306030273,0.4175764620304107,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,168082.953353405,182171.3619320393,168082.953353405,14040.968413114548,28.18370485305786,0.0 -378100,3.0934074,1.3117181,,,,,,,,,,,,,, -378200,4.159212,3.059535,,,,,,,,,,,,,, -378300,3.1220117,1.7355852,,,,,,,,,,,,,, -378400,3.1554418,1.6355895,,,,,,,,,,,,,, -378500,2.9214542,1.0944755,,,,,,,,,,,,,, -378600,3.833359,3.275504,,,,,,,,,,,,,, -378700,3.8613603,3.2577186,,,,,,,,,,,,,, -378800,3.7756622,1.0730696,,,,,,,,,,,,,, -378900,3.1147356,1.1403931,,,,,,,,,,,,,, -379000,,,0.8867382407188416,0.4216572046279907,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,168502.90359401703,182627.8619401455,168502.90359401703,14077.369477033615,28.28520369529724,0.0 -379000,3.2030382,2.2512321,,,,,,,,,,,,,, -379100,2.9336185,1.3342284,,,,,,,,,,,,,, -379200,3.062302,1.1058376,,,,,,,,,,,,,, -379300,2.9867382,1.1006898,,,,,,,,,,,,,, -379400,3.0924923,1.8964617,,,,,,,,,,,,,, -379500,3.2289832,1.708422,,,,,,,,,,,,,, -379600,3.0277276,1.0599743,,,,,,,,,,,,,, -379700,3.0317938,1.1033629,,,,,,,,,,,,,, -379800,3.1452432,1.1392683,,,,,,,,,,,,,, -379900,3.0242357,1.1968964,,,,,,,,,,,,,, -379940,,,0.8876171708106995,0.4176125228404999,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,168923.17162704468,183095.3925318718,168923.17162704468,14124.485233306885,28.38457489013672,0.0 -380000,3.029076,1.108013,,,,,,,,,,,,,, -380100,3.0558152,1.2874445,,,,,,,,,,,,,, -380200,3.0549495,2.7422967,,,,,,,,,,,,,, -380300,3.2581224,2.879668,,,,,,,,,,,,,, -380400,2.969334,1.0732993,,,,,,,,,,,,,, -380500,2.8919642,2.1992855,,,,,,,,,,,,,, -380600,3.0784616,1.1987128,,,,,,,,,,,,,, -380700,3.3552785,2.1901968,,,,,,,,,,,,,, -380800,2.9497914,1.1199933,,,,,,,,,,,,,, -380885,,,0.8860155940055847,0.426149308681488,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,169343.40920114517,183554.4429168701,169343.40920114517,14163.1655189991,28.470067501068115,0.0 -380900,2.909833,1.9779963,,,,,,,,,,,,,, -381000,3.033342,1.0260915,,,,,,,,,,,,,, -381100,4.268845,3.1278203,,,,,,,,,,,,,, -381200,4.2464747,2.4550855,,,,,,,,,,,,,, -381300,3.179054,2.2378109,,,,,,,,,,,,,, -381400,3.336779,1.1069723,,,,,,,,,,,,,, -381500,2.9522033,1.1777,,,,,,,,,,,,,, -381600,3.5461009,2.9591491,,,,,,,,,,,,,, -381700,2.972987,1.0468196,,,,,,,,,,,,,, -381800,3.2166946,2.062744,,,,,,,,,,,,,, -381830,,,0.8877539038658142,0.4175314307212829,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,169763.40131759644,184013.37898135185,169763.40131759644,14201.979838609695,28.55268931388855,0.0 -381900,3.4418628,2.616854,,,,,,,,,,,,,, -382000,3.1713977,1.212741,,,,,,,,,,,,,, -382100,2.901184,2.276609,,,,,,,,,,,,,, -382200,3.8578305,3.3236935,,,,,,,,,,,,,, -382300,2.9961436,1.0275671,,,,,,,,,,,,,, -382400,3.129888,2.4692132,,,,,,,,,,,,,, -382500,3.4538245,3.0486958,,,,,,,,,,,,,, -382600,2.887098,1.931664,,,,,,,,,,,,,, -382700,3.254214,1.0688353,,,,,,,,,,,,,, -382773,,,0.88636714220047,0.4217416942119598,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,170183.68056845665,184473.9022257328,170183.68056845665,14242.088454008102,28.640817880630493,0.0 -382800,2.9764383,1.0210911,,,,,,,,,,,,,, -382900,3.3309352,2.4944344,,,,,,,,,,,,,, -383000,3.0059035,1.143773,,,,,,,,,,,,,, -383100,2.9995682,1.1563891,,,,,,,,,,,,,, -383200,2.9834409,2.0626397,,,,,,,,,,,,,, -383300,3.5243838,1.3468361,,,,,,,,,,,,,, -383400,3.2283125,2.3228893,,,,,,,,,,,,,, -383500,3.0648348,1.0788445,,,,,,,,,,,,,, -383600,3.3777313,1.2025883,,,,,,,,,,,,,, -383700,3.274608,2.5943882,,,,,,,,,,,,,, -383717,,,0.8875781297683716,0.4189814925193786,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,170603.55612325668,184932.04956531525,170603.55612325668,14280.231041908264,28.722613096237183,0.0 -383800,3.1774476,1.4556614,,,,,,,,,,,,,, -383900,3.9783165,2.7206616,,,,,,,,,,,,,, -384000,3.4869947,2.307129,,,,,,,,,,,,,, -384100,4.3110523,3.3165774,,,,,,,,,,,,,, -384200,3.1163723,2.394508,,,,,,,,,,,,,, -384300,3.7317011,2.946835,,,,,,,,,,,,,, -384400,2.9882393,2.1678696,,,,,,,,,,,,,, -384500,3.2419915,1.1834906,,,,,,,,,,,,,, -384600,3.3056426,1.2687126,,,,,,,,,,,,,, -384657,,,0.887499988079071,0.4142645299434662,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,171023.6480166912,185389.73262786865,171023.6480166912,14317.67663049698,28.820887088775635,0.0 -384700,3.8339005,3.3067868,,,,,,,,,,,,,, -384800,3.1175117,1.1303279,,,,,,,,,,,,,, -384900,2.9934928,1.2524273,,,,,,,,,,,,,, -385000,4.269878,3.2438762,,,,,,,,,,,,,, -385100,2.9787936,2.6338801,,,,,,,,,,,,,, -385200,3.3658645,2.8367083,,,,,,,,,,,,,, -385300,3.0133533,1.1552908,,,,,,,,,,,,,, -385400,3.1627936,1.274296,,,,,,,,,,,,,, -385500,3.3523822,1.1993843,,,,,,,,,,,,,, -385599,,,0.8870898485183716,0.4226619601249695,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,171443.81467032433,185854.7215340137,171443.81467032433,14362.350728034971,28.92026162147522,0.0 -385600,3.2251358,1.2416935,,,,,,,,,,,,,, -385700,2.8953931,1.0052474,,,,,,,,,,,,,, -385800,3.0873098,1.1923614,,,,,,,,,,,,,, -385900,2.8866901,1.1771055,,,,,,,,,,,,,, -386000,3.6129303,3.2351575,,,,,,,,,,,,,, -386100,3.3385866,2.5784807,,,,,,,,,,,,,, -386200,2.8705132,1.6906309,,,,,,,,,,,,,, -386300,3.9554837,3.2121038,,,,,,,,,,,,,, -386400,3.2301078,2.3650484,,,,,,,,,,,,,, -386500,3.3410115,1.211848,,,,,,,,,,,,,, -386545,,,0.886035144329071,0.4234894216060638,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,171863.97700834274,186307.54375863075,171863.97700834274,14394.879940986631,29.00299835205078,0.0 -386600,3.1135662,1.2765676,,,,,,,,,,,,,, -386700,3.0529754,1.2197988,,,,,,,,,,,,,, -386800,3.3073578,1.2558662,,,,,,,,,,,,,, -386900,3.0132065,2.4215977,,,,,,,,,,,,,, -387000,3.0466058,1.206888,,,,,,,,,,,,,, -387100,3.5215392,1.1942713,,,,,,,,,,,,,, -387200,3.2432835,1.1155642,,,,,,,,,,,,,, -387300,2.9470224,2.1486826,,,,,,,,,,,,,, -387400,2.9985743,1.2987105,,,,,,,,,,,,,, -387487,,,0.8861913681030273,0.4212893843650818,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,172284.1622555256,186765.5475420952,172284.1622555256,14432.549918174744,29.103593349456787,0.0 -387500,3.0185728,2.1275668,,,,,,,,,,,,,, -387600,2.9149365,1.061032,,,,,,,,,,,,,, -387700,2.7529004,1.9638369,,,,,,,,,,,,,, -387800,3.150157,1.1292406,,,,,,,,,,,,,, -387900,2.7792666,1.575047,,,,,,,,,,,,,, -388000,3.237061,1.1189637,,,,,,,,,,,,,, -388100,2.9074924,1.1219411,,,,,,,,,,,,,, -388200,2.938879,2.0410244,,,,,,,,,,,,,, -388300,3.2433062,1.0574249,,,,,,,,,,,,,, -388400,3.6082935,1.2133578,,,,,,,,,,,,,, -388432,,,0.887499988079071,0.4149161875247955,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,172704.41179394722,187223.76943206787,172704.41179394722,14470.37551188469,29.203126668930054,0.0 -388500,2.8950083,2.0543635,,,,,,,,,,,,,, -388600,3.0408108,1.3048302,,,,,,,,,,,,,, -388700,3.1574686,1.7301695,,,,,,,,,,,,,, -388800,2.934718,1.3023791,,,,,,,,,,,,,, -388900,3.305919,1.373155,,,,,,,,,,,,,, -389000,2.9535313,1.0549611,,,,,,,,,,,,,, -389100,3.6891587,2.9496856,,,,,,,,,,,,,, -389200,2.9737613,1.0882344,,,,,,,,,,,,,, -389300,3.0529163,1.0608456,,,,,,,,,,,,,, -389377,,,0.8874413967132568,0.4166697859764099,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,173124.3463385105,187679.0199213028,173124.3463385105,14505.541404247284,29.30449271202088,0.0 -389400,3.9166453,3.180008,,,,,,,,,,,,,, -389500,3.1326551,1.1549404,,,,,,,,,,,,,, -389600,3.5467513,2.9207072,,,,,,,,,,,,,, -389700,2.808117,1.6178477,,,,,,,,,,,,,, -389800,3.193115,1.1537551,,,,,,,,,,,,,, -389900,9.637664,2.3750768,,,,,,,,,,,,,, -390000,3.1499023,1.2363971,,,,,,,,,,,,,, -390100,3.0381684,1.205933,,,,,,,,,,,,,, -390200,2.9016097,1.0830237,,,,,,,,,,,,,, -390300,3.2529876,1.1525875,,,,,,,,,,,,,, -390321,,,0.8872265219688416,0.4215186536312103,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,173544.2095386982,188138.7891843319,173544.2095386982,14545.296684980392,29.40795946121216,0.0 -390400,2.8651693,1.022617,,,,,,,,,,,,,, -390500,3.0962598,1.2211673,,,,,,,,,,,,,, -390600,3.3435326,2.733994,,,,,,,,,,,,,, -390700,3.434796,1.2084059,,,,,,,,,,,,,, -390800,3.6505375,3.1684783,,,,,,,,,,,,,, -390900,2.9623687,1.4888406,,,,,,,,,,,,,, -391000,3.2106357,1.0704193,,,,,,,,,,,,,, -391100,3.0360208,1.4417585,,,,,,,,,,,,,, -391200,3.1225212,1.8566397,,,,,,,,,,,,,, -391267,,,0.8884375095367432,0.4242092669010162,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,173964.30621051788,188592.8221981525,173964.30621051788,14579.097965717316,29.495544910430908,0.0 -391300,3.0100188,1.1245692,,,,,,,,,,,,,, -391400,3.0782752,1.3101743,,,,,,,,,,,,,, -391500,3.9351609,3.1554146,,,,,,,,,,,,,, -391600,3.2254055,1.3193221,,,,,,,,,,,,,, -391700,3.1063886,1.1692039,,,,,,,,,,,,,, -391800,3.0611973,1.2439967,,,,,,,,,,,,,, -391900,3.0824854,1.0690073,,,,,,,,,,,,,, -392000,3.2154033,1.14562,,,,,,,,,,,,,, -392100,3.1056335,1.1273112,,,,,,,,,,,,,, -392200,3.2338774,1.1361301,,,,,,,,,,,,,, -392207,,,0.8878905773162842,0.4217130243778229,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,174384.32402396202,189051.9608139992,174384.32402396202,14618.065967798231,29.601163387298584,0.0 -392300,2.972864,2.3159122,,,,,,,,,,,,,, -392400,2.9928124,1.1322813,,,,,,,,,,,,,, -392500,2.9718199,2.3645613,,,,,,,,,,,,,, -392600,5.4764886,1.7660718,,,,,,,,,,,,,, -392700,3.5765166,1.5549588,,,,,,,,,,,,,, -392800,3.0764651,1.712215,,,,,,,,,,,,,, -392900,3.9789495,1.2017074,,,,,,,,,,,,,, -393000,5.0056195,3.153533,,,,,,,,,,,,,, -393100,3.4761846,1.0920467,,,,,,,,,,,,,, -393149,,,0.8867577910423279,0.4197116792201996,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,174804.2192106247,189509.67946982384,174804.2192106247,14655.7409658432,29.7026801109314,0.0 -393200,3.1533706,1.0848794,,,,,,,,,,,,,, -393300,3.256093,1.390435,,,,,,,,,,,,,, -393400,3.0738409,1.3974543,,,,,,,,,,,,,, -393500,3.4699583,2.6936433,,,,,,,,,,,,,, -393600,3.5017645,1.1931957,,,,,,,,,,,,,, -393700,3.1324308,2.0934691,,,,,,,,,,,,,, -393800,3.1450565,1.1924304,,,,,,,,,,,,,, -393900,3.2377408,1.0700814,,,,,,,,,,,,,, -394000,3.3801901,1.1299164,,,,,,,,,,,,,, -394093,,,0.8876757621765137,0.416228324174881,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,175224.10652852058,189971.56618881223,175224.10652852058,14697.594356775284,29.80150079727173,0.0 -394100,3.7941213,3.0019767,,,,,,,,,,,,,, -394200,3.6054966,3.2410128,,,,,,,,,,,,,, -394300,3.327021,2.9493725,,,,,,,,,,,,,, -394400,3.035068,1.4107113,,,,,,,,,,,,,, -394500,3.457843,1.22556,,,,,,,,,,,,,, -394600,3.1055932,1.4891458,,,,,,,,,,,,,, -394700,2.9849434,1.0374324,,,,,,,,,,,,,, -394800,2.9197211,1.1013471,,,,,,,,,,,,,, -394900,3.0663462,1.3713253,,,,,,,,,,,,,, -395000,3.0507185,1.1301612,,,,,,,,,,,,,, -395038,,,0.8874413967132568,0.4183365404605865,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,175644.32959270477,190430.5368359089,175644.32959270477,14736.21154475212,29.8844575881958,0.0 -395100,2.946443,1.04243,,,,,,,,,,,,,, -395200,3.3473825,1.111016,,,,,,,,,,,,,, -395300,3.0040755,1.0206351,,,,,,,,,,,,,, -395400,3.7501743,2.238812,,,,,,,,,,,,,, -395500,3.01876,1.9010112,,,,,,,,,,,,,, -395600,3.0129647,1.1158844,,,,,,,,,,,,,, -395700,2.9509914,1.4619772,,,,,,,,,,,,,, -395800,4.1494527,3.165967,,,,,,,,,,,,,, -395900,3.2354188,1.1324042,,,,,,,,,,,,,, -395981,,,0.8874804377555847,0.4170241057872772,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,176064.4482076168,190893.79946780205,176064.4482076168,14779.204402923584,29.98852634429932,0.0 -396000,3.53067,1.1617275,,,,,,,,,,,,,, -396100,2.8325675,1.1072924,,,,,,,,,,,,,, -396200,3.0968437,1.0939355,,,,,,,,,,,,,, -396300,3.529036,2.893601,,,,,,,,,,,,,, -396400,3.1705217,1.1857753,,,,,,,,,,,,,, -396500,3.1283786,1.50603,,,,,,,,,,,,,, -396600,3.367873,2.6298735,,,,,,,,,,,,,, -396700,3.1021938,1.1586045,,,,,,,,,,,,,, -396800,3.0560162,1.1609185,,,,,,,,,,,,,, -396900,3.2888606,2.8292224,,,,,,,,,,,,,, -396924,,,0.8891991972923279,0.4172369539737701,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,176484.37260246277,191352.2600080967,176484.37260246277,14817.607572078705,30.07315492630005,0.0 -397000,3.0188932,1.1234362,,,,,,,,,,,,,, -397100,3.2105725,1.0257504,,,,,,,,,,,,,, -397200,3.9091303,3.2622468,,,,,,,,,,,,,, -397300,3.0053983,1.082055,,,,,,,,,,,,,, -397400,2.9023547,1.9640839,,,,,,,,,,,,,, -397500,2.8139446,1.0985954,,,,,,,,,,,,,, -397600,3.1556475,1.0506971,,,,,,,,,,,,,, -397700,2.6683762,1.6356146,,,,,,,,,,,,,, -397800,2.9228017,1.1805829,,,,,,,,,,,,,, -397870,,,0.8872265219688416,0.4217978119850158,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,176904.55462956429,191808.3955821991,176904.55462956429,14853.429702758787,30.15562510490417,0.0 -397900,3.0933347,1.5081615,,,,,,,,,,,,,, -398000,3.227717,2.7156358,,,,,,,,,,,,,, -398100,3.2041278,2.1692505,,,,,,,,,,,,,, -398200,4.0761137,3.1327348,,,,,,,,,,,,,, -398300,3.1573172,1.2172172,,,,,,,,,,,,,, -398400,3.0430553,1.2065588,,,,,,,,,,,,,, -398500,3.1060796,1.5912784,,,,,,,,,,,,,, -398600,3.5518355,2.9266837,,,,,,,,,,,,,, -398700,2.9412963,1.7006657,,,,,,,,,,,,,, -398800,3.251749,1.2229302,,,,,,,,,,,,,, -398814,,,0.8900976181030273,0.4069123864173889,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,177324.88240146637,192264.0342531204,177324.88240146637,14888.592317581177,30.25603485107422,0.0 -398900,3.907846,3.164788,,,,,,,,,,,,,, -399000,3.0541854,1.2726518,,,,,,,,,,,,,, -399100,2.9541821,1.0724375,,,,,,,,,,,,,, -399200,2.9009814,1.7916682,,,,,,,,,,,,,, -399300,3.1247547,2.7153924,,,,,,,,,,,,,, -399400,3.3506074,1.2254474,,,,,,,,,,,,,, -399500,2.930031,1.5034889,,,,,,,,,,,,,, -399600,3.0110207,1.1609019,,,,,,,,,,,,,, -399700,3.286207,2.833455,,,,,,,,,,,,,, -399751,,,0.8873632550239563,0.4189270436763763,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,177745.00523638725,192722.6485464573,177745.00523638725,14926.936472415924,30.356215715408325,0.0 -399800,3.171714,1.0874875,,,,,,,,,,,,,, -399900,2.9705904,1.081481,,,,,,,,,,,,,, -400000,3.424288,2.8523984,,,,,,,,,,,,,, -400100,3.1740959,2.564507,,,,,,,,,,,,,, -400200,2.8664484,1.3020339,,,,,,,,,,,,,, -400300,3.1277533,1.9535958,,,,,,,,,,,,,, -400400,2.9423127,1.1577477,,,,,,,,,,,,,, -400500,3.487013,2.9752777,,,,,,,,,,,,,, -400600,3.1464663,2.45884,,,,,,,,,,,,,, -400693,,,0.8883788585662842,0.4117428958415985,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,178164.88333821297,193180.8615772724,178164.88333821297,14965.11044239998,30.468629837036133,0.0 -400700,3.0265224,1.0801595,,,,,,,,,,,,,, -400800,3.1100698,1.8118316,,,,,,,,,,,,,, -400900,3.3435068,1.4690851,,,,,,,,,,,,,, -401000,2.92148,1.2195572,,,,,,,,,,,,,, -401100,3.2892025,1.2271441,,,,,,,,,,,,,, -401200,2.9861042,1.180604,,,,,,,,,,,,,, -401300,3.14785,1.1653677,,,,,,,,,,,,,, -401400,3.3778522,1.5086911,,,,,,,,,,,,,, -401500,3.766246,3.2216678,,,,,,,,,,,,,, -401600,3.1902788,2.2178807,,,,,,,,,,,,,, -401638,,,0.8850390315055847,0.4258567094802856,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,178585.00466036797,193643.89689183235,178585.00466036797,15007.873401165009,30.57186913490296,0.0 -401700,2.9931736,1.1670829,,,,,,,,,,,,,, -401800,3.8017113,3.26232,,,,,,,,,,,,,, -401900,3.2624724,1.2285315,,,,,,,,,,,,,, -402000,2.9628592,1.149251,,,,,,,,,,,,,, -402100,3.2104623,2.6778588,,,,,,,,,,,,,, -402200,3.2192779,2.87682,,,,,,,,,,,,,, -402300,3.366858,3.0686388,,,,,,,,,,,,,, -402400,2.96639,1.7192144,,,,,,,,,,,,,, -402500,4.311824,3.234054,,,,,,,,,,,,,, -402585,,,0.8865624666213989,0.4191708266735077,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,179004.96719145775,194098.60548830032,179004.96719145775,15042.488958358765,30.655258893966675,0.0 -402600,3.2118773,1.0818728,,,,,,,,,,,,,, -402700,3.97165,3.270104,,,,,,,,,,,,,, -402800,2.9687483,1.0964833,,,,,,,,,,,,,, -402900,2.9224079,1.1534805,,,,,,,,,,,,,, -403000,3.6791048,1.4102821,,,,,,,,,,,,,, -403100,3.3667393,1.1654812,,,,,,,,,,,,,, -403200,3.0104804,2.0153332,,,,,,,,,,,,,, -403300,3.3029354,1.0799264,,,,,,,,,,,,,, -403400,3.9019375,3.17222,,,,,,,,,,,,,, -403500,2.8147647,1.3928491,,,,,,,,,,,,,, -403526,,,0.8875585794448853,0.4197700917720794,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,179425.1990661621,194564.8892962933,179425.1990661621,15088.383921384811,30.7644944190979,0.0 -403600,2.9909763,1.0893927,,,,,,,,,,,,,, -403700,3.1765072,1.2537129,,,,,,,,,,,,,, -403800,2.9426775,0.9934006,,,,,,,,,,,,,, -403900,3.0603237,1.3267679,,,,,,,,,,,,,, -404000,5.152493,3.073393,,,,,,,,,,,,,, -404100,3.1231995,1.352088,,,,,,,,,,,,,, -404200,3.2381868,1.467883,,,,,,,,,,,,,, -404300,3.1861434,2.0793433,,,,,,,,,,,,,, -404400,3.0478506,2.0733302,,,,,,,,,,,,,, -404467,,,0.8872265219688416,0.4191464781761169,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,179845.30004048347,195027.85599136355,179845.30004048347,15131.103526830671,30.86377716064453,0.0 -404500,3.246492,1.1618916,,,,,,,,,,,,,, -404600,3.0885046,2.1879945,,,,,,,,,,,,,, -404700,3.0288608,1.3029372,,,,,,,,,,,,,, -404800,3.0659318,1.8566039,,,,,,,,,,,,,, -404900,3.4317894,1.1210625,,,,,,,,,,,,,, -405000,3.0597394,1.1130667,,,,,,,,,,,,,, -405100,3.0760844,1.1183982,,,,,,,,,,,,,, -405200,3.1805725,1.1668943,,,,,,,,,,,,,, -405300,3.3762343,1.1608086,,,,,,,,,,,,,, -405400,3.0007117,1.9355174,,,,,,,,,,,,,, -405411,,,0.8865624666213989,0.4221104681491852,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,180265.17570757863,195483.1527121067,180265.17570757863,15166.39280295372,30.94795870780945,0.0 -405500,3.0257406,1.1507239,,,,,,,,,,,,,, -405600,3.1533134,1.1667212,,,,,,,,,,,,,, -405700,3.3508592,1.1041263,,,,,,,,,,,,,, -405800,3.015422,1.178477,,,,,,,,,,,,,, -405900,3.0576575,1.2620519,,,,,,,,,,,,,, -406000,3.8035016,3.1005147,,,,,,,,,,,,,, -406100,4.7205567,1.1336472,,,,,,,,,,,,,, -406200,3.4150236,1.368664,,,,,,,,,,,,,, -406300,3.2841542,1.1283957,,,,,,,,,,,,,, -406350,,,0.8861327767372131,0.4223325252532959,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,180685.35479688644,195943.6081705093,180685.35479688644,15206.514477968216,31.055206060409542,0.0 -406400,2.9877055,1.5392684,,,,,,,,,,,,,, -406500,3.0638478,1.132698,,,,,,,,,,,,,, -406600,3.0557249,1.0336592,,,,,,,,,,,,,, -406700,3.12107,1.1418922,,,,,,,,,,,,,, -406800,3.0491629,1.4790602,,,,,,,,,,,,,, -406900,3.0154958,1.068339,,,,,,,,,,,,,, -407000,4.9353046,2.8746774,,,,,,,,,,,,,, -407100,3.0469508,2.3634052,,,,,,,,,,,,,, -407200,3.20646,1.0530927,,,,,,,,,,,,,, -407288,,,0.888476550579071,0.4176208078861236,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,181105.31928634644,196401.7843978405,181105.31928634644,15244.572483301165,31.16160798072815,0.0 -407300,2.9084256,2.1647449,,,,,,,,,,,,,, -407400,3.9461718,3.3633616,,,,,,,,,,,,,, -407500,3.3180804,1.6348937,,,,,,,,,,,,,, -407600,2.995521,1.9604175,,,,,,,,,,,,,, -407700,3.132797,1.2338041,,,,,,,,,,,,,, -407800,3.178459,1.2115986,,,,,,,,,,,,,, -407900,3.0947099,1.2883384,,,,,,,,,,,,,, -408000,3.2353895,1.1123679,,,,,,,,,,,,,, -408100,3.0571568,1.79438,,,,,,,,,,,,,, -408200,2.959432,1.025058,,,,,,,,,,,,,, -408228,,,0.8869140148162842,0.4151829779148102,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,181525.1695902348,196860.4535934925,181525.1695902348,15283.238829135897,31.26688051223755,0.0 -408300,2.9556582,1.6648413,,,,,,,,,,,,,, -408400,3.3710308,3.0445306,,,,,,,,,,,,,, -408500,3.473667,1.1840777,,,,,,,,,,,,,, -408600,3.3974087,3.1035433,,,,,,,,,,,,,, -408700,3.0839913,1.2890779,,,,,,,,,,,,,, -408800,3.071048,1.1414642,,,,,,,,,,,,,, -408900,5.3244014,1.4640794,,,,,,,,,,,,,, -409000,3.2192438,1.1335381,,,,,,,,,,,,,, -409100,3.270426,2.4919078,,,,,,,,,,,,,, -409172,,,0.8871874809265137,0.4227842986583709,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,181945.4384617805,197322.14072799683,181945.4384617805,15324.498523712158,31.37296724319458,0.0 -409200,3.6841297,2.7457929,,,,,,,,,,,,,, -409300,3.163289,2.2652116,,,,,,,,,,,,,, -409400,3.5647886,3.0681605,,,,,,,,,,,,,, -409500,2.9548187,1.9645822,,,,,,,,,,,,,, -409600,3.2054257,1.551727,,,,,,,,,,,,,, -409700,3.304254,1.1986375,,,,,,,,,,,,,, -409800,3.082349,1.760377,,,,,,,,,,,,,, -409900,3.5152822,3.0085871,,,,,,,,,,,,,, -410000,3.0045483,1.8886769,,,,,,,,,,,,,, -410100,2.9601924,1.5291631,,,,,,,,,,,,,, -410120,,,0.8855859041213989,0.4234414994716644,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,182365.6897525788,197784.67691898343,182365.6897525788,15366.648778438568,31.45956587791443,0.0 -410200,2.9599886,1.1386582,,,,,,,,,,,,,, -410300,2.9407578,1.0143425,,,,,,,,,,,,,, -410400,2.7912772,1.1738815,,,,,,,,,,,,,, -410500,2.9929454,1.1473696,,,,,,,,,,,,,, -410600,2.957985,2.3028016,,,,,,,,,,,,,, -410700,2.9351122,2.5508792,,,,,,,,,,,,,, -410800,2.8875208,1.7539642,,,,,,,,,,,,,, -410900,3.1360087,1.1644382,,,,,,,,,,,,,, -411000,3.0781066,1.1466265,,,,,,,,,,,,,, -411060,,,0.8866015672683716,0.4224465787410736,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,182785.7644975185,198239.89599585533,182785.7644975185,15401.639070272446,31.566688537597656,0.0 -411100,3.171512,1.2603025,,,,,,,,,,,,,, -411200,3.090679,2.072135,,,,,,,,,,,,,, -411300,3.1053896,1.9585534,,,,,,,,,,,,,, -411400,3.1770687,2.8772197,,,,,,,,,,,,,, -411500,2.8142781,1.9422464,,,,,,,,,,,,,, -411600,3.13164,1.2739784,,,,,,,,,,,,,, -411700,2.9568443,1.0592242,,,,,,,,,,,,,, -411800,3.263769,1.156152,,,,,,,,,,,,,, -411900,3.019198,1.7666626,,,,,,,,,,,,,, -411999,,,0.8878124952316284,0.4114729464054107,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,183206.07600712776,198703.19691753387,183206.07600712776,15444.472546577454,31.67509889602661,0.0 -412000,3.229859,1.1465205,,,,,,,,,,,,,, -412100,3.1512563,1.4410928,,,,,,,,,,,,,, -412200,2.9925303,2.2039592,,,,,,,,,,,,,, -412300,3.1791153,1.4039692,,,,,,,,,,,,,, -412400,2.898224,1.2945307,,,,,,,,,,,,,, -412500,3.3867085,1.163479,,,,,,,,,,,,,, -412600,3.6306927,3.288067,,,,,,,,,,,,,, -412700,2.9327202,1.1677647,,,,,,,,,,,,,, -412800,3.0514138,1.0562565,,,,,,,,,,,,,, -412900,3.0848243,1.1581476,,,,,,,,,,,,,, -412943,,,0.8874609470367432,0.419604629278183,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,183626.06824731827,199157.4434304237,183626.06824731827,15478.594518899918,31.759069442749023,0.0 -413000,3.7705686,3.1987262,,,,,,,,,,,,,, -413100,2.9749484,1.0612863,,,,,,,,,,,,,, -413200,3.3179657,1.108685,,,,,,,,,,,,,, -413300,3.2249951,1.1210718,,,,,,,,,,,,,, -413400,3.4097674,2.8748348,,,,,,,,,,,,,, -413500,3.7851858,3.192749,,,,,,,,,,,,,, -413600,3.4510674,1.6294264,,,,,,,,,,,,,, -413700,2.8780048,1.1109178,,,,,,,,,,,,,, -413800,3.136998,2.3266907,,,,,,,,,,,,,, -413886,,,0.8876367211341858,0.4223132729530334,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,184045.93671751025,199615.89252972603,184045.93671751025,15517.017944574356,31.868417978286743,0.0 -413900,3.9693882,2.7248971,,,,,,,,,,,,,, -414000,3.008263,1.2524426,,,,,,,,,,,,,, -414100,2.7600284,1.0595709,,,,,,,,,,,,,, -414200,3.2927277,2.8501918,,,,,,,,,,,,,, -414300,3.0166435,1.1173893,,,,,,,,,,,,,, -414400,3.4191835,2.3630972,,,,,,,,,,,,,, -414500,3.1279948,1.2261096,,,,,,,,,,,,,, -414600,3.340715,1.0671233,,,,,,,,,,,,,, -414700,2.8854172,1.0235354,,,,,,,,,,,,,, -414800,3.2173295,1.3362502,,,,,,,,,,,,,, -414826,,,0.8863281011581421,0.427275151014328,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,184465.8177063465,200077.1666204929,184465.8177063465,15558.254828453064,31.97731304168701,0.0 -414900,3.4834223,2.4491453,,,,,,,,,,,,,, -415000,3.0513508,1.1853358,,,,,,,,,,,,,, -415100,3.1438076,1.2316306,,,,,,,,,,,,,, -415200,3.1868808,1.1866755,,,,,,,,,,,,,, -415300,3.222845,1.154109,,,,,,,,,,,,,, -415400,3.222659,2.2249856,,,,,,,,,,,,,, -415500,3.211746,2.4878979,,,,,,,,,,,,,, -415600,3.0704963,1.1707878,,,,,,,,,,,,,, -415700,3.1083438,1.125573,,,,,,,,,,,,,, -415771,,,0.8896093368530273,0.4156470894813537,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,184885.7917456627,200532.31643295288,184885.7917456627,15593.294227838516,32.06606459617615,0.0 -415800,2.8914204,1.7099581,,,,,,,,,,,,,, -415900,3.0930924,1.0985318,,,,,,,,,,,,,, -416000,3.5011501,2.9128828,,,,,,,,,,,,,, -416100,3.7039165,3.1544554,,,,,,,,,,,,,, -416200,3.2707798,2.8320708,,,,,,,,,,,,,, -416300,3.8217623,3.3504791,,,,,,,,,,,,,, -416400,3.0726333,1.8407751,,,,,,,,,,,,,, -416500,3.0594873,1.2309839,,,,,,,,,,,,,, -416600,3.0792112,1.1544116,,,,,,,,,,,,,, -416700,2.9520257,1.8571837,,,,,,,,,,,,,, -416709,,,0.8847265243530273,0.4236800968647003,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,185305.9206287861,200987.87162804604,185305.9206287861,15628.56196141243,32.17793083190918,0.0 -416800,3.0747437,1.1162812,,,,,,,,,,,,,, -416900,3.8699007,1.117941,,,,,,,,,,,,,, -417000,2.9221756,1.416959,,,,,,,,,,,,,, -417100,3.2859309,2.9391067,,,,,,,,,,,,,, -417200,3.2082572,2.788513,,,,,,,,,,,,,, -417300,3.2243292,2.202289,,,,,,,,,,,,,, -417400,3.016863,1.1011509,,,,,,,,,,,,,, -417500,3.177167,1.9168541,,,,,,,,,,,,,, -417600,3.0500388,1.7673817,,,,,,,,,,,,,, -417649,,,0.8902343511581421,0.4112899601459503,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,185726.07363700867,201445.11868214607,185726.07363700867,15665.49923491478,32.28747606277466,0.0 -417700,3.6517901,3.0121806,,,,,,,,,,,,,, -417800,2.928657,1.8267083,,,,,,,,,,,,,, -417900,3.4700463,3.0340538,,,,,,,,,,,,,, -418000,3.8978984,3.1875904,,,,,,,,,,,,,, -418100,3.2328386,3.0079925,,,,,,,,,,,,,, -418200,3.0036364,1.1371859,,,,,,,,,,,,,, -418300,3.94377,3.2288911,,,,,,,,,,,,,, -418400,3.1209548,1.18259,,,,,,,,,,,,,, -418500,3.126968,1.1821119,,,,,,,,,,,,,, -418587,,,0.887499988079071,0.4170810580253601,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,186146.1403939724,201909.37362527847,186146.1403939724,15709.535342693329,32.39239454269409,0.0 -418600,3.162863,2.1495838,,,,,,,,,,,,,, -418700,3.4588268,1.8853964,,,,,,,,,,,,,, -418800,3.2379434,2.676468,,,,,,,,,,,,,, -418900,3.0677164,1.2292763,,,,,,,,,,,,,, -419000,3.309395,2.222642,,,,,,,,,,,,,, -419100,3.2116468,1.091327,,,,,,,,,,,,,, -419200,3.0460532,1.1225872,,,,,,,,,,,,,, -419300,3.1241093,1.0586411,,,,,,,,,,,,,, -419400,2.9756508,1.3777776,,,,,,,,,,,,,, -419500,3.0017366,1.5629395,,,,,,,,,,,,,, -419523,,,0.8847265243530273,0.4248861670494079,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,186566.229950428,202371.4406733513,186566.229950428,15751.373156309128,32.48406267166138,0.0 -419600,3.1967075,1.9762547,,,,,,,,,,,,,, -419700,2.9667299,1.7635032,,,,,,,,,,,,,, -419800,3.275628,1.1413437,,,,,,,,,,,,,, -419900,3.5201578,3.057301,,,,,,,,,,,,,, -420000,2.9389498,1.8953214,,,,,,,,,,,,,, -420100,3.092882,1.3030481,,,,,,,,,,,,,, -420200,3.0932806,1.1513813,,,,,,,,,,,,,, -420300,3.2344282,1.1289954,,,,,,,,,,,,,, -420400,2.9781916,1.2615058,,,,,,,,,,,,,, -420468,,,0.8904687166213989,0.4141132235527038,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,186986.4781191349,202825.1754875183,186986.4781191349,15784.72314286232,32.571659564971924,0.0 -420500,3.0734427,1.1669703,,,,,,,,,,,,,, -420600,3.2135704,1.1398548,,,,,,,,,,,,,, -420700,3.4478936,1.2013546,,,,,,,,,,,,,, -420800,3.08184,2.7739134,,,,,,,,,,,,,, -420900,3.431615,1.7748265,,,,,,,,,,,,,, -421000,3.2927048,1.1163522,,,,,,,,,,,,,, -421100,3.0945213,1.462061,,,,,,,,,,,,,, -421200,3.1142669,1.2192638,,,,,,,,,,,,,, -421300,3.7568247,3.25206,,,,,,,,,,,,,, -421400,3.4548337,2.8229132,,,,,,,,,,,,,, -421408,,,0.8883984088897705,0.4225490093231201,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,187406.4428319931,203285.67652893063,187406.4428319931,15825.103631973268,32.67925763130188,0.0 -421500,3.2600415,2.8605895,,,,,,,,,,,,,, -421600,3.1700141,1.2782016,,,,,,,,,,,,,, -421700,3.070622,1.1120188,,,,,,,,,,,,,, -421800,3.8235688,3.3199053,,,,,,,,,,,,,, -421900,2.9675665,1.1192331,,,,,,,,,,,,,, -422000,3.2135596,2.679838,,,,,,,,,,,,,, -422100,3.342985,2.7201521,,,,,,,,,,,,,, -422200,3.4116905,2.8404903,,,,,,,,,,,,,, -422300,2.9181364,1.9862548,,,,,,,,,,,,,, -422352,,,0.8893554210662842,0.4090215563774109,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,187826.50917100903,203746.17261266708,187826.50917100903,15865.375450849531,32.78850722312927,0.0 -422400,3.2964141,2.7655225,,,,,,,,,,,,,, -422500,3.0958636,1.2023275,,,,,,,,,,,,,, -422600,5.921455,1.2151331,,,,,,,,,,,,,, -422700,2.984485,1.8645463,,,,,,,,,,,,,, -422800,2.8924074,1.0064628,,,,,,,,,,,,,, -422900,3.212112,1.5497009,,,,,,,,,,,,,, -423000,3.6247082,3.1266422,,,,,,,,,,,,,, -423100,3.0161102,1.2300997,,,,,,,,,,,,,, -423200,3.3616204,1.300312,,,,,,,,,,,,,, -423296,,,0.8877343535423279,0.4170146882534027,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,188246.44891786567,204207.1106221676,188246.44891786567,15906.216959238052,32.89739751815796,0.0 -423300,3.0905864,1.2933817,,,,,,,,,,,,,, -423400,2.849155,1.6613276,,,,,,,,,,,,,, -423500,2.9435163,1.623979,,,,,,,,,,,,,, -423600,3.7206347,3.228584,,,,,,,,,,,,,, -423700,3.1965632,1.1805807,,,,,,,,,,,,,, -423800,3.3933735,2.3002315,,,,,,,,,,,,,, -423900,3.1892703,1.0804749,,,,,,,,,,,,,, -424000,3.2382507,1.1212262,,,,,,,,,,,,,, -424100,3.1353397,1.6732426,,,,,,,,,,,,,, -424200,3.0428343,1.2544601,,,,,,,,,,,,,, -424240,,,0.8865820169448853,0.4170560240745544,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,188666.66090917587,204663.19302153587,188666.66090917587,15941.91059589386,33.02621150016785,0.0 -424300,3.05617,1.4555359,,,,,,,,,,,,,, -424400,3.143913,2.7208943,,,,,,,,,,,,,, -424500,2.9275856,1.1838616,,,,,,,,,,,,,, -424600,3.2252455,1.1956273,,,,,,,,,,,,,, -424700,3.3335028,2.7383037,,,,,,,,,,,,,, -424800,3.3380094,1.1561339,,,,,,,,,,,,,, -424900,3.1631324,1.0655721,,,,,,,,,,,,,, -425000,3.213257,1.9595684,,,,,,,,,,,,,, -425100,3.2008681,1.7681172,,,,,,,,,,,,,, -425183,,,0.8876562118530273,0.4162637591361999,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,189086.88722491264,205123.508122921,189086.88722491264,15981.843644618988,33.13454604148865,0.0 -425200,3.3951614,1.50388,,,,,,,,,,,,,, -425300,3.765379,2.8398452,,,,,,,,,,,,,, -425400,3.085048,1.070427,,,,,,,,,,,,,, -425500,3.956205,3.2555878,,,,,,,,,,,,,, -425600,3.3822916,2.8731313,,,,,,,,,,,,,, -425700,3.052359,1.1541387,,,,,,,,,,,,,, -425800,3.2848506,1.9778364,,,,,,,,,,,,,, -425900,2.838496,1.0785326,,,,,,,,,,,,,, -426000,2.8697095,1.2624966,,,,,,,,,,,,,, -426100,3.5921302,1.1315317,,,,,,,,,,,,,, -426119,,,0.8857812285423279,0.4212409257888794,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,189507.0796597004,205588.7849709988,189507.0796597004,16026.771565198898,33.24440002441406,0.0 -426200,2.9913955,2.2793431,,,,,,,,,,,,,, -426300,3.220547,1.2918143,,,,,,,,,,,,,, -426400,3.096689,2.2738535,,,,,,,,,,,,,, -426500,3.3260949,1.1628087,,,,,,,,,,,,,, -426600,3.4179175,1.3817785,,,,,,,,,,,,,, -426700,2.9754968,2.0934997,,,,,,,,,,,,,, -426800,2.7867162,1.351943,,,,,,,,,,,,,, -426900,2.9717002,1.0828611,,,,,,,,,,,,,, -427000,3.5537393,1.0824791,,,,,,,,,,,,,, -427064,,,0.8869726657867432,0.4184642732143402,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,189927.3677642345,206045.5145745277,189927.3677642345,16063.07480955124,33.3340208530426,0.0 -427100,3.939063,1.9084445,,,,,,,,,,,,,, -427200,3.0245247,1.0818087,,,,,,,,,,,,,, -427300,2.7847214,1.5390553,,,,,,,,,,,,,, -427400,3.3197148,2.606236,,,,,,,,,,,,,, -427500,3.0990424,1.0679065,,,,,,,,,,,,,, -427600,3.4352024,1.3715159,,,,,,,,,,,,,, -427700,3.146535,2.7703135,,,,,,,,,,,,,, -427800,4.2936006,3.1952393,,,,,,,,,,,,,, -427900,3.1015456,1.6204658,,,,,,,,,,,,,, -428000,2.9301689,1.0754597,,,,,,,,,,,,,, -428007,,,0.8869726657867432,0.4254200160503387,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,190347.5967078209,206501.27928853035,190347.5967078209,16098.452381849287,33.44491386413574,0.0 -428100,2.889043,1.1529111,,,,,,,,,,,,,, -428200,3.001221,1.1178159,,,,,,,,,,,,,, -428300,3.2137828,1.0472444,,,,,,,,,,,,,, -428400,2.9743268,2.058193,,,,,,,,,,,,,, -428500,2.8707194,2.1415,,,,,,,,,,,,,, -428600,3.8041248,3.0698247,,,,,,,,,,,,,, -428700,2.8968985,1.1484946,,,,,,,,,,,,,, -428800,3.0092273,1.2929096,,,,,,,,,,,,,, -428900,3.1970787,2.507803,,,,,,,,,,,,,, -428950,,,0.8871093392372131,0.4201515614986419,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,190767.7202951908,206968.0955746174,190767.7202951908,16144.987778663635,33.55448365211487,0.0 -429000,2.9376614,1.4079934,,,,,,,,,,,,,, -429100,3.2369838,1.2592869,,,,,,,,,,,,,, -429200,2.950099,1.9266298,,,,,,,,,,,,,, -429300,3.191445,2.4625762,,,,,,,,,,,,,, -429400,3.2832868,1.1712174,,,,,,,,,,,,,, -429500,3.1876857,1.234389,,,,,,,,,,,,,, -429600,3.3722186,1.1286243,,,,,,,,,,,,,, -429700,3.0355253,1.0638832,,,,,,,,,,,,,, -429800,3.4713569,2.9880657,,,,,,,,,,,,,, -429896,,,0.8857030868530273,0.4262813031673431,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,191187.6201922893,207424.07229161265,191187.6201922893,16180.92680835724,33.644153118133545,0.0 -429900,2.8795478,1.0446634,,,,,,,,,,,,,, -430000,3.2004669,1.0843658,,,,,,,,,,,,,, -430100,3.1778083,1.5736123,,,,,,,,,,,,,, -430200,2.9814596,1.7539048,,,,,,,,,,,,,, -430300,3.2658234,2.8317757,,,,,,,,,,,,,, -430400,3.1072807,1.6152368,,,,,,,,,,,,,, -430500,3.655045,1.149599,,,,,,,,,,,,,, -430600,3.1422353,1.5272956,,,,,,,,,,,,,, -430700,3.1702092,1.1568952,,,,,,,,,,,,,, -430800,2.9825084,1.4684236,,,,,,,,,,,,,, -430841,,,0.8873046636581421,0.4140742123126983,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,191607.90561056137,207880.95266985893,191607.90561056137,16217.364193677902,33.7524573802948,0.0 -430900,2.9897795,1.8150394,,,,,,,,,,,,,, -431000,2.9309916,1.3895589,,,,,,,,,,,,,, -431100,3.2786734,1.1091143,,,,,,,,,,,,,, -431200,4.0797987,2.8031824,,,,,,,,,,,,,, -431300,3.9044533,3.3477736,,,,,,,,,,,,,, -431400,3.07554,2.6188195,,,,,,,,,,,,,, -431500,3.176229,1.0152413,,,,,,,,,,,,,, -431600,2.9349124,1.0902058,,,,,,,,,,,,,, -431700,2.8866062,1.0777988,,,,,,,,,,,,,, -431784,,,0.8875781297683716,0.4148840606212616,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,192027.8820848465,208343.6129004956,192027.8820848465,16259.892963647842,33.86008358001709,0.0 -431800,3.6689053,1.2036141,,,,,,,,,,,,,, -431900,3.3772745,1.2068728,,,,,,,,,,,,,, -432000,6.0352206,3.143766,,,,,,,,,,,,,, -432100,3.295223,1.2118177,,,,,,,,,,,,,, -432200,3.4020188,1.420127,,,,,,,,,,,,,, -432300,3.107161,1.2015183,,,,,,,,,,,,,, -432400,3.0759115,2.046232,,,,,,,,,,,,,, -432500,2.988278,1.404085,,,,,,,,,,,,,, -432600,2.9984362,1.1100717,,,,,,,,,,,,,, -432700,3.2971566,1.1231331,,,,,,,,,,,,,, -432731,,,0.8882616758346558,0.4209084510803222,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,192447.784427166,208798.56092476845,192447.784427166,16294.80163049698,33.95030069351196,0.0 -432800,3.1889894,2.191484,,,,,,,,,,,,,, -432900,3.550177,3.1587138,,,,,,,,,,,,,, -433000,3.199989,1.0933652,,,,,,,,,,,,,, -433100,3.6039512,2.076242,,,,,,,,,,,,,, -433200,3.1099174,2.5737386,,,,,,,,,,,,,, -433300,3.4878376,2.6171944,,,,,,,,,,,,,, -433400,3.5346413,3.118138,,,,,,,,,,,,,, -433500,3.03257,1.193959,,,,,,,,,,,,,, -433600,3.5270684,1.8658383,,,,,,,,,,,,,, -433674,,,0.8856444954872131,0.4259785115718841,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,192867.92846989632,209256.31863760948,192867.92846989632,16332.21807217598,34.1000452041626,0.0 -433700,3.6703615,3.2550902,,,,,,,,,,,,,, -433800,3.3757527,1.0568858,,,,,,,,,,,,,, -433900,4.1286845,3.336625,,,,,,,,,,,,,, -434000,3.3924253,2.5944927,,,,,,,,,,,,,, -434100,2.9803903,1.7736293,,,,,,,,,,,,,, -434200,3.0170572,1.381745,,,,,,,,,,,,,, -434300,3.0383918,1.104665,,,,,,,,,,,,,, -434400,2.8279104,1.1270066,,,,,,,,,,,,,, -434500,3.0396748,2.3617773,,,,,,,,,,,,,, -434600,2.9743445,2.1368973,,,,,,,,,,,,,, -434615,,,0.8866601586341858,0.4217454791069031,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,193288.04555511475,209710.3728477955,193288.04555511475,16366.020912408829,34.18677401542664,0.0 -434700,3.054832,1.2472538,,,,,,,,,,,,,, -434800,4.2269316,2.98651,,,,,,,,,,,,,, -434900,3.4116437,3.167846,,,,,,,,,,,,,, -435000,3.2259407,1.1079106,,,,,,,,,,,,,, -435100,3.1538463,1.1475961,,,,,,,,,,,,,, -435200,3.2119143,1.2763807,,,,,,,,,,,,,, -435300,3.6995306,3.2578976,,,,,,,,,,,,,, -435400,3.0683408,1.2719713,,,,,,,,,,,,,, -435500,3.4987006,2.8191547,,,,,,,,,,,,,, -435556,,,0.8860546946525574,0.4153732657432556,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,193708.16326212883,210174.040797472,193708.16326212883,16409.411342144012,34.299379110336304,0.0 -435600,2.8859355,1.9387538,,,,,,,,,,,,,, -435700,3.135236,1.4479038,,,,,,,,,,,,,, -435800,2.9952967,1.1802983,,,,,,,,,,,,,, -435900,3.033555,1.3361065,,,,,,,,,,,,,, -436000,3.1202679,2.0821314,,,,,,,,,,,,,, -436100,3.2231472,2.4466698,,,,,,,,,,,,,, -436200,3.1139,1.0310289,,,,,,,,,,,,,, -436300,3.019667,1.6921214,,,,,,,,,,,,,, -436400,3.1001742,1.0973654,,,,,,,,,,,,,, -436500,,,0.8884179592132568,0.4191081523895263,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,194128.2306342125,210631.71815299988,194128.2306342125,16446.87612438202,34.396831035614014,0.0 -436500,3.2789507,2.5960605,,,,,,,,,,,,,, -436600,3.1563,1.2802908,,,,,,,,,,,,,, -436700,3.0202975,2.3059757,,,,,,,,,,,,,, -436800,2.95658,1.2234602,,,,,,,,,,,,,, -436900,3.1400023,1.2883288,,,,,,,,,,,,,, -437000,2.8940628,1.1714755,,,,,,,,,,,,,, -437100,4.698124,3.1749215,,,,,,,,,,,,,, -437200,3.121953,1.6425772,,,,,,,,,,,,,, -437300,3.5879388,2.880903,,,,,,,,,,,,,, -437400,3.391938,1.8052049,,,,,,,,,,,,,, -437447,,,0.8879492282867432,0.421378344297409,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,194548.28590345383,211088.8171780109,194548.28590345383,16483.761477947235,34.50777578353882,0.0 -437500,3.6810012,3.069866,,,,,,,,,,,,,, -437600,3.1312404,1.0996485,,,,,,,,,,,,,, -437700,3.0999057,1.9253627,,,,,,,,,,,,,, -437800,2.9783187,2.0553489,,,,,,,,,,,,,, -437900,3.0658724,1.2287453,,,,,,,,,,,,,, -438000,3.0507214,1.2354695,,,,,,,,,,,,,, -438100,3.3284261,1.1251289,,,,,,,,,,,,,, -438200,3.8244283,3.2471144,,,,,,,,,,,,,, -438300,3.0066798,1.4646075,,,,,,,,,,,,,, -438388,,,0.8866991996765137,0.4202461242675781,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,194968.3395268917,211553.87763905525,194968.3395268917,16528.610441684723,34.61800146102905,0.0 -438400,3.2709954,2.3748662,,,,,,,,,,,,,, -438500,3.2505827,1.6403792,,,,,,,,,,,,,, -438600,2.9057992,1.0340514,,,,,,,,,,,,,, -438700,3.184177,1.6428187,,,,,,,,,,,,,, -438800,3.3072078,1.2086364,,,,,,,,,,,,,, -438900,3.2827835,1.1715221,,,,,,,,,,,,,, -439000,3.0311844,1.3816395,,,,,,,,,,,,,, -439100,3.0535722,1.5192163,,,,,,,,,,,,,, -439200,3.3962808,2.883956,,,,,,,,,,,,,, -439300,3.2923503,1.7562332,,,,,,,,,,,,,, -439332,,,0.8890234231948853,0.4170263707637787,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,195388.52269601825,212017.23931241035,195388.52269601825,16571.6326918602,34.726662397384644,0.0 -439400,3.0520074,1.0710521,,,,,,,,,,,,,, -439500,3.2225084,1.151319,,,,,,,,,,,,,, -439600,3.0394702,1.6052408,,,,,,,,,,,,,, -439700,3.3798504,2.290306,,,,,,,,,,,,,, -439800,3.4498067,1.0621581,,,,,,,,,,,,,, -439900,4.4344554,3.1956787,,,,,,,,,,,,,, -440000,2.8977017,1.2769548,,,,,,,,,,,,,, -440100,3.1451907,0.97695684,,,,,,,,,,,,,, -440200,2.916829,1.0868958,,,,,,,,,,,,,, -440278,,,0.8864843845367432,0.4210885167121887,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,195808.53741788864,212475.6008841992,195808.53741788864,16609.84241771698,34.8164746761322,0.0 -440300,2.7348096,1.1106237,,,,,,,,,,,,,, -440400,3.2100973,1.2386808,,,,,,,,,,,,,, -440500,3.0003686,1.1948797,,,,,,,,,,,,,, -440600,3.135582,1.2010443,,,,,,,,,,,,,, -440700,2.8551834,1.561404,,,,,,,,,,,,,, -440800,3.2663527,2.7618325,,,,,,,,,,,,,, -440900,3.16323,1.1738738,,,,,,,,,,,,,, -441000,3.447348,2.9989605,,,,,,,,,,,,,, -441100,3.0430036,1.5205683,,,,,,,,,,,,,, -441200,3.516063,1.2675151,,,,,,,,,,,,,, -441221,,,0.8875195384025574,0.4173341393470764,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,196228.62652277944,212942.71294903755,196228.62652277944,16656.709317207336,34.92500853538513,0.0 -441300,3.6485186,2.933256,,,,,,,,,,,,,, -441400,3.3938618,1.170224,,,,,,,,,,,,,, -441500,3.0438738,1.1730413,,,,,,,,,,,,,, -441600,3.3519113,2.7866244,,,,,,,,,,,,,, -441700,3.7235458,2.7622766,,,,,,,,,,,,,, -441800,2.908204,1.1429902,,,,,,,,,,,,,, -441900,3.8426106,3.3634462,,,,,,,,,,,,,, -442000,3.2490118,1.558631,,,,,,,,,,,,,, -442100,3.297742,2.8328052,,,,,,,,,,,,,, -442159,,,0.8889257907867432,0.4166601002216339,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,196648.6313741207,213407.027728796,196648.6313741207,16700.883150815964,35.01406455039978,0.0 -442200,3.358823,1.1795074,,,,,,,,,,,,,, -442300,3.5151358,1.2278395,,,,,,,,,,,,,, -442400,3.1208167,1.306386,,,,,,,,,,,,,, -442500,3.5113797,1.1981372,,,,,,,,,,,,,, -442600,3.2610588,1.7722086,,,,,,,,,,,,,, -442700,3.2657719,1.1211791,,,,,,,,,,,,,, -442800,2.8853996,1.9270182,,,,,,,,,,,,,, -442900,2.9434168,1.3733178,,,,,,,,,,,,,, -443000,3.021943,1.5402181,,,,,,,,,,,,,, -443100,3.292246,1.3046055,,,,,,,,,,,,,, -443101,,,0.8854296803474426,0.422122985124588,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,197069.1317648888,213871.40952277184,197069.1317648888,16744.626331090927,35.10526466369629,0.0 -443200,3.0601046,1.7634689,,,,,,,,,,,,,, -443300,3.2861834,1.1863358,,,,,,,,,,,,,, -443400,3.3494642,2.9898114,,,,,,,,,,,,,, -443500,3.045347,1.7284554,,,,,,,,,,,,,, -443600,3.2995193,1.0644367,,,,,,,,,,,,,, -443700,3.4524825,1.1909657,,,,,,,,,,,,,, -443800,3.0283988,1.6720164,,,,,,,,,,,,,, -443900,3.0545998,1.1612028,,,,,,,,,,,,,, -444000,3.750696,3.294655,,,,,,,,,,,,,, -444043,,,0.8884179592132568,0.4219891726970672,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,197489.02456712723,214331.8288094997,197489.02456712723,16785.014815568924,35.19585084915161,0.0 -444100,3.8301525,2.9999962,,,,,,,,,,,,,, -444200,3.1570423,1.0655304,,,,,,,,,,,,,, -444300,3.9214275,1.1524422,,,,,,,,,,,,,, -444400,3.204091,1.1597962,,,,,,,,,,,,,, -444500,2.9606786,1.0097355,,,,,,,,,,,,,, -444600,3.1908636,1.3802104,,,,,,,,,,,,,, -444700,4.2098227,3.1363983,,,,,,,,,,,,,, -444800,3.157223,1.2514004,,,,,,,,,,,,,, -444900,3.1621604,1.0786498,,,,,,,,,,,,,, -444988,,,0.8905078172683716,0.4116933345794678,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,197909.07106089592,214794.9362232685,197909.07106089592,16827.93798494339,35.28695893287659,0.0 -445000,2.9830256,1.0793557,,,,,,,,,,,,,, -445100,3.4446218,3.1874843,,,,,,,,,,,,,, -445200,3.2752235,2.797108,,,,,,,,,,,,,, -445300,3.03209,2.3244734,,,,,,,,,,,,,, -445400,3.1207242,2.3268926,,,,,,,,,,,,,, -445500,3.1975694,2.2628613,,,,,,,,,,,,,, -445600,3.1949792,2.903082,,,,,,,,,,,,,, -445700,3.059381,1.0273255,,,,,,,,,,,,,, -445800,2.9846811,1.2368394,,,,,,,,,,,,,, -445900,3.2363153,1.2217376,,,,,,,,,,,,,, -445931,,,0.8903515338897705,0.4074158072471618,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,198329.2227098941,215258.19314956665,198329.2227098941,16870.90295481682,35.37960910797119,0.0 -446000,2.8842666,1.6350179,,,,,,,,,,,,,, -446100,3.1848474,1.2010279,,,,,,,,,,,,,, -446200,2.988748,1.0050447,,,,,,,,,,,,,, -446300,3.0921338,1.0344017,,,,,,,,,,,,,, -446400,3.8613014,2.8711467,,,,,,,,,,,,,, -446500,3.2427804,1.3047327,,,,,,,,,,,,,, -446600,3.0413184,1.121497,,,,,,,,,,,,,, -446700,3.7157085,3.264086,,,,,,,,,,,,,, -446800,3.178181,2.6562662,,,,,,,,,,,,,, -446876,,,0.885546863079071,0.4211876988410949,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,198749.5303769112,215723.04054903984,198749.5303769112,16915.30412220955,35.47115755081177,0.0 -446900,3.3897443,2.395718,,,,,,,,,,,,,, -447000,3.6634927,1.2078992,,,,,,,,,,,,,, -447100,3.2542284,1.2277236,,,,,,,,,,,,,, -447200,3.2082853,1.0623976,,,,,,,,,,,,,, -447300,4.4315104,3.0433447,,,,,,,,,,,,,, -447400,3.0536158,1.0937293,,,,,,,,,,,,,, -447500,3.1526618,1.1352339,,,,,,,,,,,,,, -447600,3.9278572,3.253165,,,,,,,,,,,,,, -447700,2.9081335,1.3309346,,,,,,,,,,,,,, -447800,3.7984776,3.2215605,,,,,,,,,,,,,, -447823,,,0.887011706829071,0.4194007813930511,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,199169.5191576481,216180.07539200783,199169.5191576481,16952.20995235443,35.563923597335815,0.0 -447900,3.1553776,1.2878178,,,,,,,,,,,,,, -448000,3.1689122,2.047883,,,,,,,,,,,,,, -448100,3.3300817,1.6745774,,,,,,,,,,,,,, -448200,3.4764435,1.3347893,,,,,,,,,,,,,, -448300,2.737029,1.3203447,,,,,,,,,,,,,, -448400,3.1790192,1.2116232,,,,,,,,,,,,,, -448500,3.7883139,1.1099695,,,,,,,,,,,,,, -448600,2.9912603,1.1091155,,,,,,,,,,,,,, -448700,3.2778144,1.0289397,,,,,,,,,,,,,, -448761,,,0.8859765529632568,0.421368658542633,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,199589.37185049057,216638.24611639977,199589.37185049057,16990.376428365707,35.668083906173706,0.0 -448800,3.191733,1.3511146,,,,,,,,,,,,,, -448900,3.142849,1.1320903,,,,,,,,,,,,,, -449000,3.1180449,1.9194701,,,,,,,,,,,,,, -449100,3.2568588,2.9836674,,,,,,,,,,,,,, -449200,3.2927823,1.1180631,,,,,,,,,,,,,, -449300,2.8590758,1.0489076,,,,,,,,,,,,,, -449400,3.4345624,2.7524347,,,,,,,,,,,,,, -449500,2.9503393,1.9492037,,,,,,,,,,,,,, -449600,3.703178,2.9185448,,,,,,,,,,,,,, -449700,3.2538526,1.1390004,,,,,,,,,,,,,, -449701,,,0.8870702981948853,0.4166412651538849,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,200009.4314689636,217103.29435062408,200009.4314689636,17035.225475549698,35.76078271865845,0.0 -449800,3.3067472,1.079432,,,,,,,,,,,,,, -449900,5.1982317,2.7348635,,,,,,,,,,,,,, -450000,3.1577363,1.0847325,,,,,,,,,,,,,, -450100,3.0996857,1.124594,,,,,,,,,,,,,, -450200,3.2973382,1.1461537,,,,,,,,,,,,,, -450300,3.1888404,1.1481416,,,,,,,,,,,,,, -450400,3.1365654,1.1521728,,,,,,,,,,,,,, -450500,3.1434371,2.7728448,,,,,,,,,,,,,, -450600,3.1253138,1.2294061,,,,,,,,,,,,,, -450645,,,0.8881444931030273,0.4162316024303436,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,200429.42027401924,217559.97447752955,200429.42027401924,17071.777951478958,35.85300397872925,0.0 -450700,3.0096269,2.2040343,,,,,,,,,,,,,, -450800,3.2552783,1.25645,,,,,,,,,,,,,, -450900,3.0642264,1.0315882,,,,,,,,,,,,,, -451000,2.9147797,1.4299525,,,,,,,,,,,,,, -451100,3.8316627,3.310535,,,,,,,,,,,,,, -451200,3.1736062,1.0862507,,,,,,,,,,,,,, -451300,3.1254072,1.1314212,,,,,,,,,,,,,, -451400,3.2080183,2.5796053,,,,,,,,,,,,,, -451500,3.1823933,1.1701039,,,,,,,,,,,,,, -451587,,,0.8881054520606995,0.421012133359909,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,200849.3727684021,218030.3255007267,200849.3727684021,17122.015122890472,35.96689581871033,0.0 -451600,3.1265988,2.9084167,,,,,,,,,,,,,, -451700,3.7970543,1.6266031,,,,,,,,,,,,,, -451800,3.5678005,3.1611907,,,,,,,,,,,,,, -451900,3.0283315,2.0121825,,,,,,,,,,,,,, -452000,3.102532,1.0976739,,,,,,,,,,,,,, -452100,3.0620036,1.1362562,,,,,,,,,,,,,, -452200,3.3016012,1.0810726,,,,,,,,,,,,,, -452300,3.2278762,1.1367667,,,,,,,,,,,,,, -452400,3.3930662,1.614562,,,,,,,,,,,,,, -452500,3.289604,2.9016533,,,,,,,,,,,,,, -452532,,,0.8839452862739563,0.4290750026702881,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,201269.47407794,218486.012780428,201269.47407794,17157.460511684418,36.06063795089722,0.0 -452600,3.353381,2.9036098,,,,,,,,,,,,,, -452700,3.319605,2.6187005,,,,,,,,,,,,,, -452800,3.5820494,2.7588806,,,,,,,,,,,,,, -452900,3.0936615,1.18551,,,,,,,,,,,,,, -453000,3.6970012,3.1143515,,,,,,,,,,,,,, -453100,3.1826313,1.2426957,,,,,,,,,,,,,, -453200,2.839131,0.9608082,,,,,,,,,,,,,, -453300,3.0533016,1.1014978,,,,,,,,,,,,,, -453400,3.1651,1.6079677,,,,,,,,,,,,,, -453473,,,0.8886913657188416,0.4158783555030823,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,201689.4728207588,218944.8034465313,201689.4728207588,17196.084241628647,36.18183422088623,0.0 -453500,3.3562455,2.1849244,,,,,,,,,,,,,, -453600,2.9218907,1.6048427,,,,,,,,,,,,,, -453700,3.1225693,1.1150674,,,,,,,,,,,,,, -453800,3.2055933,1.5816967,,,,,,,,,,,,,, -453900,3.0028288,2.108015,,,,,,,,,,,,,, -454000,3.3765643,3.0514631,,,,,,,,,,,,,, -454100,3.170674,1.0858108,,,,,,,,,,,,,, -454200,3.0556238,2.022274,,,,,,,,,,,,,, -454300,3.0350025,1.1560401,,,,,,,,,,,,,, -454400,3.0464487,1.6177788,,,,,,,,,,,,,, -454414,,,0.8854491710662842,0.4210602045059204,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,202109.38947105408,219410.560652256,202109.38947105408,17241.766832351685,36.29235625267029,0.0 -454500,3.555757,2.4882293,,,,,,,,,,,,,, -454600,3.2624955,1.157965,,,,,,,,,,,,,, -454700,2.98878,1.9280082,,,,,,,,,,,,,, -454800,3.701441,3.2594872,,,,,,,,,,,,,, -454900,2.9191463,2.067562,,,,,,,,,,,,,, -455000,3.3638496,1.121794,,,,,,,,,,,,,, -455100,2.976488,1.6889267,,,,,,,,,,,,,, -455200,3.2127347,1.1428616,,,,,,,,,,,,,, -455300,3.5788262,1.2305727,,,,,,,,,,,,,, -455360,,,0.8898632526397705,0.4121298491954803,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,202529.48514270785,219867.0714662075,202529.48514270785,17278.042666196823,36.3836932182312,0.0 -455400,3.1057143,1.087848,,,,,,,,,,,,,, -455500,3.2993062,1.189529,,,,,,,,,,,,,, -455600,3.725907,3.1777012,,,,,,,,,,,,,, -455700,2.939318,1.1524589,,,,,,,,,,,,,, -455800,3.0942075,1.500215,,,,,,,,,,,,,, -455900,3.1003942,2.2057872,,,,,,,,,,,,,, -456000,3.1738405,1.1491958,,,,,,,,,,,,,, -456100,3.0649202,1.2043054,,,,,,,,,,,,,, -456200,3.6356943,3.190445,,,,,,,,,,,,,, -456300,2.9706743,1.9106865,,,,,,,,,,,,,, -456304,,,0.8844726085662842,0.427636057138443,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,202949.4070637226,220326.2607011795,202949.4070637226,17317.149385929108,36.49660277366638,0.0 -456400,3.0423899,1.0690185,,,,,,,,,,,,,, -456500,4.9602375,3.203227,,,,,,,,,,,,,, -456600,3.1956599,2.6264985,,,,,,,,,,,,,, -456700,3.3078935,2.833118,,,,,,,,,,,,,, -456800,3.6429982,3.246268,,,,,,,,,,,,,, -456900,3.3505077,3.0536134,,,,,,,,,,,,,, -457000,3.1214724,1.3687395,,,,,,,,,,,,,, -457100,3.5805078,2.6281302,,,,,,,,,,,,,, -457200,3.0261679,1.0395675,,,,,,,,,,,,,, -457242,,,0.8884375095367432,0.4146224856376648,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,203369.30669617653,220784.4556255341,203369.30669617653,17355.284680366516,36.609949350357056,0.0 -457300,3.0835805,1.2314773,,,,,,,,,,,,,, -457400,2.924645,2.1187088,,,,,,,,,,,,,, -457500,3.1136343,1.0402558,,,,,,,,,,,,,, -457600,3.285744,1.1164953,,,,,,,,,,,,,, -457700,3.2756064,2.2903342,,,,,,,,,,,,,, -457800,3.608545,2.4669065,,,,,,,,,,,,,, -457900,3.0870433,1.1662757,,,,,,,,,,,,,, -458000,3.4420357,1.178689,,,,,,,,,,,,,, -458100,3.378785,2.4912744,,,,,,,,,,,,,, -458183,,,0.88685542345047,0.4204631745815277,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,203789.2432193756,221246.88278746605,203789.2432193756,17397.613396406174,36.72454309463501,0.0 -458200,3.123512,2.0516043,,,,,,,,,,,,,, -458300,3.0804958,1.0532112,,,,,,,,,,,,,, -458400,2.9688528,1.2417511,,,,,,,,,,,,,, -458500,3.2034485,1.5533425,,,,,,,,,,,,,, -458600,3.1259153,1.2478212,,,,,,,,,,,,,, -458700,2.9388845,1.5721219,,,,,,,,,,,,,, -458800,3.4259899,1.1087389,,,,,,,,,,,,,, -458900,3.5403876,2.3694956,,,,,,,,,,,,,, -459000,3.0095181,1.0279655,,,,,,,,,,,,,, -459100,2.9570417,1.0736313,,,,,,,,,,,,,, -459129,,,0.8864452838897705,0.4202134609222412,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,204209.13581633568,221705.85871863365,204209.13581633568,17436.533225536346,36.84003710746765,0.0 -459200,3.143835,1.044555,,,,,,,,,,,,,, -459300,3.743371,3.0437782,,,,,,,,,,,,,, -459400,2.9040544,1.1616858,,,,,,,,,,,,,, -459500,3.4127626,1.1780752,,,,,,,,,,,,,, -459600,3.2924547,3.048639,,,,,,,,,,,,,, -459700,3.0860078,2.6707249,,,,,,,,,,,,,, -459800,3.267712,1.2254032,,,,,,,,,,,,,, -459900,3.0887504,1.1261659,,,,,,,,,,,,,, -460000,3.2672272,2.1652136,,,,,,,,,,,,,, -460072,,,0.8882030844688416,0.416919469833374,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,204629.39767479897,222171.1817200184,204629.39767479897,17481.431627988815,36.955846309661865,0.0 -460100,2.9779794,1.0669053,,,,,,,,,,,,,, -460200,2.9353116,1.8879161,,,,,,,,,,,,,, -460300,3.076754,1.0951668,,,,,,,,,,,,,, -460400,3.8484845,3.3583422,,,,,,,,,,,,,, -460500,3.8660257,3.2253413,,,,,,,,,,,,,, -460600,3.06009,2.156166,,,,,,,,,,,,,, -460700,3.0678077,1.6679637,,,,,,,,,,,,,, -460800,2.963323,1.2877185,,,,,,,,,,,,,, -460900,2.9016538,1.6133215,,,,,,,,,,,,,, -461000,2.9600196,2.2977533,,,,,,,,,,,,,, -461015,,,0.8881054520606995,0.4188825786113739,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,205049.5205335617,222634.9448349476,205049.5205335617,17524.9306910038,37.0502290725708,0.0 -461100,3.0149934,1.106025,,,,,,,,,,,,,, -461200,3.1435585,1.0822245,,,,,,,,,,,,,, -461300,2.9773223,1.87953,,,,,,,,,,,,,, -461400,2.9918544,2.206936,,,,,,,,,,,,,, -461500,3.0649078,1.260656,,,,,,,,,,,,,, -461600,3.2696857,1.6700672,,,,,,,,,,,,,, -461700,3.3140123,3.0075145,,,,,,,,,,,,,, -461800,3.6384747,3.1513267,,,,,,,,,,,,,, -461900,3.036285,1.1297231,,,,,,,,,,,,,, -461958,,,0.8861327767372131,0.4285134375095367,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,205469.76189637184,223092.48494267464,205469.76189637184,17562.089664697647,37.14282035827637,0.0 -462000,3.214984,1.6063116,,,,,,,,,,,,,, -462100,3.2941942,1.1487027,,,,,,,,,,,,,, -462200,2.991848,1.0486832,,,,,,,,,,,,,, -462300,2.8812892,1.109415,,,,,,,,,,,,,, -462400,3.0791168,1.2440732,,,,,,,,,,,,,, -462500,2.9227173,1.7583792,,,,,,,,,,,,,, -462600,3.1566834,2.3641052,,,,,,,,,,,,,, -462700,2.8202338,1.8429197,,,,,,,,,,,,,, -462800,3.1900408,1.1622032,,,,,,,,,,,,,, -462900,3.3854253,3.1157043,,,,,,,,,,,,,, -462901,,,0.8883007764816284,0.4149970114231109,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,205890.15830087665,223556.4319159985,205890.15830087665,17605.4763879776,37.25931525230408,0.0 -463000,3.044625,2.530011,,,,,,,,,,,,,, -463100,2.9758077,1.6601653,,,,,,,,,,,,,, -463200,3.32726,2.962846,,,,,,,,,,,,,, -463300,3.6956096,2.8494437,,,,,,,,,,,,,, -463400,3.0747972,1.1691378,,,,,,,,,,,,,, -463500,2.9825723,1.1970928,,,,,,,,,,,,,, -463600,3.0738666,1.1780534,,,,,,,,,,,,,, -463700,2.7785914,1.6715057,,,,,,,,,,,,,, -463800,3.0719333,1.8153131,,,,,,,,,,,,,, -463843,,,0.887011706829071,0.4187023341655731,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,206310.3851909637,224015.01740431783,206310.3851909637,17643.671263217926,37.37636876106262,0.0 -463900,3.392962,2.7304475,,,,,,,,,,,,,, -464000,3.1746926,1.15461,,,,,,,,,,,,,, -464100,3.8552825,2.9666982,,,,,,,,,,,,,, -464200,3.6377375,3.0032992,,,,,,,,,,,,,, -464300,3.3452206,1.5601243,,,,,,,,,,,,,, -464400,3.246714,1.2695882,,,,,,,,,,,,,, -464500,2.7807722,1.5031469,,,,,,,,,,,,,, -464600,3.6234264,1.4341806,,,,,,,,,,,,,, -464700,3.696573,1.1755068,,,,,,,,,,,,,, -464785,,,0.8866406083106995,0.4217980206012726,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,206730.5388054848,224470.98530101776,206730.5388054848,17679.339367866516,37.474828243255615,0.0 -464800,3.2167935,2.5928893,,,,,,,,,,,,,, -464900,3.8190873,3.1431837,,,,,,,,,,,,,, -465000,3.015527,2.4685388,,,,,,,,,,,,,, -465100,3.0613136,1.9672934,,,,,,,,,,,,,, -465200,3.4891949,3.008816,,,,,,,,,,,,,, -465300,3.9826906,3.14862,,,,,,,,,,,,,, -465400,3.179277,1.0969561,,,,,,,,,,,,,, -465500,3.149571,1.1925176,,,,,,,,,,,,,, -465600,3.1495063,1.2259595,,,,,,,,,,,,,, -465700,3.0152762,1.137679,,,,,,,,,,,,,, -465727,,,0.8874609470367432,0.4204729497432709,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,207150.3881704808,224930.26573109627,207150.3881704808,17718.601448774338,37.595885038375854,0.0 -465800,3.0484126,1.0493678,,,,,,,,,,,,,, -465900,3.421572,1.1511002,,,,,,,,,,,,,, -466000,3.015769,1.2592745,,,,,,,,,,,,,, -466100,3.5972111,2.8120978,,,,,,,,,,,,,, -466200,3.310637,1.2410815,,,,,,,,,,,,,, -466300,3.4280317,1.2569795,,,,,,,,,,,,,, -466400,3.4493065,2.4193141,,,,,,,,,,,,,, -466500,3.319164,1.2187191,,,,,,,,,,,,,, -466600,3.3498201,1.2211636,,,,,,,,,,,,,, -466668,,,0.8872265219688416,0.4182624816894531,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,207570.5442082882,225389.47584223747,207570.5442082882,17757.488273620605,37.715657234191895,0.0 -466700,2.9908762,1.180109,,,,,,,,,,,,,, -466800,2.926559,1.0804719,,,,,,,,,,,,,, -466900,2.9520206,1.1716914,,,,,,,,,,,,,, -467000,3.162697,1.0808614,,,,,,,,,,,,,, -467100,3.0901012,1.2395273,,,,,,,,,,,,,, -467200,3.0167768,2.1667273,,,,,,,,,,,,,, -467300,3.0632885,1.1022427,,,,,,,,,,,,,, -467400,3.2692187,2.5714104,,,,,,,,,,,,,, -467500,2.8233666,2.0241857,,,,,,,,,,,,,, -467600,3.5573459,2.73691,,,,,,,,,,,,,, -467613,,,0.88978511095047,0.4156844913959503,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,207990.6780860424,225853.3052201271,207990.6780860424,17800.95606303215,37.89576005935669,0.0 -467700,2.871078,1.1715889,,,,,,,,,,,,,, -467800,3.2363198,1.1619334,,,,,,,,,,,,,, -467900,3.1719306,2.613181,,,,,,,,,,,,,, -468000,3.2981758,1.2311574,,,,,,,,,,,,,, -468100,2.9715257,2.1188114,,,,,,,,,,,,,, -468200,2.861329,1.1581075,,,,,,,,,,,,,, -468300,3.3050525,1.1421547,,,,,,,,,,,,,, -468400,3.0240152,1.8818711,,,,,,,,,,,,,, -468500,3.5396018,2.9694395,,,,,,,,,,,,,, -468557,,,0.8885741829872131,0.4178123772144317,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,208410.72726297376,226313.6565389633,208410.72726297376,17841.113161325455,37.99327063560486,0.0 -468600,3.3294277,2.7173238,,,,,,,,,,,,,, -468700,2.964757,1.1595027,,,,,,,,,,,,,, -468800,3.1223516,1.0408632,,,,,,,,,,,,,, -468900,3.4259486,1.4081179,,,,,,,,,,,,,, -469000,3.1197298,1.0695696,,,,,,,,,,,,,, -469100,3.443508,2.821442,,,,,,,,,,,,,, -469200,3.0640209,1.3089662,,,,,,,,,,,,,, -469300,3.3522859,1.7938769,,,,,,,,,,,,,, -469400,3.0993526,2.3637476,,,,,,,,,,,,,, -469500,2.9786716,1.0051208,,,,,,,,,,,,,, -469501,,,0.8895312547683716,0.4101268053054809,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,208830.80399012569,226780.88972759247,208830.80399012569,17888.128987550735,38.08624720573425,0.0 -469600,3.5428567,1.734204,,,,,,,,,,,,,, -469700,3.191655,1.1751709,,,,,,,,,,,,,, -469800,2.9223328,1.3680702,,,,,,,,,,,,,, -469900,3.4108875,2.4943438,,,,,,,,,,,,,, -470000,3.6100667,3.153059,,,,,,,,,,,,,, -470100,3.0066414,1.2064424,,,,,,,,,,,,,, -470200,3.1677666,1.1282935,,,,,,,,,,,,,, -470300,3.2191744,2.911337,,,,,,,,,,,,,, -470400,2.959887,1.3685689,,,,,,,,,,,,,, -470446,,,0.8869921565055847,0.4194200932979584,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,209250.91027021408,227239.3307085037,209250.91027021408,17926.32308101654,38.17964243888855,0.0 -470500,3.122643,1.8023173,,,,,,,,,,,,,, -470600,3.013576,1.498439,,,,,,,,,,,,,, -470700,3.065337,2.5576286,,,,,,,,,,,,,, -470800,3.0439177,1.144872,,,,,,,,,,,,,, -470900,3.1843405,2.5838652,,,,,,,,,,,,,, -471000,2.9551666,1.1197131,,,,,,,,,,,,,, -471100,3.27697,2.4171796,,,,,,,,,,,,,, -471200,3.1190066,2.2393332,,,,,,,,,,,,,, -471300,2.961235,1.3495629,,,,,,,,,,,,,, -471383,,,0.8870898485183716,0.4167537093162536,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,209671.11908245087,227700.62208509445,209671.11908245087,17967.24115753174,38.296884059906006,0.0 -471400,3.141161,1.0635389,,,,,,,,,,,,,, -471500,3.5097923,1.1587551,,,,,,,,,,,,,, -471600,3.632951,3.1763592,,,,,,,,,,,,,, -471700,3.2350557,2.240122,,,,,,,,,,,,,, -471800,3.070647,1.0734187,,,,,,,,,,,,,, -471900,3.339801,1.5132638,,,,,,,,,,,,,, -472000,3.5201604,3.2021854,,,,,,,,,,,,,, -472100,3.7838295,2.9224052,,,,,,,,,,,,,, -472200,3.3410954,1.9932323,,,,,,,,,,,,,, -472300,3.182885,1.123054,,,,,,,,,,,,,, -472323,,,0.8866991996765137,0.418349415063858,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,210091.2394402027,228161.14769411087,210091.2394402027,18007.481026172638,38.41466212272644,0.0 -472400,3.2633746,1.7853067,,,,,,,,,,,,,, -472500,3.2506232,1.0854878,,,,,,,,,,,,,, -472600,3.5248094,1.0597247,,,,,,,,,,,,,, -472700,2.962116,1.8517789,,,,,,,,,,,,,, -472800,3.6766884,1.1620327,,,,,,,,,,,,,, -472900,3.7843635,1.3102405,,,,,,,,,,,,,, -473000,3.3355484,2.6450355,,,,,,,,,,,,,, -473100,3.4943035,3.024518,,,,,,,,,,,,,, -473200,2.7925892,1.9991074,,,,,,,,,,,,,, -473266,,,0.8862109184265137,0.419338971376419,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,210511.18789815903,228617.8493804932,210511.18789815903,18044.091156482697,38.51059150695801,0.0 -473300,3.36688,1.0149267,,,,,,,,,,,,,, -473400,3.2377465,1.4235927,,,,,,,,,,,,,, -473500,3.0311246,1.7882335,,,,,,,,,,,,,, -473600,3.0180216,1.2467427,,,,,,,,,,,,,, -473700,3.00061,1.6076875,,,,,,,,,,,,,, -473800,3.2908142,1.2346067,,,,,,,,,,,,,, -473900,3.2030263,2.2022374,,,,,,,,,,,,,, -474000,2.9786353,2.3510134,,,,,,,,,,,,,, -474100,3.2481165,1.1067168,,,,,,,,,,,,,, -474200,3.2420442,1.1686617,,,,,,,,,,,,,, -474207,,,0.8871679306030273,0.4201378226280212,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,210931.07780385017,229075.71579408649,210931.07780385017,18081.904654741287,38.62590575218201,0.0 -474300,3.834309,2.971809,,,,,,,,,,,,,, -474400,2.8734798,2.021968,,,,,,,,,,,,,, -474500,3.0900116,1.0999142,,,,,,,,,,,,,, -474600,3.1350348,1.1702,,,,,,,,,,,,,, -474700,3.2232563,1.1513169,,,,,,,,,,,,,, -474800,2.9651387,1.8668764,,,,,,,,,,,,,, -474900,3.0749197,1.1504068,,,,,,,,,,,,,, -475000,3.021024,1.1015892,,,,,,,,,,,,,, -475100,3.3419461,2.473911,,,,,,,,,,,,,, -475148,,,0.8871874809265137,0.4222822189331054,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,211351.1733222008,229536.30795049667,211351.1733222008,18122.236476182938,38.743818521499634,0.0 -475200,3.2098022,1.5671812,,,,,,,,,,,,,, -475300,2.9856467,1.2318854,,,,,,,,,,,,,, -475400,3.0731328,2.3288822,,,,,,,,,,,,,, -475500,2.9379442,1.106616,,,,,,,,,,,,,, -475600,3.9471018,3.2158322,,,,,,,,,,,,,, -475700,2.9644206,1.6200329,,,,,,,,,,,,,, -475800,3.3043287,1.6581147,,,,,,,,,,,,,, -475900,2.9952412,1.1667325,,,,,,,,,,,,,, -476000,2.8521726,1.4123635,,,,,,,,,,,,,, -476091,,,0.887499988079071,0.4217891097068786,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,211771.2523927689,229990.5814120769,211771.2523927689,18156.287446975708,38.8391637802124,0.0 -476100,3.060454,1.3433502,,,,,,,,,,,,,, -476200,3.195971,1.1144229,,,,,,,,,,,,,, -476300,2.9357667,2.1663036,,,,,,,,,,,,,, -476400,2.965598,1.2711785,,,,,,,,,,,,,, -476500,3.1558695,2.5328689,,,,,,,,,,,,,, -476600,3.361993,1.3400482,,,,,,,,,,,,,, -476700,3.0577712,2.3766487,,,,,,,,,,,,,, -476800,2.9938064,1.5631499,,,,,,,,,,,,,, -476900,2.9835114,1.4112794,,,,,,,,,,,,,, -477000,2.9304383,1.3847424,,,,,,,,,,,,,, -477034,,,0.8870702981948853,0.420778214931488,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,212191.43901252747,230450.12804293635,212191.43901252747,18195.48608493805,38.95311379432678,0.0 -477100,2.7551305,1.5548543,,,,,,,,,,,,,, -477200,4.5169024,3.282689,,,,,,,,,,,,,, -477300,3.9346585,1.9696363,,,,,,,,,,,,,, -477400,3.405755,1.0253191,,,,,,,,,,,,,, -477500,2.8291788,2.0437562,,,,,,,,,,,,,, -477600,3.0913002,1.2868685,,,,,,,,,,,,,, -477700,2.9416833,1.0601163,,,,,,,,,,,,,, -477800,3.9004984,3.1653223,,,,,,,,,,,,,, -477900,3.443629,1.2324662,,,,,,,,,,,,,, -477976,,,0.8859765529632568,0.4200300574302673,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,212611.59212899208,230903.41416811943,212611.59212899208,18228.449116706848,39.07531595230103,0.0 -478000,4.10714,3.1208177,,,,,,,,,,,,,, -478100,3.5737672,2.8885956,,,,,,,,,,,,,, -478200,3.0599303,1.1907547,,,,,,,,,,,,,, -478300,2.9252412,1.0750408,,,,,,,,,,,,,, -478400,2.8543026,2.2764447,,,,,,,,,,,,,, -478500,3.6293583,3.270597,,,,,,,,,,,,,, -478600,3.0819137,1.0951873,,,,,,,,,,,,,, -478700,3.314848,1.1754806,,,,,,,,,,,,,, -478800,3.3956828,1.7209548,,,,,,,,,,,,,, -478900,3.6827524,3.2679038,,,,,,,,,,,,,, -478919,,,0.8873242139816284,0.4137724637985229,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,213031.64986252785,231360.5252292156,213031.64986252785,18265.33827686309,39.19279146194458,0.0 -479000,3.1398187,1.1129134,,,,,,,,,,,,,, -479100,3.5859888,1.6727939,,,,,,,,,,,,,, -479200,3.1934364,1.1606194,,,,,,,,,,,,,, -479300,3.3360238,1.1069322,,,,,,,,,,,,,, -479400,3.430939,2.6555252,,,,,,,,,,,,,, -479500,3.250504,1.4762546,,,,,,,,,,,,,, -479600,3.3091943,2.4096143,,,,,,,,,,,,,, -479700,3.3393743,1.9475851,,,,,,,,,,,,,, -479800,3.0452485,1.9421895,,,,,,,,,,,,,, -479863,,,0.8858593702316284,0.4239358007907867,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,213451.8512706757,231823.8510775566,213451.8512706757,18308.297819375992,39.30975008010864,0.0 -479900,2.8353395,1.5777539,,,,,,,,,,,,,, -480000,2.8238814,1.6239495,,,,,,,,,,,,,, -480100,2.7685382,1.2931169,,,,,,,,,,,,,, -480200,3.1066642,1.2121009,,,,,,,,,,,,,, -480300,3.0845869,1.9794996,,,,,,,,,,,,,, -480400,3.374718,1.1096435,,,,,,,,,,,,,, -480500,3.432132,1.1555154,,,,,,,,,,,,,, -480600,2.9404833,1.0739207,,,,,,,,,,,,,, -480700,3.8152213,3.2976534,,,,,,,,,,,,,, -480800,2.9096713,1.9929442,,,,,,,,,,,,,, -480810,,,0.8883398175239563,0.4194519817829132,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,213871.92631745336,232286.0463590622,213871.92631745336,18350.265253067017,39.41498994827271,0.0 -480900,3.187399,2.1780298,,,,,,,,,,,,,, -481000,3.3181963,1.0848992,,,,,,,,,,,,,, -481100,3.8449104,3.141507,,,,,,,,,,,,,, -481200,3.1534295,2.2970717,,,,,,,,,,,,,, -481300,3.0731986,1.2026739,,,,,,,,,,,,,, -481400,3.66259,1.154032,,,,,,,,,,,,,, -481500,3.788009,2.9986036,,,,,,,,,,,,,, -481600,2.8556244,1.3409842,,,,,,,,,,,,,, -481700,3.0494337,1.1933668,,,,,,,,,,,,,, -481752,,,0.8857226371765137,0.4243901968002319,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,214291.8306343556,232746.0855693817,214291.8306343556,18390.25630378723,39.5118944644928,0.0 -481800,3.2675574,1.1362733,,,,,,,,,,,,,, -481900,3.174083,1.2110014,,,,,,,,,,,,,, -482000,3.9946804,3.2322183,,,,,,,,,,,,,, -482100,3.028259,1.1393093,,,,,,,,,,,,,, -482200,3.1239002,1.15413,,,,,,,,,,,,,, -482300,2.9841082,1.1320977,,,,,,,,,,,,,, -482400,3.0248244,1.3596778,,,,,,,,,,,,,, -482500,2.938131,1.6405064,,,,,,,,,,,,,, -482600,3.1618671,2.176015,,,,,,,,,,,,,, -482692,,,0.8866796493530273,0.420348048210144,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,214711.83514642715,233197.8092508316,214711.83514642715,18421.80694413185,39.63293313980103,0.0 -482700,3.319253,2.6791372,,,,,,,,,,,,,, -482800,3.6584952,3.1701179,,,,,,,,,,,,,, -482900,3.066921,1.6642879,,,,,,,,,,,,,, -483000,3.3544483,1.1606803,,,,,,,,,,,,,, -483100,3.1719031,1.4872503,,,,,,,,,,,,,, -483200,4.0600834,3.3866882,,,,,,,,,,,,,, -483300,3.238037,1.3197669,,,,,,,,,,,,,, -483400,2.9488163,1.1347278,,,,,,,,,,,,,, -483500,3.9490986,3.1894157,,,,,,,,,,,,,, -483600,3.188705,1.1513557,,,,,,,,,,,,,, -483636,,,0.8885937333106995,0.4157139956951141,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,215131.989664793,233664.60327482224,215131.989664793,18468.27669906616,39.75440168380737,0.0 -483700,3.6093805,2.928472,,,,,,,,,,,,,, -483800,3.5426056,3.1392767,,,,,,,,,,,,,, -483900,3.8777084,3.1820815,,,,,,,,,,,,,, -484000,3.1537712,1.6359365,,,,,,,,,,,,,, -484100,3.1909723,1.9627054,,,,,,,,,,,,,, -484200,3.2265053,1.527688,,,,,,,,,,,,,, -484300,3.3688948,1.963678,,,,,,,,,,,,,, -484400,3.8504868,2.8754396,,,,,,,,,,,,,, -484500,3.1282282,1.7076557,,,,,,,,,,,,,, -484577,,,0.8880664110183716,0.4205309450626373,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,215551.972063303,234124.7588033676,215551.972063303,18508.301794052124,39.85543179512024,0.0 -484600,3.01071,1.1163996,,,,,,,,,,,,,, -484700,3.2776368,2.1109638,,,,,,,,,,,,,, -484800,3.1527681,1.1651024,,,,,,,,,,,,,, -484900,3.2673707,1.1525514,,,,,,,,,,,,,, -485000,3.2466037,1.3746622,,,,,,,,,,,,,, -485100,3.0715127,2.0970163,,,,,,,,,,,,,, -485200,3.0295048,1.5309162,,,,,,,,,,,,,, -485300,3.0492082,1.2622464,,,,,,,,,,,,,, -485400,3.2057939,1.1229534,,,,,,,,,,,,,, -485500,3.4245975,1.1556948,,,,,,,,,,,,,, -485518,,,0.8874609470367432,0.4170184135437011,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,215972.1343464852,234577.79371976847,215972.1343464852,18541.02921271324,39.95377564430237,0.0 -485600,3.897976,3.299739,,,,,,,,,,,,,, -485700,2.942469,1.118743,,,,,,,,,,,,,, -485800,3.5529947,1.0534738,,,,,,,,,,,,,, -485900,3.1140268,1.1443812,,,,,,,,,,,,,, -486000,3.1283214,2.6525693,,,,,,,,,,,,,, -486100,3.0579844,1.0982705,,,,,,,,,,,,,, -486200,3.4688148,1.8500206,,,,,,,,,,,,,, -486300,3.094226,1.2967983,,,,,,,,,,,,,, -486400,3.501198,1.2582735,,,,,,,,,,,,,, -486459,,,0.8861523270606995,0.4273178875446319,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,216392.3759646416,235041.70927786827,216392.3759646416,18584.534476995468,40.07387638092041,0.0 -486500,3.158494,1.9839736,,,,,,,,,,,,,, -486600,3.3376305,1.0721298,,,,,,,,,,,,,, -486700,3.209766,2.5716496,,,,,,,,,,,,,, -486800,3.8625844,3.2485373,,,,,,,,,,,,,, -486900,3.2784278,1.258868,,,,,,,,,,,,,, -487000,2.9839647,1.1799774,,,,,,,,,,,,,, -487100,3.020685,1.8870997,,,,,,,,,,,,,, -487200,3.2948682,1.1434592,,,,,,,,,,,,,, -487300,3.5238516,2.8829558,,,,,,,,,,,,,, -487400,3.0015996,1.3434321,,,,,,,,,,,,,, -487404,,,0.8877733945846558,0.4160880446434021,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,216812.4767594337,235501.1575639248,216812.4767594337,18623.737052679066,40.17147946357727,0.0 -487500,3.3574567,2.9956875,,,,,,,,,,,,,, -487600,3.096644,2.0594847,,,,,,,,,,,,,, -487700,3.0647209,2.073326,,,,,,,,,,,,,, -487800,3.5147562,2.2372687,,,,,,,,,,,,,, -487900,2.8133337,1.0125761,,,,,,,,,,,,,, -488000,3.1822584,1.1472365,,,,,,,,,,,,,, -488100,3.1292167,2.5405624,,,,,,,,,,,,,, -488200,3.0319324,1.4507654,,,,,,,,,,,,,, -488300,3.867528,3.1386704,,,,,,,,,,,,,, -488346,,,0.8884570002555847,0.4186307489871979,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,217232.3497395516,235965.63935542107,217232.3497395516,18668.17640280724,40.29322910308838,0.0 -488400,3.4345307,1.1889323,,,,,,,,,,,,,, -488500,3.1935854,1.0398556,,,,,,,,,,,,,, -488600,3.0476875,2.0435038,,,,,,,,,,,,,, -488700,3.1262002,1.2207212,,,,,,,,,,,,,, -488800,3.4132755,2.6260936,,,,,,,,,,,,,, -488900,5.7388043,2.9619203,,,,,,,,,,,,,, -489000,3.0032158,1.095586,,,,,,,,,,,,,, -489100,2.7396886,1.5293245,,,,,,,,,,,,,, -489200,3.128317,1.1501496,,,,,,,,,,,,,, -489290,,,0.8884375095367432,0.4156650900840759,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,217652.6532354355,236423.1210463047,217652.6532354355,18705.208114624023,40.391053676605225,0.0 -489300,2.962754,1.2139065,,,,,,,,,,,,,, -489400,2.9762113,2.0754802,,,,,,,,,,,,,, -489500,3.090875,1.7178674,,,,,,,,,,,,,, -489600,3.006971,1.3720722,,,,,,,,,,,,,, -489700,3.5333712,3.2107773,,,,,,,,,,,,,, -489800,3.5606027,2.1717668,,,,,,,,,,,,,, -489900,3.9616404,3.3691325,,,,,,,,,,,,,, -490000,2.959573,1.0841154,,,,,,,,,,,,,, -490100,3.126097,1.2945365,,,,,,,,,,,,,, -490200,3.1620142,1.0970386,,,,,,,,,,,,,, -490233,,,0.8855859041213989,0.418851226568222,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,218072.6185135841,236889.67304754257,218072.6185135841,18751.618181467056,40.52058506011963,0.0 -490300,3.146057,1.1775249,,,,,,,,,,,,,, -490400,3.2571836,2.479837,,,,,,,,,,,,,, -490500,3.1013145,1.4382046,,,,,,,,,,,,,, -490600,3.1761794,2.3468208,,,,,,,,,,,,,, -490700,2.9831018,2.1430526,,,,,,,,,,,,,, -490800,3.3527837,1.1101791,,,,,,,,,,,,,, -490900,3.1111784,1.1240032,,,,,,,,,,,,,, -491000,3.1561933,2.423194,,,,,,,,,,,,,, -491100,3.3172522,2.3487136,,,,,,,,,,,,,, -491177,,,0.8892577886581421,0.4170732498168945,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,218493.1903386116,237351.99741578105,218493.1903386116,18793.20076608658,40.64298367500305,0.0 -491200,3.0400872,1.183408,,,,,,,,,,,,,, -491300,3.0778298,1.419972,,,,,,,,,,,,,, -491400,3.1895459,1.1825022,,,,,,,,,,,,,, -491500,3.1577592,2.053928,,,,,,,,,,,,,, -491600,3.419003,1.2052091,,,,,,,,,,,,,, -491700,3.3338923,1.1916984,,,,,,,,,,,,,, -491800,3.57849,2.6233828,,,,,,,,,,,,,, -491900,3.3472216,1.1765637,,,,,,,,,,,,,, -492000,3.053355,1.1355423,,,,,,,,,,,,,, -492100,3.1383648,1.3633853,,,,,,,,,,,,,, -492121,,,0.8887695074081421,0.4192838668823242,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,218913.1488711834,237808.29712486267,218913.1488711834,18829.3963367939,40.74140095710754,0.0 -492200,2.9416802,1.5866499,,,,,,,,,,,,,, -492300,3.0919454,1.2371821,,,,,,,,,,,,,, -492400,3.1322074,2.7882016,,,,,,,,,,,,,, -492500,3.2592812,1.052917,,,,,,,,,,,,,, -492600,3.139279,1.771366,,,,,,,,,,,,,, -492700,3.028051,1.1603926,,,,,,,,,,,,,, -492800,3.108079,1.2262044,,,,,,,,,,,,,, -492900,2.8181293,1.5316924,,,,,,,,,,,,,, -493000,3.7141223,3.1957057,,,,,,,,,,,,,, -493062,,,0.8889257907867432,0.4122632443904876,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,219333.3186454773,238267.8093366623,219333.3186454773,18868.54022550583,40.892817974090576,0.0 -493100,4.0383425,3.2045057,,,,,,,,,,,,,, -493200,3.780872,3.383819,,,,,,,,,,,,,, -493300,3.1073184,1.1166441,,,,,,,,,,,,,, -493400,2.8585327,1.0980449,,,,,,,,,,,,,, -493500,2.9917674,1.3765066,,,,,,,,,,,,,, -493600,3.0126042,1.087382,,,,,,,,,,,,,, -493700,3.172742,1.4856365,,,,,,,,,,,,,, -493800,3.184284,1.2813538,,,,,,,,,,,,,, -493900,3.0657992,1.8025029,,,,,,,,,,,,,, -494000,3.152144,1.1275463,,,,,,,,,,,,,, -494003,,,0.8864257335662842,0.4204651415348053,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,219753.3205792904,238725.52990865707,219753.3205792904,18906.0834171772,41.02065348625183,0.0 -494100,3.2128565,1.0695827,,,,,,,,,,,,,, -494200,3.02812,1.2140058,,,,,,,,,,,,,, -494300,3.3389611,2.9283035,,,,,,,,,,,,,, -494400,3.8615193,3.298937,,,,,,,,,,,,,, -494500,3.057954,1.1197319,,,,,,,,,,,,,, -494600,3.1429055,1.0580727,,,,,,,,,,,,,, -494700,3.265702,2.3613937,,,,,,,,,,,,,, -494800,3.54733,2.82984,,,,,,,,,,,,,, -494900,2.9003901,1.4739234,,,,,,,,,,,,,, -494941,,,0.8876562118530273,0.4156997203826904,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,220173.2574682236,239184.4129190445,220173.2574682236,18944.862055540085,41.14168167114258,0.0 -495000,3.2215333,2.375674,,,,,,,,,,,,,, -495100,3.1361098,1.2278157,,,,,,,,,,,,,, -495200,3.0053284,1.2576401,,,,,,,,,,,,,, -495300,3.337267,1.1600499,,,,,,,,,,,,,, -495400,2.8703644,1.1016779,,,,,,,,,,,,,, -495500,2.8299265,1.068743,,,,,,,,,,,,,, -495600,3.164823,1.0977962,,,,,,,,,,,,,, -495700,3.1025794,1.1087763,,,,,,,,,,,,,, -495800,2.997787,1.6253846,,,,,,,,,,,,,, -495883,,,0.8883593678474426,0.414913535118103,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,220593.34018707275,239644.45602679253,220593.34018707275,18984.65509462357,41.261394739151,0.0 -495900,3.2235653,2.8400161,,,,,,,,,,,,,, -496000,3.0325918,1.0942298,,,,,,,,,,,,,, -496100,2.9641764,1.0902746,,,,,,,,,,,,,, -496200,3.0524583,1.9454104,,,,,,,,,,,,,, -496300,2.7241871,1.1426504,,,,,,,,,,,,,, -496400,2.9994876,1.3074048,,,,,,,,,,,,,, -496500,2.9565039,1.0762277,,,,,,,,,,,,,, -496600,3.323131,1.2033174,,,,,,,,,,,,,, -496700,2.8708186,1.580168,,,,,,,,,,,,,, -496800,2.9461167,1.2191336,,,,,,,,,,,,,, -496827,,,0.8857421875,0.4213462769985199,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,221013.46103596687,240102.84364795685,221013.46103596687,19022.749748706818,41.38596391677856,0.0 -496900,3.2795355,1.1728147,,,,,,,,,,,,,, -497000,3.118084,1.141636,,,,,,,,,,,,,, -497100,3.5362291,2.9199245,,,,,,,,,,,,,, -497200,3.0904238,1.6363746,,,,,,,,,,,,,, -497300,2.9466224,1.19101,,,,,,,,,,,,,, -497400,3.7904494,1.1634943,,,,,,,,,,,,,, -497500,3.419679,1.1908132,,,,,,,,,,,,,, -497600,3.83727,3.289236,,,,,,,,,,,,,, -497700,3.6182497,3.358664,,,,,,,,,,,,,, -497773,,,0.8886523246765137,0.4136727452278137,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,221433.70231842995,240566.4491128921,221433.70231842995,19065.9390039444,41.51298332214356,0.0 -497800,5.14171,2.1695669,,,,,,,,,,,,,, -497900,3.0207777,1.4407998,,,,,,,,,,,,,, -498000,2.9801326,1.7145629,,,,,,,,,,,,,, -498100,3.2652447,1.1311696,,,,,,,,,,,,,, -498200,3.007814,1.1201072,,,,,,,,,,,,,, -498300,3.162748,1.1494974,,,,,,,,,,,,,, -498400,3.0153975,1.1923144,,,,,,,,,,,,,, -498500,3.218363,2.552783,,,,,,,,,,,,,, -498600,3.3763978,3.0078104,,,,,,,,,,,,,, -498700,3.1811857,1.1258361,,,,,,,,,,,,,, -498712,,,0.8863085508346558,0.4256725311279297,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,221853.9527220726,241029.9697060585,221853.9527220726,19109.039420604706,41.63520693778992,0.0 -498800,3.0886974,1.3047316,,,,,,,,,,,,,, -498900,2.885569,1.658084,,,,,,,,,,,,,, -499000,3.320412,1.4155706,,,,,,,,,,,,,, -499100,2.8398614,1.7449163,,,,,,,,,,,,,, -499200,3.0142004,2.682734,,,,,,,,,,,,,, -499300,3.8487887,3.191139,,,,,,,,,,,,,, -499400,3.5631423,3.0472636,,,,,,,,,,,,,, -499500,3.0291772,1.1379826,,,,,,,,,,,,,, -499600,3.45535,1.5749809,,,,,,,,,,,,,, -499657,,,0.8865820169448853,0.4248594641685486,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,222274.0769441128,241491.1147809029,222274.0769441128,19149.91508102417,41.73280024528504,0.0 -499700,3.2722795,2.8231049,,,,,,,,,,,,,, -499800,3.3145897,2.5592167,,,,,,,,,,,,,, -499900,3.0127802,2.2443588,,,,,,,,,,,,,, -500000,3.1569147,1.4445252,,,,,,,,,,,,,, -500100,3.1298077,2.5169806,,,,,,,,,,,,,, -500200,3.23906,2.0962374,,,,,,,,,,,,,, -500300,2.9935079,1.0829252,,,,,,,,,,,,,, -500400,3.1020315,1.6850318,,,,,,,,,,,,,, -500500,3.22574,1.6567737,,,,,,,,,,,,,, -500600,2.7594051,1.10089,,,,,,,,,,,,,, -500603,,,0.8859961032867432,0.4199767112731933,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,222694.147285223,241946.46710276604,222694.147285223,19185.05141544342,41.83038401603699,0.0 -500700,3.660809,2.8973055,,,,,,,,,,,,,, -500800,3.3267782,1.0623223,,,,,,,,,,,,,, -500900,3.0029516,1.0802972,,,,,,,,,,,,,, -501000,3.7418623,3.319207,,,,,,,,,,,,,, -501100,3.6907063,3.1473413,,,,,,,,,,,,,, -501200,3.243749,1.0361165,,,,,,,,,,,,,, -501300,3.3438354,1.1048656,,,,,,,,,,,,,, -501400,2.9182515,1.4842821,,,,,,,,,,,,,, -501500,3.8484216,3.0258176,,,,,,,,,,,,,, -501543,,,0.8876171708106995,0.4205630719661712,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,223114.2418487072,242406.2993843556,223114.2418487072,19224.61433959008,41.95796823501587,0.0 -501600,3.6180272,3.202408,,,,,,,,,,,,,, -501700,3.749409,3.031563,,,,,,,,,,,,,, -501800,3.1249707,2.4819465,,,,,,,,,,,,,, -501900,3.1641452,1.0527105,,,,,,,,,,,,,, -502000,3.6153831,1.1196638,,,,,,,,,,,,,, -502100,3.4374313,1.1566098,,,,,,,,,,,,,, -502200,3.7592075,1.8830833,,,,,,,,,,,,,, -502300,2.8135912,1.0218521,,,,,,,,,,,,,, -502400,2.9492505,1.0421498,,,,,,,,,,,,,, -502483,,,0.8876757621765137,0.4135380089282989,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,223534.12419629097,242867.21188783649,223534.12419629097,19265.47282481193,42.08222556114197,0.0 -502500,3.1523895,1.0387578,,,,,,,,,,,,,, -502600,3.2220259,1.2546196,,,,,,,,,,,,,, -502700,3.069433,1.6958283,,,,,,,,,,,,,, -502800,2.985024,1.6588291,,,,,,,,,,,,,, -502900,3.180142,1.1293063,,,,,,,,,,,,,, -503000,2.7663538,1.9752154,,,,,,,,,,,,,, -503100,2.8687196,1.9234613,,,,,,,,,,,,,, -503200,3.6888082,3.2726614,,,,,,,,,,,,,, -503300,3.380542,1.2006735,,,,,,,,,,,,,, -503381,,,0.8858984112739563,0.4253036081790924,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,223954.3198173046,243333.0746455193,223954.3198173046,19310.969926834103,42.20741128921509,0.0 -503400,2.9556103,1.1178615,,,,,,,,,,,,,, -503500,3.895798,3.2483685,,,,,,,,,,,,,, -503600,3.5214224,1.1892489,,,,,,,,,,,,,, -503700,3.3022764,2.5395377,,,,,,,,,,,,,, -503800,3.0995104,1.6329818,,,,,,,,,,,,,, -503900,3.481493,1.117457,,,,,,,,,,,,,, -504000,3.1382809,1.304282,,,,,,,,,,,,,, -504100,3.3035815,2.8354673,,,,,,,,,,,,,, -504200,3.4657917,1.2303631,,,,,,,,,,,,,, -504300,3.037952,1.4476986,,,,,,,,,,,,,, -504325,,,0.8860546946525574,0.4211524724960327,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,224374.5869781971,243798.6955487728,224374.5869781971,19356.176063776016,42.30738377571106,0.0 -504400,2.847134,1.9289043,,,,,,,,,,,,,, -504500,3.9424646,3.2924118,,,,,,,,,,,,,, -504600,3.2953897,1.0852953,,,,,,,,,,,,,, -504700,3.1753914,1.1483572,,,,,,,,,,,,,, -504800,3.0551226,1.0348235,,,,,,,,,,,,,, -504900,3.362893,1.1680148,,,,,,,,,,,,,, -505000,3.1370947,1.236294,,,,,,,,,,,,,, -505100,3.4870052,3.1301491,,,,,,,,,,,,,, -505200,2.9753056,1.1143787,,,,,,,,,,,,,, -505264,,,0.8863866925239563,0.4218363761901855,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,224794.48518443108,244257.38936328888,224794.48518443108,19394.80228829384,42.43042516708374,0.0 -505300,2.9022403,1.6070956,,,,,,,,,,,,,, -505400,3.193239,1.1400706,,,,,,,,,,,,,, -505500,4.5053625,2.8636487,,,,,,,,,,,,,, -505600,3.1577501,1.1170565,,,,,,,,,,,,,, -505700,4.1832705,3.2772806,,,,,,,,,,,,,, -505800,3.1021397,1.1554067,,,,,,,,,,,,,, -505900,3.5125117,1.9122813,,,,,,,,,,,,,, -506000,3.144811,1.2645104,,,,,,,,,,,,,, -506100,3.1064613,1.3566383,,,,,,,,,,,,,, -506199,,,0.8872656226158142,0.4188918173313141,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,225214.68184113505,244717.8974416256,225214.68184113505,19434.938640117645,42.55869674682617,0.0 -506200,3.4780235,3.0923254,,,,,,,,,,,,,, -506300,3.9683788,3.287833,,,,,,,,,,,,,, -506400,3.1895726,1.0409642,,,,,,,,,,,,,, -506500,3.3259964,2.7725096,,,,,,,,,,,,,, -506600,3.2430778,1.2162976,,,,,,,,,,,,,, -506700,2.7879765,1.5150626,,,,,,,,,,,,,, -506800,3.179823,1.1784978,,,,,,,,,,,,,, -506900,3.179013,1.7333066,,,,,,,,,,,,,, -507000,3.373775,3.032692,,,,,,,,,,,,,, -507100,3.3439276,3.0042353,,,,,,,,,,,,,, -507138,,,0.8887499570846558,0.4152712821960449,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,225634.7327599525,245174.6440844536,225634.7327599525,19471.48560786248,42.659632205963135,0.0 -507200,2.9845548,1.0356144,,,,,,,,,,,,,, -507300,3.0897608,1.209512,,,,,,,,,,,,,, -507400,3.2563028,1.1673352,,,,,,,,,,,,,, -507500,3.341613,1.6138564,,,,,,,,,,,,,, -507600,3.629276,2.971469,,,,,,,,,,,,,, -507700,3.220418,1.2663871,,,,,,,,,,,,,, -507800,2.9912252,1.8173977,,,,,,,,,,,,,, -507900,3.8860972,3.1767542,,,,,,,,,,,,,, -508000,3.7904353,1.2686208,,,,,,,,,,,,,, -508079,,,0.8889843821525574,0.4183682203292846,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,226054.76191449165,245633.1302809716,226054.76191449165,19509.77243232727,42.78360247612,0.0 -508100,2.9760447,1.3259987,,,,,,,,,,,,,, -508200,3.0826793,1.5478591,,,,,,,,,,,,,, -508300,3.4082942,2.7988806,,,,,,,,,,,,,, -508400,3.8107848,3.1135597,,,,,,,,,,,,,, -508500,3.9972878,3.3219213,,,,,,,,,,,,,, -508600,3.6382952,3.130799,,,,,,,,,,,,,, -508700,3.0964448,1.1496679,,,,,,,,,,,,,, -508800,2.9660106,1.5692813,,,,,,,,,,,,,, -508900,3.714628,3.2212164,,,,,,,,,,,,,, -509000,2.8948734,1.6284089,,,,,,,,,,,,,, -509023,,,0.88525390625,0.4223649799823761,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,226475.2010500431,246102.30951094627,226475.2010500431,19558.337856054302,42.91069960594177,0.0 -509100,3.0663586,1.1093705,,,,,,,,,,,,,, -509200,3.440696,2.8670948,,,,,,,,,,,,,, -509300,2.90708,1.0949868,,,,,,,,,,,,,, -509400,3.3307872,1.3277587,,,,,,,,,,,,,, -509500,3.1489222,1.0560836,,,,,,,,,,,,,, -509600,3.3413143,1.3163584,,,,,,,,,,,,,, -509700,3.4102247,2.770829,,,,,,,,,,,,,, -509800,3.8580892,3.2240224,,,,,,,,,,,,,, -509900,3.4262803,1.1648257,,,,,,,,,,,,,, -509966,,,0.8899023532867432,0.4177020490169525,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,226895.11728739736,246557.4968225956,226895.11728739736,19593.461698532104,43.01064705848694,0.0 -510000,3.3083563,2.8008938,,,,,,,,,,,,,, -510100,3.0753348,1.1674147,,,,,,,,,,,,,, -510200,3.734505,2.7730205,,,,,,,,,,,,,, -510300,3.1431465,1.0488452,,,,,,,,,,,,,, -510400,3.2976947,1.222151,,,,,,,,,,,,,, -510500,2.9222088,1.766914,,,,,,,,,,,,,, -510600,3.4979417,1.115972,,,,,,,,,,,,,, -510700,3.196259,2.5602417,,,,,,,,,,,,,, -510800,3.5725367,2.8872414,,,,,,,,,,,,,, -510900,3.0479121,1.6788952,,,,,,,,,,,,,, -510908,,,0.8851757645606995,0.4234828948974609,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,227315.35199666023,247014.7506687641,227315.35199666023,19630.303337335587,43.14096117019653,0.0 -511000,3.44966,3.0667396,,,,,,,,,,,,,, -511100,2.8403933,1.943361,,,,,,,,,,,,,, -511200,3.3160057,1.1812537,,,,,,,,,,,,,, -511300,2.8931937,2.10723,,,,,,,,,,,,,, -511400,3.4410853,3.0523033,,,,,,,,,,,,,, -511500,3.2384973,1.0662519,,,,,,,,,,,,,, -511600,3.1659389,1.1162285,,,,,,,,,,,,,, -511700,3.4559886,2.7882216,,,,,,,,,,,,,, -511800,2.776658,1.7141427,,,,,,,,,,,,,, -511850,,,0.8869335651397705,0.4181903004646301,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,227735.47498559952,247479.6405696869,227735.47498559952,19674.89552092552,43.26824116706848,0.0 -511900,3.415302,1.1330917,,,,,,,,,,,,,, -512000,3.1362655,2.7652214,,,,,,,,,,,,,, -512100,3.1591482,1.2373993,,,,,,,,,,,,,, -512200,3.003401,2.346759,,,,,,,,,,,,,, -512300,3.0248628,1.3902158,,,,,,,,,,,,,, -512400,3.66027,2.928529,,,,,,,,,,,,,, -512500,3.4322586,3.0255919,,,,,,,,,,,,,, -512600,3.1159642,1.1342925,,,,,,,,,,,,,, -512700,3.0764148,1.0753412,,,,,,,,,,,,,, -512799,,,0.8880273103713989,0.4198199510574341,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,228155.5176768303,247942.41296219823,228155.5176768303,19717.473808527,43.37141489982605,0.0 -512800,3.102396,1.2005451,,,,,,,,,,,,,, -512900,3.3708937,1.1299584,,,,,,,,,,,,,, -513000,3.7285924,3.1273887,,,,,,,,,,,,,, -513100,2.9497387,1.1034915,,,,,,,,,,,,,, -513200,2.9952264,1.6786041,,,,,,,,,,,,,, -513300,2.8781362,1.3766247,,,,,,,,,,,,,, -513400,3.0692446,1.575799,,,,,,,,,,,,,, -513500,2.8882825,1.0486366,,,,,,,,,,,,,, -513600,3.139176,1.1939952,,,,,,,,,,,,,, -513700,3.3251104,2.8903599,,,,,,,,,,,,,, -513744,,,0.8875195384025574,0.4158684015274048,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,228575.4135878086,248401.4799234867,228575.4135878086,19756.49240231514,43.47539925575256,0.0 -513800,3.1954,1.0388038,,,,,,,,,,,,,, -513900,3.2129765,1.386763,,,,,,,,,,,,,, -514000,3.0686526,1.0415289,,,,,,,,,,,,,, -514100,3.0131593,1.8909335,,,,,,,,,,,,,, -514200,3.7978897,3.0656776,,,,,,,,,,,,,, -514300,2.8004122,1.6579223,,,,,,,,,,,,,, -514400,2.8195047,1.0888522,,,,,,,,,,,,,, -514500,3.131248,1.248768,,,,,,,,,,,,,, -514600,3.0585465,1.1916739,,,,,,,,,,,,,, -514686,,,0.8874218463897705,0.4217660427093506,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,228995.6517190933,248865.8999125957,228995.6517190933,19800.474741220474,43.62744307518005,0.0 -514700,3.028159,1.7282959,,,,,,,,,,,,,, -514800,3.565663,2.3065019,,,,,,,,,,,,,, -514900,3.4056184,1.1991799,,,,,,,,,,,,,, -515000,2.9727561,1.7735369,,,,,,,,,,,,,, -515100,2.9268317,1.0461699,,,,,,,,,,,,,, -515200,3.7836654,3.2296286,,,,,,,,,,,,,, -515300,3.003331,1.0304474,,,,,,,,,,,,,, -515400,2.9758904,1.2063967,,,,,,,,,,,,,, -515500,3.62757,3.0180569,,,,,,,,,,,,,, -515600,3.237955,1.1332798,,,,,,,,,,,,,, -515625,,,0.8893554210662842,0.4156787693500519,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,229415.50728487968,249319.9051039219,229415.50728487968,19834.475734710693,43.72905158996582,0.0 -515700,3.0456147,1.1376302,,,,,,,,,,,,,, -515800,3.6440275,2.996581,,,,,,,,,,,,,, -515900,2.8766513,1.1682873,,,,,,,,,,,,,, -516000,3.2348764,1.3469803,,,,,,,,,,,,,, -516100,3.1259675,1.1417966,,,,,,,,,,,,,, -516200,2.8866346,1.1771387,,,,,,,,,,,,,, -516300,3.1142523,1.0340048,,,,,,,,,,,,,, -516400,2.8758001,1.7385755,,,,,,,,,,,,,, -516500,3.3873641,1.3395907,,,,,,,,,,,,,, -516567,,,0.8910155892372131,0.4081867933273315,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,229835.58886671063,249778.6609852314,229835.58886671063,19872.9740600586,43.85731816291809,0.0 -516600,3.1150408,1.1617814,,,,,,,,,,,,,, -516700,3.2480783,1.0685809,,,,,,,,,,,,,, -516800,3.188034,1.2060078,,,,,,,,,,,,,, -516900,3.266432,2.3146129,,,,,,,,,,,,,, -517000,3.4100337,3.140187,,,,,,,,,,,,,, -517100,3.009782,1.7399713,,,,,,,,,,,,,, -517200,3.1368635,1.1588057,,,,,,,,,,,,,, -517300,3.0698395,1.1173863,,,,,,,,,,,,,, -517400,4.1905737,3.279366,,,,,,,,,,,,,, -517500,3.117842,2.530925,,,,,,,,,,,,,, -517509,,,0.8847851157188416,0.4224032461643219,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,230255.75998997688,250242.39598941803,230255.75998997688,19916.36367583275,43.98423504829407,0.0 -517600,2.8386364,2.0437143,,,,,,,,,,,,,, -517700,3.18505,1.1005656,,,,,,,,,,,,,, -517800,3.2934492,1.0529752,,,,,,,,,,,,,, -517900,2.9920754,1.0892956,,,,,,,,,,,,,, -518000,3.2176282,1.0863686,,,,,,,,,,,,,, -518100,3.926822,3.1664138,,,,,,,,,,,,,, -518200,3.088356,1.082847,,,,,,,,,,,,,, -518300,3.572976,3.125886,,,,,,,,,,,,,, -518400,3.0510716,1.1800351,,,,,,,,,,,,,, -518454,,,0.8889062404632568,0.4124764800071716,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,230676.0453734398,250698.45132374763,230676.0453734398,19951.98074555397,44.08957409858704,0.0 -518500,3.5267868,3.1387906,,,,,,,,,,,,,, -518600,3.061735,1.1793728,,,,,,,,,,,,,, -518700,3.223344,1.1835257,,,,,,,,,,,,,, -518800,3.0389802,2.4534824,,,,,,,,,,,,,, -518900,2.925258,1.0267668,,,,,,,,,,,,,, -519000,3.6362932,3.0796502,,,,,,,,,,,,,, -519100,3.418342,2.9164531,,,,,,,,,,,,,, -519200,3.0206604,1.080596,,,,,,,,,,,,,, -519300,3.2828937,1.1470367,,,,,,,,,,,,,, -519396,,,0.8869726657867432,0.4183215200901031,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,231096.0907754898,251160.64197206497,231096.0907754898,19993.95086407661,44.21716904640198,0.0 -519400,3.3460584,1.2231447,,,,,,,,,,,,,, -519500,3.8829434,3.3830395,,,,,,,,,,,,,, -519600,2.8614862,1.8799536,,,,,,,,,,,,,, -519700,3.9983666,3.131715,,,,,,,,,,,,,, -519800,3.0666702,1.1222216,,,,,,,,,,,,,, -519900,3.289721,2.810081,,,,,,,,,,,,,, -520000,2.9835176,1.639486,,,,,,,,,,,,,, -520100,3.3755074,1.5211879,,,,,,,,,,,,,, -520200,3.2629838,1.2005273,,,,,,,,,,,,,, -520300,2.9785967,1.8286582,,,,,,,,,,,,,, -520338,,,0.8862890601158142,0.4208606481552124,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,231516.16703391075,251622.5686440468,231516.16703391075,20035.627217292786,44.343384742736816,0.0 -520400,2.7891374,1.5539901,,,,,,,,,,,,,, -520500,4.0710006,3.088256,,,,,,,,,,,,,, -520600,3.0245662,1.9133039,,,,,,,,,,,,,, -520700,4.393525,3.2738967,,,,,,,,,,,,,, -520800,4.228197,3.2143831,,,,,,,,,,,,,, -520900,3.0562618,1.1518598,,,,,,,,,,,,,, -521000,3.1704774,1.1329595,,,,,,,,,,,,,, -521100,4.0057616,3.3950164,,,,,,,,,,,,,, -521200,3.0507624,1.6608772,,,,,,,,,,,,,, -521278,,,0.8882226347923279,0.4174492061138153,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,231936.3868584633,252086.10162854195,231936.3868584633,20078.789390325543,44.4472975730896,0.0 -521300,3.449647,2.5092468,,,,,,,,,,,,,, -521400,3.3945265,2.5766306,,,,,,,,,,,,,, -521500,4.0553775,3.0681593,,,,,,,,,,,,,, -521600,3.051878,1.5843681,,,,,,,,,,,,,, -521700,3.2931535,2.2023356,,,,,,,,,,,,,, -521800,3.2343254,2.5905623,,,,,,,,,,,,,, -521900,3.5730324,3.0594285,,,,,,,,,,,,,, -522000,3.160044,1.1245382,,,,,,,,,,,,,, -522100,3.117535,1.1695733,,,,,,,,,,,,,, -522200,2.9937606,1.0847389,,,,,,,,,,,,,, -522219,,,0.8867968320846558,0.4223299324512481,0.7803199887275696,0.8548135757446289,50000.0,0.6648000478744507,1.4506032466888428,10000.0,232356.45809936523,252546.62259054184,232356.45809936523,20119.08979439736,44.54976439476013,0.0 -522300,3.1408012,2.3429255,,,,,,,,,,,,,, -522400,4.00904,2.892801,,,,,,,,,,,,,, -522500,3.0653822,1.114464,,,,,,,,,,,,,, -522600,4.59593,3.0881882,,,,,,,,,,,,,, -522682,,,,,,,,,,,232560.16098117828,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index b2996e395..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,42 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -173.14654231071472,0.0,61.19409894943237,1,0,61.19409894943237,30.174337,2472,1.4389332358377511,234.3406932353973,31.231562,1.1201143202215604,30.051725,5348,1.4357627658650087 -296.4638662338257,0.0394837856292724,1502.108412981033,1791,0,1502.108412981033,3.14853,2472,0.5956980074340381,1798.6836760044098,3.0233684,0.6003480875780107,3.4465265,5348,0.6467941724514129 -424.61270332336426,0.090696096420288,2942.371209859848,3610,0,2942.371209859848,0.63590884,2472,0.2060000406231592,3367.225006580353,0.59163195,0.1998401832565325,0.9076137,5348,0.265068499763461 -554.705050945282,0.1443111896514892,4382.683770656586,5415,0,4382.683770656586,0.48688084,2472,0.161436434911543,4937.763836860657,0.4084525,0.146960376572787,0.73416996,5348,0.2211687922994487 -684.926840543747,0.1929576396942138,5823.219467878342,7194,0,5823.219467878342,0.42035604,2472,0.1408811163244165,6508.647459983826,0.36481315,0.1303396202703976,0.65578574,5348,0.1987699972001506 -815.4455087184906,0.2454867362976074,7263.481862545013,8953,0,7263.481862545013,0.3899724,2472,0.1313143623179574,8079.55590224266,0.33626315,0.1227793602818065,0.62236243,5348,0.1883815905075451 -945.2633166313173,0.299727201461792,8704.104737520218,10732,0,8704.104737520218,0.36532426,2472,0.12373814311539,9650.128216028214,0.3294852,0.1173426655499303,0.59158397,5348,0.178553153692422 -1075.877376317978,0.3549344539642334,10144.391798257828,12489,0,10144.391798257828,0.35092798,2472,0.1189242987427132,11221.16163301468,0.29332313,0.1067789760279071,0.56526625,5348,0.1701149869179451 -1207.8728342056274,0.4063010215759277,11585.176102399826,14244,0,11585.176102399826,0.33576742,2472,0.1140901427904048,12794.070428848268,0.2573273,0.0954731012246808,0.5519761,5348,0.1652007685103835 -1340.1641960144043,0.4635465145111084,13025.042648553848,16014,0,13025.042648553848,0.3204342,2472,0.1075295025694148,14366.364121437073,0.23602745,0.0879293926126116,0.528238,5348,0.1569267308379273 -1471.3020498752594,0.5215849876403809,14465.142292261124,17763,0,14465.142292261124,0.3125017,2472,0.1053561635488392,15937.739381551744,0.23291294,0.0894099223991376,0.5173014,5348,0.1559129922666229 -1602.643741607666,0.571497917175293,15905.654467105864,19499,0,15905.654467105864,0.3041531,2472,0.1017610139540552,17509.719973564148,0.24477959,0.0886326978502868,0.5056676,5348,0.1514815065120635 -1734.2814166545868,0.6246070861816406,17346.12431693077,21243,0,17346.12431693077,0.2986674,2472,0.099790790729795,19081.95854330063,0.24407542,0.0899781043079114,0.4981141,5348,0.1499174527163366 -1867.4430181980133,0.6834304332733154,18786.23919892311,23011,0,18786.23919892311,0.28946954,2472,0.0951394389941705,20655.370346546173,0.22209774,0.0820917789896079,0.48654518,5348,0.1449259970842947 -1998.485382080078,0.7370121479034424,20226.15410375595,24741,0,20226.15410375595,0.28323427,2472,0.0961347063961164,22226.457488298416,0.20882161,0.0786548287656771,0.4894069,5348,0.1448391052067544 -2129.420209884644,0.7897412776947021,21666.161897182465,26512,0,21666.161897182465,0.2768741,2472,0.0906505798955984,23797.530810832977,0.1967695,0.0746498343723101,0.4649453,5348,0.1382835957789857 -2261.6781487464905,0.8389241695404053,23106.45556807518,28310,0,23106.45556807518,0.26881522,2472,0.0909958767493347,25370.210528612137,0.1956766,0.0728948256640517,0.46214008,5348,0.1379843015341243 -2391.651878118515,0.8914635181427002,24546.814392089844,30074,0,24546.814392089844,0.26378438,2472,0.0888225377287591,26940.67303466797,0.20220742,0.0759609551048708,0.4574062,5348,0.1352037614528322 -2521.70818734169,0.942244291305542,25987.314188480377,31824,0,25987.314188480377,0.2555178,2472,0.0871163650397091,28511.35634493828,0.18969351,0.0693004302790574,0.44134903,5348,0.1326163144327408 -2660.7630712985992,0.9998691082000732,27427.76982665062,33573,0,27427.76982665062,0.24836451,2472,0.0830540491133995,30091.00088500977,0.19449702,0.0708998351047162,0.43590006,5348,0.1275669308823387 -2792.369478940964,1.0533790588378906,28868.35710072517,35337,0,28868.35710072517,0.24490088,2472,0.0808807100928239,31663.327996253967,0.1883224,0.0683738823704935,0.4340027,5348,0.1286096334128233 -2924.0652084350586,1.1087830066680908,30308.52282428741,37063,0,30308.52282428741,0.23621194,2472,0.0795198342575102,33235.32124829292,0.14173692,0.0543289976732982,0.41600996,5348,0.1218031030054935 -3055.453970432281,1.1656112670898438,31749.41622185707,38807,0,31749.41622185707,0.23410633,2472,0.0759449962423577,34807.73562335968,0.16449861,0.0600990475385947,0.410855,5348,0.1203935236587273 -3185.8636391162872,1.2198340892791748,33189.860087394714,40584,0,33189.860087394714,0.23114812,2472,0.0747059898848333,36378.72210025787,0.20285141,0.0741797432239657,0.40358555,5348,0.1180184790059569 -3316.1265754699707,1.2786462306976318,34629.90348124504,42318,0,34629.90348124504,0.22265244,2472,0.0721670424308898,37949.16435742378,0.20199291,0.0740769525416766,0.39573586,5348,0.1171109416183129 -3445.9241964817047,1.341212272644043,36070.11869764328,44053,0,36070.11869764328,0.21891987,2472,0.072248288749416,39519.31596660614,0.22752737,0.08577896590251,0.38816926,5348,0.1146393504349421 -3576.526694059372,1.396867036819458,37509.99157762528,45799,0,37509.99157762528,0.21051346,2472,0.0695468486584201,41089.9249560833,0.2029801,0.0721935366119317,0.3828147,5348,0.1124284348841924 -3706.613924980164,1.4564719200134275,38950.62455177307,47560,0,38950.62455177307,0.21286468,2472,0.0688359433713159,42660.78283596039,0.18338831,0.068447774821017,0.37283552,5348,0.1085569190071154 -3837.757830858231,1.5125069618225098,40390.96273756027,49288,0,40390.96273756027,0.20216686,2472,0.0661751264395832,44232.39781737328,0.15685317,0.0596471957280666,0.3723242,5348,0.1092520540274385 -3968.502446889877,1.5761635303497314,41831.39310407639,51047,0,41831.39310407639,0.20247392,2472,0.0669063433063189,45803.714185237885,0.17114201,0.0648845633720203,0.3645804,5348,0.106403931374726 -4101.162638664246,1.6361210346221924,43271.88686680794,52804,0,43271.88686680794,0.19651937,2472,0.0634533747689557,47377.0074198246,0.14830387,0.0573462463494245,0.35951543,5348,0.1038647576199349 -4233.571403264999,1.7039833068847656,44711.75618267059,54528,0,44711.75618267059,0.1865691,2472,0.0607519346779599,48949.43096876144,0.14442445,0.0558153158528515,0.35094666,5348,0.1008235419060216 -4364.2648112773895,1.765101194381714,46151.96347117424,56266,0,46151.96347117424,0.18515553,2472,0.0613003473280117,50520.468977451324,0.1373797,0.0521849551414768,0.34084773,5348,0.0978981820288288 -4493.534925699234,1.8188579082489007,47592.51656937599,57996,0,47592.51656937599,0.18242605,2472,0.0595738630593301,52090.4231069088,0.13606107,0.051268843789705,0.33888745,5348,0.0969134074167044 -4624.469212293625,1.8802030086517327,49032.50491023064,59699,0,49032.50491023064,0.17691371,2472,0.0585379724981211,53661.48346114159,0.13671345,0.0510678617715287,0.33184198,5348,0.0946928372128947 -4755.3815841674805,1.945399284362793,50472.42002773285,61411,0,50472.42002773285,0.17471576,2472,0.0569333577072288,55232.452363967896,0.11137697,0.0431546618243773,0.32254496,5348,0.0926653600702858 -4885.515163421631,2.0081372261047363,51912.341347932816,63155,0,51912.341347932816,0.17560743,2472,0.0565271261145979,56802.64616537094,0.11432057,0.0435226598512162,0.3218405,5348,0.0907054654990972 -5018.692674398422,2.0654304027557373,53352.50097608566,64852,0,53352.50097608566,0.16428222,2472,0.0530741575772347,58376.11468625069,0.104371265,0.04056934339151,0.3143445,5348,0.0894117419890516 -5148.585556030273,2.1266865730285645,54792.7521352768,66559,0,54792.7521352768,0.16560408,2472,0.0542522291958645,59946.395359277725,0.10364472,0.0397125890330625,0.3120098,5348,0.0878766521525049 -5278.0486397743225,2.187870979309082,56232.73989892006,68288,0,56232.73989892006,0.161813,2472,0.0521804480734466,61515.9838206768,0.09611028,0.0371413877149437,0.30804402,5348,0.0861581239078173 -5409.477321624756,2.2469236850738525,57673.28281927109,69983,0,57673.28281927109,0.15796709,2472,0.05222107123270977,63088.09074831009,0.092972785,0.03547037401901641,0.30447015,5348,0.08587813896907615 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index eca293336..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,743 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,57.549614,32.68696,,,,,,,,,,,,,, -1,,,31.231562,1.1201143202215604,30.051725,1.4357627658650087,5348.0,30.174337,1.4389332358377511,2472.0,61.19409894943237,234.3406932353973,61.19409894943237,173.14654231071472,0.0,0.0 -100,0.5291343,5.9551167,,,,,,,,,,,,,, -200,0.7200822,5.819937,,,,,,,,,,,,,, -300,1.0256727,5.82024,,,,,,,,,,,,,, -400,1.5901171,5.7644134,,,,,,,,,,,,,, -500,0.45581523,5.760982,,,,,,,,,,,,,, -600,1.9840364,5.784789,,,,,,,,,,,,,, -700,2.5852633,5.5646954,,,,,,,,,,,,,, -800,1.991635,5.318555,,,,,,,,,,,,,, -900,0.9209955,4.1451774,,,,,,,,,,,,,, -1000,0.9893736,3.5109003,,,,,,,,,,,,,, -1100,1.5917709,3.233417,,,,,,,,,,,,,, -1200,0.6814035,3.0025542,,,,,,,,,,,,,, -1300,0.7630158,2.86026,,,,,,,,,,,,,, -1400,1.1338953,2.8161988,,,,,,,,,,,,,, -1500,1.2050257,2.6688805,,,,,,,,,,,,,, -1600,0.71932185,2.5635848,,,,,,,,,,,,,, -1700,0.7436488,2.4778073,,,,,,,,,,,,,, -1791,,,3.0233684,0.6003480875780107,3.4465265,0.6467941724514129,5348.0,3.14853,0.5956980074340381,2472.0,1502.108412981033,1798.6836760044098,1502.108412981033,296.4638662338257,0.0394837856292724,0.0 -1800,1.1641936,2.3983424,,,,,,,,,,,,,, -1900,0.6861397,2.3067358,,,,,,,,,,,,,, -2000,0.747141,2.275282,,,,,,,,,,,,,, -2100,0.77769613,2.181769,,,,,,,,,,,,,, -2200,0.8679224,2.194666,,,,,,,,,,,,,, -2300,0.738713,2.0995667,,,,,,,,,,,,,, -2400,1.2712567,2.099398,,,,,,,,,,,,,, -2500,0.70654815,1.9705663,,,,,,,,,,,,,, -2600,0.9189004,2.0454438,,,,,,,,,,,,,, -2700,0.9979005,2.0588355,,,,,,,,,,,,,, -2800,0.65578705,2.0021374,,,,,,,,,,,,,, -2900,0.89222854,1.9855012,,,,,,,,,,,,,, -3000,0.61468107,1.9590611,,,,,,,,,,,,,, -3100,0.5891937,1.8973802,,,,,,,,,,,,,, -3200,0.6134435,1.9280387,,,,,,,,,,,,,, -3300,0.6003213,1.8847165,,,,,,,,,,,,,, -3400,0.5985744,1.9183903,,,,,,,,,,,,,, -3500,0.68554527,1.8326272,,,,,,,,,,,,,, -3600,0.7240346,1.8358896,,,,,,,,,,,,,, -3610,,,0.59163195,0.1998401832565325,0.9076137,0.265068499763461,5348.0,0.63590884,0.2060000406231592,2472.0,2942.371209859848,3367.225006580353,2942.371209859848,424.61270332336426,0.090696096420288,0.0 -3700,0.6143134,1.7831413,,,,,,,,,,,,,, -3800,0.5506622,1.8442405,,,,,,,,,,,,,, -3900,0.63076353,1.800431,,,,,,,,,,,,,, -4000,0.62629527,1.777475,,,,,,,,,,,,,, -4100,0.7607052,1.8164601,,,,,,,,,,,,,, -4200,0.5742911,1.8035163,,,,,,,,,,,,,, -4300,0.5571632,1.7354863,,,,,,,,,,,,,, -4400,0.4948258,1.7041279,,,,,,,,,,,,,, -4500,0.66219074,1.7764097,,,,,,,,,,,,,, -4600,0.62353784,1.8296325,,,,,,,,,,,,,, -4700,0.53011817,1.7385819,,,,,,,,,,,,,, -4800,0.58308,1.7587624,,,,,,,,,,,,,, -4900,0.49954584,1.700293,,,,,,,,,,,,,, -5000,0.58714116,1.7267019,,,,,,,,,,,,,, -5100,0.56932855,1.6764729,,,,,,,,,,,,,, -5200,0.43270877,1.6977955,,,,,,,,,,,,,, -5300,0.5265885,1.6699892,,,,,,,,,,,,,, -5400,0.56635696,1.632443,,,,,,,,,,,,,, -5415,,,0.4084525,0.146960376572787,0.73416996,0.2211687922994487,5348.0,0.48688084,0.161436434911543,2472.0,4382.683770656586,4937.763836860657,4382.683770656586,554.705050945282,0.1443111896514892,0.0 -5500,0.5809053,1.6762621,,,,,,,,,,,,,, -5600,0.6359267,1.6750693,,,,,,,,,,,,,, -5700,0.4216226,1.649662,,,,,,,,,,,,,, -5800,0.5887393,1.6627089,,,,,,,,,,,,,, -5900,0.60868645,1.653303,,,,,,,,,,,,,, -6000,0.48234507,1.606213,,,,,,,,,,,,,, -6100,0.5381482,1.6504363,,,,,,,,,,,,,, -6200,0.53927505,1.6448147,,,,,,,,,,,,,, -6300,0.49513915,1.6425595,,,,,,,,,,,,,, -6400,0.5168737,1.6723715,,,,,,,,,,,,,, -6500,0.65983456,1.6299622,,,,,,,,,,,,,, -6600,0.50261915,1.5619805,,,,,,,,,,,,,, -6700,0.46532702,1.5932176,,,,,,,,,,,,,, -6800,0.547731,1.652613,,,,,,,,,,,,,, -6900,0.63365436,1.6432676,,,,,,,,,,,,,, -7000,0.45675322,1.6106875,,,,,,,,,,,,,, -7100,0.5611325,1.5296451,,,,,,,,,,,,,, -7194,,,0.36481315,0.1303396202703976,0.65578574,0.1987699972001506,5348.0,0.42035604,0.1408811163244165,2472.0,5823.219467878342,6508.647459983826,5823.219467878342,684.926840543747,0.1929576396942138,0.0 -7200,0.65926826,1.5715326,,,,,,,,,,,,,, -7300,0.5469623,1.577925,,,,,,,,,,,,,, -7400,0.79801756,1.606821,,,,,,,,,,,,,, -7500,0.49380538,1.6034577,,,,,,,,,,,,,, -7600,0.60982674,1.5952252,,,,,,,,,,,,,, -7700,0.6435374,1.5653499,,,,,,,,,,,,,, -7800,0.5701338,1.6465791,,,,,,,,,,,,,, -7900,0.4884371,1.6288203,,,,,,,,,,,,,, -8000,0.5678629,1.5729731,,,,,,,,,,,,,, -8100,0.46460205,1.6000307,,,,,,,,,,,,,, -8200,0.4732476,1.6335396,,,,,,,,,,,,,, -8300,0.5377221,1.5797234,,,,,,,,,,,,,, -8400,0.5264165,1.5602561,,,,,,,,,,,,,, -8500,0.4989956,1.6092174,,,,,,,,,,,,,, -8600,0.5893933,1.4576327,,,,,,,,,,,,,, -8700,0.42680418,1.5368919,,,,,,,,,,,,,, -8800,0.6213263,1.6012261,,,,,,,,,,,,,, -8900,0.47643188,1.5700397,,,,,,,,,,,,,, -8953,,,0.33626315,0.1227793602818065,0.62236243,0.1883815905075451,5348.0,0.3899724,0.1313143623179574,2472.0,7263.481862545013,8079.55590224266,7263.481862545013,815.4455087184906,0.2454867362976074,0.0 -9000,0.5096922,1.5268081,,,,,,,,,,,,,, -9100,0.49137852,1.4949733,,,,,,,,,,,,,, -9200,0.4803939,1.5901893,,,,,,,,,,,,,, -9300,0.5266543,1.539298,,,,,,,,,,,,,, -9400,0.4944701,1.5140427,,,,,,,,,,,,,, -9500,0.4934444,1.4915621,,,,,,,,,,,,,, -9600,0.5426326,1.5849317,,,,,,,,,,,,,, -9700,0.5295103,1.4609107,,,,,,,,,,,,,, -9800,0.6239318,1.5212693,,,,,,,,,,,,,, -9900,0.4267903,1.5198007,,,,,,,,,,,,,, -10000,0.5784571,1.5667603,,,,,,,,,,,,,, -10100,0.62454414,1.5277113,,,,,,,,,,,,,, -10200,0.45003256,1.5323304,,,,,,,,,,,,,, -10300,0.4465586,1.451582,,,,,,,,,,,,,, -10400,0.43135858,1.4665661,,,,,,,,,,,,,, -10500,0.38681456,1.4285275,,,,,,,,,,,,,, -10600,0.5990269,1.4669166,,,,,,,,,,,,,, -10700,0.65500945,1.490895,,,,,,,,,,,,,, -10732,,,0.3294852,0.1173426655499303,0.59158397,0.178553153692422,5348.0,0.36532426,0.12373814311539,2472.0,8704.104737520218,9650.128216028214,8704.104737520218,945.2633166313173,0.299727201461792,0.0 -10800,0.4931559,1.4838082,,,,,,,,,,,,,, -10900,0.52358943,1.5552138,,,,,,,,,,,,,, -11000,0.49090198,1.4857899,,,,,,,,,,,,,, -11100,0.5849099,1.5303012,,,,,,,,,,,,,, -11200,0.45476612,1.4933014,,,,,,,,,,,,,, -11300,0.5642379,1.4677582,,,,,,,,,,,,,, -11400,0.6463124,1.467438,,,,,,,,,,,,,, -11500,0.4857045,1.4962823,,,,,,,,,,,,,, -11600,0.46468857,1.4665275,,,,,,,,,,,,,, -11700,0.52537274,1.467051,,,,,,,,,,,,,, -11800,0.45850584,1.4465959,,,,,,,,,,,,,, -11900,0.4719053,1.4413074,,,,,,,,,,,,,, -12000,0.44203445,1.4439723,,,,,,,,,,,,,, -12100,0.5242597,1.4554958,,,,,,,,,,,,,, -12200,0.63888955,1.4627662,,,,,,,,,,,,,, -12300,0.44319052,1.4704391,,,,,,,,,,,,,, -12400,0.5068385,1.4497316,,,,,,,,,,,,,, -12489,,,0.29332313,0.1067789760279071,0.56526625,0.1701149869179451,5348.0,0.35092798,0.1189242987427132,2472.0,10144.391798257828,11221.16163301468,10144.391798257828,1075.877376317978,0.3549344539642334,0.0 -12500,0.51994073,1.4498748,,,,,,,,,,,,,, -12600,0.51595724,1.4534954,,,,,,,,,,,,,, -12700,0.5756588,1.3922946,,,,,,,,,,,,,, -12800,0.38245422,1.4431328,,,,,,,,,,,,,, -12900,0.52300584,1.4131413,,,,,,,,,,,,,, -13000,0.52874666,1.447836,,,,,,,,,,,,,, -13100,0.5041785,1.3957671,,,,,,,,,,,,,, -13200,0.51923823,1.461888,,,,,,,,,,,,,, -13300,0.42595556,1.4574376,,,,,,,,,,,,,, -13400,0.5417939,1.384817,,,,,,,,,,,,,, -13500,0.49784732,1.45403,,,,,,,,,,,,,, -13600,0.5959624,1.4625822,,,,,,,,,,,,,, -13700,0.59320277,1.5070835,,,,,,,,,,,,,, -13800,0.47576708,1.4366684,,,,,,,,,,,,,, -13900,0.46139413,1.3997447,,,,,,,,,,,,,, -14000,0.51107395,1.45203,,,,,,,,,,,,,, -14100,0.62596625,1.4062965,,,,,,,,,,,,,, -14200,0.42538565,1.3961436,,,,,,,,,,,,,, -14244,,,0.2573273,0.0954731012246808,0.5519761,0.1652007685103835,5348.0,0.33576742,0.1140901427904048,2472.0,11585.176102399826,12794.070428848268,11585.176102399826,1207.8728342056274,0.4063010215759277,0.0 -14300,0.5731492,1.4253376,,,,,,,,,,,,,, -14400,0.5611746,1.4160967,,,,,,,,,,,,,, -14500,0.5958022,1.4355963,,,,,,,,,,,,,, -14600,0.50347394,1.425972,,,,,,,,,,,,,, -14700,0.47203428,1.4048046,,,,,,,,,,,,,, -14800,0.5152795,1.385948,,,,,,,,,,,,,, -14900,0.5609923,1.4185989,,,,,,,,,,,,,, -15000,0.5800141,1.4125701,,,,,,,,,,,,,, -15100,0.42092666,1.4247755,,,,,,,,,,,,,, -15200,0.4864524,1.3957692,,,,,,,,,,,,,, -15300,0.4792912,1.4083014,,,,,,,,,,,,,, -15400,0.53367585,1.4352571,,,,,,,,,,,,,, -15500,0.57541275,1.4040958,,,,,,,,,,,,,, -15600,0.5493712,1.3998374,,,,,,,,,,,,,, -15700,0.50337,1.4057325,,,,,,,,,,,,,, -15800,0.48603722,1.4275968,,,,,,,,,,,,,, -15900,0.48729602,1.3188618,,,,,,,,,,,,,, -16000,0.5105705,1.3769097,,,,,,,,,,,,,, -16014,,,0.23602745,0.0879293926126116,0.528238,0.1569267308379273,5348.0,0.3204342,0.1075295025694148,2472.0,13025.042648553848,14366.364121437073,13025.042648553848,1340.1641960144043,0.4635465145111084,0.0 -16100,0.4408184,1.453906,,,,,,,,,,,,,, -16200,0.4657649,1.4010211,,,,,,,,,,,,,, -16300,0.4656711,1.3436983,,,,,,,,,,,,,, -16400,0.5104836,1.3817477,,,,,,,,,,,,,, -16500,0.39592347,1.3421775,,,,,,,,,,,,,, -16600,0.5381342,1.4099092,,,,,,,,,,,,,, -16700,0.57622707,1.3671819,,,,,,,,,,,,,, -16800,0.5575615,1.3891089,,,,,,,,,,,,,, -16900,0.5515418,1.342441,,,,,,,,,,,,,, -17000,0.4799947,1.4001472,,,,,,,,,,,,,, -17100,0.4968511,1.3801484,,,,,,,,,,,,,, -17200,0.5823726,1.3711807,,,,,,,,,,,,,, -17300,0.41392034,1.384476,,,,,,,,,,,,,, -17400,0.44849524,1.4178839,,,,,,,,,,,,,, -17500,0.60487765,1.3531958,,,,,,,,,,,,,, -17600,0.52113414,1.3609545,,,,,,,,,,,,,, -17700,0.67095536,1.4053916,,,,,,,,,,,,,, -17763,,,0.23291294,0.0894099223991376,0.5173014,0.1559129922666229,5348.0,0.3125017,0.1053561635488392,2472.0,14465.142292261124,15937.739381551744,14465.142292261124,1471.3020498752594,0.5215849876403809,0.0 -17800,0.4984604,1.4089247,,,,,,,,,,,,,, -17900,0.5468965,1.4429762,,,,,,,,,,,,,, -18000,0.49270174,1.3911353,,,,,,,,,,,,,, -18100,0.48818883,1.4031577,,,,,,,,,,,,,, -18200,0.45365617,1.3801947,,,,,,,,,,,,,, -18300,0.5309122,1.3715383,,,,,,,,,,,,,, -18400,0.5515788,1.366283,,,,,,,,,,,,,, -18500,0.64936775,1.3720721,,,,,,,,,,,,,, -18600,0.493,1.3843336,,,,,,,,,,,,,, -18700,0.42686638,1.3379785,,,,,,,,,,,,,, -18800,0.56925696,1.3900831,,,,,,,,,,,,,, -18900,0.5370798,1.373349,,,,,,,,,,,,,, -19000,0.5302896,1.3615917,,,,,,,,,,,,,, -19100,0.51615936,1.394166,,,,,,,,,,,,,, -19200,0.49714977,1.3626243,,,,,,,,,,,,,, -19300,0.5232854,1.3888437,,,,,,,,,,,,,, -19400,0.62469614,1.4446671,,,,,,,,,,,,,, -19499,,,0.24477959,0.0886326978502868,0.5056676,0.1514815065120635,5348.0,0.3041531,0.1017610139540552,2472.0,15905.654467105864,17509.719973564148,15905.654467105864,1602.643741607666,0.571497917175293,0.0 -19500,0.6012165,1.3636413,,,,,,,,,,,,,, -19600,0.58787364,1.321244,,,,,,,,,,,,,, -19700,0.514933,1.3058665,,,,,,,,,,,,,, -19800,0.5448355,1.3444148,,,,,,,,,,,,,, -19900,0.51517516,1.3601642,,,,,,,,,,,,,, -20000,0.45277664,1.3325922,,,,,,,,,,,,,, -20100,0.6581369,1.3669215,,,,,,,,,,,,,, -20200,0.46611613,1.3061947,,,,,,,,,,,,,, -20300,0.49834526,1.3149198,,,,,,,,,,,,,, -20400,0.44913736,1.3732584,,,,,,,,,,,,,, -20500,0.4420346,1.3575898,,,,,,,,,,,,,, -20600,0.5488225,1.2975055,,,,,,,,,,,,,, -20700,0.56871206,1.376689,,,,,,,,,,,,,, -20800,0.46097335,1.4331087,,,,,,,,,,,,,, -20900,0.5122685,1.3864529,,,,,,,,,,,,,, -21000,0.5621698,1.3292888,,,,,,,,,,,,,, -21100,0.51957387,1.3231198,,,,,,,,,,,,,, -21200,0.54677355,1.3498935,,,,,,,,,,,,,, -21243,,,0.24407542,0.0899781043079114,0.4981141,0.1499174527163366,5348.0,0.2986674,0.099790790729795,2472.0,17346.12431693077,19081.95854330063,17346.12431693077,1734.2814166545868,0.6246070861816406,0.0 -21300,0.48344412,1.3769048,,,,,,,,,,,,,, -21400,0.54173845,1.3856322,,,,,,,,,,,,,, -21500,0.58565223,1.3711345,,,,,,,,,,,,,, -21600,0.49428445,1.3054384,,,,,,,,,,,,,, -21700,0.5629585,1.333765,,,,,,,,,,,,,, -21800,0.60051286,1.3337142,,,,,,,,,,,,,, -21900,0.6875498,1.3004794,,,,,,,,,,,,,, -22000,0.53411496,1.3892275,,,,,,,,,,,,,, -22100,0.47292534,1.30475,,,,,,,,,,,,,, -22200,0.6393217,1.343558,,,,,,,,,,,,,, -22300,0.5359614,1.385193,,,,,,,,,,,,,, -22400,0.5678413,1.337957,,,,,,,,,,,,,, -22500,0.5493821,1.3752817,,,,,,,,,,,,,, -22600,0.48015437,1.4057561,,,,,,,,,,,,,, -22700,0.48112977,1.236968,,,,,,,,,,,,,, -22800,0.5285423,1.3345003,,,,,,,,,,,,,, -22900,0.5886238,1.3145708,,,,,,,,,,,,,, -23000,0.45261514,1.3143201,,,,,,,,,,,,,, -23011,,,0.22209774,0.0820917789896079,0.48654518,0.1449259970842947,5348.0,0.28946954,0.0951394389941705,2472.0,18786.23919892311,20655.370346546173,18786.23919892311,1867.4430181980133,0.6834304332733154,0.0 -23100,0.5070313,1.3594373,,,,,,,,,,,,,, -23200,0.56603605,1.3263394,,,,,,,,,,,,,, -23300,0.58160955,1.3836286,,,,,,,,,,,,,, -23400,0.5947543,1.3068581,,,,,,,,,,,,,, -23500,0.6825194,1.3584925,,,,,,,,,,,,,, -23600,0.4253458,1.3005394,,,,,,,,,,,,,, -23700,0.60102636,1.3626105,,,,,,,,,,,,,, -23800,0.52939665,1.3103728,,,,,,,,,,,,,, -23900,0.6061914,1.3638126,,,,,,,,,,,,,, -24000,0.45817208,1.3445417,,,,,,,,,,,,,, -24100,0.49350125,1.2710512,,,,,,,,,,,,,, -24200,0.7060601,1.3345932,,,,,,,,,,,,,, -24300,0.63604057,1.3270637,,,,,,,,,,,,,, -24400,0.6227888,1.3216398,,,,,,,,,,,,,, -24500,0.417763,1.3579054,,,,,,,,,,,,,, -24600,0.61116797,1.303956,,,,,,,,,,,,,, -24700,0.48165846,1.3100804,,,,,,,,,,,,,, -24741,,,0.20882161,0.0786548287656771,0.4894069,0.1448391052067544,5348.0,0.28323427,0.0961347063961164,2472.0,20226.15410375595,22226.457488298416,20226.15410375595,1998.485382080078,0.7370121479034424,0.0 -24800,0.435524,1.3221123,,,,,,,,,,,,,, -24900,0.53630173,1.3063008,,,,,,,,,,,,,, -25000,0.5224784,1.2690442,,,,,,,,,,,,,, -25100,0.660603,1.3520511,,,,,,,,,,,,,, -25200,0.57303035,1.364431,,,,,,,,,,,,,, -25300,0.5110195,1.2904094,,,,,,,,,,,,,, -25400,0.4486678,1.2946994,,,,,,,,,,,,,, -25500,0.63980675,1.3756586,,,,,,,,,,,,,, -25600,0.5594826,1.3027972,,,,,,,,,,,,,, -25700,0.6607045,1.2845052,,,,,,,,,,,,,, -25800,0.5579332,1.2718896,,,,,,,,,,,,,, -25900,0.5417814,1.335471,,,,,,,,,,,,,, -26000,0.5092717,1.2807949,,,,,,,,,,,,,, -26100,0.45912465,1.3141416,,,,,,,,,,,,,, -26200,0.49675855,1.2340791,,,,,,,,,,,,,, -26300,0.46291143,1.3178246,,,,,,,,,,,,,, -26400,0.5881176,1.308587,,,,,,,,,,,,,, -26500,0.5950944,1.2984987,,,,,,,,,,,,,, -26512,,,0.1967695,0.0746498343723101,0.4649453,0.1382835957789857,5348.0,0.2768741,0.0906505798955984,2472.0,21666.161897182465,23797.530810832977,21666.161897182465,2129.420209884644,0.7897412776947021,0.0 -26600,0.4687142,1.3196777,,,,,,,,,,,,,, -26700,0.59248906,1.3003436,,,,,,,,,,,,,, -26800,0.5921146,1.274456,,,,,,,,,,,,,, -26900,0.47083342,1.2835644,,,,,,,,,,,,,, -27000,0.53015095,1.354094,,,,,,,,,,,,,, -27100,0.5880229,1.2393882,,,,,,,,,,,,,, -27200,0.5745496,1.3330895,,,,,,,,,,,,,, -27300,0.55295336,1.2937908,,,,,,,,,,,,,, -27400,0.44324997,1.3102331,,,,,,,,,,,,,, -27500,0.5906453,1.2745863,,,,,,,,,,,,,, -27600,0.5951812,1.3306278,,,,,,,,,,,,,, -27700,0.7240942,1.2990438,,,,,,,,,,,,,, -27800,0.7357238,1.3041494,,,,,,,,,,,,,, -27900,0.4360523,1.2525859,,,,,,,,,,,,,, -28000,0.4797226,1.2413527,,,,,,,,,,,,,, -28100,0.61071616,1.3094807,,,,,,,,,,,,,, -28200,0.52567756,1.2353877,,,,,,,,,,,,,, -28300,0.5423241,1.2992787,,,,,,,,,,,,,, -28310,,,0.1956766,0.0728948256640517,0.46214008,0.1379843015341243,5348.0,0.26881522,0.0909958767493347,2472.0,23106.45556807518,25370.210528612137,23106.45556807518,2261.6781487464905,0.8389241695404053,0.0 -28400,0.5673416,1.288983,,,,,,,,,,,,,, -28500,0.5928205,1.3160625,,,,,,,,,,,,,, -28600,0.63848525,1.3224721,,,,,,,,,,,,,, -28700,0.5388093,1.2860056,,,,,,,,,,,,,, -28800,0.45542765,1.330739,,,,,,,,,,,,,, -28900,0.60019696,1.2964817,,,,,,,,,,,,,, -29000,0.4925931,1.2488025,,,,,,,,,,,,,, -29100,0.4246568,1.2391412,,,,,,,,,,,,,, -29200,0.45691624,1.2528877,,,,,,,,,,,,,, -29300,0.5277187,1.2565788,,,,,,,,,,,,,, -29400,0.5056779,1.3321211,,,,,,,,,,,,,, -29500,0.51015806,1.2914599,,,,,,,,,,,,,, -29600,0.6032464,1.2443933,,,,,,,,,,,,,, -29700,0.47178546,1.257815,,,,,,,,,,,,,, -29800,0.60144645,1.2787044,,,,,,,,,,,,,, -29900,0.43992618,1.2340769,,,,,,,,,,,,,, -30000,0.50219613,1.229751,,,,,,,,,,,,,, -30074,,,0.20220742,0.0759609551048708,0.4574062,0.1352037614528322,5348.0,0.26378438,0.0888225377287591,2472.0,24546.814392089844,26940.67303466797,24546.814392089844,2391.651878118515,0.8914635181427002,0.0 -30100,0.63161784,1.2587501,,,,,,,,,,,,,, -30200,0.5695069,1.2745656,,,,,,,,,,,,,, -30300,0.5061,1.2256968,,,,,,,,,,,,,, -30400,0.6824529,1.3037953,,,,,,,,,,,,,, -30500,0.48864526,1.2684939,,,,,,,,,,,,,, -30600,0.4694447,1.2350702,,,,,,,,,,,,,, -30700,0.5440893,1.2659146,,,,,,,,,,,,,, -30800,0.5715653,1.2324289,,,,,,,,,,,,,, -30900,0.45173615,1.2555166,,,,,,,,,,,,,, -31000,0.5607161,1.2465875,,,,,,,,,,,,,, -31100,0.5106442,1.3422545,,,,,,,,,,,,,, -31200,0.48155487,1.2136872,,,,,,,,,,,,,, -31300,0.49551845,1.2275136,,,,,,,,,,,,,, -31400,0.5889168,1.3513352,,,,,,,,,,,,,, -31500,0.45115712,1.2321458,,,,,,,,,,,,,, -31600,0.66079354,1.2382314,,,,,,,,,,,,,, -31700,0.7126502,1.2657089,,,,,,,,,,,,,, -31800,0.5207657,1.3164426,,,,,,,,,,,,,, -31824,,,0.18969351,0.0693004302790574,0.44134903,0.1326163144327408,5348.0,0.2555178,0.0871163650397091,2472.0,25987.314188480377,28511.35634493828,25987.314188480377,2521.70818734169,0.942244291305542,0.0 -31900,0.65909904,1.2625504,,,,,,,,,,,,,, -32000,0.6788327,1.2607144,,,,,,,,,,,,,, -32100,0.50753593,1.2578841,,,,,,,,,,,,,, -32200,0.51818985,1.2687627,,,,,,,,,,,,,, -32300,0.6222148,1.2125342,,,,,,,,,,,,,, -32400,0.6581438,1.2086254,,,,,,,,,,,,,, -32500,0.44985396,1.242421,,,,,,,,,,,,,, -32600,0.49820304,1.1907248,,,,,,,,,,,,,, -32700,0.55623734,1.2271776,,,,,,,,,,,,,, -32800,0.5211414,1.3216909,,,,,,,,,,,,,, -32900,0.5235466,1.2985196,,,,,,,,,,,,,, -33000,0.52501047,1.2677642,,,,,,,,,,,,,, -33100,0.4732177,1.1996571,,,,,,,,,,,,,, -33200,0.551502,1.2811162,,,,,,,,,,,,,, -33300,0.50774354,1.2098173,,,,,,,,,,,,,, -33400,0.5355201,1.2177423,,,,,,,,,,,,,, -33500,0.5190376,1.2619133,,,,,,,,,,,,,, -33573,,,0.19449702,0.0708998351047162,0.43590006,0.1275669308823387,5348.0,0.24836451,0.0830540491133995,2472.0,27427.76982665062,30091.00088500977,27427.76982665062,2660.7630712985992,0.9998691082000732,0.0 -33600,0.536905,1.2637573,,,,,,,,,,,,,, -33700,0.53395116,1.2370558,,,,,,,,,,,,,, -33800,0.67019105,1.200627,,,,,,,,,,,,,, -33900,0.5508301,1.2608078,,,,,,,,,,,,,, -34000,0.5967234,1.1915168,,,,,,,,,,,,,, -34100,0.67028105,1.2704269,,,,,,,,,,,,,, -34200,0.5387054,1.1970118,,,,,,,,,,,,,, -34300,0.55692124,1.1978979,,,,,,,,,,,,,, -34400,0.4617164,1.2456021,,,,,,,,,,,,,, -34500,0.6267957,1.1806637,,,,,,,,,,,,,, -34600,0.52365446,1.2611986,,,,,,,,,,,,,, -34700,0.51973385,1.2426682,,,,,,,,,,,,,, -34800,0.48415855,1.2702352,,,,,,,,,,,,,, -34900,0.5394601,1.2245417,,,,,,,,,,,,,, -35000,0.6307022,1.2286445,,,,,,,,,,,,,, -35100,0.61178076,1.2457088,,,,,,,,,,,,,, -35200,0.49160215,1.1993715,,,,,,,,,,,,,, -35300,0.5425399,1.252946,,,,,,,,,,,,,, -35337,,,0.1883224,0.0683738823704935,0.4340027,0.1286096334128233,5348.0,0.24490088,0.0808807100928239,2472.0,28868.35710072517,31663.327996253967,28868.35710072517,2792.369478940964,1.0533790588378906,0.0 -35400,0.47521695,1.2338682,,,,,,,,,,,,,, -35500,0.48150724,1.2405509,,,,,,,,,,,,,, -35600,0.47501084,1.2340448,,,,,,,,,,,,,, -35700,0.5126067,1.223328,,,,,,,,,,,,,, -35800,0.73276573,1.2205094,,,,,,,,,,,,,, -35900,0.61820513,1.2423987,,,,,,,,,,,,,, -36000,0.49467114,1.2578346,,,,,,,,,,,,,, -36100,0.5540451,1.2499928,,,,,,,,,,,,,, -36200,0.6019249,1.1848665,,,,,,,,,,,,,, -36300,0.4769104,1.2026031,,,,,,,,,,,,,, -36400,0.5539255,1.203367,,,,,,,,,,,,,, -36500,0.87726164,1.2108505,,,,,,,,,,,,,, -36600,0.6300856,1.1613197,,,,,,,,,,,,,, -36700,0.5383397,1.2398497,,,,,,,,,,,,,, -36800,0.66912466,1.2290492,,,,,,,,,,,,,, -36900,0.56887585,1.2387449,,,,,,,,,,,,,, -37000,0.6481514,1.2189901,,,,,,,,,,,,,, -37063,,,0.14173692,0.0543289976732982,0.41600996,0.1218031030054935,5348.0,0.23621194,0.0795198342575102,2472.0,30308.52282428741,33235.32124829292,30308.52282428741,2924.0652084350586,1.1087830066680908,0.0 -37100,0.49558637,1.1984268,,,,,,,,,,,,,, -37200,0.623314,1.2208294,,,,,,,,,,,,,, -37300,0.41589892,1.1482568,,,,,,,,,,,,,, -37400,0.5577865,1.2113793,,,,,,,,,,,,,, -37500,0.5615355,1.1788459,,,,,,,,,,,,,, -37600,0.5374774,1.2292358,,,,,,,,,,,,,, -37700,0.4953275,1.1595055,,,,,,,,,,,,,, -37800,0.5042544,1.188515,,,,,,,,,,,,,, -37900,0.5189762,1.2039726,,,,,,,,,,,,,, -38000,0.8204037,1.2523469,,,,,,,,,,,,,, -38100,0.5882912,1.2130066,,,,,,,,,,,,,, -38200,0.50023776,1.1528896,,,,,,,,,,,,,, -38300,0.68648773,1.2490107,,,,,,,,,,,,,, -38400,0.5307393,1.2339233,,,,,,,,,,,,,, -38500,0.64619255,1.2372534,,,,,,,,,,,,,, -38600,0.44160885,1.1648551,,,,,,,,,,,,,, -38700,0.6583839,1.1511933,,,,,,,,,,,,,, -38800,0.5091886,1.1880289,,,,,,,,,,,,,, -38807,,,0.16449861,0.0600990475385947,0.410855,0.1203935236587273,5348.0,0.23410633,0.0759449962423577,2472.0,31749.41622185707,34807.73562335968,31749.41622185707,3055.453970432281,1.1656112670898438,0.0 -38900,0.55987793,1.1829387,,,,,,,,,,,,,, -39000,0.45305732,1.1882744,,,,,,,,,,,,,, -39100,0.5770364,1.181959,,,,,,,,,,,,,, -39200,0.6158332,1.125843,,,,,,,,,,,,,, -39300,0.5145481,1.122371,,,,,,,,,,,,,, -39400,0.5837908,1.1932032,,,,,,,,,,,,,, -39500,0.548464,1.2092333,,,,,,,,,,,,,, -39600,0.49640948,1.1990031,,,,,,,,,,,,,, -39700,0.530194,1.1725093,,,,,,,,,,,,,, -39800,0.74804604,1.1835092,,,,,,,,,,,,,, -39900,0.50658363,1.1352834,,,,,,,,,,,,,, -40000,0.56523186,1.196823,,,,,,,,,,,,,, -40100,0.57626396,1.1552373,,,,,,,,,,,,,, -40200,0.5617929,1.1737957,,,,,,,,,,,,,, -40300,0.5545943,1.1580735,,,,,,,,,,,,,, -40400,0.6391276,1.1665471,,,,,,,,,,,,,, -40500,0.7492543,1.1966119,,,,,,,,,,,,,, -40584,,,0.20285141,0.0741797432239657,0.40358555,0.1180184790059569,5348.0,0.23114812,0.0747059898848333,2472.0,33189.860087394714,36378.72210025787,33189.860087394714,3185.8636391162872,1.2198340892791748,0.0 -40600,0.49528965,1.1408426,,,,,,,,,,,,,, -40700,0.54503554,1.1990329,,,,,,,,,,,,,, -40800,0.5366819,1.1817031,,,,,,,,,,,,,, -40900,0.5487447,1.181415,,,,,,,,,,,,,, -41000,0.70109403,1.1557318,,,,,,,,,,,,,, -41100,0.61453706,1.1570534,,,,,,,,,,,,,, -41200,0.50441533,1.1075113,,,,,,,,,,,,,, -41300,0.48526782,1.1504189,,,,,,,,,,,,,, -41400,0.573687,1.1995034,,,,,,,,,,,,,, -41500,0.6162398,1.197981,,,,,,,,,,,,,, -41600,0.58592194,1.1594169,,,,,,,,,,,,,, -41700,0.49841136,1.1772844,,,,,,,,,,,,,, -41800,0.48178864,1.1490225,,,,,,,,,,,,,, -41900,0.61975545,1.1935667,,,,,,,,,,,,,, -42000,0.7300919,1.1761882,,,,,,,,,,,,,, -42100,0.464607,1.1667969,,,,,,,,,,,,,, -42200,0.56267506,1.144813,,,,,,,,,,,,,, -42300,0.45468843,1.1182725,,,,,,,,,,,,,, -42318,,,0.20199291,0.0740769525416766,0.39573586,0.1171109416183129,5348.0,0.22265244,0.0721670424308898,2472.0,34629.90348124504,37949.16435742378,34629.90348124504,3316.1265754699707,1.2786462306976318,0.0 -42400,0.5123347,1.1104003,,,,,,,,,,,,,, -42500,0.67004013,1.1192268,,,,,,,,,,,,,, -42600,0.4790164,1.1848874,,,,,,,,,,,,,, -42700,0.5277906,1.1672709,,,,,,,,,,,,,, -42800,0.5099757,1.1243384,,,,,,,,,,,,,, -42900,0.70880055,1.1157101,,,,,,,,,,,,,, -43000,0.49811998,1.187918,,,,,,,,,,,,,, -43100,0.5280082,1.1547655,,,,,,,,,,,,,, -43200,0.57745737,1.1189697,,,,,,,,,,,,,, -43300,0.6841632,1.1522626,,,,,,,,,,,,,, -43400,0.62930316,1.1558189,,,,,,,,,,,,,, -43500,0.50399435,1.1677977,,,,,,,,,,,,,, -43600,0.58737975,1.1646667,,,,,,,,,,,,,, -43700,0.604568,1.120648,,,,,,,,,,,,,, -43800,0.5752443,1.1896394,,,,,,,,,,,,,, -43900,0.47610196,1.1633837,,,,,,,,,,,,,, -44000,0.7951192,1.1661015,,,,,,,,,,,,,, -44053,,,0.22752737,0.08577896590251,0.38816926,0.1146393504349421,5348.0,0.21891987,0.072248288749416,2472.0,36070.11869764328,39519.31596660614,36070.11869764328,3445.9241964817047,1.341212272644043,0.0 -44100,0.5989198,1.1052232,,,,,,,,,,,,,, -44200,0.53723943,1.1380205,,,,,,,,,,,,,, -44300,0.51577175,1.1374775,,,,,,,,,,,,,, -44400,0.5648325,1.1762258,,,,,,,,,,,,,, -44500,0.57322794,1.1168898,,,,,,,,,,,,,, -44600,0.6113897,1.1127102,,,,,,,,,,,,,, -44700,0.5580629,1.1336466,,,,,,,,,,,,,, -44800,0.5858027,1.0954028,,,,,,,,,,,,,, -44900,0.57554543,1.177603,,,,,,,,,,,,,, -45000,0.5407974,1.1288813,,,,,,,,,,,,,, -45100,0.6410517,1.1121804,,,,,,,,,,,,,, -45200,0.78260964,1.2164072,,,,,,,,,,,,,, -45300,0.66775095,1.124632,,,,,,,,,,,,,, -45400,0.6892948,1.144581,,,,,,,,,,,,,, -45500,0.47777933,1.0893424,,,,,,,,,,,,,, -45600,0.68506145,1.1700379,,,,,,,,,,,,,, -45700,0.7801659,1.126898,,,,,,,,,,,,,, -45799,,,0.2029801,0.0721935366119317,0.3828147,0.1124284348841924,5348.0,0.21051346,0.0695468486584201,2472.0,37509.99157762528,41089.9249560833,37509.99157762528,3576.526694059372,1.396867036819458,0.0 -45800,0.8572724,1.133635,,,,,,,,,,,,,, -45900,0.6072569,1.1294849,,,,,,,,,,,,,, -46000,0.66654974,1.077102,,,,,,,,,,,,,, -46100,0.627205,1.1115844,,,,,,,,,,,,,, -46200,0.51287526,1.1582862,,,,,,,,,,,,,, -46300,0.7058312,1.1298788,,,,,,,,,,,,,, -46400,0.5987654,1.0991707,,,,,,,,,,,,,, -46500,0.5586312,1.0962828,,,,,,,,,,,,,, -46600,0.6997541,1.1139146,,,,,,,,,,,,,, -46700,0.68838257,1.0855294,,,,,,,,,,,,,, -46800,0.547958,1.1574229,,,,,,,,,,,,,, -46900,0.53428453,1.1509805,,,,,,,,,,,,,, -47000,0.5711557,1.1292344,,,,,,,,,,,,,, -47100,0.66213524,1.1160473,,,,,,,,,,,,,, -47200,0.567544,1.1420106,,,,,,,,,,,,,, -47300,0.5862509,1.0716386,,,,,,,,,,,,,, -47400,0.61671996,1.1235282,,,,,,,,,,,,,, -47500,0.56095606,1.0406101,,,,,,,,,,,,,, -47560,,,0.18338831,0.068447774821017,0.37283552,0.1085569190071154,5348.0,0.21286468,0.0688359433713159,2472.0,38950.62455177307,42660.78283596039,38950.62455177307,3706.613924980164,1.4564719200134275,0.0 -47600,0.71912026,1.0555768,,,,,,,,,,,,,, -47700,0.51122564,1.0994159,,,,,,,,,,,,,, -47800,0.5617215,1.1209182,,,,,,,,,,,,,, -47900,0.487238,1.084432,,,,,,,,,,,,,, -48000,0.7885131,1.1042222,,,,,,,,,,,,,, -48100,0.68492466,1.0966938,,,,,,,,,,,,,, -48200,0.60518956,1.1559782,,,,,,,,,,,,,, -48300,0.7096223,1.0980258,,,,,,,,,,,,,, -48400,0.68730193,1.1202987,,,,,,,,,,,,,, -48500,0.6916633,1.1006944,,,,,,,,,,,,,, -48600,0.57521176,1.102693,,,,,,,,,,,,,, -48700,0.6175388,1.1002746,,,,,,,,,,,,,, -48800,0.5103845,1.0900594,,,,,,,,,,,,,, -48900,1.0282325,1.1076237,,,,,,,,,,,,,, -49000,1.0751824,1.0992815,,,,,,,,,,,,,, -49100,0.57994825,1.1255584,,,,,,,,,,,,,, -49200,0.5701346,1.0671996,,,,,,,,,,,,,, -49288,,,0.15685317,0.0596471957280666,0.3723242,0.1092520540274385,5348.0,0.20216686,0.0661751264395832,2472.0,40390.96273756027,44232.39781737328,40390.96273756027,3837.757830858231,1.5125069618225098,0.0 -49300,0.6268431,1.0925909,,,,,,,,,,,,,, -49400,0.57355106,1.1560804,,,,,,,,,,,,,, -49500,0.5065048,1.0909517,,,,,,,,,,,,,, -49600,0.8451638,1.0581917,,,,,,,,,,,,,, -49700,0.61393565,1.1053967,,,,,,,,,,,,,, -49800,0.6214304,1.0802621,,,,,,,,,,,,,, -49900,0.57554346,1.1119547,,,,,,,,,,,,,, -50000,0.80128175,1.1007997,,,,,,,,,,,,,, -50100,0.5567479,1.1361696,,,,,,,,,,,,,, -50200,0.5458636,1.0707388,,,,,,,,,,,,,, -50300,0.65458465,1.0931237,,,,,,,,,,,,,, -50400,0.74084455,1.1319308,,,,,,,,,,,,,, -50500,0.7178344,1.069188,,,,,,,,,,,,,, -50600,0.531457,1.0524316,,,,,,,,,,,,,, -50700,0.58845955,1.033515,,,,,,,,,,,,,, -50800,0.53668857,1.0633541,,,,,,,,,,,,,, -50900,0.6379332,1.1402712,,,,,,,,,,,,,, -51000,0.6024789,1.0930339,,,,,,,,,,,,,, -51047,,,0.17114201,0.0648845633720203,0.3645804,0.106403931374726,5348.0,0.20247392,0.0669063433063189,2472.0,41831.39310407639,45803.714185237885,41831.39310407639,3968.502446889877,1.5761635303497314,0.0 -51100,0.64083296,1.1457784,,,,,,,,,,,,,, -51200,0.5816702,1.1080549,,,,,,,,,,,,,, -51300,0.63204336,1.0770522,,,,,,,,,,,,,, -51400,0.6163845,1.0426915,,,,,,,,,,,,,, -51500,0.6370536,1.0756372,,,,,,,,,,,,,, -51600,0.67116934,1.0579084,,,,,,,,,,,,,, -51700,0.52046394,1.0820357,,,,,,,,,,,,,, -51800,0.5138169,1.032791,,,,,,,,,,,,,, -51900,0.70821625,1.0609221,,,,,,,,,,,,,, -52000,0.6331116,1.063086,,,,,,,,,,,,,, -52100,0.6261203,1.095887,,,,,,,,,,,,,, -52200,0.71134365,1.0577416,,,,,,,,,,,,,, -52300,0.6060276,1.1013713,,,,,,,,,,,,,, -52400,0.5558924,1.0295304,,,,,,,,,,,,,, -52500,0.7254388,1.0477477,,,,,,,,,,,,,, -52600,0.5762525,1.0319651,,,,,,,,,,,,,, -52700,0.65537804,1.0733047,,,,,,,,,,,,,, -52800,0.70872164,1.1222073,,,,,,,,,,,,,, -52804,,,0.14830387,0.0573462463494245,0.35951543,0.1038647576199349,5348.0,0.19651937,0.0634533747689557,2472.0,43271.88686680794,47377.0074198246,43271.88686680794,4101.162638664246,1.6361210346221924,0.0 -52900,0.5740304,1.0742447,,,,,,,,,,,,,, -53000,0.7194057,1.0660727,,,,,,,,,,,,,, -53100,0.7642675,1.0342458,,,,,,,,,,,,,, -53200,0.5808894,1.061892,,,,,,,,,,,,,, -53300,0.5103456,1.0256464,,,,,,,,,,,,,, -53400,0.6996967,1.0732358,,,,,,,,,,,,,, -53500,0.8751302,1.1314034,,,,,,,,,,,,,, -53600,0.7883605,1.0740714,,,,,,,,,,,,,, -53700,0.60732245,1.0975864,,,,,,,,,,,,,, -53800,0.61836034,1.0690451,,,,,,,,,,,,,, -53900,0.7655027,1.0596861,,,,,,,,,,,,,, -54000,0.6316084,1.0411534,,,,,,,,,,,,,, -54100,0.68710566,1.0543412,,,,,,,,,,,,,, -54200,0.6903546,1.0885177,,,,,,,,,,,,,, -54300,0.70862883,1.0367134,,,,,,,,,,,,,, -54400,0.52534527,1.0279636,,,,,,,,,,,,,, -54500,0.5912515,1.0179869,,,,,,,,,,,,,, -54528,,,0.14442445,0.0558153158528515,0.35094666,0.1008235419060216,5348.0,0.1865691,0.0607519346779599,2472.0,44711.75618267059,48949.43096876144,44711.75618267059,4233.571403264999,1.7039833068847656,0.0 -54600,0.7805065,1.0358617,,,,,,,,,,,,,, -54700,0.5574345,1.0321816,,,,,,,,,,,,,, -54800,0.52933353,1.0776336,,,,,,,,,,,,,, -54900,0.6172369,1.0444796,,,,,,,,,,,,,, -55000,0.95094055,1.0350362,,,,,,,,,,,,,, -55100,0.5543378,1.0123924,,,,,,,,,,,,,, -55200,0.76518816,1.057976,,,,,,,,,,,,,, -55300,0.6658173,1.0060785,,,,,,,,,,,,,, -55400,0.5116255,1.0419663,,,,,,,,,,,,,, -55500,0.6818933,1.062872,,,,,,,,,,,,,, -55600,0.6946267,1.0469979,,,,,,,,,,,,,, -55700,0.58147275,1.0240115,,,,,,,,,,,,,, -55800,0.66982037,0.99091285,,,,,,,,,,,,,, -55900,0.9861698,1.0695236,,,,,,,,,,,,,, -56000,0.5068805,1.0491385,,,,,,,,,,,,,, -56100,0.58021593,1.01467,,,,,,,,,,,,,, -56200,0.59341866,1.0462562,,,,,,,,,,,,,, -56266,,,0.1373797,0.0521849551414768,0.34084773,0.0978981820288288,5348.0,0.18515553,0.0613003473280117,2472.0,46151.96347117424,50520.468977451324,46151.96347117424,4364.2648112773895,1.765101194381714,0.0 -56300,0.5396494,0.9664485,,,,,,,,,,,,,, -56400,0.7909593,1.0561447,,,,,,,,,,,,,, -56500,0.6658363,1.0760626,,,,,,,,,,,,,, -56600,0.6306729,1.0898746,,,,,,,,,,,,,, -56700,0.6431625,1.0321826,,,,,,,,,,,,,, -56800,0.7036578,1.0319725,,,,,,,,,,,,,, -56900,0.62552524,1.067982,,,,,,,,,,,,,, -57000,0.58469254,1.025071,,,,,,,,,,,,,, -57100,0.6545204,0.98948425,,,,,,,,,,,,,, -57200,0.63153166,1.0325662,,,,,,,,,,,,,, -57300,0.6757891,1.0478117,,,,,,,,,,,,,, -57400,1.223707,1.0005481,,,,,,,,,,,,,, -57500,0.5843135,1.0120003,,,,,,,,,,,,,, -57600,0.9233177,0.9956638,,,,,,,,,,,,,, -57700,0.6991036,1.015568,,,,,,,,,,,,,, -57800,0.6659013,0.9828381,,,,,,,,,,,,,, -57900,0.5046647,0.9682937,,,,,,,,,,,,,, -57996,,,0.13606107,0.051268843789705,0.33888745,0.0969134074167044,5348.0,0.18242605,0.0595738630593301,2472.0,47592.51656937599,52090.4231069088,47592.51656937599,4493.534925699234,1.8188579082489007,0.0 -58000,0.6838363,1.0210928,,,,,,,,,,,,,, -58100,0.711179,1.0589278,,,,,,,,,,,,,, -58200,0.70803005,0.9785564,,,,,,,,,,,,,, -58300,0.8438225,1.0158379,,,,,,,,,,,,,, -58400,0.65082747,0.99938035,,,,,,,,,,,,,, -58500,0.57978165,0.9910564,,,,,,,,,,,,,, -58600,0.7489369,1.0266377,,,,,,,,,,,,,, -58700,0.6623321,0.9944265,,,,,,,,,,,,,, -58800,0.6370118,0.96700627,,,,,,,,,,,,,, -58900,0.7215274,0.98756176,,,,,,,,,,,,,, -59000,0.75466865,0.98719615,,,,,,,,,,,,,, -59100,0.7438207,1.0071813,,,,,,,,,,,,,, -59200,0.7983601,0.98237175,,,,,,,,,,,,,, -59300,0.8149673,0.97308105,,,,,,,,,,,,,, -59400,0.6574475,1.0198133,,,,,,,,,,,,,, -59500,0.60379165,1.0416518,,,,,,,,,,,,,, -59600,0.5417982,0.95369905,,,,,,,,,,,,,, -59699,,,0.13671345,0.0510678617715287,0.33184198,0.0946928372128947,5348.0,0.17691371,0.0585379724981211,2472.0,49032.50491023064,53661.48346114159,49032.50491023064,4624.469212293625,1.8802030086517327,0.0 -59700,0.621994,0.9807433,,,,,,,,,,,,,, -59800,0.7393173,1.0248003,,,,,,,,,,,,,, -59900,0.60650444,0.9866752,,,,,,,,,,,,,, -60000,0.6086139,1.009099,,,,,,,,,,,,,, -60100,0.5854969,1.0053148,,,,,,,,,,,,,, -60200,0.6857567,0.98119086,,,,,,,,,,,,,, -60300,0.6052375,0.95923233,,,,,,,,,,,,,, -60400,0.90475506,1.0392663,,,,,,,,,,,,,, -60500,0.72769725,1.0074692,,,,,,,,,,,,,, -60600,0.57044965,1.0243554,,,,,,,,,,,,,, -60700,0.55990744,0.9997868,,,,,,,,,,,,,, -60800,0.6423661,0.9799867,,,,,,,,,,,,,, -60900,0.6264235,1.018088,,,,,,,,,,,,,, -61000,0.64133006,0.95085996,,,,,,,,,,,,,, -61100,0.68740076,0.9892993,,,,,,,,,,,,,, -61200,0.73715186,0.984816,,,,,,,,,,,,,, -61300,0.57467353,0.9700627,,,,,,,,,,,,,, -61400,0.6666862,0.9809466,,,,,,,,,,,,,, -61411,,,0.11137697,0.0431546618243773,0.32254496,0.0926653600702858,5348.0,0.17471576,0.0569333577072288,2472.0,50472.42002773285,55232.452363967896,50472.42002773285,4755.3815841674805,1.945399284362793,0.0 -61500,0.6570466,1.0058534,,,,,,,,,,,,,, -61600,0.87434804,0.9830123,,,,,,,,,,,,,, -61700,0.6739858,0.9285281,,,,,,,,,,,,,, -61800,0.7551869,0.9539923,,,,,,,,,,,,,, -61900,0.9675873,0.9297673,,,,,,,,,,,,,, -62000,0.711021,1.0002791,,,,,,,,,,,,,, -62100,1.043682,0.93850154,,,,,,,,,,,,,, -62200,0.7575945,0.960275,,,,,,,,,,,,,, -62300,0.72057545,1.0018768,,,,,,,,,,,,,, -62400,0.7313148,1.0184866,,,,,,,,,,,,,, -62500,0.6274728,0.9317154,,,,,,,,,,,,,, -62600,0.61941516,0.97364247,,,,,,,,,,,,,, -62700,0.70244116,0.97832257,,,,,,,,,,,,,, -62800,0.62602353,0.9300811,,,,,,,,,,,,,, -62900,0.7105006,0.9779538,,,,,,,,,,,,,, -63000,0.81638527,0.98335886,,,,,,,,,,,,,, -63100,0.5976712,0.9696216,,,,,,,,,,,,,, -63155,,,0.11432057,0.0435226598512162,0.3218405,0.0907054654990972,5348.0,0.17560743,0.0565271261145979,2472.0,51912.341347932816,56802.64616537094,51912.341347932816,4885.515163421631,2.0081372261047363,0.0 -63200,0.6797105,0.9587087,,,,,,,,,,,,,, -63300,0.61400527,0.99491334,,,,,,,,,,,,,, -63400,0.62339747,0.96131104,,,,,,,,,,,,,, -63500,0.6434192,0.9855549,,,,,,,,,,,,,, -63600,0.7950567,0.97325784,,,,,,,,,,,,,, -63700,0.6826425,0.95145303,,,,,,,,,,,,,, -63800,0.6161528,0.9317969,,,,,,,,,,,,,, -63900,0.64393425,0.9580608,,,,,,,,,,,,,, -64000,0.5678547,0.9301604,,,,,,,,,,,,,, -64100,0.753005,1.0140116,,,,,,,,,,,,,, -64200,0.8260473,0.97363764,,,,,,,,,,,,,, -64300,0.66519046,0.9534943,,,,,,,,,,,,,, -64400,1.0052621,0.97677064,,,,,,,,,,,,,, -64500,0.7181348,0.95424193,,,,,,,,,,,,,, -64600,0.94190425,0.9536833,,,,,,,,,,,,,, -64700,0.76961935,0.99598265,,,,,,,,,,,,,, -64800,0.96977156,0.9670276,,,,,,,,,,,,,, -64852,,,0.104371265,0.04056934339151,0.3143445,0.0894117419890516,5348.0,0.16428222,0.0530741575772347,2472.0,53352.50097608566,58376.11468625069,53352.50097608566,5018.692674398422,2.0654304027557373,0.0 -64900,0.84816366,0.92735267,,,,,,,,,,,,,, -65000,0.80956084,0.9614062,,,,,,,,,,,,,, -65100,0.7623953,0.95654035,,,,,,,,,,,,,, -65200,0.6529738,0.97988003,,,,,,,,,,,,,, -65300,0.6111324,0.96631426,,,,,,,,,,,,,, -65400,0.67657334,0.97750795,,,,,,,,,,,,,, -65500,0.9984878,0.9731796,,,,,,,,,,,,,, -65600,0.77520496,0.96645415,,,,,,,,,,,,,, -65700,0.7228196,0.9624217,,,,,,,,,,,,,, -65800,0.89172643,0.9520336,,,,,,,,,,,,,, -65900,0.58790463,0.9348422,,,,,,,,,,,,,, -66000,0.6316168,0.95493996,,,,,,,,,,,,,, -66100,0.6477789,0.9458384,,,,,,,,,,,,,, -66200,0.9670115,0.91571236,,,,,,,,,,,,,, -66300,0.57567614,0.9658978,,,,,,,,,,,,,, -66400,0.73199666,0.9592927,,,,,,,,,,,,,, -66500,0.86398375,0.93268776,,,,,,,,,,,,,, -66559,,,0.10364472,0.0397125890330625,0.3120098,0.0878766521525049,5348.0,0.16560408,0.0542522291958645,2472.0,54792.7521352768,59946.395359277725,54792.7521352768,5148.585556030273,2.1266865730285645,0.0 -66600,0.7427457,0.9456684,,,,,,,,,,,,,, -66700,0.65274704,0.93688315,,,,,,,,,,,,,, -66800,0.6901633,0.9404891,,,,,,,,,,,,,, -66900,0.74520695,0.99369746,,,,,,,,,,,,,, -67000,0.8507121,0.90285933,,,,,,,,,,,,,, -67100,0.66521627,0.8839998,,,,,,,,,,,,,, -67200,0.7145483,0.9424075,,,,,,,,,,,,,, -67300,0.8064061,0.90294105,,,,,,,,,,,,,, -67400,0.685106,0.91727215,,,,,,,,,,,,,, -67500,0.78584886,0.97629815,,,,,,,,,,,,,, -67600,0.837244,0.914121,,,,,,,,,,,,,, -67700,0.603796,0.9478733,,,,,,,,,,,,,, -67800,0.80726093,0.8953168,,,,,,,,,,,,,, -67900,1.2439871,0.9704955,,,,,,,,,,,,,, -68000,0.636828,0.8923713,,,,,,,,,,,,,, -68100,0.6133541,0.9084843,,,,,,,,,,,,,, -68200,0.8699441,0.90250224,,,,,,,,,,,,,, -68288,,,0.09611028,0.0371413877149437,0.30804402,0.0861581239078173,5348.0,0.161813,0.0521804480734466,2472.0,56232.73989892006,61515.9838206768,56232.73989892006,5278.0486397743225,2.187870979309082,0.0 -68300,0.7651989,0.9258523,,,,,,,,,,,,,, -68400,0.8641355,0.95165825,,,,,,,,,,,,,, -68500,0.72322553,0.91474414,,,,,,,,,,,,,, -68600,0.69015086,0.9093419,,,,,,,,,,,,,, -68700,0.9425339,0.9265461,,,,,,,,,,,,,, -68800,0.8715294,0.907848,,,,,,,,,,,,,, -68900,0.66645604,0.94594806,,,,,,,,,,,,,, -69000,1.5970981,0.9365563,,,,,,,,,,,,,, -69100,0.7246988,0.9693435,,,,,,,,,,,,,, -69200,0.75459856,0.9114769,,,,,,,,,,,,,, -69300,0.7333976,0.9210862,,,,,,,,,,,,,, -69400,0.7511244,0.9093225,,,,,,,,,,,,,, -69500,0.70997304,0.8783005,,,,,,,,,,,,,, -69600,0.7348962,0.8997055,,,,,,,,,,,,,, -69700,0.604371,0.9168842,,,,,,,,,,,,,, -69800,0.8001469,0.89838475,,,,,,,,,,,,,, -69900,0.8465921,0.9050518,,,,,,,,,,,,,, -69983,,,0.092972785,0.0354703740190164,0.30447015,0.0858781389690761,5348.0,0.15796709,0.0522210712327097,2472.0,57673.28281927109,63088.09074831009,57673.28281927109,5409.477321624756,2.2469236850738525,0.0 -69983,,,,,,,,,,,57673.28281927109,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 41440ef4f..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,27 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -183.5684108734131,0.0,42.37398552894592,1,0,42.37398552894592,30.705326,2472,1.8779477179940285,225.94245171546936,31.636915,1.985486121233576,30.702776,5348,1.707309537831758 -305.5697865486145,0.0383594036102294,1482.5158276557922,1835,0,1482.5158276557922,2.2324085,2472,0.4973087156988199,1788.1893684864044,2.3297267,0.505952503552965,2.6970866,5348,0.5646716935226933 -432.54165267944336,0.083158254623413,2923.0815935134888,3675,0,2923.0815935134888,0.5609731,2472,0.1808746166189344,3355.8438744544983,0.5273757,0.1752337319873212,0.8686737,5348,0.2495824362551531 -559.979326248169,0.1280710697174072,4362.980728149414,5515,0,4362.980728149414,0.46814135,2472,0.1503260008530863,4923.29837346077,0.39720902,0.1383947161735481,0.7544181,5348,0.2186103092385375 -687.5585751533508,0.1763753890991211,5803.595576286316,7316,0,5803.595576286316,0.41549453,2472,0.1346251497978997,6491.611656188965,0.36748856,0.1257736493208641,0.69803226,5348,0.2021008525058652 -813.9893782138824,0.2295944690704345,7243.767488241196,9137,0,7243.767488241196,0.39312792,2472,0.1298925517437491,8058.339680671692,0.33860832,0.1189414860075235,0.66417503,5348,0.1937881962211688 -942.3130850791932,0.2782130241394043,8684.250477790833,10948,0,8684.250477790833,0.3835853,2472,0.1250787073710722,9627.266298294067,0.35432062,0.1211623304223503,0.6541005,5348,0.1902642478542533 -1070.5347871780396,0.3274891376495361,10124.378967046738,12763,0,10124.378967046738,0.3579518,2472,0.1162634818109804,11195.741937160492,0.31071076,0.1082360198243039,0.62130105,5348,0.1810440541819129 -1199.4069213867188,0.3752939701080322,11564.656034946442,14541,0,11564.656034946442,0.35419095,2472,0.1150650986127191,12765.01008605957,0.2729434,0.0984684989209287,0.61381125,5348,0.1769987545497552 -1329.8330297470093,0.4251205921173095,13004.882545471191,16309,0,13004.882545471191,0.34476194,2472,0.1120183616679869,14335.784494638445,0.2645288,0.0924194947039139,0.59791833,5348,0.1743244156521235 -1458.9472970962524,0.4777882099151611,14444.854892969131,18091,0,14444.854892969131,0.34444848,2472,0.1112262100623565,15904.995854139328,0.26972863,0.0982425844787549,0.59498703,5348,0.1731272386726783 -1587.2185769081116,0.5276048183441162,15885.44274878502,19865,0,15885.44274878502,0.31974483,2472,0.1054374098673653,17473.980586767197,0.27147388,0.0936460509267828,0.5683497,5348,0.1661662338163878 -1714.918186187744,0.5763967037200928,17325.437792778015,21623,0,17325.437792778015,0.31626993,2472,0.1033250055856844,19041.79708814621,0.2707888,0.0946052180336085,0.5585864,5348,0.1652586964287438 -1846.3830785751345,0.6237497329711914,18766.05039000511,23367,0,18766.05039000511,0.30071133,2472,0.0974143359129039,20613.99510908127,0.2480218,0.0869533902664883,0.5418796,5348,0.1582880369193933 -1975.4561262130733,0.6766774654388428,20206.7496676445,25165,0,20206.7496676445,0.29173753,2472,0.0952613084719598,22183.8956758976,0.22527014,0.0796946919471306,0.525986,5348,0.1530455603077903 -2105.6693222522736,0.7321693897247314,21647.13750886917,26917,0,21647.13750886917,0.28797644,2472,0.0925598683809639,23754.626539707184,0.21498787,0.0771903283862385,0.51132315,5348,0.1494154107572144 -2236.7508721351624,0.7884025573730469,23088.148204803467,28676,0,23088.148204803467,0.27436754,2472,0.0901834135640728,25326.84839820861,0.20073771,0.070516819915926,0.49496615,5348,0.1445977388802533 -2367.8951456546783,0.8454592227935791,24528.5071554184,30457,0,24528.5071554184,0.2712866,2472,0.0887412914102329,26898.482776403427,0.21881136,0.0770118304732189,0.49151084,5348,0.1446073935333134 -2496.7887375354767,0.8944950103759766,25968.977142572403,32198,0,25968.977142572403,0.26078537,2472,0.0849227144395019,28467.96927213669,0.19692653,0.0684702575705782,0.472254,5348,0.1387084005136275 -2627.1548767089844,0.9515886306762696,27409.48743200302,33952,0,27409.48743200302,0.25536835,2472,0.0824853248837162,30038.97890615464,0.2030439,0.0699495040299174,0.46502018,5348,0.1353292719426127 -2757.4602587223053,1.0017600059509275,28849.65223646164,35692,0,28849.65223646164,0.24218376,2472,0.0774074299758292,31609.57398867607,0.18933187,0.0637221991325383,0.44784793,5348,0.1313419002288152 -2886.308877468109,1.0653839111328125,30289.932039260864,37500,0,30289.932039260864,0.2352704,2472,0.0771433794406191,33178.84275341034,0.14027981,0.0504955772167727,0.4387208,5348,0.1268235225967154 -3017.221733808517,1.116579532623291,31730.45033931732,39245,0,31730.45033931732,0.22939885,2472,0.074360693031097,34750.40091919899,0.15786321,0.055283943316931,0.42685777,5348,0.1245546791276055 -3147.7732717990875,1.1740765571594238,33170.74593114853,41018,0,33170.74593114853,0.22491564,2472,0.0739747730180976,36321.37967252731,0.19925556,0.0705835769650092,0.419292,5348,0.1226237485155971 -3276.707732439041,1.225447177886963,34610.90734243393,42794,0,34610.90734243393,0.22072789,2472,0.0729185708772571,37890.603238105774,0.20634504,0.0723615914553922,0.4132973,5348,0.1202390492097666 -3402.704583644867,1.2802519798278809,36051.12446498871,44559,0,36051.12446498871,0.22010094,2472,0.0721467308512583,39456.94681978226,0.24091095,0.08543968739580347,0.41089597,5348,0.11951495023026347 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 67966b1ed..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,474 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,21.99945,33.24736,,,,,,,,,,,,,, -1,,,31.636915,1.985486121233576,30.702776,1.707309537831758,5348.0,30.705326,1.8779477179940285,2472.0,42.37398552894592,225.94245171546936,42.37398552894592,183.5684108734131,0.0,0.0 -100,0.8869887,5.9982686,,,,,,,,,,,,,, -200,0.9783815,5.8225994,,,,,,,,,,,,,, -300,1.0230471,5.7006454,,,,,,,,,,,,,, -400,0.782777,5.130528,,,,,,,,,,,,,, -500,1.3458927,4.349212,,,,,,,,,,,,,, -600,2.1326241,3.779672,,,,,,,,,,,,,, -700,2.548518,3.3797019,,,,,,,,,,,,,, -800,2.6040616,3.1172767,,,,,,,,,,,,,, -900,2.5833325,3.0388694,,,,,,,,,,,,,, -1000,2.4418616,2.7909038,,,,,,,,,,,,,, -1100,2.3157227,2.697493,,,,,,,,,,,,,, -1200,4.1586504,2.602493,,,,,,,,,,,,,, -1300,2.4649808,2.600618,,,,,,,,,,,,,, -1400,1.8749237,2.4949658,,,,,,,,,,,,,, -1500,1.6315514,2.3959925,,,,,,,,,,,,,, -1600,3.3813422,2.4005613,,,,,,,,,,,,,, -1700,2.7743611,2.2875893,,,,,,,,,,,,,, -1800,2.084383,2.2773626,,,,,,,,,,,,,, -1835,,,2.3297267,0.505952503552965,2.6970866,0.5646716935226933,5348.0,2.2324085,0.4973087156988199,2472.0,1482.5158276557922,1788.1893684864044,1482.5158276557922,305.5697865486145,0.0383594036102294,0.0 -1900,3.445948,2.2451217,,,,,,,,,,,,,, -2000,2.5891814,2.173324,,,,,,,,,,,,,, -2100,1.6982245,2.1572022,,,,,,,,,,,,,, -2200,3.6416674,2.1409588,,,,,,,,,,,,,, -2300,2.1780605,2.1232903,,,,,,,,,,,,,, -2400,2.7763333,2.0699658,,,,,,,,,,,,,, -2500,1.8598765,2.049362,,,,,,,,,,,,,, -2600,2.3865125,1.9726619,,,,,,,,,,,,,, -2700,1.997621,1.9723717,,,,,,,,,,,,,, -2800,4.1924224,2.0356529,,,,,,,,,,,,,, -2900,2.3602424,1.9288645,,,,,,,,,,,,,, -3000,1.8679129,1.8940705,,,,,,,,,,,,,, -3100,3.2585986,1.9188794,,,,,,,,,,,,,, -3200,2.0762749,1.8868645,,,,,,,,,,,,,, -3300,2.6445718,1.9077383,,,,,,,,,,,,,, -3400,2.7288659,1.9069352,,,,,,,,,,,,,, -3500,3.2173507,1.9191452,,,,,,,,,,,,,, -3600,2.399165,1.8308412,,,,,,,,,,,,,, -3675,,,0.5273757,0.1752337319873212,0.8686737,0.2495824362551531,5348.0,0.5609731,0.1808746166189344,2472.0,2923.0815935134888,3355.8438744544983,2923.0815935134888,432.54165267944336,0.083158254623413,0.0 -3700,2.482278,1.8176063,,,,,,,,,,,,,, -3800,2.9954524,1.8644382,,,,,,,,,,,,,, -3900,2.848104,1.8144151,,,,,,,,,,,,,, -4000,2.5182147,1.776022,,,,,,,,,,,,,, -4100,2.1595619,1.8048826,,,,,,,,,,,,,, -4200,2.6006663,1.7903899,,,,,,,,,,,,,, -4300,2.514251,1.7376808,,,,,,,,,,,,,, -4400,2.2862487,1.7928908,,,,,,,,,,,,,, -4500,2.2900221,1.686795,,,,,,,,,,,,,, -4600,3.3764236,1.7317027,,,,,,,,,,,,,, -4700,2.697171,1.7093822,,,,,,,,,,,,,, -4800,1.9253975,1.7600341,,,,,,,,,,,,,, -4900,1.8287364,1.7140915,,,,,,,,,,,,,, -5000,3.2298741,1.6950976,,,,,,,,,,,,,, -5100,2.1939309,1.6945498,,,,,,,,,,,,,, -5200,1.8368113,1.7725289,,,,,,,,,,,,,, -5300,2.589619,1.6878146,,,,,,,,,,,,,, -5400,3.982808,1.762145,,,,,,,,,,,,,, -5500,1.9799315,1.7261149,,,,,,,,,,,,,, -5515,,,0.39720902,0.1383947161735481,0.7544181,0.2186103092385375,5348.0,0.46814135,0.1503260008530863,2472.0,4362.980728149414,4923.29837346077,4362.980728149414,559.979326248169,0.1280710697174072,0.0 -5600,2.2388425,1.6286037,,,,,,,,,,,,,, -5700,3.1846285,1.6763899,,,,,,,,,,,,,, -5800,2.7300298,1.6284281,,,,,,,,,,,,,, -5900,2.1085835,1.6179647,,,,,,,,,,,,,, -6000,2.8462958,1.6739084,,,,,,,,,,,,,, -6100,3.57122,1.6492862,,,,,,,,,,,,,, -6200,2.6992369,1.6482743,,,,,,,,,,,,,, -6300,1.8671219,1.6618234,,,,,,,,,,,,,, -6400,1.972264,1.6344419,,,,,,,,,,,,,, -6500,2.1631472,1.6498303,,,,,,,,,,,,,, -6600,3.0630617,1.6152653,,,,,,,,,,,,,, -6700,3.8275018,1.61179,,,,,,,,,,,,,, -6800,5.3236227,1.5959084,,,,,,,,,,,,,, -6900,3.5598812,1.6069595,,,,,,,,,,,,,, -7000,2.2606456,1.6011993,,,,,,,,,,,,,, -7100,2.675988,1.679802,,,,,,,,,,,,,, -7200,2.9272628,1.6388372,,,,,,,,,,,,,, -7300,1.9107757,1.621007,,,,,,,,,,,,,, -7316,,,0.36748856,0.1257736493208641,0.69803226,0.2021008525058652,5348.0,0.41549453,0.1346251497978997,2472.0,5803.595576286316,6491.611656188965,5803.595576286316,687.5585751533508,0.1763753890991211,0.0 -7400,2.435545,1.5964221,,,,,,,,,,,,,, -7500,2.513138,1.6048719,,,,,,,,,,,,,, -7600,2.3266222,1.5628475,,,,,,,,,,,,,, -7700,2.0071943,1.5796508,,,,,,,,,,,,,, -7800,2.85105,1.5822678,,,,,,,,,,,,,, -7900,2.9599068,1.6277596,,,,,,,,,,,,,, -8000,3.163659,1.5767845,,,,,,,,,,,,,, -8100,2.465216,1.5699679,,,,,,,,,,,,,, -8200,2.368698,1.5950632,,,,,,,,,,,,,, -8300,3.3467917,1.5479842,,,,,,,,,,,,,, -8400,2.1610913,1.6312025,,,,,,,,,,,,,, -8500,2.9547133,1.6367118,,,,,,,,,,,,,, -8600,2.4890134,1.5734229,,,,,,,,,,,,,, -8700,3.375867,1.5468348,,,,,,,,,,,,,, -8800,2.1411166,1.6108894,,,,,,,,,,,,,, -8900,2.159539,1.6109785,,,,,,,,,,,,,, -9000,3.1400204,1.5747901,,,,,,,,,,,,,, -9100,2.8083215,1.5176582,,,,,,,,,,,,,, -9137,,,0.33860832,0.1189414860075235,0.66417503,0.1937881962211688,5348.0,0.39312792,0.1298925517437491,2472.0,7243.767488241196,8058.339680671692,7243.767488241196,813.9893782138824,0.2295944690704345,0.0 -9200,3.5119371,1.5806123,,,,,,,,,,,,,, -9300,3.2640758,1.4780608,,,,,,,,,,,,,, -9400,2.9220161,1.5602777,,,,,,,,,,,,,, -9500,2.7901964,1.5393778,,,,,,,,,,,,,, -9600,1.9752636,1.5276991,,,,,,,,,,,,,, -9700,2.8611066,1.6185577,,,,,,,,,,,,,, -9800,3.4441648,1.4646949,,,,,,,,,,,,,, -9900,2.288367,1.4978806,,,,,,,,,,,,,, -10000,3.6534097,1.6327777,,,,,,,,,,,,,, -10100,2.4951355,1.5693411,,,,,,,,,,,,,, -10200,2.788275,1.5569345,,,,,,,,,,,,,, -10300,3.1919954,1.4273444,,,,,,,,,,,,,, -10400,3.000966,1.5891182,,,,,,,,,,,,,, -10500,3.3975627,1.5273713,,,,,,,,,,,,,, -10600,1.9760627,1.5453533,,,,,,,,,,,,,, -10700,2.411465,1.5467721,,,,,,,,,,,,,, -10800,2.2230713,1.561932,,,,,,,,,,,,,, -10900,3.0051596,1.5884355,,,,,,,,,,,,,, -10948,,,0.35432062,0.1211623304223503,0.6541005,0.1902642478542533,5348.0,0.3835853,0.1250787073710722,2472.0,8684.250477790833,9627.266298294067,8684.250477790833,942.3130850791932,0.2782130241394043,0.0 -11000,2.0749483,1.5026238,,,,,,,,,,,,,, -11100,2.3583968,1.5812411,,,,,,,,,,,,,, -11200,3.0407934,1.5941935,,,,,,,,,,,,,, -11300,3.4077034,1.5244856,,,,,,,,,,,,,, -11400,1.9424003,1.5151093,,,,,,,,,,,,,, -11500,2.4004405,1.4683082,,,,,,,,,,,,,, -11600,3.041633,1.4980301,,,,,,,,,,,,,, -11700,2.8038144,1.511308,,,,,,,,,,,,,, -11800,2.923119,1.5052962,,,,,,,,,,,,,, -11900,3.4838054,1.5003338,,,,,,,,,,,,,, -12000,2.6439517,1.5153493,,,,,,,,,,,,,, -12100,2.3476696,1.4977672,,,,,,,,,,,,,, -12200,3.1399813,1.4862558,,,,,,,,,,,,,, -12300,3.3492014,1.6687077,,,,,,,,,,,,,, -12400,2.2376726,1.5158817,,,,,,,,,,,,,, -12500,3.1635392,1.5641502,,,,,,,,,,,,,, -12600,2.2550287,1.4411653,,,,,,,,,,,,,, -12700,2.0226493,1.4804528,,,,,,,,,,,,,, -12763,,,0.31071076,0.1082360198243039,0.62130105,0.1810440541819129,5348.0,0.3579518,0.1162634818109804,2472.0,10124.378967046738,11195.741937160492,10124.378967046738,1070.5347871780396,0.3274891376495361,0.0 -12800,4.754427,1.4726055,,,,,,,,,,,,,, -12900,4.1422243,1.5130202,,,,,,,,,,,,,, -13000,3.1807036,1.567217,,,,,,,,,,,,,, -13100,1.8874983,1.5121328,,,,,,,,,,,,,, -13200,2.8050873,1.5169803,,,,,,,,,,,,,, -13300,2.216116,1.4481767,,,,,,,,,,,,,, -13400,2.9972227,1.4914194,,,,,,,,,,,,,, -13500,2.8656974,1.4548367,,,,,,,,,,,,,, -13600,3.5880558,1.5149307,,,,,,,,,,,,,, -13700,4.0096893,1.4479359,,,,,,,,,,,,,, -13800,1.9719486,1.5073837,,,,,,,,,,,,,, -13900,2.337426,1.5018959,,,,,,,,,,,,,, -14000,2.8105214,1.4793838,,,,,,,,,,,,,, -14100,3.5051003,1.5137748,,,,,,,,,,,,,, -14200,2.670113,1.4821926,,,,,,,,,,,,,, -14300,3.351204,1.5604719,,,,,,,,,,,,,, -14400,2.2789183,1.5910581,,,,,,,,,,,,,, -14500,5.460126,1.4646022,,,,,,,,,,,,,, -14541,,,0.2729434,0.0984684989209287,0.61381125,0.1769987545497552,5348.0,0.35419095,0.1150650986127191,2472.0,11564.656034946442,12765.01008605957,11564.656034946442,1199.4069213867188,0.3752939701080322,0.0 -14600,1.8396542,1.4381161,,,,,,,,,,,,,, -14700,2.5799503,1.4635829,,,,,,,,,,,,,, -14800,3.2583296,1.5367028,,,,,,,,,,,,,, -14900,2.8270476,1.4855008,,,,,,,,,,,,,, -15000,3.1469896,1.4652553,,,,,,,,,,,,,, -15100,1.7736228,1.4845278,,,,,,,,,,,,,, -15200,3.1272278,1.4952239,,,,,,,,,,,,,, -15300,2.8471909,1.5055883,,,,,,,,,,,,,, -15400,4.6009736,1.516678,,,,,,,,,,,,,, -15500,2.003168,1.4661283,,,,,,,,,,,,,, -15600,1.9324163,1.4449013,,,,,,,,,,,,,, -15700,2.010022,1.4670197,,,,,,,,,,,,,, -15800,2.179329,1.4604492,,,,,,,,,,,,,, -15900,3.553894,1.4384384,,,,,,,,,,,,,, -16000,3.0688064,1.4650753,,,,,,,,,,,,,, -16100,2.7942295,1.4578992,,,,,,,,,,,,,, -16200,2.8639824,1.4347892,,,,,,,,,,,,,, -16300,2.5599627,1.4697318,,,,,,,,,,,,,, -16309,,,0.2645288,0.0924194947039139,0.59791833,0.1743244156521235,5348.0,0.34476194,0.1120183616679869,2472.0,13004.882545471191,14335.784494638445,13004.882545471191,1329.8330297470093,0.4251205921173095,0.0 -16400,2.5531526,1.534394,,,,,,,,,,,,,, -16500,4.693992,1.4828198,,,,,,,,,,,,,, -16600,3.0379784,1.4724253,,,,,,,,,,,,,, -16700,1.9504023,1.4320538,,,,,,,,,,,,,, -16800,4.8271127,1.4528942,,,,,,,,,,,,,, -16900,3.3980932,1.478587,,,,,,,,,,,,,, -17000,2.713086,1.4854732,,,,,,,,,,,,,, -17100,2.37041,1.4723568,,,,,,,,,,,,,, -17200,2.1387455,1.491452,,,,,,,,,,,,,, -17300,4.68761,1.511054,,,,,,,,,,,,,, -17400,3.7676125,1.4760708,,,,,,,,,,,,,, -17500,2.2496235,1.4679097,,,,,,,,,,,,,, -17600,3.802906,1.4440098,,,,,,,,,,,,,, -17700,3.0305383,1.4393306,,,,,,,,,,,,,, -17800,1.4276348,1.3632102,,,,,,,,,,,,,, -17900,2.334777,1.4326204,,,,,,,,,,,,,, -18000,3.2919018,1.5290762,,,,,,,,,,,,,, -18091,,,0.26972863,0.0982425844787549,0.59498703,0.1731272386726783,5348.0,0.34444848,0.1112262100623565,2472.0,14444.854892969131,15904.995854139328,14444.854892969131,1458.9472970962524,0.4777882099151611,0.0 -18100,2.942988,1.498286,,,,,,,,,,,,,, -18200,2.8240044,1.4652799,,,,,,,,,,,,,, -18300,2.9370308,1.4156246,,,,,,,,,,,,,, -18400,2.5857882,1.4500841,,,,,,,,,,,,,, -18500,5.604708,1.453555,,,,,,,,,,,,,, -18600,2.282981,1.4012187,,,,,,,,,,,,,, -18700,2.1204834,1.4145682,,,,,,,,,,,,,, -18800,2.6960707,1.4757941,,,,,,,,,,,,,, -18900,3.398049,1.4201844,,,,,,,,,,,,,, -19000,3.0799909,1.4523513,,,,,,,,,,,,,, -19100,3.2986438,1.4592348,,,,,,,,,,,,,, -19200,3.493085,1.438312,,,,,,,,,,,,,, -19300,3.528006,1.4599135,,,,,,,,,,,,,, -19400,2.1315851,1.477514,,,,,,,,,,,,,, -19500,3.414375,1.4323368,,,,,,,,,,,,,, -19600,3.326721,1.4391491,,,,,,,,,,,,,, -19700,2.6757197,1.4004332,,,,,,,,,,,,,, -19800,3.6791027,1.4730664,,,,,,,,,,,,,, -19865,,,0.27147388,0.0936460509267828,0.5683497,0.1661662338163878,5348.0,0.31974483,0.1054374098673653,2472.0,15885.44274878502,17473.980586767197,15885.44274878502,1587.2185769081116,0.5276048183441162,0.0 -19900,3.487421,1.3823198,,,,,,,,,,,,,, -20000,2.59879,1.3828678,,,,,,,,,,,,,, -20100,3.282346,1.4256577,,,,,,,,,,,,,, -20200,2.30958,1.4366412,,,,,,,,,,,,,, -20300,3.5616894,1.4163325,,,,,,,,,,,,,, -20400,2.5666225,1.482261,,,,,,,,,,,,,, -20500,3.649535,1.4628121,,,,,,,,,,,,,, -20600,2.615801,1.4719833,,,,,,,,,,,,,, -20700,2.7781494,1.4312905,,,,,,,,,,,,,, -20800,3.0327196,1.3956659,,,,,,,,,,,,,, -20900,2.8091593,1.4084615,,,,,,,,,,,,,, -21000,2.4590936,1.4171901,,,,,,,,,,,,,, -21100,4.1887946,1.4304631,,,,,,,,,,,,,, -21200,3.0978653,1.3578532,,,,,,,,,,,,,, -21300,2.4878018,1.4589728,,,,,,,,,,,,,, -21400,1.7543296,1.4273114,,,,,,,,,,,,,, -21500,2.4635258,1.3949859,,,,,,,,,,,,,, -21600,3.1842053,1.427052,,,,,,,,,,,,,, -21623,,,0.2707888,0.0946052180336085,0.5585864,0.1652586964287438,5348.0,0.31626993,0.1033250055856844,2472.0,17325.437792778015,19041.79708814621,17325.437792778015,1714.918186187744,0.5763967037200928,0.0 -21700,2.7431674,1.43487,,,,,,,,,,,,,, -21800,2.3095953,1.411173,,,,,,,,,,,,,, -21900,2.19631,1.4057537,,,,,,,,,,,,,, -22000,3.7126591,1.4628516,,,,,,,,,,,,,, -22100,3.000986,1.3811103,,,,,,,,,,,,,, -22200,7.6040726,1.4660579,,,,,,,,,,,,,, -22300,2.5642152,1.433729,,,,,,,,,,,,,, -22400,2.3453557,1.4061539,,,,,,,,,,,,,, -22500,4.4316196,1.4464285,,,,,,,,,,,,,, -22600,4.573266,1.4060386,,,,,,,,,,,,,, -22700,4.3036304,1.3346399,,,,,,,,,,,,,, -22800,2.353251,1.4632592,,,,,,,,,,,,,, -22900,3.6084561,1.4895235,,,,,,,,,,,,,, -23000,2.6667511,1.42706,,,,,,,,,,,,,, -23100,2.271,1.4136283,,,,,,,,,,,,,, -23200,3.2340586,1.4252292,,,,,,,,,,,,,, -23300,2.4503396,1.3658175,,,,,,,,,,,,,, -23367,,,0.2480218,0.0869533902664883,0.5418796,0.1582880369193933,5348.0,0.30071133,0.0974143359129039,2472.0,18766.05039000511,20613.99510908127,18766.05039000511,1846.3830785751345,0.6237497329711914,0.0 -23400,2.434761,1.4141998,,,,,,,,,,,,,, -23500,2.2009723,1.3734002,,,,,,,,,,,,,, -23600,2.2688498,1.3817838,,,,,,,,,,,,,, -23700,2.827133,1.3791772,,,,,,,,,,,,,, -23800,4.391348,1.4497974,,,,,,,,,,,,,, -23900,2.015046,1.4262999,,,,,,,,,,,,,, -24000,3.668645,1.3725939,,,,,,,,,,,,,, -24100,2.404187,1.392926,,,,,,,,,,,,,, -24200,2.7786882,1.3357687,,,,,,,,,,,,,, -24300,1.785115,1.4069978,,,,,,,,,,,,,, -24400,2.1185844,1.5064368,,,,,,,,,,,,,, -24500,2.7712672,1.3731779,,,,,,,,,,,,,, -24600,1.9916104,1.3685635,,,,,,,,,,,,,, -24700,3.4136727,1.3659565,,,,,,,,,,,,,, -24800,3.0341272,1.4131362,,,,,,,,,,,,,, -24900,2.094389,1.3657886,,,,,,,,,,,,,, -25000,2.0770588,1.3776239,,,,,,,,,,,,,, -25100,3.137873,1.3798317,,,,,,,,,,,,,, -25165,,,0.22527014,0.0796946919471306,0.525986,0.1530455603077903,5348.0,0.29173753,0.0952613084719598,2472.0,20206.7496676445,22183.8956758976,20206.7496676445,1975.4561262130733,0.6766774654388428,0.0 -25200,3.1138616,1.3880352,,,,,,,,,,,,,, -25300,2.4853027,1.3965956,,,,,,,,,,,,,, -25400,2.3042455,1.3838232,,,,,,,,,,,,,, -25500,3.0179396,1.3981231,,,,,,,,,,,,,, -25600,8.921978,1.3802267,,,,,,,,,,,,,, -25700,1.9943637,1.364481,,,,,,,,,,,,,, -25800,2.242098,1.3513734,,,,,,,,,,,,,, -25900,3.321183,1.3731,,,,,,,,,,,,,, -26000,2.4396486,1.3839257,,,,,,,,,,,,,, -26100,3.130766,1.399545,,,,,,,,,,,,,, -26200,2.0342448,1.293035,,,,,,,,,,,,,, -26300,2.3610291,1.4333096,,,,,,,,,,,,,, -26400,2.7381139,1.3222277,,,,,,,,,,,,,, -26500,2.4882863,1.2874614,,,,,,,,,,,,,, -26600,2.6213884,1.3789241,,,,,,,,,,,,,, -26700,2.6937099,1.3944492,,,,,,,,,,,,,, -26800,2.5252829,1.3319435,,,,,,,,,,,,,, -26900,3.647032,1.3295703,,,,,,,,,,,,,, -26917,,,0.21498787,0.0771903283862385,0.51132315,0.1494154107572144,5348.0,0.28797644,0.0925598683809639,2472.0,21647.13750886917,23754.626539707184,21647.13750886917,2105.6693222522736,0.7321693897247314,0.0 -27000,3.3045108,1.3312782,,,,,,,,,,,,,, -27100,4.357745,1.3277199,,,,,,,,,,,,,, -27200,3.1633496,1.3730597,,,,,,,,,,,,,, -27300,2.5736263,1.3709664,,,,,,,,,,,,,, -27400,3.4019954,1.3900145,,,,,,,,,,,,,, -27500,1.4314619,1.3245143,,,,,,,,,,,,,, -27600,2.7335992,1.3351921,,,,,,,,,,,,,, -27700,4.419072,1.4392723,,,,,,,,,,,,,, -27800,1.9906461,1.3506745,,,,,,,,,,,,,, -27900,2.1106563,1.3216585,,,,,,,,,,,,,, -28000,1.8388207,1.3600026,,,,,,,,,,,,,, -28100,2.1905,1.3467658,,,,,,,,,,,,,, -28200,2.2535746,1.3466808,,,,,,,,,,,,,, -28300,2.027454,1.2935737,,,,,,,,,,,,,, -28400,5.1608634,1.2868615,,,,,,,,,,,,,, -28500,3.2741766,1.3501215,,,,,,,,,,,,,, -28600,2.8537724,1.3314501,,,,,,,,,,,,,, -28676,,,0.20073771,0.070516819915926,0.49496615,0.1445977388802533,5348.0,0.27436754,0.0901834135640728,2472.0,23088.148204803467,25326.84839820861,23088.148204803467,2236.7508721351624,0.7884025573730469,0.0 -28700,2.3003702,1.3184584,,,,,,,,,,,,,, -28800,3.3990285,1.3141819,,,,,,,,,,,,,, -28900,1.9544914,1.2801565,,,,,,,,,,,,,, -29000,2.5642805,1.3169264,,,,,,,,,,,,,, -29100,3.3136935,1.320023,,,,,,,,,,,,,, -29200,2.3952045,1.3347464,,,,,,,,,,,,,, -29300,7.443598,1.3176751,,,,,,,,,,,,,, -29400,4.812481,1.2701033,,,,,,,,,,,,,, -29500,3.5994477,1.311686,,,,,,,,,,,,,, -29600,2.3608663,1.3791614,,,,,,,,,,,,,, -29700,1.8917392,1.3139113,,,,,,,,,,,,,, -29800,6.781118,1.3532283,,,,,,,,,,,,,, -29900,2.9433854,1.2931204,,,,,,,,,,,,,, -30000,2.0378559,1.3103029,,,,,,,,,,,,,, -30100,5.4033775,1.2161309,,,,,,,,,,,,,, -30200,2.2089043,1.2645146,,,,,,,,,,,,,, -30300,4.7114186,1.2869909,,,,,,,,,,,,,, -30400,3.2883723,1.3153478,,,,,,,,,,,,,, -30457,,,0.21881136,0.0770118304732189,0.49151084,0.1446073935333134,5348.0,0.2712866,0.0887412914102329,2472.0,24528.5071554184,26898.482776403427,24528.5071554184,2367.8951456546783,0.8454592227935791,0.0 -30500,2.2009425,1.2848951,,,,,,,,,,,,,, -30600,3.556293,1.3050934,,,,,,,,,,,,,, -30700,6.796073,1.2793307,,,,,,,,,,,,,, -30800,2.6533604,1.2863117,,,,,,,,,,,,,, -30900,4.636167,1.3098403,,,,,,,,,,,,,, -31000,1.8267485,1.3448744,,,,,,,,,,,,,, -31100,5.166768,1.2407348,,,,,,,,,,,,,, -31200,3.9947972,1.2778585,,,,,,,,,,,,,, -31300,4.1404004,1.3303047,,,,,,,,,,,,,, -31400,2.0259235,1.3589281,,,,,,,,,,,,,, -31500,3.3525693,1.2894233,,,,,,,,,,,,,, -31600,2.0955358,1.27322,,,,,,,,,,,,,, -31700,1.8672494,1.2735715,,,,,,,,,,,,,, -31800,1.8140059,1.3669977,,,,,,,,,,,,,, -31900,5.908447,1.2470517,,,,,,,,,,,,,, -32000,3.8843608,1.2454153,,,,,,,,,,,,,, -32100,8.131752,1.294619,,,,,,,,,,,,,, -32198,,,0.19692653,0.0684702575705782,0.472254,0.1387084005136275,5348.0,0.26078537,0.0849227144395019,2472.0,25968.977142572403,28467.96927213669,25968.977142572403,2496.7887375354767,0.8944950103759766,0.0 -32200,3.421142,1.3020971,,,,,,,,,,,,,, -32300,1.9658281,1.2735523,,,,,,,,,,,,,, -32400,2.4439967,1.3387766,,,,,,,,,,,,,, -32500,3.2607334,1.2786112,,,,,,,,,,,,,, -32600,2.301806,1.3525217,,,,,,,,,,,,,, -32700,2.6004136,1.2970431,,,,,,,,,,,,,, -32800,3.4272704,1.2818012,,,,,,,,,,,,,, -32900,2.1728418,1.3022951,,,,,,,,,,,,,, -33000,2.1334996,1.253115,,,,,,,,,,,,,, -33100,2.3515234,1.204372,,,,,,,,,,,,,, -33200,2.3367972,1.2631887,,,,,,,,,,,,,, -33300,3.8744605,1.2895916,,,,,,,,,,,,,, -33400,3.0860367,1.2119125,,,,,,,,,,,,,, -33500,4.1734815,1.1868742,,,,,,,,,,,,,, -33600,2.470018,1.222341,,,,,,,,,,,,,, -33700,4.410639,1.243658,,,,,,,,,,,,,, -33800,2.856041,1.2981019,,,,,,,,,,,,,, -33900,2.061318,1.2011378,,,,,,,,,,,,,, -33952,,,0.2030439,0.0699495040299174,0.46502018,0.1353292719426127,5348.0,0.25536835,0.0824853248837162,2472.0,27409.48743200302,30038.97890615464,27409.48743200302,2627.1548767089844,0.9515886306762696,0.0 -34000,4.0293,1.2814077,,,,,,,,,,,,,, -34100,1.7357892,1.3025317,,,,,,,,,,,,,, -34200,2.776603,1.2424678,,,,,,,,,,,,,, -34300,2.216663,1.2512432,,,,,,,,,,,,,, -34400,1.9649324,1.2713621,,,,,,,,,,,,,, -34500,3.80125,1.2562503,,,,,,,,,,,,,, -34600,2.733254,1.2275083,,,,,,,,,,,,,, -34700,5.289,1.205434,,,,,,,,,,,,,, -34800,2.8706167,1.1878308,,,,,,,,,,,,,, -34900,2.051632,1.2596428,,,,,,,,,,,,,, -35000,2.0355504,1.199499,,,,,,,,,,,,,, -35100,1.8018476,1.1810579,,,,,,,,,,,,,, -35200,2.5499709,1.243882,,,,,,,,,,,,,, -35300,2.7545474,1.1782627,,,,,,,,,,,,,, -35400,11.038237,1.3189389,,,,,,,,,,,,,, -35500,1.9803619,1.2297575,,,,,,,,,,,,,, -35600,1.7717441,1.1984735,,,,,,,,,,,,,, -35692,,,0.18933187,0.0637221991325383,0.44784793,0.1313419002288152,5348.0,0.24218376,0.0774074299758292,2472.0,28849.65223646164,31609.57398867607,28849.65223646164,2757.4602587223053,1.0017600059509275,0.0 -35700,2.5206017,1.2237792,,,,,,,,,,,,,, -35800,2.4375913,1.2949818,,,,,,,,,,,,,, -35900,3.3213181,1.2091141,,,,,,,,,,,,,, -36000,3.147571,1.1661754,,,,,,,,,,,,,, -36100,3.8081064,1.2296053,,,,,,,,,,,,,, -36200,3.4016361,1.2245622,,,,,,,,,,,,,, -36300,2.7601671,1.2055633,,,,,,,,,,,,,, -36400,3.6995714,1.1920897,,,,,,,,,,,,,, -36500,1.9200448,1.2087395,,,,,,,,,,,,,, -36600,1.7000884,1.2110795,,,,,,,,,,,,,, -36700,1.9253305,1.1722034,,,,,,,,,,,,,, -36800,3.3326068,1.1752055,,,,,,,,,,,,,, -36900,3.1762412,1.2358154,,,,,,,,,,,,,, -37000,2.8380444,1.2281898,,,,,,,,,,,,,, -37100,4.0972123,1.248226,,,,,,,,,,,,,, -37200,2.3653116,1.2050437,,,,,,,,,,,,,, -37300,2.7679687,1.191444,,,,,,,,,,,,,, -37400,1.8346462,1.1814061,,,,,,,,,,,,,, -37500,,,0.14027981,0.0504955772167727,0.4387208,0.1268235225967154,5348.0,0.2352704,0.0771433794406191,2472.0,30289.932039260864,33178.84275341034,30289.932039260864,2886.308877468109,1.0653839111328125,0.0 -37500,1.7872049,1.1915329,,,,,,,,,,,,,, -37600,1.5075938,1.2071462,,,,,,,,,,,,,, -37700,2.415625,1.2363813,,,,,,,,,,,,,, -37800,1.8822871,1.23398,,,,,,,,,,,,,, -37900,5.541638,1.1880784,,,,,,,,,,,,,, -38000,3.8807998,1.2057794,,,,,,,,,,,,,, -38100,2.3654997,1.1653577,,,,,,,,,,,,,, -38200,2.2098365,1.1948965,,,,,,,,,,,,,, -38300,1.6054435,1.2335348,,,,,,,,,,,,,, -38400,3.6041632,1.2441117,,,,,,,,,,,,,, -38500,2.3920338,1.2467871,,,,,,,,,,,,,, -38600,3.2229044,1.2043135,,,,,,,,,,,,,, -38700,3.1809049,1.1749194,,,,,,,,,,,,,, -38800,2.93359,1.2108164,,,,,,,,,,,,,, -38900,7.0848923,1.180996,,,,,,,,,,,,,, -39000,2.1033278,1.1941715,,,,,,,,,,,,,, -39100,1.7957195,1.1223266,,,,,,,,,,,,,, -39200,2.1434672,1.1861808,,,,,,,,,,,,,, -39245,,,0.15786321,0.055283943316931,0.42685777,0.1245546791276055,5348.0,0.22939885,0.074360693031097,2472.0,31730.45033931732,34750.40091919899,31730.45033931732,3017.221733808517,1.116579532623291,0.0 -39300,2.2754772,1.1942147,,,,,,,,,,,,,, -39400,2.14765,1.1949866,,,,,,,,,,,,,, -39500,1.6793436,1.168385,,,,,,,,,,,,,, -39600,2.567636,1.1565905,,,,,,,,,,,,,, -39700,4.635197,1.2418997,,,,,,,,,,,,,, -39800,2.5227609,1.1666756,,,,,,,,,,,,,, -39900,1.9481395,1.1812189,,,,,,,,,,,,,, -40000,2.3941352,1.1383352,,,,,,,,,,,,,, -40100,2.41834,1.1646793,,,,,,,,,,,,,, -40200,3.3845313,1.221494,,,,,,,,,,,,,, -40300,5.932262,1.2112868,,,,,,,,,,,,,, -40400,2.6993048,1.2316585,,,,,,,,,,,,,, -40500,2.7863467,1.177277,,,,,,,,,,,,,, -40600,10.465865,1.0543464,,,,,,,,,,,,,, -40700,2.0331023,1.1093928,,,,,,,,,,,,,, -40800,3.107963,1.2270889,,,,,,,,,,,,,, -40900,1.5672371,1.2171155,,,,,,,,,,,,,, -41000,2.0231302,1.1610554,,,,,,,,,,,,,, -41018,,,0.19925556,0.0705835769650092,0.419292,0.1226237485155971,5348.0,0.22491564,0.0739747730180976,2472.0,33170.74593114853,36321.37967252731,33170.74593114853,3147.7732717990875,1.1740765571594238,0.0 -41100,2.0496237,1.1734816,,,,,,,,,,,,,, -41200,4.186741,1.2048783,,,,,,,,,,,,,, -41300,1.6034528,1.1675397,,,,,,,,,,,,,, -41400,2.6612198,1.1431853,,,,,,,,,,,,,, -41500,2.048598,1.1659325,,,,,,,,,,,,,, -41600,2.0470076,1.1500262,,,,,,,,,,,,,, -41700,2.5769274,1.1105245,,,,,,,,,,,,,, -41800,3.1405032,1.1777635,,,,,,,,,,,,,, -41900,3.9676309,1.1845723,,,,,,,,,,,,,, -42000,2.9438205,1.174678,,,,,,,,,,,,,, -42100,3.3840253,1.1282344,,,,,,,,,,,,,, -42200,2.599609,1.1962792,,,,,,,,,,,,,, -42300,3.4071176,1.157556,,,,,,,,,,,,,, -42400,1.9754975,1.1283048,,,,,,,,,,,,,, -42500,4.143108,1.0922157,,,,,,,,,,,,,, -42600,2.4883657,1.1187086,,,,,,,,,,,,,, -42700,4.6243644,1.1640884,,,,,,,,,,,,,, -42794,,,0.20634504,0.0723615914553922,0.4132973,0.1202390492097666,5348.0,0.22072789,0.0729185708772571,2472.0,34610.90734243393,37890.603238105774,34610.90734243393,3276.707732439041,1.225447177886963,0.0 -42800,2.2111635,1.125945,,,,,,,,,,,,,, -42900,4.8718767,1.1413716,,,,,,,,,,,,,, -43000,3.8051794,1.1463292,,,,,,,,,,,,,, -43100,4.5588865,1.128494,,,,,,,,,,,,,, -43200,2.2360988,1.1430886,,,,,,,,,,,,,, -43300,2.043672,1.1093355,,,,,,,,,,,,,, -43400,2.2164874,1.1202106,,,,,,,,,,,,,, -43500,3.202901,1.1284233,,,,,,,,,,,,,, -43600,2.132236,1.1200333,,,,,,,,,,,,,, -43700,4.3116155,1.1135157,,,,,,,,,,,,,, -43800,2.5288544,1.1618009,,,,,,,,,,,,,, -43900,2.6437573,1.1110556,,,,,,,,,,,,,, -44000,4.2198367,1.1651617,,,,,,,,,,,,,, -44100,2.6204076,1.1855907,,,,,,,,,,,,,, -44200,2.2931273,1.1661978,,,,,,,,,,,,,, -44300,2.9406683,1.1649412,,,,,,,,,,,,,, -44400,3.268897,1.1014454,,,,,,,,,,,,,, -44500,2.384972,1.0929598,,,,,,,,,,,,,, -44559,,,0.24091095,0.0854396873958034,0.41089597,0.1195149502302634,5348.0,0.22010094,0.0721467308512583,2472.0,36051.12446498871,39456.94681978226,36051.12446498871,3402.704583644867,1.2802519798278809,0.0 -44559,,,,,,,,,,,36051.12446498871,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 756ccbbf6..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,232 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -529.7613813877106,0.0,19.528467893600464,1,0,19.528467893600464,0.5401853322982788,0.7146537899971008,0.0280191991350671,43793,549.2898893356323,0.5404419302940369,0.7089899182319641,0.022265429341623,0.5404320955276489,0.7129108905792236,0.0261747405043297,43793 -648.1611218452454,0.0277159214019775,259.61792516708374,748,0,259.61792516708374,0.983219563961029,0.0641781017184257,0.0540834326521347,43793,907.8264088630676,0.986663281917572,0.0516927912831306,0.0518440765289784,0.9842145442962646,0.0608852803707122,0.0512608568202372,43793 -770.0165569782257,0.0568201541900634,499.6568791866303,1491,0,499.6568791866303,0.9837393164634703,0.0582070387899875,0.1037207408094717,43793,1269.7700910568235,0.9875317811965942,0.0461782217025756,0.102587681895728,0.9847463369369508,0.0551133230328559,0.1048863280755017,43793 -895.7120683193207,0.084691047668457,739.8452451229095,2235,0,739.8452451229095,0.984077513217926,0.0562639608979225,0.1320783047225885,43793,1635.7016966342926,0.9879155158996582,0.043444886803627,0.1422481736933233,0.9850674271583556,0.0531889759004116,0.1327050623244102,43793 -1019.7300884723664,0.1144242286682128,980.0934472084044,2973,0,980.0934472084044,0.9843037128448486,0.0542911104857921,0.1567664400875005,43793,2000.018956899643,0.9879909753799438,0.0420600734651088,0.1689640605027597,0.985255002975464,0.0514483675360679,0.1546322777983081,43793 -1146.2826611995697,0.1429240703582763,1220.1030368804932,3716,0,1220.1030368804932,0.9846242666244508,0.0522505082190036,0.1775325516586747,43793,2366.63014125824,0.9885411262512208,0.0397534817457199,0.2006911287897099,0.9854920506477356,0.0495907105505466,0.1726104452563086,43793 -1273.235831975937,0.1699335575103759,1460.2951698303225,4455,0,1460.2951698303225,0.9848222136497498,0.0510275997221469,0.189843762419542,43793,2733.823101758957,0.988589644432068,0.0388961508870124,0.2133832001783586,0.9857372045516968,0.0483437851071357,0.1880250188182171,43793 -1401.9401717185974,0.1983261108398437,1700.3518633842468,5192,0,1700.3518633842468,0.985047161579132,0.0498691536486148,0.2044275570064991,43793,3102.633382081985,0.9890533089637756,0.0369926616549491,0.253568965894043,0.9858285784721376,0.0473449490964412,0.203995115895378,43793 -1527.7839286327362,0.2261953353881836,1940.53954744339,5927,0,1940.53954744339,0.9852859377861024,0.0491692684590816,0.2194756435087781,43793,3468.713948726654,0.9893468022346495,0.0361139774322509,0.2782895930164883,0.9861021637916564,0.0466593876481056,0.2093368453462271,43793 -1652.4074277877808,0.2549798488616943,2180.5288002491,6660,0,2180.5288002491,0.9852400422096252,0.0492401048541069,0.2220660015979791,43793,3833.375732898712,0.9893681406974792,0.0361036546528339,0.2615927564089334,0.986129343509674,0.0465039052069187,0.2194878491913071,43793 -1781.4448142051697,0.283771276473999,2420.737658262253,7396,0,2420.737658262253,0.9854978322982788,0.0484322234988212,0.231361468160634,43793,4202.670556783676,0.9895302057266236,0.0353137105703353,0.3048981406299598,0.9863741397857666,0.0459171384572982,0.2346207459164243,43793 -1909.870851516724,0.3127274513244629,2660.835664749145,8119,0,2660.835664749145,0.9853870272636414,0.0484150499105453,0.2312606216910576,43793,4571.243372917175,0.9896426796913148,0.0347865670919418,0.297149318468285,0.9862799644470216,0.0459003075957298,0.2356317973826318,43793 -2040.66327047348,0.3399653434753418,2900.8369381427765,8871,0,2900.8369381427765,0.985410213470459,0.0479228235781192,0.2439316950400646,43793,4942.084951400757,0.9896151423454284,0.0344796739518642,0.3239638788346503,0.9862819910049438,0.0454410649836063,0.2402566717270283,43793 -2168.9167437553406,0.3717496395111084,3141.071623802185,9603,0,3141.071623802185,0.9856237769126892,0.0479062609374523,0.2502284643444161,43793,5310.628950357437,0.9900298118591307,0.0330846048891544,0.339935185024856,0.9865061044692992,0.0452296324074268,0.2436097289388764,43793 -2294.598527908325,0.4007346630096435,3381.203540802002,10355,0,3381.203540802002,0.985651969909668,0.0473545342683792,0.246920308466755,43793,5676.491842985153,0.9901762008666992,0.0326098129153251,0.3561634795300548,0.9865044355392456,0.0447402521967887,0.2490307962152749,43793 -2424.878109931946,0.4318912029266357,3621.350810050965,11106,0,3621.350810050965,0.9858317971229552,0.0472535192966461,0.2559278961595786,43793,6046.970125198364,0.9905160665512084,0.0314555391669273,0.3813649471138655,0.9866416454315186,0.0444622375071048,0.2543336478332198,43793 -2553.246810436249,0.4600224494934082,3861.3633439540863,11838,0,3861.3633439540863,0.985789716243744,0.0473696626722812,0.2570963040451753,43793,6415.400445222855,0.9904680848121644,0.0314260125160217,0.3899980857227915,0.9865442514419556,0.0448398254811763,0.2530719494491042,43793 -2684.271157026291,0.5138566493988037,4101.431351184845,12583,0,4101.431351184845,0.9858494997024536,0.0471902675926685,0.2632409372541082,43793,6786.566886663437,0.990330457687378,0.0318564511835575,0.3694480247887148,0.9867208003997804,0.0444731898605823,0.259595101609442,43793 -2815.033797264099,0.5426039695739746,4341.640905618668,13332,0,4341.640905618668,0.9859135150909424,0.0467578880488872,0.2592363877832236,43793,7157.587838172913,0.990576148033142,0.0312654264271259,0.3798342046666803,0.9867805242538452,0.0439825765788555,0.2672474720559514,43793 -2943.0571522712708,0.571929931640625,4581.922067165375,14079,0,4581.922067165375,0.985912263393402,0.0469102598726749,0.2601537864182929,43793,7525.94166302681,0.9905685782432556,0.0311376955360174,0.3765914331016559,0.9868125915527344,0.0441166646778583,0.2629989339221388,43793 -3073.758298397064,0.6007485389709473,4822.095349788666,14815,0,4822.095349788666,0.9859982132911682,0.0468400716781616,0.2640174170571049,43793,7896.865571737289,0.990611970424652,0.0308564845472574,0.3953191977620491,0.9868409633636476,0.0441383384168148,0.2656919440829268,43793 -3207.43451666832,0.6317923069000244,5062.213956356049,15547,0,5062.213956356049,0.9859632253646852,0.0472245104610919,0.260707375649573,43793,8270.714021921158,0.99070805311203,0.0302942767739295,0.4081647168737971,0.9868353009223938,0.0444324798882007,0.261461695531231,43793 -3334.1568129062653,0.661334753036499,5302.277789592743,16279,0,5302.277789592743,0.9859278798103333,0.0468454733490943,0.2631692741851291,43793,8637.550012588501,0.9908111095428468,0.0298472344875335,0.415534257377386,0.9867788553237916,0.0441870130598545,0.2656293101641374,43793 -3462.667748451233,0.6905078887939453,5542.288363218308,17025,0,5542.288363218308,0.9859012961387634,0.0468287318944931,0.2625971975460897,43793,9006.121246814728,0.9908833503723145,0.0295868404209613,0.4339656447805606,0.9867642521858216,0.0441272146999836,0.2659126980593257,43793 -3590.7058432102203,0.7197377681732178,5782.3681898117065,17766,0,5782.3681898117065,0.9859914779663086,0.046521496027708,0.2605604429493556,43793,9374.288189411163,0.9913222193717957,0.0284822024405002,0.4490868252663456,0.9868288040161132,0.0438111871480941,0.2682480897307769,43793 -3718.7504889965057,0.7523410320281982,6022.423826932907,18490,0,6022.423826932907,0.986002802848816,0.0469948165118694,0.26194694370496,43793,9742.443471431732,0.9911221265792848,0.0291727930307388,0.4280779237068718,0.9867683053016664,0.0443317145109176,0.2622161551066213,43793 -3847.483499765396,0.7836909294128418,6262.385018110275,19233,0,6262.385018110275,0.9859328866004944,0.0467655733227729,0.262611156496814,43793,10111.189734220505,0.9910096526145936,0.0293345991522073,0.4320504665074334,0.986805260181427,0.0439628921449184,0.2754075840554567,43793 -3975.6875672340393,0.813539981842041,6502.49117732048,19976,0,6502.49117732048,0.9860508441925048,0.0472310557961463,0.2681115070087855,43793,10479.550280809402,0.9907341599464417,0.0300019308924675,0.4278425982866988,0.986909568309784,0.0444412007927894,0.2761037932319848,43793 -4102.936810493469,0.8434133529663086,6742.673607110977,20722,0,6742.673607110977,0.9859228134155272,0.0470405109226703,0.2631733995044497,43793,10847.032025575638,0.9907103776931764,0.0302657280117273,0.4083187747450125,0.9867829084396362,0.0442190542817115,0.2743305827789721,43793 -4232.218222379684,0.8738067150115967,6982.719746828079,21465,0,6982.719746828079,0.9860706329345704,0.0468168444931507,0.2728187773762324,43793,11216.410384893416,0.9911373853683472,0.028706619516015,0.4494239762717231,0.9869254231452942,0.0439803414046764,0.2773801293033675,43793 -4357.820563793182,0.9055871963500975,7222.9662935733795,22210,0,7222.9662935733795,0.985990583896637,0.0468807965517044,0.2618272747604446,43793,11582.311628341677,0.9912395477294922,0.0284947175532579,0.4497181502557184,0.9868494868278505,0.0440571866929531,0.2787659404410005,43793 -4488.238070964813,0.937244176864624,7462.987103223801,22951,0,7462.987103223801,0.9859901666641236,0.0469595305621624,0.2666838640350755,43793,11952.801618337631,0.9914074540138244,0.0279013980180025,0.4588849204040839,0.9868494868278505,0.0440556108951568,0.2763261030184993,43793 -4616.471472978592,0.9694650173187256,7703.116607427597,23693,0,7703.116607427597,0.9859842658042908,0.0466601550579071,0.265365051821998,43793,12321.216604471208,0.9915399551391602,0.0275176540017128,0.4827089102235441,0.9867687225341796,0.0440514981746673,0.2806043098786221,43793 -4746.633130073547,1.000164270401001,7943.241847038269,24437,0,7943.241847038269,0.9860904216766356,0.0465195775032043,0.2728047702282302,43793,12691.554522037506,0.9916155338287354,0.027268037199974,0.4807372090344972,0.9868682026863098,0.0439805127680301,0.2749072176084976,43793 -4877.8257603645325,1.0311439037322998,8183.310019493103,25167,0,8183.310019493103,0.9861481189727784,0.0468418523669242,0.2706147259221962,43793,13062.866502285004,0.9914122819900512,0.0279057696461677,0.4571382522024744,0.9869388341903688,0.0440827757120132,0.280402054904702,43793 -5004.138481140137,1.0636134147644043,8423.271643161774,25896,0,8423.271643161774,0.9862020611763,0.0467101261019706,0.2640981604920295,43793,13429.19395661354,0.991398811340332,0.0279161781072616,0.4598515014503374,0.9869047403335572,0.0441686175763607,0.2727832495909499,43793 -5132.878929376602,1.0949442386627195,8663.474416017532,26634,0,8663.474416017532,0.9862193465232848,0.0467636249959468,0.2770079309687364,43793,13798.188932180405,0.9914059638977052,0.0278588086366653,0.4804148735269029,0.987074375152588,0.0439664088189601,0.2818839899912088,43793 -5258.856852293015,1.128938913345337,8903.555037021637,27370,0,8903.555037021637,0.9862067103385924,0.0468411184847354,0.2719546221242722,43793,14164.301958322523,0.9915772080421448,0.0271905399858951,0.4743581862163639,0.9870736002922058,0.0438411273062229,0.2876096184277866,43793 -5389.0773758888245,1.1617298126220703,9143.596687078476,28112,0,9143.596687078476,0.9860523343086244,0.0470948740839958,0.2704405059868041,43793,14534.617344141006,0.9914939999580384,0.0273632034659385,0.4776075046106587,0.9868556261062622,0.0443020462989807,0.2781089525019385,43793 -5516.761331558228,1.194563627243042,9383.813615322111,28846,0,9383.813615322111,0.98611319065094,0.0470629222691059,0.2691851735216722,43793,14902.571242570875,0.9917556643486024,0.0266566574573516,0.4946562271208262,0.9868446588516236,0.0443001873791217,0.2766937713708444,43793 -5643.975158452988,1.228114128112793,9623.958882570269,29584,0,9623.958882570269,0.9861321449279784,0.047050341963768,0.2701175138067016,43793,15269.985672235489,0.9919654130935668,0.0259400550276041,0.5117402117318774,0.9869948625564576,0.0441078841686248,0.2880325126914176,43793 -5767.308829545975,1.2603094577789309,9864.012751102448,30326,0,9864.012751102448,0.9860247373580932,0.04674918577075,0.2688654699671372,43793,15633.425725221634,0.9920579195022584,0.0257496107369661,0.5215447573111239,0.9868953824043274,0.0439697168767452,0.2858635197434587,43793 -5892.495110750198,1.2933599948883057,10104.057809352877,31067,0,10104.057809352877,0.9861097931861876,0.0469236336648464,0.2746272969201792,43793,15998.710246562958,0.991892635822296,0.0262051355093717,0.5086030748444564,0.986970067024231,0.0440684333443641,0.2834956332839969,43793 -6019.689756393433,1.3267717361450195,10344.097812891006,31816,0,10344.097812891006,0.9861460328102112,0.0468500703573226,0.2760263600114768,43793,16365.99844264984,0.9916676878929138,0.026850014925003,0.4904087292640114,0.9869644045829772,0.0441396757960319,0.2862625865048266,43793 -6147.57537651062,1.3592548370361328,10584.372139453888,32560,0,10584.372139453888,0.9861759543418884,0.0464902445673942,0.2793258645120916,43793,16734.21134352684,0.9917237758636476,0.0267841536551713,0.493660033538621,0.9869781732559204,0.0438555255532264,0.290015872842108,43793 -6275.525854349136,1.3940739631652832,10824.466272115707,33296,0,10824.466272115707,0.9861931800842284,0.0470806285738945,0.2709264239678933,43793,17102.31195449829,0.9918603897094728,0.0262102521955966,0.4982211958330765,0.9870301485061646,0.0441683307290077,0.287087614806626,43793 -6399.477682113648,1.4283545017242432,11064.580405950546,34029,0,11064.580405950546,0.9860681295394896,0.0469147376716136,0.2748894485436163,43793,17466.4331305027,0.9919319152832032,0.0259202476590871,0.5161529871779962,0.9869481325149536,0.0443512499332428,0.2823631179528553,43793 -6529.37206697464,1.4609463214874268,11304.724274158478,34779,0,11304.724274158478,0.986177623271942,0.0469932220876216,0.2721970894810648,43793,17836.523913621902,0.992088496685028,0.0254748146981,0.5129617949169412,0.9869903922080994,0.044027104973793,0.2915628601867511,43793 -6656.179391384125,1.4955766201019287,11544.948768615724,35509,0,11544.948768615724,0.9862454533576964,0.0470156781375408,0.2753576260925718,43793,18203.61198425293,0.9922428131103516,0.0247446577996015,0.5505906445570861,0.9870585799217224,0.0441386215388774,0.2907984307159181,43793 -6784.992396831512,1.5288910865783691,11784.974064826964,36250,0,11784.974064826964,0.9862197637557985,0.0467967800796031,0.2766800615256882,43793,18572.50408053398,0.992393672466278,0.0244789887219667,0.5408486038115388,0.9870277047157288,0.0441088415682315,0.2908113959289101,43793 -6913.606185436249,1.5621047019958496,12024.976420640944,36993,0,12024.976420640944,0.9861136078834534,0.0472501963376998,0.2734903307389538,43793,18941.173763275143,0.9923135042190552,0.0246789418160915,0.5443136465959396,0.9869603514671326,0.0445223152637481,0.2863923956318354,43793 -7042.153408050537,1.5955908298492432,12265.204212665558,37716,0,12265.204212665558,0.9861860275268556,0.0470856502652168,0.2764250250297177,43793,19310.005308389664,0.9922302961349488,0.0250727999955415,0.5250435594834874,0.9870094656944276,0.044437400996685,0.2870718545969231,43793 -7167.328474283218,1.6293635368347168,12505.271298408508,38451,0,12505.271298408508,0.98612242937088,0.0472948960959911,0.2705252738596758,43793,19675.301374912266,0.9921622276306152,0.0252226088196039,0.5285684492729031,0.9870207905769348,0.0444123074412345,0.287736763336245,43793 -7293.105636358261,1.66961932182312,12745.515764474869,39188,0,12745.515764474869,0.9861502647399902,0.0476831793785095,0.2739721380819767,43793,20041.384852409363,0.9923495650291444,0.0245155375450849,0.5353314789084255,0.9869980812072754,0.0446966439485549,0.2880907345675527,43793 -7424.434202432632,1.7036066055297852,12985.472758769987,39921,0,12985.472758769987,0.9862037301063538,0.0472649373114109,0.2791759747694861,43793,20412.724797964096,0.9920818209648132,0.0250757187604904,0.5387114399975329,0.9870151281356812,0.0445631518959999,0.2937449771905799,43793 -7550.27251458168,1.7443230152130127,13225.699823617935,40650,0,13225.699823617935,0.9861915111541748,0.0471347123384475,0.2772578004381141,43793,20778.853857278824,0.992318868637085,0.0244046170264482,0.5562079477709713,0.98701673746109,0.0444464944303035,0.2909518713632352,43793 -7677.433129072189,1.7788498401641846,13465.783110141754,41388,0,13465.783110141754,0.9861556887626648,0.0476178787648677,0.2709540144738322,43793,21146.153143405914,0.9926362037658693,0.0233683511614799,0.5768032113155684,0.9870049953460692,0.0447624586522579,0.2829433254848256,43793 -7808.087415456772,1.815146446228028,13705.972064971924,42116,0,13705.972064971924,0.986143946647644,0.0475996248424053,0.2726043926098506,43793,21517.054282665253,0.992948591709137,0.022594491019845,0.5889403009136849,0.9868965744972228,0.0448220036923885,0.2870109267260117,43793 -7932.13530087471,1.8508188724517824,13946.16392469406,42857,0,13946.16392469406,0.9862349033355712,0.0478113815188407,0.277432456781912,43793,21881.35036206245,0.9928516745567322,0.0227420013397932,0.5831679997609343,0.9870119094848632,0.0449406802654266,0.2870644899198072,43793 -8057.833017587662,1.88677716255188,14186.274589061735,43598,0,14186.274589061735,0.9860474467277528,0.0475262962281703,0.2770334013150399,43793,22247.215339899063,0.9924556016921996,0.0241395328193902,0.5471580974598048,0.9868945479393004,0.0446267016232013,0.2904814723725562,43793 -8182.722589492798,1.921675443649292,14426.232063055038,44327,0,14426.232063055038,0.986100137233734,0.0481578856706619,0.2753276791308924,43793,22612.117126464844,0.9924808740615844,0.0239379573613405,0.5512886909498305,0.9869850873947144,0.0452325120568275,0.2895107733140303,43793 -8310.012906312943,1.9601173400878904,14666.217765569689,45063,0,14666.217765569689,0.9860609769821168,0.0477090999484062,0.276624794918469,43793,22979.45182442665,0.992731750011444,0.0231606904417276,0.5747661718707735,0.9868564009666444,0.0449821837246418,0.28365360157518,43793 -8439.401960372925,1.9960315227508545,14906.2397480011,45795,0,14906.2397480011,0.9862298369407654,0.0482310391962528,0.2777345965558692,43793,23348.91964435577,0.9927511811256408,0.022958293557167,0.586343797232631,0.987152338027954,0.045345164835453,0.2902990328130586,43793 -8567.942576169968,2.0327556133270264,15146.333956003187,46535,0,15146.333956003187,0.98613041639328,0.0485120005905628,0.2745446524126901,43793,23717.61184549332,0.9927822351455688,0.0226614810526371,0.5853062788070533,0.9870321750640868,0.0456238947808742,0.288616217569316,43793 -8695.869742393494,2.068514823913574,15386.32931470871,47272,0,15386.32931470871,0.9861814379692078,0.0481357239186763,0.2751556729047205,43793,24085.590614795685,0.9932552576065063,0.0213545765727758,0.6282544798621549,0.9870585799217224,0.0451401658356189,0.2918916130401043,43793 -8817.450694084167,2.1042933464050293,15626.565325260162,48016,0,15626.565325260162,0.986250936985016,0.0485204793512821,0.2787191662258837,43793,24447.46418762207,0.9932361245155334,0.0213189255446195,0.6125954411107427,0.9870853424072266,0.0456184148788452,0.2917766721293825,43793 -8944.631282806396,2.1402597427368164,15866.78537106514,48757,0,15866.78537106514,0.9862517714500428,0.0484604015946388,0.275624097328915,43793,24814.92128801346,0.9933009743690492,0.0212412904947996,0.6215911847182629,0.9870471954345704,0.0455567426979541,0.2858171389788816,43793 -9068.64112663269,2.176506996154785,16106.996109962463,49494,0,16106.996109962463,0.9862500429153442,0.0490105859935283,0.2732506608323254,43793,25179.198776960373,0.9931593537330629,0.0215862058103084,0.6026874616431173,0.9870740175247192,0.0459361635148525,0.2909390577597649,43793 -9197.52271914482,2.213009119033813,16347.118661403656,50221,0,16347.118661403656,0.9861556887626648,0.0486082173883914,0.2786955782953111,43793,25548.262518405914,0.9930923581123352,0.0217242259532213,0.6127489989284732,0.9870272874832152,0.0457522124052047,0.2891217197250926,43793 -9326.890407800674,2.250386476516724,16587.07647919655,50958,0,16587.07647919655,0.9862281680107116,0.048733152449131,0.2759372437124913,43793,25917.645943164825,0.9932236075401306,0.0213300064206123,0.6119174869729177,0.98704195022583,0.0457834005355834,0.2952523480539175,43793 -9452.768271923063,2.288546562194824,16827.240026474,51693,0,16827.240026474,0.9862563610076904,0.048782855272293,0.2809863309832373,43793,26283.74558782577,0.993235409259796,0.02134177275002,0.6205512641982079,0.987038254737854,0.0458993092179298,0.2932043078510703,43793 -9575.785498142242,2.326702117919922,17067.183010816574,52431,0,17067.183010816574,0.9862008094787598,0.0494671426713466,0.275198819814396,43793,26646.76397776604,0.9933454394340516,0.0208866987377405,0.6233633128448957,0.9870431423187256,0.0463724546134471,0.2947568615806708,43793 -9702.37763953209,2.363882303237915,17307.16687989235,53174,0,17307.16687989235,0.9861982464790344,0.0490794591605663,0.277876897106266,43793,27013.39756894112,0.993765652179718,0.0198266431689262,0.648274235699503,0.9869290590286256,0.0461834780871868,0.293072159487718,43793 -9829.080079555511,2.401202440261841,17547.396492242813,53902,0,17547.396492242813,0.9861683249473572,0.0492582395672798,0.2728726141685595,43793,27380.38984155655,0.9939457774162292,0.0190976541489362,0.6707530227995426,0.9869185090065002,0.046229638159275,0.2894399170703837,43793 -9960.751490354538,2.438985824584961,17787.61997103691,54639,0,17787.61997103691,0.9862428903579712,0.049737349152565,0.275510165065686,43793,27752.34281229973,0.9939642548561096,0.0190879423171281,0.6836719989171522,0.986976146697998,0.0466739870607852,0.2931826954710649,43793 -10089.079269647598,2.4811816215515137,18027.63859820366,55378,0,18027.63859820366,0.9861915111541748,0.0497394278645515,0.2762871000737157,43793,28120.75264573097,0.9935755729675292,0.0200575366616249,0.6444034424776388,0.9869863390922546,0.0467824675142765,0.2931136241225678,43793 -10218.807644367218,2.518969774246216,18267.65841269493,56114,0,18267.65841269493,0.9862412214279176,0.0501952022314071,0.2766432008581418,43793,28490.55920910836,0.9936255216598512,0.0199688728898763,0.6374130133398599,0.9870532751083374,0.0471296124160289,0.2916289264718201,43793 -10345.98371219635,2.55792236328125,18507.83859324456,56855,0,18507.83859324456,0.986143946647644,0.0501659922301769,0.2744961808523273,43793,28857.974859952927,0.9937307238578796,0.0195559505373239,0.6575340063625909,0.987097144126892,0.0470090880990028,0.2928508141046779,43793 -10472.135902881622,2.5963222980499268,18747.98366856575,57592,0,18747.98366856575,0.9862513542175292,0.0502744019031524,0.2765929581524281,43793,29224.3308801651,0.9938116073608398,0.0193178337067365,0.6544946164639226,0.9870354533195496,0.0472130291163921,0.2911910347651354,43793 -10601.76334810257,2.6340818405151367,18988.1043651104,58329,0,18988.1043651104,0.9861683249473572,0.0507872477173805,0.2757953912061542,43793,29594.138032197952,0.9938645958900452,0.0190196689218282,0.6740505018592804,0.9870337843894958,0.0474877506494522,0.2891203495535188,43793 -10722.174923658373,2.673880100250244,19228.10386610031,59073,0,19228.10386610031,0.9862648248672484,0.0509006977081298,0.2718785727090479,43793,29954.60914540291,0.9941751956939696,0.0180447380989789,0.7022220962344548,0.9870853424072266,0.0476047173142433,0.2895204188531082,43793 -10850.485614299774,2.712521314620972,19468.106098413467,59813,0,19468.106098413467,0.9861982464790344,0.0511804521083831,0.2744422124499764,43793,30322.98073625565,0.9943927526474,0.0174613632261753,0.7008946711494934,0.9870354533195496,0.0478969104588031,0.2893068462662133,43793 -10969.589729309082,2.7505152225494385,19708.216701745987,60557,0,19708.216701745987,0.986163318157196,0.0511066801846027,0.2768247622583396,43793,30682.25395989418,0.9945353865623474,0.0171850584447383,0.7161099193970502,0.987028956413269,0.0478489883244037,0.2911537483110391,43793 -11091.201363563538,2.789128541946411,19948.361443281174,61294,0,19948.361443281174,0.9862361550331116,0.0511085912585258,0.2774663366633761,43793,31044.06996655464,0.9944341778755188,0.0173479840159416,0.6998307741889802,0.9870285391807556,0.048157338052988,0.2906187634034315,43793 -11210.443829536438,2.8275270462036133,20188.502468585968,62029,0,20188.502468585968,0.9862399697303772,0.0514232739806175,0.2774561780744564,43793,31403.512269973755,0.9944174885749816,0.0174654368311166,0.7034989601227458,0.987064242362976,0.0482059121131897,0.294787254294828,43793 -11329.68220448494,2.8689663410186768,20428.56449484825,62776,0,20428.56449484825,0.9862205982208252,0.0515215322375297,0.2804855365186107,43793,31762.874539613724,0.9943106174468994,0.0174315962940454,0.6984793166748001,0.9869989156723022,0.0483455397188663,0.2943116500300501,43793 -11453.315598249435,2.9075872898101807,20668.679448366165,63524,0,20668.679448366165,0.9861738085746764,0.0517492853105068,0.2760619205947037,43793,32126.68216776848,0.9942901730537416,0.0177717264741659,0.6904933370506032,0.9868978261947632,0.0487305261194705,0.2893957604156967,43793 -11580.409008979796,2.946688413619995,20908.896485328674,64266,0,20908.896485328674,0.9861649870872498,0.0520334914326667,0.2751095010284893,43793,32494.05240297317,0.99448424577713,0.0171282012015581,0.7091291766287441,0.987028956413269,0.0488222688436508,0.295160113444928,43793 -11697.220544338226,2.9854743480682373,21149.007489204407,65013,0,21149.007489204407,0.9861902594566344,0.0524006597697734,0.2766862655102707,43793,32851.034240722656,0.9946561455726624,0.0164926163852214,0.7217163236524171,0.9869893789291382,0.0493305474519729,0.2885682063827535,43793 -11817.548296689987,3.0303094387054443,21388.94634079933,65754,0,21388.94634079933,0.986180543899536,0.052390594035387,0.2744711419810177,43793,33211.36627578735,0.994984209537506,0.0157103948295116,0.751098217448867,0.9869558811187744,0.0491550415754318,0.2917384044633542,43793 -11937.971812486649,3.069685459136963,21628.92591571808,66488,0,21628.92591571808,0.9861359000205994,0.0525679588317871,0.2745586815742383,43793,33571.829740047455,0.9949955940246582,0.015580584295094,0.743294926692901,0.9869229793548584,0.0493643842637538,0.2929701131226171,43793 -12060.300013542175,3.110136747360229,21868.94931316376,67226,0,21868.94931316376,0.9861502647399902,0.0527020953595638,0.2758090958557493,43793,33934.243431806564,0.995172679424286,0.0151609126478433,0.752292266305013,0.9870151281356812,0.0494721122086048,0.2935105053459713,43793 -12179.768604755402,3.154653310775757,22108.92947244644,67962,0,22108.92947244644,0.9861688017845154,0.053035944700241,0.2768718733391812,43793,34293.75819349289,0.9949970245361328,0.0155762517824769,0.7404822138643696,0.9870175719261168,0.0498333163559436,0.2956191146081193,43793 -12304.749731063845,3.194903612136841,22349.121163129807,68695,0,22349.121163129807,0.9861283302307128,0.0528410337865352,0.2770126310365181,43793,34658.99287319183,0.9949944615364076,0.0155779477208852,0.7456122928440716,0.9869124293327332,0.0496644973754882,0.2947325860111379,43793 -12423.543547868729,3.256009817123413,22589.33414030075,69439,0,22589.33414030075,0.9861435294151306,0.0531154498457908,0.2764680653602601,43793,35018.08116745949,0.9948853850364684,0.0157805941998958,0.7421132570454949,0.9869562983512878,0.0499289967119693,0.2922757941451337,43793 -12549.790276288986,3.2973122596740723,22829.57661533356,70182,0,22829.57661533356,0.9861018061637878,0.0530740618705749,0.2752897350878246,43793,35384.631796598434,0.9949990510940552,0.0154972299933433,0.7395333602277423,0.986955463886261,0.049820426851511,0.2946346190690715,43793 -12668.670906543732,3.3372585773468018,23069.656600236893,70928,0,23069.656600236893,0.9861287474632264,0.0536137297749519,0.2747126367614844,43793,35743.65252280235,0.9952114820480348,0.0149235017597675,0.7624726762016496,0.987009048461914,0.0501800142228603,0.292888755323264,43793 -12794.468041181564,3.378878831863404,23309.76019668579,71669,0,23309.76019668579,0.9861708879470824,0.0535033456981182,0.2766858995086599,43793,36109.615293741226,0.9953161478042604,0.0145877562463283,0.7535696656906885,0.9870131015777588,0.0501160286366939,0.2949806488973432,43793 -12916.356031417848,3.4264605045318604,23550.00160861016,72398,0,23550.00160861016,0.9861599206924438,0.0534425675868988,0.2757994127724285,43793,36471.81417179108,0.9954724311828612,0.0142005616798996,0.7746771146110315,0.9869810342788696,0.0500811003148555,0.2950637680379031,43793 -13040.898215293884,3.469341993331909,23790.09094119072,73141,0,23790.09094119072,0.9861510992050172,0.0536785386502742,0.2758017772485302,43793,36836.50923585892,0.9954382181167604,0.0143452454358339,0.7782669256073729,0.9870086312294006,0.050230972468853,0.2939535993961091,43793 -13161.084740161896,3.51848578453064,24030.04671812057,73865,0,24030.04671812057,0.9861910939216614,0.0537983998656272,0.2746981960824259,43793,37196.72438669205,0.9954219460487366,0.014223325997591,0.77684622631777,0.9870402812957764,0.0503243319690227,0.2935714816912196,43793 -13281.721255779266,3.5607388019561768,24270.110892534256,74606,0,24270.110892534256,0.9861281514167786,0.0537189431488513,0.2753094927661393,43793,37557.48768091202,0.9953264594078064,0.0145053453743457,0.7675750118704716,0.9869623780250548,0.0502690747380256,0.2933164821603256,43793 -13404.296205759048,3.6379876136779785,24510.027459144592,75339,0,24510.027459144592,0.9861477017402648,0.0538434907793998,0.2762096863110901,43793,37920.07783651352,0.9953559637069702,0.0144855827093124,0.7559158070368784,0.9869757294654846,0.0504377521574497,0.2932075498164984,43793 -13525.030313014984,3.6831717491149902,24750.19059085846,76068,0,24750.19059085846,0.9861472845077516,0.053802452981472,0.2767602742047537,43793,38281.042036771774,0.9953692555427552,0.0144447674974799,0.7646718876340945,0.9869931936264038,0.0504125356674194,0.2931289894339166,43793 -13641.3963701725,3.724697113037109,24990.2333111763,76810,0,24990.2333111763,0.986127495765686,0.0538610033690929,0.275816124039496,43793,38637.51325106621,0.995473325252533,0.0141021022573113,0.7793426664482765,0.9869822263717652,0.0504016950726509,0.2939869283770364,43793 -13760.451066493988,3.767266511917114,25230.289999961853,77543,0,25230.289999961853,0.98613041639328,0.0538419187068939,0.2757654531108339,43793,38996.68863105774,0.995535135269165,0.013972028158605,0.7854258956826089,0.9869814515113832,0.0504111647605896,0.2938731668240078,43793 -13881.799325227736,3.8089919090271,25470.411824703217,78284,0,25470.411824703217,0.9861502647399902,0.0538810566067695,0.2762060660463843,43793,39358.22065544128,0.9955268502235411,0.0140170855447649,0.7766022190074704,0.9870017170906068,0.0504525080323219,0.2935173090637798,43793 -14004.569951534271,3.8572590351104736,25710.480096817017,79005,0,25710.480096817017,0.9861544370651244,0.0539372749626636,0.2760413132298542,43793,39721.13111758232,0.9955037236213684,0.0140764387324452,0.7753496218501175,0.9870216250419616,0.050493486225605,0.2931579732099656,43793 -14122.105741977692,3.906131267547608,25950.60027909279,79743,0,25950.60027909279,0.986162006855011,0.0539317578077316,0.2758741018370922,43793,40078.85696077347,0.9954732656478882,0.0141325732693076,0.7680403320439355,0.9870207905769348,0.0504891276359558,0.2931052477542326,43793 -14239.878486156464,3.949522018432617,26190.832530498505,80484,0,26190.832530498505,0.9861586689949036,0.0539308004081249,0.2760013251638327,43793,40436.92608857155,0.9954727292060852,0.0140985697507858,0.7734134501513834,0.9870207905769348,0.0504884012043476,0.2932042189287904,43793 -14356.332607507706,3.993171215057373,26430.913827180862,81227,0,26430.913827180862,0.9861586689949036,0.0539308078587055,0.2759893265669187,43793,40793.525622844696,0.9955683350563048,0.0138638988137245,0.7826635520548781,0.9870207905769348,0.0504884012043476,0.2933164355896822,43793 -14480.77075123787,4.034946918487549,26671.08064627648,81964,0,26671.08064627648,0.9861586689949036,0.0539308004081249,0.2758879511772302,43793,41158.192217350006,0.9954887628555298,0.014108401723206,0.7777829012838969,0.9870207905769348,0.0504884012043476,0.2932682244947826,43793 -14599.658869504929,4.078474044799805,26911.22525882721,82684,0,26911.22525882721,0.9861586689949036,0.0539308078587055,0.2759877260387026,43793,41517.2912569046,0.9954996109008788,0.0140416109934449,0.773361696380792,0.9870207905769348,0.0504884012043476,0.2932646215050519,43793 -14720.061032772064,4.120722055435181,27151.46985912323,83424,0,27151.46985912323,0.9861586689949036,0.0539308078587055,0.2759603496560459,43793,41878.00029158592,0.995476007461548,0.0141355525702238,0.7666465209970005,0.9870207905769348,0.0504884012043476,0.2931580788234526,43793 -14839.390854358671,4.164043664932251,27391.469383001328,84163,0,27391.469383001328,0.9861586689949036,0.0539308004081249,0.2759228082006909,43793,42237.39330744743,0.9954543709754944,0.0141666149720549,0.7774010441861712,0.9870207905769348,0.0504884012043476,0.2932389238302139,43793 -14958.367205619812,4.2065746784210205,27631.43978309632,84890,0,27631.43978309632,0.9861586689949036,0.0539308078587055,0.2759544450025004,43793,42596.40437030792,0.9955111145973206,0.0140264285728335,0.7820682909038227,0.9870207905769348,0.0504884012043476,0.2933124181851586,43793 -15081.195603370668,4.2489824295043945,27871.646685361862,85621,0,27871.646685361862,0.9861586689949036,0.0539308004081249,0.2759841329736199,43793,42959.50312495232,0.9955190420150756,0.013955594971776,0.7841828855297852,0.9870207905769348,0.0504884012043476,0.2931340940827518,43793 -15206.80028629303,4.29235053062439,28111.63972616196,86345,0,28111.63972616196,0.9861586689949036,0.0539308078587055,0.275911657756955,43793,43325.16566514969,0.99550598859787,0.0140627259388566,0.772074521198249,0.9870207905769348,0.0504884012043476,0.2932729704005158,43793 -15325.437908411026,4.341654539108276,28351.80605864525,87074,0,28351.80605864525,0.9861586689949036,0.0539308004081249,0.2760227125616841,43793,43684.0396668911,0.9954939484596252,0.0140989450737833,0.7701163354328072,0.9870207905769348,0.0504884012043476,0.2930998900386766,43793 -15443.043601989746,4.391093969345093,28591.916484594345,87801,0,28591.916484594345,0.9861586689949036,0.0539308078587055,0.2759060362578897,43793,44041.82602286339,0.9954929947853088,0.0141305215656757,0.771353052064392,0.9870207905769348,0.0504884012043476,0.2932041771177117,43793 -15557.111010789871,4.433941125869751,28831.876116752625,88532,0,28831.876116752625,0.9861586689949036,0.0539308078587055,0.2759228496513186,43793,44395.9159309864,0.9954800009727478,0.0140950288623571,0.781205735501981,0.9870207905769348,0.0504884012043476,0.2932369735857291,43793 -15676.584495544434,4.477561473846436,29071.96212220192,89274,0,29071.96212220192,0.9861586689949036,0.0539308078587055,0.2758838902077703,43793,44755.5393948555,0.9955880641937256,0.0138135934248566,0.7830603803705166,0.9870207905769348,0.0504884012043476,0.2932198575902922,43793 -15792.836474895475,4.530365467071533,29312.07736825943,90005,0,29312.07736825943,0.9861586689949036,0.0539308004081249,0.2758904980645204,43793,45111.980585575104,0.9954330921173096,0.0141439158469438,0.7768471852559291,0.9870207905769348,0.0504884012043476,0.2932248225737306,43793 -15916.857501029968,4.576200008392334,29552.14755630493,90745,0,29552.14755630493,0.9861586689949036,0.0539308004081249,0.2759802245458693,43793,45476.13833808899,0.9954633712768556,0.0141695011407136,0.7759647711457989,0.9870207905769348,0.0504884012043476,0.2933183661124961,43793 -16034.809426784515,4.627282619476318,29792.3046181202,91480,0,29792.3046181202,0.9861586689949036,0.0539308078587055,0.2760763937091977,43793,45834.32055068016,0.9955159425735474,0.0140622379258275,0.7617297949769014,0.9870207905769348,0.0504884012043476,0.2932435573506383,43793 -16145.526446580889,4.671795606613159,30032.342000246048,92223,0,30032.342000246048,0.9861586689949036,0.0539308078587055,0.2759261753171792,43793,46185.14011597633,0.9955058097839355,0.014057345688343,0.7794176027937787,0.9870207905769348,0.0504884012043476,0.2930912488721401,43793 -16259.84417271614,4.716129779815674,30272.507484912872,92972,0,30272.507484912872,0.9861586689949036,0.0539308004081249,0.2760432220871991,43793,46539.68813109398,0.9955204129219056,0.0139325829222798,0.7852867760278921,0.9870207905769348,0.0504884012043476,0.2931630314705667,43793 -16382.511649608612,4.762882471084595,30512.44143342972,93718,0,30512.44143342972,0.9861586689949036,0.0539308078587055,0.2760487360250577,43793,46902.35689115524,0.9955332279205322,0.0139707894995808,0.7733520647228189,0.9870207905769348,0.0504884012043476,0.2931728770329977,43793 -16498.03994011879,4.806723833084106,30752.5343811512,94453,0,30752.5343811512,0.9861586689949036,0.0539308078587055,0.2758970474532101,43793,47258.04355049133,0.9954761266708374,0.0141239240765571,0.7818585828536757,0.9870207905769348,0.0504884012043476,0.2932003254588689,43793 -16611.206999063492,4.85191011428833,30992.618040323257,95184,0,30992.618040323257,0.9861586689949036,0.0539308078587055,0.2759964447818105,43793,47611.36131834984,0.9954842329025269,0.0140874776989221,0.7673942179190336,0.9870207905769348,0.0504884012043476,0.2931838900380704,43793 -16730.642671346664,4.896559953689575,31232.86323785782,95924,0,31232.86323785782,0.9861586689949036,0.0539308078587055,0.2759322101408783,43793,47971.10742545128,0.9954838156700134,0.014161848463118,0.7696215638128405,0.9870207905769348,0.0504884012043476,0.2931871120912127,43793 -16848.551023483276,4.945436477661133,31473.12002182007,96659,0,31473.12002182007,0.9861586689949036,0.0539308078587055,0.2759397003339527,43793,48329.34169435501,0.9955300688743592,0.0139715131372213,0.7872416582912839,0.9870207905769348,0.0504884012043476,0.2932467264780176,43793 -16960.928512334824,4.992716550827026,31713.225883245468,97395,0,31713.225883245468,0.9861586689949036,0.0539308078587055,0.2760300954617187,43793,48681.89289999008,0.9954854249954224,0.0140521274879574,0.7760487820006098,0.9870207905769348,0.0504884012043476,0.2932641153462439,43793 -17078.096727132797,5.03900408744812,31953.380586862564,98139,0,31953.380586862564,0.9861586689949036,0.0539308078587055,0.2759224751185829,43793,49039.28237915039,0.99550861120224,0.0140364868566393,0.775531835162473,0.9870207905769348,0.0504884012043476,0.2930844133223836,43793 -17189.921184539795,5.08355450630188,32193.38186430931,98892,0,32193.38186430931,0.9861586689949036,0.0539308078587055,0.2759939712103814,43793,49391.1727848053,0.9955294132232666,0.0139757683500647,0.7690663991926115,0.9870207905769348,0.0504884012043476,0.2931506027499619,43793 -17303.229751110077,5.12899374961853,32433.47683668137,99635,0,32433.47683668137,0.9861586689949036,0.0539308078587055,0.2758982340598774,43793,49744.64137840271,0.9954495429992676,0.0142050180584192,0.770540888307726,0.9870207905769348,0.0504884012043476,0.2932130341888327,43793 -17411.270839214325,5.1742777824401855,32673.473799943924,100387,0,32673.473799943924,0.9861586689949036,0.0539308078587055,0.2759344706039897,43793,50092.74442219734,0.995512068271637,0.0140080070123076,0.7855383859173153,0.9870207905769348,0.0504884012043476,0.2930783521607571,43793 -17523.80086541176,5.221256732940674,32913.52817153931,101136,0,32913.52817153931,0.9861586689949036,0.0539308078587055,0.2759272854940975,43793,50445.39581871033,0.9954873919487,0.0140253193676471,0.7812430634313685,0.9870207905769348,0.0504884012043476,0.293090200785777,43793 -17639.583562850952,5.267125368118286,33153.59733271599,101875,0,33153.59733271599,0.9861586689949036,0.0539308004081249,0.2759245002265661,43793,50801.31637907028,0.995502471923828,0.014068104326725,0.775914528450701,0.9870207905769348,0.0504884012043476,0.2931656029885052,43793 -17751.57058095932,5.313305616378784,33393.808596372604,102621,0,33393.808596372604,0.9861586689949036,0.0539308078587055,0.2758684385834484,43793,51153.58129143715,0.995550572872162,0.0140123721212148,0.776397050638031,0.9870207905769348,0.0504884012043476,0.2933286895597028,43793 -17866.617556095123,5.3591132164001465,33633.803196430206,103365,0,33633.803196430206,0.9861586689949036,0.0539308004081249,0.2760386250555259,43793,51508.68910455704,0.995477855205536,0.0140792066231369,0.7638450164442087,0.9870207905769348,0.0504884012043476,0.2932093756277367,43793 -17979.562771081924,5.405600786209106,33873.81412649155,104114,0,33873.81412649155,0.9861586689949036,0.0539308004081249,0.2759427131775175,43793,51861.71270442009,0.9954692721366882,0.0141795184463262,0.7776649159084392,0.9870207905769348,0.0504884012043476,0.2932161456923653,43793 -18089.034460544583,5.451803922653198,34113.82161974907,104849,0,34113.82161974907,0.9861586689949036,0.0539308004081249,0.2760546916619492,43793,52211.25935649872,0.995512306690216,0.013982149772346,0.7832667867806614,0.9870207905769348,0.0504884012043476,0.2931969650094557,43793 -18200.998265981674,5.49837327003479,34353.93736720085,105602,0,34353.93736720085,0.9861586689949036,0.0539308004081249,0.2759460188977808,43793,52563.40596866608,0.9954873919487,0.0140405539423227,0.7785648148482405,0.9870207905769348,0.0504884012043476,0.2932527154438649,43793 -18312.1730606556,5.544644832611084,34594.05852675438,106346,0,34594.05852675438,0.9861586689949036,0.0539308004081249,0.2760089875594442,43793,52914.76798701286,0.9955300688743592,0.0140305142849683,0.7752093696893039,0.9870207905769348,0.0504884012043476,0.2931428330709554,43793 -18422.77271294593,5.592383623123169,34834.22911691666,107096,0,34834.22911691666,0.9861586689949036,0.0539308004081249,0.275923875209314,43793,53265.606330394745,0.995497703552246,0.0140161113813519,0.7694205666819729,0.9870207905769348,0.0504884012043476,0.293144991287255,43793 -18537.066703557968,5.642268657684326,35074.33372759819,107836,0,35074.33372759819,0.9861586689949036,0.0539308078587055,0.2759135417957489,43793,53620.07576179504,0.995468020439148,0.0141827668994665,0.7754401981422524,0.9870207905769348,0.0504884012043476,0.2932577357417439,43793 -18649.70942378044,5.689713478088379,35314.36759090424,108589,0,35314.36759090424,0.9861586689949036,0.0539308004081249,0.2759655696714156,43793,53972.82032966614,0.9955382347106934,0.0139516443014144,0.7749240428541433,0.9870207905769348,0.0504884012043476,0.2932017358982969,43793 -18761.4860200882,5.736135482788086,35554.36666512489,109339,0,35554.36666512489,0.9861586689949036,0.0539308078587055,0.275938599498735,43793,54324.66298913956,0.995488941669464,0.0140517679974436,0.783534154624986,0.9870207905769348,0.0504884012043476,0.2930661018682657,43793 -18875.78289008141,5.784965753555298,35794.49884533882,110088,0,35794.49884533882,0.9861586689949036,0.0539308004081249,0.2759169175102017,43793,54679.16133594513,0.9954979419708252,0.0140773402526974,0.7767405525270463,0.9870207905769348,0.0504884012043476,0.2932059522267767,43793 -18984.285687446594,5.839473009109497,36034.63081741333,110825,0,36034.63081741333,0.9861586689949036,0.0539308004081249,0.2759136667682885,43793,55027.872133016586,0.99551659822464,0.0140205714851617,0.7701554597187098,0.9870207905769348,0.0504884012043476,0.2933111682688331,43793 -19094.74762225151,5.886816501617432,36274.76571893692,111571,0,36274.76571893692,0.9861586689949036,0.0539308004081249,0.2758906413013993,43793,55378.53616476059,0.995445728302002,0.0141978915780782,0.7755410798187834,0.9870207905769348,0.0504884012043476,0.2932146665047697,43793 -19202.802329540253,5.942864656448364,36514.839708566666,112296,0,36514.839708566666,0.9861586689949036,0.0539308004081249,0.2759193595748794,43793,55726.74456644058,0.9955244064331056,0.0139960274100303,0.7807075584140233,0.9870207905769348,0.0504884012043476,0.2931795272519896,43793 -19315.524201393127,5.990647315979004,36754.81792402268,113040,0,36754.81792402268,0.9861586689949036,0.0539308078587055,0.2760884196260506,43793,56079.51251745224,0.9955311417579652,0.0139847984537482,0.7829461868920284,0.9870207905769348,0.0504884012043476,0.2933635049845934,43793 -19423.286600351334,6.043809175491333,36994.844656705856,113788,0,36994.844656705856,0.9861586689949036,0.0539308078587055,0.2759159872101339,43793,56427.37561607361,0.9954670071601868,0.0140660954639315,0.7718478538114525,0.9870207905769348,0.0504884012043476,0.293139372132737,43793 -19538.20528769493,6.430257320404053,37234.62434768677,114532,0,37234.62434768677,0.9861586689949036,0.0539308004081249,0.2759121370281739,43793,56782.48134112358,0.9955251216888428,0.0140108959749341,0.7738367760206022,0.9870207905769348,0.0504884012043476,0.293188261011175,43793 -19648.142145633698,6.480026960372925,37474.78272938728,115275,0,37474.78272938728,0.9861586689949036,0.0539308004081249,0.2759276565912759,43793,57132.64653587341,0.9954529404640198,0.0141806257888674,0.7671019554487162,0.9870207905769348,0.0504884012043476,0.2932588652343885,43793 -19761.3506834507,6.528173208236694,37714.924875974655,116018,0,37714.924875974655,0.9861586689949036,0.0539308078587055,0.2760005242200415,43793,57486.06534719467,0.9955293536186218,0.0140388701111078,0.7772930990901283,0.9870207905769348,0.0504884012043476,0.2932376430301499,43793 -19868.06057667732,6.577084302902222,37954.86515974999,116766,0,37954.86515974999,0.9861586689949036,0.0539308078587055,0.2760248373857741,43793,57832.78515815735,0.9954851865768432,0.0140848821029067,0.7850383991852307,0.9870207905769348,0.0504884012043476,0.2931388648347254,43793 -19982.06836414337,6.626428604125977,38195.090396404266,117507,0,38195.090396404266,0.9861586689949036,0.0539308004081249,0.2759380183987922,43793,58187.087896347046,0.9955120086669922,0.0139834303408861,0.7743394761970276,0.9870207905769348,0.0504884012043476,0.2931852186260689,43793 -20096.57867860794,6.675387859344482,38435.11602497101,118253,0,38435.11602497101,0.9861586689949036,0.0539308078587055,0.276064940503458,43793,58541.69343018532,0.995533049106598,0.013999680057168,0.7822424577658693,0.9870207905769348,0.0504884012043476,0.2932161305695458,43793 -20206.304889678955,6.732793807983398,38675.15397500992,118999,0,38675.15397500992,0.9861586689949036,0.0539308078587055,0.2759337036347488,43793,58891.5354681015,0.9954657554626464,0.0141053376719355,0.7663495065328356,0.9870207905769348,0.0504884012043476,0.2931702847497467,43793 -20316.15882253647,6.782490015029907,38915.19271636009,119739,0,38915.19271636009,0.9861586689949036,0.0539308004081249,0.2761250320439972,43793,59241.49839806557,0.9954808950424194,0.0141744362190365,0.7775733664311058,0.9870207905769348,0.0504884012043476,0.2932722189513038,43793 -20422.99523949623,6.833070516586304,39155.25708556175,120477,0,39155.25708556175,0.9861586689949036,0.0539308078587055,0.2759202887756614,43793,59588.46991467476,0.9955210089683532,0.0139794861897826,0.7815879575595298,0.9870207905769348,0.0504884012043476,0.2930735441592322,43793 -20533.277759552,6.884111404418945,39395.50072574616,121214,0,39395.50072574616,0.9861586689949036,0.0539308078587055,0.2759633130124874,43793,59939.0671851635,0.9954981207847596,0.0139925740659236,0.7821686742341195,0.9870207905769348,0.0504884012043476,0.2931317994411526,43793 -20642.88063430786,6.936408996582031,39635.43482732773,121949,0,39635.43482732773,0.9861586689949036,0.0539308004081249,0.2758850451009829,43793,60288.67901420593,0.9954849481582642,0.0141089959070086,0.7756348087463598,0.9870207905769348,0.0504883974790573,0.2932923447189138,43793 -20753.818819522858,6.987528085708618,39875.50879430771,122693,0,39875.50879430771,0.9861586689949036,0.0539308078587055,0.2759276980217821,43793,60639.76265883446,0.99551123380661,0.0140492506325244,0.7678146983914897,0.9870207905769348,0.0504884012043476,0.2932222266958525,43793 -20864.3991625309,7.039609432220459,40115.57969236374,123435,0,40115.57969236374,0.9861586689949036,0.0539308078587055,0.2759675452993614,43793,60990.48681926727,0.9954609870910645,0.0141681898385286,0.7764094208096587,0.9870207905769348,0.0504884012043476,0.2931615148093934,43793 -20970.238475561146,7.089834451675415,40355.65077519417,124180,0,40355.65077519417,0.9861586689949036,0.0539308078587055,0.2760686342789627,43793,61336.46803617477,0.9955407977104188,0.0139289442449808,0.7768192297472776,0.9870207905769348,0.0504884012043476,0.2932437338587668,43793 -21081.669250011444,7.141475677490234,40595.69643044472,124928,0,40595.69643044472,0.9861586689949036,0.0539308078587055,0.2760136439954811,43793,61688.01654314995,0.9954872131347656,0.0140464529395103,0.7832406491909317,0.9870207905769348,0.0504884012043476,0.2931946410851083,43793 -21196.552217245106,7.191240072250366,40835.86017179489,125672,0,40835.86017179489,0.9861586689949036,0.0539308004081249,0.2760047208967673,43793,62043.13323545456,0.9954885840415956,0.0140621056780219,0.774091442743637,0.9870207905769348,0.0504884012043476,0.2933094306738013,43793 -21306.72452545166,7.242794752120972,41075.91949534416,126416,0,41075.91949534416,0.9861586689949036,0.0539308004081249,0.2760524931298125,43793,62393.43706226349,0.995519518852234,0.0140607506036758,0.7807933473108024,0.9870207905769348,0.0504884012043476,0.2932310735617742,43793 -21415.063493013386,7.293832540512085,41316.159787893295,127160,0,41316.159787893295,0.9861586689949036,0.0539308078587055,0.2760117545460758,43793,62742.088060855865,0.9954697489738464,0.0141372112557291,0.7587198449639805,0.9870207905769348,0.0504884012043476,0.2932764543116762,43793 -21520.694525957108,7.344229459762573,41556.275861501694,127901,0,41556.275861501694,0.9861586689949036,0.0539308078587055,0.275960832651128,43793,63087.90578913689,0.995512843132019,0.0140344286337494,0.7794579884619639,0.9870207905769348,0.0504884012043476,0.2931954360456731,43793 -21627.90496778488,7.395256757736206,41796.47999668121,128646,0,41796.47999668121,0.9861586689949036,0.0539308004081249,0.2760132085099576,43793,63435.39198184013,0.99550598859787,0.0140278497710824,0.7830336621813878,0.9870207905769348,0.0504884012043476,0.2931289849079299,43793 -21738.40746998787,7.44689416885376,42036.49334049225,129368,0,42036.49334049225,0.9861586689949036,0.0539308078587055,0.276029149677099,43793,63785.981810092926,0.9955154657363892,0.0139666367322206,0.7766582892696063,0.9870207905769348,0.0504884012043476,0.293199850367342,43793 -21848.249663591385,7.499475479125977,42276.55060958862,130105,0,42276.55060958862,0.9861586689949036,0.0539308004081249,0.2759899810924676,43793,64135.95481061936,0.9955015778541564,0.0140527635812759,0.7812440709210555,0.9870207905769348,0.0504884012043476,0.2932153395867975,43793 -21959.45810699463,7.557837963104248,42516.48381781578,130840,0,42516.48381781578,0.9861586689949036,0.0539308078587055,0.2759373339764978,43793,64487.17621850968,0.9954874515533448,0.0140948053449392,0.7717035971463917,0.9870207905769348,0.0504884012043476,0.2930906746882962,43793 -22067.263786792755,7.610391139984131,42756.610813617706,131586,0,42756.610813617706,0.9861586689949036,0.0539308004081249,0.2758972165051127,43793,64835.18199014664,0.995464563369751,0.0142739154398441,0.7646260937715028,0.9870207905769348,0.0504884012043476,0.2932155802724631,43793 -22177.67981219292,7.664034605026245,42996.56942439079,132333,0,42996.56942439079,0.9861586689949036,0.0539308078587055,0.2759797114188226,43793,65185.630719423294,0.9955654144287108,0.0138419400900602,0.7805771446483449,0.9870207905769348,0.0504884012043476,0.2932781435099585,43793 -22286.250038146973,7.716859579086304,43236.740758657455,133080,0,43236.740758657455,0.9861586689949036,0.0539308078587055,0.2759513026070447,43793,65534.44535279274,0.9954608678817748,0.0140660284087061,0.7814404601298929,0.9870207905769348,0.0504884012043476,0.2933030047203722,43793 -22399.64813780785,7.770319223403931,43476.80657362938,133820,0,43476.80657362938,0.9861586689949036,0.0539308078587055,0.2759277858320216,43793,65887.98328089714,0.9954729676246644,0.0141444317996501,0.7757278511595392,0.9870207905769348,0.0504884012043476,0.293336474690113,43793 -22508.24461197853,7.830138921737671,43716.8826019764,134551,0,43716.8826019764,0.9861586689949036,0.0539308004081249,0.2759842286678688,43793,66236.73903632164,0.9955356121063232,0.0139828957617282,0.7739661532021433,0.9870207905769348,0.0504884012043476,0.2931877864285181,43793 -22621.700261354446,7.88405442237854,43956.86270856857,135297,0,43956.86270856857,0.9861586689949036,0.0539308078587055,0.2759395228017564,43793,66590.24928569794,0.995487630367279,0.0141261909157037,0.7703792187677057,0.9870207905769348,0.0504884012043476,0.2932659339598245,43793 -22734.069207906723,7.935742139816284,44196.990394830704,136036,0,44196.990394830704,0.9861586689949036,0.0539308078587055,0.2759236846054778,43793,66942.81797623634,0.995522141456604,0.014012542553246,0.7761438917983973,0.9870207905769348,0.0504884012043476,0.2932118139807845,43793 -22846.563081502914,7.987916707992554,44437.059807538986,136772,0,44437.059807538986,0.9861586689949036,0.0539308004081249,0.2759907501293877,43793,67295.4547200203,0.995475709438324,0.0140751581639051,0.7810047617352318,0.9870207905769348,0.0504884012043476,0.2931009444456341,43793 -22957.16247224808,8.046531915664673,44677.20724821091,137508,0,44677.20724821091,0.9861586689949036,0.0539308078587055,0.27596270562728,43793,67646.28160524368,0.9954957365989684,0.0140201319009065,0.7739173571564724,0.9870207905769348,0.0504884012043476,0.2932000694716932,43793 -23065.58174324036,8.102015018463135,44917.2401702404,138247,0,44917.2401702404,0.9861586689949036,0.0539308078587055,0.2759087396709697,43793,67994.80980920792,0.9955067038536072,0.0140622872859239,0.779246398721551,0.9870207905769348,0.0504884012043476,0.2931393059919693,43793 -23171.964718818665,8.154671430587769,45157.17242622376,138991,0,45157.17242622376,0.9861586689949036,0.0539308078587055,0.2759749694299141,43793,68341.19829106331,0.9954759478569032,0.0141209280118346,0.7633788395828764,0.9870207905769348,0.0504884012043476,0.2931961719757504,43793 -23278.051063776016,8.208355903625488,45397.25456047058,139731,0,45397.25456047058,0.9861586689949036,0.0539308078587055,0.2759331733707005,43793,68687.44159507751,0.9955406785011292,0.0140155563130974,0.7742222547996289,0.9870207905769348,0.0504884012043476,0.293142011524399,43793 -23387.48162317276,8.26203179359436,45637.525134563446,140480,0,45637.525134563446,0.9861586689949036,0.0539308004081249,0.2759290721408751,43793,69037.21703863144,0.9954973459243774,0.0140028344467282,0.786076248561501,0.9870207905769348,0.0504884012043476,0.2932708711021463,43793 -23498.272793293,8.319572448730469,45877.65914797783,141228,0,45877.65914797783,0.9861586689949036,0.0539308004081249,0.2760197268325939,43793,69388.22018957138,0.995505392551422,0.0140246078372001,0.7788733946088782,0.9870207905769348,0.0504884012043476,0.2932886916868568,43793 -23601.612596273422,8.372910976409912,46117.658863544464,141970,0,46117.658863544464,0.9861586689949036,0.0539308078587055,0.2760298456241528,43793,69731.63406729698,0.9954705834388732,0.0141415288671851,0.7747681263770184,0.9870207905769348,0.0504884012043476,0.293163724473554,43793 -23711.222553491592,8.42588186264038,46357.69627165794,142710,0,46357.69627165794,0.9861586689949036,0.0539308078587055,0.2759479146602972,43793,70081.35587143898,0.9955161213874816,0.0139932902529835,0.76743690980945,0.9870207905769348,0.0504884012043476,0.2932152944727341,43793 -23817.034957647324,8.481031894683838,46597.8717443943,143450,0,46597.8717443943,0.9861586689949036,0.0539308078587055,0.2759225501940873,43793,70427.41950631142,0.9954485893249512,0.0142313251271843,0.7793896921346442,0.9870207905769348,0.0504884012043476,0.2931962193358611,43793 -23930.11753678322,8.535125494003296,46837.9917037487,144195,0,46837.9917037487,0.9861586689949036,0.0539308078587055,0.2759723667264904,43793,70780.69690322876,0.9955180287361144,0.0140170911327004,0.7712021251836819,0.9870207905769348,0.0504884012043476,0.2932476380325232,43793 -24041.820734739304,8.595582008361816,47078.04253005981,144936,0,47078.04253005981,0.9861586689949036,0.0539308078587055,0.275894328449335,43793,71132.5337510109,0.9955515265464784,0.0138782355934381,0.7903820845194295,0.9870207905769348,0.0504884012043476,0.2932307445075147,43793 -24152.11885714531,8.649686574935913,47318.212882995605,145685,0,47318.212882995605,0.9861586689949036,0.0539308078587055,0.2759183809349992,43793,71483.07665705681,0.9954928755760192,0.0140475835651159,0.7702413222315503,0.9870207905769348,0.0504884012043476,0.2932309214595385,43793 -24256.440824985504,8.703256368637085,47558.26383471489,146433,0,47558.26383471489,0.9861586689949036,0.0539308004081249,0.276092969195866,43793,71827.52420902252,0.9955329895019532,0.0140459137037396,0.7693459707086195,0.9870207905769348,0.0504884012043476,0.2932822507404936,43793 -24359.37511229515,8.758103132247925,47798.32083177567,147184,0,47798.32083177567,0.9861586689949036,0.0539308004081249,0.2759415898627759,43793,72170.59030127525,0.9954391717910768,0.0142329670488834,0.7748171236902126,0.9870207905769348,0.0504884012043476,0.2933502504790597,43793 -24465.72505545616,8.8131422996521,48038.32773327828,147932,0,48038.32773327828,0.9861586689949036,0.0539308004081249,0.2759515052454083,43793,72517.02304172516,0.9954963326454164,0.0140677941963076,0.7755266511394573,0.9870207905769348,0.0504884012043476,0.2931353810233713,43793 -24571.807502031326,8.868491649627686,48278.375985860825,148677,0,48278.375985860825,0.9861586689949036,0.0539308078587055,0.2759377518485115,43793,72863.22947764397,0.995520830154419,0.01393973082304,0.7862674566821604,0.9870207905769348,0.0504884012043476,0.2931562245813233,43793 -24683.352266073227,8.925062656402588,48518.42352437973,149418,0,48518.42352437973,0.9861586689949036,0.0539308004081249,0.2758873321968428,43793,73214.89954638481,0.9955045580863952,0.0140247605741024,0.7792347084590541,0.9870207905769348,0.0504884012043476,0.2932035091562113,43793 -24794.351008176804,8.982405424118042,48758.59159564972,150160,0,48758.59159564972,0.9861586689949036,0.0539308078587055,0.2759207687265104,43793,73566.14442420006,0.9955043196678162,0.0140610709786415,0.7787874434069905,0.9870207905769348,0.0504884012043476,0.2932349068101686,43793 -24899.233260393143,9.03764295578003,48998.60104203224,150901,0,48998.60104203224,0.9861586689949036,0.0539308078587055,0.2759093645543544,43793,73911.11200237274,0.995482325553894,0.0140950242057442,0.7615456821250393,0.9870207905769348,0.0504884012043476,0.2930889760262496,43793 -25009.81607890129,9.094947576522827,49238.62616467476,151624,0,49238.62616467476,0.9861586689949036,0.0539308078587055,0.2759801610305899,43793,74261.79939770699,0.9954485893249512,0.0142013663426041,0.7779760903966846,0.9870207905769348,0.0504884012043476,0.2931897409197678,43793 -25115.36889028549,9.151697158813477,49478.61374115944,152363,0,49478.61374115944,0.9861586689949036,0.0539308078587055,0.2759371192052875,43793,74607.41888213158,0.9955525994300842,0.0139115117490291,0.7837436456993829,0.9870207905769348,0.0504884012043476,0.2934063355225442,43793 -25226.03780841828,9.207793951034546,49718.60764694214,153102,0,49718.60764694214,0.9861586689949036,0.0539308078587055,0.27590422734712,43793,74958.15905070305,0.9954980611801147,0.0140454312786459,0.7813701511430701,0.9870207905769348,0.0504884012043476,0.2933010124942687,43793 -25330.334392786022,9.263737916946411,49958.73828434944,153849,0,49958.73828434944,0.9861586689949036,0.0539308004081249,0.2758850884604144,43793,75302.66285181046,0.995534360408783,0.0139383375644683,0.7729496941587721,0.9870207905769348,0.0504884012043476,0.2932470890447527,43793 -25436.34400558472,9.328450679779053,50198.92618584633,154588,0,50198.92618584633,0.9861586689949036,0.0539308004081249,0.2760416363120542,43793,75648.9467959404,0.9954615831375122,0.014223264530301,0.7669854386099493,0.9870207905769348,0.0504884012043476,0.2932654833117584,43793 -25543.98120045662,9.3852961063385,50438.856477975845,155327,0,50438.856477975845,0.9861586689949036,0.0539308078587055,0.2759604086981166,43793,75996.59171676636,0.9954792261123656,0.0141309844329953,0.77602726156619,0.9870207905769348,0.0504884012043476,0.2932259288464124,43793 -25650.809689760208,9.441604852676392,50679.07929563522,156065,0,50679.07929563522,0.9861586689949036,0.0539308004081249,0.2759401832193064,43793,76343.71928310394,0.9955069422721864,0.0139641826972365,0.7782063057074599,0.9870207905769348,0.0504884012043476,0.2930974827205313,43793 -25764.367182970047,9.498466730117798,50919.26484918594,156808,0,50919.26484918594,0.9861586689949036,0.0539308078587055,0.2760160452841331,43793,76697.53965449333,0.995521605014801,0.0139957116916775,0.7851263115303592,0.9870207905769348,0.0504884012043476,0.2933636890766024,43793 -25874.278917074203,9.567306756973268,51159.42909312248,157535,0,51159.42909312248,0.9861586689949036,0.0539308004081249,0.2758944622297735,43793,77047.70661783218,0.9954743981361388,0.0140967210754752,0.7758255457820338,0.9870207905769348,0.0504884012043476,0.293116103459514,43793 -25981.39734506607,9.625344038009644,51399.40223193169,158280,0,51399.40223193169,0.9861586689949036,0.0539308004081249,0.275884767550971,43793,77394.87665104866,0.9955391883850098,0.0139821004122495,0.7747015856380461,0.9870207905769348,0.0504884012043476,0.2931608099330871,43793 -26092.07701206208,9.682848691940308,51639.65240359306,159021,0,51639.65240359306,0.9861586689949036,0.0539308078587055,0.2759102206121796,43793,77745.88492846489,0.9954434037208556,0.0142225446179509,0.7661622075147424,0.9870207905769348,0.0504884012043476,0.2931163238204452,43793 -26198.158568143845,9.7459876537323,51879.57717633248,159762,0,51879.57717633248,0.9861586689949036,0.0539308004081249,0.275975848750648,43793,78091.97503495216,0.99553245306015,0.014043060131371,0.7827541050873437,0.9870207905769348,0.0504884012043476,0.293129897049583,43793 -26301.383195638657,9.815279960632324,52119.58177232742,160498,0,52119.58177232742,0.9861586689949036,0.0539308078587055,0.2759441100966021,43793,78435.2933254242,0.995540201663971,0.0139052579179406,0.7825101913603374,0.9870207905769348,0.0504884012043476,0.2931839937265932,43793 -26408.17783999443,9.880241632461548,52359.627802848816,161235,0,52359.627802848816,0.9861586689949036,0.0539308078587055,0.2759489303136601,43793,78782.22266626358,0.9954307675361632,0.0141254318878054,0.7702380539786756,0.9870207905769348,0.0504884012043476,0.2932830918273247,43793 -26514.83008337021,9.939560890197754,52599.55924654007,161978,0,52599.55924654007,0.9861586689949036,0.0539308004081249,0.275967983978321,43793,79128.88596534729,0.9955178499221802,0.0140783488750457,0.778865225406159,0.9870207905769348,0.0504884012043476,0.2931667231247588,43793 -26619.158552885056,9.997602939605711,52839.79007482529,162719,0,52839.79007482529,0.9861586689949036,0.0539308078587055,0.2760785878656386,43793,79473.52391839027,0.9954817891120912,0.014129121787846,0.7618262995083183,0.9870207905769348,0.0504884012043476,0.2932867883896753,43793 -26726.71019911766,10.055290699005129,53079.76247572899,163454,0,53079.76247572899,0.9861586689949036,0.0539308078587055,0.275908853560883,43793,79821.12660717964,0.9954960942268372,0.0141142038628458,0.7793173555613062,0.9870207905769348,0.0504884012043476,0.2932046730429082,43793 -26830.27416396141,10.113618850708008,53319.88084149361,164196,0,53319.88084149361,0.9861586689949036,0.0539308004081249,0.2760912667040872,43793,80164.88756608963,0.9955257773399352,0.0139328949153423,0.7799171908740734,0.9870207905769348,0.0504884012043476,0.2933241683517332,43793 -26933.934188604355,10.171966552734377,53559.96479272842,164940,0,53559.96479272842,0.9861586689949036,0.0539308078587055,0.2759160436130984,43793,80508.71048593521,0.9955264925956726,0.0139749469235539,0.781608569478042,0.9870207905769348,0.0504884012043476,0.2930783591027502,43793 -27041.776702404022,10.230546236038208,53800.09856343269,165683,0,53800.09856343269,0.9861586689949036,0.0539308004081249,0.2758898549006813,43793,80856.76532030106,0.9954779148101808,0.0141129968687891,0.7717062319651007,0.9870207905769348,0.0504884012043476,0.2932151454680108,43793 -27149.88500189781,10.289138793945312,54040.2394015789,166430,0,54040.2394015789,0.9861586689949036,0.0539308004081249,0.2759094356033681,43793,81205.09357523918,0.9955308437347412,0.013994694687426,0.7687587064128196,0.9870207905769348,0.0504884012043476,0.2931156943476471,43793 -27258.364639759064,10.35668396949768,54280.27514696121,167172,0,54280.27514696121,0.9861586689949036,0.0539308078587055,0.276066158956547,43793,81553.69691586494,0.9954442977905272,0.0142495296895504,0.7769547425832453,0.9870207905769348,0.0504884012043476,0.2930756406702172,43793 -27361.902376651764,10.414924383163452,54520.21311235428,167916,0,54520.21311235428,0.9861586689949036,0.0539308078587055,0.2759194664446954,43793,81897.25141072273,0.9955223202705384,0.0139664039015769,0.778954704864464,0.9870207905769348,0.0504884012043476,0.2930991619901983,43793 -27468.845179080963,10.472869873046877,54760.16240620613,168667,0,54760.16240620613,0.9861586689949036,0.0539308078587055,0.2760429089668513,43793,82244.221940279,0.9955002069473268,0.0140024097636342,0.781999000653092,0.9870207905769348,0.0504884012043476,0.2931762744871093,43793 -27570.706958293915,10.529196977615356,55000.32378101349,169410,0,55000.32378101349,0.9861586689949036,0.0539308078587055,0.2759854657923026,43793,82586.32157897949,0.995485544204712,0.0140526164323091,0.7778016832839559,0.9870207905769348,0.0504884012043476,0.2932229489257906,43793 -27675.19109106064,10.588984966278076,55240.26776766777,170162,0,55240.26776766777,0.9861586689949036,0.05393080785870552,0.27592051244043825,43793,82930.82968711853,0.9955334067344666,0.013999897986650467,0.768694405169295,0.9870207905769348,0.05048840120434761,0.2932143589528342,43793 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index 98209a29a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1941 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.935929,0.7161926,,,,,,,,,,,,,,,,, -1,,,0.5404419302940369,0.7089899182319641,0.022265429341623,0.5404320955276489,0.7129108905792236,0.0261747405043297,43793.0,0.5401853322982788,0.7146537899971008,0.0280191991350671,43793.0,19.528467893600464,549.2898893356323,19.528467893600464,529.7613813877106,0.0,0.0 -100,0.29004562,0.269423,,,,,,,,,,,,,,,,, -200,0.09319386,0.10455389,,,,,,,,,,,,,,,,, -300,0.029113209,0.067329705,,,,,,,,,,,,,,,,, -400,0.024603032,0.060855325,,,,,,,,,,,,,,,,, -500,0.015349899,0.05241393,,,,,,,,,,,,,,,,, -600,0.016219355,0.057228573,,,,,,,,,,,,,,,,, -700,0.015388397,0.051509432,,,,,,,,,,,,,,,,, -748,,,0.986663281917572,0.0516927912831306,0.0518440765289784,0.9842145442962646,0.0608852803707122,0.0512608568202372,43793.0,0.983219563961029,0.0641781017184257,0.0540834326521347,43793.0,259.61792516708374,907.8264088630676,259.61792516708374,648.1611218452454,0.0277159214019775,0.0 -800,0.019057063,0.053982712,,,,,,,,,,,,,,,,, -900,0.030707862,0.05276609,,,,,,,,,,,,,,,,, -1000,0.025111223,0.0469112,,,,,,,,,,,,,,,,, -1100,0.02290161,0.048773997,,,,,,,,,,,,,,,,, -1200,0.025079237,0.053601976,,,,,,,,,,,,,,,,, -1300,0.018196229,0.048621666,,,,,,,,,,,,,,,,, -1400,0.025427088,0.045936167,,,,,,,,,,,,,,,,, -1491,,,0.9875317811965942,0.0461782217025756,0.102587681895728,0.9847463369369508,0.0551133230328559,0.1048863280755017,43793.0,0.9837393164634703,0.0582070387899875,0.1037207408094717,43793.0,499.6568791866303,1269.7700910568235,499.6568791866303,770.0165569782257,0.0568201541900634,0.0 -1500,0.023155827,0.04535239,,,,,,,,,,,,,,,,, -1600,0.02740579,0.04540094,,,,,,,,,,,,,,,,, -1700,0.027410038,0.046357516,,,,,,,,,,,,,,,,, -1800,0.027165052,0.045217454,,,,,,,,,,,,,,,,, -1900,0.031123327,0.04650807,,,,,,,,,,,,,,,,, -2000,0.017809376,0.041228324,,,,,,,,,,,,,,,,, -2100,0.01571742,0.048328046,,,,,,,,,,,,,,,,, -2200,0.021313375,0.044131104,,,,,,,,,,,,,,,,, -2235,,,0.9879155158996582,0.043444886803627,0.1422481736933233,0.9850674271583556,0.0531889759004116,0.1327050623244102,43793.0,0.984077513217926,0.0562639608979225,0.1320783047225885,43793.0,739.8452451229095,1635.7016966342926,739.8452451229095,895.7120683193207,0.084691047668457,0.0 -2300,0.020433161,0.042580914,,,,,,,,,,,,,,,,, -2400,0.019155975,0.046994254,,,,,,,,,,,,,,,,, -2500,0.0098871235,0.04038045,,,,,,,,,,,,,,,,, -2600,0.016372355,0.040667374,,,,,,,,,,,,,,,,, -2700,0.014405611,0.040157195,,,,,,,,,,,,,,,,, -2800,0.0125482585,0.045218352,,,,,,,,,,,,,,,,, -2900,0.013567665,0.041076604,,,,,,,,,,,,,,,,, -2973,,,0.9879909753799438,0.0420600734651088,0.1689640605027597,0.985255002975464,0.0514483675360679,0.1546322777983081,43793.0,0.9843037128448486,0.0542911104857921,0.1567664400875005,43793.0,980.0934472084044,2000.018956899643,980.0934472084044,1019.7300884723664,0.1144242286682128,0.0 -3000,0.016450172,0.045963656,,,,,,,,,,,,,,,,, -3100,0.014151948,0.04252089,,,,,,,,,,,,,,,,, -3200,0.033646528,0.044156488,,,,,,,,,,,,,,,,, -3300,0.015774481,0.040520515,,,,,,,,,,,,,,,,, -3400,0.017164607,0.04189171,,,,,,,,,,,,,,,,, -3500,0.01299017,0.037375517,,,,,,,,,,,,,,,,, -3600,0.010679212,0.041765414,,,,,,,,,,,,,,,,, -3700,0.02255251,0.04119892,,,,,,,,,,,,,,,,, -3716,,,0.9885411262512208,0.0397534817457199,0.2006911287897099,0.9854920506477356,0.0495907105505466,0.1726104452563086,43793.0,0.9846242666244508,0.0522505082190036,0.1775325516586747,43793.0,1220.1030368804932,2366.63014125824,1220.1030368804932,1146.2826611995697,0.1429240703582763,0.0 -3800,0.012506285,0.03905615,,,,,,,,,,,,,,,,, -3900,0.012748895,0.043136325,,,,,,,,,,,,,,,,, -4000,0.01781841,0.039250713,,,,,,,,,,,,,,,,, -4100,0.012371403,0.03854257,,,,,,,,,,,,,,,,, -4200,0.016505316,0.043073,,,,,,,,,,,,,,,,, -4300,0.009317643,0.041271474,,,,,,,,,,,,,,,,, -4400,0.009673785,0.037105866,,,,,,,,,,,,,,,,, -4455,,,0.988589644432068,0.0388961508870124,0.2133832001783586,0.9857372045516968,0.0483437851071357,0.1880250188182171,43793.0,0.9848222136497498,0.0510275997221469,0.189843762419542,43793.0,1460.2951698303225,2733.823101758957,1460.2951698303225,1273.235831975937,0.1699335575103759,0.0 -4500,0.01303669,0.03883566,,,,,,,,,,,,,,,,, -4600,0.014362093,0.03751414,,,,,,,,,,,,,,,,, -4700,0.014330367,0.04137805,,,,,,,,,,,,,,,,, -4800,0.043370858,0.037923172,,,,,,,,,,,,,,,,, -4900,0.024038084,0.038431372,,,,,,,,,,,,,,,,, -5000,0.016242448,0.040203754,,,,,,,,,,,,,,,,, -5100,0.024200907,0.040941328,,,,,,,,,,,,,,,,, -5192,,,0.9890533089637756,0.0369926616549491,0.253568965894043,0.9858285784721376,0.0473449490964412,0.203995115895378,43793.0,0.985047161579132,0.0498691536486148,0.2044275570064991,43793.0,1700.3518633842468,3102.633382081985,1700.3518633842468,1401.9401717185974,0.1983261108398437,0.0 -5200,0.00979402,0.03704553,,,,,,,,,,,,,,,,, -5300,0.011190268,0.04043368,,,,,,,,,,,,,,,,, -5400,0.013759306,0.038912404,,,,,,,,,,,,,,,,, -5500,0.013374401,0.040721014,,,,,,,,,,,,,,,,, -5600,0.013094854,0.04153103,,,,,,,,,,,,,,,,, -5700,0.02222943,0.037934836,,,,,,,,,,,,,,,,, -5800,0.014440475,0.038917813,,,,,,,,,,,,,,,,, -5900,0.019353192,0.03785087,,,,,,,,,,,,,,,,, -5927,,,0.9893468022346495,0.0361139774322509,0.2782895930164883,0.9861021637916564,0.0466593876481056,0.2093368453462271,43793.0,0.9852859377861024,0.0491692684590816,0.2194756435087781,43793.0,1940.53954744339,3468.713948726654,1940.53954744339,1527.7839286327362,0.2261953353881836,0.0 -6000,0.027475365,0.03794378,,,,,,,,,,,,,,,,, -6100,0.011654961,0.035214093,,,,,,,,,,,,,,,,, -6200,0.011190327,0.03591469,,,,,,,,,,,,,,,,, -6300,0.02081115,0.044328213,,,,,,,,,,,,,,,,, -6400,0.01834914,0.040679507,,,,,,,,,,,,,,,,, -6500,0.019279145,0.03811389,,,,,,,,,,,,,,,,, -6600,0.013127054,0.034668576,,,,,,,,,,,,,,,,, -6660,,,0.9893681406974792,0.0361036546528339,0.2615927564089334,0.986129343509674,0.0465039052069187,0.2194878491913071,43793.0,0.9852400422096252,0.0492401048541069,0.2220660015979791,43793.0,2180.5288002491,3833.375732898712,2180.5288002491,1652.4074277877808,0.2549798488616943,0.0 -6700,0.021764616,0.035035413,,,,,,,,,,,,,,,,, -6800,0.013218509,0.03757256,,,,,,,,,,,,,,,,, -6900,0.013352403,0.03537241,,,,,,,,,,,,,,,,, -7000,0.014349796,0.03339148,,,,,,,,,,,,,,,,, -7100,0.015666094,0.03923815,,,,,,,,,,,,,,,,, -7200,0.025414374,0.0336059,,,,,,,,,,,,,,,,, -7300,0.013976642,0.0363739,,,,,,,,,,,,,,,,, -7396,,,0.9895302057266236,0.0353137105703353,0.3048981406299598,0.9863741397857666,0.0459171384572982,0.2346207459164243,43793.0,0.9854978322982788,0.0484322234988212,0.231361468160634,43793.0,2420.737658262253,4202.670556783676,2420.737658262253,1781.4448142051697,0.283771276473999,0.0 -7400,0.025741352,0.03414412,,,,,,,,,,,,,,,,, -7500,0.017210107,0.038309734,,,,,,,,,,,,,,,,, -7600,0.017336916,0.03254927,,,,,,,,,,,,,,,,, -7700,0.019357823,0.03667493,,,,,,,,,,,,,,,,, -7800,0.014957935,0.038446393,,,,,,,,,,,,,,,,, -7900,0.019791335,0.032854848,,,,,,,,,,,,,,,,, -8000,0.03131292,0.032845452,,,,,,,,,,,,,,,,, -8100,0.016134337,0.037194192,,,,,,,,,,,,,,,,, -8119,,,0.9896426796913148,0.0347865670919418,0.297149318468285,0.9862799644470216,0.0459003075957298,0.2356317973826318,43793.0,0.9853870272636414,0.0484150499105453,0.2312606216910576,43793.0,2660.835664749145,4571.243372917175,2660.835664749145,1909.870851516724,0.3127274513244629,0.0 -8200,0.025885386,0.036696993,,,,,,,,,,,,,,,,, -8300,0.017100574,0.03460853,,,,,,,,,,,,,,,,, -8400,0.020738678,0.03501413,,,,,,,,,,,,,,,,, -8500,0.017057914,0.038600676,,,,,,,,,,,,,,,,, -8600,0.013382739,0.03126877,,,,,,,,,,,,,,,,, -8700,0.024147155,0.03762609,,,,,,,,,,,,,,,,, -8800,0.02706047,0.03234538,,,,,,,,,,,,,,,,, -8871,,,0.9896151423454284,0.0344796739518642,0.3239638788346503,0.9862819910049438,0.0454410649836063,0.2402566717270283,43793.0,0.985410213470459,0.0479228235781192,0.2439316950400646,43793.0,2900.8369381427765,4942.084951400757,2900.8369381427765,2040.66327047348,0.3399653434753418,0.0 -8900,0.015290289,0.03988706,,,,,,,,,,,,,,,,, -9000,0.023409266,0.0397412,,,,,,,,,,,,,,,,, -9100,0.025558358,0.03584542,,,,,,,,,,,,,,,,, -9200,0.016747074,0.038298007,,,,,,,,,,,,,,,,, -9300,0.015229834,0.030830199,,,,,,,,,,,,,,,,, -9400,0.02401642,0.039808545,,,,,,,,,,,,,,,,, -9500,0.022371061,0.03492017,,,,,,,,,,,,,,,,, -9600,0.02067722,0.03521614,,,,,,,,,,,,,,,,, -9603,,,0.9900298118591307,0.0330846048891544,0.339935185024856,0.9865061044692992,0.0452296324074268,0.2436097289388764,43793.0,0.9856237769126892,0.0479062609374523,0.2502284643444161,43793.0,3141.071623802185,5310.628950357437,3141.071623802185,2168.9167437553406,0.3717496395111084,0.0 -9700,0.01895095,0.03158784,,,,,,,,,,,,,,,,, -9800,0.02510164,0.03382033,,,,,,,,,,,,,,,,, -9900,0.023971016,0.038149774,,,,,,,,,,,,,,,,, -10000,0.019171633,0.03703318,,,,,,,,,,,,,,,,, -10100,0.028101366,0.036519922,,,,,,,,,,,,,,,,, -10200,0.019744802,0.03619929,,,,,,,,,,,,,,,,, -10300,0.019149227,0.03390954,,,,,,,,,,,,,,,,, -10355,,,0.9901762008666992,0.0326098129153251,0.3561634795300548,0.9865044355392456,0.0447402521967887,0.2490307962152749,43793.0,0.985651969909668,0.0473545342683792,0.246920308466755,43793.0,3381.203540802002,5676.491842985153,3381.203540802002,2294.598527908325,0.4007346630096435,0.0 -10400,0.027776208,0.036848057,,,,,,,,,,,,,,,,, -10500,0.030157654,0.033844948,,,,,,,,,,,,,,,,, -10600,0.032471005,0.03262158,,,,,,,,,,,,,,,,, -10700,0.020681463,0.03160123,,,,,,,,,,,,,,,,, -10800,0.027629972,0.03246709,,,,,,,,,,,,,,,,, -10900,0.03264613,0.034173,,,,,,,,,,,,,,,,, -11000,0.024420366,0.031970575,,,,,,,,,,,,,,,,, -11100,0.023142273,0.030603787,,,,,,,,,,,,,,,,, -11106,,,0.9905160665512084,0.0314555391669273,0.3813649471138655,0.9866416454315186,0.0444622375071048,0.2543336478332198,43793.0,0.9858317971229552,0.0472535192966461,0.2559278961595786,43793.0,3621.350810050965,6046.970125198364,3621.350810050965,2424.878109931946,0.4318912029266357,0.0 -11200,0.024388907,0.03318196,,,,,,,,,,,,,,,,, -11300,0.026726855,0.032684084,,,,,,,,,,,,,,,,, -11400,0.02189619,0.034557786,,,,,,,,,,,,,,,,, -11500,0.022791935,0.03854434,,,,,,,,,,,,,,,,, -11600,0.028832385,0.032625455,,,,,,,,,,,,,,,,, -11700,0.029865952,0.03238131,,,,,,,,,,,,,,,,, -11800,0.030214794,0.034059085,,,,,,,,,,,,,,,,, -11838,,,0.9904680848121644,0.0314260125160217,0.3899980857227915,0.9865442514419556,0.0448398254811763,0.2530719494491042,43793.0,0.985789716243744,0.0473696626722812,0.2570963040451753,43793.0,3861.3633439540863,6415.400445222855,3861.3633439540863,2553.246810436249,0.4600224494934082,0.0 -11900,0.032576956,0.03879125,,,,,,,,,,,,,,,,, -12000,0.033996306,0.03409137,,,,,,,,,,,,,,,,, -12100,0.027409406,0.035666022,,,,,,,,,,,,,,,,, -12200,0.022771666,0.031614643,,,,,,,,,,,,,,,,, -12300,0.03297316,0.034721386,,,,,,,,,,,,,,,,, -12400,0.0359084,0.03384205,,,,,,,,,,,,,,,,, -12500,0.024377495,0.031168357,,,,,,,,,,,,,,,,, -12583,,,0.990330457687378,0.0318564511835575,0.3694480247887148,0.9867208003997804,0.0444731898605823,0.259595101609442,43793.0,0.9858494997024536,0.0471902675926685,0.2632409372541082,43793.0,4101.431351184845,6786.566886663437,4101.431351184845,2684.271157026291,0.5138566493988037,0.0 -12600,0.024767429,0.031674214,,,,,,,,,,,,,,,,, -12700,0.028316097,0.032548845,,,,,,,,,,,,,,,,, -12800,0.033012412,0.029769467,,,,,,,,,,,,,,,,, -12900,0.032952506,0.03484405,,,,,,,,,,,,,,,,, -13000,0.031256586,0.03321953,,,,,,,,,,,,,,,,, -13100,0.02735039,0.032208286,,,,,,,,,,,,,,,,, -13200,0.03132951,0.03483708,,,,,,,,,,,,,,,,, -13300,0.023765642,0.031654835,,,,,,,,,,,,,,,,, -13332,,,0.990576148033142,0.0312654264271259,0.3798342046666803,0.9867805242538452,0.0439825765788555,0.2672474720559514,43793.0,0.9859135150909424,0.0467578880488872,0.2592363877832236,43793.0,4341.640905618668,7157.587838172913,4341.640905618668,2815.033797264099,0.5426039695739746,0.0 -13400,0.026338926,0.03264628,,,,,,,,,,,,,,,,, -13500,0.034086548,0.036550924,,,,,,,,,,,,,,,,, -13600,0.03389835,0.03427552,,,,,,,,,,,,,,,,, -13700,0.029192507,0.033601344,,,,,,,,,,,,,,,,, -13800,0.033005446,0.03410873,,,,,,,,,,,,,,,,, -13900,0.042147644,0.03788698,,,,,,,,,,,,,,,,, -14000,0.039954357,0.036555413,,,,,,,,,,,,,,,,, -14079,,,0.9905685782432556,0.0311376955360174,0.3765914331016559,0.9868125915527344,0.0441166646778583,0.2629989339221388,43793.0,0.985912263393402,0.0469102598726749,0.2601537864182929,43793.0,4581.922067165375,7525.94166302681,4581.922067165375,2943.0571522712708,0.571929931640625,0.0 -14100,0.025190132,0.030771088,,,,,,,,,,,,,,,,, -14200,0.025300452,0.030971674,,,,,,,,,,,,,,,,, -14300,0.038632147,0.03448578,,,,,,,,,,,,,,,,, -14400,0.04359469,0.03107058,,,,,,,,,,,,,,,,, -14500,0.036630135,0.03319159,,,,,,,,,,,,,,,,, -14600,0.033422187,0.028688021,,,,,,,,,,,,,,,,, -14700,0.031535815,0.03423169,,,,,,,,,,,,,,,,, -14800,0.03245003,0.03082989,,,,,,,,,,,,,,,,, -14815,,,0.990611970424652,0.0308564845472574,0.3953191977620491,0.9868409633636476,0.0441383384168148,0.2656919440829268,43793.0,0.9859982132911682,0.0468400716781616,0.2640174170571049,43793.0,4822.095349788666,7896.865571737289,4822.095349788666,3073.758298397064,0.6007485389709473,0.0 -14900,0.040700916,0.03372019,,,,,,,,,,,,,,,,, -15000,0.03267021,0.030057313,,,,,,,,,,,,,,,,, -15100,0.04350272,0.03208989,,,,,,,,,,,,,,,,, -15200,0.032311775,0.033207897,,,,,,,,,,,,,,,,, -15300,0.036105532,0.033855624,,,,,,,,,,,,,,,,, -15400,0.04298045,0.03187805,,,,,,,,,,,,,,,,, -15500,0.03667267,0.03711731,,,,,,,,,,,,,,,,, -15547,,,0.99070805311203,0.0302942767739295,0.4081647168737971,0.9868353009223938,0.0444324798882007,0.261461695531231,43793.0,0.9859632253646852,0.0472245104610919,0.260707375649573,43793.0,5062.213956356049,8270.714021921158,5062.213956356049,3207.43451666832,0.6317923069000244,0.0 -15600,0.03325095,0.029436318,,,,,,,,,,,,,,,,, -15700,0.03374705,0.031663846,,,,,,,,,,,,,,,,, -15800,0.034191996,0.032522716,,,,,,,,,,,,,,,,, -15900,0.034239687,0.030908678,,,,,,,,,,,,,,,,, -16000,0.03974965,0.032689694,,,,,,,,,,,,,,,,, -16100,0.033296827,0.030003633,,,,,,,,,,,,,,,,, -16200,0.03812378,0.028327327,,,,,,,,,,,,,,,,, -16279,,,0.9908111095428468,0.0298472344875335,0.415534257377386,0.9867788553237916,0.0441870130598545,0.2656293101641374,43793.0,0.9859278798103333,0.0468454733490943,0.2631692741851291,43793.0,5302.277789592743,8637.550012588501,5302.277789592743,3334.1568129062653,0.661334753036499,0.0 -16300,0.040296797,0.03146144,,,,,,,,,,,,,,,,, -16400,0.037306793,0.030487925,,,,,,,,,,,,,,,,, -16500,0.03370778,0.02987373,,,,,,,,,,,,,,,,, -16600,0.042522285,0.030647956,,,,,,,,,,,,,,,,, -16700,0.052015103,0.032918017,,,,,,,,,,,,,,,,, -16800,0.035054304,0.028797664,,,,,,,,,,,,,,,,, -16900,0.03430894,0.029643219,,,,,,,,,,,,,,,,, -17000,0.037794232,0.03345768,,,,,,,,,,,,,,,,, -17025,,,0.9908833503723145,0.0295868404209613,0.4339656447805606,0.9867642521858216,0.0441272146999836,0.2659126980593257,43793.0,0.9859012961387634,0.0468287318944931,0.2625971975460897,43793.0,5542.288363218308,9006.121246814728,5542.288363218308,3462.667748451233,0.6905078887939453,0.0 -17100,0.0355089,0.02924295,,,,,,,,,,,,,,,,, -17200,0.04458476,0.030386714,,,,,,,,,,,,,,,,, -17300,0.034619972,0.027370539,,,,,,,,,,,,,,,,, -17400,0.036591064,0.032257363,,,,,,,,,,,,,,,,, -17500,0.03944034,0.03372431,,,,,,,,,,,,,,,,, -17600,0.042890333,0.03177414,,,,,,,,,,,,,,,,, -17700,0.057873785,0.031223843,,,,,,,,,,,,,,,,, -17766,,,0.9913222193717957,0.0284822024405002,0.4490868252663456,0.9868288040161132,0.0438111871480941,0.2682480897307769,43793.0,0.9859914779663086,0.046521496027708,0.2605604429493556,43793.0,5782.3681898117065,9374.288189411163,5782.3681898117065,3590.7058432102203,0.7197377681732178,0.0 -17800,0.06254016,0.03399284,,,,,,,,,,,,,,,,, -17900,0.077010274,0.032145146,,,,,,,,,,,,,,,,, -18000,0.051091697,0.03251256,,,,,,,,,,,,,,,,, -18100,0.040515397,0.03118311,,,,,,,,,,,,,,,,, -18200,0.047503937,0.033468455,,,,,,,,,,,,,,,,, -18300,0.0421842,0.029776841,,,,,,,,,,,,,,,,, -18400,0.04490624,0.029954936,,,,,,,,,,,,,,,,, -18490,,,0.9911221265792848,0.0291727930307388,0.4280779237068718,0.9867683053016664,0.0443317145109176,0.2622161551066213,43793.0,0.986002802848816,0.0469948165118694,0.26194694370496,43793.0,6022.423826932907,9742.443471431732,6022.423826932907,3718.7504889965057,0.7523410320281982,0.0 -18500,0.04626356,0.031745873,,,,,,,,,,,,,,,,, -18600,0.045235477,0.027028503,,,,,,,,,,,,,,,,, -18700,0.21340585,0.033934183,,,,,,,,,,,,,,,,, -18800,0.042629585,0.029820455,,,,,,,,,,,,,,,,, -18900,0.038912058,0.0302553,,,,,,,,,,,,,,,,, -19000,0.046626803,0.028233815,,,,,,,,,,,,,,,,, -19100,0.036281314,0.028949855,,,,,,,,,,,,,,,,, -19200,0.03540264,0.028225552,,,,,,,,,,,,,,,,, -19233,,,0.9910096526145936,0.0293345991522073,0.4320504665074334,0.986805260181427,0.0439628921449184,0.2754075840554567,43793.0,0.9859328866004944,0.0467655733227729,0.262611156496814,43793.0,6262.385018110275,10111.189734220505,6262.385018110275,3847.483499765396,0.7836909294128418,0.0 -19300,0.036918987,0.03116872,,,,,,,,,,,,,,,,, -19400,0.04082407,0.025486095,,,,,,,,,,,,,,,,, -19500,0.04446175,0.03091943,,,,,,,,,,,,,,,,, -19600,0.04128679,0.029153038,,,,,,,,,,,,,,,,, -19700,0.04966513,0.032565974,,,,,,,,,,,,,,,,, -19800,0.04677715,0.030360544,,,,,,,,,,,,,,,,, -19900,0.04704949,0.029475693,,,,,,,,,,,,,,,,, -19976,,,0.9907341599464417,0.0300019308924675,0.4278425982866988,0.986909568309784,0.0444412007927894,0.2761037932319848,43793.0,0.9860508441925048,0.0472310557961463,0.2681115070087855,43793.0,6502.49117732048,10479.550280809402,6502.49117732048,3975.6875672340393,0.813539981842041,0.0 -20000,0.04954091,0.030219419,,,,,,,,,,,,,,,,, -20100,0.042683013,0.028130107,,,,,,,,,,,,,,,,, -20200,0.04398751,0.031536683,,,,,,,,,,,,,,,,, -20300,0.042661287,0.029495107,,,,,,,,,,,,,,,,, -20400,0.04752557,0.033366222,,,,,,,,,,,,,,,,, -20500,0.05252185,0.031610753,,,,,,,,,,,,,,,,, -20600,0.041580394,0.031732358,,,,,,,,,,,,,,,,, -20700,0.050435275,0.03278649,,,,,,,,,,,,,,,,, -20722,,,0.9907103776931764,0.0302657280117273,0.4083187747450125,0.9867829084396362,0.0442190542817115,0.2743305827789721,43793.0,0.9859228134155272,0.0470405109226703,0.2631733995044497,43793.0,6742.673607110977,10847.032025575638,6742.673607110977,4102.936810493469,0.8434133529663086,0.0 -20800,0.041655615,0.030670673,,,,,,,,,,,,,,,,, -20900,0.04534166,0.030079689,,,,,,,,,,,,,,,,, -21000,0.051598158,0.036390297,,,,,,,,,,,,,,,,, -21100,0.041901916,0.02955124,,,,,,,,,,,,,,,,, -21200,0.053525917,0.032073148,,,,,,,,,,,,,,,,, -21300,0.042252965,0.03151238,,,,,,,,,,,,,,,,, -21400,0.04158563,0.029603858,,,,,,,,,,,,,,,,, -21465,,,0.9911373853683472,0.028706619516015,0.4494239762717231,0.9869254231452942,0.0439803414046764,0.2773801293033675,43793.0,0.9860706329345704,0.0468168444931507,0.2728187773762324,43793.0,6982.719746828079,11216.410384893416,6982.719746828079,4232.218222379684,0.8738067150115967,0.0 -21500,0.055863403,0.028955359,,,,,,,,,,,,,,,,, -21600,0.047910918,0.030743953,,,,,,,,,,,,,,,,, -21700,0.044716094,0.032343008,,,,,,,,,,,,,,,,, -21800,0.043393437,0.029290602,,,,,,,,,,,,,,,,, -21900,0.050086528,0.033292964,,,,,,,,,,,,,,,,, -22000,0.043737646,0.032098524,,,,,,,,,,,,,,,,, -22100,0.04504176,0.02970301,,,,,,,,,,,,,,,,, -22200,0.0509573,0.0310639,,,,,,,,,,,,,,,,, -22210,,,0.9912395477294922,0.0284947175532579,0.4497181502557184,0.9868494868278505,0.0440571866929531,0.2787659404410005,43793.0,0.985990583896637,0.0468807965517044,0.2618272747604446,43793.0,7222.9662935733795,11582.311628341677,7222.9662935733795,4357.820563793182,0.9055871963500975,0.0 -22300,0.051298816,0.03333797,,,,,,,,,,,,,,,,, -22400,0.04274308,0.029037599,,,,,,,,,,,,,,,,, -22500,0.068019606,0.030020356,,,,,,,,,,,,,,,,, -22600,0.04483852,0.030335778,,,,,,,,,,,,,,,,, -22700,0.048870433,0.03299718,,,,,,,,,,,,,,,,, -22800,0.05417491,0.028352626,,,,,,,,,,,,,,,,, -22900,0.04359511,0.027395515,,,,,,,,,,,,,,,,, -22951,,,0.9914074540138244,0.0279013980180025,0.4588849204040839,0.9868494868278505,0.0440556108951568,0.2763261030184993,43793.0,0.9859901666641236,0.0469595305621624,0.2666838640350755,43793.0,7462.987103223801,11952.801618337631,7462.987103223801,4488.238070964813,0.937244176864624,0.0 -23000,0.05593955,0.029659754,,,,,,,,,,,,,,,,, -23100,0.051621288,0.030849474,,,,,,,,,,,,,,,,, -23200,0.0536997,0.030434025,,,,,,,,,,,,,,,,, -23300,0.044088464,0.029269358,,,,,,,,,,,,,,,,, -23400,0.05549231,0.03432328,,,,,,,,,,,,,,,,, -23500,0.051903646,0.034379,,,,,,,,,,,,,,,,, -23600,0.06155432,0.028666224,,,,,,,,,,,,,,,,, -23693,,,0.9915399551391602,0.0275176540017128,0.4827089102235441,0.9867687225341796,0.0440514981746673,0.2806043098786221,43793.0,0.9859842658042908,0.0466601550579071,0.265365051821998,43793.0,7703.116607427597,12321.216604471208,7703.116607427597,4616.471472978592,0.9694650173187256,0.0 -23700,0.04763127,0.029349126,,,,,,,,,,,,,,,,, -23800,0.04596472,0.028745016,,,,,,,,,,,,,,,,, -23900,0.052728612,0.02919948,,,,,,,,,,,,,,,,, -24000,0.072198875,0.029859226,,,,,,,,,,,,,,,,, -24100,0.04240445,0.025829569,,,,,,,,,,,,,,,,, -24200,0.053400073,0.028441755,,,,,,,,,,,,,,,,, -24300,0.055958357,0.03137128,,,,,,,,,,,,,,,,, -24400,0.0631064,0.032014214,,,,,,,,,,,,,,,,, -24437,,,0.9916155338287354,0.027268037199974,0.4807372090344972,0.9868682026863098,0.0439805127680301,0.2749072176084976,43793.0,0.9860904216766356,0.0465195775032043,0.2728047702282302,43793.0,7943.241847038269,12691.554522037506,7943.241847038269,4746.633130073547,1.000164270401001,0.0 -24500,0.0657092,0.034512725,,,,,,,,,,,,,,,,, -24600,0.047695834,0.028638048,,,,,,,,,,,,,,,,, -24700,0.043971866,0.029528145,,,,,,,,,,,,,,,,, -24800,0.056585494,0.029504199,,,,,,,,,,,,,,,,, -24900,0.058418717,0.02631937,,,,,,,,,,,,,,,,, -25000,0.052663103,0.030880382,,,,,,,,,,,,,,,,, -25100,0.046406325,0.028392853,,,,,,,,,,,,,,,,, -25167,,,0.9914122819900512,0.0279057696461677,0.4571382522024744,0.9869388341903688,0.0440827757120132,0.280402054904702,43793.0,0.9861481189727784,0.0468418523669242,0.2706147259221962,43793.0,8183.310019493103,13062.866502285004,8183.310019493103,4877.8257603645325,1.0311439037322998,0.0 -25200,0.04717,0.029869227,,,,,,,,,,,,,,,,, -25300,0.057393238,0.028546903,,,,,,,,,,,,,,,,, -25400,0.046670016,0.028178867,,,,,,,,,,,,,,,,, -25500,0.05709975,0.033321634,,,,,,,,,,,,,,,,, -25600,0.053057816,0.030894674,,,,,,,,,,,,,,,,, -25700,0.05011387,0.029379437,,,,,,,,,,,,,,,,, -25800,0.05327356,0.028592108,,,,,,,,,,,,,,,,, -25896,,,0.991398811340332,0.0279161781072616,0.4598515014503374,0.9869047403335572,0.0441686175763607,0.2727832495909499,43793.0,0.9862020611763,0.0467101261019706,0.2640981604920295,43793.0,8423.271643161774,13429.19395661354,8423.271643161774,5004.138481140137,1.0636134147644043,0.0 -25900,0.07289481,0.040306028,,,,,,,,,,,,,,,,, -26000,0.056462083,0.030923313,,,,,,,,,,,,,,,,, -26100,0.05488643,0.034841876,,,,,,,,,,,,,,,,, -26200,0.049558412,0.029322365,,,,,,,,,,,,,,,,, -26300,0.06411981,0.031814247,,,,,,,,,,,,,,,,, -26400,0.054951303,0.030687457,,,,,,,,,,,,,,,,, -26500,0.06523062,0.031336475,,,,,,,,,,,,,,,,, -26600,0.058304287,0.030338263,,,,,,,,,,,,,,,,, -26634,,,0.9914059638977052,0.0278588086366653,0.4804148735269029,0.987074375152588,0.0439664088189601,0.2818839899912088,43793.0,0.9862193465232848,0.0467636249959468,0.2770079309687364,43793.0,8663.474416017532,13798.188932180405,8663.474416017532,5132.878929376602,1.0949442386627195,0.0 -26700,0.06420899,0.031530213,,,,,,,,,,,,,,,,, -26800,0.05739329,0.029014437,,,,,,,,,,,,,,,,, -26900,0.057577398,0.027702885,,,,,,,,,,,,,,,,, -27000,0.05751229,0.02830501,,,,,,,,,,,,,,,,, -27100,0.06942106,0.02680235,,,,,,,,,,,,,,,,, -27200,0.055025037,0.029423315,,,,,,,,,,,,,,,,, -27300,0.05149404,0.030018523,,,,,,,,,,,,,,,,, -27370,,,0.9915772080421448,0.0271905399858951,0.4743581862163639,0.9870736002922058,0.0438411273062229,0.2876096184277866,43793.0,0.9862067103385924,0.0468411184847354,0.2719546221242722,43793.0,8903.555037021637,14164.301958322523,8903.555037021637,5258.856852293015,1.128938913345337,0.0 -27400,0.0607777,0.03065319,,,,,,,,,,,,,,,,, -27500,0.06109905,0.03221721,,,,,,,,,,,,,,,,, -27600,0.05586348,0.03006329,,,,,,,,,,,,,,,,, -27700,0.062019046,0.029036514,,,,,,,,,,,,,,,,, -27800,0.05561531,0.029795107,,,,,,,,,,,,,,,,, -27900,0.05075626,0.027041767,,,,,,,,,,,,,,,,, -28000,0.05069806,0.02971394,,,,,,,,,,,,,,,,, -28100,0.063261166,0.028043628,,,,,,,,,,,,,,,,, -28112,,,0.9914939999580384,0.0273632034659385,0.4776075046106587,0.9868556261062622,0.0443020462989807,0.2781089525019385,43793.0,0.9860523343086244,0.0470948740839958,0.2704405059868041,43793.0,9143.596687078476,14534.617344141006,9143.596687078476,5389.0773758888245,1.1617298126220703,0.0 -28200,0.059275925,0.028083839,,,,,,,,,,,,,,,,, -28300,0.05668153,0.029032024,,,,,,,,,,,,,,,,, -28400,0.06702686,0.02947461,,,,,,,,,,,,,,,,, -28500,0.05099794,0.029317422,,,,,,,,,,,,,,,,, -28600,0.054191954,0.02723988,,,,,,,,,,,,,,,,, -28700,0.05592376,0.030868283,,,,,,,,,,,,,,,,, -28800,0.05156082,0.029392688,,,,,,,,,,,,,,,,, -28846,,,0.9917556643486024,0.0266566574573516,0.4946562271208262,0.9868446588516236,0.0443001873791217,0.2766937713708444,43793.0,0.98611319065094,0.0470629222691059,0.2691851735216722,43793.0,9383.813615322111,14902.571242570875,9383.813615322111,5516.761331558228,1.194563627243042,0.0 -28900,0.048368208,0.02641993,,,,,,,,,,,,,,,,, -29000,0.055071816,0.027159993,,,,,,,,,,,,,,,,, -29100,0.079686895,0.0368435,,,,,,,,,,,,,,,,, -29200,0.06158022,0.03149969,,,,,,,,,,,,,,,,, -29300,0.052043825,0.027968619,,,,,,,,,,,,,,,,, -29400,0.055842478,0.029497238,,,,,,,,,,,,,,,,, -29500,0.050475936,0.028429486,,,,,,,,,,,,,,,,, -29584,,,0.9919654130935668,0.0259400550276041,0.5117402117318774,0.9869948625564576,0.0441078841686248,0.2880325126914176,43793.0,0.9861321449279784,0.047050341963768,0.2701175138067016,43793.0,9623.958882570269,15269.985672235489,9623.958882570269,5643.975158452988,1.228114128112793,0.0 -29600,0.064212166,0.03310059,,,,,,,,,,,,,,,,, -29700,0.064918466,0.027579637,,,,,,,,,,,,,,,,, -29800,0.05643648,0.028486414,,,,,,,,,,,,,,,,, -29900,0.05542605,0.028130304,,,,,,,,,,,,,,,,, -30000,0.0644141,0.033383705,,,,,,,,,,,,,,,,, -30100,0.052407846,0.028936794,,,,,,,,,,,,,,,,, -30200,0.06188855,0.029852672,,,,,,,,,,,,,,,,, -30300,0.052362587,0.027714087,,,,,,,,,,,,,,,,, -30326,,,0.9920579195022584,0.0257496107369661,0.5215447573111239,0.9868953824043274,0.0439697168767452,0.2858635197434587,43793.0,0.9860247373580932,0.04674918577075,0.2688654699671372,43793.0,9864.012751102448,15633.425725221634,9864.012751102448,5767.308829545975,1.2603094577789309,0.0 -30400,0.051699854,0.026705582,,,,,,,,,,,,,,,,, -30500,0.059421856,0.027339097,,,,,,,,,,,,,,,,, -30600,0.053871375,0.028164158,,,,,,,,,,,,,,,,, -30700,0.054256227,0.025620913,,,,,,,,,,,,,,,,, -30800,0.07340335,0.032238282,,,,,,,,,,,,,,,,, -30900,0.06690443,0.026047561,,,,,,,,,,,,,,,,, -31000,0.06009657,0.026959507,,,,,,,,,,,,,,,,, -31067,,,0.991892635822296,0.0262051355093717,0.5086030748444564,0.986970067024231,0.0440684333443641,0.2834956332839969,43793.0,0.9861097931861876,0.0469236336648464,0.2746272969201792,43793.0,10104.057809352877,15998.710246562958,10104.057809352877,5892.495110750198,1.2933599948883057,0.0 -31100,0.064024195,0.026744341,,,,,,,,,,,,,,,,, -31200,0.055279817,0.028040607,,,,,,,,,,,,,,,,, -31300,0.06246899,0.026664859,,,,,,,,,,,,,,,,, -31400,0.05022863,0.02620009,,,,,,,,,,,,,,,,, -31500,0.057326917,0.02814812,,,,,,,,,,,,,,,,, -31600,0.06076301,0.029442284,,,,,,,,,,,,,,,,, -31700,0.060314428,0.030564165,,,,,,,,,,,,,,,,, -31800,0.064136386,0.031125657,,,,,,,,,,,,,,,,, -31816,,,0.9916676878929138,0.026850014925003,0.4904087292640114,0.9869644045829772,0.0441396757960319,0.2862625865048266,43793.0,0.9861460328102112,0.0468500703573226,0.2760263600114768,43793.0,10344.097812891006,16365.99844264984,10344.097812891006,6019.689756393433,1.3267717361450195,0.0 -31900,0.05342401,0.024149679,,,,,,,,,,,,,,,,, -32000,0.05387119,0.026007336,,,,,,,,,,,,,,,,, -32100,0.05174546,0.027344102,,,,,,,,,,,,,,,,, -32200,0.06198532,0.02869427,,,,,,,,,,,,,,,,, -32300,0.05204534,0.030274006,,,,,,,,,,,,,,,,, -32400,0.07143624,0.03336765,,,,,,,,,,,,,,,,, -32500,0.052932087,0.027515294,,,,,,,,,,,,,,,,, -32560,,,0.9917237758636476,0.0267841536551713,0.493660033538621,0.9869781732559204,0.0438555255532264,0.290015872842108,43793.0,0.9861759543418884,0.0464902445673942,0.2793258645120916,43793.0,10584.372139453888,16734.21134352684,10584.372139453888,6147.57537651062,1.3592548370361328,0.0 -32600,0.05446056,0.026810016,,,,,,,,,,,,,,,,, -32700,0.08985935,0.031437363,,,,,,,,,,,,,,,,, -32800,0.08751804,0.029215671,,,,,,,,,,,,,,,,, -32900,0.0533782,0.02712535,,,,,,,,,,,,,,,,, -33000,0.050908774,0.029243644,,,,,,,,,,,,,,,,, -33100,0.057273783,0.029188203,,,,,,,,,,,,,,,,, -33200,0.054863304,0.025639558,,,,,,,,,,,,,,,,, -33296,,,0.9918603897094728,0.0262102521955966,0.4982211958330765,0.9870301485061646,0.0441683307290077,0.287087614806626,43793.0,0.9861931800842284,0.0470806285738945,0.2709264239678933,43793.0,10824.466272115707,17102.31195449829,10824.466272115707,6275.525854349136,1.3940739631652832,0.0 -33300,0.057318453,0.025646344,,,,,,,,,,,,,,,,, -33400,0.06188294,0.025296515,,,,,,,,,,,,,,,,, -33500,0.05094684,0.02732706,,,,,,,,,,,,,,,,, -33600,0.06330642,0.0319127,,,,,,,,,,,,,,,,, -33700,0.05420823,0.027589386,,,,,,,,,,,,,,,,, -33800,0.061590333,0.03042163,,,,,,,,,,,,,,,,, -33900,0.055094257,0.028373633,,,,,,,,,,,,,,,,, -34000,0.053679295,0.026272012,,,,,,,,,,,,,,,,, -34029,,,0.9919319152832032,0.0259202476590871,0.5161529871779962,0.9869481325149536,0.0443512499332428,0.2823631179528553,43793.0,0.9860681295394896,0.0469147376716136,0.2748894485436163,43793.0,11064.580405950546,17466.4331305027,11064.580405950546,6399.477682113648,1.4283545017242432,0.0 -34100,0.056622285,0.027684001,,,,,,,,,,,,,,,,, -34200,0.060968682,0.028868709,,,,,,,,,,,,,,,,, -34300,0.06312649,0.028674103,,,,,,,,,,,,,,,,, -34400,0.061782483,0.028921423,,,,,,,,,,,,,,,,, -34500,0.058559097,0.028711444,,,,,,,,,,,,,,,,, -34600,0.056649644,0.026148561,,,,,,,,,,,,,,,,, -34700,0.076985806,0.033302475,,,,,,,,,,,,,,,,, -34779,,,0.992088496685028,0.0254748146981,0.5129617949169412,0.9869903922080994,0.044027104973793,0.2915628601867511,43793.0,0.986177623271942,0.0469932220876216,0.2721970894810648,43793.0,11304.724274158478,17836.523913621902,11304.724274158478,6529.37206697464,1.4609463214874268,0.0 -34800,0.064565726,0.026122665,,,,,,,,,,,,,,,,, -34900,0.063004255,0.031016188,,,,,,,,,,,,,,,,, -35000,0.05138947,0.024303107,,,,,,,,,,,,,,,,, -35100,0.06598122,0.025877312,,,,,,,,,,,,,,,,, -35200,0.056133904,0.027241932,,,,,,,,,,,,,,,,, -35300,0.05954493,0.026613766,,,,,,,,,,,,,,,,, -35400,0.0579109,0.025037903,,,,,,,,,,,,,,,,, -35500,0.06649895,0.028482629,,,,,,,,,,,,,,,,, -35509,,,0.9922428131103516,0.0247446577996015,0.5505906445570861,0.9870585799217224,0.0441386215388774,0.2907984307159181,43793.0,0.9862454533576964,0.0470156781375408,0.2753576260925718,43793.0,11544.948768615724,18203.61198425293,11544.948768615724,6656.179391384125,1.4955766201019287,0.0 -35600,0.054731775,0.025498757,,,,,,,,,,,,,,,,, -35700,0.06145589,0.02859945,,,,,,,,,,,,,,,,, -35800,0.07210765,0.024908116,,,,,,,,,,,,,,,,, -35900,0.07803512,0.026304774,,,,,,,,,,,,,,,,, -36000,0.0793027,0.027754996,,,,,,,,,,,,,,,,, -36100,0.05907004,0.02619876,,,,,,,,,,,,,,,,, -36200,0.07884063,0.028884705,,,,,,,,,,,,,,,,, -36250,,,0.992393672466278,0.0244789887219667,0.5408486038115388,0.9870277047157288,0.0441088415682315,0.2908113959289101,43793.0,0.9862197637557985,0.0467967800796031,0.2766800615256882,43793.0,11784.974064826964,18572.50408053398,11784.974064826964,6784.992396831512,1.5288910865783691,0.0 -36300,0.06797825,0.032011453,,,,,,,,,,,,,,,,, -36400,0.08456826,0.031793363,,,,,,,,,,,,,,,,, -36500,0.06397566,0.030541373,,,,,,,,,,,,,,,,, -36600,0.08065389,0.027756123,,,,,,,,,,,,,,,,, -36700,0.061528806,0.02863926,,,,,,,,,,,,,,,,, -36800,0.07460822,0.02951571,,,,,,,,,,,,,,,,, -36900,0.056640785,0.024703197,,,,,,,,,,,,,,,,, -36993,,,0.9923135042190552,0.0246789418160915,0.5443136465959396,0.9869603514671326,0.0445223152637481,0.2863923956318354,43793.0,0.9861136078834534,0.0472501963376998,0.2734903307389538,43793.0,12024.976420640944,18941.173763275143,12024.976420640944,6913.606185436249,1.5621047019958496,0.0 -37000,0.058728594,0.024406716,,,,,,,,,,,,,,,,, -37100,0.05740258,0.026250646,,,,,,,,,,,,,,,,, -37200,0.06594787,0.027880585,,,,,,,,,,,,,,,,, -37300,0.05993324,0.025549114,,,,,,,,,,,,,,,,, -37400,0.06191692,0.02570981,,,,,,,,,,,,,,,,, -37500,0.05745138,0.025585875,,,,,,,,,,,,,,,,, -37600,0.060951915,0.028218348,,,,,,,,,,,,,,,,, -37700,0.066758595,0.024995683,,,,,,,,,,,,,,,,, -37716,,,0.9922302961349488,0.0250727999955415,0.5250435594834874,0.9870094656944276,0.044437400996685,0.2870718545969231,43793.0,0.9861860275268556,0.0470856502652168,0.2764250250297177,43793.0,12265.204212665558,19310.005308389664,12265.204212665558,7042.153408050537,1.5955908298492432,0.0 -37800,0.11824077,0.033158552,,,,,,,,,,,,,,,,, -37900,0.06683033,0.02528263,,,,,,,,,,,,,,,,, -38000,0.079231225,0.028274532,,,,,,,,,,,,,,,,, -38100,0.06268689,0.026909139,,,,,,,,,,,,,,,,, -38200,0.066801704,0.031077916,,,,,,,,,,,,,,,,, -38300,0.0764827,0.032277294,,,,,,,,,,,,,,,,, -38400,0.06447134,0.027896728,,,,,,,,,,,,,,,,, -38451,,,0.9921622276306152,0.0252226088196039,0.5285684492729031,0.9870207905769348,0.0444123074412345,0.287736763336245,43793.0,0.98612242937088,0.0472948960959911,0.2705252738596758,43793.0,12505.271298408508,19675.301374912266,12505.271298408508,7167.328474283218,1.6293635368347168,0.0 -38500,0.065985255,0.027022073,,,,,,,,,,,,,,,,, -38600,0.062058385,0.02986273,,,,,,,,,,,,,,,,, -38700,0.06492322,0.027095726,,,,,,,,,,,,,,,,, -38800,0.06982808,0.027948136,,,,,,,,,,,,,,,,, -38900,0.06703295,0.028534288,,,,,,,,,,,,,,,,, -39000,0.062188502,0.026508523,,,,,,,,,,,,,,,,, -39100,0.12251968,0.031099837,,,,,,,,,,,,,,,,, -39188,,,0.9923495650291444,0.0245155375450849,0.5353314789084255,0.9869980812072754,0.0446966439485549,0.2880907345675527,43793.0,0.9861502647399902,0.0476831793785095,0.2739721380819767,43793.0,12745.515764474869,20041.384852409363,12745.515764474869,7293.105636358261,1.66961932182312,0.0 -39200,0.07545386,0.029688785,,,,,,,,,,,,,,,,, -39300,0.05556681,0.023978565,,,,,,,,,,,,,,,,, -39400,0.066512644,0.027351024,,,,,,,,,,,,,,,,, -39500,0.07252562,0.027911192,,,,,,,,,,,,,,,,, -39600,0.07309992,0.029088613,,,,,,,,,,,,,,,,, -39700,0.07023008,0.024758238,,,,,,,,,,,,,,,,, -39800,0.065774366,0.025943058,,,,,,,,,,,,,,,,, -39900,0.07501997,0.027727505,,,,,,,,,,,,,,,,, -39921,,,0.9920818209648132,0.0250757187604904,0.5387114399975329,0.9870151281356812,0.0445631518959999,0.2937449771905799,43793.0,0.9862037301063538,0.0472649373114109,0.2791759747694861,43793.0,12985.472758769987,20412.724797964096,12985.472758769987,7424.434202432632,1.7036066055297852,0.0 -40000,0.065359026,0.025612969,,,,,,,,,,,,,,,,, -40100,0.06827082,0.025871966,,,,,,,,,,,,,,,,, -40200,0.078831375,0.028765813,,,,,,,,,,,,,,,,, -40300,0.063428186,0.026496483,,,,,,,,,,,,,,,,, -40400,0.062926576,0.02574508,,,,,,,,,,,,,,,,, -40500,0.06423223,0.023110926,,,,,,,,,,,,,,,,, -40600,0.08466144,0.029558169,,,,,,,,,,,,,,,,, -40650,,,0.992318868637085,0.0244046170264482,0.5562079477709713,0.98701673746109,0.0444464944303035,0.2909518713632352,43793.0,0.9861915111541748,0.0471347123384475,0.2772578004381141,43793.0,13225.699823617935,20778.853857278824,13225.699823617935,7550.27251458168,1.7443230152130127,0.0 -40700,0.068402424,0.029806059,,,,,,,,,,,,,,,,, -40800,0.071535274,0.026619092,,,,,,,,,,,,,,,,, -40900,0.07391201,0.026127562,,,,,,,,,,,,,,,,, -41000,0.08125178,0.026271444,,,,,,,,,,,,,,,,, -41100,0.07271082,0.024672125,,,,,,,,,,,,,,,,, -41200,0.09090445,0.027955865,,,,,,,,,,,,,,,,, -41300,0.08244576,0.027832173,,,,,,,,,,,,,,,,, -41388,,,0.9926362037658693,0.0233683511614799,0.5768032113155684,0.9870049953460692,0.0447624586522579,0.2829433254848256,43793.0,0.9861556887626648,0.0476178787648677,0.2709540144738322,43793.0,13465.783110141754,21146.153143405914,13465.783110141754,7677.433129072189,1.7788498401641846,0.0 -41400,0.083102,0.028063592,,,,,,,,,,,,,,,,, -41500,0.072396025,0.026583742,,,,,,,,,,,,,,,,, -41600,0.08530804,0.030500956,,,,,,,,,,,,,,,,, -41700,0.08692914,0.026684316,,,,,,,,,,,,,,,,, -41800,0.07295118,0.025525112,,,,,,,,,,,,,,,,, -41900,0.07212522,0.02664014,,,,,,,,,,,,,,,,, -42000,0.076275274,0.025389206,,,,,,,,,,,,,,,,, -42100,0.07183555,0.02549032,,,,,,,,,,,,,,,,, -42116,,,0.992948591709137,0.022594491019845,0.5889403009136849,0.9868965744972228,0.0448220036923885,0.2870109267260117,43793.0,0.986143946647644,0.0475996248424053,0.2726043926098506,43793.0,13705.972064971924,21517.054282665253,13705.972064971924,7808.087415456772,1.815146446228028,0.0 -42200,0.07410889,0.028331779,,,,,,,,,,,,,,,,, -42300,0.07778998,0.024431465,,,,,,,,,,,,,,,,, -42400,0.083738245,0.029697135,,,,,,,,,,,,,,,,, -42500,0.070011616,0.025548419,,,,,,,,,,,,,,,,, -42600,0.06739425,0.026586462,,,,,,,,,,,,,,,,, -42700,0.07230254,0.025058437,,,,,,,,,,,,,,,,, -42800,0.08407361,0.028309148,,,,,,,,,,,,,,,,, -42857,,,0.9928516745567322,0.0227420013397932,0.5831679997609343,0.9870119094848632,0.0449406802654266,0.2870644899198072,43793.0,0.9862349033355712,0.0478113815188407,0.277432456781912,43793.0,13946.16392469406,21881.35036206245,13946.16392469406,7932.13530087471,1.8508188724517824,0.0 -42900,0.07206131,0.024609532,,,,,,,,,,,,,,,,, -43000,0.06579774,0.024487928,,,,,,,,,,,,,,,,, -43100,0.0814524,0.027383398,,,,,,,,,,,,,,,,, -43200,0.07448095,0.025785577,,,,,,,,,,,,,,,,, -43300,0.07682888,0.02665708,,,,,,,,,,,,,,,,, -43400,0.07529534,0.02534282,,,,,,,,,,,,,,,,, -43500,0.078369625,0.025001427,,,,,,,,,,,,,,,,, -43598,,,0.9924556016921996,0.0241395328193902,0.5471580974598048,0.9868945479393004,0.0446267016232013,0.2904814723725562,43793.0,0.9860474467277528,0.0475262962281703,0.2770334013150399,43793.0,14186.274589061735,22247.215339899063,14186.274589061735,8057.833017587662,1.88677716255188,0.0 -43600,0.071620025,0.024461143,,,,,,,,,,,,,,,,, -43700,0.0788827,0.025470005,,,,,,,,,,,,,,,,, -43800,0.073311135,0.025287805,,,,,,,,,,,,,,,,, -43900,0.06681001,0.023464557,,,,,,,,,,,,,,,,, -44000,0.069550045,0.023440132,,,,,,,,,,,,,,,,, -44100,0.07631209,0.026429787,,,,,,,,,,,,,,,,, -44200,0.06803203,0.023476345,,,,,,,,,,,,,,,,, -44300,0.0682713,0.02174135,,,,,,,,,,,,,,,,, -44327,,,0.9924808740615844,0.0239379573613405,0.5512886909498305,0.9869850873947144,0.0452325120568275,0.2895107733140303,43793.0,0.986100137233734,0.0481578856706619,0.2753276791308924,43793.0,14426.232063055038,22612.117126464844,14426.232063055038,8182.722589492798,1.921675443649292,0.0 -44400,0.0916494,0.025147403,,,,,,,,,,,,,,,,, -44500,0.072096564,0.024725322,,,,,,,,,,,,,,,,, -44600,0.079761125,0.025666054,,,,,,,,,,,,,,,,, -44700,0.07139467,0.027348746,,,,,,,,,,,,,,,,, -44800,0.088232756,0.029798862,,,,,,,,,,,,,,,,, -44900,0.07738235,0.025954999,,,,,,,,,,,,,,,,, -45000,0.08327954,0.022504672,,,,,,,,,,,,,,,,, -45063,,,0.992731750011444,0.0231606904417276,0.5747661718707735,0.9868564009666444,0.0449821837246418,0.28365360157518,43793.0,0.9860609769821168,0.0477090999484062,0.276624794918469,43793.0,14666.217765569689,22979.45182442665,14666.217765569689,8310.012906312943,1.9601173400878904,0.0 -45100,0.07097564,0.021600854,,,,,,,,,,,,,,,,, -45200,0.072759405,0.025438914,,,,,,,,,,,,,,,,, -45300,0.07719142,0.026513932,,,,,,,,,,,,,,,,, -45400,0.07102966,0.024439143,,,,,,,,,,,,,,,,, -45500,0.09334257,0.029310206,,,,,,,,,,,,,,,,, -45600,0.076779805,0.025301283,,,,,,,,,,,,,,,,, -45700,0.080305286,0.026553214,,,,,,,,,,,,,,,,, -45795,,,0.9927511811256408,0.022958293557167,0.586343797232631,0.987152338027954,0.045345164835453,0.2902990328130586,43793.0,0.9862298369407654,0.0482310391962528,0.2777345965558692,43793.0,14906.2397480011,23348.91964435577,14906.2397480011,8439.401960372925,1.9960315227508545,0.0 -45800,0.07005532,0.024812508,,,,,,,,,,,,,,,,, -45900,0.08032841,0.02691062,,,,,,,,,,,,,,,,, -46000,0.072938725,0.024537183,,,,,,,,,,,,,,,,, -46100,0.088536985,0.028610893,,,,,,,,,,,,,,,,, -46200,0.07776403,0.025750624,,,,,,,,,,,,,,,,, -46300,0.086705536,0.027886955,,,,,,,,,,,,,,,,, -46400,0.08439925,0.026267903,,,,,,,,,,,,,,,,, -46500,0.086224794,0.02542649,,,,,,,,,,,,,,,,, -46535,,,0.9927822351455688,0.0226614810526371,0.5853062788070533,0.9870321750640868,0.0456238947808742,0.288616217569316,43793.0,0.98613041639328,0.0485120005905628,0.2745446524126901,43793.0,15146.333956003187,23717.61184549332,15146.333956003187,8567.942576169968,2.0327556133270264,0.0 -46600,0.06769857,0.026177336,,,,,,,,,,,,,,,,, -46700,0.087716445,0.028499484,,,,,,,,,,,,,,,,, -46800,0.0819651,0.024431692,,,,,,,,,,,,,,,,, -46900,0.07611637,0.024341773,,,,,,,,,,,,,,,,, -47000,0.09310828,0.025438773,,,,,,,,,,,,,,,,, -47100,0.082638144,0.028172545,,,,,,,,,,,,,,,,, -47200,0.09165749,0.02603653,,,,,,,,,,,,,,,,, -47272,,,0.9932552576065063,0.0213545765727758,0.6282544798621549,0.9870585799217224,0.0451401658356189,0.2918916130401043,43793.0,0.9861814379692078,0.0481357239186763,0.2751556729047205,43793.0,15386.32931470871,24085.590614795685,15386.32931470871,8695.869742393494,2.068514823913574,0.0 -47300,0.07766961,0.026965259,,,,,,,,,,,,,,,,, -47400,0.081014976,0.024253814,,,,,,,,,,,,,,,,, -47500,0.081091456,0.025641434,,,,,,,,,,,,,,,,, -47600,0.08075495,0.02430228,,,,,,,,,,,,,,,,, -47700,0.085914984,0.027481463,,,,,,,,,,,,,,,,, -47800,0.08943145,0.026448961,,,,,,,,,,,,,,,,, -47900,0.10118799,0.026937047,,,,,,,,,,,,,,,,, -48000,0.08519481,0.025697378,,,,,,,,,,,,,,,,, -48016,,,0.9932361245155334,0.0213189255446195,0.6125954411107427,0.9870853424072266,0.0456184148788452,0.2917766721293825,43793.0,0.986250936985016,0.0485204793512821,0.2787191662258837,43793.0,15626.565325260162,24447.46418762207,15626.565325260162,8817.450694084167,2.1042933464050293,0.0 -48100,0.083674595,0.025319915,,,,,,,,,,,,,,,,, -48200,0.08622265,0.025334697,,,,,,,,,,,,,,,,, -48300,0.08393668,0.023139587,,,,,,,,,,,,,,,,, -48400,0.076013155,0.02474332,,,,,,,,,,,,,,,,, -48500,0.0993349,0.025271783,,,,,,,,,,,,,,,,, -48600,0.088422276,0.02371487,,,,,,,,,,,,,,,,, -48700,0.08639203,0.024422964,,,,,,,,,,,,,,,,, -48757,,,0.9933009743690492,0.0212412904947996,0.6215911847182629,0.9870471954345704,0.0455567426979541,0.2858171389788816,43793.0,0.9862517714500428,0.0484604015946388,0.275624097328915,43793.0,15866.78537106514,24814.92128801346,15866.78537106514,8944.631282806396,2.1402597427368164,0.0 -48800,0.09156648,0.023009414,,,,,,,,,,,,,,,,, -48900,0.084089816,0.024989707,,,,,,,,,,,,,,,,, -49000,0.08764921,0.026519617,,,,,,,,,,,,,,,,, -49100,0.095591106,0.027203292,,,,,,,,,,,,,,,,, -49200,0.09070899,0.025600692,,,,,,,,,,,,,,,,, -49300,0.1007514,0.028775772,,,,,,,,,,,,,,,,, -49400,0.08391576,0.022701876,,,,,,,,,,,,,,,,, -49494,,,0.9931593537330629,0.0215862058103084,0.6026874616431173,0.9870740175247192,0.0459361635148525,0.2909390577597649,43793.0,0.9862500429153442,0.0490105859935283,0.2732506608323254,43793.0,16106.996109962463,25179.198776960373,16106.996109962463,9068.64112663269,2.176506996154785,0.0 -49500,0.08331668,0.025609953,,,,,,,,,,,,,,,,, -49600,0.077689394,0.022675693,,,,,,,,,,,,,,,,, -49700,0.08390728,0.023901066,,,,,,,,,,,,,,,,, -49800,0.08988595,0.024603179,,,,,,,,,,,,,,,,, -49900,0.079494566,0.02363174,,,,,,,,,,,,,,,,, -50000,0.09243244,0.022856653,,,,,,,,,,,,,,,,, -50100,0.08513662,0.02801155,,,,,,,,,,,,,,,,, -50200,0.087496966,0.024600038,,,,,,,,,,,,,,,,, -50221,,,0.9930923581123352,0.0217242259532213,0.6127489989284732,0.9870272874832152,0.0457522124052047,0.2891217197250926,43793.0,0.9861556887626648,0.0486082173883914,0.2786955782953111,43793.0,16347.118661403656,25548.262518405914,16347.118661403656,9197.52271914482,2.213009119033813,0.0 -50300,0.10724114,0.024079334,,,,,,,,,,,,,,,,, -50400,0.08925257,0.025858881,,,,,,,,,,,,,,,,, -50500,0.085036784,0.023785392,,,,,,,,,,,,,,,,, -50600,0.09141618,0.02304934,,,,,,,,,,,,,,,,, -50700,0.08842773,0.022968879,,,,,,,,,,,,,,,,, -50800,0.094382204,0.024961298,,,,,,,,,,,,,,,,, -50900,0.084754884,0.023864852,,,,,,,,,,,,,,,,, -50958,,,0.9932236075401306,0.0213300064206123,0.6119174869729177,0.98704195022583,0.0457834005355834,0.2952523480539175,43793.0,0.9862281680107116,0.048733152449131,0.2759372437124913,43793.0,16587.07647919655,25917.645943164825,16587.07647919655,9326.890407800674,2.250386476516724,0.0 -51000,0.09466232,0.024203824,,,,,,,,,,,,,,,,, -51100,0.08315764,0.022896564,,,,,,,,,,,,,,,,, -51200,0.09057501,0.022420745,,,,,,,,,,,,,,,,, -51300,0.10061184,0.025443468,,,,,,,,,,,,,,,,, -51400,0.0866125,0.022852212,,,,,,,,,,,,,,,,, -51500,0.10539997,0.025158487,,,,,,,,,,,,,,,,, -51600,0.07876079,0.023202993,,,,,,,,,,,,,,,,, -51693,,,0.993235409259796,0.02134177275002,0.6205512641982079,0.987038254737854,0.0458993092179298,0.2932043078510703,43793.0,0.9862563610076904,0.048782855272293,0.2809863309832373,43793.0,16827.240026474,26283.74558782577,16827.240026474,9452.768271923063,2.288546562194824,0.0 -51700,0.09457076,0.024755413,,,,,,,,,,,,,,,,, -51800,0.08810257,0.023206633,,,,,,,,,,,,,,,,, -51900,0.087437384,0.02390925,,,,,,,,,,,,,,,,, -52000,0.10836627,0.025275612,,,,,,,,,,,,,,,,, -52100,0.08385967,0.01904661,,,,,,,,,,,,,,,,, -52200,0.08433134,0.023295399,,,,,,,,,,,,,,,,, -52300,0.09797684,0.024084246,,,,,,,,,,,,,,,,, -52400,0.0966711,0.026011318,,,,,,,,,,,,,,,,, -52431,,,0.9933454394340516,0.0208866987377405,0.6233633128448957,0.9870431423187256,0.0463724546134471,0.2947568615806708,43793.0,0.9862008094787598,0.0494671426713466,0.275198819814396,43793.0,17067.183010816574,26646.76397776604,17067.183010816574,9575.785498142242,2.326702117919922,0.0 -52500,0.082472414,0.022663359,,,,,,,,,,,,,,,,, -52600,0.11075064,0.025402455,,,,,,,,,,,,,,,,, -52700,0.096833415,0.025336154,,,,,,,,,,,,,,,,, -52800,0.11799472,0.024835348,,,,,,,,,,,,,,,,, -52900,0.08159507,0.021608748,,,,,,,,,,,,,,,,, -53000,0.10054516,0.025432277,,,,,,,,,,,,,,,,, -53100,0.09265455,0.021554409,,,,,,,,,,,,,,,,, -53174,,,0.993765652179718,0.0198266431689262,0.648274235699503,0.9869290590286256,0.0461834780871868,0.293072159487718,43793.0,0.9861982464790344,0.0490794591605663,0.277876897106266,43793.0,17307.16687989235,27013.39756894112,17307.16687989235,9702.37763953209,2.363882303237915,0.0 -53200,0.09622556,0.021537993,,,,,,,,,,,,,,,,, -53300,0.09155439,0.022268355,,,,,,,,,,,,,,,,, -53400,0.09735198,0.023086347,,,,,,,,,,,,,,,,, -53500,0.09760654,0.023046162,,,,,,,,,,,,,,,,, -53600,0.099731915,0.023309305,,,,,,,,,,,,,,,,, -53700,0.07935494,0.019994007,,,,,,,,,,,,,,,,, -53800,0.12381974,0.023914855,,,,,,,,,,,,,,,,, -53900,0.10925677,0.024492456,,,,,,,,,,,,,,,,, -53902,,,0.9939457774162292,0.0190976541489362,0.6707530227995426,0.9869185090065002,0.046229638159275,0.2894399170703837,43793.0,0.9861683249473572,0.0492582395672798,0.2728726141685595,43793.0,17547.396492242813,27380.38984155655,17547.396492242813,9829.080079555511,2.401202440261841,0.0 -54000,0.10256202,0.022598987,,,,,,,,,,,,,,,,, -54100,0.10461978,0.025157942,,,,,,,,,,,,,,,,, -54200,0.09949943,0.022071552,,,,,,,,,,,,,,,,, -54300,0.088346906,0.021734638,,,,,,,,,,,,,,,,, -54400,0.104454,0.023584366,,,,,,,,,,,,,,,,, -54500,0.10366485,0.022819022,,,,,,,,,,,,,,,,, -54600,0.08488291,0.022736395,,,,,,,,,,,,,,,,, -54639,,,0.9939642548561096,0.0190879423171281,0.6836719989171522,0.986976146697998,0.0466739870607852,0.2931826954710649,43793.0,0.9862428903579712,0.049737349152565,0.275510165065686,43793.0,17787.61997103691,27752.34281229973,17787.61997103691,9960.751490354538,2.438985824584961,0.0 -54700,0.101251744,0.020011919,,,,,,,,,,,,,,,,, -54800,0.09535916,0.022550924,,,,,,,,,,,,,,,,, -54900,0.10289986,0.023320474,,,,,,,,,,,,,,,,, -55000,0.111696504,0.023559524,,,,,,,,,,,,,,,,, -55100,0.10497626,0.02224917,,,,,,,,,,,,,,,,, -55200,0.09502899,0.021461757,,,,,,,,,,,,,,,,, -55300,0.09833744,0.020372082,,,,,,,,,,,,,,,,, -55378,,,0.9935755729675292,0.0200575366616249,0.6444034424776388,0.9869863390922546,0.0467824675142765,0.2931136241225678,43793.0,0.9861915111541748,0.0497394278645515,0.2762871000737157,43793.0,18027.63859820366,28120.75264573097,18027.63859820366,10089.079269647598,2.4811816215515137,0.0 -55400,0.09368225,0.021304982,,,,,,,,,,,,,,,,, -55500,0.11320194,0.023885248,,,,,,,,,,,,,,,,, -55600,0.11545751,0.02054501,,,,,,,,,,,,,,,,, -55700,0.096890785,0.023989208,,,,,,,,,,,,,,,,, -55800,0.117868,0.0225496,,,,,,,,,,,,,,,,, -55900,0.10189563,0.023987442,,,,,,,,,,,,,,,,, -56000,0.095735416,0.02135519,,,,,,,,,,,,,,,,, -56100,0.10949822,0.02381324,,,,,,,,,,,,,,,,, -56114,,,0.9936255216598512,0.0199688728898763,0.6374130133398599,0.9870532751083374,0.0471296124160289,0.2916289264718201,43793.0,0.9862412214279176,0.0501952022314071,0.2766432008581418,43793.0,18267.65841269493,28490.55920910836,18267.65841269493,10218.807644367218,2.518969774246216,0.0 -56200,0.09668948,0.02199921,,,,,,,,,,,,,,,,, -56300,0.10124029,0.021930385,,,,,,,,,,,,,,,,, -56400,0.1095461,0.021941464,,,,,,,,,,,,,,,,, -56500,0.11118507,0.023704719,,,,,,,,,,,,,,,,, -56600,0.10783285,0.02121931,,,,,,,,,,,,,,,,, -56700,0.10502976,0.022234406,,,,,,,,,,,,,,,,, -56800,0.106728904,0.019272221,,,,,,,,,,,,,,,,, -56855,,,0.9937307238578796,0.0195559505373239,0.6575340063625909,0.987097144126892,0.0470090880990028,0.2928508141046779,43793.0,0.986143946647644,0.0501659922301769,0.2744961808523273,43793.0,18507.83859324456,28857.974859952927,18507.83859324456,10345.98371219635,2.55792236328125,0.0 -56900,0.10761294,0.023458645,,,,,,,,,,,,,,,,, -57000,0.10479982,0.02276241,,,,,,,,,,,,,,,,, -57100,0.124343514,0.022627633,,,,,,,,,,,,,,,,, -57200,0.10940963,0.021556072,,,,,,,,,,,,,,,,, -57300,0.100559406,0.021059502,,,,,,,,,,,,,,,,, -57400,0.10928501,0.02264866,,,,,,,,,,,,,,,,, -57500,0.10346756,0.0223036,,,,,,,,,,,,,,,,, -57592,,,0.9938116073608398,0.0193178337067365,0.6544946164639226,0.9870354533195496,0.0472130291163921,0.2911910347651354,43793.0,0.9862513542175292,0.0502744019031524,0.2765929581524281,43793.0,18747.98366856575,29224.3308801651,18747.98366856575,10472.135902881622,2.5963222980499268,0.0 -57600,0.10506211,0.018154118,,,,,,,,,,,,,,,,, -57700,0.10737256,0.02220252,,,,,,,,,,,,,,,,, -57800,0.10715688,0.021535259,,,,,,,,,,,,,,,,, -57900,0.107682794,0.020778682,,,,,,,,,,,,,,,,, -58000,0.10762068,0.020702595,,,,,,,,,,,,,,,,, -58100,0.107163474,0.021600729,,,,,,,,,,,,,,,,, -58200,0.10874211,0.022638742,,,,,,,,,,,,,,,,, -58300,0.11086383,0.02342639,,,,,,,,,,,,,,,,, -58329,,,0.9938645958900452,0.0190196689218282,0.6740505018592804,0.9870337843894958,0.0474877506494522,0.2891203495535188,43793.0,0.9861683249473572,0.0507872477173805,0.2757953912061542,43793.0,18988.1043651104,29594.138032197952,18988.1043651104,10601.76334810257,2.6340818405151367,0.0 -58400,0.11129722,0.020876514,,,,,,,,,,,,,,,,, -58500,0.10367599,0.019870466,,,,,,,,,,,,,,,,, -58600,0.11647089,0.024114484,,,,,,,,,,,,,,,,, -58700,0.11286285,0.022168644,,,,,,,,,,,,,,,,, -58800,0.11101033,0.021550866,,,,,,,,,,,,,,,,, -58900,0.10842541,0.019972337,,,,,,,,,,,,,,,,, -59000,0.13038792,0.022349661,,,,,,,,,,,,,,,,, -59073,,,0.9941751956939696,0.0180447380989789,0.7022220962344548,0.9870853424072266,0.0476047173142433,0.2895204188531082,43793.0,0.9862648248672484,0.0509006977081298,0.2718785727090479,43793.0,19228.10386610031,29954.60914540291,19228.10386610031,10722.174923658373,2.673880100250244,0.0 -59100,0.11617518,0.021247689,,,,,,,,,,,,,,,,, -59200,0.113943994,0.021915378,,,,,,,,,,,,,,,,, -59300,0.13006665,0.022006352,,,,,,,,,,,,,,,,, -59400,0.102213666,0.02190729,,,,,,,,,,,,,,,,, -59500,0.12242408,0.021907978,,,,,,,,,,,,,,,,, -59600,0.13128191,0.024074089,,,,,,,,,,,,,,,,, -59700,0.10651588,0.019471752,,,,,,,,,,,,,,,,, -59800,0.12826853,0.02328773,,,,,,,,,,,,,,,,, -59813,,,0.9943927526474,0.0174613632261753,0.7008946711494934,0.9870354533195496,0.0478969104588031,0.2893068462662133,43793.0,0.9861982464790344,0.0511804521083831,0.2744422124499764,43793.0,19468.106098413467,30322.98073625565,19468.106098413467,10850.485614299774,2.712521314620972,0.0 -59900,0.119624354,0.02189821,,,,,,,,,,,,,,,,, -60000,0.10479622,0.02001673,,,,,,,,,,,,,,,,, -60100,0.13092467,0.021328604,,,,,,,,,,,,,,,,, -60200,0.11066933,0.020075025,,,,,,,,,,,,,,,,, -60300,0.11771563,0.018871201,,,,,,,,,,,,,,,,, -60400,0.11428441,0.021542944,,,,,,,,,,,,,,,,, -60500,0.12883373,0.020227183,,,,,,,,,,,,,,,,, -60557,,,0.9945353865623474,0.0171850584447383,0.7161099193970502,0.987028956413269,0.0478489883244037,0.2911537483110391,43793.0,0.986163318157196,0.0511066801846027,0.2768247622583396,43793.0,19708.216701745987,30682.25395989418,19708.216701745987,10969.589729309082,2.7505152225494385,0.0 -60600,0.12097762,0.023555048,,,,,,,,,,,,,,,,, -60700,0.10688589,0.018858938,,,,,,,,,,,,,,,,, -60800,0.11680216,0.021147585,,,,,,,,,,,,,,,,, -60900,0.1179429,0.021747267,,,,,,,,,,,,,,,,, -61000,0.12871215,0.022749111,,,,,,,,,,,,,,,,, -61100,0.1357123,0.023879802,,,,,,,,,,,,,,,,, -61200,0.11730891,0.020004312,,,,,,,,,,,,,,,,, -61294,,,0.9944341778755188,0.0173479840159416,0.6998307741889802,0.9870285391807556,0.048157338052988,0.2906187634034315,43793.0,0.9862361550331116,0.0511085912585258,0.2774663366633761,43793.0,19948.361443281174,31044.06996655464,19948.361443281174,11091.201363563538,2.789128541946411,0.0 -61300,0.11323896,0.020026883,,,,,,,,,,,,,,,,, -61400,0.108237006,0.020705324,,,,,,,,,,,,,,,,, -61500,0.104279116,0.01915243,,,,,,,,,,,,,,,,, -61600,0.11930552,0.020824997,,,,,,,,,,,,,,,,, -61700,0.11036013,0.020233162,,,,,,,,,,,,,,,,, -61800,0.12362546,0.018039567,,,,,,,,,,,,,,,,, -61900,0.11138732,0.020946875,,,,,,,,,,,,,,,,, -62000,0.11204113,0.020918878,,,,,,,,,,,,,,,,, -62029,,,0.9944174885749816,0.0174654368311166,0.7034989601227458,0.987064242362976,0.0482059121131897,0.294787254294828,43793.0,0.9862399697303772,0.0514232739806175,0.2774561780744564,43793.0,20188.502468585968,31403.512269973755,20188.502468585968,11210.443829536438,2.8275270462036133,0.0 -62100,0.13970236,0.021078086,,,,,,,,,,,,,,,,, -62200,0.11253787,0.019956285,,,,,,,,,,,,,,,,, -62300,0.116205715,0.018403305,,,,,,,,,,,,,,,,, -62400,0.13857757,0.02133606,,,,,,,,,,,,,,,,, -62500,0.13400444,0.020458553,,,,,,,,,,,,,,,,, -62600,0.111569345,0.020220188,,,,,,,,,,,,,,,,, -62700,0.119446576,0.018370483,,,,,,,,,,,,,,,,, -62776,,,0.9943106174468994,0.0174315962940454,0.6984793166748001,0.9869989156723022,0.0483455397188663,0.2943116500300501,43793.0,0.9862205982208252,0.0515215322375297,0.2804855365186107,43793.0,20428.56449484825,31762.874539613724,20428.56449484825,11329.68220448494,2.8689663410186768,0.0 -62800,0.11674138,0.019764291,,,,,,,,,,,,,,,,, -62900,0.113934875,0.018848294,,,,,,,,,,,,,,,,, -63000,0.12180241,0.019198451,,,,,,,,,,,,,,,,, -63100,0.1309504,0.018955745,,,,,,,,,,,,,,,,, -63200,0.13441455,0.019119184,,,,,,,,,,,,,,,,, -63300,0.1317403,0.02152911,,,,,,,,,,,,,,,,, -63400,0.14412956,0.020695984,,,,,,,,,,,,,,,,, -63500,0.1581123,0.022579072,,,,,,,,,,,,,,,,, -63524,,,0.9942901730537416,0.0177717264741659,0.6904933370506032,0.9868978261947632,0.0487305261194705,0.2893957604156967,43793.0,0.9861738085746764,0.0517492853105068,0.2760619205947037,43793.0,20668.679448366165,32126.68216776848,20668.679448366165,11453.315598249435,2.9075872898101807,0.0 -63600,0.1320917,0.019472778,,,,,,,,,,,,,,,,, -63700,0.12788467,0.017803857,,,,,,,,,,,,,,,,, -63800,0.12626076,0.018845644,,,,,,,,,,,,,,,,, -63900,0.13193257,0.021442054,,,,,,,,,,,,,,,,, -64000,0.11076095,0.017061153,,,,,,,,,,,,,,,,, -64100,0.14410245,0.023600547,,,,,,,,,,,,,,,,, -64200,0.12712054,0.019439034,,,,,,,,,,,,,,,,, -64266,,,0.99448424577713,0.0171282012015581,0.7091291766287441,0.987028956413269,0.0488222688436508,0.295160113444928,43793.0,0.9861649870872498,0.0520334914326667,0.2751095010284893,43793.0,20908.896485328674,32494.05240297317,20908.896485328674,11580.409008979796,2.946688413619995,0.0 -64300,0.13668205,0.018911071,,,,,,,,,,,,,,,,, -64400,0.1214686,0.019289795,,,,,,,,,,,,,,,,, -64500,0.13720529,0.020848613,,,,,,,,,,,,,,,,, -64600,0.1382448,0.019067228,,,,,,,,,,,,,,,,, -64700,0.123843335,0.019478887,,,,,,,,,,,,,,,,, -64800,0.11960843,0.019269457,,,,,,,,,,,,,,,,, -64900,0.124733455,0.020951618,,,,,,,,,,,,,,,,, -65000,0.1442651,0.021093203,,,,,,,,,,,,,,,,, -65013,,,0.9946561455726624,0.0164926163852214,0.7217163236524171,0.9869893789291382,0.0493305474519729,0.2885682063827535,43793.0,0.9861902594566344,0.0524006597697734,0.2766862655102707,43793.0,21149.007489204407,32851.034240722656,21149.007489204407,11697.220544338226,2.9854743480682373,0.0 -65100,0.12740369,0.018274948,,,,,,,,,,,,,,,,, -65200,0.13345511,0.020451734,,,,,,,,,,,,,,,,, -65300,0.11675694,0.019076291,,,,,,,,,,,,,,,,, -65400,0.14519197,0.020282477,,,,,,,,,,,,,,,,, -65500,0.124043755,0.019203398,,,,,,,,,,,,,,,,, -65600,0.112449035,0.017069876,,,,,,,,,,,,,,,,, -65700,0.13068917,0.02027465,,,,,,,,,,,,,,,,, -65754,,,0.994984209537506,0.0157103948295116,0.751098217448867,0.9869558811187744,0.0491550415754318,0.2917384044633542,43793.0,0.986180543899536,0.052390594035387,0.2744711419810177,43793.0,21388.94634079933,33211.36627578735,21388.94634079933,11817.548296689987,3.0303094387054443,0.0 -65800,0.13774346,0.019714134,,,,,,,,,,,,,,,,, -65900,0.13542454,0.019850152,,,,,,,,,,,,,,,,, -66000,0.1356874,0.0206243,,,,,,,,,,,,,,,,, -66100,0.13653445,0.020324504,,,,,,,,,,,,,,,,, -66200,0.1275696,0.021388642,,,,,,,,,,,,,,,,, -66300,0.13290223,0.021450395,,,,,,,,,,,,,,,,, -66400,0.12810345,0.019695746,,,,,,,,,,,,,,,,, -66488,,,0.9949955940246582,0.015580584295094,0.743294926692901,0.9869229793548584,0.0493643842637538,0.2929701131226171,43793.0,0.9861359000205994,0.0525679588317871,0.2745586815742383,43793.0,21628.92591571808,33571.829740047455,21628.92591571808,11937.971812486649,3.069685459136963,0.0 -66500,0.13524196,0.01901317,,,,,,,,,,,,,,,,, -66600,0.12490438,0.020132985,,,,,,,,,,,,,,,,, -66700,0.122485414,0.019154735,,,,,,,,,,,,,,,,, -66800,0.14208348,0.019931903,,,,,,,,,,,,,,,,, -66900,0.11433101,0.017998334,,,,,,,,,,,,,,,,, -67000,0.13588808,0.018800054,,,,,,,,,,,,,,,,, -67100,0.11726905,0.017481253,,,,,,,,,,,,,,,,, -67200,0.11231855,0.017874233,,,,,,,,,,,,,,,,, -67226,,,0.995172679424286,0.0151609126478433,0.752292266305013,0.9870151281356812,0.0494721122086048,0.2935105053459713,43793.0,0.9861502647399902,0.0527020953595638,0.2758090958557493,43793.0,21868.94931316376,33934.243431806564,21868.94931316376,12060.300013542175,3.110136747360229,0.0 -67300,0.12332869,0.018359642,,,,,,,,,,,,,,,,, -67400,0.1256248,0.017170295,,,,,,,,,,,,,,,,, -67500,0.1254153,0.018607581,,,,,,,,,,,,,,,,, -67600,0.11873998,0.018343309,,,,,,,,,,,,,,,,, -67700,0.13726199,0.022129878,,,,,,,,,,,,,,,,, -67800,0.11763323,0.01657614,,,,,,,,,,,,,,,,, -67900,0.1318822,0.01895565,,,,,,,,,,,,,,,,, -67962,,,0.9949970245361328,0.0155762517824769,0.7404822138643696,0.9870175719261168,0.0498333163559436,0.2956191146081193,43793.0,0.9861688017845154,0.053035944700241,0.2768718733391812,43793.0,22108.92947244644,34293.75819349289,22108.92947244644,12179.768604755402,3.154653310775757,0.0 -68000,0.12338884,0.02004365,,,,,,,,,,,,,,,,, -68100,0.111271836,0.016002165,,,,,,,,,,,,,,,,, -68200,0.13647996,0.021014474,,,,,,,,,,,,,,,,, -68300,0.12913,0.018553196,,,,,,,,,,,,,,,,, -68400,0.13858072,0.019125164,,,,,,,,,,,,,,,,, -68500,0.13088793,0.017889516,,,,,,,,,,,,,,,,, -68600,0.12749647,0.017354049,,,,,,,,,,,,,,,,, -68695,,,0.9949944615364076,0.0155779477208852,0.7456122928440716,0.9869124293327332,0.0496644973754882,0.2947325860111379,43793.0,0.9861283302307128,0.0528410337865352,0.2770126310365181,43793.0,22349.121163129807,34658.99287319183,22349.121163129807,12304.749731063845,3.194903612136841,0.0 -68700,0.15409325,0.022390962,,,,,,,,,,,,,,,,, -68800,0.13172598,0.019711105,,,,,,,,,,,,,,,,, -68900,0.12153055,0.020217784,,,,,,,,,,,,,,,,, -69000,0.1594726,0.020929338,,,,,,,,,,,,,,,,, -69100,0.18378076,0.020400854,,,,,,,,,,,,,,,,, -69200,0.13581716,0.018827332,,,,,,,,,,,,,,,,, -69300,0.1407894,0.017457021,,,,,,,,,,,,,,,,, -69400,0.16152334,0.020606397,,,,,,,,,,,,,,,,, -69439,,,0.9948853850364684,0.0157805941998958,0.7421132570454949,0.9869562983512878,0.0499289967119693,0.2922757941451337,43793.0,0.9861435294151306,0.0531154498457908,0.2764680653602601,43793.0,22589.33414030075,35018.08116745949,22589.33414030075,12423.543547868729,3.256009817123413,0.0 -69500,0.12844208,0.014954738,,,,,,,,,,,,,,,,, -69600,0.1493294,0.019116662,,,,,,,,,,,,,,,,, -69700,0.12602009,0.018599493,,,,,,,,,,,,,,,,, -69800,0.14205904,0.019067405,,,,,,,,,,,,,,,,, -69900,0.1427471,0.021314211,,,,,,,,,,,,,,,,, -70000,0.11717898,0.016049676,,,,,,,,,,,,,,,,, -70100,0.119189,0.01676679,,,,,,,,,,,,,,,,, -70182,,,0.9949990510940552,0.0154972299933433,0.7395333602277423,0.986955463886261,0.049820426851511,0.2946346190690715,43793.0,0.9861018061637878,0.0530740618705749,0.2752897350878246,43793.0,22829.57661533356,35384.631796598434,22829.57661533356,12549.790276288986,3.2973122596740723,0.0 -70200,0.15842335,0.022511046,,,,,,,,,,,,,,,,, -70300,0.13315,0.01773868,,,,,,,,,,,,,,,,, -70400,0.14508897,0.020318883,,,,,,,,,,,,,,,,, -70500,0.13924067,0.018336423,,,,,,,,,,,,,,,,, -70600,0.1541756,0.021283945,,,,,,,,,,,,,,,,, -70700,0.13489659,0.018528005,,,,,,,,,,,,,,,,, -70800,0.13631445,0.018689226,,,,,,,,,,,,,,,,, -70900,0.14117743,0.017856501,,,,,,,,,,,,,,,,, -70928,,,0.9952114820480348,0.0149235017597675,0.7624726762016496,0.987009048461914,0.0501800142228603,0.292888755323264,43793.0,0.9861287474632264,0.0536137297749519,0.2747126367614844,43793.0,23069.656600236893,35743.65252280235,23069.656600236893,12668.670906543732,3.3372585773468018,0.0 -71000,0.14941747,0.01797999,,,,,,,,,,,,,,,,, -71100,0.13732325,0.01793948,,,,,,,,,,,,,,,,, -71200,0.12969725,0.017177546,,,,,,,,,,,,,,,,, -71300,0.16176368,0.01984375,,,,,,,,,,,,,,,,, -71400,0.14760116,0.02113524,,,,,,,,,,,,,,,,, -71500,0.1430824,0.01914459,,,,,,,,,,,,,,,,, -71600,0.12511699,0.018832205,,,,,,,,,,,,,,,,, -71669,,,0.9953161478042604,0.0145877562463283,0.7535696656906885,0.9870131015777588,0.0501160286366939,0.2949806488973432,43793.0,0.9861708879470824,0.0535033456981182,0.2766858995086599,43793.0,23309.76019668579,36109.615293741226,23309.76019668579,12794.468041181564,3.378878831863404,0.0 -71700,0.13827844,0.019045543,,,,,,,,,,,,,,,,, -71800,0.138404,0.017972331,,,,,,,,,,,,,,,,, -71900,0.142162,0.019623712,,,,,,,,,,,,,,,,, -72000,0.13440765,0.016547859,,,,,,,,,,,,,,,,, -72100,0.13725904,0.020029524,,,,,,,,,,,,,,,,, -72200,0.13760017,0.018525815,,,,,,,,,,,,,,,,, -72300,0.1370438,0.017638555,,,,,,,,,,,,,,,,, -72398,,,0.9954724311828612,0.0142005616798996,0.7746771146110315,0.9869810342788696,0.0500811003148555,0.2950637680379031,43793.0,0.9861599206924438,0.0534425675868988,0.2757994127724285,43793.0,23550.00160861016,36471.81417179108,23550.00160861016,12916.356031417848,3.4264605045318604,0.0 -72400,0.14109202,0.018049864,,,,,,,,,,,,,,,,, -72500,0.12698084,0.016573325,,,,,,,,,,,,,,,,, -72600,0.13818124,0.018180145,,,,,,,,,,,,,,,,, -72700,0.15074559,0.019932684,,,,,,,,,,,,,,,,, -72800,0.15108848,0.020412814,,,,,,,,,,,,,,,,, -72900,0.14320846,0.018028777,,,,,,,,,,,,,,,,, -73000,0.14783244,0.019545224,,,,,,,,,,,,,,,,, -73100,0.14245021,0.018734513,,,,,,,,,,,,,,,,, -73141,,,0.9954382181167604,0.0143452454358339,0.7782669256073729,0.9870086312294006,0.050230972468853,0.2939535993961091,43793.0,0.9861510992050172,0.0536785386502742,0.2758017772485302,43793.0,23790.09094119072,36836.50923585892,23790.09094119072,13040.898215293884,3.469341993331909,0.0 -73200,0.12288396,0.016800892,,,,,,,,,,,,,,,,, -73300,0.14264111,0.01873261,,,,,,,,,,,,,,,,, -73400,0.1413763,0.017550806,,,,,,,,,,,,,,,,, -73500,0.13504292,0.01766274,,,,,,,,,,,,,,,,, -73600,0.15460916,0.019734655,,,,,,,,,,,,,,,,, -73700,0.12811007,0.017222226,,,,,,,,,,,,,,,,, -73800,0.15337181,0.01881259,,,,,,,,,,,,,,,,, -73865,,,0.9954219460487366,0.014223325997591,0.77684622631777,0.9870402812957764,0.0503243319690227,0.2935714816912196,43793.0,0.9861910939216614,0.0537983998656272,0.2746981960824259,43793.0,24030.04671812057,37196.72438669205,24030.04671812057,13161.084740161896,3.51848578453064,0.0 -73900,0.1434904,0.019267436,,,,,,,,,,,,,,,,, -74000,0.13627127,0.01798719,,,,,,,,,,,,,,,,, -74100,0.14930913,0.017159251,,,,,,,,,,,,,,,,, -74200,0.14334229,0.018713979,,,,,,,,,,,,,,,,, -74300,0.14889993,0.019693008,,,,,,,,,,,,,,,,, -74400,0.14179018,0.017766893,,,,,,,,,,,,,,,,, -74500,0.15094648,0.02140798,,,,,,,,,,,,,,,,, -74600,0.12893732,0.017600177,,,,,,,,,,,,,,,,, -74606,,,0.9953264594078064,0.0145053453743457,0.7675750118704716,0.9869623780250548,0.0502690747380256,0.2933164821603256,43793.0,0.9861281514167786,0.0537189431488513,0.2753094927661393,43793.0,24270.110892534256,37557.48768091202,24270.110892534256,13281.721255779266,3.5607388019561768,0.0 -74700,0.13519621,0.015969934,,,,,,,,,,,,,,,,, -74800,0.13484581,0.018498523,,,,,,,,,,,,,,,,, -74900,0.13925411,0.017872008,,,,,,,,,,,,,,,,, -75000,0.14225467,0.01844257,,,,,,,,,,,,,,,,, -75100,0.14100252,0.019736106,,,,,,,,,,,,,,,,, -75200,0.14316474,0.01935876,,,,,,,,,,,,,,,,, -75300,0.13756868,0.017317316,,,,,,,,,,,,,,,,, -75339,,,0.9953559637069702,0.0144855827093124,0.7559158070368784,0.9869757294654846,0.0504377521574497,0.2932075498164984,43793.0,0.9861477017402648,0.0538434907793998,0.2762096863110901,43793.0,24510.027459144592,37920.07783651352,24510.027459144592,13404.296205759048,3.6379876136779785,0.0 -75400,0.15574029,0.018461017,,,,,,,,,,,,,,,,, -75500,0.13474561,0.016800884,,,,,,,,,,,,,,,,, -75600,0.14091791,0.018379876,,,,,,,,,,,,,,,,, -75700,0.13205825,0.01608644,,,,,,,,,,,,,,,,, -75800,0.16147698,0.018714081,,,,,,,,,,,,,,,,, -75900,0.140732,0.019754339,,,,,,,,,,,,,,,,, -76000,0.12469514,0.01763211,,,,,,,,,,,,,,,,, -76068,,,0.9953692555427552,0.0144447674974799,0.7646718876340945,0.9869931936264038,0.0504125356674194,0.2931289894339166,43793.0,0.9861472845077516,0.053802452981472,0.2767602742047537,43793.0,24750.19059085846,38281.042036771774,24750.19059085846,13525.030313014984,3.6831717491149902,0.0 -76100,0.14902018,0.01868058,,,,,,,,,,,,,,,,, -76200,0.11732294,0.015415348,,,,,,,,,,,,,,,,, -76300,0.15141018,0.017891683,,,,,,,,,,,,,,,,, -76400,0.13631262,0.018648388,,,,,,,,,,,,,,,,, -76500,0.13876246,0.016225724,,,,,,,,,,,,,,,,, -76600,0.14451952,0.016778288,,,,,,,,,,,,,,,,, -76700,0.13989684,0.018676734,,,,,,,,,,,,,,,,, -76800,0.12976007,0.015959464,,,,,,,,,,,,,,,,, -76810,,,0.995473325252533,0.0141021022573113,0.7793426664482765,0.9869822263717652,0.0504016950726509,0.2939869283770364,43793.0,0.986127495765686,0.0538610033690929,0.275816124039496,43793.0,24990.2333111763,38637.51325106621,24990.2333111763,13641.3963701725,3.724697113037109,0.0 -76900,0.14662534,0.01759448,,,,,,,,,,,,,,,,, -77000,0.14472894,0.018246597,,,,,,,,,,,,,,,,, -77100,0.15056096,0.019896016,,,,,,,,,,,,,,,,, -77200,0.122821845,0.014983941,,,,,,,,,,,,,,,,, -77300,0.15287034,0.020737106,,,,,,,,,,,,,,,,, -77400,0.1503993,0.018055946,,,,,,,,,,,,,,,,, -77500,0.13296697,0.015375957,,,,,,,,,,,,,,,,, -77543,,,0.995535135269165,0.013972028158605,0.7854258956826089,0.9869814515113832,0.0504111647605896,0.2938731668240078,43793.0,0.98613041639328,0.0538419187068939,0.2757654531108339,43793.0,25230.289999961853,38996.68863105774,25230.289999961853,13760.451066493988,3.767266511917114,0.0 -77600,0.14045347,0.01782123,,,,,,,,,,,,,,,,, -77700,0.15241593,0.019381596,,,,,,,,,,,,,,,,, -77800,0.13169624,0.016601548,,,,,,,,,,,,,,,,, -77900,0.14742446,0.01974331,,,,,,,,,,,,,,,,, -78000,0.12992586,0.017335467,,,,,,,,,,,,,,,,, -78100,0.15369512,0.021362888,,,,,,,,,,,,,,,,, -78200,0.13108483,0.017135452,,,,,,,,,,,,,,,,, -78284,,,0.9955268502235411,0.0140170855447649,0.7766022190074704,0.9870017170906068,0.0504525080323219,0.2935173090637798,43793.0,0.9861502647399902,0.0538810566067695,0.2762060660463843,43793.0,25470.411824703217,39358.22065544128,25470.411824703217,13881.799325227736,3.8089919090271,0.0 -78300,0.14591993,0.018044088,,,,,,,,,,,,,,,,, -78400,0.12879723,0.017001664,,,,,,,,,,,,,,,,, -78500,0.1345491,0.016864698,,,,,,,,,,,,,,,,, -78600,0.1342511,0.017064648,,,,,,,,,,,,,,,,, -78700,0.14972118,0.017120728,,,,,,,,,,,,,,,,, -78800,0.15621826,0.018373404,,,,,,,,,,,,,,,,, -78900,0.15561709,0.017953321,,,,,,,,,,,,,,,,, -79000,0.13674621,0.01857381,,,,,,,,,,,,,,,,, -79005,,,0.9955037236213684,0.0140764387324452,0.7753496218501175,0.9870216250419616,0.050493486225605,0.2931579732099656,43793.0,0.9861544370651244,0.0539372749626636,0.2760413132298542,43793.0,25710.480096817017,39721.13111758232,25710.480096817017,14004.569951534271,3.8572590351104736,0.0 -79100,0.14245117,0.017523209,,,,,,,,,,,,,,,,, -79200,0.1573789,0.017779542,,,,,,,,,,,,,,,,, -79300,0.1462768,0.01889063,,,,,,,,,,,,,,,,, -79400,0.13829648,0.01664496,,,,,,,,,,,,,,,,, -79500,0.13058077,0.017824603,,,,,,,,,,,,,,,,, -79600,0.1277846,0.017488686,,,,,,,,,,,,,,,,, -79700,0.14804608,0.017270364,,,,,,,,,,,,,,,,, -79743,,,0.9954732656478882,0.0141325732693076,0.7680403320439355,0.9870207905769348,0.0504891276359558,0.2931052477542326,43793.0,0.986162006855011,0.0539317578077316,0.2758741018370922,43793.0,25950.60027909279,40078.85696077347,25950.60027909279,14122.105741977692,3.906131267547608,0.0 -79800,0.13678321,0.017673288,,,,,,,,,,,,,,,,, -79900,0.14071803,0.018142305,,,,,,,,,,,,,,,,, -80000,0.14739402,0.017156558,,,,,,,,,,,,,,,,, -80100,0.16010132,0.0193055,,,,,,,,,,,,,,,,, -80200,0.16461496,0.020857092,,,,,,,,,,,,,,,,, -80300,0.17257975,0.017289313,,,,,,,,,,,,,,,,, -80400,0.13679563,0.017947003,,,,,,,,,,,,,,,,, -80484,,,0.9954727292060852,0.0140985697507858,0.7734134501513834,0.9870207905769348,0.0504884012043476,0.2932042189287904,43793.0,0.9861586689949036,0.0539308004081249,0.2760013251638327,43793.0,26190.832530498505,40436.92608857155,26190.832530498505,14239.878486156464,3.949522018432617,0.0 -80500,0.14784135,0.018021153,,,,,,,,,,,,,,,,, -80600,0.13841367,0.01866863,,,,,,,,,,,,,,,,, -80700,0.14459741,0.017428737,,,,,,,,,,,,,,,,, -80800,0.13929538,0.018011188,,,,,,,,,,,,,,,,, -80900,0.1349884,0.018690879,,,,,,,,,,,,,,,,, -81000,0.140292,0.017371586,,,,,,,,,,,,,,,,, -81100,0.13589355,0.017601768,,,,,,,,,,,,,,,,, -81200,0.14384757,0.01814095,,,,,,,,,,,,,,,,, -81227,,,0.9955683350563048,0.0138638988137245,0.7826635520548781,0.9870207905769348,0.0504884012043476,0.2933164355896822,43793.0,0.9861586689949036,0.0539308078587055,0.2759893265669187,43793.0,26430.913827180862,40793.525622844696,26430.913827180862,14356.332607507706,3.993171215057373,0.0 -81300,0.13030559,0.01722715,,,,,,,,,,,,,,,,, -81400,0.1353843,0.016823353,,,,,,,,,,,,,,,,, -81500,0.1445357,0.01713204,,,,,,,,,,,,,,,,, -81600,0.14510202,0.01853564,,,,,,,,,,,,,,,,, -81700,0.15021254,0.019671805,,,,,,,,,,,,,,,,, -81800,0.15536204,0.018534169,,,,,,,,,,,,,,,,, -81900,0.15465863,0.018212402,,,,,,,,,,,,,,,,, -81964,,,0.9954887628555298,0.014108401723206,0.7777829012838969,0.9870207905769348,0.0504884012043476,0.2932682244947826,43793.0,0.9861586689949036,0.0539308004081249,0.2758879511772302,43793.0,26671.08064627648,41158.192217350006,26671.08064627648,14480.77075123787,4.034946918487549,0.0 -82000,0.13252722,0.017571,,,,,,,,,,,,,,,,, -82100,0.13474755,0.018058304,,,,,,,,,,,,,,,,, -82200,0.13582559,0.017677173,,,,,,,,,,,,,,,,, -82300,0.14458074,0.01876116,,,,,,,,,,,,,,,,, -82400,0.14584236,0.019484477,,,,,,,,,,,,,,,,, -82500,0.1311143,0.01741667,,,,,,,,,,,,,,,,, -82600,0.12859033,0.017192207,,,,,,,,,,,,,,,,, -82684,,,0.9954996109008788,0.0140416109934449,0.773361696380792,0.9870207905769348,0.0504884012043476,0.2932646215050519,43793.0,0.9861586689949036,0.0539308078587055,0.2759877260387026,43793.0,26911.22525882721,41517.2912569046,26911.22525882721,14599.658869504929,4.078474044799805,0.0 -82700,0.13170715,0.017171806,,,,,,,,,,,,,,,,, -82800,0.12986082,0.016382022,,,,,,,,,,,,,,,,, -82900,0.15081035,0.01860652,,,,,,,,,,,,,,,,, -83000,0.14486125,0.01877108,,,,,,,,,,,,,,,,, -83100,0.13283297,0.016147725,,,,,,,,,,,,,,,,, -83200,0.14256941,0.01881572,,,,,,,,,,,,,,,,, -83300,0.14256652,0.018910864,,,,,,,,,,,,,,,,, -83400,0.1288206,0.017889012,,,,,,,,,,,,,,,,, -83424,,,0.995476007461548,0.0141355525702238,0.7666465209970005,0.9870207905769348,0.0504884012043476,0.2931580788234526,43793.0,0.9861586689949036,0.0539308078587055,0.2759603496560459,43793.0,27151.46985912323,41878.00029158592,27151.46985912323,14720.061032772064,4.120722055435181,0.0 -83500,0.14336582,0.017854054,,,,,,,,,,,,,,,,, -83600,0.13546099,0.018206896,,,,,,,,,,,,,,,,, -83700,0.15808508,0.019460626,,,,,,,,,,,,,,,,, -83800,0.13850275,0.01869957,,,,,,,,,,,,,,,,, -83900,0.14356908,0.020836571,,,,,,,,,,,,,,,,, -84000,0.15149838,0.019559221,,,,,,,,,,,,,,,,, -84100,0.13857771,0.019750323,,,,,,,,,,,,,,,,, -84163,,,0.9954543709754944,0.0141666149720549,0.7774010441861712,0.9870207905769348,0.0504884012043476,0.2932389238302139,43793.0,0.9861586689949036,0.0539308004081249,0.2759228082006909,43793.0,27391.469383001328,42237.39330744743,27391.469383001328,14839.390854358671,4.164043664932251,0.0 -84200,0.15812577,0.01983177,,,,,,,,,,,,,,,,, -84300,0.14061268,0.019550934,,,,,,,,,,,,,,,,, -84400,0.14219682,0.01785862,,,,,,,,,,,,,,,,, -84500,0.1351222,0.016189175,,,,,,,,,,,,,,,,, -84600,0.14496744,0.019199034,,,,,,,,,,,,,,,,, -84700,0.15414944,0.01847926,,,,,,,,,,,,,,,,, -84800,0.13591088,0.018759795,,,,,,,,,,,,,,,,, -84890,,,0.9955111145973206,0.0140264285728335,0.7820682909038227,0.9870207905769348,0.0504884012043476,0.2933124181851586,43793.0,0.9861586689949036,0.0539308078587055,0.2759544450025004,43793.0,27631.43978309632,42596.40437030792,27631.43978309632,14958.367205619812,4.2065746784210205,0.0 -84900,0.14565146,0.018853664,,,,,,,,,,,,,,,,, -85000,0.14673774,0.017867465,,,,,,,,,,,,,,,,, -85100,0.1640435,0.019667082,,,,,,,,,,,,,,,,, -85200,0.13350277,0.01808094,,,,,,,,,,,,,,,,, -85300,0.1468828,0.017477289,,,,,,,,,,,,,,,,, -85400,0.13735396,0.018733865,,,,,,,,,,,,,,,,, -85500,0.14132981,0.018459734,,,,,,,,,,,,,,,,, -85600,0.11797129,0.014262203,,,,,,,,,,,,,,,,, -85621,,,0.9955190420150756,0.013955594971776,0.7841828855297852,0.9870207905769348,0.0504884012043476,0.2931340940827518,43793.0,0.9861586689949036,0.0539308004081249,0.2759841329736199,43793.0,27871.646685361862,42959.50312495232,27871.646685361862,15081.195603370668,4.2489824295043945,0.0 -85700,0.12437001,0.01674617,,,,,,,,,,,,,,,,, -85800,0.14901982,0.018064085,,,,,,,,,,,,,,,,, -85900,0.14603212,0.020196548,,,,,,,,,,,,,,,,, -86000,0.13461082,0.018071357,,,,,,,,,,,,,,,,, -86100,0.1384828,0.017685188,,,,,,,,,,,,,,,,, -86200,0.13961421,0.01824008,,,,,,,,,,,,,,,,, -86300,0.13884865,0.019256787,,,,,,,,,,,,,,,,, -86345,,,0.99550598859787,0.0140627259388566,0.772074521198249,0.9870207905769348,0.0504884012043476,0.2932729704005158,43793.0,0.9861586689949036,0.0539308078587055,0.275911657756955,43793.0,28111.63972616196,43325.16566514969,28111.63972616196,15206.80028629303,4.29235053062439,0.0 -86400,0.14712206,0.016707286,,,,,,,,,,,,,,,,, -86500,0.15825157,0.01768835,,,,,,,,,,,,,,,,, -86600,0.12603454,0.017593019,,,,,,,,,,,,,,,,, -86700,0.12895595,0.01617703,,,,,,,,,,,,,,,,, -86800,0.1576209,0.018256972,,,,,,,,,,,,,,,,, -86900,0.14827287,0.017287884,,,,,,,,,,,,,,,,, -87000,0.14450864,0.01685747,,,,,,,,,,,,,,,,, -87074,,,0.9954939484596252,0.0140989450737833,0.7701163354328072,0.9870207905769348,0.0504884012043476,0.2930998900386766,43793.0,0.9861586689949036,0.0539308004081249,0.2760227125616841,43793.0,28351.80605864525,43684.0396668911,28351.80605864525,15325.437908411026,4.341654539108276,0.0 -87100,0.14955221,0.01914194,,,,,,,,,,,,,,,,, -87200,0.14328162,0.018746685,,,,,,,,,,,,,,,,, -87300,0.15977417,0.020932363,,,,,,,,,,,,,,,,, -87400,0.12882073,0.017793681,,,,,,,,,,,,,,,,, -87500,0.163185,0.017636364,,,,,,,,,,,,,,,,, -87600,0.1348452,0.01769206,,,,,,,,,,,,,,,,, -87700,0.13871577,0.018912798,,,,,,,,,,,,,,,,, -87800,0.14138429,0.016079267,,,,,,,,,,,,,,,,, -87801,,,0.9954929947853088,0.0141305215656757,0.771353052064392,0.9870207905769348,0.0504884012043476,0.2932041771177117,43793.0,0.9861586689949036,0.0539308078587055,0.2759060362578897,43793.0,28591.916484594345,44041.82602286339,28591.916484594345,15443.043601989746,4.391093969345093,0.0 -87900,0.13670844,0.017492509,,,,,,,,,,,,,,,,, -88000,0.14887698,0.019531399,,,,,,,,,,,,,,,,, -88100,0.1410408,0.018800007,,,,,,,,,,,,,,,,, -88200,0.1540235,0.02014619,,,,,,,,,,,,,,,,, -88300,0.15568647,0.018672427,,,,,,,,,,,,,,,,, -88400,0.12229537,0.017216504,,,,,,,,,,,,,,,,, -88500,0.14871278,0.018672029,,,,,,,,,,,,,,,,, -88532,,,0.9954800009727478,0.0140950288623571,0.781205735501981,0.9870207905769348,0.0504884012043476,0.2932369735857291,43793.0,0.9861586689949036,0.0539308078587055,0.2759228496513186,43793.0,28831.876116752625,44395.9159309864,28831.876116752625,15557.111010789871,4.433941125869751,0.0 -88600,0.15132858,0.018039519,,,,,,,,,,,,,,,,, -88700,0.14299895,0.017934145,,,,,,,,,,,,,,,,, -88800,0.14677382,0.017087666,,,,,,,,,,,,,,,,, -88900,0.123167306,0.015887072,,,,,,,,,,,,,,,,, -89000,0.13180694,0.018195923,,,,,,,,,,,,,,,,, -89100,0.14458232,0.018403918,,,,,,,,,,,,,,,,, -89200,0.13522749,0.017800365,,,,,,,,,,,,,,,,, -89274,,,0.9955880641937256,0.0138135934248566,0.7830603803705166,0.9870207905769348,0.0504884012043476,0.2932198575902922,43793.0,0.9861586689949036,0.0539308078587055,0.2758838902077703,43793.0,29071.96212220192,44755.5393948555,29071.96212220192,15676.584495544434,4.477561473846436,0.0 -89300,0.13311177,0.017470133,,,,,,,,,,,,,,,,, -89400,0.1276168,0.017984772,,,,,,,,,,,,,,,,, -89500,0.1303824,0.017525138,,,,,,,,,,,,,,,,, -89600,0.1607196,0.019191891,,,,,,,,,,,,,,,,, -89700,0.1279574,0.015529652,,,,,,,,,,,,,,,,, -89800,0.15252511,0.020816324,,,,,,,,,,,,,,,,, -89900,0.15130569,0.019274475,,,,,,,,,,,,,,,,, -90000,0.13488816,0.017857574,,,,,,,,,,,,,,,,, -90005,,,0.9954330921173096,0.0141439158469438,0.7768471852559291,0.9870207905769348,0.0504884012043476,0.2932248225737306,43793.0,0.9861586689949036,0.0539308004081249,0.2758904980645204,43793.0,29312.07736825943,45111.980585575104,29312.07736825943,15792.836474895475,4.530365467071533,0.0 -90100,0.14727844,0.018112585,,,,,,,,,,,,,,,,, -90200,0.12809043,0.018352186,,,,,,,,,,,,,,,,, -90300,0.13274069,0.017588863,,,,,,,,,,,,,,,,, -90400,0.12962218,0.01714698,,,,,,,,,,,,,,,,, -90500,0.15108708,0.018986888,,,,,,,,,,,,,,,,, -90600,0.15056776,0.01726895,,,,,,,,,,,,,,,,, -90700,0.15008274,0.018054774,,,,,,,,,,,,,,,,, -90745,,,0.9954633712768556,0.0141695011407136,0.7759647711457989,0.9870207905769348,0.0504884012043476,0.2933183661124961,43793.0,0.9861586689949036,0.0539308004081249,0.2759802245458693,43793.0,29552.14755630493,45476.13833808899,29552.14755630493,15916.857501029968,4.576200008392334,0.0 -90800,0.12965876,0.016624741,,,,,,,,,,,,,,,,, -90900,0.15577061,0.01948124,,,,,,,,,,,,,,,,, -91000,0.14428106,0.018014083,,,,,,,,,,,,,,,,, -91100,0.15603983,0.018613812,,,,,,,,,,,,,,,,, -91200,0.13957024,0.019597316,,,,,,,,,,,,,,,,, -91300,0.13528724,0.018343087,,,,,,,,,,,,,,,,, -91400,0.12975484,0.01667662,,,,,,,,,,,,,,,,, -91480,,,0.9955159425735474,0.0140622379258275,0.7617297949769014,0.9870207905769348,0.0504884012043476,0.2932435573506383,43793.0,0.9861586689949036,0.0539308078587055,0.2760763937091977,43793.0,29792.3046181202,45834.32055068016,29792.3046181202,16034.809426784515,4.627282619476318,0.0 -91500,0.15038015,0.018093657,,,,,,,,,,,,,,,,, -91600,0.13742316,0.017199783,,,,,,,,,,,,,,,,, -91700,0.15603058,0.016631564,,,,,,,,,,,,,,,,, -91800,0.14799702,0.017485896,,,,,,,,,,,,,,,,, -91900,0.12404871,0.017706903,,,,,,,,,,,,,,,,, -92000,0.13559271,0.018096944,,,,,,,,,,,,,,,,, -92100,0.13496214,0.016199984,,,,,,,,,,,,,,,,, -92200,0.14255138,0.018173229,,,,,,,,,,,,,,,,, -92223,,,0.9955058097839355,0.014057345688343,0.7794176027937787,0.9870207905769348,0.0504884012043476,0.2930912488721401,43793.0,0.9861586689949036,0.0539308078587055,0.2759261753171792,43793.0,30032.342000246048,46185.14011597633,30032.342000246048,16145.526446580889,4.671795606613159,0.0 -92300,0.15260367,0.017639823,,,,,,,,,,,,,,,,, -92400,0.13424182,0.017378133,,,,,,,,,,,,,,,,, -92500,0.15172245,0.018073248,,,,,,,,,,,,,,,,, -92600,0.15518066,0.020676063,,,,,,,,,,,,,,,,, -92700,0.15575251,0.018636582,,,,,,,,,,,,,,,,, -92800,0.12679707,0.01622234,,,,,,,,,,,,,,,,, -92900,0.1376493,0.017718326,,,,,,,,,,,,,,,,, -92972,,,0.9955204129219056,0.0139325829222798,0.7852867760278921,0.9870207905769348,0.0504884012043476,0.2931630314705667,43793.0,0.9861586689949036,0.0539308004081249,0.2760432220871991,43793.0,30272.507484912872,46539.68813109398,30272.507484912872,16259.84417271614,4.716129779815674,0.0 -93000,0.13732846,0.018277511,,,,,,,,,,,,,,,,, -93100,0.1340368,0.01613308,,,,,,,,,,,,,,,,, -93200,0.14637399,0.016468525,,,,,,,,,,,,,,,,, -93300,0.14683756,0.016549231,,,,,,,,,,,,,,,,, -93400,0.1406019,0.016001364,,,,,,,,,,,,,,,,, -93500,0.13209663,0.016748302,,,,,,,,,,,,,,,,, -93600,0.15236115,0.018969012,,,,,,,,,,,,,,,,, -93700,0.1424079,0.017064223,,,,,,,,,,,,,,,,, -93718,,,0.9955332279205322,0.0139707894995808,0.7733520647228189,0.9870207905769348,0.0504884012043476,0.2931728770329977,43793.0,0.9861586689949036,0.0539308078587055,0.2760487360250577,43793.0,30512.44143342972,46902.35689115524,30512.44143342972,16382.511649608612,4.762882471084595,0.0 -93800,0.13677457,0.015881607,,,,,,,,,,,,,,,,, -93900,0.13505125,0.017525565,,,,,,,,,,,,,,,,, -94000,0.15020725,0.019891214,,,,,,,,,,,,,,,,, -94100,0.13749072,0.018302212,,,,,,,,,,,,,,,,, -94200,0.14091398,0.017070074,,,,,,,,,,,,,,,,, -94300,0.13831267,0.019385608,,,,,,,,,,,,,,,,, -94400,0.14195628,0.01717742,,,,,,,,,,,,,,,,, -94453,,,0.9954761266708374,0.0141239240765571,0.7818585828536757,0.9870207905769348,0.0504884012043476,0.2932003254588689,43793.0,0.9861586689949036,0.0539308078587055,0.2758970474532101,43793.0,30752.5343811512,47258.04355049133,30752.5343811512,16498.03994011879,4.806723833084106,0.0 -94500,0.15519957,0.01934858,,,,,,,,,,,,,,,,, -94600,0.13711685,0.017319178,,,,,,,,,,,,,,,,, -94700,0.13595264,0.017203925,,,,,,,,,,,,,,,,, -94800,0.1378547,0.016132357,,,,,,,,,,,,,,,,, -94900,0.123002075,0.015107166,,,,,,,,,,,,,,,,, -95000,0.13677387,0.016538931,,,,,,,,,,,,,,,,, -95100,0.14703186,0.017898392,,,,,,,,,,,,,,,,, -95184,,,0.9954842329025269,0.0140874776989221,0.7673942179190336,0.9870207905769348,0.0504884012043476,0.2931838900380704,43793.0,0.9861586689949036,0.0539308078587055,0.2759964447818105,43793.0,30992.618040323257,47611.36131834984,30992.618040323257,16611.206999063492,4.85191011428833,0.0 -95200,0.15986696,0.017734798,,,,,,,,,,,,,,,,, -95300,0.13406879,0.01765818,,,,,,,,,,,,,,,,, -95400,0.16389747,0.016420493,,,,,,,,,,,,,,,,, -95500,0.1279454,0.01598899,,,,,,,,,,,,,,,,, -95600,0.1323502,0.017266273,,,,,,,,,,,,,,,,, -95700,0.14820032,0.016727164,,,,,,,,,,,,,,,,, -95800,0.13815187,0.016666206,,,,,,,,,,,,,,,,, -95900,0.14647661,0.01695809,,,,,,,,,,,,,,,,, -95924,,,0.9954838156700134,0.014161848463118,0.7696215638128405,0.9870207905769348,0.0504884012043476,0.2931871120912127,43793.0,0.9861586689949036,0.0539308078587055,0.2759322101408783,43793.0,31232.86323785782,47971.10742545128,31232.86323785782,16730.642671346664,4.896559953689575,0.0 -96000,0.14392513,0.017060487,,,,,,,,,,,,,,,,, -96100,0.13515429,0.017266674,,,,,,,,,,,,,,,,, -96200,0.12766603,0.01655768,,,,,,,,,,,,,,,,, -96300,0.13557446,0.018805912,,,,,,,,,,,,,,,,, -96400,0.14208424,0.018573597,,,,,,,,,,,,,,,,, -96500,0.13977985,0.016506692,,,,,,,,,,,,,,,,, -96600,0.14639933,0.019248066,,,,,,,,,,,,,,,,, -96659,,,0.9955300688743592,0.0139715131372213,0.7872416582912839,0.9870207905769348,0.0504884012043476,0.2932467264780176,43793.0,0.9861586689949036,0.0539308078587055,0.2759397003339527,43793.0,31473.12002182007,48329.34169435501,31473.12002182007,16848.551023483276,4.945436477661133,0.0 -96700,0.14820293,0.018967628,,,,,,,,,,,,,,,,, -96800,0.17532493,0.022322686,,,,,,,,,,,,,,,,, -96900,0.14090997,0.01919475,,,,,,,,,,,,,,,,, -97000,0.13406785,0.017279284,,,,,,,,,,,,,,,,, -97100,0.17205675,0.017465323,,,,,,,,,,,,,,,,, -97200,0.13716276,0.01923438,,,,,,,,,,,,,,,,, -97300,0.14466783,0.018028418,,,,,,,,,,,,,,,,, -97395,,,0.9954854249954224,0.0140521274879574,0.7760487820006098,0.9870207905769348,0.0504884012043476,0.2932641153462439,43793.0,0.9861586689949036,0.0539308078587055,0.2760300954617187,43793.0,31713.225883245468,48681.89289999008,31713.225883245468,16960.928512334824,4.992716550827026,0.0 -97400,0.15012589,0.018115321,,,,,,,,,,,,,,,,, -97500,0.14158995,0.018170692,,,,,,,,,,,,,,,,, -97600,0.13074027,0.017151134,,,,,,,,,,,,,,,,, -97700,0.12944023,0.01535202,,,,,,,,,,,,,,,,, -97800,0.1329367,0.017929714,,,,,,,,,,,,,,,,, -97900,0.13515873,0.016800726,,,,,,,,,,,,,,,,, -98000,0.13490011,0.017982941,,,,,,,,,,,,,,,,, -98100,0.14594868,0.01740945,,,,,,,,,,,,,,,,, -98139,,,0.99550861120224,0.0140364868566393,0.775531835162473,0.9870207905769348,0.0504884012043476,0.2930844133223836,43793.0,0.9861586689949036,0.0539308078587055,0.2759224751185829,43793.0,31953.380586862564,49039.28237915039,31953.380586862564,17078.096727132797,5.03900408744812,0.0 -98200,0.14063254,0.016732642,,,,,,,,,,,,,,,,, -98300,0.14160797,0.018603861,,,,,,,,,,,,,,,,, -98400,0.14041086,0.016923117,,,,,,,,,,,,,,,,, -98500,0.1451229,0.021004673,,,,,,,,,,,,,,,,, -98600,0.13494423,0.017406939,,,,,,,,,,,,,,,,, -98700,0.14342159,0.017839544,,,,,,,,,,,,,,,,, -98800,0.16619354,0.019297231,,,,,,,,,,,,,,,,, -98892,,,0.9955294132232666,0.0139757683500647,0.7690663991926115,0.9870207905769348,0.0504884012043476,0.2931506027499619,43793.0,0.9861586689949036,0.0539308078587055,0.2759939712103814,43793.0,32193.38186430931,49391.1727848053,32193.38186430931,17189.921184539795,5.08355450630188,0.0 -98900,0.13981853,0.01753954,,,,,,,,,,,,,,,,, -99000,0.13390791,0.016036369,,,,,,,,,,,,,,,,, -99100,0.12550303,0.016189614,,,,,,,,,,,,,,,,, -99200,0.14880553,0.019377867,,,,,,,,,,,,,,,,, -99300,0.14164545,0.019030573,,,,,,,,,,,,,,,,, -99400,0.14277565,0.018182734,,,,,,,,,,,,,,,,, -99500,0.1314918,0.01787305,,,,,,,,,,,,,,,,, -99600,0.13369879,0.014960822,,,,,,,,,,,,,,,,, -99635,,,0.9954495429992676,0.0142050180584192,0.770540888307726,0.9870207905769348,0.0504884012043476,0.2932130341888327,43793.0,0.9861586689949036,0.0539308078587055,0.2758982340598774,43793.0,32433.47683668137,49744.64137840271,32433.47683668137,17303.229751110077,5.12899374961853,0.0 -99700,0.14937724,0.018030997,,,,,,,,,,,,,,,,, -99800,0.14504553,0.019355228,,,,,,,,,,,,,,,,, -99900,0.14368702,0.017650168,,,,,,,,,,,,,,,,, -100000,0.13398048,0.018066503,,,,,,,,,,,,,,,,, -100100,0.14445503,0.018236557,,,,,,,,,,,,,,,,, -100200,0.14735027,0.01826548,,,,,,,,,,,,,,,,, -100300,0.1368939,0.018001858,,,,,,,,,,,,,,,,, -100387,,,0.995512068271637,0.0140080070123076,0.7855383859173153,0.9870207905769348,0.0504884012043476,0.2930783521607571,43793.0,0.9861586689949036,0.0539308078587055,0.2759344706039897,43793.0,32673.473799943924,50092.74442219734,32673.473799943924,17411.270839214325,5.1742777824401855,0.0 -100400,0.14417662,0.017443115,,,,,,,,,,,,,,,,, -100500,0.13720298,0.015398592,,,,,,,,,,,,,,,,, -100600,0.15683979,0.018922312,,,,,,,,,,,,,,,,, -100700,0.14201671,0.019240241,,,,,,,,,,,,,,,,, -100800,0.14151572,0.017266065,,,,,,,,,,,,,,,,, -100900,0.1522225,0.019500094,,,,,,,,,,,,,,,,, -101000,0.15416074,0.019538838,,,,,,,,,,,,,,,,, -101100,0.14206329,0.016151309,,,,,,,,,,,,,,,,, -101136,,,0.9954873919487,0.0140253193676471,0.7812430634313685,0.9870207905769348,0.0504884012043476,0.293090200785777,43793.0,0.9861586689949036,0.0539308078587055,0.2759272854940975,43793.0,32913.52817153931,50445.39581871033,32913.52817153931,17523.80086541176,5.221256732940674,0.0 -101200,0.13331057,0.01672442,,,,,,,,,,,,,,,,, -101300,0.14078403,0.017171627,,,,,,,,,,,,,,,,, -101400,0.13860579,0.018766839,,,,,,,,,,,,,,,,, -101500,0.12930399,0.017887719,,,,,,,,,,,,,,,,, -101600,0.14293022,0.017884415,,,,,,,,,,,,,,,,, -101700,0.1461058,0.020213554,,,,,,,,,,,,,,,,, -101800,0.13535534,0.0148542905,,,,,,,,,,,,,,,,, -101875,,,0.995502471923828,0.014068104326725,0.775914528450701,0.9870207905769348,0.0504884012043476,0.2931656029885052,43793.0,0.9861586689949036,0.0539308004081249,0.2759245002265661,43793.0,33153.59733271599,50801.31637907028,33153.59733271599,17639.583562850952,5.267125368118286,0.0 -101900,0.13113108,0.016688121,,,,,,,,,,,,,,,,, -102000,0.1422915,0.017991835,,,,,,,,,,,,,,,,, -102100,0.13769025,0.017061964,,,,,,,,,,,,,,,,, -102200,0.13921396,0.016347129,,,,,,,,,,,,,,,,, -102300,0.13599886,0.016767615,,,,,,,,,,,,,,,,, -102400,0.13322414,0.018313041,,,,,,,,,,,,,,,,, -102500,0.1423773,0.019692916,,,,,,,,,,,,,,,,, -102600,0.1391722,0.018049343,,,,,,,,,,,,,,,,, -102621,,,0.995550572872162,0.0140123721212148,0.776397050638031,0.9870207905769348,0.0504884012043476,0.2933286895597028,43793.0,0.9861586689949036,0.0539308078587055,0.2758684385834484,43793.0,33393.808596372604,51153.58129143715,33393.808596372604,17751.57058095932,5.313305616378784,0.0 -102700,0.13612969,0.018113079,,,,,,,,,,,,,,,,, -102800,0.14678818,0.019702844,,,,,,,,,,,,,,,,, -102900,0.14847563,0.017284475,,,,,,,,,,,,,,,,, -103000,0.13375369,0.017790372,,,,,,,,,,,,,,,,, -103100,0.14280327,0.017872782,,,,,,,,,,,,,,,,, -103200,0.12257039,0.016322913,,,,,,,,,,,,,,,,, -103300,0.1445406,0.019780727,,,,,,,,,,,,,,,,, -103365,,,0.995477855205536,0.0140792066231369,0.7638450164442087,0.9870207905769348,0.0504884012043476,0.2932093756277367,43793.0,0.9861586689949036,0.0539308004081249,0.2760386250555259,43793.0,33633.803196430206,51508.68910455704,33633.803196430206,17866.617556095123,5.3591132164001465,0.0 -103400,0.13669667,0.018577045,,,,,,,,,,,,,,,,, -103500,0.14400007,0.016988661,,,,,,,,,,,,,,,,, -103600,0.13454753,0.016361428,,,,,,,,,,,,,,,,, -103700,0.13374494,0.016851312,,,,,,,,,,,,,,,,, -103800,0.14509296,0.018222595,,,,,,,,,,,,,,,,, -103900,0.15388097,0.019524412,,,,,,,,,,,,,,,,, -104000,0.12247548,0.015885625,,,,,,,,,,,,,,,,, -104100,0.13117829,0.015126694,,,,,,,,,,,,,,,,, -104114,,,0.9954692721366882,0.0141795184463262,0.7776649159084392,0.9870207905769348,0.0504884012043476,0.2932161456923653,43793.0,0.9861586689949036,0.0539308004081249,0.2759427131775175,43793.0,33873.81412649155,51861.71270442009,33873.81412649155,17979.562771081924,5.405600786209106,0.0 -104200,0.14941257,0.01804605,,,,,,,,,,,,,,,,, -104300,0.14064534,0.019232495,,,,,,,,,,,,,,,,, -104400,0.13777815,0.01661393,,,,,,,,,,,,,,,,, -104500,0.14956959,0.020827077,,,,,,,,,,,,,,,,, -104600,0.15123571,0.018786898,,,,,,,,,,,,,,,,, -104700,0.13951445,0.018325308,,,,,,,,,,,,,,,,, -104800,0.1385507,0.018221458,,,,,,,,,,,,,,,,, -104849,,,0.995512306690216,0.013982149772346,0.7832667867806614,0.9870207905769348,0.0504884012043476,0.2931969650094557,43793.0,0.9861586689949036,0.0539308004081249,0.2760546916619492,43793.0,34113.82161974907,52211.25935649872,34113.82161974907,18089.034460544583,5.451803922653198,0.0 -104900,0.13941625,0.018156374,,,,,,,,,,,,,,,,, -105000,0.13915738,0.020104451,,,,,,,,,,,,,,,,, -105100,0.14496481,0.019909836,,,,,,,,,,,,,,,,, -105200,0.1403258,0.016725946,,,,,,,,,,,,,,,,, -105300,0.14364083,0.01799261,,,,,,,,,,,,,,,,, -105400,0.13226178,0.016309746,,,,,,,,,,,,,,,,, -105500,0.1615122,0.019722095,,,,,,,,,,,,,,,,, -105600,0.1504616,0.016966969,,,,,,,,,,,,,,,,, -105602,,,0.9954873919487,0.0140405539423227,0.7785648148482405,0.9870207905769348,0.0504884012043476,0.2932527154438649,43793.0,0.9861586689949036,0.0539308004081249,0.2759460188977808,43793.0,34353.93736720085,52563.40596866608,34353.93736720085,18200.998265981674,5.49837327003479,0.0 -105700,0.1456249,0.018100075,,,,,,,,,,,,,,,,, -105800,0.1321239,0.019115508,,,,,,,,,,,,,,,,, -105900,0.15897615,0.01868429,,,,,,,,,,,,,,,,, -106000,0.14308698,0.017499827,,,,,,,,,,,,,,,,, -106100,0.13218486,0.017553326,,,,,,,,,,,,,,,,, -106200,0.1631953,0.019307988,,,,,,,,,,,,,,,,, -106300,0.1339446,0.01667958,,,,,,,,,,,,,,,,, -106346,,,0.9955300688743592,0.0140305142849683,0.7752093696893039,0.9870207905769348,0.0504884012043476,0.2931428330709554,43793.0,0.9861586689949036,0.0539308004081249,0.2760089875594442,43793.0,34594.05852675438,52914.76798701286,34594.05852675438,18312.1730606556,5.544644832611084,0.0 -106400,0.14198901,0.016118484,,,,,,,,,,,,,,,,, -106500,0.13322467,0.018473357,,,,,,,,,,,,,,,,, -106600,0.1239292,0.015840441,,,,,,,,,,,,,,,,, -106700,0.12467937,0.0154498,,,,,,,,,,,,,,,,, -106800,0.14512475,0.016810916,,,,,,,,,,,,,,,,, -106900,0.12858075,0.01646074,,,,,,,,,,,,,,,,, -107000,0.12921354,0.015692472,,,,,,,,,,,,,,,,, -107096,,,0.995497703552246,0.0140161113813519,0.7694205666819729,0.9870207905769348,0.0504884012043476,0.293144991287255,43793.0,0.9861586689949036,0.0539308004081249,0.275923875209314,43793.0,34834.22911691666,53265.606330394745,34834.22911691666,18422.77271294593,5.592383623123169,0.0 -107100,0.13240758,0.016904425,,,,,,,,,,,,,,,,, -107200,0.15537679,0.018127937,,,,,,,,,,,,,,,,, -107300,0.15690756,0.01903339,,,,,,,,,,,,,,,,, -107400,0.14671698,0.018169181,,,,,,,,,,,,,,,,, -107500,0.14882164,0.01845048,,,,,,,,,,,,,,,,, -107600,0.17824997,0.020810489,,,,,,,,,,,,,,,,, -107700,0.1448715,0.01874718,,,,,,,,,,,,,,,,, -107800,0.13735645,0.018339485,,,,,,,,,,,,,,,,, -107836,,,0.995468020439148,0.0141827668994665,0.7754401981422524,0.9870207905769348,0.0504884012043476,0.2932577357417439,43793.0,0.9861586689949036,0.0539308078587055,0.2759135417957489,43793.0,35074.33372759819,53620.07576179504,35074.33372759819,18537.066703557968,5.642268657684326,0.0 -107900,0.14029181,0.017699905,,,,,,,,,,,,,,,,, -108000,0.1638379,0.02155449,,,,,,,,,,,,,,,,, -108100,0.13450167,0.017077142,,,,,,,,,,,,,,,,, -108200,0.1416193,0.015630603,,,,,,,,,,,,,,,,, -108300,0.1338012,0.017657371,,,,,,,,,,,,,,,,, -108400,0.13378246,0.018549051,,,,,,,,,,,,,,,,, -108500,0.13853681,0.018226301,,,,,,,,,,,,,,,,, -108589,,,0.9955382347106934,0.0139516443014144,0.7749240428541433,0.9870207905769348,0.0504884012043476,0.2932017358982969,43793.0,0.9861586689949036,0.0539308004081249,0.2759655696714156,43793.0,35314.36759090424,53972.82032966614,35314.36759090424,18649.70942378044,5.689713478088379,0.0 -108600,0.14697294,0.016835047,,,,,,,,,,,,,,,,, -108700,0.15583257,0.018583797,,,,,,,,,,,,,,,,, -108800,0.124120384,0.018057683,,,,,,,,,,,,,,,,, -108900,0.13805385,0.016082434,,,,,,,,,,,,,,,,, -109000,0.13201135,0.017439473,,,,,,,,,,,,,,,,, -109100,0.1468288,0.018525552,,,,,,,,,,,,,,,,, -109200,0.13911791,0.01873382,,,,,,,,,,,,,,,,, -109300,0.14814615,0.020079767,,,,,,,,,,,,,,,,, -109339,,,0.995488941669464,0.0140517679974436,0.783534154624986,0.9870207905769348,0.0504884012043476,0.2930661018682657,43793.0,0.9861586689949036,0.0539308078587055,0.275938599498735,43793.0,35554.36666512489,54324.66298913956,35554.36666512489,18761.4860200882,5.736135482788086,0.0 -109400,0.13906725,0.018993907,,,,,,,,,,,,,,,,, -109500,0.15557125,0.017644322,,,,,,,,,,,,,,,,, -109600,0.12903921,0.015912918,,,,,,,,,,,,,,,,, -109700,0.14003187,0.019410318,,,,,,,,,,,,,,,,, -109800,0.13915709,0.01811094,,,,,,,,,,,,,,,,, -109900,0.14565536,0.016464667,,,,,,,,,,,,,,,,, -110000,0.1394034,0.020731747,,,,,,,,,,,,,,,,, -110088,,,0.9954979419708252,0.0140773402526974,0.7767405525270463,0.9870207905769348,0.0504884012043476,0.2932059522267767,43793.0,0.9861586689949036,0.0539308004081249,0.2759169175102017,43793.0,35794.49884533882,54679.16133594513,35794.49884533882,18875.78289008141,5.784965753555298,0.0 -110100,0.13362478,0.018170329,,,,,,,,,,,,,,,,, -110200,0.15104806,0.016833106,,,,,,,,,,,,,,,,, -110300,0.14018233,0.018509094,,,,,,,,,,,,,,,,, -110400,0.15318419,0.01740897,,,,,,,,,,,,,,,,, -110500,0.12447828,0.015304405,,,,,,,,,,,,,,,,, -110600,0.13374637,0.016425837,,,,,,,,,,,,,,,,, -110700,0.14136893,0.01844713,,,,,,,,,,,,,,,,, -110800,0.13276361,0.018571807,,,,,,,,,,,,,,,,, -110825,,,0.99551659822464,0.0140205714851617,0.7701554597187098,0.9870207905769348,0.0504884012043476,0.2933111682688331,43793.0,0.9861586689949036,0.0539308004081249,0.2759136667682885,43793.0,36034.63081741333,55027.872133016586,36034.63081741333,18984.285687446594,5.839473009109497,0.0 -110900,0.13773572,0.017749809,,,,,,,,,,,,,,,,, -111000,0.15922433,0.019248059,,,,,,,,,,,,,,,,, -111100,0.18307224,0.01797043,,,,,,,,,,,,,,,,, -111200,0.13555542,0.018029904,,,,,,,,,,,,,,,,, -111300,0.13329777,0.0177726,,,,,,,,,,,,,,,,, -111400,0.1402655,0.019344034,,,,,,,,,,,,,,,,, -111500,0.12460314,0.01730499,,,,,,,,,,,,,,,,, -111571,,,0.995445728302002,0.0141978915780782,0.7755410798187834,0.9870207905769348,0.0504884012043476,0.2932146665047697,43793.0,0.9861586689949036,0.0539308004081249,0.2758906413013993,43793.0,36274.76571893692,55378.53616476059,36274.76571893692,19094.74762225151,5.886816501617432,0.0 -111600,0.15715414,0.018894508,,,,,,,,,,,,,,,,, -111700,0.13462801,0.019224858,,,,,,,,,,,,,,,,, -111800,0.12267219,0.01730501,,,,,,,,,,,,,,,,, -111900,0.15790823,0.019296773,,,,,,,,,,,,,,,,, -112000,0.1441634,0.0184985,,,,,,,,,,,,,,,,, -112100,0.13846648,0.017844344,,,,,,,,,,,,,,,,, -112200,0.14368875,0.017465223,,,,,,,,,,,,,,,,, -112296,,,0.9955244064331056,0.0139960274100303,0.7807075584140233,0.9870207905769348,0.0504884012043476,0.2931795272519896,43793.0,0.9861586689949036,0.0539308004081249,0.2759193595748794,43793.0,36514.839708566666,55726.74456644058,36514.839708566666,19202.802329540253,5.942864656448364,0.0 -112300,0.13529179,0.015733825,,,,,,,,,,,,,,,,, -112400,0.15872668,0.016242206,,,,,,,,,,,,,,,,, -112500,0.142346,0.018481247,,,,,,,,,,,,,,,,, -112600,0.13953026,0.015723126,,,,,,,,,,,,,,,,, -112700,0.15792103,0.01533871,,,,,,,,,,,,,,,,, -112800,0.1355364,0.017466618,,,,,,,,,,,,,,,,, -112900,0.13494791,0.016955264,,,,,,,,,,,,,,,,, -113000,0.1437738,0.018597385,,,,,,,,,,,,,,,,, -113040,,,0.9955311417579652,0.0139847984537482,0.7829461868920284,0.9870207905769348,0.0504884012043476,0.2933635049845934,43793.0,0.9861586689949036,0.0539308078587055,0.2760884196260506,43793.0,36754.81792402268,56079.51251745224,36754.81792402268,19315.524201393127,5.990647315979004,0.0 -113100,0.12989663,0.017826231,,,,,,,,,,,,,,,,, -113200,0.1487563,0.018685833,,,,,,,,,,,,,,,,, -113300,0.13634428,0.016995167,,,,,,,,,,,,,,,,, -113400,0.14694701,0.017354487,,,,,,,,,,,,,,,,, -113500,0.1327337,0.017625023,,,,,,,,,,,,,,,,, -113600,0.136678,0.015994208,,,,,,,,,,,,,,,,, -113700,0.12671642,0.017713375,,,,,,,,,,,,,,,,, -113788,,,0.9954670071601868,0.0140660954639315,0.7718478538114525,0.9870207905769348,0.0504884012043476,0.293139372132737,43793.0,0.9861586689949036,0.0539308078587055,0.2759159872101339,43793.0,36994.844656705856,56427.37561607361,36994.844656705856,19423.286600351334,6.043809175491333,0.0 -113800,0.14796574,0.018568635,,,,,,,,,,,,,,,,, -113900,0.12918626,0.019596135,,,,,,,,,,,,,,,,, -114000,0.1257488,0.015524527,,,,,,,,,,,,,,,,, -114100,0.14089175,0.018285329,,,,,,,,,,,,,,,,, -114200,0.14378367,0.017886473,,,,,,,,,,,,,,,,, -114300,0.13597192,0.016148424,,,,,,,,,,,,,,,,, -114400,0.12387131,0.0169786,,,,,,,,,,,,,,,,, -114500,0.13511986,0.018521711,,,,,,,,,,,,,,,,, -114532,,,0.9955251216888428,0.0140108959749341,0.7738367760206022,0.9870207905769348,0.0504884012043476,0.293188261011175,43793.0,0.9861586689949036,0.0539308004081249,0.2759121370281739,43793.0,37234.62434768677,56782.48134112358,37234.62434768677,19538.20528769493,6.430257320404053,0.0 -114600,0.14371417,0.01881657,,,,,,,,,,,,,,,,, -114700,0.15208863,0.018660357,,,,,,,,,,,,,,,,, -114800,0.13845576,0.016545422,,,,,,,,,,,,,,,,, -114900,0.13966307,0.01729806,,,,,,,,,,,,,,,,, -115000,0.13846672,0.017094428,,,,,,,,,,,,,,,,, -115100,0.13391206,0.016578969,,,,,,,,,,,,,,,,, -115200,0.14478803,0.0202352,,,,,,,,,,,,,,,,, -115275,,,0.9954529404640198,0.0141806257888674,0.7671019554487162,0.9870207905769348,0.0504884012043476,0.2932588652343885,43793.0,0.9861586689949036,0.0539308004081249,0.2759276565912759,43793.0,37474.78272938728,57132.64653587341,37474.78272938728,19648.142145633698,6.480026960372925,0.0 -115300,0.13324966,0.0139926,,,,,,,,,,,,,,,,, -115400,0.14426321,0.018863384,,,,,,,,,,,,,,,,, -115500,0.15063569,0.018900601,,,,,,,,,,,,,,,,, -115600,0.13851342,0.01738066,,,,,,,,,,,,,,,,, -115700,0.13410382,0.01719949,,,,,,,,,,,,,,,,, -115800,0.14238915,0.01691441,,,,,,,,,,,,,,,,, -115900,0.13578592,0.018700937,,,,,,,,,,,,,,,,, -116000,0.13517843,0.01730043,,,,,,,,,,,,,,,,, -116018,,,0.9955293536186218,0.0140388701111078,0.7772930990901283,0.9870207905769348,0.0504884012043476,0.2932376430301499,43793.0,0.9861586689949036,0.0539308078587055,0.2760005242200415,43793.0,37714.924875974655,57486.06534719467,37714.924875974655,19761.3506834507,6.528173208236694,0.0 -116100,0.1276045,0.015683342,,,,,,,,,,,,,,,,, -116200,0.14649244,0.017845472,,,,,,,,,,,,,,,,, -116300,0.13219023,0.016582005,,,,,,,,,,,,,,,,, -116400,0.14037418,0.016622635,,,,,,,,,,,,,,,,, -116500,0.14153624,0.019119328,,,,,,,,,,,,,,,,, -116600,0.13563423,0.019029308,,,,,,,,,,,,,,,,, -116700,0.14873007,0.019369874,,,,,,,,,,,,,,,,, -116766,,,0.9954851865768432,0.0140848821029067,0.7850383991852307,0.9870207905769348,0.0504884012043476,0.2931388648347254,43793.0,0.9861586689949036,0.0539308078587055,0.2760248373857741,43793.0,37954.86515974999,57832.78515815735,37954.86515974999,19868.06057667732,6.577084302902222,0.0 -116800,0.13178353,0.018924084,,,,,,,,,,,,,,,,, -116900,0.14936192,0.018455222,,,,,,,,,,,,,,,,, -117000,0.1352391,0.01727542,,,,,,,,,,,,,,,,, -117100,0.13456705,0.01765511,,,,,,,,,,,,,,,,, -117200,0.15178394,0.01835033,,,,,,,,,,,,,,,,, -117300,0.1363353,0.017487524,,,,,,,,,,,,,,,,, -117400,0.16124508,0.019941965,,,,,,,,,,,,,,,,, -117500,0.15485853,0.016777918,,,,,,,,,,,,,,,,, -117507,,,0.9955120086669922,0.0139834303408861,0.7743394761970276,0.9870207905769348,0.0504884012043476,0.2931852186260689,43793.0,0.9861586689949036,0.0539308004081249,0.2759380183987922,43793.0,38195.090396404266,58187.087896347046,38195.090396404266,19982.06836414337,6.626428604125977,0.0 -117600,0.14124255,0.017309224,,,,,,,,,,,,,,,,, -117700,0.14222687,0.017608544,,,,,,,,,,,,,,,,, -117800,0.14296414,0.018580027,,,,,,,,,,,,,,,,, -117900,0.14297763,0.01811329,,,,,,,,,,,,,,,,, -118000,0.12839381,0.015150154,,,,,,,,,,,,,,,,, -118100,0.13981526,0.018378982,,,,,,,,,,,,,,,,, -118200,0.121628985,0.016820783,,,,,,,,,,,,,,,,, -118253,,,0.995533049106598,0.013999680057168,0.7822424577658693,0.9870207905769348,0.0504884012043476,0.2932161305695458,43793.0,0.9861586689949036,0.0539308078587055,0.276064940503458,43793.0,38435.11602497101,58541.69343018532,38435.11602497101,20096.57867860794,6.675387859344482,0.0 -118300,0.13106717,0.017218562,,,,,,,,,,,,,,,,, -118400,0.14323065,0.018147968,,,,,,,,,,,,,,,,, -118500,0.14074981,0.016767984,,,,,,,,,,,,,,,,, -118600,0.12863353,0.018682089,,,,,,,,,,,,,,,,, -118700,0.14147471,0.016772648,,,,,,,,,,,,,,,,, -118800,0.15022905,0.017057197,,,,,,,,,,,,,,,,, -118900,0.12963659,0.017218215,,,,,,,,,,,,,,,,, -118999,,,0.9954657554626464,0.0141053376719355,0.7663495065328356,0.9870207905769348,0.0504884012043476,0.2931702847497467,43793.0,0.9861586689949036,0.0539308078587055,0.2759337036347488,43793.0,38675.15397500992,58891.5354681015,38675.15397500992,20206.304889678955,6.732793807983398,0.0 -119000,0.14727697,0.017784366,,,,,,,,,,,,,,,,, -119100,0.12918246,0.016909976,,,,,,,,,,,,,,,,, -119200,0.14011636,0.019128667,,,,,,,,,,,,,,,,, -119300,0.13067959,0.017281532,,,,,,,,,,,,,,,,, -119400,0.15133074,0.021506805,,,,,,,,,,,,,,,,, -119500,0.13841273,0.01712583,,,,,,,,,,,,,,,,, -119600,0.14921679,0.019760367,,,,,,,,,,,,,,,,, -119700,0.14185067,0.020350864,,,,,,,,,,,,,,,,, -119739,,,0.9954808950424194,0.0141744362190365,0.7775733664311058,0.9870207905769348,0.0504884012043476,0.2932722189513038,43793.0,0.9861586689949036,0.0539308004081249,0.2761250320439972,43793.0,38915.19271636009,59241.49839806557,38915.19271636009,20316.15882253647,6.782490015029907,0.0 -119800,0.1470654,0.017441412,,,,,,,,,,,,,,,,, -119900,0.16500513,0.019190453,,,,,,,,,,,,,,,,, -120000,0.14320277,0.015209719,,,,,,,,,,,,,,,,, -120100,0.13222112,0.017644817,,,,,,,,,,,,,,,,, -120200,0.12362792,0.01706911,,,,,,,,,,,,,,,,, -120300,0.13369866,0.01710038,,,,,,,,,,,,,,,,, -120400,0.15006354,0.018967703,,,,,,,,,,,,,,,,, -120477,,,0.9955210089683532,0.0139794861897826,0.7815879575595298,0.9870207905769348,0.0504884012043476,0.2930735441592322,43793.0,0.9861586689949036,0.0539308078587055,0.2759202887756614,43793.0,39155.25708556175,59588.46991467476,39155.25708556175,20422.99523949623,6.833070516586304,0.0 -120500,0.15692599,0.01988875,,,,,,,,,,,,,,,,, -120600,0.13685672,0.017710023,,,,,,,,,,,,,,,,, -120700,0.12989625,0.015852982,,,,,,,,,,,,,,,,, -120800,0.1349817,0.017258065,,,,,,,,,,,,,,,,, -120900,0.15509573,0.01833648,,,,,,,,,,,,,,,,, -121000,0.15436852,0.018240536,,,,,,,,,,,,,,,,, -121100,0.13109072,0.017298432,,,,,,,,,,,,,,,,, -121200,0.1367637,0.016704686,,,,,,,,,,,,,,,,, -121214,,,0.9954981207847596,0.0139925740659236,0.7821686742341195,0.9870207905769348,0.0504884012043476,0.2931317994411526,43793.0,0.9861586689949036,0.0539308078587055,0.2759633130124874,43793.0,39395.50072574616,59939.0671851635,39395.50072574616,20533.277759552,6.884111404418945,0.0 -121300,0.14352047,0.01699418,,,,,,,,,,,,,,,,, -121400,0.14371929,0.019759756,,,,,,,,,,,,,,,,, -121500,0.14765134,0.020154105,,,,,,,,,,,,,,,,, -121600,0.16076422,0.01903236,,,,,,,,,,,,,,,,, -121700,0.14414375,0.017688608,,,,,,,,,,,,,,,,, -121800,0.14134963,0.016863283,,,,,,,,,,,,,,,,, -121900,0.15097271,0.018966,,,,,,,,,,,,,,,,, -121949,,,0.9954849481582642,0.0141089959070086,0.7756348087463598,0.9870207905769348,0.0504883974790573,0.2932923447189138,43793.0,0.9861586689949036,0.0539308004081249,0.2758850451009829,43793.0,39635.43482732773,60288.67901420593,39635.43482732773,20642.88063430786,6.936408996582031,0.0 -122000,0.15510844,0.020138046,,,,,,,,,,,,,,,,, -122100,0.15037891,0.018359972,,,,,,,,,,,,,,,,, -122200,0.12805346,0.018329563,,,,,,,,,,,,,,,,, -122300,0.13065109,0.017481837,,,,,,,,,,,,,,,,, -122400,0.1417705,0.019940337,,,,,,,,,,,,,,,,, -122500,0.15178452,0.018615833,,,,,,,,,,,,,,,,, -122600,0.1413611,0.017689997,,,,,,,,,,,,,,,,, -122693,,,0.99551123380661,0.0140492506325244,0.7678146983914897,0.9870207905769348,0.0504884012043476,0.2932222266958525,43793.0,0.9861586689949036,0.0539308078587055,0.2759276980217821,43793.0,39875.50879430771,60639.76265883446,39875.50879430771,20753.818819522858,6.987528085708618,0.0 -122700,0.13559029,0.016598582,,,,,,,,,,,,,,,,, -122800,0.16300695,0.018841425,,,,,,,,,,,,,,,,, -122900,0.14703017,0.018466232,,,,,,,,,,,,,,,,, -123000,0.13613361,0.015453363,,,,,,,,,,,,,,,,, -123100,0.14378369,0.015952678,,,,,,,,,,,,,,,,, -123200,0.13488585,0.018844845,,,,,,,,,,,,,,,,, -123300,0.16010413,0.019122528,,,,,,,,,,,,,,,,, -123400,0.12087178,0.01719541,,,,,,,,,,,,,,,,, -123435,,,0.9954609870910645,0.0141681898385286,0.7764094208096587,0.9870207905769348,0.0504884012043476,0.2931615148093934,43793.0,0.9861586689949036,0.0539308078587055,0.2759675452993614,43793.0,40115.57969236374,60990.48681926727,40115.57969236374,20864.3991625309,7.039609432220459,0.0 -123500,0.14954752,0.02087518,,,,,,,,,,,,,,,,, -123600,0.14780708,0.020427711,,,,,,,,,,,,,,,,, -123700,0.14330891,0.018985014,,,,,,,,,,,,,,,,, -123800,0.14717393,0.019068964,,,,,,,,,,,,,,,,, -123900,0.14556487,0.01852057,,,,,,,,,,,,,,,,, -124000,0.16497588,0.019981433,,,,,,,,,,,,,,,,, -124100,0.12636028,0.016221039,,,,,,,,,,,,,,,,, -124180,,,0.9955407977104188,0.0139289442449808,0.7768192297472776,0.9870207905769348,0.0504884012043476,0.2932437338587668,43793.0,0.9861586689949036,0.0539308078587055,0.2760686342789627,43793.0,40355.65077519417,61336.46803617477,40355.65077519417,20970.238475561146,7.089834451675415,0.0 -124200,0.14182317,0.017528815,,,,,,,,,,,,,,,,, -124300,0.13500962,0.017515453,,,,,,,,,,,,,,,,, -124400,0.14187405,0.018700251,,,,,,,,,,,,,,,,, -124500,0.15322413,0.02024026,,,,,,,,,,,,,,,,, -124600,0.12717943,0.017834935,,,,,,,,,,,,,,,,, -124700,0.13355432,0.016845275,,,,,,,,,,,,,,,,, -124800,0.14123999,0.016074471,,,,,,,,,,,,,,,,, -124900,0.13433324,0.017558536,,,,,,,,,,,,,,,,, -124928,,,0.9954872131347656,0.0140464529395103,0.7832406491909317,0.9870207905769348,0.0504884012043476,0.2931946410851083,43793.0,0.9861586689949036,0.0539308078587055,0.2760136439954811,43793.0,40595.69643044472,61688.01654314995,40595.69643044472,21081.669250011444,7.141475677490234,0.0 -125000,0.14556979,0.01800148,,,,,,,,,,,,,,,,, -125100,0.1467915,0.019559586,,,,,,,,,,,,,,,,, -125200,0.11623795,0.017180886,,,,,,,,,,,,,,,,, -125300,0.14082715,0.016351564,,,,,,,,,,,,,,,,, -125400,0.14661255,0.019028908,,,,,,,,,,,,,,,,, -125500,0.14643236,0.019506546,,,,,,,,,,,,,,,,, -125600,0.13172737,0.017297782,,,,,,,,,,,,,,,,, -125672,,,0.9954885840415956,0.0140621056780219,0.774091442743637,0.9870207905769348,0.0504884012043476,0.2933094306738013,43793.0,0.9861586689949036,0.0539308004081249,0.2760047208967673,43793.0,40835.86017179489,62043.13323545456,40835.86017179489,21196.552217245106,7.191240072250366,0.0 -125700,0.14521252,0.01651182,,,,,,,,,,,,,,,,, -125800,0.15157355,0.017748136,,,,,,,,,,,,,,,,, -125900,0.16065595,0.016387148,,,,,,,,,,,,,,,,, -126000,0.14094028,0.019335598,,,,,,,,,,,,,,,,, -126100,0.13652183,0.018943993,,,,,,,,,,,,,,,,, -126200,0.11876491,0.017780228,,,,,,,,,,,,,,,,, -126300,0.14990617,0.01955064,,,,,,,,,,,,,,,,, -126400,0.14965281,0.018397667,,,,,,,,,,,,,,,,, -126416,,,0.995519518852234,0.0140607506036758,0.7807933473108024,0.9870207905769348,0.0504884012043476,0.2932310735617742,43793.0,0.9861586689949036,0.0539308004081249,0.2760524931298125,43793.0,41075.91949534416,62393.43706226349,41075.91949534416,21306.72452545166,7.242794752120972,0.0 -126500,0.15518327,0.019274242,,,,,,,,,,,,,,,,, -126600,0.13881367,0.016448578,,,,,,,,,,,,,,,,, -126700,0.15145394,0.01871141,,,,,,,,,,,,,,,,, -126800,0.1354863,0.017397566,,,,,,,,,,,,,,,,, -126900,0.14117183,0.0186324,,,,,,,,,,,,,,,,, -127000,0.14185756,0.01927039,,,,,,,,,,,,,,,,, -127100,0.15107436,0.018612426,,,,,,,,,,,,,,,,, -127160,,,0.9954697489738464,0.0141372112557291,0.7587198449639805,0.9870207905769348,0.0504884012043476,0.2932764543116762,43793.0,0.9861586689949036,0.0539308078587055,0.2760117545460758,43793.0,41316.159787893295,62742.088060855865,41316.159787893295,21415.063493013386,7.293832540512085,0.0 -127200,0.14702775,0.019873353,,,,,,,,,,,,,,,,, -127300,0.1301065,0.017722229,,,,,,,,,,,,,,,,, -127400,0.13525875,0.016105276,,,,,,,,,,,,,,,,, -127500,0.16055718,0.019481042,,,,,,,,,,,,,,,,, -127600,0.15566635,0.01890073,,,,,,,,,,,,,,,,, -127700,0.123981796,0.016212013,,,,,,,,,,,,,,,,, -127800,0.13098823,0.018456122,,,,,,,,,,,,,,,,, -127900,0.14554094,0.01819167,,,,,,,,,,,,,,,,, -127901,,,0.995512843132019,0.0140344286337494,0.7794579884619639,0.9870207905769348,0.0504884012043476,0.2931954360456731,43793.0,0.9861586689949036,0.0539308078587055,0.275960832651128,43793.0,41556.275861501694,63087.90578913689,41556.275861501694,21520.694525957108,7.344229459762573,0.0 -128000,0.15203664,0.019539427,,,,,,,,,,,,,,,,, -128100,0.13766523,0.01781137,,,,,,,,,,,,,,,,, -128200,0.14154257,0.017937431,,,,,,,,,,,,,,,,, -128300,0.14723966,0.018196369,,,,,,,,,,,,,,,,, -128400,0.12983789,0.014713643,,,,,,,,,,,,,,,,, -128500,0.13704343,0.018404322,,,,,,,,,,,,,,,,, -128600,0.1532645,0.01785505,,,,,,,,,,,,,,,,, -128646,,,0.99550598859787,0.0140278497710824,0.7830336621813878,0.9870207905769348,0.0504884012043476,0.2931289849079299,43793.0,0.9861586689949036,0.0539308004081249,0.2760132085099576,43793.0,41796.47999668121,63435.39198184013,41796.47999668121,21627.90496778488,7.395256757736206,0.0 -128700,0.14875722,0.01808667,,,,,,,,,,,,,,,,, -128800,0.13962947,0.017472688,,,,,,,,,,,,,,,,, -128900,0.16874771,0.020608973,,,,,,,,,,,,,,,,, -129000,0.16537936,0.020927059,,,,,,,,,,,,,,,,, -129100,0.13900161,0.017621163,,,,,,,,,,,,,,,,, -129200,0.16874751,0.018790014,,,,,,,,,,,,,,,,, -129300,0.1382345,0.017667875,,,,,,,,,,,,,,,,, -129368,,,0.9955154657363892,0.0139666367322206,0.7766582892696063,0.9870207905769348,0.0504884012043476,0.293199850367342,43793.0,0.9861586689949036,0.0539308078587055,0.276029149677099,43793.0,42036.49334049225,63785.981810092926,42036.49334049225,21738.40746998787,7.44689416885376,0.0 -129400,0.13872987,0.016658947,,,,,,,,,,,,,,,,, -129500,0.14481203,0.018947564,,,,,,,,,,,,,,,,, -129600,0.13342674,0.017758667,,,,,,,,,,,,,,,,, -129700,0.14261189,0.017434387,,,,,,,,,,,,,,,,, -129800,0.15358041,0.017601676,,,,,,,,,,,,,,,,, -129900,0.13955152,0.016653594,,,,,,,,,,,,,,,,, -130000,0.13863242,0.018269269,,,,,,,,,,,,,,,,, -130100,0.13712609,0.017565511,,,,,,,,,,,,,,,,, -130105,,,0.9955015778541564,0.0140527635812759,0.7812440709210555,0.9870207905769348,0.0504884012043476,0.2932153395867975,43793.0,0.9861586689949036,0.0539308004081249,0.2759899810924676,43793.0,42276.55060958862,64135.95481061936,42276.55060958862,21848.249663591385,7.499475479125977,0.0 -130200,0.14325355,0.01846693,,,,,,,,,,,,,,,,, -130300,0.14886545,0.020057647,,,,,,,,,,,,,,,,, -130400,0.1434342,0.018058037,,,,,,,,,,,,,,,,, -130500,0.13815737,0.0187675,,,,,,,,,,,,,,,,, -130600,0.1480629,0.017140776,,,,,,,,,,,,,,,,, -130700,0.14288943,0.018148603,,,,,,,,,,,,,,,,, -130800,0.15100034,0.017228143,,,,,,,,,,,,,,,,, -130840,,,0.9954874515533448,0.0140948053449392,0.7717035971463917,0.9870207905769348,0.0504884012043476,0.2930906746882962,43793.0,0.9861586689949036,0.0539308078587055,0.2759373339764978,43793.0,42516.48381781578,64487.17621850968,42516.48381781578,21959.45810699463,7.557837963104248,0.0 -130900,0.14082271,0.019587874,,,,,,,,,,,,,,,,, -131000,0.14261748,0.018662687,,,,,,,,,,,,,,,,, -131100,0.12572971,0.016617209,,,,,,,,,,,,,,,,, -131200,0.13427459,0.018483093,,,,,,,,,,,,,,,,, -131300,0.1215433,0.014188449,,,,,,,,,,,,,,,,, -131400,0.14233528,0.018680338,,,,,,,,,,,,,,,,, -131500,0.15489475,0.017798338,,,,,,,,,,,,,,,,, -131586,,,0.995464563369751,0.0142739154398441,0.7646260937715028,0.9870207905769348,0.0504884012043476,0.2932155802724631,43793.0,0.9861586689949036,0.0539308004081249,0.2758972165051127,43793.0,42756.610813617706,64835.18199014664,42756.610813617706,22067.263786792755,7.610391139984131,0.0 -131600,0.13785358,0.017689249,,,,,,,,,,,,,,,,, -131700,0.1143866,0.013368597,,,,,,,,,,,,,,,,, -131800,0.16005197,0.019298142,,,,,,,,,,,,,,,,, -131900,0.13614,0.01841079,,,,,,,,,,,,,,,,, -132000,0.15479824,0.019544521,,,,,,,,,,,,,,,,, -132100,0.1505036,0.018074408,,,,,,,,,,,,,,,,, -132200,0.14637592,0.018956007,,,,,,,,,,,,,,,,, -132300,0.13431127,0.019133093,,,,,,,,,,,,,,,,, -132333,,,0.9955654144287108,0.0138419400900602,0.7805771446483449,0.9870207905769348,0.0504884012043476,0.2932781435099585,43793.0,0.9861586689949036,0.0539308078587055,0.2759797114188226,43793.0,42996.56942439079,65185.630719423294,42996.56942439079,22177.67981219292,7.664034605026245,0.0 -132400,0.13027805,0.018270416,,,,,,,,,,,,,,,,, -132500,0.14147007,0.01850687,,,,,,,,,,,,,,,,, -132600,0.14929004,0.01996543,,,,,,,,,,,,,,,,, -132700,0.1399272,0.018358195,,,,,,,,,,,,,,,,, -132800,0.13806506,0.01692947,,,,,,,,,,,,,,,,, -132900,0.1354001,0.017039942,,,,,,,,,,,,,,,,, -133000,0.13477884,0.015490217,,,,,,,,,,,,,,,,, -133080,,,0.9954608678817748,0.0140660284087061,0.7814404601298929,0.9870207905769348,0.0504884012043476,0.2933030047203722,43793.0,0.9861586689949036,0.0539308078587055,0.2759513026070447,43793.0,43236.740758657455,65534.44535279274,43236.740758657455,22286.250038146973,7.716859579086304,0.0 -133100,0.13556086,0.017141705,,,,,,,,,,,,,,,,, -133200,0.14680424,0.02016594,,,,,,,,,,,,,,,,, -133300,0.12505685,0.017407961,,,,,,,,,,,,,,,,, -133400,0.14397874,0.01765976,,,,,,,,,,,,,,,,, -133500,0.13729441,0.018621987,,,,,,,,,,,,,,,,, -133600,0.14502397,0.018052718,,,,,,,,,,,,,,,,, -133700,0.1328649,0.016420195,,,,,,,,,,,,,,,,, -133800,0.14553419,0.01799012,,,,,,,,,,,,,,,,, -133820,,,0.9954729676246644,0.0141444317996501,0.7757278511595392,0.9870207905769348,0.0504884012043476,0.293336474690113,43793.0,0.9861586689949036,0.0539308078587055,0.2759277858320216,43793.0,43476.80657362938,65887.98328089714,43476.80657362938,22399.64813780785,7.770319223403931,0.0 -133900,0.1404312,0.020146608,,,,,,,,,,,,,,,,, -134000,0.15173537,0.019531058,,,,,,,,,,,,,,,,, -134100,0.12801182,0.015670093,,,,,,,,,,,,,,,,, -134200,0.13784988,0.015784858,,,,,,,,,,,,,,,,, -134300,0.122164495,0.017652214,,,,,,,,,,,,,,,,, -134400,0.14662288,0.018441863,,,,,,,,,,,,,,,,, -134500,0.13337483,0.015858002,,,,,,,,,,,,,,,,, -134551,,,0.9955356121063232,0.0139828957617282,0.7739661532021433,0.9870207905769348,0.0504884012043476,0.2931877864285181,43793.0,0.9861586689949036,0.0539308004081249,0.2759842286678688,43793.0,43716.8826019764,66236.73903632164,43716.8826019764,22508.24461197853,7.830138921737671,0.0 -134600,0.15121426,0.017521897,,,,,,,,,,,,,,,,, -134700,0.13942845,0.01775659,,,,,,,,,,,,,,,,, -134800,0.14558455,0.017966006,,,,,,,,,,,,,,,,, -134900,0.14196548,0.016653389,,,,,,,,,,,,,,,,, -135000,0.14797981,0.018703017,,,,,,,,,,,,,,,,, -135100,0.15489627,0.019161884,,,,,,,,,,,,,,,,, -135200,0.13653585,0.015645036,,,,,,,,,,,,,,,,, -135297,,,0.995487630367279,0.0141261909157037,0.7703792187677057,0.9870207905769348,0.0504884012043476,0.2932659339598245,43793.0,0.9861586689949036,0.0539308078587055,0.2759395228017564,43793.0,43956.86270856857,66590.24928569794,43956.86270856857,22621.700261354446,7.88405442237854,0.0 -135300,0.1386519,0.017072367,,,,,,,,,,,,,,,,, -135400,0.15661122,0.020154243,,,,,,,,,,,,,,,,, -135500,0.14091669,0.017475478,,,,,,,,,,,,,,,,, -135600,0.15410912,0.018538412,,,,,,,,,,,,,,,,, -135700,0.13522157,0.017758917,,,,,,,,,,,,,,,,, -135800,0.13158906,0.016238397,,,,,,,,,,,,,,,,, -135900,0.14724535,0.018964626,,,,,,,,,,,,,,,,, -136000,0.14222857,0.01856769,,,,,,,,,,,,,,,,, -136036,,,0.995522141456604,0.014012542553246,0.7761438917983973,0.9870207905769348,0.0504884012043476,0.2932118139807845,43793.0,0.9861586689949036,0.0539308078587055,0.2759236846054778,43793.0,44196.990394830704,66942.81797623634,44196.990394830704,22734.069207906723,7.935742139816284,0.0 -136100,0.14238557,0.017149944,,,,,,,,,,,,,,,,, -136200,0.14768922,0.018968565,,,,,,,,,,,,,,,,, -136300,0.15628259,0.01907755,,,,,,,,,,,,,,,,, -136400,0.1358539,0.01726276,,,,,,,,,,,,,,,,, -136500,0.12921913,0.016706863,,,,,,,,,,,,,,,,, -136600,0.1572128,0.021320695,,,,,,,,,,,,,,,,, -136700,0.15109977,0.019908428,,,,,,,,,,,,,,,,, -136772,,,0.995475709438324,0.0140751581639051,0.7810047617352318,0.9870207905769348,0.0504884012043476,0.2931009444456341,43793.0,0.9861586689949036,0.0539308004081249,0.2759907501293877,43793.0,44437.059807538986,67295.4547200203,44437.059807538986,22846.563081502914,7.987916707992554,0.0 -136800,0.14815739,0.019118374,,,,,,,,,,,,,,,,, -136900,0.1373408,0.018987661,,,,,,,,,,,,,,,,, -137000,0.14572382,0.019171644,,,,,,,,,,,,,,,,, -137100,0.1553245,0.01823893,,,,,,,,,,,,,,,,, -137200,0.13281883,0.016216738,,,,,,,,,,,,,,,,, -137300,0.17120117,0.01911153,,,,,,,,,,,,,,,,, -137400,0.12970173,0.017613918,,,,,,,,,,,,,,,,, -137500,0.14474532,0.018106868,,,,,,,,,,,,,,,,, -137508,,,0.9954957365989684,0.0140201319009065,0.7739173571564724,0.9870207905769348,0.0504884012043476,0.2932000694716932,43793.0,0.9861586689949036,0.0539308078587055,0.27596270562728,43793.0,44677.20724821091,67646.28160524368,44677.20724821091,22957.16247224808,8.046531915664673,0.0 -137600,0.13959248,0.017741704,,,,,,,,,,,,,,,,, -137700,0.14274736,0.020204507,,,,,,,,,,,,,,,,, -137800,0.13057917,0.016121916,,,,,,,,,,,,,,,,, -137900,0.14375521,0.019235132,,,,,,,,,,,,,,,,, -138000,0.158411,0.019135881,,,,,,,,,,,,,,,,, -138100,0.13847254,0.02014574,,,,,,,,,,,,,,,,, -138200,0.1314495,0.016817141,,,,,,,,,,,,,,,,, -138247,,,0.9955067038536072,0.0140622872859239,0.779246398721551,0.9870207905769348,0.0504884012043476,0.2931393059919693,43793.0,0.9861586689949036,0.0539308078587055,0.2759087396709697,43793.0,44917.2401702404,67994.80980920792,44917.2401702404,23065.58174324036,8.102015018463135,0.0 -138300,0.16557261,0.019363169,,,,,,,,,,,,,,,,, -138400,0.15014508,0.017595315,,,,,,,,,,,,,,,,, -138500,0.14307676,0.018380862,,,,,,,,,,,,,,,,, -138600,0.14576955,0.017617613,,,,,,,,,,,,,,,,, -138700,0.16157228,0.018837964,,,,,,,,,,,,,,,,, -138800,0.16209428,0.018020308,,,,,,,,,,,,,,,,, -138900,0.1509743,0.01890624,,,,,,,,,,,,,,,,, -138991,,,0.9954759478569032,0.0141209280118346,0.7633788395828764,0.9870207905769348,0.0504884012043476,0.2931961719757504,43793.0,0.9861586689949036,0.0539308078587055,0.2759749694299141,43793.0,45157.17242622376,68341.19829106331,45157.17242622376,23171.964718818665,8.154671430587769,0.0 -139000,0.15691514,0.017590335,,,,,,,,,,,,,,,,, -139100,0.14025277,0.017281765,,,,,,,,,,,,,,,,, -139200,0.15003808,0.018341335,,,,,,,,,,,,,,,,, -139300,0.123441786,0.015634084,,,,,,,,,,,,,,,,, -139400,0.14533189,0.019911282,,,,,,,,,,,,,,,,, -139500,0.14403185,0.01831918,,,,,,,,,,,,,,,,, -139600,0.14725253,0.01695343,,,,,,,,,,,,,,,,, -139700,0.13332328,0.015899274,,,,,,,,,,,,,,,,, -139731,,,0.9955406785011292,0.0140155563130974,0.7742222547996289,0.9870207905769348,0.0504884012043476,0.293142011524399,43793.0,0.9861586689949036,0.0539308078587055,0.2759331733707005,43793.0,45397.25456047058,68687.44159507751,45397.25456047058,23278.051063776016,8.208355903625488,0.0 -139800,0.14433156,0.01668976,,,,,,,,,,,,,,,,, -139900,0.123519056,0.016856879,,,,,,,,,,,,,,,,, -140000,0.15505856,0.020038165,,,,,,,,,,,,,,,,, -140100,0.13967279,0.017983299,,,,,,,,,,,,,,,,, -140200,0.1527394,0.019625774,,,,,,,,,,,,,,,,, -140300,0.12366192,0.016735315,,,,,,,,,,,,,,,,, -140400,0.14532624,0.018441409,,,,,,,,,,,,,,,,, -140480,,,0.9954973459243774,0.0140028344467282,0.786076248561501,0.9870207905769348,0.0504884012043476,0.2932708711021463,43793.0,0.9861586689949036,0.0539308004081249,0.2759290721408751,43793.0,45637.525134563446,69037.21703863144,45637.525134563446,23387.48162317276,8.26203179359436,0.0 -140500,0.1400356,0.018118022,,,,,,,,,,,,,,,,, -140600,0.13671297,0.019146606,,,,,,,,,,,,,,,,, -140700,0.12906882,0.016652979,,,,,,,,,,,,,,,,, -140800,0.12813595,0.016547691,,,,,,,,,,,,,,,,, -140900,0.15216085,0.017475694,,,,,,,,,,,,,,,,, -141000,0.12524174,0.01732303,,,,,,,,,,,,,,,,, -141100,0.14566426,0.01759732,,,,,,,,,,,,,,,,, -141200,0.14047238,0.01654521,,,,,,,,,,,,,,,,, -141228,,,0.995505392551422,0.0140246078372001,0.7788733946088782,0.9870207905769348,0.0504884012043476,0.2932886916868568,43793.0,0.9861586689949036,0.0539308004081249,0.2760197268325939,43793.0,45877.65914797783,69388.22018957138,45877.65914797783,23498.272793293,8.319572448730469,0.0 -141300,0.124132715,0.016177356,,,,,,,,,,,,,,,,, -141400,0.14982398,0.018735778,,,,,,,,,,,,,,,,, -141500,0.15985204,0.019034356,,,,,,,,,,,,,,,,, -141600,0.14596944,0.0191561,,,,,,,,,,,,,,,,, -141700,0.13506426,0.017303549,,,,,,,,,,,,,,,,, -141800,0.15290931,0.019570667,,,,,,,,,,,,,,,,, -141900,0.14738671,0.017923998,,,,,,,,,,,,,,,,, -141970,,,0.9954705834388732,0.0141415288671851,0.7747681263770184,0.9870207905769348,0.0504884012043476,0.293163724473554,43793.0,0.9861586689949036,0.0539308078587055,0.2760298456241528,43793.0,46117.658863544464,69731.63406729698,46117.658863544464,23601.612596273422,8.372910976409912,0.0 -142000,0.13596655,0.018517405,,,,,,,,,,,,,,,,, -142100,0.1414316,0.01911954,,,,,,,,,,,,,,,,, -142200,0.13233319,0.019008292,,,,,,,,,,,,,,,,, -142300,0.13300733,0.01647316,,,,,,,,,,,,,,,,, -142400,0.14357205,0.018474488,,,,,,,,,,,,,,,,, -142500,0.15526858,0.017398825,,,,,,,,,,,,,,,,, -142600,0.14171757,0.020320825,,,,,,,,,,,,,,,,, -142700,0.14580707,0.018566592,,,,,,,,,,,,,,,,, -142710,,,0.9955161213874816,0.0139932902529835,0.76743690980945,0.9870207905769348,0.0504884012043476,0.2932152944727341,43793.0,0.9861586689949036,0.0539308078587055,0.2759479146602972,43793.0,46357.69627165794,70081.35587143898,46357.69627165794,23711.222553491592,8.42588186264038,0.0 -142800,0.16218917,0.019074516,,,,,,,,,,,,,,,,, -142900,0.15121247,0.016480932,,,,,,,,,,,,,,,,, -143000,0.13870937,0.016886465,,,,,,,,,,,,,,,,, -143100,0.13644071,0.017447965,,,,,,,,,,,,,,,,, -143200,0.15247822,0.017990636,,,,,,,,,,,,,,,,, -143300,0.12445054,0.015843512,,,,,,,,,,,,,,,,, -143400,0.14179988,0.01833162,,,,,,,,,,,,,,,,, -143450,,,0.9954485893249512,0.0142313251271843,0.7793896921346442,0.9870207905769348,0.0504884012043476,0.2931962193358611,43793.0,0.9861586689949036,0.0539308078587055,0.2759225501940873,43793.0,46597.8717443943,70427.41950631142,46597.8717443943,23817.034957647324,8.481031894683838,0.0 -143500,0.14366873,0.01889533,,,,,,,,,,,,,,,,, -143600,0.14191578,0.016889492,,,,,,,,,,,,,,,,, -143700,0.15143116,0.01854647,,,,,,,,,,,,,,,,, -143800,0.13796791,0.01878285,,,,,,,,,,,,,,,,, -143900,0.14353445,0.018049078,,,,,,,,,,,,,,,,, -144000,0.13680139,0.017412618,,,,,,,,,,,,,,,,, -144100,0.17171273,0.01948654,,,,,,,,,,,,,,,,, -144195,,,0.9955180287361144,0.0140170911327004,0.7712021251836819,0.9870207905769348,0.0504884012043476,0.2932476380325232,43793.0,0.9861586689949036,0.0539308078587055,0.2759723667264904,43793.0,46837.9917037487,70780.69690322876,46837.9917037487,23930.11753678322,8.535125494003296,0.0 -144200,0.13043182,0.017490873,,,,,,,,,,,,,,,,, -144300,0.14189655,0.017077917,,,,,,,,,,,,,,,,, -144400,0.1347645,0.01709087,,,,,,,,,,,,,,,,, -144500,0.14874093,0.018014861,,,,,,,,,,,,,,,,, -144600,0.13969679,0.015787695,,,,,,,,,,,,,,,,, -144700,0.13652173,0.016034896,,,,,,,,,,,,,,,,, -144800,0.12399685,0.015831346,,,,,,,,,,,,,,,,, -144900,0.1497396,0.017637858,,,,,,,,,,,,,,,,, -144936,,,0.9955515265464784,0.0138782355934381,0.7903820845194295,0.9870207905769348,0.0504884012043476,0.2932307445075147,43793.0,0.9861586689949036,0.0539308078587055,0.275894328449335,43793.0,47078.04253005981,71132.5337510109,47078.04253005981,24041.820734739304,8.595582008361816,0.0 -145000,0.1589006,0.01694246,,,,,,,,,,,,,,,,, -145100,0.14280763,0.017272457,,,,,,,,,,,,,,,,, -145200,0.13232283,0.01672337,,,,,,,,,,,,,,,,, -145300,0.1351104,0.017226093,,,,,,,,,,,,,,,,, -145400,0.1299724,0.016425114,,,,,,,,,,,,,,,,, -145500,0.15383726,0.019257834,,,,,,,,,,,,,,,,, -145600,0.13849293,0.017575487,,,,,,,,,,,,,,,,, -145685,,,0.9954928755760192,0.0140475835651159,0.7702413222315503,0.9870207905769348,0.0504884012043476,0.2932309214595385,43793.0,0.9861586689949036,0.0539308078587055,0.2759183809349992,43793.0,47318.212882995605,71483.07665705681,47318.212882995605,24152.11885714531,8.649686574935913,0.0 -145700,0.14300837,0.019147296,,,,,,,,,,,,,,,,, -145800,0.132142,0.016236337,,,,,,,,,,,,,,,,, -145900,0.14574485,0.014728941,,,,,,,,,,,,,,,,, -146000,0.15374023,0.019240923,,,,,,,,,,,,,,,,, -146100,0.15114835,0.018685559,,,,,,,,,,,,,,,,, -146200,0.13961159,0.017539933,,,,,,,,,,,,,,,,, -146300,0.1157718,0.015687816,,,,,,,,,,,,,,,,, -146400,0.13601905,0.01715009,,,,,,,,,,,,,,,,, -146433,,,0.9955329895019532,0.0140459137037396,0.7693459707086195,0.9870207905769348,0.0504884012043476,0.2932822507404936,43793.0,0.9861586689949036,0.0539308004081249,0.276092969195866,43793.0,47558.26383471489,71827.52420902252,47558.26383471489,24256.440824985504,8.703256368637085,0.0 -146500,0.12933347,0.015544668,,,,,,,,,,,,,,,,, -146600,0.15133089,0.01792419,,,,,,,,,,,,,,,,, -146700,0.14704254,0.018266253,,,,,,,,,,,,,,,,, -146800,0.15745448,0.018963704,,,,,,,,,,,,,,,,, -146900,0.13833378,0.0176094,,,,,,,,,,,,,,,,, -147000,0.13231517,0.018359497,,,,,,,,,,,,,,,,, -147100,0.13817506,0.01781809,,,,,,,,,,,,,,,,, -147184,,,0.9954391717910768,0.0142329670488834,0.7748171236902126,0.9870207905769348,0.0504884012043476,0.2933502504790597,43793.0,0.9861586689949036,0.0539308004081249,0.2759415898627759,43793.0,47798.32083177567,72170.59030127525,47798.32083177567,24359.37511229515,8.758103132247925,0.0 -147200,0.14631447,0.017711684,,,,,,,,,,,,,,,,, -147300,0.1326763,0.016946202,,,,,,,,,,,,,,,,, -147400,0.13278827,0.016390156,,,,,,,,,,,,,,,,, -147500,0.13822414,0.018486694,,,,,,,,,,,,,,,,, -147600,0.15601304,0.019680189,,,,,,,,,,,,,,,,, -147700,0.13778456,0.017351322,,,,,,,,,,,,,,,,, -147800,0.122428164,0.0149293635,,,,,,,,,,,,,,,,, -147900,0.12468984,0.017252248,,,,,,,,,,,,,,,,, -147932,,,0.9954963326454164,0.0140677941963076,0.7755266511394573,0.9870207905769348,0.0504884012043476,0.2931353810233713,43793.0,0.9861586689949036,0.0539308004081249,0.2759515052454083,43793.0,48038.32773327828,72517.02304172516,48038.32773327828,24465.72505545616,8.8131422996521,0.0 -148000,0.15210806,0.01802582,,,,,,,,,,,,,,,,, -148100,0.15620716,0.019647926,,,,,,,,,,,,,,,,, -148200,0.1470912,0.018073615,,,,,,,,,,,,,,,,, -148300,0.13814193,0.0189097,,,,,,,,,,,,,,,,, -148400,0.13357694,0.016789097,,,,,,,,,,,,,,,,, -148500,0.14260633,0.018513942,,,,,,,,,,,,,,,,, -148600,0.14680558,0.017551001,,,,,,,,,,,,,,,,, -148677,,,0.995520830154419,0.01393973082304,0.7862674566821604,0.9870207905769348,0.0504884012043476,0.2931562245813233,43793.0,0.9861586689949036,0.0539308078587055,0.2759377518485115,43793.0,48278.375985860825,72863.22947764397,48278.375985860825,24571.807502031326,8.868491649627686,0.0 -148700,0.14935215,0.017526051,,,,,,,,,,,,,,,,, -148800,0.16311538,0.019318882,,,,,,,,,,,,,,,,, -148900,0.13136171,0.020010302,,,,,,,,,,,,,,,,, -149000,0.1611483,0.019749591,,,,,,,,,,,,,,,,, -149100,0.15159558,0.021110937,,,,,,,,,,,,,,,,, -149200,0.150154,0.018451601,,,,,,,,,,,,,,,,, -149300,0.14045382,0.0154882325,,,,,,,,,,,,,,,,, -149400,0.12700708,0.01788406,,,,,,,,,,,,,,,,, -149418,,,0.9955045580863952,0.0140247605741024,0.7792347084590541,0.9870207905769348,0.0504884012043476,0.2932035091562113,43793.0,0.9861586689949036,0.0539308004081249,0.2758873321968428,43793.0,48518.42352437973,73214.89954638481,48518.42352437973,24683.352266073227,8.925062656402588,0.0 -149500,0.1454683,0.017421933,,,,,,,,,,,,,,,,, -149600,0.13904855,0.017067144,,,,,,,,,,,,,,,,, -149700,0.124904975,0.014938438,,,,,,,,,,,,,,,,, -149800,0.14348732,0.017687086,,,,,,,,,,,,,,,,, -149900,0.15379319,0.018113095,,,,,,,,,,,,,,,,, -150000,0.13971125,0.016386712,,,,,,,,,,,,,,,,, -150100,0.1514379,0.019490458,,,,,,,,,,,,,,,,, -150160,,,0.9955043196678162,0.0140610709786415,0.7787874434069905,0.9870207905769348,0.0504884012043476,0.2932349068101686,43793.0,0.9861586689949036,0.0539308078587055,0.2759207687265104,43793.0,48758.59159564972,73566.14442420006,48758.59159564972,24794.351008176804,8.982405424118042,0.0 -150200,0.14416969,0.017466955,,,,,,,,,,,,,,,,, -150300,0.14333566,0.017479222,,,,,,,,,,,,,,,,, -150400,0.14522818,0.019042417,,,,,,,,,,,,,,,,, -150500,0.13499908,0.018330565,,,,,,,,,,,,,,,,, -150600,0.16073081,0.019422254,,,,,,,,,,,,,,,,, -150700,0.14114738,0.016286382,,,,,,,,,,,,,,,,, -150800,0.13111071,0.01461841,,,,,,,,,,,,,,,,, -150900,0.14977397,0.017972345,,,,,,,,,,,,,,,,, -150901,,,0.995482325553894,0.0140950242057442,0.7615456821250393,0.9870207905769348,0.0504884012043476,0.2930889760262496,43793.0,0.9861586689949036,0.0539308078587055,0.2759093645543544,43793.0,48998.60104203224,73911.11200237274,48998.60104203224,24899.233260393143,9.03764295578003,0.0 -151000,0.14220703,0.017859263,,,,,,,,,,,,,,,,, -151100,0.1339721,0.016370557,,,,,,,,,,,,,,,,, -151200,0.14686202,0.0209937,,,,,,,,,,,,,,,,, -151300,0.13542627,0.016025612,,,,,,,,,,,,,,,,, -151400,0.14295484,0.019102842,,,,,,,,,,,,,,,,, -151500,0.147509,0.017898591,,,,,,,,,,,,,,,,, -151600,0.161534,0.02072909,,,,,,,,,,,,,,,,, -151624,,,0.9954485893249512,0.0142013663426041,0.7779760903966846,0.9870207905769348,0.0504884012043476,0.2931897409197678,43793.0,0.9861586689949036,0.0539308078587055,0.2759801610305899,43793.0,49238.62616467476,74261.79939770699,49238.62616467476,25009.81607890129,9.094947576522827,0.0 -151700,0.14932358,0.015315761,,,,,,,,,,,,,,,,, -151800,0.14454772,0.019551996,,,,,,,,,,,,,,,,, -151900,0.1313218,0.016651418,,,,,,,,,,,,,,,,, -152000,0.14899464,0.018308429,,,,,,,,,,,,,,,,, -152100,0.14588015,0.018107388,,,,,,,,,,,,,,,,, -152200,0.12771621,0.016777812,,,,,,,,,,,,,,,,, -152300,0.13854185,0.016307043,,,,,,,,,,,,,,,,, -152363,,,0.9955525994300842,0.0139115117490291,0.7837436456993829,0.9870207905769348,0.0504884012043476,0.2934063355225442,43793.0,0.9861586689949036,0.0539308078587055,0.2759371192052875,43793.0,49478.61374115944,74607.41888213158,49478.61374115944,25115.36889028549,9.151697158813477,0.0 -152400,0.13084598,0.017286764,,,,,,,,,,,,,,,,, -152500,0.15879129,0.02008204,,,,,,,,,,,,,,,,, -152600,0.1589219,0.017803846,,,,,,,,,,,,,,,,, -152700,0.14208162,0.018855512,,,,,,,,,,,,,,,,, -152800,0.12608537,0.016999356,,,,,,,,,,,,,,,,, -152900,0.13577186,0.018155087,,,,,,,,,,,,,,,,, -153000,0.13892217,0.01749522,,,,,,,,,,,,,,,,, -153100,0.15672491,0.017513191,,,,,,,,,,,,,,,,, -153102,,,0.9954980611801147,0.0140454312786459,0.7813701511430701,0.9870207905769348,0.0504884012043476,0.2933010124942687,43793.0,0.9861586689949036,0.0539308078587055,0.27590422734712,43793.0,49718.60764694214,74958.15905070305,49718.60764694214,25226.03780841828,9.207793951034546,0.0 -153200,0.13078111,0.016927956,,,,,,,,,,,,,,,,, -153300,0.16829039,0.0219337,,,,,,,,,,,,,,,,, -153400,0.13795719,0.01882468,,,,,,,,,,,,,,,,, -153500,0.13425463,0.019352764,,,,,,,,,,,,,,,,, -153600,0.14433242,0.017413577,,,,,,,,,,,,,,,,, -153700,0.13623615,0.017449351,,,,,,,,,,,,,,,,, -153800,0.12450092,0.0159268,,,,,,,,,,,,,,,,, -153849,,,0.995534360408783,0.0139383375644683,0.7729496941587721,0.9870207905769348,0.0504884012043476,0.2932470890447527,43793.0,0.9861586689949036,0.0539308004081249,0.2758850884604144,43793.0,49958.73828434944,75302.66285181046,49958.73828434944,25330.334392786022,9.263737916946411,0.0 -153900,0.14164197,0.01567599,,,,,,,,,,,,,,,,, -154000,0.14632055,0.0176069,,,,,,,,,,,,,,,,, -154100,0.13996388,0.019083783,,,,,,,,,,,,,,,,, -154200,0.14077377,0.0168342,,,,,,,,,,,,,,,,, -154300,0.14759547,0.019589074,,,,,,,,,,,,,,,,, -154400,0.14068371,0.016734533,,,,,,,,,,,,,,,,, -154500,0.15221539,0.0207385,,,,,,,,,,,,,,,,, -154588,,,0.9954615831375122,0.014223264530301,0.7669854386099493,0.9870207905769348,0.0504884012043476,0.2932654833117584,43793.0,0.9861586689949036,0.0539308004081249,0.2760416363120542,43793.0,50198.92618584633,75648.9467959404,50198.92618584633,25436.34400558472,9.328450679779053,0.0 -154600,0.13459148,0.017073061,,,,,,,,,,,,,,,,, -154700,0.12806448,0.01683171,,,,,,,,,,,,,,,,, -154800,0.14594245,0.017947245,,,,,,,,,,,,,,,,, -154900,0.1285165,0.016512567,,,,,,,,,,,,,,,,, -155000,0.131409,0.016373003,,,,,,,,,,,,,,,,, -155100,0.14251718,0.018294001,,,,,,,,,,,,,,,,, -155200,0.12978211,0.016770108,,,,,,,,,,,,,,,,, -155300,0.13717325,0.018041836,,,,,,,,,,,,,,,,, -155327,,,0.9954792261123656,0.0141309844329953,0.77602726156619,0.9870207905769348,0.0504884012043476,0.2932259288464124,43793.0,0.9861586689949036,0.0539308078587055,0.2759604086981166,43793.0,50438.856477975845,75996.59171676636,50438.856477975845,25543.98120045662,9.3852961063385,0.0 -155400,0.13941413,0.017278522,,,,,,,,,,,,,,,,, -155500,0.13961586,0.018884208,,,,,,,,,,,,,,,,, -155600,0.14723241,0.018201113,,,,,,,,,,,,,,,,, -155700,0.13546629,0.016726157,,,,,,,,,,,,,,,,, -155800,0.15122971,0.019899292,,,,,,,,,,,,,,,,, -155900,0.14010644,0.017862389,,,,,,,,,,,,,,,,, -156000,0.15086415,0.018024605,,,,,,,,,,,,,,,,, -156065,,,0.9955069422721864,0.0139641826972365,0.7782063057074599,0.9870207905769348,0.0504884012043476,0.2930974827205313,43793.0,0.9861586689949036,0.0539308004081249,0.2759401832193064,43793.0,50679.07929563522,76343.71928310394,50679.07929563522,25650.809689760208,9.441604852676392,0.0 -156100,0.15374298,0.018336978,,,,,,,,,,,,,,,,, -156200,0.13765289,0.0185884,,,,,,,,,,,,,,,,, -156300,0.13311307,0.018600732,,,,,,,,,,,,,,,,, -156400,0.14354423,0.018474406,,,,,,,,,,,,,,,,, -156500,0.15575688,0.01613428,,,,,,,,,,,,,,,,, -156600,0.1374428,0.019118257,,,,,,,,,,,,,,,,, -156700,0.11789398,0.016555507,,,,,,,,,,,,,,,,, -156800,0.14722928,0.016016334,,,,,,,,,,,,,,,,, -156808,,,0.995521605014801,0.0139957116916775,0.7851263115303592,0.9870207905769348,0.0504884012043476,0.2933636890766024,43793.0,0.9861586689949036,0.0539308078587055,0.2760160452841331,43793.0,50919.26484918594,76697.53965449333,50919.26484918594,25764.367182970047,9.498466730117798,0.0 -156900,0.13327822,0.017600406,,,,,,,,,,,,,,,,, -157000,0.13405192,0.017943725,,,,,,,,,,,,,,,,, -157100,0.13903011,0.01737117,,,,,,,,,,,,,,,,, -157200,0.14558864,0.020147132,,,,,,,,,,,,,,,,, -157300,0.15436438,0.01766102,,,,,,,,,,,,,,,,, -157400,0.13749528,0.015981747,,,,,,,,,,,,,,,,, -157500,0.1335507,0.017060434,,,,,,,,,,,,,,,,, -157535,,,0.9954743981361388,0.0140967210754752,0.7758255457820338,0.9870207905769348,0.0504884012043476,0.293116103459514,43793.0,0.9861586689949036,0.0539308004081249,0.2758944622297735,43793.0,51159.42909312248,77047.70661783218,51159.42909312248,25874.278917074203,9.567306756973268,0.0 -157600,0.16217466,0.019244444,,,,,,,,,,,,,,,,, -157700,0.12891601,0.017687606,,,,,,,,,,,,,,,,, -157800,0.15781192,0.02056337,,,,,,,,,,,,,,,,, -157900,0.1384495,0.017681323,,,,,,,,,,,,,,,,, -158000,0.13810268,0.020145414,,,,,,,,,,,,,,,,, -158100,0.13641097,0.018021973,,,,,,,,,,,,,,,,, -158200,0.13088825,0.01677061,,,,,,,,,,,,,,,,, -158280,,,0.9955391883850098,0.0139821004122495,0.7747015856380461,0.9870207905769348,0.0504884012043476,0.2931608099330871,43793.0,0.9861586689949036,0.0539308004081249,0.275884767550971,43793.0,51399.40223193169,77394.87665104866,51399.40223193169,25981.39734506607,9.625344038009644,0.0 -158300,0.15835159,0.01860197,,,,,,,,,,,,,,,,, -158400,0.16678593,0.017596109,,,,,,,,,,,,,,,,, -158500,0.14125338,0.018233683,,,,,,,,,,,,,,,,, -158600,0.15130499,0.017313678,,,,,,,,,,,,,,,,, -158700,0.12978537,0.016441295,,,,,,,,,,,,,,,,, -158800,0.14787468,0.020117404,,,,,,,,,,,,,,,,, -158900,0.14204496,0.017101867,,,,,,,,,,,,,,,,, -159000,0.13498105,0.017082253,,,,,,,,,,,,,,,,, -159021,,,0.9954434037208556,0.0142225446179509,0.7661622075147424,0.9870207905769348,0.0504884012043476,0.2931163238204452,43793.0,0.9861586689949036,0.0539308078587055,0.2759102206121796,43793.0,51639.65240359306,77745.88492846489,51639.65240359306,26092.07701206208,9.682848691940308,0.0 -159100,0.14071196,0.015555609,,,,,,,,,,,,,,,,, -159200,0.15137272,0.019190406,,,,,,,,,,,,,,,,, -159300,0.138706,0.01770922,,,,,,,,,,,,,,,,, -159400,0.13977963,0.01740101,,,,,,,,,,,,,,,,, -159500,0.12770404,0.017669769,,,,,,,,,,,,,,,,, -159600,0.12567852,0.019188443,,,,,,,,,,,,,,,,, -159700,0.15031911,0.018610673,,,,,,,,,,,,,,,,, -159762,,,0.99553245306015,0.014043060131371,0.7827541050873437,0.9870207905769348,0.0504884012043476,0.293129897049583,43793.0,0.9861586689949036,0.0539308004081249,0.275975848750648,43793.0,51879.57717633248,78091.97503495216,51879.57717633248,26198.158568143845,9.7459876537323,0.0 -159800,0.1543338,0.021352582,,,,,,,,,,,,,,,,, -159900,0.13911068,0.016831761,,,,,,,,,,,,,,,,, -160000,0.1505983,0.01874735,,,,,,,,,,,,,,,,, -160100,0.13938569,0.017390085,,,,,,,,,,,,,,,,, -160200,0.15280245,0.018383836,,,,,,,,,,,,,,,,, -160300,0.1422977,0.018348202,,,,,,,,,,,,,,,,, -160400,0.13603932,0.017131316,,,,,,,,,,,,,,,,, -160498,,,0.995540201663971,0.0139052579179406,0.7825101913603374,0.9870207905769348,0.0504884012043476,0.2931839937265932,43793.0,0.9861586689949036,0.0539308078587055,0.2759441100966021,43793.0,52119.58177232742,78435.2933254242,52119.58177232742,26301.383195638657,9.815279960632324,0.0 -160500,0.1452722,0.019670824,,,,,,,,,,,,,,,,, -160600,0.14076614,0.018137297,,,,,,,,,,,,,,,,, -160700,0.13624428,0.017201329,,,,,,,,,,,,,,,,, -160800,0.16738793,0.020570697,,,,,,,,,,,,,,,,, -160900,0.14411724,0.01697091,,,,,,,,,,,,,,,,, -161000,0.13517176,0.017137423,,,,,,,,,,,,,,,,, -161100,0.13328107,0.015709154,,,,,,,,,,,,,,,,, -161200,0.13900886,0.018648043,,,,,,,,,,,,,,,,, -161235,,,0.9954307675361632,0.0141254318878054,0.7702380539786756,0.9870207905769348,0.0504884012043476,0.2932830918273247,43793.0,0.9861586689949036,0.0539308078587055,0.2759489303136601,43793.0,52359.627802848816,78782.22266626358,52359.627802848816,26408.17783999443,9.880241632461548,0.0 -161300,0.15912056,0.020296818,,,,,,,,,,,,,,,,, -161400,0.1534839,0.018227862,,,,,,,,,,,,,,,,, -161500,0.13249366,0.01748017,,,,,,,,,,,,,,,,, -161600,0.15201494,0.01916288,,,,,,,,,,,,,,,,, -161700,0.1340053,0.017608235,,,,,,,,,,,,,,,,, -161800,0.14610368,0.016812224,,,,,,,,,,,,,,,,, -161900,0.15756196,0.016883375,,,,,,,,,,,,,,,,, -161978,,,0.9955178499221802,0.0140783488750457,0.778865225406159,0.9870207905769348,0.0504884012043476,0.2931667231247588,43793.0,0.9861586689949036,0.0539308004081249,0.275967983978321,43793.0,52599.55924654007,79128.88596534729,52599.55924654007,26514.83008337021,9.939560890197754,0.0 -162000,0.15865733,0.017457278,,,,,,,,,,,,,,,,, -162100,0.14381729,0.018591572,,,,,,,,,,,,,,,,, -162200,0.15214723,0.017854476,,,,,,,,,,,,,,,,, -162300,0.13461298,0.017998265,,,,,,,,,,,,,,,,, -162400,0.16268057,0.019470269,,,,,,,,,,,,,,,,, -162500,0.13373671,0.017618414,,,,,,,,,,,,,,,,, -162600,0.12937224,0.01766212,,,,,,,,,,,,,,,,, -162700,0.13385217,0.017862584,,,,,,,,,,,,,,,,, -162719,,,0.9954817891120912,0.014129121787846,0.7618262995083183,0.9870207905769348,0.0504884012043476,0.2932867883896753,43793.0,0.9861586689949036,0.0539308078587055,0.2760785878656386,43793.0,52839.79007482529,79473.52391839027,52839.79007482529,26619.158552885056,9.997602939605711,0.0 -162800,0.15325119,0.0178577,,,,,,,,,,,,,,,,, -162900,0.15269731,0.018401252,,,,,,,,,,,,,,,,, -163000,0.12910153,0.017649386,,,,,,,,,,,,,,,,, -163100,0.13973834,0.01707989,,,,,,,,,,,,,,,,, -163200,0.13237722,0.016894931,,,,,,,,,,,,,,,,, -163300,0.13495956,0.016111044,,,,,,,,,,,,,,,,, -163400,0.13814174,0.017803704,,,,,,,,,,,,,,,,, -163454,,,0.9954960942268372,0.0141142038628458,0.7793173555613062,0.9870207905769348,0.0504884012043476,0.2932046730429082,43793.0,0.9861586689949036,0.0539308078587055,0.275908853560883,43793.0,53079.76247572899,79821.12660717964,53079.76247572899,26726.71019911766,10.055290699005129,0.0 -163500,0.14160092,0.01909023,,,,,,,,,,,,,,,,, -163600,0.13799296,0.018277362,,,,,,,,,,,,,,,,, -163700,0.13560586,0.017316809,,,,,,,,,,,,,,,,, -163800,0.13546768,0.018455582,,,,,,,,,,,,,,,,, -163900,0.14733633,0.02003605,,,,,,,,,,,,,,,,, -164000,0.14765224,0.01754377,,,,,,,,,,,,,,,,, -164100,0.13525,0.017854167,,,,,,,,,,,,,,,,, -164196,,,0.9955257773399352,0.0139328949153423,0.7799171908740734,0.9870207905769348,0.0504884012043476,0.2933241683517332,43793.0,0.9861586689949036,0.0539308004081249,0.2760912667040872,43793.0,53319.88084149361,80164.88756608963,53319.88084149361,26830.27416396141,10.113618850708008,0.0 -164200,0.14460722,0.018118134,,,,,,,,,,,,,,,,, -164300,0.13446471,0.019124988,,,,,,,,,,,,,,,,, -164400,0.14943193,0.017540976,,,,,,,,,,,,,,,,, -164500,0.13240473,0.014874897,,,,,,,,,,,,,,,,, -164600,0.15472895,0.019418461,,,,,,,,,,,,,,,,, -164700,0.1450578,0.016093258,,,,,,,,,,,,,,,,, -164800,0.14793232,0.018453034,,,,,,,,,,,,,,,,, -164900,0.1445198,0.017599327,,,,,,,,,,,,,,,,, -164940,,,0.9955264925956726,0.0139749469235539,0.781608569478042,0.9870207905769348,0.0504884012043476,0.2930783591027502,43793.0,0.9861586689949036,0.0539308078587055,0.2759160436130984,43793.0,53559.96479272842,80508.71048593521,53559.96479272842,26933.934188604355,10.171966552734377,0.0 -165000,0.16092105,0.019343428,,,,,,,,,,,,,,,,, -165100,0.1449883,0.018590366,,,,,,,,,,,,,,,,, -165200,0.14679892,0.018248226,,,,,,,,,,,,,,,,, -165300,0.14105771,0.017781474,,,,,,,,,,,,,,,,, -165400,0.14911288,0.018416293,,,,,,,,,,,,,,,,, -165500,0.14798385,0.01746173,,,,,,,,,,,,,,,,, -165600,0.15727939,0.019855615,,,,,,,,,,,,,,,,, -165683,,,0.9954779148101808,0.0141129968687891,0.7717062319651007,0.9870207905769348,0.0504884012043476,0.2932151454680108,43793.0,0.9861586689949036,0.0539308004081249,0.2758898549006813,43793.0,53800.09856343269,80856.76532030106,53800.09856343269,27041.776702404022,10.230546236038208,0.0 -165700,0.1506328,0.021764018,,,,,,,,,,,,,,,,, -165800,0.15078045,0.018352093,,,,,,,,,,,,,,,,, -165900,0.14528494,0.018112142,,,,,,,,,,,,,,,,, -166000,0.14131452,0.019094655,,,,,,,,,,,,,,,,, -166100,0.14653428,0.01835831,,,,,,,,,,,,,,,,, -166200,0.13818769,0.015986988,,,,,,,,,,,,,,,,, -166300,0.13605054,0.017618867,,,,,,,,,,,,,,,,, -166400,0.13240892,0.017670749,,,,,,,,,,,,,,,,, -166430,,,0.9955308437347412,0.013994694687426,0.7687587064128196,0.9870207905769348,0.0504884012043476,0.2931156943476471,43793.0,0.9861586689949036,0.0539308004081249,0.2759094356033681,43793.0,54040.2394015789,81205.09357523918,54040.2394015789,27149.88500189781,10.289138793945312,0.0 -166500,0.12944694,0.016808828,,,,,,,,,,,,,,,,, -166600,0.13629295,0.016703295,,,,,,,,,,,,,,,,, -166700,0.15267704,0.019593377,,,,,,,,,,,,,,,,, -166800,0.1498189,0.019566372,,,,,,,,,,,,,,,,, -166900,0.14590837,0.018280867,,,,,,,,,,,,,,,,, -167000,0.15734255,0.020357305,,,,,,,,,,,,,,,,, -167100,0.1398205,0.01864877,,,,,,,,,,,,,,,,, -167172,,,0.9954442977905272,0.0142495296895504,0.7769547425832453,0.9870207905769348,0.0504884012043476,0.2930756406702172,43793.0,0.9861586689949036,0.0539308078587055,0.276066158956547,43793.0,54280.27514696121,81553.69691586494,54280.27514696121,27258.364639759064,10.35668396949768,0.0 -167200,0.16316348,0.018066864,,,,,,,,,,,,,,,,, -167300,0.15295957,0.018080128,,,,,,,,,,,,,,,,, -167400,0.13145362,0.015545991,,,,,,,,,,,,,,,,, -167500,0.13605455,0.017748816,,,,,,,,,,,,,,,,, -167600,0.1377551,0.01685699,,,,,,,,,,,,,,,,, -167700,0.14567854,0.01906314,,,,,,,,,,,,,,,,, -167800,0.16146544,0.01806423,,,,,,,,,,,,,,,,, -167900,0.14090766,0.015837371,,,,,,,,,,,,,,,,, -167916,,,0.9955223202705384,0.0139664039015769,0.778954704864464,0.9870207905769348,0.0504884012043476,0.2930991619901983,43793.0,0.9861586689949036,0.0539308078587055,0.2759194664446954,43793.0,54520.21311235428,81897.25141072273,54520.21311235428,27361.902376651764,10.414924383163452,0.0 -168000,0.13203654,0.018289212,,,,,,,,,,,,,,,,, -168100,0.14430527,0.017774338,,,,,,,,,,,,,,,,, -168200,0.12797748,0.016365223,,,,,,,,,,,,,,,,, -168300,0.12482215,0.018723886,,,,,,,,,,,,,,,,, -168400,0.15285985,0.017161706,,,,,,,,,,,,,,,,, -168500,0.14301933,0.01727951,,,,,,,,,,,,,,,,, -168600,0.13337359,0.017244173,,,,,,,,,,,,,,,,, -168667,,,0.9955002069473268,0.0140024097636342,0.781999000653092,0.9870207905769348,0.0504884012043476,0.2931762744871093,43793.0,0.9861586689949036,0.0539308078587055,0.2760429089668513,43793.0,54760.16240620613,82244.221940279,54760.16240620613,27468.845179080963,10.472869873046877,0.0 -168700,0.13361084,0.016811118,,,,,,,,,,,,,,,,, -168800,0.13451089,0.016955111,,,,,,,,,,,,,,,,, -168900,0.124888055,0.014420641,,,,,,,,,,,,,,,,, -169000,0.15384617,0.01741418,,,,,,,,,,,,,,,,, -169100,0.13081057,0.017212471,,,,,,,,,,,,,,,,, -169200,0.14685898,0.014985806,,,,,,,,,,,,,,,,, -169300,0.1435203,0.018736657,,,,,,,,,,,,,,,,, -169400,0.15044853,0.018922046,,,,,,,,,,,,,,,,, -169410,,,0.995485544204712,0.0140526164323091,0.7778016832839559,0.9870207905769348,0.0504884012043476,0.2932229489257906,43793.0,0.9861586689949036,0.0539308078587055,0.2759854657923026,43793.0,55000.32378101349,82586.32157897949,55000.32378101349,27570.706958293915,10.529196977615356,0.0 -169500,0.1650487,0.018813074,,,,,,,,,,,,,,,,, -169600,0.15269808,0.017838458,,,,,,,,,,,,,,,,, -169700,0.15020089,0.01683086,,,,,,,,,,,,,,,,, -169800,0.15180898,0.018523254,,,,,,,,,,,,,,,,, -169900,0.1380504,0.01694285,,,,,,,,,,,,,,,,, -170000,0.14141004,0.018817423,,,,,,,,,,,,,,,,, -170100,0.14067134,0.016911477,,,,,,,,,,,,,,,,, -170162,,,0.9955334067344666,0.0139998979866504,0.768694405169295,0.9870207905769348,0.0504884012043476,0.2932143589528342,43793.0,0.9861586689949036,0.0539308078587055,0.2759205124404382,43793.0,55240.26776766777,82930.82968711853,55240.26776766777,27675.19109106064,10.588984966278076,0.0 -170200,0.12920044,0.01806071,,,,,,,,,,,,,,,,, -170300,0.1488318,0.018439213,,,,,,,,,,,,,,,,, -170400,0.14175124,0.019653285,,,,,,,,,,,,,,,,, -170500,0.13925849,0.017810384,,,,,,,,,,,,,,,,, -170600,0.14193769,0.01782693,,,,,,,,,,,,,,,,, -170700,0.14852712,0.018070633,,,,,,,,,,,,,,,,, -170758,,,,,,,,,,,,,,55431.01209068298,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index f4cfe617e..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,50 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -900.4765393733978,0.0,37.97205352783203,1,0,37.97205352783203,0.0007088489946909,0.0,11.114374160766602,3003,938.4486303329468,0.0006708127912133,0.0,11.12966251373291,0.0004835649742744,0.0,11.12306308746338,3000 -1370.3998427391052,0.0915694236755371,877.8537795543671,2302,0,877.8537795543671,0.5070478320121765,16.23465916489031,2.945228576660156,3003,2248.415539741516,0.5118868350982666,21.798427704443863,2.8945295810699463,0.5098015069961548,17.50913139290973,2.891077995300293,3000 -1819.5305128097527,0.1161665916442871,1717.9510114192965,4607,0,1717.9510114192965,0.5886119604110718,21.83645237670241,2.14756727218628,3003,3537.737144947052,0.5724933743476868,26.642778119238887,2.27587366104126,0.5883870124816895,23.35456405805425,2.16740345954895,3000 -2269.650138616562,0.1388981342315673,2557.946650266648,6912,0,2557.946650266648,0.6200685501098633,24.104531101372096,1.8967156410217283,3003,4827.942455291748,0.6066136360168457,29.017854244266523,2.0027647018432617,0.6156030297279358,25.35619328040451,1.9432120323181152,3000 -2728.376378059387,0.1646599769592285,3397.8931415081024,9215,0,3397.8931415081024,0.6364883184432983,25.23227095534961,1.7690547704696655,3003,6126.708575963974,0.6120672225952148,29.916192478814693,1.959139347076416,0.6295396089553833,26.2413759186834,1.8218027353286743,3000 -3228.05371260643,0.1887707710266113,4237.998886823654,11519,0,4237.998886823654,0.6449480056762695,25.399824245978028,1.7010329961776731,3003,7466.583305358887,0.6196917295455933,30.07146527648772,1.887145757675171,0.6392481327056885,26.88392878478057,1.7518912553787231,3000 -3682.720309734345,0.214097261428833,5078.007721185684,13824,0,5078.007721185684,0.6524780988693237,26.42720410787868,1.6474509239196775,3003,8761.350906610489,0.6305181980133057,30.40540911721149,1.8076016902923584,0.6429430246353149,27.113700672480107,1.710771918296814,3000 -4210.324150323868,0.2381901741027832,5918.018773555756,16128,0,5918.018773555756,0.6546278595924377,26.11191716618898,1.6145895719528198,3003,10129.056498527529,0.6288556456565857,30.45977095033596,1.831342458724976,0.6467743515968323,27.353765269924896,1.6920876502990725,3000 -4678.479519367218,0.2643682956695556,6757.959350347519,18432,0,6757.959350347519,0.661158561706543,26.92102044075361,1.5822104215621948,3003,11437.245300531387,0.6309792995452881,30.74278095983484,1.8097506761550903,0.6506924629211426,27.71685002958468,1.6586767435073853,3000 -5239.943262815476,0.2907545566558838,7598.0957589149475,20737,0,7598.0957589149475,0.6625065803527832,27.02671804348759,1.571100831031799,3003,12838.938373327255,0.6350600719451904,31.086818959270097,1.7639600038528442,0.6527507305145264,27.800610918108475,1.6468219757080078,3000 -5752.122546672821,0.3188722133636474,8438.187176465988,23042,0,8438.187176465988,0.6654813885688782,27.471505111759694,1.552293062210083,3003,14191.30432844162,0.6321536302566528,30.843176318892503,1.7879424095153809,0.6557885408401489,27.80706443069435,1.6294809579849243,3000 -6292.985018253326,0.3446135520935058,9278.177888393402,25346,0,9278.177888393402,0.6647841930389404,26.93845269141764,1.544036865234375,3003,15572.249928712845,0.6534295678138733,31.96841674854174,1.6287175416946411,0.6557636857032776,27.763272642865847,1.620172142982483,3000 -6871.334754228592,0.3750591278076172,10118.242518424988,27651,0,10118.242518424988,0.6674219965934753,27.34712870999217,1.5395432710647583,3003,16990.762141942978,0.6354812383651733,30.85710459621641,1.7555207014083862,0.655738890171051,27.76954025263315,1.6140739917755127,3000 -7409.287669897079,0.4019620418548584,10958.493296861649,29957,0,10958.493296861649,0.6674801111221313,26.982702387093063,1.530782699584961,3003,18369.059331417084,0.6409510970115662,30.828748013400578,1.732843995094299,0.6587767004966736,27.44061347257815,1.6031513214111328,3000 -7955.819365501404,0.4307701587677002,11798.761286258698,32263,0,11798.761286258698,0.6707687377929688,27.344738706710345,1.5176211595535278,3003,19755.955070257187,0.6501732468605042,31.79653325067196,1.671488881111145,0.6608969569206238,27.934553907088265,1.5915027856826782,3000 -8483.97195148468,0.4588100910186767,12638.940054178238,34568,0,12638.940054178238,0.6739526987075806,27.71240378788925,1.5071828365325928,3003,21124.38235092163,0.639989972114563,31.40815149263302,1.737483024597168,0.6610333323478699,28.18224271573024,1.5890058279037476,3000 -9033.718392133713,0.4855470657348633,13478.908013343813,36873,0,13478.908013343813,0.6739643216133118,27.51490274102142,1.4984428882598877,3003,22514.191542625427,0.6420833468437195,31.19616011212892,1.7289228439331057,0.662744402885437,28.243285991467825,1.5755648612976074,3000 -9664.304621696472,0.5149462223052979,14318.914702415466,39178,0,14318.914702415466,0.6731741428375244,27.401873969473264,1.4919666051864624,3003,23984.88139367104,0.6463394165039062,31.62721259487478,1.6828906536102295,0.6632651686668396,28.16472535589393,1.5713107585906982,3000 -10241.513010263445,0.5448994636535645,15159.026709794998,41484,0,15159.026709794998,0.677241325378418,27.910100164262925,1.4807902574539185,3003,25402.29855298996,0.6445146203041077,31.45555725558557,1.705758571624756,0.6643810868263245,28.44802331863066,1.5640370845794678,3000 -10730.098761081696,0.5752518177032471,15999.166652917862,43790,0,15999.166652917862,0.6773110628128052,28.13029009785281,1.4803187847137451,3003,26731.12109351158,0.678406298160553,34.31987309324694,1.4894598722457886,0.666340172290802,28.83492635339662,1.5590335130691528,3000 -11318.53396177292,0.605954647064209,16839.33149933815,46095,0,16839.33149933815,0.6787520051002502,28.32559704180684,1.4696136713027954,3003,28159.8185441494,0.6512089371681213,31.58748186605699,1.6554925441741943,0.6669477224349976,28.66998281072759,1.552070379257202,3000 -11905.7664437294,0.6351094245910645,17679.514671325684,48401,0,17679.514671325684,0.6809017658233643,28.30155804844667,1.458929181098938,3003,29587.330949544907,0.648673415184021,31.704655840258024,1.6775165796279907,0.6679148077964783,28.49673433425077,1.5459349155426023,3000 -12453.499742031096,0.663583517074585,18519.6294836998,50707,0,18519.6294836998,0.6794027090072632,27.870305365442896,1.4583882093429563,3003,30975.2739675045,0.6610821485519409,32.63553191464314,1.5793545246124268,0.6676668524742126,28.773472505424525,1.5392777919769287,3000 -13041.082812786102,0.6946089267730713,19359.6201646328,53012,0,19359.6201646328,0.6824240684509277,28.27440388210995,1.4426162242889404,3003,32402.945701360703,0.6509921550750732,31.95503741876957,1.663438320159912,0.6693159341812134,28.66539324958268,1.528120040893555,3000 -13607.088655233383,0.7248268127441406,20199.62682914734,55318,0,20199.62682914734,0.6829469799995422,28.26515132787788,1.442746877670288,3003,33809.05586147308,0.6490379571914673,31.85959430543849,1.6658644676208496,0.668584406375885,28.63456609958289,1.5241024494171145,3000 -14177.54667854309,0.7564868927001953,21039.827073812485,57624,0,21039.827073812485,0.6851897239685059,28.56584009832488,1.4309513568878174,3003,35219.81306219101,0.6594597697257996,32.67480079360041,1.596582531929016,0.6716469526290894,29.14386669057417,1.5170725584030151,3000 -14757.277026891708,0.7875394821166992,21880.05396389961,59929,0,21880.05396389961,0.6837139129638672,28.37492052326649,1.4270353317260742,3003,36639.86911034584,0.6542331576347351,32.1923354088288,1.6329275369644165,0.6711261868476868,28.66304524085199,1.514757752418518,3000 -15261.25569176674,0.8184239864349365,22720.285170316696,62235,0,22720.285170316696,0.6874208450317383,28.6996542227149,1.4156063795089722,3003,37984.17668604851,0.6547898054122925,32.29131481800968,1.6457945108413696,0.6748583316802979,29.144210761682988,1.5035715103149414,3000 -15842.44847869873,0.8506321907043457,23560.25806093216,64540,0,23560.25806093216,0.6869560480117798,29.01646062679263,1.4149410724639893,3003,39405.44167017937,0.6574546694755554,32.33868366983711,1.6062756776809692,0.6742879748344421,29.16130690093247,1.499944806098938,3000 -16379.987210035324,0.8830595016479492,24400.42298603058,66847,0,24400.42298603058,0.6883156299591064,28.97097313380021,1.4015823602676392,3003,40783.24477171898,0.6589851975440979,32.50040292950039,1.6024383306503296,0.675317108631134,29.14662722173597,1.4933512210845947,3000 -16991.00972509384,0.9139461517333984,25240.656269311905,69153,0,25240.656269311905,0.689233660697937,29.031414589198896,1.3974084854125977,3003,42234.59798383713,0.6765788197517395,33.541029229499564,1.4881170988082886,0.676656186580658,29.058710968380737,1.4876593351364136,3000 -17630.459098100662,0.9451935291290284,26080.56746864319,71459,0,26080.56746864319,0.6905351281166077,28.987678393285474,1.3871610164642334,3003,43714.0570063591,0.665115475654602,32.916180646277745,1.5632176399230957,0.6775985360145569,29.321900177988866,1.4751025438308716,3000 -18194.033081531525,0.9769787788391112,26920.736166715626,73765,0,26920.736166715626,0.6927546262741089,28.91781925583517,1.381793975830078,3003,45117.89793419838,0.6569005846977234,33.01953946347801,1.6199302673339844,0.6786896586418152,29.54283922963444,1.4717580080032349,3000 -18745.29731321335,1.009082555770874,27760.940415620804,76071,0,27760.940415620804,0.6914531588554382,28.920786485322584,1.3808372020721436,3003,46509.46474575997,0.6729021668434143,33.29321228860028,1.506928563117981,0.6776605248451233,29.190714508145124,1.4689043760299685,3000 -19440.16348552704,1.0446865558624268,28600.997208356857,78377,0,28600.997208356857,0.6954041123390198,29.17774261731953,1.366335391998291,3003,48044.48940658569,0.6650096774101257,33.123997561018435,1.5661970376968384,0.6799543499946594,29.59472036218692,1.4611165523529053,3000 -20023.029673337936,1.0782907009124756,29441.037866830826,80684,0,29441.037866830826,0.6970890760421753,29.614426414284896,1.3583707809448242,3003,49467.49631810188,0.66160649061203,33.317098467297384,1.584054946899414,0.6825581789016724,29.821797940794205,1.45094633102417,3000 -20611.38403081894,1.1141114234924316,30281.05094337464,82990,0,30281.05094337464,0.6958457231521606,29.73206205813485,1.355960488319397,3003,50895.96606874466,0.6741752624511719,33.5210667999066,1.5034538507461548,0.6817522048950195,29.79512560255042,1.444745421409607,3000 -21244.30550432205,1.1484274864196775,31121.205088377,85296,0,31121.205088377,0.6988786458969116,29.391099964317444,1.3426960706710815,3003,52369.14261388779,0.6712327003479004,33.1370007701296,1.5273113250732422,0.6831037402153015,29.524563702331893,1.4382604360580444,3000 -21825.86384224892,1.1863529682159424,31961.25281882286,87601,0,31961.25281882286,0.7010400295257568,29.907318585279327,1.3378076553344729,3003,53790.85534501076,0.7025133371353149,36.16159872807506,1.3403351306915283,0.6858935356140137,30.037808770372497,1.4293599128723145,3000 -22489.192754983906,1.2217700481414795,32801.215609788895,89907,0,32801.215609788895,0.7012608647346497,29.77899500255573,1.3300431966781616,3003,55294.24873948097,0.6764506101608276,33.920997173484295,1.486994981765747,0.6865506768226624,30.330018269000004,1.4231069087982178,3000 -23181.273468256,1.2571892738342283,33641.39935684204,92214,0,33641.39935684204,0.7024693489074707,29.70131795830228,1.323116898536682,3003,56826.61493492127,0.6773262619972229,33.9196685232179,1.4910597801208496,0.6862035393714905,30.39234321071231,1.423099160194397,3000 -23671.67979907989,1.292123556137085,34481.39404511452,94520,0,34481.39404511452,0.7016210556030273,29.794721474995853,1.318655252456665,3003,58157.11773443222,0.6835636496543884,34.798970108395004,1.435005784034729,0.6891916990280151,30.261402817179807,1.411435604095459,3000 -24450.89505577088,1.328049898147583,35321.29417848587,96826,0,35321.29417848587,0.7032014727592468,29.931591119075662,1.310566782951355,3003,59776.336155653,0.6781773567199707,34.505094291513295,1.476668119430542,0.6873690485954285,30.01696468775817,1.410524845123291,3000 -24991.73729467392,1.3659710884094238,36161.29120993614,99133,0,36161.29120993614,0.7041775584220886,30.258098751022786,1.3046365976333618,3003,61157.28055071831,0.6766044497489929,34.415588200909774,1.4872363805770874,0.687480628490448,30.11480655939417,1.401458740234375,3000 -25580.229484319687,1.402970314025879,37001.29831576347,101439,0,37001.29831576347,0.7062227725982666,30.72127951020204,1.295928120613098,3003,62585.88406729698,0.6870901584625244,34.776603339122985,1.424283742904663,0.6906920075416565,30.415153484049604,1.3949700593948364,3000 -26135.257561206818,1.4406273365020752,37841.34276676178,103746,0,37841.34276676178,0.7068967819213867,30.42128479296152,1.29006826877594,3003,63981.060650110245,0.6832632422447205,34.29670099865465,1.444430947303772,0.6907168030738831,30.255927333565843,1.3926249742507937,3000 -26683.23649024964,1.4774634838104248,38681.39353013039,106052,0,38681.39353013039,0.7068037986755371,30.07426208954982,1.2878295183181765,3003,65369.194038152695,0.6834360361099243,34.66426654290793,1.44869065284729,0.692055881023407,30.47464467803755,1.3876729011535645,3000 -27293.159260988235,1.5152764320373535,39521.30562400818,108357,0,39521.30562400818,0.7074894309043884,30.3378661903933,1.2821885347366333,3003,66819.13472270966,0.6897974610328674,35.04683613009326,1.4114888906478882,0.6923534870147705,30.593538745220897,1.3850077390670776,3000 -27895.077735185623,1.5552773475646973,40361.46251535416,110664,0,40361.46251535416,0.7101272344589233,30.811757487590313,1.273249864578247,3003,68261.31683158875,0.6857077479362488,34.912174504037615,1.4301151037216187,0.6944116950035095,30.888292461111718,1.377577304840088,3000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 2c439890a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1158 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.725228,11.120249,,,,,,,,,,,,,,,,, -1,,,0.0006708127912133,11.12966251373291,0.0,0.0004835649742744,11.12306308746338,0.0,3000.0,0.0007088489946909,11.114374160766602,0.0,3003.0,37.97205352783203,938.4486303329468,37.97205352783203,900.4765393733978,0.0,0.0 -100,0.21887526,8.272218,,,,,,,,,,,,,,,,, -200,0.38372335,7.484184,,,,,,,,,,,,,,,,, -300,0.60493875,6.904033,,,,,,,,,,,,,,,,, -400,0.4638379,6.318812,,,,,,,,,,,,,,,,, -500,0.3883469,5.914494,,,,,,,,,,,,,,,,, -600,0.48332554,5.589924,,,,,,,,,,,,,,,,, -700,0.51485074,5.3105464,,,,,,,,,,,,,,,,, -800,0.35538694,5.034418,,,,,,,,,,,,,,,,, -900,0.47955674,4.7558756,,,,,,,,,,,,,,,,, -1000,0.6570616,4.662245,,,,,,,,,,,,,,,,, -1100,0.6227917,4.310379,,,,,,,,,,,,,,,,, -1200,0.57548285,4.0753136,,,,,,,,,,,,,,,,, -1300,0.47599283,4.01641,,,,,,,,,,,,,,,,, -1400,0.52494824,3.7851775,,,,,,,,,,,,,,,,, -1500,0.5289649,3.5601988,,,,,,,,,,,,,,,,, -1600,0.44953576,3.3836448,,,,,,,,,,,,,,,,, -1700,0.4503388,3.4285252,,,,,,,,,,,,,,,,, -1800,0.39155906,3.3259916,,,,,,,,,,,,,,,,, -1900,0.5687533,3.2698247,,,,,,,,,,,,,,,,, -2000,0.44308496,3.159562,,,,,,,,,,,,,,,,, -2100,0.35789463,3.066498,,,,,,,,,,,,,,,,, -2200,0.41886663,3.1343932,,,,,,,,,,,,,,,,, -2300,0.38303238,3.0801651,,,,,,,,,,,,,,,,, -2302,,,0.5118868350982666,2.8945295810699463,21.798427704443863,0.5098015069961548,2.891077995300293,17.50913139290973,3000.0,0.5070478320121765,2.945228576660156,16.23465916489031,3003.0,877.8537795543671,2248.415539741516,877.8537795543671,1370.3998427391052,0.0915694236755371,0.0 -2400,0.39508376,2.8956304,,,,,,,,,,,,,,,,, -2500,0.33646703,2.9960992,,,,,,,,,,,,,,,,, -2600,0.3957514,2.8347135,,,,,,,,,,,,,,,,, -2700,0.31320435,2.8561444,,,,,,,,,,,,,,,,, -2800,0.26729307,2.7026067,,,,,,,,,,,,,,,,, -2900,0.26953614,2.7543688,,,,,,,,,,,,,,,,, -3000,0.24678704,2.709902,,,,,,,,,,,,,,,,, -3100,0.30836302,2.7077296,,,,,,,,,,,,,,,,, -3200,0.35101584,2.5858142,,,,,,,,,,,,,,,,, -3300,0.23874788,2.5429776,,,,,,,,,,,,,,,,, -3400,0.2481794,2.5790415,,,,,,,,,,,,,,,,, -3500,0.22397053,2.4988787,,,,,,,,,,,,,,,,, -3600,0.26213962,2.5137556,,,,,,,,,,,,,,,,, -3700,0.20281321,2.5625422,,,,,,,,,,,,,,,,, -3800,0.20899118,2.4696174,,,,,,,,,,,,,,,,, -3900,0.2304461,2.4649239,,,,,,,,,,,,,,,,, -4000,0.19710103,2.6012518,,,,,,,,,,,,,,,,, -4100,0.22324988,2.3901448,,,,,,,,,,,,,,,,, -4200,0.20932771,2.3625617,,,,,,,,,,,,,,,,, -4300,0.22953755,2.356823,,,,,,,,,,,,,,,,, -4400,0.18655264,2.3360684,,,,,,,,,,,,,,,,, -4500,0.17303692,2.3230548,,,,,,,,,,,,,,,,, -4600,0.18371367,2.3282235,,,,,,,,,,,,,,,,, -4607,,,0.5724933743476868,2.27587366104126,26.642778119238887,0.5883870124816895,2.16740345954895,23.35456405805425,3000.0,0.5886119604110718,2.14756727218628,21.83645237670241,3003.0,1717.9510114192965,3537.737144947052,1717.9510114192965,1819.5305128097527,0.1161665916442871,0.0 -4700,0.19333638,2.3663044,,,,,,,,,,,,,,,,, -4800,0.16526279,2.2958837,,,,,,,,,,,,,,,,, -4900,0.18257584,2.3164883,,,,,,,,,,,,,,,,, -5000,0.16423742,2.219531,,,,,,,,,,,,,,,,, -5100,0.1905681,2.2659192,,,,,,,,,,,,,,,,, -5200,0.1963034,2.209758,,,,,,,,,,,,,,,,, -5300,0.17518897,2.211074,,,,,,,,,,,,,,,,, -5400,0.17478536,2.264068,,,,,,,,,,,,,,,,, -5500,0.15040073,2.2820966,,,,,,,,,,,,,,,,, -5600,0.17927234,2.1624415,,,,,,,,,,,,,,,,, -5700,0.15802926,2.1998866,,,,,,,,,,,,,,,,, -5800,0.19136491,2.1412961,,,,,,,,,,,,,,,,, -5900,0.16127153,2.1341417,,,,,,,,,,,,,,,,, -6000,0.1583638,2.1834874,,,,,,,,,,,,,,,,, -6100,0.1867782,2.255277,,,,,,,,,,,,,,,,, -6200,0.15530622,2.1372895,,,,,,,,,,,,,,,,, -6300,0.19860497,2.2156324,,,,,,,,,,,,,,,,, -6400,0.1615611,2.2132847,,,,,,,,,,,,,,,,, -6500,0.17532712,2.234128,,,,,,,,,,,,,,,,, -6600,0.16268836,2.1157367,,,,,,,,,,,,,,,,, -6700,0.18521269,2.1610837,,,,,,,,,,,,,,,,, -6800,0.14746332,2.1763203,,,,,,,,,,,,,,,,, -6900,0.17076617,2.094633,,,,,,,,,,,,,,,,, -6912,,,0.6066136360168457,2.0027647018432617,29.017854244266523,0.6156030297279358,1.9432120323181152,25.35619328040451,3000.0,0.6200685501098633,1.8967156410217283,24.104531101372096,3003.0,2557.946650266648,4827.942455291748,2557.946650266648,2269.650138616562,0.1388981342315673,0.0 -7000,0.1868866,2.246604,,,,,,,,,,,,,,,,, -7100,0.16178605,2.1798859,,,,,,,,,,,,,,,,, -7200,0.18504375,2.1500673,,,,,,,,,,,,,,,,, -7300,0.15706912,2.003743,,,,,,,,,,,,,,,,, -7400,0.17228512,2.1404243,,,,,,,,,,,,,,,,, -7500,0.14443426,2.069366,,,,,,,,,,,,,,,,, -7600,0.16331434,2.164281,,,,,,,,,,,,,,,,, -7700,0.15993975,2.1808534,,,,,,,,,,,,,,,,, -7800,0.15012805,1.9865309,,,,,,,,,,,,,,,,, -7900,0.17578469,2.0673082,,,,,,,,,,,,,,,,, -8000,0.18078229,2.1677597,,,,,,,,,,,,,,,,, -8100,0.23737773,2.0866435,,,,,,,,,,,,,,,,, -8200,0.1824337,2.0817595,,,,,,,,,,,,,,,,, -8300,0.17011201,2.0971446,,,,,,,,,,,,,,,,, -8400,0.16869423,2.08519,,,,,,,,,,,,,,,,, -8500,0.19561577,2.045548,,,,,,,,,,,,,,,,, -8600,0.16517131,2.0041316,,,,,,,,,,,,,,,,, -8700,0.1756661,2.0450118,,,,,,,,,,,,,,,,, -8800,0.16353321,2.0790687,,,,,,,,,,,,,,,,, -8900,0.17540805,2.0843074,,,,,,,,,,,,,,,,, -9000,0.1590442,2.057417,,,,,,,,,,,,,,,,, -9100,0.15968852,1.9399925,,,,,,,,,,,,,,,,, -9200,0.18134679,2.1622221,,,,,,,,,,,,,,,,, -9215,,,0.6120672225952148,1.959139347076416,29.916192478814693,0.6295396089553833,1.8218027353286743,26.2413759186834,3000.0,0.6364883184432983,1.7690547704696655,25.23227095534961,3003.0,3397.8931415081024,6126.708575963974,3397.8931415081024,2728.376378059387,0.1646599769592285,0.0 -9300,0.16704239,2.0158057,,,,,,,,,,,,,,,,, -9400,0.19519399,1.985942,,,,,,,,,,,,,,,,, -9500,0.2041135,2.0560675,,,,,,,,,,,,,,,,, -9600,0.1790225,2.0975206,,,,,,,,,,,,,,,,, -9700,0.20747028,1.9740381,,,,,,,,,,,,,,,,, -9800,0.16681235,2.0648475,,,,,,,,,,,,,,,,, -9900,0.1882708,1.9613676,,,,,,,,,,,,,,,,, -10000,0.17375496,1.9480528,,,,,,,,,,,,,,,,, -10100,0.20238723,2.0462132,,,,,,,,,,,,,,,,, -10200,0.1811101,1.9192091,,,,,,,,,,,,,,,,, -10300,0.16705176,2.0296757,,,,,,,,,,,,,,,,, -10400,0.18359502,1.8924712,,,,,,,,,,,,,,,,, -10500,0.17497672,2.0670815,,,,,,,,,,,,,,,,, -10600,0.19588187,1.9625057,,,,,,,,,,,,,,,,, -10700,0.18217286,1.995643,,,,,,,,,,,,,,,,, -10800,0.17692468,2.0994346,,,,,,,,,,,,,,,,, -10900,0.18130854,1.9136074,,,,,,,,,,,,,,,,, -11000,0.20167996,2.0956905,,,,,,,,,,,,,,,,, -11100,0.18807012,1.8671209,,,,,,,,,,,,,,,,, -11200,0.18328637,1.9882109,,,,,,,,,,,,,,,,, -11300,0.18194526,2.0085137,,,,,,,,,,,,,,,,, -11400,0.1825603,2.021753,,,,,,,,,,,,,,,,, -11500,0.1936278,1.9847482,,,,,,,,,,,,,,,,, -11519,,,0.6196917295455933,1.887145757675171,30.07146527648772,0.6392481327056885,1.7518912553787231,26.88392878478057,3000.0,0.6449480056762695,1.7010329961776731,25.399824245978028,3003.0,4237.998886823654,7466.583305358887,4237.998886823654,3228.05371260643,0.1887707710266113,0.0 -11600,0.19617115,1.8582898,,,,,,,,,,,,,,,,, -11700,0.16742776,1.9844165,,,,,,,,,,,,,,,,, -11800,0.26806894,1.8809295,,,,,,,,,,,,,,,,, -11900,0.18868488,2.0458505,,,,,,,,,,,,,,,,, -12000,0.24352168,1.9172976,,,,,,,,,,,,,,,,, -12100,0.172744,1.9128417,,,,,,,,,,,,,,,,, -12200,0.1895464,1.9983252,,,,,,,,,,,,,,,,, -12300,0.25966784,1.9732307,,,,,,,,,,,,,,,,, -12400,0.22509979,1.8585854,,,,,,,,,,,,,,,,, -12500,0.19753747,1.9881876,,,,,,,,,,,,,,,,, -12600,0.19299574,1.9044908,,,,,,,,,,,,,,,,, -12700,0.26574627,1.9485823,,,,,,,,,,,,,,,,, -12800,0.17857145,1.904555,,,,,,,,,,,,,,,,, -12900,0.24465479,1.9572692,,,,,,,,,,,,,,,,, -13000,0.25389248,1.8901299,,,,,,,,,,,,,,,,, -13100,0.25377494,1.9284517,,,,,,,,,,,,,,,,, -13200,0.25322846,2.0437965,,,,,,,,,,,,,,,,, -13300,0.2180294,1.977274,,,,,,,,,,,,,,,,, -13400,0.21445183,1.9545237,,,,,,,,,,,,,,,,, -13500,0.16347837,1.9281029,,,,,,,,,,,,,,,,, -13600,0.22107595,1.9339895,,,,,,,,,,,,,,,,, -13700,0.20206705,1.9219893,,,,,,,,,,,,,,,,, -13800,0.19050057,1.8847668,,,,,,,,,,,,,,,,, -13824,,,0.6305181980133057,1.8076016902923584,30.40540911721149,0.6429430246353149,1.710771918296814,27.113700672480107,3000.0,0.6524780988693237,1.6474509239196775,26.42720410787868,3003.0,5078.007721185684,8761.350906610489,5078.007721185684,3682.720309734345,0.214097261428833,0.0 -13900,0.19399033,1.9362454,,,,,,,,,,,,,,,,, -14000,0.1899593,1.8591722,,,,,,,,,,,,,,,,, -14100,0.28497428,1.880201,,,,,,,,,,,,,,,,, -14200,0.22189671,1.9487259,,,,,,,,,,,,,,,,, -14300,0.17895195,1.8664428,,,,,,,,,,,,,,,,, -14400,0.20045403,2.0218806,,,,,,,,,,,,,,,,, -14500,0.18691398,1.9748067,,,,,,,,,,,,,,,,, -14600,0.19554853,1.8745332,,,,,,,,,,,,,,,,, -14700,0.17413194,1.8735292,,,,,,,,,,,,,,,,, -14800,0.21735467,1.9285421,,,,,,,,,,,,,,,,, -14900,0.22916612,1.9125522,,,,,,,,,,,,,,,,, -15000,0.1952695,1.9357249,,,,,,,,,,,,,,,,, -15100,0.17317423,1.8637515,,,,,,,,,,,,,,,,, -15200,0.2012051,1.911244,,,,,,,,,,,,,,,,, -15300,0.24223152,1.9193363,,,,,,,,,,,,,,,,, -15400,0.18475069,1.9276538,,,,,,,,,,,,,,,,, -15500,0.18922082,1.874103,,,,,,,,,,,,,,,,, -15600,0.21002844,1.9642277,,,,,,,,,,,,,,,,, -15700,0.17845955,1.8770584,,,,,,,,,,,,,,,,, -15800,0.20422831,1.8951536,,,,,,,,,,,,,,,,, -15900,0.22421527,1.9599253,,,,,,,,,,,,,,,,, -16000,0.19351572,1.9563444,,,,,,,,,,,,,,,,, -16100,0.18566555,1.8978081,,,,,,,,,,,,,,,,, -16128,,,0.6288556456565857,1.831342458724976,30.45977095033596,0.6467743515968323,1.6920876502990725,27.353765269924896,3000.0,0.6546278595924377,1.6145895719528198,26.11191716618898,3003.0,5918.018773555756,10129.056498527529,5918.018773555756,4210.324150323868,0.2381901741027832,0.0 -16200,0.23309629,1.8227935,,,,,,,,,,,,,,,,, -16300,0.23193292,1.8497376,,,,,,,,,,,,,,,,, -16400,0.206105,1.9057126,,,,,,,,,,,,,,,,, -16500,0.29089853,1.8826979,,,,,,,,,,,,,,,,, -16600,0.20648298,1.8835407,,,,,,,,,,,,,,,,, -16700,0.20999098,1.8298522,,,,,,,,,,,,,,,,, -16800,0.27127612,1.8695679,,,,,,,,,,,,,,,,, -16900,0.2041268,1.9104047,,,,,,,,,,,,,,,,, -17000,0.17713448,1.9452796,,,,,,,,,,,,,,,,, -17100,0.18155882,2.0274844,,,,,,,,,,,,,,,,, -17200,0.23303054,1.8337766,,,,,,,,,,,,,,,,, -17300,0.3405526,1.9013281,,,,,,,,,,,,,,,,, -17400,0.2066327,1.8826042,,,,,,,,,,,,,,,,, -17500,0.23360905,1.8366085,,,,,,,,,,,,,,,,, -17600,0.17989606,1.8543257,,,,,,,,,,,,,,,,, -17700,0.25300914,1.8907616,,,,,,,,,,,,,,,,, -17800,0.22472404,1.8461522,,,,,,,,,,,,,,,,, -17900,0.21531165,1.8946003,,,,,,,,,,,,,,,,, -18000,0.19906,1.8153063,,,,,,,,,,,,,,,,, -18100,0.18849868,1.9244578,,,,,,,,,,,,,,,,, -18200,0.18775623,1.8724413,,,,,,,,,,,,,,,,, -18300,0.21660459,1.7810409,,,,,,,,,,,,,,,,, -18400,0.2624984,1.8262403,,,,,,,,,,,,,,,,, -18432,,,0.6309792995452881,1.8097506761550903,30.74278095983484,0.6506924629211426,1.6586767435073853,27.71685002958468,3000.0,0.661158561706543,1.5822104215621948,26.92102044075361,3003.0,6757.959350347519,11437.245300531387,6757.959350347519,4678.479519367218,0.2643682956695556,0.0 -18500,0.2194399,1.9572475,,,,,,,,,,,,,,,,, -18600,0.2579745,1.8786278,,,,,,,,,,,,,,,,, -18700,0.22866727,1.8663156,,,,,,,,,,,,,,,,, -18800,0.21179776,1.9065053,,,,,,,,,,,,,,,,, -18900,0.27929685,1.8740318,,,,,,,,,,,,,,,,, -19000,0.19104786,1.9164963,,,,,,,,,,,,,,,,, -19100,0.21359906,1.9203401,,,,,,,,,,,,,,,,, -19200,0.17767271,1.7975224,,,,,,,,,,,,,,,,, -19300,0.2041025,1.8898406,,,,,,,,,,,,,,,,, -19400,0.2541415,1.8186282,,,,,,,,,,,,,,,,, -19500,0.88708794,1.9043449,,,,,,,,,,,,,,,,, -19600,0.20382455,1.9470683,,,,,,,,,,,,,,,,, -19700,0.23151776,1.8457482,,,,,,,,,,,,,,,,, -19800,0.17765188,1.8821541,,,,,,,,,,,,,,,,, -19900,0.19381574,1.861106,,,,,,,,,,,,,,,,, -20000,0.24116425,1.8825033,,,,,,,,,,,,,,,,, -20100,0.18496484,1.8177859,,,,,,,,,,,,,,,,, -20200,0.2074206,1.8822348,,,,,,,,,,,,,,,,, -20300,0.19372642,1.9187021,,,,,,,,,,,,,,,,, -20400,0.20954755,1.864337,,,,,,,,,,,,,,,,, -20500,0.19141047,1.8388512,,,,,,,,,,,,,,,,, -20600,0.20090076,1.7451503,,,,,,,,,,,,,,,,, -20700,0.20141557,1.9007944,,,,,,,,,,,,,,,,, -20737,,,0.6350600719451904,1.7639600038528442,31.086818959270097,0.6527507305145264,1.6468219757080078,27.800610918108475,3000.0,0.6625065803527832,1.571100831031799,27.02671804348759,3003.0,7598.0957589149475,12838.938373327255,7598.0957589149475,5239.943262815476,0.2907545566558838,0.0 -20800,0.21672323,1.8000901,,,,,,,,,,,,,,,,, -20900,0.21489151,1.8449923,,,,,,,,,,,,,,,,, -21000,0.1702988,1.8415383,,,,,,,,,,,,,,,,, -21100,0.22616506,1.8485434,,,,,,,,,,,,,,,,, -21200,0.19599774,1.879248,,,,,,,,,,,,,,,,, -21300,0.23794775,1.9080986,,,,,,,,,,,,,,,,, -21400,0.18661982,1.8020455,,,,,,,,,,,,,,,,, -21500,0.18899453,1.9449914,,,,,,,,,,,,,,,,, -21600,0.21938188,1.8581544,,,,,,,,,,,,,,,,, -21700,0.23376843,1.8573976,,,,,,,,,,,,,,,,, -21800,0.2000783,1.9125057,,,,,,,,,,,,,,,,, -21900,0.23279163,1.9556597,,,,,,,,,,,,,,,,, -22000,0.229415,1.8421594,,,,,,,,,,,,,,,,, -22100,0.21524917,1.8975619,,,,,,,,,,,,,,,,, -22200,0.2075936,1.899851,,,,,,,,,,,,,,,,, -22300,0.20642112,1.9120712,,,,,,,,,,,,,,,,, -22400,0.22296691,1.9138896,,,,,,,,,,,,,,,,, -22500,0.19819915,1.9067614,,,,,,,,,,,,,,,,, -22600,0.2212368,1.8231748,,,,,,,,,,,,,,,,, -22700,0.26154745,1.865338,,,,,,,,,,,,,,,,, -22800,0.18485032,1.9060357,,,,,,,,,,,,,,,,, -22900,0.21461456,1.8878677,,,,,,,,,,,,,,,,, -23000,0.19744939,1.8228345,,,,,,,,,,,,,,,,, -23042,,,0.6321536302566528,1.7879424095153809,30.843176318892503,0.6557885408401489,1.6294809579849243,27.80706443069435,3000.0,0.6654813885688782,1.552293062210083,27.471505111759694,3003.0,8438.187176465988,14191.30432844162,8438.187176465988,5752.122546672821,0.3188722133636474,0.0 -23100,0.21679504,1.8202388,,,,,,,,,,,,,,,,, -23200,0.22278449,1.9358603,,,,,,,,,,,,,,,,, -23300,0.2121848,1.8138794,,,,,,,,,,,,,,,,, -23400,0.27858755,1.9360181,,,,,,,,,,,,,,,,, -23500,0.20382872,1.8191433,,,,,,,,,,,,,,,,, -23600,0.34468782,1.9308604,,,,,,,,,,,,,,,,, -23700,0.18789414,1.7567244,,,,,,,,,,,,,,,,, -23800,0.22213082,1.8209732,,,,,,,,,,,,,,,,, -23900,0.25383896,1.9195099,,,,,,,,,,,,,,,,, -24000,0.19316146,1.8263482,,,,,,,,,,,,,,,,, -24100,0.28567806,1.8243202,,,,,,,,,,,,,,,,, -24200,0.2151045,1.9091214,,,,,,,,,,,,,,,,, -24300,0.21136595,1.870337,,,,,,,,,,,,,,,,, -24400,0.23184016,1.8694725,,,,,,,,,,,,,,,,, -24500,0.22499248,1.7698404,,,,,,,,,,,,,,,,, -24600,0.21198614,1.885006,,,,,,,,,,,,,,,,, -24700,0.22189565,1.8937255,,,,,,,,,,,,,,,,, -24800,0.22922094,1.8402493,,,,,,,,,,,,,,,,, -24900,0.21530488,1.8586183,,,,,,,,,,,,,,,,, -25000,0.19325922,1.8350323,,,,,,,,,,,,,,,,, -25100,0.19209653,1.8058814,,,,,,,,,,,,,,,,, -25200,0.23737195,1.8633024,,,,,,,,,,,,,,,,, -25300,0.29300997,1.8435543,,,,,,,,,,,,,,,,, -25346,,,0.6534295678138733,1.6287175416946411,31.96841674854174,0.6557636857032776,1.620172142982483,27.763272642865847,3000.0,0.6647841930389404,1.544036865234375,26.93845269141764,3003.0,9278.177888393402,15572.249928712845,9278.177888393402,6292.985018253326,0.3446135520935058,0.0 -25400,0.19549444,1.8638182,,,,,,,,,,,,,,,,, -25500,0.20152883,1.7373201,,,,,,,,,,,,,,,,, -25600,0.27310967,1.7633777,,,,,,,,,,,,,,,,, -25700,0.20439003,1.800768,,,,,,,,,,,,,,,,, -25800,0.19584431,1.8751693,,,,,,,,,,,,,,,,, -25900,0.21887337,1.8186239,,,,,,,,,,,,,,,,, -26000,0.21371488,1.7590178,,,,,,,,,,,,,,,,, -26100,0.2059834,1.9042809,,,,,,,,,,,,,,,,, -26200,0.18478426,1.8143431,,,,,,,,,,,,,,,,, -26300,0.24724546,1.8136845,,,,,,,,,,,,,,,,, -26400,0.21436891,1.7674319,,,,,,,,,,,,,,,,, -26500,0.2317531,1.8372827,,,,,,,,,,,,,,,,, -26600,0.1922243,1.8838383,,,,,,,,,,,,,,,,, -26700,0.2227068,1.785513,,,,,,,,,,,,,,,,, -26800,0.22052382,1.8404841,,,,,,,,,,,,,,,,, -26900,0.23231202,1.8225113,,,,,,,,,,,,,,,,, -27000,0.19157995,1.7772663,,,,,,,,,,,,,,,,, -27100,0.21766944,1.823278,,,,,,,,,,,,,,,,, -27200,0.24090263,1.846098,,,,,,,,,,,,,,,,, -27300,0.18524708,1.8033412,,,,,,,,,,,,,,,,, -27400,0.2745156,1.7728082,,,,,,,,,,,,,,,,, -27500,0.24675693,1.8288391,,,,,,,,,,,,,,,,, -27600,0.24448043,1.774571,,,,,,,,,,,,,,,,, -27651,,,0.6354812383651733,1.7555207014083862,30.85710459621641,0.655738890171051,1.6140739917755127,27.76954025263315,3000.0,0.6674219965934753,1.5395432710647583,27.34712870999217,3003.0,10118.242518424988,16990.762141942978,10118.242518424988,6871.334754228592,0.3750591278076172,0.0 -27700,0.20209208,1.8285884,,,,,,,,,,,,,,,,, -27800,0.2154314,1.8016835,,,,,,,,,,,,,,,,, -27900,0.20263688,1.7493265,,,,,,,,,,,,,,,,, -28000,0.21286963,1.8503228,,,,,,,,,,,,,,,,, -28100,0.19960427,1.8399311,,,,,,,,,,,,,,,,, -28200,0.18308677,1.8590282,,,,,,,,,,,,,,,,, -28300,0.22825767,1.7652057,,,,,,,,,,,,,,,,, -28400,0.32754615,1.8577001,,,,,,,,,,,,,,,,, -28500,0.20178565,1.9027889,,,,,,,,,,,,,,,,, -28600,0.20633605,1.8458,,,,,,,,,,,,,,,,, -28700,0.17387724,1.7602073,,,,,,,,,,,,,,,,, -28800,0.4792249,1.9239275,,,,,,,,,,,,,,,,, -28900,0.19005127,1.9192818,,,,,,,,,,,,,,,,, -29000,0.2777935,1.7693895,,,,,,,,,,,,,,,,, -29100,0.21997178,1.8778462,,,,,,,,,,,,,,,,, -29200,0.2415114,1.844625,,,,,,,,,,,,,,,,, -29300,0.22840641,1.8595914,,,,,,,,,,,,,,,,, -29400,0.49759552,1.8800012,,,,,,,,,,,,,,,,, -29500,0.20674886,1.8564223,,,,,,,,,,,,,,,,, -29600,0.22233847,1.7687769,,,,,,,,,,,,,,,,, -29700,0.25394955,1.8111584,,,,,,,,,,,,,,,,, -29800,0.2468758,1.8742718,,,,,,,,,,,,,,,,, -29900,0.254914,1.7650115,,,,,,,,,,,,,,,,, -29957,,,0.6409510970115662,1.732843995094299,30.828748013400578,0.6587767004966736,1.6031513214111328,27.44061347257815,3000.0,0.6674801111221313,1.530782699584961,26.982702387093063,3003.0,10958.493296861649,18369.059331417084,10958.493296861649,7409.287669897079,0.4019620418548584,0.0 -30000,0.19241421,1.7607818,,,,,,,,,,,,,,,,, -30100,0.19808823,1.7956398,,,,,,,,,,,,,,,,, -30200,0.18407828,1.8392296,,,,,,,,,,,,,,,,, -30300,0.19516855,1.812278,,,,,,,,,,,,,,,,, -30400,0.20580128,1.8139771,,,,,,,,,,,,,,,,, -30500,0.19276561,1.8140093,,,,,,,,,,,,,,,,, -30600,0.20539975,1.7246318,,,,,,,,,,,,,,,,, -30700,0.24389663,1.7786806,,,,,,,,,,,,,,,,, -30800,0.23670459,1.8232634,,,,,,,,,,,,,,,,, -30900,0.18403502,1.8235582,,,,,,,,,,,,,,,,, -31000,0.19593683,1.7989206,,,,,,,,,,,,,,,,, -31100,0.21113285,1.8468152,,,,,,,,,,,,,,,,, -31200,0.19311985,1.811303,,,,,,,,,,,,,,,,, -31300,0.19387454,1.869241,,,,,,,,,,,,,,,,, -31400,0.20393308,1.8415551,,,,,,,,,,,,,,,,, -31500,0.213504,1.7694716,,,,,,,,,,,,,,,,, -31600,0.21602738,1.7323735,,,,,,,,,,,,,,,,, -31700,0.2128564,1.8659235,,,,,,,,,,,,,,,,, -31800,0.34121233,1.8504794,,,,,,,,,,,,,,,,, -31900,0.19603932,1.7478994,,,,,,,,,,,,,,,,, -32000,0.19256555,1.8032103,,,,,,,,,,,,,,,,, -32100,0.18514118,1.8417727,,,,,,,,,,,,,,,,, -32200,0.20124635,1.8141757,,,,,,,,,,,,,,,,, -32263,,,0.6501732468605042,1.671488881111145,31.79653325067196,0.6608969569206238,1.5915027856826782,27.934553907088265,3000.0,0.6707687377929688,1.5176211595535278,27.344738706710345,3003.0,11798.761286258698,19755.955070257187,11798.761286258698,7955.819365501404,0.4307701587677002,0.0 -32300,0.19431567,1.6990482,,,,,,,,,,,,,,,,, -32400,0.22443864,1.7860925,,,,,,,,,,,,,,,,, -32500,0.21269569,1.7607458,,,,,,,,,,,,,,,,, -32600,0.22579871,1.8453435,,,,,,,,,,,,,,,,, -32700,0.19660163,1.8175523,,,,,,,,,,,,,,,,, -32800,0.24838595,1.893913,,,,,,,,,,,,,,,,, -32900,0.18267727,1.8018612,,,,,,,,,,,,,,,,, -33000,0.20767872,1.7913563,,,,,,,,,,,,,,,,, -33100,0.18302211,1.7886332,,,,,,,,,,,,,,,,, -33200,0.23524779,1.8387854,,,,,,,,,,,,,,,,, -33300,0.29138392,1.8046961,,,,,,,,,,,,,,,,, -33400,0.20463224,1.854581,,,,,,,,,,,,,,,,, -33500,0.20145123,1.902606,,,,,,,,,,,,,,,,, -33600,0.20809308,1.8335843,,,,,,,,,,,,,,,,, -33700,0.18386735,1.749725,,,,,,,,,,,,,,,,, -33800,0.18122219,1.7396417,,,,,,,,,,,,,,,,, -33900,0.21464382,1.7876956,,,,,,,,,,,,,,,,, -34000,0.20181252,1.8903598,,,,,,,,,,,,,,,,, -34100,0.20069543,1.8086677,,,,,,,,,,,,,,,,, -34200,0.19453253,1.8621507,,,,,,,,,,,,,,,,, -34300,0.2141208,1.7701533,,,,,,,,,,,,,,,,, -34400,0.2488802,1.8194069,,,,,,,,,,,,,,,,, -34500,0.20145008,1.7991111,,,,,,,,,,,,,,,,, -34568,,,0.639989972114563,1.737483024597168,31.40815149263302,0.6610333323478699,1.5890058279037476,28.18224271573024,3000.0,0.6739526987075806,1.5071828365325928,27.71240378788925,3003.0,12638.940054178238,21124.38235092163,12638.940054178238,8483.97195148468,0.4588100910186767,0.0 -34600,0.22468671,1.8104818,,,,,,,,,,,,,,,,, -34700,0.18113282,1.8197374,,,,,,,,,,,,,,,,, -34800,0.19566411,1.7189071,,,,,,,,,,,,,,,,, -34900,0.23439658,1.8065034,,,,,,,,,,,,,,,,, -35000,0.18641604,1.8068769,,,,,,,,,,,,,,,,, -35100,0.23003106,1.7609025,,,,,,,,,,,,,,,,, -35200,0.23544651,1.7763442,,,,,,,,,,,,,,,,, -35300,0.2453307,1.8014655,,,,,,,,,,,,,,,,, -35400,0.1990334,1.8355614,,,,,,,,,,,,,,,,, -35500,0.2257675,1.7855372,,,,,,,,,,,,,,,,, -35600,0.21580872,1.7588601,,,,,,,,,,,,,,,,, -35700,0.20144366,1.8094373,,,,,,,,,,,,,,,,, -35800,0.20805474,1.7694271,,,,,,,,,,,,,,,,, -35900,0.19273488,1.7388456,,,,,,,,,,,,,,,,, -36000,0.19623157,1.7588022,,,,,,,,,,,,,,,,, -36100,0.7499222,1.7999254,,,,,,,,,,,,,,,,, -36200,0.73479676,1.7979845,,,,,,,,,,,,,,,,, -36300,0.4159825,1.7875409,,,,,,,,,,,,,,,,, -36400,0.2161127,1.8233871,,,,,,,,,,,,,,,,, -36500,0.20779887,1.8248249,,,,,,,,,,,,,,,,, -36600,0.1974043,1.7536877,,,,,,,,,,,,,,,,, -36700,0.29695162,1.8274662,,,,,,,,,,,,,,,,, -36800,0.2177079,1.7468479,,,,,,,,,,,,,,,,, -36873,,,0.6420833468437195,1.7289228439331057,31.19616011212892,0.662744402885437,1.5755648612976074,28.243285991467825,3000.0,0.6739643216133118,1.4984428882598877,27.51490274102142,3003.0,13478.908013343813,22514.191542625427,13478.908013343813,9033.718392133713,0.4855470657348633,0.0 -36900,0.22276421,1.8332689,,,,,,,,,,,,,,,,, -37000,0.20423459,1.7966105,,,,,,,,,,,,,,,,, -37100,0.2247872,1.7516197,,,,,,,,,,,,,,,,, -37200,0.2031689,1.8169822,,,,,,,,,,,,,,,,, -37300,0.22517665,1.7128404,,,,,,,,,,,,,,,,, -37400,0.21555462,1.7800026,,,,,,,,,,,,,,,,, -37500,0.20230313,1.8685077,,,,,,,,,,,,,,,,, -37600,0.26254827,1.7837726,,,,,,,,,,,,,,,,, -37700,0.20664722,1.8014035,,,,,,,,,,,,,,,,, -37800,0.20358686,1.8162571,,,,,,,,,,,,,,,,, -37900,0.20822968,1.7878747,,,,,,,,,,,,,,,,, -38000,0.17681296,1.7639724,,,,,,,,,,,,,,,,, -38100,0.18371373,1.7968692,,,,,,,,,,,,,,,,, -38200,0.21692012,1.8495094,,,,,,,,,,,,,,,,, -38300,0.21688683,1.826138,,,,,,,,,,,,,,,,, -38400,0.19987616,1.7958779,,,,,,,,,,,,,,,,, -38500,0.2221066,1.7970871,,,,,,,,,,,,,,,,, -38600,0.21221209,1.8562343,,,,,,,,,,,,,,,,, -38700,0.25378537,1.7460737,,,,,,,,,,,,,,,,, -38800,0.2214892,1.7173792,,,,,,,,,,,,,,,,, -38900,0.20059912,1.7432387,,,,,,,,,,,,,,,,, -39000,0.22517836,1.811445,,,,,,,,,,,,,,,,, -39100,0.19995715,1.7504742,,,,,,,,,,,,,,,,, -39178,,,0.6463394165039062,1.6828906536102295,31.62721259487478,0.6632651686668396,1.5713107585906982,28.16472535589393,3000.0,0.6731741428375244,1.4919666051864624,27.401873969473264,3003.0,14318.914702415466,23984.88139367104,14318.914702415466,9664.304621696472,0.5149462223052979,0.0 -39200,0.36271286,1.8349108,,,,,,,,,,,,,,,,, -39300,0.2136281,1.8763108,,,,,,,,,,,,,,,,, -39400,0.1990496,1.7709653,,,,,,,,,,,,,,,,, -39500,0.2355059,1.8282566,,,,,,,,,,,,,,,,, -39600,0.33609685,1.7042962,,,,,,,,,,,,,,,,, -39700,0.24302772,1.822921,,,,,,,,,,,,,,,,, -39800,0.23989585,1.770578,,,,,,,,,,,,,,,,, -39900,0.20227085,1.7586297,,,,,,,,,,,,,,,,, -40000,0.20518793,1.7506549,,,,,,,,,,,,,,,,, -40100,0.19476028,1.7497349,,,,,,,,,,,,,,,,, -40200,0.22687715,1.7823774,,,,,,,,,,,,,,,,, -40300,0.20014136,1.80208,,,,,,,,,,,,,,,,, -40400,0.22359492,1.8064144,,,,,,,,,,,,,,,,, -40500,0.20938294,1.7993443,,,,,,,,,,,,,,,,, -40600,0.19624975,1.8080262,,,,,,,,,,,,,,,,, -40700,0.20161912,1.7635117,,,,,,,,,,,,,,,,, -40800,0.20934153,1.7095233,,,,,,,,,,,,,,,,, -40900,0.20889995,1.7726257,,,,,,,,,,,,,,,,, -41000,0.22600785,1.7733179,,,,,,,,,,,,,,,,, -41100,0.53162193,1.784725,,,,,,,,,,,,,,,,, -41200,0.22770816,1.7524548,,,,,,,,,,,,,,,,, -41300,0.18748607,1.7264541,,,,,,,,,,,,,,,,, -41400,0.19904096,1.7501731,,,,,,,,,,,,,,,,, -41484,,,0.6445146203041077,1.705758571624756,31.45555725558557,0.6643810868263245,1.5640370845794678,28.44802331863066,3000.0,0.677241325378418,1.4807902574539185,27.910100164262925,3003.0,15159.026709794998,25402.29855298996,15159.026709794998,10241.513010263445,0.5448994636535645,0.0 -41500,0.2543409,1.7784773,,,,,,,,,,,,,,,,, -41600,0.23214312,1.7788495,,,,,,,,,,,,,,,,, -41700,0.22668715,1.8026798,,,,,,,,,,,,,,,,, -41800,0.24731438,1.744937,,,,,,,,,,,,,,,,, -41900,0.22578338,1.8201811,,,,,,,,,,,,,,,,, -42000,0.21343404,1.7496907,,,,,,,,,,,,,,,,, -42100,0.19623275,1.8335418,,,,,,,,,,,,,,,,, -42200,0.23399673,1.8325818,,,,,,,,,,,,,,,,, -42300,0.23631649,1.8001218,,,,,,,,,,,,,,,,, -42400,0.21772397,1.778767,,,,,,,,,,,,,,,,, -42500,0.30535924,1.722139,,,,,,,,,,,,,,,,, -42600,0.22024627,1.771434,,,,,,,,,,,,,,,,, -42700,0.19198538,1.8103914,,,,,,,,,,,,,,,,, -42800,0.24801691,1.7795166,,,,,,,,,,,,,,,,, -42900,0.20883715,1.7174338,,,,,,,,,,,,,,,,, -43000,0.252283,1.7733179,,,,,,,,,,,,,,,,, -43100,0.2459779,1.7387822,,,,,,,,,,,,,,,,, -43200,0.225083,1.7211642,,,,,,,,,,,,,,,,, -43300,0.19102252,1.7652105,,,,,,,,,,,,,,,,, -43400,0.19308412,1.7335162,,,,,,,,,,,,,,,,, -43500,0.20394342,1.8212669,,,,,,,,,,,,,,,,, -43600,0.21195899,1.7268984,,,,,,,,,,,,,,,,, -43700,0.19812049,1.8253504,,,,,,,,,,,,,,,,, -43790,,,0.678406298160553,1.4894598722457886,34.31987309324694,0.666340172290802,1.5590335130691528,28.83492635339662,3000.0,0.6773110628128052,1.4803187847137451,28.13029009785281,3003.0,15999.166652917862,26731.12109351158,15999.166652917862,10730.098761081696,0.5752518177032471,0.0 -43800,0.20645963,1.8062682,,,,,,,,,,,,,,,,, -43900,0.20607546,1.7567165,,,,,,,,,,,,,,,,, -44000,0.22395675,1.7515104,,,,,,,,,,,,,,,,, -44100,0.20239368,1.7624986,,,,,,,,,,,,,,,,, -44200,0.22781153,1.7522864,,,,,,,,,,,,,,,,, -44300,0.21411274,1.7029922,,,,,,,,,,,,,,,,, -44400,0.2570739,1.7437408,,,,,,,,,,,,,,,,, -44500,0.19099586,1.8071986,,,,,,,,,,,,,,,,, -44600,0.18781976,1.7359833,,,,,,,,,,,,,,,,, -44700,0.28933817,1.7951968,,,,,,,,,,,,,,,,, -44800,0.23018385,1.8026785,,,,,,,,,,,,,,,,, -44900,0.18952537,1.7825646,,,,,,,,,,,,,,,,, -45000,0.3129982,1.6841245,,,,,,,,,,,,,,,,, -45100,0.19649349,1.8013464,,,,,,,,,,,,,,,,, -45200,0.2394724,1.7868503,,,,,,,,,,,,,,,,, -45300,0.19486049,1.7495532,,,,,,,,,,,,,,,,, -45400,0.21660024,1.8048731,,,,,,,,,,,,,,,,, -45500,0.20093387,1.7602983,,,,,,,,,,,,,,,,, -45600,0.1883239,1.7400584,,,,,,,,,,,,,,,,, -45700,0.21708673,1.7045876,,,,,,,,,,,,,,,,, -45800,0.20910136,1.7003654,,,,,,,,,,,,,,,,, -45900,0.29309738,1.8066146,,,,,,,,,,,,,,,,, -46000,0.20364942,1.7580011,,,,,,,,,,,,,,,,, -46095,,,0.6512089371681213,1.6554925441741943,31.58748186605699,0.6669477224349976,1.552070379257202,28.66998281072759,3000.0,0.6787520051002502,1.4696136713027954,28.32559704180684,3003.0,16839.33149933815,28159.8185441494,16839.33149933815,11318.53396177292,0.605954647064209,0.0 -46100,0.22283968,1.7405652,,,,,,,,,,,,,,,,, -46200,0.21205726,1.7326288,,,,,,,,,,,,,,,,, -46300,0.197305,1.7233561,,,,,,,,,,,,,,,,, -46400,0.2269588,1.8327677,,,,,,,,,,,,,,,,, -46500,0.27360034,1.807146,,,,,,,,,,,,,,,,, -46600,0.19826716,1.7856946,,,,,,,,,,,,,,,,, -46700,0.22592615,1.7814759,,,,,,,,,,,,,,,,, -46800,0.20868474,1.7250535,,,,,,,,,,,,,,,,, -46900,0.19321296,1.7484026,,,,,,,,,,,,,,,,, -47000,0.21079521,1.729932,,,,,,,,,,,,,,,,, -47100,0.21274558,1.8528306,,,,,,,,,,,,,,,,, -47200,0.19963104,1.7117925,,,,,,,,,,,,,,,,, -47300,0.23395003,1.6882746,,,,,,,,,,,,,,,,, -47400,0.19556287,1.7719196,,,,,,,,,,,,,,,,, -47500,0.23737556,1.8210357,,,,,,,,,,,,,,,,, -47600,0.21559381,1.7080623,,,,,,,,,,,,,,,,, -47700,0.20792727,1.7766423,,,,,,,,,,,,,,,,, -47800,0.21095103,1.6783867,,,,,,,,,,,,,,,,, -47900,0.20300417,1.7019539,,,,,,,,,,,,,,,,, -48000,0.2793281,1.7935442,,,,,,,,,,,,,,,,, -48100,0.30379567,1.7096204,,,,,,,,,,,,,,,,, -48200,0.18714261,1.7820028,,,,,,,,,,,,,,,,, -48300,0.2177389,1.7455065,,,,,,,,,,,,,,,,, -48400,0.20077612,1.7407163,,,,,,,,,,,,,,,,, -48401,,,0.648673415184021,1.6775165796279907,31.704655840258024,0.6679148077964783,1.5459349155426023,28.49673433425077,3000.0,0.6809017658233643,1.458929181098938,28.30155804844667,3003.0,17679.514671325684,29587.330949544907,17679.514671325684,11905.7664437294,0.6351094245910645,0.0 -48500,0.21070458,1.7761191,,,,,,,,,,,,,,,,, -48600,0.20464368,1.7747568,,,,,,,,,,,,,,,,, -48700,0.21381985,1.8136033,,,,,,,,,,,,,,,,, -48800,0.19120376,1.7014621,,,,,,,,,,,,,,,,, -48900,0.20037854,1.7393708,,,,,,,,,,,,,,,,, -49000,0.23005041,1.7554191,,,,,,,,,,,,,,,,, -49100,0.18498173,1.7434293,,,,,,,,,,,,,,,,, -49200,0.18383993,1.7085724,,,,,,,,,,,,,,,,, -49300,0.38405457,1.7583963,,,,,,,,,,,,,,,,, -49400,0.21044903,1.790432,,,,,,,,,,,,,,,,, -49500,0.23652998,1.7335827,,,,,,,,,,,,,,,,, -49600,0.18315238,1.7845157,,,,,,,,,,,,,,,,, -49700,0.20563242,1.759367,,,,,,,,,,,,,,,,, -49800,0.18012105,1.725721,,,,,,,,,,,,,,,,, -49900,0.19689102,1.7860835,,,,,,,,,,,,,,,,, -50000,0.2046072,1.7581544,,,,,,,,,,,,,,,,, -50100,0.20569274,1.7471818,,,,,,,,,,,,,,,,, -50200,0.21193448,1.767768,,,,,,,,,,,,,,,,, -50300,0.27419204,1.7135377,,,,,,,,,,,,,,,,, -50400,0.20406988,1.772164,,,,,,,,,,,,,,,,, -50500,0.2062477,1.6858474,,,,,,,,,,,,,,,,, -50600,0.20117296,1.7370665,,,,,,,,,,,,,,,,, -50700,0.21445243,1.7729195,,,,,,,,,,,,,,,,, -50707,,,0.6610821485519409,1.5793545246124268,32.63553191464314,0.6676668524742126,1.5392777919769287,28.773472505424525,3000.0,0.6794027090072632,1.4583882093429563,27.870305365442896,3003.0,18519.6294836998,30975.2739675045,18519.6294836998,12453.499742031096,0.663583517074585,0.0 -50800,0.20951208,1.824441,,,,,,,,,,,,,,,,, -50900,0.26078174,1.7090379,,,,,,,,,,,,,,,,, -51000,0.21148467,1.6786604,,,,,,,,,,,,,,,,, -51100,0.20784344,1.7496656,,,,,,,,,,,,,,,,, -51200,0.32811177,1.7453779,,,,,,,,,,,,,,,,, -51300,0.19823408,1.7859701,,,,,,,,,,,,,,,,, -51400,0.21337481,1.739386,,,,,,,,,,,,,,,,, -51500,0.22229843,1.7131201,,,,,,,,,,,,,,,,, -51600,0.19529797,1.8042864,,,,,,,,,,,,,,,,, -51700,0.2065068,1.7073617,,,,,,,,,,,,,,,,, -51800,0.23260103,1.7808311,,,,,,,,,,,,,,,,, -51900,0.2202326,1.7883818,,,,,,,,,,,,,,,,, -52000,0.19155055,1.6679384,,,,,,,,,,,,,,,,, -52100,0.20913175,1.7713826,,,,,,,,,,,,,,,,, -52200,0.21103048,1.700663,,,,,,,,,,,,,,,,, -52300,0.18329321,1.7100784,,,,,,,,,,,,,,,,, -52400,0.21473274,1.7039591,,,,,,,,,,,,,,,,, -52500,0.19830641,1.7253342,,,,,,,,,,,,,,,,, -52600,0.1962596,1.7791535,,,,,,,,,,,,,,,,, -52700,0.22917058,1.7017132,,,,,,,,,,,,,,,,, -52800,0.19406195,1.7409241,,,,,,,,,,,,,,,,, -52900,0.18584925,1.7032076,,,,,,,,,,,,,,,,, -53000,0.23253995,1.74997,,,,,,,,,,,,,,,,, -53012,,,0.6509921550750732,1.663438320159912,31.95503741876957,0.6693159341812134,1.528120040893555,28.66539324958268,3000.0,0.6824240684509277,1.4426162242889404,28.27440388210995,3003.0,19359.6201646328,32402.945701360703,19359.6201646328,13041.082812786102,0.6946089267730713,0.0 -53100,0.19009171,1.6734719,,,,,,,,,,,,,,,,, -53200,0.19275391,1.7374054,,,,,,,,,,,,,,,,, -53300,0.4065494,1.7233453,,,,,,,,,,,,,,,,, -53400,0.25014764,1.7148856,,,,,,,,,,,,,,,,, -53500,0.21885508,1.7093344,,,,,,,,,,,,,,,,, -53600,0.18150873,1.6706719,,,,,,,,,,,,,,,,, -53700,0.2429211,1.6852002,,,,,,,,,,,,,,,,, -53800,0.28181502,1.7664762,,,,,,,,,,,,,,,,, -53900,0.21320885,1.758334,,,,,,,,,,,,,,,,, -54000,0.19890617,1.7730502,,,,,,,,,,,,,,,,, -54100,0.19703962,1.650518,,,,,,,,,,,,,,,,, -54200,0.21172227,1.7952282,,,,,,,,,,,,,,,,, -54300,0.22722481,1.7292565,,,,,,,,,,,,,,,,, -54400,0.22345266,1.725178,,,,,,,,,,,,,,,,, -54500,0.20951085,1.7572381,,,,,,,,,,,,,,,,, -54600,0.21572098,1.698798,,,,,,,,,,,,,,,,, -54700,0.39929688,1.6479423,,,,,,,,,,,,,,,,, -54800,0.2019564,1.7194661,,,,,,,,,,,,,,,,, -54900,0.19576663,1.6908938,,,,,,,,,,,,,,,,, -55000,0.19241224,1.6664637,,,,,,,,,,,,,,,,, -55100,0.2034906,1.7662598,,,,,,,,,,,,,,,,, -55200,0.19734724,1.6271143,,,,,,,,,,,,,,,,, -55300,0.19330288,1.7879642,,,,,,,,,,,,,,,,, -55318,,,0.6490379571914673,1.6658644676208496,31.85959430543849,0.668584406375885,1.5241024494171145,28.63456609958289,3000.0,0.6829469799995422,1.442746877670288,28.26515132787788,3003.0,20199.62682914734,33809.05586147308,20199.62682914734,13607.088655233383,0.7248268127441406,0.0 -55400,0.20331298,1.6687992,,,,,,,,,,,,,,,,, -55500,0.19727376,1.6601051,,,,,,,,,,,,,,,,, -55600,0.19336303,1.7253752,,,,,,,,,,,,,,,,, -55700,0.3785516,1.6596057,,,,,,,,,,,,,,,,, -55800,0.23917426,1.7072212,,,,,,,,,,,,,,,,, -55900,0.20105053,1.6618092,,,,,,,,,,,,,,,,, -56000,0.20462713,1.8207705,,,,,,,,,,,,,,,,, -56100,0.26749304,1.6364346,,,,,,,,,,,,,,,,, -56200,0.28029406,1.7366309,,,,,,,,,,,,,,,,, -56300,0.19796951,1.6848934,,,,,,,,,,,,,,,,, -56400,0.19285716,1.740619,,,,,,,,,,,,,,,,, -56500,0.18455832,1.7097569,,,,,,,,,,,,,,,,, -56600,0.19292471,1.7097027,,,,,,,,,,,,,,,,, -56700,0.22518525,1.7193407,,,,,,,,,,,,,,,,, -56800,0.18834043,1.6764297,,,,,,,,,,,,,,,,, -56900,0.21605101,1.7187953,,,,,,,,,,,,,,,,, -57000,0.20141989,1.6714813,,,,,,,,,,,,,,,,, -57100,0.22947495,1.7782254,,,,,,,,,,,,,,,,, -57200,0.19962248,1.6772348,,,,,,,,,,,,,,,,, -57300,0.20570673,1.6665558,,,,,,,,,,,,,,,,, -57400,0.21851532,1.7531263,,,,,,,,,,,,,,,,, -57500,0.2014447,1.6991014,,,,,,,,,,,,,,,,, -57600,0.19320346,1.6136364,,,,,,,,,,,,,,,,, -57624,,,0.6594597697257996,1.596582531929016,32.67480079360041,0.6716469526290894,1.5170725584030151,29.14386669057417,3000.0,0.6851897239685059,1.4309513568878174,28.56584009832488,3003.0,21039.827073812485,35219.81306219101,21039.827073812485,14177.54667854309,0.7564868927001953,0.0 -57700,0.22915272,1.7685778,,,,,,,,,,,,,,,,, -57800,0.19324556,1.6967695,,,,,,,,,,,,,,,,, -57900,0.22226441,1.6238776,,,,,,,,,,,,,,,,, -58000,0.19595575,1.6683112,,,,,,,,,,,,,,,,, -58100,0.19940485,1.7400588,,,,,,,,,,,,,,,,, -58200,0.18999057,1.6111863,,,,,,,,,,,,,,,,, -58300,0.21084692,1.6650193,,,,,,,,,,,,,,,,, -58400,0.19968605,1.6650454,,,,,,,,,,,,,,,,, -58500,0.1994697,1.685,,,,,,,,,,,,,,,,, -58600,0.18064824,1.6666666,,,,,,,,,,,,,,,,, -58700,0.21890296,1.6910163,,,,,,,,,,,,,,,,, -58800,0.19305101,1.7291318,,,,,,,,,,,,,,,,, -58900,0.22190312,1.7102367,,,,,,,,,,,,,,,,, -59000,0.19277906,1.7134261,,,,,,,,,,,,,,,,, -59100,0.2004091,1.7408639,,,,,,,,,,,,,,,,, -59200,0.21330108,1.7291247,,,,,,,,,,,,,,,,, -59300,0.26589495,1.7657923,,,,,,,,,,,,,,,,, -59400,0.21803345,1.7356029,,,,,,,,,,,,,,,,, -59500,0.21589917,1.6633899,,,,,,,,,,,,,,,,, -59600,0.2046597,1.6243293,,,,,,,,,,,,,,,,, -59700,0.18213663,1.7090003,,,,,,,,,,,,,,,,, -59800,0.193442,1.698142,,,,,,,,,,,,,,,,, -59900,0.19385666,1.6487288,,,,,,,,,,,,,,,,, -59929,,,0.6542331576347351,1.6329275369644165,32.1923354088288,0.6711261868476868,1.514757752418518,28.66304524085199,3000.0,0.6837139129638672,1.4270353317260742,28.37492052326649,3003.0,21880.05396389961,36639.86911034584,21880.05396389961,14757.277026891708,0.7875394821166992,0.0 -60000,0.21210417,1.7158306,,,,,,,,,,,,,,,,, -60100,0.19325247,1.7409427,,,,,,,,,,,,,,,,, -60200,0.19117877,1.6777735,,,,,,,,,,,,,,,,, -60300,0.21989352,1.7815399,,,,,,,,,,,,,,,,, -60400,0.20510079,1.7508209,,,,,,,,,,,,,,,,, -60500,0.30644083,1.7487352,,,,,,,,,,,,,,,,, -60600,0.22276695,1.6751941,,,,,,,,,,,,,,,,, -60700,0.193485,1.6484174,,,,,,,,,,,,,,,,, -60800,0.26403144,1.7107536,,,,,,,,,,,,,,,,, -60900,0.20278281,1.6648494,,,,,,,,,,,,,,,,, -61000,0.24114047,1.6904848,,,,,,,,,,,,,,,,, -61100,0.19117334,1.6559778,,,,,,,,,,,,,,,,, -61200,0.23028581,1.717435,,,,,,,,,,,,,,,,, -61300,0.19214445,1.6757698,,,,,,,,,,,,,,,,, -61400,0.2160384,1.706743,,,,,,,,,,,,,,,,, -61500,0.19173516,1.6669943,,,,,,,,,,,,,,,,, -61600,0.18804522,1.6541866,,,,,,,,,,,,,,,,, -61700,0.21502422,1.6588266,,,,,,,,,,,,,,,,, -61800,0.22264415,1.7484381,,,,,,,,,,,,,,,,, -61900,0.22027586,1.7124546,,,,,,,,,,,,,,,,, -62000,0.25208455,1.7536293,,,,,,,,,,,,,,,,, -62100,0.22082117,1.6639587,,,,,,,,,,,,,,,,, -62200,0.22687574,1.7361202,,,,,,,,,,,,,,,,, -62235,,,0.6547898054122925,1.6457945108413696,32.29131481800968,0.6748583316802979,1.5035715103149414,29.144210761682988,3000.0,0.6874208450317383,1.4156063795089722,28.6996542227149,3003.0,22720.285170316696,37984.17668604851,22720.285170316696,15261.25569176674,0.8184239864349365,0.0 -62300,0.21523853,1.7132916,,,,,,,,,,,,,,,,, -62400,0.20758903,1.5922985,,,,,,,,,,,,,,,,, -62500,0.20229588,1.6484137,,,,,,,,,,,,,,,,, -62600,0.20027883,1.6567776,,,,,,,,,,,,,,,,, -62700,0.19950743,1.6842188,,,,,,,,,,,,,,,,, -62800,0.22579269,1.6448532,,,,,,,,,,,,,,,,, -62900,0.21207036,1.7208425,,,,,,,,,,,,,,,,, -63000,0.25597453,1.7277753,,,,,,,,,,,,,,,,, -63100,0.17977287,1.6188565,,,,,,,,,,,,,,,,, -63200,0.1970549,1.7215298,,,,,,,,,,,,,,,,, -63300,0.19877347,1.6802378,,,,,,,,,,,,,,,,, -63400,0.21996349,1.7637193,,,,,,,,,,,,,,,,, -63500,0.21108976,1.7153533,,,,,,,,,,,,,,,,, -63600,0.18803349,1.668376,,,,,,,,,,,,,,,,, -63700,0.2180895,1.673912,,,,,,,,,,,,,,,,, -63800,0.20424768,1.7529054,,,,,,,,,,,,,,,,, -63900,0.24226835,1.7161436,,,,,,,,,,,,,,,,, -64000,0.19476156,1.7498065,,,,,,,,,,,,,,,,, -64100,0.20469318,1.6548946,,,,,,,,,,,,,,,,, -64200,0.22859052,1.7538116,,,,,,,,,,,,,,,,, -64300,0.27705663,1.792267,,,,,,,,,,,,,,,,, -64400,0.20596567,1.723329,,,,,,,,,,,,,,,,, -64500,0.20861968,1.7306298,,,,,,,,,,,,,,,,, -64540,,,0.6574546694755554,1.6062756776809692,32.33868366983711,0.6742879748344421,1.499944806098938,29.16130690093247,3000.0,0.6869560480117798,1.4149410724639893,29.01646062679263,3003.0,23560.25806093216,39405.44167017937,23560.25806093216,15842.44847869873,0.8506321907043457,0.0 -64600,0.22018312,1.7849963,,,,,,,,,,,,,,,,, -64700,0.1956993,1.5985395,,,,,,,,,,,,,,,,, -64800,0.20387127,1.6177552,,,,,,,,,,,,,,,,, -64900,0.20435461,1.6374146,,,,,,,,,,,,,,,,, -65000,0.20427041,1.6640741,,,,,,,,,,,,,,,,, -65100,0.19985646,1.6219746,,,,,,,,,,,,,,,,, -65200,0.21295927,1.7102686,,,,,,,,,,,,,,,,, -65300,0.2065419,1.6691686,,,,,,,,,,,,,,,,, -65400,0.193711,1.68275,,,,,,,,,,,,,,,,, -65500,0.22094949,1.742182,,,,,,,,,,,,,,,,, -65600,0.1854285,1.6966524,,,,,,,,,,,,,,,,, -65700,0.1850329,1.5780112,,,,,,,,,,,,,,,,, -65800,0.21035331,1.6689079,,,,,,,,,,,,,,,,, -65900,0.21714815,1.7291743,,,,,,,,,,,,,,,,, -66000,0.1926891,1.6406822,,,,,,,,,,,,,,,,, -66100,0.19488566,1.7171805,,,,,,,,,,,,,,,,, -66200,0.2235625,1.684792,,,,,,,,,,,,,,,,, -66300,0.1994467,1.6398981,,,,,,,,,,,,,,,,, -66400,0.19417974,1.5649301,,,,,,,,,,,,,,,,, -66500,0.18925399,1.5941502,,,,,,,,,,,,,,,,, -66600,0.2217431,1.734451,,,,,,,,,,,,,,,,, -66700,0.18717718,1.6666538,,,,,,,,,,,,,,,,, -66800,0.21821705,1.6636004,,,,,,,,,,,,,,,,, -66847,,,0.6589851975440979,1.6024383306503296,32.50040292950039,0.675317108631134,1.4933512210845947,29.14662722173597,3000.0,0.6883156299591064,1.4015823602676392,28.97097313380021,3003.0,24400.42298603058,40783.24477171898,24400.42298603058,16379.987210035324,0.8830595016479492,0.0 -66900,0.19734712,1.6550617,,,,,,,,,,,,,,,,, -67000,0.20140883,1.7163193,,,,,,,,,,,,,,,,, -67100,0.20762752,1.6939579,,,,,,,,,,,,,,,,, -67200,0.20502587,1.6801497,,,,,,,,,,,,,,,,, -67300,0.21704225,1.7018176,,,,,,,,,,,,,,,,, -67400,0.2189018,1.682886,,,,,,,,,,,,,,,,, -67500,0.2140187,1.5870291,,,,,,,,,,,,,,,,, -67600,0.19976375,1.6030908,,,,,,,,,,,,,,,,, -67700,0.1955638,1.6938437,,,,,,,,,,,,,,,,, -67800,0.19399078,1.6931682,,,,,,,,,,,,,,,,, -67900,0.2090459,1.7727486,,,,,,,,,,,,,,,,, -68000,0.2889842,1.6349194,,,,,,,,,,,,,,,,, -68100,0.2294471,1.6712668,,,,,,,,,,,,,,,,, -68200,0.21018513,1.7894498,,,,,,,,,,,,,,,,, -68300,0.2115702,1.6956959,,,,,,,,,,,,,,,,, -68400,0.1888846,1.638264,,,,,,,,,,,,,,,,, -68500,0.21215397,1.6847378,,,,,,,,,,,,,,,,, -68600,0.19827043,1.6821431,,,,,,,,,,,,,,,,, -68700,0.2077332,1.7306811,,,,,,,,,,,,,,,,, -68800,0.23521484,1.7223555,,,,,,,,,,,,,,,,, -68900,0.19482157,1.6595467,,,,,,,,,,,,,,,,, -69000,0.20278314,1.644595,,,,,,,,,,,,,,,,, -69100,0.21168491,1.7171823,,,,,,,,,,,,,,,,, -69153,,,0.6765788197517395,1.4881170988082886,33.541029229499564,0.676656186580658,1.4876593351364136,29.058710968380737,3000.0,0.689233660697937,1.3974084854125977,29.031414589198896,3003.0,25240.656269311905,42234.59798383713,25240.656269311905,16991.00972509384,0.9139461517333984,0.0 -69200,0.23263429,1.7389927,,,,,,,,,,,,,,,,, -69300,0.21880893,1.6543081,,,,,,,,,,,,,,,,, -69400,0.20461844,1.6294072,,,,,,,,,,,,,,,,, -69500,0.2046605,1.6659364,,,,,,,,,,,,,,,,, -69600,0.21734259,1.6769112,,,,,,,,,,,,,,,,, -69700,0.21622966,1.6416333,,,,,,,,,,,,,,,,, -69800,0.20544505,1.6322263,,,,,,,,,,,,,,,,, -69900,0.2011793,1.7015777,,,,,,,,,,,,,,,,, -70000,0.19648671,1.5805947,,,,,,,,,,,,,,,,, -70100,0.19895397,1.6774431,,,,,,,,,,,,,,,,, -70200,0.19854422,1.6234609,,,,,,,,,,,,,,,,, -70300,0.2461557,1.6332424,,,,,,,,,,,,,,,,, -70400,0.23089476,1.6966465,,,,,,,,,,,,,,,,, -70500,0.20132717,1.6518743,,,,,,,,,,,,,,,,, -70600,0.19552289,1.7318206,,,,,,,,,,,,,,,,, -70700,0.22793499,1.693096,,,,,,,,,,,,,,,,, -70800,0.20536529,1.6338233,,,,,,,,,,,,,,,,, -70900,0.21686164,1.7040668,,,,,,,,,,,,,,,,, -71000,0.20811717,1.678328,,,,,,,,,,,,,,,,, -71100,0.20172071,1.6464659,,,,,,,,,,,,,,,,, -71200,0.22793084,1.6433566,,,,,,,,,,,,,,,,, -71300,0.20358461,1.711304,,,,,,,,,,,,,,,,, -71400,0.20666914,1.6640965,,,,,,,,,,,,,,,,, -71459,,,0.665115475654602,1.5632176399230957,32.916180646277745,0.6775985360145569,1.4751025438308716,29.321900177988866,3000.0,0.6905351281166077,1.3871610164642334,28.987678393285474,3003.0,26080.56746864319,43714.0570063591,26080.56746864319,17630.459098100662,0.9451935291290284,0.0 -71500,0.19655554,1.7093782,,,,,,,,,,,,,,,,, -71600,0.19366397,1.6944374,,,,,,,,,,,,,,,,, -71700,0.2675658,1.6563418,,,,,,,,,,,,,,,,, -71800,0.20108525,1.6621352,,,,,,,,,,,,,,,,, -71900,0.1848118,1.6434007,,,,,,,,,,,,,,,,, -72000,0.20145418,1.6338449,,,,,,,,,,,,,,,,, -72100,0.21032049,1.6573876,,,,,,,,,,,,,,,,, -72200,0.2113841,1.6654681,,,,,,,,,,,,,,,,, -72300,0.22918625,1.6880151,,,,,,,,,,,,,,,,, -72400,0.23122808,1.5942497,,,,,,,,,,,,,,,,, -72500,0.21342206,1.6484617,,,,,,,,,,,,,,,,, -72600,0.20622478,1.6425643,,,,,,,,,,,,,,,,, -72700,0.20275958,1.6114259,,,,,,,,,,,,,,,,, -72800,0.2198327,1.7171328,,,,,,,,,,,,,,,,, -72900,0.2078231,1.6221507,,,,,,,,,,,,,,,,, -73000,0.21790892,1.6626979,,,,,,,,,,,,,,,,, -73100,0.22496034,1.6482825,,,,,,,,,,,,,,,,, -73200,0.21073277,1.6302788,,,,,,,,,,,,,,,,, -73300,0.25487387,1.6645746,,,,,,,,,,,,,,,,, -73400,0.22751306,1.7401402,,,,,,,,,,,,,,,,, -73500,0.19239156,1.5707176,,,,,,,,,,,,,,,,, -73600,0.2421986,1.6739408,,,,,,,,,,,,,,,,, -73700,0.19346613,1.6659933,,,,,,,,,,,,,,,,, -73765,,,0.6569005846977234,1.6199302673339844,33.01953946347801,0.6786896586418152,1.4717580080032349,29.54283922963444,3000.0,0.6927546262741089,1.381793975830078,28.91781925583517,3003.0,26920.736166715626,45117.89793419838,26920.736166715626,18194.033081531525,0.9769787788391112,0.0 -73800,0.20340522,1.6761031,,,,,,,,,,,,,,,,, -73900,0.23760559,1.6377791,,,,,,,,,,,,,,,,, -74000,0.21688972,1.6797917,,,,,,,,,,,,,,,,, -74100,0.22506346,1.6786679,,,,,,,,,,,,,,,,, -74200,0.21553205,1.6878108,,,,,,,,,,,,,,,,, -74300,0.21475421,1.6495723,,,,,,,,,,,,,,,,, -74400,0.20245218,1.6367568,,,,,,,,,,,,,,,,, -74500,0.2246441,1.6779597,,,,,,,,,,,,,,,,, -74600,0.20787244,1.6173165,,,,,,,,,,,,,,,,, -74700,0.2156998,1.6642752,,,,,,,,,,,,,,,,, -74800,1.6292732,1.6771742,,,,,,,,,,,,,,,,, -74900,0.20084372,1.6229336,,,,,,,,,,,,,,,,, -75000,0.26123625,1.599568,,,,,,,,,,,,,,,,, -75100,0.19165337,1.6559373,,,,,,,,,,,,,,,,, -75200,0.2164001,1.6746408,,,,,,,,,,,,,,,,, -75300,0.21228206,1.6375865,,,,,,,,,,,,,,,,, -75400,0.20552558,1.6205394,,,,,,,,,,,,,,,,, -75500,0.18867649,1.6096563,,,,,,,,,,,,,,,,, -75600,0.2023422,1.6542343,,,,,,,,,,,,,,,,, -75700,0.19918066,1.7163768,,,,,,,,,,,,,,,,, -75800,0.22177994,1.7392048,,,,,,,,,,,,,,,,, -75900,0.21526244,1.7008288,,,,,,,,,,,,,,,,, -76000,0.21037604,1.6362296,,,,,,,,,,,,,,,,, -76071,,,0.6729021668434143,1.506928563117981,33.29321228860028,0.6776605248451233,1.4689043760299685,29.190714508145124,3000.0,0.6914531588554382,1.3808372020721436,28.920786485322584,3003.0,27760.940415620804,46509.46474575997,27760.940415620804,18745.29731321335,1.009082555770874,0.0 -76100,0.19836701,1.5759238,,,,,,,,,,,,,,,,, -76200,0.21220548,1.6415067,,,,,,,,,,,,,,,,, -76300,0.1929588,1.6062329,,,,,,,,,,,,,,,,, -76400,0.21513681,1.7371083,,,,,,,,,,,,,,,,, -76500,0.20463045,1.6183604,,,,,,,,,,,,,,,,, -76600,0.20265594,1.6023382,,,,,,,,,,,,,,,,, -76700,0.21425855,1.6576807,,,,,,,,,,,,,,,,, -76800,0.20925003,1.7484372,,,,,,,,,,,,,,,,, -76900,0.20745,1.640087,,,,,,,,,,,,,,,,, -77000,0.21385309,1.6570843,,,,,,,,,,,,,,,,, -77100,0.21090765,1.6732688,,,,,,,,,,,,,,,,, -77200,0.22184072,1.6142185,,,,,,,,,,,,,,,,, -77300,0.21533002,1.6508497,,,,,,,,,,,,,,,,, -77400,0.19966076,1.603609,,,,,,,,,,,,,,,,, -77500,0.21840481,1.6366963,,,,,,,,,,,,,,,,, -77600,0.20806305,1.6472843,,,,,,,,,,,,,,,,, -77700,0.20940265,1.6300527,,,,,,,,,,,,,,,,, -77800,0.22886047,1.6946323,,,,,,,,,,,,,,,,, -77900,0.20224781,1.6635875,,,,,,,,,,,,,,,,, -78000,0.20822378,1.6582389,,,,,,,,,,,,,,,,, -78100,0.2113013,1.6505507,,,,,,,,,,,,,,,,, -78200,0.2189721,1.7107676,,,,,,,,,,,,,,,,, -78300,0.20976594,1.6336496,,,,,,,,,,,,,,,,, -78377,,,0.6650096774101257,1.5661970376968384,33.123997561018435,0.6799543499946594,1.4611165523529053,29.59472036218692,3000.0,0.6954041123390198,1.366335391998291,29.17774261731953,3003.0,28600.997208356857,48044.48940658569,28600.997208356857,19440.16348552704,1.0446865558624268,0.0 -78400,0.20609424,1.5514709,,,,,,,,,,,,,,,,, -78500,0.19863433,1.6385881,,,,,,,,,,,,,,,,, -78600,0.21303497,1.7163092,,,,,,,,,,,,,,,,, -78700,0.21632549,1.6071746,,,,,,,,,,,,,,,,, -78800,0.22481765,1.660467,,,,,,,,,,,,,,,,, -78900,0.19273068,1.699265,,,,,,,,,,,,,,,,, -79000,0.2081595,1.5679435,,,,,,,,,,,,,,,,, -79100,0.223007,1.6086843,,,,,,,,,,,,,,,,, -79200,0.20472558,1.6727351,,,,,,,,,,,,,,,,, -79300,0.20359592,1.5759485,,,,,,,,,,,,,,,,, -79400,0.23479693,1.6834913,,,,,,,,,,,,,,,,, -79500,0.19333397,1.6961855,,,,,,,,,,,,,,,,, -79600,0.21544741,1.6196655,,,,,,,,,,,,,,,,, -79700,0.19898143,1.722031,,,,,,,,,,,,,,,,, -79800,0.23088233,1.7091943,,,,,,,,,,,,,,,,, -79900,0.2037375,1.6225389,,,,,,,,,,,,,,,,, -80000,0.2210052,1.6507144,,,,,,,,,,,,,,,,, -80100,0.20743994,1.619804,,,,,,,,,,,,,,,,, -80200,0.2107366,1.5749642,,,,,,,,,,,,,,,,, -80300,0.20365445,1.7061816,,,,,,,,,,,,,,,,, -80400,0.21939671,1.6729563,,,,,,,,,,,,,,,,, -80500,0.21574248,1.643286,,,,,,,,,,,,,,,,, -80600,0.23083477,1.6139153,,,,,,,,,,,,,,,,, -80684,,,0.66160649061203,1.584054946899414,33.317098467297384,0.6825581789016724,1.45094633102417,29.821797940794205,3000.0,0.6970890760421753,1.3583707809448242,29.614426414284896,3003.0,29441.037866830826,49467.49631810188,29441.037866830826,20023.029673337936,1.0782907009124756,0.0 -80700,0.20196883,1.5804815,,,,,,,,,,,,,,,,, -80800,0.20211759,1.6075016,,,,,,,,,,,,,,,,, -80900,0.20350486,1.6979226,,,,,,,,,,,,,,,,, -81000,0.19705689,1.53482,,,,,,,,,,,,,,,,, -81100,0.18646282,1.6123623,,,,,,,,,,,,,,,,, -81200,0.21791394,1.6037323,,,,,,,,,,,,,,,,, -81300,0.21687497,1.6776688,,,,,,,,,,,,,,,,, -81400,0.21652853,1.6841007,,,,,,,,,,,,,,,,, -81500,0.2059576,1.6012634,,,,,,,,,,,,,,,,, -81600,0.21880232,1.6704928,,,,,,,,,,,,,,,,, -81700,0.19778176,1.554482,,,,,,,,,,,,,,,,, -81800,0.2130788,1.6580482,,,,,,,,,,,,,,,,, -81900,0.19198811,1.6375681,,,,,,,,,,,,,,,,, -82000,0.25226924,1.5865226,,,,,,,,,,,,,,,,, -82100,0.22239159,1.614553,,,,,,,,,,,,,,,,, -82200,0.20437308,1.5947675,,,,,,,,,,,,,,,,, -82300,0.21364151,1.6618787,,,,,,,,,,,,,,,,, -82400,0.22253801,1.6453292,,,,,,,,,,,,,,,,, -82500,0.19105498,1.5950954,,,,,,,,,,,,,,,,, -82600,0.2056042,1.6565312,,,,,,,,,,,,,,,,, -82700,0.20683806,1.6684157,,,,,,,,,,,,,,,,, -82800,0.25321388,1.6476489,,,,,,,,,,,,,,,,, -82900,0.20577037,1.6114286,,,,,,,,,,,,,,,,, -82990,,,0.6741752624511719,1.5034538507461548,33.5210667999066,0.6817522048950195,1.444745421409607,29.79512560255042,3000.0,0.6958457231521606,1.355960488319397,29.73206205813485,3003.0,30281.05094337464,50895.96606874466,30281.05094337464,20611.38403081894,1.1141114234924316,0.0 -83000,0.21025997,1.5527805,,,,,,,,,,,,,,,,, -83100,0.19338109,1.5747255,,,,,,,,,,,,,,,,, -83200,0.20710103,1.5916741,,,,,,,,,,,,,,,,, -83300,0.22544663,1.6895511,,,,,,,,,,,,,,,,, -83400,0.21185666,1.5880212,,,,,,,,,,,,,,,,, -83500,0.20257327,1.6301911,,,,,,,,,,,,,,,,, -83600,0.22088273,1.6285506,,,,,,,,,,,,,,,,, -83700,0.20789024,1.6530417,,,,,,,,,,,,,,,,, -83800,0.20798403,1.6062573,,,,,,,,,,,,,,,,, -83900,0.20430504,1.6288877,,,,,,,,,,,,,,,,, -84000,0.21821532,1.6569384,,,,,,,,,,,,,,,,, -84100,0.20823891,1.5993414,,,,,,,,,,,,,,,,, -84200,0.20264865,1.600827,,,,,,,,,,,,,,,,, -84300,0.19433425,1.5464745,,,,,,,,,,,,,,,,, -84400,0.2266402,1.597581,,,,,,,,,,,,,,,,, -84500,0.20657587,1.5704049,,,,,,,,,,,,,,,,, -84600,0.20959415,1.5651904,,,,,,,,,,,,,,,,, -84700,0.23389922,1.6051105,,,,,,,,,,,,,,,,, -84800,0.20175952,1.6594667,,,,,,,,,,,,,,,,, -84900,0.19209376,1.5084672,,,,,,,,,,,,,,,,, -85000,0.20966496,1.5881037,,,,,,,,,,,,,,,,, -85100,0.2159229,1.502844,,,,,,,,,,,,,,,,, -85200,0.21195588,1.5785438,,,,,,,,,,,,,,,,, -85296,,,0.6712327003479004,1.5273113250732422,33.1370007701296,0.6831037402153015,1.4382604360580444,29.524563702331893,3000.0,0.6988786458969116,1.3426960706710815,29.391099964317444,3003.0,31121.205088377,52369.14261388779,31121.205088377,21244.30550432205,1.1484274864196775,0.0 -85300,0.206991,1.6516097,,,,,,,,,,,,,,,,, -85400,0.21630162,1.6166462,,,,,,,,,,,,,,,,, -85500,0.198949,1.5495266,,,,,,,,,,,,,,,,, -85600,0.1979535,1.6139523,,,,,,,,,,,,,,,,, -85700,0.26673082,1.6321505,,,,,,,,,,,,,,,,, -85800,0.22066408,1.5814049,,,,,,,,,,,,,,,,, -85900,0.1974829,1.5814646,,,,,,,,,,,,,,,,, -86000,0.20810679,1.6124687,,,,,,,,,,,,,,,,, -86100,0.20799454,1.5854715,,,,,,,,,,,,,,,,, -86200,0.22020379,1.6466454,,,,,,,,,,,,,,,,, -86300,0.20677052,1.6584322,,,,,,,,,,,,,,,,, -86400,0.21408531,1.567904,,,,,,,,,,,,,,,,, -86500,0.22030927,1.6778443,,,,,,,,,,,,,,,,, -86600,0.22362638,1.6196319,,,,,,,,,,,,,,,,, -86700,0.20751013,1.5781603,,,,,,,,,,,,,,,,, -86800,0.20922062,1.5585189,,,,,,,,,,,,,,,,, -86900,0.21804519,1.6194826,,,,,,,,,,,,,,,,, -87000,0.21605621,1.5804571,,,,,,,,,,,,,,,,, -87100,0.22909413,1.5773677,,,,,,,,,,,,,,,,, -87200,0.2038158,1.5413936,,,,,,,,,,,,,,,,, -87300,0.2032545,1.6228807,,,,,,,,,,,,,,,,, -87400,0.22453798,1.640661,,,,,,,,,,,,,,,,, -87500,0.21019298,1.4960618,,,,,,,,,,,,,,,,, -87600,0.2181963,1.6542664,,,,,,,,,,,,,,,,, -87601,,,0.7025133371353149,1.3403351306915283,36.16159872807506,0.6858935356140137,1.4293599128723145,30.037808770372497,3000.0,0.7010400295257568,1.3378076553344729,29.907318585279327,3003.0,31961.25281882286,53790.85534501076,31961.25281882286,21825.86384224892,1.1863529682159424,0.0 -87700,0.22014281,1.7069452,,,,,,,,,,,,,,,,, -87800,0.20252812,1.4849428,,,,,,,,,,,,,,,,, -87900,0.21786048,1.6448405,,,,,,,,,,,,,,,,, -88000,0.20435709,1.6310037,,,,,,,,,,,,,,,,, -88100,0.21581079,1.5684881,,,,,,,,,,,,,,,,, -88200,0.22241655,1.5587554,,,,,,,,,,,,,,,,, -88300,0.19349688,1.568634,,,,,,,,,,,,,,,,, -88400,0.20872965,1.6603007,,,,,,,,,,,,,,,,, -88500,0.21435091,1.614987,,,,,,,,,,,,,,,,, -88600,0.20126845,1.5877863,,,,,,,,,,,,,,,,, -88700,0.22555287,1.6022781,,,,,,,,,,,,,,,,, -88800,0.25178793,1.6164612,,,,,,,,,,,,,,,,, -88900,0.2007361,1.5435481,,,,,,,,,,,,,,,,, -89000,0.2231687,1.5810484,,,,,,,,,,,,,,,,, -89100,0.22553109,1.6011357,,,,,,,,,,,,,,,,, -89200,0.22046989,1.6317483,,,,,,,,,,,,,,,,, -89300,0.20938523,1.5875254,,,,,,,,,,,,,,,,, -89400,0.3239607,1.529159,,,,,,,,,,,,,,,,, -89500,0.21893309,1.5336947,,,,,,,,,,,,,,,,, -89600,0.22398143,1.6051388,,,,,,,,,,,,,,,,, -89700,0.21339105,1.6017053,,,,,,,,,,,,,,,,, -89800,0.20971425,1.531111,,,,,,,,,,,,,,,,, -89900,0.21212551,1.5334992,,,,,,,,,,,,,,,,, -89907,,,0.6764506101608276,1.486994981765747,33.920997173484295,0.6865506768226624,1.4231069087982178,30.330018269000004,3000.0,0.7012608647346497,1.3300431966781616,29.77899500255573,3003.0,32801.215609788895,55294.24873948097,32801.215609788895,22489.192754983906,1.2217700481414795,0.0 -90000,0.20571165,1.6329259,,,,,,,,,,,,,,,,, -90100,0.21081194,1.5589819,,,,,,,,,,,,,,,,, -90200,0.2152712,1.6198785,,,,,,,,,,,,,,,,, -90300,0.2292551,1.5137558,,,,,,,,,,,,,,,,, -90400,0.29771212,1.5527316,,,,,,,,,,,,,,,,, -90500,0.20873861,1.5833966,,,,,,,,,,,,,,,,, -90600,0.20286323,1.6131797,,,,,,,,,,,,,,,,, -90700,0.20329037,1.5157545,,,,,,,,,,,,,,,,, -90800,0.24590006,1.5867116,,,,,,,,,,,,,,,,, -90900,0.22587125,1.6014915,,,,,,,,,,,,,,,,, -91000,0.21404825,1.5307071,,,,,,,,,,,,,,,,, -91100,0.21976902,1.5472933,,,,,,,,,,,,,,,,, -91200,0.22389549,1.6544827,,,,,,,,,,,,,,,,, -91300,0.22375514,1.5648925,,,,,,,,,,,,,,,,, -91400,0.20510873,1.5673832,,,,,,,,,,,,,,,,, -91500,0.20955755,1.5866768,,,,,,,,,,,,,,,,, -91600,0.20507742,1.5208921,,,,,,,,,,,,,,,,, -91700,0.20985447,1.6175605,,,,,,,,,,,,,,,,, -91800,0.22386862,1.5704323,,,,,,,,,,,,,,,,, -91900,0.21330763,1.5397519,,,,,,,,,,,,,,,,, -92000,0.2237701,1.6174959,,,,,,,,,,,,,,,,, -92100,0.2141318,1.63981,,,,,,,,,,,,,,,,, -92200,0.20509678,1.5443975,,,,,,,,,,,,,,,,, -92214,,,0.6773262619972229,1.4910597801208496,33.9196685232179,0.6862035393714905,1.423099160194397,30.39234321071231,3000.0,0.7024693489074707,1.323116898536682,29.70131795830228,3003.0,33641.39935684204,56826.61493492127,33641.39935684204,23181.273468256,1.2571892738342283,0.0 -92300,0.22337836,1.5589938,,,,,,,,,,,,,,,,, -92400,0.21836847,1.622348,,,,,,,,,,,,,,,,, -92500,0.22515231,1.5915067,,,,,,,,,,,,,,,,, -92600,0.2180278,1.5892041,,,,,,,,,,,,,,,,, -92700,0.22460134,1.6499083,,,,,,,,,,,,,,,,, -92800,0.20931704,1.5457495,,,,,,,,,,,,,,,,, -92900,0.22072543,1.6912316,,,,,,,,,,,,,,,,, -93000,0.23653723,1.5976981,,,,,,,,,,,,,,,,, -93100,0.24783075,1.6597443,,,,,,,,,,,,,,,,, -93200,0.2084919,1.5908849,,,,,,,,,,,,,,,,, -93300,0.20739767,1.6338831,,,,,,,,,,,,,,,,, -93400,0.22286336,1.5833225,,,,,,,,,,,,,,,,, -93500,0.23373492,1.5515386,,,,,,,,,,,,,,,,, -93600,0.20971261,1.6006504,,,,,,,,,,,,,,,,, -93700,0.22980267,1.5949634,,,,,,,,,,,,,,,,, -93800,0.2185176,1.5226836,,,,,,,,,,,,,,,,, -93900,0.27020863,1.6133286,,,,,,,,,,,,,,,,, -94000,0.2171331,1.654913,,,,,,,,,,,,,,,,, -94100,0.20824069,1.5196699,,,,,,,,,,,,,,,,, -94200,0.21570706,1.5724937,,,,,,,,,,,,,,,,, -94300,0.22043574,1.5879639,,,,,,,,,,,,,,,,, -94400,0.2256648,1.6256078,,,,,,,,,,,,,,,,, -94500,0.21425715,1.5938276,,,,,,,,,,,,,,,,, -94520,,,0.6835636496543884,1.435005784034729,34.798970108395004,0.6891916990280151,1.411435604095459,30.261402817179807,3000.0,0.7016210556030273,1.318655252456665,29.794721474995853,3003.0,34481.39404511452,58157.11773443222,34481.39404511452,23671.67979907989,1.292123556137085,0.0 -94600,0.22979282,1.560348,,,,,,,,,,,,,,,,, -94700,0.22135743,1.5728686,,,,,,,,,,,,,,,,, -94800,0.21373643,1.5594339,,,,,,,,,,,,,,,,, -94900,0.22146983,1.5655047,,,,,,,,,,,,,,,,, -95000,0.21282038,1.576276,,,,,,,,,,,,,,,,, -95100,0.22154124,1.5991527,,,,,,,,,,,,,,,,, -95200,0.22637606,1.5889518,,,,,,,,,,,,,,,,, -95300,0.21843873,1.5626676,,,,,,,,,,,,,,,,, -95400,0.21066429,1.5118556,,,,,,,,,,,,,,,,, -95500,0.22937942,1.6023535,,,,,,,,,,,,,,,,, -95600,0.20856221,1.5579232,,,,,,,,,,,,,,,,, -95700,0.22312056,1.6750863,,,,,,,,,,,,,,,,, -95800,0.23383172,1.6816,,,,,,,,,,,,,,,,, -95900,0.2408482,1.620601,,,,,,,,,,,,,,,,, -96000,0.20362802,1.4688475,,,,,,,,,,,,,,,,, -96100,0.21001874,1.5624604,,,,,,,,,,,,,,,,, -96200,0.2208119,1.6060529,,,,,,,,,,,,,,,,, -96300,0.23338516,1.5809748,,,,,,,,,,,,,,,,, -96400,0.21896908,1.6339743,,,,,,,,,,,,,,,,, -96500,0.23627548,1.5549022,,,,,,,,,,,,,,,,, -96600,0.2163361,1.5978218,,,,,,,,,,,,,,,,, -96700,0.21622008,1.6777024,,,,,,,,,,,,,,,,, -96800,0.21428621,1.5694202,,,,,,,,,,,,,,,,, -96826,,,0.6781773567199707,1.476668119430542,34.505094291513295,0.6873690485954285,1.410524845123291,30.01696468775817,3000.0,0.7032014727592468,1.310566782951355,29.931591119075662,3003.0,35321.29417848587,59776.336155653,35321.29417848587,24450.89505577088,1.328049898147583,0.0 -96900,0.22414264,1.5653425,,,,,,,,,,,,,,,,, -97000,0.2222563,1.6337464,,,,,,,,,,,,,,,,, -97100,0.20322016,1.5698934,,,,,,,,,,,,,,,,, -97200,0.21138789,1.5767473,,,,,,,,,,,,,,,,, -97300,0.2458223,1.5330758,,,,,,,,,,,,,,,,, -97400,0.21625444,1.572167,,,,,,,,,,,,,,,,, -97500,0.2415617,1.6072123,,,,,,,,,,,,,,,,, -97600,0.21045369,1.5249534,,,,,,,,,,,,,,,,, -97700,0.21391612,1.4916883,,,,,,,,,,,,,,,,, -97800,0.22171056,1.5569265,,,,,,,,,,,,,,,,, -97900,0.21514735,1.5848203,,,,,,,,,,,,,,,,, -98000,0.22036324,1.4735366,,,,,,,,,,,,,,,,, -98100,0.21492384,1.5595471,,,,,,,,,,,,,,,,, -98200,0.21206276,1.5792668,,,,,,,,,,,,,,,,, -98300,0.21918896,1.5753046,,,,,,,,,,,,,,,,, -98400,0.22164436,1.5979611,,,,,,,,,,,,,,,,, -98500,0.20891817,1.5052557,,,,,,,,,,,,,,,,, -98600,0.20914648,1.5293747,,,,,,,,,,,,,,,,, -98700,0.23397815,1.5304111,,,,,,,,,,,,,,,,, -98800,0.22023888,1.5815002,,,,,,,,,,,,,,,,, -98900,0.22855699,1.5532279,,,,,,,,,,,,,,,,, -99000,0.21543446,1.5595438,,,,,,,,,,,,,,,,, -99100,0.21859674,1.612536,,,,,,,,,,,,,,,,, -99133,,,0.6766044497489929,1.4872363805770874,34.415588200909774,0.687480628490448,1.401458740234375,30.11480655939417,3000.0,0.7041775584220886,1.3046365976333618,30.258098751022786,3003.0,36161.29120993614,61157.28055071831,36161.29120993614,24991.73729467392,1.3659710884094238,0.0 -99200,0.20916343,1.5188596,,,,,,,,,,,,,,,,, -99300,0.22162138,1.5472302,,,,,,,,,,,,,,,,, -99400,0.22225125,1.5166655,,,,,,,,,,,,,,,,, -99500,0.2292963,1.4869522,,,,,,,,,,,,,,,,, -99600,0.21841766,1.4940175,,,,,,,,,,,,,,,,, -99700,0.2272652,1.4759815,,,,,,,,,,,,,,,,, -99800,0.22249934,1.5527653,,,,,,,,,,,,,,,,, -99900,0.20759821,1.5396899,,,,,,,,,,,,,,,,, -100000,0.22972852,1.6161957,,,,,,,,,,,,,,,,, -100100,0.2203855,1.612844,,,,,,,,,,,,,,,,, -100200,0.22807309,1.5325102,,,,,,,,,,,,,,,,, -100300,0.21018022,1.5554246,,,,,,,,,,,,,,,,, -100400,0.22128749,1.5778172,,,,,,,,,,,,,,,,, -100500,0.2336957,1.5472198,,,,,,,,,,,,,,,,, -100600,0.21949981,1.5341747,,,,,,,,,,,,,,,,, -100700,0.22125289,1.5531026,,,,,,,,,,,,,,,,, -100800,0.24577016,1.5268896,,,,,,,,,,,,,,,,, -100900,0.22056788,1.5899519,,,,,,,,,,,,,,,,, -101000,0.23135039,1.5630087,,,,,,,,,,,,,,,,, -101100,0.2227387,1.5744953,,,,,,,,,,,,,,,,, -101200,0.2397934,1.5748622,,,,,,,,,,,,,,,,, -101300,0.21371552,1.5345467,,,,,,,,,,,,,,,,, -101400,0.22264636,1.6136853,,,,,,,,,,,,,,,,, -101439,,,0.6870901584625244,1.424283742904663,34.776603339122985,0.6906920075416565,1.3949700593948364,30.415153484049604,3000.0,0.7062227725982666,1.295928120613098,30.72127951020204,3003.0,37001.29831576347,62585.88406729698,37001.29831576347,25580.229484319687,1.402970314025879,0.0 -101500,0.24280366,1.5108352,,,,,,,,,,,,,,,,, -101600,0.22542675,1.527254,,,,,,,,,,,,,,,,, -101700,0.23569588,1.4812106,,,,,,,,,,,,,,,,, -101800,0.22137527,1.540759,,,,,,,,,,,,,,,,, -101900,0.2137904,1.6118759,,,,,,,,,,,,,,,,, -102000,0.20035312,1.5221903,,,,,,,,,,,,,,,,, -102100,0.21768478,1.495189,,,,,,,,,,,,,,,,, -102200,0.21576361,1.4814037,,,,,,,,,,,,,,,,, -102300,0.23869574,1.5081502,,,,,,,,,,,,,,,,, -102400,0.22198245,1.5299066,,,,,,,,,,,,,,,,, -102500,0.23314331,1.4704294,,,,,,,,,,,,,,,,, -102600,0.21766028,1.5205027,,,,,,,,,,,,,,,,, -102700,0.2807365,1.5409418,,,,,,,,,,,,,,,,, -102800,0.22413133,1.5351415,,,,,,,,,,,,,,,,, -102900,0.23639968,1.534573,,,,,,,,,,,,,,,,, -103000,0.22653769,1.4811825,,,,,,,,,,,,,,,,, -103100,0.22116265,1.5595212,,,,,,,,,,,,,,,,, -103200,0.22477683,1.5249375,,,,,,,,,,,,,,,,, -103300,0.24188961,1.5725942,,,,,,,,,,,,,,,,, -103400,0.22106205,1.5738869,,,,,,,,,,,,,,,,, -103500,0.22037777,1.5018941,,,,,,,,,,,,,,,,, -103600,0.22986525,1.5360913,,,,,,,,,,,,,,,,, -103700,0.22481261,1.5286046,,,,,,,,,,,,,,,,, -103746,,,0.6832632422447205,1.444430947303772,34.29670099865465,0.6907168030738831,1.3926249742507937,30.255927333565843,3000.0,0.7068967819213867,1.29006826877594,30.42128479296152,3003.0,37841.34276676178,63981.060650110245,37841.34276676178,26135.257561206818,1.4406273365020752,0.0 -103800,0.2209895,1.4828193,,,,,,,,,,,,,,,,, -103900,0.22380695,1.546403,,,,,,,,,,,,,,,,, -104000,0.23358779,1.5710534,,,,,,,,,,,,,,,,, -104100,0.2163566,1.5452713,,,,,,,,,,,,,,,,, -104200,0.24000649,1.5939951,,,,,,,,,,,,,,,,, -104300,0.22339346,1.5733947,,,,,,,,,,,,,,,,, -104400,0.24459854,1.5587897,,,,,,,,,,,,,,,,, -104500,0.22568488,1.5319518,,,,,,,,,,,,,,,,, -104600,0.22336514,1.5111798,,,,,,,,,,,,,,,,, -104700,0.23068047,1.5076241,,,,,,,,,,,,,,,,, -104800,0.22060145,1.4871619,,,,,,,,,,,,,,,,, -104900,0.22385506,1.5000918,,,,,,,,,,,,,,,,, -105000,0.22617264,1.5057124,,,,,,,,,,,,,,,,, -105100,0.23377995,1.4916012,,,,,,,,,,,,,,,,, -105200,0.22994496,1.5802081,,,,,,,,,,,,,,,,, -105300,0.22726126,1.519307,,,,,,,,,,,,,,,,, -105400,0.22617094,1.5798061,,,,,,,,,,,,,,,,, -105500,0.24387562,1.4714798,,,,,,,,,,,,,,,,, -105600,0.22340712,1.5097235,,,,,,,,,,,,,,,,, -105700,0.2256579,1.4722371,,,,,,,,,,,,,,,,, -105800,0.22310422,1.5136285,,,,,,,,,,,,,,,,, -105900,0.22774696,1.5640379,,,,,,,,,,,,,,,,, -106000,0.2358454,1.567882,,,,,,,,,,,,,,,,, -106052,,,0.6834360361099243,1.44869065284729,34.66426654290793,0.692055881023407,1.3876729011535645,30.47464467803755,3000.0,0.7068037986755371,1.2878295183181765,30.07426208954982,3003.0,38681.39353013039,65369.194038152695,38681.39353013039,26683.23649024964,1.4774634838104248,0.0 -106100,0.21858886,1.480781,,,,,,,,,,,,,,,,, -106200,0.22391447,1.5382942,,,,,,,,,,,,,,,,, -106300,0.23556995,1.58821,,,,,,,,,,,,,,,,, -106400,0.22690819,1.5906006,,,,,,,,,,,,,,,,, -106500,0.21912754,1.4770684,,,,,,,,,,,,,,,,, -106600,0.22115256,1.5149835,,,,,,,,,,,,,,,,, -106700,0.22120896,1.5129337,,,,,,,,,,,,,,,,, -106800,0.21655205,1.5206223,,,,,,,,,,,,,,,,, -106900,0.23013918,1.3770502,,,,,,,,,,,,,,,,, -107000,0.22315781,1.5295185,,,,,,,,,,,,,,,,, -107100,0.25218076,1.5711684,,,,,,,,,,,,,,,,, -107200,0.22401579,1.4829857,,,,,,,,,,,,,,,,, -107300,0.22977519,1.515583,,,,,,,,,,,,,,,,, -107400,0.23391557,1.5099645,,,,,,,,,,,,,,,,, -107500,0.23841725,1.4929684,,,,,,,,,,,,,,,,, -107600,0.22049722,1.5084372,,,,,,,,,,,,,,,,, -107700,0.21494357,1.4879338,,,,,,,,,,,,,,,,, -107800,0.22460791,1.5227494,,,,,,,,,,,,,,,,, -107900,0.23156068,1.5226427,,,,,,,,,,,,,,,,, -108000,0.22603554,1.5478164,,,,,,,,,,,,,,,,, -108100,0.22831817,1.5596371,,,,,,,,,,,,,,,,, -108200,0.22313674,1.5197378,,,,,,,,,,,,,,,,, -108300,0.22860254,1.5401227,,,,,,,,,,,,,,,,, -108357,,,0.6897974610328674,1.4114888906478882,35.04683613009326,0.6923534870147705,1.3850077390670776,30.593538745220897,3000.0,0.7074894309043884,1.2821885347366333,30.3378661903933,3003.0,39521.30562400818,66819.13472270966,39521.30562400818,27293.159260988235,1.5152764320373535,0.0 -108400,0.2308269,1.538604,,,,,,,,,,,,,,,,, -108500,0.22224537,1.472747,,,,,,,,,,,,,,,,, -108600,0.23465009,1.5592258,,,,,,,,,,,,,,,,, -108700,0.23989174,1.5097529,,,,,,,,,,,,,,,,, -108800,0.22619437,1.5055997,,,,,,,,,,,,,,,,, -108900,0.22706486,1.4782374,,,,,,,,,,,,,,,,, -109000,0.22110707,1.4795512,,,,,,,,,,,,,,,,, -109100,0.23603137,1.4854307,,,,,,,,,,,,,,,,, -109200,0.22479321,1.4977953,,,,,,,,,,,,,,,,, -109300,0.22911339,1.4612373,,,,,,,,,,,,,,,,, -109400,0.23178355,1.5147636,,,,,,,,,,,,,,,,, -109500,0.22728045,1.530952,,,,,,,,,,,,,,,,, -109600,0.22859989,1.5642718,,,,,,,,,,,,,,,,, -109700,0.21767855,1.4358689,,,,,,,,,,,,,,,,, -109800,0.22905928,1.4730037,,,,,,,,,,,,,,,,, -109900,0.23830843,1.5215815,,,,,,,,,,,,,,,,, -110000,0.23529553,1.5379289,,,,,,,,,,,,,,,,, -110100,0.22840132,1.5019611,,,,,,,,,,,,,,,,, -110200,0.23458317,1.4772136,,,,,,,,,,,,,,,,, -110300,0.22638607,1.4622885,,,,,,,,,,,,,,,,, -110400,0.2441511,1.5061889,,,,,,,,,,,,,,,,, -110500,0.24153954,1.524141,,,,,,,,,,,,,,,,, -110600,0.23226483,1.5024064,,,,,,,,,,,,,,,,, -110664,,,0.6857077479362488,1.4301151037216189,34.912174504037615,0.6944116950035095,1.377577304840088,30.88829246111172,3000.0,0.7101272344589233,1.273249864578247,30.811757487590317,3003.0,40361.46251535416,68261.31683158875,40361.46251535416,27895.077735185623,1.5552773475646973,0.0 -110664,,,,,,,,,,,,,,40361.46251535416,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 46934cded..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,48 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -779.5880715847015,0.0,19.17674994468689,1,0,19.17674994468689,0.43054574140625,95000000,798.7648780345917,0.4309603513786628,0.4267215120682919,83274637 -1412.1931610107422,0.0281963348388671,139.9085566997528,183,0,139.9085566997528,0.1319288563322368,95000000,1552.1369013786316,0.1290529801410699,0.1295675382628244,83274637 -2048.160971879959,0.0513989925384521,260.5636398792267,359,0,260.5636398792267,0.1298185745990954,95000000,2308.7897160053253,0.1257673784293843,0.1272758297169723,83274637 -2674.5712316036224,0.0720260143280029,381.12420296669006,534,0,381.12420296669006,0.1291501434004934,95000000,3055.787654876709,0.1237547100761776,0.1265307136263113,83274637 -3280.2100269794464,0.0934879779815673,501.3299443721771,716,0,501.3299443721771,0.1288308596011513,95000000,3781.6602625846863,0.1251109601496338,0.1264131890678866,83274637 -3890.018922567368,0.1171457767486572,621.7531876564026,896,0,621.7531876564026,0.1282574816714638,95000000,4511.923014163971,0.1260557845352018,0.1258670893255064,83274637 -4459.738814115524,0.1395285129547119,741.8462219238281,1069,0,741.8462219238281,0.1288197573190789,95000000,5201.764748811722,0.1242104866102619,0.1263148251948587,83274637 -5021.499953269959,0.1599841117858886,862.0807681083679,1251,0,862.0807681083679,0.1285263898745888,95000000,5883.787642478943,0.1227888965030323,0.1261936183963416,83274637 -5551.082055568695,0.1806581020355224,982.0765647888184,1442,0,982.0765647888184,0.128418237942023,95000000,6533.393225431442,0.1241370248930446,0.126015634340798,83274637 -6086.701997041702,0.2041034698486328,1102.328135251999,1627,0,1102.328135251999,0.1280614014802631,95000000,7189.295044898987,0.1250419927491519,0.1259286173995026,83274637 -6578.271548032761,0.2253038883209228,1223.171159029007,1801,0,1223.171159029007,0.1277675512952302,95000000,7801.735506534576,0.124996210325439,0.1254014500645429,83274637 -7100.556094169617,0.2461154460906982,1343.1663393974304,1981,0,1343.1663393974304,0.1279253799033717,95000000,8444.042902231216,0.1222820501868267,0.1253713186264399,83274637 -7634.562827587128,0.2673537731170654,1463.3568828105929,2161,0,1463.3568828105929,0.127584342732319,95000000,9098.268038511276,0.1245490766522832,0.1252325630200491,83274637 -8174.649076461792,0.2893369197845459,1583.834689617157,2337,0,1583.834689617157,0.1278088888774671,95000000,9758.860672712326,0.1220589123150837,0.1253653343449691,83274637 -8685.467759132385,0.3133678436279297,1704.0993795394895,2515,0,1704.0993795394895,0.1277990022203947,95000000,10389.974791765211,0.1256873317695451,0.1254345174065048,83274637 -9243.28431081772,0.3339982032775879,1824.14817070961,2690,0,1824.14817070961,0.1272656276007401,95000000,11067.867232322693,0.1233852642395024,0.1249681097859832,83274637 -9770.026929616928,0.3554120063781738,1944.3549394607544,2871,0,1944.3549394607544,0.1273142332648026,95000000,11714.844737052916,0.1240398846683824,0.1248725952128456,83274637 -10328.035050153732,0.3767735958099365,2064.422380447388,3046,0,2064.422380447388,0.1272158850842927,95000000,12392.948280096054,0.1220686655936751,0.1249816561067124,83274637 -10870.541580438614,0.3982052803039551,2184.762518644333,3221,0,2184.762518644333,0.1270780076377467,95000000,13055.822746276855,0.1233471658127675,0.1247227697722413,83274637 -11406.606538772585,0.4202721118927002,2305.140183210373,3400,0,2305.140183210373,0.1270692928659539,95000000,13712.294030189514,0.1236477804563517,0.1247065230611947,83274637 -11941.78291940689,0.4414503574371338,2425.458587169647,3578,0,2425.458587169647,0.1271274240234375,95000000,14367.816534042358,0.1227167904283265,0.124776033640529,83274637 -12482.893585205078,0.4628136157989502,2545.525310277939,3752,0,2545.525310277939,0.1269764299239309,95000000,15029.021893501282,0.1226627091941593,0.1246724846055355,83274637 -13025.119114637377,0.4842236042022705,2665.8254606723785,3929,0,2665.8254606723785,0.127308511040296,95000000,15691.575644254684,0.1242814511177862,0.1248581031556688,83274637 -13575.154886245728,0.5061235427856445,2786.0510370731354,4104,0,2786.0510370731354,0.1270873604440789,95000000,16361.86536693573,0.1241287340720494,0.1246956142095109,83274637 -14107.437706708908,0.5311899185180664,2906.668580770493,4280,0,2906.668580770493,0.1269933303145559,95000000,17014.7972843647,0.1230434426135799,0.1245712757422936,83274637 -14647.807305336,0.5556049346923828,3026.6890342235565,4465,0,3026.6890342235565,0.1269713012335526,95000000,17675.21853160858,0.1214922842153775,0.1245158436671288,83274637 -15189.587604999542,0.5773332118988037,3147.345049381256,4642,0,3147.345049381256,0.1268638210526315,95000000,18337.68328022957,0.1236645429857871,0.1243984857846971,83274637 -15753.776514053345,0.5985479354858398,3267.979365348816,4821,0,3267.979365348816,0.1267540081311677,95000000,19022.5343644619,0.1229192059159091,0.124411130831937,83274637 -16321.096604585648,0.6199290752410889,3388.0267992019653,5002,0,3388.0267992019653,0.1267046369346217,95000000,19709.929924726486,0.1233405592212887,0.1243999604679303,83274637 -16864.94277358055,0.642521858215332,3508.23348236084,5178,0,3508.23348236084,0.1268297857421875,95000000,20374.01232600212,0.1241698907512538,0.1244587499098296,83274637 -17412.443184375763,0.6688754558563232,3628.45765209198,5354,0,3628.45765209198,0.1269416361842105,95000000,21041.76970744133,0.122747369122299,0.1245628027958073,83274637 -17974.503488063812,0.6905982494354248,3748.835675954818,5535,0,3748.835675954818,0.1266410290296052,95000000,21724.23626804352,0.1241950874550724,0.1242881210994501,83274637 -18512.399346351624,0.7121307849884033,3869.279004096985,5712,0,3869.279004096985,0.1268404406558388,95000000,22382.60340952873,0.1209589424045205,0.124491540160136,83274637 -19074.74055337906,0.740128755569458,3989.6233706474304,5886,0,3989.6233706474304,0.1266918185032894,95000000,23065.32358169556,0.123182110617277,0.1242398312541766,83274637 -19611.05926060677,0.7634031772613525,4110.365835428238,6063,0,4110.365835428238,0.1264803563116776,95000000,23722.414705514908,0.1216745592958334,0.1241032846587941,83274637 -20155.92942047119,0.7852051258087158,4231.051466464996,6239,0,4231.051466464996,0.1260748144428454,95000000,24387.99911737442,0.1216520412514607,0.1238063229334273,83274637 -20705.06143975258,0.8067381381988525,4351.294145584106,6414,0,4351.294145584106,0.1263578943873355,95000000,25057.40193104744,0.122006993787656,0.1239933237755456,83274637 -21273.723149061203,0.8337926864624023,4471.66161942482,6599,0,4471.66161942482,0.126194436379523,95000000,25746.46507501602,0.1218283845100013,0.123925746182722,83274637 -21823.25283265114,0.8554544448852539,4591.978357076645,6774,0,4591.978357076645,0.1263712296258223,95000000,26416.339547872543,0.1202651443159055,0.124064176264906,83274637 -22368.196395874023,0.8878428936004639,4712.086730480194,6951,0,4712.086730480194,0.1261761017783717,95000000,27081.430657863617,0.1239679206115832,0.1238483778206645,83274637 -22923.81343126297,0.910447120666504,4832.796728849411,7129,0,4832.796728849411,0.1262103525185033,95000000,27757.786990880966,0.1212623873081222,0.1238867547305775,83274637 -23460.436990499496,0.932957649230957,4953.348455429077,7310,0,4953.348455429077,0.1262782567845394,95000000,28414.99127578736,0.1225593172177766,0.1238999159480256,83274637 -24003.04792308808,0.9626514911651612,5073.562851428986,7487,0,5073.562851428986,0.1260569997841283,95000000,29077.852850198746,0.1222998002462042,0.1238021491439843,83274637 -24562.72512698173,0.9888694286346436,5193.944893836975,7662,0,5193.944893836975,0.1261222990542763,95000000,29757.94474029541,0.1224185839271957,0.1238043794453792,83274637 -25124.88265299797,1.0150315761566162,5314.483114957809,7837,0,5314.483114957809,0.1260904157894737,95000000,30440.6731274128,0.1214677194519987,0.1237525353486004,83274637 -25667.88701057434,1.038297414779663,5434.737701892853,8019,0,5434.737701892853,0.126142082915296,95000000,31103.96208691597,0.1211583566201745,0.1238081131142738,83274637 -26204.221406936646,1.067702054977417,5555.337151288986,8193,0,5555.337151288986,0.12591006790707238,95000000,31760.931745529175,0.11978451514019156,0.12363587442582645,83274637 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 021384755..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,131 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,3.9360924,0.43507662,,,,,,,,,,, -1,,,0.4309603513786628,0.4267215120682919,83274637.0,0.43054574140625,95000000.0,19.17674994468689,798.7648780345917,19.17674994468689,779.5880715847015,0.0,0.0 -100,0.05709005,0.12944393,,,,,,,,,,, -183,,,0.1290529801410699,0.1295675382628244,83274637.0,0.1319288563322368,95000000.0,139.9085566997528,1552.1369013786316,139.9085566997528,1412.1931610107422,0.0281963348388671,0.0 -200,0.039863233,0.12567216,,,,,,,,,,, -300,0.006608013,0.124249496,,,,,,,,,,, -359,,,0.1257673784293843,0.1272758297169723,83274637.0,0.1298185745990954,95000000.0,260.5636398792267,2308.7897160053253,260.5636398792267,2048.160971879959,0.0513989925384521,0.0 -400,0.04877414,0.13177015,,,,,,,,,,, -500,0.031658117,0.12464927,,,,,,,,,,, -534,,,0.1237547100761776,0.1265307136263113,83274637.0,0.1291501434004934,95000000.0,381.12420296669006,3055.787654876709,381.12420296669006,2674.5712316036224,0.0720260143280029,0.0 -600,0.022134468,0.119190976,,,,,,,,,,, -700,0.047413833,0.12213755,,,,,,,,,,, -716,,,0.1251109601496338,0.1264131890678866,83274637.0,0.1288308596011513,95000000.0,501.3299443721771,3781.6602625846863,501.3299443721771,3280.2100269794464,0.0934879779815673,0.0 -800,0.0176207,0.11906594,,,,,,,,,,, -896,,,0.1260557845352018,0.1258670893255064,83274637.0,0.1282574816714638,95000000.0,621.7531876564026,4511.923014163971,621.7531876564026,3890.018922567368,0.1171457767486572,0.0 -900,0.028485501,0.123912044,,,,,,,,,,, -1000,0.02441325,0.12005597,,,,,,,,,,, -1069,,,0.1242104866102619,0.1263148251948587,83274637.0,0.1288197573190789,95000000.0,741.8462219238281,5201.764748811722,741.8462219238281,4459.738814115524,0.1395285129547119,0.0 -1100,0.012173343,0.12980479,,,,,,,,,,, -1200,0.025473557,0.12973458,,,,,,,,,,, -1251,,,0.1227888965030323,0.1261936183963416,83274637.0,0.1285263898745888,95000000.0,862.0807681083679,5883.787642478943,862.0807681083679,5021.499953269959,0.1599841117858886,0.0 -1300,0.00902799,0.12146268,,,,,,,,,,, -1400,0.0042857993,0.11846672,,,,,,,,,,, -1442,,,0.1241370248930446,0.126015634340798,83274637.0,0.128418237942023,95000000.0,982.0765647888184,6533.393225431442,982.0765647888184,5551.082055568695,0.1806581020355224,0.0 -1500,0.02111355,0.12340115,,,,,,,,,,, -1600,0.0042321533,0.1229988,,,,,,,,,,, -1627,,,0.1250419927491519,0.1259286173995026,83274637.0,0.1280614014802631,95000000.0,1102.328135251999,7189.295044898987,1102.328135251999,6086.701997041702,0.2041034698486328,0.0 -1700,0.018751781,0.12667602,,,,,,,,,,, -1800,0.030665075,0.12425961,,,,,,,,,,, -1801,,,0.124996210325439,0.1254014500645429,83274637.0,0.1277675512952302,95000000.0,1223.171159029007,7801.735506534576,1223.171159029007,6578.271548032761,0.2253038883209228,0.0 -1900,0.0041113812,0.12241669,,,,,,,,,,, -1981,,,0.1222820501868267,0.1253713186264399,83274637.0,0.1279253799033717,95000000.0,1343.1663393974304,8444.042902231216,1343.1663393974304,7100.556094169617,0.2461154460906982,0.0 -2000,0.011595049,0.1184201,,,,,,,,,,, -2100,0.020671304,0.13230817,,,,,,,,,,, -2161,,,0.1245490766522832,0.1252325630200491,83274637.0,0.127584342732319,95000000.0,1463.3568828105929,9098.268038511276,1463.3568828105929,7634.562827587128,0.2673537731170654,0.0 -2200,0.005515536,0.13086484,,,,,,,,,,, -2300,0.014820219,0.12879844,,,,,,,,,,, -2337,,,0.1220589123150837,0.1253653343449691,83274637.0,0.1278088888774671,95000000.0,1583.834689617157,9758.860672712326,1583.834689617157,8174.649076461792,0.2893369197845459,0.0 -2400,0.023019424,0.124028236,,,,,,,,,,, -2500,0.0045582224,0.12599315,,,,,,,,,,, -2515,,,0.1256873317695451,0.1254345174065048,83274637.0,0.1277990022203947,95000000.0,1704.0993795394895,10389.974791765211,1704.0993795394895,8685.467759132385,0.3133678436279297,0.0 -2600,0.029395418,0.124570444,,,,,,,,,,, -2690,,,0.1233852642395024,0.1249681097859832,83274637.0,0.1272656276007401,95000000.0,1824.14817070961,11067.867232322693,1824.14817070961,9243.28431081772,0.3339982032775879,0.0 -2700,0.009638276,0.117418766,,,,,,,,,,, -2800,0.009396469,0.12100652,,,,,,,,,,, -2871,,,0.1240398846683824,0.1248725952128456,83274637.0,0.1273142332648026,95000000.0,1944.3549394607544,11714.844737052916,1944.3549394607544,9770.026929616928,0.3554120063781738,0.0 -2900,0.007092391,0.13121009,,,,,,,,,,, -3000,0.010838089,0.130713,,,,,,,,,,, -3046,,,0.1220686655936751,0.1249816561067124,83274637.0,0.1272158850842927,95000000.0,2064.422380447388,12392.948280096054,2064.422380447388,10328.035050153732,0.3767735958099365,0.0 -3100,0.03202035,0.12970921,,,,,,,,,,, -3200,0.0045139873,0.120193966,,,,,,,,,,, -3221,,,0.1233471658127675,0.1247227697722413,83274637.0,0.1270780076377467,95000000.0,2184.762518644333,13055.822746276855,2184.762518644333,10870.541580438614,0.3982052803039551,0.0 -3300,0.007196172,0.13129672,,,,,,,,,,, -3400,,,0.1236477804563517,0.1247065230611947,83274637.0,0.1270692928659539,95000000.0,2305.140183210373,13712.294030189514,2305.140183210373,11406.606538772585,0.4202721118927002,0.0 -3400,0.0066139484,0.12169848,,,,,,,,,,, -3500,0.006613795,0.12709662,,,,,,,,,,, -3578,,,0.1227167904283265,0.124776033640529,83274637.0,0.1271274240234375,95000000.0,2425.458587169647,14367.816534042358,2425.458587169647,11941.78291940689,0.4414503574371338,0.0 -3600,0.0063947346,0.1344451,,,,,,,,,,, -3700,0.008983965,0.13300493,,,,,,,,,,, -3752,,,0.1226627091941593,0.1246724846055355,83274637.0,0.1269764299239309,95000000.0,2545.525310277939,15029.021893501282,2545.525310277939,12482.893585205078,0.4628136157989502,0.0 -3800,0.0049025533,0.118760265,,,,,,,,,,, -3900,0.0053046364,0.12154223,,,,,,,,,,, -3929,,,0.1242814511177862,0.1248581031556688,83274637.0,0.127308511040296,95000000.0,2665.8254606723785,15691.575644254684,2665.8254606723785,13025.119114637377,0.4842236042022705,0.0 -4000,0.008221152,0.121760815,,,,,,,,,,, -4100,0.031703185,0.11387305,,,,,,,,,,, -4104,,,0.1241287340720494,0.1246956142095109,83274637.0,0.1270873604440789,95000000.0,2786.0510370731354,16361.86536693573,2786.0510370731354,13575.154886245728,0.5061235427856445,0.0 -4200,0.008413831,0.13103718,,,,,,,,,,, -4280,,,0.1230434426135799,0.1245712757422936,83274637.0,0.1269933303145559,95000000.0,2906.668580770493,17014.7972843647,2906.668580770493,14107.437706708908,0.5311899185180664,0.0 -4300,0.015028944,0.11194843,,,,,,,,,,, -4400,0.0043497924,0.120270066,,,,,,,,,,, -4465,,,0.1214922842153775,0.1245158436671288,83274637.0,0.1269713012335526,95000000.0,3026.6890342235565,17675.21853160858,3026.6890342235565,14647.807305336,0.5556049346923828,0.0 -4500,0.010306578,0.12359634,,,,,,,,,,, -4600,0.0053308136,0.13086227,,,,,,,,,,, -4642,,,0.1236645429857871,0.1243984857846971,83274637.0,0.1268638210526315,95000000.0,3147.345049381256,18337.68328022957,3147.345049381256,15189.587604999542,0.5773332118988037,0.0 -4700,0.0048113293,0.12563778,,,,,,,,,,, -4800,0.005843924,0.112806864,,,,,,,,,,, -4821,,,0.1229192059159091,0.124411130831937,83274637.0,0.1267540081311677,95000000.0,3267.979365348816,19022.5343644619,3267.979365348816,15753.776514053345,0.5985479354858398,0.0 -4900,0.022066746,0.12540239,,,,,,,,,,, -5000,0.010844285,0.12438074,,,,,,,,,,, -5002,,,0.1233405592212887,0.1243999604679303,83274637.0,0.1267046369346217,95000000.0,3388.0267992019653,19709.929924726486,3388.0267992019653,16321.096604585648,0.6199290752410889,0.0 -5100,0.014556532,0.13842319,,,,,,,,,,, -5178,,,0.1241698907512538,0.1244587499098296,83274637.0,0.1268297857421875,95000000.0,3508.23348236084,20374.01232600212,3508.23348236084,16864.94277358055,0.642521858215332,0.0 -5200,0.005591566,0.11290233,,,,,,,,,,, -5300,0.005138731,0.116698086,,,,,,,,,,, -5354,,,0.122747369122299,0.1245628027958073,83274637.0,0.1269416361842105,95000000.0,3628.45765209198,21041.76970744133,3628.45765209198,17412.443184375763,0.6688754558563232,0.0 -5400,0.008778748,0.122243784,,,,,,,,,,, -5500,0.0047843046,0.12114176,,,,,,,,,,, -5535,,,0.1241950874550724,0.1242881210994501,83274637.0,0.1266410290296052,95000000.0,3748.835675954818,21724.23626804352,3748.835675954818,17974.503488063812,0.6905982494354248,0.0 -5600,0.015178842,0.12303834,,,,,,,,,,, -5700,0.013363073,0.123717785,,,,,,,,,,, -5712,,,0.1209589424045205,0.124491540160136,83274637.0,0.1268404406558388,95000000.0,3869.279004096985,22382.60340952873,3869.279004096985,18512.399346351624,0.7121307849884033,0.0 -5800,0.010534889,0.12060778,,,,,,,,,,, -5886,,,0.123182110617277,0.1242398312541766,83274637.0,0.1266918185032894,95000000.0,3989.6233706474304,23065.32358169556,3989.6233706474304,19074.74055337906,0.740128755569458,0.0 -5900,0.0060796556,0.12350195,,,,,,,,,,, -6000,0.008758977,0.1317566,,,,,,,,,,, -6063,,,0.1216745592958334,0.1241032846587941,83274637.0,0.1264803563116776,95000000.0,4110.365835428238,23722.414705514908,4110.365835428238,19611.05926060677,0.7634031772613525,0.0 -6100,0.007767477,0.12504159,,,,,,,,,,, -6200,0.008343344,0.11973611,,,,,,,,,,, -6239,,,0.1216520412514607,0.1238063229334273,83274637.0,0.1260748144428454,95000000.0,4231.051466464996,24387.99911737442,4231.051466464996,20155.92942047119,0.7852051258087158,0.0 -6300,0.0065229503,0.11509027,,,,,,,,,,, -6400,0.0098875705,0.12765823,,,,,,,,,,, -6414,,,0.122006993787656,0.1239933237755456,83274637.0,0.1263578943873355,95000000.0,4351.294145584106,25057.40193104744,4351.294145584106,20705.06143975258,0.8067381381988525,0.0 -6500,0.0075627263,0.121489055,,,,,,,,,,, -6599,,,0.1218283845100013,0.123925746182722,83274637.0,0.126194436379523,95000000.0,4471.66161942482,25746.46507501602,4471.66161942482,21273.723149061203,0.8337926864624023,0.0 -6600,0.0054781204,0.120760106,,,,,,,,,,, -6700,0.004444042,0.11894067,,,,,,,,,,, -6774,,,0.1202651443159055,0.124064176264906,83274637.0,0.1263712296258223,95000000.0,4591.978357076645,26416.339547872543,4591.978357076645,21823.25283265114,0.8554544448852539,0.0 -6800,0.0055877157,0.12519377,,,,,,,,,,, -6900,0.006655612,0.12095548,,,,,,,,,,, -6951,,,0.1239679206115832,0.1238483778206645,83274637.0,0.1261761017783717,95000000.0,4712.086730480194,27081.430657863617,4712.086730480194,22368.196395874023,0.8878428936004639,0.0 -7000,0.008799141,0.11888005,,,,,,,,,,, -7100,0.0062353145,0.11702229,,,,,,,,,,, -7129,,,0.1212623873081222,0.1238867547305775,83274637.0,0.1262103525185033,95000000.0,4832.796728849411,27757.786990880966,4832.796728849411,22923.81343126297,0.910447120666504,0.0 -7200,0.015043348,0.11793436,,,,,,,,,,, -7300,0.0065625138,0.119749606,,,,,,,,,,, -7310,,,0.1225593172177766,0.1238999159480256,83274637.0,0.1262782567845394,95000000.0,4953.348455429077,28414.99127578736,4953.348455429077,23460.436990499496,0.932957649230957,0.0 -7400,0.0064175497,0.12538418,,,,,,,,,,, -7487,,,0.1222998002462042,0.1238021491439843,83274637.0,0.1260569997841283,95000000.0,5073.562851428986,29077.852850198746,5073.562851428986,24003.04792308808,0.9626514911651612,0.0 -7500,0.010728486,0.12995608,,,,,,,,,,, -7600,0.0069591464,0.11923873,,,,,,,,,,, -7662,,,0.1224185839271957,0.1238043794453792,83274637.0,0.1261222990542763,95000000.0,5193.944893836975,29757.94474029541,5193.944893836975,24562.72512698173,0.9888694286346436,0.0 -7700,0.005587487,0.11946197,,,,,,,,,,, -7800,0.0065531647,0.12828636,,,,,,,,,,, -7837,,,0.1214677194519987,0.1237525353486004,83274637.0,0.1260904157894737,95000000.0,5314.483114957809,30440.6731274128,5314.483114957809,25124.88265299797,1.0150315761566162,0.0 -7900,0.0063129338,0.12564006,,,,,,,,,,, -8000,0.006726888,0.12376168,,,,,,,,,,, -8019,,,0.1211583566201745,0.1238081131142738,83274637.0,0.126142082915296,95000000.0,5434.737701892853,31103.96208691597,5434.737701892853,25667.88701057434,1.038297414779663,0.0 -8100,0.0062676966,0.12032779,,,,,,,,,,, -8193,,,0.1197845151401915,0.1236358744258264,83274637.0,0.1259100679070723,95000000.0,5555.337151288986,31760.93174552917,5555.337151288986,26204.22140693665,1.067702054977417,0.0 -8193,,,,,,,,5555.337151288986,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index e739657ef..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,26 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -202.559326171875,0.0,56.391112089157104,1,0,56.391112089157104,1.0956219265960974,3581,0.2584000805575432,258.9508192539215,1.0998219762529646,0.2384166376931326,1.0996187174662353,3554,0.2354174921468328 -206.9951455593109,0.0345721244812011,136.3739686012268,337,0,136.3739686012268,0.3128357359697885,3581,0.7142116841707973,343.41506576538086,0.2892417567116873,0.7199898447309222,0.310480309928074,3554,0.6970055472267164 -211.06392359733584,0.0766663551330566,216.5176022052765,582,0,216.5176022052765,0.309186784690467,3581,0.7143449695441217,427.6772720813751,0.2864444255828857,0.7176377432686942,0.3073082335154931,3554,0.6977407168859032 -215.13433575630188,0.1128833293914794,296.8364827632904,832,0,296.8364827632904,0.2989926011959822,3581,0.7280837258490994,512.1105699539185,0.2757127285003662,0.7342916897365025,0.2970286011360439,3554,0.7110329168058878 -219.2087481021881,0.1449482440948486,376.9990062713623,1128,0,376.9990062713623,0.3016011425317648,3581,0.7270126023195337,596.3894002437592,0.2785420077187674,0.7332958493913923,0.2995688932694675,3554,0.7100512022105374 -223.27605557441711,0.1681487560272216,457.1517524719238,1473,0,457.1517524719238,0.2942166215246265,3581,0.7326609705520455,680.6443018913269,0.2712824514933994,0.7385317257472447,0.292653441764561,3554,0.7153012560671075 -227.3480653762817,0.1919243335723877,537.2392702102661,1820,0,537.2392702102661,0.2979964379057875,3581,0.7325720000087266,764.8391892910004,0.2751588821411133,0.7382958275931222,0.2962919888901589,3554,0.7154874184457654 -231.41916298866272,0.2149121761322021,617.2763676643372,2161,0,617.2763676643372,0.2934923808490121,3581,0.7356083158073862,848.9818904399872,0.2701513086046491,0.7415499005998883,0.2918283852151624,3554,0.7184250745199071 -235.4900426864624,0.2387645244598388,697.3838765621185,2508,0,697.3838765621185,0.2914261507675056,3581,0.7364003922612399,933.196018218994,0.2681118590491159,0.7426836150033134,0.2898782140831633,3554,0.7191463678689505 -239.561717748642,0.2617554664611816,777.5420806407928,2854,0,777.5420806407928,0.2924938313756632,3581,0.7343460249188425,1017.460663318634,0.2691434962408883,0.7407185690743583,0.2910748054019239,3554,0.7170123012802476 -243.637444972992,0.2851660251617431,857.5538339614868,3197,0,857.5538339614868,0.2908867029330145,3581,0.7381361700642278,1101.583218574524,0.2674750770841326,0.7446472985403878,0.2893060567208603,3554,0.7210393163952589 -247.7060973644257,0.3091092109680176,937.6811022758484,3545,0,937.6811022758484,0.2907484747517279,3581,0.7352689323862049,1185.814774274826,0.2675669022968837,0.7404789924621582,0.289257592677265,3554,0.7183150944578293 -251.7816228866577,0.3340692520141601,1017.7604601383208,3894,0,1017.7604601383208,0.2907551560645595,3581,0.7374393364065555,1270.006408214569,0.2677767276763916,0.742952687399728,0.2893170135103053,3554,0.7204911334499859 -255.8511726856232,0.3580706119537353,1097.8150520324707,4239,0,1097.8150520324707,0.2891597880960276,3581,0.739725777104859,1354.166056394577,0.2658652407782418,0.7461169787815639,0.2877639315406056,3554,0.7225389196152223 -259.9235010147095,0.3812909126281738,1178.0337393283844,4587,0,1178.0337393283844,0.2900792185536512,3581,0.7370788863969562,1438.4921023845673,0.2663721527372087,0.7441460745675224,0.2887165883181802,3554,0.7196655616910523 -263.99485969543457,0.404461145401001,1258.2251710891724,4937,0,1258.2251710891724,0.2892136817469806,3581,0.739319989615331,1522.78990483284,0.2657783882958548,0.7458579880850655,0.2878372286837894,3554,0.7220892447330473 -268.0642638206482,0.4267683029174804,1338.2833817005155,5281,0,1338.2833817005155,0.289568984418633,3581,0.7398108615784696,1606.9511659145355,0.2662630251475743,0.7460690225873675,0.2881349854477349,3554,0.7228021573403207 -272.1420018672943,0.4503428936004638,1418.387591600418,5627,0,1418.387591600418,0.2887484101202527,3581,0.7408307844352137,1691.168239593506,0.264906849179949,0.7479839324951172,0.2873642491569798,3554,0.7236124102024127 -276.2174243927002,0.4755406379699707,1498.3514828681946,5975,0,1498.3514828681946,0.288400368263055,3581,0.7401961960695337,1775.244146347046,0.2647814069475446,0.7471363885062081,0.2870195911516953,3554,0.7228996349843486 -280.2893924713135,0.5003418922424316,1578.4904873371124,6323,0,1578.4904873371124,0.2885712530652227,3581,0.7403354809890743,1859.4913849830627,0.2650374174118042,0.7470765113830566,0.2872435699102243,3554,0.7230641585625351 -284.363648891449,0.5252387523651123,1658.508949995041,6670,0,1658.508949995041,0.2885813091228009,3581,0.73911150538432,1943.62032198906,0.264569810458592,0.7462148666381836,0.2871640730789691,3554,0.7219730134619443 -288.43237042427063,0.5486984252929688,1738.528511285782,7018,0,1738.528511285782,0.2888908311662245,3581,0.7390810985932701,2027.7436203956604,0.2654059273856027,0.745499679020473,0.2875164592272791,3554,0.7219162717184862 -292.50692319869995,0.5723929405212402,1818.5980372428887,7366,0,1818.5980372428887,0.2886655413881422,3581,0.7401437000401424,2111.922720432281,0.2649506500789097,0.7471366609845843,0.2871706334137152,3554,0.7230871712550999 -296.5782120227814,0.5983188152313232,1898.6324636936188,7712,0,1898.6324636936188,0.2892601100534941,3581,0.7377579259459648,2196.0656032562256,0.2655230249677385,0.74489654813494,0.2879726944442002,3554,0.7204773945290518 -300.6479024887085,0.6236989498138428,1978.7172708511353,8062,0,1978.7172708511353,0.28839941378979334,3581,0.7416393596411617,2280.256668329239,0.26483871255602154,0.7482733726501465,0.2870159675112989,3554,0.7244876481605234 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 582e13fa1..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,108 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.130725,1.1220623,,,,,,,,,,,,,, -1,,,0.2384166376931326,1.0998219762529646,0.2354174921468328,1.0996187174662353,3554.0,0.2584000805575432,1.0956219265960974,3581.0,56.391112089157104,258.9508192539215,56.391112089157104,202.559326171875,0.0,0.0 -100,0.08013673,0.3275945,,,,,,,,,,,,,, -200,0.24918945,0.30402273,,,,,,,,,,,,,, -300,0.34409615,0.3138375,,,,,,,,,,,,,, -337,,,0.7199898447309222,0.2892417567116873,0.6970055472267164,0.310480309928074,3554.0,0.7142116841707973,0.3128357359697885,3581.0,136.3739686012268,343.41506576538086,136.3739686012268,206.9951455593109,0.0345721244812011,0.0 -400,0.40607318,0.3285567,,,,,,,,,,,,,, -500,0.53232586,0.29132095,,,,,,,,,,,,,, -582,,,0.7176377432686942,0.2864444255828857,0.6977407168859032,0.3073082335154931,3554.0,0.7143449695441217,0.309186784690467,3581.0,216.5176022052765,427.6772720813751,216.5176022052765,211.06392359733584,0.0766663551330566,0.0 -600,0.058449563,0.25742507,,,,,,,,,,,,,, -700,0.52254415,0.2359663,,,,,,,,,,,,,, -800,0.26233798,0.24539708,,,,,,,,,,,,,, -832,,,0.7342916897365025,0.2757127285003662,0.7110329168058878,0.2970286011360439,3554.0,0.7280837258490994,0.2989926011959822,3581.0,296.8364827632904,512.1105699539185,296.8364827632904,215.13433575630188,0.1128833293914794,0.0 -900,0.1583623,0.32193533,,,,,,,,,,,,,, -1000,0.09597534,0.22472113,,,,,,,,,,,,,, -1100,0.18203972,0.32535866,,,,,,,,,,,,,, -1128,,,0.7332958493913923,0.2785420077187674,0.7100512022105374,0.2995688932694675,3554.0,0.7270126023195337,0.3016011425317648,3581.0,376.9990062713623,596.3894002437592,376.9990062713623,219.2087481021881,0.1449482440948486,0.0 -1200,0.094448276,0.28865162,,,,,,,,,,,,,, -1300,0.2599754,0.25903004,,,,,,,,,,,,,, -1400,0.3638671,0.24560462,,,,,,,,,,,,,, -1473,,,0.7385317257472447,0.2712824514933994,0.7153012560671075,0.292653441764561,3554.0,0.7326609705520455,0.2942166215246265,3581.0,457.1517524719238,680.6443018913269,457.1517524719238,223.27605557441711,0.1681487560272216,0.0 -1500,0.44677216,0.31872678,,,,,,,,,,,,,, -1600,0.18950914,0.23406258,,,,,,,,,,,,,, -1700,0.12797159,0.3326422,,,,,,,,,,,,,, -1800,0.17747803,0.20437652,,,,,,,,,,,,,, -1820,,,0.7382958275931222,0.2751588821411133,0.7154874184457654,0.2962919888901589,3554.0,0.7325720000087266,0.2979964379057875,3581.0,537.2392702102661,764.8391892910004,537.2392702102661,227.3480653762817,0.1919243335723877,0.0 -1900,0.07745711,0.29840702,,,,,,,,,,,,,, -2000,0.25118476,0.3058094,,,,,,,,,,,,,, -2100,0.04799412,0.31828246,,,,,,,,,,,,,, -2161,,,0.7415499005998883,0.2701513086046491,0.7184250745199071,0.2918283852151624,3554.0,0.7356083158073862,0.2934923808490121,3581.0,617.2763676643372,848.9818904399872,617.2763676643372,231.41916298866272,0.2149121761322021,0.0 -2200,0.2499049,0.23620202,,,,,,,,,,,,,, -2300,0.05843782,0.2930182,,,,,,,,,,,,,, -2400,0.19244525,0.24558763,,,,,,,,,,,,,, -2500,0.1311621,0.25883913,,,,,,,,,,,,,, -2508,,,0.7426836150033134,0.2681118590491159,0.7191463678689505,0.2898782140831633,3554.0,0.7364003922612399,0.2914261507675056,3581.0,697.3838765621185,933.196018218994,697.3838765621185,235.4900426864624,0.2387645244598388,0.0 -2600,0.18265642,0.23491712,,,,,,,,,,,,,, -2700,0.07213435,0.3115383,,,,,,,,,,,,,, -2800,0.0928279,0.27458748,,,,,,,,,,,,,, -2854,,,0.7407185690743583,0.2691434962408883,0.7170123012802476,0.2910748054019239,3554.0,0.7343460249188425,0.2924938313756632,3581.0,777.5420806407928,1017.460663318634,777.5420806407928,239.561717748642,0.2617554664611816,0.0 -2900,0.16429672,0.2938419,,,,,,,,,,,,,, -3000,0.09822041,0.26016802,,,,,,,,,,,,,, -3100,0.11980322,0.23845834,,,,,,,,,,,,,, -3197,,,0.7446472985403878,0.2674750770841326,0.7210393163952589,0.2893060567208603,3554.0,0.7381361700642278,0.2908867029330145,3581.0,857.5538339614868,1101.583218574524,857.5538339614868,243.637444972992,0.2851660251617431,0.0 -3200,0.20207025,0.30822918,,,,,,,,,,,,,, -3300,0.13592179,0.2536915,,,,,,,,,,,,,, -3400,0.031421624,0.30269858,,,,,,,,,,,,,, -3500,0.086290985,0.25007337,,,,,,,,,,,,,, -3545,,,0.7404789924621582,0.2675669022968837,0.7183150944578293,0.289257592677265,3554.0,0.7352689323862049,0.2907484747517279,3581.0,937.6811022758484,1185.814774274826,937.6811022758484,247.7060973644257,0.3091092109680176,0.0 -3600,0.077255584,0.31853363,,,,,,,,,,,,,, -3700,0.3100473,0.3526002,,,,,,,,,,,,,, -3800,0.048381675,0.27417725,,,,,,,,,,,,,, -3894,,,0.742952687399728,0.2677767276763916,0.7204911334499859,0.2893170135103053,3554.0,0.7374393364065555,0.2907551560645595,3581.0,1017.7604601383208,1270.006408214569,1017.7604601383208,251.7816228866577,0.3340692520141601,0.0 -3900,0.10747234,0.299761,,,,,,,,,,,,,, -4000,0.13284613,0.2544521,,,,,,,,,,,,,, -4100,0.14264257,0.19326648,,,,,,,,,,,,,, -4200,0.13502698,0.2712496,,,,,,,,,,,,,, -4239,,,0.7461169787815639,0.2658652407782418,0.7225389196152223,0.2877639315406056,3554.0,0.739725777104859,0.2891597880960276,3581.0,1097.8150520324707,1354.166056394577,1097.8150520324707,255.8511726856232,0.3580706119537353,0.0 -4300,0.0880667,0.38084793,,,,,,,,,,,,,, -4400,0.046944726,0.3165134,,,,,,,,,,,,,, -4500,0.15188596,0.2589569,,,,,,,,,,,,,, -4587,,,0.7441460745675224,0.2663721527372087,0.7196655616910523,0.2887165883181802,3554.0,0.7370788863969562,0.2900792185536512,3581.0,1178.0337393283844,1438.4921023845673,1178.0337393283844,259.9235010147095,0.3812909126281738,0.0 -4600,0.10976906,0.2629108,,,,,,,,,,,,,, -4700,0.15732892,0.25771275,,,,,,,,,,,,,, -4800,0.14643991,0.22009856,,,,,,,,,,,,,, -4900,0.07395892,0.30662218,,,,,,,,,,,,,, -4937,,,0.7458579880850655,0.2657783882958548,0.7220892447330473,0.2878372286837894,3554.0,0.739319989615331,0.2892136817469806,3581.0,1258.2251710891724,1522.78990483284,1258.2251710891724,263.99485969543457,0.404461145401001,0.0 -5000,0.1268579,0.2465822,,,,,,,,,,,,,, -5100,0.14708252,0.24585675,,,,,,,,,,,,,, -5200,0.14521506,0.23001847,,,,,,,,,,,,,, -5281,,,0.7460690225873675,0.2662630251475743,0.7228021573403207,0.2881349854477349,3554.0,0.7398108615784696,0.289568984418633,3581.0,1338.2833817005155,1606.9511659145355,1338.2833817005155,268.0642638206482,0.4267683029174804,0.0 -5300,0.08943622,0.20991458,,,,,,,,,,,,,, -5400,0.16697484,0.26337817,,,,,,,,,,,,,, -5500,0.163502,0.2836933,,,,,,,,,,,,,, -5600,0.06651434,0.26581958,,,,,,,,,,,,,, -5627,,,0.7479839324951172,0.264906849179949,0.7236124102024127,0.2873642491569798,3554.0,0.7408307844352137,0.2887484101202527,3581.0,1418.387591600418,1691.168239593506,1418.387591600418,272.1420018672943,0.4503428936004638,0.0 -5700,0.098243445,0.33385253,,,,,,,,,,,,,, -5800,0.05587707,0.36025366,,,,,,,,,,,,,, -5900,0.043268718,0.32516852,,,,,,,,,,,,,, -5975,,,0.7471363885062081,0.2647814069475446,0.7228996349843486,0.2870195911516953,3554.0,0.7401961960695337,0.288400368263055,3581.0,1498.3514828681946,1775.244146347046,1498.3514828681946,276.2174243927002,0.4755406379699707,0.0 -6000,0.3010619,0.24601421,,,,,,,,,,,,,, -6100,0.08374709,0.2885128,,,,,,,,,,,,,, -6200,0.11889995,0.2947234,,,,,,,,,,,,,, -6300,0.29899684,0.2516868,,,,,,,,,,,,,, -6323,,,0.7470765113830566,0.2650374174118042,0.7230641585625351,0.2872435699102243,3554.0,0.7403354809890743,0.2885712530652227,3581.0,1578.4904873371124,1859.4913849830627,1578.4904873371124,280.2893924713135,0.5003418922424316,0.0 -6400,0.1485045,0.2569887,,,,,,,,,,,,,, -6500,0.051899336,0.20906222,,,,,,,,,,,,,, -6600,0.24322747,0.1966787,,,,,,,,,,,,,, -6670,,,0.7462148666381836,0.264569810458592,0.7219730134619443,0.2871640730789691,3554.0,0.73911150538432,0.2885813091228009,3581.0,1658.508949995041,1943.62032198906,1658.508949995041,284.363648891449,0.5252387523651123,0.0 -6700,0.11618533,0.33645377,,,,,,,,,,,,,, -6800,0.063251786,0.30806798,,,,,,,,,,,,,, -6900,0.09041195,0.34454072,,,,,,,,,,,,,, -7000,0.25850224,0.23599207,,,,,,,,,,,,,, -7018,,,0.745499679020473,0.2654059273856027,0.7219162717184862,0.2875164592272791,3554.0,0.7390810985932701,0.2888908311662245,3581.0,1738.528511285782,2027.7436203956604,1738.528511285782,288.43237042427063,0.5486984252929688,0.0 -7100,0.17099896,0.28328955,,,,,,,,,,,,,, -7200,0.11657627,0.3139846,,,,,,,,,,,,,, -7300,0.09019787,0.281382,,,,,,,,,,,,,, -7366,,,0.7471366609845843,0.2649506500789097,0.7230871712550999,0.2871706334137152,3554.0,0.7401437000401424,0.2886655413881422,3581.0,1818.5980372428887,2111.922720432281,1818.5980372428887,292.50692319869995,0.5723929405212402,0.0 -7400,0.17899609,0.21628088,,,,,,,,,,,,,, -7500,0.17904757,0.22918083,,,,,,,,,,,,,, -7600,0.1869908,0.21273714,,,,,,,,,,,,,, -7700,0.20082739,0.19402641,,,,,,,,,,,,,, -7712,,,0.74489654813494,0.2655230249677385,0.7204773945290518,0.2879726944442002,3554.0,0.7377579259459648,0.2892601100534941,3581.0,1898.6324636936188,2196.0656032562256,1898.6324636936188,296.5782120227814,0.5983188152313232,0.0 -7800,0.12208676,0.29802403,,,,,,,,,,,,,, -7900,0.08308976,0.35803717,,,,,,,,,,,,,, -8000,0.11416018,0.24554859,,,,,,,,,,,,,, -8062,,,0.7482733726501465,0.2648387125560215,0.7244876481605234,0.2870159675112989,3554.0,0.7416393596411617,0.2883994137897933,3581.0,1978.7172708511357,2280.256668329239,1978.7172708511357,300.6479024887085,0.6236989498138428,0.0 -8062,,,,,,,,,,,1978.7172708511353,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 26c2491f7..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,370 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -36.71697783470154,0.0,51.41184377670288,1,0,51.41184377670288,0.0017000001389533,6.912716865539551,10000,88.12890410423279,0.0008769132546149,6.912364482879639,0.0011399999493733,6.912585258483887,50000 -54.24302554130554,0.0274267196655273,561.3822112083435,1519,0,561.3822112083435,0.1094000041484832,4.919663429260254,10000,615.7079174518585,0.161371961236,4.343627452850342,0.1440799981355667,4.503424167633057,50000 -71.78705096244812,0.0556445121765136,1071.4606931209564,3037,0,1071.4606931209564,0.239200010895729,3.840391159057617,10000,1143.4133098125458,0.3479352593421936,3.004847288131714,0.3234999775886535,3.174034833908081,50000 -89.53337836265564,0.083956241607666,1581.70814204216,4558,0,1581.70814204216,0.323600023984909,3.25213885307312,10000,1671.4895164966583,0.4669762253761291,2.329172372817993,0.4377799928188324,2.5032687187194824,50000 -107.06326746940611,0.1095569133758544,2091.871442079544,6079,0,2091.871442079544,0.3765000104904175,2.998728036880493,10000,2199.260755300522,0.5306122303009033,2.0072100162506104,0.4969799816608429,2.2160658836364746,50000 -124.62890672683716,0.1368377208709716,2602.244828939438,7601,0,2602.244828939438,0.417600005865097,2.7481794357299805,10000,2727.28044962883,0.5672432780265808,1.8172202110290527,0.5317000150680542,2.017817258834839,50000 -142.38618516921997,0.1638340950012207,3112.356943130493,9123,0,3112.356943130493,0.4397000074386596,2.59698748588562,10000,3255.2293422222137,0.6525231003761292,1.404765486717224,0.5640599727630615,1.8469562530517576,50000 -160.78351044654846,0.1913130283355713,3622.518341064453,10645,0,3622.518341064453,0.4448000192642212,2.552358627319336,10000,3783.868696928024,0.6332509517669678,1.4783037900924685,0.5731399655342102,1.8122446537017824,50000 -178.64738535881042,0.2245426177978515,4132.719137430191,12167,0,4132.719137430191,0.4365000128746032,2.630974054336548,10000,4312.018725633621,0.6190608739852905,1.5576162338256836,0.5658199787139893,1.834885835647583,50000 -196.4004583358765,0.2521340847015381,4642.680076122284,13689,0,4642.680076122284,0.4770000278949737,2.4406228065490723,10000,4839.813012838364,0.6477000713348389,1.4246323108673096,0.5928399562835693,1.7087825536727903,50000 -214.842568397522,0.289806604385376,5152.619967222214,15211,0,5152.619967222214,0.4726000130176544,2.4248533248901367,10000,5368.286201000214,0.6474409699440002,1.4198980331420898,0.5953800082206726,1.6889138221740725,50000 -233.00697684288025,0.3311009407043457,5662.831291437149,16734,0,5662.831291437149,0.468500018119812,2.439729690551758,10000,5896.757298469544,0.6482182741165161,1.4272924661636353,0.6008999943733215,1.6777706146240234,50000 -251.19899654388428,0.3682253360748291,6173.075484752655,18258,0,6173.075484752655,0.4749000370502472,2.4517030715942383,10000,6425.283870458603,0.6813018321990967,1.2537661790847778,0.6010000109672546,1.6689846515655518,50000 -270.4668297767639,0.4093070030212402,6683.053819656372,19781,0,6683.053819656372,0.4679000079631805,2.46029019355774,10000,6954.625028371811,0.6559510231018066,1.365241765975952,0.5920000076293945,1.7166610956192017,50000 -292.76573491096497,0.4507136344909668,7193.051939487457,21304,0,7193.051939487457,0.4934000372886657,2.318434953689575,10000,7487.018528699875,0.6759207248687744,1.2931432723999023,0.6177799701690674,1.5928702354431152,50000 -316.6596989631653,0.4778110980987549,7703.096212387085,22827,0,7703.096212387085,0.490200012922287,2.343901395797729,10000,8021.0365245342255,0.6716158986091614,1.3155529499053955,0.6108999848365784,1.6090680360794067,50000 -338.99554920196533,0.5040531158447266,8213.247112512589,24351,0,8213.247112512589,0.4942000210285187,2.327371120452881,10000,8553.60300731659,0.6676697731018066,1.3202012777328491,0.6128999590873718,1.6204156875610352,50000 -363.0035355091095,0.53145432472229,8723.437615156174,25875,0,8723.437615156174,0.5009000301361084,2.293175458908081,10000,9087.882278203964,0.71097731590271,1.140007257461548,0.6279399991035461,1.5446945428848269,50000 -387.4523963928223,0.5643725395202637,9233.4411110878,27398,0,9233.4411110878,0.4933000206947326,2.3293328285217285,10000,9622.420518398283,0.697265625,1.176541090011597,0.623479962348938,1.5636550188064575,50000 -411.3435943126679,0.5972199440002441,9743.443282604218,28922,0,9743.443282604218,0.4853000342845917,2.3837692737579346,10000,10156.39930844307,0.6786311864852905,1.2781400680541992,0.6108599901199341,1.6249542236328125,50000 -436.1221706867218,0.6296067237854004,10253.598109006882,30446,0,10253.598109006882,0.4905000329017639,2.291940689086914,10000,10691.41821050644,0.6844507455825806,1.2520849704742432,0.6162399649620056,1.5863946676254272,50000 -459.7998206615448,0.6647074222564697,10763.550921201706,31970,0,10763.550921201706,0.5026000142097473,2.285898447036743,10000,11225.136703014374,0.6871213316917419,1.239983081817627,0.6253799796104431,1.5454217195510864,50000 -484.0868966579437,0.7003006935119629,11273.611344337463,33494,0,11273.611344337463,0.4869000315666199,2.363152027130127,10000,11759.572400569916,0.66796875,1.3174986839294434,0.6146999597549438,1.6063886880874634,50000 -508.71211862564087,0.7353689670562744,11783.628342866898,35018,0,11783.628342866898,0.5064000487327576,2.267633438110352,10000,12294.303589820862,0.7241310477256775,1.0600242614746094,0.6274799704551697,1.5452181100845337,50000 -533.906332731247,0.7664787769317627,12293.778350830078,36542,0,12293.778350830078,0.5040000081062317,2.234344959259033,10000,12829.732397556303,0.7159797549247742,1.0924153327941897,0.6389200091362,1.4935661554336548,50000 -556.6493656635284,0.8046107292175293,12803.850924253464,38066,0,12803.850924253464,0.497700035572052,2.274075508117676,10000,13362.639755010605,0.693757951259613,1.1842937469482422,0.6302199959754944,1.5333516597747805,50000 -578.2093193531036,0.8341331481933594,13313.921353816986,39590,0,13313.921353816986,0.5046000480651855,2.22383189201355,10000,13894.354410886765,0.6944355964660645,1.1813580989837646,0.6322799921035767,1.505229353904724,50000 -600.675609588623,0.8630940914154053,13823.889991521835,41114,0,13823.889991521835,0.5070000290870667,2.210167407989502,10000,14426.871874332428,0.6988998651504517,1.1760797500610352,0.6380599737167358,1.4915393590927124,50000 -623.4824783802032,0.8966398239135742,14334.058574199677,42639,0,14334.058574199677,0.4996000230312347,2.271448135375977,10000,14959.933693647385,0.6919044852256775,1.2171635627746582,0.632319986820221,1.5132615566253662,50000 -645.7052969932556,0.9284241199493408,14844.04565525055,44163,0,14844.04565525055,0.5051000118255615,2.2277207374572754,10000,15492.22899198532,0.7145447731018066,1.0940722227096558,0.633359968662262,1.494249939918518,50000 -667.4693231582642,0.9678726196289062,15354.069417715073,45687,0,15354.069417715073,0.5095000267028809,2.23911714553833,10000,16024.108890295029,0.7106783986091614,1.1211135387420654,0.6406999826431274,1.4792100191116333,50000 -687.938235282898,1.0020246505737305,15864.12791633606,47211,0,15864.12791633606,0.5143000483512878,2.2181320190429688,10000,16554.72419142723,0.6990792155265808,1.1850415468215942,0.6307199597358704,1.5144232511520386,50000 -706.0882074832916,1.0412023067474363,16374.140513420103,48734,0,16374.140513420103,0.517300009727478,2.1623549461364746,10000,17082.979299545288,0.71097731590271,1.1214135885238647,0.6475399732589722,1.454314947128296,50000 -725.4657227993011,1.08018159866333,16884.34206557274,50257,0,16884.34206557274,0.513700008392334,2.207813501358032,10000,17612.652390241623,0.7004743218421936,1.1613342761993408,0.6417399644851685,1.4614994525909424,50000 -744.2578091621399,1.111807346343994,17394.389335393906,51781,0,17394.389335393906,0.5035000443458557,2.28792405128479,10000,18141.577426195145,0.7308274507522583,1.039522409439087,0.6245399713516235,1.5539544820785522,50000 -762.2627856731415,1.1545979976654053,17904.58046245575,53306,0,17904.58046245575,0.5109000205993652,2.1753416061401367,10000,18669.87014555931,0.7253667116165161,1.0528578758239746,0.6495199799537659,1.4327813386917114,50000 -779.59827876091,1.1950109004974363,18414.588896036148,54830,0,18414.588896036148,0.5209000110626221,2.2052316665649414,10000,19197.308834314343,0.7131297588348389,1.118245244026184,0.6433199644088745,1.4688860177993774,50000 -798.1648647785187,1.235038995742798,18924.505990982056,56354,0,18924.505990982056,0.4989000260829925,2.3033556938171387,10000,19725.887673854828,0.6937978267669678,1.2075111865997314,0.6305800080299377,1.5379019975662231,50000 -814.8658380508423,1.271176815032959,19434.580425977707,57879,0,19434.580425977707,0.5266000032424927,2.125169515609741,10000,20252.75336956978,0.7224768400192261,1.0849560499191284,0.6535800099372864,1.421253681182861,50000 -831.8755283355713,1.3114573955535889,19944.806575775143,59404,0,19944.806575775143,0.5297000408172607,2.1135621070861816,10000,20780.083662986755,0.7204440236091614,1.080566167831421,0.6579599976539612,1.4034535884857178,50000 -848.7308654785156,1.3519139289855957,20454.81146836281,60927,0,20454.81146836281,0.5304000377655029,2.15209436416626,10000,21307.03911685944,0.750398576259613,0.9552063941955566,0.6536399722099304,1.428490400314331,50000 -865.6110301017761,1.3921701908111572,20965.004019737244,62452,0,20965.004019737244,0.5215000510215759,2.1652936935424805,10000,21834.207016468048,0.7303292155265808,1.0258498191833496,0.6525799632072449,1.428292751312256,50000 -882.3958539962769,1.4340605735778809,21475.078028202057,63977,0,21475.078028202057,0.5161000490188599,2.201335906982422,10000,22361.16133069992,0.7126116156578064,1.0953108072280884,0.6471799612045288,1.4452688694000244,50000 -899.1315841674805,1.4839389324188232,21985.282883882523,65503,0,21985.282883882523,0.5282000303268433,2.1661126613616943,10000,22888.20652484893,0.7235730290412903,1.0580490827560425,0.6538400053977966,1.4230231046676636,50000 -915.8876085281372,1.521956205368042,22495.20993900299,67028,0,22495.20993900299,0.539900004863739,2.0818004608154297,10000,23414.981965065,0.7344347834587097,1.0239489078521729,0.6676200032234192,1.360720157623291,50000 -932.5114860534668,1.5625011920928955,23005.27852010727,68553,0,23005.27852010727,0.5121000409126282,2.1813247203826904,10000,23941.77029824257,0.7496412396430969,0.9645988941192628,0.641319990158081,1.464681625366211,50000 -949.3906710147858,1.6044843196868896,23515.206540346146,70078,0,23515.206540346146,0.5301000475883484,2.1011109352111816,10000,24468.674884796143,0.7494021058082581,0.9496089220046996,0.6649399995803833,1.3731956481933594,50000 -966.1160097122192,1.6491477489471436,24025.367934703827,71603,0,24025.367934703827,0.5294000506401062,2.140475511550904,10000,24995.66038680077,0.7323421239852905,1.0162076950073242,0.6559000015258789,1.420893669128418,50000 -982.704963684082,1.689319372177124,24535.35845017433,73128,0,24535.35845017433,0.5420000553131104,2.0839552879333496,10000,25522.33491587639,0.7429647445678711,0.9706432223320008,0.6710000038146973,1.3437200784683228,50000 -999.3370826244354,1.7352948188781738,25045.55570292473,74653,0,25045.55570292473,0.51500004529953,2.1855597496032715,10000,26049.26572537422,0.7179328799247742,1.0882242918014526,0.6489799618721008,1.4370622634887695,50000 -1016.7017018795012,1.7748703956604004,25555.715767622,76178,0,25555.715767622,0.5300000309944153,2.1334023475646973,10000,26576.88431286812,0.73246169090271,1.026774287223816,0.6608399748802185,1.3760240077972412,50000 -1033.5776641368866,1.8206498622894287,26065.67103695869,77702,0,26065.67103695869,0.5368000268936157,2.1114277839660645,10000,27103.81628870964,0.7727798223495483,0.8478021025657654,0.670199990272522,1.3518798351287842,50000 -1050.3322749137878,1.860451936721801,26575.838837385178,79227,0,26575.838837385178,0.5470000505447388,2.039279699325561,10000,27630.83352947235,0.7599848508834839,0.9066538214683532,0.6744799613952637,1.3212937116622925,50000 -1066.9713323116302,1.9235823154449463,27085.72404384613,80751,0,27085.72404384613,0.5384000539779663,2.07574200630188,10000,28157.477063655853,0.7531289458274841,0.9358610510826112,0.6725999712944031,1.338200569152832,50000 -1083.7439770698547,1.9646568298339844,27595.813599586487,82276,0,27595.813599586487,0.5415000319480896,2.055564641952514,10000,28684.43613266945,0.7472097873687744,0.9645992517471312,0.6704999804496765,1.3342227935791016,50000 -1100.5029304027555,2.0091564655303955,28105.84876608849,83801,0,28105.84876608849,0.5383000373840332,2.107455253601074,10000,29211.3304066658,0.7385004758834839,0.9893120527267456,0.664900004863739,1.3621970415115356,50000 -1117.165348291397,2.049769401550293,28616.067785024643,85327,0,28616.067785024643,0.5385000109672546,2.0425949096679688,10000,29738.30677652359,0.7606425285339355,0.9143653512001038,0.6753199696540833,1.3267306089401243,50000 -1133.8120720386505,2.0952630043029785,29126.03837108612,86852,0,29126.03837108612,0.5397000312805176,2.076277732849121,10000,30265.02364993096,0.7731385231018066,0.8398574590682983,0.6742599606513977,1.3247716426849363,50000 -1150.2982242107391,2.140028476715088,29636.034289360046,88376,0,29636.034289360046,0.5455000400543213,2.0423593521118164,10000,30791.60491251945,0.7651267647743225,0.8875182867050171,0.6764000058174133,1.306697130203247,50000 -1167.15305352211,2.1838526725769043,30145.949870347977,89900,0,30145.949870347977,0.5515000224113464,2.004889488220215,10000,31318.47411704064,0.7606425285339355,0.8991544842720032,0.6783599853515625,1.304556965827942,50000 -1184.3367433547974,2.234565019607544,30655.910806179047,91424,0,30655.910806179047,0.5540000200271606,2.0211398601531982,10000,31845.72444677353,0.7586694955825806,0.9094310998916626,0.6788399815559387,1.301895022392273,50000 -1201.0060527324677,2.280010223388672,31165.815375089645,92948,0,31165.815375089645,0.5422000288963318,2.0770034790039062,10000,32372.397426128387,0.7504783272743225,0.9368448853492736,0.6730200052261353,1.32490336894989,50000 -1217.6655213832855,2.319309711456299,31675.780433416367,94472,0,31675.780433416367,0.5512000322341919,2.0316712856292725,10000,32899.11539578438,0.802754282951355,0.7244368195533752,0.6849600076675415,1.2787599563598633,50000 -1234.2645723819733,2.3646020889282227,32186.006425857544,95996,0,32186.006425857544,0.530500054359436,2.1422789096832275,10000,33426.04164767265,0.7570551633834839,0.9185478091239928,0.6616399884223938,1.392592191696167,50000 -1251.119171857834,2.4251294136047363,32696.02758693695,97520,0,32696.02758693695,0.5635000467300415,1.9730379581451416,10000,33953.034247636795,0.7798349857330322,0.8255358338356018,0.6879799962043762,1.2594505548477173,50000 -1267.745992422104,2.4731388092041016,33206.21092581749,99045,0,33206.21092581749,0.5490000247955322,2.031160354614258,10000,34479.94657087326,0.7669602632522583,0.872205913066864,0.6832599639892578,1.2858682870864868,50000 -1284.3461983203888,2.517145156860352,33716.131663799286,100569,0,33716.131663799286,0.5618000030517578,2.010899305343628,10000,35006.567068099976,0.7652662396430969,0.8769233822822571,0.6861599683761597,1.2804961204528809,50000 -1301.061465740204,2.5662214756011963,34226.252163648605,102094,0,34226.252163648605,0.5538000464439392,1.9944852590560915,10000,35533.50708389282,0.7729990482330322,0.8372607827186584,0.6893799901008606,1.2580443620681765,50000 -1317.6769466400146,2.6166603565216064,34736.3680062294,103619,0,34736.3680062294,0.5582000017166138,2.011260747909546,10000,36060.34365081787,0.7914739847183228,0.7607530951499939,0.6838200092315674,1.2859368324279783,50000 -1334.381745815277,2.665559768676758,35246.5287566185,105144,0,35246.5287566185,0.5649999976158142,1.953182339668274,10000,36587.31213951111,0.7905572056770325,0.7790660858154297,0.6904199719429016,1.2633452415466309,50000 -1351.1192100048063,2.7153921127319336,35756.58667945862,106669,0,35756.58667945862,0.5599000453948975,2.0167007446289062,10000,37114.21175909042,0.7826650142669678,0.8016412258148193,0.6915199756622314,1.2654014825820925,50000 -1367.7784917354584,2.7660505771636963,36266.73639631271,108194,0,36266.73639631271,0.5564000010490417,2.0124053955078125,10000,37641.12651062012,0.7798748016357422,0.8180897831916809,0.6882199645042419,1.2640862464904783,50000 -1384.3214082717896,2.8171417713165283,36776.9069879055,109719,0,36776.9069879055,0.5622000098228455,1.9760229587554927,10000,38167.94488573074,0.786531388759613,0.7900282740592957,0.6940799951553345,1.2394052743911743,50000 -1401.1726398468018,2.862706422805786,37287.03126382828,111244,0,37287.03126382828,0.5749000310897827,1.9428856372833248,10000,38695.02125692368,0.8427534699440002,0.5710492134094238,0.6987000107765198,1.2164056301116943,50000 -1417.827962398529,2.9112823009490967,37797.02880716324,112768,0,37797.02880716324,0.5710000395774841,1.9627013206481927,10000,39221.77711343765,0.812898576259613,0.6848605871200562,0.698419988155365,1.2241036891937256,50000 -1434.3773369789124,2.959364175796509,38307.227852106094,114293,0,38307.227852106094,0.5749000310897827,1.921513319015503,10000,39748.62899613381,0.8019969463348389,0.7204362154006958,0.7006999850273132,1.2203973531723022,50000 -1450.955587387085,3.0082755088806152,38817.41033864021,115818,0,38817.41033864021,0.5778000354766846,1.939614176750183,10000,40275.49406194687,0.8029735088348389,0.7088333964347839,0.7026199698448181,1.2160744667053225,50000 -1467.459722518921,3.056226253509521,39327.47944116592,117343,0,39327.47944116592,0.5725000500679016,1.961987257003784,10000,40802.17005300522,0.7983697056770325,0.7344706654548645,0.7015999555587769,1.2197657823562622,50000 -1484.165627002716,3.1047351360321045,39837.70289897919,118869,0,39837.70289897919,0.5705000162124634,1.956821322441101,10000,41329.202719688416,0.8006417155265808,0.7289294600486755,0.700439989566803,1.2137749195098877,50000 -1500.6749844551086,3.155070304870605,40347.663089990616,120394,0,40347.663089990616,0.579800009727478,1.906605124473572,10000,41855.777509212494,0.8328284025192261,0.6055227518081665,0.7037999629974365,1.199912428855896,50000 -1517.159220457077,3.204002857208252,40857.58970141411,121919,0,40857.58970141411,0.5807000398635864,1.9057252407073968,10000,42382.29209589958,0.82425856590271,0.62799072265625,0.7076799869537354,1.1814346313476562,50000 -1533.836226463318,3.2530641555786133,41367.651956796646,123444,0,41367.651956796646,0.5823000073432922,1.890337824821472,10000,42909.13440656662,0.8248166441917419,0.625325620174408,0.7119799852371216,1.170873522758484,50000 -1550.5613188743591,3.304357051849365,41877.71378183365,124969,0,41877.71378183365,0.5931000113487244,1.8835399150848389,10000,43436.02739739418,0.8237603306770325,0.6367748379707336,0.7152199745178223,1.161817193031311,50000 -1567.2161169052124,3.3536806106567383,42387.83588695526,126494,0,42387.83588695526,0.5841000080108643,1.881402969360352,10000,43962.90839862824,0.8254543542861938,0.6237724423408508,0.7156199812889099,1.1618980169296265,50000 -1583.872680425644,3.401171922683716,42897.97383713722,128019,0,42897.97383713722,0.5861000418663025,1.9103562831878664,10000,44489.80436134338,0.8600525856018066,0.4976900219917297,0.7145199775695801,1.1730865240097046,50000 -1600.5715289115906,3.4611611366271973,43407.957873106,129544,0,43407.957873106,0.5879000425338745,1.882236361503601,10000,45016.60299420357,0.8498086333274841,0.5317939519882202,0.7183399796485901,1.1576786041259766,50000 -1617.2234988212583,3.512183427810669,43918.01558470726,131069,0,43918.01558470726,0.5843000411987305,1.9149274826049805,10000,45543.41831231117,0.8396045565605164,0.5685706734657288,0.7088800072669983,1.184848427772522,50000 -1633.9201967716217,3.5640017986297607,44428.10785365105,132594,0,44428.10785365105,0.5939000248908997,1.8933364152908323,10000,46070.31389427185,0.8439094424247742,0.5533644556999207,0.7190600037574768,1.15404212474823,50000 -1650.5535411834717,3.6218745708465576,44938.21067500114,134120,0,44938.21067500114,0.5896000266075134,1.887064933776856,10000,46597.1632874012,0.8478555083274841,0.5405254364013672,0.7234199643135071,1.133036494255066,50000 -1667.193475484848,3.674959659576416,45448.31420826912,135646,0,45448.31420826912,0.593000054359436,1.909881353378296,10000,47124.01470160484,0.8484534025192261,0.5400087833404541,0.7185199856758118,1.1664830446243286,50000 -1683.880009889603,3.7243666648864746,45958.44335961342,137172,0,45958.44335961342,0.5922000408172607,1.8785035610198968,10000,47650.93498468399,0.8808394074440002,0.4171130955219269,0.7249799966812134,1.1274539232254028,50000 -1700.59228515625,3.780108690261841,46468.43311071396,138698,0,46468.43311071396,0.6017000079154968,1.8933743238449097,10000,48177.746673583984,0.8705556392669678,0.4483226835727691,0.7280600070953369,1.13063383102417,50000 -1717.3363778591156,3.829201221466065,46978.51990914345,140223,0,46978.51990914345,0.6011000275611877,1.8609994649887085,10000,48704.682052373886,0.8729272484779358,0.4394337832927704,0.7323200106620789,1.1090658903121948,50000 -1733.9795546531675,3.879467010498047,47488.66160154343,141749,0,47488.66160154343,0.602400004863739,1.851275086402893,10000,49231.571748018265,0.8720304369926453,0.4448748230934143,0.7297999858856201,1.107924222946167,50000 -1750.5503687858582,3.9371063709259033,47998.768881082535,143275,0,47998.768881082535,0.6026000380516052,1.867939591407776,10000,49758.36278486252,0.8688217401504517,0.4498252868652344,0.7297799587249756,1.1145541667938232,50000 -1767.1221826076508,3.994133949279785,48508.867156744,144800,0,48508.867156744,0.6032000184059143,1.8936916589736936,10000,50285.14401555061,0.89164137840271,0.3792960345745086,0.7303000092506409,1.120832443237305,50000 -1783.9339122772217,4.045078992843628,49019.03108620644,146325,0,49019.03108620644,0.6065000295639038,1.8500038385391235,10000,50812.22521138191,0.8965840339660645,0.36079141497612,0.7327799797058105,1.1133227348327637,50000 -1800.472449302673,4.096417427062988,49529.17516756058,147851,0,49529.17516756058,0.6067000031471252,1.86048686504364,10000,51339.01440405846,0.8950892686843872,0.3633810877799988,0.7359399795532227,1.0984907150268557,50000 -1817.333134889603,4.147741079330444,50039.28535270691,149376,0,50039.28535270691,0.6077000498771667,1.858315110206604,10000,51866.091879844666,0.8971220850944519,0.3583529591560364,0.7389199733734131,1.0933120250701904,50000 -1833.978541135788,4.200202941894531,50549.420686244965,150901,0,50549.420686244965,0.6073000431060791,1.8836166858673096,10000,52392.9816904068,0.8970025181770325,0.3487862348556518,0.740339994430542,1.1000778675079346,50000 -1850.6516172885893,4.255874156951904,51059.445302248,152426,0,51059.445302248,0.6089000105857849,1.82943332195282,10000,52919.78895497322,0.9053730964660645,0.3358666598796844,0.738319993019104,1.0839452743530271,50000 -1867.9935710430143,4.310412168502808,51569.52954649925,153951,0,51569.52954649925,0.6115000247955322,1.877955794334412,10000,53447.32394480705,0.9297871589660645,0.2487999647855758,0.7396799921989441,1.0863995552062988,50000 -1884.5948798656464,4.365032911300659,52079.50380158424,155476,0,52079.50380158424,0.617400050163269,1.85482394695282,10000,53974.00850534439,0.919343888759613,0.2705561220645904,0.743939995765686,1.0910815000534058,50000 -1901.0806233882904,4.418842077255249,52589.87325882912,157001,0,52589.87325882912,0.6187000274658203,1.8395804166793823,10000,54500.97246050835,0.9226123690605164,0.2707573473453522,0.7443199753761292,1.079198122024536,50000 -1917.6260554790497,4.475511312484741,53099.86465787888,158525,0,53099.86465787888,0.6181000471115112,1.8403584957122805,10000,55027.62099266052,0.9240074753761292,0.2637510001659393,0.7441999912261963,1.0730669498443604,50000 -1934.39634346962,4.528732776641846,53610.03590750694,160050,0,53610.03590750694,0.6213000416755676,1.8441084623336792,10000,55554.67021775246,0.9275350570678712,0.2511384785175323,0.7457000017166138,1.075924515724182,50000 -1950.8956489562988,4.591880559921265,54120.14770627022,161575,0,54120.14770627022,0.6184000372886658,1.8437731266021729,10000,56081.3993742466,0.9306042790412904,0.2416285127401352,0.7472400069236755,1.0739290714263916,50000 -1967.4637095928192,4.646895170211792,54630.08940792084,163099,0,54630.08940792084,0.6213000416755676,1.836686849594116,10000,56608.01881337166,0.9433194994926452,0.2045804262161255,0.7457999587059021,1.067560791969299,50000 -1984.099481344223,4.70342230796814,55140.205631017685,164624,0,55140.205631017685,0.6182000041007996,1.8608922958374023,10000,57134.88123440743,0.9432597160339355,0.200706347823143,0.7478399872779846,1.0719140768051147,50000 -2000.7171757221224,4.762126207351685,55650.33197426796,166149,0,55650.33197426796,0.6215000152587891,1.8464528322219849,10000,57661.73883676529,0.9441764950752258,0.1963724493980407,0.7493000030517578,1.0691767930984497,50000 -2017.3095281124115,4.821733713150024,56160.3334069252,167673,0,56160.3334069252,0.6243000030517578,1.844048857688904,10000,58188.44730424881,0.9439173936843872,0.1987826824188232,0.7494999766349792,1.0682849884033203,50000 -2033.912229537964,4.877910375595093,56670.40325450897,169198,0,56670.40325450897,0.6219000220298767,1.847320795059204,10000,58715.23123264313,0.9467872977256776,0.1883253902196884,0.7502999901771545,1.0689047574996948,50000 -2050.458500146866,4.937411308288574,57180.59017777443,170723,0,57180.59017777443,0.6260000467300415,1.84285831451416,10000,59242.078933000565,0.9580875039100648,0.1585330516099929,0.7505199909210205,1.065174221992493,50000 -2066.960639238357,4.997782468795776,57690.77718710899,172248,0,57690.77718710899,0.6273000240325928,1.85725200176239,10000,59768.884503126144,0.9566724896430968,0.1599704623222351,0.7523999810218811,1.0659375190734863,50000 -2083.6383051872253,5.058853149414063,58200.87525296211,173772,0,58200.87525296211,0.628000020980835,1.849697709083557,10000,60295.775837898254,0.9553371667861938,0.1635049283504486,0.7538999915122986,1.0613198280334473,50000 -2100.2404623031616,5.118907690048218,58711.03318500519,175297,0,58711.03318500519,0.628000020980835,1.851643443107605,10000,60822.650486946106,0.9546197056770324,0.1615114212036132,0.7534799575805664,1.060207486152649,50000 -2116.711015462876,5.191303253173828,59220.95730185509,176821,0,59220.95730185509,0.6261000037193298,1.847219467163086,10000,61349.172367334366,0.9572106003761292,0.158622071146965,0.7540599703788757,1.0589282512664795,50000 -2133.417886257172,5.247831106185913,59730.89150452614,178345,0,59730.89150452614,0.6258000135421753,1.8444610834121704,10000,61875.924149274826,0.95804762840271,0.153394877910614,0.7551199793815613,1.0549086332321167,50000 -2149.941981315613,5.304441213607788,60240.81658220291,179869,0,60240.81658220291,0.6279000043869019,1.8462977409362795,10000,62402.48623228073,0.9615951776504515,0.1446188241243362,0.7554000020027161,1.056932806968689,50000 -2166.7651064395905,5.369521379470825,60750.70178294182,181393,0,60750.70178294182,0.6279000043869019,1.8437385559082031,10000,62929.314858675,0.96000075340271,0.1480083018541336,0.7560999989509583,1.0549815893173218,50000 -2183.3145368099213,5.4302978515625,61260.74364876747,182918,0,61260.74364876747,0.6274000406265259,1.8420039415359497,10000,63456.022471666336,0.9614157676696776,0.1445799022912979,0.7551400065422058,1.0538650751113892,50000 -2199.8553504943848,5.492738246917725,61770.76652598381,184442,0,61770.76652598381,0.6267000436782837,1.8438823223114007,10000,63982.70475888252,0.960598647594452,0.1481765508651733,0.7552399635314941,1.054193139076233,50000 -2216.459883451462,5.554203510284424,62280.75594091416,185967,0,62280.75594091416,0.6271000504493713,1.8431257009506223,10000,64509.414836645126,0.9606983065605164,0.1485299617052078,0.7558599710464478,1.0535261631011963,50000 -2233.191954135895,5.617657661437988,62790.8897960186,187492,0,62790.8897960186,0.626800000667572,1.8425025939941408,10000,65036.399695158005,0.9605787396430968,0.146380066871643,0.7558000087738037,1.053703546524048,50000 -2249.7101545333862,5.676936149597168,63300.8772277832,189017,0,63300.8772277832,0.6281000375747681,1.843846201896668,10000,65563.01993250847,0.9619937539100648,0.14264976978302,0.7558599710464478,1.0539004802703855,50000 -2266.336406469345,5.737735748291016,63811.00440359116,190542,0,63811.00440359116,0.6265000104904175,1.8426756858825684,10000,66089.88953256607,0.9632294178009032,0.142451986670494,0.7558799982070923,1.0543040037155151,50000 -2282.8199610710144,5.799633979797363,64321.091645240784,192067,0,64321.091645240784,0.6270000338554382,1.8448634147644043,10000,66616.57750272751,0.9621531963348388,0.1447095423936844,0.7556399703025818,1.0540560483932495,50000 -2300.1227231025696,5.86241602897644,64831.305763959885,193593,0,64831.305763959885,0.6272000074386597,1.841867446899414,10000,67144.21295118332,0.9602199792861938,0.1460577994585037,0.7560200095176697,1.0536680221557615,50000 -2316.826899766922,5.926252603530884,65341.203255176544,195118,0,65341.203255176544,0.6266000270843506,1.84319007396698,10000,67670.93338441849,0.9604591727256776,0.1471811532974243,0.7560399770736694,1.053143858909607,50000 -2333.264596939087,5.989895343780518,65851.13903975487,196643,0,65851.13903975487,0.6277000308036804,1.840379357337952,10000,68197.4244966507,0.961933970451355,0.1429389268159866,0.756119966506958,1.0525805950164795,50000 -2349.8306188583374,6.049070358276367,66361.20763134956,198168,0,66361.20763134956,0.627500057220459,1.841537594795227,10000,68724.17261481285,0.9618542790412904,0.1434179842472076,0.7559199929237366,1.052737832069397,50000 -2366.4155824184418,6.112751245498657,66871.37352252007,199693,0,66871.37352252007,0.626800000667572,1.8423495292663568,10000,69251.04145288467,0.9616350531578064,0.1440503299236297,0.7558000087738037,1.0533536672592163,50000 -2383.205046653748,6.177959680557251,67381.27967381477,201217,0,67381.27967381477,0.6271000504493713,1.840630888938904,10000,69777.85696268082,0.960160195827484,0.1471266746520996,0.7561599612236023,1.0531275272369385,50000 -2399.942197084427,6.244342565536499,67891.42406487465,202742,0,67891.42406487465,0.6278000473976135,1.8418047428131104,10000,70304.85989117622,0.9604591727256776,0.1461871117353439,0.755620002746582,1.0531686544418335,50000 -2416.4498698711395,6.305903196334839,68401.58019638062,204268,0,68401.58019638062,0.6273000240325928,1.84238874912262,10000,70831.64112019539,0.9620735049247742,0.1424334943294525,0.7561399936676025,1.0534887313842771,50000 -2432.962220907212,6.379040241241455,68911.47825837135,205793,0,68911.47825837135,0.6279000043869019,1.840342879295349,10000,71358.17896771431,0.9616150856018066,0.145903930068016,0.755840003490448,1.0521328449249268,50000 -2449.6562552452087,6.44709062576294,69421.45237731934,207318,0,69421.45237731934,0.6272000074386597,1.8419071435928345,10000,71884.97119355202,0.9592832922935486,0.1478613168001175,0.7560399770736694,1.0538548231124878,50000 -2466.2402641773224,6.519393682479858,69931.56385087967,208843,0,69931.56385087967,0.6274000406265259,1.8434849977493288,10000,72411.7939863205,0.9614756107330322,0.1456343531608581,0.7562199831008911,1.0526604652404783,50000 -2482.894499540329,6.58108377456665,70441.48584985733,210368,0,70441.48584985733,0.6283000111579895,1.842454433441162,10000,72938.48751044273,0.9622129797935486,0.1451591402292251,0.7560799717903137,1.0537360906600952,50000 -2499.4245150089264,6.648257732391357,70951.64380955696,211893,0,70951.64380955696,0.6271000504493713,1.8440243005752563,10000,73465.29731369019,0.962312638759613,0.1424638777971267,0.7555199861526489,1.0546331405639648,50000 -2516.0319879055023,6.712458610534668,71461.56416463852,213418,0,71461.56416463852,0.6278000473976135,1.843787789344788,10000,73991.9448082447,0.9602997303009032,0.1486295908689499,0.7561999559402466,1.0539764165878296,50000 -2532.8387157917023,6.777054309844971,71971.70177531242,214943,0,71971.70177531242,0.626800000667572,1.8391730785369875,10000,74519.00781488419,0.9592633843421936,0.1488154977560043,0.7554599642753601,1.0519776344299316,50000 -2549.4425597190857,6.845051527023315,72481.71206188202,216468,0,72481.71206188202,0.6269000172615051,1.8440710306167605,10000,75045.74508166313,0.961336076259613,0.145295962691307,0.756060004234314,1.0536319017410278,50000 -2565.9037024974823,6.909708738327026,72991.72962641716,217992,0,72991.72962641716,0.6264000535011292,1.841581106185913,10000,75572.34304332733,0.9616350531578064,0.1456587761640548,0.755899965763092,1.0531891584396362,50000 -2582.430620908737,6.99518346786499,73501.86399936676,219517,0,73501.86399936676,0.6279000043869019,1.8403915166854856,10000,76099.1445813179,0.9599609375,0.1475105881690979,0.7558000087738037,1.052554965019226,50000 -2599.18460726738,7.058007955551148,74011.82157802582,221041,0,74011.82157802582,0.6272000074386597,1.843128561973572,10000,76625.97312402725,0.9603196382522584,0.1453995704650879,0.7556799650192261,1.0543261766433716,50000 -2615.8400671482086,7.123383760452271,74521.89327073097,222567,0,74521.89327073097,0.6279000043869019,1.84185791015625,10000,77152.81997966766,0.961355984210968,0.1447490602731704,0.7556399703025818,1.052621603012085,50000 -2632.435756444931,7.188803672790527,75031.88020396233,224092,0,75031.88020396233,0.6265000104904175,1.8424419164657595,10000,77679.52306723595,0.9618343114852904,0.1450876146554947,0.7557799816131592,1.0538315773010254,50000 -2649.096224308014,7.255248308181763,75541.82284140587,225617,0,75541.82284140587,0.6282000541687012,1.8433343172073364,10000,78206.24735283852,0.9606385231018066,0.1497254520654678,0.7558799982070923,1.053784728050232,50000 -2665.7563643455505,7.321617126464844,76052.00051903725,227143,0,76052.00051903725,0.6274000406265259,1.8423489332199097,10000,78733.20663499832,0.9608577489852904,0.1467025727033615,0.7559599876403809,1.0523626804351809,50000 -2682.2827892303467,7.397505044937134,76561.98649954796,228667,0,76561.98649954796,0.6272000074386597,1.8436956405639648,10000,79259.84963774681,0.962890625,0.1419253796339035,0.7560200095176697,1.0544745922088623,50000 -2699.087404489517,7.466514587402344,77071.98637270927,230192,0,77071.98637270927,0.627500057220459,1.8432239294052124,10000,79786.77798008919,0.962133288383484,0.1451293677091598,0.7562199831008911,1.053865909576416,50000 -2716.8570907115936,7.537179708480835,77581.86835837364,231716,0,77581.86835837364,0.6273000240325928,1.8422167301177976,10000,80314.55636286736,0.960957407951355,0.1444891989231109,0.7560799717903137,1.053759217262268,50000 -2733.3647408485413,7.594546318054199,78091.93625879288,233241,0,78091.93625879288,0.6271000504493713,1.8426487445831297,10000,80841.24501657486,0.9602000713348388,0.1468252092599868,0.7556799650192261,1.0530420541763306,50000 -2750.1510181427,7.663392543792725,78601.9127805233,234766,0,78601.9127805233,0.6276000142097473,1.8427488803863523,10000,81368.13152265549,0.961694836616516,0.1441184133291244,0.7562599778175354,1.0534793138504028,50000 -2766.808699846268,7.740317106246948,79112.02974677086,236291,0,79112.02974677086,0.6277000308036804,1.8410484790802,10000,81895.03840327263,0.961355984210968,0.1435891985893249,0.7560799717903137,1.0528885126113892,50000 -2783.3465077877045,7.810421228408813,79621.93251681328,237816,0,79621.93251681328,0.6279000043869019,1.8438612222671509,10000,82421.6057536602,0.961734652519226,0.1432908326387405,0.7560999989509583,1.054379105567932,50000 -2799.8947973251343,7.878611326217651,80132.0266327858,239341,0,80132.0266327858,0.6271000504493713,1.8431426286697388,10000,82948.37228608131,0.9607780575752258,0.1474478989839553,0.7555999755859375,1.0538862943649292,50000 -2816.4147934913635,7.948309659957886,80642.20257234573,240867,0,80642.20257234573,0.6279000043869019,1.843974471092224,10000,83475.19399499893,0.961316168308258,0.1430841088294983,0.7557399868965149,1.0552139282226562,50000 -2832.955675125122,8.019760608673096,81152.19447517395,242392,0,81152.19447517395,0.6281000375747681,1.843250632286072,10000,84001.85370469093,0.961535394191742,0.145120620727539,0.7561399936676025,1.054344654083252,50000 -2849.7008497715,8.091844081878662,81662.36057519913,243918,0,81662.36057519913,0.6273000240325928,1.8438608646392824,10000,84528.89245724678,0.9616748690605164,0.1430436074733734,0.7561799883842468,1.0540276765823364,50000 -2866.384510755539,8.159352779388428,82172.32745957375,245443,0,82172.32745957375,0.6276000142097473,1.8433830738067627,10000,85055.66488361359,0.9590441584587096,0.1470848023891449,0.755899965763092,1.0540157556533811,50000 -2882.9736964702606,8.230422973632812,82682.42796611786,246968,0,82682.42796611786,0.6270000338554382,1.843010663986206,10000,85582.4812579155,0.9604790806770324,0.1470487266778946,0.756060004234314,1.053966760635376,50000 -2899.4492712020874,8.30149245262146,83192.36039113998,248492,0,83192.36039113998,0.6277000308036804,1.8438180685043333,10000,86109.01463675499,0.9622329473495485,0.1438832581043243,0.7556799650192261,1.054873824119568,50000 -2916.1209042072296,8.375693798065186,83702.24043631554,250017,0,83702.24043631554,0.6274000406265259,1.8433725833892824,10000,86635.69530034065,0.9619140625,0.1431471705436706,0.756060004234314,1.0536892414093018,50000 -2932.704018354416,8.445951461791992,84212.27120828629,251542,0,84212.27120828629,0.6281000375747681,1.8413687944412231,10000,87162.43416666985,0.960718274116516,0.1478823870420456,0.7559199929237366,1.0535964965820312,50000 -2949.2420732975006,8.548237800598145,84722.27003097534,253067,0,84722.27003097534,0.6267000436782837,1.843252420425415,10000,87689.12922596931,0.9599011540412904,0.1474985778331756,0.7556599974632263,1.054338455200195,50000 -2965.849251270294,8.616014003753662,85232.41868042946,254592,0,85232.41868042946,0.627500057220459,1.8423309326171875,10000,88216.00778889656,0.9603196382522584,0.1466724723577499,0.7559599876403809,1.0540852546691897,50000 -2982.4352350234985,8.684788465499878,85742.36588978767,256117,0,85742.36588978767,0.6273000240325928,1.8429205417633057,10000,88742.66361165047,0.9610570669174194,0.1451389640569687,0.7560799717903137,1.0543256998062134,50000 -2999.1208930015564,8.751543760299683,86252.392973423,257642,0,86252.392973423,0.626800000667572,1.842377066612244,10000,89269.49858641624,0.9624322056770324,0.1431660652160644,0.7559599876403809,1.0529989004135132,50000 -3015.657775402069,8.82495379447937,86762.40819621086,259166,0,86762.40819621086,0.6267000436782837,1.8413981199264529,10000,89796.18056607246,0.959402859210968,0.1494778543710708,0.7557399868965149,1.0539062023162842,50000 -3032.370623588562,8.895149946212769,87272.5882089138,260691,0,87272.5882089138,0.6272000074386597,1.8419866561889648,10000,90323.19861745834,0.9612762928009032,0.14541095495224,0.755899965763092,1.0525975227355957,50000 -3048.934982776642,8.971031188964844,87782.54205465317,262215,0,87782.54205465317,0.6269000172615051,1.8433120250701904,10000,90849.84784388542,0.961136758327484,0.1466261595487594,0.7559399604797363,1.0539066791534424,50000 -3065.657157897949,9.04961919784546,88292.70301318169,263740,0,88292.70301318169,0.6279000043869019,1.843176245689392,10000,91376.86492967606,0.9615154266357422,0.146799087524414,0.7556999921798706,1.0533528327941897,50000 -3082.2946906089783,9.11883282661438,88802.68064022064,265265,0,88802.68064022064,0.626800000667572,1.8413385152816768,10000,91903.60470747948,0.9604192972183228,0.1475045830011367,0.755840003490448,1.0530436038970947,50000 -3098.7946536540985,9.190809488296509,89312.86989212036,266791,0,89312.86989212036,0.6271000504493713,1.8437012434005733,10000,92430.42082619669,0.9616549611091614,0.1436686366796493,0.7555999755859375,1.053915023803711,50000 -3115.280136346817,9.261037111282349,89823.01347899437,268315,0,89823.01347899437,0.6284000277519226,1.842480182647705,10000,92957.17655205728,0.9625318646430968,0.1464009433984756,0.7557199597358704,1.0538593530654907,50000 -3131.8267362117767,9.355196475982666,90333.03701162338,269840,0,90333.03701162338,0.6276000142097473,1.8441766500473025,10000,93483.89601063728,0.9630300998687744,0.1411333978176117,0.7552599906921387,1.0552451610565186,50000 -3149.2292981147766,9.430007934570312,90843.17112255096,271366,0,90843.17112255096,0.626800000667572,1.844045639038086,10000,94011.56312346458,0.960180163383484,0.1452442556619644,0.7560399770736694,1.054051399230957,50000 -3165.8072340488434,9.508738994598389,91353.35414123537,272892,0,91353.35414123537,0.6278000473976135,1.842432737350464,10000,94538.45765447617,0.959363043308258,0.1494936347007751,0.7558000087738037,1.0542312860488892,50000 -3182.3766729831696,9.585581302642822,91863.55009460448,274418,0,91863.55009460448,0.6262000203132629,1.8420876264572144,10000,95065.35556936264,0.962312638759613,0.1411945521831512,0.7556999921798706,1.053430438041687,50000 -3198.919429302216,9.665200233459473,92373.46947193146,275942,0,92373.46947193146,0.6269000172615051,1.843717694282532,10000,95591.95217609406,0.9617745280265808,0.1434799134731292,0.755620002746582,1.0534061193466189,50000 -3215.657660007477,9.739282369613647,92883.548268795,277466,0,92883.548268795,0.6273000240325928,1.8438247442245483,10000,96118.89766335487,0.961933970451355,0.1448633521795272,0.7560799717903137,1.0545258522033691,50000 -3232.344386816025,9.817296981811523,93393.59517908096,278989,0,93393.59517908096,0.6269000172615051,1.843254327774048,10000,96645.7634203434,0.960379421710968,0.1471159309148788,0.7558000087738037,1.0548198223114014,50000 -3248.8543951511383,9.893784284591677,93903.50009322166,280511,0,93903.50009322166,0.6272000074386597,1.842954397201538,10000,97172.31008005142,0.9603196382522584,0.1468950361013412,0.755840003490448,1.0539928674697876,50000 -3265.3635454177856,10.01119351387024,94413.65333604813,282035,0,94413.65333604813,0.6262000203132629,1.842752814292908,10000,97699.14486837389,0.9624720811843872,0.1407591849565506,0.7559399604797363,1.053431510925293,50000 -3281.825499534607,10.0881450176239,94923.66025543211,283558,0,94923.66025543211,0.6278000473976135,1.8434096574783323,10000,98225.74566698074,0.9611965417861938,0.1465084254741668,0.756060004234314,1.053581476211548,50000 -3298.5862398147583,10.166528940200806,95433.71224570274,285082,0,95433.71224570274,0.6277000308036804,1.8420534133911133,10000,98752.69161009789,0.9600207209587096,0.1455494463443756,0.7560200095176697,1.0529658794403076,50000 -3315.0661194324493,10.243273496627808,95943.7070596218,286605,0,95943.7070596218,0.6269000172615051,1.844022750854492,10000,99279.29757618904,0.9612563848495485,0.1459049135446548,0.7559799551963806,1.054235577583313,50000 -3331.6242294311523,10.32362985610962,96453.7833352089,288129,0,96453.7833352089,0.6270000338554382,1.8442836999893188,10000,99806.06667613985,0.9620535373687744,0.1462266147136688,0.7553799748420715,1.0544589757919312,50000 -3348.3108434677124,10.402316093444824,96963.97358345984,289653,0,96963.97358345984,0.6270000338554382,1.843524932861328,10000,100333.07609534264,0.9625717401504515,0.1420051157474517,0.7561999559402466,1.053595781326294,50000 -3364.91783618927,10.482990264892578,97473.8518486023,291176,0,97473.8518486023,0.6282000541687012,1.8418251276016235,10000,100859.69595813753,0.9592633843421936,0.1484569460153579,0.7561599612236023,1.0529533624649048,50000 -3381.796614646912,10.558565378189089,97983.82679986954,292699,0,97983.82679986954,0.6266000270843506,1.842485427856445,10000,101386.6808810234,0.9601402878761292,0.1482071578502655,0.7561799883842468,1.052278757095337,50000 -3398.401019334793,10.63289713859558,98493.7816810608,294222,0,98493.7816810608,0.626300036907196,1.8423364162445068,10000,101913.36937975883,0.960718274116516,0.1471677124500274,0.7562400102615356,1.0532735586166382,50000 -3414.94277882576,10.699851036071776,99003.77694940568,295746,0,99003.77694940568,0.6274000406265259,1.8412050008773804,10000,102440.0279521942,0.9618542790412904,0.144152745604515,0.756060004234314,1.0529661178588867,50000 -3431.464282512665,10.78010392189026,99513.71591758728,297270,0,99513.71591758728,0.6274000406265259,1.8436675071716309,10000,102966.62348175047,0.9604192972183228,0.145902469754219,0.7557599544525146,1.05452299118042,50000 -3448.119199991226,10.856154441833496,100023.7195968628,298794,0,100023.7195968628,0.6273000240325928,1.8408688306808472,10000,103493.4132258892,0.9609375,0.145173043012619,0.7562800049781799,1.0528514385223389,50000 -3464.7071413993835,10.93600869178772,100533.74055194856,300318,0,100533.74055194856,0.628000020980835,1.844310998916626,10000,104020.15699887276,0.9599409699440002,0.1476414948701858,0.7556399703025818,1.0543302297592163,50000 -3481.1808309555054,11.012544393539429,101043.92377829552,301843,0,101043.92377829552,0.628000020980835,1.8417717218399048,10000,104546.94538855553,0.9615154266357422,0.1449480205774307,0.7561799883842468,1.0539348125457764,50000 -3497.745079278946,11.105444431304932,101553.90138459206,303366,0,101553.90138459206,0.6267000436782837,1.8432505130767824,10000,105073.63487243652,0.9603196382522584,0.1485083252191543,0.7553399801254272,1.054638147354126,50000 -3514.2617585659027,11.190539121627808,102063.76321053503,304890,0,102063.76321053503,0.6270000338554382,1.8422266244888303,10000,105600.15265583992,0.9614756107330322,0.1455578804016113,0.7556599974632263,1.0540353059768677,50000 -3530.9448194503784,11.271536588668823,102573.88640189172,306414,0,102573.88640189172,0.626800000667572,1.841051697731018,10000,106127.09458732604,0.9628507494926452,0.1430933475494384,0.7557799816131592,1.0519496202468872,50000 -3547.7343316078186,11.354299306869509,103084.04380178452,307938,0,103084.04380178452,0.6266000270843506,1.842125654220581,10000,106654.18004608154,0.9618542790412904,0.1455347239971161,0.7562599778175354,1.0540136098861694,50000 -3565.0480420589447,11.41740655899048,103594.16372561456,309462,0,103594.16372561456,0.6272000074386597,1.8425387144088743,10000,107181.73225021362,0.9624919891357422,0.1417141854763031,0.7557399868965149,1.053352117538452,50000 -3581.7052421569824,11.497244358062744,104104.17977261543,310984,0,104104.17977261543,0.6276000142097473,1.842854619026184,10000,107708.53921103476,0.9591238498687744,0.149687573313713,0.7559199929237366,1.0534484386444092,50000 -3598.469397068024,11.572259664535522,104614.0937845707,312507,0,104614.0937845707,0.6279000043869019,1.8433014154434204,10000,108235.34715676308,0.9618343114852904,0.1447564959526062,0.7558799982070923,1.053959846496582,50000 -3615.086302042008,11.65283203125,105123.96016263962,314029,0,105123.96016263962,0.6272000074386597,1.8429999351501465,10000,108761.96542572977,0.9614556431770324,0.1427123695611953,0.7558000087738037,1.0546585321426392,50000 -3631.61336350441,11.731301069259644,105633.98216438292,315553,0,105633.98216438292,0.626800000667572,1.8422795534133911,10000,109288.64826965332,0.9622528553009032,0.1424500048160553,0.7561399936676025,1.053011775016785,50000 -3648.164837121964,11.824159622192385,106143.96684598924,317077,0,106143.96684598924,0.6277000308036804,1.8440029621124268,10000,109815.33223104475,0.9612762928009032,0.1458015292882919,0.7560399770736694,1.0532641410827637,50000 -3664.6562349796295,11.90298581123352,106654.22787618636,318601,0,106654.22787618636,0.6276000142097473,1.8443596363067627,10000,110342.21775627136,0.9610171914100648,0.1459923684597015,0.7558799982070923,1.0537469387054443,50000 -3681.319794178009,11.985785961151125,107164.28194069862,320125,0,107164.28194069862,0.6277000308036804,1.839996337890625,10000,110869.0744421482,0.9613759517669678,0.146775797009468,0.7560999989509583,1.0519914627075195,50000 -3699.360595703125,12.066577196121216,107674.23223781586,321629,0,107674.23223781586,0.6273000240325928,1.8425594568252563,10000,111397.20132136343,0.9612762928009032,0.1422847658395767,0.7560399770736694,1.0536316633224487,50000 -3716.6544332504272,12.149267673492432,108184.30120563509,323153,0,108184.30120563509,0.6271000504493713,1.844212532043457,10000,111924.70087218285,0.9610371589660645,0.145976260304451,0.7555999755859375,1.0547045469284058,50000 -3733.7039256095886,12.232320070266724,108694.235871315,324677,0,108694.235871315,0.6270000338554382,1.8409771919250488,10000,112451.8219499588,0.9600406289100648,0.1472441256046295,0.756119966506958,1.0521292686462402,50000 -3750.8754110336304,12.315529584884644,109204.26220989227,326201,0,109204.26220989227,0.6288000345230103,1.8415600061416624,10000,112979.1577911377,0.9612563848495485,0.1457024663686752,0.7558199763298035,1.053580641746521,50000 -3767.899305582048,12.398417234420776,109714.14976787569,327724,0,109714.14976787569,0.6270000338554382,1.84445321559906,10000,113506.20581841467,0.962332546710968,0.1425944864749908,0.7559399604797363,1.0540307760238647,50000 -3784.79211306572,12.491395235061646,110224.18912863731,329248,0,110224.18912863731,0.6265000104904175,1.841633915901184,10000,114033.2861776352,0.9614955186843872,0.1465763002634048,0.7559199929237366,1.0538091659545898,50000 -3801.7138855457306,12.616500854492188,110734.2365758419,330772,0,110734.2365758419,0.6279000043869019,1.84337055683136,10000,114560.4348537922,0.9596420526504515,0.1490146070718765,0.7561399936676025,1.0536932945251465,50000 -3818.649199962616,12.698453664779665,111244.34099173546,332296,0,111244.34099173546,0.6276000142097473,1.843773365020752,10000,115087.61276984216,0.9601203799247742,0.1480126678943634,0.7555999755859375,1.053968071937561,50000 -3835.370189905167,12.785258054733276,111754.31213140488,333820,0,111754.31213140488,0.6271000504493713,1.841819643974304,10000,115614.44660925864,0.9615951776504515,0.1442433297634124,0.7560200095176697,1.0530657768249512,50000 -3852.092576265335,12.87584137916565,112264.47166848184,335344,0,112264.47166848184,0.6282000541687012,1.8424676656723025,10000,116141.47416877748,0.9606385231018066,0.1442461311817169,0.7555599808692932,1.053433537483215,50000 -3868.7485876083374,12.985454320907593,112774.3374478817,336868,0,112774.3374478817,0.626800000667572,1.841143488883972,10000,116668.15985655785,0.9594228267669678,0.1478379368782043,0.7559999823570251,1.0530071258544922,50000 -3885.435347318649,13.066871643066406,113284.472843647,338393,0,113284.472843647,0.6278000473976135,1.84160315990448,10000,117195.11727333067,0.9607780575752258,0.1457557678222656,0.7559199929237366,1.0537904500961304,50000 -3902.153110980988,13.152750968933104,113794.48411178587,339918,0,113794.48411178587,0.6276000142097473,1.843093991279602,10000,117721.9868991375,0.9614157676696776,0.1437328457832336,0.7555599808692932,1.0540083646774292,50000 -3918.9126632213593,13.23846983909607,114304.55684185028,341443,0,114304.55684185028,0.6276000142097473,1.843406319618225,10000,118248.95959162712,0.9618343114852904,0.1472954601049423,0.7556799650192261,1.0539937019348145,50000 -3935.580096244812,13.327780723571776,114814.5960021019,342968,0,114814.5960021019,0.6274000406265259,1.8429653644561768,10000,118775.81064414978,0.9596819281578064,0.14888696372509,0.7560799717903137,1.0539145469665527,50000 -3952.235322237015,13.418944597244264,115324.70213341711,344493,0,115324.70213341711,0.627500057220459,1.84093177318573,10000,119302.71807384492,0.9618741869926452,0.1440886259078979,0.7561799883842468,1.052625298500061,50000 -3968.7403090000153,13.50695252418518,115834.6356317997,346017,0,115834.6356317997,0.628000020980835,1.8426051139831543,10000,119829.2996737957,0.9626514315605164,0.1440666168928146,0.7559799551963806,1.0532101392745972,50000 -3985.4980008602142,13.59234380722046,116344.7417371273,347542,0,116344.7417371273,0.627500057220459,1.844247579574585,10000,120356.30416107178,0.9629504084587096,0.142887681722641,0.7559599876403809,1.0534168481826782,50000 -4002.7495708465576,13.679721355438232,116854.75827503204,349067,0,116854.75827503204,0.627500057220459,1.8426477909088133,10000,120883.71487736702,0.9593430757522584,0.1477753520011901,0.7563599944114685,1.05348002910614,50000 -4019.250088214874,13.76552438735962,117364.76794171332,350591,0,117364.76794171332,0.6266000270843506,1.8422508239746087,10000,121410.36559605598,0.9602199792861938,0.1474127322435379,0.7559199929237366,1.0539977550506592,50000 -4035.767923355103,13.854008674621582,117874.66694450378,352115,0,117874.66694450378,0.628000020980835,1.8433915376663208,10000,121936.92624044418,0.9626514315605164,0.1426747590303421,0.7557599544525146,1.054054856300354,50000 -4052.32389998436,13.944149494171144,118384.54222488403,353639,0,118384.54222488403,0.627500057220459,1.84516704082489,10000,122463.50339007378,0.9620535373687744,0.1417208313941955,0.7559399604797363,1.0543521642684937,50000 -4069.257670402527,14.031378030776978,118894.63495087624,355163,0,118894.63495087624,0.626800000667572,1.8399124145507808,10000,122990.6729967594,0.9616549611091614,0.1446973383426666,0.7563599944114685,1.0527466535568235,50000 -4085.834371328354,14.115801095962524,119404.62611865996,356687,0,119404.62611865996,0.626800000667572,1.842963933944702,10000,123517.38052773476,0.9599210619926452,0.1480130702257156,0.7562999725341797,1.0537092685699463,50000 -4102.438501119614,14.206216812133787,119914.49568676949,358211,0,119914.49568676949,0.6278000473976135,1.8446106910705569,10000,124043.99854040146,0.9610769748687744,0.1448649168014526,0.7556999921798706,1.0543800592422483,50000 -4119.014720439911,14.295626163482666,120424.40429615974,359735,0,120424.40429615974,0.6267000436782837,1.8425129652023315,10000,124570.62741327286,0.9620934128761292,0.1429737657308578,0.7559399604797363,1.0538039207458496,50000 -4135.7907111644745,14.385024547576904,120934.38921999931,361259,0,120934.38921999931,0.6271000504493713,1.8423242568969729,10000,125097.53248429298,0.960957407951355,0.1437092125415802,0.7561599612236023,1.053349852561951,50000 -4152.445852518082,14.476088762283323,121444.4788107872,362784,0,121444.4788107872,0.6282000541687012,1.8418200016021729,10000,125624.42392516136,0.9610570669174194,0.1475917845964431,0.7564399838447571,1.053044080734253,50000 -4169.068185567856,14.566959142684937,121954.5669374466,364309,0,121954.5669374466,0.6279000043869019,1.841619849205017,10000,126151.28058385848,0.960598647594452,0.1446724534034729,0.7563199996948242,1.0528602600097656,50000 -4185.526557207108,14.665619611740112,122464.4394042492,365833,0,122464.4394042492,0.626800000667572,1.841666460037232,10000,126677.76492261888,0.9619140625,0.1466599553823471,0.7559599876403809,1.0532474517822266,50000 -4202.016310453415,14.755435466766356,122974.46217608452,367358,0,122974.46217608452,0.6270000338554382,1.843050837516785,10000,127204.42240476608,0.9614955186843872,0.1442684829235077,0.7562800049781799,1.0540716648101809,50000 -4218.683201313019,14.845154762268066,123484.38162231444,368883,0,123484.38162231444,0.6273000240325928,1.842363953590393,10000,127731.15286946297,0.9606983065605164,0.1458490192890167,0.7562800049781799,1.0531740188598633,50000 -4235.262308835983,14.935875177383425,123994.45357394218,370408,0,123994.45357394218,0.6269000172615051,1.842165470123291,10000,128257.9499156475,0.9596619606018066,0.1495345383882522,0.7561599612236023,1.0537803173065186,50000 -4251.81521987915,15.023420572280884,124504.32865929604,371933,0,124504.32865929604,0.6274000406265259,1.841818809509277,10000,128784.52232336998,0.9612364172935486,0.1453320086002349,0.756060004234314,1.0531513690948486,50000 -4268.265427350998,15.114747524261476,125014.37292766573,373458,0,125014.37292766573,0.6274000406265259,1.8427166938781736,10000,129311.163510561,0.9603196382522584,0.1460485756397247,0.756060004234314,1.054241180419922,50000 -4285.071767568588,15.203657388687134,125524.3220334053,374982,0,125524.3220334053,0.6284000277519226,1.8403536081314087,10000,129838.06326699255,0.961136758327484,0.1452155411243438,0.7561799883842468,1.052392840385437,50000 -4301.894116163254,15.297083377838137,126034.21823072432,376506,0,126034.21823072432,0.6279000043869019,1.839387059211731,10000,130364.93006849287,0.9604990482330322,0.1470222175121307,0.756339967250824,1.051685094833374,50000 -4318.468457937241,15.392176628112791,126544.26707220078,378031,0,126544.26707220078,0.626800000667572,1.8416091203689573,10000,130891.7034125328,0.9604392051696776,0.1461730599403381,0.7559999823570251,1.0533385276794434,50000 -4335.007391452789,15.482544660568236,127054.2448322773,379555,0,127054.2448322773,0.626800000667572,1.84324848651886,10000,131418.3648777008,0.9622528553009032,0.146114632487297,0.7559799551963806,1.0539250373840332,50000 -4351.570056438446,15.580342531204224,127564.3393175602,381080,0,127564.3393175602,0.6282000541687012,1.8413126468658447,10000,131945.17438149452,0.9605189561843872,0.1476415544748306,0.7562800049781799,1.0527596473693848,50000 -4368.340341567993,15.680543184280396,128074.41396832466,382605,0,128074.41396832466,0.6281000375747681,1.8423075675964355,10000,132472.1737203598,0.9604392051696776,0.1470810174942016,0.7559999823570251,1.0539182424545288,50000 -4384.839435815811,15.771271228790283,128584.54533720016,384130,0,128584.54533720016,0.626800000667572,1.8423826694488523,10000,132998.9505019188,0.9626514315605164,0.1441816985607147,0.7557799816131592,1.0531625747680664,50000 -4401.344930171967,15.865436553955078,129094.4659075737,385654,0,129094.4659075737,0.6274000406265259,1.84333074092865,10000,133525.52531456947,0.9628706574440002,0.1429570764303207,0.7558000087738037,1.0541527271270752,50000 -4418.605168104172,15.95547103881836,129604.40239357948,387179,0,129604.40239357948,0.628000020980835,1.8408176898956297,10000,134052.86649990082,0.9623923301696776,0.143512025475502,0.7556999921798706,1.053622841835022,50000 -4435.200484991074,16.04998469352722,130114.50559592248,388704,0,130114.50559592248,0.6267000436782837,1.8424186706542969,10000,134579.7143881321,0.959582269191742,0.14606274664402,0.7560999989509583,1.053932547569275,50000 -4451.753606081009,16.146687269210815,130624.4790096283,390229,0,130624.4790096283,0.6267000436782837,1.843072533607483,10000,135106.39228892326,0.9609375,0.1475533097982406,0.7559999823570251,1.0535551309585571,50000 -4468.369817733765,16.23644709587097,131134.4970755577,391754,0,131134.4970755577,0.6271000504493713,1.8442800045013428,10000,135633.17174863815,0.962511956691742,0.141088455915451,0.7557799816131592,1.0546152591705322,50000 -4484.873900651932,16.379529237747192,131644.5588364601,393279,0,131644.5588364601,0.626800000667572,1.84110689163208,10000,136159.93689537048,0.960957407951355,0.145300954580307,0.7558799982070923,1.0527784824371338,50000 -4501.413993358612,16.484485864639282,132154.6266951561,394804,0,132154.6266951561,0.626800000667572,1.8425297737121584,10000,136686.70456910133,0.9612962007522584,0.1456609070301056,0.7558199763298035,1.0530370473861694,50000 -4517.964068174362,16.57689118385315,132664.5911166668,396328,0,132664.5911166668,0.6260000467300415,1.84261167049408,10000,137213.3666226864,0.9614955186843872,0.1447633355855941,0.7559199929237366,1.0543824434280396,50000 -4534.476091146469,16.671013355255127,133174.60023641586,397853,0,133174.60023641586,0.626800000667572,1.8429162502288816,10000,137740.03743171692,0.960558831691742,0.1463161557912826,0.7558199763298035,1.0536457300186155,50000 -4550.910442829132,16.769649982452393,133684.56909918785,399378,0,133684.56909918785,0.627500057220459,1.8440628051757808,10000,138266.5942044258,0.9625717401504515,0.1407064348459243,0.7558599710464478,1.053896427154541,50000 -4567.4318034648895,16.86515712738037,134194.7005777359,400903,0,134194.7005777359,0.6270000338554382,1.8432790040969849,10000,138793.39784646034,0.9602399468421936,0.1468479335308075,0.7557199597358704,1.0537818670272827,50000 -4584.428512334824,16.963099002838135,134704.63186764717,402428,0,134704.63186764717,0.6276000142097473,1.842927098274231,10000,139320.47827506065,0.9604392051696776,0.1478034555912017,0.7559399604797363,1.053891658782959,50000 -4601.067002296448,17.061861038208008,135214.64415454865,403953,0,135214.64415454865,0.6271000504493713,1.839210629463196,10000,139847.28141379356,0.961336076259613,0.1461579501628875,0.756119966506958,1.0521721839904783,50000 -4617.628590106964,17.158963441848755,135724.62863755226,405478,0,135724.62863755226,0.6273000240325928,1.84215247631073,10000,140373.97967410088,0.9627909660339355,0.1421767473220825,0.756119966506958,1.0536324977874756,50000 -4634.187923192978,17.261436939239502,136234.4849574566,407002,0,136234.4849574566,0.6276000142097473,1.8435428142547607,10000,140900.5523967743,0.9618144035339355,0.1452556103467941,0.7559999823570251,1.0538339614868164,50000 -4650.723289728165,17.356701374053955,136744.42813515663,408527,0,136744.42813515663,0.6272000074386597,1.842506766319275,10000,141427.18142938614,0.9584860801696776,0.1510231494903564,0.7560399770736694,1.0538654327392578,50000 -4667.281981468201,17.456937551498413,137254.58539009094,410052,0,137254.58539009094,0.6276000142097473,1.843324899673462,10000,141954.05271077156,0.9606385231018066,0.1461514383554458,0.7562400102615356,1.053558349609375,50000 -4683.87499833107,17.549427270889282,137764.53922367096,411576,0,137764.53922367096,0.6277000308036804,1.8429654836654663,10000,142480.7475554943,0.9613958597183228,0.1457064151763916,0.7561399936676025,1.0536881685256958,50000 -4700.526404619217,17.64485478401184,138274.40833592415,413100,0,138274.40833592415,0.6271000504493713,1.843506932258606,10000,143007.41964292526,0.9624919891357422,0.1417132914066314,0.7558000087738037,1.054227352142334,50000 -4717.203856706619,17.738426208496094,138784.3907828331,414625,0,138784.3907828331,0.6271000504493713,1.843660473823548,10000,143534.22866630554,0.9585259556770324,0.1491367220878601,0.7554399967193604,1.0536048412322998,50000 -4733.677309751511,17.83676266670227,139294.40902137756,416150,0,139294.40902137756,0.6269000172615051,1.8419307470321653,10000,144060.8727862835,0.9607780575752258,0.1463647484779358,0.7557199597358704,1.052968978881836,50000 -4750.356118917465,17.932732820510864,139804.53283405304,417675,0,139804.53283405304,0.6277000308036804,1.8424909114837649,10000,144587.82571840286,0.9611766338348388,0.1453454792499542,0.7560399770736694,1.053350567817688,50000 -4766.973588705063,18.03618574142456,140314.65873599052,419200,0,140314.65873599052,0.6277000308036804,1.8426084518432613,10000,145114.72708678246,0.9610171914100648,0.148466870188713,0.7556799650192261,1.053810477256775,50000 -4783.490590810776,18.113587141036987,140824.52908945084,420724,0,140824.52908945084,0.626300036907196,1.843507647514344,10000,145641.24696421623,0.9606584906578064,0.1469320952892303,0.7560399770736694,1.0542471408843994,50000 -4800.164356708527,18.211303234100345,141334.59196543694,422249,0,141334.59196543694,0.6279000043869019,1.8419170379638672,10000,146168.13632249832,0.9617944359779358,0.1449295580387115,0.7558000087738037,1.0536588430404663,50000 -4816.875243186951,18.31092357635498,141844.57013821602,423774,0,141844.57013821602,0.628000020980835,1.8425239324569704,10000,146694.98079276085,0.9623525142669678,0.1428000479936599,0.7557599544525146,1.0537769794464111,50000 -4833.446505069733,18.412598848342896,142354.64644885063,425299,0,142354.64644885063,0.6271000504493713,1.8420350551605225,10000,147221.78683519363,0.9629504084587096,0.1425699740648269,0.7560799717903137,1.053001046180725,50000 -4850.79785990715,18.51312065124512,142864.6662731171,426824,0,142864.6662731171,0.627500057220459,1.840661644935608,10000,147749.3140487671,0.9593231678009032,0.1488415002822876,0.7558799982070923,1.0532394647598269,50000 -4867.446770191193,18.61627769470215,143374.65701889992,428349,0,143374.65701889992,0.6274000406265259,1.8414006233215328,10000,148276.11263155937,0.9606783986091614,0.1454301327466964,0.7557399868965149,1.0528854131698608,50000 -4883.9972012043,18.715560913085938,143884.65881085396,429873,0,143884.65881085396,0.6272000074386597,1.8428633213043213,10000,148802.81915712357,0.9620535373687744,0.144050195813179,0.7560999989509583,1.053205966949463,50000 -4900.618678569794,18.810232400894165,144394.69930911064,431398,0,144394.69930911064,0.6273000240325928,1.842931151390076,10000,149329.63092327118,0.961933970451355,0.142555832862854,0.7561399936676025,1.0534591674804688,50000 -4917.084590911865,18.91004538536072,144904.76746439934,432923,0,144904.76746439934,0.6272000074386597,1.842706322669983,10000,149856.3191523552,0.961355984210968,0.1434606611728668,0.755899965763092,1.0536073446273804,50000 -4933.675719738007,19.008899688720703,145414.69251775742,434447,0,145414.69251775742,0.6271000504493713,1.842933297157288,10000,150382.98892378807,0.9610969424247742,0.1472132056951522,0.7562199831008911,1.0545414686203003,50000 -4950.279651641846,19.110421895980835,145924.79586172104,435972,0,145924.79586172104,0.627500057220459,1.8414922952651973,10000,150909.85272932053,0.9607780575752258,0.145490288734436,0.755840003490448,1.053587555885315,50000 -4966.919440507889,19.20820665359497,146434.88862538338,437497,0,146434.88862538338,0.6271000504493713,1.8432046175003047,10000,151436.73866152763,0.9619937539100648,0.1432654112577438,0.7560999989509583,1.053731918334961,50000 -4983.461073637009,19.31168293952942,146944.86338567734,439022,0,146944.86338567734,0.6266000270843506,1.8422584533691408,10000,151963.4141998291,0.9611168503761292,0.1442411392927169,0.7558599710464478,1.053155779838562,50000 -5000.237932682037,19.418365716934204,147454.86088490486,440547,0,147454.86088490486,0.6271000504493713,1.8422164916992188,10000,152490.3506128788,0.9596819281578064,0.1478977352380752,0.7560799717903137,1.053455114364624,50000 -5016.71445608139,19.52862310409546,147964.97935509682,442073,0,147964.97935509682,0.6277000308036804,1.841708302497864,10000,153017.11069369316,0.9610171914100648,0.1458736658096313,0.7556599974632263,1.052662372589111,50000 -5033.272862672806,19.630709886550903,148475.09104943275,443599,0,148475.09104943275,0.6272000074386597,1.843441247940064,10000,153543.9378838539,0.962332546710968,0.1446838676929474,0.7557799816131592,1.053775429725647,50000 -5049.900293827057,19.734724521636963,148985.1442747116,445124,0,148985.1442747116,0.6267000436782837,1.84416663646698,10000,154070.77855920792,0.9621731042861938,0.1434157192707061,0.7562599778175354,1.053500056266785,50000 -5066.508980512619,19.83415412902832,149495.27058815956,446649,0,149495.27058815956,0.6273000240325928,1.8439942598342896,10000,154597.66735219955,0.9598612785339355,0.1488574594259262,0.7561599612236023,1.0542877912521362,50000 -5082.940105676651,19.93660569190979,150005.38194799423,448174,0,150005.38194799423,0.627500057220459,1.8451577425003047,10000,155124.3677699566,0.959781527519226,0.1462220698595047,0.7554799914360046,1.0554804801940918,50000 -5099.444447517395,20.040841579437256,150515.29887747765,449699,0,150515.29887747765,0.6270000338554382,1.8444410562515257,10000,155650.94755601883,0.9611168503761292,0.1465611755847931,0.7557399868965149,1.0540775060653689,50000 -5116.119092226028,20.15410733222961,151025.3807592392,451224,0,151025.3807592392,0.627500057220459,1.8438587188720703,10000,156177.87268805504,0.9607979655265808,0.1459456384181976,0.7552599906921387,1.054438829421997,50000 -5132.721889019012,20.254146337509155,151535.39058494568,452749,0,151535.39058494568,0.6274000406265259,1.8419733047485352,10000,156704.6409471035,0.9612165093421936,0.1441743522882461,0.755840003490448,1.052942156791687,50000 -5149.499848604202,20.36068964004517,152045.5179359913,454275,0,152045.5179359913,0.6279000043869019,1.8424499034881592,10000,157231.7075505257,0.960359513759613,0.145753726363182,0.756060004234314,1.0537270307540894,50000 -5165.890701532364,20.46353912353516,152555.53424096107,455800,0,152555.53424096107,0.626800000667572,1.8414360284805296,10000,157758.27311992645,0.9604192972183228,0.1473921835422516,0.7557599544525146,1.0532865524291992,50000 -5182.322814702988,20.57215189933777,153065.48213148117,457325,0,153065.48213148117,0.6271000504493713,1.84191644191742,10000,158284.81679821014,0.9620735049247742,0.1456650346517563,0.7555800080299377,1.0540398359298706,50000 -5199.036423683167,20.67511820793152,153575.37730908394,458850,0,153575.37730908394,0.6274000406265259,1.8431851863861084,10000,158811.58496284485,0.960558831691742,0.1474863737821579,0.7563999891281128,1.0531036853790283,50000 -5215.4732303619385,20.78009033203125,154085.50613236427,460376,0,154085.50613236427,0.6276000142097473,1.842554211616516,10000,159338.31057286265,0.960558831691742,0.1477661728858947,0.7560399770736694,1.0535650253295898,50000 -5232.09307384491,20.888579607009888,154595.75695943832,461901,0,154595.75695943832,0.628000020980835,1.841142654418945,10000,159865.34474110603,0.963109850883484,0.1407929807901382,0.7557399868965149,1.0532169342041016,50000 -5248.627458095551,20.992839574813843,155105.88724708557,463426,0,155105.88724708557,0.627500057220459,1.8413161039352417,10000,160392.16827869415,0.9622329473495485,0.1450704485177993,0.7559799551963806,1.0533766746520996,50000 -5265.688413143158,21.09821939468384,155615.8431122303,464950,0,155615.8431122303,0.6267000436782837,1.8424440622329712,10000,160919.34607696533,0.9618343114852904,0.1451112926006317,0.7556399703025818,1.054116129875183,50000 -5282.34098482132,21.203664302825928,156125.88497161865,466475,0,156125.88497161865,0.6279000043869019,1.840462565422058,10000,161446.20129728317,0.9597217440605164,0.1456145346164703,0.7562599778175354,1.052058219909668,50000 -5298.78008556366,21.30401349067688,156635.74769496918,468000,0,156635.74769496918,0.6278000473976135,1.840748310089112,10000,161972.6580798626,0.9604192972183228,0.1474740803241729,0.7558199763298035,1.0522536039352417,50000 -5315.314467906952,21.41022562980652,157145.59330034256,469524,0,157145.59330034256,0.6272000074386597,1.8413631916046145,10000,162499.1994020939,0.9623923301696776,0.143751248717308,0.7559199929237366,1.0532188415527344,50000 -5331.7469527721405,21.52133750915528,157655.4903459549,471049,0,157655.4903459549,0.6276000142097473,1.842286944389344,10000,163025.69514727592,0.9620934128761292,0.1425232142210006,0.7557399868965149,1.0534279346466064,50000 -5348.365239858627,21.62071895599365,158165.50344014168,472574,0,158165.50344014168,0.6266000270843506,1.8426686525344849,10000,163552.47993898392,0.9614756107330322,0.1465017646551132,0.7559199929237366,1.054178237915039,50000 -5364.82608294487,21.726683616638184,158675.39512062073,474099,0,158675.39512062073,0.6260000467300415,1.8445450067520144,10000,164078.99399518967,0.960598647594452,0.1443846672773361,0.7561799883842468,1.0545673370361328,50000 -5381.293584108353,21.83575701713562,159185.36111426353,475624,0,159185.36111426353,0.6276000142097473,1.8428512811660769,10000,164605.59284853935,0.961734652519226,0.145376443862915,0.7560999989509583,1.0535876750946045,50000 -5397.987356424332,21.93658208847046,159695.37145113945,477149,0,159695.37145113945,0.6272000074386597,1.841827273368836,10000,165132.4526028633,0.9620137214660645,0.1419821977615356,0.756339967250824,1.0527808666229248,50000 -5414.497764825821,22.03964066505432,160205.27926325798,478674,0,160205.27926325798,0.6279000043869019,1.8423027992248533,10000,165659.02926802635,0.96000075340271,0.1474557220935821,0.7561799883842468,1.052963137626648,50000 -5431.116099834442,22.14845061302185,160715.17672729492,480198,0,160715.17672729492,0.6277000308036804,1.8432202339172363,10000,166185.70827054977,0.9604990482330322,0.1476502567529678,0.7561999559402466,1.053896188735962,50000 -5447.69705748558,22.248322248458862,161225.3224759102,481724,0,161225.3224759102,0.6271000504493713,1.841601610183716,10000,166712.59029626846,0.9612364172935486,0.1440017223358154,0.7558599710464478,1.0537447929382324,50000 -5464.181796789169,22.35664367675781,161735.1811146736,483249,0,161735.1811146736,0.626300036907196,1.843042254447937,10000,167239.09651231766,0.9628507494926452,0.1429650485515594,0.7554599642753601,1.0544278621673584,50000 -5480.593630313873,22.470579624176025,162245.09081196785,484774,0,162245.09081196785,0.628000020980835,1.8424336910247805,10000,167765.5865190029,0.9619140625,0.1458676904439926,0.755899965763092,1.054257035255432,50000 -5497.265798330307,22.580674409866333,162754.97107815742,486299,0,162754.97107815742,0.6274000406265259,1.8425171375274656,10000,168292.3037648201,0.9597616195678712,0.1485474854707718,0.7559199929237366,1.0532896518707275,50000 -5513.765965938568,22.693618059158325,163264.9387593269,487825,0,163264.9387593269,0.6271000504493713,1.8440321683883667,10000,168818.93969726562,0.9598014950752258,0.1468220502138137,0.7554999589920044,1.0537186861038208,50000 -5530.3492596149445,22.80390238761902,163774.9257595539,489350,0,163774.9257595539,0.626800000667572,1.8408963680267327,10000,169345.6744902134,0.9611766338348388,0.1468583345413208,0.7562400102615356,1.0533497333526611,50000 -5546.840213775635,22.913722276687626,164284.85869836807,490876,0,164284.85869836807,0.6274000406265259,1.8428215980529783,10000,169872.2647330761,0.9616748690605164,0.1434008181095123,0.7559399604797363,1.0531055927276611,50000 -5563.665554523468,23.0304811000824,164794.85840845108,492401,0,164794.85840845108,0.628000020980835,1.8415027856826784,10000,170399.26122307777,0.95902419090271,0.1491030007600784,0.7559999823570251,1.0529998540878296,50000 -5580.262861967087,23.13922953605652,165304.89320635796,493927,0,165304.89320635796,0.6270000338554382,1.842665672302246,10000,170926.0573911667,0.961535394191742,0.1440262198448181,0.7557599544525146,1.0536928176879885,50000 -5596.775694847107,23.250207901000977,165814.74216461182,495452,0,165814.74216461182,0.625700056552887,1.8420727252960205,10000,171452.58564066887,0.9607979655265808,0.146340012550354,0.7558799982070923,1.053356647491455,50000 -5613.217988491058,23.358347415924072,166324.6839056015,496977,0,166324.6839056015,0.6273000240325928,1.841199278831482,10000,171979.1335196495,0.9622329473495485,0.1473593860864639,0.7556599974632263,1.0541554689407349,50000 -5629.716914653778,23.470913410186768,166834.6341896057,498502,0,166834.6341896057,0.6270000338554382,1.8436610698699951,10000,172505.75029206276,0.9600805044174194,0.1480731964111328,0.7558199763298035,1.0539394617080688,50000 -5646.337370872498,23.58010768890381,167344.79048371315,500028,0,167344.79048371315,0.6279000043869019,1.8444896936416624,10000,173032.69118475914,0.961933970451355,0.1435198634862899,0.755899965763092,1.05421781539917,50000 -5662.989856958389,23.68991541862488,167854.95328998566,501554,0,167854.95328998566,0.6272000074386597,1.8417606353759768,10000,173559.6701028347,0.9622528553009032,0.1441863924264907,0.7561999559402466,1.0531775951385498,50000 -5679.408395290375,23.802663803100582,168364.83893346786,503044,0,168364.83893346786,0.626800000667572,1.84154212474823,10000,174086.14171028137,0.961933970451355,0.1459151953458786,0.7558799982070923,1.053280234336853,50000 -5696.652493000031,23.912421226501465,168874.7902429104,504568,0,168874.7902429104,0.6267000436782837,1.8432775735855105,10000,174613.502784729,0.9605189561843872,0.1453207582235336,0.7561799883842468,1.0537110567092896,50000 -5713.153592586517,24.021644353866577,169384.76676797867,506093,0,169384.76676797867,0.6278000473976135,1.8410217761993408,10000,175140.1451165676,0.9606186151504515,0.1453305780887603,0.756060004234314,1.0525985956192017,50000 -5729.777534723282,24.134642362594604,169894.76099991798,507618,0,169894.76099991798,0.6282000541687012,1.8442169427871704,10000,175666.93125104904,0.9618144035339355,0.1438751220703125,0.7559199929237366,1.0549389123916626,50000 -5746.335768461227,24.246805906295776,170404.71442842484,509143,0,170404.71442842484,0.6276000142097473,1.841456770896912,10000,176193.6103644371,0.9618542790412904,0.142910823225975,0.756339967250824,1.0527304410934448,50000 -5762.775901317596,24.35623812675476,170914.71456050873,510668,0,170914.71456050873,0.6284000277519226,1.842883825302124,10000,176720.21445131302,0.961355984210968,0.1437069475650787,0.7561599612236023,1.0545107126235962,50000 -5779.257610321045,24.469741344451904,171424.77597403526,512193,0,171424.77597403526,0.6279000043869019,1.8420346975326536,10000,177246.92549419403,0.960339605808258,0.1483767330646515,0.7562400102615356,1.0535874366760254,50000 -5795.874584913254,24.5813729763031,171934.87462615967,513718,0,171934.87462615967,0.626800000667572,1.8419123888015747,10000,177773.80760240555,0.9610371589660645,0.144314095377922,0.7560399770736694,1.052869439125061,50000 -5812.326997756958,24.69336938858032,172444.89008188248,515242,0,172444.89008188248,0.626800000667572,1.843067049980164,10000,178300.44254755974,0.9625318646430968,0.1427466124296188,0.755620002746582,1.054432988166809,50000 -5828.814296245575,24.806583642959595,172954.96911978722,516767,0,172954.96911978722,0.6270000338554382,1.8397737741470337,10000,178827.1764163971,0.9606186151504515,0.1462768763303756,0.7558799982070923,1.0523682832717896,50000 -5845.223039627075,24.920645713806152,173464.90362358093,518291,0,173464.90362358093,0.6271000504493713,1.843677639961243,10000,179353.6888029575,0.9604790806770324,0.1468463689088821,0.755899965763092,1.0545798540115356,50000 -5861.859467506409,25.033472061157227,173974.819116354,519816,0,173974.819116354,0.6278000473976135,1.8394850492477417,10000,179880.40839219093,0.9612962007522584,0.1456597298383712,0.7555999755859375,1.0524157285690308,50000 -5878.530686855316,25.147634506225582,174484.72979211807,521341,0,174484.72979211807,0.6273000240325928,1.8433527946472168,10000,180407.1592879296,0.9618343114852904,0.1445403099060058,0.7559999823570251,1.05410897731781,50000 -5895.08545088768,25.262465476989743,174994.86545610428,522866,0,174994.86545610428,0.6267000436782837,1.842559814453125,10000,180934.01882386208,0.961933970451355,0.1441616415977478,0.7557599544525146,1.0527174472808838,50000 -5911.562630414963,25.43892741203308,175504.70123004913,524390,0,175504.70123004913,0.6276000142097473,1.8422954082489007,10000,181460.5635290146,0.9607780575752258,0.1470319777727127,0.7557599544525146,1.053484320640564,50000 -5928.032803058624,25.55296421051025,176014.61433887482,525914,0,176014.61433887482,0.6270000338554382,1.843161582946777,10000,181987.117274046,0.958984375,0.1497214436531067,0.7561599612236023,1.0536900758743286,50000 -5944.878372192383,25.67080807685852,176524.59498858452,527438,0,176524.59498858452,0.6278000473976135,1.841441512107849,10000,182514.1163351536,0.9616350531578064,0.1456404328346252,0.756060004234314,1.0531024932861328,50000 -5961.299426794052,25.78989005088806,177034.4949579239,528962,0,177034.4949579239,0.627500057220459,1.841924071311951,10000,183040.6117610932,0.9606186151504515,0.1452937424182891,0.7559199929237366,1.05355703830719,50000 -5978.160856962204,26.731829404830933,177543.64189291,530484,0,177543.64189291,0.6266000270843506,1.84251868724823,10000,183567.61677789688,0.960957407951355,0.1444409638643264,0.7559399604797363,1.053533673286438,50000 -5994.678615808487,26.84842109680176,178053.5409553051,532009,0,178053.5409553051,0.6270000338554382,1.843008637428284,10000,184094.20546674728,0.9603993892669678,0.1479032784700393,0.7558799982070923,1.0539860725402832,50000 -6011.105018854141,26.96674156188965,178563.4308130741,533533,0,178563.4308130741,0.6273000240325928,1.8435173034667969,10000,184620.69504141808,0.961336076259613,0.1444842964410781,0.7559799551963806,1.0537562370300293,50000 -6027.668654680252,27.084050178527832,179073.46681928635,535058,0,179073.46681928635,0.6281000375747681,1.842696189880371,10000,185147.4665656089,0.96097731590271,0.1471495926380157,0.7558000087738037,1.0533289909362793,50000 -6044.2437698841095,27.201574563980103,179583.57815003395,536505,0,179583.57815003395,0.6271000504493713,1.84367835521698,10000,185674.32324552536,0.9612962007522584,0.1462796926498413,0.7561399936676025,1.0543361902236938,50000 -6060.702314376831,27.3272008895874,180093.41668057442,538029,0,180093.41668057442,0.6276000142097473,1.841093063354492,10000,186200.8007659912,0.9604790806770324,0.1484165489673614,0.7555800080299377,1.0530855655670166,50000 -6077.2591071128845,27.44335126876831,180603.4564025402,539554,0,180603.4564025402,0.6273000240325928,1.8427072763442995,10000,186727.56794548035,0.9626315236091614,0.1421427875757217,0.7562999725341797,1.053134799003601,50000 -6093.973418951035,27.557010173797607,181113.3877146244,541078,0,181113.3877146244,0.627500057220459,1.8408427238464355,10000,187254.3824417591,0.961933970451355,0.1462670266628265,0.7557599544525146,1.0533305406570437,50000 -6110.797034263611,27.67525243759156,181623.31734347343,542602,0,181623.31734347343,0.6265000104904175,1.8440637588500977,10000,187781.3084187508,0.962890625,0.1419719606637954,0.7557399868965149,1.0550309419631958,50000 -6127.386559247971,27.79384064674377,182133.297362566,544127,0,182133.297362566,0.6279000043869019,1.8415024280548096,10000,188308.0508189201,0.9596819281578064,0.1471571922302246,0.755899965763092,1.0522381067276,50000 -6143.728442192078,27.91211032867432,182643.4147348404,545652,0,182643.4147348404,0.6266000270843506,1.8414411544799805,10000,188834.6835012436,0.9605787396430968,0.145737811923027,0.7560999989509583,1.053802251815796,50000 -6160.168985366821,28.02666687965393,183153.32903194427,547176,0,183153.32903194427,0.6276000142097473,1.8418182134628296,10000,189361.20776104927,0.9627709984779358,0.1430668830871582,0.7557599544525146,1.0532115697860718,50000 -6176.777256727219,28.14390015602112,183663.2982008457,548700,0,183663.2982008457,0.6270000338554382,1.84215247631073,10000,189887.9575984478,0.9610371589660645,0.1452589631080627,0.7557799816131592,1.0536679029464722,50000 -6193.291128158569,28.26856780052185,184173.2213394642,550225,0,184173.2213394642,0.6277000308036804,1.8426450490951536,10000,190414.57474303248,0.9610969424247742,0.1436194628477096,0.756060004234314,1.0538721084594729,50000 -6209.7367441654205,28.450235843658447,184683.1692752838,551749,0,184683.1692752838,0.6265000104904175,1.842309713363648,10000,190941.2047362328,0.9610371589660645,0.1476167142391204,0.7555999755859375,1.0538707971572876,50000 -6226.223253250122,28.5725200176239,185193.2670288086,553274,0,185193.2670288086,0.6269000172615051,1.841801643371582,10000,191467.9659619332,0.961136758327484,0.1457731872797012,0.7558599710464478,1.0535832643508911,50000 -6242.835218429565,28.690925359725952,185703.3607866764,554799,0,185703.3607866764,0.6267000436782837,1.842273473739624,10000,191994.8460161686,0.9621731042861938,0.1420705914497375,0.7558199763298035,1.052920937538147,50000 -6259.27534365654,28.811352014541622,186213.40687561035,556324,0,186213.40687561035,0.6276000142097473,1.8436654806137085,10000,192521.5072004795,0.9607979655265808,0.1460682153701782,0.7557199597358704,1.053757905960083,50000 -6275.80174779892,28.93087267875672,186723.44114279747,557849,0,186723.44114279747,0.6270000338554382,1.843936562538147,10000,193048.2424557209,0.9605787396430968,0.1464001685380935,0.7560200095176697,1.0539249181747437,50000 -6292.241653680801,29.049833059310917,187233.40432572365,559374,0,187233.40432572365,0.6279000043869019,1.8411613702774048,10000,193574.8201098442,0.9610171914100648,0.1468623131513595,0.7558000087738037,1.052876591682434,50000 -6308.638211250305,29.170271158218384,187441.93405771255,559998,0,187441.93405771255,0.6265000104904175,1.8419371843338013,10000,193799.88971567154,0.9618144035339355,0.1436149626970291,0.7557799816131592,1.0530809164047241,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 782c9677a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5971 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.665139,6.9233155,,,,,,,,,,,,,, -1,,,0.0008769132546149,6.912364482879639,0.0011399999493733,6.912585258483887,50000.0,0.0017000001389533,6.912716865539551,10000.0,51.41184377670288,88.12890410423279,51.41184377670288,36.71697783470154,0.0,0.0 -100,0.6590658,6.8222494,,,,,,,,,,,,,, -200,0.78555757,6.57683,,,,,,,,,,,,,, -300,0.92572737,6.281175,,,,,,,,,,,,,, -400,1.6252986,5.977209,,,,,,,,,,,,,, -500,5.455198,5.8064556,,,,,,,,,,,,,, -600,2.7916892,5.646673,,,,,,,,,,,,,, -700,4.076124,5.4626446,,,,,,,,,,,,,, -800,3.0420108,5.35788,,,,,,,,,,,,,, -900,4.411283,5.161393,,,,,,,,,,,,,, -1000,5.462913,5.098168,,,,,,,,,,,,,, -1100,4.8197713,4.8853736,,,,,,,,,,,,,, -1200,5.1507807,4.801757,,,,,,,,,,,,,, -1300,3.6745691,4.70689,,,,,,,,,,,,,, -1400,7.3698845,4.6300025,,,,,,,,,,,,,, -1500,3.3774402,4.510443,,,,,,,,,,,,,, -1519,,,0.161371961236,4.343627452850342,0.1440799981355667,4.503424167633057,50000.0,0.1094000041484832,4.919663429260254,10000.0,561.3822112083435,615.7079174518585,561.3822112083435,54.24302554130554,0.0274267196655273,0.0 -1600,3.5852602,4.3741155,,,,,,,,,,,,,, -1700,3.7626095,4.3047485,,,,,,,,,,,,,, -1800,4.3903823,4.2928777,,,,,,,,,,,,,, -1900,3.2502728,4.140803,,,,,,,,,,,,,, -2000,4.8889494,4.0262976,,,,,,,,,,,,,, -2100,6.9118505,3.9091718,,,,,,,,,,,,,, -2200,3.5044029,3.8772764,,,,,,,,,,,,,, -2300,6.272111,3.851929,,,,,,,,,,,,,, -2400,5.3651648,3.805567,,,,,,,,,,,,,, -2500,3.927331,3.8045173,,,,,,,,,,,,,, -2600,3.3356109,3.6084666,,,,,,,,,,,,,, -2700,3.7984998,3.4511921,,,,,,,,,,,,,, -2800,4.549656,3.6312814,,,,,,,,,,,,,, -2900,3.9052622,3.565258,,,,,,,,,,,,,, -3000,2.833333,3.3753366,,,,,,,,,,,,,, -3037,,,0.3479352593421936,3.004847288131714,0.3234999775886535,3.174034833908081,50000.0,0.239200010895729,3.840391159057617,10000.0,1071.4606931209564,1143.4133098125458,1071.4606931209564,71.78705096244812,0.0556445121765136,0.0 -3100,3.5848448,3.543594,,,,,,,,,,,,,, -3200,3.5136065,3.2353573,,,,,,,,,,,,,, -3300,4.2362466,3.2633572,,,,,,,,,,,,,, -3400,3.4003832,3.217549,,,,,,,,,,,,,, -3500,2.9312239,3.1602397,,,,,,,,,,,,,, -3600,4.2102437,3.060515,,,,,,,,,,,,,, -3700,3.2607546,3.1797612,,,,,,,,,,,,,, -3800,3.4166486,3.051098,,,,,,,,,,,,,, -3900,3.1666083,3.1052012,,,,,,,,,,,,,, -4000,2.7012913,2.8365014,,,,,,,,,,,,,, -4100,3.372876,2.9770555,,,,,,,,,,,,,, -4200,2.092073,3.019944,,,,,,,,,,,,,, -4300,2.6649196,2.7817764,,,,,,,,,,,,,, -4400,1.9235871,2.8151765,,,,,,,,,,,,,, -4500,1.9796189,2.7617943,,,,,,,,,,,,,, -4558,,,0.4669762253761291,2.329172372817993,0.4377799928188324,2.5032687187194824,50000.0,0.323600023984909,3.25213885307312,10000.0,1581.70814204216,1671.4895164966583,1581.70814204216,89.53337836265564,0.083956241607666,0.0 -4600,2.2354076,2.8499053,,,,,,,,,,,,,, -4700,3.159547,2.788797,,,,,,,,,,,,,, -4800,2.6193712,2.7980728,,,,,,,,,,,,,, -4900,2.3042676,2.6735363,,,,,,,,,,,,,, -5000,2.135039,2.6540306,,,,,,,,,,,,,, -5100,3.2964017,2.6750164,,,,,,,,,,,,,, -5200,2.6254704,2.7093205,,,,,,,,,,,,,, -5300,2.6932616,2.6366346,,,,,,,,,,,,,, -5400,2.6200354,2.6463962,,,,,,,,,,,,,, -5500,3.0457647,2.5517592,,,,,,,,,,,,,, -5600,1.9290946,2.3733425,,,,,,,,,,,,,, -5700,2.2857902,2.3422875,,,,,,,,,,,,,, -5800,2.1958303,2.4761956,,,,,,,,,,,,,, -5900,2.0451455,2.4355664,,,,,,,,,,,,,, -6000,1.968429,2.500198,,,,,,,,,,,,,, -6079,,,0.5306122303009033,2.0072100162506104,0.4969799816608429,2.2160658836364746,50000.0,0.3765000104904175,2.998728036880493,10000.0,2091.871442079544,2199.260755300522,2091.871442079544,107.06326746940611,0.1095569133758544,0.0 -6100,2.4810915,2.392868,,,,,,,,,,,,,, -6200,2.260974,2.455267,,,,,,,,,,,,,, -6300,1.8102134,2.4760702,,,,,,,,,,,,,, -6400,2.445865,2.3212028,,,,,,,,,,,,,, -6500,1.810933,2.3876922,,,,,,,,,,,,,, -6600,2.3991008,2.4591465,,,,,,,,,,,,,, -6700,2.1342993,2.4588096,,,,,,,,,,,,,, -6800,1.9989898,2.2618446,,,,,,,,,,,,,, -6900,1.7286512,2.4045844,,,,,,,,,,,,,, -7000,1.8733954,2.3233743,,,,,,,,,,,,,, -7100,2.2520876,2.317935,,,,,,,,,,,,,, -7200,2.084926,2.4049115,,,,,,,,,,,,,, -7300,2.184294,2.218565,,,,,,,,,,,,,, -7400,1.8211495,2.2256663,,,,,,,,,,,,,, -7500,2.213704,2.2418258,,,,,,,,,,,,,, -7600,3.1834645,2.1890569,,,,,,,,,,,,,, -7601,,,0.5672432780265808,1.8172202110290527,0.5317000150680542,2.017817258834839,50000.0,0.417600005865097,2.7481794357299805,10000.0,2602.244828939438,2727.28044962883,2602.244828939438,124.62890672683716,0.1368377208709716,0.0 -7700,1.7672453,2.2312362,,,,,,,,,,,,,, -7800,1.768785,2.1522324,,,,,,,,,,,,,, -7900,1.5349395,2.1595058,,,,,,,,,,,,,, -8000,1.6182768,2.2495584,,,,,,,,,,,,,, -8100,1.9892079,2.169372,,,,,,,,,,,,,, -8200,2.1735864,2.2151344,,,,,,,,,,,,,, -8300,1.2818519,2.0845766,,,,,,,,,,,,,, -8400,2.0007389,2.120809,,,,,,,,,,,,,, -8500,1.6040903,2.176741,,,,,,,,,,,,,, -8600,1.4705904,2.0390625,,,,,,,,,,,,,, -8700,1.7349837,2.156631,,,,,,,,,,,,,, -8800,1.7756243,2.1192887,,,,,,,,,,,,,, -8900,1.5293155,2.1408021,,,,,,,,,,,,,, -9000,1.9981587,2.2205248,,,,,,,,,,,,,, -9100,1.4134771,2.0747852,,,,,,,,,,,,,, -9123,,,0.6525231003761292,1.404765486717224,0.5640599727630615,1.8469562530517576,50000.0,0.4397000074386596,2.59698748588562,10000.0,3112.356943130493,3255.2293422222137,3112.356943130493,142.38618516921997,0.1638340950012207,0.0 -9200,1.9854565,2.1381261,,,,,,,,,,,,,, -9300,1.7829711,2.3099804,,,,,,,,,,,,,, -9400,2.0594342,1.9884001,,,,,,,,,,,,,, -9500,1.4637622,2.1965933,,,,,,,,,,,,,, -9600,1.8583136,2.1454527,,,,,,,,,,,,,, -9700,1.6223799,2.2053924,,,,,,,,,,,,,, -9800,1.5481464,2.1016974,,,,,,,,,,,,,, -9900,1.7440674,2.1405292,,,,,,,,,,,,,, -10000,1.8628002,2.1662345,,,,,,,,,,,,,, -10100,1.6575541,2.1706932,,,,,,,,,,,,,, -10200,1.553414,2.1252093,,,,,,,,,,,,,, -10300,1.5992339,2.1590078,,,,,,,,,,,,,, -10400,1.5174006,2.1621773,,,,,,,,,,,,,, -10500,1.4626331,1.8755566,,,,,,,,,,,,,, -10600,1.6510634,2.1372454,,,,,,,,,,,,,, -10645,,,0.6332509517669678,1.4783037900924685,0.5731399655342102,1.8122446537017824,50000.0,0.4448000192642212,2.552358627319336,10000.0,3622.518341064453,3783.868696928024,3622.518341064453,160.78351044654846,0.1913130283355713,0.0 -10700,1.8061233,2.0046086,,,,,,,,,,,,,, -10800,1.7186908,1.9887735,,,,,,,,,,,,,, -10900,1.7869563,2.054155,,,,,,,,,,,,,, -11000,1.9129874,2.0732188,,,,,,,,,,,,,, -11100,1.7939178,2.0483432,,,,,,,,,,,,,, -11200,1.6470375,2.0467281,,,,,,,,,,,,,, -11300,1.9096677,1.9453428,,,,,,,,,,,,,, -11400,1.3624477,1.9681485,,,,,,,,,,,,,, -11500,2.3127851,1.9806839,,,,,,,,,,,,,, -11600,1.6134497,2.0833948,,,,,,,,,,,,,, -11700,1.7963148,2.0833366,,,,,,,,,,,,,, -11800,1.716267,2.0834649,,,,,,,,,,,,,, -11900,1.3667005,1.9745309,,,,,,,,,,,,,, -12000,1.674946,1.8665955,,,,,,,,,,,,,, -12100,2.0617335,1.974114,,,,,,,,,,,,,, -12167,,,0.6190608739852905,1.5576162338256836,0.5658199787139893,1.834885835647583,50000.0,0.4365000128746032,2.630974054336548,10000.0,4132.719137430191,4312.018725633621,4132.719137430191,178.64738535881042,0.2245426177978515,0.0 -12200,1.317779,1.9102386,,,,,,,,,,,,,, -12300,1.6681383,2.0567033,,,,,,,,,,,,,, -12400,1.6568856,2.0368938,,,,,,,,,,,,,, -12500,1.6701827,1.8757653,,,,,,,,,,,,,, -12600,1.4720421,1.9470596,,,,,,,,,,,,,, -12700,1.3411194,1.8934997,,,,,,,,,,,,,, -12800,1.7306834,1.9452908,,,,,,,,,,,,,, -12900,1.4574857,1.9019413,,,,,,,,,,,,,, -13000,1.5968356,1.9470382,,,,,,,,,,,,,, -13100,2.0213976,1.9189639,,,,,,,,,,,,,, -13200,1.703134,1.906262,,,,,,,,,,,,,, -13300,1.8746132,1.781275,,,,,,,,,,,,,, -13400,1.5838845,1.9241806,,,,,,,,,,,,,, -13500,1.5680238,1.9574066,,,,,,,,,,,,,, -13600,1.4405615,1.8263253,,,,,,,,,,,,,, -13689,,,0.6477000713348389,1.4246323108673096,0.5928399562835693,1.7087825536727903,50000.0,0.4770000278949737,2.4406228065490723,10000.0,4642.680076122284,4839.813012838364,4642.680076122284,196.4004583358765,0.2521340847015381,0.0 -13700,1.7133737,1.8249147,,,,,,,,,,,,,, -13800,1.5975428,1.9064883,,,,,,,,,,,,,, -13900,1.4287702,2.011334,,,,,,,,,,,,,, -14000,1.571991,1.9783595,,,,,,,,,,,,,, -14100,1.5663625,2.0072317,,,,,,,,,,,,,, -14200,1.4764934,1.8739836,,,,,,,,,,,,,, -14300,1.6252826,1.9214208,,,,,,,,,,,,,, -14400,2.0344918,1.9610643,,,,,,,,,,,,,, -14500,1.6121795,1.8636001,,,,,,,,,,,,,, -14600,1.4546472,2.0112622,,,,,,,,,,,,,, -14700,1.5390452,1.8879921,,,,,,,,,,,,,, -14800,1.7012953,1.9386739,,,,,,,,,,,,,, -14900,1.8303392,1.9716221,,,,,,,,,,,,,, -15000,2.0096848,2.0780241,,,,,,,,,,,,,, -15100,1.6852281,1.7712505,,,,,,,,,,,,,, -15200,1.4257736,1.9367142,,,,,,,,,,,,,, -15211,,,0.6474409699440002,1.4198980331420898,0.5953800082206726,1.6889138221740725,50000.0,0.4726000130176544,2.4248533248901367,10000.0,5152.619967222214,5368.286201000214,5152.619967222214,214.842568397522,0.289806604385376,0.0 -15300,1.6349075,1.9324207,,,,,,,,,,,,,, -15400,2.1581173,1.8992865,,,,,,,,,,,,,, -15500,1.6249721,1.8890573,,,,,,,,,,,,,, -15600,1.752845,1.8851142,,,,,,,,,,,,,, -15700,1.4920875,1.7865393,,,,,,,,,,,,,, -15800,1.8581041,1.9522622,,,,,,,,,,,,,, -15900,1.8216395,1.8446146,,,,,,,,,,,,,, -16000,1.7485597,1.8587819,,,,,,,,,,,,,, -16100,1.3955742,1.8481843,,,,,,,,,,,,,, -16200,1.5874435,1.8699064,,,,,,,,,,,,,, -16300,1.5327433,1.8736721,,,,,,,,,,,,,, -16400,1.6182258,1.9652252,,,,,,,,,,,,,, -16500,1.604727,1.8785279,,,,,,,,,,,,,, -16600,1.4265872,1.8820436,,,,,,,,,,,,,, -16700,1.7203844,1.8451422,,,,,,,,,,,,,, -16734,,,0.6482182741165161,1.4272924661636353,0.6008999943733215,1.6777706146240234,50000.0,0.468500018119812,2.439729690551758,10000.0,5662.831291437149,5896.757298469544,5662.831291437149,233.00697684288025,0.3311009407043457,0.0 -16800,1.5699452,1.8490207,,,,,,,,,,,,,, -16900,1.4669015,1.8735471,,,,,,,,,,,,,, -17000,1.5789155,1.8818268,,,,,,,,,,,,,, -17100,1.4799162,1.856448,,,,,,,,,,,,,, -17200,1.5301899,1.7776785,,,,,,,,,,,,,, -17300,1.2825135,1.8799646,,,,,,,,,,,,,, -17400,1.8088204,1.9041073,,,,,,,,,,,,,, -17500,1.6247445,1.9822867,,,,,,,,,,,,,, -17600,1.6237407,1.85021,,,,,,,,,,,,,, -17700,1.8673819,1.8215519,,,,,,,,,,,,,, -17800,1.6348825,1.9150143,,,,,,,,,,,,,, -17900,1.6389898,1.8583884,,,,,,,,,,,,,, -18000,1.6547824,1.8491566,,,,,,,,,,,,,, -18100,1.9096357,1.7242506,,,,,,,,,,,,,, -18200,1.590811,1.8184866,,,,,,,,,,,,,, -18258,,,0.6813018321990967,1.2537661790847778,0.6010000109672546,1.6689846515655518,50000.0,0.4749000370502472,2.4517030715942383,10000.0,6173.075484752655,6425.283870458603,6173.075484752655,251.19899654388428,0.3682253360748291,0.0 -18300,1.6955577,1.8376254,,,,,,,,,,,,,, -18400,1.684792,1.9226224,,,,,,,,,,,,,, -18500,1.7156203,1.8690073,,,,,,,,,,,,,, -18600,1.635929,1.8500384,,,,,,,,,,,,,, -18700,1.7400348,1.8609879,,,,,,,,,,,,,, -18800,1.7122585,1.8806127,,,,,,,,,,,,,, -18900,1.46085,1.8614799,,,,,,,,,,,,,, -19000,1.7187872,1.8591753,,,,,,,,,,,,,, -19100,1.5184863,1.7799209,,,,,,,,,,,,,, -19200,1.5032989,1.7630119,,,,,,,,,,,,,, -19300,1.6740049,1.9623873,,,,,,,,,,,,,, -19400,1.8747311,1.9409729,,,,,,,,,,,,,, -19500,1.5955851,1.8453822,,,,,,,,,,,,,, -19600,1.8858881,1.7034926,,,,,,,,,,,,,, -19700,1.5767794,1.8099973,,,,,,,,,,,,,, -19781,,,0.6559510231018066,1.365241765975952,0.5920000076293945,1.7166610956192017,50000.0,0.4679000079631805,2.46029019355774,10000.0,6683.053819656372,6954.625028371811,6683.053819656372,270.4668297767639,0.4093070030212402,0.0 -19800,1.8245595,1.8174362,,,,,,,,,,,,,, -19900,1.447946,1.7822129,,,,,,,,,,,,,, -20000,1.5741014,1.7867197,,,,,,,,,,,,,, -20100,1.6952785,1.7611675,,,,,,,,,,,,,, -20200,1.598883,1.8021758,,,,,,,,,,,,,, -20300,1.699018,1.9070101,,,,,,,,,,,,,, -20400,1.7848376,1.9077537,,,,,,,,,,,,,, -20500,1.7177423,1.8773743,,,,,,,,,,,,,, -20600,2.0096376,1.8229767,,,,,,,,,,,,,, -20700,1.4643246,1.7748125,,,,,,,,,,,,,, -20800,1.980762,1.8745964,,,,,,,,,,,,,, -20900,1.635468,1.8304875,,,,,,,,,,,,,, -21000,1.7598573,1.8575978,,,,,,,,,,,,,, -21100,1.6779073,1.9410906,,,,,,,,,,,,,, -21200,1.7978996,1.8111353,,,,,,,,,,,,,, -21300,1.4809932,1.8631594,,,,,,,,,,,,,, -21304,,,0.6759207248687744,1.2931432723999023,0.6177799701690674,1.5928702354431152,50000.0,0.4934000372886657,2.318434953689575,10000.0,7193.051939487457,7487.018528699875,7193.051939487457,292.76573491096497,0.4507136344909668,0.0 -21400,1.5647839,1.8260627,,,,,,,,,,,,,, -21500,1.5743439,1.8235484,,,,,,,,,,,,,, -21600,1.7080536,1.8152432,,,,,,,,,,,,,, -21700,1.5721246,1.6506373,,,,,,,,,,,,,, -21800,1.6056107,1.7365491,,,,,,,,,,,,,, -21900,1.6881657,1.7603065,,,,,,,,,,,,,, -22000,1.6397196,1.8551646,,,,,,,,,,,,,, -22100,1.360387,1.7335358,,,,,,,,,,,,,, -22200,1.6203647,1.8265887,,,,,,,,,,,,,, -22300,1.9340633,1.8156092,,,,,,,,,,,,,, -22400,1.6342362,1.7575753,,,,,,,,,,,,,, -22500,1.494328,1.6066518,,,,,,,,,,,,,, -22600,1.9667689,1.8455245,,,,,,,,,,,,,, -22700,1.6378584,1.7251662,,,,,,,,,,,,,, -22800,1.9408168,1.794676,,,,,,,,,,,,,, -22827,,,0.6716158986091614,1.3155529499053955,0.6108999848365784,1.6090680360794067,50000.0,0.490200012922287,2.343901395797729,10000.0,7703.096212387085,8021.0365245342255,7703.096212387085,316.6596989631653,0.4778110980987549,0.0 -22900,1.7171754,1.6670735,,,,,,,,,,,,,, -23000,1.5752169,1.7022841,,,,,,,,,,,,,, -23100,1.6275517,1.7220675,,,,,,,,,,,,,, -23200,1.5740683,1.6625721,,,,,,,,,,,,,, -23300,1.8588562,1.8669015,,,,,,,,,,,,,, -23400,1.619627,1.8520645,,,,,,,,,,,,,, -23500,2.0177457,1.8558372,,,,,,,,,,,,,, -23600,1.669084,1.7182281,,,,,,,,,,,,,, -23700,1.7722627,1.6985018,,,,,,,,,,,,,, -23800,1.70319,1.6795974,,,,,,,,,,,,,, -23900,1.7411343,1.823328,,,,,,,,,,,,,, -24000,1.8247881,1.673985,,,,,,,,,,,,,, -24100,1.686493,1.6976893,,,,,,,,,,,,,, -24200,1.8983037,1.8371173,,,,,,,,,,,,,, -24300,2.3118422,1.730114,,,,,,,,,,,,,, -24351,,,0.6676697731018066,1.3202012777328491,0.6128999590873718,1.6204156875610352,50000.0,0.4942000210285187,2.327371120452881,10000.0,8213.247112512589,8553.60300731659,8213.247112512589,338.99554920196533,0.5040531158447266,0.0 -24400,1.621727,1.8432226,,,,,,,,,,,,,, -24500,1.6305685,1.7607614,,,,,,,,,,,,,, -24600,2.070208,1.8759662,,,,,,,,,,,,,, -24700,1.8148077,1.8195184,,,,,,,,,,,,,, -24800,1.6508936,1.7881356,,,,,,,,,,,,,, -24900,1.9155413,1.7835021,,,,,,,,,,,,,, -25000,1.6596445,1.6403441,,,,,,,,,,,,,, -25100,2.0096433,1.7676795,,,,,,,,,,,,,, -25200,1.8129535,1.8298137,,,,,,,,,,,,,, -25300,1.4542346,1.6520476,,,,,,,,,,,,,, -25400,1.606117,1.7542522,,,,,,,,,,,,,, -25500,1.4898183,1.7489406,,,,,,,,,,,,,, -25600,1.7847693,1.8248256,,,,,,,,,,,,,, -25700,1.7846954,1.7928994,,,,,,,,,,,,,, -25800,1.7192991,1.7155837,,,,,,,,,,,,,, -25875,,,0.71097731590271,1.140007257461548,0.6279399991035461,1.5446945428848269,50000.0,0.5009000301361084,2.293175458908081,10000.0,8723.437615156174,9087.882278203964,8723.437615156174,363.0035355091095,0.53145432472229,0.0 -25900,1.6221653,1.7652483,,,,,,,,,,,,,, -26000,1.7807631,1.7221488,,,,,,,,,,,,,, -26100,1.9305447,1.8658715,,,,,,,,,,,,,, -26200,1.7038043,1.8197566,,,,,,,,,,,,,, -26300,1.9170132,1.7724133,,,,,,,,,,,,,, -26400,1.6657882,1.835736,,,,,,,,,,,,,, -26500,1.8108028,1.70577,,,,,,,,,,,,,, -26600,1.6130413,1.7046363,,,,,,,,,,,,,, -26700,1.8972938,1.6028261,,,,,,,,,,,,,, -26800,2.0549457,1.7405088,,,,,,,,,,,,,, -26900,1.540445,1.8216386,,,,,,,,,,,,,, -27000,1.7437993,1.908591,,,,,,,,,,,,,, -27100,1.6555004,1.7253122,,,,,,,,,,,,,, -27200,1.9654363,1.820762,,,,,,,,,,,,,, -27300,1.810817,1.7223752,,,,,,,,,,,,,, -27398,,,0.697265625,1.176541090011597,0.623479962348938,1.5636550188064575,50000.0,0.4933000206947326,2.3293328285217285,10000.0,9233.4411110878,9622.420518398283,9233.4411110878,387.4523963928223,0.5643725395202637,0.0 -27400,1.6504576,1.7278097,,,,,,,,,,,,,, -27500,1.6504529,1.7094458,,,,,,,,,,,,,, -27600,1.8402139,1.8837976,,,,,,,,,,,,,, -27700,1.6452296,1.8393841,,,,,,,,,,,,,, -27800,1.656801,1.7793496,,,,,,,,,,,,,, -27900,2.1041355,1.8353266,,,,,,,,,,,,,, -28000,1.8112139,1.7785846,,,,,,,,,,,,,, -28100,2.1600478,1.7723063,,,,,,,,,,,,,, -28200,1.8540549,1.7313846,,,,,,,,,,,,,, -28300,1.7649734,1.7607654,,,,,,,,,,,,,, -28400,1.7565321,1.7791789,,,,,,,,,,,,,, -28500,1.8413336,1.8090866,,,,,,,,,,,,,, -28600,1.7036871,1.5959023,,,,,,,,,,,,,, -28700,1.9238262,1.7499293,,,,,,,,,,,,,, -28800,1.59924,1.7564304,,,,,,,,,,,,,, -28900,1.6435316,1.6771834,,,,,,,,,,,,,, -28922,,,0.6786311864852905,1.2781400680541992,0.6108599901199341,1.6249542236328125,50000.0,0.4853000342845917,2.3837692737579346,10000.0,9743.443282604218,10156.39930844307,9743.443282604218,411.3435943126679,0.5972199440002441,0.0 -29000,1.673646,1.7267118,,,,,,,,,,,,,, -29100,2.261179,1.8747513,,,,,,,,,,,,,, -29200,1.7620246,1.748618,,,,,,,,,,,,,, -29300,1.7365786,1.716656,,,,,,,,,,,,,, -29400,1.7282722,1.6408132,,,,,,,,,,,,,, -29500,1.8832142,1.7373397,,,,,,,,,,,,,, -29600,1.7925326,1.6816558,,,,,,,,,,,,,, -29700,2.0005193,1.8490238,,,,,,,,,,,,,, -29800,1.7610612,1.7118665,,,,,,,,,,,,,, -29900,1.7685024,1.726465,,,,,,,,,,,,,, -30000,1.7492185,1.6148787,,,,,,,,,,,,,, -30100,1.8815488,1.7490023,,,,,,,,,,,,,, -30200,1.7785444,1.7643044,,,,,,,,,,,,,, -30300,1.6879904,1.6470015,,,,,,,,,,,,,, -30400,1.8084781,1.740499,,,,,,,,,,,,,, -30446,,,0.6844507455825806,1.2520849704742432,0.6162399649620056,1.5863946676254272,50000.0,0.4905000329017639,2.291940689086914,10000.0,10253.598109006882,10691.41821050644,10253.598109006882,436.1221706867218,0.6296067237854004,0.0 -30500,1.7651083,1.7745965,,,,,,,,,,,,,, -30600,1.8701106,1.7078753,,,,,,,,,,,,,, -30700,1.8151104,1.6875101,,,,,,,,,,,,,, -30800,1.7732607,1.7497958,,,,,,,,,,,,,, -30900,1.5790584,1.7130486,,,,,,,,,,,,,, -31000,1.7156963,1.6362514,,,,,,,,,,,,,, -31100,1.7677724,1.7903966,,,,,,,,,,,,,, -31200,1.7419535,1.5731808,,,,,,,,,,,,,, -31300,1.7673794,1.7136232,,,,,,,,,,,,,, -31400,1.7114315,1.6108803,,,,,,,,,,,,,, -31500,1.6717548,1.7447582,,,,,,,,,,,,,, -31600,1.8261055,1.6668507,,,,,,,,,,,,,, -31700,1.6858137,1.7234749,,,,,,,,,,,,,, -31800,1.6430606,1.6117003,,,,,,,,,,,,,, -31900,1.8352444,1.7618588,,,,,,,,,,,,,, -31970,,,0.6871213316917419,1.239983081817627,0.6253799796104431,1.5454217195510864,50000.0,0.5026000142097473,2.285898447036743,10000.0,10763.550921201706,11225.136703014374,10763.550921201706,459.7998206615448,0.6647074222564697,0.0 -32000,1.8790128,1.7236398,,,,,,,,,,,,,, -32100,1.7497479,1.681829,,,,,,,,,,,,,, -32200,1.6620907,1.6656728,,,,,,,,,,,,,, -32300,1.7691687,1.6925195,,,,,,,,,,,,,, -32400,1.6590084,1.7690268,,,,,,,,,,,,,, -32500,1.664187,1.66433,,,,,,,,,,,,,, -32600,1.9565666,1.7107548,,,,,,,,,,,,,, -32700,1.7956719,1.6677231,,,,,,,,,,,,,, -32800,1.81928,1.715868,,,,,,,,,,,,,, -32900,1.9400101,1.7346652,,,,,,,,,,,,,, -33000,1.7166119,1.6447955,,,,,,,,,,,,,, -33100,1.9623439,1.8190858,,,,,,,,,,,,,, -33200,1.6269416,1.7361016,,,,,,,,,,,,,, -33300,1.9235709,1.6880571,,,,,,,,,,,,,, -33400,1.7284061,1.7267544,,,,,,,,,,,,,, -33494,,,0.66796875,1.3174986839294434,0.6146999597549438,1.6063886880874634,50000.0,0.4869000315666199,2.363152027130127,10000.0,11273.611344337463,11759.572400569916,11273.611344337463,484.0868966579437,0.7003006935119629,0.0 -33500,1.6507152,1.679858,,,,,,,,,,,,,, -33600,1.6282043,1.7097499,,,,,,,,,,,,,, -33700,1.674408,1.7413312,,,,,,,,,,,,,, -33800,1.7445128,1.7190251,,,,,,,,,,,,,, -33900,1.9392247,1.776444,,,,,,,,,,,,,, -34000,1.845139,1.810733,,,,,,,,,,,,,, -34100,1.7578568,1.5899494,,,,,,,,,,,,,, -34200,1.9164815,1.7044735,,,,,,,,,,,,,, -34300,1.7575989,1.6456776,,,,,,,,,,,,,, -34400,1.7075189,1.7849164,,,,,,,,,,,,,, -34500,2.015447,1.7835306,,,,,,,,,,,,,, -34600,1.833047,1.7679619,,,,,,,,,,,,,, -34700,2.0095348,1.746776,,,,,,,,,,,,,, -34800,1.8810923,1.667825,,,,,,,,,,,,,, -34900,1.8336072,1.712987,,,,,,,,,,,,,, -35000,1.6736724,1.6159811,,,,,,,,,,,,,, -35018,,,0.7241310477256775,1.0600242614746094,0.6274799704551697,1.5452181100845337,50000.0,0.5064000487327576,2.267633438110352,10000.0,11783.628342866898,12294.303589820862,11783.628342866898,508.71211862564087,0.7353689670562744,0.0 -35100,1.8102196,1.8181446,,,,,,,,,,,,,, -35200,1.79937,1.6965514,,,,,,,,,,,,,, -35300,1.6028539,1.7437991,,,,,,,,,,,,,, -35400,2.4117258,1.7447675,,,,,,,,,,,,,, -35500,1.880332,1.6751668,,,,,,,,,,,,,, -35600,1.8523171,1.694917,,,,,,,,,,,,,, -35700,1.7245709,1.6917762,,,,,,,,,,,,,, -35800,1.827348,1.7048515,,,,,,,,,,,,,, -35900,1.6675735,1.6684202,,,,,,,,,,,,,, -36000,1.7477683,1.677294,,,,,,,,,,,,,, -36100,1.7132988,1.5932637,,,,,,,,,,,,,, -36200,1.7168177,1.7380403,,,,,,,,,,,,,, -36300,1.7944915,1.6608806,,,,,,,,,,,,,, -36400,1.8593916,1.6871495,,,,,,,,,,,,,, -36500,1.8318273,1.6922866,,,,,,,,,,,,,, -36542,,,0.7159797549247742,1.0924153327941897,0.6389200091362,1.4935661554336548,50000.0,0.5040000081062317,2.234344959259033,10000.0,12293.778350830078,12829.732397556303,12293.778350830078,533.906332731247,0.7664787769317627,0.0 -36600,2.086481,1.72033,,,,,,,,,,,,,, -36700,1.7922901,1.7276629,,,,,,,,,,,,,, -36800,1.8727329,1.8203524,,,,,,,,,,,,,, -36900,1.7265553,1.6628988,,,,,,,,,,,,,, -37000,1.7361716,1.6705029,,,,,,,,,,,,,, -37100,1.7526133,1.6279393,,,,,,,,,,,,,, -37200,1.7876924,1.7420837,,,,,,,,,,,,,, -37300,1.6637363,1.7491574,,,,,,,,,,,,,, -37400,1.8277749,1.7281523,,,,,,,,,,,,,, -37500,1.8473349,1.7309017,,,,,,,,,,,,,, -37600,1.6772214,1.7466025,,,,,,,,,,,,,, -37700,2.0861647,1.669658,,,,,,,,,,,,,, -37800,1.6269115,1.6250898,,,,,,,,,,,,,, -37900,1.5913342,1.6883608,,,,,,,,,,,,,, -38000,1.6045572,1.7416897,,,,,,,,,,,,,, -38066,,,0.693757951259613,1.1842937469482422,0.6302199959754944,1.5333516597747805,50000.0,0.497700035572052,2.274075508117676,10000.0,12803.850924253464,13362.639755010605,12803.850924253464,556.6493656635284,0.8046107292175293,0.0 -38100,1.843256,1.613426,,,,,,,,,,,,,, -38200,1.7583386,1.6918314,,,,,,,,,,,,,, -38300,1.7972542,1.6538175,,,,,,,,,,,,,, -38400,1.7670562,1.6806566,,,,,,,,,,,,,, -38500,1.8208889,1.6545928,,,,,,,,,,,,,, -38600,1.744687,1.6415082,,,,,,,,,,,,,, -38700,1.5992141,1.607683,,,,,,,,,,,,,, -38800,1.7674469,1.6237495,,,,,,,,,,,,,, -38900,1.8479879,1.6706278,,,,,,,,,,,,,, -39000,1.9162896,1.6671913,,,,,,,,,,,,,, -39100,1.66502,1.6548678,,,,,,,,,,,,,, -39200,1.7819912,1.6637094,,,,,,,,,,,,,, -39300,1.8311669,1.6700461,,,,,,,,,,,,,, -39400,1.9639376,1.7272633,,,,,,,,,,,,,, -39500,1.9392887,1.6427093,,,,,,,,,,,,,, -39590,,,0.6944355964660645,1.1813580989837646,0.6322799921035767,1.505229353904724,50000.0,0.5046000480651855,2.22383189201355,10000.0,13313.921353816986,13894.354410886765,13313.921353816986,578.2093193531036,0.8341331481933594,0.0 -39600,1.7505672,1.6674638,,,,,,,,,,,,,, -39700,1.9613876,1.7272254,,,,,,,,,,,,,, -39800,1.8012695,1.7123449,,,,,,,,,,,,,, -39900,1.9554296,1.8077309,,,,,,,,,,,,,, -40000,1.7815026,1.6761895,,,,,,,,,,,,,, -40100,1.8864236,1.6333756,,,,,,,,,,,,,, -40200,1.9030343,1.78957,,,,,,,,,,,,,, -40300,1.9010162,1.6095635,,,,,,,,,,,,,, -40400,2.251788,1.8061649,,,,,,,,,,,,,, -40500,1.7912827,1.8202616,,,,,,,,,,,,,, -40600,1.7035196,1.60539,,,,,,,,,,,,,, -40700,1.820652,1.7427611,,,,,,,,,,,,,, -40800,1.7602675,1.7376766,,,,,,,,,,,,,, -40900,1.7891527,1.6351511,,,,,,,,,,,,,, -41000,1.9819398,1.590733,,,,,,,,,,,,,, -41100,1.7803731,1.8011849,,,,,,,,,,,,,, -41114,,,0.6988998651504517,1.1760797500610352,0.6380599737167358,1.4915393590927124,50000.0,0.5070000290870667,2.210167407989502,10000.0,13823.889991521835,14426.871874332428,13823.889991521835,600.675609588623,0.8630940914154053,0.0 -41200,1.8331008,1.7792016,,,,,,,,,,,,,, -41300,1.7876786,1.669358,,,,,,,,,,,,,, -41400,1.6561223,1.5040668,,,,,,,,,,,,,, -41500,1.7776436,1.6016831,,,,,,,,,,,,,, -41600,1.8661537,1.5930583,,,,,,,,,,,,,, -41700,1.7838821,1.5921242,,,,,,,,,,,,,, -41800,1.7138312,1.627726,,,,,,,,,,,,,, -41900,1.9294516,1.7303666,,,,,,,,,,,,,, -42000,1.9937539,1.675832,,,,,,,,,,,,,, -42100,1.7169963,1.5437467,,,,,,,,,,,,,, -42200,1.9768007,1.5863779,,,,,,,,,,,,,, -42300,1.8289732,1.6682421,,,,,,,,,,,,,, -42400,1.9772061,1.6849482,,,,,,,,,,,,,, -42500,1.9184672,1.6793021,,,,,,,,,,,,,, -42600,1.6791939,1.6096675,,,,,,,,,,,,,, -42639,,,0.6919044852256775,1.2171635627746582,0.632319986820221,1.5132615566253662,50000.0,0.4996000230312347,2.271448135375977,10000.0,14334.058574199677,14959.933693647385,14334.058574199677,623.4824783802032,0.8966398239135742,0.0 -42700,1.8950365,1.7489691,,,,,,,,,,,,,, -42800,1.7969632,1.613241,,,,,,,,,,,,,, -42900,2.0907059,1.704772,,,,,,,,,,,,,, -43000,1.9421245,1.6977012,,,,,,,,,,,,,, -43100,1.7918833,1.6021693,,,,,,,,,,,,,, -43200,1.603777,1.4897631,,,,,,,,,,,,,, -43300,1.9486448,1.7472234,,,,,,,,,,,,,, -43400,1.8700302,1.6678042,,,,,,,,,,,,,, -43500,1.8442279,1.7541395,,,,,,,,,,,,,, -43600,1.8258753,1.6474942,,,,,,,,,,,,,, -43700,1.9364176,1.6847959,,,,,,,,,,,,,, -43800,1.782053,1.6303037,,,,,,,,,,,,,, -43900,1.9433709,1.7082014,,,,,,,,,,,,,, -44000,1.7687011,1.6190622,,,,,,,,,,,,,, -44100,1.8719585,1.6116143,,,,,,,,,,,,,, -44163,,,0.7145447731018066,1.0940722227096558,0.633359968662262,1.494249939918518,50000.0,0.5051000118255615,2.2277207374572754,10000.0,14844.04565525055,15492.22899198532,14844.04565525055,645.7052969932556,0.9284241199493408,0.0 -44200,1.9479209,1.7567931,,,,,,,,,,,,,, -44300,2.0151944,1.75437,,,,,,,,,,,,,, -44400,1.7226173,1.656322,,,,,,,,,,,,,, -44500,1.8283472,1.6030318,,,,,,,,,,,,,, -44600,1.8636478,1.5328699,,,,,,,,,,,,,, -44700,1.7010382,1.6614108,,,,,,,,,,,,,, -44800,1.8231809,1.5864965,,,,,,,,,,,,,, -44900,1.9319413,1.6155562,,,,,,,,,,,,,, -45000,1.9190985,1.5308965,,,,,,,,,,,,,, -45100,1.9012386,1.5118175,,,,,,,,,,,,,, -45200,1.8336297,1.5818502,,,,,,,,,,,,,, -45300,1.8676621,1.6918259,,,,,,,,,,,,,, -45400,1.7124785,1.612776,,,,,,,,,,,,,, -45500,1.7584237,1.6623973,,,,,,,,,,,,,, -45600,1.8227403,1.7050564,,,,,,,,,,,,,, -45687,,,0.7106783986091614,1.1211135387420654,0.6406999826431274,1.4792100191116333,50000.0,0.5095000267028809,2.23911714553833,10000.0,15354.069417715073,16024.108890295029,15354.069417715073,667.4693231582642,0.9678726196289062,0.0 -45700,1.6909579,1.6111226,,,,,,,,,,,,,, -45800,1.8290106,1.6110284,,,,,,,,,,,,,, -45900,1.7581534,1.6463486,,,,,,,,,,,,,, -46000,1.700196,1.6923373,,,,,,,,,,,,,, -46100,1.6946293,1.5865799,,,,,,,,,,,,,, -46200,1.7659096,1.5232494,,,,,,,,,,,,,, -46300,1.8358312,1.6336753,,,,,,,,,,,,,, -46400,1.9378517,1.687703,,,,,,,,,,,,,, -46500,1.9523243,1.6517147,,,,,,,,,,,,,, -46600,1.8890588,1.5035175,,,,,,,,,,,,,, -46700,1.7590331,1.5910971,,,,,,,,,,,,,, -46800,1.9886675,1.6513491,,,,,,,,,,,,,, -46900,1.8549825,1.6044073,,,,,,,,,,,,,, -47000,1.9281929,1.5432268,,,,,,,,,,,,,, -47100,1.8313179,1.6762751,,,,,,,,,,,,,, -47200,1.9278216,1.7341523,,,,,,,,,,,,,, -47211,,,0.6990792155265808,1.1850415468215942,0.6307199597358704,1.5144232511520386,50000.0,0.5143000483512878,2.2181320190429688,10000.0,15864.12791633606,16554.72419142723,15864.12791633606,687.938235282898,1.0020246505737305,0.0 -47300,1.6969334,1.5566362,,,,,,,,,,,,,, -47400,1.8450673,1.7011827,,,,,,,,,,,,,, -47500,1.9834082,1.6890984,,,,,,,,,,,,,, -47600,1.9287521,1.6885302,,,,,,,,,,,,,, -47700,1.8505782,1.6560569,,,,,,,,,,,,,, -47800,1.8874383,1.6801952,,,,,,,,,,,,,, -47900,1.8865806,1.6061541,,,,,,,,,,,,,, -48000,1.9639935,1.6040319,,,,,,,,,,,,,, -48100,1.7384669,1.5921786,,,,,,,,,,,,,, -48200,2.0555499,1.6915326,,,,,,,,,,,,,, -48300,1.6600063,1.6036294,,,,,,,,,,,,,, -48400,1.9639567,1.6273984,,,,,,,,,,,,,, -48500,1.9610837,1.5881305,,,,,,,,,,,,,, -48600,1.6881812,1.6141171,,,,,,,,,,,,,, -48700,1.890047,1.6261495,,,,,,,,,,,,,, -48734,,,0.71097731590271,1.1214135885238647,0.6475399732589722,1.454314947128296,50000.0,0.517300009727478,2.1623549461364746,10000.0,16374.140513420103,17082.979299545288,16374.140513420103,706.0882074832916,1.0412023067474363,0.0 -48800,1.7530595,1.5406133,,,,,,,,,,,,,, -48900,1.6498876,1.5857121,,,,,,,,,,,,,, -49000,1.7276083,1.6376655,,,,,,,,,,,,,, -49100,1.9515817,1.6592089,,,,,,,,,,,,,, -49200,1.912226,1.7200024,,,,,,,,,,,,,, -49300,2.0016387,1.6420835,,,,,,,,,,,,,, -49400,1.9515021,1.542825,,,,,,,,,,,,,, -49500,1.7681922,1.5954677,,,,,,,,,,,,,, -49600,1.9616506,1.5917188,,,,,,,,,,,,,, -49700,1.7057059,1.5208291,,,,,,,,,,,,,, -49800,1.8105524,1.6909047,,,,,,,,,,,,,, -49900,2.0654392,1.6366429,,,,,,,,,,,,,, -50000,1.8654251,1.6299276,,,,,,,,,,,,,, -50100,1.7926368,1.5240927,,,,,,,,,,,,,, -50200,2.2420983,1.6085578,,,,,,,,,,,,,, -50257,,,0.7004743218421936,1.1613342761993408,0.6417399644851685,1.4614994525909424,50000.0,0.513700008392334,2.207813501358032,10000.0,16884.34206557274,17612.652390241623,16884.34206557274,725.4657227993011,1.08018159866333,0.0 -50300,1.6348014,1.5882084,,,,,,,,,,,,,, -50400,1.8671284,1.5183579,,,,,,,,,,,,,, -50500,1.7580118,1.561411,,,,,,,,,,,,,, -50600,1.8820904,1.5727718,,,,,,,,,,,,,, -50700,1.7228397,1.5542247,,,,,,,,,,,,,, -50800,1.9634002,1.687189,,,,,,,,,,,,,, -50900,1.7675357,1.6049445,,,,,,,,,,,,,, -51000,1.8655646,1.6334423,,,,,,,,,,,,,, -51100,1.7952927,1.5546,,,,,,,,,,,,,, -51200,1.750461,1.5270437,,,,,,,,,,,,,, -51300,1.9735334,1.5840155,,,,,,,,,,,,,, -51400,1.775698,1.5498418,,,,,,,,,,,,,, -51500,1.7382011,1.6726651,,,,,,,,,,,,,, -51600,2.0175827,1.6099567,,,,,,,,,,,,,, -51700,2.196378,1.5054471,,,,,,,,,,,,,, -51781,,,0.7308274507522583,1.039522409439087,0.6245399713516235,1.5539544820785522,50000.0,0.5035000443458557,2.28792405128479,10000.0,17394.389335393906,18141.577426195145,17394.389335393906,744.2578091621399,1.111807346343994,0.0 -51800,1.8045002,1.6696174,,,,,,,,,,,,,, -51900,1.8587966,1.5165132,,,,,,,,,,,,,, -52000,1.7261145,1.5404348,,,,,,,,,,,,,, -52100,1.8552897,1.544865,,,,,,,,,,,,,, -52200,1.8687533,1.6292604,,,,,,,,,,,,,, -52300,1.8390149,1.522418,,,,,,,,,,,,,, -52400,1.9251292,1.6198348,,,,,,,,,,,,,, -52500,1.8642851,1.5882123,,,,,,,,,,,,,, -52600,1.997957,1.6502882,,,,,,,,,,,,,, -52700,1.865122,1.563566,,,,,,,,,,,,,, -52800,1.8782322,1.6492598,,,,,,,,,,,,,, -52900,1.8842041,1.6700782,,,,,,,,,,,,,, -53000,2.0358613,1.5844465,,,,,,,,,,,,,, -53100,1.8028638,1.4943044,,,,,,,,,,,,,, -53200,1.9400612,1.6456803,,,,,,,,,,,,,, -53300,1.875195,1.6515476,,,,,,,,,,,,,, -53306,,,0.7253667116165161,1.0528578758239746,0.6495199799537659,1.4327813386917114,50000.0,0.5109000205993652,2.1753416061401367,10000.0,17904.58046245575,18669.87014555931,17904.58046245575,762.2627856731415,1.1545979976654053,0.0 -53400,2.0823548,1.5205255,,,,,,,,,,,,,, -53500,1.7752969,1.6347775,,,,,,,,,,,,,, -53600,1.7867182,1.6011643,,,,,,,,,,,,,, -53700,1.9572037,1.7317845,,,,,,,,,,,,,, -53800,1.8778385,1.5410982,,,,,,,,,,,,,, -53900,1.7361231,1.4710326,,,,,,,,,,,,,, -54000,1.7930825,1.6118017,,,,,,,,,,,,,, -54100,1.9433362,1.5098679,,,,,,,,,,,,,, -54200,2.079601,1.6437645,,,,,,,,,,,,,, -54300,1.779267,1.5314667,,,,,,,,,,,,,, -54400,1.7512835,1.4835404,,,,,,,,,,,,,, -54500,1.9693733,1.5597855,,,,,,,,,,,,,, -54600,1.9617187,1.5943475,,,,,,,,,,,,,, -54700,2.164371,1.5828756,,,,,,,,,,,,,, -54800,1.8139974,1.5466638,,,,,,,,,,,,,, -54830,,,0.7131297588348389,1.118245244026184,0.6433199644088745,1.4688860177993774,50000.0,0.5209000110626221,2.2052316665649414,10000.0,18414.588896036148,19197.308834314343,18414.588896036148,779.59827876091,1.1950109004974363,0.0 -54900,2.0028837,1.6466252,,,,,,,,,,,,,, -55000,1.8407911,1.6160696,,,,,,,,,,,,,, -55100,1.7930739,1.5595038,,,,,,,,,,,,,, -55200,2.0484493,1.6205152,,,,,,,,,,,,,, -55300,1.7556667,1.5978296,,,,,,,,,,,,,, -55400,1.9208615,1.5144967,,,,,,,,,,,,,, -55500,2.0286663,1.6599411,,,,,,,,,,,,,, -55600,1.8634287,1.5947844,,,,,,,,,,,,,, -55700,1.865218,1.6366855,,,,,,,,,,,,,, -55800,1.9206089,1.615535,,,,,,,,,,,,,, -55900,1.7934539,1.4748673,,,,,,,,,,,,,, -56000,1.7768103,1.4551754,,,,,,,,,,,,,, -56100,1.8770212,1.6106322,,,,,,,,,,,,,, -56200,1.8082895,1.6615998,,,,,,,,,,,,,, -56300,1.8269533,1.6840568,,,,,,,,,,,,,, -56354,,,0.6937978267669678,1.2075111865997314,0.6305800080299377,1.5379019975662231,50000.0,0.4989000260829925,2.3033556938171387,10000.0,18924.505990982056,19725.887673854828,18924.505990982056,798.1648647785187,1.235038995742798,0.0 -56400,1.8194814,1.4889103,,,,,,,,,,,,,, -56500,2.1110618,1.52549,,,,,,,,,,,,,, -56600,1.9484339,1.6351243,,,,,,,,,,,,,, -56700,1.8433391,1.5275308,,,,,,,,,,,,,, -56800,1.9370222,1.6217431,,,,,,,,,,,,,, -56900,1.8750519,1.5844893,,,,,,,,,,,,,, -57000,1.9056202,1.6074219,,,,,,,,,,,,,, -57100,1.9354361,1.5328671,,,,,,,,,,,,,, -57200,1.9783313,1.639324,,,,,,,,,,,,,, -57300,1.8281724,1.5492182,,,,,,,,,,,,,, -57400,1.8829495,1.5140469,,,,,,,,,,,,,, -57500,1.7238519,1.6119324,,,,,,,,,,,,,, -57600,1.8522204,1.6337497,,,,,,,,,,,,,, -57700,1.8867902,1.59762,,,,,,,,,,,,,, -57800,1.9545448,1.6317542,,,,,,,,,,,,,, -57879,,,0.7224768400192261,1.0849560499191284,0.6535800099372864,1.421253681182861,50000.0,0.5266000032424927,2.125169515609741,10000.0,19434.580425977707,20252.75336956978,19434.580425977707,814.8658380508423,1.271176815032959,0.0 -57900,1.8224814,1.4721376,,,,,,,,,,,,,, -58000,1.8745708,1.3914025,,,,,,,,,,,,,, -58100,2.4314485,1.5397552,,,,,,,,,,,,,, -58200,2.0437639,1.583492,,,,,,,,,,,,,, -58300,1.982026,1.5170788,,,,,,,,,,,,,, -58400,1.8756821,1.557419,,,,,,,,,,,,,, -58500,2.4487329,1.6471789,,,,,,,,,,,,,, -58600,1.945936,1.5895895,,,,,,,,,,,,,, -58700,1.9453275,1.624421,,,,,,,,,,,,,, -58800,1.8720369,1.4833075,,,,,,,,,,,,,, -58900,1.8944106,1.5904374,,,,,,,,,,,,,, -59000,1.8818306,1.506001,,,,,,,,,,,,,, -59100,2.0090837,1.6706318,,,,,,,,,,,,,, -59200,1.9442284,1.6967343,,,,,,,,,,,,,, -59300,1.9047662,1.5070716,,,,,,,,,,,,,, -59400,2.117815,1.5436045,,,,,,,,,,,,,, -59404,,,0.7204440236091614,1.080566167831421,0.6579599976539612,1.4034535884857178,50000.0,0.5297000408172607,2.1135621070861816,10000.0,19944.806575775143,20780.083662986755,19944.806575775143,831.8755283355713,1.3114573955535889,0.0 -59500,1.8114829,1.4817004,,,,,,,,,,,,,, -59600,1.8120013,1.6201015,,,,,,,,,,,,,, -59700,1.9217182,1.6367259,,,,,,,,,,,,,, -59800,1.958583,1.5074669,,,,,,,,,,,,,, -59900,1.9517846,1.5128733,,,,,,,,,,,,,, -60000,2.117896,1.5742066,,,,,,,,,,,,,, -60100,1.9644114,1.5734701,,,,,,,,,,,,,, -60200,1.909434,1.4513942,,,,,,,,,,,,,, -60300,2.183339,1.6515548,,,,,,,,,,,,,, -60400,2.0344732,1.4777036,,,,,,,,,,,,,, -60500,1.9177207,1.4708762,,,,,,,,,,,,,, -60600,2.0328999,1.6683762,,,,,,,,,,,,,, -60700,1.9441714,1.5453674,,,,,,,,,,,,,, -60800,2.1037877,1.6519511,,,,,,,,,,,,,, -60900,1.8524418,1.5478833,,,,,,,,,,,,,, -60927,,,0.750398576259613,0.9552063941955566,0.6536399722099304,1.428490400314331,50000.0,0.5304000377655029,2.15209436416626,10000.0,20454.81146836281,21307.03911685944,20454.81146836281,848.7308654785156,1.3519139289855957,0.0 -61000,2.085944,1.5776994,,,,,,,,,,,,,, -61100,2.0935557,1.4934092,,,,,,,,,,,,,, -61200,2.0811508,1.5992005,,,,,,,,,,,,,, -61300,1.783974,1.531671,,,,,,,,,,,,,, -61400,1.868508,1.5918486,,,,,,,,,,,,,, -61500,1.8560275,1.6410114,,,,,,,,,,,,,, -61600,1.8785232,1.594636,,,,,,,,,,,,,, -61700,1.9012752,1.6850473,,,,,,,,,,,,,, -61800,1.9285798,1.4436809,,,,,,,,,,,,,, -61900,1.849795,1.4836705,,,,,,,,,,,,,, -62000,1.9567184,1.5384611,,,,,,,,,,,,,, -62100,2.0111287,1.4646078,,,,,,,,,,,,,, -62200,1.876768,1.4693439,,,,,,,,,,,,,, -62300,1.9212875,1.5095059,,,,,,,,,,,,,, -62400,1.9484848,1.554996,,,,,,,,,,,,,, -62452,,,0.7303292155265808,1.0258498191833496,0.6525799632072449,1.428292751312256,50000.0,0.5215000510215759,2.1652936935424805,10000.0,20965.004019737244,21834.207016468048,20965.004019737244,865.6110301017761,1.3921701908111572,0.0 -62500,2.182581,1.7046847,,,,,,,,,,,,,, -62600,2.0696573,1.6043732,,,,,,,,,,,,,, -62700,1.8305627,1.5138057,,,,,,,,,,,,,, -62800,1.9495441,1.5687258,,,,,,,,,,,,,, -62900,1.9919828,1.4891477,,,,,,,,,,,,,, -63000,2.140699,1.5112219,,,,,,,,,,,,,, -63100,2.1922922,1.6848392,,,,,,,,,,,,,, -63200,2.144757,1.5980564,,,,,,,,,,,,,, -63300,1.8659102,1.5487169,,,,,,,,,,,,,, -63400,1.7922266,1.4177701,,,,,,,,,,,,,, -63500,1.8889337,1.5290191,,,,,,,,,,,,,, -63600,1.9701263,1.6738806,,,,,,,,,,,,,, -63700,1.931637,1.5927172,,,,,,,,,,,,,, -63800,2.022091,1.4995347,,,,,,,,,,,,,, -63900,2.020367,1.6200173,,,,,,,,,,,,,, -63977,,,0.7126116156578064,1.0953108072280884,0.6471799612045288,1.4452688694000244,50000.0,0.5161000490188599,2.201335906982422,10000.0,21475.078028202057,22361.16133069992,21475.078028202057,882.3958539962769,1.4340605735778809,0.0 -64000,1.9887615,1.5405817,,,,,,,,,,,,,, -64100,1.8614202,1.4158353,,,,,,,,,,,,,, -64200,1.7272972,1.5384015,,,,,,,,,,,,,, -64300,1.9922302,1.5124853,,,,,,,,,,,,,, -64400,1.8739427,1.4517164,,,,,,,,,,,,,, -64500,2.039133,1.5224626,,,,,,,,,,,,,, -64600,1.8535347,1.5347347,,,,,,,,,,,,,, -64700,2.046722,1.6377012,,,,,,,,,,,,,, -64800,2.0695317,1.5243806,,,,,,,,,,,,,, -64900,1.882358,1.6468196,,,,,,,,,,,,,, -65000,1.8086451,1.6803033,,,,,,,,,,,,,, -65100,1.9210861,1.5856311,,,,,,,,,,,,,, -65200,1.8357756,1.3933402,,,,,,,,,,,,,, -65300,2.2809057,1.5180013,,,,,,,,,,,,,, -65400,1.9486468,1.4104066,,,,,,,,,,,,,, -65500,2.0523503,1.6377325,,,,,,,,,,,,,, -65503,,,0.7235730290412903,1.0580490827560425,0.6538400053977966,1.4230231046676636,50000.0,0.5282000303268433,2.1661126613616943,10000.0,21985.282883882523,22888.20652484893,21985.282883882523,899.1315841674805,1.4839389324188232,0.0 -65600,1.8503426,1.5697126,,,,,,,,,,,,,, -65700,1.9710796,1.4434369,,,,,,,,,,,,,, -65800,2.1528416,1.6163505,,,,,,,,,,,,,, -65900,2.1200087,1.5428102,,,,,,,,,,,,,, -66000,2.0054672,1.5278689,,,,,,,,,,,,,, -66100,1.9609908,1.6016827,,,,,,,,,,,,,, -66200,1.959627,1.5590296,,,,,,,,,,,,,, -66300,2.1880815,1.6190155,,,,,,,,,,,,,, -66400,2.0329592,1.5515004,,,,,,,,,,,,,, -66500,2.0292442,1.548203,,,,,,,,,,,,,, -66600,2.0449333,1.6313864,,,,,,,,,,,,,, -66700,2.0967913,1.5002077,,,,,,,,,,,,,, -66800,2.079073,1.5860028,,,,,,,,,,,,,, -66900,2.0293117,1.5362308,,,,,,,,,,,,,, -67000,1.7785116,1.4574702,,,,,,,,,,,,,, -67028,,,0.7344347834587097,1.0239489078521729,0.6676200032234192,1.360720157623291,50000.0,0.539900004863739,2.0818004608154297,10000.0,22495.20993900299,23414.981965065,22495.20993900299,915.8876085281372,1.521956205368042,0.0 -67100,1.9988542,1.539715,,,,,,,,,,,,,, -67200,2.011742,1.5409284,,,,,,,,,,,,,, -67300,1.9366788,1.4423668,,,,,,,,,,,,,, -67400,2.228107,1.5281026,,,,,,,,,,,,,, -67500,2.1077523,1.5673405,,,,,,,,,,,,,, -67600,2.147949,1.5131892,,,,,,,,,,,,,, -67700,1.8673264,1.4414244,,,,,,,,,,,,,, -67800,1.838115,1.5230856,,,,,,,,,,,,,, -67900,2.1386642,1.4575791,,,,,,,,,,,,,, -68000,1.9140763,1.5316396,,,,,,,,,,,,,, -68100,1.8199719,1.5349561,,,,,,,,,,,,,, -68200,2.0592208,1.6237884,,,,,,,,,,,,,, -68300,2.1713312,1.464282,,,,,,,,,,,,,, -68400,2.0703185,1.4738076,,,,,,,,,,,,,, -68500,2.3316495,1.5843863,,,,,,,,,,,,,, -68553,,,0.7496412396430969,0.9645988941192628,0.641319990158081,1.464681625366211,50000.0,0.5121000409126282,2.1813247203826904,10000.0,23005.27852010727,23941.77029824257,23005.27852010727,932.5114860534668,1.5625011920928955,0.0 -68600,1.9490427,1.4289165,,,,,,,,,,,,,, -68700,2.155017,1.546927,,,,,,,,,,,,,, -68800,2.093471,1.542616,,,,,,,,,,,,,, -68900,2.073553,1.5994399,,,,,,,,,,,,,, -69000,1.8425748,1.4369272,,,,,,,,,,,,,, -69100,2.0699112,1.5188681,,,,,,,,,,,,,, -69200,2.210272,1.6196136,,,,,,,,,,,,,, -69300,1.9153339,1.4629958,,,,,,,,,,,,,, -69400,1.895715,1.5018873,,,,,,,,,,,,,, -69500,2.1001546,1.4498408,,,,,,,,,,,,,, -69600,2.0762618,1.5198456,,,,,,,,,,,,,, -69700,1.9750488,1.4752737,,,,,,,,,,,,,, -69800,1.9011313,1.513524,,,,,,,,,,,,,, -69900,1.8699973,1.5171789,,,,,,,,,,,,,, -70000,1.896931,1.439841,,,,,,,,,,,,,, -70078,,,0.7494021058082581,0.9496089220046996,0.6649399995803833,1.3731956481933594,50000.0,0.5301000475883484,2.1011109352111816,10000.0,23515.206540346146,24468.674884796143,23515.206540346146,949.3906710147858,1.6044843196868896,0.0 -70100,2.0914717,1.5496676,,,,,,,,,,,,,, -70200,1.8766354,1.4939142,,,,,,,,,,,,,, -70300,2.0780792,1.555737,,,,,,,,,,,,,, -70400,2.2624574,1.5460635,,,,,,,,,,,,,, -70500,1.9805236,1.4465499,,,,,,,,,,,,,, -70600,2.296647,1.5328908,,,,,,,,,,,,,, -70700,2.0035324,1.5721719,,,,,,,,,,,,,, -70800,2.090951,1.5187614,,,,,,,,,,,,,, -70900,2.1172392,1.4917884,,,,,,,,,,,,,, -71000,2.2293155,1.4950578,,,,,,,,,,,,,, -71100,1.9177189,1.4009452,,,,,,,,,,,,,, -71200,1.9937935,1.4840543,,,,,,,,,,,,,, -71300,1.9292755,1.5330416,,,,,,,,,,,,,, -71400,2.0209064,1.4525359,,,,,,,,,,,,,, -71500,1.9079973,1.3906825,,,,,,,,,,,,,, -71600,2.144585,1.5218862,,,,,,,,,,,,,, -71603,,,0.7323421239852905,1.0162076950073242,0.6559000015258789,1.420893669128418,50000.0,0.5294000506401062,2.140475511550904,10000.0,24025.367934703827,24995.66038680077,24025.367934703827,966.1160097122192,1.6491477489471436,0.0 -71700,2.109907,1.5196913,,,,,,,,,,,,,, -71800,1.8816816,1.4717067,,,,,,,,,,,,,, -71900,2.2719092,1.5331537,,,,,,,,,,,,,, -72000,2.1236587,1.474916,,,,,,,,,,,,,, -72100,2.0472353,1.450006,,,,,,,,,,,,,, -72200,2.1613379,1.5297945,,,,,,,,,,,,,, -72300,2.063013,1.4939109,,,,,,,,,,,,,, -72400,1.9991645,1.3563904,,,,,,,,,,,,,, -72500,1.9373772,1.4961677,,,,,,,,,,,,,, -72600,1.9199706,1.4047501,,,,,,,,,,,,,, -72700,1.9465843,1.5199965,,,,,,,,,,,,,, -72800,2.2591488,1.4814042,,,,,,,,,,,,,, -72900,2.0127063,1.4099585,,,,,,,,,,,,,, -73000,1.9465901,1.4208198,,,,,,,,,,,,,, -73100,1.9324073,1.4784395,,,,,,,,,,,,,, -73128,,,0.7429647445678711,0.9706432223320008,0.6710000038146973,1.3437200784683228,50000.0,0.5420000553131104,2.0839552879333496,10000.0,24535.35845017433,25522.33491587639,24535.35845017433,982.704963684082,1.689319372177124,0.0 -73200,1.7314714,1.4483413,,,,,,,,,,,,,, -73300,1.8795605,1.4055064,,,,,,,,,,,,,, -73400,2.1116872,1.4754865,,,,,,,,,,,,,, -73500,1.9969584,1.503699,,,,,,,,,,,,,, -73600,2.044153,1.4331284,,,,,,,,,,,,,, -73700,2.0795767,1.5700582,,,,,,,,,,,,,, -73800,1.9233108,1.4899404,,,,,,,,,,,,,, -73900,2.1549256,1.4583104,,,,,,,,,,,,,, -74000,1.9169713,1.4510114,,,,,,,,,,,,,, -74100,1.9955195,1.5448477,,,,,,,,,,,,,, -74200,2.1632652,1.5580784,,,,,,,,,,,,,, -74300,2.017018,1.3399734,,,,,,,,,,,,,, -74400,2.1418724,1.5296408,,,,,,,,,,,,,, -74500,2.0852797,1.4731026,,,,,,,,,,,,,, -74600,2.3431602,1.4477283,,,,,,,,,,,,,, -74653,,,0.7179328799247742,1.0882242918014526,0.6489799618721008,1.4370622634887695,50000.0,0.51500004529953,2.1855597496032715,10000.0,25045.55570292473,26049.26572537422,25045.55570292473,999.3370826244354,1.7352948188781738,0.0 -74700,1.8567746,1.4124451,,,,,,,,,,,,,, -74800,2.3930354,1.4239732,,,,,,,,,,,,,, -74900,2.0468347,1.4665207,,,,,,,,,,,,,, -75000,2.263626,1.4177158,,,,,,,,,,,,,, -75100,1.9114888,1.4661183,,,,,,,,,,,,,, -75200,2.012335,1.4951868,,,,,,,,,,,,,, -75300,2.2181292,1.5444703,,,,,,,,,,,,,, -75400,2.1354053,1.4850371,,,,,,,,,,,,,, -75500,2.151162,1.4792283,,,,,,,,,,,,,, -75600,2.0670447,1.4992598,,,,,,,,,,,,,, -75700,2.1968622,1.5465304,,,,,,,,,,,,,, -75800,1.9728222,1.445534,,,,,,,,,,,,,, -75900,2.1751714,1.548703,,,,,,,,,,,,,, -76000,2.093183,1.5988975,,,,,,,,,,,,,, -76100,2.0284774,1.4074305,,,,,,,,,,,,,, -76178,,,0.73246169090271,1.026774287223816,0.6608399748802185,1.3760240077972412,50000.0,0.5300000309944153,2.1334023475646973,10000.0,25555.715767622,26576.88431286812,25555.715767622,1016.7017018795012,1.7748703956604004,0.0 -76200,2.204152,1.4871022,,,,,,,,,,,,,, -76300,2.1373868,1.4752009,,,,,,,,,,,,,, -76400,1.9155027,1.3469715,,,,,,,,,,,,,, -76500,2.1010754,1.4572345,,,,,,,,,,,,,, -76600,2.0453877,1.511772,,,,,,,,,,,,,, -76700,2.1199965,1.4344391,,,,,,,,,,,,,, -76800,1.9740453,1.487668,,,,,,,,,,,,,, -76900,2.1512291,1.5079726,,,,,,,,,,,,,, -77000,1.9858681,1.4601206,,,,,,,,,,,,,, -77100,2.1618936,1.4893458,,,,,,,,,,,,,, -77200,2.1411037,1.4927835,,,,,,,,,,,,,, -77300,2.0604491,1.4524133,,,,,,,,,,,,,, -77400,2.141511,1.3833607,,,,,,,,,,,,,, -77500,2.0990832,1.4622331,,,,,,,,,,,,,, -77600,2.116823,1.5332842,,,,,,,,,,,,,, -77700,2.1275158,1.4155793,,,,,,,,,,,,,, -77702,,,0.7727798223495483,0.8478021025657654,0.670199990272522,1.3518798351287842,50000.0,0.5368000268936157,2.1114277839660645,10000.0,26065.67103695869,27103.81628870964,26065.67103695869,1033.5776641368866,1.8206498622894287,0.0 -77800,1.9211094,1.4702336,,,,,,,,,,,,,, -77900,2.1556406,1.5290042,,,,,,,,,,,,,, -78000,2.1651618,1.4791834,,,,,,,,,,,,,, -78100,1.9554183,1.3941648,,,,,,,,,,,,,, -78200,2.0998375,1.5550768,,,,,,,,,,,,,, -78300,2.1928248,1.4190367,,,,,,,,,,,,,, -78400,2.1174843,1.3992155,,,,,,,,,,,,,, -78500,2.129953,1.4815915,,,,,,,,,,,,,, -78600,2.2258096,1.4091288,,,,,,,,,,,,,, -78700,1.9520012,1.5354702,,,,,,,,,,,,,, -78800,2.2753477,1.3976706,,,,,,,,,,,,,, -78900,2.0621583,1.3702369,,,,,,,,,,,,,, -79000,2.3217614,1.5709686,,,,,,,,,,,,,, -79100,2.0569334,1.4955972,,,,,,,,,,,,,, -79200,2.11295,1.40023,,,,,,,,,,,,,, -79227,,,0.7599848508834839,0.9066538214683532,0.6744799613952637,1.3212937116622925,50000.0,0.5470000505447388,2.039279699325561,10000.0,26575.838837385178,27630.83352947235,26575.838837385178,1050.3322749137878,1.860451936721801,0.0 -79300,2.5154774,1.5013119,,,,,,,,,,,,,, -79400,2.1884053,1.4995948,,,,,,,,,,,,,, -79500,1.9541079,1.3986077,,,,,,,,,,,,,, -79600,2.1038506,1.4298239,,,,,,,,,,,,,, -79700,2.0992568,1.4827081,,,,,,,,,,,,,, -79800,2.1629617,1.4643197,,,,,,,,,,,,,, -79900,1.7999164,1.3440969,,,,,,,,,,,,,, -80000,2.513599,1.535686,,,,,,,,,,,,,, -80100,2.190663,1.4915493,,,,,,,,,,,,,, -80200,2.2019293,1.4707357,,,,,,,,,,,,,, -80300,2.1093981,1.548765,,,,,,,,,,,,,, -80400,2.2682683,1.5180708,,,,,,,,,,,,,, -80500,2.2206852,1.5388815,,,,,,,,,,,,,, -80600,2.032493,1.4095707,,,,,,,,,,,,,, -80700,1.8879118,1.362522,,,,,,,,,,,,,, -80751,,,0.7531289458274841,0.9358610510826112,0.6725999712944031,1.338200569152832,50000.0,0.5384000539779663,2.07574200630188,10000.0,27085.72404384613,28157.477063655853,27085.72404384613,1066.9713323116302,1.9235823154449463,0.0 -80800,2.281141,1.402973,,,,,,,,,,,,,, -80900,2.1594014,1.5290349,,,,,,,,,,,,,, -81000,2.0279086,1.5216166,,,,,,,,,,,,,, -81100,2.1867242,1.5060884,,,,,,,,,,,,,, -81200,2.1417716,1.3972899,,,,,,,,,,,,,, -81300,2.1075869,1.4243823,,,,,,,,,,,,,, -81400,2.1841702,1.4730892,,,,,,,,,,,,,, -81500,2.1539614,1.5033412,,,,,,,,,,,,,, -81600,2.022577,1.2393157,,,,,,,,,,,,,, -81700,2.0223434,1.4270881,,,,,,,,,,,,,, -81800,2.2159092,1.4315695,,,,,,,,,,,,,, -81900,2.0590353,1.533319,,,,,,,,,,,,,, -82000,2.1340132,1.4857676,,,,,,,,,,,,,, -82100,2.0889847,1.4632574,,,,,,,,,,,,,, -82200,2.1299372,1.4561404,,,,,,,,,,,,,, -82276,,,0.7472097873687744,0.9645992517471312,0.6704999804496765,1.3342227935791016,50000.0,0.5415000319480896,2.055564641952514,10000.0,27595.813599586487,28684.43613266945,27595.813599586487,1083.7439770698547,1.9646568298339844,0.0 -82300,2.031586,1.3654256,,,,,,,,,,,,,, -82400,2.2028227,1.4703007,,,,,,,,,,,,,, -82500,2.2111478,1.4683349,,,,,,,,,,,,,, -82600,2.01344,1.457124,,,,,,,,,,,,,, -82700,2.1993744,1.4272617,,,,,,,,,,,,,, -82800,1.9596225,1.4450606,,,,,,,,,,,,,, -82900,2.0539744,1.4604607,,,,,,,,,,,,,, -83000,1.9504857,1.352099,,,,,,,,,,,,,, -83100,2.1831887,1.4031638,,,,,,,,,,,,,, -83200,2.3662372,1.5149667,,,,,,,,,,,,,, -83300,1.9845284,1.4284165,,,,,,,,,,,,,, -83400,2.339251,1.3520012,,,,,,,,,,,,,, -83500,2.2424898,1.417207,,,,,,,,,,,,,, -83600,2.0616658,1.2950201,,,,,,,,,,,,,, -83700,2.4614234,1.557484,,,,,,,,,,,,,, -83800,2.1523228,1.5061938,,,,,,,,,,,,,, -83801,,,0.7385004758834839,0.9893120527267456,0.664900004863739,1.3621970415115356,50000.0,0.5383000373840332,2.107455253601074,10000.0,28105.84876608849,29211.3304066658,28105.84876608849,1100.5029304027555,2.0091564655303955,0.0 -83900,2.2429824,1.49137,,,,,,,,,,,,,, -84000,2.2731695,1.4615638,,,,,,,,,,,,,, -84100,2.14951,1.4324597,,,,,,,,,,,,,, -84200,2.3267233,1.5504708,,,,,,,,,,,,,, -84300,2.0825796,1.5255835,,,,,,,,,,,,,, -84400,2.0919228,1.4555119,,,,,,,,,,,,,, -84500,2.1421432,1.3760662,,,,,,,,,,,,,, -84600,2.1997426,1.3805718,,,,,,,,,,,,,, -84700,2.1985748,1.4022965,,,,,,,,,,,,,, -84800,2.253579,1.4651283,,,,,,,,,,,,,, -84900,2.1008525,1.4843824,,,,,,,,,,,,,, -85000,2.6415448,1.403314,,,,,,,,,,,,,, -85100,2.1420016,1.429649,,,,,,,,,,,,,, -85200,2.337393,1.3535082,,,,,,,,,,,,,, -85300,2.4097505,1.4478129,,,,,,,,,,,,,, -85327,,,0.7606425285339355,0.9143653512001038,0.6753199696540833,1.3267306089401243,50000.0,0.5385000109672546,2.0425949096679688,10000.0,28616.067785024643,29738.30677652359,28616.067785024643,1117.165348291397,2.049769401550293,0.0 -85400,2.2002747,1.3166587,,,,,,,,,,,,,, -85500,2.082619,1.3963573,,,,,,,,,,,,,, -85600,2.0447056,1.3870252,,,,,,,,,,,,,, -85700,2.1291053,1.4611192,,,,,,,,,,,,,, -85800,2.3215437,1.4023846,,,,,,,,,,,,,, -85900,2.2905025,1.3967437,,,,,,,,,,,,,, -86000,2.1794016,1.5122877,,,,,,,,,,,,,, -86100,2.1701248,1.4277327,,,,,,,,,,,,,, -86200,2.2154984,1.4575878,,,,,,,,,,,,,, -86300,2.309963,1.4286076,,,,,,,,,,,,,, -86400,2.148583,1.3793283,,,,,,,,,,,,,, -86500,2.4725885,1.4326149,,,,,,,,,,,,,, -86600,2.442551,1.4370184,,,,,,,,,,,,,, -86700,2.2637987,1.3349434,,,,,,,,,,,,,, -86800,2.1126983,1.4282305,,,,,,,,,,,,,, -86852,,,0.7731385231018066,0.8398574590682983,0.6742599606513977,1.3247716426849363,50000.0,0.5397000312805176,2.076277732849121,10000.0,29126.03837108612,30265.02364993096,29126.03837108612,1133.8120720386505,2.0952630043029785,0.0 -86900,2.548082,1.528818,,,,,,,,,,,,,, -87000,2.1281786,1.4618555,,,,,,,,,,,,,, -87100,2.2543864,1.4689426,,,,,,,,,,,,,, -87200,2.1637995,1.3981665,,,,,,,,,,,,,, -87300,2.2128544,1.4021708,,,,,,,,,,,,,, -87400,2.3514154,1.413546,,,,,,,,,,,,,, -87500,2.273911,1.4705199,,,,,,,,,,,,,, -87600,2.276034,1.4998665,,,,,,,,,,,,,, -87700,2.0599036,1.3835275,,,,,,,,,,,,,, -87800,2.2365675,1.4511912,,,,,,,,,,,,,, -87900,2.3527408,1.408362,,,,,,,,,,,,,, -88000,2.2261662,1.4726204,,,,,,,,,,,,,, -88100,2.0552175,1.3818309,,,,,,,,,,,,,, -88200,2.3943167,1.4804327,,,,,,,,,,,,,, -88300,2.1385098,1.3528497,,,,,,,,,,,,,, -88376,,,0.7651267647743225,0.8875182867050171,0.6764000058174133,1.306697130203247,50000.0,0.5455000400543213,2.0423593521118164,10000.0,29636.034289360046,30791.60491251945,29636.034289360046,1150.2982242107391,2.140028476715088,0.0 -88400,2.299866,1.4193251,,,,,,,,,,,,,, -88500,2.19842,1.4158081,,,,,,,,,,,,,, -88600,2.2972045,1.4962718,,,,,,,,,,,,,, -88700,2.1105654,1.3101883,,,,,,,,,,,,,, -88800,2.2809544,1.512213,,,,,,,,,,,,,, -88900,2.428346,1.3312938,,,,,,,,,,,,,, -89000,2.1989248,1.3878164,,,,,,,,,,,,,, -89100,2.263841,1.4811623,,,,,,,,,,,,,, -89200,2.3044798,1.4611837,,,,,,,,,,,,,, -89300,2.3095787,1.3959981,,,,,,,,,,,,,, -89400,2.3734696,1.4732972,,,,,,,,,,,,,, -89500,2.2674527,1.4222864,,,,,,,,,,,,,, -89600,2.2183805,1.4074289,,,,,,,,,,,,,, -89700,2.2594361,1.4726021,,,,,,,,,,,,,, -89800,2.2481377,1.4210869,,,,,,,,,,,,,, -89900,,,0.7606425285339355,0.8991544842720032,0.6783599853515625,1.304556965827942,50000.0,0.5515000224113464,2.004889488220215,10000.0,30145.949870347977,31318.47411704064,30145.949870347977,1167.15305352211,2.1838526725769043,0.0 -89900,2.2449496,1.3925874,,,,,,,,,,,,,, -90000,2.311595,1.4001493,,,,,,,,,,,,,, -90100,2.2478347,1.5015419,,,,,,,,,,,,,, -90200,2.3302953,1.4083171,,,,,,,,,,,,,, -90300,2.4321063,1.3676862,,,,,,,,,,,,,, -90400,2.1830542,1.4301977,,,,,,,,,,,,,, -90500,2.352191,1.4501063,,,,,,,,,,,,,, -90600,2.377161,1.397711,,,,,,,,,,,,,, -90700,2.2600527,1.4198422,,,,,,,,,,,,,, -90800,2.2483685,1.3835186,,,,,,,,,,,,,, -90900,2.1971154,1.3842492,,,,,,,,,,,,,, -91000,2.5846946,1.4061592,,,,,,,,,,,,,, -91100,2.1494875,1.4459039,,,,,,,,,,,,,, -91200,2.1761477,1.3693643,,,,,,,,,,,,,, -91300,2.2478232,1.3542885,,,,,,,,,,,,,, -91400,2.2749236,1.4131154,,,,,,,,,,,,,, -91424,,,0.7586694955825806,0.9094310998916626,0.6788399815559387,1.301895022392273,50000.0,0.5540000200271606,2.0211398601531982,10000.0,30655.910806179047,31845.72444677353,30655.910806179047,1184.3367433547974,2.234565019607544,0.0 -91500,2.312146,1.3398964,,,,,,,,,,,,,, -91600,2.3483772,1.4577085,,,,,,,,,,,,,, -91700,2.341708,1.5337839,,,,,,,,,,,,,, -91800,2.2089972,1.3094625,,,,,,,,,,,,,, -91900,2.2437801,1.3901811,,,,,,,,,,,,,, -92000,2.4799073,1.4376097,,,,,,,,,,,,,, -92100,2.361995,1.3641241,,,,,,,,,,,,,, -92200,2.5013726,1.3942909,,,,,,,,,,,,,, -92300,2.423444,1.5281024,,,,,,,,,,,,,, -92400,2.2043443,1.3249284,,,,,,,,,,,,,, -92500,2.4420795,1.3460845,,,,,,,,,,,,,, -92600,2.068878,1.334197,,,,,,,,,,,,,, -92700,2.6655889,1.3418038,,,,,,,,,,,,,, -92800,2.3717368,1.3830087,,,,,,,,,,,,,, -92900,2.6854033,1.3069105,,,,,,,,,,,,,, -92948,,,0.7504783272743225,0.9368448853492736,0.6730200052261353,1.32490336894989,50000.0,0.5422000288963318,2.0770034790039062,10000.0,31165.815375089645,32372.397426128387,31165.815375089645,1201.0060527324677,2.280010223388672,0.0 -93000,2.4722342,1.3434154,,,,,,,,,,,,,, -93100,2.3891206,1.3203554,,,,,,,,,,,,,, -93200,2.3066862,1.3493521,,,,,,,,,,,,,, -93300,2.0837674,1.2391326,,,,,,,,,,,,,, -93400,2.517635,1.3298414,,,,,,,,,,,,,, -93500,2.38598,1.4433403,,,,,,,,,,,,,, -93600,2.354366,1.3955235,,,,,,,,,,,,,, -93700,2.3756711,1.3201591,,,,,,,,,,,,,, -93800,2.4602454,1.4029067,,,,,,,,,,,,,, -93900,2.5184577,1.3013905,,,,,,,,,,,,,, -94000,2.462025,1.4638152,,,,,,,,,,,,,, -94100,2.3242705,1.4036238,,,,,,,,,,,,,, -94200,2.435095,1.4074914,,,,,,,,,,,,,, -94300,2.5612652,1.4144008,,,,,,,,,,,,,, -94400,2.2507253,1.3296398,,,,,,,,,,,,,, -94472,,,0.802754282951355,0.7244368195533752,0.6849600076675415,1.2787599563598633,50000.0,0.5512000322341919,2.0316712856292725,10000.0,31675.780433416367,32899.11539578438,31675.780433416367,1217.6655213832855,2.319309711456299,0.0 -94500,2.4200635,1.3437274,,,,,,,,,,,,,, -94600,2.2726429,1.2856878,,,,,,,,,,,,,, -94700,2.4628398,1.4063611,,,,,,,,,,,,,, -94800,2.2779093,1.3577545,,,,,,,,,,,,,, -94900,2.3232734,1.4598159,,,,,,,,,,,,,, -95000,2.4237578,1.3841416,,,,,,,,,,,,,, -95100,2.2750406,1.3838385,,,,,,,,,,,,,, -95200,2.3671694,1.3805693,,,,,,,,,,,,,, -95300,2.7183475,1.4379072,,,,,,,,,,,,,, -95400,2.216843,1.3535317,,,,,,,,,,,,,, -95500,2.2617176,1.5149819,,,,,,,,,,,,,, -95600,2.6311538,1.3702587,,,,,,,,,,,,,, -95700,2.4136612,1.359628,,,,,,,,,,,,,, -95800,2.6226203,1.4265516,,,,,,,,,,,,,, -95900,2.3033867,1.2952735,,,,,,,,,,,,,, -95996,,,0.7570551633834839,0.9185478091239928,0.6616399884223938,1.392592191696167,50000.0,0.530500054359436,2.1422789096832275,10000.0,32186.006425857544,33426.04164767265,32186.006425857544,1234.2645723819733,2.3646020889282227,0.0 -96000,2.7047408,1.3603883,,,,,,,,,,,,,, -96100,2.358571,1.3503884,,,,,,,,,,,,,, -96200,2.576259,1.3199008,,,,,,,,,,,,,, -96300,2.6152573,1.3758805,,,,,,,,,,,,,, -96400,2.2972753,1.3964674,,,,,,,,,,,,,, -96500,2.7380314,1.3756992,,,,,,,,,,,,,, -96600,2.2254422,1.2932692,,,,,,,,,,,,,, -96700,2.4240453,1.4251617,,,,,,,,,,,,,, -96800,2.2784746,1.3496605,,,,,,,,,,,,,, -96900,2.5033925,1.331973,,,,,,,,,,,,,, -97000,2.2359798,1.3509306,,,,,,,,,,,,,, -97100,2.3849375,1.3281251,,,,,,,,,,,,,, -97200,2.618553,1.455105,,,,,,,,,,,,,, -97300,2.4980195,1.3357538,,,,,,,,,,,,,, -97400,2.5637732,1.3661911,,,,,,,,,,,,,, -97500,2.5133734,1.4105042,,,,,,,,,,,,,, -97520,,,0.7798349857330322,0.8255358338356018,0.6879799962043762,1.2594505548477173,50000.0,0.5635000467300415,1.9730379581451416,10000.0,32696.02758693695,33953.034247636795,32696.02758693695,1251.119171857834,2.4251294136047363,0.0 -97600,2.346483,1.414444,,,,,,,,,,,,,, -97700,2.4502683,1.3513114,,,,,,,,,,,,,, -97800,2.3091733,1.3031924,,,,,,,,,,,,,, -97900,2.296368,1.343151,,,,,,,,,,,,,, -98000,2.4529002,1.2341222,,,,,,,,,,,,,, -98100,2.8364694,1.4397669,,,,,,,,,,,,,, -98200,2.4664567,1.4074563,,,,,,,,,,,,,, -98300,2.4557137,1.3893071,,,,,,,,,,,,,, -98400,2.317491,1.385971,,,,,,,,,,,,,, -98500,2.4172819,1.3328289,,,,,,,,,,,,,, -98600,2.515286,1.3644087,,,,,,,,,,,,,, -98700,2.5601609,1.3508264,,,,,,,,,,,,,, -98800,2.3892362,1.3929626,,,,,,,,,,,,,, -98900,2.3779528,1.3833848,,,,,,,,,,,,,, -99000,2.6764045,1.4238328,,,,,,,,,,,,,, -99045,,,0.7669602632522583,0.872205913066864,0.6832599639892578,1.2858682870864868,50000.0,0.5490000247955322,2.031160354614258,10000.0,33206.21092581749,34479.94657087326,33206.21092581749,1267.745992422104,2.4731388092041016,0.0 -99100,2.61833,1.3520105,,,,,,,,,,,,,, -99200,2.519807,1.2963398,,,,,,,,,,,,,, -99300,2.485322,1.3535779,,,,,,,,,,,,,, -99400,2.3586352,1.3814512,,,,,,,,,,,,,, -99500,2.4688566,1.4105747,,,,,,,,,,,,,, -99600,2.2079287,1.3073386,,,,,,,,,,,,,, -99700,2.4182434,1.265982,,,,,,,,,,,,,, -99800,2.749387,1.3226135,,,,,,,,,,,,,, -99900,2.7476025,1.38004,,,,,,,,,,,,,, -100000,2.4549944,1.3769499,,,,,,,,,,,,,, -100100,2.2627783,1.2514396,,,,,,,,,,,,,, -100200,2.2892532,1.2953885,,,,,,,,,,,,,, -100300,2.524049,1.3142886,,,,,,,,,,,,,, -100400,2.4633052,1.4331325,,,,,,,,,,,,,, -100500,2.4524274,1.3009378,,,,,,,,,,,,,, -100569,,,0.7652662396430969,0.8769233822822571,0.6861599683761597,1.2804961204528809,50000.0,0.5618000030517578,2.010899305343628,10000.0,33716.131663799286,35006.567068099976,33716.131663799286,1284.3461983203888,2.517145156860352,0.0 -100600,2.4841895,1.3138155,,,,,,,,,,,,,, -100700,2.4204538,1.2933319,,,,,,,,,,,,,, -100800,2.6196244,1.288703,,,,,,,,,,,,,, -100900,2.752721,1.3528781,,,,,,,,,,,,,, -101000,2.683876,1.4478419,,,,,,,,,,,,,, -101100,2.8173563,1.3418559,,,,,,,,,,,,,, -101200,2.4358287,1.2462379,,,,,,,,,,,,,, -101300,2.3714192,1.3562565,,,,,,,,,,,,,, -101400,2.6164818,1.3838089,,,,,,,,,,,,,, -101500,2.7359211,1.4023461,,,,,,,,,,,,,, -101600,2.6411917,1.3513157,,,,,,,,,,,,,, -101700,2.5488853,1.3398244,,,,,,,,,,,,,, -101800,2.562471,1.3926706,,,,,,,,,,,,,, -101900,2.5121338,1.4006224,,,,,,,,,,,,,, -102000,2.2102978,1.3606002,,,,,,,,,,,,,, -102094,,,0.7729990482330322,0.8372607827186584,0.6893799901008606,1.2580443620681765,50000.0,0.5538000464439392,1.9944852590560915,10000.0,34226.252163648605,35533.50708389282,34226.252163648605,1301.061465740204,2.5662214756011963,0.0 -102100,2.5505862,1.3086135,,,,,,,,,,,,,, -102200,2.425068,1.3617101,,,,,,,,,,,,,, -102300,2.5437956,1.256139,,,,,,,,,,,,,, -102400,2.3888893,1.2744524,,,,,,,,,,,,,, -102500,2.438934,1.1777999,,,,,,,,,,,,,, -102600,2.3313148,1.3340805,,,,,,,,,,,,,, -102700,2.6642919,1.4836125,,,,,,,,,,,,,, -102800,2.575542,1.4238127,,,,,,,,,,,,,, -102900,2.4174035,1.4432592,,,,,,,,,,,,,, -103000,2.6898217,1.3624009,,,,,,,,,,,,,, -103100,2.694439,1.3795738,,,,,,,,,,,,,, -103200,2.5082793,1.3130101,,,,,,,,,,,,,, -103300,2.429562,1.306828,,,,,,,,,,,,,, -103400,2.5096674,1.3424656,,,,,,,,,,,,,, -103500,2.5709512,1.2640568,,,,,,,,,,,,,, -103600,2.7186987,1.3542066,,,,,,,,,,,,,, -103619,,,0.7914739847183228,0.7607530951499939,0.6838200092315674,1.2859368324279783,50000.0,0.5582000017166138,2.011260747909546,10000.0,34736.3680062294,36060.34365081787,34736.3680062294,1317.6769466400146,2.6166603565216064,0.0 -103700,2.5587282,1.3032212,,,,,,,,,,,,,, -103800,2.6319206,1.4552841,,,,,,,,,,,,,, -103900,2.6749353,1.302749,,,,,,,,,,,,,, -104000,2.6182768,1.3359923,,,,,,,,,,,,,, -104100,2.728622,1.3604037,,,,,,,,,,,,,, -104200,2.4837189,1.2409371,,,,,,,,,,,,,, -104300,2.601109,1.3262615,,,,,,,,,,,,,, -104400,2.605202,1.4392372,,,,,,,,,,,,,, -104500,2.6302984,1.4289497,,,,,,,,,,,,,, -104600,2.4367762,1.3187187,,,,,,,,,,,,,, -104700,2.4699037,1.3052597,,,,,,,,,,,,,, -104800,2.6263783,1.2903488,,,,,,,,,,,,,, -104900,2.695023,1.2815942,,,,,,,,,,,,,, -105000,2.5013661,1.3256421,,,,,,,,,,,,,, -105100,2.4221542,1.3006648,,,,,,,,,,,,,, -105144,,,0.7905572056770325,0.7790660858154297,0.6904199719429016,1.2633452415466309,50000.0,0.5649999976158142,1.953182339668274,10000.0,35246.5287566185,36587.31213951111,35246.5287566185,1334.381745815277,2.665559768676758,0.0 -105200,2.5473838,1.3006301,,,,,,,,,,,,,, -105300,2.6752658,1.2374706,,,,,,,,,,,,,, -105400,2.9362345,1.2840378,,,,,,,,,,,,,, -105500,2.5376303,1.246581,,,,,,,,,,,,,, -105600,2.435149,1.3685052,,,,,,,,,,,,,, -105700,2.4672656,1.291419,,,,,,,,,,,,,, -105800,2.551831,1.3166384,,,,,,,,,,,,,, -105900,2.4357693,1.2644677,,,,,,,,,,,,,, -106000,2.6078055,1.2648464,,,,,,,,,,,,,, -106100,2.438701,1.3796965,,,,,,,,,,,,,, -106200,2.4682083,1.2849143,,,,,,,,,,,,,, -106300,2.5840182,1.2349449,,,,,,,,,,,,,, -106400,2.8174443,1.2812068,,,,,,,,,,,,,, -106500,2.4887877,1.1893167,,,,,,,,,,,,,, -106600,2.4825375,1.2836106,,,,,,,,,,,,,, -106669,,,0.7826650142669678,0.8016412258148193,0.6915199756622314,1.2654014825820925,50000.0,0.5599000453948975,2.0167007446289062,10000.0,35756.58667945862,37114.21175909042,35756.58667945862,1351.1192100048063,2.7153921127319336,0.0 -106700,2.668068,1.2887329,,,,,,,,,,,,,, -106800,2.5799346,1.185777,,,,,,,,,,,,,, -106900,2.673964,1.2757826,,,,,,,,,,,,,, -107000,2.524468,1.2957357,,,,,,,,,,,,,, -107100,2.6932156,1.3846506,,,,,,,,,,,,,, -107200,2.9688737,1.3409188,,,,,,,,,,,,,, -107300,2.7658238,1.4112191,,,,,,,,,,,,,, -107400,2.8124638,1.2168642,,,,,,,,,,,,,, -107500,2.5463383,1.3424768,,,,,,,,,,,,,, -107600,2.3050482,1.2177767,,,,,,,,,,,,,, -107700,2.9874177,1.330934,,,,,,,,,,,,,, -107800,2.701813,1.3081442,,,,,,,,,,,,,, -107900,2.7267842,1.2408016,,,,,,,,,,,,,, -108000,2.668957,1.3286802,,,,,,,,,,,,,, -108100,2.5338206,1.3673658,,,,,,,,,,,,,, -108194,,,0.7798748016357422,0.8180897831916809,0.6882199645042419,1.2640862464904783,50000.0,0.5564000010490417,2.0124053955078125,10000.0,36266.73639631271,37641.12651062012,36266.73639631271,1367.7784917354584,2.7660505771636963,0.0 -108200,2.6148417,1.3619907,,,,,,,,,,,,,, -108300,2.5596712,1.3095484,,,,,,,,,,,,,, -108400,2.4738443,1.2525473,,,,,,,,,,,,,, -108500,2.61169,1.2302923,,,,,,,,,,,,,, -108600,2.675576,1.3267581,,,,,,,,,,,,,, -108700,2.4563105,1.2345015,,,,,,,,,,,,,, -108800,2.7809176,1.358244,,,,,,,,,,,,,, -108900,2.6739705,1.2798777,,,,,,,,,,,,,, -109000,2.6591792,1.223604,,,,,,,,,,,,,, -109100,2.7282693,1.3477322,,,,,,,,,,,,,, -109200,2.641854,1.1666951,,,,,,,,,,,,,, -109300,2.8823864,1.2669774,,,,,,,,,,,,,, -109400,2.64594,1.2819252,,,,,,,,,,,,,, -109500,2.7055895,1.2507172,,,,,,,,,,,,,, -109600,2.6869495,1.3121812,,,,,,,,,,,,,, -109700,2.8237288,1.2582495,,,,,,,,,,,,,, -109719,,,0.786531388759613,0.7900282740592957,0.6940799951553345,1.2394052743911743,50000.0,0.5622000098228455,1.9760229587554927,10000.0,36776.9069879055,38167.94488573074,36776.9069879055,1384.3214082717896,2.8171417713165283,0.0 -109800,2.781433,1.2774408,,,,,,,,,,,,,, -109900,2.7188044,1.3236837,,,,,,,,,,,,,, -110000,2.631438,1.2895348,,,,,,,,,,,,,, -110100,2.5025914,1.1800151,,,,,,,,,,,,,, -110200,2.7550216,1.2320044,,,,,,,,,,,,,, -110300,2.617338,1.2684889,,,,,,,,,,,,,, -110400,2.6203365,1.2911401,,,,,,,,,,,,,, -110500,2.7107658,1.2590854,,,,,,,,,,,,,, -110600,2.5797598,1.2719116,,,,,,,,,,,,,, -110700,2.8122487,1.2689159,,,,,,,,,,,,,, -110800,2.730555,1.2492831,,,,,,,,,,,,,, -110900,2.621419,1.2495747,,,,,,,,,,,,,, -111000,2.876,1.3184258,,,,,,,,,,,,,, -111100,2.956787,1.3575069,,,,,,,,,,,,,, -111200,2.6922536,1.1969398,,,,,,,,,,,,,, -111244,,,0.8427534699440002,0.5710492134094238,0.6987000107765198,1.2164056301116943,50000.0,0.5749000310897827,1.9428856372833248,10000.0,37287.03126382828,38695.02125692368,37287.03126382828,1401.1726398468018,2.862706422805786,0.0 -111300,2.6157863,1.123354,,,,,,,,,,,,,, -111400,2.6426544,1.2774972,,,,,,,,,,,,,, -111500,2.9142368,1.2952013,,,,,,,,,,,,,, -111600,2.7240136,1.3447598,,,,,,,,,,,,,, -111700,2.6396663,1.1783608,,,,,,,,,,,,,, -111800,2.6884618,1.2690256,,,,,,,,,,,,,, -111900,2.5544298,1.2478004,,,,,,,,,,,,,, -112000,3.3845863,1.2837946,,,,,,,,,,,,,, -112100,2.4917297,1.291471,,,,,,,,,,,,,, -112200,2.5391953,1.2451868,,,,,,,,,,,,,, -112300,2.9843106,1.2655032,,,,,,,,,,,,,, -112400,2.8884523,1.3226111,,,,,,,,,,,,,, -112500,2.7337606,1.2775283,,,,,,,,,,,,,, -112600,2.6557748,1.2334445,,,,,,,,,,,,,, -112700,2.7742407,1.2270073,,,,,,,,,,,,,, -112768,,,0.812898576259613,0.6848605871200562,0.698419988155365,1.2241036891937256,50000.0,0.5710000395774841,1.9627013206481927,10000.0,37797.02880716324,39221.77711343765,37797.02880716324,1417.827962398529,2.9112823009490967,0.0 -112800,2.9165716,1.201056,,,,,,,,,,,,,, -112900,2.7433949,1.289719,,,,,,,,,,,,,, -113000,2.696533,1.1892407,,,,,,,,,,,,,, -113100,2.788737,1.229691,,,,,,,,,,,,,, -113200,2.6861064,1.2559832,,,,,,,,,,,,,, -113300,2.6506243,1.2407812,,,,,,,,,,,,,, -113400,2.8333247,1.2746942,,,,,,,,,,,,,, -113500,2.855918,1.3042401,,,,,,,,,,,,,, -113600,3.1761506,1.2566193,,,,,,,,,,,,,, -113700,2.6678717,1.274591,,,,,,,,,,,,,, -113800,2.7472808,1.1878551,,,,,,,,,,,,,, -113900,2.8214967,1.2381284,,,,,,,,,,,,,, -114000,2.5644364,1.2778184,,,,,,,,,,,,,, -114100,2.773311,1.2205384,,,,,,,,,,,,,, -114200,2.7927957,1.2022212,,,,,,,,,,,,,, -114293,,,0.8019969463348389,0.7204362154006958,0.7006999850273132,1.2203973531723022,50000.0,0.5749000310897827,1.921513319015503,10000.0,38307.227852106094,39748.62899613381,38307.227852106094,1434.3773369789124,2.959364175796509,0.0 -114300,2.7219584,1.1743828,,,,,,,,,,,,,, -114400,2.8933575,1.306354,,,,,,,,,,,,,, -114500,3.0216036,1.2659574,,,,,,,,,,,,,, -114600,2.7513258,1.2871771,,,,,,,,,,,,,, -114700,2.6805751,1.2087147,,,,,,,,,,,,,, -114800,2.8394537,1.2276565,,,,,,,,,,,,,, -114900,2.7900352,1.2263339,,,,,,,,,,,,,, -115000,2.6537921,1.2874446,,,,,,,,,,,,,, -115100,2.8226652,1.2455535,,,,,,,,,,,,,, -115200,2.6061153,1.1432514,,,,,,,,,,,,,, -115300,2.6517985,1.2310095,,,,,,,,,,,,,, -115400,2.8800776,1.2191848,,,,,,,,,,,,,, -115500,2.8685005,1.2531129,,,,,,,,,,,,,, -115600,2.9794648,1.1289859,,,,,,,,,,,,,, -115700,2.9416761,1.2563539,,,,,,,,,,,,,, -115800,3.136192,1.2894725,,,,,,,,,,,,,, -115818,,,0.8029735088348389,0.7088333964347839,0.7026199698448181,1.2160744667053225,50000.0,0.5778000354766846,1.939614176750183,10000.0,38817.41033864021,40275.49406194687,38817.41033864021,1450.955587387085,3.0082755088806152,0.0 -115900,2.8889494,1.2397909,,,,,,,,,,,,,, -116000,2.778295,1.238367,,,,,,,,,,,,,, -116100,2.66477,1.0860256,,,,,,,,,,,,,, -116200,2.807542,1.2803003,,,,,,,,,,,,,, -116300,2.819132,1.1936238,,,,,,,,,,,,,, -116400,2.772446,1.2225995,,,,,,,,,,,,,, -116500,2.9520063,1.2726424,,,,,,,,,,,,,, -116600,2.825274,1.1981246,,,,,,,,,,,,,, -116700,2.8624315,1.3086913,,,,,,,,,,,,,, -116800,2.8792095,1.1988686,,,,,,,,,,,,,, -116900,2.8682778,1.1968416,,,,,,,,,,,,,, -117000,2.9096718,1.2540311,,,,,,,,,,,,,, -117100,2.9479501,1.1638569,,,,,,,,,,,,,, -117200,3.0136743,1.1425345,,,,,,,,,,,,,, -117300,2.8502455,1.2135465,,,,,,,,,,,,,, -117343,,,0.7983697056770325,0.7344706654548645,0.7015999555587769,1.2197657823562622,50000.0,0.5725000500679016,1.961987257003784,10000.0,39327.47944116592,40802.17005300522,39327.47944116592,1467.459722518921,3.056226253509521,0.0 -117400,2.9646254,1.2150581,,,,,,,,,,,,,, -117500,3.0244322,1.2563858,,,,,,,,,,,,,, -117600,2.987891,1.3049783,,,,,,,,,,,,,, -117700,2.7931983,1.2225122,,,,,,,,,,,,,, -117800,2.9966316,1.309832,,,,,,,,,,,,,, -117900,2.8628337,1.2926953,,,,,,,,,,,,,, -118000,2.848206,1.1752166,,,,,,,,,,,,,, -118100,2.859142,1.1934701,,,,,,,,,,,,,, -118200,2.762339,1.166431,,,,,,,,,,,,,, -118300,2.764454,1.1830549,,,,,,,,,,,,,, -118400,2.8135967,1.242,,,,,,,,,,,,,, -118500,3.0370452,1.2149625,,,,,,,,,,,,,, -118600,3.5267363,1.2051549,,,,,,,,,,,,,, -118700,2.968801,1.1722839,,,,,,,,,,,,,, -118800,3.0796564,1.2367936,,,,,,,,,,,,,, -118869,,,0.8006417155265808,0.7289294600486755,0.700439989566803,1.2137749195098877,50000.0,0.5705000162124634,1.956821322441101,10000.0,39837.70289897919,41329.202719688416,39837.70289897919,1484.165627002716,3.1047351360321045,0.0 -118900,3.3008733,1.2019264,,,,,,,,,,,,,, -119000,2.986352,1.1720064,,,,,,,,,,,,,, -119100,3.132426,1.2291795,,,,,,,,,,,,,, -119200,2.9227293,1.2334447,,,,,,,,,,,,,, -119300,2.932172,1.1733656,,,,,,,,,,,,,, -119400,3.0307755,1.1040791,,,,,,,,,,,,,, -119500,2.7870703,1.0963342,,,,,,,,,,,,,, -119600,3.155844,1.2499788,,,,,,,,,,,,,, -119700,2.79913,1.1915097,,,,,,,,,,,,,, -119800,2.9175537,1.2787107,,,,,,,,,,,,,, -119900,3.068456,1.1740966,,,,,,,,,,,,,, -120000,3.022858,1.1931602,,,,,,,,,,,,,, -120100,3.0636032,1.2715358,,,,,,,,,,,,,, -120200,2.9252262,1.195239,,,,,,,,,,,,,, -120300,2.8525045,1.1606866,,,,,,,,,,,,,, -120394,,,0.8328284025192261,0.6055227518081665,0.7037999629974365,1.199912428855896,50000.0,0.579800009727478,1.906605124473572,10000.0,40347.663089990616,41855.777509212494,40347.663089990616,1500.6749844551086,3.155070304870605,0.0 -120400,2.8937294,1.1817243,,,,,,,,,,,,,, -120500,2.7809353,1.2203987,,,,,,,,,,,,,, -120600,3.0092278,1.1977621,,,,,,,,,,,,,, -120700,3.0957296,1.1470156,,,,,,,,,,,,,, -120800,3.0660503,1.1799352,,,,,,,,,,,,,, -120900,2.7708435,1.1022733,,,,,,,,,,,,,, -121000,3.0583975,1.2041336,,,,,,,,,,,,,, -121100,3.0440476,1.1665032,,,,,,,,,,,,,, -121200,2.8895533,1.1298134,,,,,,,,,,,,,, -121300,2.951661,1.1646204,,,,,,,,,,,,,, -121400,3.1259997,1.1993898,,,,,,,,,,,,,, -121500,2.8698175,1.2080348,,,,,,,,,,,,,, -121600,2.996938,1.2273622,,,,,,,,,,,,,, -121700,3.1952486,1.2182667,,,,,,,,,,,,,, -121800,2.9215355,1.2290697,,,,,,,,,,,,,, -121900,3.1131823,1.2329042,,,,,,,,,,,,,, -121919,,,0.82425856590271,0.62799072265625,0.7076799869537354,1.1814346313476562,50000.0,0.5807000398635864,1.9057252407073968,10000.0,40857.58970141411,42382.29209589958,40857.58970141411,1517.159220457077,3.204002857208252,0.0 -122000,2.8750327,1.1170218,,,,,,,,,,,,,, -122100,3.079211,1.1197951,,,,,,,,,,,,,, -122200,2.8358343,1.1632996,,,,,,,,,,,,,, -122300,3.0226061,1.2578943,,,,,,,,,,,,,, -122400,3.0720599,1.1639483,,,,,,,,,,,,,, -122500,2.9336104,1.1394079,,,,,,,,,,,,,, -122600,2.88138,1.1359046,,,,,,,,,,,,,, -122700,2.9982755,1.1084298,,,,,,,,,,,,,, -122800,3.2254584,1.1921208,,,,,,,,,,,,,, -122900,2.936501,1.1265335,,,,,,,,,,,,,, -123000,3.2973084,1.2271559,,,,,,,,,,,,,, -123100,2.9619672,1.219726,,,,,,,,,,,,,, -123200,3.0501897,1.1326663,,,,,,,,,,,,,, -123300,2.8307145,1.04864,,,,,,,,,,,,,, -123400,2.8725917,1.1160519,,,,,,,,,,,,,, -123444,,,0.8248166441917419,0.625325620174408,0.7119799852371216,1.170873522758484,50000.0,0.5823000073432922,1.890337824821472,10000.0,41367.651956796646,42909.13440656662,41367.651956796646,1533.836226463318,3.2530641555786133,0.0 -123500,3.0797443,1.2171004,,,,,,,,,,,,,, -123600,3.0558054,1.1196427,,,,,,,,,,,,,, -123700,3.262392,1.1531395,,,,,,,,,,,,,, -123800,2.8759189,1.0685734,,,,,,,,,,,,,, -123900,3.0241368,1.2130796,,,,,,,,,,,,,, -124000,2.9177043,1.1373286,,,,,,,,,,,,,, -124100,3.014361,1.1553806,,,,,,,,,,,,,, -124200,3.0809464,1.1564765,,,,,,,,,,,,,, -124300,3.1754842,1.1961515,,,,,,,,,,,,,, -124400,3.0943274,1.0520967,,,,,,,,,,,,,, -124500,2.9002616,1.103271,,,,,,,,,,,,,, -124600,2.9365053,1.1216027,,,,,,,,,,,,,, -124700,2.9405699,1.0977863,,,,,,,,,,,,,, -124800,2.995707,1.1633561,,,,,,,,,,,,,, -124900,3.166001,1.173823,,,,,,,,,,,,,, -124969,,,0.8237603306770325,0.6367748379707336,0.7152199745178223,1.161817193031311,50000.0,0.5931000113487244,1.8835399150848389,10000.0,41877.71378183365,43436.02739739418,41877.71378183365,1550.5613188743591,3.304357051849365,0.0 -125000,2.7471151,1.0938728,,,,,,,,,,,,,, -125100,3.260613,1.1633655,,,,,,,,,,,,,, -125200,3.0304363,1.1213534,,,,,,,,,,,,,, -125300,2.9833548,1.1804266,,,,,,,,,,,,,, -125400,3.1227272,1.1940012,,,,,,,,,,,,,, -125500,3.2178264,1.1644098,,,,,,,,,,,,,, -125600,2.976901,1.1938946,,,,,,,,,,,,,, -125700,3.1542768,1.0757942,,,,,,,,,,,,,, -125800,2.9774942,1.0718521,,,,,,,,,,,,,, -125900,3.1591673,1.1565473,,,,,,,,,,,,,, -126000,3.1072497,1.1857772,,,,,,,,,,,,,, -126100,2.976184,1.116709,,,,,,,,,,,,,, -126200,3.121962,1.1266402,,,,,,,,,,,,,, -126300,3.1321287,1.0938377,,,,,,,,,,,,,, -126400,2.9563844,1.0409639,,,,,,,,,,,,,, -126494,,,0.8254543542861938,0.6237724423408508,0.7156199812889099,1.1618980169296265,50000.0,0.5841000080108643,1.881402969360352,10000.0,42387.83588695526,43962.90839862824,42387.83588695526,1567.2161169052124,3.3536806106567383,0.0 -126500,3.2481985,1.0828263,,,,,,,,,,,,,, -126600,3.2637193,1.1140784,,,,,,,,,,,,,, -126700,3.3583684,1.0802934,,,,,,,,,,,,,, -126800,3.2176673,1.0304899,,,,,,,,,,,,,, -126900,3.2097325,1.097333,,,,,,,,,,,,,, -127000,2.7858016,1.0727963,,,,,,,,,,,,,, -127100,3.0569143,1.0960414,,,,,,,,,,,,,, -127200,3.2863038,1.1494503,,,,,,,,,,,,,, -127300,3.1655743,1.1807985,,,,,,,,,,,,,, -127400,3.3257594,1.0163971,,,,,,,,,,,,,, -127500,3.2659671,1.2487099,,,,,,,,,,,,,, -127600,3.0179439,1.0609038,,,,,,,,,,,,,, -127700,3.310479,1.1687438,,,,,,,,,,,,,, -127800,3.0833595,1.0755101,,,,,,,,,,,,,, -127900,3.1454992,1.1452528,,,,,,,,,,,,,, -128000,3.295317,1.083816,,,,,,,,,,,,,, -128019,,,0.8600525856018066,0.4976900219917297,0.7145199775695801,1.1730865240097046,50000.0,0.5861000418663025,1.9103562831878664,10000.0,42897.97383713722,44489.80436134338,42897.97383713722,1583.872680425644,3.401171922683716,0.0 -128100,3.05264,1.0537572,,,,,,,,,,,,,, -128200,3.23927,1.1311369,,,,,,,,,,,,,, -128300,3.1419365,1.0741664,,,,,,,,,,,,,, -128400,3.567979,1.1504515,,,,,,,,,,,,,, -128500,3.2103467,1.2076504,,,,,,,,,,,,,, -128600,3.3778238,1.043677,,,,,,,,,,,,,, -128700,3.2944329,1.1136774,,,,,,,,,,,,,, -128800,2.9740744,0.99131113,,,,,,,,,,,,,, -128900,3.2371235,1.1377823,,,,,,,,,,,,,, -129000,3.1826773,1.187706,,,,,,,,,,,,,, -129100,3.4036222,1.1169271,,,,,,,,,,,,,, -129200,3.3089557,1.0996432,,,,,,,,,,,,,, -129300,3.0410898,1.0516427,,,,,,,,,,,,,, -129400,3.3606417,1.2001764,,,,,,,,,,,,,, -129500,3.550434,1.1473992,,,,,,,,,,,,,, -129544,,,0.8498086333274841,0.5317939519882202,0.7183399796485901,1.1576786041259766,50000.0,0.5879000425338745,1.882236361503601,10000.0,43407.957873106,45016.60299420357,43407.957873106,1600.5715289115906,3.4611611366271973,0.0 -129600,3.2972274,1.0614228,,,,,,,,,,,,,, -129700,3.5365531,1.1498137,,,,,,,,,,,,,, -129800,3.1744344,1.0616099,,,,,,,,,,,,,, -129900,3.288741,1.0816176,,,,,,,,,,,,,, -130000,3.301408,1.058411,,,,,,,,,,,,,, -130100,3.016835,1.0757877,,,,,,,,,,,,,, -130200,3.215256,1.2274688,,,,,,,,,,,,,, -130300,3.4229596,1.095082,,,,,,,,,,,,,, -130400,3.422038,1.120026,,,,,,,,,,,,,, -130500,3.1187603,1.0762877,,,,,,,,,,,,,, -130600,3.116424,1.0794382,,,,,,,,,,,,,, -130700,3.5164523,1.0952114,,,,,,,,,,,,,, -130800,3.2601748,1.0638311,,,,,,,,,,,,,, -130900,3.351967,1.063542,,,,,,,,,,,,,, -131000,3.3958652,1.0858433,,,,,,,,,,,,,, -131069,,,0.8396045565605164,0.5685706734657288,0.7088800072669983,1.184848427772522,50000.0,0.5843000411987305,1.9149274826049805,10000.0,43918.01558470726,45543.41831231117,43918.01558470726,1617.2234988212583,3.512183427810669,0.0 -131100,3.1762385,1.0952486,,,,,,,,,,,,,, -131200,3.268914,1.1551925,,,,,,,,,,,,,, -131300,3.2461863,1.0287399,,,,,,,,,,,,,, -131400,3.0933597,1.0815895,,,,,,,,,,,,,, -131500,3.3293242,1.1163303,,,,,,,,,,,,,, -131600,3.077721,1.1123422,,,,,,,,,,,,,, -131700,3.3338048,1.1221257,,,,,,,,,,,,,, -131800,3.2508607,1.0694804,,,,,,,,,,,,,, -131900,3.4243634,1.1254485,,,,,,,,,,,,,, -132000,3.2945535,1.0073563,,,,,,,,,,,,,, -132100,3.2483873,1.1289513,,,,,,,,,,,,,, -132200,3.2328112,1.0227735,,,,,,,,,,,,,, -132300,3.0856092,1.1506491,,,,,,,,,,,,,, -132400,3.02603,0.9472218,,,,,,,,,,,,,, -132500,3.7300541,1.068479,,,,,,,,,,,,,, -132594,,,0.8439094424247742,0.5533644556999207,0.7190600037574768,1.15404212474823,50000.0,0.5939000248908997,1.8933364152908323,10000.0,44428.10785365105,46070.31389427185,44428.10785365105,1633.9201967716217,3.5640017986297607,0.0 -132600,3.2825165,1.0383341,,,,,,,,,,,,,, -132700,3.3724082,1.0343858,,,,,,,,,,,,,, -132800,3.7078812,1.1149272,,,,,,,,,,,,,, -132900,3.5274227,1.1353252,,,,,,,,,,,,,, -133000,3.055468,1.0757422,,,,,,,,,,,,,, -133100,4.1265564,1.1080941,,,,,,,,,,,,,, -133200,3.5751166,1.0760251,,,,,,,,,,,,,, -133300,3.266792,1.0361503,,,,,,,,,,,,,, -133400,3.231196,1.0526923,,,,,,,,,,,,,, -133500,3.5315726,0.9792923,,,,,,,,,,,,,, -133600,3.2737541,1.0343093,,,,,,,,,,,,,, -133700,3.2651134,1.086957,,,,,,,,,,,,,, -133800,3.3785963,1.0411913,,,,,,,,,,,,,, -133900,3.53986,0.98877555,,,,,,,,,,,,,, -134000,3.4067755,1.0951678,,,,,,,,,,,,,, -134100,3.204141,0.9439184,,,,,,,,,,,,,, -134120,,,0.8478555083274841,0.5405254364013672,0.7234199643135071,1.133036494255066,50000.0,0.5896000266075134,1.887064933776856,10000.0,44938.21067500114,46597.1632874012,44938.21067500114,1650.5535411834717,3.6218745708465576,0.0 -134200,3.4096406,1.1069318,,,,,,,,,,,,,, -134300,3.230833,1.02599,,,,,,,,,,,,,, -134400,3.175383,0.9963484,,,,,,,,,,,,,, -134500,3.6840546,1.1484709,,,,,,,,,,,,,, -134600,3.1915097,1.0476825,,,,,,,,,,,,,, -134700,3.6447158,1.1236084,,,,,,,,,,,,,, -134800,3.5173287,1.049799,,,,,,,,,,,,,, -134900,3.4873269,0.93445605,,,,,,,,,,,,,, -135000,3.3574493,1.0695049,,,,,,,,,,,,,, -135100,3.6171272,0.99662685,,,,,,,,,,,,,, -135200,3.3015823,0.9399015,,,,,,,,,,,,,, -135300,3.318001,1.0135322,,,,,,,,,,,,,, -135400,3.7800903,1.1362143,,,,,,,,,,,,,, -135500,3.4356117,1.0343355,,,,,,,,,,,,,, -135600,3.3889403,1.0081332,,,,,,,,,,,,,, -135646,,,0.8484534025192261,0.5400087833404541,0.7185199856758118,1.1664830446243286,50000.0,0.593000054359436,1.909881353378296,10000.0,45448.31420826912,47124.01470160484,45448.31420826912,1667.193475484848,3.674959659576416,0.0 -135700,3.387652,0.9663273,,,,,,,,,,,,,, -135800,3.540967,1.0158198,,,,,,,,,,,,,, -135900,3.3714864,1.1214979,,,,,,,,,,,,,, -136000,3.6056476,1.0022606,,,,,,,,,,,,,, -136100,3.219635,0.91634357,,,,,,,,,,,,,, -136200,3.4037533,1.0437744,,,,,,,,,,,,,, -136300,3.2339542,0.891593,,,,,,,,,,,,,, -136400,3.3524945,1.0688015,,,,,,,,,,,,,, -136500,3.1704485,0.9497631,,,,,,,,,,,,,, -136600,3.7027998,0.9858988,,,,,,,,,,,,,, -136700,3.6900804,0.98389095,,,,,,,,,,,,,, -136800,3.7744694,0.9033548,,,,,,,,,,,,,, -136900,3.5757852,1.0285823,,,,,,,,,,,,,, -137000,3.684742,1.1287619,,,,,,,,,,,,,, -137100,3.15461,1.0680091,,,,,,,,,,,,,, -137172,,,0.8808394074440002,0.4171130955219269,0.7249799966812134,1.1274539232254028,50000.0,0.5922000408172607,1.8785035610198968,10000.0,45958.44335961342,47650.93498468399,45958.44335961342,1683.880009889603,3.7243666648864746,0.0 -137200,3.423238,1.0194665,,,,,,,,,,,,,, -137300,3.2866518,0.9758387,,,,,,,,,,,,,, -137400,3.9855926,1.0275993,,,,,,,,,,,,,, -137500,3.9036894,1.0130174,,,,,,,,,,,,,, -137600,3.3723624,0.9841448,,,,,,,,,,,,,, -137700,3.5920868,1.0568144,,,,,,,,,,,,,, -137800,3.8058727,1.0843235,,,,,,,,,,,,,, -137900,3.2879164,1.0507634,,,,,,,,,,,,,, -138000,3.5700114,0.9010065,,,,,,,,,,,,,, -138100,3.391778,0.9062251,,,,,,,,,,,,,, -138200,3.6731143,1.0556389,,,,,,,,,,,,,, -138300,3.7084951,1.1170189,,,,,,,,,,,,,, -138400,3.4821372,1.0585971,,,,,,,,,,,,,, -138500,3.4936051,0.9664407,,,,,,,,,,,,,, -138600,3.355198,0.920678,,,,,,,,,,,,,, -138698,,,0.8705556392669678,0.4483226835727691,0.7280600070953369,1.13063383102417,50000.0,0.6017000079154968,1.8933743238449097,10000.0,46468.43311071396,48177.746673583984,46468.43311071396,1700.59228515625,3.780108690261841,0.0 -138700,3.3723996,0.97423196,,,,,,,,,,,,,, -138800,3.724696,0.9508847,,,,,,,,,,,,,, -138900,3.4620101,0.9908097,,,,,,,,,,,,,, -139000,3.4749005,0.99092317,,,,,,,,,,,,,, -139100,3.4518988,0.950201,,,,,,,,,,,,,, -139200,3.725613,0.9244622,,,,,,,,,,,,,, -139300,3.781661,1.0450027,,,,,,,,,,,,,, -139400,4.1815915,1.1492281,,,,,,,,,,,,,, -139500,4.148693,1.0642865,,,,,,,,,,,,,, -139600,3.6081684,0.9915893,,,,,,,,,,,,,, -139700,3.608748,1.0393342,,,,,,,,,,,,,, -139800,3.3578234,1.0049853,,,,,,,,,,,,,, -139900,3.515164,0.96374655,,,,,,,,,,,,,, -140000,3.688957,1.0616987,,,,,,,,,,,,,, -140100,3.6762695,0.9507212,,,,,,,,,,,,,, -140200,3.7657373,1.0092356,,,,,,,,,,,,,, -140223,,,0.8729272484779358,0.4394337832927704,0.7323200106620789,1.1090658903121948,50000.0,0.6011000275611877,1.8609994649887085,10000.0,46978.51990914345,48704.682052373886,46978.51990914345,1717.3363778591156,3.829201221466065,0.0 -140300,3.5921147,1.0283786,,,,,,,,,,,,,, -140400,3.7937503,1.0131572,,,,,,,,,,,,,, -140500,3.5522232,0.96634173,,,,,,,,,,,,,, -140600,3.332572,0.91581726,,,,,,,,,,,,,, -140700,3.8129992,0.9893954,,,,,,,,,,,,,, -140800,3.5597973,0.99181634,,,,,,,,,,,,,, -140900,3.5916872,0.92148113,,,,,,,,,,,,,, -141000,3.5990682,0.92546636,,,,,,,,,,,,,, -141100,3.7085953,0.962622,,,,,,,,,,,,,, -141200,4.0069494,0.96163774,,,,,,,,,,,,,, -141300,3.531903,0.9327866,,,,,,,,,,,,,, -141400,3.7510152,1.0154331,,,,,,,,,,,,,, -141500,3.4503891,0.9088617,,,,,,,,,,,,,, -141600,3.4091172,0.8807807,,,,,,,,,,,,,, -141700,3.944468,0.93736756,,,,,,,,,,,,,, -141749,,,0.8720304369926453,0.4448748230934143,0.7297999858856201,1.107924222946167,50000.0,0.602400004863739,1.851275086402893,10000.0,47488.66160154343,49231.571748018265,47488.66160154343,1733.9795546531675,3.879467010498047,0.0 -141800,3.754309,0.9351745,,,,,,,,,,,,,, -141900,3.622632,0.9931459,,,,,,,,,,,,,, -142000,3.690542,0.97508365,,,,,,,,,,,,,, -142100,3.4448428,0.92239726,,,,,,,,,,,,,, -142200,3.7700086,0.940973,,,,,,,,,,,,,, -142300,3.37281,0.9307103,,,,,,,,,,,,,, -142400,3.6216362,0.87727875,,,,,,,,,,,,,, -142500,3.388075,0.88909817,,,,,,,,,,,,,, -142600,3.673515,0.95904994,,,,,,,,,,,,,, -142700,3.62252,0.8629909,,,,,,,,,,,,,, -142800,3.8671231,0.97329694,,,,,,,,,,,,,, -142900,3.5630884,0.9635531,,,,,,,,,,,,,, -143000,3.759522,0.98856384,,,,,,,,,,,,,, -143100,3.633548,0.9498006,,,,,,,,,,,,,, -143200,3.6907697,0.9291606,,,,,,,,,,,,,, -143275,,,0.8688217401504517,0.4498252868652344,0.7297799587249756,1.1145541667938232,50000.0,0.6026000380516052,1.867939591407776,10000.0,47998.768881082535,49758.36278486252,47998.768881082535,1750.5503687858582,3.9371063709259033,0.0 -143300,3.5289268,0.92168856,,,,,,,,,,,,,, -143400,3.7333393,0.98852026,,,,,,,,,,,,,, -143500,3.7242482,0.9317118,,,,,,,,,,,,,, -143600,3.777241,0.94044185,,,,,,,,,,,,,, -143700,3.596701,0.9446575,,,,,,,,,,,,,, -143800,4.1418853,0.99459946,,,,,,,,,,,,,, -143900,4.038819,1.0073776,,,,,,,,,,,,,, -144000,3.759789,0.9371134,,,,,,,,,,,,,, -144100,3.4481359,0.86522925,,,,,,,,,,,,,, -144200,3.824941,0.9107995,,,,,,,,,,,,,, -144300,3.6230607,0.88755924,,,,,,,,,,,,,, -144400,3.859325,0.92423826,,,,,,,,,,,,,, -144500,3.7691524,0.8754552,,,,,,,,,,,,,, -144600,3.353054,0.89531046,,,,,,,,,,,,,, -144700,3.907755,0.8584452,,,,,,,,,,,,,, -144800,,,0.89164137840271,0.3792960345745086,0.7303000092506409,1.120832443237305,50000.0,0.6032000184059143,1.8936916589736936,10000.0,48508.867156744,50285.14401555061,48508.867156744,1767.1221826076508,3.994133949279785,0.0 -144800,3.9594157,0.9393618,,,,,,,,,,,,,, -144900,4.2168794,1.0376022,,,,,,,,,,,,,, -145000,3.6427314,0.9292389,,,,,,,,,,,,,, -145100,3.857217,0.9930156,,,,,,,,,,,,,, -145200,3.9266021,0.9367784,,,,,,,,,,,,,, -145300,3.7898815,0.9358007,,,,,,,,,,,,,, -145400,3.9979794,0.9273205,,,,,,,,,,,,,, -145500,3.9878561,0.9443153,,,,,,,,,,,,,, -145600,3.996493,1.0515006,,,,,,,,,,,,,, -145700,3.9337888,0.9924442,,,,,,,,,,,,,, -145800,4.130034,0.94429207,,,,,,,,,,,,,, -145900,3.7585979,0.89194983,,,,,,,,,,,,,, -146000,3.7504084,0.87777615,,,,,,,,,,,,,, -146100,3.644857,0.9848318,,,,,,,,,,,,,, -146200,3.9174726,0.9198554,,,,,,,,,,,,,, -146300,3.4872482,0.88086486,,,,,,,,,,,,,, -146325,,,0.8965840339660645,0.36079141497612,0.7327799797058105,1.1133227348327637,50000.0,0.6065000295639038,1.8500038385391235,10000.0,49019.03108620644,50812.22521138191,49019.03108620644,1783.9339122772217,4.045078992843628,0.0 -146400,4.3305774,0.99917126,,,,,,,,,,,,,, -146500,3.9683619,0.96378654,,,,,,,,,,,,,, -146600,3.434784,0.808324,,,,,,,,,,,,,, -146700,3.979092,0.96136284,,,,,,,,,,,,,, -146800,3.9689116,0.8782102,,,,,,,,,,,,,, -146900,3.688203,1.009091,,,,,,,,,,,,,, -147000,3.9174726,0.9460396,,,,,,,,,,,,,, -147100,4.0124784,0.96241105,,,,,,,,,,,,,, -147200,3.724102,0.87757856,,,,,,,,,,,,,, -147300,3.8318188,0.94314533,,,,,,,,,,,,,, -147400,3.690521,0.8867165,,,,,,,,,,,,,, -147500,3.9077318,0.9237125,,,,,,,,,,,,,, -147600,4.171967,0.92434424,,,,,,,,,,,,,, -147700,3.8523083,0.8783641,,,,,,,,,,,,,, -147800,4.1591873,0.8672637,,,,,,,,,,,,,, -147851,,,0.8950892686843872,0.3633810877799988,0.7359399795532227,1.0984907150268557,50000.0,0.6067000031471252,1.86048686504364,10000.0,49529.17516756058,51339.01440405846,49529.17516756058,1800.472449302673,4.096417427062988,0.0 -147900,4.0425534,0.8625507,,,,,,,,,,,,,, -148000,3.5825357,0.8475842,,,,,,,,,,,,,, -148100,3.9613755,0.82769847,,,,,,,,,,,,,, -148200,3.7072418,0.8850412,,,,,,,,,,,,,, -148300,3.8663442,0.8791357,,,,,,,,,,,,,, -148400,3.9475143,0.8989388,,,,,,,,,,,,,, -148500,3.6896389,0.8064378,,,,,,,,,,,,,, -148600,4.2512755,0.81327915,,,,,,,,,,,,,, -148700,3.6976795,0.835314,,,,,,,,,,,,,, -148800,4.352868,0.87138486,,,,,,,,,,,,,, -148900,4.277066,0.94870406,,,,,,,,,,,,,, -149000,3.795388,0.8314153,,,,,,,,,,,,,, -149100,4.253574,0.9281124,,,,,,,,,,,,,, -149200,3.7956986,0.95270395,,,,,,,,,,,,,, -149300,3.865817,0.95680195,,,,,,,,,,,,,, -149376,,,0.8971220850944519,0.3583529591560364,0.7389199733734131,1.0933120250701904,50000.0,0.6077000498771667,1.858315110206604,10000.0,50039.28535270691,51866.091879844666,50039.28535270691,1817.333134889603,4.147741079330444,0.0 -149400,3.784322,0.8481184,,,,,,,,,,,,,, -149500,3.8713617,0.889278,,,,,,,,,,,,,, -149600,4.1104646,0.8487432,,,,,,,,,,,,,, -149700,3.8754895,0.8754538,,,,,,,,,,,,,, -149800,4.455315,0.8852707,,,,,,,,,,,,,, -149900,4.2066865,0.87789047,,,,,,,,,,,,,, -150000,3.937558,0.85193217,,,,,,,,,,,,,, -150100,4.060028,0.9292487,,,,,,,,,,,,,, -150200,3.6479714,0.8281156,,,,,,,,,,,,,, -150300,4.071697,0.8123211,,,,,,,,,,,,,, -150400,3.901854,0.84057283,,,,,,,,,,,,,, -150500,3.903717,0.869985,,,,,,,,,,,,,, -150600,4.184001,0.8718632,,,,,,,,,,,,,, -150700,4.4074354,0.88606757,,,,,,,,,,,,,, -150800,3.9551113,0.80858815,,,,,,,,,,,,,, -150900,4.6464677,0.8884156,,,,,,,,,,,,,, -150901,,,0.8970025181770325,0.3487862348556518,0.740339994430542,1.1000778675079346,50000.0,0.6073000431060791,1.8836166858673096,10000.0,50549.420686244965,52392.9816904068,50549.420686244965,1833.978541135788,4.200202941894531,0.0 -151000,4.0758038,0.88841563,,,,,,,,,,,,,, -151100,4.021389,0.78266025,,,,,,,,,,,,,, -151200,3.9117217,0.8376513,,,,,,,,,,,,,, -151300,4.093054,0.8836634,,,,,,,,,,,,,, -151400,3.661122,0.7097199,,,,,,,,,,,,,, -151500,3.8615339,0.8707017,,,,,,,,,,,,,, -151600,4.3064256,0.92855525,,,,,,,,,,,,,, -151700,4.0229626,0.87042344,,,,,,,,,,,,,, -151800,4.2104306,0.79686296,,,,,,,,,,,,,, -151900,4.313771,0.8864031,,,,,,,,,,,,,, -152000,4.232473,0.8874165,,,,,,,,,,,,,, -152100,4.074572,0.81210434,,,,,,,,,,,,,, -152200,4.24945,0.90725857,,,,,,,,,,,,,, -152300,4.7108297,0.86052376,,,,,,,,,,,,,, -152400,4.169077,0.8019716,,,,,,,,,,,,,, -152426,,,0.9053730964660645,0.3358666598796844,0.738319993019104,1.0839452743530271,50000.0,0.6089000105857849,1.82943332195282,10000.0,51059.445302248,52919.78895497322,51059.445302248,1850.6516172885893,4.255874156951904,0.0 -152500,4.148617,0.81901807,,,,,,,,,,,,,, -152600,4.174143,0.84757245,,,,,,,,,,,,,, -152700,4.1782928,0.82290673,,,,,,,,,,,,,, -152800,4.193158,0.82972467,,,,,,,,,,,,,, -152900,4.209256,0.84711874,,,,,,,,,,,,,, -153000,4.1798406,0.79321945,,,,,,,,,,,,,, -153100,4.078963,0.7822647,,,,,,,,,,,,,, -153200,4.4023457,0.8190866,,,,,,,,,,,,,, -153300,3.7100263,0.76109344,,,,,,,,,,,,,, -153400,4.596656,0.94493836,,,,,,,,,,,,,, -153500,3.6466262,0.77967423,,,,,,,,,,,,,, -153600,4.208678,0.83355534,,,,,,,,,,,,,, -153700,3.948829,0.7593697,,,,,,,,,,,,,, -153800,4.3262835,0.89911896,,,,,,,,,,,,,, -153900,4.0453553,0.8181865,,,,,,,,,,,,,, -153951,,,0.9297871589660645,0.2487999647855758,0.7396799921989441,1.0863995552062988,50000.0,0.6115000247955322,1.877955794334412,10000.0,51569.52954649925,53447.32394480705,51569.52954649925,1867.9935710430143,4.310412168502808,0.0 -154000,4.0209527,0.77196914,,,,,,,,,,,,,, -154100,4.1027417,0.80570936,,,,,,,,,,,,,, -154200,4.1477447,0.8256688,,,,,,,,,,,,,, -154300,3.9024231,0.8562386,,,,,,,,,,,,,, -154400,3.8540323,0.7298251,,,,,,,,,,,,,, -154500,4.0056543,0.8303781,,,,,,,,,,,,,, -154600,4.453283,0.9121132,,,,,,,,,,,,,, -154700,4.6051397,0.8337522,,,,,,,,,,,,,, -154800,4.193294,0.80958354,,,,,,,,,,,,,, -154900,4.3210945,0.73314536,,,,,,,,,,,,,, -155000,4.0452204,0.7929686,,,,,,,,,,,,,, -155100,4.9256067,0.89888626,,,,,,,,,,,,,, -155200,4.1171227,0.7902128,,,,,,,,,,,,,, -155300,3.97008,0.81870526,,,,,,,,,,,,,, -155400,4.086364,0.837964,,,,,,,,,,,,,, -155476,,,0.919343888759613,0.2705561220645904,0.743939995765686,1.0910815000534058,50000.0,0.617400050163269,1.85482394695282,10000.0,52079.50380158424,53974.00850534439,52079.50380158424,1884.5948798656464,4.365032911300659,0.0 -155500,4.074111,0.7592513,,,,,,,,,,,,,, -155600,4.3940444,0.8996132,,,,,,,,,,,,,, -155700,4.8965244,0.87430686,,,,,,,,,,,,,, -155800,4.346132,0.8056489,,,,,,,,,,,,,, -155900,3.8815107,0.775242,,,,,,,,,,,,,, -156000,3.9777768,0.7658779,,,,,,,,,,,,,, -156100,4.111319,0.7686346,,,,,,,,,,,,,, -156200,4.7727323,0.8859716,,,,,,,,,,,,,, -156300,4.1001167,0.8007605,,,,,,,,,,,,,, -156400,4.1016135,0.80721295,,,,,,,,,,,,,, -156500,4.0686517,0.6967369,,,,,,,,,,,,,, -156600,4.313261,0.79758394,,,,,,,,,,,,,, -156700,4.1743016,0.80337787,,,,,,,,,,,,,, -156800,4.119027,0.7763306,,,,,,,,,,,,,, -156900,4.13593,0.8638173,,,,,,,,,,,,,, -157000,3.9686072,0.7223306,,,,,,,,,,,,,, -157001,,,0.9226123690605164,0.2707573473453522,0.7443199753761292,1.079198122024536,50000.0,0.6187000274658203,1.8395804166793823,10000.0,52589.87325882912,54500.97246050835,52589.87325882912,1901.0806233882904,4.418842077255249,0.0 -157100,4.4478636,0.84178376,,,,,,,,,,,,,, -157200,4.2459917,0.7921435,,,,,,,,,,,,,, -157300,4.2385116,0.8212911,,,,,,,,,,,,,, -157400,4.0931053,0.75053936,,,,,,,,,,,,,, -157500,4.2932158,0.7731466,,,,,,,,,,,,,, -157600,4.262903,0.78163105,,,,,,,,,,,,,, -157700,4.385808,0.8167735,,,,,,,,,,,,,, -157800,4.642258,0.814713,,,,,,,,,,,,,, -157900,4.493055,0.7874158,,,,,,,,,,,,,, -158000,4.439838,0.8666004,,,,,,,,,,,,,, -158100,4.5899644,0.87728256,,,,,,,,,,,,,, -158200,4.208107,0.8188647,,,,,,,,,,,,,, -158300,4.576313,0.84437907,,,,,,,,,,,,,, -158400,3.9848802,0.776037,,,,,,,,,,,,,, -158500,4.4528265,0.7630591,,,,,,,,,,,,,, -158525,,,0.9240074753761292,0.2637510001659393,0.7441999912261963,1.0730669498443604,50000.0,0.6181000471115112,1.8403584957122805,10000.0,53099.86465787888,55027.62099266052,53099.86465787888,1917.6260554790497,4.475511312484741,0.0 -158600,4.048344,0.7381617,,,,,,,,,,,,,, -158700,4.2726603,0.7408602,,,,,,,,,,,,,, -158800,4.664089,0.8086327,,,,,,,,,,,,,, -158900,4.2000732,0.7544327,,,,,,,,,,,,,, -159000,4.5189567,0.778499,,,,,,,,,,,,,, -159100,4.16238,0.73590004,,,,,,,,,,,,,, -159200,4.2946,0.80993074,,,,,,,,,,,,,, -159300,4.157605,0.82610554,,,,,,,,,,,,,, -159400,3.8186276,0.67143744,,,,,,,,,,,,,, -159500,4.0673127,0.6833958,,,,,,,,,,,,,, -159600,4.589474,0.83492035,,,,,,,,,,,,,, -159700,4.4810085,0.75866264,,,,,,,,,,,,,, -159800,4.0632334,0.7295109,,,,,,,,,,,,,, -159900,4.4979396,0.7582183,,,,,,,,,,,,,, -160000,4.3491435,0.8409915,,,,,,,,,,,,,, -160050,,,0.9275350570678712,0.2511384785175323,0.7457000017166138,1.075924515724182,50000.0,0.6213000416755676,1.8441084623336792,10000.0,53610.03590750694,55554.67021775246,53610.03590750694,1934.39634346962,4.528732776641846,0.0 -160100,4.224381,0.6938829,,,,,,,,,,,,,, -160200,4.6174083,0.74175024,,,,,,,,,,,,,, -160300,4.683777,0.7793061,,,,,,,,,,,,,, -160400,5.292832,0.7436423,,,,,,,,,,,,,, -160500,4.6746273,0.71001005,,,,,,,,,,,,,, -160600,4.5107903,0.75559324,,,,,,,,,,,,,, -160700,4.300135,0.7858442,,,,,,,,,,,,,, -160800,4.3588605,0.70074195,,,,,,,,,,,,,, -160900,4.284051,0.7427711,,,,,,,,,,,,,, -161000,4.390279,0.81084096,,,,,,,,,,,,,, -161100,4.49023,0.74430394,,,,,,,,,,,,,, -161200,4.297149,0.7648211,,,,,,,,,,,,,, -161300,4.143503,0.71718,,,,,,,,,,,,,, -161400,4.4179797,0.7133662,,,,,,,,,,,,,, -161500,4.5228863,0.81248313,,,,,,,,,,,,,, -161575,,,0.9306042790412904,0.2416285127401352,0.7472400069236755,1.0739290714263916,50000.0,0.6184000372886658,1.8437731266021729,10000.0,54120.14770627022,56081.3993742466,54120.14770627022,1950.8956489562988,4.591880559921265,0.0 -161600,4.1149006,0.75575644,,,,,,,,,,,,,, -161700,4.4728823,0.79001874,,,,,,,,,,,,,, -161800,4.3814025,0.7643086,,,,,,,,,,,,,, -161900,4.2387857,0.6734602,,,,,,,,,,,,,, -162000,4.516899,0.7757991,,,,,,,,,,,,,, -162100,4.426669,0.7081777,,,,,,,,,,,,,, -162200,4.0679226,0.65476775,,,,,,,,,,,,,, -162300,4.523077,0.78759193,,,,,,,,,,,,,, -162400,4.063435,0.71758074,,,,,,,,,,,,,, -162500,4.4931016,0.791903,,,,,,,,,,,,,, -162600,4.220044,0.69242775,,,,,,,,,,,,,, -162700,4.386594,0.7355969,,,,,,,,,,,,,, -162800,4.1178384,0.720667,,,,,,,,,,,,,, -162900,4.3562226,0.74426126,,,,,,,,,,,,,, -163000,4.7622514,0.8481285,,,,,,,,,,,,,, -163099,,,0.9433194994926452,0.2045804262161255,0.7457999587059021,1.067560791969299,50000.0,0.6213000416755676,1.836686849594116,10000.0,54630.08940792084,56608.01881337166,54630.08940792084,1967.4637095928192,4.646895170211792,0.0 -163100,4.4688444,0.7656994,,,,,,,,,,,,,, -163200,4.2882686,0.7751878,,,,,,,,,,,,,, -163300,4.324904,0.6861729,,,,,,,,,,,,,, -163400,5.0016055,0.6941674,,,,,,,,,,,,,, -163500,4.3105865,0.66234887,,,,,,,,,,,,,, -163600,4.0573826,0.71397036,,,,,,,,,,,,,, -163700,4.5285926,0.79150534,,,,,,,,,,,,,, -163800,4.827763,0.7706569,,,,,,,,,,,,,, -163900,3.8275678,0.6595312,,,,,,,,,,,,,, -164000,4.76669,0.76424503,,,,,,,,,,,,,, -164100,5.2771173,0.74803805,,,,,,,,,,,,,, -164200,4.1201243,0.70939213,,,,,,,,,,,,,, -164300,4.5514903,0.701944,,,,,,,,,,,,,, -164400,4.294675,0.77745515,,,,,,,,,,,,,, -164500,4.7419667,0.7434788,,,,,,,,,,,,,, -164600,4.5744123,0.7371968,,,,,,,,,,,,,, -164624,,,0.9432597160339355,0.200706347823143,0.7478399872779846,1.0719140768051147,50000.0,0.6182000041007996,1.8608922958374023,10000.0,55140.205631017685,57134.88123440743,55140.205631017685,1984.099481344223,4.70342230796814,0.0 -164700,4.7559657,0.85375524,,,,,,,,,,,,,, -164800,4.518356,0.72928476,,,,,,,,,,,,,, -164900,4.329751,0.66419274,,,,,,,,,,,,,, -165000,4.0904756,0.64833355,,,,,,,,,,,,,, -165100,4.3961434,0.7140724,,,,,,,,,,,,,, -165200,4.6804504,0.705368,,,,,,,,,,,,,, -165300,4.553439,0.72441494,,,,,,,,,,,,,, -165400,4.608599,0.7228123,,,,,,,,,,,,,, -165500,4.219409,0.6905979,,,,,,,,,,,,,, -165600,4.6602616,0.7046766,,,,,,,,,,,,,, -165700,4.457911,0.66166055,,,,,,,,,,,,,, -165800,4.2436795,0.61946326,,,,,,,,,,,,,, -165900,4.0664062,0.6231059,,,,,,,,,,,,,, -166000,4.6435094,0.7435483,,,,,,,,,,,,,, -166100,4.3613343,0.7551544,,,,,,,,,,,,,, -166149,,,0.9441764950752258,0.1963724493980407,0.7493000030517578,1.0691767930984497,50000.0,0.6215000152587891,1.8464528322219849,10000.0,55650.33197426796,57661.73883676529,55650.33197426796,2000.7171757221224,4.762126207351685,0.0 -166200,4.526275,0.6643131,,,,,,,,,,,,,, -166300,4.7372236,0.76323533,,,,,,,,,,,,,, -166400,4.444132,0.7559626,,,,,,,,,,,,,, -166500,4.3731637,0.69668937,,,,,,,,,,,,,, -166600,4.8327403,0.6796274,,,,,,,,,,,,,, -166700,4.145549,0.6539461,,,,,,,,,,,,,, -166800,4.4631333,0.69577837,,,,,,,,,,,,,, -166900,4.1311564,0.66182244,,,,,,,,,,,,,, -167000,4.8331733,0.6704236,,,,,,,,,,,,,, -167100,4.556855,0.7243376,,,,,,,,,,,,,, -167200,4.4697285,0.73237497,,,,,,,,,,,,,, -167300,4.814448,0.79836833,,,,,,,,,,,,,, -167400,4.449346,0.75042754,,,,,,,,,,,,,, -167500,4.7933955,0.62662566,,,,,,,,,,,,,, -167600,4.336264,0.6519434,,,,,,,,,,,,,, -167673,,,0.9439173936843872,0.1987826824188232,0.7494999766349792,1.0682849884033203,50000.0,0.6243000030517578,1.844048857688904,10000.0,56160.3334069252,58188.44730424881,56160.3334069252,2017.3095281124115,4.821733713150024,0.0 -167700,4.14137,0.62545,,,,,,,,,,,,,, -167800,4.7227135,0.6586065,,,,,,,,,,,,,, -167900,4.5684247,0.70319587,,,,,,,,,,,,,, -168000,4.7518377,0.6368008,,,,,,,,,,,,,, -168100,4.3244925,0.7067339,,,,,,,,,,,,,, -168200,4.616556,0.69187844,,,,,,,,,,,,,, -168300,4.992765,0.6706894,,,,,,,,,,,,,, -168400,4.677968,0.662555,,,,,,,,,,,,,, -168500,4.816456,0.67893314,,,,,,,,,,,,,, -168600,4.543952,0.6869004,,,,,,,,,,,,,, -168700,4.332266,0.6792195,,,,,,,,,,,,,, -168800,4.5470977,0.62266785,,,,,,,,,,,,,, -168900,4.296862,0.62769145,,,,,,,,,,,,,, -169000,4.625564,0.7164234,,,,,,,,,,,,,, -169100,4.4537015,0.64679945,,,,,,,,,,,,,, -169198,,,0.9467872977256776,0.1883253902196884,0.7502999901771545,1.0689047574996948,50000.0,0.6219000220298767,1.847320795059204,10000.0,56670.40325450897,58715.23123264313,56670.40325450897,2033.912229537964,4.877910375595093,0.0 -169200,4.593506,0.71746385,,,,,,,,,,,,,, -169300,4.3372197,0.6891014,,,,,,,,,,,,,, -169400,4.8107653,0.6924339,,,,,,,,,,,,,, -169500,4.5505137,0.65841585,,,,,,,,,,,,,, -169600,4.475323,0.6909502,,,,,,,,,,,,,, -169700,4.2951217,0.65562916,,,,,,,,,,,,,, -169800,4.6111455,0.6625843,,,,,,,,,,,,,, -169900,4.3738027,0.68271625,,,,,,,,,,,,,, -170000,4.317917,0.6279687,,,,,,,,,,,,,, -170100,4.3614326,0.7048875,,,,,,,,,,,,,, -170200,4.908007,0.700184,,,,,,,,,,,,,, -170300,4.4557066,0.65485495,,,,,,,,,,,,,, -170400,4.97006,0.6649741,,,,,,,,,,,,,, -170500,5.124007,0.67964447,,,,,,,,,,,,,, -170600,5.0292907,0.6622744,,,,,,,,,,,,,, -170700,4.632879,0.6292606,,,,,,,,,,,,,, -170723,,,0.9580875039100648,0.1585330516099929,0.7505199909210205,1.065174221992493,50000.0,0.6260000467300415,1.84285831451416,10000.0,57180.59017777443,59242.078933000565,57180.59017777443,2050.458500146866,4.937411308288574,0.0 -170800,4.6559124,0.6358342,,,,,,,,,,,,,, -170900,4.2092004,0.6882115,,,,,,,,,,,,,, -171000,4.657514,0.72951305,,,,,,,,,,,,,, -171100,4.3043704,0.70588064,,,,,,,,,,,,,, -171200,4.7475038,0.7555104,,,,,,,,,,,,,, -171300,4.681131,0.64575756,,,,,,,,,,,,,, -171400,4.8400884,0.60637075,,,,,,,,,,,,,, -171500,4.2340136,0.6478944,,,,,,,,,,,,,, -171600,4.605728,0.7551388,,,,,,,,,,,,,, -171700,4.2664433,0.60261625,,,,,,,,,,,,,, -171800,4.970092,0.66867703,,,,,,,,,,,,,, -171900,4.480967,0.63192767,,,,,,,,,,,,,, -172000,4.717367,0.7137002,,,,,,,,,,,,,, -172100,4.537001,0.6841831,,,,,,,,,,,,,, -172200,5.0284595,0.6616992,,,,,,,,,,,,,, -172248,,,0.9566724896430968,0.1599704623222351,0.7523999810218811,1.0659375190734863,50000.0,0.6273000240325928,1.85725200176239,10000.0,57690.77718710899,59768.884503126144,57690.77718710899,2066.960639238357,4.997782468795776,0.0 -172300,4.5150476,0.7207936,,,,,,,,,,,,,, -172400,4.8210216,0.68498814,,,,,,,,,,,,,, -172500,5.045757,0.6733304,,,,,,,,,,,,,, -172600,4.3416853,0.66430604,,,,,,,,,,,,,, -172700,4.442243,0.61546063,,,,,,,,,,,,,, -172800,5.175176,0.6857218,,,,,,,,,,,,,, -172900,4.287768,0.66524255,,,,,,,,,,,,,, -173000,4.6273894,0.6109251,,,,,,,,,,,,,, -173100,4.4694834,0.70193577,,,,,,,,,,,,,, -173200,4.2032948,0.62277615,,,,,,,,,,,,,, -173300,5.034359,0.7080878,,,,,,,,,,,,,, -173400,4.029258,0.6458738,,,,,,,,,,,,,, -173500,4.6469455,0.66721416,,,,,,,,,,,,,, -173600,4.3022156,0.6868352,,,,,,,,,,,,,, -173700,5.004288,0.713043,,,,,,,,,,,,,, -173772,,,0.9553371667861938,0.1635049283504486,0.7538999915122986,1.0613198280334473,50000.0,0.628000020980835,1.849697709083557,10000.0,58200.87525296211,60295.775837898254,58200.87525296211,2083.6383051872253,5.058853149414063,0.0 -173800,4.5038376,0.67712176,,,,,,,,,,,,,, -173900,4.3213816,0.6176567,,,,,,,,,,,,,, -174000,4.825276,0.6321171,,,,,,,,,,,,,, -174100,4.203861,0.63918096,,,,,,,,,,,,,, -174200,4.611768,0.6932209,,,,,,,,,,,,,, -174300,4.1259375,0.6312384,,,,,,,,,,,,,, -174400,4.1568527,0.67359555,,,,,,,,,,,,,, -174500,4.6605816,0.6577172,,,,,,,,,,,,,, -174600,4.7674756,0.75522214,,,,,,,,,,,,,, -174700,4.5866637,0.6431564,,,,,,,,,,,,,, -174800,5.062666,0.64174306,,,,,,,,,,,,,, -174900,4.850379,0.61377555,,,,,,,,,,,,,, -175000,4.9644203,0.6610252,,,,,,,,,,,,,, -175100,4.221117,0.62030137,,,,,,,,,,,,,, -175200,4.7332406,0.6729354,,,,,,,,,,,,,, -175297,,,0.9546197056770324,0.1615114212036132,0.7534799575805664,1.060207486152649,50000.0,0.628000020980835,1.851643443107605,10000.0,58711.03318500519,60822.650486946106,58711.03318500519,2100.2404623031616,5.118907690048218,0.0 -175300,4.333253,0.5969477,,,,,,,,,,,,,, -175400,4.5441456,0.6690797,,,,,,,,,,,,,, -175500,4.8595653,0.61717343,,,,,,,,,,,,,, -175600,4.736058,0.6547613,,,,,,,,,,,,,, -175700,4.739286,0.7424473,,,,,,,,,,,,,, -175800,4.725636,0.78974426,,,,,,,,,,,,,, -175900,4.6465836,0.6838891,,,,,,,,,,,,,, -176000,4.706835,0.7002587,,,,,,,,,,,,,, -176100,4.2259917,0.6153575,,,,,,,,,,,,,, -176200,4.3915854,0.67087334,,,,,,,,,,,,,, -176300,4.4761124,0.67665374,,,,,,,,,,,,,, -176400,4.5418987,0.5859632,,,,,,,,,,,,,, -176500,4.617876,0.6106389,,,,,,,,,,,,,, -176600,5.112998,0.6588256,,,,,,,,,,,,,, -176700,4.7240424,0.6461899,,,,,,,,,,,,,, -176800,5.046757,0.6596187,,,,,,,,,,,,,, -176821,,,0.9572106003761292,0.158622071146965,0.7540599703788757,1.0589282512664795,50000.0,0.6261000037193298,1.847219467163086,10000.0,59220.95730185509,61349.172367334366,59220.95730185509,2116.711015462876,5.191303253173828,0.0 -176900,4.4453783,0.64875734,,,,,,,,,,,,,, -177000,5.2761197,0.61846656,,,,,,,,,,,,,, -177100,4.396248,0.59128404,,,,,,,,,,,,,, -177200,4.274762,0.5832386,,,,,,,,,,,,,, -177300,4.4721704,0.65834177,,,,,,,,,,,,,, -177400,4.023632,0.6215932,,,,,,,,,,,,,, -177500,4.581212,0.6845817,,,,,,,,,,,,,, -177600,4.336203,0.65202767,,,,,,,,,,,,,, -177700,4.5149255,0.6330411,,,,,,,,,,,,,, -177800,4.925359,0.6856431,,,,,,,,,,,,,, -177900,4.6666627,0.63862914,,,,,,,,,,,,,, -178000,4.171452,0.5782809,,,,,,,,,,,,,, -178100,5.0228515,0.63539034,,,,,,,,,,,,,, -178200,4.2130513,0.6025311,,,,,,,,,,,,,, -178300,4.579269,0.6293534,,,,,,,,,,,,,, -178345,,,0.95804762840271,0.153394877910614,0.7551199793815613,1.0549086332321167,50000.0,0.6258000135421753,1.8444610834121704,10000.0,59730.89150452614,61875.924149274826,59730.89150452614,2133.417886257172,5.247831106185913,0.0 -178400,4.794644,0.59855336,,,,,,,,,,,,,, -178500,4.8718786,0.6022413,,,,,,,,,,,,,, -178600,4.36683,0.5945735,,,,,,,,,,,,,, -178700,4.3206816,0.59744906,,,,,,,,,,,,,, -178800,4.835785,0.66847616,,,,,,,,,,,,,, -178900,4.181367,0.5980505,,,,,,,,,,,,,, -179000,5.3903103,0.6236309,,,,,,,,,,,,,, -179100,4.4614406,0.6343087,,,,,,,,,,,,,, -179200,4.657704,0.60762626,,,,,,,,,,,,,, -179300,4.3096604,0.5882867,,,,,,,,,,,,,, -179400,4.706169,0.5850271,,,,,,,,,,,,,, -179500,4.725593,0.65385896,,,,,,,,,,,,,, -179600,4.3585744,0.6562314,,,,,,,,,,,,,, -179700,4.2095985,0.65097314,,,,,,,,,,,,,, -179800,4.512711,0.6145919,,,,,,,,,,,,,, -179869,,,0.9615951776504515,0.1446188241243362,0.7554000020027161,1.056932806968689,50000.0,0.6279000043869019,1.8462977409362795,10000.0,60240.81658220291,62402.48623228073,60240.81658220291,2149.941981315613,5.304441213607788,0.0 -179900,4.5713916,0.61094457,,,,,,,,,,,,,, -180000,4.7735953,0.58717716,,,,,,,,,,,,,, -180100,4.6710644,0.60038596,,,,,,,,,,,,,, -180200,4.406329,0.60692894,,,,,,,,,,,,,, -180300,4.4562583,0.57981396,,,,,,,,,,,,,, -180400,4.529541,0.6736008,,,,,,,,,,,,,, -180500,4.774353,0.7152056,,,,,,,,,,,,,, -180600,4.595582,0.66394544,,,,,,,,,,,,,, -180700,4.5347266,0.62256134,,,,,,,,,,,,,, -180800,4.4460382,0.6373657,,,,,,,,,,,,,, -180900,4.816585,0.6604863,,,,,,,,,,,,,, -181000,4.5567656,0.72499037,,,,,,,,,,,,,, -181100,4.4709377,0.566408,,,,,,,,,,,,,, -181200,4.764737,0.7016922,,,,,,,,,,,,,, -181300,4.581568,0.618261,,,,,,,,,,,,,, -181393,,,0.96000075340271,0.1480083018541336,0.7560999989509583,1.0549815893173218,50000.0,0.6279000043869019,1.8437385559082031,10000.0,60750.70178294182,62929.314858675,60750.70178294182,2166.7651064395905,5.369521379470825,0.0 -181400,4.727734,0.6042969,,,,,,,,,,,,,, -181500,4.22591,0.58877856,,,,,,,,,,,,,, -181600,4.252052,0.54712987,,,,,,,,,,,,,, -181700,4.6592584,0.6222678,,,,,,,,,,,,,, -181800,4.1528807,0.55724275,,,,,,,,,,,,,, -181900,4.5290437,0.6166196,,,,,,,,,,,,,, -182000,4.4387536,0.5880647,,,,,,,,,,,,,, -182100,4.257504,0.6436801,,,,,,,,,,,,,, -182200,4.81505,0.57408255,,,,,,,,,,,,,, -182300,4.5470953,0.62854457,,,,,,,,,,,,,, -182400,4.1819735,0.60560036,,,,,,,,,,,,,, -182500,4.821156,0.60977125,,,,,,,,,,,,,, -182600,4.969169,0.6456752,,,,,,,,,,,,,, -182700,4.2337804,0.557495,,,,,,,,,,,,,, -182800,5.108711,0.7160347,,,,,,,,,,,,,, -182900,4.5245733,0.6758719,,,,,,,,,,,,,, -182918,,,0.9614157676696776,0.1445799022912979,0.7551400065422058,1.0538650751113892,50000.0,0.6274000406265259,1.8420039415359497,10000.0,61260.74364876747,63456.022471666336,61260.74364876747,2183.3145368099213,5.4302978515625,0.0 -183000,4.5286684,0.5693959,,,,,,,,,,,,,, -183100,4.2463174,0.56756246,,,,,,,,,,,,,, -183200,4.465297,0.58437634,,,,,,,,,,,,,, -183300,4.543274,0.63312644,,,,,,,,,,,,,, -183400,4.993707,0.684513,,,,,,,,,,,,,, -183500,4.9638596,0.6343465,,,,,,,,,,,,,, -183600,4.691371,0.6072858,,,,,,,,,,,,,, -183700,4.722298,0.6759856,,,,,,,,,,,,,, -183800,4.3058953,0.583384,,,,,,,,,,,,,, -183900,4.208772,0.5936471,,,,,,,,,,,,,, -184000,4.7487864,0.66367406,,,,,,,,,,,,,, -184100,4.480052,0.57275844,,,,,,,,,,,,,, -184200,4.662003,0.57652724,,,,,,,,,,,,,, -184300,4.8272634,0.7087578,,,,,,,,,,,,,, -184400,4.136803,0.5910493,,,,,,,,,,,,,, -184442,,,0.960598647594452,0.1481765508651733,0.7552399635314941,1.054193139076233,50000.0,0.6267000436782837,1.8438823223114007,10000.0,61770.76652598381,63982.70475888252,61770.76652598381,2199.8553504943848,5.492738246917725,0.0 -184500,4.5667663,0.67993665,,,,,,,,,,,,,, -184600,4.4004993,0.66505635,,,,,,,,,,,,,, -184700,4.488286,0.6500674,,,,,,,,,,,,,, -184800,4.688794,0.6328016,,,,,,,,,,,,,, -184900,4.6220055,0.6014725,,,,,,,,,,,,,, -185000,4.6119466,0.5961825,,,,,,,,,,,,,, -185100,4.8226953,0.60449535,,,,,,,,,,,,,, -185200,4.571762,0.6734634,,,,,,,,,,,,,, -185300,4.9487433,0.6753925,,,,,,,,,,,,,, -185400,4.7282352,0.6603843,,,,,,,,,,,,,, -185500,4.5921974,0.64108336,,,,,,,,,,,,,, -185600,4.2059703,0.58208567,,,,,,,,,,,,,, -185700,4.2619114,0.61234784,,,,,,,,,,,,,, -185800,4.4503903,0.62167263,,,,,,,,,,,,,, -185900,4.3503804,0.627903,,,,,,,,,,,,,, -185967,,,0.9606983065605164,0.1485299617052078,0.7558599710464478,1.0535261631011963,50000.0,0.6271000504493713,1.8431257009506223,10000.0,62280.75594091416,64509.414836645126,62280.75594091416,2216.459883451462,5.554203510284424,0.0 -186000,4.778067,0.60371345,,,,,,,,,,,,,, -186100,4.977308,0.6357751,,,,,,,,,,,,,, -186200,4.5713487,0.61097646,,,,,,,,,,,,,, -186300,4.518051,0.62029994,,,,,,,,,,,,,, -186400,4.9012523,0.6137061,,,,,,,,,,,,,, -186500,4.318491,0.62837064,,,,,,,,,,,,,, -186600,4.661849,0.5677598,,,,,,,,,,,,,, -186700,4.32995,0.6013956,,,,,,,,,,,,,, -186800,4.371547,0.53644794,,,,,,,,,,,,,, -186900,5.0317907,0.68566597,,,,,,,,,,,,,, -187000,4.3030434,0.5940549,,,,,,,,,,,,,, -187100,4.843406,0.6112141,,,,,,,,,,,,,, -187200,4.781218,0.6129647,,,,,,,,,,,,,, -187300,4.358046,0.58343995,,,,,,,,,,,,,, -187400,4.4602294,0.6389829,,,,,,,,,,,,,, -187492,,,0.9605787396430968,0.146380066871643,0.7558000087738037,1.053703546524048,50000.0,0.626800000667572,1.8425025939941408,10000.0,62790.8897960186,65036.399695158005,62790.8897960186,2233.191954135895,5.617657661437988,0.0 -187500,5.0489936,0.5794029,,,,,,,,,,,,,, -187600,4.750179,0.62313515,,,,,,,,,,,,,, -187700,4.794527,0.6913767,,,,,,,,,,,,,, -187800,4.7721915,0.6252954,,,,,,,,,,,,,, -187900,4.611618,0.62060446,,,,,,,,,,,,,, -188000,4.748086,0.62442666,,,,,,,,,,,,,, -188100,4.677743,0.63898754,,,,,,,,,,,,,, -188200,4.6240582,0.58994085,,,,,,,,,,,,,, -188300,4.254281,0.6365951,,,,,,,,,,,,,, -188400,5.3593726,0.6568261,,,,,,,,,,,,,, -188500,4.295638,0.5919374,,,,,,,,,,,,,, -188600,4.916492,0.6572626,,,,,,,,,,,,,, -188700,4.056114,0.64890194,,,,,,,,,,,,,, -188800,4.163162,0.5376083,,,,,,,,,,,,,, -188900,4.680171,0.5777939,,,,,,,,,,,,,, -189000,4.3372717,0.59539926,,,,,,,,,,,,,, -189017,,,0.9619937539100648,0.14264976978302,0.7558599710464478,1.0539004802703855,50000.0,0.6281000375747681,1.843846201896668,10000.0,63300.8772277832,65563.01993250847,63300.8772277832,2249.7101545333862,5.676936149597168,0.0 -189100,4.63907,0.6244134,,,,,,,,,,,,,, -189200,4.320063,0.57692766,,,,,,,,,,,,,, -189300,4.4445977,0.61248946,,,,,,,,,,,,,, -189400,4.743136,0.63762265,,,,,,,,,,,,,, -189500,5.2750726,0.69084156,,,,,,,,,,,,,, -189600,4.6224875,0.5974082,,,,,,,,,,,,,, -189700,4.9778576,0.57800984,,,,,,,,,,,,,, -189800,4.6816893,0.6275601,,,,,,,,,,,,,, -189900,4.3281155,0.6136992,,,,,,,,,,,,,, -190000,4.9157653,0.68064857,,,,,,,,,,,,,, -190100,4.7221174,0.613418,,,,,,,,,,,,,, -190200,5.246296,0.63686574,,,,,,,,,,,,,, -190300,5.164946,0.54773074,,,,,,,,,,,,,, -190400,4.1689405,0.61148643,,,,,,,,,,,,,, -190500,4.669248,0.5665632,,,,,,,,,,,,,, -190542,,,0.9632294178009032,0.142451986670494,0.7558799982070923,1.0543040037155151,50000.0,0.6265000104904175,1.8426756858825684,10000.0,63811.00440359116,66089.88953256607,63811.00440359116,2266.336406469345,5.737735748291016,0.0 -190600,4.1446714,0.57617885,,,,,,,,,,,,,, -190700,4.690025,0.74229276,,,,,,,,,,,,,, -190800,4.9495096,0.6267536,,,,,,,,,,,,,, -190900,4.8152103,0.6953001,,,,,,,,,,,,,, -191000,4.596708,0.6358609,,,,,,,,,,,,,, -191100,5.1507916,0.6264835,,,,,,,,,,,,,, -191200,3.9694057,0.5730002,,,,,,,,,,,,,, -191300,4.5601172,0.61147016,,,,,,,,,,,,,, -191400,4.7284904,0.58547133,,,,,,,,,,,,,, -191500,4.1820636,0.63656604,,,,,,,,,,,,,, -191600,4.656747,0.61601794,,,,,,,,,,,,,, -191700,5.0000467,0.67736757,,,,,,,,,,,,,, -191800,5.237472,0.6135149,,,,,,,,,,,,,, -191900,5.0389338,0.58121234,,,,,,,,,,,,,, -192000,4.199538,0.63030916,,,,,,,,,,,,,, -192067,,,0.9621531963348388,0.1447095423936844,0.7556399703025818,1.0540560483932495,50000.0,0.6270000338554382,1.8448634147644043,10000.0,64321.091645240784,66616.57750272751,64321.091645240784,2282.8199610710144,5.799633979797363,0.0 -192100,4.7974296,0.62134844,,,,,,,,,,,,,, -192200,5.0760293,0.6013468,,,,,,,,,,,,,, -192300,4.805531,0.61365837,,,,,,,,,,,,,, -192400,4.5851026,0.67990834,,,,,,,,,,,,,, -192500,4.172595,0.5849544,,,,,,,,,,,,,, -192600,4.616903,0.55718887,,,,,,,,,,,,,, -192700,5.135476,0.619019,,,,,,,,,,,,,, -192800,4.6761317,0.65287375,,,,,,,,,,,,,, -192900,4.9368224,0.6257325,,,,,,,,,,,,,, -193000,4.3523455,0.5610512,,,,,,,,,,,,,, -193100,4.7113066,0.61789477,,,,,,,,,,,,,, -193200,4.619186,0.69966114,,,,,,,,,,,,,, -193300,4.368329,0.58964974,,,,,,,,,,,,,, -193400,4.3379664,0.5724737,,,,,,,,,,,,,, -193500,4.530273,0.61078846,,,,,,,,,,,,,, -193593,,,0.9602199792861938,0.1460577994585037,0.7560200095176697,1.0536680221557615,50000.0,0.6272000074386597,1.841867446899414,10000.0,64831.305763959885,67144.21295118332,64831.305763959885,2300.1227231025696,5.86241602897644,0.0 -193600,4.5805335,0.6174793,,,,,,,,,,,,,, -193700,4.7327304,0.61656326,,,,,,,,,,,,,, -193800,4.788758,0.5849665,,,,,,,,,,,,,, -193900,4.2671537,0.6340693,,,,,,,,,,,,,, -194000,4.2093124,0.60238594,,,,,,,,,,,,,, -194100,5.369802,0.71529716,,,,,,,,,,,,,, -194200,4.7314405,0.581465,,,,,,,,,,,,,, -194300,4.43318,0.5668912,,,,,,,,,,,,,, -194400,4.4015956,0.61849874,,,,,,,,,,,,,, -194500,5.1086073,0.6744321,,,,,,,,,,,,,, -194600,5.0582137,0.6388465,,,,,,,,,,,,,, -194700,4.2627463,0.58254564,,,,,,,,,,,,,, -194800,4.597193,0.6675541,,,,,,,,,,,,,, -194900,4.327204,0.5345559,,,,,,,,,,,,,, -195000,4.4655423,0.58867633,,,,,,,,,,,,,, -195100,4.1873727,0.5609428,,,,,,,,,,,,,, -195118,,,0.9604591727256776,0.1471811532974243,0.7560399770736694,1.053143858909607,50000.0,0.6266000270843506,1.84319007396698,10000.0,65341.203255176544,67670.93338441849,65341.203255176544,2316.826899766922,5.926252603530884,0.0 -195200,4.749942,0.6706478,,,,,,,,,,,,,, -195300,4.7269588,0.6408461,,,,,,,,,,,,,, -195400,4.5777946,0.65385985,,,,,,,,,,,,,, -195500,4.4115343,0.64296514,,,,,,,,,,,,,, -195600,4.6601143,0.63324344,,,,,,,,,,,,,, -195700,4.356062,0.6178953,,,,,,,,,,,,,, -195800,4.2449684,0.59277713,,,,,,,,,,,,,, -195900,4.812993,0.66769475,,,,,,,,,,,,,, -196000,4.3074875,0.6050923,,,,,,,,,,,,,, -196100,4.2793956,0.63898265,,,,,,,,,,,,,, -196200,5.04191,0.6116216,,,,,,,,,,,,,, -196300,4.9657054,0.61059374,,,,,,,,,,,,,, -196400,4.5405927,0.63349307,,,,,,,,,,,,,, -196500,4.901871,0.6302955,,,,,,,,,,,,,, -196600,4.49448,0.5774644,,,,,,,,,,,,,, -196643,,,0.961933970451355,0.1429389268159866,0.756119966506958,1.0525805950164795,50000.0,0.6277000308036804,1.840379357337952,10000.0,65851.13903975487,68197.4244966507,65851.13903975487,2333.264596939087,5.989895343780518,0.0 -196700,3.9386368,0.56966627,,,,,,,,,,,,,, -196800,4.2172103,0.591982,,,,,,,,,,,,,, -196900,4.020962,0.56418395,,,,,,,,,,,,,, -197000,4.584681,0.6052733,,,,,,,,,,,,,, -197100,4.3311977,0.61864954,,,,,,,,,,,,,, -197200,4.1830916,0.60706204,,,,,,,,,,,,,, -197300,5.1629725,0.5896218,,,,,,,,,,,,,, -197400,4.356447,0.61559,,,,,,,,,,,,,, -197500,4.646155,0.61781454,,,,,,,,,,,,,, -197600,4.878864,0.6350398,,,,,,,,,,,,,, -197700,4.183366,0.59089255,,,,,,,,,,,,,, -197800,4.780615,0.58091265,,,,,,,,,,,,,, -197900,4.4377475,0.6474062,,,,,,,,,,,,,, -198000,4.2623124,0.61224633,,,,,,,,,,,,,, -198100,4.4246492,0.58209395,,,,,,,,,,,,,, -198168,,,0.9618542790412904,0.1434179842472076,0.7559199929237366,1.052737832069397,50000.0,0.627500057220459,1.841537594795227,10000.0,66361.20763134956,68724.17261481285,66361.20763134956,2349.8306188583374,6.049070358276367,0.0 -198200,4.8691144,0.6514562,,,,,,,,,,,,,, -198300,4.221611,0.5723713,,,,,,,,,,,,,, -198400,4.567626,0.6152266,,,,,,,,,,,,,, -198500,4.3966904,0.6044664,,,,,,,,,,,,,, -198600,4.4192586,0.5472542,,,,,,,,,,,,,, -198700,4.8752265,0.6555702,,,,,,,,,,,,,, -198800,4.7731934,0.589688,,,,,,,,,,,,,, -198900,4.779182,0.5391842,,,,,,,,,,,,,, -199000,4.5928764,0.6758506,,,,,,,,,,,,,, -199100,4.3458347,0.63016486,,,,,,,,,,,,,, -199200,4.322171,0.59884703,,,,,,,,,,,,,, -199300,4.1592045,0.58740425,,,,,,,,,,,,,, -199400,4.5613337,0.5631384,,,,,,,,,,,,,, -199500,5.002876,0.68629384,,,,,,,,,,,,,, -199600,4.39432,0.58906895,,,,,,,,,,,,,, -199693,,,0.9616350531578064,0.1440503299236297,0.7558000087738037,1.0533536672592163,50000.0,0.626800000667572,1.8423495292663568,10000.0,66871.37352252007,69251.04145288467,66871.37352252007,2366.4155824184418,6.112751245498657,0.0 -199700,4.656458,0.57842624,,,,,,,,,,,,,, -199800,4.303844,0.590797,,,,,,,,,,,,,, -199900,4.9612303,0.6321973,,,,,,,,,,,,,, -200000,4.8878508,0.6325579,,,,,,,,,,,,,, -200100,4.3721056,0.549521,,,,,,,,,,,,,, -200200,4.8539524,0.7175001,,,,,,,,,,,,,, -200300,4.643678,0.63024026,,,,,,,,,,,,,, -200400,4.5456896,0.66668993,,,,,,,,,,,,,, -200500,4.7645206,0.6364316,,,,,,,,,,,,,, -200600,4.8625617,0.64127326,,,,,,,,,,,,,, -200700,4.330499,0.6750839,,,,,,,,,,,,,, -200800,5.2041736,0.64239985,,,,,,,,,,,,,, -200900,4.750402,0.622426,,,,,,,,,,,,,, -201000,4.7259293,0.5884366,,,,,,,,,,,,,, -201100,4.819882,0.7132729,,,,,,,,,,,,,, -201200,4.6205115,0.54992366,,,,,,,,,,,,,, -201217,,,0.960160195827484,0.1471266746520996,0.7561599612236023,1.0531275272369385,50000.0,0.6271000504493713,1.840630888938904,10000.0,67381.27967381477,69777.85696268082,67381.27967381477,2383.205046653748,6.177959680557251,0.0 -201300,4.631244,0.5907687,,,,,,,,,,,,,, -201400,4.4195423,0.5805301,,,,,,,,,,,,,, -201500,5.822747,0.6314508,,,,,,,,,,,,,, -201600,4.615281,0.5676573,,,,,,,,,,,,,, -201700,4.5775905,0.6613409,,,,,,,,,,,,,, -201800,4.5523167,0.60975015,,,,,,,,,,,,,, -201900,4.7159495,0.5974388,,,,,,,,,,,,,, -202000,4.3493867,0.6375743,,,,,,,,,,,,,, -202100,4.4899917,0.61220807,,,,,,,,,,,,,, -202200,4.613104,0.62681425,,,,,,,,,,,,,, -202300,4.732911,0.6877659,,,,,,,,,,,,,, -202400,4.943324,0.63822025,,,,,,,,,,,,,, -202500,4.396016,0.67444074,,,,,,,,,,,,,, -202600,4.4932528,0.665733,,,,,,,,,,,,,, -202700,4.783888,0.6015621,,,,,,,,,,,,,, -202742,,,0.9604591727256776,0.1461871117353439,0.755620002746582,1.0531686544418335,50000.0,0.6278000473976135,1.8418047428131104,10000.0,67891.42406487465,70304.85989117622,67891.42406487465,2399.942197084427,6.244342565536499,0.0 -202800,4.277577,0.5916049,,,,,,,,,,,,,, -202900,4.5646048,0.65296495,,,,,,,,,,,,,, -203000,4.8659515,0.5997405,,,,,,,,,,,,,, -203100,5.3364377,0.6939598,,,,,,,,,,,,,, -203200,4.4939218,0.652946,,,,,,,,,,,,,, -203300,4.8703175,0.5706502,,,,,,,,,,,,,, -203400,4.1982484,0.5698377,,,,,,,,,,,,,, -203500,4.3397875,0.60955507,,,,,,,,,,,,,, -203600,4.228259,0.64671993,,,,,,,,,,,,,, -203700,4.5485616,0.59918356,,,,,,,,,,,,,, -203800,4.834974,0.60701203,,,,,,,,,,,,,, -203900,4.9966555,0.5821848,,,,,,,,,,,,,, -204000,4.4280224,0.61551285,,,,,,,,,,,,,, -204100,4.580682,0.64485955,,,,,,,,,,,,,, -204200,4.6215744,0.6032283,,,,,,,,,,,,,, -204268,,,0.9620735049247742,0.1424334943294525,0.7561399936676025,1.0534887313842771,50000.0,0.6273000240325928,1.84238874912262,10000.0,68401.58019638062,70831.64112019539,68401.58019638062,2416.4498698711395,6.305903196334839,0.0 -204300,4.1808343,0.61940867,,,,,,,,,,,,,, -204400,4.8193183,0.63278514,,,,,,,,,,,,,, -204500,4.654717,0.6533071,,,,,,,,,,,,,, -204600,4.4016147,0.60277414,,,,,,,,,,,,,, -204700,4.869156,0.56769043,,,,,,,,,,,,,, -204800,4.373248,0.6221738,,,,,,,,,,,,,, -204900,4.6336803,0.67086816,,,,,,,,,,,,,, -205000,4.434231,0.58926183,,,,,,,,,,,,,, -205100,4.9845443,0.6624554,,,,,,,,,,,,,, -205200,4.673436,0.5902698,,,,,,,,,,,,,, -205300,4.2903347,0.64305866,,,,,,,,,,,,,, -205400,4.6235037,0.6012746,,,,,,,,,,,,,, -205500,4.5136204,0.61125284,,,,,,,,,,,,,, -205600,4.6719775,0.68128836,,,,,,,,,,,,,, -205700,4.79875,0.65232724,,,,,,,,,,,,,, -205793,,,0.9616150856018066,0.145903930068016,0.755840003490448,1.0521328449249268,50000.0,0.6279000043869019,1.840342879295349,10000.0,68911.47825837135,71358.17896771431,68911.47825837135,2432.962220907212,6.379040241241455,0.0 -205800,4.6975503,0.6863871,,,,,,,,,,,,,, -205900,5.0719376,0.63164353,,,,,,,,,,,,,, -206000,4.4312186,0.6218225,,,,,,,,,,,,,, -206100,4.9982834,0.6334209,,,,,,,,,,,,,, -206200,4.604789,0.6762661,,,,,,,,,,,,,, -206300,4.4492626,0.5946099,,,,,,,,,,,,,, -206400,4.595174,0.6171985,,,,,,,,,,,,,, -206500,4.2356796,0.5862456,,,,,,,,,,,,,, -206600,4.991721,0.6653282,,,,,,,,,,,,,, -206700,4.2261176,0.61431384,,,,,,,,,,,,,, -206800,5.2954316,0.6557501,,,,,,,,,,,,,, -206900,4.2422667,0.57798827,,,,,,,,,,,,,, -207000,4.48108,0.60445035,,,,,,,,,,,,,, -207100,4.5816197,0.61698097,,,,,,,,,,,,,, -207200,4.5634427,0.6786706,,,,,,,,,,,,,, -207300,4.7525005,0.6392671,,,,,,,,,,,,,, -207318,,,0.9592832922935486,0.1478613168001175,0.7560399770736694,1.0538548231124878,50000.0,0.6272000074386597,1.8419071435928345,10000.0,69421.45237731934,71884.97119355202,69421.45237731934,2449.6562552452087,6.44709062576294,0.0 -207400,4.692963,0.6924693,,,,,,,,,,,,,, -207500,4.79198,0.65837955,,,,,,,,,,,,,, -207600,4.764326,0.5583876,,,,,,,,,,,,,, -207700,4.5536327,0.58838123,,,,,,,,,,,,,, -207800,4.463266,0.6412239,,,,,,,,,,,,,, -207900,4.065365,0.5710458,,,,,,,,,,,,,, -208000,4.641369,0.63164264,,,,,,,,,,,,,, -208100,5.320874,0.67304957,,,,,,,,,,,,,, -208200,4.611995,0.60334563,,,,,,,,,,,,,, -208300,4.464503,0.5323915,,,,,,,,,,,,,, -208400,4.5234184,0.6751807,,,,,,,,,,,,,, -208500,4.653197,0.6134876,,,,,,,,,,,,,, -208600,4.6682587,0.6070131,,,,,,,,,,,,,, -208700,4.587603,0.6296314,,,,,,,,,,,,,, -208800,4.6796265,0.6017198,,,,,,,,,,,,,, -208843,,,0.9614756107330322,0.1456343531608581,0.7562199831008911,1.0526604652404783,50000.0,0.6274000406265259,1.8434849977493288,10000.0,69931.56385087967,72411.7939863205,69931.56385087967,2466.2402641773224,6.519393682479858,0.0 -208900,4.6443744,0.58461624,,,,,,,,,,,,,, -209000,4.6912045,0.612483,,,,,,,,,,,,,, -209100,4.9179144,0.60418755,,,,,,,,,,,,,, -209200,4.168638,0.62023985,,,,,,,,,,,,,, -209300,4.5745583,0.6207493,,,,,,,,,,,,,, -209400,4.5794816,0.65590155,,,,,,,,,,,,,, -209500,4.1339483,0.5952171,,,,,,,,,,,,,, -209600,4.2790337,0.58150995,,,,,,,,,,,,,, -209700,4.755843,0.6603086,,,,,,,,,,,,,, -209800,4.3785667,0.62276065,,,,,,,,,,,,,, -209900,4.3375397,0.59519506,,,,,,,,,,,,,, -210000,4.3149533,0.63166183,,,,,,,,,,,,,, -210100,5.716091,0.60730404,,,,,,,,,,,,,, -210200,4.4837356,0.5482886,,,,,,,,,,,,,, -210300,4.7554936,0.66988254,,,,,,,,,,,,,, -210368,,,0.9622129797935486,0.1451591402292251,0.7560799717903137,1.0537360906600952,50000.0,0.6283000111579895,1.842454433441162,10000.0,70441.48584985733,72938.48751044273,70441.48584985733,2482.894499540329,6.58108377456665,0.0 -210400,4.4029255,0.6734768,,,,,,,,,,,,,, -210500,4.720629,0.644292,,,,,,,,,,,,,, -210600,4.5162287,0.5956804,,,,,,,,,,,,,, -210700,4.835653,0.6809577,,,,,,,,,,,,,, -210800,4.8596797,0.6195995,,,,,,,,,,,,,, -210900,4.570009,0.63683414,,,,,,,,,,,,,, -211000,4.630522,0.64885783,,,,,,,,,,,,,, -211100,4.6466813,0.5787555,,,,,,,,,,,,,, -211200,4.9983096,0.6296696,,,,,,,,,,,,,, -211300,4.600572,0.6175081,,,,,,,,,,,,,, -211400,4.8222237,0.6274011,,,,,,,,,,,,,, -211500,4.5250187,0.63552225,,,,,,,,,,,,,, -211600,4.3882165,0.64882714,,,,,,,,,,,,,, -211700,4.7108874,0.5968639,,,,,,,,,,,,,, -211800,4.6019197,0.66834134,,,,,,,,,,,,,, -211893,,,0.962312638759613,0.1424638777971267,0.7555199861526489,1.0546331405639648,50000.0,0.6271000504493713,1.8440243005752563,10000.0,70951.64380955696,73465.29731369019,70951.64380955696,2499.4245150089264,6.648257732391357,0.0 -211900,4.114729,0.5492199,,,,,,,,,,,,,, -212000,4.6325374,0.64176774,,,,,,,,,,,,,, -212100,4.545258,0.6486004,,,,,,,,,,,,,, -212200,4.7658715,0.5972746,,,,,,,,,,,,,, -212300,4.166558,0.586886,,,,,,,,,,,,,, -212400,4.4837146,0.60655797,,,,,,,,,,,,,, -212500,4.91389,0.62821174,,,,,,,,,,,,,, -212600,4.834135,0.63447046,,,,,,,,,,,,,, -212700,4.386222,0.5897434,,,,,,,,,,,,,, -212800,4.9319253,0.6494665,,,,,,,,,,,,,, -212900,4.2270017,0.62837,,,,,,,,,,,,,, -213000,4.547633,0.64832985,,,,,,,,,,,,,, -213100,5.2652826,0.62380725,,,,,,,,,,,,,, -213200,4.319219,0.6254428,,,,,,,,,,,,,, -213300,4.7703676,0.58805335,,,,,,,,,,,,,, -213400,4.585982,0.5673613,,,,,,,,,,,,,, -213418,,,0.9602997303009032,0.1486295908689499,0.7561999559402466,1.0539764165878296,50000.0,0.6278000473976135,1.843787789344788,10000.0,71461.56416463852,73991.9448082447,71461.56416463852,2516.0319879055023,6.712458610534668,0.0 -213500,4.4009285,0.65057,,,,,,,,,,,,,, -213600,4.4711614,0.66359293,,,,,,,,,,,,,, -213700,4.245042,0.5582308,,,,,,,,,,,,,, -213800,5.850785,0.60213894,,,,,,,,,,,,,, -213900,4.4652176,0.6356415,,,,,,,,,,,,,, -214000,4.449743,0.5707801,,,,,,,,,,,,,, -214100,4.0964313,0.5250257,,,,,,,,,,,,,, -214200,4.5503273,0.63406914,,,,,,,,,,,,,, -214300,4.248073,0.5345737,,,,,,,,,,,,,, -214400,4.313966,0.5639684,,,,,,,,,,,,,, -214500,4.3530307,0.59765065,,,,,,,,,,,,,, -214600,4.7956524,0.61498284,,,,,,,,,,,,,, -214700,5.1246333,0.6372053,,,,,,,,,,,,,, -214800,5.1756253,0.605952,,,,,,,,,,,,,, -214900,4.8100066,0.6504768,,,,,,,,,,,,,, -214943,,,0.9592633843421936,0.1488154977560043,0.7554599642753601,1.0519776344299316,50000.0,0.626800000667572,1.8391730785369875,10000.0,71971.70177531242,74519.00781488419,71971.70177531242,2532.8387157917023,6.777054309844971,0.0 -215000,4.5166073,0.63843787,,,,,,,,,,,,,, -215100,4.672736,0.65411055,,,,,,,,,,,,,, -215200,4.313977,0.5956601,,,,,,,,,,,,,, -215300,4.203008,0.5303398,,,,,,,,,,,,,, -215400,4.9850864,0.6502285,,,,,,,,,,,,,, -215500,4.658325,0.63382363,,,,,,,,,,,,,, -215600,4.647406,0.64263463,,,,,,,,,,,,,, -215700,4.6915355,0.648244,,,,,,,,,,,,,, -215800,4.906663,0.7592658,,,,,,,,,,,,,, -215900,4.9430842,0.6803779,,,,,,,,,,,,,, -216000,4.567771,0.6177635,,,,,,,,,,,,,, -216100,4.631772,0.64307,,,,,,,,,,,,,, -216200,4.3802752,0.69111174,,,,,,,,,,,,,, -216300,4.637812,0.6531898,,,,,,,,,,,,,, -216400,5.025363,0.6140898,,,,,,,,,,,,,, -216468,,,0.961336076259613,0.145295962691307,0.756060004234314,1.0536319017410278,50000.0,0.6269000172615051,1.8440710306167605,10000.0,72481.71206188202,75045.74508166313,72481.71206188202,2549.4425597190857,6.845051527023315,0.0 -216500,4.6495385,0.56284946,,,,,,,,,,,,,, -216600,4.520073,0.5975642,,,,,,,,,,,,,, -216700,4.4845467,0.6317318,,,,,,,,,,,,,, -216800,4.3852487,0.60807306,,,,,,,,,,,,,, -216900,4.6169934,0.6140763,,,,,,,,,,,,,, -217000,4.6347313,0.6110676,,,,,,,,,,,,,, -217100,4.5931215,0.60581434,,,,,,,,,,,,,, -217200,4.7323604,0.5722888,,,,,,,,,,,,,, -217300,4.7531304,0.6420028,,,,,,,,,,,,,, -217400,4.7494955,0.6825088,,,,,,,,,,,,,, -217500,4.7528234,0.58583534,,,,,,,,,,,,,, -217600,4.9602695,0.66445845,,,,,,,,,,,,,, -217700,4.5941706,0.56685823,,,,,,,,,,,,,, -217800,4.674105,0.65343624,,,,,,,,,,,,,, -217900,4.7198377,0.6331351,,,,,,,,,,,,,, -217992,,,0.9616350531578064,0.1456587761640548,0.755899965763092,1.0531891584396362,50000.0,0.6264000535011292,1.841581106185913,10000.0,72991.72962641716,75572.34304332733,72991.72962641716,2565.9037024974823,6.909708738327026,0.0 -218000,6.332373,0.74418753,,,,,,,,,,,,,, -218100,4.226698,0.5993576,,,,,,,,,,,,,, -218200,4.783189,0.6592575,,,,,,,,,,,,,, -218300,4.5387764,0.6585045,,,,,,,,,,,,,, -218400,4.4090533,0.65283144,,,,,,,,,,,,,, -218500,3.9697804,0.50773835,,,,,,,,,,,,,, -218600,5.231098,0.6476697,,,,,,,,,,,,,, -218700,4.5615883,0.6239214,,,,,,,,,,,,,, -218800,4.915067,0.6136451,,,,,,,,,,,,,, -218900,4.157439,0.5819263,,,,,,,,,,,,,, -219000,4.7108135,0.6597141,,,,,,,,,,,,,, -219100,4.883922,0.6889219,,,,,,,,,,,,,, -219200,4.3635554,0.616084,,,,,,,,,,,,,, -219300,4.186877,0.5836036,,,,,,,,,,,,,, -219400,4.9086423,0.619913,,,,,,,,,,,,,, -219500,4.102337,0.56418055,,,,,,,,,,,,,, -219517,,,0.9599609375,0.1475105881690979,0.7558000087738037,1.052554965019226,50000.0,0.6279000043869019,1.8403915166854856,10000.0,73501.86399936676,76099.1445813179,73501.86399936676,2582.430620908737,6.99518346786499,0.0 -219600,4.618686,0.57127506,,,,,,,,,,,,,, -219700,4.7508125,0.64700043,,,,,,,,,,,,,, -219800,4.4023757,0.6084845,,,,,,,,,,,,,, -219900,4.2311893,0.6272022,,,,,,,,,,,,,, -220000,4.115531,0.58704054,,,,,,,,,,,,,, -220100,4.8054366,0.61222696,,,,,,,,,,,,,, -220200,5.0532312,0.68019867,,,,,,,,,,,,,, -220300,4.5287747,0.6587006,,,,,,,,,,,,,, -220400,4.213055,0.54975104,,,,,,,,,,,,,, -220500,4.449879,0.62512016,,,,,,,,,,,,,, -220600,4.6704874,0.5631055,,,,,,,,,,,,,, -220700,4.407319,0.6023921,,,,,,,,,,,,,, -220800,4.214423,0.57652205,,,,,,,,,,,,,, -220900,4.335378,0.6292566,,,,,,,,,,,,,, -221000,4.5589695,0.6609192,,,,,,,,,,,,,, -221041,,,0.9603196382522584,0.1453995704650879,0.7556799650192261,1.0543261766433716,50000.0,0.6272000074386597,1.843128561973572,10000.0,74011.82157802582,76625.97312402725,74011.82157802582,2599.18460726738,7.058007955551148,0.0 -221100,5.1393423,0.7302278,,,,,,,,,,,,,, -221200,4.709648,0.64585716,,,,,,,,,,,,,, -221300,4.6525836,0.6515411,,,,,,,,,,,,,, -221400,4.0411463,0.51706576,,,,,,,,,,,,,, -221500,4.920149,0.6039604,,,,,,,,,,,,,, -221600,4.328347,0.61556906,,,,,,,,,,,,,, -221700,4.55087,0.6060947,,,,,,,,,,,,,, -221800,4.475647,0.6294736,,,,,,,,,,,,,, -221900,4.243983,0.5822994,,,,,,,,,,,,,, -222000,4.403658,0.5809368,,,,,,,,,,,,,, -222100,4.57248,0.65642506,,,,,,,,,,,,,, -222200,4.4588466,0.6388,,,,,,,,,,,,,, -222300,4.376145,0.6690388,,,,,,,,,,,,,, -222400,4.5660768,0.68599606,,,,,,,,,,,,,, -222500,5.765942,0.6588327,,,,,,,,,,,,,, -222567,,,0.961355984210968,0.1447490602731704,0.7556399703025818,1.052621603012085,50000.0,0.6279000043869019,1.84185791015625,10000.0,74521.89327073097,77152.81997966766,74521.89327073097,2615.8400671482086,7.123383760452271,0.0 -222600,4.5968184,0.6105047,,,,,,,,,,,,,, -222700,4.5803704,0.60249877,,,,,,,,,,,,,, -222800,4.6598883,0.60151356,,,,,,,,,,,,,, -222900,4.5386205,0.54260266,,,,,,,,,,,,,, -223000,5.5204296,0.69205636,,,,,,,,,,,,,, -223100,5.044399,0.60440147,,,,,,,,,,,,,, -223200,4.7242613,0.6535729,,,,,,,,,,,,,, -223300,4.310483,0.6358491,,,,,,,,,,,,,, -223400,4.5396247,0.5132554,,,,,,,,,,,,,, -223500,4.450763,0.6752224,,,,,,,,,,,,,, -223600,4.7936554,0.65807563,,,,,,,,,,,,,, -223700,4.6043553,0.6726137,,,,,,,,,,,,,, -223800,4.328692,0.65887475,,,,,,,,,,,,,, -223900,5.322882,0.77429146,,,,,,,,,,,,,, -224000,4.2427945,0.56760347,,,,,,,,,,,,,, -224092,,,0.9618343114852904,0.1450876146554947,0.7557799816131592,1.0538315773010254,50000.0,0.6265000104904175,1.8424419164657595,10000.0,75031.88020396233,77679.52306723595,75031.88020396233,2632.435756444931,7.188803672790527,0.0 -224100,4.748082,0.68469703,,,,,,,,,,,,,, -224200,4.769686,0.6677333,,,,,,,,,,,,,, -224300,4.555495,0.55867404,,,,,,,,,,,,,, -224400,4.2414823,0.5932631,,,,,,,,,,,,,, -224500,4.3575487,0.591269,,,,,,,,,,,,,, -224600,4.466118,0.6388119,,,,,,,,,,,,,, -224700,4.1919165,0.58928704,,,,,,,,,,,,,, -224800,4.268874,0.5555041,,,,,,,,,,,,,, -224900,4.6923423,0.66405547,,,,,,,,,,,,,, -225000,4.681554,0.6606358,,,,,,,,,,,,,, -225100,4.482715,0.6288408,,,,,,,,,,,,,, -225200,4.384555,0.6845148,,,,,,,,,,,,,, -225300,4.221152,0.56628454,,,,,,,,,,,,,, -225400,4.638729,0.60468984,,,,,,,,,,,,,, -225500,4.9033303,0.68852735,,,,,,,,,,,,,, -225600,4.633722,0.6414955,,,,,,,,,,,,,, -225617,,,0.9606385231018066,0.1497254520654678,0.7558799982070923,1.053784728050232,50000.0,0.6282000541687012,1.8433343172073364,10000.0,75541.82284140587,78206.24735283852,75541.82284140587,2649.096224308014,7.255248308181763,0.0 -225700,4.7674356,0.5870052,,,,,,,,,,,,,, -225800,4.310015,0.59750706,,,,,,,,,,,,,, -225900,4.3993773,0.68661124,,,,,,,,,,,,,, -226000,4.730622,0.60238075,,,,,,,,,,,,,, -226100,4.5157614,0.6739578,,,,,,,,,,,,,, -226200,4.583623,0.5853354,,,,,,,,,,,,,, -226300,4.2382817,0.5741209,,,,,,,,,,,,,, -226400,4.8964376,0.5525401,,,,,,,,,,,,,, -226500,4.8605537,0.6759101,,,,,,,,,,,,,, -226600,4.5160146,0.68512493,,,,,,,,,,,,,, -226700,5.044348,0.6058378,,,,,,,,,,,,,, -226800,4.553723,0.6062326,,,,,,,,,,,,,, -226900,4.0592456,0.5854742,,,,,,,,,,,,,, -227000,4.874325,0.6343555,,,,,,,,,,,,,, -227100,4.49965,0.62783206,,,,,,,,,,,,,, -227143,,,0.9608577489852904,0.1467025727033615,0.7559599876403809,1.0523626804351809,50000.0,0.6274000406265259,1.8423489332199097,10000.0,76052.00051903725,78733.20663499832,76052.00051903725,2665.7563643455505,7.321617126464844,0.0 -227200,4.554543,0.6100862,,,,,,,,,,,,,, -227300,4.074432,0.5436452,,,,,,,,,,,,,, -227400,4.571905,0.5534626,,,,,,,,,,,,,, -227500,4.357079,0.6201374,,,,,,,,,,,,,, -227600,4.359911,0.57562506,,,,,,,,,,,,,, -227700,5.254115,0.63053805,,,,,,,,,,,,,, -227800,5.0624585,0.6090131,,,,,,,,,,,,,, -227900,4.7845654,0.62050414,,,,,,,,,,,,,, -228000,4.4705925,0.6244242,,,,,,,,,,,,,, -228100,4.763494,0.65384007,,,,,,,,,,,,,, -228200,4.2654777,0.61046064,,,,,,,,,,,,,, -228300,4.593439,0.66325885,,,,,,,,,,,,,, -228400,4.05041,0.5563814,,,,,,,,,,,,,, -228500,4.3084955,0.54627866,,,,,,,,,,,,,, -228600,4.976978,0.6996847,,,,,,,,,,,,,, -228667,,,0.962890625,0.1419253796339035,0.7560200095176697,1.0544745922088623,50000.0,0.6272000074386597,1.8436956405639648,10000.0,76561.98649954796,79259.84963774681,76561.98649954796,2682.2827892303467,7.397505044937134,0.0 -228700,5.4114285,0.6337555,,,,,,,,,,,,,, -228800,4.325077,0.60216004,,,,,,,,,,,,,, -228900,4.5642323,0.6306733,,,,,,,,,,,,,, -229000,4.1218314,0.55799884,,,,,,,,,,,,,, -229100,4.9453697,0.66040134,,,,,,,,,,,,,, -229200,4.556922,0.6223544,,,,,,,,,,,,,, -229300,4.4794965,0.616232,,,,,,,,,,,,,, -229400,4.1533856,0.5516706,,,,,,,,,,,,,, -229500,4.4898643,0.67184955,,,,,,,,,,,,,, -229600,4.2276807,0.6575822,,,,,,,,,,,,,, -229700,4.821733,0.5865245,,,,,,,,,,,,,, -229800,4.490629,0.5455759,,,,,,,,,,,,,, -229900,4.27791,0.60224867,,,,,,,,,,,,,, -230000,4.470223,0.584931,,,,,,,,,,,,,, -230100,4.887096,0.71043557,,,,,,,,,,,,,, -230192,,,0.962133288383484,0.1451293677091598,0.7562199831008911,1.053865909576416,50000.0,0.627500057220459,1.8432239294052124,10000.0,77071.98637270927,79786.77798008919,77071.98637270927,2699.087404489517,7.466514587402344,0.0 -230200,5.206189,0.6766154,,,,,,,,,,,,,, -230300,4.477676,0.5886464,,,,,,,,,,,,,, -230400,4.5120907,0.66671634,,,,,,,,,,,,,, -230500,4.3029637,0.65473807,,,,,,,,,,,,,, -230600,4.586443,0.58818316,,,,,,,,,,,,,, -230700,4.93755,0.6310895,,,,,,,,,,,,,, -230800,4.4757524,0.62950593,,,,,,,,,,,,,, -230900,4.326603,0.63094133,,,,,,,,,,,,,, -231000,4.7359962,0.6420288,,,,,,,,,,,,,, -231100,4.007981,0.52545834,,,,,,,,,,,,,, -231200,5.1395016,0.71227425,,,,,,,,,,,,,, -231300,4.4823604,0.61978495,,,,,,,,,,,,,, -231400,4.578143,0.6927129,,,,,,,,,,,,,, -231500,4.4258156,0.6294011,,,,,,,,,,,,,, -231600,4.528422,0.6638124,,,,,,,,,,,,,, -231700,4.8106575,0.6404534,,,,,,,,,,,,,, -231716,,,0.960957407951355,0.1444891989231109,0.7560799717903137,1.053759217262268,50000.0,0.6273000240325928,1.8422167301177976,10000.0,77581.86835837364,80314.55636286736,77581.86835837364,2716.8570907115936,7.537179708480835,0.0 -231800,4.928691,0.7008832,,,,,,,,,,,,,, -231900,4.562515,0.6417435,,,,,,,,,,,,,, -232000,4.3740697,0.5475471,,,,,,,,,,,,,, -232100,4.5709615,0.6210967,,,,,,,,,,,,,, -232200,4.580722,0.6611434,,,,,,,,,,,,,, -232300,4.665267,0.714219,,,,,,,,,,,,,, -232400,4.7074695,0.577452,,,,,,,,,,,,,, -232500,4.2192016,0.5781035,,,,,,,,,,,,,, -232600,4.340318,0.64869606,,,,,,,,,,,,,, -232700,4.7541637,0.6185765,,,,,,,,,,,,,, -232800,4.6489563,0.6446879,,,,,,,,,,,,,, -232900,4.030468,0.5757271,,,,,,,,,,,,,, -233000,4.315762,0.60913885,,,,,,,,,,,,,, -233100,4.9778733,0.5851484,,,,,,,,,,,,,, -233200,4.455233,0.6214115,,,,,,,,,,,,,, -233241,,,0.9602000713348388,0.1468252092599868,0.7556799650192261,1.0530420541763306,50000.0,0.6271000504493713,1.8426487445831297,10000.0,78091.93625879288,80841.24501657486,78091.93625879288,2733.3647408485413,7.594546318054199,0.0 -233300,4.31693,0.5793712,,,,,,,,,,,,,, -233400,4.5891566,0.6065791,,,,,,,,,,,,,, -233500,5.3935957,0.65687144,,,,,,,,,,,,,, -233600,4.516614,0.53146344,,,,,,,,,,,,,, -233700,4.5674458,0.65933305,,,,,,,,,,,,,, -233800,4.3365397,0.64571625,,,,,,,,,,,,,, -233900,4.1873817,0.5074326,,,,,,,,,,,,,, -234000,4.783894,0.57259876,,,,,,,,,,,,,, -234100,4.267869,0.5041272,,,,,,,,,,,,,, -234200,4.800155,0.5457546,,,,,,,,,,,,,, -234300,4.2341695,0.6160978,,,,,,,,,,,,,, -234400,4.706141,0.59725446,,,,,,,,,,,,,, -234500,4.808832,0.6304105,,,,,,,,,,,,,, -234600,4.702715,0.6089984,,,,,,,,,,,,,, -234700,4.67723,0.66425705,,,,,,,,,,,,,, -234766,,,0.961694836616516,0.1441184133291244,0.7562599778175354,1.0534793138504028,50000.0,0.6276000142097473,1.8427488803863523,10000.0,78601.9127805233,81368.13152265549,78601.9127805233,2750.1510181427,7.663392543792725,0.0 -234800,4.311533,0.63420045,,,,,,,,,,,,,, -234900,5.2083416,0.70298374,,,,,,,,,,,,,, -235000,4.845182,0.61776954,,,,,,,,,,,,,, -235100,5.685459,0.6567767,,,,,,,,,,,,,, -235200,4.2951303,0.58972037,,,,,,,,,,,,,, -235300,4.7921867,0.62890184,,,,,,,,,,,,,, -235400,5.003264,0.6423352,,,,,,,,,,,,,, -235500,4.587876,0.60989773,,,,,,,,,,,,,, -235600,4.5113754,0.6406089,,,,,,,,,,,,,, -235700,4.813362,0.6447365,,,,,,,,,,,,,, -235800,4.717925,0.61290646,,,,,,,,,,,,,, -235900,4.4247746,0.586136,,,,,,,,,,,,,, -236000,4.45355,0.6247788,,,,,,,,,,,,,, -236100,4.412606,0.6216599,,,,,,,,,,,,,, -236200,4.2492204,0.5874536,,,,,,,,,,,,,, -236291,,,0.961355984210968,0.1435891985893249,0.7560799717903137,1.0528885126113892,50000.0,0.6277000308036804,1.8410484790802,10000.0,79112.02974677086,81895.03840327263,79112.02974677086,2766.808699846268,7.740317106246948,0.0 -236300,4.561129,0.6230022,,,,,,,,,,,,,, -236400,4.4063873,0.59762096,,,,,,,,,,,,,, -236500,4.455272,0.66487044,,,,,,,,,,,,,, -236600,4.553278,0.6768767,,,,,,,,,,,,,, -236700,4.7508664,0.6473468,,,,,,,,,,,,,, -236800,4.505404,0.6384254,,,,,,,,,,,,,, -236900,4.0528927,0.54629576,,,,,,,,,,,,,, -237000,4.433602,0.6593646,,,,,,,,,,,,,, -237100,4.076115,0.63758504,,,,,,,,,,,,,, -237200,4.6301436,0.72341037,,,,,,,,,,,,,, -237300,4.332231,0.5742053,,,,,,,,,,,,,, -237400,5.268266,0.6476865,,,,,,,,,,,,,, -237500,4.720273,0.68009466,,,,,,,,,,,,,, -237600,4.587965,0.6856659,,,,,,,,,,,,,, -237700,4.8422146,0.6548518,,,,,,,,,,,,,, -237800,4.6866508,0.6874426,,,,,,,,,,,,,, -237816,,,0.961734652519226,0.1432908326387405,0.7560999989509583,1.054379105567932,50000.0,0.6279000043869019,1.8438612222671509,10000.0,79621.93251681328,82421.6057536602,79621.93251681328,2783.3465077877045,7.810421228408813,0.0 -237900,4.560271,0.5877424,,,,,,,,,,,,,, -238000,4.1645837,0.6159555,,,,,,,,,,,,,, -238100,4.57445,0.62276274,,,,,,,,,,,,,, -238200,4.5344667,0.5449572,,,,,,,,,,,,,, -238300,4.2285156,0.5103838,,,,,,,,,,,,,, -238400,4.214178,0.59211886,,,,,,,,,,,,,, -238500,4.498871,0.63489157,,,,,,,,,,,,,, -238600,4.5124884,0.67897224,,,,,,,,,,,,,, -238700,4.77429,0.6055565,,,,,,,,,,,,,, -238800,5.027693,0.6003463,,,,,,,,,,,,,, -238900,4.317276,0.64059126,,,,,,,,,,,,,, -239000,4.6296325,0.67711073,,,,,,,,,,,,,, -239100,4.856865,0.64828795,,,,,,,,,,,,,, -239200,4.442301,0.63826126,,,,,,,,,,,,,, -239300,4.2965527,0.5779584,,,,,,,,,,,,,, -239341,,,0.9607780575752258,0.1474478989839553,0.7555999755859375,1.0538862943649292,50000.0,0.6271000504493713,1.8431426286697388,10000.0,80132.0266327858,82948.37228608131,80132.0266327858,2799.8947973251343,7.878611326217651,0.0 -239400,4.519123,0.6077757,,,,,,,,,,,,,, -239500,4.317131,0.6162842,,,,,,,,,,,,,, -239600,4.508172,0.6752521,,,,,,,,,,,,,, -239700,5.0141225,0.6462634,,,,,,,,,,,,,, -239800,4.490554,0.5851557,,,,,,,,,,,,,, -239900,4.4182363,0.60363364,,,,,,,,,,,,,, -240000,4.6337886,0.6777457,,,,,,,,,,,,,, -240100,4.0477147,0.5211762,,,,,,,,,,,,,, -240200,5.4213715,0.69027126,,,,,,,,,,,,,, -240300,4.1962357,0.53726614,,,,,,,,,,,,,, -240400,4.763869,0.62421083,,,,,,,,,,,,,, -240500,4.589292,0.6537027,,,,,,,,,,,,,, -240600,4.0311227,0.5253806,,,,,,,,,,,,,, -240700,4.766544,0.67529905,,,,,,,,,,,,,, -240800,4.56489,0.6164154,,,,,,,,,,,,,, -240867,,,0.961316168308258,0.1430841088294983,0.7557399868965149,1.0552139282226562,50000.0,0.6279000043869019,1.843974471092224,10000.0,80642.20257234573,83475.19399499893,80642.20257234573,2816.4147934913635,7.948309659957886,0.0 -240900,4.4618573,0.6093195,,,,,,,,,,,,,, -241000,4.2393875,0.5712545,,,,,,,,,,,,,, -241100,4.914525,0.614087,,,,,,,,,,,,,, -241200,5.305566,0.6419169,,,,,,,,,,,,,, -241300,4.3915277,0.57928437,,,,,,,,,,,,,, -241400,4.3433642,0.6221343,,,,,,,,,,,,,, -241500,4.7671423,0.6113236,,,,,,,,,,,,,, -241600,4.383199,0.6262183,,,,,,,,,,,,,, -241700,4.7018385,0.6319055,,,,,,,,,,,,,, -241800,4.453774,0.60981405,,,,,,,,,,,,,, -241900,4.7438655,0.6659842,,,,,,,,,,,,,, -242000,4.6895323,0.58951455,,,,,,,,,,,,,, -242100,4.73566,0.5338721,,,,,,,,,,,,,, -242200,4.650126,0.6863195,,,,,,,,,,,,,, -242300,4.3497963,0.62360764,,,,,,,,,,,,,, -242392,,,0.961535394191742,0.145120620727539,0.7561399936676025,1.054344654083252,50000.0,0.6281000375747681,1.843250632286072,10000.0,81152.19447517395,84001.85370469093,81152.19447517395,2832.955675125122,8.019760608673096,0.0 -242400,4.2916703,0.5625854,,,,,,,,,,,,,, -242500,4.995604,0.64431816,,,,,,,,,,,,,, -242600,4.233731,0.5917307,,,,,,,,,,,,,, -242700,4.5639577,0.6432897,,,,,,,,,,,,,, -242800,4.762829,0.6487053,,,,,,,,,,,,,, -242900,4.64629,0.6383129,,,,,,,,,,,,,, -243000,4.6499567,0.6400499,,,,,,,,,,,,,, -243100,4.433269,0.5500029,,,,,,,,,,,,,, -243200,4.6294193,0.63411885,,,,,,,,,,,,,, -243300,4.4229445,0.56671554,,,,,,,,,,,,,, -243400,4.339429,0.60438406,,,,,,,,,,,,,, -243500,4.382505,0.6169247,,,,,,,,,,,,,, -243600,4.259821,0.6249678,,,,,,,,,,,,,, -243700,4.7848063,0.67810404,,,,,,,,,,,,,, -243800,4.812136,0.6537008,,,,,,,,,,,,,, -243900,4.3245373,0.5432939,,,,,,,,,,,,,, -243918,,,0.9616748690605164,0.1430436074733734,0.7561799883842468,1.0540276765823364,50000.0,0.6273000240325928,1.8438608646392824,10000.0,81662.36057519913,84528.89245724678,81662.36057519913,2849.7008497715,8.091844081878662,0.0 -244000,4.468274,0.69058967,,,,,,,,,,,,,, -244100,4.297966,0.6025113,,,,,,,,,,,,,, -244200,4.813727,0.60734224,,,,,,,,,,,,,, -244300,4.389643,0.5734941,,,,,,,,,,,,,, -244400,4.678243,0.677251,,,,,,,,,,,,,, -244500,4.833476,0.5733552,,,,,,,,,,,,,, -244600,4.2851806,0.6272423,,,,,,,,,,,,,, -244700,4.1486325,0.6303395,,,,,,,,,,,,,, -244800,4.7103705,0.68821347,,,,,,,,,,,,,, -244900,4.7419558,0.6528953,,,,,,,,,,,,,, -245000,4.4724245,0.6848822,,,,,,,,,,,,,, -245100,4.6027875,0.61088806,,,,,,,,,,,,,, -245200,4.37192,0.6125411,,,,,,,,,,,,,, -245300,4.731192,0.605879,,,,,,,,,,,,,, -245400,4.3873606,0.57326466,,,,,,,,,,,,,, -245443,,,0.9590441584587096,0.1470848023891449,0.755899965763092,1.0540157556533811,50000.0,0.6276000142097473,1.8433830738067627,10000.0,82172.32745957375,85055.66488361359,82172.32745957375,2866.384510755539,8.159352779388428,0.0 -245500,4.7176657,0.622706,,,,,,,,,,,,,, -245600,4.6598754,0.63140255,,,,,,,,,,,,,, -245700,4.7871375,0.62537867,,,,,,,,,,,,,, -245800,4.669308,0.6404827,,,,,,,,,,,,,, -245900,4.192938,0.62524575,,,,,,,,,,,,,, -246000,4.675858,0.6213865,,,,,,,,,,,,,, -246100,4.449722,0.55999184,,,,,,,,,,,,,, -246200,4.3874035,0.5372494,,,,,,,,,,,,,, -246300,5.389301,0.6725757,,,,,,,,,,,,,, -246400,5.7076807,0.6018783,,,,,,,,,,,,,, -246500,4.2925153,0.5743616,,,,,,,,,,,,,, -246600,4.315358,0.56843376,,,,,,,,,,,,,, -246700,4.578206,0.61551094,,,,,,,,,,,,,, -246800,4.0924473,0.5985607,,,,,,,,,,,,,, -246900,4.378364,0.6108043,,,,,,,,,,,,,, -246968,,,0.9604790806770324,0.1470487266778946,0.756060004234314,1.053966760635376,50000.0,0.6270000338554382,1.843010663986206,10000.0,82682.42796611786,85582.4812579155,82682.42796611786,2882.9736964702606,8.230422973632812,0.0 -247000,4.6398153,0.6629964,,,,,,,,,,,,,, -247100,4.2272296,0.6219789,,,,,,,,,,,,,, -247200,4.302164,0.6391885,,,,,,,,,,,,,, -247300,4.179651,0.6243262,,,,,,,,,,,,,, -247400,4.3550878,0.6436467,,,,,,,,,,,,,, -247500,4.4126954,0.59836733,,,,,,,,,,,,,, -247600,4.573876,0.5966098,,,,,,,,,,,,,, -247700,4.060799,0.5652076,,,,,,,,,,,,,, -247800,4.3481393,0.5872327,,,,,,,,,,,,,, -247900,4.612388,0.628278,,,,,,,,,,,,,, -248000,4.510288,0.6077297,,,,,,,,,,,,,, -248100,4.426469,0.59836555,,,,,,,,,,,,,, -248200,4.4140477,0.6318969,,,,,,,,,,,,,, -248300,4.85197,0.6864545,,,,,,,,,,,,,, -248400,4.2126064,0.5495397,,,,,,,,,,,,,, -248492,,,0.9622329473495485,0.1438832581043243,0.7556799650192261,1.054873824119568,50000.0,0.6277000308036804,1.8438180685043333,10000.0,83192.36039113998,86109.01463675499,83192.36039113998,2899.4492712020874,8.30149245262146,0.0 -248500,4.595107,0.66876566,,,,,,,,,,,,,, -248600,4.545738,0.66669756,,,,,,,,,,,,,, -248700,4.8929963,0.58442646,,,,,,,,,,,,,, -248800,4.467103,0.6180851,,,,,,,,,,,,,, -248900,4.3215885,0.581719,,,,,,,,,,,,,, -249000,5.059455,0.6390098,,,,,,,,,,,,,, -249100,4.7708216,0.6342909,,,,,,,,,,,,,, -249200,4.249172,0.60640156,,,,,,,,,,,,,, -249300,4.521028,0.6197088,,,,,,,,,,,,,, -249400,4.61227,0.6410406,,,,,,,,,,,,,, -249500,4.133993,0.55178803,,,,,,,,,,,,,, -249600,4.896564,0.646702,,,,,,,,,,,,,, -249700,4.62291,0.65732,,,,,,,,,,,,,, -249800,4.59378,0.683033,,,,,,,,,,,,,, -249900,5.052596,0.6017124,,,,,,,,,,,,,, -250000,4.350204,0.6053522,,,,,,,,,,,,,, -250017,,,0.9619140625,0.1431471705436706,0.756060004234314,1.0536892414093018,50000.0,0.6274000406265259,1.8433725833892824,10000.0,83702.24043631554,86635.69530034065,83702.24043631554,2916.1209042072296,8.375693798065186,0.0 -250100,4.67731,0.66912115,,,,,,,,,,,,,, -250200,4.85096,0.6914492,,,,,,,,,,,,,, -250300,4.6769814,0.62182766,,,,,,,,,,,,,, -250400,5.1613226,0.76820344,,,,,,,,,,,,,, -250500,4.583392,0.6757736,,,,,,,,,,,,,, -250600,4.1485934,0.587228,,,,,,,,,,,,,, -250700,4.3412695,0.65301573,,,,,,,,,,,,,, -250800,4.822074,0.6072915,,,,,,,,,,,,,, -250900,4.4017663,0.71566105,,,,,,,,,,,,,, -251000,4.534703,0.61094075,,,,,,,,,,,,,, -251100,4.6676946,0.5961642,,,,,,,,,,,,,, -251200,4.5570426,0.6788338,,,,,,,,,,,,,, -251300,4.326822,0.63669825,,,,,,,,,,,,,, -251400,4.0983005,0.566809,,,,,,,,,,,,,, -251500,4.951911,0.71724695,,,,,,,,,,,,,, -251542,,,0.960718274116516,0.1478823870420456,0.7559199929237366,1.0535964965820312,50000.0,0.6281000375747681,1.8413687944412231,10000.0,84212.27120828629,87162.43416666985,84212.27120828629,2932.704018354416,8.445951461791992,0.0 -251600,4.6488667,0.61812335,,,,,,,,,,,,,, -251700,4.794132,0.61069727,,,,,,,,,,,,,, -251800,4.2862077,0.5850079,,,,,,,,,,,,,, -251900,4.638602,0.68432677,,,,,,,,,,,,,, -252000,4.634264,0.66142607,,,,,,,,,,,,,, -252100,4.8125143,0.56755394,,,,,,,,,,,,,, -252200,4.2792625,0.5911171,,,,,,,,,,,,,, -252300,4.193281,0.6331488,,,,,,,,,,,,,, -252400,4.806395,0.6591733,,,,,,,,,,,,,, -252500,4.3399553,0.6186027,,,,,,,,,,,,,, -252600,4.6787286,0.6698258,,,,,,,,,,,,,, -252700,4.851439,0.6876218,,,,,,,,,,,,,, -252800,4.4962916,0.62698,,,,,,,,,,,,,, -252900,4.4063826,0.6244127,,,,,,,,,,,,,, -253000,4.1291194,0.5452157,,,,,,,,,,,,,, -253067,,,0.9599011540412904,0.1474985778331756,0.7556599974632263,1.054338455200195,50000.0,0.6267000436782837,1.843252420425415,10000.0,84722.27003097534,87689.12922596931,84722.27003097534,2949.2420732975006,8.548237800598145,0.0 -253100,5.146723,0.6998376,,,,,,,,,,,,,, -253200,4.690844,0.6647651,,,,,,,,,,,,,, -253300,5.181615,0.6793432,,,,,,,,,,,,,, -253400,4.824922,0.6135914,,,,,,,,,,,,,, -253500,4.3860583,0.6131132,,,,,,,,,,,,,, -253600,4.3156495,0.5788524,,,,,,,,,,,,,, -253700,4.406955,0.6253116,,,,,,,,,,,,,, -253800,4.5202327,0.64619505,,,,,,,,,,,,,, -253900,4.2951703,0.61081886,,,,,,,,,,,,,, -254000,4.55456,0.57027274,,,,,,,,,,,,,, -254100,4.725277,0.5574451,,,,,,,,,,,,,, -254200,4.5535226,0.61828715,,,,,,,,,,,,,, -254300,4.24867,0.5783616,,,,,,,,,,,,,, -254400,4.5848465,0.61509836,,,,,,,,,,,,,, -254500,4.5732737,0.56736,,,,,,,,,,,,,, -254592,,,0.9603196382522584,0.1466724723577499,0.7559599876403809,1.0540852546691897,50000.0,0.627500057220459,1.8423309326171875,10000.0,85232.41868042946,88216.00778889656,85232.41868042946,2965.849251270294,8.616014003753662,0.0 -254600,4.5351934,0.61323524,,,,,,,,,,,,,, -254700,4.852862,0.76057076,,,,,,,,,,,,,, -254800,4.5981717,0.6132605,,,,,,,,,,,,,, -254900,3.8999362,0.5178189,,,,,,,,,,,,,, -255000,4.540632,0.63372153,,,,,,,,,,,,,, -255100,4.373575,0.6556654,,,,,,,,,,,,,, -255200,4.8084173,0.67614514,,,,,,,,,,,,,, -255300,4.9925876,0.6299565,,,,,,,,,,,,,, -255400,4.7452083,0.6577286,,,,,,,,,,,,,, -255500,4.9162064,0.7044064,,,,,,,,,,,,,, -255600,4.428699,0.58191085,,,,,,,,,,,,,, -255700,4.6652513,0.6837827,,,,,,,,,,,,,, -255800,4.781039,0.6084866,,,,,,,,,,,,,, -255900,4.5957546,0.62739325,,,,,,,,,,,,,, -256000,4.5162287,0.6630572,,,,,,,,,,,,,, -256100,4.8754754,0.67920065,,,,,,,,,,,,,, -256117,,,0.9610570669174194,0.1451389640569687,0.7560799717903137,1.0543256998062134,50000.0,0.6273000240325928,1.8429205417633057,10000.0,85742.36588978767,88742.66361165047,85742.36588978767,2982.4352350234985,8.684788465499878,0.0 -256200,4.737348,0.5809171,,,,,,,,,,,,,, -256300,4.4099293,0.64496917,,,,,,,,,,,,,, -256400,4.048384,0.5863364,,,,,,,,,,,,,, -256500,5.535199,0.6343273,,,,,,,,,,,,,, -256600,4.2706184,0.5756607,,,,,,,,,,,,,, -256700,4.694576,0.6814279,,,,,,,,,,,,,, -256800,4.8813944,0.715544,,,,,,,,,,,,,, -256900,4.35906,0.66749513,,,,,,,,,,,,,, -257000,4.5975432,0.61553264,,,,,,,,,,,,,, -257100,4.4759307,0.62773985,,,,,,,,,,,,,, -257200,4.565339,0.6012658,,,,,,,,,,,,,, -257300,4.577997,0.6751992,,,,,,,,,,,,,, -257400,4.6805286,0.6420847,,,,,,,,,,,,,, -257500,4.3126917,0.62472606,,,,,,,,,,,,,, -257600,5.11594,0.64460766,,,,,,,,,,,,,, -257642,,,0.9624322056770324,0.1431660652160644,0.7559599876403809,1.0529989004135132,50000.0,0.626800000667572,1.842377066612244,10000.0,86252.392973423,89269.49858641624,86252.392973423,2999.1208930015564,8.751543760299683,0.0 -257700,4.3287077,0.6204251,,,,,,,,,,,,,, -257800,4.8519454,0.6553568,,,,,,,,,,,,,, -257900,3.9716527,0.5550877,,,,,,,,,,,,,, -258000,4.827025,0.60243046,,,,,,,,,,,,,, -258100,4.2484426,0.55239147,,,,,,,,,,,,,, -258200,4.6588707,0.6569184,,,,,,,,,,,,,, -258300,4.7572074,0.5940659,,,,,,,,,,,,,, -258400,4.696093,0.5713176,,,,,,,,,,,,,, -258500,4.29797,0.6257723,,,,,,,,,,,,,, -258600,4.607335,0.67058116,,,,,,,,,,,,,, -258700,5.7026744,0.61660266,,,,,,,,,,,,,, -258800,5.1215725,0.659361,,,,,,,,,,,,,, -258900,4.9977913,0.64753616,,,,,,,,,,,,,, -259000,4.553768,0.6331417,,,,,,,,,,,,,, -259100,5.0136905,0.66482687,,,,,,,,,,,,,, -259166,,,0.959402859210968,0.1494778543710708,0.7557399868965149,1.0539062023162842,50000.0,0.6267000436782837,1.8413981199264529,10000.0,86762.40819621086,89796.18056607246,86762.40819621086,3015.657775402069,8.82495379447937,0.0 -259200,4.6967626,0.6130215,,,,,,,,,,,,,, -259300,4.636855,0.65246814,,,,,,,,,,,,,, -259400,4.849182,0.70819956,,,,,,,,,,,,,, -259500,4.658605,0.6841762,,,,,,,,,,,,,, -259600,4.8302755,0.7020896,,,,,,,,,,,,,, -259700,4.6361036,0.6335727,,,,,,,,,,,,,, -259800,4.4885406,0.55121076,,,,,,,,,,,,,, -259900,4.1508784,0.56878847,,,,,,,,,,,,,, -260000,4.9252214,0.59712166,,,,,,,,,,,,,, -260100,4.343209,0.6014324,,,,,,,,,,,,,, -260200,4.185982,0.5402125,,,,,,,,,,,,,, -260300,4.6361613,0.588328,,,,,,,,,,,,,, -260400,4.5209565,0.5795659,,,,,,,,,,,,,, -260500,4.775785,0.62688124,,,,,,,,,,,,,, -260600,4.8819795,0.63783276,,,,,,,,,,,,,, -260691,,,0.9612762928009032,0.14541095495224,0.755899965763092,1.0525975227355957,50000.0,0.6272000074386597,1.8419866561889648,10000.0,87272.5882089138,90323.19861745834,87272.5882089138,3032.370623588562,8.895149946212769,0.0 -260700,4.374735,0.594708,,,,,,,,,,,,,, -260800,4.387408,0.5704021,,,,,,,,,,,,,, -260900,4.362649,0.6005903,,,,,,,,,,,,,, -261000,4.7598453,0.6184276,,,,,,,,,,,,,, -261100,4.4715023,0.62903106,,,,,,,,,,,,,, -261200,5.0188117,0.6237446,,,,,,,,,,,,,, -261300,4.9336853,0.64603996,,,,,,,,,,,,,, -261400,4.319525,0.5783607,,,,,,,,,,,,,, -261500,4.5730753,0.65162486,,,,,,,,,,,,,, -261600,5.0093646,0.6200587,,,,,,,,,,,,,, -261700,4.496904,0.6681188,,,,,,,,,,,,,, -261800,4.5796394,0.62405044,,,,,,,,,,,,,, -261900,4.5951395,0.70207614,,,,,,,,,,,,,, -262000,4.622803,0.59642714,,,,,,,,,,,,,, -262100,4.1834636,0.61606836,,,,,,,,,,,,,, -262200,4.5113134,0.61668473,,,,,,,,,,,,,, -262215,,,0.961136758327484,0.1466261595487594,0.7559399604797363,1.0539066791534424,50000.0,0.6269000172615051,1.8433120250701904,10000.0,87782.54205465317,90849.84784388542,87782.54205465317,3048.934982776642,8.971031188964844,0.0 -262300,5.029445,0.6868282,,,,,,,,,,,,,, -262400,4.8589377,0.6392681,,,,,,,,,,,,,, -262500,5.093133,0.67021143,,,,,,,,,,,,,, -262600,4.4954495,0.6079396,,,,,,,,,,,,,, -262700,4.646637,0.63618654,,,,,,,,,,,,,, -262800,4.9084516,0.6905469,,,,,,,,,,,,,, -262900,4.6769776,0.5801197,,,,,,,,,,,,,, -263000,4.249322,0.57375115,,,,,,,,,,,,,, -263100,4.6497626,0.60681915,,,,,,,,,,,,,, -263200,5.044744,0.64944804,,,,,,,,,,,,,, -263300,5.0170546,0.6148402,,,,,,,,,,,,,, -263400,4.8147726,0.6389142,,,,,,,,,,,,,, -263500,5.0982423,0.6101561,,,,,,,,,,,,,, -263600,4.613959,0.6155888,,,,,,,,,,,,,, -263700,4.861864,0.5981324,,,,,,,,,,,,,, -263740,,,0.9615154266357422,0.146799087524414,0.7556999921798706,1.0533528327941897,50000.0,0.6279000043869019,1.843176245689392,10000.0,88292.70301318169,91376.86492967606,88292.70301318169,3065.657157897949,9.04961919784546,0.0 -263800,4.6377535,0.58707875,,,,,,,,,,,,,, -263900,4.4653535,0.6038122,,,,,,,,,,,,,, -264000,4.5978584,0.6387858,,,,,,,,,,,,,, -264100,4.292287,0.62991804,,,,,,,,,,,,,, -264200,4.648652,0.57922965,,,,,,,,,,,,,, -264300,4.518424,0.6081145,,,,,,,,,,,,,, -264400,5.0163617,0.63652796,,,,,,,,,,,,,, -264500,4.604945,0.6779667,,,,,,,,,,,,,, -264600,4.8291297,0.63163257,,,,,,,,,,,,,, -264700,4.5352907,0.6109945,,,,,,,,,,,,,, -264800,4.7150836,0.6287221,,,,,,,,,,,,,, -264900,4.842271,0.6760926,,,,,,,,,,,,,, -265000,4.3031807,0.59385884,,,,,,,,,,,,,, -265100,4.3350134,0.6087962,,,,,,,,,,,,,, -265200,4.595827,0.6160716,,,,,,,,,,,,,, -265265,,,0.9604192972183228,0.1475045830011367,0.755840003490448,1.0530436038970947,50000.0,0.626800000667572,1.8413385152816768,10000.0,88802.68064022064,91903.60470747948,88802.68064022064,3082.2946906089783,9.11883282661438,0.0 -265300,4.856592,0.70531696,,,,,,,,,,,,,, -265400,4.466982,0.5535483,,,,,,,,,,,,,, -265500,4.5712633,0.61244,,,,,,,,,,,,,, -265600,4.454579,0.6335954,,,,,,,,,,,,,, -265700,4.6392264,0.6154289,,,,,,,,,,,,,, -265800,4.8235373,0.61746085,,,,,,,,,,,,,, -265900,4.5132303,0.62626296,,,,,,,,,,,,,, -266000,4.4473934,0.6733357,,,,,,,,,,,,,, -266100,4.1756816,0.659665,,,,,,,,,,,,,, -266200,4.76857,0.60415965,,,,,,,,,,,,,, -266300,4.9598203,0.6651953,,,,,,,,,,,,,, -266400,5.106262,0.6377451,,,,,,,,,,,,,, -266500,4.690888,0.6492738,,,,,,,,,,,,,, -266600,4.4088235,0.6388221,,,,,,,,,,,,,, -266700,4.4163194,0.5928323,,,,,,,,,,,,,, -266791,,,0.9616549611091614,0.1436686366796493,0.7555999755859375,1.053915023803711,50000.0,0.6271000504493713,1.8437012434005733,10000.0,89312.86989212036,92430.42082619669,89312.86989212036,3098.7946536540985,9.190809488296509,0.0 -266800,4.2416205,0.6935221,,,,,,,,,,,,,, -266900,4.5030475,0.63962674,,,,,,,,,,,,,, -267000,5.0458536,0.64055884,,,,,,,,,,,,,, -267100,4.598095,0.63563406,,,,,,,,,,,,,, -267200,4.7993627,0.66588575,,,,,,,,,,,,,, -267300,4.668628,0.5971296,,,,,,,,,,,,,, -267400,5.1226096,0.58854735,,,,,,,,,,,,,, -267500,4.201324,0.6073398,,,,,,,,,,,,,, -267600,4.482912,0.5960902,,,,,,,,,,,,,, -267700,4.534901,0.6134405,,,,,,,,,,,,,, -267800,4.5594583,0.5875435,,,,,,,,,,,,,, -267900,4.4898267,0.6932474,,,,,,,,,,,,,, -268000,4.7032638,0.65598553,,,,,,,,,,,,,, -268100,4.916646,0.63176936,,,,,,,,,,,,,, -268200,4.450625,0.5335462,,,,,,,,,,,,,, -268300,4.504006,0.66274494,,,,,,,,,,,,,, -268315,,,0.9625318646430968,0.1464009433984756,0.7557199597358704,1.0538593530654907,50000.0,0.6284000277519226,1.842480182647705,10000.0,89823.01347899437,92957.17655205728,89823.01347899437,3115.280136346817,9.261037111282349,0.0 -268400,4.3010807,0.58372396,,,,,,,,,,,,,, -268500,5.0232334,0.6406843,,,,,,,,,,,,,, -268600,4.507121,0.63547146,,,,,,,,,,,,,, -268700,4.348438,0.5742646,,,,,,,,,,,,,, -268800,4.8986793,0.6136981,,,,,,,,,,,,,, -268900,4.4620843,0.6320288,,,,,,,,,,,,,, -269000,4.2464466,0.6482028,,,,,,,,,,,,,, -269100,4.523979,0.6670839,,,,,,,,,,,,,, -269200,4.1792336,0.6465274,,,,,,,,,,,,,, -269300,4.70198,0.6632153,,,,,,,,,,,,,, -269400,4.416546,0.58415926,,,,,,,,,,,,,, -269500,4.0363107,0.54151034,,,,,,,,,,,,,, -269600,4.827672,0.6162126,,,,,,,,,,,,,, -269700,4.438334,0.60791224,,,,,,,,,,,,,, -269800,4.5543704,0.6493112,,,,,,,,,,,,,, -269840,,,0.9630300998687744,0.1411333978176117,0.7552599906921387,1.0552451610565186,50000.0,0.6276000142097473,1.8441766500473025,10000.0,90333.03701162338,93483.89601063728,90333.03701162338,3131.8267362117767,9.355196475982666,0.0 -269900,4.5178647,0.5902344,,,,,,,,,,,,,, -270000,4.34042,0.57602566,,,,,,,,,,,,,, -270100,4.670971,0.69076025,,,,,,,,,,,,,, -270200,4.249677,0.58687854,,,,,,,,,,,,,, -270300,4.490158,0.6073243,,,,,,,,,,,,,, -270400,5.1572523,0.71059203,,,,,,,,,,,,,, -270500,4.484278,0.6206589,,,,,,,,,,,,,, -270600,5.02467,0.6458339,,,,,,,,,,,,,, -270700,4.9957705,0.7168228,,,,,,,,,,,,,, -270800,4.711217,0.62328655,,,,,,,,,,,,,, -270900,4.452776,0.65074384,,,,,,,,,,,,,, -271000,4.1723375,0.55476516,,,,,,,,,,,,,, -271100,4.513071,0.6248825,,,,,,,,,,,,,, -271200,4.774862,0.6633736,,,,,,,,,,,,,, -271300,5.0046496,0.581584,,,,,,,,,,,,,, -271366,,,0.960180163383484,0.1452442556619644,0.7560399770736694,1.054051399230957,50000.0,0.626800000667572,1.844045639038086,10000.0,90843.17112255096,94011.56312346458,90843.17112255096,3149.2292981147766,9.430007934570312,0.0 -271400,4.5229144,0.6337649,,,,,,,,,,,,,, -271500,4.5408235,0.63127637,,,,,,,,,,,,,, -271600,4.620112,0.6223153,,,,,,,,,,,,,, -271700,4.4503946,0.62272847,,,,,,,,,,,,,, -271800,4.7723203,0.61777884,,,,,,,,,,,,,, -271900,4.8329844,0.62765706,,,,,,,,,,,,,, -272000,4.6992764,0.6286864,,,,,,,,,,,,,, -272100,4.097,0.58189803,,,,,,,,,,,,,, -272200,4.5368166,0.6186664,,,,,,,,,,,,,, -272300,5.3670855,0.6914526,,,,,,,,,,,,,, -272400,4.9820127,0.62100697,,,,,,,,,,,,,, -272500,4.448211,0.6407878,,,,,,,,,,,,,, -272600,4.918193,0.5811414,,,,,,,,,,,,,, -272700,4.4995556,0.6832592,,,,,,,,,,,,,, -272800,4.392799,0.6880696,,,,,,,,,,,,,, -272892,,,0.959363043308258,0.1494936347007751,0.7558000087738037,1.0542312860488892,50000.0,0.6278000473976135,1.842432737350464,10000.0,91353.35414123537,94538.45765447617,91353.35414123537,3165.8072340488434,9.508738994598389,0.0 -272900,4.318316,0.5816182,,,,,,,,,,,,,, -273000,5.061232,0.7076811,,,,,,,,,,,,,, -273100,4.7940855,0.60534286,,,,,,,,,,,,,, -273200,4.2499547,0.60815144,,,,,,,,,,,,,, -273300,4.3205595,0.5781424,,,,,,,,,,,,,, -273400,4.641359,0.63675517,,,,,,,,,,,,,, -273500,4.3478384,0.57618004,,,,,,,,,,,,,, -273600,4.271379,0.6041756,,,,,,,,,,,,,, -273700,5.050911,0.6812678,,,,,,,,,,,,,, -273800,4.7211967,0.63026184,,,,,,,,,,,,,, -273900,4.654699,0.62795776,,,,,,,,,,,,,, -274000,4.6534495,0.6829652,,,,,,,,,,,,,, -274100,4.71266,0.65751517,,,,,,,,,,,,,, -274200,4.57323,0.5835375,,,,,,,,,,,,,, -274300,4.6707506,0.62966716,,,,,,,,,,,,,, -274400,4.221119,0.6311662,,,,,,,,,,,,,, -274418,,,0.962312638759613,0.1411945521831512,0.7556999921798706,1.053430438041687,50000.0,0.6262000203132629,1.8420876264572144,10000.0,91863.55009460448,95065.35556936264,91863.55009460448,3182.3766729831696,9.585581302642822,0.0 -274500,4.374876,0.6597178,,,,,,,,,,,,,, -274600,4.5333257,0.6507046,,,,,,,,,,,,,, -274700,4.4968204,0.5877999,,,,,,,,,,,,,, -274800,4.558702,0.60926366,,,,,,,,,,,,,, -274900,3.9128916,0.5244758,,,,,,,,,,,,,, -275000,4.238487,0.5631741,,,,,,,,,,,,,, -275100,4.2607913,0.5697718,,,,,,,,,,,,,, -275200,4.502503,0.5968596,,,,,,,,,,,,,, -275300,5.466414,0.7060226,,,,,,,,,,,,,, -275400,4.319066,0.5536965,,,,,,,,,,,,,, -275500,4.412664,0.6252657,,,,,,,,,,,,,, -275600,4.421685,0.6288083,,,,,,,,,,,,,, -275700,4.187361,0.58943814,,,,,,,,,,,,,, -275800,4.4295254,0.61216056,,,,,,,,,,,,,, -275900,4.867925,0.66856426,,,,,,,,,,,,,, -275942,,,0.9617745280265808,0.1434799134731292,0.755620002746582,1.0534061193466189,50000.0,0.6269000172615051,1.843717694282532,10000.0,92373.46947193146,95591.95217609406,92373.46947193146,3198.919429302216,9.665200233459473,0.0 -276000,4.6355386,0.6731014,,,,,,,,,,,,,, -276100,4.76269,0.66883194,,,,,,,,,,,,,, -276200,4.360813,0.59784925,,,,,,,,,,,,,, -276300,4.1412253,0.5745738,,,,,,,,,,,,,, -276400,4.5231524,0.5735788,,,,,,,,,,,,,, -276500,4.911973,0.60818106,,,,,,,,,,,,,, -276600,4.5167556,0.6058935,,,,,,,,,,,,,, -276700,5.147902,0.6427676,,,,,,,,,,,,,, -276800,4.3990498,0.6281997,,,,,,,,,,,,,, -276900,5.3321857,0.62903154,,,,,,,,,,,,,, -277000,4.869804,0.7155188,,,,,,,,,,,,,, -277100,4.5988646,0.5960589,,,,,,,,,,,,,, -277200,4.5023923,0.7024741,,,,,,,,,,,,,, -277300,4.377043,0.59792405,,,,,,,,,,,,,, -277400,4.3402214,0.59796107,,,,,,,,,,,,,, -277466,,,0.961933970451355,0.1448633521795272,0.7560799717903137,1.0545258522033691,50000.0,0.6273000240325928,1.8438247442245483,10000.0,92883.548268795,96118.89766335487,92883.548268795,3215.657660007477,9.739282369613647,0.0 -277500,4.5944967,0.6462449,,,,,,,,,,,,,, -277600,4.711406,0.6575813,,,,,,,,,,,,,, -277700,4.623167,0.6243164,,,,,,,,,,,,,, -277800,4.4394255,0.62161225,,,,,,,,,,,,,, -277900,4.2040505,0.5407022,,,,,,,,,,,,,, -278000,4.489658,0.6071246,,,,,,,,,,,,,, -278100,4.3312535,0.5309782,,,,,,,,,,,,,, -278200,4.6712685,0.6549935,,,,,,,,,,,,,, -278300,4.499597,0.64321303,,,,,,,,,,,,,, -278400,4.629425,0.63341564,,,,,,,,,,,,,, -278500,4.322396,0.6004926,,,,,,,,,,,,,, -278600,4.5248585,0.53905696,,,,,,,,,,,,,, -278700,4.765097,0.6013468,,,,,,,,,,,,,, -278800,4.369496,0.6346735,,,,,,,,,,,,,, -278900,4.5048165,0.71291745,,,,,,,,,,,,,, -278989,,,0.960379421710968,0.1471159309148788,0.7558000087738037,1.0548198223114014,50000.0,0.6269000172615051,1.843254327774048,10000.0,93393.59517908096,96645.7634203434,93393.59517908096,3232.344386816025,9.817296981811523,0.0 -279000,5.23907,0.60885835,,,,,,,,,,,,,, -279100,4.4828186,0.6689822,,,,,,,,,,,,,, -279200,4.0039287,0.5185772,,,,,,,,,,,,,, -279300,4.9003263,0.6434202,,,,,,,,,,,,,, -279400,4.577628,0.64861614,,,,,,,,,,,,,, -279500,4.5505476,0.62873584,,,,,,,,,,,,,, -279600,4.1874475,0.597031,,,,,,,,,,,,,, -279700,5.196849,0.68414116,,,,,,,,,,,,,, -279800,4.296672,0.5894657,,,,,,,,,,,,,, -279900,4.7585726,0.6403233,,,,,,,,,,,,,, -280000,4.706943,0.6268222,,,,,,,,,,,,,, -280100,4.5703435,0.5890758,,,,,,,,,,,,,, -280200,4.355831,0.5600162,,,,,,,,,,,,,, -280300,4.861775,0.661134,,,,,,,,,,,,,, -280400,4.407025,0.6272265,,,,,,,,,,,,,, -280500,4.1300573,0.63898987,,,,,,,,,,,,,, -280511,,,0.9603196382522584,0.1468950361013412,0.755840003490448,1.0539928674697876,50000.0,0.6272000074386597,1.842954397201538,10000.0,93903.50009322166,97172.31008005142,93903.50009322166,3248.8543951511383,9.893784284591677,0.0 -280600,4.8185124,0.68381786,,,,,,,,,,,,,, -280700,4.477459,0.5800545,,,,,,,,,,,,,, -280800,4.484513,0.66179216,,,,,,,,,,,,,, -280900,4.2850313,0.59842366,,,,,,,,,,,,,, -281000,4.8710113,0.6207898,,,,,,,,,,,,,, -281100,4.745605,0.65321726,,,,,,,,,,,,,, -281200,4.411532,0.5884887,,,,,,,,,,,,,, -281300,4.622534,0.727375,,,,,,,,,,,,,, -281400,4.8195415,0.59219205,,,,,,,,,,,,,, -281500,4.5249605,0.6454804,,,,,,,,,,,,,, -281600,4.908253,0.6882852,,,,,,,,,,,,,, -281700,4.720685,0.61253977,,,,,,,,,,,,,, -281800,4.778465,0.67717105,,,,,,,,,,,,,, -281900,4.281706,0.57924724,,,,,,,,,,,,,, -282000,4.5287795,0.62612563,,,,,,,,,,,,,, -282035,,,0.9624720811843872,0.1407591849565506,0.7559399604797363,1.053431510925293,50000.0,0.6262000203132629,1.842752814292908,10000.0,94413.65333604813,97699.14486837389,94413.65333604813,3265.3635454177856,10.01119351387024,0.0 -282100,4.708757,0.6164996,,,,,,,,,,,,,, -282200,4.6353965,0.6470256,,,,,,,,,,,,,, -282300,4.761958,0.61898446,,,,,,,,,,,,,, -282400,4.2578382,0.54213715,,,,,,,,,,,,,, -282500,4.4659257,0.62546456,,,,,,,,,,,,,, -282600,4.7670856,0.58443195,,,,,,,,,,,,,, -282700,4.489568,0.60245126,,,,,,,,,,,,,, -282800,5.1965203,0.61033523,,,,,,,,,,,,,, -282900,4.7186203,0.57908434,,,,,,,,,,,,,, -283000,4.346937,0.6083392,,,,,,,,,,,,,, -283100,4.6124616,0.66495264,,,,,,,,,,,,,, -283200,4.9463553,0.6991823,,,,,,,,,,,,,, -283300,4.593094,0.58173525,,,,,,,,,,,,,, -283400,4.174962,0.5514657,,,,,,,,,,,,,, -283500,4.1201534,0.58685213,,,,,,,,,,,,,, -283558,,,0.9611965417861938,0.1465084254741668,0.756060004234314,1.053581476211548,50000.0,0.6278000473976135,1.8434096574783323,10000.0,94923.66025543211,98225.74566698074,94923.66025543211,3281.825499534607,10.0881450176239,0.0 -283600,4.4363832,0.7048811,,,,,,,,,,,,,, -283700,4.556039,0.61855954,,,,,,,,,,,,,, -283800,4.4079056,0.6311674,,,,,,,,,,,,,, -283900,4.1417737,0.5584312,,,,,,,,,,,,,, -284000,4.2547784,0.5533608,,,,,,,,,,,,,, -284100,4.40315,0.561631,,,,,,,,,,,,,, -284200,4.7252603,0.5931463,,,,,,,,,,,,,, -284300,4.6666126,0.5956985,,,,,,,,,,,,,, -284400,5.253418,0.7341809,,,,,,,,,,,,,, -284500,4.5980124,0.62973154,,,,,,,,,,,,,, -284600,4.664675,0.646035,,,,,,,,,,,,,, -284700,4.843252,0.6240209,,,,,,,,,,,,,, -284800,5.4134917,0.6183202,,,,,,,,,,,,,, -284900,4.5483646,0.6153055,,,,,,,,,,,,,, -285000,4.5700088,0.54768324,,,,,,,,,,,,,, -285082,,,0.9600207209587096,0.1455494463443756,0.7560200095176697,1.0529658794403076,50000.0,0.6277000308036804,1.8420534133911133,10000.0,95433.71224570274,98752.69161009789,95433.71224570274,3298.5862398147583,10.166528940200806,0.0 -285100,4.050741,0.57521236,,,,,,,,,,,,,, -285200,4.001466,0.5988444,,,,,,,,,,,,,, -285300,4.9344296,0.63388693,,,,,,,,,,,,,, -285400,4.4078383,0.5708188,,,,,,,,,,,,,, -285500,4.931127,0.76336706,,,,,,,,,,,,,, -285600,4.3721113,0.5382273,,,,,,,,,,,,,, -285700,4.4635124,0.6337477,,,,,,,,,,,,,, -285800,4.7121754,0.6215968,,,,,,,,,,,,,, -285900,5.1770563,0.6057423,,,,,,,,,,,,,, -286000,4.89976,0.59820086,,,,,,,,,,,,,, -286100,4.5275855,0.6186386,,,,,,,,,,,,,, -286200,4.52027,0.5673502,,,,,,,,,,,,,, -286300,4.2555437,0.65351397,,,,,,,,,,,,,, -286400,4.541202,0.6025474,,,,,,,,,,,,,, -286500,4.6344123,0.6607534,,,,,,,,,,,,,, -286600,4.1254683,0.5529026,,,,,,,,,,,,,, -286605,,,0.9612563848495485,0.1459049135446548,0.7559799551963806,1.054235577583313,50000.0,0.6269000172615051,1.844022750854492,10000.0,95943.7070596218,99279.29757618904,95943.7070596218,3315.0661194324493,10.243273496627808,0.0 -286700,4.4530706,0.65128636,,,,,,,,,,,,,, -286800,4.454782,0.6981395,,,,,,,,,,,,,, -286900,5.010634,0.6003062,,,,,,,,,,,,,, -287000,4.129536,0.6141185,,,,,,,,,,,,,, -287100,4.2633443,0.64324325,,,,,,,,,,,,,, -287200,4.45717,0.6805685,,,,,,,,,,,,,, -287300,5.069561,0.6848858,,,,,,,,,,,,,, -287400,4.1529975,0.56542766,,,,,,,,,,,,,, -287500,5.1086693,0.6680931,,,,,,,,,,,,,, -287600,4.439696,0.6676748,,,,,,,,,,,,,, -287700,4.262868,0.58546823,,,,,,,,,,,,,, -287800,4.676843,0.6745992,,,,,,,,,,,,,, -287900,4.687344,0.6488604,,,,,,,,,,,,,, -288000,4.9380493,0.64065903,,,,,,,,,,,,,, -288100,4.5323396,0.595779,,,,,,,,,,,,,, -288129,,,0.9620535373687744,0.1462266147136688,0.7553799748420715,1.0544589757919312,50000.0,0.6270000338554382,1.8442836999893188,10000.0,96453.7833352089,99806.06667613985,96453.7833352089,3331.6242294311523,10.32362985610962,0.0 -288200,4.590095,0.6795979,,,,,,,,,,,,,, -288300,4.6718364,0.61412233,,,,,,,,,,,,,, -288400,4.776696,0.72969973,,,,,,,,,,,,,, -288500,4.386845,0.5949565,,,,,,,,,,,,,, -288600,4.681433,0.6766757,,,,,,,,,,,,,, -288700,4.7256246,0.63086385,,,,,,,,,,,,,, -288800,4.872753,0.55214095,,,,,,,,,,,,,, -288900,4.2961903,0.6433983,,,,,,,,,,,,,, -289000,4.13696,0.55651474,,,,,,,,,,,,,, -289100,4.708814,0.7184869,,,,,,,,,,,,,, -289200,4.9511,0.625328,,,,,,,,,,,,,, -289300,4.3707585,0.6316648,,,,,,,,,,,,,, -289400,4.9099264,0.650134,,,,,,,,,,,,,, -289500,4.5484543,0.6561781,,,,,,,,,,,,,, -289600,4.584152,0.667183,,,,,,,,,,,,,, -289653,,,0.9625717401504515,0.1420051157474517,0.7561999559402466,1.053595781326294,50000.0,0.6270000338554382,1.843524932861328,10000.0,96963.97358345984,100333.07609534264,96963.97358345984,3348.3108434677124,10.402316093444824,0.0 -289700,4.0910673,0.6245397,,,,,,,,,,,,,, -289800,4.393037,0.63124454,,,,,,,,,,,,,, -289900,4.8685513,0.6016013,,,,,,,,,,,,,, -290000,4.948643,0.6092464,,,,,,,,,,,,,, -290100,4.4225307,0.56531733,,,,,,,,,,,,,, -290200,4.820553,0.66681397,,,,,,,,,,,,,, -290300,4.23144,0.6100167,,,,,,,,,,,,,, -290400,4.554235,0.62373567,,,,,,,,,,,,,, -290500,4.595616,0.5736986,,,,,,,,,,,,,, -290600,5.3840666,0.6382864,,,,,,,,,,,,,, -290700,4.4798927,0.6520901,,,,,,,,,,,,,, -290800,4.460582,0.6654938,,,,,,,,,,,,,, -290900,4.1731596,0.5791789,,,,,,,,,,,,,, -291000,4.467194,0.63825375,,,,,,,,,,,,,, -291100,4.737364,0.6943321,,,,,,,,,,,,,, -291176,,,0.9592633843421936,0.1484569460153579,0.7561599612236023,1.0529533624649048,50000.0,0.6282000541687012,1.8418251276016235,10000.0,97473.8518486023,100859.69595813753,97473.8518486023,3364.91783618927,10.482990264892578,0.0 -291200,4.784255,0.6591174,,,,,,,,,,,,,, -291300,4.9478955,0.6184526,,,,,,,,,,,,,, -291400,4.4179196,0.61810267,,,,,,,,,,,,,, -291500,5.2564287,0.60134554,,,,,,,,,,,,,, -291600,5.100582,0.6838043,,,,,,,,,,,,,, -291700,4.5557914,0.61322695,,,,,,,,,,,,,, -291800,4.404966,0.6160487,,,,,,,,,,,,,, -291900,4.459108,0.65879726,,,,,,,,,,,,,, -292000,4.5871754,0.6157361,,,,,,,,,,,,,, -292100,4.5316396,0.65028816,,,,,,,,,,,,,, -292200,4.383114,0.57432944,,,,,,,,,,,,,, -292300,4.6005416,0.60974544,,,,,,,,,,,,,, -292400,4.4864326,0.60271424,,,,,,,,,,,,,, -292500,5.1537304,0.64814895,,,,,,,,,,,,,, -292600,5.0842867,0.73382735,,,,,,,,,,,,,, -292699,,,0.9601402878761292,0.1482071578502655,0.7561799883842468,1.052278757095337,50000.0,0.6266000270843506,1.842485427856445,10000.0,97983.82679986954,101386.6808810234,97983.82679986954,3381.796614646912,10.558565378189089,0.0 -292700,4.267696,0.59591424,,,,,,,,,,,,,, -292800,4.5692024,0.61795896,,,,,,,,,,,,,, -292900,4.3713555,0.663043,,,,,,,,,,,,,, -293000,4.4132533,0.5774296,,,,,,,,,,,,,, -293100,4.5439444,0.54835,,,,,,,,,,,,,, -293200,4.644828,0.56053215,,,,,,,,,,,,,, -293300,3.9883687,0.5820076,,,,,,,,,,,,,, -293400,4.9378386,0.6099008,,,,,,,,,,,,,, -293500,4.214818,0.55016404,,,,,,,,,,,,,, -293600,4.5863767,0.68904537,,,,,,,,,,,,,, -293700,4.4859033,0.59567326,,,,,,,,,,,,,, -293800,4.9056287,0.68306834,,,,,,,,,,,,,, -293900,4.9923487,0.7341775,,,,,,,,,,,,,, -294000,4.2606387,0.6109233,,,,,,,,,,,,,, -294100,4.955699,0.64181143,,,,,,,,,,,,,, -294200,4.2908554,0.629679,,,,,,,,,,,,,, -294222,,,0.960718274116516,0.1471677124500274,0.7562400102615356,1.0532735586166382,50000.0,0.626300036907196,1.8423364162445068,10000.0,98493.7816810608,101913.36937975883,98493.7816810608,3398.401019334793,10.63289713859558,0.0 -294300,4.9647846,0.65469927,,,,,,,,,,,,,, -294400,4.424223,0.5992075,,,,,,,,,,,,,, -294500,4.3366003,0.6068509,,,,,,,,,,,,,, -294600,4.449832,0.6650643,,,,,,,,,,,,,, -294700,4.3711524,0.60131913,,,,,,,,,,,,,, -294800,5.074647,0.63590705,,,,,,,,,,,,,, -294900,4.2753224,0.635441,,,,,,,,,,,,,, -295000,4.6393423,0.6338775,,,,,,,,,,,,,, -295100,4.4423594,0.63692635,,,,,,,,,,,,,, -295200,4.685819,0.6632873,,,,,,,,,,,,,, -295300,5.112066,0.6091541,,,,,,,,,,,,,, -295400,4.4468937,0.59382117,,,,,,,,,,,,,, -295500,4.745668,0.6718295,,,,,,,,,,,,,, -295600,4.746931,0.63990605,,,,,,,,,,,,,, -295700,4.416452,0.6196154,,,,,,,,,,,,,, -295746,,,0.9618542790412904,0.144152745604515,0.756060004234314,1.0529661178588867,50000.0,0.6274000406265259,1.8412050008773804,10000.0,99003.77694940568,102440.0279521942,99003.77694940568,3414.94277882576,10.699851036071776,0.0 -295800,4.6315928,0.72553796,,,,,,,,,,,,,, -295900,4.2751427,0.64265996,,,,,,,,,,,,,, -296000,4.327942,0.5617002,,,,,,,,,,,,,, -296100,4.804827,0.69987893,,,,,,,,,,,,,, -296200,4.685161,0.6163278,,,,,,,,,,,,,, -296300,4.7997556,0.53163,,,,,,,,,,,,,, -296400,4.259522,0.5796936,,,,,,,,,,,,,, -296500,4.261299,0.63342744,,,,,,,,,,,,,, -296600,4.7892914,0.7248661,,,,,,,,,,,,,, -296700,4.6037154,0.63456017,,,,,,,,,,,,,, -296800,4.39673,0.60541517,,,,,,,,,,,,,, -296900,4.90516,0.6043066,,,,,,,,,,,,,, -297000,5.369972,0.7402449,,,,,,,,,,,,,, -297100,4.693268,0.6660847,,,,,,,,,,,,,, -297200,4.649361,0.620913,,,,,,,,,,,,,, -297270,,,0.9604192972183228,0.145902469754219,0.7557599544525146,1.05452299118042,50000.0,0.6274000406265259,1.8436675071716309,10000.0,99513.71591758728,102966.62348175047,99513.71591758728,3431.464282512665,10.78010392189026,0.0 -297300,5.0319467,0.56827587,,,,,,,,,,,,,, -297400,4.8288536,0.6428986,,,,,,,,,,,,,, -297500,4.454829,0.6681965,,,,,,,,,,,,,, -297600,4.509408,0.57131124,,,,,,,,,,,,,, -297700,4.988337,0.6743453,,,,,,,,,,,,,, -297800,4.6380363,0.6135789,,,,,,,,,,,,,, -297900,4.7686243,0.6749674,,,,,,,,,,,,,, -298000,4.1109104,0.58587706,,,,,,,,,,,,,, -298100,5.2433496,0.64729285,,,,,,,,,,,,,, -298200,4.43873,0.6535495,,,,,,,,,,,,,, -298300,4.7234945,0.6141205,,,,,,,,,,,,,, -298400,4.195497,0.5733954,,,,,,,,,,,,,, -298500,5.0298257,0.65896463,,,,,,,,,,,,,, -298600,4.6955504,0.5884378,,,,,,,,,,,,,, -298700,4.7702928,0.58419234,,,,,,,,,,,,,, -298794,,,0.9609375,0.145173043012619,0.7562800049781799,1.0528514385223389,50000.0,0.6273000240325928,1.8408688306808472,10000.0,100023.7195968628,103493.4132258892,100023.7195968628,3448.119199991226,10.856154441833496,0.0 -298800,4.7250986,0.6604372,,,,,,,,,,,,,, -298900,4.55055,0.6447342,,,,,,,,,,,,,, -299000,4.4148283,0.6117442,,,,,,,,,,,,,, -299100,5.1276603,0.64171773,,,,,,,,,,,,,, -299200,5.17643,0.6170473,,,,,,,,,,,,,, -299300,4.7391167,0.6918294,,,,,,,,,,,,,, -299400,4.711346,0.66160315,,,,,,,,,,,,,, -299500,4.9632,0.61654454,,,,,,,,,,,,,, -299600,4.1977305,0.6071017,,,,,,,,,,,,,, -299700,4.849786,0.6184796,,,,,,,,,,,,,, -299800,4.7547436,0.56789,,,,,,,,,,,,,, -299900,4.464037,0.5815638,,,,,,,,,,,,,, -300000,4.549873,0.5886809,,,,,,,,,,,,,, -300100,4.178148,0.60693204,,,,,,,,,,,,,, -300200,4.7351255,0.5666194,,,,,,,,,,,,,, -300300,4.407746,0.56538236,,,,,,,,,,,,,, -300318,,,0.9599409699440002,0.1476414948701858,0.7556399703025818,1.0543302297592163,50000.0,0.628000020980835,1.844310998916626,10000.0,100533.74055194856,104020.15699887276,100533.74055194856,3464.7071413993835,10.93600869178772,0.0 -300400,4.9300194,0.6440038,,,,,,,,,,,,,, -300500,4.653803,0.6145392,,,,,,,,,,,,,, -300600,5.0454736,0.66937345,,,,,,,,,,,,,, -300700,4.3967357,0.6185503,,,,,,,,,,,,,, -300800,4.3569803,0.6278671,,,,,,,,,,,,,, -300900,4.878039,0.6821184,,,,,,,,,,,,,, -301000,4.200558,0.517197,,,,,,,,,,,,,, -301100,4.4837394,0.58205557,,,,,,,,,,,,,, -301200,4.408638,0.59523094,,,,,,,,,,,,,, -301300,5.3158913,0.62560815,,,,,,,,,,,,,, -301400,4.848719,0.67071575,,,,,,,,,,,,,, -301500,4.1923113,0.5717503,,,,,,,,,,,,,, -301600,4.505037,0.67194045,,,,,,,,,,,,,, -301700,4.64936,0.5960151,,,,,,,,,,,,,, -301800,4.4089417,0.61228395,,,,,,,,,,,,,, -301843,,,0.9615154266357422,0.1449480205774307,0.7561799883842468,1.0539348125457764,50000.0,0.628000020980835,1.8417717218399048,10000.0,101043.92377829552,104546.94538855553,101043.92377829552,3481.1808309555054,11.012544393539429,0.0 -301900,4.3811784,0.60337144,,,,,,,,,,,,,, -302000,4.4247704,0.6504895,,,,,,,,,,,,,, -302100,4.7074056,0.6684758,,,,,,,,,,,,,, -302200,4.540697,0.66974515,,,,,,,,,,,,,, -302300,4.457777,0.578271,,,,,,,,,,,,,, -302400,4.4136624,0.58775663,,,,,,,,,,,,,, -302500,4.3734155,0.54156166,,,,,,,,,,,,,, -302600,4.7148466,0.57019347,,,,,,,,,,,,,, -302700,4.540966,0.5946241,,,,,,,,,,,,,, -302800,4.5299964,0.6193967,,,,,,,,,,,,,, -302900,4.753752,0.5872272,,,,,,,,,,,,,, -303000,4.6621842,0.6469792,,,,,,,,,,,,,, -303100,4.566357,0.65715957,,,,,,,,,,,,,, -303200,4.7684455,0.6150872,,,,,,,,,,,,,, -303300,4.325204,0.63410294,,,,,,,,,,,,,, -303366,,,0.9603196382522584,0.1485083252191543,0.7553399801254272,1.054638147354126,50000.0,0.6267000436782837,1.8432505130767824,10000.0,101553.90138459206,105073.63487243652,101553.90138459206,3497.745079278946,11.105444431304932,0.0 -303400,5.0102186,0.7443321,,,,,,,,,,,,,, -303500,4.6979346,0.606755,,,,,,,,,,,,,, -303600,4.07014,0.57267624,,,,,,,,,,,,,, -303700,5.1834693,0.68701965,,,,,,,,,,,,,, -303800,4.375941,0.60053426,,,,,,,,,,,,,, -303900,4.857063,0.6988271,,,,,,,,,,,,,, -304000,4.900278,0.5849335,,,,,,,,,,,,,, -304100,4.609953,0.5687051,,,,,,,,,,,,,, -304200,4.3739142,0.6011267,,,,,,,,,,,,,, -304300,4.5835166,0.55179787,,,,,,,,,,,,,, -304400,4.495349,0.5954538,,,,,,,,,,,,,, -304500,4.971292,0.68224716,,,,,,,,,,,,,, -304600,4.702497,0.6261587,,,,,,,,,,,,,, -304700,4.892335,0.69751924,,,,,,,,,,,,,, -304800,4.9985886,0.7015004,,,,,,,,,,,,,, -304890,,,0.9614756107330322,0.1455578804016113,0.7556599974632263,1.0540353059768677,50000.0,0.6270000338554382,1.8422266244888303,10000.0,102063.76321053503,105600.15265583992,102063.76321053503,3514.2617585659027,11.190539121627808,0.0 -304900,4.6905785,0.6112002,,,,,,,,,,,,,, -305000,4.72137,0.6660887,,,,,,,,,,,,,, -305100,4.795095,0.62363,,,,,,,,,,,,,, -305200,4.9408436,0.6728689,,,,,,,,,,,,,, -305300,4.433143,0.584901,,,,,,,,,,,,,, -305400,4.2144227,0.5963091,,,,,,,,,,,,,, -305500,5.037545,0.6506548,,,,,,,,,,,,,, -305600,4.2861257,0.5957733,,,,,,,,,,,,,, -305700,4.3429885,0.6007328,,,,,,,,,,,,,, -305800,4.777757,0.65043485,,,,,,,,,,,,,, -305900,4.488534,0.5806406,,,,,,,,,,,,,, -306000,4.357891,0.58646685,,,,,,,,,,,,,, -306100,4.4241962,0.6087787,,,,,,,,,,,,,, -306200,4.7602825,0.5954231,,,,,,,,,,,,,, -306300,4.3235445,0.5834839,,,,,,,,,,,,,, -306400,4.5102153,0.6721195,,,,,,,,,,,,,, -306414,,,0.9628507494926452,0.1430933475494384,0.7557799816131592,1.0519496202468872,50000.0,0.626800000667572,1.841051697731018,10000.0,102573.88640189172,106127.09458732604,102573.88640189172,3530.9448194503784,11.271536588668823,0.0 -306500,4.5653257,0.6705938,,,,,,,,,,,,,, -306600,4.8610506,0.6398923,,,,,,,,,,,,,, -306700,4.5494833,0.591709,,,,,,,,,,,,,, -306800,5.368993,0.62943435,,,,,,,,,,,,,, -306900,4.464685,0.56166065,,,,,,,,,,,,,, -307000,4.71057,0.61792654,,,,,,,,,,,,,, -307100,4.560654,0.5123524,,,,,,,,,,,,,, -307200,4.2085276,0.54090923,,,,,,,,,,,,,, -307300,4.096064,0.61565804,,,,,,,,,,,,,, -307400,4.442332,0.5765962,,,,,,,,,,,,,, -307500,4.9727926,0.61248434,,,,,,,,,,,,,, -307600,4.4074225,0.6491709,,,,,,,,,,,,,, -307700,4.39249,0.6279054,,,,,,,,,,,,,, -307800,4.94816,0.6796726,,,,,,,,,,,,,, -307900,4.832234,0.64858294,,,,,,,,,,,,,, -307938,,,0.9618542790412904,0.1455347239971161,0.7562599778175354,1.0540136098861694,50000.0,0.6266000270843506,1.842125654220581,10000.0,103084.04380178452,106654.18004608154,103084.04380178452,3547.7343316078186,11.354299306869509,0.0 -308000,4.335396,0.64473474,,,,,,,,,,,,,, -308100,4.3994555,0.6535061,,,,,,,,,,,,,, -308200,4.507838,0.6155778,,,,,,,,,,,,,, -308300,4.384886,0.5984193,,,,,,,,,,,,,, -308400,4.891445,0.6707673,,,,,,,,,,,,,, -308500,4.469103,0.6195412,,,,,,,,,,,,,, -308600,4.346347,0.67272735,,,,,,,,,,,,,, -308700,4.3667197,0.6044877,,,,,,,,,,,,,, -308800,4.270488,0.6189116,,,,,,,,,,,,,, -308900,4.7301526,0.7318491,,,,,,,,,,,,,, -309000,4.2522206,0.60732985,,,,,,,,,,,,,, -309100,4.7531056,0.5789436,,,,,,,,,,,,,, -309200,4.49205,0.59371674,,,,,,,,,,,,,, -309300,4.9999795,0.64184904,,,,,,,,,,,,,, -309400,4.983782,0.6826398,,,,,,,,,,,,,, -309462,,,0.9624919891357422,0.1417141854763031,0.7557399868965149,1.053352117538452,50000.0,0.6272000074386597,1.8425387144088743,10000.0,103594.16372561456,107181.73225021362,103594.16372561456,3565.0480420589447,11.41740655899048,0.0 -309500,4.5411143,0.66829884,,,,,,,,,,,,,, -309600,4.828741,0.5715203,,,,,,,,,,,,,, -309700,4.345805,0.59301853,,,,,,,,,,,,,, -309800,5.006516,0.6385778,,,,,,,,,,,,,, -309900,5.517789,0.6925849,,,,,,,,,,,,,, -310000,4.5592437,0.60326624,,,,,,,,,,,,,, -310100,4.0922294,0.6106488,,,,,,,,,,,,,, -310200,4.3800087,0.60323256,,,,,,,,,,,,,, -310300,4.4747767,0.6322151,,,,,,,,,,,,,, -310400,4.477677,0.6701954,,,,,,,,,,,,,, -310500,4.082841,0.61960644,,,,,,,,,,,,,, -310600,4.350035,0.6292732,,,,,,,,,,,,,, -310700,4.6649528,0.6865926,,,,,,,,,,,,,, -310800,4.2135115,0.55750334,,,,,,,,,,,,,, -310900,4.3301167,0.59109616,,,,,,,,,,,,,, -310984,,,0.9591238498687744,0.149687573313713,0.7559199929237366,1.0534484386444092,50000.0,0.6276000142097473,1.842854619026184,10000.0,104104.17977261543,107708.53921103476,104104.17977261543,3581.7052421569824,11.497244358062744,0.0 -311000,4.350616,0.69959754,,,,,,,,,,,,,, -311100,4.8186984,0.57687116,,,,,,,,,,,,,, -311200,4.8341007,0.6369562,,,,,,,,,,,,,, -311300,4.4166255,0.6310605,,,,,,,,,,,,,, -311400,4.398016,0.5951159,,,,,,,,,,,,,, -311500,4.8930807,0.6887593,,,,,,,,,,,,,, -311600,4.6911697,0.6413017,,,,,,,,,,,,,, -311700,4.51558,0.59544283,,,,,,,,,,,,,, -311800,4.940367,0.69054085,,,,,,,,,,,,,, -311900,4.4552646,0.5955782,,,,,,,,,,,,,, -312000,5.1621375,0.6188645,,,,,,,,,,,,,, -312100,4.8551183,0.6240226,,,,,,,,,,,,,, -312200,4.7924943,0.7008072,,,,,,,,,,,,,, -312300,4.839925,0.6730187,,,,,,,,,,,,,, -312400,4.6006775,0.65519667,,,,,,,,,,,,,, -312500,4.0719366,0.5953741,,,,,,,,,,,,,, -312507,,,0.9618343114852904,0.1447564959526062,0.7558799982070923,1.053959846496582,50000.0,0.6279000043869019,1.8433014154434204,10000.0,104614.0937845707,108235.34715676308,104614.0937845707,3598.469397068024,11.572259664535522,0.0 -312600,4.289654,0.6007257,,,,,,,,,,,,,, -312700,4.7753515,0.62924004,,,,,,,,,,,,,, -312800,4.3539157,0.6293396,,,,,,,,,,,,,, -312900,4.258391,0.59606415,,,,,,,,,,,,,, -313000,4.342664,0.5512294,,,,,,,,,,,,,, -313100,5.063212,0.6534157,,,,,,,,,,,,,, -313200,4.5530787,0.5809277,,,,,,,,,,,,,, -313300,4.8807926,0.73740125,,,,,,,,,,,,,, -313400,4.5536,0.6087148,,,,,,,,,,,,,, -313500,4.399895,0.59856737,,,,,,,,,,,,,, -313600,4.857042,0.6251277,,,,,,,,,,,,,, -313700,4.356521,0.60619825,,,,,,,,,,,,,, -313800,4.064584,0.57871354,,,,,,,,,,,,,, -313900,4.1877894,0.57548815,,,,,,,,,,,,,, -314000,4.3642664,0.590665,,,,,,,,,,,,,, -314029,,,0.9614556431770324,0.1427123695611953,0.7558000087738037,1.0546585321426392,50000.0,0.6272000074386597,1.8429999351501465,10000.0,105123.96016263962,108761.96542572977,105123.96016263962,3615.086302042008,11.65283203125,0.0 -314100,5.0285583,0.6660447,,,,,,,,,,,,,, -314200,4.3984404,0.6037798,,,,,,,,,,,,,, -314300,4.8457522,0.63952804,,,,,,,,,,,,,, -314400,4.945631,0.6728357,,,,,,,,,,,,,, -314500,4.563634,0.63175976,,,,,,,,,,,,,, -314600,4.9750977,0.74863464,,,,,,,,,,,,,, -314700,4.3218474,0.60190576,,,,,,,,,,,,,, -314800,5.0556626,0.65133893,,,,,,,,,,,,,, -314900,4.3508677,0.6392043,,,,,,,,,,,,,, -315000,5.05345,0.662102,,,,,,,,,,,,,, -315100,4.3807034,0.6260228,,,,,,,,,,,,,, -315200,4.206508,0.5865609,,,,,,,,,,,,,, -315300,4.6613464,0.65392023,,,,,,,,,,,,,, -315400,4.410764,0.5836853,,,,,,,,,,,,,, -315500,5.4900007,0.68516326,,,,,,,,,,,,,, -315553,,,0.9622528553009032,0.1424500048160553,0.7561399936676025,1.053011775016785,50000.0,0.626800000667572,1.8422795534133911,10000.0,105633.98216438292,109288.64826965332,105633.98216438292,3631.61336350441,11.731301069259644,0.0 -315600,4.600593,0.6746215,,,,,,,,,,,,,, -315700,4.712124,0.6130856,,,,,,,,,,,,,, -315800,4.305679,0.58284074,,,,,,,,,,,,,, -315900,4.6955266,0.6757399,,,,,,,,,,,,,, -316000,4.312469,0.60321546,,,,,,,,,,,,,, -316100,4.5615864,0.62580967,,,,,,,,,,,,,, -316200,4.054935,0.62628883,,,,,,,,,,,,,, -316300,4.3772845,0.6280026,,,,,,,,,,,,,, -316400,5.150899,0.61355364,,,,,,,,,,,,,, -316500,4.7494674,0.63663477,,,,,,,,,,,,,, -316600,4.5046997,0.6686999,,,,,,,,,,,,,, -316700,4.499671,0.62203544,,,,,,,,,,,,,, -316800,4.9739733,0.691914,,,,,,,,,,,,,, -316900,4.716785,0.6668749,,,,,,,,,,,,,, -317000,4.668119,0.64661294,,,,,,,,,,,,,, -317077,,,0.9612762928009032,0.1458015292882919,0.7560399770736694,1.0532641410827637,50000.0,0.6277000308036804,1.8440029621124268,10000.0,106143.96684598924,109815.33223104475,106143.96684598924,3648.164837121964,11.824159622192385,0.0 -317100,4.3758717,0.65113926,,,,,,,,,,,,,, -317200,4.785661,0.59131604,,,,,,,,,,,,,, -317300,4.102146,0.57725656,,,,,,,,,,,,,, -317400,4.3867493,0.6102419,,,,,,,,,,,,,, -317500,4.828653,0.62873507,,,,,,,,,,,,,, -317600,4.6334605,0.6201692,,,,,,,,,,,,,, -317700,4.9469013,0.68584645,,,,,,,,,,,,,, -317800,4.5176373,0.64099807,,,,,,,,,,,,,, -317900,4.232627,0.6417591,,,,,,,,,,,,,, -318000,4.3491993,0.5620634,,,,,,,,,,,,,, -318100,4.4460626,0.5960232,,,,,,,,,,,,,, -318200,4.1199675,0.5955463,,,,,,,,,,,,,, -318300,4.847421,0.72560346,,,,,,,,,,,,,, -318400,4.26894,0.59009683,,,,,,,,,,,,,, -318500,4.678149,0.6360326,,,,,,,,,,,,,, -318600,4.3756065,0.6314483,,,,,,,,,,,,,, -318601,,,0.9610171914100648,0.1459923684597015,0.7558799982070923,1.0537469387054443,50000.0,0.6276000142097473,1.8443596363067627,10000.0,106654.22787618636,110342.21775627136,106654.22787618636,3664.6562349796295,11.90298581123352,0.0 -318700,5.620263,0.6657694,,,,,,,,,,,,,, -318800,4.704752,0.5849332,,,,,,,,,,,,,, -318900,4.2390895,0.5878793,,,,,,,,,,,,,, -319000,4.358485,0.6427774,,,,,,,,,,,,,, -319100,4.355743,0.6525634,,,,,,,,,,,,,, -319200,4.3743815,0.59924585,,,,,,,,,,,,,, -319300,4.44831,0.6034406,,,,,,,,,,,,,, -319400,4.413855,0.63280857,,,,,,,,,,,,,, -319500,4.2526913,0.6234842,,,,,,,,,,,,,, -319600,4.287504,0.6358484,,,,,,,,,,,,,, -319700,4.3229012,0.60446113,,,,,,,,,,,,,, -319800,4.4040594,0.6401227,,,,,,,,,,,,,, -319900,4.266972,0.6496276,,,,,,,,,,,,,, -320000,4.973948,0.68722224,,,,,,,,,,,,,, -320100,4.2388983,0.59527373,,,,,,,,,,,,,, -320125,,,0.9613759517669678,0.146775797009468,0.7560999989509583,1.0519914627075195,50000.0,0.6277000308036804,1.839996337890625,10000.0,107164.28194069862,110869.0744421482,107164.28194069862,3681.319794178009,11.985785961151125,0.0 -320200,4.247873,0.59504306,,,,,,,,,,,,,, -320300,4.8169827,0.67263365,,,,,,,,,,,,,, -320400,4.179503,0.5337598,,,,,,,,,,,,,, -320500,4.5821376,0.6370076,,,,,,,,,,,,,, -320600,4.551484,0.5851761,,,,,,,,,,,,,, -320700,4.610871,0.64440155,,,,,,,,,,,,,, -320800,4.871104,0.5872687,,,,,,,,,,,,,, -320900,4.8951426,0.658083,,,,,,,,,,,,,, -321000,4.679669,0.69593954,,,,,,,,,,,,,, -321100,4.595119,0.69381917,,,,,,,,,,,,,, -321200,4.835796,0.67633617,,,,,,,,,,,,,, -321300,4.4894447,0.65218425,,,,,,,,,,,,,, -321400,4.419818,0.594412,,,,,,,,,,,,,, -321500,4.4946346,0.5977094,,,,,,,,,,,,,, -321600,4.634228,0.74148864,,,,,,,,,,,,,, -321629,,,0.9612762928009032,0.1422847658395767,0.7560399770736694,1.0536316633224487,50000.0,0.6273000240325928,1.8425594568252563,10000.0,107674.23223781586,111397.20132136343,107674.23223781586,3699.360595703125,12.066577196121216,0.0 -321700,4.411343,0.63967776,,,,,,,,,,,,,, -321800,4.2567883,0.5952423,,,,,,,,,,,,,, -321900,4.5640492,0.59895927,,,,,,,,,,,,,, -322000,4.9428573,0.5672132,,,,,,,,,,,,,, -322100,4.8435826,0.62843066,,,,,,,,,,,,,, -322200,4.2653418,0.60001934,,,,,,,,,,,,,, -322300,4.601494,0.63718516,,,,,,,,,,,,,, -322400,4.7244525,0.6289087,,,,,,,,,,,,,, -322500,4.4841356,0.67491204,,,,,,,,,,,,,, -322600,4.5308084,0.6572776,,,,,,,,,,,,,, -322700,4.630702,0.5584693,,,,,,,,,,,,,, -322800,4.4137287,0.579581,,,,,,,,,,,,,, -322900,4.520963,0.64597726,,,,,,,,,,,,,, -323000,4.6948767,0.6269561,,,,,,,,,,,,,, -323100,4.527742,0.6275791,,,,,,,,,,,,,, -323153,,,0.9610371589660645,0.145976260304451,0.7555999755859375,1.0547045469284058,50000.0,0.6271000504493713,1.844212532043457,10000.0,108184.30120563509,111924.70087218285,108184.30120563509,3716.6544332504272,12.149267673492432,0.0 -323200,4.594905,0.6094838,,,,,,,,,,,,,, -323300,4.9208865,0.6182875,,,,,,,,,,,,,, -323400,4.2961617,0.56195354,,,,,,,,,,,,,, -323500,4.726416,0.70589614,,,,,,,,,,,,,, -323600,4.79389,0.5662298,,,,,,,,,,,,,, -323700,4.6093173,0.5660513,,,,,,,,,,,,,, -323800,4.4373055,0.6353971,,,,,,,,,,,,,, -323900,5.021484,0.71983045,,,,,,,,,,,,,, -324000,4.3069916,0.6197183,,,,,,,,,,,,,, -324100,4.1804667,0.52790624,,,,,,,,,,,,,, -324200,4.678522,0.60303545,,,,,,,,,,,,,, -324300,4.6651754,0.6872467,,,,,,,,,,,,,, -324400,4.2704825,0.55739087,,,,,,,,,,,,,, -324500,4.443783,0.5934659,,,,,,,,,,,,,, -324600,4.8239346,0.64543116,,,,,,,,,,,,,, -324677,,,0.9600406289100648,0.1472441256046295,0.756119966506958,1.0521292686462402,50000.0,0.6270000338554382,1.8409771919250488,10000.0,108694.235871315,112451.8219499588,108694.235871315,3733.7039256095886,12.232320070266724,0.0 -324700,5.133732,0.6395155,,,,,,,,,,,,,, -324800,5.4298167,0.616972,,,,,,,,,,,,,, -324900,4.6816936,0.5719383,,,,,,,,,,,,,, -325000,4.717876,0.6496892,,,,,,,,,,,,,, -325100,4.6040735,0.6858172,,,,,,,,,,,,,, -325200,4.312839,0.63485163,,,,,,,,,,,,,, -325300,4.3277345,0.5938772,,,,,,,,,,,,,, -325400,4.1613345,0.5851366,,,,,,,,,,,,,, -325500,4.8420835,0.6796772,,,,,,,,,,,,,, -325600,4.5789227,0.56757617,,,,,,,,,,,,,, -325700,4.6140237,0.6836451,,,,,,,,,,,,,, -325800,4.0420594,0.61317074,,,,,,,,,,,,,, -325900,4.3908744,0.6072189,,,,,,,,,,,,,, -326000,4.419718,0.5601298,,,,,,,,,,,,,, -326100,4.5157037,0.5846167,,,,,,,,,,,,,, -326200,4.4443593,0.5959998,,,,,,,,,,,,,, -326201,,,0.9612563848495485,0.1457024663686752,0.7558199763298035,1.053580641746521,50000.0,0.6288000345230103,1.8415600061416624,10000.0,109204.26220989227,112979.1577911377,109204.26220989227,3750.8754110336304,12.315529584884644,0.0 -326300,4.577042,0.6519662,,,,,,,,,,,,,, -326400,5.2210655,0.6074681,,,,,,,,,,,,,, -326500,4.618804,0.58987516,,,,,,,,,,,,,, -326600,4.509489,0.6301748,,,,,,,,,,,,,, -326700,4.5090523,0.5129379,,,,,,,,,,,,,, -326800,4.3098855,0.6288774,,,,,,,,,,,,,, -326900,4.7494674,0.6088107,,,,,,,,,,,,,, -327000,4.51164,0.5978379,,,,,,,,,,,,,, -327100,4.784511,0.62010777,,,,,,,,,,,,,, -327200,5.0789356,0.65813124,,,,,,,,,,,,,, -327300,4.440452,0.5847781,,,,,,,,,,,,,, -327400,4.392024,0.6091965,,,,,,,,,,,,,, -327500,4.4018636,0.6141335,,,,,,,,,,,,,, -327600,4.513411,0.56013525,,,,,,,,,,,,,, -327700,4.157088,0.55720586,,,,,,,,,,,,,, -327724,,,0.962332546710968,0.1425944864749908,0.7559399604797363,1.0540307760238647,50000.0,0.6270000338554382,1.84445321559906,10000.0,109714.14976787569,113506.20581841467,109714.14976787569,3767.899305582048,12.398417234420776,0.0 -327800,4.6425505,0.6112752,,,,,,,,,,,,,, -327900,4.1569047,0.6552156,,,,,,,,,,,,,, -328000,4.955568,0.69451624,,,,,,,,,,,,,, -328100,4.7235966,0.6781241,,,,,,,,,,,,,, -328200,4.4646254,0.66283023,,,,,,,,,,,,,, -328300,4.8661504,0.63367194,,,,,,,,,,,,,, -328400,4.6936784,0.61034393,,,,,,,,,,,,,, -328500,4.671816,0.61194223,,,,,,,,,,,,,, -328600,5.0646286,0.65811884,,,,,,,,,,,,,, -328700,4.491097,0.6485975,,,,,,,,,,,,,, -328800,4.133434,0.6110842,,,,,,,,,,,,,, -328900,4.045182,0.590281,,,,,,,,,,,,,, -329000,4.587124,0.6512917,,,,,,,,,,,,,, -329100,4.4691787,0.57159156,,,,,,,,,,,,,, -329200,4.9500813,0.65586066,,,,,,,,,,,,,, -329248,,,0.9614955186843872,0.1465763002634048,0.7559199929237366,1.0538091659545898,50000.0,0.6265000104904175,1.841633915901184,10000.0,110224.18912863731,114033.2861776352,110224.18912863731,3784.79211306572,12.491395235061646,0.0 -329300,4.7213273,0.5775912,,,,,,,,,,,,,, -329400,3.9405236,0.52720284,,,,,,,,,,,,,, -329500,4.945654,0.6157236,,,,,,,,,,,,,, -329600,4.5948157,0.65584695,,,,,,,,,,,,,, -329700,4.8359904,0.71603185,,,,,,,,,,,,,, -329800,4.534518,0.65845716,,,,,,,,,,,,,, -329900,4.4973483,0.64664304,,,,,,,,,,,,,, -330000,4.5312347,0.66920173,,,,,,,,,,,,,, -330100,5.066451,0.6567005,,,,,,,,,,,,,, -330200,4.457487,0.6217041,,,,,,,,,,,,,, -330300,4.5379114,0.613952,,,,,,,,,,,,,, -330400,4.6532598,0.597669,,,,,,,,,,,,,, -330500,4.4224443,0.60249645,,,,,,,,,,,,,, -330600,4.3410816,0.6448701,,,,,,,,,,,,,, -330700,4.458352,0.5967507,,,,,,,,,,,,,, -330772,,,0.9596420526504515,0.1490146070718765,0.7561399936676025,1.0536932945251465,50000.0,0.6279000043869019,1.84337055683136,10000.0,110734.2365758419,114560.4348537922,110734.2365758419,3801.7138855457306,12.616500854492188,0.0 -330800,4.7432218,0.5832037,,,,,,,,,,,,,, -330900,4.5305047,0.67191756,,,,,,,,,,,,,, -331000,4.5079923,0.6220644,,,,,,,,,,,,,, -331100,4.6037903,0.6124425,,,,,,,,,,,,,, -331200,4.62439,0.6333555,,,,,,,,,,,,,, -331300,4.242093,0.6104363,,,,,,,,,,,,,, -331400,4.2670727,0.55809104,,,,,,,,,,,,,, -331500,4.7223954,0.66218984,,,,,,,,,,,,,, -331600,4.3261228,0.6336334,,,,,,,,,,,,,, -331700,4.6722317,0.65620816,,,,,,,,,,,,,, -331800,4.544023,0.6147425,,,,,,,,,,,,,, -331900,4.6529155,0.5831391,,,,,,,,,,,,,, -332000,4.85957,0.64526004,,,,,,,,,,,,,, -332100,4.657986,0.59570026,,,,,,,,,,,,,, -332200,4.753078,0.71015304,,,,,,,,,,,,,, -332296,,,0.9601203799247742,0.1480126678943634,0.7555999755859375,1.053968071937561,50000.0,0.6276000142097473,1.843773365020752,10000.0,111244.34099173546,115087.61276984216,111244.34099173546,3818.649199962616,12.698453664779665,0.0 -332300,4.6968427,0.633965,,,,,,,,,,,,,, -332400,4.1077814,0.5455712,,,,,,,,,,,,,, -332500,5.221973,0.6801703,,,,,,,,,,,,,, -332600,4.2665606,0.628582,,,,,,,,,,,,,, -332700,4.915156,0.6942594,,,,,,,,,,,,,, -332800,5.086795,0.6448002,,,,,,,,,,,,,, -332900,4.207908,0.60023874,,,,,,,,,,,,,, -333000,4.4420557,0.57514185,,,,,,,,,,,,,, -333100,4.147278,0.61897045,,,,,,,,,,,,,, -333200,4.6128316,0.58394647,,,,,,,,,,,,,, -333300,4.516454,0.6091752,,,,,,,,,,,,,, -333400,5.0176315,0.61039114,,,,,,,,,,,,,, -333500,4.3852525,0.58994603,,,,,,,,,,,,,, -333600,4.546962,0.63690686,,,,,,,,,,,,,, -333700,4.842576,0.63350666,,,,,,,,,,,,,, -333800,4.7374053,0.61165345,,,,,,,,,,,,,, -333820,,,0.9615951776504515,0.1442433297634124,0.7560200095176697,1.0530657768249512,50000.0,0.6271000504493713,1.841819643974304,10000.0,111754.31213140488,115614.44660925864,111754.31213140488,3835.370189905167,12.785258054733276,0.0 -333900,4.440299,0.61359143,,,,,,,,,,,,,, -334000,4.860217,0.69384766,,,,,,,,,,,,,, -334100,4.456586,0.6282486,,,,,,,,,,,,,, -334200,4.8276362,0.67101645,,,,,,,,,,,,,, -334300,4.391693,0.5815157,,,,,,,,,,,,,, -334400,5.0088696,0.6444443,,,,,,,,,,,,,, -334500,4.0115895,0.5622814,,,,,,,,,,,,,, -334600,4.4003577,0.6352922,,,,,,,,,,,,,, -334700,4.810286,0.6082578,,,,,,,,,,,,,, -334800,4.273022,0.5979216,,,,,,,,,,,,,, -334900,4.420077,0.6677989,,,,,,,,,,,,,, -335000,4.43768,0.6821375,,,,,,,,,,,,,, -335100,4.863035,0.6695246,,,,,,,,,,,,,, -335200,5.383155,0.6840083,,,,,,,,,,,,,, -335300,4.3135285,0.63410264,,,,,,,,,,,,,, -335344,,,0.9606385231018066,0.1442461311817169,0.7555599808692932,1.053433537483215,50000.0,0.6282000541687012,1.8424676656723025,10000.0,112264.47166848184,116141.47416877748,112264.47166848184,3852.092576265335,12.87584137916565,0.0 -335400,5.2355623,0.7021054,,,,,,,,,,,,,, -335500,4.167891,0.56407976,,,,,,,,,,,,,, -335600,4.8787036,0.6658014,,,,,,,,,,,,,, -335700,5.235644,0.64394027,,,,,,,,,,,,,, -335800,4.422055,0.6421227,,,,,,,,,,,,,, -335900,4.245442,0.64530486,,,,,,,,,,,,,, -336000,3.8931499,0.5494794,,,,,,,,,,,,,, -336100,4.5718074,0.66694134,,,,,,,,,,,,,, -336200,4.1703677,0.5890503,,,,,,,,,,,,,, -336300,4.7214155,0.6564649,,,,,,,,,,,,,, -336400,4.3497267,0.59931636,,,,,,,,,,,,,, -336500,4.7562313,0.63889134,,,,,,,,,,,,,, -336600,4.6862288,0.5870539,,,,,,,,,,,,,, -336700,4.1014233,0.544674,,,,,,,,,,,,,, -336800,4.537149,0.6835208,,,,,,,,,,,,,, -336868,,,0.9594228267669678,0.1478379368782043,0.7559999823570251,1.0530071258544922,50000.0,0.626800000667572,1.841143488883972,10000.0,112774.3374478817,116668.15985655785,112774.3374478817,3868.7485876083374,12.985454320907593,0.0 -336900,5.0479937,0.7199941,,,,,,,,,,,,,, -337000,4.95998,0.6822443,,,,,,,,,,,,,, -337100,4.4289484,0.5698561,,,,,,,,,,,,,, -337200,4.7430573,0.6446163,,,,,,,,,,,,,, -337300,4.533377,0.59456277,,,,,,,,,,,,,, -337400,5.068889,0.709036,,,,,,,,,,,,,, -337500,4.422884,0.55862457,,,,,,,,,,,,,, -337600,4.400978,0.68442005,,,,,,,,,,,,,, -337700,4.6642046,0.5822145,,,,,,,,,,,,,, -337800,4.5428195,0.6206493,,,,,,,,,,,,,, -337900,4.94666,0.6007383,,,,,,,,,,,,,, -338000,5.0915723,0.6238722,,,,,,,,,,,,,, -338100,5.0150943,0.68250835,,,,,,,,,,,,,, -338200,4.0927844,0.54990864,,,,,,,,,,,,,, -338300,4.4693084,0.6122922,,,,,,,,,,,,,, -338393,,,0.9607780575752258,0.1457557678222656,0.7559199929237366,1.0537904500961304,50000.0,0.6278000473976135,1.84160315990448,10000.0,113284.472843647,117195.11727333067,113284.472843647,3885.435347318649,13.066871643066406,0.0 -338400,4.675785,0.6848557,,,,,,,,,,,,,, -338500,4.962055,0.6918913,,,,,,,,,,,,,, -338600,4.1851735,0.6150007,,,,,,,,,,,,,, -338700,4.3987975,0.6482734,,,,,,,,,,,,,, -338800,4.9461145,0.63492286,,,,,,,,,,,,,, -338900,4.71501,0.6328682,,,,,,,,,,,,,, -339000,4.478322,0.5997659,,,,,,,,,,,,,, -339100,5.2110033,0.7162022,,,,,,,,,,,,,, -339200,4.454237,0.6118002,,,,,,,,,,,,,, -339300,4.866991,0.67173433,,,,,,,,,,,,,, -339400,4.6614933,0.62518704,,,,,,,,,,,,,, -339500,4.958797,0.5980142,,,,,,,,,,,,,, -339600,4.4223156,0.54417926,,,,,,,,,,,,,, -339700,4.5505075,0.56267345,,,,,,,,,,,,,, -339800,4.7936587,0.6874146,,,,,,,,,,,,,, -339900,4.611163,0.6671251,,,,,,,,,,,,,, -339918,,,0.9614157676696776,0.1437328457832336,0.7555599808692932,1.0540083646774292,50000.0,0.6276000142097473,1.843093991279602,10000.0,113794.48411178587,117721.9868991375,113794.48411178587,3902.153110980988,13.152750968933104,0.0 -340000,5.2988763,0.6618776,,,,,,,,,,,,,, -340100,4.6465235,0.60883313,,,,,,,,,,,,,, -340200,5.13873,0.5847331,,,,,,,,,,,,,, -340300,4.412714,0.6015386,,,,,,,,,,,,,, -340400,4.787122,0.5781982,,,,,,,,,,,,,, -340500,4.743993,0.63655865,,,,,,,,,,,,,, -340600,4.6359096,0.6234817,,,,,,,,,,,,,, -340700,5.067225,0.57972574,,,,,,,,,,,,,, -340800,4.4641914,0.61605316,,,,,,,,,,,,,, -340900,3.9626656,0.50631094,,,,,,,,,,,,,, -341000,4.880807,0.65565944,,,,,,,,,,,,,, -341100,4.9145513,0.64066243,,,,,,,,,,,,,, -341200,4.39325,0.56465095,,,,,,,,,,,,,, -341300,4.372186,0.5968086,,,,,,,,,,,,,, -341400,4.05052,0.5903695,,,,,,,,,,,,,, -341443,,,0.9618343114852904,0.1472954601049423,0.7556799650192261,1.0539937019348145,50000.0,0.6276000142097473,1.843406319618225,10000.0,114304.55684185028,118248.95959162712,114304.55684185028,3918.9126632213593,13.23846983909607,0.0 -341500,4.5648937,0.5808953,,,,,,,,,,,,,, -341600,4.367955,0.6298344,,,,,,,,,,,,,, -341700,4.5569773,0.64923865,,,,,,,,,,,,,, -341800,5.1062045,0.6101595,,,,,,,,,,,,,, -341900,4.9121976,0.6449069,,,,,,,,,,,,,, -342000,4.792958,0.6813066,,,,,,,,,,,,,, -342100,4.252537,0.6650547,,,,,,,,,,,,,, -342200,4.5618196,0.6248827,,,,,,,,,,,,,, -342300,4.4668827,0.5019843,,,,,,,,,,,,,, -342400,4.69954,0.6794586,,,,,,,,,,,,,, -342500,4.207176,0.6240276,,,,,,,,,,,,,, -342600,4.7915525,0.6073335,,,,,,,,,,,,,, -342700,4.0795655,0.60079956,,,,,,,,,,,,,, -342800,4.398669,0.5862116,,,,,,,,,,,,,, -342900,4.2947197,0.57440066,,,,,,,,,,,,,, -342968,,,0.9596819281578064,0.14888696372509,0.7560799717903137,1.0539145469665527,50000.0,0.6274000406265259,1.8429653644561768,10000.0,114814.5960021019,118775.81064414978,114814.5960021019,3935.580096244812,13.327780723571776,0.0 -343000,4.803675,0.64895016,,,,,,,,,,,,,, -343100,4.3229656,0.6278236,,,,,,,,,,,,,, -343200,4.734392,0.668716,,,,,,,,,,,,,, -343300,5.0468936,0.6639297,,,,,,,,,,,,,, -343400,4.4870496,0.7126829,,,,,,,,,,,,,, -343500,4.9502273,0.60082334,,,,,,,,,,,,,, -343600,4.4738135,0.55614126,,,,,,,,,,,,,, -343700,4.256067,0.5824509,,,,,,,,,,,,,, -343800,4.3941674,0.6485964,,,,,,,,,,,,,, -343900,4.1783547,0.56446856,,,,,,,,,,,,,, -344000,5.025523,0.56380767,,,,,,,,,,,,,, -344100,4.1928616,0.5626945,,,,,,,,,,,,,, -344200,4.818286,0.6238879,,,,,,,,,,,,,, -344300,4.4116263,0.5799164,,,,,,,,,,,,,, -344400,4.144482,0.60636634,,,,,,,,,,,,,, -344493,,,0.9618741869926452,0.1440886259078979,0.7561799883842468,1.052625298500061,50000.0,0.627500057220459,1.84093177318573,10000.0,115324.70213341711,119302.71807384492,115324.70213341711,3952.235322237015,13.418944597244264,0.0 -344500,4.382407,0.5923748,,,,,,,,,,,,,, -344600,4.583708,0.6535166,,,,,,,,,,,,,, -344700,4.059207,0.50244683,,,,,,,,,,,,,, -344800,4.645161,0.5348308,,,,,,,,,,,,,, -344900,4.879456,0.6500835,,,,,,,,,,,,,, -345000,4.370725,0.577426,,,,,,,,,,,,,, -345100,4.2516856,0.61906356,,,,,,,,,,,,,, -345200,4.4151516,0.6935525,,,,,,,,,,,,,, -345300,4.3453445,0.5860958,,,,,,,,,,,,,, -345400,4.9800587,0.68360996,,,,,,,,,,,,,, -345500,4.8472567,0.65922475,,,,,,,,,,,,,, -345600,4.8611465,0.643688,,,,,,,,,,,,,, -345700,4.8781295,0.660145,,,,,,,,,,,,,, -345800,4.9225063,0.6161298,,,,,,,,,,,,,, -345900,5.071549,0.6426507,,,,,,,,,,,,,, -346000,4.4471745,0.5954699,,,,,,,,,,,,,, -346017,,,0.9626514315605164,0.1440666168928146,0.7559799551963806,1.0532101392745972,50000.0,0.628000020980835,1.8426051139831543,10000.0,115834.6356317997,119829.2996737957,115834.6356317997,3968.7403090000153,13.50695252418518,0.0 -346100,4.3530593,0.5860162,,,,,,,,,,,,,, -346200,4.882333,0.57001686,,,,,,,,,,,,,, -346300,4.4664044,0.6308638,,,,,,,,,,,,,, -346400,4.2386484,0.65731096,,,,,,,,,,,,,, -346500,5.0874124,0.7387228,,,,,,,,,,,,,, -346600,4.5929375,0.6543485,,,,,,,,,,,,,, -346700,4.48911,0.6528621,,,,,,,,,,,,,, -346800,4.828794,0.58604676,,,,,,,,,,,,,, -346900,4.6784744,0.62911654,,,,,,,,,,,,,, -347000,5.043734,0.6081531,,,,,,,,,,,,,, -347100,4.6699204,0.61742216,,,,,,,,,,,,,, -347200,5.271939,0.7404923,,,,,,,,,,,,,, -347300,4.2445426,0.598528,,,,,,,,,,,,,, -347400,4.886315,0.66218525,,,,,,,,,,,,,, -347500,4.380382,0.645456,,,,,,,,,,,,,, -347542,,,0.9629504084587096,0.142887681722641,0.7559599876403809,1.0534168481826782,50000.0,0.627500057220459,1.844247579574585,10000.0,116344.7417371273,120356.30416107178,116344.7417371273,3985.4980008602142,13.59234380722046,0.0 -347600,4.3984594,0.6231453,,,,,,,,,,,,,, -347700,4.6095147,0.5921535,,,,,,,,,,,,,, -347800,4.904277,0.6269226,,,,,,,,,,,,,, -347900,4.577824,0.61260486,,,,,,,,,,,,,, -348000,4.851025,0.6791562,,,,,,,,,,,,,, -348100,4.5190454,0.5928877,,,,,,,,,,,,,, -348200,4.159926,0.54674584,,,,,,,,,,,,,, -348300,4.5752754,0.66094553,,,,,,,,,,,,,, -348400,4.403257,0.6530038,,,,,,,,,,,,,, -348500,4.252965,0.5851627,,,,,,,,,,,,,, -348600,4.484901,0.6437403,,,,,,,,,,,,,, -348700,5.3767776,0.7074002,,,,,,,,,,,,,, -348800,4.5299964,0.5618746,,,,,,,,,,,,,, -348900,4.321933,0.6129445,,,,,,,,,,,,,, -349000,4.1631007,0.6004666,,,,,,,,,,,,,, -349067,,,0.9593430757522584,0.1477753520011901,0.7563599944114685,1.05348002910614,50000.0,0.627500057220459,1.8426477909088133,10000.0,116854.75827503204,120883.71487736702,116854.75827503204,4002.7495708465576,13.679721355438232,0.0 -349100,4.720754,0.6123542,,,,,,,,,,,,,, -349200,4.7528324,0.6151663,,,,,,,,,,,,,, -349300,5.448831,0.6890011,,,,,,,,,,,,,, -349400,4.224067,0.5877481,,,,,,,,,,,,,, -349500,4.428458,0.63561,,,,,,,,,,,,,, -349600,4.537114,0.6070879,,,,,,,,,,,,,, -349700,4.875844,0.64863384,,,,,,,,,,,,,, -349800,4.170148,0.59761417,,,,,,,,,,,,,, -349900,4.498982,0.55908823,,,,,,,,,,,,,, -350000,5.1315084,0.6937246,,,,,,,,,,,,,, -350100,4.287399,0.5942621,,,,,,,,,,,,,, -350200,4.638564,0.57478833,,,,,,,,,,,,,, -350300,4.632641,0.57795507,,,,,,,,,,,,,, -350400,4.2459483,0.5919893,,,,,,,,,,,,,, -350500,4.7028184,0.6830132,,,,,,,,,,,,,, -350591,,,0.9602199792861938,0.1474127322435379,0.7559199929237366,1.0539977550506592,50000.0,0.6266000270843506,1.8422508239746087,10000.0,117364.76794171332,121410.36559605598,117364.76794171332,4019.250088214874,13.76552438735962,0.0 -350600,4.9982862,0.6123718,,,,,,,,,,,,,, -350700,4.808106,0.5311044,,,,,,,,,,,,,, -350800,4.2228913,0.6097907,,,,,,,,,,,,,, -350900,4.8349223,0.7141974,,,,,,,,,,,,,, -351000,4.596806,0.6287464,,,,,,,,,,,,,, -351100,4.6769342,0.57606846,,,,,,,,,,,,,, -351200,4.546255,0.63134533,,,,,,,,,,,,,, -351300,4.3492637,0.5998862,,,,,,,,,,,,,, -351400,5.020262,0.7371373,,,,,,,,,,,,,, -351500,4.536487,0.64538395,,,,,,,,,,,,,, -351600,4.688909,0.5831513,,,,,,,,,,,,,, -351700,4.167508,0.54219115,,,,,,,,,,,,,, -351800,4.5328727,0.5742035,,,,,,,,,,,,,, -351900,4.3326616,0.627595,,,,,,,,,,,,,, -352000,4.244889,0.5971096,,,,,,,,,,,,,, -352100,4.693995,0.607016,,,,,,,,,,,,,, -352115,,,0.9626514315605164,0.1426747590303421,0.7557599544525146,1.054054856300354,50000.0,0.628000020980835,1.8433915376663208,10000.0,117874.66694450378,121936.92624044418,117874.66694450378,4035.767923355103,13.854008674621582,0.0 -352200,4.5920978,0.660217,,,,,,,,,,,,,, -352300,5.320482,0.6421391,,,,,,,,,,,,,, -352400,5.2248616,0.6800363,,,,,,,,,,,,,, -352500,4.6831846,0.6279019,,,,,,,,,,,,,, -352600,4.2978697,0.4997063,,,,,,,,,,,,,, -352700,4.9112196,0.5977183,,,,,,,,,,,,,, -352800,4.7092147,0.65293247,,,,,,,,,,,,,, -352900,4.3277993,0.5488674,,,,,,,,,,,,,, -353000,4.754089,0.69608223,,,,,,,,,,,,,, -353100,4.3775597,0.5655991,,,,,,,,,,,,,, -353200,4.066921,0.525309,,,,,,,,,,,,,, -353300,4.0015187,0.53906125,,,,,,,,,,,,,, -353400,4.4909825,0.5712743,,,,,,,,,,,,,, -353500,4.6550617,0.60845935,,,,,,,,,,,,,, -353600,4.6933236,0.5891309,,,,,,,,,,,,,, -353639,,,0.9620535373687744,0.1417208313941955,0.7559399604797363,1.0543521642684937,50000.0,0.627500057220459,1.84516704082489,10000.0,118384.54222488403,122463.50339007378,118384.54222488403,4052.32389998436,13.944149494171144,0.0 -353700,4.5537553,0.6158,,,,,,,,,,,,,, -353800,4.5556087,0.6415412,,,,,,,,,,,,,, -353900,4.746703,0.64241934,,,,,,,,,,,,,, -354000,4.4344306,0.57767016,,,,,,,,,,,,,, -354100,4.64579,0.63890874,,,,,,,,,,,,,, -354200,5.00494,0.6797525,,,,,,,,,,,,,, -354300,4.6414595,0.6498972,,,,,,,,,,,,,, -354400,5.2685943,0.65350133,,,,,,,,,,,,,, -354500,4.871385,0.6716379,,,,,,,,,,,,,, -354600,3.8485525,0.54202735,,,,,,,,,,,,,, -354700,4.871304,0.58653677,,,,,,,,,,,,,, -354800,4.726465,0.68682444,,,,,,,,,,,,,, -354900,5.088848,0.6493019,,,,,,,,,,,,,, -355000,4.2656384,0.60508233,,,,,,,,,,,,,, -355100,4.2855115,0.5935808,,,,,,,,,,,,,, -355163,,,0.9616549611091614,0.1446973383426666,0.7563599944114685,1.0527466535568235,50000.0,0.626800000667572,1.8399124145507808,10000.0,118894.63495087624,122990.6729967594,118894.63495087624,4069.257670402527,14.031378030776978,0.0 -355200,4.7621517,0.5937921,,,,,,,,,,,,,, -355300,4.8964863,0.6460832,,,,,,,,,,,,,, -355400,4.863089,0.5794146,,,,,,,,,,,,,, -355500,4.2869625,0.61961794,,,,,,,,,,,,,, -355600,3.995392,0.52578133,,,,,,,,,,,,,, -355700,4.673471,0.60932344,,,,,,,,,,,,,, -355800,5.0868716,0.6430489,,,,,,,,,,,,,, -355900,4.511787,0.5656302,,,,,,,,,,,,,, -356000,4.171567,0.5871324,,,,,,,,,,,,,, -356100,4.471042,0.6089712,,,,,,,,,,,,,, -356200,4.2478848,0.5766009,,,,,,,,,,,,,, -356300,4.6210055,0.59499156,,,,,,,,,,,,,, -356400,4.9885526,0.65476674,,,,,,,,,,,,,, -356500,4.152834,0.55797976,,,,,,,,,,,,,, -356600,4.4971476,0.51326257,,,,,,,,,,,,,, -356687,,,0.9599210619926452,0.1480130702257156,0.7562999725341797,1.0537092685699463,50000.0,0.626800000667572,1.842963933944702,10000.0,119404.62611865996,123517.38052773476,119404.62611865996,4085.834371328354,14.115801095962524,0.0 -356700,4.324127,0.6040836,,,,,,,,,,,,,, -356800,4.276443,0.5970985,,,,,,,,,,,,,, -356900,4.4575996,0.63553506,,,,,,,,,,,,,, -357000,5.018613,0.69469833,,,,,,,,,,,,,, -357100,4.7033954,0.65526825,,,,,,,,,,,,,, -357200,4.542482,0.5757791,,,,,,,,,,,,,, -357300,4.568016,0.62138265,,,,,,,,,,,,,, -357400,4.831605,0.66214985,,,,,,,,,,,,,, -357500,4.741725,0.62384343,,,,,,,,,,,,,, -357600,4.644375,0.58460784,,,,,,,,,,,,,, -357700,5.095617,0.622954,,,,,,,,,,,,,, -357800,4.574928,0.65063936,,,,,,,,,,,,,, -357900,4.3507524,0.57225907,,,,,,,,,,,,,, -358000,4.256233,0.62945557,,,,,,,,,,,,,, -358100,4.4034486,0.5899952,,,,,,,,,,,,,, -358200,4.250397,0.5949594,,,,,,,,,,,,,, -358211,,,0.9610769748687744,0.1448649168014526,0.7556999921798706,1.0543800592422483,50000.0,0.6278000473976135,1.8446106910705569,10000.0,119914.49568676949,124043.99854040146,119914.49568676949,4102.438501119614,14.206216812133787,0.0 -358300,4.7705736,0.59722507,,,,,,,,,,,,,, -358400,4.266934,0.59850407,,,,,,,,,,,,,, -358500,4.5324206,0.68991435,,,,,,,,,,,,,, -358600,4.366515,0.6431181,,,,,,,,,,,,,, -358700,4.3492136,0.5788277,,,,,,,,,,,,,, -358800,4.8947945,0.72088426,,,,,,,,,,,,,, -358900,4.24084,0.524834,,,,,,,,,,,,,, -359000,4.6204557,0.63825953,,,,,,,,,,,,,, -359100,5.1978655,0.6439369,,,,,,,,,,,,,, -359200,4.5724664,0.650754,,,,,,,,,,,,,, -359300,4.1748238,0.6403942,,,,,,,,,,,,,, -359400,4.7779775,0.65695393,,,,,,,,,,,,,, -359500,4.974237,0.67518723,,,,,,,,,,,,,, -359600,4.682846,0.66138494,,,,,,,,,,,,,, -359700,4.3613377,0.6751977,,,,,,,,,,,,,, -359735,,,0.9620934128761292,0.1429737657308578,0.7559399604797363,1.0538039207458496,50000.0,0.6267000436782837,1.8425129652023315,10000.0,120424.40429615974,124570.62741327286,120424.40429615974,4119.014720439911,14.295626163482666,0.0 -359800,4.481721,0.587748,,,,,,,,,,,,,, -359900,4.8428783,0.62757117,,,,,,,,,,,,,, -360000,4.7709785,0.6502143,,,,,,,,,,,,,, -360100,4.548736,0.575144,,,,,,,,,,,,,, -360200,4.788381,0.6364978,,,,,,,,,,,,,, -360300,4.3180223,0.5983803,,,,,,,,,,,,,, -360400,4.6452546,0.6112074,,,,,,,,,,,,,, -360500,4.3478966,0.6175717,,,,,,,,,,,,,, -360600,4.4507704,0.56899124,,,,,,,,,,,,,, -360700,4.083419,0.6074369,,,,,,,,,,,,,, -360800,4.602016,0.6104084,,,,,,,,,,,,,, -360900,4.7621512,0.6051906,,,,,,,,,,,,,, -361000,4.309736,0.58126724,,,,,,,,,,,,,, -361100,4.465329,0.6158145,,,,,,,,,,,,,, -361200,4.517334,0.6105576,,,,,,,,,,,,,, -361259,,,0.960957407951355,0.1437092125415802,0.7561599612236023,1.053349852561951,50000.0,0.6271000504493713,1.8423242568969729,10000.0,120934.38921999931,125097.53248429298,120934.38921999931,4135.7907111644745,14.385024547576904,0.0 -361300,4.6296525,0.7173545,,,,,,,,,,,,,, -361400,4.8526945,0.60655606,,,,,,,,,,,,,, -361500,4.334566,0.61974645,,,,,,,,,,,,,, -361600,4.368927,0.5516511,,,,,,,,,,,,,, -361700,4.5360174,0.62579465,,,,,,,,,,,,,, -361800,4.411802,0.650821,,,,,,,,,,,,,, -361900,4.776441,0.6153837,,,,,,,,,,,,,, -362000,4.2054753,0.5942057,,,,,,,,,,,,,, -362100,4.329519,0.52145445,,,,,,,,,,,,,, -362200,4.2719135,0.6086024,,,,,,,,,,,,,, -362300,4.9516606,0.61176187,,,,,,,,,,,,,, -362400,4.339695,0.6434571,,,,,,,,,,,,,, -362500,4.8670597,0.61119425,,,,,,,,,,,,,, -362600,4.3839417,0.61720204,,,,,,,,,,,,,, -362700,4.7037573,0.67761433,,,,,,,,,,,,,, -362784,,,0.9610570669174194,0.1475917845964431,0.7564399838447571,1.053044080734253,50000.0,0.6282000541687012,1.8418200016021729,10000.0,121444.4788107872,125624.42392516136,121444.4788107872,4152.445852518082,14.476088762283323,0.0 -362800,4.6454973,0.6871325,,,,,,,,,,,,,, -362900,4.217958,0.5805412,,,,,,,,,,,,,, -363000,4.8422217,0.6609948,,,,,,,,,,,,,, -363100,4.420447,0.60286653,,,,,,,,,,,,,, -363200,4.620276,0.57027245,,,,,,,,,,,,,, -363300,4.903365,0.6761302,,,,,,,,,,,,,, -363400,4.6130695,0.64837974,,,,,,,,,,,,,, -363500,4.41427,0.55743563,,,,,,,,,,,,,, -363600,4.7377667,0.6742585,,,,,,,,,,,,,, -363700,4.494305,0.5683935,,,,,,,,,,,,,, -363800,4.3556733,0.59887296,,,,,,,,,,,,,, -363900,4.4557843,0.67120224,,,,,,,,,,,,,, -364000,5.144181,0.62398964,,,,,,,,,,,,,, -364100,4.471108,0.6036712,,,,,,,,,,,,,, -364200,5.110281,0.62657064,,,,,,,,,,,,,, -364300,4.3892565,0.5833502,,,,,,,,,,,,,, -364309,,,0.960598647594452,0.1446724534034729,0.7563199996948242,1.0528602600097656,50000.0,0.6279000043869019,1.841619849205017,10000.0,121954.5669374466,126151.28058385848,121954.5669374466,4169.068185567856,14.566959142684937,0.0 -364400,4.9128404,0.69588786,,,,,,,,,,,,,, -364500,4.426611,0.6154803,,,,,,,,,,,,,, -364600,5.416843,0.6620182,,,,,,,,,,,,,, -364700,4.9065447,0.68053883,,,,,,,,,,,,,, -364800,4.5982914,0.5486194,,,,,,,,,,,,,, -364900,4.371597,0.5760294,,,,,,,,,,,,,, -365000,4.6589427,0.6635458,,,,,,,,,,,,,, -365100,4.466582,0.62709075,,,,,,,,,,,,,, -365200,4.359466,0.62265325,,,,,,,,,,,,,, -365300,4.8466377,0.7208202,,,,,,,,,,,,,, -365400,5.1785808,0.66601825,,,,,,,,,,,,,, -365500,4.3422465,0.617755,,,,,,,,,,,,,, -365600,4.65175,0.6190231,,,,,,,,,,,,,, -365700,5.008307,0.5894514,,,,,,,,,,,,,, -365800,4.3343987,0.59938705,,,,,,,,,,,,,, -365833,,,0.9619140625,0.1466599553823471,0.7559599876403809,1.0532474517822266,50000.0,0.626800000667572,1.841666460037232,10000.0,122464.4394042492,126677.76492261888,122464.4394042492,4185.526557207108,14.665619611740112,0.0 -365900,4.416085,0.5714052,,,,,,,,,,,,,, -366000,4.799513,0.6988641,,,,,,,,,,,,,, -366100,4.6394215,0.65185213,,,,,,,,,,,,,, -366200,4.1407037,0.61320674,,,,,,,,,,,,,, -366300,4.6319466,0.61282957,,,,,,,,,,,,,, -366400,4.4228387,0.5524545,,,,,,,,,,,,,, -366500,4.6489263,0.69901013,,,,,,,,,,,,,, -366600,4.6157,0.541678,,,,,,,,,,,,,, -366700,4.623209,0.6815089,,,,,,,,,,,,,, -366800,4.3407793,0.59629905,,,,,,,,,,,,,, -366900,5.1628103,0.70967776,,,,,,,,,,,,,, -367000,4.466499,0.60572284,,,,,,,,,,,,,, -367100,4.546742,0.64450586,,,,,,,,,,,,,, -367200,4.377045,0.6872741,,,,,,,,,,,,,, -367300,4.489026,0.55696166,,,,,,,,,,,,,, -367358,,,0.9614955186843872,0.1442684829235077,0.7562800049781799,1.0540716648101809,50000.0,0.6270000338554382,1.843050837516785,10000.0,122974.46217608452,127204.42240476608,122974.46217608452,4202.016310453415,14.755435466766356,0.0 -367400,4.8250265,0.69709873,,,,,,,,,,,,,, -367500,4.8694057,0.6981131,,,,,,,,,,,,,, -367600,4.711163,0.6743454,,,,,,,,,,,,,, -367700,4.407826,0.54932564,,,,,,,,,,,,,, -367800,4.3099318,0.6333386,,,,,,,,,,,,,, -367900,4.8889256,0.6797991,,,,,,,,,,,,,, -368000,4.785171,0.54610735,,,,,,,,,,,,,, -368100,4.804336,0.62402046,,,,,,,,,,,,,, -368200,4.2538605,0.6237004,,,,,,,,,,,,,, -368300,4.596433,0.60100734,,,,,,,,,,,,,, -368400,4.413168,0.6091642,,,,,,,,,,,,,, -368500,4.725089,0.57611805,,,,,,,,,,,,,, -368600,4.662274,0.55458814,,,,,,,,,,,,,, -368700,4.9006705,0.59822,,,,,,,,,,,,,, -368800,4.448562,0.5347446,,,,,,,,,,,,,, -368883,,,0.9606983065605164,0.1458490192890167,0.7562800049781799,1.0531740188598633,50000.0,0.6273000240325928,1.842363953590393,10000.0,123484.38162231444,127731.15286946297,123484.38162231444,4218.683201313019,14.845154762268066,0.0 -368900,4.2674103,0.64299166,,,,,,,,,,,,,, -369000,4.7653866,0.5879875,,,,,,,,,,,,,, -369100,4.7617354,0.6517372,,,,,,,,,,,,,, -369200,4.576288,0.6450016,,,,,,,,,,,,,, -369300,4.487176,0.62720925,,,,,,,,,,,,,, -369400,4.5061684,0.62774265,,,,,,,,,,,,,, -369500,4.5254016,0.64605427,,,,,,,,,,,,,, -369600,4.2254457,0.6508032,,,,,,,,,,,,,, -369700,4.7939987,0.64948654,,,,,,,,,,,,,, -369800,4.641887,0.56543636,,,,,,,,,,,,,, -369900,4.3103747,0.6490163,,,,,,,,,,,,,, -370000,4.554194,0.60302234,,,,,,,,,,,,,, -370100,4.9651346,0.58671373,,,,,,,,,,,,,, -370200,4.359807,0.6136077,,,,,,,,,,,,,, -370300,5.193597,0.62443024,,,,,,,,,,,,,, -370400,5.30325,0.64616823,,,,,,,,,,,,,, -370408,,,0.9596619606018066,0.1495345383882522,0.7561599612236023,1.0537803173065186,50000.0,0.6269000172615051,1.842165470123291,10000.0,123994.45357394218,128257.9499156475,123994.45357394218,4235.262308835983,14.935875177383425,0.0 -370500,4.734705,0.64015275,,,,,,,,,,,,,, -370600,4.494413,0.65149903,,,,,,,,,,,,,, -370700,4.7493243,0.55868834,,,,,,,,,,,,,, -370800,4.461955,0.6098252,,,,,,,,,,,,,, -370900,4.3255954,0.6417287,,,,,,,,,,,,,, -371000,4.7091975,0.71429485,,,,,,,,,,,,,, -371100,4.934977,0.6729047,,,,,,,,,,,,,, -371200,4.490164,0.54829156,,,,,,,,,,,,,, -371300,4.9185834,0.65131974,,,,,,,,,,,,,, -371400,4.7678,0.6374922,,,,,,,,,,,,,, -371500,4.267775,0.52421767,,,,,,,,,,,,,, -371600,5.1319466,0.6786487,,,,,,,,,,,,,, -371700,5.0373945,0.65147525,,,,,,,,,,,,,, -371800,4.4762545,0.59093523,,,,,,,,,,,,,, -371900,4.5049067,0.6796789,,,,,,,,,,,,,, -371933,,,0.9612364172935486,0.1453320086002349,0.756060004234314,1.0531513690948486,50000.0,0.6274000406265259,1.841818809509277,10000.0,124504.32865929604,128784.52232336998,124504.32865929604,4251.81521987915,15.023420572280884,0.0 -372000,4.4058194,0.61801755,,,,,,,,,,,,,, -372100,4.233122,0.55002886,,,,,,,,,,,,,, -372200,4.644502,0.7036308,,,,,,,,,,,,,, -372300,4.1319056,0.6056924,,,,,,,,,,,,,, -372400,4.4895797,0.61421597,,,,,,,,,,,,,, -372500,5.026216,0.6675355,,,,,,,,,,,,,, -372600,4.7194405,0.64570934,,,,,,,,,,,,,, -372700,4.7676473,0.71594125,,,,,,,,,,,,,, -372800,4.726285,0.6435366,,,,,,,,,,,,,, -372900,5.379614,0.6297941,,,,,,,,,,,,,, -373000,5.0430846,0.6347373,,,,,,,,,,,,,, -373100,4.5195084,0.5994432,,,,,,,,,,,,,, -373200,4.507314,0.6092942,,,,,,,,,,,,,, -373300,4.2087784,0.6091007,,,,,,,,,,,,,, -373400,4.3757973,0.5112827,,,,,,,,,,,,,, -373458,,,0.9603196382522584,0.1460485756397247,0.756060004234314,1.054241180419922,50000.0,0.6274000406265259,1.8427166938781736,10000.0,125014.37292766573,129311.163510561,125014.37292766573,4268.265427350998,15.114747524261476,0.0 -373500,4.6051445,0.6471777,,,,,,,,,,,,,, -373600,5.1513753,0.7037379,,,,,,,,,,,,,, -373700,4.8705435,0.6358535,,,,,,,,,,,,,, -373800,4.367916,0.59436953,,,,,,,,,,,,,, -373900,4.624549,0.6558679,,,,,,,,,,,,,, -374000,5.275878,0.7258253,,,,,,,,,,,,,, -374100,4.4355626,0.6073169,,,,,,,,,,,,,, -374200,4.6671166,0.59789747,,,,,,,,,,,,,, -374300,5.202804,0.6652226,,,,,,,,,,,,,, -374400,4.5504193,0.5913762,,,,,,,,,,,,,, -374500,4.6144977,0.6995575,,,,,,,,,,,,,, -374600,4.929761,0.6164821,,,,,,,,,,,,,, -374700,4.5366297,0.6523293,,,,,,,,,,,,,, -374800,4.224766,0.5925292,,,,,,,,,,,,,, -374900,5.4439073,0.69955355,,,,,,,,,,,,,, -374982,,,0.961136758327484,0.1452155411243438,0.7561799883842468,1.052392840385437,50000.0,0.6284000277519226,1.8403536081314087,10000.0,125524.3220334053,129838.06326699255,125524.3220334053,4285.071767568588,15.203657388687134,0.0 -375000,5.3364773,0.6388816,,,,,,,,,,,,,, -375100,4.4653873,0.6303929,,,,,,,,,,,,,, -375200,4.524828,0.65869457,,,,,,,,,,,,,, -375300,4.269465,0.6582655,,,,,,,,,,,,,, -375400,4.774466,0.5978791,,,,,,,,,,,,,, -375500,4.4848585,0.6891537,,,,,,,,,,,,,, -375600,4.937458,0.7176263,,,,,,,,,,,,,, -375700,4.591949,0.6038764,,,,,,,,,,,,,, -375800,4.45897,0.5914264,,,,,,,,,,,,,, -375900,4.569869,0.6501758,,,,,,,,,,,,,, -376000,4.472265,0.61526144,,,,,,,,,,,,,, -376100,4.237666,0.63746685,,,,,,,,,,,,,, -376200,5.9369335,0.61401874,,,,,,,,,,,,,, -376300,4.414469,0.61861074,,,,,,,,,,,,,, -376400,4.5662475,0.56963336,,,,,,,,,,,,,, -376500,4.538837,0.63542414,,,,,,,,,,,,,, -376506,,,0.9604990482330322,0.1470222175121307,0.756339967250824,1.051685094833374,50000.0,0.6279000043869019,1.839387059211731,10000.0,126034.21823072432,130364.93006849287,126034.21823072432,4301.894116163254,15.297083377838137,0.0 -376600,4.963301,0.7027613,,,,,,,,,,,,,, -376700,4.271407,0.6012196,,,,,,,,,,,,,, -376800,4.5377016,0.5718601,,,,,,,,,,,,,, -376900,4.437039,0.5744441,,,,,,,,,,,,,, -377000,4.5325794,0.5976281,,,,,,,,,,,,,, -377100,4.3960247,0.60661966,,,,,,,,,,,,,, -377200,4.4470916,0.5689784,,,,,,,,,,,,,, -377300,4.446763,0.6004401,,,,,,,,,,,,,, -377400,4.2217336,0.5350914,,,,,,,,,,,,,, -377500,4.7421665,0.67927146,,,,,,,,,,,,,, -377600,4.9104457,0.61514705,,,,,,,,,,,,,, -377700,4.1810374,0.5497334,,,,,,,,,,,,,, -377800,4.69415,0.6151268,,,,,,,,,,,,,, -377900,4.745077,0.6044414,,,,,,,,,,,,,, -378000,4.862809,0.70782995,,,,,,,,,,,,,, -378031,,,0.9604392051696776,0.1461730599403381,0.7559999823570251,1.0533385276794434,50000.0,0.626800000667572,1.8416091203689573,10000.0,126544.26707220078,130891.7034125328,126544.26707220078,4318.468457937241,15.392176628112791,0.0 -378100,5.07349,0.56617737,,,,,,,,,,,,,, -378200,4.6115904,0.6366991,,,,,,,,,,,,,, -378300,4.8111563,0.6562048,,,,,,,,,,,,,, -378400,4.595291,0.6615926,,,,,,,,,,,,,, -378500,4.2763968,0.53414094,,,,,,,,,,,,,, -378600,4.434537,0.64662325,,,,,,,,,,,,,, -378700,4.5122824,0.608935,,,,,,,,,,,,,, -378800,4.570555,0.66977775,,,,,,,,,,,,,, -378900,5.1025205,0.6203991,,,,,,,,,,,,,, -379000,4.2380886,0.59024835,,,,,,,,,,,,,, -379100,3.9928696,0.57826936,,,,,,,,,,,,,, -379200,4.3335824,0.60898453,,,,,,,,,,,,,, -379300,4.1048055,0.5804056,,,,,,,,,,,,,, -379400,4.6396947,0.6361918,,,,,,,,,,,,,, -379500,4.047763,0.5570131,,,,,,,,,,,,,, -379555,,,0.9622528553009032,0.146114632487297,0.7559799551963806,1.0539250373840332,50000.0,0.626800000667572,1.84324848651886,10000.0,127054.2448322773,131418.3648777008,127054.2448322773,4335.007391452789,15.482544660568236,0.0 -379600,4.893714,0.6348964,,,,,,,,,,,,,, -379700,5.0514264,0.7395831,,,,,,,,,,,,,, -379800,4.91245,0.70266676,,,,,,,,,,,,,, -379900,4.6216598,0.637868,,,,,,,,,,,,,, -380000,4.7132344,0.60873526,,,,,,,,,,,,,, -380100,4.850449,0.6701118,,,,,,,,,,,,,, -380200,4.6650114,0.5911781,,,,,,,,,,,,,, -380300,4.4497457,0.67090154,,,,,,,,,,,,,, -380400,4.6637554,0.5952564,,,,,,,,,,,,,, -380500,4.2759233,0.5641481,,,,,,,,,,,,,, -380600,4.1793685,0.5974206,,,,,,,,,,,,,, -380700,5.6553993,0.67974436,,,,,,,,,,,,,, -380800,4.129677,0.6312752,,,,,,,,,,,,,, -380900,4.165393,0.6084823,,,,,,,,,,,,,, -381000,5.2079887,0.67285573,,,,,,,,,,,,,, -381080,,,0.9605189561843872,0.1476415544748306,0.7562800049781799,1.0527596473693848,50000.0,0.6282000541687012,1.8413126468658447,10000.0,127564.3393175602,131945.17438149452,127564.3393175602,4351.570056438446,15.580342531204224,0.0 -381100,4.755302,0.5647427,,,,,,,,,,,,,, -381200,4.8311296,0.6614523,,,,,,,,,,,,,, -381300,4.6864476,0.6244268,,,,,,,,,,,,,, -381400,4.6720643,0.6906994,,,,,,,,,,,,,, -381500,4.921909,0.7038113,,,,,,,,,,,,,, -381600,4.52992,0.60566366,,,,,,,,,,,,,, -381700,4.2422886,0.6015773,,,,,,,,,,,,,, -381800,4.769185,0.6708354,,,,,,,,,,,,,, -381900,4.484877,0.6658801,,,,,,,,,,,,,, -382000,4.87947,0.63377285,,,,,,,,,,,,,, -382100,4.3490233,0.6388315,,,,,,,,,,,,,, -382200,4.5318203,0.57636774,,,,,,,,,,,,,, -382300,4.006895,0.51163346,,,,,,,,,,,,,, -382400,4.4593253,0.6518216,,,,,,,,,,,,,, -382500,4.4525576,0.6424578,,,,,,,,,,,,,, -382600,4.8046107,0.70563865,,,,,,,,,,,,,, -382605,,,0.9604392051696776,0.1470810174942016,0.7559999823570251,1.0539182424545288,50000.0,0.6281000375747681,1.8423075675964355,10000.0,128074.41396832466,132472.1737203598,128074.41396832466,4368.340341567993,15.680543184280396,0.0 -382700,4.859433,0.65386856,,,,,,,,,,,,,, -382800,4.613312,0.68933976,,,,,,,,,,,,,, -382900,4.606376,0.6291009,,,,,,,,,,,,,, -383000,4.890321,0.6250304,,,,,,,,,,,,,, -383100,4.1122155,0.6089286,,,,,,,,,,,,,, -383200,4.610498,0.6376902,,,,,,,,,,,,,, -383300,4.294504,0.6267297,,,,,,,,,,,,,, -383400,4.663025,0.6257566,,,,,,,,,,,,,, -383500,4.335266,0.6139528,,,,,,,,,,,,,, -383600,4.686686,0.6718884,,,,,,,,,,,,,, -383700,5.3064632,0.6191869,,,,,,,,,,,,,, -383800,4.3366923,0.5729034,,,,,,,,,,,,,, -383900,4.90602,0.74793524,,,,,,,,,,,,,, -384000,4.543725,0.6279003,,,,,,,,,,,,,, -384100,4.294996,0.6361302,,,,,,,,,,,,,, -384130,,,0.9626514315605164,0.1441816985607147,0.7557799816131592,1.0531625747680664,50000.0,0.626800000667572,1.8423826694488523,10000.0,128584.54533720016,132998.9505019188,128584.54533720016,4384.839435815811,15.771271228790283,0.0 -384200,4.294458,0.53820485,,,,,,,,,,,,,, -384300,4.4001617,0.5909945,,,,,,,,,,,,,, -384400,4.86938,0.6382538,,,,,,,,,,,,,, -384500,4.485346,0.61882466,,,,,,,,,,,,,, -384600,4.2950287,0.5994375,,,,,,,,,,,,,, -384700,5.1237435,0.7229401,,,,,,,,,,,,,, -384800,4.468465,0.6094871,,,,,,,,,,,,,, -384900,4.5747266,0.65083796,,,,,,,,,,,,,, -385000,4.157121,0.6405848,,,,,,,,,,,,,, -385100,4.774449,0.6819928,,,,,,,,,,,,,, -385200,4.525742,0.6391913,,,,,,,,,,,,,, -385300,5.289426,0.6610926,,,,,,,,,,,,,, -385400,4.802059,0.6288741,,,,,,,,,,,,,, -385500,4.89622,0.62937546,,,,,,,,,,,,,, -385600,4.5279655,0.6019144,,,,,,,,,,,,,, -385654,,,0.9628706574440002,0.1429570764303207,0.7558000087738037,1.0541527271270752,50000.0,0.6274000406265259,1.84333074092865,10000.0,129094.4659075737,133525.52531456947,129094.4659075737,4401.344930171967,15.865436553955078,0.0 -385700,5.284791,0.5837804,,,,,,,,,,,,,, -385800,5.0472817,0.67359746,,,,,,,,,,,,,, -385900,4.3618865,0.5857278,,,,,,,,,,,,,, -386000,4.5620103,0.6201713,,,,,,,,,,,,,, -386100,4.841622,0.64516205,,,,,,,,,,,,,, -386200,4.570787,0.64285415,,,,,,,,,,,,,, -386300,4.47199,0.60964817,,,,,,,,,,,,,, -386400,4.449661,0.5779889,,,,,,,,,,,,,, -386500,4.8961363,0.6367835,,,,,,,,,,,,,, -386600,4.258984,0.5762513,,,,,,,,,,,,,, -386700,4.1271186,0.5803536,,,,,,,,,,,,,, -386800,4.711425,0.6510763,,,,,,,,,,,,,, -386900,4.6192355,0.6315368,,,,,,,,,,,,,, -387000,4.4869146,0.5782626,,,,,,,,,,,,,, -387100,4.605291,0.59885895,,,,,,,,,,,,,, -387179,,,0.9623923301696776,0.143512025475502,0.7556999921798706,1.053622841835022,50000.0,0.628000020980835,1.8408176898956297,10000.0,129604.40239357948,134052.86649990082,129604.40239357948,4418.605168104172,15.95547103881836,0.0 -387200,4.163128,0.5586188,,,,,,,,,,,,,, -387300,4.446133,0.621164,,,,,,,,,,,,,, -387400,4.16607,0.558805,,,,,,,,,,,,,, -387500,4.355813,0.6580829,,,,,,,,,,,,,, -387600,4.317009,0.5707589,,,,,,,,,,,,,, -387700,4.607883,0.61017776,,,,,,,,,,,,,, -387800,4.3962736,0.6600845,,,,,,,,,,,,,, -387900,4.5812764,0.55272603,,,,,,,,,,,,,, -388000,4.1637626,0.6329666,,,,,,,,,,,,,, -388100,4.797816,0.645648,,,,,,,,,,,,,, -388200,4.727269,0.60134536,,,,,,,,,,,,,, -388300,4.186329,0.61181563,,,,,,,,,,,,,, -388400,4.631975,0.56127065,,,,,,,,,,,,,, -388500,4.847202,0.61159515,,,,,,,,,,,,,, -388600,4.685605,0.57739615,,,,,,,,,,,,,, -388700,4.6559296,0.6148033,,,,,,,,,,,,,, -388704,,,0.959582269191742,0.14606274664402,0.7560999989509583,1.053932547569275,50000.0,0.6267000436782837,1.8424186706542969,10000.0,130114.50559592248,134579.7143881321,130114.50559592248,4435.200484991074,16.04998469352722,0.0 -388800,4.614242,0.6317114,,,,,,,,,,,,,, -388900,4.660522,0.61685234,,,,,,,,,,,,,, -389000,4.22918,0.60070646,,,,,,,,,,,,,, -389100,4.3360662,0.6611316,,,,,,,,,,,,,, -389200,4.490624,0.64327574,,,,,,,,,,,,,, -389300,4.420104,0.55889004,,,,,,,,,,,,,, -389400,5.0765505,0.5922036,,,,,,,,,,,,,, -389500,4.479797,0.64651835,,,,,,,,,,,,,, -389600,4.247595,0.5760437,,,,,,,,,,,,,, -389700,4.637978,0.62548125,,,,,,,,,,,,,, -389800,4.740063,0.65039825,,,,,,,,,,,,,, -389900,4.323768,0.57188874,,,,,,,,,,,,,, -390000,4.7068295,0.66700053,,,,,,,,,,,,,, -390100,4.3156238,0.5717961,,,,,,,,,,,,,, -390200,4.836597,0.69892913,,,,,,,,,,,,,, -390229,,,0.9609375,0.1475533097982406,0.7559999823570251,1.0535551309585571,50000.0,0.6267000436782837,1.843072533607483,10000.0,130624.4790096283,135106.39228892326,130624.4790096283,4451.753606081009,16.146687269210815,0.0 -390300,4.4582705,0.56058985,,,,,,,,,,,,,, -390400,4.706483,0.5939152,,,,,,,,,,,,,, -390500,4.975193,0.6606972,,,,,,,,,,,,,, -390600,4.6812167,0.7065346,,,,,,,,,,,,,, -390700,4.388336,0.5849318,,,,,,,,,,,,,, -390800,4.339349,0.5485289,,,,,,,,,,,,,, -390900,4.4621873,0.6725968,,,,,,,,,,,,,, -391000,4.560135,0.64055115,,,,,,,,,,,,,, -391100,4.307185,0.6479758,,,,,,,,,,,,,, -391200,5.2859573,0.6347821,,,,,,,,,,,,,, -391300,5.1133747,0.67247045,,,,,,,,,,,,,, -391400,4.6223874,0.6969708,,,,,,,,,,,,,, -391500,4.481446,0.6058425,,,,,,,,,,,,,, -391600,4.7240653,0.61150485,,,,,,,,,,,,,, -391700,4.3216968,0.5921908,,,,,,,,,,,,,, -391754,,,0.962511956691742,0.141088455915451,0.7557799816131592,1.0546152591705322,50000.0,0.6271000504493713,1.8442800045013428,10000.0,131134.4970755577,135633.17174863815,131134.4970755577,4468.369817733765,16.23644709587097,0.0 -391800,4.3522487,0.66300166,,,,,,,,,,,,,, -391900,4.772311,0.642238,,,,,,,,,,,,,, -392000,4.789434,0.69277513,,,,,,,,,,,,,, -392100,4.344675,0.5742065,,,,,,,,,,,,,, -392200,4.3115616,0.51913613,,,,,,,,,,,,,, -392300,4.5614457,0.62010604,,,,,,,,,,,,,, -392400,4.5149646,0.6433862,,,,,,,,,,,,,, -392500,4.856274,0.63824403,,,,,,,,,,,,,, -392600,4.1991224,0.5658716,,,,,,,,,,,,,, -392700,4.5043693,0.58493406,,,,,,,,,,,,,, -392800,4.4359083,0.6276027,,,,,,,,,,,,,, -392900,4.6736875,0.6648737,,,,,,,,,,,,,, -393000,4.9423356,0.7120233,,,,,,,,,,,,,, -393100,4.3669624,0.6451704,,,,,,,,,,,,,, -393200,5.2660933,0.6560161,,,,,,,,,,,,,, -393279,,,0.960957407951355,0.145300954580307,0.7558799982070923,1.0527784824371338,50000.0,0.626800000667572,1.84110689163208,10000.0,131644.5588364601,136159.93689537048,131644.5588364601,4484.873900651932,16.379529237747192,0.0 -393300,4.224614,0.6249531,,,,,,,,,,,,,, -393400,4.839012,0.65649664,,,,,,,,,,,,,, -393500,4.3382435,0.6173708,,,,,,,,,,,,,, -393600,4.6516795,0.6471585,,,,,,,,,,,,,, -393700,4.2562637,0.5663575,,,,,,,,,,,,,, -393800,4.9490232,0.5967469,,,,,,,,,,,,,, -393900,4.2314987,0.5435777,,,,,,,,,,,,,, -394000,4.6409087,0.6651332,,,,,,,,,,,,,, -394100,4.244904,0.58774173,,,,,,,,,,,,,, -394200,4.3654666,0.64899,,,,,,,,,,,,,, -394300,4.5261993,0.6402139,,,,,,,,,,,,,, -394400,4.5889416,0.6489229,,,,,,,,,,,,,, -394500,4.987117,0.6771895,,,,,,,,,,,,,, -394600,4.4318686,0.6220155,,,,,,,,,,,,,, -394700,4.6265473,0.6469878,,,,,,,,,,,,,, -394800,4.5867205,0.6156324,,,,,,,,,,,,,, -394804,,,0.9612962007522584,0.1456609070301056,0.7558199763298035,1.0530370473861694,50000.0,0.626800000667572,1.8425297737121584,10000.0,132154.6266951561,136686.70456910133,132154.6266951561,4501.413993358612,16.484485864639282,0.0 -394900,4.2787757,0.55676657,,,,,,,,,,,,,, -395000,4.7575955,0.61533403,,,,,,,,,,,,,, -395100,4.868369,0.60918164,,,,,,,,,,,,,, -395200,5.269859,0.6302693,,,,,,,,,,,,,, -395300,4.820743,0.5959239,,,,,,,,,,,,,, -395400,4.608847,0.62924516,,,,,,,,,,,,,, -395500,4.5329227,0.6724152,,,,,,,,,,,,,, -395600,4.94033,0.6718744,,,,,,,,,,,,,, -395700,4.5037336,0.544772,,,,,,,,,,,,,, -395800,4.187406,0.62914294,,,,,,,,,,,,,, -395900,4.217268,0.5372906,,,,,,,,,,,,,, -396000,4.3405147,0.59384114,,,,,,,,,,,,,, -396100,4.3476686,0.6245761,,,,,,,,,,,,,, -396200,5.1680446,0.66577435,,,,,,,,,,,,,, -396300,4.32628,0.6309395,,,,,,,,,,,,,, -396328,,,0.9614955186843872,0.1447633355855941,0.7559199929237366,1.0543824434280396,50000.0,0.6260000467300415,1.84261167049408,10000.0,132664.5911166668,137213.3666226864,132664.5911166668,4517.964068174362,16.57689118385315,0.0 -396400,4.746836,0.6348984,,,,,,,,,,,,,, -396500,4.590897,0.58669424,,,,,,,,,,,,,, -396600,4.2749667,0.6153916,,,,,,,,,,,,,, -396700,4.0807915,0.5980109,,,,,,,,,,,,,, -396800,4.584858,0.5783884,,,,,,,,,,,,,, -396900,4.031484,0.53239,,,,,,,,,,,,,, -397000,4.3296156,0.6548803,,,,,,,,,,,,,, -397100,5.3792205,0.6590669,,,,,,,,,,,,,, -397200,4.5888357,0.64850795,,,,,,,,,,,,,, -397300,4.551784,0.62995654,,,,,,,,,,,,,, -397400,4.7102222,0.6457459,,,,,,,,,,,,,, -397500,5.046377,0.69644266,,,,,,,,,,,,,, -397600,4.2635374,0.6126768,,,,,,,,,,,,,, -397700,4.4332304,0.5894612,,,,,,,,,,,,,, -397800,4.300905,0.5825961,,,,,,,,,,,,,, -397853,,,0.960558831691742,0.1463161557912826,0.7558199763298035,1.0536457300186155,50000.0,0.626800000667572,1.8429162502288816,10000.0,133174.60023641586,137740.03743171692,133174.60023641586,4534.476091146469,16.671013355255127,0.0 -397900,4.4314094,0.5833553,,,,,,,,,,,,,, -398000,4.678286,0.58510333,,,,,,,,,,,,,, -398100,4.634236,0.66838324,,,,,,,,,,,,,, -398200,4.858269,0.66020346,,,,,,,,,,,,,, -398300,5.216693,0.6422256,,,,,,,,,,,,,, -398400,5.044062,0.661703,,,,,,,,,,,,,, -398500,4.4122577,0.5908829,,,,,,,,,,,,,, -398600,4.330435,0.6529561,,,,,,,,,,,,,, -398700,4.6775494,0.62322795,,,,,,,,,,,,,, -398800,4.228154,0.5929808,,,,,,,,,,,,,, -398900,4.717573,0.6309696,,,,,,,,,,,,,, -399000,4.679638,0.65092254,,,,,,,,,,,,,, -399100,4.619471,0.62308323,,,,,,,,,,,,,, -399200,4.544109,0.661158,,,,,,,,,,,,,, -399300,4.750721,0.6177571,,,,,,,,,,,,,, -399378,,,0.9625717401504515,0.1407064348459243,0.7558599710464478,1.053896427154541,50000.0,0.627500057220459,1.8440628051757808,10000.0,133684.56909918785,138266.5942044258,133684.56909918785,4550.910442829132,16.769649982452393,0.0 -399400,4.7569065,0.6480768,,,,,,,,,,,,,, -399500,5.264253,0.5848545,,,,,,,,,,,,,, -399600,4.8558807,0.65325344,,,,,,,,,,,,,, -399700,4.6845584,0.6614884,,,,,,,,,,,,,, -399800,4.75021,0.72267616,,,,,,,,,,,,,, -399900,5.9304852,0.63714594,,,,,,,,,,,,,, -400000,5.7956743,0.68130267,,,,,,,,,,,,,, -400100,4.528327,0.5800484,,,,,,,,,,,,,, -400200,4.8496113,0.6171241,,,,,,,,,,,,,, -400300,4.400009,0.6455011,,,,,,,,,,,,,, -400400,4.994782,0.6873103,,,,,,,,,,,,,, -400500,4.4382496,0.63885117,,,,,,,,,,,,,, -400600,4.4817796,0.5892385,,,,,,,,,,,,,, -400700,4.364994,0.5528294,,,,,,,,,,,,,, -400800,4.6420803,0.63484234,,,,,,,,,,,,,, -400900,4.3440914,0.5727992,,,,,,,,,,,,,, -400903,,,0.9602399468421936,0.1468479335308075,0.7557199597358704,1.0537818670272827,50000.0,0.6270000338554382,1.8432790040969849,10000.0,134194.7005777359,138793.39784646034,134194.7005777359,4567.4318034648895,16.86515712738037,0.0 -401000,4.622767,0.6822197,,,,,,,,,,,,,, -401100,4.7915797,0.6698351,,,,,,,,,,,,,, -401200,4.9365644,0.6730305,,,,,,,,,,,,,, -401300,4.686367,0.60096073,,,,,,,,,,,,,, -401400,4.6793785,0.6358323,,,,,,,,,,,,,, -401500,5.2283607,0.6161854,,,,,,,,,,,,,, -401600,4.694048,0.59079254,,,,,,,,,,,,,, -401700,4.4691534,0.64042187,,,,,,,,,,,,,, -401800,4.2098465,0.5933615,,,,,,,,,,,,,, -401900,4.860131,0.6726862,,,,,,,,,,,,,, -402000,5.1596813,0.5761758,,,,,,,,,,,,,, -402100,4.3741107,0.6368756,,,,,,,,,,,,,, -402200,5.0761776,0.6755959,,,,,,,,,,,,,, -402300,4.6340647,0.5799416,,,,,,,,,,,,,, -402400,5.1291585,0.62495685,,,,,,,,,,,,,, -402428,,,0.9604392051696776,0.1478034555912017,0.7559399604797363,1.053891658782959,50000.0,0.6276000142097473,1.842927098274231,10000.0,134704.63186764717,139320.47827506065,134704.63186764717,4584.428512334824,16.963099002838135,0.0 -402500,4.4119034,0.6693549,,,,,,,,,,,,,, -402600,4.291568,0.603702,,,,,,,,,,,,,, -402700,5.0078106,0.67277074,,,,,,,,,,,,,, -402800,4.57624,0.6552582,,,,,,,,,,,,,, -402900,4.9063125,0.61834663,,,,,,,,,,,,,, -403000,4.641386,0.7018486,,,,,,,,,,,,,, -403100,4.1755915,0.6200383,,,,,,,,,,,,,, -403200,4.7371826,0.58113605,,,,,,,,,,,,,, -403300,4.5911665,0.6800363,,,,,,,,,,,,,, -403400,4.6052194,0.665705,,,,,,,,,,,,,, -403500,4.6871877,0.61759335,,,,,,,,,,,,,, -403600,4.6216288,0.5577509,,,,,,,,,,,,,, -403700,4.515743,0.66541296,,,,,,,,,,,,,, -403800,4.60238,0.67610306,,,,,,,,,,,,,, -403900,4.4252734,0.5871929,,,,,,,,,,,,,, -403953,,,0.961336076259613,0.1461579501628875,0.756119966506958,1.0521721839904783,50000.0,0.6271000504493713,1.839210629463196,10000.0,135214.64415454865,139847.28141379356,135214.64415454865,4601.067002296448,17.061861038208008,0.0 -404000,4.729621,0.6510471,,,,,,,,,,,,,, -404100,5.032044,0.6305551,,,,,,,,,,,,,, -404200,4.654224,0.62118995,,,,,,,,,,,,,, -404300,4.4230747,0.6799193,,,,,,,,,,,,,, -404400,4.455571,0.64623034,,,,,,,,,,,,,, -404500,4.7345595,0.60750103,,,,,,,,,,,,,, -404600,4.925088,0.6863076,,,,,,,,,,,,,, -404700,4.2591786,0.5382761,,,,,,,,,,,,,, -404800,4.6892996,0.6360494,,,,,,,,,,,,,, -404900,5.933216,0.683077,,,,,,,,,,,,,, -405000,4.538146,0.6227972,,,,,,,,,,,,,, -405100,4.4209046,0.6522627,,,,,,,,,,,,,, -405200,4.44455,0.55298847,,,,,,,,,,,,,, -405300,4.458955,0.6158901,,,,,,,,,,,,,, -405400,4.3682914,0.5969825,,,,,,,,,,,,,, -405478,,,0.9627909660339355,0.1421767473220825,0.756119966506958,1.0536324977874756,50000.0,0.6273000240325928,1.84215247631073,10000.0,135724.62863755226,140373.97967410088,135724.62863755226,4617.628590106964,17.158963441848755,0.0 -405500,4.620757,0.64745337,,,,,,,,,,,,,, -405600,4.3850036,0.63023764,,,,,,,,,,,,,, -405700,4.298974,0.59224033,,,,,,,,,,,,,, -405800,4.441116,0.59045744,,,,,,,,,,,,,, -405900,5.2184772,0.57192075,,,,,,,,,,,,,, -406000,4.178902,0.56957316,,,,,,,,,,,,,, -406100,4.9327073,0.5712701,,,,,,,,,,,,,, -406200,4.3116403,0.6228913,,,,,,,,,,,,,, -406300,4.6195507,0.6068976,,,,,,,,,,,,,, -406400,4.633013,0.686225,,,,,,,,,,,,,, -406500,4.102326,0.5663711,,,,,,,,,,,,,, -406600,4.169232,0.6736542,,,,,,,,,,,,,, -406700,5.0643106,0.6715395,,,,,,,,,,,,,, -406800,4.4267945,0.6566409,,,,,,,,,,,,,, -406900,4.39809,0.6206911,,,,,,,,,,,,,, -407000,4.5956697,0.61443913,,,,,,,,,,,,,, -407002,,,0.9618144035339355,0.1452556103467941,0.7559999823570251,1.0538339614868164,50000.0,0.6276000142097473,1.8435428142547607,10000.0,136234.4849574566,140900.5523967743,136234.4849574566,4634.187923192978,17.261436939239502,0.0 -407100,4.7437654,0.64562327,,,,,,,,,,,,,, -407200,4.647052,0.6658292,,,,,,,,,,,,,, -407300,4.3422236,0.6634902,,,,,,,,,,,,,, -407400,4.2612543,0.5917662,,,,,,,,,,,,,, -407500,4.4125524,0.59374756,,,,,,,,,,,,,, -407600,4.362571,0.62894183,,,,,,,,,,,,,, -407700,4.391084,0.619883,,,,,,,,,,,,,, -407800,4.718494,0.58856165,,,,,,,,,,,,,, -407900,4.9980736,0.67770857,,,,,,,,,,,,,, -408000,5.1609855,0.69601846,,,,,,,,,,,,,, -408100,4.498583,0.6338179,,,,,,,,,,,,,, -408200,4.489147,0.64031696,,,,,,,,,,,,,, -408300,4.554712,0.559798,,,,,,,,,,,,,, -408400,5.069632,0.5788724,,,,,,,,,,,,,, -408500,4.747878,0.65004,,,,,,,,,,,,,, -408527,,,0.9584860801696776,0.1510231494903564,0.7560399770736694,1.0538654327392578,50000.0,0.6272000074386597,1.842506766319275,10000.0,136744.42813515663,141427.18142938614,136744.42813515663,4650.723289728165,17.356701374053955,0.0 -408600,4.470951,0.72139615,,,,,,,,,,,,,, -408700,4.507941,0.66558564,,,,,,,,,,,,,, -408800,4.2945385,0.66820455,,,,,,,,,,,,,, -408900,4.354513,0.60090023,,,,,,,,,,,,,, -409000,4.9105167,0.6348066,,,,,,,,,,,,,, -409100,4.6102753,0.6594156,,,,,,,,,,,,,, -409200,5.147737,0.60055816,,,,,,,,,,,,,, -409300,4.6519647,0.6049172,,,,,,,,,,,,,, -409400,4.4102006,0.65493166,,,,,,,,,,,,,, -409500,4.9204655,0.60941064,,,,,,,,,,,,,, -409600,4.8571267,0.6583173,,,,,,,,,,,,,, -409700,4.1973944,0.56401837,,,,,,,,,,,,,, -409800,4.239894,0.5519592,,,,,,,,,,,,,, -409900,4.73809,0.72051984,,,,,,,,,,,,,, -410000,4.462313,0.5900214,,,,,,,,,,,,,, -410052,,,0.9606385231018066,0.1461514383554458,0.7562400102615356,1.053558349609375,50000.0,0.6276000142097473,1.843324899673462,10000.0,137254.58539009094,141954.05271077156,137254.58539009094,4667.281981468201,17.456937551498413,0.0 -410100,4.517074,0.6400081,,,,,,,,,,,,,, -410200,5.4335856,0.74417394,,,,,,,,,,,,,, -410300,4.371431,0.65350163,,,,,,,,,,,,,, -410400,4.691273,0.648659,,,,,,,,,,,,,, -410500,4.221381,0.57848704,,,,,,,,,,,,,, -410600,4.5073853,0.6987788,,,,,,,,,,,,,, -410700,4.2302623,0.5960101,,,,,,,,,,,,,, -410800,4.472657,0.573034,,,,,,,,,,,,,, -410900,4.603691,0.572567,,,,,,,,,,,,,, -411000,5.227068,0.68679804,,,,,,,,,,,,,, -411100,4.539701,0.6088723,,,,,,,,,,,,,, -411200,4.4185348,0.58644223,,,,,,,,,,,,,, -411300,4.886457,0.6920631,,,,,,,,,,,,,, -411400,4.603728,0.6599175,,,,,,,,,,,,,, -411500,4.7574887,0.607257,,,,,,,,,,,,,, -411576,,,0.9613958597183228,0.1457064151763916,0.7561399936676025,1.0536881685256958,50000.0,0.6277000308036804,1.8429654836654663,10000.0,137764.53922367096,142480.7475554943,137764.53922367096,4683.87499833107,17.549427270889282,0.0 -411600,4.376093,0.6630196,,,,,,,,,,,,,, -411700,4.110746,0.6076139,,,,,,,,,,,,,, -411800,5.0482903,0.60440797,,,,,,,,,,,,,, -411900,4.8373365,0.624116,,,,,,,,,,,,,, -412000,4.596595,0.6544603,,,,,,,,,,,,,, -412100,4.714539,0.6783198,,,,,,,,,,,,,, -412200,4.483939,0.678851,,,,,,,,,,,,,, -412300,4.2107353,0.57276046,,,,,,,,,,,,,, -412400,4.305084,0.634815,,,,,,,,,,,,,, -412500,4.978358,0.66910917,,,,,,,,,,,,,, -412600,4.3735723,0.6267759,,,,,,,,,,,,,, -412700,4.7166915,0.6809192,,,,,,,,,,,,,, -412800,4.868145,0.60400915,,,,,,,,,,,,,, -412900,5.087831,0.66679496,,,,,,,,,,,,,, -413000,4.4521236,0.5970567,,,,,,,,,,,,,, -413100,,,0.9624919891357422,0.1417132914066314,0.7558000087738037,1.054227352142334,50000.0,0.6271000504493713,1.843506932258606,10000.0,138274.40833592415,143007.41964292526,138274.40833592415,4700.526404619217,17.64485478401184,0.0 -413100,4.254164,0.56608737,,,,,,,,,,,,,, -413200,4.2098265,0.6005903,,,,,,,,,,,,,, -413300,4.3387976,0.65466595,,,,,,,,,,,,,, -413400,4.568479,0.6283982,,,,,,,,,,,,,, -413500,5.017947,0.7001427,,,,,,,,,,,,,, -413600,4.743122,0.6216303,,,,,,,,,,,,,, -413700,4.178593,0.60054266,,,,,,,,,,,,,, -413800,4.3898535,0.64790684,,,,,,,,,,,,,, -413900,4.441518,0.6893931,,,,,,,,,,,,,, -414000,4.687839,0.601827,,,,,,,,,,,,,, -414100,4.674128,0.65598154,,,,,,,,,,,,,, -414200,4.4674873,0.5997714,,,,,,,,,,,,,, -414300,4.5406523,0.544796,,,,,,,,,,,,,, -414400,5.023528,0.65496016,,,,,,,,,,,,,, -414500,4.7196217,0.63615066,,,,,,,,,,,,,, -414600,4.5776615,0.58422256,,,,,,,,,,,,,, -414625,,,0.9585259556770324,0.1491367220878601,0.7554399967193604,1.0536048412322998,50000.0,0.6271000504493713,1.843660473823548,10000.0,138784.3907828331,143534.22866630554,138784.3907828331,4717.203856706619,17.738426208496094,0.0 -414700,4.2331796,0.6436284,,,,,,,,,,,,,, -414800,4.410841,0.65758324,,,,,,,,,,,,,, -414900,4.381276,0.62527686,,,,,,,,,,,,,, -415000,4.4350486,0.5626104,,,,,,,,,,,,,, -415100,4.160778,0.6056225,,,,,,,,,,,,,, -415200,3.8804984,0.5474038,,,,,,,,,,,,,, -415300,4.9348903,0.6430699,,,,,,,,,,,,,, -415400,4.3475137,0.5325743,,,,,,,,,,,,,, -415500,4.454976,0.601827,,,,,,,,,,,,,, -415600,5.654922,0.6567897,,,,,,,,,,,,,, -415700,4.55758,0.5570304,,,,,,,,,,,,,, -415800,4.418561,0.5745663,,,,,,,,,,,,,, -415900,5.0881195,0.62365407,,,,,,,,,,,,,, -416000,4.7663956,0.6227924,,,,,,,,,,,,,, -416100,4.139088,0.53820205,,,,,,,,,,,,,, -416150,,,0.9607780575752258,0.1463647484779358,0.7557199597358704,1.052968978881836,50000.0,0.6269000172615051,1.8419307470321653,10000.0,139294.40902137756,144060.8727862835,139294.40902137756,4733.677309751511,17.83676266670227,0.0 -416200,4.4809303,0.6268796,,,,,,,,,,,,,, -416300,4.463036,0.69613767,,,,,,,,,,,,,, -416400,4.329256,0.5149255,,,,,,,,,,,,,, -416500,4.665633,0.69136316,,,,,,,,,,,,,, -416600,4.36158,0.69827926,,,,,,,,,,,,,, -416700,4.444125,0.56913733,,,,,,,,,,,,,, -416800,4.0867214,0.54007107,,,,,,,,,,,,,, -416900,4.409653,0.6318919,,,,,,,,,,,,,, -417000,4.1396437,0.5634303,,,,,,,,,,,,,, -417100,4.5496125,0.6440508,,,,,,,,,,,,,, -417200,4.22109,0.6046731,,,,,,,,,,,,,, -417300,5.384428,0.6165914,,,,,,,,,,,,,, -417400,4.544859,0.6385131,,,,,,,,,,,,,, -417500,5.0234094,0.60674155,,,,,,,,,,,,,, -417600,4.5082493,0.67772394,,,,,,,,,,,,,, -417675,,,0.9611766338348388,0.1453454792499542,0.7560399770736694,1.053350567817688,50000.0,0.6277000308036804,1.8424909114837649,10000.0,139804.53283405304,144587.82571840286,139804.53283405304,4750.356118917465,17.932732820510864,0.0 -417700,4.4343038,0.6688932,,,,,,,,,,,,,, -417800,4.618514,0.60854983,,,,,,,,,,,,,, -417900,4.2415104,0.5591085,,,,,,,,,,,,,, -418000,5.3966646,0.6157758,,,,,,,,,,,,,, -418100,4.835492,0.6506851,,,,,,,,,,,,,, -418200,4.382182,0.57493794,,,,,,,,,,,,,, -418300,4.3511734,0.60895646,,,,,,,,,,,,,, -418400,4.6924996,0.6690457,,,,,,,,,,,,,, -418500,5.1010823,0.6913465,,,,,,,,,,,,,, -418600,4.632395,0.5995457,,,,,,,,,,,,,, -418700,4.7241626,0.568545,,,,,,,,,,,,,, -418800,4.5268636,0.6791159,,,,,,,,,,,,,, -418900,4.731163,0.6007519,,,,,,,,,,,,,, -419000,4.548099,0.5759154,,,,,,,,,,,,,, -419100,4.937176,0.6654259,,,,,,,,,,,,,, -419200,,,0.9610171914100648,0.148466870188713,0.7556799650192261,1.053810477256775,50000.0,0.6277000308036804,1.8426084518432613,10000.0,140314.65873599052,145114.72708678246,140314.65873599052,4766.973588705063,18.03618574142456,0.0 -419200,4.543442,0.6167116,,,,,,,,,,,,,, -419300,4.9248276,0.631954,,,,,,,,,,,,,, -419400,4.4795585,0.6762295,,,,,,,,,,,,,, -419500,4.809747,0.6333758,,,,,,,,,,,,,, -419600,4.2138395,0.60818386,,,,,,,,,,,,,, -419700,4.292684,0.60119253,,,,,,,,,,,,,, -419800,4.5838323,0.6428998,,,,,,,,,,,,,, -419900,4.6487536,0.6515496,,,,,,,,,,,,,, -420000,4.579339,0.6232911,,,,,,,,,,,,,, -420100,4.6049066,0.63935435,,,,,,,,,,,,,, -420200,4.921886,0.6597763,,,,,,,,,,,,,, -420300,4.2879553,0.59005153,,,,,,,,,,,,,, -420400,5.0355363,0.70914483,,,,,,,,,,,,,, -420500,4.889178,0.59132624,,,,,,,,,,,,,, -420600,4.310732,0.62667215,,,,,,,,,,,,,, -420700,4.7785263,0.6660126,,,,,,,,,,,,,, -420724,,,0.9606584906578064,0.1469320952892303,0.7560399770736694,1.0542471408843994,50000.0,0.626300036907196,1.843507647514344,10000.0,140824.52908945084,145641.24696421623,140824.52908945084,4783.490590810776,18.113587141036987,0.0 -420800,4.5192666,0.6425535,,,,,,,,,,,,,, -420900,4.515926,0.64493537,,,,,,,,,,,,,, -421000,4.971899,0.65571725,,,,,,,,,,,,,, -421100,4.5189505,0.6056071,,,,,,,,,,,,,, -421200,3.9603035,0.56149143,,,,,,,,,,,,,, -421300,4.413006,0.6184279,,,,,,,,,,,,,, -421400,4.5312715,0.5874444,,,,,,,,,,,,,, -421500,4.582631,0.5945538,,,,,,,,,,,,,, -421600,4.333894,0.57735026,,,,,,,,,,,,,, -421700,4.469681,0.6144267,,,,,,,,,,,,,, -421800,4.5928073,0.6339379,,,,,,,,,,,,,, -421900,4.864942,0.64938307,,,,,,,,,,,,,, -422000,4.8461747,0.6188678,,,,,,,,,,,,,, -422100,5.1162767,0.7263046,,,,,,,,,,,,,, -422200,4.2619815,0.635921,,,,,,,,,,,,,, -422249,,,0.9617944359779358,0.1449295580387115,0.7558000087738037,1.0536588430404663,50000.0,0.6279000043869019,1.8419170379638672,10000.0,141334.59196543694,146168.13632249832,141334.59196543694,4800.164356708527,18.211303234100345,0.0 -422300,4.5691686,0.6132636,,,,,,,,,,,,,, -422400,4.774273,0.6666524,,,,,,,,,,,,,, -422500,4.605125,0.57784593,,,,,,,,,,,,,, -422600,4.4173074,0.55765927,,,,,,,,,,,,,, -422700,4.670621,0.6207554,,,,,,,,,,,,,, -422800,4.08223,0.59024906,,,,,,,,,,,,,, -422900,4.413762,0.5947779,,,,,,,,,,,,,, -423000,4.5550895,0.62972605,,,,,,,,,,,,,, -423100,4.8070817,0.61682826,,,,,,,,,,,,,, -423200,4.565702,0.7201559,,,,,,,,,,,,,, -423300,4.6735716,0.63121873,,,,,,,,,,,,,, -423400,4.4329176,0.60401803,,,,,,,,,,,,,, -423500,4.8230186,0.66240495,,,,,,,,,,,,,, -423600,4.8795004,0.774312,,,,,,,,,,,,,, -423700,4.7781353,0.6362083,,,,,,,,,,,,,, -423774,,,0.9623525142669678,0.1428000479936599,0.7557599544525146,1.0537769794464111,50000.0,0.628000020980835,1.8425239324569704,10000.0,141844.57013821602,146694.98079276085,141844.57013821602,4816.875243186951,18.31092357635498,0.0 -423800,4.357248,0.62936103,,,,,,,,,,,,,, -423900,4.5014987,0.63486683,,,,,,,,,,,,,, -424000,4.678394,0.6164177,,,,,,,,,,,,,, -424100,4.5093107,0.6317677,,,,,,,,,,,,,, -424200,4.518099,0.6363466,,,,,,,,,,,,,, -424300,4.347304,0.60978323,,,,,,,,,,,,,, -424400,4.687145,0.61558896,,,,,,,,,,,,,, -424500,4.692449,0.6002393,,,,,,,,,,,,,, -424600,4.4979105,0.62376046,,,,,,,,,,,,,, -424700,4.2081466,0.6372811,,,,,,,,,,,,,, -424800,4.418576,0.71119887,,,,,,,,,,,,,, -424900,4.744467,0.5708199,,,,,,,,,,,,,, -425000,4.687252,0.66136014,,,,,,,,,,,,,, -425100,4.383711,0.62202907,,,,,,,,,,,,,, -425200,4.444802,0.61644423,,,,,,,,,,,,,, -425299,,,0.9629504084587096,0.1425699740648269,0.7560799717903137,1.053001046180725,50000.0,0.6271000504493713,1.8420350551605225,10000.0,142354.64644885063,147221.78683519363,142354.64644885063,4833.446505069733,18.412598848342896,0.0 -425300,4.3683057,0.6631397,,,,,,,,,,,,,, -425400,4.0018415,0.583867,,,,,,,,,,,,,, -425500,5.2090716,0.6382113,,,,,,,,,,,,,, -425600,4.5493226,0.69304615,,,,,,,,,,,,,, -425700,4.8322306,0.69302285,,,,,,,,,,,,,, -425800,4.743022,0.62734103,,,,,,,,,,,,,, -425900,4.5515633,0.63579315,,,,,,,,,,,,,, -426000,4.7470546,0.59952456,,,,,,,,,,,,,, -426100,4.3630533,0.56651056,,,,,,,,,,,,,, -426200,4.7162814,0.6792013,,,,,,,,,,,,,, -426300,5.01875,0.67393327,,,,,,,,,,,,,, -426400,4.5109415,0.59664226,,,,,,,,,,,,,, -426500,4.6334605,0.6731119,,,,,,,,,,,,,, -426600,4.612125,0.7157286,,,,,,,,,,,,,, -426700,4.6629944,0.6586204,,,,,,,,,,,,,, -426800,4.6966968,0.6090911,,,,,,,,,,,,,, -426824,,,0.9593231678009032,0.1488415002822876,0.7558799982070923,1.0532394647598269,50000.0,0.627500057220459,1.840661644935608,10000.0,142864.6662731171,147749.3140487671,142864.6662731171,4850.79785990715,18.51312065124512,0.0 -426900,4.7723145,0.5666015,,,,,,,,,,,,,, -427000,4.4531684,0.5817789,,,,,,,,,,,,,, -427100,4.435735,0.5831469,,,,,,,,,,,,,, -427200,4.262737,0.60720766,,,,,,,,,,,,,, -427300,4.5986853,0.6500027,,,,,,,,,,,,,, -427400,4.1600747,0.60898113,,,,,,,,,,,,,, -427500,4.7531815,0.61458945,,,,,,,,,,,,,, -427600,4.507546,0.5829843,,,,,,,,,,,,,, -427700,5.356961,0.6395339,,,,,,,,,,,,,, -427800,4.7398868,0.6574441,,,,,,,,,,,,,, -427900,4.3624177,0.6824819,,,,,,,,,,,,,, -428000,4.7993383,0.6964421,,,,,,,,,,,,,, -428100,4.704367,0.6095327,,,,,,,,,,,,,, -428200,4.4412465,0.53082806,,,,,,,,,,,,,, -428300,5.1349945,0.62591124,,,,,,,,,,,,,, -428349,,,0.9606783986091614,0.1454301327466964,0.7557399868965149,1.0528854131698608,50000.0,0.6274000406265259,1.8414006233215328,10000.0,143374.65701889992,148276.11263155937,143374.65701889992,4867.446770191193,18.61627769470215,0.0 -428400,4.4648557,0.6559587,,,,,,,,,,,,,, -428500,4.803275,0.71120477,,,,,,,,,,,,,, -428600,4.189527,0.6248584,,,,,,,,,,,,,, -428700,4.595642,0.6884554,,,,,,,,,,,,,, -428800,4.575704,0.6319764,,,,,,,,,,,,,, -428900,4.967797,0.66336864,,,,,,,,,,,,,, -429000,4.7368374,0.6038978,,,,,,,,,,,,,, -429100,4.597162,0.62332964,,,,,,,,,,,,,, -429200,4.7677407,0.6751844,,,,,,,,,,,,,, -429300,4.7886777,0.610882,,,,,,,,,,,,,, -429400,4.666062,0.66231745,,,,,,,,,,,,,, -429500,4.907286,0.7119634,,,,,,,,,,,,,, -429600,4.555275,0.6321975,,,,,,,,,,,,,, -429700,4.235726,0.60504264,,,,,,,,,,,,,, -429800,4.6688166,0.6541659,,,,,,,,,,,,,, -429873,,,0.9620535373687744,0.144050195813179,0.7560999989509583,1.053205966949463,50000.0,0.6272000074386597,1.8428633213043213,10000.0,143884.65881085396,148802.81915712357,143884.65881085396,4883.9972012043,18.715560913085938,0.0 -429900,5.458669,0.7295656,,,,,,,,,,,,,, -430000,4.908867,0.6628287,,,,,,,,,,,,,, -430100,4.527543,0.55257154,,,,,,,,,,,,,, -430200,4.582652,0.69244146,,,,,,,,,,,,,, -430300,4.6433926,0.65754163,,,,,,,,,,,,,, -430400,4.288617,0.60418916,,,,,,,,,,,,,, -430500,4.406377,0.55602914,,,,,,,,,,,,,, -430600,5.020114,0.6449331,,,,,,,,,,,,,, -430700,4.485681,0.58936715,,,,,,,,,,,,,, -430800,4.3208513,0.58168924,,,,,,,,,,,,,, -430900,4.7228837,0.6513155,,,,,,,,,,,,,, -431000,4.6145287,0.60708904,,,,,,,,,,,,,, -431100,4.416632,0.5982597,,,,,,,,,,,,,, -431200,4.799037,0.6008912,,,,,,,,,,,,,, -431300,5.112901,0.6647548,,,,,,,,,,,,,, -431398,,,0.961933970451355,0.142555832862854,0.7561399936676025,1.0534591674804688,50000.0,0.6273000240325928,1.842931151390076,10000.0,144394.69930911064,149329.63092327118,144394.69930911064,4900.618678569794,18.810232400894165,0.0 -431400,4.0524287,0.56242186,,,,,,,,,,,,,, -431500,4.5544205,0.66346526,,,,,,,,,,,,,, -431600,4.1777205,0.6095137,,,,,,,,,,,,,, -431700,4.295172,0.60488695,,,,,,,,,,,,,, -431800,4.371417,0.6056514,,,,,,,,,,,,,, -431900,5.2584267,0.61886036,,,,,,,,,,,,,, -432000,4.262075,0.55505735,,,,,,,,,,,,,, -432100,4.4931207,0.6601652,,,,,,,,,,,,,, -432200,5.2994404,0.6678972,,,,,,,,,,,,,, -432300,4.227121,0.60255957,,,,,,,,,,,,,, -432400,5.0162888,0.6717732,,,,,,,,,,,,,, -432500,4.9159536,0.67443895,,,,,,,,,,,,,, -432600,4.6602855,0.80051297,,,,,,,,,,,,,, -432700,5.094258,0.69359607,,,,,,,,,,,,,, -432800,4.7084994,0.6090778,,,,,,,,,,,,,, -432900,4.694358,0.6871397,,,,,,,,,,,,,, -432923,,,0.961355984210968,0.1434606611728668,0.755899965763092,1.0536073446273804,50000.0,0.6272000074386597,1.842706322669983,10000.0,144904.76746439934,149856.3191523552,144904.76746439934,4917.084590911865,18.91004538536072,0.0 -433000,4.8906574,0.60397863,,,,,,,,,,,,,, -433100,5.049358,0.73336977,,,,,,,,,,,,,, -433200,4.56174,0.6519762,,,,,,,,,,,,,, -433300,4.6182384,0.65929604,,,,,,,,,,,,,, -433400,4.625616,0.62905467,,,,,,,,,,,,,, -433500,4.8811393,0.66368103,,,,,,,,,,,,,, -433600,4.4372196,0.61097366,,,,,,,,,,,,,, -433700,4.7961764,0.6833524,,,,,,,,,,,,,, -433800,4.7438326,0.65959126,,,,,,,,,,,,,, -433900,5.0356183,0.6124139,,,,,,,,,,,,,, -434000,4.483565,0.58506966,,,,,,,,,,,,,, -434100,4.451104,0.645432,,,,,,,,,,,,,, -434200,5.0545936,0.627715,,,,,,,,,,,,,, -434300,4.7832303,0.6356329,,,,,,,,,,,,,, -434400,4.3234396,0.58122903,,,,,,,,,,,,,, -434447,,,0.9610969424247742,0.1472132056951522,0.7562199831008911,1.0545414686203003,50000.0,0.6271000504493713,1.842933297157288,10000.0,145414.69251775742,150382.98892378807,145414.69251775742,4933.675719738007,19.008899688720703,0.0 -434500,4.4581227,0.5822287,,,,,,,,,,,,,, -434600,4.8358746,0.674388,,,,,,,,,,,,,, -434700,4.387824,0.5667411,,,,,,,,,,,,,, -434800,4.2361393,0.5529746,,,,,,,,,,,,,, -434900,4.874669,0.65189576,,,,,,,,,,,,,, -435000,4.66074,0.59875023,,,,,,,,,,,,,, -435100,4.493224,0.54151475,,,,,,,,,,,,,, -435200,4.7716613,0.5820278,,,,,,,,,,,,,, -435300,4.513714,0.6274629,,,,,,,,,,,,,, -435400,5.147728,0.6361611,,,,,,,,,,,,,, -435500,4.6705637,0.68338305,,,,,,,,,,,,,, -435600,4.747831,0.6523062,,,,,,,,,,,,,, -435700,5.4904175,0.5836695,,,,,,,,,,,,,, -435800,4.7600455,0.63956213,,,,,,,,,,,,,, -435900,3.9728172,0.5466381,,,,,,,,,,,,,, -435972,,,0.9607780575752258,0.145490288734436,0.755840003490448,1.053587555885315,50000.0,0.627500057220459,1.8414922952651973,10000.0,145924.79586172104,150909.85272932053,145924.79586172104,4950.279651641846,19.110421895980835,0.0 -436000,4.8992167,0.69066244,,,,,,,,,,,,,, -436100,4.6350513,0.5870864,,,,,,,,,,,,,, -436200,5.0364256,0.6402713,,,,,,,,,,,,,, -436300,4.1617455,0.6037124,,,,,,,,,,,,,, -436400,4.741469,0.60514534,,,,,,,,,,,,,, -436500,4.215293,0.5589724,,,,,,,,,,,,,, -436600,4.5063863,0.59097224,,,,,,,,,,,,,, -436700,4.3006806,0.6258894,,,,,,,,,,,,,, -436800,4.335789,0.55299765,,,,,,,,,,,,,, -436900,4.5886097,0.63465583,,,,,,,,,,,,,, -437000,4.5255146,0.63158846,,,,,,,,,,,,,, -437100,4.573469,0.66865814,,,,,,,,,,,,,, -437200,4.7352357,0.69412917,,,,,,,,,,,,,, -437300,4.505202,0.58640474,,,,,,,,,,,,,, -437400,4.458678,0.58394885,,,,,,,,,,,,,, -437497,,,0.9619937539100648,0.1432654112577438,0.7560999989509583,1.053731918334961,50000.0,0.6271000504493713,1.8432046175003047,10000.0,146434.88862538338,151436.73866152763,146434.88862538338,4966.919440507889,19.20820665359497,0.0 -437500,4.5663166,0.58670485,,,,,,,,,,,,,, -437600,4.311152,0.6101133,,,,,,,,,,,,,, -437700,4.5425124,0.63514984,,,,,,,,,,,,,, -437800,4.9605837,0.65175927,,,,,,,,,,,,,, -437900,4.6280117,0.6302209,,,,,,,,,,,,,, -438000,4.584025,0.60437393,,,,,,,,,,,,,, -438100,4.8202095,0.6894734,,,,,,,,,,,,,, -438200,4.73992,0.7020794,,,,,,,,,,,,,, -438300,4.52939,0.5663449,,,,,,,,,,,,,, -438400,4.498139,0.6048444,,,,,,,,,,,,,, -438500,4.6313944,0.70195657,,,,,,,,,,,,,, -438600,4.3894415,0.6253024,,,,,,,,,,,,,, -438700,5.123816,0.62502253,,,,,,,,,,,,,, -438800,4.5560775,0.6320363,,,,,,,,,,,,,, -438900,4.4568534,0.6123826,,,,,,,,,,,,,, -439000,4.404314,0.60312974,,,,,,,,,,,,,, -439022,,,0.9611168503761292,0.1442411392927169,0.7558599710464478,1.053155779838562,50000.0,0.6266000270843506,1.8422584533691408,10000.0,146944.86338567734,151963.4141998291,146944.86338567734,4983.461073637009,19.31168293952942,0.0 -439100,4.658617,0.6321351,,,,,,,,,,,,,, -439200,4.420503,0.6407468,,,,,,,,,,,,,, -439300,4.3406696,0.5895117,,,,,,,,,,,,,, -439400,4.5115166,0.6337674,,,,,,,,,,,,,, -439500,4.109468,0.576636,,,,,,,,,,,,,, -439600,4.288458,0.5250102,,,,,,,,,,,,,, -439700,4.5771804,0.6364454,,,,,,,,,,,,,, -439800,4.564931,0.5954468,,,,,,,,,,,,,, -439900,4.5037856,0.60589045,,,,,,,,,,,,,, -440000,4.6919,0.61267203,,,,,,,,,,,,,, -440100,4.8901625,0.5719696,,,,,,,,,,,,,, -440200,4.095008,0.60106623,,,,,,,,,,,,,, -440300,4.2579737,0.6363376,,,,,,,,,,,,,, -440400,5.1138864,0.6539312,,,,,,,,,,,,,, -440500,4.7863483,0.6361536,,,,,,,,,,,,,, -440547,,,0.9596819281578064,0.1478977352380752,0.7560799717903137,1.053455114364624,50000.0,0.6271000504493713,1.8422164916992188,10000.0,147454.86088490486,152490.3506128788,147454.86088490486,5000.237932682037,19.418365716934204,0.0 -440600,4.8363256,0.61275935,,,,,,,,,,,,,, -440700,4.3584814,0.57604504,,,,,,,,,,,,,, -440800,4.496217,0.6246537,,,,,,,,,,,,,, -440900,4.3941183,0.6199232,,,,,,,,,,,,,, -441000,4.1994095,0.60603285,,,,,,,,,,,,,, -441100,4.1996202,0.4846625,,,,,,,,,,,,,, -441200,4.6779194,0.6489414,,,,,,,,,,,,,, -441300,4.31224,0.5844295,,,,,,,,,,,,,, -441400,4.57631,0.614534,,,,,,,,,,,,,, -441500,4.816916,0.6310277,,,,,,,,,,,,,, -441600,4.273928,0.60555935,,,,,,,,,,,,,, -441700,4.701213,0.565537,,,,,,,,,,,,,, -441800,5.0267396,0.63698035,,,,,,,,,,,,,, -441900,4.2490935,0.65573424,,,,,,,,,,,,,, -442000,4.303952,0.6336525,,,,,,,,,,,,,, -442073,,,0.9610171914100648,0.1458736658096313,0.7556599974632263,1.052662372589111,50000.0,0.6277000308036804,1.841708302497864,10000.0,147964.97935509682,153017.11069369316,147964.97935509682,5016.71445608139,19.52862310409546,0.0 -442100,4.99455,0.64308953,,,,,,,,,,,,,, -442200,4.5076656,0.65624464,,,,,,,,,,,,,, -442300,4.451197,0.6255224,,,,,,,,,,,,,, -442400,4.902543,0.6402774,,,,,,,,,,,,,, -442500,4.4703007,0.624552,,,,,,,,,,,,,, -442600,5.0846944,0.7056716,,,,,,,,,,,,,, -442700,4.712432,0.5930199,,,,,,,,,,,,,, -442800,4.6370573,0.5927107,,,,,,,,,,,,,, -442900,4.2883296,0.63417417,,,,,,,,,,,,,, -443000,4.6011972,0.6511328,,,,,,,,,,,,,, -443100,4.4157662,0.6044576,,,,,,,,,,,,,, -443200,4.3513975,0.63053954,,,,,,,,,,,,,, -443300,4.8196588,0.6429097,,,,,,,,,,,,,, -443400,5.016656,0.65805477,,,,,,,,,,,,,, -443500,4.53149,0.59053594,,,,,,,,,,,,,, -443599,,,0.962332546710968,0.1446838676929474,0.7557799816131592,1.053775429725647,50000.0,0.6272000074386597,1.843441247940064,10000.0,148475.09104943275,153543.9378838539,148475.09104943275,5033.272862672806,19.630709886550903,0.0 -443600,5.009988,0.57463235,,,,,,,,,,,,,, -443700,4.96188,0.60284936,,,,,,,,,,,,,, -443800,4.4317913,0.62387955,,,,,,,,,,,,,, -443900,4.467961,0.6233463,,,,,,,,,,,,,, -444000,4.962142,0.6411375,,,,,,,,,,,,,, -444100,4.3484554,0.63240653,,,,,,,,,,,,,, -444200,4.4758015,0.6099887,,,,,,,,,,,,,, -444300,5.0497513,0.5286821,,,,,,,,,,,,,, -444400,4.5394135,0.64098763,,,,,,,,,,,,,, -444500,4.2834983,0.59388036,,,,,,,,,,,,,, -444600,4.6180367,0.60695374,,,,,,,,,,,,,, -444700,4.3397694,0.5961517,,,,,,,,,,,,,, -444800,5.30262,0.6786592,,,,,,,,,,,,,, -444900,4.125838,0.5784478,,,,,,,,,,,,,, -445000,4.8421187,0.61382455,,,,,,,,,,,,,, -445100,4.6349683,0.65712297,,,,,,,,,,,,,, -445124,,,0.9621731042861938,0.1434157192707061,0.7562599778175354,1.053500056266785,50000.0,0.6267000436782837,1.84416663646698,10000.0,148985.1442747116,154070.77855920792,148985.1442747116,5049.900293827057,19.734724521636963,0.0 -445200,4.7820983,0.59949064,,,,,,,,,,,,,, -445300,4.272617,0.60513353,,,,,,,,,,,,,, -445400,4.7470555,0.64217997,,,,,,,,,,,,,, -445500,4.4184494,0.6308546,,,,,,,,,,,,,, -445600,4.1504445,0.64374876,,,,,,,,,,,,,, -445700,4.3452682,0.621525,,,,,,,,,,,,,, -445800,4.7857556,0.6883983,,,,,,,,,,,,,, -445900,5.4862986,0.64933854,,,,,,,,,,,,,, -446000,4.8024616,0.6892485,,,,,,,,,,,,,, -446100,4.217276,0.59559476,,,,,,,,,,,,,, -446200,4.335389,0.60503626,,,,,,,,,,,,,, -446300,4.6485243,0.71250933,,,,,,,,,,,,,, -446400,4.391453,0.640215,,,,,,,,,,,,,, -446500,5.3356066,0.6290961,,,,,,,,,,,,,, -446600,4.396145,0.5687833,,,,,,,,,,,,,, -446649,,,0.9598612785339355,0.1488574594259262,0.7561599612236023,1.0542877912521362,50000.0,0.6273000240325928,1.8439942598342896,10000.0,149495.27058815956,154597.66735219955,149495.27058815956,5066.508980512619,19.83415412902832,0.0 -446700,4.753288,0.6766119,,,,,,,,,,,,,, -446800,4.8128667,0.5943053,,,,,,,,,,,,,, -446900,4.360545,0.5948774,,,,,,,,,,,,,, -447000,4.600617,0.6693846,,,,,,,,,,,,,, -447100,4.247311,0.64643747,,,,,,,,,,,,,, -447200,4.226285,0.630136,,,,,,,,,,,,,, -447300,5.1116586,0.62583727,,,,,,,,,,,,,, -447400,4.5369425,0.6105983,,,,,,,,,,,,,, -447500,4.454961,0.60094625,,,,,,,,,,,,,, -447600,4.2666574,0.63027036,,,,,,,,,,,,,, -447700,5.7875433,0.6525559,,,,,,,,,,,,,, -447800,4.849117,0.632364,,,,,,,,,,,,,, -447900,4.7239428,0.6336069,,,,,,,,,,,,,, -448000,4.7261734,0.66227233,,,,,,,,,,,,,, -448100,4.5629015,0.5999237,,,,,,,,,,,,,, -448174,,,0.959781527519226,0.1462220698595047,0.7554799914360046,1.0554804801940918,50000.0,0.627500057220459,1.8451577425003047,10000.0,150005.38194799423,155124.3677699566,150005.38194799423,5082.940105676651,19.93660569190979,0.0 -448200,4.2999883,0.55680984,,,,,,,,,,,,,, -448300,4.521458,0.5530536,,,,,,,,,,,,,, -448400,4.5391064,0.66208667,,,,,,,,,,,,,, -448500,4.7838197,0.67021096,,,,,,,,,,,,,, -448600,4.1570735,0.57776475,,,,,,,,,,,,,, -448700,4.438527,0.598549,,,,,,,,,,,,,, -448800,4.9114504,0.6826807,,,,,,,,,,,,,, -448900,4.938761,0.6470429,,,,,,,,,,,,,, -449000,4.565526,0.5912401,,,,,,,,,,,,,, -449100,4.548634,0.6728922,,,,,,,,,,,,,, -449200,4.54989,0.61755955,,,,,,,,,,,,,, -449300,4.4249053,0.61090976,,,,,,,,,,,,,, -449400,4.491169,0.6218585,,,,,,,,,,,,,, -449500,4.86922,0.63945395,,,,,,,,,,,,,, -449600,3.9236593,0.51588976,,,,,,,,,,,,,, -449699,,,0.9611168503761292,0.1465611755847931,0.7557399868965149,1.0540775060653689,50000.0,0.6270000338554382,1.8444410562515257,10000.0,150515.29887747765,155650.94755601883,150515.29887747765,5099.444447517395,20.040841579437256,0.0 -449700,4.681255,0.5606502,,,,,,,,,,,,,, -449800,4.203196,0.605044,,,,,,,,,,,,,, -449900,4.3253827,0.5580428,,,,,,,,,,,,,, -450000,4.6752605,0.58851796,,,,,,,,,,,,,, -450100,4.6903462,0.6543485,,,,,,,,,,,,,, -450200,4.2917757,0.6063721,,,,,,,,,,,,,, -450300,4.748244,0.66996825,,,,,,,,,,,,,, -450400,4.4005013,0.58624095,,,,,,,,,,,,,, -450500,5.352464,0.6488463,,,,,,,,,,,,,, -450600,5.1965094,0.58272934,,,,,,,,,,,,,, -450700,4.3308744,0.60028374,,,,,,,,,,,,,, -450800,4.7250557,0.64566886,,,,,,,,,,,,,, -450900,4.010184,0.56931317,,,,,,,,,,,,,, -451000,4.8031,0.71780187,,,,,,,,,,,,,, -451100,4.654403,0.68532443,,,,,,,,,,,,,, -451200,4.8618803,0.6525001,,,,,,,,,,,,,, -451224,,,0.9607979655265808,0.1459456384181976,0.7552599906921387,1.054438829421997,50000.0,0.627500057220459,1.8438587188720703,10000.0,151025.3807592392,156177.87268805504,151025.3807592392,5116.119092226028,20.15410733222961,0.0 -451300,4.781226,0.6354559,,,,,,,,,,,,,, -451400,4.611247,0.653077,,,,,,,,,,,,,, -451500,4.3327694,0.5634651,,,,,,,,,,,,,, -451600,4.60121,0.70009357,,,,,,,,,,,,,, -451700,4.180804,0.55820054,,,,,,,,,,,,,, -451800,4.716607,0.71697974,,,,,,,,,,,,,, -451900,4.1079874,0.57004964,,,,,,,,,,,,,, -452000,4.7313933,0.58236617,,,,,,,,,,,,,, -452100,4.085942,0.5839902,,,,,,,,,,,,,, -452200,4.755446,0.64597464,,,,,,,,,,,,,, -452300,4.583612,0.56195426,,,,,,,,,,,,,, -452400,4.6442776,0.6366135,,,,,,,,,,,,,, -452500,5.1620917,0.6952844,,,,,,,,,,,,,, -452600,4.469288,0.60094976,,,,,,,,,,,,,, -452700,4.2846494,0.63253987,,,,,,,,,,,,,, -452749,,,0.9612165093421936,0.1441743522882461,0.755840003490448,1.052942156791687,50000.0,0.6274000406265259,1.8419733047485352,10000.0,151535.39058494568,156704.6409471035,151535.39058494568,5132.721889019012,20.254146337509155,0.0 -452800,4.796109,0.6331893,,,,,,,,,,,,,, -452900,4.4527617,0.65865886,,,,,,,,,,,,,, -453000,5.0772567,0.6063902,,,,,,,,,,,,,, -453100,4.5619717,0.60578454,,,,,,,,,,,,,, -453200,4.891881,0.67877454,,,,,,,,,,,,,, -453300,4.772368,0.6390638,,,,,,,,,,,,,, -453400,4.224537,0.67875856,,,,,,,,,,,,,, -453500,4.1131525,0.5785583,,,,,,,,,,,,,, -453600,4.3999796,0.6135228,,,,,,,,,,,,,, -453700,5.243698,0.5861004,,,,,,,,,,,,,, -453800,4.663303,0.65611595,,,,,,,,,,,,,, -453900,4.353862,0.60013103,,,,,,,,,,,,,, -454000,4.8685822,0.6319613,,,,,,,,,,,,,, -454100,4.351211,0.5715119,,,,,,,,,,,,,, -454200,4.986472,0.7198485,,,,,,,,,,,,,, -454275,,,0.960359513759613,0.145753726363182,0.756060004234314,1.0537270307540894,50000.0,0.6279000043869019,1.8424499034881592,10000.0,152045.5179359913,157231.7075505257,152045.5179359913,5149.499848604202,20.36068964004517,0.0 -454300,4.9914246,0.60267013,,,,,,,,,,,,,, -454400,4.209264,0.58283687,,,,,,,,,,,,,, -454500,3.9468799,0.5813238,,,,,,,,,,,,,, -454600,4.2750144,0.6394415,,,,,,,,,,,,,, -454700,5.3507376,0.6812972,,,,,,,,,,,,,, -454800,4.3489404,0.5505628,,,,,,,,,,,,,, -454900,5.2781496,0.60801077,,,,,,,,,,,,,, -455000,4.431879,0.60888857,,,,,,,,,,,,,, -455100,4.3603835,0.58545154,,,,,,,,,,,,,, -455200,4.1592703,0.5704387,,,,,,,,,,,,,, -455300,4.2850327,0.5956007,,,,,,,,,,,,,, -455400,4.6677775,0.63761693,,,,,,,,,,,,,, -455500,4.503996,0.6079781,,,,,,,,,,,,,, -455600,4.311458,0.67941046,,,,,,,,,,,,,, -455700,4.780799,0.58110225,,,,,,,,,,,,,, -455800,,,0.9604192972183228,0.1473921835422516,0.7557599544525146,1.0532865524291992,50000.0,0.626800000667572,1.8414360284805296,10000.0,152555.53424096107,157758.27311992645,152555.53424096107,5165.890701532364,20.46353912353516,0.0 -455800,4.423581,0.6137057,,,,,,,,,,,,,, -455900,4.351661,0.6355582,,,,,,,,,,,,,, -456000,4.124264,0.5711584,,,,,,,,,,,,,, -456100,4.494295,0.6593704,,,,,,,,,,,,,, -456200,4.03432,0.58577245,,,,,,,,,,,,,, -456300,4.4800367,0.6597267,,,,,,,,,,,,,, -456400,4.7648945,0.59167707,,,,,,,,,,,,,, -456500,4.767723,0.7008842,,,,,,,,,,,,,, -456600,4.091941,0.60835993,,,,,,,,,,,,,, -456700,4.5511827,0.66592014,,,,,,,,,,,,,, -456800,5.2946815,0.67068523,,,,,,,,,,,,,, -456900,4.669515,0.61790615,,,,,,,,,,,,,, -457000,4.3354516,0.5925881,,,,,,,,,,,,,, -457100,4.305539,0.59702045,,,,,,,,,,,,,, -457200,4.532235,0.6450723,,,,,,,,,,,,,, -457300,4.681274,0.62907743,,,,,,,,,,,,,, -457325,,,0.9620735049247742,0.1456650346517563,0.7555800080299377,1.0540398359298706,50000.0,0.6271000504493713,1.84191644191742,10000.0,153065.48213148117,158284.81679821014,153065.48213148117,5182.322814702988,20.57215189933777,0.0 -457400,4.285461,0.61098075,,,,,,,,,,,,,, -457500,4.6395435,0.6130373,,,,,,,,,,,,,, -457600,4.763303,0.6813014,,,,,,,,,,,,,, -457700,4.519191,0.607445,,,,,,,,,,,,,, -457800,4.7246046,0.65915346,,,,,,,,,,,,,, -457900,4.594471,0.6313906,,,,,,,,,,,,,, -458000,4.7005415,0.59322125,,,,,,,,,,,,,, -458100,4.959586,0.65761226,,,,,,,,,,,,,, -458200,4.0263286,0.5082258,,,,,,,,,,,,,, -458300,4.4239464,0.6345545,,,,,,,,,,,,,, -458400,4.151599,0.55093044,,,,,,,,,,,,,, -458500,4.4588656,0.5628855,,,,,,,,,,,,,, -458600,4.315013,0.57135266,,,,,,,,,,,,,, -458700,4.441721,0.5678767,,,,,,,,,,,,,, -458800,3.8125007,0.52694654,,,,,,,,,,,,,, -458850,,,0.960558831691742,0.1474863737821579,0.7563999891281128,1.0531036853790283,50000.0,0.6274000406265259,1.8431851863861084,10000.0,153575.37730908394,158811.58496284485,153575.37730908394,5199.036423683167,20.67511820793152,0.0 -458900,4.505648,0.5812396,,,,,,,,,,,,,, -459000,4.7104883,0.57476014,,,,,,,,,,,,,, -459100,4.257397,0.6304294,,,,,,,,,,,,,, -459200,4.0480804,0.51637805,,,,,,,,,,,,,, -459300,4.8482885,0.65293086,,,,,,,,,,,,,, -459400,4.3906693,0.59884113,,,,,,,,,,,,,, -459500,4.4544835,0.6680415,,,,,,,,,,,,,, -459600,4.8840437,0.6485547,,,,,,,,,,,,,, -459700,4.750864,0.6711596,,,,,,,,,,,,,, -459800,4.386032,0.59061867,,,,,,,,,,,,,, -459900,4.7317314,0.7002389,,,,,,,,,,,,,, -460000,4.5645533,0.568756,,,,,,,,,,,,,, -460100,4.2418785,0.57359934,,,,,,,,,,,,,, -460200,4.1627326,0.59541476,,,,,,,,,,,,,, -460300,4.627932,0.7055506,,,,,,,,,,,,,, -460376,,,0.960558831691742,0.1477661728858947,0.7560399770736694,1.0535650253295898,50000.0,0.6276000142097473,1.842554211616516,10000.0,154085.50613236427,159338.31057286265,154085.50613236427,5215.4732303619385,20.78009033203125,0.0 -460400,4.846162,0.6065312,,,,,,,,,,,,,, -460500,4.3229456,0.56364673,,,,,,,,,,,,,, -460600,4.321166,0.6353154,,,,,,,,,,,,,, -460700,4.488588,0.62546766,,,,,,,,,,,,,, -460800,4.6770353,0.59739524,,,,,,,,,,,,,, -460900,5.253739,0.5581163,,,,,,,,,,,,,, -461000,4.2082677,0.57216847,,,,,,,,,,,,,, -461100,4.648835,0.52142143,,,,,,,,,,,,,, -461200,4.5279546,0.624001,,,,,,,,,,,,,, -461300,4.530606,0.6074017,,,,,,,,,,,,,, -461400,4.9421706,0.6313947,,,,,,,,,,,,,, -461500,4.5215917,0.6670736,,,,,,,,,,,,,, -461600,4.751604,0.689403,,,,,,,,,,,,,, -461700,4.4658103,0.60769767,,,,,,,,,,,,,, -461800,4.47976,0.580531,,,,,,,,,,,,,, -461900,4.746514,0.6393147,,,,,,,,,,,,,, -461901,,,0.963109850883484,0.1407929807901382,0.7557399868965149,1.0532169342041016,50000.0,0.628000020980835,1.841142654418945,10000.0,154595.75695943832,159865.34474110603,154595.75695943832,5232.09307384491,20.888579607009888,0.0 -462000,4.740638,0.65801567,,,,,,,,,,,,,, -462100,4.9655232,0.6586064,,,,,,,,,,,,,, -462200,4.51625,0.6251606,,,,,,,,,,,,,, -462300,4.5161276,0.63104403,,,,,,,,,,,,,, -462400,4.056143,0.5686635,,,,,,,,,,,,,, -462500,4.18219,0.6049531,,,,,,,,,,,,,, -462600,5.028352,0.6791828,,,,,,,,,,,,,, -462700,4.485493,0.6374634,,,,,,,,,,,,,, -462800,4.238605,0.5375297,,,,,,,,,,,,,, -462900,4.784902,0.6903775,,,,,,,,,,,,,, -463000,4.6672664,0.619529,,,,,,,,,,,,,, -463100,4.6714745,0.5672818,,,,,,,,,,,,,, -463200,4.2924504,0.61914766,,,,,,,,,,,,,, -463300,4.2938175,0.5800131,,,,,,,,,,,,,, -463400,4.6753216,0.60058606,,,,,,,,,,,,,, -463426,,,0.9622329473495485,0.1450704485177993,0.7559799551963806,1.0533766746520996,50000.0,0.627500057220459,1.8413161039352417,10000.0,155105.88724708557,160392.16827869415,155105.88724708557,5248.627458095551,20.992839574813843,0.0 -463500,4.5059366,0.66972667,,,,,,,,,,,,,, -463600,4.391022,0.59787405,,,,,,,,,,,,,, -463700,4.9170384,0.6532251,,,,,,,,,,,,,, -463800,4.378106,0.5857082,,,,,,,,,,,,,, -463900,5.2381163,0.71887344,,,,,,,,,,,,,, -464000,4.6662974,0.6585951,,,,,,,,,,,,,, -464100,4.4071045,0.59597445,,,,,,,,,,,,,, -464200,4.283333,0.5718521,,,,,,,,,,,,,, -464300,4.496006,0.6196211,,,,,,,,,,,,,, -464400,4.31416,0.5755762,,,,,,,,,,,,,, -464500,4.6213846,0.63945186,,,,,,,,,,,,,, -464600,4.435226,0.66260064,,,,,,,,,,,,,, -464700,4.6861367,0.58939457,,,,,,,,,,,,,, -464800,4.673388,0.5878285,,,,,,,,,,,,,, -464900,4.366634,0.6360575,,,,,,,,,,,,,, -464950,,,0.9618343114852904,0.1451112926006317,0.7556399703025818,1.054116129875183,50000.0,0.6267000436782837,1.8424440622329712,10000.0,155615.8431122303,160919.34607696533,155615.8431122303,5265.688413143158,21.09821939468384,0.0 -465000,4.789527,0.5499841,,,,,,,,,,,,,, -465100,4.5129423,0.5851132,,,,,,,,,,,,,, -465200,4.3246164,0.6279024,,,,,,,,,,,,,, -465300,4.7278676,0.6200338,,,,,,,,,,,,,, -465400,4.3248196,0.64857394,,,,,,,,,,,,,, -465500,5.428687,0.7273007,,,,,,,,,,,,,, -465600,4.6910267,0.6757398,,,,,,,,,,,,,, -465700,5.476553,0.6395473,,,,,,,,,,,,,, -465800,4.529981,0.6123695,,,,,,,,,,,,,, -465900,5.473354,0.70072097,,,,,,,,,,,,,, -466000,4.165847,0.5526978,,,,,,,,,,,,,, -466100,4.2095304,0.58232343,,,,,,,,,,,,,, -466200,4.366562,0.57855785,,,,,,,,,,,,,, -466300,4.4419303,0.5523476,,,,,,,,,,,,,, -466400,4.6527414,0.5462064,,,,,,,,,,,,,, -466475,,,0.9597217440605164,0.1456145346164703,0.7562599778175354,1.052058219909668,50000.0,0.6279000043869019,1.840462565422058,10000.0,156125.88497161865,161446.20129728317,156125.88497161865,5282.34098482132,21.203664302825928,0.0 -466500,4.1919665,0.53945637,,,,,,,,,,,,,, -466600,4.339739,0.6032371,,,,,,,,,,,,,, -466700,4.361605,0.58577275,,,,,,,,,,,,,, -466800,4.265592,0.61943185,,,,,,,,,,,,,, -466900,4.578592,0.6697492,,,,,,,,,,,,,, -467000,4.797105,0.64500326,,,,,,,,,,,,,, -467100,4.4206476,0.60329866,,,,,,,,,,,,,, -467200,4.0271473,0.55558527,,,,,,,,,,,,,, -467300,4.469739,0.60636896,,,,,,,,,,,,,, -467400,4.2795954,0.61197567,,,,,,,,,,,,,, -467500,4.686002,0.6752865,,,,,,,,,,,,,, -467600,4.616176,0.71019244,,,,,,,,,,,,,, -467700,4.291232,0.60844225,,,,,,,,,,,,,, -467800,4.416436,0.6080047,,,,,,,,,,,,,, -467900,5.279144,0.62853307,,,,,,,,,,,,,, -468000,,,0.9604192972183228,0.1474740803241729,0.7558199763298035,1.0522536039352417,50000.0,0.6278000473976135,1.840748310089112,10000.0,156635.74769496918,161972.6580798626,156635.74769496918,5298.78008556366,21.30401349067688,0.0 -468000,4.7326045,0.5823145,,,,,,,,,,,,,, -468100,4.4909234,0.55451655,,,,,,,,,,,,,, -468200,4.5329356,0.61919296,,,,,,,,,,,,,, -468300,4.48633,0.6597432,,,,,,,,,,,,,, -468400,4.6717596,0.65207344,,,,,,,,,,,,,, -468500,4.4112387,0.59601945,,,,,,,,,,,,,, -468600,4.5074224,0.57294756,,,,,,,,,,,,,, -468700,4.1142254,0.58965355,,,,,,,,,,,,,, -468800,4.539187,0.65600145,,,,,,,,,,,,,, -468900,4.939082,0.6185413,,,,,,,,,,,,,, -469000,4.417056,0.5630435,,,,,,,,,,,,,, -469100,5.198957,0.6332033,,,,,,,,,,,,,, -469200,4.8903875,0.6176553,,,,,,,,,,,,,, -469300,4.8831615,0.71005327,,,,,,,,,,,,,, -469400,4.467642,0.57462776,,,,,,,,,,,,,, -469500,4.635161,0.6361099,,,,,,,,,,,,,, -469524,,,0.9623923301696776,0.143751248717308,0.7559199929237366,1.0532188415527344,50000.0,0.6272000074386597,1.8413631916046145,10000.0,157145.59330034256,162499.1994020939,157145.59330034256,5315.314467906952,21.41022562980652,0.0 -469600,4.6134453,0.6571113,,,,,,,,,,,,,, -469700,4.4422946,0.63648814,,,,,,,,,,,,,, -469800,4.4915133,0.67528844,,,,,,,,,,,,,, -469900,4.3478703,0.5731031,,,,,,,,,,,,,, -470000,4.143252,0.58762515,,,,,,,,,,,,,, -470100,4.417697,0.6335534,,,,,,,,,,,,,, -470200,4.7137685,0.59466016,,,,,,,,,,,,,, -470300,4.603637,0.66272306,,,,,,,,,,,,,, -470400,4.3987484,0.5894207,,,,,,,,,,,,,, -470500,4.77372,0.64836156,,,,,,,,,,,,,, -470600,4.669475,0.57773703,,,,,,,,,,,,,, -470700,4.553305,0.6451446,,,,,,,,,,,,,, -470800,4.554162,0.6609609,,,,,,,,,,,,,, -470900,4.5810804,0.5587275,,,,,,,,,,,,,, -471000,4.1795125,0.597957,,,,,,,,,,,,,, -471049,,,0.9620934128761292,0.1425232142210006,0.7557399868965149,1.0534279346466064,50000.0,0.6276000142097473,1.842286944389344,10000.0,157655.4903459549,163025.69514727592,157655.4903459549,5331.7469527721405,21.52133750915528,0.0 -471100,4.6359477,0.7011052,,,,,,,,,,,,,, -471200,4.406514,0.60304624,,,,,,,,,,,,,, -471300,4.237603,0.54555184,,,,,,,,,,,,,, -471400,4.480926,0.5807482,,,,,,,,,,,,,, -471500,4.7723827,0.60454655,,,,,,,,,,,,,, -471600,4.9457846,0.6776686,,,,,,,,,,,,,, -471700,4.291858,0.5941752,,,,,,,,,,,,,, -471800,4.790321,0.6648251,,,,,,,,,,,,,, -471900,5.1485944,0.59600425,,,,,,,,,,,,,, -472000,4.0666175,0.59611046,,,,,,,,,,,,,, -472100,4.9713087,0.61727494,,,,,,,,,,,,,, -472200,5.686345,0.6052496,,,,,,,,,,,,,, -472300,4.7440553,0.62364334,,,,,,,,,,,,,, -472400,4.482463,0.57671154,,,,,,,,,,,,,, -472500,4.5929885,0.5712929,,,,,,,,,,,,,, -472574,,,0.9614756107330322,0.1465017646551132,0.7559199929237366,1.054178237915039,50000.0,0.6266000270843506,1.8426686525344849,10000.0,158165.50344014168,163552.47993898392,158165.50344014168,5348.365239858627,21.62071895599365,0.0 -472600,4.249711,0.5652744,,,,,,,,,,,,,, -472700,4.535915,0.6266185,,,,,,,,,,,,,, -472800,4.5552135,0.6507634,,,,,,,,,,,,,, -472900,4.7233114,0.6484628,,,,,,,,,,,,,, -473000,4.540245,0.6141596,,,,,,,,,,,,,, -473100,4.6009407,0.5976197,,,,,,,,,,,,,, -473200,4.698856,0.66658014,,,,,,,,,,,,,, -473300,4.338219,0.6364316,,,,,,,,,,,,,, -473400,4.5515776,0.68526596,,,,,,,,,,,,,, -473500,4.023422,0.55402607,,,,,,,,,,,,,, -473600,4.861092,0.633435,,,,,,,,,,,,,, -473700,5.154351,0.65728253,,,,,,,,,,,,,, -473800,4.306108,0.56682795,,,,,,,,,,,,,, -473900,4.7553096,0.65768,,,,,,,,,,,,,, -474000,4.8838816,0.6446909,,,,,,,,,,,,,, -474099,,,0.960598647594452,0.1443846672773361,0.7561799883842468,1.0545673370361328,50000.0,0.6260000467300415,1.8445450067520144,10000.0,158675.39512062073,164078.99399518967,158675.39512062073,5364.82608294487,21.726683616638184,0.0 -474100,5.1566496,0.7044561,,,,,,,,,,,,,, -474200,4.602358,0.643038,,,,,,,,,,,,,, -474300,5.217817,0.6282569,,,,,,,,,,,,,, -474400,4.4418826,0.5728403,,,,,,,,,,,,,, -474500,4.6799326,0.6148782,,,,,,,,,,,,,, -474600,5.172929,0.64307815,,,,,,,,,,,,,, -474700,5.096512,0.62144554,,,,,,,,,,,,,, -474800,4.1203485,0.583463,,,,,,,,,,,,,, -474900,4.7709913,0.6472944,,,,,,,,,,,,,, -475000,4.46868,0.64501905,,,,,,,,,,,,,, -475100,4.3573074,0.58095944,,,,,,,,,,,,,, -475200,4.9428854,0.6509994,,,,,,,,,,,,,, -475300,4.1847563,0.6589055,,,,,,,,,,,,,, -475400,4.5060034,0.68199503,,,,,,,,,,,,,, -475500,4.5522304,0.59836453,,,,,,,,,,,,,, -475600,4.442284,0.62890273,,,,,,,,,,,,,, -475624,,,0.961734652519226,0.145376443862915,0.7560999989509583,1.0535876750946045,50000.0,0.6276000142097473,1.8428512811660769,10000.0,159185.36111426353,164605.59284853935,159185.36111426353,5381.293584108353,21.83575701713562,0.0 -475700,4.4817667,0.6612792,,,,,,,,,,,,,, -475800,4.341201,0.6009702,,,,,,,,,,,,,, -475900,5.360606,0.62180436,,,,,,,,,,,,,, -476000,4.4106355,0.63758963,,,,,,,,,,,,,, -476100,4.40606,0.5907555,,,,,,,,,,,,,, -476200,4.919465,0.7040031,,,,,,,,,,,,,, -476300,4.2967315,0.526677,,,,,,,,,,,,,, -476400,5.2751036,0.6987467,,,,,,,,,,,,,, -476500,4.6823826,0.59310734,,,,,,,,,,,,,, -476600,4.4165397,0.64209616,,,,,,,,,,,,,, -476700,4.317212,0.5409175,,,,,,,,,,,,,, -476800,4.707351,0.566974,,,,,,,,,,,,,, -476900,4.3521323,0.6502808,,,,,,,,,,,,,, -477000,5.0061655,0.61740386,,,,,,,,,,,,,, -477100,4.4821568,0.61080235,,,,,,,,,,,,,, -477149,,,0.9620137214660645,0.1419821977615356,0.756339967250824,1.0527808666229248,50000.0,0.6272000074386597,1.841827273368836,10000.0,159695.37145113945,165132.4526028633,159695.37145113945,5397.987356424332,21.93658208847046,0.0 -477200,4.5531526,0.65928155,,,,,,,,,,,,,, -477300,4.794081,0.6798487,,,,,,,,,,,,,, -477400,4.6417255,0.7179853,,,,,,,,,,,,,, -477500,4.625769,0.5771541,,,,,,,,,,,,,, -477600,4.8134007,0.6883118,,,,,,,,,,,,,, -477700,4.536417,0.5747119,,,,,,,,,,,,,, -477800,4.5987315,0.63662165,,,,,,,,,,,,,, -477900,4.3510294,0.61631244,,,,,,,,,,,,,, -478000,4.13691,0.6049595,,,,,,,,,,,,,, -478100,4.860081,0.62660944,,,,,,,,,,,,,, -478200,4.604631,0.6415016,,,,,,,,,,,,,, -478300,4.2614045,0.6221091,,,,,,,,,,,,,, -478400,4.8733845,0.6598207,,,,,,,,,,,,,, -478500,4.403467,0.53919655,,,,,,,,,,,,,, -478600,4.7535477,0.65793824,,,,,,,,,,,,,, -478674,,,0.96000075340271,0.1474557220935821,0.7561799883842468,1.052963137626648,50000.0,0.6279000043869019,1.8423027992248533,10000.0,160205.27926325798,165659.02926802635,160205.27926325798,5414.497764825821,22.03964066505432,0.0 -478700,4.385972,0.5980126,,,,,,,,,,,,,, -478800,4.3258033,0.58465564,,,,,,,,,,,,,, -478900,4.653346,0.658162,,,,,,,,,,,,,, -479000,4.6939526,0.62414587,,,,,,,,,,,,,, -479100,4.2015834,0.5801031,,,,,,,,,,,,,, -479200,4.7077374,0.61043376,,,,,,,,,,,,,, -479300,4.620136,0.6723326,,,,,,,,,,,,,, -479400,4.8270116,0.63961685,,,,,,,,,,,,,, -479500,4.4824367,0.65738475,,,,,,,,,,,,,, -479600,4.523421,0.6270629,,,,,,,,,,,,,, -479700,4.6549697,0.60089,,,,,,,,,,,,,, -479800,4.8226705,0.66099,,,,,,,,,,,,,, -479900,4.3281307,0.62669384,,,,,,,,,,,,,, -480000,4.6186748,0.6011903,,,,,,,,,,,,,, -480100,5.004118,0.64885294,,,,,,,,,,,,,, -480198,,,0.9604990482330322,0.1476502567529678,0.7561999559402466,1.053896188735962,50000.0,0.6277000308036804,1.8432202339172363,10000.0,160715.17672729492,166185.70827054977,160715.17672729492,5431.116099834442,22.14845061302185,0.0 -480200,4.326776,0.6261751,,,,,,,,,,,,,, -480300,4.2005835,0.6236304,,,,,,,,,,,,,, -480400,4.537783,0.5784264,,,,,,,,,,,,,, -480500,4.5647326,0.6041374,,,,,,,,,,,,,, -480600,4.5423956,0.6133046,,,,,,,,,,,,,, -480700,4.5558224,0.63916796,,,,,,,,,,,,,, -480800,4.647112,0.61887175,,,,,,,,,,,,,, -480900,4.3471837,0.6037658,,,,,,,,,,,,,, -481000,4.4361663,0.67091626,,,,,,,,,,,,,, -481100,4.610709,0.62422633,,,,,,,,,,,,,, -481200,4.8102818,0.71329975,,,,,,,,,,,,,, -481300,4.4318404,0.6277058,,,,,,,,,,,,,, -481400,4.490926,0.67796355,,,,,,,,,,,,,, -481500,4.514211,0.62702054,,,,,,,,,,,,,, -481600,4.8851867,0.67923,,,,,,,,,,,,,, -481700,4.4078226,0.5812306,,,,,,,,,,,,,, -481724,,,0.9612364172935486,0.1440017223358154,0.7558599710464478,1.0537447929382324,50000.0,0.6271000504493713,1.841601610183716,10000.0,161225.3224759102,166712.59029626846,161225.3224759102,5447.69705748558,22.248322248458862,0.0 -481800,4.7642727,0.6987426,,,,,,,,,,,,,, -481900,5.0549626,0.58185035,,,,,,,,,,,,,, -482000,4.382037,0.64963007,,,,,,,,,,,,,, -482100,4.581539,0.6347487,,,,,,,,,,,,,, -482200,4.1065445,0.5831294,,,,,,,,,,,,,, -482300,4.6523685,0.6055156,,,,,,,,,,,,,, -482400,4.322072,0.5625337,,,,,,,,,,,,,, -482500,4.448874,0.6197338,,,,,,,,,,,,,, -482600,5.067746,0.6728647,,,,,,,,,,,,,, -482700,4.4980793,0.66325164,,,,,,,,,,,,,, -482800,4.5999317,0.6663071,,,,,,,,,,,,,, -482900,4.7480025,0.7111116,,,,,,,,,,,,,, -483000,4.4020414,0.6155368,,,,,,,,,,,,,, -483100,4.3389177,0.576325,,,,,,,,,,,,,, -483200,4.7962074,0.67062926,,,,,,,,,,,,,, -483249,,,0.9628507494926452,0.1429650485515594,0.7554599642753601,1.0544278621673584,50000.0,0.626300036907196,1.843042254447937,10000.0,161735.1811146736,167239.09651231766,161735.1811146736,5464.181796789169,22.35664367675781,0.0 -483300,4.8175883,0.61291254,,,,,,,,,,,,,, -483400,4.475192,0.67943394,,,,,,,,,,,,,, -483500,4.4731245,0.59650755,,,,,,,,,,,,,, -483600,4.121833,0.61857057,,,,,,,,,,,,,, -483700,4.558166,0.6209873,,,,,,,,,,,,,, -483800,4.4880705,0.6012471,,,,,,,,,,,,,, -483900,4.865153,0.6321233,,,,,,,,,,,,,, -484000,4.713775,0.6417858,,,,,,,,,,,,,, -484100,4.396047,0.65128946,,,,,,,,,,,,,, -484200,4.5668683,0.67969334,,,,,,,,,,,,,, -484300,4.660645,0.5394304,,,,,,,,,,,,,, -484400,5.285971,0.70119995,,,,,,,,,,,,,, -484500,5.091409,0.6468226,,,,,,,,,,,,,, -484600,5.018604,0.6427838,,,,,,,,,,,,,, -484700,4.506955,0.5953277,,,,,,,,,,,,,, -484774,,,0.9619140625,0.1458676904439926,0.755899965763092,1.054257035255432,50000.0,0.628000020980835,1.8424336910247805,10000.0,162245.09081196785,167765.5865190029,162245.09081196785,5480.593630313873,22.470579624176025,0.0 -484800,5.1371136,0.5861385,,,,,,,,,,,,,, -484900,4.5187464,0.6540455,,,,,,,,,,,,,, -485000,4.5240626,0.68340683,,,,,,,,,,,,,, -485100,4.344922,0.55978143,,,,,,,,,,,,,, -485200,4.971452,0.6578689,,,,,,,,,,,,,, -485300,4.7195616,0.6753786,,,,,,,,,,,,,, -485400,4.476355,0.5994179,,,,,,,,,,,,,, -485500,4.4126863,0.5506822,,,,,,,,,,,,,, -485600,4.811784,0.6801249,,,,,,,,,,,,,, -485700,4.33868,0.6522454,,,,,,,,,,,,,, -485800,4.6054263,0.6489172,,,,,,,,,,,,,, -485900,4.8869076,0.61246496,,,,,,,,,,,,,, -486000,5.184149,0.742789,,,,,,,,,,,,,, -486100,4.4085,0.6154285,,,,,,,,,,,,,, -486200,4.990835,0.642063,,,,,,,,,,,,,, -486299,,,0.9597616195678712,0.1485474854707718,0.7559199929237366,1.0532896518707275,50000.0,0.6274000406265259,1.8425171375274656,10000.0,162754.97107815742,168292.3037648201,162754.97107815742,5497.265798330307,22.580674409866333,0.0 -486300,4.7844887,0.63041866,,,,,,,,,,,,,, -486400,5.1533628,0.6280092,,,,,,,,,,,,,, -486500,4.7844563,0.6475599,,,,,,,,,,,,,, -486600,4.6245565,0.71880436,,,,,,,,,,,,,, -486700,4.7032647,0.6347091,,,,,,,,,,,,,, -486800,4.1727552,0.5660378,,,,,,,,,,,,,, -486900,4.5802593,0.6463851,,,,,,,,,,,,,, -487000,4.5797477,0.6170707,,,,,,,,,,,,,, -487100,4.451083,0.54338974,,,,,,,,,,,,,, -487200,4.7787957,0.68817556,,,,,,,,,,,,,, -487300,4.4080024,0.58064616,,,,,,,,,,,,,, -487400,4.605834,0.64967513,,,,,,,,,,,,,, -487500,4.559372,0.673734,,,,,,,,,,,,,, -487600,4.9984617,0.6085038,,,,,,,,,,,,,, -487700,4.4252977,0.6122447,,,,,,,,,,,,,, -487800,4.5303326,0.67449504,,,,,,,,,,,,,, -487825,,,0.9598014950752258,0.1468220502138137,0.7554999589920044,1.0537186861038208,50000.0,0.6271000504493713,1.8440321683883667,10000.0,163264.9387593269,168818.93969726562,163264.9387593269,5513.765965938568,22.693618059158325,0.0 -487900,4.4248357,0.565069,,,,,,,,,,,,,, -488000,5.002386,0.6325163,,,,,,,,,,,,,, -488100,4.7986765,0.65784943,,,,,,,,,,,,,, -488200,4.5579424,0.61174494,,,,,,,,,,,,,, -488300,5.0096016,0.6072334,,,,,,,,,,,,,, -488400,4.3198934,0.58428353,,,,,,,,,,,,,, -488500,4.490044,0.6580756,,,,,,,,,,,,,, -488600,4.6085896,0.6425649,,,,,,,,,,,,,, -488700,4.8612375,0.65698975,,,,,,,,,,,,,, -488800,4.667008,0.74087703,,,,,,,,,,,,,, -488900,4.814975,0.6157671,,,,,,,,,,,,,, -489000,4.767084,0.6511239,,,,,,,,,,,,,, -489100,4.3679423,0.5849635,,,,,,,,,,,,,, -489200,4.4420476,0.57413816,,,,,,,,,,,,,, -489300,4.670324,0.6435735,,,,,,,,,,,,,, -489350,,,0.9611766338348388,0.1468583345413208,0.7562400102615356,1.0533497333526611,50000.0,0.626800000667572,1.8408963680267327,10000.0,163774.9257595539,169345.6744902134,163774.9257595539,5530.3492596149445,22.80390238761902,0.0 -489400,4.5686755,0.5439143,,,,,,,,,,,,,, -489500,4.7703824,0.5501581,,,,,,,,,,,,,, -489600,4.748449,0.7149804,,,,,,,,,,,,,, -489700,4.3082037,0.59831965,,,,,,,,,,,,,, -489800,4.292296,0.56266224,,,,,,,,,,,,,, -489900,4.4907126,0.56818867,,,,,,,,,,,,,, -490000,4.4471717,0.6020443,,,,,,,,,,,,,, -490100,4.5774136,0.54190695,,,,,,,,,,,,,, -490200,4.3514876,0.619471,,,,,,,,,,,,,, -490300,5.1859956,0.572352,,,,,,,,,,,,,, -490400,4.764102,0.64552045,,,,,,,,,,,,,, -490500,4.9301467,0.7544696,,,,,,,,,,,,,, -490600,4.6597385,0.5747497,,,,,,,,,,,,,, -490700,4.8587413,0.7586398,,,,,,,,,,,,,, -490800,4.4954634,0.61286354,,,,,,,,,,,,,, -490876,,,0.9616748690605164,0.1434008181095123,0.7559399604797363,1.0531055927276611,50000.0,0.6274000406265259,1.8428215980529783,10000.0,164284.85869836807,169872.2647330761,164284.85869836807,5546.840213775635,22.913722276687626,0.0 -490900,4.1269197,0.52646005,,,,,,,,,,,,,, -491000,5.0145383,0.61435604,,,,,,,,,,,,,, -491100,4.6941113,0.6160658,,,,,,,,,,,,,, -491200,4.750326,0.6830719,,,,,,,,,,,,,, -491300,4.7097116,0.6402563,,,,,,,,,,,,,, -491400,4.4945173,0.636083,,,,,,,,,,,,,, -491500,4.4632177,0.64050466,,,,,,,,,,,,,, -491600,4.619197,0.61376256,,,,,,,,,,,,,, -491700,4.721036,0.6704756,,,,,,,,,,,,,, -491800,4.8337603,0.57712257,,,,,,,,,,,,,, -491900,5.2462626,0.6759409,,,,,,,,,,,,,, -492000,4.6964445,0.614269,,,,,,,,,,,,,, -492100,4.7583666,0.5657931,,,,,,,,,,,,,, -492200,4.330109,0.62663263,,,,,,,,,,,,,, -492300,4.46083,0.58587474,,,,,,,,,,,,,, -492400,4.3232336,0.565804,,,,,,,,,,,,,, -492401,,,0.95902419090271,0.1491030007600784,0.7559999823570251,1.0529998540878296,50000.0,0.628000020980835,1.8415027856826784,10000.0,164794.85840845108,170399.26122307777,164794.85840845108,5563.665554523468,23.0304811000824,0.0 -492500,4.4326105,0.55380064,,,,,,,,,,,,,, -492600,4.932294,0.6815879,,,,,,,,,,,,,, -492700,4.3769875,0.63537425,,,,,,,,,,,,,, -492800,4.860992,0.67170316,,,,,,,,,,,,,, -492900,4.872268,0.6829661,,,,,,,,,,,,,, -493000,4.977144,0.65404737,,,,,,,,,,,,,, -493100,4.1769834,0.55600923,,,,,,,,,,,,,, -493200,4.304905,0.54346323,,,,,,,,,,,,,, -493300,4.529541,0.6407398,,,,,,,,,,,,,, -493400,4.24452,0.5347159,,,,,,,,,,,,,, -493500,4.7526555,0.69364786,,,,,,,,,,,,,, -493600,4.73664,0.66846406,,,,,,,,,,,,,, -493700,4.635628,0.6564827,,,,,,,,,,,,,, -493800,4.598159,0.71184343,,,,,,,,,,,,,, -493900,4.6155696,0.641464,,,,,,,,,,,,,, -493927,,,0.961535394191742,0.1440262198448181,0.7557599544525146,1.0536928176879885,50000.0,0.6270000338554382,1.842665672302246,10000.0,165304.89320635796,170926.0573911667,165304.89320635796,5580.262861967087,23.13922953605652,0.0 -494000,4.351068,0.6146974,,,,,,,,,,,,,, -494100,4.2249136,0.5666126,,,,,,,,,,,,,, -494200,4.20577,0.6279916,,,,,,,,,,,,,, -494300,4.556978,0.6040622,,,,,,,,,,,,,, -494400,4.3340917,0.62719285,,,,,,,,,,,,,, -494500,3.9947503,0.572607,,,,,,,,,,,,,, -494600,4.3769774,0.6276308,,,,,,,,,,,,,, -494700,4.093511,0.57234484,,,,,,,,,,,,,, -494800,4.6374006,0.6343936,,,,,,,,,,,,,, -494900,5.147204,0.6622697,,,,,,,,,,,,,, -495000,4.512298,0.57885313,,,,,,,,,,,,,, -495100,4.6005044,0.61245203,,,,,,,,,,,,,, -495200,5.2397957,0.7131282,,,,,,,,,,,,,, -495300,4.498318,0.6050428,,,,,,,,,,,,,, -495400,4.5068593,0.5852612,,,,,,,,,,,,,, -495452,,,0.9607979655265808,0.146340012550354,0.7558799982070923,1.053356647491455,50000.0,0.625700056552887,1.8420727252960205,10000.0,165814.74216461182,171452.58564066887,165814.74216461182,5596.775694847107,23.250207901000977,0.0 -495500,4.655079,0.5777209,,,,,,,,,,,,,, -495600,4.5506597,0.6387628,,,,,,,,,,,,,, -495700,4.1242146,0.5467887,,,,,,,,,,,,,, -495800,4.1976943,0.66009176,,,,,,,,,,,,,, -495900,4.3657665,0.61536473,,,,,,,,,,,,,, -496000,4.6976323,0.6666388,,,,,,,,,,,,,, -496100,4.40412,0.5885495,,,,,,,,,,,,,, -496200,5.0155888,0.7158313,,,,,,,,,,,,,, -496300,4.020798,0.59471303,,,,,,,,,,,,,, -496400,4.569906,0.6246079,,,,,,,,,,,,,, -496500,4.018623,0.56959164,,,,,,,,,,,,,, -496600,4.800584,0.6287447,,,,,,,,,,,,,, -496700,4.6014013,0.6846334,,,,,,,,,,,,,, -496800,4.7670875,0.6170636,,,,,,,,,,,,,, -496900,4.2759247,0.60545456,,,,,,,,,,,,,, -496977,,,0.9622329473495485,0.1473593860864639,0.7556599974632263,1.0541554689407349,50000.0,0.6273000240325928,1.841199278831482,10000.0,166324.6839056015,171979.1335196495,166324.6839056015,5613.217988491058,23.358347415924072,0.0 -497000,5.0576334,0.8318155,,,,,,,,,,,,,, -497100,5.123899,0.7167519,,,,,,,,,,,,,, -497200,4.8334775,0.647539,,,,,,,,,,,,,, -497300,4.7448287,0.603479,,,,,,,,,,,,,, -497400,4.301151,0.55104285,,,,,,,,,,,,,, -497500,4.679431,0.5856072,,,,,,,,,,,,,, -497600,4.0743175,0.5582123,,,,,,,,,,,,,, -497700,5.3352304,0.7197027,,,,,,,,,,,,,, -497800,4.6625347,0.67771643,,,,,,,,,,,,,, -497900,4.727439,0.5441408,,,,,,,,,,,,,, -498000,4.6265535,0.74550486,,,,,,,,,,,,,, -498100,4.4598837,0.629204,,,,,,,,,,,,,, -498200,4.738768,0.6567799,,,,,,,,,,,,,, -498300,4.588686,0.62560433,,,,,,,,,,,,,, -498400,4.7094617,0.66471446,,,,,,,,,,,,,, -498500,4.5909557,0.5720119,,,,,,,,,,,,,, -498502,,,0.9600805044174194,0.1480731964111328,0.7558199763298035,1.0539394617080688,50000.0,0.6270000338554382,1.8436610698699951,10000.0,166834.6341896057,172505.75029206276,166834.6341896057,5629.716914653778,23.470913410186768,0.0 -498600,4.6160355,0.6238266,,,,,,,,,,,,,, -498700,4.301345,0.61447227,,,,,,,,,,,,,, -498800,4.738711,0.6652129,,,,,,,,,,,,,, -498900,4.7776027,0.6134021,,,,,,,,,,,,,, -499000,5.0496054,0.64988965,,,,,,,,,,,,,, -499100,4.203109,0.5627842,,,,,,,,,,,,,, -499200,4.29635,0.57296723,,,,,,,,,,,,,, -499300,4.768571,0.63737816,,,,,,,,,,,,,, -499400,4.5522423,0.59069455,,,,,,,,,,,,,, -499500,4.6464925,0.64855933,,,,,,,,,,,,,, -499600,4.6294513,0.5962375,,,,,,,,,,,,,, -499700,4.4805446,0.6503971,,,,,,,,,,,,,, -499800,4.4629235,0.6366524,,,,,,,,,,,,,, -499900,4.3828735,0.65674156,,,,,,,,,,,,,, -500000,4.7307115,0.684623,,,,,,,,,,,,,, -500028,,,0.961933970451355,0.1435198634862899,0.755899965763092,1.05421781539917,50000.0,0.6279000043869019,1.8444896936416624,10000.0,167344.79048371315,173032.69118475914,167344.79048371315,5646.337370872498,23.58010768890381,0.0 -500100,4.1703286,0.6031622,,,,,,,,,,,,,, -500200,4.5071282,0.6657397,,,,,,,,,,,,,, -500300,4.7397175,0.6659973,,,,,,,,,,,,,, -500400,4.7134085,0.60516024,,,,,,,,,,,,,, -500500,4.685778,0.65277195,,,,,,,,,,,,,, -500600,4.4375854,0.63815725,,,,,,,,,,,,,, -500700,4.410193,0.65389115,,,,,,,,,,,,,, -500800,4.6195703,0.62127703,,,,,,,,,,,,,, -500900,4.24646,0.5793415,,,,,,,,,,,,,, -501000,4.38331,0.5818758,,,,,,,,,,,,,, -501100,4.123699,0.55212474,,,,,,,,,,,,,, -501200,4.7488666,0.64979243,,,,,,,,,,,,,, -501300,4.1185017,0.60782,,,,,,,,,,,,,, -501400,4.320708,0.59893215,,,,,,,,,,,,,, -501500,4.7763033,0.76946217,,,,,,,,,,,,,, -501554,,,0.9622528553009032,0.1441863924264907,0.7561999559402466,1.0531775951385498,50000.0,0.6272000074386597,1.8417606353759768,10000.0,167854.95328998566,173559.6701028347,167854.95328998566,5662.989856958389,23.68991541862488,0.0 -501600,4.7394123,0.6477067,,,,,,,,,,,,,, -501700,4.474443,0.6033011,,,,,,,,,,,,,, -501800,4.2891483,0.5674514,,,,,,,,,,,,,, -501900,4.617287,0.7018042,,,,,,,,,,,,,, -502000,4.4603405,0.6298684,,,,,,,,,,,,,, -502100,4.3407145,0.6310133,,,,,,,,,,,,,, -502200,4.641955,0.6086564,,,,,,,,,,,,,, -502300,4.3986855,0.56139535,,,,,,,,,,,,,, -502400,4.056003,0.5492617,,,,,,,,,,,,,, -502500,4.355138,0.582343,,,,,,,,,,,,,, -502600,4.5758457,0.6229078,,,,,,,,,,,,,, -502700,4.5300727,0.60553914,,,,,,,,,,,,,, -502800,4.107407,0.5352565,,,,,,,,,,,,,, -502900,4.660281,0.6266607,,,,,,,,,,,,,, -503000,4.770134,0.68942755,,,,,,,,,,,,,, -503044,,,0.961933970451355,0.1459151953458786,0.7558799982070923,1.053280234336853,50000.0,0.626800000667572,1.84154212474823,10000.0,168364.83893346786,174086.14171028137,168364.83893346786,5679.408395290375,23.802663803100582,0.0 -503100,4.487931,0.61365587,,,,,,,,,,,,,, -503200,4.171625,0.5702226,,,,,,,,,,,,,, -503300,5.4754186,0.72425485,,,,,,,,,,,,,, -503400,4.2159524,0.56069666,,,,,,,,,,,,,, -503500,4.657732,0.6528466,,,,,,,,,,,,,, -503600,4.4488254,0.6448202,,,,,,,,,,,,,, -503700,4.552936,0.64932317,,,,,,,,,,,,,, -503800,4.5888414,0.65705705,,,,,,,,,,,,,, -503900,4.37835,0.6010487,,,,,,,,,,,,,, -504000,4.5785336,0.60231405,,,,,,,,,,,,,, -504100,4.7806683,0.6580609,,,,,,,,,,,,,, -504200,4.847267,0.5886575,,,,,,,,,,,,,, -504300,4.7492433,0.57982266,,,,,,,,,,,,,, -504400,4.640085,0.67735887,,,,,,,,,,,,,, -504500,4.375025,0.6641911,,,,,,,,,,,,,, -504568,,,0.9605189561843872,0.1453207582235336,0.7561799883842468,1.0537110567092896,50000.0,0.6267000436782837,1.8432775735855105,10000.0,168874.7902429104,174613.502784729,168874.7902429104,5696.652493000031,23.912421226501465,0.0 -504600,4.417261,0.635112,,,,,,,,,,,,,, -504700,4.6060095,0.5855178,,,,,,,,,,,,,, -504800,4.468836,0.60675055,,,,,,,,,,,,,, -504900,4.6296124,0.6576758,,,,,,,,,,,,,, -505000,4.3907113,0.6119272,,,,,,,,,,,,,, -505100,4.4266624,0.59738696,,,,,,,,,,,,,, -505200,4.814717,0.70059365,,,,,,,,,,,,,, -505300,4.7734857,0.6216326,,,,,,,,,,,,,, -505400,4.793711,0.6021548,,,,,,,,,,,,,, -505500,4.74708,0.6901301,,,,,,,,,,,,,, -505600,4.74081,0.6183402,,,,,,,,,,,,,, -505700,4.3724656,0.61818147,,,,,,,,,,,,,, -505800,4.1237397,0.5475384,,,,,,,,,,,,,, -505900,4.833098,0.6041574,,,,,,,,,,,,,, -506000,4.569167,0.6404958,,,,,,,,,,,,,, -506093,,,0.9606186151504515,0.1453305780887603,0.756060004234314,1.0525985956192017,50000.0,0.6278000473976135,1.8410217761993408,10000.0,169384.76676797867,175140.1451165676,169384.76676797867,5713.153592586517,24.021644353866577,0.0 -506100,4.588873,0.55333835,,,,,,,,,,,,,, -506200,4.8675528,0.5967714,,,,,,,,,,,,,, -506300,4.4899898,0.600133,,,,,,,,,,,,,, -506400,4.347247,0.69970345,,,,,,,,,,,,,, -506500,4.141884,0.55181146,,,,,,,,,,,,,, -506600,4.4080462,0.6696468,,,,,,,,,,,,,, -506700,4.751103,0.6192421,,,,,,,,,,,,,, -506800,4.236886,0.594312,,,,,,,,,,,,,, -506900,4.676292,0.6250964,,,,,,,,,,,,,, -507000,4.36323,0.5678941,,,,,,,,,,,,,, -507100,4.3267436,0.5726657,,,,,,,,,,,,,, -507200,4.869409,0.6196149,,,,,,,,,,,,,, -507300,4.8984427,0.57849467,,,,,,,,,,,,,, -507400,5.170106,0.6932842,,,,,,,,,,,,,, -507500,4.039262,0.5733673,,,,,,,,,,,,,, -507600,4.4504914,0.62138695,,,,,,,,,,,,,, -507618,,,0.9618144035339355,0.1438751220703125,0.7559199929237366,1.0549389123916626,50000.0,0.6282000541687012,1.8442169427871704,10000.0,169894.76099991798,175666.93125104904,169894.76099991798,5729.777534723282,24.134642362594604,0.0 -507700,4.6215615,0.62089354,,,,,,,,,,,,,, -507800,4.528426,0.64906317,,,,,,,,,,,,,, -507900,4.561792,0.61783576,,,,,,,,,,,,,, -508000,4.7072763,0.5809514,,,,,,,,,,,,,, -508100,4.254691,0.6470643,,,,,,,,,,,,,, -508200,4.679493,0.59638923,,,,,,,,,,,,,, -508300,4.56377,0.65608746,,,,,,,,,,,,,, -508400,4.9132156,0.637711,,,,,,,,,,,,,, -508500,5.383348,0.688647,,,,,,,,,,,,,, -508600,4.466284,0.723018,,,,,,,,,,,,,, -508700,4.68731,0.6453605,,,,,,,,,,,,,, -508800,4.379806,0.5924809,,,,,,,,,,,,,, -508900,4.949202,0.630326,,,,,,,,,,,,,, -509000,4.692222,0.6183515,,,,,,,,,,,,,, -509100,4.8512244,0.6640951,,,,,,,,,,,,,, -509143,,,0.9618542790412904,0.142910823225975,0.756339967250824,1.0527304410934448,50000.0,0.6276000142097473,1.841456770896912,10000.0,170404.71442842484,176193.6103644371,170404.71442842484,5746.335768461227,24.246805906295776,0.0 -509200,4.325314,0.5672668,,,,,,,,,,,,,, -509300,4.8788023,0.6637169,,,,,,,,,,,,,, -509400,4.800893,0.68169224,,,,,,,,,,,,,, -509500,4.6972756,0.5974701,,,,,,,,,,,,,, -509600,4.702321,0.645822,,,,,,,,,,,,,, -509700,4.8130884,0.6908909,,,,,,,,,,,,,, -509800,4.324666,0.6037988,,,,,,,,,,,,,, -509900,4.664465,0.66401625,,,,,,,,,,,,,, -510000,4.274202,0.58092123,,,,,,,,,,,,,, -510100,4.4173536,0.57896024,,,,,,,,,,,,,, -510200,4.589279,0.6509609,,,,,,,,,,,,,, -510300,4.6467566,0.63599837,,,,,,,,,,,,,, -510400,4.5362563,0.64549303,,,,,,,,,,,,,, -510500,5.044271,0.71743417,,,,,,,,,,,,,, -510600,4.3300376,0.6147836,,,,,,,,,,,,,, -510668,,,0.961355984210968,0.1437069475650787,0.7561599612236023,1.0545107126235962,50000.0,0.6284000277519226,1.842883825302124,10000.0,170914.71456050873,176720.21445131302,170914.71456050873,5762.775901317596,24.35623812675476,0.0 -510700,4.465799,0.58405775,,,,,,,,,,,,,, -510800,4.060627,0.5511999,,,,,,,,,,,,,, -510900,4.098231,0.58910096,,,,,,,,,,,,,, -511000,5.3334007,0.6856693,,,,,,,,,,,,,, -511100,4.834436,0.68957037,,,,,,,,,,,,,, -511200,5.1852965,0.6522072,,,,,,,,,,,,,, -511300,4.3688745,0.5724721,,,,,,,,,,,,,, -511400,5.0055118,0.6231561,,,,,,,,,,,,,, -511500,5.0978775,0.58178926,,,,,,,,,,,,,, -511600,4.4713407,0.6176639,,,,,,,,,,,,,, -511700,4.7196407,0.60098916,,,,,,,,,,,,,, -511800,4.4849296,0.6479451,,,,,,,,,,,,,, -511900,4.231701,0.6011597,,,,,,,,,,,,,, -512000,4.840965,0.60350865,,,,,,,,,,,,,, -512100,4.28944,0.6207932,,,,,,,,,,,,,, -512193,,,0.960339605808258,0.1483767330646515,0.7562400102615356,1.0535874366760254,50000.0,0.6279000043869019,1.8420346975326536,10000.0,171424.77597403526,177246.92549419403,171424.77597403526,5779.257610321045,24.469741344451904,0.0 -512200,5.234527,0.6357159,,,,,,,,,,,,,, -512300,4.815878,0.59764457,,,,,,,,,,,,,, -512400,4.4331927,0.62533236,,,,,,,,,,,,,, -512500,4.677942,0.57428235,,,,,,,,,,,,,, -512600,4.559441,0.583383,,,,,,,,,,,,,, -512700,4.528979,0.6018347,,,,,,,,,,,,,, -512800,4.65356,0.64154243,,,,,,,,,,,,,, -512900,4.433058,0.62526435,,,,,,,,,,,,,, -513000,4.11994,0.6121774,,,,,,,,,,,,,, -513100,4.7050066,0.68343073,,,,,,,,,,,,,, -513200,4.4681396,0.6574478,,,,,,,,,,,,,, -513300,4.870691,0.578915,,,,,,,,,,,,,, -513400,4.6942725,0.615682,,,,,,,,,,,,,, -513500,4.961204,0.6152663,,,,,,,,,,,,,, -513600,4.859407,0.67526317,,,,,,,,,,,,,, -513700,4.953802,0.67852974,,,,,,,,,,,,,, -513718,,,0.9610371589660645,0.144314095377922,0.7560399770736694,1.052869439125061,50000.0,0.626800000667572,1.8419123888015747,10000.0,171934.87462615967,177773.80760240555,171934.87462615967,5795.874584913254,24.5813729763031,0.0 -513800,4.547638,0.56717974,,,,,,,,,,,,,, -513900,4.281021,0.6073224,,,,,,,,,,,,,, -514000,4.7776313,0.6247446,,,,,,,,,,,,,, -514100,4.3942947,0.64338994,,,,,,,,,,,,,, -514200,4.3776402,0.65002644,,,,,,,,,,,,,, -514300,5.090743,0.5996111,,,,,,,,,,,,,, -514400,5.0123916,0.65623933,,,,,,,,,,,,,, -514500,4.402135,0.5779224,,,,,,,,,,,,,, -514600,4.4275727,0.5923965,,,,,,,,,,,,,, -514700,4.275314,0.6345765,,,,,,,,,,,,,, -514800,4.387221,0.6535382,,,,,,,,,,,,,, -514900,3.9908664,0.5252773,,,,,,,,,,,,,, -515000,4.4356303,0.5515445,,,,,,,,,,,,,, -515100,4.516123,0.6179369,,,,,,,,,,,,,, -515200,4.805815,0.62429863,,,,,,,,,,,,,, -515242,,,0.9625318646430968,0.1427466124296188,0.755620002746582,1.054432988166809,50000.0,0.626800000667572,1.843067049980164,10000.0,172444.89008188248,178300.44254755974,172444.89008188248,5812.326997756958,24.69336938858032,0.0 -515300,4.4915147,0.60207456,,,,,,,,,,,,,, -515400,4.5196905,0.64321494,,,,,,,,,,,,,, -515500,4.498189,0.6228649,,,,,,,,,,,,,, -515600,4.5960255,0.60679567,,,,,,,,,,,,,, -515700,4.3857193,0.63071305,,,,,,,,,,,,,, -515800,4.7360463,0.59994465,,,,,,,,,,,,,, -515900,4.235222,0.543245,,,,,,,,,,,,,, -516000,4.9216447,0.7045395,,,,,,,,,,,,,, -516100,4.469709,0.63615656,,,,,,,,,,,,,, -516200,4.276394,0.56887895,,,,,,,,,,,,,, -516300,4.2371464,0.5912485,,,,,,,,,,,,,, -516400,4.7321362,0.6471418,,,,,,,,,,,,,, -516500,4.9466467,0.615396,,,,,,,,,,,,,, -516600,4.992017,0.6344851,,,,,,,,,,,,,, -516700,4.6865716,0.60437304,,,,,,,,,,,,,, -516767,,,0.9606186151504515,0.1462768763303756,0.7558799982070923,1.0523682832717896,50000.0,0.6270000338554382,1.8397737741470337,10000.0,172954.96911978722,178827.1764163971,172954.96911978722,5828.814296245575,24.806583642959595,0.0 -516800,4.4231105,0.5695746,,,,,,,,,,,,,, -516900,4.941811,0.610664,,,,,,,,,,,,,, -517000,4.7226653,0.6169535,,,,,,,,,,,,,, -517100,4.489122,0.59895754,,,,,,,,,,,,,, -517200,4.6637897,0.6800977,,,,,,,,,,,,,, -517300,4.44527,0.64032185,,,,,,,,,,,,,, -517400,4.584763,0.6123536,,,,,,,,,,,,,, -517500,4.536417,0.64905626,,,,,,,,,,,,,, -517600,4.893542,0.6511481,,,,,,,,,,,,,, -517700,4.3066607,0.542493,,,,,,,,,,,,,, -517800,4.9678993,0.639667,,,,,,,,,,,,,, -517900,4.7671423,0.69510716,,,,,,,,,,,,,, -518000,4.4967947,0.63690555,,,,,,,,,,,,,, -518100,5.149195,0.59254104,,,,,,,,,,,,,, -518200,4.286205,0.52041286,,,,,,,,,,,,,, -518291,,,0.9604790806770324,0.1468463689088821,0.755899965763092,1.0545798540115356,50000.0,0.6271000504493713,1.843677639961243,10000.0,173464.90362358093,179353.6888029575,173464.90362358093,5845.223039627075,24.920645713806152,0.0 -518300,4.007168,0.6238733,,,,,,,,,,,,,, -518400,4.6022606,0.60528785,,,,,,,,,,,,,, -518500,4.4235635,0.6338849,,,,,,,,,,,,,, -518600,4.5873833,0.6028842,,,,,,,,,,,,,, -518700,4.2504435,0.5735903,,,,,,,,,,,,,, -518800,4.408902,0.5975211,,,,,,,,,,,,,, -518900,4.527135,0.613292,,,,,,,,,,,,,, -519000,4.391265,0.66521156,,,,,,,,,,,,,, -519100,4.7998934,0.6932848,,,,,,,,,,,,,, -519200,4.3665137,0.6664358,,,,,,,,,,,,,, -519300,4.284249,0.6288806,,,,,,,,,,,,,, -519400,4.5090694,0.6592475,,,,,,,,,,,,,, -519500,4.9726834,0.61174774,,,,,,,,,,,,,, -519600,4.5101237,0.58929175,,,,,,,,,,,,,, -519700,4.9498973,0.64008826,,,,,,,,,,,,,, -519800,4.3341494,0.6393042,,,,,,,,,,,,,, -519816,,,0.9612962007522584,0.1456597298383712,0.7555999755859375,1.0524157285690308,50000.0,0.6278000473976135,1.8394850492477417,10000.0,173974.819116354,179880.40839219093,173974.819116354,5861.859467506409,25.033472061157227,0.0 -519900,4.2226305,0.5844637,,,,,,,,,,,,,, -520000,4.267606,0.5230713,,,,,,,,,,,,,, -520100,4.779983,0.64209265,,,,,,,,,,,,,, -520200,4.48386,0.68759453,,,,,,,,,,,,,, -520300,5.2225847,0.71552473,,,,,,,,,,,,,, -520400,4.478705,0.5913094,,,,,,,,,,,,,, -520500,5.0737815,0.685373,,,,,,,,,,,,,, -520600,4.402547,0.49269885,,,,,,,,,,,,,, -520700,5.0337296,0.6843885,,,,,,,,,,,,,, -520800,4.583979,0.5835464,,,,,,,,,,,,,, -520900,4.8561254,0.5919979,,,,,,,,,,,,,, -521000,4.549122,0.61971396,,,,,,,,,,,,,, -521100,4.3651786,0.541514,,,,,,,,,,,,,, -521200,4.990673,0.6294947,,,,,,,,,,,,,, -521300,4.359197,0.62261486,,,,,,,,,,,,,, -521341,,,0.9618343114852904,0.1445403099060058,0.7559999823570251,1.05410897731781,50000.0,0.6273000240325928,1.8433527946472168,10000.0,174484.72979211807,180407.1592879296,174484.72979211807,5878.530686855316,25.147634506225582,0.0 -521400,4.8197923,0.58714956,,,,,,,,,,,,,, -521500,5.088758,0.70967484,,,,,,,,,,,,,, -521600,4.538136,0.5713005,,,,,,,,,,,,,, -521700,4.3773894,0.6244231,,,,,,,,,,,,,, -521800,4.825193,0.59030026,,,,,,,,,,,,,, -521900,4.519647,0.6049236,,,,,,,,,,,,,, -522000,4.8268833,0.68221474,,,,,,,,,,,,,, -522100,4.4286175,0.58730716,,,,,,,,,,,,,, -522200,4.610958,0.6413654,,,,,,,,,,,,,, -522300,4.614852,0.65460885,,,,,,,,,,,,,, -522400,4.138092,0.58757937,,,,,,,,,,,,,, -522500,4.6441083,0.6193123,,,,,,,,,,,,,, -522600,4.932289,0.6538849,,,,,,,,,,,,,, -522700,4.662079,0.60734487,,,,,,,,,,,,,, -522800,4.5573573,0.65462965,,,,,,,,,,,,,, -522866,,,0.961933970451355,0.1441616415977478,0.7557599544525146,1.0527174472808838,50000.0,0.6267000436782837,1.842559814453125,10000.0,174994.86545610428,180934.01882386208,174994.86545610428,5895.08545088768,25.262465476989743,0.0 -522900,4.695127,0.652891,,,,,,,,,,,,,, -523000,4.474781,0.5630832,,,,,,,,,,,,,, -523100,4.1968927,0.619519,,,,,,,,,,,,,, -523200,4.5892854,0.6121873,,,,,,,,,,,,,, -523300,4.475266,0.6362339,,,,,,,,,,,,,, -523400,4.7717214,0.6138316,,,,,,,,,,,,,, -523500,5.4140306,0.60590917,,,,,,,,,,,,,, -523600,4.8034997,0.6234239,,,,,,,,,,,,,, -523700,4.505289,0.63136655,,,,,,,,,,,,,, -523800,4.411556,0.5931481,,,,,,,,,,,,,, -523900,4.7047634,0.711261,,,,,,,,,,,,,, -524000,4.5683494,0.7020595,,,,,,,,,,,,,, -524100,4.6483817,0.6902298,,,,,,,,,,,,,, -524200,4.4780045,0.6204505,,,,,,,,,,,,,, -524300,4.1316123,0.5569994,,,,,,,,,,,,,, -524390,,,0.9607780575752258,0.1470319777727127,0.7557599544525146,1.053484320640564,50000.0,0.6276000142097473,1.8422954082489007,10000.0,175504.70123004913,181460.5635290146,175504.70123004913,5911.562630414963,25.43892741203308,0.0 -524400,4.47533,0.5750861,,,,,,,,,,,,,, -524500,4.6976275,0.58541536,,,,,,,,,,,,,, -524600,4.540474,0.6038301,,,,,,,,,,,,,, -524700,4.992748,0.6051353,,,,,,,,,,,,,, -524800,4.249572,0.57925373,,,,,,,,,,,,,, -524900,4.605364,0.6309927,,,,,,,,,,,,,, -525000,4.400017,0.5794605,,,,,,,,,,,,,, -525100,4.567641,0.6705934,,,,,,,,,,,,,, -525200,6.3223047,0.6656358,,,,,,,,,,,,,, -525300,4.580685,0.68591696,,,,,,,,,,,,,, -525400,4.2793436,0.5965566,,,,,,,,,,,,,, -525500,4.423162,0.59158915,,,,,,,,,,,,,, -525600,4.0610523,0.55233663,,,,,,,,,,,,,, -525700,5.968299,0.71561563,,,,,,,,,,,,,, -525800,4.627531,0.64231217,,,,,,,,,,,,,, -525900,4.4168468,0.6503611,,,,,,,,,,,,,, -525914,,,0.958984375,0.1497214436531067,0.7561599612236023,1.0536900758743286,50000.0,0.6270000338554382,1.843161582946777,10000.0,176014.61433887482,181987.117274046,176014.61433887482,5928.032803058624,25.55296421051025,0.0 -526000,4.776521,0.67613983,,,,,,,,,,,,,, -526100,4.643796,0.5957318,,,,,,,,,,,,,, -526200,4.561967,0.6334095,,,,,,,,,,,,,, -526300,4.7800484,0.5902704,,,,,,,,,,,,,, -526400,4.6078887,0.5845027,,,,,,,,,,,,,, -526500,5.2102523,0.7031007,,,,,,,,,,,,,, -526600,4.768908,0.6131947,,,,,,,,,,,,,, -526700,5.0647206,0.5975107,,,,,,,,,,,,,, -526800,4.976682,0.654761,,,,,,,,,,,,,, -526900,4.2261224,0.5320519,,,,,,,,,,,,,, -527000,4.344735,0.5704705,,,,,,,,,,,,,, -527100,4.5753384,0.6248565,,,,,,,,,,,,,, -527200,4.524,0.6477424,,,,,,,,,,,,,, -527300,5.07694,0.6131855,,,,,,,,,,,,,, -527400,4.207258,0.59213114,,,,,,,,,,,,,, -527438,,,0.9616350531578064,0.1456404328346252,0.756060004234314,1.0531024932861328,50000.0,0.6278000473976135,1.841441512107849,10000.0,176524.59498858452,182514.1163351536,176524.59498858452,5944.878372192383,25.67080807685852,0.0 -527500,4.1282463,0.595126,,,,,,,,,,,,,, -527600,4.5238924,0.6611868,,,,,,,,,,,,,, -527700,4.568972,0.62792444,,,,,,,,,,,,,, -527800,4.46744,0.68249524,,,,,,,,,,,,,, -527900,4.7204165,0.7169945,,,,,,,,,,,,,, -528000,4.307324,0.603876,,,,,,,,,,,,,, -528100,4.514718,0.68219876,,,,,,,,,,,,,, -528200,4.9053164,0.6817671,,,,,,,,,,,,,, -528300,4.5569806,0.62665313,,,,,,,,,,,,,, -528400,4.0583057,0.52784526,,,,,,,,,,,,,, -528500,4.8231297,0.6158608,,,,,,,,,,,,,, -528600,4.7360644,0.5855665,,,,,,,,,,,,,, -528700,4.5292716,0.6581434,,,,,,,,,,,,,, -528800,4.8396616,0.6563848,,,,,,,,,,,,,, -528900,5.0669374,0.57818174,,,,,,,,,,,,,, -528962,,,0.9606186151504515,0.1452937424182891,0.7559199929237366,1.05355703830719,50000.0,0.627500057220459,1.841924071311951,10000.0,177034.4949579239,183040.6117610932,177034.4949579239,5961.299426794052,25.78989005088806,0.0 -529000,5.2993755,0.6761712,,,,,,,,,,,,,, -529100,4.137909,0.62060726,,,,,,,,,,,,,, -529200,4.805147,0.60394996,,,,,,,,,,,,,, -529300,4.9543304,0.59239703,,,,,,,,,,,,,, -529400,4.3381,0.57908326,,,,,,,,,,,,,, -529500,4.3098974,0.6121148,,,,,,,,,,,,,, -529600,3.919882,0.57360786,,,,,,,,,,,,,, -529700,4.4417443,0.66682744,,,,,,,,,,,,,, -529800,4.550863,0.6069283,,,,,,,,,,,,,, -529900,4.571077,0.612324,,,,,,,,,,,,,, -530000,4.57463,0.6355525,,,,,,,,,,,,,, -530100,4.456437,0.61051273,,,,,,,,,,,,,, -530200,3.9090583,0.5333886,,,,,,,,,,,,,, -530300,4.7857895,0.6638588,,,,,,,,,,,,,, -530400,4.65207,0.650016,,,,,,,,,,,,,, -530484,,,0.960957407951355,0.1444409638643264,0.7559399604797363,1.053533673286438,50000.0,0.6266000270843506,1.84251868724823,10000.0,177543.64189291,183567.61677789688,177543.64189291,5978.160856962204,26.731829404830933,0.0 -530500,4.6734257,0.6167467,,,,,,,,,,,,,, -530600,4.2664037,0.56335294,,,,,,,,,,,,,, -530700,4.644364,0.6520174,,,,,,,,,,,,,, -530800,4.6340857,0.6478283,,,,,,,,,,,,,, -530900,4.816488,0.67946655,,,,,,,,,,,,,, -531000,4.8858256,0.6860359,,,,,,,,,,,,,, -531100,4.2386475,0.6116959,,,,,,,,,,,,,, -531200,4.336876,0.6060748,,,,,,,,,,,,,, -531300,4.740811,0.69780207,,,,,,,,,,,,,, -531400,4.731096,0.5920732,,,,,,,,,,,,,, -531500,4.378274,0.6214011,,,,,,,,,,,,,, -531600,4.7076235,0.6360512,,,,,,,,,,,,,, -531700,4.5505557,0.6137471,,,,,,,,,,,,,, -531800,4.3511095,0.56899893,,,,,,,,,,,,,, -531900,4.645086,0.6455606,,,,,,,,,,,,,, -532000,4.541383,0.6426327,,,,,,,,,,,,,, -532009,,,0.9603993892669678,0.1479032784700393,0.7558799982070923,1.0539860725402832,50000.0,0.6270000338554382,1.843008637428284,10000.0,178053.5409553051,184094.20546674728,178053.5409553051,5994.678615808487,26.84842109680176,0.0 -532100,4.632339,0.6281428,,,,,,,,,,,,,, -532200,4.4410167,0.5647638,,,,,,,,,,,,,, -532300,4.6139245,0.6660494,,,,,,,,,,,,,, -532400,4.526094,0.622298,,,,,,,,,,,,,, -532500,5.189747,0.6359619,,,,,,,,,,,,,, -532600,4.819828,0.66842914,,,,,,,,,,,,,, -532700,4.351941,0.5991864,,,,,,,,,,,,,, -532800,4.503009,0.62315094,,,,,,,,,,,,,, -532900,4.5041676,0.68845487,,,,,,,,,,,,,, -533000,4.5600286,0.6526836,,,,,,,,,,,,,, -533100,4.6615753,0.66505665,,,,,,,,,,,,,, -533200,4.4468203,0.6212861,,,,,,,,,,,,,, -533300,4.721327,0.60756993,,,,,,,,,,,,,, -533400,4.5651135,0.66827774,,,,,,,,,,,,,, -533500,4.753378,0.5892643,,,,,,,,,,,,,, -533533,,,0.961336076259613,0.1444842964410781,0.7559799551963806,1.0537562370300293,50000.0,0.6273000240325928,1.8435173034667969,10000.0,178563.4308130741,184620.69504141808,178563.4308130741,6011.105018854141,26.96674156188965,0.0 -533600,4.4691153,0.5611923,,,,,,,,,,,,,, -533700,4.582348,0.6130618,,,,,,,,,,,,,, -533800,4.6094913,0.57035583,,,,,,,,,,,,,, -533900,4.5953135,0.6157236,,,,,,,,,,,,,, -534000,5.181254,0.65769744,,,,,,,,,,,,,, -534100,4.51727,0.6795444,,,,,,,,,,,,,, -534200,4.150234,0.61914235,,,,,,,,,,,,,, -534300,4.580542,0.65498954,,,,,,,,,,,,,, -534400,4.5157566,0.6022147,,,,,,,,,,,,,, -534500,4.515368,0.6090156,,,,,,,,,,,,,, -534600,4.3239794,0.5823294,,,,,,,,,,,,,, -534700,4.2233205,0.61349684,,,,,,,,,,,,,, -534800,4.261032,0.54328585,,,,,,,,,,,,,, -534900,4.5299172,0.63791823,,,,,,,,,,,,,, -535000,4.458574,0.6094297,,,,,,,,,,,,,, -535058,,,0.96097731590271,0.1471495926380157,0.7558000087738037,1.0533289909362793,50000.0,0.6281000375747681,1.842696189880371,10000.0,179073.46681928635,185147.4665656089,179073.46681928635,6027.668654680252,27.084050178527832,0.0 -535100,4.765002,0.6469121,,,,,,,,,,,,,, -535200,4.2734804,0.5309758,,,,,,,,,,,,,, -535300,4.2038684,0.60229903,,,,,,,,,,,,,, -535400,4.4290833,0.6662001,,,,,,,,,,,,,, -535500,4.3842096,0.56476384,,,,,,,,,,,,,, -535600,4.5577216,0.6189643,,,,,,,,,,,,,, -535700,4.561962,0.58736455,,,,,,,,,,,,,, -535800,4.7518215,0.6462135,,,,,,,,,,,,,, -535900,4.699509,0.5929978,,,,,,,,,,,,,, -536000,4.865716,0.647278,,,,,,,,,,,,,, -536100,4.3811593,0.63248026,,,,,,,,,,,,,, -536200,4.260075,0.5670941,,,,,,,,,,,,,, -536300,4.514418,0.6238811,,,,,,,,,,,,,, -536400,4.1969314,0.64896804,,,,,,,,,,,,,, -536500,4.7847176,0.57903445,,,,,,,,,,,,,, -536505,,,0.9612962007522584,0.1462796926498413,0.7561399936676025,1.0543361902236938,50000.0,0.6271000504493713,1.84367835521698,10000.0,179583.57815003395,185674.32324552536,179583.57815003395,6044.2437698841095,27.201574563980103,0.0 -536600,4.5334873,0.63781226,,,,,,,,,,,,,, -536700,4.4716043,0.59379494,,,,,,,,,,,,,, -536800,4.9229956,0.7315494,,,,,,,,,,,,,, -536900,4.5011973,0.6468653,,,,,,,,,,,,,, -537000,4.1553783,0.5796433,,,,,,,,,,,,,, -537100,4.2305923,0.58866876,,,,,,,,,,,,,, -537200,4.7338157,0.5545239,,,,,,,,,,,,,, -537300,4.446261,0.6672071,,,,,,,,,,,,,, -537400,4.4084563,0.6218942,,,,,,,,,,,,,, -537500,4.234739,0.54490703,,,,,,,,,,,,,, -537600,4.526903,0.6337438,,,,,,,,,,,,,, -537700,5.133844,0.6595319,,,,,,,,,,,,,, -537800,4.7470818,0.65018535,,,,,,,,,,,,,, -537900,4.4875813,0.6595112,,,,,,,,,,,,,, -538000,4.8761415,0.7273535,,,,,,,,,,,,,, -538029,,,0.9604790806770324,0.1484165489673614,0.7555800080299377,1.0530855655670166,50000.0,0.6276000142097473,1.841093063354492,10000.0,180093.41668057442,186200.8007659912,180093.41668057442,6060.702314376831,27.3272008895874,0.0 -538100,4.3575625,0.6136626,,,,,,,,,,,,,, -538200,4.4842944,0.6195931,,,,,,,,,,,,,, -538300,4.4932623,0.6447457,,,,,,,,,,,,,, -538400,4.3974576,0.5274883,,,,,,,,,,,,,, -538500,4.433351,0.5847382,,,,,,,,,,,,,, -538600,4.346185,0.6359593,,,,,,,,,,,,,, -538700,4.1166925,0.54988647,,,,,,,,,,,,,, -538800,4.574645,0.64299685,,,,,,,,,,,,,, -538900,4.7934294,0.63424724,,,,,,,,,,,,,, -539000,4.5721126,0.6232057,,,,,,,,,,,,,, -539100,4.195144,0.5924355,,,,,,,,,,,,,, -539200,4.504337,0.60034204,,,,,,,,,,,,,, -539300,4.8180337,0.6146891,,,,,,,,,,,,,, -539400,4.241727,0.6085331,,,,,,,,,,,,,, -539500,4.3245625,0.6180578,,,,,,,,,,,,,, -539554,,,0.9626315236091614,0.1421427875757217,0.7562999725341797,1.053134799003601,50000.0,0.6273000240325928,1.8427072763442995,10000.0,180603.4564025402,186727.56794548035,180603.4564025402,6077.2591071128845,27.44335126876831,0.0 -539600,4.740704,0.55987245,,,,,,,,,,,,,, -539700,4.1570144,0.561297,,,,,,,,,,,,,, -539800,4.588585,0.6143201,,,,,,,,,,,,,, -539900,5.0822597,0.67310977,,,,,,,,,,,,,, -540000,4.5882425,0.6473323,,,,,,,,,,,,,, -540100,4.756973,0.6316428,,,,,,,,,,,,,, -540200,4.8517566,0.6351491,,,,,,,,,,,,,, -540300,4.1634884,0.60351807,,,,,,,,,,,,,, -540400,4.9491763,0.6064054,,,,,,,,,,,,,, -540500,5.5169835,0.70794004,,,,,,,,,,,,,, -540600,4.0078797,0.55883676,,,,,,,,,,,,,, -540700,4.8066435,0.68483496,,,,,,,,,,,,,, -540800,4.7926893,0.647758,,,,,,,,,,,,,, -540900,4.485292,0.69417113,,,,,,,,,,,,,, -541000,4.3089185,0.60586286,,,,,,,,,,,,,, -541078,,,0.961933970451355,0.1462670266628265,0.7557599544525146,1.0533305406570437,50000.0,0.627500057220459,1.8408427238464355,10000.0,181113.3877146244,187254.3824417591,181113.3877146244,6093.973418951035,27.557010173797607,0.0 -541100,4.4424067,0.5836692,,,,,,,,,,,,,, -541200,4.37156,0.5400392,,,,,,,,,,,,,, -541300,4.4390035,0.6018193,,,,,,,,,,,,,, -541400,4.4870253,0.6885827,,,,,,,,,,,,,, -541500,4.3784556,0.60773164,,,,,,,,,,,,,, -541600,4.777118,0.6026366,,,,,,,,,,,,,, -541700,4.9460235,0.68677306,,,,,,,,,,,,,, -541800,4.7883186,0.6264186,,,,,,,,,,,,,, -541900,4.17121,0.5958662,,,,,,,,,,,,,, -542000,4.699129,0.6378609,,,,,,,,,,,,,, -542100,4.0786815,0.55964434,,,,,,,,,,,,,, -542200,4.2319903,0.5872939,,,,,,,,,,,,,, -542300,5.1068125,0.60827404,,,,,,,,,,,,,, -542400,4.532334,0.6598502,,,,,,,,,,,,,, -542500,4.3453517,0.60445863,,,,,,,,,,,,,, -542600,4.7604322,0.62726456,,,,,,,,,,,,,, -542602,,,0.962890625,0.1419719606637954,0.7557399868965149,1.0550309419631958,50000.0,0.6265000104904175,1.8440637588500977,10000.0,181623.31734347343,187781.3084187508,181623.31734347343,6110.797034263611,27.67525243759156,0.0 -542700,4.5485835,0.6260137,,,,,,,,,,,,,, -542800,4.460746,0.5846852,,,,,,,,,,,,,, -542900,4.412121,0.5798645,,,,,,,,,,,,,, -543000,4.2810698,0.5900954,,,,,,,,,,,,,, -543100,4.182479,0.59887356,,,,,,,,,,,,,, -543200,4.536531,0.6473115,,,,,,,,,,,,,, -543300,4.8237867,0.6290068,,,,,,,,,,,,,, -543400,4.4209237,0.57642066,,,,,,,,,,,,,, -543500,5.321451,0.57423466,,,,,,,,,,,,,, -543600,4.289183,0.6117723,,,,,,,,,,,,,, -543700,4.6317806,0.6538396,,,,,,,,,,,,,, -543800,4.544804,0.61506844,,,,,,,,,,,,,, -543900,4.719925,0.61803275,,,,,,,,,,,,,, -544000,4.684595,0.6366585,,,,,,,,,,,,,, -544100,5.0491548,0.6444664,,,,,,,,,,,,,, -544127,,,0.9596819281578064,0.1471571922302246,0.755899965763092,1.0522381067276,50000.0,0.6279000043869019,1.8415024280548096,10000.0,182133.297362566,188308.0508189201,182133.297362566,6127.386559247971,27.79384064674377,0.0 -544200,4.701121,0.65213704,,,,,,,,,,,,,, -544300,4.6377726,0.6005422,,,,,,,,,,,,,, -544400,4.3671403,0.6562505,,,,,,,,,,,,,, -544500,3.9343605,0.569952,,,,,,,,,,,,,, -544600,4.6575017,0.63301057,,,,,,,,,,,,,, -544700,4.4706583,0.6300749,,,,,,,,,,,,,, -544800,5.861497,0.6144641,,,,,,,,,,,,,, -544900,5.3398395,0.65160197,,,,,,,,,,,,,, -545000,4.6099215,0.6597897,,,,,,,,,,,,,, -545100,4.227875,0.5945039,,,,,,,,,,,,,, -545200,5.3852253,0.7068888,,,,,,,,,,,,,, -545300,4.4458723,0.5636585,,,,,,,,,,,,,, -545400,5.0179987,0.72992814,,,,,,,,,,,,,, -545500,4.6793966,0.63719106,,,,,,,,,,,,,, -545600,4.575453,0.60920185,,,,,,,,,,,,,, -545652,,,0.9605787396430968,0.145737811923027,0.7560999989509583,1.053802251815796,50000.0,0.6266000270843506,1.8414411544799805,10000.0,182643.4147348404,188834.6835012436,182643.4147348404,6143.728442192078,27.91211032867432,0.0 -545700,4.7143335,0.68772745,,,,,,,,,,,,,, -545800,4.329837,0.6025178,,,,,,,,,,,,,, -545900,4.9741216,0.5347481,,,,,,,,,,,,,, -546000,4.554979,0.62944776,,,,,,,,,,,,,, -546100,4.865916,0.623157,,,,,,,,,,,,,, -546200,4.308937,0.57236165,,,,,,,,,,,,,, -546300,4.264748,0.6328747,,,,,,,,,,,,,, -546400,4.9403358,0.69806504,,,,,,,,,,,,,, -546500,4.6761484,0.690206,,,,,,,,,,,,,, -546600,4.644952,0.69449323,,,,,,,,,,,,,, -546700,4.7603602,0.6011862,,,,,,,,,,,,,, -546800,4.2645082,0.5693356,,,,,,,,,,,,,, -546900,5.3993506,0.61519,,,,,,,,,,,,,, -547000,4.3420534,0.5654418,,,,,,,,,,,,,, -547100,5.1754403,0.6323074,,,,,,,,,,,,,, -547176,,,0.9627709984779358,0.1430668830871582,0.7557599544525146,1.0532115697860718,50000.0,0.6276000142097473,1.8418182134628296,10000.0,183153.32903194427,189361.20776104927,183153.32903194427,6160.168985366821,28.02666687965393,0.0 -547200,4.5764594,0.6626648,,,,,,,,,,,,,, -547300,4.4966664,0.58333796,,,,,,,,,,,,,, -547400,4.4077187,0.627046,,,,,,,,,,,,,, -547500,4.351196,0.6574564,,,,,,,,,,,,,, -547600,4.3196373,0.62901413,,,,,,,,,,,,,, -547700,6.114619,0.65073586,,,,,,,,,,,,,, -547800,4.2015934,0.6168034,,,,,,,,,,,,,, -547900,4.554284,0.60131925,,,,,,,,,,,,,, -548000,5.305532,0.6729773,,,,,,,,,,,,,, -548100,4.3361096,0.5512657,,,,,,,,,,,,,, -548200,4.453591,0.58770734,,,,,,,,,,,,,, -548300,4.5716667,0.6359855,,,,,,,,,,,,,, -548400,4.4031987,0.6790807,,,,,,,,,,,,,, -548500,4.4665194,0.6217831,,,,,,,,,,,,,, -548600,4.3519483,0.61678666,,,,,,,,,,,,,, -548700,,,0.9610371589660645,0.1452589631080627,0.7557799816131592,1.0536679029464722,50000.0,0.6270000338554382,1.84215247631073,10000.0,183663.2982008457,189887.9575984478,183663.2982008457,6176.777256727219,28.14390015602112,0.0 -548700,5.0245876,0.63652337,,,,,,,,,,,,,, -548800,4.2248626,0.5974211,,,,,,,,,,,,,, -548900,4.6743803,0.64785385,,,,,,,,,,,,,, -549000,5.367786,0.62301064,,,,,,,,,,,,,, -549100,4.3153925,0.55050325,,,,,,,,,,,,,, -549200,4.43697,0.65383375,,,,,,,,,,,,,, -549300,4.265665,0.61950976,,,,,,,,,,,,,, -549400,4.5704684,0.60172296,,,,,,,,,,,,,, -549500,4.9834366,0.6294066,,,,,,,,,,,,,, -549600,4.152566,0.576192,,,,,,,,,,,,,, -549700,4.543679,0.6136726,,,,,,,,,,,,,, -549800,4.2384944,0.59903646,,,,,,,,,,,,,, -549900,4.2578883,0.6546815,,,,,,,,,,,,,, -550000,5.0527945,0.66023827,,,,,,,,,,,,,, -550100,4.4543023,0.6516067,,,,,,,,,,,,,, -550200,4.8604555,0.6158688,,,,,,,,,,,,,, -550225,,,0.9610969424247742,0.1436194628477096,0.756060004234314,1.0538721084594729,50000.0,0.6277000308036804,1.8426450490951536,10000.0,184173.2213394642,190414.57474303248,184173.2213394642,6193.291128158569,28.26856780052185,0.0 -550300,4.631459,0.59065634,,,,,,,,,,,,,, -550400,4.5177813,0.6128503,,,,,,,,,,,,,, -550500,4.26663,0.62540376,,,,,,,,,,,,,, -550600,4.565075,0.6285277,,,,,,,,,,,,,, -550700,4.7602377,0.6461604,,,,,,,,,,,,,, -550800,4.645562,0.6541556,,,,,,,,,,,,,, -550900,4.225107,0.5906207,,,,,,,,,,,,,, -551000,4.7478905,0.6181554,,,,,,,,,,,,,, -551100,4.938554,0.6277505,,,,,,,,,,,,,, -551200,4.6245365,0.6209064,,,,,,,,,,,,,, -551300,4.2950587,0.6168442,,,,,,,,,,,,,, -551400,4.6340404,0.6247663,,,,,,,,,,,,,, -551500,4.547006,0.636678,,,,,,,,,,,,,, -551600,5.0045595,0.6778442,,,,,,,,,,,,,, -551700,4.4179363,0.6190818,,,,,,,,,,,,,, -551749,,,0.9610371589660645,0.1476167142391204,0.7555999755859375,1.0538707971572876,50000.0,0.6265000104904175,1.842309713363648,10000.0,184683.1692752838,190941.2047362328,184683.1692752838,6209.7367441654205,28.450235843658447,0.0 -551800,4.2507153,0.57797724,,,,,,,,,,,,,, -551900,4.4318676,0.5485876,,,,,,,,,,,,,, -552000,4.749039,0.7497356,,,,,,,,,,,,,, -552100,4.5545096,0.5665038,,,,,,,,,,,,,, -552200,4.771418,0.64529216,,,,,,,,,,,,,, -552300,4.5018954,0.59035254,,,,,,,,,,,,,, -552400,5.2252326,0.6888699,,,,,,,,,,,,,, -552500,4.1383395,0.52231264,,,,,,,,,,,,,, -552600,5.5690904,0.56263566,,,,,,,,,,,,,, -552700,4.717896,0.7070052,,,,,,,,,,,,,, -552800,4.5697145,0.6839464,,,,,,,,,,,,,, -552900,4.9133587,0.69151837,,,,,,,,,,,,,, -553000,4.604476,0.546784,,,,,,,,,,,,,, -553100,4.2702475,0.6488395,,,,,,,,,,,,,, -553200,4.433193,0.59904003,,,,,,,,,,,,,, -553274,,,0.961136758327484,0.1457731872797012,0.7558599710464478,1.0535832643508911,50000.0,0.6269000172615051,1.841801643371582,10000.0,185193.2670288086,191467.9659619332,185193.2670288086,6226.223253250122,28.5725200176239,0.0 -553300,4.5651155,0.6579961,,,,,,,,,,,,,, -553400,5.0558515,0.71117616,,,,,,,,,,,,,, -553500,3.8647115,0.54680514,,,,,,,,,,,,,, -553600,4.7465878,0.6495926,,,,,,,,,,,,,, -553700,4.5474143,0.6119635,,,,,,,,,,,,,, -553800,4.524197,0.6170745,,,,,,,,,,,,,, -553900,4.2804713,0.58606213,,,,,,,,,,,,,, -554000,4.7214165,0.62561345,,,,,,,,,,,,,, -554100,4.835414,0.65818316,,,,,,,,,,,,,, -554200,4.477627,0.6738634,,,,,,,,,,,,,, -554300,4.192527,0.57162017,,,,,,,,,,,,,, -554400,4.8711967,0.6908881,,,,,,,,,,,,,, -554500,4.2518554,0.60292214,,,,,,,,,,,,,, -554600,4.3901,0.6336104,,,,,,,,,,,,,, -554700,4.2090845,0.60647064,,,,,,,,,,,,,, -554799,,,0.9621731042861938,0.1420705914497375,0.7558199763298035,1.052920937538147,50000.0,0.6267000436782837,1.842273473739624,10000.0,185703.3607866764,191994.8460161686,185703.3607866764,6242.835218429565,28.690925359725952,0.0 -554800,4.5174313,0.5701225,,,,,,,,,,,,,, -554900,4.7087135,0.61187845,,,,,,,,,,,,,, -555000,4.516648,0.6305645,,,,,,,,,,,,,, -555100,4.4607472,0.58657396,,,,,,,,,,,,,, -555200,5.0747495,0.57347506,,,,,,,,,,,,,, -555300,4.504955,0.6219678,,,,,,,,,,,,,, -555400,4.7417884,0.674124,,,,,,,,,,,,,, -555500,4.732375,0.6739352,,,,,,,,,,,,,, -555600,4.6890388,0.63184786,,,,,,,,,,,,,, -555700,4.472231,0.61485577,,,,,,,,,,,,,, -555800,4.743808,0.5994712,,,,,,,,,,,,,, -555900,4.519266,0.61639535,,,,,,,,,,,,,, -556000,4.4761133,0.6029073,,,,,,,,,,,,,, -556100,4.566155,0.5472228,,,,,,,,,,,,,, -556200,4.845163,0.66366965,,,,,,,,,,,,,, -556300,4.195363,0.5272296,,,,,,,,,,,,,, -556324,,,0.9607979655265808,0.1460682153701782,0.7557199597358704,1.053757905960083,50000.0,0.6276000142097473,1.8436654806137085,10000.0,186213.40687561035,192521.5072004795,186213.40687561035,6259.27534365654,28.811352014541622,0.0 -556400,4.224648,0.574523,,,,,,,,,,,,,, -556500,4.7373743,0.63635296,,,,,,,,,,,,,, -556600,4.585903,0.6192189,,,,,,,,,,,,,, -556700,4.1421576,0.58386445,,,,,,,,,,,,,, -556800,4.879759,0.64999765,,,,,,,,,,,,,, -556900,4.720467,0.5892005,,,,,,,,,,,,,, -557000,4.605098,0.7060687,,,,,,,,,,,,,, -557100,4.108129,0.57733226,,,,,,,,,,,,,, -557200,4.466889,0.5997936,,,,,,,,,,,,,, -557300,4.259212,0.5779652,,,,,,,,,,,,,, -557400,4.6538157,0.66874516,,,,,,,,,,,,,, -557500,4.5507455,0.57741624,,,,,,,,,,,,,, -557600,4.48176,0.62778044,,,,,,,,,,,,,, -557700,4.881882,0.6408174,,,,,,,,,,,,,, -557800,4.822842,0.63469005,,,,,,,,,,,,,, -557849,,,0.9605787396430968,0.1464001685380935,0.7560200095176697,1.0539249181747437,50000.0,0.6270000338554382,1.843936562538147,10000.0,186723.44114279747,193048.2424557209,186723.44114279747,6275.80174779892,28.93087267875672,0.0 -557900,4.6014376,0.64706105,,,,,,,,,,,,,, -558000,4.0831413,0.5889343,,,,,,,,,,,,,, -558100,4.4221244,0.6099956,,,,,,,,,,,,,, -558200,4.300793,0.6057701,,,,,,,,,,,,,, -558300,4.3926716,0.65638167,,,,,,,,,,,,,, -558400,4.5037518,0.5966309,,,,,,,,,,,,,, -558500,4.922428,0.615217,,,,,,,,,,,,,, -558600,4.455594,0.60470587,,,,,,,,,,,,,, -558700,4.673192,0.62268,,,,,,,,,,,,,, -558800,4.5616503,0.6321124,,,,,,,,,,,,,, -558900,4.226509,0.58909214,,,,,,,,,,,,,, -559000,4.519611,0.6299919,,,,,,,,,,,,,, -559100,4.8954067,0.6054379,,,,,,,,,,,,,, -559200,4.8473554,0.6737082,,,,,,,,,,,,,, -559300,4.859651,0.6167862,,,,,,,,,,,,,, -559374,,,0.9610171914100648,0.1468623131513595,0.7558000087738037,1.052876591682434,50000.0,0.6279000043869019,1.8411613702774048,10000.0,187233.40432572365,193574.8201098442,187233.40432572365,6292.241653680801,29.04983305931092,0.0 -559400,4.637806,0.67424476,,,,,,,,,,,,,, -559500,4.6320686,0.5998235,,,,,,,,,,,,,, -559600,5.0787973,0.5977039,,,,,,,,,,,,,, -559700,4.3449483,0.616843,,,,,,,,,,,,,, -559800,5.492806,0.6768433,,,,,,,,,,,,,, -559900,4.1258326,0.6288353,,,,,,,,,,,,,, -559998,,,0.9618144035339355,0.1436149626970291,0.7557799816131592,1.053080916404724,50000.0,0.6265000104904175,1.8419371843338013,10000.0,187441.93405771253,193799.88971567157,187441.93405771253,6308.638211250305,29.170271158218384,0.0 -559998,,,,,,,,,,,187441.93405771255,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index e2eb5a50e..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,555 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.48432898521423,0.0,42.01926302909851,1,0,42.01926302909851,0.0010000000474974,6.907756805419922,10000,82.50370907783508,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -61.537635803222656,0.0274052619934082,462.3217113018036,926,0,462.3217113018036,0.0266000013798475,6.037107467651367,10000,523.9356322288513,0.034375000745058,5.894576072692871,0.032999999821186,5.924591541290283,50000 -81.992928981781,0.0549311637878418,882.624751329422,1901,0,882.624751329422,0.0599000044167041,5.397579193115234,10000,964.7744688987732,0.0853906199336052,5.1030755043029785,0.0809599980711937,5.148300647735596,50000 -103.20603942871094,0.0836653709411621,1302.5669553279877,2877,0,1302.5669553279877,0.1026000082492828,4.934473514556885,10000,1406.0095813274384,0.1455273479223251,4.522874355316162,0.1358200013637542,4.604587554931641,50000 -124.155011177063,0.1127946376800537,1722.5828483104706,3854,0,1722.5828483104706,0.1454000025987625,4.553320407867432,10000,1847.055628299713,0.2039648443460464,4.073319435119629,0.1905999928712844,4.152755260467529,50000 -147.63187551498413,0.1450769901275634,2142.8305237293243,4828,0,2142.8305237293243,0.195700004696846,4.1686320304870605,10000,2290.86334848404,0.2770312428474426,3.557164430618286,0.2561799883842468,3.682000160217285,50000 -168.98122477531433,0.1753633022308349,2563.035956144333,5795,0,2563.035956144333,0.2262000143527984,3.907768249511719,10000,2732.498987197876,0.3194335997104645,3.239790916442871,0.2967000007629394,3.379960775375366,50000 -193.0559239387512,0.2038640975952148,2983.3539803028107,6764,0,2983.3539803028107,0.2533000111579895,3.719790458679199,10000,3176.9725415706635,0.36865234375,2.939857721328736,0.3310999870300293,3.159950494766236,50000 -218.21373295784,0.2344825267791748,3403.5411076545715,7732,0,3403.5411076545715,0.2832000255584717,3.5490877628326416,10000,3622.407206058502,0.3927929699420929,2.838428497314453,0.3671599924564361,2.9767770767211914,50000 -249.78102779388428,0.2738537788391113,3823.8924305439,8701,0,3823.8924305439,0.2953000068664551,3.4372708797454834,10000,4074.4161858558655,0.41845703125,2.6635642051696777,0.389739990234375,2.8288674354553223,50000 -275.4603519439697,0.3063430786132812,4244.15566444397,9674,0,4244.15566444397,0.3086000084877014,3.3669235706329346,10000,4520.442771434784,0.4361914098262787,2.5769314765930176,0.3981599807739258,2.77098035812378,50000 -300.88571643829346,0.3400704860687256,4664.272991895676,10642,0,4664.272991895676,0.3243000209331512,3.2398502826690674,10000,4966.070648193359,0.4862890541553497,2.301498889923096,0.4239600002765655,2.618833065032959,50000 -332.84100675582886,0.3719866275787353,5084.614083766937,11610,0,5084.614083766937,0.3436000049114227,3.111103773117065,10000,5418.4511461257935,0.4810546636581421,2.291144609451294,0.4471199810504913,2.464445114135742,50000 -364.16368222236633,0.4047729969024658,5504.788379192352,12579,0,5504.788379192352,0.3518000245094299,3.076727867126465,10000,5870.032203912735,0.4947851598262787,2.2269725799560547,0.4521400034427643,2.441700935363769,50000 -390.2532448768616,0.436398983001709,5924.957451581955,13548,0,5924.957451581955,0.3651000261306762,2.996798276901245,10000,6316.37331199646,0.5150390267372131,2.110365390777588,0.4720799922943115,2.342987060546875,50000 -419.6220715045929,0.4737794399261474,6345.0926015377045,14515,0,6345.0926015377045,0.3729000091552734,2.967996597290039,10000,6765.965216636658,0.5230273604393005,2.1068804264068604,0.4803399741649627,2.3169755935668945,50000 -451.0975050926209,0.50453782081604,6765.137340545654,15480,0,6765.137340545654,0.3804000318050384,2.88616681098938,10000,7217.566586494446,0.5296484231948853,2.038872718811035,0.4894999861717224,2.242551803588867,50000 -485.8946087360382,0.53570556640625,7185.215455293655,16446,0,7185.215455293655,0.3939000070095062,2.830028295516968,10000,7672.523951292038,0.5384374856948853,1.977601408958435,0.502299964427948,2.171231746673584,50000 -520.9082908630371,0.5695958137512207,7605.359039306641,17412,0,7605.359039306641,0.3972000181674957,2.81754994392395,10000,8127.765516996384,0.5560351610183716,1.9124348163604736,0.507420003414154,2.158041477203369,50000 -555.1550786495209,0.5985524654388428,8025.422616481781,18373,0,8025.422616481781,0.4043000340461731,2.763168811798096,10000,8582.154694318771,0.5530468821525574,1.9021896123886108,0.5165199637413025,2.088361978530884,50000 -592.0320270061493,0.6309165954589844,8445.353466033936,19330,0,8445.353466033936,0.3993000090122223,2.775531530380249,10000,9039.04480600357,0.5592187643051147,1.8991457223892207,0.5181199908256531,2.118349313735962,50000 -628.9637792110443,0.6646013259887695,8865.62778878212,20286,0,8865.62778878212,0.4166000187397003,2.704303026199341,10000,9496.333614110948,0.5762304663658142,1.809704303741455,0.5256400108337402,2.042741775512696,50000 -664.3080334663391,0.694582462310791,9285.970098257065,21245,0,9285.970098257065,0.4236000180244446,2.6766977310180664,10000,9952.100473880768,0.60107421875,1.7186756134033203,0.5374400019645691,2.0195109844207764,50000 -701.2506122589111,0.7266604900360107,9706.001689434052,22200,0,9706.001689434052,0.4204000234603882,2.663136005401612,10000,10409.156054019928,0.5738281011581421,1.7905476093292236,0.5320799946784973,1.9965107440948489,50000 -734.7533724308014,0.7566776275634766,10126.287100076675,23162,0,10126.287100076675,0.4255000054836273,2.6321370601654053,10000,10863.024575471878,0.5878320336341858,1.741782546043396,0.5437399744987488,1.9589741230010984,50000 -766.6508240699768,0.7921888828277588,10546.216496706007,24122,0,10546.216496706007,0.425100028514862,2.689875602722168,10000,11314.93691921234,0.5895312428474426,1.806769847869873,0.5410400032997131,2.0305187702178955,50000 -798.524719953537,0.8267090320587158,10966.567569971085,25079,0,10966.567569971085,0.4388000071048736,2.607574462890625,10000,11767.246729373932,0.6007226705551147,1.691822528839111,0.5526800155639648,1.9392848014831543,50000 -832.1123118400574,0.8609256744384766,11386.746564149857,26027,0,11386.746564149857,0.4345000088214874,2.5865228176116943,10000,12221.09743642807,0.5960351228713989,1.6998733282089231,0.5539199709892273,1.9093241691589355,50000 -865.8403429985046,0.8936762809753418,11806.727120399475,26975,0,11806.727120399475,0.4431000351905823,2.551481008529663,10000,12674.887342214584,0.6092773079872131,1.6344122886657717,0.5594599843025208,1.863364815711975,50000 -900.530428647995,0.9298102855682372,12227.05313396454,27935,0,12227.05313396454,0.4442000091075897,2.544827938079834,10000,13129.989423274994,0.6161523461341858,1.614219307899475,0.5626599788665771,1.8807896375656128,50000 -934.0426671504974,0.9640679359436036,12647.381070375444,28890,0,12647.381070375444,0.445000022649765,2.535609483718872,10000,13583.913120269775,0.6077929735183716,1.6699862480163574,0.5683799982070923,1.8518006801605225,50000 -967.7215456962584,0.9999191761016846,13067.404882907867,29846,0,13067.404882907867,0.4506000280380249,2.4998276233673096,10000,14037.70155787468,0.6141015291213989,1.616439938545227,0.5706599950790405,1.8262709379196167,50000 -999.9251253604888,1.0317416191101074,13487.583720207214,30801,0,13487.583720207214,0.458400011062622,2.4593968391418457,10000,14490.164947509766,0.6283984184265137,1.5483167171478271,0.5783599615097046,1.7855757474899292,50000 -1033.848325252533,1.0704412460327148,13907.73057460785,31758,0,13907.73057460785,0.4554000198841095,2.479957342147827,10000,14944.324080467224,0.6522851586341858,1.4674758911132812,0.5711199641227722,1.8236415386199951,50000 -1067.429335355759,1.102602243423462,14327.77366900444,32715,0,14327.77366900444,0.4630000293254852,2.479883909225464,10000,15398.03229379654,0.6245507597923279,1.5911355018615725,0.5801599621772766,1.807516932487488,50000 -1097.7876312732697,1.136333703994751,14747.711050987244,33670,0,14747.711050987244,0.4677000343799591,2.4360461235046387,10000,15848.411272764206,0.6285351514816284,1.5775716304779053,0.5806999802589417,1.7970975637435913,50000 -1131.367756843567,1.1800377368927002,15167.88524198532,34622,0,15167.88524198532,0.4688000082969665,2.4291539192199707,10000,16302.25860452652,0.639355480670929,1.4995332956314087,0.5854399800300598,1.7577393054962158,50000 -1165.467206954956,1.21533465385437,15588.195026397703,35577,0,15588.195026397703,0.4697000086307525,2.40012788772583,10000,16756.752949476242,0.6327343583106995,1.5223309993743896,0.5879799723625183,1.7351510524749756,50000 -1198.2166216373444,1.2557530403137207,16008.226280927658,36534,0,16008.226280927658,0.4677000343799591,2.421023368835449,10000,17209.62438249588,0.6396093368530273,1.5190812349319458,0.5896199941635132,1.7429559230804443,50000 -1230.3696205615995,1.2913818359375,16428.334725379944,37483,0,16428.334725379944,0.4678000211715698,2.402662515640259,10000,17661.970279455185,0.6413280963897705,1.495647668838501,0.5907399654388428,1.7329176664352417,50000 -1263.7805182933807,1.3276479244232178,16848.304398536682,38436,0,16848.304398536682,0.4671000242233276,2.4115984439849854,10000,18115.436757087708,0.6577734351158142,1.416632056236267,0.5920599699020386,1.7245413064956665,50000 -1296.6276342868805,1.3599748611450195,17268.549030303955,39393,0,17268.549030303955,0.4754000306129455,2.362500667572021,10000,18568.60999751091,0.6365038752555847,1.513041615486145,0.5950599908828735,1.7173501253128052,50000 -1331.296977519989,1.3953208923339844,17688.56049466133,40347,0,17688.56049466133,0.483100026845932,2.3189759254455566,10000,19023.37623500824,0.6510156393051147,1.4378249645233154,0.6019399762153625,1.6753579378128052,50000 -1365.8039565086365,1.434863567352295,18108.91366672516,41301,0,18108.91366672516,0.4743000268936157,2.380539894104004,10000,19478.32561135292,0.6549023389816284,1.4472695589065552,0.6011599898338318,1.710437893867493,50000 -1399.6762397289276,1.4674532413482666,18529.00963950157,42256,0,18529.00963950157,0.4777000248432159,2.3627915382385254,10000,19932.3771443367,0.662109375,1.3703303337097168,0.6002799868583679,1.6656572818756104,50000 -1436.2877542972565,1.508568525314331,18949.35471701622,43214,0,18949.35471701622,0.4887000322341919,2.299259185791016,10000,20389.42469477653,0.6528710722923279,1.4166855812072754,0.6093400120735168,1.6414382457733154,50000 -1470.940866470337,1.5481336116790771,19369.56394290924,44170,0,19369.56394290924,0.4886000156402588,2.2939584255218506,10000,20844.375911474228,0.6615234017372131,1.3803268671035769,0.6105799674987793,1.6226335763931274,50000 -1505.269142627716,1.5860624313354492,19789.847430944443,45127,0,19789.847430944443,0.4873000085353851,2.2900307178497314,10000,21299.075829267505,0.6758593320846558,1.3159501552581787,0.6125400066375732,1.623989820480347,50000 -1537.1973378658297,1.619213342666626,20210.07796788216,46084,0,20210.07796788216,0.4923000335693359,2.3095061779022217,10000,21751.31724905968,0.6569140553474426,1.404916763305664,0.6086599826812744,1.6356945037841797,50000 -1572.094933271408,1.6627938747406006,20630.187987327576,47041,0,20630.187987327576,0.4798000156879425,2.3286588191986084,10000,22206.4185230732,0.6555468440055847,1.4303855895996094,0.6114799976348877,1.6497628688812256,50000 -1606.835849761963,1.6966009140014648,21050.396213054657,47996,0,21050.396213054657,0.4918000102043152,2.2892038822174072,10000,22661.451620817184,0.6664062142372131,1.364644169807434,0.611739993095398,1.6247352361679075,50000 -1641.1119446754456,1.742445945739746,21470.66002368927,48950,0,21470.66002368927,0.4933000206947326,2.256746530532837,10000,23116.08723402024,0.6929882764816284,1.2464438676834106,0.6179400086402893,1.586337924003601,50000 -1675.7305736541748,1.7791416645050049,21890.65754508972,49906,0,21890.65754508972,0.4933000206947326,2.2742130756378174,10000,23570.7903380394,0.6653515696525574,1.385420560836792,0.6169799566268921,1.6083929538726809,50000 -1710.644835472107,1.812294244766236,22310.686877965927,50862,0,22310.686877965927,0.4958000183105469,2.254197597503662,10000,24025.816939115524,0.6701366901397705,1.3385523557662964,0.6211999654769897,1.5810306072235107,50000 -1744.2933213710785,1.84682846069336,22730.87821960449,51818,0,22730.87821960449,0.5011000037193298,2.231066942214966,10000,24479.74113345146,0.6743554472923279,1.317443609237671,0.6154999732971191,1.582787036895752,50000 -1781.1761672496796,1.8918776512146,23150.86624217033,52770,0,23150.86624217033,0.5,2.2460379600524902,10000,24936.70695376396,0.6700195074081421,1.3372474908828735,0.6182599663734436,1.5742107629776,50000 -1815.730684518814,1.9351487159729004,23571.08939766884,53726,0,23571.08939766884,0.4960000216960907,2.242409467697144,10000,25391.57753252983,0.6646288633346558,1.3680269718170166,0.6178399920463562,1.5924180746078491,50000 -1848.4059002399445,1.9695231914520264,23991.25001358986,54684,0,23991.25001358986,0.4976000189781189,2.2368738651275635,10000,25844.49798321724,0.6791015267372131,1.3228133916854858,0.6257199645042419,1.5733473300933838,50000 -1881.3608667850488,2.004138708114624,24411.1894197464,55641,0,24411.1894197464,0.4954000115394592,2.2573320865631104,10000,26297.4784488678,0.6865624785423279,1.307192087173462,0.6205799579620361,1.607119917869568,50000 -1915.2504620552063,2.0507349967956543,24831.47842264176,56598,0,24831.47842264176,0.5113000273704529,2.1875674724578857,10000,26751.75327205658,0.6784765720367432,1.305625319480896,0.6305400133132935,1.534404158592224,50000 -1949.0071649551392,2.090480327606201,25251.627435684204,57554,0,25251.627435684204,0.5078999996185303,2.219440460205078,10000,27205.74866914749,0.6798437237739563,1.33681058883667,0.6278600096702576,1.564784049987793,50000 -1982.4562137126925,2.126615047454834,25671.693274259567,58510,0,25671.693274259567,0.5093000531196594,2.194056749343872,10000,27659.35000705719,0.6871093511581421,1.2710402011871338,0.6285799741744995,1.5473343133926392,50000 -2015.0270056724548,2.170564889907837,26091.714426994324,59466,0,26091.714426994324,0.5082000494003296,2.2486259937286377,10000,28112.035681962967,0.7090429663658142,1.2248718738555908,0.6277599930763245,1.591611385345459,50000 -2050.053850412369,2.206193447113037,26511.63784146309,60421,0,26511.63784146309,0.511900007724762,2.176940679550171,10000,28567.07047390937,0.6841210722923279,1.2804559469223022,0.632099986076355,1.5170131921768188,50000 -2085.340404987335,2.2476589679718018,26931.86723518372,61377,0,26931.86723518372,0.5074000358581543,2.163461208343506,10000,29022.67796754837,0.6870312094688416,1.2437814474105835,0.6377999782562256,1.4895076751708984,50000 -2118.1527137756348,2.290111780166626,27352.072308063507,62332,0,27352.072308063507,0.5136000514030457,2.16353702545166,10000,29475.78770780564,0.7010741829872131,1.2027630805969238,0.6353999972343445,1.4932037591934204,50000 -2153.5841794013977,2.325376510620117,27772.3682949543,63289,0,27772.3682949543,0.5095000267028809,2.2137796878814697,10000,29931.60078239441,0.6837499737739563,1.3085410594940186,0.6333799958229065,1.5339182615280151,50000 -2188.6370465755463,2.3665175437927246,28192.69956564904,64244,0,28192.69956564904,0.5075000524520874,2.188990592956543,10000,30387.0754737854,0.6842578053474426,1.3000692129135132,0.6389999985694885,1.5179568529129028,50000 -2223.919200897217,2.40610671043396,28613.02726626396,65201,0,28613.02726626396,0.5093000531196594,2.22403335571289,10000,30842.774602413177,0.6913281083106995,1.2925525903701782,0.6354599595069885,1.5564727783203125,50000 -2254.170946121216,2.441678762435913,29033.197748422623,66160,0,29033.197748422623,0.5128999948501587,2.1780552864074707,10000,31293.28220319748,0.7096288800239563,1.2008816003799438,0.6379199624061584,1.5290579795837402,50000 -2288.9244425296783,2.4865760803222656,29453.328754663467,67115,0,29453.328754663467,0.5189000368118286,2.1419031620025635,10000,31748.261180639267,0.6934179663658142,1.2447786331176758,0.6399399638175964,1.487149953842163,50000 -2324.6626167297363,2.5216784477233887,29873.27702856064,68073,0,29873.27702856064,0.522599995136261,2.106295347213745,10000,32204.031671524048,0.7007812261581421,1.195706486701965,0.6473599672317505,1.448695421218872,50000 -2360.7755336761475,2.5654780864715576,30293.492072343823,69028,0,30293.492072343823,0.5178000330924988,2.1673731803894043,10000,32660.45233750344,0.7033398151397705,1.235210657119751,0.6412999629974365,1.5109248161315918,50000 -2392.7489824295044,2.603192090988159,30713.74888730049,69982,0,30713.74888730049,0.5182000398635864,2.133967399597168,10000,33112.77065825462,0.7116405963897705,1.1737558841705322,0.6464200019836426,1.4845588207244873,50000 -2427.8326795101166,2.6537318229675293,31133.713386058807,70937,0,31133.713386058807,0.5223000049591064,2.1306920051574707,10000,33567.919130563736,0.69447261095047,1.2243953943252563,0.6431599855422974,1.4684062004089355,50000 -2463.0030977725983,2.694608211517334,31553.676849365234,71889,0,31553.676849365234,0.5151000022888184,2.1532177925109863,10000,34023.14230251312,0.697460949420929,1.2218685150146484,0.6411399841308594,1.4857412576675415,50000 -2497.761979341507,2.7404704093933105,31974.063121318817,72846,0,31974.063121318817,0.5228000283241272,2.113523483276367,10000,34478.38299822807,0.7125195264816284,1.142621397972107,0.650119960308075,1.44003164768219,50000 -2533.733570098877,2.7803633213043213,32394.13893151284,73804,0,32394.13893151284,0.527999997138977,2.113095760345459,10000,34934.52019953728,0.699999988079071,1.2046812772750854,0.6521199941635132,1.439921736717224,50000 -2569.016624450684,2.836053371429444,32814.26351618767,74763,0,32814.26351618767,0.5234000086784363,2.103261709213257,10000,35390.0334186554,0.7062109112739563,1.1783642768859863,0.6523199677467346,1.4315232038497925,50000 -2604.0515208244324,2.8736648559570312,33234.302689790726,75717,0,33234.302689790726,0.5235000252723694,2.121488332748413,10000,35845.19429016113,0.7077734470367432,1.1573503017425537,0.6542999744415283,1.419303297996521,50000 -2640.2735633850098,2.9117650985717773,33654.46897649765,76674,0,33654.46897649765,0.5303000211715698,2.0822157859802246,10000,36301.674439907074,0.7352538704872131,1.0566195249557495,0.6536200046539307,1.416456937789917,50000 -2675.3321437835693,2.9543845653533936,34074.84289550781,77634,0,34074.84289550781,0.5340999960899353,2.075775384902954,10000,36757.19959259033,0.7109179496765137,1.169638752937317,0.6612199544906616,1.404836654663086,50000 -2712.6424124240875,2.9909353256225586,34494.79498171806,78592,0,34494.79498171806,0.5379000306129456,2.034130811691284,10000,37214.54881882668,0.7156835794448853,1.1127052307128906,0.662559986114502,1.369977593421936,50000 -2747.687013626098,3.044427394866944,34914.722329854965,79548,0,34914.722329854965,0.5356000065803528,2.0531795024871826,10000,37669.63617515564,0.7233788967132568,1.1037510633468628,0.659280002117157,1.397465705871582,50000 -2783.3004961013794,3.0824363231658936,35334.804896593094,80505,0,35334.804896593094,0.5367000102996826,2.0481276512146,10000,38125.42041826248,0.7168163657188416,1.139781832695007,0.6624999642372131,1.3918650150299072,50000 -2819.2582035064697,3.1309750080108643,35754.91748666763,81462,0,35754.91748666763,0.526199996471405,2.07778549194336,10000,38581.58862805367,0.7132421731948853,1.1469886302947998,0.6568199992179871,1.4000229835510254,50000 -2853.3048980236053,3.1728460788726807,36174.90492129326,82421,0,36174.90492129326,0.5421000123023987,2.0170681476593018,10000,39035.7145614624,0.7280859351158142,1.0940496921539309,0.6650399565696716,1.375693678855896,50000 -2888.807193994522,3.2124273777008057,36595.001002788544,83380,0,36595.001002788544,0.5343000292778015,2.052915334701538,10000,39491.40196967125,0.7390820384025574,1.0498539209365845,0.6658599972724915,1.3794326782226562,50000 -2923.7270460128784,3.254138708114624,37015.265879392624,84339,0,37015.265879392624,0.5428000092506409,1.999638795852661,10000,39946.67766284943,0.7233007550239563,1.0985373258590698,0.668720006942749,1.348969340324402,50000 -2959.251959323883,3.2952709197998047,37435.61778759956,85298,0,37435.61778759956,0.5355000495910645,2.0667412281036377,10000,40402.64578437805,0.7191796898841858,1.1635682582855225,0.6645399928092957,1.4133321046829224,50000 -2990.723461866379,3.3366193771362305,37855.74568748474,86259,0,37855.74568748474,0.5408000349998474,2.054095506668091,10000,40854.33689212799,0.7269921898841858,1.1259955167770386,0.6650800108909607,1.396593451499939,50000 -3022.891412258148,3.387751340866089,38275.79717183113,87217,0,38275.79717183113,0.5439000129699707,1.98435378074646,10000,41306.657376527786,0.7546288967132568,0.9544236660003662,0.6703999638557434,1.32748544216156,50000 -3056.3972618579865,3.430220603942871,38695.72180867195,88173,0,38695.72180867195,0.5463000535964966,1.977535843849182,10000,41760.179690122604,0.7283398509025574,1.082810401916504,0.6726599931716919,1.337311029434204,50000 -3092.58779001236,3.471376419067383,39116.09435725212,89131,0,39116.09435725212,0.5505000352859497,1.975814938545227,10000,42216.833818912506,0.7352343797683716,1.056477427482605,0.671779990196228,1.3395153284072876,50000 -3134.03320145607,3.513042688369751,39536.358402490616,90091,0,39536.358402490616,0.5543000102043152,1.976184606552124,10000,42678.63556289673,0.7414257526397705,1.0378321409225464,0.6741200089454651,1.3346014022827148,50000 -3171.4239320755005,3.5524487495422363,39956.39900755882,91048,0,39956.39900755882,0.5467000007629395,1.9848815202713013,10000,43136.15661621094,0.7334179282188416,1.0601955652236938,0.6719399690628052,1.3371931314468384,50000 -3211.6039159297943,3.5961568355560303,40376.65997195244,92005,0,40376.65997195244,0.5501000285148621,2.014688491821289,10000,43596.69050884247,0.732128918170929,1.0990062952041626,0.6755799651145935,1.3497874736785889,50000 -3246.4359505176544,3.63786244392395,40796.6041162014,92962,0,40796.6041162014,0.5496000051498413,1.969154953956604,10000,44051.557891607285,0.7373241782188416,1.0257611274719238,0.6764799952507019,1.3098903894424438,50000 -3283.058206319809,3.680174589157105,41216.86024641991,93923,0,41216.86024641991,0.5472000241279602,1.9758864641189573,10000,44508.52853775024,0.74916011095047,0.9873992800712584,0.6765999794006348,1.3194611072540283,50000 -3320.365930557251,3.7214975357055664,41637.13806056976,94883,0,41637.13806056976,0.553600013256073,1.94719660282135,10000,44966.2047367096,0.740917980670929,1.0271331071853638,0.6797199845314026,1.3003922700881958,50000 -3363.7155849933624,3.76453709602356,42057.04630446434,95838,0,42057.04630446434,0.5525000095367432,1.9437371492385864,10000,45429.55492138863,0.7427148222923279,1.0067245960235596,0.6803999543190002,1.2884384393692017,50000 -3402.8420236110687,3.809993982315064,42476.98233127594,96794,0,42476.98233127594,0.5600000023841858,1.9353830814361568,10000,45888.71178412437,0.7479296922683716,0.9932562112808228,0.6821199655532837,1.2954468727111816,50000 -3444.739847421646,3.855117321014404,42897.07012176514,97752,0,42897.07012176514,0.5525000095367432,1.9546915292739868,10000,46350.79205417633,0.7708593606948853,0.9360008239746094,0.6833399534225464,1.3163715600967407,50000 -3479.8404698371887,3.900003433227539,43317.06185436249,98707,0,43317.06185436249,0.5577000379562378,1.93332839012146,10000,46805.97884583473,0.7409374713897705,1.0260651111602783,0.6859999895095825,1.2778408527374268,50000 -3516.056496620178,3.943291664123535,43737.05564260483,99664,0,43737.05564260483,0.5599000453948975,1.931175708770752,10000,47262.28113818169,0.7491015195846558,0.9865267872810364,0.6865599751472473,1.276195764541626,50000 -3551.346296310425,3.98929500579834,44157.2749619484,100623,0,44157.2749619484,0.5639000535011292,1.9013869762420648,10000,47717.88592863083,0.7603319883346558,0.9452508091926576,0.6875,1.2674585580825806,50000 -3587.5535452365875,4.040088415145874,44577.33863568306,101584,0,44577.33863568306,0.562000036239624,1.9139764308929443,10000,48174.25753450394,0.7491992115974426,1.0066653490066528,0.6896399855613708,1.2784479856491089,50000 -3619.7777602672577,4.091718912124634,44997.26434326172,102541,0,44997.26434326172,0.5687000155448914,1.890002012252808,10000,48626.50860714912,0.7575390338897705,0.9742467999458312,0.694920003414154,1.2482126951217651,50000 -3656.063153505325,4.135879993438721,45417.35946679115,103500,0,45417.35946679115,0.563800036907196,1.9133514165878296,10000,49082.983598947525,0.7555859088897705,0.9570018649101256,0.6918999552726746,1.2508394718170166,50000 -3692.076496124268,4.183363676071167,45837.52560162544,104460,0,45837.52560162544,0.5600000023841858,1.9212124347686768,10000,49539.26046657562,0.7704687118530273,0.9185051321983336,0.6934799551963806,1.2655454874038696,50000 -3727.893661260605,4.22884464263916,46257.73142623901,105421,0,46257.73142623901,0.5672000050544739,1.8987038135528564,10000,49995.38002538681,0.7509570121765137,0.986622154712677,0.6902799606323242,1.2532689571380615,50000 -3773.502552270889,4.281575441360474,46677.66419768333,106380,0,46677.66419768333,0.5681000351905823,1.871907353401184,10000,50461.02531313896,0.7627343535423279,0.9351568222045898,0.6969799995422363,1.2286380529403689,50000 -3812.488776922226,4.329556226730347,47097.86094522476,107336,0,47097.86094522476,0.5746000409126282,1.867994785308838,10000,50920.30652046204,0.7720312476158142,0.901677131652832,0.6996600031852722,1.2177472114562988,50000 -3856.903764724731,4.379816055297852,47518.158441782,108291,0,47518.158441782,0.5742000341415405,1.8423296213150024,10000,51385.118559122086,0.7859960794448853,0.8419576287269592,0.6972599625587463,1.2124354839324951,50000 -3901.020851612091,4.426162004470825,47938.24389958382,109246,0,47938.24389958382,0.5749000310897827,1.869747757911682,10000,51849.41720533371,0.7622460722923279,0.9467803835868835,0.6993599534034729,1.227556586265564,50000 -3942.307982444763,4.472722053527832,48358.40622162819,110201,0,48358.40622162819,0.5772000551223755,1.8481879234313965,10000,52310.96383190155,0.7684179544448853,0.9048970341682434,0.7038999795913696,1.196832299232483,50000 -3984.3732771873474,4.523653030395508,48778.63722419739,111158,0,48778.63722419739,0.5800999999046326,1.8331900835037231,10000,52773.36154890061,0.7777343392372131,0.8578104972839355,0.7034199833869934,1.197596788406372,50000 -4024.359180688858,4.575310945510864,49198.80200004578,112113,0,49198.80200004578,0.5842000246047974,1.8372535705566408,10000,53233.61412191391,0.7675976157188416,0.9133527278900146,0.7023599743843079,1.210939884185791,50000 -4064.391876220703,4.628418207168579,49619.058660030365,113068,0,49619.058660030365,0.5800000429153442,1.843191146850586,10000,53694.006588459015,0.76917964220047,0.9122159481048584,0.703819990158081,1.2008824348449707,50000 -4106.380417108536,4.684727907180786,50039.012590408325,114020,0,50039.012590408325,0.5845000147819519,1.8225212097167969,10000,54156.05561089516,0.77685546875,0.8710706830024719,0.707539975643158,1.1852104663848877,50000 -4142.604337930679,4.731132507324219,50459.272194862366,114978,0,50459.272194862366,0.5812000036239624,1.831921935081482,10000,54612.63543081284,0.7953710556030273,0.8017292022705078,0.7105000019073486,1.1759315729141235,50000 -4188.624214410782,4.783179759979248,50879.5531334877,115933,0,50879.5531334877,0.5839000344276428,1.8186765909194944,10000,55079.03768348694,0.775097668170929,0.878689169883728,0.7120400071144104,1.159434199333191,50000 -4231.638665437698,4.826200723648071,51299.98179316521,116889,0,51299.98179316521,0.5835000276565552,1.794126033782959,10000,55542.57398247719,0.7811132669448853,0.8491606712341309,0.712119996547699,1.1562445163726809,50000 -4269.591933727264,4.876907825469971,51720.19587230682,117847,0,51720.19587230682,0.5782000422477722,1.8094024658203125,10000,56000.84201049805,0.78822261095047,0.8326475620269775,0.710919976234436,1.1654090881347656,50000 -4309.85601067543,4.921182155609131,52140.52798819542,118805,0,52140.52798819542,0.5823000073432922,1.799047350883484,10000,56461.5328617096,0.7870702743530273,0.817166268825531,0.7132399678230286,1.1419477462768557,50000 -4349.591715335846,4.970265626907349,52560.54592466354,119764,0,52560.54592466354,0.5859000086784363,1.8054711818695068,10000,56921.385232687,0.7832226157188416,0.831150472164154,0.7125399708747864,1.1509326696395874,50000 -4390.941907405853,5.0248682498931885,52980.62607479096,120716,0,52980.62607479096,0.5871000289916992,1.7703475952148438,10000,57382.91956567764,0.7913867235183716,0.8026143312454224,0.7174599766731262,1.1281665563583374,50000 -4432.157784461975,5.077365398406982,53400.83889579773,121672,0,53400.83889579773,0.5902000069618225,1.784095048904419,10000,57844.45234084129,0.7996289134025574,0.7926232814788818,0.7188000082969666,1.1456674337387085,50000 -4470.588375091553,5.134720325469971,53820.74506020546,122625,0,53820.74506020546,0.5949000120162964,1.749130368232727,10000,58302.89676403999,0.7887304425239563,0.8185681104660034,0.7198799848556519,1.1269749402999878,50000 -4511.416751623154,5.179789066314697,54240.97389984131,123580,0,54240.97389984131,0.598300039768219,1.7739555835723877,10000,58764.0489218235,0.7926952838897705,0.8255926966667175,0.7219199538230896,1.1399753093719482,50000 -4548.689545869827,5.231285095214844,54660.88675928116,124534,0,54660.88675928116,0.5958999991416931,1.755924940109253,10000,59221.33545231819,0.7994140386581421,0.7840669751167297,0.7232399582862854,1.110971450805664,50000 -4587.305732250214,5.2786900997161865,55080.96838617325,125490,0,55080.96838617325,0.6070000529289246,1.718909502029419,10000,59680.13125133514,0.8218359351158142,0.6901920437812805,0.7267999649047852,1.1000694036483765,50000 -4631.069079637528,5.3379082679748535,55501.05866575241,126445,0,55501.05866575241,0.5904000401496887,1.757405400276184,10000,60144.09332036972,0.79296875,0.8103262782096863,0.7235199809074402,1.1214704513549805,50000 -4675.814074754715,5.386435270309448,55921.156155347824,127403,0,55921.156155347824,0.600100040435791,1.7393742799758911,10000,60609.03382444382,0.80140620470047,0.7748098373413086,0.7250999808311462,1.1140674352645874,50000 -4720.664792776108,5.441857099533081,56341.33557915688,128362,0,56341.33557915688,0.6046000123023987,1.7480357885360718,10000,61074.16904568672,0.808886706829071,0.7777705192565918,0.7276399731636047,1.1254874467849731,50000 -4766.389065265656,5.490319013595581,56761.44100022316,129319,0,56761.44100022316,0.6008000373840332,1.7313231229782104,10000,61540.09782385826,0.8020898103713989,0.7728776335716248,0.7276600003242493,1.092849850654602,50000 -4809.945647716522,5.539454460144043,57181.441122055054,130272,0,57181.441122055054,0.6047000288963318,1.703953981399536,10000,62003.752519369125,0.8078320026397705,0.7349570989608765,0.7328199744224548,1.0716028213500977,50000 -4852.868507862091,5.591620922088623,57601.39503288269,131223,0,57601.39503288269,0.6057000160217285,1.6870981454849243,10000,62466.73106837273,0.81068354845047,0.7343043088912964,0.734499990940094,1.0711177587509155,50000 -4892.319691181183,5.639818906784058,58021.39619445801,132177,0,58021.39619445801,0.6094000339508057,1.702414870262146,10000,62926.28131008148,0.8233007788658142,0.6953759789466858,0.7353000044822693,1.072637915611267,50000 -4937.602854728699,5.699962615966797,58441.57426691055,133130,0,58441.57426691055,0.6080000400543213,1.6945815086364746,10000,63391.85244774818,0.8082226514816284,0.7451169490814209,0.7351199984550476,1.071921944618225,50000 -4983.517159461975,5.74588418006897,58861.696504592896,134083,0,58861.696504592896,0.6075000166893005,1.6868488788604736,10000,63857.983344078064,0.8153125047683716,0.7177433371543884,0.7369799613952637,1.052171230316162,50000 -5028.039650678635,5.795482873916626,59281.61557817459,135039,0,59281.61557817459,0.6113000512123108,1.669729232788086,10000,64322.52414488792,0.8215429782867432,0.6732996702194214,0.7372199892997742,1.0496302843093872,50000 -5071.571031808853,5.841874361038208,59701.6204688549,135996,0,59701.6204688549,0.6158000230789185,1.6552666425704956,10000,64786.15665626526,0.8267382383346558,0.6621195673942566,0.7412199974060059,1.0311529636383057,50000 -5118.237744569778,5.8914642333984375,60121.56489992142,136949,0,60121.56489992142,0.6136000156402588,1.6724987030029297,10000,65252.8662545681,0.81689453125,0.7047690153121948,0.7401599884033203,1.0472062826156616,50000 -5155.756701231003,5.945829391479492,60541.76283168793,137905,0,60541.76283168793,0.6154000163078308,1.663904905319214,10000,65710.68705844879,0.8246093392372131,0.6600923538208008,0.7413600087165833,1.0288509130477903,50000 -5195.878809213638,5.996556520462036,60962.057092905045,138860,0,60962.057092905045,0.6155000329017639,1.6657668352127075,10000,66171.20323944092,0.8321484327316284,0.6528514623641968,0.7429599761962891,1.038433313369751,50000 -5238.357710123062,6.045234203338623,61382.15120720863,139816,0,61382.15120720863,0.6185000538825989,1.638699293136597,10000,66633.87408804893,0.8234570026397705,0.6648285388946533,0.7445200085639954,1.0157908201217651,50000 -5277.053035020828,6.096495151519775,61802.4379234314,140773,0,61802.4379234314,0.6203000545501709,1.6408082246780396,10000,67092.95656824112,0.8301171660423279,0.6384050846099854,0.7467399835586548,1.0131675004959106,50000 -5320.5501408576965,6.1448814868927,62222.53354215622,141729,0,62222.53354215622,0.6189000010490417,1.641674518585205,10000,67556.64731407166,0.8341015577316284,0.6274772882461548,0.7470999956130981,1.0052226781845093,50000 -5366.112824201584,6.20355749130249,62642.4606757164,142683,0,62642.4606757164,0.6233000159263611,1.6336963176727295,10000,68022.24512600899,0.8469530940055847,0.5803266167640686,0.749239981174469,0.993854284286499,50000 -5407.827532052994,6.254222393035889,63062.7713201046,143641,0,63062.7713201046,0.6241000294685364,1.6271873712539673,10000,68484.37150168419,0.8289452791213989,0.6445606350898743,0.7488399744033813,0.999646782875061,50000 -5449.443675994873,6.337775707244873,63482.79140949249,144596,0,63482.79140949249,0.6246000528335571,1.6112269163131714,10000,68946.13994860649,0.8374804258346558,0.6231352686882019,0.7511799931526184,0.9905893206596376,50000 -5496.952882766724,6.399603128433228,63903.018033504486,145551,0,63903.018033504486,0.6222000122070312,1.6284139156341553,10000,69413.98632764816,0.842578113079071,0.6011013984680176,0.7515400052070618,0.997151792049408,50000 -5539.651907920837,6.44763708114624,64323.128088235855,146507,0,64323.128088235855,0.6258000135421753,1.6044663190841677,10000,69876.89319229126,0.8387304544448853,0.6057491898536682,0.7545599937438965,0.975864589214325,50000 -5576.707340240479,6.506728172302246,64743.03775429726,147462,0,64743.03775429726,0.6244000196456909,1.6166449785232544,10000,70333.96669006348,0.8384374976158142,0.630463719367981,0.7534799575805664,0.9893783330917358,50000 -5618.667296886444,6.567543268203735,65163.15673828125,148421,0,65163.15673828125,0.6320000290870667,1.5996073484420776,10000,70796.16014623642,0.8456249833106995,0.5935382843017578,0.7538999915122986,0.978970229625702,50000 -5660.120743513107,6.616364240646362,65583.5103931427,149378,0,65583.5103931427,0.6299000382423401,1.5885131359100342,10000,71258.06540131569,0.8524218797683716,0.5529924035072327,0.7565799951553345,0.9650366306304932,50000 -5707.843321084976,6.665755748748779,66003.76546001434,150334,0,66003.76546001434,0.6344000101089478,1.5788832902908323,10000,71726.14133667946,0.8441210985183716,0.5873963832855225,0.7583999633789062,0.9608394503593444,50000 -5753.7766098976135,6.717709302902222,66424.10148477554,151290,0,66424.10148477554,0.6296000480651855,1.575633525848389,10000,72192.51217556,0.8487499952316284,0.5678128004074097,0.7597799897193909,0.949147641658783,50000 -5793.150819063187,6.770383834838867,66844.15052103996,152246,0,66844.15052103996,0.6332000494003296,1.5770167112350464,10000,72652.03750824928,0.8516015410423279,0.5689120888710022,0.7590599656105042,0.9657217860221864,50000 -5834.262246847153,6.820079565048218,67264.09875321388,153200,0,67264.09875321388,0.638700008392334,1.55450701713562,10000,73113.19673418999,0.8597851395606995,0.5119295120239258,0.7633399963378906,0.9361425638198853,50000 -5875.818645477295,6.8744797706604,67684.29598784447,154154,0,67684.29598784447,0.6413000226020813,1.5551538467407229,10000,73575.05439019203,0.8531835675239563,0.5538952350616455,0.7621200084686279,0.944446623325348,50000 -5919.092953443527,6.933387756347656,68104.2943456173,155112,0,68104.2943456173,0.6392000317573547,1.5500725507736206,10000,74038.43568396568,0.8586328029632568,0.5334701538085938,0.7653200030326843,0.94097101688385,50000 -5961.080224990845,6.984781265258789,68524.47574973106,156069,0,68524.47574973106,0.6394000053405762,1.5504746437072754,10000,74500.70484685898,0.8615820407867432,0.5220953822135925,0.7646399736404419,0.9310302138328552,50000 -6001.720447778702,7.037832736968994,68944.65562939644,157026,0,68944.65562939644,0.6401000022888184,1.5474493503570557,10000,74961.62737870216,0.8561913967132568,0.5413217544555664,0.7650399804115295,0.9309642910957336,50000 -6041.917126655579,7.940981864929199,69364.08955550194,157979,0,69364.08955550194,0.6416000127792358,1.548611760139465,10000,75422.21031832695,0.8600585460662842,0.5279272794723511,0.7641599774360657,0.9306029677391052,50000 -6085.347447156906,8.00169324874878,69784.15007662773,158936,0,69784.15007662773,0.6457000374794006,1.527753233909607,10000,75885.81394910812,0.8643164038658142,0.5127362608909607,0.768619954586029,0.924987494945526,50000 -6127.854299068451,8.064128637313843,70204.42936730385,159891,0,70204.42936730385,0.6384000182151794,1.5375657081604004,10000,76348.71235513687,0.8702148199081421,0.484088271856308,0.7680000066757202,0.9140318632125854,50000 -6175.40148806572,8.118890523910522,70624.54157876968,160847,0,70624.54157876968,0.64410001039505,1.5222582817077637,10000,76816.47623872757,0.8659570217132568,0.5069062113761902,0.7679399847984314,0.9168761372566224,50000 -6221.217273712158,8.172521114349365,71044.70680117607,161801,0,71044.70680117607,0.6473000049591064,1.5221868753433228,10000,77282.55980610847,0.8662499785423279,0.4994567334651947,0.7704600095748901,0.9088558554649352,50000 -6265.452960968018,8.222732543945312,71464.65548014641,162757,0,71464.65548014641,0.6461000442504883,1.5066665410995483,10000,77746.84324288368,0.8711718320846558,0.479285329580307,0.7711199522018433,0.9005151987075806,50000 -6306.055119752884,8.277120113372803,71884.6285545826,163711,0,71884.6285545826,0.6484000086784363,1.51751446723938,10000,78207.52199673653,0.87451171875,0.4802699983119964,0.7723999619483948,0.9044002890586852,50000 -6347.265980482101,8.328987121582031,72304.52912330627,164667,0,72304.52912330627,0.6477000117301941,1.507075309753418,10000,78668.73527359962,0.8724218606948853,0.4697950780391693,0.774679958820343,0.8892613053321838,50000 -6389.079388380051,8.3844153881073,72724.6625881195,165623,0,72724.6625881195,0.650600016117096,1.5049769878387451,10000,79130.78731393814,0.8748242259025574,0.4756855964660644,0.7743600010871887,0.8956592679023743,50000 -6432.890777826309,8.447266101837158,73144.80176186562,166577,0,73144.80176186562,0.6516000032424927,1.505468726158142,10000,79594.84973239899,0.8760351538658142,0.4655555188655853,0.7749999761581421,0.8960703611373901,50000 -6476.155483722687,8.500588655471802,73564.99614834785,167536,0,73564.99614834785,0.6520000100135803,1.4951791763305664,10000,80058.41223526001,0.8737109303474426,0.4756890833377838,0.7744999527931213,0.8931192755699158,50000 -6523.7312026023865,8.55799913406372,73985.04015851021,168492,0,73985.04015851021,0.653700053691864,1.4894458055496216,10000,80526.1391685009,0.8785156011581421,0.4542920589447021,0.776479959487915,0.885757565498352,50000 -6563.126978397369,8.622390508651733,74405.21560502052,169404,0,74405.21560502052,0.6568000316619873,1.4942373037338257,10000,80985.82221055031,0.8810546398162842,0.4524443447589874,0.7756399512290955,0.8834952712059021,50000 -6606.119580030441,8.677473068237305,74825.23263335228,170356,0,74825.23263335228,0.6581000089645386,1.4730889797210691,10000,81448.93747091293,0.8832812309265137,0.4339889287948608,0.7788199782371521,0.870005190372467,50000 -6645.805203676224,8.72992992401123,75245.5376894474,171311,0,75245.5376894474,0.6554000377655029,1.4801712036132812,10000,81909.02955842018,0.8805468678474426,0.4478434324264526,0.7775999903678894,0.8766506314277649,50000 -6688.530457019806,8.78211498260498,75665.74482417107,172266,0,75665.74482417107,0.6565000414848328,1.4745608568191528,10000,82372.0632212162,0.8803319931030273,0.4448411166667938,0.7797999978065491,0.8689996600151062,50000 -6731.652981758118,8.847268104553223,76085.70857977867,173218,0,76085.70857977867,0.656000018119812,1.47455096244812,10000,82835.26369142532,0.8841796517372131,0.430279940366745,0.7793799638748169,0.8702945709228516,50000 -6772.623242139816,8.900946855545044,76505.80146336555,174174,0,76505.80146336555,0.6546000242233276,1.47637140750885,10000,83296.43021583557,0.8836718797683716,0.4322193562984466,0.7788800001144409,0.8668044805526733,50000 -6815.133854389191,8.958553314208984,76925.99597835541,175131,0,76925.99597835541,0.6604000329971313,1.4671146869659424,10000,83759.2424018383,0.8831640481948853,0.4393573999404907,0.781059980392456,0.8640338778495789,50000 -6858.408992290497,9.01213812828064,77346.0910449028,176084,0,77346.0910449028,0.6589000225067139,1.4714841842651367,10000,84222.7153236866,0.8835546970367432,0.4316456019878387,0.7802599668502808,0.8601840138435364,50000 -6900.202219963074,9.065497159957886,77766.39183497429,177041,0,77766.39183497429,0.6598000526428223,1.4636272192001345,10000,84684.91393399239,0.8862109184265137,0.4194622039794922,0.780739963054657,0.8591325283050537,50000 -6940.258265972137,9.121346712112429,78186.408213377,177997,0,78186.408213377,0.6583000421524048,1.462801814079285,10000,85145.09214758873,0.8862695097923279,0.4204770624637604,0.7815799713134766,0.8558434247970581,50000 -6989.28365111351,9.188668251037598,78606.38973283768,178953,0,78606.38973283768,0.6598000526428223,1.461968183517456,10000,85614.21636533737,0.8860937356948853,0.4298219084739685,0.7818399667739868,0.8621166944503784,50000 -7031.9412133693695,9.244869709014893,79026.4203722477,179911,0,79026.4203722477,0.6611000299453735,1.4549766778945925,10000,86077.0102751255,0.8874218463897705,0.4167936742305755,0.7824999690055847,0.8534600734710693,50000 -7077.8280510902405,9.303065538406372,79446.53990650177,180869,0,79446.53990650177,0.661300003528595,1.453679442405701,10000,86543.125269413,0.8839648365974426,0.4304944276809692,0.7829399704933167,0.8545869588851929,50000 -7123.566502571106,9.364516973495483,79866.53466320038,181824,0,79866.53466320038,0.6606000065803528,1.4558229446411133,10000,87008.9695174694,0.885546863079071,0.4214950501918793,0.782480001449585,0.8549716472625732,50000 -7172.918921947479,9.42094874382019,80286.43485283852,182779,0,80286.43485283852,0.660800039768219,1.4561413526535034,10000,87478.32796931267,0.8897070288658142,0.411954402923584,0.7824400067329407,0.8548983931541443,50000 -7213.88618850708,9.477174997329712,80706.45871520042,183735,0,80706.45871520042,0.6620000600814819,1.4533320665359497,10000,87939.42469787598,0.8878905773162842,0.4117888808250427,0.782759964466095,0.852722704410553,50000 -7260.604390859604,9.539127111434937,81126.73969316483,184689,0,81126.73969316483,0.6610000133514404,1.4538012742996216,10000,88406.53539347649,0.8874609470367432,0.4206721782684326,0.7829999923706055,0.8523223400115967,50000 -7300.4049389362335,9.598478078842165,81546.66754245758,185645,0,81546.66754245758,0.6612000465393066,1.4538347721099854,10000,88866.37303447723,0.8898437023162842,0.4111132025718689,0.7829999923706055,0.8524206280708313,50000 -7344.524179458618,9.663646221160889,81966.87752747536,186599,0,81966.87752747536,0.6614000201225281,1.4538928270339966,10000,89330.82261490822,0.8867577910423279,0.4253540337085724,0.7829399704933167,0.8525302410125732,50000 -7391.805383205414,9.732877016067505,82386.98583173752,187550,0,82386.98583173752,0.6614000201225281,1.453892707824707,10000,89798.33035802841,0.8905078172683716,0.4123665988445282,0.7829399704933167,0.8525302410125732,50000 -7430.450542926788,9.789993047714232,82807.08119153976,188507,0,82807.08119153976,0.6614000201225281,1.453892707824707,10000,90257.17705130576,0.8867382407188416,0.4162900745868683,0.7829399704933167,0.8525302410125732,50000 -7475.613933086395,9.851112365722656,83227.10395240784,189465,0,83227.10395240784,0.6614000201225281,1.453892707824707,10000,90722.47425937653,0.8881444931030273,0.415896475315094,0.7829399704933167,0.8525302410125732,50000 -7517.003622770309,9.918134927749634,83646.99283194542,190421,0,83646.99283194542,0.6614000201225281,1.453892707824707,10000,91183.8697359562,0.8898632526397705,0.4129151701927185,0.7829399704933167,0.8525302410125732,50000 -7557.5557742118835,9.97598123550415,84066.93710017204,191376,0,84066.93710017204,0.6614000201225281,1.453892707824707,10000,91644.4737186432,0.8852733969688416,0.4186182618141174,0.7829399704933167,0.8525302410125732,50000 -7598.079861402512,10.03218388557434,84486.87456440926,192333,0,84486.87456440926,0.6614000201225281,1.453892707824707,10000,92105.04097628592,0.8894140720367432,0.415103018283844,0.7829399704933167,0.8525302410125732,50000 -7636.415265798569,10.100399255752563,84906.87883043289,193288,0,84906.87883043289,0.6614000201225281,1.453892707824707,10000,92563.49833774568,0.8891796469688416,0.4154730439186096,0.7829399704933167,0.8525302410125732,50000 -7685.801539182663,10.174046993255615,85327.06767630577,194241,0,85327.06767630577,0.6614000201225281,1.453892707824707,10000,93033.19660949708,0.8891015648841858,0.4140681624412536,0.7829399704933167,0.8525302410125732,50000 -7729.244254589081,10.232911586761476,85746.978110075,195200,0,85746.978110075,0.6614000201225281,1.453892707824707,10000,93496.65905618668,0.8866796493530273,0.4165627956390381,0.7829399704933167,0.8525302410125732,50000 -7771.67853140831,10.292376279830933,86167.28060626984,196157,0,86167.28060626984,0.6614000201225281,1.453892707824707,10000,93959.505079031,0.8865429759025574,0.425489604473114,0.7829399704933167,0.8525302410125732,50000 -7809.706657886505,10.347052335739136,86587.31809902191,197109,0,86587.31809902191,0.6614000201225281,1.453892707824707,10000,94417.6742913723,0.8899609446525574,0.4093320369720459,0.7829399704933167,0.8525302410125732,50000 -7848.151168823242,10.403776168823242,87007.37144398689,198061,0,87007.37144398689,0.6614000201225281,1.453892707824707,10000,94876.27783370018,0.8864452838897705,0.4223592579364776,0.7829399704933167,0.8525302410125732,50000 -7891.088857412338,10.479464530944824,87427.40733599663,199014,0,87427.40733599663,0.6614000201225281,1.453892707824707,10000,95339.37646174432,0.8879101276397705,0.4180936813354492,0.7829399704933167,0.8525302410125732,50000 -7936.020844221115,10.547487020492554,87847.40805268288,199970,0,87847.40805268288,0.6614000201225281,1.453892707824707,10000,95804.42789506912,0.8871484398841858,0.4199837148189544,0.7829399704933167,0.8525302410125732,50000 -7972.082072734833,10.60279655456543,88267.30465459824,200927,0,88267.30465459824,0.6614000201225281,1.453892707824707,10000,96260.49082493782,0.8883788585662842,0.4186073839664459,0.7829399704933167,0.8525302410125732,50000 -8011.646758794785,11.446199417114258,88686.67367577553,201881,0,88686.67367577553,0.6614000201225281,1.453892707824707,10000,96720.3174176216,0.8891015648841858,0.4090310931205749,0.7829399704933167,0.8525302410125732,50000 -8049.51655459404,11.518182754516602,89106.6478202343,202837,0,89106.6478202343,0.6614000201225281,1.453892707824707,10000,97178.28229618073,0.88636714220047,0.4229801595211029,0.7829399704933167,0.8525302410125732,50000 -8087.728713512421,11.590760707855225,89526.93627262115,203791,0,89526.93627262115,0.6614000201225281,1.453892707824707,10000,97636.9053592682,0.888476550579071,0.4164254367351532,0.7829399704933167,0.8525302410125732,50000 -8123.699931621551,11.662600040435793,89946.85517311096,204745,0,89946.85517311096,0.6614000201225281,1.453892707824707,10000,98092.91681575777,0.8869921565055847,0.4216805696487427,0.7829399704933167,0.8525302410125732,50000 -8164.820489406586,11.73024082183838,90367.04581069946,205697,0,90367.04581069946,0.6614000201225281,1.453892707824707,10000,98554.3462498188,0.8867382407188416,0.4210879802703857,0.7829399704933167,0.8525302410125732,50000 -8197.201642036438,11.799509048461914,90787.44080162048,206657,0,90787.44080162048,0.6614000201225281,1.453892707824707,10000,99007.24154257774,0.8891015648841858,0.4129902124404907,0.7829399704933167,0.8525302410125732,50000 -8231.754022359848,11.868899822235107,91207.7722082138,207611,0,91207.7722082138,0.6614000201225281,1.453892707824707,10000,99462.24376773834,0.88818359375,0.4146837592124939,0.7829399704933167,0.8525302410125732,50000 -8269.022126674652,11.942992687225342,91627.7535970211,208563,0,91627.7535970211,0.6614000201225281,1.453892707824707,10000,99919.61623358728,0.8888476490974426,0.4169521927833557,0.7829399704933167,0.8525302410125732,50000 -8315.414091348648,12.020440578460692,92047.78227806091,209518,0,92047.78227806091,0.6614000201225281,1.453892707824707,10000,100386.16372036934,0.887011706829071,0.4165227115154266,0.7829399704933167,0.8525302410125732,50000 -8352.31871843338,12.077834129333496,92468.02927017212,210479,0,92468.02927017212,0.6614000201225281,1.453892707824707,10000,100843.42229747772,0.8866015672683716,0.4254970550537109,0.7829399704933167,0.8525302410125732,50000 -8387.379422187805,12.153021335601808,92887.907848835,211435,0,92887.907848835,0.6614000201225281,1.453892707824707,10000,101298.48625206947,0.8909765481948853,0.4111464917659759,0.7829399704933167,0.8525302410125732,50000 -8422.15352678299,12.22426986694336,93307.9272289276,212387,0,93307.9272289276,0.6614000201225281,1.453892707824707,10000,101753.40010499954,0.8878515362739563,0.4138557016849518,0.7829399704933167,0.8525302410125732,50000 -8468.450547218323,12.29885721206665,93728.0548722744,213342,0,93728.0548722744,0.6614000201225281,1.453892707824707,10000,102219.9491224289,0.8890234231948853,0.4153840839862823,0.7829399704933167,0.8525302410125732,50000 -8503.7742228508,12.362093687057495,94148.00626373292,214300,0,94148.00626373292,0.6614000201225281,1.453892707824707,10000,102675.33705687524,0.8884375095367432,0.4121640026569366,0.7829399704933167,0.8525302410125732,50000 -8540.573637008667,12.436352014541626,94567.9120604992,215256,0,94567.9120604992,0.6614000201225281,1.453892707824707,10000,103132.16620612144,0.8880078196525574,0.4170250296592712,0.7829399704933167,0.8525302410125732,50000 -8586.329691886902,12.511003971099854,94987.81290459631,216205,0,94987.81290459631,0.6614000201225281,1.453892707824707,10000,103597.94665384293,0.8862890601158142,0.4187511801719665,0.7829399704933167,0.8525302410125732,50000 -8617.764047384262,12.574751377105711,95407.8063764572,217161,0,95407.8063764572,0.6614000201225281,1.453892707824707,10000,104049.49093866348,0.8904101252555847,0.4095988273620605,0.7829399704933167,0.8525302410125732,50000 -8665.94267654419,12.649450063705444,95827.81059122086,218114,0,95827.81059122086,0.6614000201225281,1.453892707824707,10000,104517.7988626957,0.888476550579071,0.4177097082138061,0.7829399704933167,0.8525302410125732,50000 -8708.01383304596,12.715829849243164,96248.04551506042,219073,0,96248.04551506042,0.6614000201225281,1.453892707824707,10000,104980.22134375572,0.8863476514816284,0.4216277599334717,0.7829399704933167,0.8525302410125732,50000 -8746.991399526596,12.78156876564026,96667.96599316596,220029,0,96667.96599316596,0.6614000201225281,1.453892707824707,10000,105439.2344288826,0.8868749737739563,0.4227591156959533,0.7829399704933167,0.8525302410125732,50000 -8784.837461471558,12.851918935775757,97088.07246875764,220985,0,97088.07246875764,0.6614000201225281,1.453892707824707,10000,105897.30706262589,0.8891015648841858,0.4091844856739044,0.7829399704933167,0.8525302410125732,50000 -8828.260638713837,12.930778980255129,97508.19736647606,221939,0,97508.19736647606,0.6614000201225281,1.453892707824707,10000,106360.98352241516,0.8893359303474426,0.4144483804702759,0.7829399704933167,0.8525302410125732,50000 -8864.390403270721,12.99502658843994,97928.41378474236,222897,0,97928.41378474236,0.6614000201225281,1.453892707824707,10000,106817.44368696211,0.8866796493530273,0.4247627854347229,0.7829399704933167,0.8525302410125732,50000 -8898.96737408638,13.071343421936035,98348.38157367706,223846,0,98348.38157367706,0.6614000201225281,1.453892707824707,10000,107272.11375570296,0.8874804377555847,0.4221131503582001,0.7829399704933167,0.8525302410125732,50000 -8946.43201994896,13.144215106964111,98768.6687772274,224799,0,98768.6687772274,0.6614000201225281,1.453892707824707,10000,107739.98773026466,0.8865624666213989,0.4165645837783813,0.7829399704933167,0.8525302410125732,50000 -8978.581200838089,13.209935903549194,99188.98446559906,225757,0,99188.98446559906,0.6614000201225281,1.453892707824707,10000,108192.5675497055,0.8892187476158142,0.4134066998958587,0.7829399704933167,0.8525302410125732,50000 -9029.657244682312,13.288574695587158,99608.9012246132,226713,0,99608.9012246132,0.6614000201225281,1.453892707824707,10000,108663.69064497948,0.8882030844688416,0.4167915880680084,0.7829399704933167,0.8525302410125732,50000 -9075.538096427916,13.347972631454468,100029.07903766632,227672,0,100029.07903766632,0.6614000201225281,1.453892707824707,10000,109129.8581469059,0.8873828053474426,0.4202496111392975,0.7829399704933167,0.8525302410125732,50000 -9109.95995593071,13.414387702941896,100449.26325941086,228631,0,100449.26325941086,0.6614000201225281,1.453892707824707,10000,109584.58092451096,0.8864648342132568,0.4214226305484772,0.7829399704933167,0.8525302410125732,50000 -9148.38451385498,13.495585680007936,100869.45522499084,229584,0,100869.45522499084,0.6614000201225281,1.453892707824707,10000,110043.32839989662,0.8885351419448853,0.4174453616142273,0.7829399704933167,0.8525302410125732,50000 -9183.717219114304,13.598756313323976,101289.73925161362,230534,0,101289.73925161362,0.6614000201225281,1.453892707824707,10000,110499.09790325163,0.8880273103713989,0.4105813503265381,0.7829399704933167,0.8525302410125732,50000 -9222.439522743223,13.67095160484314,101709.70796656609,231490,0,101709.70796656609,0.6614000201225281,1.453892707824707,10000,110957.91074514388,0.887499988079071,0.4188116490840912,0.7829399704933167,0.8525302410125732,50000 -9260.00146317482,13.75056290626526,102129.62954187392,232446,0,102129.62954187392,0.6614000201225281,1.453892707824707,10000,111415.52277565002,0.8887304663658142,0.4167661666870117,0.7829399704933167,0.8525302410125732,50000 -9300.13559770584,13.828028678894045,102549.59725284576,233400,0,102549.59725284576,0.6614000201225281,1.453892707824707,10000,111875.75222849846,0.887499988079071,0.4170728921890259,0.7829399704933167,0.8525302410125732,50000 -9340.741424560549,13.89189600944519,102969.69378328323,234360,0,102969.69378328323,0.6614000201225281,1.453892707824707,10000,112336.56853604317,0.8882030844688416,0.4220650792121887,0.7829399704933167,0.8525302410125732,50000 -9380.29682302475,13.971137523651125,103389.81013679504,235318,0,103389.81013679504,0.6614000201225281,1.453892707824707,10000,112796.37017011642,0.88818359375,0.4170994460582733,0.7829399704933167,0.8525302410125732,50000 -9428.746124267578,14.04552173614502,103809.80734848976,236272,0,103809.80734848976,0.6614000201225281,1.453892707824707,10000,113264.9406042099,0.888671875,0.4124659597873688,0.7829399704933167,0.8525302410125732,50000 -9468.708633422852,14.108554124832152,104230.0831682682,237229,0,104230.0831682682,0.6614000201225281,1.453892707824707,10000,113725.29149913788,0.8904492259025574,0.4119705855846405,0.7829399704933167,0.8525302410125732,50000 -9516.149178743362,14.232004404067991,104650.29148387907,238186,0,104650.29148387907,0.6614000201225281,1.453892707824707,10000,114193.11408925056,0.8873632550239563,0.4150916635990143,0.7829399704933167,0.8525302410125732,50000 -9553.618543624878,14.30665922164917,105070.2694966793,239143,0,105070.2694966793,0.6614000201225281,1.453892707824707,10000,114650.68592834473,0.8875585794448853,0.4165133833885193,0.7829399704933167,0.8525302410125732,50000 -9590.587503671646,14.38660478591919,105490.25032567978,240097,0,105490.25032567978,0.6614000201225281,1.453892707824707,10000,115107.77051401138,0.8878905773162842,0.4163110256195068,0.7829399704933167,0.8525302410125732,50000 -9629.101864814758,14.51517915725708,105910.28881311417,241051,0,105910.28881311417,0.6614000201225281,1.453892707824707,10000,115566.50127744676,0.8893749713897705,0.416501522064209,0.7829399704933167,0.8525302410125732,50000 -9668.963171720505,14.59090256690979,106330.19357085228,242003,0,106330.19357085228,0.6614000201225281,1.453892707824707,10000,116026.39215540886,0.8878124952316284,0.4150179922580719,0.7829399704933167,0.8525302410125732,50000 -9705.530426979063,14.666574001312256,106750.3352985382,242957,0,106750.3352985382,0.6614000201225281,1.453892707824707,10000,116483.22671723366,0.8879687190055847,0.4157115519046783,0.7829399704933167,0.8525302410125732,50000 -9752.409542560576,14.741652011871338,107170.53363466264,243911,0,107170.53363466264,0.6614000201225281,1.453892707824707,10000,116950.42909789084,0.8875976204872131,0.4220549166202545,0.7829399704933167,0.8525302410125732,50000 -9795.938479661942,14.807200908660889,107590.54052233696,244867,0,107590.54052233696,0.6614000201225281,1.453892707824707,10000,117414.08001947404,0.8887304663658142,0.4134644269943237,0.7829399704933167,0.8525302410125732,50000 -9835.25951552391,14.880392789840698,108010.750831604,245826,0,108010.750831604,0.6614000201225281,1.453892707824707,10000,117873.7346212864,0.8887304663658142,0.4132008254528045,0.7829399704933167,0.8525302410125732,50000 -9881.516390562056,14.945569515228271,108430.8633108139,246779,0,108430.8633108139,0.6614000201225281,1.453892707824707,10000,118340.21789956091,0.88685542345047,0.4224175214767456,0.7829399704933167,0.8525302410125732,50000 -9922.546244859695,15.019672870635986,108850.87331867218,247734,0,108850.87331867218,0.6614000201225281,1.453892707824707,10000,118801.3809273243,0.8893359303474426,0.4157688319683075,0.7829399704933167,0.8525302410125732,50000 -9961.588781833649,15.095571517944336,109271.09632349014,248690,0,109271.09632349014,0.6614000201225281,1.453892707824707,10000,119260.77212262154,0.8864257335662842,0.4195002615451813,0.7829399704933167,0.8525302410125732,50000 -9998.935094356537,15.17807126045227,109691.39629292488,249643,0,109691.39629292488,0.6614000201225281,1.453892707824707,10000,119718.55010271072,0.8881640434265137,0.4143919646739959,0.7829399704933167,0.8525302410125732,50000 -10037.054334640505,15.254687309265137,110111.48210144044,250598,0,110111.48210144044,0.6614000201225281,1.453892707824707,10000,120176.88127589226,0.8877539038658142,0.4176174700260162,0.7829399704933167,0.8525302410125732,50000 -10077.62155175209,15.338451147079468,110531.47845578194,251552,0,110531.47845578194,0.6614000201225281,1.453892707824707,10000,120637.57777690887,0.8870507478713989,0.4220073521137237,0.7829399704933167,0.8525302410125732,50000 -10119.577965021132,15.408671617507936,110951.68130636217,252511,0,110951.68130636217,0.6614000201225281,1.453892707824707,10000,121099.85851335526,0.8853319883346558,0.4239788651466369,0.7829399704933167,0.8525302410125732,50000 -10161.588672876358,15.490390062332152,111371.58928275108,253469,0,111371.58928275108,0.6614000201225281,1.453892707824707,10000,121561.9081993103,0.88671875,0.4214539229869842,0.7829399704933167,0.8525302410125732,50000 -10199.127958536148,15.558288812637327,111791.53270983696,254423,0,111791.53270983696,0.6614000201225281,1.453892707824707,10000,122019.50779938698,0.8893554210662842,0.4129403531551361,0.7829399704933167,0.8525302410125732,50000 -10248.407604455948,15.63877296447754,112211.63786792757,255378,0,112211.63786792757,0.6614000201225281,1.453892707824707,10000,122489.02125787736,0.8873632550239563,0.4160084426403045,0.7829399704933167,0.8525302410125732,50000 -10282.272267341614,15.70403218269348,112631.70882320404,256336,0,112631.70882320404,0.6614000201225281,1.453892707824707,10000,122943.07149362564,0.8903515338897705,0.4116775095462799,0.7829399704933167,0.8525302410125732,50000 -10332.705268383026,15.78389596939087,113052.01241707802,257290,0,113052.01241707802,0.6614000201225281,1.453892707824707,10000,123413.9364683628,0.8877929449081421,0.4188413619995117,0.7829399704933167,0.8525302410125732,50000 -10373.404185056686,15.850019216537476,113471.94405174255,258246,0,113471.94405174255,0.6614000201225281,1.453892707824707,10000,123874.68347358704,0.8865820169448853,0.4241874516010284,0.7829399704933167,0.8525302410125732,50000 -10412.161917448044,15.914101839065552,113892.2898645401,259201,0,113892.2898645401,0.6614000201225281,1.453892707824707,10000,124333.91425085068,0.8890624642372131,0.415942519903183,0.7829399704933167,0.8525302410125732,50000 -10454.48885679245,16.005540370941162,114312.50053668022,260147,0,114312.50053668022,0.6614000201225281,1.453892707824707,10000,124796.59203076364,0.8899999856948853,0.407744437456131,0.7829399704933167,0.8525302410125732,50000 -10492.80323791504,16.09019136428833,114732.67173314096,261098,0,114732.67173314096,0.6614000201225281,1.453892707824707,10000,125255.21079921722,0.8884179592132568,0.4153020083904266,0.7829399704933167,0.8525302410125732,50000 -10529.83626651764,16.163026809692383,115152.60396766664,262052,0,115152.60396766664,0.6614000201225281,1.453892707824707,10000,125712.29814648628,0.8883398175239563,0.4166604876518249,0.7829399704933167,0.8525302410125732,50000 -10577.499941587448,16.24975323677063,115572.77207779884,263006,0,115572.77207779884,0.6614000201225281,1.453892707824707,10000,126180.26666045187,0.8867968320846558,0.4171631932258606,0.7829399704933167,0.8525302410125732,50000 -10618.628512144089,16.316112995147705,115992.9700987339,263965,0,115992.9700987339,0.6614000201225281,1.453892707824707,10000,126641.70953035356,0.8879296779632568,0.4193416833877563,0.7829399704933167,0.8525302410125732,50000 -10655.76745057106,16.38568639755249,116412.8924088478,264920,0,116412.8924088478,0.6614000201225281,1.453892707824707,10000,127098.88965320589,0.8895702958106995,0.4095839560031891,0.7829399704933167,0.8525302410125732,50000 -10702.066796064377,16.466209650039673,116832.95955109596,265874,0,116832.95955109596,0.6614000201225281,1.453892707824707,10000,127565.38592290878,0.8883984088897705,0.4182425737380981,0.7829399704933167,0.8525302410125732,50000 -10737.159190177916,16.54511594772339,117253.22350525856,266832,0,117253.22350525856,0.6614000201225281,1.453892707824707,10000,128020.87030768394,0.8867382407188416,0.4178032875061035,0.7829399704933167,0.8525302410125732,50000 -10776.096898555756,16.626272678375244,117673.31190681458,267786,0,117673.31190681458,0.6614000201225281,1.453892707824707,10000,128480.0262553692,0.8861132860183716,0.4290775954723358,0.7829399704933167,0.8525302410125732,50000 -10814.793027639387,16.707128047943115,118093.4741909504,268739,0,118093.4741909504,0.6614000201225281,1.453892707824707,10000,128939.01418423653,0.8898437023162842,0.4073122441768646,0.7829399704933167,0.8525302410125732,50000 -10856.493202209473,16.789478540420532,118513.58444952963,269697,0,118513.58444952963,0.6614000201225281,1.453892707824707,10000,129400.95628118516,0.8878124952316284,0.414100170135498,0.7829399704933167,0.8525302410125732,50000 -10891.216101408005,16.87260341644287,118933.8285355568,270655,0,118933.8285355568,0.6614000201225281,1.453892707824707,10000,129856.05528616904,0.8884375095367432,0.4185597002506256,0.7829399704933167,0.8525302410125732,50000 -10927.63363814354,16.95663595199585,119353.74111104012,271611,0,119353.74111104012,0.6614000201225281,1.453892707824707,10000,130312.5183684826,0.8877343535423279,0.4192695617675781,0.7829399704933167,0.8525302410125732,50000 -10968.197939157486,17.039021015167236,119773.71231675148,272566,0,119773.71231675148,0.6614000201225281,1.453892707824707,10000,130773.18575668336,0.8866015672683716,0.4227052330970764,0.7829399704933167,0.8525302410125732,50000 -11016.682899475098,17.105791568756104,120193.92505979538,273525,0,120193.92505979538,0.6614000201225281,1.453892707824707,10000,131241.99991846085,0.8893554210662842,0.4108372330665588,0.7829399704933167,0.8525302410125732,50000 -11057.888793230057,17.17463517189026,120614.32047319412,274484,0,120614.32047319412,0.6614000201225281,1.453892707824707,10000,131703.71882891655,0.8886327743530273,0.4180715084075928,0.7829399704933167,0.8525302410125732,50000 -11096.530444145204,17.246686697006226,121034.56578993796,275441,0,121034.56578993796,0.6614000201225281,1.453892707824707,10000,132162.7273669243,0.8881054520606995,0.4140663146972656,0.7829399704933167,0.8525302410125732,50000 -11129.518389940262,17.333673238754272,121454.5478489399,276382,0,121454.5478489399,0.6614000201225281,1.453892707824707,10000,132615.83280420303,0.8871874809265137,0.4224247932434082,0.7829399704933167,0.8525302410125732,50000 -11165.648698568344,17.41505718231201,121874.64784169196,277326,0,121874.64784169196,0.6614000201225281,1.453892707824707,10000,133072.1931116581,0.8856640458106995,0.4246830940246582,0.7829399704933167,0.8525302410125732,50000 -11204.90605711937,17.502548217773438,122294.87970471382,278279,0,122294.87970471382,0.6614000201225281,1.453892707824707,10000,133531.81871771812,0.8876367211341858,0.4154536426067352,0.7829399704933167,0.8525302410125732,50000 -11243.041862249374,17.574153184890747,122715.11296343803,279237,0,122715.11296343803,0.6614000201225281,1.453892707824707,10000,133990.30941224098,0.88832026720047,0.4142638742923736,0.7829399704933167,0.8525302410125732,50000 -11274.032299041748,17.657609224319458,123135.0280020237,280193,0,123135.0280020237,0.6614000201225281,1.453892707824707,10000,134441.34727716446,0.8883007764816284,0.4194623529911041,0.7829399704933167,0.8525302410125732,50000 -11319.689100265505,17.738707065582275,123555.07653808594,281148,0,123555.07653808594,0.6614000201225281,1.453892707824707,10000,134907.18300938606,0.8873046636581421,0.4161858260631561,0.7829399704933167,0.8525302410125732,50000 -11356.789668560028,17.809013605117798,123975.168166399,282104,0,123975.168166399,0.6614000201225281,1.453892707824707,10000,135364.49425768852,0.8891406059265137,0.4142415225505829,0.7829399704933167,0.8525302410125732,50000 -11400.639189958572,17.891976594924927,124395.1518342495,283058,0,124395.1518342495,0.6614000201225281,1.453892707824707,10000,135828.45954036713,0.8882812261581421,0.4211524426937103,0.7829399704933167,0.8525302410125732,50000 -11440.33082962036,17.974778175354004,124815.4282989502,284012,0,124815.4282989502,0.6614000201225281,1.453892707824707,10000,136288.5594909191,0.8887304663658142,0.4123408496379852,0.7829399704933167,0.8525302410125732,50000 -11476.264526367188,18.068246841430664,125235.39047026634,284966,0,125235.39047026634,0.6614000201225281,1.453892707824707,10000,136744.59841275215,0.8908007740974426,0.410673975944519,0.7829399704933167,0.8525302410125732,50000 -11519.718435525894,18.15044403076172,125655.38073897362,285915,0,125655.38073897362,0.6614000201225281,1.453892707824707,10000,137208.1740090847,0.8874804377555847,0.4160492718219757,0.7829399704933167,0.8525302410125732,50000 -11557.939510822296,18.23632121086121,126075.25695848464,286871,0,126075.25695848464,0.6614000201225281,1.453892707824707,10000,137666.40572619438,0.88623046875,0.4197134375572204,0.7829399704933167,0.8525302410125732,50000 -11595.891792058945,18.34083724021912,126495.22283816338,287826,0,126495.22283816338,0.6614000201225281,1.453892707824707,10000,138124.47844481468,0.8890624642372131,0.41220623254776,0.7829399704933167,0.8525302410125732,50000 -11636.082073688509,18.42564058303833,126915.48208451273,288774,0,126915.48208451273,0.6614000201225281,1.453892707824707,10000,138585.06207585335,0.8875976204872131,0.4183759093284607,0.7829399704933167,0.8525302410125732,50000 -11675.6428399086,18.51556301116944,127335.6986641884,289728,0,127335.6986641884,0.6614000201225281,1.453892707824707,10000,139044.97914934158,0.8895702958106995,0.4143651723861694,0.7829399704933167,0.8525302410125732,50000 -11718.595754623411,18.59935998916626,127755.82724237442,290682,0,127755.82724237442,0.6614000201225281,1.453892707824707,10000,139508.1933040619,0.884765625,0.4241220355033874,0.7829399704933167,0.8525302410125732,50000 -11756.851325035095,18.68771505355835,128176.1004807949,291638,0,128176.1004807949,0.6614000201225281,1.453892707824707,10000,139966.8594853878,0.8886523246765137,0.417254239320755,0.7829399704933167,0.8525302410125732,50000 -11797.527698516846,18.773333311080933,128596.28019595146,292593,0,128596.28019595146,0.6614000201225281,1.453892707824707,10000,140427.85061454773,0.8885546922683716,0.4134081602096557,0.7829399704933167,0.8525302410125732,50000 -11835.907625436785,18.86729431152344,129016.2508816719,293549,0,129016.2508816719,0.6614000201225281,1.453892707824707,10000,140886.34425234795,0.8882030844688416,0.4148542284965515,0.7829399704933167,0.8525302410125732,50000 -11873.716660499573,18.954487562179565,129436.32824611664,294501,0,129436.32824611664,0.6614000201225281,1.453892707824707,10000,141344.36743688583,0.8893749713897705,0.414853423833847,0.7829399704933167,0.8525302410125732,50000 -11914.05112528801,19.04377579689026,129856.59120893478,295453,0,129856.59120893478,0.6614000201225281,1.453892707824707,10000,141805.1029598713,0.88929682970047,0.4171392321586609,0.7829399704933167,0.8525302410125732,50000 -11961.55670619011,19.131915807724,130276.792396307,296408,0,130276.792396307,0.6614000201225281,1.453892707824707,10000,142272.9465227127,0.8854687213897705,0.4234060049057007,0.7829399704933167,0.8525302410125732,50000 -12007.264422655106,19.20346760749817,130696.7673239708,297369,0,130696.7673239708,0.6614000201225281,1.453892707824707,10000,142738.7496433258,0.8866601586341858,0.4202454388141632,0.7829399704933167,0.8525302410125732,50000 -12052.48368382454,19.28312730789185,131116.9161388874,298328,0,131116.9161388874,0.6614000201225281,1.453892707824707,10000,143204.24835014343,0.8893163800239563,0.4147822558879852,0.7829399704933167,0.8525302410125732,50000 -12091.130192756653,19.356364011764526,131537.1429643631,299282,0,131537.1429643631,0.6614000201225281,1.453892707824707,10000,143663.2446167469,0.8881054520606995,0.4154012203216553,0.7829399704933167,0.8525302410125732,50000 -12128.053198814392,19.444679021835327,131957.35796141624,300226,0,131957.35796141624,0.6614000201225281,1.453892707824707,10000,144120.52024149895,0.8876757621765137,0.4198732674121856,0.7829399704933167,0.8525302410125732,50000 -12169.511289596558,19.53000783920288,132377.5421822071,301179,0,132377.5421822071,0.6614000201225281,1.453892707824707,10000,144582.29728770256,0.88636714220047,0.4219211935997009,0.7829399704933167,0.8525302410125732,50000 -12207.755256175997,19.618409633636475,132797.46908450127,302132,0,132797.46908450127,0.6614000201225281,1.453892707824707,10000,145040.6056535244,0.8879687190055847,0.4166455268859863,0.7829399704933167,0.8525302410125732,50000 -12251.54194355011,19.716617345809937,133217.7573838234,303087,0,133217.7573838234,0.6614000201225281,1.453892707824707,10000,145504.82874035835,0.8862499594688416,0.4185951948165893,0.7829399704933167,0.8525302410125732,50000 -12290.439830303192,19.80525302886963,133637.78238511086,304045,0,133637.78238511086,0.6614000201225281,1.453892707824707,10000,145963.89064216614,0.8896093368530273,0.4145287573337555,0.7829399704933167,0.8525302410125732,50000 -12330.843381643295,19.895825386047363,134058.0566318035,305000,0,134058.0566318035,0.6614000201225281,1.453892707824707,10000,146424.70875573158,0.8884570002555847,0.4126504063606262,0.7829399704933167,0.8525302410125732,50000 -12376.776423931122,19.98555564880371,134478.2640724182,305955,0,134478.2640724182,0.6614000201225281,1.453892707824707,10000,146890.98871064186,0.88783198595047,0.4241567254066467,0.7829399704933167,0.8525302410125732,50000 -12414.676441431046,20.075188636779785,134898.1319179535,306912,0,134898.1319179535,0.6614000201225281,1.453892707824707,10000,147348.89574956894,0.8887695074081421,0.4157889485359192,0.7829399704933167,0.8525302410125732,50000 -12453.731521844864,20.169724941253666,135318.12880396843,307867,0,135318.12880396843,0.6614000201225281,1.453892707824707,10000,147808.09196567535,0.8900195360183716,0.4082540273666382,0.7829399704933167,0.8525302410125732,50000 -12490.151092767715,20.24302864074707,135738.11697626114,308824,0,135738.11697626114,0.6614000201225281,1.453892707824707,10000,148264.62213468552,0.8877343535423279,0.4190030992031097,0.7829399704933167,0.8525302410125732,50000 -12534.50862789154,20.331170082092285,136158.0286180973,309777,0,136158.0286180973,0.6614000201225281,1.453892707824707,10000,148729.02848935127,0.8886523246765137,0.4138353765010834,0.7829399704933167,0.8525302410125732,50000 -12571.31841301918,20.42776656150818,136577.9388947487,310731,0,136577.9388947487,0.6614000201225281,1.453892707824707,10000,149185.8946583271,0.8862695097923279,0.4201231002807617,0.7829399704933167,0.8525302410125732,50000 -12609.654133558271,20.53228735923767,136997.87422275543,311684,0,136997.87422275543,0.6614000201225281,1.453892707824707,10000,149644.31943631172,0.8872851133346558,0.4167616665363312,0.7829399704933167,0.8525302410125732,50000 -12657.285651922226,20.62012791633606,137418.08788132668,312638,0,137418.08788132668,0.6614000201225281,1.453892707824707,10000,150112.3021736145,0.8905078172683716,0.4076966643333435,0.7829399704933167,0.8525302410125732,50000 -12697.472255945206,20.69444990158081,137838.1891798973,313597,0,137838.1891798973,0.6614000201225281,1.453892707824707,10000,150572.7144780159,0.8882616758346558,0.4182638525962829,0.7829399704933167,0.8525302410125732,50000 -12738.952292442322,20.77275776863098,138258.14862036705,314553,0,138258.14862036705,0.6614000201225281,1.453892707824707,10000,151034.2818892002,0.8880859017372131,0.4185377359390259,0.7829399704933167,0.8525302410125732,50000 -12780.446078777311,20.86136531829834,138678.3341538906,315505,0,138678.3341538906,0.6614000201225281,1.453892707824707,10000,151496.09946393967,0.8875390291213989,0.4202651381492615,0.7829399704933167,0.8525302410125732,50000 -12829.110315561296,20.9499146938324,139098.41466093063,316454,0,139098.41466093063,0.6614000201225281,1.453892707824707,10000,151964.981477499,0.88671875,0.4202517867088318,0.7829399704933167,0.8525302410125732,50000 -12871.982561588287,21.02366185188293,139518.60390734673,317411,0,139518.60390734673,0.6614000201225281,1.453892707824707,10000,152428.16596484184,0.8899804353713989,0.4071928262710571,0.7829399704933167,0.8525302410125732,50000 -12910.795140266418,21.09812617301941,139938.67020440102,318367,0,139938.67020440102,0.6614000201225281,1.453892707824707,10000,152887.1681947708,0.8865624666213989,0.4238818287849426,0.7829399704933167,0.8525302410125732,50000 -12951.488573789597,21.18717908859253,140358.56078743935,319304,0,140358.56078743935,0.6614000201225281,1.453892707824707,10000,153347.8892352581,0.8884961009025574,0.4180271029472351,0.7829399704933167,0.8525302410125732,50000 -12990.127300024033,21.276575088500977,140778.834120512,320258,0,140778.834120512,0.6614000201225281,1.453892707824707,10000,153806.9401268959,0.8869335651397705,0.419131875038147,0.7829399704933167,0.8525302410125732,50000 -13027.79325413704,21.4208242893219,141198.73992681503,321213,0,141198.73992681503,0.6614000201225281,1.453892707824707,10000,154264.70492196083,0.8894921541213989,0.4086560010910034,0.7829399704933167,0.8525302410125732,50000 -13066.877766132357,21.51161813735962,141619.0731432438,322164,0,141619.0731432438,0.6614000201225281,1.453892707824707,10000,154724.26458621025,0.8868749737739563,0.4255317449569702,0.7829399704933167,0.8525302410125732,50000 -13111.180988073347,21.60645341873169,142039.1426100731,323119,0,142039.1426100731,0.6614000201225281,1.453892707824707,10000,155188.78133320808,0.8875390291213989,0.4166724681854248,0.7829399704933167,0.8525302410125732,50000 -13159.376702070236,21.68626976013184,142459.23724484444,324077,0,142459.23724484444,0.6614000201225281,1.453892707824707,10000,155657.20071411133,0.88783198595047,0.4228438138961792,0.7829399704933167,0.8525302410125732,50000 -13207.968095541,21.76215624809265,142879.33050012589,325034,0,142879.33050012589,0.6614000201225281,1.453892707824707,10000,156126.01038050652,0.8855078220367432,0.4231415390968323,0.7829399704933167,0.8525302410125732,50000 -13247.191797733309,21.83880043029785,143299.4483640194,325990,0,143299.4483640194,0.6614000201225281,1.453892707824707,10000,156585.47791337967,0.8892773389816284,0.4138164818286896,0.7829399704933167,0.8525302410125732,50000 -13296.705656290054,21.93300485610962,143719.5908768177,326935,0,143719.5908768177,0.6614000201225281,1.453892707824707,10000,157055.27760457993,0.8871874809265137,0.4137941896915436,0.7829399704933167,0.8525302410125732,50000 -13335.431085586548,22.008670568466187,144139.76654458046,327889,0,144139.76654458046,0.6614000201225281,1.453892707824707,10000,157514.30312132835,0.8894726634025574,0.4152855277061462,0.7829399704933167,0.8525302410125732,50000 -13375.67746925354,22.125629425048828,144560.12257027626,328841,0,144560.12257027626,0.6614000201225281,1.453892707824707,10000,157975.0710504055,0.8877539038658142,0.416228175163269,0.7829399704933167,0.8525302410125732,50000 -13419.763077259064,22.216106176376343,144980.0615439415,329785,0,144980.0615439415,0.6614000201225281,1.453892707824707,10000,158439.23374319077,0.8871679306030273,0.4185080528259277,0.7829399704933167,0.8525302410125732,50000 -13458.372012853622,22.297227382659912,145400.25120687485,330743,0,145400.25120687485,0.6614000201225281,1.453892707824707,10000,158898.16319799423,0.8885937333106995,0.4203934073448181,0.7829399704933167,0.8525302410125732,50000 -13504.289674520493,22.39159369468689,145820.2736287117,331697,0,145820.2736287117,0.6614000201225281,1.453892707824707,10000,159364.24648237228,0.8895702958106995,0.4124229848384857,0.7829399704933167,0.8525302410125732,50000 -13551.26194858551,22.474181175231934,146240.3351507187,332651,0,146240.3351507187,0.6614000201225281,1.453892707824707,10000,159831.4127779007,0.8883398175239563,0.4132517576217651,0.7829399704933167,0.8525302410125732,50000 -13597.88392996788,22.552911043167114,146660.5423769951,333609,0,146660.5423769951,0.6614000201225281,1.453892707824707,10000,160298.3706278801,0.8883788585662842,0.4145821630954742,0.7829399704933167,0.8525302410125732,50000 -13646.050462961197,22.62759017944336,147080.81252145767,334567,0,147080.81252145767,0.6614000201225281,1.453892707824707,10000,160766.93076348305,0.8887499570846558,0.4145262539386749,0.7829399704933167,0.8525302410125732,50000 -13681.3886988163,22.70599865913391,147500.88044071198,335521,0,147500.88044071198,0.6614000201225281,1.453892707824707,10000,161222.46529269218,0.8877343535423279,0.4150673449039459,0.7829399704933167,0.8525302410125732,50000 -13723.035228729248,22.79630279541016,147921.03347206116,336456,0,147921.03347206116,0.6614000201225281,1.453892707824707,10000,161684.40320611,0.8890820145606995,0.4130926430225372,0.7829399704933167,0.8525302410125732,50000 -13764.488429307938,22.89143419265747,148341.1167113781,337407,0,148341.1167113781,0.6614000201225281,1.453892707824707,10000,162146.08368301392,0.88671875,0.4238328039646148,0.7829399704933167,0.8525302410125732,50000 -13804.790137529371,22.987659454345703,148761.1108095646,338363,0,148761.1108095646,0.6614000201225281,1.453892707824707,10000,162606.52558994293,0.88832026720047,0.4145942628383636,0.7829399704933167,0.8525302410125732,50000 -13846.27517938614,23.08351254463196,149180.98473072052,339317,0,149180.98473072052,0.6614000201225281,1.453892707824707,10000,163068.02965641022,0.8875976204872131,0.4212967455387115,0.7829399704933167,0.8525302410125732,50000 -13884.049741983414,23.206378698349,149600.92562270164,340272,0,149600.92562270164,0.6614000201225281,1.453892707824707,10000,163525.91829109192,0.8876757621765137,0.4158148467540741,0.7829399704933167,0.8525302410125732,50000 -13926.174630880356,23.29926109313965,150021.06316399574,341227,0,150021.06316399574,0.6614000201225281,1.453892707824707,10000,163988.32326960564,0.8883007764816284,0.4141983985900879,0.7829399704933167,0.8525302410125732,50000 -13966.889407157898,23.39518094062805,150441.1364018917,342183,0,150441.1364018917,0.6614000201225281,1.453892707824707,10000,164449.25674581528,0.8895702958106995,0.41547891497612,0.7829399704933167,0.8525302410125732,50000 -14012.772404670715,23.496416807174683,150861.22448587418,343140,0,150861.22448587418,0.6614000201225281,1.453892707824707,10000,164915.37854075432,0.8856054544448853,0.4266979992389679,0.7829399704933167,0.8525302410125732,50000 -14051.161545276642,23.57652187347412,151281.30113005638,344097,0,151281.30113005638,0.6614000201225281,1.453892707824707,10000,165373.97410154343,0.8867382407188416,0.419879138469696,0.7829399704933167,0.8525302410125732,50000 -14095.20967411995,23.672836303710938,151701.54492998123,345051,0,151701.54492998123,0.6614000201225281,1.453892707824707,10000,165838.41160845757,0.8881250023841858,0.4131699800491333,0.7829399704933167,0.8525302410125732,50000 -14138.956525087357,23.76843428611756,152121.5381603241,346001,0,152121.5381603241,0.6614000201225281,1.453892707824707,10000,166302.2970738411,0.8878905773162842,0.4152270853519439,0.7829399704933167,0.8525302410125732,50000 -14174.279473781586,23.846048831939697,152541.55381274223,346956,0,152541.55381274223,0.6614000201225281,1.453892707824707,10000,166757.76275634766,0.8885155916213989,0.4200997948646545,0.7829399704933167,0.8525302410125732,50000 -14219.28055357933,23.9422287940979,152961.44713258743,347909,0,152961.44713258743,0.6614000201225281,1.453892707824707,10000,167222.80291485786,0.8882226347923279,0.4179379642009735,0.7829399704933167,0.8525302410125732,50000 -14260.740604877472,24.04466724395752,153381.35877251625,348863,0,153381.35877251625,0.6614000201225281,1.453892707824707,10000,167684.3264117241,0.8864648342132568,0.4202927052974701,0.7829399704933167,0.8525302410125732,50000 -14300.969812393188,24.14210081100464,153801.53926444054,349818,0,153801.53926444054,0.6614000201225281,1.453892707824707,10000,168144.88361668587,0.8893945217132568,0.4127410650253296,0.7829399704933167,0.8525302410125732,50000 -14340.72930598259,24.23790383338928,154221.6833667755,350775,0,154221.6833667755,0.6614000201225281,1.453892707824707,10000,168604.9321911335,0.8864843845367432,0.4197324514389038,0.7829399704933167,0.8525302410125732,50000 -14390.32236981392,24.33543372154236,154641.72449684143,351730,0,154641.72449684143,0.6614000201225281,1.453892707824707,10000,169074.7146832943,0.8885937333106995,0.4155570268630981,0.7829399704933167,0.8525302410125732,50000 -14431.321568489077,24.41545557975769,155061.90138220787,352688,0,155061.90138220787,0.6614000201225281,1.453892707824707,10000,169536.02051210403,0.8866210579872131,0.4199872612953186,0.7829399704933167,0.8525302410125732,50000 -14475.808803319933,24.493077278137207,155481.78094792366,353645,0,155481.78094792366,0.6614000201225281,1.453892707824707,10000,170000.51396226883,0.8891991972923279,0.4173726439476013,0.7829399704933167,0.8525302410125732,50000 -14513.76334285736,24.60359287261963,155901.67514777184,354599,0,155901.67514777184,0.6614000201225281,1.453892707824707,10000,170458.52275180817,0.8884375095367432,0.4171162843704223,0.7829399704933167,0.8525302410125732,50000 -14561.556003570557,24.704198122024536,156321.83713293076,355550,0,156321.83713293076,0.6614000201225281,1.453892707824707,10000,170926.6277961731,0.890625,0.4058804512023926,0.7829399704933167,0.8525302410125732,50000 -14598.595462560654,24.78358292579651,156742.0211148262,356506,0,156742.0211148262,0.6614000201225281,1.453892707824707,10000,171383.98013997078,0.8883984088897705,0.416498452425003,0.7829399704933167,0.8525302410125732,50000 -14637.712956905363,24.881208181381226,157162.01435351372,357459,0,157162.01435351372,0.6614000201225281,1.453892707824707,10000,171843.23746800423,0.88818359375,0.4171052873134613,0.7829399704933167,0.8525302410125732,50000 -14687.286793231964,24.98118376731873,157582.07474136353,358413,0,157582.07474136353,0.6614000201225281,1.453892707824707,10000,172313.02126026154,0.8861718773841858,0.4189674258232116,0.7829399704933167,0.8525302410125732,50000 -14727.204654693604,25.068183422088623,158002.15069007874,359367,0,158002.15069007874,0.6614000201225281,1.453892707824707,10000,172773.15025138855,0.8883398175239563,0.4130707383155823,0.7829399704933167,0.8525302410125732,50000 -14767.577633619308,25.1567451953888,158422.29297685623,360324,0,158422.29297685623,0.6614000201225281,1.453892707824707,10000,173233.80309987068,0.88929682970047,0.4147252142429352,0.7829399704933167,0.8525302410125732,50000 -14807.759229898453,25.254371643066406,158842.19017791748,361268,0,158842.19017791748,0.6614000201225281,1.453892707824707,10000,173694.02808856964,0.88832026720047,0.4192902147769928,0.7829399704933167,0.8525302410125732,50000 -14853.150671720505,25.353637218475345,159262.09999990463,362214,0,159262.09999990463,0.6614000201225281,1.453892707824707,10000,174159.47754120827,0.8856054544448853,0.4205483794212341,0.7829399704933167,0.8525302410125732,50000 -14892.368756771088,25.447653770446777,159682.17498278618,363169,0,159682.17498278618,0.6614000201225281,1.453892707824707,10000,174618.91343593597,0.8899999856948853,0.4131052792072296,0.7829399704933167,0.8525302410125732,50000 -14934.86432981491,25.54854822158813,160102.2657313347,364121,0,160102.2657313347,0.6614000201225281,1.453892707824707,10000,175081.6497273445,0.8870507478713989,0.4179600775241852,0.7829399704933167,0.8525302410125732,50000 -14975.874757051468,25.65269923210144,160522.2911117077,365073,0,160522.2911117077,0.6614000201225281,1.453892707824707,10000,175542.83787035942,0.8878124952316284,0.414995789527893,0.7829399704933167,0.8525302410125732,50000 -15019.486609220505,25.7526330947876,160942.30407452583,366025,0,160942.30407452583,0.6614000201225281,1.453892707824707,10000,176006.61175465584,0.8901953101158142,0.4158586859703064,0.7829399704933167,0.8525302410125732,50000 -15057.952741146088,25.909390687942505,161362.19010972977,366978,0,161362.19010972977,0.6614000201225281,1.453892707824707,10000,176465.17006731033,0.8854882717132568,0.4268224537372589,0.7829399704933167,0.8525302410125732,50000 -15102.571516036987,26.01111960411072,161782.4360189438,367933,0,161782.4360189438,0.6614000201225281,1.453892707824707,10000,176930.18487286568,0.8866796493530273,0.4201131463050842,0.7829399704933167,0.8525302410125732,50000 -15145.225117206572,26.113094091415405,162202.54720520973,368887,0,162202.54720520973,0.6614000201225281,1.453892707824707,10000,177393.1012325287,0.8903710842132568,0.4076790809631347,0.7829399704933167,0.8525302410125732,50000 -15183.17496752739,26.19810652732849,162622.64628648758,369844,0,162622.64628648758,0.6614000201225281,1.453892707824707,10000,177851.28444957733,0.887011706829071,0.4195477068424225,0.7829399704933167,0.8525302410125732,50000 -15225.187652349472,26.29881167411804,163042.84110546112,370799,0,163042.84110546112,0.6614000201225281,1.453892707824707,10000,178313.641654253,0.8887695074081421,0.4166419804096222,0.7829399704933167,0.8525302410125732,50000 -15268.341364622116,26.399053812026978,163463.0209646225,371753,0,163463.0209646225,0.6614000201225281,1.453892707824707,10000,178777.1244843006,0.8869921565055847,0.4218083918094635,0.7829399704933167,0.8525302410125732,50000 -15305.892180919647,26.482391119003296,163883.04605412483,372710,0,163883.04605412483,0.6614000201225281,1.453892707824707,10000,179234.83329749107,0.8858007788658142,0.4265918731689453,0.7829399704933167,0.8525302410125732,50000 -15354.65480685234,26.608601808547974,164302.91035580635,373664,0,164302.91035580635,0.6614000201225281,1.453892707824707,10000,179703.6354522705,0.8891406059265137,0.4111347496509552,0.7829399704933167,0.8525302410125732,50000 -15400.136773586271,26.708449602127075,164722.90692043304,374620,0,164722.90692043304,0.6614000201225281,1.453892707824707,10000,180169.262591362,0.8878710865974426,0.4127155244350433,0.7829399704933167,0.8525302410125732,50000 -15445.25285935402,26.79526329040528,165143.0420908928,375576,0,165143.0420908928,0.6614000201225281,1.453892707824707,10000,180634.65010356903,0.8874609470367432,0.4219913482666015,0.7829399704933167,0.8525302410125732,50000 -15490.088018417358,26.87953495979309,165563.02144479752,376526,0,165563.02144479752,0.6614000201225281,1.453892707824707,10000,181099.59773135185,0.887011706829071,0.4182519316673279,0.7829399704933167,0.8525302410125732,50000 -15528.273938655851,26.961669921875,165983.27627158165,377483,0,165983.27627158165,0.6614000201225281,1.453892707824707,10000,181558.17009282112,0.8883398175239563,0.4191173315048218,0.7829399704933167,0.8525302410125732,50000 -15576.504554748535,27.059396266937256,166403.22260427475,378434,0,166403.22260427475,0.6614000201225281,1.453892707824707,10000,182026.49364376068,0.8884570002555847,0.4142103791236877,0.7829399704933167,0.8525302410125732,50000 -15614.70454120636,27.141717195510864,166823.30525374413,379390,0,166823.30525374413,0.6614000201225281,1.453892707824707,10000,182484.9082839489,0.8917577862739563,0.4062821865081787,0.7829399704933167,0.8525302410125732,50000 -15651.926760196686,27.245988607406616,167243.23605418205,380344,0,167243.23605418205,0.6614000201225281,1.453892707824707,10000,182942.2154922485,0.8888476490974426,0.4164382219314575,0.7829399704933167,0.8525302410125732,50000 -15693.9504032135,27.34863257408142,167663.43669724464,381297,0,167663.43669724464,0.6614000201225281,1.453892707824707,10000,183404.5911934376,0.8883007764816284,0.4152159392833709,0.7829399704933167,0.8525302410125732,50000 -15732.852969169617,27.450412034988403,168083.33858251572,382250,0,168083.33858251572,0.6614000201225281,1.453892707824707,10000,183863.54661631584,0.8875195384025574,0.416832834482193,0.7829399704933167,0.8525302410125732,50000 -15776.350043535233,27.55328798294068,168503.38748073578,383205,0,168503.38748073578,0.6614000201225281,1.453892707824707,10000,184327.24445915225,0.8854296803474426,0.4210447669029236,0.7829399704933167,0.8525302410125732,50000 -15817.368630886078,27.65798783302307,168923.58438396454,384159,0,168923.58438396454,0.6614000201225281,1.453892707824707,10000,184788.61385440824,0.889941394329071,0.409912645816803,0.7829399704933167,0.8525302410125732,50000 -15859.136594057083,27.77664041519165,169343.8165242672,385117,0,169343.8165242672,0.6614000201225281,1.453892707824707,10000,185250.7816569805,0.8879101276397705,0.4181992709636688,0.7829399704933167,0.8525302410125732,50000 -15897.479062080383,27.886797189712524,169763.95485639572,386074,0,169763.95485639572,0.6614000201225281,1.453892707824707,10000,185709.42883133888,0.8887499570846558,0.4167336821556091,0.7829399704933167,0.8525302410125732,50000 -15938.935741901398,27.993557691574097,170184.15087461472,387027,0,170184.15087461472,0.6614000201225281,1.453892707824707,10000,186171.23771500587,0.8863476514816284,0.420527309179306,0.7829399704933167,0.8525302410125732,50000 -15985.61399960518,28.094316720962524,170604.10696482658,387982,0,170604.10696482658,0.6614000201225281,1.453892707824707,10000,186638.0218274593,0.8876562118530273,0.4173603355884552,0.7829399704933167,0.8525302410125732,50000 -16020.70000576973,28.178041458129883,171024.2012052536,388940,0,171024.2012052536,0.6614000201225281,1.453892707824707,10000,187093.33483076096,0.8876757621765137,0.4181455373764038,0.7829399704933167,0.8525302410125732,50000 -16061.401812076569,28.2806077003479,171444.37899065018,389894,0,171444.37899065018,0.6614000201225281,1.453892707824707,10000,187554.36606407168,0.8898437023162842,0.4151351451873779,0.7829399704933167,0.8525302410125732,50000 -16111.968054294586,28.38612174987793,171864.35090208054,390845,0,171864.35090208054,0.6614000201225281,1.453892707824707,10000,188025.0598976612,0.8887304663658142,0.4167289733886719,0.7829399704933167,0.8525302410125732,50000 -16149.739753246307,28.470293283462524,172284.2420580387,391801,0,172284.2420580387,0.6614000201225281,1.453892707824707,10000,188482.8562738896,0.88818359375,0.4137005805969238,0.7829399704933167,0.8525302410125732,50000 -16189.30582332611,28.58032822608948,172704.4234380722,392756,0,172704.4234380722,0.6614000201225281,1.453892707824707,10000,188942.7630982399,0.8877343535423279,0.4162317514419555,0.7829399704933167,0.8525302410125732,50000 -16228.974833250046,28.68292617797852,173124.35608148575,393710,0,173124.35608148575,0.6614000201225281,1.453892707824707,10000,189402.51794099808,0.8866796493530273,0.4246825277805328,0.7829399704933167,0.8525302410125732,50000 -16268.61748099327,28.78764533996582,173544.5121805668,394665,0,173544.5121805668,0.6614000201225281,1.453892707824707,10000,189862.47140693665,0.8871679306030273,0.4199066758155823,0.7829399704933167,0.8525302410125732,50000 -16315.68248319626,28.8935661315918,173964.45985746384,395619,0,173964.45985746384,0.6614000201225281,1.453892707824707,10000,190329.6402170658,0.8871093392372131,0.4206791222095489,0.7829399704933167,0.8525302410125732,50000 -16353.387253761292,28.98282265663147,174384.35671782494,396579,0,174384.35671782494,0.6614000201225281,1.453892707824707,10000,190787.3806118965,0.8859961032867432,0.4233563840389251,0.7829399704933167,0.8525302410125732,50000 -16394.567228794098,29.087074756622314,174804.3445544243,397534,0,174804.3445544243,0.6614000201225281,1.453892707824707,10000,191248.7022612095,0.8891991972923279,0.4108793437480926,0.7829399704933167,0.8525302410125732,50000 -16437.66381263733,29.20992875099182,175224.36603736877,398491,0,175224.36603736877,0.6614000201225281,1.453892707824707,10000,191711.99235510823,0.8869531154632568,0.4180715978145599,0.7829399704933167,0.8525302410125732,50000 -16480.50748872757,29.315454959869385,175644.36986541748,399446,0,175644.36986541748,0.6614000201225281,1.453892707824707,10000,192174.9947481156,0.8896874785423279,0.4124925434589386,0.7829399704933167,0.8525302410125732,50000 -16529.972445726395,29.421252727508545,176064.40971302986,400401,0,176064.40971302986,0.6614000201225281,1.453892707824707,10000,192644.6549062729,0.8876757621765137,0.4160926342010498,0.7829399704933167,0.8525302410125732,50000 -16569.809871912003,29.519530534744263,176484.6373269558,401357,0,176484.6373269558,0.6614000201225281,1.453892707824707,10000,193104.86778235435,0.88880854845047,0.417743444442749,0.7829399704933167,0.8525302410125732,50000 -16609.966049671173,29.63040018081665,176904.59464097023,402300,0,176904.59464097023,0.6614000201225281,1.453892707824707,10000,193565.14063882828,0.8880664110183716,0.4243686497211456,0.7829399704933167,0.8525302410125732,50000 -16657.998848199844,29.737738132476807,177324.74939084053,403247,0,177324.74939084053,0.6614000201225281,1.453892707824707,10000,194033.4842557907,0.8893554210662842,0.4064174890518188,0.7829399704933167,0.8525302410125732,50000 -16697.615966558456,29.82686495780945,177744.90615653992,404203,0,177744.90615653992,0.6614000201225281,1.453892707824707,10000,194493.39645719528,0.8868359327316284,0.4202545583248138,0.7829399704933167,0.8525302410125732,50000 -16741.23883986473,29.913761138916016,178165.01534843445,405161,0,178165.01534843445,0.6614000201225281,1.453892707824707,10000,194957.2669413089,0.8908202648162842,0.4089140594005584,0.7829399704933167,0.8525302410125732,50000 -16782.42190861702,30.02169752120972,178585.18918466568,406117,0,178585.18918466568,0.6614000201225281,1.453892707824707,10000,195418.7814545632,0.8858398199081421,0.4200993180274963,0.7829399704933167,0.8525302410125732,50000 -16827.65636920929,30.131014585494995,179005.29097795486,407053,0,179005.29097795486,0.6614000201225281,1.453892707824707,10000,195884.275844574,0.8891991972923279,0.4103685021400451,0.7829399704933167,0.8525302410125732,50000 -16867.759991645813,30.217520475387573,179425.34422326088,408006,0,179425.34422326088,0.6614000201225281,1.453892707824707,10000,196344.5682406425,0.8874609470367432,0.4193182587623596,0.7829399704933167,0.8525302410125732,50000 -16914.508714675903,30.320476055145264,179845.50100183487,408955,0,179845.50100183487,0.6614000201225281,1.453892707824707,10000,196811.6254954338,0.8893163800239563,0.4162834584712982,0.7829399704933167,0.8525302410125732,50000 -16963.657194137573,30.428412675857544,180265.3992426396,409870,0,180265.3992426396,0.6614000201225281,1.453892707824707,10000,197280.82704353333,0.8877733945846558,0.4182986617088318,0.7829399704933167,0.8525302410125732,50000 -17006.726067066193,30.515804767608643,180685.4873828888,410825,0,180685.4873828888,0.6614000201225281,1.453892707824707,10000,197744.1201946736,0.8863281011581421,0.4216166734695434,0.7829399704933167,0.8525302410125732,50000 -17044.269870519638,30.603248834609985,181105.803034544,411778,0,181105.803034544,0.6614000201225281,1.453892707824707,10000,198202.1157720089,0.890625,0.4106134176254272,0.7829399704933167,0.8525302410125732,50000 -17083.35184264183,30.71147727966309,181525.74473643303,412712,0,181525.74473643303,0.6614000201225281,1.453892707824707,10000,198661.29543995857,0.8864843845367432,0.4179598391056061,0.7829399704933167,0.8525302410125732,50000 -17134.33325767517,30.818686723709103,181945.9853141308,413662,0,181945.9853141308,0.6614000201225281,1.453892707824707,10000,199132.67255043983,0.8892187476158142,0.4130817055702209,0.7829399704933167,0.8525302410125732,50000 -17172.24253630638,30.905985593795776,182366.14447641373,414618,0,182366.14447641373,0.6614000201225281,1.453892707824707,10000,199590.87727713585,0.8873242139816284,0.4246455430984497,0.7829399704933167,0.8525302410125732,50000 -17214.97789669037,31.05756092071533,182786.24632263184,415576,0,182786.24632263184,0.6614000201225281,1.453892707824707,10000,200053.9158146381,0.8868749737739563,0.4205108880996704,0.7829399704933167,0.8525302410125732,50000 -17258.430349826813,31.167073965072632,183206.3391830921,416528,0,183206.3391830921,0.6614000201225281,1.453892707824707,10000,200517.62397146225,0.888964831829071,0.4104123115539551,0.7829399704933167,0.8525302410125732,50000 -17299.904410362244,31.27821397781372,183626.3146479129,417479,0,183626.3146479129,0.6614000201225281,1.453892707824707,10000,200979.2339589596,0.8873632550239563,0.4178951978683471,0.7829399704933167,0.8525302410125732,50000 -17338.990936517715,31.370931386947632,184046.30386471748,418433,0,184046.30386471748,0.6614000201225281,1.453892707824707,10000,201438.4517595768,0.8864062428474426,0.4254805743694305,0.7829399704933167,0.8525302410125732,50000 -17388.526458263397,31.48215794563293,184466.1449456215,419385,0,184466.1449456215,0.6614000201225281,1.453892707824707,10000,201907.9888682365,0.8881250023841858,0.4159334599971771,0.7829399704933167,0.8525302410125732,50000 -17426.62258195877,31.82004451751709,184885.83039593697,420341,0,184885.83039593697,0.6614000201225281,1.453892707824707,10000,202366.15796732905,0.8867773413658142,0.423092246055603,0.7829399704933167,0.8525302410125732,50000 -17469.606053113937,31.93477702140808,185305.8543047905,421294,0,185305.8543047905,0.6614000201225281,1.453892707824707,10000,202829.32935857773,0.8882421851158142,0.4161365032196045,0.7829399704933167,0.8525302410125732,50000 -17508.005979537964,32.02518367767334,185725.99530100825,422250,0,185725.99530100825,0.6614000201225281,1.453892707824707,10000,203288.01004314423,0.888476550579071,0.4120833575725555,0.7829399704933167,0.8525302410125732,50000 -17548.692868709564,32.13347125053406,186145.98817968369,423203,0,186145.98817968369,0.6614000201225281,1.453892707824707,10000,203748.846982956,0.8883398175239563,0.414626270532608,0.7829399704933167,0.8525302410125732,50000 -17593.436250925064,32.24198937416077,186565.88954353333,424158,0,186565.88954353333,0.6614000201225281,1.453892707824707,10000,204213.6493074894,0.8863085508346558,0.421446144580841,0.7829399704933167,0.8525302410125732,50000 -17630.724541187286,32.34537315368652,186986.0024909973,425112,0,186986.0024909973,0.6614000201225281,1.453892707824707,10000,204671.20263504985,0.8889062404632568,0.416959673166275,0.7829399704933167,0.8525302410125732,50000 -17676.37200140953,32.45468544960022,187406.1336224079,426067,0,187406.1336224079,0.6614000201225281,1.453892707824707,10000,205137.1404938697,0.888476550579071,0.4230095446109772,0.7829399704933167,0.8525302410125732,50000 -17720.457312583923,32.5679783821106,187826.3649520874,427022,0,187826.3649520874,0.6614000201225281,1.453892707824707,10000,205601.6198780537,0.8894140720367432,0.4090057015419006,0.7829399704933167,0.8525302410125732,50000 -17767.809679031372,32.65678024291992,188246.64551234245,427979,0,188246.64551234245,0.6614000201225281,1.453892707824707,10000,206069.3907442093,0.8889452815055847,0.4116565585136413,0.7829399704933167,0.8525302410125732,50000 -17809.26633501053,32.74723792076111,188666.6430413723,428934,0,188666.6430413723,0.6614000201225281,1.453892707824707,10000,206530.98461413383,0.8880664110183716,0.4169478118419647,0.7829399704933167,0.8525302410125732,50000 -17853.379207134247,32.857391119003296,189086.55633735657,429888,0,189086.55633735657,0.6614000201225281,1.453892707824707,10000,206995.16974568367,0.8879101276397705,0.4141132533550262,0.7829399704933167,0.8525302410125732,50000 -17896.074108600616,32.970332860946655,189506.4773423672,430833,0,189506.4773423672,0.6614000201225281,1.453892707824707,10000,207457.9479892254,0.8874804377555847,0.4158562123775482,0.7829399704933167,0.8525302410125732,50000 -17935.916864156723,33.06070804595947,189926.61285805705,431786,0,189926.61285805705,0.6614000201225281,1.453892707824707,10000,207918.0655157566,0.8893945217132568,0.4129515886306762,0.7829399704933167,0.8525302410125732,50000 -17976.231738567352,33.17487382888794,190346.702862978,432730,0,190346.702862978,0.6614000201225281,1.453892707824707,10000,208378.6335697174,0.8885937333106995,0.4198548793792724,0.7829399704933167,0.8525302410125732,50000 -18022.161470651627,33.28832411766052,190766.5861680508,433681,0,190766.5861680508,0.6614000201225281,1.453892707824707,10000,208844.6084537506,0.8869335651397705,0.418942928314209,0.7829399704933167,0.8525302410125732,50000 -18061.090587615967,33.39108967781067,191186.5708837509,434636,0,191186.5708837509,0.6614000201225281,1.453892707824707,10000,209303.67406725883,0.8886327743530273,0.4150497019290924,0.7829399704933167,0.8525302410125732,50000 -18103.057653665543,33.505431175231934,191606.7920079232,435590,0,191606.7920079232,0.6614000201225281,1.453892707824707,10000,209766.02551412585,0.8875195384025574,0.4184292554855346,0.7829399704933167,0.8525302410125732,50000 -18144.561408996586,33.61554431915283,192026.9568374157,436546,0,192026.9568374157,0.6614000201225281,1.453892707824707,10000,210227.8538463116,0.8882812261581421,0.4159662425518036,0.7829399704933167,0.8525302410125732,50000 -18186.55794620514,33.72833752632141,192447.2765059471,437500,0,192447.2765059471,0.6614000201225281,1.453892707824707,10000,210690.33194756508,0.8877148032188416,0.4200031161308288,0.7829399704933167,0.8525302410125732,50000 -18225.270318984985,33.887531042099,192867.13353562355,438455,0,192867.13353562355,0.6614000201225281,1.453892707824707,10000,211149.1098475456,0.88734370470047,0.4195761084556579,0.7829399704933167,0.8525302410125732,50000 -18267.40621161461,34.00122618675232,193287.17481279373,439407,0,193287.17481279373,0.6614000201225281,1.453892707824707,10000,211611.44933223724,0.88832026720047,0.4196644127368927,0.7829399704933167,0.8525302410125732,50000 -18308.33525133133,34.11226868629456,193707.19133090973,440359,0,193707.19133090973,0.6614000201225281,1.453892707824707,10000,212072.5546195507,0.88671875,0.4149238169193268,0.7829399704933167,0.8525302410125732,50000 -18354.741428136826,34.21055889129639,194127.15976166725,441318,0,194127.15976166725,0.6614000201225281,1.453892707824707,10000,212539.07664871216,0.8875781297683716,0.4181137084960937,0.7829399704933167,0.8525302410125732,50000 -18399.771948337555,34.30404305458069,194547.1519122124,442273,0,194547.1519122124,0.6614000201225281,1.453892707824707,10000,213004.2418987751,0.8886913657188416,0.4170783460140228,0.7829399704933167,0.8525302410125732,50000 -18439.80325841904,34.41198301315308,194967.35440659523,443227,0,194967.35440659523,0.6614000201225281,1.453892707824707,10000,213464.63312387464,0.8878515362739563,0.4187040328979492,0.7829399704933167,0.8525302410125732,50000 -18486.59080529213,34.52042746543884,195387.2975564003,444170,0,195387.2975564003,0.6614000201225281,1.453892707824707,10000,213931.5210936069,0.8865429759025574,0.4236267507076263,0.7829399704933167,0.8525302410125732,50000 -18527.39293718338,34.61554837226868,195807.40478396416,445126,0,195807.40478396416,0.6614000201225281,1.453892707824707,10000,214392.5746004581,0.8879492282867432,0.4153803884983063,0.7829399704933167,0.8525302410125732,50000 -18570.385375261307,34.730725049972534,196227.24235582352,446075,0,196227.24235582352,0.6614000201225281,1.453892707824707,10000,214855.56814837456,0.8886132836341858,0.4135034084320068,0.7829399704933167,0.8525302410125732,50000 -18616.15464353561,34.85476279258728,196647.45865631104,447008,0,196647.45865631104,0.6614000201225281,1.453892707824707,10000,215321.72642993927,0.8878905773162842,0.4189607799053192,0.7829399704933167,0.8525302410125732,50000 -18656.66638660431,34.947226762771606,197067.3557920456,447962,0,197067.3557920456,0.6614000201225281,1.453892707824707,10000,215782.27796936035,0.8869726657867432,0.4157951772212982,0.7829399704933167,0.8525302410125732,50000 -18706.196996450424,35.063289403915405,197487.6512079239,448911,0,197487.6512079239,0.6614000201225281,1.453892707824707,10000,216252.2682375908,0.8879492282867432,0.4203372597694397,0.7829399704933167,0.8525302410125732,50000 -18753.732135295868,35.15776562690735,197907.5611524582,449865,0,197907.5611524582,0.6614000201225281,1.453892707824707,10000,216719.85712742803,0.8902929425239563,0.4134377539157867,0.7829399704933167,0.8525302410125732,50000 -18803.01878094673,35.27369546890259,198327.7874581813,450819,0,198327.7874581813,0.6614000201225281,1.453892707824707,10000,217189.5351667404,0.8892773389816284,0.4141132533550262,0.7829399704933167,0.8525302410125732,50000 -18843.86488962173,35.36725568771362,198747.9303052425,451771,0,198747.9303052425,0.6614000201225281,1.453892707824707,10000,217650.66642785072,0.8887499570846558,0.4107620418071747,0.7829399704933167,0.8525302410125732,50000 -18882.94619178772,35.465445041656494,199168.05628609657,452726,0,199168.05628609657,0.6614000201225281,1.453892707824707,10000,218110.0209414959,0.8876367211341858,0.4182771444320678,0.7829399704933167,0.8525302410125732,50000 -18924.48981571197,35.583218812942505,199588.1751565933,453667,0,199588.1751565933,0.6614000201225281,1.453892707824707,10000,218571.8494119644,0.8885546922683716,0.4138567149639129,0.7829399704933167,0.8525302410125732,50000 -18974.62998723984,35.69882941246033,200008.38172197345,454608,0,200008.38172197345,0.6614000201225281,1.453892707824707,10000,219042.3599162101,0.88636714220047,0.4170846343040466,0.7829399704933167,0.8525302410125732,50000 -19014.16286468506,35.791934967041016,200428.2955994606,455562,0,200428.2955994606,0.6614000201225281,1.453892707824707,10000,219501.9488492012,0.88978511095047,0.4132919609546661,0.7829399704933167,0.8525302410125732,50000 -19054.00243878365,35.90884757041931,200848.2231376171,456515,0,200848.2231376171,0.6614000201225281,1.453892707824707,10000,219961.8817877769,0.8884179592132568,0.4171240925788879,0.7829399704933167,0.8525302410125732,50000 -19096.94891095161,36.02534198760986,201268.2205114365,457457,0,201268.2205114365,0.6614000201225281,1.453892707824707,10000,220424.9903256893,0.8882616758346558,0.41612708568573,0.7829399704933167,0.8525302410125732,50000 -19140.144728183743,36.14205241203308,201688.31512522697,458408,0,201688.31512522697,0.6614000201225281,1.453892707824707,10000,220888.446969986,0.8860155940055847,0.4237580001354217,0.7829399704933167,0.8525302410125732,50000 -19180.52417993545,36.32553815841675,202108.22252106667,459363,0,202108.22252106667,0.6614000201225281,1.453892707824707,10000,221348.96659827232,0.88734370470047,0.4194290935993194,0.7829399704933167,0.8525302410125732,50000 -19224.78936481476,36.44341278076172,202528.4370057583,460316,0,202528.4370057583,0.6614000201225281,1.453892707824707,10000,221813.6135840416,0.890429675579071,0.4093267023563385,0.7829399704933167,0.8525302410125732,50000 -19273.821818590164,36.55770182609558,202948.36471748352,461268,0,202948.36471748352,0.6614000201225281,1.453892707824707,10000,222282.7366580963,0.8883007764816284,0.4184071719646454,0.7829399704933167,0.8525302410125732,50000 -19322.230692386627,36.65404319763184,203368.44725489616,462224,0,203368.44725489616,0.6614000201225281,1.453892707824707,10000,222751.3733520508,0.88623046875,0.422183096408844,0.7829399704933167,0.8525302410125732,50000 -19361.34489059448,36.750526428222656,203788.41516184807,463180,0,203788.41516184807,0.6614000201225281,1.453892707824707,10000,223210.6052725315,0.8882421851158142,0.4180953502655029,0.7829399704933167,0.8525302410125732,50000 -19404.15842437744,37.63438177108765,204207.8207669258,464120,0,204207.8207669258,0.6614000201225281,1.453892707824707,10000,223673.75593829155,0.887499988079071,0.4132614433765411,0.7829399704933167,0.8525302410125732,50000 -19454.58660507202,37.75303912162781,204627.69776153564,465070,0,204627.69776153564,0.6614000201225281,1.453892707824707,10000,224144.2284386158,0.8865624666213989,0.4224892854690552,0.7829399704933167,0.8525302410125732,50000 -19494.825929641724,37.85163021087647,205047.81872677803,466025,0,205047.81872677803,0.6614000201225281,1.453892707824707,10000,224604.7361364365,0.8872656226158142,0.420401781797409,0.7829399704933167,0.8525302410125732,50000 -19537.241803646088,37.98124074935913,205468.0308253765,466975,0,205468.0308253765,0.6614000201225281,1.453892707824707,10000,225067.5422782898,0.8876757621765137,0.4198971390724182,0.7829399704933167,0.8525302410125732,50000 -19583.554235219955,38.10013508796692,205887.92675709724,467912,0,205887.92675709724,0.6614000201225281,1.453892707824707,10000,225533.91818737984,0.887988269329071,0.4176282882690429,0.7829399704933167,0.8525302410125732,50000 -19634.826660633087,38.21635723114014,206307.7999806404,468863,0,206307.7999806404,0.6614000201225281,1.453892707824707,10000,226005.22905516624,0.8880664110183716,0.415490984916687,0.7829399704933167,0.8525302410125732,50000 -19675.311785697937,38.31474304199219,206727.8733928204,469817,0,206727.8733928204,0.6614000201225281,1.453892707824707,10000,226465.93561315536,0.8888671398162842,0.4122638702392578,0.7829399704933167,0.8525302410125732,50000 -19720.53402042389,38.43796992301941,207148.1556749344,470759,0,207148.1556749344,0.6614000201225281,1.453892707824707,10000,226931.6114668846,0.8885351419448853,0.4153727293014526,0.7829399704933167,0.8525302410125732,50000 -19762.82291984558,38.55806303024292,207568.38267040253,471704,0,207568.38267040253,0.6614000201225281,1.453892707824707,10000,227394.2978703976,0.8871874809265137,0.4193627834320068,0.7829399704933167,0.8525302410125732,50000 -19804.49430608749,38.68139624595642,207988.270154953,472652,0,207988.270154953,0.6614000201225281,1.453892707824707,10000,227856.02899241447,0.8884179592132568,0.4182244837284088,0.7829399704933167,0.8525302410125732,50000 -19843.9731669426,38.7846896648407,208408.5045876503,473606,0,208408.5045876503,0.6614000201225281,1.453892707824707,10000,228315.894469738,0.8891015648841858,0.4154936075210571,0.7829399704933167,0.8525302410125732,50000 -19884.347821950912,38.901485204696655,208828.76788640025,474560,0,208828.76788640025,0.6614000201225281,1.453892707824707,10000,228776.6988329888,0.8894921541213989,0.4117230474948883,0.7829399704933167,0.8525302410125732,50000 -19934.0034570694,39.02295017242432,209248.7898492813,475507,0,209248.7898492813,0.6614000201225281,1.453892707824707,10000,229246.5458858013,0.8870507478713989,0.4228179156780243,0.7829399704933167,0.8525302410125732,50000 -19973.2469329834,39.11786460876465,209668.7404808998,476464,0,209668.7404808998,0.6614000201225281,1.453892707824707,10000,229705.8835630417,0.8886327743530273,0.4156919717788696,0.7829399704933167,0.8525302410125732,50000 -20019.682067871094,39.2409257888794,210088.61083316803,477411,0,210088.61083316803,0.6614000201225281,1.453892707824707,10000,230172.3604860305,0.8877343535423279,0.4107499718666076,0.7829399704933167,0.8525302410125732,50000 -20060.430149316788,39.36173105239868,210508.6423151493,478363,0,210508.6423151493,0.6614000201225281,1.453892707824707,10000,230633.310161829,0.8869726657867432,0.4187818169593811,0.7829399704933167,0.8525302410125732,50000 -20105.6329100132,39.48548531532288,210928.602619648,479305,0,210928.602619648,0.6614000201225281,1.453892707824707,10000,231098.6457479,0.8893749713897705,0.4133050739765167,0.7829399704933167,0.8525302410125732,50000 -20145.21943593025,39.585007667541504,211348.75797367096,480260,0,211348.75797367096,0.6614000201225281,1.453892707824707,10000,231558.53671503067,0.8883788585662842,0.4162740111351013,0.7829399704933167,0.8525302410125732,50000 -20188.25280070305,39.7059965133667,211768.6008861065,481210,0,211768.6008861065,0.6614000201225281,1.453892707824707,10000,232021.582574606,0.8876757621765137,0.418246865272522,0.7829399704933167,0.8525302410125732,50000 -20237.58907341957,39.82587885856629,212188.4453895092,482163,0,212188.4453895092,0.6614000201225281,1.453892707824707,10000,232490.9327290058,0.8851562142372131,0.4248987138271332,0.7829399704933167,0.8525302410125732,50000 -20277.970195531845,39.92865061759949,212608.6507983208,483121,0,212608.6507983208,0.6614000201225281,1.453892707824707,10000,232951.67104578007,0.8895312547683716,0.4151040911674499,0.7829399704933167,0.8525302410125732,50000 -20317.68110179901,40.04916906356812,213028.6021828652,484076,0,213028.6021828652,0.6614000201225281,1.453892707824707,10000,233411.50339460373,0.888964831829071,0.4096009731292724,0.7829399704933167,0.8525302410125732,50000 -20358.64752459526,40.17074942588806,213448.84131765369,485011,0,213448.84131765369,0.6614000201225281,1.453892707824707,10000,233872.877859354,0.88832026720047,0.4179162383079529,0.7829399704933167,0.8525302410125732,50000 -20407.51692557335,40.30259609222412,213868.94949388504,485943,0,213868.94949388504,0.6614000201225281,1.453892707824707,10000,234342.0355579853,0.8872656226158142,0.4240639507770538,0.7829399704933167,0.8525302410125732,50000 -20453.481875658035,40.40373396873474,214289.19634270668,486899,0,214289.19634270668,0.6614000201225281,1.453892707824707,10000,234808.39871358871,0.88832026720047,0.416032999753952,0.7829399704933167,0.8525302410125732,50000 -20494.503543138504,40.596789598464966,214709.33081889155,487850,0,214709.33081889155,0.6614000201225281,1.453892707824707,10000,235269.7968466282,0.8885351419448853,0.4118785858154297,0.7829399704933167,0.8525302410125732,50000 -20536.487267017365,40.7184145450592,215129.6508102417,488786,0,215129.6508102417,0.6614000201225281,1.453892707824707,10000,235732.26984477043,0.8866210579872131,0.4215874373912811,0.7829399704933167,0.8525302410125732,50000 -20581.34592533112,40.841949224472046,215549.8371298313,489735,0,215549.8371298313,0.6614000201225281,1.453892707824707,10000,236197.4879083633,0.8875585794448853,0.4183970689773559,0.7829399704933167,0.8525302410125732,50000 -20622.188474416733,40.951939821243286,215970.1008954048,490689,0,215970.1008954048,0.6614000201225281,1.453892707824707,10000,236658.7532222271,0.8875976204872131,0.4221860468387604,0.7829399704933167,0.8525302410125732,50000 -20662.358990192413,41.06981325149536,216390.0034880638,491638,0,216390.0034880638,0.6614000201225281,1.453892707824707,10000,237118.9931571484,0.88623046875,0.4179167151451111,0.7829399704933167,0.8525302410125732,50000 -20713.29764199257,41.19071340560913,216810.23352575305,492593,0,216810.23352575305,0.6614000201225281,1.453892707824707,10000,237590.33246302605,0.888964831829071,0.4156960546970367,0.7829399704933167,0.8525302410125732,50000 -20761.96725344658,41.29007768630981,217230.2256433964,493552,0,217230.2256433964,0.6614000201225281,1.453892707824707,10000,238059.1432621479,0.8879101276397705,0.4139544665813446,0.7829399704933167,0.8525302410125732,50000 -20801.961450099945,41.39179611206055,217650.34621620167,494507,0,217650.34621620167,0.6614000201225281,1.453892707824707,10000,238519.40918159485,0.887988269329071,0.4172173142433166,0.7829399704933167,0.8525302410125732,50000 -20842.219407081604,41.5134813785553,218070.4790790081,495449,0,218070.4790790081,0.6614000201225281,1.453892707824707,10000,238979.97016358376,0.8880664110183716,0.4167808294296264,0.7829399704933167,0.8525302410125732,50000 -20893.64836382866,41.63759517669678,218490.4862658977,496385,0,218490.4862658977,0.6614000201225281,1.453892707824707,10000,239451.57863640785,0.8890038728713989,0.4185583293437958,0.7829399704933167,0.8525302410125732,50000 -20942.15079689026,41.74072694778442,218910.6556572914,497340,0,218910.6556572914,0.6614000201225281,1.453892707824707,10000,239920.405535698,0.8873828053474426,0.4235077500343323,0.7829399704933167,0.8525302410125732,50000 -20986.72572159767,41.84137916564941,219330.9631397724,498296,0,219330.9631397724,0.6614000201225281,1.453892707824707,10000,240385.43779563904,0.8897070288658142,0.4100601077079773,0.7829399704933167,0.8525302410125732,50000 -21032.17861413956,41.97613787651062,219751.18331694603,499248,0,219751.18331694603,0.6614000201225281,1.453892707824707,10000,240851.29478430748,0.8898437023162842,0.4133493602275848,0.7829399704933167,0.8525302410125732,50000 -21074.464852809902,42.10923957824707,220171.2231209278,500199,0,220171.2231209278,0.6614000201225281,1.453892707824707,10000,241313.8032577037,0.8876171708106995,0.4118110835552215,0.7829399704933167,0.8525302410125732,50000 -21120.94376969337,42.23678421974182,220591.38791012764,501146,0,220591.38791012764,0.6614000201225281,1.453892707824707,10000,241780.624147892,0.8863085508346558,0.4214934110641479,0.7829399704933167,0.8525302410125732,50000 -21161.110082149506,42.36047720909119,221011.6120646,502100,0,221011.6120646,0.6614000201225281,1.453892707824707,10000,242241.1875770092,0.8866991996765137,0.4188326299190521,0.7829399704933167,0.8525302410125732,50000 -21208.55673956871,42.48579788208008,221431.6432979107,503036,0,221431.6432979107,0.6614000201225281,1.453892707824707,10000,242708.8383128643,0.8915234208106995,0.4066730737686157,0.7829399704933167,0.8525302410125732,50000 -21253.096702337265,42.610806465148926,221851.79825234413,503987,0,221851.79825234413,0.6614000201225281,1.453892707824707,10000,243173.70783042908,0.8878710865974426,0.4177476763725281,0.7829399704933167,0.8525302410125732,50000 -21296.14944720268,42.73886179924011,222271.69676876068,504939,0,222271.69676876068,0.6614000201225281,1.453892707824707,10000,243636.8362257481,0.8858398199081421,0.4197298288345337,0.7829399704933167,0.8525302410125732,50000 -21340.503289461136,42.846314430236816,222691.6384379864,505895,0,222691.6384379864,0.6614000201225281,1.453892707824707,10000,244101.2883806229,0.8896679282188416,0.4172879457473755,0.7829399704933167,0.8525302410125732,50000 -21381.57668352127,42.97149038314819,223111.82583379743,506850,0,223111.82583379743,0.6614000201225281,1.453892707824707,10000,244562.72336554527,0.88720703125,0.417205810546875,0.7829399704933167,0.8525302410125732,50000 -21422.78080868721,43.89462304115296,223530.96260380745,507797,0,223530.96260380745,0.6614000201225281,1.453892707824707,10000,245024.03587937355,0.8874022960662842,0.4177261292934418,0.7829399704933167,0.8525302410125732,50000 -21464.441920757294,44.01762056350708,223951.0003323555,508734,0,223951.0003323555,0.6614000201225281,1.453892707824707,10000,245485.90597319603,0.8900390267372131,0.4131047427654266,0.7829399704933167,0.8525302410125732,50000 -21508.74162054062,44.14262199401856,224370.9559493065,509680,0,224370.9559493065,0.6614000201225281,1.453892707824707,10000,245950.3347005844,0.8871093392372131,0.4248417615890503,0.7829399704933167,0.8525302410125732,50000 -21554.337733268738,44.26745653152466,224790.7845821381,510633,0,224790.7845821381,0.6614000201225281,1.453892707824707,10000,246415.9335541725,0.8889452815055847,0.4144680798053741,0.7829399704933167,0.8525302410125732,50000 -21599.514624118805,44.40690755844116,225211.0207374096,511589,0,225211.0207374096,0.6614000201225281,1.453892707824707,10000,246881.53498005867,0.8864648342132568,0.4167560040950775,0.7829399704933167,0.8525302410125732,50000 -21637.92317771912,44.51779532432556,225630.9021072388,512540,0,225630.9021072388,0.6614000201225281,1.453892707824707,10000,247339.98562526703,0.8864648342132568,0.4207066595554352,0.7829399704933167,0.8525302410125732,50000 -21682.883960962296,44.65179920196533,226050.8605763912,513462,0,226050.8605763912,0.6614000201225281,1.453892707824707,10000,247805.08733654025,0.8897070288658142,0.414121150970459,0.7829399704933167,0.8525302410125732,50000 -21725.83594822884,44.77778220176697,226470.79879546163,514413,0,226470.79879546163,0.6614000201225281,1.453892707824707,10000,248268.15232920647,0.8862890601158142,0.4216505289077759,0.7829399704933167,0.8525302410125732,50000 -21767.78396916389,44.90311050415039,226890.87377142903,515363,0,226890.87377142903,0.6614000201225281,1.453892707824707,10000,248730.3489470482,0.8855664134025574,0.4250814616680145,0.7829399704933167,0.8525302410125732,50000 -21815.223826885223,45.10130214691162,227310.91848754883,516319,0,227310.91848754883,0.6614000201225281,1.453892707824707,10000,249198.0817565918,0.8897656202316284,0.4119044840335846,0.7829399704933167,0.8525302410125732,50000 -21854.596625089645,45.2050838470459,227731.0243735313,517273,0,227731.0243735313,0.6614000201225281,1.453892707824707,10000,249657.7138082981,0.888671875,0.4123574793338775,0.7829399704933167,0.8525302410125732,50000 -21899.41953110695,45.33416724205017,228151.06265687945,518217,0,228151.06265687945,0.6614000201225281,1.453892707824707,10000,250122.7528753281,0.8880273103713989,0.4175519347190857,0.7829399704933167,0.8525302410125732,50000 -21936.388018369675,45.43733549118042,228570.9589471817,519169,0,228570.9589471817,0.6614000201225281,1.453892707824707,10000,250579.7715086937,0.8868945240974426,0.4196770787239074,0.7829399704933167,0.8525302410125732,50000 -21979.60288333893,45.56658720970154,228990.81290125847,520121,0,228990.81290125847,0.6614000201225281,1.453892707824707,10000,251043.0187008381,0.8871679306030273,0.4226775765419006,0.7829399704933167,0.8525302410125732,50000 -22024.37559247017,45.68960404396057,229410.6971549988,521066,0,229410.6971549988,0.6614000201225281,1.453892707824707,10000,251507.8479943276,0.8896874785423279,0.4149737358093261,0.7829399704933167,0.8525302410125732,50000 -22069.55735874176,45.82239007949829,229830.59431529045,522018,0,229830.59431529045,0.6614000201225281,1.453892707824707,10000,251973.10938835144,0.8896093368530273,0.4115025997161865,0.7829399704933167,0.8525302410125732,50000 -22109.15285468101,45.93182587623596,230250.5108890533,522975,0,230250.5108890533,0.6614000201225281,1.453892707824707,10000,252432.78043317795,0.8873828053474426,0.4191846847534179,0.7829399704933167,0.8525302410125732,50000 -22160.45391464233,46.063761472702026,230670.47365617752,523921,0,230670.47365617752,0.6614000201225281,1.453892707824707,10000,252904.2251083851,0.8888671398162842,0.4110132753849029,0.7829399704933167,0.8525302410125732,50000 -22202.33893918991,46.168038845062256,231090.6543712616,524878,0,231090.6543712616,0.6614000201225281,1.453892707824707,10000,253366.4442384243,0.8877148032188416,0.4167254567146301,0.7829399704933167,0.8525302410125732,50000 -22245.91069793701,46.29690170288086,231510.7819397449,525826,0,231510.7819397449,0.6614000201225281,1.453892707824707,10000,253830.32165122032,0.8874022960662842,0.4148069620132446,0.7829399704933167,0.8525302410125732,50000 -22294.248700141907,46.49507021903992,231930.7475101948,526778,0,231930.7475101948,0.6614000201225281,1.453892707824707,10000,254298.8728711605,0.8897070288658142,0.4127437472343445,0.7829399704933167,0.8525302410125732,50000 -22334.084921598434,46.60599493980408,232350.62709617615,527733,0,232350.62709617615,0.6614000201225281,1.453892707824707,10000,254758.7492594719,0.8895898461341858,0.4154190421104431,0.7829399704933167,0.8525302410125732,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 55748d604..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5839 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.38732913,6.907756,,,,,,,,,,,,,, -1,,,0.0010351561941206,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.01926302909851,82.50370907783508,42.01926302909851,40.48432898521423,0.0,0.0 -100,0.47300062,6.8772135,,,,,,,,,,,,,, -200,0.6493789,6.746748,,,,,,,,,,,,,, -300,0.89038146,6.6332474,,,,,,,,,,,,,, -400,0.9373463,6.5644765,,,,,,,,,,,,,, -500,1.0129205,6.465884,,,,,,,,,,,,,, -600,1.0467464,6.3281612,,,,,,,,,,,,,, -700,1.3808005,6.7671847,,,,,,,,,,,,,, -800,1.6040305,6.182112,,,,,,,,,,,,,, -900,1.0841154,6.1717124,,,,,,,,,,,,,, -926,,,0.034375000745058,5.894576072692871,0.032999999821186,5.924591541290283,50000.0,0.0266000013798475,6.037107467651367,10000.0,462.3217113018036,523.9356322288513,462.3217113018036,61.537635803222656,0.0274052619934082,0.0 -1000,1.1204011,6.1545305,,,,,,,,,,,,,, -1100,1.2675164,6.1373453,,,,,,,,,,,,,, -1200,1.4749066,5.8729615,,,,,,,,,,,,,, -1300,1.0675726,5.87135,,,,,,,,,,,,,, -1400,1.103511,5.9173617,,,,,,,,,,,,,, -1500,1.1751494,5.804575,,,,,,,,,,,,,, -1600,1.120998,5.8388886,,,,,,,,,,,,,, -1700,1.2167246,6.668438,,,,,,,,,,,,,, -1800,1.0309289,5.6746106,,,,,,,,,,,,,, -1900,0.93020195,5.7745743,,,,,,,,,,,,,, -1901,,,0.0853906199336052,5.1030755043029785,0.0809599980711937,5.148300647735596,50000.0,0.0599000044167041,5.397579193115234,10000.0,882.624751329422,964.7744688987732,882.624751329422,81.992928981781,0.0549311637878418,0.0 -2000,1.0103729,5.6342154,,,,,,,,,,,,,, -2100,1.5933118,5.6329975,,,,,,,,,,,,,, -2200,0.8991654,6.5579095,,,,,,,,,,,,,, -2300,1.5010616,5.502831,,,,,,,,,,,,,, -2400,1.045278,5.291459,,,,,,,,,,,,,, -2500,1.0281262,5.344346,,,,,,,,,,,,,, -2600,1.4424853,5.356646,,,,,,,,,,,,,, -2700,1.7104876,6.636169,,,,,,,,,,,,,, -2800,1.2406664,5.7381644,,,,,,,,,,,,,, -2877,,,0.1455273479223251,4.522874355316162,0.1358200013637542,4.604587554931641,50000.0,0.1026000082492828,4.934473514556885,10000.0,1302.5669553279877,1406.0095813274384,1302.5669553279877,103.20603942871094,0.0836653709411621,0.0 -2900,0.9527946,5.2501335,,,,,,,,,,,,,, -3000,0.8463239,6.2479734,,,,,,,,,,,,,, -3100,0.9457314,5.073087,,,,,,,,,,,,,, -3200,0.8480853,5.54012,,,,,,,,,,,,,, -3300,0.8887188,5.208357,,,,,,,,,,,,,, -3400,1.0653993,5.026526,,,,,,,,,,,,,, -3500,1.0249418,4.9477615,,,,,,,,,,,,,, -3600,1.0827569,4.9442477,,,,,,,,,,,,,, -3700,0.9319472,5.9352016,,,,,,,,,,,,,, -3800,0.8555127,4.8076963,,,,,,,,,,,,,, -3854,,,0.2039648443460464,4.073319435119629,0.1905999928712844,4.152755260467529,50000.0,0.1454000025987625,4.553320407867432,10000.0,1722.5828483104706,1847.055628299713,1722.5828483104706,124.155011177063,0.1127946376800537,0.0 -3900,0.95821565,6.358395,,,,,,,,,,,,,, -4000,1.0012063,4.769633,,,,,,,,,,,,,, -4100,0.93894565,5.2930474,,,,,,,,,,,,,, -4200,0.8639516,4.628836,,,,,,,,,,,,,, -4300,0.7900479,4.976015,,,,,,,,,,,,,, -4400,1.0757428,4.7719965,,,,,,,,,,,,,, -4500,0.8694384,4.396705,,,,,,,,,,,,,, -4600,0.93970114,4.6779246,,,,,,,,,,,,,, -4700,0.96444297,4.2273645,,,,,,,,,,,,,, -4800,1.0367763,4.49852,,,,,,,,,,,,,, -4828,,,0.2770312428474426,3.557164430618286,0.2561799883842468,3.682000160217285,50000.0,0.195700004696846,4.1686320304870605,10000.0,2142.8305237293243,2290.86334848404,2142.8305237293243,147.63187551498413,0.1450769901275634,0.0 -4900,0.9923043,4.6165714,,,,,,,,,,,,,, -5000,0.926102,4.3079605,,,,,,,,,,,,,, -5100,0.76382405,5.3816094,,,,,,,,,,,,,, -5200,0.93281245,6.2572374,,,,,,,,,,,,,, -5300,0.99300075,4.2889214,,,,,,,,,,,,,, -5400,1.1943803,4.2035637,,,,,,,,,,,,,, -5500,1.0449545,5.251244,,,,,,,,,,,,,, -5600,0.92376465,4.275272,,,,,,,,,,,,,, -5700,0.6923041,4.8841276,,,,,,,,,,,,,, -5795,,,0.3194335997104645,3.239790916442871,0.2967000007629394,3.379960775375366,50000.0,0.2262000143527984,3.907768249511719,10000.0,2563.035956144333,2732.498987197876,2563.035956144333,168.98122477531433,0.1753633022308349,0.0 -5800,1.0191642,4.3300877,,,,,,,,,,,,,, -5900,0.90182024,5.268446,,,,,,,,,,,,,, -6000,0.9379677,4.163128,,,,,,,,,,,,,, -6100,1.0444008,4.158493,,,,,,,,,,,,,, -6200,1.0443726,4.261236,,,,,,,,,,,,,, -6300,0.76632434,5.347913,,,,,,,,,,,,,, -6400,0.7600005,4.059296,,,,,,,,,,,,,, -6500,0.9524235,4.144712,,,,,,,,,,,,,, -6600,0.8747229,4.637163,,,,,,,,,,,,,, -6700,0.95043296,3.8229747,,,,,,,,,,,,,, -6764,,,0.36865234375,2.939857721328736,0.3310999870300293,3.159950494766236,50000.0,0.2533000111579895,3.719790458679199,10000.0,2983.3539803028107,3176.9725415706635,2983.3539803028107,193.0559239387512,0.2038640975952148,0.0 -6800,0.74432504,4.5047574,,,,,,,,,,,,,, -6900,0.5869833,5.1120367,,,,,,,,,,,,,, -7000,0.7723554,4.1902556,,,,,,,,,,,,,, -7100,0.82849497,4.335855,,,,,,,,,,,,,, -7200,1.006207,3.887609,,,,,,,,,,,,,, -7300,0.88783425,3.8646233,,,,,,,,,,,,,, -7400,0.74711794,5.048344,,,,,,,,,,,,,, -7500,0.6163855,5.8100834,,,,,,,,,,,,,, -7600,0.71005297,5.6516085,,,,,,,,,,,,,, -7700,0.7878963,4.1264024,,,,,,,,,,,,,, -7732,,,0.3927929699420929,2.838428497314453,0.3671599924564361,2.9767770767211914,50000.0,0.2832000255584717,3.5490877628326416,10000.0,3403.5411076545715,3622.407206058502,3403.5411076545715,218.21373295784,0.2344825267791748,0.0 -7800,1.1095774,3.7308614,,,,,,,,,,,,,, -7900,1.0062449,3.9019542,,,,,,,,,,,,,, -8000,0.95815015,3.730879,,,,,,,,,,,,,, -8100,0.8742973,3.8154216,,,,,,,,,,,,,, -8200,0.95988804,3.9628937,,,,,,,,,,,,,, -8300,0.79116637,5.5182066,,,,,,,,,,,,,, -8400,0.665524,5.4917564,,,,,,,,,,,,,, -8500,1.0341631,3.7566354,,,,,,,,,,,,,, -8600,0.75199556,4.367637,,,,,,,,,,,,,, -8700,0.87849206,3.8237047,,,,,,,,,,,,,, -8701,,,0.41845703125,2.6635642051696777,0.389739990234375,2.8288674354553223,50000.0,0.2953000068664551,3.4372708797454834,10000.0,3823.8924305439,4074.4161858558655,3823.8924305439,249.78102779388428,0.2738537788391113,0.0 -8800,0.97078824,3.5953248,,,,,,,,,,,,,, -8900,0.9243159,3.6516323,,,,,,,,,,,,,, -9000,0.7215835,4.5269284,,,,,,,,,,,,,, -9100,1.0118124,3.6449976,,,,,,,,,,,,,, -9200,0.8774474,3.5675929,,,,,,,,,,,,,, -9300,1.0146694,3.5841513,,,,,,,,,,,,,, -9400,0.9132383,3.4917521,,,,,,,,,,,,,, -9500,0.8740046,4.6624384,,,,,,,,,,,,,, -9600,1.0102301,3.5532475,,,,,,,,,,,,,, -9674,,,0.4361914098262787,2.5769314765930176,0.3981599807739258,2.77098035812378,50000.0,0.3086000084877014,3.3669235706329346,10000.0,4244.15566444397,4520.442771434784,4244.15566444397,275.4603519439697,0.3063430786132812,0.0 -9700,0.77162147,4.5167513,,,,,,,,,,,,,, -9800,0.95207566,3.5226367,,,,,,,,,,,,,, -9900,0.8269004,4.660453,,,,,,,,,,,,,, -10000,0.9076865,5.8728952,,,,,,,,,,,,,, -10100,0.8973094,3.3835115,,,,,,,,,,,,,, -10200,1.0959779,3.6483777,,,,,,,,,,,,,, -10300,1.0119926,3.398452,,,,,,,,,,,,,, -10400,0.99152917,3.4958882,,,,,,,,,,,,,, -10500,1.02256,3.4773097,,,,,,,,,,,,,, -10600,0.6210804,5.550068,,,,,,,,,,,,,, -10642,,,0.4862890541553497,2.301498889923096,0.4239600002765655,2.618833065032959,50000.0,0.3243000209331512,3.2398502826690674,10000.0,4664.272991895676,4966.070648193359,4664.272991895676,300.88571643829346,0.3400704860687256,0.0 -10700,1.0093993,3.2830877,,,,,,,,,,,,,, -10800,0.82846624,4.503332,,,,,,,,,,,,,, -10900,0.9389571,3.4363587,,,,,,,,,,,,,, -11000,0.8738347,3.8205998,,,,,,,,,,,,,, -11100,0.7356636,5.783767,,,,,,,,,,,,,, -11200,1.025262,3.4505658,,,,,,,,,,,,,, -11300,0.86626506,5.866368,,,,,,,,,,,,,, -11400,0.9530553,3.282422,,,,,,,,,,,,,, -11500,1.0387229,3.4954367,,,,,,,,,,,,,, -11600,1.1165999,3.3433518,,,,,,,,,,,,,, -11610,,,0.4810546636581421,2.291144609451294,0.4471199810504913,2.464445114135742,50000.0,0.3436000049114227,3.111103773117065,10000.0,5084.614083766937,5418.4511461257935,5084.614083766937,332.84100675582886,0.3719866275787353,0.0 -11700,0.82971895,5.738659,,,,,,,,,,,,,, -11800,0.76122916,5.1583304,,,,,,,,,,,,,, -11900,0.9468449,3.4924846,,,,,,,,,,,,,, -12000,0.7489407,4.440767,,,,,,,,,,,,,, -12100,1.1963345,3.2238073,,,,,,,,,,,,,, -12200,1.0776428,3.347383,,,,,,,,,,,,,, -12300,1.004993,3.3057206,,,,,,,,,,,,,, -12400,1.046233,3.300161,,,,,,,,,,,,,, -12500,0.97526354,3.2339988,,,,,,,,,,,,,, -12579,,,0.4947851598262787,2.2269725799560547,0.4521400034427643,2.441700935363769,50000.0,0.3518000245094299,3.076727867126465,10000.0,5504.788379192352,5870.032203912735,5504.788379192352,364.16368222236633,0.4047729969024658,0.0 -12600,0.8820997,3.5979285,,,,,,,,,,,,,, -12700,1.1068295,3.27406,,,,,,,,,,,,,, -12800,1.0332698,3.266356,,,,,,,,,,,,,, -12900,1.0970556,3.2386212,,,,,,,,,,,,,, -13000,0.8347224,3.9472976,,,,,,,,,,,,,, -13100,0.9141038,3.3863614,,,,,,,,,,,,,, -13200,1.1181418,3.0698602,,,,,,,,,,,,,, -13300,0.80592006,4.5620084,,,,,,,,,,,,,, -13400,1.0113977,5.6561656,,,,,,,,,,,,,, -13500,0.99725085,3.250574,,,,,,,,,,,,,, -13548,,,0.5150390267372131,2.110365390777588,0.4720799922943115,2.342987060546875,50000.0,0.3651000261306762,2.996798276901245,10000.0,5924.957451581955,6316.37331199646,5924.957451581955,390.2532448768616,0.436398983001709,0.0 -13600,1.0684938,3.3998876,,,,,,,,,,,,,, -13700,0.72181594,5.5430837,,,,,,,,,,,,,, -13800,1.0860895,3.1219788,,,,,,,,,,,,,, -13900,1.04549,3.2549567,,,,,,,,,,,,,, -14000,1.008785,3.813355,,,,,,,,,,,,,, -14100,1.0812409,3.0994177,,,,,,,,,,,,,, -14200,1.2206458,5.5714374,,,,,,,,,,,,,, -14300,0.8833195,4.8650055,,,,,,,,,,,,,, -14400,1.0421488,3.1845484,,,,,,,,,,,,,, -14500,1.0951988,3.1160333,,,,,,,,,,,,,, -14515,,,0.5230273604393005,2.1068804264068604,0.4803399741649627,2.3169755935668945,50000.0,0.3729000091552734,2.967996597290039,10000.0,6345.0926015377045,6765.965216636658,6345.0926015377045,419.6220715045929,0.4737794399261474,0.0 -14600,1.0067141,3.0339782,,,,,,,,,,,,,, -14700,1.1041883,3.1300561,,,,,,,,,,,,,, -14800,0.8780669,5.3987007,,,,,,,,,,,,,, -14900,1.1191933,2.9411492,,,,,,,,,,,,,, -15000,1.1875335,3.1839147,,,,,,,,,,,,,, -15100,1.1201612,3.0714338,,,,,,,,,,,,,, -15200,1.1124109,3.1560273,,,,,,,,,,,,,, -15300,0.9566744,5.591735,,,,,,,,,,,,,, -15400,1.0517961,3.0152917,,,,,,,,,,,,,, -15480,,,0.5296484231948853,2.038872718811035,0.4894999861717224,2.242551803588867,50000.0,0.3804000318050384,2.88616681098938,10000.0,6765.137340545654,7217.566586494446,6765.137340545654,451.0975050926209,0.50453782081604,0.0 -15500,1.375148,3.0252352,,,,,,,,,,,,,, -15600,1.0489335,3.0601048,,,,,,,,,,,,,, -15700,0.892565,4.841108,,,,,,,,,,,,,, -15800,1.039224,3.0728784,,,,,,,,,,,,,, -15900,0.8162554,4.502153,,,,,,,,,,,,,, -16000,1.0141537,5.3831897,,,,,,,,,,,,,, -16100,0.99443305,3.0782049,,,,,,,,,,,,,, -16200,0.8269745,5.4382763,,,,,,,,,,,,,, -16300,0.8358973,5.5771637,,,,,,,,,,,,,, -16400,1.0868633,3.0100708,,,,,,,,,,,,,, -16446,,,0.5384374856948853,1.977601408958435,0.502299964427948,2.171231746673584,50000.0,0.3939000070095062,2.830028295516968,10000.0,7185.215455293655,7672.523951292038,7185.215455293655,485.8946087360382,0.53570556640625,0.0 -16500,1.1681263,3.4547985,,,,,,,,,,,,,, -16600,1.0639682,3.135181,,,,,,,,,,,,,, -16700,1.0141786,3.2973273,,,,,,,,,,,,,, -16800,1.0033937,3.2863872,,,,,,,,,,,,,, -16900,1.1222677,2.940519,,,,,,,,,,,,,, -17000,1.0853935,2.8361714,,,,,,,,,,,,,, -17100,0.99810725,4.178808,,,,,,,,,,,,,, -17200,1.2280239,3.126961,,,,,,,,,,,,,, -17300,1.0890857,3.358355,,,,,,,,,,,,,, -17400,1.4474741,3.0789092,,,,,,,,,,,,,, -17412,,,0.5560351610183716,1.9124348163604736,0.507420003414154,2.158041477203369,50000.0,0.3972000181674957,2.81754994392395,10000.0,7605.359039306641,8127.765516996384,7605.359039306641,520.9082908630371,0.5695958137512207,0.0 -17500,1.0736667,2.9615507,,,,,,,,,,,,,, -17600,1.1318488,2.9233184,,,,,,,,,,,,,, -17700,1.0670183,3.0659232,,,,,,,,,,,,,, -17800,1.0508615,3.0076509,,,,,,,,,,,,,, -17900,1.1292502,2.9990504,,,,,,,,,,,,,, -18000,1.1177554,3.1190364,,,,,,,,,,,,,, -18100,0.9476169,3.9383526,,,,,,,,,,,,,, -18200,1.1590462,2.92942,,,,,,,,,,,,,, -18300,0.9934928,3.0570323,,,,,,,,,,,,,, -18373,,,0.5530468821525574,1.9021896123886108,0.5165199637413025,2.088361978530884,50000.0,0.4043000340461731,2.763168811798096,10000.0,8025.422616481781,8582.154694318771,8025.422616481781,555.1550786495209,0.5985524654388428,0.0 -18400,1.2319328,2.916369,,,,,,,,,,,,,, -18500,0.91252995,4.776577,,,,,,,,,,,,,, -18600,1.0831052,3.010851,,,,,,,,,,,,,, -18700,1.1477897,2.9074275,,,,,,,,,,,,,, -18800,1.1353452,3.0268102,,,,,,,,,,,,,, -18900,1.066961,2.9064445,,,,,,,,,,,,,, -19000,1.0260692,3.2523022,,,,,,,,,,,,,, -19100,1.0621811,2.969618,,,,,,,,,,,,,, -19200,1.1163853,2.7099257,,,,,,,,,,,,,, -19300,1.146397,3.0273256,,,,,,,,,,,,,, -19330,,,0.5592187643051147,1.8991457223892207,0.5181199908256531,2.118349313735962,50000.0,0.3993000090122223,2.775531530380249,10000.0,8445.353466033936,9039.04480600357,8445.353466033936,592.0320270061493,0.6309165954589844,0.0 -19400,1.160056,2.9832983,,,,,,,,,,,,,, -19500,1.055325,3.0536456,,,,,,,,,,,,,, -19600,0.9412661,3.8997295,,,,,,,,,,,,,, -19700,1.0187328,2.9110992,,,,,,,,,,,,,, -19800,0.9388886,4.9281673,,,,,,,,,,,,,, -19900,1.2073971,2.9202766,,,,,,,,,,,,,, -20000,1.0912273,3.6707551,,,,,,,,,,,,,, -20100,0.93055415,5.2441325,,,,,,,,,,,,,, -20200,0.88316983,5.459878,,,,,,,,,,,,,, -20286,,,0.5762304663658142,1.809704303741455,0.5256400108337402,2.042741775512696,50000.0,0.4166000187397003,2.704303026199341,10000.0,8865.62778878212,9496.333614110948,8865.62778878212,628.9637792110443,0.6646013259887695,0.0 -20300,0.9541621,4.7664075,,,,,,,,,,,,,, -20400,1.0431799,3.042375,,,,,,,,,,,,,, -20500,1.0685784,2.9402092,,,,,,,,,,,,,, -20600,1.0068913,3.1566474,,,,,,,,,,,,,, -20700,0.8314804,4.820638,,,,,,,,,,,,,, -20800,0.9847646,3.4005382,,,,,,,,,,,,,, -20900,1.0155531,3.0045946,,,,,,,,,,,,,, -21000,1.1505959,2.8803246,,,,,,,,,,,,,, -21100,1.2412547,3.1745725,,,,,,,,,,,,,, -21200,0.8926757,4.0621223,,,,,,,,,,,,,, -21245,,,0.60107421875,1.7186756134033203,0.5374400019645691,2.0195109844207764,50000.0,0.4236000180244446,2.6766977310180664,10000.0,9285.970098257065,9952.100473880768,9285.970098257065,664.3080334663391,0.694582462310791,0.0 -21300,1.0804659,2.8486493,,,,,,,,,,,,,, -21400,1.1783704,2.8118937,,,,,,,,,,,,,, -21500,1.1817962,2.9046147,,,,,,,,,,,,,, -21600,1.3008204,2.8397005,,,,,,,,,,,,,, -21700,1.1746949,3.0401049,,,,,,,,,,,,,, -21800,1.0548499,3.943305,,,,,,,,,,,,,, -21900,0.93059015,5.394566,,,,,,,,,,,,,, -22000,0.92349863,5.3108377,,,,,,,,,,,,,, -22100,1.1889658,2.8066623,,,,,,,,,,,,,, -22200,,,0.5738281011581421,1.7905476093292236,0.5320799946784973,1.9965107440948489,50000.0,0.4204000234603882,2.663136005401612,10000.0,9706.001689434052,10409.156054019928,9706.001689434052,701.2506122589111,0.7266604900360107,0.0 -22200,1.1560571,2.8074026,,,,,,,,,,,,,, -22300,1.2435902,2.72107,,,,,,,,,,,,,, -22400,1.0683386,2.897757,,,,,,,,,,,,,, -22500,1.0631523,3.0156062,,,,,,,,,,,,,, -22600,1.1115791,2.7701325,,,,,,,,,,,,,, -22700,1.2076827,2.8165355,,,,,,,,,,,,,, -22800,1.0862952,5.4010115,,,,,,,,,,,,,, -22900,1.1189821,2.9563458,,,,,,,,,,,,,, -23000,1.1085646,2.8093615,,,,,,,,,,,,,, -23100,1.1403883,2.728451,,,,,,,,,,,,,, -23162,,,0.5878320336341858,1.741782546043396,0.5437399744987488,1.9589741230010984,50000.0,0.4255000054836273,2.6321370601654053,10000.0,10126.287100076675,10863.024575471878,10126.287100076675,734.7533724308014,0.7566776275634766,0.0 -23200,1.1162505,2.698742,,,,,,,,,,,,,, -23300,1.3328942,2.8711061,,,,,,,,,,,,,, -23400,1.1753666,2.678231,,,,,,,,,,,,,, -23500,1.1733686,2.9736319,,,,,,,,,,,,,, -23600,1.0387477,3.3102977,,,,,,,,,,,,,, -23700,1.1324464,2.8006964,,,,,,,,,,,,,, -23800,1.066493,2.8768406,,,,,,,,,,,,,, -23900,0.9381661,4.8282576,,,,,,,,,,,,,, -24000,0.9954037,4.942324,,,,,,,,,,,,,, -24100,1.1154069,3.0484378,,,,,,,,,,,,,, -24122,,,0.5895312428474426,1.806769847869873,0.5410400032997131,2.0305187702178955,50000.0,0.425100028514862,2.689875602722168,10000.0,10546.216496706007,11314.93691921234,10546.216496706007,766.6508240699768,0.7921888828277588,0.0 -24200,0.8353648,5.2885695,,,,,,,,,,,,,, -24300,1.2091615,2.7646296,,,,,,,,,,,,,, -24400,0.92762226,3.379088,,,,,,,,,,,,,, -24500,1.0133395,5.3330717,,,,,,,,,,,,,, -24600,0.91001135,5.2577257,,,,,,,,,,,,,, -24700,1.1838993,4.40609,,,,,,,,,,,,,, -24800,1.3439035,2.7421417,,,,,,,,,,,,,, -24900,1.0519104,3.6500676,,,,,,,,,,,,,, -25000,1.2186595,2.7028205,,,,,,,,,,,,,, -25079,,,0.6007226705551147,1.691822528839111,0.5526800155639648,1.9392848014831543,50000.0,0.4388000071048736,2.607574462890625,10000.0,10966.567569971085,11767.246729373932,10966.567569971085,798.524719953537,0.8267090320587158,0.0 -25100,0.91847044,5.345995,,,,,,,,,,,,,, -25200,0.9346249,5.0443506,,,,,,,,,,,,,, -25300,1.0697203,5.2080526,,,,,,,,,,,,,, -25400,1.1289641,2.567557,,,,,,,,,,,,,, -25500,0.95723575,5.2058387,,,,,,,,,,,,,, -25600,0.9103539,5.2788634,,,,,,,,,,,,,, -25700,1.2897488,2.6028938,,,,,,,,,,,,,, -25800,1.4919183,5.1997085,,,,,,,,,,,,,, -25900,1.188257,2.6527243,,,,,,,,,,,,,, -26000,1.0962496,4.817388,,,,,,,,,,,,,, -26027,,,0.5960351228713989,1.6998733282089231,0.5539199709892273,1.9093241691589355,50000.0,0.4345000088214874,2.5865228176116943,10000.0,11386.746564149857,12221.09743642807,11386.746564149857,832.1123118400574,0.8609256744384766,0.0 -26100,1.05989,5.3375587,,,,,,,,,,,,,, -26200,1.1306491,3.1049087,,,,,,,,,,,,,, -26300,1.1423221,2.7651079,,,,,,,,,,,,,, -26400,1.138386,2.7446032,,,,,,,,,,,,,, -26500,1.2569276,2.702849,,,,,,,,,,,,,, -26600,1.2246972,2.6588848,,,,,,,,,,,,,, -26700,1.2675495,2.7821155,,,,,,,,,,,,,, -26800,1.1990851,3.1649213,,,,,,,,,,,,,, -26900,1.1022446,2.7403922,,,,,,,,,,,,,, -26975,,,0.6092773079872131,1.6344122886657717,0.5594599843025208,1.863364815711975,50000.0,0.4431000351905823,2.551481008529663,10000.0,11806.727120399475,12674.887342214584,11806.727120399475,865.8403429985046,0.8936762809753418,0.0 -27000,1.110292,2.652313,,,,,,,,,,,,,, -27100,1.1528653,2.6016397,,,,,,,,,,,,,, -27200,1.0911967,2.687332,,,,,,,,,,,,,, -27300,0.85091543,5.244171,,,,,,,,,,,,,, -27400,1.091271,4.1428204,,,,,,,,,,,,,, -27500,0.96637595,3.6886916,,,,,,,,,,,,,, -27600,1.1159526,2.860858,,,,,,,,,,,,,, -27700,0.94446754,4.488002,,,,,,,,,,,,,, -27800,0.86681247,4.1481495,,,,,,,,,,,,,, -27900,0.90235555,4.641292,,,,,,,,,,,,,, -27935,,,0.6161523461341858,1.614219307899475,0.5626599788665771,1.8807896375656128,50000.0,0.4442000091075897,2.544827938079834,10000.0,12227.05313396454,13129.989423274994,12227.05313396454,900.530428647995,0.9298102855682372,0.0 -28000,1.1616563,2.7751894,,,,,,,,,,,,,, -28100,1.2516004,3.4991643,,,,,,,,,,,,,, -28200,1.3703846,2.6536036,,,,,,,,,,,,,, -28300,1.1498715,2.8643043,,,,,,,,,,,,,, -28400,1.0761104,5.235845,,,,,,,,,,,,,, -28500,0.9990469,4.9205775,,,,,,,,,,,,,, -28600,1.1338595,2.9323537,,,,,,,,,,,,,, -28700,1.1671745,2.6679833,,,,,,,,,,,,,, -28800,1.1471672,2.813876,,,,,,,,,,,,,, -28890,,,0.6077929735183716,1.6699862480163574,0.5683799982070923,1.8518006801605225,50000.0,0.445000022649765,2.535609483718872,10000.0,12647.381070375444,13583.913120269775,12647.381070375444,934.0426671504974,0.9640679359436036,0.0 -28900,1.2331932,2.7067285,,,,,,,,,,,,,, -29000,1.0855042,4.595937,,,,,,,,,,,,,, -29100,1.0464516,4.307212,,,,,,,,,,,,,, -29200,1.1825968,3.2050006,,,,,,,,,,,,,, -29300,1.0212034,3.3590302,,,,,,,,,,,,,, -29400,1.131137,2.8602693,,,,,,,,,,,,,, -29500,1.1043342,2.603988,,,,,,,,,,,,,, -29600,1.2005391,2.5667722,,,,,,,,,,,,,, -29700,1.1310649,2.971027,,,,,,,,,,,,,, -29800,1.0709058,5.035444,,,,,,,,,,,,,, -29846,,,0.6141015291213989,1.616439938545227,0.5706599950790405,1.8262709379196167,50000.0,0.4506000280380249,2.4998276233673096,10000.0,13067.404882907867,14037.70155787468,13067.404882907867,967.7215456962584,0.9999191761016846,0.0 -29900,1.1607943,2.7978354,,,,,,,,,,,,,, -30000,1.1474618,2.6874664,,,,,,,,,,,,,, -30100,1.1570104,2.5173674,,,,,,,,,,,,,, -30200,1.4604563,2.740901,,,,,,,,,,,,,, -30300,1.1640868,2.8955832,,,,,,,,,,,,,, -30400,0.95278114,4.346856,,,,,,,,,,,,,, -30500,1.2647572,2.9598236,,,,,,,,,,,,,, -30600,1.2642102,2.805042,,,,,,,,,,,,,, -30700,1.1640316,4.7345877,,,,,,,,,,,,,, -30800,1.1563181,2.7517283,,,,,,,,,,,,,, -30801,,,0.6283984184265137,1.5483167171478271,0.5783599615097046,1.7855757474899292,50000.0,0.458400011062622,2.4593968391418457,10000.0,13487.583720207214,14490.164947509766,13487.583720207214,999.9251253604888,1.0317416191101074,0.0 -30900,1.1327815,2.777274,,,,,,,,,,,,,, -31000,1.2864784,2.720458,,,,,,,,,,,,,, -31100,1.2457663,2.7065387,,,,,,,,,,,,,, -31200,1.1772424,2.6654215,,,,,,,,,,,,,, -31300,1.3356918,2.7678885,,,,,,,,,,,,,, -31400,1.2954769,2.5622313,,,,,,,,,,,,,, -31500,0.92953277,5.052957,,,,,,,,,,,,,, -31600,1.1094892,2.529438,,,,,,,,,,,,,, -31700,1.0061526,4.347381,,,,,,,,,,,,,, -31758,,,0.6522851586341858,1.4674758911132812,0.5711199641227722,1.8236415386199951,50000.0,0.4554000198841095,2.479957342147827,10000.0,13907.73057460785,14944.324080467224,13907.73057460785,1033.848325252533,1.0704412460327148,0.0 -31800,1.2856051,2.7117338,,,,,,,,,,,,,, -31900,1.0408102,3.1623163,,,,,,,,,,,,,, -32000,1.2079297,2.6582975,,,,,,,,,,,,,, -32100,1.0726501,4.9607244,,,,,,,,,,,,,, -32200,0.90259635,5.1624084,,,,,,,,,,,,,, -32300,0.84982055,4.776862,,,,,,,,,,,,,, -32400,1.1064489,2.5183446,,,,,,,,,,,,,, -32500,1.2366828,2.6175323,,,,,,,,,,,,,, -32600,0.9288892,4.171254,,,,,,,,,,,,,, -32700,1.0652213,4.174108,,,,,,,,,,,,,, -32715,,,0.6245507597923279,1.5911355018615725,0.5801599621772766,1.807516932487488,50000.0,0.4630000293254852,2.479883909225464,10000.0,14327.77366900444,15398.03229379654,14327.77366900444,1067.429335355759,1.102602243423462,0.0 -32800,1.0866799,3.3862772,,,,,,,,,,,,,, -32900,1.1687678,2.6741428,,,,,,,,,,,,,, -33000,1.0720603,2.8080597,,,,,,,,,,,,,, -33100,1.2411615,2.6987355,,,,,,,,,,,,,, -33200,1.2274301,2.4614522,,,,,,,,,,,,,, -33300,1.101981,2.441578,,,,,,,,,,,,,, -33400,1.2525158,2.5492554,,,,,,,,,,,,,, -33500,1.2256513,2.7155795,,,,,,,,,,,,,, -33600,1.0224526,2.9937062,,,,,,,,,,,,,, -33670,,,0.6285351514816284,1.5775716304779053,0.5806999802589417,1.7970975637435913,50000.0,0.4677000343799591,2.4360461235046387,10000.0,14747.711050987244,15848.411272764206,14747.711050987244,1097.7876312732697,1.136333703994751,0.0 -33700,1.155207,2.430864,,,,,,,,,,,,,, -33800,1.0432063,3.3087459,,,,,,,,,,,,,, -33900,0.9229368,5.0462747,,,,,,,,,,,,,, -34000,1.1429074,2.5567183,,,,,,,,,,,,,, -34100,0.9439118,4.858121,,,,,,,,,,,,,, -34200,0.91204107,4.85149,,,,,,,,,,,,,, -34300,0.9837717,4.496055,,,,,,,,,,,,,, -34400,1.0845398,3.6178396,,,,,,,,,,,,,, -34500,1.0091602,4.983824,,,,,,,,,,,,,, -34600,1.2818466,2.5947917,,,,,,,,,,,,,, -34622,,,0.639355480670929,1.4995332956314087,0.5854399800300598,1.7577393054962158,50000.0,0.4688000082969665,2.4291539192199707,10000.0,15167.88524198532,16302.25860452652,15167.88524198532,1131.367756843567,1.1800377368927002,0.0 -34700,1.1826925,2.9835577,,,,,,,,,,,,,, -34800,1.2412721,2.4919472,,,,,,,,,,,,,, -34900,1.1770921,3.7394497,,,,,,,,,,,,,, -35000,1.0791216,3.156749,,,,,,,,,,,,,, -35100,1.3964224,2.4621692,,,,,,,,,,,,,, -35200,1.1816034,2.586779,,,,,,,,,,,,,, -35300,1.2298385,2.478469,,,,,,,,,,,,,, -35400,1.2079996,2.5179014,,,,,,,,,,,,,, -35500,1.0428367,3.1499543,,,,,,,,,,,,,, -35577,,,0.6327343583106995,1.5223309993743896,0.5879799723625183,1.7351510524749756,50000.0,0.4697000086307525,2.40012788772583,10000.0,15588.195026397703,16756.752949476242,15588.195026397703,1165.467206954956,1.21533465385437,0.0 -35600,1.1458842,2.6119046,,,,,,,,,,,,,, -35700,1.3767388,2.5624363,,,,,,,,,,,,,, -35800,1.0416591,3.9611998,,,,,,,,,,,,,, -35900,1.297656,2.6189914,,,,,,,,,,,,,, -36000,1.007239,2.973907,,,,,,,,,,,,,, -36100,1.1723483,2.6190891,,,,,,,,,,,,,, -36200,1.1330421,2.5342224,,,,,,,,,,,,,, -36300,1.191844,2.6499746,,,,,,,,,,,,,, -36400,1.1912484,2.5334053,,,,,,,,,,,,,, -36500,1.1887126,2.6584604,,,,,,,,,,,,,, -36534,,,0.6396093368530273,1.5190812349319458,0.5896199941635132,1.7429559230804443,50000.0,0.4677000343799591,2.421023368835449,10000.0,16008.226280927658,17209.62438249588,16008.226280927658,1198.2166216373444,1.2557530403137207,0.0 -36600,1.2206215,2.7206595,,,,,,,,,,,,,, -36700,1.1989413,2.6030946,,,,,,,,,,,,,, -36800,1.2150078,2.5131428,,,,,,,,,,,,,, -36900,1.2990539,2.7730756,,,,,,,,,,,,,, -37000,1.1979905,2.6064286,,,,,,,,,,,,,, -37100,1.157815,5.1808434,,,,,,,,,,,,,, -37200,1.0496266,3.1512532,,,,,,,,,,,,,, -37300,1.2408489,2.584663,,,,,,,,,,,,,, -37400,1.2782841,2.5493128,,,,,,,,,,,,,, -37483,,,0.6413280963897705,1.495647668838501,0.5907399654388428,1.7329176664352417,50000.0,0.4678000211715698,2.402662515640259,10000.0,16428.334725379944,17661.970279455185,16428.334725379944,1230.3696205615995,1.2913818359375,0.0 -37500,1.3626344,2.5270958,,,,,,,,,,,,,, -37600,1.0190668,5.045201,,,,,,,,,,,,,, -37700,0.9672668,3.8448308,,,,,,,,,,,,,, -37800,1.0856489,5.157548,,,,,,,,,,,,,, -37900,1.1691027,2.5493126,,,,,,,,,,,,,, -38000,1.2213879,3.0078895,,,,,,,,,,,,,, -38100,1.179752,2.4030087,,,,,,,,,,,,,, -38200,1.2264532,2.607698,,,,,,,,,,,,,, -38300,1.085356,4.755173,,,,,,,,,,,,,, -38400,0.9977613,3.2333999,,,,,,,,,,,,,, -38436,,,0.6577734351158142,1.416632056236267,0.5920599699020386,1.7245413064956665,50000.0,0.4671000242233276,2.4115984439849854,10000.0,16848.304398536682,18115.436757087708,16848.304398536682,1263.7805182933807,1.3276479244232178,0.0 -38500,1.3221626,4.789532,,,,,,,,,,,,,, -38600,1.197009,2.5454743,,,,,,,,,,,,,, -38700,1.2781829,2.5507522,,,,,,,,,,,,,, -38800,1.361965,2.5043926,,,,,,,,,,,,,, -38900,1.0834744,5.0552063,,,,,,,,,,,,,, -39000,1.1611973,2.5521033,,,,,,,,,,,,,, -39100,1.0386255,5.0983515,,,,,,,,,,,,,, -39200,0.9486762,4.4141984,,,,,,,,,,,,,, -39300,1.332035,2.5649145,,,,,,,,,,,,,, -39393,,,0.6365038752555847,1.513041615486145,0.5950599908828735,1.7173501253128052,50000.0,0.4754000306129455,2.362500667572021,10000.0,17268.549030303955,18568.60999751091,17268.549030303955,1296.6276342868805,1.3599748611450195,0.0 -39400,1.198249,2.4387531,,,,,,,,,,,,,, -39500,1.1235359,2.7603796,,,,,,,,,,,,,, -39600,0.97263026,4.977895,,,,,,,,,,,,,, -39700,1.3528662,2.6156602,,,,,,,,,,,,,, -39800,1.04627,3.3277771,,,,,,,,,,,,,, -39900,1.190958,2.5671434,,,,,,,,,,,,,, -40000,1.0947081,2.7989612,,,,,,,,,,,,,, -40100,0.9604375,4.00754,,,,,,,,,,,,,, -40200,1.3012023,2.5792418,,,,,,,,,,,,,, -40300,1.0455762,3.62871,,,,,,,,,,,,,, -40347,,,0.6510156393051147,1.4378249645233154,0.6019399762153625,1.6753579378128052,50000.0,0.483100026845932,2.3189759254455566,10000.0,17688.56049466133,19023.37623500824,17688.56049466133,1331.296977519989,1.3953208923339844,0.0 -40400,0.9628813,4.3271174,,,,,,,,,,,,,, -40500,1.0788188,2.3912594,,,,,,,,,,,,,, -40600,1.2056272,4.5941386,,,,,,,,,,,,,, -40700,1.2178804,2.5255702,,,,,,,,,,,,,, -40800,1.2509657,2.734342,,,,,,,,,,,,,, -40900,1.1134719,2.848031,,,,,,,,,,,,,, -41000,1.2183768,2.509691,,,,,,,,,,,,,, -41100,1.0017353,3.4844518,,,,,,,,,,,,,, -41200,1.0911715,4.4522457,,,,,,,,,,,,,, -41300,1.0570343,3.7276547,,,,,,,,,,,,,, -41301,,,0.6549023389816284,1.4472695589065552,0.6011599898338318,1.710437893867493,50000.0,0.4743000268936157,2.380539894104004,10000.0,18108.91366672516,19478.32561135292,18108.91366672516,1365.8039565086365,1.434863567352295,0.0 -41400,1.2756857,2.5406842,,,,,,,,,,,,,, -41500,1.0529839,5.093581,,,,,,,,,,,,,, -41600,1.5494652,2.5037453,,,,,,,,,,,,,, -41700,0.9537952,4.0320635,,,,,,,,,,,,,, -41800,1.2661762,3.8453782,,,,,,,,,,,,,, -41900,1.1619515,4.1519547,,,,,,,,,,,,,, -42000,1.2362266,2.6004248,,,,,,,,,,,,,, -42100,1.1637006,2.8484669,,,,,,,,,,,,,, -42200,1.1374091,3.168937,,,,,,,,,,,,,, -42256,,,0.662109375,1.3703303337097168,0.6002799868583679,1.6656572818756104,50000.0,0.4777000248432159,2.3627915382385254,10000.0,18529.00963950157,19932.3771443367,18529.00963950157,1399.6762397289276,1.4674532413482666,0.0 -42300,1.1221472,3.2560952,,,,,,,,,,,,,, -42400,1.1863494,2.5409017,,,,,,,,,,,,,, -42500,1.0149797,3.8553495,,,,,,,,,,,,,, -42600,1.2023472,2.341997,,,,,,,,,,,,,, -42700,1.5470111,2.48569,,,,,,,,,,,,,, -42800,1.0035245,4.4750185,,,,,,,,,,,,,, -42900,1.3838829,5.0797625,,,,,,,,,,,,,, -43000,1.2415184,2.613046,,,,,,,,,,,,,, -43100,1.1582158,3.1268654,,,,,,,,,,,,,, -43200,1.2162153,2.4837768,,,,,,,,,,,,,, -43214,,,0.6528710722923279,1.4166855812072754,0.6093400120735168,1.6414382457733154,50000.0,0.4887000322341919,2.299259185791016,10000.0,18949.35471701622,20389.42469477653,18949.35471701622,1436.2877542972565,1.508568525314331,0.0 -43300,0.8852088,4.2501225,,,,,,,,,,,,,, -43400,1.2277681,2.4953122,,,,,,,,,,,,,, -43500,1.4279279,2.4411383,,,,,,,,,,,,,, -43600,1.1719216,4.8775196,,,,,,,,,,,,,, -43700,1.0507431,3.066437,,,,,,,,,,,,,, -43800,1.2923045,2.986589,,,,,,,,,,,,,, -43900,1.2586938,5.1742296,,,,,,,,,,,,,, -44000,0.96142465,4.0961704,,,,,,,,,,,,,, -44100,1.4467632,2.522572,,,,,,,,,,,,,, -44170,,,0.6615234017372131,1.3803268671035769,0.6105799674987793,1.6226335763931274,50000.0,0.4886000156402588,2.2939584255218506,10000.0,19369.56394290924,20844.375911474228,19369.56394290924,1470.940866470337,1.5481336116790771,0.0 -44200,1.0200306,4.7073507,,,,,,,,,,,,,, -44300,1.0409951,3.6342819,,,,,,,,,,,,,, -44400,1.0259432,5.091766,,,,,,,,,,,,,, -44500,1.4836458,2.6664062,,,,,,,,,,,,,, -44600,1.309817,2.2938216,,,,,,,,,,,,,, -44700,0.9869874,3.9040062,,,,,,,,,,,,,, -44800,1.098788,3.635701,,,,,,,,,,,,,, -44900,1.1818794,2.5609558,,,,,,,,,,,,,, -45000,1.2162265,2.493926,,,,,,,,,,,,,, -45100,1.3172671,2.4437842,,,,,,,,,,,,,, -45127,,,0.6758593320846558,1.3159501552581787,0.6125400066375732,1.623989820480347,50000.0,0.4873000085353851,2.2900307178497314,10000.0,19789.847430944443,21299.075829267505,19789.847430944443,1505.269142627716,1.5860624313354492,0.0 -45200,1.3065655,2.6809638,,,,,,,,,,,,,, -45300,1.2399971,4.909924,,,,,,,,,,,,,, -45400,1.1878035,3.0957534,,,,,,,,,,,,,, -45500,1.1021591,4.148614,,,,,,,,,,,,,, -45600,1.0464991,4.9029546,,,,,,,,,,,,,, -45700,1.1984265,3.3585227,,,,,,,,,,,,,, -45800,1.1755363,2.4000745,,,,,,,,,,,,,, -45900,1.2382042,2.8283346,,,,,,,,,,,,,, -46000,1.0177314,4.5303936,,,,,,,,,,,,,, -46084,,,0.6569140553474426,1.404916763305664,0.6086599826812744,1.6356945037841797,50000.0,0.4923000335693359,2.3095061779022217,10000.0,20210.07796788216,21751.31724905968,20210.07796788216,1537.1973378658297,1.619213342666626,0.0 -46100,1.0299559,4.2003436,,,,,,,,,,,,,, -46200,1.2765737,4.865142,,,,,,,,,,,,,, -46300,1.1360979,4.1317153,,,,,,,,,,,,,, -46400,1.4196056,2.4640741,,,,,,,,,,,,,, -46500,1.2734435,2.6803954,,,,,,,,,,,,,, -46600,1.424666,2.5913744,,,,,,,,,,,,,, -46700,1.3307565,2.8307679,,,,,,,,,,,,,, -46800,1.374118,2.461258,,,,,,,,,,,,,, -46900,1.1923643,4.217698,,,,,,,,,,,,,, -47000,1.1388799,2.7425685,,,,,,,,,,,,,, -47041,,,0.6555468440055847,1.4303855895996094,0.6114799976348877,1.6497628688812256,50000.0,0.4798000156879425,2.3286588191986084,10000.0,20630.187987327576,22206.4185230732,20630.187987327576,1572.094933271408,1.6627938747406006,0.0 -47100,1.0999554,3.2607622,,,,,,,,,,,,,, -47200,1.1691496,2.9921415,,,,,,,,,,,,,, -47300,1.0758344,5.050608,,,,,,,,,,,,,, -47400,1.1156223,2.8500931,,,,,,,,,,,,,, -47500,1.1921031,3.8301637,,,,,,,,,,,,,, -47600,1.196044,3.5006711,,,,,,,,,,,,,, -47700,1.2793332,2.3321075,,,,,,,,,,,,,, -47800,1.1770595,2.3503249,,,,,,,,,,,,,, -47900,1.0328006,3.7121928,,,,,,,,,,,,,, -47996,,,0.6664062142372131,1.364644169807434,0.611739993095398,1.6247352361679075,50000.0,0.4918000102043152,2.2892038822174072,10000.0,21050.396213054657,22661.451620817184,21050.396213054657,1606.835849761963,1.6966009140014648,0.0 -48000,1.5270797,2.4187493,,,,,,,,,,,,,, -48100,1.0293359,3.2616887,,,,,,,,,,,,,, -48200,1.4018703,2.3911743,,,,,,,,,,,,,, -48300,1.4617448,2.3138661,,,,,,,,,,,,,, -48400,1.0524079,3.775479,,,,,,,,,,,,,, -48500,1.3631535,4.8452625,,,,,,,,,,,,,, -48600,1.1208521,2.3979955,,,,,,,,,,,,,, -48700,1.3522979,2.3683527,,,,,,,,,,,,,, -48800,1.2322171,2.4234405,,,,,,,,,,,,,, -48900,1.2790129,2.2935588,,,,,,,,,,,,,, -48950,,,0.6929882764816284,1.2464438676834106,0.6179400086402893,1.586337924003601,50000.0,0.4933000206947326,2.256746530532837,10000.0,21470.66002368927,23116.08723402024,21470.66002368927,1641.1119446754456,1.742445945739746,0.0 -49000,1.3403431,2.3246024,,,,,,,,,,,,,, -49100,1.1905404,4.9649205,,,,,,,,,,,,,, -49200,1.2108656,2.4211476,,,,,,,,,,,,,, -49300,1.3653251,2.2545524,,,,,,,,,,,,,, -49400,1.2615223,2.5233455,,,,,,,,,,,,,, -49500,1.1764289,2.7095246,,,,,,,,,,,,,, -49600,1.2419763,2.4589214,,,,,,,,,,,,,, -49700,1.3087373,4.0220304,,,,,,,,,,,,,, -49800,1.2301104,3.2030542,,,,,,,,,,,,,, -49900,1.1459452,5.0124683,,,,,,,,,,,,,, -49906,,,0.6653515696525574,1.385420560836792,0.6169799566268921,1.6083929538726809,50000.0,0.4933000206947326,2.2742130756378174,10000.0,21890.65754508972,23570.7903380394,21890.65754508972,1675.7305736541748,1.7791416645050049,0.0 -50000,1.3067682,2.4216194,,,,,,,,,,,,,, -50100,1.3626254,2.5202692,,,,,,,,,,,,,, -50200,1.2857049,2.449642,,,,,,,,,,,,,, -50300,1.2946606,2.471509,,,,,,,,,,,,,, -50400,1.1450391,4.518323,,,,,,,,,,,,,, -50500,1.3796365,2.3551297,,,,,,,,,,,,,, -50600,1.3014929,2.6926942,,,,,,,,,,,,,, -50700,1.2628005,2.7200847,,,,,,,,,,,,,, -50800,1.2186408,4.825332,,,,,,,,,,,,,, -50862,,,0.6701366901397705,1.3385523557662964,0.6211999654769897,1.5810306072235107,50000.0,0.4958000183105469,2.254197597503662,10000.0,22310.686877965927,24025.816939115524,22310.686877965927,1710.644835472107,1.812294244766236,0.0 -50900,1.2219071,2.4770825,,,,,,,,,,,,,, -51000,1.3239808,2.4907625,,,,,,,,,,,,,, -51100,1.0865045,3.664672,,,,,,,,,,,,,, -51200,1.0404792,4.532588,,,,,,,,,,,,,, -51300,1.1263776,3.5154886,,,,,,,,,,,,,, -51400,0.9848787,4.2838554,,,,,,,,,,,,,, -51500,1.3372285,2.457851,,,,,,,,,,,,,, -51600,1.3291556,2.423716,,,,,,,,,,,,,, -51700,1.2215625,2.784899,,,,,,,,,,,,,, -51800,1.1039743,4.2942963,,,,,,,,,,,,,, -51818,,,0.6743554472923279,1.317443609237671,0.6154999732971191,1.582787036895752,50000.0,0.5011000037193298,2.231066942214966,10000.0,22730.87821960449,24479.74113345146,22730.87821960449,1744.2933213710785,1.84682846069336,0.0 -51900,1.2352248,2.589727,,,,,,,,,,,,,, -52000,1.4401546,2.5374086,,,,,,,,,,,,,, -52100,1.3002278,2.3559055,,,,,,,,,,,,,, -52200,1.2772188,2.517753,,,,,,,,,,,,,, -52300,1.2039698,2.4063792,,,,,,,,,,,,,, -52400,1.1908712,2.3930717,,,,,,,,,,,,,, -52500,1.186076,2.5421662,,,,,,,,,,,,,, -52600,1.278728,2.3410788,,,,,,,,,,,,,, -52700,1.0605016,4.2396955,,,,,,,,,,,,,, -52770,,,0.6700195074081421,1.3372474908828735,0.6182599663734436,1.5742107629776,50000.0,0.5,2.2460379600524902,10000.0,23150.86624217033,24936.70695376396,23150.86624217033,1781.1761672496796,1.8918776512146,0.0 -52800,1.1774896,2.50003,,,,,,,,,,,,,, -52900,1.067963,3.8627315,,,,,,,,,,,,,, -53000,1.2623425,2.3532877,,,,,,,,,,,,,, -53100,1.2402333,4.2420335,,,,,,,,,,,,,, -53200,1.121803,3.9641087,,,,,,,,,,,,,, -53300,1.4286503,2.2645621,,,,,,,,,,,,,, -53400,1.0073,3.2667267,,,,,,,,,,,,,, -53500,1.3473022,2.3572311,,,,,,,,,,,,,, -53600,1.2057344,2.4283826,,,,,,,,,,,,,, -53700,1.2817686,2.5335658,,,,,,,,,,,,,, -53726,,,0.6646288633346558,1.3680269718170166,0.6178399920463562,1.5924180746078491,50000.0,0.4960000216960907,2.242409467697144,10000.0,23571.08939766884,25391.57753252983,23571.08939766884,1815.730684518814,1.9351487159729004,0.0 -53800,1.1658659,3.0037947,,,,,,,,,,,,,, -53900,1.2093613,2.5492845,,,,,,,,,,,,,, -54000,1.7247155,2.6720753,,,,,,,,,,,,,, -54100,1.3661294,2.6130261,,,,,,,,,,,,,, -54200,1.1700372,4.960824,,,,,,,,,,,,,, -54300,1.1965803,2.6621523,,,,,,,,,,,,,, -54400,1.3728518,2.337523,,,,,,,,,,,,,, -54500,1.247154,2.403393,,,,,,,,,,,,,, -54600,1.3652315,2.348096,,,,,,,,,,,,,, -54684,,,0.6791015267372131,1.3228133916854858,0.6257199645042419,1.5733473300933838,50000.0,0.4976000189781189,2.2368738651275635,10000.0,23991.25001358986,25844.49798321724,23991.25001358986,1848.4059002399445,1.9695231914520264,0.0 -54700,1.257557,2.7745624,,,,,,,,,,,,,, -54800,1.4901847,2.284854,,,,,,,,,,,,,, -54900,1.3682766,2.4318266,,,,,,,,,,,,,, -55000,1.2985672,2.4612677,,,,,,,,,,,,,, -55100,1.1027163,3.9748435,,,,,,,,,,,,,, -55200,1.3592557,2.37911,,,,,,,,,,,,,, -55300,1.3743674,2.3834329,,,,,,,,,,,,,, -55400,1.2490183,2.2942343,,,,,,,,,,,,,, -55500,1.2881886,2.2228863,,,,,,,,,,,,,, -55600,1.3107928,2.3286088,,,,,,,,,,,,,, -55641,,,0.6865624785423279,1.307192087173462,0.6205799579620361,1.607119917869568,50000.0,0.4954000115394592,2.2573320865631104,10000.0,24411.1894197464,26297.4784488678,24411.1894197464,1881.3608667850488,2.004138708114624,0.0 -55700,0.9903094,4.223136,,,,,,,,,,,,,, -55800,1.1494259,2.86794,,,,,,,,,,,,,, -55900,1.1193765,2.8615346,,,,,,,,,,,,,, -56000,1.1754435,2.8292503,,,,,,,,,,,,,, -56100,1.1150082,4.7067733,,,,,,,,,,,,,, -56200,1.2024761,2.5179307,,,,,,,,,,,,,, -56300,1.012635,4.259282,,,,,,,,,,,,,, -56400,1.279144,2.7164977,,,,,,,,,,,,,, -56500,1.4227477,2.4715316,,,,,,,,,,,,,, -56598,,,0.6784765720367432,1.305625319480896,0.6305400133132935,1.534404158592224,50000.0,0.5113000273704529,2.1875674724578857,10000.0,24831.47842264176,26751.75327205658,24831.47842264176,1915.2504620552063,2.0507349967956543,0.0 -56600,1.0479468,4.812432,,,,,,,,,,,,,, -56700,1.1182705,4.218926,,,,,,,,,,,,,, -56800,1.2849072,2.0535536,,,,,,,,,,,,,, -56900,1.1236649,2.5486119,,,,,,,,,,,,,, -57000,1.1606967,2.5644474,,,,,,,,,,,,,, -57100,1.0865241,4.205668,,,,,,,,,,,,,, -57200,1.26254,2.3255587,,,,,,,,,,,,,, -57300,1.1173337,4.4578733,,,,,,,,,,,,,, -57400,1.1608934,2.7294266,,,,,,,,,,,,,, -57500,1.3321894,2.35015,,,,,,,,,,,,,, -57554,,,0.6798437237739563,1.33681058883667,0.6278600096702576,1.564784049987793,50000.0,0.5078999996185303,2.219440460205078,10000.0,25251.627435684204,27205.74866914749,25251.627435684204,1949.0071649551392,2.090480327606201,0.0 -57600,1.1611933,4.8074837,,,,,,,,,,,,,, -57700,1.2390877,2.7025812,,,,,,,,,,,,,, -57800,1.132043,3.5952141,,,,,,,,,,,,,, -57900,1.1358848,3.255252,,,,,,,,,,,,,, -58000,1.3822303,2.3631575,,,,,,,,,,,,,, -58100,1.1709161,3.9244223,,,,,,,,,,,,,, -58200,1.2124572,2.4560957,,,,,,,,,,,,,, -58300,1.4324232,2.433634,,,,,,,,,,,,,, -58400,1.1799946,4.254667,,,,,,,,,,,,,, -58500,1.2969791,2.244323,,,,,,,,,,,,,, -58510,,,0.6871093511581421,1.2710402011871338,0.6285799741744995,1.5473343133926392,50000.0,0.5093000531196594,2.194056749343872,10000.0,25671.693274259567,27659.35000705719,25671.693274259567,1982.4562137126925,2.126615047454834,0.0 -58600,1.3005973,2.760292,,,,,,,,,,,,,, -58700,1.4075686,2.432611,,,,,,,,,,,,,, -58800,1.3296778,2.2834349,,,,,,,,,,,,,, -58900,1.1582326,2.7042751,,,,,,,,,,,,,, -59000,1.1714281,3.1757436,,,,,,,,,,,,,, -59100,1.2629863,2.7100754,,,,,,,,,,,,,, -59200,1.3372121,2.63328,,,,,,,,,,,,,, -59300,1.2132673,4.3803163,,,,,,,,,,,,,, -59400,1.2643216,2.342386,,,,,,,,,,,,,, -59466,,,0.7090429663658142,1.2248718738555908,0.6277599930763245,1.591611385345459,50000.0,0.5082000494003296,2.2486259937286377,10000.0,26091.714426994324,28112.035681962967,26091.714426994324,2015.0270056724548,2.170564889907837,0.0 -59500,1.3664753,2.303888,,,,,,,,,,,,,, -59600,1.359284,4.272784,,,,,,,,,,,,,, -59700,1.3309926,2.406796,,,,,,,,,,,,,, -59800,1.3015906,2.442606,,,,,,,,,,,,,, -59900,1.3164847,2.4626467,,,,,,,,,,,,,, -60000,1.2199776,2.5956004,,,,,,,,,,,,,, -60100,1.431544,2.325516,,,,,,,,,,,,,, -60200,1.2707137,2.2926693,,,,,,,,,,,,,, -60300,1.3484585,2.3427982,,,,,,,,,,,,,, -60400,1.1249187,3.9245057,,,,,,,,,,,,,, -60421,,,0.6841210722923279,1.2804559469223022,0.632099986076355,1.5170131921768188,50000.0,0.511900007724762,2.176940679550171,10000.0,26511.63784146309,28567.07047390937,26511.63784146309,2050.053850412369,2.206193447113037,0.0 -60500,1.2971442,3.3999047,,,,,,,,,,,,,, -60600,1.2385128,2.3021142,,,,,,,,,,,,,, -60700,1.1652997,3.3405905,,,,,,,,,,,,,, -60800,1.1571819,3.2653203,,,,,,,,,,,,,, -60900,1.1841614,2.802119,,,,,,,,,,,,,, -61000,1.4445776,2.3027263,,,,,,,,,,,,,, -61100,1.1683189,4.854738,,,,,,,,,,,,,, -61200,1.2856908,4.9930515,,,,,,,,,,,,,, -61300,1.2425468,2.4533875,,,,,,,,,,,,,, -61377,,,0.6870312094688416,1.2437814474105835,0.6377999782562256,1.4895076751708984,50000.0,0.5074000358581543,2.163461208343506,10000.0,26931.86723518372,29022.67796754837,26931.86723518372,2085.340404987335,2.2476589679718018,0.0 -61400,1.2991669,2.8952422,,,,,,,,,,,,,, -61500,1.1754681,4.751719,,,,,,,,,,,,,, -61600,1.0311095,4.0519013,,,,,,,,,,,,,, -61700,1.1775386,3.1322403,,,,,,,,,,,,,, -61800,1.2223898,3.0316207,,,,,,,,,,,,,, -61900,1.3476446,2.0938256,,,,,,,,,,,,,, -62000,1.3464462,2.287632,,,,,,,,,,,,,, -62100,1.2131178,2.4615211,,,,,,,,,,,,,, -62200,1.3388553,2.3033257,,,,,,,,,,,,,, -62300,1.2454157,2.602899,,,,,,,,,,,,,, -62332,,,0.7010741829872131,1.2027630805969238,0.6353999972343445,1.4932037591934204,50000.0,0.5136000514030457,2.16353702545166,10000.0,27352.072308063507,29475.78770780564,27352.072308063507,2118.1527137756348,2.290111780166626,0.0 -62400,1.1462381,4.7648673,,,,,,,,,,,,,, -62500,1.3089142,2.4948907,,,,,,,,,,,,,, -62600,1.2669594,2.2658472,,,,,,,,,,,,,, -62700,1.0668621,4.6447754,,,,,,,,,,,,,, -62800,1.3021765,2.3548124,,,,,,,,,,,,,, -62900,1.3370765,2.2539191,,,,,,,,,,,,,, -63000,1.0990188,3.2224627,,,,,,,,,,,,,, -63100,1.2799474,2.2209322,,,,,,,,,,,,,, -63200,1.1877202,3.497594,,,,,,,,,,,,,, -63289,,,0.6837499737739563,1.3085410594940186,0.6333799958229065,1.5339182615280151,50000.0,0.5095000267028809,2.2137796878814697,10000.0,27772.3682949543,29931.60078239441,27772.3682949543,2153.5841794013977,2.325376510620117,0.0 -63300,1.5390859,2.5964582,,,,,,,,,,,,,, -63400,1.3220413,2.379963,,,,,,,,,,,,,, -63500,1.291957,2.283136,,,,,,,,,,,,,, -63600,1.1850514,3.315764,,,,,,,,,,,,,, -63700,1.2912025,2.2346475,,,,,,,,,,,,,, -63800,1.1443844,3.1786544,,,,,,,,,,,,,, -63900,1.3289199,2.1145794,,,,,,,,,,,,,, -64000,1.1234882,4.3208685,,,,,,,,,,,,,, -64100,1.1033618,2.4877691,,,,,,,,,,,,,, -64200,1.2688341,2.265555,,,,,,,,,,,,,, -64244,,,0.6842578053474426,1.3000692129135132,0.6389999985694885,1.5179568529129028,50000.0,0.5075000524520874,2.188990592956543,10000.0,28192.69956564904,30387.0754737854,28192.69956564904,2188.6370465755463,2.3665175437927246,0.0 -64300,1.2221515,3.3419945,,,,,,,,,,,,,, -64400,1.1038947,4.000433,,,,,,,,,,,,,, -64500,1.3176587,2.2493184,,,,,,,,,,,,,, -64600,1.1193086,4.559126,,,,,,,,,,,,,, -64700,1.1594411,4.461648,,,,,,,,,,,,,, -64800,1.1781245,2.3785744,,,,,,,,,,,,,, -64900,1.262361,4.5758395,,,,,,,,,,,,,, -65000,1.415196,2.3214576,,,,,,,,,,,,,, -65100,1.2202022,2.1524405,,,,,,,,,,,,,, -65200,1.1484538,4.812312,,,,,,,,,,,,,, -65201,,,0.6913281083106995,1.2925525903701782,0.6354599595069885,1.5564727783203125,50000.0,0.5093000531196594,2.22403335571289,10000.0,28613.02726626396,30842.774602413177,28613.02726626396,2223.919200897217,2.40610671043396,0.0 -65300,1.1691729,2.239303,,,,,,,,,,,,,, -65400,1.2875832,2.1474094,,,,,,,,,,,,,, -65500,1.2755343,2.611347,,,,,,,,,,,,,, -65600,1.1264685,3.7584026,,,,,,,,,,,,,, -65700,1.136089,3.4245238,,,,,,,,,,,,,, -65800,1.3222587,2.3159,,,,,,,,,,,,,, -65900,1.2911655,2.2941532,,,,,,,,,,,,,, -66000,1.2013059,2.6531596,,,,,,,,,,,,,, -66100,1.4029534,2.2818239,,,,,,,,,,,,,, -66160,,,0.7096288800239563,1.2008816003799438,0.6379199624061584,1.5290579795837402,50000.0,0.5128999948501587,2.1780552864074707,10000.0,29033.197748422623,31293.28220319748,29033.197748422623,2254.170946121216,2.441678762435913,0.0 -66200,1.5705609,2.4145656,,,,,,,,,,,,,, -66300,1.236687,2.5964699,,,,,,,,,,,,,, -66400,1.27081,2.3589494,,,,,,,,,,,,,, -66500,1.1986855,2.4699185,,,,,,,,,,,,,, -66600,1.3347732,4.3201156,,,,,,,,,,,,,, -66700,1.3888855,2.4227834,,,,,,,,,,,,,, -66800,1.1480637,4.6902294,,,,,,,,,,,,,, -66900,1.1276842,3.2099938,,,,,,,,,,,,,, -67000,1.1918744,4.5259047,,,,,,,,,,,,,, -67100,1.3845156,4.543748,,,,,,,,,,,,,, -67115,,,0.6934179663658142,1.2447786331176758,0.6399399638175964,1.487149953842163,50000.0,0.5189000368118286,2.1419031620025635,10000.0,29453.328754663467,31748.261180639267,29453.328754663467,2288.9244425296783,2.4865760803222656,0.0 -67200,1.132838,3.031125,,,,,,,,,,,,,, -67300,1.2949021,2.9893026,,,,,,,,,,,,,, -67400,1.2856895,2.0956848,,,,,,,,,,,,,, -67500,1.1468726,4.8852954,,,,,,,,,,,,,, -67600,1.3023702,2.2443595,,,,,,,,,,,,,, -67700,1.3767513,2.5172539,,,,,,,,,,,,,, -67800,1.2021888,3.1386933,,,,,,,,,,,,,, -67900,1.2994492,2.2691784,,,,,,,,,,,,,, -68000,1.3801546,2.2764978,,,,,,,,,,,,,, -68073,,,0.7007812261581421,1.195706486701965,0.6473599672317505,1.448695421218872,50000.0,0.522599995136261,2.106295347213745,10000.0,29873.27702856064,32204.031671524048,29873.27702856064,2324.6626167297363,2.5216784477233887,0.0 -68100,1.2788483,2.586469,,,,,,,,,,,,,, -68200,1.294166,2.2844472,,,,,,,,,,,,,, -68300,1.3509623,2.2459974,,,,,,,,,,,,,, -68400,1.0916685,3.3139858,,,,,,,,,,,,,, -68500,1.3848441,2.27024,,,,,,,,,,,,,, -68600,1.1470176,4.311465,,,,,,,,,,,,,, -68700,1.0887289,2.9690428,,,,,,,,,,,,,, -68800,1.1587298,4.349029,,,,,,,,,,,,,, -68900,1.1739143,3.389491,,,,,,,,,,,,,, -69000,1.1691482,4.284396,,,,,,,,,,,,,, -69028,,,0.7033398151397705,1.235210657119751,0.6412999629974365,1.5109248161315918,50000.0,0.5178000330924988,2.1673731803894043,10000.0,30293.492072343823,32660.45233750344,30293.492072343823,2360.7755336761475,2.5654780864715576,0.0 -69100,1.3500912,3.6085613,,,,,,,,,,,,,, -69200,1.2097113,3.9888535,,,,,,,,,,,,,, -69300,1.1765496,3.172457,,,,,,,,,,,,,, -69400,1.3921947,2.1878924,,,,,,,,,,,,,, -69500,1.3619094,2.3015447,,,,,,,,,,,,,, -69600,1.3171122,3.3426135,,,,,,,,,,,,,, -69700,1.4395214,2.2761059,,,,,,,,,,,,,, -69800,1.3198988,2.7263634,,,,,,,,,,,,,, -69900,1.2265239,3.9580703,,,,,,,,,,,,,, -69982,,,0.7116405963897705,1.1737558841705322,0.6464200019836426,1.4845588207244873,50000.0,0.5182000398635864,2.133967399597168,10000.0,30713.74888730049,33112.77065825462,30713.74888730049,2392.7489824295044,2.603192090988159,0.0 -70000,1.1131794,2.923377,,,,,,,,,,,,,, -70100,1.3846935,2.2686274,,,,,,,,,,,,,, -70200,1.229404,2.6552813,,,,,,,,,,,,,, -70300,1.1554316,3.7831407,,,,,,,,,,,,,, -70400,1.3453689,2.2508907,,,,,,,,,,,,,, -70500,1.12678,4.0923357,,,,,,,,,,,,,, -70600,1.2707634,3.438995,,,,,,,,,,,,,, -70700,1.1098312,4.1620083,,,,,,,,,,,,,, -70800,1.4081011,2.2472954,,,,,,,,,,,,,, -70900,1.3857119,4.639096,,,,,,,,,,,,,, -70937,,,0.69447261095047,1.2243953943252563,0.6431599855422974,1.4684062004089355,50000.0,0.5223000049591064,2.1306920051574707,10000.0,31133.713386058807,33567.919130563736,31133.713386058807,2427.8326795101166,2.6537318229675293,0.0 -71000,1.2973819,2.3611739,,,,,,,,,,,,,, -71100,1.3183526,4.1725473,,,,,,,,,,,,,, -71200,1.1078123,4.194028,,,,,,,,,,,,,, -71300,1.1634842,4.4485188,,,,,,,,,,,,,, -71400,1.3223023,2.239676,,,,,,,,,,,,,, -71500,1.346726,2.2984934,,,,,,,,,,,,,, -71600,1.3940175,2.220444,,,,,,,,,,,,,, -71700,1.2552142,2.2641761,,,,,,,,,,,,,, -71800,1.400573,2.2786503,,,,,,,,,,,,,, -71889,,,0.697460949420929,1.2218685150146484,0.6411399841308594,1.4857412576675415,50000.0,0.5151000022888184,2.1532177925109863,10000.0,31553.676849365234,34023.14230251312,31553.676849365234,2463.0030977725983,2.694608211517334,0.0 -71900,1.3654957,2.2008357,,,,,,,,,,,,,, -72000,1.3604834,2.1849551,,,,,,,,,,,,,, -72100,1.270092,2.184692,,,,,,,,,,,,,, -72200,1.3785281,2.2098122,,,,,,,,,,,,,, -72300,1.3507036,4.686123,,,,,,,,,,,,,, -72400,1.3056571,2.3427553,,,,,,,,,,,,,, -72500,1.3522649,2.1814744,,,,,,,,,,,,,, -72600,1.2272891,2.6705642,,,,,,,,,,,,,, -72700,1.4933238,2.2680829,,,,,,,,,,,,,, -72800,1.3752606,2.1840882,,,,,,,,,,,,,, -72846,,,0.7125195264816284,1.142621397972107,0.650119960308075,1.44003164768219,50000.0,0.5228000283241272,2.113523483276367,10000.0,31974.063121318817,34478.38299822807,31974.063121318817,2497.761979341507,2.7404704093933105,0.0 -72900,1.3272796,2.2655315,,,,,,,,,,,,,, -73000,1.3944882,2.4863074,,,,,,,,,,,,,, -73100,1.3908031,2.1465132,,,,,,,,,,,,,, -73200,1.1950577,3.0864427,,,,,,,,,,,,,, -73300,1.2875831,2.6502419,,,,,,,,,,,,,, -73400,1.1675972,2.8335752,,,,,,,,,,,,,, -73500,1.3896388,2.2444055,,,,,,,,,,,,,, -73600,1.3185222,4.714774,,,,,,,,,,,,,, -73700,1.2987963,2.2127783,,,,,,,,,,,,,, -73800,1.2582059,4.4234195,,,,,,,,,,,,,, -73804,,,0.699999988079071,1.2046812772750854,0.6521199941635132,1.439921736717224,50000.0,0.527999997138977,2.113095760345459,10000.0,32394.13893151284,34934.52019953728,32394.13893151284,2533.733570098877,2.7803633213043213,0.0 -73900,1.4929037,2.1029444,,,,,,,,,,,,,, -74000,1.2112844,2.611001,,,,,,,,,,,,,, -74100,1.4201983,2.1877952,,,,,,,,,,,,,, -74200,1.4403741,2.313997,,,,,,,,,,,,,, -74300,1.3310926,2.2060065,,,,,,,,,,,,,, -74400,1.3617371,2.0870445,,,,,,,,,,,,,, -74500,1.3803927,2.1139393,,,,,,,,,,,,,, -74600,1.4479607,2.3097997,,,,,,,,,,,,,, -74700,1.2183318,2.3317192,,,,,,,,,,,,,, -74763,,,0.7062109112739563,1.1783642768859863,0.6523199677467346,1.4315232038497925,50000.0,0.5234000086784363,2.103261709213257,10000.0,32814.26351618767,35390.0334186554,32814.26351618767,2569.016624450684,2.836053371429444,0.0 -74800,1.2938219,2.0457478,,,,,,,,,,,,,, -74900,1.3163384,2.389043,,,,,,,,,,,,,, -75000,1.2521162,4.2293363,,,,,,,,,,,,,, -75100,1.2302625,3.3475368,,,,,,,,,,,,,, -75200,1.1945187,4.021247,,,,,,,,,,,,,, -75300,1.3001418,2.2517009,,,,,,,,,,,,,, -75400,1.4679102,2.1712077,,,,,,,,,,,,,, -75500,1.3218135,2.1714373,,,,,,,,,,,,,, -75600,1.3600444,2.182702,,,,,,,,,,,,,, -75700,1.2488561,2.7801805,,,,,,,,,,,,,, -75717,,,0.7077734470367432,1.1573503017425537,0.6542999744415283,1.419303297996521,50000.0,0.5235000252723694,2.121488332748413,10000.0,33234.302689790726,35845.19429016113,33234.302689790726,2604.0515208244324,2.8736648559570312,0.0 -75800,1.1603029,4.6795936,,,,,,,,,,,,,, -75900,1.4035424,2.1880875,,,,,,,,,,,,,, -76000,1.6229599,2.185435,,,,,,,,,,,,,, -76100,1.3912201,2.0716558,,,,,,,,,,,,,, -76200,1.2840227,4.5595164,,,,,,,,,,,,,, -76300,1.396499,2.1150072,,,,,,,,,,,,,, -76400,1.3218943,2.2089653,,,,,,,,,,,,,, -76500,1.241243,3.1647096,,,,,,,,,,,,,, -76600,1.2285551,2.3303022,,,,,,,,,,,,,, -76674,,,0.7352538704872131,1.0566195249557495,0.6536200046539307,1.416456937789917,50000.0,0.5303000211715698,2.0822157859802246,10000.0,33654.46897649765,36301.674439907074,33654.46897649765,2640.2735633850098,2.9117650985717773,0.0 -76700,1.2533482,4.3013816,,,,,,,,,,,,,, -76800,1.2075999,2.9631026,,,,,,,,,,,,,, -76900,1.4012954,2.0897896,,,,,,,,,,,,,, -77000,1.2303183,2.2132769,,,,,,,,,,,,,, -77100,1.2928041,3.2683177,,,,,,,,,,,,,, -77200,1.2704289,4.1006656,,,,,,,,,,,,,, -77300,1.4614304,2.0873165,,,,,,,,,,,,,, -77400,1.5003982,2.2960343,,,,,,,,,,,,,, -77500,1.3045446,3.4252954,,,,,,,,,,,,,, -77600,1.4800912,2.5620203,,,,,,,,,,,,,, -77634,,,0.7109179496765137,1.169638752937317,0.6612199544906616,1.404836654663086,50000.0,0.5340999960899353,2.075775384902954,10000.0,34074.84289550781,36757.19959259033,34074.84289550781,2675.3321437835693,2.9543845653533936,0.0 -77700,1.4116756,2.1318877,,,,,,,,,,,,,, -77800,1.4752336,2.1386242,,,,,,,,,,,,,, -77900,1.374208,2.1744018,,,,,,,,,,,,,, -78000,1.3093492,2.039871,,,,,,,,,,,,,, -78100,1.1893554,3.586369,,,,,,,,,,,,,, -78200,1.5300095,4.6297736,,,,,,,,,,,,,, -78300,1.316141,4.271508,,,,,,,,,,,,,, -78400,1.2082009,3.823961,,,,,,,,,,,,,, -78500,1.2774979,1.8934102,,,,,,,,,,,,,, -78592,,,0.7156835794448853,1.1127052307128906,0.662559986114502,1.369977593421936,50000.0,0.5379000306129456,2.034130811691284,10000.0,34494.79498171806,37214.54881882668,34494.79498171806,2712.6424124240875,2.9909353256225586,0.0 -78600,1.2954699,2.2407153,,,,,,,,,,,,,, -78700,1.2387412,3.9841642,,,,,,,,,,,,,, -78800,1.2124627,3.4050212,,,,,,,,,,,,,, -78900,1.2877811,2.4527454,,,,,,,,,,,,,, -79000,1.2569398,2.2801068,,,,,,,,,,,,,, -79100,1.4432783,2.2538528,,,,,,,,,,,,,, -79200,1.41661,2.230831,,,,,,,,,,,,,, -79300,1.3199954,3.4087496,,,,,,,,,,,,,, -79400,1.3530896,2.1956065,,,,,,,,,,,,,, -79500,1.3514092,4.8046184,,,,,,,,,,,,,, -79548,,,0.7233788967132568,1.1037510633468628,0.659280002117157,1.397465705871582,50000.0,0.5356000065803528,2.0531795024871826,10000.0,34914.722329854965,37669.63617515564,34914.722329854965,2747.687013626098,3.044427394866944,0.0 -79600,1.2121756,4.4317827,,,,,,,,,,,,,, -79700,1.4071317,2.4200177,,,,,,,,,,,,,, -79800,1.3853443,2.1697254,,,,,,,,,,,,,, -79900,1.2465353,2.6681304,,,,,,,,,,,,,, -80000,1.1931368,2.9088383,,,,,,,,,,,,,, -80100,1.3396616,3.5628402,,,,,,,,,,,,,, -80200,1.4502493,3.3700054,,,,,,,,,,,,,, -80300,1.1733352,4.5431576,,,,,,,,,,,,,, -80400,1.2213686,4.179052,,,,,,,,,,,,,, -80500,1.2973932,2.3650584,,,,,,,,,,,,,, -80505,,,0.7168163657188416,1.139781832695007,0.6624999642372131,1.3918650150299072,50000.0,0.5367000102996826,2.0481276512146,10000.0,35334.804896593094,38125.42041826248,35334.804896593094,2783.3004961013794,3.0824363231658936,0.0 -80600,1.3737137,2.0082097,,,,,,,,,,,,,, -80700,1.4862994,2.13804,,,,,,,,,,,,,, -80800,1.2864943,2.6396413,,,,,,,,,,,,,, -80900,1.314311,2.2557244,,,,,,,,,,,,,, -81000,1.4152398,2.0939221,,,,,,,,,,,,,, -81100,1.1873388,4.5353246,,,,,,,,,,,,,, -81200,1.3118036,3.3571289,,,,,,,,,,,,,, -81300,1.3804011,2.2516277,,,,,,,,,,,,,, -81400,1.2783986,4.713823,,,,,,,,,,,,,, -81462,,,0.7132421731948853,1.1469886302947998,0.6568199992179871,1.4000229835510254,50000.0,0.526199996471405,2.07778549194336,10000.0,35754.91748666763,38581.58862805367,35754.91748666763,2819.2582035064697,3.1309750080108643,0.0 -81500,1.301363,4.136373,,,,,,,,,,,,,, -81600,1.2474873,3.4195962,,,,,,,,,,,,,, -81700,1.2514846,2.2643056,,,,,,,,,,,,,, -81800,1.3300931,3.8909807,,,,,,,,,,,,,, -81900,1.4832307,1.9848514,,,,,,,,,,,,,, -82000,1.160334,4.22079,,,,,,,,,,,,,, -82100,1.5642279,2.1659074,,,,,,,,,,,,,, -82200,1.6228678,2.1934972,,,,,,,,,,,,,, -82300,1.3695717,4.646786,,,,,,,,,,,,,, -82400,1.3160789,2.0273843,,,,,,,,,,,,,, -82421,,,0.7280859351158142,1.0940496921539309,0.6650399565696716,1.375693678855896,50000.0,0.5421000123023987,2.0170681476593018,10000.0,36174.90492129326,39035.7145614624,36174.90492129326,2853.3048980236053,3.1728460788726807,0.0 -82500,1.2705069,2.4884326,,,,,,,,,,,,,, -82600,1.3352821,2.0276384,,,,,,,,,,,,,, -82700,1.2186916,2.5955036,,,,,,,,,,,,,, -82800,1.4879246,2.1262267,,,,,,,,,,,,,, -82900,1.3469224,3.5916123,,,,,,,,,,,,,, -83000,1.4209863,2.0953288,,,,,,,,,,,,,, -83100,1.2871435,2.6896477,,,,,,,,,,,,,, -83200,1.3610636,2.1369524,,,,,,,,,,,,,, -83300,1.2891641,2.1043646,,,,,,,,,,,,,, -83380,,,0.7390820384025574,1.0498539209365845,0.6658599972724915,1.3794326782226562,50000.0,0.5343000292778015,2.052915334701538,10000.0,36595.001002788544,39491.40196967125,36595.001002788544,2888.807193994522,3.2124273777008057,0.0 -83400,1.3233696,2.307684,,,,,,,,,,,,,, -83500,1.3932548,2.9425437,,,,,,,,,,,,,, -83600,1.3108299,3.2480009,,,,,,,,,,,,,, -83700,1.4237695,2.1377683,,,,,,,,,,,,,, -83800,1.4766141,2.047493,,,,,,,,,,,,,, -83900,1.3707272,2.0428095,,,,,,,,,,,,,, -84000,1.4122111,2.069181,,,,,,,,,,,,,, -84100,1.4913344,2.4015589,,,,,,,,,,,,,, -84200,1.4600776,3.263821,,,,,,,,,,,,,, -84300,1.3641104,1.9169624,,,,,,,,,,,,,, -84339,,,0.7233007550239563,1.0985373258590698,0.668720006942749,1.348969340324402,50000.0,0.5428000092506409,1.999638795852661,10000.0,37015.265879392624,39946.67766284943,37015.265879392624,2923.7270460128784,3.254138708114624,0.0 -84400,1.4278653,1.9585598,,,,,,,,,,,,,, -84500,1.3918927,2.072132,,,,,,,,,,,,,, -84600,1.6024481,2.2625828,,,,,,,,,,,,,, -84700,1.3636622,1.9999439,,,,,,,,,,,,,, -84800,1.4500852,2.122309,,,,,,,,,,,,,, -84900,1.2912155,3.0017354,,,,,,,,,,,,,, -85000,1.3056273,4.4509864,,,,,,,,,,,,,, -85100,1.290349,2.8034801,,,,,,,,,,,,,, -85200,1.5249988,2.1511989,,,,,,,,,,,,,, -85298,,,0.7191796898841858,1.1635682582855225,0.6645399928092957,1.4133321046829224,50000.0,0.5355000495910645,2.0667412281036377,10000.0,37435.61778759956,40402.64578437805,37435.61778759956,2959.251959323883,3.2952709197998047,0.0 -85300,1.3461142,2.8560553,,,,,,,,,,,,,, -85400,1.4988964,2.0490208,,,,,,,,,,,,,, -85500,1.3907542,2.0109358,,,,,,,,,,,,,, -85600,1.5829376,2.0356941,,,,,,,,,,,,,, -85700,1.4199804,1.9721454,,,,,,,,,,,,,, -85800,1.2360286,3.6420894,,,,,,,,,,,,,, -85900,1.4060314,2.330695,,,,,,,,,,,,,, -86000,1.1822052,2.9550924,,,,,,,,,,,,,, -86100,1.2634083,2.736551,,,,,,,,,,,,,, -86200,1.5194873,1.9745727,,,,,,,,,,,,,, -86259,,,0.7269921898841858,1.1259955167770386,0.6650800108909607,1.396593451499939,50000.0,0.5408000349998474,2.054095506668091,10000.0,37855.74568748474,40854.33689212799,37855.74568748474,2990.723461866379,3.3366193771362305,0.0 -86300,1.3047452,3.983051,,,,,,,,,,,,,, -86400,1.308654,4.4376593,,,,,,,,,,,,,, -86500,1.5792556,2.2406895,,,,,,,,,,,,,, -86600,1.3170804,2.6253607,,,,,,,,,,,,,, -86700,1.1942878,3.908403,,,,,,,,,,,,,, -86800,1.3399911,2.583165,,,,,,,,,,,,,, -86900,1.2951117,4.487061,,,,,,,,,,,,,, -87000,1.2626609,2.791729,,,,,,,,,,,,,, -87100,1.4371437,2.360487,,,,,,,,,,,,,, -87200,1.2586468,3.7599587,,,,,,,,,,,,,, -87217,,,0.7546288967132568,0.9544236660003662,0.6703999638557434,1.32748544216156,50000.0,0.5439000129699707,1.98435378074646,10000.0,38275.79717183113,41306.657376527786,38275.79717183113,3022.891412258148,3.387751340866089,0.0 -87300,1.3425487,2.6933832,,,,,,,,,,,,,, -87400,1.3487545,2.1124644,,,,,,,,,,,,,, -87500,1.5106008,4.516044,,,,,,,,,,,,,, -87600,1.3607885,2.005143,,,,,,,,,,,,,, -87700,1.2582848,2.768888,,,,,,,,,,,,,, -87800,1.5636669,2.1012704,,,,,,,,,,,,,, -87900,1.2699718,4.109259,,,,,,,,,,,,,, -88000,1.6436341,2.054793,,,,,,,,,,,,,, -88100,1.6017373,4.6676083,,,,,,,,,,,,,, -88173,,,0.7283398509025574,1.082810401916504,0.6726599931716919,1.337311029434204,50000.0,0.5463000535964966,1.977535843849182,10000.0,38695.72180867195,41760.179690122604,38695.72180867195,3056.3972618579865,3.430220603942871,0.0 -88200,1.3841802,2.0990367,,,,,,,,,,,,,, -88300,1.3260721,4.5862017,,,,,,,,,,,,,, -88400,1.4064603,2.0700755,,,,,,,,,,,,,, -88500,1.5096091,2.1499326,,,,,,,,,,,,,, -88600,1.485053,2.2913034,,,,,,,,,,,,,, -88700,1.2856802,3.092912,,,,,,,,,,,,,, -88800,1.2834574,3.5900025,,,,,,,,,,,,,, -88900,1.4270265,2.2108867,,,,,,,,,,,,,, -89000,1.3693516,2.092332,,,,,,,,,,,,,, -89100,1.4076127,2.0171003,,,,,,,,,,,,,, -89131,,,0.7352343797683716,1.056477427482605,0.671779990196228,1.3395153284072876,50000.0,0.5505000352859497,1.975814938545227,10000.0,39116.09435725212,42216.833818912506,39116.09435725212,3092.58779001236,3.471376419067383,0.0 -89200,1.451187,2.0344954,,,,,,,,,,,,,, -89300,1.5388045,2.1277149,,,,,,,,,,,,,, -89400,1.364984,4.633152,,,,,,,,,,,,,, -89500,1.3411295,2.8722765,,,,,,,,,,,,,, -89600,1.2229733,3.152984,,,,,,,,,,,,,, -89700,1.340694,2.858252,,,,,,,,,,,,,, -89800,1.2026881,3.320084,,,,,,,,,,,,,, -89900,1.3432093,4.6400185,,,,,,,,,,,,,, -90000,1.3578913,3.903924,,,,,,,,,,,,,, -90091,,,0.7414257526397705,1.0378321409225464,0.6741200089454651,1.3346014022827148,50000.0,0.5543000102043152,1.976184606552124,10000.0,39536.358402490616,42678.63556289673,39536.358402490616,3134.03320145607,3.513042688369751,0.0 -90100,1.4816921,2.010296,,,,,,,,,,,,,, -90200,1.3976961,2.40623,,,,,,,,,,,,,, -90300,1.3494606,4.1789026,,,,,,,,,,,,,, -90400,1.3103937,2.5004983,,,,,,,,,,,,,, -90500,1.4595665,2.2422948,,,,,,,,,,,,,, -90600,1.4284524,2.0321383,,,,,,,,,,,,,, -90700,1.3152229,4.559638,,,,,,,,,,,,,, -90800,1.2725451,2.7880714,,,,,,,,,,,,,, -90900,1.3718336,3.6786432,,,,,,,,,,,,,, -91000,1.4180064,2.1947863,,,,,,,,,,,,,, -91048,,,0.7334179282188416,1.0601955652236938,0.6719399690628052,1.3371931314468384,50000.0,0.5467000007629395,1.9848815202713013,10000.0,39956.39900755882,43136.15661621094,39956.39900755882,3171.4239320755005,3.5524487495422363,0.0 -91100,1.3482909,2.6995096,,,,,,,,,,,,,, -91200,1.5822512,4.622669,,,,,,,,,,,,,, -91300,1.5544616,2.2060876,,,,,,,,,,,,,, -91400,1.3158854,4.236249,,,,,,,,,,,,,, -91500,1.4444088,2.096882,,,,,,,,,,,,,, -91600,1.5805807,4.587179,,,,,,,,,,,,,, -91700,1.4559368,2.0123863,,,,,,,,,,,,,, -91800,1.3367317,2.3710172,,,,,,,,,,,,,, -91900,1.4741327,2.0237873,,,,,,,,,,,,,, -92000,1.2917776,4.227817,,,,,,,,,,,,,, -92005,,,0.732128918170929,1.0990062952041626,0.6755799651145935,1.3497874736785889,50000.0,0.5501000285148621,2.014688491821289,10000.0,40376.65997195244,43596.69050884247,40376.65997195244,3211.6039159297943,3.5961568355560303,0.0 -92100,1.5330977,2.0060112,,,,,,,,,,,,,, -92200,1.3140494,3.9796374,,,,,,,,,,,,,, -92300,1.4187481,2.268392,,,,,,,,,,,,,, -92400,1.4916742,2.1163604,,,,,,,,,,,,,, -92500,1.6199782,2.030309,,,,,,,,,,,,,, -92600,1.5179416,2.1104078,,,,,,,,,,,,,, -92700,1.302055,3.0204391,,,,,,,,,,,,,, -92800,1.562703,4.278313,,,,,,,,,,,,,, -92900,1.4535204,4.616191,,,,,,,,,,,,,, -92962,,,0.7373241782188416,1.0257611274719238,0.6764799952507019,1.3098903894424438,50000.0,0.5496000051498413,1.969154953956604,10000.0,40796.6041162014,44051.557891607285,40796.6041162014,3246.4359505176544,3.63786244392395,0.0 -93000,1.6342325,2.1133015,,,,,,,,,,,,,, -93100,1.5672017,2.1659367,,,,,,,,,,,,,, -93200,1.5492406,2.0267625,,,,,,,,,,,,,, -93300,1.2852378,3.7703516,,,,,,,,,,,,,, -93400,1.4915688,4.0469575,,,,,,,,,,,,,, -93500,1.3775158,3.4687889,,,,,,,,,,,,,, -93600,1.3834562,2.5370765,,,,,,,,,,,,,, -93700,1.5154262,1.9899056,,,,,,,,,,,,,, -93800,1.3234881,3.7049952,,,,,,,,,,,,,, -93900,1.4541818,2.6382728,,,,,,,,,,,,,, -93923,,,0.74916011095047,0.9873992800712584,0.6765999794006348,1.3194611072540283,50000.0,0.5472000241279602,1.9758864641189573,10000.0,41216.86024641991,44508.52853775024,41216.86024641991,3283.058206319809,3.680174589157105,0.0 -94000,1.3465531,1.9822025,,,,,,,,,,,,,, -94100,1.5201621,2.2083027,,,,,,,,,,,,,, -94200,1.4231334,1.9995391,,,,,,,,,,,,,, -94300,1.3799958,4.333173,,,,,,,,,,,,,, -94400,1.5575355,2.0036695,,,,,,,,,,,,,, -94500,1.5088168,2.2756667,,,,,,,,,,,,,, -94600,1.6546565,4.525064,,,,,,,,,,,,,, -94700,1.5514269,2.0862257,,,,,,,,,,,,,, -94800,1.4276031,3.294247,,,,,,,,,,,,,, -94883,,,0.740917980670929,1.0271331071853638,0.6797199845314026,1.3003922700881958,50000.0,0.553600013256073,1.94719660282135,10000.0,41637.13806056976,44966.2047367096,41637.13806056976,3320.365930557251,3.7214975357055664,0.0 -94900,1.4157664,3.9816704,,,,,,,,,,,,,, -95000,1.2599733,3.8212905,,,,,,,,,,,,,, -95100,1.7228527,2.0569983,,,,,,,,,,,,,, -95200,1.3494984,4.2983675,,,,,,,,,,,,,, -95300,1.7214836,1.9709225,,,,,,,,,,,,,, -95400,1.3410298,2.3674803,,,,,,,,,,,,,, -95500,1.3908762,1.9311645,,,,,,,,,,,,,, -95600,1.4584672,4.481081,,,,,,,,,,,,,, -95700,1.3129228,3.328601,,,,,,,,,,,,,, -95800,1.5025362,2.1106527,,,,,,,,,,,,,, -95838,,,0.7427148222923279,1.0067245960235596,0.6803999543190002,1.2884384393692017,50000.0,0.5525000095367432,1.9437371492385864,10000.0,42057.04630446434,45429.55492138863,42057.04630446434,3363.7155849933624,3.76453709602356,0.0 -95900,1.3901498,2.0770893,,,,,,,,,,,,,, -96000,1.3278334,3.4500494,,,,,,,,,,,,,, -96100,1.3407218,4.5474067,,,,,,,,,,,,,, -96200,1.4863595,2.0139782,,,,,,,,,,,,,, -96300,1.2797004,2.826554,,,,,,,,,,,,,, -96400,1.385847,1.7561363,,,,,,,,,,,,,, -96500,1.269106,3.3257992,,,,,,,,,,,,,, -96600,1.2994605,4.301977,,,,,,,,,,,,,, -96700,1.5277262,2.028708,,,,,,,,,,,,,, -96794,,,0.7479296922683716,0.9932562112808228,0.6821199655532837,1.2954468727111816,50000.0,0.5600000023841858,1.9353830814361568,10000.0,42476.98233127594,45888.71178412437,42476.98233127594,3402.8420236110687,3.809993982315064,0.0 -96800,1.4020452,2.0460494,,,,,,,,,,,,,, -96900,1.4495926,1.8777725,,,,,,,,,,,,,, -97000,1.4374589,4.169097,,,,,,,,,,,,,, -97100,1.430081,4.3247104,,,,,,,,,,,,,, -97200,1.4796575,2.166157,,,,,,,,,,,,,, -97300,1.4407072,1.9016094,,,,,,,,,,,,,, -97400,1.556575,1.9192688,,,,,,,,,,,,,, -97500,1.4491218,1.818677,,,,,,,,,,,,,, -97600,1.4146056,2.4328127,,,,,,,,,,,,,, -97700,1.6408888,1.9419947,,,,,,,,,,,,,, -97752,,,0.7708593606948853,0.9360008239746094,0.6833399534225464,1.3163715600967407,50000.0,0.5525000095367432,1.9546915292739868,10000.0,42897.07012176514,46350.79205417633,42897.07012176514,3444.739847421646,3.855117321014404,0.0 -97800,1.4595236,4.3274736,,,,,,,,,,,,,, -97900,1.2313507,3.0239122,,,,,,,,,,,,,, -98000,1.401357,3.0673816,,,,,,,,,,,,,, -98100,1.52924,2.0747492,,,,,,,,,,,,,, -98200,1.6570343,2.0267007,,,,,,,,,,,,,, -98300,1.4121972,2.4332933,,,,,,,,,,,,,, -98400,1.3405917,2.3507147,,,,,,,,,,,,,, -98500,1.3553988,3.78374,,,,,,,,,,,,,, -98600,1.3397275,3.3862114,,,,,,,,,,,,,, -98700,1.4603838,4.1811724,,,,,,,,,,,,,, -98707,,,0.7409374713897705,1.0260651111602783,0.6859999895095825,1.2778408527374268,50000.0,0.5577000379562378,1.93332839012146,10000.0,43317.06185436249,46805.97884583473,43317.06185436249,3479.8404698371887,3.900003433227539,0.0 -98800,1.2799319,3.862558,,,,,,,,,,,,,, -98900,1.4989455,1.9713236,,,,,,,,,,,,,, -99000,1.5178162,1.9726472,,,,,,,,,,,,,, -99100,1.5289537,2.7393994,,,,,,,,,,,,,, -99200,1.5196991,1.8244071,,,,,,,,,,,,,, -99300,1.3593065,1.9430736,,,,,,,,,,,,,, -99400,1.3588183,4.009842,,,,,,,,,,,,,, -99500,1.5094606,2.0470836,,,,,,,,,,,,,, -99600,1.4290162,1.9647859,,,,,,,,,,,,,, -99664,,,0.7491015195846558,0.9865267872810364,0.6865599751472473,1.276195764541626,50000.0,0.5599000453948975,1.931175708770752,10000.0,43737.05564260483,47262.28113818169,43737.05564260483,3516.056496620178,3.943291664123535,0.0 -99700,1.6275811,4.3567195,,,,,,,,,,,,,, -99800,1.3748251,2.6405413,,,,,,,,,,,,,, -99900,1.6089569,2.0568342,,,,,,,,,,,,,, -100000,1.516675,1.8375666,,,,,,,,,,,,,, -100100,1.5074222,2.1898718,,,,,,,,,,,,,, -100200,1.496179,1.9007177,,,,,,,,,,,,,, -100300,1.337886,3.1336074,,,,,,,,,,,,,, -100400,1.5269619,2.7515929,,,,,,,,,,,,,, -100500,1.5457621,1.958772,,,,,,,,,,,,,, -100600,1.3513696,3.151242,,,,,,,,,,,,,, -100623,,,0.7603319883346558,0.9452508091926576,0.6875,1.2674585580825806,50000.0,0.5639000535011292,1.9013869762420648,10000.0,44157.2749619484,47717.88592863083,44157.2749619484,3551.346296310425,3.98929500579834,0.0 -100700,1.4140211,2.498886,,,,,,,,,,,,,, -100800,1.3883497,3.1148493,,,,,,,,,,,,,, -100900,1.4109006,2.9554007,,,,,,,,,,,,,, -101000,1.4025669,3.6108563,,,,,,,,,,,,,, -101100,1.3625352,2.4314744,,,,,,,,,,,,,, -101200,1.5332187,1.9325535,,,,,,,,,,,,,, -101300,1.5378629,2.024626,,,,,,,,,,,,,, -101400,1.4402667,1.9780166,,,,,,,,,,,,,, -101500,1.3447222,2.993538,,,,,,,,,,,,,, -101584,,,0.7491992115974426,1.0066653490066528,0.6896399855613708,1.2784479856491089,50000.0,0.562000036239624,1.9139764308929443,10000.0,44577.33863568306,48174.25753450394,44577.33863568306,3587.5535452365875,4.040088415145874,0.0 -101600,1.3530487,2.5062811,,,,,,,,,,,,,, -101700,1.5245491,2.7447011,,,,,,,,,,,,,, -101800,1.4125924,3.565047,,,,,,,,,,,,,, -101900,1.4495792,3.1887405,,,,,,,,,,,,,, -102000,1.4573811,2.1447144,,,,,,,,,,,,,, -102100,1.3213681,4.4444075,,,,,,,,,,,,,, -102200,1.5644041,1.8876629,,,,,,,,,,,,,, -102300,1.4539857,4.0728326,,,,,,,,,,,,,, -102400,1.5103378,2.4426095,,,,,,,,,,,,,, -102500,1.7851276,1.9376404,,,,,,,,,,,,,, -102541,,,0.7575390338897705,0.9742467999458312,0.694920003414154,1.2482126951217651,50000.0,0.5687000155448914,1.890002012252808,10000.0,44997.26434326172,48626.50860714912,44997.26434326172,3619.7777602672577,4.091718912124634,0.0 -102600,1.4974053,1.9677085,,,,,,,,,,,,,, -102700,1.4826161,1.8169645,,,,,,,,,,,,,, -102800,1.4011036,4.301801,,,,,,,,,,,,,, -102900,1.4283549,3.898123,,,,,,,,,,,,,, -103000,1.4573612,2.5001714,,,,,,,,,,,,,, -103100,1.4442277,2.0303335,,,,,,,,,,,,,, -103200,1.5165737,2.0165174,,,,,,,,,,,,,, -103300,1.8840169,1.9710329,,,,,,,,,,,,,, -103400,1.5214767,2.102419,,,,,,,,,,,,,, -103500,,,0.7555859088897705,0.9570018649101256,0.6918999552726746,1.2508394718170166,50000.0,0.563800036907196,1.9133514165878296,10000.0,45417.35946679115,49082.983598947525,45417.35946679115,3656.063153505325,4.135879993438721,0.0 -103500,1.4537714,3.8093727,,,,,,,,,,,,,, -103600,1.5490252,1.8573788,,,,,,,,,,,,,, -103700,1.358192,2.7638042,,,,,,,,,,,,,, -103800,1.5982913,1.8033457,,,,,,,,,,,,,, -103900,1.432743,2.3152847,,,,,,,,,,,,,, -104000,1.4372854,2.609227,,,,,,,,,,,,,, -104100,1.4988633,2.0807962,,,,,,,,,,,,,, -104200,1.4639467,2.22229,,,,,,,,,,,,,, -104300,1.5423331,2.0859897,,,,,,,,,,,,,, -104400,1.6970533,1.8346587,,,,,,,,,,,,,, -104460,,,0.7704687118530273,0.9185051321983336,0.6934799551963806,1.2655454874038696,50000.0,0.5600000023841858,1.9212124347686768,10000.0,45837.52560162544,49539.26046657562,45837.52560162544,3692.076496124268,4.183363676071167,0.0 -104500,1.5851539,2.1795378,,,,,,,,,,,,,, -104600,1.5304396,3.4247465,,,,,,,,,,,,,, -104700,1.4795343,4.205735,,,,,,,,,,,,,, -104800,1.4076343,2.8008122,,,,,,,,,,,,,, -104900,1.6112757,2.2508423,,,,,,,,,,,,,, -105000,1.5226895,3.579279,,,,,,,,,,,,,, -105100,1.5777233,1.917427,,,,,,,,,,,,,, -105200,1.5303733,2.0120697,,,,,,,,,,,,,, -105300,1.3613734,3.8274388,,,,,,,,,,,,,, -105400,1.486309,1.7633475,,,,,,,,,,,,,, -105421,,,0.7509570121765137,0.986622154712677,0.6902799606323242,1.2532689571380615,50000.0,0.5672000050544739,1.8987038135528564,10000.0,46257.73142623901,49995.38002538681,46257.73142623901,3727.893661260605,4.22884464263916,0.0 -105500,1.4046798,4.2791305,,,,,,,,,,,,,, -105600,1.5941101,1.7279317,,,,,,,,,,,,,, -105700,1.5890027,2.002722,,,,,,,,,,,,,, -105800,1.5232475,1.9113314,,,,,,,,,,,,,, -105900,1.4842927,2.0951447,,,,,,,,,,,,,, -106000,1.4938872,1.7338479,,,,,,,,,,,,,, -106100,1.4746816,4.0813227,,,,,,,,,,,,,, -106200,1.4802469,2.0292294,,,,,,,,,,,,,, -106300,1.6498411,1.9579782,,,,,,,,,,,,,, -106380,,,0.7627343535423279,0.9351568222045898,0.6969799995422363,1.2286380529403689,50000.0,0.5681000351905823,1.871907353401184,10000.0,46677.66419768333,50461.02531313896,46677.66419768333,3773.502552270889,4.281575441360474,0.0 -106400,1.3833652,3.6420922,,,,,,,,,,,,,, -106500,1.5841408,1.8601171,,,,,,,,,,,,,, -106600,1.6294446,1.9187292,,,,,,,,,,,,,, -106700,1.5807486,1.8012195,,,,,,,,,,,,,, -106800,1.5624474,2.0981302,,,,,,,,,,,,,, -106900,1.6383004,1.9185529,,,,,,,,,,,,,, -107000,1.4073976,3.1171668,,,,,,,,,,,,,, -107100,1.7039226,1.8854324,,,,,,,,,,,,,, -107200,1.5931473,2.0736117,,,,,,,,,,,,,, -107300,1.405696,2.9170787,,,,,,,,,,,,,, -107336,,,0.7720312476158142,0.901677131652832,0.6996600031852722,1.2177472114562988,50000.0,0.5746000409126282,1.867994785308838,10000.0,47097.86094522476,50920.30652046204,47097.86094522476,3812.488776922226,4.329556226730347,0.0 -107400,1.8358014,1.9313964,,,,,,,,,,,,,, -107500,1.6680032,1.9927073,,,,,,,,,,,,,, -107600,1.4989358,3.1598458,,,,,,,,,,,,,, -107700,1.4846994,2.154459,,,,,,,,,,,,,, -107800,1.6121163,1.8728074,,,,,,,,,,,,,, -107900,1.4980811,1.9705669,,,,,,,,,,,,,, -108000,1.661718,1.7867094,,,,,,,,,,,,,, -108100,1.5014831,3.2692866,,,,,,,,,,,,,, -108200,1.586252,4.3681135,,,,,,,,,,,,,, -108291,,,0.7859960794448853,0.8419576287269592,0.6972599625587463,1.2124354839324951,50000.0,0.5742000341415405,1.8423296213150024,10000.0,47518.158441782,51385.118559122086,47518.158441782,3856.903764724731,4.379816055297852,0.0 -108300,1.542266,4.2197356,,,,,,,,,,,,,, -108400,1.5668972,3.2609499,,,,,,,,,,,,,, -108500,1.605294,1.878566,,,,,,,,,,,,,, -108600,1.4772953,4.3558784,,,,,,,,,,,,,, -108700,1.5111951,3.20259,,,,,,,,,,,,,, -108800,1.4514716,3.3606906,,,,,,,,,,,,,, -108900,1.625217,1.9654955,,,,,,,,,,,,,, -109000,1.6442988,1.8745222,,,,,,,,,,,,,, -109100,1.4468572,3.55252,,,,,,,,,,,,,, -109200,1.4356141,3.5954626,,,,,,,,,,,,,, -109246,,,0.7622460722923279,0.9467803835868835,0.6993599534034729,1.227556586265564,50000.0,0.5749000310897827,1.869747757911682,10000.0,47938.24389958382,51849.41720533371,47938.24389958382,3901.020851612091,4.426162004470825,0.0 -109300,1.7059554,1.8269331,,,,,,,,,,,,,, -109400,1.7380922,1.7585812,,,,,,,,,,,,,, -109500,1.5803708,1.8351853,,,,,,,,,,,,,, -109600,1.5556546,1.9492291,,,,,,,,,,,,,, -109700,1.4246348,1.8690135,,,,,,,,,,,,,, -109800,1.7928978,1.8567971,,,,,,,,,,,,,, -109900,1.6793381,1.957928,,,,,,,,,,,,,, -110000,1.7201002,1.7065449,,,,,,,,,,,,,, -110100,1.7228376,1.8752092,,,,,,,,,,,,,, -110200,1.5953246,2.231229,,,,,,,,,,,,,, -110201,,,0.7684179544448853,0.9048970341682434,0.7038999795913696,1.196832299232483,50000.0,0.5772000551223755,1.8481879234313965,10000.0,48358.40622162819,52310.96383190155,48358.40622162819,3942.307982444763,4.472722053527832,0.0 -110300,1.5103569,4.154955,,,,,,,,,,,,,, -110400,1.7709663,1.8778746,,,,,,,,,,,,,, -110500,1.5959142,1.8656391,,,,,,,,,,,,,, -110600,1.4209476,2.6054173,,,,,,,,,,,,,, -110700,1.5906404,2.068546,,,,,,,,,,,,,, -110800,1.6396677,2.084935,,,,,,,,,,,,,, -110900,1.510772,2.2431252,,,,,,,,,,,,,, -111000,1.5505203,1.8284985,,,,,,,,,,,,,, -111100,1.6264592,1.8405159,,,,,,,,,,,,,, -111158,,,0.7777343392372131,0.8578104972839355,0.7034199833869934,1.197596788406372,50000.0,0.5800999999046326,1.8331900835037231,10000.0,48778.63722419739,52773.36154890061,48778.63722419739,3984.3732771873474,4.523653030395508,0.0 -111200,1.3920006,2.3522124,,,,,,,,,,,,,, -111300,1.4197186,2.6520762,,,,,,,,,,,,,, -111400,1.7116574,4.286585,,,,,,,,,,,,,, -111500,1.6544629,1.7234645,,,,,,,,,,,,,, -111600,1.7760452,1.7860601,,,,,,,,,,,,,, -111700,1.7357265,2.2974272,,,,,,,,,,,,,, -111800,1.6678678,4.2720137,,,,,,,,,,,,,, -111900,1.7463573,1.9264662,,,,,,,,,,,,,, -112000,1.4925499,2.1893158,,,,,,,,,,,,,, -112100,1.727202,1.7684304,,,,,,,,,,,,,, -112113,,,0.7675976157188416,0.9133527278900146,0.7023599743843079,1.210939884185791,50000.0,0.5842000246047974,1.8372535705566408,10000.0,49198.80200004578,53233.61412191391,49198.80200004578,4024.359180688858,4.575310945510864,0.0 -112200,1.6065787,2.5818255,,,,,,,,,,,,,, -112300,1.7155749,1.8825886,,,,,,,,,,,,,, -112400,1.6974912,1.849365,,,,,,,,,,,,,, -112500,1.5723885,1.8328513,,,,,,,,,,,,,, -112600,1.7003771,1.7292961,,,,,,,,,,,,,, -112700,1.4672732,3.2207327,,,,,,,,,,,,,, -112800,1.872654,1.9384536,,,,,,,,,,,,,, -112900,1.656717,1.7389246,,,,,,,,,,,,,, -113000,1.6962738,1.8061051,,,,,,,,,,,,,, -113068,,,0.76917964220047,0.9122159481048584,0.703819990158081,1.2008824348449707,50000.0,0.5800000429153442,1.843191146850586,10000.0,49619.058660030365,53694.006588459015,49619.058660030365,4064.391876220703,4.628418207168579,0.0 -113100,1.3989006,2.884606,,,,,,,,,,,,,, -113200,1.5337982,2.6939921,,,,,,,,,,,,,, -113300,1.6806297,1.8922076,,,,,,,,,,,,,, -113400,1.5518798,3.34115,,,,,,,,,,,,,, -113500,1.6676526,3.1904867,,,,,,,,,,,,,, -113600,1.5354098,3.866717,,,,,,,,,,,,,, -113700,1.5577286,1.7243758,,,,,,,,,,,,,, -113800,1.6318735,2.3591309,,,,,,,,,,,,,, -113900,1.7096587,2.0591407,,,,,,,,,,,,,, -114000,1.6786107,2.9357355,,,,,,,,,,,,,, -114020,,,0.77685546875,0.8710706830024719,0.707539975643158,1.1852104663848877,50000.0,0.5845000147819519,1.8225212097167969,10000.0,50039.012590408325,54156.05561089516,50039.012590408325,4106.380417108536,4.684727907180786,0.0 -114100,1.804361,1.8710636,,,,,,,,,,,,,, -114200,1.5603415,2.7004461,,,,,,,,,,,,,, -114300,1.6373587,1.7397904,,,,,,,,,,,,,, -114400,1.5108136,3.3731635,,,,,,,,,,,,,, -114500,1.830076,4.286707,,,,,,,,,,,,,, -114600,1.6844302,1.8024778,,,,,,,,,,,,,, -114700,1.5964574,1.8641713,,,,,,,,,,,,,, -114800,1.5124265,3.3029718,,,,,,,,,,,,,, -114900,1.5540444,4.2331295,,,,,,,,,,,,,, -114978,,,0.7953710556030273,0.8017292022705078,0.7105000019073486,1.1759315729141235,50000.0,0.5812000036239624,1.831921935081482,10000.0,50459.272194862366,54612.63543081284,50459.272194862366,4142.604337930679,4.731132507324219,0.0 -115000,1.7034342,4.135503,,,,,,,,,,,,,, -115100,1.8656961,1.773,,,,,,,,,,,,,, -115200,1.6112164,1.7954891,,,,,,,,,,,,,, -115300,1.7839223,1.799248,,,,,,,,,,,,,, -115400,1.5710589,1.6781542,,,,,,,,,,,,,, -115500,1.7278131,1.7835824,,,,,,,,,,,,,, -115600,1.5795641,1.8065631,,,,,,,,,,,,,, -115700,1.86126,4.1214614,,,,,,,,,,,,,, -115800,1.6759596,1.8453043,,,,,,,,,,,,,, -115900,1.6678383,3.9514096,,,,,,,,,,,,,, -115933,,,0.775097668170929,0.878689169883728,0.7120400071144104,1.159434199333191,50000.0,0.5839000344276428,1.8186765909194944,10000.0,50879.5531334877,55079.03768348694,50879.5531334877,4188.624214410782,4.783179759979248,0.0 -116000,1.9332713,1.9361149,,,,,,,,,,,,,, -116100,1.828589,1.7324448,,,,,,,,,,,,,, -116200,1.6948078,1.7192936,,,,,,,,,,,,,, -116300,1.4855111,2.7366562,,,,,,,,,,,,,, -116400,1.5504713,2.7122972,,,,,,,,,,,,,, -116500,1.7495035,1.7557254,,,,,,,,,,,,,, -116600,1.7319639,1.8141255,,,,,,,,,,,,,, -116700,1.6307648,1.676708,,,,,,,,,,,,,, -116800,1.6776328,2.031585,,,,,,,,,,,,,, -116889,,,0.7811132669448853,0.8491606712341309,0.712119996547699,1.1562445163726809,50000.0,0.5835000276565552,1.794126033782959,10000.0,51299.98179316521,55542.57398247719,51299.98179316521,4231.638665437698,4.826200723648071,0.0 -116900,1.781568,1.8220925,,,,,,,,,,,,,, -117000,1.9948254,1.757308,,,,,,,,,,,,,, -117100,1.733175,4.1333485,,,,,,,,,,,,,, -117200,1.7204328,3.0116339,,,,,,,,,,,,,, -117300,1.7640285,3.1426806,,,,,,,,,,,,,, -117400,1.6807705,1.7734193,,,,,,,,,,,,,, -117500,1.687845,1.7024794,,,,,,,,,,,,,, -117600,1.6376143,2.104639,,,,,,,,,,,,,, -117700,1.8151038,1.8030431,,,,,,,,,,,,,, -117800,1.613324,3.1626172,,,,,,,,,,,,,, -117847,,,0.78822261095047,0.8326475620269775,0.710919976234436,1.1654090881347656,50000.0,0.5782000422477722,1.8094024658203125,10000.0,51720.19587230682,56000.84201049805,51720.19587230682,4269.591933727264,4.876907825469971,0.0 -117900,1.7952353,1.7385348,,,,,,,,,,,,,, -118000,1.6646175,1.763938,,,,,,,,,,,,,, -118100,1.8289397,1.7434429,,,,,,,,,,,,,, -118200,1.6517185,1.7868408,,,,,,,,,,,,,, -118300,1.5466182,2.370915,,,,,,,,,,,,,, -118400,1.9091187,1.7712789,,,,,,,,,,,,,, -118500,1.7293522,4.1436014,,,,,,,,,,,,,, -118600,1.7381555,3.4690022,,,,,,,,,,,,,, -118700,1.8019211,1.7343292,,,,,,,,,,,,,, -118800,1.7501829,1.6736546,,,,,,,,,,,,,, -118805,,,0.7870702743530273,0.817166268825531,0.7132399678230286,1.1419477462768557,50000.0,0.5823000073432922,1.799047350883484,10000.0,52140.52798819542,56461.5328617096,52140.52798819542,4309.85601067543,4.921182155609131,0.0 -118900,1.8171753,2.0051174,,,,,,,,,,,,,, -119000,1.6194564,2.2070112,,,,,,,,,,,,,, -119100,1.8757777,1.7460203,,,,,,,,,,,,,, -119200,1.8902184,4.048197,,,,,,,,,,,,,, -119300,1.4777396,3.2160249,,,,,,,,,,,,,, -119400,1.8545007,1.7831751,,,,,,,,,,,,,, -119500,1.6438433,1.9015479,,,,,,,,,,,,,, -119600,1.7472141,3.4094288,,,,,,,,,,,,,, -119700,1.6365399,3.866115,,,,,,,,,,,,,, -119764,,,0.7832226157188416,0.831150472164154,0.7125399708747864,1.1509326696395874,50000.0,0.5859000086784363,1.8054711818695068,10000.0,52560.54592466354,56921.385232687,52560.54592466354,4349.591715335846,4.970265626907349,0.0 -119800,1.7142341,2.7219124,,,,,,,,,,,,,, -119900,1.6935623,3.8170605,,,,,,,,,,,,,, -120000,1.7755631,3.7967658,,,,,,,,,,,,,, -120100,1.5565889,2.5418882,,,,,,,,,,,,,, -120200,1.597063,3.3753414,,,,,,,,,,,,,, -120300,1.7305982,2.954452,,,,,,,,,,,,,, -120400,1.5014186,2.5675406,,,,,,,,,,,,,, -120500,1.9623317,3.6944916,,,,,,,,,,,,,, -120600,1.8499209,1.824844,,,,,,,,,,,,,, -120700,1.8990319,1.7508814,,,,,,,,,,,,,, -120716,,,0.7913867235183716,0.8026143312454224,0.7174599766731262,1.1281665563583374,50000.0,0.5871000289916992,1.7703475952148438,10000.0,52980.62607479096,57382.91956567764,52980.62607479096,4390.941907405853,5.0248682498931885,0.0 -120800,1.6879972,2.2524626,,,,,,,,,,,,,, -120900,1.6514976,2.631507,,,,,,,,,,,,,, -121000,1.661693,1.7503797,,,,,,,,,,,,,, -121100,1.6289666,2.3361554,,,,,,,,,,,,,, -121200,1.7952139,1.7590436,,,,,,,,,,,,,, -121300,1.7463979,2.6328537,,,,,,,,,,,,,, -121400,1.6576984,1.750124,,,,,,,,,,,,,, -121500,1.7267069,1.7324712,,,,,,,,,,,,,, -121600,1.768995,1.8220131,,,,,,,,,,,,,, -121672,,,0.7996289134025574,0.7926232814788818,0.7188000082969666,1.1456674337387085,50000.0,0.5902000069618225,1.784095048904419,10000.0,53400.83889579773,57844.45234084129,53400.83889579773,4432.157784461975,5.077365398406982,0.0 -121700,1.8482813,1.7910092,,,,,,,,,,,,,, -121800,1.6751088,3.214961,,,,,,,,,,,,,, -121900,1.8178129,1.7027286,,,,,,,,,,,,,, -122000,1.7115037,3.2912536,,,,,,,,,,,,,, -122100,1.7075077,2.842921,,,,,,,,,,,,,, -122200,1.5792918,2.2658105,,,,,,,,,,,,,, -122300,1.5918313,2.5770035,,,,,,,,,,,,,, -122400,2.0845366,4.244634,,,,,,,,,,,,,, -122500,1.6981165,3.5734372,,,,,,,,,,,,,, -122600,1.7591268,3.2254262,,,,,,,,,,,,,, -122625,,,0.7887304425239563,0.8185681104660034,0.7198799848556519,1.1269749402999878,50000.0,0.5949000120162964,1.749130368232727,10000.0,53820.74506020546,58302.89676403999,53820.74506020546,4470.588375091553,5.134720325469971,0.0 -122700,1.5924314,2.4603329,,,,,,,,,,,,,, -122800,1.5492076,2.6990848,,,,,,,,,,,,,, -122900,1.7620456,1.6271678,,,,,,,,,,,,,, -123000,1.7565404,1.8206033,,,,,,,,,,,,,, -123100,1.8627002,1.6146768,,,,,,,,,,,,,, -123200,1.7494478,2.093778,,,,,,,,,,,,,, -123300,1.6783338,1.9777046,,,,,,,,,,,,,, -123400,1.8195596,1.7545831,,,,,,,,,,,,,, -123500,1.8846154,4.1568894,,,,,,,,,,,,,, -123580,,,0.7926952838897705,0.8255926966667175,0.7219199538230896,1.1399753093719482,50000.0,0.598300039768219,1.7739555835723877,10000.0,54240.97389984131,58764.0489218235,54240.97389984131,4511.416751623154,5.179789066314697,0.0 -123600,1.9576229,1.6898528,,,,,,,,,,,,,, -123700,1.5972,1.9428537,,,,,,,,,,,,,, -123800,1.6407006,2.0289376,,,,,,,,,,,,,, -123900,1.994756,1.6532279,,,,,,,,,,,,,, -124000,1.802893,1.6119272,,,,,,,,,,,,,, -124100,1.8384739,1.7401822,,,,,,,,,,,,,, -124200,1.9024496,4.027767,,,,,,,,,,,,,, -124300,1.6266776,2.9532733,,,,,,,,,,,,,, -124400,2.0449235,1.5681057,,,,,,,,,,,,,, -124500,1.6504097,2.3368316,,,,,,,,,,,,,, -124534,,,0.7994140386581421,0.7840669751167297,0.7232399582862854,1.110971450805664,50000.0,0.5958999991416931,1.755924940109253,10000.0,54660.88675928116,59221.33545231819,54660.88675928116,4548.689545869827,5.231285095214844,0.0 -124600,1.6520845,2.5693886,,,,,,,,,,,,,, -124700,1.8258429,1.8089476,,,,,,,,,,,,,, -124800,1.7199533,1.6398109,,,,,,,,,,,,,, -124900,1.596286,2.348102,,,,,,,,,,,,,, -125000,1.9345324,1.9342896,,,,,,,,,,,,,, -125100,1.883489,1.8118551,,,,,,,,,,,,,, -125200,1.8326591,1.8360101,,,,,,,,,,,,,, -125300,1.8133616,3.1086454,,,,,,,,,,,,,, -125400,1.6324605,2.840488,,,,,,,,,,,,,, -125490,,,0.8218359351158142,0.6901920437812805,0.7267999649047852,1.1000694036483765,50000.0,0.6070000529289246,1.718909502029419,10000.0,55080.96838617325,59680.13125133514,55080.96838617325,4587.305732250214,5.2786900997161865,0.0 -125500,1.7728257,1.7631108,,,,,,,,,,,,,, -125600,1.743321,2.1331499,,,,,,,,,,,,,, -125700,1.8935764,1.5974512,,,,,,,,,,,,,, -125800,1.7561998,1.5390512,,,,,,,,,,,,,, -125900,1.7996018,1.7740426,,,,,,,,,,,,,, -126000,1.751025,2.0011415,,,,,,,,,,,,,, -126100,1.8136321,1.641078,,,,,,,,,,,,,, -126200,1.7643522,1.9014862,,,,,,,,,,,,,, -126300,1.6236383,3.0200317,,,,,,,,,,,,,, -126400,1.8610795,1.6707253,,,,,,,,,,,,,, -126445,,,0.79296875,0.8103262782096863,0.7235199809074402,1.1214704513549805,50000.0,0.5904000401496887,1.757405400276184,10000.0,55501.05866575241,60144.09332036972,55501.05866575241,4631.069079637528,5.3379082679748535,0.0 -126500,1.9045718,1.571725,,,,,,,,,,,,,, -126600,1.8869117,1.6735783,,,,,,,,,,,,,, -126700,1.8815848,1.7925802,,,,,,,,,,,,,, -126800,1.904708,1.6623743,,,,,,,,,,,,,, -126900,2.0569644,4.0301256,,,,,,,,,,,,,, -127000,2.1924028,3.8639386,,,,,,,,,,,,,, -127100,1.8392996,2.1456227,,,,,,,,,,,,,, -127200,1.927675,1.5735946,,,,,,,,,,,,,, -127300,1.9329439,1.8986079,,,,,,,,,,,,,, -127400,1.9554759,1.8948386,,,,,,,,,,,,,, -127403,,,0.80140620470047,0.7748098373413086,0.7250999808311462,1.1140674352645874,50000.0,0.600100040435791,1.7393742799758911,10000.0,55921.156155347824,60609.03382444382,55921.156155347824,4675.814074754715,5.386435270309448,0.0 -127500,2.0090926,4.104971,,,,,,,,,,,,,, -127600,1.8641676,1.6306244,,,,,,,,,,,,,, -127700,1.7968597,1.8550645,,,,,,,,,,,,,, -127800,1.6974993,1.7553797,,,,,,,,,,,,,, -127900,1.7049379,2.3408852,,,,,,,,,,,,,, -128000,1.8534629,2.265924,,,,,,,,,,,,,, -128100,1.9899832,1.6752219,,,,,,,,,,,,,, -128200,1.9646237,3.6650906,,,,,,,,,,,,,, -128300,1.9069738,1.7169569,,,,,,,,,,,,,, -128362,,,0.808886706829071,0.7777705192565918,0.7276399731636047,1.1254874467849731,50000.0,0.6046000123023987,1.7480357885360718,10000.0,56341.33557915688,61074.16904568672,56341.33557915688,4720.664792776108,5.441857099533081,0.0 -128400,1.8417609,2.3359861,,,,,,,,,,,,,, -128500,1.8315358,3.086624,,,,,,,,,,,,,, -128600,1.9629072,1.892456,,,,,,,,,,,,,, -128700,2.0668764,1.5176377,,,,,,,,,,,,,, -128800,1.7597014,2.3521752,,,,,,,,,,,,,, -128900,1.9559654,1.733309,,,,,,,,,,,,,, -129000,1.8532745,1.5413067,,,,,,,,,,,,,, -129100,1.9868156,3.7159958,,,,,,,,,,,,,, -129200,1.9069444,1.6222879,,,,,,,,,,,,,, -129300,1.8377714,1.7520325,,,,,,,,,,,,,, -129319,,,0.8020898103713989,0.7728776335716248,0.7276600003242493,1.092849850654602,50000.0,0.6008000373840332,1.7313231229782104,10000.0,56761.44100022316,61540.09782385826,56761.44100022316,4766.389065265656,5.490319013595581,0.0 -129400,1.7690715,1.521528,,,,,,,,,,,,,, -129500,1.7837087,2.0462964,,,,,,,,,,,,,, -129600,1.9783018,3.7500434,,,,,,,,,,,,,, -129700,1.7512755,1.7051612,,,,,,,,,,,,,, -129800,1.8595598,1.920472,,,,,,,,,,,,,, -129900,1.6917332,2.5714877,,,,,,,,,,,,,, -130000,1.680911,3.3207393,,,,,,,,,,,,,, -130100,1.8053435,1.5822718,,,,,,,,,,,,,, -130200,1.8782609,1.54679,,,,,,,,,,,,,, -130272,,,0.8078320026397705,0.7349570989608765,0.7328199744224548,1.0716028213500977,50000.0,0.6047000288963318,1.703953981399536,10000.0,57181.441122055054,62003.752519369125,57181.441122055054,4809.945647716522,5.539454460144043,0.0 -130300,2.0061936,1.5967482,,,,,,,,,,,,,, -130400,2.0367773,1.626516,,,,,,,,,,,,,, -130500,1.9286011,2.5023775,,,,,,,,,,,,,, -130600,2.0195444,1.6007669,,,,,,,,,,,,,, -130700,1.9926547,3.791287,,,,,,,,,,,,,, -130800,1.8903759,1.8402809,,,,,,,,,,,,,, -130900,1.800514,1.8137863,,,,,,,,,,,,,, -131000,2.0090065,1.5411673,,,,,,,,,,,,,, -131100,2.0484736,1.7831476,,,,,,,,,,,,,, -131200,1.9274274,1.574146,,,,,,,,,,,,,, -131223,,,0.81068354845047,0.7343043088912964,0.734499990940094,1.0711177587509155,50000.0,0.6057000160217285,1.6870981454849243,10000.0,57601.39503288269,62466.73106837273,57601.39503288269,4852.868507862091,5.591620922088623,0.0 -131300,1.8216988,3.4763186,,,,,,,,,,,,,, -131400,2.0312574,2.0284295,,,,,,,,,,,,,, -131500,2.0056438,1.6243782,,,,,,,,,,,,,, -131600,1.6972312,2.5468411,,,,,,,,,,,,,, -131700,1.9269582,1.5615088,,,,,,,,,,,,,, -131800,2.0768988,4.0373545,,,,,,,,,,,,,, -131900,2.0172944,1.5698351,,,,,,,,,,,,,, -132000,2.073368,3.815538,,,,,,,,,,,,,, -132100,2.032612,3.9096336,,,,,,,,,,,,,, -132177,,,0.8233007788658142,0.6953759789466858,0.7353000044822693,1.072637915611267,50000.0,0.6094000339508057,1.702414870262146,10000.0,58021.39619445801,62926.28131008148,58021.39619445801,4892.319691181183,5.639818906784058,0.0 -132200,2.142835,3.7554226,,,,,,,,,,,,,, -132300,2.101913,1.5442669,,,,,,,,,,,,,, -132400,1.8321522,1.5318342,,,,,,,,,,,,,, -132500,2.070024,3.8360028,,,,,,,,,,,,,, -132600,2.0579162,1.5739013,,,,,,,,,,,,,, -132700,1.8083066,2.57792,,,,,,,,,,,,,, -132800,1.9299312,1.8381643,,,,,,,,,,,,,, -132900,2.0046306,3.872737,,,,,,,,,,,,,, -133000,2.0495584,1.5883967,,,,,,,,,,,,,, -133100,1.942715,1.6167732,,,,,,,,,,,,,, -133130,,,0.8082226514816284,0.7451169490814209,0.7351199984550476,1.071921944618225,50000.0,0.6080000400543213,1.6945815086364746,10000.0,58441.57426691055,63391.85244774818,58441.57426691055,4937.602854728699,5.699962615966797,0.0 -133200,1.9078652,1.6378669,,,,,,,,,,,,,, -133300,1.8210069,3.5971105,,,,,,,,,,,,,, -133400,1.9628148,1.8939435,,,,,,,,,,,,,, -133500,2.0110457,3.7928042,,,,,,,,,,,,,, -133600,1.7682223,2.3528473,,,,,,,,,,,,,, -133700,2.1136768,1.5451375,,,,,,,,,,,,,, -133800,2.0892544,3.971651,,,,,,,,,,,,,, -133900,2.1159682,1.5528891,,,,,,,,,,,,,, -134000,2.021564,3.9134576,,,,,,,,,,,,,, -134083,,,0.8153125047683716,0.7177433371543884,0.7369799613952637,1.052171230316162,50000.0,0.6075000166893005,1.6868488788604736,10000.0,58861.696504592896,63857.983344078064,58861.696504592896,4983.517159461975,5.74588418006897,0.0 -134100,1.8913296,1.9461058,,,,,,,,,,,,,, -134200,2.0396488,2.7118244,,,,,,,,,,,,,, -134300,2.1664414,1.4927671,,,,,,,,,,,,,, -134400,1.8779758,2.4219482,,,,,,,,,,,,,, -134500,1.9528002,3.4093795,,,,,,,,,,,,,, -134600,1.9399945,1.6662042,,,,,,,,,,,,,, -134700,1.8315458,1.7285154,,,,,,,,,,,,,, -134800,1.9878877,1.5672762,,,,,,,,,,,,,, -134900,2.280529,3.952086,,,,,,,,,,,,,, -135000,1.9556068,1.481242,,,,,,,,,,,,,, -135039,,,0.8215429782867432,0.6732996702194214,0.7372199892997742,1.0496302843093872,50000.0,0.6113000512123108,1.669729232788086,10000.0,59281.61557817459,64322.52414488792,59281.61557817459,5028.039650678635,5.795482873916626,0.0 -135100,2.0730999,1.6018608,,,,,,,,,,,,,, -135200,2.2342935,3.9354396,,,,,,,,,,,,,, -135300,2.2005389,3.4215508,,,,,,,,,,,,,, -135400,1.9645376,1.5886285,,,,,,,,,,,,,, -135500,1.9224654,2.399817,,,,,,,,,,,,,, -135600,1.9766726,3.282778,,,,,,,,,,,,,, -135700,1.9563222,1.7909683,,,,,,,,,,,,,, -135800,2.2059712,3.7720172,,,,,,,,,,,,,, -135900,1.8377552,1.7555002,,,,,,,,,,,,,, -135996,,,0.8267382383346558,0.6621195673942566,0.7412199974060059,1.0311529636383057,50000.0,0.6158000230789185,1.6552666425704956,10000.0,59701.6204688549,64786.15665626526,59701.6204688549,5071.571031808853,5.841874361038208,0.0 -136000,2.3031304,3.9113798,,,,,,,,,,,,,, -136100,2.0649586,1.5770273,,,,,,,,,,,,,, -136200,1.9577714,2.8469236,,,,,,,,,,,,,, -136300,1.8715913,1.4256094,,,,,,,,,,,,,, -136400,2.271355,3.6106255,,,,,,,,,,,,,, -136500,2.067631,1.6019126,,,,,,,,,,,,,, -136600,1.8910638,2.0061975,,,,,,,,,,,,,, -136700,1.752321,2.0113857,,,,,,,,,,,,,, -136800,2.1329277,1.8020835,,,,,,,,,,,,,, -136900,2.1036248,2.9702108,,,,,,,,,,,,,, -136949,,,0.81689453125,0.7047690153121948,0.7401599884033203,1.0472062826156616,50000.0,0.6136000156402588,1.6724987030029297,10000.0,60121.56489992142,65252.8662545681,60121.56489992142,5118.237744569778,5.8914642333984375,0.0 -137000,1.9997419,1.5083814,,,,,,,,,,,,,, -137100,1.9491597,2.003998,,,,,,,,,,,,,, -137200,2.0592172,3.4895508,,,,,,,,,,,,,, -137300,2.072222,2.993202,,,,,,,,,,,,,, -137400,1.9928992,2.5702229,,,,,,,,,,,,,, -137500,1.9569206,1.820081,,,,,,,,,,,,,, -137600,2.3111534,1.5874932,,,,,,,,,,,,,, -137700,2.098385,2.255365,,,,,,,,,,,,,, -137800,2.145639,1.4773393,,,,,,,,,,,,,, -137900,2.0138032,1.4780173,,,,,,,,,,,,,, -137905,,,0.8246093392372131,0.6600923538208008,0.7413600087165833,1.0288509130477903,50000.0,0.6154000163078308,1.663904905319214,10000.0,60541.76283168793,65710.68705844879,60541.76283168793,5155.756701231003,5.945829391479492,0.0 -138000,2.0154629,2.6099205,,,,,,,,,,,,,, -138100,1.8835726,3.3358045,,,,,,,,,,,,,, -138200,2.0498457,1.5733663,,,,,,,,,,,,,, -138300,2.084805,1.6728706,,,,,,,,,,,,,, -138400,2.1404953,1.8141202,,,,,,,,,,,,,, -138500,2.0420747,1.8189062,,,,,,,,,,,,,, -138600,2.0776339,1.626832,,,,,,,,,,,,,, -138700,2.027259,1.5714025,,,,,,,,,,,,,, -138800,2.0333867,1.4783027,,,,,,,,,,,,,, -138860,,,0.8321484327316284,0.6528514623641968,0.7429599761962891,1.038433313369751,50000.0,0.6155000329017639,1.6657668352127075,10000.0,60962.057092905045,66171.20323944092,60962.057092905045,5195.878809213638,5.996556520462036,0.0 -138900,2.0377479,1.631638,,,,,,,,,,,,,, -139000,2.0514598,1.4218894,,,,,,,,,,,,,, -139100,2.147542,3.441763,,,,,,,,,,,,,, -139200,1.9423279,2.3358388,,,,,,,,,,,,,, -139300,1.990847,2.472457,,,,,,,,,,,,,, -139400,2.2151625,1.4304861,,,,,,,,,,,,,, -139500,2.149748,1.6593754,,,,,,,,,,,,,, -139600,2.2095425,3.4028132,,,,,,,,,,,,,, -139700,2.130948,1.6357603,,,,,,,,,,,,,, -139800,2.0987704,1.4993987,,,,,,,,,,,,,, -139816,,,0.8234570026397705,0.6648285388946533,0.7445200085639954,1.0157908201217651,50000.0,0.6185000538825989,1.638699293136597,10000.0,61382.15120720863,66633.87408804893,61382.15120720863,5238.357710123062,6.045234203338623,0.0 -139900,2.0179114,1.5326579,,,,,,,,,,,,,, -140000,2.0811396,2.9887393,,,,,,,,,,,,,, -140100,2.0268805,2.44455,,,,,,,,,,,,,, -140200,2.3324075,3.4113538,,,,,,,,,,,,,, -140300,2.1520429,1.880096,,,,,,,,,,,,,, -140400,2.1303897,1.6858397,,,,,,,,,,,,,, -140500,2.1870027,1.5039327,,,,,,,,,,,,,, -140600,2.3035219,1.4938827,,,,,,,,,,,,,, -140700,2.4072206,1.5780215,,,,,,,,,,,,,, -140773,,,0.8301171660423279,0.6384050846099854,0.7467399835586548,1.0131675004959106,50000.0,0.6203000545501709,1.6408082246780396,10000.0,61802.4379234314,67092.95656824112,61802.4379234314,5277.053035020828,6.096495151519775,0.0 -140800,2.1816895,1.4869214,,,,,,,,,,,,,, -140900,2.303734,3.7317216,,,,,,,,,,,,,, -141000,2.1344655,1.7070563,,,,,,,,,,,,,, -141100,2.0252752,1.627423,,,,,,,,,,,,,, -141200,2.0197027,2.2461073,,,,,,,,,,,,,, -141300,1.9913694,2.407028,,,,,,,,,,,,,, -141400,2.1797402,1.6066921,,,,,,,,,,,,,, -141500,1.9919463,1.3915116,,,,,,,,,,,,,, -141600,2.1672127,1.5093513,,,,,,,,,,,,,, -141700,2.010712,1.8688217,,,,,,,,,,,,,, -141729,,,0.8341015577316284,0.6274772882461548,0.7470999956130981,1.0052226781845093,50000.0,0.6189000010490417,1.641674518585205,10000.0,62222.53354215622,67556.64731407166,62222.53354215622,5320.5501408576965,6.1448814868927,0.0 -141800,2.2029133,3.242278,,,,,,,,,,,,,, -141900,2.3883493,3.6100605,,,,,,,,,,,,,, -142000,2.247284,1.5392861,,,,,,,,,,,,,, -142100,2.11506,1.4771988,,,,,,,,,,,,,, -142200,1.9543597,2.1278796,,,,,,,,,,,,,, -142300,2.116654,2.9992402,,,,,,,,,,,,,, -142400,2.1068947,2.9323802,,,,,,,,,,,,,, -142500,2.3508666,3.329239,,,,,,,,,,,,,, -142600,2.309082,1.3714539,,,,,,,,,,,,,, -142683,,,0.8469530940055847,0.5803266167640686,0.749239981174469,0.993854284286499,50000.0,0.6233000159263611,1.6336963176727295,10000.0,62642.4606757164,68022.24512600899,62642.4606757164,5366.112824201584,6.20355749130249,0.0 -142700,2.882976,3.7410784,,,,,,,,,,,,,, -142800,2.1384723,3.0751162,,,,,,,,,,,,,, -142900,2.1253688,1.630659,,,,,,,,,,,,,, -143000,2.3872762,3.6645043,,,,,,,,,,,,,, -143100,2.0640423,2.0011723,,,,,,,,,,,,,, -143200,2.1539257,2.6112995,,,,,,,,,,,,,, -143300,2.1708891,1.5215867,,,,,,,,,,,,,, -143400,2.222057,1.573273,,,,,,,,,,,,,, -143500,2.4638317,3.427557,,,,,,,,,,,,,, -143600,2.143902,3.3994677,,,,,,,,,,,,,, -143641,,,0.8289452791213989,0.6445606350898743,0.7488399744033813,0.999646782875061,50000.0,0.6241000294685364,1.6271873712539673,10000.0,63062.7713201046,68484.37150168419,63062.7713201046,5407.827532052994,6.254222393035889,0.0 -143700,2.0669725,1.4507304,,,,,,,,,,,,,, -143800,2.07934,1.548505,,,,,,,,,,,,,, -143900,2.2971547,1.9534347,,,,,,,,,,,,,, -144000,2.2597258,1.5240384,,,,,,,,,,,,,, -144100,2.1322188,1.3990833,,,,,,,,,,,,,, -144200,2.360902,1.5044386,,,,,,,,,,,,,, -144300,2.4133956,3.2470644,,,,,,,,,,,,,, -144400,2.3110852,1.3325486,,,,,,,,,,,,,, -144500,2.2360778,1.5453167,,,,,,,,,,,,,, -144596,,,0.8374804258346558,0.6231352686882019,0.7511799931526184,0.9905893206596376,50000.0,0.6246000528335571,1.6112269163131714,10000.0,63482.79140949249,68946.13994860649,63482.79140949249,5449.443675994873,6.337775707244873,0.0 -144600,2.30782,1.8826172,,,,,,,,,,,,,, -144700,2.3881428,3.511036,,,,,,,,,,,,,, -144800,2.3959956,1.4050896,,,,,,,,,,,,,, -144900,2.3084064,1.4217429,,,,,,,,,,,,,, -145000,2.365071,1.4365298,,,,,,,,,,,,,, -145100,2.1235168,1.5084356,,,,,,,,,,,,,, -145200,2.1347644,1.8031968,,,,,,,,,,,,,, -145300,2.4118862,1.4683907,,,,,,,,,,,,,, -145400,2.216575,2.8842304,,,,,,,,,,,,,, -145500,2.3511078,1.481437,,,,,,,,,,,,,, -145551,,,0.842578113079071,0.6011013984680176,0.7515400052070618,0.997151792049408,50000.0,0.6222000122070312,1.6284139156341553,10000.0,63903.018033504486,69413.98632764816,63903.018033504486,5496.952882766724,6.399603128433228,0.0 -145600,2.220283,3.0934355,,,,,,,,,,,,,, -145700,2.1344635,1.3826661,,,,,,,,,,,,,, -145800,2.2679796,1.5584598,,,,,,,,,,,,,, -145900,2.1691985,1.3959364,,,,,,,,,,,,,, -146000,2.572234,3.2224035,,,,,,,,,,,,,, -146100,2.1212568,1.464894,,,,,,,,,,,,,, -146200,2.3087811,1.619028,,,,,,,,,,,,,, -146300,2.2201972,1.8417172,,,,,,,,,,,,,, -146400,2.1943445,1.3712289,,,,,,,,,,,,,, -146500,2.2935436,3.3869236,,,,,,,,,,,,,, -146507,,,0.8387304544448853,0.6057491898536682,0.7545599937438965,0.975864589214325,50000.0,0.6258000135421753,1.6044663190841677,10000.0,64323.128088235855,69876.89319229126,64323.128088235855,5539.651907920837,6.44763708114624,0.0 -146600,2.267912,1.4799215,,,,,,,,,,,,,, -146700,2.1423814,3.2597594,,,,,,,,,,,,,, -146800,2.32282,2.0280616,,,,,,,,,,,,,, -146900,2.402969,1.5290203,,,,,,,,,,,,,, -147000,2.4436824,1.5609373,,,,,,,,,,,,,, -147100,2.3341331,2.0342653,,,,,,,,,,,,,, -147200,2.3239808,1.4737636,,,,,,,,,,,,,, -147300,2.6379611,3.6304066,,,,,,,,,,,,,, -147400,2.5666611,3.5146623,,,,,,,,,,,,,, -147462,,,0.8384374976158142,0.630463719367981,0.7534799575805664,0.9893783330917358,50000.0,0.6244000196456909,1.6166449785232544,10000.0,64743.03775429726,70333.96669006348,64743.03775429726,5576.707340240479,6.506728172302246,0.0 -147500,2.4280307,3.2363024,,,,,,,,,,,,,, -147600,2.4206185,1.574375,,,,,,,,,,,,,, -147700,2.449498,3.5146945,,,,,,,,,,,,,, -147800,2.2153835,2.7533534,,,,,,,,,,,,,, -147900,2.772617,1.5384047,,,,,,,,,,,,,, -148000,2.1858454,1.8083906,,,,,,,,,,,,,, -148100,2.7567284,2.8807356,,,,,,,,,,,,,, -148200,2.2080786,1.4357052,,,,,,,,,,,,,, -148300,2.2378154,2.0715916,,,,,,,,,,,,,, -148400,2.0681622,2.511636,,,,,,,,,,,,,, -148421,,,0.8456249833106995,0.5935382843017578,0.7538999915122986,0.978970229625702,50000.0,0.6320000290870667,1.5996073484420776,10000.0,65163.15673828125,70796.16014623642,65163.15673828125,5618.667296886444,6.567543268203735,0.0 -148500,2.3028607,1.4412105,,,,,,,,,,,,,, -148600,2.1344712,2.1574523,,,,,,,,,,,,,, -148700,2.7690008,3.7053416,,,,,,,,,,,,,, -148800,2.2762978,3.1774082,,,,,,,,,,,,,, -148900,2.2549803,1.4644235,,,,,,,,,,,,,, -149000,2.385164,3.3597407,,,,,,,,,,,,,, -149100,2.3799608,1.4021382,,,,,,,,,,,,,, -149200,2.666831,3.192853,,,,,,,,,,,,,, -149300,2.4713006,2.758146,,,,,,,,,,,,,, -149378,,,0.8524218797683716,0.5529924035072327,0.7565799951553345,0.9650366306304932,50000.0,0.6299000382423401,1.5885131359100342,10000.0,65583.5103931427,71258.06540131569,65583.5103931427,5660.120743513107,6.616364240646362,0.0 -149400,2.3889596,1.8376367,,,,,,,,,,,,,, -149500,2.668669,1.4690738,,,,,,,,,,,,,, -149600,2.5715215,3.7258177,,,,,,,,,,,,,, -149700,2.3328948,1.3896769,,,,,,,,,,,,,, -149800,2.350559,1.3683108,,,,,,,,,,,,,, -149900,2.4826026,1.4362257,,,,,,,,,,,,,, -150000,2.2362072,1.3245211,,,,,,,,,,,,,, -150100,2.573077,1.340922,,,,,,,,,,,,,, -150200,2.4812856,2.1000996,,,,,,,,,,,,,, -150300,2.407062,1.3811011,,,,,,,,,,,,,, -150334,,,0.8441210985183716,0.5873963832855225,0.7583999633789062,0.9608394503593444,50000.0,0.6344000101089478,1.5788832902908323,10000.0,66003.76546001434,71726.14133667946,66003.76546001434,5707.843321084976,6.665755748748779,0.0 -150400,2.3150425,1.3453046,,,,,,,,,,,,,, -150500,2.4407096,1.4648402,,,,,,,,,,,,,, -150600,2.604899,1.4054623,,,,,,,,,,,,,, -150700,2.3003542,1.4557508,,,,,,,,,,,,,, -150800,2.5207753,1.2954618,,,,,,,,,,,,,, -150900,2.811036,3.63658,,,,,,,,,,,,,, -151000,2.7245991,3.1304622,,,,,,,,,,,,,, -151100,2.3178601,1.6298971,,,,,,,,,,,,,, -151200,2.2902908,1.4613075,,,,,,,,,,,,,, -151290,,,0.8487499952316284,0.5678128004074097,0.7597799897193909,0.949147641658783,50000.0,0.6296000480651855,1.575633525848389,10000.0,66424.10148477554,72192.51217556,66424.10148477554,5753.7766098976135,6.717709302902222,0.0 -151300,2.3362834,2.1783736,,,,,,,,,,,,,, -151400,3.1218877,3.604145,,,,,,,,,,,,,, -151500,2.8137753,3.595712,,,,,,,,,,,,,, -151600,2.3590648,1.4645724,,,,,,,,,,,,,, -151700,2.5072331,3.5769029,,,,,,,,,,,,,, -151800,2.6792645,3.3486063,,,,,,,,,,,,,, -151900,2.3693976,1.4184116,,,,,,,,,,,,,, -152000,2.4660106,2.7832997,,,,,,,,,,,,,, -152100,2.1812131,1.9809737,,,,,,,,,,,,,, -152200,2.3480144,2.732202,,,,,,,,,,,,,, -152246,,,0.8516015410423279,0.5689120888710022,0.7590599656105042,0.9657217860221864,50000.0,0.6332000494003296,1.5770167112350464,10000.0,66844.15052103996,72652.03750824928,66844.15052103996,5793.150819063187,6.770383834838867,0.0 -152300,2.4935668,3.134371,,,,,,,,,,,,,, -152400,2.3519197,2.2108269,,,,,,,,,,,,,, -152500,2.6110332,1.3142836,,,,,,,,,,,,,, -152600,2.1972828,2.5821798,,,,,,,,,,,,,, -152700,2.365633,1.3048376,,,,,,,,,,,,,, -152800,2.5515957,1.4445567,,,,,,,,,,,,,, -152900,2.3463373,1.6666012,,,,,,,,,,,,,, -153000,2.536872,1.6470375,,,,,,,,,,,,,, -153100,2.2881074,1.5575856,,,,,,,,,,,,,, -153200,,,0.8597851395606995,0.5119295120239258,0.7633399963378906,0.9361425638198853,50000.0,0.638700008392334,1.55450701713562,10000.0,67264.09875321388,73113.19673418999,67264.09875321388,5834.262246847153,6.820079565048218,0.0 -153200,2.6482027,1.3478793,,,,,,,,,,,,,, -153300,2.400726,1.3733139,,,,,,,,,,,,,, -153400,2.265418,2.3161252,,,,,,,,,,,,,, -153500,2.7567377,1.3112282,,,,,,,,,,,,,, -153600,2.4027956,1.3733141,,,,,,,,,,,,,, -153700,2.3532882,2.4460375,,,,,,,,,,,,,, -153800,2.5577862,1.3974881,,,,,,,,,,,,,, -153900,2.6232667,2.9728365,,,,,,,,,,,,,, -154000,2.3974762,2.1908228,,,,,,,,,,,,,, -154100,2.5908935,1.3467448,,,,,,,,,,,,,, -154154,,,0.8531835675239563,0.5538952350616455,0.7621200084686279,0.944446623325348,50000.0,0.6413000226020813,1.5551538467407229,10000.0,67684.29598784447,73575.05439019203,67684.29598784447,5875.818645477295,6.8744797706604,0.0 -154200,2.632485,1.3884586,,,,,,,,,,,,,, -154300,2.4179063,1.415672,,,,,,,,,,,,,, -154400,2.7607718,1.5948273,,,,,,,,,,,,,, -154500,2.5050452,1.4165261,,,,,,,,,,,,,, -154600,2.4648893,1.402899,,,,,,,,,,,,,, -154700,2.2951057,1.3223602,,,,,,,,,,,,,, -154800,2.72318,1.4922678,,,,,,,,,,,,,, -154900,2.495998,1.3679632,,,,,,,,,,,,,, -155000,2.539037,1.2990978,,,,,,,,,,,,,, -155100,2.574317,2.2019,,,,,,,,,,,,,, -155112,,,0.8586328029632568,0.5334701538085938,0.7653200030326843,0.94097101688385,50000.0,0.6392000317573547,1.5500725507736206,10000.0,68104.2943456173,74038.43568396568,68104.2943456173,5919.092953443527,6.933387756347656,0.0 -155200,2.4821894,2.7127945,,,,,,,,,,,,,, -155300,2.407977,1.2334106,,,,,,,,,,,,,, -155400,2.4405153,1.2529085,,,,,,,,,,,,,, -155500,2.5739229,1.384456,,,,,,,,,,,,,, -155600,2.586432,1.2911124,,,,,,,,,,,,,, -155700,2.4727514,1.3218086,,,,,,,,,,,,,, -155800,2.9174454,3.529467,,,,,,,,,,,,,, -155900,2.201589,2.2719853,,,,,,,,,,,,,, -156000,2.6903691,1.4369316,,,,,,,,,,,,,, -156069,,,0.8615820407867432,0.5220953822135925,0.7646399736404419,0.9310302138328552,50000.0,0.6394000053405762,1.5504746437072754,10000.0,68524.47574973106,74500.70484685898,68524.47574973106,5961.080224990845,6.984781265258789,0.0 -156100,2.5949457,1.227464,,,,,,,,,,,,,, -156200,2.498283,1.236285,,,,,,,,,,,,,, -156300,2.433062,1.4591956,,,,,,,,,,,,,, -156400,2.29851,2.4269211,,,,,,,,,,,,,, -156500,2.7937932,1.2494609,,,,,,,,,,,,,, -156600,2.7532928,3.0695448,,,,,,,,,,,,,, -156700,2.740723,3.2381704,,,,,,,,,,,,,, -156800,2.6716352,1.4190497,,,,,,,,,,,,,, -156900,2.6740146,1.5697824,,,,,,,,,,,,,, -157000,2.5359993,1.249769,,,,,,,,,,,,,, -157026,,,0.8561913967132568,0.5413217544555664,0.7650399804115295,0.9309642910957336,50000.0,0.6401000022888184,1.5474493503570557,10000.0,68944.65562939644,74961.62737870216,68944.65562939644,6001.720447778702,7.037832736968994,0.0 -157100,2.5784159,1.2917941,,,,,,,,,,,,,, -157200,2.771235,1.2949181,,,,,,,,,,,,,, -157300,2.7226336,2.257953,,,,,,,,,,,,,, -157400,2.495594,1.2834692,,,,,,,,,,,,,, -157500,2.5466793,1.2779822,,,,,,,,,,,,,, -157600,2.2728078,1.225334,,,,,,,,,,,,,, -157700,2.4737408,2.2579498,,,,,,,,,,,,,, -157800,2.7344556,1.3898181,,,,,,,,,,,,,, -157900,2.740065,1.4848604,,,,,,,,,,,,,, -157979,,,0.8600585460662842,0.5279272794723511,0.7641599774360657,0.9306029677391052,50000.0,0.6416000127792358,1.548611760139465,10000.0,69364.08955550194,75422.21031832695,69364.08955550194,6041.917126655579,7.940981864929199,0.0 -158000,2.4568655,1.2647738,,,,,,,,,,,,,, -158100,2.573893,2.0468864,,,,,,,,,,,,,, -158200,2.5209358,3.064961,,,,,,,,,,,,,, -158300,2.7309787,2.0131302,,,,,,,,,,,,,, -158400,2.415098,2.7968934,,,,,,,,,,,,,, -158500,2.4019115,1.6681099,,,,,,,,,,,,,, -158600,2.3926435,1.7684007,,,,,,,,,,,,,, -158700,3.5229316,3.5396192,,,,,,,,,,,,,, -158800,2.5702982,1.7346723,,,,,,,,,,,,,, -158900,2.4542627,1.3566616,,,,,,,,,,,,,, -158936,,,0.8643164038658142,0.5127362608909607,0.768619954586029,0.924987494945526,50000.0,0.6457000374794006,1.527753233909607,10000.0,69784.15007662773,75885.81394910812,69784.15007662773,6085.347447156906,8.00169324874878,0.0 -159000,2.7597766,2.9325027,,,,,,,,,,,,,, -159100,2.6736643,1.2657498,,,,,,,,,,,,,, -159200,2.4335442,2.0576367,,,,,,,,,,,,,, -159300,2.626966,1.2907633,,,,,,,,,,,,,, -159400,2.593074,1.2761302,,,,,,,,,,,,,, -159500,2.6078925,1.2363288,,,,,,,,,,,,,, -159600,2.5842664,1.8406013,,,,,,,,,,,,,, -159700,2.67369,1.3582647,,,,,,,,,,,,,, -159800,2.772332,3.2818973,,,,,,,,,,,,,, -159891,,,0.8702148199081421,0.484088271856308,0.7680000066757202,0.9140318632125854,50000.0,0.6384000182151794,1.5375657081604004,10000.0,70204.42936730385,76348.71235513687,70204.42936730385,6127.854299068451,8.064128637313843,0.0 -159900,2.7025115,2.580702,,,,,,,,,,,,,, -160000,3.93003,3.5993533,,,,,,,,,,,,,, -160100,3.2083647,3.1239145,,,,,,,,,,,,,, -160200,3.0973198,1.8265443,,,,,,,,,,,,,, -160300,2.4341938,1.219862,,,,,,,,,,,,,, -160400,3.9816515,3.525909,,,,,,,,,,,,,, -160500,2.5420575,1.1673856,,,,,,,,,,,,,, -160600,2.9872823,2.475301,,,,,,,,,,,,,, -160700,2.803062,1.3504128,,,,,,,,,,,,,, -160800,2.7270176,2.1815474,,,,,,,,,,,,,, -160847,,,0.8659570217132568,0.5069062113761902,0.7679399847984314,0.9168761372566224,50000.0,0.64410001039505,1.5222582817077637,10000.0,70624.54157876968,76816.47623872757,70624.54157876968,6175.40148806572,8.118890523910522,0.0 -160900,3.2172217,3.635425,,,,,,,,,,,,,, -161000,2.8057146,1.2779807,,,,,,,,,,,,,, -161100,2.9570997,1.2582223,,,,,,,,,,,,,, -161200,2.6661766,1.6218237,,,,,,,,,,,,,, -161300,2.6218138,1.341371,,,,,,,,,,,,,, -161400,3.8169725,3.276723,,,,,,,,,,,,,, -161500,2.8371727,1.3201323,,,,,,,,,,,,,, -161600,2.5858135,1.1748035,,,,,,,,,,,,,, -161700,2.7060843,2.4550729,,,,,,,,,,,,,, -161800,2.9848511,3.0303757,,,,,,,,,,,,,, -161801,,,0.8662499785423279,0.4994567334651947,0.7704600095748901,0.9088558554649352,50000.0,0.6473000049591064,1.5221868753433228,10000.0,71044.70680117607,77282.55980610847,71044.70680117607,6221.217273712158,8.172521114349365,0.0 -161900,3.125392,3.0692127,,,,,,,,,,,,,, -162000,2.7786238,1.3420064,,,,,,,,,,,,,, -162100,4.192619,3.3916698,,,,,,,,,,,,,, -162200,2.5165317,1.165777,,,,,,,,,,,,,, -162300,2.620969,1.3734324,,,,,,,,,,,,,, -162400,2.5839655,1.6539375,,,,,,,,,,,,,, -162500,2.663773,1.2511104,,,,,,,,,,,,,, -162600,3.3131554,3.2306104,,,,,,,,,,,,,, -162700,2.6295862,1.6430581,,,,,,,,,,,,,, -162757,,,0.8711718320846558,0.479285329580307,0.7711199522018433,0.9005151987075806,50000.0,0.6461000442504883,1.5066665410995483,10000.0,71464.65548014641,77746.84324288368,71464.65548014641,6265.452960968018,8.222732543945312,0.0 -162800,2.5144126,1.5346692,,,,,,,,,,,,,, -162900,2.432085,2.0810761,,,,,,,,,,,,,, -163000,2.832986,1.1969578,,,,,,,,,,,,,, -163100,2.7797382,1.3336233,,,,,,,,,,,,,, -163200,2.587973,2.0918946,,,,,,,,,,,,,, -163300,2.818338,2.7015967,,,,,,,,,,,,,, -163400,2.9185739,1.2442698,,,,,,,,,,,,,, -163500,2.8055859,1.3428913,,,,,,,,,,,,,, -163600,2.7676437,1.5760777,,,,,,,,,,,,,, -163700,2.6740096,1.3657118,,,,,,,,,,,,,, -163711,,,0.87451171875,0.4802699983119964,0.7723999619483948,0.9044002890586852,50000.0,0.6484000086784363,1.51751446723938,10000.0,71884.6285545826,78207.52199673653,71884.6285545826,6306.055119752884,8.277120113372803,0.0 -163800,3.0864294,2.1658149,,,,,,,,,,,,,, -163900,3.3060377,3.2505898,,,,,,,,,,,,,, -164000,3.1074266,1.3046186,,,,,,,,,,,,,, -164100,2.657486,1.2336857,,,,,,,,,,,,,, -164200,2.8871696,1.2359812,,,,,,,,,,,,,, -164300,3.219126,1.4660437,,,,,,,,,,,,,, -164400,2.7182083,1.1731225,,,,,,,,,,,,,, -164500,3.1252437,3.052845,,,,,,,,,,,,,, -164600,3.1679838,3.4843001,,,,,,,,,,,,,, -164667,,,0.8724218606948853,0.4697950780391693,0.774679958820343,0.8892613053321838,50000.0,0.6477000117301941,1.507075309753418,10000.0,72304.52912330627,78668.73527359962,72304.52912330627,6347.265980482101,8.328987121582031,0.0 -164700,3.0600123,1.2713563,,,,,,,,,,,,,, -164800,3.5654557,3.4254577,,,,,,,,,,,,,, -164900,2.718704,1.6531343,,,,,,,,,,,,,, -165000,3.1025045,1.2610574,,,,,,,,,,,,,, -165100,2.7511156,1.4422448,,,,,,,,,,,,,, -165200,2.6049027,1.8677135,,,,,,,,,,,,,, -165300,3.3428507,1.3179293,,,,,,,,,,,,,, -165400,2.6710515,1.4641747,,,,,,,,,,,,,, -165500,3.4054935,1.1973962,,,,,,,,,,,,,, -165600,3.0536366,1.2384565,,,,,,,,,,,,,, -165623,,,0.8748242259025574,0.4756855964660644,0.7743600010871887,0.8956592679023743,50000.0,0.650600016117096,1.5049769878387451,10000.0,72724.6625881195,79130.78731393814,72724.6625881195,6389.079388380051,8.3844153881073,0.0 -165700,3.2574615,3.2707431,,,,,,,,,,,,,, -165800,3.1061695,2.3044033,,,,,,,,,,,,,, -165900,2.5675685,1.1664331,,,,,,,,,,,,,, -166000,3.414493,3.4856381,,,,,,,,,,,,,, -166100,3.1241672,3.0588753,,,,,,,,,,,,,, -166200,2.761581,2.9088802,,,,,,,,,,,,,, -166300,3.1259604,2.4853396,,,,,,,,,,,,,, -166400,3.1983774,3.0851398,,,,,,,,,,,,,, -166500,2.902116,1.1921947,,,,,,,,,,,,,, -166577,,,0.8760351538658142,0.4655555188655853,0.7749999761581421,0.8960703611373901,50000.0,0.6516000032424927,1.505468726158142,10000.0,73144.80176186562,79594.84973239899,73144.80176186562,6432.890777826309,8.447266101837158,0.0 -166600,2.7070565,1.1768582,,,,,,,,,,,,,, -166700,2.690493,1.1309943,,,,,,,,,,,,,, -166800,2.739235,2.62283,,,,,,,,,,,,,, -166900,2.8547995,1.2368077,,,,,,,,,,,,,, -167000,2.7628121,1.3577578,,,,,,,,,,,,,, -167100,3.4847856,3.2736561,,,,,,,,,,,,,, -167200,3.1795437,1.2239096,,,,,,,,,,,,,, -167300,2.5588748,1.5531666,,,,,,,,,,,,,, -167400,2.5258467,2.089204,,,,,,,,,,,,,, -167500,2.7809784,1.7111776,,,,,,,,,,,,,, -167536,,,0.8737109303474426,0.4756890833377838,0.7744999527931213,0.8931192755699158,50000.0,0.6520000100135803,1.4951791763305664,10000.0,73564.99614834785,80058.41223526001,73564.99614834785,6476.155483722687,8.500588655471802,0.0 -167600,2.79382,1.191382,,,,,,,,,,,,,, -167700,2.873179,1.8923974,,,,,,,,,,,,,, -167800,2.9348853,2.5774095,,,,,,,,,,,,,, -167900,3.265596,1.2818692,,,,,,,,,,,,,, -168000,2.8239808,1.160306,,,,,,,,,,,,,, -168100,2.925018,1.2745687,,,,,,,,,,,,,, -168200,3.401033,3.380836,,,,,,,,,,,,,, -168300,2.9390264,2.0933256,,,,,,,,,,,,,, -168400,3.6026063,3.251147,,,,,,,,,,,,,, -168492,,,0.8785156011581421,0.4542920589447021,0.776479959487915,0.885757565498352,50000.0,0.653700053691864,1.4894458055496216,10000.0,73985.04015851021,80526.1391685009,73985.04015851021,6523.7312026023865,8.55799913406372,0.0 -168500,2.7733145,1.6986523,,,,,,,,,,,,,, -168600,2.8258896,1.2195368,,,,,,,,,,,,,, -168700,2.8376534,2.8323278,,,,,,,,,,,,,, -168800,2.862683,2.4513085,,,,,,,,,,,,,, -168900,3.4495406,1.1964102,,,,,,,,,,,,,, -169000,3.0331614,2.564101,,,,,,,,,,,,,, -169100,3.2872076,3.340724,,,,,,,,,,,,,, -169200,3.0343864,1.3223492,,,,,,,,,,,,,, -169300,3.714167,3.297708,,,,,,,,,,,,,, -169400,3.344085,2.9019735,,,,,,,,,,,,,, -169404,,,0.8810546398162842,0.4524443447589874,0.7756399512290955,0.8834952712059021,50000.0,0.6568000316619873,1.4942373037338257,10000.0,74405.21560502052,80985.82221055031,74405.21560502052,6563.126978397369,8.622390508651733,0.0 -169500,2.6112778,1.4081703,,,,,,,,,,,,,, -169600,3.349666,1.1970283,,,,,,,,,,,,,, -169700,2.9114015,1.2863951,,,,,,,,,,,,,, -169800,2.951241,1.8511658,,,,,,,,,,,,,, -169900,2.9093552,1.1082534,,,,,,,,,,,,,, -170000,2.8239188,1.155602,,,,,,,,,,,,,, -170100,2.9770558,1.2333702,,,,,,,,,,,,,, -170200,3.3938427,1.2543966,,,,,,,,,,,,,, -170300,3.4641018,3.4067793,,,,,,,,,,,,,, -170356,,,0.8832812309265137,0.4339889287948608,0.7788199782371521,0.870005190372467,50000.0,0.6581000089645386,1.4730889797210691,10000.0,74825.23263335228,81448.93747091293,74825.23263335228,6606.119580030441,8.677473068237305,0.0 -170400,3.7925122,3.344276,,,,,,,,,,,,,, -170500,2.7392273,1.1586316,,,,,,,,,,,,,, -170600,2.821528,1.2459972,,,,,,,,,,,,,, -170700,2.8610601,1.7581065,,,,,,,,,,,,,, -170800,2.9345584,2.7364676,,,,,,,,,,,,,, -170900,3.0341666,2.5346074,,,,,,,,,,,,,, -171000,3.666405,2.601757,,,,,,,,,,,,,, -171100,2.9224002,1.9380487,,,,,,,,,,,,,, -171200,2.9977636,1.3112253,,,,,,,,,,,,,, -171300,3.5120747,3.3464198,,,,,,,,,,,,,, -171311,,,0.8805468678474426,0.4478434324264526,0.7775999903678894,0.8766506314277649,50000.0,0.6554000377655029,1.4801712036132812,10000.0,75245.5376894474,81909.02955842018,75245.5376894474,6645.805203676224,8.72992992401123,0.0 -171400,2.7918458,1.1452622,,,,,,,,,,,,,, -171500,3.005211,1.1493613,,,,,,,,,,,,,, -171600,2.9255805,1.9855361,,,,,,,,,,,,,, -171700,3.028447,1.2828236,,,,,,,,,,,,,, -171800,6.5797806,1.2368392,,,,,,,,,,,,,, -171900,2.9584336,1.2021008,,,,,,,,,,,,,, -172000,3.1254225,1.2332587,,,,,,,,,,,,,, -172100,3.4086304,3.2093554,,,,,,,,,,,,,, -172200,3.3704689,3.2172334,,,,,,,,,,,,,, -172266,,,0.8803319931030273,0.4448411166667938,0.7797999978065491,0.8689996600151062,50000.0,0.6565000414848328,1.4745608568191528,10000.0,75665.74482417107,82372.0632212162,75665.74482417107,6688.530457019806,8.78211498260498,0.0 -172300,2.7933612,1.3537097,,,,,,,,,,,,,, -172400,2.9471028,2.3242178,,,,,,,,,,,,,, -172500,2.7946846,1.1104372,,,,,,,,,,,,,, -172600,3.149831,1.551088,,,,,,,,,,,,,, -172700,3.0089953,1.2899528,,,,,,,,,,,,,, -172800,2.8266685,1.176621,,,,,,,,,,,,,, -172900,2.8430915,2.0479777,,,,,,,,,,,,,, -173000,2.9174695,1.2463095,,,,,,,,,,,,,, -173100,3.7111623,3.3997,,,,,,,,,,,,,, -173200,3.4620087,1.2256606,,,,,,,,,,,,,, -173218,,,0.8841796517372131,0.430279940366745,0.7793799638748169,0.8702945709228516,50000.0,0.656000018119812,1.47455096244812,10000.0,76085.70857977867,82835.26369142532,76085.70857977867,6731.652981758118,8.847268104553223,0.0 -173300,3.2060306,2.7815704,,,,,,,,,,,,,, -173400,2.9703414,1.1509851,,,,,,,,,,,,,, -173500,3.0333428,2.4138117,,,,,,,,,,,,,, -173600,2.6481688,1.4571073,,,,,,,,,,,,,, -173700,2.9153068,1.2188979,,,,,,,,,,,,,, -173800,3.0257878,1.8394594,,,,,,,,,,,,,, -173900,3.012063,1.1215347,,,,,,,,,,,,,, -174000,2.8656058,1.6538799,,,,,,,,,,,,,, -174100,3.1129532,1.1121122,,,,,,,,,,,,,, -174174,,,0.8836718797683716,0.4322193562984466,0.7788800001144409,0.8668044805526733,50000.0,0.6546000242233276,1.47637140750885,10000.0,76505.80146336555,83296.43021583557,76505.80146336555,6772.623242139816,8.900946855545044,0.0 -174200,3.1049087,1.1791481,,,,,,,,,,,,,, -174300,3.2605333,3.1633587,,,,,,,,,,,,,, -174400,3.394098,2.9892478,,,,,,,,,,,,,, -174500,3.0744922,1.2420617,,,,,,,,,,,,,, -174600,2.8803337,1.9462377,,,,,,,,,,,,,, -174700,3.0116124,1.2810138,,,,,,,,,,,,,, -174800,3.3217492,1.6416289,,,,,,,,,,,,,, -174900,3.132895,1.1498129,,,,,,,,,,,,,, -175000,3.9591904,3.170928,,,,,,,,,,,,,, -175100,2.8768742,1.2069409,,,,,,,,,,,,,, -175131,,,0.8831640481948853,0.4393573999404907,0.781059980392456,0.8640338778495789,50000.0,0.6604000329971313,1.4671146869659424,10000.0,76925.99597835541,83759.2424018383,76925.99597835541,6815.133854389191,8.958553314208984,0.0 -175200,2.9074748,1.8003812,,,,,,,,,,,,,, -175300,3.0746872,1.2077072,,,,,,,,,,,,,, -175400,3.0849302,2.538476,,,,,,,,,,,,,, -175500,2.735071,1.7188646,,,,,,,,,,,,,, -175600,3.392919,3.030988,,,,,,,,,,,,,, -175700,3.2471094,2.4079823,,,,,,,,,,,,,, -175800,3.1453118,1.2969636,,,,,,,,,,,,,, -175900,2.919833,1.1425594,,,,,,,,,,,,,, -176000,3.1419275,1.7875347,,,,,,,,,,,,,, -176084,,,0.8835546970367432,0.4316456019878387,0.7802599668502808,0.8601840138435364,50000.0,0.6589000225067139,1.4714841842651367,10000.0,77346.0910449028,84222.7153236866,77346.0910449028,6858.408992290497,9.01213812828064,0.0 -176100,3.5861652,3.2228081,,,,,,,,,,,,,, -176200,3.5916035,2.3331187,,,,,,,,,,,,,, -176300,2.945079,1.4560258,,,,,,,,,,,,,, -176400,2.9648492,1.4939426,,,,,,,,,,,,,, -176500,3.3819253,2.6268563,,,,,,,,,,,,,, -176600,3.3246868,2.0853598,,,,,,,,,,,,,, -176700,3.058241,1.7483224,,,,,,,,,,,,,, -176800,3.1806498,2.6894784,,,,,,,,,,,,,, -176900,3.280397,1.0495433,,,,,,,,,,,,,, -177000,3.1616974,1.4686625,,,,,,,,,,,,,, -177041,,,0.8862109184265137,0.4194622039794922,0.780739963054657,0.8591325283050537,50000.0,0.6598000526428223,1.4636272192001345,10000.0,77766.39183497429,84684.91393399239,77766.39183497429,6900.202219963074,9.065497159957886,0.0 -177100,3.081928,1.1590215,,,,,,,,,,,,,, -177200,3.0543892,2.378821,,,,,,,,,,,,,, -177300,3.0495408,1.1442447,,,,,,,,,,,,,, -177400,2.956068,1.0600821,,,,,,,,,,,,,, -177500,3.5291348,2.8950233,,,,,,,,,,,,,, -177600,3.0145485,2.107892,,,,,,,,,,,,,, -177700,3.8603837,3.1399846,,,,,,,,,,,,,, -177800,2.9441512,1.3877866,,,,,,,,,,,,,, -177900,3.1748567,1.1971091,,,,,,,,,,,,,, -177997,,,0.8862695097923279,0.4204770624637604,0.7815799713134766,0.8558434247970581,50000.0,0.6583000421524048,1.462801814079285,10000.0,78186.408213377,85145.09214758873,78186.408213377,6940.258265972137,9.121346712112429,0.0 -178000,3.0130014,1.1343095,,,,,,,,,,,,,, -178100,3.2956834,1.3402554,,,,,,,,,,,,,, -178200,3.1659725,2.5253057,,,,,,,,,,,,,, -178300,3.1913943,1.2190942,,,,,,,,,,,,,, -178400,3.4755058,3.0030334,,,,,,,,,,,,,, -178500,3.7100804,3.2983165,,,,,,,,,,,,,, -178600,5.1887565,3.2857783,,,,,,,,,,,,,, -178700,2.7494783,1.6581167,,,,,,,,,,,,,, -178800,3.656422,3.0690372,,,,,,,,,,,,,, -178900,5.092248,1.0921199,,,,,,,,,,,,,, -178953,,,0.8860937356948853,0.4298219084739685,0.7818399667739868,0.8621166944503784,50000.0,0.6598000526428223,1.461968183517456,10000.0,78606.38973283768,85614.21636533737,78606.38973283768,6989.28365111351,9.188668251037598,0.0 -179000,3.1370354,1.1555239,,,,,,,,,,,,,, -179100,3.2746034,2.906672,,,,,,,,,,,,,, -179200,3.134971,1.1735492,,,,,,,,,,,,,, -179300,3.129068,1.5477731,,,,,,,,,,,,,, -179400,3.1802394,1.195394,,,,,,,,,,,,,, -179500,3.5025141,3.0746531,,,,,,,,,,,,,, -179600,2.9796324,1.1594881,,,,,,,,,,,,,, -179700,3.40569,3.0361667,,,,,,,,,,,,,, -179800,3.070709,1.6550772,,,,,,,,,,,,,, -179900,2.900244,1.1317848,,,,,,,,,,,,,, -179911,,,0.8874218463897705,0.4167936742305755,0.7824999690055847,0.8534600734710693,50000.0,0.6611000299453735,1.4549766778945925,10000.0,79026.4203722477,86077.0102751255,79026.4203722477,7031.9412133693695,9.244869709014893,0.0 -180000,3.298501,1.1251446,,,,,,,,,,,,,, -180100,3.0459285,1.2279334,,,,,,,,,,,,,, -180200,4.709093,1.1413211,,,,,,,,,,,,,, -180300,3.058626,1.0711656,,,,,,,,,,,,,, -180400,3.2967536,1.1315942,,,,,,,,,,,,,, -180500,2.9124088,1.0584068,,,,,,,,,,,,,, -180600,3.5365815,1.8369129,,,,,,,,,,,,,, -180700,3.4605136,1.235132,,,,,,,,,,,,,, -180800,3.0677483,1.9509826,,,,,,,,,,,,,, -180869,,,0.8839648365974426,0.4304944276809692,0.7829399704933167,0.8545869588851929,50000.0,0.661300003528595,1.453679442405701,10000.0,79446.53990650177,86543.125269413,79446.53990650177,7077.8280510902405,9.303065538406372,0.0 -180900,3.3781996,1.0958247,,,,,,,,,,,,,, -181000,2.90438,1.2067856,,,,,,,,,,,,,, -181100,3.0901241,1.121116,,,,,,,,,,,,,, -181200,3.3125648,2.9449167,,,,,,,,,,,,,, -181300,3.160262,1.1361939,,,,,,,,,,,,,, -181400,3.1419973,1.2237294,,,,,,,,,,,,,, -181500,2.8702638,1.2927493,,,,,,,,,,,,,, -181600,2.8785565,1.312192,,,,,,,,,,,,,, -181700,3.0854936,1.9492263,,,,,,,,,,,,,, -181800,4.1020336,3.2355924,,,,,,,,,,,,,, -181824,,,0.885546863079071,0.4214950501918793,0.782480001449585,0.8549716472625732,50000.0,0.6606000065803528,1.4558229446411133,10000.0,79866.53466320038,87008.9695174694,79866.53466320038,7123.566502571106,9.364516973495483,0.0 -181900,3.262254,1.2175688,,,,,,,,,,,,,, -182000,3.09394,1.1720047,,,,,,,,,,,,,, -182100,3.6639762,3.184204,,,,,,,,,,,,,, -182200,3.7219563,3.1790588,,,,,,,,,,,,,, -182300,3.351074,1.2366409,,,,,,,,,,,,,, -182400,3.271112,1.1566018,,,,,,,,,,,,,, -182500,2.9636748,1.5535071,,,,,,,,,,,,,, -182600,3.0532603,2.5088584,,,,,,,,,,,,,, -182700,3.152706,1.0535492,,,,,,,,,,,,,, -182779,,,0.8897070288658142,0.411954402923584,0.7824400067329407,0.8548983931541443,50000.0,0.660800039768219,1.4561413526535034,10000.0,80286.43485283852,87478.32796931267,80286.43485283852,7172.918921947479,9.42094874382019,0.0 -182800,3.1304238,2.8129346,,,,,,,,,,,,,, -182900,3.162564,1.9135003,,,,,,,,,,,,,, -183000,3.3593426,2.8645363,,,,,,,,,,,,,, -183100,3.254154,1.1470772,,,,,,,,,,,,,, -183200,3.1220639,1.9352291,,,,,,,,,,,,,, -183300,3.0886457,2.1588345,,,,,,,,,,,,,, -183400,3.0734208,1.3131824,,,,,,,,,,,,,, -183500,3.5174236,1.9484872,,,,,,,,,,,,,, -183600,3.5536513,2.224946,,,,,,,,,,,,,, -183700,2.9826477,1.806155,,,,,,,,,,,,,, -183735,,,0.8878905773162842,0.4117888808250427,0.782759964466095,0.852722704410553,50000.0,0.6620000600814819,1.4533320665359497,10000.0,80706.45871520042,87939.42469787598,80706.45871520042,7213.88618850708,9.477174997329712,0.0 -183800,3.1228027,1.0928388,,,,,,,,,,,,,, -183900,3.114101,1.1217536,,,,,,,,,,,,,, -184000,3.1012545,2.3233824,,,,,,,,,,,,,, -184100,3.1086617,1.0842652,,,,,,,,,,,,,, -184200,2.7687478,1.5134808,,,,,,,,,,,,,, -184300,3.1653793,1.0647074,,,,,,,,,,,,,, -184400,3.1746328,2.282972,,,,,,,,,,,,,, -184500,2.9947767,1.2000499,,,,,,,,,,,,,, -184600,3.3478677,2.4608877,,,,,,,,,,,,,, -184689,,,0.8874609470367432,0.4206721782684326,0.7829999923706055,0.8523223400115967,50000.0,0.6610000133514404,1.4538012742996216,10000.0,81126.73969316483,88406.53539347649,81126.73969316483,7260.604390859604,9.539127111434937,0.0 -184700,3.2448542,1.1409057,,,,,,,,,,,,,, -184800,3.0055354,2.177352,,,,,,,,,,,,,, -184900,2.98551,1.077185,,,,,,,,,,,,,, -185000,3.6033697,3.077945,,,,,,,,,,,,,, -185100,3.0094478,1.1187364,,,,,,,,,,,,,, -185200,3.1835568,2.6966727,,,,,,,,,,,,,, -185300,3.102168,2.7232783,,,,,,,,,,,,,, -185400,3.211481,1.9356838,,,,,,,,,,,,,, -185500,3.0314455,1.0764195,,,,,,,,,,,,,, -185600,2.995234,1.1169306,,,,,,,,,,,,,, -185645,,,0.8898437023162842,0.4111132025718689,0.7829999923706055,0.8524206280708313,50000.0,0.6612000465393066,1.4538347721099854,10000.0,81546.66754245758,88866.37303447723,81546.66754245758,7300.4049389362335,9.598478078842165,0.0 -185700,3.2039464,1.4061859,,,,,,,,,,,,,, -185800,3.32323,1.1765869,,,,,,,,,,,,,, -185900,2.9389882,1.0578198,,,,,,,,,,,,,, -186000,3.0557983,1.4874433,,,,,,,,,,,,,, -186100,3.023746,2.22816,,,,,,,,,,,,,, -186200,3.3918204,2.1344306,,,,,,,,,,,,,, -186300,3.148069,1.1857141,,,,,,,,,,,,,, -186400,3.344888,1.1195483,,,,,,,,,,,,,, -186500,2.9165182,1.3962663,,,,,,,,,,,,,, -186599,,,0.8867577910423279,0.4253540337085724,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.4538928270339966,10000.0,81966.87752747536,89330.82261490822,81966.87752747536,7344.524179458618,9.663646221160889,0.0 -186600,3.4597504,1.4662642,,,,,,,,,,,,,, -186700,2.9988372,1.2226208,,,,,,,,,,,,,, -186800,3.716844,3.1355653,,,,,,,,,,,,,, -186900,3.0386817,1.3580289,,,,,,,,,,,,,, -187000,3.1424487,2.6712713,,,,,,,,,,,,,, -187100,3.1720731,2.2651608,,,,,,,,,,,,,, -187200,3.0913322,1.127998,,,,,,,,,,,,,, -187300,2.9642634,1.0578206,,,,,,,,,,,,,, -187400,3.1453297,1.2029631,,,,,,,,,,,,,, -187500,3.0286996,1.1482762,,,,,,,,,,,,,, -187550,,,0.8905078172683716,0.4123665988445282,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,82386.98583173752,89798.33035802841,82386.98583173752,7391.805383205414,9.732877016067505,0.0 -187600,3.3495655,1.0648266,,,,,,,,,,,,,, -187700,3.5152586,2.8960826,,,,,,,,,,,,,, -187800,3.027325,1.7679319,,,,,,,,,,,,,, -187900,3.2499754,1.4570148,,,,,,,,,,,,,, -188000,3.1088717,1.0621841,,,,,,,,,,,,,, -188100,3.216502,1.1793629,,,,,,,,,,,,,, -188200,3.256152,2.3547804,,,,,,,,,,,,,, -188300,3.5526152,1.124579,,,,,,,,,,,,,, -188400,4.0210056,3.2396047,,,,,,,,,,,,,, -188500,4.3867416,3.170289,,,,,,,,,,,,,, -188507,,,0.8867382407188416,0.4162900745868683,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,82807.08119153976,90257.17705130576,82807.08119153976,7430.450542926788,9.789993047714232,0.0 -188600,3.0187266,1.4940882,,,,,,,,,,,,,, -188700,3.5668843,1.1626385,,,,,,,,,,,,,, -188800,3.1980677,1.0669911,,,,,,,,,,,,,, -188900,2.7679427,1.2493148,,,,,,,,,,,,,, -189000,3.2182229,1.1678796,,,,,,,,,,,,,, -189100,3.0997694,1.2997166,,,,,,,,,,,,,, -189200,3.302368,2.638638,,,,,,,,,,,,,, -189300,3.2339396,1.60322,,,,,,,,,,,,,, -189400,3.1756513,1.8776255,,,,,,,,,,,,,, -189465,,,0.8881444931030273,0.415896475315094,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,83227.10395240784,90722.47425937653,83227.10395240784,7475.613933086395,9.851112365722656,0.0 -189500,3.7489645,3.2058163,,,,,,,,,,,,,, -189600,3.17406,1.1955435,,,,,,,,,,,,,, -189700,3.0649655,2.2059317,,,,,,,,,,,,,, -189800,3.2700253,2.48621,,,,,,,,,,,,,, -189900,3.3470223,1.2179612,,,,,,,,,,,,,, -190000,2.812618,2.1251311,,,,,,,,,,,,,, -190100,3.1042402,1.9644974,,,,,,,,,,,,,, -190200,3.2769055,2.8451753,,,,,,,,,,,,,, -190300,3.525148,1.172544,,,,,,,,,,,,,, -190400,2.8902793,0.98960286,,,,,,,,,,,,,, -190421,,,0.8898632526397705,0.4129151701927185,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,83646.99283194542,91183.8697359562,83646.99283194542,7517.003622770309,9.918134927749634,0.0 -190500,3.3847582,0.9560254,,,,,,,,,,,,,, -190600,3.2288926,1.10643,,,,,,,,,,,,,, -190700,3.7309637,2.9858985,,,,,,,,,,,,,, -190800,2.9195926,1.4146436,,,,,,,,,,,,,, -190900,3.2033122,1.230589,,,,,,,,,,,,,, -191000,3.3384411,1.2353407,,,,,,,,,,,,,, -191100,3.185966,1.1992601,,,,,,,,,,,,,, -191200,3.197064,2.4249578,,,,,,,,,,,,,, -191300,3.1187766,1.8815036,,,,,,,,,,,,,, -191376,,,0.8852733969688416,0.4186182618141174,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,84066.93710017204,91644.4737186432,84066.93710017204,7557.5557742118835,9.97598123550415,0.0 -191400,3.0743692,2.6902323,,,,,,,,,,,,,, -191500,3.1215305,1.1921408,,,,,,,,,,,,,, -191600,3.5708022,2.4779115,,,,,,,,,,,,,, -191700,2.9435623,2.271564,,,,,,,,,,,,,, -191800,3.5485168,3.0088284,,,,,,,,,,,,,, -191900,3.2279525,1.6018218,,,,,,,,,,,,,, -192000,3.2788093,1.1382736,,,,,,,,,,,,,, -192100,2.8630903,1.7055374,,,,,,,,,,,,,, -192200,3.0603619,1.180898,,,,,,,,,,,,,, -192300,3.6751497,3.1632624,,,,,,,,,,,,,, -192333,,,0.8894140720367432,0.415103018283844,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,84486.87456440926,92105.04097628592,84486.87456440926,7598.079861402512,10.03218388557434,0.0 -192400,3.516607,2.9257348,,,,,,,,,,,,,, -192500,3.071617,1.0706847,,,,,,,,,,,,,, -192600,3.56444,3.0824234,,,,,,,,,,,,,, -192700,3.0245063,1.2638518,,,,,,,,,,,,,, -192800,3.3061357,2.7856166,,,,,,,,,,,,,, -192900,3.241989,1.1530452,,,,,,,,,,,,,, -193000,3.3129213,1.1098552,,,,,,,,,,,,,, -193100,3.5468516,2.714337,,,,,,,,,,,,,, -193200,3.438621,1.207556,,,,,,,,,,,,,, -193288,,,0.8891796469688416,0.4154730439186096,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,84906.87883043289,92563.49833774568,84906.87883043289,7636.415265798569,10.100399255752563,0.0 -193300,2.882967,1.1175386,,,,,,,,,,,,,, -193400,3.3533845,1.3414932,,,,,,,,,,,,,, -193500,3.1594958,1.2406497,,,,,,,,,,,,,, -193600,3.090011,1.6218848,,,,,,,,,,,,,, -193700,2.9658096,1.3371334,,,,,,,,,,,,,, -193800,3.279915,1.4941778,,,,,,,,,,,,,, -193900,2.8343163,1.5573134,,,,,,,,,,,,,, -194000,3.3451266,1.6068133,,,,,,,,,,,,,, -194100,4.421242,3.0091243,,,,,,,,,,,,,, -194200,2.9795136,1.1975945,,,,,,,,,,,,,, -194241,,,0.8891015648841858,0.4140681624412536,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,85327.06767630577,93033.19660949708,85327.06767630577,7685.801539182663,10.174046993255615,0.0 -194300,3.0559618,2.4642189,,,,,,,,,,,,,, -194400,3.3376796,1.1395242,,,,,,,,,,,,,, -194500,3.140771,1.1627511,,,,,,,,,,,,,, -194600,3.0679786,1.0441368,,,,,,,,,,,,,, -194700,3.2210348,1.1653548,,,,,,,,,,,,,, -194800,2.9650464,1.1030067,,,,,,,,,,,,,, -194900,3.201996,1.0977486,,,,,,,,,,,,,, -195000,3.1900866,1.1087809,,,,,,,,,,,,,, -195100,3.0706995,1.0713139,,,,,,,,,,,,,, -195200,,,0.8866796493530273,0.4165627956390381,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,85746.978110075,93496.65905618668,85746.978110075,7729.244254589081,10.232911586761476,0.0 -195200,3.2173824,1.2746165,,,,,,,,,,,,,, -195300,3.405438,1.106155,,,,,,,,,,,,,, -195400,2.968658,1.0120058,,,,,,,,,,,,,, -195500,3.3579295,1.024258,,,,,,,,,,,,,, -195600,2.9830773,2.28729,,,,,,,,,,,,,, -195700,3.352834,1.10147,,,,,,,,,,,,,, -195800,3.57119,3.2692027,,,,,,,,,,,,,, -195900,3.2687716,1.0846398,,,,,,,,,,,,,, -196000,3.8189242,3.243184,,,,,,,,,,,,,, -196100,3.1497066,1.1170299,,,,,,,,,,,,,, -196157,,,0.8865429759025574,0.425489604473114,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,86167.28060626984,93959.505079031,86167.28060626984,7771.67853140831,10.292376279830933,0.0 -196200,3.2309275,2.453545,,,,,,,,,,,,,, -196300,2.976653,1.1201125,,,,,,,,,,,,,, -196400,3.4379911,2.9368236,,,,,,,,,,,,,, -196500,3.482289,3.0395093,,,,,,,,,,,,,, -196600,2.8651545,0.96953744,,,,,,,,,,,,,, -196700,3.905625,3.071516,,,,,,,,,,,,,, -196800,2.8167205,1.4858584,,,,,,,,,,,,,, -196900,3.3284898,2.8526645,,,,,,,,,,,,,, -197000,3.0359914,1.1653566,,,,,,,,,,,,,, -197100,3.609127,2.86558,,,,,,,,,,,,,, -197109,,,0.8899609446525574,0.4093320369720459,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,86587.31809902191,94417.6742913723,86587.31809902191,7809.706657886505,10.347052335739136,0.0 -197200,3.3285563,1.5329388,,,,,,,,,,,,,, -197300,3.3386095,1.2000347,,,,,,,,,,,,,, -197400,2.973482,1.7295922,,,,,,,,,,,,,, -197500,3.3051937,1.119831,,,,,,,,,,,,,, -197600,3.9433222,3.0728161,,,,,,,,,,,,,, -197700,3.3322217,1.129025,,,,,,,,,,,,,, -197800,3.3384871,2.298079,,,,,,,,,,,,,, -197900,2.9775252,1.8815776,,,,,,,,,,,,,, -198000,3.079885,1.1714832,,,,,,,,,,,,,, -198061,,,0.8864452838897705,0.4223592579364776,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,87007.37144398689,94876.27783370018,87007.37144398689,7848.151168823242,10.403776168823242,0.0 -198100,3.2848306,2.861828,,,,,,,,,,,,,, -198200,2.8869128,1.9748266,,,,,,,,,,,,,, -198300,3.0434964,1.1458447,,,,,,,,,,,,,, -198400,3.264108,1.8241438,,,,,,,,,,,,,, -198500,3.1947007,2.3076568,,,,,,,,,,,,,, -198600,3.2187595,1.245939,,,,,,,,,,,,,, -198700,2.9733477,1.0684807,,,,,,,,,,,,,, -198800,3.4837809,3.1634152,,,,,,,,,,,,,, -198900,2.9414294,1.0992706,,,,,,,,,,,,,, -199000,3.4978578,1.1459819,,,,,,,,,,,,,, -199014,,,0.8879101276397705,0.4180936813354492,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,87427.40733599663,95339.37646174432,87427.40733599663,7891.088857412338,10.479464530944824,0.0 -199100,3.2149193,1.3033922,,,,,,,,,,,,,, -199200,2.9418068,2.467455,,,,,,,,,,,,,, -199300,3.840474,3.303615,,,,,,,,,,,,,, -199400,3.598794,1.4160428,,,,,,,,,,,,,, -199500,3.0804284,1.8096204,,,,,,,,,,,,,, -199600,3.2001352,1.1397257,,,,,,,,,,,,,, -199700,3.2899368,2.36273,,,,,,,,,,,,,, -199800,3.2599547,1.1770232,,,,,,,,,,,,,, -199900,4.337386,3.3127193,,,,,,,,,,,,,, -199970,,,0.8871484398841858,0.4199837148189544,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,87847.40805268288,95804.42789506912,87847.40805268288,7936.020844221115,10.547487020492554,0.0 -200000,3.8267496,3.1701634,,,,,,,,,,,,,, -200100,3.5762875,3.0108643,,,,,,,,,,,,,, -200200,3.7159698,3.2993927,,,,,,,,,,,,,, -200300,3.1027708,1.1027191,,,,,,,,,,,,,, -200400,3.0136642,1.1721742,,,,,,,,,,,,,, -200500,3.3668277,1.0757723,,,,,,,,,,,,,, -200600,3.2481165,1.1949359,,,,,,,,,,,,,, -200700,3.5401871,3.1857371,,,,,,,,,,,,,, -200800,3.117517,2.61357,,,,,,,,,,,,,, -200900,3.31096,1.178932,,,,,,,,,,,,,, -200927,,,0.8883788585662842,0.4186073839664459,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,88267.30465459824,96260.49082493782,88267.30465459824,7972.082072734833,10.60279655456543,0.0 -201000,3.1751745,1.1827704,,,,,,,,,,,,,, -201100,3.0106065,1.2449791,,,,,,,,,,,,,, -201200,3.1507435,2.8313282,,,,,,,,,,,,,, -201300,2.8662915,1.7012043,,,,,,,,,,,,,, -201400,2.807403,1.9101691,,,,,,,,,,,,,, -201500,3.6821818,3.2034898,,,,,,,,,,,,,, -201600,3.1486058,1.378016,,,,,,,,,,,,,, -201700,3.1369114,1.6299981,,,,,,,,,,,,,, -201800,3.8895104,1.2104688,,,,,,,,,,,,,, -201881,,,0.8891015648841858,0.4090310931205749,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,88686.67367577553,96720.3174176216,88686.67367577553,8011.646758794785,11.446199417114258,0.0 -201900,3.2818995,1.1363162,,,,,,,,,,,,,, -202000,3.1650884,1.1293694,,,,,,,,,,,,,, -202100,3.4185655,2.4401836,,,,,,,,,,,,,, -202200,2.932616,1.3882352,,,,,,,,,,,,,, -202300,3.0428789,1.4155966,,,,,,,,,,,,,, -202400,4.0123096,3.1772037,,,,,,,,,,,,,, -202500,3.2934046,2.2477047,,,,,,,,,,,,,, -202600,3.3943768,2.9025006,,,,,,,,,,,,,, -202700,3.0173635,1.1882406,,,,,,,,,,,,,, -202800,3.4549174,2.3904514,,,,,,,,,,,,,, -202837,,,0.88636714220047,0.4229801595211029,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,89106.6478202343,97178.28229618073,89106.6478202343,8049.51655459404,11.518182754516602,0.0 -202900,3.270326,1.180737,,,,,,,,,,,,,, -203000,3.3418849,1.1503605,,,,,,,,,,,,,, -203100,3.0662806,1.1602538,,,,,,,,,,,,,, -203200,3.7755764,3.2399883,,,,,,,,,,,,,, -203300,2.8830101,1.0898821,,,,,,,,,,,,,, -203400,3.3749664,2.856708,,,,,,,,,,,,,, -203500,3.1526556,1.6317382,,,,,,,,,,,,,, -203600,3.327343,1.1157176,,,,,,,,,,,,,, -203700,3.1154299,1.0285633,,,,,,,,,,,,,, -203791,,,0.888476550579071,0.4164254367351532,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,89526.93627262115,97636.9053592682,89526.93627262115,8087.728713512421,11.590760707855225,0.0 -203800,2.9297302,1.1279844,,,,,,,,,,,,,, -203900,3.1542063,1.3780948,,,,,,,,,,,,,, -204000,3.2198365,3.0275679,,,,,,,,,,,,,, -204100,3.1164987,1.1609573,,,,,,,,,,,,,, -204200,3.2237053,1.2414961,,,,,,,,,,,,,, -204300,3.1372519,1.2837267,,,,,,,,,,,,,, -204400,3.175511,1.1118474,,,,,,,,,,,,,, -204500,2.8484836,1.5337094,,,,,,,,,,,,,, -204600,3.211027,2.680399,,,,,,,,,,,,,, -204700,2.989402,1.2994109,,,,,,,,,,,,,, -204745,,,0.8869921565055847,0.4216805696487427,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,89946.85517311096,98092.91681575777,89946.85517311096,8123.699931621551,11.662600040435793,0.0 -204800,3.1076233,2.9109862,,,,,,,,,,,,,, -204900,2.7588873,1.1052397,,,,,,,,,,,,,, -205000,3.301493,1.2025819,,,,,,,,,,,,,, -205100,3.5072672,1.1899154,,,,,,,,,,,,,, -205200,3.3338296,2.246422,,,,,,,,,,,,,, -205300,3.6071012,1.2924757,,,,,,,,,,,,,, -205400,3.0416453,1.1643026,,,,,,,,,,,,,, -205500,2.8357582,1.2381107,,,,,,,,,,,,,, -205600,2.9664202,1.0820566,,,,,,,,,,,,,, -205697,,,0.8867382407188416,0.4210879802703857,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,90367.04581069946,98554.3462498188,90367.04581069946,8164.820489406586,11.73024082183838,0.0 -205700,3.0556521,1.0979323,,,,,,,,,,,,,, -205800,2.949609,1.1785576,,,,,,,,,,,,,, -205900,3.0714111,1.2729545,,,,,,,,,,,,,, -206000,3.1639757,1.0777528,,,,,,,,,,,,,, -206100,3.1408975,2.79832,,,,,,,,,,,,,, -206200,3.0918875,1.0698488,,,,,,,,,,,,,, -206300,3.5154896,1.5859097,,,,,,,,,,,,,, -206400,3.1833878,1.1521761,,,,,,,,,,,,,, -206500,4.327448,3.2359638,,,,,,,,,,,,,, -206600,3.01378,1.2388165,,,,,,,,,,,,,, -206657,,,0.8891015648841858,0.4129902124404907,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,90787.44080162048,99007.24154257774,90787.44080162048,8197.201642036438,11.799509048461914,0.0 -206700,3.2288148,1.0481795,,,,,,,,,,,,,, -206800,3.1542718,1.9167279,,,,,,,,,,,,,, -206900,3.1337242,2.4590464,,,,,,,,,,,,,, -207000,3.101426,1.1603953,,,,,,,,,,,,,, -207100,3.8239,1.6607841,,,,,,,,,,,,,, -207200,3.451393,1.3039311,,,,,,,,,,,,,, -207300,3.1962605,1.0707122,,,,,,,,,,,,,, -207400,3.3799174,3.0301046,,,,,,,,,,,,,, -207500,3.176399,1.9991211,,,,,,,,,,,,,, -207600,3.3854365,3.0224621,,,,,,,,,,,,,, -207611,,,0.88818359375,0.4146837592124939,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,91207.7722082138,99462.24376773834,91207.7722082138,8231.754022359848,11.868899822235107,0.0 -207700,2.8892305,1.0477989,,,,,,,,,,,,,, -207800,3.1440883,1.025791,,,,,,,,,,,,,, -207900,3.254473,1.119927,,,,,,,,,,,,,, -208000,3.4373422,1.1721294,,,,,,,,,,,,,, -208100,3.0910985,1.2613896,,,,,,,,,,,,,, -208200,3.218342,3.0212903,,,,,,,,,,,,,, -208300,3.1682298,2.5258524,,,,,,,,,,,,,, -208400,3.2810268,1.6786509,,,,,,,,,,,,,, -208500,3.8930054,3.1998363,,,,,,,,,,,,,, -208563,,,0.8888476490974426,0.4169521927833557,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,91627.7535970211,99919.61623358728,91627.7535970211,8269.022126674652,11.942992687225342,0.0 -208600,2.7865672,1.6450655,,,,,,,,,,,,,, -208700,3.112542,1.3509867,,,,,,,,,,,,,, -208800,3.7667112,3.1097984,,,,,,,,,,,,,, -208900,3.1543505,2.4645524,,,,,,,,,,,,,, -209000,3.173783,2.5129547,,,,,,,,,,,,,, -209100,3.3712974,1.2568955,,,,,,,,,,,,,, -209200,2.889051,1.2374063,,,,,,,,,,,,,, -209300,3.326411,1.322256,,,,,,,,,,,,,, -209400,3.0081542,2.7577083,,,,,,,,,,,,,, -209500,3.233878,1.2097884,,,,,,,,,,,,,, -209518,,,0.887011706829071,0.4165227115154266,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,92047.78227806091,100386.16372036934,92047.78227806091,8315.414091348648,12.020440578460692,0.0 -209600,2.958099,1.6658816,,,,,,,,,,,,,, -209700,2.977169,1.1249651,,,,,,,,,,,,,, -209800,2.9272833,1.8135839,,,,,,,,,,,,,, -209900,2.7754455,1.1529962,,,,,,,,,,,,,, -210000,3.3087776,1.0963436,,,,,,,,,,,,,, -210100,3.0792665,1.0646555,,,,,,,,,,,,,, -210200,2.9497201,1.2023424,,,,,,,,,,,,,, -210300,3.42293,2.955173,,,,,,,,,,,,,, -210400,4.4674106,1.1141082,,,,,,,,,,,,,, -210479,,,0.8866015672683716,0.4254970550537109,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,92468.02927017212,100843.42229747772,92468.02927017212,8352.31871843338,12.077834129333496,0.0 -210500,3.185053,1.1322722,,,,,,,,,,,,,, -210600,2.9776978,1.314537,,,,,,,,,,,,,, -210700,2.9041493,1.2432003,,,,,,,,,,,,,, -210800,2.9229124,1.2626961,,,,,,,,,,,,,, -210900,3.2686694,1.1227082,,,,,,,,,,,,,, -211000,2.7829688,1.4665718,,,,,,,,,,,,,, -211100,3.1865795,1.1028944,,,,,,,,,,,,,, -211200,3.0954862,1.2457871,,,,,,,,,,,,,, -211300,3.6419628,1.1873363,,,,,,,,,,,,,, -211400,3.1719675,1.3052158,,,,,,,,,,,,,, -211435,,,0.8909765481948853,0.4111464917659759,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,92887.907848835,101298.48625206947,92887.907848835,8387.379422187805,12.153021335601808,0.0 -211500,3.1539028,1.9652774,,,,,,,,,,,,,, -211600,3.2635117,2.920053,,,,,,,,,,,,,, -211700,3.249,1.1165322,,,,,,,,,,,,,, -211800,3.281845,1.8535911,,,,,,,,,,,,,, -211900,3.5759165,2.994427,,,,,,,,,,,,,, -212000,3.2384036,1.2036062,,,,,,,,,,,,,, -212100,3.3602664,1.2581645,,,,,,,,,,,,,, -212200,2.735136,1.255832,,,,,,,,,,,,,, -212300,2.9256198,1.5481927,,,,,,,,,,,,,, -212387,,,0.8878515362739563,0.4138557016849518,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,93307.9272289276,101753.40010499954,93307.9272289276,8422.15352678299,12.22426986694336,0.0 -212400,3.481138,1.1835968,,,,,,,,,,,,,, -212500,3.1860528,1.1869053,,,,,,,,,,,,,, -212600,3.4547307,3.024565,,,,,,,,,,,,,, -212700,2.920779,1.1179721,,,,,,,,,,,,,, -212800,3.1781964,2.6422071,,,,,,,,,,,,,, -212900,3.0869453,1.8647755,,,,,,,,,,,,,, -213000,2.9738505,2.2361426,,,,,,,,,,,,,, -213100,3.0550146,1.1230564,,,,,,,,,,,,,, -213200,3.4680548,1.194775,,,,,,,,,,,,,, -213300,2.9244618,1.1457546,,,,,,,,,,,,,, -213342,,,0.8890234231948853,0.4153840839862823,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,93728.0548722744,102219.9491224289,93728.0548722744,8468.450547218323,12.29885721206665,0.0 -213400,3.0400124,1.1139212,,,,,,,,,,,,,, -213500,2.9377248,2.1531725,,,,,,,,,,,,,, -213600,3.979797,3.1998858,,,,,,,,,,,,,, -213700,3.4629247,2.8816974,,,,,,,,,,,,,, -213800,3.131738,1.8542614,,,,,,,,,,,,,, -213900,2.89779,1.1185565,,,,,,,,,,,,,, -214000,3.347716,2.606276,,,,,,,,,,,,,, -214100,2.9087207,1.0421996,,,,,,,,,,,,,, -214200,2.9517398,1.2148296,,,,,,,,,,,,,, -214300,,,0.8884375095367432,0.4121640026569366,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,94148.00626373292,102675.33705687524,94148.00626373292,8503.7742228508,12.362093687057495,0.0 -214300,3.5324628,3.0655303,,,,,,,,,,,,,, -214400,3.13231,1.4657627,,,,,,,,,,,,,, -214500,3.3035529,1.6872041,,,,,,,,,,,,,, -214600,3.4171762,1.1354427,,,,,,,,,,,,,, -214700,2.9906743,1.112381,,,,,,,,,,,,,, -214800,3.5178497,1.0777752,,,,,,,,,,,,,, -214900,3.839043,2.7417424,,,,,,,,,,,,,, -215000,3.4628968,2.6256304,,,,,,,,,,,,,, -215100,2.874151,1.1409223,,,,,,,,,,,,,, -215200,3.0942159,1.1432276,,,,,,,,,,,,,, -215256,,,0.8880078196525574,0.4170250296592712,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,94567.9120604992,103132.16620612144,94567.9120604992,8540.573637008667,12.436352014541626,0.0 -215300,3.08542,1.1105796,,,,,,,,,,,,,, -215400,2.9510448,1.0911138,,,,,,,,,,,,,, -215500,3.0887473,2.1496806,,,,,,,,,,,,,, -215600,3.2034948,1.194885,,,,,,,,,,,,,, -215700,3.0625992,1.2330756,,,,,,,,,,,,,, -215800,3.0399134,1.2762277,,,,,,,,,,,,,, -215900,2.8729486,1.858388,,,,,,,,,,,,,, -216000,3.1917038,2.2645223,,,,,,,,,,,,,, -216100,3.2089388,2.7209084,,,,,,,,,,,,,, -216200,3.2191586,1.2076573,,,,,,,,,,,,,, -216205,,,0.8862890601158142,0.4187511801719665,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,94987.81290459631,103597.94665384293,94987.81290459631,8586.329691886902,12.511003971099854,0.0 -216300,3.2113175,1.3459393,,,,,,,,,,,,,, -216400,3.2319646,1.3525515,,,,,,,,,,,,,, -216500,3.076183,1.146714,,,,,,,,,,,,,, -216600,3.0764387,1.1400295,,,,,,,,,,,,,, -216700,3.3730848,1.1551266,,,,,,,,,,,,,, -216800,2.9720683,1.1355541,,,,,,,,,,,,,, -216900,3.5970125,1.3300519,,,,,,,,,,,,,, -217000,2.881048,1.8572509,,,,,,,,,,,,,, -217100,3.100997,1.1695751,,,,,,,,,,,,,, -217161,,,0.8904101252555847,0.4095988273620605,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,95407.8063764572,104049.49093866348,95407.8063764572,8617.764047384262,12.574751377105711,0.0 -217200,3.3486679,2.8943398,,,,,,,,,,,,,, -217300,3.7547758,3.0701764,,,,,,,,,,,,,, -217400,3.162925,1.7610549,,,,,,,,,,,,,, -217500,3.1608706,2.2348661,,,,,,,,,,,,,, -217600,2.9154568,1.3613575,,,,,,,,,,,,,, -217700,3.367172,1.2751367,,,,,,,,,,,,,, -217800,4.2241707,1.114326,,,,,,,,,,,,,, -217900,3.1509635,1.2073911,,,,,,,,,,,,,, -218000,3.1086123,2.0316665,,,,,,,,,,,,,, -218100,3.775953,1.5767219,,,,,,,,,,,,,, -218114,,,0.888476550579071,0.4177097082138061,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,95827.81059122086,104517.7988626957,95827.81059122086,8665.94267654419,12.649450063705444,0.0 -218200,3.11913,1.091579,,,,,,,,,,,,,, -218300,3.210892,1.3449707,,,,,,,,,,,,,, -218400,3.3855672,1.1380328,,,,,,,,,,,,,, -218500,3.2740512,1.590173,,,,,,,,,,,,,, -218600,3.1184247,1.6699394,,,,,,,,,,,,,, -218700,3.007411,1.7379398,,,,,,,,,,,,,, -218800,2.9242628,1.4283297,,,,,,,,,,,,,, -218900,3.2328074,1.313684,,,,,,,,,,,,,, -219000,3.6111844,3.0449202,,,,,,,,,,,,,, -219073,,,0.8863476514816284,0.4216277599334717,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,96248.04551506042,104980.22134375572,96248.04551506042,8708.01383304596,12.715829849243164,0.0 -219100,3.6844883,2.834745,,,,,,,,,,,,,, -219200,4.017938,3.234201,,,,,,,,,,,,,, -219300,2.9031804,1.3715429,,,,,,,,,,,,,, -219400,2.9849517,1.798763,,,,,,,,,,,,,, -219500,3.0222616,1.1221137,,,,,,,,,,,,,, -219600,3.2374227,1.1435243,,,,,,,,,,,,,, -219700,2.9986868,1.0870417,,,,,,,,,,,,,, -219800,3.5247083,1.0823185,,,,,,,,,,,,,, -219900,3.6504714,2.2907722,,,,,,,,,,,,,, -220000,3.1549294,1.0381426,,,,,,,,,,,,,, -220029,,,0.8868749737739563,0.4227591156959533,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,96667.96599316596,105439.2344288826,96667.96599316596,8746.991399526596,12.78156876564026,0.0 -220100,3.9512684,3.2904322,,,,,,,,,,,,,, -220200,3.8644962,3.3049798,,,,,,,,,,,,,, -220300,3.2292674,1.2138507,,,,,,,,,,,,,, -220400,3.4185786,3.04361,,,,,,,,,,,,,, -220500,3.66943,3.1437464,,,,,,,,,,,,,, -220600,3.0950136,1.5506048,,,,,,,,,,,,,, -220700,3.2889645,1.917576,,,,,,,,,,,,,, -220800,3.218486,1.1472015,,,,,,,,,,,,,, -220900,2.929822,2.189452,,,,,,,,,,,,,, -220985,,,0.8891015648841858,0.4091844856739044,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,97088.07246875764,105897.30706262589,97088.07246875764,8784.837461471558,12.851918935775757,0.0 -221000,2.9579039,1.1718333,,,,,,,,,,,,,, -221100,3.337339,1.3991711,,,,,,,,,,,,,, -221200,3.1946635,1.1425736,,,,,,,,,,,,,, -221300,3.2047033,1.1866232,,,,,,,,,,,,,, -221400,3.7556512,2.7513144,,,,,,,,,,,,,, -221500,3.1569574,1.1466894,,,,,,,,,,,,,, -221600,3.2417614,1.0896277,,,,,,,,,,,,,, -221700,4.298221,3.2768064,,,,,,,,,,,,,, -221800,3.176201,1.0555527,,,,,,,,,,,,,, -221900,3.275796,1.1927073,,,,,,,,,,,,,, -221939,,,0.8893359303474426,0.4144483804702759,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,97508.19736647606,106360.98352241516,97508.19736647606,8828.260638713837,12.930778980255129,0.0 -222000,2.9147887,1.3342608,,,,,,,,,,,,,, -222100,3.0521743,2.0123456,,,,,,,,,,,,,, -222200,3.779694,1.1789846,,,,,,,,,,,,,, -222300,3.0558536,1.9683235,,,,,,,,,,,,,, -222400,3.0494452,1.2146211,,,,,,,,,,,,,, -222500,2.9702556,1.0385411,,,,,,,,,,,,,, -222600,3.0417533,1.6791115,,,,,,,,,,,,,, -222700,2.8777916,2.12359,,,,,,,,,,,,,, -222800,3.2870378,2.221792,,,,,,,,,,,,,, -222897,,,0.8866796493530273,0.4247627854347229,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,97928.41378474236,106817.44368696211,97928.41378474236,8864.390403270721,12.99502658843994,0.0 -222900,3.0907393,1.1693575,,,,,,,,,,,,,, -223000,3.167612,1.4100181,,,,,,,,,,,,,, -223100,3.1850047,1.1311061,,,,,,,,,,,,,, -223200,3.2147791,1.953866,,,,,,,,,,,,,, -223300,3.01788,1.2265458,,,,,,,,,,,,,, -223400,3.1380544,2.2138548,,,,,,,,,,,,,, -223500,2.9969888,1.0749335,,,,,,,,,,,,,, -223600,2.9970434,1.0831289,,,,,,,,,,,,,, -223700,3.4632132,3.0264897,,,,,,,,,,,,,, -223800,3.125978,1.216764,,,,,,,,,,,,,, -223846,,,0.8874804377555847,0.4221131503582001,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,98348.38157367706,107272.11375570296,98348.38157367706,8898.96737408638,13.071343421936035,0.0 -223900,3.49822,2.6978288,,,,,,,,,,,,,, -224000,3.3788,1.883278,,,,,,,,,,,,,, -224100,3.1798942,1.1860127,,,,,,,,,,,,,, -224200,2.9980507,1.9371123,,,,,,,,,,,,,, -224300,3.1856194,2.1328516,,,,,,,,,,,,,, -224400,3.3816059,1.2159173,,,,,,,,,,,,,, -224500,3.5262775,1.138131,,,,,,,,,,,,,, -224600,4.031432,3.2731173,,,,,,,,,,,,,, -224700,3.1105487,1.3084067,,,,,,,,,,,,,, -224799,,,0.8865624666213989,0.4165645837783813,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,98768.6687772274,107739.98773026466,98768.6687772274,8946.43201994896,13.144215106964111,0.0 -224800,3.648743,3.0776916,,,,,,,,,,,,,, -224900,3.153006,1.5956277,,,,,,,,,,,,,, -225000,3.1102562,1.704078,,,,,,,,,,,,,, -225100,3.1604767,1.4433445,,,,,,,,,,,,,, -225200,3.0562088,1.0726494,,,,,,,,,,,,,, -225300,2.9953747,1.7720975,,,,,,,,,,,,,, -225400,3.4589977,2.7504811,,,,,,,,,,,,,, -225500,3.4326212,1.1557692,,,,,,,,,,,,,, -225600,3.1040962,1.1279527,,,,,,,,,,,,,, -225700,4.6221027,3.2821598,,,,,,,,,,,,,, -225757,,,0.8892187476158142,0.4134066998958587,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,99188.98446559906,108192.5675497055,99188.98446559906,8978.581200838089,13.209935903549194,0.0 -225800,2.96896,1.1541147,,,,,,,,,,,,,, -225900,3.0964136,1.1358229,,,,,,,,,,,,,, -226000,3.197091,1.2031682,,,,,,,,,,,,,, -226100,3.085689,1.1672242,,,,,,,,,,,,,, -226200,3.1235774,1.4130539,,,,,,,,,,,,,, -226300,3.0794897,1.1552541,,,,,,,,,,,,,, -226400,3.1936407,2.5446527,,,,,,,,,,,,,, -226500,3.2990303,1.1563197,,,,,,,,,,,,,, -226600,3.2901394,2.1476164,,,,,,,,,,,,,, -226700,3.4619782,1.5762813,,,,,,,,,,,,,, -226713,,,0.8882030844688416,0.4167915880680084,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,99608.9012246132,108663.69064497948,99608.9012246132,9029.657244682312,13.288574695587158,0.0 -226800,3.2984874,1.051616,,,,,,,,,,,,,, -226900,2.947352,1.373816,,,,,,,,,,,,,, -227000,3.3690443,1.3297331,,,,,,,,,,,,,, -227100,3.2728922,1.9804065,,,,,,,,,,,,,, -227200,3.6003542,1.2863512,,,,,,,,,,,,,, -227300,2.7458546,1.9105612,,,,,,,,,,,,,, -227400,3.2918599,1.1182586,,,,,,,,,,,,,, -227500,3.1967716,1.0887274,,,,,,,,,,,,,, -227600,3.1490629,1.2784278,,,,,,,,,,,,,, -227672,,,0.8873828053474426,0.4202496111392975,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,100029.07903766632,109129.8581469059,100029.07903766632,9075.538096427916,13.347972631454468,0.0 -227700,3.1664927,1.1584446,,,,,,,,,,,,,, -227800,3.170177,2.3097465,,,,,,,,,,,,,, -227900,3.3090613,2.3594177,,,,,,,,,,,,,, -228000,3.0525885,1.7524037,,,,,,,,,,,,,, -228100,2.9497306,1.0500536,,,,,,,,,,,,,, -228200,3.0447311,1.1646703,,,,,,,,,,,,,, -228300,3.0116827,1.2803023,,,,,,,,,,,,,, -228400,3.3488579,2.024958,,,,,,,,,,,,,, -228500,3.107905,2.2358372,,,,,,,,,,,,,, -228600,3.060994,2.3779447,,,,,,,,,,,,,, -228631,,,0.8864648342132568,0.4214226305484772,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,100449.26325941086,109584.58092451096,100449.26325941086,9109.95995593071,13.414387702941896,0.0 -228700,3.1088011,2.2879267,,,,,,,,,,,,,, -228800,3.097888,2.6402705,,,,,,,,,,,,,, -228900,3.1201463,1.0625745,,,,,,,,,,,,,, -229000,2.8000448,1.3704766,,,,,,,,,,,,,, -229100,3.318836,1.115973,,,,,,,,,,,,,, -229200,3.4819946,1.2837121,,,,,,,,,,,,,, -229300,3.473264,1.0434225,,,,,,,,,,,,,, -229400,3.6444418,1.0332704,,,,,,,,,,,,,, -229500,3.1080983,2.5140715,,,,,,,,,,,,,, -229584,,,0.8885351419448853,0.4174453616142273,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,100869.45522499084,110043.32839989662,100869.45522499084,9148.38451385498,13.495585680007936,0.0 -229600,3.7010183,3.2040126,,,,,,,,,,,,,, -229700,3.0108356,1.1343021,,,,,,,,,,,,,, -229800,3.1217942,1.1724186,,,,,,,,,,,,,, -229900,2.9076855,1.0771661,,,,,,,,,,,,,, -230000,2.8010564,1.5097929,,,,,,,,,,,,,, -230100,4.0807714,3.2103288,,,,,,,,,,,,,, -230200,3.2367854,1.2683316,,,,,,,,,,,,,, -230300,3.1035974,1.6512742,,,,,,,,,,,,,, -230400,3.0092552,1.0668907,,,,,,,,,,,,,, -230500,3.1274035,2.8023639,,,,,,,,,,,,,, -230534,,,0.8880273103713989,0.4105813503265381,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,101289.73925161362,110499.09790325163,101289.73925161362,9183.717219114304,13.598756313323976,0.0 -230600,3.3831317,1.0613155,,,,,,,,,,,,,, -230700,3.1371317,1.0957193,,,,,,,,,,,,,, -230800,2.8593733,1.2316419,,,,,,,,,,,,,, -230900,3.1273935,2.4994187,,,,,,,,,,,,,, -231000,3.009243,1.1368842,,,,,,,,,,,,,, -231100,3.1022284,1.0451584,,,,,,,,,,,,,, -231200,3.0559485,1.2063324,,,,,,,,,,,,,, -231300,3.034599,1.1436707,,,,,,,,,,,,,, -231400,3.4205573,1.4771681,,,,,,,,,,,,,, -231490,,,0.887499988079071,0.4188116490840912,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,101709.70796656609,110957.91074514388,101709.70796656609,9222.439522743223,13.67095160484314,0.0 -231500,3.315052,1.4461282,,,,,,,,,,,,,, -231600,3.214583,1.1028994,,,,,,,,,,,,,, -231700,3.3601718,2.81094,,,,,,,,,,,,,, -231800,3.053485,1.4083251,,,,,,,,,,,,,, -231900,3.9900548,3.2913284,,,,,,,,,,,,,, -232000,3.505371,1.1210401,,,,,,,,,,,,,, -232100,3.4562802,1.4839041,,,,,,,,,,,,,, -232200,3.0660727,2.4016256,,,,,,,,,,,,,, -232300,3.1257768,1.2967778,,,,,,,,,,,,,, -232400,2.9052212,1.9899173,,,,,,,,,,,,,, -232446,,,0.8887304663658142,0.4167661666870117,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,102129.62954187392,111415.52277565002,102129.62954187392,9260.00146317482,13.75056290626526,0.0 -232500,3.6883724,3.2099154,,,,,,,,,,,,,, -232600,2.981727,1.0378876,,,,,,,,,,,,,, -232700,3.2528608,1.1452594,,,,,,,,,,,,,, -232800,3.3706884,1.1285741,,,,,,,,,,,,,, -232900,3.2018964,1.161404,,,,,,,,,,,,,, -233000,4.024102,3.4001822,,,,,,,,,,,,,, -233100,3.8303142,3.286399,,,,,,,,,,,,,, -233200,3.2694445,1.197203,,,,,,,,,,,,,, -233300,2.8357544,1.7495638,,,,,,,,,,,,,, -233400,,,0.887499988079071,0.4170728921890259,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,102549.59725284576,111875.75222849846,102549.59725284576,9300.13559770584,13.828028678894045,0.0 -233400,2.9002883,1.0885193,,,,,,,,,,,,,, -233500,2.8883026,2.0312154,,,,,,,,,,,,,, -233600,3.7946527,1.1670731,,,,,,,,,,,,,, -233700,3.3155203,1.1254659,,,,,,,,,,,,,, -233800,3.1226716,1.1051974,,,,,,,,,,,,,, -233900,2.8360527,1.8175582,,,,,,,,,,,,,, -234000,3.469065,3.1302726,,,,,,,,,,,,,, -234100,3.209738,1.2394378,,,,,,,,,,,,,, -234200,4.1261253,1.0822852,,,,,,,,,,,,,, -234300,3.0938256,1.1542339,,,,,,,,,,,,,, -234360,,,0.8882030844688416,0.4220650792121887,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,102969.69378328323,112336.56853604317,102969.69378328323,9340.741424560549,13.89189600944519,0.0 -234400,3.2042727,1.1080389,,,,,,,,,,,,,, -234500,3.0297487,1.1855996,,,,,,,,,,,,,, -234600,3.0613723,1.3156888,,,,,,,,,,,,,, -234700,3.1947823,1.6416612,,,,,,,,,,,,,, -234800,3.0935297,2.375156,,,,,,,,,,,,,, -234900,3.0167263,2.2995043,,,,,,,,,,,,,, -235000,3.5481038,2.8433502,,,,,,,,,,,,,, -235100,3.0488596,1.0711824,,,,,,,,,,,,,, -235200,3.1408954,1.1146691,,,,,,,,,,,,,, -235300,3.0930307,1.1368631,,,,,,,,,,,,,, -235318,,,0.88818359375,0.4170994460582733,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,103389.81013679504,112796.37017011642,103389.81013679504,9380.29682302475,13.971137523651125,0.0 -235400,4.6994443,1.1287047,,,,,,,,,,,,,, -235500,2.9825199,1.2307565,,,,,,,,,,,,,, -235600,2.9931061,1.1195457,,,,,,,,,,,,,, -235700,3.3376403,2.6842883,,,,,,,,,,,,,, -235800,3.4903018,2.2233589,,,,,,,,,,,,,, -235900,3.6979246,3.0612862,,,,,,,,,,,,,, -236000,3.4793582,1.1657963,,,,,,,,,,,,,, -236100,3.219516,1.1235824,,,,,,,,,,,,,, -236200,2.9858575,1.6192156,,,,,,,,,,,,,, -236272,,,0.888671875,0.4124659597873688,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,103809.80734848976,113264.9406042099,103809.80734848976,9428.746124267578,14.04552173614502,0.0 -236300,3.6471996,2.9422822,,,,,,,,,,,,,, -236400,3.1968465,1.2302871,,,,,,,,,,,,,, -236500,3.5324273,3.0778759,,,,,,,,,,,,,, -236600,2.8249238,1.0786474,,,,,,,,,,,,,, -236700,3.2462227,1.6730945,,,,,,,,,,,,,, -236800,3.0649166,1.1170738,,,,,,,,,,,,,, -236900,4.28897,3.241792,,,,,,,,,,,,,, -237000,3.0717406,2.2139359,,,,,,,,,,,,,, -237100,3.6759062,3.1684444,,,,,,,,,,,,,, -237200,3.0468187,1.0236609,,,,,,,,,,,,,, -237229,,,0.8904492259025574,0.4119705855846405,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,104230.0831682682,113725.29149913788,104230.0831682682,9468.708633422852,14.108554124832152,0.0 -237300,2.9191778,1.1537226,,,,,,,,,,,,,, -237400,3.2380104,1.1162258,,,,,,,,,,,,,, -237500,3.4267344,2.4407783,,,,,,,,,,,,,, -237600,3.7524524,1.1955558,,,,,,,,,,,,,, -237700,3.0528264,1.0807132,,,,,,,,,,,,,, -237800,3.5336576,1.1732277,,,,,,,,,,,,,, -237900,2.9599261,1.4262887,,,,,,,,,,,,,, -238000,4.1503406,1.5161434,,,,,,,,,,,,,, -238100,3.4516156,1.1166953,,,,,,,,,,,,,, -238186,,,0.8873632550239563,0.4150916635990143,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,104650.29148387907,114193.11408925056,104650.29148387907,9516.149178743362,14.232004404067991,0.0 -238200,3.1513069,1.1276001,,,,,,,,,,,,,, -238300,2.9541612,1.6485802,,,,,,,,,,,,,, -238400,3.0965965,1.7873951,,,,,,,,,,,,,, -238500,3.0999365,1.1017267,,,,,,,,,,,,,, -238600,3.688501,3.1080625,,,,,,,,,,,,,, -238700,3.0164697,1.0812752,,,,,,,,,,,,,, -238800,2.9894867,1.154176,,,,,,,,,,,,,, -238900,2.9535375,1.0494739,,,,,,,,,,,,,, -239000,3.1067579,1.0222383,,,,,,,,,,,,,, -239100,3.2375042,1.3766241,,,,,,,,,,,,,, -239143,,,0.8875585794448853,0.4165133833885193,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,105070.2694966793,114650.68592834473,105070.2694966793,9553.618543624878,14.30665922164917,0.0 -239200,2.9273705,2.0175843,,,,,,,,,,,,,, -239300,3.2643535,2.8187659,,,,,,,,,,,,,, -239400,3.3414555,1.2622459,,,,,,,,,,,,,, -239500,3.3041818,1.0484705,,,,,,,,,,,,,, -239600,3.210871,1.9956805,,,,,,,,,,,,,, -239700,3.2785075,1.1708105,,,,,,,,,,,,,, -239800,3.1769936,2.2573295,,,,,,,,,,,,,, -239900,4.130312,3.0553923,,,,,,,,,,,,,, -240000,3.1109927,2.119183,,,,,,,,,,,,,, -240097,,,0.8878905773162842,0.4163110256195068,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,105490.25032567978,115107.77051401138,105490.25032567978,9590.587503671646,14.38660478591919,0.0 -240100,3.459508,2.886786,,,,,,,,,,,,,, -240200,3.4766805,1.2260053,,,,,,,,,,,,,, -240300,3.7466292,3.150562,,,,,,,,,,,,,, -240400,3.272599,1.1727427,,,,,,,,,,,,,, -240500,3.1062648,2.0398057,,,,,,,,,,,,,, -240600,2.9137533,1.2927763,,,,,,,,,,,,,, -240700,4.04671,3.0469146,,,,,,,,,,,,,, -240800,3.2253072,1.4054753,,,,,,,,,,,,,, -240900,3.4141986,1.933948,,,,,,,,,,,,,, -241000,3.726843,3.1274607,,,,,,,,,,,,,, -241051,,,0.8893749713897705,0.416501522064209,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,105910.28881311417,115566.50127744676,105910.28881311417,9629.101864814758,14.51517915725708,0.0 -241100,3.633458,1.4388398,,,,,,,,,,,,,, -241200,3.603291,3.140806,,,,,,,,,,,,,, -241300,3.0498612,1.1984489,,,,,,,,,,,,,, -241400,3.2705488,2.1018128,,,,,,,,,,,,,, -241500,3.1801522,1.140496,,,,,,,,,,,,,, -241600,3.7009168,3.2777426,,,,,,,,,,,,,, -241700,3.2763715,1.133393,,,,,,,,,,,,,, -241800,3.0772722,1.1575031,,,,,,,,,,,,,, -241900,3.0151355,1.2273717,,,,,,,,,,,,,, -242000,3.0388875,1.0670244,,,,,,,,,,,,,, -242003,,,0.8878124952316284,0.4150179922580719,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,106330.19357085228,116026.39215540886,106330.19357085228,9668.963171720505,14.59090256690979,0.0 -242100,3.2455857,1.1691763,,,,,,,,,,,,,, -242200,3.6210475,3.025086,,,,,,,,,,,,,, -242300,2.9914281,1.2036837,,,,,,,,,,,,,, -242400,3.1655772,1.3275464,,,,,,,,,,,,,, -242500,3.0614693,1.9191349,,,,,,,,,,,,,, -242600,3.2058306,2.8623161,,,,,,,,,,,,,, -242700,3.341011,1.1123074,,,,,,,,,,,,,, -242800,3.0661054,0.99496967,,,,,,,,,,,,,, -242900,4.519301,3.3606234,,,,,,,,,,,,,, -242957,,,0.8879687190055847,0.4157115519046783,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,106750.3352985382,116483.22671723366,106750.3352985382,9705.530426979063,14.666574001312256,0.0 -243000,3.5262916,2.604723,,,,,,,,,,,,,, -243100,3.936572,1.1914214,,,,,,,,,,,,,, -243200,3.266804,2.6448941,,,,,,,,,,,,,, -243300,3.1566045,1.1356902,,,,,,,,,,,,,, -243400,3.4869347,2.113672,,,,,,,,,,,,,, -243500,3.3325765,1.1852288,,,,,,,,,,,,,, -243600,3.0022576,1.051366,,,,,,,,,,,,,, -243700,3.332879,2.9419162,,,,,,,,,,,,,, -243800,2.9047167,1.1291916,,,,,,,,,,,,,, -243900,3.1771915,2.628121,,,,,,,,,,,,,, -243911,,,0.8875976204872131,0.4220549166202545,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,107170.53363466264,116950.42909789084,107170.53363466264,9752.409542560576,14.741652011871338,0.0 -244000,3.375789,1.1284902,,,,,,,,,,,,,, -244100,2.782337,1.5721682,,,,,,,,,,,,,, -244200,3.03911,1.1059885,,,,,,,,,,,,,, -244300,2.9682796,1.1381354,,,,,,,,,,,,,, -244400,3.6036503,3.3486588,,,,,,,,,,,,,, -244500,2.8406582,1.282754,,,,,,,,,,,,,, -244600,3.8963487,1.1340388,,,,,,,,,,,,,, -244700,2.8996918,2.0556154,,,,,,,,,,,,,, -244800,2.9650755,1.544076,,,,,,,,,,,,,, -244867,,,0.8887304663658142,0.4134644269943237,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,107590.54052233696,117414.08001947404,107590.54052233696,9795.938479661942,14.807200908660889,0.0 -244900,3.1621993,1.3853948,,,,,,,,,,,,,, -245000,3.0795174,1.2408997,,,,,,,,,,,,,, -245100,3.2427256,1.1645331,,,,,,,,,,,,,, -245200,3.2218387,2.7893693,,,,,,,,,,,,,, -245300,3.032645,1.0688127,,,,,,,,,,,,,, -245400,3.2332594,1.1463947,,,,,,,,,,,,,, -245500,3.1415648,1.2260289,,,,,,,,,,,,,, -245600,3.2437923,1.1642252,,,,,,,,,,,,,, -245700,3.412273,1.968212,,,,,,,,,,,,,, -245800,3.204316,2.0182583,,,,,,,,,,,,,, -245826,,,0.8887304663658142,0.4132008254528045,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,108010.750831604,117873.7346212864,108010.750831604,9835.25951552391,14.880392789840698,0.0 -245900,3.538644,2.6882808,,,,,,,,,,,,,, -246000,3.9999204,3.2519395,,,,,,,,,,,,,, -246100,3.0037875,1.099905,,,,,,,,,,,,,, -246200,3.2440574,1.3136243,,,,,,,,,,,,,, -246300,3.0726268,1.1333979,,,,,,,,,,,,,, -246400,3.0871024,1.329799,,,,,,,,,,,,,, -246500,3.270305,2.078106,,,,,,,,,,,,,, -246600,3.595122,3.015017,,,,,,,,,,,,,, -246700,3.460926,1.1947979,,,,,,,,,,,,,, -246779,,,0.88685542345047,0.4224175214767456,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,108430.8633108139,118340.21789956091,108430.8633108139,9881.516390562056,14.945569515228271,0.0 -246800,3.2057297,1.1111228,,,,,,,,,,,,,, -246900,3.413122,1.9345689,,,,,,,,,,,,,, -247000,3.0867794,1.1601145,,,,,,,,,,,,,, -247100,3.0071683,1.1198452,,,,,,,,,,,,,, -247200,3.3897913,1.442126,,,,,,,,,,,,,, -247300,3.3525727,1.7349966,,,,,,,,,,,,,, -247400,3.4833746,2.883646,,,,,,,,,,,,,, -247500,3.656686,3.185748,,,,,,,,,,,,,, -247600,3.2547956,2.7233753,,,,,,,,,,,,,, -247700,3.0958478,1.0934006,,,,,,,,,,,,,, -247734,,,0.8893359303474426,0.4157688319683075,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,108850.87331867218,118801.3809273243,108850.87331867218,9922.546244859695,15.019672870635986,0.0 -247800,4.038685,3.2414885,,,,,,,,,,,,,, -247900,3.7360146,3.2161703,,,,,,,,,,,,,, -248000,3.641934,3.1184022,,,,,,,,,,,,,, -248100,3.175461,1.2445588,,,,,,,,,,,,,, -248200,3.1412535,1.4760174,,,,,,,,,,,,,, -248300,3.2010453,1.155978,,,,,,,,,,,,,, -248400,3.1089811,2.2804158,,,,,,,,,,,,,, -248500,2.9769974,1.7605159,,,,,,,,,,,,,, -248600,3.0169868,1.3214136,,,,,,,,,,,,,, -248690,,,0.8864257335662842,0.4195002615451813,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,109271.09632349014,119260.77212262154,109271.09632349014,9961.588781833649,15.095571517944336,0.0 -248700,3.3681684,1.5189117,,,,,,,,,,,,,, -248800,3.0099866,2.2596085,,,,,,,,,,,,,, -248900,3.0449035,1.4591513,,,,,,,,,,,,,, -249000,3.2980614,1.0945289,,,,,,,,,,,,,, -249100,3.2551613,1.1362362,,,,,,,,,,,,,, -249200,3.0916212,1.6640099,,,,,,,,,,,,,, -249300,3.6093466,3.0002813,,,,,,,,,,,,,, -249400,3.3126123,2.4524665,,,,,,,,,,,,,, -249500,3.013404,2.0276866,,,,,,,,,,,,,, -249600,3.1263328,1.1128767,,,,,,,,,,,,,, -249643,,,0.8881640434265137,0.4143919646739959,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,109691.39629292488,119718.55010271072,109691.39629292488,9998.935094356537,15.17807126045227,0.0 -249700,3.259727,1.0716877,,,,,,,,,,,,,, -249800,2.9510703,1.0627791,,,,,,,,,,,,,, -249900,3.1476424,1.1603507,,,,,,,,,,,,,, -250000,2.906457,2.225141,,,,,,,,,,,,,, -250100,4.0003734,3.1410248,,,,,,,,,,,,,, -250200,3.066951,1.1223409,,,,,,,,,,,,,, -250300,3.0811982,1.2044001,,,,,,,,,,,,,, -250400,3.1358557,1.1587813,,,,,,,,,,,,,, -250500,3.3266404,2.284944,,,,,,,,,,,,,, -250598,,,0.8877539038658142,0.4176174700260162,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,110111.48210144044,120176.88127589226,110111.48210144044,10037.054334640505,15.254687309265137,0.0 -250600,3.441415,2.201058,,,,,,,,,,,,,, -250700,3.408155,1.1611254,,,,,,,,,,,,,, -250800,3.0691857,1.2285764,,,,,,,,,,,,,, -250900,3.262621,1.5511608,,,,,,,,,,,,,, -251000,3.067181,1.070348,,,,,,,,,,,,,, -251100,3.0957124,1.1872435,,,,,,,,,,,,,, -251200,2.988599,1.3130773,,,,,,,,,,,,,, -251300,3.095159,2.2503624,,,,,,,,,,,,,, -251400,3.024568,1.0845487,,,,,,,,,,,,,, -251500,3.1520658,2.394155,,,,,,,,,,,,,, -251552,,,0.8870507478713989,0.4220073521137237,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,110531.47845578194,120637.57777690887,110531.47845578194,10077.62155175209,15.338451147079468,0.0 -251600,2.9170666,1.1475337,,,,,,,,,,,,,, -251700,3.8757305,3.298609,,,,,,,,,,,,,, -251800,3.000672,1.658614,,,,,,,,,,,,,, -251900,3.5971587,3.256772,,,,,,,,,,,,,, -252000,3.0413725,1.4631361,,,,,,,,,,,,,, -252100,2.9941516,1.325408,,,,,,,,,,,,,, -252200,3.1329536,1.1725168,,,,,,,,,,,,,, -252300,3.00076,1.5655746,,,,,,,,,,,,,, -252400,3.2150314,1.1313747,,,,,,,,,,,,,, -252500,3.2537887,1.1331267,,,,,,,,,,,,,, -252511,,,0.8853319883346558,0.4239788651466369,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,110951.68130636217,121099.85851335526,110951.68130636217,10119.577965021132,15.408671617507936,0.0 -252600,3.018055,1.737585,,,,,,,,,,,,,, -252700,3.2854805,1.1121986,,,,,,,,,,,,,, -252800,3.3575075,1.1678079,,,,,,,,,,,,,, -252900,3.073928,1.0942777,,,,,,,,,,,,,, -253000,3.1951888,1.1170746,,,,,,,,,,,,,, -253100,3.4157639,1.1524935,,,,,,,,,,,,,, -253200,3.3355935,2.7600741,,,,,,,,,,,,,, -253300,3.2651422,1.3280857,,,,,,,,,,,,,, -253400,3.0156496,1.4187918,,,,,,,,,,,,,, -253469,,,0.88671875,0.4214539229869842,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,111371.58928275108,121561.9081993103,111371.58928275108,10161.588672876358,15.490390062332152,0.0 -253500,3.1838694,1.2204752,,,,,,,,,,,,,, -253600,3.073894,2.558175,,,,,,,,,,,,,, -253700,3.2291884,1.6029197,,,,,,,,,,,,,, -253800,3.0580382,2.3223276,,,,,,,,,,,,,, -253900,2.983148,1.1275483,,,,,,,,,,,,,, -254000,3.0166054,1.5056065,,,,,,,,,,,,,, -254100,3.809176,3.268714,,,,,,,,,,,,,, -254200,3.2595851,1.6507633,,,,,,,,,,,,,, -254300,3.8037019,3.0126827,,,,,,,,,,,,,, -254400,3.4515204,3.0873084,,,,,,,,,,,,,, -254423,,,0.8893554210662842,0.4129403531551361,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,111791.53270983696,122019.50779938698,111791.53270983696,10199.127958536148,15.558288812637327,0.0 -254500,3.239258,1.9291219,,,,,,,,,,,,,, -254600,3.8024588,3.385386,,,,,,,,,,,,,, -254700,3.3882635,2.9777532,,,,,,,,,,,,,, -254800,3.5296085,1.1309704,,,,,,,,,,,,,, -254900,3.0835872,1.2543502,,,,,,,,,,,,,, -255000,3.1225028,2.025946,,,,,,,,,,,,,, -255100,2.8986933,1.9595249,,,,,,,,,,,,,, -255200,3.388011,1.1503752,,,,,,,,,,,,,, -255300,3.8000338,2.0195346,,,,,,,,,,,,,, -255378,,,0.8873632550239563,0.4160084426403045,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,112211.63786792757,122489.02125787736,112211.63786792757,10248.407604455948,15.63877296447754,0.0 -255400,4.4267864,3.38485,,,,,,,,,,,,,, -255500,2.8308663,2.190609,,,,,,,,,,,,,, -255600,3.231443,1.8172033,,,,,,,,,,,,,, -255700,3.005757,1.1956917,,,,,,,,,,,,,, -255800,4.136733,3.1193416,,,,,,,,,,,,,, -255900,2.821478,1.398383,,,,,,,,,,,,,, -256000,3.068193,1.0393721,,,,,,,,,,,,,, -256100,3.3121405,2.766635,,,,,,,,,,,,,, -256200,3.1742685,1.0915263,,,,,,,,,,,,,, -256300,3.0150766,1.2625415,,,,,,,,,,,,,, -256336,,,0.8903515338897705,0.4116775095462799,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,112631.70882320404,122943.07149362564,112631.70882320404,10282.272267341614,15.70403218269348,0.0 -256400,4.0546384,3.0090609,,,,,,,,,,,,,, -256500,3.2675517,1.1042647,,,,,,,,,,,,,, -256600,3.1305265,1.3914984,,,,,,,,,,,,,, -256700,3.0327952,1.2053448,,,,,,,,,,,,,, -256800,3.2601323,2.107667,,,,,,,,,,,,,, -256900,3.320481,1.110458,,,,,,,,,,,,,, -257000,3.1442337,1.4919087,,,,,,,,,,,,,, -257100,3.0701675,1.066604,,,,,,,,,,,,,, -257200,3.193284,2.7600708,,,,,,,,,,,,,, -257290,,,0.8877929449081421,0.4188413619995117,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,113052.01241707802,123413.9364683628,113052.01241707802,10332.705268383026,15.78389596939087,0.0 -257300,3.139405,1.210401,,,,,,,,,,,,,, -257400,3.3777313,1.1473926,,,,,,,,,,,,,, -257500,3.0689664,1.1879653,,,,,,,,,,,,,, -257600,3.2115793,2.4801588,,,,,,,,,,,,,, -257700,3.172965,1.9908166,,,,,,,,,,,,,, -257800,2.9205647,2.0818086,,,,,,,,,,,,,, -257900,3.1223404,1.387126,,,,,,,,,,,,,, -258000,3.1791654,1.1515453,,,,,,,,,,,,,, -258100,2.9754868,1.1091216,,,,,,,,,,,,,, -258200,3.3429363,1.9598746,,,,,,,,,,,,,, -258246,,,0.8865820169448853,0.4241874516010284,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,113471.94405174255,123874.68347358704,113471.94405174255,10373.404185056686,15.850019216537476,0.0 -258300,3.147708,2.1842716,,,,,,,,,,,,,, -258400,3.3260474,1.1151273,,,,,,,,,,,,,, -258500,3.2691822,2.5593379,,,,,,,,,,,,,, -258600,3.0985718,1.213272,,,,,,,,,,,,,, -258700,3.1263855,0.995229,,,,,,,,,,,,,, -258800,3.290146,1.2643347,,,,,,,,,,,,,, -258900,3.090248,1.0791453,,,,,,,,,,,,,, -259000,3.9201908,3.349528,,,,,,,,,,,,,, -259100,3.2016196,1.2672923,,,,,,,,,,,,,, -259200,3.0804195,1.1603558,,,,,,,,,,,,,, -259201,,,0.8890624642372131,0.415942519903183,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,113892.2898645401,124333.91425085068,113892.2898645401,10412.161917448044,15.914101839065552,0.0 -259300,3.3381622,2.5280905,,,,,,,,,,,,,, -259400,3.352676,2.542524,,,,,,,,,,,,,, -259500,3.654228,3.0723062,,,,,,,,,,,,,, -259600,2.7766242,1.933799,,,,,,,,,,,,,, -259700,3.6782017,3.1987524,,,,,,,,,,,,,, -259800,2.9135482,1.2904925,,,,,,,,,,,,,, -259900,3.0103323,1.4527608,,,,,,,,,,,,,, -260000,2.8984964,1.8201294,,,,,,,,,,,,,, -260100,3.058367,1.0820775,,,,,,,,,,,,,, -260147,,,0.8899999856948853,0.407744437456131,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,114312.50053668022,124796.59203076364,114312.50053668022,10454.48885679245,16.005540370941162,0.0 -260200,3.641067,2.3720784,,,,,,,,,,,,,, -260300,3.8458598,3.1571069,,,,,,,,,,,,,, -260400,3.2591856,2.3581343,,,,,,,,,,,,,, -260500,3.0447717,1.2693226,,,,,,,,,,,,,, -260600,3.047546,2.0492651,,,,,,,,,,,,,, -260700,3.8833373,3.4235935,,,,,,,,,,,,,, -260800,3.632857,2.9230175,,,,,,,,,,,,,, -260900,3.6968386,1.2083187,,,,,,,,,,,,,, -261000,3.3774607,2.0820324,,,,,,,,,,,,,, -261098,,,0.8884179592132568,0.4153020083904266,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,114732.67173314096,125255.21079921722,114732.67173314096,10492.80323791504,16.09019136428833,0.0 -261100,3.1038055,2.951885,,,,,,,,,,,,,, -261200,3.068866,0.97854555,,,,,,,,,,,,,, -261300,3.0234358,1.387064,,,,,,,,,,,,,, -261400,3.2392962,1.142882,,,,,,,,,,,,,, -261500,3.463419,3.1616986,,,,,,,,,,,,,, -261600,3.3731153,1.5342121,,,,,,,,,,,,,, -261700,3.2803714,2.6212301,,,,,,,,,,,,,, -261800,2.9313142,1.7078508,,,,,,,,,,,,,, -261900,3.3418396,1.1171579,,,,,,,,,,,,,, -262000,3.8072526,1.2705396,,,,,,,,,,,,,, -262052,,,0.8883398175239563,0.4166604876518249,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,115152.60396766664,125712.29814648628,115152.60396766664,10529.83626651764,16.163026809692383,0.0 -262100,3.0179605,1.1269498,,,,,,,,,,,,,, -262200,3.0286608,2.0275939,,,,,,,,,,,,,, -262300,3.3612888,1.1274731,,,,,,,,,,,,,, -262400,2.9792428,1.1865685,,,,,,,,,,,,,, -262500,3.4002068,2.6136396,,,,,,,,,,,,,, -262600,3.3388093,2.5991647,,,,,,,,,,,,,, -262700,3.0911696,1.0993063,,,,,,,,,,,,,, -262800,3.9029884,3.3048337,,,,,,,,,,,,,, -262900,3.2186737,1.2162647,,,,,,,,,,,,,, -263000,3.0348766,1.6406116,,,,,,,,,,,,,, -263006,,,0.8867968320846558,0.4171631932258606,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,115572.77207779884,126180.26666045187,115572.77207779884,10577.499941587448,16.24975323677063,0.0 -263100,3.4397535,2.2119117,,,,,,,,,,,,,, -263200,3.1725595,1.2905631,,,,,,,,,,,,,, -263300,3.1549988,1.1072582,,,,,,,,,,,,,, -263400,3.2572024,2.63246,,,,,,,,,,,,,, -263500,2.9302616,1.0640469,,,,,,,,,,,,,, -263600,3.069477,2.054931,,,,,,,,,,,,,, -263700,3.9456575,2.2978256,,,,,,,,,,,,,, -263800,3.540786,2.7550385,,,,,,,,,,,,,, -263900,3.047333,1.0903482,,,,,,,,,,,,,, -263965,,,0.8879296779632568,0.4193416833877563,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,115992.9700987339,126641.70953035356,115992.9700987339,10618.628512144089,16.316112995147705,0.0 -264000,3.0574965,1.1837395,,,,,,,,,,,,,, -264100,3.3358953,1.0857589,,,,,,,,,,,,,, -264200,3.6477969,1.540107,,,,,,,,,,,,,, -264300,2.8567865,1.6258364,,,,,,,,,,,,,, -264400,3.0407658,1.613427,,,,,,,,,,,,,, -264500,2.9062366,1.077967,,,,,,,,,,,,,, -264600,3.219336,1.1604307,,,,,,,,,,,,,, -264700,3.0755765,1.3023753,,,,,,,,,,,,,, -264800,3.0098238,1.17831,,,,,,,,,,,,,, -264900,3.152744,1.1148108,,,,,,,,,,,,,, -264920,,,0.8895702958106995,0.4095839560031891,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,116412.8924088478,127098.88965320589,116412.8924088478,10655.76745057106,16.38568639755249,0.0 -265000,3.1183774,1.3654985,,,,,,,,,,,,,, -265100,3.217218,1.1773893,,,,,,,,,,,,,, -265200,3.1057436,1.1107012,,,,,,,,,,,,,, -265300,2.9758012,1.918772,,,,,,,,,,,,,, -265400,3.0069826,1.1350694,,,,,,,,,,,,,, -265500,2.929637,1.6378975,,,,,,,,,,,,,, -265600,3.9445152,1.1838082,,,,,,,,,,,,,, -265700,3.010138,1.2302536,,,,,,,,,,,,,, -265800,3.3861048,3.0645864,,,,,,,,,,,,,, -265874,,,0.8883984088897705,0.4182425737380981,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,116832.95955109596,127565.38592290878,116832.95955109596,10702.066796064377,16.466209650039673,0.0 -265900,3.86813,3.1789815,,,,,,,,,,,,,, -266000,3.1350572,2.119687,,,,,,,,,,,,,, -266100,3.0016162,1.0711962,,,,,,,,,,,,,, -266200,3.2675302,2.1600785,,,,,,,,,,,,,, -266300,3.1245522,1.1367561,,,,,,,,,,,,,, -266400,3.61662,3.320282,,,,,,,,,,,,,, -266500,3.5713315,1.3083763,,,,,,,,,,,,,, -266600,3.1105998,2.4214456,,,,,,,,,,,,,, -266700,3.3117232,2.7499602,,,,,,,,,,,,,, -266800,3.3170815,2.4458506,,,,,,,,,,,,,, -266832,,,0.8867382407188416,0.4178032875061035,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,117253.22350525856,128020.87030768394,117253.22350525856,10737.159190177916,16.54511594772339,0.0 -266900,3.2749116,1.1273391,,,,,,,,,,,,,, -267000,3.6286387,1.1898417,,,,,,,,,,,,,, -267100,3.0988817,1.1556029,,,,,,,,,,,,,, -267200,2.980457,1.0959634,,,,,,,,,,,,,, -267300,3.2772467,1.1084063,,,,,,,,,,,,,, -267400,3.3041747,1.7996873,,,,,,,,,,,,,, -267500,3.4552274,1.2718631,,,,,,,,,,,,,, -267600,2.9077728,2.4525766,,,,,,,,,,,,,, -267700,3.297741,1.17637,,,,,,,,,,,,,, -267786,,,0.8861132860183716,0.4290775954723358,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,117673.31190681458,128480.0262553692,117673.31190681458,10776.096898555756,16.626272678375244,0.0 -267800,3.233673,1.122156,,,,,,,,,,,,,, -267900,3.1695547,2.405002,,,,,,,,,,,,,, -268000,3.1518564,1.37643,,,,,,,,,,,,,, -268100,2.949061,1.0309972,,,,,,,,,,,,,, -268200,2.8681734,1.4346957,,,,,,,,,,,,,, -268300,3.0046926,1.5171299,,,,,,,,,,,,,, -268400,3.2213435,1.0472631,,,,,,,,,,,,,, -268500,3.0205593,1.0257751,,,,,,,,,,,,,, -268600,3.2585819,1.8742294,,,,,,,,,,,,,, -268700,3.0651648,1.6960564,,,,,,,,,,,,,, -268739,,,0.8898437023162842,0.4073122441768646,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,118093.4741909504,128939.01418423653,118093.4741909504,10814.793027639387,16.707128047943115,0.0 -268800,2.9929242,1.2277857,,,,,,,,,,,,,, -268900,3.0122762,1.737993,,,,,,,,,,,,,, -269000,3.5464854,1.3290539,,,,,,,,,,,,,, -269100,2.7797577,1.9011152,,,,,,,,,,,,,, -269200,3.2476344,2.836554,,,,,,,,,,,,,, -269300,3.3506708,1.1605259,,,,,,,,,,,,,, -269400,2.9232464,1.7266269,,,,,,,,,,,,,, -269500,2.8545668,1.6245205,,,,,,,,,,,,,, -269600,3.2947993,1.2602131,,,,,,,,,,,,,, -269697,,,0.8878124952316284,0.414100170135498,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,118513.58444952963,129400.95628118516,118513.58444952963,10856.493202209473,16.789478540420532,0.0 -269700,3.392314,2.6913283,,,,,,,,,,,,,, -269800,2.8933876,1.6439935,,,,,,,,,,,,,, -269900,3.460042,1.1193194,,,,,,,,,,,,,, -270000,3.4534838,2.9979637,,,,,,,,,,,,,, -270100,3.3038633,2.1050456,,,,,,,,,,,,,, -270200,3.194836,2.470438,,,,,,,,,,,,,, -270300,3.6147501,1.1925722,,,,,,,,,,,,,, -270400,3.0699,1.1214437,,,,,,,,,,,,,, -270500,3.7221625,3.129975,,,,,,,,,,,,,, -270600,3.0140843,1.5597144,,,,,,,,,,,,,, -270655,,,0.8884375095367432,0.4185597002506256,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,118933.8285355568,129856.05528616904,118933.8285355568,10891.216101408005,16.87260341644287,0.0 -270700,3.2067654,1.9106635,,,,,,,,,,,,,, -270800,2.9100869,1.2205484,,,,,,,,,,,,,, -270900,2.9992046,2.0258393,,,,,,,,,,,,,, -271000,3.314685,2.810482,,,,,,,,,,,,,, -271100,3.2583163,1.1058303,,,,,,,,,,,,,, -271200,3.0008361,1.0582826,,,,,,,,,,,,,, -271300,3.6920066,1.7089527,,,,,,,,,,,,,, -271400,3.20595,1.8433845,,,,,,,,,,,,,, -271500,3.9063787,3.170279,,,,,,,,,,,,,, -271600,3.148329,1.2141747,,,,,,,,,,,,,, -271611,,,0.8877343535423279,0.4192695617675781,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,119353.74111104012,130312.5183684826,119353.74111104012,10927.63363814354,16.95663595199585,0.0 -271700,3.0641136,1.203507,,,,,,,,,,,,,, -271800,2.9508677,1.3206388,,,,,,,,,,,,,, -271900,3.6709464,1.1423465,,,,,,,,,,,,,, -272000,3.0556357,1.3518965,,,,,,,,,,,,,, -272100,3.4712393,1.0409367,,,,,,,,,,,,,, -272200,3.1374722,1.1570648,,,,,,,,,,,,,, -272300,3.5111551,2.9552557,,,,,,,,,,,,,, -272400,2.990958,1.464833,,,,,,,,,,,,,, -272500,2.5885115,1.6987062,,,,,,,,,,,,,, -272566,,,0.8866015672683716,0.4227052330970764,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,119773.71231675148,130773.18575668336,119773.71231675148,10968.197939157486,17.039021015167236,0.0 -272600,3.3288517,1.1028234,,,,,,,,,,,,,, -272700,3.1924932,1.6106341,,,,,,,,,,,,,, -272800,3.5211194,1.1097102,,,,,,,,,,,,,, -272900,3.1491826,1.0876184,,,,,,,,,,,,,, -273000,3.5197864,2.8818038,,,,,,,,,,,,,, -273100,3.0926208,1.7354648,,,,,,,,,,,,,, -273200,3.3731694,3.1845806,,,,,,,,,,,,,, -273300,2.8936381,1.7036735,,,,,,,,,,,,,, -273400,2.9927306,1.1813681,,,,,,,,,,,,,, -273500,3.083445,2.5359433,,,,,,,,,,,,,, -273525,,,0.8893554210662842,0.4108372330665588,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,120193.92505979538,131241.99991846085,120193.92505979538,11016.682899475098,17.105791568756104,0.0 -273600,3.7313864,3.264911,,,,,,,,,,,,,, -273700,2.888548,1.2640632,,,,,,,,,,,,,, -273800,3.1544962,2.666558,,,,,,,,,,,,,, -273900,3.1874945,1.1557647,,,,,,,,,,,,,, -274000,3.4848003,1.2333614,,,,,,,,,,,,,, -274100,2.8612266,1.6397519,,,,,,,,,,,,,, -274200,3.147991,1.0232263,,,,,,,,,,,,,, -274300,3.1131368,1.123935,,,,,,,,,,,,,, -274400,3.1301105,2.4493494,,,,,,,,,,,,,, -274484,,,0.8886327743530273,0.4180715084075928,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,120614.32047319412,131703.71882891655,120614.32047319412,11057.888793230057,17.17463517189026,0.0 -274500,3.0338368,1.14198,,,,,,,,,,,,,, -274600,3.301009,1.0925859,,,,,,,,,,,,,, -274700,3.5205138,2.9628837,,,,,,,,,,,,,, -274800,3.010574,1.9946352,,,,,,,,,,,,,, -274900,3.1579409,1.1321031,,,,,,,,,,,,,, -275000,2.9751513,1.241799,,,,,,,,,,,,,, -275100,3.3712478,2.7036424,,,,,,,,,,,,,, -275200,3.3348837,1.1664919,,,,,,,,,,,,,, -275300,3.451131,2.9799974,,,,,,,,,,,,,, -275400,3.186951,1.3682706,,,,,,,,,,,,,, -275441,,,0.8881054520606995,0.4140663146972656,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,121034.56578993796,132162.7273669243,121034.56578993796,11096.530444145204,17.246686697006226,0.0 -275500,3.2489648,1.3330137,,,,,,,,,,,,,, -275600,3.5253007,2.844765,,,,,,,,,,,,,, -275700,3.6320066,3.3182511,,,,,,,,,,,,,, -275800,3.140614,1.1569533,,,,,,,,,,,,,, -275900,3.2286155,1.2637647,,,,,,,,,,,,,, -276000,3.2066867,1.1869489,,,,,,,,,,,,,, -276100,2.9188,2.0436404,,,,,,,,,,,,,, -276200,3.3507404,2.346631,,,,,,,,,,,,,, -276300,3.315805,1.1601052,,,,,,,,,,,,,, -276382,,,0.8871874809265137,0.4224247932434082,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,121454.5478489399,132615.83280420303,121454.5478489399,11129.518389940262,17.333673238754272,0.0 -276400,3.4674861,1.4181674,,,,,,,,,,,,,, -276500,3.8760536,3.2299623,,,,,,,,,,,,,, -276600,2.9216845,1.281771,,,,,,,,,,,,,, -276700,3.1538754,1.1026734,,,,,,,,,,,,,, -276800,4.120682,3.2378218,,,,,,,,,,,,,, -276900,3.4470131,1.4581617,,,,,,,,,,,,,, -277000,3.4193816,1.3775234,,,,,,,,,,,,,, -277100,3.016479,1.7954451,,,,,,,,,,,,,, -277200,3.359738,2.2285056,,,,,,,,,,,,,, -277300,3.523052,2.9565494,,,,,,,,,,,,,, -277326,,,0.8856640458106995,0.4246830940246582,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,121874.64784169196,133072.1931116581,121874.64784169196,11165.648698568344,17.41505718231201,0.0 -277400,3.2851603,1.1032343,,,,,,,,,,,,,, -277500,4.0644884,3.3452063,,,,,,,,,,,,,, -277600,3.2857504,2.8349783,,,,,,,,,,,,,, -277700,3.03002,2.349602,,,,,,,,,,,,,, -277800,3.1057107,2.1261265,,,,,,,,,,,,,, -277900,3.7113943,3.0886786,,,,,,,,,,,,,, -278000,2.8998754,1.238448,,,,,,,,,,,,,, -278100,4.4792767,3.2805405,,,,,,,,,,,,,, -278200,3.2505193,1.1105707,,,,,,,,,,,,,, -278279,,,0.8876367211341858,0.4154536426067352,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,122294.87970471382,133531.81871771812,122294.87970471382,11204.90605711937,17.502548217773438,0.0 -278300,3.2058797,1.2196231,,,,,,,,,,,,,, -278400,2.9569106,1.2071972,,,,,,,,,,,,,, -278500,3.0417469,2.0611625,,,,,,,,,,,,,, -278600,3.2799044,2.7827175,,,,,,,,,,,,,, -278700,3.2455614,1.1667265,,,,,,,,,,,,,, -278800,3.1291797,2.1497004,,,,,,,,,,,,,, -278900,3.1267865,1.1041319,,,,,,,,,,,,,, -279000,3.4095364,3.0716941,,,,,,,,,,,,,, -279100,2.9419692,1.7160431,,,,,,,,,,,,,, -279200,3.1508036,2.37102,,,,,,,,,,,,,, -279237,,,0.88832026720047,0.4142638742923736,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,122715.11296343803,133990.30941224098,122715.11296343803,11243.041862249374,17.574153184890747,0.0 -279300,3.1178727,1.1577821,,,,,,,,,,,,,, -279400,3.3507562,2.7328122,,,,,,,,,,,,,, -279500,2.844956,1.2430633,,,,,,,,,,,,,, -279600,3.2192078,1.1948197,,,,,,,,,,,,,, -279700,3.127745,2.557343,,,,,,,,,,,,,, -279800,3.9430375,3.2817898,,,,,,,,,,,,,, -279900,3.13672,1.6760057,,,,,,,,,,,,,, -280000,3.2375338,1.5723426,,,,,,,,,,,,,, -280100,3.615632,3.3084474,,,,,,,,,,,,,, -280193,,,0.8883007764816284,0.4194623529911041,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,123135.0280020237,134441.34727716446,123135.0280020237,11274.032299041748,17.657609224319458,0.0 -280200,3.1796584,1.0940392,,,,,,,,,,,,,, -280300,3.242215,2.840715,,,,,,,,,,,,,, -280400,3.4406316,2.1789932,,,,,,,,,,,,,, -280500,3.258551,1.186025,,,,,,,,,,,,,, -280600,3.1974585,2.169892,,,,,,,,,,,,,, -280700,3.2427814,1.1413808,,,,,,,,,,,,,, -280800,2.9100015,1.1885237,,,,,,,,,,,,,, -280900,4.0463285,3.267323,,,,,,,,,,,,,, -281000,2.873598,1.5534694,,,,,,,,,,,,,, -281100,3.182178,1.2893689,,,,,,,,,,,,,, -281148,,,0.8873046636581421,0.4161858260631561,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,123555.07653808594,134907.18300938606,123555.07653808594,11319.689100265505,17.738707065582275,0.0 -281200,3.3104005,1.1600295,,,,,,,,,,,,,, -281300,2.9378362,1.6084429,,,,,,,,,,,,,, -281400,2.9760332,1.4421976,,,,,,,,,,,,,, -281500,2.980939,1.6707286,,,,,,,,,,,,,, -281600,3.3521125,0.96336746,,,,,,,,,,,,,, -281700,3.2659018,2.0526137,,,,,,,,,,,,,, -281800,3.490763,2.5831938,,,,,,,,,,,,,, -281900,3.4259079,1.1187353,,,,,,,,,,,,,, -282000,3.5581217,1.1564882,,,,,,,,,,,,,, -282100,3.0342915,1.1846235,,,,,,,,,,,,,, -282104,,,0.8891406059265137,0.4142415225505829,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,123975.168166399,135364.49425768852,123975.168166399,11356.789668560028,17.809013605117798,0.0 -282200,3.2014031,1.1720788,,,,,,,,,,,,,, -282300,3.081549,1.1192862,,,,,,,,,,,,,, -282400,3.0429592,1.8321244,,,,,,,,,,,,,, -282500,3.9752598,3.3190627,,,,,,,,,,,,,, -282600,3.2738082,1.7768905,,,,,,,,,,,,,, -282700,3.351399,2.3020558,,,,,,,,,,,,,, -282800,3.4300067,2.6697578,,,,,,,,,,,,,, -282900,3.223833,1.2692825,,,,,,,,,,,,,, -283000,3.1185062,1.4211508,,,,,,,,,,,,,, -283058,,,0.8882812261581421,0.4211524426937103,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,124395.1518342495,135828.45954036713,124395.1518342495,11400.639189958572,17.891976594924927,0.0 -283100,3.0211954,1.0816209,,,,,,,,,,,,,, -283200,2.9471486,1.3819996,,,,,,,,,,,,,, -283300,3.4308593,2.9486272,,,,,,,,,,,,,, -283400,3.9850514,3.2625701,,,,,,,,,,,,,, -283500,3.2463171,1.0903724,,,,,,,,,,,,,, -283600,3.238787,1.2067231,,,,,,,,,,,,,, -283700,2.8760982,1.4737208,,,,,,,,,,,,,, -283800,3.055449,1.4891406,,,,,,,,,,,,,, -283900,3.127305,1.2147205,,,,,,,,,,,,,, -284000,3.0302958,1.154207,,,,,,,,,,,,,, -284012,,,0.8887304663658142,0.4123408496379852,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,124815.4282989502,136288.5594909191,124815.4282989502,11440.33082962036,17.974778175354004,0.0 -284100,3.1285899,2.7126048,,,,,,,,,,,,,, -284200,3.8256745,3.3374612,,,,,,,,,,,,,, -284300,3.0283973,1.510385,,,,,,,,,,,,,, -284400,3.5288606,2.069666,,,,,,,,,,,,,, -284500,3.0906508,1.1510403,,,,,,,,,,,,,, -284600,2.9838207,1.7957482,,,,,,,,,,,,,, -284700,3.1123872,1.0776975,,,,,,,,,,,,,, -284800,3.7131991,3.2947016,,,,,,,,,,,,,, -284900,3.0226796,1.5318149,,,,,,,,,,,,,, -284966,,,0.8908007740974426,0.410673975944519,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,125235.39047026634,136744.59841275215,125235.39047026634,11476.264526367188,18.068246841430664,0.0 -285000,3.2038486,2.6829684,,,,,,,,,,,,,, -285100,2.9564872,1.0164031,,,,,,,,,,,,,, -285200,3.206678,1.0335941,,,,,,,,,,,,,, -285300,3.1138222,1.0165555,,,,,,,,,,,,,, -285400,3.3114946,2.1005301,,,,,,,,,,,,,, -285500,3.7912612,1.0842811,,,,,,,,,,,,,, -285600,3.0056922,1.0761082,,,,,,,,,,,,,, -285700,3.304003,1.4715207,,,,,,,,,,,,,, -285800,3.1624475,1.1506243,,,,,,,,,,,,,, -285900,3.4167707,1.2545186,,,,,,,,,,,,,, -285915,,,0.8874804377555847,0.4160492718219757,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,125655.38073897362,137208.1740090847,125655.38073897362,11519.718435525894,18.15044403076172,0.0 -286000,3.1583037,1.1675825,,,,,,,,,,,,,, -286100,3.0629864,1.07939,,,,,,,,,,,,,, -286200,3.1266372,1.1096808,,,,,,,,,,,,,, -286300,3.2470083,1.6393392,,,,,,,,,,,,,, -286400,2.79851,1.7912004,,,,,,,,,,,,,, -286500,3.3828676,1.0049843,,,,,,,,,,,,,, -286600,4.0290365,1.4378755,,,,,,,,,,,,,, -286700,3.235266,1.6225076,,,,,,,,,,,,,, -286800,3.8613749,3.3319564,,,,,,,,,,,,,, -286871,,,0.88623046875,0.4197134375572204,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,126075.25695848464,137666.40572619438,126075.25695848464,11557.939510822296,18.23632121086121,0.0 -286900,2.9004233,1.2955651,,,,,,,,,,,,,, -287000,3.3648345,2.8933253,,,,,,,,,,,,,, -287100,3.3306963,1.241822,,,,,,,,,,,,,, -287200,3.129507,1.1326692,,,,,,,,,,,,,, -287300,4.077969,1.200525,,,,,,,,,,,,,, -287400,2.9473104,1.6362772,,,,,,,,,,,,,, -287500,3.3648186,2.1959293,,,,,,,,,,,,,, -287600,2.9868884,1.1486237,,,,,,,,,,,,,, -287700,2.952162,1.2489477,,,,,,,,,,,,,, -287800,4.243385,3.473217,,,,,,,,,,,,,, -287826,,,0.8890624642372131,0.41220623254776,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,126495.22283816338,138124.47844481468,126495.22283816338,11595.891792058945,18.34083724021912,0.0 -287900,3.4380991,3.1273718,,,,,,,,,,,,,, -288000,3.7558486,3.0864978,,,,,,,,,,,,,, -288100,3.148085,1.5642129,,,,,,,,,,,,,, -288200,3.2160606,1.1017237,,,,,,,,,,,,,, -288300,3.4631586,2.999746,,,,,,,,,,,,,, -288400,2.9413774,1.966994,,,,,,,,,,,,,, -288500,3.0960288,1.0955449,,,,,,,,,,,,,, -288600,3.229153,1.099035,,,,,,,,,,,,,, -288700,3.2470002,1.1602683,,,,,,,,,,,,,, -288774,,,0.8875976204872131,0.4183759093284607,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,126915.48208451273,138585.06207585335,126915.48208451273,11636.082073688509,18.42564058303833,0.0 -288800,3.0134845,1.305628,,,,,,,,,,,,,, -288900,2.9427497,1.1306134,,,,,,,,,,,,,, -289000,3.0958822,1.0693945,,,,,,,,,,,,,, -289100,3.265908,1.1711702,,,,,,,,,,,,,, -289200,3.0881817,1.0887563,,,,,,,,,,,,,, -289300,3.5640998,3.0419235,,,,,,,,,,,,,, -289400,3.2711535,1.1246719,,,,,,,,,,,,,, -289500,3.018699,1.0546808,,,,,,,,,,,,,, -289600,3.6553884,3.2349339,,,,,,,,,,,,,, -289700,3.5497434,2.152591,,,,,,,,,,,,,, -289728,,,0.8895702958106995,0.4143651723861694,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,127335.6986641884,139044.97914934158,127335.6986641884,11675.6428399086,18.51556301116944,0.0 -289800,2.980631,1.0916493,,,,,,,,,,,,,, -289900,3.002574,1.1663828,,,,,,,,,,,,,, -290000,3.0754297,1.0484632,,,,,,,,,,,,,, -290100,3.2576425,1.2914783,,,,,,,,,,,,,, -290200,3.4750276,2.7996025,,,,,,,,,,,,,, -290300,3.220566,1.0569557,,,,,,,,,,,,,, -290400,3.0811057,1.1456031,,,,,,,,,,,,,, -290500,3.7116997,3.2207675,,,,,,,,,,,,,, -290600,3.1497614,1.2081815,,,,,,,,,,,,,, -290682,,,0.884765625,0.4241220355033874,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,127755.82724237442,139508.1933040619,127755.82724237442,11718.595754623411,18.59935998916626,0.0 -290700,3.1176274,1.1922214,,,,,,,,,,,,,, -290800,3.1987295,1.2383606,,,,,,,,,,,,,, -290900,3.0186174,1.4131513,,,,,,,,,,,,,, -291000,3.2958841,1.236841,,,,,,,,,,,,,, -291100,3.2102737,1.3729845,,,,,,,,,,,,,, -291200,2.9430885,1.1285738,,,,,,,,,,,,,, -291300,3.5081735,2.6199045,,,,,,,,,,,,,, -291400,3.4382615,2.7987533,,,,,,,,,,,,,, -291500,2.9946363,1.0090793,,,,,,,,,,,,,, -291600,3.08117,0.9972866,,,,,,,,,,,,,, -291638,,,0.8886523246765137,0.417254239320755,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,128176.1004807949,139966.8594853878,128176.1004807949,11756.851325035095,18.68771505355835,0.0 -291700,3.8416884,3.164355,,,,,,,,,,,,,, -291800,2.8318443,1.6120113,,,,,,,,,,,,,, -291900,3.14417,1.1565636,,,,,,,,,,,,,, -292000,3.299387,1.5893145,,,,,,,,,,,,,, -292100,2.9231725,1.1466129,,,,,,,,,,,,,, -292200,3.7673,3.1977963,,,,,,,,,,,,,, -292300,3.1829634,1.2289026,,,,,,,,,,,,,, -292400,3.3663826,2.6231759,,,,,,,,,,,,,, -292500,3.0465074,1.9554527,,,,,,,,,,,,,, -292593,,,0.8885546922683716,0.4134081602096557,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,128596.28019595146,140427.85061454773,128596.28019595146,11797.527698516846,18.773333311080933,0.0 -292600,3.0085626,1.1175826,,,,,,,,,,,,,, -292700,3.0254538,1.1148716,,,,,,,,,,,,,, -292800,3.177009,1.933104,,,,,,,,,,,,,, -292900,3.2753837,1.1097529,,,,,,,,,,,,,, -293000,3.0509024,1.1029285,,,,,,,,,,,,,, -293100,2.8437386,1.1730293,,,,,,,,,,,,,, -293200,3.0972743,1.1546159,,,,,,,,,,,,,, -293300,3.0926359,1.5753015,,,,,,,,,,,,,, -293400,3.0353882,2.786749,,,,,,,,,,,,,, -293500,3.0017524,1.088298,,,,,,,,,,,,,, -293549,,,0.8882030844688416,0.4148542284965515,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,129016.2508816719,140886.34425234795,129016.2508816719,11835.907625436785,18.86729431152344,0.0 -293600,3.3446336,1.157823,,,,,,,,,,,,,, -293700,3.924482,3.1606674,,,,,,,,,,,,,, -293800,3.169425,1.4366302,,,,,,,,,,,,,, -293900,3.1762745,2.7213943,,,,,,,,,,,,,, -294000,3.0435724,1.4682249,,,,,,,,,,,,,, -294100,3.6732981,1.5562439,,,,,,,,,,,,,, -294200,2.9668262,1.4717877,,,,,,,,,,,,,, -294300,3.1842752,2.6252203,,,,,,,,,,,,,, -294400,3.1539235,1.1312845,,,,,,,,,,,,,, -294500,3.1653104,1.083491,,,,,,,,,,,,,, -294501,,,0.8893749713897705,0.414853423833847,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,129436.32824611664,141344.36743688583,129436.32824611664,11873.716660499573,18.954487562179565,0.0 -294600,3.2063606,1.1150203,,,,,,,,,,,,,, -294700,3.1666462,1.1776754,,,,,,,,,,,,,, -294800,3.008181,1.3924928,,,,,,,,,,,,,, -294900,3.0559058,1.131311,,,,,,,,,,,,,, -295000,3.4364004,2.8953469,,,,,,,,,,,,,, -295100,2.9656444,1.6862961,,,,,,,,,,,,,, -295200,3.3020031,1.1122601,,,,,,,,,,,,,, -295300,3.305051,2.3442895,,,,,,,,,,,,,, -295400,3.7231882,3.037642,,,,,,,,,,,,,, -295453,,,0.88929682970047,0.4171392321586609,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,129856.59120893478,141805.1029598713,129856.59120893478,11914.05112528801,19.04377579689026,0.0 -295500,3.0860012,2.382329,,,,,,,,,,,,,, -295600,3.0236108,1.0743542,,,,,,,,,,,,,, -295700,4.136633,3.28468,,,,,,,,,,,,,, -295800,3.1512501,1.1597037,,,,,,,,,,,,,, -295900,3.5972846,1.1529205,,,,,,,,,,,,,, -296000,3.6420393,3.1007423,,,,,,,,,,,,,, -296100,3.3814669,1.1214957,,,,,,,,,,,,,, -296200,3.2365484,1.0823803,,,,,,,,,,,,,, -296300,3.0670595,1.0226494,,,,,,,,,,,,,, -296400,3.269351,1.3193972,,,,,,,,,,,,,, -296408,,,0.8854687213897705,0.4234060049057007,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,130276.792396307,142272.9465227127,130276.792396307,11961.55670619011,19.131915807724,0.0 -296500,3.2604797,1.1056085,,,,,,,,,,,,,, -296600,3.1192749,1.1359199,,,,,,,,,,,,,, -296700,3.3866942,1.2118692,,,,,,,,,,,,,, -296800,3.7766428,3.0687258,,,,,,,,,,,,,, -296900,3.4093335,3.0515423,,,,,,,,,,,,,, -297000,3.306071,1.2365901,,,,,,,,,,,,,, -297100,3.1770105,1.5626955,,,,,,,,,,,,,, -297200,3.4805844,1.1070065,,,,,,,,,,,,,, -297300,3.5461383,3.220193,,,,,,,,,,,,,, -297369,,,0.8866601586341858,0.4202454388141632,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,130696.7673239708,142738.7496433258,130696.7673239708,12007.264422655106,19.20346760749817,0.0 -297400,3.7038958,3.1711216,,,,,,,,,,,,,, -297500,3.336319,1.1534705,,,,,,,,,,,,,, -297600,3.1643746,1.1423925,,,,,,,,,,,,,, -297700,2.7846525,1.9742899,,,,,,,,,,,,,, -297800,3.2292597,1.7762907,,,,,,,,,,,,,, -297900,3.780136,1.1617692,,,,,,,,,,,,,, -298000,2.91393,1.797421,,,,,,,,,,,,,, -298100,3.1892421,1.5991536,,,,,,,,,,,,,, -298200,3.2245338,2.7956038,,,,,,,,,,,,,, -298300,3.0947137,1.0736674,,,,,,,,,,,,,, -298328,,,0.8893163800239563,0.4147822558879852,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,131116.9161388874,143204.24835014343,131116.9161388874,12052.48368382454,19.28312730789185,0.0 -298400,5.3570995,2.6278768,,,,,,,,,,,,,, -298500,3.1022198,1.0940287,,,,,,,,,,,,,, -298600,3.0330174,1.095135,,,,,,,,,,,,,, -298700,3.384075,1.4573789,,,,,,,,,,,,,, -298800,2.8882837,1.0488129,,,,,,,,,,,,,, -298900,3.15929,1.138753,,,,,,,,,,,,,, -299000,3.3113074,2.3266196,,,,,,,,,,,,,, -299100,3.1708877,2.6916342,,,,,,,,,,,,,, -299200,2.9726481,1.1993917,,,,,,,,,,,,,, -299282,,,0.8881054520606995,0.4154012203216553,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,131537.1429643631,143663.2446167469,131537.1429643631,12091.130192756653,19.356364011764526,0.0 -299300,3.249367,1.220312,,,,,,,,,,,,,, -299400,3.1159658,2.918762,,,,,,,,,,,,,, -299500,3.3108277,1.0590036,,,,,,,,,,,,,, -299600,2.9322124,1.7371403,,,,,,,,,,,,,, -299700,3.0934694,1.0926839,,,,,,,,,,,,,, -299800,3.1299074,2.4721506,,,,,,,,,,,,,, -299900,3.0952215,1.5053986,,,,,,,,,,,,,, -300000,3.0995848,1.141564,,,,,,,,,,,,,, -300100,3.2940638,1.1819078,,,,,,,,,,,,,, -300200,3.568455,1.217786,,,,,,,,,,,,,, -300226,,,0.8876757621765137,0.4198732674121856,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,131957.35796141624,144120.52024149895,131957.35796141624,12128.053198814392,19.444679021835327,0.0 -300300,3.6859484,1.5737448,,,,,,,,,,,,,, -300400,3.0551283,1.1372932,,,,,,,,,,,,,, -300500,4.3070965,3.252706,,,,,,,,,,,,,, -300600,3.13956,1.650584,,,,,,,,,,,,,, -300700,3.223224,1.5632833,,,,,,,,,,,,,, -300800,3.604872,2.9873607,,,,,,,,,,,,,, -300900,3.1945715,1.1463379,,,,,,,,,,,,,, -301000,3.2002933,1.5604547,,,,,,,,,,,,,, -301100,3.0392485,1.0387602,,,,,,,,,,,,,, -301179,,,0.88636714220047,0.4219211935997009,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,132377.5421822071,144582.29728770256,132377.5421822071,12169.511289596558,19.53000783920288,0.0 -301200,4.7979255,1.2363153,,,,,,,,,,,,,, -301300,3.0735078,2.2238297,,,,,,,,,,,,,, -301400,3.3898897,1.6542559,,,,,,,,,,,,,, -301500,3.0095966,1.156276,,,,,,,,,,,,,, -301600,3.939294,3.2030044,,,,,,,,,,,,,, -301700,2.9873416,1.7831262,,,,,,,,,,,,,, -301800,3.1468132,1.9005151,,,,,,,,,,,,,, -301900,2.9490495,1.8249263,,,,,,,,,,,,,, -302000,3.4840088,1.185756,,,,,,,,,,,,,, -302100,3.1651115,1.3835597,,,,,,,,,,,,,, -302132,,,0.8879687190055847,0.4166455268859863,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,132797.46908450127,145040.6056535244,132797.46908450127,12207.755256175997,19.618409633636475,0.0 -302200,3.549473,2.5704606,,,,,,,,,,,,,, -302300,3.1090994,1.1124073,,,,,,,,,,,,,, -302400,2.8748443,1.5317776,,,,,,,,,,,,,, -302500,3.3642178,1.3927014,,,,,,,,,,,,,, -302600,2.8502269,1.3470254,,,,,,,,,,,,,, -302700,3.285225,1.1488119,,,,,,,,,,,,,, -302800,3.2695215,1.4331704,,,,,,,,,,,,,, -302900,2.9070072,1.3560691,,,,,,,,,,,,,, -303000,3.1240423,2.4989946,,,,,,,,,,,,,, -303087,,,0.8862499594688416,0.4185951948165893,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,133217.7573838234,145504.82874035835,133217.7573838234,12251.54194355011,19.716617345809937,0.0 -303100,3.0898793,1.1391151,,,,,,,,,,,,,, -303200,3.826513,3.328833,,,,,,,,,,,,,, -303300,3.0925791,1.8612074,,,,,,,,,,,,,, -303400,3.0396903,1.1404678,,,,,,,,,,,,,, -303500,3.5993652,3.2223399,,,,,,,,,,,,,, -303600,3.8006084,2.4046469,,,,,,,,,,,,,, -303700,3.0332189,2.4155376,,,,,,,,,,,,,, -303800,3.2353406,1.2694356,,,,,,,,,,,,,, -303900,3.0779579,1.7775497,,,,,,,,,,,,,, -304000,3.0131598,1.1695944,,,,,,,,,,,,,, -304045,,,0.8896093368530273,0.4145287573337555,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,133637.78238511086,145963.89064216614,133637.78238511086,12290.439830303192,19.80525302886963,0.0 -304100,3.2642782,1.0922121,,,,,,,,,,,,,, -304200,3.8370306,3.3076487,,,,,,,,,,,,,, -304300,4.4889326,3.3084366,,,,,,,,,,,,,, -304400,2.914906,1.1453336,,,,,,,,,,,,,, -304500,3.00644,1.2702527,,,,,,,,,,,,,, -304600,3.4992282,1.1679868,,,,,,,,,,,,,, -304700,3.346217,1.108563,,,,,,,,,,,,,, -304800,3.2168584,1.3630526,,,,,,,,,,,,,, -304900,3.233654,1.3266251,,,,,,,,,,,,,, -305000,,,0.8884570002555847,0.4126504063606262,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,134058.0566318035,146424.70875573158,134058.0566318035,12330.843381643295,19.895825386047363,0.0 -305000,3.5172017,2.0215814,,,,,,,,,,,,,, -305100,3.1425989,1.1335543,,,,,,,,,,,,,, -305200,3.0725505,1.6649898,,,,,,,,,,,,,, -305300,3.1113055,1.4429675,,,,,,,,,,,,,, -305400,2.881336,1.0629362,,,,,,,,,,,,,, -305500,2.9743319,1.4025109,,,,,,,,,,,,,, -305600,3.1082187,1.2624705,,,,,,,,,,,,,, -305700,2.7813408,1.2980483,,,,,,,,,,,,,, -305800,3.0039852,1.522867,,,,,,,,,,,,,, -305900,3.9954455,3.2540252,,,,,,,,,,,,,, -305955,,,0.88783198595047,0.4241567254066467,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,134478.2640724182,146890.98871064186,134478.2640724182,12376.776423931122,19.98555564880371,0.0 -306000,3.2598042,1.2270527,,,,,,,,,,,,,, -306100,3.1292777,2.192526,,,,,,,,,,,,,, -306200,3.0938358,1.1837374,,,,,,,,,,,,,, -306300,3.2086163,1.1119906,,,,,,,,,,,,,, -306400,3.0540586,1.6595678,,,,,,,,,,,,,, -306500,3.3768685,1.0417657,,,,,,,,,,,,,, -306600,3.2753403,1.0573106,,,,,,,,,,,,,, -306700,3.3441687,1.2782993,,,,,,,,,,,,,, -306800,3.7422328,3.1891577,,,,,,,,,,,,,, -306900,3.055789,2.3392699,,,,,,,,,,,,,, -306912,,,0.8887695074081421,0.4157889485359192,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,134898.1319179535,147348.89574956894,134898.1319179535,12414.676441431046,20.075188636779785,0.0 -307000,5.057618,3.2019775,,,,,,,,,,,,,, -307100,3.0251007,1.109877,,,,,,,,,,,,,, -307200,3.1648412,1.1483376,,,,,,,,,,,,,, -307300,3.2078352,2.5950978,,,,,,,,,,,,,, -307400,2.8677719,2.2358909,,,,,,,,,,,,,, -307500,3.1433113,2.085655,,,,,,,,,,,,,, -307600,3.152725,2.0464196,,,,,,,,,,,,,, -307700,3.0647798,1.5756539,,,,,,,,,,,,,, -307800,3.7145653,3.2305646,,,,,,,,,,,,,, -307867,,,0.8900195360183716,0.4082540273666382,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,135318.12880396843,147808.09196567535,135318.12880396843,12453.731521844864,20.169724941253666,0.0 -307900,3.031784,1.8566021,,,,,,,,,,,,,, -308000,3.0261912,2.147963,,,,,,,,,,,,,, -308100,3.0292048,1.4453821,,,,,,,,,,,,,, -308200,2.8822856,2.0185633,,,,,,,,,,,,,, -308300,3.0480595,2.2493327,,,,,,,,,,,,,, -308400,3.0774035,1.6627369,,,,,,,,,,,,,, -308500,3.0906954,1.2563933,,,,,,,,,,,,,, -308600,2.8525896,1.1339495,,,,,,,,,,,,,, -308700,3.390009,2.910546,,,,,,,,,,,,,, -308800,3.3012629,1.1226708,,,,,,,,,,,,,, -308824,,,0.8877343535423279,0.4190030992031097,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,135738.11697626114,148264.62213468552,135738.11697626114,12490.151092767715,20.24302864074707,0.0 -308900,2.9831257,1.5156211,,,,,,,,,,,,,, -309000,2.9714687,1.201066,,,,,,,,,,,,,, -309100,2.8584514,1.8433493,,,,,,,,,,,,,, -309200,3.4082608,1.231785,,,,,,,,,,,,,, -309300,3.4407876,1.0966679,,,,,,,,,,,,,, -309400,3.0928066,2.5796833,,,,,,,,,,,,,, -309500,3.1268497,1.1691419,,,,,,,,,,,,,, -309600,3.1094403,2.5584526,,,,,,,,,,,,,, -309700,3.2425838,2.8016434,,,,,,,,,,,,,, -309777,,,0.8886523246765137,0.4138353765010834,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,136158.0286180973,148729.02848935127,136158.0286180973,12534.50862789154,20.331170082092285,0.0 -309800,3.075083,2.4964533,,,,,,,,,,,,,, -309900,3.4504702,1.2314374,,,,,,,,,,,,,, -310000,2.9909167,1.7602221,,,,,,,,,,,,,, -310100,3.7153888,1.1469097,,,,,,,,,,,,,, -310200,3.0587614,1.1396015,,,,,,,,,,,,,, -310300,3.049814,2.0367932,,,,,,,,,,,,,, -310400,3.0169225,1.782913,,,,,,,,,,,,,, -310500,3.1562831,1.6012572,,,,,,,,,,,,,, -310600,3.1820843,1.535384,,,,,,,,,,,,,, -310700,3.1160133,1.1958009,,,,,,,,,,,,,, -310731,,,0.8862695097923279,0.4201231002807617,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,136577.9388947487,149185.8946583271,136577.9388947487,12571.31841301918,20.42776656150818,0.0 -310800,3.2522838,1.7355018,,,,,,,,,,,,,, -310900,3.0538344,1.0582857,,,,,,,,,,,,,, -311000,3.4664497,1.6595713,,,,,,,,,,,,,, -311100,3.271731,2.6605968,,,,,,,,,,,,,, -311200,3.4336674,2.8952267,,,,,,,,,,,,,, -311300,3.025301,1.045505,,,,,,,,,,,,,, -311400,3.0337772,1.1315435,,,,,,,,,,,,,, -311500,3.1260395,1.2499093,,,,,,,,,,,,,, -311600,3.0524719,1.1364558,,,,,,,,,,,,,, -311684,,,0.8872851133346558,0.4167616665363312,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,136997.87422275543,149644.31943631172,136997.87422275543,12609.654133558271,20.53228735923767,0.0 -311700,3.0834901,2.084682,,,,,,,,,,,,,, -311800,3.1934483,1.1439886,,,,,,,,,,,,,, -311900,3.4902103,2.8667865,,,,,,,,,,,,,, -312000,2.8643878,1.1443274,,,,,,,,,,,,,, -312100,3.3082755,2.8911138,,,,,,,,,,,,,, -312200,2.9676425,1.3556644,,,,,,,,,,,,,, -312300,3.2328947,1.1398761,,,,,,,,,,,,,, -312400,3.1666477,1.1234283,,,,,,,,,,,,,, -312500,3.275775,1.5551599,,,,,,,,,,,,,, -312600,3.1782808,1.497782,,,,,,,,,,,,,, -312638,,,0.8905078172683716,0.4076966643333435,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,137418.08788132668,150112.3021736145,137418.08788132668,12657.285651922226,20.62012791633606,0.0 -312700,2.981424,1.1497698,,,,,,,,,,,,,, -312800,3.0339937,1.1738132,,,,,,,,,,,,,, -312900,3.3576066,1.1281567,,,,,,,,,,,,,, -313000,3.0917716,1.3387992,,,,,,,,,,,,,, -313100,2.900022,0.9721284,,,,,,,,,,,,,, -313200,3.0035717,1.1664234,,,,,,,,,,,,,, -313300,2.965118,2.1134765,,,,,,,,,,,,,, -313400,3.1086988,2.3185754,,,,,,,,,,,,,, -313500,3.5886736,2.3655179,,,,,,,,,,,,,, -313597,,,0.8882616758346558,0.4182638525962829,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,137838.1891798973,150572.7144780159,137838.1891798973,12697.472255945206,20.69444990158081,0.0 -313600,2.824909,1.4504112,,,,,,,,,,,,,, -313700,3.121523,1.3587118,,,,,,,,,,,,,, -313800,3.0528934,1.0596836,,,,,,,,,,,,,, -313900,3.811973,1.1709399,,,,,,,,,,,,,, -314000,2.784725,1.4566627,,,,,,,,,,,,,, -314100,3.0316882,1.690959,,,,,,,,,,,,,, -314200,3.3069415,1.2187717,,,,,,,,,,,,,, -314300,3.0176775,1.3696877,,,,,,,,,,,,,, -314400,3.420921,2.900742,,,,,,,,,,,,,, -314500,3.1949348,2.7512953,,,,,,,,,,,,,, -314553,,,0.8880859017372131,0.4185377359390259,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,138258.14862036705,151034.2818892002,138258.14862036705,12738.952292442322,20.77275776863098,0.0 -314600,2.9832509,1.1633646,,,,,,,,,,,,,, -314700,3.0738137,1.1413319,,,,,,,,,,,,,, -314800,3.2660706,1.5173802,,,,,,,,,,,,,, -314900,2.9017947,1.1890501,,,,,,,,,,,,,, -315000,3.092691,1.3107322,,,,,,,,,,,,,, -315100,3.2366333,1.2042627,,,,,,,,,,,,,, -315200,3.312086,1.1140084,,,,,,,,,,,,,, -315300,3.058299,1.1290526,,,,,,,,,,,,,, -315400,3.002796,1.9188592,,,,,,,,,,,,,, -315500,3.1788504,1.098157,,,,,,,,,,,,,, -315505,,,0.8875390291213989,0.4202651381492615,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,138678.3341538906,151496.09946393967,138678.3341538906,12780.446078777311,20.86136531829834,0.0 -315600,3.3663933,3.0949588,,,,,,,,,,,,,, -315700,2.9512765,1.137265,,,,,,,,,,,,,, -315800,2.9570851,1.0824934,,,,,,,,,,,,,, -315900,2.953526,1.4611676,,,,,,,,,,,,,, -316000,3.3106005,2.8258827,,,,,,,,,,,,,, -316100,3.075749,1.7896711,,,,,,,,,,,,,, -316200,3.3531775,1.1336657,,,,,,,,,,,,,, -316300,3.1755593,2.662554,,,,,,,,,,,,,, -316400,3.008231,2.5256908,,,,,,,,,,,,,, -316454,,,0.88671875,0.4202517867088318,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,139098.41466093063,151964.981477499,139098.41466093063,12829.110315561296,20.9499146938324,0.0 -316500,3.1196718,1.1452281,,,,,,,,,,,,,, -316600,3.6995552,2.8506167,,,,,,,,,,,,,, -316700,3.2855692,2.5962553,,,,,,,,,,,,,, -316800,3.1590054,1.8290708,,,,,,,,,,,,,, -316900,2.9882176,1.7036277,,,,,,,,,,,,,, -317000,3.9521835,3.1344566,,,,,,,,,,,,,, -317100,3.12841,1.1586635,,,,,,,,,,,,,, -317200,3.426966,2.8520198,,,,,,,,,,,,,, -317300,3.025859,1.1445395,,,,,,,,,,,,,, -317400,2.9394825,1.4666662,,,,,,,,,,,,,, -317411,,,0.8899804353713989,0.4071928262710571,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,139518.60390734673,152428.16596484184,139518.60390734673,12871.982561588287,21.02366185188293,0.0 -317500,2.9326806,1.4479394,,,,,,,,,,,,,, -317600,3.197695,1.6376252,,,,,,,,,,,,,, -317700,3.7717373,3.2627218,,,,,,,,,,,,,, -317800,2.9080842,1.082684,,,,,,,,,,,,,, -317900,2.95038,1.3773396,,,,,,,,,,,,,, -318000,2.8754637,2.3094497,,,,,,,,,,,,,, -318100,2.8304462,2.244872,,,,,,,,,,,,,, -318200,3.03691,1.3013389,,,,,,,,,,,,,, -318300,2.9007661,1.3179512,,,,,,,,,,,,,, -318367,,,0.8865624666213989,0.4238818287849426,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,139938.67020440102,152887.1681947708,139938.67020440102,12910.795140266418,21.09812617301941,0.0 -318400,3.0066364,2.2760344,,,,,,,,,,,,,, -318500,2.8179495,1.0492442,,,,,,,,,,,,,, -318600,3.2263806,1.1123905,,,,,,,,,,,,,, -318700,3.00794,2.3438182,,,,,,,,,,,,,, -318800,3.5233095,3.1974926,,,,,,,,,,,,,, -318900,2.8165078,1.0811465,,,,,,,,,,,,,, -319000,3.2547636,1.1681669,,,,,,,,,,,,,, -319100,4.076026,3.1300278,,,,,,,,,,,,,, -319200,3.1016223,1.7592944,,,,,,,,,,,,,, -319300,3.1871881,1.1048489,,,,,,,,,,,,,, -319304,,,0.8884961009025574,0.4180271029472351,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,140358.56078743935,153347.8892352581,140358.56078743935,12951.488573789597,21.18717908859253,0.0 -319400,3.134469,1.3578238,,,,,,,,,,,,,, -319500,3.542825,1.1551625,,,,,,,,,,,,,, -319600,3.229332,1.7728064,,,,,,,,,,,,,, -319700,2.9305158,1.2333817,,,,,,,,,,,,,, -319800,3.149128,1.2436764,,,,,,,,,,,,,, -319900,3.1432178,1.610369,,,,,,,,,,,,,, -320000,3.1749141,1.0566667,,,,,,,,,,,,,, -320100,3.6533437,1.4844385,,,,,,,,,,,,,, -320200,3.9882672,3.2038493,,,,,,,,,,,,,, -320258,,,0.8869335651397705,0.419131875038147,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,140778.834120512,153806.9401268959,140778.834120512,12990.127300024033,21.276575088500977,0.0 -320300,3.6268811,1.4445751,,,,,,,,,,,,,, -320400,3.029677,1.032592,,,,,,,,,,,,,, -320500,3.0961561,1.1042951,,,,,,,,,,,,,, -320600,3.0904145,1.1621543,,,,,,,,,,,,,, -320700,3.2147834,1.1505026,,,,,,,,,,,,,, -320800,3.1036012,1.0599082,,,,,,,,,,,,,, -320900,2.8954723,1.0254701,,,,,,,,,,,,,, -321000,3.0195873,1.7381704,,,,,,,,,,,,,, -321100,3.243417,1.5915229,,,,,,,,,,,,,, -321200,3.0346315,1.3150661,,,,,,,,,,,,,, -321213,,,0.8894921541213989,0.4086560010910034,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,141198.73992681503,154264.70492196083,141198.73992681503,13027.79325413704,21.4208242893219,0.0 -321300,3.929727,3.0590215,,,,,,,,,,,,,, -321400,3.0846744,2.794269,,,,,,,,,,,,,, -321500,3.1466143,1.180809,,,,,,,,,,,,,, -321600,3.0558937,1.2396014,,,,,,,,,,,,,, -321700,2.9382222,1.1599356,,,,,,,,,,,,,, -321800,3.428603,1.1704749,,,,,,,,,,,,,, -321900,2.9707253,2.376949,,,,,,,,,,,,,, -322000,3.1355903,1.1504542,,,,,,,,,,,,,, -322100,3.1731842,2.0497792,,,,,,,,,,,,,, -322164,,,0.8868749737739563,0.4255317449569702,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,141619.0731432438,154724.26458621025,141619.0731432438,13066.877766132357,21.51161813735962,0.0 -322200,3.335821,1.1383291,,,,,,,,,,,,,, -322300,3.1467369,2.956955,,,,,,,,,,,,,, -322400,3.4648561,2.7851233,,,,,,,,,,,,,, -322500,3.0485864,1.341851,,,,,,,,,,,,,, -322600,3.56675,1.7148539,,,,,,,,,,,,,, -322700,3.2955496,1.1354158,,,,,,,,,,,,,, -322800,3.0731018,2.164762,,,,,,,,,,,,,, -322900,3.1020164,1.1211843,,,,,,,,,,,,,, -323000,3.266142,1.0743761,,,,,,,,,,,,,, -323100,3.5626569,2.9028168,,,,,,,,,,,,,, -323119,,,0.8875390291213989,0.4166724681854248,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,142039.1426100731,155188.78133320808,142039.1426100731,13111.180988073347,21.60645341873169,0.0 -323200,2.8898182,1.1191804,,,,,,,,,,,,,, -323300,2.992083,1.0759801,,,,,,,,,,,,,, -323400,3.5243082,1.8067104,,,,,,,,,,,,,, -323500,3.102482,1.1897308,,,,,,,,,,,,,, -323600,3.4765515,2.120471,,,,,,,,,,,,,, -323700,3.256577,1.5707242,,,,,,,,,,,,,, -323800,3.4081135,1.0805892,,,,,,,,,,,,,, -323900,3.035269,1.1162007,,,,,,,,,,,,,, -324000,2.9321551,1.2197652,,,,,,,,,,,,,, -324077,,,0.88783198595047,0.4228438138961792,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,142459.23724484444,155657.20071411133,142459.23724484444,13159.376702070236,21.68626976013184,0.0 -324100,3.4996002,1.1830177,,,,,,,,,,,,,, -324200,3.123301,1.1080014,,,,,,,,,,,,,, -324300,3.1847546,2.08935,,,,,,,,,,,,,, -324400,4.4629765,3.279356,,,,,,,,,,,,,, -324500,3.0270133,2.022356,,,,,,,,,,,,,, -324600,2.9506412,1.1501167,,,,,,,,,,,,,, -324700,2.9195497,1.1776173,,,,,,,,,,,,,, -324800,3.308731,1.3774842,,,,,,,,,,,,,, -324900,3.362905,1.0979083,,,,,,,,,,,,,, -325000,3.0244443,1.1588595,,,,,,,,,,,,,, -325034,,,0.8855078220367432,0.4231415390968323,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,142879.33050012589,156126.01038050652,142879.33050012589,13207.968095541,21.76215624809265,0.0 -325100,3.035269,1.1532254,,,,,,,,,,,,,, -325200,2.964609,1.0342468,,,,,,,,,,,,,, -325300,2.8938377,1.665707,,,,,,,,,,,,,, -325400,3.1054475,2.518659,,,,,,,,,,,,,, -325500,3.2505698,1.2296913,,,,,,,,,,,,,, -325600,3.4655826,2.0444944,,,,,,,,,,,,,, -325700,3.271261,1.1745542,,,,,,,,,,,,,, -325800,3.159129,2.1781235,,,,,,,,,,,,,, -325900,3.0613246,1.3366215,,,,,,,,,,,,,, -325990,,,0.8892773389816284,0.4138164818286896,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,143299.4483640194,156585.47791337967,143299.4483640194,13247.191797733309,21.83880043029785,0.0 -326000,2.9781113,1.2946196,,,,,,,,,,,,,, -326100,2.9267995,1.1410024,,,,,,,,,,,,,, -326200,3.144209,1.1205204,,,,,,,,,,,,,, -326300,3.0973923,1.3734151,,,,,,,,,,,,,, -326400,3.045773,1.3243097,,,,,,,,,,,,,, -326500,3.2562897,2.3641033,,,,,,,,,,,,,, -326600,3.1590488,1.1012335,,,,,,,,,,,,,, -326700,2.8791087,1.6312054,,,,,,,,,,,,,, -326800,3.4111223,1.1734039,,,,,,,,,,,,,, -326900,4.085909,3.2981274,,,,,,,,,,,,,, -326935,,,0.8871874809265137,0.4137941896915436,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,143719.5908768177,157055.27760457993,143719.5908768177,13296.705656290054,21.93300485610962,0.0 -327000,3.876712,3.2044592,,,,,,,,,,,,,, -327100,3.1856043,1.177466,,,,,,,,,,,,,, -327200,2.8038268,1.4615363,,,,,,,,,,,,,, -327300,3.082991,1.110636,,,,,,,,,,,,,, -327400,3.2911804,1.4778872,,,,,,,,,,,,,, -327500,3.5588038,3.241526,,,,,,,,,,,,,, -327600,2.8286743,1.2063305,,,,,,,,,,,,,, -327700,3.1045756,1.4338366,,,,,,,,,,,,,, -327800,4.149275,2.2868943,,,,,,,,,,,,,, -327889,,,0.8894726634025574,0.4152855277061462,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,144139.76654458046,157514.30312132835,144139.76654458046,13335.431085586548,22.008670568466187,0.0 -327900,3.0320055,1.714678,,,,,,,,,,,,,, -328000,3.1911151,2.8267088,,,,,,,,,,,,,, -328100,3.0133455,1.2864932,,,,,,,,,,,,,, -328200,3.116041,1.1807201,,,,,,,,,,,,,, -328300,3.0906062,1.2603081,,,,,,,,,,,,,, -328400,3.312448,2.8917718,,,,,,,,,,,,,, -328500,3.4073958,2.5839667,,,,,,,,,,,,,, -328600,2.7887893,1.2175227,,,,,,,,,,,,,, -328700,3.1435623,1.087352,,,,,,,,,,,,,, -328800,3.0639799,1.4792917,,,,,,,,,,,,,, -328841,,,0.8877539038658142,0.416228175163269,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,144560.12257027626,157975.0710504055,144560.12257027626,13375.67746925354,22.125629425048828,0.0 -328900,3.3624594,1.1955153,,,,,,,,,,,,,, -329000,3.2628317,1.1070263,,,,,,,,,,,,,, -329100,3.8722253,3.3695812,,,,,,,,,,,,,, -329200,2.9493334,1.0732213,,,,,,,,,,,,,, -329300,3.1891072,1.0559647,,,,,,,,,,,,,, -329400,3.4376473,1.132561,,,,,,,,,,,,,, -329500,5.438139,1.1649903,,,,,,,,,,,,,, -329600,3.6575327,1.6330441,,,,,,,,,,,,,, -329700,2.930596,1.2467996,,,,,,,,,,,,,, -329785,,,0.8871679306030273,0.4185080528259277,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,144980.0615439415,158439.23374319077,144980.0615439415,13419.763077259064,22.216106176376343,0.0 -329800,3.6016276,3.1275966,,,,,,,,,,,,,, -329900,2.9936492,1.5783837,,,,,,,,,,,,,, -330000,3.0860417,1.1453868,,,,,,,,,,,,,, -330100,3.486176,1.265073,,,,,,,,,,,,,, -330200,3.2380915,1.1852334,,,,,,,,,,,,,, -330300,3.3249164,1.1110396,,,,,,,,,,,,,, -330400,3.2629519,2.4154067,,,,,,,,,,,,,, -330500,2.9735367,1.0126367,,,,,,,,,,,,,, -330600,3.7825127,3.1905513,,,,,,,,,,,,,, -330700,3.6850758,1.1644232,,,,,,,,,,,,,, -330743,,,0.8885937333106995,0.4203934073448181,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,145400.25120687485,158898.16319799423,145400.25120687485,13458.372012853622,22.297227382659912,0.0 -330800,3.169573,1.1176058,,,,,,,,,,,,,, -330900,3.4698172,1.064388,,,,,,,,,,,,,, -331000,3.0692987,1.0278949,,,,,,,,,,,,,, -331100,3.139553,1.1470588,,,,,,,,,,,,,, -331200,2.918867,1.8256487,,,,,,,,,,,,,, -331300,3.0478282,1.1422299,,,,,,,,,,,,,, -331400,3.1408906,2.063177,,,,,,,,,,,,,, -331500,3.0196457,1.1138399,,,,,,,,,,,,,, -331600,3.1179442,1.7827193,,,,,,,,,,,,,, -331697,,,0.8895702958106995,0.4124229848384857,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,145820.2736287117,159364.24648237228,145820.2736287117,13504.289674520493,22.39159369468689,0.0 -331700,3.173089,1.1604478,,,,,,,,,,,,,, -331800,3.1320636,2.4240808,,,,,,,,,,,,,, -331900,2.9541445,1.993849,,,,,,,,,,,,,, -332000,3.2517905,1.1338714,,,,,,,,,,,,,, -332100,3.7189317,3.2999337,,,,,,,,,,,,,, -332200,3.2597773,1.3734478,,,,,,,,,,,,,, -332300,2.9426386,1.5741026,,,,,,,,,,,,,, -332400,3.1067648,1.1757381,,,,,,,,,,,,,, -332500,3.0699759,1.2901423,,,,,,,,,,,,,, -332600,3.282382,1.1986353,,,,,,,,,,,,,, -332651,,,0.8883398175239563,0.4132517576217651,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,146240.3351507187,159831.4127779007,146240.3351507187,13551.26194858551,22.474181175231934,0.0 -332700,3.226206,2.1004317,,,,,,,,,,,,,, -332800,3.0850704,1.0587299,,,,,,,,,,,,,, -332900,2.9241943,1.8425124,,,,,,,,,,,,,, -333000,2.991316,1.1038058,,,,,,,,,,,,,, -333100,3.6719174,3.2562633,,,,,,,,,,,,,, -333200,3.2933404,1.1875654,,,,,,,,,,,,,, -333300,4.252702,3.1945214,,,,,,,,,,,,,, -333400,3.1466959,1.1368707,,,,,,,,,,,,,, -333500,3.3804386,2.6080654,,,,,,,,,,,,,, -333600,3.0092554,2.3594475,,,,,,,,,,,,,, -333609,,,0.8883788585662842,0.4145821630954742,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,146660.5423769951,160298.3706278801,146660.5423769951,13597.88392996788,22.552911043167114,0.0 -333700,3.2878637,1.1414819,,,,,,,,,,,,,, -333800,3.7706919,1.1948878,,,,,,,,,,,,,, -333900,3.3665695,1.1875012,,,,,,,,,,,,,, -334000,3.37218,1.3610532,,,,,,,,,,,,,, -334100,3.6527076,3.2877188,,,,,,,,,,,,,, -334200,3.2248218,2.522022,,,,,,,,,,,,,, -334300,4.0830054,1.1848868,,,,,,,,,,,,,, -334400,3.0838845,1.1153214,,,,,,,,,,,,,, -334500,3.0338163,1.205377,,,,,,,,,,,,,, -334567,,,0.8887499570846558,0.4145262539386749,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,147080.81252145767,160766.93076348305,147080.81252145767,13646.050462961197,22.62759017944336,0.0 -334600,3.1703188,1.0922339,,,,,,,,,,,,,, -334700,2.8398392,1.316335,,,,,,,,,,,,,, -334800,2.9574773,1.5478659,,,,,,,,,,,,,, -334900,2.9076166,1.1549988,,,,,,,,,,,,,, -335000,3.2670345,1.3514714,,,,,,,,,,,,,, -335100,3.2275934,1.1693449,,,,,,,,,,,,,, -335200,3.06753,1.0421788,,,,,,,,,,,,,, -335300,3.3234866,1.1211729,,,,,,,,,,,,,, -335400,3.3217604,1.1617655,,,,,,,,,,,,,, -335500,2.9206684,1.0204582,,,,,,,,,,,,,, -335521,,,0.8877343535423279,0.4150673449039459,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,147500.88044071198,161222.46529269218,147500.88044071198,13681.3886988163,22.70599865913391,0.0 -335600,2.9821744,1.0427498,,,,,,,,,,,,,, -335700,3.177664,1.5917529,,,,,,,,,,,,,, -335800,3.0884833,1.1276829,,,,,,,,,,,,,, -335900,3.0595071,1.1738199,,,,,,,,,,,,,, -336000,3.3268409,1.1822697,,,,,,,,,,,,,, -336100,3.3282275,2.1941662,,,,,,,,,,,,,, -336200,4.0882387,3.1732402,,,,,,,,,,,,,, -336300,3.4596474,3.0615537,,,,,,,,,,,,,, -336400,3.0862718,1.1931366,,,,,,,,,,,,,, -336456,,,0.8890820145606995,0.4130926430225372,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,147921.03347206116,161684.40320611,147921.03347206116,13723.035228729248,22.79630279541016,0.0 -336500,3.1298215,1.1933496,,,,,,,,,,,,,, -336600,3.7297137,2.6175082,,,,,,,,,,,,,, -336700,3.0960503,1.4664233,,,,,,,,,,,,,, -336800,3.264507,1.1285751,,,,,,,,,,,,,, -336900,3.659574,3.2666495,,,,,,,,,,,,,, -337000,3.9430735,3.1677697,,,,,,,,,,,,,, -337100,3.2634997,2.7899287,,,,,,,,,,,,,, -337200,3.0243604,2.4182866,,,,,,,,,,,,,, -337300,3.9369614,3.2934113,,,,,,,,,,,,,, -337400,3.5462265,3.1975377,,,,,,,,,,,,,, -337407,,,0.88671875,0.4238328039646148,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,148341.1167113781,162146.08368301392,148341.1167113781,13764.488429307938,22.89143419265747,0.0 -337500,2.8843217,1.0482476,,,,,,,,,,,,,, -337600,3.1686766,1.1051611,,,,,,,,,,,,,, -337700,3.183747,2.4808207,,,,,,,,,,,,,, -337800,3.2997608,1.1711233,,,,,,,,,,,,,, -337900,3.1171257,1.1888777,,,,,,,,,,,,,, -338000,3.333842,2.655683,,,,,,,,,,,,,, -338100,3.2823858,1.0997611,,,,,,,,,,,,,, -338200,3.2601457,1.2245682,,,,,,,,,,,,,, -338300,3.2276838,1.1822585,,,,,,,,,,,,,, -338363,,,0.88832026720047,0.4145942628383636,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,148761.1108095646,162606.52558994293,148761.1108095646,13804.790137529371,22.987659454345703,0.0 -338400,3.1747482,1.2238647,,,,,,,,,,,,,, -338500,3.3553,1.1342323,,,,,,,,,,,,,, -338600,3.2017446,1.163455,,,,,,,,,,,,,, -338700,3.3400185,1.3712351,,,,,,,,,,,,,, -338800,4.100817,2.9262314,,,,,,,,,,,,,, -338900,3.225479,1.1241009,,,,,,,,,,,,,, -339000,3.0811179,2.2459104,,,,,,,,,,,,,, -339100,2.9151883,1.1262299,,,,,,,,,,,,,, -339200,3.0603201,1.0954759,,,,,,,,,,,,,, -339300,3.2822783,2.8844748,,,,,,,,,,,,,, -339317,,,0.8875976204872131,0.4212967455387115,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,149180.98473072052,163068.02965641022,149180.98473072052,13846.27517938614,23.08351254463196,0.0 -339400,3.2371185,1.8806479,,,,,,,,,,,,,, -339500,3.2034032,1.1744373,,,,,,,,,,,,,, -339600,3.0696647,2.6746135,,,,,,,,,,,,,, -339700,2.9952407,2.1420352,,,,,,,,,,,,,, -339800,2.9866383,1.5855095,,,,,,,,,,,,,, -339900,3.1074467,1.0613185,,,,,,,,,,,,,, -340000,3.3181624,2.7873745,,,,,,,,,,,,,, -340100,2.883197,1.0243224,,,,,,,,,,,,,, -340200,3.6459801,3.2061982,,,,,,,,,,,,,, -340272,,,0.8876757621765137,0.4158148467540741,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,149600.92562270164,163525.91829109192,149600.92562270164,13884.049741983414,23.206378698349,0.0 -340300,3.6584666,3.0789194,,,,,,,,,,,,,, -340400,4.166898,3.2578588,,,,,,,,,,,,,, -340500,3.2102697,1.1535833,,,,,,,,,,,,,, -340600,3.098297,2.4475174,,,,,,,,,,,,,, -340700,3.7648406,2.791313,,,,,,,,,,,,,, -340800,3.0661948,1.0664704,,,,,,,,,,,,,, -340900,3.0533776,1.1567976,,,,,,,,,,,,,, -341000,3.0247664,1.3776722,,,,,,,,,,,,,, -341100,3.1597307,1.0778364,,,,,,,,,,,,,, -341200,3.0795903,1.6606998,,,,,,,,,,,,,, -341227,,,0.8883007764816284,0.4141983985900879,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,150021.06316399574,163988.32326960564,150021.06316399574,13926.174630880356,23.29926109313965,0.0 -341300,3.1508398,1.2200588,,,,,,,,,,,,,, -341400,3.1328125,1.1382778,,,,,,,,,,,,,, -341500,3.062565,0.9600744,,,,,,,,,,,,,, -341600,3.0269709,1.5836637,,,,,,,,,,,,,, -341700,2.938997,2.4385624,,,,,,,,,,,,,, -341800,3.7753837,2.8811758,,,,,,,,,,,,,, -341900,3.2653418,2.2896273,,,,,,,,,,,,,, -342000,3.817469,3.2338507,,,,,,,,,,,,,, -342100,3.5650117,3.036711,,,,,,,,,,,,,, -342183,,,0.8895702958106995,0.41547891497612,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,150441.1364018917,164449.25674581528,150441.1364018917,13966.889407157898,23.39518094062805,0.0 -342200,3.0137236,1.2780714,,,,,,,,,,,,,, -342300,3.3608122,1.1706885,,,,,,,,,,,,,, -342400,3.219438,1.1378638,,,,,,,,,,,,,, -342500,3.2903864,1.7723138,,,,,,,,,,,,,, -342600,3.123654,1.146004,,,,,,,,,,,,,, -342700,3.2183979,1.0917861,,,,,,,,,,,,,, -342800,3.1256902,2.5497398,,,,,,,,,,,,,, -342900,2.9075515,0.9871837,,,,,,,,,,,,,, -343000,3.3279793,1.1772139,,,,,,,,,,,,,, -343100,4.3747225,3.28331,,,,,,,,,,,,,, -343140,,,0.8856054544448853,0.4266979992389679,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,150861.22448587418,164915.37854075432,150861.22448587418,14012.772404670715,23.496416807174683,0.0 -343200,3.1010008,2.6622968,,,,,,,,,,,,,, -343300,2.805667,1.4021533,,,,,,,,,,,,,, -343400,2.931931,1.1431416,,,,,,,,,,,,,, -343500,3.0836506,1.1245564,,,,,,,,,,,,,, -343600,2.9739354,1.0562254,,,,,,,,,,,,,, -343700,3.3847518,2.5917366,,,,,,,,,,,,,, -343800,3.0923574,1.1352333,,,,,,,,,,,,,, -343900,3.2122476,2.65684,,,,,,,,,,,,,, -344000,3.2960706,1.2082434,,,,,,,,,,,,,, -344097,,,0.8867382407188416,0.419879138469696,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,151281.30113005638,165373.97410154343,151281.30113005638,14051.161545276642,23.57652187347412,0.0 -344100,4.1965656,1.1796398,,,,,,,,,,,,,, -344200,3.2028546,1.1820018,,,,,,,,,,,,,, -344300,2.7333643,1.2557501,,,,,,,,,,,,,, -344400,3.8793685,3.1156592,,,,,,,,,,,,,, -344500,3.010984,1.6880268,,,,,,,,,,,,,, -344600,3.4125657,2.9747639,,,,,,,,,,,,,, -344700,2.9605498,1.128482,,,,,,,,,,,,,, -344800,3.3485045,1.3113321,,,,,,,,,,,,,, -344900,3.3812428,2.9232697,,,,,,,,,,,,,, -345000,3.201311,1.1104301,,,,,,,,,,,,,, -345051,,,0.8881250023841858,0.4131699800491333,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,151701.54492998123,165838.41160845757,151701.54492998123,14095.20967411995,23.672836303710938,0.0 -345100,3.273873,1.1590983,,,,,,,,,,,,,, -345200,3.0300987,1.3589877,,,,,,,,,,,,,, -345300,2.9804876,2.3648226,,,,,,,,,,,,,, -345400,2.9546337,2.1011925,,,,,,,,,,,,,, -345500,4.4091883,3.189938,,,,,,,,,,,,,, -345600,3.3701637,1.1486328,,,,,,,,,,,,,, -345700,3.1575787,1.1231347,,,,,,,,,,,,,, -345800,4.0893598,3.1324587,,,,,,,,,,,,,, -345900,3.0685246,1.1141139,,,,,,,,,,,,,, -346000,3.758314,2.7145655,,,,,,,,,,,,,, -346001,,,0.8878905773162842,0.4152270853519439,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,152121.5381603241,166302.2970738411,152121.5381603241,14138.956525087357,23.76843428611756,0.0 -346100,3.4513166,1.6578071,,,,,,,,,,,,,, -346200,3.2663443,2.8021278,,,,,,,,,,,,,, -346300,3.2261252,1.4708799,,,,,,,,,,,,,, -346400,3.1942925,2.5685658,,,,,,,,,,,,,, -346500,3.2657342,1.2229638,,,,,,,,,,,,,, -346600,3.2578561,1.0516523,,,,,,,,,,,,,, -346700,3.0075872,1.8365587,,,,,,,,,,,,,, -346800,3.0622976,1.1723483,,,,,,,,,,,,,, -346900,3.0599804,1.5263004,,,,,,,,,,,,,, -346956,,,0.8885155916213989,0.4200997948646545,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,152541.55381274223,166757.76275634766,152541.55381274223,14174.279473781586,23.846048831939697,0.0 -347000,3.0205307,1.7290564,,,,,,,,,,,,,, -347100,3.5776577,2.4244714,,,,,,,,,,,,,, -347200,4.122407,3.255645,,,,,,,,,,,,,, -347300,2.897392,1.6817837,,,,,,,,,,,,,, -347400,3.808056,3.2201672,,,,,,,,,,,,,, -347500,3.672434,1.1925815,,,,,,,,,,,,,, -347600,3.030439,1.8880606,,,,,,,,,,,,,, -347700,3.1513119,1.0645174,,,,,,,,,,,,,, -347800,2.9314044,1.0469266,,,,,,,,,,,,,, -347900,3.5847394,1.1878493,,,,,,,,,,,,,, -347909,,,0.8882226347923279,0.4179379642009735,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,152961.44713258743,167222.80291485786,152961.44713258743,14219.28055357933,23.9422287940979,0.0 -348000,3.390504,2.9320047,,,,,,,,,,,,,, -348100,3.4585116,2.6048794,,,,,,,,,,,,,, -348200,3.4551322,1.1506836,,,,,,,,,,,,,, -348300,3.1091661,1.4134277,,,,,,,,,,,,,, -348400,3.0211792,1.2876618,,,,,,,,,,,,,, -348500,2.9517334,1.3112961,,,,,,,,,,,,,, -348600,3.3240185,1.1803432,,,,,,,,,,,,,, -348700,3.2364578,1.0773073,,,,,,,,,,,,,, -348800,2.999855,1.3804061,,,,,,,,,,,,,, -348863,,,0.8864648342132568,0.4202927052974701,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,153381.35877251625,167684.3264117241,153381.35877251625,14260.740604877472,24.04466724395752,0.0 -348900,3.8372028,3.136021,,,,,,,,,,,,,, -349000,2.9992132,1.1835842,,,,,,,,,,,,,, -349100,3.4260368,1.1379638,,,,,,,,,,,,,, -349200,3.25564,1.1456921,,,,,,,,,,,,,, -349300,3.7032645,3.006895,,,,,,,,,,,,,, -349400,2.9322586,1.3721375,,,,,,,,,,,,,, -349500,3.0510926,1.1792476,,,,,,,,,,,,,, -349600,3.0407705,1.0929095,,,,,,,,,,,,,, -349700,3.364805,1.2345046,,,,,,,,,,,,,, -349800,3.2119012,1.1463054,,,,,,,,,,,,,, -349818,,,0.8893945217132568,0.4127410650253296,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,153801.53926444054,168144.88361668587,153801.53926444054,14300.969812393188,24.14210081100464,0.0 -349900,3.2728083,1.0509267,,,,,,,,,,,,,, -350000,3.209803,1.6400628,,,,,,,,,,,,,, -350100,3.5843394,3.0206504,,,,,,,,,,,,,, -350200,3.351555,2.9072745,,,,,,,,,,,,,, -350300,3.2839491,2.4456463,,,,,,,,,,,,,, -350400,3.4717908,3.0356956,,,,,,,,,,,,,, -350500,3.5932703,1.2552145,,,,,,,,,,,,,, -350600,3.1135874,1.156993,,,,,,,,,,,,,, -350700,3.097653,2.3304124,,,,,,,,,,,,,, -350775,,,0.8864843845367432,0.4197324514389038,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,154221.6833667755,168604.9321911335,154221.6833667755,14340.72930598259,24.23790383338928,0.0 -350800,3.2286353,1.1732407,,,,,,,,,,,,,, -350900,3.2367578,1.1608918,,,,,,,,,,,,,, -351000,2.887017,1.0423225,,,,,,,,,,,,,, -351100,3.1092768,2.270264,,,,,,,,,,,,,, -351200,2.9378095,1.5211697,,,,,,,,,,,,,, -351300,3.3408928,1.135452,,,,,,,,,,,,,, -351400,3.1128714,1.1916648,,,,,,,,,,,,,, -351500,3.1259284,2.1281974,,,,,,,,,,,,,, -351600,3.70506,2.8362803,,,,,,,,,,,,,, -351700,3.499925,1.118511,,,,,,,,,,,,,, -351730,,,0.8885937333106995,0.4155570268630981,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,154641.72449684143,169074.7146832943,154641.72449684143,14390.32236981392,24.33543372154236,0.0 -351800,3.8455193,1.1662465,,,,,,,,,,,,,, -351900,3.6996436,3.1049747,,,,,,,,,,,,,, -352000,3.253344,1.1895853,,,,,,,,,,,,,, -352100,4.6674743,3.2002833,,,,,,,,,,,,,, -352200,3.014785,1.0733078,,,,,,,,,,,,,, -352300,3.1051023,2.678984,,,,,,,,,,,,,, -352400,3.1616342,1.3216033,,,,,,,,,,,,,, -352500,3.1910274,2.160964,,,,,,,,,,,,,, -352600,3.163026,1.0355538,,,,,,,,,,,,,, -352688,,,0.8866210579872131,0.4199872612953186,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,155061.90138220787,169536.02051210403,155061.90138220787,14431.321568489077,24.41545557975769,0.0 -352700,2.807948,1.6818925,,,,,,,,,,,,,, -352800,4.1947336,2.9095023,,,,,,,,,,,,,, -352900,3.078714,2.6440074,,,,,,,,,,,,,, -353000,2.8674672,1.4539485,,,,,,,,,,,,,, -353100,2.9704585,1.1216202,,,,,,,,,,,,,, -353200,3.2969358,1.1259184,,,,,,,,,,,,,, -353300,3.6385329,2.9441638,,,,,,,,,,,,,, -353400,3.1414883,1.114174,,,,,,,,,,,,,, -353500,3.0053678,1.1218662,,,,,,,,,,,,,, -353600,3.133222,1.5252503,,,,,,,,,,,,,, -353645,,,0.8891991972923279,0.4173726439476013,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,155481.78094792366,170000.51396226883,155481.78094792366,14475.808803319933,24.493077278137207,0.0 -353700,2.9590325,1.0889136,,,,,,,,,,,,,, -353800,3.4663444,2.9167259,,,,,,,,,,,,,, -353900,3.0413547,1.1019249,,,,,,,,,,,,,, -354000,3.056038,1.2783542,,,,,,,,,,,,,, -354100,3.828179,3.1076326,,,,,,,,,,,,,, -354200,3.1828148,2.6512175,,,,,,,,,,,,,, -354300,3.1510572,1.1506658,,,,,,,,,,,,,, -354400,3.577931,2.4439342,,,,,,,,,,,,,, -354500,3.1308556,2.2579246,,,,,,,,,,,,,, -354599,,,0.8884375095367432,0.4171162843704223,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,155901.67514777184,170458.52275180817,155901.67514777184,14513.76334285736,24.60359287261963,0.0 -354600,3.1213124,1.4757848,,,,,,,,,,,,,, -354700,3.3378286,1.2538183,,,,,,,,,,,,,, -354800,3.0430028,2.7003605,,,,,,,,,,,,,, -354900,3.0124843,1.0130087,,,,,,,,,,,,,, -355000,3.1785753,1.1826816,,,,,,,,,,,,,, -355100,3.1043537,1.2012806,,,,,,,,,,,,,, -355200,3.0505912,1.4611787,,,,,,,,,,,,,, -355300,3.1369379,2.4497206,,,,,,,,,,,,,, -355400,3.039351,1.2108495,,,,,,,,,,,,,, -355500,3.7515314,3.2407305,,,,,,,,,,,,,, -355550,,,0.890625,0.4058804512023926,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,156321.83713293076,170926.6277961731,156321.83713293076,14561.556003570557,24.704198122024536,0.0 -355600,3.6255066,2.729381,,,,,,,,,,,,,, -355700,2.9949944,2.1775177,,,,,,,,,,,,,, -355800,3.3939128,1.1289909,,,,,,,,,,,,,, -355900,3.2416263,2.1567214,,,,,,,,,,,,,, -356000,3.165164,1.027978,,,,,,,,,,,,,, -356100,2.9991195,1.4520329,,,,,,,,,,,,,, -356200,3.5374732,2.7832592,,,,,,,,,,,,,, -356300,2.9869838,1.4020662,,,,,,,,,,,,,, -356400,3.85929,3.2701652,,,,,,,,,,,,,, -356500,3.3963232,1.5166572,,,,,,,,,,,,,, -356506,,,0.8883984088897705,0.416498452425003,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,156742.0211148262,171383.98013997078,156742.0211148262,14598.595462560654,24.78358292579651,0.0 -356600,3.860192,3.2026243,,,,,,,,,,,,,, -356700,3.2464652,1.2890599,,,,,,,,,,,,,, -356800,2.9981477,1.1847268,,,,,,,,,,,,,, -356900,3.1810358,2.3238,,,,,,,,,,,,,, -357000,2.805712,2.0522037,,,,,,,,,,,,,, -357100,3.2562692,1.5553786,,,,,,,,,,,,,, -357200,3.1072187,1.2498161,,,,,,,,,,,,,, -357300,3.4262311,1.1679006,,,,,,,,,,,,,, -357400,3.0962753,1.1788726,,,,,,,,,,,,,, -357459,,,0.88818359375,0.4171052873134613,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,157162.01435351372,171843.23746800423,157162.01435351372,14637.712956905363,24.881208181381226,0.0 -357500,3.0709689,1.2783339,,,,,,,,,,,,,, -357600,3.6412094,3.1375809,,,,,,,,,,,,,, -357700,3.1249707,1.0877314,,,,,,,,,,,,,, -357800,3.1289086,1.7946934,,,,,,,,,,,,,, -357900,3.370766,1.1478828,,,,,,,,,,,,,, -358000,3.3159206,1.2495263,,,,,,,,,,,,,, -358100,2.7857378,1.1645029,,,,,,,,,,,,,, -358200,2.8311477,1.5125829,,,,,,,,,,,,,, -358300,3.4899068,1.3267071,,,,,,,,,,,,,, -358400,3.266693,1.1663746,,,,,,,,,,,,,, -358413,,,0.8861718773841858,0.4189674258232116,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,157582.07474136353,172313.02126026154,157582.07474136353,14687.286793231964,24.98118376731873,0.0 -358500,4.0222826,1.2424402,,,,,,,,,,,,,, -358600,3.0865433,1.7514188,,,,,,,,,,,,,, -358700,3.0899851,2.164206,,,,,,,,,,,,,, -358800,3.2066703,1.119492,,,,,,,,,,,,,, -358900,3.031565,2.2168343,,,,,,,,,,,,,, -359000,3.2144828,1.7625147,,,,,,,,,,,,,, -359100,3.1980102,1.1368188,,,,,,,,,,,,,, -359200,3.829107,2.762167,,,,,,,,,,,,,, -359300,3.765471,3.19038,,,,,,,,,,,,,, -359367,,,0.8883398175239563,0.4130707383155823,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,158002.15069007874,172773.15025138855,158002.15069007874,14727.204654693604,25.068183422088623,0.0 -359400,3.5406942,1.0643007,,,,,,,,,,,,,, -359500,3.073902,1.1374062,,,,,,,,,,,,,, -359600,4.1790876,2.946034,,,,,,,,,,,,,, -359700,3.4606328,2.145166,,,,,,,,,,,,,, -359800,3.4909844,2.9258063,,,,,,,,,,,,,, -359900,3.2091155,1.1713021,,,,,,,,,,,,,, -360000,3.2765028,2.5572171,,,,,,,,,,,,,, -360100,3.0488923,0.9939374,,,,,,,,,,,,,, -360200,3.031844,1.112262,,,,,,,,,,,,,, -360300,2.9148045,1.9392524,,,,,,,,,,,,,, -360324,,,0.88929682970047,0.4147252142429352,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,158422.29297685623,173233.80309987068,158422.29297685623,14767.577633619308,25.1567451953888,0.0 -360400,3.2490013,1.1826457,,,,,,,,,,,,,, -360500,3.5976682,1.1082996,,,,,,,,,,,,,, -360600,3.143692,1.3790737,,,,,,,,,,,,,, -360700,3.3647072,1.157966,,,,,,,,,,,,,, -360800,3.204882,1.9844129,,,,,,,,,,,,,, -360900,3.2259753,1.3654135,,,,,,,,,,,,,, -361000,3.37602,2.3985348,,,,,,,,,,,,,, -361100,3.621375,3.185697,,,,,,,,,,,,,, -361200,3.134695,1.1395955,,,,,,,,,,,,,, -361268,,,0.88832026720047,0.4192902147769928,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,158842.19017791748,173694.02808856964,158842.19017791748,14807.759229898453,25.254371643066406,0.0 -361300,3.0198777,1.0948111,,,,,,,,,,,,,, -361400,3.3003905,2.9219706,,,,,,,,,,,,,, -361500,3.4181361,1.2870524,,,,,,,,,,,,,, -361600,3.280689,1.1609808,,,,,,,,,,,,,, -361700,2.905724,1.785083,,,,,,,,,,,,,, -361800,3.1100917,2.4447715,,,,,,,,,,,,,, -361900,3.1592073,1.0903025,,,,,,,,,,,,,, -362000,3.083682,1.4113462,,,,,,,,,,,,,, -362100,3.936246,3.182396,,,,,,,,,,,,,, -362200,3.1989303,2.2907834,,,,,,,,,,,,,, -362214,,,0.8856054544448853,0.4205483794212341,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,159262.09999990463,174159.47754120827,159262.09999990463,14853.150671720505,25.353637218475345,0.0 -362300,3.2432945,1.0820225,,,,,,,,,,,,,, -362400,3.0327384,1.2923696,,,,,,,,,,,,,, -362500,3.0968497,1.1604047,,,,,,,,,,,,,, -362600,2.819546,2.1847396,,,,,,,,,,,,,, -362700,2.9964516,1.3116376,,,,,,,,,,,,,, -362800,3.123166,1.2006053,,,,,,,,,,,,,, -362900,3.1903086,1.5668833,,,,,,,,,,,,,, -363000,3.1760364,1.7199746,,,,,,,,,,,,,, -363100,2.9458425,1.720335,,,,,,,,,,,,,, -363169,,,0.8899999856948853,0.4131052792072296,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,159682.17498278618,174618.91343593597,159682.17498278618,14892.368756771088,25.447653770446777,0.0 -363200,3.0612829,2.1539707,,,,,,,,,,,,,, -363300,3.3686194,2.915226,,,,,,,,,,,,,, -363400,3.1288555,1.1394323,,,,,,,,,,,,,, -363500,3.1122844,1.2154602,,,,,,,,,,,,,, -363600,4.6754775,3.3148558,,,,,,,,,,,,,, -363700,2.8866568,1.1019707,,,,,,,,,,,,,, -363800,3.2703025,1.1299934,,,,,,,,,,,,,, -363900,2.7493842,1.5916239,,,,,,,,,,,,,, -364000,3.7980905,3.1455836,,,,,,,,,,,,,, -364100,2.8613677,1.76213,,,,,,,,,,,,,, -364121,,,0.8870507478713989,0.4179600775241852,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,160102.2657313347,175081.6497273445,160102.2657313347,14934.86432981491,25.54854822158813,0.0 -364200,3.052517,1.2109525,,,,,,,,,,,,,, -364300,3.2162185,2.060167,,,,,,,,,,,,,, -364400,3.2708385,1.0920193,,,,,,,,,,,,,, -364500,3.3361573,1.0546716,,,,,,,,,,,,,, -364600,3.2772262,1.1750938,,,,,,,,,,,,,, -364700,3.6243157,3.054141,,,,,,,,,,,,,, -364800,3.1591806,2.176344,,,,,,,,,,,,,, -364900,3.1266081,1.7088007,,,,,,,,,,,,,, -365000,3.9505363,3.011939,,,,,,,,,,,,,, -365073,,,0.8878124952316284,0.414995789527893,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,160522.2911117077,175542.83787035942,160522.2911117077,14975.874757051468,25.65269923210144,0.0 -365100,3.1987097,3.0314605,,,,,,,,,,,,,, -365200,3.3203638,2.5830712,,,,,,,,,,,,,, -365300,2.9458234,1.1715589,,,,,,,,,,,,,, -365400,3.3638759,1.1973598,,,,,,,,,,,,,, -365500,2.8509784,1.3641524,,,,,,,,,,,,,, -365600,3.3991249,2.645697,,,,,,,,,,,,,, -365700,3.1283422,2.6457946,,,,,,,,,,,,,, -365800,2.9840095,2.253332,,,,,,,,,,,,,, -365900,3.1052673,1.1436639,,,,,,,,,,,,,, -366000,3.2977333,1.3839716,,,,,,,,,,,,,, -366025,,,0.8901953101158142,0.4158586859703064,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,160942.30407452583,176006.61175465584,160942.30407452583,15019.486609220505,25.7526330947876,0.0 -366100,3.1990445,2.883037,,,,,,,,,,,,,, -366200,3.1120749,1.1784626,,,,,,,,,,,,,, -366300,3.0494754,2.407968,,,,,,,,,,,,,, -366400,3.491789,2.4746244,,,,,,,,,,,,,, -366500,2.9920936,1.818257,,,,,,,,,,,,,, -366600,3.1947463,1.0453408,,,,,,,,,,,,,, -366700,3.8637826,3.2021239,,,,,,,,,,,,,, -366800,3.1731668,1.4468293,,,,,,,,,,,,,, -366900,3.1088645,1.7379309,,,,,,,,,,,,,, -366978,,,0.8854882717132568,0.4268224537372589,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,161362.19010972977,176465.17006731033,161362.19010972977,15057.952741146088,25.909390687942505,0.0 -367000,4.220937,3.1469328,,,,,,,,,,,,,, -367100,3.1286004,1.074502,,,,,,,,,,,,,, -367200,3.0384815,1.111373,,,,,,,,,,,,,, -367300,3.2279146,1.1888076,,,,,,,,,,,,,, -367400,3.7228858,2.9493866,,,,,,,,,,,,,, -367500,3.1602385,1.0578029,,,,,,,,,,,,,, -367600,3.0690975,1.096597,,,,,,,,,,,,,, -367700,3.2924316,1.5437799,,,,,,,,,,,,,, -367800,2.78566,2.0239992,,,,,,,,,,,,,, -367900,3.3969333,1.1325648,,,,,,,,,,,,,, -367933,,,0.8866796493530273,0.4201131463050842,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,161782.4360189438,176930.18487286568,161782.4360189438,15102.571516036987,26.01111960411072,0.0 -368000,2.953886,1.2654713,,,,,,,,,,,,,, -368100,3.195904,1.9866898,,,,,,,,,,,,,, -368200,3.2850864,2.834534,,,,,,,,,,,,,, -368300,3.0253267,2.472291,,,,,,,,,,,,,, -368400,3.0858037,1.0829325,,,,,,,,,,,,,, -368500,3.1071491,2.7134337,,,,,,,,,,,,,, -368600,3.135646,1.0994967,,,,,,,,,,,,,, -368700,3.3255627,2.6008615,,,,,,,,,,,,,, -368800,3.0784056,1.4006224,,,,,,,,,,,,,, -368887,,,0.8903710842132568,0.4076790809631347,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,162202.54720520973,177393.1012325287,162202.54720520973,15145.225117206572,26.113094091415405,0.0 -368900,3.194605,1.3662516,,,,,,,,,,,,,, -369000,3.0589502,1.1001236,,,,,,,,,,,,,, -369100,3.1654606,1.1141946,,,,,,,,,,,,,, -369200,3.140129,2.3762348,,,,,,,,,,,,,, -369300,3.2128787,2.058283,,,,,,,,,,,,,, -369400,3.4812734,2.7286403,,,,,,,,,,,,,, -369500,3.0812352,1.264766,,,,,,,,,,,,,, -369600,3.5443654,3.2463696,,,,,,,,,,,,,, -369700,4.602711,1.1949556,,,,,,,,,,,,,, -369800,3.5405104,2.7260282,,,,,,,,,,,,,, -369844,,,0.887011706829071,0.4195477068424225,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,162622.64628648758,177851.28444957733,162622.64628648758,15183.17496752739,26.19810652732849,0.0 -369900,3.2188878,1.6727424,,,,,,,,,,,,,, -370000,3.119636,1.0454334,,,,,,,,,,,,,, -370100,2.9774175,1.6332797,,,,,,,,,,,,,, -370200,3.0447438,1.29281,,,,,,,,,,,,,, -370300,3.006762,1.1176907,,,,,,,,,,,,,, -370400,3.1010914,1.1019365,,,,,,,,,,,,,, -370500,3.2381895,2.6614244,,,,,,,,,,,,,, -370600,3.1809554,1.1662748,,,,,,,,,,,,,, -370700,3.9843736,3.2380393,,,,,,,,,,,,,, -370799,,,0.8887695074081421,0.4166419804096222,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,163042.84110546112,178313.641654253,163042.84110546112,15225.187652349472,26.29881167411804,0.0 -370800,3.0220263,2.3411863,,,,,,,,,,,,,, -370900,2.90225,1.0782255,,,,,,,,,,,,,, -371000,3.1592646,1.1842606,,,,,,,,,,,,,, -371100,3.3618062,1.0993136,,,,,,,,,,,,,, -371200,3.2494338,1.1636013,,,,,,,,,,,,,, -371300,3.1304684,2.8608766,,,,,,,,,,,,,, -371400,3.4061928,2.3450017,,,,,,,,,,,,,, -371500,3.8626447,3.2735152,,,,,,,,,,,,,, -371600,2.9639947,1.2050915,,,,,,,,,,,,,, -371700,3.1235745,1.1794447,,,,,,,,,,,,,, -371753,,,0.8869921565055847,0.4218083918094635,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,163463.0209646225,178777.1244843006,163463.0209646225,15268.341364622116,26.399053812026978,0.0 -371800,3.6051028,2.4227228,,,,,,,,,,,,,, -371900,4.252451,3.1658506,,,,,,,,,,,,,, -372000,3.031154,1.0453928,,,,,,,,,,,,,, -372100,3.420534,1.1909511,,,,,,,,,,,,,, -372200,3.1549363,2.4758055,,,,,,,,,,,,,, -372300,3.0496,2.6766825,,,,,,,,,,,,,, -372400,3.0233164,1.0958773,,,,,,,,,,,,,, -372500,3.2253797,2.8485656,,,,,,,,,,,,,, -372600,2.9075284,1.4737175,,,,,,,,,,,,,, -372700,3.100065,1.1585284,,,,,,,,,,,,,, -372710,,,0.8858007788658142,0.4265918731689453,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,163883.04605412483,179234.83329749107,163883.04605412483,15305.892180919647,26.482391119003296,0.0 -372800,3.0601065,1.0145255,,,,,,,,,,,,,, -372900,3.0085487,1.0294689,,,,,,,,,,,,,, -373000,3.2081134,1.1719232,,,,,,,,,,,,,, -373100,3.6985404,3.2372606,,,,,,,,,,,,,, -373200,3.2523723,1.1824234,,,,,,,,,,,,,, -373300,3.0228093,1.0280774,,,,,,,,,,,,,, -373400,3.1265564,2.2723813,,,,,,,,,,,,,, -373500,3.1960294,1.481692,,,,,,,,,,,,,, -373600,3.5581696,1.258379,,,,,,,,,,,,,, -373664,,,0.8891406059265137,0.4111347496509552,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,164302.91035580635,179703.6354522705,164302.91035580635,15354.65480685234,26.608601808547974,0.0 -373700,3.2231016,1.2053169,,,,,,,,,,,,,, -373800,3.1367955,1.1642494,,,,,,,,,,,,,, -373900,3.065145,1.9968858,,,,,,,,,,,,,, -374000,3.1676168,1.1637759,,,,,,,,,,,,,, -374100,3.1963434,1.1973107,,,,,,,,,,,,,, -374200,3.1107588,1.2643696,,,,,,,,,,,,,, -374300,3.1342683,1.466027,,,,,,,,,,,,,, -374400,3.1677463,1.0503976,,,,,,,,,,,,,, -374500,4.0870748,3.3531752,,,,,,,,,,,,,, -374600,3.5386136,2.265921,,,,,,,,,,,,,, -374620,,,0.8878710865974426,0.4127155244350433,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,164722.90692043304,180169.262591362,164722.90692043304,15400.136773586271,26.708449602127075,0.0 -374700,3.2628105,2.8104017,,,,,,,,,,,,,, -374800,3.0216348,1.1532093,,,,,,,,,,,,,, -374900,2.8852627,1.4347351,,,,,,,,,,,,,, -375000,2.7370594,1.2615503,,,,,,,,,,,,,, -375100,2.9242885,1.9356271,,,,,,,,,,,,,, -375200,4.001499,3.2890491,,,,,,,,,,,,,, -375300,3.4324422,2.184618,,,,,,,,,,,,,, -375400,2.986424,1.6905214,,,,,,,,,,,,,, -375500,3.5341566,1.2328572,,,,,,,,,,,,,, -375576,,,0.8874609470367432,0.4219913482666015,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,165143.0420908928,180634.65010356903,165143.0420908928,15445.25285935402,26.79526329040528,0.0 -375600,3.308783,2.7380495,,,,,,,,,,,,,, -375700,3.5006268,1.139159,,,,,,,,,,,,,, -375800,3.5804312,1.1705642,,,,,,,,,,,,,, -375900,3.6800544,3.2116964,,,,,,,,,,,,,, -376000,3.1744077,1.1231879,,,,,,,,,,,,,, -376100,3.428873,2.4861195,,,,,,,,,,,,,, -376200,3.0502663,1.060289,,,,,,,,,,,,,, -376300,2.968718,1.4165148,,,,,,,,,,,,,, -376400,3.0970573,1.6822314,,,,,,,,,,,,,, -376500,2.9921706,1.7940285,,,,,,,,,,,,,, -376526,,,0.887011706829071,0.4182519316673279,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,165563.02144479752,181099.59773135185,165563.02144479752,15490.088018417358,26.87953495979309,0.0 -376600,3.4370804,2.489208,,,,,,,,,,,,,, -376700,3.1963,1.3221895,,,,,,,,,,,,,, -376800,3.0875757,1.4524028,,,,,,,,,,,,,, -376900,2.9300134,2.089274,,,,,,,,,,,,,, -377000,3.2375753,1.1251979,,,,,,,,,,,,,, -377100,3.2460353,2.86946,,,,,,,,,,,,,, -377200,2.96485,1.0964925,,,,,,,,,,,,,, -377300,2.9958584,1.080196,,,,,,,,,,,,,, -377400,3.4802096,1.1609628,,,,,,,,,,,,,, -377483,,,0.8883398175239563,0.4191173315048218,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,165983.27627158165,181558.17009282112,165983.27627158165,15528.273938655851,26.961669921875,0.0 -377500,3.0074508,1.3200421,,,,,,,,,,,,,, -377600,3.123503,2.9452157,,,,,,,,,,,,,, -377700,3.7357113,3.2591722,,,,,,,,,,,,,, -377800,3.236232,2.5832918,,,,,,,,,,,,,, -377900,3.8392,3.2771869,,,,,,,,,,,,,, -378000,3.1657317,1.3288022,,,,,,,,,,,,,, -378100,3.257566,1.2815228,,,,,,,,,,,,,, -378200,3.5450332,2.953461,,,,,,,,,,,,,, -378300,3.465467,2.6210015,,,,,,,,,,,,,, -378400,3.3788168,1.1091166,,,,,,,,,,,,,, -378434,,,0.8884570002555847,0.4142103791236877,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,166403.22260427475,182026.49364376068,166403.22260427475,15576.504554748535,27.059396266937256,0.0 -378500,3.3477242,3.0035129,,,,,,,,,,,,,, -378600,3.6793582,3.3569262,,,,,,,,,,,,,, -378700,3.1351938,1.1773751,,,,,,,,,,,,,, -378800,3.2479138,1.1188825,,,,,,,,,,,,,, -378900,2.8838038,1.2051384,,,,,,,,,,,,,, -379000,3.0394814,1.0715585,,,,,,,,,,,,,, -379100,3.252836,1.1797876,,,,,,,,,,,,,, -379200,3.1237426,1.0268289,,,,,,,,,,,,,, -379300,3.0227604,1.081575,,,,,,,,,,,,,, -379390,,,0.8917577862739563,0.4062821865081787,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,166823.30525374413,182484.9082839489,166823.30525374413,15614.70454120636,27.141717195510864,0.0 -379400,3.3994553,1.1275855,,,,,,,,,,,,,, -379500,2.9009206,1.0901932,,,,,,,,,,,,,, -379600,3.083734,1.136456,,,,,,,,,,,,,, -379700,3.499744,1.3497914,,,,,,,,,,,,,, -379800,3.2316391,1.9083265,,,,,,,,,,,,,, -379900,3.070486,1.2657472,,,,,,,,,,,,,, -380000,3.2721493,1.1509573,,,,,,,,,,,,,, -380100,3.421045,2.8844573,,,,,,,,,,,,,, -380200,2.9710248,1.2110813,,,,,,,,,,,,,, -380300,3.3325396,1.054929,,,,,,,,,,,,,, -380344,,,0.8888476490974426,0.4164382219314575,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,167243.23605418205,182942.2154922485,167243.23605418205,15651.926760196686,27.245988607406616,0.0 -380400,3.0914617,2.4048498,,,,,,,,,,,,,, -380500,3.27903,3.038392,,,,,,,,,,,,,, -380600,3.00733,2.443723,,,,,,,,,,,,,, -380700,3.217211,1.062633,,,,,,,,,,,,,, -380800,3.1046093,1.0254776,,,,,,,,,,,,,, -380900,3.1538613,1.8964889,,,,,,,,,,,,,, -381000,2.9271789,1.1418236,,,,,,,,,,,,,, -381100,3.143535,1.018512,,,,,,,,,,,,,, -381200,2.935188,1.5339085,,,,,,,,,,,,,, -381297,,,0.8883007764816284,0.4152159392833709,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,167663.43669724464,183404.5911934376,167663.43669724464,15693.9504032135,27.34863257408142,0.0 -381300,3.0261676,1.1375661,,,,,,,,,,,,,, -381400,4.0683618,3.1589108,,,,,,,,,,,,,, -381500,4.0869174,3.2540216,,,,,,,,,,,,,, -381600,3.3199008,2.9042602,,,,,,,,,,,,,, -381700,3.327613,1.0686572,,,,,,,,,,,,,, -381800,3.0565858,1.1981218,,,,,,,,,,,,,, -381900,3.722263,3.3407657,,,,,,,,,,,,,, -382000,3.2218215,1.2736932,,,,,,,,,,,,,, -382100,3.3016982,1.3993735,,,,,,,,,,,,,, -382200,2.8201764,1.4355025,,,,,,,,,,,,,, -382250,,,0.8875195384025574,0.416832834482193,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,168083.33858251572,183863.54661631584,168083.33858251572,15732.852969169617,27.450412034988403,0.0 -382300,3.555355,3.047874,,,,,,,,,,,,,, -382400,3.1441164,1.4125733,,,,,,,,,,,,,, -382500,3.001048,1.3659312,,,,,,,,,,,,,, -382600,3.2975757,1.0817429,,,,,,,,,,,,,, -382700,2.8957012,1.0404027,,,,,,,,,,,,,, -382800,3.6398976,3.0853395,,,,,,,,,,,,,, -382900,3.1670222,1.0846047,,,,,,,,,,,,,, -383000,2.973905,1.16652,,,,,,,,,,,,,, -383100,3.062041,1.3146476,,,,,,,,,,,,,, -383200,3.225537,1.167219,,,,,,,,,,,,,, -383205,,,0.8854296803474426,0.4210447669029236,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,168503.38748073578,184327.24445915225,168503.38748073578,15776.350043535233,27.55328798294068,0.0 -383300,3.0167434,1.3664131,,,,,,,,,,,,,, -383400,2.9814243,1.6432393,,,,,,,,,,,,,, -383500,3.061832,1.880471,,,,,,,,,,,,,, -383600,3.401027,2.9917636,,,,,,,,,,,,,, -383700,3.0467503,2.7843924,,,,,,,,,,,,,, -383800,3.4654183,2.931409,,,,,,,,,,,,,, -383900,3.2568603,1.3856704,,,,,,,,,,,,,, -384000,3.3195817,2.9025733,,,,,,,,,,,,,, -384100,3.1821668,1.0652151,,,,,,,,,,,,,, -384159,,,0.889941394329071,0.409912645816803,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,168923.58438396454,184788.61385440824,168923.58438396454,15817.368630886078,27.65798783302307,0.0 -384200,3.6294281,3.2504349,,,,,,,,,,,,,, -384300,2.8535638,1.9701675,,,,,,,,,,,,,, -384400,3.8862,1.1768253,,,,,,,,,,,,,, -384500,4.1592975,3.261787,,,,,,,,,,,,,, -384600,3.0007052,1.0573721,,,,,,,,,,,,,, -384700,3.6928024,3.2593055,,,,,,,,,,,,,, -384800,3.416268,1.1668108,,,,,,,,,,,,,, -384900,3.311892,1.4053599,,,,,,,,,,,,,, -385000,3.9277391,3.2049599,,,,,,,,,,,,,, -385100,3.1792686,2.5249715,,,,,,,,,,,,,, -385117,,,0.8879101276397705,0.4181992709636688,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,169343.8165242672,185250.7816569805,169343.8165242672,15859.136594057083,27.77664041519165,0.0 -385200,3.0106897,1.2884867,,,,,,,,,,,,,, -385300,3.4461725,1.3240746,,,,,,,,,,,,,, -385400,3.5558155,3.2396016,,,,,,,,,,,,,, -385500,2.9469435,1.1953251,,,,,,,,,,,,,, -385600,3.1517136,1.1401328,,,,,,,,,,,,,, -385700,4.3623676,1.1362509,,,,,,,,,,,,,, -385800,4.3003764,2.8305035,,,,,,,,,,,,,, -385900,3.5360343,1.1216922,,,,,,,,,,,,,, -386000,3.4483142,2.2900918,,,,,,,,,,,,,, -386074,,,0.8887499570846558,0.4167336821556091,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,169763.95485639572,185709.42883133888,169763.95485639572,15897.479062080383,27.886797189712524,0.0 -386100,2.8013828,1.0495987,,,,,,,,,,,,,, -386200,3.5980275,3.1430526,,,,,,,,,,,,,, -386300,3.3755944,1.234968,,,,,,,,,,,,,, -386400,2.8270805,1.1307163,,,,,,,,,,,,,, -386500,3.0434039,1.4652663,,,,,,,,,,,,,, -386600,3.1051617,1.4059979,,,,,,,,,,,,,, -386700,3.0189216,1.8210757,,,,,,,,,,,,,, -386800,3.0672758,1.1991782,,,,,,,,,,,,,, -386900,3.1489046,1.2043955,,,,,,,,,,,,,, -387000,3.0721912,1.1237248,,,,,,,,,,,,,, -387027,,,0.8863476514816284,0.420527309179306,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,170184.15087461472,186171.23771500587,170184.15087461472,15938.935741901398,27.993557691574097,0.0 -387100,3.9250917,1.957716,,,,,,,,,,,,,, -387200,2.907496,1.0527515,,,,,,,,,,,,,, -387300,2.8732665,1.7380769,,,,,,,,,,,,,, -387400,2.871035,1.065106,,,,,,,,,,,,,, -387500,3.1477468,2.8524997,,,,,,,,,,,,,, -387600,3.105829,1.1408596,,,,,,,,,,,,,, -387700,3.1462653,1.9012419,,,,,,,,,,,,,, -387800,3.30602,1.258816,,,,,,,,,,,,,, -387900,3.0581732,2.0346286,,,,,,,,,,,,,, -387982,,,0.8876562118530273,0.4173603355884552,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,170604.10696482658,186638.0218274593,170604.10696482658,15985.61399960518,28.094316720962524,0.0 -388000,3.399041,3.0582855,,,,,,,,,,,,,, -388100,3.2449145,2.7747304,,,,,,,,,,,,,, -388200,3.8135233,3.2885919,,,,,,,,,,,,,, -388300,3.3972673,1.1594617,,,,,,,,,,,,,, -388400,3.1982543,1.9479543,,,,,,,,,,,,,, -388500,3.7059207,3.2552636,,,,,,,,,,,,,, -388600,3.1376047,1.1903847,,,,,,,,,,,,,, -388700,2.8942058,1.7983018,,,,,,,,,,,,,, -388800,3.5258806,2.5437574,,,,,,,,,,,,,, -388900,2.8940752,2.0008442,,,,,,,,,,,,,, -388940,,,0.8876757621765137,0.4181455373764038,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,171024.2012052536,187093.33483076096,171024.2012052536,16020.70000576973,28.178041458129883,0.0 -389000,3.4316149,2.981659,,,,,,,,,,,,,, -389100,2.9006422,2.0517874,,,,,,,,,,,,,, -389200,4.050631,3.2490873,,,,,,,,,,,,,, -389300,3.2949598,1.2207565,,,,,,,,,,,,,, -389400,2.9812183,2.0061924,,,,,,,,,,,,,, -389500,3.1484563,1.1459578,,,,,,,,,,,,,, -389600,3.2676628,1.1314644,,,,,,,,,,,,,, -389700,3.0059912,1.1096098,,,,,,,,,,,,,, -389800,3.3896058,1.14343,,,,,,,,,,,,,, -389894,,,0.8898437023162842,0.4151351451873779,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,171444.37899065018,187554.36606407168,171444.37899065018,16061.401812076569,28.2806077003479,0.0 -389900,3.9443583,2.993862,,,,,,,,,,,,,, -390000,3.849729,3.2218208,,,,,,,,,,,,,, -390100,3.0875099,2.0605736,,,,,,,,,,,,,, -390200,2.9889925,1.4047813,,,,,,,,,,,,,, -390300,3.09156,1.2174098,,,,,,,,,,,,,, -390400,3.589193,1.1130018,,,,,,,,,,,,,, -390500,3.2812152,2.7429972,,,,,,,,,,,,,, -390600,3.1849096,1.3267293,,,,,,,,,,,,,, -390700,3.150212,1.1580638,,,,,,,,,,,,,, -390800,3.1923552,2.3458524,,,,,,,,,,,,,, -390845,,,0.8887304663658142,0.4167289733886719,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,171864.35090208054,188025.0598976612,171864.35090208054,16111.968054294586,28.38612174987793,0.0 -390900,3.2236674,1.1447945,,,,,,,,,,,,,, -391000,3.058585,2.0221636,,,,,,,,,,,,,, -391100,3.1456873,1.1123816,,,,,,,,,,,,,, -391200,3.3304625,1.056204,,,,,,,,,,,,,, -391300,3.056967,1.1486145,,,,,,,,,,,,,, -391400,2.9232419,1.0230864,,,,,,,,,,,,,, -391500,3.1101713,1.6827636,,,,,,,,,,,,,, -391600,3.3005924,2.0593643,,,,,,,,,,,,,, -391700,3.390578,1.2732308,,,,,,,,,,,,,, -391800,3.1834445,1.2376413,,,,,,,,,,,,,, -391801,,,0.88818359375,0.4137005805969238,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,172284.2420580387,188482.8562738896,172284.2420580387,16149.739753246307,28.470293283462524,0.0 -391900,3.3862557,1.1151706,,,,,,,,,,,,,, -392000,2.9043324,1.2989225,,,,,,,,,,,,,, -392100,3.1491141,1.3967695,,,,,,,,,,,,,, -392200,3.3844156,3.037901,,,,,,,,,,,,,, -392300,3.4214869,1.1734393,,,,,,,,,,,,,, -392400,3.4396315,1.2606877,,,,,,,,,,,,,, -392500,2.9942555,1.1585271,,,,,,,,,,,,,, -392600,3.1074376,2.5656855,,,,,,,,,,,,,, -392700,3.207022,1.015004,,,,,,,,,,,,,, -392756,,,0.8877343535423279,0.4162317514419555,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,172704.4234380722,188942.7630982399,172704.4234380722,16189.30582332611,28.58032822608948,0.0 -392800,3.1108384,1.1140844,,,,,,,,,,,,,, -392900,3.1805198,1.1841078,,,,,,,,,,,,,, -393000,3.4334202,1.2197518,,,,,,,,,,,,,, -393100,3.0551364,2.5521412,,,,,,,,,,,,,, -393200,4.1055064,3.2020226,,,,,,,,,,,,,, -393300,3.0607026,1.2106929,,,,,,,,,,,,,, -393400,3.3711617,2.5872123,,,,,,,,,,,,,, -393500,3.4926941,2.7077308,,,,,,,,,,,,,, -393600,3.3636293,1.1223918,,,,,,,,,,,,,, -393700,3.141761,1.0705755,,,,,,,,,,,,,, -393710,,,0.8866796493530273,0.4246825277805328,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,173124.35608148575,189402.51794099808,173124.35608148575,16228.974833250046,28.68292617797852,0.0 -393800,3.1599352,1.4538735,,,,,,,,,,,,,, -393900,3.1967742,1.9347348,,,,,,,,,,,,,, -394000,3.6353302,2.949182,,,,,,,,,,,,,, -394100,3.0399795,1.2218971,,,,,,,,,,,,,, -394200,3.0124786,1.3256203,,,,,,,,,,,,,, -394300,4.13122,3.2362647,,,,,,,,,,,,,, -394400,3.2253788,1.1052912,,,,,,,,,,,,,, -394500,2.9199188,1.0834891,,,,,,,,,,,,,, -394600,3.3125238,1.1770055,,,,,,,,,,,,,, -394665,,,0.8871679306030273,0.4199066758155823,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,173544.5121805668,189862.47140693665,173544.5121805668,16268.61748099327,28.78764533996582,0.0 -394700,3.0984905,1.7177677,,,,,,,,,,,,,, -394800,3.467075,2.7652273,,,,,,,,,,,,,, -394900,3.2785664,2.808282,,,,,,,,,,,,,, -395000,3.4613466,1.0953006,,,,,,,,,,,,,, -395100,3.843933,3.007959,,,,,,,,,,,,,, -395200,3.2484221,1.4077435,,,,,,,,,,,,,, -395300,3.1913972,1.4878242,,,,,,,,,,,,,, -395400,3.0973692,1.0391533,,,,,,,,,,,,,, -395500,3.06525,1.2756141,,,,,,,,,,,,,, -395600,4.097432,3.2597675,,,,,,,,,,,,,, -395619,,,0.8871093392372131,0.4206791222095489,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,173964.45985746384,190329.6402170658,173964.45985746384,16315.68248319626,28.8935661315918,0.0 -395700,3.093805,2.4696164,,,,,,,,,,,,,, -395800,3.1066911,2.1784534,,,,,,,,,,,,,, -395900,3.2182016,1.1628461,,,,,,,,,,,,,, -396000,3.8127244,3.2003455,,,,,,,,,,,,,, -396100,3.3096578,1.1594139,,,,,,,,,,,,,, -396200,3.3217854,2.4170775,,,,,,,,,,,,,, -396300,3.3118763,1.070925,,,,,,,,,,,,,, -396400,3.3095484,1.4619071,,,,,,,,,,,,,, -396500,3.0418255,2.033301,,,,,,,,,,,,,, -396579,,,0.8859961032867432,0.4233563840389251,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,174384.35671782494,190787.3806118965,174384.35671782494,16353.387253761292,28.98282265663147,0.0 -396600,3.1236246,1.6857306,,,,,,,,,,,,,, -396700,3.1851156,1.3511434,,,,,,,,,,,,,, -396800,3.1589663,1.9212658,,,,,,,,,,,,,, -396900,3.0356498,1.3382363,,,,,,,,,,,,,, -397000,3.350757,1.1982436,,,,,,,,,,,,,, -397100,3.4438217,2.4067438,,,,,,,,,,,,,, -397200,3.135953,2.704321,,,,,,,,,,,,,, -397300,4.7460012,2.0886338,,,,,,,,,,,,,, -397400,3.4868765,2.696008,,,,,,,,,,,,,, -397500,3.3507442,1.3008072,,,,,,,,,,,,,, -397534,,,0.8891991972923279,0.4108793437480926,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,174804.3445544243,191248.7022612095,174804.3445544243,16394.567228794098,29.087074756622314,0.0 -397600,2.8809998,2.0837963,,,,,,,,,,,,,, -397700,3.0202203,1.1890931,,,,,,,,,,,,,, -397800,3.7019494,2.2431302,,,,,,,,,,,,,, -397900,3.1734536,1.1751486,,,,,,,,,,,,,, -398000,4.033476,3.2258716,,,,,,,,,,,,,, -398100,3.1574202,1.9223937,,,,,,,,,,,,,, -398200,2.9433136,1.6965153,,,,,,,,,,,,,, -398300,3.310876,1.1344889,,,,,,,,,,,,,, -398400,3.1417837,2.5322852,,,,,,,,,,,,,, -398491,,,0.8869531154632568,0.4180715978145599,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,175224.36603736877,191711.99235510823,175224.36603736877,16437.66381263733,29.20992875099182,0.0 -398500,3.1727135,1.1939726,,,,,,,,,,,,,, -398600,3.3226378,1.1153136,,,,,,,,,,,,,, -398700,4.108811,3.014369,,,,,,,,,,,,,, -398800,3.0896852,1.4792776,,,,,,,,,,,,,, -398900,3.0665956,2.0891304,,,,,,,,,,,,,, -399000,3.0669305,1.1270608,,,,,,,,,,,,,, -399100,3.5000489,1.2441989,,,,,,,,,,,,,, -399200,3.0668228,1.0902984,,,,,,,,,,,,,, -399300,3.2879543,1.1043637,,,,,,,,,,,,,, -399400,3.8517194,3.0203748,,,,,,,,,,,,,, -399446,,,0.8896874785423279,0.4124925434589386,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,175644.36986541748,192174.9947481156,175644.36986541748,16480.50748872757,29.315454959869385,0.0 -399500,3.0455627,2.4156604,,,,,,,,,,,,,, -399600,2.9925072,2.4353554,,,,,,,,,,,,,, -399700,3.0287335,1.0232577,,,,,,,,,,,,,, -399800,3.0798848,1.146644,,,,,,,,,,,,,, -399900,3.0331187,1.6167287,,,,,,,,,,,,,, -400000,3.2786164,1.1568348,,,,,,,,,,,,,, -400100,2.9249406,1.0929629,,,,,,,,,,,,,, -400200,2.924809,1.56819,,,,,,,,,,,,,, -400300,4.9603286,2.9684854,,,,,,,,,,,,,, -400400,3.4823582,1.1315799,,,,,,,,,,,,,, -400401,,,0.8876757621765137,0.4160926342010498,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,176064.40971302986,192644.6549062729,176064.40971302986,16529.972445726395,29.421252727508545,0.0 -400500,2.696793,1.6060714,,,,,,,,,,,,,, -400600,3.029787,1.0811744,,,,,,,,,,,,,, -400700,3.0806942,1.1813602,,,,,,,,,,,,,, -400800,3.1532445,1.1271718,,,,,,,,,,,,,, -400900,3.155804,1.1333202,,,,,,,,,,,,,, -401000,3.0929108,1.4148953,,,,,,,,,,,,,, -401100,2.9556453,1.1434481,,,,,,,,,,,,,, -401200,3.3197267,2.9618697,,,,,,,,,,,,,, -401300,3.1115513,1.6404655,,,,,,,,,,,,,, -401357,,,0.88880854845047,0.417743444442749,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,176484.6373269558,193104.86778235435,176484.6373269558,16569.809871912003,29.519530534744263,0.0 -401400,2.7757564,1.874105,,,,,,,,,,,,,, -401500,2.930589,1.037863,,,,,,,,,,,,,, -401600,3.5617342,1.0569994,,,,,,,,,,,,,, -401700,2.9494636,1.0522223,,,,,,,,,,,,,, -401800,3.0560217,1.0822717,,,,,,,,,,,,,, -401900,3.8607554,3.344038,,,,,,,,,,,,,, -402000,3.4373815,2.8148956,,,,,,,,,,,,,, -402100,3.252867,2.1689277,,,,,,,,,,,,,, -402200,3.6506178,1.1889726,,,,,,,,,,,,,, -402300,,,0.8880664110183716,0.4243686497211456,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,176904.59464097023,193565.14063882828,176904.59464097023,16609.966049671173,29.63040018081665,0.0 -402300,3.0228806,2.0522988,,,,,,,,,,,,,, -402400,3.1965647,1.0874671,,,,,,,,,,,,,, -402500,3.004938,2.14121,,,,,,,,,,,,,, -402600,3.3074627,3.0311346,,,,,,,,,,,,,, -402700,3.6198967,1.112279,,,,,,,,,,,,,, -402800,3.9611588,3.0713935,,,,,,,,,,,,,, -402900,3.2439222,1.2144189,,,,,,,,,,,,,, -403000,2.8926213,1.8073192,,,,,,,,,,,,,, -403100,3.2048812,2.6556933,,,,,,,,,,,,,, -403200,3.3194797,1.0653816,,,,,,,,,,,,,, -403247,,,0.8893554210662842,0.4064174890518188,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,177324.74939084053,194033.4842557907,177324.74939084053,16657.998848199844,29.737738132476807,0.0 -403300,3.695065,3.227507,,,,,,,,,,,,,, -403400,3.1295664,2.0986745,,,,,,,,,,,,,, -403500,2.9903326,1.0838289,,,,,,,,,,,,,, -403600,2.7560496,1.1103166,,,,,,,,,,,,,, -403700,3.0977995,1.095139,,,,,,,,,,,,,, -403800,2.8707552,1.9218824,,,,,,,,,,,,,, -403900,3.313362,2.8153005,,,,,,,,,,,,,, -404000,2.9247265,1.7226291,,,,,,,,,,,,,, -404100,3.0091927,1.2522601,,,,,,,,,,,,,, -404200,3.0500324,1.1899372,,,,,,,,,,,,,, -404203,,,0.8868359327316284,0.4202545583248138,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,177744.90615653992,194493.39645719528,177744.90615653992,16697.615966558456,29.82686495780945,0.0 -404300,4.0978556,3.2255301,,,,,,,,,,,,,, -404400,3.0277078,1.0771538,,,,,,,,,,,,,, -404500,3.3047462,1.1291695,,,,,,,,,,,,,, -404600,3.0182056,1.2397864,,,,,,,,,,,,,, -404700,3.1761172,1.309063,,,,,,,,,,,,,, -404800,3.0962567,1.6507299,,,,,,,,,,,,,, -404900,3.9685016,3.1732295,,,,,,,,,,,,,, -405000,3.0342355,2.3214588,,,,,,,,,,,,,, -405100,3.3659182,1.0454936,,,,,,,,,,,,,, -405161,,,0.8908202648162842,0.4089140594005584,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,178165.01534843445,194957.2669413089,178165.01534843445,16741.23883986473,29.913761138916016,0.0 -405200,3.0723727,1.860442,,,,,,,,,,,,,, -405300,3.25703,1.0607748,,,,,,,,,,,,,, -405400,3.210422,1.983483,,,,,,,,,,,,,, -405500,3.430651,1.1275784,,,,,,,,,,,,,, -405600,3.3283467,1.3187697,,,,,,,,,,,,,, -405700,3.0727088,1.2549738,,,,,,,,,,,,,, -405800,2.971233,1.1313479,,,,,,,,,,,,,, -405900,3.085208,2.0314312,,,,,,,,,,,,,, -406000,2.9455442,1.5753063,,,,,,,,,,,,,, -406100,2.943686,1.0621344,,,,,,,,,,,,,, -406117,,,0.8858398199081421,0.4200993180274963,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,178585.18918466568,195418.7814545632,178585.18918466568,16782.42190861702,30.02169752120972,0.0 -406200,3.231604,1.4808661,,,,,,,,,,,,,, -406300,3.067547,1.0812443,,,,,,,,,,,,,, -406400,3.0721831,1.1063842,,,,,,,,,,,,,, -406500,3.317972,1.0250682,,,,,,,,,,,,,, -406600,3.2600539,1.1050735,,,,,,,,,,,,,, -406700,2.9717104,1.9793452,,,,,,,,,,,,,, -406800,3.1993217,1.5386488,,,,,,,,,,,,,, -406900,2.9358191,2.2909436,,,,,,,,,,,,,, -407000,3.2389965,2.7896364,,,,,,,,,,,,,, -407053,,,0.8891991972923279,0.4103685021400451,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,179005.29097795486,195884.275844574,179005.29097795486,16827.65636920929,30.131014585494995,0.0 -407100,3.187737,1.0776978,,,,,,,,,,,,,, -407200,3.2021115,1.1998492,,,,,,,,,,,,,, -407300,3.2784524,1.0887525,,,,,,,,,,,,,, -407400,3.511225,3.1289525,,,,,,,,,,,,,, -407500,3.1489947,2.3503537,,,,,,,,,,,,,, -407600,3.053765,1.0884349,,,,,,,,,,,,,, -407700,3.0265093,1.2111964,,,,,,,,,,,,,, -407800,3.130694,1.0701977,,,,,,,,,,,,,, -407900,2.9316025,1.3110821,,,,,,,,,,,,,, -408000,3.1567113,1.1115971,,,,,,,,,,,,,, -408006,,,0.8874609470367432,0.4193182587623596,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,179425.34422326088,196344.5682406425,179425.34422326088,16867.759991645813,30.217520475387573,0.0 -408100,3.2642004,2.8341274,,,,,,,,,,,,,, -408200,2.9640453,1.6949863,,,,,,,,,,,,,, -408300,2.946493,1.2528336,,,,,,,,,,,,,, -408400,2.9927533,1.917608,,,,,,,,,,,,,, -408500,3.1527522,2.7530227,,,,,,,,,,,,,, -408600,3.0248485,2.1378222,,,,,,,,,,,,,, -408700,3.1137109,1.042707,,,,,,,,,,,,,, -408800,2.9783866,2.2191443,,,,,,,,,,,,,, -408900,2.8765924,1.5927601,,,,,,,,,,,,,, -408955,,,0.8893163800239563,0.4162834584712982,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,179845.50100183487,196811.6254954338,179845.50100183487,16914.508714675903,30.320476055145264,0.0 -409000,3.0712736,2.2883344,,,,,,,,,,,,,, -409100,4.829988,1.1068056,,,,,,,,,,,,,, -409200,2.9454904,1.0301727,,,,,,,,,,,,,, -409300,3.121609,1.1253234,,,,,,,,,,,,,, -409400,3.0724745,1.2174656,,,,,,,,,,,,,, -409500,3.3039758,1.2612962,,,,,,,,,,,,,, -409600,3.458228,1.3918551,,,,,,,,,,,,,, -409700,3.016773,1.0292645,,,,,,,,,,,,,, -409800,3.7639117,3.204476,,,,,,,,,,,,,, -409870,,,0.8877733945846558,0.4182986617088318,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,180265.3992426396,197280.82704353333,180265.3992426396,16963.657194137573,30.428412675857544,0.0 -409900,3.1137478,1.0962511,,,,,,,,,,,,,, -410000,3.2128158,2.0345309,,,,,,,,,,,,,, -410100,3.4720693,1.8361845,,,,,,,,,,,,,, -410200,3.389447,2.0286365,,,,,,,,,,,,,, -410300,3.219652,2.7455711,,,,,,,,,,,,,, -410400,3.033412,1.1574962,,,,,,,,,,,,,, -410500,3.1400611,1.1533372,,,,,,,,,,,,,, -410600,3.318586,3.0676596,,,,,,,,,,,,,, -410700,3.0669618,1.1627073,,,,,,,,,,,,,, -410800,3.3099957,1.0034889,,,,,,,,,,,,,, -410825,,,0.8863281011581421,0.4216166734695434,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,180685.4873828888,197744.1201946736,180685.4873828888,17006.726067066193,30.515804767608643,0.0 -410900,3.88546,2.7390375,,,,,,,,,,,,,, -411000,3.1183305,1.05815,,,,,,,,,,,,,, -411100,3.3497121,2.0848498,,,,,,,,,,,,,, -411200,3.815457,2.733765,,,,,,,,,,,,,, -411300,3.5013041,3.2286777,,,,,,,,,,,,,, -411400,2.9995384,1.038877,,,,,,,,,,,,,, -411500,2.9184787,1.8403456,,,,,,,,,,,,,, -411600,3.4336057,3.300177,,,,,,,,,,,,,, -411700,2.955238,1.3092984,,,,,,,,,,,,,, -411778,,,0.890625,0.4106134176254272,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,181105.803034544,198202.1157720089,181105.803034544,17044.269870519638,30.603248834609985,0.0 -411800,3.3459463,2.982147,,,,,,,,,,,,,, -411900,3.457023,3.0108275,,,,,,,,,,,,,, -412000,2.9449863,1.0084037,,,,,,,,,,,,,, -412100,3.0511131,1.6117022,,,,,,,,,,,,,, -412200,3.0605783,1.0970035,,,,,,,,,,,,,, -412300,3.038306,1.1479483,,,,,,,,,,,,,, -412400,3.8908179,3.208087,,,,,,,,,,,,,, -412500,3.0132718,1.9020644,,,,,,,,,,,,,, -412600,3.9892948,3.3308506,,,,,,,,,,,,,, -412700,3.184506,1.2151395,,,,,,,,,,,,,, -412712,,,0.8864843845367432,0.4179598391056061,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,181525.74473643303,198661.29543995857,181525.74473643303,17083.35184264183,30.71147727966309,0.0 -412800,3.0160158,1.0606489,,,,,,,,,,,,,, -412900,2.9474497,1.1455412,,,,,,,,,,,,,, -413000,2.875196,1.0797698,,,,,,,,,,,,,, -413100,2.932558,1.0854446,,,,,,,,,,,,,, -413200,3.1693347,1.2553483,,,,,,,,,,,,,, -413300,2.8310795,1.2401574,,,,,,,,,,,,,, -413400,3.432703,3.0253918,,,,,,,,,,,,,, -413500,3.4854147,1.7755613,,,,,,,,,,,,,, -413600,3.6651006,1.743633,,,,,,,,,,,,,, -413662,,,0.8892187476158142,0.4130817055702209,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,181945.9853141308,199132.67255043983,181945.9853141308,17134.33325767517,30.818686723709103,0.0 -413700,3.0632565,1.1888119,,,,,,,,,,,,,, -413800,3.0847433,1.1357884,,,,,,,,,,,,,, -413900,3.1086395,1.5828886,,,,,,,,,,,,,, -414000,3.125004,2.6519916,,,,,,,,,,,,,, -414100,3.310339,1.323124,,,,,,,,,,,,,, -414200,3.754543,2.991259,,,,,,,,,,,,,, -414300,3.1455562,1.2590277,,,,,,,,,,,,,, -414400,2.979838,1.1766635,,,,,,,,,,,,,, -414500,3.1794312,2.210093,,,,,,,,,,,,,, -414600,3.292263,1.2196376,,,,,,,,,,,,,, -414618,,,0.8873242139816284,0.4246455430984497,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,182366.14447641373,199590.87727713585,182366.14447641373,17172.24253630638,30.905985593795776,0.0 -414700,2.9759507,1.1447184,,,,,,,,,,,,,, -414800,3.0526204,1.4992454,,,,,,,,,,,,,, -414900,3.3117409,2.6874313,,,,,,,,,,,,,, -415000,3.3574574,1.1340151,,,,,,,,,,,,,, -415100,2.855269,1.0890181,,,,,,,,,,,,,, -415200,3.0355222,1.1768095,,,,,,,,,,,,,, -415300,2.9635887,1.1881198,,,,,,,,,,,,,, -415400,3.525076,3.0145485,,,,,,,,,,,,,, -415500,3.3414598,1.6204308,,,,,,,,,,,,,, -415576,,,0.8868749737739563,0.4205108880996704,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,182786.24632263184,200053.9158146381,182786.24632263184,17214.97789669037,31.05756092071533,0.0 -415600,2.8617938,1.7625272,,,,,,,,,,,,,, -415700,2.944226,1.1882402,,,,,,,,,,,,,, -415800,3.2443726,1.078924,,,,,,,,,,,,,, -415900,3.2039294,1.1018208,,,,,,,,,,,,,, -416000,3.0276854,1.3268611,,,,,,,,,,,,,, -416100,2.8835335,1.3108606,,,,,,,,,,,,,, -416200,2.9889405,1.2929497,,,,,,,,,,,,,, -416300,3.1942797,2.1291256,,,,,,,,,,,,,, -416400,3.7765958,3.3537993,,,,,,,,,,,,,, -416500,2.9552743,1.8290998,,,,,,,,,,,,,, -416528,,,0.888964831829071,0.4104123115539551,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,183206.3391830921,200517.62397146225,183206.3391830921,17258.430349826813,31.167073965072632,0.0 -416600,3.1958418,2.7072663,,,,,,,,,,,,,, -416700,3.4162269,1.3977022,,,,,,,,,,,,,, -416800,3.1592102,1.1135529,,,,,,,,,,,,,, -416900,2.9858334,1.1454442,,,,,,,,,,,,,, -417000,3.63326,3.0097508,,,,,,,,,,,,,, -417100,3.1844432,1.4801985,,,,,,,,,,,,,, -417200,3.1640155,1.3732334,,,,,,,,,,,,,, -417300,2.9048104,1.8536494,,,,,,,,,,,,,, -417400,3.2557025,1.1693153,,,,,,,,,,,,,, -417479,,,0.8873632550239563,0.4178951978683471,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,183626.3146479129,200979.2339589596,183626.3146479129,17299.904410362244,31.27821397781372,0.0 -417500,2.9654942,1.0680214,,,,,,,,,,,,,, -417600,2.8851612,1.6892616,,,,,,,,,,,,,, -417700,3.2139065,1.3222451,,,,,,,,,,,,,, -417800,3.992992,3.3166964,,,,,,,,,,,,,, -417900,3.0881965,2.0508542,,,,,,,,,,,,,, -418000,3.3176992,1.1612353,,,,,,,,,,,,,, -418100,3.1428483,1.072617,,,,,,,,,,,,,, -418200,2.9809306,1.2293024,,,,,,,,,,,,,, -418300,3.06731,1.5638489,,,,,,,,,,,,,, -418400,3.5504808,3.096312,,,,,,,,,,,,,, -418433,,,0.8864062428474426,0.4254805743694305,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,184046.30386471748,201438.4517595768,184046.30386471748,17338.990936517715,31.370931386947632,0.0 -418500,3.5821283,1.0904503,,,,,,,,,,,,,, -418600,3.16852,2.8935132,,,,,,,,,,,,,, -418700,3.3596392,1.9398913,,,,,,,,,,,,,, -418800,2.971176,2.6755679,,,,,,,,,,,,,, -418900,3.0241156,1.6184546,,,,,,,,,,,,,, -419000,2.9557397,1.492803,,,,,,,,,,,,,, -419100,3.0074308,1.1566267,,,,,,,,,,,,,, -419200,3.0977006,2.7369611,,,,,,,,,,,,,, -419300,2.8813298,1.1353691,,,,,,,,,,,,,, -419385,,,0.8881250023841858,0.4159334599971771,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,184466.1449456215,201907.9888682365,184466.1449456215,17388.526458263397,31.48215794563293,0.0 -419400,3.551154,2.9383388,,,,,,,,,,,,,, -419500,3.0323024,1.692867,,,,,,,,,,,,,, -419600,3.0213783,2.4009695,,,,,,,,,,,,,, -419700,3.0329046,1.0337021,,,,,,,,,,,,,, -419800,3.2928355,1.9663838,,,,,,,,,,,,,, -419900,3.2511413,2.104643,,,,,,,,,,,,,, -420000,3.1693482,1.600241,,,,,,,,,,,,,, -420100,3.1962163,2.6194751,,,,,,,,,,,,,, -420200,3.3002608,1.8585742,,,,,,,,,,,,,, -420300,3.140478,2.483365,,,,,,,,,,,,,, -420341,,,0.8867773413658142,0.423092246055603,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,184885.83039593697,202366.15796732905,184885.83039593697,17426.62258195877,31.82004451751709,0.0 -420400,2.9688542,1.016749,,,,,,,,,,,,,, -420500,3.1819546,1.8121349,,,,,,,,,,,,,, -420600,3.0618563,1.9980035,,,,,,,,,,,,,, -420700,3.0317674,1.618199,,,,,,,,,,,,,, -420800,3.1972802,1.2821207,,,,,,,,,,,,,, -420900,3.5923452,1.8214241,,,,,,,,,,,,,, -421000,3.0245001,1.179646,,,,,,,,,,,,,, -421100,3.703391,1.1201165,,,,,,,,,,,,,, -421200,3.6585953,1.0827441,,,,,,,,,,,,,, -421294,,,0.8882421851158142,0.4161365032196045,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,185305.8543047905,202829.32935857773,185305.8543047905,17469.606053113937,31.93477702140808,0.0 -421300,3.8271785,1.1466894,,,,,,,,,,,,,, -421400,2.98986,1.8038166,,,,,,,,,,,,,, -421500,3.0729434,2.7384267,,,,,,,,,,,,,, -421600,3.4542506,2.2600315,,,,,,,,,,,,,, -421700,3.1213658,1.1282074,,,,,,,,,,,,,, -421800,2.89498,1.2524067,,,,,,,,,,,,,, -421900,2.9826763,2.298256,,,,,,,,,,,,,, -422000,4.1084456,3.18941,,,,,,,,,,,,,, -422100,3.0874362,1.1218169,,,,,,,,,,,,,, -422200,3.457324,2.4810665,,,,,,,,,,,,,, -422250,,,0.888476550579071,0.4120833575725555,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,185725.99530100825,203288.01004314423,185725.99530100825,17508.005979537964,32.02518367767334,0.0 -422300,3.1121945,1.1453911,,,,,,,,,,,,,, -422400,3.2742686,1.0745224,,,,,,,,,,,,,, -422500,4.299244,3.1729703,,,,,,,,,,,,,, -422600,3.8863084,3.1594493,,,,,,,,,,,,,, -422700,3.0043788,2.6124055,,,,,,,,,,,,,, -422800,3.4311957,1.077332,,,,,,,,,,,,,, -422900,3.3102288,2.5567076,,,,,,,,,,,,,, -423000,3.2708042,2.3658223,,,,,,,,,,,,,, -423100,3.1367126,1.1363827,,,,,,,,,,,,,, -423200,3.1882885,1.6342639,,,,,,,,,,,,,, -423203,,,0.8883398175239563,0.414626270532608,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,186145.98817968369,203748.846982956,186145.98817968369,17548.692868709564,32.13347125053406,0.0 -423300,3.526235,2.9347248,,,,,,,,,,,,,, -423400,3.108139,1.0759538,,,,,,,,,,,,,, -423500,3.1328378,1.1752062,,,,,,,,,,,,,, -423600,3.278899,1.4899778,,,,,,,,,,,,,, -423700,2.9760277,2.0724523,,,,,,,,,,,,,, -423800,2.862836,1.6219946,,,,,,,,,,,,,, -423900,3.1006303,1.2774504,,,,,,,,,,,,,, -424000,3.1885116,1.7964596,,,,,,,,,,,,,, -424100,3.2979574,1.1296902,,,,,,,,,,,,,, -424158,,,0.8863085508346558,0.421446144580841,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,186565.88954353333,204213.6493074894,186565.88954353333,17593.436250925064,32.24198937416077,0.0 -424200,2.976734,1.3225954,,,,,,,,,,,,,, -424300,3.3147678,1.0860996,,,,,,,,,,,,,, -424400,3.3445413,1.1998767,,,,,,,,,,,,,, -424500,3.13496,1.1335835,,,,,,,,,,,,,, -424600,3.1458983,2.5207634,,,,,,,,,,,,,, -424700,2.7551723,1.4436812,,,,,,,,,,,,,, -424800,3.1923172,1.0388566,,,,,,,,,,,,,, -424900,3.05656,1.6447918,,,,,,,,,,,,,, -425000,3.3121305,1.3403727,,,,,,,,,,,,,, -425100,3.0813766,1.1079788,,,,,,,,,,,,,, -425112,,,0.8889062404632568,0.416959673166275,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,186986.0024909973,204671.20263504985,186986.0024909973,17630.724541187286,32.34537315368652,0.0 -425200,3.4431076,1.1400315,,,,,,,,,,,,,, -425300,3.1803825,1.1638805,,,,,,,,,,,,,, -425400,3.0992131,1.7479339,,,,,,,,,,,,,, -425500,3.6564245,1.1033998,,,,,,,,,,,,,, -425600,3.183756,2.1944299,,,,,,,,,,,,,, -425700,3.0080585,1.2064799,,,,,,,,,,,,,, -425800,3.155225,1.1249398,,,,,,,,,,,,,, -425900,2.9554641,1.121768,,,,,,,,,,,,,, -426000,3.2774673,1.1818132,,,,,,,,,,,,,, -426067,,,0.888476550579071,0.4230095446109772,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,187406.1336224079,205137.1404938697,187406.1336224079,17676.37200140953,32.45468544960022,0.0 -426100,3.02655,2.0693746,,,,,,,,,,,,,, -426200,2.88614,1.6930305,,,,,,,,,,,,,, -426300,3.2900105,1.5627381,,,,,,,,,,,,,, -426400,2.946846,1.8835905,,,,,,,,,,,,,, -426500,3.400645,2.6532493,,,,,,,,,,,,,, -426600,2.9718614,2.0191047,,,,,,,,,,,,,, -426700,3.24352,1.1816695,,,,,,,,,,,,,, -426800,3.3310468,2.6945279,,,,,,,,,,,,,, -426900,2.9632373,1.5133564,,,,,,,,,,,,,, -427000,3.0609303,1.1329983,,,,,,,,,,,,,, -427022,,,0.8894140720367432,0.4090057015419006,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,187826.3649520874,205601.6198780537,187826.3649520874,17720.457312583923,32.5679783821106,0.0 -427100,3.2254407,1.2079306,,,,,,,,,,,,,, -427200,3.082797,2.4201722,,,,,,,,,,,,,, -427300,4.3547597,2.984988,,,,,,,,,,,,,, -427400,3.162564,1.134508,,,,,,,,,,,,,, -427500,3.0507364,1.215362,,,,,,,,,,,,,, -427600,3.0585115,1.2594998,,,,,,,,,,,,,, -427700,2.95616,1.1107143,,,,,,,,,,,,,, -427800,2.9050262,1.6178203,,,,,,,,,,,,,, -427900,3.3561008,1.1689019,,,,,,,,,,,,,, -427979,,,0.8889452815055847,0.4116565585136413,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,188246.64551234245,206069.3907442093,188246.64551234245,17767.809679031372,32.65678024291992,0.0 -428000,3.1565595,1.3474749,,,,,,,,,,,,,, -428100,3.536057,2.4768207,,,,,,,,,,,,,, -428200,4.0032835,3.4078412,,,,,,,,,,,,,, -428300,3.3999476,1.6900188,,,,,,,,,,,,,, -428400,3.2519832,2.6111848,,,,,,,,,,,,,, -428500,2.9698467,1.1420038,,,,,,,,,,,,,, -428600,3.7989562,1.8526844,,,,,,,,,,,,,, -428700,3.0458734,1.8802421,,,,,,,,,,,,,, -428800,3.5952392,1.4296486,,,,,,,,,,,,,, -428900,3.0697355,1.8251987,,,,,,,,,,,,,, -428934,,,0.8880664110183716,0.4169478118419647,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,188666.6430413723,206530.98461413383,188666.6430413723,17809.26633501053,32.74723792076111,0.0 -429000,3.0882747,1.083167,,,,,,,,,,,,,, -429100,3.0398042,1.1428307,,,,,,,,,,,,,, -429200,2.9337254,1.5262077,,,,,,,,,,,,,, -429300,3.0560112,1.5349922,,,,,,,,,,,,,, -429400,2.867999,1.4114722,,,,,,,,,,,,,, -429500,3.0105858,1.2868905,,,,,,,,,,,,,, -429600,2.9609573,1.1209111,,,,,,,,,,,,,, -429700,3.199513,1.2557687,,,,,,,,,,,,,, -429800,4.653924,3.1432924,,,,,,,,,,,,,, -429888,,,0.8879101276397705,0.4141132533550262,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,189086.55633735657,206995.16974568367,189086.55633735657,17853.379207134247,32.857391119003296,0.0 -429900,3.3137226,1.1415557,,,,,,,,,,,,,, -430000,2.9231396,1.6934478,,,,,,,,,,,,,, -430100,3.147288,1.170105,,,,,,,,,,,,,, -430200,3.3115356,1.1832006,,,,,,,,,,,,,, -430300,3.1047573,2.6569023,,,,,,,,,,,,,, -430400,3.1073961,1.2987208,,,,,,,,,,,,,, -430500,3.002769,1.091127,,,,,,,,,,,,,, -430600,3.08613,1.1797955,,,,,,,,,,,,,, -430700,3.394902,2.960874,,,,,,,,,,,,,, -430800,3.0126786,1.0843372,,,,,,,,,,,,,, -430833,,,0.8874804377555847,0.4158562123775482,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,189506.4773423672,207457.9479892254,189506.4773423672,17896.074108600616,32.970332860946655,0.0 -430900,2.842348,1.130095,,,,,,,,,,,,,, -431000,4.2605534,2.8323245,,,,,,,,,,,,,, -431100,3.1446881,1.1279191,,,,,,,,,,,,,, -431200,3.2298868,1.1405886,,,,,,,,,,,,,, -431300,3.2081761,1.2104344,,,,,,,,,,,,,, -431400,2.9941869,1.154612,,,,,,,,,,,,,, -431500,3.124091,1.1509477,,,,,,,,,,,,,, -431600,3.228884,1.4599599,,,,,,,,,,,,,, -431700,3.1281736,2.0423508,,,,,,,,,,,,,, -431786,,,0.8893945217132568,0.4129515886306762,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,189926.61285805705,207918.0655157566,189926.61285805705,17935.916864156723,33.06070804595947,0.0 -431800,3.2327611,2.3197484,,,,,,,,,,,,,, -431900,3.2608554,1.1500088,,,,,,,,,,,,,, -432000,3.1431437,1.0477747,,,,,,,,,,,,,, -432100,3.2880943,1.6076577,,,,,,,,,,,,,, -432200,3.1367383,1.1155486,,,,,,,,,,,,,, -432300,3.0317557,1.6378661,,,,,,,,,,,,,, -432400,3.109502,1.2672334,,,,,,,,,,,,,, -432500,3.2536104,1.5846065,,,,,,,,,,,,,, -432600,3.283468,1.1175696,,,,,,,,,,,,,, -432700,3.7859662,3.3810852,,,,,,,,,,,,,, -432730,,,0.8885937333106995,0.4198548793792724,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,190346.702862978,208378.6335697174,190346.702862978,17976.231738567352,33.17487382888794,0.0 -432800,2.780401,1.977798,,,,,,,,,,,,,, -432900,3.0666747,1.1423943,,,,,,,,,,,,,, -433000,3.1527767,1.1958227,,,,,,,,,,,,,, -433100,3.7391617,1.1260288,,,,,,,,,,,,,, -433200,3.2864141,1.1761404,,,,,,,,,,,,,, -433300,3.642892,3.1905086,,,,,,,,,,,,,, -433400,3.288359,1.1791185,,,,,,,,,,,,,, -433500,3.1651447,1.1522264,,,,,,,,,,,,,, -433600,3.2539277,1.0376127,,,,,,,,,,,,,, -433681,,,0.8869335651397705,0.418942928314209,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,190766.5861680508,208844.6084537506,190766.5861680508,18022.161470651627,33.28832411766052,0.0 -433700,3.5720592,3.137795,,,,,,,,,,,,,, -433800,3.1315303,2.1496525,,,,,,,,,,,,,, -433900,2.9207373,1.718406,,,,,,,,,,,,,, -434000,2.8875704,1.8301817,,,,,,,,,,,,,, -434100,3.381891,1.0778967,,,,,,,,,,,,,, -434200,3.179655,1.1634071,,,,,,,,,,,,,, -434300,3.0984447,1.0908649,,,,,,,,,,,,,, -434400,3.3738015,1.215316,,,,,,,,,,,,,, -434500,3.0373268,1.9858212,,,,,,,,,,,,,, -434600,3.4245656,3.1233802,,,,,,,,,,,,,, -434636,,,0.8886327743530273,0.4150497019290924,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,191186.5708837509,209303.67406725883,191186.5708837509,18061.090587615967,33.39108967781067,0.0 -434700,3.6296084,3.1304705,,,,,,,,,,,,,, -434800,4.1946316,3.198624,,,,,,,,,,,,,, -434900,3.5781791,3.0135322,,,,,,,,,,,,,, -435000,3.0447235,1.413461,,,,,,,,,,,,,, -435100,3.3236365,1.0596776,,,,,,,,,,,,,, -435200,3.1940908,1.1539135,,,,,,,,,,,,,, -435300,2.9643943,1.0860136,,,,,,,,,,,,,, -435400,4.1070504,3.0359178,,,,,,,,,,,,,, -435500,3.3180027,2.8536787,,,,,,,,,,,,,, -435590,,,0.8875195384025574,0.4184292554855346,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,191606.7920079232,209766.02551412585,191606.7920079232,18103.057653665543,33.505431175231934,0.0 -435600,2.9970622,1.0404577,,,,,,,,,,,,,, -435700,3.4023101,1.1898216,,,,,,,,,,,,,, -435800,3.309255,2.5037923,,,,,,,,,,,,,, -435900,3.1036177,1.2939568,,,,,,,,,,,,,, -436000,3.045924,1.2114303,,,,,,,,,,,,,, -436100,3.0730906,1.1086783,,,,,,,,,,,,,, -436200,3.4123187,2.9436207,,,,,,,,,,,,,, -436300,3.732627,1.2091686,,,,,,,,,,,,,, -436400,3.9728348,3.2594702,,,,,,,,,,,,,, -436500,3.4165318,1.6656394,,,,,,,,,,,,,, -436546,,,0.8882812261581421,0.4159662425518036,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,192026.9568374157,210227.8538463116,192026.9568374157,18144.561408996586,33.61554431915283,0.0 -436600,3.1008303,1.1368922,,,,,,,,,,,,,, -436700,2.9177651,1.571726,,,,,,,,,,,,,, -436800,3.9853969,3.320723,,,,,,,,,,,,,, -436900,3.6785803,2.7940726,,,,,,,,,,,,,, -437000,3.2035348,1.146671,,,,,,,,,,,,,, -437100,3.3539515,1.2016362,,,,,,,,,,,,,, -437200,3.3122659,1.1125896,,,,,,,,,,,,,, -437300,3.3534033,1.151705,,,,,,,,,,,,,, -437400,3.2243488,1.8204916,,,,,,,,,,,,,, -437500,,,0.8877148032188416,0.4200031161308288,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,192447.2765059471,210690.33194756508,192447.2765059471,18186.55794620514,33.72833752632141,0.0 -437500,2.98775,1.0928473,,,,,,,,,,,,,, -437600,2.996091,1.5890102,,,,,,,,,,,,,, -437700,3.2811284,1.6080883,,,,,,,,,,,,,, -437800,2.9014082,1.1435933,,,,,,,,,,,,,, -437900,2.9477334,1.7772623,,,,,,,,,,,,,, -438000,3.227636,1.0779388,,,,,,,,,,,,,, -438100,3.0338907,1.6387186,,,,,,,,,,,,,, -438200,3.430952,1.1386808,,,,,,,,,,,,,, -438300,3.5215862,3.229173,,,,,,,,,,,,,, -438400,2.8884003,1.0985008,,,,,,,,,,,,,, -438455,,,0.88734370470047,0.4195761084556579,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,192867.13353562355,211149.1098475456,192867.13353562355,18225.270318984985,33.887531042099,0.0 -438500,3.2855222,1.5219452,,,,,,,,,,,,,, -438600,3.0489995,1.3164233,,,,,,,,,,,,,, -438700,3.2027097,1.7230048,,,,,,,,,,,,,, -438800,3.3534732,1.8951465,,,,,,,,,,,,,, -438900,3.1426423,1.1898994,,,,,,,,,,,,,, -439000,3.4244015,2.8888497,,,,,,,,,,,,,, -439100,4.3112817,3.0943756,,,,,,,,,,,,,, -439200,3.2676497,1.2856895,,,,,,,,,,,,,, -439300,3.0376484,1.1432593,,,,,,,,,,,,,, -439400,2.9862175,1.1501869,,,,,,,,,,,,,, -439407,,,0.88832026720047,0.4196644127368927,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,193287.17481279373,211611.44933223724,193287.17481279373,18267.40621161461,34.00122618675232,0.0 -439500,3.0814383,1.1222644,,,,,,,,,,,,,, -439600,3.3846748,2.4628737,,,,,,,,,,,,,, -439700,3.3474445,2.6048806,,,,,,,,,,,,,, -439800,3.107474,1.6792717,,,,,,,,,,,,,, -439900,2.9233882,1.7871681,,,,,,,,,,,,,, -440000,3.3621538,2.3692205,,,,,,,,,,,,,, -440100,3.2974198,1.100921,,,,,,,,,,,,,, -440200,2.9328353,1.1405545,,,,,,,,,,,,,, -440300,3.25682,1.0121976,,,,,,,,,,,,,, -440359,,,0.88671875,0.4149238169193268,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,193707.19133090973,212072.5546195507,193707.19133090973,18308.33525133133,34.11226868629456,0.0 -440400,3.2019663,1.1815546,,,,,,,,,,,,,, -440500,5.4734325,1.1046818,,,,,,,,,,,,,, -440600,3.3366086,2.3190432,,,,,,,,,,,,,, -440700,3.1077297,2.7324896,,,,,,,,,,,,,, -440800,2.9352002,2.1436272,,,,,,,,,,,,,, -440900,3.0811784,2.1261168,,,,,,,,,,,,,, -441000,3.1774585,1.0713531,,,,,,,,,,,,,, -441100,3.2178164,1.1470333,,,,,,,,,,,,,, -441200,2.879457,1.0778269,,,,,,,,,,,,,, -441300,3.2046752,1.5587944,,,,,,,,,,,,,, -441318,,,0.8875781297683716,0.4181137084960937,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,194127.15976166725,212539.07664871216,194127.15976166725,18354.741428136826,34.21055889129639,0.0 -441400,2.92461,1.0045259,,,,,,,,,,,,,, -441500,2.9609587,1.2304025,,,,,,,,,,,,,, -441600,2.9850218,1.136983,,,,,,,,,,,,,, -441700,3.0921192,1.5312415,,,,,,,,,,,,,, -441800,3.019494,1.1838764,,,,,,,,,,,,,, -441900,3.2183764,1.0770811,,,,,,,,,,,,,, -442000,2.997971,1.1833453,,,,,,,,,,,,,, -442100,3.465458,1.1055555,,,,,,,,,,,,,, -442200,2.9608757,1.6821657,,,,,,,,,,,,,, -442273,,,0.8886913657188416,0.4170783460140228,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,194547.1519122124,213004.2418987751,194547.1519122124,18399.771948337555,34.30404305458069,0.0 -442300,3.2121181,0.98675114,,,,,,,,,,,,,, -442400,3.2973483,1.2482097,,,,,,,,,,,,,, -442500,3.2499435,1.7501032,,,,,,,,,,,,,, -442600,3.2207708,1.1760625,,,,,,,,,,,,,, -442700,3.2438195,1.140877,,,,,,,,,,,,,, -442800,3.1142087,1.0495079,,,,,,,,,,,,,, -442900,4.0718727,3.2595553,,,,,,,,,,,,,, -443000,3.3832886,1.0982178,,,,,,,,,,,,,, -443100,3.1086519,1.1982787,,,,,,,,,,,,,, -443200,3.3528364,1.2217926,,,,,,,,,,,,,, -443227,,,0.8878515362739563,0.4187040328979492,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,194967.35440659523,213464.63312387464,194967.35440659523,18439.80325841904,34.41198301315308,0.0 -443300,4.0576024,3.1930287,,,,,,,,,,,,,, -443400,3.0333228,1.0768836,,,,,,,,,,,,,, -443500,2.6941476,1.845531,,,,,,,,,,,,,, -443600,2.8306773,1.7584329,,,,,,,,,,,,,, -443700,3.3535414,1.2980711,,,,,,,,,,,,,, -443800,3.0688758,1.2806773,,,,,,,,,,,,,, -443900,3.117377,1.1198173,,,,,,,,,,,,,, -444000,3.4691973,2.758214,,,,,,,,,,,,,, -444100,2.9987435,1.8163633,,,,,,,,,,,,,, -444170,,,0.8865429759025574,0.4236267507076263,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,195387.2975564003,213931.5210936069,195387.2975564003,18486.59080529213,34.52042746543884,0.0 -444200,3.2294202,1.1416523,,,,,,,,,,,,,, -444300,3.3870044,1.2205081,,,,,,,,,,,,,, -444400,3.181643,1.2089777,,,,,,,,,,,,,, -444500,3.0943072,1.0385875,,,,,,,,,,,,,, -444600,3.0517495,2.021666,,,,,,,,,,,,,, -444700,3.3674264,2.7225113,,,,,,,,,,,,,, -444800,2.973558,1.1372033,,,,,,,,,,,,,, -444900,3.3476548,1.1371076,,,,,,,,,,,,,, -445000,2.9794097,2.1816187,,,,,,,,,,,,,, -445100,3.0062728,1.1201061,,,,,,,,,,,,,, -445126,,,0.8879492282867432,0.4153803884983063,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,195807.40478396416,214392.5746004581,195807.40478396416,18527.39293718338,34.61554837226868,0.0 -445200,3.140568,2.7581854,,,,,,,,,,,,,, -445300,3.6691458,2.9355714,,,,,,,,,,,,,, -445400,2.9550319,1.1197886,,,,,,,,,,,,,, -445500,3.168858,1.0380005,,,,,,,,,,,,,, -445600,3.0347188,1.5840036,,,,,,,,,,,,,, -445700,3.1869888,1.1163465,,,,,,,,,,,,,, -445800,2.9744995,1.2723873,,,,,,,,,,,,,, -445900,3.5572426,1.0561452,,,,,,,,,,,,,, -446000,3.172727,1.3531939,,,,,,,,,,,,,, -446075,,,0.8886132836341858,0.4135034084320068,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,196227.24235582352,214855.56814837456,196227.24235582352,18570.385375261307,34.730725049972534,0.0 -446100,3.228879,1.225471,,,,,,,,,,,,,, -446200,3.1988337,2.5943115,,,,,,,,,,,,,, -446300,3.4784615,1.2180322,,,,,,,,,,,,,, -446400,3.1598227,1.5192149,,,,,,,,,,,,,, -446500,2.824277,1.460084,,,,,,,,,,,,,, -446600,3.037324,1.1690701,,,,,,,,,,,,,, -446700,2.8346736,1.358016,,,,,,,,,,,,,, -446800,2.9975126,1.231098,,,,,,,,,,,,,, -446900,3.4589837,1.1708518,,,,,,,,,,,,,, -447000,3.246183,1.1987002,,,,,,,,,,,,,, -447008,,,0.8878905773162842,0.4189607799053192,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,196647.45865631104,215321.72642993927,196647.45865631104,18616.15464353561,34.85476279258728,0.0 -447100,2.9648077,1.1456668,,,,,,,,,,,,,, -447200,3.5579765,3.053683,,,,,,,,,,,,,, -447300,3.239972,2.4877954,,,,,,,,,,,,,, -447400,3.1289985,1.1350518,,,,,,,,,,,,,, -447500,3.8612711,3.225936,,,,,,,,,,,,,, -447600,2.9799175,1.0960003,,,,,,,,,,,,,, -447700,3.2178738,1.1583109,,,,,,,,,,,,,, -447800,3.6262455,3.214461,,,,,,,,,,,,,, -447900,3.5282986,3.0333226,,,,,,,,,,,,,, -447962,,,0.8869726657867432,0.4157951772212982,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,197067.3557920456,215782.27796936035,197067.3557920456,18656.66638660431,34.947226762771606,0.0 -448000,3.260963,1.0289555,,,,,,,,,,,,,, -448100,3.1916497,1.0776365,,,,,,,,,,,,,, -448200,3.2798915,1.5195993,,,,,,,,,,,,,, -448300,3.4696457,1.1243948,,,,,,,,,,,,,, -448400,3.56131,1.1145214,,,,,,,,,,,,,, -448500,3.8293915,1.5540372,,,,,,,,,,,,,, -448600,2.9590182,1.1486981,,,,,,,,,,,,,, -448700,3.1802604,1.2171808,,,,,,,,,,,,,, -448800,3.0147552,1.070065,,,,,,,,,,,,,, -448900,2.9197373,1.0786816,,,,,,,,,,,,,, -448911,,,0.8879492282867432,0.4203372597694397,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,197487.6512079239,216252.2682375908,197487.6512079239,18706.196996450424,35.063289403915405,0.0 -449000,3.1536608,1.3400521,,,,,,,,,,,,,, -449100,3.622925,3.1589193,,,,,,,,,,,,,, -449200,3.1398604,2.2342174,,,,,,,,,,,,,, -449300,3.5116048,3.010326,,,,,,,,,,,,,, -449400,3.477306,3.2547204,,,,,,,,,,,,,, -449500,2.9361677,2.1129277,,,,,,,,,,,,,, -449600,3.1320775,1.189397,,,,,,,,,,,,,, -449700,3.3143191,1.3438776,,,,,,,,,,,,,, -449800,3.1320248,1.1645862,,,,,,,,,,,,,, -449865,,,0.8902929425239563,0.4134377539157867,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,197907.5611524582,216719.85712742803,197907.5611524582,18753.732135295868,35.15776562690735,0.0 -449900,3.2000585,1.271491,,,,,,,,,,,,,, -450000,2.8824096,1.2607481,,,,,,,,,,,,,, -450100,3.1283703,1.1193101,,,,,,,,,,,,,, -450200,3.1627991,1.0888748,,,,,,,,,,,,,, -450300,3.0875366,1.5279312,,,,,,,,,,,,,, -450400,3.2566354,1.1836697,,,,,,,,,,,,,, -450500,3.195148,1.1378844,,,,,,,,,,,,,, -450600,3.4509077,2.9793742,,,,,,,,,,,,,, -450700,3.6459491,3.2629218,,,,,,,,,,,,,, -450800,3.1819122,1.2426931,,,,,,,,,,,,,, -450819,,,0.8892773389816284,0.4141132533550262,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,198327.7874581813,217189.5351667404,198327.7874581813,18803.01878094673,35.27369546890259,0.0 -450900,3.123627,1.3553922,,,,,,,,,,,,,, -451000,2.8267171,1.0973966,,,,,,,,,,,,,, -451100,2.9861152,1.2743598,,,,,,,,,,,,,, -451200,3.0430145,2.2837439,,,,,,,,,,,,,, -451300,3.8192906,3.0570354,,,,,,,,,,,,,, -451400,3.015846,1.0918896,,,,,,,,,,,,,, -451500,3.3881218,1.1008389,,,,,,,,,,,,,, -451600,2.9717224,1.6662469,,,,,,,,,,,,,, -451700,3.1299646,1.1678112,,,,,,,,,,,,,, -451771,,,0.8887499570846558,0.4107620418071747,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,198747.9303052425,217650.66642785072,198747.9303052425,18843.86488962173,35.36725568771362,0.0 -451800,2.906244,1.6113672,,,,,,,,,,,,,, -451900,3.877584,3.331596,,,,,,,,,,,,,, -452000,3.0991495,2.037433,,,,,,,,,,,,,, -452100,3.1436818,2.1480036,,,,,,,,,,,,,, -452200,4.109047,3.170939,,,,,,,,,,,,,, -452300,3.1071393,1.4914411,,,,,,,,,,,,,, -452400,2.9350562,1.0402713,,,,,,,,,,,,,, -452500,3.2423675,2.898368,,,,,,,,,,,,,, -452600,3.9199686,3.309691,,,,,,,,,,,,,, -452700,4.971704,1.1702513,,,,,,,,,,,,,, -452726,,,0.8876367211341858,0.4182771444320678,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,199168.05628609657,218110.0209414959,199168.05628609657,18882.94619178772,35.465445041656494,0.0 -452800,3.0737026,1.7301903,,,,,,,,,,,,,, -452900,3.8676417,3.258621,,,,,,,,,,,,,, -453000,3.0266736,1.1663283,,,,,,,,,,,,,, -453100,3.8722584,3.2470543,,,,,,,,,,,,,, -453200,3.3173783,2.6865308,,,,,,,,,,,,,, -453300,2.925275,1.0063344,,,,,,,,,,,,,, -453400,3.5200782,3.0466945,,,,,,,,,,,,,, -453500,3.487388,1.1638184,,,,,,,,,,,,,, -453600,3.5954666,1.0557659,,,,,,,,,,,,,, -453667,,,0.8885546922683716,0.4138567149639129,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,199588.1751565933,218571.8494119644,199588.1751565933,18924.48981571197,35.583218812942505,0.0 -453700,3.1960087,2.633799,,,,,,,,,,,,,, -453800,3.1145499,1.4814701,,,,,,,,,,,,,, -453900,3.0227444,1.0671512,,,,,,,,,,,,,, -454000,3.3914032,3.0611033,,,,,,,,,,,,,, -454100,2.927805,1.1450087,,,,,,,,,,,,,, -454200,3.0077302,1.1657395,,,,,,,,,,,,,, -454300,3.2141652,1.3781468,,,,,,,,,,,,,, -454400,3.4101012,1.1491631,,,,,,,,,,,,,, -454500,2.817007,1.1576833,,,,,,,,,,,,,, -454600,2.7684593,1.0560735,,,,,,,,,,,,,, -454608,,,0.88636714220047,0.4170846343040466,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,200008.38172197345,219042.3599162101,200008.38172197345,18974.62998723984,35.69882941246033,0.0 -454700,3.5732443,3.035974,,,,,,,,,,,,,, -454800,3.291637,2.644298,,,,,,,,,,,,,, -454900,3.0964508,1.207335,,,,,,,,,,,,,, -455000,3.1892018,2.361897,,,,,,,,,,,,,, -455100,3.2115996,1.0632727,,,,,,,,,,,,,, -455200,3.8652573,3.2258291,,,,,,,,,,,,,, -455300,3.3577478,3.000539,,,,,,,,,,,,,, -455400,3.434112,2.1962652,,,,,,,,,,,,,, -455500,3.54115,3.0680466,,,,,,,,,,,,,, -455562,,,0.88978511095047,0.4132919609546661,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,200428.2955994606,219501.9488492012,200428.2955994606,19014.16286468506,35.791934967041016,0.0 -455600,3.2912464,1.1300962,,,,,,,,,,,,,, -455700,3.0892057,1.1024172,,,,,,,,,,,,,, -455800,3.7638512,3.1470146,,,,,,,,,,,,,, -455900,3.2199824,1.3628707,,,,,,,,,,,,,, -456000,2.9450674,2.3254848,,,,,,,,,,,,,, -456100,3.2139401,1.251206,,,,,,,,,,,,,, -456200,2.88982,1.0324712,,,,,,,,,,,,,, -456300,4.011862,3.0403523,,,,,,,,,,,,,, -456400,2.9194558,1.058384,,,,,,,,,,,,,, -456500,3.0112493,1.0413892,,,,,,,,,,,,,, -456515,,,0.8884179592132568,0.4171240925788879,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,200848.2231376171,219961.8817877769,200848.2231376171,19054.00243878365,35.90884757041931,0.0 -456600,3.065257,1.0987664,,,,,,,,,,,,,, -456700,4.2128167,3.2320836,,,,,,,,,,,,,, -456800,3.3924928,1.1130959,,,,,,,,,,,,,, -456900,3.0819423,1.1023833,,,,,,,,,,,,,, -457000,3.3239524,1.1251154,,,,,,,,,,,,,, -457100,3.5511625,3.07306,,,,,,,,,,,,,, -457200,2.853437,1.0154182,,,,,,,,,,,,,, -457300,3.8686786,3.133854,,,,,,,,,,,,,, -457400,3.3401442,2.8677838,,,,,,,,,,,,,, -457457,,,0.8882616758346558,0.41612708568573,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,201268.2205114365,220424.9903256893,201268.2205114365,19096.94891095161,36.02534198760986,0.0 -457500,3.0486455,1.0401562,,,,,,,,,,,,,, -457600,3.0834897,1.1230161,,,,,,,,,,,,,, -457700,3.4583716,3.192037,,,,,,,,,,,,,, -457800,3.2078338,1.0844876,,,,,,,,,,,,,, -457900,2.9807706,1.1563945,,,,,,,,,,,,,, -458000,4.091015,1.2426188,,,,,,,,,,,,,, -458100,4.001153,3.1221359,,,,,,,,,,,,,, -458200,3.1942565,1.1029131,,,,,,,,,,,,,, -458300,2.8934774,1.2246875,,,,,,,,,,,,,, -458400,2.9934764,1.5632399,,,,,,,,,,,,,, -458408,,,0.8860155940055847,0.4237580001354217,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,201688.31512522697,220888.446969986,201688.31512522697,19140.144728183743,36.14205241203308,0.0 -458500,3.2729626,2.0662637,,,,,,,,,,,,,, -458600,3.2846198,1.0323541,,,,,,,,,,,,,, -458700,3.1230967,1.4428188,,,,,,,,,,,,,, -458800,3.0217354,1.7507634,,,,,,,,,,,,,, -458900,3.0785604,1.1807094,,,,,,,,,,,,,, -459000,3.2258275,1.262985,,,,,,,,,,,,,, -459100,3.3702836,1.2061608,,,,,,,,,,,,,, -459200,2.9849946,1.4473525,,,,,,,,,,,,,, -459300,2.970936,1.1017257,,,,,,,,,,,,,, -459363,,,0.88734370470047,0.4194290935993194,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,202108.22252106667,221348.96659827232,202108.22252106667,19180.52417993545,36.32553815841675,0.0 -459400,3.1171846,1.1975204,,,,,,,,,,,,,, -459500,3.171628,2.8193538,,,,,,,,,,,,,, -459600,3.0988967,1.215212,,,,,,,,,,,,,, -459700,3.5574868,2.952112,,,,,,,,,,,,,, -459800,3.772525,3.2421718,,,,,,,,,,,,,, -459900,3.1042678,2.5083084,,,,,,,,,,,,,, -460000,3.4412448,2.9164944,,,,,,,,,,,,,, -460100,3.1979346,1.1024377,,,,,,,,,,,,,, -460200,3.0555382,1.3702227,,,,,,,,,,,,,, -460300,3.3397927,1.3498398,,,,,,,,,,,,,, -460316,,,0.890429675579071,0.4093267023563385,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,202528.4370057583,221813.6135840416,202528.4370057583,19224.78936481476,36.44341278076172,0.0 -460400,3.459875,1.2812114,,,,,,,,,,,,,, -460500,3.0475717,1.1256675,,,,,,,,,,,,,, -460600,4.9690075,3.11348,,,,,,,,,,,,,, -460700,3.9188058,1.1345718,,,,,,,,,,,,,, -460800,3.3491447,1.1330861,,,,,,,,,,,,,, -460900,3.6062865,2.7604074,,,,,,,,,,,,,, -461000,3.174707,1.449873,,,,,,,,,,,,,, -461100,3.1983867,1.1189482,,,,,,,,,,,,,, -461200,3.2712617,2.6084602,,,,,,,,,,,,,, -461268,,,0.8883007764816284,0.4184071719646454,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,202948.36471748352,222282.7366580963,202948.36471748352,19273.821818590164,36.55770182609558,0.0 -461300,3.8186846,2.5330315,,,,,,,,,,,,,, -461400,3.3503482,1.8114046,,,,,,,,,,,,,, -461500,3.077968,1.2707105,,,,,,,,,,,,,, -461600,3.1102953,1.1477509,,,,,,,,,,,,,, -461700,3.3013737,2.370631,,,,,,,,,,,,,, -461800,3.453158,2.7732852,,,,,,,,,,,,,, -461900,3.7508,3.1808944,,,,,,,,,,,,,, -462000,3.04739,1.5772732,,,,,,,,,,,,,, -462100,3.2853577,1.9775214,,,,,,,,,,,,,, -462200,3.1247325,1.427344,,,,,,,,,,,,,, -462224,,,0.88623046875,0.422183096408844,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,203368.44725489616,222751.3733520508,203368.44725489616,19322.230692386627,36.65404319763184,0.0 -462300,3.0057635,1.197921,,,,,,,,,,,,,, -462400,3.0941434,1.0913472,,,,,,,,,,,,,, -462500,3.0530512,1.1499131,,,,,,,,,,,,,, -462600,3.2430284,1.3279556,,,,,,,,,,,,,, -462700,3.111845,1.1533128,,,,,,,,,,,,,, -462800,3.1190913,1.1344314,,,,,,,,,,,,,, -462900,2.8403542,1.3189691,,,,,,,,,,,,,, -463000,3.550249,3.2436216,,,,,,,,,,,,,, -463100,3.281767,1.1215156,,,,,,,,,,,,,, -463180,,,0.8882421851158142,0.4180953502655029,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,203788.41516184807,223210.6052725315,203788.41516184807,19361.34489059448,36.750526428222656,0.0 -463200,3.4068294,2.7091627,,,,,,,,,,,,,, -463300,3.045276,1.1259826,,,,,,,,,,,,,, -463400,3.2983782,1.1544499,,,,,,,,,,,,,, -463500,3.2903144,1.146251,,,,,,,,,,,,,, -463600,3.3526247,1.2713602,,,,,,,,,,,,,, -463700,2.920607,1.2404188,,,,,,,,,,,,,, -463800,3.77777,3.0635743,,,,,,,,,,,,,, -463900,3.209788,1.0609967,,,,,,,,,,,,,, -464000,3.0920439,2.1228921,,,,,,,,,,,,,, -464100,3.2103574,1.2041187,,,,,,,,,,,,,, -464120,,,0.887499988079071,0.4132614433765411,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,204207.8207669258,223673.75593829155,204207.8207669258,19404.15842437744,37.63438177108765,0.0 -464200,2.8002193,1.5709214,,,,,,,,,,,,,, -464300,2.935229,1.1405996,,,,,,,,,,,,,, -464400,2.8729887,1.7076557,,,,,,,,,,,,,, -464500,3.250716,1.1396992,,,,,,,,,,,,,, -464600,3.1918025,2.8007421,,,,,,,,,,,,,, -464700,3.6907187,1.2540491,,,,,,,,,,,,,, -464800,2.8011203,1.0390925,,,,,,,,,,,,,, -464900,3.219202,2.4553554,,,,,,,,,,,,,, -465000,3.5395703,1.0579927,,,,,,,,,,,,,, -465070,,,0.8865624666213989,0.4224892854690552,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,204627.69776153564,224144.2284386158,204627.69776153564,19454.58660507202,37.75303912162781,0.0 -465100,3.289236,1.6703472,,,,,,,,,,,,,, -465200,3.4150832,1.2296001,,,,,,,,,,,,,, -465300,3.1290011,2.7800531,,,,,,,,,,,,,, -465400,3.5219414,1.1859797,,,,,,,,,,,,,, -465500,3.350482,2.4467273,,,,,,,,,,,,,, -465600,3.1842375,1.1091107,,,,,,,,,,,,,, -465700,3.2436676,1.1094428,,,,,,,,,,,,,, -465800,3.2289863,1.1374364,,,,,,,,,,,,,, -465900,3.2013626,1.6275616,,,,,,,,,,,,,, -466000,3.280174,1.4489987,,,,,,,,,,,,,, -466025,,,0.8872656226158142,0.420401781797409,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,205047.81872677803,224604.7361364365,205047.81872677803,19494.825929641724,37.85163021087647,0.0 -466100,3.2684534,1.3689075,,,,,,,,,,,,,, -466200,3.0550404,1.1621602,,,,,,,,,,,,,, -466300,3.4686475,2.988941,,,,,,,,,,,,,, -466400,3.2308035,1.1199532,,,,,,,,,,,,,, -466500,3.1630695,1.1814227,,,,,,,,,,,,,, -466600,3.6333222,3.1560178,,,,,,,,,,,,,, -466700,2.894775,1.6483467,,,,,,,,,,,,,, -466800,3.2271066,1.1032073,,,,,,,,,,,,,, -466900,3.0190475,1.0613089,,,,,,,,,,,,,, -466975,,,0.8876757621765137,0.4198971390724182,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,205468.0308253765,225067.5422782898,205468.0308253765,19537.241803646088,37.98124074935913,0.0 -467000,3.075193,2.393894,,,,,,,,,,,,,, -467100,3.0984018,1.2541128,,,,,,,,,,,,,, -467200,3.2434835,1.1121192,,,,,,,,,,,,,, -467300,2.989586,1.0733669,,,,,,,,,,,,,, -467400,3.0848048,1.2017679,,,,,,,,,,,,,, -467500,2.957293,2.0573494,,,,,,,,,,,,,, -467600,3.2424235,1.118156,,,,,,,,,,,,,, -467700,2.8722346,1.5079741,,,,,,,,,,,,,, -467800,3.5201387,3.1328516,,,,,,,,,,,,,, -467900,3.3042579,2.5429554,,,,,,,,,,,,,, -467912,,,0.887988269329071,0.4176282882690429,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,205887.92675709724,225533.91818737984,205887.92675709724,19583.554235219955,38.10013508796692,0.0 -468000,3.9585135,3.0886958,,,,,,,,,,,,,, -468100,3.115053,1.0776244,,,,,,,,,,,,,, -468200,3.0605578,1.5273635,,,,,,,,,,,,,, -468300,3.0042353,1.2835722,,,,,,,,,,,,,, -468400,3.1352925,1.0764085,,,,,,,,,,,,,, -468500,3.477883,2.3614795,,,,,,,,,,,,,, -468600,3.015879,1.0573797,,,,,,,,,,,,,, -468700,3.600539,2.9920979,,,,,,,,,,,,,, -468800,3.2821567,2.2183967,,,,,,,,,,,,,, -468863,,,0.8880664110183716,0.415490984916687,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,206307.7999806404,226005.22905516624,206307.7999806404,19634.826660633087,38.21635723114014,0.0 -468900,3.4313126,3.0542989,,,,,,,,,,,,,, -469000,6.968658,3.2514095,,,,,,,,,,,,,, -469100,3.155918,1.9795603,,,,,,,,,,,,,, -469200,3.0592694,1.0888424,,,,,,,,,,,,,, -469300,3.3487453,2.781035,,,,,,,,,,,,,, -469400,3.0862463,1.0521507,,,,,,,,,,,,,, -469500,4.204188,1.0483465,,,,,,,,,,,,,, -469600,3.1331275,1.5140561,,,,,,,,,,,,,, -469700,3.0148883,1.5287193,,,,,,,,,,,,,, -469800,3.182117,1.1853383,,,,,,,,,,,,,, -469817,,,0.8888671398162842,0.4122638702392578,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,206727.8733928204,226465.93561315536,206727.8733928204,19675.311785697937,38.31474304199219,0.0 -469900,3.0637348,1.0710404,,,,,,,,,,,,,, -470000,3.0865552,1.216666,,,,,,,,,,,,,, -470100,3.395152,1.219712,,,,,,,,,,,,,, -470200,8.374682,2.1532667,,,,,,,,,,,,,, -470300,3.344067,1.10043,,,,,,,,,,,,,, -470400,3.9436913,1.341012,,,,,,,,,,,,,, -470500,3.0970929,2.2089643,,,,,,,,,,,,,, -470600,3.2009435,2.8943977,,,,,,,,,,,,,, -470700,3.1081958,1.0974534,,,,,,,,,,,,,, -470759,,,0.8885351419448853,0.4153727293014526,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,207148.1556749344,226931.6114668846,207148.1556749344,19720.53402042389,38.43796992301941,0.0 -470800,3.014604,1.2248116,,,,,,,,,,,,,, -470900,3.1252284,1.0260551,,,,,,,,,,,,,, -471000,3.1881983,1.1256204,,,,,,,,,,,,,, -471100,3.3816042,3.0125809,,,,,,,,,,,,,, -471200,3.3009145,1.1887786,,,,,,,,,,,,,, -471300,3.1840215,2.4574454,,,,,,,,,,,,,, -471400,4.4642653,3.256998,,,,,,,,,,,,,, -471500,3.0234022,1.3213081,,,,,,,,,,,,,, -471600,2.8429244,1.0283086,,,,,,,,,,,,,, -471700,3.0406303,2.6868975,,,,,,,,,,,,,, -471704,,,0.8871874809265137,0.4193627834320068,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,207568.38267040253,227394.2978703976,207568.38267040253,19762.82291984558,38.55806303024292,0.0 -471800,3.044906,1.4363621,,,,,,,,,,,,,, -471900,3.635907,3.1881971,,,,,,,,,,,,,, -472000,3.1905127,1.0756918,,,,,,,,,,,,,, -472100,3.352192,1.1919264,,,,,,,,,,,,,, -472200,3.3357358,1.0824368,,,,,,,,,,,,,, -472300,2.899137,1.0593195,,,,,,,,,,,,,, -472400,3.2560582,1.2008156,,,,,,,,,,,,,, -472500,2.9341888,1.5594119,,,,,,,,,,,,,, -472600,4.1304026,3.2130904,,,,,,,,,,,,,, -472652,,,0.8884179592132568,0.4182244837284088,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,207988.270154953,227856.02899241447,207988.270154953,19804.49430608749,38.68139624595642,0.0 -472700,3.1241157,1.0035992,,,,,,,,,,,,,, -472800,3.426297,3.0975986,,,,,,,,,,,,,, -472900,2.8813567,1.6976715,,,,,,,,,,,,,, -473000,3.6577315,1.1602664,,,,,,,,,,,,,, -473100,3.129265,1.4744054,,,,,,,,,,,,,, -473200,3.4820685,1.1922064,,,,,,,,,,,,,, -473300,3.2259922,1.1041785,,,,,,,,,,,,,, -473400,3.2636886,1.1623495,,,,,,,,,,,,,, -473500,3.2725832,1.2247736,,,,,,,,,,,,,, -473600,3.885867,1.1712222,,,,,,,,,,,,,, -473606,,,0.8891015648841858,0.4154936075210571,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,208408.5045876503,228315.894469738,208408.5045876503,19843.9731669426,38.7846896648407,0.0 -473700,2.826769,1.0068462,,,,,,,,,,,,,, -473800,3.0663664,2.053419,,,,,,,,,,,,,, -473900,3.2904465,2.676445,,,,,,,,,,,,,, -474000,2.8748226,1.0504321,,,,,,,,,,,,,, -474100,3.211897,1.0954411,,,,,,,,,,,,,, -474200,3.333203,1.3296094,,,,,,,,,,,,,, -474300,3.1625772,2.7775533,,,,,,,,,,,,,, -474400,3.0415006,2.2089405,,,,,,,,,,,,,, -474500,2.8698907,1.3361343,,,,,,,,,,,,,, -474560,,,0.8894921541213989,0.4117230474948883,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,208828.76788640025,228776.6988329888,208828.76788640025,19884.347821950912,38.901485204696655,0.0 -474600,3.3242526,2.265518,,,,,,,,,,,,,, -474700,3.321554,2.3098054,,,,,,,,,,,,,, -474800,4.136687,3.312624,,,,,,,,,,,,,, -474900,3.161874,2.9796565,,,,,,,,,,,,,, -475000,3.0588224,2.0050557,,,,,,,,,,,,,, -475100,3.0482638,1.8738736,,,,,,,,,,,,,, -475200,4.1963053,3.229815,,,,,,,,,,,,,, -475300,3.5543683,2.507194,,,,,,,,,,,,,, -475400,3.056976,1.1892624,,,,,,,,,,,,,, -475500,3.186933,1.1556436,,,,,,,,,,,,,, -475507,,,0.8870507478713989,0.4228179156780243,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,209248.7898492813,229246.5458858013,209248.7898492813,19934.0034570694,39.02295017242432,0.0 -475600,3.137846,2.670979,,,,,,,,,,,,,, -475700,3.3122995,2.9293122,,,,,,,,,,,,,, -475800,3.70684,2.5472941,,,,,,,,,,,,,, -475900,3.1297238,2.625293,,,,,,,,,,,,,, -476000,3.0680017,2.0068078,,,,,,,,,,,,,, -476100,2.981886,2.1009684,,,,,,,,,,,,,, -476200,3.7997553,3.1691432,,,,,,,,,,,,,, -476300,2.8413012,2.0803962,,,,,,,,,,,,,, -476400,2.9612532,1.3675996,,,,,,,,,,,,,, -476464,,,0.8886327743530273,0.4156919717788696,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,209668.7404808998,229705.8835630417,209668.7404808998,19973.2469329834,39.11786460876465,0.0 -476500,5.3685904,3.2915158,,,,,,,,,,,,,, -476600,2.9170208,1.7502265,,,,,,,,,,,,,, -476700,3.0970762,0.9942083,,,,,,,,,,,,,, -476800,3.790413,2.8014538,,,,,,,,,,,,,, -476900,3.1511304,1.180268,,,,,,,,,,,,,, -477000,3.335256,1.1583002,,,,,,,,,,,,,, -477100,3.5236504,1.4825925,,,,,,,,,,,,,, -477200,3.5671554,3.0536826,,,,,,,,,,,,,, -477300,3.4893577,2.2584941,,,,,,,,,,,,,, -477400,3.7205086,2.8597128,,,,,,,,,,,,,, -477411,,,0.8877343535423279,0.4107499718666076,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,210088.61083316803,230172.3604860305,210088.61083316803,20019.682067871094,39.2409257888794,0.0 -477500,3.0903025,2.4051569,,,,,,,,,,,,,, -477600,2.9045074,1.10984,,,,,,,,,,,,,, -477700,3.138497,2.1209304,,,,,,,,,,,,,, -477800,4.0138497,3.163641,,,,,,,,,,,,,, -477900,3.0997488,2.5518737,,,,,,,,,,,,,, -478000,2.9740787,1.0544138,,,,,,,,,,,,,, -478100,2.9945135,1.0872176,,,,,,,,,,,,,, -478200,3.615199,2.2615168,,,,,,,,,,,,,, -478300,2.9968133,1.097283,,,,,,,,,,,,,, -478363,,,0.8869726657867432,0.4187818169593811,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,210508.6423151493,230633.310161829,210508.6423151493,20060.430149316788,39.36173105239868,0.0 -478400,3.138669,1.0622324,,,,,,,,,,,,,, -478500,3.1238265,1.1441451,,,,,,,,,,,,,, -478600,3.2669165,1.1093194,,,,,,,,,,,,,, -478700,3.7361038,2.9390864,,,,,,,,,,,,,, -478800,3.1588025,1.0724628,,,,,,,,,,,,,, -478900,3.4586806,2.0500069,,,,,,,,,,,,,, -479000,2.7315602,1.3541915,,,,,,,,,,,,,, -479100,3.7291274,2.9076238,,,,,,,,,,,,,, -479200,2.9944398,1.5741543,,,,,,,,,,,,,, -479300,3.2386632,2.7804258,,,,,,,,,,,,,, -479305,,,0.8893749713897705,0.4133050739765167,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,210928.602619648,231098.6457479,210928.602619648,20105.6329100132,39.48548531532288,0.0 -479400,2.9310696,1.1180162,,,,,,,,,,,,,, -479500,3.8115861,1.3226902,,,,,,,,,,,,,, -479600,3.119114,2.689081,,,,,,,,,,,,,, -479700,3.3813186,1.1303968,,,,,,,,,,,,,, -479800,3.5470216,1.2131243,,,,,,,,,,,,,, -479900,3.0368595,1.139268,,,,,,,,,,,,,, -480000,3.0987754,1.2118858,,,,,,,,,,,,,, -480100,3.9389782,3.1392016,,,,,,,,,,,,,, -480200,3.0359776,1.0016302,,,,,,,,,,,,,, -480260,,,0.8883788585662842,0.4162740111351013,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,211348.75797367096,231558.53671503067,211348.75797367096,20145.21943593025,39.585007667541504,0.0 -480300,3.6273506,1.1601496,,,,,,,,,,,,,, -480400,3.0150673,1.4940661,,,,,,,,,,,,,, -480500,3.124702,1.1770502,,,,,,,,,,,,,, -480600,2.9508114,1.1806574,,,,,,,,,,,,,, -480700,3.1380217,2.1960113,,,,,,,,,,,,,, -480800,2.9943492,1.5404592,,,,,,,,,,,,,, -480900,3.4628685,1.1999981,,,,,,,,,,,,,, -481000,3.781468,2.9615426,,,,,,,,,,,,,, -481100,3.1009068,2.1092603,,,,,,,,,,,,,, -481200,3.1367,1.5314693,,,,,,,,,,,,,, -481210,,,0.8876757621765137,0.418246865272522,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,211768.6008861065,232021.582574606,211768.6008861065,20188.25280070305,39.7059965133667,0.0 -481300,3.1154683,1.1151228,,,,,,,,,,,,,, -481400,3.5189111,1.1248602,,,,,,,,,,,,,, -481500,3.7167337,3.167097,,,,,,,,,,,,,, -481600,3.3110096,1.1692832,,,,,,,,,,,,,, -481700,2.835711,2.1871912,,,,,,,,,,,,,, -481800,3.2753808,1.1102955,,,,,,,,,,,,,, -481900,3.2285721,1.2340379,,,,,,,,,,,,,, -482000,3.0593555,1.0753071,,,,,,,,,,,,,, -482100,3.2082496,1.9797527,,,,,,,,,,,,,, -482163,,,0.8851562142372131,0.4248987138271332,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,212188.4453895092,232490.9327290058,212188.4453895092,20237.58907341957,39.82587885856629,0.0 -482200,2.798779,1.1161907,,,,,,,,,,,,,, -482300,3.074797,1.1213279,,,,,,,,,,,,,, -482400,3.2886848,3.0582356,,,,,,,,,,,,,, -482500,3.3765812,1.0185603,,,,,,,,,,,,,, -482600,3.369567,1.4418496,,,,,,,,,,,,,, -482700,3.105047,2.7389338,,,,,,,,,,,,,, -482800,4.2649746,3.27616,,,,,,,,,,,,,, -482900,3.266813,2.8469744,,,,,,,,,,,,,, -483000,2.9069736,1.8595424,,,,,,,,,,,,,, -483100,3.3722928,2.9089456,,,,,,,,,,,,,, -483121,,,0.8895312547683716,0.4151040911674499,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,212608.6507983208,232951.67104578007,212608.6507983208,20277.970195531845,39.92865061759949,0.0 -483200,5.179128,2.7789588,,,,,,,,,,,,,, -483300,3.1221652,2.3360064,,,,,,,,,,,,,, -483400,3.0822246,1.1270337,,,,,,,,,,,,,, -483500,3.0666273,2.2101655,,,,,,,,,,,,,, -483600,3.3132713,1.1873798,,,,,,,,,,,,,, -483700,3.4074385,1.1896175,,,,,,,,,,,,,, -483800,4.090967,2.7065337,,,,,,,,,,,,,, -483900,3.1570792,1.2028128,,,,,,,,,,,,,, -484000,3.090783,1.7521003,,,,,,,,,,,,,, -484076,,,0.888964831829071,0.4096009731292724,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,213028.6021828652,233411.50339460373,213028.6021828652,20317.68110179901,40.04916906356812,0.0 -484100,3.307847,1.4133563,,,,,,,,,,,,,, -484200,3.4432693,2.979559,,,,,,,,,,,,,, -484300,3.134243,2.3412056,,,,,,,,,,,,,, -484400,2.895888,1.6018455,,,,,,,,,,,,,, -484500,3.4687767,1.1580122,,,,,,,,,,,,,, -484600,3.2523358,1.1270477,,,,,,,,,,,,,, -484700,3.7206423,3.237866,,,,,,,,,,,,,, -484800,3.012299,2.1802447,,,,,,,,,,,,,, -484900,3.0258741,1.8045689,,,,,,,,,,,,,, -485000,2.8121865,1.7174916,,,,,,,,,,,,,, -485011,,,0.88832026720047,0.4179162383079529,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,213448.84131765369,233872.877859354,213448.84131765369,20358.64752459526,40.17074942588806,0.0 -485100,3.2168567,1.100181,,,,,,,,,,,,,, -485200,2.8141265,1.5206717,,,,,,,,,,,,,, -485300,3.0993865,1.0837286,,,,,,,,,,,,,, -485400,2.8973055,1.0571235,,,,,,,,,,,,,, -485500,3.324053,2.555606,,,,,,,,,,,,,, -485600,2.9640121,1.1412258,,,,,,,,,,,,,, -485700,3.1768587,2.5581985,,,,,,,,,,,,,, -485800,3.1571083,2.5351255,,,,,,,,,,,,,, -485900,3.0747578,1.9688449,,,,,,,,,,,,,, -485943,,,0.8872656226158142,0.4240639507770538,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,213868.94949388504,234342.0355579853,213868.94949388504,20407.51692557335,40.30259609222412,0.0 -486000,3.1957827,1.2273839,,,,,,,,,,,,,, -486100,3.3119333,2.088047,,,,,,,,,,,,,, -486200,3.4729178,1.6089667,,,,,,,,,,,,,, -486300,3.7982295,3.2687006,,,,,,,,,,,,,, -486400,3.2893953,1.1677594,,,,,,,,,,,,,, -486500,3.7196383,3.3200336,,,,,,,,,,,,,, -486600,3.3169246,1.1562537,,,,,,,,,,,,,, -486700,3.059612,2.494008,,,,,,,,,,,,,, -486800,3.001787,1.3048807,,,,,,,,,,,,,, -486899,,,0.88832026720047,0.416032999753952,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,214289.19634270668,234808.39871358871,214289.19634270668,20453.481875658035,40.40373396873474,0.0 -486900,3.2534053,2.640226,,,,,,,,,,,,,, -487000,3.2146442,1.1542993,,,,,,,,,,,,,, -487100,3.6378577,2.9929593,,,,,,,,,,,,,, -487200,3.0066986,1.9666115,,,,,,,,,,,,,, -487300,3.092553,1.084982,,,,,,,,,,,,,, -487400,3.0134058,1.1041853,,,,,,,,,,,,,, -487500,2.9447813,1.7756864,,,,,,,,,,,,,, -487600,3.391837,1.0784353,,,,,,,,,,,,,, -487700,3.1648228,2.839989,,,,,,,,,,,,,, -487800,3.1645684,1.0843569,,,,,,,,,,,,,, -487850,,,0.8885351419448853,0.4118785858154297,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,214709.33081889155,235269.7968466282,214709.33081889155,20494.503543138504,40.596789598464966,0.0 -487900,2.9604008,1.48739,,,,,,,,,,,,,, -488000,3.48013,2.8503633,,,,,,,,,,,,,, -488100,3.2429857,1.0629349,,,,,,,,,,,,,, -488200,3.2314074,1.2085434,,,,,,,,,,,,,, -488300,2.9332108,1.1854329,,,,,,,,,,,,,, -488400,3.725405,3.2590632,,,,,,,,,,,,,, -488500,3.2298048,1.0900009,,,,,,,,,,,,,, -488600,3.7954972,3.3206944,,,,,,,,,,,,,, -488700,3.125518,2.2946243,,,,,,,,,,,,,, -488786,,,0.8866210579872131,0.4215874373912811,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,215129.6508102417,235732.26984477043,215129.6508102417,20536.487267017365,40.7184145450592,0.0 -488800,4.338479,2.8545516,,,,,,,,,,,,,, -488900,3.220788,1.1143321,,,,,,,,,,,,,, -489000,3.407822,1.0716437,,,,,,,,,,,,,, -489100,3.614638,3.1603384,,,,,,,,,,,,,, -489200,3.2522907,1.1244376,,,,,,,,,,,,,, -489300,3.1140914,1.1703252,,,,,,,,,,,,,, -489400,2.7588372,1.3928611,,,,,,,,,,,,,, -489500,3.3357491,1.2805622,,,,,,,,,,,,,, -489600,3.101884,1.1911784,,,,,,,,,,,,,, -489700,3.2921932,1.2347898,,,,,,,,,,,,,, -489735,,,0.8875585794448853,0.4183970689773559,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,215549.8371298313,236197.4879083633,215549.8371298313,20581.34592533112,40.841949224472046,0.0 -489800,2.9893706,1.1114707,,,,,,,,,,,,,, -489900,3.2961237,1.1425526,,,,,,,,,,,,,, -490000,3.1464338,1.6111732,,,,,,,,,,,,,, -490100,3.1611533,1.5124055,,,,,,,,,,,,,, -490200,3.0787284,1.1627476,,,,,,,,,,,,,, -490300,3.1612098,1.5143572,,,,,,,,,,,,,, -490400,3.8769119,1.8071678,,,,,,,,,,,,,, -490500,3.1290615,1.1160276,,,,,,,,,,,,,, -490600,2.9290729,1.2314942,,,,,,,,,,,,,, -490689,,,0.8875976204872131,0.4221860468387604,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,215970.1008954048,236658.7532222271,215970.1008954048,20622.188474416733,40.951939821243286,0.0 -490700,3.0962498,2.10183,,,,,,,,,,,,,, -490800,3.2228336,2.6933677,,,,,,,,,,,,,, -490900,3.20169,2.3140604,,,,,,,,,,,,,, -491000,3.4161522,1.1197567,,,,,,,,,,,,,, -491100,2.9720392,1.1137156,,,,,,,,,,,,,, -491200,3.0577226,2.1005151,,,,,,,,,,,,,, -491300,3.5109286,1.1666663,,,,,,,,,,,,,, -491400,3.276684,1.9176615,,,,,,,,,,,,,, -491500,2.841035,1.6398913,,,,,,,,,,,,,, -491600,2.9629567,2.116079,,,,,,,,,,,,,, -491638,,,0.88623046875,0.4179167151451111,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,216390.0034880638,237118.9931571484,216390.0034880638,20662.358990192413,41.06981325149536,0.0 -491700,3.1414652,1.2767054,,,,,,,,,,,,,, -491800,3.5012722,3.0432286,,,,,,,,,,,,,, -491900,2.9516406,1.1552124,,,,,,,,,,,,,, -492000,3.0125794,1.1640711,,,,,,,,,,,,,, -492100,3.2100313,1.2795665,,,,,,,,,,,,,, -492200,3.052361,1.848188,,,,,,,,,,,,,, -492300,3.2048433,2.655507,,,,,,,,,,,,,, -492400,2.9966967,1.1311464,,,,,,,,,,,,,, -492500,3.201101,1.9674792,,,,,,,,,,,,,, -492593,,,0.888964831829071,0.4156960546970367,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,216810.23352575305,237590.33246302605,216810.23352575305,20713.29764199257,41.19071340560913,0.0 -492600,3.372304,1.2742779,,,,,,,,,,,,,, -492700,3.3458583,1.0229973,,,,,,,,,,,,,, -492800,3.0872123,1.0941126,,,,,,,,,,,,,, -492900,3.2083724,1.1968489,,,,,,,,,,,,,, -493000,3.029257,1.1653264,,,,,,,,,,,,,, -493100,3.078193,1.1073062,,,,,,,,,,,,,, -493200,3.2900476,1.3464366,,,,,,,,,,,,,, -493300,2.9897087,1.1385447,,,,,,,,,,,,,, -493400,2.9310627,1.7409917,,,,,,,,,,,,,, -493500,3.810585,3.0832255,,,,,,,,,,,,,, -493552,,,0.8879101276397705,0.4139544665813446,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,217230.2256433964,238059.1432621479,217230.2256433964,20761.96725344658,41.29007768630981,0.0 -493600,3.2834134,3.0599258,,,,,,,,,,,,,, -493700,3.0481234,1.2000983,,,,,,,,,,,,,, -493800,3.2210424,1.2524773,,,,,,,,,,,,,, -493900,2.982373,1.0397208,,,,,,,,,,,,,, -494000,3.3569832,1.1029432,,,,,,,,,,,,,, -494100,3.2868636,1.3872801,,,,,,,,,,,,,, -494200,3.1993928,2.5146236,,,,,,,,,,,,,, -494300,2.9149547,1.6106772,,,,,,,,,,,,,, -494400,3.0653906,1.930709,,,,,,,,,,,,,, -494500,3.083615,2.147836,,,,,,,,,,,,,, -494507,,,0.887988269329071,0.4172173142433166,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,217650.34621620167,238519.40918159485,217650.34621620167,20801.961450099945,41.39179611206055,0.0 -494600,2.8712263,1.8150609,,,,,,,,,,,,,, -494700,3.1127925,1.2293137,,,,,,,,,,,,,, -494800,3.2117155,2.9213934,,,,,,,,,,,,,, -494900,3.257637,1.1711118,,,,,,,,,,,,,, -495000,3.1047835,1.2237555,,,,,,,,,,,,,, -495100,3.2977815,1.1262773,,,,,,,,,,,,,, -495200,3.275065,1.1902364,,,,,,,,,,,,,, -495300,4.871347,3.2150633,,,,,,,,,,,,,, -495400,2.8803334,1.0814289,,,,,,,,,,,,,, -495449,,,0.8880664110183716,0.4167808294296264,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,218070.4790790081,238979.97016358376,218070.4790790081,20842.219407081604,41.5134813785553,0.0 -495500,3.5597916,1.6943588,,,,,,,,,,,,,, -495600,3.018447,1.1203811,,,,,,,,,,,,,, -495700,3.1356587,1.0830204,,,,,,,,,,,,,, -495800,3.2717702,2.7220879,,,,,,,,,,,,,, -495900,4.0863543,2.5011792,,,,,,,,,,,,,, -496000,3.1859493,1.2538267,,,,,,,,,,,,,, -496100,3.6348436,3.183107,,,,,,,,,,,,,, -496200,3.0574963,1.3884866,,,,,,,,,,,,,, -496300,3.0870583,1.0866187,,,,,,,,,,,,,, -496385,,,0.8890038728713989,0.4185583293437958,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,218490.4862658977,239451.57863640785,218490.4862658977,20893.64836382866,41.63759517669678,0.0 -496400,3.1453054,1.0652959,,,,,,,,,,,,,, -496500,3.3608007,1.7622776,,,,,,,,,,,,,, -496600,3.74705,3.2465405,,,,,,,,,,,,,, -496700,3.1605406,1.1587422,,,,,,,,,,,,,, -496800,3.4015038,2.3292928,,,,,,,,,,,,,, -496900,2.8597627,1.1184474,,,,,,,,,,,,,, -497000,3.100795,1.4694433,,,,,,,,,,,,,, -497100,3.9118493,3.3422737,,,,,,,,,,,,,, -497200,2.9882715,1.9007925,,,,,,,,,,,,,, -497300,3.2658513,2.8686364,,,,,,,,,,,,,, -497340,,,0.8873828053474426,0.4235077500343323,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,218910.6556572914,239920.405535698,218910.6556572914,20942.15079689026,41.74072694778442,0.0 -497400,3.114439,1.5290065,,,,,,,,,,,,,, -497500,2.8639503,1.8678421,,,,,,,,,,,,,, -497600,3.1472273,2.444002,,,,,,,,,,,,,, -497700,3.2116892,1.1969049,,,,,,,,,,,,,, -497800,4.2018485,1.1389844,,,,,,,,,,,,,, -497900,3.3399143,1.2121414,,,,,,,,,,,,,, -498000,3.172119,2.513102,,,,,,,,,,,,,, -498100,2.9708934,1.9919777,,,,,,,,,,,,,, -498200,2.8936887,1.1913236,,,,,,,,,,,,,, -498296,,,0.8897070288658142,0.4100601077079773,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,219330.9631397724,240385.43779563904,219330.9631397724,20986.72572159767,41.84137916564941,0.0 -498300,4.213931,3.2643456,,,,,,,,,,,,,, -498400,3.430421,3.0666356,,,,,,,,,,,,,, -498500,3.7014782,3.1309032,,,,,,,,,,,,,, -498600,3.0813327,1.1346122,,,,,,,,,,,,,, -498700,3.3024263,1.1216193,,,,,,,,,,,,,, -498800,3.1577477,1.4540501,,,,,,,,,,,,,, -498900,3.3479807,1.1612957,,,,,,,,,,,,,, -499000,3.076581,1.1518604,,,,,,,,,,,,,, -499100,3.33826,2.2363977,,,,,,,,,,,,,, -499200,3.0576494,1.4396623,,,,,,,,,,,,,, -499248,,,0.8898437023162842,0.4133493602275848,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,219751.18331694603,240851.29478430748,219751.18331694603,21032.17861413956,41.97613787651062,0.0 -499300,3.4829147,2.8413718,,,,,,,,,,,,,, -499400,3.0661287,1.012907,,,,,,,,,,,,,, -499500,3.4732757,2.275621,,,,,,,,,,,,,, -499600,3.0489717,1.4265513,,,,,,,,,,,,,, -499700,3.039837,1.2399566,,,,,,,,,,,,,, -499800,3.3452044,1.1885242,,,,,,,,,,,,,, -499900,4.188578,3.207016,,,,,,,,,,,,,, -500000,3.3697963,2.4634087,,,,,,,,,,,,,, -500100,3.0445974,1.8473213,,,,,,,,,,,,,, -500199,,,0.8876171708106995,0.4118110835552215,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,220171.2231209278,241313.8032577037,220171.2231209278,21074.464852809902,42.10923957824707,0.0 -500200,3.0009747,1.439731,,,,,,,,,,,,,, -500300,2.9479642,1.2604764,,,,,,,,,,,,,, -500400,2.9194329,1.171267,,,,,,,,,,,,,, -500500,3.172395,1.11497,,,,,,,,,,,,,, -500600,3.0708022,1.4068424,,,,,,,,,,,,,, -500700,3.0794623,1.0948297,,,,,,,,,,,,,, -500800,2.8821626,1.9222771,,,,,,,,,,,,,, -500900,3.811095,3.249653,,,,,,,,,,,,,, -501000,3.1277413,2.481668,,,,,,,,,,,,,, -501100,4.4654346,2.5220742,,,,,,,,,,,,,, -501146,,,0.8863085508346558,0.4214934110641479,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,220591.38791012764,241780.624147892,220591.38791012764,21120.94376969337,42.23678421974182,0.0 -501200,3.1395342,1.647833,,,,,,,,,,,,,, -501300,3.1184268,1.9185091,,,,,,,,,,,,,, -501400,2.9204106,2.1787517,,,,,,,,,,,,,, -501500,3.5557466,3.1439834,,,,,,,,,,,,,, -501600,3.6060736,3.0517101,,,,,,,,,,,,,, -501700,2.9588203,1.9579458,,,,,,,,,,,,,, -501800,3.0045066,1.1687211,,,,,,,,,,,,,, -501900,3.8561804,3.25739,,,,,,,,,,,,,, -502000,3.915754,3.273212,,,,,,,,,,,,,, -502100,,,0.8866991996765137,0.4188326299190521,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,221011.6120646,242241.1875770092,221011.6120646,21161.110082149506,42.36047720909119,0.0 -502100,3.0101497,1.2943976,,,,,,,,,,,,,, -502200,2.9763045,1.0976262,,,,,,,,,,,,,, -502300,2.9790528,2.4972887,,,,,,,,,,,,,, -502400,3.4494357,1.1237059,,,,,,,,,,,,,, -502500,2.9440722,1.4650812,,,,,,,,,,,,,, -502600,3.034197,1.1721025,,,,,,,,,,,,,, -502700,3.9904058,3.2931123,,,,,,,,,,,,,, -502800,3.0580177,1.0599785,,,,,,,,,,,,,, -502900,3.501539,2.960865,,,,,,,,,,,,,, -503000,3.087969,1.1465955,,,,,,,,,,,,,, -503036,,,0.8915234208106995,0.4066730737686157,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,221431.6432979107,242708.8383128643,221431.6432979107,21208.55673956871,42.48579788208008,0.0 -503100,2.834432,1.8563412,,,,,,,,,,,,,, -503200,3.284776,1.8434438,,,,,,,,,,,,,, -503300,3.0557613,2.452512,,,,,,,,,,,,,, -503400,2.9831746,2.3777838,,,,,,,,,,,,,, -503500,4.4742517,3.0959182,,,,,,,,,,,,,, -503600,3.060357,2.0715792,,,,,,,,,,,,,, -503700,2.8933003,1.1435621,,,,,,,,,,,,,, -503800,3.2044039,2.5351934,,,,,,,,,,,,,, -503900,3.187797,1.1765832,,,,,,,,,,,,,, -503987,,,0.8878710865974426,0.4177476763725281,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,221851.79825234413,243173.70783042908,221851.79825234413,21253.096702337265,42.610806465148926,0.0 -504000,3.5463254,3.0010402,,,,,,,,,,,,,, -504100,3.2173016,1.143243,,,,,,,,,,,,,, -504200,3.2766376,1.4689577,,,,,,,,,,,,,, -504300,3.2757306,1.1212754,,,,,,,,,,,,,, -504400,2.7564392,1.3864928,,,,,,,,,,,,,, -504500,3.0256572,1.7966754,,,,,,,,,,,,,, -504600,3.1005328,1.0516423,,,,,,,,,,,,,, -504700,4.3620996,3.206589,,,,,,,,,,,,,, -504800,3.6390073,2.681391,,,,,,,,,,,,,, -504900,3.3522766,1.1670911,,,,,,,,,,,,,, -504939,,,0.8858398199081421,0.4197298288345337,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,222271.69676876068,243636.8362257481,222271.69676876068,21296.14944720268,42.73886179924011,0.0 -505000,2.921428,1.2234042,,,,,,,,,,,,,, -505100,3.0699925,2.581613,,,,,,,,,,,,,, -505200,3.3519676,1.0853179,,,,,,,,,,,,,, -505300,3.286998,1.5344117,,,,,,,,,,,,,, -505400,3.245077,3.0086944,,,,,,,,,,,,,, -505500,3.082269,1.1385047,,,,,,,,,,,,,, -505600,3.0904553,1.6758868,,,,,,,,,,,,,, -505700,3.3413332,1.1791165,,,,,,,,,,,,,, -505800,3.362464,2.8463364,,,,,,,,,,,,,, -505895,,,0.8896679282188416,0.4172879457473755,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,222691.6384379864,244101.2883806229,222691.6384379864,21340.503289461136,42.846314430236816,0.0 -505900,3.0281682,2.140852,,,,,,,,,,,,,, -506000,3.6120453,1.1393124,,,,,,,,,,,,,, -506100,3.0780592,1.6114533,,,,,,,,,,,,,, -506200,3.261434,1.1464976,,,,,,,,,,,,,, -506300,3.4820483,2.7882197,,,,,,,,,,,,,, -506400,3.2193847,2.9073408,,,,,,,,,,,,,, -506500,3.2022057,1.4018447,,,,,,,,,,,,,, -506600,3.3326983,1.110661,,,,,,,,,,,,,, -506700,3.2294066,1.1525944,,,,,,,,,,,,,, -506800,3.1874158,1.0666972,,,,,,,,,,,,,, -506850,,,0.88720703125,0.417205810546875,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,223111.82583379743,244562.72336554527,223111.82583379743,21381.57668352127,42.97149038314819,0.0 -506900,2.8736312,1.2916098,,,,,,,,,,,,,, -507000,4.46863,3.2889426,,,,,,,,,,,,,, -507100,3.1405811,1.1098421,,,,,,,,,,,,,, -507200,3.8786805,3.0335596,,,,,,,,,,,,,, -507300,3.3326697,1.1135703,,,,,,,,,,,,,, -507400,4.145327,3.2402506,,,,,,,,,,,,,, -507500,3.4971542,2.8122573,,,,,,,,,,,,,, -507600,3.2429996,1.1910307,,,,,,,,,,,,,, -507700,3.1865656,2.6510508,,,,,,,,,,,,,, -507797,,,0.8874022960662842,0.4177261292934418,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,223530.96260380745,245024.03587937355,223530.96260380745,21422.78080868721,43.89462304115296,0.0 -507800,3.1601188,1.1810791,,,,,,,,,,,,,, -507900,3.3024786,1.1543447,,,,,,,,,,,,,, -508000,3.0163083,1.2945218,,,,,,,,,,,,,, -508100,3.1615872,1.2268013,,,,,,,,,,,,,, -508200,3.1426945,1.0872765,,,,,,,,,,,,,, -508300,3.0420094,1.1659154,,,,,,,,,,,,,, -508400,3.2420945,1.2439836,,,,,,,,,,,,,, -508500,4.44751,3.2604487,,,,,,,,,,,,,, -508600,3.1099343,1.1263729,,,,,,,,,,,,,, -508700,3.4187956,2.7880933,,,,,,,,,,,,,, -508734,,,0.8900390267372131,0.4131047427654266,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,223951.0003323555,245485.90597319603,223951.0003323555,21464.441920757294,44.01762056350708,0.0 -508800,3.15997,1.1882033,,,,,,,,,,,,,, -508900,3.1284707,1.2274256,,,,,,,,,,,,,, -509000,3.3990557,2.718895,,,,,,,,,,,,,, -509100,3.2033207,1.1478643,,,,,,,,,,,,,, -509200,3.4315553,1.3952136,,,,,,,,,,,,,, -509300,3.0064816,1.11326,,,,,,,,,,,,,, -509400,3.1151202,1.8612771,,,,,,,,,,,,,, -509500,3.0670831,1.1707666,,,,,,,,,,,,,, -509600,4.0744042,3.366965,,,,,,,,,,,,,, -509680,,,0.8871093392372131,0.4248417615890503,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,224370.9559493065,245950.3347005844,224370.9559493065,21508.74162054062,44.14262199401856,0.0 -509700,3.294049,2.3868787,,,,,,,,,,,,,, -509800,2.9152486,1.8535293,,,,,,,,,,,,,, -509900,3.4686954,1.2336338,,,,,,,,,,,,,, -510000,3.2730417,1.8162466,,,,,,,,,,,,,, -510100,3.0676358,1.588171,,,,,,,,,,,,,, -510200,4.352845,2.8870525,,,,,,,,,,,,,, -510300,3.1966357,1.2317395,,,,,,,,,,,,,, -510400,3.3003333,1.0630425,,,,,,,,,,,,,, -510500,3.3160934,1.1257207,,,,,,,,,,,,,, -510600,3.2182465,2.4637022,,,,,,,,,,,,,, -510633,,,0.8889452815055847,0.4144680798053741,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,224790.7845821381,246415.9335541725,224790.7845821381,21554.337733268738,44.26745653152466,0.0 -510700,2.9794166,2.4328835,,,,,,,,,,,,,, -510800,3.4169815,1.8604498,,,,,,,,,,,,,, -510900,3.1786945,1.0825846,,,,,,,,,,,,,, -511000,4.137804,3.3108327,,,,,,,,,,,,,, -511100,5.2575192,3.2312114,,,,,,,,,,,,,, -511200,3.7488108,1.803845,,,,,,,,,,,,,, -511300,3.130073,1.4450365,,,,,,,,,,,,,, -511400,3.6117575,3.05449,,,,,,,,,,,,,, -511500,3.2228954,0.9911604,,,,,,,,,,,,,, -511589,,,0.8864648342132568,0.4167560040950775,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,225211.0207374096,246881.53498005867,225211.0207374096,21599.514624118805,44.40690755844116,0.0 -511600,3.8956652,3.145565,,,,,,,,,,,,,, -511700,3.0457246,1.2041725,,,,,,,,,,,,,, -511800,2.8744948,1.0439754,,,,,,,,,,,,,, -511900,3.374956,1.4075067,,,,,,,,,,,,,, -512000,3.1762657,1.1110104,,,,,,,,,,,,,, -512100,3.035561,2.4989471,,,,,,,,,,,,,, -512200,2.9465318,1.9403534,,,,,,,,,,,,,, -512300,2.9481454,2.0570838,,,,,,,,,,,,,, -512400,3.1575766,1.0516224,,,,,,,,,,,,,, -512500,3.164087,1.1576544,,,,,,,,,,,,,, -512540,,,0.8864648342132568,0.4207066595554352,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,225630.9021072388,247339.98562526703,225630.9021072388,21637.92317771912,44.51779532432556,0.0 -512600,3.4176862,1.6094031,,,,,,,,,,,,,, -512700,3.3605075,2.6950474,,,,,,,,,,,,,, -512800,3.3824744,1.2359154,,,,,,,,,,,,,, -512900,2.9125504,1.0564232,,,,,,,,,,,,,, -513000,3.053005,1.941595,,,,,,,,,,,,,, -513100,2.9476016,1.5874093,,,,,,,,,,,,,, -513200,3.1260915,1.3270141,,,,,,,,,,,,,, -513300,3.1163726,1.3375998,,,,,,,,,,,,,, -513400,3.193343,1.0342652,,,,,,,,,,,,,, -513462,,,0.8897070288658142,0.414121150970459,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,226050.8605763912,247805.08733654025,226050.8605763912,21682.883960962296,44.65179920196533,0.0 -513500,3.643499,3.1595411,,,,,,,,,,,,,, -513600,3.0074377,1.1284051,,,,,,,,,,,,,, -513700,3.6944792,2.8564253,,,,,,,,,,,,,, -513800,3.2595408,2.932208,,,,,,,,,,,,,, -513900,3.9000692,3.1231592,,,,,,,,,,,,,, -514000,3.1345713,1.3910369,,,,,,,,,,,,,, -514100,3.167823,1.1061652,,,,,,,,,,,,,, -514200,2.8190544,1.0878292,,,,,,,,,,,,,, -514300,3.1040576,2.0528777,,,,,,,,,,,,,, -514400,3.0050159,1.0154988,,,,,,,,,,,,,, -514413,,,0.8862890601158142,0.4216505289077759,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,226470.79879546163,248268.15232920647,226470.79879546163,21725.83594822884,44.77778220176697,0.0 -514500,3.1686876,2.6462398,,,,,,,,,,,,,, -514600,3.0158262,2.121314,,,,,,,,,,,,,, -514700,3.08945,1.8273746,,,,,,,,,,,,,, -514800,2.9496722,1.0614171,,,,,,,,,,,,,, -514900,3.1468205,1.3530226,,,,,,,,,,,,,, -515000,2.8572357,0.98000693,,,,,,,,,,,,,, -515100,3.2256014,1.1425335,,,,,,,,,,,,,, -515200,3.1181607,1.1993661,,,,,,,,,,,,,, -515300,2.9361045,1.5708025,,,,,,,,,,,,,, -515363,,,0.8855664134025574,0.4250814616680145,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,226890.87377142903,248730.3489470482,226890.87377142903,21767.78396916389,44.90311050415039,0.0 -515400,2.8528435,1.8912157,,,,,,,,,,,,,, -515500,3.1484249,2.2453158,,,,,,,,,,,,,, -515600,3.0842614,1.7481084,,,,,,,,,,,,,, -515700,2.8055642,1.0654624,,,,,,,,,,,,,, -515800,2.8527095,1.3110743,,,,,,,,,,,,,, -515900,3.0403535,2.6859157,,,,,,,,,,,,,, -516000,3.1172726,1.1260766,,,,,,,,,,,,,, -516100,2.9986534,1.113718,,,,,,,,,,,,,, -516200,3.1985548,1.1118975,,,,,,,,,,,,,, -516300,2.9746392,1.6916597,,,,,,,,,,,,,, -516319,,,0.8897656202316284,0.4119044840335846,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,227310.91848754883,249198.0817565918,227310.91848754883,21815.223826885223,45.10130214691162,0.0 -516400,4.1571283,3.1363354,,,,,,,,,,,,,, -516500,3.1225939,1.1902101,,,,,,,,,,,,,, -516600,3.2078743,1.820623,,,,,,,,,,,,,, -516700,2.9173071,1.4892699,,,,,,,,,,,,,, -516800,3.1169994,2.4843602,,,,,,,,,,,,,, -516900,3.1475685,2.7571764,,,,,,,,,,,,,, -517000,3.096852,1.1070915,,,,,,,,,,,,,, -517100,2.9456096,1.1351452,,,,,,,,,,,,,, -517200,3.5953948,3.1092937,,,,,,,,,,,,,, -517273,,,0.888671875,0.4123574793338775,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,227731.0243735313,249657.7138082981,227731.0243735313,21854.596625089645,45.2050838470459,0.0 -517300,2.9985895,1.0541873,,,,,,,,,,,,,, -517400,3.0347474,1.0627313,,,,,,,,,,,,,, -517500,3.086911,1.1488643,,,,,,,,,,,,,, -517600,3.0794299,1.1355568,,,,,,,,,,,,,, -517700,3.2744632,2.225443,,,,,,,,,,,,,, -517800,3.5524502,3.0479305,,,,,,,,,,,,,, -517900,3.1127672,1.1599909,,,,,,,,,,,,,, -518000,3.0846035,1.0985752,,,,,,,,,,,,,, -518100,3.1627939,1.2341956,,,,,,,,,,,,,, -518200,2.9877734,1.6434042,,,,,,,,,,,,,, -518217,,,0.8880273103713989,0.4175519347190857,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,228151.06265687945,250122.7528753281,228151.06265687945,21899.41953110695,45.33416724205017,0.0 -518300,2.9275932,2.6046991,,,,,,,,,,,,,, -518400,3.3444195,2.5238996,,,,,,,,,,,,,, -518500,3.6871123,2.9939237,,,,,,,,,,,,,, -518600,3.4540663,2.7814207,,,,,,,,,,,,,, -518700,3.2922158,1.3146337,,,,,,,,,,,,,, -518800,3.6292078,3.108736,,,,,,,,,,,,,, -518900,2.9984744,1.0573575,,,,,,,,,,,,,, -519000,2.898137,1.3434794,,,,,,,,,,,,,, -519100,3.2046907,1.5054425,,,,,,,,,,,,,, -519169,,,0.8868945240974426,0.4196770787239074,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,228570.9589471817,250579.7715086937,228570.9589471817,21936.388018369675,45.43733549118042,0.0 -519200,3.155872,1.3410561,,,,,,,,,,,,,, -519300,3.4199996,2.8961952,,,,,,,,,,,,,, -519400,2.9309042,1.8959823,,,,,,,,,,,,,, -519500,4.4650397,3.1944005,,,,,,,,,,,,,, -519600,3.2363808,1.5355518,,,,,,,,,,,,,, -519700,3.035422,1.089787,,,,,,,,,,,,,, -519800,3.2581258,1.4604187,,,,,,,,,,,,,, -519900,3.2820275,1.0168115,,,,,,,,,,,,,, -520000,3.1617136,1.1033443,,,,,,,,,,,,,, -520100,3.3187113,1.3937031,,,,,,,,,,,,,, -520121,,,0.8871679306030273,0.4226775765419006,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,228990.81290125847,251043.0187008381,228990.81290125847,21979.60288333893,45.56658720970154,0.0 -520200,3.1901512,1.1139883,,,,,,,,,,,,,, -520300,3.4580407,2.9342437,,,,,,,,,,,,,, -520400,3.1350884,1.4299649,,,,,,,,,,,,,, -520500,3.6826036,2.9409454,,,,,,,,,,,,,, -520600,3.1715825,1.2526803,,,,,,,,,,,,,, -520700,2.9078693,1.1118995,,,,,,,,,,,,,, -520800,3.2669215,2.68458,,,,,,,,,,,,,, -520900,3.201276,1.1609291,,,,,,,,,,,,,, -521000,3.1443777,1.9347651,,,,,,,,,,,,,, -521066,,,0.8896874785423279,0.4149737358093261,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,229410.6971549988,251507.8479943276,229410.6971549988,22024.37559247017,45.68960404396057,0.0 -521100,3.269793,2.7154622,,,,,,,,,,,,,, -521200,4.0528574,3.1636922,,,,,,,,,,,,,, -521300,3.062837,1.5719934,,,,,,,,,,,,,, -521400,3.3755937,1.3404585,,,,,,,,,,,,,, -521500,3.5842366,1.1762519,,,,,,,,,,,,,, -521600,3.279236,1.5808853,,,,,,,,,,,,,, -521700,3.024702,1.1021814,,,,,,,,,,,,,, -521800,3.125138,2.3810914,,,,,,,,,,,,,, -521900,3.3950264,2.9219508,,,,,,,,,,,,,, -522000,2.934682,1.2453688,,,,,,,,,,,,,, -522018,,,0.8896093368530273,0.4115025997161865,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,229830.59431529045,251973.10938835144,229830.59431529045,22069.55735874176,45.82239007949829,0.0 -522100,3.497163,1.1258852,,,,,,,,,,,,,, -522200,3.4429026,2.889344,,,,,,,,,,,,,, -522300,3.34908,1.4884641,,,,,,,,,,,,,, -522400,3.1384313,1.7034814,,,,,,,,,,,,,, -522500,3.005923,1.1627531,,,,,,,,,,,,,, -522600,3.507098,2.494102,,,,,,,,,,,,,, -522700,3.5081108,1.1580685,,,,,,,,,,,,,, -522800,3.1583064,1.3358175,,,,,,,,,,,,,, -522900,2.9014533,1.1088097,,,,,,,,,,,,,, -522975,,,0.8873828053474426,0.4191846847534179,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,230250.5108890533,252432.78043317795,230250.5108890533,22109.15285468101,45.93182587623596,0.0 -523000,3.141704,1.560313,,,,,,,,,,,,,, -523100,3.9079535,3.1625605,,,,,,,,,,,,,, -523200,3.539324,2.831478,,,,,,,,,,,,,, -523300,3.1486485,1.0884911,,,,,,,,,,,,,, -523400,3.7225096,1.612648,,,,,,,,,,,,,, -523500,3.3143919,1.2560067,,,,,,,,,,,,,, -523600,3.0378969,1.4650205,,,,,,,,,,,,,, -523700,3.365876,1.1242521,,,,,,,,,,,,,, -523800,3.2602568,1.1704618,,,,,,,,,,,,,, -523900,3.511813,1.3117367,,,,,,,,,,,,,, -523921,,,0.8888671398162842,0.4110132753849029,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,230670.47365617752,252904.2251083851,230670.47365617752,22160.45391464233,46.063761472702026,0.0 -524000,3.651651,3.240475,,,,,,,,,,,,,, -524100,3.3616087,1.1590354,,,,,,,,,,,,,, -524200,3.131838,1.0748421,,,,,,,,,,,,,, -524300,2.992968,1.0608416,,,,,,,,,,,,,, -524400,3.0135407,1.0750915,,,,,,,,,,,,,, -524500,3.83202,3.190945,,,,,,,,,,,,,, -524600,3.8970363,3.105249,,,,,,,,,,,,,, -524700,4.0210495,3.2101433,,,,,,,,,,,,,, -524800,3.0754416,1.0759957,,,,,,,,,,,,,, -524878,,,0.8877148032188416,0.4167254567146301,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,231090.6543712616,253366.4442384243,231090.6543712616,22202.33893918991,46.168038845062256,0.0 -524900,3.124151,1.9173598,,,,,,,,,,,,,, -525000,4.195497,3.3433661,,,,,,,,,,,,,, -525100,3.0619695,1.0728798,,,,,,,,,,,,,, -525200,3.115086,1.1023673,,,,,,,,,,,,,, -525300,2.9726112,1.4519923,,,,,,,,,,,,,, -525400,3.7126503,1.1296344,,,,,,,,,,,,,, -525500,3.0790596,1.1820652,,,,,,,,,,,,,, -525600,3.1489828,1.0626457,,,,,,,,,,,,,, -525700,4.622759,3.2383876,,,,,,,,,,,,,, -525800,4.0318403,2.8607135,,,,,,,,,,,,,, -525826,,,0.8874022960662842,0.4148069620132446,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,231510.7819397449,253830.32165122032,231510.7819397449,22245.91069793701,46.29690170288086,0.0 -525900,3.2342627,1.3876572,,,,,,,,,,,,,, -526000,3.1345022,1.0416348,,,,,,,,,,,,,, -526100,3.0406923,1.2411789,,,,,,,,,,,,,, -526200,3.5531185,1.1354725,,,,,,,,,,,,,, -526300,3.3796751,3.098136,,,,,,,,,,,,,, -526400,3.1464257,1.5697199,,,,,,,,,,,,,, -526500,3.0393348,1.0965848,,,,,,,,,,,,,, -526600,3.0680053,1.9011546,,,,,,,,,,,,,, -526700,3.0335472,1.0874629,,,,,,,,,,,,,, -526778,,,0.8897070288658142,0.4127437472343445,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,231930.7475101948,254298.8728711605,231930.7475101948,22294.248700141907,46.49507021903992,0.0 -526800,2.918771,2.0084279,,,,,,,,,,,,,, -526900,2.8905146,1.4318442,,,,,,,,,,,,,, -527000,3.045827,1.1440287,,,,,,,,,,,,,, -527100,3.1222842,1.2485167,,,,,,,,,,,,,, -527200,3.6290693,3.0358014,,,,,,,,,,,,,, -527300,3.188304,1.1645412,,,,,,,,,,,,,, -527400,3.117844,1.0368766,,,,,,,,,,,,,, -527500,3.0592954,1.2197839,,,,,,,,,,,,,, -527600,3.4134073,1.2571638,,,,,,,,,,,,,, -527700,3.0780435,1.1857151,,,,,,,,,,,,,, -527733,,,0.8895898461341858,0.4154190421104431,0.7829399704933167,0.8525302410125732,50000.0,0.6614000201225281,1.453892707824707,10000.0,232350.6270961761,254758.7492594719,232350.6270961761,22334.084921598434,46.60599493980408,0.0 -527800,3.4449224,1.2115521,,,,,,,,,,,,,, -527900,3.0057423,1.4517173,,,,,,,,,,,,,, -528000,3.5398655,1.2156609,,,,,,,,,,,,,, -528100,4.278433,3.18335,,,,,,,,,,,,,, -528200,2.9886036,2.574868,,,,,,,,,,,,,, -528201,,,,,,,,,,,232560.57017302513,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index e126901d7..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,42 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -179.3943247795105,0.0,62.45750379562378,1,0,62.45750379562378,30.067327,2472,1.1492494871326142,241.85189867019653,30.63818,1.394191324990109,30.082218,5348,1.1100051169661218 -304.4973177909851,0.041710615158081,1502.963063955307,1757,0,1502.963063955307,3.1675358,2472,0.5846281965348445,1807.5719513893127,3.0410984,0.5905542625280633,3.4923468,5348,0.6396207652278015 -434.86101627349854,0.0924489498138427,2942.972272872925,3550,0,2942.972272872925,0.66489387,2472,0.212479434525623,3378.072374105453,0.62109214,0.2086354313720267,0.9490379,5348,0.2759010204968284 -566.6209635734558,0.1419525146484375,4383.247010231018,5317,0,4383.247010231018,0.49097118,2472,0.1631222960209615,4950.232921600342,0.41384113,0.1480673234551502,0.7433237,5348,0.2233410892379582 -699.9576814174652,0.1899044513702392,5823.124886751175,7055,0,5823.124886751175,0.43877867,2472,0.1471370828509333,6523.570137262344,0.37714663,0.134654253369498,0.67645335,5348,0.2034042306689709 -831.6092147827148,0.2443428039550781,7263.046508550644,8792,0,7263.046508550644,0.39292622,2472,0.1334877013385331,8095.271713018417,0.33957744,0.1227793602818065,0.63309586,5348,0.1919151935275206 -964.390949010849,0.3064408302307129,8702.932670116425,10525,0,8702.932670116425,0.38322657,2472,0.1270692421749639,9668.075123786926,0.34834158,0.1233335645048572,0.6127697,5348,0.1846355851202487 -1097.5230059623718,0.357492446899414,10143.555659532549,12256,0,10143.555659532549,0.35471162,2472,0.1210570146040257,11241.953930139542,0.3036773,0.1103630910599196,0.5770447,5348,0.1742375237745831 -1231.768848657608,0.4228217601776123,11583.86165523529,13990,0,11583.86165523529,0.34388256,2472,0.1162025470720858,12816.645854234695,0.26289067,0.0994409631035648,0.55949336,5348,0.1682902574895971 -1370.736781835556,0.48415207862854,13023.903942584991,15738,0,13023.903942584991,0.32644108,2472,0.1121808543050393,14395.793190717695,0.24273707,0.0902063638153099,0.5420629,5348,0.1638973903472778 -1502.307198524475,0.5404567718505859,14464.083364963531,17455,0,14464.083364963531,0.31930014,2472,0.1086060162898868,15967.672088384628,0.24323381,0.0948961478438544,0.53616065,5348,0.1624298830821514 -1636.3237302303314,0.5938196182250977,15904.700827360151,19158,0,15904.700827360151,0.31532496,2472,0.1049296203765766,17542.432443857193,0.25264207,0.0925830931542346,0.5239434,5348,0.1574867007154098 -1771.5046956539154,0.6519536972045898,17345.316667318344,20916,0,17345.316667318344,0.3009962,2472,0.1009282391891617,19118.36318206787,0.24772409,0.0907958952172422,0.5120749,5348,0.1534800196954922 -1903.2082245349884,0.7093734741210938,18785.592749118805,22635,0,18785.592749118805,0.29497865,2472,0.0990189507037962,20690.474380731583,0.22900142,0.0860479473196831,0.5003814,5348,0.1506029330835996 -2037.400043010712,0.7659990787506104,20225.857773780823,24368,0,20225.857773780823,0.28867492,2472,0.0962971990331688,22265.06233000756,0.21313213,0.079963895105781,0.48328742,5348,0.1474651708390859 -2172.017728328705,0.8262643814086914,21666.32645058632,26122,0,21666.32645058632,0.27718583,2472,0.093006723132858,23840.28493332863,0.2005794,0.0753592947129554,0.47255135,5348,0.1411896463500584 -2305.922488927841,0.885556697845459,23106.63402581215,27836,0,23106.63402581215,0.2700468,2472,0.0896553124936526,25414.62991690636,0.19883163,0.0731044992891552,0.46602303,5348,0.1379360282688241 -2438.7332191467285,0.9390220642089844,24546.826768159863,29580,0,24546.826768159863,0.26341733,2472,0.0888631608880222,26987.76125073433,0.20785119,0.0773260097070549,0.4579564,5348,0.1361112988404761 -2571.6975836753845,0.9967617988586426,25986.92082810402,31326,0,25986.92082810402,0.25740403,2472,0.0864663944914996,28560.95162129402,0.18994193,0.0693275601061318,0.4489827,5348,0.1332149029224635 -2706.4022014141083,1.061279058456421,27426.860374689106,33048,0,27426.860374689106,0.25247926,2472,0.0856539313062376,30135.734421014786,0.20247611,0.0736891852323688,0.4474757,5348,0.1316894677389768 -2839.3952023983,1.1202399730682373,28867.33429455757,34778,0,28867.33429455757,0.25044668,2472,0.0848820912802388,31709.335352897644,0.19859408,0.0700789667188848,0.44250348,5348,0.1293433870453865 -2977.8900032043457,1.179589033126831,30307.208152771,36520,0,30307.208152771,0.24829671,2472,0.0823837669855584,33287.83837580681,0.1440633,0.0544417453337843,0.42724466,5348,0.1259546038213116 -3110.9470710754395,1.235142946243286,31747.13672208786,38245,0,31747.13672208786,0.24066494,2472,0.0790729795056161,34860.952849149704,0.1679655,0.0616090473274059,0.41607982,5348,0.1227878776176178 -3242.813079357147,1.2900941371917725,33187.74602842331,39983,0,33187.74602842331,0.23006636,2472,0.0759449962423577,36433.55738687515,0.20621349,0.0749445387632008,0.40232724,5348,0.1190708361895015 -3373.170571565628,1.361530303955078,34628.43957090378,41717,0,34628.43957090378,0.22757748,2472,0.0745028740885178,38004.75534749031,0.21006845,0.07618645808468,0.40071264,5348,0.1179605510875966 -3508.219400167465,1.4211549758911133,36068.67969703674,43454,0,36068.67969703674,0.22053853,2472,0.0726951435013101,39580.1800005436,0.23857479,0.0877853887956198,0.39650992,5348,0.1165413170877704 -3639.807451248169,1.488295316696167,37509.16482591629,45175,0,37509.16482591629,0.21159218,2472,0.0723701582272053,41152.39415502548,0.20972961,0.0750666768718096,0.38620195,5348,0.1124284348841924 -3776.064968347549,1.5486819744110107,38949.38499760628,46908,0,38949.38499760628,0.21320014,2472,0.0701561960473666,42729.00663161278,0.18874624,0.0710643225368265,0.38043752,5348,0.1116077893740888 -3910.092271089554,1.622194766998291,40389.63473272324,48643,0,40389.63473272324,0.20846424,2472,0.0671094591026344,44303.433913230896,0.16027203,0.0604193629076988,0.37589723,5348,0.1079872944765729 -4041.277694702149,1.6876039505004885,41829.74850130081,50360,0,41829.74850130081,0.19980165,2472,0.0668657201470558,45874.87132453919,0.17267397,0.0656883564061761,0.3635543,5348,0.105708796354403 -4175.834066867828,1.761640548706055,43270.4296848774,52092,0,43270.4296848774,0.19757481,2472,0.0647330042857433,47450.25897097588,0.15158781,0.0584736299604878,0.36148843,5348,0.1051391718238605 -4308.934076786041,1.827185869216919,44711.04065990448,53833,0,44711.04065990448,0.19231722,2472,0.0629862084374301,49024.11247205734,0.15212244,0.0581043404680024,0.35791597,5348,0.1030441121098313 -4443.142770767212,1.886054515838623,46150.963654994965,55564,0,46150.963654994965,0.18797308,2472,0.0615847094428533,50598.37662386894,0.14071093,0.0529579020013802,0.34808445,5348,0.0988153740695328 -4574.003474712372,1.948566198348999,47591.68971157074,57296,0,47591.68971157074,0.1845956,2472,0.0598988483334348,52170.102900505066,0.1408747,0.0539587697308929,0.34242028,5348,0.0970968458248452 -4707.52090382576,2.02084755897522,49031.99387717247,59032,0,49031.99387717247,0.17841998,2472,0.0572583429813336,53744.072734594345,0.14205861,0.0535074439930733,0.33538797,5348,0.0947990383965552 -4840.933924913406,2.087460994720459,50472.62706637383,60739,0,50472.62706637383,0.17275436,2472,0.0567302419109134,55318.258915901184,0.11626529,0.0448752437958168,0.33002877,5348,0.0928777624376068 -4972.952255249023,2.15358567237854,51912.70821213722,62451,0,51912.70821213722,0.17404887,2472,0.0566286840127556,56890.49936628342,0.11678836,0.0445723609146091,0.3245294,5348,0.0920571169275032 -5108.296470880508,2.219554662704468,53352.64188218117,64191,0,53352.64188218117,0.16827966,2472,0.0546381492088639,58465.91937327385,0.10589888,0.0406809355889836,0.31565297,5348,0.0894117419890516 -5238.6785979270935,2.280007123947144,54793.7216861248,65914,0,54793.7216861248,0.16650042,2472,0.0526476144049722,60037.51570367813,0.10621779,0.0402648644647767,0.3113982,5348,0.0871235892138216 -5371.498435974121,2.340143918991089,56233.85720348358,67624,0,56233.85720348358,0.16369732,2472,0.0521398249141835,61610.60500574112,0.09754225,0.0369401399789807,0.30889,5348,0.0859939948057966 -5503.210609436035,2.404115915298462,57674.4818482399,69373,0,57674.4818482399,0.16329591,2472,0.052017955436394286,63183.08140993118,0.09445868,0.03591866553025691,0.30673182,5348,0.08555953541809475 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index db6684989..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,737 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,66.43476,31.226974,,,,,,,,,,,,,, -1,,,30.63818,1.394191324990109,30.082218,1.1100051169661218,5348.0,30.067327,1.1492494871326142,2472.0,62.45750379562378,241.85189867019653,62.45750379562378,179.3943247795105,0.0,0.0 -100,0.7838031,5.9505577,,,,,,,,,,,,,, -200,0.7218252,5.8124294,,,,,,,,,,,,,, -300,2.2197638,5.793088,,,,,,,,,,,,,, -400,0.8995598,5.803171,,,,,,,,,,,,,, -500,0.4345341,5.7829814,,,,,,,,,,,,,, -600,6.347502,5.800556,,,,,,,,,,,,,, -700,2.1589062,5.573463,,,,,,,,,,,,,, -800,1.5705203,5.016434,,,,,,,,,,,,,, -900,0.84231704,3.924136,,,,,,,,,,,,,, -1000,2.3405023,3.5519474,,,,,,,,,,,,,, -1100,1.0987102,3.237439,,,,,,,,,,,,,, -1200,0.8433298,3.0859532,,,,,,,,,,,,,, -1300,1.3865834,3.1022093,,,,,,,,,,,,,, -1400,0.62995934,2.846231,,,,,,,,,,,,,, -1500,0.67828494,2.6956143,,,,,,,,,,,,,, -1600,0.96397865,2.633719,,,,,,,,,,,,,, -1700,0.75980306,2.5346625,,,,,,,,,,,,,, -1757,,,3.0410984,0.5905542625280633,3.4923468,0.6396207652278015,5348.0,3.1675358,0.5846281965348445,2472.0,1502.963063955307,1807.5719513893127,1502.963063955307,304.4973177909851,0.041710615158081,0.0 -1800,0.63673806,2.3967297,,,,,,,,,,,,,, -1900,0.8880259,2.379304,,,,,,,,,,,,,, -2000,0.73297703,2.292478,,,,,,,,,,,,,, -2100,0.7546445,2.2423832,,,,,,,,,,,,,, -2200,0.6334413,2.1592894,,,,,,,,,,,,,, -2300,1.9950361,2.1537168,,,,,,,,,,,,,, -2400,0.5230759,2.0989306,,,,,,,,,,,,,, -2500,0.6062974,2.075724,,,,,,,,,,,,,, -2600,0.6433494,2.0494986,,,,,,,,,,,,,, -2700,0.80823463,2.0032885,,,,,,,,,,,,,, -2800,0.62321395,2.0237248,,,,,,,,,,,,,, -2900,0.92592317,2.0193398,,,,,,,,,,,,,, -3000,1.0179956,1.9666897,,,,,,,,,,,,,, -3100,0.5777755,1.9717695,,,,,,,,,,,,,, -3200,0.51113826,1.8881223,,,,,,,,,,,,,, -3300,0.55000323,1.8215854,,,,,,,,,,,,,, -3400,0.6257687,1.9195797,,,,,,,,,,,,,, -3500,0.696385,1.8609334,,,,,,,,,,,,,, -3550,,,0.62109214,0.2086354313720267,0.9490379,0.2759010204968284,5348.0,0.66489387,0.212479434525623,2472.0,2942.972272872925,3378.072374105453,2942.972272872925,434.86101627349854,0.0924489498138427,0.0 -3600,0.5360621,1.878287,,,,,,,,,,,,,, -3700,0.7286329,1.9054009,,,,,,,,,,,,,, -3800,0.63236415,1.8151624,,,,,,,,,,,,,, -3900,0.6133627,1.7684894,,,,,,,,,,,,,, -4000,0.511878,1.7883213,,,,,,,,,,,,,, -4100,0.51701885,1.8523449,,,,,,,,,,,,,, -4200,0.66716367,1.7759104,,,,,,,,,,,,,, -4300,0.6549353,1.8180468,,,,,,,,,,,,,, -4400,0.49734506,1.7046841,,,,,,,,,,,,,, -4500,0.45376143,1.6996561,,,,,,,,,,,,,, -4600,0.5712273,1.724109,,,,,,,,,,,,,, -4700,0.53129953,1.6771404,,,,,,,,,,,,,, -4800,0.6112358,1.6608028,,,,,,,,,,,,,, -4900,0.54209256,1.7592567,,,,,,,,,,,,,, -5000,0.50148153,1.7135218,,,,,,,,,,,,,, -5100,0.56563175,1.7316558,,,,,,,,,,,,,, -5200,0.64386994,1.7090571,,,,,,,,,,,,,, -5300,0.5178512,1.6177465,,,,,,,,,,,,,, -5317,,,0.41384113,0.1480673234551502,0.7433237,0.2233410892379582,5348.0,0.49097118,0.1631222960209615,2472.0,4383.247010231018,4950.232921600342,4383.247010231018,566.6209635734558,0.1419525146484375,0.0 -5400,0.6193933,1.6991044,,,,,,,,,,,,,, -5500,0.51688856,1.7002528,,,,,,,,,,,,,, -5600,0.47352225,1.6998394,,,,,,,,,,,,,, -5700,0.5544598,1.6397225,,,,,,,,,,,,,, -5800,0.53526413,1.6880187,,,,,,,,,,,,,, -5900,0.4771771,1.6581149,,,,,,,,,,,,,, -6000,0.437725,1.6829376,,,,,,,,,,,,,, -6100,0.51451254,1.6105434,,,,,,,,,,,,,, -6200,0.61802506,1.5959342,,,,,,,,,,,,,, -6300,0.5953034,1.6740454,,,,,,,,,,,,,, -6400,0.50690454,1.6312987,,,,,,,,,,,,,, -6500,0.68005383,1.6261749,,,,,,,,,,,,,, -6600,0.4536917,1.6346202,,,,,,,,,,,,,, -6700,0.45188332,1.6298735,,,,,,,,,,,,,, -6800,0.5180606,1.5379285,,,,,,,,,,,,,, -6900,0.4758267,1.5942097,,,,,,,,,,,,,, -7000,0.42600587,1.597694,,,,,,,,,,,,,, -7055,,,0.37714663,0.134654253369498,0.67645335,0.2034042306689709,5348.0,0.43877867,0.1471370828509333,2472.0,5823.124886751175,6523.570137262344,5823.124886751175,699.9576814174652,0.1899044513702392,0.0 -7100,0.44075137,1.6637926,,,,,,,,,,,,,, -7200,0.48162603,1.6815469,,,,,,,,,,,,,, -7300,0.47042432,1.5975072,,,,,,,,,,,,,, -7400,0.5067065,1.5857815,,,,,,,,,,,,,, -7500,0.7956273,1.5957887,,,,,,,,,,,,,, -7600,0.42591813,1.4966384,,,,,,,,,,,,,, -7700,0.48084706,1.596031,,,,,,,,,,,,,, -7800,0.5353084,1.5896327,,,,,,,,,,,,,, -7900,0.4895227,1.5768367,,,,,,,,,,,,,, -8000,0.56921905,1.6112356,,,,,,,,,,,,,, -8100,0.42040405,1.5756298,,,,,,,,,,,,,, -8200,0.44452095,1.5655401,,,,,,,,,,,,,, -8300,0.66391283,1.5903795,,,,,,,,,,,,,, -8400,0.48498815,1.5078261,,,,,,,,,,,,,, -8500,0.4298941,1.6108797,,,,,,,,,,,,,, -8600,0.47487733,1.5584422,,,,,,,,,,,,,, -8700,0.479187,1.5203621,,,,,,,,,,,,,, -8792,,,0.33957744,0.1227793602818065,0.63309586,0.1919151935275206,5348.0,0.39292622,0.1334877013385331,2472.0,7263.046508550644,8095.271713018417,7263.046508550644,831.6092147827148,0.2443428039550781,0.0 -8800,0.5510005,1.5799446,,,,,,,,,,,,,, -8900,0.6483215,1.5599177,,,,,,,,,,,,,, -9000,0.7681776,1.5279785,,,,,,,,,,,,,, -9100,0.5093946,1.6159234,,,,,,,,,,,,,, -9200,0.6610401,1.5945231,,,,,,,,,,,,,, -9300,0.62556714,1.4790921,,,,,,,,,,,,,, -9400,0.5247631,1.5753361,,,,,,,,,,,,,, -9500,0.5291571,1.5107703,,,,,,,,,,,,,, -9600,0.48933366,1.5643654,,,,,,,,,,,,,, -9700,0.4728441,1.5342182,,,,,,,,,,,,,, -9800,0.46766454,1.5381106,,,,,,,,,,,,,, -9900,0.5078729,1.5254681,,,,,,,,,,,,,, -10000,0.53114235,1.5143332,,,,,,,,,,,,,, -10100,0.46743596,1.6076816,,,,,,,,,,,,,, -10200,0.4846272,1.4918909,,,,,,,,,,,,,, -10300,0.5320031,1.4714968,,,,,,,,,,,,,, -10400,0.56887335,1.477861,,,,,,,,,,,,,, -10500,0.541687,1.5509682,,,,,,,,,,,,,, -10525,,,0.34834158,0.1233335645048572,0.6127697,0.1846355851202487,5348.0,0.38322657,0.1270692421749639,2472.0,8702.932670116425,9668.075123786926,8702.932670116425,964.390949010849,0.3064408302307129,0.0 -10600,0.4302563,1.5032381,,,,,,,,,,,,,, -10700,0.6318899,1.5025728,,,,,,,,,,,,,, -10800,0.44682997,1.4366404,,,,,,,,,,,,,, -10900,0.56209546,1.5138429,,,,,,,,,,,,,, -11000,0.54899085,1.5047538,,,,,,,,,,,,,, -11100,0.71625465,1.5010237,,,,,,,,,,,,,, -11200,0.56083596,1.5429059,,,,,,,,,,,,,, -11300,0.65368116,1.4621985,,,,,,,,,,,,,, -11400,0.47248346,1.4932965,,,,,,,,,,,,,, -11500,0.53895813,1.4983485,,,,,,,,,,,,,, -11600,0.49556202,1.4801619,,,,,,,,,,,,,, -11700,0.5436151,1.5340346,,,,,,,,,,,,,, -11800,0.47978133,1.4749944,,,,,,,,,,,,,, -11900,0.55258256,1.5143893,,,,,,,,,,,,,, -12000,0.5271204,1.4102654,,,,,,,,,,,,,, -12100,0.51011974,1.4739007,,,,,,,,,,,,,, -12200,0.6388358,1.4742635,,,,,,,,,,,,,, -12256,,,0.3036773,0.1103630910599196,0.5770447,0.1742375237745831,5348.0,0.35471162,0.1210570146040257,2472.0,10143.555659532549,11241.953930139542,10143.555659532549,1097.5230059623718,0.357492446899414,0.0 -12300,0.57884514,1.4903238,,,,,,,,,,,,,, -12400,0.5028015,1.4479643,,,,,,,,,,,,,, -12500,0.5990213,1.4429195,,,,,,,,,,,,,, -12600,0.560973,1.4450747,,,,,,,,,,,,,, -12700,0.45125717,1.4252381,,,,,,,,,,,,,, -12800,0.5525917,1.5293695,,,,,,,,,,,,,, -12900,0.6440034,1.4456079,,,,,,,,,,,,,, -13000,0.78241146,1.4881157,,,,,,,,,,,,,, -13100,0.53233314,1.4299592,,,,,,,,,,,,,, -13200,0.54119164,1.5119826,,,,,,,,,,,,,, -13300,0.46406186,1.493184,,,,,,,,,,,,,, -13400,0.5681821,1.3739316,,,,,,,,,,,,,, -13500,0.460429,1.4808983,,,,,,,,,,,,,, -13600,0.43332946,1.454046,,,,,,,,,,,,,, -13700,0.52262133,1.4887127,,,,,,,,,,,,,, -13800,0.5060225,1.4306792,,,,,,,,,,,,,, -13900,0.4403455,1.463208,,,,,,,,,,,,,, -13990,,,0.26289067,0.0994409631035648,0.55949336,0.1682902574895971,5348.0,0.34388256,0.1162025470720858,2472.0,11583.86165523529,12816.645854234695,11583.86165523529,1231.768848657608,0.4228217601776123,0.0 -14000,0.55344796,1.4323987,,,,,,,,,,,,,, -14100,0.43238023,1.4430988,,,,,,,,,,,,,, -14200,0.5293006,1.4640253,,,,,,,,,,,,,, -14300,0.40371466,1.3929391,,,,,,,,,,,,,, -14400,0.520168,1.3960152,,,,,,,,,,,,,, -14500,0.47448528,1.4769963,,,,,,,,,,,,,, -14600,0.4213053,1.46404,,,,,,,,,,,,,, -14700,0.49888927,1.4118149,,,,,,,,,,,,,, -14800,0.59397817,1.408821,,,,,,,,,,,,,, -14900,0.6973314,1.4205842,,,,,,,,,,,,,, -15000,0.6546131,1.459527,,,,,,,,,,,,,, -15100,0.45406187,1.4015075,,,,,,,,,,,,,, -15200,0.47790855,1.3860483,,,,,,,,,,,,,, -15300,0.53847367,1.4408785,,,,,,,,,,,,,, -15400,0.50829625,1.4686712,,,,,,,,,,,,,, -15500,0.4102264,1.384661,,,,,,,,,,,,,, -15600,0.6400864,1.4122933,,,,,,,,,,,,,, -15700,0.6948603,1.475569,,,,,,,,,,,,,, -15738,,,0.24273707,0.0902063638153099,0.5420629,0.1638973903472778,5348.0,0.32644108,0.1121808543050393,2472.0,13023.903942584991,14395.793190717695,13023.903942584991,1370.736781835556,0.48415207862854,0.0 -15800,0.60495347,1.4891187,,,,,,,,,,,,,, -15900,0.5160583,1.4888501,,,,,,,,,,,,,, -16000,0.51828074,1.4116024,,,,,,,,,,,,,, -16100,0.4812649,1.4766034,,,,,,,,,,,,,, -16200,0.49546027,1.4172058,,,,,,,,,,,,,, -16300,0.5592415,1.4612253,,,,,,,,,,,,,, -16400,0.51252663,1.3984375,,,,,,,,,,,,,, -16500,0.5348491,1.3753872,,,,,,,,,,,,,, -16600,0.644839,1.4034455,,,,,,,,,,,,,, -16700,0.50460786,1.4062039,,,,,,,,,,,,,, -16800,0.638582,1.443363,,,,,,,,,,,,,, -16900,0.44388276,1.3691796,,,,,,,,,,,,,, -17000,0.4631363,1.4460979,,,,,,,,,,,,,, -17100,0.56005895,1.4054108,,,,,,,,,,,,,, -17200,0.51495826,1.372159,,,,,,,,,,,,,, -17300,0.6238552,1.447281,,,,,,,,,,,,,, -17400,0.49322298,1.3626883,,,,,,,,,,,,,, -17455,,,0.24323381,0.0948961478438544,0.53616065,0.1624298830821514,5348.0,0.31930014,0.1086060162898868,2472.0,14464.083364963531,15967.672088384628,14464.083364963531,1502.307198524475,0.5404567718505859,0.0 -17500,0.7085062,1.4399006,,,,,,,,,,,,,, -17600,0.55515546,1.369484,,,,,,,,,,,,,, -17700,0.5134875,1.3855152,,,,,,,,,,,,,, -17800,0.49535236,1.3996118,,,,,,,,,,,,,, -17900,0.51610327,1.3882433,,,,,,,,,,,,,, -18000,0.55583024,1.4107603,,,,,,,,,,,,,, -18100,0.53421086,1.4292067,,,,,,,,,,,,,, -18200,0.5564167,1.4250641,,,,,,,,,,,,,, -18300,0.532336,1.3862201,,,,,,,,,,,,,, -18400,0.51969415,1.3744303,,,,,,,,,,,,,, -18500,0.47491175,1.3212425,,,,,,,,,,,,,, -18600,0.54648674,1.38249,,,,,,,,,,,,,, -18700,0.5163671,1.3601042,,,,,,,,,,,,,, -18800,0.513987,1.3918468,,,,,,,,,,,,,, -18900,0.4441589,1.3658202,,,,,,,,,,,,,, -19000,0.47474802,1.3744924,,,,,,,,,,,,,, -19100,0.58016,1.3832626,,,,,,,,,,,,,, -19158,,,0.25264207,0.0925830931542346,0.5239434,0.1574867007154098,5348.0,0.31532496,0.1049296203765766,2472.0,15904.700827360151,17542.432443857193,15904.700827360151,1636.3237302303314,0.5938196182250977,0.0 -19200,0.59706616,1.4131126,,,,,,,,,,,,,, -19300,0.50094146,1.3407378,,,,,,,,,,,,,, -19400,0.64811486,1.4278564,,,,,,,,,,,,,, -19500,0.5869367,1.3697529,,,,,,,,,,,,,, -19600,0.5281713,1.31494,,,,,,,,,,,,,, -19700,0.5745431,1.4008265,,,,,,,,,,,,,, -19800,0.5088556,1.3400468,,,,,,,,,,,,,, -19900,0.47311524,1.366344,,,,,,,,,,,,,, -20000,0.45713606,1.363234,,,,,,,,,,,,,, -20100,0.4836165,1.3743472,,,,,,,,,,,,,, -20200,0.5368255,1.4220256,,,,,,,,,,,,,, -20300,0.5323174,1.3393629,,,,,,,,,,,,,, -20400,0.5007142,1.336783,,,,,,,,,,,,,, -20500,0.5841394,1.3735942,,,,,,,,,,,,,, -20600,0.49994993,1.3734523,,,,,,,,,,,,,, -20700,0.49108145,1.3217092,,,,,,,,,,,,,, -20800,0.43745708,1.3309946,,,,,,,,,,,,,, -20900,0.54295176,1.3949003,,,,,,,,,,,,,, -20916,,,0.24772409,0.0907958952172422,0.5120749,0.1534800196954922,5348.0,0.3009962,0.1009282391891617,2472.0,17345.316667318344,19118.36318206787,17345.316667318344,1771.5046956539154,0.6519536972045898,0.0 -21000,0.5312882,1.3398668,,,,,,,,,,,,,, -21100,0.5317217,1.3277729,,,,,,,,,,,,,, -21200,0.47746527,1.3392876,,,,,,,,,,,,,, -21300,0.46740034,1.3182269,,,,,,,,,,,,,, -21400,0.4768095,1.3639855,,,,,,,,,,,,,, -21500,0.5460165,1.3762953,,,,,,,,,,,,,, -21600,0.6225601,1.3667809,,,,,,,,,,,,,, -21700,0.520254,1.3874066,,,,,,,,,,,,,, -21800,0.4977649,1.3345543,,,,,,,,,,,,,, -21900,0.5825542,1.3276596,,,,,,,,,,,,,, -22000,0.5403352,1.3191082,,,,,,,,,,,,,, -22100,0.46802807,1.3236189,,,,,,,,,,,,,, -22200,0.44526136,1.3839455,,,,,,,,,,,,,, -22300,0.5629272,1.3801638,,,,,,,,,,,,,, -22400,0.5033186,1.3768299,,,,,,,,,,,,,, -22500,0.43618447,1.3463935,,,,,,,,,,,,,, -22600,0.5476524,1.3224736,,,,,,,,,,,,,, -22635,,,0.22900142,0.0860479473196831,0.5003814,0.1506029330835996,5348.0,0.29497865,0.0990189507037962,2472.0,18785.592749118805,20690.474380731583,18785.592749118805,1903.2082245349884,0.7093734741210938,0.0 -22700,0.48407385,1.3147857,,,,,,,,,,,,,, -22800,0.5391249,1.3424357,,,,,,,,,,,,,, -22900,0.5210363,1.3144118,,,,,,,,,,,,,, -23000,0.51803684,1.3533342,,,,,,,,,,,,,, -23100,0.44334862,1.3094169,,,,,,,,,,,,,, -23200,0.61350536,1.3870369,,,,,,,,,,,,,, -23300,0.5060705,1.3144908,,,,,,,,,,,,,, -23400,0.52719134,1.3513899,,,,,,,,,,,,,, -23500,0.53730994,1.3789004,,,,,,,,,,,,,, -23600,0.43998778,1.354782,,,,,,,,,,,,,, -23700,0.4720217,1.29813,,,,,,,,,,,,,, -23800,0.46860558,1.3324872,,,,,,,,,,,,,, -23900,0.6576236,1.3825123,,,,,,,,,,,,,, -24000,0.43386337,1.328856,,,,,,,,,,,,,, -24100,0.54900724,1.3573346,,,,,,,,,,,,,, -24200,0.47810292,1.3160962,,,,,,,,,,,,,, -24300,0.5160125,1.3410238,,,,,,,,,,,,,, -24368,,,0.21313213,0.079963895105781,0.48328742,0.1474651708390859,5348.0,0.28867492,0.0962971990331688,2472.0,20225.857773780823,22265.06233000756,20225.857773780823,2037.400043010712,0.7659990787506104,0.0 -24400,0.49547306,1.3769611,,,,,,,,,,,,,, -24500,0.6220373,1.3495952,,,,,,,,,,,,,, -24600,0.57096934,1.2973062,,,,,,,,,,,,,, -24700,0.60041374,1.371113,,,,,,,,,,,,,, -24800,0.5905934,1.329374,,,,,,,,,,,,,, -24900,0.68105686,1.3285413,,,,,,,,,,,,,, -25000,0.4132523,1.2963439,,,,,,,,,,,,,, -25100,0.57847035,1.3034935,,,,,,,,,,,,,, -25200,0.506343,1.3396032,,,,,,,,,,,,,, -25300,0.50638866,1.2965842,,,,,,,,,,,,,, -25400,0.5068908,1.3086226,,,,,,,,,,,,,, -25500,0.459648,1.2965955,,,,,,,,,,,,,, -25600,0.6544444,1.3127043,,,,,,,,,,,,,, -25700,0.5217365,1.3900855,,,,,,,,,,,,,, -25800,0.48186556,1.225316,,,,,,,,,,,,,, -25900,0.53769684,1.3406929,,,,,,,,,,,,,, -26000,0.5256762,1.2458938,,,,,,,,,,,,,, -26100,0.6385783,1.3279524,,,,,,,,,,,,,, -26122,,,0.2005794,0.0753592947129554,0.47255135,0.1411896463500584,5348.0,0.27718583,0.093006723132858,2472.0,21666.32645058632,23840.28493332863,21666.32645058632,2172.017728328705,0.8262643814086914,0.0 -26200,0.5920971,1.2971725,,,,,,,,,,,,,, -26300,0.7403679,1.3523008,,,,,,,,,,,,,, -26400,0.64027876,1.3224863,,,,,,,,,,,,,, -26500,0.47114483,1.3014729,,,,,,,,,,,,,, -26600,0.5167574,1.2894834,,,,,,,,,,,,,, -26700,0.53980637,1.3336407,,,,,,,,,,,,,, -26800,0.6072627,1.3109342,,,,,,,,,,,,,, -26900,0.5083011,1.3135467,,,,,,,,,,,,,, -27000,0.51451564,1.2356708,,,,,,,,,,,,,, -27100,0.5984484,1.363182,,,,,,,,,,,,,, -27200,0.672499,1.2537355,,,,,,,,,,,,,, -27300,0.5364679,1.2759907,,,,,,,,,,,,,, -27400,0.5477739,1.2877457,,,,,,,,,,,,,, -27500,0.46029145,1.2759439,,,,,,,,,,,,,, -27600,0.5188593,1.3021922,,,,,,,,,,,,,, -27700,0.47038928,1.2703056,,,,,,,,,,,,,, -27800,0.5797614,1.2770361,,,,,,,,,,,,,, -27836,,,0.19883163,0.0731044992891552,0.46602303,0.1379360282688241,5348.0,0.2700468,0.0896553124936526,2472.0,23106.63402581215,25414.62991690636,23106.63402581215,2305.922488927841,0.885556697845459,0.0 -27900,0.65839577,1.2958915,,,,,,,,,,,,,, -28000,0.5282226,1.2514662,,,,,,,,,,,,,, -28100,0.50305456,1.313605,,,,,,,,,,,,,, -28200,0.5016173,1.2588546,,,,,,,,,,,,,, -28300,0.535055,1.3151495,,,,,,,,,,,,,, -28400,0.5706073,1.2464302,,,,,,,,,,,,,, -28500,0.5719452,1.2805222,,,,,,,,,,,,,, -28600,0.49181265,1.2839155,,,,,,,,,,,,,, -28700,0.66089016,1.3034527,,,,,,,,,,,,,, -28800,0.552605,1.287313,,,,,,,,,,,,,, -28900,0.7557163,1.2384388,,,,,,,,,,,,,, -29000,0.4933393,1.29766,,,,,,,,,,,,,, -29100,0.6087339,1.3545284,,,,,,,,,,,,,, -29200,0.550644,1.2760099,,,,,,,,,,,,,, -29300,0.4668103,1.2703134,,,,,,,,,,,,,, -29400,0.576842,1.2960753,,,,,,,,,,,,,, -29500,0.4770444,1.2740852,,,,,,,,,,,,,, -29580,,,0.20785119,0.0773260097070549,0.4579564,0.1361112988404761,5348.0,0.26341733,0.0888631608880222,2472.0,24546.826768159863,26987.76125073433,24546.826768159863,2438.7332191467285,0.9390220642089844,0.0 -29600,0.5116694,1.2843987,,,,,,,,,,,,,, -29700,0.5949318,1.2901005,,,,,,,,,,,,,, -29800,0.72734046,1.2199229,,,,,,,,,,,,,, -29900,0.55322427,1.2526722,,,,,,,,,,,,,, -30000,0.9295303,1.2796878,,,,,,,,,,,,,, -30100,0.72102547,1.2628983,,,,,,,,,,,,,, -30200,0.5227761,1.272335,,,,,,,,,,,,,, -30300,0.46907744,1.285481,,,,,,,,,,,,,, -30400,0.62351114,1.2465875,,,,,,,,,,,,,, -30500,0.54219097,1.2781146,,,,,,,,,,,,,, -30600,0.5191525,1.2590599,,,,,,,,,,,,,, -30700,0.5617512,1.3170187,,,,,,,,,,,,,, -30800,0.503858,1.3054297,,,,,,,,,,,,,, -30900,0.63518906,1.2815939,,,,,,,,,,,,,, -31000,0.7386463,1.2164775,,,,,,,,,,,,,, -31100,0.4740039,1.2309572,,,,,,,,,,,,,, -31200,0.5770629,1.2788547,,,,,,,,,,,,,, -31300,0.5277377,1.2592804,,,,,,,,,,,,,, -31326,,,0.18994193,0.0693275601061318,0.4489827,0.1332149029224635,5348.0,0.25740403,0.0864663944914996,2472.0,25986.92082810402,28560.95162129402,25986.92082810402,2571.6975836753845,0.9967617988586426,0.0 -31400,0.4827551,1.2480774,,,,,,,,,,,,,, -31500,0.65578866,1.2635287,,,,,,,,,,,,,, -31600,0.48438698,1.2984756,,,,,,,,,,,,,, -31700,0.57489824,1.2420317,,,,,,,,,,,,,, -31800,0.58697623,1.2082771,,,,,,,,,,,,,, -31900,0.546671,1.2473512,,,,,,,,,,,,,, -32000,0.47187486,1.2867563,,,,,,,,,,,,,, -32100,0.7177878,1.1920209,,,,,,,,,,,,,, -32200,0.5525111,1.2150694,,,,,,,,,,,,,, -32300,0.47051895,1.2661675,,,,,,,,,,,,,, -32400,0.6096362,1.2764137,,,,,,,,,,,,,, -32500,0.4444394,1.2325068,,,,,,,,,,,,,, -32600,0.49144715,1.2033513,,,,,,,,,,,,,, -32700,0.49141058,1.2331114,,,,,,,,,,,,,, -32800,0.57265633,1.218123,,,,,,,,,,,,,, -32900,0.5620572,1.2963058,,,,,,,,,,,,,, -33000,0.8353186,1.2349093,,,,,,,,,,,,,, -33048,,,0.20247611,0.0736891852323688,0.4474757,0.1316894677389768,5348.0,0.25247926,0.0856539313062376,2472.0,27426.860374689106,30135.734421014786,27426.860374689106,2706.4022014141083,1.061279058456421,0.0 -33100,0.50481194,1.2733117,,,,,,,,,,,,,, -33200,0.4222852,1.2108123,,,,,,,,,,,,,, -33300,0.5623525,1.2955589,,,,,,,,,,,,,, -33400,0.5818646,1.2136883,,,,,,,,,,,,,, -33500,0.7025109,1.2037406,,,,,,,,,,,,,, -33600,0.5264549,1.2338204,,,,,,,,,,,,,, -33700,0.6212981,1.271501,,,,,,,,,,,,,, -33800,0.5432763,1.2360815,,,,,,,,,,,,,, -33900,0.57243246,1.2263334,,,,,,,,,,,,,, -34000,0.5045817,1.1956393,,,,,,,,,,,,,, -34100,0.6066714,1.2226933,,,,,,,,,,,,,, -34200,0.4693825,1.2422116,,,,,,,,,,,,,, -34300,0.5556511,1.2272683,,,,,,,,,,,,,, -34400,0.51523304,1.2523764,,,,,,,,,,,,,, -34500,0.57257754,1.2711222,,,,,,,,,,,,,, -34600,0.5475732,1.1692196,,,,,,,,,,,,,, -34700,0.67289674,1.2777542,,,,,,,,,,,,,, -34778,,,0.19859408,0.0700789667188848,0.44250348,0.1293433870453865,5348.0,0.25044668,0.0848820912802388,2472.0,28867.33429455757,31709.335352897644,28867.33429455757,2839.3952023983,1.1202399730682373,0.0 -34800,0.54875004,1.2693745,,,,,,,,,,,,,, -34900,0.60525006,1.1977907,,,,,,,,,,,,,, -35000,0.66254187,1.2775897,,,,,,,,,,,,,, -35100,0.50973845,1.2219461,,,,,,,,,,,,,, -35200,0.4853522,1.2060559,,,,,,,,,,,,,, -35300,0.69735014,1.1840985,,,,,,,,,,,,,, -35400,0.5105356,1.274924,,,,,,,,,,,,,, -35500,0.49488,1.2349733,,,,,,,,,,,,,, -35600,0.6942523,1.2346755,,,,,,,,,,,,,, -35700,0.75787216,1.244092,,,,,,,,,,,,,, -35800,0.464746,1.1530535,,,,,,,,,,,,,, -35900,0.575272,1.2510009,,,,,,,,,,,,,, -36000,0.53742045,1.2713864,,,,,,,,,,,,,, -36100,0.56937474,1.2282627,,,,,,,,,,,,,, -36200,0.4963139,1.2347664,,,,,,,,,,,,,, -36300,0.5630982,1.2173288,,,,,,,,,,,,,, -36400,0.57464635,1.1602206,,,,,,,,,,,,,, -36500,0.65090036,1.2628975,,,,,,,,,,,,,, -36520,,,0.1440633,0.0544417453337843,0.42724466,0.1259546038213116,5348.0,0.24829671,0.0823837669855584,2472.0,30307.208152771,33287.83837580681,30307.208152771,2977.8900032043457,1.179589033126831,0.0 -36600,0.5149529,1.2508528,,,,,,,,,,,,,, -36700,0.55479634,1.2384484,,,,,,,,,,,,,, -36800,0.6563373,1.174265,,,,,,,,,,,,,, -36900,0.5389844,1.1816524,,,,,,,,,,,,,, -37000,0.5572661,1.2474716,,,,,,,,,,,,,, -37100,0.9307299,1.169766,,,,,,,,,,,,,, -37200,0.6434611,1.2126968,,,,,,,,,,,,,, -37300,0.88076764,1.1694038,,,,,,,,,,,,,, -37400,0.49846837,1.2856368,,,,,,,,,,,,,, -37500,0.6368635,1.2084167,,,,,,,,,,,,,, -37600,0.6391499,1.2166299,,,,,,,,,,,,,, -37700,0.59211904,1.1759093,,,,,,,,,,,,,, -37800,0.45546925,1.1924378,,,,,,,,,,,,,, -37900,0.6637939,1.2222612,,,,,,,,,,,,,, -38000,0.5613109,1.183649,,,,,,,,,,,,,, -38100,0.6445896,1.2430056,,,,,,,,,,,,,, -38200,0.57873875,1.1973428,,,,,,,,,,,,,, -38245,,,0.1679655,0.0616090473274059,0.41607982,0.1227878776176178,5348.0,0.24066494,0.0790729795056161,2472.0,31747.13672208786,34860.952849149704,31747.13672208786,3110.9470710754395,1.235142946243286,0.0 -38300,0.6491798,1.1730534,,,,,,,,,,,,,, -38400,0.46697354,1.1535307,,,,,,,,,,,,,, -38500,0.6001922,1.1746098,,,,,,,,,,,,,, -38600,0.6301047,1.1674976,,,,,,,,,,,,,, -38700,0.5381231,1.2471297,,,,,,,,,,,,,, -38800,0.5411197,1.2295256,,,,,,,,,,,,,, -38900,0.7329119,1.1722893,,,,,,,,,,,,,, -39000,0.5918905,1.159451,,,,,,,,,,,,,, -39100,0.5056328,1.1731544,,,,,,,,,,,,,, -39200,0.512301,1.123189,,,,,,,,,,,,,, -39300,0.5898733,1.2213457,,,,,,,,,,,,,, -39400,0.5402269,1.166597,,,,,,,,,,,,,, -39500,0.5759264,1.1805096,,,,,,,,,,,,,, -39600,0.57252425,1.1695576,,,,,,,,,,,,,, -39700,0.51343507,1.1649994,,,,,,,,,,,,,, -39800,0.5192453,1.1760851,,,,,,,,,,,,,, -39900,0.62285024,1.2274557,,,,,,,,,,,,,, -39983,,,0.20621349,0.0749445387632008,0.40232724,0.1190708361895015,5348.0,0.23006636,0.0759449962423577,2472.0,33187.74602842331,36433.55738687515,33187.74602842331,3242.813079357147,1.2900941371917725,0.0 -40000,0.6280156,1.2054797,,,,,,,,,,,,,, -40100,0.46410072,1.1587113,,,,,,,,,,,,,, -40200,0.59246963,1.1933959,,,,,,,,,,,,,, -40300,0.58421916,1.1724792,,,,,,,,,,,,,, -40400,0.5694514,1.1203735,,,,,,,,,,,,,, -40500,0.56457746,1.2127506,,,,,,,,,,,,,, -40600,0.5567085,1.1944348,,,,,,,,,,,,,, -40700,0.6243162,1.1842446,,,,,,,,,,,,,, -40800,0.55893385,1.2016892,,,,,,,,,,,,,, -40900,0.5259906,1.1878482,,,,,,,,,,,,,, -41000,0.53877074,1.1767123,,,,,,,,,,,,,, -41100,0.7776817,1.1201736,,,,,,,,,,,,,, -41200,0.52873725,1.188749,,,,,,,,,,,,,, -41300,0.5833874,1.1237748,,,,,,,,,,,,,, -41400,0.46832874,1.125867,,,,,,,,,,,,,, -41500,0.6897716,1.1275166,,,,,,,,,,,,,, -41600,0.591734,1.12138,,,,,,,,,,,,,, -41700,0.5202558,1.2301269,,,,,,,,,,,,,, -41717,,,0.21006845,0.07618645808468,0.40071264,0.1179605510875966,5348.0,0.22757748,0.0745028740885178,2472.0,34628.43957090378,38004.75534749031,34628.43957090378,3373.170571565628,1.361530303955078,0.0 -41800,0.64716166,1.1405823,,,,,,,,,,,,,, -41900,0.6451994,1.1877897,,,,,,,,,,,,,, -42000,0.58801454,1.1777503,,,,,,,,,,,,,, -42100,0.6369366,1.1607709,,,,,,,,,,,,,, -42200,0.6662062,1.1847539,,,,,,,,,,,,,, -42300,0.5359542,1.1472524,,,,,,,,,,,,,, -42400,0.6831229,1.1831654,,,,,,,,,,,,,, -42500,0.62824285,1.1537933,,,,,,,,,,,,,, -42600,0.5753578,1.1677647,,,,,,,,,,,,,, -42700,0.5152304,1.1984986,,,,,,,,,,,,,, -42800,0.7299755,1.1558279,,,,,,,,,,,,,, -42900,0.6099189,1.208598,,,,,,,,,,,,,, -43000,0.5474599,1.1757913,,,,,,,,,,,,,, -43100,0.6326805,1.1231667,,,,,,,,,,,,,, -43200,0.48364502,1.173375,,,,,,,,,,,,,, -43300,0.6589574,1.078591,,,,,,,,,,,,,, -43400,0.6571178,1.1622216,,,,,,,,,,,,,, -43454,,,0.23857479,0.0877853887956198,0.39650992,0.1165413170877704,5348.0,0.22053853,0.0726951435013101,2472.0,36068.67969703674,39580.1800005436,36068.67969703674,3508.219400167465,1.4211549758911133,0.0 -43500,0.5493897,1.2067224,,,,,,,,,,,,,, -43600,0.68679553,1.1684759,,,,,,,,,,,,,, -43700,0.8353188,1.1542729,,,,,,,,,,,,,, -43800,0.5578307,1.1650878,,,,,,,,,,,,,, -43900,0.6246343,1.1978415,,,,,,,,,,,,,, -44000,0.5856543,1.1898352,,,,,,,,,,,,,, -44100,0.56699693,1.1613891,,,,,,,,,,,,,, -44200,0.67427737,1.1690471,,,,,,,,,,,,,, -44300,0.6305688,1.1594037,,,,,,,,,,,,,, -44400,0.54899853,1.2197641,,,,,,,,,,,,,, -44500,0.5264306,1.152784,,,,,,,,,,,,,, -44600,0.559471,1.1316621,,,,,,,,,,,,,, -44700,0.6321211,1.1503481,,,,,,,,,,,,,, -44800,0.5342835,1.1432769,,,,,,,,,,,,,, -44900,0.6190756,1.1558164,,,,,,,,,,,,,, -45000,0.49937087,1.1301775,,,,,,,,,,,,,, -45100,0.55285746,1.1571971,,,,,,,,,,,,,, -45175,,,0.20972961,0.0750666768718096,0.38620195,0.1124284348841924,5348.0,0.21159218,0.0723701582272053,2472.0,37509.16482591629,41152.39415502548,37509.16482591629,3639.807451248169,1.488295316696167,0.0 -45200,0.625495,1.1821231,,,,,,,,,,,,,, -45300,0.56095976,1.1992358,,,,,,,,,,,,,, -45400,0.5634025,1.1347777,,,,,,,,,,,,,, -45500,0.7572363,1.1326745,,,,,,,,,,,,,, -45600,0.50567365,1.1001979,,,,,,,,,,,,,, -45700,0.58193535,1.0899986,,,,,,,,,,,,,, -45800,0.65119237,1.152145,,,,,,,,,,,,,, -45900,0.6226317,1.1211438,,,,,,,,,,,,,, -46000,0.5393624,1.1500789,,,,,,,,,,,,,, -46100,0.5867923,1.1510957,,,,,,,,,,,,,, -46200,0.72911805,1.1792811,,,,,,,,,,,,,, -46300,0.51677644,1.1075925,,,,,,,,,,,,,, -46400,0.67271334,1.107747,,,,,,,,,,,,,, -46500,0.5100553,1.1197488,,,,,,,,,,,,,, -46600,0.5358007,1.1206118,,,,,,,,,,,,,, -46700,0.8446669,1.0953856,,,,,,,,,,,,,, -46800,0.616476,1.1584767,,,,,,,,,,,,,, -46900,0.6314906,1.1506572,,,,,,,,,,,,,, -46908,,,0.18874624,0.0710643225368265,0.38043752,0.1116077893740888,5348.0,0.21320014,0.0701561960473666,2472.0,38949.38499760628,42729.00663161278,38949.38499760628,3776.064968347549,1.5486819744110107,0.0 -47000,0.56533754,1.1105927,,,,,,,,,,,,,, -47100,0.6157799,1.1513963,,,,,,,,,,,,,, -47200,0.71113783,1.1439877,,,,,,,,,,,,,, -47300,0.54710615,1.1381797,,,,,,,,,,,,,, -47400,0.46094167,1.0353385,,,,,,,,,,,,,, -47500,0.6432339,1.0882077,,,,,,,,,,,,,, -47600,0.6310731,1.1555157,,,,,,,,,,,,,, -47700,0.61435896,1.0983604,,,,,,,,,,,,,, -47800,0.5646136,1.1037236,,,,,,,,,,,,,, -47900,0.5756928,1.1292322,,,,,,,,,,,,,, -48000,0.54314417,1.154268,,,,,,,,,,,,,, -48100,0.48691648,1.062645,,,,,,,,,,,,,, -48200,0.6743855,1.1202527,,,,,,,,,,,,,, -48300,0.5960751,1.0973643,,,,,,,,,,,,,, -48400,0.65986514,1.1411439,,,,,,,,,,,,,, -48500,0.7364598,1.1004785,,,,,,,,,,,,,, -48600,0.6882841,1.0909315,,,,,,,,,,,,,, -48643,,,0.16027203,0.0604193629076988,0.37589723,0.1079872944765729,5348.0,0.20846424,0.0671094591026344,2472.0,40389.63473272324,44303.433913230896,40389.63473272324,3910.092271089554,1.622194766998291,0.0 -48700,0.5839285,1.1442138,,,,,,,,,,,,,, -48800,0.6431374,1.073479,,,,,,,,,,,,,, -48900,0.64405334,1.1001887,,,,,,,,,,,,,, -49000,0.627175,1.0879263,,,,,,,,,,,,,, -49100,0.609642,1.1493844,,,,,,,,,,,,,, -49200,0.60432,1.041551,,,,,,,,,,,,,, -49300,0.6103499,1.1489624,,,,,,,,,,,,,, -49400,0.6716178,1.0955439,,,,,,,,,,,,,, -49500,0.57325304,1.1000665,,,,,,,,,,,,,, -49600,0.6682005,1.131187,,,,,,,,,,,,,, -49700,0.57345337,1.110543,,,,,,,,,,,,,, -49800,0.6535083,1.1130947,,,,,,,,,,,,,, -49900,0.56107354,1.0865538,,,,,,,,,,,,,, -50000,0.6764444,1.1076868,,,,,,,,,,,,,, -50100,0.80641747,1.0702938,,,,,,,,,,,,,, -50200,0.7177496,1.0886766,,,,,,,,,,,,,, -50300,0.73381543,1.1288362,,,,,,,,,,,,,, -50360,,,0.17267397,0.0656883564061761,0.3635543,0.105708796354403,5348.0,0.19980165,0.0668657201470558,2472.0,41829.74850130081,45874.87132453919,41829.74850130081,4041.277694702149,1.6876039505004885,0.0 -50400,0.7579217,1.149648,,,,,,,,,,,,,, -50500,0.65234756,1.1044189,,,,,,,,,,,,,, -50600,0.6206755,1.1055954,,,,,,,,,,,,,, -50700,0.607331,1.0761815,,,,,,,,,,,,,, -50800,0.61428726,1.055532,,,,,,,,,,,,,, -50900,0.6506785,1.0686233,,,,,,,,,,,,,, -51000,0.5492914,1.1270282,,,,,,,,,,,,,, -51100,0.55879927,1.1203351,,,,,,,,,,,,,, -51200,0.58095294,1.0936692,,,,,,,,,,,,,, -51300,0.79732937,1.1452111,,,,,,,,,,,,,, -51400,0.67691153,1.1212327,,,,,,,,,,,,,, -51500,0.7267576,1.0341736,,,,,,,,,,,,,, -51600,0.68803304,1.0512123,,,,,,,,,,,,,, -51700,0.60172784,1.0786625,,,,,,,,,,,,,, -51800,0.54644805,1.0584836,,,,,,,,,,,,,, -51900,0.6166912,1.1046897,,,,,,,,,,,,,, -52000,0.72474587,1.0720358,,,,,,,,,,,,,, -52092,,,0.15158781,0.0584736299604878,0.36148843,0.1051391718238605,5348.0,0.19757481,0.0647330042857433,2472.0,43270.4296848774,47450.25897097588,43270.4296848774,4175.834066867828,1.761640548706055,0.0 -52100,0.608979,1.0374256,,,,,,,,,,,,,, -52200,0.55690527,1.0520875,,,,,,,,,,,,,, -52300,0.56056994,1.0670037,,,,,,,,,,,,,, -52400,0.66425294,1.0774862,,,,,,,,,,,,,, -52500,0.61147326,1.0890833,,,,,,,,,,,,,, -52600,0.764953,1.0912915,,,,,,,,,,,,,, -52700,0.8015644,1.0582236,,,,,,,,,,,,,, -52800,0.90639573,1.0401952,,,,,,,,,,,,,, -52900,0.5019143,1.0418836,,,,,,,,,,,,,, -53000,0.6681548,1.0204531,,,,,,,,,,,,,, -53100,0.6648225,1.1047627,,,,,,,,,,,,,, -53200,0.6693293,1.1022547,,,,,,,,,,,,,, -53300,0.62930894,1.1028001,,,,,,,,,,,,,, -53400,0.6437141,1.0760132,,,,,,,,,,,,,, -53500,0.610744,1.0776013,,,,,,,,,,,,,, -53600,0.73118526,1.0346432,,,,,,,,,,,,,, -53700,0.59406966,1.0855697,,,,,,,,,,,,,, -53800,0.71464473,1.0645704,,,,,,,,,,,,,, -53833,,,0.15212244,0.0581043404680024,0.35791597,0.1030441121098313,5348.0,0.19231722,0.0629862084374301,2472.0,44711.04065990448,49024.11247205734,44711.04065990448,4308.934076786041,1.827185869216919,0.0 -53900,0.655743,1.0280384,,,,,,,,,,,,,, -54000,0.536724,1.018985,,,,,,,,,,,,,, -54100,0.65869474,1.0624824,,,,,,,,,,,,,, -54200,0.5226171,1.0366756,,,,,,,,,,,,,, -54300,0.6475698,1.0522225,,,,,,,,,,,,,, -54400,0.55941975,1.0403451,,,,,,,,,,,,,, -54500,0.5573558,1.1148932,,,,,,,,,,,,,, -54600,0.66160446,1.0967851,,,,,,,,,,,,,, -54700,0.7492956,1.0266825,,,,,,,,,,,,,, -54800,0.55799717,1.0393142,,,,,,,,,,,,,, -54900,0.632278,1.0047663,,,,,,,,,,,,,, -55000,0.5855488,1.0406253,,,,,,,,,,,,,, -55100,0.53168863,1.0130786,,,,,,,,,,,,,, -55200,0.79564095,1.0749809,,,,,,,,,,,,,, -55300,0.94360465,1.0680805,,,,,,,,,,,,,, -55400,0.67483294,1.0524943,,,,,,,,,,,,,, -55500,0.5873751,1.0649288,,,,,,,,,,,,,, -55564,,,0.14071093,0.0529579020013802,0.34808445,0.0988153740695328,5348.0,0.18797308,0.0615847094428533,2472.0,46150.963654994965,50598.37662386894,46150.963654994965,4443.142770767212,1.886054515838623,0.0 -55600,0.58133715,1.057905,,,,,,,,,,,,,, -55700,0.5771081,0.98381466,,,,,,,,,,,,,, -55800,0.6344809,0.9986657,,,,,,,,,,,,,, -55900,0.8120318,1.0624954,,,,,,,,,,,,,, -56000,0.59790945,1.013671,,,,,,,,,,,,,, -56100,0.6728932,1.0779954,,,,,,,,,,,,,, -56200,0.5887111,1.0750722,,,,,,,,,,,,,, -56300,0.5120002,1.02704,,,,,,,,,,,,,, -56400,0.78779143,1.0252483,,,,,,,,,,,,,, -56500,0.55593437,1.036054,,,,,,,,,,,,,, -56600,0.9422399,1.0507749,,,,,,,,,,,,,, -56700,0.53302354,0.9934954,,,,,,,,,,,,,, -56800,0.54121166,1.014221,,,,,,,,,,,,,, -56900,0.56824756,0.99395865,,,,,,,,,,,,,, -57000,0.63386804,1.0186055,,,,,,,,,,,,,, -57100,0.6369895,1.0110173,,,,,,,,,,,,,, -57200,0.63784724,0.9876456,,,,,,,,,,,,,, -57296,,,0.1408747,0.0539587697308929,0.34242028,0.0970968458248452,5348.0,0.1845956,0.0598988483334348,2472.0,47591.68971157074,52170.102900505066,47591.68971157074,4574.003474712372,1.948566198348999,0.0 -57300,0.5548129,0.9963459,,,,,,,,,,,,,, -57400,0.6556538,1.0471346,,,,,,,,,,,,,, -57500,0.6650335,1.0405653,,,,,,,,,,,,,, -57600,0.5426243,1.052416,,,,,,,,,,,,,, -57700,0.60524344,0.98690164,,,,,,,,,,,,,, -57800,0.73559207,1.0827825,,,,,,,,,,,,,, -57900,0.52136517,1.0421337,,,,,,,,,,,,,, -58000,0.63761586,1.001399,,,,,,,,,,,,,, -58100,0.66859466,0.9862999,,,,,,,,,,,,,, -58200,0.6489198,0.9919563,,,,,,,,,,,,,, -58300,0.74065703,1.0178907,,,,,,,,,,,,,, -58400,0.68245673,1.049523,,,,,,,,,,,,,, -58500,0.5731254,1.0181997,,,,,,,,,,,,,, -58600,0.62018555,0.99230903,,,,,,,,,,,,,, -58700,0.731165,1.0467073,,,,,,,,,,,,,, -58800,0.5406857,0.98020196,,,,,,,,,,,,,, -58900,0.58722305,1.0024648,,,,,,,,,,,,,, -59000,0.69558597,0.9493656,,,,,,,,,,,,,, -59032,,,0.14205861,0.0535074439930733,0.33538797,0.0947990383965552,5348.0,0.17841998,0.0572583429813336,2472.0,49031.99387717247,53744.072734594345,49031.99387717247,4707.52090382576,2.02084755897522,0.0 -59100,0.7394939,1.010333,,,,,,,,,,,,,, -59200,0.82937515,1.0259312,,,,,,,,,,,,,, -59300,0.60817206,1.0156617,,,,,,,,,,,,,, -59400,0.5768841,1.0058446,,,,,,,,,,,,,, -59500,0.542243,1.0275339,,,,,,,,,,,,,, -59600,0.7053528,1.0823845,,,,,,,,,,,,,, -59700,0.62314624,1.0067947,,,,,,,,,,,,,, -59800,0.70643926,1.005586,,,,,,,,,,,,,, -59900,0.6483693,1.0131152,,,,,,,,,,,,,, -60000,0.6376113,0.9737294,,,,,,,,,,,,,, -60100,0.69399583,0.9882033,,,,,,,,,,,,,, -60200,0.5848854,1.0181905,,,,,,,,,,,,,, -60300,0.754152,1.02543,,,,,,,,,,,,,, -60400,0.6412352,1.0219214,,,,,,,,,,,,,, -60500,0.74494606,1.000416,,,,,,,,,,,,,, -60600,0.691224,0.99039507,,,,,,,,,,,,,, -60700,0.6303218,0.988517,,,,,,,,,,,,,, -60739,,,0.11626529,0.0448752437958168,0.33002877,0.0928777624376068,5348.0,0.17275436,0.0567302419109134,2472.0,50472.62706637383,55318.258915901184,50472.62706637383,4840.933924913406,2.087460994720459,0.0 -60800,0.68747824,0.9973276,,,,,,,,,,,,,, -60900,0.6420849,1.0133605,,,,,,,,,,,,,, -61000,0.8503272,0.956329,,,,,,,,,,,,,, -61100,0.69241965,0.9936916,,,,,,,,,,,,,, -61200,0.6551744,0.9744969,,,,,,,,,,,,,, -61300,0.5834271,1.0315688,,,,,,,,,,,,,, -61400,0.6384439,1.0157036,,,,,,,,,,,,,, -61500,0.6317679,1.000237,,,,,,,,,,,,,, -61600,0.67413706,0.9694558,,,,,,,,,,,,,, -61700,0.5589989,1.0326896,,,,,,,,,,,,,, -61800,0.5893753,0.9642123,,,,,,,,,,,,,, -61900,0.5874384,0.97738063,,,,,,,,,,,,,, -62000,0.8928422,1.0030388,,,,,,,,,,,,,, -62100,0.64674354,0.97819227,,,,,,,,,,,,,, -62200,0.62325466,0.9532412,,,,,,,,,,,,,, -62300,0.82269347,0.97956973,,,,,,,,,,,,,, -62400,0.66159004,0.97446525,,,,,,,,,,,,,, -62451,,,0.11678836,0.0445723609146091,0.3245294,0.0920571169275032,5348.0,0.17404887,0.0566286840127556,2472.0,51912.70821213722,56890.49936628342,51912.70821213722,4972.952255249023,2.15358567237854,0.0 -62500,0.6430712,0.98025304,,,,,,,,,,,,,, -62600,0.59128183,0.95647025,,,,,,,,,,,,,, -62700,0.6725031,0.98869973,,,,,,,,,,,,,, -62800,0.6411,0.9239395,,,,,,,,,,,,,, -62900,0.7167549,0.96700644,,,,,,,,,,,,,, -63000,0.56551033,0.9666772,,,,,,,,,,,,,, -63100,0.5800638,0.94565815,,,,,,,,,,,,,, -63200,0.6782655,0.9695365,,,,,,,,,,,,,, -63300,0.75252,0.94814634,,,,,,,,,,,,,, -63400,0.7170061,0.9739711,,,,,,,,,,,,,, -63500,0.71318614,0.99053997,,,,,,,,,,,,,, -63600,0.77650905,0.95081866,,,,,,,,,,,,,, -63700,0.8875788,0.91773236,,,,,,,,,,,,,, -63800,0.617372,0.9030109,,,,,,,,,,,,,, -63900,0.6055583,0.99643767,,,,,,,,,,,,,, -64000,0.7557642,0.94812524,,,,,,,,,,,,,, -64100,0.670771,0.9668833,,,,,,,,,,,,,, -64191,,,0.10589888,0.0406809355889836,0.31565297,0.0894117419890516,5348.0,0.16827966,0.0546381492088639,2472.0,53352.64188218117,58465.91937327385,53352.64188218117,5108.296470880508,2.219554662704468,0.0 -64200,0.71934927,0.9204125,,,,,,,,,,,,,, -64300,0.57938623,1.0229969,,,,,,,,,,,,,, -64400,0.619439,0.9576057,,,,,,,,,,,,,, -64500,0.67492014,0.9586272,,,,,,,,,,,,,, -64600,0.9153351,0.97162354,,,,,,,,,,,,,, -64700,0.70190316,0.99561423,,,,,,,,,,,,,, -64800,0.5736418,0.98882943,,,,,,,,,,,,,, -64900,0.5674015,0.9333688,,,,,,,,,,,,,, -65000,0.6894053,0.9659688,,,,,,,,,,,,,, -65100,0.55766094,0.9715955,,,,,,,,,,,,,, -65200,0.66857636,0.9280065,,,,,,,,,,,,,, -65300,0.6660478,0.92008364,,,,,,,,,,,,,, -65400,0.64458203,0.94775695,,,,,,,,,,,,,, -65500,0.6469706,0.98338556,,,,,,,,,,,,,, -65600,0.70313245,0.9474989,,,,,,,,,,,,,, -65700,1.0395057,0.9274683,,,,,,,,,,,,,, -65800,0.6740102,0.9206466,,,,,,,,,,,,,, -65900,0.69325554,0.93543893,,,,,,,,,,,,,, -65914,,,0.10621779,0.0402648644647767,0.3113982,0.0871235892138216,5348.0,0.16650042,0.0526476144049722,2472.0,54793.7216861248,60037.51570367813,54793.7216861248,5238.6785979270935,2.280007123947144,0.0 -66000,0.93366313,0.9665635,,,,,,,,,,,,,, -66100,0.69798803,0.91707504,,,,,,,,,,,,,, -66200,0.6059518,0.9100972,,,,,,,,,,,,,, -66300,0.8547708,0.94612134,,,,,,,,,,,,,, -66400,0.6810994,0.93377584,,,,,,,,,,,,,, -66500,0.6926766,0.9212759,,,,,,,,,,,,,, -66600,0.7198333,0.9420716,,,,,,,,,,,,,, -66700,0.6106641,0.94295114,,,,,,,,,,,,,, -66800,0.60277987,0.90612173,,,,,,,,,,,,,, -66900,0.61874324,0.93809617,,,,,,,,,,,,,, -67000,0.6755345,0.9422291,,,,,,,,,,,,,, -67100,0.6268143,0.9141365,,,,,,,,,,,,,, -67200,0.6375849,0.90699196,,,,,,,,,,,,,, -67300,0.62497264,0.93814373,,,,,,,,,,,,,, -67400,0.73512536,0.96497077,,,,,,,,,,,,,, -67500,0.65130854,0.9075119,,,,,,,,,,,,,, -67600,0.6599641,0.92216474,,,,,,,,,,,,,, -67624,,,0.09754225,0.0369401399789807,0.30889,0.0859939948057966,5348.0,0.16369732,0.0521398249141835,2472.0,56233.85720348358,61610.60500574112,56233.85720348358,5371.498435974121,2.340143918991089,0.0 -67700,0.71132094,0.92946106,,,,,,,,,,,,,, -67800,0.6578474,0.96524554,,,,,,,,,,,,,, -67900,0.62970734,0.9202323,,,,,,,,,,,,,, -68000,0.88304174,0.9489007,,,,,,,,,,,,,, -68100,0.5647319,0.9736722,,,,,,,,,,,,,, -68200,0.6086224,0.9711433,,,,,,,,,,,,,, -68300,0.6415088,0.9020826,,,,,,,,,,,,,, -68400,0.6425572,0.89415026,,,,,,,,,,,,,, -68500,0.7443059,0.92777693,,,,,,,,,,,,,, -68600,0.5806994,0.950006,,,,,,,,,,,,,, -68700,0.6640678,0.9332361,,,,,,,,,,,,,, -68800,0.68726546,0.906251,,,,,,,,,,,,,, -68900,0.8425317,0.92658335,,,,,,,,,,,,,, -69000,0.6349368,0.8842879,,,,,,,,,,,,,, -69100,0.67900497,0.9432667,,,,,,,,,,,,,, -69200,0.72775644,0.9222003,,,,,,,,,,,,,, -69300,0.9598658,0.921856,,,,,,,,,,,,,, -69373,,,0.09445868,0.0359186655302569,0.30673182,0.0855595354180947,5348.0,0.16329591,0.0520179554363942,2472.0,57674.4818482399,63183.08140993118,57674.4818482399,5503.210609436035,2.404115915298462,0.0 -69373,,,,,,,,,,,57674.4818482399,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 32a922a94..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,27 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -191.6783373355865,0.0,42.44247460365296,1,0,42.44247460365296,31.068468,2472,2.0823634554059267,234.12087845802307,32.005703,2.172323678237511,31.066788,5348,2.115073809822644 -316.6514196395874,0.0433142185211181,1482.4016528129578,1821,0,1482.4016528129578,1.3049777,2472,0.3606727195173969,1799.1643204689026,1.3006566,0.3580616259191366,1.7522995,5348,0.4345076609672031 -446.47476387023926,0.0939624309539794,2922.9053342342377,3644,0,2922.9053342342377,0.5849962,2472,0.1863181199601893,3369.6125876903534,0.56066525,0.1831393335641797,0.90998024,5348,0.2569972098052656 -574.8677513599396,0.1399412155151367,4362.876003265381,5443,0,4362.876003265381,0.5009056,2472,0.1613348770133853,4938.094379425049,0.4335046,0.1502759460442462,0.8066129,5348,0.229462139278025 -703.8194465637207,0.1875872611999511,5803.261379480362,7227,0,5803.261379480362,0.4495168,2472,0.1439684764284118,6507.550093650818,0.4039872,0.1359266512373153,0.7400715,5348,0.2126051150351912 -835.4457139968872,0.2387313842773437,7243.5706782341,9035,0,7243.5706782341,0.42103618,2472,0.1364735035443706,8079.611067295074,0.37986395,0.1296886211920241,0.7127948,5348,0.2048910472402174 -965.1938455104828,0.2876284122467041,8684.085877656937,10824,0,8684.085877656937,0.40143517,2472,0.1293035159344342,9649.997616052628,0.38293672,0.1269345055507839,0.67859113,5348,0.1951398476495747 -1098.6871774196625,0.3397469520568847,10124.552185058594,12589,0,10124.552185058594,0.38833746,2472,0.1252005768488615,11224.081290483477,0.34332344,0.1176429923638142,0.656048,5348,0.1870878669974994 -1230.286295413971,0.3859894275665283,11564.96788263321,14344,0,11564.96788263321,0.37290213,2472,0.1205086019539739,12796.21423316002,0.29857165,0.1054889622715099,0.63960975,5348,0.1832839336918428 -1362.066598653793,0.4408023357391357,13004.995213031769,16105,0,13004.995213031769,0.3568221,2472,0.1155525765238762,14368.149621248243,0.27393517,0.0952391085764142,0.61575806,5348,0.1773077034476766 -1492.6827099323273,0.4890756607055664,14445.443444490433,17906,0,14445.443444490433,0.34615928,2472,0.1115511953364613,15939.336789369583,0.28437436,0.1015193036837617,0.6062099,5348,0.1754443554070884 -1624.5698716640472,0.5372681617736816,15885.84837603569,19647,0,15885.84837603569,0.3346423,2472,0.1081388499583612,17511.750654459,0.28456303,0.0969512678812237,0.58785826,5348,0.1692653774486613 -1756.6709995269775,0.5862374305725098,17326.075987815857,21394,0,17326.075987815857,0.31810853,2472,0.1022281802855808,19084.2019238472,0.2806753,0.0961141741630833,0.56532633,5348,0.163018816918814 -1888.72429060936,0.6344974040985107,18766.0640540123,23163,0,18766.0640540123,0.31591764,2472,0.1013954055206873,20656.36541342736,0.2647326,0.0910690400246939,0.56283754,5348,0.1627195226739527 -2020.7151553630829,0.6808972358703613,20206.35933160782,24924,0,20206.35933160782,0.30580786,2472,0.097678386448114,22228.771274089813,0.2338939,0.0815316076179215,0.5402023,5348,0.1555171514911611 -2151.0077567100525,0.7294008731842041,21646.578894615173,26668,0,21646.578894615173,0.29551646,2472,0.0941847947514878,23799.4051322937,0.22001642,0.0781554031143222,0.5292033,5348,0.1534220917771319 -2282.118963956833,0.7779040336608887,23086.48792457581,28420,0,23086.48792457581,0.27854708,2472,0.0901427904048097,25370.54607105255,0.20773691,0.0718822554745272,0.50856733,5348,0.1475520627166262 -2410.6302189826965,0.8339059352874756,24526.86054444313,30184,0,24526.86054444313,0.27274805,2472,0.0890256535250746,26939.56102442741,0.22040235,0.0768168226729069,0.49543273,5348,0.1430143757784064 -2539.895192861557,0.8861689567565918,25966.865243196487,31904,0,25966.865243196487,0.2672006,2472,0.0856336197266061,28508.955932617188,0.20300266,0.0685516470518017,0.48683342,5348,0.1398862681869527 -2672.469190120697,0.9372451305389404,27407.19617295265,33640,0,27407.19617295265,0.2545436,2472,0.0820790932910852,30081.98685359955,0.20574053,0.0707560012123142,0.4757007,5348,0.1369126350444597 -2804.6382009983063,0.9974474906921388,28847.49475240708,35391,0,28847.49475240708,0.24884571,2472,0.0786464363333536,31654.591391325,0.19533391,0.0648997730106461,0.46448025,5348,0.1330893924326829 -2937.5611956119537,1.0491628646850586,30287.84063434601,37110,0,30287.84063434601,0.24322133,2472,0.0775699226128816,33227.98648023605,0.14400123,0.0516486782899254,0.44723344,5348,0.1284648136169226 -3069.1378898620605,1.1044816970825195,31728.33105635643,38860,0,31728.33105635643,0.23379442,2472,0.0750309751589381,34800.18440055847,0.16303992,0.0565880340436316,0.4371549,5348,0.1246126070459658 -3200.66814160347,1.1607980728149414,33168.68277978897,40624,0,33168.68277978897,0.22834969,2472,0.0737716572217821,36372.19922232628,0.2076175,0.0723138590715056,0.42986232,5348,0.1229326974135184 -3333.876192331314,1.2174742221832275,34609.292607069016,42351,0,34609.292607069016,0.2234378,2472,0.0713342676659963,37946.1487801075,0.21227044,0.0734329982180229,0.42083701,5348,0.1201618119852863 -3462.570736169815,1.2742955684661865,36049.41221165657,44083,0,36049.41221165657,0.22179714,2472,0.07092803607336542,39515.09550356865,0.24737003,0.08589010886160361,0.41725254,5348,0.11932185716906263 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 5cfd867e1..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,469 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,19.991858,33.774593,,,,,,,,,,,,,, -1,,,32.005703,2.172323678237511,31.066788,2.115073809822644,5348.0,31.068468,2.0823634554059267,2472.0,42.44247460365296,234.12087845802307,42.44247460365296,191.6783373355865,0.0,0.0 -100,0.9082452,6.164198,,,,,,,,,,,,,, -200,0.33543566,5.846587,,,,,,,,,,,,,, -300,0.9883584,5.647816,,,,,,,,,,,,,, -400,1.9251047,5.4283485,,,,,,,,,,,,,, -500,1.1511276,4.733867,,,,,,,,,,,,,, -600,2.7435932,3.9491882,,,,,,,,,,,,,, -700,2.186598,3.464498,,,,,,,,,,,,,, -800,2.2722094,3.1959493,,,,,,,,,,,,,, -900,2.4117289,3.019333,,,,,,,,,,,,,, -1000,2.7357838,2.7839174,,,,,,,,,,,,,, -1100,2.3036857,2.7623212,,,,,,,,,,,,,, -1200,1.7342998,2.6403558,,,,,,,,,,,,,, -1300,2.2394135,2.5086155,,,,,,,,,,,,,, -1400,2.7764091,2.5174248,,,,,,,,,,,,,, -1500,1.8454039,2.4005992,,,,,,,,,,,,,, -1600,2.3825634,2.3699405,,,,,,,,,,,,,, -1700,1.8233855,2.315738,,,,,,,,,,,,,, -1800,2.828814,2.2272384,,,,,,,,,,,,,, -1821,,,1.3006566,0.3580616259191366,1.7522995,0.4345076609672031,5348.0,1.3049777,0.3606727195173969,2472.0,1482.4016528129578,1799.1643204689026,1482.4016528129578,316.6514196395874,0.0433142185211181,0.0 -1900,2.225754,2.272867,,,,,,,,,,,,,, -2000,1.7917188,2.1810586,,,,,,,,,,,,,, -2100,2.2316089,2.1595087,,,,,,,,,,,,,, -2200,2.089547,2.1596065,,,,,,,,,,,,,, -2300,2.7346392,2.0998096,,,,,,,,,,,,,, -2400,2.2061768,2.1071541,,,,,,,,,,,,,, -2500,3.3295739,2.0923092,,,,,,,,,,,,,, -2600,2.5892675,2.0377185,,,,,,,,,,,,,, -2700,3.2997856,2.063645,,,,,,,,,,,,,, -2800,1.4939276,2.0175173,,,,,,,,,,,,,, -2900,2.953889,1.997903,,,,,,,,,,,,,, -3000,2.819003,1.9985291,,,,,,,,,,,,,, -3100,2.4186413,1.9455067,,,,,,,,,,,,,, -3200,1.87422,1.8696897,,,,,,,,,,,,,, -3300,2.7304707,1.8920614,,,,,,,,,,,,,, -3400,4.0163217,1.9278733,,,,,,,,,,,,,, -3500,2.2755392,1.8951614,,,,,,,,,,,,,, -3600,2.2125168,1.8977249,,,,,,,,,,,,,, -3644,,,0.56066525,0.1831393335641797,0.90998024,0.2569972098052656,5348.0,0.5849962,0.1863181199601893,2472.0,2922.9053342342377,3369.6125876903534,2922.9053342342377,446.47476387023926,0.0939624309539794,0.0 -3700,2.1218457,1.8739824,,,,,,,,,,,,,, -3800,3.3883557,1.9998417,,,,,,,,,,,,,, -3900,2.4711332,1.9035062,,,,,,,,,,,,,, -4000,2.685063,1.9432645,,,,,,,,,,,,,, -4100,3.0942137,1.8586843,,,,,,,,,,,,,, -4200,2.29755,1.821755,,,,,,,,,,,,,, -4300,3.919476,1.8218356,,,,,,,,,,,,,, -4400,2.2974184,1.8381853,,,,,,,,,,,,,, -4500,1.9713702,1.8335077,,,,,,,,,,,,,, -4600,2.036154,1.7767929,,,,,,,,,,,,,, -4700,4.5755863,1.8559334,,,,,,,,,,,,,, -4800,2.6002297,1.8144374,,,,,,,,,,,,,, -4900,2.2559054,1.8005224,,,,,,,,,,,,,, -5000,2.5913043,1.8071939,,,,,,,,,,,,,, -5100,2.6065354,1.8002546,,,,,,,,,,,,,, -5200,3.1877604,1.7654544,,,,,,,,,,,,,, -5300,3.387835,1.7559222,,,,,,,,,,,,,, -5400,2.7733443,1.7514927,,,,,,,,,,,,,, -5443,,,0.4335046,0.1502759460442462,0.8066129,0.229462139278025,5348.0,0.5009056,0.1613348770133853,2472.0,4362.876003265381,4938.094379425049,4362.876003265381,574.8677513599396,0.1399412155151367,0.0 -5500,3.003824,1.7829133,,,,,,,,,,,,,, -5600,3.2661095,1.7577012,,,,,,,,,,,,,, -5700,2.0377774,1.7221172,,,,,,,,,,,,,, -5800,2.836892,1.7490319,,,,,,,,,,,,,, -5900,4.8627367,1.7488983,,,,,,,,,,,,,, -6000,2.093739,1.7714353,,,,,,,,,,,,,, -6100,2.113293,1.688824,,,,,,,,,,,,,, -6200,3.0594127,1.6846083,,,,,,,,,,,,,, -6300,2.6557274,1.6833794,,,,,,,,,,,,,, -6400,3.2356431,1.7042718,,,,,,,,,,,,,, -6500,3.0530696,1.7147754,,,,,,,,,,,,,, -6600,1.804057,1.6817064,,,,,,,,,,,,,, -6700,2.9494991,1.7328795,,,,,,,,,,,,,, -6800,2.2605658,1.7433816,,,,,,,,,,,,,, -6900,5.5332317,1.6543336,,,,,,,,,,,,,, -7000,1.4978858,1.6987108,,,,,,,,,,,,,, -7100,2.2703493,1.6909338,,,,,,,,,,,,,, -7200,2.359143,1.7077701,,,,,,,,,,,,,, -7227,,,0.4039872,0.1359266512373153,0.7400715,0.2126051150351912,5348.0,0.4495168,0.1439684764284118,2472.0,5803.261379480362,6507.550093650818,5803.261379480362,703.8194465637207,0.1875872611999511,0.0 -7300,2.4253683,1.5964854,,,,,,,,,,,,,, -7400,2.3631797,1.6020396,,,,,,,,,,,,,, -7500,3.1190257,1.6917379,,,,,,,,,,,,,, -7600,3.3541284,1.7251533,,,,,,,,,,,,,, -7700,3.175207,1.6928694,,,,,,,,,,,,,, -7800,4.0971165,1.6738842,,,,,,,,,,,,,, -7900,3.0579453,1.6445894,,,,,,,,,,,,,, -8000,2.4961672,1.630954,,,,,,,,,,,,,, -8100,2.2262194,1.6906871,,,,,,,,,,,,,, -8200,2.2799377,1.6740772,,,,,,,,,,,,,, -8300,4.4701743,1.660851,,,,,,,,,,,,,, -8400,1.9943767,1.6688783,,,,,,,,,,,,,, -8500,2.8651166,1.650242,,,,,,,,,,,,,, -8600,3.07853,1.6496104,,,,,,,,,,,,,, -8700,2.6677234,1.6600268,,,,,,,,,,,,,, -8800,2.6919236,1.6429585,,,,,,,,,,,,,, -8900,3.3196225,1.7031404,,,,,,,,,,,,,, -9000,3.1349573,1.6578548,,,,,,,,,,,,,, -9035,,,0.37986395,0.1296886211920241,0.7127948,0.2048910472402174,5348.0,0.42103618,0.1364735035443706,2472.0,7243.5706782341,8079.611067295074,7243.5706782341,835.4457139968872,0.2387313842773437,0.0 -9100,3.2442605,1.685649,,,,,,,,,,,,,, -9200,2.9802256,1.5926528,,,,,,,,,,,,,, -9300,3.0709825,1.6530075,,,,,,,,,,,,,, -9400,3.37339,1.5768086,,,,,,,,,,,,,, -9500,2.928099,1.5685035,,,,,,,,,,,,,, -9600,2.6417983,1.6447464,,,,,,,,,,,,,, -9700,3.6992824,1.6306473,,,,,,,,,,,,,, -9800,2.5401285,1.6181717,,,,,,,,,,,,,, -9900,2.487824,1.5965618,,,,,,,,,,,,,, -10000,2.2272506,1.5555376,,,,,,,,,,,,,, -10100,3.0986722,1.5731899,,,,,,,,,,,,,, -10200,2.6513495,1.5267713,,,,,,,,,,,,,, -10300,3.3402605,1.6148905,,,,,,,,,,,,,, -10400,3.1902275,1.5833887,,,,,,,,,,,,,, -10500,2.262671,1.64496,,,,,,,,,,,,,, -10600,3.2855723,1.6054822,,,,,,,,,,,,,, -10700,4.137402,1.5681862,,,,,,,,,,,,,, -10800,4.84354,1.5760401,,,,,,,,,,,,,, -10824,,,0.38293672,0.1269345055507839,0.67859113,0.1951398476495747,5348.0,0.40143517,0.1293035159344342,2472.0,8684.085877656937,9649.997616052628,8684.085877656937,965.1938455104828,0.2876284122467041,0.0 -10900,2.373023,1.584818,,,,,,,,,,,,,, -11000,2.1881168,1.5728613,,,,,,,,,,,,,, -11100,2.8335576,1.6367853,,,,,,,,,,,,,, -11200,2.9641798,1.6400619,,,,,,,,,,,,,, -11300,2.482748,1.5580486,,,,,,,,,,,,,, -11400,2.8605266,1.5942793,,,,,,,,,,,,,, -11500,2.464129,1.6177729,,,,,,,,,,,,,, -11600,2.9543831,1.5547372,,,,,,,,,,,,,, -11700,2.3285744,1.5350676,,,,,,,,,,,,,, -11800,3.7037182,1.5644411,,,,,,,,,,,,,, -11900,2.2099578,1.5835232,,,,,,,,,,,,,, -12000,3.1966226,1.6090388,,,,,,,,,,,,,, -12100,2.1290503,1.5275882,,,,,,,,,,,,,, -12200,3.897675,1.6169661,,,,,,,,,,,,,, -12300,2.2565193,1.5211283,,,,,,,,,,,,,, -12400,2.9096887,1.5808086,,,,,,,,,,,,,, -12500,2.643979,1.5539178,,,,,,,,,,,,,, -12589,,,0.34332344,0.1176429923638142,0.656048,0.1870878669974994,5348.0,0.38833746,0.1252005768488615,2472.0,10124.552185058594,11224.081290483477,10124.552185058594,1098.6871774196625,0.3397469520568847,0.0 -12600,2.4521189,1.546554,,,,,,,,,,,,,, -12700,3.0161145,1.6347529,,,,,,,,,,,,,, -12800,2.4733067,1.5108846,,,,,,,,,,,,,, -12900,3.6275408,1.5984312,,,,,,,,,,,,,, -13000,2.0929234,1.5839174,,,,,,,,,,,,,, -13100,2.9357576,1.6175383,,,,,,,,,,,,,, -13200,2.7560937,1.5875461,,,,,,,,,,,,,, -13300,3.7324698,1.6059301,,,,,,,,,,,,,, -13400,2.0254738,1.5332811,,,,,,,,,,,,,, -13500,2.823652,1.5283989,,,,,,,,,,,,,, -13600,2.328813,1.5500745,,,,,,,,,,,,,, -13700,2.2030735,1.5317178,,,,,,,,,,,,,, -13800,2.1804943,1.5451959,,,,,,,,,,,,,, -13900,3.3317947,1.5114912,,,,,,,,,,,,,, -14000,2.2407975,1.5494763,,,,,,,,,,,,,, -14100,2.9290137,1.5910349,,,,,,,,,,,,,, -14200,1.9253833,1.5904144,,,,,,,,,,,,,, -14300,1.7162199,1.549633,,,,,,,,,,,,,, -14344,,,0.29857165,0.1054889622715099,0.63960975,0.1832839336918428,5348.0,0.37290213,0.1205086019539739,2472.0,11564.96788263321,12796.21423316002,11564.96788263321,1230.286295413971,0.3859894275665283,0.0 -14400,3.21197,1.5376611,,,,,,,,,,,,,, -14500,2.0382972,1.4432547,,,,,,,,,,,,,, -14600,5.032814,1.5575823,,,,,,,,,,,,,, -14700,2.0114923,1.5707426,,,,,,,,,,,,,, -14800,3.1793268,1.5403594,,,,,,,,,,,,,, -14900,3.3303483,1.5138853,,,,,,,,,,,,,, -15000,3.6995573,1.5827821,,,,,,,,,,,,,, -15100,4.124869,1.5910358,,,,,,,,,,,,,, -15200,2.6032557,1.5031323,,,,,,,,,,,,,, -15300,1.8811086,1.5145067,,,,,,,,,,,,,, -15400,2.500126,1.5581172,,,,,,,,,,,,,, -15500,2.093428,1.4794494,,,,,,,,,,,,,, -15600,4.7293973,1.4818288,,,,,,,,,,,,,, -15700,2.2247026,1.5543022,,,,,,,,,,,,,, -15800,2.4709933,1.5053114,,,,,,,,,,,,,, -15900,2.352931,1.5876492,,,,,,,,,,,,,, -16000,2.3261847,1.5117232,,,,,,,,,,,,,, -16100,2.7347803,1.5233939,,,,,,,,,,,,,, -16105,,,0.27393517,0.0952391085764142,0.61575806,0.1773077034476766,5348.0,0.3568221,0.1155525765238762,2472.0,13004.995213031769,14368.149621248243,13004.995213031769,1362.066598653793,0.4408023357391357,0.0 -16200,2.0151873,1.4803951,,,,,,,,,,,,,, -16300,2.148716,1.4473037,,,,,,,,,,,,,, -16400,3.5016584,1.6390206,,,,,,,,,,,,,, -16500,3.075194,1.4584639,,,,,,,,,,,,,, -16600,1.9488653,1.4882182,,,,,,,,,,,,,, -16700,2.864691,1.5240295,,,,,,,,,,,,,, -16800,2.007202,1.4259598,,,,,,,,,,,,,, -16900,3.0186303,1.4580349,,,,,,,,,,,,,, -17000,2.8910704,1.4708108,,,,,,,,,,,,,, -17100,2.2815785,1.499687,,,,,,,,,,,,,, -17200,2.192338,1.4950817,,,,,,,,,,,,,, -17300,1.7953372,1.407181,,,,,,,,,,,,,, -17400,4.8612614,1.4645976,,,,,,,,,,,,,, -17500,2.7653196,1.5585489,,,,,,,,,,,,,, -17600,4.245565,1.5391842,,,,,,,,,,,,,, -17700,2.9115837,1.4637057,,,,,,,,,,,,,, -17800,5.136131,1.5415984,,,,,,,,,,,,,, -17900,2.8606536,1.5572541,,,,,,,,,,,,,, -17906,,,0.28437436,0.1015193036837617,0.6062099,0.1754443554070884,5348.0,0.34615928,0.1115511953364613,2472.0,14445.443444490433,15939.336789369583,14445.443444490433,1492.6827099323273,0.4890756607055664,0.0 -18000,2.501084,1.501807,,,,,,,,,,,,,, -18100,2.8020504,1.4015796,,,,,,,,,,,,,, -18200,2.684688,1.5150203,,,,,,,,,,,,,, -18300,2.2871978,1.5104111,,,,,,,,,,,,,, -18400,1.7678272,1.5187502,,,,,,,,,,,,,, -18500,2.119644,1.5281469,,,,,,,,,,,,,, -18600,4.2636814,1.5090206,,,,,,,,,,,,,, -18700,2.5780387,1.4223253,,,,,,,,,,,,,, -18800,2.385536,1.4379649,,,,,,,,,,,,,, -18900,2.7403114,1.4955786,,,,,,,,,,,,,, -19000,1.6593009,1.4789437,,,,,,,,,,,,,, -19100,4.3849473,1.4904677,,,,,,,,,,,,,, -19200,4.373647,1.4742441,,,,,,,,,,,,,, -19300,2.766737,1.5394163,,,,,,,,,,,,,, -19400,2.3102024,1.5041841,,,,,,,,,,,,,, -19500,2.557215,1.4906001,,,,,,,,,,,,,, -19600,3.6890435,1.469616,,,,,,,,,,,,,, -19647,,,0.28456303,0.0969512678812237,0.58785826,0.1692653774486613,5348.0,0.3346423,0.1081388499583612,2472.0,15885.84837603569,17511.750654459,15885.84837603569,1624.5698716640472,0.5372681617736816,0.0 -19700,1.6337763,1.4344873,,,,,,,,,,,,,, -19800,3.0464659,1.4453685,,,,,,,,,,,,,, -19900,2.2587786,1.4377238,,,,,,,,,,,,,, -20000,3.450837,1.3519201,,,,,,,,,,,,,, -20100,3.4173758,1.4574629,,,,,,,,,,,,,, -20200,4.0646935,1.4877969,,,,,,,,,,,,,, -20300,2.5382352,1.4681693,,,,,,,,,,,,,, -20400,2.7182775,1.4599981,,,,,,,,,,,,,, -20500,2.274433,1.4624113,,,,,,,,,,,,,, -20600,2.260424,1.4668009,,,,,,,,,,,,,, -20700,3.2031271,1.4328808,,,,,,,,,,,,,, -20800,3.0352404,1.4406356,,,,,,,,,,,,,, -20900,3.2295058,1.445712,,,,,,,,,,,,,, -21000,3.6542106,1.4962673,,,,,,,,,,,,,, -21100,2.7095835,1.4788568,,,,,,,,,,,,,, -21200,1.9078507,1.4135063,,,,,,,,,,,,,, -21300,3.094591,1.4002422,,,,,,,,,,,,,, -21394,,,0.2806753,0.0961141741630833,0.56532633,0.163018816918814,5348.0,0.31810853,0.1022281802855808,2472.0,17326.075987815857,19084.2019238472,17326.075987815857,1756.6709995269775,0.5862374305725098,0.0 -21400,2.6510074,1.4509907,,,,,,,,,,,,,, -21500,3.1426222,1.516236,,,,,,,,,,,,,, -21600,2.085263,1.4402728,,,,,,,,,,,,,, -21700,2.1222737,1.4115407,,,,,,,,,,,,,, -21800,2.4867153,1.4564475,,,,,,,,,,,,,, -21900,3.7332327,1.4565462,,,,,,,,,,,,,, -22000,2.129159,1.3824265,,,,,,,,,,,,,, -22100,1.545189,1.4440523,,,,,,,,,,,,,, -22200,2.715506,1.3656582,,,,,,,,,,,,,, -22300,1.5031898,1.4080256,,,,,,,,,,,,,, -22400,2.7027237,1.4353819,,,,,,,,,,,,,, -22500,2.867678,1.4095342,,,,,,,,,,,,,, -22600,3.3347428,1.4285157,,,,,,,,,,,,,, -22700,3.3195803,1.4342194,,,,,,,,,,,,,, -22800,3.895644,1.4037045,,,,,,,,,,,,,, -22900,2.0212915,1.3564706,,,,,,,,,,,,,, -23000,3.7464194,1.3833573,,,,,,,,,,,,,, -23100,1.8804177,1.3723844,,,,,,,,,,,,,, -23163,,,0.2647326,0.0910690400246939,0.56283754,0.1627195226739527,5348.0,0.31591764,0.1013954055206873,2472.0,18766.0640540123,20656.36541342736,18766.0640540123,1888.72429060936,0.6344974040985107,0.0 -23200,4.551311,1.4613025,,,,,,,,,,,,,, -23300,2.4525924,1.4043332,,,,,,,,,,,,,, -23400,2.3286052,1.3704343,,,,,,,,,,,,,, -23500,2.9727087,1.500351,,,,,,,,,,,,,, -23600,2.5454385,1.4271219,,,,,,,,,,,,,, -23700,3.637351,1.3560418,,,,,,,,,,,,,, -23800,2.5862994,1.3837497,,,,,,,,,,,,,, -23900,2.2930298,1.3480401,,,,,,,,,,,,,, -24000,3.705514,1.4066627,,,,,,,,,,,,,, -24100,2.0181067,1.4206805,,,,,,,,,,,,,, -24200,3.7069242,1.3813343,,,,,,,,,,,,,, -24300,2.6927588,1.3933588,,,,,,,,,,,,,, -24400,2.449769,1.4310685,,,,,,,,,,,,,, -24500,2.3841095,1.3646616,,,,,,,,,,,,,, -24600,2.0490708,1.409848,,,,,,,,,,,,,, -24700,3.1264699,1.4597347,,,,,,,,,,,,,, -24800,2.4877543,1.4108247,,,,,,,,,,,,,, -24900,2.2504358,1.4600744,,,,,,,,,,,,,, -24924,,,0.2338939,0.0815316076179215,0.5402023,0.1555171514911611,5348.0,0.30580786,0.097678386448114,2472.0,20206.35933160782,22228.771274089813,20206.35933160782,2020.7151553630829,0.6808972358703613,0.0 -25000,1.8831965,1.3533999,,,,,,,,,,,,,, -25100,3.736529,1.393298,,,,,,,,,,,,,, -25200,2.3165517,1.4343859,,,,,,,,,,,,,, -25300,3.33392,1.4156398,,,,,,,,,,,,,, -25400,4.1478076,1.3891842,,,,,,,,,,,,,, -25500,6.4486737,1.3785814,,,,,,,,,,,,,, -25600,2.7309384,1.3946344,,,,,,,,,,,,,, -25700,3.3563392,1.37096,,,,,,,,,,,,,, -25800,3.0237813,1.4423755,,,,,,,,,,,,,, -25900,3.011014,1.337375,,,,,,,,,,,,,, -26000,2.5794895,1.3175777,,,,,,,,,,,,,, -26100,2.8209667,1.3508556,,,,,,,,,,,,,, -26200,3.3976326,1.3981808,,,,,,,,,,,,,, -26300,4.4618425,1.3460088,,,,,,,,,,,,,, -26400,3.0873387,1.3590196,,,,,,,,,,,,,, -26500,4.6938663,1.3518325,,,,,,,,,,,,,, -26600,2.1879077,1.3807398,,,,,,,,,,,,,, -26668,,,0.22001642,0.0781554031143222,0.5292033,0.1534220917771319,5348.0,0.29551646,0.0941847947514878,2472.0,21646.578894615173,23799.4051322937,21646.578894615173,2151.0077567100525,0.7294008731842041,0.0 -26700,3.182042,1.3828756,,,,,,,,,,,,,, -26800,4.237483,1.3372861,,,,,,,,,,,,,, -26900,2.0113544,1.3264303,,,,,,,,,,,,,, -27000,1.8516582,1.3862945,,,,,,,,,,,,,, -27100,2.8473763,1.2895576,,,,,,,,,,,,,, -27200,2.3245764,1.3778678,,,,,,,,,,,,,, -27300,2.2293937,1.2802364,,,,,,,,,,,,,, -27400,1.7017376,1.3390619,,,,,,,,,,,,,, -27500,1.8646631,1.3010287,,,,,,,,,,,,,, -27600,2.8425405,1.3179356,,,,,,,,,,,,,, -27700,2.145865,1.3709621,,,,,,,,,,,,,, -27800,2.1969411,1.3292384,,,,,,,,,,,,,, -27900,2.9382167,1.3429215,,,,,,,,,,,,,, -28000,1.9252187,1.3293896,,,,,,,,,,,,,, -28100,1.9555492,1.3549591,,,,,,,,,,,,,, -28200,2.3951478,1.4001087,,,,,,,,,,,,,, -28300,3.434749,1.3119395,,,,,,,,,,,,,, -28400,2.62774,1.4352995,,,,,,,,,,,,,, -28420,,,0.20773691,0.0718822554745272,0.50856733,0.1475520627166262,5348.0,0.27854708,0.0901427904048097,2472.0,23086.48792457581,25370.54607105255,23086.48792457581,2282.118963956833,0.7779040336608887,0.0 -28500,1.8543633,1.3395884,,,,,,,,,,,,,, -28600,2.8555567,1.3852385,,,,,,,,,,,,,, -28700,2.2979584,1.371102,,,,,,,,,,,,,, -28800,2.4670079,1.3118824,,,,,,,,,,,,,, -28900,1.5455207,1.3001081,,,,,,,,,,,,,, -29000,2.9298189,1.3369056,,,,,,,,,,,,,, -29100,2.482613,1.3457476,,,,,,,,,,,,,, -29200,2.2245471,1.2968233,,,,,,,,,,,,,, -29300,2.8339417,1.3556176,,,,,,,,,,,,,, -29400,3.7654061,1.4176855,,,,,,,,,,,,,, -29500,3.063106,1.3589212,,,,,,,,,,,,,, -29600,5.687146,1.4030781,,,,,,,,,,,,,, -29700,3.690635,1.31459,,,,,,,,,,,,,, -29800,3.7136636,1.3373191,,,,,,,,,,,,,, -29900,2.5480416,1.2713724,,,,,,,,,,,,,, -30000,1.816117,1.3789457,,,,,,,,,,,,,, -30100,2.7537692,1.3398607,,,,,,,,,,,,,, -30184,,,0.22040235,0.0768168226729069,0.49543273,0.1430143757784064,5348.0,0.27274805,0.0890256535250746,2472.0,24526.86054444313,26939.56102442741,24526.86054444313,2410.6302189826965,0.8339059352874756,0.0 -30200,2.7505958,1.2559882,,,,,,,,,,,,,, -30300,2.0002184,1.3590719,,,,,,,,,,,,,, -30400,6.7474227,1.3511405,,,,,,,,,,,,,, -30500,1.9427916,1.3021036,,,,,,,,,,,,,, -30600,2.0816555,1.3717866,,,,,,,,,,,,,, -30700,2.8604414,1.3538991,,,,,,,,,,,,,, -30800,2.5535161,1.3321354,,,,,,,,,,,,,, -30900,3.3637028,1.3594837,,,,,,,,,,,,,, -31000,3.0292473,1.2819322,,,,,,,,,,,,,, -31100,2.6924753,1.3173203,,,,,,,,,,,,,, -31200,2.287367,1.2713777,,,,,,,,,,,,,, -31300,2.8345826,1.307239,,,,,,,,,,,,,, -31400,2.2976813,1.3309908,,,,,,,,,,,,,, -31500,2.557169,1.2556319,,,,,,,,,,,,,, -31600,2.5101404,1.3141379,,,,,,,,,,,,,, -31700,3.2509606,1.322413,,,,,,,,,,,,,, -31800,1.9082366,1.3239644,,,,,,,,,,,,,, -31900,2.266796,1.2879016,,,,,,,,,,,,,, -31904,,,0.20300266,0.0685516470518017,0.48683342,0.1398862681869527,5348.0,0.2672006,0.0856336197266061,2472.0,25966.865243196487,28508.955932617188,25966.865243196487,2539.895192861557,0.8861689567565918,0.0 -32000,2.6241922,1.2491109,,,,,,,,,,,,,, -32100,3.2856503,1.3364389,,,,,,,,,,,,,, -32200,2.434042,1.3433913,,,,,,,,,,,,,, -32300,1.9463781,1.2767222,,,,,,,,,,,,,, -32400,1.8245208,1.319152,,,,,,,,,,,,,, -32500,2.208315,1.2130029,,,,,,,,,,,,,, -32600,2.9692123,1.3137251,,,,,,,,,,,,,, -32700,4.0562754,1.3115942,,,,,,,,,,,,,, -32800,2.284492,1.2767185,,,,,,,,,,,,,, -32900,1.4933473,1.2278875,,,,,,,,,,,,,, -33000,2.4017258,1.2725042,,,,,,,,,,,,,, -33100,3.262292,1.2017846,,,,,,,,,,,,,, -33200,2.877843,1.2594602,,,,,,,,,,,,,, -33300,1.9653988,1.2730387,,,,,,,,,,,,,, -33400,2.6892378,1.3763359,,,,,,,,,,,,,, -33500,1.3679991,1.2831322,,,,,,,,,,,,,, -33600,2.8965514,1.2562586,,,,,,,,,,,,,, -33640,,,0.20574053,0.0707560012123142,0.4757007,0.1369126350444597,5348.0,0.2545436,0.0820790932910852,2472.0,27407.19617295265,30081.98685359955,27407.19617295265,2672.469190120697,0.9372451305389404,0.0 -33700,2.7170467,1.2608083,,,,,,,,,,,,,, -33800,1.6449902,1.3320572,,,,,,,,,,,,,, -33900,2.9614184,1.3520551,,,,,,,,,,,,,, -34000,3.6265528,1.2996823,,,,,,,,,,,,,, -34100,2.2396922,1.2555137,,,,,,,,,,,,,, -34200,2.246049,1.2878386,,,,,,,,,,,,,, -34300,2.8288825,1.2675456,,,,,,,,,,,,,, -34400,2.5564756,1.2538316,,,,,,,,,,,,,, -34500,2.3829966,1.2686018,,,,,,,,,,,,,, -34600,2.232759,1.2496353,,,,,,,,,,,,,, -34700,3.0453684,1.2462445,,,,,,,,,,,,,, -34800,3.0178196,1.2348716,,,,,,,,,,,,,, -34900,2.6957312,1.2376446,,,,,,,,,,,,,, -35000,4.853954,1.2515428,,,,,,,,,,,,,, -35100,2.8952289,1.2362652,,,,,,,,,,,,,, -35200,2.4209454,1.2372925,,,,,,,,,,,,,, -35300,2.2013075,1.2261374,,,,,,,,,,,,,, -35391,,,0.19533391,0.0648997730106461,0.46448025,0.1330893924326829,5348.0,0.24884571,0.0786464363333536,2472.0,28847.49475240708,31654.591391325,28847.49475240708,2804.6382009983063,0.9974474906921388,0.0 -35400,2.5145648,1.2558354,,,,,,,,,,,,,, -35500,1.7582747,1.2085491,,,,,,,,,,,,,, -35600,1.9120318,1.1945559,,,,,,,,,,,,,, -35700,2.6185167,1.2450083,,,,,,,,,,,,,, -35800,2.035741,1.2291231,,,,,,,,,,,,,, -35900,2.5745902,1.2061256,,,,,,,,,,,,,, -36000,2.0702965,1.2366433,,,,,,,,,,,,,, -36100,2.4707837,1.246154,,,,,,,,,,,,,, -36200,2.9830875,1.2523835,,,,,,,,,,,,,, -36300,1.4942846,1.2120479,,,,,,,,,,,,,, -36400,2.5950212,1.2232757,,,,,,,,,,,,,, -36500,2.5739229,1.2106092,,,,,,,,,,,,,, -36600,1.7883939,1.2085931,,,,,,,,,,,,,, -36700,2.450215,1.2085329,,,,,,,,,,,,,, -36800,2.3672614,1.2459695,,,,,,,,,,,,,, -36900,2.6091764,1.1536827,,,,,,,,,,,,,, -37000,2.425118,1.1920459,,,,,,,,,,,,,, -37100,3.867893,1.1608526,,,,,,,,,,,,,, -37110,,,0.14400123,0.0516486782899254,0.44723344,0.1284648136169226,5348.0,0.24322133,0.0775699226128816,2472.0,30287.84063434601,33227.98648023605,30287.84063434601,2937.5611956119537,1.0491628646850586,0.0 -37200,1.7716035,1.2210816,,,,,,,,,,,,,, -37300,3.1231182,1.2466925,,,,,,,,,,,,,, -37400,2.764072,1.1558101,,,,,,,,,,,,,, -37500,2.235732,1.1956965,,,,,,,,,,,,,, -37600,2.007476,1.1420536,,,,,,,,,,,,,, -37700,2.1606138,1.1793932,,,,,,,,,,,,,, -37800,10.750224,1.166093,,,,,,,,,,,,,, -37900,3.4210937,1.184927,,,,,,,,,,,,,, -38000,2.7320683,1.1824318,,,,,,,,,,,,,, -38100,2.420515,1.2550658,,,,,,,,,,,,,, -38200,2.5982952,1.1537169,,,,,,,,,,,,,, -38300,3.2498624,1.2324393,,,,,,,,,,,,,, -38400,2.082541,1.225241,,,,,,,,,,,,,, -38500,2.3627179,1.2259382,,,,,,,,,,,,,, -38600,2.4826424,1.1253288,,,,,,,,,,,,,, -38700,2.2436929,1.1708552,,,,,,,,,,,,,, -38800,2.5595546,1.2251122,,,,,,,,,,,,,, -38860,,,0.16303992,0.0565880340436316,0.4371549,0.1246126070459658,5348.0,0.23379442,0.0750309751589381,2472.0,31728.33105635643,34800.18440055847,31728.33105635643,3069.1378898620605,1.1044816970825195,0.0 -38900,2.3966408,1.2315464,,,,,,,,,,,,,, -39000,3.3385649,1.1560061,,,,,,,,,,,,,, -39100,2.1486144,1.2057866,,,,,,,,,,,,,, -39200,1.9094125,1.1668768,,,,,,,,,,,,,, -39300,2.2093086,1.2281501,,,,,,,,,,,,,, -39400,1.8948805,1.1722857,,,,,,,,,,,,,, -39500,1.9521077,1.1708887,,,,,,,,,,,,,, -39600,3.368233,1.1838526,,,,,,,,,,,,,, -39700,4.7174497,1.205313,,,,,,,,,,,,,, -39800,5.2799487,1.1557173,,,,,,,,,,,,,, -39900,2.2273526,1.1573375,,,,,,,,,,,,,, -40000,2.3264356,1.1787727,,,,,,,,,,,,,, -40100,1.8009177,1.100873,,,,,,,,,,,,,, -40200,3.2176228,1.1734262,,,,,,,,,,,,,, -40300,3.0648568,1.1771961,,,,,,,,,,,,,, -40400,3.6675045,1.2090843,,,,,,,,,,,,,, -40500,2.2999372,1.154262,,,,,,,,,,,,,, -40600,2.5086422,1.1846391,,,,,,,,,,,,,, -40624,,,0.2076175,0.0723138590715056,0.42986232,0.1229326974135184,5348.0,0.22834969,0.0737716572217821,2472.0,33168.68277978897,36372.19922232628,33168.68277978897,3200.66814160347,1.1607980728149414,0.0 -40700,1.975345,1.1574491,,,,,,,,,,,,,, -40800,2.5593503,1.1848423,,,,,,,,,,,,,, -40900,1.77797,1.1736274,,,,,,,,,,,,,, -41000,2.1737518,1.1847079,,,,,,,,,,,,,, -41100,2.1940615,1.1628128,,,,,,,,,,,,,, -41200,2.139935,1.1699717,,,,,,,,,,,,,, -41300,2.8435783,1.1629149,,,,,,,,,,,,,, -41400,3.1130993,1.2047886,,,,,,,,,,,,,, -41500,1.9618412,1.1841507,,,,,,,,,,,,,, -41600,5.707369,1.1864771,,,,,,,,,,,,,, -41700,2.2404814,1.2225387,,,,,,,,,,,,,, -41800,2.0419688,1.1947885,,,,,,,,,,,,,, -41900,3.3580017,1.1257432,,,,,,,,,,,,,, -42000,1.8635546,1.1469682,,,,,,,,,,,,,, -42100,2.4305744,1.2101212,,,,,,,,,,,,,, -42200,1.6908728,1.0895189,,,,,,,,,,,,,, -42300,4.774452,1.1257739,,,,,,,,,,,,,, -42351,,,0.21227044,0.0734329982180229,0.42083701,0.1201618119852863,5348.0,0.2234378,0.0713342676659963,2472.0,34609.292607069016,37946.1487801075,34609.292607069016,3333.876192331314,1.2174742221832275,0.0 -42400,3.1387794,1.1863132,,,,,,,,,,,,,, -42500,3.7361257,1.1493176,,,,,,,,,,,,,, -42600,2.5085268,1.1158314,,,,,,,,,,,,,, -42700,2.5659063,1.1571989,,,,,,,,,,,,,, -42800,6.7479835,1.1671039,,,,,,,,,,,,,, -42900,2.918103,1.1620234,,,,,,,,,,,,,, -43000,1.9810848,1.1317432,,,,,,,,,,,,,, -43100,2.6875188,1.1413049,,,,,,,,,,,,,, -43200,3.3181112,1.2007653,,,,,,,,,,,,,, -43300,3.2854865,1.2014487,,,,,,,,,,,,,, -43400,2.6793475,1.1289324,,,,,,,,,,,,,, -43500,5.991156,1.1862592,,,,,,,,,,,,,, -43600,2.5776212,1.0902369,,,,,,,,,,,,,, -43700,2.5785656,1.1075383,,,,,,,,,,,,,, -43800,2.0208058,1.1488574,,,,,,,,,,,,,, -43900,2.3970857,1.0982437,,,,,,,,,,,,,, -44000,3.1561563,1.1598554,,,,,,,,,,,,,, -44083,,,0.24737003,0.0858901088616036,0.41725254,0.1193218571690626,5348.0,0.22179714,0.0709280360733654,2472.0,36049.41221165657,39515.09550356865,36049.41221165657,3462.570736169815,1.2742955684661863,0.0 -44083,,,,,,,,,,,36049.41221165657,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 8c3daf8dd..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,232 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -536.8605499267578,0.0,19.18818688392639,1,0,19.18818688392639,0.413312017917633,0.7984466552734375,0.0260650279516998,43793,556.0487859249115,0.4159256517887115,0.7988808155059814,0.0215982155287161,0.4129833579063415,0.7986935973167419,0.0252096170502471,43793 -659.2226083278656,0.0308315753936767,259.19083619117737,744,0,259.19083619117737,0.983284831047058,0.0636374056339263,0.0572574487730643,43793,918.4641320705414,0.9869332909584044,0.0506426654756069,0.0564247931206779,0.9842758178710938,0.060369711369276,0.0555193829226929,43793 -782.5733218193054,0.0585832595825195,499.22732520103455,1481,0,499.22732520103455,0.9836386442184448,0.0592108108103275,0.1004557028126341,43793,1281.8988349437714,0.9874383211135864,0.0463490262627601,0.1024011862839557,0.9846505522727966,0.0559681691229343,0.1030645214502391,43793 -905.2488882541656,0.0856134891510009,739.374596118927,2219,0,739.374596118927,0.9837620854377748,0.0573773942887783,0.1340239544379468,43793,1644.7690522670746,0.9876528978347778,0.044420201331377,0.1416577263864256,0.984782874584198,0.0541465990245342,0.137276791969974,43793 -1029.594246149063,0.1141455173492431,979.5938329696656,2972,0,979.5938329696656,0.9842371940612792,0.0549632683396339,0.1510426967436176,43793,2009.3818798065183,0.988052487373352,0.042341161519289,0.1634633712465514,0.9852432012557985,0.0518907941877841,0.1526518339536486,43793 -1156.464278936386,0.1400582790374755,1219.685434341431,3731,0,1219.685434341431,0.9843096137046814,0.053251676261425,0.1643560816542293,43793,2376.389413833618,0.9881816506385804,0.0408075787127018,0.1910989147009367,0.9852667450904846,0.0502415895462036,0.1666665480608038,43793 -1282.059383392334,0.1661477088928222,1459.7971086502075,4491,0,1459.7971086502075,0.9847784042358398,0.051520898938179,0.1892872621373282,43793,2742.1424379348755,0.9886085391044616,0.0388112030923366,0.2143681979482724,0.985697865486145,0.0487300790846347,0.1857914809779446,43793 -1412.4098732471466,0.1936049461364746,1699.9812216758728,5251,0,1699.9812216758728,0.9847156405448914,0.0506240203976631,0.2009949060040942,43793,3112.7249789237976,0.9887627959251404,0.0380763970315456,0.2471778039737912,0.9855229258537292,0.0480495318770408,0.1963845763391835,43793 -1539.7698109149933,0.2290825843811035,1940.042201757431,6007,0,1940.042201757431,0.98514062166214,0.0493680611252784,0.2163814131888277,43793,3480.2014927864075,0.9892771244049072,0.0365316681563854,0.2702459337040197,0.986055076122284,0.0466743931174278,0.2205379475830717,43793 -1669.380942583084,0.2561860084533691,2180.241662979126,6754,0,2180.241662979126,0.985238790512085,0.0494339019060134,0.2141966573978142,43793,3850.060249567032,0.9893851280212402,0.0358582735061645,0.2830660872122692,0.9861866235733032,0.0466683804988861,0.2184325413342826,43793 -1792.50821018219,0.2849104404449463,2420.270122051239,7502,0,2420.270122051239,0.985331416130066,0.0481740869581699,0.2225995118593554,43793,4213.264726400375,0.9895324110984802,0.035286109894514,0.3089379075079039,0.9861618280410768,0.0457406304776668,0.2240028748716719,43793 -1919.977798461914,0.3132412433624267,2660.5069646835327,8252,0,2660.5069646835327,0.9855239391326904,0.0481030829250812,0.2293570996321657,43793,4581.020742177963,0.9896116852760316,0.0348451845347881,0.2959481314296294,0.9864756464958192,0.0453048162162303,0.2330827604520835,43793 -2050.3256623744965,0.3406836986541748,2900.5411689281464,9003,0,2900.5411689281464,0.9855251908302308,0.0483233146369457,0.2416666839001448,43793,4951.450447320938,0.98968768119812,0.0344913601875305,0.3023390847822264,0.9864464402198792,0.0456723272800445,0.2355386945639433,43793 -2176.679932117462,0.3683798313140869,3140.5600502491,9738,0,3140.5600502491,0.985731601715088,0.0476286262273788,0.2443877526841293,43793,5317.872577667236,0.9900109171867372,0.0331563651561737,0.3391733973208491,0.9865714311599731,0.044988140463829,0.2418029850919489,43793 -2307.231298685074,0.3975164890289306,3380.5997710227966,10481,0,3380.5997710227966,0.9857054352760316,0.0473709106445312,0.2480712251497775,43793,5688.5130405426025,0.9901197552680968,0.0329639092087745,0.3427990933394376,0.9865767359733582,0.044519018381834,0.2523031039881308,43793 -2437.435226202011,0.4268772602081299,3620.5970821380615,11228,0,3620.5970821380615,0.9857454895973206,0.0474274680018425,0.2429146183176039,43793,6058.763890743256,0.990300476551056,0.0322925746440887,0.3670847887656891,0.9866266250610352,0.0446537360548973,0.256547329570424,43793 -2573.549907445908,0.4588387012481689,3860.692731380463,11964,0,3860.692731380463,0.9857821464538574,0.0473833791911602,0.2490345214654525,43793,6435.028518915176,0.990463137626648,0.0314462892711162,0.3924278482920513,0.9866583347320556,0.0446167550981044,0.2597770575960549,43793 -2703.654602050781,0.4913444519042969,4100.820775270462,12707,0,4100.820775270462,0.9858554005622864,0.0474821887910366,0.2588720984792856,43793,6805.316474914551,0.9905829429626464,0.0309941451996564,0.387986523878819,0.9866992831230164,0.0445774421095848,0.262996735421005,43793 -2834.7084777355194,0.521092414855957,4340.906958580017,13452,0,4340.906958580017,0.9859215617179872,0.0470420643687248,0.2574354219744418,43793,7176.506344079971,0.9906689524650574,0.0307073928415775,0.3913719672774121,0.9868081212043762,0.0442623235285282,0.26061023076219,43793 -2961.4146535396576,0.5505294799804688,4581.159161567688,14200,0,4581.159161567688,0.985874354839325,0.0469127632677555,0.2512387055784661,43793,7543.514000415802,0.9906706213951112,0.0308509096503257,0.3912459024790902,0.9867467880249025,0.0441011153161525,0.2619010691977937,43793 -3092.9744811058044,0.5795705318450928,4821.198109149933,14944,0,4821.198109149933,0.9859371185302734,0.0473070107400417,0.2592420343565292,43793,7915.161383390427,0.9905703663825988,0.0310817845165729,0.382999974042977,0.9868085384368896,0.0445431284606456,0.2633013852647012,43793 -3222.9942405223846,0.6097488403320312,5061.430529117584,15692,0,5061.430529117584,0.9859729409217834,0.0467361696064472,0.2624294844696725,43793,8285.463474035263,0.9907108545303344,0.0305028110742568,0.4101220713336749,0.9868288040161132,0.0439677983522415,0.2711070559332986,43793 -3352.3659710884094,0.6399135589599609,5301.552769899368,16434,0,5301.552769899368,0.9860348105430604,0.0468795970082283,0.2668301827293936,43793,8655.007350206375,0.9906653761863708,0.0305659603327512,0.3961345079374617,0.9868064522743224,0.0440406166017055,0.2708324881862166,43793 -3482.0657093524933,0.6692507266998291,5541.53219461441,17176,0,5541.53219461441,0.9859800934791564,0.0467806570231914,0.2602629097743209,43793,9024.735594034197,0.9908429384231568,0.0300751198083162,0.4131038732595066,0.9868007898330688,0.0441732369363307,0.2631906991121361,43793 -3608.844582557678,0.699115514755249,5781.594381570816,17933,0,5781.594381570816,0.9859012961387634,0.0467691197991371,0.262596531655433,43793,9391.62649512291,0.991029977798462,0.0294215101748704,0.4207735413460807,0.9868117570877076,0.0441651605069637,0.2655381044547623,43793 -3738.4994037151337,0.730036735534668,6021.693260192871,18678,0,6021.693260192871,0.9860984086990356,0.0463713221251964,0.2668474315698864,43793,9761.431322336197,0.9913448691368104,0.0282798688858747,0.4553588251086291,0.9869843125343324,0.043611004948616,0.2739963602729493,43793 -3868.2067382335663,0.7608804702758789,6261.680698633194,19421,0,6261.680698633194,0.98604154586792,0.046822752803564,0.2666428469243027,43793,10131.17840242386,0.991253674030304,0.0285461358726024,0.4504025211918963,0.9868957996368408,0.0441204234957695,0.2729475198973159,43793 -3997.909258365631,0.7910916805267334,6501.814088821411,20164,0,6501.814088821411,0.9860731363296508,0.0467290207743644,0.2667028073713977,43793,10501.064566135406,0.9912835359573364,0.0284360982477664,0.4628734931838974,0.9869039058685304,0.043829821050167,0.2754548308109774,43793 -4127.003994941711,0.8224525451660156,6741.928815841675,20909,0,6741.928815841675,0.986042857170105,0.046838354319334,0.2644472106800913,43793,10870.3255007267,0.9910857677459716,0.0290487967431545,0.4316780481261128,0.9869911670684814,0.0437979027628898,0.2797274694723569,43793 -4266.083292245865,0.8539795875549316,6981.9416534900665,21660,0,6981.9416534900665,0.986139714717865,0.0466521978378295,0.2727992682591666,43793,11249.46938920021,0.9911643266677856,0.028825219720602,0.4205208436652018,0.9870212078094482,0.0438480079174041,0.2758581640517665,43793 -4397.061862707138,0.8849740028381348,7222.009767055511,22400,0,7222.009767055511,0.9860116839408876,0.0466559007763862,0.2678142261886165,43793,11620.568714141846,0.9911457896232604,0.029067164286971,0.4441765352018306,0.9868693947792052,0.0439830757677555,0.2746977620566058,43793 -4530.485562801361,0.9154069423675536,7462.221218585968,23147,0,7462.221218585968,0.9860929846763612,0.0467063896358013,0.2686252559516493,43793,11994.254369735718,0.9912724494934082,0.0284826979041099,0.4513491212236842,0.987023651599884,0.0437239073216915,0.2754543168780456,43793 -4662.562074661255,0.9465370178222656,7702.266625642776,23892,0,7702.266625642776,0.9861068725585938,0.0471601597964763,0.2613879065969777,43793,12366.427788496016,0.9913005828857422,0.0282000079751014,0.4627503649671614,0.9869924187660216,0.0442375540733337,0.2778445772522299,43793 -4791.539767742157,0.978119134902954,7942.415654420853,24632,0,7942.415654420853,0.9860866665840148,0.046415239572525,0.2697986160231079,43793,12735.605735778809,0.991439700126648,0.0278518870472908,0.4635890547062042,0.9869059324264526,0.043697815388441,0.2797329279295364,43793 -4920.541341781616,1.0089633464813232,8182.555495977402,25374,0,8182.555495977402,0.9860929846763612,0.0466537848114967,0.2697822094920301,43793,13104.798302650452,0.9915762543678284,0.0273897033184766,0.4749107520915142,0.9869404435157776,0.0440693721175193,0.2789975196120251,43793 -5047.630271434784,1.0395777225494385,8422.817781925201,26125,0,8422.817781925201,0.9861578345298768,0.0468244887888431,0.275521530104557,43793,13472.200419664385,0.9917351603507996,0.0267021115869283,0.4959739328984074,0.986988365650177,0.0440122708678245,0.2815963343967564,43793 -5181.35660982132,1.0719294548034668,8662.943137168884,26870,0,8662.943137168884,0.9860352277755736,0.0468922369182109,0.2654404572829221,43793,13846.104840278624,0.9916290044784546,0.0272741466760635,0.4856420312623262,0.9868893027305604,0.043773666024208,0.2773456521371067,43793 -5312.733390331268,1.108165979385376,8903.035054683685,27618,0,8903.035054683685,0.9862993359565736,0.046916589140892,0.2668597689788753,43793,14217.63026714325,0.9916507601737976,0.0270869210362434,0.4838185757181377,0.9870354533195496,0.0440697260200977,0.2777947569709336,43793 -5439.122165679932,1.1402628421783447,9143.177698135376,28366,0,9143.177698135376,0.9861291646957396,0.0469421707093715,0.2653502959951228,43793,14584.213775396349,0.99137681722641,0.0279942359775304,0.46102037043348,0.9870212078094482,0.0440024808049202,0.2821440113552839,43793 -5569.727295160294,1.1746866703033447,9383.163092136385,29113,0,9383.163092136385,0.9860820174217224,0.0466917492449283,0.2662801861727709,43793,14954.858783721924,0.9914528131484984,0.0275321006774902,0.4668566314086093,0.9869303107261658,0.0438966490328311,0.2802756131313504,43793 -5694.410034656525,1.206719160079956,9623.224762439728,29860,0,9623.224762439728,0.9861574172973632,0.0468879751861095,0.2757145467342193,43793,15319.655102968216,0.9916426539421082,0.0269223246723413,0.4755547289723321,0.9869992733001708,0.0440030097961425,0.2888798783758334,43793 -5823.471135616303,1.2388403415679932,9863.198410272598,30604,0,9863.198410272598,0.9861317276954652,0.0468833744525909,0.2673260861461388,43793,15688.741545438766,0.9915139079093932,0.0273679699748754,0.4902585092633399,0.987023651599884,0.0440075658261776,0.2788562240367521,43793 -5949.193348169327,1.2714645862579346,10103.280684947968,31352,0,10103.280684947968,0.9862100481987,0.0471400320529937,0.2710934963817139,43793,16054.598301887512,0.9917362928390504,0.026557233184576,0.5076482448586728,0.987015962600708,0.0445297770202159,0.2808674786219359,43793 -6076.021708488464,1.3138582706451416,10343.537009716034,32100,0,10343.537009716034,0.9861077070236206,0.0469487346708774,0.2723113894305002,43793,16421.74523949623,0.9919912815093994,0.0256998073309659,0.516720698352363,0.9870447516441344,0.0440525151789188,0.2836175444213666,43793 -6203.211605548859,1.3478012084960938,10583.608781576157,32846,0,10583.608781576157,0.985975444316864,0.0468091182410717,0.2697275506495091,43793,16789.060408830643,0.991959512233734,0.0257570836693048,0.5138833451896644,0.9868990182876588,0.0438992008566856,0.2828000828133626,43793 -6330.386728048325,1.3803482055664062,10823.852623224258,33587,0,10823.852623224258,0.9861405491828918,0.0469157211482524,0.2739759901800147,43793,17156.531865119934,0.9921208620071412,0.0254428181797266,0.5243534403747736,0.9870321750640868,0.043906345963478,0.2832817246832959,43793 -6461.535669565201,1.4151837825775146,11063.997322559357,34331,0,11063.997322559357,0.9860954880714417,0.0470939092338085,0.2688336477584092,43793,17527.88009405136,0.9919935464859008,0.0258319564163684,0.5051137293483496,0.9869655966758728,0.0441421791911125,0.2822380868318115,43793 -6586.641371488571,1.4497976303100586,11304.021178007126,35077,0,11304.021178007126,0.9861544370651244,0.0474284365773201,0.2655076803593601,43793,17893.06462287903,0.9918292164802552,0.0262144785374403,0.5070033636209796,0.9870935082435608,0.0444223061203956,0.2835313514000049,43793 -6713.556245326996,1.4856345653533936,11544.161909103394,35819,0,11544.161909103394,0.986088752746582,0.0473106503486633,0.266058645154373,43793,18260.17632985115,0.9920508861541748,0.0255311019718647,0.5223396912909026,0.9869850873947144,0.0442952550947666,0.2810178823058246,43793 -6844.728674411774,1.5204980373382568,11784.26586842537,36565,0,11784.26586842537,0.9860820174217224,0.0470707193017005,0.266341244878793,43793,18631.508261442184,0.9919530153274536,0.0257911756634712,0.5123207397683507,0.9869518280029296,0.044221568852663,0.2859384338148794,43793 -6974.566897153854,1.554847240447998,12024.3671708107,37298,0,12024.3671708107,0.9861974120140076,0.0475418157875537,0.2678350340250758,43793,19001.5023636818,0.9920364618301392,0.0253801997750997,0.5294398886495806,0.9870569705963136,0.0445403046905994,0.2838524985635823,43793 -7103.1003510952,1.5899608135223389,12264.32112288475,38029,0,12264.32112288475,0.9860537648200988,0.0468669906258583,0.2705610278003447,43793,19370.045646190643,0.9923050403594972,0.0245301537215709,0.5316224932311173,0.9869834780693054,0.044165726751089,0.2815878288103547,43793 -7228.229280233383,1.6254029273986816,12504.334551811218,38775,0,12504.334551811218,0.9861186742782592,0.0474701225757598,0.2700197825705046,43793,19735.243750810623,0.9925279021263124,0.0239597801119089,0.5613465445746542,0.9869436621665956,0.0446382723748683,0.2795682889672867,43793 -7351.084131479263,1.660273790359497,12744.566393852234,39518,0,12744.566393852234,0.986042857170105,0.0473844185471534,0.2716201668149071,43793,20098.38708305359,0.992580771446228,0.0237706210464239,0.5658957285509353,0.986880362033844,0.0443928875029087,0.2854602876735767,43793 -7476.147330522537,1.694267749786377,12984.634531497955,40264,0,12984.634531497955,0.9861733913421632,0.0474028438329696,0.2700235207289608,43793,20463.572825193405,0.9925384521484376,0.0238226205110549,0.5491199180507131,0.9870293140411376,0.0445915944874286,0.2869749485756071,43793 -7602.085725069046,1.7297906875610352,13224.755261421204,41010,0,13224.755261421204,0.9862361550331116,0.0475085414946079,0.2737203631907097,43793,20829.687710762024,0.992478609085083,0.0241543073207139,0.5465455027988128,0.9870561361312866,0.0445823147892952,0.2870653774717982,43793 -7728.96719622612,1.7673423290252686,13464.745002746582,41754,0,13464.745002746582,0.9862349033355712,0.0477738939225673,0.2742440255073568,43793,21196.616488933563,0.9922721982002258,0.0246417224407196,0.5371658929798946,0.9870626330375672,0.0446793921291828,0.2881509323311192,43793 -7852.276547431946,1.8038551807403564,13704.911063671112,42494,0,13704.911063671112,0.9861738085746764,0.0473472066223621,0.2703982284928044,43793,21560.149626255035,0.9924113154411316,0.0241913646459579,0.5530880035861376,0.9869915843009948,0.0445187799632549,0.2845378522712065,43793 -7982.220856189728,1.8395116329193115,13944.901382684708,43234,0,13944.901382684708,0.9861283302307128,0.0478745140135288,0.2756288716884729,43793,21930.13982820511,0.9924986958503724,0.0238905660808086,0.5642352845260664,0.987065851688385,0.0448200851678848,0.2887220173444159,43793 -8104.878615617752,1.876089572906494,14184.867455482485,43970,0,14184.867455482485,0.9862905144691468,0.0479585081338882,0.274493591062196,43793,22292.821218252186,0.9926086664199828,0.0232233218848705,0.5818679028796002,0.9871900677680968,0.0448637790977954,0.2902499637499417,43793 -8227.839316368103,1.911909580230713,14425.079031467438,44713,0,14425.079031467438,0.9861140251159668,0.0478437356650829,0.269938147158632,43793,22656.048730373383,0.9928114414215088,0.0229176878929138,0.5719944676468065,0.987003743648529,0.0447695814073085,0.286275793689501,43793 -8352.581824541092,1.949923276901245,14665.148543596268,45458,0,14665.148543596268,0.9861544370651244,0.0478927008807659,0.2740243349431455,43793,23020.918548822403,0.9930723309516908,0.0220610983669757,0.6057093471035695,0.9870406985282898,0.0448703877627849,0.2880019608799084,43793 -8473.786965847015,1.9878568649291992,14905.407657384872,46195,0,14905.407657384872,0.9862251877784728,0.048274990171194,0.275243783480796,43793,23382.441463947296,0.9932297468185424,0.0215552020817995,0.611245669579988,0.9870281219482422,0.0453068502247333,0.287117571203055,43793 -8599.17787861824,2.023483991622925,15145.571384191511,46939,0,15145.571384191511,0.9862580895423888,0.0486461706459522,0.2688644851086066,43793,23748.05126833916,0.9930238723754884,0.022141970694065,0.5996005868878975,0.9870882034301758,0.0453628972172737,0.2826583293956821,43793 -8725.551683664322,2.058670997619629,15385.59983444214,47672,0,15385.59983444214,0.9862008094787598,0.0481557138264179,0.2730225395600637,43793,24114.510402441025,0.992883801460266,0.0226908214390277,0.5954002678011658,0.987058162689209,0.045141864567995,0.2881418197880501,43793 -8848.880156755447,2.0949044227600098,15625.717020273209,48417,0,15625.717020273209,0.9862513542175292,0.0486955344676971,0.2771422885173468,43793,24478.012163877487,0.9927613139152528,0.0228866655379533,0.564050308760248,0.987027108669281,0.0457050092518329,0.28652231033362,43793 -8972.782843828201,2.131263732910156,15865.82262635231,49148,0,15865.82262635231,0.9862251877784728,0.0485099367797374,0.2718043034407581,43793,24842.079083919525,0.9930188655853271,0.0221155732870101,0.5984401713121497,0.987084150314331,0.0453145541250705,0.2855918960402324,43793 -9102.306144714355,2.167288064956665,16105.879657030106,49889,0,16105.879657030106,0.9862938523292542,0.0486649572849273,0.2738028980149617,43793,25211.7152671814,0.9930760860443116,0.0218061376363039,0.6098008796744399,0.9870622158050536,0.045785091817379,0.2812710950757494,43793 -9224.947851657867,2.207081794738769,16345.824395656586,50609,0,16345.824395656586,0.9862997531890868,0.0491118878126144,0.2682895887359769,43793,25574.36565876007,0.9931766986846924,0.0214318484067916,0.619487995667,0.9871689677238464,0.045978058129549,0.28589317875376,43793 -9347.113098144531,2.242493867874145,16585.909583091736,51353,0,16585.909583091736,0.9862328171730042,0.0489258803427219,0.2725946008014016,43793,25936.671803712845,0.9935702085494996,0.0202808193862438,0.6379793391083107,0.9870991706848145,0.0458444878458976,0.2865038906963764,43793 -9470.78262424469,2.279003143310547,16825.856951475143,52098,0,16825.856951475143,0.9862193465232848,0.0490687452256679,0.2715121900772505,43793,26300.3449792862,0.9937245845794678,0.0197700150310993,0.6558632894885597,0.987107276916504,0.0459720119833946,0.2865892628409184,43793 -9596.453342676165,2.315351963043213,17066.066660165787,52846,0,17066.066660165787,0.9862968325614928,0.0496241226792335,0.273131802214466,43793,26666.281783103943,0.9934922456741332,0.0203552469611167,0.6389656490329827,0.987166166305542,0.0464010424911975,0.2904116007109317,43793 -9718.271196126938,2.3540172576904297,17306.208050727844,53594,0,17306.208050727844,0.9861013889312744,0.0491682216525077,0.2707588206225723,43793,27028.29993700981,0.9935109615325928,0.0204487219452857,0.635550748244853,0.9870272874832152,0.0459951683878898,0.2882264457127399,43793 -9840.98232102394,2.391303777694702,17546.322751760483,54339,0,17546.322751760483,0.9861831068992616,0.049438040703535,0.2739239046998224,43793,27391.182716608047,0.9935351014137268,0.0203988421708345,0.6431458559111352,0.9870265126228333,0.0463195703923702,0.2899425669337316,43793 -9963.41772222519,2.4342310428619385,17786.44678425789,55079,0,17786.44678425789,0.9861974120140076,0.0495461821556091,0.2750850588168503,43793,27753.804602861404,0.9935994744300842,0.0202458370476961,0.6338022190807497,0.9870439767837524,0.0462508201599121,0.2895553029434483,43793 -10085.61788034439,2.470750331878662,18026.465641736984,55829,0,18026.465641736984,0.9862349033355712,0.0499169677495956,0.2747479620975098,43793,28116.08030819893,0.9936383962631226,0.0199243109673261,0.6444634510671647,0.9870975613594056,0.0466697067022323,0.289663565541211,43793 -10204.472261428831,2.5078279972076416,18266.70391392708,56574,0,18266.70391392708,0.986233651638031,0.0503293424844741,0.2715353789860384,43793,28475.22965478897,0.9936703443527222,0.0196411423385143,0.6531613462400097,0.9871182441711426,0.0468352176249027,0.2888919640760225,43793 -10327.702929973602,2.546010971069336,18506.681567668915,57318,0,18506.681567668915,0.986137628555298,0.050077386200428,0.2762526456314265,43793,28838.496221780777,0.9939448237419128,0.0190205872058868,0.6744130159403038,0.9870589971542358,0.0467969439923763,0.2875065035808982,43793 -10452.352368354796,2.582829475402832,18746.95504975319,58059,0,18746.95504975319,0.9861649870872498,0.0505991503596305,0.2743655426824452,43793,29203.476695775986,0.9942131638526917,0.0180310141295194,0.6876187034913789,0.9870078563690186,0.0471584536135196,0.2862036634371643,43793 -10573.129535675049,2.621100902557373,18986.97296071053,58805,0,18986.97296071053,0.9862053990364076,0.0509169474244117,0.2733854418500244,43793,29564.330701112747,0.9944328665733336,0.01756127551198,0.7030557313454966,0.9870532751083374,0.047436848282814,0.2879422701973064,43793 -10693.31673192978,2.6599795818328857,19227.02808666229,59552,0,19227.02808666229,0.9861890077590942,0.0505956448614597,0.2751660225322885,43793,29924.63180088997,0.9943233728408812,0.0178030133247375,0.694210838560239,0.9870098829269408,0.0471440367400646,0.2937513754272167,43793 -10815.961034536362,2.698820114135742,19467.04409813881,60296,0,19467.04409813881,0.986295998096466,0.051175232976675,0.2740927208224178,43793,30287.35133075714,0.9941422343254088,0.0181688442826271,0.6758369738154911,0.9870870113372804,0.047808714210987,0.2878571093468118,43793 -10937.600918531418,2.740966796875,19707.040642499924,61040,0,19707.040642499924,0.9861696362495422,0.0514022409915924,0.2712695979829206,43793,30649.05088710785,0.9941512942314148,0.0182098597288131,0.686735293228531,0.987099587917328,0.0478898957371711,0.286116871264883,43793 -11056.847058057783,2.7800092697143555,19947.08567380905,61786,0,19947.08567380905,0.986199915409088,0.051761258393526,0.2712998483111147,43793,31008.40128135681,0.9941979050636292,0.0179943554103374,0.6911118193222282,0.9871572256088256,0.0478870645165443,0.290800598389781,43793 -11180.55793595314,2.8193583488464355,20187.041157007217,62531,0,20187.041157007217,0.9861767888069152,0.0520378574728965,0.2690783502432456,43793,31372.12839603424,0.994190752506256,0.0179026070982217,0.6973760724944296,0.9870695471763612,0.0483664646744728,0.2877178300183725,43793 -11298.31682562828,2.868788719177246,20427.09281229973,63276,0,20427.09281229973,0.9861236810684204,0.0519551709294319,0.2732008237432165,43793,31730.00851416588,0.9944777488708496,0.017259057611227,0.7072899931696965,0.987057328224182,0.0482454113662242,0.2879580300275564,43793 -11418.96899819374,2.907371282577514,20667.08352851868,64020,0,20667.08352851868,0.986163318157196,0.0521516539156436,0.2728042778335163,43793,32090.70943880081,0.9946504831314088,0.0165974479168653,0.7181443084231507,0.987084984779358,0.0485220737755298,0.2862159947309958,43793 -11537.486523628237,2.947197914123535,20907.040336608887,64754,0,20907.040336608887,0.9861885905265808,0.0523666962981224,0.270310425672684,43793,32449.245572328568,0.9948357939720154,0.0161562655121088,0.7323550397195834,0.98714017868042,0.0485176891088485,0.2874680887460105,43793 -11653.92560315132,2.98614239692688,21147.292327165604,65501,0,21147.292327165604,0.9861925840377808,0.0526473931968212,0.2703020664572869,43793,32805.996050834656,0.9949519634246826,0.0156491510570049,0.7398820880929526,0.9871227145195008,0.0487340465188026,0.2906366602949252,43793 -11773.509294509888,3.0265464782714844,21387.325922966003,66248,0,21387.325922966003,0.9861291646957396,0.0527142025530338,0.2725590229995492,43793,33165.67395567894,0.99501234292984,0.0157033149152994,0.7419741470556183,0.9870402812957764,0.0490181297063827,0.2915888285229315,43793 -11893.024748325348,3.0661911964416504,21627.543115854263,66996,0,21627.543115854263,0.9861043095588684,0.0529823116958141,0.2693005014060972,43793,33525.466379880905,0.9948969483375548,0.0158338658511638,0.7381022885433084,0.9870325922966005,0.0491348057985305,0.2917242559636582,43793 -12011.118778467178,3.105543851852417,21867.74240756035,67749,0,21867.74240756035,0.986127495765686,0.0531786978244781,0.2718304619033597,43793,33883.81890010834,0.994721531867981,0.0163781829178333,0.7252492006625263,0.9870800971984864,0.0493294596672058,0.2926122424504049,43793 -12138.936106681824,3.1450343132019043,22107.820801496506,68488,0,22107.820801496506,0.986100137233734,0.0535429641604423,0.273014472514247,43793,34251.776836156845,0.994666337966919,0.01639249548316,0.7239857813140791,0.9870427250862122,0.0495910719037056,0.29098764881197,43793 -12258.297290086746,3.189326524734497,22347.77027964592,69225,0,22347.77027964592,0.9860045313835144,0.0530650578439235,0.2730619442161265,43793,34611.156307697296,0.994940996170044,0.0157144218683242,0.7364670876265623,0.986950159072876,0.0492487885057926,0.2891234404148812,43793 -12372.406210422516,3.2301290035247803,22587.750163316727,69963,0,22587.750163316727,0.986139714717865,0.0534122213721275,0.2720989782136767,43793,34965.306415081024,0.9950496554374696,0.0153866959735751,0.7498461206729499,0.9870471954345704,0.0496566444635391,0.2911408890764547,43793 -12493.917872667313,3.2706449031829834,22827.77950668335,70711,0,22827.77950668335,0.98613041639328,0.0536233186721801,0.2714608792304772,43793,35326.90785813332,0.9950863122940063,0.0152974929660558,0.7440555332907148,0.9869948625564576,0.0498314648866653,0.290006136839753,43793 -12609.725140094755,3.3116016387939453,23067.749517202377,71455,0,23067.749517202377,0.9860959053039552,0.0535416230559349,0.2728977546716898,43793,35682.74603819847,0.9954333305358888,0.0143726943060755,0.7739653157648199,0.9869473576545716,0.0498649254441261,0.2913538274236733,43793 -12726.47549533844,3.352621078491211,23307.969571113583,72207,0,23307.969571113583,0.9861700534820556,0.0537815801799297,0.2736166316268893,43793,36039.77791452408,0.9954251646995544,0.0143637405708432,0.768106455239543,0.9870321750640868,0.0499582774937152,0.2905641657020273,43793 -12843.777476787567,3.3956403732299805,23548.14874601364,72960,0,23548.14874601364,0.9861670732498168,0.053872849792242,0.2733393715619848,43793,36397.32256865501,0.9953681230545044,0.0144671695306897,0.7619073822725733,0.9870269298553468,0.0500656254589557,0.2910683527730367,43793 -12959.944328784944,3.435800552368164,23788.296051740646,73704,0,23788.296051740646,0.986119508743286,0.0539797991514205,0.2723943428460345,43793,36753.69769072533,0.9953656792640686,0.0145191131159663,0.7680448154889454,0.9869895577430724,0.0500662177801132,0.2890999776953153,43793 -13073.09792804718,3.483970880508423,24028.29723644257,74442,0,24028.29723644257,0.9861443638801576,0.0540723912417888,0.2734716980950172,43793,37106.92096066475,0.99525386095047,0.0147152254357934,0.7530949013829958,0.987002968788147,0.0501992814242839,0.2914377994324131,43793 -13191.766512155533,3.5248465538024902,24268.272846221924,75192,0,24268.272846221924,0.9861211776733398,0.0539165139198303,0.2740543698815366,43793,37465.62667059898,0.9953786730766296,0.0144358789548277,0.771100602272772,0.9869931936264038,0.0501015745103359,0.2919096212544834,43793 -13307.42241358757,3.565934658050537,24508.311578273773,75940,0,24508.311578273773,0.9861279129981996,0.053962018340826,0.2749305524965749,43793,37821.3829460144,0.9953892230987548,0.0143926069140434,0.7691021492142511,0.9870362281799316,0.0501075685024261,0.2927785560003524,43793 -13424.56529855728,3.6069235801696777,24748.563071012497,76686,0,24748.563071012497,0.9861447811126708,0.0540152974426746,0.2739643168130566,43793,38178.83873414993,0.9954409003257751,0.0143494829535484,0.7656371137133122,0.9870646595954896,0.0501628257334232,0.2936275545673407,43793 -13542.111910581589,3.654918909072876,24988.685559511185,77423,0,24988.685559511185,0.9861460328102112,0.0540119931101799,0.2738614273246971,43793,38536.57685160637,0.9954899549484252,0.0141655439510941,0.7759545638285559,0.9870675206184388,0.0501727983355522,0.2925905235182048,43793 -13656.379624128342,4.030130386352539,25228.494074106216,78169,0,25228.494074106216,0.986154854297638,0.0539759583771228,0.2742019631959193,43793,38891.04864430428,0.9955411553382874,0.0139900390058755,0.7764028795076405,0.987064242362976,0.0501210130751132,0.2925099990148256,43793 -13778.132029294968,4.071470975875855,25468.65472769737,78911,0,25468.65472769737,0.9861447811126708,0.0540265403687953,0.274390207559422,43793,39253.02361893654,0.9954591989517212,0.0141547610983252,0.7724652479037462,0.987065851688385,0.0501732900738716,0.2921166804633121,43793 -13895.783916950226,4.113170385360718,25708.90886187553,79651,0,25708.90886187553,0.9861464500427246,0.0540345124900341,0.2745096935947982,43793,39610.99244427681,0.9955033659934998,0.014123479835689,0.7787619339885247,0.9870723485946656,0.0501797534525394,0.2923345979855717,43793 -14010.557025671003,4.155007600784302,25949.06646060944,80398,0,25949.06646060944,0.9861477017402648,0.0540336780250072,0.2744157901547382,43793,39965.985496521,0.9954550862312316,0.0142388418316841,0.7613149448795697,0.987072765827179,0.0501789674162864,0.292532548861496,43793 -14125.44719004631,4.197999954223633,26189.072664499283,81145,0,26189.072664499283,0.9861477017402648,0.0540336780250072,0.274459226207241,43793,40320.94511008263,0.9954413771629332,0.0142920305952429,0.7743069602250183,0.987072765827179,0.0501789785921573,0.2925131864548106,43793 -14242.58826804161,4.240504264831543,26429.078844308853,81891,0,26429.078844308853,0.9861477017402648,0.0540336780250072,0.2745432332629763,43793,40678.15526819229,0.995533525943756,0.0140317762270569,0.7743152053587273,0.987072765827179,0.0501789674162864,0.2923898254512626,43793 -14360.468386650084,4.283483266830444,26669.20731639862,82635,0,26669.20731639862,0.9861477017402648,0.0540336780250072,0.2743955574094778,43793,41036.22692847252,0.9954736232757568,0.0142103852704167,0.7793815673647857,0.987072765827179,0.0501789674162864,0.2923741737832562,43793 -14473.543465852736,4.326739549636841,26909.310037374496,83379,0,26909.310037374496,0.9861477017402648,0.0540336780250072,0.2744025722481929,43793,41389.468089580536,0.9954969882965088,0.0140584139153361,0.7734911544122332,0.987072765827179,0.0501789711415767,0.29237454060272,43793 -14588.418242692947,4.370355844497681,27149.2596681118,84115,0,27149.2596681118,0.9861477017402648,0.0540336780250072,0.2743914294477163,43793,41744.3564248085,0.9954795241355896,0.0141797987744212,0.7700450569577544,0.987072765827179,0.0501789674162864,0.2924245796731328,43793 -14705.068259239197,4.415584087371826,27389.41187477112,84868,0,27389.41187477112,0.9861477017402648,0.0540336780250072,0.2744031934173734,43793,42101.22405195236,0.9954484105110168,0.0142580112442374,0.770631365526065,0.987072765827179,0.050178974866867,0.2925296154257013,43793 -14820.498401403427,4.458751201629639,27629.48045659065,85616,0,27629.48045659065,0.9861477017402648,0.0540336780250072,0.274370387334361,43793,42456.78595900536,0.9955559968948364,0.0140151465311646,0.7786308148010082,0.987072765827179,0.0501789711415767,0.2923086632501189,43793 -14932.916816949844,4.501690149307251,27869.674690246586,86357,0,27869.674690246586,0.9861477017402648,0.0540336780250072,0.2745166189528792,43793,42809.461366176605,0.9954648613929749,0.0142206866294145,0.7778873312937167,0.987072765827179,0.0501789711415767,0.2925448356793363,43793 -15051.508164167404,4.545614242553711,28109.69511294365,87093,0,28109.69511294365,0.9861477017402648,0.0540336780250072,0.2744101209278129,43793,43168.13737511635,0.9954904317855836,0.0141125125810503,0.7748273898665948,0.987072765827179,0.0501789674162864,0.2924114938285133,43793 -15170.293683290482,4.589420795440674,28349.8712515831,87824,0,28349.8712515831,0.9861477017402648,0.0540336780250072,0.2743490402265405,43793,43527.16404628754,0.9954723715782166,0.0142516875639557,0.7732605363318497,0.987072765827179,0.0501789674162864,0.2925361165228815,43793 -15288.689393758774,4.63522481918335,28589.81309151649,88544,0,28589.81309151649,0.9861477017402648,0.0540336780250072,0.2745361670129824,43793,43885.57111406326,0.9954808950424194,0.0141577227041125,0.7625540593532746,0.987072765827179,0.0501789711415767,0.2924557376319083,43793 -15402.21783900261,4.680805921554565,28829.955509662628,89284,0,28829.955509662628,0.9861477017402648,0.0540336780250072,0.2744330019098556,43793,44239.30815052986,0.9955047965049744,0.0140820499509572,0.7791833701354678,0.987072765827179,0.0501789785921573,0.2923814411188966,43793 -15518.440243244171,4.724467754364014,29070.022459983826,90030,0,29070.022459983826,0.9861477017402648,0.0540336854755878,0.2744070792908892,43793,44595.66125655174,0.9955043792724608,0.0141207063570618,0.7797977344418124,0.987072765827179,0.0501789785921573,0.2923577434288198,43793 -15638.720978021622,4.768067836761475,29310.049928426743,90777,0,29310.049928426743,0.9861477017402648,0.0540336854755878,0.2745075266508673,43793,44956.033213377,0.9954771399497986,0.0141914384439587,0.7664620953486616,0.987072765827179,0.0501789711415767,0.2925620521718061,43793 -15752.73408961296,4.812305450439453,29550.18574547768,91515,0,29550.18574547768,0.9861477017402648,0.0540336780250072,0.2745610067990969,43793,45310.246342897415,0.9954972863197328,0.0141352312639355,0.7745224158954835,0.987072765827179,0.0501789674162864,0.2923887581809313,43793 -15866.069275140762,4.857031583786011,29790.211117506027,92253,0,29790.211117506027,0.9861477017402648,0.0540336780250072,0.2745345744716574,43793,45663.67285728455,0.995453119277954,0.0142373200505971,0.7638528805363961,0.987072765827179,0.0501789785921573,0.292432335709977,43793 -15981.527184009552,4.903205156326294,30030.20610809326,92984,0,30030.20610809326,0.9861477017402648,0.0540336854755878,0.2743704135473845,43793,46019.193653821945,0.9955169558525084,0.0140610234811902,0.7800021202800156,0.987072765827179,0.0501789711415767,0.2924081659905708,43793 -16097.984763383864,4.948692083358765,30270.435495376587,93720,0,30270.435495376587,0.9861477017402648,0.0540336780250072,0.2744685360980869,43793,46375.94743990898,0.995534121990204,0.0141052734106779,0.7757578925328368,0.987072765827179,0.0501789711415767,0.2923997285759094,43793 -16213.29602575302,4.994197368621826,30510.429839372635,94469,0,30510.429839372635,0.9861477017402648,0.0540336780250072,0.2743349141264332,43793,46731.31958055496,0.9954125881195068,0.0142912846058607,0.7713674171966016,0.987072765827179,0.0501789674162864,0.292367312775709,43793 -16328.961030006409,5.038869380950928,30750.42320585251,95213,0,30750.42320585251,0.9861477017402648,0.0540336780250072,0.2743825265260335,43793,47087.043660879135,0.995521366596222,0.0140772210434079,0.7769381347581835,0.987072765827179,0.0501789785921573,0.2924208216760564,43793 -16441.328189611435,5.085657596588135,30990.67506980896,95959,0,30990.67506980896,0.9861477017402648,0.0540336780250072,0.2743611788337915,43793,47439.72964859009,0.9954837560653688,0.0141481598839163,0.7681165605917525,0.987072765827179,0.050178974866867,0.2925440796695244,43793 -16556.407190322876,5.13048243522644,31230.70459461212,96699,0,31230.70459461212,0.9861477017402648,0.0540336780250072,0.2743702992049815,43793,47794.90297555924,0.9954470992088318,0.0142582403495907,0.7757481490303596,0.987072765827179,0.0501789711415767,0.2923798986293627,43793 -16670.566215515137,5.1765077114105225,31470.8629090786,97442,0,31470.8629090786,0.9861477017402648,0.0540336780250072,0.2743770625028734,43793,48149.28699564934,0.9955335855484008,0.0140630081295967,0.7739189992361805,0.987072765827179,0.0501789674162864,0.2924558861565163,43793 -16783.64903306961,5.22965407371521,31710.853369951248,98189,0,31710.853369951248,0.9861477017402648,0.0540336780250072,0.2744425142668042,43793,48502.43397283554,0.9955210089683532,0.0140525046736001,0.7769833804928025,0.987072765827179,0.0501789711415767,0.2925534718224171,43793 -16899.202238321304,5.274975776672363,31950.86460542679,98932,0,31950.86460542679,0.9861477017402648,0.0540336780250072,0.2744893408034077,43793,48858.06380486488,0.9954484701156616,0.014266662299633,0.7728061638280789,0.987072765827179,0.0501789711415767,0.2925407066451959,43793 -17013.020210027695,5.321348190307617,32191.01792359352,99673,0,32191.01792359352,0.9861477017402648,0.0540336780250072,0.2744765086034905,43793,49212.104393959045,0.9955064058303832,0.0140399364754557,0.769985629437685,0.987072765827179,0.0501789674162864,0.2923742836898387,43793 -17127.36168217659,5.368618726730347,32431.11732816696,100419,0,32431.11732816696,0.9861477017402648,0.0540336780250072,0.2743564112248008,43793,49566.61224031448,0.995432198047638,0.0143799893558025,0.7721081716111647,0.987072765827179,0.0501789711415767,0.292357912723474,43793 -17236.50082373619,5.4143126010894775,32671.16210055352,101162,0,32671.16210055352,0.9861477017402648,0.0540336780250072,0.274346909907928,43793,49915.86226391792,0.9955034255981444,0.0141010275110602,0.7755480809717403,0.987072765827179,0.0501789674162864,0.2924520208366916,43793 -17344.714772701263,5.459990978240967,32911.343249082565,101901,0,32911.343249082565,0.9861477017402648,0.0540336780250072,0.2743002376678368,43793,50264.32465529442,0.9954982399940492,0.0141067765653133,0.7797891037776852,0.987072765827179,0.050178974866867,0.2924009146730034,43793 -17454.36855649948,5.506225824356079,33151.44903755188,102658,0,33151.44903755188,0.9861477017402648,0.0540336780250072,0.2744376055525118,43793,50614.150799274445,0.9954644441604614,0.0141842234879732,0.7669780497456298,0.987072765827179,0.050178974866867,0.2925304588886242,43793 -17566.41631412506,5.552358627319336,33391.576731443405,103408,0,33391.576731443405,0.9861477017402648,0.0540336780250072,0.2744152863768196,43793,50966.39300918579,0.9954658150672911,0.0142328403890132,0.7757999081072839,0.987072765827179,0.050178974866867,0.2925922920047358,43793 -17678.056384563446,5.598142623901367,33631.52317357063,104161,0,33631.52317357063,0.9861477017402648,0.0540336780250072,0.2744654155294517,43793,51318.0455186367,0.9954906105995178,0.0141477445140481,0.7707154556254288,0.987072765827179,0.0501789785921573,0.2923405477629439,43793 -17786.38423061371,5.643944025039673,33871.74913954735,104914,0,33871.74913954735,0.9861477017402648,0.0540336780250072,0.2743916661481683,43793,51666.66525363922,0.9955100417137146,0.0141041558235883,0.7736689829168021,0.987072765827179,0.050178974866867,0.2924022043200743,43793 -17891.954362392426,5.690004587173462,34111.82308912277,105653,0,34111.82308912277,0.9861477017402648,0.0540336780250072,0.2744481433858795,43793,52012.376638650894,0.9955135583877563,0.0141152972355484,0.7804668605179181,0.987072765827179,0.050178974866867,0.2926694313673195,43793 -18005.360337734222,5.737663269042969,34351.85087966919,106399,0,34351.85087966919,0.9861477017402648,0.0540336780250072,0.274423879669587,43793,52365.87801599503,0.9954696297645568,0.0141687374562025,0.7746599640365843,0.987072765827179,0.0501789674162864,0.292687754129531,43793 -18119.052418470383,5.792099714279175,34591.82379245758,107146,0,34591.82379245758,0.9861477017402648,0.0540336780250072,0.2743853711222883,43793,52719.61870455742,0.9955094456672668,0.0140214832499623,0.7750776297005886,0.987072765827179,0.050178974866867,0.2924795025938732,43793 -18224.53417825699,5.845448970794678,34831.85524511337,107877,0,34831.85524511337,0.9861477017402648,0.0540336780250072,0.2744068619439111,43793,53065.20850133896,0.9954585433006288,0.0142492037266492,0.7664326035292928,0.987072765827179,0.050178974866867,0.2925241672229261,43793 -18333.98607826233,5.893911361694336,35071.80486369133,108610,0,35071.80486369133,0.9861477017402648,0.0540336780250072,0.2744291479385146,43793,53414.68152284622,0.995460331439972,0.0142254335805773,0.7700719906062522,0.987072765827179,0.0501789674162864,0.29244259418498,43793 -18444.12341618538,5.941523551940918,35311.79198503494,109360,0,35311.79198503494,0.9861477017402648,0.0540336780250072,0.2743291088378314,43793,53764.87401008606,0.9955194592475892,0.0141528416424989,0.7760508584807527,0.987072765827179,0.0501789711415767,0.2923588761610667,43793 -18559.434452056885,5.98870325088501,35551.902698516846,110108,0,35551.902698516846,0.9861477017402648,0.0540336780250072,0.2744111213048259,43793,54120.36315464973,0.9955502152442932,0.0139934895560145,0.7846037522654938,0.987072765827179,0.0501789785921573,0.2923811446565916,43793 -18671.87842154503,6.035488843917847,35791.86517548561,110858,0,35791.86517548561,0.9861477017402648,0.0540336854755878,0.2744625788209498,43793,54472.83672046661,0.9954262971878052,0.0142672462388873,0.7683804375932366,0.987072765827179,0.0501789711415767,0.2923924791678979,43793 -18784.168533802032,6.103245735168457,36031.905665159225,111605,0,36031.905665159225,0.9861477017402648,0.0540336780250072,0.2744341457188285,43793,54825.25518202782,0.9955164790153505,0.0140737434849143,0.7725581660335283,0.987072765827179,0.0501789674162864,0.292436956294154,43793 -18892.454761505127,6.150546789169312,36271.98492002487,112356,0,36271.98492002487,0.9861477017402648,0.0540336780250072,0.2744849630053591,43793,55173.6886715889,0.9954671859741212,0.0142094586044549,0.7682750228101535,0.987072765827179,0.0501789674162864,0.2924133826804372,43793 -18996.58028960228,6.198683500289917,36512.089584350586,113099,0,36512.089584350586,0.9861477017402648,0.0540336854755878,0.2743955276193887,43793,55517.98745632172,0.9955037236213684,0.0141445789486169,0.7743666140906535,0.987072765827179,0.050178974866867,0.2924845483570632,43793 -19104.04130196572,6.246399402618408,36752.29545927048,113842,0,36752.29545927048,0.9861477017402648,0.0540336780250072,0.2743417720933081,43793,55865.72206258774,0.9955094456672668,0.0141254626214504,0.7793063321714918,0.987072765827179,0.0501789674162864,0.2925966404311597,43793 -19218.23925757408,6.296609163284302,36992.5061249733,114587,0,36992.5061249733,0.9861477017402648,0.0540336854755878,0.2744716788874261,43793,56220.20129728317,0.9954116940498352,0.0142963584512472,0.7696221359037616,0.987072765827179,0.050178974866867,0.2923634887193423,43793 -19325.64443707466,6.345672607421875,37232.54695749283,115340,0,37232.54695749283,0.9861477017402648,0.0540336780250072,0.2744870863371934,43793,56567.71599459648,0.9955100417137146,0.0141028864309191,0.7763613971947132,0.987072765827179,0.050178974866867,0.292477968978769,43793 -19433.797921419144,6.394840955734253,37472.57381772995,116085,0,37472.57381772995,0.9861477017402648,0.0540336854755878,0.2744327779240913,43793,56915.96519422531,0.995463728904724,0.0142246037721633,0.7704293389531462,0.987072765827179,0.0501789711415767,0.2923070868083733,43793 -19541.73554468155,6.446910858154297,37712.61404252052,116839,0,37712.61404252052,0.9861477017402648,0.0540336780250072,0.2744755774496301,43793,57264.01523518562,0.9954558610916138,0.0142057770863175,0.7722315712608714,0.987072765827179,0.0501789674162864,0.2924067980397405,43793 -19647.851148843765,6.495280981063843,37952.76475858688,117592,0,37952.76475858688,0.9861477017402648,0.0540336780250072,0.2744644973891466,43793,57610.35025882721,0.995553195476532,0.0139929689466953,0.7718386775875689,0.987072765827179,0.0501789785921573,0.2923722358786466,43793 -19765.65750479698,6.543762683868408,38192.82896399498,118339,0,38192.82896399498,0.9861477017402648,0.0540336780250072,0.2745096845449642,43793,57968.28904867172,0.9954729080200196,0.014159295707941,0.7737373209301305,0.987072765827179,0.050178974866867,0.2923521487118571,43793 -19874.38309621811,6.593130588531494,38432.975707530975,119088,0,38432.975707530975,0.9861477017402648,0.0540336854755878,0.2745536327412705,43793,58317.23108887672,0.9955180287361144,0.0141503792256116,0.7726563510216342,0.987072765827179,0.0501789711415767,0.2924237165053917,43793 -19986.17211055756,6.641467571258545,38673.135501623154,119839,0,38673.135501623154,0.9861477017402648,0.0540336780250072,0.2744036483974371,43793,58669.247754096985,0.995469093322754,0.0141744092106819,0.7726981784844295,0.987072765827179,0.0501789785921573,0.2923949008598069,43793 -20092.66683936119,6.689886331558228,38913.38436675072,120584,0,38913.38436675072,0.9861477017402648,0.0540336780250072,0.2744314207490373,43793,59016.06122899056,0.9954341650009156,0.0142730260267853,0.7724481400829802,0.987072765827179,0.050178974866867,0.2924784708646529,43793 -20207.542983531952,6.741236925125122,39153.452647686005,121324,0,39153.452647686005,0.9861477017402648,0.0540336780250072,0.2744400738644023,43793,59371.07761955261,0.9955268502235411,0.0140812043100595,0.7726060218973945,0.987072765827179,0.0501789674162864,0.2924184000733607,43793 -20320.10085272789,6.794872522354126,39393.46521615982,122076,0,39393.46521615982,0.9861477017402648,0.0540336780250072,0.274348726685141,43793,59723.72233343125,0.9955320358276368,0.0140745975077152,0.7766342497186973,0.987072765827179,0.0501789674162864,0.2924523927774579,43793 -20430.971091747284,6.845506191253662,39633.420177698135,122824,0,39633.420177698135,0.9861477017402648,0.0540336780250072,0.2744166570766237,43793,60074.61845064163,0.9954253435134888,0.014281541109085,0.7695000327559407,0.987072765827179,0.0501789785921573,0.2924960667603323,43793 -20540.36001110077,6.896385669708252,39873.37377142906,123573,0,39873.37377142906,0.9861477017402648,0.0540336780250072,0.2744540328933385,43793,60424.03245639801,0.9954928755760192,0.0141173610463738,0.7755888247643258,0.987072765827179,0.0501789711415767,0.2924262429932207,43793 -20650.623522281647,6.954053163528442,40113.52621340752,124305,0,40113.52621340752,0.9861477017402648,0.0540336780250072,0.2743404442294432,43793,60774.52795505524,0.9954410195350648,0.0142628876492381,0.7707968891314905,0.987072765827179,0.0501789785921573,0.2924575249004623,43793 -20759.99427127838,7.006211042404175,40353.63005948067,125052,0,40353.63005948067,0.9861477017402648,0.0540336780250072,0.2744823761778551,43793,61124.07531356812,0.9955234527587892,0.0140299526974558,0.7771478198568447,0.987072765827179,0.050178974866867,0.2923724415569426,43793 -20869.096217632294,7.056287527084351,40593.73137521744,125798,0,40593.73137521744,0.9861477017402648,0.0540336780250072,0.2744143988411106,43793,61473.34885716438,0.9955294132232666,0.0140703143551945,0.7733585278918658,0.987072765827179,0.0501789785921573,0.2923784314654618,43793 -20976.6785402298,7.106051445007324,40833.93912935257,126555,0,40833.93912935257,0.9861477017402648,0.0540336780250072,0.2743924223147777,43793,61821.20895910263,0.9954558610916138,0.0142387766391038,0.7651018401588172,0.987072765827179,0.0501789711415767,0.2926135407799796,43793 -21086.576821565628,7.158508539199829,41074.18945026398,127310,0,41074.18945026398,0.9861477017402648,0.0540336780250072,0.2743393079338933,43793,62171.43013072014,0.9954926371574402,0.0141503317281603,0.7795920817475045,0.987072765827179,0.0501789711415767,0.2925470066934093,43793 -21195.692762613297,7.215036869049072,41314.4131102562,128049,0,41314.4131102562,0.9861477017402648,0.0540336780250072,0.2744622774438503,43793,62520.848984479904,0.9954724311828612,0.0141659248620271,0.7580778686618597,0.987072765827179,0.0501789711415767,0.2922960154879298,43793 -21306.24901342392,7.266671419143677,41554.57661104202,128800,0,41554.57661104202,0.9861477017402648,0.0540336780250072,0.2744522066223132,43793,62871.640630960464,0.9954509735107422,0.0143068879842758,0.772381038299001,0.987072765827179,0.0501789711415767,0.2925018996634979,43793 -21421.27315235138,7.316658735275267,41794.73320269585,129536,0,41794.73320269585,0.9861477017402648,0.0540336854755878,0.2744407621982706,43793,63226.893221616745,0.9955395460128784,0.0140335224568843,0.7769832884068464,0.987072765827179,0.050178974866867,0.2925394755638397,43793 -21530.62320971489,7.376136064529419,42034.90148067474,130272,0,42034.90148067474,0.9861477017402648,0.0540336780250072,0.2744175331022145,43793,63576.49175739288,0.9954925775527954,0.0141486516222357,0.7728318242745001,0.987072765827179,0.050178974866867,0.2924767408212498,43793 -21641.708025217056,7.427554130554199,42274.94419336319,131001,0,42274.94419336319,0.9861477017402648,0.0540336780250072,0.2744591741016137,43793,63927.694029569626,0.9954299330711364,0.0142009994015097,0.7792628467588556,0.987072765827179,0.0501789711415767,0.2925354472598478,43793 -21753.172875642776,7.480916500091553,42515.0906047821,131735,0,42515.0906047821,0.9861477017402648,0.0540336780250072,0.2744910385094342,43793,64279.381809711456,0.995569348335266,0.0139863276854157,0.7674853021467718,0.987072765827179,0.0501789711415767,0.2923617745264467,43793 -21864.389171123505,7.533536434173584,42755.26333451271,132474,0,42755.26333451271,0.9861477017402648,0.0540336780250072,0.2744854342397442,43793,64630.84366893768,0.9954354166984558,0.0142763331532478,0.7732393560158943,0.987072765827179,0.0501789785921573,0.2925081247424754,43793 -21974.587879419327,7.593179702758789,42995.351095438,133192,0,42995.351095438,0.9861477017402648,0.0540336780250072,0.2745147941177334,43793,64981.21354317665,0.9955052733421326,0.0141240824013948,0.7706519939588254,0.987072765827179,0.0501789711415767,0.292544236251272,43793 -22080.42444038391,7.647066116333008,43235.31395435333,133937,0,43235.31395435333,0.9861477017402648,0.0540336780250072,0.2744378650368807,43793,65327.087550640106,0.9955182671546936,0.0140765700489282,0.7777763457581376,0.987072765827179,0.050178974866867,0.292485678695066,43793 -22188.27728843689,7.699393749237059,43475.29788017273,134684,0,43475.29788017273,0.9861477017402648,0.0540336854755878,0.2744699717030999,43793,65674.99710512161,0.9954232573509216,0.0143252816051244,0.7749691994204116,0.987072765827179,0.050178974866867,0.292514974609005,43793 -22292.08743953705,7.750726938247681,43715.30280971527,135432,0,43715.30280971527,0.9861477017402648,0.0540336780250072,0.2743817458990037,43793,66018.88409805298,0.9955505132675172,0.0139947505667805,0.771572487802147,0.987072765827179,0.0501789674162864,0.2924638056018725,43793 -22403.13285589218,7.805722713470459,43955.23675751686,136175,0,43955.23675751686,0.9861477017402648,0.0540336780250072,0.2743165139314618,43793,66369.940782547,0.9954281449317932,0.0143410833552479,0.7698603206108076,0.987072765827179,0.0501789674162864,0.2923284004565294,43793 -22508.251557588577,7.863274812698364,44195.45515322685,136917,0,44195.45515322685,0.9861477017402648,0.0540336780250072,0.2744015337243864,43793,66715.35626888275,0.9954986572265624,0.0141018470749259,0.7790549032148038,0.987072765827179,0.0501789785921573,0.2925348605695119,43793 -22616.19847869873,7.916009187698364,44435.672504901886,137657,0,44435.672504901886,0.9861477017402648,0.0540336780250072,0.2743812715418015,43793,67063.59295868874,0.9955148100852966,0.0141217578202486,0.7732282225552818,0.987072765827179,0.050178974866867,0.2923614648157295,43793 -22724.865947008133,7.969079494476318,44675.83380961418,138404,0,44675.83380961418,0.9861477017402648,0.0540336780250072,0.2746092836619395,43793,67412.49493050575,0.9954774975776672,0.0141524942591786,0.7739244373319915,0.987072765827179,0.0501789711415767,0.2923292070342875,43793 -22832.310029745106,8.023373365402222,44915.87771701813,139153,0,44915.87771701813,0.9861477017402648,0.0540336854755878,0.2744020805623824,43793,67760.05751276016,0.9954907298088074,0.0141388094052672,0.7785713528209731,0.987072765827179,0.0501789674162864,0.2926596141909245,43793 -22938.924989938736,8.08329463005066,45156.082431316376,139900,0,45156.082431316376,0.9861477017402648,0.0540336854755878,0.2744188849967219,43793,68106.95738077164,0.995424747467041,0.0142870005220174,0.761383769673432,0.987072765827179,0.0501789785921573,0.292531407052199,43793 -23049.42386555672,8.13773775100708,45396.19762325287,140642,0,45396.19762325287,0.9861477017402648,0.0540336854755878,0.2745654491453426,43793,68457.6458311081,0.9955102205276488,0.0141165740787982,0.7750042961359505,0.987072765827179,0.0501789674162864,0.2924117972188245,43793 -23157.95317864418,8.191657781600952,45636.25564050674,141381,0,45636.25564050674,0.9861477017402648,0.0540336780250072,0.2744659269564402,43793,68806.30915427208,0.9955143928527832,0.0140935350209474,0.777542621227472,0.987072765827179,0.0501789674162864,0.2923862031675611,43793 -23266.372206687927,8.247857809066772,45876.27267360687,142126,0,45876.27267360687,0.9861477017402648,0.0540336780250072,0.2745318743089971,43793,69154.82185125351,0.9954879283905028,0.0141484467312693,0.7753354730267371,0.987072765827179,0.0501789711415767,0.2923758806923691,43793 -23373.897718191147,8.308566331863403,46116.45699834824,142855,0,46116.45699834824,0.9861477017402648,0.0540336780250072,0.2744458325069142,43793,69502.61528301239,0.9955083727836608,0.0140606602653861,0.783613736914508,0.987072765827179,0.0501789674162864,0.2923354873325368,43793 -23486.36240339279,8.371394395828247,46356.47380423546,143584,0,46356.47380423546,0.9861477017402648,0.0540336780250072,0.2744276065726812,43793,69855.18458938599,0.9954497814178468,0.0142854899168014,0.7609779749806512,0.987072765827179,0.0501789711415767,0.2925092009419555,43793 -23593.684818267822,8.426125526428223,46596.55307674408,144327,0,46596.55307674408,0.9861477017402648,0.0540336854755878,0.2744851973090345,43793,70202.66167020798,0.9954796433448792,0.0142038194462656,0.7661755426876338,0.987072765827179,0.0501789674162864,0.2926325091584312,43793 -23699.498154878616,8.482213497161865,46836.53562474251,145070,0,46836.53562474251,0.9861477017402648,0.0540336780250072,0.274352949729243,43793,70548.53416538239,0.9954817295074464,0.0141157330945134,0.7795299761374892,0.987072765827179,0.050178974866867,0.2923431012403145,43793 -23806.48648405075,8.537409782409668,47076.47433376312,145811,0,47076.47433376312,0.9861477017402648,0.0540336780250072,0.2744340913738719,43793,70895.5379679203,0.995508313179016,0.0141398878768086,0.7773677847036007,0.987072765827179,0.0501789674162864,0.2924520914694519,43793 -23914.076204538345,8.591312885284424,47316.4525346756,146559,0,47316.4525346756,0.9861477017402648,0.0540336780250072,0.2745671273673323,43793,71243.1797504425,0.9955046772956848,0.0141176972538232,0.7730506080338739,0.987072765827179,0.0501789674162864,0.2924591081368707,43793 -24018.88875675201,8.645967245101929,47556.43048453331,147295,0,47556.43048453331,0.9861477017402648,0.0540336780250072,0.2743590851789299,43793,71588.04568099976,0.9955081939697266,0.0140927387401461,0.7726374976251774,0.987072765827179,0.0501789711415767,0.2924923169390194,43793 -24124.432327747345,8.701493501663208,47796.37060403824,148044,0,47796.37060403824,0.9861477017402648,0.0540336780250072,0.2744173503944239,43793,71933.60494160652,0.995439887046814,0.0142427571117877,0.7694661257596855,0.987072765827179,0.050178974866867,0.292422802303003,43793 -24230.023329257965,8.754802942276001,48036.4431014061,148793,0,48036.4431014061,0.9861477017402648,0.0540336780250072,0.2743672697535215,43793,72279.34216976166,0.9954816102981568,0.0141702918335795,0.7750624725733591,0.987072765827179,0.0501789674162864,0.2924158822165662,43793 -24339.857694149017,8.81046986579895,48276.60923433304,149541,0,48276.60923433304,0.9861477017402648,0.0540336780250072,0.2745146309739536,43793,72629.41816806793,0.9955478310585022,0.0140143623575568,0.7779941175755147,0.987072765827179,0.050178974866867,0.2925357045022977,43793 -24444.07306456566,8.864438533782959,48516.63416361809,150297,0,48516.63416361809,0.9861477017402648,0.0540336780250072,0.2744235536170402,43793,72973.73292207718,0.9954404234886168,0.0142840798944234,0.7704192706591224,0.987072765827179,0.0501789674162864,0.2923980627222845,43793 -24551.4905269146,8.925953388214111,48756.61280536652,151056,0,48756.61280536652,0.9861477017402648,0.0540336780250072,0.2744903066878821,43793,73321.21061086655,0.9955169558525084,0.0140691706910729,0.7767232637434311,0.987072765827179,0.0501789674162864,0.2923393462303533,43793 -24659.799003601074,9.30262327194214,48996.37939405441,151802,0,48996.37939405441,0.9861477017402648,0.0540336780250072,0.2743454136781733,43793,73669.68244862556,0.9954544305801392,0.0141996936872601,0.7650808939684499,0.987072765827179,0.0501789711415767,0.2923891979268853,43793 -24766.657678842545,9.358002424240112,49236.43760156632,152546,0,49236.43760156632,0.9861477017402648,0.0540336780250072,0.274417204869546,43793,74016.67475938797,0.9955030083656312,0.0141865722835063,0.7775761097654365,0.987072765827179,0.050178974866867,0.2923806819872739,43793 -24867.64109492302,9.417367696762083,49476.56862664223,153288,0,49476.56862664223,0.9861477017402648,0.0540336780250072,0.2743833817906149,43793,74357.86896109581,0.9954869747161864,0.014114367775619,0.775209529543543,0.987072765827179,0.050178974866867,0.2923222150566751,43793 -24974.34710907936,9.47364616394043,49716.54668498039,154025,0,49716.54668498039,0.9861477017402648,0.0540336780250072,0.2744389354879151,43793,74704.63004040718,0.9954938888549804,0.0141519317403435,0.7746333298750403,0.987072765827179,0.0501789674162864,0.2925510130236306,43793 -25078.46404027939,9.529872179031372,49956.64419960976,154776,0,49956.64419960976,0.9861477017402648,0.0540336780250072,0.2744790929915433,43793,75048.92160797119,0.995477020740509,0.0141919394955039,0.772443339172785,0.987072765827179,0.0501789674162864,0.292435496395953,43793 -25187.63874030113,9.584931373596191,50196.8704969883,155522,0,50196.8704969883,0.9861477017402648,0.0540336780250072,0.274395693634054,43793,75398.39845585823,0.9955002665519714,0.014126512221992,0.76504731206366,0.987072765827179,0.050178974866867,0.2924108013284655,43793 -25296.34480142593,9.640064239501951,50436.93785953522,156274,0,50436.93785953522,0.9861477017402648,0.0540336780250072,0.2744178720782878,43793,75747.2478158474,0.9954464435577391,0.0142757603898644,0.7749395935942144,0.987072765827179,0.0501789674162864,0.2924427218899189,43793 -25406.3370552063,9.698690176010132,50677.16159534454,157024,0,50677.16159534454,0.9861477017402648,0.0540336780250072,0.2743299048412865,43793,76097.5432009697,0.9955170154571532,0.0140560949221253,0.773829621089907,0.987072765827179,0.050178974866867,0.2922885314680591,43793 -25511.156789064407,9.754071712493896,50917.29948115349,157769,0,50917.29948115349,0.9861477017402648,0.0540336854755878,0.2744481542418377,43793,76442.57691502571,0.9955052137374878,0.0141193345189094,0.7832026154449414,0.987072765827179,0.050178974866867,0.2924538217306947,43793 -25618.3601474762,9.810832738876345,51157.23203110695,158521,0,51157.23203110695,0.9861477017402648,0.0540336780250072,0.2744134394539258,43793,76789.78963327408,0.995479941368103,0.0141672519966959,0.7708454919205815,0.987072765827179,0.0501789674162864,0.2924374967382563,43793 -25725.11342215538,9.86826467514038,51397.40039587021,159271,0,51397.40039587021,0.9861477017402648,0.0540336780250072,0.2745207814921742,43793,77136.78889465332,0.9954665899276732,0.0142047116532921,0.7754148010822717,0.987072765827179,0.0501789711415767,0.2925999843117496,43793 -25829.893330335617,9.92480731010437,51637.54203295708,160026,0,51637.54203295708,0.9861477017402648,0.0540336780250072,0.2745612762808954,43793,77481.7868654728,0.9954699277877808,0.0141995521262288,0.7704376398776334,0.987072765827179,0.050178974866867,0.2923189401479742,43793 -25930.87596058845,9.9811851978302,51877.696434021,160775,0,51877.696434021,0.9861477017402648,0.0540336780250072,0.2744706562424572,43793,77823.0005209446,0.9954940676689148,0.0141418268904089,0.7768644225591141,0.987072765827179,0.0501789674162864,0.2923581265098338,43793 -26039.03204393387,10.039782524108888,52117.85686635971,161528,0,52117.85686635971,0.9861477017402648,0.0540336780250072,0.2743982161334868,43793,78171.39575123787,0.995476245880127,0.0142111340537667,0.7777987133311279,0.987072765827179,0.0501789711415767,0.2924305651890328,43793 -26143.434961795807,10.106394290924072,52358.030159950256,162275,0,52358.030159950256,0.9861477017402648,0.0540336780250072,0.2744069678953243,43793,78516.05895709991,0.9954986572265624,0.0141007341444492,0.7723155867022389,0.987072765827179,0.0501789785921573,0.2925567746610274,43793 -26246.93022799492,10.163278579711914,52598.15224838257,163022,0,52598.15224838257,0.9861477017402648,0.0540336780250072,0.2744230971769069,43793,78859.75465726852,0.9955146312713624,0.0140972472727298,0.7758655269731343,0.987072765827179,0.0501789674162864,0.292350433179705,43793 -26349.03185725212,10.222217321395874,52838.39267086983,163771,0,52838.39267086983,0.9861477017402648,0.0540336780250072,0.2744978733779977,43793,79202.17622852325,0.99542498588562,0.0142718916758894,0.7642837725594658,0.987072765827179,0.0501789785921573,0.2923946279295678,43793 -26455.47572350502,10.280250310897827,53078.33946681023,164511,0,53078.33946681023,0.9861477017402648,0.0540336780250072,0.2743908303667785,43793,79548.64507198334,0.9954708814620972,0.0141994431614875,0.7770340807670562,0.987072765827179,0.0501789674162864,0.2923320008925272,43793 -26558.87539362908,10.338693141937256,53318.45733141899,165264,0,53318.45733141899,0.9861477017402648,0.0540336780250072,0.2744477263950919,43793,79892.24152255058,0.9955559968948364,0.014011014252901,0.7774745704495394,0.987072765827179,0.0501789674162864,0.2923164205091282,43793 -26662.17934703827,10.398459672927856,53558.51423501968,166013,0,53558.51423501968,0.9861477017402648,0.0540336780250072,0.2744653453568398,43793,80235.681828022,0.9955283999443054,0.0141153950244188,0.7682561890850204,0.987072765827179,0.0501789785921573,0.2923437326480599,43793 -26771.942001581192,10.45738959312439,53798.596177339554,166758,0,53798.596177339554,0.9861477017402648,0.0540336780250072,0.2744245503170944,43793,80585.60542726517,0.9954188466072084,0.0142263481393456,0.7785996373453546,0.987072765827179,0.0501789711415767,0.2923470375904951,43793 -26880.91184401512,10.514888763427734,54038.70073080063,167501,0,54038.70073080063,0.9861477017402648,0.0540336780250072,0.274327476345126,43793,80934.76012897491,0.9954667687416076,0.0141899297013878,0.7694678128577732,0.987072765827179,0.0501789785921573,0.2924166168514924,43793 -26992.4874894619,10.573052167892456,54278.76337981224,168246,0,54278.76337981224,0.9861477017402648,0.0540336780250072,0.2744346043114226,43793,81286.47692489624,0.9954833388328552,0.014200116507709,0.7711502631017008,0.987072765827179,0.0501789711415767,0.2924447621749889,43793 -27097.03732419014,10.639933347702026,54518.72223806381,168989,0,54518.72223806381,0.9861477017402648,0.0540336780250072,0.2744506152585214,43793,81631.07444930077,0.9955009818077089,0.0141321178525686,0.7773532093488216,0.987072765827179,0.050178974866867,0.2924251409088802,43793 -27202.15098762512,10.69777512550354,54758.81324410439,169737,0,54758.81324410439,0.9861477017402648,0.0540336780250072,0.2743306396429991,43793,81976.3571498394,0.9955348372459412,0.0140738934278488,0.7799641597540772,0.987072765827179,0.0501789711415767,0.2924950547097051,43793 -27307.94575953484,10.756629467010498,54998.77826976776,170490,0,54998.77826976776,0.9861477017402648,0.0540336780250072,0.2744199995516733,43793,82322.19646334648,0.995444118976593,0.0141743142157793,0.7645382818223485,0.987072765827179,0.0501789674162864,0.2925678030672734,43793 -27419.886949539185,10.8176851272583,55238.96902322769,171230,0,55238.96902322769,0.9861477017402649,0.05403367802500725,0.2744220064947314,43793,82674.4102742672,0.995492696762085,0.014181965962052345,0.7743298517611776,0.987072765827179,0.050178974866867065,0.29255196553221835,43793 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index ac7bca8ee..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1952 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.8176769,0.7969516,,,,,,,,,,,,,,,,, -1,,,0.4159256517887115,0.7988808155059814,0.0215982155287161,0.4129833579063415,0.7986935973167419,0.0252096170502471,43793.0,0.413312017917633,0.7984466552734375,0.0260650279516998,43793.0,19.18818688392639,556.0487859249115,19.18818688392639,536.8605499267578,0.0,0.0 -100,0.3168575,0.29237947,,,,,,,,,,,,,,,,, -200,0.10401216,0.11525812,,,,,,,,,,,,,,,,, -300,0.03403454,0.07123977,,,,,,,,,,,,,,,,, -400,0.020361101,0.062121026,,,,,,,,,,,,,,,,, -500,0.037369,0.055113003,,,,,,,,,,,,,,,,, -600,0.08303836,0.051122222,,,,,,,,,,,,,,,,, -700,0.022644823,0.051207032,,,,,,,,,,,,,,,,, -744,,,0.9869332909584044,0.0506426654756069,0.0564247931206779,0.9842758178710938,0.060369711369276,0.0555193829226929,43793.0,0.983284831047058,0.0636374056339263,0.0572574487730643,43793.0,259.19083619117737,918.4641320705414,259.19083619117737,659.2226083278656,0.0308315753936767,0.0 -800,0.0139924325,0.047376785,,,,,,,,,,,,,,,,, -900,0.024515338,0.053743392,,,,,,,,,,,,,,,,, -1000,0.029711735,0.050741546,,,,,,,,,,,,,,,,, -1100,0.018164707,0.04895498,,,,,,,,,,,,,,,,, -1200,0.025821256,0.04787698,,,,,,,,,,,,,,,,, -1300,0.043173265,0.050634526,,,,,,,,,,,,,,,,, -1400,0.028970411,0.047887404,,,,,,,,,,,,,,,,, -1481,,,0.9874383211135864,0.0463490262627601,0.1024011862839557,0.9846505522727966,0.0559681691229343,0.1030645214502391,43793.0,0.9836386442184448,0.0592108108103275,0.1004557028126341,43793.0,499.22732520103455,1281.8988349437714,499.22732520103455,782.5733218193054,0.0585832595825195,0.0 -1500,0.018785696,0.048534907,,,,,,,,,,,,,,,,, -1600,0.024491739,0.045746952,,,,,,,,,,,,,,,,, -1700,0.0373838,0.044479538,,,,,,,,,,,,,,,,, -1800,0.016535696,0.04489483,,,,,,,,,,,,,,,,, -1900,0.015485273,0.04706202,,,,,,,,,,,,,,,,, -2000,0.018703071,0.04769511,,,,,,,,,,,,,,,,, -2100,0.011753742,0.039850794,,,,,,,,,,,,,,,,, -2200,0.010969804,0.03860945,,,,,,,,,,,,,,,,, -2219,,,0.9876528978347778,0.044420201331377,0.1416577263864256,0.984782874584198,0.0541465990245342,0.137276791969974,43793.0,0.9837620854377748,0.0573773942887783,0.1340239544379468,43793.0,739.374596118927,1644.7690522670746,739.374596118927,905.2488882541656,0.0856134891510009,0.0 -2300,0.01930325,0.04200928,,,,,,,,,,,,,,,,, -2400,0.011556304,0.0405158,,,,,,,,,,,,,,,,, -2500,0.012202272,0.04269539,,,,,,,,,,,,,,,,, -2600,0.018665249,0.04210348,,,,,,,,,,,,,,,,, -2700,0.018503105,0.04549755,,,,,,,,,,,,,,,,, -2800,0.019552594,0.04496784,,,,,,,,,,,,,,,,, -2900,0.01136059,0.04445969,,,,,,,,,,,,,,,,, -2972,,,0.988052487373352,0.042341161519289,0.1634633712465514,0.9852432012557985,0.0518907941877841,0.1526518339536486,43793.0,0.9842371940612792,0.0549632683396339,0.1510426967436176,43793.0,979.5938329696656,2009.3818798065183,979.5938329696656,1029.594246149063,0.1141455173492431,0.0 -3000,0.01611384,0.042463884,,,,,,,,,,,,,,,,, -3100,0.010604904,0.04193078,,,,,,,,,,,,,,,,, -3200,0.009723875,0.043793134,,,,,,,,,,,,,,,,, -3300,0.011847022,0.042650662,,,,,,,,,,,,,,,,, -3400,0.0126790395,0.044060968,,,,,,,,,,,,,,,,, -3500,0.011518176,0.038687497,,,,,,,,,,,,,,,,, -3600,0.011132275,0.038447365,,,,,,,,,,,,,,,,, -3700,0.010035537,0.040817223,,,,,,,,,,,,,,,,, -3731,,,0.9881816506385804,0.0408075787127018,0.1910989147009367,0.9852667450904846,0.0502415895462036,0.1666665480608038,43793.0,0.9843096137046814,0.053251676261425,0.1643560816542293,43793.0,1219.685434341431,2376.389413833618,1219.685434341431,1156.464278936386,0.1400582790374755,0.0 -3800,0.012210219,0.04310893,,,,,,,,,,,,,,,,, -3900,0.022705162,0.04519741,,,,,,,,,,,,,,,,, -4000,0.014875881,0.039580505,,,,,,,,,,,,,,,,, -4100,0.012138567,0.040149,,,,,,,,,,,,,,,,, -4200,0.0144183105,0.046758167,,,,,,,,,,,,,,,,, -4300,0.01572337,0.03715817,,,,,,,,,,,,,,,,, -4400,0.018405171,0.040916324,,,,,,,,,,,,,,,,, -4491,,,0.9886085391044616,0.0388112030923366,0.2143681979482724,0.985697865486145,0.0487300790846347,0.1857914809779446,43793.0,0.9847784042358398,0.051520898938179,0.1892872621373282,43793.0,1459.7971086502075,2742.1424379348755,1459.7971086502075,1282.059383392334,0.1661477088928222,0.0 -4500,0.008281799,0.035406534,,,,,,,,,,,,,,,,, -4600,0.014743689,0.041961107,,,,,,,,,,,,,,,,, -4700,0.012739566,0.03716656,,,,,,,,,,,,,,,,, -4800,0.01110523,0.040268857,,,,,,,,,,,,,,,,, -4900,0.020470103,0.03568149,,,,,,,,,,,,,,,,, -5000,0.01310933,0.03701922,,,,,,,,,,,,,,,,, -5100,0.012750776,0.037506875,,,,,,,,,,,,,,,,, -5200,0.018333282,0.039426636,,,,,,,,,,,,,,,,, -5251,,,0.9887627959251404,0.0380763970315456,0.2471778039737912,0.9855229258537292,0.0480495318770408,0.1963845763391835,43793.0,0.9847156405448914,0.0506240203976631,0.2009949060040942,43793.0,1699.9812216758728,3112.7249789237976,1699.9812216758728,1412.4098732471466,0.1936049461364746,0.0 -5300,0.014067859,0.03661409,,,,,,,,,,,,,,,,, -5400,0.0099365255,0.036215413,,,,,,,,,,,,,,,,, -5500,0.011091538,0.042244934,,,,,,,,,,,,,,,,, -5600,0.014835398,0.039553184,,,,,,,,,,,,,,,,, -5700,0.014413313,0.03922562,,,,,,,,,,,,,,,,, -5800,0.009702098,0.03961841,,,,,,,,,,,,,,,,, -5900,0.013627697,0.042018533,,,,,,,,,,,,,,,,, -6000,0.01095878,0.034761123,,,,,,,,,,,,,,,,, -6007,,,0.9892771244049072,0.0365316681563854,0.2702459337040197,0.986055076122284,0.0466743931174278,0.2205379475830717,43793.0,0.98514062166214,0.0493680611252784,0.2163814131888277,43793.0,1940.042201757431,3480.2014927864075,1940.042201757431,1539.7698109149933,0.2290825843811035,0.0 -6100,0.013831774,0.036212273,,,,,,,,,,,,,,,,, -6200,0.012907647,0.036838464,,,,,,,,,,,,,,,,, -6300,0.017017445,0.03931703,,,,,,,,,,,,,,,,, -6400,0.016929518,0.04010206,,,,,,,,,,,,,,,,, -6500,0.014949971,0.037511684,,,,,,,,,,,,,,,,, -6600,0.017933326,0.039361306,,,,,,,,,,,,,,,,, -6700,0.013554266,0.03296085,,,,,,,,,,,,,,,,, -6754,,,0.9893851280212402,0.0358582735061645,0.2830660872122692,0.9861866235733032,0.0466683804988861,0.2184325413342826,43793.0,0.985238790512085,0.0494339019060134,0.2141966573978142,43793.0,2180.241662979126,3850.060249567032,2180.241662979126,1669.380942583084,0.2561860084533691,0.0 -6800,0.01486736,0.037658636,,,,,,,,,,,,,,,,, -6900,0.016852781,0.037515003,,,,,,,,,,,,,,,,, -7000,0.017276393,0.03955469,,,,,,,,,,,,,,,,, -7100,0.012672594,0.039834276,,,,,,,,,,,,,,,,, -7200,0.0153131215,0.036409516,,,,,,,,,,,,,,,,, -7300,0.02619241,0.03761406,,,,,,,,,,,,,,,,, -7400,0.014376548,0.040863138,,,,,,,,,,,,,,,,, -7500,0.01879261,0.038109686,,,,,,,,,,,,,,,,, -7502,,,0.9895324110984802,0.035286109894514,0.3089379075079039,0.9861618280410768,0.0457406304776668,0.2240028748716719,43793.0,0.985331416130066,0.0481740869581699,0.2225995118593554,43793.0,2420.270122051239,4213.264726400375,2420.270122051239,1792.50821018219,0.2849104404449463,0.0 -7600,0.013739595,0.040074937,,,,,,,,,,,,,,,,, -7700,0.013353367,0.037552822,,,,,,,,,,,,,,,,, -7800,0.017298928,0.040222332,,,,,,,,,,,,,,,,, -7900,0.0155312475,0.033853054,,,,,,,,,,,,,,,,, -8000,0.014315587,0.032689817,,,,,,,,,,,,,,,,, -8100,0.024891354,0.040100794,,,,,,,,,,,,,,,,, -8200,0.028276818,0.034511216,,,,,,,,,,,,,,,,, -8252,,,0.9896116852760316,0.0348451845347881,0.2959481314296294,0.9864756464958192,0.0453048162162303,0.2330827604520835,43793.0,0.9855239391326904,0.0481030829250812,0.2293570996321657,43793.0,2660.5069646835327,4581.020742177963,2660.5069646835327,1919.977798461914,0.3132412433624267,0.0 -8300,0.01965271,0.033564147,,,,,,,,,,,,,,,,, -8400,0.015834732,0.03719676,,,,,,,,,,,,,,,,, -8500,0.02305593,0.034542263,,,,,,,,,,,,,,,,, -8600,0.029696444,0.036788158,,,,,,,,,,,,,,,,, -8700,0.016745338,0.032593336,,,,,,,,,,,,,,,,, -8800,0.017136909,0.034063414,,,,,,,,,,,,,,,,, -8900,0.017627915,0.03522176,,,,,,,,,,,,,,,,, -9000,0.020407274,0.03790874,,,,,,,,,,,,,,,,, -9003,,,0.98968768119812,0.0344913601875305,0.3023390847822264,0.9864464402198792,0.0456723272800445,0.2355386945639433,43793.0,0.9855251908302308,0.0483233146369457,0.2416666839001448,43793.0,2900.5411689281464,4951.450447320938,2900.5411689281464,2050.3256623744965,0.3406836986541748,0.0 -9100,0.02068756,0.038706467,,,,,,,,,,,,,,,,, -9200,0.015938025,0.03256019,,,,,,,,,,,,,,,,, -9300,0.025839647,0.03518712,,,,,,,,,,,,,,,,, -9400,0.018991362,0.03559945,,,,,,,,,,,,,,,,, -9500,0.015270545,0.03389803,,,,,,,,,,,,,,,,, -9600,0.02472681,0.030358678,,,,,,,,,,,,,,,,, -9700,0.021257445,0.034914654,,,,,,,,,,,,,,,,, -9738,,,0.9900109171867372,0.0331563651561737,0.3391733973208491,0.9865714311599731,0.044988140463829,0.2418029850919489,43793.0,0.985731601715088,0.0476286262273788,0.2443877526841293,43793.0,3140.5600502491,5317.872577667236,3140.5600502491,2176.679932117462,0.3683798313140869,0.0 -9800,0.019524839,0.03294609,,,,,,,,,,,,,,,,, -9900,0.019977313,0.0294885,,,,,,,,,,,,,,,,, -10000,0.019072972,0.0324784,,,,,,,,,,,,,,,,, -10100,0.022847777,0.041384894,,,,,,,,,,,,,,,,, -10200,0.022803681,0.037258092,,,,,,,,,,,,,,,,, -10300,0.022653274,0.036092002,,,,,,,,,,,,,,,,, -10400,0.023600206,0.03364772,,,,,,,,,,,,,,,,, -10481,,,0.9901197552680968,0.0329639092087745,0.3427990933394376,0.9865767359733582,0.044519018381834,0.2523031039881308,43793.0,0.9857054352760316,0.0473709106445312,0.2480712251497775,43793.0,3380.5997710227966,5688.5130405426025,3380.5997710227966,2307.231298685074,0.3975164890289306,0.0 -10500,0.023924677,0.035022847,,,,,,,,,,,,,,,,, -10600,0.026141677,0.03790025,,,,,,,,,,,,,,,,, -10700,0.023289667,0.034919538,,,,,,,,,,,,,,,,, -10800,0.028320719,0.033973623,,,,,,,,,,,,,,,,, -10900,0.020257168,0.03257068,,,,,,,,,,,,,,,,, -11000,0.024104854,0.032346833,,,,,,,,,,,,,,,,, -11100,0.025206491,0.036526658,,,,,,,,,,,,,,,,, -11200,0.020701196,0.034059707,,,,,,,,,,,,,,,,, -11228,,,0.990300476551056,0.0322925746440887,0.3670847887656891,0.9866266250610352,0.0446537360548973,0.256547329570424,43793.0,0.9857454895973206,0.0474274680018425,0.2429146183176039,43793.0,3620.5970821380615,6058.763890743256,3620.5970821380615,2437.435226202011,0.4268772602081299,0.0 -11300,0.022960972,0.034766648,,,,,,,,,,,,,,,,, -11400,0.024858708,0.0319999,,,,,,,,,,,,,,,,, -11500,0.025958257,0.032953024,,,,,,,,,,,,,,,,, -11600,0.024222884,0.029755129,,,,,,,,,,,,,,,,, -11700,0.029143134,0.035057977,,,,,,,,,,,,,,,,, -11800,0.023344561,0.036948495,,,,,,,,,,,,,,,,, -11900,0.024155714,0.032385916,,,,,,,,,,,,,,,,, -11964,,,0.990463137626648,0.0314462892711162,0.3924278482920513,0.9866583347320556,0.0446167550981044,0.2597770575960549,43793.0,0.9857821464538574,0.0473833791911602,0.2490345214654525,43793.0,3860.692731380463,6435.028518915176,3860.692731380463,2573.549907445908,0.4588387012481689,0.0 -12000,0.02818826,0.034065668,,,,,,,,,,,,,,,,, -12100,0.029530292,0.03539952,,,,,,,,,,,,,,,,, -12200,0.028059237,0.033380758,,,,,,,,,,,,,,,,, -12300,0.02638173,0.030423323,,,,,,,,,,,,,,,,, -12400,0.020731688,0.03161689,,,,,,,,,,,,,,,,, -12500,0.03384843,0.03405277,,,,,,,,,,,,,,,,, -12600,0.026957547,0.034945693,,,,,,,,,,,,,,,,, -12700,0.027243815,0.033796,,,,,,,,,,,,,,,,, -12707,,,0.9905829429626464,0.0309941451996564,0.387986523878819,0.9866992831230164,0.0445774421095848,0.262996735421005,43793.0,0.9858554005622864,0.0474821887910366,0.2588720984792856,43793.0,4100.820775270462,6805.316474914551,4100.820775270462,2703.654602050781,0.4913444519042969,0.0 -12800,0.029122077,0.0347348,,,,,,,,,,,,,,,,, -12900,0.02750661,0.03418482,,,,,,,,,,,,,,,,, -13000,0.025699781,0.030459309,,,,,,,,,,,,,,,,, -13100,0.0237146,0.031065324,,,,,,,,,,,,,,,,, -13200,0.030174803,0.03432681,,,,,,,,,,,,,,,,, -13300,0.030297052,0.034337036,,,,,,,,,,,,,,,,, -13400,0.02560462,0.029890554,,,,,,,,,,,,,,,,, -13452,,,0.9906689524650574,0.0307073928415775,0.3913719672774121,0.9868081212043762,0.0442623235285282,0.26061023076219,43793.0,0.9859215617179872,0.0470420643687248,0.2574354219744418,43793.0,4340.906958580017,7176.506344079971,4340.906958580017,2834.7084777355194,0.521092414855957,0.0 -13500,0.024407381,0.0359592,,,,,,,,,,,,,,,,, -13600,0.028844496,0.030375702,,,,,,,,,,,,,,,,, -13700,0.02714839,0.029929316,,,,,,,,,,,,,,,,, -13800,0.02307728,0.027971068,,,,,,,,,,,,,,,,, -13900,0.045609668,0.035800554,,,,,,,,,,,,,,,,, -14000,0.0400745,0.033561826,,,,,,,,,,,,,,,,, -14100,0.03752301,0.03317969,,,,,,,,,,,,,,,,, -14200,,,0.9906706213951112,0.0308509096503257,0.3912459024790902,0.9867467880249025,0.0441011153161525,0.2619010691977937,43793.0,0.985874354839325,0.0469127632677555,0.2512387055784661,43793.0,4581.159161567688,7543.514000415802,4581.159161567688,2961.4146535396576,0.5505294799804688,0.0 -14200,0.035758678,0.028997002,,,,,,,,,,,,,,,,, -14300,0.027526636,0.03476118,,,,,,,,,,,,,,,,, -14400,0.04043193,0.035932,,,,,,,,,,,,,,,,, -14500,0.038111046,0.03134549,,,,,,,,,,,,,,,,, -14600,0.025339566,0.030273572,,,,,,,,,,,,,,,,, -14700,0.04530024,0.033331573,,,,,,,,,,,,,,,,, -14800,0.038109265,0.031596307,,,,,,,,,,,,,,,,, -14900,0.03147125,0.03162441,,,,,,,,,,,,,,,,, -14944,,,0.9905703663825988,0.0310817845165729,0.382999974042977,0.9868085384368896,0.0445431284606456,0.2633013852647012,43793.0,0.9859371185302734,0.0473070107400417,0.2592420343565292,43793.0,4821.198109149933,7915.161383390427,4821.198109149933,3092.9744811058044,0.5795705318450928,0.0 -15000,0.025234092,0.029617716,,,,,,,,,,,,,,,,, -15100,0.04897748,0.032243524,,,,,,,,,,,,,,,,, -15200,0.027312154,0.027449315,,,,,,,,,,,,,,,,, -15300,0.03587427,0.03576749,,,,,,,,,,,,,,,,, -15400,0.04182959,0.0390797,,,,,,,,,,,,,,,,, -15500,0.03348238,0.034253392,,,,,,,,,,,,,,,,, -15600,0.043918546,0.032738622,,,,,,,,,,,,,,,,, -15692,,,0.9907108545303344,0.0305028110742568,0.4101220713336749,0.9868288040161132,0.0439677983522415,0.2711070559332986,43793.0,0.9859729409217834,0.0467361696064472,0.2624294844696725,43793.0,5061.430529117584,8285.463474035263,5061.430529117584,3222.9942405223846,0.6097488403320312,0.0 -15700,0.0407534,0.03615333,,,,,,,,,,,,,,,,, -15800,0.03618036,0.032723293,,,,,,,,,,,,,,,,, -15900,0.033987947,0.03282333,,,,,,,,,,,,,,,,, -16000,0.03169537,0.03123046,,,,,,,,,,,,,,,,, -16100,0.034546778,0.031791,,,,,,,,,,,,,,,,, -16200,0.047112186,0.031619333,,,,,,,,,,,,,,,,, -16300,0.033032432,0.028423263,,,,,,,,,,,,,,,,, -16400,0.03488565,0.032316983,,,,,,,,,,,,,,,,, -16434,,,0.9906653761863708,0.0305659603327512,0.3961345079374617,0.9868064522743224,0.0440406166017055,0.2708324881862166,43793.0,0.9860348105430604,0.0468795970082283,0.2668301827293936,43793.0,5301.552769899368,8655.007350206375,5301.552769899368,3352.3659710884094,0.6399135589599609,0.0 -16500,0.046775367,0.034555446,,,,,,,,,,,,,,,,, -16600,0.032008067,0.030005142,,,,,,,,,,,,,,,,, -16700,0.059359934,0.034042817,,,,,,,,,,,,,,,,, -16800,0.043836124,0.03316181,,,,,,,,,,,,,,,,, -16900,0.037312668,0.029903445,,,,,,,,,,,,,,,,, -17000,0.040749744,0.034838274,,,,,,,,,,,,,,,,, -17100,0.03789676,0.03350928,,,,,,,,,,,,,,,,, -17176,,,0.9908429384231568,0.0300751198083162,0.4131038732595066,0.9868007898330688,0.0441732369363307,0.2631906991121361,43793.0,0.9859800934791564,0.0467806570231914,0.2602629097743209,43793.0,5541.53219461441,9024.735594034197,5541.53219461441,3482.0657093524933,0.6692507266998291,0.0 -17200,0.03894709,0.03029865,,,,,,,,,,,,,,,,, -17300,0.040823314,0.03228401,,,,,,,,,,,,,,,,, -17400,0.0398039,0.033460896,,,,,,,,,,,,,,,,, -17500,0.03805678,0.033025034,,,,,,,,,,,,,,,,, -17600,0.03557171,0.033387195,,,,,,,,,,,,,,,,, -17700,0.035844907,0.03132279,,,,,,,,,,,,,,,,, -17800,0.038398925,0.030123176,,,,,,,,,,,,,,,,, -17900,0.039511923,0.03283869,,,,,,,,,,,,,,,,, -17933,,,0.991029977798462,0.0294215101748704,0.4207735413460807,0.9868117570877076,0.0441651605069637,0.2655381044547623,43793.0,0.9859012961387634,0.0467691197991371,0.262596531655433,43793.0,5781.594381570816,9391.62649512291,5781.594381570816,3608.844582557678,0.699115514755249,0.0 -18000,0.049304694,0.030149931,,,,,,,,,,,,,,,,, -18100,0.04516165,0.030357683,,,,,,,,,,,,,,,,, -18200,0.049517684,0.0359323,,,,,,,,,,,,,,,,, -18300,0.04379739,0.03225406,,,,,,,,,,,,,,,,, -18400,0.037698284,0.03214487,,,,,,,,,,,,,,,,, -18500,0.052256223,0.031275537,,,,,,,,,,,,,,,,, -18600,0.041978803,0.031226622,,,,,,,,,,,,,,,,, -18678,,,0.9913448691368104,0.0282798688858747,0.4553588251086291,0.9869843125343324,0.043611004948616,0.2739963602729493,43793.0,0.9860984086990356,0.0463713221251964,0.2668474315698864,43793.0,6021.693260192871,9761.431322336197,6021.693260192871,3738.4994037151337,0.730036735534668,0.0 -18700,0.0502999,0.029645693,,,,,,,,,,,,,,,,, -18800,0.047739822,0.032028534,,,,,,,,,,,,,,,,, -18900,0.04670657,0.034308128,,,,,,,,,,,,,,,,, -19000,0.03756248,0.032601874,,,,,,,,,,,,,,,,, -19100,0.048903342,0.033962026,,,,,,,,,,,,,,,,, -19200,0.042873785,0.033317562,,,,,,,,,,,,,,,,, -19300,0.056835536,0.03506462,,,,,,,,,,,,,,,,, -19400,0.037951905,0.028144516,,,,,,,,,,,,,,,,, -19421,,,0.991253674030304,0.0285461358726024,0.4504025211918963,0.9868957996368408,0.0441204234957695,0.2729475198973159,43793.0,0.98604154586792,0.046822752803564,0.2666428469243027,43793.0,6261.680698633194,10131.17840242386,6261.680698633194,3868.2067382335663,0.7608804702758789,0.0 -19500,0.050806217,0.03220125,,,,,,,,,,,,,,,,, -19600,0.044048347,0.032321155,,,,,,,,,,,,,,,,, -19700,0.040591735,0.03355936,,,,,,,,,,,,,,,,, -19800,0.061510675,0.029903904,,,,,,,,,,,,,,,,, -19900,0.04169246,0.033058245,,,,,,,,,,,,,,,,, -20000,0.043271948,0.03064201,,,,,,,,,,,,,,,,, -20100,0.04276509,0.029229991,,,,,,,,,,,,,,,,, -20164,,,0.9912835359573364,0.0284360982477664,0.4628734931838974,0.9869039058685304,0.043829821050167,0.2754548308109774,43793.0,0.9860731363296508,0.0467290207743644,0.2667028073713977,43793.0,6501.814088821411,10501.064566135406,6501.814088821411,3997.909258365631,0.7910916805267334,0.0 -20200,0.041663617,0.030681202,,,,,,,,,,,,,,,,, -20300,0.05267294,0.031390008,,,,,,,,,,,,,,,,, -20400,0.04755023,0.028186206,,,,,,,,,,,,,,,,, -20500,0.054109424,0.029588299,,,,,,,,,,,,,,,,, -20600,0.09327198,0.03722686,,,,,,,,,,,,,,,,, -20700,0.046937924,0.029031089,,,,,,,,,,,,,,,,, -20800,0.04123195,0.030859878,,,,,,,,,,,,,,,,, -20900,0.060158968,0.032906823,,,,,,,,,,,,,,,,, -20909,,,0.9910857677459716,0.0290487967431545,0.4316780481261128,0.9869911670684814,0.0437979027628898,0.2797274694723569,43793.0,0.986042857170105,0.046838354319334,0.2644472106800913,43793.0,6741.928815841675,10870.3255007267,6741.928815841675,4127.003994941711,0.8224525451660156,0.0 -21000,0.035194922,0.029218078,,,,,,,,,,,,,,,,, -21100,0.049579322,0.029235108,,,,,,,,,,,,,,,,, -21200,0.04565212,0.032757793,,,,,,,,,,,,,,,,, -21300,0.0517585,0.030544154,,,,,,,,,,,,,,,,, -21400,0.045764018,0.030701792,,,,,,,,,,,,,,,,, -21500,0.044506934,0.02877115,,,,,,,,,,,,,,,,, -21600,0.046032045,0.03186293,,,,,,,,,,,,,,,,, -21660,,,0.9911643266677856,0.028825219720602,0.4205208436652018,0.9870212078094482,0.0438480079174041,0.2758581640517665,43793.0,0.986139714717865,0.0466521978378295,0.2727992682591666,43793.0,6981.9416534900665,11249.46938920021,6981.9416534900665,4266.083292245865,0.8539795875549316,0.0 -21700,0.05366627,0.029842429,,,,,,,,,,,,,,,,, -21800,0.05594112,0.031117413,,,,,,,,,,,,,,,,, -21900,0.047559377,0.030072402,,,,,,,,,,,,,,,,, -22000,0.049005058,0.028472167,,,,,,,,,,,,,,,,, -22100,0.04571907,0.029381914,,,,,,,,,,,,,,,,, -22200,0.04762279,0.030790277,,,,,,,,,,,,,,,,, -22300,0.046329178,0.028992208,,,,,,,,,,,,,,,,, -22400,,,0.9911457896232604,0.029067164286971,0.4441765352018306,0.9868693947792052,0.0439830757677555,0.2746977620566058,43793.0,0.9860116839408876,0.0466559007763862,0.2678142261886165,43793.0,7222.009767055511,11620.568714141846,7222.009767055511,4397.061862707138,0.8849740028381348,0.0 -22400,0.054605797,0.03387639,,,,,,,,,,,,,,,,, -22500,0.04569575,0.032049898,,,,,,,,,,,,,,,,, -22600,0.046044797,0.030652186,,,,,,,,,,,,,,,,, -22700,0.04452632,0.027679987,,,,,,,,,,,,,,,,, -22800,0.04731273,0.032458365,,,,,,,,,,,,,,,,, -22900,0.043598503,0.027751502,,,,,,,,,,,,,,,,, -23000,0.043208167,0.029942732,,,,,,,,,,,,,,,,, -23100,0.05017168,0.02996815,,,,,,,,,,,,,,,,, -23147,,,0.9912724494934082,0.0284826979041099,0.4513491212236842,0.987023651599884,0.0437239073216915,0.2754543168780456,43793.0,0.9860929846763612,0.0467063896358013,0.2686252559516493,43793.0,7462.221218585968,11994.254369735718,7462.221218585968,4530.485562801361,0.9154069423675536,0.0 -23200,0.04903403,0.026875542,,,,,,,,,,,,,,,,, -23300,0.038237154,0.030307619,,,,,,,,,,,,,,,,, -23400,0.05118476,0.030255476,,,,,,,,,,,,,,,,, -23500,0.047354773,0.030913675,,,,,,,,,,,,,,,,, -23600,0.052345812,0.031781543,,,,,,,,,,,,,,,,, -23700,0.047921292,0.029703682,,,,,,,,,,,,,,,,, -23800,0.047498107,0.03284348,,,,,,,,,,,,,,,,, -23892,,,0.9913005828857422,0.0282000079751014,0.4627503649671614,0.9869924187660216,0.0442375540733337,0.2778445772522299,43793.0,0.9861068725585938,0.0471601597964763,0.2613879065969777,43793.0,7702.266625642776,12366.427788496016,7702.266625642776,4662.562074661255,0.9465370178222656,0.0 -23900,0.069757804,0.028244702,,,,,,,,,,,,,,,,, -24000,0.0533246,0.028122416,,,,,,,,,,,,,,,,, -24100,0.05317066,0.028903415,,,,,,,,,,,,,,,,, -24200,0.04951519,0.03002383,,,,,,,,,,,,,,,,, -24300,0.047279634,0.027739095,,,,,,,,,,,,,,,,, -24400,0.05000483,0.030490551,,,,,,,,,,,,,,,,, -24500,0.047197316,0.028736202,,,,,,,,,,,,,,,,, -24600,0.044602785,0.034464896,,,,,,,,,,,,,,,,, -24632,,,0.991439700126648,0.0278518870472908,0.4635890547062042,0.9869059324264526,0.043697815388441,0.2797329279295364,43793.0,0.9860866665840148,0.046415239572525,0.2697986160231079,43793.0,7942.415654420853,12735.605735778809,7942.415654420853,4791.539767742157,0.978119134902954,0.0 -24700,0.055775438,0.03335333,,,,,,,,,,,,,,,,, -24800,0.05207021,0.032296717,,,,,,,,,,,,,,,,, -24900,0.053960264,0.03393924,,,,,,,,,,,,,,,,, -25000,0.054156378,0.030326804,,,,,,,,,,,,,,,,, -25100,0.079191685,0.031008951,,,,,,,,,,,,,,,,, -25200,0.045773577,0.030135576,,,,,,,,,,,,,,,,, -25300,0.059902646,0.03084736,,,,,,,,,,,,,,,,, -25374,,,0.9915762543678284,0.0273897033184766,0.4749107520915142,0.9869404435157776,0.0440693721175193,0.2789975196120251,43793.0,0.9860929846763612,0.0466537848114967,0.2697822094920301,43793.0,8182.555495977402,13104.798302650452,8182.555495977402,4920.541341781616,1.0089633464813232,0.0 -25400,0.051552612,0.029031143,,,,,,,,,,,,,,,,, -25500,0.047239304,0.029820167,,,,,,,,,,,,,,,,, -25600,0.046704747,0.028456783,,,,,,,,,,,,,,,,, -25700,0.05533983,0.028028851,,,,,,,,,,,,,,,,, -25800,0.056611136,0.03267659,,,,,,,,,,,,,,,,, -25900,0.047657017,0.031858496,,,,,,,,,,,,,,,,, -26000,0.05058276,0.028003175,,,,,,,,,,,,,,,,, -26100,0.05712215,0.033265322,,,,,,,,,,,,,,,,, -26125,,,0.9917351603507996,0.0267021115869283,0.4959739328984074,0.986988365650177,0.0440122708678245,0.2815963343967564,43793.0,0.9861578345298768,0.0468244887888431,0.275521530104557,43793.0,8422.817781925201,13472.200419664385,8422.817781925201,5047.630271434784,1.0395777225494385,0.0 -26200,0.05267439,0.028528986,,,,,,,,,,,,,,,,, -26300,0.04866396,0.031414237,,,,,,,,,,,,,,,,, -26400,0.04979815,0.030099321,,,,,,,,,,,,,,,,, -26500,0.07714839,0.035105456,,,,,,,,,,,,,,,,, -26600,0.05412589,0.02883169,,,,,,,,,,,,,,,,, -26700,0.06968984,0.029165018,,,,,,,,,,,,,,,,, -26800,0.05031562,0.028631331,,,,,,,,,,,,,,,,, -26870,,,0.9916290044784546,0.0272741466760635,0.4856420312623262,0.9868893027305604,0.043773666024208,0.2773456521371067,43793.0,0.9860352277755736,0.0468922369182109,0.2654404572829221,43793.0,8662.943137168884,13846.104840278624,8662.943137168884,5181.35660982132,1.0719294548034668,0.0 -26900,0.05628506,0.027504263,,,,,,,,,,,,,,,,, -27000,0.04941592,0.027956951,,,,,,,,,,,,,,,,, -27100,0.051853877,0.027244983,,,,,,,,,,,,,,,,, -27200,0.058689307,0.031055124,,,,,,,,,,,,,,,,, -27300,0.05254655,0.030377932,,,,,,,,,,,,,,,,, -27400,0.059979696,0.03444662,,,,,,,,,,,,,,,,, -27500,0.06269112,0.030048594,,,,,,,,,,,,,,,,, -27600,0.046253987,0.02626603,,,,,,,,,,,,,,,,, -27618,,,0.9916507601737976,0.0270869210362434,0.4838185757181377,0.9870354533195496,0.0440697260200977,0.2777947569709336,43793.0,0.9862993359565736,0.046916589140892,0.2668597689788753,43793.0,8903.035054683685,14217.63026714325,8903.035054683685,5312.733390331268,1.108165979385376,0.0 -27700,0.055670638,0.029338237,,,,,,,,,,,,,,,,, -27800,0.09418677,0.032405835,,,,,,,,,,,,,,,,, -27900,0.0475609,0.02826774,,,,,,,,,,,,,,,,, -28000,0.06313946,0.031007078,,,,,,,,,,,,,,,,, -28100,0.05684075,0.029615166,,,,,,,,,,,,,,,,, -28200,0.056590665,0.028215377,,,,,,,,,,,,,,,,, -28300,0.05531239,0.02894216,,,,,,,,,,,,,,,,, -28366,,,0.99137681722641,0.0279942359775304,0.46102037043348,0.9870212078094482,0.0440024808049202,0.2821440113552839,43793.0,0.9861291646957396,0.0469421707093715,0.2653502959951228,43793.0,9143.177698135376,14584.213775396349,9143.177698135376,5439.122165679932,1.1402628421783447,0.0 -28400,0.06111458,0.034270305,,,,,,,,,,,,,,,,, -28500,0.052806877,0.032727506,,,,,,,,,,,,,,,,, -28600,0.07358466,0.029025655,,,,,,,,,,,,,,,,, -28700,0.047955576,0.028569495,,,,,,,,,,,,,,,,, -28800,0.05612074,0.028986795,,,,,,,,,,,,,,,,, -28900,0.052071355,0.030539855,,,,,,,,,,,,,,,,, -29000,0.052019693,0.027952919,,,,,,,,,,,,,,,,, -29100,0.049740385,0.03196732,,,,,,,,,,,,,,,,, -29113,,,0.9914528131484984,0.0275321006774902,0.4668566314086093,0.9869303107261658,0.0438966490328311,0.2802756131313504,43793.0,0.9860820174217224,0.0466917492449283,0.2662801861727709,43793.0,9383.163092136385,14954.858783721924,9383.163092136385,5569.727295160294,1.1746866703033447,0.0 -29200,0.057682026,0.03175514,,,,,,,,,,,,,,,,, -29300,0.05682658,0.030491574,,,,,,,,,,,,,,,,, -29400,0.050819326,0.029821279,,,,,,,,,,,,,,,,, -29500,0.058707677,0.02922763,,,,,,,,,,,,,,,,, -29600,0.048275735,0.023345247,,,,,,,,,,,,,,,,, -29700,0.049382288,0.029471215,,,,,,,,,,,,,,,,, -29800,0.05069393,0.02812141,,,,,,,,,,,,,,,,, -29860,,,0.9916426539421082,0.0269223246723413,0.4755547289723321,0.9869992733001708,0.0440030097961425,0.2888798783758334,43793.0,0.9861574172973632,0.0468879751861095,0.2757145467342193,43793.0,9623.224762439728,15319.655102968216,9623.224762439728,5694.410034656525,1.206719160079956,0.0 -29900,0.049963344,0.025494816,,,,,,,,,,,,,,,,, -30000,0.060175423,0.03206782,,,,,,,,,,,,,,,,, -30100,0.053237744,0.026744923,,,,,,,,,,,,,,,,, -30200,0.060931012,0.03112358,,,,,,,,,,,,,,,,, -30300,0.077839546,0.0316995,,,,,,,,,,,,,,,,, -30400,0.051900193,0.025355881,,,,,,,,,,,,,,,,, -30500,0.051938094,0.027225202,,,,,,,,,,,,,,,,, -30600,0.058524087,0.027921772,,,,,,,,,,,,,,,,, -30604,,,0.9915139079093932,0.0273679699748754,0.4902585092633399,0.987023651599884,0.0440075658261776,0.2788562240367521,43793.0,0.9861317276954652,0.0468833744525909,0.2673260861461388,43793.0,9863.198410272598,15688.741545438766,9863.198410272598,5823.471135616303,1.2388403415679932,0.0 -30700,0.05308541,0.0277981,,,,,,,,,,,,,,,,, -30800,0.06388877,0.029726781,,,,,,,,,,,,,,,,, -30900,0.061501388,0.031136937,,,,,,,,,,,,,,,,, -31000,0.062356953,0.025748404,,,,,,,,,,,,,,,,, -31100,0.047986094,0.029161489,,,,,,,,,,,,,,,,, -31200,0.061370097,0.02968523,,,,,,,,,,,,,,,,, -31300,0.05891055,0.029486032,,,,,,,,,,,,,,,,, -31352,,,0.9917362928390504,0.026557233184576,0.5076482448586728,0.987015962600708,0.0445297770202159,0.2808674786219359,43793.0,0.9862100481987,0.0471400320529937,0.2710934963817139,43793.0,10103.280684947968,16054.598301887512,10103.280684947968,5949.193348169327,1.2714645862579346,0.0 -31400,0.065860964,0.030721065,,,,,,,,,,,,,,,,, -31500,0.05455309,0.026686067,,,,,,,,,,,,,,,,, -31600,0.051664405,0.030536901,,,,,,,,,,,,,,,,, -31700,0.07251937,0.030205896,,,,,,,,,,,,,,,,, -31800,0.059162304,0.030598938,,,,,,,,,,,,,,,,, -31900,0.06371659,0.025986249,,,,,,,,,,,,,,,,, -32000,0.066165455,0.03416238,,,,,,,,,,,,,,,,, -32100,,,0.9919912815093994,0.0256998073309659,0.516720698352363,0.9870447516441344,0.0440525151789188,0.2836175444213666,43793.0,0.9861077070236206,0.0469487346708774,0.2723113894305002,43793.0,10343.537009716034,16421.74523949623,10343.537009716034,6076.021708488464,1.3138582706451416,0.0 -32100,0.06874028,0.029380945,,,,,,,,,,,,,,,,, -32200,0.06650916,0.029091796,,,,,,,,,,,,,,,,, -32300,0.06683005,0.027537428,,,,,,,,,,,,,,,,, -32400,0.0528851,0.029255832,,,,,,,,,,,,,,,,, -32500,0.055980295,0.029793318,,,,,,,,,,,,,,,,, -32600,0.07686371,0.029972369,,,,,,,,,,,,,,,,, -32700,0.05665341,0.027759736,,,,,,,,,,,,,,,,, -32800,0.06757954,0.030744117,,,,,,,,,,,,,,,,, -32846,,,0.991959512233734,0.0257570836693048,0.5138833451896644,0.9868990182876588,0.0438992008566856,0.2828000828133626,43793.0,0.985975444316864,0.0468091182410717,0.2697275506495091,43793.0,10583.608781576157,16789.060408830643,10583.608781576157,6203.211605548859,1.3478012084960938,0.0 -32900,0.06060958,0.025451353,,,,,,,,,,,,,,,,, -33000,0.07435762,0.032413885,,,,,,,,,,,,,,,,, -33100,0.06673037,0.029371753,,,,,,,,,,,,,,,,, -33200,0.053547535,0.025647897,,,,,,,,,,,,,,,,, -33300,0.062216032,0.029115735,,,,,,,,,,,,,,,,, -33400,0.06147701,0.026518902,,,,,,,,,,,,,,,,, -33500,0.061792962,0.026424099,,,,,,,,,,,,,,,,, -33587,,,0.9921208620071412,0.0254428181797266,0.5243534403747736,0.9870321750640868,0.043906345963478,0.2832817246832959,43793.0,0.9861405491828918,0.0469157211482524,0.2739759901800147,43793.0,10823.852623224258,17156.531865119934,10823.852623224258,6330.386728048325,1.3803482055664062,0.0 -33600,0.06867617,0.028735928,,,,,,,,,,,,,,,,, -33700,0.05898191,0.030985799,,,,,,,,,,,,,,,,, -33800,0.054745067,0.027796576,,,,,,,,,,,,,,,,, -33900,0.066671684,0.032330416,,,,,,,,,,,,,,,,, -34000,0.064625524,0.028445413,,,,,,,,,,,,,,,,, -34100,0.06465747,0.03036586,,,,,,,,,,,,,,,,, -34200,0.060848027,0.028401162,,,,,,,,,,,,,,,,, -34300,0.06836264,0.030786064,,,,,,,,,,,,,,,,, -34331,,,0.9919935464859008,0.0258319564163684,0.5051137293483496,0.9869655966758728,0.0441421791911125,0.2822380868318115,43793.0,0.9860954880714417,0.0470939092338085,0.2688336477584092,43793.0,11063.997322559357,17527.88009405136,11063.997322559357,6461.535669565201,1.4151837825775146,0.0 -34400,0.057842784,0.027749784,,,,,,,,,,,,,,,,, -34500,0.061373934,0.027796289,,,,,,,,,,,,,,,,, -34600,0.06786202,0.030719506,,,,,,,,,,,,,,,,, -34700,0.066461846,0.028614314,,,,,,,,,,,,,,,,, -34800,0.06364783,0.028192712,,,,,,,,,,,,,,,,, -34900,0.046874702,0.024676898,,,,,,,,,,,,,,,,, -35000,0.06066823,0.027983142,,,,,,,,,,,,,,,,, -35077,,,0.9918292164802552,0.0262144785374403,0.5070033636209796,0.9870935082435608,0.0444223061203956,0.2835313514000049,43793.0,0.9861544370651244,0.0474284365773201,0.2655076803593601,43793.0,11304.021178007126,17893.06462287903,11304.021178007126,6586.641371488571,1.4497976303100586,0.0 -35100,0.05807991,0.02762211,,,,,,,,,,,,,,,,, -35200,0.05370012,0.02533677,,,,,,,,,,,,,,,,, -35300,0.06769648,0.029557765,,,,,,,,,,,,,,,,, -35400,0.06115276,0.03150616,,,,,,,,,,,,,,,,, -35500,0.063386746,0.024980098,,,,,,,,,,,,,,,,, -35600,0.06875066,0.024036558,,,,,,,,,,,,,,,,, -35700,0.06298054,0.03035415,,,,,,,,,,,,,,,,, -35800,0.0666112,0.028607162,,,,,,,,,,,,,,,,, -35819,,,0.9920508861541748,0.0255311019718647,0.5223396912909026,0.9869850873947144,0.0442952550947666,0.2810178823058246,43793.0,0.986088752746582,0.0473106503486633,0.266058645154373,43793.0,11544.161909103394,18260.17632985115,11544.161909103394,6713.556245326996,1.4856345653533936,0.0 -35900,0.06920605,0.03133941,,,,,,,,,,,,,,,,, -36000,0.07770276,0.030419037,,,,,,,,,,,,,,,,, -36100,0.06794023,0.028775947,,,,,,,,,,,,,,,,, -36200,0.06715701,0.029713325,,,,,,,,,,,,,,,,, -36300,0.075102195,0.026552116,,,,,,,,,,,,,,,,, -36400,0.070522785,0.026652515,,,,,,,,,,,,,,,,, -36500,0.056693297,0.024023626,,,,,,,,,,,,,,,,, -36565,,,0.9919530153274536,0.0257911756634712,0.5123207397683507,0.9869518280029296,0.044221568852663,0.2859384338148794,43793.0,0.9860820174217224,0.0470707193017005,0.266341244878793,43793.0,11784.26586842537,18631.508261442184,11784.26586842537,6844.728674411774,1.5204980373382568,0.0 -36600,0.06321321,0.027681313,,,,,,,,,,,,,,,,, -36700,0.063328505,0.023683017,,,,,,,,,,,,,,,,, -36800,0.05896513,0.027767697,,,,,,,,,,,,,,,,, -36900,0.06350888,0.030008925,,,,,,,,,,,,,,,,, -37000,0.06561772,0.025555132,,,,,,,,,,,,,,,,, -37100,0.066323295,0.02612306,,,,,,,,,,,,,,,,, -37200,0.06697184,0.02791466,,,,,,,,,,,,,,,,, -37298,,,0.9920364618301392,0.0253801997750997,0.5294398886495806,0.9870569705963136,0.0445403046905994,0.2838524985635823,43793.0,0.9861974120140076,0.0475418157875537,0.2678350340250758,43793.0,12024.3671708107,19001.5023636818,12024.3671708107,6974.566897153854,1.554847240447998,0.0 -37300,0.06366145,0.026182532,,,,,,,,,,,,,,,,, -37400,0.06394638,0.032157052,,,,,,,,,,,,,,,,, -37500,0.078720935,0.027906483,,,,,,,,,,,,,,,,, -37600,0.08002505,0.02885582,,,,,,,,,,,,,,,,, -37700,0.06131587,0.025164312,,,,,,,,,,,,,,,,, -37800,0.06710809,0.028075835,,,,,,,,,,,,,,,,, -37900,0.05386283,0.023056973,,,,,,,,,,,,,,,,, -38000,0.078463376,0.026484577,,,,,,,,,,,,,,,,, -38029,,,0.9923050403594972,0.0245301537215709,0.5316224932311173,0.9869834780693054,0.044165726751089,0.2815878288103547,43793.0,0.9860537648200988,0.0468669906258583,0.2705610278003447,43793.0,12264.32112288475,19370.045646190643,12264.32112288475,7103.1003510952,1.5899608135223389,0.0 -38100,0.07817018,0.029645959,,,,,,,,,,,,,,,,, -38200,0.07695909,0.02660068,,,,,,,,,,,,,,,,, -38300,0.073003486,0.02585329,,,,,,,,,,,,,,,,, -38400,0.07398999,0.029160352,,,,,,,,,,,,,,,,, -38500,0.076345995,0.02764676,,,,,,,,,,,,,,,,, -38600,0.06485041,0.024232881,,,,,,,,,,,,,,,,, -38700,0.062242705,0.02553788,,,,,,,,,,,,,,,,, -38775,,,0.9925279021263124,0.0239597801119089,0.5613465445746542,0.9869436621665956,0.0446382723748683,0.2795682889672867,43793.0,0.9861186742782592,0.0474701225757598,0.2700197825705046,43793.0,12504.334551811218,19735.243750810623,12504.334551811218,7228.229280233383,1.6254029273986816,0.0 -38800,0.069132656,0.027454317,,,,,,,,,,,,,,,,, -38900,0.067611635,0.025389288,,,,,,,,,,,,,,,,, -39000,0.07286063,0.027707942,,,,,,,,,,,,,,,,, -39100,0.08320916,0.026127331,,,,,,,,,,,,,,,,, -39200,0.08578292,0.024245871,,,,,,,,,,,,,,,,, -39300,0.07114834,0.027087064,,,,,,,,,,,,,,,,, -39400,0.06547792,0.026951667,,,,,,,,,,,,,,,,, -39500,0.08146758,0.027845774,,,,,,,,,,,,,,,,, -39518,,,0.992580771446228,0.0237706210464239,0.5658957285509353,0.986880362033844,0.0443928875029087,0.2854602876735767,43793.0,0.986042857170105,0.0473844185471534,0.2716201668149071,43793.0,12744.566393852234,20098.38708305359,12744.566393852234,7351.084131479263,1.660273790359497,0.0 -39600,0.07502373,0.028158952,,,,,,,,,,,,,,,,, -39700,0.05649,0.024884256,,,,,,,,,,,,,,,,, -39800,0.063641265,0.025686078,,,,,,,,,,,,,,,,, -39900,0.0709741,0.02807892,,,,,,,,,,,,,,,,, -40000,0.07144256,0.02841439,,,,,,,,,,,,,,,,, -40100,0.060846727,0.026658906,,,,,,,,,,,,,,,,, -40200,0.064880036,0.024371281,,,,,,,,,,,,,,,,, -40264,,,0.9925384521484376,0.0238226205110549,0.5491199180507131,0.9870293140411376,0.0445915944874286,0.2869749485756071,43793.0,0.9861733913421632,0.0474028438329696,0.2700235207289608,43793.0,12984.634531497955,20463.572825193405,12984.634531497955,7476.147330522537,1.694267749786377,0.0 -40300,0.07224259,0.027186535,,,,,,,,,,,,,,,,, -40400,0.06599569,0.028165944,,,,,,,,,,,,,,,,, -40500,0.068825565,0.025336709,,,,,,,,,,,,,,,,, -40600,0.061231937,0.025690302,,,,,,,,,,,,,,,,, -40700,0.064902514,0.02771631,,,,,,,,,,,,,,,,, -40800,0.06666272,0.028946213,,,,,,,,,,,,,,,,, -40900,0.07491027,0.027917996,,,,,,,,,,,,,,,,, -41000,0.067019224,0.026397882,,,,,,,,,,,,,,,,, -41010,,,0.992478609085083,0.0241543073207139,0.5465455027988128,0.9870561361312866,0.0445823147892952,0.2870653774717982,43793.0,0.9862361550331116,0.0475085414946079,0.2737203631907097,43793.0,13224.755261421204,20829.687710762024,13224.755261421204,7602.085725069046,1.7297906875610352,0.0 -41100,0.07035059,0.025496677,,,,,,,,,,,,,,,,, -41200,0.06615538,0.026151849,,,,,,,,,,,,,,,,, -41300,0.073582135,0.02901931,,,,,,,,,,,,,,,,, -41400,0.07923896,0.029628009,,,,,,,,,,,,,,,,, -41500,0.079320684,0.02804347,,,,,,,,,,,,,,,,, -41600,0.07119914,0.026457347,,,,,,,,,,,,,,,,, -41700,0.07916245,0.028623657,,,,,,,,,,,,,,,,, -41754,,,0.9922721982002258,0.0246417224407196,0.5371658929798946,0.9870626330375672,0.0446793921291828,0.2881509323311192,43793.0,0.9862349033355712,0.0477738939225673,0.2742440255073568,43793.0,13464.745002746582,21196.616488933563,13464.745002746582,7728.96719622612,1.7673423290252686,0.0 -41800,0.06738397,0.026966928,,,,,,,,,,,,,,,,, -41900,0.07285489,0.027451003,,,,,,,,,,,,,,,,, -42000,0.08183742,0.027411543,,,,,,,,,,,,,,,,, -42100,0.075441286,0.028311968,,,,,,,,,,,,,,,,, -42200,0.07299705,0.027564345,,,,,,,,,,,,,,,,, -42300,0.08200138,0.029697325,,,,,,,,,,,,,,,,, -42400,0.091227874,0.02631655,,,,,,,,,,,,,,,,, -42494,,,0.9924113154411316,0.0241913646459579,0.5530880035861376,0.9869915843009948,0.0445187799632549,0.2845378522712065,43793.0,0.9861738085746764,0.0473472066223621,0.2703982284928044,43793.0,13704.911063671112,21560.149626255035,13704.911063671112,7852.276547431946,1.8038551807403564,0.0 -42500,0.07049434,0.025113115,,,,,,,,,,,,,,,,, -42600,0.08139328,0.025586054,,,,,,,,,,,,,,,,, -42700,0.07411033,0.026299035,,,,,,,,,,,,,,,,, -42800,0.06211443,0.023930114,,,,,,,,,,,,,,,,, -42900,0.0864832,0.029571306,,,,,,,,,,,,,,,,, -43000,0.07791138,0.023030315,,,,,,,,,,,,,,,,, -43100,0.06610049,0.023234483,,,,,,,,,,,,,,,,, -43200,0.08197271,0.028781693,,,,,,,,,,,,,,,,, -43234,,,0.9924986958503724,0.0238905660808086,0.5642352845260664,0.987065851688385,0.0448200851678848,0.2887220173444159,43793.0,0.9861283302307128,0.0478745140135288,0.2756288716884729,43793.0,13944.901382684708,21930.13982820511,13944.901382684708,7982.220856189728,1.8395116329193115,0.0 -43300,0.07861281,0.025056228,,,,,,,,,,,,,,,,, -43400,0.088114925,0.029068548,,,,,,,,,,,,,,,,, -43500,0.08281549,0.026988683,,,,,,,,,,,,,,,,, -43600,0.07358697,0.023969047,,,,,,,,,,,,,,,,, -43700,0.06759662,0.025627384,,,,,,,,,,,,,,,,, -43800,0.06949776,0.02496233,,,,,,,,,,,,,,,,, -43900,0.06837179,0.024507007,,,,,,,,,,,,,,,,, -43970,,,0.9926086664199828,0.0232233218848705,0.5818679028796002,0.9871900677680968,0.0448637790977954,0.2902499637499417,43793.0,0.9862905144691468,0.0479585081338882,0.274493591062196,43793.0,14184.867455482485,22292.821218252186,14184.867455482485,8104.878615617752,1.876089572906494,0.0 -44000,0.07292281,0.026861887,,,,,,,,,,,,,,,,, -44100,0.07734775,0.022435062,,,,,,,,,,,,,,,,, -44200,0.07677669,0.026636746,,,,,,,,,,,,,,,,, -44300,0.07743709,0.027392138,,,,,,,,,,,,,,,,, -44400,0.082306355,0.02570017,,,,,,,,,,,,,,,,, -44500,0.08223782,0.027059145,,,,,,,,,,,,,,,,, -44600,0.07902405,0.025296304,,,,,,,,,,,,,,,,, -44700,0.092308074,0.024356226,,,,,,,,,,,,,,,,, -44713,,,0.9928114414215088,0.0229176878929138,0.5719944676468065,0.987003743648529,0.0447695814073085,0.286275793689501,43793.0,0.9861140251159668,0.0478437356650829,0.269938147158632,43793.0,14425.079031467438,22656.048730373383,14425.079031467438,8227.839316368103,1.911909580230713,0.0 -44800,0.08459542,0.026033651,,,,,,,,,,,,,,,,, -44900,0.0682643,0.024373015,,,,,,,,,,,,,,,,, -45000,0.07247297,0.024931539,,,,,,,,,,,,,,,,, -45100,0.095450096,0.026321854,,,,,,,,,,,,,,,,, -45200,0.082765244,0.025865944,,,,,,,,,,,,,,,,, -45300,0.08371166,0.02322332,,,,,,,,,,,,,,,,, -45400,0.07842001,0.02467111,,,,,,,,,,,,,,,,, -45458,,,0.9930723309516908,0.0220610983669757,0.6057093471035695,0.9870406985282898,0.0448703877627849,0.2880019608799084,43793.0,0.9861544370651244,0.0478927008807659,0.2740243349431455,43793.0,14665.148543596268,23020.918548822403,14665.148543596268,8352.581824541092,1.949923276901245,0.0 -45500,0.07558875,0.023219611,,,,,,,,,,,,,,,,, -45600,0.07435297,0.02338655,,,,,,,,,,,,,,,,, -45700,0.10957728,0.021728583,,,,,,,,,,,,,,,,, -45800,0.07206492,0.023721438,,,,,,,,,,,,,,,,, -45900,0.097551756,0.026282534,,,,,,,,,,,,,,,,, -46000,0.074940786,0.024996337,,,,,,,,,,,,,,,,, -46100,0.08102949,0.027820703,,,,,,,,,,,,,,,,, -46195,,,0.9932297468185424,0.0215552020817995,0.611245669579988,0.9870281219482422,0.0453068502247333,0.287117571203055,43793.0,0.9862251877784728,0.048274990171194,0.275243783480796,43793.0,14905.407657384872,23382.441463947296,14905.407657384872,8473.786965847015,1.9878568649291992,0.0 -46200,0.0850639,0.026954945,,,,,,,,,,,,,,,,, -46300,0.08212864,0.029070748,,,,,,,,,,,,,,,,, -46400,0.077719174,0.023995029,,,,,,,,,,,,,,,,, -46500,0.08130593,0.026371555,,,,,,,,,,,,,,,,, -46600,0.07527494,0.027357085,,,,,,,,,,,,,,,,, -46700,0.07111631,0.024779636,,,,,,,,,,,,,,,,, -46800,0.09253953,0.025174046,,,,,,,,,,,,,,,,, -46900,0.07587799,0.022719508,,,,,,,,,,,,,,,,, -46939,,,0.9930238723754884,0.022141970694065,0.5996005868878975,0.9870882034301758,0.0453628972172737,0.2826583293956821,43793.0,0.9862580895423888,0.0486461706459522,0.2688644851086066,43793.0,15145.571384191511,23748.05126833916,15145.571384191511,8599.17787861824,2.023483991622925,0.0 -47000,0.071669124,0.024289753,,,,,,,,,,,,,,,,, -47100,0.08205002,0.026219293,,,,,,,,,,,,,,,,, -47200,0.08865095,0.026746908,,,,,,,,,,,,,,,,, -47300,0.07349734,0.026249025,,,,,,,,,,,,,,,,, -47400,0.09522366,0.02761569,,,,,,,,,,,,,,,,, -47500,0.076685704,0.02386331,,,,,,,,,,,,,,,,, -47600,0.086133786,0.023981327,,,,,,,,,,,,,,,,, -47672,,,0.992883801460266,0.0226908214390277,0.5954002678011658,0.987058162689209,0.045141864567995,0.2881418197880501,43793.0,0.9862008094787598,0.0481557138264179,0.2730225395600637,43793.0,15385.59983444214,24114.510402441025,15385.59983444214,8725.551683664322,2.058670997619629,0.0 -47700,0.08410013,0.02399295,,,,,,,,,,,,,,,,, -47800,0.11088078,0.024101255,,,,,,,,,,,,,,,,, -47900,0.07985622,0.021901798,,,,,,,,,,,,,,,,, -48000,0.101833634,0.025015837,,,,,,,,,,,,,,,,, -48100,0.084486656,0.027085675,,,,,,,,,,,,,,,,, -48200,0.094425894,0.02579208,,,,,,,,,,,,,,,,, -48300,0.0909859,0.026867101,,,,,,,,,,,,,,,,, -48400,0.112521775,0.023721937,,,,,,,,,,,,,,,,, -48417,,,0.9927613139152528,0.0228866655379533,0.564050308760248,0.987027108669281,0.0457050092518329,0.28652231033362,43793.0,0.9862513542175292,0.0486955344676971,0.2771422885173468,43793.0,15625.717020273209,24478.012163877487,15625.717020273209,8848.880156755447,2.0949044227600098,0.0 -48500,0.07629792,0.024488898,,,,,,,,,,,,,,,,, -48600,0.08203674,0.023600189,,,,,,,,,,,,,,,,, -48700,0.08471239,0.024529116,,,,,,,,,,,,,,,,, -48800,0.098488964,0.027091878,,,,,,,,,,,,,,,,, -48900,0.07127196,0.021382123,,,,,,,,,,,,,,,,, -49000,0.08782932,0.024130357,,,,,,,,,,,,,,,,, -49100,0.10530227,0.02755731,,,,,,,,,,,,,,,,, -49148,,,0.9930188655853271,0.0221155732870101,0.5984401713121497,0.987084150314331,0.0453145541250705,0.2855918960402324,43793.0,0.9862251877784728,0.0485099367797374,0.2718043034407581,43793.0,15865.82262635231,24842.079083919525,15865.82262635231,8972.782843828201,2.131263732910156,0.0 -49200,0.07618132,0.02391167,,,,,,,,,,,,,,,,, -49300,0.08280318,0.023603205,,,,,,,,,,,,,,,,, -49400,0.08855917,0.022559501,,,,,,,,,,,,,,,,, -49500,0.07486877,0.023426307,,,,,,,,,,,,,,,,, -49600,0.0799574,0.023919474,,,,,,,,,,,,,,,,, -49700,0.09715218,0.025513474,,,,,,,,,,,,,,,,, -49800,0.09428531,0.02537545,,,,,,,,,,,,,,,,, -49889,,,0.9930760860443116,0.0218061376363039,0.6098008796744399,0.9870622158050536,0.045785091817379,0.2812710950757494,43793.0,0.9862938523292542,0.0486649572849273,0.2738028980149617,43793.0,16105.879657030106,25211.7152671814,16105.879657030106,9102.306144714355,2.167288064956665,0.0 -49900,0.08295376,0.023808686,,,,,,,,,,,,,,,,, -50000,0.081407264,0.022564495,,,,,,,,,,,,,,,,, -50100,0.08149726,0.02264637,,,,,,,,,,,,,,,,, -50200,0.085288174,0.024365637,,,,,,,,,,,,,,,,, -50300,0.07830995,0.022380065,,,,,,,,,,,,,,,,, -50400,0.099819936,0.025819376,,,,,,,,,,,,,,,,, -50500,0.092097715,0.024995092,,,,,,,,,,,,,,,,, -50600,0.09719544,0.024230901,,,,,,,,,,,,,,,,, -50609,,,0.9931766986846924,0.0214318484067916,0.619487995667,0.9871689677238464,0.045978058129549,0.28589317875376,43793.0,0.9862997531890868,0.0491118878126144,0.2682895887359769,43793.0,16345.824395656586,25574.36565876007,16345.824395656586,9224.947851657867,2.207081794738769,0.0 -50700,0.083908856,0.022732154,,,,,,,,,,,,,,,,, -50800,0.09188956,0.027424285,,,,,,,,,,,,,,,,, -50900,0.092713386,0.024244677,,,,,,,,,,,,,,,,, -51000,0.07693863,0.02139965,,,,,,,,,,,,,,,,, -51100,0.10177314,0.026528997,,,,,,,,,,,,,,,,, -51200,0.09711543,0.024208033,,,,,,,,,,,,,,,,, -51300,0.086534835,0.023415316,,,,,,,,,,,,,,,,, -51353,,,0.9935702085494996,0.0202808193862438,0.6379793391083107,0.9870991706848145,0.0458444878458976,0.2865038906963764,43793.0,0.9862328171730042,0.0489258803427219,0.2725946008014016,43793.0,16585.909583091736,25936.671803712845,16585.909583091736,9347.113098144531,2.242493867874145,0.0 -51400,0.08672523,0.023996627,,,,,,,,,,,,,,,,, -51500,0.08810893,0.020928519,,,,,,,,,,,,,,,,, -51600,0.085708566,0.023832513,,,,,,,,,,,,,,,,, -51700,0.097293966,0.025037926,,,,,,,,,,,,,,,,, -51800,0.092697285,0.023711337,,,,,,,,,,,,,,,,, -51900,0.08957911,0.023215998,,,,,,,,,,,,,,,,, -52000,0.08374943,0.021001292,,,,,,,,,,,,,,,,, -52098,,,0.9937245845794678,0.0197700150310993,0.6558632894885597,0.987107276916504,0.0459720119833946,0.2865892628409184,43793.0,0.9862193465232848,0.0490687452256679,0.2715121900772505,43793.0,16825.856951475143,26300.3449792862,16825.856951475143,9470.78262424469,2.279003143310547,0.0 -52100,0.09763848,0.022735856,,,,,,,,,,,,,,,,, -52200,0.10535734,0.026988883,,,,,,,,,,,,,,,,, -52300,0.08439214,0.022559965,,,,,,,,,,,,,,,,, -52400,0.09112004,0.022859681,,,,,,,,,,,,,,,,, -52500,0.108638145,0.02307461,,,,,,,,,,,,,,,,, -52600,0.09211181,0.024482436,,,,,,,,,,,,,,,,, -52700,0.10719623,0.024624472,,,,,,,,,,,,,,,,, -52800,0.09919486,0.02341131,,,,,,,,,,,,,,,,, -52846,,,0.9934922456741332,0.0203552469611167,0.6389656490329827,0.987166166305542,0.0464010424911975,0.2904116007109317,43793.0,0.9862968325614928,0.0496241226792335,0.273131802214466,43793.0,17066.066660165787,26666.281783103943,17066.066660165787,9596.453342676165,2.315351963043213,0.0 -52900,0.09995056,0.025991121,,,,,,,,,,,,,,,,, -53000,0.09205584,0.023364114,,,,,,,,,,,,,,,,, -53100,0.08410144,0.023276702,,,,,,,,,,,,,,,,, -53200,0.10439929,0.027538113,,,,,,,,,,,,,,,,, -53300,0.08578727,0.021663489,,,,,,,,,,,,,,,,, -53400,0.09211927,0.021842314,,,,,,,,,,,,,,,,, -53500,0.10251301,0.02340223,,,,,,,,,,,,,,,,, -53594,,,0.9935109615325928,0.0204487219452857,0.635550748244853,0.9870272874832152,0.0459951683878898,0.2882264457127399,43793.0,0.9861013889312744,0.0491682216525077,0.2707588206225723,43793.0,17306.208050727844,27028.29993700981,17306.208050727844,9718.271196126938,2.3540172576904297,0.0 -53600,0.10142356,0.022104267,,,,,,,,,,,,,,,,, -53700,0.11279858,0.02579198,,,,,,,,,,,,,,,,, -53800,0.103145376,0.023724137,,,,,,,,,,,,,,,,, -53900,0.10687806,0.024033017,,,,,,,,,,,,,,,,, -54000,0.10815765,0.022830151,,,,,,,,,,,,,,,,, -54100,0.09565223,0.02324772,,,,,,,,,,,,,,,,, -54200,0.121302925,0.025048466,,,,,,,,,,,,,,,,, -54300,0.09578578,0.022204336,,,,,,,,,,,,,,,,, -54339,,,0.9935351014137268,0.0203988421708345,0.6431458559111352,0.9870265126228333,0.0463195703923702,0.2899425669337316,43793.0,0.9861831068992616,0.049438040703535,0.2739239046998224,43793.0,17546.322751760483,27391.182716608047,17546.322751760483,9840.98232102394,2.391303777694702,0.0 -54400,0.08409344,0.01836743,,,,,,,,,,,,,,,,, -54500,0.09236652,0.020791577,,,,,,,,,,,,,,,,, -54600,0.119049914,0.024551345,,,,,,,,,,,,,,,,, -54700,0.09364949,0.021514717,,,,,,,,,,,,,,,,, -54800,0.10169178,0.024334187,,,,,,,,,,,,,,,,, -54900,0.095463,0.022452313,,,,,,,,,,,,,,,,, -55000,0.099303536,0.021275608,,,,,,,,,,,,,,,,, -55079,,,0.9935994744300842,0.0202458370476961,0.6338022190807497,0.9870439767837524,0.0462508201599121,0.2895553029434483,43793.0,0.9861974120140076,0.0495461821556091,0.2750850588168503,43793.0,17786.44678425789,27753.804602861404,17786.44678425789,9963.41772222519,2.4342310428619385,0.0 -55100,0.09678691,0.020425502,,,,,,,,,,,,,,,,, -55200,0.103520036,0.025013218,,,,,,,,,,,,,,,,, -55300,0.095481776,0.021313203,,,,,,,,,,,,,,,,, -55400,0.11373997,0.023100374,,,,,,,,,,,,,,,,, -55500,0.11111234,0.024035342,,,,,,,,,,,,,,,,, -55600,0.10598068,0.021685982,,,,,,,,,,,,,,,,, -55700,0.10973117,0.021719929,,,,,,,,,,,,,,,,, -55800,0.09111812,0.020667845,,,,,,,,,,,,,,,,, -55829,,,0.9936383962631226,0.0199243109673261,0.6444634510671647,0.9870975613594056,0.0466697067022323,0.289663565541211,43793.0,0.9862349033355712,0.0499169677495956,0.2747479620975098,43793.0,18026.465641736984,28116.08030819893,18026.465641736984,10085.61788034439,2.470750331878662,0.0 -55900,0.12521061,0.024416627,,,,,,,,,,,,,,,,, -56000,0.1120859,0.022255892,,,,,,,,,,,,,,,,, -56100,0.10228838,0.020770341,,,,,,,,,,,,,,,,, -56200,0.11609132,0.024600765,,,,,,,,,,,,,,,,, -56300,0.10294663,0.02360297,,,,,,,,,,,,,,,,, -56400,0.098140724,0.023435134,,,,,,,,,,,,,,,,, -56500,0.095478885,0.021079483,,,,,,,,,,,,,,,,, -56574,,,0.9936703443527222,0.0196411423385143,0.6531613462400097,0.9871182441711426,0.0468352176249027,0.2888919640760225,43793.0,0.986233651638031,0.0503293424844741,0.2715353789860384,43793.0,18266.70391392708,28475.22965478897,18266.70391392708,10204.472261428831,2.5078279972076416,0.0 -56600,0.10426061,0.022786513,,,,,,,,,,,,,,,,, -56700,0.10591562,0.019037385,,,,,,,,,,,,,,,,, -56800,0.09864757,0.020581286,,,,,,,,,,,,,,,,, -56900,0.11319164,0.0247006,,,,,,,,,,,,,,,,, -57000,0.114411384,0.024700884,,,,,,,,,,,,,,,,, -57100,0.11687836,0.02224426,,,,,,,,,,,,,,,,, -57200,0.09935024,0.020377003,,,,,,,,,,,,,,,,, -57300,0.0931564,0.01956639,,,,,,,,,,,,,,,,, -57318,,,0.9939448237419128,0.0190205872058868,0.6744130159403038,0.9870589971542358,0.0467969439923763,0.2875065035808982,43793.0,0.986137628555298,0.050077386200428,0.2762526456314265,43793.0,18506.681567668915,28838.496221780777,18506.681567668915,10327.702929973602,2.546010971069336,0.0 -57400,0.102756515,0.020408878,,,,,,,,,,,,,,,,, -57500,0.103941366,0.021345679,,,,,,,,,,,,,,,,, -57600,0.10512424,0.021407025,,,,,,,,,,,,,,,,, -57700,0.08979923,0.020003578,,,,,,,,,,,,,,,,, -57800,0.10078456,0.0220847,,,,,,,,,,,,,,,,, -57900,0.10357853,0.021759022,,,,,,,,,,,,,,,,, -58000,0.10085164,0.020120291,,,,,,,,,,,,,,,,, -58059,,,0.9942131638526917,0.0180310141295194,0.6876187034913789,0.9870078563690186,0.0471584536135196,0.2862036634371643,43793.0,0.9861649870872498,0.0505991503596305,0.2743655426824452,43793.0,18746.95504975319,29203.476695775986,18746.95504975319,10452.352368354796,2.582829475402832,0.0 -58100,0.10596733,0.022179874,,,,,,,,,,,,,,,,, -58200,0.097248174,0.019010687,,,,,,,,,,,,,,,,, -58300,0.12495612,0.024422375,,,,,,,,,,,,,,,,, -58400,0.10783815,0.024497623,,,,,,,,,,,,,,,,, -58500,0.120527826,0.023962926,,,,,,,,,,,,,,,,, -58600,0.11060624,0.020318741,,,,,,,,,,,,,,,,, -58700,0.120951444,0.023601186,,,,,,,,,,,,,,,,, -58800,0.10811012,0.024006065,,,,,,,,,,,,,,,,, -58805,,,0.9944328665733336,0.01756127551198,0.7030557313454966,0.9870532751083374,0.047436848282814,0.2879422701973064,43793.0,0.9862053990364076,0.0509169474244117,0.2733854418500244,43793.0,18986.97296071053,29564.330701112747,18986.97296071053,10573.129535675049,2.621100902557373,0.0 -58900,0.11052621,0.022975562,,,,,,,,,,,,,,,,, -59000,0.108533755,0.017885795,,,,,,,,,,,,,,,,, -59100,0.10664021,0.022169461,,,,,,,,,,,,,,,,, -59200,0.103129,0.020113599,,,,,,,,,,,,,,,,, -59300,0.11132897,0.022495192,,,,,,,,,,,,,,,,, -59400,0.112345524,0.022266377,,,,,,,,,,,,,,,,, -59500,0.10666191,0.020084739,,,,,,,,,,,,,,,,, -59552,,,0.9943233728408812,0.0178030133247375,0.694210838560239,0.9870098829269408,0.0471440367400646,0.2937513754272167,43793.0,0.9861890077590942,0.0505956448614597,0.2751660225322885,43793.0,19227.02808666229,29924.63180088997,19227.02808666229,10693.31673192978,2.6599795818328857,0.0 -59600,0.10125925,0.01989703,,,,,,,,,,,,,,,,, -59700,0.110657245,0.021171654,,,,,,,,,,,,,,,,, -59800,0.11774025,0.021808214,,,,,,,,,,,,,,,,, -59900,0.11290964,0.022486076,,,,,,,,,,,,,,,,, -60000,0.10970931,0.01994119,,,,,,,,,,,,,,,,, -60100,0.13170528,0.024689522,,,,,,,,,,,,,,,,, -60200,0.12087208,0.0203538,,,,,,,,,,,,,,,,, -60296,,,0.9941422343254088,0.0181688442826271,0.6758369738154911,0.9870870113372804,0.047808714210987,0.2878571093468118,43793.0,0.986295998096466,0.051175232976675,0.2740927208224178,43793.0,19467.04409813881,30287.35133075714,19467.04409813881,10815.961034536362,2.698820114135742,0.0 -60300,0.116390705,0.023128562,,,,,,,,,,,,,,,,, -60400,0.109224,0.020900523,,,,,,,,,,,,,,,,, -60500,0.11119852,0.022408653,,,,,,,,,,,,,,,,, -60600,0.13467138,0.02499345,,,,,,,,,,,,,,,,, -60700,0.109312765,0.018148476,,,,,,,,,,,,,,,,, -60800,0.12886138,0.022881715,,,,,,,,,,,,,,,,, -60900,0.11248747,0.022677075,,,,,,,,,,,,,,,,, -61000,0.13691962,0.021700509,,,,,,,,,,,,,,,,, -61040,,,0.9941512942314148,0.0182098597288131,0.686735293228531,0.987099587917328,0.0478898957371711,0.286116871264883,43793.0,0.9861696362495422,0.0514022409915924,0.2712695979829206,43793.0,19707.040642499924,30649.05088710785,19707.040642499924,10937.600918531418,2.740966796875,0.0 -61100,0.14064172,0.021411816,,,,,,,,,,,,,,,,, -61200,0.13602275,0.018651947,,,,,,,,,,,,,,,,, -61300,0.11635673,0.022450827,,,,,,,,,,,,,,,,, -61400,0.12006754,0.020467622,,,,,,,,,,,,,,,,, -61500,0.1310047,0.021741299,,,,,,,,,,,,,,,,, -61600,0.129843,0.021244846,,,,,,,,,,,,,,,,, -61700,0.11199544,0.01873979,,,,,,,,,,,,,,,,, -61786,,,0.9941979050636292,0.0179943554103374,0.6911118193222282,0.9871572256088256,0.0478870645165443,0.290800598389781,43793.0,0.986199915409088,0.051761258393526,0.2712998483111147,43793.0,19947.08567380905,31008.40128135681,19947.08567380905,11056.847058057783,2.7800092697143555,0.0 -61800,0.122816294,0.02117665,,,,,,,,,,,,,,,,, -61900,0.11888595,0.022624837,,,,,,,,,,,,,,,,, -62000,0.11489656,0.019821491,,,,,,,,,,,,,,,,, -62100,0.13231127,0.021765858,,,,,,,,,,,,,,,,, -62200,0.1259446,0.021980084,,,,,,,,,,,,,,,,, -62300,0.13413809,0.024453014,,,,,,,,,,,,,,,,, -62400,0.1262504,0.020423945,,,,,,,,,,,,,,,,, -62500,0.1214489,0.019538915,,,,,,,,,,,,,,,,, -62531,,,0.994190752506256,0.0179026070982217,0.6973760724944296,0.9870695471763612,0.0483664646744728,0.2877178300183725,43793.0,0.9861767888069152,0.0520378574728965,0.2690783502432456,43793.0,20187.041157007217,31372.12839603424,20187.041157007217,11180.55793595314,2.8193583488464355,0.0 -62600,0.14017902,0.021357125,,,,,,,,,,,,,,,,, -62700,0.12497441,0.017023481,,,,,,,,,,,,,,,,, -62800,0.12775414,0.020754775,,,,,,,,,,,,,,,,, -62900,0.11139925,0.021091998,,,,,,,,,,,,,,,,, -63000,0.123682074,0.019431489,,,,,,,,,,,,,,,,, -63100,0.11208666,0.02080516,,,,,,,,,,,,,,,,, -63200,0.11396168,0.019460209,,,,,,,,,,,,,,,,, -63276,,,0.9944777488708496,0.017259057611227,0.7072899931696965,0.987057328224182,0.0482454113662242,0.2879580300275564,43793.0,0.9861236810684204,0.0519551709294319,0.2732008237432165,43793.0,20427.09281229973,31730.00851416588,20427.09281229973,11298.31682562828,2.868788719177246,0.0 -63300,0.12414718,0.020359218,,,,,,,,,,,,,,,,, -63400,0.13111529,0.022240905,,,,,,,,,,,,,,,,, -63500,0.1336208,0.021278648,,,,,,,,,,,,,,,,, -63600,0.11715946,0.017790068,,,,,,,,,,,,,,,,, -63700,0.11950948,0.020632382,,,,,,,,,,,,,,,,, -63800,0.14889643,0.02043825,,,,,,,,,,,,,,,,, -63900,0.12767684,0.020321123,,,,,,,,,,,,,,,,, -64000,0.12684552,0.019211946,,,,,,,,,,,,,,,,, -64020,,,0.9946504831314088,0.0165974479168653,0.7181443084231507,0.987084984779358,0.0485220737755298,0.2862159947309958,43793.0,0.986163318157196,0.0521516539156436,0.2728042778335163,43793.0,20667.08352851868,32090.70943880081,20667.08352851868,11418.96899819374,2.907371282577514,0.0 -64100,0.13635415,0.021588031,,,,,,,,,,,,,,,,, -64200,0.13507015,0.020973701,,,,,,,,,,,,,,,,, -64300,0.13100314,0.020371629,,,,,,,,,,,,,,,,, -64400,0.13511227,0.02240751,,,,,,,,,,,,,,,,, -64500,0.13613586,0.022729676,,,,,,,,,,,,,,,,, -64600,0.12214905,0.020341754,,,,,,,,,,,,,,,,, -64700,0.13060793,0.017923225,,,,,,,,,,,,,,,,, -64754,,,0.9948357939720154,0.0161562655121088,0.7323550397195834,0.98714017868042,0.0485176891088485,0.2874680887460105,43793.0,0.9861885905265808,0.0523666962981224,0.270310425672684,43793.0,20907.040336608887,32449.245572328568,20907.040336608887,11537.486523628237,2.947197914123535,0.0 -64800,0.12862301,0.017981585,,,,,,,,,,,,,,,,, -64900,0.13274197,0.020498428,,,,,,,,,,,,,,,,, -65000,0.13063584,0.023360798,,,,,,,,,,,,,,,,, -65100,0.13258116,0.019612812,,,,,,,,,,,,,,,,, -65200,0.14262916,0.021292482,,,,,,,,,,,,,,,,, -65300,0.1111699,0.01857695,,,,,,,,,,,,,,,,, -65400,0.1379221,0.021573568,,,,,,,,,,,,,,,,, -65500,0.13607647,0.02121482,,,,,,,,,,,,,,,,, -65501,,,0.9949519634246826,0.0156491510570049,0.7398820880929526,0.9871227145195008,0.0487340465188026,0.2906366602949252,43793.0,0.9861925840377808,0.0526473931968212,0.2703020664572869,43793.0,21147.292327165604,32805.996050834656,21147.292327165604,11653.92560315132,2.98614239692688,0.0 -65600,0.11667198,0.020339334,,,,,,,,,,,,,,,,, -65700,0.13195911,0.021063283,,,,,,,,,,,,,,,,, -65800,0.13465248,0.020326281,,,,,,,,,,,,,,,,, -65900,0.14982292,0.020122347,,,,,,,,,,,,,,,,, -66000,0.12957208,0.023054387,,,,,,,,,,,,,,,,, -66100,0.12899415,0.020759286,,,,,,,,,,,,,,,,, -66200,0.13846104,0.018096708,,,,,,,,,,,,,,,,, -66248,,,0.99501234292984,0.0157033149152994,0.7419741470556183,0.9870402812957764,0.0490181297063827,0.2915888285229315,43793.0,0.9861291646957396,0.0527142025530338,0.2725590229995492,43793.0,21387.325922966003,33165.67395567894,21387.325922966003,11773.509294509888,3.0265464782714844,0.0 -66300,0.1268106,0.017395943,,,,,,,,,,,,,,,,, -66400,0.15141197,0.020562695,,,,,,,,,,,,,,,,, -66500,0.12275041,0.01885338,,,,,,,,,,,,,,,,, -66600,0.13127132,0.019157182,,,,,,,,,,,,,,,,, -66700,0.14962137,0.01832652,,,,,,,,,,,,,,,,, -66800,0.12778556,0.019277299,,,,,,,,,,,,,,,,, -66900,0.14138855,0.019670313,,,,,,,,,,,,,,,,, -66996,,,0.9948969483375548,0.0158338658511638,0.7381022885433084,0.9870325922966005,0.0491348057985305,0.2917242559636582,43793.0,0.9861043095588684,0.0529823116958141,0.2693005014060972,43793.0,21627.543115854263,33525.466379880905,21627.543115854263,11893.024748325348,3.0661911964416504,0.0 -67000,0.13917027,0.018574227,,,,,,,,,,,,,,,,, -67100,0.12406645,0.017837338,,,,,,,,,,,,,,,,, -67200,0.13494706,0.018791744,,,,,,,,,,,,,,,,, -67300,0.15160085,0.021985404,,,,,,,,,,,,,,,,, -67400,0.1366314,0.020484127,,,,,,,,,,,,,,,,, -67500,0.14735788,0.019122746,,,,,,,,,,,,,,,,, -67600,0.13026573,0.019179583,,,,,,,,,,,,,,,,, -67700,0.12217398,0.019521281,,,,,,,,,,,,,,,,, -67749,,,0.994721531867981,0.0163781829178333,0.7252492006625263,0.9870800971984864,0.0493294596672058,0.2926122424504049,43793.0,0.986127495765686,0.0531786978244781,0.2718304619033597,43793.0,21867.74240756035,33883.81890010834,21867.74240756035,12011.118778467178,3.105543851852417,0.0 -67800,0.15288025,0.020852027,,,,,,,,,,,,,,,,, -67900,0.13841955,0.019890472,,,,,,,,,,,,,,,,, -68000,0.13044132,0.019414457,,,,,,,,,,,,,,,,, -68100,0.1325488,0.017745733,,,,,,,,,,,,,,,,, -68200,0.14267649,0.020922335,,,,,,,,,,,,,,,,, -68300,0.15125546,0.021744628,,,,,,,,,,,,,,,,, -68400,0.12987745,0.018791521,,,,,,,,,,,,,,,,, -68488,,,0.994666337966919,0.01639249548316,0.7239857813140791,0.9870427250862122,0.0495910719037056,0.29098764881197,43793.0,0.986100137233734,0.0535429641604423,0.273014472514247,43793.0,22107.820801496506,34251.776836156845,22107.820801496506,12138.936106681824,3.1450343132019043,0.0 -68500,0.12954393,0.016791297,,,,,,,,,,,,,,,,, -68600,0.14381869,0.018560497,,,,,,,,,,,,,,,,, -68700,0.1491381,0.021690333,,,,,,,,,,,,,,,,, -68800,0.12742831,0.018924665,,,,,,,,,,,,,,,,, -68900,0.13080864,0.01843456,,,,,,,,,,,,,,,,, -69000,0.13104738,0.020214118,,,,,,,,,,,,,,,,, -69100,0.14534192,0.021527763,,,,,,,,,,,,,,,,, -69200,0.12356303,0.019358,,,,,,,,,,,,,,,,, -69225,,,0.994940996170044,0.0157144218683242,0.7364670876265623,0.986950159072876,0.0492487885057926,0.2891234404148812,43793.0,0.9860045313835144,0.0530650578439235,0.2730619442161265,43793.0,22347.77027964592,34611.156307697296,22347.77027964592,12258.297290086746,3.189326524734497,0.0 -69300,0.152144,0.018386131,,,,,,,,,,,,,,,,, -69400,0.14366387,0.021343747,,,,,,,,,,,,,,,,, -69500,0.12533084,0.018735703,,,,,,,,,,,,,,,,, -69600,0.14351383,0.02041522,,,,,,,,,,,,,,,,, -69700,0.12117692,0.018888898,,,,,,,,,,,,,,,,, -69800,0.1356779,0.018491235,,,,,,,,,,,,,,,,, -69900,0.13741113,0.019081812,,,,,,,,,,,,,,,,, -69963,,,0.9950496554374696,0.0153866959735751,0.7498461206729499,0.9870471954345704,0.0496566444635391,0.2911408890764547,43793.0,0.986139714717865,0.0534122213721275,0.2720989782136767,43793.0,22587.750163316727,34965.306415081024,22587.750163316727,12372.406210422516,3.2301290035247803,0.0 -70000,0.14530224,0.019872393,,,,,,,,,,,,,,,,, -70100,0.14747046,0.019115822,,,,,,,,,,,,,,,,, -70200,0.13306369,0.017217463,,,,,,,,,,,,,,,,, -70300,0.13368137,0.02035009,,,,,,,,,,,,,,,,, -70400,0.12699793,0.018928105,,,,,,,,,,,,,,,,, -70500,0.13732868,0.018781105,,,,,,,,,,,,,,,,, -70600,0.14754032,0.020935751,,,,,,,,,,,,,,,,, -70700,0.13262942,0.016924001,,,,,,,,,,,,,,,,, -70711,,,0.9950863122940063,0.0152974929660558,0.7440555332907148,0.9869948625564576,0.0498314648866653,0.290006136839753,43793.0,0.98613041639328,0.0536233186721801,0.2714608792304772,43793.0,22827.77950668335,35326.90785813332,22827.77950668335,12493.917872667313,3.2706449031829834,0.0 -70800,0.14540258,0.020560434,,,,,,,,,,,,,,,,, -70900,0.13525352,0.018065644,,,,,,,,,,,,,,,,, -71000,0.13695896,0.018406898,,,,,,,,,,,,,,,,, -71100,0.13342966,0.016286276,,,,,,,,,,,,,,,,, -71200,0.14279279,0.01888695,,,,,,,,,,,,,,,,, -71300,0.13694121,0.019657983,,,,,,,,,,,,,,,,, -71400,0.12624599,0.017098129,,,,,,,,,,,,,,,,, -71455,,,0.9954333305358888,0.0143726943060755,0.7739653157648199,0.9869473576545716,0.0498649254441261,0.2913538274236733,43793.0,0.9860959053039552,0.0535416230559349,0.2728977546716898,43793.0,23067.749517202377,35682.74603819847,23067.749517202377,12609.725140094755,3.3116016387939453,0.0 -71500,0.1363999,0.01633637,,,,,,,,,,,,,,,,, -71600,0.13560134,0.018311758,,,,,,,,,,,,,,,,, -71700,0.14280896,0.020970874,,,,,,,,,,,,,,,,, -71800,0.14512001,0.018384915,,,,,,,,,,,,,,,,, -71900,0.15791588,0.019094536,,,,,,,,,,,,,,,,, -72000,0.14952444,0.018848408,,,,,,,,,,,,,,,,, -72100,0.1592024,0.018368984,,,,,,,,,,,,,,,,, -72200,0.13633433,0.019172046,,,,,,,,,,,,,,,,, -72207,,,0.9954251646995544,0.0143637405708432,0.768106455239543,0.9870321750640868,0.0499582774937152,0.2905641657020273,43793.0,0.9861700534820556,0.0537815801799297,0.2736166316268893,43793.0,23307.969571113583,36039.77791452408,23307.969571113583,12726.47549533844,3.352621078491211,0.0 -72300,0.13665381,0.01846304,,,,,,,,,,,,,,,,, -72400,0.13378617,0.017352644,,,,,,,,,,,,,,,,, -72500,0.13120714,0.01745682,,,,,,,,,,,,,,,,, -72600,0.14087789,0.01908941,,,,,,,,,,,,,,,,, -72700,0.15369056,0.020221494,,,,,,,,,,,,,,,,, -72800,0.13732637,0.01811364,,,,,,,,,,,,,,,,, -72900,0.1574997,0.020953795,,,,,,,,,,,,,,,,, -72960,,,0.9953681230545044,0.0144671695306897,0.7619073822725733,0.9870269298553468,0.0500656254589557,0.2910683527730367,43793.0,0.9861670732498168,0.053872849792242,0.2733393715619848,43793.0,23548.14874601364,36397.32256865501,23548.14874601364,12843.777476787567,3.3956403732299805,0.0 -73000,0.14768894,0.016911766,,,,,,,,,,,,,,,,, -73100,0.14315036,0.018873982,,,,,,,,,,,,,,,,, -73200,0.16591875,0.019527761,,,,,,,,,,,,,,,,, -73300,0.14032799,0.019486086,,,,,,,,,,,,,,,,, -73400,0.14082111,0.017752498,,,,,,,,,,,,,,,,, -73500,0.13321875,0.0175315,,,,,,,,,,,,,,,,, -73600,0.14018032,0.017747615,,,,,,,,,,,,,,,,, -73700,0.12689777,0.01628037,,,,,,,,,,,,,,,,, -73704,,,0.9953656792640686,0.0145191131159663,0.7680448154889454,0.9869895577430724,0.0500662177801132,0.2890999776953153,43793.0,0.986119508743286,0.0539797991514205,0.2723943428460345,43793.0,23788.296051740646,36753.69769072533,23788.296051740646,12959.944328784944,3.435800552368164,0.0 -73800,0.12777895,0.017561123,,,,,,,,,,,,,,,,, -73900,0.1422443,0.01740172,,,,,,,,,,,,,,,,, -74000,0.1576788,0.019242229,,,,,,,,,,,,,,,,, -74100,0.13478284,0.01871192,,,,,,,,,,,,,,,,, -74200,0.13830596,0.017530976,,,,,,,,,,,,,,,,, -74300,0.13169943,0.017503353,,,,,,,,,,,,,,,,, -74400,0.14322387,0.016321829,,,,,,,,,,,,,,,,, -74442,,,0.99525386095047,0.0147152254357934,0.7530949013829958,0.987002968788147,0.0501992814242839,0.2914377994324131,43793.0,0.9861443638801576,0.0540723912417888,0.2734716980950172,43793.0,24028.29723644257,37106.92096066475,24028.29723644257,13073.09792804718,3.483970880508423,0.0 -74500,0.14070095,0.018308701,,,,,,,,,,,,,,,,, -74600,0.14131853,0.018456178,,,,,,,,,,,,,,,,, -74700,0.13720131,0.017402021,,,,,,,,,,,,,,,,, -74800,0.15692186,0.0187614,,,,,,,,,,,,,,,,, -74900,0.13655192,0.019285018,,,,,,,,,,,,,,,,, -75000,0.12904571,0.018000284,,,,,,,,,,,,,,,,, -75100,0.1458842,0.019190948,,,,,,,,,,,,,,,,, -75192,,,0.9953786730766296,0.0144358789548277,0.771100602272772,0.9869931936264038,0.0501015745103359,0.2919096212544834,43793.0,0.9861211776733398,0.0539165139198303,0.2740543698815366,43793.0,24268.272846221924,37465.62667059898,24268.272846221924,13191.766512155533,3.5248465538024902,0.0 -75200,0.14930074,0.018959168,,,,,,,,,,,,,,,,, -75300,0.14083798,0.017069096,,,,,,,,,,,,,,,,, -75400,0.14285225,0.018123716,,,,,,,,,,,,,,,,, -75500,0.15296122,0.018539371,,,,,,,,,,,,,,,,, -75600,0.1558476,0.019499648,,,,,,,,,,,,,,,,, -75700,0.14861673,0.01733074,,,,,,,,,,,,,,,,, -75800,0.14104794,0.020288039,,,,,,,,,,,,,,,,, -75900,0.14019766,0.019160297,,,,,,,,,,,,,,,,, -75940,,,0.9953892230987548,0.0143926069140434,0.7691021492142511,0.9870362281799316,0.0501075685024261,0.2927785560003524,43793.0,0.9861279129981996,0.053962018340826,0.2749305524965749,43793.0,24508.311578273773,37821.3829460144,24508.311578273773,13307.42241358757,3.565934658050537,0.0 -76000,0.15471822,0.019326592,,,,,,,,,,,,,,,,, -76100,0.13398036,0.018961325,,,,,,,,,,,,,,,,, -76200,0.16730681,0.020730413,,,,,,,,,,,,,,,,, -76300,0.14244121,0.017672163,,,,,,,,,,,,,,,,, -76400,0.15502563,0.019802557,,,,,,,,,,,,,,,,, -76500,0.11834796,0.016810952,,,,,,,,,,,,,,,,, -76600,0.14262961,0.019446896,,,,,,,,,,,,,,,,, -76686,,,0.9954409003257751,0.0143494829535484,0.7656371137133122,0.9870646595954896,0.0501628257334232,0.2936275545673407,43793.0,0.9861447811126708,0.0540152974426746,0.2739643168130566,43793.0,24748.563071012497,38178.83873414993,24748.563071012497,13424.56529855728,3.6069235801696777,0.0 -76700,0.13962625,0.017192952,,,,,,,,,,,,,,,,, -76800,0.13662525,0.018197814,,,,,,,,,,,,,,,,, -76900,0.13070396,0.017917521,,,,,,,,,,,,,,,,, -77000,0.15556103,0.02013955,,,,,,,,,,,,,,,,, -77100,0.14046249,0.01843713,,,,,,,,,,,,,,,,, -77200,0.14249033,0.01704828,,,,,,,,,,,,,,,,, -77300,0.14020956,0.016938917,,,,,,,,,,,,,,,,, -77400,0.14083914,0.017562028,,,,,,,,,,,,,,,,, -77423,,,0.9954899549484252,0.0141655439510941,0.7759545638285559,0.9870675206184388,0.0501727983355522,0.2925905235182048,43793.0,0.9861460328102112,0.0540119931101799,0.2738614273246971,43793.0,24988.685559511185,38536.57685160637,24988.685559511185,13542.111910581589,3.654918909072876,0.0 -77500,0.14456502,0.018825542,,,,,,,,,,,,,,,,, -77600,0.13812928,0.015650714,,,,,,,,,,,,,,,,, -77700,0.1415255,0.018062763,,,,,,,,,,,,,,,,, -77800,0.13383245,0.016953975,,,,,,,,,,,,,,,,, -77900,0.14081948,0.0170274,,,,,,,,,,,,,,,,, -78000,0.14173265,0.016674444,,,,,,,,,,,,,,,,, -78100,0.13426317,0.018092489,,,,,,,,,,,,,,,,, -78169,,,0.9955411553382874,0.0139900390058755,0.7764028795076405,0.987064242362976,0.0501210130751132,0.2925099990148256,43793.0,0.986154854297638,0.0539759583771228,0.2742019631959193,43793.0,25228.494074106216,38891.04864430428,25228.494074106216,13656.379624128342,4.030130386352539,0.0 -78200,0.14038436,0.016975692,,,,,,,,,,,,,,,,, -78300,0.15115686,0.020205734,,,,,,,,,,,,,,,,, -78400,0.13269433,0.018138682,,,,,,,,,,,,,,,,, -78500,0.14862387,0.020132324,,,,,,,,,,,,,,,,, -78600,0.13927744,0.018405551,,,,,,,,,,,,,,,,, -78700,0.13906379,0.017683672,,,,,,,,,,,,,,,,, -78800,0.15652215,0.021503748,,,,,,,,,,,,,,,,, -78900,0.13752908,0.017135935,,,,,,,,,,,,,,,,, -78911,,,0.9954591989517212,0.0141547610983252,0.7724652479037462,0.987065851688385,0.0501732900738716,0.2921166804633121,43793.0,0.9861447811126708,0.0540265403687953,0.274390207559422,43793.0,25468.65472769737,39253.02361893654,25468.65472769737,13778.132029294968,4.071470975875855,0.0 -79000,0.14954966,0.018305205,,,,,,,,,,,,,,,,, -79100,0.15401953,0.018409627,,,,,,,,,,,,,,,,, -79200,0.15185027,0.016676573,,,,,,,,,,,,,,,,, -79300,0.1423225,0.018063206,,,,,,,,,,,,,,,,, -79400,0.14715841,0.01765792,,,,,,,,,,,,,,,,, -79500,0.14561412,0.016878847,,,,,,,,,,,,,,,,, -79600,0.15261802,0.01908579,,,,,,,,,,,,,,,,, -79651,,,0.9955033659934998,0.014123479835689,0.7787619339885247,0.9870723485946656,0.0501797534525394,0.2923345979855717,43793.0,0.9861464500427246,0.0540345124900341,0.2745096935947982,43793.0,25708.90886187553,39610.99244427681,25708.90886187553,13895.783916950226,4.113170385360718,0.0 -79700,0.13912247,0.017296264,,,,,,,,,,,,,,,,, -79800,0.13110724,0.016497344,,,,,,,,,,,,,,,,, -79900,0.14952403,0.019266456,,,,,,,,,,,,,,,,, -80000,0.14717297,0.018847952,,,,,,,,,,,,,,,,, -80100,0.15100664,0.0193047,,,,,,,,,,,,,,,,, -80200,0.11924582,0.016990038,,,,,,,,,,,,,,,,, -80300,0.13424222,0.0170749,,,,,,,,,,,,,,,,, -80398,,,0.9954550862312316,0.0142388418316841,0.7613149448795697,0.987072765827179,0.0501789674162864,0.292532548861496,43793.0,0.9861477017402648,0.0540336780250072,0.2744157901547382,43793.0,25949.06646060944,39965.985496521,25949.06646060944,14010.557025671003,4.155007600784302,0.0 -80400,0.15393333,0.017191082,,,,,,,,,,,,,,,,, -80500,0.12462115,0.016189815,,,,,,,,,,,,,,,,, -80600,0.14536713,0.017334856,,,,,,,,,,,,,,,,, -80700,0.13560045,0.017515337,,,,,,,,,,,,,,,,, -80800,0.14001934,0.017866734,,,,,,,,,,,,,,,,, -80900,0.15375344,0.019082861,,,,,,,,,,,,,,,,, -81000,0.13231367,0.018883143,,,,,,,,,,,,,,,,, -81100,0.12804896,0.015832152,,,,,,,,,,,,,,,,, -81145,,,0.9954413771629332,0.0142920305952429,0.7743069602250183,0.987072765827179,0.0501789785921573,0.2925131864548106,43793.0,0.9861477017402648,0.0540336780250072,0.274459226207241,43793.0,26189.072664499283,40320.94511008263,26189.072664499283,14125.44719004631,4.197999954223633,0.0 -81200,0.16126822,0.02002139,,,,,,,,,,,,,,,,, -81300,0.14277901,0.017395988,,,,,,,,,,,,,,,,, -81400,0.13700573,0.015978402,,,,,,,,,,,,,,,,, -81500,0.14345796,0.01806118,,,,,,,,,,,,,,,,, -81600,0.16086392,0.019981677,,,,,,,,,,,,,,,,, -81700,0.14309908,0.016949072,,,,,,,,,,,,,,,,, -81800,0.1421959,0.018005071,,,,,,,,,,,,,,,,, -81891,,,0.995533525943756,0.0140317762270569,0.7743152053587273,0.987072765827179,0.0501789674162864,0.2923898254512626,43793.0,0.9861477017402648,0.0540336780250072,0.2745432332629763,43793.0,26429.078844308853,40678.15526819229,26429.078844308853,14242.58826804161,4.240504264831543,0.0 -81900,0.14737572,0.018060258,,,,,,,,,,,,,,,,, -82000,0.15377125,0.019294914,,,,,,,,,,,,,,,,, -82100,0.14478694,0.01823675,,,,,,,,,,,,,,,,, -82200,0.15805481,0.019428674,,,,,,,,,,,,,,,,, -82300,0.16020946,0.021500215,,,,,,,,,,,,,,,,, -82400,0.15763721,0.020298682,,,,,,,,,,,,,,,,, -82500,0.14878051,0.019082872,,,,,,,,,,,,,,,,, -82600,0.16206187,0.019432468,,,,,,,,,,,,,,,,, -82635,,,0.9954736232757568,0.0142103852704167,0.7793815673647857,0.987072765827179,0.0501789674162864,0.2923741737832562,43793.0,0.9861477017402648,0.0540336780250072,0.2743955574094778,43793.0,26669.20731639862,41036.22692847252,26669.20731639862,14360.468386650084,4.283483266830444,0.0 -82700,0.13654594,0.016610408,,,,,,,,,,,,,,,,, -82800,0.14031945,0.018659802,,,,,,,,,,,,,,,,, -82900,0.13319689,0.018728973,,,,,,,,,,,,,,,,, -83000,0.14570433,0.020213867,,,,,,,,,,,,,,,,, -83100,0.15407929,0.018632496,,,,,,,,,,,,,,,,, -83200,0.14358819,0.018473322,,,,,,,,,,,,,,,,, -83300,0.1362091,0.017885517,,,,,,,,,,,,,,,,, -83379,,,0.9954969882965088,0.0140584139153361,0.7734911544122332,0.987072765827179,0.0501789711415767,0.29237454060272,43793.0,0.9861477017402648,0.0540336780250072,0.2744025722481929,43793.0,26909.310037374496,41389.468089580536,26909.310037374496,14473.543465852736,4.326739549636841,0.0 -83400,0.14788112,0.018228352,,,,,,,,,,,,,,,,, -83500,0.13877702,0.015530188,,,,,,,,,,,,,,,,, -83600,0.14285187,0.018884504,,,,,,,,,,,,,,,,, -83700,0.14652774,0.019366816,,,,,,,,,,,,,,,,, -83800,0.13301624,0.017070265,,,,,,,,,,,,,,,,, -83900,0.14656307,0.018394843,,,,,,,,,,,,,,,,, -84000,0.15035737,0.018894449,,,,,,,,,,,,,,,,, -84100,0.14501964,0.019049902,,,,,,,,,,,,,,,,, -84115,,,0.9954795241355896,0.0141797987744212,0.7700450569577544,0.987072765827179,0.0501789674162864,0.2924245796731328,43793.0,0.9861477017402648,0.0540336780250072,0.2743914294477163,43793.0,27149.2596681118,41744.3564248085,27149.2596681118,14588.418242692947,4.370355844497681,0.0 -84200,0.14916457,0.018890157,,,,,,,,,,,,,,,,, -84300,0.12410744,0.016821902,,,,,,,,,,,,,,,,, -84400,0.13950054,0.01839316,,,,,,,,,,,,,,,,, -84500,0.13976054,0.016789097,,,,,,,,,,,,,,,,, -84600,0.14723398,0.01916076,,,,,,,,,,,,,,,,, -84700,0.15153973,0.01948905,,,,,,,,,,,,,,,,, -84800,0.13856453,0.018201431,,,,,,,,,,,,,,,,, -84868,,,0.9954484105110168,0.0142580112442374,0.770631365526065,0.987072765827179,0.050178974866867,0.2925296154257013,43793.0,0.9861477017402648,0.0540336780250072,0.2744031934173734,43793.0,27389.41187477112,42101.22405195236,27389.41187477112,14705.068259239197,4.415584087371826,0.0 -84900,0.13738926,0.017063702,,,,,,,,,,,,,,,,, -85000,0.15565182,0.018850906,,,,,,,,,,,,,,,,, -85100,0.1510855,0.017703705,,,,,,,,,,,,,,,,, -85200,0.13749649,0.018039234,,,,,,,,,,,,,,,,, -85300,0.13392246,0.017599102,,,,,,,,,,,,,,,,, -85400,0.1346236,0.017822636,,,,,,,,,,,,,,,,, -85500,0.14773734,0.016660295,,,,,,,,,,,,,,,,, -85600,0.12090877,0.017803479,,,,,,,,,,,,,,,,, -85616,,,0.9955559968948364,0.0140151465311646,0.7786308148010082,0.987072765827179,0.0501789711415767,0.2923086632501189,43793.0,0.9861477017402648,0.0540336780250072,0.274370387334361,43793.0,27629.48045659065,42456.78595900536,27629.48045659065,14820.498401403427,4.458751201629639,0.0 -85700,0.1344577,0.018095128,,,,,,,,,,,,,,,,, -85800,0.14226264,0.018213117,,,,,,,,,,,,,,,,, -85900,0.14724651,0.016213566,,,,,,,,,,,,,,,,, -86000,0.122044355,0.015246981,,,,,,,,,,,,,,,,, -86100,0.14665636,0.020395001,,,,,,,,,,,,,,,,, -86200,0.13001424,0.017135419,,,,,,,,,,,,,,,,, -86300,0.15263541,0.017523913,,,,,,,,,,,,,,,,, -86357,,,0.9954648613929749,0.0142206866294145,0.7778873312937167,0.987072765827179,0.0501789711415767,0.2925448356793363,43793.0,0.9861477017402648,0.0540336780250072,0.2745166189528792,43793.0,27869.674690246586,42809.461366176605,27869.674690246586,14932.916816949844,4.501690149307251,0.0 -86400,0.16196549,0.019590724,,,,,,,,,,,,,,,,, -86500,0.1452833,0.016117023,,,,,,,,,,,,,,,,, -86600,0.13723084,0.017832931,,,,,,,,,,,,,,,,, -86700,0.13988623,0.018662974,,,,,,,,,,,,,,,,, -86800,0.16086084,0.019595591,,,,,,,,,,,,,,,,, -86900,0.14493307,0.018007187,,,,,,,,,,,,,,,,, -87000,0.1323348,0.01765319,,,,,,,,,,,,,,,,, -87093,,,0.9954904317855836,0.0141125125810503,0.7748273898665948,0.987072765827179,0.0501789674162864,0.2924114938285133,43793.0,0.9861477017402648,0.0540336780250072,0.2744101209278129,43793.0,28109.69511294365,43168.13737511635,28109.69511294365,15051.508164167404,4.545614242553711,0.0 -87100,0.14154328,0.017741581,,,,,,,,,,,,,,,,, -87200,0.12544684,0.014831545,,,,,,,,,,,,,,,,, -87300,0.14346884,0.016523354,,,,,,,,,,,,,,,,, -87400,0.15152821,0.019013496,,,,,,,,,,,,,,,,, -87500,0.15579511,0.017085722,,,,,,,,,,,,,,,,, -87600,0.1432028,0.016821709,,,,,,,,,,,,,,,,, -87700,0.14050817,0.017076202,,,,,,,,,,,,,,,,, -87800,0.14570025,0.01708046,,,,,,,,,,,,,,,,, -87824,,,0.9954723715782166,0.0142516875639557,0.7732605363318497,0.987072765827179,0.0501789674162864,0.2925361165228815,43793.0,0.9861477017402648,0.0540336780250072,0.2743490402265405,43793.0,28349.8712515831,43527.16404628754,28349.8712515831,15170.293683290482,4.589420795440674,0.0 -87900,0.14398399,0.019501962,,,,,,,,,,,,,,,,, -88000,0.1389249,0.01948734,,,,,,,,,,,,,,,,, -88100,0.13509655,0.01762256,,,,,,,,,,,,,,,,, -88200,0.14788713,0.02081368,,,,,,,,,,,,,,,,, -88300,0.1480062,0.019754587,,,,,,,,,,,,,,,,, -88400,0.13876997,0.017009852,,,,,,,,,,,,,,,,, -88500,0.14667715,0.019182464,,,,,,,,,,,,,,,,, -88544,,,0.9954808950424194,0.0141577227041125,0.7625540593532746,0.987072765827179,0.0501789711415767,0.2924557376319083,43793.0,0.9861477017402648,0.0540336780250072,0.2745361670129824,43793.0,28589.81309151649,43885.57111406326,28589.81309151649,15288.689393758774,4.63522481918335,0.0 -88600,0.16147442,0.020008741,,,,,,,,,,,,,,,,, -88700,0.13169253,0.016946124,,,,,,,,,,,,,,,,, -88800,0.13189207,0.018265912,,,,,,,,,,,,,,,,, -88900,0.15468128,0.019851051,,,,,,,,,,,,,,,,, -89000,0.14262706,0.018478915,,,,,,,,,,,,,,,,, -89100,0.15769944,0.019612977,,,,,,,,,,,,,,,,, -89200,0.14536738,0.018814981,,,,,,,,,,,,,,,,, -89284,,,0.9955047965049744,0.0140820499509572,0.7791833701354678,0.987072765827179,0.0501789785921573,0.2923814411188966,43793.0,0.9861477017402648,0.0540336780250072,0.2744330019098556,43793.0,28829.955509662628,44239.30815052986,28829.955509662628,15402.21783900261,4.680805921554565,0.0 -89300,0.16386154,0.021284852,,,,,,,,,,,,,,,,, -89400,0.15415716,0.020367272,,,,,,,,,,,,,,,,, -89500,0.15425731,0.019092092,,,,,,,,,,,,,,,,, -89600,0.14205162,0.019638596,,,,,,,,,,,,,,,,, -89700,0.1612786,0.019561557,,,,,,,,,,,,,,,,, -89800,0.1594194,0.017955417,,,,,,,,,,,,,,,,, -89900,0.15045156,0.019133328,,,,,,,,,,,,,,,,, -90000,0.13623445,0.019427886,,,,,,,,,,,,,,,,, -90030,,,0.9955043792724608,0.0141207063570618,0.7797977344418124,0.987072765827179,0.0501789785921573,0.2923577434288198,43793.0,0.9861477017402648,0.0540336854755878,0.2744070792908892,43793.0,29070.022459983826,44595.66125655174,29070.022459983826,15518.440243244171,4.724467754364014,0.0 -90100,0.17170224,0.021645054,,,,,,,,,,,,,,,,, -90200,0.12986979,0.017471913,,,,,,,,,,,,,,,,, -90300,0.13835435,0.018076979,,,,,,,,,,,,,,,,, -90400,0.13994062,0.01667764,,,,,,,,,,,,,,,,, -90500,0.14380465,0.019703262,,,,,,,,,,,,,,,,, -90600,0.14492936,0.017793857,,,,,,,,,,,,,,,,, -90700,0.13605917,0.017153386,,,,,,,,,,,,,,,,, -90777,,,0.9954771399497986,0.0141914384439587,0.7664620953486616,0.987072765827179,0.0501789711415767,0.2925620521718061,43793.0,0.9861477017402648,0.0540336854755878,0.2745075266508673,43793.0,29310.049928426743,44956.033213377,29310.049928426743,15638.720978021622,4.768067836761475,0.0 -90800,0.14363761,0.017663123,,,,,,,,,,,,,,,,, -90900,0.12130281,0.01572994,,,,,,,,,,,,,,,,, -91000,0.14552628,0.019600589,,,,,,,,,,,,,,,,, -91100,0.12394264,0.017691534,,,,,,,,,,,,,,,,, -91200,0.1606355,0.020412365,,,,,,,,,,,,,,,,, -91300,0.15677804,0.020182827,,,,,,,,,,,,,,,,, -91400,0.13518527,0.015815465,,,,,,,,,,,,,,,,, -91500,0.14870237,0.020165466,,,,,,,,,,,,,,,,, -91515,,,0.9954972863197328,0.0141352312639355,0.7745224158954835,0.987072765827179,0.0501789674162864,0.2923887581809313,43793.0,0.9861477017402648,0.0540336780250072,0.2745610067990969,43793.0,29550.18574547768,45310.246342897415,29550.18574547768,15752.73408961296,4.812305450439453,0.0 -91600,0.14625901,0.019389253,,,,,,,,,,,,,,,,, -91700,0.1441788,0.016655523,,,,,,,,,,,,,,,,, -91800,0.13913777,0.015906116,,,,,,,,,,,,,,,,, -91900,0.13312934,0.0154664535,,,,,,,,,,,,,,,,, -92000,0.13869017,0.017357072,,,,,,,,,,,,,,,,, -92100,0.15105596,0.019308614,,,,,,,,,,,,,,,,, -92200,0.13529344,0.015485339,,,,,,,,,,,,,,,,, -92253,,,0.995453119277954,0.0142373200505971,0.7638528805363961,0.987072765827179,0.0501789785921573,0.292432335709977,43793.0,0.9861477017402648,0.0540336780250072,0.2745345744716574,43793.0,29790.211117506027,45663.67285728455,29790.211117506027,15866.069275140762,4.857031583786011,0.0 -92300,0.15239665,0.01805636,,,,,,,,,,,,,,,,, -92400,0.14741924,0.018786877,,,,,,,,,,,,,,,,, -92500,0.13404265,0.016817173,,,,,,,,,,,,,,,,, -92600,0.13663855,0.019945335,,,,,,,,,,,,,,,,, -92700,0.15771402,0.01851985,,,,,,,,,,,,,,,,, -92800,0.13053277,0.017975748,,,,,,,,,,,,,,,,, -92900,0.1303035,0.016295716,,,,,,,,,,,,,,,,, -92984,,,0.9955169558525084,0.0140610234811902,0.7800021202800156,0.987072765827179,0.0501789711415767,0.2924081659905708,43793.0,0.9861477017402648,0.0540336854755878,0.2743704135473845,43793.0,30030.20610809326,46019.193653821945,30030.20610809326,15981.527184009552,4.903205156326294,0.0 -93000,0.13957521,0.019433849,,,,,,,,,,,,,,,,, -93100,0.14164513,0.01835955,,,,,,,,,,,,,,,,, -93200,0.1621069,0.020152124,,,,,,,,,,,,,,,,, -93300,0.14049494,0.016884705,,,,,,,,,,,,,,,,, -93400,0.13621363,0.017394142,,,,,,,,,,,,,,,,, -93500,0.15385169,0.018922461,,,,,,,,,,,,,,,,, -93600,0.1380801,0.01818835,,,,,,,,,,,,,,,,, -93700,0.13197114,0.018519867,,,,,,,,,,,,,,,,, -93720,,,0.995534121990204,0.0141052734106779,0.7757578925328368,0.987072765827179,0.0501789711415767,0.2923997285759094,43793.0,0.9861477017402648,0.0540336780250072,0.2744685360980869,43793.0,30270.435495376587,46375.94743990898,30270.435495376587,16097.984763383864,4.948692083358765,0.0 -93800,0.1338121,0.01785982,,,,,,,,,,,,,,,,, -93900,0.14122641,0.018486306,,,,,,,,,,,,,,,,, -94000,0.1489643,0.021374913,,,,,,,,,,,,,,,,, -94100,0.15831266,0.01888195,,,,,,,,,,,,,,,,, -94200,0.14134066,0.017107088,,,,,,,,,,,,,,,,, -94300,0.14602411,0.019616866,,,,,,,,,,,,,,,,, -94400,0.13792711,0.01702304,,,,,,,,,,,,,,,,, -94469,,,0.9954125881195068,0.0142912846058607,0.7713674171966016,0.987072765827179,0.0501789674162864,0.292367312775709,43793.0,0.9861477017402648,0.0540336780250072,0.2743349141264332,43793.0,30510.429839372635,46731.31958055496,30510.429839372635,16213.29602575302,4.994197368621826,0.0 -94500,0.13169657,0.01855428,,,,,,,,,,,,,,,,, -94600,0.13391751,0.018202279,,,,,,,,,,,,,,,,, -94700,0.15824819,0.018658422,,,,,,,,,,,,,,,,, -94800,0.14999758,0.018414538,,,,,,,,,,,,,,,,, -94900,0.14479625,0.016089227,,,,,,,,,,,,,,,,, -95000,0.14160267,0.018929262,,,,,,,,,,,,,,,,, -95100,0.16010928,0.01863003,,,,,,,,,,,,,,,,, -95200,0.15489502,0.018467728,,,,,,,,,,,,,,,,, -95213,,,0.995521366596222,0.0140772210434079,0.7769381347581835,0.987072765827179,0.0501789785921573,0.2924208216760564,43793.0,0.9861477017402648,0.0540336780250072,0.2743825265260335,43793.0,30750.42320585251,47087.043660879135,30750.42320585251,16328.961030006409,5.038869380950928,0.0 -95300,0.13961205,0.018149229,,,,,,,,,,,,,,,,, -95400,0.14812697,0.017853135,,,,,,,,,,,,,,,,, -95500,0.14216295,0.016944237,,,,,,,,,,,,,,,,, -95600,0.14832355,0.016689943,,,,,,,,,,,,,,,,, -95700,0.1480797,0.018454066,,,,,,,,,,,,,,,,, -95800,0.13383663,0.016754176,,,,,,,,,,,,,,,,, -95900,0.15500203,0.017445108,,,,,,,,,,,,,,,,, -95959,,,0.9954837560653688,0.0141481598839163,0.7681165605917525,0.987072765827179,0.050178974866867,0.2925440796695244,43793.0,0.9861477017402648,0.0540336780250072,0.2743611788337915,43793.0,30990.67506980896,47439.72964859009,30990.67506980896,16441.328189611435,5.085657596588135,0.0 -96000,0.12990278,0.017680448,,,,,,,,,,,,,,,,, -96100,0.13586468,0.01641757,,,,,,,,,,,,,,,,, -96200,0.14051004,0.016979827,,,,,,,,,,,,,,,,, -96300,0.13770989,0.01593651,,,,,,,,,,,,,,,,, -96400,0.13589649,0.017429344,,,,,,,,,,,,,,,,, -96500,0.138593,0.01897364,,,,,,,,,,,,,,,,, -96600,0.13153166,0.018248646,,,,,,,,,,,,,,,,, -96699,,,0.9954470992088318,0.0142582403495907,0.7757481490303596,0.987072765827179,0.0501789711415767,0.2923798986293627,43793.0,0.9861477017402648,0.0540336780250072,0.2743702992049815,43793.0,31230.70459461212,47794.90297555924,31230.70459461212,16556.407190322876,5.13048243522644,0.0 -96700,0.14059636,0.016801061,,,,,,,,,,,,,,,,, -96800,0.14566933,0.018351391,,,,,,,,,,,,,,,,, -96900,0.13368829,0.018509462,,,,,,,,,,,,,,,,, -97000,0.1479159,0.017962417,,,,,,,,,,,,,,,,, -97100,0.14435904,0.017918587,,,,,,,,,,,,,,,,, -97200,0.1359856,0.016867368,,,,,,,,,,,,,,,,, -97300,0.14479373,0.01693359,,,,,,,,,,,,,,,,, -97400,0.13986789,0.016904602,,,,,,,,,,,,,,,,, -97442,,,0.9955335855484008,0.0140630081295967,0.7739189992361805,0.987072765827179,0.0501789674162864,0.2924558861565163,43793.0,0.9861477017402648,0.0540336780250072,0.2743770625028734,43793.0,31470.8629090786,48149.28699564934,31470.8629090786,16670.566215515137,5.1765077114105225,0.0 -97500,0.15042627,0.017660394,,,,,,,,,,,,,,,,, -97600,0.14653656,0.018336117,,,,,,,,,,,,,,,,, -97700,0.1479207,0.02025889,,,,,,,,,,,,,,,,, -97800,0.15063812,0.018218454,,,,,,,,,,,,,,,,, -97900,0.14673056,0.017777817,,,,,,,,,,,,,,,,, -98000,0.15348989,0.01976349,,,,,,,,,,,,,,,,, -98100,0.13259973,0.017590791,,,,,,,,,,,,,,,,, -98189,,,0.9955210089683532,0.0140525046736001,0.7769833804928025,0.987072765827179,0.0501789711415767,0.2925534718224171,43793.0,0.9861477017402648,0.0540336780250072,0.2744425142668042,43793.0,31710.853369951248,48502.43397283554,31710.853369951248,16783.64903306961,5.22965407371521,0.0 -98200,0.1316431,0.016996238,,,,,,,,,,,,,,,,, -98300,0.13334832,0.018084656,,,,,,,,,,,,,,,,, -98400,0.12410064,0.018034248,,,,,,,,,,,,,,,,, -98500,0.15000033,0.017117266,,,,,,,,,,,,,,,,, -98600,0.12999998,0.017295476,,,,,,,,,,,,,,,,, -98700,0.14213957,0.018460056,,,,,,,,,,,,,,,,, -98800,0.13963988,0.01705068,,,,,,,,,,,,,,,,, -98900,0.14724429,0.015806127,,,,,,,,,,,,,,,,, -98932,,,0.9954484701156616,0.014266662299633,0.7728061638280789,0.987072765827179,0.0501789711415767,0.2925407066451959,43793.0,0.9861477017402648,0.0540336780250072,0.2744893408034077,43793.0,31950.86460542679,48858.06380486488,31950.86460542679,16899.202238321304,5.274975776672363,0.0 -99000,0.12843347,0.016266035,,,,,,,,,,,,,,,,, -99100,0.1468353,0.018424897,,,,,,,,,,,,,,,,, -99200,0.12965819,0.01591966,,,,,,,,,,,,,,,,, -99300,0.15776792,0.019770402,,,,,,,,,,,,,,,,, -99400,0.14269155,0.016835988,,,,,,,,,,,,,,,,, -99500,0.14341179,0.01776184,,,,,,,,,,,,,,,,, -99600,0.1501945,0.017451068,,,,,,,,,,,,,,,,, -99673,,,0.9955064058303832,0.0140399364754557,0.769985629437685,0.987072765827179,0.0501789674162864,0.2923742836898387,43793.0,0.9861477017402648,0.0540336780250072,0.2744765086034905,43793.0,32191.01792359352,49212.104393959045,32191.01792359352,17013.020210027695,5.321348190307617,0.0 -99700,0.13695946,0.017167624,,,,,,,,,,,,,,,,, -99800,0.14559072,0.015594683,,,,,,,,,,,,,,,,, -99900,0.13787495,0.018444879,,,,,,,,,,,,,,,,, -100000,0.13407932,0.016905695,,,,,,,,,,,,,,,,, -100100,0.14841828,0.018448683,,,,,,,,,,,,,,,,, -100200,0.1558104,0.019832091,,,,,,,,,,,,,,,,, -100300,0.1585656,0.01652671,,,,,,,,,,,,,,,,, -100400,0.17010829,0.019395152,,,,,,,,,,,,,,,,, -100419,,,0.995432198047638,0.0143799893558025,0.7721081716111647,0.987072765827179,0.0501789711415767,0.292357912723474,43793.0,0.9861477017402648,0.0540336780250072,0.2743564112248008,43793.0,32431.11732816696,49566.61224031448,32431.11732816696,17127.36168217659,5.368618726730347,0.0 -100500,0.13641255,0.019313553,,,,,,,,,,,,,,,,, -100600,0.12950994,0.016538884,,,,,,,,,,,,,,,,, -100700,0.14846984,0.018035302,,,,,,,,,,,,,,,,, -100800,0.1364816,0.017878784,,,,,,,,,,,,,,,,, -100900,0.13398825,0.017538577,,,,,,,,,,,,,,,,, -101000,0.13081634,0.017090255,,,,,,,,,,,,,,,,, -101100,0.12369883,0.015115108,,,,,,,,,,,,,,,,, -101162,,,0.9955034255981444,0.0141010275110602,0.7755480809717403,0.987072765827179,0.0501789674162864,0.2924520208366916,43793.0,0.9861477017402648,0.0540336780250072,0.274346909907928,43793.0,32671.16210055352,49915.86226391792,32671.16210055352,17236.50082373619,5.4143126010894775,0.0 -101200,0.13172393,0.017406013,,,,,,,,,,,,,,,,, -101300,0.1586245,0.017883066,,,,,,,,,,,,,,,,, -101400,0.1540023,0.018647233,,,,,,,,,,,,,,,,, -101500,0.1334777,0.016975498,,,,,,,,,,,,,,,,, -101600,0.12963408,0.015752055,,,,,,,,,,,,,,,,, -101700,0.13102256,0.0164399,,,,,,,,,,,,,,,,, -101800,0.14935197,0.020112373,,,,,,,,,,,,,,,,, -101900,0.13129123,0.019012,,,,,,,,,,,,,,,,, -101901,,,0.9954982399940492,0.0141067765653133,0.7797891037776852,0.987072765827179,0.050178974866867,0.2924009146730034,43793.0,0.9861477017402648,0.0540336780250072,0.2743002376678368,43793.0,32911.343249082565,50264.32465529442,32911.343249082565,17344.714772701263,5.459990978240967,0.0 -102000,0.13991196,0.018072695,,,,,,,,,,,,,,,,, -102100,0.15993701,0.021136502,,,,,,,,,,,,,,,,, -102200,0.14761077,0.019023998,,,,,,,,,,,,,,,,, -102300,0.13372333,0.016999599,,,,,,,,,,,,,,,,, -102400,0.16848017,0.018331014,,,,,,,,,,,,,,,,, -102500,0.15546387,0.02005935,,,,,,,,,,,,,,,,, -102600,0.140195,0.017757451,,,,,,,,,,,,,,,,, -102658,,,0.9954644441604614,0.0141842234879732,0.7669780497456298,0.987072765827179,0.050178974866867,0.2925304588886242,43793.0,0.9861477017402648,0.0540336780250072,0.2744376055525118,43793.0,33151.44903755188,50614.150799274445,33151.44903755188,17454.36855649948,5.506225824356079,0.0 -102700,0.13108414,0.018170005,,,,,,,,,,,,,,,,, -102800,0.14075387,0.016544057,,,,,,,,,,,,,,,,, -102900,0.12891413,0.015389734,,,,,,,,,,,,,,,,, -103000,0.14314963,0.017745063,,,,,,,,,,,,,,,,, -103100,0.13991034,0.017057506,,,,,,,,,,,,,,,,, -103200,0.1298571,0.016664807,,,,,,,,,,,,,,,,, -103300,0.14079793,0.018107781,,,,,,,,,,,,,,,,, -103400,0.14446023,0.017610027,,,,,,,,,,,,,,,,, -103408,,,0.9954658150672911,0.0142328403890132,0.7757999081072839,0.987072765827179,0.050178974866867,0.2925922920047358,43793.0,0.9861477017402648,0.0540336780250072,0.2744152863768196,43793.0,33391.576731443405,50966.39300918579,33391.576731443405,17566.41631412506,5.552358627319336,0.0 -103500,0.14519417,0.016678926,,,,,,,,,,,,,,,,, -103600,0.14116241,0.018586095,,,,,,,,,,,,,,,,, -103700,0.13637117,0.016692149,,,,,,,,,,,,,,,,, -103800,0.14666298,0.018836124,,,,,,,,,,,,,,,,, -103900,0.13402404,0.015834536,,,,,,,,,,,,,,,,, -104000,0.14643855,0.018074691,,,,,,,,,,,,,,,,, -104100,0.13994466,0.017504254,,,,,,,,,,,,,,,,, -104161,,,0.9954906105995178,0.0141477445140481,0.7707154556254288,0.987072765827179,0.0501789785921573,0.2923405477629439,43793.0,0.9861477017402648,0.0540336780250072,0.2744654155294517,43793.0,33631.52317357063,51318.0455186367,33631.52317357063,17678.056384563446,5.598142623901367,0.0 -104200,0.13166136,0.0181156,,,,,,,,,,,,,,,,, -104300,0.13408506,0.01611427,,,,,,,,,,,,,,,,, -104400,0.13535054,0.016767351,,,,,,,,,,,,,,,,, -104500,0.12356727,0.016541053,,,,,,,,,,,,,,,,, -104600,0.13070594,0.017569426,,,,,,,,,,,,,,,,, -104700,0.14319971,0.017554685,,,,,,,,,,,,,,,,, -104800,0.1450733,0.019022955,,,,,,,,,,,,,,,,, -104900,0.13174576,0.01698249,,,,,,,,,,,,,,,,, -104914,,,0.9955100417137146,0.0141041558235883,0.7736689829168021,0.987072765827179,0.050178974866867,0.2924022043200743,43793.0,0.9861477017402648,0.0540336780250072,0.2743916661481683,43793.0,33871.74913954735,51666.66525363922,33871.74913954735,17786.38423061371,5.643944025039673,0.0 -105000,0.16235742,0.02059097,,,,,,,,,,,,,,,,, -105100,0.13685338,0.01841452,,,,,,,,,,,,,,,,, -105200,0.13615364,0.01736799,,,,,,,,,,,,,,,,, -105300,0.16914105,0.019992748,,,,,,,,,,,,,,,,, -105400,0.14638261,0.020032037,,,,,,,,,,,,,,,,, -105500,0.14605789,0.018377526,,,,,,,,,,,,,,,,, -105600,0.16434589,0.020557886,,,,,,,,,,,,,,,,, -105653,,,0.9955135583877563,0.0141152972355484,0.7804668605179181,0.987072765827179,0.050178974866867,0.2926694313673195,43793.0,0.9861477017402648,0.0540336780250072,0.2744481433858795,43793.0,34111.82308912277,52012.376638650894,34111.82308912277,17891.954362392426,5.690004587173462,0.0 -105700,0.14974143,0.018876899,,,,,,,,,,,,,,,,, -105800,0.13141133,0.017379995,,,,,,,,,,,,,,,,, -105900,0.13601613,0.017922092,,,,,,,,,,,,,,,,, -106000,0.14280593,0.01771641,,,,,,,,,,,,,,,,, -106100,0.13410005,0.017401474,,,,,,,,,,,,,,,,, -106200,0.13045318,0.01726536,,,,,,,,,,,,,,,,, -106300,0.13218148,0.018802078,,,,,,,,,,,,,,,,, -106399,,,0.9954696297645568,0.0141687374562025,0.7746599640365843,0.987072765827179,0.0501789674162864,0.292687754129531,43793.0,0.9861477017402648,0.0540336780250072,0.274423879669587,43793.0,34351.85087966919,52365.87801599503,34351.85087966919,18005.360337734222,5.737663269042969,0.0 -106400,0.14230572,0.020354502,,,,,,,,,,,,,,,,, -106500,0.14162368,0.018213466,,,,,,,,,,,,,,,,, -106600,0.14434032,0.01837524,,,,,,,,,,,,,,,,, -106700,0.15181558,0.02039214,,,,,,,,,,,,,,,,, -106800,0.14092757,0.016430061,,,,,,,,,,,,,,,,, -106900,0.15805824,0.020181868,,,,,,,,,,,,,,,,, -107000,0.14686944,0.01894391,,,,,,,,,,,,,,,,, -107100,0.13477647,0.017682718,,,,,,,,,,,,,,,,, -107146,,,0.9955094456672668,0.0140214832499623,0.7750776297005886,0.987072765827179,0.050178974866867,0.2924795025938732,43793.0,0.9861477017402648,0.0540336780250072,0.2743853711222883,43793.0,34591.82379245758,52719.61870455742,34591.82379245758,18119.052418470383,5.792099714279175,0.0 -107200,0.14713506,0.016707588,,,,,,,,,,,,,,,,, -107300,0.1717283,0.021992875,,,,,,,,,,,,,,,,, -107400,0.14283417,0.018779708,,,,,,,,,,,,,,,,, -107500,0.14318232,0.017527435,,,,,,,,,,,,,,,,, -107600,0.13776837,0.017797822,,,,,,,,,,,,,,,,, -107700,0.14018881,0.016333072,,,,,,,,,,,,,,,,, -107800,0.14272831,0.018714283,,,,,,,,,,,,,,,,, -107877,,,0.9954585433006288,0.0142492037266492,0.7664326035292928,0.987072765827179,0.050178974866867,0.2925241672229261,43793.0,0.9861477017402648,0.0540336780250072,0.2744068619439111,43793.0,34831.85524511337,53065.20850133896,34831.85524511337,18224.53417825699,5.845448970794678,0.0 -107900,0.14749108,0.018818283,,,,,,,,,,,,,,,,, -108000,0.14081305,0.018438747,,,,,,,,,,,,,,,,, -108100,0.14793727,0.019210013,,,,,,,,,,,,,,,,, -108200,0.14860897,0.020076025,,,,,,,,,,,,,,,,, -108300,0.11873852,0.016659958,,,,,,,,,,,,,,,,, -108400,0.11959091,0.013908123,,,,,,,,,,,,,,,,, -108500,0.14893977,0.019804033,,,,,,,,,,,,,,,,, -108600,0.14425328,0.018550117,,,,,,,,,,,,,,,,, -108610,,,0.995460331439972,0.0142254335805773,0.7700719906062522,0.987072765827179,0.0501789674162864,0.29244259418498,43793.0,0.9861477017402648,0.0540336780250072,0.2744291479385146,43793.0,35071.80486369133,53414.68152284622,35071.80486369133,18333.98607826233,5.893911361694336,0.0 -108700,0.1485647,0.019615138,,,,,,,,,,,,,,,,, -108800,0.16346742,0.018844068,,,,,,,,,,,,,,,,, -108900,0.15232219,0.019865986,,,,,,,,,,,,,,,,, -109000,0.12101499,0.017166523,,,,,,,,,,,,,,,,, -109100,0.1471961,0.017382583,,,,,,,,,,,,,,,,, -109200,0.15056323,0.018382309,,,,,,,,,,,,,,,,, -109300,0.14837313,0.018906575,,,,,,,,,,,,,,,,, -109360,,,0.9955194592475892,0.0141528416424989,0.7760508584807527,0.987072765827179,0.0501789711415767,0.2923588761610667,43793.0,0.9861477017402648,0.0540336780250072,0.2743291088378314,43793.0,35311.79198503494,53764.87401008606,35311.79198503494,18444.12341618538,5.941523551940918,0.0 -109400,0.1336179,0.018218251,,,,,,,,,,,,,,,,, -109500,0.14110154,0.016728476,,,,,,,,,,,,,,,,, -109600,0.130263,0.015295203,,,,,,,,,,,,,,,,, -109700,0.15243337,0.01792004,,,,,,,,,,,,,,,,, -109800,0.165488,0.021002145,,,,,,,,,,,,,,,,, -109900,0.13727716,0.018931545,,,,,,,,,,,,,,,,, -110000,0.14816347,0.017084692,,,,,,,,,,,,,,,,, -110100,0.1432365,0.01579953,,,,,,,,,,,,,,,,, -110108,,,0.9955502152442932,0.0139934895560145,0.7846037522654938,0.987072765827179,0.0501789785921573,0.2923811446565916,43793.0,0.9861477017402648,0.0540336780250072,0.2744111213048259,43793.0,35551.902698516846,54120.36315464973,35551.902698516846,18559.434452056885,5.98870325088501,0.0 -110200,0.1421744,0.01657833,,,,,,,,,,,,,,,,, -110300,0.14074236,0.017535515,,,,,,,,,,,,,,,,, -110400,0.13587344,0.016911512,,,,,,,,,,,,,,,,, -110500,0.15627119,0.019038286,,,,,,,,,,,,,,,,, -110600,0.15482615,0.018992858,,,,,,,,,,,,,,,,, -110700,0.14356846,0.018981557,,,,,,,,,,,,,,,,, -110800,0.13658386,0.017944057,,,,,,,,,,,,,,,,, -110858,,,0.9954262971878052,0.0142672462388873,0.7683804375932366,0.987072765827179,0.0501789711415767,0.2923924791678979,43793.0,0.9861477017402648,0.0540336854755878,0.2744625788209498,43793.0,35791.86517548561,54472.83672046661,35791.86517548561,18671.87842154503,6.035488843917847,0.0 -110900,0.15676619,0.018383343,,,,,,,,,,,,,,,,, -111000,0.1316365,0.017232843,,,,,,,,,,,,,,,,, -111100,0.15085967,0.01959409,,,,,,,,,,,,,,,,, -111200,0.14096338,0.018103618,,,,,,,,,,,,,,,,, -111300,0.15260153,0.016769826,,,,,,,,,,,,,,,,, -111400,0.14455007,0.01870212,,,,,,,,,,,,,,,,, -111500,0.1471691,0.019453513,,,,,,,,,,,,,,,,, -111600,0.16089943,0.018846447,,,,,,,,,,,,,,,,, -111605,,,0.9955164790153505,0.0140737434849143,0.7725581660335283,0.987072765827179,0.0501789674162864,0.292436956294154,43793.0,0.9861477017402648,0.0540336780250072,0.2744341457188285,43793.0,36031.905665159225,54825.25518202782,36031.905665159225,18784.168533802032,6.103245735168457,0.0 -111700,0.1624481,0.020753684,,,,,,,,,,,,,,,,, -111800,0.14377981,0.018283604,,,,,,,,,,,,,,,,, -111900,0.16087587,0.021329507,,,,,,,,,,,,,,,,, -112000,0.15874492,0.01983241,,,,,,,,,,,,,,,,, -112100,0.14528938,0.017025145,,,,,,,,,,,,,,,,, -112200,0.15027745,0.017963495,,,,,,,,,,,,,,,,, -112300,0.14372317,0.018698357,,,,,,,,,,,,,,,,, -112356,,,0.9954671859741212,0.0142094586044549,0.7682750228101535,0.987072765827179,0.0501789674162864,0.2924133826804372,43793.0,0.9861477017402648,0.0540336780250072,0.2744849630053591,43793.0,36271.98492002487,55173.6886715889,36271.98492002487,18892.454761505127,6.150546789169312,0.0 -112400,0.14980687,0.020068623,,,,,,,,,,,,,,,,, -112500,0.13127425,0.018048258,,,,,,,,,,,,,,,,, -112600,0.16520907,0.020710992,,,,,,,,,,,,,,,,, -112700,0.14089143,0.018581154,,,,,,,,,,,,,,,,, -112800,0.14093278,0.01962547,,,,,,,,,,,,,,,,, -112900,0.14788277,0.01910468,,,,,,,,,,,,,,,,, -113000,0.15014416,0.018814268,,,,,,,,,,,,,,,,, -113099,,,0.9955037236213684,0.0141445789486169,0.7743666140906535,0.987072765827179,0.050178974866867,0.2924845483570632,43793.0,0.9861477017402648,0.0540336854755878,0.2743955276193887,43793.0,36512.089584350586,55517.98745632172,36512.089584350586,18996.58028960228,6.198683500289917,0.0 -113100,0.14610972,0.01871732,,,,,,,,,,,,,,,,, -113200,0.1358287,0.018424857,,,,,,,,,,,,,,,,, -113300,0.13265081,0.016815923,,,,,,,,,,,,,,,,, -113400,0.1352569,0.016297197,,,,,,,,,,,,,,,,, -113500,0.15183134,0.017233752,,,,,,,,,,,,,,,,, -113600,0.16049191,0.019686477,,,,,,,,,,,,,,,,, -113700,0.1366177,0.017924435,,,,,,,,,,,,,,,,, -113800,0.15077569,0.01919757,,,,,,,,,,,,,,,,, -113842,,,0.9955094456672668,0.0141254626214504,0.7793063321714918,0.987072765827179,0.0501789674162864,0.2925966404311597,43793.0,0.9861477017402648,0.0540336780250072,0.2743417720933081,43793.0,36752.29545927048,55865.72206258774,36752.29545927048,19104.04130196572,6.246399402618408,0.0 -113900,0.14912048,0.017785653,,,,,,,,,,,,,,,,, -114000,0.14450951,0.019713797,,,,,,,,,,,,,,,,, -114100,0.14396696,0.017522616,,,,,,,,,,,,,,,,, -114200,0.15452415,0.017540762,,,,,,,,,,,,,,,,, -114300,0.14811175,0.018496415,,,,,,,,,,,,,,,,, -114400,0.1366082,0.016698238,,,,,,,,,,,,,,,,, -114500,0.13071813,0.017849788,,,,,,,,,,,,,,,,, -114587,,,0.9954116940498352,0.0142963584512472,0.7696221359037616,0.987072765827179,0.050178974866867,0.2923634887193423,43793.0,0.9861477017402648,0.0540336854755878,0.2744716788874261,43793.0,36992.5061249733,56220.20129728317,36992.5061249733,19218.23925757408,6.296609163284302,0.0 -114600,0.16015485,0.019769033,,,,,,,,,,,,,,,,, -114700,0.17260203,0.019045254,,,,,,,,,,,,,,,,, -114800,0.13527587,0.017649176,,,,,,,,,,,,,,,,, -114900,0.15982428,0.01864596,,,,,,,,,,,,,,,,, -115000,0.1477473,0.01934205,,,,,,,,,,,,,,,,, -115100,0.12993088,0.017204078,,,,,,,,,,,,,,,,, -115200,0.14953254,0.017569503,,,,,,,,,,,,,,,,, -115300,0.14827918,0.017906634,,,,,,,,,,,,,,,,, -115340,,,0.9955100417137146,0.0141028864309191,0.7763613971947132,0.987072765827179,0.050178974866867,0.292477968978769,43793.0,0.9861477017402648,0.0540336780250072,0.2744870863371934,43793.0,37232.54695749283,56567.71599459648,37232.54695749283,19325.64443707466,6.345672607421875,0.0 -115400,0.13815325,0.017097134,,,,,,,,,,,,,,,,, -115500,0.17825358,0.020659667,,,,,,,,,,,,,,,,, -115600,0.14446439,0.017374683,,,,,,,,,,,,,,,,, -115700,0.13385807,0.018845825,,,,,,,,,,,,,,,,, -115800,0.14960042,0.018064082,,,,,,,,,,,,,,,,, -115900,0.14485435,0.017971933,,,,,,,,,,,,,,,,, -116000,0.121817745,0.015883012,,,,,,,,,,,,,,,,, -116085,,,0.995463728904724,0.0142246037721633,0.7704293389531462,0.987072765827179,0.0501789711415767,0.2923070868083733,43793.0,0.9861477017402648,0.0540336854755878,0.2744327779240913,43793.0,37472.57381772995,56915.96519422531,37472.57381772995,19433.797921419144,6.394840955734253,0.0 -116100,0.13141301,0.01725365,,,,,,,,,,,,,,,,, -116200,0.15021706,0.018285187,,,,,,,,,,,,,,,,, -116300,0.13934803,0.017331671,,,,,,,,,,,,,,,,, -116400,0.13479824,0.01703065,,,,,,,,,,,,,,,,, -116500,0.16253257,0.019115759,,,,,,,,,,,,,,,,, -116600,0.15285018,0.01729196,,,,,,,,,,,,,,,,, -116700,0.15394497,0.019779852,,,,,,,,,,,,,,,,, -116800,0.14940101,0.016930375,,,,,,,,,,,,,,,,, -116839,,,0.9954558610916138,0.0142057770863175,0.7722315712608714,0.987072765827179,0.0501789674162864,0.2924067980397405,43793.0,0.9861477017402648,0.0540336780250072,0.2744755774496301,43793.0,37712.61404252052,57264.01523518562,37712.61404252052,19541.73554468155,6.446910858154297,0.0 -116900,0.16526921,0.017846156,,,,,,,,,,,,,,,,, -117000,0.12996602,0.015498627,,,,,,,,,,,,,,,,, -117100,0.1480419,0.018305188,,,,,,,,,,,,,,,,, -117200,0.14986826,0.01827683,,,,,,,,,,,,,,,,, -117300,0.16629776,0.019599631,,,,,,,,,,,,,,,,, -117400,0.13426098,0.020140568,,,,,,,,,,,,,,,,, -117500,0.13457523,0.016840758,,,,,,,,,,,,,,,,, -117592,,,0.995553195476532,0.0139929689466953,0.7718386775875689,0.987072765827179,0.0501789785921573,0.2923722358786466,43793.0,0.9861477017402648,0.0540336780250072,0.2744644973891466,43793.0,37952.76475858688,57610.35025882721,37952.76475858688,19647.851148843765,6.495280981063843,0.0 -117600,0.16218334,0.018474124,,,,,,,,,,,,,,,,, -117700,0.14612925,0.019143801,,,,,,,,,,,,,,,,, -117800,0.12910433,0.019185923,,,,,,,,,,,,,,,,, -117900,0.16717264,0.020583497,,,,,,,,,,,,,,,,, -118000,0.13543549,0.018182633,,,,,,,,,,,,,,,,, -118100,0.1392397,0.018653568,,,,,,,,,,,,,,,,, -118200,0.14528564,0.01982086,,,,,,,,,,,,,,,,, -118300,0.14086577,0.013777951,,,,,,,,,,,,,,,,, -118339,,,0.9954729080200196,0.014159295707941,0.7737373209301305,0.987072765827179,0.050178974866867,0.2923521487118571,43793.0,0.9861477017402648,0.0540336780250072,0.2745096845449642,43793.0,38192.82896399498,57968.28904867172,38192.82896399498,19765.65750479698,6.543762683868408,0.0 -118400,0.15386434,0.020383658,,,,,,,,,,,,,,,,, -118500,0.11809155,0.016306989,,,,,,,,,,,,,,,,, -118600,0.14871784,0.019950705,,,,,,,,,,,,,,,,, -118700,0.14114685,0.017562373,,,,,,,,,,,,,,,,, -118800,0.1281083,0.014929369,,,,,,,,,,,,,,,,, -118900,0.14706278,0.017947106,,,,,,,,,,,,,,,,, -119000,0.15563072,0.019426636,,,,,,,,,,,,,,,,, -119088,,,0.9955180287361144,0.0141503792256116,0.7726563510216342,0.987072765827179,0.0501789711415767,0.2924237165053917,43793.0,0.9861477017402648,0.0540336854755878,0.2745536327412705,43793.0,38432.975707530975,58317.23108887672,38432.975707530975,19874.38309621811,6.593130588531494,0.0 -119100,0.16187663,0.01895899,,,,,,,,,,,,,,,,, -119200,0.15083691,0.019313283,,,,,,,,,,,,,,,,, -119300,0.13009283,0.01675663,,,,,,,,,,,,,,,,, -119400,0.13043733,0.01880396,,,,,,,,,,,,,,,,, -119500,0.15884475,0.017271074,,,,,,,,,,,,,,,,, -119600,0.16354726,0.019258002,,,,,,,,,,,,,,,,, -119700,0.1293185,0.016483946,,,,,,,,,,,,,,,,, -119800,0.13521095,0.017107422,,,,,,,,,,,,,,,,, -119839,,,0.995469093322754,0.0141744092106819,0.7726981784844295,0.987072765827179,0.0501789785921573,0.2923949008598069,43793.0,0.9861477017402648,0.0540336780250072,0.2744036483974371,43793.0,38673.135501623154,58669.247754096985,38673.135501623154,19986.17211055756,6.641467571258545,0.0 -119900,0.1476868,0.016616154,,,,,,,,,,,,,,,,, -120000,0.12999341,0.0172647,,,,,,,,,,,,,,,,, -120100,0.14024448,0.01748616,,,,,,,,,,,,,,,,, -120200,0.14638045,0.01831525,,,,,,,,,,,,,,,,, -120300,0.14595702,0.017769452,,,,,,,,,,,,,,,,, -120400,0.14771621,0.018343877,,,,,,,,,,,,,,,,, -120500,0.14390178,0.018202808,,,,,,,,,,,,,,,,, -120584,,,0.9954341650009156,0.0142730260267853,0.7724481400829802,0.987072765827179,0.050178974866867,0.2924784708646529,43793.0,0.9861477017402648,0.0540336780250072,0.2744314207490373,43793.0,38913.38436675072,59016.06122899056,38913.38436675072,20092.66683936119,6.689886331558228,0.0 -120600,0.13564484,0.016706316,,,,,,,,,,,,,,,,, -120700,0.13869898,0.01761441,,,,,,,,,,,,,,,,, -120800,0.1317209,0.0157029,,,,,,,,,,,,,,,,, -120900,0.13539281,0.017835712,,,,,,,,,,,,,,,,, -121000,0.13425903,0.017215518,,,,,,,,,,,,,,,,, -121100,0.15541472,0.017190356,,,,,,,,,,,,,,,,, -121200,0.13917467,0.015606354,,,,,,,,,,,,,,,,, -121300,0.1388968,0.01773714,,,,,,,,,,,,,,,,, -121324,,,0.9955268502235411,0.0140812043100595,0.7726060218973945,0.987072765827179,0.0501789674162864,0.2924184000733607,43793.0,0.9861477017402648,0.0540336780250072,0.2744400738644023,43793.0,39153.452647686005,59371.07761955261,39153.452647686005,20207.542983531952,6.741236925125122,0.0 -121400,0.1305922,0.017267678,,,,,,,,,,,,,,,,, -121500,0.16215372,0.019803373,,,,,,,,,,,,,,,,, -121600,0.14306904,0.01827157,,,,,,,,,,,,,,,,, -121700,0.13059705,0.016051171,,,,,,,,,,,,,,,,, -121800,0.14642797,0.018361226,,,,,,,,,,,,,,,,, -121900,0.12680763,0.016724432,,,,,,,,,,,,,,,,, -122000,0.13513193,0.016697748,,,,,,,,,,,,,,,,, -122076,,,0.9955320358276368,0.0140745975077152,0.7766342497186973,0.987072765827179,0.0501789674162864,0.2924523927774579,43793.0,0.9861477017402648,0.0540336780250072,0.274348726685141,43793.0,39393.46521615982,59723.72233343125,39393.46521615982,20320.10085272789,6.794872522354126,0.0 -122100,0.13327433,0.017434496,,,,,,,,,,,,,,,,, -122200,0.14226638,0.01834894,,,,,,,,,,,,,,,,, -122300,0.13516288,0.016813733,,,,,,,,,,,,,,,,, -122400,0.14695793,0.018793387,,,,,,,,,,,,,,,,, -122500,0.14586844,0.018937837,,,,,,,,,,,,,,,,, -122600,0.1251253,0.015666649,,,,,,,,,,,,,,,,, -122700,0.14129572,0.018540489,,,,,,,,,,,,,,,,, -122800,0.14088134,0.017704766,,,,,,,,,,,,,,,,, -122824,,,0.9954253435134888,0.014281541109085,0.7695000327559407,0.987072765827179,0.0501789785921573,0.2924960667603323,43793.0,0.9861477017402648,0.0540336780250072,0.2744166570766237,43793.0,39633.420177698135,60074.61845064163,39633.420177698135,20430.971091747284,6.845506191253662,0.0 -122900,0.15035836,0.018358788,,,,,,,,,,,,,,,,, -123000,0.1462159,0.018876696,,,,,,,,,,,,,,,,, -123100,0.13354714,0.018703954,,,,,,,,,,,,,,,,, -123200,0.13970488,0.016892549,,,,,,,,,,,,,,,,, -123300,0.14650206,0.018166956,,,,,,,,,,,,,,,,, -123400,0.13912614,0.01571114,,,,,,,,,,,,,,,,, -123500,0.13071753,0.016329223,,,,,,,,,,,,,,,,, -123573,,,0.9954928755760192,0.0141173610463738,0.7755888247643258,0.987072765827179,0.0501789711415767,0.2924262429932207,43793.0,0.9861477017402648,0.0540336780250072,0.2744540328933385,43793.0,39873.37377142906,60424.03245639801,39873.37377142906,20540.36001110077,6.896385669708252,0.0 -123600,0.1567752,0.018785682,,,,,,,,,,,,,,,,, -123700,0.14806838,0.017910449,,,,,,,,,,,,,,,,, -123800,0.14060695,0.017243382,,,,,,,,,,,,,,,,, -123900,0.13285588,0.017692354,,,,,,,,,,,,,,,,, -124000,0.13849092,0.015083955,,,,,,,,,,,,,,,,, -124100,0.14897086,0.018662974,,,,,,,,,,,,,,,,, -124200,0.14592178,0.018438445,,,,,,,,,,,,,,,,, -124300,0.12906154,0.018143883,,,,,,,,,,,,,,,,, -124305,,,0.9954410195350648,0.0142628876492381,0.7707968891314905,0.987072765827179,0.0501789785921573,0.2924575249004623,43793.0,0.9861477017402648,0.0540336780250072,0.2743404442294432,43793.0,40113.52621340752,60774.52795505524,40113.52621340752,20650.623522281647,6.954053163528442,0.0 -124400,0.129151,0.019348402,,,,,,,,,,,,,,,,, -124500,0.14757687,0.01931312,,,,,,,,,,,,,,,,, -124600,0.14569741,0.018315785,,,,,,,,,,,,,,,,, -124700,0.14013085,0.01958368,,,,,,,,,,,,,,,,, -124800,0.14898834,0.01858747,,,,,,,,,,,,,,,,, -124900,0.13551122,0.016584119,,,,,,,,,,,,,,,,, -125000,0.13388737,0.017456766,,,,,,,,,,,,,,,,, -125052,,,0.9955234527587892,0.0140299526974558,0.7771478198568447,0.987072765827179,0.050178974866867,0.2923724415569426,43793.0,0.9861477017402648,0.0540336780250072,0.2744823761778551,43793.0,40353.63005948067,61124.07531356812,40353.63005948067,20759.99427127838,7.006211042404175,0.0 -125100,0.13492374,0.018521594,,,,,,,,,,,,,,,,, -125200,0.15756153,0.017557966,,,,,,,,,,,,,,,,, -125300,0.14322035,0.01713299,,,,,,,,,,,,,,,,, -125400,0.1354129,0.017054075,,,,,,,,,,,,,,,,, -125500,0.16241239,0.02067897,,,,,,,,,,,,,,,,, -125600,0.14683354,0.016851032,,,,,,,,,,,,,,,,, -125700,0.14446579,0.017930774,,,,,,,,,,,,,,,,, -125798,,,0.9955294132232666,0.0140703143551945,0.7733585278918658,0.987072765827179,0.0501789785921573,0.2923784314654618,43793.0,0.9861477017402648,0.0540336780250072,0.2744143988411106,43793.0,40593.73137521744,61473.34885716438,40593.73137521744,20869.096217632294,7.056287527084351,0.0 -125800,0.13398448,0.016947081,,,,,,,,,,,,,,,,, -125900,0.13290615,0.017572714,,,,,,,,,,,,,,,,, -126000,0.12671655,0.01651889,,,,,,,,,,,,,,,,, -126100,0.13626842,0.01724281,,,,,,,,,,,,,,,,, -126200,0.14557381,0.017606422,,,,,,,,,,,,,,,,, -126300,0.122062124,0.016348489,,,,,,,,,,,,,,,,, -126400,0.13522395,0.019369751,,,,,,,,,,,,,,,,, -126500,0.14777865,0.018968744,,,,,,,,,,,,,,,,, -126555,,,0.9954558610916138,0.0142387766391038,0.7651018401588172,0.987072765827179,0.0501789711415767,0.2926135407799796,43793.0,0.9861477017402648,0.0540336780250072,0.2743924223147777,43793.0,40833.93912935257,61821.20895910263,40833.93912935257,20976.6785402298,7.106051445007324,0.0 -126600,0.15238722,0.018440193,,,,,,,,,,,,,,,,, -126700,0.12857698,0.015098455,,,,,,,,,,,,,,,,, -126800,0.14159924,0.019116314,,,,,,,,,,,,,,,,, -126900,0.15061648,0.017705001,,,,,,,,,,,,,,,,, -127000,0.1364397,0.017168008,,,,,,,,,,,,,,,,, -127100,0.13531563,0.017797474,,,,,,,,,,,,,,,,, -127200,0.16261712,0.01875873,,,,,,,,,,,,,,,,, -127300,0.14129858,0.017551811,,,,,,,,,,,,,,,,, -127310,,,0.9954926371574402,0.0141503317281603,0.7795920817475045,0.987072765827179,0.0501789711415767,0.2925470066934093,43793.0,0.9861477017402648,0.0540336780250072,0.2743393079338933,43793.0,41074.18945026398,62171.43013072014,41074.18945026398,21086.576821565628,7.158508539199829,0.0 -127400,0.13913126,0.017893791,,,,,,,,,,,,,,,,, -127500,0.13674821,0.016866708,,,,,,,,,,,,,,,,, -127600,0.13164954,0.016937165,,,,,,,,,,,,,,,,, -127700,0.15219934,0.020237062,,,,,,,,,,,,,,,,, -127800,0.14377512,0.018267771,,,,,,,,,,,,,,,,, -127900,0.14276469,0.017717738,,,,,,,,,,,,,,,,, -128000,0.13648677,0.019198861,,,,,,,,,,,,,,,,, -128049,,,0.9954724311828612,0.0141659248620271,0.7580778686618597,0.987072765827179,0.0501789711415767,0.2922960154879298,43793.0,0.9861477017402648,0.0540336780250072,0.2744622774438503,43793.0,41314.4131102562,62520.848984479904,41314.4131102562,21195.692762613297,7.215036869049072,0.0 -128100,0.14428681,0.017925266,,,,,,,,,,,,,,,,, -128200,0.1433548,0.017388644,,,,,,,,,,,,,,,,, -128300,0.1410761,0.01785601,,,,,,,,,,,,,,,,, -128400,0.15008298,0.018684285,,,,,,,,,,,,,,,,, -128500,0.14333732,0.017163951,,,,,,,,,,,,,,,,, -128600,0.13592243,0.015978828,,,,,,,,,,,,,,,,, -128700,0.15501787,0.019109042,,,,,,,,,,,,,,,,, -128800,,,0.9954509735107422,0.0143068879842758,0.772381038299001,0.987072765827179,0.0501789711415767,0.2925018996634979,43793.0,0.9861477017402648,0.0540336780250072,0.2744522066223132,43793.0,41554.57661104202,62871.640630960464,41554.57661104202,21306.24901342392,7.266671419143677,0.0 -128800,0.13094494,0.01697002,,,,,,,,,,,,,,,,, -128900,0.1527095,0.018576467,,,,,,,,,,,,,,,,, -129000,0.1567837,0.02110279,,,,,,,,,,,,,,,,, -129100,0.15219276,0.020388804,,,,,,,,,,,,,,,,, -129200,0.14337158,0.018345334,,,,,,,,,,,,,,,,, -129300,0.14183153,0.01884359,,,,,,,,,,,,,,,,, -129400,0.1375003,0.018193536,,,,,,,,,,,,,,,,, -129500,0.14078926,0.01928569,,,,,,,,,,,,,,,,, -129536,,,0.9955395460128784,0.0140335224568843,0.7769832884068464,0.987072765827179,0.050178974866867,0.2925394755638397,43793.0,0.9861477017402648,0.0540336854755878,0.2744407621982706,43793.0,41794.73320269585,63226.893221616745,41794.73320269585,21421.27315235138,7.316658735275267,0.0 -129600,0.13135518,0.01840482,,,,,,,,,,,,,,,,, -129700,0.1644923,0.019614065,,,,,,,,,,,,,,,,, -129800,0.15363577,0.019335175,,,,,,,,,,,,,,,,, -129900,0.14914964,0.01994319,,,,,,,,,,,,,,,,, -130000,0.14185432,0.018574899,,,,,,,,,,,,,,,,, -130100,0.1508664,0.019496925,,,,,,,,,,,,,,,,, -130200,0.13167718,0.019881617,,,,,,,,,,,,,,,,, -130272,,,0.9954925775527954,0.0141486516222357,0.7728318242745001,0.987072765827179,0.050178974866867,0.2924767408212498,43793.0,0.9861477017402648,0.0540336780250072,0.2744175331022145,43793.0,42034.90148067474,63576.49175739288,42034.90148067474,21530.62320971489,7.376136064529419,0.0 -130300,0.14925714,0.018767808,,,,,,,,,,,,,,,,, -130400,0.1539567,0.017934531,,,,,,,,,,,,,,,,, -130500,0.14859778,0.018932145,,,,,,,,,,,,,,,,, -130600,0.139432,0.01876184,,,,,,,,,,,,,,,,, -130700,0.16508655,0.019424802,,,,,,,,,,,,,,,,, -130800,0.15534875,0.019661475,,,,,,,,,,,,,,,,, -130900,0.14503571,0.018004432,,,,,,,,,,,,,,,,, -131000,0.1444511,0.018348522,,,,,,,,,,,,,,,,, -131001,,,0.9954299330711364,0.0142009994015097,0.7792628467588556,0.987072765827179,0.0501789711415767,0.2925354472598478,43793.0,0.9861477017402648,0.0540336780250072,0.2744591741016137,43793.0,42274.94419336319,63927.694029569626,42274.94419336319,21641.708025217056,7.427554130554199,0.0 -131100,0.14893362,0.017760813,,,,,,,,,,,,,,,,, -131200,0.14889991,0.01897219,,,,,,,,,,,,,,,,, -131300,0.13189773,0.017540023,,,,,,,,,,,,,,,,, -131400,0.14970146,0.01839949,,,,,,,,,,,,,,,,, -131500,0.12991692,0.017282926,,,,,,,,,,,,,,,,, -131600,0.14673446,0.018001849,,,,,,,,,,,,,,,,, -131700,0.13573052,0.019062938,,,,,,,,,,,,,,,,, -131735,,,0.995569348335266,0.0139863276854157,0.7674853021467718,0.987072765827179,0.0501789711415767,0.2923617745264467,43793.0,0.9861477017402648,0.0540336780250072,0.2744910385094342,43793.0,42515.0906047821,64279.381809711456,42515.0906047821,21753.172875642776,7.480916500091553,0.0 -131800,0.14945772,0.019580835,,,,,,,,,,,,,,,,, -131900,0.14056282,0.017549314,,,,,,,,,,,,,,,,, -132000,0.13227408,0.017461756,,,,,,,,,,,,,,,,, -132100,0.15093844,0.019637756,,,,,,,,,,,,,,,,, -132200,0.15186083,0.018838102,,,,,,,,,,,,,,,,, -132300,0.12448923,0.016128264,,,,,,,,,,,,,,,,, -132400,0.13425776,0.018250162,,,,,,,,,,,,,,,,, -132474,,,0.9954354166984558,0.0142763331532478,0.7732393560158943,0.987072765827179,0.0501789785921573,0.2925081247424754,43793.0,0.9861477017402648,0.0540336780250072,0.2744854342397442,43793.0,42755.26333451271,64630.84366893768,42755.26333451271,21864.389171123505,7.533536434173584,0.0 -132500,0.14845203,0.01794841,,,,,,,,,,,,,,,,, -132600,0.14450273,0.017600236,,,,,,,,,,,,,,,,, -132700,0.1392965,0.016452631,,,,,,,,,,,,,,,,, -132800,0.14149524,0.018443814,,,,,,,,,,,,,,,,, -132900,0.14012298,0.015830712,,,,,,,,,,,,,,,,, -133000,0.13220885,0.016888436,,,,,,,,,,,,,,,,, -133100,0.13748172,0.018491106,,,,,,,,,,,,,,,,, -133192,,,0.9955052733421326,0.0141240824013948,0.7706519939588254,0.987072765827179,0.0501789711415767,0.292544236251272,43793.0,0.9861477017402648,0.0540336780250072,0.2745147941177334,43793.0,42995.351095438,64981.21354317665,42995.351095438,21974.587879419327,7.593179702758789,0.0 -133200,0.15293229,0.02025319,,,,,,,,,,,,,,,,, -133300,0.14179978,0.016112475,,,,,,,,,,,,,,,,, -133400,0.15401913,0.01725519,,,,,,,,,,,,,,,,, -133500,0.13311778,0.017573446,,,,,,,,,,,,,,,,, -133600,0.13746427,0.017690858,,,,,,,,,,,,,,,,, -133700,0.14148,0.01994015,,,,,,,,,,,,,,,,, -133800,0.13575105,0.01579946,,,,,,,,,,,,,,,,, -133900,0.15034066,0.016516507,,,,,,,,,,,,,,,,, -133937,,,0.9955182671546936,0.0140765700489282,0.7777763457581376,0.987072765827179,0.050178974866867,0.292485678695066,43793.0,0.9861477017402648,0.0540336780250072,0.2744378650368807,43793.0,43235.31395435333,65327.087550640106,43235.31395435333,22080.42444038391,7.647066116333008,0.0 -134000,0.13276222,0.01745574,,,,,,,,,,,,,,,,, -134100,0.13792293,0.01670523,,,,,,,,,,,,,,,,, -134200,0.1675681,0.019207519,,,,,,,,,,,,,,,,, -134300,0.14348382,0.016877085,,,,,,,,,,,,,,,,, -134400,0.15113686,0.017341021,,,,,,,,,,,,,,,,, -134500,0.15055743,0.017021324,,,,,,,,,,,,,,,,, -134600,0.1461796,0.017092125,,,,,,,,,,,,,,,,, -134684,,,0.9954232573509216,0.0143252816051244,0.7749691994204116,0.987072765827179,0.050178974866867,0.292514974609005,43793.0,0.9861477017402648,0.0540336854755878,0.2744699717030999,43793.0,43475.29788017273,65674.99710512161,43475.29788017273,22188.27728843689,7.699393749237059,0.0 -134700,0.14209382,0.016872294,,,,,,,,,,,,,,,,, -134800,0.14956935,0.020023521,,,,,,,,,,,,,,,,, -134900,0.13981454,0.017701639,,,,,,,,,,,,,,,,, -135000,0.14235923,0.017332353,,,,,,,,,,,,,,,,, -135100,0.1369133,0.0187672,,,,,,,,,,,,,,,,, -135200,0.13393971,0.018474095,,,,,,,,,,,,,,,,, -135300,0.13751142,0.01737098,,,,,,,,,,,,,,,,, -135400,0.1343622,0.01696878,,,,,,,,,,,,,,,,, -135432,,,0.9955505132675172,0.0139947505667805,0.771572487802147,0.987072765827179,0.0501789674162864,0.2924638056018725,43793.0,0.9861477017402648,0.0540336780250072,0.2743817458990037,43793.0,43715.30280971527,66018.88409805298,43715.30280971527,22292.08743953705,7.750726938247681,0.0 -135500,0.13051234,0.01746786,,,,,,,,,,,,,,,,, -135600,0.1375457,0.016076826,,,,,,,,,,,,,,,,, -135700,0.15139563,0.016804852,,,,,,,,,,,,,,,,, -135800,0.15267198,0.01990327,,,,,,,,,,,,,,,,, -135900,0.16021219,0.019523056,,,,,,,,,,,,,,,,, -136000,0.12679277,0.01595209,,,,,,,,,,,,,,,,, -136100,0.13159923,0.016191253,,,,,,,,,,,,,,,,, -136175,,,0.9954281449317932,0.0143410833552479,0.7698603206108076,0.987072765827179,0.0501789674162864,0.2923284004565294,43793.0,0.9861477017402648,0.0540336780250072,0.2743165139314618,43793.0,43955.23675751686,66369.940782547,43955.23675751686,22403.13285589218,7.805722713470459,0.0 -136200,0.13597395,0.015825054,,,,,,,,,,,,,,,,, -136300,0.15552889,0.018245483,,,,,,,,,,,,,,,,, -136400,0.13819014,0.018062392,,,,,,,,,,,,,,,,, -136500,0.15741311,0.019158056,,,,,,,,,,,,,,,,, -136600,0.15281352,0.018241484,,,,,,,,,,,,,,,,, -136700,0.1559273,0.018815046,,,,,,,,,,,,,,,,, -136800,0.13665117,0.019141106,,,,,,,,,,,,,,,,, -136900,0.1472493,0.01754702,,,,,,,,,,,,,,,,, -136917,,,0.9954986572265624,0.0141018470749259,0.7790549032148038,0.987072765827179,0.0501789785921573,0.2925348605695119,43793.0,0.9861477017402648,0.0540336780250072,0.2744015337243864,43793.0,44195.45515322685,66715.35626888275,44195.45515322685,22508.251557588577,7.863274812698364,0.0 -137000,0.14662674,0.016680012,,,,,,,,,,,,,,,,, -137100,0.1345799,0.01639764,,,,,,,,,,,,,,,,, -137200,0.13794972,0.019060286,,,,,,,,,,,,,,,,, -137300,0.16483594,0.019355033,,,,,,,,,,,,,,,,, -137400,0.13774669,0.018698482,,,,,,,,,,,,,,,,, -137500,0.13074872,0.017288338,,,,,,,,,,,,,,,,, -137600,0.14151415,0.019434445,,,,,,,,,,,,,,,,, -137657,,,0.9955148100852966,0.0141217578202486,0.7732282225552818,0.987072765827179,0.050178974866867,0.2923614648157295,43793.0,0.9861477017402648,0.0540336780250072,0.2743812715418015,43793.0,44435.672504901886,67063.59295868874,44435.672504901886,22616.19847869873,7.916009187698364,0.0 -137700,0.13497534,0.019435186,,,,,,,,,,,,,,,,, -137800,0.13741271,0.016974073,,,,,,,,,,,,,,,,, -137900,0.16853535,0.018719677,,,,,,,,,,,,,,,,, -138000,0.14422858,0.016541759,,,,,,,,,,,,,,,,, -138100,0.14445953,0.018661162,,,,,,,,,,,,,,,,, -138200,0.13653727,0.016885336,,,,,,,,,,,,,,,,, -138300,0.1481238,0.018675663,,,,,,,,,,,,,,,,, -138400,0.1674593,0.019025715,,,,,,,,,,,,,,,,, -138404,,,0.9954774975776672,0.0141524942591786,0.7739244373319915,0.987072765827179,0.0501789711415767,0.2923292070342875,43793.0,0.9861477017402648,0.0540336780250072,0.2746092836619395,43793.0,44675.83380961418,67412.49493050575,44675.83380961418,22724.865947008133,7.969079494476318,0.0 -138500,0.13610205,0.018434376,,,,,,,,,,,,,,,,, -138600,0.14499696,0.01807395,,,,,,,,,,,,,,,,, -138700,0.12629497,0.015909955,,,,,,,,,,,,,,,,, -138800,0.12947808,0.017079,,,,,,,,,,,,,,,,, -138900,0.17032696,0.018156148,,,,,,,,,,,,,,,,, -139000,0.14741087,0.016930528,,,,,,,,,,,,,,,,, -139100,0.14404678,0.01863153,,,,,,,,,,,,,,,,, -139153,,,0.9954907298088074,0.0141388094052672,0.7785713528209731,0.987072765827179,0.0501789674162864,0.2926596141909245,43793.0,0.9861477017402648,0.0540336854755878,0.2744020805623824,43793.0,44915.87771701813,67760.05751276016,44915.87771701813,22832.310029745106,8.023373365402222,0.0 -139200,0.15480721,0.019187769,,,,,,,,,,,,,,,,, -139300,0.13911943,0.018075237,,,,,,,,,,,,,,,,, -139400,0.1415992,0.020140622,,,,,,,,,,,,,,,,, -139500,0.15043549,0.017481884,,,,,,,,,,,,,,,,, -139600,0.1317317,0.017924309,,,,,,,,,,,,,,,,, -139700,0.14602745,0.017407686,,,,,,,,,,,,,,,,, -139800,0.1398054,0.01780977,,,,,,,,,,,,,,,,, -139900,,,0.995424747467041,0.0142870005220174,0.761383769673432,0.987072765827179,0.0501789785921573,0.292531407052199,43793.0,0.9861477017402648,0.0540336854755878,0.2744188849967219,43793.0,45156.082431316376,68106.95738077164,45156.082431316376,22938.924989938736,8.08329463005066,0.0 -139900,0.16506635,0.01790273,,,,,,,,,,,,,,,,, -140000,0.16487436,0.01908436,,,,,,,,,,,,,,,,, -140100,0.15000598,0.019612132,,,,,,,,,,,,,,,,, -140200,0.1459256,0.015334929,,,,,,,,,,,,,,,,, -140300,0.14933878,0.019034639,,,,,,,,,,,,,,,,, -140400,0.14724289,0.018182138,,,,,,,,,,,,,,,,, -140500,0.14939865,0.017188188,,,,,,,,,,,,,,,,, -140600,0.15448493,0.016224932,,,,,,,,,,,,,,,,, -140642,,,0.9955102205276488,0.0141165740787982,0.7750042961359505,0.987072765827179,0.0501789674162864,0.2924117972188245,43793.0,0.9861477017402648,0.0540336854755878,0.2745654491453426,43793.0,45396.19762325287,68457.6458311081,45396.19762325287,23049.42386555672,8.13773775100708,0.0 -140700,0.13658729,0.0174812,,,,,,,,,,,,,,,,, -140800,0.14074488,0.016831705,,,,,,,,,,,,,,,,, -140900,0.12576118,0.016719058,,,,,,,,,,,,,,,,, -141000,0.14928606,0.019559165,,,,,,,,,,,,,,,,, -141100,0.16058624,0.019143224,,,,,,,,,,,,,,,,, -141200,0.15076683,0.017856058,,,,,,,,,,,,,,,,, -141300,0.15679102,0.01911471,,,,,,,,,,,,,,,,, -141381,,,0.9955143928527832,0.0140935350209474,0.777542621227472,0.987072765827179,0.0501789674162864,0.2923862031675611,43793.0,0.9861477017402648,0.0540336780250072,0.2744659269564402,43793.0,45636.25564050674,68806.30915427208,45636.25564050674,23157.95317864418,8.191657781600952,0.0 -141400,0.14077848,0.017902724,,,,,,,,,,,,,,,,, -141500,0.13953859,0.017492963,,,,,,,,,,,,,,,,, -141600,0.14795478,0.018900855,,,,,,,,,,,,,,,,, -141700,0.16470103,0.018465172,,,,,,,,,,,,,,,,, -141800,0.13593675,0.0171329,,,,,,,,,,,,,,,,, -141900,0.1334979,0.01785055,,,,,,,,,,,,,,,,, -142000,0.13887198,0.018458653,,,,,,,,,,,,,,,,, -142100,0.15085691,0.0185017,,,,,,,,,,,,,,,,, -142126,,,0.9954879283905028,0.0141484467312693,0.7753354730267371,0.987072765827179,0.0501789711415767,0.2923758806923691,43793.0,0.9861477017402648,0.0540336780250072,0.2745318743089971,43793.0,45876.27267360687,69154.82185125351,45876.27267360687,23266.372206687927,8.247857809066772,0.0 -142200,0.14447217,0.017913751,,,,,,,,,,,,,,,,, -142300,0.1462425,0.01732273,,,,,,,,,,,,,,,,, -142400,0.14067721,0.01730667,,,,,,,,,,,,,,,,, -142500,0.16706155,0.020413436,,,,,,,,,,,,,,,,, -142600,0.12919201,0.016118763,,,,,,,,,,,,,,,,, -142700,0.1312646,0.016725885,,,,,,,,,,,,,,,,, -142800,0.17325115,0.01946842,,,,,,,,,,,,,,,,, -142855,,,0.9955083727836608,0.0140606602653861,0.783613736914508,0.987072765827179,0.0501789674162864,0.2923354873325368,43793.0,0.9861477017402648,0.0540336780250072,0.2744458325069142,43793.0,46116.45699834824,69502.61528301239,46116.45699834824,23373.897718191147,8.308566331863403,0.0 -142900,0.13790983,0.017311381,,,,,,,,,,,,,,,,, -143000,0.15080777,0.01798394,,,,,,,,,,,,,,,,, -143100,0.16598701,0.019479929,,,,,,,,,,,,,,,,, -143200,0.14257517,0.016129984,,,,,,,,,,,,,,,,, -143300,0.16273917,0.018992566,,,,,,,,,,,,,,,,, -143400,0.14084086,0.017788852,,,,,,,,,,,,,,,,, -143500,0.13788103,0.017427646,,,,,,,,,,,,,,,,, -143584,,,0.9954497814178468,0.0142854899168014,0.7609779749806512,0.987072765827179,0.0501789711415767,0.2925092009419555,43793.0,0.9861477017402648,0.0540336780250072,0.2744276065726812,43793.0,46356.47380423546,69855.18458938599,46356.47380423546,23486.36240339279,8.371394395828247,0.0 -143600,0.14990908,0.016888455,,,,,,,,,,,,,,,,, -143700,0.12625717,0.016298112,,,,,,,,,,,,,,,,, -143800,0.14350471,0.016474815,,,,,,,,,,,,,,,,, -143900,0.14126977,0.01581933,,,,,,,,,,,,,,,,, -144000,0.1439014,0.019577892,,,,,,,,,,,,,,,,, -144100,0.13398702,0.01656342,,,,,,,,,,,,,,,,, -144200,0.14903215,0.019120814,,,,,,,,,,,,,,,,, -144300,0.13542871,0.017542591,,,,,,,,,,,,,,,,, -144327,,,0.9954796433448792,0.0142038194462656,0.7661755426876338,0.987072765827179,0.0501789674162864,0.2926325091584312,43793.0,0.9861477017402648,0.0540336854755878,0.2744851973090345,43793.0,46596.55307674408,70202.66167020798,46596.55307674408,23593.684818267822,8.426125526428223,0.0 -144400,0.15595873,0.017731553,,,,,,,,,,,,,,,,, -144500,0.14055961,0.018673465,,,,,,,,,,,,,,,,, -144600,0.13804883,0.01954748,,,,,,,,,,,,,,,,, -144700,0.14039648,0.019408353,,,,,,,,,,,,,,,,, -144800,0.15403682,0.01931453,,,,,,,,,,,,,,,,, -144900,0.12693065,0.016361617,,,,,,,,,,,,,,,,, -145000,0.14532591,0.01849503,,,,,,,,,,,,,,,,, -145070,,,0.9954817295074464,0.0141157330945134,0.7795299761374892,0.987072765827179,0.050178974866867,0.2923431012403145,43793.0,0.9861477017402648,0.0540336780250072,0.274352949729243,43793.0,46836.53562474251,70548.53416538239,46836.53562474251,23699.498154878616,8.482213497161865,0.0 -145100,0.13770567,0.017832477,,,,,,,,,,,,,,,,, -145200,0.13068016,0.016840445,,,,,,,,,,,,,,,,, -145300,0.1454404,0.017556481,,,,,,,,,,,,,,,,, -145400,0.15104574,0.020425035,,,,,,,,,,,,,,,,, -145500,0.1376386,0.018135594,,,,,,,,,,,,,,,,, -145600,0.13780367,0.019323992,,,,,,,,,,,,,,,,, -145700,0.14225231,0.018602038,,,,,,,,,,,,,,,,, -145800,0.13211957,0.018182836,,,,,,,,,,,,,,,,, -145811,,,0.995508313179016,0.0141398878768086,0.7773677847036007,0.987072765827179,0.0501789674162864,0.2924520914694519,43793.0,0.9861477017402648,0.0540336780250072,0.2744340913738719,43793.0,47076.47433376312,70895.5379679203,47076.47433376312,23806.48648405075,8.537409782409668,0.0 -145900,0.13126628,0.017377313,,,,,,,,,,,,,,,,, -146000,0.13870254,0.01933213,,,,,,,,,,,,,,,,, -146100,0.14360751,0.017724145,,,,,,,,,,,,,,,,, -146200,0.15664828,0.01740961,,,,,,,,,,,,,,,,, -146300,0.14650969,0.01783688,,,,,,,,,,,,,,,,, -146400,0.1420598,0.016898269,,,,,,,,,,,,,,,,, -146500,0.1463751,0.018225452,,,,,,,,,,,,,,,,, -146559,,,0.9955046772956848,0.0141176972538232,0.7730506080338739,0.987072765827179,0.0501789674162864,0.2924591081368707,43793.0,0.9861477017402648,0.0540336780250072,0.2745671273673323,43793.0,47316.4525346756,71243.1797504425,47316.4525346756,23914.076204538345,8.591312885284424,0.0 -146600,0.14732088,0.019344043,,,,,,,,,,,,,,,,, -146700,0.17058966,0.020161083,,,,,,,,,,,,,,,,, -146800,0.13305771,0.017095119,,,,,,,,,,,,,,,,, -146900,0.14559929,0.017786473,,,,,,,,,,,,,,,,, -147000,0.12534514,0.015793445,,,,,,,,,,,,,,,,, -147100,0.14421917,0.017318595,,,,,,,,,,,,,,,,, -147200,0.13866709,0.019422986,,,,,,,,,,,,,,,,, -147295,,,0.9955081939697266,0.0140927387401461,0.7726374976251774,0.987072765827179,0.0501789711415767,0.2924923169390194,43793.0,0.9861477017402648,0.0540336780250072,0.2743590851789299,43793.0,47556.43048453331,71588.04568099976,47556.43048453331,24018.88875675201,8.645967245101929,0.0 -147300,0.12744293,0.017039713,,,,,,,,,,,,,,,,, -147400,0.12749681,0.016510122,,,,,,,,,,,,,,,,, -147500,0.12684458,0.017773598,,,,,,,,,,,,,,,,, -147600,0.14778475,0.016320206,,,,,,,,,,,,,,,,, -147700,0.14955932,0.018883105,,,,,,,,,,,,,,,,, -147800,0.1453272,0.019977074,,,,,,,,,,,,,,,,, -147900,0.14022641,0.019430786,,,,,,,,,,,,,,,,, -148000,0.13608927,0.01640372,,,,,,,,,,,,,,,,, -148044,,,0.995439887046814,0.0142427571117877,0.7694661257596855,0.987072765827179,0.050178974866867,0.292422802303003,43793.0,0.9861477017402648,0.0540336780250072,0.2744173503944239,43793.0,47796.37060403824,71933.60494160652,47796.37060403824,24124.432327747345,8.701493501663208,0.0 -148100,0.14549428,0.016214773,,,,,,,,,,,,,,,,, -148200,0.15424047,0.018347127,,,,,,,,,,,,,,,,, -148300,0.15612403,0.020530092,,,,,,,,,,,,,,,,, -148400,0.13831231,0.019101454,,,,,,,,,,,,,,,,, -148500,0.15073763,0.018254941,,,,,,,,,,,,,,,,, -148600,0.14593454,0.019556792,,,,,,,,,,,,,,,,, -148700,0.12720492,0.017717427,,,,,,,,,,,,,,,,, -148793,,,0.9954816102981568,0.0141702918335795,0.7750624725733591,0.987072765827179,0.0501789674162864,0.2924158822165662,43793.0,0.9861477017402648,0.0540336780250072,0.2743672697535215,43793.0,48036.4431014061,72279.34216976166,48036.4431014061,24230.023329257965,8.754802942276001,0.0 -148800,0.12878947,0.016489761,,,,,,,,,,,,,,,,, -148900,0.15114403,0.018692538,,,,,,,,,,,,,,,,, -149000,0.12977563,0.016645908,,,,,,,,,,,,,,,,, -149100,0.1398961,0.01597106,,,,,,,,,,,,,,,,, -149200,0.15796517,0.01853598,,,,,,,,,,,,,,,,, -149300,0.14029548,0.018494451,,,,,,,,,,,,,,,,, -149400,0.13937229,0.019214494,,,,,,,,,,,,,,,,, -149500,0.1432459,0.018980512,,,,,,,,,,,,,,,,, -149541,,,0.9955478310585022,0.0140143623575568,0.7779941175755147,0.987072765827179,0.050178974866867,0.2925357045022977,43793.0,0.9861477017402648,0.0540336780250072,0.2745146309739536,43793.0,48276.60923433304,72629.41816806793,48276.60923433304,24339.857694149017,8.81046986579895,0.0 -149600,0.1392305,0.016097587,,,,,,,,,,,,,,,,, -149700,0.13193446,0.01700657,,,,,,,,,,,,,,,,, -149800,0.14742154,0.018268611,,,,,,,,,,,,,,,,, -149900,0.13097051,0.017057857,,,,,,,,,,,,,,,,, -150000,0.15656252,0.015991498,,,,,,,,,,,,,,,,, -150100,0.12629922,0.017074151,,,,,,,,,,,,,,,,, -150200,0.14064659,0.018630749,,,,,,,,,,,,,,,,, -150297,,,0.9954404234886168,0.0142840798944234,0.7704192706591224,0.987072765827179,0.0501789674162864,0.2923980627222845,43793.0,0.9861477017402648,0.0540336780250072,0.2744235536170402,43793.0,48516.63416361809,72973.73292207718,48516.63416361809,24444.07306456566,8.864438533782959,0.0 -150300,0.14944912,0.017098868,,,,,,,,,,,,,,,,, -150400,0.14030594,0.01923281,,,,,,,,,,,,,,,,, -150500,0.14221767,0.01868914,,,,,,,,,,,,,,,,, -150600,0.13421798,0.016342571,,,,,,,,,,,,,,,,, -150700,0.13724925,0.01843592,,,,,,,,,,,,,,,,, -150800,0.13990948,0.018333992,,,,,,,,,,,,,,,,, -150900,0.13752158,0.018002592,,,,,,,,,,,,,,,,, -151000,0.15216236,0.01942762,,,,,,,,,,,,,,,,, -151056,,,0.9955169558525084,0.0140691706910729,0.7767232637434311,0.987072765827179,0.0501789674162864,0.2923393462303533,43793.0,0.9861477017402648,0.0540336780250072,0.2744903066878821,43793.0,48756.61280536652,73321.21061086655,48756.61280536652,24551.4905269146,8.925953388214111,0.0 -151100,0.13528608,0.015680386,,,,,,,,,,,,,,,,, -151200,0.14790452,0.016443817,,,,,,,,,,,,,,,,, -151300,0.15704858,0.019210795,,,,,,,,,,,,,,,,, -151400,0.14714053,0.017157795,,,,,,,,,,,,,,,,, -151500,0.1341467,0.017533828,,,,,,,,,,,,,,,,, -151600,0.1539319,0.01905136,,,,,,,,,,,,,,,,, -151700,0.14001052,0.017310774,,,,,,,,,,,,,,,,, -151800,0.15123913,0.018915066,,,,,,,,,,,,,,,,, -151802,,,0.9954544305801392,0.0141996936872601,0.7650808939684499,0.987072765827179,0.0501789711415767,0.2923891979268853,43793.0,0.9861477017402648,0.0540336780250072,0.2743454136781733,43793.0,48996.37939405441,73669.68244862556,48996.37939405441,24659.799003601074,9.30262327194214,0.0 -151900,0.13146485,0.015692133,,,,,,,,,,,,,,,,, -152000,0.13238235,0.014761347,,,,,,,,,,,,,,,,, -152100,0.1376267,0.01732048,,,,,,,,,,,,,,,,, -152200,0.16019526,0.020626066,,,,,,,,,,,,,,,,, -152300,0.14289804,0.01774872,,,,,,,,,,,,,,,,, -152400,0.15191932,0.016954284,,,,,,,,,,,,,,,,, -152500,0.15466236,0.019710701,,,,,,,,,,,,,,,,, -152546,,,0.9955030083656312,0.0141865722835063,0.7775761097654365,0.987072765827179,0.050178974866867,0.2923806819872739,43793.0,0.9861477017402648,0.0540336780250072,0.274417204869546,43793.0,49236.43760156632,74016.67475938797,49236.43760156632,24766.657678842545,9.358002424240112,0.0 -152600,0.12608604,0.016756563,,,,,,,,,,,,,,,,, -152700,0.15532623,0.018442418,,,,,,,,,,,,,,,,, -152800,0.14273848,0.017219003,,,,,,,,,,,,,,,,, -152900,0.14471354,0.018855937,,,,,,,,,,,,,,,,, -153000,0.14987658,0.017703936,,,,,,,,,,,,,,,,, -153100,0.13056095,0.016333176,,,,,,,,,,,,,,,,, -153200,0.13868167,0.019151475,,,,,,,,,,,,,,,,, -153288,,,0.9954869747161864,0.014114367775619,0.775209529543543,0.987072765827179,0.050178974866867,0.2923222150566751,43793.0,0.9861477017402648,0.0540336780250072,0.2743833817906149,43793.0,49476.56862664223,74357.86896109581,49476.56862664223,24867.64109492302,9.417367696762083,0.0 -153300,0.1306509,0.016906379,,,,,,,,,,,,,,,,, -153400,0.14706081,0.016950095,,,,,,,,,,,,,,,,, -153500,0.14214267,0.016893016,,,,,,,,,,,,,,,,, -153600,0.1388044,0.017283961,,,,,,,,,,,,,,,,, -153700,0.14824717,0.019676076,,,,,,,,,,,,,,,,, -153800,0.17214844,0.01852766,,,,,,,,,,,,,,,,, -153900,0.13887776,0.018651063,,,,,,,,,,,,,,,,, -154000,0.13809374,0.015978701,,,,,,,,,,,,,,,,, -154025,,,0.9954938888549804,0.0141519317403435,0.7746333298750403,0.987072765827179,0.0501789674162864,0.2925510130236306,43793.0,0.9861477017402648,0.0540336780250072,0.2744389354879151,43793.0,49716.54668498039,74704.63004040718,49716.54668498039,24974.34710907936,9.47364616394043,0.0 -154100,0.15410711,0.018298209,,,,,,,,,,,,,,,,, -154200,0.15219025,0.018698141,,,,,,,,,,,,,,,,, -154300,0.17064033,0.018322652,,,,,,,,,,,,,,,,, -154400,0.14987509,0.019041797,,,,,,,,,,,,,,,,, -154500,0.15836707,0.019062689,,,,,,,,,,,,,,,,, -154600,0.15394321,0.019466361,,,,,,,,,,,,,,,,, -154700,0.15765356,0.01793006,,,,,,,,,,,,,,,,, -154776,,,0.995477020740509,0.0141919394955039,0.772443339172785,0.987072765827179,0.0501789674162864,0.292435496395953,43793.0,0.9861477017402648,0.0540336780250072,0.2744790929915433,43793.0,49956.64419960976,75048.92160797119,49956.64419960976,25078.46404027939,9.529872179031372,0.0 -154800,0.1446615,0.019848015,,,,,,,,,,,,,,,,, -154900,0.12738839,0.015091581,,,,,,,,,,,,,,,,, -155000,0.17226848,0.019008461,,,,,,,,,,,,,,,,, -155100,0.14686677,0.021547232,,,,,,,,,,,,,,,,, -155200,0.13896686,0.017820483,,,,,,,,,,,,,,,,, -155300,0.16136412,0.019285373,,,,,,,,,,,,,,,,, -155400,0.14620996,0.018315068,,,,,,,,,,,,,,,,, -155500,0.14259769,0.016698983,,,,,,,,,,,,,,,,, -155522,,,0.9955002665519714,0.014126512221992,0.76504731206366,0.987072765827179,0.050178974866867,0.2924108013284655,43793.0,0.9861477017402648,0.0540336780250072,0.274395693634054,43793.0,50196.8704969883,75398.39845585823,50196.8704969883,25187.63874030113,9.584931373596191,0.0 -155600,0.15661456,0.019241253,,,,,,,,,,,,,,,,, -155700,0.15235399,0.017997196,,,,,,,,,,,,,,,,, -155800,0.14544217,0.018489733,,,,,,,,,,,,,,,,, -155900,0.13615747,0.01701069,,,,,,,,,,,,,,,,, -156000,0.12684761,0.016966749,,,,,,,,,,,,,,,,, -156100,0.13177426,0.016588833,,,,,,,,,,,,,,,,, -156200,0.13786936,0.017800832,,,,,,,,,,,,,,,,, -156274,,,0.9954464435577391,0.0142757603898644,0.7749395935942144,0.987072765827179,0.0501789674162864,0.2924427218899189,43793.0,0.9861477017402648,0.0540336780250072,0.2744178720782878,43793.0,50436.93785953522,75747.2478158474,50436.93785953522,25296.34480142593,9.640064239501951,0.0 -156300,0.13601209,0.01830248,,,,,,,,,,,,,,,,, -156400,0.15048665,0.017980047,,,,,,,,,,,,,,,,, -156500,0.1459599,0.018393913,,,,,,,,,,,,,,,,, -156600,0.13306512,0.01752437,,,,,,,,,,,,,,,,, -156700,0.13469748,0.016215818,,,,,,,,,,,,,,,,, -156800,0.14822835,0.018589905,,,,,,,,,,,,,,,,, -156900,0.15421322,0.017912576,,,,,,,,,,,,,,,,, -157000,0.14383692,0.01771478,,,,,,,,,,,,,,,,, -157024,,,0.9955170154571532,0.0140560949221253,0.773829621089907,0.987072765827179,0.050178974866867,0.2922885314680591,43793.0,0.9861477017402648,0.0540336780250072,0.2743299048412865,43793.0,50677.16159534454,76097.5432009697,50677.16159534454,25406.3370552063,9.698690176010132,0.0 -157100,0.14251786,0.017846566,,,,,,,,,,,,,,,,, -157200,0.13958621,0.017271884,,,,,,,,,,,,,,,,, -157300,0.13596435,0.016173907,,,,,,,,,,,,,,,,, -157400,0.14172973,0.020161022,,,,,,,,,,,,,,,,, -157500,0.14492454,0.019303622,,,,,,,,,,,,,,,,, -157600,0.14690523,0.019412154,,,,,,,,,,,,,,,,, -157700,0.16781174,0.01886439,,,,,,,,,,,,,,,,, -157769,,,0.9955052137374878,0.0141193345189094,0.7832026154449414,0.987072765827179,0.050178974866867,0.2924538217306947,43793.0,0.9861477017402648,0.0540336854755878,0.2744481542418377,43793.0,50917.29948115349,76442.57691502571,50917.29948115349,25511.156789064407,9.754071712493896,0.0 -157800,0.13348024,0.016815064,,,,,,,,,,,,,,,,, -157900,0.14576577,0.018338736,,,,,,,,,,,,,,,,, -158000,0.14733927,0.017580513,,,,,,,,,,,,,,,,, -158100,0.15971397,0.022279598,,,,,,,,,,,,,,,,, -158200,0.14449377,0.018523315,,,,,,,,,,,,,,,,, -158300,0.14213657,0.019359382,,,,,,,,,,,,,,,,, -158400,0.13674325,0.017953645,,,,,,,,,,,,,,,,, -158500,0.14471748,0.019576946,,,,,,,,,,,,,,,,, -158521,,,0.995479941368103,0.0141672519966959,0.7708454919205815,0.987072765827179,0.0501789674162864,0.2924374967382563,43793.0,0.9861477017402648,0.0540336780250072,0.2744134394539258,43793.0,51157.23203110695,76789.78963327408,51157.23203110695,25618.3601474762,9.810832738876345,0.0 -158600,0.14256528,0.017048405,,,,,,,,,,,,,,,,, -158700,0.14128956,0.01700986,,,,,,,,,,,,,,,,, -158800,0.13470802,0.017522454,,,,,,,,,,,,,,,,, -158900,0.14867546,0.01737101,,,,,,,,,,,,,,,,, -159000,0.13845815,0.01889269,,,,,,,,,,,,,,,,, -159100,0.13834885,0.017649427,,,,,,,,,,,,,,,,, -159200,0.12809062,0.017005544,,,,,,,,,,,,,,,,, -159271,,,0.9954665899276732,0.0142047116532921,0.7754148010822717,0.987072765827179,0.0501789711415767,0.2925999843117496,43793.0,0.9861477017402648,0.0540336780250072,0.2745207814921742,43793.0,51397.40039587021,77136.78889465332,51397.40039587021,25725.11342215538,9.86826467514038,0.0 -159300,0.12948237,0.015831301,,,,,,,,,,,,,,,,, -159400,0.14169602,0.019589638,,,,,,,,,,,,,,,,, -159500,0.12295759,0.015540566,,,,,,,,,,,,,,,,, -159600,0.15249713,0.019750824,,,,,,,,,,,,,,,,, -159700,0.13887435,0.018318877,,,,,,,,,,,,,,,,, -159800,0.14249887,0.018170975,,,,,,,,,,,,,,,,, -159900,0.1390212,0.017973306,,,,,,,,,,,,,,,,, -160000,0.14366825,0.018137923,,,,,,,,,,,,,,,,, -160026,,,0.9954699277877808,0.0141995521262288,0.7704376398776334,0.987072765827179,0.050178974866867,0.2923189401479742,43793.0,0.9861477017402648,0.0540336780250072,0.2745612762808954,43793.0,51637.54203295708,77481.7868654728,51637.54203295708,25829.893330335617,9.92480731010437,0.0 -160100,0.14404672,0.01921911,,,,,,,,,,,,,,,,, -160200,0.1227415,0.016597576,,,,,,,,,,,,,,,,, -160300,0.13119082,0.01796293,,,,,,,,,,,,,,,,, -160400,0.137171,0.019702451,,,,,,,,,,,,,,,,, -160500,0.14895546,0.018110707,,,,,,,,,,,,,,,,, -160600,0.1489426,0.01755267,,,,,,,,,,,,,,,,, -160700,0.14358374,0.01802274,,,,,,,,,,,,,,,,, -160775,,,0.9954940676689148,0.0141418268904089,0.7768644225591141,0.987072765827179,0.0501789674162864,0.2923581265098338,43793.0,0.9861477017402648,0.0540336780250072,0.2744706562424572,43793.0,51877.696434021,77823.0005209446,51877.696434021,25930.87596058845,9.9811851978302,0.0 -160800,0.16417037,0.020670319,,,,,,,,,,,,,,,,, -160900,0.13827676,0.017775955,,,,,,,,,,,,,,,,, -161000,0.14512089,0.019421188,,,,,,,,,,,,,,,,, -161100,0.1283302,0.017042013,,,,,,,,,,,,,,,,, -161200,0.13886605,0.017886916,,,,,,,,,,,,,,,,, -161300,0.1466437,0.019088786,,,,,,,,,,,,,,,,, -161400,0.1550301,0.018510066,,,,,,,,,,,,,,,,, -161500,0.13930504,0.01710191,,,,,,,,,,,,,,,,, -161528,,,0.995476245880127,0.0142111340537667,0.7777987133311279,0.987072765827179,0.0501789711415767,0.2924305651890328,43793.0,0.9861477017402648,0.0540336780250072,0.2743982161334868,43793.0,52117.85686635971,78171.39575123787,52117.85686635971,26039.03204393387,10.039782524108888,0.0 -161600,0.15160349,0.019948218,,,,,,,,,,,,,,,,, -161700,0.13563883,0.016953595,,,,,,,,,,,,,,,,, -161800,0.13966078,0.01686495,,,,,,,,,,,,,,,,, -161900,0.13747318,0.017650357,,,,,,,,,,,,,,,,, -162000,0.15582702,0.020284526,,,,,,,,,,,,,,,,, -162100,0.13195553,0.016200671,,,,,,,,,,,,,,,,, -162200,0.1541688,0.02044864,,,,,,,,,,,,,,,,, -162275,,,0.9954986572265624,0.0141007341444492,0.7723155867022389,0.987072765827179,0.0501789785921573,0.2925567746610274,43793.0,0.9861477017402648,0.0540336780250072,0.2744069678953243,43793.0,52358.030159950256,78516.05895709991,52358.030159950256,26143.434961795807,10.106394290924072,0.0 -162300,0.13944277,0.01628599,,,,,,,,,,,,,,,,, -162400,0.14086555,0.018684072,,,,,,,,,,,,,,,,, -162500,0.14885643,0.017822405,,,,,,,,,,,,,,,,, -162600,0.15064931,0.019762544,,,,,,,,,,,,,,,,, -162700,0.14305991,0.018090643,,,,,,,,,,,,,,,,, -162800,0.121986814,0.016329389,,,,,,,,,,,,,,,,, -162900,0.15183784,0.0184654,,,,,,,,,,,,,,,,, -163000,0.13014513,0.017191611,,,,,,,,,,,,,,,,, -163022,,,0.9955146312713624,0.0140972472727298,0.7758655269731343,0.987072765827179,0.0501789674162864,0.292350433179705,43793.0,0.9861477017402648,0.0540336780250072,0.2744230971769069,43793.0,52598.15224838257,78859.75465726852,52598.15224838257,26246.93022799492,10.163278579711914,0.0 -163100,0.13845563,0.01907948,,,,,,,,,,,,,,,,, -163200,0.12302119,0.016007075,,,,,,,,,,,,,,,,, -163300,0.14196071,0.018463604,,,,,,,,,,,,,,,,, -163400,0.14022921,0.017110704,,,,,,,,,,,,,,,,, -163500,0.13948582,0.018504843,,,,,,,,,,,,,,,,, -163600,0.13784908,0.01633228,,,,,,,,,,,,,,,,, -163700,0.1303437,0.016502513,,,,,,,,,,,,,,,,, -163771,,,0.99542498588562,0.0142718916758894,0.7642837725594658,0.987072765827179,0.0501789785921573,0.2923946279295678,43793.0,0.9861477017402648,0.0540336780250072,0.2744978733779977,43793.0,52838.39267086983,79202.17622852325,52838.39267086983,26349.03185725212,10.222217321395874,0.0 -163800,0.1390667,0.017609252,,,,,,,,,,,,,,,,, -163900,0.1387825,0.016657805,,,,,,,,,,,,,,,,, -164000,0.12423589,0.016880395,,,,,,,,,,,,,,,,, -164100,0.14445339,0.018242022,,,,,,,,,,,,,,,,, -164200,0.13517883,0.018069874,,,,,,,,,,,,,,,,, -164300,0.16033632,0.019507395,,,,,,,,,,,,,,,,, -164400,0.14551474,0.019201303,,,,,,,,,,,,,,,,, -164500,0.14924432,0.0212589,,,,,,,,,,,,,,,,, -164511,,,0.9954708814620972,0.0141994431614875,0.7770340807670562,0.987072765827179,0.0501789674162864,0.2923320008925272,43793.0,0.9861477017402648,0.0540336780250072,0.2743908303667785,43793.0,53078.33946681023,79548.64507198334,53078.33946681023,26455.47572350502,10.280250310897827,0.0 -164600,0.13918032,0.016911611,,,,,,,,,,,,,,,,, -164700,0.14383875,0.019043414,,,,,,,,,,,,,,,,, -164800,0.144427,0.0181076,,,,,,,,,,,,,,,,, -164900,0.15070502,0.019175736,,,,,,,,,,,,,,,,, -165000,0.12640032,0.016118713,,,,,,,,,,,,,,,,, -165100,0.13108835,0.018956732,,,,,,,,,,,,,,,,, -165200,0.13702336,0.0176014,,,,,,,,,,,,,,,,, -165264,,,0.9955559968948364,0.014011014252901,0.7774745704495394,0.987072765827179,0.0501789674162864,0.2923164205091282,43793.0,0.9861477017402648,0.0540336780250072,0.2744477263950919,43793.0,53318.45733141899,79892.24152255058,53318.45733141899,26558.87539362908,10.338693141937256,0.0 -165300,0.14515759,0.017404543,,,,,,,,,,,,,,,,, -165400,0.1441062,0.01642059,,,,,,,,,,,,,,,,, -165500,0.14308208,0.018976543,,,,,,,,,,,,,,,,, -165600,0.15840197,0.019006038,,,,,,,,,,,,,,,,, -165700,0.13186133,0.01561467,,,,,,,,,,,,,,,,, -165800,0.13928162,0.01846745,,,,,,,,,,,,,,,,, -165900,0.13016436,0.017820185,,,,,,,,,,,,,,,,, -166000,0.13219266,0.017083172,,,,,,,,,,,,,,,,, -166013,,,0.9955283999443054,0.0141153950244188,0.7682561890850204,0.987072765827179,0.0501789785921573,0.2923437326480599,43793.0,0.9861477017402648,0.0540336780250072,0.2744653453568398,43793.0,53558.51423501968,80235.681828022,53558.51423501968,26662.17934703827,10.398459672927856,0.0 -166100,0.13347381,0.017798569,,,,,,,,,,,,,,,,, -166200,0.13744125,0.018646985,,,,,,,,,,,,,,,,, -166300,0.14672501,0.018420408,,,,,,,,,,,,,,,,, -166400,0.15817304,0.019736428,,,,,,,,,,,,,,,,, -166500,0.14073409,0.01827013,,,,,,,,,,,,,,,,, -166600,0.1291914,0.015842361,,,,,,,,,,,,,,,,, -166700,0.16740008,0.018169288,,,,,,,,,,,,,,,,, -166758,,,0.9954188466072084,0.0142263481393456,0.7785996373453546,0.987072765827179,0.0501789711415767,0.2923470375904951,43793.0,0.9861477017402648,0.0540336780250072,0.2744245503170944,43793.0,53798.596177339554,80585.60542726517,53798.596177339554,26771.942001581192,10.45738959312439,0.0 -166800,0.13880974,0.017275732,,,,,,,,,,,,,,,,, -166900,0.14987004,0.019827094,,,,,,,,,,,,,,,,, -167000,0.13071406,0.017388035,,,,,,,,,,,,,,,,, -167100,0.13455248,0.018482924,,,,,,,,,,,,,,,,, -167200,0.14274764,0.019233815,,,,,,,,,,,,,,,,, -167300,0.14230351,0.017807689,,,,,,,,,,,,,,,,, -167400,0.1569553,0.018462554,,,,,,,,,,,,,,,,, -167500,0.13210724,0.015797187,,,,,,,,,,,,,,,,, -167501,,,0.9954667687416076,0.0141899297013878,0.7694678128577732,0.987072765827179,0.0501789785921573,0.2924166168514924,43793.0,0.9861477017402648,0.0540336780250072,0.274327476345126,43793.0,54038.70073080063,80934.76012897491,54038.70073080063,26880.91184401512,10.514888763427734,0.0 -167600,0.13234837,0.017375782,,,,,,,,,,,,,,,,, -167700,0.1376874,0.015801823,,,,,,,,,,,,,,,,, -167800,0.15742159,0.019645248,,,,,,,,,,,,,,,,, -167900,0.11739582,0.015200795,,,,,,,,,,,,,,,,, -168000,0.1555211,0.019638997,,,,,,,,,,,,,,,,, -168100,0.1420578,0.017766163,,,,,,,,,,,,,,,,, -168200,0.12845054,0.017152566,,,,,,,,,,,,,,,,, -168246,,,0.9954833388328552,0.014200116507709,0.7711502631017008,0.987072765827179,0.0501789711415767,0.2924447621749889,43793.0,0.9861477017402648,0.0540336780250072,0.2744346043114226,43793.0,54278.76337981224,81286.47692489624,54278.76337981224,26992.4874894619,10.573052167892456,0.0 -168300,0.13853596,0.018641835,,,,,,,,,,,,,,,,, -168400,0.16743661,0.018889839,,,,,,,,,,,,,,,,, -168500,0.14247008,0.019529369,,,,,,,,,,,,,,,,, -168600,0.1423743,0.020955902,,,,,,,,,,,,,,,,, -168700,0.12626071,0.01572132,,,,,,,,,,,,,,,,, -168800,0.14231262,0.017793046,,,,,,,,,,,,,,,,, -168900,0.14963947,0.018429097,,,,,,,,,,,,,,,,, -168989,,,0.9955009818077089,0.0141321178525686,0.7773532093488216,0.987072765827179,0.050178974866867,0.2924251409088802,43793.0,0.9861477017402648,0.0540336780250072,0.2744506152585214,43793.0,54518.72223806381,81631.07444930077,54518.72223806381,27097.03732419014,10.639933347702026,0.0 -169000,0.1442754,0.01746415,,,,,,,,,,,,,,,,, -169100,0.12728874,0.01629059,,,,,,,,,,,,,,,,, -169200,0.124909334,0.016302248,,,,,,,,,,,,,,,,, -169300,0.13873705,0.018636838,,,,,,,,,,,,,,,,, -169400,0.12919977,0.018090898,,,,,,,,,,,,,,,,, -169500,0.15168045,0.01900044,,,,,,,,,,,,,,,,, -169600,0.14025144,0.017364694,,,,,,,,,,,,,,,,, -169700,0.14454165,0.018592639,,,,,,,,,,,,,,,,, -169737,,,0.9955348372459412,0.0140738934278488,0.7799641597540772,0.987072765827179,0.0501789711415767,0.2924950547097051,43793.0,0.9861477017402648,0.0540336780250072,0.2743306396429991,43793.0,54758.81324410439,81976.3571498394,54758.81324410439,27202.15098762512,10.69777512550354,0.0 -169800,0.13786317,0.019005995,,,,,,,,,,,,,,,,, -169900,0.15721123,0.018342374,,,,,,,,,,,,,,,,, -170000,0.15373735,0.020182962,,,,,,,,,,,,,,,,, -170100,0.1338808,0.01615665,,,,,,,,,,,,,,,,, -170200,0.13304225,0.016164605,,,,,,,,,,,,,,,,, -170300,0.14378422,0.01789728,,,,,,,,,,,,,,,,, -170400,0.14518343,0.021049006,,,,,,,,,,,,,,,,, -170490,,,0.995444118976593,0.0141743142157793,0.7645382818223485,0.987072765827179,0.0501789674162864,0.2925678030672734,43793.0,0.9861477017402648,0.0540336780250072,0.2744199995516733,43793.0,54998.77826976776,82322.19646334648,54998.77826976776,27307.94575953484,10.756629467010498,0.0 -170500,0.14735976,0.018162478,,,,,,,,,,,,,,,,, -170600,0.1495097,0.018798787,,,,,,,,,,,,,,,,, -170700,0.1482905,0.018818239,,,,,,,,,,,,,,,,, -170800,0.15420373,0.018568516,,,,,,,,,,,,,,,,, -170900,0.15057196,0.019269066,,,,,,,,,,,,,,,,, -171000,0.1340115,0.01800012,,,,,,,,,,,,,,,,, -171100,0.13571845,0.017245296,,,,,,,,,,,,,,,,, -171200,0.15441653,0.019579347,,,,,,,,,,,,,,,,, -171230,,,0.995492696762085,0.0141819659620523,0.7743298517611776,0.987072765827179,0.050178974866867,0.2925519655322183,43793.0,0.9861477017402648,0.0540336780250072,0.2744220064947314,43793.0,55238.96902322769,82674.4102742672,55238.96902322769,27419.886949539185,10.8176851272583,0.0 -171300,0.12910175,0.014955925,,,,,,,,,,,,,,,,, -171400,0.1581382,0.021028824,,,,,,,,,,,,,,,,, -171500,0.14784896,0.018499628,,,,,,,,,,,,,,,,, -171600,0.15393564,0.018130695,,,,,,,,,,,,,,,,, -171700,0.14574595,0.018085811,,,,,,,,,,,,,,,,, -171800,0.13185874,0.01659189,,,,,,,,,,,,,,,,, -171816,,,,,,,,,,,,,,55431.25052642822,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index e62084b2a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,48 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -892.5795657634735,0.0,38.32567739486694,1,0,38.32567739486694,0.0007088489946909,0.0,11.12126636505127,3003,930.9052810668944,0.0006190601852722,0.0,11.111475944519045,0.0004835649742744,0.0,11.119153022766112,3000 -1357.4461088180542,0.0298850536346435,878.5678820610046,2331,0,878.5678820610046,0.5068154335021973,16.345875563990155,2.946174144744873,3003,2236.116005897522,0.5102968215942383,21.65943524518221,2.873313665390014,0.5102850198745728,17.68833473762671,2.8761160373687744,3000 -1803.7466485500336,0.0580279827117919,1718.7574768066406,4662,0,1718.7574768066406,0.5927139520645142,22.215112739406536,2.1307547092437744,3003,3522.705801486969,0.5730522871017456,27.33075637415592,2.268893480300904,0.5891309380531311,23.586874302922432,2.1536061763763428,3000 -2254.35815167427,0.082970380783081,2558.678544282913,6993,0,2558.678544282913,0.6202428936958313,23.74940006571612,1.8929826021194456,3003,4813.3322784900665,0.606548547744751,29.43806395732164,1.9968069791793823,0.6165453791618347,25.35945736848895,1.9260854721069336,3000 -2709.2392904758453,0.1067221164703369,3398.8785541057587,9326,0,3398.8785541057587,0.6364766955375671,24.992818880390896,1.7680728435516355,3003,6108.510417699814,0.6131796836853027,29.64469360437432,1.9451521635055544,0.6309779286384583,26.356215868632788,1.825148224830628,3000 -3165.6758439540863,0.131770372390747,4238.8062624931335,11659,0,4238.8062624931335,0.6424728631973267,25.21233941447844,1.711891770362854,3003,7404.970848798752,0.616400420665741,29.70226124134817,1.916278600692749,0.6364831328392029,26.524825522377707,1.7651243209838867,3000 -3660.879511117935,0.1590526103973388,5078.8299243450165,13992,0,5078.8299243450165,0.6517343521118164,25.849126741883687,1.6489189863204956,3003,8740.295587062836,0.6251095533370972,29.71905861466232,1.8352748155593872,0.6442325711250305,27.30846081869481,1.712920308113098,3000 -4142.569768667221,0.1852412223815918,5918.779824733734,16324,0,5918.779824733734,0.6576956510543823,26.30162304760944,1.614374041557312,3003,10062.03084230423,0.6289481520652771,30.1445867527318,1.8222870826721191,0.6479398608207703,27.3637250719459,1.68116557598114,3000 -4637.421412944794,0.2112081050872802,6758.940435886383,18657,0,6758.940435886383,0.660449743270874,26.63128190515756,1.6002322435379028,3003,11397.138581991196,0.6277276873588562,30.65094670772645,1.8295557498931885,0.6513124108314514,27.49430518378227,1.6622154712677002,3000 -5109.4391322135925,0.2413890361785888,7599.17161488533,20990,0,7599.17161488533,0.6596363186836243,26.523374816812822,1.5901591777801514,3003,12709.490337610245,0.6352289915084839,30.791125978727298,1.7688525915145874,0.6497625708580017,27.406212220377967,1.6578360795974731,3000 -5611.919378757477,0.2727680206298828,8439.236397266388,23323,0,8439.236397266388,0.6660391688346863,27.10138290676962,1.55837881565094,3003,14052.139616012571,0.6351427435874939,31.04816736219118,1.7711073160171509,0.6542758345603943,28.171750929722943,1.6312884092330933,3000 -6112.617798089981,0.3002674579620361,9279.262301683426,25657,0,9279.262301683426,0.6674452424049377,27.24795073109318,1.546763896942139,3003,15392.960274457932,0.6489943861961365,31.861470460471523,1.6708744764328003,0.6577103734016418,28.386727715138804,1.6123778820037842,3000 -6766.884717226028,0.3277285099029541,10119.33438873291,27990,0,10119.33438873291,0.6686421632766724,27.659953390141943,1.5359697341918943,3003,16887.398057222366,0.6425632238388062,31.24810712567143,1.7168598175048828,0.658665120601654,28.33344174410828,1.6051126718521118,3000 -7224.911325931549,0.3553922176361084,10959.488750457764,30324,0,10959.488750457764,0.6704317331314087,27.53926569000312,1.5235344171524048,3003,18185.67586708069,0.6387715935707092,31.180361760347587,1.7513316869735718,0.6582807302474976,28.01698388066988,1.5954887866973877,3000 -7715.065255880356,0.382869005203247,11799.395884513857,32657,0,11799.395884513857,0.6709546446800232,27.46059266106052,1.5159692764282229,3003,19515.8352162838,0.6421418190002441,31.688443196843643,1.7130132913589478,0.6594462394714355,28.124145456230487,1.5872998237609863,3000 -8221.539778709412,0.4115808010101318,12639.648458957672,34991,0,12639.648458957672,0.672105073928833,27.63131568906135,1.5036916732788086,3003,20862.66091775894,0.6410354971885681,31.29086602755576,1.7259587049484253,0.6630048155784607,28.3995127800219,1.5775558948516846,3000 -8793.286875247955,0.4394073486328125,13479.89543390274,37325,0,13479.89543390274,0.6757538914680481,27.793051245605607,1.500253200531006,3003,22274.75265264511,0.6429249048233032,31.46545863659349,1.7336983680725098,0.663240373134613,28.164983907427093,1.577974796295166,3000 -9451.717700958252,0.4671125411987304,14320.017933368685,39659,0,14320.017933368685,0.6762884259223938,28.09599251642844,1.486888408660889,3003,23773.4026260376,0.6447225213050842,31.25804832940152,1.689799427986145,0.6650258302688599,28.47577944514998,1.5642013549804688,3000 -10083.92147922516,0.4971275329589844,15160.147383213043,41993,0,15160.147383213043,0.6767416596412659,28.159225092401268,1.482648491859436,3003,25245.83581018448,0.6418246626853943,31.54752667866807,1.7203962802886963,0.6655341982841492,28.69866005713504,1.562987208366394,3000 -10619.568261384964,0.5311579704284668,16000.051938295364,44326,0,16000.051938295364,0.6781128644943237,27.77103557395157,1.4767955541610718,3003,26621.491897821423,0.6613112092018127,32.61597261217031,1.5802650451660156,0.6657201647758484,28.666405029152983,1.5555511713027954,3000 -11198.965550661089,0.561424970626831,16839.980953216553,46660,0,16839.980953216553,0.6805066466331482,28.084528100338435,1.4682908058166504,3003,28040.916669368744,0.6497166752815247,31.55529780918572,1.657746195793152,0.6670964956283569,28.635666652429748,1.5500373840332031,3000 -11758.793085098268,0.5909426212310791,17679.945336341858,48994,0,17679.945336341858,0.6798210740089417,28.0331932471544,1.4563076496124268,3003,29440.80671453476,0.6467783451080322,31.408422055872663,1.6856553554534912,0.6677412390708923,28.786372928465703,1.5390740633010864,3000 -12309.529057741163,0.6208062171936035,18520.05733203888,51328,0,18520.05733203888,0.6797513365745544,28.20761255202798,1.4512735605239868,3003,30831.75518250465,0.6549112796783447,31.91548361763397,1.6199612617492676,0.6689687371253967,28.929565175389858,1.5352182388305664,3000 -12857.206429243088,0.6506195068359375,19360.178364753723,53662,0,19360.178364753723,0.681552529335022,28.29054019765348,1.4458444118499756,3003,32219.65393900872,0.6495839357376099,31.52180095518269,1.6710714101791382,0.6688695549964905,28.870555564318444,1.532265305519104,3000 -13384.084159612656,0.6815006732940674,20200.181646585464,55996,0,20200.181646585464,0.6832491159439087,28.576496861826488,1.438015103340149,3003,33586.63511300087,0.6558064818382263,32.056675923308504,1.641991376876831,0.6720685362815857,29.18964414297396,1.5172662734985352,3000 -14075.882135391235,0.7136006355285645,21040.09161424637,58329,0,21040.09161424637,0.6833652853965759,28.48284404893331,1.4308712482452393,3003,35118.44615530968,0.657656729221344,31.937285743635364,1.6101428270339966,0.6714237928390503,29.03327787917762,1.5166760683059692,3000 -14703.68877029419,0.7454655170440674,21880.31508302689,60663,0,21880.31508302689,0.6852594614028931,28.83500179481324,1.4206849336624146,3003,36586.57926249504,0.6537819504737854,32.49978187618908,1.6467565298080444,0.671584963798523,28.973790934716995,1.510255217552185,3000 -15315.602014303207,0.77724289894104,22720.35317778588,62997,0,22720.35317778588,0.686967670917511,28.856663573822804,1.4123915433883667,3003,38038.63135719299,0.6693201661109924,33.4739314800073,1.52785062789917,0.6742631793022156,29.266614300038707,1.5014219284057615,3000 -16020.533456087112,0.8095309734344482,23560.48533654213,65331,0,23560.48533654213,0.6884666681289673,28.83885231298425,1.4014354944229126,3003,39583.79589796066,0.6595335006713867,32.48676522191736,1.6017496585845947,0.6758750677108765,29.29447900717668,1.486812710762024,3000 -16616.940055847168,0.8407487869262695,24400.5163128376,67665,0,24400.5163128376,0.6875835657119751,28.650583418270205,1.4012601375579834,3003,41020.33675789833,0.660723090171814,32.36851512432437,1.6009280681610107,0.6756022572517395,29.302940848683274,1.4917354583740234,3000 -17352.784069776535,0.9213240146636964,25240.725229740143,70000,0,25240.725229740143,0.6928824782371521,29.57294526056814,1.385662317276001,3003,42596.53853440285,0.6662484407424927,32.96686890315351,1.547742247581482,0.678317666053772,29.641713014224447,1.4796814918518066,3000 -18074.45157289505,0.9541702270507812,26080.97846698761,72334,0,26080.97846698761,0.6930567622184753,29.155925471336744,1.3839657306671145,3003,44158.56359124184,0.658568263053894,32.689614051614335,1.5997875928878784,0.6778836846351624,29.733602133158183,1.4754083156585691,3000 -18649.47722673416,0.9877212047576904,26921.192920207977,74669,0,26921.192920207977,0.6938470005989075,29.201825259701923,1.376531720161438,3003,45573.90559220314,0.6603477597236633,32.826282796088854,1.6007717847824097,0.6780697107315063,29.297419152581835,1.4693963527679443,3000 -19375.05325198173,1.0212042331695557,27761.27295565605,77003,0,27761.27295565605,0.6957759857177734,29.476644371476635,1.3691195249557495,3003,47139.663432359695,0.6674932241439819,33.33707098666702,1.5427851676940918,0.6805866956710815,29.74598294178544,1.4604946374893188,3000 -19987.100786447525,1.0545799732208252,28601.181200027462,79336,0,28601.181200027462,0.6976584792137146,29.926070662587613,1.358900547027588,3003,48591.722954034805,0.662295937538147,33.02561788436131,1.577651858329773,0.6813926696777344,29.73820499849484,1.4510258436203003,3000 -20549.10893559456,1.0872220993041992,29441.39288806916,81670,0,29441.39288806916,0.6980535984039307,29.6608945179934,1.3542295694351196,3003,49994.04566836357,0.6846832633018494,34.21251206089079,1.4385541677474976,0.682719349861145,29.936081717855267,1.4486734867095947,3000 -21285.675103902817,1.1223838329315186,30281.515100955963,84004,0,30281.515100955963,0.6986462473869324,30.022813565275428,1.3453952074050903,3003,51570.83880257607,0.6690401434898376,33.571556752165826,1.5280447006225586,0.6831657290458679,29.646118721990334,1.441229224205017,3000 -21849.073429584503,1.157557249069214,31121.409340381622,86337,0,31121.409340381622,0.7010284066200256,29.734592836322268,1.3384041786193848,3003,52974.23804974556,0.672156810760498,33.86852339264609,1.5186889171600342,0.6863647103309631,30.026562760041603,1.4359986782073977,3000 -22428.9524166584,1.192662000656128,31961.54389023781,88671,0,31961.54389023781,0.7016443014144897,30.041883007413933,1.331700563430786,3003,54394.35652112961,0.6767730116844177,33.87990809492461,1.4741313457489014,0.6857447624206543,30.211913446848925,1.4289525747299194,3000 -23022.13383102417,1.2275655269622805,32801.646438360214,91004,0,32801.646438360214,0.7026785612106323,30.104204020163536,1.3232601881027222,3003,55827.74574255943,0.6733483672142029,33.843482795286214,1.5103007555007937,0.6872202157974243,30.112663837010107,1.4207197427749634,3000 -23623.600753068924,1.2687926292419434,33641.57714486122,93337,0,33641.57714486122,0.7031898498535156,30.082529087863985,1.3187764883041382,3003,57269.25446343422,0.6737951040267944,33.77046943780147,1.5091018676757812,0.6864390969276428,30.335093032326945,1.415480136871338,3000 -24330.59509205818,1.3097121715545654,34481.814543008804,95671,0,34481.814543008804,0.7049445509910583,30.18251317817529,1.3141802549362185,3003,58816.59733939171,0.6822972297668457,34.47859926606245,1.4458683729171753,0.689315676689148,30.457518693412105,1.405459761619568,3000 -24921.81315469742,1.3472654819488523,35321.9128882885,98004,0,35321.9128882885,0.7058508992195129,30.45435340686677,1.3002111911773682,3003,60248.02245783806,0.678022563457489,34.01514409048839,1.476861596107483,0.6899852156639099,30.62139618709256,1.4042116403579712,3000 -25543.32467985153,1.3858873844146729,36162.06133246422,100338,0,36162.06133246422,0.7059787511825562,30.23059396378059,1.2977460622787476,3003,61709.79067540169,0.6938577890396118,34.990644862138524,1.3822377920150757,0.6900968551635742,30.46087354581043,1.396977782249451,3000 -26230.07666492462,1.4218950271606443,37002.12159109116,102671,0,37002.12159109116,0.7083958387374878,30.438368768962743,1.2918590307235718,3003,63236.7092871666,0.6866313219070435,34.47752292320814,1.427037477493286,0.6905555725097656,30.45557605030123,1.392742156982422,3000 -26824.891982793808,1.4596679210662842,37842.12282657623,105004,0,37842.12282657623,0.7087908983230591,30.58542091418264,1.2866441011428833,3003,64671.632581949234,0.6834654808044434,34.62209129947804,1.4519506692886353,0.691969096660614,30.79525270570812,1.38778555393219,3000 -27443.19491481781,1.4974017143249512,38682.26754260063,107337,0,38682.26754260063,0.7097437977790833,30.80200748225378,1.2780189514160156,3003,66130.1880209446,0.690151035785675,35.204680267477826,1.4052026271820068,0.6935065984725952,30.966569494395618,1.38334059715271,3000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index f75c6152a..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1123 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.9713173,11.115017,,,,,,,,,,,,,,,,, -1,,,0.0006190601852722,11.111475944519045,0.0,0.0004835649742744,11.119153022766112,0.0,3000.0,0.0007088489946909,11.12126636505127,0.0,3003.0,38.32567739486694,930.9052810668944,38.32567739486694,892.5795657634735,0.0,0.0 -100,0.17473924,8.255207,,,,,,,,,,,,,,,,, -200,0.8806711,7.4936314,,,,,,,,,,,,,,,,, -300,0.31875873,6.870786,,,,,,,,,,,,,,,,, -400,0.34042042,6.261581,,,,,,,,,,,,,,,,, -500,0.31365836,5.8279996,,,,,,,,,,,,,,,,, -600,0.40159178,5.5768223,,,,,,,,,,,,,,,,, -700,0.38505277,5.2881346,,,,,,,,,,,,,,,,, -800,0.5153511,5.066572,,,,,,,,,,,,,,,,, -900,0.45454448,4.8235617,,,,,,,,,,,,,,,,, -1000,0.59630394,4.577527,,,,,,,,,,,,,,,,, -1100,0.46690905,4.2728453,,,,,,,,,,,,,,,,, -1200,0.4874353,3.9878235,,,,,,,,,,,,,,,,, -1300,0.54745984,4.049309,,,,,,,,,,,,,,,,, -1400,0.49286005,3.7997725,,,,,,,,,,,,,,,,, -1500,0.5363073,3.5792592,,,,,,,,,,,,,,,,, -1600,0.52047044,3.5046682,,,,,,,,,,,,,,,,, -1700,0.5932169,3.4855154,,,,,,,,,,,,,,,,, -1800,0.5544502,3.3314762,,,,,,,,,,,,,,,,, -1900,0.4438033,3.2393394,,,,,,,,,,,,,,,,, -2000,0.42903876,3.1940663,,,,,,,,,,,,,,,,, -2100,0.3666499,3.0677881,,,,,,,,,,,,,,,,, -2200,0.54077494,3.1059391,,,,,,,,,,,,,,,,, -2300,0.3610672,3.032393,,,,,,,,,,,,,,,,, -2331,,,0.5102968215942383,2.873313665390014,21.65943524518221,0.5102850198745728,2.8761160373687744,17.68833473762671,3000.0,0.5068154335021973,2.946174144744873,16.345875563990155,3003.0,878.5678820610046,2236.116005897522,878.5678820610046,1357.4461088180542,0.0298850536346435,0.0 -2400,0.34228384,2.9053328,,,,,,,,,,,,,,,,, -2500,0.42004362,2.8995502,,,,,,,,,,,,,,,,, -2600,0.3739925,2.8158472,,,,,,,,,,,,,,,,, -2700,0.35428214,2.8294425,,,,,,,,,,,,,,,,, -2800,0.29457167,2.6937113,,,,,,,,,,,,,,,,, -2900,0.2411242,2.7431707,,,,,,,,,,,,,,,,, -3000,0.29837978,2.755699,,,,,,,,,,,,,,,,, -3100,0.26000154,2.6501033,,,,,,,,,,,,,,,,, -3200,0.23959847,2.637304,,,,,,,,,,,,,,,,, -3300,0.50428855,2.6760948,,,,,,,,,,,,,,,,, -3400,0.22397828,2.608281,,,,,,,,,,,,,,,,, -3500,0.22319758,2.5026805,,,,,,,,,,,,,,,,, -3600,0.21842165,2.6162453,,,,,,,,,,,,,,,,, -3700,0.1988437,2.4896111,,,,,,,,,,,,,,,,, -3800,0.22006623,2.4724452,,,,,,,,,,,,,,,,, -3900,0.21548983,2.4288337,,,,,,,,,,,,,,,,, -4000,0.18380864,2.5786965,,,,,,,,,,,,,,,,, -4100,0.18509515,2.4093735,,,,,,,,,,,,,,,,, -4200,0.19215357,2.4627354,,,,,,,,,,,,,,,,, -4300,0.19404598,2.4369612,,,,,,,,,,,,,,,,, -4400,0.1819127,2.4079003,,,,,,,,,,,,,,,,, -4500,0.17951286,2.251172,,,,,,,,,,,,,,,,, -4600,0.17488627,2.3183262,,,,,,,,,,,,,,,,, -4662,,,0.5730522871017456,2.268893480300904,27.33075637415592,0.5891309380531311,2.1536061763763428,23.586874302922432,3000.0,0.5927139520645142,2.1307547092437744,22.215112739406536,3003.0,1718.7574768066406,3522.705801486969,1718.7574768066406,1803.7466485500336,0.0580279827117919,0.0 -4700,0.17042081,2.3270905,,,,,,,,,,,,,,,,, -4800,0.17860267,2.3114023,,,,,,,,,,,,,,,,, -4900,0.16864996,2.293741,,,,,,,,,,,,,,,,, -5000,0.16054949,2.2948463,,,,,,,,,,,,,,,,, -5100,0.21337654,2.35377,,,,,,,,,,,,,,,,, -5200,0.16260512,2.2262452,,,,,,,,,,,,,,,,, -5300,0.14996392,2.2892742,,,,,,,,,,,,,,,,, -5400,0.15732391,2.2836387,,,,,,,,,,,,,,,,, -5500,0.15543893,2.220042,,,,,,,,,,,,,,,,, -5600,0.21213095,2.2747297,,,,,,,,,,,,,,,,, -5700,0.15153879,2.235819,,,,,,,,,,,,,,,,, -5800,0.17221999,2.2379944,,,,,,,,,,,,,,,,, -5900,0.15531985,2.2114155,,,,,,,,,,,,,,,,, -6000,0.14976972,2.2021937,,,,,,,,,,,,,,,,, -6100,0.18325368,2.216001,,,,,,,,,,,,,,,,, -6200,0.17777312,2.227609,,,,,,,,,,,,,,,,, -6300,0.15728767,2.149391,,,,,,,,,,,,,,,,, -6400,0.1393967,2.21835,,,,,,,,,,,,,,,,, -6500,0.15935975,2.2160597,,,,,,,,,,,,,,,,, -6600,0.16024014,2.2830477,,,,,,,,,,,,,,,,, -6700,0.16052064,2.1406052,,,,,,,,,,,,,,,,, -6800,0.1521301,2.1652737,,,,,,,,,,,,,,,,, -6900,0.14633965,2.1458857,,,,,,,,,,,,,,,,, -6993,,,0.606548547744751,1.9968069791793823,29.43806395732164,0.6165453791618347,1.9260854721069336,25.35945736848895,3000.0,0.6202428936958313,1.8929826021194456,23.74940006571612,3003.0,2558.678544282913,4813.3322784900665,2558.678544282913,2254.35815167427,0.082970380783081,0.0 -7000,0.14471635,2.0966883,,,,,,,,,,,,,,,,, -7100,0.17276631,2.1198835,,,,,,,,,,,,,,,,, -7200,0.15143253,2.180427,,,,,,,,,,,,,,,,, -7300,0.3148695,2.0454776,,,,,,,,,,,,,,,,, -7400,0.16660337,2.0636861,,,,,,,,,,,,,,,,, -7500,0.164577,2.003624,,,,,,,,,,,,,,,,, -7600,0.23859541,2.1124446,,,,,,,,,,,,,,,,, -7700,0.15180601,2.1801503,,,,,,,,,,,,,,,,, -7800,0.19357365,2.044172,,,,,,,,,,,,,,,,, -7900,0.25401545,2.0676105,,,,,,,,,,,,,,,,, -8000,0.16076586,2.0332484,,,,,,,,,,,,,,,,, -8100,0.16683133,2.1173325,,,,,,,,,,,,,,,,, -8200,0.17695367,2.0610003,,,,,,,,,,,,,,,,, -8300,0.1938218,2.0577452,,,,,,,,,,,,,,,,, -8400,0.17446107,2.125331,,,,,,,,,,,,,,,,, -8500,0.15022396,2.0805404,,,,,,,,,,,,,,,,, -8600,0.17071487,2.0678072,,,,,,,,,,,,,,,,, -8700,0.17201287,1.9627209,,,,,,,,,,,,,,,,, -8800,0.1616931,1.9637395,,,,,,,,,,,,,,,,, -8900,0.1594353,2.0428927,,,,,,,,,,,,,,,,, -9000,0.1582642,2.0401256,,,,,,,,,,,,,,,,, -9100,0.28326863,2.0732388,,,,,,,,,,,,,,,,, -9200,0.24306133,2.084116,,,,,,,,,,,,,,,,, -9300,0.20982194,2.020942,,,,,,,,,,,,,,,,, -9326,,,0.6131796836853027,1.9451521635055544,29.64469360437432,0.6309779286384583,1.825148224830628,26.356215868632788,3000.0,0.6364766955375671,1.7680728435516355,24.992818880390896,3003.0,3398.8785541057587,6108.510417699814,3398.8785541057587,2709.2392904758453,0.1067221164703369,0.0 -9400,0.172759,2.0091794,,,,,,,,,,,,,,,,, -9500,0.15586133,2.0213687,,,,,,,,,,,,,,,,, -9600,0.16982785,2.0197778,,,,,,,,,,,,,,,,, -9700,0.15531962,1.9895656,,,,,,,,,,,,,,,,, -9800,0.17986558,2.0409698,,,,,,,,,,,,,,,,, -9900,0.20827651,1.9868845,,,,,,,,,,,,,,,,, -10000,0.17502284,2.0528147,,,,,,,,,,,,,,,,, -10100,4.981084,3.964525,,,,,,,,,,,,,,,,, -10200,0.24335872,2.0348513,,,,,,,,,,,,,,,,, -10300,0.15544477,2.0507982,,,,,,,,,,,,,,,,, -10400,0.16959801,1.9611727,,,,,,,,,,,,,,,,, -10500,0.15140453,2.0697708,,,,,,,,,,,,,,,,, -10600,0.16318044,1.9581211,,,,,,,,,,,,,,,,, -10700,0.15956515,2.0247028,,,,,,,,,,,,,,,,, -10800,0.16649106,2.0673292,,,,,,,,,,,,,,,,, -10900,0.17335331,1.9626733,,,,,,,,,,,,,,,,, -11000,0.18436486,1.9955807,,,,,,,,,,,,,,,,, -11100,0.17096415,1.9544176,,,,,,,,,,,,,,,,, -11200,0.17691232,1.9381021,,,,,,,,,,,,,,,,, -11300,0.19830078,2.0314724,,,,,,,,,,,,,,,,, -11400,0.17323253,1.9289682,,,,,,,,,,,,,,,,, -11500,0.19249134,2.0210695,,,,,,,,,,,,,,,,, -11600,0.1619088,1.9258035,,,,,,,,,,,,,,,,, -11659,,,0.616400420665741,1.916278600692749,29.70226124134817,0.6364831328392029,1.7651243209838867,26.524825522377707,3000.0,0.6424728631973267,1.711891770362854,25.21233941447844,3003.0,4238.8062624931335,7404.970848798752,4238.8062624931335,3165.6758439540863,0.131770372390747,0.0 -11700,0.15551822,2.0172186,,,,,,,,,,,,,,,,, -11800,0.15035258,1.9220946,,,,,,,,,,,,,,,,, -11900,0.19101475,1.9764144,,,,,,,,,,,,,,,,, -12000,0.19950907,2.002119,,,,,,,,,,,,,,,,, -12100,0.18404987,1.8656105,,,,,,,,,,,,,,,,, -12200,0.20200957,1.998382,,,,,,,,,,,,,,,,, -12300,0.15605406,1.9384573,,,,,,,,,,,,,,,,, -12400,0.2225751,1.8988553,,,,,,,,,,,,,,,,, -12500,0.28939274,2.0021958,,,,,,,,,,,,,,,,, -12600,0.18149228,1.9986844,,,,,,,,,,,,,,,,, -12700,0.23209113,1.9137418,,,,,,,,,,,,,,,,, -12800,0.16450234,1.9688109,,,,,,,,,,,,,,,,, -12900,0.19863652,2.0150495,,,,,,,,,,,,,,,,, -13000,0.1680789,1.9485627,,,,,,,,,,,,,,,,, -13100,0.170493,1.9159217,,,,,,,,,,,,,,,,, -13200,0.17794888,1.97642,,,,,,,,,,,,,,,,, -13300,0.23389274,1.9369067,,,,,,,,,,,,,,,,, -13400,0.30086166,1.9596134,,,,,,,,,,,,,,,,, -13500,0.16770302,1.9618492,,,,,,,,,,,,,,,,, -13600,0.20008028,1.8864143,,,,,,,,,,,,,,,,, -13700,0.19207948,1.9161822,,,,,,,,,,,,,,,,, -13800,0.18045959,2.0110877,,,,,,,,,,,,,,,,, -13900,0.19927453,1.9252924,,,,,,,,,,,,,,,,, -13992,,,0.6251095533370972,1.8352748155593872,29.71905861466232,0.6442325711250305,1.712920308113098,27.30846081869481,3000.0,0.6517343521118164,1.6489189863204956,25.849126741883687,3003.0,5078.8299243450165,8740.295587062836,5078.8299243450165,3660.879511117935,0.1590526103973388,0.0 -14000,0.19693334,1.8317704,,,,,,,,,,,,,,,,, -14100,0.2568029,1.953296,,,,,,,,,,,,,,,,, -14200,0.20867847,2.011967,,,,,,,,,,,,,,,,, -14300,0.16443926,1.9504131,,,,,,,,,,,,,,,,, -14400,0.24800502,1.8499006,,,,,,,,,,,,,,,,, -14500,0.18601862,1.8599497,,,,,,,,,,,,,,,,, -14600,0.19527082,1.8775204,,,,,,,,,,,,,,,,, -14700,0.21222596,1.904131,,,,,,,,,,,,,,,,, -14800,0.20777676,2.0337894,,,,,,,,,,,,,,,,, -14900,0.19153905,1.929952,,,,,,,,,,,,,,,,, -15000,0.17847307,1.9830004,,,,,,,,,,,,,,,,, -15100,0.18983454,1.9188327,,,,,,,,,,,,,,,,, -15200,0.19374157,1.9239619,,,,,,,,,,,,,,,,, -15300,0.18288334,1.8778839,,,,,,,,,,,,,,,,, -15400,0.22373174,1.8775619,,,,,,,,,,,,,,,,, -15500,0.24434206,1.8744698,,,,,,,,,,,,,,,,, -15600,0.20686805,1.7934934,,,,,,,,,,,,,,,,, -15700,0.18759443,1.8985028,,,,,,,,,,,,,,,,, -15800,0.26230195,1.806245,,,,,,,,,,,,,,,,, -15900,0.20532262,1.9705746,,,,,,,,,,,,,,,,, -16000,0.20743465,1.926262,,,,,,,,,,,,,,,,, -16100,0.1993552,1.9064881,,,,,,,,,,,,,,,,, -16200,0.18888435,1.8348403,,,,,,,,,,,,,,,,, -16300,0.2243833,1.9329529,,,,,,,,,,,,,,,,, -16324,,,0.6289481520652771,1.8222870826721191,30.1445867527318,0.6479398608207703,1.68116557598114,27.3637250719459,3000.0,0.6576956510543823,1.614374041557312,26.30162304760944,3003.0,5918.779824733734,10062.03084230423,5918.779824733734,4142.569768667221,0.1852412223815918,0.0 -16400,0.25757328,1.9323071,,,,,,,,,,,,,,,,, -16500,0.17982936,1.8971622,,,,,,,,,,,,,,,,, -16600,0.16904405,1.8361574,,,,,,,,,,,,,,,,, -16700,0.18432315,1.9253763,,,,,,,,,,,,,,,,, -16800,0.2107669,1.9255271,,,,,,,,,,,,,,,,, -16900,0.21101595,1.9487609,,,,,,,,,,,,,,,,, -17000,0.19383436,1.8927586,,,,,,,,,,,,,,,,, -17100,0.19650863,1.9962692,,,,,,,,,,,,,,,,, -17200,0.1933038,1.8586919,,,,,,,,,,,,,,,,, -17300,0.18127963,1.881075,,,,,,,,,,,,,,,,, -17400,0.18531561,1.9253354,,,,,,,,,,,,,,,,, -17500,0.18550833,1.8516834,,,,,,,,,,,,,,,,, -17600,1.7489785,1.9978205,,,,,,,,,,,,,,,,, -17700,0.20618159,1.8747588,,,,,,,,,,,,,,,,, -17800,0.1761357,1.8989137,,,,,,,,,,,,,,,,, -17900,0.19284534,1.927687,,,,,,,,,,,,,,,,, -18000,0.21118686,1.9733245,,,,,,,,,,,,,,,,, -18100,0.18943328,1.8489501,,,,,,,,,,,,,,,,, -18200,0.7105435,1.9229343,,,,,,,,,,,,,,,,, -18300,0.18089227,1.8650063,,,,,,,,,,,,,,,,, -18400,0.19099578,1.9607583,,,,,,,,,,,,,,,,, -18500,0.2017297,1.8884281,,,,,,,,,,,,,,,,, -18600,0.22619368,1.8479207,,,,,,,,,,,,,,,,, -18657,,,0.6277276873588562,1.8295557498931885,30.65094670772645,0.6513124108314514,1.6622154712677002,27.49430518378227,3000.0,0.660449743270874,1.6002322435379028,26.63128190515756,3003.0,6758.940435886383,11397.138581991196,6758.940435886383,4637.421412944794,0.2112081050872802,0.0 -18700,0.18722859,1.8153722,,,,,,,,,,,,,,,,, -18800,0.19304277,1.9072475,,,,,,,,,,,,,,,,, -18900,0.18872808,1.839164,,,,,,,,,,,,,,,,, -19000,0.16861452,1.8323835,,,,,,,,,,,,,,,,, -19100,0.22462769,1.8858104,,,,,,,,,,,,,,,,, -19200,0.17812702,1.8458576,,,,,,,,,,,,,,,,, -19300,0.18946573,1.815844,,,,,,,,,,,,,,,,, -19400,0.24319828,1.8536979,,,,,,,,,,,,,,,,, -19500,0.21954252,1.8650502,,,,,,,,,,,,,,,,, -19600,0.20304748,1.8227241,,,,,,,,,,,,,,,,, -19700,0.1885431,1.8427675,,,,,,,,,,,,,,,,, -19800,0.21480802,1.7751422,,,,,,,,,,,,,,,,, -19900,0.27448353,1.8422419,,,,,,,,,,,,,,,,, -20000,0.17144382,1.8597713,,,,,,,,,,,,,,,,, -20100,0.21100374,1.8715526,,,,,,,,,,,,,,,,, -20200,0.21448335,1.8968955,,,,,,,,,,,,,,,,, -20300,0.20913039,1.8289964,,,,,,,,,,,,,,,,, -20400,0.17574555,1.9039952,,,,,,,,,,,,,,,,, -20500,0.19155984,1.837343,,,,,,,,,,,,,,,,, -20600,0.1717649,1.8607748,,,,,,,,,,,,,,,,, -20700,0.2146694,1.8639511,,,,,,,,,,,,,,,,, -20800,0.27211577,1.8572704,,,,,,,,,,,,,,,,, -20900,0.22497858,1.8717637,,,,,,,,,,,,,,,,, -20990,,,0.6352289915084839,1.7688525915145874,30.791125978727298,0.6497625708580017,1.6578360795974731,27.406212220377967,3000.0,0.6596363186836243,1.5901591777801514,26.523374816812822,3003.0,7599.17161488533,12709.490337610245,7599.17161488533,5109.4391322135925,0.2413890361785888,0.0 -21000,0.31140807,1.8961323,,,,,,,,,,,,,,,,, -21100,0.27674147,1.8132002,,,,,,,,,,,,,,,,, -21200,0.20883854,1.915576,,,,,,,,,,,,,,,,, -21300,0.23532048,1.8775198,,,,,,,,,,,,,,,,, -21400,0.24034284,1.8529726,,,,,,,,,,,,,,,,, -21500,0.17399882,1.9147576,,,,,,,,,,,,,,,,, -21600,0.22028278,1.905907,,,,,,,,,,,,,,,,, -21700,0.21811889,1.8280978,,,,,,,,,,,,,,,,, -21800,0.19132319,1.7924894,,,,,,,,,,,,,,,,, -21900,0.18338087,1.879493,,,,,,,,,,,,,,,,, -22000,0.2957231,2.0249894,,,,,,,,,,,,,,,,, -22100,0.19564432,1.8629696,,,,,,,,,,,,,,,,, -22200,0.18070433,1.8616799,,,,,,,,,,,,,,,,, -22300,0.18838084,1.8527138,,,,,,,,,,,,,,,,, -22400,0.20149681,1.9144565,,,,,,,,,,,,,,,,, -22500,0.1860712,1.8978624,,,,,,,,,,,,,,,,, -22600,0.20886782,1.8156431,,,,,,,,,,,,,,,,, -22700,0.19421637,1.9117318,,,,,,,,,,,,,,,,, -22800,0.22258464,1.8250461,,,,,,,,,,,,,,,,, -22900,0.20304836,1.9395107,,,,,,,,,,,,,,,,, -23000,0.20535201,1.91726,,,,,,,,,,,,,,,,, -23100,0.2286224,1.7829386,,,,,,,,,,,,,,,,, -23200,0.19759767,1.8663706,,,,,,,,,,,,,,,,, -23300,0.30250037,1.857072,,,,,,,,,,,,,,,,, -23323,,,0.6351427435874939,1.7711073160171509,31.04816736219118,0.6542758345603943,1.6312884092330933,28.171750929722943,3000.0,0.6660391688346863,1.55837881565094,27.10138290676962,3003.0,8439.236397266388,14052.139616012571,8439.236397266388,5611.919378757477,0.2727680206298828,0.0 -23400,0.2155641,1.9189196,,,,,,,,,,,,,,,,, -23500,0.18387982,1.8524342,,,,,,,,,,,,,,,,, -23600,0.22626129,1.9236093,,,,,,,,,,,,,,,,, -23700,0.1805997,1.8480364,,,,,,,,,,,,,,,,, -23800,2.1887124,1.777966,,,,,,,,,,,,,,,,, -23900,0.24507381,1.82269,,,,,,,,,,,,,,,,, -24000,0.23430793,1.9573556,,,,,,,,,,,,,,,,, -24100,0.18577476,1.8264321,,,,,,,,,,,,,,,,, -24200,0.19428875,1.8561317,,,,,,,,,,,,,,,,, -24300,0.245874,1.8312839,,,,,,,,,,,,,,,,, -24400,0.30490142,1.8141141,,,,,,,,,,,,,,,,, -24500,0.20336041,1.8272516,,,,,,,,,,,,,,,,, -24600,0.16889271,1.8589426,,,,,,,,,,,,,,,,, -24700,0.19956371,1.8096793,,,,,,,,,,,,,,,,, -24800,0.1875841,1.8167051,,,,,,,,,,,,,,,,, -24900,0.2164951,1.8297437,,,,,,,,,,,,,,,,, -25000,0.24455078,1.8187674,,,,,,,,,,,,,,,,, -25100,0.26356575,1.8144016,,,,,,,,,,,,,,,,, -25200,0.27643952,1.8584265,,,,,,,,,,,,,,,,, -25300,0.19039543,1.9073098,,,,,,,,,,,,,,,,, -25400,0.2386372,1.8310384,,,,,,,,,,,,,,,,, -25500,0.24257904,1.711161,,,,,,,,,,,,,,,,, -25600,0.18481408,1.8162484,,,,,,,,,,,,,,,,, -25657,,,0.6489943861961365,1.6708744764328003,31.861470460471523,0.6577103734016418,1.6123778820037842,28.386727715138804,3000.0,0.6674452424049377,1.546763896942139,27.24795073109318,3003.0,9279.262301683426,15392.960274457932,9279.262301683426,6112.617798089981,0.3002674579620361,0.0 -25700,0.2035568,1.8255029,,,,,,,,,,,,,,,,, -25800,0.1989109,1.8996799,,,,,,,,,,,,,,,,, -25900,0.18261923,1.7907939,,,,,,,,,,,,,,,,, -26000,0.17054884,1.8406658,,,,,,,,,,,,,,,,, -26100,0.19124126,1.8778504,,,,,,,,,,,,,,,,, -26200,0.21289401,1.8789495,,,,,,,,,,,,,,,,, -26300,0.22046681,1.7879068,,,,,,,,,,,,,,,,, -26400,0.19104625,1.8210348,,,,,,,,,,,,,,,,, -26500,0.2590201,1.7630143,,,,,,,,,,,,,,,,, -26600,0.18422313,1.8061539,,,,,,,,,,,,,,,,, -26700,0.17770757,1.8244674,,,,,,,,,,,,,,,,, -26800,0.24572021,1.8995249,,,,,,,,,,,,,,,,, -26900,0.22659977,1.8153743,,,,,,,,,,,,,,,,, -27000,0.22680518,1.8587348,,,,,,,,,,,,,,,,, -27100,0.21587548,1.8081816,,,,,,,,,,,,,,,,, -27200,0.18687391,1.8463241,,,,,,,,,,,,,,,,, -27300,0.19547449,1.8353814,,,,,,,,,,,,,,,,, -27400,0.18171693,1.7648154,,,,,,,,,,,,,,,,, -27500,0.19299158,1.9240812,,,,,,,,,,,,,,,,, -27600,0.19839622,1.7867613,,,,,,,,,,,,,,,,, -27700,0.19486296,1.8039073,,,,,,,,,,,,,,,,, -27800,0.20381546,1.8281859,,,,,,,,,,,,,,,,, -27900,0.17523883,1.8389333,,,,,,,,,,,,,,,,, -27990,,,0.6425632238388062,1.7168598175048828,31.24810712567143,0.658665120601654,1.6051126718521118,28.33344174410828,3000.0,0.6686421632766724,1.5359697341918943,27.659953390141943,3003.0,10119.33438873291,16887.398057222366,10119.33438873291,6766.884717226028,0.3277285099029541,0.0 -28000,0.19459477,1.8768247,,,,,,,,,,,,,,,,, -28100,0.25026327,1.7764727,,,,,,,,,,,,,,,,, -28200,0.19850148,1.8940057,,,,,,,,,,,,,,,,, -28300,0.28127873,1.7997112,,,,,,,,,,,,,,,,, -28400,0.18462531,1.8107337,,,,,,,,,,,,,,,,, -28500,0.19414425,1.8119473,,,,,,,,,,,,,,,,, -28600,0.18643376,1.8101258,,,,,,,,,,,,,,,,, -28700,0.23299639,1.7980099,,,,,,,,,,,,,,,,, -28800,0.27052107,1.8975945,,,,,,,,,,,,,,,,, -28900,0.20043503,1.7986404,,,,,,,,,,,,,,,,, -29000,0.20914543,1.8436302,,,,,,,,,,,,,,,,, -29100,0.18819457,1.7860409,,,,,,,,,,,,,,,,, -29200,0.22114208,1.7944746,,,,,,,,,,,,,,,,, -29300,0.20875068,1.8607128,,,,,,,,,,,,,,,,, -29400,0.25478977,1.8642449,,,,,,,,,,,,,,,,, -29500,0.19597168,1.8615934,,,,,,,,,,,,,,,,, -29600,0.19326283,1.7378935,,,,,,,,,,,,,,,,, -29700,0.17873994,1.8022034,,,,,,,,,,,,,,,,, -29800,0.18972282,1.843089,,,,,,,,,,,,,,,,, -29900,0.25149685,1.8216726,,,,,,,,,,,,,,,,, -30000,0.18486784,1.8316246,,,,,,,,,,,,,,,,, -30100,0.1988331,1.8506503,,,,,,,,,,,,,,,,, -30200,0.2141706,1.834121,,,,,,,,,,,,,,,,, -30300,0.19795203,1.8716407,,,,,,,,,,,,,,,,, -30324,,,0.6387715935707092,1.7513316869735718,31.180361760347587,0.6582807302474976,1.5954887866973877,28.01698388066988,3000.0,0.6704317331314087,1.5235344171524048,27.53926569000312,3003.0,10959.488750457764,18185.67586708069,10959.488750457764,7224.911325931549,0.3553922176361084,0.0 -30400,0.2079333,1.8130714,,,,,,,,,,,,,,,,, -30500,0.17918803,1.7579237,,,,,,,,,,,,,,,,, -30600,0.18341815,1.7784622,,,,,,,,,,,,,,,,, -30700,0.20426776,1.768447,,,,,,,,,,,,,,,,, -30800,0.20012334,1.8205208,,,,,,,,,,,,,,,,, -30900,0.20458946,1.8491184,,,,,,,,,,,,,,,,, -31000,0.2174366,1.773028,,,,,,,,,,,,,,,,, -31100,0.19990648,1.8939322,,,,,,,,,,,,,,,,, -31200,0.21049985,1.8093827,,,,,,,,,,,,,,,,, -31300,0.20931497,1.83663,,,,,,,,,,,,,,,,, -31400,0.25209776,1.8836008,,,,,,,,,,,,,,,,, -31500,0.23094444,1.7224823,,,,,,,,,,,,,,,,, -31600,0.18171689,1.8102604,,,,,,,,,,,,,,,,, -31700,0.21632922,1.8803883,,,,,,,,,,,,,,,,, -31800,0.1919839,1.7872695,,,,,,,,,,,,,,,,, -31900,0.2270005,1.818978,,,,,,,,,,,,,,,,, -32000,0.17696737,1.8322166,,,,,,,,,,,,,,,,, -32100,0.2605443,1.799409,,,,,,,,,,,,,,,,, -32200,0.19377089,1.7086157,,,,,,,,,,,,,,,,, -32300,0.19875744,1.805546,,,,,,,,,,,,,,,,, -32400,0.19208655,1.8165058,,,,,,,,,,,,,,,,, -32500,0.22211868,1.7964932,,,,,,,,,,,,,,,,, -32600,0.18842697,1.8484635,,,,,,,,,,,,,,,,, -32657,,,0.6421418190002441,1.7130132913589478,31.688443196843643,0.6594462394714355,1.5872998237609863,28.124145456230487,3000.0,0.6709546446800232,1.5159692764282229,27.46059266106052,3003.0,11799.395884513857,19515.8352162838,11799.395884513857,7715.065255880356,0.382869005203247,0.0 -32700,0.19347125,1.8101605,,,,,,,,,,,,,,,,, -32800,0.2303539,1.8458018,,,,,,,,,,,,,,,,, -32900,0.20986551,1.810738,,,,,,,,,,,,,,,,, -33000,0.18807039,1.8130279,,,,,,,,,,,,,,,,, -33100,0.22270499,1.8245099,,,,,,,,,,,,,,,,, -33200,0.20414455,1.8565454,,,,,,,,,,,,,,,,, -33300,0.19222552,1.79296,,,,,,,,,,,,,,,,, -33400,0.2704769,1.8556502,,,,,,,,,,,,,,,,, -33500,0.18852168,1.7591143,,,,,,,,,,,,,,,,, -33600,0.18767025,1.8212298,,,,,,,,,,,,,,,,, -33700,0.19545823,1.8577206,,,,,,,,,,,,,,,,, -33800,0.19129537,1.7071083,,,,,,,,,,,,,,,,, -33900,0.21653482,1.8717899,,,,,,,,,,,,,,,,, -34000,0.21051128,1.8043497,,,,,,,,,,,,,,,,, -34100,0.21475594,1.838852,,,,,,,,,,,,,,,,, -34200,0.19421554,1.7435415,,,,,,,,,,,,,,,,, -34300,0.22085072,1.8030804,,,,,,,,,,,,,,,,, -34400,0.19596319,1.7467331,,,,,,,,,,,,,,,,, -34500,0.20819743,1.7201196,,,,,,,,,,,,,,,,, -34600,0.21319987,1.8609661,,,,,,,,,,,,,,,,, -34700,0.18701857,1.8207049,,,,,,,,,,,,,,,,, -34800,0.21256284,1.7809134,,,,,,,,,,,,,,,,, -34900,0.1995407,1.8281538,,,,,,,,,,,,,,,,, -34991,,,0.6410354971885681,1.7259587049484253,31.29086602755576,0.6630048155784607,1.5775558948516846,28.3995127800219,3000.0,0.672105073928833,1.5036916732788086,27.63131568906135,3003.0,12639.648458957672,20862.66091775894,12639.648458957672,8221.539778709412,0.4115808010101318,0.0 -35000,0.1776244,1.7316053,,,,,,,,,,,,,,,,, -35100,0.21733871,1.8552227,,,,,,,,,,,,,,,,, -35200,0.2337081,1.8787923,,,,,,,,,,,,,,,,, -35300,0.23659556,1.7844725,,,,,,,,,,,,,,,,, -35400,0.21215816,1.7892318,,,,,,,,,,,,,,,,, -35500,0.2135389,1.8181604,,,,,,,,,,,,,,,,, -35600,0.23310773,1.8334736,,,,,,,,,,,,,,,,, -35700,0.18114163,1.7237179,,,,,,,,,,,,,,,,, -35800,0.1864036,1.7045022,,,,,,,,,,,,,,,,, -35900,0.19067457,1.8057519,,,,,,,,,,,,,,,,, -36000,0.23338486,1.7941339,,,,,,,,,,,,,,,,, -36100,0.20980594,1.7690876,,,,,,,,,,,,,,,,, -36200,0.20328464,1.8128538,,,,,,,,,,,,,,,,, -36300,0.22500083,1.8042376,,,,,,,,,,,,,,,,, -36400,0.20059177,1.7770014,,,,,,,,,,,,,,,,, -36500,0.1880629,1.7791083,,,,,,,,,,,,,,,,, -36600,0.22608921,1.8174896,,,,,,,,,,,,,,,,, -36700,0.2035952,1.6798184,,,,,,,,,,,,,,,,, -36800,0.20450516,1.7906154,,,,,,,,,,,,,,,,, -36900,0.18461931,1.8037603,,,,,,,,,,,,,,,,, -37000,0.19203705,1.7963277,,,,,,,,,,,,,,,,, -37100,0.24309686,1.9218985,,,,,,,,,,,,,,,,, -37200,0.29672235,1.845238,,,,,,,,,,,,,,,,, -37300,0.2898587,1.8142959,,,,,,,,,,,,,,,,, -37325,,,0.6429249048233032,1.7336983680725098,31.46545863659349,0.663240373134613,1.577974796295166,28.164983907427093,3000.0,0.6757538914680481,1.500253200531006,27.793051245605607,3003.0,13479.89543390274,22274.75265264511,13479.89543390274,8793.286875247955,0.4394073486328125,0.0 -37400,0.20379126,1.852975,,,,,,,,,,,,,,,,, -37500,0.23631014,1.7381812,,,,,,,,,,,,,,,,, -37600,0.21452136,1.7982118,,,,,,,,,,,,,,,,, -37700,0.20490952,1.7447597,,,,,,,,,,,,,,,,, -37800,0.18414831,1.7487091,,,,,,,,,,,,,,,,, -37900,0.20000602,1.7166407,,,,,,,,,,,,,,,,, -38000,0.18756852,1.7220386,,,,,,,,,,,,,,,,, -38100,0.19800016,1.7580043,,,,,,,,,,,,,,,,, -38200,0.24085715,1.737911,,,,,,,,,,,,,,,,, -38300,0.21707867,1.7429448,,,,,,,,,,,,,,,,, -38400,0.20108378,1.8470951,,,,,,,,,,,,,,,,, -38500,0.24289452,1.7823639,,,,,,,,,,,,,,,,, -38600,0.22028524,1.8020976,,,,,,,,,,,,,,,,, -38700,0.18852578,1.7998472,,,,,,,,,,,,,,,,, -38800,0.19974871,1.7902029,,,,,,,,,,,,,,,,, -38900,0.20721911,1.8405546,,,,,,,,,,,,,,,,, -39000,0.18648197,1.7223809,,,,,,,,,,,,,,,,, -39100,0.5988109,1.792685,,,,,,,,,,,,,,,,, -39200,0.26912054,1.7401481,,,,,,,,,,,,,,,,, -39300,0.22703913,1.8223866,,,,,,,,,,,,,,,,, -39400,0.21919602,1.7095286,,,,,,,,,,,,,,,,, -39500,0.196798,1.815152,,,,,,,,,,,,,,,,, -39600,0.18015781,1.7433258,,,,,,,,,,,,,,,,, -39659,,,0.6447225213050842,1.689799427986145,31.25804832940152,0.6650258302688599,1.5642013549804688,28.47577944514998,3000.0,0.6762884259223938,1.486888408660889,28.09599251642844,3003.0,14320.017933368685,23773.4026260376,14320.017933368685,9451.717700958252,0.4671125411987304,0.0 -39700,0.20212561,1.7972459,,,,,,,,,,,,,,,,, -39800,0.23531452,1.702481,,,,,,,,,,,,,,,,, -39900,0.2111476,1.7705197,,,,,,,,,,,,,,,,, -40000,0.2176353,1.7500026,,,,,,,,,,,,,,,,, -40100,0.22401783,1.7959149,,,,,,,,,,,,,,,,, -40200,0.17293239,1.7043232,,,,,,,,,,,,,,,,, -40300,0.20296367,1.7997437,,,,,,,,,,,,,,,,, -40400,0.18560876,1.7464708,,,,,,,,,,,,,,,,, -40500,0.20196891,1.8570788,,,,,,,,,,,,,,,,, -40600,0.20635037,1.7581835,,,,,,,,,,,,,,,,, -40700,0.20008986,1.7453364,,,,,,,,,,,,,,,,, -40800,0.20323953,1.72241,,,,,,,,,,,,,,,,, -40900,0.2132849,1.7298756,,,,,,,,,,,,,,,,, -41000,0.19501086,1.7854663,,,,,,,,,,,,,,,,, -41100,0.1834688,1.704275,,,,,,,,,,,,,,,,, -41200,0.20958242,1.7471132,,,,,,,,,,,,,,,,, -41300,0.22093439,1.7454455,,,,,,,,,,,,,,,,, -41400,0.19898714,1.7937275,,,,,,,,,,,,,,,,, -41500,0.25655586,1.7666236,,,,,,,,,,,,,,,,, -41600,0.19333649,1.7590345,,,,,,,,,,,,,,,,, -41700,0.2073328,1.8310603,,,,,,,,,,,,,,,,, -41800,0.18592006,1.7401949,,,,,,,,,,,,,,,,, -41900,0.19466819,1.7277269,,,,,,,,,,,,,,,,, -41993,,,0.6418246626853943,1.7203962802886963,31.54752667866807,0.6655341982841492,1.562987208366394,28.69866005713504,3000.0,0.6767416596412659,1.482648491859436,28.159225092401268,3003.0,15160.147383213043,25245.83581018448,15160.147383213043,10083.92147922516,0.4971275329589844,0.0 -42000,0.19381274,1.7503896,,,,,,,,,,,,,,,,, -42100,0.19362721,1.7062141,,,,,,,,,,,,,,,,, -42200,0.40338442,1.8048745,,,,,,,,,,,,,,,,, -42300,0.19179635,1.7397668,,,,,,,,,,,,,,,,, -42400,0.19182654,1.7747022,,,,,,,,,,,,,,,,, -42500,0.19539638,1.8696625,,,,,,,,,,,,,,,,, -42600,0.2060106,1.810288,,,,,,,,,,,,,,,,, -42700,0.17525503,1.6914135,,,,,,,,,,,,,,,,, -42800,0.19994393,1.7153904,,,,,,,,,,,,,,,,, -42900,0.31669807,1.6834192,,,,,,,,,,,,,,,,, -43000,0.21075685,1.749048,,,,,,,,,,,,,,,,, -43100,0.20944726,1.7799782,,,,,,,,,,,,,,,,, -43200,0.19300032,1.7821765,,,,,,,,,,,,,,,,, -43300,0.20802154,1.7744246,,,,,,,,,,,,,,,,, -43400,0.25765684,1.7102824,,,,,,,,,,,,,,,,, -43500,0.19467081,1.6679678,,,,,,,,,,,,,,,,, -43600,0.22950046,1.7440052,,,,,,,,,,,,,,,,, -43700,0.19146454,1.729677,,,,,,,,,,,,,,,,, -43800,0.20355564,1.7742015,,,,,,,,,,,,,,,,, -43900,0.1882386,1.8098435,,,,,,,,,,,,,,,,, -44000,0.24008273,1.8377239,,,,,,,,,,,,,,,,, -44100,0.20135076,1.7998956,,,,,,,,,,,,,,,,, -44200,0.19521487,1.8118151,,,,,,,,,,,,,,,,, -44300,0.2208027,1.7843872,,,,,,,,,,,,,,,,, -44326,,,0.6613112092018127,1.5802650451660156,32.61597261217031,0.6657201647758484,1.5555511713027954,28.666405029152983,3000.0,0.6781128644943237,1.4767955541610718,27.77103557395157,3003.0,16000.051938295364,26621.491897821423,16000.051938295364,10619.568261384964,0.5311579704284668,0.0 -44400,0.19831614,1.675891,,,,,,,,,,,,,,,,, -44500,0.2229347,1.7722609,,,,,,,,,,,,,,,,, -44600,0.19900392,1.7697769,,,,,,,,,,,,,,,,, -44700,0.19193685,1.7532399,,,,,,,,,,,,,,,,, -44800,0.2536553,1.7387309,,,,,,,,,,,,,,,,, -44900,0.1980207,1.7975396,,,,,,,,,,,,,,,,, -45000,0.18713589,1.7156285,,,,,,,,,,,,,,,,, -45100,0.19846193,1.7332214,,,,,,,,,,,,,,,,, -45200,0.19866818,1.7663381,,,,,,,,,,,,,,,,, -45300,0.19110672,1.6816546,,,,,,,,,,,,,,,,, -45400,0.20462734,1.7355615,,,,,,,,,,,,,,,,, -45500,0.21307042,1.7936702,,,,,,,,,,,,,,,,, -45600,0.21102136,1.7711831,,,,,,,,,,,,,,,,, -45700,0.24607761,1.7252754,,,,,,,,,,,,,,,,, -45800,0.22401872,1.733166,,,,,,,,,,,,,,,,, -45900,0.19118555,1.8071319,,,,,,,,,,,,,,,,, -46000,0.20828804,1.8132963,,,,,,,,,,,,,,,,, -46100,0.20359533,1.838276,,,,,,,,,,,,,,,,, -46200,0.20068853,1.7317183,,,,,,,,,,,,,,,,, -46300,0.2019463,1.8098342,,,,,,,,,,,,,,,,, -46400,0.25196305,1.8524649,,,,,,,,,,,,,,,,, -46500,0.24895352,1.7386467,,,,,,,,,,,,,,,,, -46600,0.1953341,1.7898921,,,,,,,,,,,,,,,,, -46660,,,0.6497166752815247,1.657746195793152,31.55529780918572,0.6670964956283569,1.5500373840332031,28.635666652429748,3000.0,0.6805066466331482,1.4682908058166504,28.084528100338435,3003.0,16839.980953216553,28040.916669368744,16839.980953216553,11198.965550661089,0.561424970626831,0.0 -46700,0.21657766,1.7551367,,,,,,,,,,,,,,,,, -46800,0.20035976,1.7469963,,,,,,,,,,,,,,,,, -46900,0.20839484,1.6835217,,,,,,,,,,,,,,,,, -47000,0.17554352,1.7111472,,,,,,,,,,,,,,,,, -47100,0.20257872,1.8000926,,,,,,,,,,,,,,,,, -47200,0.21781823,1.7561005,,,,,,,,,,,,,,,,, -47300,0.20067832,1.7488686,,,,,,,,,,,,,,,,, -47400,0.1930265,1.6663857,,,,,,,,,,,,,,,,, -47500,0.2082527,1.8037786,,,,,,,,,,,,,,,,, -47600,0.19505344,1.7561113,,,,,,,,,,,,,,,,, -47700,0.21046762,1.7656189,,,,,,,,,,,,,,,,, -47800,0.24713391,1.7696338,,,,,,,,,,,,,,,,, -47900,0.20303406,1.7555599,,,,,,,,,,,,,,,,, -48000,0.24796066,1.7624729,,,,,,,,,,,,,,,,, -48100,0.3457181,1.6550243,,,,,,,,,,,,,,,,, -48200,0.20059927,1.7632409,,,,,,,,,,,,,,,,, -48300,0.2547533,1.7929838,,,,,,,,,,,,,,,,, -48400,0.19821031,1.8096173,,,,,,,,,,,,,,,,, -48500,0.25926235,1.711405,,,,,,,,,,,,,,,,, -48600,0.24538974,1.7089015,,,,,,,,,,,,,,,,, -48700,0.19419433,1.7348932,,,,,,,,,,,,,,,,, -48800,0.1940067,1.7374278,,,,,,,,,,,,,,,,, -48900,0.21441142,1.7465808,,,,,,,,,,,,,,,,, -48994,,,0.6467783451080322,1.6856553554534912,31.408422055872663,0.6677412390708923,1.5390740633010864,28.786372928465703,3000.0,0.6798210740089417,1.4563076496124268,28.0331932471544,3003.0,17679.945336341858,29440.80671453476,17679.945336341858,11758.793085098268,0.5909426212310791,0.0 -49000,0.22686312,1.8160182,,,,,,,,,,,,,,,,, -49100,0.20732105,1.694582,,,,,,,,,,,,,,,,, -49200,0.19376218,1.7902967,,,,,,,,,,,,,,,,, -49300,0.20480658,1.7103953,,,,,,,,,,,,,,,,, -49400,0.18792057,1.7637491,,,,,,,,,,,,,,,,, -49500,0.20406042,1.7028546,,,,,,,,,,,,,,,,, -49600,0.1927487,1.7673469,,,,,,,,,,,,,,,,, -49700,0.33368546,1.78866,,,,,,,,,,,,,,,,, -49800,0.20713344,1.7195009,,,,,,,,,,,,,,,,, -49900,0.18020315,1.6713399,,,,,,,,,,,,,,,,, -50000,0.20332956,1.7879667,,,,,,,,,,,,,,,,, -50100,0.19375964,1.68478,,,,,,,,,,,,,,,,, -50200,0.20315798,1.7413447,,,,,,,,,,,,,,,,, -50300,0.22082195,1.7411746,,,,,,,,,,,,,,,,, -50400,0.21219283,1.8623999,,,,,,,,,,,,,,,,, -50500,0.22966476,1.7925389,,,,,,,,,,,,,,,,, -50600,0.1908321,1.7302598,,,,,,,,,,,,,,,,, -50700,0.19918069,1.7135651,,,,,,,,,,,,,,,,, -50800,0.20684564,1.7008744,,,,,,,,,,,,,,,,, -50900,0.462346,1.7751234,,,,,,,,,,,,,,,,, -51000,0.19232589,1.687653,,,,,,,,,,,,,,,,, -51100,0.23281546,1.6723309,,,,,,,,,,,,,,,,, -51200,0.19100225,1.7567573,,,,,,,,,,,,,,,,, -51300,0.21145679,1.7795403,,,,,,,,,,,,,,,,, -51328,,,0.6549112796783447,1.6199612617492676,31.91548361763397,0.6689687371253967,1.5352182388305664,28.929565175389858,3000.0,0.6797513365745544,1.4512735605239868,28.20761255202798,3003.0,18520.05733203888,30831.75518250465,18520.05733203888,12309.529057741163,0.6208062171936035,0.0 -51400,0.21911849,1.7276309,,,,,,,,,,,,,,,,, -51500,0.27681568,1.7571266,,,,,,,,,,,,,,,,, -51600,0.20082551,1.7322396,,,,,,,,,,,,,,,,, -51700,0.21444461,1.764328,,,,,,,,,,,,,,,,, -51800,0.18916836,1.7326115,,,,,,,,,,,,,,,,, -51900,0.20167035,1.7880826,,,,,,,,,,,,,,,,, -52000,0.22693652,1.8379058,,,,,,,,,,,,,,,,, -52100,0.18559624,1.7395221,,,,,,,,,,,,,,,,, -52200,0.20929144,1.6514618,,,,,,,,,,,,,,,,, -52300,0.23912266,1.6952761,,,,,,,,,,,,,,,,, -52400,0.23681203,1.7734236,,,,,,,,,,,,,,,,, -52500,0.2178991,1.7122469,,,,,,,,,,,,,,,,, -52600,0.19694975,1.726062,,,,,,,,,,,,,,,,, -52700,0.23791133,1.7230343,,,,,,,,,,,,,,,,, -52800,0.20994256,1.7866383,,,,,,,,,,,,,,,,, -52900,0.20396084,1.72796,,,,,,,,,,,,,,,,, -53000,0.2316943,1.71877,,,,,,,,,,,,,,,,, -53100,0.26387247,1.7335824,,,,,,,,,,,,,,,,, -53200,0.21070638,1.8265728,,,,,,,,,,,,,,,,, -53300,0.23122093,1.7183937,,,,,,,,,,,,,,,,, -53400,0.19907172,1.7582952,,,,,,,,,,,,,,,,, -53500,0.22269553,1.765029,,,,,,,,,,,,,,,,, -53600,0.198187,1.7264159,,,,,,,,,,,,,,,,, -53662,,,0.6495839357376099,1.6710714101791382,31.52180095518269,0.6688695549964905,1.532265305519104,28.870555564318444,3000.0,0.681552529335022,1.4458444118499756,28.29054019765348,3003.0,19360.178364753723,32219.65393900872,19360.178364753723,12857.206429243088,0.6506195068359375,0.0 -53700,0.19163078,1.6537789,,,,,,,,,,,,,,,,, -53800,0.19032885,1.7446024,,,,,,,,,,,,,,,,, -53900,0.24298114,1.7106214,,,,,,,,,,,,,,,,, -54000,0.20092808,1.7181052,,,,,,,,,,,,,,,,, -54100,0.20574223,1.7189047,,,,,,,,,,,,,,,,, -54200,0.19250436,1.7040952,,,,,,,,,,,,,,,,, -54300,0.21405755,1.8076409,,,,,,,,,,,,,,,,, -54400,0.20831375,1.645997,,,,,,,,,,,,,,,,, -54500,0.20170052,1.6812583,,,,,,,,,,,,,,,,, -54600,0.24132629,1.8154379,,,,,,,,,,,,,,,,, -54700,0.20934637,1.8292392,,,,,,,,,,,,,,,,, -54800,0.19823067,1.7318093,,,,,,,,,,,,,,,,, -54900,0.2034641,1.7525228,,,,,,,,,,,,,,,,, -55000,0.19388844,1.7358962,,,,,,,,,,,,,,,,, -55100,0.24565889,1.6857253,,,,,,,,,,,,,,,,, -55200,0.20572254,1.7173994,,,,,,,,,,,,,,,,, -55300,0.18146257,1.6389107,,,,,,,,,,,,,,,,, -55400,0.19752447,1.660335,,,,,,,,,,,,,,,,, -55500,0.2503912,1.7801665,,,,,,,,,,,,,,,,, -55600,0.2182678,1.6659462,,,,,,,,,,,,,,,,, -55700,0.22639228,1.7377183,,,,,,,,,,,,,,,,, -55800,0.21645494,1.7282739,,,,,,,,,,,,,,,,, -55900,0.21701549,1.7293036,,,,,,,,,,,,,,,,, -55996,,,0.6558064818382263,1.641991376876831,32.056675923308504,0.6720685362815857,1.5172662734985352,29.18964414297396,3000.0,0.6832491159439087,1.438015103340149,28.576496861826488,3003.0,20200.181646585464,33586.63511300087,20200.181646585464,13384.084159612656,0.6815006732940674,0.0 -56000,0.21650808,1.6895767,,,,,,,,,,,,,,,,, -56100,0.19064672,1.6576474,,,,,,,,,,,,,,,,, -56200,0.19278905,1.7525897,,,,,,,,,,,,,,,,, -56300,0.21510758,1.7325882,,,,,,,,,,,,,,,,, -56400,0.20088287,1.7467115,,,,,,,,,,,,,,,,, -56500,0.19387086,1.70103,,,,,,,,,,,,,,,,, -56600,0.24970742,1.8569771,,,,,,,,,,,,,,,,, -56700,0.19886692,1.713883,,,,,,,,,,,,,,,,, -56800,0.20450453,1.7103924,,,,,,,,,,,,,,,,, -56900,0.20713831,1.7495756,,,,,,,,,,,,,,,,, -57000,0.19796441,1.6718343,,,,,,,,,,,,,,,,, -57100,0.20769572,1.6437376,,,,,,,,,,,,,,,,, -57200,0.19630446,1.6932434,,,,,,,,,,,,,,,,, -57300,0.20141812,1.7655083,,,,,,,,,,,,,,,,, -57400,0.20089784,1.7653536,,,,,,,,,,,,,,,,, -57500,0.21963467,1.7711339,,,,,,,,,,,,,,,,, -57600,0.2564995,1.7322265,,,,,,,,,,,,,,,,, -57700,0.20238458,1.7184488,,,,,,,,,,,,,,,,, -57800,0.19845062,1.7829384,,,,,,,,,,,,,,,,, -57900,0.20772515,1.708682,,,,,,,,,,,,,,,,, -58000,0.19994152,1.6609416,,,,,,,,,,,,,,,,, -58100,0.23493944,1.7869763,,,,,,,,,,,,,,,,, -58200,0.19358563,1.7552935,,,,,,,,,,,,,,,,, -58300,0.1966758,1.642648,,,,,,,,,,,,,,,,, -58329,,,0.657656729221344,1.6101428270339966,31.937285743635364,0.6714237928390503,1.5166760683059692,29.03327787917762,3000.0,0.6833652853965759,1.4308712482452393,28.48284404893331,3003.0,21040.09161424637,35118.44615530968,21040.09161424637,14075.882135391235,0.7136006355285645,0.0 -58400,0.20538001,1.7452692,,,,,,,,,,,,,,,,, -58500,0.1841519,1.6546183,,,,,,,,,,,,,,,,, -58600,0.2041155,1.7190183,,,,,,,,,,,,,,,,, -58700,0.21702252,1.7096442,,,,,,,,,,,,,,,,, -58800,0.20218621,1.755277,,,,,,,,,,,,,,,,, -58900,0.20914184,1.6869732,,,,,,,,,,,,,,,,, -59000,0.21465518,1.6949493,,,,,,,,,,,,,,,,, -59100,0.19815926,1.7388471,,,,,,,,,,,,,,,,, -59200,0.18773092,1.6973825,,,,,,,,,,,,,,,,, -59300,0.19827527,1.6124164,,,,,,,,,,,,,,,,, -59400,0.2346311,1.7041756,,,,,,,,,,,,,,,,, -59500,0.21570338,1.7064419,,,,,,,,,,,,,,,,, -59600,0.19447166,1.7518451,,,,,,,,,,,,,,,,, -59700,0.21349941,1.718272,,,,,,,,,,,,,,,,, -59800,0.20319554,1.6727691,,,,,,,,,,,,,,,,, -59900,0.20422211,1.6570945,,,,,,,,,,,,,,,,, -60000,0.19290754,1.7280865,,,,,,,,,,,,,,,,, -60100,0.19725558,1.7060689,,,,,,,,,,,,,,,,, -60200,0.19258143,1.6492927,,,,,,,,,,,,,,,,, -60300,0.19356234,1.6449285,,,,,,,,,,,,,,,,, -60400,0.2142983,1.8019872,,,,,,,,,,,,,,,,, -60500,0.21573968,1.7558037,,,,,,,,,,,,,,,,, -60600,0.24844252,1.6910998,,,,,,,,,,,,,,,,, -60663,,,0.6537819504737854,1.6467565298080444,32.49978187618908,0.671584963798523,1.510255217552185,28.973790934716995,3000.0,0.6852594614028931,1.4206849336624146,28.83500179481324,3003.0,21880.31508302689,36586.57926249504,21880.31508302689,14703.68877029419,0.7454655170440674,0.0 -60700,0.19390014,1.6546819,,,,,,,,,,,,,,,,, -60800,0.1917094,1.6721036,,,,,,,,,,,,,,,,, -60900,0.4858384,1.7257973,,,,,,,,,,,,,,,,, -61000,0.23248562,1.7321537,,,,,,,,,,,,,,,,, -61100,0.21269694,1.7812613,,,,,,,,,,,,,,,,, -61200,0.20341372,1.7201533,,,,,,,,,,,,,,,,, -61300,0.2030524,1.6878328,,,,,,,,,,,,,,,,, -61400,0.20582892,1.7080857,,,,,,,,,,,,,,,,, -61500,0.19003862,1.7506477,,,,,,,,,,,,,,,,, -61600,0.2232614,1.7539034,,,,,,,,,,,,,,,,, -61700,0.22337745,1.7458775,,,,,,,,,,,,,,,,, -61800,0.20151585,1.6939324,,,,,,,,,,,,,,,,, -61900,0.19370687,1.6156538,,,,,,,,,,,,,,,,, -62000,0.20262837,1.7256384,,,,,,,,,,,,,,,,, -62100,0.19298522,1.7745552,,,,,,,,,,,,,,,,, -62200,0.19233337,1.7134681,,,,,,,,,,,,,,,,, -62300,0.18619987,1.6678121,,,,,,,,,,,,,,,,, -62400,0.2529917,1.6042604,,,,,,,,,,,,,,,,, -62500,0.2911831,1.6685748,,,,,,,,,,,,,,,,, -62600,0.21216425,1.7068274,,,,,,,,,,,,,,,,, -62700,0.24439147,1.7004822,,,,,,,,,,,,,,,,, -62800,0.19615316,1.6415782,,,,,,,,,,,,,,,,, -62900,0.20183668,1.7277582,,,,,,,,,,,,,,,,, -62997,,,0.6693201661109924,1.52785062789917,33.4739314800073,0.6742631793022156,1.5014219284057615,29.266614300038707,3000.0,0.686967670917511,1.4123915433883667,28.856663573822804,3003.0,22720.35317778588,38038.63135719299,22720.35317778588,15315.602014303207,0.77724289894104,0.0 -63000,0.19685827,1.6832246,,,,,,,,,,,,,,,,, -63100,0.19750123,1.704665,,,,,,,,,,,,,,,,, -63200,0.22932296,1.764256,,,,,,,,,,,,,,,,, -63300,0.19440983,1.6226279,,,,,,,,,,,,,,,,, -63400,0.21687701,1.7197118,,,,,,,,,,,,,,,,, -63500,0.22867215,1.7170043,,,,,,,,,,,,,,,,, -63600,0.19507246,1.5944791,,,,,,,,,,,,,,,,, -63700,0.20130235,1.6579334,,,,,,,,,,,,,,,,, -63800,0.20489472,1.7110283,,,,,,,,,,,,,,,,, -63900,0.24199644,1.6390986,,,,,,,,,,,,,,,,, -64000,0.22010563,1.732783,,,,,,,,,,,,,,,,, -64100,0.20917244,1.7530159,,,,,,,,,,,,,,,,, -64200,0.21840242,1.7147336,,,,,,,,,,,,,,,,, -64300,0.19128858,1.7330816,,,,,,,,,,,,,,,,, -64400,0.2097538,1.7307285,,,,,,,,,,,,,,,,, -64500,0.20841688,1.7037644,,,,,,,,,,,,,,,,, -64600,0.18882373,1.6698599,,,,,,,,,,,,,,,,, -64700,0.20545942,1.7427459,,,,,,,,,,,,,,,,, -64800,0.20608544,1.6831774,,,,,,,,,,,,,,,,, -64900,0.21123558,1.6841234,,,,,,,,,,,,,,,,, -65000,0.23065464,1.6777493,,,,,,,,,,,,,,,,, -65100,0.20706594,1.7133656,,,,,,,,,,,,,,,,, -65200,0.19565357,1.6612821,,,,,,,,,,,,,,,,, -65300,0.2437721,1.5787414,,,,,,,,,,,,,,,,, -65331,,,0.6595335006713867,1.6017496585845947,32.48676522191736,0.6758750677108765,1.486812710762024,29.29447900717668,3000.0,0.6884666681289673,1.4014354944229126,28.83885231298425,3003.0,23560.48533654213,39583.79589796066,23560.48533654213,16020.533456087112,0.8095309734344482,0.0 -65400,0.19227076,1.6843119,,,,,,,,,,,,,,,,, -65500,0.20212981,1.7161942,,,,,,,,,,,,,,,,, -65600,0.20098662,1.6745763,,,,,,,,,,,,,,,,, -65700,0.20256689,1.772194,,,,,,,,,,,,,,,,, -65800,0.20799536,1.668416,,,,,,,,,,,,,,,,, -65900,0.18814866,1.643743,,,,,,,,,,,,,,,,, -66000,0.19714119,1.7467321,,,,,,,,,,,,,,,,, -66100,0.27896544,1.7286702,,,,,,,,,,,,,,,,, -66200,0.1956546,1.7105349,,,,,,,,,,,,,,,,, -66300,0.20359664,1.7347738,,,,,,,,,,,,,,,,, -66400,0.6461173,1.69675,,,,,,,,,,,,,,,,, -66500,0.18962209,1.6865226,,,,,,,,,,,,,,,,, -66600,0.20730856,1.6593758,,,,,,,,,,,,,,,,, -66700,0.21224093,1.6258028,,,,,,,,,,,,,,,,, -66800,0.20179251,1.6372153,,,,,,,,,,,,,,,,, -66900,0.19857028,1.7027559,,,,,,,,,,,,,,,,, -67000,0.18741077,1.6464406,,,,,,,,,,,,,,,,, -67100,0.1913452,1.7383019,,,,,,,,,,,,,,,,, -67200,0.20019041,1.664397,,,,,,,,,,,,,,,,, -67300,0.22798094,1.7131358,,,,,,,,,,,,,,,,, -67400,0.21134509,1.7209846,,,,,,,,,,,,,,,,, -67500,0.19242632,1.6083263,,,,,,,,,,,,,,,,, -67600,0.20295325,1.6075135,,,,,,,,,,,,,,,,, -67665,,,0.660723090171814,1.6009280681610107,32.36851512432437,0.6756022572517395,1.4917354583740234,29.302940848683274,3000.0,0.6875835657119751,1.4012601375579834,28.650583418270205,3003.0,24400.5163128376,41020.33675789833,24400.5163128376,16616.940055847168,0.8407487869262695,0.0 -67700,0.22945139,1.7304924,,,,,,,,,,,,,,,,, -67800,0.21423994,1.7221515,,,,,,,,,,,,,,,,, -67900,0.19176528,1.7222434,,,,,,,,,,,,,,,,, -68000,0.21902412,1.6841414,,,,,,,,,,,,,,,,, -68100,0.21574898,1.6477444,,,,,,,,,,,,,,,,, -68200,0.19673434,1.6639663,,,,,,,,,,,,,,,,, -68300,0.19573058,1.6363121,,,,,,,,,,,,,,,,, -68400,0.20144317,1.5859286,,,,,,,,,,,,,,,,, -68500,0.18379338,1.6646086,,,,,,,,,,,,,,,,, -68600,0.1992089,1.6367828,,,,,,,,,,,,,,,,, -68700,0.24362142,1.6089724,,,,,,,,,,,,,,,,, -68800,0.19949973,1.6813473,,,,,,,,,,,,,,,,, -68900,0.18124993,1.6771299,,,,,,,,,,,,,,,,, -69000,0.1905084,1.6497074,,,,,,,,,,,,,,,,, -69100,0.20648743,1.7201711,,,,,,,,,,,,,,,,, -69200,0.195961,1.654612,,,,,,,,,,,,,,,,, -69300,0.20179419,1.7119974,,,,,,,,,,,,,,,,, -69400,0.18928687,1.6338388,,,,,,,,,,,,,,,,, -69500,0.23729737,1.619195,,,,,,,,,,,,,,,,, -69600,0.205611,1.6062114,,,,,,,,,,,,,,,,, -69700,0.20862377,1.682496,,,,,,,,,,,,,,,,, -69800,0.17657779,1.6467599,,,,,,,,,,,,,,,,, -69900,0.19079405,1.6471766,,,,,,,,,,,,,,,,, -70000,,,0.6662484407424927,1.547742247581482,32.96686890315351,0.678317666053772,1.4796814918518066,29.641713014224447,3000.0,0.6928824782371521,1.385662317276001,29.57294526056814,3003.0,25240.725229740143,42596.53853440285,25240.725229740143,17352.784069776535,0.9213240146636964,0.0 -70000,0.21556824,1.5549458,,,,,,,,,,,,,,,,, -70100,0.21746704,1.7389603,,,,,,,,,,,,,,,,, -70200,0.23071076,1.7176245,,,,,,,,,,,,,,,,, -70300,0.20965928,1.707148,,,,,,,,,,,,,,,,, -70400,0.18611775,1.7416915,,,,,,,,,,,,,,,,, -70500,0.2035906,1.6249727,,,,,,,,,,,,,,,,, -70600,0.2015103,1.6240562,,,,,,,,,,,,,,,,, -70700,0.20850569,1.6644707,,,,,,,,,,,,,,,,, -70800,0.20247489,1.7629281,,,,,,,,,,,,,,,,, -70900,0.19507179,1.6352718,,,,,,,,,,,,,,,,, -71000,0.22807418,1.6925986,,,,,,,,,,,,,,,,, -71100,0.2154074,1.6184189,,,,,,,,,,,,,,,,, -71200,0.19775496,1.6203872,,,,,,,,,,,,,,,,, -71300,0.22051698,1.7174128,,,,,,,,,,,,,,,,, -71400,0.20955989,1.6524919,,,,,,,,,,,,,,,,, -71500,0.23116954,1.6527066,,,,,,,,,,,,,,,,, -71600,0.19182841,1.6540537,,,,,,,,,,,,,,,,, -71700,0.18913728,1.6047798,,,,,,,,,,,,,,,,, -71800,0.20924272,1.7330906,,,,,,,,,,,,,,,,, -71900,0.2569217,1.6817276,,,,,,,,,,,,,,,,, -72000,0.18794693,1.5753479,,,,,,,,,,,,,,,,, -72100,0.25667077,1.645905,,,,,,,,,,,,,,,,, -72200,0.2183198,1.7090179,,,,,,,,,,,,,,,,, -72300,0.19319023,1.6233793,,,,,,,,,,,,,,,,, -72334,,,0.658568263053894,1.5997875928878784,32.689614051614335,0.6778836846351624,1.4754083156585691,29.733602133158183,3000.0,0.6930567622184753,1.3839657306671145,29.155925471336744,3003.0,26080.97846698761,44158.56359124184,26080.97846698761,18074.45157289505,0.9541702270507812,0.0 -72400,0.20938368,1.6990712,,,,,,,,,,,,,,,,, -72500,0.20580187,1.6975482,,,,,,,,,,,,,,,,, -72600,0.21267384,1.5856266,,,,,,,,,,,,,,,,, -72700,0.20832476,1.6809556,,,,,,,,,,,,,,,,, -72800,0.37821817,1.6860062,,,,,,,,,,,,,,,,, -72900,0.21109132,1.5730759,,,,,,,,,,,,,,,,, -73000,0.22471385,1.677334,,,,,,,,,,,,,,,,, -73100,0.20211314,1.6503462,,,,,,,,,,,,,,,,, -73200,2.2274878,1.6524963,,,,,,,,,,,,,,,,, -73300,0.218862,1.6132684,,,,,,,,,,,,,,,,, -73400,0.1886788,1.6748973,,,,,,,,,,,,,,,,, -73500,0.19716197,1.5829494,,,,,,,,,,,,,,,,, -73600,0.22104634,1.6842623,,,,,,,,,,,,,,,,, -73700,0.21066633,1.6568316,,,,,,,,,,,,,,,,, -73800,0.18880065,1.6368583,,,,,,,,,,,,,,,,, -73900,0.21412653,1.6493835,,,,,,,,,,,,,,,,, -74000,0.22228417,1.6066223,,,,,,,,,,,,,,,,, -74100,0.20054737,1.6660293,,,,,,,,,,,,,,,,, -74200,0.23227237,1.6402011,,,,,,,,,,,,,,,,, -74300,0.20482895,1.7031138,,,,,,,,,,,,,,,,, -74400,0.21646613,1.6949079,,,,,,,,,,,,,,,,, -74500,0.19075702,1.650552,,,,,,,,,,,,,,,,, -74600,0.18760473,1.6162723,,,,,,,,,,,,,,,,, -74669,,,0.6603477597236633,1.6007717847824097,32.826282796088854,0.6780697107315063,1.4693963527679443,29.297419152581835,3000.0,0.6938470005989075,1.376531720161438,29.201825259701923,3003.0,26921.192920207977,45573.90559220314,26921.192920207977,18649.47722673416,0.9877212047576904,0.0 -74700,0.21503544,1.6812298,,,,,,,,,,,,,,,,, -74800,0.20784016,1.63735,,,,,,,,,,,,,,,,, -74900,0.20694537,1.7013803,,,,,,,,,,,,,,,,, -75000,0.24042733,1.6652282,,,,,,,,,,,,,,,,, -75100,0.21578206,1.6601274,,,,,,,,,,,,,,,,, -75200,0.20338951,1.7328359,,,,,,,,,,,,,,,,, -75300,0.20991172,1.6431507,,,,,,,,,,,,,,,,, -75400,0.20081437,1.6204582,,,,,,,,,,,,,,,,, -75500,0.20836574,1.5787007,,,,,,,,,,,,,,,,, -75600,0.20725696,1.6658133,,,,,,,,,,,,,,,,, -75700,0.22471651,1.664297,,,,,,,,,,,,,,,,, -75800,0.20611157,1.7503937,,,,,,,,,,,,,,,,, -75900,0.20308949,1.6674608,,,,,,,,,,,,,,,,, -76000,0.20890416,1.6896521,,,,,,,,,,,,,,,,, -76100,0.20197438,1.584172,,,,,,,,,,,,,,,,, -76200,0.19162494,1.6110272,,,,,,,,,,,,,,,,, -76300,1.6520226,1.6599637,,,,,,,,,,,,,,,,, -76400,0.20842159,1.717979,,,,,,,,,,,,,,,,, -76500,0.21332109,1.6082516,,,,,,,,,,,,,,,,, -76600,0.20435716,1.555303,,,,,,,,,,,,,,,,, -76700,0.20793924,1.680952,,,,,,,,,,,,,,,,, -76800,0.20043832,1.6644126,,,,,,,,,,,,,,,,, -76900,0.20630041,1.5830743,,,,,,,,,,,,,,,,, -77000,0.20021419,1.6079557,,,,,,,,,,,,,,,,, -77003,,,0.6674932241439819,1.5427851676940918,33.33707098666702,0.6805866956710815,1.4604946374893188,29.74598294178544,3000.0,0.6957759857177734,1.3691195249557495,29.476644371476635,3003.0,27761.27295565605,47139.663432359695,27761.27295565605,19375.05325198173,1.0212042331695557,0.0 -77100,0.20632438,1.6602312,,,,,,,,,,,,,,,,, -77200,0.19978859,1.6306367,,,,,,,,,,,,,,,,, -77300,0.21407317,1.7225745,,,,,,,,,,,,,,,,, -77400,0.22500408,1.7163438,,,,,,,,,,,,,,,,, -77500,0.19166104,1.5980011,,,,,,,,,,,,,,,,, -77600,0.20645654,1.6554643,,,,,,,,,,,,,,,,, -77700,0.21241577,1.6240406,,,,,,,,,,,,,,,,, -77800,0.22218096,1.6895828,,,,,,,,,,,,,,,,, -77900,0.19944236,1.6781187,,,,,,,,,,,,,,,,, -78000,0.22930804,1.7709063,,,,,,,,,,,,,,,,, -78100,0.18784021,1.5952744,,,,,,,,,,,,,,,,, -78200,0.197708,1.5759711,,,,,,,,,,,,,,,,, -78300,0.21157522,1.6016725,,,,,,,,,,,,,,,,, -78400,0.19640008,1.5857996,,,,,,,,,,,,,,,,, -78500,0.2017298,1.6571856,,,,,,,,,,,,,,,,, -78600,0.20870686,1.6315912,,,,,,,,,,,,,,,,, -78700,0.20982558,1.5927339,,,,,,,,,,,,,,,,, -78800,0.21423362,1.634443,,,,,,,,,,,,,,,,, -78900,0.20455588,1.566166,,,,,,,,,,,,,,,,, -79000,0.210989,1.6162539,,,,,,,,,,,,,,,,, -79100,0.20474929,1.5946052,,,,,,,,,,,,,,,,, -79200,0.20547856,1.6429874,,,,,,,,,,,,,,,,, -79300,0.19485767,1.6061498,,,,,,,,,,,,,,,,, -79336,,,0.662295937538147,1.577651858329773,33.02561788436131,0.6813926696777344,1.4510258436203003,29.73820499849484,3000.0,0.6976584792137146,1.358900547027588,29.926070662587613,3003.0,28601.181200027462,48591.722954034805,28601.181200027462,19987.100786447525,1.0545799732208252,0.0 -79400,0.20355368,1.6221442,,,,,,,,,,,,,,,,, -79500,0.20387946,1.7401208,,,,,,,,,,,,,,,,, -79600,0.20910658,1.6815171,,,,,,,,,,,,,,,,, -79700,0.20343521,1.6101067,,,,,,,,,,,,,,,,, -79800,0.19950236,1.648387,,,,,,,,,,,,,,,,, -79900,0.214672,1.6097484,,,,,,,,,,,,,,,,, -80000,0.21787195,1.6157836,,,,,,,,,,,,,,,,, -80100,0.2321907,1.6703098,,,,,,,,,,,,,,,,, -80200,0.21054396,1.5731988,,,,,,,,,,,,,,,,, -80300,0.20888227,1.6927232,,,,,,,,,,,,,,,,, -80400,0.2027073,1.6739054,,,,,,,,,,,,,,,,, -80500,0.19643046,1.6445869,,,,,,,,,,,,,,,,, -80600,0.18992352,1.5709302,,,,,,,,,,,,,,,,, -80700,0.20613395,1.5975449,,,,,,,,,,,,,,,,, -80800,0.20285396,1.7257867,,,,,,,,,,,,,,,,, -80900,0.20460841,1.6356677,,,,,,,,,,,,,,,,, -81000,0.21308126,1.6467094,,,,,,,,,,,,,,,,, -81100,0.22504279,1.7052888,,,,,,,,,,,,,,,,, -81200,0.1957268,1.633105,,,,,,,,,,,,,,,,, -81300,0.19674379,1.6229017,,,,,,,,,,,,,,,,, -81400,0.20115247,1.612132,,,,,,,,,,,,,,,,, -81500,0.192216,1.5028579,,,,,,,,,,,,,,,,, -81600,0.21021706,1.659703,,,,,,,,,,,,,,,,, -81670,,,0.6846832633018494,1.4385541677474976,34.21251206089079,0.682719349861145,1.4486734867095947,29.936081717855267,3000.0,0.6980535984039307,1.3542295694351196,29.6608945179934,3003.0,29441.39288806916,49994.04566836357,29441.39288806916,20549.10893559456,1.0872220993041992,0.0 -81700,0.19258484,1.5782459,,,,,,,,,,,,,,,,, -81800,0.2213706,1.593697,,,,,,,,,,,,,,,,, -81900,0.20561282,1.7405338,,,,,,,,,,,,,,,,, -82000,0.210686,1.572326,,,,,,,,,,,,,,,,, -82100,0.20095304,1.6068523,,,,,,,,,,,,,,,,, -82200,0.19418353,1.6225109,,,,,,,,,,,,,,,,, -82300,0.20873998,1.6811389,,,,,,,,,,,,,,,,, -82400,0.20766427,1.6719636,,,,,,,,,,,,,,,,, -82500,0.205071,1.6150172,,,,,,,,,,,,,,,,, -82600,0.2040242,1.6177703,,,,,,,,,,,,,,,,, -82700,0.19983406,1.6338612,,,,,,,,,,,,,,,,, -82800,0.22665368,1.6713411,,,,,,,,,,,,,,,,, -82900,0.2057699,1.6251361,,,,,,,,,,,,,,,,, -83000,0.19332117,1.5901561,,,,,,,,,,,,,,,,, -83100,0.19393341,1.5886897,,,,,,,,,,,,,,,,, -83200,0.20636997,1.6212474,,,,,,,,,,,,,,,,, -83300,0.20819959,1.6306952,,,,,,,,,,,,,,,,, -83400,0.20017494,1.6137617,,,,,,,,,,,,,,,,, -83500,0.21168232,1.6530714,,,,,,,,,,,,,,,,, -83600,0.22248773,1.5864513,,,,,,,,,,,,,,,,, -83700,0.19554693,1.5926169,,,,,,,,,,,,,,,,, -83800,0.20106645,1.5719458,,,,,,,,,,,,,,,,, -83900,0.20662045,1.6336291,,,,,,,,,,,,,,,,, -84000,0.20187269,1.5839779,,,,,,,,,,,,,,,,, -84004,,,0.6690401434898376,1.5280447006225586,33.571556752165826,0.6831657290458679,1.441229224205017,29.646118721990334,3000.0,0.6986462473869324,1.3453952074050903,30.022813565275428,3003.0,30281.515100955963,51570.83880257607,30281.515100955963,21285.675103902817,1.1223838329315186,0.0 -84100,0.21925132,1.663669,,,,,,,,,,,,,,,,, -84200,0.21381712,1.6257868,,,,,,,,,,,,,,,,, -84300,0.20086046,1.7067482,,,,,,,,,,,,,,,,, -84400,0.20461544,1.705485,,,,,,,,,,,,,,,,, -84500,0.19991976,1.621071,,,,,,,,,,,,,,,,, -84600,0.20538157,1.5713954,,,,,,,,,,,,,,,,, -84700,0.19656289,1.5526763,,,,,,,,,,,,,,,,, -84800,0.21977998,1.5872654,,,,,,,,,,,,,,,,, -84900,0.19636948,1.5890378,,,,,,,,,,,,,,,,, -85000,0.20033476,1.5164284,,,,,,,,,,,,,,,,, -85100,0.20947891,1.6267065,,,,,,,,,,,,,,,,, -85200,0.18987523,1.5795511,,,,,,,,,,,,,,,,, -85300,0.2136882,1.6507128,,,,,,,,,,,,,,,,, -85400,0.21336395,1.5197017,,,,,,,,,,,,,,,,, -85500,0.20048557,1.6094508,,,,,,,,,,,,,,,,, -85600,0.20855106,1.6047364,,,,,,,,,,,,,,,,, -85700,0.20685303,1.5847188,,,,,,,,,,,,,,,,, -85800,0.21031597,1.6984771,,,,,,,,,,,,,,,,, -85900,0.23680107,1.5760896,,,,,,,,,,,,,,,,, -86000,0.20988007,1.6478838,,,,,,,,,,,,,,,,, -86100,0.223944,1.6509023,,,,,,,,,,,,,,,,, -86200,0.19936638,1.5296801,,,,,,,,,,,,,,,,, -86300,0.23136282,1.7152083,,,,,,,,,,,,,,,,, -86337,,,0.672156810760498,1.5186889171600342,33.86852339264609,0.6863647103309631,1.4359986782073977,30.026562760041603,3000.0,0.7010284066200256,1.3384041786193848,29.734592836322268,3003.0,31121.409340381622,52974.23804974556,31121.409340381622,21849.073429584503,1.157557249069214,0.0 -86400,0.20580892,1.5207224,,,,,,,,,,,,,,,,, -86500,0.2108345,1.6206559,,,,,,,,,,,,,,,,, -86600,0.20157972,1.5813838,,,,,,,,,,,,,,,,, -86700,0.20253131,1.582914,,,,,,,,,,,,,,,,, -86800,0.20987794,1.5686404,,,,,,,,,,,,,,,,, -86900,0.22821604,1.6760283,,,,,,,,,,,,,,,,, -87000,0.22034876,1.6468847,,,,,,,,,,,,,,,,, -87100,0.22134452,1.5901321,,,,,,,,,,,,,,,,, -87200,0.2106796,1.5933545,,,,,,,,,,,,,,,,, -87300,0.21056291,1.5494218,,,,,,,,,,,,,,,,, -87400,0.23926397,1.6698118,,,,,,,,,,,,,,,,, -87500,0.19832574,1.5837224,,,,,,,,,,,,,,,,, -87600,0.20949602,1.6151944,,,,,,,,,,,,,,,,, -87700,0.20888036,1.6289333,,,,,,,,,,,,,,,,, -87800,0.22281101,1.5622706,,,,,,,,,,,,,,,,, -87900,0.21431006,1.6352987,,,,,,,,,,,,,,,,, -88000,0.2143855,1.6168013,,,,,,,,,,,,,,,,, -88100,0.20546627,1.5518992,,,,,,,,,,,,,,,,, -88200,0.21001397,1.6191137,,,,,,,,,,,,,,,,, -88300,0.22850487,1.5977961,,,,,,,,,,,,,,,,, -88400,0.2273232,1.6336588,,,,,,,,,,,,,,,,, -88500,0.21014172,1.6928412,,,,,,,,,,,,,,,,, -88600,0.22497678,1.584052,,,,,,,,,,,,,,,,, -88671,,,0.6767730116844177,1.4741313457489014,33.87990809492461,0.6857447624206543,1.4289525747299194,30.211913446848925,3000.0,0.7016443014144897,1.331700563430786,30.041883007413933,3003.0,31961.54389023781,54394.35652112961,31961.54389023781,22428.9524166584,1.192662000656128,0.0 -88700,0.22137628,1.6156228,,,,,,,,,,,,,,,,, -88800,0.1969481,1.5598906,,,,,,,,,,,,,,,,, -88900,0.20834723,1.5853844,,,,,,,,,,,,,,,,, -89000,0.2117751,1.5711881,,,,,,,,,,,,,,,,, -89100,0.20999235,1.571131,,,,,,,,,,,,,,,,, -89200,0.20322554,1.5878251,,,,,,,,,,,,,,,,, -89300,0.21495904,1.591465,,,,,,,,,,,,,,,,, -89400,0.21875769,1.5342035,,,,,,,,,,,,,,,,, -89500,0.20909104,1.5907816,,,,,,,,,,,,,,,,, -89600,0.21576919,1.5648406,,,,,,,,,,,,,,,,, -89700,0.21351169,1.6231799,,,,,,,,,,,,,,,,, -89800,0.19991675,1.5348333,,,,,,,,,,,,,,,,, -89900,0.23630628,1.5535395,,,,,,,,,,,,,,,,, -90000,0.21041662,1.6081027,,,,,,,,,,,,,,,,, -90100,0.22391236,1.5917822,,,,,,,,,,,,,,,,, -90200,0.2012637,1.5767342,,,,,,,,,,,,,,,,, -90300,0.21131471,1.6092196,,,,,,,,,,,,,,,,, -90400,0.2193479,1.6319451,,,,,,,,,,,,,,,,, -90500,0.20597272,1.5530757,,,,,,,,,,,,,,,,, -90600,0.2201607,1.5821797,,,,,,,,,,,,,,,,, -90700,0.20896247,1.5635562,,,,,,,,,,,,,,,,, -90800,0.20035566,1.628992,,,,,,,,,,,,,,,,, -90900,0.22043365,1.5760993,,,,,,,,,,,,,,,,, -91000,0.20110318,1.5142752,,,,,,,,,,,,,,,,, -91004,,,0.6733483672142029,1.5103007555007937,33.843482795286214,0.6872202157974243,1.4207197427749634,30.112663837010107,3000.0,0.7026785612106323,1.3232601881027222,30.104204020163536,3003.0,32801.646438360214,55827.74574255943,32801.646438360214,23022.13383102417,1.2275655269622805,0.0 -91100,0.21799016,1.6215407,,,,,,,,,,,,,,,,, -91200,0.2204283,1.6204597,,,,,,,,,,,,,,,,, -91300,0.19767456,1.5449668,,,,,,,,,,,,,,,,, -91400,0.2201235,1.5111873,,,,,,,,,,,,,,,,, -91500,0.21047828,1.5723361,,,,,,,,,,,,,,,,, -91600,0.21816468,1.5927063,,,,,,,,,,,,,,,,, -91700,0.22170743,1.6158522,,,,,,,,,,,,,,,,, -91800,0.22668439,1.653882,,,,,,,,,,,,,,,,, -91900,0.20314215,1.5421761,,,,,,,,,,,,,,,,, -92000,0.20878229,1.5904182,,,,,,,,,,,,,,,,, -92100,0.21421042,1.5471822,,,,,,,,,,,,,,,,, -92200,0.22022507,1.6554846,,,,,,,,,,,,,,,,, -92300,0.21261849,1.647447,,,,,,,,,,,,,,,,, -92400,0.2160942,1.5523174,,,,,,,,,,,,,,,,, -92500,0.2112793,1.578304,,,,,,,,,,,,,,,,, -92600,0.22527356,1.5623218,,,,,,,,,,,,,,,,, -92700,0.2081897,1.6330308,,,,,,,,,,,,,,,,, -92800,0.21527645,1.6201539,,,,,,,,,,,,,,,,, -92900,0.21806963,1.6229255,,,,,,,,,,,,,,,,, -93000,0.22760437,1.6727043,,,,,,,,,,,,,,,,, -93100,0.21782638,1.5683799,,,,,,,,,,,,,,,,, -93200,0.21389306,1.5777599,,,,,,,,,,,,,,,,, -93300,0.21185766,1.5472398,,,,,,,,,,,,,,,,, -93337,,,0.6737951040267944,1.5091018676757812,33.77046943780147,0.6864390969276428,1.415480136871338,30.335093032326945,3000.0,0.7031898498535156,1.3187764883041382,30.082529087863985,3003.0,33641.57714486122,57269.25446343422,33641.57714486122,23623.600753068924,1.2687926292419434,0.0 -93400,0.21654578,1.5715995,,,,,,,,,,,,,,,,, -93500,0.21907744,1.6110824,,,,,,,,,,,,,,,,, -93600,0.21737061,1.5726354,,,,,,,,,,,,,,,,, -93700,0.2103356,1.5274869,,,,,,,,,,,,,,,,, -93800,0.2188544,1.5539593,,,,,,,,,,,,,,,,, -93900,0.20345414,1.6179563,,,,,,,,,,,,,,,,, -94000,0.21100613,1.5911111,,,,,,,,,,,,,,,,, -94100,0.20930445,1.5099487,,,,,,,,,,,,,,,,, -94200,0.24663055,1.6730658,,,,,,,,,,,,,,,,, -94300,0.2271795,1.6051397,,,,,,,,,,,,,,,,, -94400,0.22367318,1.507001,,,,,,,,,,,,,,,,, -94500,0.20434225,1.527546,,,,,,,,,,,,,,,,, -94600,0.20934641,1.5734171,,,,,,,,,,,,,,,,, -94700,0.21638037,1.6014374,,,,,,,,,,,,,,,,, -94800,0.21990421,1.5681257,,,,,,,,,,,,,,,,, -94900,0.21130514,1.5528479,,,,,,,,,,,,,,,,, -95000,0.21922155,1.5712918,,,,,,,,,,,,,,,,, -95100,0.2266321,1.5423625,,,,,,,,,,,,,,,,, -95200,0.22074342,1.5757847,,,,,,,,,,,,,,,,, -95300,0.22003776,1.5189369,,,,,,,,,,,,,,,,, -95400,0.20270947,1.4752728,,,,,,,,,,,,,,,,, -95500,0.2081944,1.5820726,,,,,,,,,,,,,,,,, -95600,0.22342433,1.5340972,,,,,,,,,,,,,,,,, -95671,,,0.6822972297668457,1.4458683729171753,34.47859926606245,0.689315676689148,1.405459761619568,30.457518693412105,3000.0,0.7049445509910583,1.3141802549362185,30.18251317817529,3003.0,34481.814543008804,58816.59733939171,34481.814543008804,24330.59509205818,1.3097121715545654,0.0 -95700,0.21957633,1.7152787,,,,,,,,,,,,,,,,, -95800,0.24289969,1.6216984,,,,,,,,,,,,,,,,, -95900,0.20320548,1.5292737,,,,,,,,,,,,,,,,, -96000,0.22142413,1.5136787,,,,,,,,,,,,,,,,, -96100,0.2097156,1.5341071,,,,,,,,,,,,,,,,, -96200,0.21947937,1.568105,,,,,,,,,,,,,,,,, -96300,0.21402024,1.5732598,,,,,,,,,,,,,,,,, -96400,0.20824902,1.5123844,,,,,,,,,,,,,,,,, -96500,0.21267642,1.5890846,,,,,,,,,,,,,,,,, -96600,0.2200436,1.5329832,,,,,,,,,,,,,,,,, -96700,0.21397331,1.5884901,,,,,,,,,,,,,,,,, -96800,0.21137637,1.6197212,,,,,,,,,,,,,,,,, -96900,0.2159009,1.6854817,,,,,,,,,,,,,,,,, -97000,0.21360362,1.5240917,,,,,,,,,,,,,,,,, -97100,0.21945794,1.5143167,,,,,,,,,,,,,,,,, -97200,0.20427798,1.5473393,,,,,,,,,,,,,,,,, -97300,0.24198014,1.5963973,,,,,,,,,,,,,,,,, -97400,0.22131677,1.5461377,,,,,,,,,,,,,,,,, -97500,0.20284894,1.4698045,,,,,,,,,,,,,,,,, -97600,0.21722248,1.6071595,,,,,,,,,,,,,,,,, -97700,0.22825792,1.5231296,,,,,,,,,,,,,,,,, -97800,0.21426274,1.5955368,,,,,,,,,,,,,,,,, -97900,0.21157935,1.6088804,,,,,,,,,,,,,,,,, -98000,0.22270744,1.5755743,,,,,,,,,,,,,,,,, -98004,,,0.678022563457489,1.476861596107483,34.01514409048839,0.6899852156639099,1.4042116403579712,30.62139618709256,3000.0,0.7058508992195129,1.3002111911773682,30.45435340686677,3003.0,35321.9128882885,60248.02245783806,35321.9128882885,24921.81315469742,1.3472654819488523,0.0 -98100,0.22478853,1.5533073,,,,,,,,,,,,,,,,, -98200,0.21382768,1.5792863,,,,,,,,,,,,,,,,, -98300,0.2122191,1.5916418,,,,,,,,,,,,,,,,, -98400,0.21255623,1.4868002,,,,,,,,,,,,,,,,, -98500,0.22329311,1.5519303,,,,,,,,,,,,,,,,, -98600,0.22820047,1.594466,,,,,,,,,,,,,,,,, -98700,0.20884787,1.5368582,,,,,,,,,,,,,,,,, -98800,0.23606564,1.561765,,,,,,,,,,,,,,,,, -98900,0.21645637,1.5108132,,,,,,,,,,,,,,,,, -99000,0.21637519,1.5651348,,,,,,,,,,,,,,,,, -99100,0.21471691,1.5403372,,,,,,,,,,,,,,,,, -99200,0.21254727,1.5286113,,,,,,,,,,,,,,,,, -99300,0.21676187,1.5048925,,,,,,,,,,,,,,,,, -99400,0.21630664,1.5595415,,,,,,,,,,,,,,,,, -99500,0.20656957,1.4716909,,,,,,,,,,,,,,,,, -99600,0.23000534,1.6068919,,,,,,,,,,,,,,,,, -99700,0.20398392,1.5064015,,,,,,,,,,,,,,,,, -99800,0.22445817,1.5300833,,,,,,,,,,,,,,,,, -99900,0.21731997,1.435993,,,,,,,,,,,,,,,,, -100000,0.22255637,1.5208713,,,,,,,,,,,,,,,,, -100100,0.23138319,1.5470345,,,,,,,,,,,,,,,,, -100200,0.22806089,1.57852,,,,,,,,,,,,,,,,, -100300,0.22107857,1.5265067,,,,,,,,,,,,,,,,, -100338,,,0.6938577890396118,1.3822377920150757,34.990644862138524,0.6900968551635742,1.396977782249451,30.46087354581043,3000.0,0.7059787511825562,1.2977460622787476,30.23059396378059,3003.0,36162.06133246422,61709.79067540169,36162.06133246422,25543.32467985153,1.3858873844146729,0.0 -100400,0.21795796,1.4675933,,,,,,,,,,,,,,,,, -100500,0.23578443,1.5104849,,,,,,,,,,,,,,,,, -100600,0.22670233,1.5122379,,,,,,,,,,,,,,,,, -100700,0.216258,1.5212963,,,,,,,,,,,,,,,,, -100800,0.23745224,1.5365107,,,,,,,,,,,,,,,,, -100900,0.22526136,1.5385283,,,,,,,,,,,,,,,,, -101000,0.22058676,1.5780637,,,,,,,,,,,,,,,,, -101100,0.2295718,1.4505199,,,,,,,,,,,,,,,,, -101200,0.2215235,1.568218,,,,,,,,,,,,,,,,, -101300,0.2105085,1.5475203,,,,,,,,,,,,,,,,, -101400,0.22693795,1.5760165,,,,,,,,,,,,,,,,, -101500,0.2578059,1.4931408,,,,,,,,,,,,,,,,, -101600,0.21631564,1.5410908,,,,,,,,,,,,,,,,, -101700,0.21159877,1.450574,,,,,,,,,,,,,,,,, -101800,0.21416819,1.5836942,,,,,,,,,,,,,,,,, -101900,0.24723361,1.5910081,,,,,,,,,,,,,,,,, -102000,0.21386996,1.499084,,,,,,,,,,,,,,,,, -102100,0.21920845,1.5037096,,,,,,,,,,,,,,,,, -102200,0.23339193,1.6429245,,,,,,,,,,,,,,,,, -102300,0.22595853,1.5367618,,,,,,,,,,,,,,,,, -102400,0.21180786,1.5254929,,,,,,,,,,,,,,,,, -102500,0.2175806,1.4802028,,,,,,,,,,,,,,,,, -102600,0.22225094,1.450412,,,,,,,,,,,,,,,,, -102671,,,0.6866313219070435,1.427037477493286,34.47752292320814,0.6905555725097656,1.392742156982422,30.45557605030123,3000.0,0.7083958387374878,1.2918590307235718,30.438368768962743,3003.0,37002.12159109116,63236.7092871666,37002.12159109116,26230.07666492462,1.4218950271606443,0.0 -102700,0.23925772,1.531092,,,,,,,,,,,,,,,,, -102800,0.21983325,1.4653555,,,,,,,,,,,,,,,,, -102900,0.22168139,1.5038863,,,,,,,,,,,,,,,,, -103000,0.2133462,1.5155389,,,,,,,,,,,,,,,,, -103100,0.21880156,1.4990838,,,,,,,,,,,,,,,,, -103200,0.21459965,1.4876188,,,,,,,,,,,,,,,,, -103300,0.22545776,1.5621324,,,,,,,,,,,,,,,,, -103400,0.248157,1.6102167,,,,,,,,,,,,,,,,, -103500,0.21812785,1.4622407,,,,,,,,,,,,,,,,, -103600,0.21135108,1.4761113,,,,,,,,,,,,,,,,, -103700,0.24685459,1.5568366,,,,,,,,,,,,,,,,, -103800,0.23043138,1.6189691,,,,,,,,,,,,,,,,, -103900,0.22353518,1.5071753,,,,,,,,,,,,,,,,, -104000,0.22285694,1.5114998,,,,,,,,,,,,,,,,, -104100,0.229673,1.5226326,,,,,,,,,,,,,,,,, -104200,0.23288675,1.6148704,,,,,,,,,,,,,,,,, -104300,0.23261677,1.5527513,,,,,,,,,,,,,,,,, -104400,0.21243796,1.503592,,,,,,,,,,,,,,,,, -104500,0.22410768,1.5143812,,,,,,,,,,,,,,,,, -104600,0.22795792,1.4590014,,,,,,,,,,,,,,,,, -104700,0.21074209,1.4661554,,,,,,,,,,,,,,,,, -104800,0.2201956,1.4743583,,,,,,,,,,,,,,,,, -104900,0.22316669,1.5817013,,,,,,,,,,,,,,,,, -105000,0.2396234,1.4991632,,,,,,,,,,,,,,,,, -105004,,,0.6834654808044434,1.4519506692886353,34.62209129947804,0.691969096660614,1.38778555393219,30.79525270570812,3000.0,0.7087908983230591,1.2866441011428833,30.58542091418264,3003.0,37842.12282657623,64671.632581949234,37842.12282657623,26824.891982793808,1.4596679210662842,0.0 -105100,0.22389099,1.5167772,,,,,,,,,,,,,,,,, -105200,0.22514138,1.5826709,,,,,,,,,,,,,,,,, -105300,0.23493624,1.5369958,,,,,,,,,,,,,,,,, -105400,0.2273184,1.4381214,,,,,,,,,,,,,,,,, -105500,0.22583689,1.4771099,,,,,,,,,,,,,,,,, -105600,0.21591581,1.4800727,,,,,,,,,,,,,,,,, -105700,0.21665739,1.5501066,,,,,,,,,,,,,,,,, -105800,0.22634406,1.5593485,,,,,,,,,,,,,,,,, -105900,0.23133698,1.5502727,,,,,,,,,,,,,,,,, -106000,0.23946795,1.5755082,,,,,,,,,,,,,,,,, -106100,0.23736295,1.5918866,,,,,,,,,,,,,,,,, -106200,0.2266118,1.5674075,,,,,,,,,,,,,,,,, -106300,0.21921204,1.4634259,,,,,,,,,,,,,,,,, -106400,0.22903699,1.5322044,,,,,,,,,,,,,,,,, -106500,0.20813093,1.4397353,,,,,,,,,,,,,,,,, -106600,0.22489074,1.4923587,,,,,,,,,,,,,,,,, -106700,0.21966197,1.5228341,,,,,,,,,,,,,,,,, -106800,0.21886349,1.5008315,,,,,,,,,,,,,,,,, -106900,0.23543215,1.5066402,,,,,,,,,,,,,,,,, -107000,0.2358519,1.5223156,,,,,,,,,,,,,,,,, -107100,0.21757889,1.5303015,,,,,,,,,,,,,,,,, -107200,0.23468323,1.5344026,,,,,,,,,,,,,,,,, -107300,0.2304095,1.5121151,,,,,,,,,,,,,,,,, -107337,,,0.690151035785675,1.4052026271820068,35.204680267477826,0.6935065984725952,1.38334059715271,30.966569494395618,3000.0,0.7097437977790833,1.2780189514160156,30.80200748225378,3003.0,38682.26754260063,66130.1880209446,38682.26754260063,27443.19491481781,1.4974017143249512,0.0 -107337,,,,,,,,,,,,,,38682.26754260063,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 4573835e1..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,50 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -780.9242124557495,0.0,20.41549754142761,1,0,20.41549754142761,0.7073875930098684,95000000,801.339747428894,0.7075171364740755,0.7065163343942286,83274637 -1412.1124548912048,0.0275628566741943,140.66957092285156,183,0,140.66957092285156,0.1318383534745066,95000000,1552.8158721923828,0.1271871388099103,0.128920479360426,83274637 -2034.498508453369,0.0482709407806396,261.05939412117004,357,0,261.05939412117004,0.1304574425370065,95000000,2295.6186287403107,0.1301833171303729,0.1277769713631602,83274637 -2647.331680059433,0.069040298461914,381.33195638656616,537,0,381.33195638656616,0.1297183601459704,95000000,3028.751186609268,0.1262307041705404,0.1273617779289393,83274637 -3249.598855018616,0.0938754081726074,501.7907288074493,714,0,501.7907288074493,0.1287761107319079,95000000,3751.5082404613495,0.1260030443469683,0.1265429995104162,83274637 -3834.333093643189,0.1173727512359619,621.9414341449738,889,0,621.9414341449738,0.1293202573704769,95000000,4456.42281794548,0.1266442178016376,0.1269329799259305,83274637 -4378.849170207977,0.1401965618133545,742.3252635002136,1069,0,742.3252635002136,0.1301694638363487,95000000,5121.351807594299,0.1271526435003528,0.1278378501184768,83274637 -4911.195806980133,0.1622538566589355,862.3552062511444,1254,0,862.3552062511444,0.1280532586759868,95000000,5773.757102012634,0.1242560849718327,0.125565447074868,83274637 -5455.546610832214,0.1836435794830322,982.542487859726,1434,0,982.542487859726,0.1281122845497533,95000000,6438.3228328228,0.1224136979629596,0.1257100016173366,83274637 -5993.718648195267,0.2044479846954345,1102.916160583496,1609,0,1102.916160583496,0.1278192471011513,95000000,7096.895583152771,0.1236820129455073,0.1255223536084451,83274637 -6435.375310897827,0.225440502166748,1223.486125946045,1785,0,1223.486125946045,0.1279375389494243,95000000,7659.149473905563,0.1258179464219313,0.1255218474179224,83274637 -6883.846434593201,0.2464847564697265,1343.8530678749084,1960,0,1343.8530678749084,0.1277369629214638,95000000,8228.014498233795,0.1241794411186714,0.1253777853219207,83274637 -7429.795899629593,0.2674703598022461,1464.0568370819092,2141,0,1464.0568370819092,0.1275420262438322,95000000,8894.195020198822,0.124650988493513,0.1250960292481521,83274637 -7971.723219394684,0.2913937568664551,1584.1165263652802,2323,0,1584.1165263652802,0.1281815174136513,95000000,9556.212258338928,0.1263659366716941,0.1258524653687848,83274637 -8444.770316362381,0.3136165142059326,1704.4217224121094,2504,0,1704.4217224121094,0.1272554857319079,95000000,10149.593060016632,0.1238533474106646,0.1248805963113579,83274637 -8936.420797586441,0.3353207111358642,1824.6317229270933,2687,0,1824.6317229270933,0.1270345026932565,95000000,10761.48148536682,0.1220445775854512,0.1246625911546021,83274637 -9472.194302797318,0.3567848205566406,1944.9548075199127,2868,0,1944.9548075199127,0.1273176975534539,95000000,11417.6058242321,0.1214646557756002,0.1249482107328564,83274637 -10002.52248263359,0.3787112236022949,2065.510046243668,3043,0,2065.510046243668,0.1271934990028783,95000000,12068.517102956772,0.12254764501638,0.124845981801514,83274637 -10540.870343446732,0.3999860286712646,2185.699591398239,3222,0,2185.699591398239,0.1269839272203947,95000000,12727.081996917725,0.1253731692760433,0.1247920694588914,83274637 -11097.06042098999,0.4221045970916748,2306.06552362442,3400,0,2306.06552362442,0.1270944154913651,95000000,13403.666176319122,0.1233362405086463,0.1247953959308181,83274637 -11633.53394317627,0.4463634490966797,2426.2269320487976,3578,0,2426.2269320487976,0.1270553401521381,95000000,14060.331575393677,0.123385631496895,0.1247588130247568,83274637 -12150.217435836792,0.4680953025817871,2546.5644059181213,3758,0,2546.5644059181213,0.12705738671875,95000000,14697.38048863411,0.1234821093930575,0.124734425605144,83274637 -12683.30470943451,0.4939754009246826,2666.656278610229,3941,0,2666.656278610229,0.1272678745271381,95000000,15350.591654062271,0.12291515066798,0.124811235773253,83274637 -13213.176526546478,0.5162239074707031,2786.872770547867,4121,0,2786.872770547867,0.1268679931537829,95000000,16000.70842075348,0.1219303092960291,0.1244560978054758,83274637 -13770.963485240936,0.538172721862793,2907.630640745163,4300,0,2907.630640745163,0.1269428061163651,95000000,16679.281347990036,0.1210334593980754,0.1244571152936341,83274637 -14329.13000202179,0.5602178573608398,3027.781261205673,4478,0,3027.781261205673,0.1268535073499177,95000000,17357.626393795013,0.1224394360404906,0.1244980308549427,83274637 -14874.075819253922,0.5834667682647705,3147.901193380356,4653,0,3147.901193380356,0.1268224839432565,95000000,18022.7213242054,0.1207981519166778,0.1244231657932092,83274637 -15435.684971809387,0.6086556911468506,3268.092779636383,4830,0,3268.092779636383,0.1269407832648026,95000000,18704.55334186554,0.1231741636710346,0.1244902192766388,83274637 -15985.845247745514,0.6302371025085449,3388.683220624924,5007,0,3388.683220624924,0.1266421239514802,95000000,19375.33169746399,0.1245397203045446,0.1243035075815423,83274637 -16534.02060341835,0.6558964252471924,3508.74847984314,5185,0,3508.74847984314,0.1267856268503289,95000000,20043.603875160217,0.1211312281776149,0.1243408592224259,83274637 -17077.79897761345,0.6779756546020508,3628.8643069267273,5364,0,3628.8643069267273,0.1267167002158717,95000000,20707.52645087242,0.1234937511126762,0.1243307710620822,83274637 -17617.085867881775,0.7020981311798096,3749.122021913528,5544,0,3749.122021913528,0.1266714173108552,95000000,21367.10143470764,0.1238525861042475,0.1242785482239253,83274637 -18177.610891342163,0.7237498760223389,3870.0109403133392,5726,0,3870.0109403133392,0.1268647591180098,95000000,22048.54324698448,0.1218163539544216,0.1244422784799541,83274637 -18728.214420557026,0.7506492137908936,3990.370181083679,5907,0,3990.370181083679,0.1267648204872533,95000000,22719.53905010224,0.1214116878877833,0.1243632209652411,83274637 -19273.00504016876,0.777338981628418,4110.908145189285,6086,0,4110.908145189285,0.1265344971731085,95000000,23384.900443792343,0.1218489267564607,0.1240728441670698,83274637 -19816.35657382012,0.8003125190734863,4231.136974811554,6262,0,4231.136974811554,0.1262730346114309,95000000,24048.50984811783,0.1221851150757111,0.1239943200612698,83274637 -20364.616135120392,0.8278322219848633,4351.401541948319,6438,0,4351.401541948319,0.1263620978207237,95000000,24717.06759095192,0.1212437743051621,0.1240503758843689,83274637 -20914.94091773033,0.8502209186553955,4471.631459951401,6613,0,4471.631459951401,0.1262541393914473,95000000,25387.65100812912,0.1240159423447817,0.1238895284264343,83274637 -21458.45505642891,0.8717246055603027,4591.935922861099,6794,0,4591.935922861099,0.1262963168893914,95000000,26051.4972178936,0.1209843989883391,0.1239240574639659,83274637 -21990.9538500309,0.898688554763794,4712.108925819397,6972,0,4712.108925819397,0.126148935598273,95000000,26704.202216148376,0.1233273534889116,0.123803586259741,83274637 -22539.277502298355,0.9209742546081544,4832.4965353012085,7157,0,4832.4965353012085,0.1263409345908717,95000000,27372.942261219025,0.120873235622948,0.1240408316051579,83274637 -23088.10276460648,0.9437015056610109,4952.603621482849,7335,0,4952.603621482849,0.1262884487972861,95000000,28041.903530597687,0.1243092920748317,0.1238773545528653,83274637 -23624.24707388878,0.9672322273254396,5072.642145395279,7514,0,5072.642145395279,0.1261816726870888,95000000,28698.11584186554,0.1212602298613051,0.1238637028004528,83274637 -24162.876575231552,0.9906661510467528,5193.1017372608185,7704,0,5193.1017372608185,0.1263341030838815,95000000,29357.23519897461,0.1217776046093529,0.1238912718399649,83274637 -24697.59641933441,1.0166869163513184,5313.213019609451,7883,0,5313.213019609451,0.1261886244757401,95000000,30012.0986905098,0.1231056984467139,0.1238406321892012,83274637 -25231.72269630432,1.04118013381958,5433.411878347397,8068,0,5433.411878347397,0.1262927160773026,95000000,30666.45467376709,0.1232462159398966,0.123903799118707,83274637 -25781.360393047333,1.0681800842285156,5553.558212995529,8245,0,5553.558212995529,0.1260936042557565,95000000,31336.271854400635,0.1224056126420977,0.1237633038631558,83274637 -26313.62245297432,1.0924384593963623,5673.848515033722,8422,0,5673.848515033722,0.1261293250411184,95000000,31988.85461974144,0.122629171401356,0.1237926062368335,83274637 -26858.971853733063,1.1167902946472168,5794.33474445343,8597,0,5794.33474445343,0.12600051003289472,95000000,32654.720746278763,0.12221383043336419,0.1236824945735541,83274637 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 2414a8227..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,137 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.4073896,0.70808744,,,,,,,,,,, -1,,,0.7075171364740755,0.7065163343942286,83274637.0,0.7073875930098684,95000000.0,20.41549754142761,801.339747428894,20.41549754142761,780.9242124557495,0.0,0.0 -100,0.10371736,0.12664557,,,,,,,,,,, -183,,,0.1271871388099103,0.128920479360426,83274637.0,0.1318383534745066,95000000.0,140.66957092285156,1552.8158721923828,140.66957092285156,1412.1124548912048,0.0275628566741943,0.0 -200,0.05960973,0.12568417,,,,,,,,,,, -300,0.010365392,0.1220967,,,,,,,,,,, -357,,,0.1301833171303729,0.1277769713631602,83274637.0,0.1304574425370065,95000000.0,261.05939412117004,2295.6186287403107,261.05939412117004,2034.498508453369,0.0482709407806396,0.0 -400,0.044213112,0.12639198,,,,,,,,,,, -500,0.02943876,0.13026261,,,,,,,,,,, -537,,,0.1262307041705404,0.1273617779289393,83274637.0,0.1297183601459704,95000000.0,381.33195638656616,3028.751186609268,381.33195638656616,2647.331680059433,0.069040298461914,0.0 -600,0.0057862224,0.12046848,,,,,,,,,,, -700,0.018406529,0.12634963,,,,,,,,,,, -714,,,0.1260030443469683,0.1265429995104162,83274637.0,0.1287761107319079,95000000.0,501.7907288074493,3751.5082404613495,501.7907288074493,3249.598855018616,0.0938754081726074,0.0 -800,0.03489539,0.12097598,,,,,,,,,,, -889,,,0.1266442178016376,0.1269329799259305,83274637.0,0.1293202573704769,95000000.0,621.9414341449738,4456.42281794548,621.9414341449738,3834.333093643189,0.1173727512359619,0.0 -900,0.0126120625,0.12574846,,,,,,,,,,, -1000,0.03339705,0.13155031,,,,,,,,,,, -1069,,,0.1271526435003528,0.1278378501184768,83274637.0,0.1301694638363487,95000000.0,742.3252635002136,5121.351807594299,742.3252635002136,4378.849170207977,0.1401965618133545,0.0 -1100,0.012081907,0.12793505,,,,,,,,,,, -1200,0.009121256,0.119793296,,,,,,,,,,, -1254,,,0.1242560849718327,0.125565447074868,83274637.0,0.1280532586759868,95000000.0,862.3552062511444,5773.757102012634,862.3552062511444,4911.195806980133,0.1622538566589355,0.0 -1300,0.03893512,0.13605224,,,,,,,,,,, -1400,0.04171889,0.13143443,,,,,,,,,,, -1434,,,0.1224136979629596,0.1257100016173366,83274637.0,0.1281122845497533,95000000.0,982.542487859726,6438.3228328228,982.542487859726,5455.546610832214,0.1836435794830322,0.0 -1500,0.03444402,0.12839966,,,,,,,,,,, -1600,0.013148673,0.12607083,,,,,,,,,,, -1609,,,0.1236820129455073,0.1255223536084451,83274637.0,0.1278192471011513,95000000.0,1102.916160583496,7096.895583152771,1102.916160583496,5993.718648195267,0.2044479846954345,0.0 -1700,0.024152048,0.12251246,,,,,,,,,,, -1785,,,0.1258179464219313,0.1255218474179224,83274637.0,0.1279375389494243,95000000.0,1223.486125946045,7659.149473905563,1223.486125946045,6435.375310897827,0.225440502166748,0.0 -1800,0.049747232,0.12641695,,,,,,,,,,, -1900,0.025512276,0.12431541,,,,,,,,,,, -1960,,,0.1241794411186714,0.1253777853219207,83274637.0,0.1277369629214638,95000000.0,1343.8530678749084,8228.014498233795,1343.8530678749084,6883.846434593201,0.2464847564697265,0.0 -2000,0.050720345,0.124862745,,,,,,,,,,, -2100,0.011547484,0.11715244,,,,,,,,,,, -2141,,,0.124650988493513,0.1250960292481521,83274637.0,0.1275420262438322,95000000.0,1464.0568370819092,8894.195020198822,1464.0568370819092,7429.795899629593,0.2674703598022461,0.0 -2200,0.0061434656,0.123105675,,,,,,,,,,, -2300,0.010326746,0.124776885,,,,,,,,,,, -2323,,,0.1263659366716941,0.1258524653687848,83274637.0,0.1281815174136513,95000000.0,1584.1165263652802,9556.212258338928,1584.1165263652802,7971.723219394684,0.2913937568664551,0.0 -2400,0.019478347,0.12167991,,,,,,,,,,, -2500,0.037597124,0.12351765,,,,,,,,,,, -2504,,,0.1238533474106646,0.1248805963113579,83274637.0,0.1272554857319079,95000000.0,1704.4217224121094,10149.593060016632,1704.4217224121094,8444.770316362381,0.3136165142059326,0.0 -2600,0.03035399,0.1382988,,,,,,,,,,, -2687,,,0.1220445775854512,0.1246625911546021,83274637.0,0.1270345026932565,95000000.0,1824.6317229270933,10761.48148536682,1824.6317229270933,8936.420797586441,0.3353207111358642,0.0 -2700,0.0175574,0.12324913,,,,,,,,,,, -2800,0.005698187,0.12901172,,,,,,,,,,, -2868,,,0.1214646557756002,0.1249482107328564,83274637.0,0.1273176975534539,95000000.0,1944.9548075199127,11417.6058242321,1944.9548075199127,9472.194302797318,0.3567848205566406,0.0 -2900,0.0075903023,0.11817081,,,,,,,,,,, -3000,0.013829896,0.1268838,,,,,,,,,,, -3043,,,0.12254764501638,0.124845981801514,83274637.0,0.1271934990028783,95000000.0,2065.510046243668,12068.517102956772,2065.510046243668,10002.52248263359,0.3787112236022949,0.0 -3100,0.009573645,0.13454546,,,,,,,,,,, -3200,0.025800468,0.12423524,,,,,,,,,,, -3222,,,0.1253731692760433,0.1247920694588914,83274637.0,0.1269839272203947,95000000.0,2185.699591398239,12727.081996917725,2185.699591398239,10540.870343446732,0.3999860286712646,0.0 -3300,0.010809552,0.11931986,,,,,,,,,,, -3400,,,0.1233362405086463,0.1247953959308181,83274637.0,0.1270944154913651,95000000.0,2306.06552362442,13403.666176319122,2306.06552362442,11097.06042098999,0.4221045970916748,0.0 -3400,0.00998079,0.12477943,,,,,,,,,,, -3500,0.014630846,0.12651917,,,,,,,,,,, -3578,,,0.123385631496895,0.1247588130247568,83274637.0,0.1270553401521381,95000000.0,2426.2269320487976,14060.331575393677,2426.2269320487976,11633.53394317627,0.4463634490966797,0.0 -3600,0.008018392,0.12197944,,,,,,,,,,, -3700,0.010332964,0.12469475,,,,,,,,,,, -3758,,,0.1234821093930575,0.124734425605144,83274637.0,0.12705738671875,95000000.0,2546.5644059181213,14697.38048863411,2546.5644059181213,12150.217435836792,0.4680953025817871,0.0 -3800,0.03475191,0.12731709,,,,,,,,,,, -3900,0.024368942,0.11535076,,,,,,,,,,, -3941,,,0.12291515066798,0.124811235773253,83274637.0,0.1272678745271381,95000000.0,2666.656278610229,15350.591654062271,2666.656278610229,12683.30470943451,0.4939754009246826,0.0 -4000,0.025217265,0.12937003,,,,,,,,,,, -4100,0.03938184,0.11982174,,,,,,,,,,, -4121,,,0.1219303092960291,0.1244560978054758,83274637.0,0.1268679931537829,95000000.0,2786.872770547867,16000.70842075348,2786.872770547867,13213.176526546478,0.5162239074707031,0.0 -4200,0.013316785,0.12111738,,,,,,,,,,, -4300,,,0.1210334593980754,0.1244571152936341,83274637.0,0.1269428061163651,95000000.0,2907.630640745163,16679.281347990036,2907.630640745163,13770.963485240936,0.538172721862793,0.0 -4300,0.0088409,0.12212373,,,,,,,,,,, -4400,0.021682113,0.12620604,,,,,,,,,,, -4478,,,0.1224394360404906,0.1244980308549427,83274637.0,0.1268535073499177,95000000.0,3027.781261205673,17357.626393795013,3027.781261205673,14329.13000202179,0.5602178573608398,0.0 -4500,0.005755223,0.13047132,,,,,,,,,,, -4600,0.012221657,0.11363322,,,,,,,,,,, -4653,,,0.1207981519166778,0.1244231657932092,83274637.0,0.1268224839432565,95000000.0,3147.901193380356,18022.7213242054,3147.901193380356,14874.075819253922,0.5834667682647705,0.0 -4700,0.0071486565,0.13246146,,,,,,,,,,, -4800,0.01860291,0.12056944,,,,,,,,,,, -4830,,,0.1231741636710346,0.1244902192766388,83274637.0,0.1269407832648026,95000000.0,3268.092779636383,18704.55334186554,3268.092779636383,15435.684971809387,0.6086556911468506,0.0 -4900,0.0067063975,0.12049677,,,,,,,,,,, -5000,0.011213513,0.11802733,,,,,,,,,,, -5007,,,0.1245397203045446,0.1243035075815423,83274637.0,0.1266421239514802,95000000.0,3388.683220624924,19375.33169746399,3388.683220624924,15985.845247745514,0.6302371025085449,0.0 -5100,0.009387051,0.12144977,,,,,,,,,,, -5185,,,0.1211312281776149,0.1243408592224259,83274637.0,0.1267856268503289,95000000.0,3508.74847984314,20043.603875160217,3508.74847984314,16534.02060341835,0.6558964252471924,0.0 -5200,0.026457766,0.12013461,,,,,,,,,,, -5300,0.012354995,0.1271243,,,,,,,,,,, -5364,,,0.1234937511126762,0.1243307710620822,83274637.0,0.1267167002158717,95000000.0,3628.8643069267273,20707.52645087242,3628.8643069267273,17077.79897761345,0.6779756546020508,0.0 -5400,0.018708335,0.1290006,,,,,,,,,,, -5500,0.007065444,0.1309199,,,,,,,,,,, -5544,,,0.1238525861042475,0.1242785482239253,83274637.0,0.1266714173108552,95000000.0,3749.122021913528,21367.10143470764,3749.122021913528,17617.085867881775,0.7020981311798096,0.0 -5600,0.0053921198,0.124913655,,,,,,,,,,, -5700,0.0062048803,0.1250586,,,,,,,,,,, -5726,,,0.1218163539544216,0.1244422784799541,83274637.0,0.1268647591180098,95000000.0,3870.0109403133392,22048.54324698448,3870.0109403133392,18177.610891342163,0.7237498760223389,0.0 -5800,0.0062707225,0.12336374,,,,,,,,,,, -5900,0.02586439,0.12583293,,,,,,,,,,, -5907,,,0.1214116878877833,0.1243632209652411,83274637.0,0.1267648204872533,95000000.0,3990.370181083679,22719.53905010224,3990.370181083679,18728.214420557026,0.7506492137908936,0.0 -6000,0.005126988,0.11462946,,,,,,,,,,, -6086,,,0.1218489267564607,0.1240728441670698,83274637.0,0.1265344971731085,95000000.0,4110.908145189285,23384.900443792343,4110.908145189285,19273.00504016876,0.777338981628418,0.0 -6100,0.01673371,0.13291736,,,,,,,,,,, -6200,0.0065839007,0.12734157,,,,,,,,,,, -6262,,,0.1221851150757111,0.1239943200612698,83274637.0,0.1262730346114309,95000000.0,4231.136974811554,24048.50984811783,4231.136974811554,19816.35657382012,0.8003125190734863,0.0 -6300,0.008832012,0.13061729,,,,,,,,,,, -6400,0.0065378207,0.12288116,,,,,,,,,,, -6438,,,0.1212437743051621,0.1240503758843689,83274637.0,0.1263620978207237,95000000.0,4351.401541948319,24717.06759095192,4351.401541948319,20364.616135120392,0.8278322219848633,0.0 -6500,0.006868472,0.11863662,,,,,,,,,,, -6600,0.011202815,0.12116524,,,,,,,,,,, -6613,,,0.1240159423447817,0.1238895284264343,83274637.0,0.1262541393914473,95000000.0,4471.631459951401,25387.65100812912,4471.631459951401,20914.94091773033,0.8502209186553955,0.0 -6700,0.011017051,0.121496685,,,,,,,,,,, -6794,,,0.1209843989883391,0.1239240574639659,83274637.0,0.1262963168893914,95000000.0,4591.935922861099,26051.4972178936,4591.935922861099,21458.45505642891,0.8717246055603027,0.0 -6800,0.010330223,0.11961944,,,,,,,,,,, -6900,0.005197394,0.11879745,,,,,,,,,,, -6972,,,0.1233273534889116,0.123803586259741,83274637.0,0.126148935598273,95000000.0,4712.108925819397,26704.202216148376,4712.108925819397,21990.9538500309,0.898688554763794,0.0 -7000,0.0053390833,0.11515863,,,,,,,,,,, -7100,0.0067385673,0.12391241,,,,,,,,,,, -7157,,,0.120873235622948,0.1240408316051579,83274637.0,0.1263409345908717,95000000.0,4832.4965353012085,27372.942261219025,4832.4965353012085,22539.277502298355,0.9209742546081544,0.0 -7200,0.014595989,0.12249733,,,,,,,,,,, -7300,0.007355165,0.12834284,,,,,,,,,,, -7335,,,0.1243092920748317,0.1238773545528653,83274637.0,0.1262884487972861,95000000.0,4952.603621482849,28041.903530597687,4952.603621482849,23088.10276460648,0.9437015056610109,0.0 -7400,0.012669442,0.12844905,,,,,,,,,,, -7500,0.008529365,0.12680563,,,,,,,,,,, -7514,,,0.1212602298613051,0.1238637028004528,83274637.0,0.1261816726870888,95000000.0,5072.642145395279,28698.11584186554,5072.642145395279,23624.24707388878,0.9672322273254396,0.0 -7600,0.011260989,0.12503873,,,,,,,,,,, -7700,0.011189284,0.11840763,,,,,,,,,,, -7704,,,0.1217776046093529,0.1238912718399649,83274637.0,0.1263341030838815,95000000.0,5193.1017372608185,29357.23519897461,5193.1017372608185,24162.876575231552,0.9906661510467528,0.0 -7800,0.021421434,0.117444076,,,,,,,,,,, -7883,,,0.1231056984467139,0.1238406321892012,83274637.0,0.1261886244757401,95000000.0,5313.213019609451,30012.0986905098,5313.213019609451,24697.59641933441,1.0166869163513184,0.0 -7900,0.010091254,0.119762056,,,,,,,,,,, -8000,0.007877991,0.11609673,,,,,,,,,,, -8068,,,0.1232462159398966,0.123903799118707,83274637.0,0.1262927160773026,95000000.0,5433.411878347397,30666.45467376709,5433.411878347397,25231.72269630432,1.04118013381958,0.0 -8100,0.007374876,0.12410062,,,,,,,,,,, -8200,0.0052226246,0.11777951,,,,,,,,,,, -8245,,,0.1224056126420977,0.1237633038631558,83274637.0,0.1260936042557565,95000000.0,5553.558212995529,31336.271854400635,5553.558212995529,25781.360393047333,1.0681800842285156,0.0 -8300,0.008379875,0.115572095,,,,,,,,,,, -8400,0.013205564,0.12279713,,,,,,,,,,, -8422,,,0.122629171401356,0.1237926062368335,83274637.0,0.1261293250411184,95000000.0,5673.848515033722,31988.85461974144,5673.848515033722,26313.62245297432,1.0924384593963623,0.0 -8500,0.0084199775,0.1229531,,,,,,,,,,, -8597,,,0.1222138304333641,0.1236824945735541,83274637.0,0.1260005100328947,95000000.0,5794.33474445343,32654.720746278763,5794.33474445343,26858.971853733063,1.1167902946472168,0.0 -8597,,,,,,,,5794.33474445343,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index b701fbe23..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,20 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -192.9060909748077,0.0,55.87105417251587,1,0,55.87105417251587,1.1173639412000838,3581,0.2186063006416242,248.77755641937256,1.120415415082659,0.1986473628452846,1.1201521530537777,3554,0.196472271010657 -197.26366686820984,0.0277307033538818,136.2571702003479,342,0,136.2571702003479,0.3142570830097214,3581,0.7130424544252653,333.5596721172333,0.291083744594029,0.7166670390537807,0.3119649033769168,3554,0.6954936476725169 -201.28857111930847,0.0647187232971191,216.53292417526245,598,0,216.53292417526245,0.3039552485666538,3581,0.7237157834971027,417.9049689769745,0.2812513794217791,0.727783203125,0.301936109346423,3554,0.7064778462647721 -205.31412816047668,0.0955805778503418,296.5067882537842,857,0,296.5067882537842,0.2993538011488585,3581,0.7282441455337196,501.9431965351105,0.2770792245864868,0.7322325706481934,0.2974483251705824,3554,0.7111753207213702 -209.34389734268188,0.1279308795928955,376.4983882904053,1167,0,376.4983882904053,0.2972662317722878,3581,0.7305277228122382,586.0067405700684,0.2743423836571829,0.7352104187011719,0.295615999632896,3554,0.7133506040728756 -213.37495374679563,0.1561641693115234,456.6143398284912,1518,0,456.6143398284912,0.294206463202056,3581,0.7327835521895071,670.193336725235,0.2720884595598493,0.7365893636431012,0.2925280397637345,3554,0.7156296162774338 -217.40148663520813,0.1792163848876953,536.6856021881104,1868,0,536.6856021881104,0.2959254354579726,3581,0.7319947482154077,754.3252778053284,0.274249792098999,0.7352070127214704,0.2943375930399726,3554,0.7147479210264842 -221.4318311214447,0.2015833854675293,616.8693685531616,2213,0,616.8693685531616,0.2936719240871963,3581,0.7346562969055431,838.5729329586029,0.2714700869151524,0.7384180341448102,0.2921553715333955,3554,0.7174992086381542 -225.4610998630524,0.2239062786102295,697.048171043396,2565,0,697.048171043396,0.2934958919470818,3581,0.7338752650708601,922.8145685195924,0.2710255725043161,0.7383323396955218,0.291980406375299,3554,0.7165508796206739 -229.491126537323,0.250126838684082,777.1508128643036,2914,0,777.1508128643036,0.2938542625685039,3581,0.7338329955407009,1006.9848296642303,0.2712868962969099,0.7385595185416085,0.2922135558635516,3554,0.7167994167003728 -233.51832556724548,0.2741034030914306,857.2717974185944,3263,0,857.2717974185944,0.293154190519408,3581,0.7363712808267593,1091.1679165363312,0.2709583214351109,0.740112168448312,0.291687836054006,3554,0.7193421474922622 -237.5518708229065,0.2962205410003662,937.3469772338868,3613,0,937.3469772338868,0.2902872937246928,3581,0.7388256406424533,1175.3094856739044,0.2678266423089163,0.7433418546404157,0.2887829129589899,3554,0.7216860074036298 -241.58291339874268,0.3197650909423828,1017.3912198543547,3962,0,1017.3912198543547,0.2908454560527785,3581,0.7370695461943242,1259.4191937446594,0.2685698270797729,0.741295405796596,0.2893246729587261,3554,0.7198487702017093 -245.614670753479,0.3463306427001953,1097.4083876609802,4312,0,1097.4083876609802,0.290724067506894,3581,0.7388518886571488,1343.5054585933683,0.2683579410825457,0.7433052062988281,0.2891037167628025,3554,0.7217283232801069 -249.64495658874512,0.36983323097229,1177.465470790863,4660,0,1177.465470790863,0.2899952249066252,3581,0.7398959460520804,1427.6272311210632,0.2671342747552054,0.7444005693708148,0.2884154311713034,3554,0.72267816357889 -253.6764862537384,0.3961493968963623,1257.4558582305908,5008,0,1257.4558582305908,0.2896336840704412,3581,0.738936155010821,1511.6863014698029,0.2669192041669573,0.7439136505126953,0.2881101523481464,3554,0.7216956246482836 -257.7066810131073,0.4196066856384277,1337.4680411815643,5360,0,1337.4680411815643,0.2888436529164339,3581,0.7394307766903448,1595.7629663944244,0.2663644892828805,0.7439980506896973,0.2874547714722847,3554,0.7221791659705613 -261.7404539585113,0.4434881210327148,1417.6038916110992,5710,0,1417.6038916110992,0.2892508721158545,3581,0.7395519947945756,1679.9671909809113,0.2666572162083217,0.7441534314836774,0.2878030702916168,3554,0.7222601569094682 -265.7671477794647,0.46773791313171387,1497.594055891037,6061,0,1497.594055891037,0.2888793433987538,3581,0.7412142781564158,1764.0194532871246,0.26613690171922955,0.7459335327148438,0.2874238760738341,3554,0.7240817317415237 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 03b3efd6c..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,82 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.588453,1.1324757,,,,,,,,,,,,,, -1,,,0.1986473628452846,1.120415415082659,0.196472271010657,1.1201521530537777,3554.0,0.2186063006416242,1.1173639412000838,3581.0,55.87105417251587,248.77755641937256,55.87105417251587,192.9060909748077,0.0,0.0 -100,0.21883173,0.23392078,,,,,,,,,,,,,, -200,0.17892344,0.29985616,,,,,,,,,,,,,, -300,0.1930982,0.3390982,,,,,,,,,,,,,, -342,,,0.7166670390537807,0.291083744594029,0.6954936476725169,0.3119649033769168,3554.0,0.7130424544252653,0.3142570830097214,3581.0,136.2571702003479,333.5596721172333,136.2571702003479,197.26366686820984,0.0277307033538818,0.0 -400,0.60579085,0.32581896,,,,,,,,,,,,,, -500,0.4643697,0.26526892,,,,,,,,,,,,,, -598,,,0.727783203125,0.2812513794217791,0.7064778462647721,0.301936109346423,3554.0,0.7237157834971027,0.3039552485666538,3581.0,216.53292417526245,417.9049689769745,216.53292417526245,201.28857111930847,0.0647187232971191,0.0 -600,0.10041086,0.32968396,,,,,,,,,,,,,, -700,0.35001785,0.2589087,,,,,,,,,,,,,, -800,0.12324807,0.3526528,,,,,,,,,,,,,, -857,,,0.7322325706481934,0.2770792245864868,0.7111753207213702,0.2974483251705824,3554.0,0.7282441455337196,0.2993538011488585,3581.0,296.5067882537842,501.9431965351105,296.5067882537842,205.31412816047668,0.0955805778503418,0.0 -900,0.14609452,0.32041368,,,,,,,,,,,,,, -1000,0.21213472,0.30569106,,,,,,,,,,,,,, -1100,0.22185993,0.27031824,,,,,,,,,,,,,, -1167,,,0.7352104187011719,0.2743423836571829,0.7133506040728756,0.295615999632896,3554.0,0.7305277228122382,0.2972662317722878,3581.0,376.4983882904053,586.0067405700684,376.4983882904053,209.34389734268188,0.1279308795928955,0.0 -1200,0.119448185,0.2897671,,,,,,,,,,,,,, -1300,0.30554202,0.2494826,,,,,,,,,,,,,, -1400,0.2108942,0.33948675,,,,,,,,,,,,,, -1500,0.42780977,0.29525688,,,,,,,,,,,,,, -1518,,,0.7365893636431012,0.2720884595598493,0.7156296162774338,0.2925280397637345,3554.0,0.7327835521895071,0.294206463202056,3581.0,456.6143398284912,670.193336725235,456.6143398284912,213.37495374679563,0.1561641693115234,0.0 -1600,0.34755382,0.33390707,,,,,,,,,,,,,, -1700,0.04429244,0.46484616,,,,,,,,,,,,,, -1800,0.08863003,0.37838164,,,,,,,,,,,,,, -1868,,,0.7352070127214704,0.274249792098999,0.7147479210264842,0.2943375930399726,3554.0,0.7319947482154077,0.2959254354579726,3581.0,536.6856021881104,754.3252778053284,536.6856021881104,217.40148663520813,0.1792163848876953,0.0 -1900,0.059442464,0.2836342,,,,,,,,,,,,,, -2000,0.18554017,0.31605756,,,,,,,,,,,,,, -2100,0.061351396,0.28339216,,,,,,,,,,,,,, -2200,0.04539597,0.3050909,,,,,,,,,,,,,, -2213,,,0.7384180341448102,0.2714700869151524,0.7174992086381542,0.2921553715333955,3554.0,0.7346562969055431,0.2936719240871963,3581.0,616.8693685531616,838.5729329586029,616.8693685531616,221.4318311214447,0.2015833854675293,0.0 -2300,0.11329267,0.28503564,,,,,,,,,,,,,, -2400,0.1262973,0.2601017,,,,,,,,,,,,,, -2500,0.095980935,0.24639049,,,,,,,,,,,,,, -2565,,,0.7383323396955218,0.2710255725043161,0.7165508796206739,0.291980406375299,3554.0,0.7338752650708601,0.2934958919470818,3581.0,697.048171043396,922.8145685195924,697.048171043396,225.4610998630524,0.2239062786102295,0.0 -2600,0.09259839,0.29258057,,,,,,,,,,,,,, -2700,0.06782509,0.32886493,,,,,,,,,,,,,, -2800,0.14498349,0.2512075,,,,,,,,,,,,,, -2900,0.104618706,0.2386779,,,,,,,,,,,,,, -2914,,,0.7385595185416085,0.2712868962969099,0.7167994167003728,0.2922135558635516,3554.0,0.7338329955407009,0.2938542625685039,3581.0,777.1508128643036,1006.9848296642303,777.1508128643036,229.491126537323,0.250126838684082,0.0 -3000,0.06842083,0.2971101,,,,,,,,,,,,,, -3100,0.070959836,0.23243582,,,,,,,,,,,,,, -3200,0.08253879,0.24771014,,,,,,,,,,,,,, -3263,,,0.740112168448312,0.2709583214351109,0.7193421474922622,0.291687836054006,3554.0,0.7363712808267593,0.293154190519408,3581.0,857.2717974185944,1091.1679165363312,857.2717974185944,233.51832556724548,0.2741034030914306,0.0 -3300,0.3086398,0.29279256,,,,,,,,,,,,,, -3400,0.06330595,0.27320582,,,,,,,,,,,,,, -3500,0.076649144,0.35211748,,,,,,,,,,,,,, -3600,0.15413111,0.3307109,,,,,,,,,,,,,, -3613,,,0.7433418546404157,0.2678266423089163,0.7216860074036298,0.2887829129589899,3554.0,0.7388256406424533,0.2902872937246928,3581.0,937.3469772338868,1175.3094856739044,937.3469772338868,237.5518708229065,0.2962205410003662,0.0 -3700,0.110814504,0.27827772,,,,,,,,,,,,,, -3800,0.05030259,0.29488188,,,,,,,,,,,,,, -3900,0.07571727,0.31337655,,,,,,,,,,,,,, -3962,,,0.741295405796596,0.2685698270797729,0.7198487702017093,0.2893246729587261,3554.0,0.7370695461943242,0.2908454560527785,3581.0,1017.3912198543547,1259.4191937446594,1017.3912198543547,241.58291339874268,0.3197650909423828,0.0 -4000,0.2250599,0.2209763,,,,,,,,,,,,,, -4100,0.1908394,0.27058634,,,,,,,,,,,,,, -4200,0.09590618,0.23519532,,,,,,,,,,,,,, -4300,0.11774375,0.30664387,,,,,,,,,,,,,, -4312,,,0.7433052062988281,0.2683579410825457,0.7217283232801069,0.2891037167628025,3554.0,0.7388518886571488,0.290724067506894,3581.0,1097.4083876609802,1343.5054585933683,1097.4083876609802,245.614670753479,0.3463306427001953,0.0 -4400,0.086698756,0.3276042,,,,,,,,,,,,,, -4500,0.11501223,0.21601894,,,,,,,,,,,,,, -4600,0.03737293,0.25340077,,,,,,,,,,,,,, -4660,,,0.7444005693708148,0.2671342747552054,0.72267816357889,0.2884154311713034,3554.0,0.7398959460520804,0.2899952249066252,3581.0,1177.465470790863,1427.6272311210632,1177.465470790863,249.64495658874512,0.36983323097229,0.0 -4700,0.0693346,0.3398607,,,,,,,,,,,,,, -4800,0.15143667,0.22684279,,,,,,,,,,,,,, -4900,0.06786488,0.30162227,,,,,,,,,,,,,, -5000,0.073945805,0.25515842,,,,,,,,,,,,,, -5008,,,0.7439136505126953,0.2669192041669573,0.7216956246482836,0.2881101523481464,3554.0,0.738936155010821,0.2896336840704412,3581.0,1257.4558582305908,1511.6863014698029,1257.4558582305908,253.6764862537384,0.3961493968963623,0.0 -5100,0.090524144,0.24579582,,,,,,,,,,,,,, -5200,0.16392428,0.24430388,,,,,,,,,,,,,, -5300,0.16353379,0.26602754,,,,,,,,,,,,,, -5360,,,0.7439980506896973,0.2663644892828805,0.7221791659705613,0.2874547714722847,3554.0,0.7394307766903448,0.2888436529164339,3581.0,1337.4680411815643,1595.7629663944244,1337.4680411815643,257.7066810131073,0.4196066856384277,0.0 -5400,0.11574315,0.34114978,,,,,,,,,,,,,, -5500,0.28286842,0.21828038,,,,,,,,,,,,,, -5600,0.069128506,0.25325805,,,,,,,,,,,,,, -5700,0.059760667,0.31716117,,,,,,,,,,,,,, -5710,,,0.7441534314836774,0.2666572162083217,0.7222601569094682,0.2878030702916168,3554.0,0.7395519947945756,0.2892508721158545,3581.0,1417.6038916110992,1679.9671909809113,1417.6038916110992,261.7404539585113,0.4434881210327148,0.0 -5800,0.18202074,0.39349112,,,,,,,,,,,,,, -5900,0.07162934,0.24536075,,,,,,,,,,,,,, -6000,0.34127748,0.2136616,,,,,,,,,,,,,, -6061,,,0.7459335327148438,0.2661369017192295,0.7240817317415237,0.2874238760738341,3554.0,0.7412142781564158,0.2888793433987538,3581.0,1497.594055891037,1764.0194532871246,1497.594055891037,265.7671477794647,0.4677379131317138,0.0 -6061,,,,,,,,,,,1497.594055891037,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 01dc3598d..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,372 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -37.36424398422241,0.0,52.38156604766846,1,0,52.38156604766846,0.0009000000427477,6.913362979888916,10000,89.74589896202087,0.0012356505030766,6.911959171295166,0.0010799999581649,6.912962913513184,50000 -55.82190465927124,0.0253584384918212,562.5784339904785,1509,0,562.5784339904785,0.125900000333786,4.813449382781982,10000,618.4830918312073,0.1795679181814193,4.181913375854492,0.1612999886274337,4.331301212310791,50000 -74.00124216079712,0.0526726245880126,1072.4957585334778,3016,0,1072.4957585334778,0.2266000062227249,3.981245994567871,10000,1146.6672384738922,0.3254344761371612,3.1697702407836914,0.2942200005054474,3.366985559463501,50000 -92.29433751106262,0.0788233280181884,1582.6927065849304,4525,0,1582.6927065849304,0.3230000138282776,3.316889524459839,10000,1675.2411172389984,0.4550183117389679,2.387629985809326,0.4248600006103515,2.5835626125335693,50000 -110.38658118247986,0.106652021408081,2092.691743612289,6035,0,2092.691743612289,0.3857000172138214,2.9400079250335693,10000,2203.4167931079865,0.5336614847183228,1.9763846397399905,0.4961600005626678,2.1787710189819336,50000 -128.57271671295166,0.1335184574127197,2602.7501418590546,7545,0,2602.7501418590546,0.415800005197525,2.75813889503479,10000,2731.746652841568,0.5792012214660645,1.7531155347824097,0.5401399731636047,1.974346160888672,50000 -146.66174578666687,0.1640608310699463,3112.906903028488,9057,0,3112.906903028488,0.4338000118732452,2.645254135131836,10000,3260.0810463428497,0.6060865521430969,1.620163083076477,0.5582199692726135,1.88018000125885,50000 -165.1640853881836,0.1944692134857177,3622.962684154512,10568,0,3622.962684154512,0.4500000178813934,2.5940630435943604,10000,3788.7260398864746,0.6356824040412903,1.4726837873458862,0.5694800019264221,1.8395023345947263,50000 -185.43234658241272,0.2300171852111816,4133.085363149643,12080,0,4133.085363149643,0.4550000131130218,2.581552028656006,10000,4319.210289001465,0.6251793503761292,1.5146089792251587,0.5725199580192566,1.817832112312317,50000 -205.85814785957336,0.2600185871124267,4643.2337856292725,13592,0,4643.2337856292725,0.4630000293254852,2.468202829360962,10000,4849.872251987457,0.6494539380073547,1.422621726989746,0.5925599932670593,1.707416534423828,50000 -226.37191343307487,0.2892529964447021,5153.15531373024,15104,0,5153.15531373024,0.4723000228404999,2.4483578205108643,10000,5380.393999576569,0.6466438174247742,1.4184733629226685,0.5945199728012085,1.7023662328720093,50000 -246.3868336677552,0.3350415229797363,5663.251859664917,16616,0,5663.251859664917,0.4771000146865845,2.402414321899414,10000,5910.609344005585,0.6574457883834839,1.3805900812149048,0.6076599955558777,1.6330012083053589,50000 -270.5326895713806,0.364840030670166,6173.184433221817,18128,0,6173.184433221817,0.4830000102519989,2.4177982807159424,10000,6444.77542424202,0.6847695708274841,1.253440260887146,0.6041199564933777,1.6582300662994385,50000 -294.9985795021057,0.404205322265625,6683.125174283981,19640,0,6683.125174283981,0.4882000088691711,2.403407573699951,10000,6979.279679298401,0.6862244606018066,1.2184010744094849,0.6089800000190735,1.6390219926834106,50000 -319.0235517024994,0.4361577033996582,7193.361950397492,21154,0,7193.361950397492,0.4700000286102295,2.498812675476074,10000,7513.631420373917,0.6550542116165161,1.37153160572052,0.5969399809837341,1.7031642198562622,50000 -343.1838707923889,0.4693658351898193,7703.452437400818,22667,0,7703.452437400818,0.4869000315666199,2.3287904262542725,10000,8047.973657846451,0.6710379123687744,1.3040506839752195,0.6122399568557739,1.612977147102356,50000 -365.3846924304962,0.4949944019317627,8213.370545864105,24180,0,8213.370545864105,0.5056000351905823,2.2998135089874268,10000,8580.17731165886,0.6864636540412903,1.2356112003326416,0.6240800023078918,1.556673288345337,50000 -389.8584134578705,0.5229532718658447,8723.534445524216,25693,0,8723.534445524216,0.4832000136375427,2.368893623352051,10000,9114.901677131653,0.6590800285339355,1.3417572975158691,0.6096000075340271,1.6214866638183594,50000 -413.8295383453369,0.5503711700439453,9233.515226125715,27205,0,9233.515226125715,0.49590003490448,2.290503978729248,10000,9648.93862605095,0.7247090339660645,1.0631380081176758,0.6251999735832214,1.5556570291519165,50000 -438.5450539588928,0.5776462554931641,9743.726099729538,28718,0,9743.726099729538,0.4818000197410583,2.3851914405822754,10000,10183.951251506804,0.6881178021430969,1.2257320880889893,0.6116399765014648,1.6099690198898315,50000 -469.4787917137146,0.6060595512390137,10253.642538785934,30231,0,10253.642538785934,0.4962000250816345,2.30889105796814,10000,10724.887073040009,0.69140625,1.2178308963775637,0.6235600113868713,1.5551875829696655,50000 -494.9138953685761,0.6445157527923584,10763.705124855042,31744,0,10763.705124855042,0.4924000203609466,2.371462106704712,10000,11260.48165845871,0.6746252775192261,1.277440309524536,0.6144999861717224,1.6090742349624634,50000 -519.7346889972687,0.6725142002105713,11273.895986318588,33257,0,11273.895986318588,0.5009000301361084,2.308980941772461,10000,11795.578994750977,0.6889747977256775,1.2175207138061523,0.6301999688148499,1.5399216413497925,50000 -544.9519906044006,0.7002818584442139,11783.841839790344,34770,0,11783.841839790344,0.5055000185966492,2.2644104957580566,10000,12330.82792854309,0.6912468075752258,1.2342867851257324,0.6308000087738037,1.5375585556030271,50000 -570.0424513816833,0.7297089099884033,12293.991918563845,36284,0,12293.991918563845,0.4988000094890594,2.274627685546875,10000,12866.157500267029,0.7382612824440002,1.002962589263916,0.6293799877166748,1.5438278913497925,50000 -592.3089263439178,0.7659671306610107,12803.992875099182,37797,0,12803.992875099182,0.5065000057220459,2.259315729141236,10000,13398.519066810608,0.7053372263908386,1.138272404670715,0.6288599967956543,1.5252172946929932,50000 -613.0463311672211,0.7976090908050537,13313.982473373411,39311,0,13313.982473373411,0.5081000328063965,2.251711368560791,10000,13929.336554050446,0.6939771771430969,1.196521520614624,0.6310999989509583,1.5323960781097412,50000 -631.3619570732117,0.8261539936065674,13824.217654943466,40825,0,13824.217654943466,0.5074000358581543,2.224247932434082,10000,14457.974243879318,0.7027463316917419,1.174934148788452,0.6353999972343445,1.4998513460159302,50000 -650.3616156578064,0.861600399017334,14334.139183282852,42339,0,14334.139183282852,0.5037000179290771,2.296604633331299,10000,14986.990688800812,0.6916852593421936,1.219604253768921,0.6343799829483032,1.524438977241516,50000 -671.3503079414368,0.8955569267272949,14844.313416957855,43853,0,14844.313416957855,0.4993000328540802,2.297249317169189,10000,15518.24677681923,0.687898576259613,1.2292335033416748,0.631060004234314,1.523308038711548,50000 -690.5901215076447,0.9257237911224364,15354.307039022446,45367,0,15354.307039022446,0.5103000402450562,2.2528679370880127,10000,16047.569056749344,0.7377431392669678,0.9952830076217652,0.6425999999046326,1.4854003190994265,50000 -708.3193356990814,0.9613451957702636,15864.34359741211,46881,0,15864.34359741211,0.5138000249862671,2.251056432723999,10000,16575.43159174919,0.7153220772743225,1.084122657775879,0.6393199563026428,1.4918371438980105,50000 -726.0973536968231,0.9971561431884766,16374.491770744324,48395,0,16374.491770744324,0.5082000494003296,2.254747152328491,10000,17103.453681230545,0.702168345451355,1.1595250368118286,0.6326000094413757,1.5194488763809204,50000 -744.994348526001,1.0333445072174072,16884.507195949554,49908,0,16884.507195949554,0.5149000287055969,2.219771146774292,10000,17632.462087869644,0.7033841013908386,1.1538665294647217,0.6426599621772766,1.476432204246521,50000 -763.3682973384857,1.068270206451416,17394.70106601715,51422,0,17394.70106601715,0.5091000199317932,2.231141805648804,10000,18161.12567853928,0.7069913744926453,1.1432219743728638,0.6446399688720703,1.4551054239273071,50000 -781.212443113327,1.1037437915802002,17904.66800928116,52935,0,17904.66800928116,0.5148000121116638,2.219569206237793,10000,18689.03200531006,0.696707546710968,1.1853846311569214,0.6376000046730042,1.4835251569747925,50000 -798.8383796215057,1.1416583061218262,18414.68908238411,54449,0,18414.68908238411,0.5045000314712524,2.2475028038024902,10000,19216.77681660652,0.7293726205825806,1.033711075782776,0.6370999813079834,1.4978829622268677,50000 -816.3726131916046,1.1773545742034912,18924.828382968903,55962,0,18924.828382968903,0.5263000130653381,2.128363609313965,10000,19744.545511245728,0.7280173897743225,1.0381181240081787,0.6498199701309204,1.4298878908157349,50000 -834.9931175708771,1.2265896797180176,19434.806651115417,57476,0,19434.806651115417,0.5209000110626221,2.188610553741455,10000,20273.25348353386,0.7173150181770325,1.0886608362197876,0.6471399664878845,1.4522713422775269,50000 -856.2516252994537,1.2634735107421875,19944.79716658592,58989,0,19944.79716658592,0.526900053024292,2.157479286193848,10000,20804.59820485115,0.7143853306770325,1.1113851070404053,0.6462999582290649,1.4568002223968506,50000 -873.8442769050598,1.2946317195892334,20454.971937417984,60504,0,20454.971937417984,0.5327000021934509,2.102976560592652,10000,21332.455352783203,0.7247688174247742,1.0654748678207395,0.6577799916267395,1.3966306447982788,50000 -891.4071831703186,1.3386225700378418,20964.876480817795,62017,0,20964.876480817795,0.5214000344276428,2.1913771629333496,10000,21860.027040958405,0.7164580821990967,1.0905795097351074,0.6516599655151367,1.424889326095581,50000 -911.6735122203828,1.3813552856445312,21474.96597290039,63531,0,21474.96597290039,0.525700032711029,2.16554856300354,10000,22390.485930919647,0.7502989172935486,0.9399452805519104,0.6513599753379822,1.4273583889007568,50000 -930.4063663482666,1.414006233215332,21985.078320980072,65045,0,21985.078320980072,0.5302000045776367,2.1050965785980225,10000,22919.42360758781,0.7389987111091614,0.9976285696029664,0.6594399809837341,1.386866569519043,50000 -948.8673043251038,1.452296018600464,22495.0784740448,66559,0,22495.0784740448,0.5330000519752502,2.112121820449829,10000,23447.982553720474,0.7310267686843872,1.0304648876190186,0.6547399759292603,1.406494140625,50000 -966.5051369667052,1.4912090301513672,23005.171627759933,68072,0,23005.171627759933,0.5356000065803528,2.10113525390625,10000,23975.81109070778,0.7293127775192261,1.0425559282302856,0.6594399809837341,1.407362461090088,50000 -983.824857711792,1.5330331325531006,23515.16109251976,69586,0,23515.16109251976,0.5320000052452087,2.1139161586761475,10000,24503.22127485276,0.7391382455825806,1.0020445585250854,0.6672199964523315,1.3593355417251587,50000 -1002.3198018074036,1.5721306800842283,24025.32474899292,71100,0,24025.32474899292,0.5295000076293945,2.145573139190674,10000,25031.97860717773,0.7205436825752258,1.0698813199996948,0.6531199812889099,1.4194097518920898,50000 -1019.9946706295012,1.6126031875610352,24535.45041823387,72614,0,24535.45041823387,0.5308000445365906,2.152554750442505,10000,25559.87971568108,0.74320387840271,0.9673858880996704,0.6480799913406372,1.4243541955947876,50000 -1037.6091718673706,1.6538634300231934,25045.385497808456,74128,0,25045.385497808456,0.5333000421524048,2.12503719329834,10000,26087.53072452545,0.7434629797935486,0.981106460094452,0.6617599725723267,1.38577401638031,50000 -1058.160115480423,1.6960022449493408,25555.467805862427,75642,0,25555.467805862427,0.5353000164031982,2.1104061603546143,10000,26618.26466965676,0.7388990521430969,0.9926463961601256,0.6643999814987183,1.3834563493728638,50000 -1075.3721516132357,1.735579490661621,26065.52658700943,77156,0,26065.52658700943,0.5426000356674194,2.0834763050079346,10000,27145.634598970413,0.7443000674247742,0.968315601348877,0.6729399561882019,1.3414160013198853,50000 -1092.6724405288696,2.6665658950805664,26574.551361083984,78668,0,26574.551361083984,0.5336000323295593,2.1245603561401367,10000,27672.950965881348,0.73046875,1.0355678796768188,0.6589999794960022,1.397053837776184,50000 -1109.7548336982727,2.7132747173309326,27084.58042907715,80182,0,27084.58042907715,0.5433000326156616,2.059293031692505,10000,28200.16864085197,0.7581114172935486,0.9292783141136168,0.6686399579048157,1.3589168787002563,50000 -1127.0345587730408,2.7547857761383057,27594.486018419266,81696,0,27594.486018419266,0.5402000546455383,2.0668623447418213,10000,28727.456157445908,0.76761794090271,0.8746867179870605,0.671459972858429,1.3250740766525269,50000 -1144.5072071552277,2.79805588722229,28104.702069997787,83210,0,28104.702069997787,0.5398000478744507,2.0859215259552,10000,29255.24874305725,0.7473094463348389,0.9471119046211244,0.6668599843978882,1.3621351718902588,50000 -1161.6754019260406,2.844470977783203,28614.92431378365,84725,0,28614.92431378365,0.5400000214576721,2.061533212661743,10000,29782.745623350143,0.7489636540412903,0.9444343447685242,0.6720799803733826,1.3304868936538696,50000 -1178.9313924312592,2.8873541355133057,29125.15648293495,86240,0,29125.15648293495,0.5558000206947327,2.029335021972656,10000,30310.33651709557,0.7512555718421936,0.9321123957633972,0.6727399826049805,1.3387607336044312,50000 -1196.0384995937347,2.9288535118103027,29635.361619234085,87755,0,29635.361619234085,0.5366000533103943,2.072619676589966,10000,30837.75000357628,0.7402941584587097,0.9757063388824464,0.6700800061225891,1.3509361743927002,50000 -1213.3233938217163,2.9756851196289062,30145.517723798752,89269,0,30145.517723798752,0.5472000241279602,2.040836334228516,10000,31365.29650616645,0.7931481003761292,0.7678533792495728,0.6738199591636658,1.330244064331055,50000 -1230.7311687469482,3.023667097091675,30655.582379579544,90784,0,30655.582379579544,0.5485000014305115,2.0378308296203613,10000,31892.87816643715,0.769929826259613,0.8533276915550232,0.6751399636268616,1.321486234664917,50000 -1248.036145210266,3.087367534637451,31165.769117355347,92298,0,31165.769117355347,0.5489000082015991,2.0553572177886963,10000,32420.492695093155,0.7577128410339355,0.9045006632804872,0.671779990196228,1.3457863330841064,50000 -1265.370517730713,3.132223129272461,31675.75886678696,93812,0,31675.75886678696,0.5547000169754028,2.0494730472564697,10000,32947.921608924866,0.7662428021430969,0.8773167729377747,0.6835599541664124,1.2988988161087036,50000 -1282.691505908966,3.175578117370605,32185.968641519547,95326,0,32185.968641519547,0.5611000061035156,1.9772965908050537,10000,33475.55606675148,0.7669802308082581,0.8797918558120728,0.6861199736595154,1.2844245433807373,50000 -1300.1406605243685,3.221791982650757,32696.044674634933,96840,0,32696.044674634933,0.558899998664856,1.9927387237548828,10000,34003.18845510483,0.7622169852256775,0.8912380337715149,0.6829599738121033,1.2990450859069824,50000 -1317.5182647705078,3.265420913696289,33206.26758170128,98355,0,33206.26758170128,0.5583000183105469,1.9982116222381592,10000,34530.89180493355,0.8137754797935486,0.6855299472808838,0.6881600022315979,1.265963435173035,50000 -1334.5619978904724,3.310832023620605,33716.18281817436,99869,0,33716.18281817436,0.5529000163078308,2.0021021366119385,10000,35057.95794534683,0.7844786047935486,0.7990682721138,0.684939980506897,1.276252269744873,50000 -1351.8213539123535,3.355522871017456,34226.19294190407,101383,0,34226.19294190407,0.5521000027656555,2.012157440185547,10000,35585.33166742325,0.7774832248687744,0.8314793705940247,0.6846199631690979,1.2771825790405271,50000 -1369.1167182922363,3.406175136566162,34736.09021759033,102897,0,34736.09021759033,0.5592000484466553,1.97683584690094,10000,36112.63495731354,0.7782804369926453,0.8227626085281372,0.6896799802780151,1.256480097770691,50000 -1386.708063840866,3.452512502670288,35246.263035058975,104412,0,35246.263035058975,0.5574000477790833,2.013371229171753,10000,36640.50486254692,0.7736367583274841,0.8423909544944763,0.6892600059509277,1.259985089302063,50000 -1404.1752064228058,3.4970178604125977,35756.16144490242,105926,0,35756.16144490242,0.5586000084877014,1.974216103553772,10000,37167.974967479706,0.7770846486091614,0.8247220516204834,0.6956799626350403,1.2360202074050903,50000 -1421.6605622768402,3.54105544090271,36266.39616584778,107440,0,36266.39616584778,0.5689000487327576,1.9630481004714968,10000,37695.79827213287,0.8256935477256775,0.6410678625106812,0.6988799571990967,1.2236398458480835,50000 -1438.8392596244812,3.589161157608032,36776.41024065018,108954,0,36776.41024065018,0.5565000176429749,2.0258400440216064,10000,38223.09888339043,0.8023955225944519,0.722143292427063,0.694159984588623,1.2538851499557495,50000 -1456.2212126255035,3.6329078674316406,37286.4473862648,110468,0,37286.4473862648,0.5703000426292419,1.931430697441101,10000,38750.62214708328,0.8027144074440002,0.7184869050979614,0.7019400000572205,1.2076020240783691,50000 -1474.6009676456451,3.678417444229126,37796.50906038284,111983,0,37796.50906038284,0.5706000328063965,1.928413152694702,10000,39279.16947221756,0.7995057106018066,0.7365074753761292,0.7022199630737305,1.210964918136597,50000 -1491.822915315628,3.724119186401367,38306.47283291817,113497,0,38306.47283291817,0.5621000528335571,2.024306297302246,10000,39806.460503578186,0.7829639315605164,0.7980664968490601,0.6932199597358704,1.2627981901168823,50000 -1509.1600093841553,3.776813507080078,38816.370770692825,115010,0,38816.370770692825,0.5746999979019165,1.9470356702804563,10000,40333.80831313133,0.7958585619926453,0.7441632151603699,0.7008000016212463,1.2077741622924805,50000 -1526.1814465522766,3.824695587158203,39326.3773624897,116524,0,39326.3773624897,0.5791000127792358,1.8854345083236688,10000,40860.94369840622,0.8346619606018066,0.6010602116584778,0.7102000117301941,1.169505596160889,50000 -1543.270350933075,3.8735198974609375,39836.38484477997,118038,0,39836.38484477997,0.575700044631958,1.944663524627685,10000,41388.14964914322,0.8148317933082581,0.6677629947662354,0.7066999673843384,1.2082501649856567,50000 -1560.5260910987854,3.9229581356048575,40346.46527194977,119552,0,40346.46527194977,0.5781000256538391,1.9262524843215945,10000,41915.59468245506,0.8146523833274841,0.6767027378082275,0.705020010471344,1.2016663551330566,50000 -1577.9214849472046,3.97292423248291,40856.44110560417,121066,0,40856.44110560417,0.5789000391960144,1.921277642250061,10000,42443.076384067535,0.8086734414100647,0.6844852566719055,0.7085199952125549,1.1973555088043213,50000 -1595.1068725585938,4.020076274871826,41366.51425933838,122580,0,41366.51425933838,0.572100043296814,1.985614538192749,10000,42970.44410133362,0.7980110049247742,0.7274233102798462,0.6998400092124939,1.2325913906097412,50000 -1612.4037322998047,4.06778359413147,41876.44450306893,124095,0,41876.44450306893,0.586400032043457,1.8914276361465447,10000,43497.77945423126,0.8214086294174194,0.6441144943237305,0.7152599692344666,1.1662580966949463,50000 -1629.8644435405731,4.1161048412323,42386.406173706055,125608,0,42386.406173706055,0.5817000269889832,1.921865940093994,10000,44025.30989646912,0.8460817933082581,0.5478847026824951,0.7125200033187866,1.174476981163025,50000 -1647.4425375461578,4.171696662902832,42896.569796323776,127123,0,42896.569796323776,0.5838000178337097,1.920723557472229,10000,44553.16710352898,0.8364357352256775,0.592322587966919,0.7131800055503845,1.1688021421432495,50000 -1664.9784235954285,4.570464134216309,43406.31649875641,128636,0,43406.31649875641,0.5868000388145447,1.8865386247634888,10000,45080.90883421898,0.8303770422935486,0.6017466187477112,0.7099599838256836,1.1903725862503052,50000 -1682.1797287464142,4.620298624038696,43916.40103673935,130150,0,43916.40103673935,0.5943000316619873,1.8531285524368288,10000,45608.30464339256,0.8363161683082581,0.5795503854751587,0.7214199900627136,1.1396958827972412,50000 -1699.2731821537018,4.671350479125977,44426.34621119499,131663,0,44426.34621119499,0.5958000421524048,1.8866512775421145,10000,46135.45411705971,0.84086012840271,0.5709583163261414,0.7206400036811829,1.1557066440582275,50000 -1716.7603754997251,4.733880996704102,44936.3878133297,133177,0,44936.3878133297,0.5940999984741211,1.8570551872253416,10000,46663.10573005676,0.8496691584587097,0.540483295917511,0.7239800095558167,1.138985514640808,50000 -1734.3433163166046,4.786406517028809,45446.60967588425,134691,0,45446.60967588425,0.5897000432014465,1.8905266523361208,10000,47191.0237467289,0.865652859210968,0.4808640778064728,0.7212600111961365,1.1499614715576172,50000 -1751.4541499614716,4.838630199432373,45956.50417351723,136205,0,45956.50417351723,0.6017000079154968,1.838789939880371,10000,47718.14137125015,0.8681640625,0.465312123298645,0.726859986782074,1.1234965324401855,50000 -1768.5149869918823,4.89011025428772,46466.70485925674,137719,0,46466.70485925674,0.5908000469207764,1.8979063034057613,10000,48245.51445031166,0.8539540767669678,0.5145111083984375,0.7205399870872498,1.144445538520813,50000 -1785.825518131256,4.946522235870361,46976.7206556797,139234,0,46976.7206556797,0.5954000353813171,1.8895883560180664,10000,48772.95825433731,0.8557477593421936,0.4996830224990845,0.7240999937057495,1.1292551755905151,50000 -1803.1059713363647,4.9977428913116455,47486.69790649414,140747,0,47486.69790649414,0.6037000417709351,1.8393598794937127,10000,49300.32685422897,0.8660913109779358,0.4651607573032379,0.728659987449646,1.1148227453231812,50000 -1820.3083732128143,5.054003477096558,47996.81061291695,142262,0,47996.81061291695,0.5940999984741211,1.8927258253097528,10000,49827.75827026367,0.8884326815605164,0.3945820927619934,0.7259399890899658,1.1340726613998413,50000 -1837.5629630088808,5.1136510372161865,48506.87425112724,143776,0,48506.87425112724,0.6053000092506409,1.8732010126113887,10000,50355.19622254372,0.8937739133834839,0.3656793534755707,0.7310000061988831,1.1196391582489014,50000 -1855.067155122757,5.179184913635254,49016.82561182976,145291,0,49016.82561182976,0.6067000031471252,1.829805374145508,10000,50882.77686429024,0.8908242583274841,0.3806366622447967,0.7371199727058411,1.0951299667358398,50000 -1872.092421293259,5.233065366744995,49526.766122579575,146804,0,49526.766122579575,0.6051000356674194,1.860052227973938,10000,51409.85617089272,0.8893893361091614,0.3860467076301574,0.7328799962997437,1.1125147342681885,50000 -1889.482700824737,5.284041404724121,50036.77022147179,148319,0,50036.77022147179,0.612500011920929,1.8237643241882324,10000,51937.361750125885,0.8917610049247742,0.3722779452800751,0.7381599545478821,1.0851082801818848,50000 -1906.845090150833,5.336570739746094,50546.90058851242,149833,0,50546.90058851242,0.6081000566482544,1.878450512886048,10000,52464.96715068817,0.8892498016357422,0.373854249715805,0.7322799563407898,1.1266218423843384,50000 -1923.9228394031525,5.388533115386963,51057.0911295414,151348,0,51057.0911295414,0.6128000020980835,1.836662530899048,10000,52992.34717011452,0.9268175959587096,0.266787976026535,0.7402799725532532,1.0880799293518066,50000 -1942.4130220413208,5.445278167724609,51567.03753566742,152861,0,51567.03753566742,0.6173000335693359,1.8380128145217896,10000,53520.90021848679,0.9162946343421936,0.2905495464801788,0.739799976348877,1.094793677330017,50000 -1959.7907156944275,5.500972747802734,52077.10262107849,154375,0,52077.10262107849,0.6141000390052795,1.8534924983978271,10000,54048.459849357605,0.9135442972183228,0.3000372946262359,0.7407799959182739,1.092759132385254,50000 -1977.1947882175448,5.554906845092773,52587.03049516678,155888,0,52587.03049516678,0.6178000569343567,1.8527776002883911,10000,54575.90499091149,0.9125877022743224,0.2956430912017822,0.7436400055885315,1.0900652408599854,50000 -1994.400959968567,5.611454725265503,53096.94408249855,157402,0,53096.94408249855,0.619100034236908,1.8253930807113647,10000,55103.1409778595,0.9200215339660645,0.2833128273487091,0.7455399632453918,1.0706331729888916,50000 -2011.4396004676817,5.665586709976196,53606.83699655533,158915,0,53606.83699655533,0.6178000569343567,1.8164268732070925,10000,55630.18832850456,0.9239875674247742,0.2668489813804626,0.7450999617576599,1.073038935661316,50000 -2028.765647888184,5.735450029373169,54116.81638598442,160429,0,54116.81638598442,0.6201000213623047,1.828052401542664,10000,56157.62402033806,0.942602038383484,0.2054823786020279,0.7473999857902527,1.0706602334976196,50000 -2046.1212282180784,5.791545629501343,54626.75292801857,161943,0,54626.75292801857,0.6215000152587891,1.846395254135132,10000,56685.03358387947,0.9362045526504515,0.2241264581680297,0.7438799738883972,1.0907021760940552,50000 -2063.5502858161926,5.850171327590942,55136.65639901161,163457,0,55136.65639901161,0.6205000281333923,1.84682035446167,10000,57212.48588275909,0.9362045526504515,0.2207538038492202,0.746999979019165,1.0740346908569336,50000 -2080.6110968589783,5.909976959228516,55646.5770778656,164971,0,55646.5770778656,0.6273000240325928,1.8529764413833616,10000,57739.58699655533,0.9383171200752258,0.2114021331071853,0.7476199865341187,1.0842297077178955,50000 -2097.937203168869,5.965330123901367,56156.48152041435,166484,0,56156.48152041435,0.6234000325202942,1.8334190845489504,10000,58266.93318080902,0.9404296875,0.2076198607683181,0.7492600083351135,1.065981149673462,50000 -2115.0789551734924,6.01855206489563,56666.4612839222,167997,0,56666.4612839222,0.6291000247001648,1.8146344423294067,10000,58794.1668009758,0.9448142051696776,0.196561262011528,0.7514399886131287,1.0683664083480835,50000 -2132.4454357624054,6.0728843212127686,57176.4802467823,169510,0,57176.4802467823,0.624500036239624,1.82466459274292,10000,59321.66878390312,0.9545400142669678,0.1681820601224899,0.7520399689674377,1.063841462135315,50000 -2149.9335446357727,6.129309177398682,57686.42001056671,171023,0,57686.42001056671,0.6247000098228455,1.8314189910888672,10000,59849.21348261833,0.9538225531578064,0.1663007885217666,0.7504199743270874,1.0649524927139282,50000 -2167.4458363056183,6.185018539428711,58196.485368967056,172538,0,58196.485368967056,0.6278000473976135,1.8282188177108765,10000,60376.90635275841,0.9530851244926452,0.1686789840459823,0.7537999749183655,1.0632176399230957,50000 -2184.832007408142,6.240122318267822,58706.56109046936,174052,0,58706.56109046936,0.6285000443458557,1.8208074569702148,10000,60904.483194589615,0.954699456691742,0.1657977998256683,0.7526999711990356,1.06158185005188,50000 -2201.833286046982,6.310078859329224,59216.54894709587,175566,0,59216.54894709587,0.626800000667572,1.83081316947937,10000,61431.60178184509,0.9553571343421936,0.1650108247995376,0.7545799612998962,1.0558218955993652,50000 -2219.350727558136,6.369282007217407,59726.47098231316,177079,0,59726.47098231316,0.6285000443458557,1.8204129934310915,10000,61959.16131854057,0.9563536047935486,0.1589114516973495,0.7540599703788757,1.0573482513427734,50000 -2236.7588534355164,6.427386999130249,60236.56385469437,178592,0,60236.56385469437,0.6297000050544739,1.8235666751861568,10000,62486.78073191643,0.960379421710968,0.1490341424942016,0.7551800012588501,1.055389404296875,50000 -2254.0810379981995,6.485968351364136,60746.7347638607,180106,0,60746.7347638607,0.6306000351905823,1.820485591888428,10000,63014.39211153984,0.9593032598495485,0.1496942341327667,0.7559199929237366,1.0523293018341064,50000 -2271.384888648987,6.543102502822876,61256.86021757126,181621,0,61256.86021757126,0.628600001335144,1.819870948791504,10000,63541.93805527687,0.960339605808258,0.1473979651927948,0.7554599642753601,1.0532605648040771,50000 -2288.3898487091064,6.603374719619751,61766.89761161804,183134,0,61766.89761161804,0.6314000487327576,1.8186100721359253,10000,64069.10019540787,0.9602000713348388,0.1495323926210403,0.7563799619674683,1.051788330078125,50000 -2305.942921876908,6.662899971008301,62276.97330927849,184648,0,62276.97330927849,0.6309000253677368,1.817617654800415,10000,64596.847626686096,0.9607780575752258,0.1484294682741165,0.7558199763298035,1.0527596473693848,50000 -2322.951239347458,6.7222740650177,62786.884974718094,186161,0,62786.884974718094,0.6305000185966492,1.8158655166625977,10000,65123.88558316231,0.9616150856018066,0.1467247605323791,0.7559599876403809,1.051039218902588,50000 -2340.3564281463623,6.781188726425171,63296.81457424164,187675,0,63296.81457424164,0.6314000487327576,1.819074869155884,10000,65651.33988261223,0.9621930718421936,0.1422573626041412,0.7559999823570251,1.0514984130859375,50000 -2357.5116155147552,6.8434693813323975,63806.97748732567,189189,0,63806.97748732567,0.6312000155448914,1.81675398349762,10000,66178.77985548973,0.9604990482330322,0.1454890668392181,0.7562599778175354,1.0513103008270264,50000 -2374.5045187473297,6.90625524520874,64316.961555957794,190703,0,64316.961555957794,0.6309000253677368,1.817474722862244,10000,66705.87880730629,0.959582269191742,0.1502355635166168,0.7557399868965149,1.0518090724945068,50000 -2392.830799102783,6.966517210006714,64826.8895072937,192217,0,64826.8895072937,0.6314000487327576,1.8174049854278564,10000,67234.2537624836,0.9605388641357422,0.147024780511856,0.7558599710464478,1.0514979362487793,50000 -2410.100454807281,7.022372007369995,65337.07450246811,193731,0,65337.07450246811,0.6310000419616699,1.8178002834320068,10000,67761.82452487946,0.9588249325752258,0.1494628936052322,0.7556399703025818,1.0511360168457031,50000 -2427.3492562770844,7.086922645568848,65846.99104523659,195244,0,65846.99104523659,0.631600022315979,1.81663978099823,10000,68289.11477446556,0.9619937539100648,0.1456216126680374,0.755840003490448,1.052322268486023,50000 -2444.515809059143,7.146530151367188,66356.99051713943,196758,0,66356.99051713943,0.6321000456809998,1.8178486824035645,10000,68816.40245962143,0.9612762928009032,0.1447423547506332,0.7559399604797363,1.052221417427063,50000 -2461.702263355255,7.207445859909058,66866.98143553734,198271,0,66866.98143553734,0.6309000253677368,1.8155436515808103,10000,69343.70125079155,0.9612962007522584,0.1464032083749771,0.7554999589920044,1.0522739887237549,50000 -2479.025264263153,7.268761396408081,67376.91571760178,199785,0,67376.91571760178,0.6308000087738037,1.8178911209106443,10000,69871.08100628853,0.9602997303009032,0.1467967927455902,0.7558199763298035,1.0520784854888916,50000 -2496.34637260437,7.337639093399048,67887.04694342613,201299,0,67887.04694342613,0.6306000351905823,1.818956971168518,10000,70398.6608979702,0.9599011540412904,0.1499934792518615,0.7560399770736694,1.052374243736267,50000 -2513.746617078781,7.402098655700684,68397.03761839867,202813,0,68397.03761839867,0.6317000389099121,1.817797303199768,10000,70926.17774367332,0.961734652519226,0.1433680802583694,0.755840003490448,1.0518240928649902,50000 -2530.812096595764,7.462094306945801,68907.05328130722,204327,0,68907.05328130722,0.6305000185966492,1.817637801170349,10000,71453.3792629242,0.9606783986091614,0.1463976204395294,0.756060004234314,1.0515127182006836,50000 -2548.2866756916046,7.5248517990112305,69417.21089410782,205842,0,69417.21089410782,0.6310000419616699,1.8176188468933103,10000,71981.13394188881,0.9607381820678712,0.1457848995923996,0.7559799551963806,1.0517550706863403,50000 -2565.68039393425,7.586683750152588,69927.16663312912,207356,0,69927.16663312912,0.631600022315979,1.818718791007996,10000,72508.60455965996,0.9599609375,0.1481670588254928,0.7555399537086487,1.0524675846099854,50000 -2582.988062143326,7.651637077331543,70437.09859132767,208870,0,70437.09859132767,0.6312000155448914,1.816641926765442,10000,73035.96860289574,0.9607780575752258,0.1466470956802368,0.7554000020027161,1.0518457889556885,50000 -2600.2966318130493,7.717857122421265,70947.238874197,210384,0,70947.238874197,0.6304000020027161,1.817249059677124,10000,73563.5444612503,0.9600605964660645,0.1470133066177368,0.7558599710464478,1.0524232387542725,50000 -2617.373607635498,7.779970407485962,71457.24890565872,211897,0,71457.24890565872,0.6318000555038452,1.8178008794784544,10000,74090.75350880623,0.9603196382522584,0.1500230282545089,0.7558599710464478,1.052908420562744,50000 -2634.5299496650696,7.84785270690918,71967.18639993668,213411,0,71967.18639993668,0.6317000389099121,1.818710446357727,10000,74617.97519087791,0.9616748690605164,0.1448988020420074,0.7557399868965149,1.0528243780136108,50000 -2651.8291053771973,7.914536952972412,72477.23326277733,214925,0,72477.23326277733,0.6321000456809998,1.8189986944198608,10000,75145.44805765152,0.9598811864852904,0.1491194367408752,0.7555800080299377,1.0532875061035156,50000 -2668.91423535347,7.982542037963867,72987.39051795006,216439,0,72987.39051795006,0.6322000026702881,1.8198258876800537,10000,75672.81736660004,0.9616549611091614,0.1474548429250717,0.7560200095176697,1.0531450510025024,50000 -2686.1898329257965,8.048727750778198,73497.54736804962,217953,0,73497.54736804962,0.6304000020027161,1.8178809881210327,10000,76200.37576818466,0.96000075340271,0.1477039009332656,0.7557799816131592,1.052709460258484,50000 -2703.3161799907684,8.112717628479004,74007.73700547218,219467,0,74007.73700547218,0.6317000389099121,1.8171485662460327,10000,76727.81662344933,0.9601203799247742,0.1478992104530334,0.7565199732780457,1.0514085292816162,50000 -2720.639880657196,8.177234411239624,74517.93527269363,220981,0,74517.93527269363,0.6320000290870667,1.8170801401138303,10000,77255.46332716942,0.960339605808258,0.1483343243598938,0.7559799551963806,1.0523303747177124,50000 -2738.0824706554413,8.243817806243896,75028.07389211655,222495,0,75028.07389211655,0.6313000321388245,1.8161752223968504,10000,77783.17099523544,0.960758090019226,0.1472418159246444,0.7555599808692932,1.0518176555633545,50000 -2755.275614976883,8.306404113769531,75538.01856184006,224009,0,75538.01856184006,0.6301000118255615,1.8154196739196773,10000,78310.43208026886,0.9613958597183228,0.1484321802854538,0.7560399770736694,1.0512542724609375,50000 -2772.6055793762207,8.376575708389282,76047.95550060272,225522,0,76047.95550060272,0.6317000389099121,1.818994641304016,10000,78837.83005452156,0.961336076259613,0.1458679139614105,0.7559599876403809,1.0527108907699585,50000 -2789.5629415512085,8.442211389541626,76557.94317746162,227037,0,76557.94317746162,0.631100058555603,1.8159196376800537,10000,79364.89956021309,0.961933970451355,0.1421834528446197,0.7557599544525146,1.0518858432769775,50000 -2807.102013349533,8.51409387588501,77068.1039595604,228550,0,77068.1039595604,0.6301000118255615,1.8176794052124023,10000,79892.73308634758,0.9608577489852904,0.1471737623214721,0.7555599808692932,1.052388072013855,50000 -2825.0334997177124,8.582934141159058,77578.0581843853,230064,0,77578.0581843853,0.6319000124931335,1.8184075355529783,10000,80420.75643348694,0.9589245915412904,0.1493618786334991,0.7556999921798706,1.051660180091858,50000 -2842.399694442749,8.647887229919434,78088.10909414291,231577,0,78088.10909414291,0.6313000321388245,1.8174564838409424,10000,80948.29805445671,0.960718274116516,0.1464399844408035,0.755840003490448,1.0522643327713013,50000 -2859.735461950302,8.714589595794678,78598.0595126152,233091,0,78598.0595126152,0.6313000321388245,1.818180799484253,10000,81475.70970320702,0.9599409699440002,0.1503974199295044,0.7559799551963806,1.051724553108215,50000 -2876.846007347107,8.78090786933899,79107.94815778732,234604,0,79107.94815778732,0.6304000020027161,1.8177772760391235,10000,82002.83469963074,0.9614955186843872,0.1457913517951965,0.7561999559402466,1.0515480041503906,50000 -2894.1628913879395,8.845736265182495,79617.99672365189,236118,0,79617.99672365189,0.6310000419616699,1.817879557609558,10000,82530.32462906837,0.9614357352256776,0.1441568434238433,0.7552199959754944,1.052870512008667,50000 -2911.3059771060944,8.911030530929565,80127.95319676399,237632,0,80127.95319676399,0.631100058555603,1.817185401916504,10000,83057.54863166809,0.9603993892669678,0.1467776894569397,0.7555599808692932,1.0517122745513916,50000 -2928.5791189670563,8.975881099700928,80637.87059664726,239146,0,80637.87059664726,0.6321000456809998,1.817415714263916,10000,83584.8637034893,0.9608178734779358,0.1466523557901382,0.7553799748420715,1.0525615215301514,50000 -2945.8968670368195,9.044187307357788,81148.00268936157,240660,0,81148.00268936157,0.6310000419616699,1.8181267976760864,10000,84112.44211959839,0.960558831691742,0.1472288072109222,0.7558199763298035,1.052385687828064,50000 -2963.219983100891,9.121662378311155,81657.94163036346,242174,0,81657.94163036346,0.6317000389099121,1.8185575008392327,10000,84639.84378743172,0.961933970451355,0.144880786538124,0.7554799914360046,1.0521867275238037,50000 -2980.5571575164795,9.191231966018677,82167.82444095612,243687,0,82167.82444095612,0.6317000389099121,1.818875312805176,10000,85167.19301080704,0.960160195827484,0.1469377130270004,0.7560999989509583,1.0522353649139404,50000 -2997.735734939575,9.263187170028688,82677.70756721497,245201,0,82677.70756721497,0.631600022315979,1.817969560623169,10000,85694.38528704643,0.9598612785339355,0.1483822166919708,0.7560799717903137,1.0516401529312134,50000 -3014.9422421455383,9.33088445663452,83187.67843437195,246714,0,83187.67843437195,0.6306000351905823,1.8190034627914429,10000,86221.69176626205,0.9618144035339355,0.145142525434494,0.7558199763298035,1.0535144805908203,50000 -3032.2274081707,9.39405369758606,83697.82090711594,248229,0,83697.82090711594,0.6314000487327576,1.8160229921340945,10000,86749.24345612526,0.9604790806770324,0.1466723680496215,0.7559199929237366,1.051486253738403,50000 -3049.45293712616,9.46442985534668,84207.96512365341,249743,0,84207.96512365341,0.6313000321388245,1.8176978826522827,10000,87276.74302601814,0.9604392051696776,0.1476988494396209,0.7559199929237366,1.0521618127822876,50000 -3066.6506621837616,9.539054870605469,84718.10183787346,251259,0,84718.10183787346,0.6307000517845154,1.816618800163269,10000,87804.21118330956,0.960758090019226,0.1481978744268417,0.7556799650192261,1.0509867668151855,50000 -3083.81050491333,9.60982871055603,85228.26144003868,252773,0,85228.26144003868,0.6312000155448914,1.8171991109848025,10000,88331.66128325462,0.9607381820678712,0.1479832530021667,0.7556999921798706,1.0520581007003784,50000 -3100.79234957695,9.67766046524048,85738.3051803112,254287,0,85738.3051803112,0.6307000517845154,1.8190534114837649,10000,88858.81435346603,0.9598612785339355,0.1501981317996978,0.7560399770736694,1.052698850631714,50000 -3118.1718657016754,9.745689868927002,86248.5129327774,255801,0,86248.5129327774,0.6309000253677368,1.8166823387146,10000,89386.53042554855,0.9608378410339355,0.1464589387178421,0.7556399703025818,1.0526121854782104,50000 -3135.312481403351,9.82122540473938,86758.58062577248,257316,0,86758.58062577248,0.6323000192642212,1.817854166030884,10000,89913.8745894432,0.9597616195678712,0.1502071619033813,0.756119966506958,1.0523533821105957,50000 -3152.405182123184,9.933086156845093,87268.71463608742,258830,0,87268.71463608742,0.6314000487327576,1.817800521850586,10000,90441.27367305756,0.9610371589660645,0.1454240679740905,0.7561799883842468,1.0518678426742554,50000 -3169.6641025543213,10.010907649993896,87778.6434044838,260343,0,87778.6434044838,0.6314000487327576,1.8175643682479856,10000,90968.59873461723,0.9601203799247742,0.1482060104608535,0.7556799650192261,1.05309796333313,50000 -3186.8991684913635,10.086401224136353,88288.82423329353,261857,0,88288.82423329353,0.6306000351905823,1.8180071115493768,10000,91496.1502726078,0.9605787396430968,0.1496013551950454,0.7554999589920044,1.051580786705017,50000 -3204.0929505825043,10.161179780960085,88798.76346611977,263370,0,88798.76346611977,0.6309000253677368,1.8187938928604128,10000,92023.4181470871,0.9611965417861938,0.1485585421323776,0.7565799951553345,1.0513132810592651,50000 -3221.382915019989,10.232295989990234,89308.81522655487,264883,0,89308.81522655487,0.6309000253677368,1.818920373916626,10000,92550.89099025726,0.9622528553009032,0.1400790512561798,0.755620002746582,1.0535967350006104,50000 -3238.949809551239,10.304664373397827,89818.75279092789,266396,0,89818.75279092789,0.6300000548362732,1.8159278631210327,10000,93078.5286242962,0.960758090019226,0.1475231349468231,0.7558599710464478,1.050364375114441,50000 -3256.099337339401,10.37904691696167,90328.8978767395,267910,0,90328.8978767395,0.6308000087738037,1.817447304725647,10000,93605.95786976814,0.9594228267669678,0.147891879081726,0.7556399703025818,1.0531243085861206,50000 -3274.080150365829,10.45369815826416,90838.84168171884,269424,0,90838.84168171884,0.6317000389099121,1.8160980939865112,10000,94134.0173215866,0.9600805044174194,0.1480940580368042,0.7558799982070923,1.0508418083190918,50000 -3291.6022622585297,10.529751300811768,91348.74285554886,270937,0,91348.74285554886,0.6314000487327576,1.818058729171753,10000,94661.57633805276,0.9594626426696776,0.1487428247928619,0.755899965763092,1.0522432327270508,50000 -3308.995859146118,10.616759300231934,91858.74562954904,272451,0,91858.74562954904,0.6312000155448914,1.81630539894104,10000,95189.11938214302,0.9612962007522584,0.1474275290966034,0.7554799914360046,1.051373839378357,50000 -3326.2125630378723,10.689738750457764,92368.68969726562,273964,0,92368.68969726562,0.6303000450134277,1.8180094957351685,10000,95716.41310477255,0.9615553021430968,0.1442984044551849,0.7552399635314941,1.0525203943252563,50000 -3343.2138271331787,10.763943195343018,92878.66392302512,275478,0,92878.66392302512,0.6309000253677368,1.8165348768234253,10000,96243.52274870872,0.9615951776504515,0.144477903842926,0.7556399703025818,1.0514625310897827,50000 -3360.4372668266296,10.837567806243896,93388.64336013794,276992,0,93388.64336013794,0.6308000087738037,1.818804621696472,10000,96770.8598151207,0.9596420526504515,0.1479740142822265,0.7559399604797363,1.0522382259368896,50000 -3377.6312334537506,10.91645884513855,93898.81333494186,278506,0,93898.81333494186,0.631100058555603,1.817908525466919,10000,97298.36285185814,0.9606584906578064,0.1490184068679809,0.7557799816131592,1.0525351762771606,50000 -3394.649137496948,10.992409467697144,94408.975086689,280020,0,94408.975086689,0.6305000185966492,1.818747520446777,10000,97825.67763566972,0.9610969424247742,0.1430348008871078,0.7557399868965149,1.0523576736450195,50000 -3411.765170574188,11.085949182510376,94919.01567602158,281535,0,94919.01567602158,0.631600022315979,1.8173832893371584,10000,98352.9867374897,0.9610371589660645,0.1468787193298339,0.7556399703025818,1.0521724224090576,50000 -3428.685005664825,11.162020683288574,95428.90130352974,283049,0,95428.90130352974,0.6310000419616699,1.817343831062317,10000,98879.92723703384,0.9612165093421936,0.1450749486684799,0.7563599944114685,1.0516657829284668,50000 -3445.661366701126,11.237503051757812,95938.8274178505,284562,0,95938.8274178505,0.6309000253677368,1.819740653038025,10000,99406.96535897256,0.9595025181770324,0.1488256752490997,0.7557399868965149,1.0531953573226929,50000 -3463.0189859867096,11.311057567596436,96449.02365016936,286076,0,96449.02365016936,0.6315000057220459,1.818746566772461,10000,99934.65275216104,0.9608378410339355,0.1471054852008819,0.756060004234314,1.052351713180542,50000 -3480.0775611400604,11.383795976638794,96958.94147777556,287589,0,96958.94147777556,0.6321000456809998,1.8166067600250244,10000,100461.76079773904,0.9601203799247742,0.1463282853364944,0.7557399868965149,1.051419377326965,50000 -3497.1005821228027,11.457651138305664,97468.86407732964,289102,0,97468.86407732964,0.6308000087738037,1.817582368850708,10000,100988.83882284164,0.9602997303009032,0.1500391662120819,0.756119966506958,1.0513322353363037,50000 -3514.145917892456,11.534794092178345,97978.9630844593,290616,0,97978.9630844593,0.6315000057220459,1.817838311195373,10000,101516.11999130248,0.960718274116516,0.1463023722171783,0.7559799551963806,1.052631974220276,50000 -3531.315213680268,11.611164569854736,98489.15401434898,292130,0,98489.15401434898,0.6310000419616699,1.8188873529434204,10000,102043.6172683239,0.9606186151504515,0.1470670998096466,0.7553799748420715,1.0529109239578247,50000 -3548.455144643784,11.689434289932253,98999.07892155647,293643,0,98999.07892155647,0.6307000517845154,1.81728196144104,10000,102570.82220602036,0.9607381820678712,0.1502828449010849,0.755620002746582,1.0518311262130735,50000 -3565.4192507267,11.765220880508425,99509.08157086372,295157,0,99509.08157086372,0.6307000517845154,1.817894458770752,10000,103097.92510271072,0.9604591727256776,0.146732673048973,0.7555599808692932,1.052149534225464,50000 -3582.733262300492,11.8424334526062,100019.2240486145,296671,0,100019.2240486145,0.6307000517845154,1.817719578742981,10000,103625.51902842522,0.9592434167861938,0.149974912405014,0.7557799816131592,1.052897572517395,50000 -3599.947383403778,11.925029277801514,100529.0994002819,298184,0,100529.0994002819,0.6317000389099121,1.816598892211914,10000,104152.75135087968,0.9615553021430968,0.1456803530454635,0.7562999725341797,1.0515629053115845,50000 -3617.3545808792114,12.00115966796875,101039.07216191292,299698,0,101039.07216191292,0.6315000057220459,1.8183579444885247,10000,104680.2663693428,0.961355984210968,0.1456865221261978,0.7557399868965149,1.0521283149719238,50000 -3635.0120573043823,12.07826280593872,101549.00704622269,301211,0,101549.00704622269,0.6318000555038452,1.81587815284729,10000,105207.99576115608,0.960598647594452,0.1500077545642852,0.7556599974632263,1.0508867502212524,50000 -3652.1334154605865,12.158013582229614,102058.9766690731,302724,0,102058.9766690731,0.6313000321388245,1.8171393871307373,10000,105735.22607922554,0.9607979655265808,0.1483609229326248,0.7554000020027161,1.0511231422424316,50000 -3669.119397878647,12.233417749404907,102569.1431913376,304237,0,102569.1431913376,0.6307000517845154,1.817326307296753,10000,106262.51443743706,0.9625318646430968,0.1406024247407913,0.7562400102615356,1.0528939962387085,50000 -3686.292558431626,12.312384128570557,103079.07321691512,305750,0,103079.07321691512,0.6307000517845154,1.8191286325454712,10000,106789.75725913048,0.960339605808258,0.1479188501834869,0.7556999921798706,1.0525130033493042,50000 -3704.487591743469,12.390041828155518,103589.23261809348,307264,0,103589.23261809348,0.6309000253677368,1.8189918994903564,10000,107318.24916386604,0.9594826102256776,0.1492210626602172,0.7557599544525146,1.0515639781951904,50000 -3721.745894193649,12.466761350631714,104099.11788463593,308777,0,104099.11788463593,0.6315000057220459,1.8163002729415887,10000,107845.5293803215,0.960379421710968,0.1470292061567306,0.7558799982070923,1.051123023033142,50000 -3738.711463928223,12.54440450668335,104609.262250185,310291,0,104609.262250185,0.6304000020027161,1.81773841381073,10000,108372.77680826189,0.959741711616516,0.148784339427948,0.7563199996948242,1.0511804819107056,50000 -3755.8963787555695,12.624755382537842,105119.16359496117,311805,0,105119.16359496117,0.631100058555603,1.81870186328888,10000,108900.00142145155,0.9624720811843872,0.1441033333539962,0.756119966506958,1.0523790121078491,50000 -3772.8295102119446,12.706658840179443,105629.2938606739,313319,0,105629.2938606739,0.6314000487327576,1.8169692754745483,10000,109427.20877575874,0.9609375,0.1451739966869354,0.7560399770736694,1.0512357950210571,50000 -3790.351084470749,12.800882577896118,106139.17324519156,314833,0,106139.17324519156,0.632900059223175,1.8191454410552976,10000,109954.76312589644,0.9600605964660645,0.1482522189617157,0.755620002746582,1.052326798439026,50000 -3807.474318265915,12.86822748184204,106649.0578083992,316345,0,106649.0578083992,0.631600022315979,1.8172281980514529,10000,110481.89925575256,0.960379421710968,0.1486100107431411,0.7559799551963806,1.0517127513885498,50000 -3824.4763877391815,12.949540615081789,107159.1848680973,317860,0,107159.1848680973,0.6304000020027161,1.817755937576294,10000,111009.16904592514,0.9601004123687744,0.1470013409852981,0.7554199695587158,1.0530093908309937,50000 -3841.6000397205353,13.036304712295532,107669.21872401236,319373,0,107669.21872401236,0.6319000124931335,1.8153104782104488,10000,111536.4730143547,0.9612563848495485,0.1443883031606674,0.7559799551963806,1.051807880401611,50000 -3858.7794704437256,13.124491453170776,108179.18863844872,320887,0,108179.18863844872,0.6313000321388245,1.815914750099182,10000,112063.77143478394,0.961136758327484,0.1468666344881057,0.7561799883842468,1.0516546964645386,50000 -3876.076628923416,13.206662654876707,108689.0499472618,322400,0,108689.0499472618,0.6304000020027161,1.81884515285492,10000,112591.07172703744,0.9613958597183228,0.1446283608675003,0.7553799748420715,1.0537073612213137,50000 -3893.3367924690247,13.290921926498411,109199.06824493408,323914,0,109199.06824493408,0.6319000124931335,1.8185746669769287,10000,113118.49463415146,0.9598811864852904,0.1484560519456863,0.7558199763298035,1.0515974760055542,50000 -3910.58306145668,13.375164985656738,109708.97253608704,325427,0,109708.97253608704,0.6305000185966492,1.8173801898956297,10000,113645.7896540165,0.9601402878761292,0.1478699147701263,0.7559799551963806,1.0521526336669922,50000 -3927.7408850193024,13.466815948486328,110218.91332650185,326941,0,110218.91332650185,0.631600022315979,1.8180609941482544,10000,114173.03961229324,0.9608777165412904,0.1470252126455307,0.7557599544525146,1.0519651174545288,50000 -3944.9255831241608,13.548635721206663,110728.96795773506,328454,0,110728.96795773506,0.6323000192642212,1.8172355890274048,10000,114700.41973114014,0.9614756107330322,0.1458468437194824,0.7557599544525146,1.051677942276001,50000 -3962.061777353287,13.631106853485107,111238.9862473011,329968,0,111238.9862473011,0.6317000389099121,1.816851019859314,10000,115227.71682953836,0.9604192972183228,0.1479916721582412,0.7558199763298035,1.0517678260803225,50000 -3978.993365049362,13.714733600616457,111749.04176855087,331481,0,111749.04176855087,0.631100058555603,1.8166096210479736,10000,115754.84624505044,0.9598811864852904,0.1524875611066818,0.7560999989509583,1.0509434938430786,50000 -3996.122615098953,13.801108837127686,112258.91333699226,332995,0,112258.91333699226,0.6313000321388245,1.818378210067749,10000,116281.9922375679,0.961734652519226,0.14434514939785,0.7558000087738037,1.0530338287353516,50000 -4013.174238204956,13.886026620864868,112768.79986357687,334508,0,112768.79986357687,0.631100058555603,1.819270372390747,10000,116809.0741918087,0.9585060477256776,0.1500162780284881,0.7559999823570251,1.052400827407837,50000 -4030.716554164888,13.973962306976318,113278.91282248496,336022,0,113278.91282248496,0.631600022315979,1.8194116353988647,10000,117336.87652873991,0.961336076259613,0.1454212963581085,0.7561999559402466,1.0526785850524902,50000 -4047.8868165016174,14.062806129455566,113789.0877354145,337536,0,113789.0877354145,0.6317000389099121,1.8193473815917969,10000,117864.37035274506,0.9607979655265808,0.1482814997434616,0.7560799717903137,1.0519975423812866,50000 -4064.900999069214,14.147011280059814,114299.129904747,339051,0,114299.129904747,0.6312000155448914,1.817227244377136,10000,118391.57015228271,0.959741711616516,0.1503542214632034,0.7552399635314941,1.052941083908081,50000 -4081.9577882289886,14.231061697006226,114809.09057998656,340565,0,114809.09057998656,0.6317000389099121,1.819175243377685,10000,118918.73136019708,0.9613958597183228,0.1452462524175644,0.7559999823570251,1.053504467010498,50000 -4099.052008152008,14.317641258239746,115319.19979858398,342080,0,115319.19979858398,0.6309000253677368,1.8193016052246087,10000,119446.07986211775,0.9623724222183228,0.1430144309997558,0.7557199597358704,1.0525835752487185,50000 -4116.148473501205,14.404023885726929,115829.36331558228,343594,0,115829.36331558228,0.631600022315979,1.8172719478607176,10000,119973.48504543304,0.960598647594452,0.1461528092622757,0.7560799717903137,1.0523217916488647,50000 -4133.013184070587,14.49116039276123,116339.2823779583,345108,0,116339.2823779583,0.6304000020027161,1.8181381225585933,10000,120500.41694951056,0.960339605808258,0.1466981768608093,0.755899965763092,1.0524567365646362,50000 -4151.086514234543,14.57667088508606,116849.27911138536,346622,0,116849.27911138536,0.6312000155448914,1.817665100097656,10000,121028.6317665577,0.9592235088348388,0.148650661110878,0.7559799551963806,1.051628589630127,50000 -4168.160834312439,14.661510467529297,117359.37034726144,348136,0,117359.37034726144,0.6324000358581543,1.816881895065308,10000,121555.94112229349,0.9605388641357422,0.1495785564184188,0.7561599612236023,1.0520284175872805,50000 -4185.293612003326,14.747785091400146,117869.27389907835,349650,0,117869.27389907835,0.6306000351905823,1.8176124095916748,10000,122083.12540221214,0.960758090019226,0.1472361534833908,0.7560799717903137,1.0515118837356567,50000 -4202.453873872757,14.835796117782593,118379.28754210472,351164,0,118379.28754210472,0.6312000155448914,1.8179231882095337,10000,122610.44742536543,0.9616150856018066,0.145010843873024,0.756339967250824,1.0514135360717771,50000 -4219.443992614746,14.937951803207396,118889.12491345406,352677,0,118889.12491345406,0.6313000321388245,1.817135453224182,10000,123137.43913602828,0.961355984210968,0.1440159380435943,0.7555800080299377,1.051655650138855,50000 -4236.364371299744,15.025042533874512,119399.18874073029,354191,0,119399.18874073029,0.6313000321388245,1.8169622421264648,10000,123664.57041716576,0.9601203799247742,0.1484006345272064,0.7557599544525146,1.0514042377471924,50000 -4253.577853441238,15.125241994857788,119909.20295524596,355704,0,119909.20295524596,0.6308000087738037,1.818961262702942,10000,124191.95948266984,0.9602997303009032,0.1497524827718734,0.7556799650192261,1.0527878999710083,50000 -4270.738157272339,15.209940195083618,120419.31431365012,357218,0,120419.31431365012,0.6312000155448914,1.820838928222656,10000,124719.37481164932,0.9614157676696776,0.1427268236875534,0.7556999921798706,1.0533047914505005,50000 -4287.819470405579,15.2974956035614,120929.3461754322,358732,0,120929.3461754322,0.6318000555038452,1.8178112506866453,10000,125246.63512897491,0.96097731590271,0.1458709090948104,0.7553600072860718,1.052885890007019,50000 -4304.847638845444,15.390103816986084,121439.44315695764,360247,0,121439.44315695764,0.6318000555038452,1.817411184310913,10000,125773.91371035576,0.9602997303009032,0.1472483873367309,0.7559599876403809,1.0523381233215332,50000 -4322.003366947174,15.480299949645996,121949.60582256316,361761,0,121949.60582256316,0.6315000057220459,1.817147135734558,10000,126301.38193583488,0.9618144035339355,0.1445937156677246,0.7554599642753601,1.0517603158950806,50000 -4339.118178367615,15.568543672561646,122459.51914787292,363275,0,122459.51914787292,0.6308000087738037,1.8169174194335933,10000,126828.55815386772,0.9597616195678712,0.1499248892068863,0.7556399703025818,1.0522288084030151,50000 -4356.393066167831,15.656342029571531,122969.4705851078,364788,0,122969.4705851078,0.6325000524520874,1.817543625831604,10000,127355.93138194084,0.961136758327484,0.145282357931137,0.7560999989509583,1.051038384437561,50000 -4373.334760904312,15.750553607940674,123479.34378123283,366301,0,123479.34378123283,0.6317000389099121,1.818870186805725,10000,127882.89841532709,0.9606584906578064,0.1480656564235687,0.7558199763298035,1.0525013208389282,50000 -4390.496916055679,15.836224555969238,123989.37150287628,367814,0,123989.37150287628,0.6315000057220459,1.8194811344146729,10000,128410.23392796516,0.961156725883484,0.1470474749803543,0.7556599974632263,1.0533629655838013,50000 -4407.583589076996,15.926433086395264,124499.26441526412,369328,0,124499.26441526412,0.6312000155448914,1.8174362182617188,10000,128937.36512541772,0.9606983065605164,0.1482247412204742,0.7558599710464478,1.0521869659423828,50000 -4424.501345396042,16.016053199768066,125009.21130037308,370841,0,125009.21130037308,0.631600022315979,1.8188133239746087,10000,129464.3785161972,0.960598647594452,0.1493550240993499,0.7554399967193604,1.0532011985778809,50000 -4441.642533063889,16.10594081878662,125519.2497870922,372355,0,125519.2497870922,0.6315000057220459,1.8178412914276123,10000,129991.7078435421,0.9604392051696776,0.1477468013763427,0.7557599544525146,1.0522220134735107,50000 -4458.787220478058,16.19475793838501,126029.41074037552,373869,0,126029.41074037552,0.6319000124931335,1.8184047937393188,10000,130519.1618218422,0.9604192972183228,0.1477959901094436,0.7557399868965149,1.0529639720916748,50000 -4475.823565721512,16.284391403198242,126539.46411824226,375383,0,126539.46411824226,0.6308000087738037,1.81686007976532,10000,131046.4018945694,0.9611168503761292,0.1462110728025436,0.7556599974632263,1.0519813299179075,50000 -4492.737801551819,16.379032611846924,127049.39342689514,376897,0,127049.39342689514,0.6315000057220459,1.8168940544128416,10000,131573.39998173714,0.9594427347183228,0.1501639932394027,0.755620002746582,1.0519487857818604,50000 -4509.983863592148,16.47271466255188,127559.34850096704,378410,0,127559.34850096704,0.6308000087738037,1.817546963691712,10000,132100.75407028198,0.9618940949440002,0.1476005017757415,0.7558799982070923,1.0524741411209106,50000 -4527.163170099258,16.564187049865723,128069.45318174362,379924,0,128069.45318174362,0.631100058555603,1.817813754081726,10000,132628.18920063972,0.9614756107330322,0.1453332155942917,0.7559399604797363,1.0524545907974243,50000 -4544.638174533844,16.653773546218872,128579.37960290907,381437,0,128579.37960290907,0.6314000487327576,1.818215489387512,10000,133155.73970913887,0.9616150856018066,0.1441227793693542,0.7558000087738037,1.053081750869751,50000 -4561.747064352036,16.74677801132202,129089.33254384996,382950,0,129089.33254384996,0.6314000487327576,1.8172036409378047,10000,133682.95392680168,0.9603196382522584,0.1468247771263122,0.7559199929237366,1.0521663427352903,50000 -4579.573089122772,16.839839935302734,129599.38512706757,384464,0,129599.38512706757,0.6313000321388245,1.818960666656494,10000,134210.98501205444,0.9594826102256776,0.1478909403085708,0.7556799650192261,1.0529996156692505,50000 -4596.726794719696,16.932087182998657,130109.36382508278,385977,0,130109.36382508278,0.6315000057220459,1.816987991333008,10000,134738.26835227013,0.9602199792861938,0.1464436054229736,0.756119966506958,1.0509934425354004,50000 -4613.694131135941,17.029442310333252,130619.21668457983,387491,0,130619.21668457983,0.6314000487327576,1.8155715465545648,10000,135265.24450850487,0.9599609375,0.1485553085803985,0.755299985408783,1.051853060722351,50000 -4630.728091716766,17.116252183914185,131129.21193361282,389004,0,131129.21193361282,0.6309000253677368,1.8187190294265747,10000,135792.42049121857,0.9614357352256776,0.1465912759304046,0.7559399604797363,1.0521849393844604,50000 -4647.803809642792,17.206831455230713,131639.15394306183,390518,0,131639.15394306183,0.6307000517845154,1.8176332712173464,10000,136319.58885860443,0.960339605808258,0.1467641741037368,0.7556999921798706,1.0530712604522705,50000 -4664.913568973541,17.301775693893433,132149.21634721756,392031,0,132149.21634721756,0.6306000351905823,1.817515850067139,10000,136846.91794514656,0.9616150856018066,0.1448460966348648,0.755899965763092,1.0519707202911377,50000 -4682.0501046180725,17.397064208984375,132659.2893781662,393545,0,132659.2893781662,0.6306000351905823,1.8179852962493896,10000,137374.28175091743,0.9604192972183228,0.1471053957939148,0.7557199597358704,1.051891207695007,50000 -4698.965122699738,17.549309730529785,133169.1497850418,395058,0,133169.1497850418,0.6321000456809998,1.8184858560562127,10000,137901.2680785656,0.9599609375,0.148743599653244,0.7558599710464478,1.0518076419830322,50000 -4716.026647567749,17.64861249923706,133679.27163481712,396573,0,133679.27163481712,0.6306000351905823,1.8179006576538088,10000,138428.61062049866,0.9614357352256776,0.1444485336542129,0.7559399604797363,1.051970601081848,50000 -4733.000705003738,17.743713855743408,134189.4073984623,398087,0,134189.4073984623,0.631100058555603,1.8180168867111208,10000,138955.87549710274,0.96195387840271,0.1440633684396743,0.7553600072860718,1.0524848699569702,50000 -4750.216578006744,17.843615293502808,134699.39883041382,399601,0,134699.39883041382,0.6314000487327576,1.8171703815460205,10000,139483.24274683,0.9598014950752258,0.1474271565675735,0.7560799717903137,1.0510292053222656,50000 -4767.408142089844,17.938778400421143,135209.38359832764,401115,0,135209.38359832764,0.631600022315979,1.8182260990142824,10000,140010.5748746395,0.9605388641357422,0.1472002863883972,0.7559199929237366,1.0520803928375244,50000 -4784.297115325928,18.03217339515686,135719.27980732918,402628,0,135719.27980732918,0.6314000487327576,1.8170477151870728,10000,140537.51276612282,0.9599609375,0.1483696550130844,0.7554799914360046,1.0520527362823486,50000 -4801.340654373169,18.129441499710083,136229.12922286987,404141,0,136229.12922286987,0.6302000284194946,1.81659734249115,10000,141064.5624115467,0.9599409699440002,0.1486997753381729,0.7559199929237366,1.05121648311615,50000 -4818.499400377274,18.232341051101685,136739.0093715191,405655,0,136739.0093715191,0.6307000517845154,1.8162307739257808,10000,141591.76547813416,0.9614157676696776,0.1460102349519729,0.7556399703025818,1.051106333732605,50000 -4835.725882053375,18.32414746284485,137248.94324493408,407168,0,137248.94324493408,0.6312000155448914,1.818498373031616,10000,142119.07744646072,0.9608577489852904,0.1470111906528473,0.7559199929237366,1.0524506568908691,50000 -4852.763338327408,18.42547035217285,137759.00602436066,408682,0,137759.00602436066,0.6312000155448914,1.8195046186447144,10000,142646.33941054344,0.9594826102256776,0.1504420191049575,0.7554999589920044,1.0528854131698608,50000 -4869.61491060257,18.52701497077942,138268.93678069115,410196,0,138268.93678069115,0.6308000087738037,1.8148283958435056,10000,143173.28214502335,0.9610570669174194,0.1465514004230499,0.7554199695587158,1.051986813545227,50000 -4886.722577571869,18.632123947143555,138778.83112430573,411711,0,138778.83112430573,0.631100058555603,1.816946029663086,10000,143700.44857931137,0.9595025181770324,0.1496334373950958,0.755899965763092,1.0508095026016235,50000 -4903.879125356674,18.72833561897278,139288.81606531143,413225,0,139288.81606531143,0.6319000124931335,1.8175504207611084,10000,144227.74542737007,0.9604790806770324,0.1473581790924072,0.7558199763298035,1.0527368783950806,50000 -4921.023787975311,18.827582120895386,139798.95591545105,414740,0,139798.95591545105,0.6313000321388245,1.817002415657044,10000,144755.18977284431,0.961136758327484,0.1464726626873016,0.7557199597358704,1.051748752593994,50000 -4938.317877531052,18.94173526763916,140309.01599621773,416253,0,140309.01599621773,0.6322000026702881,1.8180079460144043,10000,145282.7172703743,0.96000075340271,0.1497869342565536,0.7555999755859375,1.0523784160614014,50000 -4955.200018882752,19.039517164230347,140818.96789312363,417767,0,140818.96789312363,0.6313000321388245,1.8175561428070068,10000,145809.70890045166,0.9615154266357422,0.1464042663574218,0.7555999755859375,1.051769137382507,50000 -4972.296702861786,19.138569831848145,141328.88626480105,419280,0,141328.88626480105,0.6313000321388245,1.815688967704773,10000,146336.88316321373,0.962113320827484,0.1444258242845535,0.7554799914360046,1.0518677234649658,50000 -4989.195043087006,19.23827767372132,141838.8102095127,420794,0,141838.8102095127,0.6318000555038452,1.81821358203888,10000,146863.86543941498,0.9602399468421936,0.146013543009758,0.7559999823570251,1.051907658576965,50000 -5006.32697892189,19.34300827980041,142348.95408391953,422308,0,142348.95408391953,0.6314000487327576,1.8166799545288088,10000,147391.3078136444,0.9610171914100648,0.1461603045463562,0.7557399868965149,1.0520353317260742,50000 -5024.113070011139,19.43835592269897,142859.0488152504,423822,0,142859.0488152504,0.6323000192642212,1.8173928260803225,10000,147919.34276366234,0.9603196382522584,0.1475877910852432,0.7555599808692932,1.0516517162322998,50000 -5041.205612659454,19.53954839706421,143369.13305354118,425335,0,143369.13305354118,0.631100058555603,1.8176788091659544,10000,148446.680259943,0.9591637253761292,0.150650754570961,0.755899965763092,1.0516971349716189,50000 -5058.131942987442,19.6390221118927,143879.30010700226,426848,0,143879.30010700226,0.6313000321388245,1.8194468021392824,10000,148973.93353700638,0.9608976244926452,0.1458160281181335,0.7559799551963806,1.0528355836868286,50000 -5075.299085140228,19.744468927383423,144389.14986610413,428360,0,144389.14986610413,0.6307000517845154,1.816765069961548,10000,149501.11479592323,0.9611766338348388,0.146694615483284,0.7557199597358704,1.051979660987854,50000 -5092.433423757553,19.84219336509705,144899.22899580002,429874,0,144899.22899580002,0.6309000253677368,1.8176268339157104,10000,150028.48589587212,0.961535394191742,0.1424656212329864,0.7558000087738037,1.051815152168274,50000 -5109.480449914932,19.94320559501648,145409.23196220398,431387,0,145409.23196220398,0.6319000124931335,1.817805528640747,10000,150555.69521546364,0.961156725883484,0.1466477066278457,0.7562800049781799,1.0512664318084717,50000 -5126.57053732872,20.042986392974854,145919.59135246277,432901,0,145919.59135246277,0.6319000124931335,1.81902277469635,10000,151083.30557370186,0.9598612785339355,0.1482953429222107,0.7561999559402466,1.0520941019058228,50000 -5143.729845046997,20.137884616851807,146429.5111811161,434414,0,146429.5111811161,0.6318000555038452,1.8176249265670776,10000,151610.5396182537,0.9611766338348388,0.1458516418933868,0.7561399936676025,1.0522480010986328,50000 -5160.749233961105,20.236326932907104,146939.64353322983,435928,0,146939.64353322983,0.631100058555603,1.816450834274292,10000,152137.84930348396,0.9609175324440002,0.1470854878425598,0.7560200095176697,1.0511912107467651,50000 -5177.542870759964,20.333744525909424,147449.54594254494,437441,0,147449.54594254494,0.631100058555603,1.8216536045074463,10000,152664.7027964592,0.9604591727256776,0.1464346051216125,0.7558599710464478,1.0538361072540283,50000 -5194.418187379837,20.43113422393799,147959.53197169304,438955,0,147959.53197169304,0.6310000419616699,1.817052960395813,10000,153191.72272014618,0.9606983065605164,0.1456460505723953,0.7556999921798706,1.0515245199203491,50000 -5211.586438894272,20.53074598312378,148469.47192668915,440468,0,148469.47192668915,0.6320000290870667,1.8179895877838133,10000,153718.99074602127,0.9604192972183228,0.1487842202186584,0.7554599642753601,1.0531373023986816,50000 -5228.576928853989,20.633300065994263,148979.646399498,441982,0,148979.646399498,0.6314000487327576,1.8177663087844849,10000,154246.31707787514,0.9610371589660645,0.1458581238985061,0.7560999989509583,1.0520803928375244,50000 -5245.627908229828,20.737919569015503,149489.79416179657,443496,0,149489.79416179657,0.6315000057220459,1.8164036273956297,10000,154773.68194508553,0.960180163383484,0.1478644162416458,0.7556999921798706,1.0514676570892334,50000 -5262.597207069397,20.838828086853027,149999.8376750946,445010,0,149999.8376750946,0.6312000155448914,1.817966341972351,10000,155300.85563635826,0.9609375,0.1475970149040222,0.7556999921798706,1.0521252155303955,50000 -5279.649640798569,20.945537328720093,150509.80757832527,446523,0,150509.80757832527,0.6313000321388245,1.8180619478225708,10000,155828.0451927185,0.9608378410339355,0.1482971906661987,0.7558599710464478,1.0520795583724976,50000 -5296.409289121628,21.0596821308136,151019.7969942093,448037,0,151019.7969942093,0.6306000351905823,1.818080186843872,10000,156354.96886587143,0.961316168308258,0.1471826732158661,0.7558599710464478,1.0523767471313477,50000 -5313.485778331757,21.15876936912537,151529.86161398888,449550,0,151529.86161398888,0.6308000087738037,1.816544771194458,10000,156882.26883530617,0.960339605808258,0.1485464125871658,0.7556599974632263,1.050680160522461,50000 -5330.426630020142,21.256157636642456,152039.7231528759,451063,0,152039.7231528759,0.631600022315979,1.817432284355164,10000,157409.22890758514,0.959622085094452,0.1482224911451339,0.7559399604797363,1.0514514446258545,50000 -5347.50834107399,21.35846757888794,152549.7496676445,452577,0,152549.7496676445,0.631600022315979,1.817237973213196,10000,157936.49866342545,0.9602598547935486,0.1482243835926056,0.7552199959754944,1.0523943901062012,50000 -5364.562657833099,21.45997190475464,153059.6380867958,454090,0,153059.6380867958,0.6305000185966492,1.818716526031494,10000,158463.60263371468,0.9610171914100648,0.1476788073778152,0.7560399770736694,1.052739143371582,50000 -5381.410010099411,21.56362199783325,153569.5043578148,455603,0,153569.5043578148,0.6310000419616699,1.818840265274048,10000,158990.47890138626,0.960598647594452,0.1480612307786941,0.7557599544525146,1.05259907245636,50000 -5398.56386756897,21.66671848297119,154079.6365056038,457118,0,154079.6365056038,0.6318000555038452,1.817526817321777,10000,159517.92832422256,0.9614157676696776,0.146322026848793,0.7562999725341797,1.0514073371887207,50000 -5415.647324562073,21.76831364631653,154589.79145264626,458632,0,154589.79145264626,0.6308000087738037,1.818753242492676,10000,160045.3290655613,0.9614955186843872,0.1443223059177398,0.7561999559402466,1.0519403219223022,50000 -5432.796221256256,21.89662194252014,155099.9243299961,460147,0,155099.9243299961,0.6310000419616699,1.8186067342758176,10000,160572.79992508888,0.9606783986091614,0.1461041569709777,0.7561799883842468,1.051734209060669,50000 -5450.276737928391,22.00607419013977,155609.88244462013,461660,0,155609.88244462013,0.6314000487327576,1.8191816806793213,10000,161100.4083454609,0.9598014950752258,0.1478946059942245,0.7555399537086487,1.0528876781463623,50000 -5467.160728693008,22.11421036720276,156119.75332951546,463174,0,156119.75332951546,0.6308000087738037,1.818250298500061,10000,161627.33212256432,0.9601004123687744,0.1464466303586959,0.7556599974632263,1.052324891090393,50000 -5484.03133893013,22.21840262413025,156629.86915493011,464687,0,156629.86915493011,0.631600022315979,1.8184653520584104,10000,162154.48467326164,0.9594228267669678,0.1503064781427383,0.7555800080299377,1.0526986122131348,50000 -5501.084171056747,22.32434344291687,157139.771348238,466202,0,157139.771348238,0.6315000057220459,1.817851424217224,10000,162681.6063232422,0.962292730808258,0.1432087272405624,0.7555199861526489,1.0532876253128052,50000 -5518.072999954224,22.43128538131714,157649.81596922874,467716,0,157649.81596922874,0.6317000389099121,1.816857933998108,10000,163208.80686163902,0.960957407951355,0.1454687416553497,0.7559199929237366,1.0522892475128174,50000 -5535.045620441437,22.535234928131104,158159.90993785858,469230,0,158159.90993785858,0.6313000321388245,1.8159784078598025,10000,163736.03716516495,0.9608577489852904,0.146379217505455,0.7557199597358704,1.0511075258255005,50000 -5552.206401586533,22.64659857749939,158670.03724741936,470744,0,158670.03724741936,0.631100058555603,1.817716717720032,10000,164263.49721503258,0.9602598547935486,0.1471463739871978,0.7562199831008911,1.051742672920227,50000 -5569.197404384613,22.748106718063354,159179.998026371,472258,0,159179.998026371,0.6324000358581543,1.818382143974304,10000,164790.60931801796,0.9595224857330322,0.1507411003112793,0.7557199597358704,1.0522160530090332,50000 -5586.219715356827,22.854050636291504,159690.03354144096,473772,0,159690.03354144096,0.631100058555603,1.817054986953736,10000,165317.8332746029,0.9618542790412904,0.1431065946817398,0.7560200095176697,1.0513827800750732,50000 -5603.152618169785,22.95527720451355,160199.90467762947,475286,0,160199.90467762947,0.6313000321388245,1.8169790506362915,10000,165844.797778368,0.9608976244926452,0.1456860154867172,0.7561599612236023,1.051494836807251,50000 -5620.208205223084,23.058412313461304,160710.0825135708,476800,0,160710.0825135708,0.6307000517845154,1.8172602653503416,10000,166372.19257593155,0.9608577489852904,0.1460122913122177,0.7555800080299377,1.0520609617233276,50000 -5637.344567537308,23.16335129737854,161220.12222194672,478314,0,161220.12222194672,0.6310000419616699,1.819790363311768,10000,166899.53250861168,0.9600406289100648,0.1492892503738403,0.7554000020027161,1.0533448457717896,50000 -5654.447158336639,23.269815683364868,161729.96871495247,479827,0,161729.96871495247,0.6304000020027161,1.8189347982406616,10000,167426.6479022503,0.9615553021430968,0.1434130817651748,0.755620002746582,1.0522792339324951,50000 -5671.34614443779,23.37634825706482,162239.8997218609,481342,0,162239.8997218609,0.6318000555038452,1.81726336479187,10000,167953.64355373383,0.9592434167861938,0.1493992060422897,0.7554799914360046,1.052570343017578,50000 -5688.821419477463,23.48125386238098,162750.039809227,482856,0,162750.039809227,0.631600022315979,1.818324089050293,10000,168481.42416000366,0.9616549611091614,0.1477699726819992,0.7559999823570251,1.0522754192352295,50000 -5705.831904172897,23.58875036239624,163259.96401071548,484371,0,163259.96401071548,0.6307000517845154,1.8176476955413816,10000,169008.52561068535,0.960180163383484,0.1472534388303756,0.7554799914360046,1.0523960590362549,50000 -5722.962559461594,23.70534729957581,163769.99699783325,485885,0,163769.99699783325,0.6325000524520874,1.8172670602798464,10000,169535.86717271805,0.9604790806770324,0.1493615359067917,0.7560799717903137,1.051154851913452,50000 -5739.9425756931305,23.87461256980896,164279.91232395172,487398,0,164279.91232395172,0.631100058555603,1.816259503364563,10000,170062.991219759,0.961316168308258,0.1479690223932266,0.7558799982070923,1.0512391328811646,50000 -5756.892340660095,23.98096513748169,164789.9333343506,488912,0,164789.9333343506,0.6306000351905823,1.817075252532959,10000,170590.12753081322,0.959980845451355,0.1480510532855987,0.7559999823570251,1.0511592626571655,50000 -5773.955453395844,24.091283559799194,165299.89897084236,490426,0,165299.89897084236,0.6318000555038452,1.8199012279510496,10000,171117.3261051178,0.9606783986091614,0.1471062153577804,0.7559399604797363,1.0521148443222046,50000 -5791.007390499115,24.209453582763672,165809.91813397408,491940,0,165809.91813397408,0.6307000517845154,1.8181631565094,10000,171644.57608389854,0.9601004123687744,0.1487551927566528,0.7557799816131592,1.0525541305541992,50000 -5808.1975655555725,24.3194260597229,166319.89604902267,493454,0,166319.89604902267,0.6310000419616699,1.8182610273361208,10000,172171.9135248661,0.9607381820678712,0.1470497101545334,0.7558000087738037,1.0523558855056765,50000 -5825.202464342117,24.42889618873596,166830.0151219368,494968,0,166830.0151219368,0.6318000555038452,1.818455576896668,10000,172699.20630049706,0.961575210094452,0.1471329629421234,0.7560799717903137,1.0521728992462158,50000 -5842.022374391556,24.539676189422607,167340.15970230105,496483,0,167340.15970230105,0.6313000321388245,1.8172544240951536,10000,173226.34223484993,0.961316168308258,0.1451499611139297,0.7558799982070923,1.0518877506256104,50000 -5859.004626750946,24.64591932296753,167850.05058646202,497996,0,167850.05058646202,0.6314000487327576,1.8174744844436648,10000,173753.38080906868,0.959781527519226,0.1461669504642486,0.7559999823570251,1.0516399145126345,50000 -5876.005375862122,24.75995421409607,168359.94944262505,499510,0,168359.94944262505,0.6314000487327576,1.8173248767852783,10000,174280.4545059204,0.9612762928009032,0.1455503702163696,0.7559399604797363,1.0524622201919556,50000 -5893.762127637863,24.870811223983765,168869.80708909035,501023,0,168869.80708909035,0.6304000020027161,1.8189656734466555,10000,174808.23965120316,0.9595224857330322,0.1492321044206619,0.7557399868965149,1.052715539932251,50000 -5910.607840776444,24.980355262756348,169379.7861609459,502537,0,169379.7861609459,0.6312000155448914,1.8172820806503296,10000,175335.23373770714,0.9601004123687744,0.1491653770208358,0.755840003490448,1.051175236701965,50000 -5927.514526128769,25.08798623085022,169889.70877027512,504051,0,169889.70877027512,0.6308000087738037,1.8169641494750977,10000,175862.2305700779,0.960758090019226,0.1461090445518493,0.7560399770736694,1.0522167682647705,50000 -5944.358859062195,25.19463801383972,170399.8674080372,505566,0,170399.8674080372,0.6320000290870667,1.818225264549256,10000,176389.4013133049,0.9611766338348388,0.1460204273462295,0.7558199763298035,1.052260160446167,50000 -5961.896793842316,25.31332540512085,170909.96117901802,507080,0,170909.96117901802,0.6301000118255615,1.8179951906204224,10000,176917.21332883835,0.9620934128761292,0.1431600302457809,0.7554999589920044,1.0530349016189575,50000 -5978.876355171204,25.40578317642212,171419.9574327469,508594,0,171419.9574327469,0.6305000185966492,1.8188080787658687,10000,177444.34196448326,0.9612762928009032,0.1459125131368637,0.7559399604797363,1.0524475574493408,50000 -5995.808697462082,25.51852774620056,171929.94116711617,510108,0,171929.94116711617,0.6309000253677368,1.81772255897522,10000,177971.4310863018,0.9589046239852904,0.1507239043712616,0.7554999589920044,1.0527056455612185,50000 -6012.799216747284,25.63161373138428,172439.7688140869,511622,0,172439.7688140869,0.6308000087738037,1.8169662952423096,10000,178498.42212200165,0.9614157676696776,0.1448870450258255,0.7563799619674683,1.051182508468628,50000 -6029.798710823059,25.749578952789307,172949.87129354477,513136,0,172949.87129354477,0.6307000517845154,1.818769216537476,10000,179025.70138788223,0.9617745280265808,0.1449913829565048,0.7557199597358704,1.0533603429794312,50000 -6047.356439828873,25.85953950881958,173459.8376841545,514650,0,173459.8376841545,0.6312000155448914,1.817663311958313,10000,179553.39589834213,0.9599210619926452,0.1473044008016586,0.7555199861526489,1.0519886016845703,50000 -6064.35414147377,25.97298693656921,173969.71770739555,516163,0,173969.71770739555,0.6307000517845154,1.8186434507369995,10000,180080.4456114769,0.9608976244926452,0.1467035561800003,0.7555800080299377,1.052548050880432,50000 -6081.307334423065,26.08626389503479,174479.7213087082,517677,0,174479.7213087082,0.6319000124931335,1.816991925239563,10000,180607.57706213,0.960598647594452,0.1472920030355453,0.7556399703025818,1.0521091222763062,50000 -6098.161105394363,26.196840047836304,174989.78865027428,519191,0,174989.78865027428,0.6306000351905823,1.8196641206741333,10000,181134.6708905697,0.9604790806770324,0.1465890109539032,0.7556799650192261,1.0526906251907349,50000 -6115.153409481049,26.310802698135376,175499.8068125248,520704,0,175499.8068125248,0.6309000253677368,1.8196110725402832,10000,181661.85493016243,0.96000075340271,0.1486792713403701,0.7560399770736694,1.0523520708084106,50000 -6132.192498922348,26.426344871521,176009.86840701103,522217,0,176009.86840701103,0.6314000487327576,1.819387555122376,10000,182189.1304523945,0.961734652519226,0.1452644020318985,0.7561399936676025,1.0519838333129885,50000 -6149.103352069855,26.54132080078125,176519.7387931347,523730,0,176519.7387931347,0.6317000389099121,1.817821979522705,10000,182716.08656191823,0.9601004123687744,0.1496659517288208,0.7558799982070923,1.052033305168152,50000 -6165.975840806961,26.663962841033936,177029.82167744637,525244,0,177029.82167744637,0.6317000389099121,1.8184934854507449,10000,183243.22400975227,0.9605189561843872,0.1485978960990905,0.7560399770736694,1.0526304244995115,50000 -6182.947371721268,26.78052973747253,177539.7880001068,526758,0,177539.7880001068,0.6315000057220459,1.816197752952576,10000,183770.3384861946,0.9609375,0.1478055715560913,0.7556599974632263,1.051428198814392,50000 -6200.699266672134,26.892802953720093,178049.65263915062,528271,0,178049.65263915062,0.6306000351905823,1.8174481391906736,10000,184298.1275918484,0.9596420526504515,0.147839218378067,0.7558000087738037,1.051953673362732,50000 -6217.64315867424,27.24454164505005,178559.2515347004,529784,0,178559.2515347004,0.6313000321388245,1.8184610605239868,10000,184825.0801999569,0.9607381820678712,0.1477023512125015,0.7559999823570251,1.051684856414795,50000 -6234.597435712814,27.356879711151123,179069.09046268463,531298,0,179069.09046268463,0.6313000321388245,1.8184059858322144,10000,185352.04632163048,0.960180163383484,0.1469533294439315,0.7556999921798706,1.0518393516540527,50000 -6251.563860416412,27.472768306732178,179579.1422381401,532812,0,179579.1422381401,0.6306000351905823,1.8165922164916992,10000,185879.2403757572,0.960558831691742,0.1497821360826492,0.7557599544525146,1.05041766166687,50000 -6268.447098493576,27.58768367767334,180089.1278910637,534326,0,180089.1278910637,0.6319000124931335,1.81664514541626,10000,186406.2837367057,0.9616748690605164,0.1470065712928772,0.756119966506958,1.0519754886627195,50000 -6285.296785116196,27.708674669265747,180599.0215339661,535840,0,180599.0215339661,0.6308000087738037,1.8189215660095213,10000,186933.2061035633,0.9614157676696776,0.1429337263107299,0.7553600072860718,1.0528799295425415,50000 -6302.178875923157,27.830414295196533,181108.84668302536,537353,0,181108.84668302536,0.6315000057220459,1.8187016248703003,10000,187460.0957119465,0.9607979655265808,0.1454615592956543,0.7557599544525146,1.0528595447540283,50000 -6319.269691467285,27.94835305213928,181618.950300932,538867,0,181618.950300932,0.6307000517845154,1.8169374465942385,10000,187987.47008562088,0.9594626426696776,0.1487793922424316,0.7560799717903137,1.0516453981399536,50000 -6336.292678594589,28.065372705459595,182128.7737035752,540380,0,182128.7737035752,0.6309000253677368,1.815916895866394,10000,188514.49407696724,0.9597616195678712,0.1483667641878128,0.7553600072860718,1.051335334777832,50000 -6353.330547809601,28.18113732337952,182638.67547369003,541894,0,182638.67547369003,0.6304000020027161,1.817833662033081,10000,189041.6077091694,0.960598647594452,0.148234486579895,0.7557799816131592,1.0529454946517944,50000 -6370.11282992363,28.30603194236756,183148.73386335373,543407,0,183148.73386335373,0.6305000185966492,1.8182512521743768,10000,189568.6337320805,0.9614955186843872,0.1472430229187011,0.7562199831008911,1.0523183345794678,50000 -6386.995321273804,28.42442011833191,183658.73257374763,544921,0,183658.73257374763,0.6317000389099121,1.8176825046539309,10000,190095.69326162327,0.9620336294174194,0.1422350853681564,0.7558199763298035,1.051290512084961,50000 -6404.198199510574,28.544098138809204,184168.75291776657,546435,0,184168.75291776657,0.6313000321388245,1.817774534225464,10000,190623.0963089466,0.960957407951355,0.1463946253061294,0.7555399537086487,1.0523061752319336,50000 -6421.181804418564,28.66173243522644,184678.612347126,547948,0,184678.612347126,0.6314000487327576,1.817501425743103,10000,191150.11498069763,0.959741711616516,0.1487929821014404,0.7556399703025818,1.0521609783172607,50000 -6438.668433189392,28.779770612716675,185188.4840466976,549461,0,185188.4840466976,0.6309000253677368,1.8153493404388428,10000,191677.650382042,0.960180163383484,0.148353636264801,0.7560799717903137,1.051113843917847,50000 -6455.479493379593,28.882810354232788,185698.51256275177,550975,0,185698.51256275177,0.6317000389099121,1.8175721168518064,10000,192204.65172839165,0.961336076259613,0.143163800239563,0.7555399537086487,1.0519856214523315,50000 -6472.452427625656,29.003294467926025,186208.4559469223,552489,0,186208.4559469223,0.6310000419616699,1.8172557353973389,10000,192731.74961352348,0.9610570669174194,0.1454282253980636,0.755840003490448,1.051978588104248,50000 -6489.370160579681,29.12525010108948,186718.4262833596,554003,0,186718.4262833596,0.6323000192642212,1.8183035850524905,10000,193258.82066512108,0.9595423936843872,0.149004265666008,0.7553600072860718,1.052505373954773,50000 -6506.138321399689,29.2472882270813,187228.433716774,555517,0,187228.433716774,0.6315000057220459,1.8194420337677,10000,193785.7789597512,0.9606584906578064,0.1476895958185196,0.7559399604797363,1.0525918006896973,50000 -6523.260720252991,29.36760807037353,187738.5424454212,557032,0,187738.5424454212,0.631100058555603,1.818196177482605,10000,194313.19152712825,0.9607381820678712,0.1459540575742721,0.755840003490448,1.05264151096344,50000 -6540.03351855278,29.49054718017578,188248.5768508911,558546,0,188248.5768508911,0.6312000155448914,1.817338943481445,10000,194840.18134617803,0.9609972834587096,0.1453019976615905,0.7559199929237366,1.0516624450683594,50000 -6556.845715045929,29.64098310470581,188737.60403227806,559998,0,188737.60403227806,0.6314000487327576,1.8173184394836426,10000,195346.22756004333,0.9608577489852905,0.14729043841362,0.7560799717903137,1.0520939826965332,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index c912da49c..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5973 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.66207147,6.922584,,,,,,,,,,,,,, -1,,,0.0012356505030766,6.911959171295166,0.0010799999581649,6.912962913513184,50000.0,0.0009000000427477,6.913362979888916,10000.0,52.38156604766846,89.74589896202087,52.38156604766846,37.36424398422241,0.0,0.0 -100,0.69254017,6.8154993,,,,,,,,,,,,,, -200,0.7990751,6.5770936,,,,,,,,,,,,,, -300,1.1705164,6.286743,,,,,,,,,,,,,, -400,2.4057217,6.004098,,,,,,,,,,,,,, -500,1.9332082,5.8112855,,,,,,,,,,,,,, -600,3.8959818,5.6845217,,,,,,,,,,,,,, -700,3.6223047,5.467581,,,,,,,,,,,,,, -800,2.6068127,5.245099,,,,,,,,,,,,,, -900,2.5341005,5.197075,,,,,,,,,,,,,, -1000,3.7051384,4.9999747,,,,,,,,,,,,,, -1100,3.7126775,4.9583025,,,,,,,,,,,,,, -1200,3.8295438,4.9110575,,,,,,,,,,,,,, -1300,5.943663,4.669181,,,,,,,,,,,,,, -1400,5.8358006,4.745398,,,,,,,,,,,,,, -1500,5.8630567,4.4491873,,,,,,,,,,,,,, -1509,,,0.1795679181814193,4.181913375854492,0.1612999886274337,4.331301212310791,50000.0,0.125900000333786,4.813449382781982,10000.0,562.5784339904785,618.4830918312073,562.5784339904785,55.82190465927124,0.0253584384918212,0.0 -1600,6.101464,4.438639,,,,,,,,,,,,,, -1700,3.557665,4.258892,,,,,,,,,,,,,, -1800,4.5065613,4.312246,,,,,,,,,,,,,, -1900,4.8925395,4.157575,,,,,,,,,,,,,, -2000,3.4145432,4.0632167,,,,,,,,,,,,,, -2100,4.113147,3.990882,,,,,,,,,,,,,, -2200,3.9068391,3.9475734,,,,,,,,,,,,,, -2300,3.836286,3.944209,,,,,,,,,,,,,, -2400,4.069746,3.7279787,,,,,,,,,,,,,, -2500,4.288114,3.6872907,,,,,,,,,,,,,, -2600,5.069545,3.5555208,,,,,,,,,,,,,, -2700,4.258487,3.5984027,,,,,,,,,,,,,, -2800,2.954413,3.4953132,,,,,,,,,,,,,, -2900,4.6718936,3.5627503,,,,,,,,,,,,,, -3000,3.8165069,3.473096,,,,,,,,,,,,,, -3016,,,0.3254344761371612,3.1697702407836914,0.2942200005054474,3.366985559463501,50000.0,0.2266000062227249,3.981245994567871,10000.0,1072.4957585334778,1146.6672384738922,1072.4957585334778,74.00124216079712,0.0526726245880126,0.0 -3100,4.2475452,3.3443065,,,,,,,,,,,,,, -3200,4.0096173,3.3332062,,,,,,,,,,,,,, -3300,3.4979174,3.1886494,,,,,,,,,,,,,, -3400,3.4969428,3.2402215,,,,,,,,,,,,,, -3500,3.473182,3.20892,,,,,,,,,,,,,, -3600,2.9827325,3.0065894,,,,,,,,,,,,,, -3700,3.7398036,3.0332682,,,,,,,,,,,,,, -3800,2.6491342,3.1060638,,,,,,,,,,,,,, -3900,2.747104,3.0167928,,,,,,,,,,,,,, -4000,2.3151426,2.858445,,,,,,,,,,,,,, -4100,3.1112273,2.7797184,,,,,,,,,,,,,, -4200,2.7393227,2.8465621,,,,,,,,,,,,,, -4300,3.6094809,2.8371267,,,,,,,,,,,,,, -4400,2.3329377,2.8877702,,,,,,,,,,,,,, -4500,3.1944196,2.8549283,,,,,,,,,,,,,, -4525,,,0.4550183117389679,2.387629985809326,0.4248600006103515,2.5835626125335693,50000.0,0.3230000138282776,3.316889524459839,10000.0,1582.6927065849304,1675.2411172389984,1582.6927065849304,92.29433751106262,0.0788233280181884,0.0 -4600,1.8694888,2.7803478,,,,,,,,,,,,,, -4700,2.3922777,2.5430574,,,,,,,,,,,,,, -4800,2.3386552,2.742394,,,,,,,,,,,,,, -4900,3.5571904,2.5219731,,,,,,,,,,,,,, -5000,1.7361699,2.6280346,,,,,,,,,,,,,, -5100,1.9416105,2.5491054,,,,,,,,,,,,,, -5200,2.3213668,2.6383216,,,,,,,,,,,,,, -5300,2.7856803,2.6097808,,,,,,,,,,,,,, -5400,2.8256104,2.6751134,,,,,,,,,,,,,, -5500,2.1742384,2.5392764,,,,,,,,,,,,,, -5600,2.8756008,2.5851774,,,,,,,,,,,,,, -5700,2.6065004,2.389205,,,,,,,,,,,,,, -5800,2.4691858,2.634041,,,,,,,,,,,,,, -5900,2.080083,2.416181,,,,,,,,,,,,,, -6000,2.2594602,2.460028,,,,,,,,,,,,,, -6035,,,0.5336614847183228,1.9763846397399905,0.4961600005626678,2.1787710189819336,50000.0,0.3857000172138214,2.9400079250335693,10000.0,2092.691743612289,2203.4167931079865,2092.691743612289,110.38658118247986,0.106652021408081,0.0 -6100,2.3696659,2.3907566,,,,,,,,,,,,,, -6200,2.3514805,2.4973478,,,,,,,,,,,,,, -6300,1.6102458,2.3579404,,,,,,,,,,,,,, -6400,1.978359,2.3411133,,,,,,,,,,,,,, -6500,2.424568,2.4395695,,,,,,,,,,,,,, -6600,1.8677903,2.2537084,,,,,,,,,,,,,, -6700,2.2172797,2.3451688,,,,,,,,,,,,,, -6800,1.9781168,2.2951667,,,,,,,,,,,,,, -6900,2.026842,2.3875906,,,,,,,,,,,,,, -7000,1.8117713,2.387484,,,,,,,,,,,,,, -7100,2.1692445,2.3757029,,,,,,,,,,,,,, -7200,1.5760081,2.2756603,,,,,,,,,,,,,, -7300,1.5512161,2.3828175,,,,,,,,,,,,,, -7400,1.8084679,2.2109356,,,,,,,,,,,,,, -7500,2.3865075,2.2670176,,,,,,,,,,,,,, -7545,,,0.5792012214660645,1.7531155347824097,0.5401399731636047,1.974346160888672,50000.0,0.415800005197525,2.75813889503479,10000.0,2602.7501418590546,2731.746652841568,2602.7501418590546,128.57271671295166,0.1335184574127197,0.0 -7600,2.2205083,2.2340646,,,,,,,,,,,,,, -7700,1.314515,2.3011096,,,,,,,,,,,,,, -7800,1.8160374,2.2865424,,,,,,,,,,,,,, -7900,1.8501717,2.1758752,,,,,,,,,,,,,, -8000,1.9277327,2.0507894,,,,,,,,,,,,,, -8100,1.8478299,2.2830498,,,,,,,,,,,,,, -8200,1.8696295,2.2012413,,,,,,,,,,,,,, -8300,1.6535131,2.1108832,,,,,,,,,,,,,, -8400,1.7048663,2.148905,,,,,,,,,,,,,, -8500,1.6650699,2.3166318,,,,,,,,,,,,,, -8600,1.8943534,2.0560482,,,,,,,,,,,,,, -8700,1.4563164,2.1254373,,,,,,,,,,,,,, -8800,1.4190296,2.1183066,,,,,,,,,,,,,, -8900,1.7317125,2.135808,,,,,,,,,,,,,, -9000,1.9203371,2.2643282,,,,,,,,,,,,,, -9057,,,0.6060865521430969,1.620163083076477,0.5582199692726135,1.88018000125885,50000.0,0.4338000118732452,2.645254135131836,10000.0,3112.906903028488,3260.0810463428497,3112.906903028488,146.66174578666687,0.1640608310699463,0.0 -9100,1.6585405,2.083142,,,,,,,,,,,,,, -9200,1.7573003,1.8674521,,,,,,,,,,,,,, -9300,1.971437,2.0393755,,,,,,,,,,,,,, -9400,1.6478745,1.9677211,,,,,,,,,,,,,, -9500,1.5748075,2.1361833,,,,,,,,,,,,,, -9600,1.4806904,2.0240576,,,,,,,,,,,,,, -9700,2.1133761,2.077426,,,,,,,,,,,,,, -9800,1.7837452,2.1245072,,,,,,,,,,,,,, -9900,1.6334752,2.1006641,,,,,,,,,,,,,, -10000,2.2313676,2.208455,,,,,,,,,,,,,, -10100,1.5344788,2.0854394,,,,,,,,,,,,,, -10200,1.6624863,1.9759958,,,,,,,,,,,,,, -10300,1.4022485,2.2261615,,,,,,,,,,,,,, -10400,1.6689811,2.1942375,,,,,,,,,,,,,, -10500,2.0092409,2.0191593,,,,,,,,,,,,,, -10568,,,0.6356824040412903,1.4726837873458862,0.5694800019264221,1.8395023345947263,50000.0,0.4500000178813934,2.5940630435943604,10000.0,3622.962684154512,3788.7260398864746,3622.962684154512,165.1640853881836,0.1944692134857177,0.0 -10600,1.5442457,1.9435908,,,,,,,,,,,,,, -10700,1.3953707,1.9922948,,,,,,,,,,,,,, -10800,1.7219038,2.0485938,,,,,,,,,,,,,, -10900,2.3047473,2.0653868,,,,,,,,,,,,,, -11000,1.7203887,2.0498176,,,,,,,,,,,,,, -11100,1.7942984,2.0949433,,,,,,,,,,,,,, -11200,1.5128193,1.9994537,,,,,,,,,,,,,, -11300,1.8164665,2.002229,,,,,,,,,,,,,, -11400,1.4969506,1.9663324,,,,,,,,,,,,,, -11500,1.5715098,2.1068082,,,,,,,,,,,,,, -11600,1.4875927,1.9200509,,,,,,,,,,,,,, -11700,2.4138272,2.0018387,,,,,,,,,,,,,, -11800,2.216327,2.0088034,,,,,,,,,,,,,, -11900,1.8924671,1.9328161,,,,,,,,,,,,,, -12000,1.6246012,1.905952,,,,,,,,,,,,,, -12080,,,0.6251793503761292,1.5146089792251587,0.5725199580192566,1.817832112312317,50000.0,0.4550000131130218,2.581552028656006,10000.0,4133.085363149643,4319.210289001465,4133.085363149643,185.43234658241272,0.2300171852111816,0.0 -12100,2.0744534,2.0693512,,,,,,,,,,,,,, -12200,1.829069,2.093494,,,,,,,,,,,,,, -12300,1.5557197,1.9505879,,,,,,,,,,,,,, -12400,1.6902716,1.9426807,,,,,,,,,,,,,, -12500,1.5573629,2.0126348,,,,,,,,,,,,,, -12600,1.5117148,1.9615806,,,,,,,,,,,,,, -12700,1.4087284,1.981394,,,,,,,,,,,,,, -12800,1.602603,1.9823232,,,,,,,,,,,,,, -12900,1.6018709,1.9222583,,,,,,,,,,,,,, -13000,1.7268279,1.9579251,,,,,,,,,,,,,, -13100,1.4103944,1.9468129,,,,,,,,,,,,,, -13200,1.6245478,1.9835942,,,,,,,,,,,,,, -13300,1.634112,1.9105879,,,,,,,,,,,,,, -13400,1.2370417,1.9066645,,,,,,,,,,,,,, -13500,1.6139666,1.8757197,,,,,,,,,,,,,, -13592,,,0.6494539380073547,1.422621726989746,0.5925599932670593,1.707416534423828,50000.0,0.4630000293254852,2.468202829360962,10000.0,4643.2337856292725,4849.872251987457,4643.2337856292725,205.85814785957336,0.2600185871124267,0.0 -13600,1.3885477,1.9373071,,,,,,,,,,,,,, -13700,1.3273448,1.8733658,,,,,,,,,,,,,, -13800,1.4475154,1.7762351,,,,,,,,,,,,,, -13900,1.7190765,1.9416933,,,,,,,,,,,,,, -14000,1.3878489,1.8915418,,,,,,,,,,,,,, -14100,1.5666207,1.9203783,,,,,,,,,,,,,, -14200,1.5271802,1.9070024,,,,,,,,,,,,,, -14300,1.5660264,1.9748474,,,,,,,,,,,,,, -14400,1.5081269,1.8939722,,,,,,,,,,,,,, -14500,1.5316168,1.8470682,,,,,,,,,,,,,, -14600,1.6878557,1.8981879,,,,,,,,,,,,,, -14700,1.5303993,1.8434191,,,,,,,,,,,,,, -14800,1.7918739,1.7937241,,,,,,,,,,,,,, -14900,1.5603442,1.9086992,,,,,,,,,,,,,, -15000,2.263939,1.8817725,,,,,,,,,,,,,, -15100,1.7147789,1.8921965,,,,,,,,,,,,,, -15104,,,0.6466438174247742,1.4184733629226685,0.5945199728012085,1.7023662328720093,50000.0,0.4723000228404999,2.4483578205108643,10000.0,5153.15531373024,5380.393999576569,5153.15531373024,226.37191343307487,0.2892529964447021,0.0 -15200,2.4344518,1.9752328,,,,,,,,,,,,,, -15300,1.8212185,1.8752356,,,,,,,,,,,,,, -15400,1.4886922,1.956113,,,,,,,,,,,,,, -15500,1.60745,1.8833802,,,,,,,,,,,,,, -15600,1.9487675,1.8911732,,,,,,,,,,,,,, -15700,1.5084922,1.8160218,,,,,,,,,,,,,, -15800,1.6441765,1.9404199,,,,,,,,,,,,,, -15900,1.5614221,1.8219796,,,,,,,,,,,,,, -16000,1.5444615,1.787076,,,,,,,,,,,,,, -16100,1.5549736,1.8057911,,,,,,,,,,,,,, -16200,1.5927039,1.9324898,,,,,,,,,,,,,, -16300,1.7414619,1.8798313,,,,,,,,,,,,,, -16400,1.6637174,1.8852229,,,,,,,,,,,,,, -16500,1.5648297,1.8082647,,,,,,,,,,,,,, -16600,1.4952465,1.8218192,,,,,,,,,,,,,, -16616,,,0.6574457883834839,1.3805900812149048,0.6076599955558777,1.6330012083053589,50000.0,0.4771000146865845,2.402414321899414,10000.0,5663.251859664917,5910.609344005585,5663.251859664917,246.3868336677552,0.3350415229797363,0.0 -16700,1.6077244,1.7354614,,,,,,,,,,,,,, -16800,1.4323287,1.8023924,,,,,,,,,,,,,, -16900,1.422332,1.8682688,,,,,,,,,,,,,, -17000,1.8735108,1.8591539,,,,,,,,,,,,,, -17100,1.6815695,1.9117645,,,,,,,,,,,,,, -17200,1.5338359,1.8419704,,,,,,,,,,,,,, -17300,1.5154155,1.8643007,,,,,,,,,,,,,, -17400,1.410162,1.7777464,,,,,,,,,,,,,, -17500,1.6006955,1.9022243,,,,,,,,,,,,,, -17600,1.6384866,1.8783951,,,,,,,,,,,,,, -17700,1.5980527,1.9847994,,,,,,,,,,,,,, -17800,1.530743,1.90148,,,,,,,,,,,,,, -17900,1.4515642,1.8730209,,,,,,,,,,,,,, -18000,1.6381059,1.7587787,,,,,,,,,,,,,, -18100,1.7469307,1.8166062,,,,,,,,,,,,,, -18128,,,0.6847695708274841,1.253440260887146,0.6041199564933777,1.6582300662994385,50000.0,0.4830000102519989,2.4177982807159424,10000.0,6173.184433221817,6444.77542424202,6173.184433221817,270.5326895713806,0.364840030670166,0.0 -18200,1.7865776,1.8345217,,,,,,,,,,,,,, -18300,1.3550035,1.7252566,,,,,,,,,,,,,, -18400,1.7307405,1.9855539,,,,,,,,,,,,,, -18500,1.5898515,1.7616402,,,,,,,,,,,,,, -18600,1.769251,1.8084497,,,,,,,,,,,,,, -18700,1.7575086,1.9363391,,,,,,,,,,,,,, -18800,1.5892385,1.8625325,,,,,,,,,,,,,, -18900,1.5253928,1.8968184,,,,,,,,,,,,,, -19000,1.4814235,1.7537439,,,,,,,,,,,,,, -19100,1.5201697,1.5621157,,,,,,,,,,,,,, -19200,1.7342861,1.8044436,,,,,,,,,,,,,, -19300,1.9129524,1.8291397,,,,,,,,,,,,,, -19400,1.582755,1.8249006,,,,,,,,,,,,,, -19500,1.8869369,1.8827649,,,,,,,,,,,,,, -19600,1.7310817,1.7750777,,,,,,,,,,,,,, -19640,,,0.6862244606018066,1.2184010744094849,0.6089800000190735,1.6390219926834106,50000.0,0.4882000088691711,2.403407573699951,10000.0,6683.125174283981,6979.279679298401,6683.125174283981,294.9985795021057,0.404205322265625,0.0 -19700,1.9254893,1.7070328,,,,,,,,,,,,,, -19800,1.6713064,1.8130748,,,,,,,,,,,,,, -19900,1.6010287,1.9186953,,,,,,,,,,,,,, -20000,1.8686383,1.7836716,,,,,,,,,,,,,, -20100,1.6340289,1.8705729,,,,,,,,,,,,,, -20200,1.6677495,1.8937995,,,,,,,,,,,,,, -20300,1.8009079,1.7794117,,,,,,,,,,,,,, -20400,1.5412914,1.7734712,,,,,,,,,,,,,, -20500,1.7813277,1.7111675,,,,,,,,,,,,,, -20600,1.622365,1.7863874,,,,,,,,,,,,,, -20700,1.6264894,1.6779234,,,,,,,,,,,,,, -20800,1.6881653,1.8822786,,,,,,,,,,,,,, -20900,1.9715352,1.8669021,,,,,,,,,,,,,, -21000,1.8347888,1.8609543,,,,,,,,,,,,,, -21100,1.6909553,1.8232273,,,,,,,,,,,,,, -21154,,,0.6550542116165161,1.37153160572052,0.5969399809837341,1.7031642198562622,50000.0,0.4700000286102295,2.498812675476074,10000.0,7193.361950397492,7513.631420373917,7193.361950397492,319.0235517024994,0.4361577033996582,0.0 -21200,1.6147835,1.8241978,,,,,,,,,,,,,, -21300,1.5547374,1.6570282,,,,,,,,,,,,,, -21400,1.6188867,1.8027211,,,,,,,,,,,,,, -21500,1.7606976,1.7610878,,,,,,,,,,,,,, -21600,1.6958638,1.7417771,,,,,,,,,,,,,, -21700,1.6299133,1.8664074,,,,,,,,,,,,,, -21800,1.643528,1.7790616,,,,,,,,,,,,,, -21900,1.701682,1.7504649,,,,,,,,,,,,,, -22000,1.6077346,1.8309565,,,,,,,,,,,,,, -22100,1.5726554,1.8166993,,,,,,,,,,,,,, -22200,1.6819582,1.7847759,,,,,,,,,,,,,, -22300,1.8475362,1.8187294,,,,,,,,,,,,,, -22400,1.4998423,1.7828966,,,,,,,,,,,,,, -22500,1.9703578,1.7650914,,,,,,,,,,,,,, -22600,1.744322,1.7570374,,,,,,,,,,,,,, -22667,,,0.6710379123687744,1.3040506839752195,0.6122399568557739,1.612977147102356,50000.0,0.4869000315666199,2.3287904262542725,10000.0,7703.452437400818,8047.973657846451,7703.452437400818,343.1838707923889,0.4693658351898193,0.0 -22700,1.6415763,1.758215,,,,,,,,,,,,,, -22800,1.7392107,1.7878218,,,,,,,,,,,,,, -22900,1.8708726,1.7823013,,,,,,,,,,,,,, -23000,1.6649425,1.6949991,,,,,,,,,,,,,, -23100,1.6889848,1.6772743,,,,,,,,,,,,,, -23200,1.7187599,1.8261073,,,,,,,,,,,,,, -23300,1.6597631,1.7577337,,,,,,,,,,,,,, -23400,1.8255867,1.7568953,,,,,,,,,,,,,, -23500,1.5471114,1.8507861,,,,,,,,,,,,,, -23600,1.661714,1.7927265,,,,,,,,,,,,,, -23700,1.664071,1.7378991,,,,,,,,,,,,,, -23800,1.5543909,1.7644842,,,,,,,,,,,,,, -23900,1.6978736,1.8388684,,,,,,,,,,,,,, -24000,1.5793328,1.7958548,,,,,,,,,,,,,, -24100,1.7433339,1.8746105,,,,,,,,,,,,,, -24180,,,0.6864636540412903,1.2356112003326416,0.6240800023078918,1.556673288345337,50000.0,0.5056000351905823,2.2998135089874268,10000.0,8213.370545864105,8580.17731165886,8213.370545864105,365.3846924304962,0.4949944019317627,0.0 -24200,1.6362047,1.7512709,,,,,,,,,,,,,, -24300,1.4731603,1.6123999,,,,,,,,,,,,,, -24400,1.5946724,1.6675758,,,,,,,,,,,,,, -24500,1.9733478,1.8123364,,,,,,,,,,,,,, -24600,1.7838562,1.6937761,,,,,,,,,,,,,, -24700,1.7962767,1.7498477,,,,,,,,,,,,,, -24800,1.9333053,1.690477,,,,,,,,,,,,,, -24900,3.0759165,1.7490263,,,,,,,,,,,,,, -25000,1.5882602,1.7515358,,,,,,,,,,,,,, -25100,1.6683745,1.7197737,,,,,,,,,,,,,, -25200,1.661822,1.8452111,,,,,,,,,,,,,, -25300,1.946386,1.8149532,,,,,,,,,,,,,, -25400,1.8164821,1.7300712,,,,,,,,,,,,,, -25500,1.6159796,1.6597453,,,,,,,,,,,,,, -25600,1.7131199,1.6721947,,,,,,,,,,,,,, -25693,,,0.6590800285339355,1.3417572975158691,0.6096000075340271,1.6214866638183594,50000.0,0.4832000136375427,2.368893623352051,10000.0,8723.534445524216,9114.901677131653,8723.534445524216,389.8584134578705,0.5229532718658447,0.0 -25700,1.7092046,1.7963095,,,,,,,,,,,,,, -25800,1.5306559,1.8093121,,,,,,,,,,,,,, -25900,1.7991489,1.6492687,,,,,,,,,,,,,, -26000,1.7992271,1.6666288,,,,,,,,,,,,,, -26100,1.7391658,1.6482435,,,,,,,,,,,,,, -26200,1.6634297,1.6654483,,,,,,,,,,,,,, -26300,1.9197202,1.8180778,,,,,,,,,,,,,, -26400,1.8828467,1.6764752,,,,,,,,,,,,,, -26500,1.6188014,1.9128705,,,,,,,,,,,,,, -26600,1.6775634,1.7462239,,,,,,,,,,,,,, -26700,1.870505,1.7574397,,,,,,,,,,,,,, -26800,1.9833919,1.7961587,,,,,,,,,,,,,, -26900,1.8474822,1.7472243,,,,,,,,,,,,,, -27000,1.6737914,1.7760328,,,,,,,,,,,,,, -27100,2.1414528,1.8446581,,,,,,,,,,,,,, -27200,1.7536675,1.807035,,,,,,,,,,,,,, -27205,,,0.7247090339660645,1.0631380081176758,0.6251999735832214,1.5556570291519165,50000.0,0.49590003490448,2.290503978729248,10000.0,9233.515226125715,9648.93862605095,9233.515226125715,413.8295383453369,0.5503711700439453,0.0 -27300,1.6632384,1.7041137,,,,,,,,,,,,,, -27400,1.8879865,1.6060812,,,,,,,,,,,,,, -27500,2.4208739,1.8213053,,,,,,,,,,,,,, -27600,1.848526,1.8159248,,,,,,,,,,,,,, -27700,1.7929734,1.7791762,,,,,,,,,,,,,, -27800,1.6233969,1.6831247,,,,,,,,,,,,,, -27900,1.6096083,1.68277,,,,,,,,,,,,,, -28000,1.7078328,1.8288246,,,,,,,,,,,,,, -28100,1.6526729,1.7582898,,,,,,,,,,,,,, -28200,1.7849393,1.701761,,,,,,,,,,,,,, -28300,1.715652,1.8171496,,,,,,,,,,,,,, -28400,2.1794362,1.8236572,,,,,,,,,,,,,, -28500,1.7122,1.7076123,,,,,,,,,,,,,, -28600,2.1266706,1.6688493,,,,,,,,,,,,,, -28700,1.6977936,1.8224399,,,,,,,,,,,,,, -28718,,,0.6881178021430969,1.2257320880889893,0.6116399765014648,1.6099690198898315,50000.0,0.4818000197410583,2.3851914405822754,10000.0,9743.726099729538,10183.951251506804,9743.726099729538,438.5450539588928,0.5776462554931641,0.0 -28800,1.9301894,1.7438018,,,,,,,,,,,,,, -28900,1.9368728,1.7479682,,,,,,,,,,,,,, -29000,1.645376,1.6877879,,,,,,,,,,,,,, -29100,1.7427735,1.7694455,,,,,,,,,,,,,, -29200,1.8470742,1.7981681,,,,,,,,,,,,,, -29300,1.7625341,1.7832849,,,,,,,,,,,,,, -29400,1.6557577,1.6993498,,,,,,,,,,,,,, -29500,1.9911366,1.7510933,,,,,,,,,,,,,, -29600,1.7920705,1.7896116,,,,,,,,,,,,,, -29700,1.7723284,1.7097527,,,,,,,,,,,,,, -29800,1.6819301,1.6741612,,,,,,,,,,,,,, -29900,1.8501315,1.740439,,,,,,,,,,,,,, -30000,1.5451043,1.6378725,,,,,,,,,,,,,, -30100,1.8889463,1.677754,,,,,,,,,,,,,, -30200,1.6319487,1.550773,,,,,,,,,,,,,, -30231,,,0.69140625,1.2178308963775637,0.6235600113868713,1.5551875829696655,50000.0,0.4962000250816345,2.30889105796814,10000.0,10253.642538785934,10724.887073040009,10253.642538785934,469.4787917137146,0.6060595512390137,0.0 -30300,1.7775139,1.7427425,,,,,,,,,,,,,, -30400,1.9305066,1.7735354,,,,,,,,,,,,,, -30500,1.7106091,1.6997967,,,,,,,,,,,,,, -30600,1.9209496,1.7093272,,,,,,,,,,,,,, -30700,1.7200736,1.7255807,,,,,,,,,,,,,, -30800,1.9536926,1.8067852,,,,,,,,,,,,,, -30900,1.7986218,1.6891694,,,,,,,,,,,,,, -31000,1.8113177,1.7416563,,,,,,,,,,,,,, -31100,1.6603675,1.6797113,,,,,,,,,,,,,, -31200,1.8286161,1.8894982,,,,,,,,,,,,,, -31300,1.7704059,1.7132586,,,,,,,,,,,,,, -31400,1.8420517,1.6631413,,,,,,,,,,,,,, -31500,1.7877463,1.6524758,,,,,,,,,,,,,, -31600,1.9695526,1.661059,,,,,,,,,,,,,, -31700,1.8094757,1.6663795,,,,,,,,,,,,,, -31744,,,0.6746252775192261,1.277440309524536,0.6144999861717224,1.6090742349624634,50000.0,0.4924000203609466,2.371462106704712,10000.0,10763.705124855042,11260.48165845871,10763.705124855042,494.9138953685761,0.6445157527923584,0.0 -31800,1.5690203,1.5562472,,,,,,,,,,,,,, -31900,1.6859044,1.7996023,,,,,,,,,,,,,, -32000,1.7909403,1.7624246,,,,,,,,,,,,,, -32100,1.7755542,1.6800619,,,,,,,,,,,,,, -32200,1.6184202,1.7091848,,,,,,,,,,,,,, -32300,1.9211539,1.7030914,,,,,,,,,,,,,, -32400,1.6786556,1.7718942,,,,,,,,,,,,,, -32500,1.7910172,1.7440947,,,,,,,,,,,,,, -32600,1.7315025,1.718528,,,,,,,,,,,,,, -32700,2.0392017,1.7452519,,,,,,,,,,,,,, -32800,2.046247,1.7838413,,,,,,,,,,,,,, -32900,1.6154962,1.7304194,,,,,,,,,,,,,, -33000,1.8100866,1.6736925,,,,,,,,,,,,,, -33100,1.8312669,1.6338999,,,,,,,,,,,,,, -33200,2.119631,1.7084752,,,,,,,,,,,,,, -33257,,,0.6889747977256775,1.2175207138061523,0.6301999688148499,1.5399216413497925,50000.0,0.5009000301361084,2.308980941772461,10000.0,11273.895986318588,11795.578994750977,11273.895986318588,519.7346889972687,0.6725142002105713,0.0 -33300,1.8889303,1.7353569,,,,,,,,,,,,,, -33400,1.6849626,1.7047944,,,,,,,,,,,,,, -33500,1.6499513,1.6601869,,,,,,,,,,,,,, -33600,1.6932236,1.7486501,,,,,,,,,,,,,, -33700,2.0045106,1.6967164,,,,,,,,,,,,,, -33800,1.5811207,1.6137464,,,,,,,,,,,,,, -33900,1.6118753,1.5997355,,,,,,,,,,,,,, -34000,1.7651417,1.7666082,,,,,,,,,,,,,, -34100,1.7898118,1.7673178,,,,,,,,,,,,,, -34200,1.818182,1.7386318,,,,,,,,,,,,,, -34300,2.07983,1.5874617,,,,,,,,,,,,,, -34400,1.6483399,1.5351917,,,,,,,,,,,,,, -34500,1.9562587,1.7987256,,,,,,,,,,,,,, -34600,1.7405467,1.6415558,,,,,,,,,,,,,, -34700,2.0063937,1.6596308,,,,,,,,,,,,,, -34770,,,0.6912468075752258,1.2342867851257324,0.6308000087738037,1.5375585556030271,50000.0,0.5055000185966492,2.2644104957580566,10000.0,11783.841839790344,12330.82792854309,11783.841839790344,544.9519906044006,0.7002818584442139,0.0 -34800,1.7315328,1.7440223,,,,,,,,,,,,,, -34900,1.6876512,1.6045309,,,,,,,,,,,,,, -35000,1.6352832,1.7236608,,,,,,,,,,,,,, -35100,1.7420237,1.6947806,,,,,,,,,,,,,, -35200,2.1121743,1.7533038,,,,,,,,,,,,,, -35300,1.7999699,1.7215891,,,,,,,,,,,,,, -35400,1.7959293,1.6433084,,,,,,,,,,,,,, -35500,1.6416912,1.7018228,,,,,,,,,,,,,, -35600,1.9127235,1.5862087,,,,,,,,,,,,,, -35700,1.7546163,1.7238108,,,,,,,,,,,,,, -35800,1.9113019,1.7229443,,,,,,,,,,,,,, -35900,1.9540359,1.7406565,,,,,,,,,,,,,, -36000,1.6835867,1.5980557,,,,,,,,,,,,,, -36100,1.8680452,1.6691747,,,,,,,,,,,,,, -36200,1.6064032,1.662628,,,,,,,,,,,,,, -36284,,,0.7382612824440002,1.002962589263916,0.6293799877166748,1.5438278913497925,50000.0,0.4988000094890594,2.274627685546875,10000.0,12293.991918563845,12866.157500267029,12293.991918563845,570.0424513816833,0.7297089099884033,0.0 -36300,1.669192,1.7307786,,,,,,,,,,,,,, -36400,2.1437588,1.69293,,,,,,,,,,,,,, -36500,1.9637936,1.7366012,,,,,,,,,,,,,, -36600,1.6034343,1.6271431,,,,,,,,,,,,,, -36700,1.5797594,1.6314896,,,,,,,,,,,,,, -36800,2.111886,1.6737071,,,,,,,,,,,,,, -36900,1.5988281,1.6916478,,,,,,,,,,,,,, -37000,1.996287,1.7170875,,,,,,,,,,,,,, -37100,1.7919449,1.6142318,,,,,,,,,,,,,, -37200,1.806186,1.6459385,,,,,,,,,,,,,, -37300,1.9395099,1.6853143,,,,,,,,,,,,,, -37400,1.7886956,1.6707717,,,,,,,,,,,,,, -37500,1.7340345,1.5858514,,,,,,,,,,,,,, -37600,1.6129466,1.6200736,,,,,,,,,,,,,, -37700,1.858018,1.6257609,,,,,,,,,,,,,, -37797,,,0.7053372263908386,1.138272404670715,0.6288599967956543,1.5252172946929932,50000.0,0.5065000057220459,2.259315729141236,10000.0,12803.992875099182,13398.519066810608,12803.992875099182,592.3089263439178,0.7659671306610107,0.0 -37800,1.8179059,1.9112729,,,,,,,,,,,,,, -37900,2.070345,1.795608,,,,,,,,,,,,,, -38000,1.84298,1.6478102,,,,,,,,,,,,,, -38100,1.7240342,1.56861,,,,,,,,,,,,,, -38200,1.6472508,1.6442431,,,,,,,,,,,,,, -38300,1.8840874,1.7022481,,,,,,,,,,,,,, -38400,1.7643762,1.7831466,,,,,,,,,,,,,, -38500,1.7089481,1.6247162,,,,,,,,,,,,,, -38600,1.8000805,1.6018305,,,,,,,,,,,,,, -38700,1.7471225,1.6817808,,,,,,,,,,,,,, -38800,1.8043563,1.6336643,,,,,,,,,,,,,, -38900,1.8423793,1.8024585,,,,,,,,,,,,,, -39000,1.9688551,1.6918204,,,,,,,,,,,,,, -39100,2.1881614,1.7293241,,,,,,,,,,,,,, -39200,1.7297869,1.6251935,,,,,,,,,,,,,, -39300,1.9433979,1.6717246,,,,,,,,,,,,,, -39311,,,0.6939771771430969,1.196521520614624,0.6310999989509583,1.5323960781097412,50000.0,0.5081000328063965,2.251711368560791,10000.0,13313.982473373411,13929.336554050446,13313.982473373411,613.0463311672211,0.7976090908050537,0.0 -39400,1.9420614,1.642654,,,,,,,,,,,,,, -39500,1.9144282,1.6485133,,,,,,,,,,,,,, -39600,1.5714935,1.5133929,,,,,,,,,,,,,, -39700,1.7427922,1.5999371,,,,,,,,,,,,,, -39800,1.9300032,1.5190897,,,,,,,,,,,,,, -39900,1.7195084,1.64235,,,,,,,,,,,,,, -40000,1.8131696,1.7175002,,,,,,,,,,,,,, -40100,2.0681634,1.6839865,,,,,,,,,,,,,, -40200,1.8583759,1.6716721,,,,,,,,,,,,,, -40300,1.9245945,1.647153,,,,,,,,,,,,,, -40400,1.795036,1.6910162,,,,,,,,,,,,,, -40500,2.0774395,1.6744609,,,,,,,,,,,,,, -40600,1.7616277,1.6372689,,,,,,,,,,,,,, -40700,1.8157532,1.5979093,,,,,,,,,,,,,, -40800,1.7028618,1.618985,,,,,,,,,,,,,, -40825,,,0.7027463316917419,1.174934148788452,0.6353999972343445,1.4998513460159302,50000.0,0.5074000358581543,2.224247932434082,10000.0,13824.217654943466,14457.974243879318,13824.217654943466,631.3619570732117,0.8261539936065674,0.0 -40900,1.9177234,1.6531024,,,,,,,,,,,,,, -41000,1.8636732,1.6241128,,,,,,,,,,,,,, -41100,1.7289406,1.6746798,,,,,,,,,,,,,, -41200,1.5950232,1.5167325,,,,,,,,,,,,,, -41300,1.7819276,1.6110039,,,,,,,,,,,,,, -41400,2.0596592,1.6234688,,,,,,,,,,,,,, -41500,2.0170386,1.7102287,,,,,,,,,,,,,, -41600,1.7450283,1.6076106,,,,,,,,,,,,,, -41700,1.7843752,1.749226,,,,,,,,,,,,,, -41800,1.7597697,1.6286288,,,,,,,,,,,,,, -41900,1.8270282,1.5985856,,,,,,,,,,,,,, -42000,1.9678686,1.708885,,,,,,,,,,,,,, -42100,1.7249777,1.6770004,,,,,,,,,,,,,, -42200,1.9026892,1.6091053,,,,,,,,,,,,,, -42300,1.7655714,1.5867363,,,,,,,,,,,,,, -42339,,,0.6916852593421936,1.219604253768921,0.6343799829483032,1.524438977241516,50000.0,0.5037000179290771,2.296604633331299,10000.0,14334.139183282852,14986.990688800812,14334.139183282852,650.3616156578064,0.861600399017334,0.0 -42400,1.8022246,1.7103976,,,,,,,,,,,,,, -42500,1.8027142,1.5993816,,,,,,,,,,,,,, -42600,1.7115868,1.6665106,,,,,,,,,,,,,, -42700,1.898166,1.6495798,,,,,,,,,,,,,, -42800,1.706465,1.6857991,,,,,,,,,,,,,, -42900,1.8572551,1.690938,,,,,,,,,,,,,, -43000,1.8375357,1.6627665,,,,,,,,,,,,,, -43100,1.8694293,1.6385052,,,,,,,,,,,,,, -43200,1.7564642,1.7111684,,,,,,,,,,,,,, -43300,1.8458353,1.527071,,,,,,,,,,,,,, -43400,1.9415607,1.7493443,,,,,,,,,,,,,, -43500,1.7124918,1.5028121,,,,,,,,,,,,,, -43600,1.8088253,1.6562848,,,,,,,,,,,,,, -43700,1.7971184,1.5912191,,,,,,,,,,,,,, -43800,2.1054053,1.6439745,,,,,,,,,,,,,, -43853,,,0.687898576259613,1.2292335033416748,0.631060004234314,1.523308038711548,50000.0,0.4993000328540802,2.297249317169189,10000.0,14844.313416957855,15518.24677681923,14844.313416957855,671.3503079414368,0.8955569267272949,0.0 -43900,1.6633942,1.5615373,,,,,,,,,,,,,, -44000,1.9396732,1.6674027,,,,,,,,,,,,,, -44100,1.9820274,1.6353456,,,,,,,,,,,,,, -44200,1.7118373,1.4095033,,,,,,,,,,,,,, -44300,1.7645737,1.5312052,,,,,,,,,,,,,, -44400,1.9098017,1.6016273,,,,,,,,,,,,,, -44500,1.9698118,1.6493666,,,,,,,,,,,,,, -44600,1.7261196,1.5401231,,,,,,,,,,,,,, -44700,1.7943131,1.5260454,,,,,,,,,,,,,, -44800,1.9496692,1.6745375,,,,,,,,,,,,,, -44900,1.6781745,1.714818,,,,,,,,,,,,,, -45000,1.9524102,1.7066851,,,,,,,,,,,,,, -45100,1.7499081,1.6103867,,,,,,,,,,,,,, -45200,1.6589776,1.5251887,,,,,,,,,,,,,, -45300,1.8137925,1.6860503,,,,,,,,,,,,,, -45367,,,0.7377431392669678,0.9952830076217652,0.6425999999046326,1.4854003190994265,50000.0,0.5103000402450562,2.2528679370880127,10000.0,15354.307039022446,16047.569056749344,15354.307039022446,690.5901215076447,0.9257237911224364,0.0 -45400,1.8072822,1.7647308,,,,,,,,,,,,,, -45500,1.7376457,1.5752667,,,,,,,,,,,,,, -45600,2.0929368,1.7084794,,,,,,,,,,,,,, -45700,1.7108518,1.5616333,,,,,,,,,,,,,, -45800,1.8412097,1.6479698,,,,,,,,,,,,,, -45900,1.8414034,1.5577942,,,,,,,,,,,,,, -46000,1.7188845,1.5090754,,,,,,,,,,,,,, -46100,1.9176855,1.531621,,,,,,,,,,,,,, -46200,1.7646648,1.6596208,,,,,,,,,,,,,, -46300,1.8261205,1.6440427,,,,,,,,,,,,,, -46400,1.8860267,1.5642514,,,,,,,,,,,,,, -46500,1.7911955,1.5230337,,,,,,,,,,,,,, -46600,1.7643588,1.5546079,,,,,,,,,,,,,, -46700,2.224818,1.7174026,,,,,,,,,,,,,, -46800,2.027948,1.6495056,,,,,,,,,,,,,, -46881,,,0.7153220772743225,1.084122657775879,0.6393199563026428,1.4918371438980105,50000.0,0.5138000249862671,2.251056432723999,10000.0,15864.34359741211,16575.43159174919,15864.34359741211,708.3193356990814,0.9613451957702636,0.0 -46900,2.037182,1.674497,,,,,,,,,,,,,, -47000,1.8395898,1.6086125,,,,,,,,,,,,,, -47100,1.6956636,1.6406492,,,,,,,,,,,,,, -47200,1.7446284,1.568398,,,,,,,,,,,,,, -47300,1.7831733,1.5963304,,,,,,,,,,,,,, -47400,1.7720287,1.6349988,,,,,,,,,,,,,, -47500,1.7304708,1.61538,,,,,,,,,,,,,, -47600,1.8963495,1.7117263,,,,,,,,,,,,,, -47700,1.7215266,1.6312943,,,,,,,,,,,,,, -47800,1.9180332,1.7353221,,,,,,,,,,,,,, -47900,1.7437007,1.6112647,,,,,,,,,,,,,, -48000,1.9097031,1.5427195,,,,,,,,,,,,,, -48100,1.7729822,1.6761011,,,,,,,,,,,,,, -48200,1.7849709,1.6311111,,,,,,,,,,,,,, -48300,2.3301404,1.7608223,,,,,,,,,,,,,, -48395,,,0.702168345451355,1.1595250368118286,0.6326000094413757,1.5194488763809204,50000.0,0.5082000494003296,2.254747152328491,10000.0,16374.491770744324,17103.453681230545,16374.491770744324,726.0973536968231,0.9971561431884766,0.0 -48400,1.8506496,1.5040852,,,,,,,,,,,,,, -48500,1.842258,1.6092229,,,,,,,,,,,,,, -48600,1.754366,1.4724259,,,,,,,,,,,,,, -48700,1.9223102,1.6901186,,,,,,,,,,,,,, -48800,1.7436087,1.566951,,,,,,,,,,,,,, -48900,1.8315524,1.5452057,,,,,,,,,,,,,, -49000,1.9109776,1.6367509,,,,,,,,,,,,,, -49100,1.8577127,1.6766222,,,,,,,,,,,,,, -49200,1.861055,1.6075779,,,,,,,,,,,,,, -49300,1.6310251,1.5526203,,,,,,,,,,,,,, -49400,1.8984646,1.6315345,,,,,,,,,,,,,, -49500,1.8676273,1.6040344,,,,,,,,,,,,,, -49600,2.0901203,1.6289675,,,,,,,,,,,,,, -49700,1.8936996,1.6959966,,,,,,,,,,,,,, -49800,1.9307303,1.5960709,,,,,,,,,,,,,, -49900,1.9561877,1.613664,,,,,,,,,,,,,, -49908,,,0.7033841013908386,1.1538665294647217,0.6426599621772766,1.476432204246521,50000.0,0.5149000287055969,2.219771146774292,10000.0,16884.507195949554,17632.462087869644,16884.507195949554,744.994348526001,1.0333445072174072,0.0 -50000,1.8201147,1.6326486,,,,,,,,,,,,,, -50100,1.8144157,1.5122905,,,,,,,,,,,,,, -50200,2.0321555,1.6396341,,,,,,,,,,,,,, -50300,1.7593563,1.503942,,,,,,,,,,,,,, -50400,1.8399953,1.6546555,,,,,,,,,,,,,, -50500,2.063157,1.6360195,,,,,,,,,,,,,, -50600,2.1517498,1.5625364,,,,,,,,,,,,,, -50700,1.7782059,1.6293504,,,,,,,,,,,,,, -50800,1.8737606,1.6626003,,,,,,,,,,,,,, -50900,1.8128393,1.6836706,,,,,,,,,,,,,, -51000,1.9382845,1.5451615,,,,,,,,,,,,,, -51100,1.7853475,1.604451,,,,,,,,,,,,,, -51200,1.8647053,1.6091925,,,,,,,,,,,,,, -51300,1.961022,1.5225687,,,,,,,,,,,,,, -51400,1.9671159,1.5501986,,,,,,,,,,,,,, -51422,,,0.7069913744926453,1.1432219743728638,0.6446399688720703,1.4551054239273071,50000.0,0.5091000199317932,2.231141805648804,10000.0,17394.70106601715,18161.12567853928,17394.70106601715,763.3682973384857,1.068270206451416,0.0 -51500,1.7840298,1.6916378,,,,,,,,,,,,,, -51600,1.7471213,1.5196643,,,,,,,,,,,,,, -51700,1.784851,1.5665019,,,,,,,,,,,,,, -51800,1.975046,1.6922326,,,,,,,,,,,,,, -51900,2.206518,1.6276752,,,,,,,,,,,,,, -52000,1.9777554,1.5885341,,,,,,,,,,,,,, -52100,1.8210564,1.5662884,,,,,,,,,,,,,, -52200,1.9979073,1.6762123,,,,,,,,,,,,,, -52300,2.0591154,1.6518795,,,,,,,,,,,,,, -52400,1.8059764,1.5154252,,,,,,,,,,,,,, -52500,1.8441205,1.5435044,,,,,,,,,,,,,, -52600,2.0011482,1.6351871,,,,,,,,,,,,,, -52700,1.9395318,1.6101598,,,,,,,,,,,,,, -52800,1.6542082,1.7586071,,,,,,,,,,,,,, -52900,1.6918546,1.5049154,,,,,,,,,,,,,, -52935,,,0.696707546710968,1.1853846311569214,0.6376000046730042,1.4835251569747925,50000.0,0.5148000121116638,2.219569206237793,10000.0,17904.66800928116,18689.03200531006,17904.66800928116,781.212443113327,1.1037437915802002,0.0 -53000,2.3973956,1.6765339,,,,,,,,,,,,,, -53100,2.0380313,1.6164019,,,,,,,,,,,,,, -53200,2.017361,1.5222021,,,,,,,,,,,,,, -53300,2.080743,1.625632,,,,,,,,,,,,,, -53400,1.903425,1.6651053,,,,,,,,,,,,,, -53500,2.1886241,1.6017721,,,,,,,,,,,,,, -53600,2.1437733,1.6527607,,,,,,,,,,,,,, -53700,1.8147242,1.5725499,,,,,,,,,,,,,, -53800,1.9230226,1.587649,,,,,,,,,,,,,, -53900,1.7024443,1.4653502,,,,,,,,,,,,,, -54000,1.7323463,1.5352571,,,,,,,,,,,,,, -54100,2.0315075,1.6448927,,,,,,,,,,,,,, -54200,1.7012182,1.5511312,,,,,,,,,,,,,, -54300,1.7681062,1.5958827,,,,,,,,,,,,,, -54400,1.823703,1.635513,,,,,,,,,,,,,, -54449,,,0.7293726205825806,1.033711075782776,0.6370999813079834,1.4978829622268677,50000.0,0.5045000314712524,2.2475028038024902,10000.0,18414.68908238411,19216.77681660652,18414.68908238411,798.8383796215057,1.1416583061218262,0.0 -54500,1.9301919,1.5083666,,,,,,,,,,,,,, -54600,1.8912126,1.5249684,,,,,,,,,,,,,, -54700,1.9124497,1.5557861,,,,,,,,,,,,,, -54800,1.691279,1.6419204,,,,,,,,,,,,,, -54900,1.9646642,1.7522244,,,,,,,,,,,,,, -55000,1.936414,1.6629493,,,,,,,,,,,,,, -55100,1.7984312,1.5665572,,,,,,,,,,,,,, -55200,1.8926408,1.5468404,,,,,,,,,,,,,, -55300,1.9322115,1.6092861,,,,,,,,,,,,,, -55400,1.8774182,1.5386064,,,,,,,,,,,,,, -55500,1.6320914,1.470083,,,,,,,,,,,,,, -55600,1.9639038,1.6175181,,,,,,,,,,,,,, -55700,1.8137088,1.5016667,,,,,,,,,,,,,, -55800,1.8581752,1.6466268,,,,,,,,,,,,,, -55900,2.0356882,1.6386635,,,,,,,,,,,,,, -55962,,,0.7280173897743225,1.0381181240081787,0.6498199701309204,1.4298878908157349,50000.0,0.5263000130653381,2.128363609313965,10000.0,18924.828382968903,19744.545511245728,18924.828382968903,816.3726131916046,1.1773545742034912,0.0 -56000,1.9740293,1.6052623,,,,,,,,,,,,,, -56100,1.8945374,1.6580094,,,,,,,,,,,,,, -56200,2.0858395,1.5283846,,,,,,,,,,,,,, -56300,1.71426,1.4118675,,,,,,,,,,,,,, -56400,1.7789646,1.4946123,,,,,,,,,,,,,, -56500,1.9349717,1.5058843,,,,,,,,,,,,,, -56600,2.4128435,1.6261916,,,,,,,,,,,,,, -56700,1.692794,1.5223006,,,,,,,,,,,,,, -56800,1.6597037,1.5123963,,,,,,,,,,,,,, -56900,1.8585466,1.5240387,,,,,,,,,,,,,, -57000,1.8597873,1.6233093,,,,,,,,,,,,,, -57100,1.8428437,1.5527515,,,,,,,,,,,,,, -57200,1.9746844,1.4678699,,,,,,,,,,,,,, -57300,1.7975721,1.5960727,,,,,,,,,,,,,, -57400,2.0780797,1.6352643,,,,,,,,,,,,,, -57476,,,0.7173150181770325,1.0886608362197876,0.6471399664878845,1.4522713422775269,50000.0,0.5209000110626221,2.188610553741455,10000.0,19434.806651115417,20273.25348353386,19434.806651115417,834.9931175708771,1.2265896797180176,0.0 -57500,1.6916355,1.5354998,,,,,,,,,,,,,, -57600,1.7749473,1.6360111,,,,,,,,,,,,,, -57700,1.9226035,1.6187545,,,,,,,,,,,,,, -57800,1.716158,1.4681065,,,,,,,,,,,,,, -57900,1.8078401,1.5779502,,,,,,,,,,,,,, -58000,1.9195731,1.5595474,,,,,,,,,,,,,, -58100,1.7670248,1.4919074,,,,,,,,,,,,,, -58200,1.884494,1.5280154,,,,,,,,,,,,,, -58300,1.7532206,1.5911833,,,,,,,,,,,,,, -58400,1.8757855,1.591252,,,,,,,,,,,,,, -58500,1.8567898,1.4582464,,,,,,,,,,,,,, -58600,1.8986828,1.4586266,,,,,,,,,,,,,, -58700,1.7998582,1.5445619,,,,,,,,,,,,,, -58800,2.1175892,1.559702,,,,,,,,,,,,,, -58900,1.9422857,1.4097217,,,,,,,,,,,,,, -58989,,,0.7143853306770325,1.1113851070404053,0.6462999582290649,1.4568002223968506,50000.0,0.526900053024292,2.157479286193848,10000.0,19944.79716658592,20804.59820485115,19944.79716658592,856.2516252994537,1.2634735107421875,0.0 -59000,1.99486,1.6293705,,,,,,,,,,,,,, -59100,1.8231423,1.5826097,,,,,,,,,,,,,, -59200,1.9371215,1.4896657,,,,,,,,,,,,,, -59300,1.8006245,1.5888466,,,,,,,,,,,,,, -59400,1.8061665,1.5518532,,,,,,,,,,,,,, -59500,2.0002873,1.6716357,,,,,,,,,,,,,, -59600,1.8181603,1.5446322,,,,,,,,,,,,,, -59700,1.9087605,1.5735905,,,,,,,,,,,,,, -59800,1.9738361,1.6422416,,,,,,,,,,,,,, -59900,1.9143001,1.6505692,,,,,,,,,,,,,, -60000,1.8369737,1.5498043,,,,,,,,,,,,,, -60100,1.9391737,1.5694988,,,,,,,,,,,,,, -60200,2.1191964,1.6406265,,,,,,,,,,,,,, -60300,1.8270149,1.5825565,,,,,,,,,,,,,, -60400,1.8252943,1.4272356,,,,,,,,,,,,,, -60500,2.1781225,1.5367099,,,,,,,,,,,,,, -60504,,,0.7247688174247742,1.0654748678207395,0.6577799916267395,1.3966306447982788,50000.0,0.5327000021934509,2.102976560592652,10000.0,20454.971937417984,21332.455352783203,20454.971937417984,873.8442769050598,1.2946317195892334,0.0 -60600,1.7418793,1.4766576,,,,,,,,,,,,,, -60700,1.9755032,1.6365385,,,,,,,,,,,,,, -60800,2.2673783,1.5185363,,,,,,,,,,,,,, -60900,2.0249603,1.7033817,,,,,,,,,,,,,, -61000,1.862848,1.5129389,,,,,,,,,,,,,, -61100,2.0058749,1.6733452,,,,,,,,,,,,,, -61200,1.8668975,1.5285059,,,,,,,,,,,,,, -61300,2.0028965,1.5629375,,,,,,,,,,,,,, -61400,1.9128162,1.5297052,,,,,,,,,,,,,, -61500,1.8312836,1.4794422,,,,,,,,,,,,,, -61600,1.8678476,1.4721444,,,,,,,,,,,,,, -61700,2.0178716,1.5313723,,,,,,,,,,,,,, -61800,2.1204762,1.5497723,,,,,,,,,,,,,, -61900,1.9766058,1.5984731,,,,,,,,,,,,,, -62000,1.8487706,1.6407526,,,,,,,,,,,,,, -62017,,,0.7164580821990967,1.0905795097351074,0.6516599655151367,1.424889326095581,50000.0,0.5214000344276428,2.1913771629333496,10000.0,20964.876480817795,21860.027040958405,20964.876480817795,891.4071831703186,1.3386225700378418,0.0 -62100,1.8534138,1.5217323,,,,,,,,,,,,,, -62200,1.9871573,1.4641635,,,,,,,,,,,,,, -62300,1.8776401,1.4782333,,,,,,,,,,,,,, -62400,1.9800828,1.6636125,,,,,,,,,,,,,, -62500,1.9719697,1.6528699,,,,,,,,,,,,,, -62600,2.1087818,1.5984443,,,,,,,,,,,,,, -62700,2.1662514,1.4866221,,,,,,,,,,,,,, -62800,1.990852,1.629673,,,,,,,,,,,,,, -62900,1.8611765,1.5049912,,,,,,,,,,,,,, -63000,2.2439113,1.5536,,,,,,,,,,,,,, -63100,1.8280427,1.5694597,,,,,,,,,,,,,, -63200,1.9254073,1.5132878,,,,,,,,,,,,,, -63300,1.8904876,1.5131264,,,,,,,,,,,,,, -63400,1.9298735,1.5445962,,,,,,,,,,,,,, -63500,2.1513135,1.558237,,,,,,,,,,,,,, -63531,,,0.7502989172935486,0.9399452805519104,0.6513599753379822,1.4273583889007568,50000.0,0.525700032711029,2.16554856300354,10000.0,21474.96597290039,22390.485930919647,21474.96597290039,911.6735122203828,1.3813552856445312,0.0 -63600,1.857038,1.5724082,,,,,,,,,,,,,, -63700,1.9667712,1.511977,,,,,,,,,,,,,, -63800,2.1036944,1.5436567,,,,,,,,,,,,,, -63900,1.800435,1.565966,,,,,,,,,,,,,, -64000,2.1544123,1.662038,,,,,,,,,,,,,, -64100,2.246263,1.6358675,,,,,,,,,,,,,, -64200,1.9092121,1.5577652,,,,,,,,,,,,,, -64300,2.024662,1.5247133,,,,,,,,,,,,,, -64400,2.0671928,1.5459421,,,,,,,,,,,,,, -64500,2.0000174,1.5638467,,,,,,,,,,,,,, -64600,1.8662071,1.5207999,,,,,,,,,,,,,, -64700,1.9781251,1.5334996,,,,,,,,,,,,,, -64800,2.0573559,1.6160679,,,,,,,,,,,,,, -64900,1.9285734,1.6695054,,,,,,,,,,,,,, -65000,1.7863734,1.5076125,,,,,,,,,,,,,, -65045,,,0.7389987111091614,0.9976285696029664,0.6594399809837341,1.386866569519043,50000.0,0.5302000045776367,2.1050965785980225,10000.0,21985.078320980072,22919.42360758781,21985.078320980072,930.4063663482666,1.414006233215332,0.0 -65100,1.875779,1.4615251,,,,,,,,,,,,,, -65200,2.1623886,1.5477947,,,,,,,,,,,,,, -65300,2.0985818,1.5639219,,,,,,,,,,,,,, -65400,2.212347,1.6537063,,,,,,,,,,,,,, -65500,1.8703312,1.5272478,,,,,,,,,,,,,, -65600,1.8657889,1.5239357,,,,,,,,,,,,,, -65700,2.0266824,1.4943974,,,,,,,,,,,,,, -65800,1.9795632,1.4166472,,,,,,,,,,,,,, -65900,1.9495717,1.5182194,,,,,,,,,,,,,, -66000,1.8973329,1.5884231,,,,,,,,,,,,,, -66100,1.9480011,1.4724939,,,,,,,,,,,,,, -66200,2.148431,1.5444436,,,,,,,,,,,,,, -66300,2.0599005,1.5605477,,,,,,,,,,,,,, -66400,1.9700115,1.5369406,,,,,,,,,,,,,, -66500,1.8211391,1.4980831,,,,,,,,,,,,,, -66559,,,0.7310267686843872,1.0304648876190186,0.6547399759292603,1.406494140625,50000.0,0.5330000519752502,2.112121820449829,10000.0,22495.0784740448,23447.982553720474,22495.0784740448,948.8673043251038,1.452296018600464,0.0 -66600,1.8304976,1.483492,,,,,,,,,,,,,, -66700,2.1109698,1.7218764,,,,,,,,,,,,,, -66800,1.9390965,1.5765655,,,,,,,,,,,,,, -66900,2.162518,1.4952347,,,,,,,,,,,,,, -67000,2.0094628,1.5279119,,,,,,,,,,,,,, -67100,1.7303542,1.4569867,,,,,,,,,,,,,, -67200,1.8720161,1.4094543,,,,,,,,,,,,,, -67300,1.9042621,1.537802,,,,,,,,,,,,,, -67400,1.8212197,1.5318515,,,,,,,,,,,,,, -67500,2.0388513,1.5637562,,,,,,,,,,,,,, -67600,2.261746,1.550018,,,,,,,,,,,,,, -67700,2.0212998,1.5157627,,,,,,,,,,,,,, -67800,2.0990715,1.6158111,,,,,,,,,,,,,, -67900,2.0737936,1.5760231,,,,,,,,,,,,,, -68000,2.0178585,1.4707041,,,,,,,,,,,,,, -68072,,,0.7293127775192261,1.0425559282302856,0.6594399809837341,1.407362461090088,50000.0,0.5356000065803528,2.10113525390625,10000.0,23005.171627759933,23975.81109070778,23005.171627759933,966.5051369667052,1.4912090301513672,0.0 -68100,1.8208382,1.5427754,,,,,,,,,,,,,, -68200,2.498908,1.5247756,,,,,,,,,,,,,, -68300,2.0449426,1.5693364,,,,,,,,,,,,,, -68400,2.1050668,1.4353383,,,,,,,,,,,,,, -68500,1.9832382,1.5865532,,,,,,,,,,,,,, -68600,1.9714917,1.525596,,,,,,,,,,,,,, -68700,2.0046122,1.5850085,,,,,,,,,,,,,, -68800,1.884903,1.4070792,,,,,,,,,,,,,, -68900,1.8485377,1.4433006,,,,,,,,,,,,,, -69000,2.1144538,1.4825357,,,,,,,,,,,,,, -69100,2.274894,1.5220091,,,,,,,,,,,,,, -69200,2.0241463,1.5236821,,,,,,,,,,,,,, -69300,2.0516555,1.6383992,,,,,,,,,,,,,, -69400,1.8379772,1.480201,,,,,,,,,,,,,, -69500,2.2454774,1.4745643,,,,,,,,,,,,,, -69586,,,0.7391382455825806,1.0020445585250854,0.6672199964523315,1.3593355417251587,50000.0,0.5320000052452087,2.1139161586761475,10000.0,23515.16109251976,24503.22127485276,23515.16109251976,983.824857711792,1.5330331325531006,0.0 -69600,2.2346575,1.6837485,,,,,,,,,,,,,, -69700,1.9695246,1.5972627,,,,,,,,,,,,,, -69800,2.0036268,1.5571495,,,,,,,,,,,,,, -69900,2.2001169,1.5319285,,,,,,,,,,,,,, -70000,1.8767494,1.5896008,,,,,,,,,,,,,, -70100,2.0123153,1.6179545,,,,,,,,,,,,,, -70200,1.8373909,1.3581951,,,,,,,,,,,,,, -70300,2.1316185,1.512336,,,,,,,,,,,,,, -70400,2.0092075,1.6247896,,,,,,,,,,,,,, -70500,2.050058,1.4562485,,,,,,,,,,,,,, -70600,1.9665891,1.5271674,,,,,,,,,,,,,, -70700,2.0510178,1.4686238,,,,,,,,,,,,,, -70800,2.0503957,1.4969003,,,,,,,,,,,,,, -70900,2.0400264,1.5256284,,,,,,,,,,,,,, -71000,2.0691504,1.5299896,,,,,,,,,,,,,, -71100,,,0.7205436825752258,1.0698813199996948,0.6531199812889099,1.4194097518920898,50000.0,0.5295000076293945,2.145573139190674,10000.0,24025.32474899292,25031.97860717773,24025.32474899292,1002.3198018074036,1.5721306800842283,0.0 -71100,1.9931139,1.4337703,,,,,,,,,,,,,, -71200,1.9618247,1.591291,,,,,,,,,,,,,, -71300,1.9400725,1.4679306,,,,,,,,,,,,,, -71400,1.9564625,1.5202827,,,,,,,,,,,,,, -71500,1.9407331,1.5035233,,,,,,,,,,,,,, -71600,1.8106464,1.5516351,,,,,,,,,,,,,, -71700,1.9108664,1.416775,,,,,,,,,,,,,, -71800,2.0671716,1.5341548,,,,,,,,,,,,,, -71900,2.0394022,1.6159877,,,,,,,,,,,,,, -72000,1.9980899,1.5187784,,,,,,,,,,,,,, -72100,2.027695,1.4285033,,,,,,,,,,,,,, -72200,1.9022782,1.5113297,,,,,,,,,,,,,, -72300,2.063795,1.5139785,,,,,,,,,,,,,, -72400,1.9747783,1.4818094,,,,,,,,,,,,,, -72500,2.0213094,1.5393693,,,,,,,,,,,,,, -72600,2.0725064,1.5295265,,,,,,,,,,,,,, -72614,,,0.74320387840271,0.9673858880996704,0.6480799913406372,1.4243541955947876,50000.0,0.5308000445365906,2.152554750442505,10000.0,24535.45041823387,25559.87971568108,24535.45041823387,1019.9946706295012,1.6126031875610352,0.0 -72700,2.0928786,1.5601448,,,,,,,,,,,,,, -72800,2.09312,1.4384173,,,,,,,,,,,,,, -72900,1.9254584,1.4715339,,,,,,,,,,,,,, -73000,2.2102625,1.4258003,,,,,,,,,,,,,, -73100,2.0142496,1.4387021,,,,,,,,,,,,,, -73200,1.8407111,1.4242948,,,,,,,,,,,,,, -73300,2.040233,1.5271949,,,,,,,,,,,,,, -73400,2.229516,1.4919684,,,,,,,,,,,,,, -73500,2.1704395,1.593165,,,,,,,,,,,,,, -73600,1.8984712,1.4782703,,,,,,,,,,,,,, -73700,2.055394,1.5013121,,,,,,,,,,,,,, -73800,2.1079974,1.520557,,,,,,,,,,,,,, -73900,2.0385013,1.4497404,,,,,,,,,,,,,, -74000,1.9554805,1.5170059,,,,,,,,,,,,,, -74100,1.9733479,1.5056195,,,,,,,,,,,,,, -74128,,,0.7434629797935486,0.981106460094452,0.6617599725723267,1.38577401638031,50000.0,0.5333000421524048,2.12503719329834,10000.0,25045.385497808456,26087.53072452545,25045.385497808456,1037.6091718673706,1.6538634300231934,0.0 -74200,1.9032617,1.4713101,,,,,,,,,,,,,, -74300,2.1944308,1.459111,,,,,,,,,,,,,, -74400,2.0465019,1.4100105,,,,,,,,,,,,,, -74500,1.923436,1.455223,,,,,,,,,,,,,, -74600,1.8563212,1.4413823,,,,,,,,,,,,,, -74700,2.0398965,1.5892466,,,,,,,,,,,,,, -74800,2.0043113,1.4822826,,,,,,,,,,,,,, -74900,1.919532,1.5599023,,,,,,,,,,,,,, -75000,1.9153323,1.4753278,,,,,,,,,,,,,, -75100,2.3873808,1.6084288,,,,,,,,,,,,,, -75200,2.0502596,1.4719386,,,,,,,,,,,,,, -75300,2.0519655,1.4700882,,,,,,,,,,,,,, -75400,2.017061,1.4475462,,,,,,,,,,,,,, -75500,1.9694452,1.46814,,,,,,,,,,,,,, -75600,2.3372262,1.4292648,,,,,,,,,,,,,, -75642,,,0.7388990521430969,0.9926463961601256,0.6643999814987183,1.3834563493728638,50000.0,0.5353000164031982,2.1104061603546143,10000.0,25555.467805862427,26618.26466965676,25555.467805862427,1058.160115480423,1.6960022449493408,0.0 -75700,2.104867,1.5157969,,,,,,,,,,,,,, -75800,2.0125332,1.5138319,,,,,,,,,,,,,, -75900,2.2077942,1.5523338,,,,,,,,,,,,,, -76000,2.0238583,1.5415564,,,,,,,,,,,,,, -76100,2.0596883,1.5331424,,,,,,,,,,,,,, -76200,2.1904829,1.5136522,,,,,,,,,,,,,, -76300,2.194654,1.6232408,,,,,,,,,,,,,, -76400,2.0871105,1.6117271,,,,,,,,,,,,,, -76500,1.9436724,1.3854948,,,,,,,,,,,,,, -76600,2.290595,1.4145598,,,,,,,,,,,,,, -76700,2.0538435,1.4931273,,,,,,,,,,,,,, -76800,2.2521267,1.4230852,,,,,,,,,,,,,, -76900,2.056574,1.4474747,,,,,,,,,,,,,, -77000,2.0132341,1.4475178,,,,,,,,,,,,,, -77100,1.9558128,1.4162755,,,,,,,,,,,,,, -77156,,,0.7443000674247742,0.968315601348877,0.6729399561882019,1.3414160013198853,50000.0,0.5426000356674194,2.0834763050079346,10000.0,26065.52658700943,27145.634598970413,26065.52658700943,1075.3721516132357,1.735579490661621,0.0 -77200,2.1481209,1.527552,,,,,,,,,,,,,, -77300,2.2306354,1.4415122,,,,,,,,,,,,,, -77400,2.376843,1.410822,,,,,,,,,,,,,, -77500,2.3283918,1.3721277,,,,,,,,,,,,,, -77600,2.1159635,1.3895277,,,,,,,,,,,,,, -77700,2.1867602,1.3373061,,,,,,,,,,,,,, -77800,2.1544192,1.4095945,,,,,,,,,,,,,, -77900,2.0544348,1.416661,,,,,,,,,,,,,, -78000,2.1059387,1.4448578,,,,,,,,,,,,,, -78100,2.1407697,1.408125,,,,,,,,,,,,,, -78200,1.9727588,1.3664699,,,,,,,,,,,,,, -78300,2.1869957,1.510399,,,,,,,,,,,,,, -78400,1.9751035,1.4457735,,,,,,,,,,,,,, -78500,1.9573622,1.448496,,,,,,,,,,,,,, -78600,2.025116,1.4242526,,,,,,,,,,,,,, -78668,,,0.73046875,1.0355678796768188,0.6589999794960022,1.397053837776184,50000.0,0.5336000323295593,2.1245603561401367,10000.0,26574.551361083984,27672.950965881348,26574.551361083984,1092.6724405288696,2.6665658950805664,0.0 -78700,2.135676,1.3672196,,,,,,,,,,,,,, -78800,2.2068748,1.5255573,,,,,,,,,,,,,, -78900,2.2345061,1.4152124,,,,,,,,,,,,,, -79000,2.2520397,1.500529,,,,,,,,,,,,,, -79100,2.3986568,1.4686286,,,,,,,,,,,,,, -79200,2.1356254,1.478882,,,,,,,,,,,,,, -79300,1.9561731,1.5074577,,,,,,,,,,,,,, -79400,2.2504938,1.4624574,,,,,,,,,,,,,, -79500,2.1058476,1.5314391,,,,,,,,,,,,,, -79600,2.0141547,1.4686986,,,,,,,,,,,,,, -79700,2.2857933,1.4395438,,,,,,,,,,,,,, -79800,2.1209981,1.4764497,,,,,,,,,,,,,, -79900,1.9642675,1.4706872,,,,,,,,,,,,,, -80000,1.9232919,1.2831436,,,,,,,,,,,,,, -80100,2.351009,1.4756017,,,,,,,,,,,,,, -80182,,,0.7581114172935486,0.9292783141136168,0.6686399579048157,1.3589168787002563,50000.0,0.5433000326156616,2.059293031692505,10000.0,27084.58042907715,28200.16864085197,27084.58042907715,1109.7548336982727,2.7132747173309326,0.0 -80200,2.0369046,1.4095653,,,,,,,,,,,,,, -80300,2.0573614,1.3771274,,,,,,,,,,,,,, -80400,1.9922385,1.341828,,,,,,,,,,,,,, -80500,2.341182,1.484437,,,,,,,,,,,,,, -80600,2.4134243,1.4437003,,,,,,,,,,,,,, -80700,2.0941923,1.4094311,,,,,,,,,,,,,, -80800,2.1897361,1.5213153,,,,,,,,,,,,,, -80900,2.0999618,1.4852941,,,,,,,,,,,,,, -81000,2.127419,1.3644588,,,,,,,,,,,,,, -81100,1.9684435,1.577363,,,,,,,,,,,,,, -81200,2.0146363,1.4039782,,,,,,,,,,,,,, -81300,1.9659154,1.3184724,,,,,,,,,,,,,, -81400,2.242287,1.3421844,,,,,,,,,,,,,, -81500,2.1706676,1.571527,,,,,,,,,,,,,, -81600,2.1581006,1.4483055,,,,,,,,,,,,,, -81696,,,0.76761794090271,0.8746867179870605,0.671459972858429,1.3250740766525269,50000.0,0.5402000546455383,2.0668623447418213,10000.0,27594.486018419266,28727.456157445908,27594.486018419266,1127.0345587730408,2.7547857761383057,0.0 -81700,2.049392,1.434154,,,,,,,,,,,,,, -81800,2.134156,1.4983265,,,,,,,,,,,,,, -81900,2.2804143,1.5361508,,,,,,,,,,,,,, -82000,2.023632,1.4156746,,,,,,,,,,,,,, -82100,2.025024,1.4217556,,,,,,,,,,,,,, -82200,2.2712865,1.4444443,,,,,,,,,,,,,, -82300,2.3309515,1.4667622,,,,,,,,,,,,,, -82400,2.3630452,1.3919684,,,,,,,,,,,,,, -82500,2.2104223,1.4279902,,,,,,,,,,,,,, -82600,2.0763407,1.4368417,,,,,,,,,,,,,, -82700,2.117654,1.4747657,,,,,,,,,,,,,, -82800,2.2832398,1.4296263,,,,,,,,,,,,,, -82900,2.2518106,1.4964722,,,,,,,,,,,,,, -83000,2.1774194,1.4006081,,,,,,,,,,,,,, -83100,2.1174536,1.3953248,,,,,,,,,,,,,, -83200,2.190843,1.4709802,,,,,,,,,,,,,, -83210,,,0.7473094463348389,0.9471119046211244,0.6668599843978882,1.3621351718902588,50000.0,0.5398000478744507,2.0859215259552,10000.0,28104.702069997787,29255.24874305725,28104.702069997787,1144.5072071552277,2.79805588722229,0.0 -83300,2.3128424,1.455548,,,,,,,,,,,,,, -83400,2.406908,1.4720082,,,,,,,,,,,,,, -83500,2.2077389,1.53479,,,,,,,,,,,,,, -83600,2.105739,1.476969,,,,,,,,,,,,,, -83700,2.314646,1.346971,,,,,,,,,,,,,, -83800,2.395298,1.512671,,,,,,,,,,,,,, -83900,2.2054193,1.4203435,,,,,,,,,,,,,, -84000,2.2528615,1.526524,,,,,,,,,,,,,, -84100,2.0610027,1.4052639,,,,,,,,,,,,,, -84200,2.252342,1.3859304,,,,,,,,,,,,,, -84300,2.1020277,1.3538427,,,,,,,,,,,,,, -84400,2.2685287,1.4538006,,,,,,,,,,,,,, -84500,2.0080044,1.2745451,,,,,,,,,,,,,, -84600,2.1864574,1.3840504,,,,,,,,,,,,,, -84700,2.1827009,1.4926143,,,,,,,,,,,,,, -84725,,,0.7489636540412903,0.9444343447685242,0.6720799803733826,1.3304868936538696,50000.0,0.5400000214576721,2.061533212661743,10000.0,28614.92431378365,29782.745623350143,28614.92431378365,1161.6754019260406,2.844470977783203,0.0 -84800,2.101828,1.5215038,,,,,,,,,,,,,, -84900,2.315202,1.4964023,,,,,,,,,,,,,, -85000,2.3239112,1.401328,,,,,,,,,,,,,, -85100,2.196707,1.5269538,,,,,,,,,,,,,, -85200,2.3498192,1.4108127,,,,,,,,,,,,,, -85300,2.4077744,1.3991829,,,,,,,,,,,,,, -85400,2.0062084,1.3014984,,,,,,,,,,,,,, -85500,2.245371,1.4855697,,,,,,,,,,,,,, -85600,2.2592945,1.49755,,,,,,,,,,,,,, -85700,2.1955678,1.3225688,,,,,,,,,,,,,, -85800,2.1656525,1.3906989,,,,,,,,,,,,,, -85900,2.1359398,1.4878894,,,,,,,,,,,,,, -86000,2.6563544,1.3803177,,,,,,,,,,,,,, -86100,2.2509246,1.3647176,,,,,,,,,,,,,, -86200,2.1561015,1.3724338,,,,,,,,,,,,,, -86240,,,0.7512555718421936,0.9321123957633972,0.6727399826049805,1.3387607336044312,50000.0,0.5558000206947327,2.029335021972656,10000.0,29125.15648293495,30310.33651709557,29125.15648293495,1178.9313924312592,2.8873541355133057,0.0 -86300,2.1685417,1.3699139,,,,,,,,,,,,,, -86400,2.0824726,1.4390646,,,,,,,,,,,,,, -86500,2.1339269,1.4525416,,,,,,,,,,,,,, -86600,2.1679502,1.4751256,,,,,,,,,,,,,, -86700,2.3339875,1.4457748,,,,,,,,,,,,,, -86800,2.1524584,1.4539424,,,,,,,,,,,,,, -86900,2.537323,1.4568737,,,,,,,,,,,,,, -87000,2.1725624,1.2884845,,,,,,,,,,,,,, -87100,2.2392845,1.3857025,,,,,,,,,,,,,, -87200,2.2436836,1.3782881,,,,,,,,,,,,,, -87300,2.2302768,1.4575341,,,,,,,,,,,,,, -87400,2.157749,1.4498727,,,,,,,,,,,,,, -87500,2.19416,1.4171063,,,,,,,,,,,,,, -87600,2.4837956,1.5132331,,,,,,,,,,,,,, -87700,2.0750937,1.4207854,,,,,,,,,,,,,, -87755,,,0.7402941584587097,0.9757063388824464,0.6700800061225891,1.3509361743927002,50000.0,0.5366000533103943,2.072619676589966,10000.0,29635.361619234085,30837.75000357628,29635.361619234085,1196.0384995937347,2.9288535118103027,0.0 -87800,2.1173923,1.4670005,,,,,,,,,,,,,, -87900,2.2570224,1.3767776,,,,,,,,,,,,,, -88000,2.2059662,1.5313158,,,,,,,,,,,,,, -88100,2.079566,1.3089924,,,,,,,,,,,,,, -88200,1.9984959,1.3999729,,,,,,,,,,,,,, -88300,2.2526178,1.4324391,,,,,,,,,,,,,, -88400,2.223975,1.4137863,,,,,,,,,,,,,, -88500,2.4318087,1.4042919,,,,,,,,,,,,,, -88600,2.152555,1.4012256,,,,,,,,,,,,,, -88700,2.1764228,1.2765418,,,,,,,,,,,,,, -88800,2.34663,1.3769975,,,,,,,,,,,,,, -88900,2.3328009,1.4314563,,,,,,,,,,,,,, -89000,2.6179886,1.4113526,,,,,,,,,,,,,, -89100,2.3088958,1.4502759,,,,,,,,,,,,,, -89200,2.5838933,1.4791424,,,,,,,,,,,,,, -89269,,,0.7931481003761292,0.7678533792495728,0.6738199591636658,1.330244064331055,50000.0,0.5472000241279602,2.040836334228516,10000.0,30145.517723798752,31365.29650616645,30145.517723798752,1213.3233938217163,2.9756851196289062,0.0 -89300,2.8990438,1.4463174,,,,,,,,,,,,,, -89400,2.271936,1.4126527,,,,,,,,,,,,,, -89500,2.443792,1.44557,,,,,,,,,,,,,, -89600,2.226111,1.3887353,,,,,,,,,,,,,, -89700,2.3352191,1.3301002,,,,,,,,,,,,,, -89800,2.3211234,1.5002167,,,,,,,,,,,,,, -89900,2.283571,1.4906423,,,,,,,,,,,,,, -90000,2.3112898,1.4588789,,,,,,,,,,,,,, -90100,2.200658,1.379343,,,,,,,,,,,,,, -90200,2.3652396,1.5037595,,,,,,,,,,,,,, -90300,2.5678408,1.3435022,,,,,,,,,,,,,, -90400,2.1984396,1.3895985,,,,,,,,,,,,,, -90500,2.5752718,1.4009905,,,,,,,,,,,,,, -90600,2.426498,1.4233762,,,,,,,,,,,,,, -90700,2.126038,1.3759512,,,,,,,,,,,,,, -90784,,,0.769929826259613,0.8533276915550232,0.6751399636268616,1.321486234664917,50000.0,0.5485000014305115,2.0378308296203613,10000.0,30655.582379579544,31892.87816643715,30655.582379579544,1230.7311687469482,3.023667097091675,0.0 -90800,2.675788,1.3911564,,,,,,,,,,,,,, -90900,2.3532598,1.4241124,,,,,,,,,,,,,, -91000,2.0975752,1.4051323,,,,,,,,,,,,,, -91100,2.2522306,1.5557449,,,,,,,,,,,,,, -91200,2.126824,1.3204682,,,,,,,,,,,,,, -91300,2.3525176,1.429421,,,,,,,,,,,,,, -91400,2.2081058,1.3895926,,,,,,,,,,,,,, -91500,2.3541825,1.4668412,,,,,,,,,,,,,, -91600,2.2887335,1.3573713,,,,,,,,,,,,,, -91700,2.3540895,1.380978,,,,,,,,,,,,,, -91800,2.973274,1.3606731,,,,,,,,,,,,,, -91900,2.3951218,1.3812044,,,,,,,,,,,,,, -92000,2.1753852,1.4300183,,,,,,,,,,,,,, -92100,2.2160645,1.335898,,,,,,,,,,,,,, -92200,2.2017395,1.3687451,,,,,,,,,,,,,, -92298,,,0.7577128410339355,0.9045006632804872,0.671779990196228,1.3457863330841064,50000.0,0.5489000082015991,2.0553572177886963,10000.0,31165.769117355347,32420.492695093155,31165.769117355347,1248.036145210266,3.087367534637451,0.0 -92300,2.2962222,1.322729,,,,,,,,,,,,,, -92400,2.323863,1.4041102,,,,,,,,,,,,,, -92500,2.2686155,1.3889123,,,,,,,,,,,,,, -92600,2.6406872,1.4334072,,,,,,,,,,,,,, -92700,2.3387358,1.4521497,,,,,,,,,,,,,, -92800,2.2066312,1.4501824,,,,,,,,,,,,,, -92900,2.3923273,1.3993797,,,,,,,,,,,,,, -93000,2.2399457,1.3781179,,,,,,,,,,,,,, -93100,2.4631248,1.369472,,,,,,,,,,,,,, -93200,2.14276,1.3370639,,,,,,,,,,,,,, -93300,2.276577,1.3153694,,,,,,,,,,,,,, -93400,2.3865988,1.427659,,,,,,,,,,,,,, -93500,2.1354797,1.3872126,,,,,,,,,,,,,, -93600,2.2921438,1.3334279,,,,,,,,,,,,,, -93700,2.3176098,1.3899083,,,,,,,,,,,,,, -93800,2.279005,1.4329603,,,,,,,,,,,,,, -93812,,,0.7662428021430969,0.8773167729377747,0.6835599541664124,1.2988988161087036,50000.0,0.5547000169754028,2.0494730472564697,10000.0,31675.75886678696,32947.921608924866,31675.75886678696,1265.370517730713,3.132223129272461,0.0 -93900,2.4270203,1.3289089,,,,,,,,,,,,,, -94000,2.4094062,1.4261994,,,,,,,,,,,,,, -94100,2.7297409,1.4786807,,,,,,,,,,,,,, -94200,2.7051983,1.3534083,,,,,,,,,,,,,, -94300,2.4656084,1.4060965,,,,,,,,,,,,,, -94400,2.4145083,1.3203834,,,,,,,,,,,,,, -94500,2.2521272,1.3633327,,,,,,,,,,,,,, -94600,2.636621,1.4207429,,,,,,,,,,,,,, -94700,2.5119708,1.4133086,,,,,,,,,,,,,, -94800,2.218242,1.4429193,,,,,,,,,,,,,, -94900,2.3718016,1.410555,,,,,,,,,,,,,, -95000,2.5902598,1.4598465,,,,,,,,,,,,,, -95100,2.401587,1.4799699,,,,,,,,,,,,,, -95200,2.5639865,1.4702288,,,,,,,,,,,,,, -95300,2.4262304,1.4235477,,,,,,,,,,,,,, -95326,,,0.7669802308082581,0.8797918558120728,0.6861199736595154,1.2844245433807373,50000.0,0.5611000061035156,1.9772965908050537,10000.0,32185.968641519547,33475.55606675148,32185.968641519547,1282.691505908966,3.175578117370605,0.0 -95400,2.777555,1.451374,,,,,,,,,,,,,, -95500,2.6666458,1.423334,,,,,,,,,,,,,, -95600,2.335596,1.3786466,,,,,,,,,,,,,, -95700,2.4460766,1.4067843,,,,,,,,,,,,,, -95800,2.541167,1.4708983,,,,,,,,,,,,,, -95900,2.3896065,1.3855103,,,,,,,,,,,,,, -96000,2.2672055,1.3660548,,,,,,,,,,,,,, -96100,2.2409518,1.2267138,,,,,,,,,,,,,, -96200,2.352219,1.3873361,,,,,,,,,,,,,, -96300,2.0759563,1.200659,,,,,,,,,,,,,, -96400,2.4561877,1.4399323,,,,,,,,,,,,,, -96500,2.0966883,1.3351992,,,,,,,,,,,,,, -96600,2.2492533,1.370289,,,,,,,,,,,,,, -96700,2.3317256,1.3175418,,,,,,,,,,,,,, -96800,2.3626904,1.3835472,,,,,,,,,,,,,, -96840,,,0.7622169852256775,0.8912380337715149,0.6829599738121033,1.2990450859069824,50000.0,0.558899998664856,1.9927387237548828,10000.0,32696.044674634933,34003.18845510483,32696.044674634933,1300.1406605243685,3.221791982650757,0.0 -96900,2.3465972,1.3917747,,,,,,,,,,,,,, -97000,2.2254963,1.364439,,,,,,,,,,,,,, -97100,2.384023,1.3892163,,,,,,,,,,,,,, -97200,2.621212,1.3671365,,,,,,,,,,,,,, -97300,2.5437603,1.4315642,,,,,,,,,,,,,, -97400,2.3373082,1.3692665,,,,,,,,,,,,,, -97500,2.2907946,1.3027837,,,,,,,,,,,,,, -97600,2.5013387,1.3812901,,,,,,,,,,,,,, -97700,2.370924,1.3531812,,,,,,,,,,,,,, -97800,2.828957,1.4604884,,,,,,,,,,,,,, -97900,2.307449,1.3780653,,,,,,,,,,,,,, -98000,2.232033,1.3303667,,,,,,,,,,,,,, -98100,2.5633287,1.5050607,,,,,,,,,,,,,, -98200,2.483733,1.4667518,,,,,,,,,,,,,, -98300,2.4224505,1.2645938,,,,,,,,,,,,,, -98355,,,0.8137754797935486,0.6855299472808838,0.6881600022315979,1.265963435173035,50000.0,0.5583000183105469,1.9982116222381592,10000.0,33206.26758170128,34530.89180493355,33206.26758170128,1317.5182647705078,3.265420913696289,0.0 -98400,2.6979196,1.3588699,,,,,,,,,,,,,, -98500,2.5075138,1.339499,,,,,,,,,,,,,, -98600,2.3295608,1.2747048,,,,,,,,,,,,,, -98700,2.3478942,1.3722588,,,,,,,,,,,,,, -98800,2.5052981,1.3708618,,,,,,,,,,,,,, -98900,2.7061648,1.316895,,,,,,,,,,,,,, -99000,2.4548185,1.3538585,,,,,,,,,,,,,, -99100,2.5663264,1.4544865,,,,,,,,,,,,,, -99200,2.3401175,1.3504978,,,,,,,,,,,,,, -99300,2.4675968,1.3612972,,,,,,,,,,,,,, -99400,2.3120096,1.3029433,,,,,,,,,,,,,, -99500,2.4639432,1.3384243,,,,,,,,,,,,,, -99600,2.6280522,1.3335739,,,,,,,,,,,,,, -99700,2.6577394,1.3924874,,,,,,,,,,,,,, -99800,2.543382,1.3385657,,,,,,,,,,,,,, -99869,,,0.7844786047935486,0.7990682721138,0.684939980506897,1.276252269744873,50000.0,0.5529000163078308,2.0021021366119385,10000.0,33716.18281817436,35057.95794534683,33716.18281817436,1334.5619978904724,3.310832023620605,0.0 -99900,2.4881544,1.433815,,,,,,,,,,,,,, -100000,2.3823285,1.2871813,,,,,,,,,,,,,, -100100,2.5521562,1.2952279,,,,,,,,,,,,,, -100200,2.4268732,1.3489618,,,,,,,,,,,,,, -100300,2.3960466,1.3506638,,,,,,,,,,,,,, -100400,2.463308,1.363009,,,,,,,,,,,,,, -100500,2.780321,1.3360168,,,,,,,,,,,,,, -100600,2.7475471,1.2825239,,,,,,,,,,,,,, -100700,2.8551605,1.3708878,,,,,,,,,,,,,, -100800,2.4895322,1.462895,,,,,,,,,,,,,, -100900,2.5966384,1.3781109,,,,,,,,,,,,,, -101000,2.5419753,1.3743519,,,,,,,,,,,,,, -101100,2.3112257,1.2802453,,,,,,,,,,,,,, -101200,2.3792589,1.3530731,,,,,,,,,,,,,, -101300,2.5487525,1.3729625,,,,,,,,,,,,,, -101383,,,0.7774832248687744,0.8314793705940247,0.6846199631690979,1.2771825790405271,50000.0,0.5521000027656555,2.012157440185547,10000.0,34226.19294190407,35585.33166742325,34226.19294190407,1351.8213539123535,3.355522871017456,0.0 -101400,2.496,1.304698,,,,,,,,,,,,,, -101500,2.5899704,1.4204648,,,,,,,,,,,,,, -101600,2.3066015,1.3074151,,,,,,,,,,,,,, -101700,2.4498663,1.3354394,,,,,,,,,,,,,, -101800,2.2931197,1.328164,,,,,,,,,,,,,, -101900,2.5829217,1.34705,,,,,,,,,,,,,, -102000,2.4089894,1.3478514,,,,,,,,,,,,,, -102100,2.5305471,1.3635192,,,,,,,,,,,,,, -102200,2.6516273,1.4558035,,,,,,,,,,,,,, -102300,2.52915,1.3831699,,,,,,,,,,,,,, -102400,2.4062905,1.3080785,,,,,,,,,,,,,, -102500,2.4072537,1.3345813,,,,,,,,,,,,,, -102600,2.6753585,1.4099001,,,,,,,,,,,,,, -102700,2.4178562,1.2901926,,,,,,,,,,,,,, -102800,2.5254724,1.3018035,,,,,,,,,,,,,, -102897,,,0.7782804369926453,0.8227626085281372,0.6896799802780151,1.256480097770691,50000.0,0.5592000484466553,1.97683584690094,10000.0,34736.09021759033,36112.63495731354,34736.09021759033,1369.1167182922363,3.406175136566162,0.0 -102900,2.4913218,1.4034784,,,,,,,,,,,,,, -103000,2.280598,1.2959447,,,,,,,,,,,,,, -103100,2.492571,1.2744042,,,,,,,,,,,,,, -103200,2.4140754,1.2313697,,,,,,,,,,,,,, -103300,2.4220176,1.3224268,,,,,,,,,,,,,, -103400,2.4650047,1.2777088,,,,,,,,,,,,,, -103500,2.7481232,1.4082618,,,,,,,,,,,,,, -103600,2.3318093,1.3136754,,,,,,,,,,,,,, -103700,2.4305658,1.299431,,,,,,,,,,,,,, -103800,2.3425043,1.2398582,,,,,,,,,,,,,, -103900,2.4879317,1.4212005,,,,,,,,,,,,,, -104000,2.4875104,1.2954104,,,,,,,,,,,,,, -104100,2.4105604,1.4376416,,,,,,,,,,,,,, -104200,2.5856938,1.2648668,,,,,,,,,,,,,, -104300,2.5156028,1.4107897,,,,,,,,,,,,,, -104400,2.4055724,1.3943365,,,,,,,,,,,,,, -104412,,,0.7736367583274841,0.8423909544944763,0.6892600059509277,1.259985089302063,50000.0,0.5574000477790833,2.013371229171753,10000.0,35246.263035058975,36640.50486254692,35246.263035058975,1386.708063840866,3.452512502670288,0.0 -104500,2.467,1.4372835,,,,,,,,,,,,,, -104600,2.4909098,1.3671259,,,,,,,,,,,,,, -104700,2.545515,1.3142703,,,,,,,,,,,,,, -104800,2.3051941,1.1902645,,,,,,,,,,,,,, -104900,2.7298493,1.3233659,,,,,,,,,,,,,, -105000,2.5555594,1.379729,,,,,,,,,,,,,, -105100,2.7066307,1.3842143,,,,,,,,,,,,,, -105200,2.520391,1.3429139,,,,,,,,,,,,,, -105300,2.6509976,1.2511129,,,,,,,,,,,,,, -105400,2.7888613,1.3522816,,,,,,,,,,,,,, -105500,2.5359123,1.3212581,,,,,,,,,,,,,, -105600,2.5828238,1.2461984,,,,,,,,,,,,,, -105700,2.5871177,1.2509011,,,,,,,,,,,,,, -105800,2.5460923,1.2940838,,,,,,,,,,,,,, -105900,2.8007615,1.2629926,,,,,,,,,,,,,, -105926,,,0.7770846486091614,0.8247220516204834,0.6956799626350403,1.2360202074050903,50000.0,0.5586000084877014,1.974216103553772,10000.0,35756.16144490242,37167.974967479706,35756.16144490242,1404.1752064228058,3.4970178604125977,0.0 -106000,2.470044,1.1909308,,,,,,,,,,,,,, -106100,2.7783005,1.2719915,,,,,,,,,,,,,, -106200,2.6730142,1.3813035,,,,,,,,,,,,,, -106300,2.6668074,1.2796284,,,,,,,,,,,,,, -106400,2.4087932,1.2681797,,,,,,,,,,,,,, -106500,2.4712532,1.2380804,,,,,,,,,,,,,, -106600,2.627408,1.3732681,,,,,,,,,,,,,, -106700,2.8257174,1.3167307,,,,,,,,,,,,,, -106800,2.6840968,1.2757843,,,,,,,,,,,,,, -106900,2.5657125,1.3317624,,,,,,,,,,,,,, -107000,2.4814744,1.1839854,,,,,,,,,,,,,, -107100,2.6367226,1.2661533,,,,,,,,,,,,,, -107200,2.4334147,1.2335924,,,,,,,,,,,,,, -107300,2.6544545,1.3109022,,,,,,,,,,,,,, -107400,2.5481873,1.3180357,,,,,,,,,,,,,, -107440,,,0.8256935477256775,0.6410678625106812,0.6988799571990967,1.2236398458480835,50000.0,0.5689000487327576,1.9630481004714968,10000.0,36266.39616584778,37695.79827213287,36266.39616584778,1421.6605622768402,3.54105544090271,0.0 -107500,2.644929,1.218121,,,,,,,,,,,,,, -107600,2.8670335,1.3637527,,,,,,,,,,,,,, -107700,2.6813636,1.2814353,,,,,,,,,,,,,, -107800,2.6434717,1.3470262,,,,,,,,,,,,,, -107900,2.5223017,1.2177322,,,,,,,,,,,,,, -108000,2.6008308,1.2536734,,,,,,,,,,,,,, -108100,2.4481122,1.3030937,,,,,,,,,,,,,, -108200,2.6290104,1.3473328,,,,,,,,,,,,,, -108300,2.4974062,1.2277354,,,,,,,,,,,,,, -108400,2.539953,1.2375078,,,,,,,,,,,,,, -108500,2.7117474,1.336659,,,,,,,,,,,,,, -108600,2.7046087,1.3875229,,,,,,,,,,,,,, -108700,2.8986478,1.2716029,,,,,,,,,,,,,, -108800,2.6044836,1.2882079,,,,,,,,,,,,,, -108900,2.7417717,1.2801111,,,,,,,,,,,,,, -108954,,,0.8023955225944519,0.722143292427063,0.694159984588623,1.2538851499557495,50000.0,0.5565000176429749,2.0258400440216064,10000.0,36776.41024065018,38223.09888339043,36776.41024065018,1438.8392596244812,3.589161157608032,0.0 -109000,2.592159,1.2563552,,,,,,,,,,,,,, -109100,2.933463,1.3363209,,,,,,,,,,,,,, -109200,2.633628,1.3509011,,,,,,,,,,,,,, -109300,2.621658,1.2562711,,,,,,,,,,,,,, -109400,3.009023,1.3368925,,,,,,,,,,,,,, -109500,2.5371797,1.2258514,,,,,,,,,,,,,, -109600,2.5116892,1.2685907,,,,,,,,,,,,,, -109700,2.5360968,1.260415,,,,,,,,,,,,,, -109800,2.640709,1.2872651,,,,,,,,,,,,,, -109900,2.7178926,1.1629571,,,,,,,,,,,,,, -110000,2.6605368,1.2894858,,,,,,,,,,,,,, -110100,2.880323,1.2380885,,,,,,,,,,,,,, -110200,2.7448902,1.1905088,,,,,,,,,,,,,, -110300,2.484837,1.2514932,,,,,,,,,,,,,, -110400,2.5784469,1.2971213,,,,,,,,,,,,,, -110468,,,0.8027144074440002,0.7184869050979614,0.7019400000572205,1.2076020240783691,50000.0,0.5703000426292419,1.931430697441101,10000.0,37286.4473862648,38750.62214708328,37286.4473862648,1456.2212126255035,3.6329078674316406,0.0 -110500,2.5638127,1.0990622,,,,,,,,,,,,,, -110600,2.6092772,1.2382301,,,,,,,,,,,,,, -110700,2.9663548,1.3254158,,,,,,,,,,,,,, -110800,2.7922611,1.3055862,,,,,,,,,,,,,, -110900,2.588947,1.2492123,,,,,,,,,,,,,, -111000,2.6358395,1.2542912,,,,,,,,,,,,,, -111100,2.6612225,1.2898198,,,,,,,,,,,,,, -111200,2.7736435,1.3080521,,,,,,,,,,,,,, -111300,2.47783,1.270727,,,,,,,,,,,,,, -111400,2.704867,1.1992697,,,,,,,,,,,,,, -111500,2.8832605,1.3000762,,,,,,,,,,,,,, -111600,2.9316516,1.2695274,,,,,,,,,,,,,, -111700,2.6320615,1.2454388,,,,,,,,,,,,,, -111800,2.7234542,1.33762,,,,,,,,,,,,,, -111900,2.91536,1.2770202,,,,,,,,,,,,,, -111983,,,0.7995057106018066,0.7365074753761292,0.7022199630737305,1.210964918136597,50000.0,0.5706000328063965,1.928413152694702,10000.0,37796.50906038284,39279.16947221756,37796.50906038284,1474.6009676456451,3.678417444229126,0.0 -112000,2.5682538,1.3276615,,,,,,,,,,,,,, -112100,2.505719,1.2285218,,,,,,,,,,,,,, -112200,2.784109,1.3109056,,,,,,,,,,,,,, -112300,2.6159577,1.2412972,,,,,,,,,,,,,, -112400,2.5840893,1.2058144,,,,,,,,,,,,,, -112500,2.751802,1.305421,,,,,,,,,,,,,, -112600,2.7715428,1.3858122,,,,,,,,,,,,,, -112700,2.8267086,1.3606842,,,,,,,,,,,,,, -112800,2.5080142,1.2083986,,,,,,,,,,,,,, -112900,2.9173572,1.239831,,,,,,,,,,,,,, -113000,2.7129276,1.2215868,,,,,,,,,,,,,, -113100,2.9890156,1.2991427,,,,,,,,,,,,,, -113200,2.96338,1.382329,,,,,,,,,,,,,, -113300,2.8884523,1.2092634,,,,,,,,,,,,,, -113400,2.5680244,1.2050138,,,,,,,,,,,,,, -113497,,,0.7829639315605164,0.7980664968490601,0.6932199597358704,1.2627981901168823,50000.0,0.5621000528335571,2.024306297302246,10000.0,38306.47283291817,39806.460503578186,38306.47283291817,1491.822915315628,3.724119186401367,0.0 -113500,2.7324476,1.2307346,,,,,,,,,,,,,, -113600,2.8456826,1.3424379,,,,,,,,,,,,,, -113700,2.8376749,1.326754,,,,,,,,,,,,,, -113800,2.7869384,1.208296,,,,,,,,,,,,,, -113900,2.7761412,1.2530693,,,,,,,,,,,,,, -114000,2.688757,1.216286,,,,,,,,,,,,,, -114100,2.8953538,1.2721362,,,,,,,,,,,,,, -114200,2.87883,1.2612636,,,,,,,,,,,,,, -114300,2.9154508,1.321843,,,,,,,,,,,,,, -114400,3.1065042,1.2196473,,,,,,,,,,,,,, -114500,2.6226535,1.1440059,,,,,,,,,,,,,, -114600,2.5592835,1.2527554,,,,,,,,,,,,,, -114700,2.8646772,1.2944121,,,,,,,,,,,,,, -114800,2.78412,1.14323,,,,,,,,,,,,,, -114900,2.984613,1.1514452,,,,,,,,,,,,,, -115000,2.762882,1.2222321,,,,,,,,,,,,,, -115010,,,0.7958585619926453,0.7441632151603699,0.7008000016212463,1.2077741622924805,50000.0,0.5746999979019165,1.9470356702804563,10000.0,38816.370770692825,40333.80831313133,38816.370770692825,1509.1600093841553,3.776813507080078,0.0 -115100,2.9035423,1.1949773,,,,,,,,,,,,,, -115200,3.0005343,1.2758842,,,,,,,,,,,,,, -115300,2.6454167,1.1804079,,,,,,,,,,,,,, -115400,3.1057646,1.268641,,,,,,,,,,,,,, -115500,2.6268466,1.1366001,,,,,,,,,,,,,, -115600,2.559392,1.1630881,,,,,,,,,,,,,, -115700,2.859623,1.2619144,,,,,,,,,,,,,, -115800,2.5348163,1.1549432,,,,,,,,,,,,,, -115900,2.7132323,1.165163,,,,,,,,,,,,,, -116000,2.882586,1.1953708,,,,,,,,,,,,,, -116100,2.702732,1.2684164,,,,,,,,,,,,,, -116200,2.6993308,1.1958936,,,,,,,,,,,,,, -116300,2.7808669,1.2446872,,,,,,,,,,,,,, -116400,2.8436942,1.2437652,,,,,,,,,,,,,, -116500,2.9034069,1.2087352,,,,,,,,,,,,,, -116524,,,0.8346619606018066,0.6010602116584778,0.7102000117301941,1.169505596160889,50000.0,0.5791000127792358,1.8854345083236688,10000.0,39326.3773624897,40860.94369840622,39326.3773624897,1526.1814465522766,3.824695587158203,0.0 -116600,2.757213,1.3096447,,,,,,,,,,,,,, -116700,3.3063269,1.1925983,,,,,,,,,,,,,, -116800,2.9559216,1.3079782,,,,,,,,,,,,,, -116900,3.4281278,1.2746077,,,,,,,,,,,,,, -117000,2.9866104,1.2830957,,,,,,,,,,,,,, -117100,2.8085861,1.1470577,,,,,,,,,,,,,, -117200,3.0966117,1.3028009,,,,,,,,,,,,,, -117300,2.8195596,1.1361561,,,,,,,,,,,,,, -117400,3.2908452,1.2831867,,,,,,,,,,,,,, -117500,2.9008286,1.2660022,,,,,,,,,,,,,, -117600,3.0002778,1.1738667,,,,,,,,,,,,,, -117700,2.8757892,1.2073379,,,,,,,,,,,,,, -117800,2.724242,1.1932809,,,,,,,,,,,,,, -117900,2.972437,1.3020655,,,,,,,,,,,,,, -118000,2.6217391,1.2057443,,,,,,,,,,,,,, -118038,,,0.8148317933082581,0.6677629947662354,0.7066999673843384,1.2082501649856567,50000.0,0.575700044631958,1.944663524627685,10000.0,39836.38484477997,41388.14964914322,39836.38484477997,1543.270350933075,3.8735198974609375,0.0 -118100,2.9425392,1.2101531,,,,,,,,,,,,,, -118200,2.7964857,1.1751286,,,,,,,,,,,,,, -118300,3.0848598,1.2138982,,,,,,,,,,,,,, -118400,2.8265362,1.175022,,,,,,,,,,,,,, -118500,3.205746,1.3498195,,,,,,,,,,,,,, -118600,2.9728703,1.267663,,,,,,,,,,,,,, -118700,2.5634363,1.1326724,,,,,,,,,,,,,, -118800,2.6156986,1.1539807,,,,,,,,,,,,,, -118900,2.9556842,1.261128,,,,,,,,,,,,,, -119000,2.788519,1.1846914,,,,,,,,,,,,,, -119100,3.2823434,1.242909,,,,,,,,,,,,,, -119200,2.7870083,1.238959,,,,,,,,,,,,,, -119300,3.2250433,1.1351432,,,,,,,,,,,,,, -119400,3.2970178,1.2424161,,,,,,,,,,,,,, -119500,2.9830296,1.1697077,,,,,,,,,,,,,, -119552,,,0.8146523833274841,0.6767027378082275,0.705020010471344,1.2016663551330566,50000.0,0.5781000256538391,1.9262524843215945,10000.0,40346.46527194977,41915.59468245506,40346.46527194977,1560.5260910987854,3.9229581356048575,0.0 -119600,2.9185147,1.200438,,,,,,,,,,,,,, -119700,3.1506693,1.214604,,,,,,,,,,,,,, -119800,2.7515883,1.215592,,,,,,,,,,,,,, -119900,2.8923097,1.2204549,,,,,,,,,,,,,, -120000,2.8527794,1.1507622,,,,,,,,,,,,,, -120100,2.8372002,1.1554649,,,,,,,,,,,,,, -120200,2.773191,1.0892277,,,,,,,,,,,,,, -120300,3.105024,1.3265232,,,,,,,,,,,,,, -120400,3.1372445,1.1587839,,,,,,,,,,,,,, -120500,2.9019701,1.1276615,,,,,,,,,,,,,, -120600,3.0598576,1.1265393,,,,,,,,,,,,,, -120700,2.8525515,1.2183182,,,,,,,,,,,,,, -120800,3.146739,1.1689705,,,,,,,,,,,,,, -120900,2.8984141,1.1967634,,,,,,,,,,,,,, -121000,2.8810587,1.1501778,,,,,,,,,,,,,, -121066,,,0.8086734414100647,0.6844852566719055,0.7085199952125549,1.1973555088043213,50000.0,0.5789000391960144,1.921277642250061,10000.0,40856.44110560417,42443.076384067535,40856.44110560417,1577.9214849472046,3.97292423248291,0.0 -121100,2.9809287,1.1151515,,,,,,,,,,,,,, -121200,2.930118,1.1316057,,,,,,,,,,,,,, -121300,3.002337,1.1149883,,,,,,,,,,,,,, -121400,2.9980597,1.1685767,,,,,,,,,,,,,, -121500,2.9651647,1.221574,,,,,,,,,,,,,, -121600,2.8629706,1.1412617,,,,,,,,,,,,,, -121700,2.889013,1.218888,,,,,,,,,,,,,, -121800,3.1333358,1.2134211,,,,,,,,,,,,,, -121900,2.8488114,1.1995654,,,,,,,,,,,,,, -122000,2.8215547,1.1406615,,,,,,,,,,,,,, -122100,3.00134,1.157407,,,,,,,,,,,,,, -122200,3.0989184,1.1434505,,,,,,,,,,,,,, -122300,3.1189435,1.114991,,,,,,,,,,,,,, -122400,3.2014325,1.1703928,,,,,,,,,,,,,, -122500,3.0239832,1.173683,,,,,,,,,,,,,, -122580,,,0.7980110049247742,0.7274233102798462,0.6998400092124939,1.2325913906097412,50000.0,0.572100043296814,1.985614538192749,10000.0,41366.51425933838,42970.44410133362,41366.51425933838,1595.1068725585938,4.020076274871826,0.0 -122600,3.4899693,1.2688701,,,,,,,,,,,,,, -122700,3.4134755,1.2720667,,,,,,,,,,,,,, -122800,2.9917493,1.2145144,,,,,,,,,,,,,, -122900,2.8963203,1.0361099,,,,,,,,,,,,,, -123000,3.1123297,1.1589437,,,,,,,,,,,,,, -123100,3.392168,1.1770575,,,,,,,,,,,,,, -123200,3.11071,1.1385014,,,,,,,,,,,,,, -123300,2.8416715,1.103271,,,,,,,,,,,,,, -123400,2.8077784,1.1689211,,,,,,,,,,,,,, -123500,2.976352,1.1101806,,,,,,,,,,,,,, -123600,3.252992,1.1952158,,,,,,,,,,,,,, -123700,3.134993,1.1214008,,,,,,,,,,,,,, -123800,3.0114903,1.144372,,,,,,,,,,,,,, -123900,3.317553,1.1879508,,,,,,,,,,,,,, -124000,2.9028802,1.2051625,,,,,,,,,,,,,, -124095,,,0.8214086294174194,0.6441144943237305,0.7152599692344666,1.1662580966949463,50000.0,0.586400032043457,1.8914276361465447,10000.0,41876.44450306893,43497.77945423126,41876.44450306893,1612.4037322998047,4.06778359413147,0.0 -124100,2.8726826,1.113298,,,,,,,,,,,,,, -124200,3.0872817,1.217813,,,,,,,,,,,,,, -124300,3.2625697,1.2076823,,,,,,,,,,,,,, -124400,3.0993507,1.216898,,,,,,,,,,,,,, -124500,3.0728388,1.106192,,,,,,,,,,,,,, -124600,3.0292728,1.1041689,,,,,,,,,,,,,, -124700,3.1790502,1.1861188,,,,,,,,,,,,,, -124800,2.8602724,1.1649047,,,,,,,,,,,,,, -124900,3.1812801,1.2273529,,,,,,,,,,,,,, -125000,3.188684,1.0669913,,,,,,,,,,,,,, -125100,3.3074808,1.1190087,,,,,,,,,,,,,, -125200,3.477601,1.2106478,,,,,,,,,,,,,, -125300,3.2071753,1.1266353,,,,,,,,,,,,,, -125400,3.086962,1.1048307,,,,,,,,,,,,,, -125500,2.8175569,1.130296,,,,,,,,,,,,,, -125600,3.1747985,1.2644775,,,,,,,,,,,,,, -125608,,,0.8460817933082581,0.5478847026824951,0.7125200033187866,1.174476981163025,50000.0,0.5817000269889832,1.921865940093994,10000.0,42386.406173706055,44025.30989646912,42386.406173706055,1629.8644435405731,4.1161048412323,0.0 -125700,2.9521756,1.1186593,,,,,,,,,,,,,, -125800,3.1039805,1.2399101,,,,,,,,,,,,,, -125900,3.4888234,1.1421626,,,,,,,,,,,,,, -126000,3.06913,1.1896273,,,,,,,,,,,,,, -126100,3.143818,1.1272866,,,,,,,,,,,,,, -126200,3.0756779,1.0855814,,,,,,,,,,,,,, -126300,3.1681325,1.108827,,,,,,,,,,,,,, -126400,3.2720213,1.144151,,,,,,,,,,,,,, -126500,3.2868485,1.1661172,,,,,,,,,,,,,, -126600,3.1431901,1.1238573,,,,,,,,,,,,,, -126700,3.2959445,1.0705718,,,,,,,,,,,,,, -126800,3.3759313,1.2090651,,,,,,,,,,,,,, -126900,3.0362983,1.11762,,,,,,,,,,,,,, -127000,3.139363,1.1560605,,,,,,,,,,,,,, -127100,3.1032639,1.1837327,,,,,,,,,,,,,, -127123,,,0.8364357352256775,0.592322587966919,0.7131800055503845,1.1688021421432495,50000.0,0.5838000178337097,1.920723557472229,10000.0,42896.569796323776,44553.16710352898,42896.569796323776,1647.4425375461578,4.171696662902832,0.0 -127200,3.3060882,1.1563137,,,,,,,,,,,,,, -127300,2.9570777,1.0806317,,,,,,,,,,,,,, -127400,2.7896085,1.0312941,,,,,,,,,,,,,, -127500,3.2798705,1.185676,,,,,,,,,,,,,, -127600,3.1546962,1.1765871,,,,,,,,,,,,,, -127700,3.0501904,1.0518837,,,,,,,,,,,,,, -127800,3.252041,1.1335349,,,,,,,,,,,,,, -127900,3.104937,1.1091473,,,,,,,,,,,,,, -128000,3.0582712,1.115388,,,,,,,,,,,,,, -128100,3.3631487,1.1617122,,,,,,,,,,,,,, -128200,3.1361759,1.1402023,,,,,,,,,,,,,, -128300,3.1349463,1.0336586,,,,,,,,,,,,,, -128400,2.9464786,1.0269146,,,,,,,,,,,,,, -128500,3.1348517,1.1832883,,,,,,,,,,,,,, -128600,3.2702267,1.0734013,,,,,,,,,,,,,, -128636,,,0.8303770422935486,0.6017466187477112,0.7099599838256836,1.1903725862503052,50000.0,0.5868000388145447,1.8865386247634888,10000.0,43406.31649875641,45080.90883421898,43406.31649875641,1664.9784235954285,4.570464134216309,0.0 -128700,2.9628227,1.0633341,,,,,,,,,,,,,, -128800,3.0053113,1.1248769,,,,,,,,,,,,,, -128900,3.5242608,1.1498806,,,,,,,,,,,,,, -129000,3.0605075,1.1228083,,,,,,,,,,,,,, -129100,3.083902,1.0516024,,,,,,,,,,,,,, -129200,3.0655022,1.0087202,,,,,,,,,,,,,, -129300,3.1860914,1.1945889,,,,,,,,,,,,,, -129400,3.528592,1.1942527,,,,,,,,,,,,,, -129500,3.3256192,1.2142706,,,,,,,,,,,,,, -129600,3.407537,1.1260608,,,,,,,,,,,,,, -129700,3.0750484,1.061551,,,,,,,,,,,,,, -129800,3.2538774,1.0701957,,,,,,,,,,,,,, -129900,3.567215,1.0267485,,,,,,,,,,,,,, -130000,3.6824746,1.1091805,,,,,,,,,,,,,, -130100,3.7434802,1.1709845,,,,,,,,,,,,,, -130150,,,0.8363161683082581,0.5795503854751587,0.7214199900627136,1.1396958827972412,50000.0,0.5943000316619873,1.8531285524368288,10000.0,43916.40103673935,45608.30464339256,43916.40103673935,1682.1797287464142,4.620298624038696,0.0 -130200,3.0300379,1.0161475,,,,,,,,,,,,,, -130300,3.361105,1.123551,,,,,,,,,,,,,, -130400,3.5213609,1.1631261,,,,,,,,,,,,,, -130500,3.200855,1.0708889,,,,,,,,,,,,,, -130600,3.273159,1.0626738,,,,,,,,,,,,,, -130700,3.4203684,1.1350673,,,,,,,,,,,,,, -130800,3.1981764,1.0858439,,,,,,,,,,,,,, -130900,3.487266,1.0468302,,,,,,,,,,,,,, -131000,3.6586452,1.0827817,,,,,,,,,,,,,, -131100,3.3208036,1.128709,,,,,,,,,,,,,, -131200,3.1969624,1.0468273,,,,,,,,,,,,,, -131300,3.0794342,1.0900571,,,,,,,,,,,,,, -131400,3.0595574,1.0398815,,,,,,,,,,,,,, -131500,3.4731278,1.1885133,,,,,,,,,,,,,, -131600,3.4111457,1.1788005,,,,,,,,,,,,,, -131663,,,0.84086012840271,0.5709583163261414,0.7206400036811829,1.1557066440582275,50000.0,0.5958000421524048,1.8866512775421145,10000.0,44426.34621119499,46135.45411705971,44426.34621119499,1699.2731821537018,4.671350479125977,0.0 -131700,3.2685325,1.0498973,,,,,,,,,,,,,, -131800,3.293396,1.0596733,,,,,,,,,,,,,, -131900,3.4529362,1.1243291,,,,,,,,,,,,,, -132000,2.9912622,1.0015917,,,,,,,,,,,,,, -132100,3.2509947,1.0726159,,,,,,,,,,,,,, -132200,3.3053684,1.0737591,,,,,,,,,,,,,, -132300,3.2675183,1.066617,,,,,,,,,,,,,, -132400,3.2631576,1.0482117,,,,,,,,,,,,,, -132500,3.3583655,1.1498253,,,,,,,,,,,,,, -132600,3.3030443,1.0519722,,,,,,,,,,,,,, -132700,3.3290765,1.0247724,,,,,,,,,,,,,, -132800,3.4487338,1.0410888,,,,,,,,,,,,,, -132900,3.2863305,0.9947306,,,,,,,,,,,,,, -133000,3.2373402,1.0072786,,,,,,,,,,,,,, -133100,3.4381936,1.089608,,,,,,,,,,,,,, -133177,,,0.8496691584587097,0.540483295917511,0.7239800095558167,1.138985514640808,50000.0,0.5940999984741211,1.8570551872253416,10000.0,44936.3878133297,46663.10573005676,44936.3878133297,1716.7603754997251,4.733880996704102,0.0 -133200,3.107808,1.00643,,,,,,,,,,,,,, -133300,3.3685327,1.0923043,,,,,,,,,,,,,, -133400,3.7515366,1.0013491,,,,,,,,,,,,,, -133500,3.6610346,1.023969,,,,,,,,,,,,,, -133600,3.4605966,1.0721644,,,,,,,,,,,,,, -133700,3.334034,1.0446283,,,,,,,,,,,,,, -133800,3.1881018,1.0394413,,,,,,,,,,,,,, -133900,3.2700307,1.0379674,,,,,,,,,,,,,, -134000,3.4464304,1.0311706,,,,,,,,,,,,,, -134100,3.4550521,1.0890113,,,,,,,,,,,,,, -134200,3.2351441,0.9524033,,,,,,,,,,,,,, -134300,3.589895,1.1173398,,,,,,,,,,,,,, -134400,3.6912198,1.2012662,,,,,,,,,,,,,, -134500,3.215547,1.0668683,,,,,,,,,,,,,, -134600,3.6240015,1.1436224,,,,,,,,,,,,,, -134691,,,0.865652859210968,0.4808640778064728,0.7212600111961365,1.1499614715576172,50000.0,0.5897000432014465,1.8905266523361208,10000.0,45446.60967588425,47191.0237467289,45446.60967588425,1734.3433163166046,4.786406517028809,0.0 -134700,3.2944443,1.015649,,,,,,,,,,,,,, -134800,3.4399917,1.0379752,,,,,,,,,,,,,, -134900,3.4517798,0.9609412,,,,,,,,,,,,,, -135000,3.35881,1.081047,,,,,,,,,,,,,, -135100,3.484192,1.115219,,,,,,,,,,,,,, -135200,3.3716583,1.054397,,,,,,,,,,,,,, -135300,3.4883232,1.1372893,,,,,,,,,,,,,, -135400,3.9469476,1.1324745,,,,,,,,,,,,,, -135500,3.2446802,1.1332862,,,,,,,,,,,,,, -135600,3.3103828,0.9679579,,,,,,,,,,,,,, -135700,3.4188013,1.0282321,,,,,,,,,,,,,, -135800,3.2162435,0.9845844,,,,,,,,,,,,,, -135900,3.4004095,1.0315987,,,,,,,,,,,,,, -136000,3.3831837,1.0336065,,,,,,,,,,,,,, -136100,3.1695719,1.0530843,,,,,,,,,,,,,, -136200,3.2736,0.9685028,,,,,,,,,,,,,, -136205,,,0.8681640625,0.465312123298645,0.726859986782074,1.1234965324401855,50000.0,0.6017000079154968,1.838789939880371,10000.0,45956.50417351723,47718.14137125015,45956.50417351723,1751.4541499614716,4.838630199432373,0.0 -136300,3.619979,1.043246,,,,,,,,,,,,,, -136400,3.5329454,1.1396397,,,,,,,,,,,,,, -136500,3.4888418,0.99616927,,,,,,,,,,,,,, -136600,3.442667,1.0377715,,,,,,,,,,,,,, -136700,3.5862148,1.1207073,,,,,,,,,,,,,, -136800,3.5437634,1.0718842,,,,,,,,,,,,,, -136900,3.4348135,1.0218564,,,,,,,,,,,,,, -137000,3.43281,0.99126023,,,,,,,,,,,,,, -137100,3.3151057,1.003716,,,,,,,,,,,,,, -137200,3.8658192,1.0610193,,,,,,,,,,,,,, -137300,3.7941146,1.0456989,,,,,,,,,,,,,, -137400,3.2687929,0.9591846,,,,,,,,,,,,,, -137500,3.8827522,0.9734585,,,,,,,,,,,,,, -137600,3.4887345,0.9953536,,,,,,,,,,,,,, -137700,3.7915232,0.984026,,,,,,,,,,,,,, -137719,,,0.8539540767669678,0.5145111083984375,0.7205399870872498,1.144445538520813,50000.0,0.5908000469207764,1.8979063034057613,10000.0,46466.70485925674,48245.51445031166,46466.70485925674,1768.5149869918823,4.89011025428772,0.0 -137800,3.6676905,1.0883194,,,,,,,,,,,,,, -137900,3.5799704,0.9761667,,,,,,,,,,,,,, -138000,3.4828558,1.0581698,,,,,,,,,,,,,, -138100,3.340985,0.98502254,,,,,,,,,,,,,, -138200,3.605531,0.98396903,,,,,,,,,,,,,, -138300,3.7582023,0.9972882,,,,,,,,,,,,,, -138400,3.9607143,1.0526291,,,,,,,,,,,,,, -138500,3.6682746,1.0244125,,,,,,,,,,,,,, -138600,3.4884918,1.043238,,,,,,,,,,,,,, -138700,3.7608018,1.0848613,,,,,,,,,,,,,, -138800,3.9729066,1.0054601,,,,,,,,,,,,,, -138900,3.485857,0.97582054,,,,,,,,,,,,,, -139000,3.7286294,1.0180593,,,,,,,,,,,,,, -139100,3.4188924,0.94385636,,,,,,,,,,,,,, -139200,3.282746,0.96633583,,,,,,,,,,,,,, -139234,,,0.8557477593421936,0.4996830224990845,0.7240999937057495,1.1292551755905151,50000.0,0.5954000353813171,1.8895883560180664,10000.0,46976.7206556797,48772.95825433731,46976.7206556797,1785.825518131256,4.946522235870361,0.0 -139300,3.7454462,1.0896542,,,,,,,,,,,,,, -139400,3.4475422,0.9741636,,,,,,,,,,,,,, -139500,3.54708,0.94728494,,,,,,,,,,,,,, -139600,3.0802796,0.89403796,,,,,,,,,,,,,, -139700,3.5101714,0.9421474,,,,,,,,,,,,,, -139800,3.2978222,0.8786624,,,,,,,,,,,,,, -139900,3.566775,0.9778966,,,,,,,,,,,,,, -140000,3.4488409,0.9569422,,,,,,,,,,,,,, -140100,3.4571514,1.0107632,,,,,,,,,,,,,, -140200,3.549121,0.9607138,,,,,,,,,,,,,, -140300,3.825577,1.0531331,,,,,,,,,,,,,, -140400,3.7257836,0.9333314,,,,,,,,,,,,,, -140500,3.5800755,1.0082624,,,,,,,,,,,,,, -140600,3.3998609,0.9093392,,,,,,,,,,,,,, -140700,3.6200538,0.9900625,,,,,,,,,,,,,, -140747,,,0.8660913109779358,0.4651607573032379,0.728659987449646,1.1148227453231812,50000.0,0.6037000417709351,1.8393598794937127,10000.0,47486.69790649414,49300.32685422897,47486.69790649414,1803.1059713363647,4.9977428913116455,0.0 -140800,3.7150218,1.0832459,,,,,,,,,,,,,, -140900,3.422982,0.96709096,,,,,,,,,,,,,, -141000,3.5386028,1.028298,,,,,,,,,,,,,, -141100,3.5364404,0.9862727,,,,,,,,,,,,,, -141200,3.545416,0.9251716,,,,,,,,,,,,,, -141300,3.5585926,0.99754655,,,,,,,,,,,,,, -141400,3.5532315,0.9282603,,,,,,,,,,,,,, -141500,3.6159346,0.9603428,,,,,,,,,,,,,, -141600,3.8825493,1.0474716,,,,,,,,,,,,,, -141700,3.6757514,0.9005006,,,,,,,,,,,,,, -141800,3.829502,0.99019015,,,,,,,,,,,,,, -141900,3.4484797,0.9872245,,,,,,,,,,,,,, -142000,3.899898,1.0252724,,,,,,,,,,,,,, -142100,3.5439508,0.9061158,,,,,,,,,,,,,, -142200,3.7694237,0.9642528,,,,,,,,,,,,,, -142262,,,0.8884326815605164,0.3945820927619934,0.7259399890899658,1.1340726613998413,50000.0,0.5940999984741211,1.8927258253097528,10000.0,47996.81061291695,49827.75827026367,47996.81061291695,1820.3083732128143,5.054003477096558,0.0 -142300,3.5708385,0.8760075,,,,,,,,,,,,,, -142400,3.51856,0.9250535,,,,,,,,,,,,,, -142500,3.8227324,0.91081077,,,,,,,,,,,,,, -142600,3.900176,0.92723536,,,,,,,,,,,,,, -142700,3.5944984,0.9549922,,,,,,,,,,,,,, -142800,3.8839214,0.926347,,,,,,,,,,,,,, -142900,3.5751326,0.8854482,,,,,,,,,,,,,, -143000,3.7647219,0.9878917,,,,,,,,,,,,,, -143100,3.9165218,1.0275791,,,,,,,,,,,,,, -143200,3.9739077,1.0190946,,,,,,,,,,,,,, -143300,3.7186408,0.9149275,,,,,,,,,,,,,, -143400,3.6240132,0.9751595,,,,,,,,,,,,,, -143500,3.50329,0.8822511,,,,,,,,,,,,,, -143600,3.5586214,0.88981533,,,,,,,,,,,,,, -143700,3.6040637,0.9813193,,,,,,,,,,,,,, -143776,,,0.8937739133834839,0.3656793534755707,0.7310000061988831,1.1196391582489014,50000.0,0.6053000092506409,1.8732010126113887,10000.0,48506.87425112724,50355.19622254372,48506.87425112724,1837.5629630088808,5.1136510372161865,0.0 -143800,3.9075408,1.0249124,,,,,,,,,,,,,, -143900,3.8502824,0.990538,,,,,,,,,,,,,, -144000,3.7165217,0.9571426,,,,,,,,,,,,,, -144100,3.5785909,0.84445727,,,,,,,,,,,,,, -144200,3.7002296,0.92428225,,,,,,,,,,,,,, -144300,3.5212078,0.9398646,,,,,,,,,,,,,, -144400,3.7344582,1.0601916,,,,,,,,,,,,,, -144500,3.572917,0.9830892,,,,,,,,,,,,,, -144600,3.4156632,0.8207447,,,,,,,,,,,,,, -144700,3.5043488,0.9881698,,,,,,,,,,,,,, -144800,4.1037197,0.96374345,,,,,,,,,,,,,, -144900,3.9216752,0.99349445,,,,,,,,,,,,,, -145000,3.922122,0.95691025,,,,,,,,,,,,,, -145100,3.7339401,0.8748793,,,,,,,,,,,,,, -145200,3.7625785,0.9543165,,,,,,,,,,,,,, -145291,,,0.8908242583274841,0.3806366622447967,0.7371199727058411,1.0951299667358398,50000.0,0.6067000031471252,1.829805374145508,10000.0,49016.82561182976,50882.77686429024,49016.82561182976,1855.067155122757,5.179184913635254,0.0 -145300,3.533262,0.9531967,,,,,,,,,,,,,, -145400,3.969205,0.98881125,,,,,,,,,,,,,, -145500,3.8041425,0.9352626,,,,,,,,,,,,,, -145600,3.744365,0.92528677,,,,,,,,,,,,,, -145700,3.7790916,1.0156362,,,,,,,,,,,,,, -145800,3.849015,0.93997294,,,,,,,,,,,,,, -145900,4.0364466,0.96554863,,,,,,,,,,,,,, -146000,3.6646419,0.9115803,,,,,,,,,,,,,, -146100,3.9810991,0.9536233,,,,,,,,,,,,,, -146200,3.8111587,0.8774664,,,,,,,,,,,,,, -146300,3.7744725,0.8502704,,,,,,,,,,,,,, -146400,3.8385892,0.9662405,,,,,,,,,,,,,, -146500,3.7598135,0.85809577,,,,,,,,,,,,,, -146600,4.027506,0.87554985,,,,,,,,,,,,,, -146700,3.8301275,0.9154335,,,,,,,,,,,,,, -146800,3.6990523,0.8160475,,,,,,,,,,,,,, -146804,,,0.8893893361091614,0.3860467076301574,0.7328799962997437,1.1125147342681885,50000.0,0.6051000356674194,1.860052227973938,10000.0,49526.766122579575,51409.85617089272,49526.766122579575,1872.092421293259,5.233065366744995,0.0 -146900,3.8208115,0.91791654,,,,,,,,,,,,,, -147000,3.803008,0.8922565,,,,,,,,,,,,,, -147100,3.4148893,0.7702805,,,,,,,,,,,,,, -147200,3.7944727,0.9358325,,,,,,,,,,,,,, -147300,3.9971406,0.8959679,,,,,,,,,,,,,, -147400,4.190254,1.029418,,,,,,,,,,,,,, -147500,4.3537073,0.8785051,,,,,,,,,,,,,, -147600,3.8932211,0.9157839,,,,,,,,,,,,,, -147700,3.740111,0.91609526,,,,,,,,,,,,,, -147800,3.894318,0.8969374,,,,,,,,,,,,,, -147900,3.4689887,0.864835,,,,,,,,,,,,,, -148000,3.7083673,0.877751,,,,,,,,,,,,,, -148100,4.1032,0.81318295,,,,,,,,,,,,,, -148200,3.977811,0.8841732,,,,,,,,,,,,,, -148300,3.8820784,0.9327118,,,,,,,,,,,,,, -148319,,,0.8917610049247742,0.3722779452800751,0.7381599545478821,1.0851082801818848,50000.0,0.612500011920929,1.8237643241882324,10000.0,50036.77022147179,51937.361750125885,50036.77022147179,1889.482700824737,5.284041404724121,0.0 -148400,3.8862445,0.8693163,,,,,,,,,,,,,, -148500,3.7246325,0.8828523,,,,,,,,,,,,,, -148600,3.7997317,0.8863732,,,,,,,,,,,,,, -148700,4.058839,0.87550133,,,,,,,,,,,,,, -148800,4.032426,0.88453627,,,,,,,,,,,,,, -148900,4.063645,0.9041749,,,,,,,,,,,,,, -149000,4.063761,0.90282124,,,,,,,,,,,,,, -149100,3.977926,0.95849335,,,,,,,,,,,,,, -149200,3.9302092,0.7626756,,,,,,,,,,,,,, -149300,3.7231328,0.89038205,,,,,,,,,,,,,, -149400,3.7114804,0.81402516,,,,,,,,,,,,,, -149500,3.9562862,0.94345635,,,,,,,,,,,,,, -149600,3.7844815,0.90390015,,,,,,,,,,,,,, -149700,3.7760134,0.868084,,,,,,,,,,,,,, -149800,3.8681636,0.84985554,,,,,,,,,,,,,, -149833,,,0.8892498016357422,0.373854249715805,0.7322799563407898,1.1266218423843384,50000.0,0.6081000566482544,1.878450512886048,10000.0,50546.90058851242,52464.96715068817,50546.90058851242,1906.845090150833,5.336570739746094,0.0 -149900,3.6931925,0.8391227,,,,,,,,,,,,,, -150000,3.7846959,0.8350997,,,,,,,,,,,,,, -150100,3.7684937,0.8297154,,,,,,,,,,,,,, -150200,4.011065,0.84074104,,,,,,,,,,,,,, -150300,3.8542778,0.8950372,,,,,,,,,,,,,, -150400,3.8954933,0.9201086,,,,,,,,,,,,,, -150500,4.162461,0.86196804,,,,,,,,,,,,,, -150600,3.6672728,0.7878375,,,,,,,,,,,,,, -150700,4.1699386,0.95814604,,,,,,,,,,,,,, -150800,4.0188046,0.873541,,,,,,,,,,,,,, -150900,4.0264244,0.9718706,,,,,,,,,,,,,, -151000,3.895191,0.83674604,,,,,,,,,,,,,, -151100,3.7710109,0.8644315,,,,,,,,,,,,,, -151200,3.6907618,0.8355075,,,,,,,,,,,,,, -151300,3.9502776,0.8947087,,,,,,,,,,,,,, -151348,,,0.9268175959587096,0.266787976026535,0.7402799725532532,1.0880799293518066,50000.0,0.6128000020980835,1.836662530899048,10000.0,51057.0911295414,52992.34717011452,51057.0911295414,1923.9228394031525,5.388533115386963,0.0 -151400,3.9201066,0.87353414,,,,,,,,,,,,,, -151500,4.1045427,0.81383884,,,,,,,,,,,,,, -151600,3.9720454,0.77510536,,,,,,,,,,,,,, -151700,3.9318638,0.8232882,,,,,,,,,,,,,, -151800,3.708657,0.7921961,,,,,,,,,,,,,, -151900,3.8168685,0.86993,,,,,,,,,,,,,, -152000,4.2156143,0.85303646,,,,,,,,,,,,,, -152100,4.0635185,0.88572484,,,,,,,,,,,,,, -152200,3.9959617,0.85318583,,,,,,,,,,,,,, -152300,4.236202,0.85026735,,,,,,,,,,,,,, -152400,3.9373019,0.83967113,,,,,,,,,,,,,, -152500,3.9774687,0.8816419,,,,,,,,,,,,,, -152600,3.810291,0.8993467,,,,,,,,,,,,,, -152700,3.91067,0.8394109,,,,,,,,,,,,,, -152800,4.345389,0.8908224,,,,,,,,,,,,,, -152861,,,0.9162946343421936,0.2905495464801788,0.739799976348877,1.094793677330017,50000.0,0.6173000335693359,1.8380128145217896,10000.0,51567.03753566742,53520.90021848679,51567.03753566742,1942.4130220413208,5.445278167724609,0.0 -152900,4.2305136,0.86766714,,,,,,,,,,,,,, -153000,4.2276626,0.8759673,,,,,,,,,,,,,, -153100,4.1576385,0.8095771,,,,,,,,,,,,,, -153200,4.4342437,0.90956306,,,,,,,,,,,,,, -153300,4.0917454,0.81268495,,,,,,,,,,,,,, -153400,4.339093,0.8558084,,,,,,,,,,,,,, -153500,4.1129994,0.85142267,,,,,,,,,,,,,, -153600,4.214735,0.79952383,,,,,,,,,,,,,, -153700,3.9357748,0.8314271,,,,,,,,,,,,,, -153800,3.832729,0.8222127,,,,,,,,,,,,,, -153900,3.8549929,0.8373894,,,,,,,,,,,,,, -154000,4.0878997,0.8343759,,,,,,,,,,,,,, -154100,3.9497464,0.7335143,,,,,,,,,,,,,, -154200,4.2738905,0.79956627,,,,,,,,,,,,,, -154300,4.178541,0.76372695,,,,,,,,,,,,,, -154375,,,0.9135442972183228,0.3000372946262359,0.7407799959182739,1.092759132385254,50000.0,0.6141000390052795,1.8534924983978271,10000.0,52077.10262107849,54048.459849357605,52077.10262107849,1959.7907156944275,5.500972747802734,0.0 -154400,4.0546403,0.8002833,,,,,,,,,,,,,, -154500,4.085106,0.752486,,,,,,,,,,,,,, -154600,3.7610898,0.869109,,,,,,,,,,,,,, -154700,4.0098834,0.78476024,,,,,,,,,,,,,, -154800,3.8848193,0.8467289,,,,,,,,,,,,,, -154900,3.8235676,0.77673477,,,,,,,,,,,,,, -155000,4.013033,0.8052259,,,,,,,,,,,,,, -155100,4.13111,0.85335577,,,,,,,,,,,,,, -155200,4.2711825,0.85554445,,,,,,,,,,,,,, -155300,4.3064947,0.85086423,,,,,,,,,,,,,, -155400,4.1046987,0.8158824,,,,,,,,,,,,,, -155500,3.9362717,0.8262087,,,,,,,,,,,,,, -155600,4.5546737,0.8514685,,,,,,,,,,,,,, -155700,4.040656,0.76669097,,,,,,,,,,,,,, -155800,4.255257,0.8345404,,,,,,,,,,,,,, -155888,,,0.9125877022743224,0.2956430912017822,0.7436400055885315,1.0900652408599854,50000.0,0.6178000569343567,1.8527776002883911,10000.0,52587.03049516678,54575.90499091149,52587.03049516678,1977.1947882175448,5.554906845092773,0.0 -155900,3.8595395,0.78250134,,,,,,,,,,,,,, -156000,3.9925709,0.7291719,,,,,,,,,,,,,, -156100,4.0298367,0.7820931,,,,,,,,,,,,,, -156200,4.1962996,0.8001549,,,,,,,,,,,,,, -156300,4.42711,0.8142983,,,,,,,,,,,,,, -156400,4.4909606,0.78825146,,,,,,,,,,,,,, -156500,4.4155293,0.8178661,,,,,,,,,,,,,, -156600,4.337352,0.79247004,,,,,,,,,,,,,, -156700,4.26531,0.8297251,,,,,,,,,,,,,, -156800,4.173423,0.8383088,,,,,,,,,,,,,, -156900,4.080791,0.8399373,,,,,,,,,,,,,, -157000,4.014217,0.78351116,,,,,,,,,,,,,, -157100,4.189879,0.7372515,,,,,,,,,,,,,, -157200,4.004935,0.75359464,,,,,,,,,,,,,, -157300,4.2315884,0.76046884,,,,,,,,,,,,,, -157400,4.0689297,0.78580105,,,,,,,,,,,,,, -157402,,,0.9200215339660645,0.2833128273487091,0.7455399632453918,1.0706331729888916,50000.0,0.619100034236908,1.8253930807113647,10000.0,53096.94408249855,55103.1409778595,53096.94408249855,1994.400959968567,5.611454725265503,0.0 -157500,3.9624166,0.7220713,,,,,,,,,,,,,, -157600,3.9680617,0.73582816,,,,,,,,,,,,,, -157700,3.7955716,0.764853,,,,,,,,,,,,,, -157800,4.3781447,0.7767361,,,,,,,,,,,,,, -157900,4.2951927,0.7298152,,,,,,,,,,,,,, -158000,4.440749,0.8423799,,,,,,,,,,,,,, -158100,4.1000214,0.7301574,,,,,,,,,,,,,, -158200,3.8202503,0.74335325,,,,,,,,,,,,,, -158300,4.3077054,0.750247,,,,,,,,,,,,,, -158400,4.4376216,0.86721665,,,,,,,,,,,,,, -158500,4.7486806,0.82733965,,,,,,,,,,,,,, -158600,4.369786,0.76334226,,,,,,,,,,,,,, -158700,4.073071,0.76973987,,,,,,,,,,,,,, -158800,4.3704567,0.7274323,,,,,,,,,,,,,, -158900,4.0196824,0.7480794,,,,,,,,,,,,,, -158915,,,0.9239875674247742,0.2668489813804626,0.7450999617576599,1.073038935661316,50000.0,0.6178000569343567,1.8164268732070925,10000.0,53606.83699655533,55630.18832850456,53606.83699655533,2011.4396004676817,5.665586709976196,0.0 -159000,4.3049135,0.73121494,,,,,,,,,,,,,, -159100,4.1048083,0.79018724,,,,,,,,,,,,,, -159200,4.5302825,0.8743308,,,,,,,,,,,,,, -159300,4.314031,0.7130072,,,,,,,,,,,,,, -159400,4.663579,0.7001609,,,,,,,,,,,,,, -159500,4.180832,0.82286686,,,,,,,,,,,,,, -159600,4.3076105,0.73341835,,,,,,,,,,,,,, -159700,3.9967046,0.7137164,,,,,,,,,,,,,, -159800,4.2759247,0.7718171,,,,,,,,,,,,,, -159900,4.411359,0.8279505,,,,,,,,,,,,,, -160000,4.3647623,0.80262125,,,,,,,,,,,,,, -160100,4.0158434,0.7600763,,,,,,,,,,,,,, -160200,4.077153,0.7410298,,,,,,,,,,,,,, -160300,4.7640944,0.7215316,,,,,,,,,,,,,, -160400,4.55135,0.81691664,,,,,,,,,,,,,, -160429,,,0.942602038383484,0.2054823786020279,0.7473999857902527,1.0706602334976196,50000.0,0.6201000213623047,1.828052401542664,10000.0,54116.81638598442,56157.62402033806,54116.81638598442,2028.765647888184,5.735450029373169,0.0 -160500,4.485734,0.7913158,,,,,,,,,,,,,, -160600,4.3072896,0.78630817,,,,,,,,,,,,,, -160700,4.218607,0.7820712,,,,,,,,,,,,,, -160800,4.2371187,0.76469004,,,,,,,,,,,,,, -160900,4.1486106,0.71504337,,,,,,,,,,,,,, -161000,4.083906,0.7466338,,,,,,,,,,,,,, -161100,4.336825,0.8051998,,,,,,,,,,,,,, -161200,4.1688585,0.7241235,,,,,,,,,,,,,, -161300,4.045768,0.69535184,,,,,,,,,,,,,, -161400,3.931513,0.6611631,,,,,,,,,,,,,, -161500,4.395826,0.66415584,,,,,,,,,,,,,, -161600,4.7300925,0.76095057,,,,,,,,,,,,,, -161700,4.349275,0.7987793,,,,,,,,,,,,,, -161800,4.3130574,0.7905912,,,,,,,,,,,,,, -161900,4.379681,0.7298622,,,,,,,,,,,,,, -161943,,,0.9362045526504515,0.2241264581680297,0.7438799738883972,1.0907021760940552,50000.0,0.6215000152587891,1.846395254135132,10000.0,54626.75292801857,56685.03358387947,54626.75292801857,2046.1212282180784,5.791545629501343,0.0 -162000,4.6245656,0.78303736,,,,,,,,,,,,,, -162100,5.1953263,0.8324193,,,,,,,,,,,,,, -162200,3.9058435,0.63851106,,,,,,,,,,,,,, -162300,4.509398,0.7381172,,,,,,,,,,,,,, -162400,4.4564605,0.746641,,,,,,,,,,,,,, -162500,4.1846266,0.695174,,,,,,,,,,,,,, -162600,4.525407,0.68634295,,,,,,,,,,,,,, -162700,4.119607,0.7524564,,,,,,,,,,,,,, -162800,4.3243694,0.7007359,,,,,,,,,,,,,, -162900,4.535618,0.698277,,,,,,,,,,,,,, -163000,4.1568108,0.6968718,,,,,,,,,,,,,, -163100,4.6628838,0.7287216,,,,,,,,,,,,,, -163200,4.2543893,0.6917697,,,,,,,,,,,,,, -163300,4.297998,0.72545135,,,,,,,,,,,,,, -163400,4.357316,0.7144801,,,,,,,,,,,,,, -163457,,,0.9362045526504515,0.2207538038492202,0.746999979019165,1.0740346908569336,50000.0,0.6205000281333923,1.84682035446167,10000.0,55136.65639901161,57212.48588275909,55136.65639901161,2063.5502858161926,5.850171327590942,0.0 -163500,4.2066264,0.76310277,,,,,,,,,,,,,, -163600,4.5128365,0.7352701,,,,,,,,,,,,,, -163700,4.350167,0.7396043,,,,,,,,,,,,,, -163800,4.3159943,0.69502294,,,,,,,,,,,,,, -163900,4.356904,0.6746991,,,,,,,,,,,,,, -164000,4.464397,0.72571725,,,,,,,,,,,,,, -164100,4.194095,0.67935234,,,,,,,,,,,,,, -164200,4.506375,0.7486457,,,,,,,,,,,,,, -164300,4.383276,0.72201824,,,,,,,,,,,,,, -164400,4.3582497,0.69401455,,,,,,,,,,,,,, -164500,4.592247,0.7409365,,,,,,,,,,,,,, -164600,4.288156,0.7560295,,,,,,,,,,,,,, -164700,4.7353897,0.78456426,,,,,,,,,,,,,, -164800,4.4251766,0.74486125,,,,,,,,,,,,,, -164900,4.2920685,0.6897533,,,,,,,,,,,,,, -164971,,,0.9383171200752258,0.2114021331071853,0.7476199865341187,1.0842297077178955,50000.0,0.6273000240325928,1.8529764413833616,10000.0,55646.5770778656,57739.58699655533,55646.5770778656,2080.6110968589783,5.909976959228516,0.0 -165000,4.37265,0.65469915,,,,,,,,,,,,,, -165100,4.552031,0.6825026,,,,,,,,,,,,,, -165200,4.260956,0.6649674,,,,,,,,,,,,,, -165300,4.5172477,0.769904,,,,,,,,,,,,,, -165400,4.288953,0.6436551,,,,,,,,,,,,,, -165500,4.8444705,0.7740662,,,,,,,,,,,,,, -165600,4.5216947,0.7804223,,,,,,,,,,,,,, -165700,4.3515053,0.70021176,,,,,,,,,,,,,, -165800,4.4356046,0.7228617,,,,,,,,,,,,,, -165900,4.854373,0.75793725,,,,,,,,,,,,,, -166000,4.759956,0.741943,,,,,,,,,,,,,, -166100,4.7907968,0.79593813,,,,,,,,,,,,,, -166200,4.691207,0.67534965,,,,,,,,,,,,,, -166300,4.4448953,0.6721408,,,,,,,,,,,,,, -166400,4.0980906,0.63676316,,,,,,,,,,,,,, -166484,,,0.9404296875,0.2076198607683181,0.7492600083351135,1.065981149673462,50000.0,0.6234000325202942,1.8334190845489504,10000.0,56156.48152041435,58266.93318080902,56156.48152041435,2097.937203168869,5.965330123901367,0.0 -166500,4.5750155,0.7159438,,,,,,,,,,,,,, -166600,4.894761,0.7399752,,,,,,,,,,,,,, -166700,4.4940777,0.7556665,,,,,,,,,,,,,, -166800,4.595414,0.751377,,,,,,,,,,,,,, -166900,4.991505,0.6668937,,,,,,,,,,,,,, -167000,4.3625226,0.6987263,,,,,,,,,,,,,, -167100,4.405895,0.7048157,,,,,,,,,,,,,, -167200,4.320405,0.7193645,,,,,,,,,,,,,, -167300,4.48147,0.75646996,,,,,,,,,,,,,, -167400,4.440904,0.69832313,,,,,,,,,,,,,, -167500,4.8470883,0.707312,,,,,,,,,,,,,, -167600,4.559648,0.72340167,,,,,,,,,,,,,, -167700,4.5185814,0.6933055,,,,,,,,,,,,,, -167800,4.79161,0.7519591,,,,,,,,,,,,,, -167900,4.178425,0.68954253,,,,,,,,,,,,,, -167997,,,0.9448142051696776,0.196561262011528,0.7514399886131287,1.0683664083480835,50000.0,0.6291000247001648,1.8146344423294067,10000.0,56666.4612839222,58794.1668009758,56666.4612839222,2115.0789551734924,6.01855206489563,0.0 -168000,4.5799513,0.7459593,,,,,,,,,,,,,, -168100,4.2917223,0.6340128,,,,,,,,,,,,,, -168200,4.881664,0.78922576,,,,,,,,,,,,,, -168300,4.3227377,0.64775556,,,,,,,,,,,,,, -168400,4.631701,0.7404095,,,,,,,,,,,,,, -168500,4.705895,0.72589105,,,,,,,,,,,,,, -168600,4.254189,0.6304496,,,,,,,,,,,,,, -168700,4.635173,0.67239165,,,,,,,,,,,,,, -168800,4.2668295,0.65736747,,,,,,,,,,,,,, -168900,4.4103,0.6607013,,,,,,,,,,,,,, -169000,4.600969,0.7515376,,,,,,,,,,,,,, -169100,4.344185,0.63106555,,,,,,,,,,,,,, -169200,4.651277,0.7113869,,,,,,,,,,,,,, -169300,4.737876,0.6945988,,,,,,,,,,,,,, -169400,4.7177234,0.6636736,,,,,,,,,,,,,, -169500,4.2961063,0.7855038,,,,,,,,,,,,,, -169510,,,0.9545400142669678,0.1681820601224899,0.7520399689674377,1.063841462135315,50000.0,0.624500036239624,1.82466459274292,10000.0,57176.4802467823,59321.66878390312,57176.4802467823,2132.4454357624054,6.0728843212127686,0.0 -169600,5.1237955,0.7393209,,,,,,,,,,,,,, -169700,4.4055862,0.65365773,,,,,,,,,,,,,, -169800,4.2173376,0.684389,,,,,,,,,,,,,, -169900,4.364335,0.68952155,,,,,,,,,,,,,, -170000,4.5948067,0.658117,,,,,,,,,,,,,, -170100,4.7703714,0.6802174,,,,,,,,,,,,,, -170200,4.5123343,0.7237357,,,,,,,,,,,,,, -170300,4.679849,0.6969324,,,,,,,,,,,,,, -170400,4.498918,0.713013,,,,,,,,,,,,,, -170500,4.1735325,0.626133,,,,,,,,,,,,,, -170600,4.8708525,0.67327625,,,,,,,,,,,,,, -170700,5.038247,0.6823282,,,,,,,,,,,,,, -170800,4.1721697,0.63970923,,,,,,,,,,,,,, -170900,4.657188,0.68004954,,,,,,,,,,,,,, -171000,4.4479837,0.6834901,,,,,,,,,,,,,, -171023,,,0.9538225531578064,0.1663007885217666,0.7504199743270874,1.0649524927139282,50000.0,0.6247000098228455,1.8314189910888672,10000.0,57686.42001056671,59849.21348261833,57686.42001056671,2149.9335446357727,6.129309177398682,0.0 -171100,4.374112,0.6702075,,,,,,,,,,,,,, -171200,4.377666,0.665776,,,,,,,,,,,,,, -171300,4.2822742,0.5546037,,,,,,,,,,,,,, -171400,4.585334,0.67362475,,,,,,,,,,,,,, -171500,4.1630454,0.63821745,,,,,,,,,,,,,, -171600,4.4254293,0.68405807,,,,,,,,,,,,,, -171700,4.174723,0.6341227,,,,,,,,,,,,,, -171800,4.498382,0.70753944,,,,,,,,,,,,,, -171900,4.2896476,0.6461693,,,,,,,,,,,,,, -172000,4.7524595,0.71170396,,,,,,,,,,,,,, -172100,4.5983405,0.6654294,,,,,,,,,,,,,, -172200,4.2477903,0.6185039,,,,,,,,,,,,,, -172300,4.5302343,0.6767216,,,,,,,,,,,,,, -172400,4.3069744,0.59179586,,,,,,,,,,,,,, -172500,4.406542,0.6664046,,,,,,,,,,,,,, -172538,,,0.9530851244926452,0.1686789840459823,0.7537999749183655,1.0632176399230957,50000.0,0.6278000473976135,1.8282188177108765,10000.0,58196.485368967056,60376.90635275841,58196.485368967056,2167.4458363056183,6.185018539428711,0.0 -172600,4.344375,0.66998935,,,,,,,,,,,,,, -172700,4.3058205,0.5967114,,,,,,,,,,,,,, -172800,4.877533,0.69288886,,,,,,,,,,,,,, -172900,4.7680736,0.7085184,,,,,,,,,,,,,, -173000,4.1533356,0.6124505,,,,,,,,,,,,,, -173100,4.437892,0.6104889,,,,,,,,,,,,,, -173200,4.2447157,0.62941486,,,,,,,,,,,,,, -173300,4.289845,0.6232691,,,,,,,,,,,,,, -173400,4.6762834,0.6355779,,,,,,,,,,,,,, -173500,4.9690824,0.6921691,,,,,,,,,,,,,, -173600,4.0371346,0.5923623,,,,,,,,,,,,,, -173700,4.680202,0.69658077,,,,,,,,,,,,,, -173800,4.40815,0.64718336,,,,,,,,,,,,,, -173900,4.4809904,0.585557,,,,,,,,,,,,,, -174000,4.3937883,0.6808124,,,,,,,,,,,,,, -174052,,,0.954699456691742,0.1657977998256683,0.7526999711990356,1.06158185005188,50000.0,0.6285000443458557,1.8208074569702148,10000.0,58706.56109046936,60904.483194589615,58706.56109046936,2184.832007408142,6.240122318267822,0.0 -174100,4.7661233,0.6417544,,,,,,,,,,,,,, -174200,4.4393353,0.63263905,,,,,,,,,,,,,, -174300,4.639011,0.67244655,,,,,,,,,,,,,, -174400,4.7555833,0.7115967,,,,,,,,,,,,,, -174500,4.6363482,0.69750315,,,,,,,,,,,,,, -174600,4.280757,0.68049204,,,,,,,,,,,,,, -174700,4.633371,0.67940193,,,,,,,,,,,,,, -174800,4.265478,0.64236283,,,,,,,,,,,,,, -174900,4.4672766,0.7084117,,,,,,,,,,,,,, -175000,3.9609485,0.60280764,,,,,,,,,,,,,, -175100,4.4425035,0.6944447,,,,,,,,,,,,,, -175200,4.192116,0.6115682,,,,,,,,,,,,,, -175300,4.464,0.678328,,,,,,,,,,,,,, -175400,4.346492,0.63787067,,,,,,,,,,,,,, -175500,4.457603,0.530367,,,,,,,,,,,,,, -175566,,,0.9553571343421936,0.1650108247995376,0.7545799612998962,1.0558218955993652,50000.0,0.626800000667572,1.83081316947937,10000.0,59216.54894709587,61431.60178184509,59216.54894709587,2201.833286046982,6.310078859329224,0.0 -175600,4.252496,0.6302827,,,,,,,,,,,,,, -175700,4.278589,0.68991494,,,,,,,,,,,,,, -175800,5.26289,0.7212171,,,,,,,,,,,,,, -175900,4.224169,0.614999,,,,,,,,,,,,,, -176000,4.561906,0.6133278,,,,,,,,,,,,,, -176100,4.5647035,0.6357329,,,,,,,,,,,,,, -176200,4.9361987,0.71508634,,,,,,,,,,,,,, -176300,4.376088,0.6372026,,,,,,,,,,,,,, -176400,4.4711494,0.6827975,,,,,,,,,,,,,, -176500,4.3421626,0.6090759,,,,,,,,,,,,,, -176600,4.1028934,0.542558,,,,,,,,,,,,,, -176700,5.0305657,0.7125821,,,,,,,,,,,,,, -176800,4.7038183,0.654156,,,,,,,,,,,,,, -176900,4.8014603,0.68576187,,,,,,,,,,,,,, -177000,4.896034,0.66497,,,,,,,,,,,,,, -177079,,,0.9563536047935486,0.1589114516973495,0.7540599703788757,1.0573482513427734,50000.0,0.6285000443458557,1.8204129934310915,10000.0,59726.47098231316,61959.16131854057,59726.47098231316,2219.350727558136,6.369282007217407,0.0 -177100,4.362602,0.65997225,,,,,,,,,,,,,, -177200,4.463179,0.5653122,,,,,,,,,,,,,, -177300,4.3476577,0.60998964,,,,,,,,,,,,,, -177400,4.728348,0.6966402,,,,,,,,,,,,,, -177500,4.5184274,0.6663141,,,,,,,,,,,,,, -177600,4.4129047,0.5966449,,,,,,,,,,,,,, -177700,4.387815,0.61059666,,,,,,,,,,,,,, -177800,4.5789275,0.6307533,,,,,,,,,,,,,, -177900,4.2496624,0.6510132,,,,,,,,,,,,,, -178000,4.1292834,0.5832737,,,,,,,,,,,,,, -178100,4.649521,0.5836836,,,,,,,,,,,,,, -178200,4.4151506,0.66772354,,,,,,,,,,,,,, -178300,4.340933,0.631552,,,,,,,,,,,,,, -178400,4.983863,0.65490746,,,,,,,,,,,,,, -178500,4.4988856,0.6783788,,,,,,,,,,,,,, -178592,,,0.960379421710968,0.1490341424942016,0.7551800012588501,1.055389404296875,50000.0,0.6297000050544739,1.8235666751861568,10000.0,60236.56385469437,62486.78073191643,60236.56385469437,2236.7588534355164,6.427386999130249,0.0 -178600,4.3597426,0.5732934,,,,,,,,,,,,,, -178700,4.347933,0.6437632,,,,,,,,,,,,,, -178800,4.908305,0.63222796,,,,,,,,,,,,,, -178900,4.6594696,0.6637997,,,,,,,,,,,,,, -179000,4.636111,0.672692,,,,,,,,,,,,,, -179100,4.256509,0.62943316,,,,,,,,,,,,,, -179200,4.6686163,0.6518233,,,,,,,,,,,,,, -179300,4.3678823,0.6312701,,,,,,,,,,,,,, -179400,4.8395386,0.58743376,,,,,,,,,,,,,, -179500,4.716472,0.64742464,,,,,,,,,,,,,, -179600,5.3154335,0.67605096,,,,,,,,,,,,,, -179700,4.52261,0.6868834,,,,,,,,,,,,,, -179800,4.25501,0.66896844,,,,,,,,,,,,,, -179900,4.322626,0.57703054,,,,,,,,,,,,,, -180000,4.6915474,0.6336182,,,,,,,,,,,,,, -180100,4.56326,0.62733144,,,,,,,,,,,,,, -180106,,,0.9593032598495485,0.1496942341327667,0.7559199929237366,1.0523293018341064,50000.0,0.6306000351905823,1.820485591888428,10000.0,60746.7347638607,63014.39211153984,60746.7347638607,2254.0810379981995,6.485968351364136,0.0 -180200,4.6617785,0.58026516,,,,,,,,,,,,,, -180300,4.946722,0.5980511,,,,,,,,,,,,,, -180400,4.9567776,0.68899333,,,,,,,,,,,,,, -180500,4.312995,0.6326426,,,,,,,,,,,,,, -180600,4.2681694,0.5877323,,,,,,,,,,,,,, -180700,4.300387,0.59125334,,,,,,,,,,,,,, -180800,4.8221803,0.6284864,,,,,,,,,,,,,, -180900,4.198215,0.60845757,,,,,,,,,,,,,, -181000,4.866104,0.77189183,,,,,,,,,,,,,, -181100,4.677946,0.6276928,,,,,,,,,,,,,, -181200,4.667403,0.7102981,,,,,,,,,,,,,, -181300,4.2873387,0.5962132,,,,,,,,,,,,,, -181400,4.4993253,0.6171588,,,,,,,,,,,,,, -181500,4.9209676,0.670215,,,,,,,,,,,,,, -181600,4.1401596,0.5418716,,,,,,,,,,,,,, -181621,,,0.960339605808258,0.1473979651927948,0.7554599642753601,1.0532605648040771,50000.0,0.628600001335144,1.819870948791504,10000.0,61256.86021757126,63541.93805527687,61256.86021757126,2271.384888648987,6.543102502822876,0.0 -181700,4.3492236,0.56885785,,,,,,,,,,,,,, -181800,4.710809,0.651402,,,,,,,,,,,,,, -181900,4.7262993,0.5678813,,,,,,,,,,,,,, -182000,4.581998,0.64942956,,,,,,,,,,,,,, -182100,4.4295597,0.6055257,,,,,,,,,,,,,, -182200,4.392995,0.6075486,,,,,,,,,,,,,, -182300,4.927229,0.65129066,,,,,,,,,,,,,, -182400,4.3254037,0.5866271,,,,,,,,,,,,,, -182500,4.281679,0.55921763,,,,,,,,,,,,,, -182600,4.5822363,0.64804643,,,,,,,,,,,,,, -182700,4.415492,0.6479162,,,,,,,,,,,,,, -182800,4.253137,0.5877995,,,,,,,,,,,,,, -182900,4.672301,0.6530883,,,,,,,,,,,,,, -183000,4.343534,0.5616595,,,,,,,,,,,,,, -183100,4.0172267,0.52249986,,,,,,,,,,,,,, -183134,,,0.9602000713348388,0.1495323926210403,0.7563799619674683,1.051788330078125,50000.0,0.6314000487327576,1.8186100721359253,10000.0,61766.89761161804,64069.10019540787,61766.89761161804,2288.3898487091064,6.603374719619751,0.0 -183200,4.5218415,0.6503204,,,,,,,,,,,,,, -183300,5.1081657,0.6195874,,,,,,,,,,,,,, -183400,4.2474284,0.57744724,,,,,,,,,,,,,, -183500,4.399865,0.60494,,,,,,,,,,,,,, -183600,4.746179,0.6303631,,,,,,,,,,,,,, -183700,4.60167,0.65568894,,,,,,,,,,,,,, -183800,4.2240357,0.62288266,,,,,,,,,,,,,, -183900,4.4945054,0.6153232,,,,,,,,,,,,,, -184000,4.6381173,0.6701479,,,,,,,,,,,,,, -184100,5.3087215,0.6435524,,,,,,,,,,,,,, -184200,4.388788,0.63565826,,,,,,,,,,,,,, -184300,4.292475,0.6325019,,,,,,,,,,,,,, -184400,4.1256123,0.65450585,,,,,,,,,,,,,, -184500,4.385014,0.644576,,,,,,,,,,,,,, -184600,4.8551874,0.65351933,,,,,,,,,,,,,, -184648,,,0.9607780575752258,0.1484294682741165,0.7558199763298035,1.0527596473693848,50000.0,0.6309000253677368,1.817617654800415,10000.0,62276.97330927849,64596.847626686096,62276.97330927849,2305.942921876908,6.662899971008301,0.0 -184700,4.450358,0.6802891,,,,,,,,,,,,,, -184800,4.154827,0.574233,,,,,,,,,,,,,, -184900,4.397022,0.63837904,,,,,,,,,,,,,, -185000,4.276034,0.5743959,,,,,,,,,,,,,, -185100,4.769325,0.6920908,,,,,,,,,,,,,, -185200,4.4702024,0.6157455,,,,,,,,,,,,,, -185300,4.8105764,0.70017076,,,,,,,,,,,,,, -185400,4.3598228,0.5912058,,,,,,,,,,,,,, -185500,4.5387106,0.6674344,,,,,,,,,,,,,, -185600,4.5308256,0.6021672,,,,,,,,,,,,,, -185700,4.9059386,0.69271755,,,,,,,,,,,,,, -185800,4.406944,0.56467885,,,,,,,,,,,,,, -185900,4.8904996,0.68361956,,,,,,,,,,,,,, -186000,4.5670657,0.65024334,,,,,,,,,,,,,, -186100,4.0332375,0.61054844,,,,,,,,,,,,,, -186161,,,0.9616150856018066,0.1467247605323791,0.7559599876403809,1.051039218902588,50000.0,0.6305000185966492,1.8158655166625977,10000.0,62786.884974718094,65123.88558316231,62786.884974718094,2322.951239347458,6.7222740650177,0.0 -186200,4.718158,0.63479304,,,,,,,,,,,,,, -186300,4.337064,0.5949993,,,,,,,,,,,,,, -186400,4.8009696,0.60295504,,,,,,,,,,,,,, -186500,4.4135065,0.5512732,,,,,,,,,,,,,, -186600,4.715968,0.68768966,,,,,,,,,,,,,, -186700,4.2558913,0.58391964,,,,,,,,,,,,,, -186800,4.410172,0.5860855,,,,,,,,,,,,,, -186900,4.5103617,0.6536653,,,,,,,,,,,,,, -187000,4.4525228,0.651981,,,,,,,,,,,,,, -187100,4.370107,0.56636745,,,,,,,,,,,,,, -187200,4.146069,0.5487865,,,,,,,,,,,,,, -187300,4.2874117,0.5873717,,,,,,,,,,,,,, -187400,4.6180663,0.62994856,,,,,,,,,,,,,, -187500,4.355523,0.6664023,,,,,,,,,,,,,, -187600,4.2264357,0.5867808,,,,,,,,,,,,,, -187675,,,0.9621930718421936,0.1422573626041412,0.7559999823570251,1.0514984130859375,50000.0,0.6314000487327576,1.819074869155884,10000.0,63296.81457424164,65651.33988261223,63296.81457424164,2340.3564281463623,6.781188726425171,0.0 -187700,4.323901,0.6542245,,,,,,,,,,,,,, -187800,4.1840973,0.61181283,,,,,,,,,,,,,, -187900,4.848255,0.6614631,,,,,,,,,,,,,, -188000,4.376596,0.6082891,,,,,,,,,,,,,, -188100,4.5070076,0.6646696,,,,,,,,,,,,,, -188200,5.348641,0.69295686,,,,,,,,,,,,,, -188300,4.0269337,0.59987515,,,,,,,,,,,,,, -188400,4.5917974,0.67515457,,,,,,,,,,,,,, -188500,4.753397,0.6608656,,,,,,,,,,,,,, -188600,4.834062,0.6855824,,,,,,,,,,,,,, -188700,4.5043745,0.62095594,,,,,,,,,,,,,, -188800,3.977112,0.58464444,,,,,,,,,,,,,, -188900,4.470325,0.6579125,,,,,,,,,,,,,, -189000,4.4597383,0.679613,,,,,,,,,,,,,, -189100,4.698012,0.6315497,,,,,,,,,,,,,, -189189,,,0.9604990482330322,0.1454890668392181,0.7562599778175354,1.0513103008270264,50000.0,0.6312000155448914,1.81675398349762,10000.0,63806.97748732567,66178.77985548973,63806.97748732567,2357.5116155147552,6.8434693813323975,0.0 -189200,4.1888704,0.61708534,,,,,,,,,,,,,, -189300,4.927593,0.6872197,,,,,,,,,,,,,, -189400,4.532572,0.71531844,,,,,,,,,,,,,, -189500,4.6184487,0.64858925,,,,,,,,,,,,,, -189600,4.50931,0.68835384,,,,,,,,,,,,,, -189700,3.9714622,0.541839,,,,,,,,,,,,,, -189800,4.691594,0.65720016,,,,,,,,,,,,,, -189900,4.2553835,0.59947777,,,,,,,,,,,,,, -190000,4.6105075,0.6991527,,,,,,,,,,,,,, -190100,4.3811684,0.61876416,,,,,,,,,,,,,, -190200,4.247535,0.5843258,,,,,,,,,,,,,, -190300,5.2682176,0.7150373,,,,,,,,,,,,,, -190400,4.6071687,0.64855963,,,,,,,,,,,,,, -190500,4.5114417,0.6177423,,,,,,,,,,,,,, -190600,4.39171,0.5681516,,,,,,,,,,,,,, -190700,4.144727,0.6205673,,,,,,,,,,,,,, -190703,,,0.959582269191742,0.1502355635166168,0.7557399868965149,1.0518090724945068,50000.0,0.6309000253677368,1.817474722862244,10000.0,64316.961555957794,66705.87880730629,64316.961555957794,2374.5045187473297,6.90625524520874,0.0 -190800,4.407321,0.6757575,,,,,,,,,,,,,, -190900,4.314751,0.61611927,,,,,,,,,,,,,, -191000,4.6378884,0.7055823,,,,,,,,,,,,,, -191100,4.518358,0.6268792,,,,,,,,,,,,,, -191200,4.3409867,0.58919126,,,,,,,,,,,,,, -191300,4.947157,0.67946297,,,,,,,,,,,,,, -191400,4.187372,0.5373529,,,,,,,,,,,,,, -191500,4.8286295,0.6572821,,,,,,,,,,,,,, -191600,4.3234015,0.6006297,,,,,,,,,,,,,, -191700,4.7762794,0.666203,,,,,,,,,,,,,, -191800,4.6109834,0.66184753,,,,,,,,,,,,,, -191900,4.3418465,0.61184,,,,,,,,,,,,,, -192000,4.1208735,0.66599554,,,,,,,,,,,,,, -192100,4.280243,0.661693,,,,,,,,,,,,,, -192200,4.4053955,0.64587885,,,,,,,,,,,,,, -192217,,,0.9605388641357422,0.147024780511856,0.7558599710464478,1.0514979362487793,50000.0,0.6314000487327576,1.8174049854278564,10000.0,64826.8895072937,67234.2537624836,64826.8895072937,2392.830799102783,6.966517210006714,0.0 -192300,4.602369,0.6593805,,,,,,,,,,,,,, -192400,4.3707924,0.6864532,,,,,,,,,,,,,, -192500,4.921909,0.7007407,,,,,,,,,,,,,, -192600,4.5946207,0.6098064,,,,,,,,,,,,,, -192700,4.593228,0.591251,,,,,,,,,,,,,, -192800,4.4508214,0.6006152,,,,,,,,,,,,,, -192900,4.549335,0.6389731,,,,,,,,,,,,,, -193000,4.5814815,0.6385119,,,,,,,,,,,,,, -193100,4.3449903,0.58094865,,,,,,,,,,,,,, -193200,4.6437073,0.6619239,,,,,,,,,,,,,, -193300,4.848735,0.6368054,,,,,,,,,,,,,, -193400,4.0974283,0.6321593,,,,,,,,,,,,,, -193500,4.589124,0.65067,,,,,,,,,,,,,, -193600,4.342237,0.6109286,,,,,,,,,,,,,, -193700,4.52645,0.6520222,,,,,,,,,,,,,, -193731,,,0.9588249325752258,0.1494628936052322,0.7556399703025818,1.0511360168457031,50000.0,0.6310000419616699,1.8178002834320068,10000.0,65337.07450246811,67761.82452487946,65337.07450246811,2410.100454807281,7.022372007369995,0.0 -193800,4.3045697,0.63423663,,,,,,,,,,,,,, -193900,4.5865006,0.6117646,,,,,,,,,,,,,, -194000,4.7329125,0.6353514,,,,,,,,,,,,,, -194100,4.511963,0.60656625,,,,,,,,,,,,,, -194200,4.5519423,0.63141847,,,,,,,,,,,,,, -194300,4.657322,0.57405114,,,,,,,,,,,,,, -194400,5.0072393,0.68513465,,,,,,,,,,,,,, -194500,4.5102134,0.6426147,,,,,,,,,,,,,, -194600,4.3313274,0.56213164,,,,,,,,,,,,,, -194700,4.7351046,0.6454389,,,,,,,,,,,,,, -194800,4.2580304,0.5848702,,,,,,,,,,,,,, -194900,4.762464,0.68776745,,,,,,,,,,,,,, -195000,4.1932197,0.56807286,,,,,,,,,,,,,, -195100,4.00611,0.566828,,,,,,,,,,,,,, -195200,4.478521,0.640939,,,,,,,,,,,,,, -195244,,,0.9619937539100648,0.1456216126680374,0.755840003490448,1.052322268486023,50000.0,0.631600022315979,1.81663978099823,10000.0,65846.99104523659,68289.11477446556,65846.99104523659,2427.3492562770844,7.086922645568848,0.0 -195300,4.2689905,0.6002744,,,,,,,,,,,,,, -195400,4.5243597,0.62763715,,,,,,,,,,,,,, -195500,4.7659135,0.6182594,,,,,,,,,,,,,, -195600,4.5505676,0.64458686,,,,,,,,,,,,,, -195700,4.423288,0.6220914,,,,,,,,,,,,,, -195800,5.1242266,0.67970353,,,,,,,,,,,,,, -195900,4.745989,0.65021926,,,,,,,,,,,,,, -196000,4.6038084,0.6000842,,,,,,,,,,,,,, -196100,4.727138,0.68118614,,,,,,,,,,,,,, -196200,4.499799,0.5732752,,,,,,,,,,,,,, -196300,4.394832,0.5747693,,,,,,,,,,,,,, -196400,4.4478574,0.6577487,,,,,,,,,,,,,, -196500,4.5763545,0.62532485,,,,,,,,,,,,,, -196600,4.530419,0.58535874,,,,,,,,,,,,,, -196700,4.3720255,0.5883762,,,,,,,,,,,,,, -196758,,,0.9612762928009032,0.1447423547506332,0.7559399604797363,1.052221417427063,50000.0,0.6321000456809998,1.8178486824035645,10000.0,66356.99051713943,68816.40245962143,66356.99051713943,2444.515809059143,7.146530151367188,0.0 -196800,4.5376124,0.6552692,,,,,,,,,,,,,, -196900,4.9865503,0.65984005,,,,,,,,,,,,,, -197000,4.5836024,0.6083884,,,,,,,,,,,,,, -197100,4.2601814,0.61908895,,,,,,,,,,,,,, -197200,4.477246,0.61278737,,,,,,,,,,,,,, -197300,4.488321,0.6045104,,,,,,,,,,,,,, -197400,4.638897,0.62543833,,,,,,,,,,,,,, -197500,4.2553625,0.55055916,,,,,,,,,,,,,, -197600,4.870101,0.62337136,,,,,,,,,,,,,, -197700,4.7019434,0.70270354,,,,,,,,,,,,,, -197800,4.5105476,0.62467355,,,,,,,,,,,,,, -197900,4.0587754,0.5393084,,,,,,,,,,,,,, -198000,4.499621,0.59332055,,,,,,,,,,,,,, -198100,4.5058465,0.6307616,,,,,,,,,,,,,, -198200,4.6211386,0.57026136,,,,,,,,,,,,,, -198271,,,0.9612962007522584,0.1464032083749771,0.7554999589920044,1.0522739887237549,50000.0,0.6309000253677368,1.8155436515808103,10000.0,66866.98143553734,69343.70125079155,66866.98143553734,2461.702263355255,7.207445859909058,0.0 -198300,4.250319,0.64464766,,,,,,,,,,,,,, -198400,4.3949447,0.6365795,,,,,,,,,,,,,, -198500,5.1008263,0.5982586,,,,,,,,,,,,,, -198600,4.733415,0.650913,,,,,,,,,,,,,, -198700,4.4601526,0.6370789,,,,,,,,,,,,,, -198800,4.471212,0.6280682,,,,,,,,,,,,,, -198900,4.2801595,0.61922854,,,,,,,,,,,,,, -199000,5.0844345,0.58338994,,,,,,,,,,,,,, -199100,4.5657535,0.63576835,,,,,,,,,,,,,, -199200,4.9142017,0.68303424,,,,,,,,,,,,,, -199300,4.3443174,0.6039412,,,,,,,,,,,,,, -199400,4.8134108,0.6513914,,,,,,,,,,,,,, -199500,5.096917,0.70395494,,,,,,,,,,,,,, -199600,4.603658,0.6121318,,,,,,,,,,,,,, -199700,4.49214,0.6557799,,,,,,,,,,,,,, -199785,,,0.9602997303009032,0.1467967927455902,0.7558199763298035,1.0520784854888916,50000.0,0.6308000087738037,1.8178911209106443,10000.0,67376.91571760178,69871.08100628853,67376.91571760178,2479.025264263153,7.268761396408081,0.0 -199800,4.2636137,0.61926407,,,,,,,,,,,,,, -199900,4.5468283,0.6283367,,,,,,,,,,,,,, -200000,4.293816,0.60931104,,,,,,,,,,,,,, -200100,4.5301647,0.6342331,,,,,,,,,,,,,, -200200,4.336999,0.58916813,,,,,,,,,,,,,, -200300,4.919884,0.7525588,,,,,,,,,,,,,, -200400,4.426195,0.603009,,,,,,,,,,,,,, -200500,4.777413,0.6021727,,,,,,,,,,,,,, -200600,4.4691734,0.6808677,,,,,,,,,,,,,, -200700,5.003759,0.6690955,,,,,,,,,,,,,, -200800,5.113779,0.65980935,,,,,,,,,,,,,, -200900,4.3606234,0.5560156,,,,,,,,,,,,,, -201000,4.674529,0.6244675,,,,,,,,,,,,,, -201100,4.9270387,0.6291953,,,,,,,,,,,,,, -201200,4.4755254,0.6452975,,,,,,,,,,,,,, -201299,,,0.9599011540412904,0.1499934792518615,0.7560399770736694,1.052374243736267,50000.0,0.6306000351905823,1.818956971168518,10000.0,67887.04694342613,70398.6608979702,67887.04694342613,2496.34637260437,7.337639093399048,0.0 -201300,4.5325203,0.61733097,,,,,,,,,,,,,, -201400,4.2980943,0.5867258,,,,,,,,,,,,,, -201500,4.3710475,0.48216975,,,,,,,,,,,,,, -201600,4.5024343,0.65000397,,,,,,,,,,,,,, -201700,4.6398444,0.62587947,,,,,,,,,,,,,, -201800,4.339041,0.61570376,,,,,,,,,,,,,, -201900,4.535589,0.619143,,,,,,,,,,,,,, -202000,4.7713637,0.6259201,,,,,,,,,,,,,, -202100,4.2605205,0.6360048,,,,,,,,,,,,,, -202200,4.7636285,0.61190045,,,,,,,,,,,,,, -202300,4.5344567,0.5924441,,,,,,,,,,,,,, -202400,4.488909,0.6376547,,,,,,,,,,,,,, -202500,4.4622154,0.68066084,,,,,,,,,,,,,, -202600,4.2395954,0.60308707,,,,,,,,,,,,,, -202700,4.4385643,0.67389554,,,,,,,,,,,,,, -202800,4.372342,0.57690424,,,,,,,,,,,,,, -202813,,,0.961734652519226,0.1433680802583694,0.755840003490448,1.0518240928649902,50000.0,0.6317000389099121,1.817797303199768,10000.0,68397.03761839867,70926.17774367332,68397.03761839867,2513.746617078781,7.402098655700684,0.0 -202900,4.2729015,0.6196813,,,,,,,,,,,,,, -203000,4.9386787,0.62507653,,,,,,,,,,,,,, -203100,4.649635,0.60478455,,,,,,,,,,,,,, -203200,4.4755206,0.5740283,,,,,,,,,,,,,, -203300,4.675663,0.65623426,,,,,,,,,,,,,, -203400,4.232133,0.611269,,,,,,,,,,,,,, -203500,4.640519,0.62677395,,,,,,,,,,,,,, -203600,4.643239,0.7159072,,,,,,,,,,,,,, -203700,4.7296715,0.6210504,,,,,,,,,,,,,, -203800,4.302073,0.69421726,,,,,,,,,,,,,, -203900,4.716772,0.66342866,,,,,,,,,,,,,, -204000,4.3920765,0.5915301,,,,,,,,,,,,,, -204100,4.5278726,0.58828104,,,,,,,,,,,,,, -204200,4.254539,0.620329,,,,,,,,,,,,,, -204300,4.218272,0.5937482,,,,,,,,,,,,,, -204327,,,0.9606783986091614,0.1463976204395294,0.756060004234314,1.0515127182006836,50000.0,0.6305000185966492,1.817637801170349,10000.0,68907.05328130722,71453.3792629242,68907.05328130722,2530.812096595764,7.462094306945801,0.0 -204400,4.420406,0.60273015,,,,,,,,,,,,,, -204500,4.173833,0.55579215,,,,,,,,,,,,,, -204600,4.6668735,0.73797035,,,,,,,,,,,,,, -204700,4.813253,0.5938429,,,,,,,,,,,,,, -204800,4.161334,0.56741303,,,,,,,,,,,,,, -204900,4.3658733,0.5734281,,,,,,,,,,,,,, -205000,4.425082,0.66791344,,,,,,,,,,,,,, -205100,5.188934,0.66434383,,,,,,,,,,,,,, -205200,4.7077346,0.6351893,,,,,,,,,,,,,, -205300,4.61661,0.5799888,,,,,,,,,,,,,, -205400,4.404164,0.54536533,,,,,,,,,,,,,, -205500,4.4020753,0.54189944,,,,,,,,,,,,,, -205600,4.3550205,0.61218923,,,,,,,,,,,,,, -205700,5.079147,0.6813421,,,,,,,,,,,,,, -205800,4.429893,0.63062245,,,,,,,,,,,,,, -205842,,,0.9607381820678712,0.1457848995923996,0.7559799551963806,1.0517550706863403,50000.0,0.6310000419616699,1.8176188468933103,10000.0,69417.21089410782,71981.13394188881,69417.21089410782,2548.2866756916046,7.5248517990112305,0.0 -205900,4.5171733,0.59211516,,,,,,,,,,,,,, -206000,4.2595162,0.5530776,,,,,,,,,,,,,, -206100,4.3881054,0.6170776,,,,,,,,,,,,,, -206200,4.648655,0.66083103,,,,,,,,,,,,,, -206300,4.766916,0.61964643,,,,,,,,,,,,,, -206400,4.7623215,0.66037077,,,,,,,,,,,,,, -206500,4.2131267,0.5898006,,,,,,,,,,,,,, -206600,4.1268153,0.6230212,,,,,,,,,,,,,, -206700,3.9800782,0.55731803,,,,,,,,,,,,,, -206800,4.564338,0.60568535,,,,,,,,,,,,,, -206900,4.9058194,0.66314334,,,,,,,,,,,,,, -207000,4.4242935,0.6046594,,,,,,,,,,,,,, -207100,4.330722,0.5812298,,,,,,,,,,,,,, -207200,4.529637,0.61308175,,,,,,,,,,,,,, -207300,4.356989,0.65473217,,,,,,,,,,,,,, -207356,,,0.9599609375,0.1481670588254928,0.7555399537086487,1.0524675846099854,50000.0,0.631600022315979,1.818718791007996,10000.0,69927.16663312912,72508.60455965996,69927.16663312912,2565.68039393425,7.586683750152588,0.0 -207400,4.6405163,0.6315711,,,,,,,,,,,,,, -207500,4.7152433,0.7028556,,,,,,,,,,,,,, -207600,4.5323496,0.63615257,,,,,,,,,,,,,, -207700,4.4791007,0.6106385,,,,,,,,,,,,,, -207800,4.8943954,0.705624,,,,,,,,,,,,,, -207900,4.377776,0.64007866,,,,,,,,,,,,,, -208000,4.823435,0.7131837,,,,,,,,,,,,,, -208100,4.3942432,0.6268133,,,,,,,,,,,,,, -208200,4.2532563,0.6142528,,,,,,,,,,,,,, -208300,4.6009817,0.5807005,,,,,,,,,,,,,, -208400,4.240388,0.558445,,,,,,,,,,,,,, -208500,4.528721,0.63873225,,,,,,,,,,,,,, -208600,4.4931335,0.6423952,,,,,,,,,,,,,, -208700,4.129532,0.59912413,,,,,,,,,,,,,, -208800,4.5722938,0.6474679,,,,,,,,,,,,,, -208870,,,0.9607780575752258,0.1466470956802368,0.7554000020027161,1.0518457889556885,50000.0,0.6312000155448914,1.816641926765442,10000.0,70437.09859132767,73035.96860289574,70437.09859132767,2582.988062143326,7.651637077331543,0.0 -208900,4.209197,0.554408,,,,,,,,,,,,,, -209000,4.5747666,0.66650903,,,,,,,,,,,,,, -209100,4.3884206,0.5991552,,,,,,,,,,,,,, -209200,4.656006,0.7002963,,,,,,,,,,,,,, -209300,4.232173,0.5888763,,,,,,,,,,,,,, -209400,4.409561,0.62322474,,,,,,,,,,,,,, -209500,4.705662,0.6192753,,,,,,,,,,,,,, -209600,4.298347,0.6138879,,,,,,,,,,,,,, -209700,4.6174765,0.62131494,,,,,,,,,,,,,, -209800,4.5428395,0.6557811,,,,,,,,,,,,,, -209900,3.8388367,0.5440005,,,,,,,,,,,,,, -210000,5.2397475,0.67354184,,,,,,,,,,,,,, -210100,3.9999382,0.5626698,,,,,,,,,,,,,, -210200,4.097153,0.5627866,,,,,,,,,,,,,, -210300,4.2750645,0.6149477,,,,,,,,,,,,,, -210384,,,0.9600605964660645,0.1470133066177368,0.7558599710464478,1.0524232387542725,50000.0,0.6304000020027161,1.817249059677124,10000.0,70947.238874197,73563.5444612503,70947.238874197,2600.2966318130493,7.717857122421265,0.0 -210400,4.6461234,0.62876284,,,,,,,,,,,,,, -210500,4.5682106,0.5770451,,,,,,,,,,,,,, -210600,4.72777,0.6405009,,,,,,,,,,,,,, -210700,4.4883733,0.6680352,,,,,,,,,,,,,, -210800,4.1954336,0.59334916,,,,,,,,,,,,,, -210900,4.332213,0.5926064,,,,,,,,,,,,,, -211000,4.6528873,0.64641726,,,,,,,,,,,,,, -211100,4.136421,0.59925485,,,,,,,,,,,,,, -211200,4.1577992,0.57560146,,,,,,,,,,,,,, -211300,4.4732757,0.6782193,,,,,,,,,,,,,, -211400,4.6707177,0.61889637,,,,,,,,,,,,,, -211500,4.296369,0.61636376,,,,,,,,,,,,,, -211600,4.4766297,0.6033015,,,,,,,,,,,,,, -211700,4.352643,0.6566079,,,,,,,,,,,,,, -211800,4.484067,0.5617228,,,,,,,,,,,,,, -211897,,,0.9603196382522584,0.1500230282545089,0.7558599710464478,1.052908420562744,50000.0,0.6318000555038452,1.8178008794784544,10000.0,71457.24890565872,74090.75350880623,71457.24890565872,2617.373607635498,7.779970407485962,0.0 -211900,4.1581445,0.52764386,,,,,,,,,,,,,, -212000,4.525587,0.6415602,,,,,,,,,,,,,, -212100,4.2465096,0.59093934,,,,,,,,,,,,,, -212200,4.8245497,0.6336615,,,,,,,,,,,,,, -212300,4.431711,0.6409742,,,,,,,,,,,,,, -212400,4.5358024,0.57823616,,,,,,,,,,,,,, -212500,4.7161674,0.64797944,,,,,,,,,,,,,, -212600,4.5736237,0.649299,,,,,,,,,,,,,, -212700,4.802623,0.74590635,,,,,,,,,,,,,, -212800,4.2408533,0.604658,,,,,,,,,,,,,, -212900,4.510796,0.55446935,,,,,,,,,,,,,, -213000,4.4036946,0.5679911,,,,,,,,,,,,,, -213100,5.0660844,0.6370227,,,,,,,,,,,,,, -213200,4.3612413,0.64066935,,,,,,,,,,,,,, -213300,4.322378,0.54933643,,,,,,,,,,,,,, -213400,4.627474,0.6435018,,,,,,,,,,,,,, -213411,,,0.9616748690605164,0.1448988020420074,0.7557399868965149,1.0528243780136108,50000.0,0.6317000389099121,1.818710446357727,10000.0,71967.18639993668,74617.97519087791,71967.18639993668,2634.5299496650696,7.84785270690918,0.0 -213500,4.5297356,0.5986686,,,,,,,,,,,,,, -213600,4.691134,0.5775406,,,,,,,,,,,,,, -213700,4.6111274,0.66876173,,,,,,,,,,,,,, -213800,4.536544,0.65868104,,,,,,,,,,,,,, -213900,4.3723783,0.6331667,,,,,,,,,,,,,, -214000,4.3989615,0.54966986,,,,,,,,,,,,,, -214100,4.665392,0.55835694,,,,,,,,,,,,,, -214200,4.456552,0.7010035,,,,,,,,,,,,,, -214300,4.500699,0.58425236,,,,,,,,,,,,,, -214400,4.81014,0.60468316,,,,,,,,,,,,,, -214500,4.5387487,0.62536955,,,,,,,,,,,,,, -214600,4.5375986,0.58204824,,,,,,,,,,,,,, -214700,4.4405456,0.64835805,,,,,,,,,,,,,, -214800,4.378691,0.6962668,,,,,,,,,,,,,, -214900,4.8509693,0.65951496,,,,,,,,,,,,,, -214925,,,0.9598811864852904,0.1491194367408752,0.7555800080299377,1.0532875061035156,50000.0,0.6321000456809998,1.8189986944198608,10000.0,72477.23326277733,75145.44805765152,72477.23326277733,2651.8291053771973,7.914536952972412,0.0 -215000,4.096466,0.5721863,,,,,,,,,,,,,, -215100,4.994827,0.6138191,,,,,,,,,,,,,, -215200,4.5777435,0.6784333,,,,,,,,,,,,,, -215300,4.4311223,0.5950625,,,,,,,,,,,,,, -215400,4.329328,0.6696845,,,,,,,,,,,,,, -215500,4.4404707,0.6510733,,,,,,,,,,,,,, -215600,4.429065,0.56659156,,,,,,,,,,,,,, -215700,4.5184097,0.68908143,,,,,,,,,,,,,, -215800,4.468437,0.6005968,,,,,,,,,,,,,, -215900,5.06646,0.5805352,,,,,,,,,,,,,, -216000,4.3371515,0.58345777,,,,,,,,,,,,,, -216100,4.292995,0.6464627,,,,,,,,,,,,,, -216200,4.453646,0.6375546,,,,,,,,,,,,,, -216300,4.1980443,0.68613064,,,,,,,,,,,,,, -216400,4.6944537,0.6306625,,,,,,,,,,,,,, -216439,,,0.9616549611091614,0.1474548429250717,0.7560200095176697,1.0531450510025024,50000.0,0.6322000026702881,1.8198258876800537,10000.0,72987.39051795006,75672.81736660004,72987.39051795006,2668.91423535347,7.982542037963867,0.0 -216500,4.3665934,0.56554645,,,,,,,,,,,,,, -216600,4.4275575,0.62988484,,,,,,,,,,,,,, -216700,4.6718326,0.64600843,,,,,,,,,,,,,, -216800,4.4917374,0.65830886,,,,,,,,,,,,,, -216900,4.426863,0.5740624,,,,,,,,,,,,,, -217000,4.855569,0.6249179,,,,,,,,,,,,,, -217100,4.8768463,0.66497076,,,,,,,,,,,,,, -217200,4.578203,0.6381679,,,,,,,,,,,,,, -217300,4.6375833,0.5707834,,,,,,,,,,,,,, -217400,4.5054417,0.56490356,,,,,,,,,,,,,, -217500,4.1280684,0.588476,,,,,,,,,,,,,, -217600,4.6210485,0.6850149,,,,,,,,,,,,,, -217700,4.4042864,0.58456445,,,,,,,,,,,,,, -217800,4.1814384,0.6077337,,,,,,,,,,,,,, -217900,4.371237,0.60487163,,,,,,,,,,,,,, -217953,,,0.96000075340271,0.1477039009332656,0.7557799816131592,1.052709460258484,50000.0,0.6304000020027161,1.8178809881210327,10000.0,73497.54736804962,76200.37576818466,73497.54736804962,2686.1898329257965,8.048727750778198,0.0 -218000,4.182722,0.6296881,,,,,,,,,,,,,, -218100,4.3788342,0.655061,,,,,,,,,,,,,, -218200,4.978625,0.5951763,,,,,,,,,,,,,, -218300,4.762043,0.67056346,,,,,,,,,,,,,, -218400,4.099931,0.58820784,,,,,,,,,,,,,, -218500,4.1719284,0.578014,,,,,,,,,,,,,, -218600,4.619797,0.72411656,,,,,,,,,,,,,, -218700,4.7215853,0.6480624,,,,,,,,,,,,,, -218800,4.235709,0.64042366,,,,,,,,,,,,,, -218900,4.336837,0.6031594,,,,,,,,,,,,,, -219000,4.40335,0.68520606,,,,,,,,,,,,,, -219100,4.2204,0.62400556,,,,,,,,,,,,,, -219200,4.618384,0.58320993,,,,,,,,,,,,,, -219300,4.3523493,0.6176754,,,,,,,,,,,,,, -219400,4.6564727,0.6116068,,,,,,,,,,,,,, -219467,,,0.9601203799247742,0.1478992104530334,0.7565199732780457,1.0514085292816162,50000.0,0.6317000389099121,1.8171485662460327,10000.0,74007.73700547218,76727.81662344933,74007.73700547218,2703.3161799907684,8.112717628479004,0.0 -219500,4.3248744,0.5713479,,,,,,,,,,,,,, -219600,5.1221466,0.6375085,,,,,,,,,,,,,, -219700,5.0043063,0.6997049,,,,,,,,,,,,,, -219800,4.4824066,0.70535636,,,,,,,,,,,,,, -219900,4.3977804,0.57871574,,,,,,,,,,,,,, -220000,4.2715526,0.6093314,,,,,,,,,,,,,, -220100,5.0550513,0.6474732,,,,,,,,,,,,,, -220200,4.8218374,0.657566,,,,,,,,,,,,,, -220300,3.8155591,0.56581295,,,,,,,,,,,,,, -220400,4.226726,0.60332334,,,,,,,,,,,,,, -220500,4.4568553,0.63701266,,,,,,,,,,,,,, -220600,4.6894703,0.6303346,,,,,,,,,,,,,, -220700,5.168103,0.657687,,,,,,,,,,,,,, -220800,4.170993,0.6442701,,,,,,,,,,,,,, -220900,4.558489,0.6147606,,,,,,,,,,,,,, -220981,,,0.960339605808258,0.1483343243598938,0.7559799551963806,1.0523303747177124,50000.0,0.6320000290870667,1.8170801401138303,10000.0,74517.93527269363,77255.46332716942,74517.93527269363,2720.639880657196,8.177234411239624,0.0 -221000,4.412785,0.5449824,,,,,,,,,,,,,, -221100,4.6359262,0.6175783,,,,,,,,,,,,,, -221200,4.7740283,0.63249296,,,,,,,,,,,,,, -221300,4.4104767,0.5490294,,,,,,,,,,,,,, -221400,4.5401335,0.5703547,,,,,,,,,,,,,, -221500,4.797807,0.59484255,,,,,,,,,,,,,, -221600,4.4182234,0.6414049,,,,,,,,,,,,,, -221700,4.5303974,0.7083403,,,,,,,,,,,,,, -221800,4.7800064,0.61704004,,,,,,,,,,,,,, -221900,4.2851887,0.6188413,,,,,,,,,,,,,, -222000,4.3467007,0.60353476,,,,,,,,,,,,,, -222100,4.258903,0.55905694,,,,,,,,,,,,,, -222200,4.374768,0.64874923,,,,,,,,,,,,,, -222300,4.486818,0.6837946,,,,,,,,,,,,,, -222400,4.8659215,0.6463154,,,,,,,,,,,,,, -222495,,,0.960758090019226,0.1472418159246444,0.7555599808692932,1.0518176555633545,50000.0,0.6313000321388245,1.8161752223968504,10000.0,75028.07389211655,77783.17099523544,75028.07389211655,2738.0824706554413,8.243817806243896,0.0 -222500,4.6466885,0.7024696,,,,,,,,,,,,,, -222600,4.5917253,0.56432796,,,,,,,,,,,,,, -222700,4.7261143,0.628899,,,,,,,,,,,,,, -222800,4.039986,0.59234047,,,,,,,,,,,,,, -222900,4.380472,0.60194296,,,,,,,,,,,,,, -223000,4.141355,0.6077609,,,,,,,,,,,,,, -223100,4.446396,0.6323623,,,,,,,,,,,,,, -223200,4.244762,0.64426255,,,,,,,,,,,,,, -223300,4.705624,0.6455796,,,,,,,,,,,,,, -223400,4.476873,0.58047634,,,,,,,,,,,,,, -223500,4.3626394,0.6300573,,,,,,,,,,,,,, -223600,4.55966,0.6324423,,,,,,,,,,,,,, -223700,5.0307145,0.6303135,,,,,,,,,,,,,, -223800,4.237191,0.5123777,,,,,,,,,,,,,, -223900,4.3190527,0.60317284,,,,,,,,,,,,,, -224000,4.4318676,0.6502392,,,,,,,,,,,,,, -224009,,,0.9613958597183228,0.1484321802854538,0.7560399770736694,1.0512542724609375,50000.0,0.6301000118255615,1.8154196739196773,10000.0,75538.01856184006,78310.43208026886,75538.01856184006,2755.275614976883,8.306404113769531,0.0 -224100,4.496915,0.5939848,,,,,,,,,,,,,, -224200,4.5113716,0.63540256,,,,,,,,,,,,,, -224300,4.6701884,0.6298571,,,,,,,,,,,,,, -224400,5.372514,0.70988923,,,,,,,,,,,,,, -224500,4.6214805,0.5891572,,,,,,,,,,,,,, -224600,4.5143785,0.74527067,,,,,,,,,,,,,, -224700,4.491213,0.66051763,,,,,,,,,,,,,, -224800,4.476956,0.6033429,,,,,,,,,,,,,, -224900,4.3199325,0.5916711,,,,,,,,,,,,,, -225000,4.2676086,0.6040919,,,,,,,,,,,,,, -225100,4.7798433,0.59879774,,,,,,,,,,,,,, -225200,4.628442,0.6202582,,,,,,,,,,,,,, -225300,4.677204,0.68304944,,,,,,,,,,,,,, -225400,4.399038,0.7006649,,,,,,,,,,,,,, -225500,4.9436927,0.6705645,,,,,,,,,,,,,, -225522,,,0.961336076259613,0.1458679139614105,0.7559599876403809,1.0527108907699585,50000.0,0.6317000389099121,1.818994641304016,10000.0,76047.95550060272,78837.83005452156,76047.95550060272,2772.6055793762207,8.376575708389282,0.0 -225600,4.548553,0.603305,,,,,,,,,,,,,, -225700,4.4080925,0.57991457,,,,,,,,,,,,,, -225800,4.5487437,0.6840519,,,,,,,,,,,,,, -225900,4.65257,0.6699264,,,,,,,,,,,,,, -226000,4.387965,0.6486884,,,,,,,,,,,,,, -226100,5.059096,0.64358723,,,,,,,,,,,,,, -226200,4.4938793,0.72707546,,,,,,,,,,,,,, -226300,4.680147,0.57328683,,,,,,,,,,,,,, -226400,4.4214315,0.5982083,,,,,,,,,,,,,, -226500,4.3424635,0.5753594,,,,,,,,,,,,,, -226600,4.745942,0.6558969,,,,,,,,,,,,,, -226700,4.8834596,0.68148607,,,,,,,,,,,,,, -226800,4.4900403,0.6634315,,,,,,,,,,,,,, -226900,4.577594,0.5731369,,,,,,,,,,,,,, -227000,4.689996,0.6422144,,,,,,,,,,,,,, -227037,,,0.961933970451355,0.1421834528446197,0.7557599544525146,1.0518858432769775,50000.0,0.631100058555603,1.8159196376800537,10000.0,76557.94317746162,79364.89956021309,76557.94317746162,2789.5629415512085,8.442211389541626,0.0 -227100,4.2884665,0.55374944,,,,,,,,,,,,,, -227200,4.448682,0.596398,,,,,,,,,,,,,, -227300,4.3972526,0.6596327,,,,,,,,,,,,,, -227400,4.6886563,0.65898186,,,,,,,,,,,,,, -227500,4.8925147,0.7290793,,,,,,,,,,,,,, -227600,4.6755347,0.6742955,,,,,,,,,,,,,, -227700,4.1860385,0.6173758,,,,,,,,,,,,,, -227800,4.4641504,0.6375763,,,,,,,,,,,,,, -227900,4.3449535,0.6116266,,,,,,,,,,,,,, -228000,4.9520144,0.60622406,,,,,,,,,,,,,, -228100,4.306178,0.61696243,,,,,,,,,,,,,, -228200,4.376973,0.701874,,,,,,,,,,,,,, -228300,4.342329,0.67181474,,,,,,,,,,,,,, -228400,4.597863,0.6985711,,,,,,,,,,,,,, -228500,4.768244,0.6398019,,,,,,,,,,,,,, -228550,,,0.9608577489852904,0.1471737623214721,0.7555599808692932,1.052388072013855,50000.0,0.6301000118255615,1.8176794052124023,10000.0,77068.1039595604,79892.73308634758,77068.1039595604,2807.102013349533,8.51409387588501,0.0 -228600,3.9708529,0.525518,,,,,,,,,,,,,, -228700,4.3857265,0.6602887,,,,,,,,,,,,,, -228800,4.4542675,0.6269078,,,,,,,,,,,,,, -228900,4.699432,0.66226155,,,,,,,,,,,,,, -229000,4.0627036,0.6096084,,,,,,,,,,,,,, -229100,3.8377995,0.5888537,,,,,,,,,,,,,, -229200,4.671726,0.7223176,,,,,,,,,,,,,, -229300,4.5802255,0.67390835,,,,,,,,,,,,,, -229400,4.624947,0.6638346,,,,,,,,,,,,,, -229500,4.3492775,0.6275055,,,,,,,,,,,,,, -229600,4.2270465,0.63916874,,,,,,,,,,,,,, -229700,4.0986657,0.58312047,,,,,,,,,,,,,, -229800,4.609236,0.6644203,,,,,,,,,,,,,, -229900,4.397245,0.63964605,,,,,,,,,,,,,, -230000,5.0533175,0.6420319,,,,,,,,,,,,,, -230064,,,0.9589245915412904,0.1493618786334991,0.7556999921798706,1.051660180091858,50000.0,0.6319000124931335,1.8184075355529783,10000.0,77578.0581843853,80420.75643348694,77578.0581843853,2825.0334997177124,8.582934141159058,0.0 -230100,5.0019264,0.65302235,,,,,,,,,,,,,, -230200,4.5506926,0.6336163,,,,,,,,,,,,,, -230300,4.2373886,0.62001646,,,,,,,,,,,,,, -230400,4.4129715,0.7040071,,,,,,,,,,,,,, -230500,4.4902196,0.62162596,,,,,,,,,,,,,, -230600,4.435966,0.72639096,,,,,,,,,,,,,, -230700,4.1078253,0.59501517,,,,,,,,,,,,,, -230800,4.0777454,0.5740602,,,,,,,,,,,,,, -230900,4.782672,0.5803318,,,,,,,,,,,,,, -231000,4.710946,0.6040043,,,,,,,,,,,,,, -231100,4.673714,0.67035,,,,,,,,,,,,,, -231200,4.755417,0.6550001,,,,,,,,,,,,,, -231300,4.568132,0.6483276,,,,,,,,,,,,,, -231400,4.2736115,0.6257989,,,,,,,,,,,,,, -231500,4.376998,0.63997906,,,,,,,,,,,,,, -231577,,,0.960718274116516,0.1464399844408035,0.755840003490448,1.0522643327713013,50000.0,0.6313000321388245,1.8174564838409424,10000.0,78088.10909414291,80948.29805445671,78088.10909414291,2842.399694442749,8.647887229919434,0.0 -231600,4.5237365,0.6260458,,,,,,,,,,,,,, -231700,4.2845635,0.5903933,,,,,,,,,,,,,, -231800,4.2740903,0.5894365,,,,,,,,,,,,,, -231900,4.2319036,0.55277544,,,,,,,,,,,,,, -232000,4.3899565,0.6937983,,,,,,,,,,,,,, -232100,4.392467,0.6073521,,,,,,,,,,,,,, -232200,4.3996944,0.5948369,,,,,,,,,,,,,, -232300,4.631749,0.62112665,,,,,,,,,,,,,, -232400,4.557539,0.5743025,,,,,,,,,,,,,, -232500,4.874803,0.5963057,,,,,,,,,,,,,, -232600,4.404402,0.66116196,,,,,,,,,,,,,, -232700,4.1201406,0.58971125,,,,,,,,,,,,,, -232800,4.8044624,0.65354824,,,,,,,,,,,,,, -232900,4.677943,0.6185189,,,,,,,,,,,,,, -233000,4.420475,0.6512099,,,,,,,,,,,,,, -233091,,,0.9599409699440002,0.1503974199295044,0.7559799551963806,1.051724553108215,50000.0,0.6313000321388245,1.818180799484253,10000.0,78598.0595126152,81475.70970320702,78598.0595126152,2859.735461950302,8.714589595794678,0.0 -233100,4.3885484,0.6650468,,,,,,,,,,,,,, -233200,4.453508,0.6168506,,,,,,,,,,,,,, -233300,4.367233,0.6402683,,,,,,,,,,,,,, -233400,4.8637424,0.6414891,,,,,,,,,,,,,, -233500,4.413956,0.5960999,,,,,,,,,,,,,, -233600,4.209099,0.61695004,,,,,,,,,,,,,, -233700,4.3496823,0.68079495,,,,,,,,,,,,,, -233800,4.1091533,0.51195186,,,,,,,,,,,,,, -233900,4.7417774,0.71770996,,,,,,,,,,,,,, -234000,4.903454,0.70041454,,,,,,,,,,,,,, -234100,4.5261526,0.5863002,,,,,,,,,,,,,, -234200,4.31206,0.5930064,,,,,,,,,,,,,, -234300,4.36843,0.6574366,,,,,,,,,,,,,, -234400,4.835496,0.6189483,,,,,,,,,,,,,, -234500,4.6258063,0.6695012,,,,,,,,,,,,,, -234600,5.417562,0.61210316,,,,,,,,,,,,,, -234604,,,0.9614955186843872,0.1457913517951965,0.7561999559402466,1.0515480041503906,50000.0,0.6304000020027161,1.8177772760391235,10000.0,79107.94815778732,82002.83469963074,79107.94815778732,2876.846007347107,8.78090786933899,0.0 -234700,3.939351,0.5642674,,,,,,,,,,,,,, -234800,4.209899,0.6443813,,,,,,,,,,,,,, -234900,4.1868663,0.54765457,,,,,,,,,,,,,, -235000,4.5114293,0.60261375,,,,,,,,,,,,,, -235100,4.4283442,0.65483755,,,,,,,,,,,,,, -235200,4.341696,0.5731822,,,,,,,,,,,,,, -235300,4.333563,0.5801132,,,,,,,,,,,,,, -235400,4.3630476,0.5856381,,,,,,,,,,,,,, -235500,4.854823,0.70666605,,,,,,,,,,,,,, -235600,4.35567,0.5584577,,,,,,,,,,,,,, -235700,4.480193,0.65031725,,,,,,,,,,,,,, -235800,4.135019,0.6008844,,,,,,,,,,,,,, -235900,4.889225,0.5768914,,,,,,,,,,,,,, -236000,4.409351,0.59421223,,,,,,,,,,,,,, -236100,4.844251,0.6363436,,,,,,,,,,,,,, -236118,,,0.9614357352256776,0.1441568434238433,0.7552199959754944,1.052870512008667,50000.0,0.6310000419616699,1.817879557609558,10000.0,79617.99672365189,82530.32462906837,79617.99672365189,2894.1628913879395,8.845736265182495,0.0 -236200,4.776307,0.5989125,,,,,,,,,,,,,, -236300,4.257646,0.5114255,,,,,,,,,,,,,, -236400,4.2984166,0.5719483,,,,,,,,,,,,,, -236500,4.4173093,0.6293351,,,,,,,,,,,,,, -236600,4.6156163,0.6345474,,,,,,,,,,,,,, -236700,4.700853,0.71231383,,,,,,,,,,,,,, -236800,4.667947,0.6676935,,,,,,,,,,,,,, -236900,4.4254966,0.647141,,,,,,,,,,,,,, -237000,4.938594,0.64006233,,,,,,,,,,,,,, -237100,4.405824,0.6295366,,,,,,,,,,,,,, -237200,4.3552465,0.60912496,,,,,,,,,,,,,, -237300,4.4658875,0.66441274,,,,,,,,,,,,,, -237400,4.3619,0.6143358,,,,,,,,,,,,,, -237500,4.547326,0.5796033,,,,,,,,,,,,,, -237600,3.9991858,0.5307859,,,,,,,,,,,,,, -237632,,,0.9603993892669678,0.1467776894569397,0.7555599808692932,1.0517122745513916,50000.0,0.631100058555603,1.817185401916504,10000.0,80127.95319676399,83057.54863166809,80127.95319676399,2911.3059771060944,8.911030530929565,0.0 -237700,5.264571,0.732343,,,,,,,,,,,,,, -237800,4.680475,0.64708114,,,,,,,,,,,,,, -237900,4.8273535,0.57470083,,,,,,,,,,,,,, -238000,4.1575475,0.58108115,,,,,,,,,,,,,, -238100,4.4464264,0.6385874,,,,,,,,,,,,,, -238200,4.4134808,0.617811,,,,,,,,,,,,,, -238300,4.5606275,0.6196193,,,,,,,,,,,,,, -238400,4.4850693,0.604991,,,,,,,,,,,,,, -238500,4.4147053,0.5608241,,,,,,,,,,,,,, -238600,4.2974052,0.55979896,,,,,,,,,,,,,, -238700,4.1890287,0.56483984,,,,,,,,,,,,,, -238800,4.582234,0.6159494,,,,,,,,,,,,,, -238900,4.75951,0.69968104,,,,,,,,,,,,,, -239000,5.0200677,0.68209785,,,,,,,,,,,,,, -239100,4.477352,0.6918118,,,,,,,,,,,,,, -239146,,,0.9608178734779358,0.1466523557901382,0.7553799748420715,1.0525615215301514,50000.0,0.6321000456809998,1.817415714263916,10000.0,80637.87059664726,83584.8637034893,80637.87059664726,2928.5791189670563,8.975881099700928,0.0 -239200,4.201587,0.6016035,,,,,,,,,,,,,, -239300,4.1528816,0.5613947,,,,,,,,,,,,,, -239400,4.642089,0.60163325,,,,,,,,,,,,,, -239500,4.453562,0.6784099,,,,,,,,,,,,,, -239600,4.6750193,0.63484055,,,,,,,,,,,,,, -239700,4.8693204,0.5890931,,,,,,,,,,,,,, -239800,4.5653744,0.6316491,,,,,,,,,,,,,, -239900,4.729978,0.6440875,,,,,,,,,,,,,, -240000,4.862992,0.7309645,,,,,,,,,,,,,, -240100,4.847428,0.6705847,,,,,,,,,,,,,, -240200,4.0607204,0.55342776,,,,,,,,,,,,,, -240300,4.1692615,0.5727165,,,,,,,,,,,,,, -240400,4.594462,0.6772483,,,,,,,,,,,,,, -240500,4.1025805,0.5916251,,,,,,,,,,,,,, -240600,4.421496,0.6778546,,,,,,,,,,,,,, -240660,,,0.960558831691742,0.1472288072109222,0.7558199763298035,1.052385687828064,50000.0,0.6310000419616699,1.8181267976760864,10000.0,81148.00268936157,84112.44211959839,81148.00268936157,2945.8968670368195,9.044187307357788,0.0 -240700,4.4734983,0.5982466,,,,,,,,,,,,,, -240800,4.245215,0.5914579,,,,,,,,,,,,,, -240900,4.2787786,0.6143332,,,,,,,,,,,,,, -241000,4.6335144,0.6277842,,,,,,,,,,,,,, -241100,4.455351,0.5992203,,,,,,,,,,,,,, -241200,4.2249856,0.5936459,,,,,,,,,,,,,, -241300,4.9235954,0.6105508,,,,,,,,,,,,,, -241400,3.878372,0.6031654,,,,,,,,,,,,,, -241500,4.644275,0.5544462,,,,,,,,,,,,,, -241600,4.3678665,0.64071715,,,,,,,,,,,,,, -241700,4.392087,0.6250642,,,,,,,,,,,,,, -241800,4.560273,0.59335595,,,,,,,,,,,,,, -241900,4.5076447,0.6437473,,,,,,,,,,,,,, -242000,4.426788,0.6526946,,,,,,,,,,,,,, -242100,4.170344,0.5546149,,,,,,,,,,,,,, -242174,,,0.961933970451355,0.144880786538124,0.7554799914360046,1.0521867275238037,50000.0,0.6317000389099121,1.8185575008392327,10000.0,81657.94163036346,84639.84378743172,81657.94163036346,2963.219983100891,9.121662378311155,0.0 -242200,4.4538755,0.6791041,,,,,,,,,,,,,, -242300,4.6322665,0.6360796,,,,,,,,,,,,,, -242400,4.1442976,0.577089,,,,,,,,,,,,,, -242500,4.4479904,0.5710578,,,,,,,,,,,,,, -242600,4.228148,0.59605217,,,,,,,,,,,,,, -242700,4.979773,0.6714929,,,,,,,,,,,,,, -242800,4.592065,0.63698494,,,,,,,,,,,,,, -242900,4.6398063,0.6559068,,,,,,,,,,,,,, -243000,4.4384704,0.64570993,,,,,,,,,,,,,, -243100,4.8114038,0.6494595,,,,,,,,,,,,,, -243200,3.9193263,0.5815185,,,,,,,,,,,,,, -243300,4.501537,0.66173536,,,,,,,,,,,,,, -243400,4.091409,0.59421265,,,,,,,,,,,,,, -243500,4.4260736,0.65957355,,,,,,,,,,,,,, -243600,4.314165,0.61193526,,,,,,,,,,,,,, -243687,,,0.960160195827484,0.1469377130270004,0.7560999989509583,1.0522353649139404,50000.0,0.6317000389099121,1.818875312805176,10000.0,82167.82444095612,85167.19301080704,82167.82444095612,2980.5571575164795,9.191231966018677,0.0 -243700,4.170955,0.5656494,,,,,,,,,,,,,, -243800,4.1909013,0.58362675,,,,,,,,,,,,,, -243900,4.9682393,0.61308384,,,,,,,,,,,,,, -244000,4.862035,0.61014086,,,,,,,,,,,,,, -244100,4.3194513,0.5756812,,,,,,,,,,,,,, -244200,4.521727,0.62378347,,,,,,,,,,,,,, -244300,4.8479276,0.6891199,,,,,,,,,,,,,, -244400,4.204747,0.55967987,,,,,,,,,,,,,, -244500,4.1121306,0.50983685,,,,,,,,,,,,,, -244600,4.205223,0.60642344,,,,,,,,,,,,,, -244700,4.321909,0.6041139,,,,,,,,,,,,,, -244800,4.7274337,0.65117794,,,,,,,,,,,,,, -244900,4.92849,0.6563529,,,,,,,,,,,,,, -245000,4.345095,0.6490747,,,,,,,,,,,,,, -245100,4.1091743,0.6094374,,,,,,,,,,,,,, -245200,4.616433,0.6090085,,,,,,,,,,,,,, -245201,,,0.9598612785339355,0.1483822166919708,0.7560799717903137,1.0516401529312134,50000.0,0.631600022315979,1.817969560623169,10000.0,82677.70756721497,85694.38528704643,82677.70756721497,2997.735734939575,9.263187170028688,0.0 -245300,4.4080763,0.6147988,,,,,,,,,,,,,, -245400,4.659362,0.6222862,,,,,,,,,,,,,, -245500,4.839382,0.7286789,,,,,,,,,,,,,, -245600,4.8120894,0.5975871,,,,,,,,,,,,,, -245700,4.067573,0.56049156,,,,,,,,,,,,,, -245800,4.2678175,0.61316144,,,,,,,,,,,,,, -245900,4.7061462,0.6491026,,,,,,,,,,,,,, -246000,4.245161,0.5852133,,,,,,,,,,,,,, -246100,4.1948223,0.6089886,,,,,,,,,,,,,, -246200,4.6235,0.6155099,,,,,,,,,,,,,, -246300,4.5831423,0.6083255,,,,,,,,,,,,,, -246400,4.4772882,0.6672805,,,,,,,,,,,,,, -246500,4.2219095,0.59191024,,,,,,,,,,,,,, -246600,4.5289855,0.56224,,,,,,,,,,,,,, -246700,4.1137753,0.62059796,,,,,,,,,,,,,, -246714,,,0.9618144035339355,0.145142525434494,0.7558199763298035,1.0535144805908203,50000.0,0.6306000351905823,1.8190034627914429,10000.0,83187.67843437195,86221.69176626205,83187.67843437195,3014.9422421455383,9.33088445663452,0.0 -246800,4.2835183,0.5887631,,,,,,,,,,,,,, -246900,4.265656,0.62909544,,,,,,,,,,,,,, -247000,4.7306113,0.62954104,,,,,,,,,,,,,, -247100,4.0429354,0.59911186,,,,,,,,,,,,,, -247200,4.7106194,0.67342883,,,,,,,,,,,,,, -247300,4.533004,0.5990418,,,,,,,,,,,,,, -247400,4.4482875,0.62786865,,,,,,,,,,,,,, -247500,4.323845,0.63441104,,,,,,,,,,,,,, -247600,4.8100486,0.65341824,,,,,,,,,,,,,, -247700,4.6193247,0.6602633,,,,,,,,,,,,,, -247800,5.1814194,0.60750043,,,,,,,,,,,,,, -247900,4.974359,0.6642748,,,,,,,,,,,,,, -248000,4.6813912,0.64073884,,,,,,,,,,,,,, -248100,4.790186,0.63201535,,,,,,,,,,,,,, -248200,4.036256,0.5676141,,,,,,,,,,,,,, -248229,,,0.9604790806770324,0.1466723680496215,0.7559199929237366,1.051486253738403,50000.0,0.6314000487327576,1.8160229921340945,10000.0,83697.82090711594,86749.24345612526,83697.82090711594,3032.2274081707,9.39405369758606,0.0 -248300,4.8158355,0.61099136,,,,,,,,,,,,,, -248400,4.6683154,0.6740192,,,,,,,,,,,,,, -248500,4.3519435,0.53838605,,,,,,,,,,,,,, -248600,4.5767064,0.5951804,,,,,,,,,,,,,, -248700,4.8929763,0.694238,,,,,,,,,,,,,, -248800,4.524377,0.7065574,,,,,,,,,,,,,, -248900,4.629757,0.73349077,,,,,,,,,,,,,, -249000,4.402221,0.57462925,,,,,,,,,,,,,, -249100,4.304179,0.63979214,,,,,,,,,,,,,, -249200,4.622239,0.61874115,,,,,,,,,,,,,, -249300,4.2089953,0.5658718,,,,,,,,,,,,,, -249400,4.584661,0.6242942,,,,,,,,,,,,,, -249500,4.555783,0.5970281,,,,,,,,,,,,,, -249600,4.3906193,0.6108685,,,,,,,,,,,,,, -249700,4.124313,0.57126534,,,,,,,,,,,,,, -249743,,,0.9604392051696776,0.1476988494396209,0.7559199929237366,1.0521618127822876,50000.0,0.6313000321388245,1.8176978826522827,10000.0,84207.96512365341,87276.74302601814,84207.96512365341,3049.45293712616,9.46442985534668,0.0 -249800,4.439882,0.6755525,,,,,,,,,,,,,, -249900,4.307506,0.6810243,,,,,,,,,,,,,, -250000,4.101353,0.61676526,,,,,,,,,,,,,, -250100,5.1464043,0.58587897,,,,,,,,,,,,,, -250200,4.5073447,0.5835401,,,,,,,,,,,,,, -250300,4.795079,0.6099619,,,,,,,,,,,,,, -250400,4.422824,0.6220045,,,,,,,,,,,,,, -250500,4.3071604,0.54520124,,,,,,,,,,,,,, -250600,4.5333347,0.655065,,,,,,,,,,,,,, -250700,4.2814937,0.59792835,,,,,,,,,,,,,, -250800,4.8929386,0.5908667,,,,,,,,,,,,,, -250900,4.482766,0.5794982,,,,,,,,,,,,,, -251000,4.4841695,0.5713082,,,,,,,,,,,,,, -251100,4.5052447,0.61635613,,,,,,,,,,,,,, -251200,4.9364853,0.6665506,,,,,,,,,,,,,, -251259,,,0.960758090019226,0.1481978744268417,0.7556799650192261,1.0509867668151855,50000.0,0.6307000517845154,1.816618800163269,10000.0,84718.10183787346,87804.21118330956,84718.10183787346,3066.6506621837616,9.539054870605469,0.0 -251300,4.388104,0.6450897,,,,,,,,,,,,,, -251400,4.3262167,0.59104705,,,,,,,,,,,,,, -251500,4.427993,0.587436,,,,,,,,,,,,,, -251600,4.7366767,0.5848872,,,,,,,,,,,,,, -251700,4.393644,0.58787274,,,,,,,,,,,,,, -251800,4.291038,0.61217266,,,,,,,,,,,,,, -251900,4.4393187,0.67258644,,,,,,,,,,,,,, -252000,4.299648,0.5945719,,,,,,,,,,,,,, -252100,4.376883,0.5967351,,,,,,,,,,,,,, -252200,4.4582725,0.64282984,,,,,,,,,,,,,, -252300,4.6980577,0.5709337,,,,,,,,,,,,,, -252400,4.4997272,0.56989,,,,,,,,,,,,,, -252500,4.348829,0.6212189,,,,,,,,,,,,,, -252600,3.948748,0.5831103,,,,,,,,,,,,,, -252700,4.585908,0.65102696,,,,,,,,,,,,,, -252773,,,0.9607381820678712,0.1479832530021667,0.7556999921798706,1.0520581007003784,50000.0,0.6312000155448914,1.8171991109848025,10000.0,85228.26144003868,88331.66128325462,85228.26144003868,3083.81050491333,9.60982871055603,0.0 -252800,4.389614,0.6078428,,,,,,,,,,,,,, -252900,4.570904,0.66306365,,,,,,,,,,,,,, -253000,4.478179,0.5866988,,,,,,,,,,,,,, -253100,4.2285085,0.55718356,,,,,,,,,,,,,, -253200,4.4231076,0.662444,,,,,,,,,,,,,, -253300,4.3316684,0.6117966,,,,,,,,,,,,,, -253400,4.95737,0.66471565,,,,,,,,,,,,,, -253500,4.40096,0.54943055,,,,,,,,,,,,,, -253600,4.357,0.54861057,,,,,,,,,,,,,, -253700,4.9036407,0.62832296,,,,,,,,,,,,,, -253800,4.431144,0.67466086,,,,,,,,,,,,,, -253900,4.099051,0.5652952,,,,,,,,,,,,,, -254000,4.451824,0.6871901,,,,,,,,,,,,,, -254100,4.479306,0.64505816,,,,,,,,,,,,,, -254200,4.26538,0.65823406,,,,,,,,,,,,,, -254287,,,0.9598612785339355,0.1501981317996978,0.7560399770736694,1.052698850631714,50000.0,0.6307000517845154,1.8190534114837649,10000.0,85738.3051803112,88858.81435346603,85738.3051803112,3100.79234957695,9.67766046524048,0.0 -254300,4.765515,0.6228856,,,,,,,,,,,,,, -254400,4.166416,0.6255739,,,,,,,,,,,,,, -254500,5.0328417,0.63266236,,,,,,,,,,,,,, -254600,4.748271,0.6009025,,,,,,,,,,,,,, -254700,4.291493,0.61549693,,,,,,,,,,,,,, -254800,4.338463,0.60571545,,,,,,,,,,,,,, -254900,4.8889484,0.65740114,,,,,,,,,,,,,, -255000,4.3317723,0.5733077,,,,,,,,,,,,,, -255100,4.3040175,0.64951456,,,,,,,,,,,,,, -255200,5.0800796,0.73458004,,,,,,,,,,,,,, -255300,4.433376,0.5928918,,,,,,,,,,,,,, -255400,4.622327,0.66437614,,,,,,,,,,,,,, -255500,4.432269,0.56032926,,,,,,,,,,,,,, -255600,4.2356853,0.6441472,,,,,,,,,,,,,, -255700,4.3158474,0.64375705,,,,,,,,,,,,,, -255800,4.8278604,0.6174737,,,,,,,,,,,,,, -255801,,,0.9608378410339355,0.1464589387178421,0.7556399703025818,1.0526121854782104,50000.0,0.6309000253677368,1.8166823387146,10000.0,86248.5129327774,89386.53042554855,86248.5129327774,3118.1718657016754,9.745689868927002,0.0 -255900,4.4516263,0.6833067,,,,,,,,,,,,,, -256000,4.428535,0.62204766,,,,,,,,,,,,,, -256100,4.501429,0.60350156,,,,,,,,,,,,,, -256200,4.2478724,0.6490531,,,,,,,,,,,,,, -256300,4.276666,0.6314833,,,,,,,,,,,,,, -256400,4.820614,0.64559025,,,,,,,,,,,,,, -256500,3.9969993,0.5631546,,,,,,,,,,,,,, -256600,4.521631,0.6426322,,,,,,,,,,,,,, -256700,4.600632,0.6025618,,,,,,,,,,,,,, -256800,4.7769475,0.6100809,,,,,,,,,,,,,, -256900,4.7797832,0.6603385,,,,,,,,,,,,,, -257000,4.357372,0.59041816,,,,,,,,,,,,,, -257100,4.5345783,0.66477495,,,,,,,,,,,,,, -257200,4.423103,0.66720265,,,,,,,,,,,,,, -257300,4.465609,0.6343079,,,,,,,,,,,,,, -257316,,,0.9597616195678712,0.1502071619033813,0.756119966506958,1.0523533821105957,50000.0,0.6323000192642212,1.817854166030884,10000.0,86758.58062577248,89913.8745894432,86758.58062577248,3135.312481403351,9.82122540473938,0.0 -257400,4.7894416,0.618361,,,,,,,,,,,,,, -257500,4.625269,0.6396855,,,,,,,,,,,,,, -257600,4.2336817,0.5944784,,,,,,,,,,,,,, -257700,4.477923,0.6677957,,,,,,,,,,,,,, -257800,5.4696317,0.6297363,,,,,,,,,,,,,, -257900,4.8141403,0.6124957,,,,,,,,,,,,,, -258000,4.8046002,0.6382521,,,,,,,,,,,,,, -258100,4.6848574,0.70869243,,,,,,,,,,,,,, -258200,4.7198377,0.6472869,,,,,,,,,,,,,, -258300,4.38638,0.6392268,,,,,,,,,,,,,, -258400,4.729742,0.6446248,,,,,,,,,,,,,, -258500,4.594122,0.6267159,,,,,,,,,,,,,, -258600,4.4560266,0.6352828,,,,,,,,,,,,,, -258700,4.200725,0.59812033,,,,,,,,,,,,,, -258800,4.3606896,0.64296925,,,,,,,,,,,,,, -258830,,,0.9610371589660645,0.1454240679740905,0.7561799883842468,1.0518678426742554,50000.0,0.6314000487327576,1.817800521850586,10000.0,87268.71463608742,90441.27367305756,87268.71463608742,3152.405182123184,9.933086156845093,0.0 -258900,4.455627,0.61412084,,,,,,,,,,,,,, -259000,4.448674,0.6768398,,,,,,,,,,,,,, -259100,4.0587854,0.5676929,,,,,,,,,,,,,, -259200,4.3249974,0.59573066,,,,,,,,,,,,,, -259300,4.6549606,0.648183,,,,,,,,,,,,,, -259400,4.2314816,0.5658607,,,,,,,,,,,,,, -259500,4.540354,0.5679064,,,,,,,,,,,,,, -259600,4.5062084,0.57376736,,,,,,,,,,,,,, -259700,4.657983,0.69851637,,,,,,,,,,,,,, -259800,4.3032765,0.6224812,,,,,,,,,,,,,, -259900,4.339612,0.6346149,,,,,,,,,,,,,, -260000,4.5352182,0.67606896,,,,,,,,,,,,,, -260100,4.677333,0.6823824,,,,,,,,,,,,,, -260200,4.5960937,0.61794734,,,,,,,,,,,,,, -260300,4.630956,0.5761979,,,,,,,,,,,,,, -260343,,,0.9601203799247742,0.1482060104608535,0.7556799650192261,1.05309796333313,50000.0,0.6314000487327576,1.8175643682479856,10000.0,87778.6434044838,90968.59873461723,87778.6434044838,3169.6641025543213,10.010907649993896,0.0 -260400,4.9604588,0.6543124,,,,,,,,,,,,,, -260500,4.524333,0.57703096,,,,,,,,,,,,,, -260600,4.539907,0.5904433,,,,,,,,,,,,,, -260700,4.765093,0.633505,,,,,,,,,,,,,, -260800,4.5537586,0.5920115,,,,,,,,,,,,,, -260900,4.003553,0.5789248,,,,,,,,,,,,,, -261000,4.941784,0.73938847,,,,,,,,,,,,,, -261100,4.7291436,0.5844543,,,,,,,,,,,,,, -261200,4.614021,0.6628654,,,,,,,,,,,,,, -261300,4.448276,0.6501663,,,,,,,,,,,,,, -261400,5.253317,0.64389396,,,,,,,,,,,,,, -261500,4.3181195,0.57425,,,,,,,,,,,,,, -261600,4.377935,0.5827391,,,,,,,,,,,,,, -261700,4.8372617,0.64458925,,,,,,,,,,,,,, -261800,4.8209386,0.6997824,,,,,,,,,,,,,, -261857,,,0.9605787396430968,0.1496013551950454,0.7554999589920044,1.051580786705017,50000.0,0.6306000351905823,1.8180071115493768,10000.0,88288.82423329353,91496.1502726078,88288.82423329353,3186.8991684913635,10.086401224136353,0.0 -261900,4.483499,0.6691846,,,,,,,,,,,,,, -262000,4.302455,0.66041726,,,,,,,,,,,,,, -262100,4.7004514,0.62783706,,,,,,,,,,,,,, -262200,4.077735,0.5701838,,,,,,,,,,,,,, -262300,4.589242,0.58259094,,,,,,,,,,,,,, -262400,4.5424623,0.61616063,,,,,,,,,,,,,, -262500,4.3011003,0.6189295,,,,,,,,,,,,,, -262600,4.8749237,0.6093415,,,,,,,,,,,,,, -262700,5.0512815,0.6531879,,,,,,,,,,,,,, -262800,4.3009486,0.619884,,,,,,,,,,,,,, -262900,4.306983,0.6330221,,,,,,,,,,,,,, -263000,4.8924336,0.60498655,,,,,,,,,,,,,, -263100,4.2353506,0.53970516,,,,,,,,,,,,,, -263200,4.548695,0.6188047,,,,,,,,,,,,,, -263300,4.255539,0.55136114,,,,,,,,,,,,,, -263370,,,0.9611965417861938,0.1485585421323776,0.7565799951553345,1.0513132810592651,50000.0,0.6309000253677368,1.8187938928604128,10000.0,88798.76346611977,92023.4181470871,88798.76346611977,3204.0929505825043,10.161179780960085,0.0 -263400,4.1561284,0.6902674,,,,,,,,,,,,,, -263500,4.2438374,0.58672756,,,,,,,,,,,,,, -263600,4.824165,0.6214268,,,,,,,,,,,,,, -263700,4.257269,0.6096204,,,,,,,,,,,,,, -263800,4.584009,0.6529292,,,,,,,,,,,,,, -263900,4.382585,0.5560259,,,,,,,,,,,,,, -264000,4.3153353,0.6390495,,,,,,,,,,,,,, -264100,4.542032,0.60769665,,,,,,,,,,,,,, -264200,4.427097,0.6287572,,,,,,,,,,,,,, -264300,4.672127,0.6127645,,,,,,,,,,,,,, -264400,4.8417473,0.6138924,,,,,,,,,,,,,, -264500,4.477886,0.6110883,,,,,,,,,,,,,, -264600,4.4784,0.6699115,,,,,,,,,,,,,, -264700,4.415178,0.6063664,,,,,,,,,,,,,, -264800,4.2331376,0.5912015,,,,,,,,,,,,,, -264883,,,0.9622528553009032,0.1400790512561798,0.755620002746582,1.0535967350006104,50000.0,0.6309000253677368,1.818920373916626,10000.0,89308.81522655487,92550.89099025726,89308.81522655487,3221.382915019989,10.232295989990234,0.0 -264900,4.233526,0.56827474,,,,,,,,,,,,,, -265000,5.0005364,0.595907,,,,,,,,,,,,,, -265100,4.5469685,0.70750034,,,,,,,,,,,,,, -265200,4.6606317,0.6266583,,,,,,,,,,,,,, -265300,4.2772975,0.6212607,,,,,,,,,,,,,, -265400,4.231669,0.5335703,,,,,,,,,,,,,, -265500,4.4242554,0.64202195,,,,,,,,,,,,,, -265600,4.399817,0.5685644,,,,,,,,,,,,,, -265700,3.999796,0.59897166,,,,,,,,,,,,,, -265800,4.377542,0.6076571,,,,,,,,,,,,,, -265900,4.4123034,0.58395654,,,,,,,,,,,,,, -266000,4.593283,0.6301324,,,,,,,,,,,,,, -266100,4.3599567,0.6325268,,,,,,,,,,,,,, -266200,4.4876394,0.59375554,,,,,,,,,,,,,, -266300,5.011476,0.701595,,,,,,,,,,,,,, -266396,,,0.960758090019226,0.1475231349468231,0.7558599710464478,1.050364375114441,50000.0,0.6300000548362732,1.8159278631210327,10000.0,89818.75279092789,93078.5286242962,89818.75279092789,3238.949809551239,10.304664373397827,0.0 -266400,4.6315002,0.6323192,,,,,,,,,,,,,, -266500,4.8646083,0.62488747,,,,,,,,,,,,,, -266600,4.182183,0.57196516,,,,,,,,,,,,,, -266700,4.443445,0.6111615,,,,,,,,,,,,,, -266800,4.162211,0.55126023,,,,,,,,,,,,,, -266900,4.2597775,0.55527437,,,,,,,,,,,,,, -267000,4.8721943,0.6314707,,,,,,,,,,,,,, -267100,4.685103,0.66698474,,,,,,,,,,,,,, -267200,4.3765836,0.59991026,,,,,,,,,,,,,, -267300,4.1342955,0.5762876,,,,,,,,,,,,,, -267400,4.817193,0.7179004,,,,,,,,,,,,,, -267500,4.5503516,0.6714363,,,,,,,,,,,,,, -267600,4.1905756,0.62377226,,,,,,,,,,,,,, -267700,4.2907104,0.5365354,,,,,,,,,,,,,, -267800,4.7575254,0.57880956,,,,,,,,,,,,,, -267900,4.1153913,0.56254876,,,,,,,,,,,,,, -267910,,,0.9594228267669678,0.147891879081726,0.7556399703025818,1.0531243085861206,50000.0,0.6308000087738037,1.817447304725647,10000.0,90328.8978767395,93605.95786976814,90328.8978767395,3256.099337339401,10.37904691696167,0.0 -268000,4.7389817,0.6340648,,,,,,,,,,,,,, -268100,4.048562,0.55232584,,,,,,,,,,,,,, -268200,4.1782413,0.63448995,,,,,,,,,,,,,, -268300,4.5056868,0.64021266,,,,,,,,,,,,,, -268400,4.9172425,0.73592615,,,,,,,,,,,,,, -268500,4.374122,0.6168075,,,,,,,,,,,,,, -268600,4.809294,0.6042669,,,,,,,,,,,,,, -268700,4.867777,0.6269654,,,,,,,,,,,,,, -268800,4.2166505,0.55727464,,,,,,,,,,,,,, -268900,4.3393097,0.5624184,,,,,,,,,,,,,, -269000,4.851721,0.65146023,,,,,,,,,,,,,, -269100,4.3208656,0.5920125,,,,,,,,,,,,,, -269200,4.466886,0.65820897,,,,,,,,,,,,,, -269300,5.2250047,0.6506844,,,,,,,,,,,,,, -269400,4.772135,0.65329933,,,,,,,,,,,,,, -269424,,,0.9600805044174194,0.1480940580368042,0.7558799982070923,1.0508418083190918,50000.0,0.6317000389099121,1.8160980939865112,10000.0,90838.84168171884,94134.0173215866,90838.84168171884,3274.080150365829,10.45369815826416,0.0 -269500,4.3284917,0.5859066,,,,,,,,,,,,,, -269600,4.7477436,0.6232256,,,,,,,,,,,,,, -269700,4.8080335,0.69565636,,,,,,,,,,,,,, -269800,4.18317,0.57734436,,,,,,,,,,,,,, -269900,4.1957493,0.56748164,,,,,,,,,,,,,, -270000,4.8519425,0.70019346,,,,,,,,,,,,,, -270100,4.2533784,0.6258117,,,,,,,,,,,,,, -270200,4.307322,0.5653448,,,,,,,,,,,,,, -270300,4.7061877,0.68068284,,,,,,,,,,,,,, -270400,3.9985373,0.5388707,,,,,,,,,,,,,, -270500,5.034802,0.70217484,,,,,,,,,,,,,, -270600,4.4493628,0.6342113,,,,,,,,,,,,,, -270700,4.5791764,0.6179942,,,,,,,,,,,,,, -270800,4.1740446,0.55047977,,,,,,,,,,,,,, -270900,4.702375,0.6850355,,,,,,,,,,,,,, -270937,,,0.9594626426696776,0.1487428247928619,0.755899965763092,1.0522432327270508,50000.0,0.6314000487327576,1.818058729171753,10000.0,91348.74285554886,94661.57633805276,91348.74285554886,3291.6022622585297,10.529751300811768,0.0 -271000,4.695891,0.59891677,,,,,,,,,,,,,, -271100,4.6630626,0.64702725,,,,,,,,,,,,,, -271200,4.376568,0.61185503,,,,,,,,,,,,,, -271300,4.5481462,0.56494766,,,,,,,,,,,,,, -271400,4.2403045,0.5512836,,,,,,,,,,,,,, -271500,4.212903,0.6117815,,,,,,,,,,,,,, -271600,4.222811,0.5862756,,,,,,,,,,,,,, -271700,4.520397,0.6553303,,,,,,,,,,,,,, -271800,4.6790066,0.64142114,,,,,,,,,,,,,, -271900,4.6510377,0.6597999,,,,,,,,,,,,,, -272000,4.905657,0.68964434,,,,,,,,,,,,,, -272100,4.289152,0.5987073,,,,,,,,,,,,,, -272200,4.103985,0.5959149,,,,,,,,,,,,,, -272300,4.784228,0.64163125,,,,,,,,,,,,,, -272400,4.278008,0.63904357,,,,,,,,,,,,,, -272451,,,0.9612962007522584,0.1474275290966034,0.7554799914360046,1.051373839378357,50000.0,0.6312000155448914,1.81630539894104,10000.0,91858.74562954904,95189.11938214302,91858.74562954904,3308.995859146118,10.616759300231934,0.0 -272500,4.4977565,0.6077561,,,,,,,,,,,,,, -272600,4.5151167,0.5884512,,,,,,,,,,,,,, -272700,4.548337,0.56455344,,,,,,,,,,,,,, -272800,4.522832,0.5865002,,,,,,,,,,,,,, -272900,4.4551606,0.64801484,,,,,,,,,,,,,, -273000,4.4741173,0.5709029,,,,,,,,,,,,,, -273100,4.614359,0.6462585,,,,,,,,,,,,,, -273200,3.9734995,0.67385024,,,,,,,,,,,,,, -273300,4.638182,0.59461343,,,,,,,,,,,,,, -273400,5.020692,0.60264593,,,,,,,,,,,,,, -273500,4.935055,0.6477285,,,,,,,,,,,,,, -273600,4.9683867,0.6303624,,,,,,,,,,,,,, -273700,4.301891,0.5927773,,,,,,,,,,,,,, -273800,5.2439036,0.6284214,,,,,,,,,,,,,, -273900,3.9668372,0.534777,,,,,,,,,,,,,, -273964,,,0.9615553021430968,0.1442984044551849,0.7552399635314941,1.0525203943252563,50000.0,0.6303000450134277,1.8180094957351685,10000.0,92368.68969726562,95716.41310477255,92368.68969726562,3326.2125630378723,10.689738750457764,0.0 -274000,4.369855,0.61199784,,,,,,,,,,,,,, -274100,4.732359,0.62464416,,,,,,,,,,,,,, -274200,4.6262164,0.6708645,,,,,,,,,,,,,, -274300,4.701886,0.60285795,,,,,,,,,,,,,, -274400,4.970097,0.61245465,,,,,,,,,,,,,, -274500,4.4048786,0.58707416,,,,,,,,,,,,,, -274600,4.3750253,0.57970357,,,,,,,,,,,,,, -274700,4.498321,0.6948907,,,,,,,,,,,,,, -274800,4.9770875,0.69456494,,,,,,,,,,,,,, -274900,4.4784017,0.6805949,,,,,,,,,,,,,, -275000,4.408657,0.67213494,,,,,,,,,,,,,, -275100,4.4516144,0.5996934,,,,,,,,,,,,,, -275200,4.759075,0.72135854,,,,,,,,,,,,,, -275300,4.807642,0.7034366,,,,,,,,,,,,,, -275400,4.2123313,0.62886333,,,,,,,,,,,,,, -275478,,,0.9615951776504515,0.144477903842926,0.7556399703025818,1.0514625310897827,50000.0,0.6309000253677368,1.8165348768234253,10000.0,92878.66392302512,96243.52274870872,92878.66392302512,3343.2138271331787,10.763943195343018,0.0 -275500,4.3931804,0.55263186,,,,,,,,,,,,,, -275600,4.423573,0.5824885,,,,,,,,,,,,,, -275700,4.594389,0.60970694,,,,,,,,,,,,,, -275800,5.232286,0.66150516,,,,,,,,,,,,,, -275900,4.587055,0.6060238,,,,,,,,,,,,,, -276000,4.6162467,0.6009942,,,,,,,,,,,,,, -276100,4.4786854,0.6070853,,,,,,,,,,,,,, -276200,4.576892,0.60068256,,,,,,,,,,,,,, -276300,4.1261535,0.5720613,,,,,,,,,,,,,, -276400,4.3255415,0.5836388,,,,,,,,,,,,,, -276500,4.188165,0.5825951,,,,,,,,,,,,,, -276600,4.601267,0.6557536,,,,,,,,,,,,,, -276700,4.46902,0.6820193,,,,,,,,,,,,,, -276800,4.4465704,0.66714394,,,,,,,,,,,,,, -276900,4.8016543,0.65458626,,,,,,,,,,,,,, -276992,,,0.9596420526504515,0.1479740142822265,0.7559399604797363,1.0522382259368896,50000.0,0.6308000087738037,1.818804621696472,10000.0,93388.64336013794,96770.8598151207,93388.64336013794,3360.4372668266296,10.837567806243896,0.0 -277000,4.5065517,0.6044827,,,,,,,,,,,,,, -277100,4.706592,0.61076605,,,,,,,,,,,,,, -277200,4.7280226,0.5818873,,,,,,,,,,,,,, -277300,4.6435075,0.6671845,,,,,,,,,,,,,, -277400,4.6640406,0.64895946,,,,,,,,,,,,,, -277500,4.517046,0.6124703,,,,,,,,,,,,,, -277600,4.486846,0.6097766,,,,,,,,,,,,,, -277700,4.8439655,0.67847836,,,,,,,,,,,,,, -277800,4.475453,0.56921375,,,,,,,,,,,,,, -277900,4.6360073,0.5913922,,,,,,,,,,,,,, -278000,4.818794,0.65301675,,,,,,,,,,,,,, -278100,4.2254515,0.58793694,,,,,,,,,,,,,, -278200,4.27132,0.57052124,,,,,,,,,,,,,, -278300,4.658746,0.59193593,,,,,,,,,,,,,, -278400,4.6486006,0.5973066,,,,,,,,,,,,,, -278500,5.023202,0.6611918,,,,,,,,,,,,,, -278506,,,0.9606584906578064,0.1490184068679809,0.7557799816131592,1.0525351762771606,50000.0,0.631100058555603,1.817908525466919,10000.0,93898.81333494186,97298.36285185814,93898.81333494186,3377.6312334537506,10.91645884513855,0.0 -278600,4.506215,0.60345256,,,,,,,,,,,,,, -278700,4.6044645,0.6932471,,,,,,,,,,,,,, -278800,4.8792844,0.63873136,,,,,,,,,,,,,, -278900,4.5063176,0.66486096,,,,,,,,,,,,,, -279000,4.3607,0.6202538,,,,,,,,,,,,,, -279100,4.414947,0.5964748,,,,,,,,,,,,,, -279200,4.9534383,0.62885916,,,,,,,,,,,,,, -279300,4.495405,0.6001396,,,,,,,,,,,,,, -279400,4.5352497,0.6861149,,,,,,,,,,,,,, -279500,5.4041696,0.6162847,,,,,,,,,,,,,, -279600,4.1573734,0.59094423,,,,,,,,,,,,,, -279700,4.8406634,0.5509104,,,,,,,,,,,,,, -279800,5.4894943,0.7057295,,,,,,,,,,,,,, -279900,4.3924465,0.68155956,,,,,,,,,,,,,, -280000,4.4212403,0.6399017,,,,,,,,,,,,,, -280020,,,0.9610969424247742,0.1430348008871078,0.7557399868965149,1.0523576736450195,50000.0,0.6305000185966492,1.818747520446777,10000.0,94408.975086689,97825.67763566972,94408.975086689,3394.649137496948,10.992409467697144,0.0 -280100,4.45735,0.55911845,,,,,,,,,,,,,, -280200,4.511625,0.6185173,,,,,,,,,,,,,, -280300,4.505354,0.56990063,,,,,,,,,,,,,, -280400,4.7624335,0.643677,,,,,,,,,,,,,, -280500,4.442142,0.62435687,,,,,,,,,,,,,, -280600,4.7237635,0.67667854,,,,,,,,,,,,,, -280700,4.693181,0.6172205,,,,,,,,,,,,,, -280800,4.1796002,0.6007859,,,,,,,,,,,,,, -280900,4.6651864,0.7022033,,,,,,,,,,,,,, -281000,4.2633343,0.56010824,,,,,,,,,,,,,, -281100,4.4038367,0.65130043,,,,,,,,,,,,,, -281200,4.5738606,0.5843147,,,,,,,,,,,,,, -281300,4.604364,0.65313643,,,,,,,,,,,,,, -281400,4.621283,0.6211665,,,,,,,,,,,,,, -281500,4.6561184,0.6558073,,,,,,,,,,,,,, -281535,,,0.9610371589660645,0.1468787193298339,0.7556399703025818,1.0521724224090576,50000.0,0.631600022315979,1.8173832893371584,10000.0,94919.01567602158,98352.9867374897,94919.01567602158,3411.765170574188,11.085949182510376,0.0 -281600,4.79351,0.5794234,,,,,,,,,,,,,, -281700,4.3490586,0.5563315,,,,,,,,,,,,,, -281800,4.3618054,0.5613734,,,,,,,,,,,,,, -281900,4.340258,0.607026,,,,,,,,,,,,,, -282000,4.6567984,0.6467135,,,,,,,,,,,,,, -282100,4.266601,0.6084314,,,,,,,,,,,,,, -282200,4.5133243,0.62606186,,,,,,,,,,,,,, -282300,4.5826497,0.59408045,,,,,,,,,,,,,, -282400,4.4522457,0.57958955,,,,,,,,,,,,,, -282500,4.549961,0.68064237,,,,,,,,,,,,,, -282600,4.590282,0.6272688,,,,,,,,,,,,,, -282700,4.487967,0.64046067,,,,,,,,,,,,,, -282800,4.5488877,0.62669444,,,,,,,,,,,,,, -282900,4.8170652,0.6006156,,,,,,,,,,,,,, -283000,4.3436337,0.60633945,,,,,,,,,,,,,, -283049,,,0.9612165093421936,0.1450749486684799,0.7563599944114685,1.0516657829284668,50000.0,0.6310000419616699,1.817343831062317,10000.0,95428.90130352974,98879.92723703384,95428.90130352974,3428.685005664825,11.162020683288574,0.0 -283100,4.407634,0.65213317,,,,,,,,,,,,,, -283200,4.203375,0.56661016,,,,,,,,,,,,,, -283300,4.5044603,0.6355791,,,,,,,,,,,,,, -283400,5.179168,0.65877545,,,,,,,,,,,,,, -283500,4.228406,0.614922,,,,,,,,,,,,,, -283600,4.4552355,0.6812734,,,,,,,,,,,,,, -283700,4.082505,0.57532424,,,,,,,,,,,,,, -283800,4.438084,0.621827,,,,,,,,,,,,,, -283900,4.190702,0.6026651,,,,,,,,,,,,,, -284000,4.8964124,0.74067056,,,,,,,,,,,,,, -284100,4.5389414,0.6653537,,,,,,,,,,,,,, -284200,4.390622,0.5800236,,,,,,,,,,,,,, -284300,4.5832896,0.6313495,,,,,,,,,,,,,, -284400,4.751253,0.5929909,,,,,,,,,,,,,, -284500,4.7314363,0.6636732,,,,,,,,,,,,,, -284562,,,0.9595025181770324,0.1488256752490997,0.7557399868965149,1.0531953573226929,50000.0,0.6309000253677368,1.819740653038025,10000.0,95938.8274178505,99406.96535897256,95938.8274178505,3445.661366701126,11.237503051757812,0.0 -284600,4.5603824,0.6441517,,,,,,,,,,,,,, -284700,4.6516075,0.6394258,,,,,,,,,,,,,, -284800,4.5353317,0.66254276,,,,,,,,,,,,,, -284900,4.1113267,0.5901623,,,,,,,,,,,,,, -285000,4.070928,0.5506524,,,,,,,,,,,,,, -285100,4.36726,0.6003261,,,,,,,,,,,,,, -285200,5.082311,0.6913628,,,,,,,,,,,,,, -285300,4.96408,0.6341479,,,,,,,,,,,,,, -285400,4.4147396,0.5614654,,,,,,,,,,,,,, -285500,4.2673044,0.6300404,,,,,,,,,,,,,, -285600,4.471286,0.65998834,,,,,,,,,,,,,, -285700,4.3466415,0.6436608,,,,,,,,,,,,,, -285800,5.039437,0.6388803,,,,,,,,,,,,,, -285900,4.229244,0.61101377,,,,,,,,,,,,,, -286000,5.016897,0.6289549,,,,,,,,,,,,,, -286076,,,0.9608378410339355,0.1471054852008819,0.756060004234314,1.052351713180542,50000.0,0.6315000057220459,1.818746566772461,10000.0,96449.02365016936,99934.65275216104,96449.02365016936,3463.0189859867096,11.311057567596436,0.0 -286100,4.3131986,0.66299355,,,,,,,,,,,,,, -286200,4.6045437,0.6532847,,,,,,,,,,,,,, -286300,4.4489026,0.6408531,,,,,,,,,,,,,, -286400,4.4998794,0.6096001,,,,,,,,,,,,,, -286500,4.3292284,0.6309221,,,,,,,,,,,,,, -286600,4.3824778,0.57099086,,,,,,,,,,,,,, -286700,4.547746,0.58656955,,,,,,,,,,,,,, -286800,4.2849708,0.57617605,,,,,,,,,,,,,, -286900,4.1241755,0.5739304,,,,,,,,,,,,,, -287000,4.0789742,0.5898245,,,,,,,,,,,,,, -287100,4.543665,0.6845323,,,,,,,,,,,,,, -287200,4.5753055,0.6740972,,,,,,,,,,,,,, -287300,4.533455,0.6151165,,,,,,,,,,,,,, -287400,4.3563147,0.5439515,,,,,,,,,,,,,, -287500,4.7731915,0.65431654,,,,,,,,,,,,,, -287589,,,0.9601203799247742,0.1463282853364944,0.7557399868965149,1.051419377326965,50000.0,0.6321000456809998,1.8166067600250244,10000.0,96958.94147777556,100461.76079773904,96958.94147777556,3480.0775611400604,11.383795976638794,0.0 -287600,4.2222114,0.62667215,,,,,,,,,,,,,, -287700,4.2866063,0.62631845,,,,,,,,,,,,,, -287800,3.8520358,0.4914003,,,,,,,,,,,,,, -287900,4.1992855,0.6177614,,,,,,,,,,,,,, -288000,4.563071,0.67452186,,,,,,,,,,,,,, -288100,4.895811,0.655279,,,,,,,,,,,,,, -288200,4.50215,0.6078627,,,,,,,,,,,,,, -288300,4.2805715,0.6155616,,,,,,,,,,,,,, -288400,4.6029253,0.6328474,,,,,,,,,,,,,, -288500,5.0908103,0.62313044,,,,,,,,,,,,,, -288600,4.6349363,0.63677526,,,,,,,,,,,,,, -288700,4.1756124,0.5345588,,,,,,,,,,,,,, -288800,4.241048,0.6391407,,,,,,,,,,,,,, -288900,4.5345435,0.65997076,,,,,,,,,,,,,, -289000,4.4486375,0.61442536,,,,,,,,,,,,,, -289100,5.484093,0.62796795,,,,,,,,,,,,,, -289102,,,0.9602997303009032,0.1500391662120819,0.756119966506958,1.0513322353363037,50000.0,0.6308000087738037,1.817582368850708,10000.0,97468.86407732964,100988.83882284164,97468.86407732964,3497.1005821228027,11.457651138305664,0.0 -289200,4.2942224,0.6276809,,,,,,,,,,,,,, -289300,4.5530367,0.5744672,,,,,,,,,,,,,, -289400,4.2410455,0.5879262,,,,,,,,,,,,,, -289500,4.6033764,0.6216248,,,,,,,,,,,,,, -289600,4.696575,0.692418,,,,,,,,,,,,,, -289700,4.3617873,0.5734036,,,,,,,,,,,,,, -289800,4.370366,0.63907975,,,,,,,,,,,,,, -289900,4.519296,0.5829044,,,,,,,,,,,,,, -290000,4.2947974,0.64453524,,,,,,,,,,,,,, -290100,4.3906,0.65363014,,,,,,,,,,,,,, -290200,5.2914767,0.6968213,,,,,,,,,,,,,, -290300,4.4837203,0.6483698,,,,,,,,,,,,,, -290400,4.371257,0.5782188,,,,,,,,,,,,,, -290500,4.163514,0.6081665,,,,,,,,,,,,,, -290600,4.2747183,0.5900271,,,,,,,,,,,,,, -290616,,,0.960718274116516,0.1463023722171783,0.7559799551963806,1.052631974220276,50000.0,0.6315000057220459,1.817838311195373,10000.0,97978.9630844593,101516.11999130248,97978.9630844593,3514.145917892456,11.534794092178345,0.0 -290700,4.609388,0.635296,,,,,,,,,,,,,, -290800,3.978283,0.57051975,,,,,,,,,,,,,, -290900,5.0132546,0.6668862,,,,,,,,,,,,,, -291000,4.651898,0.6461979,,,,,,,,,,,,,, -291100,4.608521,0.5991782,,,,,,,,,,,,,, -291200,4.632224,0.6719491,,,,,,,,,,,,,, -291300,4.58961,0.647853,,,,,,,,,,,,,, -291400,5.1686096,0.5778533,,,,,,,,,,,,,, -291500,4.4065623,0.7344679,,,,,,,,,,,,,, -291600,4.5407524,0.5974035,,,,,,,,,,,,,, -291700,4.7859473,0.6754509,,,,,,,,,,,,,, -291800,5.0639353,0.6223247,,,,,,,,,,,,,, -291900,4.4239473,0.5748992,,,,,,,,,,,,,, -292000,4.4466653,0.6356468,,,,,,,,,,,,,, -292100,4.2544284,0.54331625,,,,,,,,,,,,,, -292130,,,0.9606186151504515,0.1470670998096466,0.7553799748420715,1.0529109239578247,50000.0,0.6310000419616699,1.8188873529434204,10000.0,98489.15401434898,102043.6172683239,98489.15401434898,3531.315213680268,11.611164569854736,0.0 -292200,4.6870036,0.67391235,,,,,,,,,,,,,, -292300,4.204915,0.5954584,,,,,,,,,,,,,, -292400,4.7423377,0.63323617,,,,,,,,,,,,,, -292500,4.341732,0.58911735,,,,,,,,,,,,,, -292600,4.784162,0.61194825,,,,,,,,,,,,,, -292700,4.208705,0.632003,,,,,,,,,,,,,, -292800,4.4894767,0.6667663,,,,,,,,,,,,,, -292900,4.7472353,0.67964774,,,,,,,,,,,,,, -293000,4.828377,0.67823434,,,,,,,,,,,,,, -293100,4.562287,0.67497104,,,,,,,,,,,,,, -293200,4.480071,0.6191953,,,,,,,,,,,,,, -293300,4.1436377,0.6023885,,,,,,,,,,,,,, -293400,4.4146147,0.5779557,,,,,,,,,,,,,, -293500,4.40854,0.58857924,,,,,,,,,,,,,, -293600,5.0586023,0.698867,,,,,,,,,,,,,, -293643,,,0.9607381820678712,0.1502828449010849,0.755620002746582,1.0518311262130735,50000.0,0.6307000517845154,1.81728196144104,10000.0,98999.07892155647,102570.82220602036,98999.07892155647,3548.455144643784,11.689434289932253,0.0 -293700,4.528419,0.60492706,,,,,,,,,,,,,, -293800,4.195769,0.54382604,,,,,,,,,,,,,, -293900,4.828252,0.66440517,,,,,,,,,,,,,, -294000,4.5516915,0.6243312,,,,,,,,,,,,,, -294100,4.3463635,0.582476,,,,,,,,,,,,,, -294200,4.289494,0.63649786,,,,,,,,,,,,,, -294300,4.153334,0.6378417,,,,,,,,,,,,,, -294400,4.2348204,0.5701741,,,,,,,,,,,,,, -294500,5.020291,0.67749983,,,,,,,,,,,,,, -294600,4.2771635,0.57833135,,,,,,,,,,,,,, -294700,4.5957026,0.63383925,,,,,,,,,,,,,, -294800,4.6093717,0.6133177,,,,,,,,,,,,,, -294900,4.749825,0.5873505,,,,,,,,,,,,,, -295000,4.2360454,0.562353,,,,,,,,,,,,,, -295100,4.556743,0.63388395,,,,,,,,,,,,,, -295157,,,0.9604591727256776,0.146732673048973,0.7555599808692932,1.052149534225464,50000.0,0.6307000517845154,1.817894458770752,10000.0,99509.08157086372,103097.92510271072,99509.08157086372,3565.4192507267,11.765220880508425,0.0 -295200,4.2948623,0.6662048,,,,,,,,,,,,,, -295300,4.1939945,0.63019353,,,,,,,,,,,,,, -295400,4.9679575,0.63659155,,,,,,,,,,,,,, -295500,4.4177794,0.6389823,,,,,,,,,,,,,, -295600,4.462454,0.65512675,,,,,,,,,,,,,, -295700,5.350708,0.6514796,,,,,,,,,,,,,, -295800,4.2578807,0.6844565,,,,,,,,,,,,,, -295900,4.6040144,0.666144,,,,,,,,,,,,,, -296000,4.3279033,0.67978513,,,,,,,,,,,,,, -296100,4.220956,0.6112453,,,,,,,,,,,,,, -296200,4.197399,0.6454787,,,,,,,,,,,,,, -296300,4.229037,0.5636695,,,,,,,,,,,,,, -296400,4.577201,0.63977945,,,,,,,,,,,,,, -296500,4.2819324,0.61776614,,,,,,,,,,,,,, -296600,5.124946,0.63785976,,,,,,,,,,,,,, -296671,,,0.9592434167861938,0.149974912405014,0.7557799816131592,1.052897572517395,50000.0,0.6307000517845154,1.817719578742981,10000.0,100019.2240486145,103625.51902842522,100019.2240486145,3582.733262300492,11.8424334526062,0.0 -296700,4.4469776,0.56472665,,,,,,,,,,,,,, -296800,4.167867,0.5911673,,,,,,,,,,,,,, -296900,4.634879,0.625416,,,,,,,,,,,,,, -297000,4.126211,0.53784865,,,,,,,,,,,,,, -297100,4.3034496,0.56857777,,,,,,,,,,,,,, -297200,4.814385,0.71186405,,,,,,,,,,,,,, -297300,4.2618876,0.65569705,,,,,,,,,,,,,, -297400,4.2745094,0.60912806,,,,,,,,,,,,,, -297500,4.148583,0.64111817,,,,,,,,,,,,,, -297600,4.8718257,0.69665504,,,,,,,,,,,,,, -297700,4.5900693,0.6625325,,,,,,,,,,,,,, -297800,4.6165657,0.6243412,,,,,,,,,,,,,, -297900,4.5334272,0.70135933,,,,,,,,,,,,,, -298000,4.3284726,0.5857985,,,,,,,,,,,,,, -298100,5.287767,0.7149426,,,,,,,,,,,,,, -298184,,,0.9615553021430968,0.1456803530454635,0.7562999725341797,1.0515629053115845,50000.0,0.6317000389099121,1.816598892211914,10000.0,100529.0994002819,104152.75135087968,100529.0994002819,3599.947383403778,11.925029277801514,0.0 -298200,4.1666737,0.5966412,,,,,,,,,,,,,, -298300,4.455921,0.63238984,,,,,,,,,,,,,, -298400,3.9834611,0.4986324,,,,,,,,,,,,,, -298500,4.0697556,0.5624083,,,,,,,,,,,,,, -298600,4.2053795,0.62372965,,,,,,,,,,,,,, -298700,4.4933214,0.6157187,,,,,,,,,,,,,, -298800,4.07339,0.51832306,,,,,,,,,,,,,, -298900,4.4627156,0.5807616,,,,,,,,,,,,,, -299000,4.961487,0.63579637,,,,,,,,,,,,,, -299100,4.620666,0.6712439,,,,,,,,,,,,,, -299200,5.019363,0.6267391,,,,,,,,,,,,,, -299300,4.6051216,0.61324006,,,,,,,,,,,,,, -299400,4.911959,0.6539602,,,,,,,,,,,,,, -299500,4.190373,0.6328993,,,,,,,,,,,,,, -299600,4.5447826,0.7232301,,,,,,,,,,,,,, -299698,,,0.961355984210968,0.1456865221261978,0.7557399868965149,1.0521283149719238,50000.0,0.6315000057220459,1.8183579444885247,10000.0,101039.07216191292,104680.2663693428,101039.07216191292,3617.3545808792114,12.00115966796875,0.0 -299700,4.583619,0.6661193,,,,,,,,,,,,,, -299800,4.242491,0.5488106,,,,,,,,,,,,,, -299900,4.222062,0.542327,,,,,,,,,,,,,, -300000,5.3004665,0.63647485,,,,,,,,,,,,,, -300100,4.352529,0.6840137,,,,,,,,,,,,,, -300200,4.26046,0.616114,,,,,,,,,,,,,, -300300,4.2939367,0.6125495,,,,,,,,,,,,,, -300400,4.3428,0.58342,,,,,,,,,,,,,, -300500,4.4143558,0.64670914,,,,,,,,,,,,,, -300600,4.4119997,0.5953216,,,,,,,,,,,,,, -300700,4.3159485,0.62278473,,,,,,,,,,,,,, -300800,4.555267,0.6949471,,,,,,,,,,,,,, -300900,4.1730256,0.6500132,,,,,,,,,,,,,, -301000,4.573067,0.5959004,,,,,,,,,,,,,, -301100,4.2528563,0.591542,,,,,,,,,,,,,, -301200,4.9293375,0.6312057,,,,,,,,,,,,,, -301211,,,0.960598647594452,0.1500077545642852,0.7556599974632263,1.0508867502212524,50000.0,0.6318000555038452,1.81587815284729,10000.0,101549.00704622269,105207.99576115608,101549.00704622269,3635.0120573043823,12.07826280593872,0.0 -301300,4.46069,0.62379503,,,,,,,,,,,,,, -301400,4.9321713,0.64453685,,,,,,,,,,,,,, -301500,4.746837,0.7007832,,,,,,,,,,,,,, -301600,4.554791,0.57419205,,,,,,,,,,,,,, -301700,4.407082,0.55492806,,,,,,,,,,,,,, -301800,4.1625266,0.603933,,,,,,,,,,,,,, -301900,4.5379477,0.6700654,,,,,,,,,,,,,, -302000,4.605447,0.6229752,,,,,,,,,,,,,, -302100,4.2339954,0.5699052,,,,,,,,,,,,,, -302200,4.31396,0.5772625,,,,,,,,,,,,,, -302300,4.112638,0.58316755,,,,,,,,,,,,,, -302400,4.3839383,0.5743818,,,,,,,,,,,,,, -302500,4.9545956,0.6762836,,,,,,,,,,,,,, -302600,4.360377,0.6010077,,,,,,,,,,,,,, -302700,4.0473385,0.5749762,,,,,,,,,,,,,, -302724,,,0.9607979655265808,0.1483609229326248,0.7554000020027161,1.0511231422424316,50000.0,0.6313000321388245,1.8171393871307373,10000.0,102058.9766690731,105735.22607922554,102058.9766690731,3652.1334154605865,12.158013582229614,0.0 -302800,4.5829444,0.5632968,,,,,,,,,,,,,, -302900,4.2930555,0.62102634,,,,,,,,,,,,,, -303000,4.0178466,0.6065843,,,,,,,,,,,,,, -303100,4.433018,0.56007564,,,,,,,,,,,,,, -303200,4.7965465,0.65626526,,,,,,,,,,,,,, -303300,4.262812,0.6064811,,,,,,,,,,,,,, -303400,4.423054,0.6933537,,,,,,,,,,,,,, -303500,4.319593,0.5884266,,,,,,,,,,,,,, -303600,4.1267905,0.55901706,,,,,,,,,,,,,, -303700,4.8353186,0.61701775,,,,,,,,,,,,,, -303800,4.5670753,0.62278205,,,,,,,,,,,,,, -303900,4.675076,0.5889431,,,,,,,,,,,,,, -304000,4.1869726,0.5825081,,,,,,,,,,,,,, -304100,4.588363,0.6365529,,,,,,,,,,,,,, -304200,4.5882335,0.5989614,,,,,,,,,,,,,, -304237,,,0.9625318646430968,0.1406024247407913,0.7562400102615356,1.0528939962387085,50000.0,0.6307000517845154,1.817326307296753,10000.0,102569.1431913376,106262.51443743706,102569.1431913376,3669.119397878647,12.233417749404907,0.0 -304300,4.1552258,0.6260383,,,,,,,,,,,,,, -304400,4.3305097,0.5923028,,,,,,,,,,,,,, -304500,4.4318147,0.62988514,,,,,,,,,,,,,, -304600,4.7626195,0.654547,,,,,,,,,,,,,, -304700,4.451286,0.6091508,,,,,,,,,,,,,, -304800,4.3759704,0.57650447,,,,,,,,,,,,,, -304900,4.223849,0.61058116,,,,,,,,,,,,,, -305000,4.557668,0.63781005,,,,,,,,,,,,,, -305100,4.570991,0.65014535,,,,,,,,,,,,,, -305200,4.318886,0.64006233,,,,,,,,,,,,,, -305300,4.573568,0.58594745,,,,,,,,,,,,,, -305400,4.1088934,0.6249448,,,,,,,,,,,,,, -305500,4.1493235,0.62302893,,,,,,,,,,,,,, -305600,4.7051244,0.61356735,,,,,,,,,,,,,, -305700,4.3160295,0.66226244,,,,,,,,,,,,,, -305750,,,0.960339605808258,0.1479188501834869,0.7556999921798706,1.0525130033493042,50000.0,0.6307000517845154,1.8191286325454712,10000.0,103079.07321691512,106789.75725913048,103079.07321691512,3686.292558431626,12.312384128570557,0.0 -305800,4.3387966,0.5664243,,,,,,,,,,,,,, -305900,4.5272813,0.6319864,,,,,,,,,,,,,, -306000,4.514192,0.55684197,,,,,,,,,,,,,, -306100,4.4067845,0.66719687,,,,,,,,,,,,,, -306200,4.593915,0.5276981,,,,,,,,,,,,,, -306300,4.450912,0.60921663,,,,,,,,,,,,,, -306400,4.5549073,0.6621712,,,,,,,,,,,,,, -306500,4.6995044,0.6383749,,,,,,,,,,,,,, -306600,4.2390966,0.5603076,,,,,,,,,,,,,, -306700,5.1030283,0.7189934,,,,,,,,,,,,,, -306800,4.3227296,0.66725564,,,,,,,,,,,,,, -306900,4.4731207,0.64003384,,,,,,,,,,,,,, -307000,4.520141,0.6201783,,,,,,,,,,,,,, -307100,4.1840897,0.58514196,,,,,,,,,,,,,, -307200,4.56504,0.6094186,,,,,,,,,,,,,, -307264,,,0.9594826102256776,0.1492210626602172,0.7557599544525146,1.0515639781951904,50000.0,0.6309000253677368,1.8189918994903564,10000.0,103589.23261809348,107318.24916386604,103589.23261809348,3704.487591743469,12.390041828155518,0.0 -307300,4.601661,0.60575116,,,,,,,,,,,,,, -307400,4.142818,0.59447145,,,,,,,,,,,,,, -307500,4.2062635,0.5828183,,,,,,,,,,,,,, -307600,4.5125937,0.645793,,,,,,,,,,,,,, -307700,4.731214,0.63966364,,,,,,,,,,,,,, -307800,4.5186267,0.69238526,,,,,,,,,,,,,, -307900,4.604223,0.61846113,,,,,,,,,,,,,, -308000,4.2761397,0.59804714,,,,,,,,,,,,,, -308100,4.532309,0.6765108,,,,,,,,,,,,,, -308200,4.5187945,0.62112963,,,,,,,,,,,,,, -308300,4.3174143,0.62249607,,,,,,,,,,,,,, -308400,4.551288,0.64800626,,,,,,,,,,,,,, -308500,4.4200835,0.50047493,,,,,,,,,,,,,, -308600,5.0187006,0.69014794,,,,,,,,,,,,,, -308700,4.075599,0.62732536,,,,,,,,,,,,,, -308777,,,0.960379421710968,0.1470292061567306,0.7558799982070923,1.051123023033142,50000.0,0.6315000057220459,1.8163002729415887,10000.0,104099.11788463593,107845.5293803215,104099.11788463593,3721.745894193649,12.466761350631714,0.0 -308800,4.705901,0.6358043,,,,,,,,,,,,,, -308900,4.4518285,0.6260919,,,,,,,,,,,,,, -309000,4.237408,0.5924839,,,,,,,,,,,,,, -309100,4.5748158,0.78205484,,,,,,,,,,,,,, -309200,4.3364353,0.611545,,,,,,,,,,,,,, -309300,4.482123,0.6035838,,,,,,,,,,,,,, -309400,4.6466017,0.64530027,,,,,,,,,,,,,, -309500,5.063292,0.64861876,,,,,,,,,,,,,, -309600,4.9495053,0.6279213,,,,,,,,,,,,,, -309700,4.4457717,0.56463563,,,,,,,,,,,,,, -309800,4.7671995,0.5964001,,,,,,,,,,,,,, -309900,4.372184,0.62111914,,,,,,,,,,,,,, -310000,3.9492095,0.55930847,,,,,,,,,,,,,, -310100,4.9575014,0.6657295,,,,,,,,,,,,,, -310200,4.183927,0.5942714,,,,,,,,,,,,,, -310291,,,0.959741711616516,0.148784339427948,0.7563199996948242,1.0511804819107056,50000.0,0.6304000020027161,1.81773841381073,10000.0,104609.262250185,108372.77680826189,104609.262250185,3738.711463928223,12.54440450668335,0.0 -310300,4.6801667,0.6606847,,,,,,,,,,,,,, -310400,4.2798057,0.616736,,,,,,,,,,,,,, -310500,4.5793705,0.6341634,,,,,,,,,,,,,, -310600,4.3876257,0.6233275,,,,,,,,,,,,,, -310700,5.058077,0.640546,,,,,,,,,,,,,, -310800,4.397342,0.6084366,,,,,,,,,,,,,, -310900,4.5701685,0.6503985,,,,,,,,,,,,,, -311000,4.4496255,0.6228329,,,,,,,,,,,,,, -311100,5.0726953,0.6105308,,,,,,,,,,,,,, -311200,4.4857154,0.5766894,,,,,,,,,,,,,, -311300,4.9226575,0.6939508,,,,,,,,,,,,,, -311400,4.665937,0.64513695,,,,,,,,,,,,,, -311500,5.0240245,0.57475626,,,,,,,,,,,,,, -311600,4.2378078,0.6614217,,,,,,,,,,,,,, -311700,4.493013,0.6186883,,,,,,,,,,,,,, -311800,4.232206,0.5953108,,,,,,,,,,,,,, -311805,,,0.9624720811843872,0.1441033333539962,0.756119966506958,1.0523790121078491,50000.0,0.631100058555603,1.81870186328888,10000.0,105119.16359496117,108900.00142145155,105119.16359496117,3755.8963787555695,12.624755382537842,0.0 -311900,5.1975193,0.61553425,,,,,,,,,,,,,, -312000,4.099677,0.55225,,,,,,,,,,,,,, -312100,4.9336553,0.68857735,,,,,,,,,,,,,, -312200,4.744914,0.6809157,,,,,,,,,,,,,, -312300,4.8637114,0.6931895,,,,,,,,,,,,,, -312400,4.4011683,0.6464274,,,,,,,,,,,,,, -312500,4.5725603,0.61711234,,,,,,,,,,,,,, -312600,4.5231333,0.66847354,,,,,,,,,,,,,, -312700,4.0106263,0.5532283,,,,,,,,,,,,,, -312800,4.554525,0.67227083,,,,,,,,,,,,,, -312900,4.3269343,0.5808821,,,,,,,,,,,,,, -313000,4.6917343,0.66241956,,,,,,,,,,,,,, -313100,4.492673,0.6299576,,,,,,,,,,,,,, -313200,4.7814946,0.7251899,,,,,,,,,,,,,, -313300,4.610501,0.5877483,,,,,,,,,,,,,, -313319,,,0.9609375,0.1451739966869354,0.7560399770736694,1.0512357950210571,50000.0,0.6314000487327576,1.8169692754745483,10000.0,105629.2938606739,109427.20877575874,105629.2938606739,3772.8295102119446,12.706658840179443,0.0 -313400,4.372318,0.6475497,,,,,,,,,,,,,, -313500,4.4859977,0.6047346,,,,,,,,,,,,,, -313600,4.4474573,0.64574015,,,,,,,,,,,,,, -313700,4.5644817,0.6285411,,,,,,,,,,,,,, -313800,4.6063514,0.56897396,,,,,,,,,,,,,, -313900,4.541167,0.64114606,,,,,,,,,,,,,, -314000,4.243269,0.6000864,,,,,,,,,,,,,, -314100,4.35798,0.567119,,,,,,,,,,,,,, -314200,4.4781456,0.66362107,,,,,,,,,,,,,, -314300,4.548386,0.6185519,,,,,,,,,,,,,, -314400,4.185141,0.57638454,,,,,,,,,,,,,, -314500,4.706252,0.704102,,,,,,,,,,,,,, -314600,4.264661,0.5921306,,,,,,,,,,,,,, -314700,4.5090904,0.66477937,,,,,,,,,,,,,, -314800,4.9510508,0.64195746,,,,,,,,,,,,,, -314833,,,0.9600605964660645,0.1482522189617157,0.755620002746582,1.052326798439026,50000.0,0.632900059223175,1.8191454410552976,10000.0,106139.17324519156,109954.76312589644,106139.17324519156,3790.351084470749,12.800882577896118,0.0 -314900,4.216503,0.5641762,,,,,,,,,,,,,, -315000,4.311788,0.59230465,,,,,,,,,,,,,, -315100,4.7335563,0.6004219,,,,,,,,,,,,,, -315200,4.399158,0.5503322,,,,,,,,,,,,,, -315300,4.1805987,0.59448946,,,,,,,,,,,,,, -315400,4.472718,0.6376617,,,,,,,,,,,,,, -315500,4.2614336,0.6406235,,,,,,,,,,,,,, -315600,4.5357003,0.66307235,,,,,,,,,,,,,, -315700,4.7827287,0.67211914,,,,,,,,,,,,,, -315800,4.45404,0.6456374,,,,,,,,,,,,,, -315900,4.507747,0.60452193,,,,,,,,,,,,,, -316000,5.3862133,0.6631975,,,,,,,,,,,,,, -316100,4.4544334,0.6511495,,,,,,,,,,,,,, -316200,4.726234,0.6446409,,,,,,,,,,,,,, -316300,4.27748,0.66056883,,,,,,,,,,,,,, -316345,,,0.960379421710968,0.1486100107431411,0.7559799551963806,1.0517127513885498,50000.0,0.631600022315979,1.8172281980514529,10000.0,106649.0578083992,110481.89925575256,106649.0578083992,3807.474318265915,12.86822748184204,0.0 -316400,4.600551,0.58679724,,,,,,,,,,,,,, -316500,4.3216677,0.7016207,,,,,,,,,,,,,, -316600,4.386678,0.61502534,,,,,,,,,,,,,, -316700,4.570438,0.6311833,,,,,,,,,,,,,, -316800,4.7376575,0.63873374,,,,,,,,,,,,,, -316900,4.1340933,0.539829,,,,,,,,,,,,,, -317000,4.2295723,0.62152946,,,,,,,,,,,,,, -317100,4.4311156,0.6281679,,,,,,,,,,,,,, -317200,4.1911163,0.53474075,,,,,,,,,,,,,, -317300,4.5311136,0.5840813,,,,,,,,,,,,,, -317400,4.564125,0.66041523,,,,,,,,,,,,,, -317500,4.1984997,0.5513346,,,,,,,,,,,,,, -317600,4.7603607,0.630265,,,,,,,,,,,,,, -317700,4.467542,0.6368825,,,,,,,,,,,,,, -317800,4.8363423,0.6493018,,,,,,,,,,,,,, -317860,,,0.9601004123687744,0.1470013409852981,0.7554199695587158,1.0530093908309937,50000.0,0.6304000020027161,1.817755937576294,10000.0,107159.1848680973,111009.16904592514,107159.1848680973,3824.4763877391815,12.949540615081789,0.0 -317900,4.200447,0.5713023,,,,,,,,,,,,,, -318000,4.18629,0.60567755,,,,,,,,,,,,,, -318100,4.321908,0.6651321,,,,,,,,,,,,,, -318200,4.627618,0.65489775,,,,,,,,,,,,,, -318300,4.4868736,0.5956242,,,,,,,,,,,,,, -318400,4.5094595,0.6386293,,,,,,,,,,,,,, -318500,4.205821,0.6625081,,,,,,,,,,,,,, -318600,4.440231,0.5685091,,,,,,,,,,,,,, -318700,4.2463326,0.584438,,,,,,,,,,,,,, -318800,4.259385,0.5605327,,,,,,,,,,,,,, -318900,4.379684,0.5946797,,,,,,,,,,,,,, -319000,4.27158,0.6358859,,,,,,,,,,,,,, -319100,4.8218193,0.6435803,,,,,,,,,,,,,, -319200,4.44404,0.5875933,,,,,,,,,,,,,, -319300,4.5816693,0.65644544,,,,,,,,,,,,,, -319373,,,0.9612563848495485,0.1443883031606674,0.7559799551963806,1.051807880401611,50000.0,0.6319000124931335,1.8153104782104488,10000.0,107669.21872401236,111536.4730143547,107669.21872401236,3841.6000397205353,13.036304712295532,0.0 -319400,3.9899297,0.58942866,,,,,,,,,,,,,, -319500,4.4092093,0.6088623,,,,,,,,,,,,,, -319600,4.468054,0.56376654,,,,,,,,,,,,,, -319700,4.3933067,0.6362697,,,,,,,,,,,,,, -319800,4.531147,0.60095966,,,,,,,,,,,,,, -319900,4.2073236,0.54908824,,,,,,,,,,,,,, -320000,5.126796,0.6512261,,,,,,,,,,,,,, -320100,4.4227333,0.6149793,,,,,,,,,,,,,, -320200,4.2900724,0.6301802,,,,,,,,,,,,,, -320300,4.322738,0.6457754,,,,,,,,,,,,,, -320400,4.462427,0.584033,,,,,,,,,,,,,, -320500,4.6007547,0.6689232,,,,,,,,,,,,,, -320600,4.7505317,0.5441195,,,,,,,,,,,,,, -320700,4.856023,0.6820939,,,,,,,,,,,,,, -320800,4.4019375,0.6173891,,,,,,,,,,,,,, -320887,,,0.961136758327484,0.1468666344881057,0.7561799883842468,1.0516546964645386,50000.0,0.6313000321388245,1.815914750099182,10000.0,108179.18863844872,112063.77143478394,108179.18863844872,3858.7794704437256,13.124491453170776,0.0 -320900,4.675238,0.5518259,,,,,,,,,,,,,, -321000,4.7015305,0.68726087,,,,,,,,,,,,,, -321100,4.8884697,0.64813066,,,,,,,,,,,,,, -321200,3.928102,0.6028736,,,,,,,,,,,,,, -321300,5.2567606,0.64932716,,,,,,,,,,,,,, -321400,4.0997124,0.58992386,,,,,,,,,,,,,, -321500,4.144302,0.59553254,,,,,,,,,,,,,, -321600,4.4947658,0.584587,,,,,,,,,,,,,, -321700,4.452639,0.59060264,,,,,,,,,,,,,, -321800,4.4021754,0.6028544,,,,,,,,,,,,,, -321900,4.6385374,0.6121687,,,,,,,,,,,,,, -322000,5.2365246,0.6169034,,,,,,,,,,,,,, -322100,4.2859435,0.5921302,,,,,,,,,,,,,, -322200,4.619876,0.6456554,,,,,,,,,,,,,, -322300,4.63932,0.6789065,,,,,,,,,,,,,, -322400,,,0.9613958597183228,0.1446283608675003,0.7553799748420715,1.0537073612213137,50000.0,0.6304000020027161,1.81884515285492,10000.0,108689.0499472618,112591.07172703744,108689.0499472618,3876.076628923416,13.206662654876707,0.0 -322400,5.9020367,0.6418987,,,,,,,,,,,,,, -322500,4.4800854,0.6161021,,,,,,,,,,,,,, -322600,4.3881283,0.58313537,,,,,,,,,,,,,, -322700,4.6090546,0.6178009,,,,,,,,,,,,,, -322800,4.3138933,0.63097274,,,,,,,,,,,,,, -322900,4.625309,0.5523479,,,,,,,,,,,,,, -323000,4.6809077,0.66035944,,,,,,,,,,,,,, -323100,4.284291,0.6159489,,,,,,,,,,,,,, -323200,4.4563346,0.6076459,,,,,,,,,,,,,, -323300,4.540272,0.69502264,,,,,,,,,,,,,, -323400,4.6611266,0.6161412,,,,,,,,,,,,,, -323500,4.4654937,0.6515271,,,,,,,,,,,,,, -323600,4.8924484,0.6797767,,,,,,,,,,,,,, -323700,4.412392,0.6528169,,,,,,,,,,,,,, -323800,4.9155927,0.6100299,,,,,,,,,,,,,, -323900,4.787303,0.5944653,,,,,,,,,,,,,, -323914,,,0.9598811864852904,0.1484560519456863,0.7558199763298035,1.0515974760055542,50000.0,0.6319000124931335,1.8185746669769287,10000.0,109199.06824493408,113118.49463415146,109199.06824493408,3893.3367924690247,13.290921926498411,0.0 -324000,4.115952,0.62034565,,,,,,,,,,,,,, -324100,5.4249115,0.65414023,,,,,,,,,,,,,, -324200,4.4332542,0.6583185,,,,,,,,,,,,,, -324300,4.9749312,0.5569357,,,,,,,,,,,,,, -324400,4.1702337,0.6313865,,,,,,,,,,,,,, -324500,4.86767,0.69003266,,,,,,,,,,,,,, -324600,4.773454,0.54579794,,,,,,,,,,,,,, -324700,4.4611335,0.6198374,,,,,,,,,,,,,, -324800,5.19481,0.6573258,,,,,,,,,,,,,, -324900,4.9994597,0.67803276,,,,,,,,,,,,,, -325000,4.452598,0.6218575,,,,,,,,,,,,,, -325100,4.3831005,0.65302426,,,,,,,,,,,,,, -325200,4.8352532,0.58717585,,,,,,,,,,,,,, -325300,4.6782665,0.66915727,,,,,,,,,,,,,, -325400,4.299223,0.6227925,,,,,,,,,,,,,, -325427,,,0.9601402878761292,0.1478699147701263,0.7559799551963806,1.0521526336669922,50000.0,0.6305000185966492,1.8173801898956297,10000.0,109708.97253608704,113645.7896540165,109708.97253608704,3910.58306145668,13.375164985656738,0.0 -325500,4.7809954,0.6113139,,,,,,,,,,,,,, -325600,4.68552,0.6836819,,,,,,,,,,,,,, -325700,4.8746014,0.62178123,,,,,,,,,,,,,, -325800,4.626134,0.66206634,,,,,,,,,,,,,, -325900,4.8352885,0.5610196,,,,,,,,,,,,,, -326000,4.4116797,0.6413045,,,,,,,,,,,,,, -326100,4.3153844,0.6627371,,,,,,,,,,,,,, -326200,4.2525897,0.62327,,,,,,,,,,,,,, -326300,4.5027876,0.6199813,,,,,,,,,,,,,, -326400,4.7692385,0.6516816,,,,,,,,,,,,,, -326500,4.4695816,0.6404843,,,,,,,,,,,,,, -326600,4.336636,0.5761662,,,,,,,,,,,,,, -326700,4.9259453,0.6385721,,,,,,,,,,,,,, -326800,5.197036,0.676854,,,,,,,,,,,,,, -326900,5.1280055,0.66736656,,,,,,,,,,,,,, -326941,,,0.9608777165412904,0.1470252126455307,0.7557599544525146,1.0519651174545288,50000.0,0.631600022315979,1.8180609941482544,10000.0,110218.91332650185,114173.03961229324,110218.91332650185,3927.7408850193024,13.466815948486328,0.0 -327000,5.0779595,0.63657135,,,,,,,,,,,,,, -327100,4.4966817,0.6048405,,,,,,,,,,,,,, -327200,4.349212,0.64798754,,,,,,,,,,,,,, -327300,4.1091094,0.64095736,,,,,,,,,,,,,, -327400,4.5914736,0.64454585,,,,,,,,,,,,,, -327500,4.3499327,0.6704939,,,,,,,,,,,,,, -327600,4.589875,0.6543857,,,,,,,,,,,,,, -327700,4.6852574,0.678491,,,,,,,,,,,,,, -327800,4.625118,0.66632,,,,,,,,,,,,,, -327900,4.5644355,0.59473175,,,,,,,,,,,,,, -328000,4.387823,0.6422719,,,,,,,,,,,,,, -328100,4.2251725,0.6214511,,,,,,,,,,,,,, -328200,4.4278116,0.6357044,,,,,,,,,,,,,, -328300,4.5226297,0.67535985,,,,,,,,,,,,,, -328400,4.566714,0.6397175,,,,,,,,,,,,,, -328454,,,0.9614756107330322,0.1458468437194824,0.7557599544525146,1.051677942276001,50000.0,0.6323000192642212,1.8172355890274048,10000.0,110728.96795773506,114700.41973114014,110728.96795773506,3944.9255831241608,13.548635721206663,0.0 -328500,4.50407,0.59774303,,,,,,,,,,,,,, -328600,4.5593677,0.60989946,,,,,,,,,,,,,, -328700,4.556239,0.57298106,,,,,,,,,,,,,, -328800,5.046918,0.6060866,,,,,,,,,,,,,, -328900,4.072428,0.5603622,,,,,,,,,,,,,, -329000,4.5729613,0.67969054,,,,,,,,,,,,,, -329100,5.7865686,0.6908707,,,,,,,,,,,,,, -329200,4.1830354,0.5869713,,,,,,,,,,,,,, -329300,4.8051105,0.7178527,,,,,,,,,,,,,, -329400,4.3789797,0.64739114,,,,,,,,,,,,,, -329500,4.544914,0.6438356,,,,,,,,,,,,,, -329600,4.648969,0.6955427,,,,,,,,,,,,,, -329700,3.983659,0.5797929,,,,,,,,,,,,,, -329800,4.90184,0.6533626,,,,,,,,,,,,,, -329900,4.325839,0.6675972,,,,,,,,,,,,,, -329968,,,0.9604192972183228,0.1479916721582412,0.7558199763298035,1.0517678260803225,50000.0,0.6317000389099121,1.816851019859314,10000.0,111238.9862473011,115227.71682953836,111238.9862473011,3962.061777353287,13.631106853485107,0.0 -330000,4.097106,0.5945237,,,,,,,,,,,,,, -330100,4.3697615,0.6083987,,,,,,,,,,,,,, -330200,4.8835263,0.6999708,,,,,,,,,,,,,, -330300,4.293478,0.6120289,,,,,,,,,,,,,, -330400,4.43554,0.64787954,,,,,,,,,,,,,, -330500,4.608046,0.631684,,,,,,,,,,,,,, -330600,4.374659,0.62188864,,,,,,,,,,,,,, -330700,4.8442044,0.6387215,,,,,,,,,,,,,, -330800,3.9629846,0.5445788,,,,,,,,,,,,,, -330900,4.138249,0.56586045,,,,,,,,,,,,,, -331000,4.377921,0.64463615,,,,,,,,,,,,,, -331100,4.516192,0.6730547,,,,,,,,,,,,,, -331200,4.3297086,0.62440926,,,,,,,,,,,,,, -331300,4.915101,0.6378514,,,,,,,,,,,,,, -331400,4.3192015,0.62306786,,,,,,,,,,,,,, -331481,,,0.9598811864852904,0.1524875611066818,0.7560999989509583,1.0509434938430786,50000.0,0.631100058555603,1.8166096210479736,10000.0,111749.04176855087,115754.84624505044,111749.04176855087,3978.993365049362,13.714733600616457,0.0 -331500,4.3651795,0.5456847,,,,,,,,,,,,,, -331600,4.1592193,0.6300912,,,,,,,,,,,,,, -331700,4.5399966,0.5600271,,,,,,,,,,,,,, -331800,5.2068486,0.6164319,,,,,,,,,,,,,, -331900,4.376365,0.62497896,,,,,,,,,,,,,, -332000,4.8403845,0.6608111,,,,,,,,,,,,,, -332100,3.822904,0.56949043,,,,,,,,,,,,,, -332200,4.878244,0.59342456,,,,,,,,,,,,,, -332300,4.631088,0.65942216,,,,,,,,,,,,,, -332400,4.4857283,0.5945222,,,,,,,,,,,,,, -332500,4.40137,0.6492908,,,,,,,,,,,,,, -332600,4.4668612,0.6445619,,,,,,,,,,,,,, -332700,4.02983,0.59728694,,,,,,,,,,,,,, -332800,4.6478457,0.59816587,,,,,,,,,,,,,, -332900,4.2233734,0.5782392,,,,,,,,,,,,,, -332995,,,0.961734652519226,0.14434514939785,0.7558000087738037,1.0530338287353516,50000.0,0.6313000321388245,1.818378210067749,10000.0,112258.91333699226,116281.9922375679,112258.91333699226,3996.122615098953,13.801108837127686,0.0 -333000,4.608274,0.601368,,,,,,,,,,,,,, -333100,4.5100565,0.6502781,,,,,,,,,,,,,, -333200,4.989475,0.62513596,,,,,,,,,,,,,, -333300,4.744852,0.6149423,,,,,,,,,,,,,, -333400,4.5882535,0.636926,,,,,,,,,,,,,, -333500,4.4626327,0.59426767,,,,,,,,,,,,,, -333600,5.0275307,0.6943463,,,,,,,,,,,,,, -333700,4.6289315,0.6698309,,,,,,,,,,,,,, -333800,4.345716,0.703808,,,,,,,,,,,,,, -333900,4.704836,0.5653527,,,,,,,,,,,,,, -334000,4.2152505,0.5725741,,,,,,,,,,,,,, -334100,4.4096236,0.60354275,,,,,,,,,,,,,, -334200,4.2485666,0.6394341,,,,,,,,,,,,,, -334300,3.944453,0.55323404,,,,,,,,,,,,,, -334400,4.546339,0.58569616,,,,,,,,,,,,,, -334500,4.3269734,0.6206793,,,,,,,,,,,,,, -334508,,,0.9585060477256776,0.1500162780284881,0.7559999823570251,1.052400827407837,50000.0,0.631100058555603,1.819270372390747,10000.0,112768.79986357687,116809.0741918087,112768.79986357687,4013.174238204956,13.886026620864868,0.0 -334600,4.2649746,0.55645853,,,,,,,,,,,,,, -334700,4.3888626,0.5967486,,,,,,,,,,,,,, -334800,4.5210285,0.63277316,,,,,,,,,,,,,, -334900,4.420468,0.61741257,,,,,,,,,,,,,, -335000,4.802305,0.6349118,,,,,,,,,,,,,, -335100,4.6210804,0.57606184,,,,,,,,,,,,,, -335200,4.4229155,0.6657717,,,,,,,,,,,,,, -335300,4.526055,0.6537304,,,,,,,,,,,,,, -335400,4.2143927,0.5412573,,,,,,,,,,,,,, -335500,4.6951385,0.68064183,,,,,,,,,,,,,, -335600,4.2669997,0.61526555,,,,,,,,,,,,,, -335700,4.806359,0.6906777,,,,,,,,,,,,,, -335800,4.503487,0.54362524,,,,,,,,,,,,,, -335900,4.3845677,0.64456743,,,,,,,,,,,,,, -336000,4.729875,0.6036001,,,,,,,,,,,,,, -336022,,,0.961336076259613,0.1454212963581085,0.7561999559402466,1.0526785850524902,50000.0,0.631600022315979,1.8194116353988647,10000.0,113278.91282248496,117336.87652873991,113278.91282248496,4030.716554164888,13.973962306976318,0.0 -336100,4.503303,0.6210336,,,,,,,,,,,,,, -336200,4.8528647,0.6207074,,,,,,,,,,,,,, -336300,4.695908,0.64664507,,,,,,,,,,,,,, -336400,4.613738,0.574592,,,,,,,,,,,,,, -336500,4.6812925,0.6491972,,,,,,,,,,,,,, -336600,4.301653,0.64741504,,,,,,,,,,,,,, -336700,4.862434,0.5613106,,,,,,,,,,,,,, -336800,4.6849604,0.64838475,,,,,,,,,,,,,, -336900,4.553692,0.60928625,,,,,,,,,,,,,, -337000,4.325275,0.5827045,,,,,,,,,,,,,, -337100,4.500874,0.62887233,,,,,,,,,,,,,, -337200,4.839661,0.6105009,,,,,,,,,,,,,, -337300,4.1747503,0.567337,,,,,,,,,,,,,, -337400,4.237851,0.59189314,,,,,,,,,,,,,, -337500,4.2515306,0.62697154,,,,,,,,,,,,,, -337536,,,0.9607979655265808,0.1482814997434616,0.7560799717903137,1.0519975423812866,50000.0,0.6317000389099121,1.8193473815917969,10000.0,113789.0877354145,117864.37035274506,113789.0877354145,4047.8868165016174,14.062806129455566,0.0 -337600,4.9238863,0.7171486,,,,,,,,,,,,,, -337700,4.7360454,0.65698445,,,,,,,,,,,,,, -337800,4.7405086,0.6281147,,,,,,,,,,,,,, -337900,4.1544905,0.7107104,,,,,,,,,,,,,, -338000,4.534075,0.63474065,,,,,,,,,,,,,, -338100,4.5648093,0.63378596,,,,,,,,,,,,,, -338200,4.833314,0.6475404,,,,,,,,,,,,,, -338300,4.2931275,0.58452725,,,,,,,,,,,,,, -338400,4.4076123,0.65344226,,,,,,,,,,,,,, -338500,4.35559,0.62654346,,,,,,,,,,,,,, -338600,4.3609514,0.6627526,,,,,,,,,,,,,, -338700,5.121256,0.6713496,,,,,,,,,,,,,, -338800,4.6139693,0.5536266,,,,,,,,,,,,,, -338900,4.247253,0.5799903,,,,,,,,,,,,,, -339000,4.6314936,0.6424623,,,,,,,,,,,,,, -339051,,,0.959741711616516,0.1503542214632034,0.7552399635314941,1.052941083908081,50000.0,0.6312000155448914,1.817227244377136,10000.0,114299.129904747,118391.57015228271,114299.129904747,4064.900999069214,14.147011280059814,0.0 -339100,4.495675,0.5818885,,,,,,,,,,,,,, -339200,4.425459,0.6521021,,,,,,,,,,,,,, -339300,4.6196446,0.68524325,,,,,,,,,,,,,, -339400,4.8169723,0.6585742,,,,,,,,,,,,,, -339500,4.7990565,0.6198142,,,,,,,,,,,,,, -339600,4.285749,0.62885845,,,,,,,,,,,,,, -339700,4.778104,0.71906984,,,,,,,,,,,,,, -339800,4.190309,0.62107635,,,,,,,,,,,,,, -339900,4.4414763,0.6020829,,,,,,,,,,,,,, -340000,4.837119,0.65156364,,,,,,,,,,,,,, -340100,4.4381943,0.61820596,,,,,,,,,,,,,, -340200,4.531872,0.6852404,,,,,,,,,,,,,, -340300,3.9273276,0.5477781,,,,,,,,,,,,,, -340400,4.5936837,0.66634375,,,,,,,,,,,,,, -340500,4.061492,0.5704487,,,,,,,,,,,,,, -340565,,,0.9613958597183228,0.1452462524175644,0.7559999823570251,1.053504467010498,50000.0,0.6317000389099121,1.819175243377685,10000.0,114809.09057998656,118918.73136019708,114809.09057998656,4081.9577882289886,14.231061697006226,0.0 -340600,4.068377,0.5308907,,,,,,,,,,,,,, -340700,4.467339,0.6338573,,,,,,,,,,,,,, -340800,4.5989146,0.61644876,,,,,,,,,,,,,, -340900,5.4989195,0.6509631,,,,,,,,,,,,,, -341000,4.152306,0.5852532,,,,,,,,,,,,,, -341100,4.6413655,0.6045556,,,,,,,,,,,,,, -341200,4.678745,0.64048517,,,,,,,,,,,,,, -341300,4.416329,0.64953536,,,,,,,,,,,,,, -341400,4.406301,0.649104,,,,,,,,,,,,,, -341500,4.5499144,0.6052009,,,,,,,,,,,,,, -341600,4.6415176,0.6861665,,,,,,,,,,,,,, -341700,4.7949295,0.6232833,,,,,,,,,,,,,, -341800,4.203076,0.5922803,,,,,,,,,,,,,, -341900,4.333049,0.6024476,,,,,,,,,,,,,, -342000,4.5891643,0.6590691,,,,,,,,,,,,,, -342080,,,0.9623724222183228,0.1430144309997558,0.7557199597358704,1.0525835752487185,50000.0,0.6309000253677368,1.8193016052246087,10000.0,115319.19979858398,119446.07986211775,115319.19979858398,4099.052008152008,14.317641258239746,0.0 -342100,4.4634223,0.5916314,,,,,,,,,,,,,, -342200,4.2816324,0.60054195,,,,,,,,,,,,,, -342300,4.048176,0.6022324,,,,,,,,,,,,,, -342400,4.5477285,0.660259,,,,,,,,,,,,,, -342500,4.695623,0.6070962,,,,,,,,,,,,,, -342600,4.509483,0.55255115,,,,,,,,,,,,,, -342700,4.206652,0.5947214,,,,,,,,,,,,,, -342800,4.3785734,0.6306911,,,,,,,,,,,,,, -342900,4.199922,0.5722889,,,,,,,,,,,,,, -343000,4.2457223,0.6168661,,,,,,,,,,,,,, -343100,4.619743,0.5820798,,,,,,,,,,,,,, -343200,4.3691525,0.6269768,,,,,,,,,,,,,, -343300,4.6521482,0.6471872,,,,,,,,,,,,,, -343400,4.2626457,0.5788396,,,,,,,,,,,,,, -343500,4.120019,0.567471,,,,,,,,,,,,,, -343594,,,0.960598647594452,0.1461528092622757,0.7560799717903137,1.0523217916488647,50000.0,0.631600022315979,1.8172719478607176,10000.0,115829.36331558228,119973.48504543304,115829.36331558228,4116.148473501205,14.404023885726929,0.0 -343600,4.2090306,0.5993399,,,,,,,,,,,,,, -343700,4.155104,0.6153204,,,,,,,,,,,,,, -343800,3.9972777,0.57964927,,,,,,,,,,,,,, -343900,5.0817676,0.68133664,,,,,,,,,,,,,, -344000,4.7481904,0.62412584,,,,,,,,,,,,,, -344100,4.2674603,0.5747607,,,,,,,,,,,,,, -344200,4.7151012,0.72387636,,,,,,,,,,,,,, -344300,4.3423867,0.6511879,,,,,,,,,,,,,, -344400,4.5954776,0.5881167,,,,,,,,,,,,,, -344500,4.381102,0.6108333,,,,,,,,,,,,,, -344600,5.213326,0.6628892,,,,,,,,,,,,,, -344700,4.326616,0.56846607,,,,,,,,,,,,,, -344800,4.2746267,0.6244609,,,,,,,,,,,,,, -344900,4.1884656,0.6089575,,,,,,,,,,,,,, -345000,4.8301005,0.69732195,,,,,,,,,,,,,, -345100,4.1291394,0.6568706,,,,,,,,,,,,,, -345108,,,0.960339605808258,0.1466981768608093,0.755899965763092,1.0524567365646362,50000.0,0.6304000020027161,1.8181381225585933,10000.0,116339.2823779583,120500.41694951056,116339.2823779583,4133.013184070587,14.49116039276123,0.0 -345200,4.0636387,0.5801676,,,,,,,,,,,,,, -345300,5.009925,0.6105383,,,,,,,,,,,,,, -345400,4.6447315,0.6487771,,,,,,,,,,,,,, -345500,5.1595078,0.7002398,,,,,,,,,,,,,, -345600,4.5257,0.65481126,,,,,,,,,,,,,, -345700,4.3004527,0.6536833,,,,,,,,,,,,,, -345800,4.333614,0.5845978,,,,,,,,,,,,,, -345900,4.7079654,0.6449787,,,,,,,,,,,,,, -346000,4.4246893,0.6463121,,,,,,,,,,,,,, -346100,4.5194187,0.64073336,,,,,,,,,,,,,, -346200,4.4784894,0.61888736,,,,,,,,,,,,,, -346300,4.2903624,0.57046616,,,,,,,,,,,,,, -346400,4.3606586,0.64093745,,,,,,,,,,,,,, -346500,5.03543,0.6856103,,,,,,,,,,,,,, -346600,4.474609,0.6554174,,,,,,,,,,,,,, -346622,,,0.9592235088348388,0.148650661110878,0.7559799551963806,1.051628589630127,50000.0,0.6312000155448914,1.817665100097656,10000.0,116849.27911138536,121028.6317665577,116849.27911138536,4151.086514234543,14.57667088508606,0.0 -346700,4.3016157,0.6363816,,,,,,,,,,,,,, -346800,4.485559,0.7400448,,,,,,,,,,,,,, -346900,4.71285,0.6257773,,,,,,,,,,,,,, -347000,5.0802817,0.6722127,,,,,,,,,,,,,, -347100,4.744091,0.63833994,,,,,,,,,,,,,, -347200,4.4228387,0.62511575,,,,,,,,,,,,,, -347300,5.572302,0.72032696,,,,,,,,,,,,,, -347400,4.207958,0.55343235,,,,,,,,,,,,,, -347500,4.4102283,0.62477064,,,,,,,,,,,,,, -347600,4.732698,0.6363331,,,,,,,,,,,,,, -347700,4.693924,0.6522193,,,,,,,,,,,,,, -347800,4.5447893,0.6381389,,,,,,,,,,,,,, -347900,4.4710245,0.6408597,,,,,,,,,,,,,, -348000,4.752021,0.61426973,,,,,,,,,,,,,, -348100,4.567299,0.6222316,,,,,,,,,,,,,, -348136,,,0.9605388641357422,0.1495785564184188,0.7561599612236023,1.0520284175872805,50000.0,0.6324000358581543,1.816881895065308,10000.0,117359.37034726144,121555.94112229349,117359.37034726144,4168.160834312439,14.661510467529297,0.0 -348200,4.574088,0.65040547,,,,,,,,,,,,,, -348300,4.743966,0.6549862,,,,,,,,,,,,,, -348400,5.181365,0.6126242,,,,,,,,,,,,,, -348500,4.5198236,0.5652525,,,,,,,,,,,,,, -348600,4.363583,0.6100288,,,,,,,,,,,,,, -348700,4.3804836,0.66375726,,,,,,,,,,,,,, -348800,4.529029,0.5626559,,,,,,,,,,,,,, -348900,4.4629793,0.5911262,,,,,,,,,,,,,, -349000,4.6044655,0.6957942,,,,,,,,,,,,,, -349100,4.578485,0.6160536,,,,,,,,,,,,,, -349200,4.99035,0.6123216,,,,,,,,,,,,,, -349300,4.1948566,0.5846878,,,,,,,,,,,,,, -349400,4.7021837,0.694515,,,,,,,,,,,,,, -349500,4.2550883,0.6184231,,,,,,,,,,,,,, -349600,4.793116,0.6254562,,,,,,,,,,,,,, -349650,,,0.960758090019226,0.1472361534833908,0.7560799717903137,1.0515118837356567,50000.0,0.6306000351905823,1.8176124095916748,10000.0,117869.27389907835,122083.12540221214,117869.27389907835,4185.293612003326,14.747785091400146,0.0 -349700,4.9125,0.6470175,,,,,,,,,,,,,, -349800,4.278381,0.6546926,,,,,,,,,,,,,, -349900,4.8416786,0.6288228,,,,,,,,,,,,,, -350000,4.1302743,0.6458725,,,,,,,,,,,,,, -350100,4.641996,0.620305,,,,,,,,,,,,,, -350200,4.520259,0.619248,,,,,,,,,,,,,, -350300,4.312088,0.6006336,,,,,,,,,,,,,, -350400,4.680745,0.6377726,,,,,,,,,,,,,, -350500,4.5253077,0.6857057,,,,,,,,,,,,,, -350600,4.3952794,0.58537865,,,,,,,,,,,,,, -350700,4.6697726,0.6981542,,,,,,,,,,,,,, -350800,4.8473153,0.6026775,,,,,,,,,,,,,, -350900,4.432761,0.52816844,,,,,,,,,,,,,, -351000,5.0002875,0.5903079,,,,,,,,,,,,,, -351100,4.7685547,0.6919505,,,,,,,,,,,,,, -351164,,,0.9616150856018066,0.145010843873024,0.756339967250824,1.0514135360717771,50000.0,0.6312000155448914,1.8179231882095337,10000.0,118379.28754210472,122610.44742536543,118379.28754210472,4202.453873872757,14.835796117782593,0.0 -351200,4.4130197,0.63758504,,,,,,,,,,,,,, -351300,4.2031903,0.62161076,,,,,,,,,,,,,, -351400,3.922212,0.56279916,,,,,,,,,,,,,, -351500,4.5056458,0.5983549,,,,,,,,,,,,,, -351600,4.6020937,0.6497598,,,,,,,,,,,,,, -351700,4.588605,0.58802015,,,,,,,,,,,,,, -351800,5.3919406,0.6422459,,,,,,,,,,,,,, -351900,4.5408564,0.6144575,,,,,,,,,,,,,, -352000,4.5631804,0.6352493,,,,,,,,,,,,,, -352100,4.6931314,0.6009042,,,,,,,,,,,,,, -352200,4.414574,0.6128886,,,,,,,,,,,,,, -352300,4.3402357,0.61278015,,,,,,,,,,,,,, -352400,4.032361,0.62242293,,,,,,,,,,,,,, -352500,4.450099,0.64035857,,,,,,,,,,,,,, -352600,4.6248264,0.6768857,,,,,,,,,,,,,, -352677,,,0.961355984210968,0.1440159380435943,0.7555800080299377,1.051655650138855,50000.0,0.6313000321388245,1.817135453224182,10000.0,118889.12491345406,123137.43913602828,118889.12491345406,4219.443992614746,14.937951803207396,0.0 -352700,4.387011,0.61712706,,,,,,,,,,,,,, -352800,5.0520597,0.63876295,,,,,,,,,,,,,, -352900,4.697355,0.64493775,,,,,,,,,,,,,, -353000,4.3702545,0.5848667,,,,,,,,,,,,,, -353100,5.2307477,0.66676795,,,,,,,,,,,,,, -353200,4.2932024,0.65384877,,,,,,,,,,,,,, -353300,4.439137,0.64981395,,,,,,,,,,,,,, -353400,4.389227,0.6757413,,,,,,,,,,,,,, -353500,4.7315173,0.61641073,,,,,,,,,,,,,, -353600,4.747996,0.6503727,,,,,,,,,,,,,, -353700,4.724851,0.628573,,,,,,,,,,,,,, -353800,4.366166,0.5864589,,,,,,,,,,,,,, -353900,4.3831296,0.6276216,,,,,,,,,,,,,, -354000,4.397693,0.61747646,,,,,,,,,,,,,, -354100,4.360209,0.53432995,,,,,,,,,,,,,, -354191,,,0.9601203799247742,0.1484006345272064,0.7557599544525146,1.0514042377471924,50000.0,0.6313000321388245,1.8169622421264648,10000.0,119399.18874073029,123664.57041716576,119399.18874073029,4236.364371299744,15.025042533874512,0.0 -354200,5.1380086,0.6669148,,,,,,,,,,,,,, -354300,4.6983147,0.6160045,,,,,,,,,,,,,, -354400,4.3402095,0.65342534,,,,,,,,,,,,,, -354500,4.7996554,0.6138205,,,,,,,,,,,,,, -354600,4.1966424,0.61774564,,,,,,,,,,,,,, -354700,4.6484995,0.6426804,,,,,,,,,,,,,, -354800,4.293424,0.62564814,,,,,,,,,,,,,, -354900,4.9241457,0.5889009,,,,,,,,,,,,,, -355000,4.694577,0.5925147,,,,,,,,,,,,,, -355100,4.478398,0.5987237,,,,,,,,,,,,,, -355200,4.670333,0.71771127,,,,,,,,,,,,,, -355300,4.5651765,0.6139531,,,,,,,,,,,,,, -355400,4.162049,0.5210642,,,,,,,,,,,,,, -355500,4.2769885,0.60141593,,,,,,,,,,,,,, -355600,5.107187,0.6304356,,,,,,,,,,,,,, -355700,4.961459,0.63332796,,,,,,,,,,,,,, -355704,,,0.9602997303009032,0.1497524827718734,0.7556799650192261,1.0527878999710083,50000.0,0.6308000087738037,1.818961262702942,10000.0,119909.20295524596,124191.95948266984,119909.20295524596,4253.577853441238,15.125241994857788,0.0 -355800,4.0668087,0.54986906,,,,,,,,,,,,,, -355900,4.2564163,0.6400212,,,,,,,,,,,,,, -356000,4.382546,0.6149629,,,,,,,,,,,,,, -356100,4.4682255,0.6048401,,,,,,,,,,,,,, -356200,4.874779,0.6526054,,,,,,,,,,,,,, -356300,4.9039087,0.7398531,,,,,,,,,,,,,, -356400,5.6596036,0.70011276,,,,,,,,,,,,,, -356500,4.633612,0.6035838,,,,,,,,,,,,,, -356600,4.301538,0.6196423,,,,,,,,,,,,,, -356700,4.7625704,0.5553297,,,,,,,,,,,,,, -356800,4.3487115,0.6408736,,,,,,,,,,,,,, -356900,4.109542,0.64214104,,,,,,,,,,,,,, -357000,4.4731646,0.62423235,,,,,,,,,,,,,, -357100,4.25398,0.6078418,,,,,,,,,,,,,, -357200,5.049681,0.6536448,,,,,,,,,,,,,, -357218,,,0.9614157676696776,0.1427268236875534,0.7556999921798706,1.0533047914505005,50000.0,0.6312000155448914,1.820838928222656,10000.0,120419.31431365012,124719.37481164932,120419.31431365012,4270.738157272339,15.209940195083618,0.0 -357300,4.603366,0.6570854,,,,,,,,,,,,,, -357400,4.299489,0.5785404,,,,,,,,,,,,,, -357500,4.420279,0.57866216,,,,,,,,,,,,,, -357600,4.69461,0.6146318,,,,,,,,,,,,,, -357700,4.222199,0.60440487,,,,,,,,,,,,,, -357800,4.3245273,0.5880522,,,,,,,,,,,,,, -357900,4.1909103,0.6749267,,,,,,,,,,,,,, -358000,4.982807,0.68438315,,,,,,,,,,,,,, -358100,4.686878,0.6493144,,,,,,,,,,,,,, -358200,4.6506476,0.60346365,,,,,,,,,,,,,, -358300,4.4151816,0.6329633,,,,,,,,,,,,,, -358400,4.071156,0.5721752,,,,,,,,,,,,,, -358500,4.776544,0.6756002,,,,,,,,,,,,,, -358600,4.602329,0.66817546,,,,,,,,,,,,,, -358700,4.841377,0.60913444,,,,,,,,,,,,,, -358732,,,0.96097731590271,0.1458709090948104,0.7553600072860718,1.052885890007019,50000.0,0.6318000555038452,1.8178112506866453,10000.0,120929.3461754322,125246.63512897491,120929.3461754322,4287.819470405579,15.2974956035614,0.0 -358800,4.9902534,0.7534279,,,,,,,,,,,,,, -358900,4.9844394,0.7302676,,,,,,,,,,,,,, -359000,4.4615073,0.5596358,,,,,,,,,,,,,, -359100,4.0598497,0.5836621,,,,,,,,,,,,,, -359200,4.590667,0.61391014,,,,,,,,,,,,,, -359300,4.668153,0.68855244,,,,,,,,,,,,,, -359400,4.180534,0.5825091,,,,,,,,,,,,,, -359500,4.407434,0.65337455,,,,,,,,,,,,,, -359600,4.206876,0.61284083,,,,,,,,,,,,,, -359700,4.4979534,0.5834129,,,,,,,,,,,,,, -359800,4.3888574,0.63515896,,,,,,,,,,,,,, -359900,4.646316,0.6111742,,,,,,,,,,,,,, -360000,4.823476,0.6263283,,,,,,,,,,,,,, -360100,4.4617925,0.59138674,,,,,,,,,,,,,, -360200,5.2288666,0.7896461,,,,,,,,,,,,,, -360247,,,0.9602997303009032,0.1472483873367309,0.7559599876403809,1.0523381233215332,50000.0,0.6318000555038452,1.817411184310913,10000.0,121439.44315695764,125773.91371035576,121439.44315695764,4304.847638845444,15.390103816986084,0.0 -360300,4.269696,0.5740687,,,,,,,,,,,,,, -360400,4.2395062,0.57181233,,,,,,,,,,,,,, -360500,4.582274,0.61329776,,,,,,,,,,,,,, -360600,4.363187,0.56956196,,,,,,,,,,,,,, -360700,4.251656,0.6045358,,,,,,,,,,,,,, -360800,4.461681,0.68440545,,,,,,,,,,,,,, -360900,3.8667212,0.53213584,,,,,,,,,,,,,, -361000,4.381277,0.61750686,,,,,,,,,,,,,, -361100,4.4069247,0.664567,,,,,,,,,,,,,, -361200,4.6326504,0.6459955,,,,,,,,,,,,,, -361300,4.4994774,0.6533044,,,,,,,,,,,,,, -361400,4.0763397,0.5830457,,,,,,,,,,,,,, -361500,4.4853125,0.6250173,,,,,,,,,,,,,, -361600,5.1175942,0.72082806,,,,,,,,,,,,,, -361700,4.55173,0.656227,,,,,,,,,,,,,, -361761,,,0.9618144035339355,0.1445937156677246,0.7554599642753601,1.0517603158950806,50000.0,0.6315000057220459,1.817147135734558,10000.0,121949.60582256316,126301.38193583488,121949.60582256316,4322.003366947174,15.480299949645996,0.0 -361800,4.8467817,0.6305414,,,,,,,,,,,,,, -361900,4.5397573,0.663279,,,,,,,,,,,,,, -362000,4.5822134,0.60979426,,,,,,,,,,,,,, -362100,4.3168716,0.58502686,,,,,,,,,,,,,, -362200,4.740304,0.60287106,,,,,,,,,,,,,, -362300,4.0371513,0.62149405,,,,,,,,,,,,,, -362400,4.563976,0.5734764,,,,,,,,,,,,,, -362500,4.3103065,0.6189778,,,,,,,,,,,,,, -362600,4.6197343,0.6054007,,,,,,,,,,,,,, -362700,4.2571225,0.5113306,,,,,,,,,,,,,, -362800,5.5472064,0.6416209,,,,,,,,,,,,,, -362900,4.6786065,0.6662841,,,,,,,,,,,,,, -363000,4.3945007,0.5943691,,,,,,,,,,,,,, -363100,4.3137164,0.60448974,,,,,,,,,,,,,, -363200,4.5589914,0.5601856,,,,,,,,,,,,,, -363275,,,0.9597616195678712,0.1499248892068863,0.7556399703025818,1.0522288084030151,50000.0,0.6308000087738037,1.8169174194335933,10000.0,122459.51914787292,126828.55815386772,122459.51914787292,4339.118178367615,15.568543672561646,0.0 -363300,4.5001483,0.5975904,,,,,,,,,,,,,, -363400,4.647363,0.63884133,,,,,,,,,,,,,, -363500,4.115897,0.58577776,,,,,,,,,,,,,, -363600,4.5841618,0.5861712,,,,,,,,,,,,,, -363700,4.1921883,0.51719224,,,,,,,,,,,,,, -363800,4.5746856,0.5879195,,,,,,,,,,,,,, -363900,4.3717194,0.59678984,,,,,,,,,,,,,, -364000,4.1979537,0.58444,,,,,,,,,,,,,, -364100,4.321239,0.5911689,,,,,,,,,,,,,, -364200,4.5068526,0.58163893,,,,,,,,,,,,,, -364300,4.4114285,0.62815917,,,,,,,,,,,,,, -364400,4.385881,0.62582093,,,,,,,,,,,,,, -364500,5.1260357,0.7028507,,,,,,,,,,,,,, -364600,4.8112082,0.6104536,,,,,,,,,,,,,, -364700,4.7754083,0.6498987,,,,,,,,,,,,,, -364788,,,0.961136758327484,0.145282357931137,0.7560999989509583,1.051038384437561,50000.0,0.6325000524520874,1.817543625831604,10000.0,122969.4705851078,127355.93138194084,122969.4705851078,4356.393066167831,15.656342029571531,0.0 -364800,4.175828,0.6141583,,,,,,,,,,,,,, -364900,4.2873373,0.594908,,,,,,,,,,,,,, -365000,4.742359,0.65017813,,,,,,,,,,,,,, -365100,4.389525,0.598464,,,,,,,,,,,,,, -365200,4.20629,0.6347927,,,,,,,,,,,,,, -365300,4.426644,0.64825094,,,,,,,,,,,,,, -365400,4.872705,0.66212934,,,,,,,,,,,,,, -365500,5.051168,0.6107669,,,,,,,,,,,,,, -365600,4.9151964,0.6139441,,,,,,,,,,,,,, -365700,4.833306,0.7361661,,,,,,,,,,,,,, -365800,4.787036,0.66499555,,,,,,,,,,,,,, -365900,4.3461103,0.5608482,,,,,,,,,,,,,, -366000,4.2923164,0.6191391,,,,,,,,,,,,,, -366100,4.896265,0.6421724,,,,,,,,,,,,,, -366200,4.2814074,0.66703784,,,,,,,,,,,,,, -366300,4.3534546,0.62476236,,,,,,,,,,,,,, -366301,,,0.9606584906578064,0.1480656564235687,0.7558199763298035,1.0525013208389282,50000.0,0.6317000389099121,1.818870186805725,10000.0,123479.34378123283,127882.89841532709,123479.34378123283,4373.334760904312,15.750553607940674,0.0 -366400,4.6373215,0.64719176,,,,,,,,,,,,,, -366500,4.778758,0.6952645,,,,,,,,,,,,,, -366600,4.5304728,0.65168464,,,,,,,,,,,,,, -366700,4.706754,0.59171045,,,,,,,,,,,,,, -366800,4.290656,0.6403935,,,,,,,,,,,,,, -366900,4.2470837,0.6468748,,,,,,,,,,,,,, -367000,4.23569,0.5645299,,,,,,,,,,,,,, -367100,4.678057,0.70849967,,,,,,,,,,,,,, -367200,4.9196324,0.590908,,,,,,,,,,,,,, -367300,4.207514,0.57475907,,,,,,,,,,,,,, -367400,4.5579653,0.592278,,,,,,,,,,,,,, -367500,4.164173,0.59707785,,,,,,,,,,,,,, -367600,4.789642,0.59092635,,,,,,,,,,,,,, -367700,4.8861814,0.6761022,,,,,,,,,,,,,, -367800,4.5445175,0.60660714,,,,,,,,,,,,,, -367814,,,0.961156725883484,0.1470474749803543,0.7556599974632263,1.0533629655838013,50000.0,0.6315000057220459,1.8194811344146729,10000.0,123989.37150287628,128410.23392796516,123989.37150287628,4390.496916055679,15.836224555969238,0.0 -367900,4.78314,0.5845348,,,,,,,,,,,,,, -368000,4.6098585,0.66956115,,,,,,,,,,,,,, -368100,4.7583475,0.6777728,,,,,,,,,,,,,, -368200,4.3834286,0.60773975,,,,,,,,,,,,,, -368300,4.632511,0.7332162,,,,,,,,,,,,,, -368400,4.517886,0.61807257,,,,,,,,,,,,,, -368500,4.411028,0.6231298,,,,,,,,,,,,,, -368600,4.235429,0.57180053,,,,,,,,,,,,,, -368700,4.669776,0.6229739,,,,,,,,,,,,,, -368800,4.4820037,0.684518,,,,,,,,,,,,,, -368900,4.4350033,0.5650982,,,,,,,,,,,,,, -369000,4.2084317,0.65136236,,,,,,,,,,,,,, -369100,4.2686152,0.6498075,,,,,,,,,,,,,, -369200,5.0461164,0.6827357,,,,,,,,,,,,,, -369300,5.1483593,0.6499362,,,,,,,,,,,,,, -369328,,,0.9606983065605164,0.1482247412204742,0.7558599710464478,1.0521869659423828,50000.0,0.6312000155448914,1.8174362182617188,10000.0,124499.26441526412,128937.36512541772,124499.26441526412,4407.583589076996,15.926433086395264,0.0 -369400,4.091107,0.5936096,,,,,,,,,,,,,, -369500,4.716769,0.6513841,,,,,,,,,,,,,, -369600,4.4986115,0.63768685,,,,,,,,,,,,,, -369700,4.2500367,0.5655186,,,,,,,,,,,,,, -369800,4.803268,0.65911394,,,,,,,,,,,,,, -369900,4.952558,0.7046741,,,,,,,,,,,,,, -370000,4.061024,0.54766315,,,,,,,,,,,,,, -370100,4.682063,0.6579547,,,,,,,,,,,,,, -370200,4.615841,0.62717694,,,,,,,,,,,,,, -370300,4.2037444,0.61323625,,,,,,,,,,,,,, -370400,4.7939577,0.66779506,,,,,,,,,,,,,, -370500,4.567541,0.5776358,,,,,,,,,,,,,, -370600,4.6223044,0.5998663,,,,,,,,,,,,,, -370700,4.7132025,0.60849,,,,,,,,,,,,,, -370800,4.599249,0.6173185,,,,,,,,,,,,,, -370841,,,0.960598647594452,0.1493550240993499,0.7554399967193604,1.0532011985778809,50000.0,0.631600022315979,1.8188133239746087,10000.0,125009.21130037308,129464.3785161972,125009.21130037308,4424.501345396042,16.016053199768066,0.0 -370900,4.3039184,0.6387355,,,,,,,,,,,,,, -371000,4.3091736,0.593212,,,,,,,,,,,,,, -371100,4.6764708,0.64740944,,,,,,,,,,,,,, -371200,4.6033297,0.67041904,,,,,,,,,,,,,, -371300,4.530157,0.6509589,,,,,,,,,,,,,, -371400,4.108808,0.63753176,,,,,,,,,,,,,, -371500,4.6990886,0.67224747,,,,,,,,,,,,,, -371600,4.3658996,0.59219766,,,,,,,,,,,,,, -371700,4.5386963,0.60586,,,,,,,,,,,,,, -371800,5.077948,0.7494654,,,,,,,,,,,,,, -371900,4.1896977,0.6739585,,,,,,,,,,,,,, -372000,4.3312607,0.53692895,,,,,,,,,,,,,, -372100,5.0624037,0.63704187,,,,,,,,,,,,,, -372200,4.4247174,0.6208646,,,,,,,,,,,,,, -372300,4.579866,0.7042197,,,,,,,,,,,,,, -372355,,,0.9604392051696776,0.1477468013763427,0.7557599544525146,1.0522220134735107,50000.0,0.6315000057220459,1.8178412914276123,10000.0,125519.2497870922,129991.7078435421,125519.2497870922,4441.642533063889,16.10594081878662,0.0 -372400,4.282008,0.6160937,,,,,,,,,,,,,, -372500,5.2479854,0.64006007,,,,,,,,,,,,,, -372600,4.3853483,0.6252965,,,,,,,,,,,,,, -372700,4.266339,0.6096042,,,,,,,,,,,,,, -372800,4.28185,0.666155,,,,,,,,,,,,,, -372900,4.1935983,0.5630377,,,,,,,,,,,,,, -373000,4.561752,0.55735487,,,,,,,,,,,,,, -373100,4.649432,0.6353209,,,,,,,,,,,,,, -373200,4.5629,0.6020683,,,,,,,,,,,,,, -373300,4.3287663,0.6032418,,,,,,,,,,,,,, -373400,4.0797195,0.59853846,,,,,,,,,,,,,, -373500,4.578959,0.6670556,,,,,,,,,,,,,, -373600,4.369323,0.64156073,,,,,,,,,,,,,, -373700,4.4889054,0.6533167,,,,,,,,,,,,,, -373800,4.5875163,0.63089085,,,,,,,,,,,,,, -373869,,,0.9604192972183228,0.1477959901094436,0.7557399868965149,1.0529639720916748,50000.0,0.6319000124931335,1.8184047937393188,10000.0,126029.41074037552,130519.1618218422,126029.41074037552,4458.787220478058,16.19475793838501,0.0 -373900,4.1298633,0.6301831,,,,,,,,,,,,,, -374000,4.3993554,0.6125219,,,,,,,,,,,,,, -374100,4.5745726,0.65848523,,,,,,,,,,,,,, -374200,4.86332,0.69157296,,,,,,,,,,,,,, -374300,4.580569,0.62112176,,,,,,,,,,,,,, -374400,4.279178,0.62130225,,,,,,,,,,,,,, -374500,4.750145,0.61054647,,,,,,,,,,,,,, -374600,4.3533854,0.56474376,,,,,,,,,,,,,, -374700,4.624574,0.6073507,,,,,,,,,,,,,, -374800,5.360254,0.6326947,,,,,,,,,,,,,, -374900,4.385492,0.5832604,,,,,,,,,,,,,, -375000,4.5667877,0.5677008,,,,,,,,,,,,,, -375100,4.049531,0.56974137,,,,,,,,,,,,,, -375200,4.5970473,0.6584456,,,,,,,,,,,,,, -375300,4.2259774,0.50455564,,,,,,,,,,,,,, -375383,,,0.9611168503761292,0.1462110728025436,0.7556599974632263,1.0519813299179075,50000.0,0.6308000087738037,1.81686007976532,10000.0,126539.46411824226,131046.4018945694,126539.46411824226,4475.823565721512,16.284391403198242,0.0 -375400,4.8062353,0.6238475,,,,,,,,,,,,,, -375500,4.251686,0.5440837,,,,,,,,,,,,,, -375600,4.3892417,0.6054585,,,,,,,,,,,,,, -375700,4.78397,0.71950835,,,,,,,,,,,,,, -375800,5.1079316,0.6488588,,,,,,,,,,,,,, -375900,4.5489335,0.5857085,,,,,,,,,,,,,, -376000,4.376383,0.6018623,,,,,,,,,,,,,, -376100,4.499627,0.6246369,,,,,,,,,,,,,, -376200,4.340876,0.6503048,,,,,,,,,,,,,, -376300,3.9214323,0.5709747,,,,,,,,,,,,,, -376400,4.3015018,0.6278455,,,,,,,,,,,,,, -376500,4.801832,0.63970727,,,,,,,,,,,,,, -376600,4.935433,0.642768,,,,,,,,,,,,,, -376700,4.274131,0.62305677,,,,,,,,,,,,,, -376800,4.3673587,0.62585497,,,,,,,,,,,,,, -376897,,,0.9594427347183228,0.1501639932394027,0.755620002746582,1.0519487857818604,50000.0,0.6315000057220459,1.8168940544128416,10000.0,127049.39342689514,131573.39998173714,127049.39342689514,4492.737801551819,16.379032611846924,0.0 -376900,4.8900857,0.6590952,,,,,,,,,,,,,, -377000,5.2344947,0.6416319,,,,,,,,,,,,,, -377100,4.2661686,0.64697146,,,,,,,,,,,,,, -377200,4.5290136,0.6387837,,,,,,,,,,,,,, -377300,4.872069,0.6475478,,,,,,,,,,,,,, -377400,4.3373246,0.66499114,,,,,,,,,,,,,, -377500,4.219407,0.5980644,,,,,,,,,,,,,, -377600,4.421224,0.6246793,,,,,,,,,,,,,, -377700,4.3901153,0.61075467,,,,,,,,,,,,,, -377800,4.541929,0.5882745,,,,,,,,,,,,,, -377900,4.5249543,0.645176,,,,,,,,,,,,,, -378000,4.6162996,0.5857425,,,,,,,,,,,,,, -378100,4.181924,0.59574306,,,,,,,,,,,,,, -378200,4.16853,0.5610045,,,,,,,,,,,,,, -378300,4.853369,0.6395771,,,,,,,,,,,,,, -378400,4.4626765,0.6048749,,,,,,,,,,,,,, -378410,,,0.9618940949440002,0.1476005017757415,0.7558799982070923,1.0524741411209106,50000.0,0.6308000087738037,1.817546963691712,10000.0,127559.34850096704,132100.75407028198,127559.34850096704,4509.983863592148,16.47271466255188,0.0 -378500,4.820912,0.6434446,,,,,,,,,,,,,, -378600,4.504298,0.65661854,,,,,,,,,,,,,, -378700,4.588249,0.6277628,,,,,,,,,,,,,, -378800,4.472194,0.53953224,,,,,,,,,,,,,, -378900,4.8156085,0.60649794,,,,,,,,,,,,,, -379000,4.7402916,0.669646,,,,,,,,,,,,,, -379100,4.508124,0.6197599,,,,,,,,,,,,,, -379200,4.3009787,0.6472499,,,,,,,,,,,,,, -379300,4.407843,0.66105264,,,,,,,,,,,,,, -379400,4.1479335,0.5511519,,,,,,,,,,,,,, -379500,5.245935,0.69489455,,,,,,,,,,,,,, -379600,4.3165345,0.6183338,,,,,,,,,,,,,, -379700,4.2605376,0.5639949,,,,,,,,,,,,,, -379800,4.2966604,0.6354319,,,,,,,,,,,,,, -379900,4.789747,0.6027315,,,,,,,,,,,,,, -379924,,,0.9614756107330322,0.1453332155942917,0.7559399604797363,1.0524545907974243,50000.0,0.631100058555603,1.817813754081726,10000.0,128069.45318174362,132628.18920063972,128069.45318174362,4527.163170099258,16.564187049865723,0.0 -380000,5.0150547,0.58233774,,,,,,,,,,,,,, -380100,4.280371,0.5914391,,,,,,,,,,,,,, -380200,4.22744,0.58211267,,,,,,,,,,,,,, -380300,4.687815,0.5882517,,,,,,,,,,,,,, -380400,4.2020664,0.60632235,,,,,,,,,,,,,, -380500,4.3180923,0.6054633,,,,,,,,,,,,,, -380600,4.8140063,0.6400689,,,,,,,,,,,,,, -380700,3.92703,0.5800799,,,,,,,,,,,,,, -380800,5.065974,0.65140194,,,,,,,,,,,,,, -380900,4.365758,0.59906745,,,,,,,,,,,,,, -381000,4.9264884,0.69322085,,,,,,,,,,,,,, -381100,4.1728263,0.6330457,,,,,,,,,,,,,, -381200,4.226581,0.60097086,,,,,,,,,,,,,, -381300,4.4418793,0.7108494,,,,,,,,,,,,,, -381400,4.4765663,0.6421093,,,,,,,,,,,,,, -381437,,,0.9616150856018066,0.1441227793693542,0.7558000087738037,1.053081750869751,50000.0,0.6314000487327576,1.818215489387512,10000.0,128579.37960290907,133155.73970913887,128579.37960290907,4544.638174533844,16.653773546218872,0.0 -381500,4.52081,0.6781526,,,,,,,,,,,,,, -381600,4.7144074,0.628755,,,,,,,,,,,,,, -381700,4.7636366,0.5287151,,,,,,,,,,,,,, -381800,4.877449,0.6439816,,,,,,,,,,,,,, -381900,4.6075315,0.6546762,,,,,,,,,,,,,, -382000,4.711158,0.6775568,,,,,,,,,,,,,, -382100,5.576517,0.57030773,,,,,,,,,,,,,, -382200,4.679326,0.5976949,,,,,,,,,,,,,, -382300,4.402195,0.5624449,,,,,,,,,,,,,, -382400,4.348926,0.6077288,,,,,,,,,,,,,, -382500,4.980016,0.6535674,,,,,,,,,,,,,, -382600,4.5714183,0.6310115,,,,,,,,,,,,,, -382700,4.641953,0.7196847,,,,,,,,,,,,,, -382800,4.3145604,0.6325468,,,,,,,,,,,,,, -382900,4.110818,0.5399781,,,,,,,,,,,,,, -382950,,,0.9603196382522584,0.1468247771263122,0.7559199929237366,1.0521663427352903,50000.0,0.6314000487327576,1.8172036409378047,10000.0,129089.33254384996,133682.95392680168,129089.33254384996,4561.747064352036,16.74677801132202,0.0 -383000,4.432203,0.57993203,,,,,,,,,,,,,, -383100,4.6904488,0.6410209,,,,,,,,,,,,,, -383200,4.8578134,0.60461783,,,,,,,,,,,,,, -383300,4.5933785,0.5874123,,,,,,,,,,,,,, -383400,4.76509,0.5974415,,,,,,,,,,,,,, -383500,4.468945,0.654745,,,,,,,,,,,,,, -383600,4.9653444,0.6008175,,,,,,,,,,,,,, -383700,4.0295887,0.59671855,,,,,,,,,,,,,, -383800,5.056215,0.6611554,,,,,,,,,,,,,, -383900,5.8715973,0.7279477,,,,,,,,,,,,,, -384000,4.3033276,0.62142235,,,,,,,,,,,,,, -384100,4.2534947,0.6077456,,,,,,,,,,,,,, -384200,4.191648,0.62077916,,,,,,,,,,,,,, -384300,4.50602,0.61695075,,,,,,,,,,,,,, -384400,4.297485,0.65500015,,,,,,,,,,,,,, -384464,,,0.9594826102256776,0.1478909403085708,0.7556799650192261,1.0529996156692505,50000.0,0.6313000321388245,1.818960666656494,10000.0,129599.38512706757,134210.98501205444,129599.38512706757,4579.573089122772,16.839839935302734,0.0 -384500,4.466037,0.591638,,,,,,,,,,,,,, -384600,4.80435,0.573701,,,,,,,,,,,,,, -384700,4.4594965,0.6011984,,,,,,,,,,,,,, -384800,4.416163,0.5944408,,,,,,,,,,,,,, -384900,4.636705,0.6700674,,,,,,,,,,,,,, -385000,4.45641,0.6891324,,,,,,,,,,,,,, -385100,4.764821,0.6769935,,,,,,,,,,,,,, -385200,4.5519404,0.6212136,,,,,,,,,,,,,, -385300,4.658679,0.6453381,,,,,,,,,,,,,, -385400,4.4975553,0.5613407,,,,,,,,,,,,,, -385500,4.768493,0.67819077,,,,,,,,,,,,,, -385600,4.457168,0.65263283,,,,,,,,,,,,,, -385700,4.6361012,0.5853584,,,,,,,,,,,,,, -385800,4.4653134,0.6503067,,,,,,,,,,,,,, -385900,4.6862783,0.6572208,,,,,,,,,,,,,, -385977,,,0.9602199792861938,0.1464436054229736,0.756119966506958,1.0509934425354004,50000.0,0.6315000057220459,1.816987991333008,10000.0,130109.36382508278,134738.26835227013,130109.36382508278,4596.726794719696,16.932087182998657,0.0 -386000,4.386393,0.59405637,,,,,,,,,,,,,, -386100,4.0106206,0.5687508,,,,,,,,,,,,,, -386200,4.401906,0.60905296,,,,,,,,,,,,,, -386300,4.3667984,0.6323567,,,,,,,,,,,,,, -386400,4.4741735,0.66065764,,,,,,,,,,,,,, -386500,4.5949435,0.6996158,,,,,,,,,,,,,, -386600,4.6241035,0.6545144,,,,,,,,,,,,,, -386700,4.3729672,0.6001649,,,,,,,,,,,,,, -386800,5.088051,0.68524516,,,,,,,,,,,,,, -386900,5.043332,0.56789964,,,,,,,,,,,,,, -387000,4.075071,0.6032512,,,,,,,,,,,,,, -387100,4.6631346,0.6749928,,,,,,,,,,,,,, -387200,4.6036315,0.68504786,,,,,,,,,,,,,, -387300,4.563134,0.66284287,,,,,,,,,,,,,, -387400,4.142016,0.58341414,,,,,,,,,,,,,, -387491,,,0.9599609375,0.1485553085803985,0.755299985408783,1.051853060722351,50000.0,0.6314000487327576,1.8155715465545648,10000.0,130619.21668457983,135265.24450850487,130619.21668457983,4613.694131135941,17.029442310333252,0.0 -387500,4.5390477,0.5629601,,,,,,,,,,,,,, -387600,4.8447614,0.67597735,,,,,,,,,,,,,, -387700,4.6287665,0.5927287,,,,,,,,,,,,,, -387800,4.127177,0.61489403,,,,,,,,,,,,,, -387900,4.5079503,0.68700415,,,,,,,,,,,,,, -388000,5.407993,0.61502886,,,,,,,,,,,,,, -388100,4.637833,0.5980596,,,,,,,,,,,,,, -388200,4.133747,0.55122954,,,,,,,,,,,,,, -388300,4.226997,0.53973293,,,,,,,,,,,,,, -388400,4.2063417,0.6012529,,,,,,,,,,,,,, -388500,4.362642,0.65333515,,,,,,,,,,,,,, -388600,4.620891,0.65046597,,,,,,,,,,,,,, -388700,4.5949154,0.6366795,,,,,,,,,,,,,, -388800,4.554953,0.6636933,,,,,,,,,,,,,, -388900,4.0668387,0.5788456,,,,,,,,,,,,,, -389000,4.4883575,0.6568545,,,,,,,,,,,,,, -389004,,,0.9614357352256776,0.1465912759304046,0.7559399604797363,1.0521849393844604,50000.0,0.6309000253677368,1.8187190294265747,10000.0,131129.21193361282,135792.42049121857,131129.21193361282,4630.728091716766,17.116252183914185,0.0 -389100,4.9797306,0.66648364,,,,,,,,,,,,,, -389200,4.458507,0.62170523,,,,,,,,,,,,,, -389300,4.3926754,0.6705748,,,,,,,,,,,,,, -389400,4.6393347,0.62739015,,,,,,,,,,,,,, -389500,4.596137,0.6464463,,,,,,,,,,,,,, -389600,4.743596,0.61172307,,,,,,,,,,,,,, -389700,5.2225833,0.6421763,,,,,,,,,,,,,, -389800,4.814622,0.6704045,,,,,,,,,,,,,, -389900,4.6414037,0.69266903,,,,,,,,,,,,,, -390000,4.3528843,0.6271985,,,,,,,,,,,,,, -390100,4.6422925,0.6294189,,,,,,,,,,,,,, -390200,4.5285654,0.58689415,,,,,,,,,,,,,, -390300,4.9442854,0.71499,,,,,,,,,,,,,, -390400,4.7699685,0.6147478,,,,,,,,,,,,,, -390500,4.770911,0.67030126,,,,,,,,,,,,,, -390518,,,0.960339605808258,0.1467641741037368,0.7556999921798706,1.0530712604522705,50000.0,0.6307000517845154,1.8176332712173464,10000.0,131639.15394306183,136319.58885860443,131639.15394306183,4647.803809642792,17.206831455230713,0.0 -390600,4.5765724,0.6396363,,,,,,,,,,,,,, -390700,4.533062,0.61424536,,,,,,,,,,,,,, -390800,4.3795543,0.562622,,,,,,,,,,,,,, -390900,4.133683,0.60491943,,,,,,,,,,,,,, -391000,4.3326077,0.6165566,,,,,,,,,,,,,, -391100,4.191318,0.6000198,,,,,,,,,,,,,, -391200,4.6580515,0.6538914,,,,,,,,,,,,,, -391300,5.1229963,0.68953264,,,,,,,,,,,,,, -391400,4.753824,0.71312636,,,,,,,,,,,,,, -391500,4.4430237,0.60361123,,,,,,,,,,,,,, -391600,4.419028,0.68650395,,,,,,,,,,,,,, -391700,4.0224476,0.5517308,,,,,,,,,,,,,, -391800,4.377753,0.6319509,,,,,,,,,,,,,, -391900,4.420415,0.60937333,,,,,,,,,,,,,, -392000,4.4043403,0.63497895,,,,,,,,,,,,,, -392031,,,0.9616150856018066,0.1448460966348648,0.755899965763092,1.0519707202911377,50000.0,0.6306000351905823,1.817515850067139,10000.0,132149.21634721756,136846.91794514656,132149.21634721756,4664.913568973541,17.301775693893433,0.0 -392100,4.204764,0.6113795,,,,,,,,,,,,,, -392200,4.401213,0.54316777,,,,,,,,,,,,,, -392300,4.3354144,0.6431787,,,,,,,,,,,,,, -392400,4.311906,0.6141043,,,,,,,,,,,,,, -392500,4.3192353,0.60742676,,,,,,,,,,,,,, -392600,5.40448,0.60068476,,,,,,,,,,,,,, -392700,4.5998077,0.5708784,,,,,,,,,,,,,, -392800,5.222716,0.6107636,,,,,,,,,,,,,, -392900,4.7357216,0.615907,,,,,,,,,,,,,, -393000,4.819175,0.6877528,,,,,,,,,,,,,, -393100,4.1635113,0.6095553,,,,,,,,,,,,,, -393200,4.336519,0.619852,,,,,,,,,,,,,, -393300,4.4453506,0.6039027,,,,,,,,,,,,,, -393400,4.769906,0.610265,,,,,,,,,,,,,, -393500,4.5925674,0.6257535,,,,,,,,,,,,,, -393545,,,0.9604192972183228,0.1471053957939148,0.7557199597358704,1.051891207695007,50000.0,0.6306000351905823,1.8179852962493896,10000.0,132659.2893781662,137374.28175091743,132659.2893781662,4682.0501046180725,17.397064208984375,0.0 -393600,4.2825813,0.65740955,,,,,,,,,,,,,, -393700,4.4277034,0.5884571,,,,,,,,,,,,,, -393800,4.352005,0.5898588,,,,,,,,,,,,,, -393900,4.325663,0.63505656,,,,,,,,,,,,,, -394000,4.4699063,0.66087276,,,,,,,,,,,,,, -394100,4.4994464,0.58395,,,,,,,,,,,,,, -394200,4.4893355,0.6241231,,,,,,,,,,,,,, -394300,4.3294773,0.5975829,,,,,,,,,,,,,, -394400,4.4223075,0.62867266,,,,,,,,,,,,,, -394500,4.9532747,0.62390566,,,,,,,,,,,,,, -394600,5.008732,0.604328,,,,,,,,,,,,,, -394700,5.255074,0.7055676,,,,,,,,,,,,,, -394800,4.5098004,0.58542633,,,,,,,,,,,,,, -394900,4.765107,0.68286484,,,,,,,,,,,,,, -395000,4.379465,0.65300393,,,,,,,,,,,,,, -395058,,,0.9599609375,0.148743599653244,0.7558599710464478,1.0518076419830322,50000.0,0.6321000456809998,1.8184858560562127,10000.0,133169.1497850418,137901.2680785656,133169.1497850418,4698.965122699738,17.549309730529785,0.0 -395100,4.644092,0.58682823,,,,,,,,,,,,,, -395200,4.4188814,0.6108481,,,,,,,,,,,,,, -395300,4.6493244,0.6109662,,,,,,,,,,,,,, -395400,4.5361085,0.66499484,,,,,,,,,,,,,, -395500,4.6264634,0.64974064,,,,,,,,,,,,,, -395600,4.3414884,0.59945303,,,,,,,,,,,,,, -395700,4.4000835,0.5703416,,,,,,,,,,,,,, -395800,4.4616485,0.6217153,,,,,,,,,,,,,, -395900,4.474955,0.66544116,,,,,,,,,,,,,, -396000,4.8170958,0.6286867,,,,,,,,,,,,,, -396100,4.1266384,0.5363017,,,,,,,,,,,,,, -396200,4.393011,0.6136427,,,,,,,,,,,,,, -396300,4.192964,0.57963943,,,,,,,,,,,,,, -396400,4.91102,0.66872364,,,,,,,,,,,,,, -396500,4.8256435,0.69143766,,,,,,,,,,,,,, -396573,,,0.9614357352256776,0.1444485336542129,0.7559399604797363,1.051970601081848,50000.0,0.6306000351905823,1.8179006576538088,10000.0,133679.27163481712,138428.61062049866,133679.27163481712,4716.026647567749,17.64861249923706,0.0 -396600,4.643987,0.6338516,,,,,,,,,,,,,, -396700,4.866178,0.59984964,,,,,,,,,,,,,, -396800,4.3614664,0.5973768,,,,,,,,,,,,,, -396900,4.718803,0.59275484,,,,,,,,,,,,,, -397000,4.6022625,0.6107786,,,,,,,,,,,,,, -397100,4.7795415,0.5788014,,,,,,,,,,,,,, -397200,4.5109854,0.6494663,,,,,,,,,,,,,, -397300,4.579932,0.66542774,,,,,,,,,,,,,, -397400,4.3881817,0.64972657,,,,,,,,,,,,,, -397500,4.526755,0.644752,,,,,,,,,,,,,, -397600,4.697292,0.60795754,,,,,,,,,,,,,, -397700,4.627472,0.6021818,,,,,,,,,,,,,, -397800,4.1417747,0.541667,,,,,,,,,,,,,, -397900,4.438681,0.6778086,,,,,,,,,,,,,, -398000,4.621259,0.6420744,,,,,,,,,,,,,, -398087,,,0.96195387840271,0.1440633684396743,0.7553600072860718,1.0524848699569702,50000.0,0.631100058555603,1.8180168867111208,10000.0,134189.4073984623,138955.87549710274,134189.4073984623,4733.000705003738,17.743713855743408,0.0 -398100,4.6020236,0.66423523,,,,,,,,,,,,,, -398200,4.6718473,0.6359018,,,,,,,,,,,,,, -398300,4.85419,0.6935128,,,,,,,,,,,,,, -398400,4.524845,0.70234436,,,,,,,,,,,,,, -398500,4.5338635,0.6144493,,,,,,,,,,,,,, -398600,4.2537804,0.5510577,,,,,,,,,,,,,, -398700,4.5369363,0.6538033,,,,,,,,,,,,,, -398800,5.03958,0.6375773,,,,,,,,,,,,,, -398900,4.7003465,0.6266085,,,,,,,,,,,,,, -399000,4.187131,0.5203129,,,,,,,,,,,,,, -399100,4.4551077,0.6502819,,,,,,,,,,,,,, -399200,4.797348,0.6499947,,,,,,,,,,,,,, -399300,4.6336856,0.6940682,,,,,,,,,,,,,, -399400,4.259829,0.5861019,,,,,,,,,,,,,, -399500,4.591734,0.59232855,,,,,,,,,,,,,, -399600,4.7883673,0.6055463,,,,,,,,,,,,,, -399601,,,0.9598014950752258,0.1474271565675735,0.7560799717903137,1.0510292053222656,50000.0,0.6314000487327576,1.8171703815460205,10000.0,134699.39883041382,139483.24274683,134699.39883041382,4750.216578006744,17.843615293502808,0.0 -399700,4.883698,0.638579,,,,,,,,,,,,,, -399800,5.036426,0.71361154,,,,,,,,,,,,,, -399900,4.3128486,0.6203447,,,,,,,,,,,,,, -400000,4.4658623,0.61043197,,,,,,,,,,,,,, -400100,4.257697,0.61380786,,,,,,,,,,,,,, -400200,4.363799,0.64152837,,,,,,,,,,,,,, -400300,4.862094,0.69365263,,,,,,,,,,,,,, -400400,4.548156,0.6074817,,,,,,,,,,,,,, -400500,4.738153,0.65198594,,,,,,,,,,,,,, -400600,4.2347684,0.6036457,,,,,,,,,,,,,, -400700,4.307309,0.6119847,,,,,,,,,,,,,, -400800,4.3774514,0.60676074,,,,,,,,,,,,,, -400900,4.2848835,0.60437626,,,,,,,,,,,,,, -401000,4.540687,0.61245465,,,,,,,,,,,,,, -401100,5.236801,0.670453,,,,,,,,,,,,,, -401115,,,0.9605388641357422,0.1472002863883972,0.7559199929237366,1.0520803928375244,50000.0,0.631600022315979,1.8182260990142824,10000.0,135209.38359832764,140010.5748746395,135209.38359832764,4767.408142089844,17.938778400421143,0.0 -401200,4.8119636,0.62005556,,,,,,,,,,,,,, -401300,5.014972,0.6539988,,,,,,,,,,,,,, -401400,4.578541,0.57327294,,,,,,,,,,,,,, -401500,4.781127,0.6167995,,,,,,,,,,,,,, -401600,4.5923347,0.539584,,,,,,,,,,,,,, -401700,4.698409,0.57983804,,,,,,,,,,,,,, -401800,4.305097,0.5524904,,,,,,,,,,,,,, -401900,4.5727196,0.671991,,,,,,,,,,,,,, -402000,4.8373938,0.6355725,,,,,,,,,,,,,, -402100,4.63814,0.63711256,,,,,,,,,,,,,, -402200,4.364575,0.6472839,,,,,,,,,,,,,, -402300,4.765073,0.6223908,,,,,,,,,,,,,, -402400,4.948992,0.61069787,,,,,,,,,,,,,, -402500,4.451801,0.55001235,,,,,,,,,,,,,, -402600,4.704302,0.69487435,,,,,,,,,,,,,, -402628,,,0.9599609375,0.1483696550130844,0.7554799914360046,1.0520527362823486,50000.0,0.6314000487327576,1.8170477151870728,10000.0,135719.27980732918,140537.51276612282,135719.27980732918,4784.297115325928,18.03217339515686,0.0 -402700,5.1369357,0.7212262,,,,,,,,,,,,,, -402800,4.474039,0.5746337,,,,,,,,,,,,,, -402900,4.417942,0.6165306,,,,,,,,,,,,,, -403000,4.3875566,0.580315,,,,,,,,,,,,,, -403100,4.0619264,0.5350507,,,,,,,,,,,,,, -403200,4.135522,0.55452144,,,,,,,,,,,,,, -403300,4.630893,0.71283054,,,,,,,,,,,,,, -403400,4.398993,0.64650583,,,,,,,,,,,,,, -403500,4.158065,0.61222196,,,,,,,,,,,,,, -403600,4.492546,0.63621736,,,,,,,,,,,,,, -403700,4.78792,0.670965,,,,,,,,,,,,,, -403800,4.4397936,0.6602183,,,,,,,,,,,,,, -403900,4.390744,0.633145,,,,,,,,,,,,,, -404000,4.4236145,0.59748065,,,,,,,,,,,,,, -404100,4.4947615,0.70089996,,,,,,,,,,,,,, -404141,,,0.9599409699440002,0.1486997753381729,0.7559199929237366,1.05121648311615,50000.0,0.6302000284194946,1.81659734249115,10000.0,136229.12922286987,141064.5624115467,136229.12922286987,4801.340654373169,18.129441499710083,0.0 -404200,4.2732224,0.56750596,,,,,,,,,,,,,, -404300,4.6895437,0.62419665,,,,,,,,,,,,,, -404400,4.4403896,0.6310645,,,,,,,,,,,,,, -404500,4.6452174,0.64776707,,,,,,,,,,,,,, -404600,4.5213737,0.59960717,,,,,,,,,,,,,, -404700,4.6327095,0.6176275,,,,,,,,,,,,,, -404800,4.467976,0.6046853,,,,,,,,,,,,,, -404900,4.6326227,0.687683,,,,,,,,,,,,,, -405000,4.2145452,0.5706698,,,,,,,,,,,,,, -405100,4.4206877,0.6189451,,,,,,,,,,,,,, -405200,4.495203,0.6076529,,,,,,,,,,,,,, -405300,4.7037597,0.59575623,,,,,,,,,,,,,, -405400,4.944071,0.61435676,,,,,,,,,,,,,, -405500,4.569481,0.6049011,,,,,,,,,,,,,, -405600,5.0052886,0.63013196,,,,,,,,,,,,,, -405655,,,0.9614157676696776,0.1460102349519729,0.7556399703025818,1.051106333732605,50000.0,0.6307000517845154,1.8162307739257808,10000.0,136739.0093715191,141591.76547813416,136739.0093715191,4818.499400377274,18.232341051101685,0.0 -405700,5.0069633,0.63987225,,,,,,,,,,,,,, -405800,4.4884906,0.63612974,,,,,,,,,,,,,, -405900,5.197708,0.62485576,,,,,,,,,,,,,, -406000,4.2708917,0.5734276,,,,,,,,,,,,,, -406100,4.2313914,0.55658317,,,,,,,,,,,,,, -406200,4.499109,0.64589906,,,,,,,,,,,,,, -406300,4.7456317,0.5762509,,,,,,,,,,,,,, -406400,4.57131,0.54795116,,,,,,,,,,,,,, -406500,4.9531794,0.6231538,,,,,,,,,,,,,, -406600,4.4883766,0.6209779,,,,,,,,,,,,,, -406700,5.2704115,0.62121856,,,,,,,,,,,,,, -406800,4.8563848,0.71304446,,,,,,,,,,,,,, -406900,4.302393,0.6403972,,,,,,,,,,,,,, -407000,4.476726,0.6305302,,,,,,,,,,,,,, -407100,4.894838,0.6459763,,,,,,,,,,,,,, -407168,,,0.9608577489852904,0.1470111906528473,0.7559199929237366,1.0524506568908691,50000.0,0.6312000155448914,1.818498373031616,10000.0,137248.94324493408,142119.07744646072,137248.94324493408,4835.725882053375,18.32414746284485,0.0 -407200,4.657719,0.68029076,,,,,,,,,,,,,, -407300,4.6434264,0.63438773,,,,,,,,,,,,,, -407400,4.07312,0.5367527,,,,,,,,,,,,,, -407500,4.0872693,0.53314406,,,,,,,,,,,,,, -407600,4.88946,0.6530886,,,,,,,,,,,,,, -407700,4.42715,0.679975,,,,,,,,,,,,,, -407800,4.2650065,0.65601146,,,,,,,,,,,,,, -407900,4.7465186,0.6642177,,,,,,,,,,,,,, -408000,4.1514215,0.58785987,,,,,,,,,,,,,, -408100,4.7260456,0.62058663,,,,,,,,,,,,,, -408200,4.612854,0.5356369,,,,,,,,,,,,,, -408300,4.098774,0.58966464,,,,,,,,,,,,,, -408400,4.5772486,0.567242,,,,,,,,,,,,,, -408500,4.522389,0.6447182,,,,,,,,,,,,,, -408600,4.599257,0.5453825,,,,,,,,,,,,,, -408682,,,0.9594826102256776,0.1504420191049575,0.7554999589920044,1.0528854131698608,50000.0,0.6312000155448914,1.8195046186447144,10000.0,137759.00602436066,142646.33941054344,137759.00602436066,4852.763338327408,18.42547035217285,0.0 -408700,5.4425406,0.7427336,,,,,,,,,,,,,, -408800,4.387985,0.5807141,,,,,,,,,,,,,, -408900,4.5151973,0.6583988,,,,,,,,,,,,,, -409000,4.7736673,0.57941186,,,,,,,,,,,,,, -409100,4.4209714,0.57629925,,,,,,,,,,,,,, -409200,4.8853655,0.61863804,,,,,,,,,,,,,, -409300,4.3819985,0.64195085,,,,,,,,,,,,,, -409400,4.1898603,0.5666193,,,,,,,,,,,,,, -409500,4.375108,0.6275365,,,,,,,,,,,,,, -409600,4.505657,0.65046644,,,,,,,,,,,,,, -409700,4.657544,0.63500017,,,,,,,,,,,,,, -409800,4.442681,0.67402226,,,,,,,,,,,,,, -409900,4.3241134,0.65155756,,,,,,,,,,,,,, -410000,4.485007,0.63544357,,,,,,,,,,,,,, -410100,4.7436724,0.66688704,,,,,,,,,,,,,, -410196,,,0.9610570669174194,0.1465514004230499,0.7554199695587158,1.051986813545227,50000.0,0.6308000087738037,1.8148283958435056,10000.0,138268.93678069115,143173.28214502335,138268.93678069115,4869.61491060257,18.52701497077942,0.0 -410200,4.2595334,0.5803395,,,,,,,,,,,,,, -410300,4.62955,0.61671984,,,,,,,,,,,,,, -410400,4.388919,0.6416536,,,,,,,,,,,,,, -410500,4.469271,0.64334893,,,,,,,,,,,,,, -410600,4.591734,0.6181938,,,,,,,,,,,,,, -410700,4.7712026,0.62780595,,,,,,,,,,,,,, -410800,4.664549,0.69413614,,,,,,,,,,,,,, -410900,4.4586964,0.60259557,,,,,,,,,,,,,, -411000,4.4184775,0.63838714,,,,,,,,,,,,,, -411100,4.9877143,0.6394168,,,,,,,,,,,,,, -411200,4.857761,0.69131494,,,,,,,,,,,,,, -411300,4.654141,0.6256013,,,,,,,,,,,,,, -411400,4.7590933,0.6375397,,,,,,,,,,,,,, -411500,4.314515,0.58544,,,,,,,,,,,,,, -411600,4.5000296,0.58233595,,,,,,,,,,,,,, -411700,4.4213605,0.60084265,,,,,,,,,,,,,, -411711,,,0.9595025181770324,0.1496334373950958,0.755899965763092,1.0508095026016235,50000.0,0.631100058555603,1.816946029663086,10000.0,138778.83112430573,143700.44857931137,138778.83112430573,4886.722577571869,18.632123947143555,0.0 -411800,4.1278887,0.6293696,,,,,,,,,,,,,, -411900,4.887199,0.61845624,,,,,,,,,,,,,, -412000,3.8379421,0.544949,,,,,,,,,,,,,, -412100,4.2377996,0.57569504,,,,,,,,,,,,,, -412200,4.56279,0.5756514,,,,,,,,,,,,,, -412300,4.5918517,0.6721432,,,,,,,,,,,,,, -412400,4.7452574,0.6135855,,,,,,,,,,,,,, -412500,4.3440733,0.61230636,,,,,,,,,,,,,, -412600,4.187968,0.59596354,,,,,,,,,,,,,, -412700,4.2293005,0.6160584,,,,,,,,,,,,,, -412800,4.598212,0.64713234,,,,,,,,,,,,,, -412900,5.1109085,0.6350648,,,,,,,,,,,,,, -413000,4.305187,0.6141265,,,,,,,,,,,,,, -413100,4.948175,0.7093998,,,,,,,,,,,,,, -413200,4.9008765,0.69821656,,,,,,,,,,,,,, -413225,,,0.9604790806770324,0.1473581790924072,0.7558199763298035,1.0527368783950806,50000.0,0.6319000124931335,1.8175504207611084,10000.0,139288.81606531143,144227.74542737007,139288.81606531143,4903.879125356674,18.72833561897278,0.0 -413300,5.01799,0.6514573,,,,,,,,,,,,,, -413400,4.3995194,0.65103394,,,,,,,,,,,,,, -413500,4.2211585,0.6227782,,,,,,,,,,,,,, -413600,4.2769628,0.5938496,,,,,,,,,,,,,, -413700,5.1420918,0.5691597,,,,,,,,,,,,,, -413800,4.443182,0.60094047,,,,,,,,,,,,,, -413900,4.401647,0.65991664,,,,,,,,,,,,,, -414000,4.5146885,0.60674936,,,,,,,,,,,,,, -414100,4.5467334,0.59588003,,,,,,,,,,,,,, -414200,4.463272,0.6558311,,,,,,,,,,,,,, -414300,5.091716,0.70600533,,,,,,,,,,,,,, -414400,4.307278,0.605582,,,,,,,,,,,,,, -414500,4.2529945,0.6554918,,,,,,,,,,,,,, -414600,4.94253,0.6258935,,,,,,,,,,,,,, -414700,4.8000193,0.6469761,,,,,,,,,,,,,, -414740,,,0.961136758327484,0.1464726626873016,0.7557199597358704,1.051748752593994,50000.0,0.6313000321388245,1.817002415657044,10000.0,139798.95591545105,144755.18977284431,139798.95591545105,4921.023787975311,18.827582120895386,0.0 -414800,4.4601755,0.5511429,,,,,,,,,,,,,, -414900,4.4933925,0.6400144,,,,,,,,,,,,,, -415000,4.538265,0.6831911,,,,,,,,,,,,,, -415100,4.3274183,0.66542,,,,,,,,,,,,,, -415200,4.3716893,0.6382795,,,,,,,,,,,,,, -415300,4.6773815,0.57927394,,,,,,,,,,,,,, -415400,4.173824,0.58302784,,,,,,,,,,,,,, -415500,4.1980906,0.61038846,,,,,,,,,,,,,, -415600,4.346868,0.64069176,,,,,,,,,,,,,, -415700,4.4454517,0.68416965,,,,,,,,,,,,,, -415800,5.273282,0.6277069,,,,,,,,,,,,,, -415900,5.098278,0.6984654,,,,,,,,,,,,,, -416000,4.4199367,0.6633056,,,,,,,,,,,,,, -416100,4.4895,0.680522,,,,,,,,,,,,,, -416200,5.225049,0.61760956,,,,,,,,,,,,,, -416253,,,0.96000075340271,0.1497869342565536,0.7555999755859375,1.0523784160614014,50000.0,0.6322000026702881,1.8180079460144043,10000.0,140309.01599621773,145282.7172703743,140309.01599621773,4938.317877531052,18.94173526763916,0.0 -416300,4.148273,0.60198456,,,,,,,,,,,,,, -416400,4.5218596,0.57620054,,,,,,,,,,,,,, -416500,4.277299,0.54988515,,,,,,,,,,,,,, -416600,5.210309,0.627687,,,,,,,,,,,,,, -416700,4.5373077,0.60643244,,,,,,,,,,,,,, -416800,4.396895,0.66020954,,,,,,,,,,,,,, -416900,4.3707924,0.6692716,,,,,,,,,,,,,, -417000,4.7067575,0.65087223,,,,,,,,,,,,,, -417100,4.0527167,0.62633026,,,,,,,,,,,,,, -417200,4.788268,0.671704,,,,,,,,,,,,,, -417300,4.6664724,0.6396928,,,,,,,,,,,,,, -417400,4.404596,0.59459615,,,,,,,,,,,,,, -417500,4.6280513,0.66251266,,,,,,,,,,,,,, -417600,4.448788,0.61817765,,,,,,,,,,,,,, -417700,5.1777453,0.70707726,,,,,,,,,,,,,, -417767,,,0.9615154266357422,0.1464042663574218,0.7555999755859375,1.051769137382507,50000.0,0.6313000321388245,1.8175561428070068,10000.0,140818.96789312363,145809.70890045166,140818.96789312363,4955.200018882752,19.039517164230347,0.0 -417800,4.37429,0.6117806,,,,,,,,,,,,,, -417900,4.676993,0.65586746,,,,,,,,,,,,,, -418000,4.509298,0.6507877,,,,,,,,,,,,,, -418100,4.34735,0.6555858,,,,,,,,,,,,,, -418200,4.344485,0.6102893,,,,,,,,,,,,,, -418300,4.564213,0.6368718,,,,,,,,,,,,,, -418400,4.7053933,0.64942575,,,,,,,,,,,,,, -418500,4.3983397,0.68191826,,,,,,,,,,,,,, -418600,4.3774037,0.5610858,,,,,,,,,,,,,, -418700,4.493548,0.59872806,,,,,,,,,,,,,, -418800,4.706328,0.6813506,,,,,,,,,,,,,, -418900,4.3761497,0.6521991,,,,,,,,,,,,,, -419000,4.645007,0.6347838,,,,,,,,,,,,,, -419100,4.308516,0.5818838,,,,,,,,,,,,,, -419200,4.7241154,0.6667047,,,,,,,,,,,,,, -419280,,,0.962113320827484,0.1444258242845535,0.7554799914360046,1.0518677234649658,50000.0,0.6313000321388245,1.815688967704773,10000.0,141328.88626480105,146336.88316321373,141328.88626480105,4972.296702861786,19.138569831848145,0.0 -419300,4.1708884,0.57549167,,,,,,,,,,,,,, -419400,4.684954,0.6593752,,,,,,,,,,,,,, -419500,4.570733,0.6090329,,,,,,,,,,,,,, -419600,4.3806725,0.631138,,,,,,,,,,,,,, -419700,4.3468637,0.607089,,,,,,,,,,,,,, -419800,4.305147,0.60478425,,,,,,,,,,,,,, -419900,4.292126,0.63672787,,,,,,,,,,,,,, -420000,4.1755424,0.5947494,,,,,,,,,,,,,, -420100,4.378901,0.59111995,,,,,,,,,,,,,, -420200,4.225561,0.6294399,,,,,,,,,,,,,, -420300,3.9821963,0.57841027,,,,,,,,,,,,,, -420400,4.432326,0.61041254,,,,,,,,,,,,,, -420500,4.143542,0.59478,,,,,,,,,,,,,, -420600,4.5356827,0.7021164,,,,,,,,,,,,,, -420700,4.8749604,0.6941446,,,,,,,,,,,,,, -420794,,,0.9602399468421936,0.146013543009758,0.7559999823570251,1.051907658576965,50000.0,0.6318000555038452,1.81821358203888,10000.0,141838.8102095127,146863.86543941498,141838.8102095127,4989.195043087006,19.23827767372132,0.0 -420800,4.1807985,0.53756344,,,,,,,,,,,,,, -420900,4.5276046,0.65998346,,,,,,,,,,,,,, -421000,4.1690826,0.67087704,,,,,,,,,,,,,, -421100,4.5765476,0.61395097,,,,,,,,,,,,,, -421200,4.8409195,0.6793508,,,,,,,,,,,,,, -421300,4.6556478,0.652091,,,,,,,,,,,,,, -421400,4.3035483,0.62356913,,,,,,,,,,,,,, -421500,4.125626,0.56890225,,,,,,,,,,,,,, -421600,4.7711015,0.628404,,,,,,,,,,,,,, -421700,4.1411476,0.6051687,,,,,,,,,,,,,, -421800,4.5618525,0.59545547,,,,,,,,,,,,,, -421900,5.042066,0.6960572,,,,,,,,,,,,,, -422000,4.9301815,0.70297873,,,,,,,,,,,,,, -422100,4.67393,0.6062256,,,,,,,,,,,,,, -422200,4.598694,0.560469,,,,,,,,,,,,,, -422300,4.4538455,0.65067005,,,,,,,,,,,,,, -422308,,,0.9610171914100648,0.1461603045463562,0.7557399868965149,1.0520353317260742,50000.0,0.6314000487327576,1.8166799545288088,10000.0,142348.95408391953,147391.3078136444,142348.95408391953,5006.32697892189,19.34300827980041,0.0 -422400,5.0821033,0.6846231,,,,,,,,,,,,,, -422500,4.846364,0.6774606,,,,,,,,,,,,,, -422600,4.259693,0.6036024,,,,,,,,,,,,,, -422700,4.671299,0.6548061,,,,,,,,,,,,,, -422800,4.2480826,0.6021693,,,,,,,,,,,,,, -422900,4.6025734,0.5589109,,,,,,,,,,,,,, -423000,4.661126,0.64907223,,,,,,,,,,,,,, -423100,4.426711,0.6658944,,,,,,,,,,,,,, -423200,4.431097,0.6443629,,,,,,,,,,,,,, -423300,5.0002537,0.63151896,,,,,,,,,,,,,, -423400,5.0788493,0.6663623,,,,,,,,,,,,,, -423500,4.443709,0.5936847,,,,,,,,,,,,,, -423600,4.253348,0.554676,,,,,,,,,,,,,, -423700,4.389667,0.60913813,,,,,,,,,,,,,, -423800,4.421277,0.6475811,,,,,,,,,,,,,, -423822,,,0.9603196382522584,0.1475877910852432,0.7555599808692932,1.0516517162322998,50000.0,0.6323000192642212,1.8173928260803225,10000.0,142859.0488152504,147919.34276366234,142859.0488152504,5024.113070011139,19.43835592269897,0.0 -423900,4.9554653,0.7181811,,,,,,,,,,,,,, -424000,4.796106,0.6091487,,,,,,,,,,,,,, -424100,4.0270786,0.59510106,,,,,,,,,,,,,, -424200,4.0329,0.5591854,,,,,,,,,,,,,, -424300,4.6941724,0.6254921,,,,,,,,,,,,,, -424400,4.1183586,0.58736783,,,,,,,,,,,,,, -424500,4.3791523,0.59722424,,,,,,,,,,,,,, -424600,4.1752667,0.631795,,,,,,,,,,,,,, -424700,4.992809,0.6409842,,,,,,,,,,,,,, -424800,5.025099,0.63678265,,,,,,,,,,,,,, -424900,4.71872,0.6423216,,,,,,,,,,,,,, -425000,4.446101,0.62037086,,,,,,,,,,,,,, -425100,4.1391697,0.54845166,,,,,,,,,,,,,, -425200,4.4327006,0.6354727,,,,,,,,,,,,,, -425300,4.822593,0.5887524,,,,,,,,,,,,,, -425335,,,0.9591637253761292,0.150650754570961,0.755899965763092,1.0516971349716189,50000.0,0.631100058555603,1.8176788091659544,10000.0,143369.13305354118,148446.680259943,143369.13305354118,5041.205612659454,19.53954839706421,0.0 -425400,4.5299335,0.6395637,,,,,,,,,,,,,, -425500,4.63863,0.6191766,,,,,,,,,,,,,, -425600,5.049384,0.64088565,,,,,,,,,,,,,, -425700,3.9970772,0.5269244,,,,,,,,,,,,,, -425800,4.3782845,0.53032976,,,,,,,,,,,,,, -425900,4.7372904,0.5670796,,,,,,,,,,,,,, -426000,4.2822328,0.6648544,,,,,,,,,,,,,, -426100,4.303397,0.61374515,,,,,,,,,,,,,, -426200,4.329372,0.6159725,,,,,,,,,,,,,, -426300,4.4442677,0.6337,,,,,,,,,,,,,, -426400,4.421771,0.5859522,,,,,,,,,,,,,, -426500,5.159841,0.6849223,,,,,,,,,,,,,, -426600,4.675652,0.61785954,,,,,,,,,,,,,, -426700,4.67925,0.6747776,,,,,,,,,,,,,, -426800,4.5324316,0.62118244,,,,,,,,,,,,,, -426848,,,0.9608976244926452,0.1458160281181335,0.7559799551963806,1.0528355836868286,50000.0,0.6313000321388245,1.8194468021392824,10000.0,143879.30010700226,148973.93353700638,143879.30010700226,5058.131942987442,19.6390221118927,0.0 -426900,4.1058903,0.65299004,,,,,,,,,,,,,, -427000,4.8183174,0.61720884,,,,,,,,,,,,,, -427100,4.663816,0.65307724,,,,,,,,,,,,,, -427200,4.147375,0.61524934,,,,,,,,,,,,,, -427300,4.4094195,0.6005219,,,,,,,,,,,,,, -427400,4.554811,0.6488156,,,,,,,,,,,,,, -427500,4.4597735,0.57391286,,,,,,,,,,,,,, -427600,4.3543596,0.5697244,,,,,,,,,,,,,, -427700,4.2911634,0.608759,,,,,,,,,,,,,, -427800,4.845448,0.6159359,,,,,,,,,,,,,, -427900,4.024133,0.5739781,,,,,,,,,,,,,, -428000,4.497734,0.6809458,,,,,,,,,,,,,, -428100,4.8759212,0.69450283,,,,,,,,,,,,,, -428200,4.614322,0.65761983,,,,,,,,,,,,,, -428300,4.485961,0.6576877,,,,,,,,,,,,,, -428360,,,0.9611766338348388,0.146694615483284,0.7557199597358704,1.051979660987854,50000.0,0.6307000517845154,1.816765069961548,10000.0,144389.14986610413,149501.11479592323,144389.14986610413,5075.299085140228,19.744468927383423,0.0 -428400,4.4829917,0.6426871,,,,,,,,,,,,,, -428500,4.1669693,0.63091034,,,,,,,,,,,,,, -428600,4.4197364,0.6267132,,,,,,,,,,,,,, -428700,4.2337794,0.5573081,,,,,,,,,,,,,, -428800,4.7924247,0.7054112,,,,,,,,,,,,,, -428900,4.8626423,0.61990505,,,,,,,,,,,,,, -429000,4.5250845,0.5906703,,,,,,,,,,,,,, -429100,4.820677,0.6326847,,,,,,,,,,,,,, -429200,3.9209054,0.5606295,,,,,,,,,,,,,, -429300,4.478366,0.65476733,,,,,,,,,,,,,, -429400,4.3627787,0.64907706,,,,,,,,,,,,,, -429500,4.64354,0.61669546,,,,,,,,,,,,,, -429600,4.5541005,0.61704576,,,,,,,,,,,,,, -429700,4.0622835,0.5578201,,,,,,,,,,,,,, -429800,4.874934,0.67111003,,,,,,,,,,,,,, -429874,,,0.961535394191742,0.1424656212329864,0.7558000087738037,1.051815152168274,50000.0,0.6309000253677368,1.8176268339157104,10000.0,144899.22899580002,150028.48589587212,144899.22899580002,5092.433423757553,19.84219336509705,0.0 -429900,4.584016,0.63268304,,,,,,,,,,,,,, -430000,4.7154202,0.64779437,,,,,,,,,,,,,, -430100,4.596051,0.63634706,,,,,,,,,,,,,, -430200,4.4503813,0.64698416,,,,,,,,,,,,,, -430300,4.948707,0.68246883,,,,,,,,,,,,,, -430400,4.5674767,0.68299025,,,,,,,,,,,,,, -430500,4.071753,0.60441136,,,,,,,,,,,,,, -430600,4.719699,0.5840989,,,,,,,,,,,,,, -430700,4.4513626,0.67730546,,,,,,,,,,,,,, -430800,4.2027736,0.5577418,,,,,,,,,,,,,, -430900,4.474594,0.674691,,,,,,,,,,,,,, -431000,4.490468,0.5761507,,,,,,,,,,,,,, -431100,4.361077,0.65773815,,,,,,,,,,,,,, -431200,4.741372,0.5837367,,,,,,,,,,,,,, -431300,4.914823,0.6682082,,,,,,,,,,,,,, -431387,,,0.961156725883484,0.1466477066278457,0.7562800049781799,1.0512664318084717,50000.0,0.6319000124931335,1.817805528640747,10000.0,145409.23196220398,150555.69521546364,145409.23196220398,5109.480449914932,19.94320559501648,0.0 -431400,4.4495964,0.6128171,,,,,,,,,,,,,, -431500,4.456759,0.5812264,,,,,,,,,,,,,, -431600,4.9641848,0.64135927,,,,,,,,,,,,,, -431700,4.7344184,0.6030691,,,,,,,,,,,,,, -431800,4.116085,0.57967204,,,,,,,,,,,,,, -431900,4.2000585,0.64537305,,,,,,,,,,,,,, -432000,4.39983,0.6192631,,,,,,,,,,,,,, -432100,4.6064577,0.62718403,,,,,,,,,,,,,, -432200,4.1546116,0.62573004,,,,,,,,,,,,,, -432300,4.2865806,0.5788152,,,,,,,,,,,,,, -432400,4.840375,0.6220659,,,,,,,,,,,,,, -432500,4.3784695,0.66861725,,,,,,,,,,,,,, -432600,4.322325,0.5845823,,,,,,,,,,,,,, -432700,4.2781463,0.5581898,,,,,,,,,,,,,, -432800,4.712484,0.6945603,,,,,,,,,,,,,, -432900,4.5263376,0.63558364,,,,,,,,,,,,,, -432901,,,0.9598612785339355,0.1482953429222107,0.7561999559402466,1.0520941019058228,50000.0,0.6319000124931335,1.81902277469635,10000.0,145919.59135246277,151083.30557370186,145919.59135246277,5126.57053732872,20.042986392974854,0.0 -433000,4.163345,0.57251334,,,,,,,,,,,,,, -433100,4.7995906,0.67090875,,,,,,,,,,,,,, -433200,4.7358475,0.6717577,,,,,,,,,,,,,, -433300,4.9673176,0.6801069,,,,,,,,,,,,,, -433400,4.462941,0.64663464,,,,,,,,,,,,,, -433500,4.20163,0.55340487,,,,,,,,,,,,,, -433600,4.3875957,0.60672826,,,,,,,,,,,,,, -433700,4.3355556,0.6514258,,,,,,,,,,,,,, -433800,4.60463,0.64459115,,,,,,,,,,,,,, -433900,4.122826,0.5529051,,,,,,,,,,,,,, -434000,4.4367275,0.6793407,,,,,,,,,,,,,, -434100,4.4448156,0.5760399,,,,,,,,,,,,,, -434200,4.6011424,0.6786828,,,,,,,,,,,,,, -434300,4.208434,0.50558156,,,,,,,,,,,,,, -434400,4.3353424,0.6352708,,,,,,,,,,,,,, -434414,,,0.9611766338348388,0.1458516418933868,0.7561399936676025,1.0522480010986328,50000.0,0.6318000555038452,1.8176249265670776,10000.0,146429.5111811161,151610.5396182537,146429.5111811161,5143.729845046997,20.137884616851807,0.0 -434500,4.7819295,0.62034374,,,,,,,,,,,,,, -434600,4.3070335,0.58937395,,,,,,,,,,,,,, -434700,4.2075195,0.57638043,,,,,,,,,,,,,, -434800,4.5263796,0.64383,,,,,,,,,,,,,, -434900,4.4480705,0.61502755,,,,,,,,,,,,,, -435000,4.764752,0.6790913,,,,,,,,,,,,,, -435100,4.504684,0.59626937,,,,,,,,,,,,,, -435200,4.3309774,0.6019627,,,,,,,,,,,,,, -435300,4.6476746,0.59611326,,,,,,,,,,,,,, -435400,4.5670714,0.7104579,,,,,,,,,,,,,, -435500,4.7963,0.65879816,,,,,,,,,,,,,, -435600,4.644412,0.65398,,,,,,,,,,,,,, -435700,4.6997876,0.621302,,,,,,,,,,,,,, -435800,4.538198,0.5487233,,,,,,,,,,,,,, -435900,4.1830125,0.60114497,,,,,,,,,,,,,, -435928,,,0.9609175324440002,0.1470854878425598,0.7560200095176697,1.0511912107467651,50000.0,0.631100058555603,1.816450834274292,10000.0,146939.64353322983,152137.84930348396,146939.64353322983,5160.749233961105,20.236326932907104,0.0 -436000,4.221586,0.5558396,,,,,,,,,,,,,, -436100,4.45309,0.58425,,,,,,,,,,,,,, -436200,4.261134,0.62224674,,,,,,,,,,,,,, -436300,4.432519,0.6401569,,,,,,,,,,,,,, -436400,4.56679,0.6479778,,,,,,,,,,,,,, -436500,4.5938153,0.66860425,,,,,,,,,,,,,, -436600,4.9880037,0.66943365,,,,,,,,,,,,,, -436700,4.687002,0.70585525,,,,,,,,,,,,,, -436800,5.239253,0.7518404,,,,,,,,,,,,,, -436900,4.416838,0.62023246,,,,,,,,,,,,,, -437000,4.406721,0.63069266,,,,,,,,,,,,,, -437100,4.2020984,0.6310457,,,,,,,,,,,,,, -437200,4.6721554,0.6367845,,,,,,,,,,,,,, -437300,4.6020107,0.6539777,,,,,,,,,,,,,, -437400,4.4551444,0.6704429,,,,,,,,,,,,,, -437441,,,0.9604591727256776,0.1464346051216125,0.7558599710464478,1.0538361072540283,50000.0,0.631100058555603,1.8216536045074463,10000.0,147449.54594254494,152664.7027964592,147449.54594254494,5177.542870759964,20.333744525909424,0.0 -437500,4.343057,0.6354706,,,,,,,,,,,,,, -437600,4.194628,0.5730752,,,,,,,,,,,,,, -437700,4.1889253,0.5924233,,,,,,,,,,,,,, -437800,5.0832376,0.70574415,,,,,,,,,,,,,, -437900,4.215305,0.6110933,,,,,,,,,,,,,, -438000,4.5017295,0.57569563,,,,,,,,,,,,,, -438100,4.0533524,0.5080804,,,,,,,,,,,,,, -438200,4.2495403,0.5800005,,,,,,,,,,,,,, -438300,4.341391,0.6046226,,,,,,,,,,,,,, -438400,4.7691827,0.5973947,,,,,,,,,,,,,, -438500,4.6248465,0.62663,,,,,,,,,,,,,, -438600,4.593476,0.6086437,,,,,,,,,,,,,, -438700,4.3898087,0.580895,,,,,,,,,,,,,, -438800,4.8509145,0.6848803,,,,,,,,,,,,,, -438900,4.353954,0.66001695,,,,,,,,,,,,,, -438955,,,0.9606983065605164,0.1456460505723953,0.7556999921798706,1.0515245199203491,50000.0,0.6310000419616699,1.817052960395813,10000.0,147959.53197169304,153191.72272014618,147959.53197169304,5194.418187379837,20.43113422393799,0.0 -439000,4.476102,0.6086607,,,,,,,,,,,,,, -439100,4.8711176,0.69864285,,,,,,,,,,,,,, -439200,4.413394,0.5785936,,,,,,,,,,,,,, -439300,4.0969887,0.5748171,,,,,,,,,,,,,, -439400,4.6581163,0.64883506,,,,,,,,,,,,,, -439500,4.6384153,0.63922167,,,,,,,,,,,,,, -439600,4.622554,0.6736823,,,,,,,,,,,,,, -439700,4.728083,0.6578579,,,,,,,,,,,,,, -439800,5.0549464,0.6281618,,,,,,,,,,,,,, -439900,4.6121607,0.64071465,,,,,,,,,,,,,, -440000,4.599387,0.66427654,,,,,,,,,,,,,, -440100,4.6259894,0.6344843,,,,,,,,,,,,,, -440200,4.6427364,0.67498237,,,,,,,,,,,,,, -440300,4.275525,0.55143607,,,,,,,,,,,,,, -440400,4.156614,0.6511433,,,,,,,,,,,,,, -440468,,,0.9604192972183228,0.1487842202186584,0.7554599642753601,1.0531373023986816,50000.0,0.6320000290870667,1.8179895877838133,10000.0,148469.47192668915,153718.99074602127,148469.47192668915,5211.586438894272,20.53074598312378,0.0 -440500,4.6042194,0.6786505,,,,,,,,,,,,,, -440600,4.4684706,0.6588709,,,,,,,,,,,,,, -440700,4.6488786,0.64376014,,,,,,,,,,,,,, -440800,4.665127,0.63288903,,,,,,,,,,,,,, -440900,4.640181,0.6697897,,,,,,,,,,,,,, -441000,4.1076546,0.59244734,,,,,,,,,,,,,, -441100,4.796588,0.6289946,,,,,,,,,,,,,, -441200,4.3803406,0.64811087,,,,,,,,,,,,,, -441300,4.5242505,0.657361,,,,,,,,,,,,,, -441400,4.807711,0.66450846,,,,,,,,,,,,,, -441500,4.279675,0.581315,,,,,,,,,,,,,, -441600,4.473799,0.6658606,,,,,,,,,,,,,, -441700,4.458708,0.6542407,,,,,,,,,,,,,, -441800,4.3926425,0.6035054,,,,,,,,,,,,,, -441900,4.557431,0.75297654,,,,,,,,,,,,,, -441982,,,0.9610371589660645,0.1458581238985061,0.7560999989509583,1.0520803928375244,50000.0,0.6314000487327576,1.8177663087844849,10000.0,148979.646399498,154246.31707787514,148979.646399498,5228.576928853989,20.633300065994263,0.0 -442000,4.252699,0.58459204,,,,,,,,,,,,,, -442100,4.453128,0.64807206,,,,,,,,,,,,,, -442200,4.353314,0.6644783,,,,,,,,,,,,,, -442300,4.4573464,0.6742633,,,,,,,,,,,,,, -442400,4.4473815,0.65080154,,,,,,,,,,,,,, -442500,4.9481053,0.6806265,,,,,,,,,,,,,, -442600,4.0628915,0.6113576,,,,,,,,,,,,,, -442700,4.707244,0.6659179,,,,,,,,,,,,,, -442800,4.7507453,0.5931916,,,,,,,,,,,,,, -442900,4.0146613,0.60846025,,,,,,,,,,,,,, -443000,4.551867,0.63409525,,,,,,,,,,,,,, -443100,4.287912,0.5917286,,,,,,,,,,,,,, -443200,4.7279706,0.61228436,,,,,,,,,,,,,, -443300,4.7870893,0.65684634,,,,,,,,,,,,,, -443400,4.3238354,0.6483834,,,,,,,,,,,,,, -443496,,,0.960180163383484,0.1478644162416458,0.7556999921798706,1.0514676570892334,50000.0,0.6315000057220459,1.8164036273956297,10000.0,149489.79416179657,154773.68194508553,149489.79416179657,5245.627908229828,20.737919569015503,0.0 -443500,4.61207,0.61003137,,,,,,,,,,,,,, -443600,4.258575,0.6548349,,,,,,,,,,,,,, -443700,4.7555494,0.61892587,,,,,,,,,,,,,, -443800,4.742627,0.64058125,,,,,,,,,,,,,, -443900,4.74522,0.60284394,,,,,,,,,,,,,, -444000,5.4107747,0.5819397,,,,,,,,,,,,,, -444100,4.6455393,0.65589595,,,,,,,,,,,,,, -444200,4.761913,0.72183853,,,,,,,,,,,,,, -444300,4.743219,0.63660336,,,,,,,,,,,,,, -444400,5.5113525,0.6127218,,,,,,,,,,,,,, -444500,4.60137,0.63255715,,,,,,,,,,,,,, -444600,4.578803,0.63419795,,,,,,,,,,,,,, -444700,4.036925,0.58638453,,,,,,,,,,,,,, -444800,4.248477,0.5504467,,,,,,,,,,,,,, -444900,4.5870337,0.6900714,,,,,,,,,,,,,, -445000,4.262877,0.61671805,,,,,,,,,,,,,, -445010,,,0.9609375,0.1475970149040222,0.7556999921798706,1.0521252155303955,50000.0,0.6312000155448914,1.817966341972351,10000.0,149999.8376750946,155300.85563635826,149999.8376750946,5262.597207069397,20.838828086853027,0.0 -445100,4.9548664,0.6430147,,,,,,,,,,,,,, -445200,4.0723214,0.55949694,,,,,,,,,,,,,, -445300,4.2818594,0.54562706,,,,,,,,,,,,,, -445400,4.490034,0.5867473,,,,,,,,,,,,,, -445500,4.648473,0.626114,,,,,,,,,,,,,, -445600,5.049037,0.69517493,,,,,,,,,,,,,, -445700,4.3660364,0.6020452,,,,,,,,,,,,,, -445800,4.0412526,0.56265557,,,,,,,,,,,,,, -445900,4.3670797,0.5732203,,,,,,,,,,,,,, -446000,4.540275,0.6539266,,,,,,,,,,,,,, -446100,4.223288,0.5684486,,,,,,,,,,,,,, -446200,4.283509,0.6542965,,,,,,,,,,,,,, -446300,4.7773104,0.700923,,,,,,,,,,,,,, -446400,4.7704453,0.57540786,,,,,,,,,,,,,, -446500,4.5631924,0.6599072,,,,,,,,,,,,,, -446523,,,0.9608378410339355,0.1482971906661987,0.7558599710464478,1.0520795583724976,50000.0,0.6313000321388245,1.8180619478225708,10000.0,150509.80757832527,155828.0451927185,150509.80757832527,5279.649640798569,20.945537328720093,0.0 -446600,4.6750355,0.62055194,,,,,,,,,,,,,, -446700,4.785571,0.57623947,,,,,,,,,,,,,, -446800,4.5171824,0.6308093,,,,,,,,,,,,,, -446900,4.9207697,0.6190389,,,,,,,,,,,,,, -447000,4.445839,0.59126335,,,,,,,,,,,,,, -447100,5.1071224,0.6493304,,,,,,,,,,,,,, -447200,4.5809884,0.55536926,,,,,,,,,,,,,, -447300,4.1039534,0.63252366,,,,,,,,,,,,,, -447400,4.650386,0.6394574,,,,,,,,,,,,,, -447500,4.5382123,0.6089227,,,,,,,,,,,,,, -447600,4.0366406,0.5188244,,,,,,,,,,,,,, -447700,5.2590556,0.7059097,,,,,,,,,,,,,, -447800,4.5461574,0.59399694,,,,,,,,,,,,,, -447900,4.6512856,0.66344297,,,,,,,,,,,,,, -448000,4.4081,0.59363323,,,,,,,,,,,,,, -448037,,,0.961316168308258,0.1471826732158661,0.7558599710464478,1.0523767471313477,50000.0,0.6306000351905823,1.818080186843872,10000.0,151019.7969942093,156354.96886587143,151019.7969942093,5296.409289121628,21.0596821308136,0.0 -448100,4.132063,0.66041654,,,,,,,,,,,,,, -448200,4.752772,0.70573777,,,,,,,,,,,,,, -448300,4.6758327,0.63906944,,,,,,,,,,,,,, -448400,4.90225,0.6584148,,,,,,,,,,,,,, -448500,4.6847615,0.6471571,,,,,,,,,,,,,, -448600,4.6541033,0.5473981,,,,,,,,,,,,,, -448700,4.452519,0.6201119,,,,,,,,,,,,,, -448800,4.750634,0.6689105,,,,,,,,,,,,,, -448900,4.617572,0.6169873,,,,,,,,,,,,,, -449000,4.491372,0.6500738,,,,,,,,,,,,,, -449100,4.199227,0.62142617,,,,,,,,,,,,,, -449200,4.4178934,0.6227229,,,,,,,,,,,,,, -449300,4.5905566,0.6473447,,,,,,,,,,,,,, -449400,4.1399493,0.56185675,,,,,,,,,,,,,, -449500,4.440641,0.64376926,,,,,,,,,,,,,, -449550,,,0.960339605808258,0.1485464125871658,0.7556599974632263,1.050680160522461,50000.0,0.6308000087738037,1.816544771194458,10000.0,151529.86161398888,156882.26883530617,151529.86161398888,5313.485778331757,21.15876936912537,0.0 -449600,4.322067,0.6120938,,,,,,,,,,,,,, -449700,4.3098884,0.58953005,,,,,,,,,,,,,, -449800,4.541608,0.6515713,,,,,,,,,,,,,, -449900,4.638589,0.5992982,,,,,,,,,,,,,, -450000,4.7717767,0.6144394,,,,,,,,,,,,,, -450100,4.0698986,0.620485,,,,,,,,,,,,,, -450200,4.3615746,0.62595004,,,,,,,,,,,,,, -450300,4.85921,0.7024527,,,,,,,,,,,,,, -450400,4.3048077,0.6456291,,,,,,,,,,,,,, -450500,4.8050556,0.6488811,,,,,,,,,,,,,, -450600,4.5697465,0.6260922,,,,,,,,,,,,,, -450700,4.920121,0.6567244,,,,,,,,,,,,,, -450800,4.9849606,0.5589084,,,,,,,,,,,,,, -450900,4.6491704,0.63699937,,,,,,,,,,,,,, -451000,5.0331426,0.6490216,,,,,,,,,,,,,, -451063,,,0.959622085094452,0.1482224911451339,0.7559399604797363,1.0514514446258545,50000.0,0.631600022315979,1.817432284355164,10000.0,152039.7231528759,157409.22890758514,152039.7231528759,5330.426630020142,21.256157636642456,0.0 -451100,4.394561,0.63029015,,,,,,,,,,,,,, -451200,4.3690934,0.65372074,,,,,,,,,,,,,, -451300,4.0509176,0.60957843,,,,,,,,,,,,,, -451400,4.3348317,0.61789215,,,,,,,,,,,,,, -451500,4.7035794,0.6670234,,,,,,,,,,,,,, -451600,4.7337832,0.58435136,,,,,,,,,,,,,, -451700,4.585277,0.618375,,,,,,,,,,,,,, -451800,4.0463376,0.5125153,,,,,,,,,,,,,, -451900,4.574614,0.6560912,,,,,,,,,,,,,, -452000,4.4394712,0.6341732,,,,,,,,,,,,,, -452100,4.634022,0.59563786,,,,,,,,,,,,,, -452200,4.4959445,0.6936158,,,,,,,,,,,,,, -452300,4.2486296,0.5915239,,,,,,,,,,,,,, -452400,4.6821384,0.6370293,,,,,,,,,,,,,, -452500,5.0745544,0.62112963,,,,,,,,,,,,,, -452577,,,0.9602598547935486,0.1482243835926056,0.7552199959754944,1.0523943901062012,50000.0,0.631600022315979,1.817237973213196,10000.0,152549.7496676445,157936.49866342545,152549.7496676445,5347.50834107399,21.35846757888794,0.0 -452600,4.820449,0.63181615,,,,,,,,,,,,,, -452700,4.561577,0.65289086,,,,,,,,,,,,,, -452800,4.600182,0.5829258,,,,,,,,,,,,,, -452900,4.4899635,0.6514708,,,,,,,,,,,,,, -453000,4.462801,0.6063249,,,,,,,,,,,,,, -453100,4.253273,0.6153121,,,,,,,,,,,,,, -453200,4.188565,0.5570726,,,,,,,,,,,,,, -453300,4.594119,0.6275826,,,,,,,,,,,,,, -453400,4.5806093,0.573287,,,,,,,,,,,,,, -453500,4.1345654,0.59162164,,,,,,,,,,,,,, -453600,4.3978615,0.660065,,,,,,,,,,,,,, -453700,4.5009613,0.6287026,,,,,,,,,,,,,, -453800,4.8184476,0.6426498,,,,,,,,,,,,,, -453900,4.900924,0.63160884,,,,,,,,,,,,,, -454000,4.31427,0.6114653,,,,,,,,,,,,,, -454090,,,0.9610171914100648,0.1476788073778152,0.7560399770736694,1.052739143371582,50000.0,0.6305000185966492,1.818716526031494,10000.0,153059.6380867958,158463.60263371468,153059.6380867958,5364.562657833099,21.45997190475464,0.0 -454100,5.0683827,0.69477135,,,,,,,,,,,,,, -454200,4.8353324,0.66011333,,,,,,,,,,,,,, -454300,4.9219413,0.7141595,,,,,,,,,,,,,, -454400,4.952548,0.6683508,,,,,,,,,,,,,, -454500,4.925128,0.6842681,,,,,,,,,,,,,, -454600,4.278722,0.6724676,,,,,,,,,,,,,, -454700,4.643089,0.621301,,,,,,,,,,,,,, -454800,4.628242,0.6727255,,,,,,,,,,,,,, -454900,4.057986,0.5251937,,,,,,,,,,,,,, -455000,4.3330693,0.62243295,,,,,,,,,,,,,, -455100,4.381782,0.66740835,,,,,,,,,,,,,, -455200,4.2108345,0.6139364,,,,,,,,,,,,,, -455300,4.664702,0.6427263,,,,,,,,,,,,,, -455400,5.0806885,0.686158,,,,,,,,,,,,,, -455500,4.1344647,0.65152085,,,,,,,,,,,,,, -455600,4.5187254,0.6727304,,,,,,,,,,,,,, -455603,,,0.960598647594452,0.1480612307786941,0.7557599544525146,1.05259907245636,50000.0,0.6310000419616699,1.818840265274048,10000.0,153569.5043578148,158990.47890138626,153569.5043578148,5381.410010099411,21.56362199783325,0.0 -455700,4.6252837,0.57751524,,,,,,,,,,,,,, -455800,4.709147,0.5767498,,,,,,,,,,,,,, -455900,4.262402,0.56575155,,,,,,,,,,,,,, -456000,4.0365286,0.59076977,,,,,,,,,,,,,, -456100,4.3625426,0.63173574,,,,,,,,,,,,,, -456200,4.1823936,0.60058033,,,,,,,,,,,,,, -456300,4.1542296,0.59760237,,,,,,,,,,,,,, -456400,4.206452,0.55736345,,,,,,,,,,,,,, -456500,4.2342935,0.6448126,,,,,,,,,,,,,, -456600,4.2732606,0.6237894,,,,,,,,,,,,,, -456700,4.604229,0.6457079,,,,,,,,,,,,,, -456800,4.1151648,0.5851135,,,,,,,,,,,,,, -456900,4.289604,0.6731737,,,,,,,,,,,,,, -457000,4.7703366,0.60838944,,,,,,,,,,,,,, -457100,4.3841434,0.5913775,,,,,,,,,,,,,, -457118,,,0.9614157676696776,0.146322026848793,0.7562999725341797,1.0514073371887207,50000.0,0.6318000555038452,1.817526817321777,10000.0,154079.6365056038,159517.92832422256,154079.6365056038,5398.56386756897,21.66671848297119,0.0 -457200,4.5331616,0.6283471,,,,,,,,,,,,,, -457300,4.4820237,0.62018746,,,,,,,,,,,,,, -457400,4.3787427,0.5874953,,,,,,,,,,,,,, -457500,4.540603,0.6273548,,,,,,,,,,,,,, -457600,4.5435324,0.6578032,,,,,,,,,,,,,, -457700,4.165054,0.57186306,,,,,,,,,,,,,, -457800,4.7752957,0.6652376,,,,,,,,,,,,,, -457900,4.147522,0.5690885,,,,,,,,,,,,,, -458000,4.6768093,0.6499919,,,,,,,,,,,,,, -458100,4.449574,0.62694997,,,,,,,,,,,,,, -458200,4.1221304,0.5733053,,,,,,,,,,,,,, -458300,4.8425794,0.6660287,,,,,,,,,,,,,, -458400,4.2719193,0.5719271,,,,,,,,,,,,,, -458500,4.639472,0.6701716,,,,,,,,,,,,,, -458600,4.3876004,0.61381614,,,,,,,,,,,,,, -458632,,,0.9614955186843872,0.1443223059177398,0.7561999559402466,1.0519403219223022,50000.0,0.6308000087738037,1.818753242492676,10000.0,154589.79145264626,160045.3290655613,154589.79145264626,5415.647324562073,21.76831364631653,0.0 -458700,4.705534,0.6874005,,,,,,,,,,,,,, -458800,4.82255,0.7186389,,,,,,,,,,,,,, -458900,4.9835258,0.6544926,,,,,,,,,,,,,, -459000,4.540142,0.6994031,,,,,,,,,,,,,, -459100,4.675516,0.64711916,,,,,,,,,,,,,, -459200,4.0866566,0.565881,,,,,,,,,,,,,, -459300,4.5643573,0.6556909,,,,,,,,,,,,,, -459400,4.5163827,0.668055,,,,,,,,,,,,,, -459500,4.5090437,0.59871405,,,,,,,,,,,,,, -459600,4.24141,0.6283089,,,,,,,,,,,,,, -459700,4.7560225,0.6094018,,,,,,,,,,,,,, -459800,4.8472433,0.5771663,,,,,,,,,,,,,, -459900,4.641816,0.63290334,,,,,,,,,,,,,, -460000,4.4335766,0.65896106,,,,,,,,,,,,,, -460100,4.278243,0.6490254,,,,,,,,,,,,,, -460147,,,0.9606783986091614,0.1461041569709777,0.7561799883842468,1.051734209060669,50000.0,0.6310000419616699,1.8186067342758176,10000.0,155099.9243299961,160572.79992508888,155099.9243299961,5432.796221256256,21.89662194252014,0.0 -460200,4.7468443,0.64982283,,,,,,,,,,,,,, -460300,4.1383333,0.5588026,,,,,,,,,,,,,, -460400,4.108269,0.53545743,,,,,,,,,,,,,, -460500,4.7649198,0.68677574,,,,,,,,,,,,,, -460600,4.447399,0.66975343,,,,,,,,,,,,,, -460700,4.5460663,0.6388525,,,,,,,,,,,,,, -460800,4.6698914,0.62176836,,,,,,,,,,,,,, -460900,4.448186,0.6505865,,,,,,,,,,,,,, -461000,4.687106,0.675189,,,,,,,,,,,,,, -461100,4.4852953,0.6362347,,,,,,,,,,,,,, -461200,4.5292273,0.64166474,,,,,,,,,,,,,, -461300,4.736012,0.6484477,,,,,,,,,,,,,, -461400,4.5956497,0.68991905,,,,,,,,,,,,,, -461500,5.116833,0.71075493,,,,,,,,,,,,,, -461600,4.241922,0.58959806,,,,,,,,,,,,,, -461660,,,0.9598014950752258,0.1478946059942245,0.7555399537086487,1.0528876781463623,50000.0,0.6314000487327576,1.8191816806793213,10000.0,155609.88244462013,161100.4083454609,155609.88244462013,5450.276737928391,22.00607419013977,0.0 -461700,4.1818595,0.5334648,,,,,,,,,,,,,, -461800,4.1474113,0.5884421,,,,,,,,,,,,,, -461900,4.5603385,0.5570297,,,,,,,,,,,,,, -462000,4.312496,0.60019153,,,,,,,,,,,,,, -462100,4.333494,0.63453156,,,,,,,,,,,,,, -462200,4.243779,0.5709623,,,,,,,,,,,,,, -462300,4.5525265,0.6497938,,,,,,,,,,,,,, -462400,4.649617,0.69317615,,,,,,,,,,,,,, -462500,4.59393,0.67770624,,,,,,,,,,,,,, -462600,4.0949984,0.60669535,,,,,,,,,,,,,, -462700,4.0062065,0.5617215,,,,,,,,,,,,,, -462800,4.4240136,0.58061033,,,,,,,,,,,,,, -462900,4.386546,0.65128917,,,,,,,,,,,,,, -463000,4.1172237,0.5603423,,,,,,,,,,,,,, -463100,4.718004,0.6083956,,,,,,,,,,,,,, -463174,,,0.9601004123687744,0.1464466303586959,0.7556599974632263,1.052324891090393,50000.0,0.6308000087738037,1.818250298500061,10000.0,156119.75332951546,161627.33212256432,156119.75332951546,5467.160728693008,22.11421036720276,0.0 -463200,4.695417,0.6645385,,,,,,,,,,,,,, -463300,4.512946,0.59332097,,,,,,,,,,,,,, -463400,4.167543,0.5491659,,,,,,,,,,,,,, -463500,4.4693403,0.614252,,,,,,,,,,,,,, -463600,4.581692,0.66147965,,,,,,,,,,,,,, -463700,4.2472515,0.53445506,,,,,,,,,,,,,, -463800,4.2404046,0.55398,,,,,,,,,,,,,, -463900,5.223947,0.68012846,,,,,,,,,,,,,, -464000,4.371213,0.55487126,,,,,,,,,,,,,, -464100,4.5568604,0.7013492,,,,,,,,,,,,,, -464200,4.834661,0.6636515,,,,,,,,,,,,,, -464300,4.595499,0.66542274,,,,,,,,,,,,,, -464400,4.5723715,0.71694076,,,,,,,,,,,,,, -464500,4.1305733,0.5643418,,,,,,,,,,,,,, -464600,4.275094,0.6177617,,,,,,,,,,,,,, -464687,,,0.9594228267669678,0.1503064781427383,0.7555800080299377,1.0526986122131348,50000.0,0.631600022315979,1.8184653520584104,10000.0,156629.86915493011,162154.48467326164,156629.86915493011,5484.03133893013,22.21840262413025,0.0 -464700,4.314864,0.62543315,,,,,,,,,,,,,, -464800,4.30705,0.67580515,,,,,,,,,,,,,, -464900,4.437059,0.6262197,,,,,,,,,,,,,, -465000,4.339929,0.64688075,,,,,,,,,,,,,, -465100,4.356972,0.62821877,,,,,,,,,,,,,, -465200,4.328011,0.56070995,,,,,,,,,,,,,, -465300,4.3054094,0.61362845,,,,,,,,,,,,,, -465400,4.4438543,0.6278969,,,,,,,,,,,,,, -465500,4.713173,0.67205703,,,,,,,,,,,,,, -465600,4.919758,0.7168754,,,,,,,,,,,,,, -465700,4.3582153,0.5997289,,,,,,,,,,,,,, -465800,4.5339494,0.6830931,,,,,,,,,,,,,, -465900,4.393386,0.62732095,,,,,,,,,,,,,, -466000,4.7960744,0.6831803,,,,,,,,,,,,,, -466100,4.5878487,0.6615964,,,,,,,,,,,,,, -466200,4.499541,0.56676507,,,,,,,,,,,,,, -466202,,,0.962292730808258,0.1432087272405624,0.7555199861526489,1.0532876253128052,50000.0,0.6315000057220459,1.817851424217224,10000.0,157139.771348238,162681.6063232422,157139.771348238,5501.084171056747,22.32434344291687,0.0 -466300,5.306,0.694237,,,,,,,,,,,,,, -466400,4.668536,0.65612775,,,,,,,,,,,,,, -466500,4.548494,0.58600646,,,,,,,,,,,,,, -466600,4.2679906,0.55794394,,,,,,,,,,,,,, -466700,4.5620503,0.610413,,,,,,,,,,,,,, -466800,4.41496,0.7214927,,,,,,,,,,,,,, -466900,4.6297646,0.6373838,,,,,,,,,,,,,, -467000,4.1761065,0.5464944,,,,,,,,,,,,,, -467100,4.8670206,0.6346197,,,,,,,,,,,,,, -467200,4.246218,0.6063242,,,,,,,,,,,,,, -467300,4.582034,0.6279607,,,,,,,,,,,,,, -467400,4.5896983,0.74103856,,,,,,,,,,,,,, -467500,4.348156,0.6331897,,,,,,,,,,,,,, -467600,4.4027534,0.63782394,,,,,,,,,,,,,, -467700,4.4861984,0.6178168,,,,,,,,,,,,,, -467716,,,0.960957407951355,0.1454687416553497,0.7559199929237366,1.0522892475128174,50000.0,0.6317000389099121,1.816857933998108,10000.0,157649.81596922874,163208.80686163902,157649.81596922874,5518.072999954224,22.43128538131714,0.0 -467800,4.8313656,0.66541237,,,,,,,,,,,,,, -467900,4.6292067,0.56306803,,,,,,,,,,,,,, -468000,4.7657747,0.64851,,,,,,,,,,,,,, -468100,4.744932,0.6461568,,,,,,,,,,,,,, -468200,4.413794,0.63517976,,,,,,,,,,,,,, -468300,4.0643973,0.60413146,,,,,,,,,,,,,, -468400,4.3805213,0.62296265,,,,,,,,,,,,,, -468500,4.6670294,0.630331,,,,,,,,,,,,,, -468600,4.505464,0.5952327,,,,,,,,,,,,,, -468700,4.178665,0.5926189,,,,,,,,,,,,,, -468800,4.661387,0.66478175,,,,,,,,,,,,,, -468900,4.2635245,0.6332548,,,,,,,,,,,,,, -469000,4.35104,0.6274085,,,,,,,,,,,,,, -469100,4.596929,0.7178442,,,,,,,,,,,,,, -469200,4.754068,0.64537734,,,,,,,,,,,,,, -469230,,,0.9608577489852904,0.146379217505455,0.7557199597358704,1.0511075258255005,50000.0,0.6313000321388245,1.8159784078598025,10000.0,158159.90993785858,163736.03716516495,158159.90993785858,5535.045620441437,22.535234928131104,0.0 -469300,4.0776663,0.58364916,,,,,,,,,,,,,, -469400,4.7250786,0.6847247,,,,,,,,,,,,,, -469500,5.12349,0.6652187,,,,,,,,,,,,,, -469600,4.5906982,0.63765246,,,,,,,,,,,,,, -469700,4.8583617,0.5937125,,,,,,,,,,,,,, -469800,4.4525657,0.67906195,,,,,,,,,,,,,, -469900,4.279833,0.54210293,,,,,,,,,,,,,, -470000,4.421992,0.6260104,,,,,,,,,,,,,, -470100,4.284515,0.6175889,,,,,,,,,,,,,, -470200,4.2820077,0.56520605,,,,,,,,,,,,,, -470300,4.822522,0.5940416,,,,,,,,,,,,,, -470400,4.7626705,0.6251104,,,,,,,,,,,,,, -470500,4.396917,0.6402465,,,,,,,,,,,,,, -470600,5.1633887,0.7193533,,,,,,,,,,,,,, -470700,4.425194,0.7270344,,,,,,,,,,,,,, -470744,,,0.9602598547935486,0.1471463739871978,0.7562199831008911,1.051742672920227,50000.0,0.631100058555603,1.817716717720032,10000.0,158670.03724741936,164263.49721503258,158670.03724741936,5552.206401586533,22.64659857749939,0.0 -470800,5.0062165,0.6504673,,,,,,,,,,,,,, -470900,4.9288926,0.67619246,,,,,,,,,,,,,, -471000,4.666561,0.64617527,,,,,,,,,,,,,, -471100,4.311086,0.7000112,,,,,,,,,,,,,, -471200,4.210423,0.59602034,,,,,,,,,,,,,, -471300,4.2236643,0.60493857,,,,,,,,,,,,,, -471400,4.349289,0.61618036,,,,,,,,,,,,,, -471500,4.083086,0.58695203,,,,,,,,,,,,,, -471600,4.6948504,0.6720184,,,,,,,,,,,,,, -471700,4.373357,0.591769,,,,,,,,,,,,,, -471800,4.2951756,0.576086,,,,,,,,,,,,,, -471900,4.51762,0.6819587,,,,,,,,,,,,,, -472000,4.560212,0.5908325,,,,,,,,,,,,,, -472100,4.7015038,0.6915472,,,,,,,,,,,,,, -472200,4.298614,0.60524184,,,,,,,,,,,,,, -472258,,,0.9595224857330322,0.1507411003112793,0.7557199597358704,1.0522160530090332,50000.0,0.6324000358581543,1.818382143974304,10000.0,159179.998026371,164790.60931801796,159179.998026371,5569.197404384613,22.748106718063354,0.0 -472300,5.2210884,0.6490095,,,,,,,,,,,,,, -472400,4.758487,0.69655097,,,,,,,,,,,,,, -472500,4.7287364,0.6808112,,,,,,,,,,,,,, -472600,4.823655,0.618078,,,,,,,,,,,,,, -472700,4.2309175,0.5829436,,,,,,,,,,,,,, -472800,4.3670464,0.6356433,,,,,,,,,,,,,, -472900,4.8187504,0.7042734,,,,,,,,,,,,,, -473000,4.5863833,0.6151053,,,,,,,,,,,,,, -473100,4.4807134,0.5581898,,,,,,,,,,,,,, -473200,4.5418715,0.57911533,,,,,,,,,,,,,, -473300,4.186321,0.5929471,,,,,,,,,,,,,, -473400,4.6989603,0.6158864,,,,,,,,,,,,,, -473500,4.531634,0.64988154,,,,,,,,,,,,,, -473600,5.2678823,0.5961142,,,,,,,,,,,,,, -473700,4.5234337,0.6142261,,,,,,,,,,,,,, -473772,,,0.9618542790412904,0.1431065946817398,0.7560200095176697,1.0513827800750732,50000.0,0.631100058555603,1.817054986953736,10000.0,159690.03354144096,165317.8332746029,159690.03354144096,5586.219715356827,22.854050636291504,0.0 -473800,4.470822,0.6139151,,,,,,,,,,,,,, -473900,4.411835,0.65160733,,,,,,,,,,,,,, -474000,4.475806,0.65252054,,,,,,,,,,,,,, -474100,4.639913,0.56828326,,,,,,,,,,,,,, -474200,5.269049,0.6000123,,,,,,,,,,,,,, -474300,4.434476,0.61283237,,,,,,,,,,,,,, -474400,4.429776,0.59699035,,,,,,,,,,,,,, -474500,4.354177,0.59306157,,,,,,,,,,,,,, -474600,4.1822953,0.5802151,,,,,,,,,,,,,, -474700,4.9142466,0.7037332,,,,,,,,,,,,,, -474800,4.612373,0.68171316,,,,,,,,,,,,,, -474900,4.112434,0.57001495,,,,,,,,,,,,,, -475000,4.4587407,0.6333604,,,,,,,,,,,,,, -475100,4.469659,0.64382356,,,,,,,,,,,,,, -475200,4.77536,0.61915284,,,,,,,,,,,,,, -475286,,,0.9608976244926452,0.1456860154867172,0.7561599612236023,1.051494836807251,50000.0,0.6313000321388245,1.8169790506362915,10000.0,160199.90467762947,165844.797778368,160199.90467762947,5603.152618169785,22.95527720451355,0.0 -475300,4.693511,0.5908392,,,,,,,,,,,,,, -475400,4.796514,0.6717345,,,,,,,,,,,,,, -475500,4.586147,0.5898753,,,,,,,,,,,,,, -475600,4.7731886,0.63058877,,,,,,,,,,,,,, -475700,4.2593713,0.58741266,,,,,,,,,,,,,, -475800,4.330352,0.60007644,,,,,,,,,,,,,, -475900,4.3057594,0.61237925,,,,,,,,,,,,,, -476000,4.49925,0.61573416,,,,,,,,,,,,,, -476100,4.6774282,0.61413753,,,,,,,,,,,,,, -476200,4.490505,0.61705923,,,,,,,,,,,,,, -476300,4.4313984,0.54818165,,,,,,,,,,,,,, -476400,4.302511,0.6553137,,,,,,,,,,,,,, -476500,5.024636,0.674431,,,,,,,,,,,,,, -476600,4.424009,0.5835403,,,,,,,,,,,,,, -476700,4.588826,0.6221323,,,,,,,,,,,,,, -476800,,,0.9608577489852904,0.1460122913122177,0.7555800080299377,1.0520609617233276,50000.0,0.6307000517845154,1.8172602653503416,10000.0,160710.0825135708,166372.19257593155,160710.0825135708,5620.208205223084,23.058412313461304,0.0 -476800,4.399397,0.5688158,,,,,,,,,,,,,, -476900,4.4074397,0.57906723,,,,,,,,,,,,,, -477000,4.226022,0.63122493,,,,,,,,,,,,,, -477100,4.8088403,0.6530325,,,,,,,,,,,,,, -477200,4.615044,0.62181723,,,,,,,,,,,,,, -477300,4.846245,0.69402623,,,,,,,,,,,,,, -477400,4.2308702,0.55076665,,,,,,,,,,,,,, -477500,4.308645,0.58033586,,,,,,,,,,,,,, -477600,4.631323,0.5888501,,,,,,,,,,,,,, -477700,4.3036227,0.56478065,,,,,,,,,,,,,, -477800,4.6895576,0.7015382,,,,,,,,,,,,,, -477900,4.260207,0.6175538,,,,,,,,,,,,,, -478000,4.6432567,0.6508823,,,,,,,,,,,,,, -478100,4.163599,0.6405585,,,,,,,,,,,,,, -478200,4.4793963,0.6071214,,,,,,,,,,,,,, -478300,4.3124375,0.66560477,,,,,,,,,,,,,, -478314,,,0.9600406289100648,0.1492892503738403,0.7554000020027161,1.0533448457717896,50000.0,0.6310000419616699,1.819790363311768,10000.0,161220.12222194672,166899.53250861168,161220.12222194672,5637.344567537308,23.16335129737854,0.0 -478400,4.6066947,0.62053645,,,,,,,,,,,,,, -478500,4.34228,0.6967187,,,,,,,,,,,,,, -478600,5.120399,0.6072108,,,,,,,,,,,,,, -478700,4.4503026,0.6227905,,,,,,,,,,,,,, -478800,4.5517116,0.66180843,,,,,,,,,,,,,, -478900,4.3647513,0.60903126,,,,,,,,,,,,,, -479000,4.7706027,0.60228986,,,,,,,,,,,,,, -479100,3.9905493,0.5650461,,,,,,,,,,,,,, -479200,4.6836734,0.66702896,,,,,,,,,,,,,, -479300,4.090514,0.6292099,,,,,,,,,,,,,, -479400,4.5787473,0.7157301,,,,,,,,,,,,,, -479500,4.4462514,0.57026106,,,,,,,,,,,,,, -479600,4.5980744,0.67364866,,,,,,,,,,,,,, -479700,4.331855,0.52787757,,,,,,,,,,,,,, -479800,5.2847533,0.62436515,,,,,,,,,,,,,, -479827,,,0.9615553021430968,0.1434130817651748,0.755620002746582,1.0522792339324951,50000.0,0.6304000020027161,1.8189347982406616,10000.0,161729.96871495247,167426.6479022503,161729.96871495247,5654.447158336639,23.269815683364868,0.0 -479900,4.651594,0.6738502,,,,,,,,,,,,,, -480000,4.483853,0.61158615,,,,,,,,,,,,,, -480100,4.937365,0.6631569,,,,,,,,,,,,,, -480200,4.635574,0.67203975,,,,,,,,,,,,,, -480300,4.5486197,0.6445265,,,,,,,,,,,,,, -480400,4.557753,0.6400571,,,,,,,,,,,,,, -480500,4.5909634,0.66192436,,,,,,,,,,,,,, -480600,4.6846843,0.61174965,,,,,,,,,,,,,, -480700,4.025608,0.51509035,,,,,,,,,,,,,, -480800,5.2064214,0.6670109,,,,,,,,,,,,,, -480900,4.6024623,0.68034774,,,,,,,,,,,,,, -481000,4.6273623,0.58877194,,,,,,,,,,,,,, -481100,4.52229,0.59835434,,,,,,,,,,,,,, -481200,4.8637705,0.6669262,,,,,,,,,,,,,, -481300,4.617269,0.6312505,,,,,,,,,,,,,, -481342,,,0.9592434167861938,0.1493992060422897,0.7554799914360046,1.052570343017578,50000.0,0.6318000555038452,1.81726336479187,10000.0,162239.8997218609,167953.64355373383,162239.8997218609,5671.34614443779,23.37634825706482,0.0 -481400,4.329079,0.60817164,,,,,,,,,,,,,, -481500,4.1822543,0.6776261,,,,,,,,,,,,,, -481600,5.1376257,0.66912884,,,,,,,,,,,,,, -481700,4.352134,0.5946742,,,,,,,,,,,,,, -481800,5.1103563,0.57112116,,,,,,,,,,,,,, -481900,4.7549834,0.6376563,,,,,,,,,,,,,, -482000,4.433591,0.62635785,,,,,,,,,,,,,, -482100,4.1623483,0.5503097,,,,,,,,,,,,,, -482200,4.238545,0.63588727,,,,,,,,,,,,,, -482300,4.517074,0.58487153,,,,,,,,,,,,,, -482400,4.347212,0.63214254,,,,,,,,,,,,,, -482500,4.72902,0.6266161,,,,,,,,,,,,,, -482600,4.5497117,0.60753334,,,,,,,,,,,,,, -482700,4.4009395,0.57012,,,,,,,,,,,,,, -482800,4.243382,0.526187,,,,,,,,,,,,,, -482856,,,0.9616549611091614,0.1477699726819992,0.7559999823570251,1.0522754192352295,50000.0,0.631600022315979,1.818324089050293,10000.0,162750.039809227,168481.42416000366,162750.039809227,5688.821419477463,23.48125386238098,0.0 -482900,4.933033,0.6242156,,,,,,,,,,,,,, -483000,4.3574347,0.61371744,,,,,,,,,,,,,, -483100,4.5971007,0.66005766,,,,,,,,,,,,,, -483200,4.385399,0.62964416,,,,,,,,,,,,,, -483300,4.0437107,0.56325865,,,,,,,,,,,,,, -483400,4.4972973,0.58959836,,,,,,,,,,,,,, -483500,4.439268,0.6203313,,,,,,,,,,,,,, -483600,4.5237107,0.6065707,,,,,,,,,,,,,, -483700,4.401481,0.6459951,,,,,,,,,,,,,, -483800,4.6709256,0.63985765,,,,,,,,,,,,,, -483900,4.363051,0.620615,,,,,,,,,,,,,, -484000,4.2600746,0.57679254,,,,,,,,,,,,,, -484100,4.572079,0.67374265,,,,,,,,,,,,,, -484200,4.4952903,0.60262513,,,,,,,,,,,,,, -484300,4.4408717,0.53001934,,,,,,,,,,,,,, -484371,,,0.960180163383484,0.1472534388303756,0.7554799914360046,1.0523960590362549,50000.0,0.6307000517845154,1.8176476955413816,10000.0,163259.96401071548,169008.52561068535,163259.96401071548,5705.831904172897,23.58875036239624,0.0 -484400,4.760165,0.65564,,,,,,,,,,,,,, -484500,4.7837353,0.72589606,,,,,,,,,,,,,, -484600,4.363959,0.5694989,,,,,,,,,,,,,, -484700,4.535098,0.6303527,,,,,,,,,,,,,, -484800,4.0136313,0.6123528,,,,,,,,,,,,,, -484900,4.118298,0.59348786,,,,,,,,,,,,,, -485000,5.3851876,0.6695934,,,,,,,,,,,,,, -485100,4.7759614,0.6109489,,,,,,,,,,,,,, -485200,4.438262,0.64367145,,,,,,,,,,,,,, -485300,4.2659235,0.6260627,,,,,,,,,,,,,, -485400,4.900657,0.6428524,,,,,,,,,,,,,, -485500,4.298609,0.59447896,,,,,,,,,,,,,, -485600,4.729836,0.6343228,,,,,,,,,,,,,, -485700,4.3886642,0.5767304,,,,,,,,,,,,,, -485800,4.914173,0.66473365,,,,,,,,,,,,,, -485885,,,0.9604790806770324,0.1493615359067917,0.7560799717903137,1.051154851913452,50000.0,0.6325000524520874,1.8172670602798464,10000.0,163769.99699783325,169535.86717271805,163769.99699783325,5722.962559461594,23.70534729957581,0.0 -485900,4.2742124,0.58234733,,,,,,,,,,,,,, -486000,4.4194784,0.6275468,,,,,,,,,,,,,, -486100,4.8193727,0.6418907,,,,,,,,,,,,,, -486200,4.806871,0.6354436,,,,,,,,,,,,,, -486300,4.486638,0.595515,,,,,,,,,,,,,, -486400,4.764167,0.6721338,,,,,,,,,,,,,, -486500,4.4995627,0.63683283,,,,,,,,,,,,,, -486600,4.5395055,0.64265496,,,,,,,,,,,,,, -486700,4.4658337,0.60717666,,,,,,,,,,,,,, -486800,4.538576,0.56937504,,,,,,,,,,,,,, -486900,4.5018554,0.62520474,,,,,,,,,,,,,, -487000,4.920142,0.6119408,,,,,,,,,,,,,, -487100,4.7988825,0.64088607,,,,,,,,,,,,,, -487200,4.692435,0.68879855,,,,,,,,,,,,,, -487300,4.442206,0.6665833,,,,,,,,,,,,,, -487398,,,0.961316168308258,0.1479690223932266,0.7558799982070923,1.0512391328811646,50000.0,0.631100058555603,1.816259503364563,10000.0,164279.91232395172,170062.991219759,164279.91232395172,5739.9425756931305,23.87461256980896,0.0 -487400,4.3249774,0.6025935,,,,,,,,,,,,,, -487500,4.3045616,0.5906047,,,,,,,,,,,,,, -487600,4.3167896,0.59733677,,,,,,,,,,,,,, -487700,4.709547,0.61846834,,,,,,,,,,,,,, -487800,4.4992948,0.6532253,,,,,,,,,,,,,, -487900,4.557075,0.650617,,,,,,,,,,,,,, -488000,4.5109086,0.6489061,,,,,,,,,,,,,, -488100,4.69579,0.6781323,,,,,,,,,,,,,, -488200,4.289351,0.6042473,,,,,,,,,,,,,, -488300,4.2359805,0.64510024,,,,,,,,,,,,,, -488400,4.347886,0.5768142,,,,,,,,,,,,,, -488500,4.6903954,0.72553074,,,,,,,,,,,,,, -488600,4.069993,0.5812393,,,,,,,,,,,,,, -488700,4.7413287,0.71508914,,,,,,,,,,,,,, -488800,4.566388,0.6020552,,,,,,,,,,,,,, -488900,4.1203375,0.5931647,,,,,,,,,,,,,, -488912,,,0.959980845451355,0.1480510532855987,0.7559999823570251,1.0511592626571655,50000.0,0.6306000351905823,1.817075252532959,10000.0,164789.9333343506,170590.12753081322,164789.9333343506,5756.892340660095,23.98096513748169,0.0 -489000,4.8251243,0.64527404,,,,,,,,,,,,,, -489100,4.508859,0.6786096,,,,,,,,,,,,,, -489200,4.3987975,0.68472683,,,,,,,,,,,,,, -489300,4.4682903,0.6265086,,,,,,,,,,,,,, -489400,3.9866168,0.54478693,,,,,,,,,,,,,, -489500,4.55808,0.61951846,,,,,,,,,,,,,, -489600,4.2995696,0.5729458,,,,,,,,,,,,,, -489700,4.3549385,0.58208483,,,,,,,,,,,,,, -489800,4.659635,0.61268157,,,,,,,,,,,,,, -489900,4.264014,0.6237724,,,,,,,,,,,,,, -490000,4.14905,0.5611507,,,,,,,,,,,,,, -490100,4.6629343,0.66625226,,,,,,,,,,,,,, -490200,4.565597,0.60048956,,,,,,,,,,,,,, -490300,4.7407,0.6076587,,,,,,,,,,,,,, -490400,4.920978,0.64454293,,,,,,,,,,,,,, -490426,,,0.9606783986091614,0.1471062153577804,0.7559399604797363,1.0521148443222046,50000.0,0.6318000555038452,1.8199012279510496,10000.0,165299.89897084236,171117.3261051178,165299.89897084236,5773.955453395844,24.091283559799194,0.0 -490500,4.2501707,0.55182487,,,,,,,,,,,,,, -490600,4.2169013,0.609279,,,,,,,,,,,,,, -490700,4.6404986,0.6731605,,,,,,,,,,,,,, -490800,4.867479,0.62074614,,,,,,,,,,,,,, -490900,4.358428,0.5345838,,,,,,,,,,,,,, -491000,4.675979,0.7085463,,,,,,,,,,,,,, -491100,4.515914,0.61638284,,,,,,,,,,,,,, -491200,4.5666757,0.6519047,,,,,,,,,,,,,, -491300,4.352965,0.6007609,,,,,,,,,,,,,, -491400,4.01847,0.65898025,,,,,,,,,,,,,, -491500,4.481408,0.69218385,,,,,,,,,,,,,, -491600,4.2966065,0.5593092,,,,,,,,,,,,,, -491700,4.781749,0.6876184,,,,,,,,,,,,,, -491800,4.4711027,0.55295,,,,,,,,,,,,,, -491900,4.391227,0.63926905,,,,,,,,,,,,,, -491940,,,0.9601004123687744,0.1487551927566528,0.7557799816131592,1.0525541305541992,50000.0,0.6307000517845154,1.8181631565094,10000.0,165809.91813397408,171644.57608389854,165809.91813397408,5791.007390499115,24.209453582763672,0.0 -492000,4.0168877,0.55834264,,,,,,,,,,,,,, -492100,4.5880275,0.66829914,,,,,,,,,,,,,, -492200,4.045303,0.5502541,,,,,,,,,,,,,, -492300,4.246262,0.6184602,,,,,,,,,,,,,, -492400,4.387776,0.637257,,,,,,,,,,,,,, -492500,4.935376,0.56847405,,,,,,,,,,,,,, -492600,4.4576178,0.6614807,,,,,,,,,,,,,, -492700,4.596197,0.6730777,,,,,,,,,,,,,, -492800,4.724611,0.6319592,,,,,,,,,,,,,, -492900,4.5866265,0.6470565,,,,,,,,,,,,,, -493000,4.8300443,0.65509987,,,,,,,,,,,,,, -493100,4.561226,0.6177729,,,,,,,,,,,,,, -493200,4.328481,0.5994024,,,,,,,,,,,,,, -493300,4.213383,0.61820006,,,,,,,,,,,,,, -493400,4.84382,0.61721814,,,,,,,,,,,,,, -493454,,,0.9607381820678712,0.1470497101545334,0.7558000087738037,1.0523558855056765,50000.0,0.6310000419616699,1.8182610273361208,10000.0,166319.89604902267,172171.9135248661,166319.89604902267,5808.1975655555725,24.3194260597229,0.0 -493500,4.7207804,0.63220096,,,,,,,,,,,,,, -493600,4.1651263,0.6116395,,,,,,,,,,,,,, -493700,4.4364314,0.63068974,,,,,,,,,,,,,, -493800,4.827646,0.59523094,,,,,,,,,,,,,, -493900,4.165939,0.5873873,,,,,,,,,,,,,, -494000,4.460764,0.5995009,,,,,,,,,,,,,, -494100,4.558571,0.58209676,,,,,,,,,,,,,, -494200,4.326722,0.7179225,,,,,,,,,,,,,, -494300,4.704471,0.6256545,,,,,,,,,,,,,, -494400,4.4868116,0.6402371,,,,,,,,,,,,,, -494500,4.312722,0.5394179,,,,,,,,,,,,,, -494600,4.522492,0.68348604,,,,,,,,,,,,,, -494700,4.5600677,0.70375085,,,,,,,,,,,,,, -494800,4.2327046,0.60114056,,,,,,,,,,,,,, -494900,4.9667273,0.62200135,,,,,,,,,,,,,, -494968,,,0.961575210094452,0.1471329629421234,0.7560799717903137,1.0521728992462158,50000.0,0.6318000555038452,1.818455576896668,10000.0,166830.0151219368,172699.20630049706,166830.0151219368,5825.202464342117,24.42889618873596,0.0 -495000,4.653445,0.65466917,,,,,,,,,,,,,, -495100,4.170272,0.5604169,,,,,,,,,,,,,, -495200,4.700253,0.62826496,,,,,,,,,,,,,, -495300,4.4697285,0.60598683,,,,,,,,,,,,,, -495400,4.7576036,0.6827654,,,,,,,,,,,,,, -495500,4.57534,0.73643965,,,,,,,,,,,,,, -495600,4.8109674,0.66823155,,,,,,,,,,,,,, -495700,4.5418286,0.627295,,,,,,,,,,,,,, -495800,4.046079,0.581704,,,,,,,,,,,,,, -495900,4.294716,0.638458,,,,,,,,,,,,,, -496000,4.1651206,0.65329814,,,,,,,,,,,,,, -496100,4.6448445,0.677246,,,,,,,,,,,,,, -496200,4.1335654,0.5960811,,,,,,,,,,,,,, -496300,4.6113796,0.67741257,,,,,,,,,,,,,, -496400,4.1373706,0.6722068,,,,,,,,,,,,,, -496483,,,0.961316168308258,0.1451499611139297,0.7558799982070923,1.0518877506256104,50000.0,0.6313000321388245,1.8172544240951536,10000.0,167340.15970230105,173226.34223484993,167340.15970230105,5842.022374391556,24.539676189422607,0.0 -496500,4.6055255,0.63424456,,,,,,,,,,,,,, -496600,4.2624,0.55116856,,,,,,,,,,,,,, -496700,4.91605,0.6147855,,,,,,,,,,,,,, -496800,4.3416996,0.58550036,,,,,,,,,,,,,, -496900,5.011578,0.60477215,,,,,,,,,,,,,, -497000,4.3075943,0.58127344,,,,,,,,,,,,,, -497100,4.7659063,0.60635614,,,,,,,,,,,,,, -497200,4.5458956,0.63491803,,,,,,,,,,,,,, -497300,4.53263,0.63716614,,,,,,,,,,,,,, -497400,4.697937,0.6382452,,,,,,,,,,,,,, -497500,4.3199115,0.59031975,,,,,,,,,,,,,, -497600,4.5605774,0.6082181,,,,,,,,,,,,,, -497700,4.223454,0.6000855,,,,,,,,,,,,,, -497800,4.4499993,0.6164842,,,,,,,,,,,,,, -497900,4.7470446,0.6224764,,,,,,,,,,,,,, -497996,,,0.959781527519226,0.1461669504642486,0.7559999823570251,1.0516399145126345,50000.0,0.6314000487327576,1.8174744844436648,10000.0,167850.05058646202,173753.38080906868,167850.05058646202,5859.004626750946,24.64591932296753,0.0 -498000,4.122295,0.55898947,,,,,,,,,,,,,, -498100,4.1155496,0.60102665,,,,,,,,,,,,,, -498200,4.4202223,0.58004564,,,,,,,,,,,,,, -498300,4.9758096,0.69987196,,,,,,,,,,,,,, -498400,4.610713,0.70522267,,,,,,,,,,,,,, -498500,4.424975,0.5237583,,,,,,,,,,,,,, -498600,4.7228103,0.6614841,,,,,,,,,,,,,, -498700,4.5825877,0.6220019,,,,,,,,,,,,,, -498800,4.5620427,0.62485343,,,,,,,,,,,,,, -498900,4.2645473,0.6232402,,,,,,,,,,,,,, -499000,4.500779,0.6124838,,,,,,,,,,,,,, -499100,4.6035953,0.65475535,,,,,,,,,,,,,, -499200,4.2012553,0.58650357,,,,,,,,,,,,,, -499300,4.068448,0.6072315,,,,,,,,,,,,,, -499400,4.50318,0.6109583,,,,,,,,,,,,,, -499500,4.4531865,0.6095779,,,,,,,,,,,,,, -499510,,,0.9612762928009032,0.1455503702163696,0.7559399604797363,1.0524622201919556,50000.0,0.6314000487327576,1.8173248767852783,10000.0,168359.94944262505,174280.4545059204,168359.94944262505,5876.005375862122,24.75995421409607,0.0 -499600,4.62426,0.7102594,,,,,,,,,,,,,, -499700,4.235594,0.58306026,,,,,,,,,,,,,, -499800,4.354459,0.5947184,,,,,,,,,,,,,, -499900,4.1947045,0.64360404,,,,,,,,,,,,,, -500000,4.2019014,0.66181356,,,,,,,,,,,,,, -500100,4.169038,0.6110786,,,,,,,,,,,,,, -500200,4.0495,0.55347085,,,,,,,,,,,,,, -500300,4.2859645,0.6224345,,,,,,,,,,,,,, -500400,4.145273,0.62533075,,,,,,,,,,,,,, -500500,5.1524267,0.6297842,,,,,,,,,,,,,, -500600,4.4653487,0.699865,,,,,,,,,,,,,, -500700,4.59125,0.6355528,,,,,,,,,,,,,, -500800,4.623864,0.61719924,,,,,,,,,,,,,, -500900,4.6443567,0.57848805,,,,,,,,,,,,,, -501000,4.4124846,0.6709636,,,,,,,,,,,,,, -501023,,,0.9595224857330322,0.1492321044206619,0.7557399868965149,1.052715539932251,50000.0,0.6304000020027161,1.8189656734466555,10000.0,168869.80708909035,174808.23965120316,168869.80708909035,5893.762127637863,24.870811223983765,0.0 -501100,4.5093856,0.6440718,,,,,,,,,,,,,, -501200,4.929437,0.6391582,,,,,,,,,,,,,, -501300,4.4838433,0.624926,,,,,,,,,,,,,, -501400,4.528539,0.59342164,,,,,,,,,,,,,, -501500,4.3960567,0.61186755,,,,,,,,,,,,,, -501600,4.4996243,0.6025074,,,,,,,,,,,,,, -501700,4.638381,0.73561376,,,,,,,,,,,,,, -501800,4.3440466,0.6168503,,,,,,,,,,,,,, -501900,4.3178453,0.6495461,,,,,,,,,,,,,, -502000,4.3288407,0.5615499,,,,,,,,,,,,,, -502100,4.1524444,0.5461407,,,,,,,,,,,,,, -502200,4.477563,0.5555299,,,,,,,,,,,,,, -502300,4.6584797,0.7022973,,,,,,,,,,,,,, -502400,4.3163137,0.61161953,,,,,,,,,,,,,, -502500,4.594288,0.6333165,,,,,,,,,,,,,, -502537,,,0.9601004123687744,0.1491653770208358,0.755840003490448,1.051175236701965,50000.0,0.6312000155448914,1.8172820806503296,10000.0,169379.7861609459,175335.23373770714,169379.7861609459,5910.607840776444,24.980355262756348,0.0 -502600,4.762227,0.6548552,,,,,,,,,,,,,, -502700,4.5265245,0.60777897,,,,,,,,,,,,,, -502800,4.8324714,0.64628786,,,,,,,,,,,,,, -502900,4.3723607,0.60722816,,,,,,,,,,,,,, -503000,4.4638853,0.62726116,,,,,,,,,,,,,, -503100,5.0500875,0.73960924,,,,,,,,,,,,,, -503200,4.296768,0.60998654,,,,,,,,,,,,,, -503300,4.6042886,0.6786799,,,,,,,,,,,,,, -503400,4.481007,0.5892382,,,,,,,,,,,,,, -503500,4.435634,0.5885346,,,,,,,,,,,,,, -503600,4.5829296,0.6521653,,,,,,,,,,,,,, -503700,4.9935613,0.6513947,,,,,,,,,,,,,, -503800,4.486771,0.60780275,,,,,,,,,,,,,, -503900,4.4951677,0.5455981,,,,,,,,,,,,,, -504000,4.4419723,0.6243311,,,,,,,,,,,,,, -504051,,,0.960758090019226,0.1461090445518493,0.7560399770736694,1.0522167682647705,50000.0,0.6308000087738037,1.8169641494750977,10000.0,169889.70877027512,175862.2305700779,169889.70877027512,5927.514526128769,25.08798623085022,0.0 -504100,4.827312,0.62241185,,,,,,,,,,,,,, -504200,4.2835445,0.64387983,,,,,,,,,,,,,, -504300,4.4953847,0.65964544,,,,,,,,,,,,,, -504400,4.6897783,0.5912537,,,,,,,,,,,,,, -504500,4.265568,0.552782,,,,,,,,,,,,,, -504600,4.6986074,0.666635,,,,,,,,,,,,,, -504700,4.842603,0.59545857,,,,,,,,,,,,,, -504800,4.509203,0.5976815,,,,,,,,,,,,,, -504900,4.8948836,0.68326896,,,,,,,,,,,,,, -505000,4.2344856,0.56061137,,,,,,,,,,,,,, -505100,4.4708633,0.62360317,,,,,,,,,,,,,, -505200,4.1935596,0.57768035,,,,,,,,,,,,,, -505300,4.3204885,0.5516613,,,,,,,,,,,,,, -505400,4.4214954,0.63260937,,,,,,,,,,,,,, -505500,4.809621,0.63652205,,,,,,,,,,,,,, -505566,,,0.9611766338348388,0.1460204273462295,0.7558199763298035,1.052260160446167,50000.0,0.6320000290870667,1.818225264549256,10000.0,170399.8674080372,176389.4013133049,170399.8674080372,5944.358859062195,25.19463801383972,0.0 -505600,4.3756447,0.5686266,,,,,,,,,,,,,, -505700,4.3043933,0.65515894,,,,,,,,,,,,,, -505800,4.375067,0.6486604,,,,,,,,,,,,,, -505900,5.2392097,0.6340671,,,,,,,,,,,,,, -506000,5.0673237,0.606063,,,,,,,,,,,,,, -506100,3.7732694,0.54469717,,,,,,,,,,,,,, -506200,4.1503267,0.53821445,,,,,,,,,,,,,, -506300,4.6083174,0.60595,,,,,,,,,,,,,, -506400,4.7069435,0.63712436,,,,,,,,,,,,,, -506500,4.258624,0.61115867,,,,,,,,,,,,,, -506600,4.7193747,0.6046823,,,,,,,,,,,,,, -506700,4.58428,0.55750054,,,,,,,,,,,,,, -506800,4.6755137,0.5918808,,,,,,,,,,,,,, -506900,4.4357057,0.5886559,,,,,,,,,,,,,, -507000,4.0283136,0.53558326,,,,,,,,,,,,,, -507080,,,0.9620934128761292,0.1431600302457809,0.7554999589920044,1.0530349016189575,50000.0,0.6301000118255615,1.8179951906204224,10000.0,170909.96117901802,176917.21332883835,170909.96117901802,5961.896793842316,25.31332540512085,0.0 -507100,4.483187,0.5957477,,,,,,,,,,,,,, -507200,4.6273856,0.5896429,,,,,,,,,,,,,, -507300,4.540632,0.5773488,,,,,,,,,,,,,, -507400,4.371722,0.66872954,,,,,,,,,,,,,, -507500,4.402944,0.6077099,,,,,,,,,,,,,, -507600,4.7838416,0.6478161,,,,,,,,,,,,,, -507700,4.1493907,0.58022606,,,,,,,,,,,,,, -507800,4.828892,0.6733383,,,,,,,,,,,,,, -507900,4.324519,0.6290426,,,,,,,,,,,,,, -508000,4.6844873,0.70519507,,,,,,,,,,,,,, -508100,4.3207374,0.5505259,,,,,,,,,,,,,, -508200,4.680224,0.6693849,,,,,,,,,,,,,, -508300,4.2667723,0.5407797,,,,,,,,,,,,,, -508400,4.0699015,0.5388973,,,,,,,,,,,,,, -508500,4.1388087,0.59186447,,,,,,,,,,,,,, -508594,,,0.9612762928009032,0.1459125131368637,0.7559399604797363,1.0524475574493408,50000.0,0.6305000185966492,1.8188080787658687,10000.0,171419.9574327469,177444.34196448326,171419.9574327469,5978.876355171204,25.40578317642212,0.0 -508600,4.670553,0.72820264,,,,,,,,,,,,,, -508700,4.5547442,0.60349524,,,,,,,,,,,,,, -508800,4.6319647,0.55803037,,,,,,,,,,,,,, -508900,4.1973906,0.57791275,,,,,,,,,,,,,, -509000,4.367719,0.68353707,,,,,,,,,,,,,, -509100,4.354681,0.61999804,,,,,,,,,,,,,, -509200,4.5432734,0.6630509,,,,,,,,,,,,,, -509300,4.492743,0.6835011,,,,,,,,,,,,,, -509400,4.2335873,0.5294221,,,,,,,,,,,,,, -509500,4.422803,0.60703933,,,,,,,,,,,,,, -509600,4.65951,0.67865646,,,,,,,,,,,,,, -509700,4.3814354,0.63754046,,,,,,,,,,,,,, -509800,4.5126085,0.6258767,,,,,,,,,,,,,, -509900,4.414112,0.57928944,,,,,,,,,,,,,, -510000,4.693563,0.6094817,,,,,,,,,,,,,, -510100,4.743976,0.68170404,,,,,,,,,,,,,, -510108,,,0.9589046239852904,0.1507239043712616,0.7554999589920044,1.0527056455612185,50000.0,0.6309000253677368,1.81772255897522,10000.0,171929.94116711617,177971.4310863018,171929.94116711617,5995.808697462082,25.51852774620056,0.0 -510200,5.0619903,0.71713686,,,,,,,,,,,,,, -510300,4.3343434,0.51366234,,,,,,,,,,,,,, -510400,4.4856286,0.64238936,,,,,,,,,,,,,, -510500,4.4096627,0.65271854,,,,,,,,,,,,,, -510600,4.604453,0.6318466,,,,,,,,,,,,,, -510700,4.434475,0.52650696,,,,,,,,,,,,,, -510800,4.37209,0.60636485,,,,,,,,,,,,,, -510900,4.3654027,0.56410867,,,,,,,,,,,,,, -511000,4.463192,0.59536964,,,,,,,,,,,,,, -511100,4.826702,0.6881205,,,,,,,,,,,,,, -511200,4.1280103,0.60168004,,,,,,,,,,,,,, -511300,4.469057,0.6021688,,,,,,,,,,,,,, -511400,4.719822,0.6738677,,,,,,,,,,,,,, -511500,4.4852552,0.61778957,,,,,,,,,,,,,, -511600,4.136077,0.5598187,,,,,,,,,,,,,, -511622,,,0.9614157676696776,0.1448870450258255,0.7563799619674683,1.051182508468628,50000.0,0.6308000087738037,1.8169662952423096,10000.0,172439.7688140869,178498.42212200165,172439.7688140869,6012.799216747284,25.63161373138428,0.0 -511700,4.702733,0.6188113,,,,,,,,,,,,,, -511800,4.4676566,0.648957,,,,,,,,,,,,,, -511900,3.8950942,0.6018677,,,,,,,,,,,,,, -512000,4.4725075,0.67949665,,,,,,,,,,,,,, -512100,4.305386,0.59417987,,,,,,,,,,,,,, -512200,4.2832217,0.6272803,,,,,,,,,,,,,, -512300,4.5431175,0.6492616,,,,,,,,,,,,,, -512400,4.0772195,0.578122,,,,,,,,,,,,,, -512500,4.376976,0.6509539,,,,,,,,,,,,,, -512600,4.5298624,0.6282049,,,,,,,,,,,,,, -512700,4.486134,0.66918176,,,,,,,,,,,,,, -512800,4.613148,0.6263405,,,,,,,,,,,,,, -512900,4.298248,0.61055815,,,,,,,,,,,,,, -513000,4.930253,0.68981636,,,,,,,,,,,,,, -513100,4.378836,0.56685555,,,,,,,,,,,,,, -513136,,,0.9617745280265808,0.1449913829565048,0.7557199597358704,1.0533603429794312,50000.0,0.6307000517845154,1.818769216537476,10000.0,172949.87129354477,179025.70138788223,172949.87129354477,6029.798710823059,25.749578952789307,0.0 -513200,4.3602586,0.62154895,,,,,,,,,,,,,, -513300,4.574991,0.628345,,,,,,,,,,,,,, -513400,3.9576535,0.5502986,,,,,,,,,,,,,, -513500,4.3401284,0.5794183,,,,,,,,,,,,,, -513600,4.4311485,0.6224252,,,,,,,,,,,,,, -513700,4.259812,0.6004636,,,,,,,,,,,,,, -513800,4.575557,0.68558097,,,,,,,,,,,,,, -513900,5.223268,0.6593175,,,,,,,,,,,,,, -514000,4.271974,0.61030823,,,,,,,,,,,,,, -514100,4.4747396,0.58705527,,,,,,,,,,,,,, -514200,4.9921184,0.6408255,,,,,,,,,,,,,, -514300,4.7542744,0.619581,,,,,,,,,,,,,, -514400,5.081372,0.63559407,,,,,,,,,,,,,, -514500,4.3421926,0.6770216,,,,,,,,,,,,,, -514600,4.3559847,0.5941094,,,,,,,,,,,,,, -514650,,,0.9599210619926452,0.1473044008016586,0.7555199861526489,1.0519886016845703,50000.0,0.6312000155448914,1.817663311958313,10000.0,173459.8376841545,179553.39589834213,173459.8376841545,6047.356439828873,25.85953950881958,0.0 -514700,4.3396935,0.6148159,,,,,,,,,,,,,, -514800,4.9485817,0.6499742,,,,,,,,,,,,,, -514900,4.8477197,0.57667726,,,,,,,,,,,,,, -515000,4.9701095,0.677131,,,,,,,,,,,,,, -515100,4.2506723,0.5642487,,,,,,,,,,,,,, -515200,4.8686066,0.6154714,,,,,,,,,,,,,, -515300,4.7092557,0.66299343,,,,,,,,,,,,,, -515400,4.0958223,0.60515106,,,,,,,,,,,,,, -515500,4.908013,0.6729272,,,,,,,,,,,,,, -515600,4.5714016,0.625965,,,,,,,,,,,,,, -515700,4.3049893,0.54594886,,,,,,,,,,,,,, -515800,4.1728773,0.5622983,,,,,,,,,,,,,, -515900,4.575998,0.6363777,,,,,,,,,,,,,, -516000,4.403699,0.6548744,,,,,,,,,,,,,, -516100,4.354911,0.5754123,,,,,,,,,,,,,, -516163,,,0.9608976244926452,0.1467035561800003,0.7555800080299377,1.052548050880432,50000.0,0.6307000517845154,1.8186434507369995,10000.0,173969.71770739555,180080.4456114769,173969.71770739555,6064.35414147377,25.97298693656921,0.0 -516200,4.4684453,0.58812726,,,,,,,,,,,,,, -516300,4.274467,0.64466894,,,,,,,,,,,,,, -516400,4.2770624,0.5818384,,,,,,,,,,,,,, -516500,4.585129,0.6139046,,,,,,,,,,,,,, -516600,4.0348096,0.5959485,,,,,,,,,,,,,, -516700,4.5325146,0.66700435,,,,,,,,,,,,,, -516800,5.01639,0.6409915,,,,,,,,,,,,,, -516900,4.433695,0.55004257,,,,,,,,,,,,,, -517000,4.41572,0.6786707,,,,,,,,,,,,,, -517100,4.314918,0.5979891,,,,,,,,,,,,,, -517200,4.3815527,0.6132618,,,,,,,,,,,,,, -517300,4.7880807,0.61605936,,,,,,,,,,,,,, -517400,4.967336,0.6457533,,,,,,,,,,,,,, -517500,4.2126,0.66357774,,,,,,,,,,,,,, -517600,5.122212,0.68668604,,,,,,,,,,,,,, -517677,,,0.960598647594452,0.1472920030355453,0.7556399703025818,1.0521091222763062,50000.0,0.6319000124931335,1.816991925239563,10000.0,174479.7213087082,180607.57706213,174479.7213087082,6081.307334423065,26.08626389503479,0.0 -517700,4.244962,0.5539775,,,,,,,,,,,,,, -517800,4.210845,0.57258725,,,,,,,,,,,,,, -517900,4.641289,0.61977994,,,,,,,,,,,,,, -518000,4.4518647,0.64199376,,,,,,,,,,,,,, -518100,4.2499056,0.5930762,,,,,,,,,,,,,, -518200,4.22228,0.5918196,,,,,,,,,,,,,, -518300,4.408295,0.61755216,,,,,,,,,,,,,, -518400,4.6947775,0.5929785,,,,,,,,,,,,,, -518500,4.5297637,0.64902234,,,,,,,,,,,,,, -518600,4.1025944,0.68010974,,,,,,,,,,,,,, -518700,4.014032,0.53269154,,,,,,,,,,,,,, -518800,4.6574097,0.6664865,,,,,,,,,,,,,, -518900,4.515151,0.59114796,,,,,,,,,,,,,, -519000,4.38637,0.63400453,,,,,,,,,,,,,, -519100,4.768811,0.67692614,,,,,,,,,,,,,, -519191,,,0.9604790806770324,0.1465890109539032,0.7556799650192261,1.0526906251907349,50000.0,0.6306000351905823,1.8196641206741333,10000.0,174989.78865027428,181134.6708905697,174989.78865027428,6098.161105394363,26.196840047836304,0.0 -519200,4.36904,0.62426054,,,,,,,,,,,,,, -519300,4.3147297,0.6568284,,,,,,,,,,,,,, -519400,4.2037363,0.6103152,,,,,,,,,,,,,, -519500,4.614819,0.59723294,,,,,,,,,,,,,, -519600,4.5538454,0.6604755,,,,,,,,,,,,,, -519700,4.9978967,0.62787914,,,,,,,,,,,,,, -519800,4.5864296,0.63602537,,,,,,,,,,,,,, -519900,4.1503077,0.56083786,,,,,,,,,,,,,, -520000,4.3657966,0.59675956,,,,,,,,,,,,,, -520100,5.8694835,0.7120958,,,,,,,,,,,,,, -520200,4.565841,0.59766775,,,,,,,,,,,,,, -520300,4.380802,0.6508782,,,,,,,,,,,,,, -520400,4.456154,0.5830103,,,,,,,,,,,,,, -520500,4.8175006,0.7296916,,,,,,,,,,,,,, -520600,4.504771,0.5808375,,,,,,,,,,,,,, -520700,4.2397523,0.6594426,,,,,,,,,,,,,, -520704,,,0.96000075340271,0.1486792713403701,0.7560399770736694,1.0523520708084106,50000.0,0.6309000253677368,1.8196110725402832,10000.0,175499.8068125248,181661.85493016243,175499.8068125248,6115.153409481049,26.310802698135376,0.0 -520800,4.6664014,0.67834365,,,,,,,,,,,,,, -520900,4.5768256,0.5908804,,,,,,,,,,,,,, -521000,4.9954557,0.67494893,,,,,,,,,,,,,, -521100,4.4447145,0.61808544,,,,,,,,,,,,,, -521200,5.5607266,0.6433763,,,,,,,,,,,,,, -521300,4.086005,0.51025516,,,,,,,,,,,,,, -521400,4.9390492,0.7577436,,,,,,,,,,,,,, -521500,4.8960276,0.6254701,,,,,,,,,,,,,, -521600,4.0315256,0.5918791,,,,,,,,,,,,,, -521700,4.764166,0.6634483,,,,,,,,,,,,,, -521800,4.4593725,0.692344,,,,,,,,,,,,,, -521900,4.6776705,0.71175194,,,,,,,,,,,,,, -522000,4.8423553,0.62466353,,,,,,,,,,,,,, -522100,4.339914,0.58442396,,,,,,,,,,,,,, -522200,4.658407,0.6066722,,,,,,,,,,,,,, -522217,,,0.961734652519226,0.1452644020318985,0.7561399936676025,1.0519838333129885,50000.0,0.6314000487327576,1.819387555122376,10000.0,176009.86840701103,182189.1304523945,176009.86840701103,6132.192498922348,26.426344871521,0.0 -522300,4.5273075,0.5428221,,,,,,,,,,,,,, -522400,4.547388,0.6069962,,,,,,,,,,,,,, -522500,5.2831817,0.6850542,,,,,,,,,,,,,, -522600,4.5698414,0.61067593,,,,,,,,,,,,,, -522700,5.12237,0.68222886,,,,,,,,,,,,,, -522800,4.5004587,0.60070425,,,,,,,,,,,,,, -522900,4.5624065,0.656141,,,,,,,,,,,,,, -523000,4.725458,0.6417715,,,,,,,,,,,,,, -523100,4.4147425,0.5881216,,,,,,,,,,,,,, -523200,4.5751715,0.6346812,,,,,,,,,,,,,, -523300,4.8168926,0.6697272,,,,,,,,,,,,,, -523400,4.4868584,0.64480066,,,,,,,,,,,,,, -523500,4.2869453,0.57059777,,,,,,,,,,,,,, -523600,4.2210035,0.59074366,,,,,,,,,,,,,, -523700,4.2315903,0.56337845,,,,,,,,,,,,,, -523730,,,0.9601004123687744,0.1496659517288208,0.7558799982070923,1.052033305168152,50000.0,0.6317000389099121,1.817821979522705,10000.0,176519.7387931347,182716.08656191823,176519.7387931347,6149.103352069855,26.54132080078125,0.0 -523800,4.106353,0.6328492,,,,,,,,,,,,,, -523900,4.147909,0.60882413,,,,,,,,,,,,,, -524000,4.8330927,0.61030096,,,,,,,,,,,,,, -524100,4.4124284,0.6163362,,,,,,,,,,,,,, -524200,4.444617,0.6329635,,,,,,,,,,,,,, -524300,5.2264347,0.6110259,,,,,,,,,,,,,, -524400,4.3135886,0.597393,,,,,,,,,,,,,, -524500,4.642392,0.69187933,,,,,,,,,,,,,, -524600,5.216236,0.6159723,,,,,,,,,,,,,, -524700,4.5051894,0.59137326,,,,,,,,,,,,,, -524800,4.3961,0.61278564,,,,,,,,,,,,,, -524900,4.524664,0.7184829,,,,,,,,,,,,,, -525000,4.7769446,0.64698887,,,,,,,,,,,,,, -525100,4.7868333,0.68080616,,,,,,,,,,,,,, -525200,4.40189,0.6549005,,,,,,,,,,,,,, -525244,,,0.9605189561843872,0.1485978960990905,0.7560399770736694,1.0526304244995115,50000.0,0.6317000389099121,1.8184934854507449,10000.0,177029.82167744637,183243.22400975227,177029.82167744637,6165.975840806961,26.663962841033936,0.0 -525300,4.6879163,0.6379329,,,,,,,,,,,,,, -525400,4.936479,0.649326,,,,,,,,,,,,,, -525500,4.3648987,0.5884164,,,,,,,,,,,,,, -525600,4.2683573,0.60598403,,,,,,,,,,,,,, -525700,5.098963,0.6480384,,,,,,,,,,,,,, -525800,4.568383,0.6490921,,,,,,,,,,,,,, -525900,5.227877,0.7572344,,,,,,,,,,,,,, -526000,5.1634893,0.66777664,,,,,,,,,,,,,, -526100,4.8710394,0.6367072,,,,,,,,,,,,,, -526200,4.328973,0.6127974,,,,,,,,,,,,,, -526300,4.0911317,0.6071802,,,,,,,,,,,,,, -526400,4.27586,0.6005213,,,,,,,,,,,,,, -526500,4.023513,0.56551737,,,,,,,,,,,,,, -526600,4.35203,0.6503943,,,,,,,,,,,,,, -526700,4.4156804,0.6209935,,,,,,,,,,,,,, -526758,,,0.9609375,0.1478055715560913,0.7556599974632263,1.051428198814392,50000.0,0.6315000057220459,1.816197752952576,10000.0,177539.7880001068,183770.3384861946,177539.7880001068,6182.947371721268,26.78052973747253,0.0 -526800,4.6821694,0.6118928,,,,,,,,,,,,,, -526900,4.7986274,0.6740671,,,,,,,,,,,,,, -527000,4.577278,0.65484416,,,,,,,,,,,,,, -527100,3.9260547,0.5608978,,,,,,,,,,,,,, -527200,4.5443935,0.5647625,,,,,,,,,,,,,, -527300,4.6147494,0.6445337,,,,,,,,,,,,,, -527400,4.6497254,0.61438304,,,,,,,,,,,,,, -527500,4.4679365,0.6389581,,,,,,,,,,,,,, -527600,5.3156157,0.68026704,,,,,,,,,,,,,, -527700,4.234062,0.5993125,,,,,,,,,,,,,, -527800,4.6824183,0.60741615,,,,,,,,,,,,,, -527900,4.767231,0.6293691,,,,,,,,,,,,,, -528000,4.485507,0.60170054,,,,,,,,,,,,,, -528100,4.4759293,0.6283363,,,,,,,,,,,,,, -528200,4.0998693,0.65811914,,,,,,,,,,,,,, -528271,,,0.9596420526504515,0.147839218378067,0.7558000087738037,1.051953673362732,50000.0,0.6306000351905823,1.8174481391906736,10000.0,178049.65263915062,184298.1275918484,178049.65263915062,6200.699266672134,26.892802953720093,0.0 -528300,4.7156563,0.6397011,,,,,,,,,,,,,, -528400,3.9730155,0.5243641,,,,,,,,,,,,,, -528500,4.493857,0.64845526,,,,,,,,,,,,,, -528600,4.604225,0.6498929,,,,,,,,,,,,,, -528700,4.8048563,0.7101117,,,,,,,,,,,,,, -528800,4.268247,0.63070375,,,,,,,,,,,,,, -528900,4.2664123,0.5787327,,,,,,,,,,,,,, -529000,4.8013,0.6600079,,,,,,,,,,,,,, -529100,4.525961,0.64828676,,,,,,,,,,,,,, -529200,4.2542887,0.61210823,,,,,,,,,,,,,, -529300,4.5616293,0.59240717,,,,,,,,,,,,,, -529400,4.5786743,0.63066065,,,,,,,,,,,,,, -529500,4.223559,0.53692085,,,,,,,,,,,,,, -529600,4.482385,0.5942928,,,,,,,,,,,,,, -529700,4.5532737,0.6206685,,,,,,,,,,,,,, -529784,,,0.9607381820678712,0.1477023512125015,0.7559999823570251,1.051684856414795,50000.0,0.6313000321388245,1.8184610605239868,10000.0,178559.2515347004,184825.0801999569,178559.2515347004,6217.64315867424,27.24454164505005,0.0 -529800,4.470291,0.62242424,,,,,,,,,,,,,, -529900,4.404224,0.643132,,,,,,,,,,,,,, -530000,4.54184,0.6466784,,,,,,,,,,,,,, -530100,4.905923,0.6234052,,,,,,,,,,,,,, -530200,4.5393195,0.62709635,,,,,,,,,,,,,, -530300,4.4587393,0.5626627,,,,,,,,,,,,,, -530400,4.052768,0.5105215,,,,,,,,,,,,,, -530500,4.2238145,0.57791483,,,,,,,,,,,,,, -530600,4.2380395,0.5515841,,,,,,,,,,,,,, -530700,4.503595,0.69515383,,,,,,,,,,,,,, -530800,4.5777197,0.585743,,,,,,,,,,,,,, -530900,4.3956795,0.66011167,,,,,,,,,,,,,, -531000,4.696571,0.6607218,,,,,,,,,,,,,, -531100,4.686051,0.73779917,,,,,,,,,,,,,, -531200,4.725381,0.6441458,,,,,,,,,,,,,, -531298,,,0.960180163383484,0.1469533294439315,0.7556999921798706,1.0518393516540527,50000.0,0.6313000321388245,1.8184059858322144,10000.0,179069.09046268463,185352.04632163048,179069.09046268463,6234.597435712814,27.356879711151123,0.0 -531300,4.730028,0.6155609,,,,,,,,,,,,,, -531400,4.479086,0.59217346,,,,,,,,,,,,,, -531500,4.213728,0.57368344,,,,,,,,,,,,,, -531600,4.319438,0.5436172,,,,,,,,,,,,,, -531700,4.5033674,0.6011393,,,,,,,,,,,,,, -531800,4.7619977,0.6487296,,,,,,,,,,,,,, -531900,4.508004,0.60568315,,,,,,,,,,,,,, -532000,4.3732705,0.6602973,,,,,,,,,,,,,, -532100,4.240374,0.6234909,,,,,,,,,,,,,, -532200,4.556755,0.6295224,,,,,,,,,,,,,, -532300,4.1727033,0.5787567,,,,,,,,,,,,,, -532400,5.6867666,0.64255786,,,,,,,,,,,,,, -532500,4.7377706,0.70039064,,,,,,,,,,,,,, -532600,4.7999606,0.5969505,,,,,,,,,,,,,, -532700,4.619041,0.57655233,,,,,,,,,,,,,, -532800,4.317469,0.621435,,,,,,,,,,,,,, -532812,,,0.960558831691742,0.1497821360826492,0.7557599544525146,1.05041766166687,50000.0,0.6306000351905823,1.8165922164916992,10000.0,179579.1422381401,185879.2403757572,179579.1422381401,6251.563860416412,27.472768306732178,0.0 -532900,4.4787593,0.5801593,,,,,,,,,,,,,, -533000,4.239667,0.62232,,,,,,,,,,,,,, -533100,4.4840746,0.58012235,,,,,,,,,,,,,, -533200,4.45508,0.609352,,,,,,,,,,,,,, -533300,4.283258,0.60578394,,,,,,,,,,,,,, -533400,4.66191,0.68742913,,,,,,,,,,,,,, -533500,4.454969,0.5829087,,,,,,,,,,,,,, -533600,3.874045,0.5512212,,,,,,,,,,,,,, -533700,4.251548,0.5851053,,,,,,,,,,,,,, -533800,4.1956015,0.5831458,,,,,,,,,,,,,, -533900,4.616402,0.5859914,,,,,,,,,,,,,, -534000,4.1052766,0.5750401,,,,,,,,,,,,,, -534100,4.2659345,0.65208304,,,,,,,,,,,,,, -534200,4.7691674,0.65497315,,,,,,,,,,,,,, -534300,4.4356446,0.62403244,,,,,,,,,,,,,, -534326,,,0.9616748690605164,0.1470065712928772,0.756119966506958,1.0519754886627195,50000.0,0.6319000124931335,1.81664514541626,10000.0,180089.1278910637,186406.2837367057,180089.1278910637,6268.447098493576,27.58768367767334,0.0 -534400,4.623845,0.633181,,,,,,,,,,,,,, -534500,4.8439536,0.6893459,,,,,,,,,,,,,, -534600,4.244871,0.63167095,,,,,,,,,,,,,, -534700,4.373151,0.6722026,,,,,,,,,,,,,, -534800,4.501256,0.59428364,,,,,,,,,,,,,, -534900,4.4112325,0.62891257,,,,,,,,,,,,,, -535000,4.2434454,0.60767376,,,,,,,,,,,,,, -535100,4.598426,0.691969,,,,,,,,,,,,,, -535200,4.2895484,0.5909346,,,,,,,,,,,,,, -535300,4.2380505,0.5682706,,,,,,,,,,,,,, -535400,3.8796587,0.53563845,,,,,,,,,,,,,, -535500,4.147005,0.58783036,,,,,,,,,,,,,, -535600,4.887353,0.7317844,,,,,,,,,,,,,, -535700,4.814695,0.6075106,,,,,,,,,,,,,, -535800,4.5115356,0.6755237,,,,,,,,,,,,,, -535840,,,0.9614157676696776,0.1429337263107299,0.7553600072860718,1.0528799295425415,50000.0,0.6308000087738037,1.8189215660095213,10000.0,180599.0215339661,186933.2061035633,180599.0215339661,6285.296785116196,27.708674669265747,0.0 -535900,4.714478,0.64107144,,,,,,,,,,,,,, -536000,4.122285,0.573814,,,,,,,,,,,,,, -536100,4.580887,0.65861464,,,,,,,,,,,,,, -536200,4.9730597,0.6667882,,,,,,,,,,,,,, -536300,4.6572967,0.613147,,,,,,,,,,,,,, -536400,5.3216095,0.6228341,,,,,,,,,,,,,, -536500,4.9087768,0.6705985,,,,,,,,,,,,,, -536600,4.119871,0.6063104,,,,,,,,,,,,,, -536700,4.258864,0.6151882,,,,,,,,,,,,,, -536800,4.557236,0.61626345,,,,,,,,,,,,,, -536900,4.45271,0.6062142,,,,,,,,,,,,,, -537000,4.407194,0.6732118,,,,,,,,,,,,,, -537100,5.1939297,0.63795996,,,,,,,,,,,,,, -537200,4.540516,0.5976023,,,,,,,,,,,,,, -537300,4.2849975,0.6359798,,,,,,,,,,,,,, -537353,,,0.9607979655265808,0.1454615592956543,0.7557599544525146,1.0528595447540283,50000.0,0.6315000057220459,1.8187016248703003,10000.0,181108.84668302536,187460.0957119465,181108.84668302536,6302.178875923157,27.830414295196533,0.0 -537400,3.9546876,0.6099177,,,,,,,,,,,,,, -537500,4.730268,0.61320996,,,,,,,,,,,,,, -537600,4.7263045,0.647374,,,,,,,,,,,,,, -537700,4.3090324,0.6435798,,,,,,,,,,,,,, -537800,4.409879,0.5903469,,,,,,,,,,,,,, -537900,4.691973,0.614697,,,,,,,,,,,,,, -538000,4.796443,0.59364974,,,,,,,,,,,,,, -538100,4.586904,0.58463067,,,,,,,,,,,,,, -538200,4.504967,0.5810671,,,,,,,,,,,,,, -538300,4.4591575,0.5686608,,,,,,,,,,,,,, -538400,4.558093,0.5921265,,,,,,,,,,,,,, -538500,4.507353,0.63279843,,,,,,,,,,,,,, -538600,4.753944,0.6553029,,,,,,,,,,,,,, -538700,4.360946,0.67894495,,,,,,,,,,,,,, -538800,4.400954,0.57268965,,,,,,,,,,,,,, -538867,,,0.9594626426696776,0.1487793922424316,0.7560799717903137,1.0516453981399536,50000.0,0.6307000517845154,1.8169374465942385,10000.0,181618.950300932,187987.47008562088,181618.950300932,6319.269691467285,27.94835305213928,0.0 -538900,4.6175404,0.6469688,,,,,,,,,,,,,, -539000,4.219622,0.6187361,,,,,,,,,,,,,, -539100,4.474717,0.5533558,,,,,,,,,,,,,, -539200,4.167549,0.56807494,,,,,,,,,,,,,, -539300,4.1017957,0.5269492,,,,,,,,,,,,,, -539400,4.374066,0.6079221,,,,,,,,,,,,,, -539500,4.391015,0.5446902,,,,,,,,,,,,,, -539600,4.376455,0.6886598,,,,,,,,,,,,,, -539700,4.8186045,0.72604054,,,,,,,,,,,,,, -539800,4.4205837,0.6005908,,,,,,,,,,,,,, -539900,4.377768,0.56531435,,,,,,,,,,,,,, -540000,4.4574747,0.6269591,,,,,,,,,,,,,, -540100,4.4705043,0.64009917,,,,,,,,,,,,,, -540200,4.37176,0.56878126,,,,,,,,,,,,,, -540300,4.5749907,0.6698719,,,,,,,,,,,,,, -540380,,,0.9597616195678712,0.1483667641878128,0.7553600072860718,1.051335334777832,50000.0,0.6309000253677368,1.815916895866394,10000.0,182128.7737035752,188514.49407696724,182128.7737035752,6336.292678594589,28.065372705459595,0.0 -540400,5.015983,0.622156,,,,,,,,,,,,,, -540500,4.6383977,0.60333616,,,,,,,,,,,,,, -540600,4.379765,0.5828706,,,,,,,,,,,,,, -540700,5.025282,0.6306096,,,,,,,,,,,,,, -540800,4.5927615,0.64864886,,,,,,,,,,,,,, -540900,4.779689,0.72253805,,,,,,,,,,,,,, -541000,4.7833977,0.6468285,,,,,,,,,,,,,, -541100,4.689289,0.66816133,,,,,,,,,,,,,, -541200,4.6560664,0.6256478,,,,,,,,,,,,,, -541300,5.0579805,0.7297251,,,,,,,,,,,,,, -541400,4.1830664,0.5913383,,,,,,,,,,,,,, -541500,4.4076085,0.56945,,,,,,,,,,,,,, -541600,4.267221,0.60119563,,,,,,,,,,,,,, -541700,4.2978897,0.65037495,,,,,,,,,,,,,, -541800,4.4065266,0.54919046,,,,,,,,,,,,,, -541894,,,0.960598647594452,0.148234486579895,0.7557799816131592,1.0529454946517944,50000.0,0.6304000020027161,1.817833662033081,10000.0,182638.67547369003,189041.6077091694,182638.67547369003,6353.330547809601,28.18113732337952,0.0 -541900,4.971882,0.6750028,,,,,,,,,,,,,, -542000,4.468947,0.5983907,,,,,,,,,,,,,, -542100,4.305967,0.6551008,,,,,,,,,,,,,, -542200,4.3374662,0.6513114,,,,,,,,,,,,,, -542300,4.479076,0.61703753,,,,,,,,,,,,,, -542400,4.4322824,0.5700704,,,,,,,,,,,,,, -542500,5.3237734,0.67047787,,,,,,,,,,,,,, -542600,4.762484,0.6808258,,,,,,,,,,,,,, -542700,4.4950194,0.68238145,,,,,,,,,,,,,, -542800,4.672702,0.6207764,,,,,,,,,,,,,, -542900,4.7139235,0.7259043,,,,,,,,,,,,,, -543000,4.1453714,0.59030885,,,,,,,,,,,,,, -543100,4.461192,0.66034746,,,,,,,,,,,,,, -543200,5.1347575,0.6725313,,,,,,,,,,,,,, -543300,4.5379443,0.61670846,,,,,,,,,,,,,, -543400,4.662022,0.67174125,,,,,,,,,,,,,, -543407,,,0.9614955186843872,0.1472430229187011,0.7562199831008911,1.0523183345794678,50000.0,0.6305000185966492,1.8182512521743768,10000.0,183148.73386335373,189568.6337320805,183148.73386335373,6370.11282992363,28.30603194236756,0.0 -543500,4.6513305,0.5993028,,,,,,,,,,,,,, -543600,4.1849747,0.6708003,,,,,,,,,,,,,, -543700,4.347841,0.57606065,,,,,,,,,,,,,, -543800,4.7422476,0.6167281,,,,,,,,,,,,,, -543900,4.4690256,0.71357787,,,,,,,,,,,,,, -544000,4.4838915,0.5888677,,,,,,,,,,,,,, -544100,4.508689,0.5878703,,,,,,,,,,,,,, -544200,4.5839543,0.6587876,,,,,,,,,,,,,, -544300,4.73442,0.6645018,,,,,,,,,,,,,, -544400,4.7183695,0.649167,,,,,,,,,,,,,, -544500,4.5561905,0.62693286,,,,,,,,,,,,,, -544600,4.563378,0.52137375,,,,,,,,,,,,,, -544700,4.8012376,0.69351524,,,,,,,,,,,,,, -544800,4.547712,0.6069286,,,,,,,,,,,,,, -544900,4.9529824,0.5946369,,,,,,,,,,,,,, -544921,,,0.9620336294174194,0.1422350853681564,0.7558199763298035,1.051290512084961,50000.0,0.6317000389099121,1.8176825046539309,10000.0,183658.73257374763,190095.69326162327,183658.73257374763,6386.995321273804,28.42442011833191,0.0 -545000,4.625051,0.60202545,,,,,,,,,,,,,, -545100,4.011793,0.5665091,,,,,,,,,,,,,, -545200,4.957487,0.6884339,,,,,,,,,,,,,, -545300,4.7166443,0.6361524,,,,,,,,,,,,,, -545400,4.2371273,0.65758926,,,,,,,,,,,,,, -545500,4.466611,0.5974861,,,,,,,,,,,,,, -545600,4.171911,0.61108005,,,,,,,,,,,,,, -545700,4.9526725,0.6515981,,,,,,,,,,,,,, -545800,4.5870786,0.6600336,,,,,,,,,,,,,, -545900,4.186086,0.57402337,,,,,,,,,,,,,, -546000,4.8143682,0.6360707,,,,,,,,,,,,,, -546100,5.145756,0.65365124,,,,,,,,,,,,,, -546200,4.4624934,0.6626451,,,,,,,,,,,,,, -546300,5.2427278,0.626619,,,,,,,,,,,,,, -546400,4.826271,0.6378416,,,,,,,,,,,,,, -546435,,,0.960957407951355,0.1463946253061294,0.7555399537086487,1.0523061752319336,50000.0,0.6313000321388245,1.817774534225464,10000.0,184168.75291776657,190623.0963089466,184168.75291776657,6404.198199510574,28.544098138809204,0.0 -546500,4.5967493,0.6146383,,,,,,,,,,,,,, -546600,4.2359624,0.5881685,,,,,,,,,,,,,, -546700,4.247297,0.6490616,,,,,,,,,,,,,, -546800,4.7688575,0.618722,,,,,,,,,,,,,, -546900,4.411064,0.5795529,,,,,,,,,,,,,, -547000,4.678926,0.6907485,,,,,,,,,,,,,, -547100,4.9131594,0.65209186,,,,,,,,,,,,,, -547200,4.458797,0.60017735,,,,,,,,,,,,,, -547300,4.631019,0.67420954,,,,,,,,,,,,,, -547400,4.021676,0.60695684,,,,,,,,,,,,,, -547500,4.3565493,0.63009363,,,,,,,,,,,,,, -547600,4.615712,0.609493,,,,,,,,,,,,,, -547700,4.2610016,0.6146711,,,,,,,,,,,,,, -547800,4.7909303,0.6605964,,,,,,,,,,,,,, -547900,4.464382,0.5517022,,,,,,,,,,,,,, -547948,,,0.959741711616516,0.1487929821014404,0.7556399703025818,1.0521609783172607,50000.0,0.6314000487327576,1.817501425743103,10000.0,184678.612347126,191150.11498069763,184678.612347126,6421.181804418564,28.66173243522644,0.0 -548000,4.954144,0.6235328,,,,,,,,,,,,,, -548100,4.3375993,0.56463975,,,,,,,,,,,,,, -548200,5.1077647,0.63908124,,,,,,,,,,,,,, -548300,4.702725,0.639555,,,,,,,,,,,,,, -548400,4.3603954,0.6039019,,,,,,,,,,,,,, -548500,4.474047,0.6628686,,,,,,,,,,,,,, -548600,4.7597384,0.61086637,,,,,,,,,,,,,, -548700,4.3231463,0.57069623,,,,,,,,,,,,,, -548800,4.831575,0.62542886,,,,,,,,,,,,,, -548900,4.3765726,0.6232651,,,,,,,,,,,,,, -549000,5.22312,0.69595635,,,,,,,,,,,,,, -549100,4.8303466,0.7193523,,,,,,,,,,,,,, -549200,4.4300275,0.6858921,,,,,,,,,,,,,, -549300,4.0522895,0.5756396,,,,,,,,,,,,,, -549400,4.389464,0.65932596,,,,,,,,,,,,,, -549461,,,0.960180163383484,0.148353636264801,0.7560799717903137,1.051113843917847,50000.0,0.6309000253677368,1.8153493404388428,10000.0,185188.4840466976,191677.650382042,185188.4840466976,6438.668433189392,28.779770612716675,0.0 -549500,4.3499894,0.5591596,,,,,,,,,,,,,, -549600,4.8898067,0.6497593,,,,,,,,,,,,,, -549700,4.664827,0.5951432,,,,,,,,,,,,,, -549800,4.6293206,0.6797183,,,,,,,,,,,,,, -549900,4.319425,0.60651267,,,,,,,,,,,,,, -550000,4.1387296,0.59659934,,,,,,,,,,,,,, -550100,4.7401643,0.5959443,,,,,,,,,,,,,, -550200,4.384994,0.6360025,,,,,,,,,,,,,, -550300,3.902297,0.5760207,,,,,,,,,,,,,, -550400,4.3446116,0.6020589,,,,,,,,,,,,,, -550500,4.5936065,0.58312345,,,,,,,,,,,,,, -550600,4.299427,0.60843474,,,,,,,,,,,,,, -550700,4.408938,0.64280033,,,,,,,,,,,,,, -550800,4.41224,0.59034663,,,,,,,,,,,,,, -550900,4.6140995,0.68485,,,,,,,,,,,,,, -550975,,,0.961336076259613,0.143163800239563,0.7555399537086487,1.0519856214523315,50000.0,0.6317000389099121,1.8175721168518064,10000.0,185698.51256275177,192204.65172839165,185698.51256275177,6455.479493379593,28.882810354232788,0.0 -551000,4.6916704,0.63481796,,,,,,,,,,,,,, -551100,4.175676,0.53142184,,,,,,,,,,,,,, -551200,4.4020195,0.56159806,,,,,,,,,,,,,, -551300,4.795047,0.61012024,,,,,,,,,,,,,, -551400,4.5698037,0.71316296,,,,,,,,,,,,,, -551500,4.4561906,0.6941699,,,,,,,,,,,,,, -551600,4.293152,0.537101,,,,,,,,,,,,,, -551700,4.432286,0.6755494,,,,,,,,,,,,,, -551800,4.555625,0.667114,,,,,,,,,,,,,, -551900,4.606584,0.61834675,,,,,,,,,,,,,, -552000,4.651066,0.60041827,,,,,,,,,,,,,, -552100,4.4950333,0.6012085,,,,,,,,,,,,,, -552200,4.7224236,0.6605201,,,,,,,,,,,,,, -552300,4.3763466,0.6227608,,,,,,,,,,,,,, -552400,4.1777263,0.56101835,,,,,,,,,,,,,, -552489,,,0.9610570669174194,0.1454282253980636,0.755840003490448,1.051978588104248,50000.0,0.6310000419616699,1.8172557353973389,10000.0,186208.4559469223,192731.74961352348,186208.4559469223,6472.452427625656,29.003294467926025,0.0 -552500,5.1048083,0.68631923,,,,,,,,,,,,,, -552600,4.7968254,0.6330514,,,,,,,,,,,,,, -552700,5.059511,0.58578503,,,,,,,,,,,,,, -552800,5.1426706,0.66816074,,,,,,,,,,,,,, -552900,4.42165,0.59457356,,,,,,,,,,,,,, -553000,4.9858794,0.6601765,,,,,,,,,,,,,, -553100,4.5905166,0.65521705,,,,,,,,,,,,,, -553200,4.633366,0.6329625,,,,,,,,,,,,,, -553300,4.3717694,0.5761159,,,,,,,,,,,,,, -553400,4.192553,0.5916271,,,,,,,,,,,,,, -553500,4.121309,0.6032021,,,,,,,,,,,,,, -553600,4.3291726,0.5914094,,,,,,,,,,,,,, -553700,4.919151,0.69823766,,,,,,,,,,,,,, -553800,4.4128017,0.6993767,,,,,,,,,,,,,, -553900,4.145484,0.5927861,,,,,,,,,,,,,, -554000,4.5867724,0.6386677,,,,,,,,,,,,,, -554003,,,0.9595423936843872,0.149004265666008,0.7553600072860718,1.052505373954773,50000.0,0.6323000192642212,1.8183035850524905,10000.0,186718.4262833596,193258.82066512108,186718.4262833596,6489.370160579681,29.12525010108948,0.0 -554100,4.27929,0.5535037,,,,,,,,,,,,,, -554200,4.5187035,0.6477629,,,,,,,,,,,,,, -554300,4.93928,0.6302994,,,,,,,,,,,,,, -554400,4.617255,0.64853084,,,,,,,,,,,,,, -554500,4.3474655,0.576797,,,,,,,,,,,,,, -554600,4.633889,0.6243163,,,,,,,,,,,,,, -554700,4.351146,0.63494307,,,,,,,,,,,,,, -554800,4.5685124,0.64898324,,,,,,,,,,,,,, -554900,4.7937164,0.655169,,,,,,,,,,,,,, -555000,4.2791133,0.57354665,,,,,,,,,,,,,, -555100,4.516512,0.5792131,,,,,,,,,,,,,, -555200,4.229927,0.5916331,,,,,,,,,,,,,, -555300,4.7579174,0.6583618,,,,,,,,,,,,,, -555400,4.1070952,0.5547508,,,,,,,,,,,,,, -555500,4.2635756,0.6040128,,,,,,,,,,,,,, -555517,,,0.9606584906578064,0.1476895958185196,0.7559399604797363,1.0525918006896973,50000.0,0.6315000057220459,1.8194420337677,10000.0,187228.433716774,193785.7789597512,187228.433716774,6506.138321399689,29.2472882270813,0.0 -555600,4.7651644,0.7195708,,,,,,,,,,,,,, -555700,4.5631185,0.64080846,,,,,,,,,,,,,, -555800,4.315597,0.6306851,,,,,,,,,,,,,, -555900,3.949629,0.5603839,,,,,,,,,,,,,, -556000,5.343704,0.6419586,,,,,,,,,,,,,, -556100,4.31619,0.6154568,,,,,,,,,,,,,, -556200,4.685599,0.701953,,,,,,,,,,,,,, -556300,4.390453,0.5876148,,,,,,,,,,,,,, -556400,4.473075,0.6775775,,,,,,,,,,,,,, -556500,4.9071817,0.7377391,,,,,,,,,,,,,, -556600,4.4097667,0.57659245,,,,,,,,,,,,,, -556700,5.0502234,0.5823113,,,,,,,,,,,,,, -556800,4.647565,0.6567766,,,,,,,,,,,,,, -556900,4.012143,0.5706056,,,,,,,,,,,,,, -557000,4.5443864,0.5881164,,,,,,,,,,,,,, -557032,,,0.9607381820678712,0.1459540575742721,0.755840003490448,1.05264151096344,50000.0,0.631100058555603,1.818196177482605,10000.0,187738.5424454212,194313.19152712825,187738.5424454212,6523.260720252991,29.36760807037353,0.0 -557100,4.5813675,0.60208255,,,,,,,,,,,,,, -557200,4.642659,0.66591775,,,,,,,,,,,,,, -557300,4.5817704,0.61110973,,,,,,,,,,,,,, -557400,4.7470093,0.62486386,,,,,,,,,,,,,, -557500,4.1901135,0.624751,,,,,,,,,,,,,, -557600,4.0898905,0.6074359,,,,,,,,,,,,,, -557700,4.1450276,0.61224365,,,,,,,,,,,,,, -557800,4.4538465,0.65815246,,,,,,,,,,,,,, -557900,4.6844387,0.69233996,,,,,,,,,,,,,, -558000,4.413052,0.6616186,,,,,,,,,,,,,, -558100,4.3804054,0.5475927,,,,,,,,,,,,,, -558200,4.440515,0.59733886,,,,,,,,,,,,,, -558300,4.0926275,0.5597036,,,,,,,,,,,,,, -558400,4.9222426,0.6028336,,,,,,,,,,,,,, -558500,4.619376,0.5965262,,,,,,,,,,,,,, -558546,,,0.9609972834587096,0.1453019976615905,0.7559199929237366,1.0516624450683594,50000.0,0.6312000155448914,1.817338943481445,10000.0,188248.5768508911,194840.18134617803,188248.5768508911,6540.03351855278,29.49054718017578,0.0 -558600,4.429727,0.63261014,,,,,,,,,,,,,, -558700,4.3603272,0.5837445,,,,,,,,,,,,,, -558800,4.287967,0.6021924,,,,,,,,,,,,,, -558900,4.4276686,0.578278,,,,,,,,,,,,,, -559000,4.4242783,0.6296928,,,,,,,,,,,,,, -559100,4.4627743,0.6345006,,,,,,,,,,,,,, -559200,4.685813,0.611565,,,,,,,,,,,,,, -559300,4.699687,0.6602609,,,,,,,,,,,,,, -559400,4.163329,0.61108583,,,,,,,,,,,,,, -559500,5.377302,0.6877669,,,,,,,,,,,,,, -559600,4.3340297,0.6794306,,,,,,,,,,,,,, -559700,4.5593987,0.5832784,,,,,,,,,,,,,, -559800,4.513135,0.5859892,,,,,,,,,,,,,, -559900,4.4907227,0.6625763,,,,,,,,,,,,,, -559998,,,0.9608577489852904,0.14729043841362,0.7560799717903137,1.0520939826965332,50000.0,0.6314000487327576,1.8173184394836424,10000.0,188737.6040322781,195346.2275600433,188737.6040322781,6556.845715045929,29.64098310470581,0.0 -559998,,,,,,,,,,,188737.60403227806,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index a26d5ca0d..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,555 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -40.32848334312439,0.0,42.50220847129822,1,0,42.50220847129822,0.0010000000474974,6.907756805419922,10000,82.83079028129578,0.0008789062267169,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -62.036508321762085,0.027024507522583,462.6079497337341,905,0,462.6079497337341,0.02730000205338,5.976207733154297,10000,524.7230081558228,0.0355468727648258,5.822093486785889,0.0365999974310398,5.84222936630249,50000 -83.49371838569641,0.0549170970916748,882.7354953289032,1862,0,882.7354953289032,0.0597000010311603,5.4772138595581055,10000,966.3888356685638,0.0798046886920929,5.190312385559082,0.0763199999928474,5.234590530395508,50000 -105.18109583854675,0.0841908454895019,1302.832392454147,2820,0,1302.832392454147,0.106900006532669,4.9579620361328125,10000,1408.2555124759674,0.1470703035593032,4.55432653427124,0.1360999941825866,4.649452209472656,50000 -126.77802324295044,0.113541841506958,1722.8942482471466,3779,0,1722.8942482471466,0.1464000046253204,4.582161903381348,10000,1849.997201681137,0.2031054645776748,4.117558002471924,0.1888599991798401,4.196497440338135,50000 -148.62538933753967,0.1412358283996582,2143.081571817398,4723,0,2143.081571817398,0.185800015926361,4.2058210372924805,10000,2292.112045764923,0.2610937356948852,3.6418685913085938,0.2449599951505661,3.743607997894287,50000 -170.35447335243225,0.1706206798553466,2563.1050510406494,5677,0,2563.1050510406494,0.2299000173807144,3.89674711227417,10000,2733.9481399059296,0.3202343583106994,3.211933135986328,0.2983799874782562,3.3576674461364746,50000 -192.0537114143372,0.2007627487182617,2983.2663242816925,6629,0,2983.2663242816925,0.2574000060558319,3.7208309173583984,10000,3175.8974702358246,0.365527331829071,2.978093147277832,0.3293399810791015,3.1722586154937744,50000 -218.76269936561584,0.2289977073669433,3403.3904716968536,7578,0,3403.3904716968536,0.2772000133991241,3.57218599319458,10000,3622.811475038528,0.392382800579071,2.8480987548828125,0.3620399832725525,2.98736834526062,50000 -245.2535297870636,0.2788910865783691,3823.300567626953,8528,0,3823.300567626953,0.2962000072002411,3.424993991851806,10000,4069.3166477680206,0.4204296767711639,2.6367781162261963,0.3862800002098083,2.817789554595948,50000 -270.96862721443176,0.3183200359344482,4243.458779096603,9478,0,4243.458779096603,0.3241000175476074,3.2569079399108887,10000,4515.28219127655,0.458300769329071,2.420060396194458,0.4150999784469604,2.6546177864074707,50000 -297.3756804466248,0.3520138263702392,4663.765017032623,10424,0,4663.765017032623,0.324500024318695,3.2115306854248047,10000,4962.0816378593445,0.4592382609844208,2.421598434448242,0.4288399815559387,2.58870530128479,50000 -325.7472426891327,0.3814010620117187,5083.707577466965,11373,0,5083.707577466965,0.3439000248908996,3.133396625518799,10000,5410.477880477905,0.4764843583106994,2.326941728591919,0.4420799911022186,2.498100519180298,50000 -353.4686424732208,0.4112565517425537,5503.754225730896,12322,0,5503.754225730896,0.3625000119209289,3.0175158977508545,10000,5858.329460859299,0.5114062428474426,2.1223716735839844,0.4688199758529663,2.3403637409210205,50000 -382.17194390296936,0.4417285919189453,5923.964722394943,13263,0,5923.964722394943,0.3637000024318695,2.990636348724365,10000,6307.325840473175,0.5409570336341858,2.0211377143859863,0.4729399979114532,2.332581043243408,50000 -413.3990566730499,0.4788401126861572,6344.242172241211,14209,0,6344.242172241211,0.3805000185966491,2.9285807609558105,10000,6758.920413732529,0.5172070264816284,2.0955166816711426,0.4853599965572357,2.272596597671509,50000 -446.1518096923828,0.5089349746704102,6764.433388948441,15153,0,6764.433388948441,0.3815000057220459,2.920339822769165,10000,7211.947317838669,0.53125,2.060498237609864,0.4918199777603149,2.2550272941589355,50000 -475.9171495437622,0.5414671897888184,7184.445158958435,16096,0,7184.445158958435,0.3912000060081482,2.838825225830078,10000,7661.808724164963,0.5531249642372131,1.9234681129455569,0.5004400014877319,2.1823103427886963,50000 -505.95252656936646,0.5746691226959229,7604.548624038696,17035,0,7604.548624038696,0.3888000249862671,2.8619396686553955,10000,8112.031835079193,0.5389062166213989,2.0137147903442383,0.5019400119781494,2.185546398162842,50000 -536.4071035385132,0.6097869873046875,8024.889421463013,17979,0,8024.889421463013,0.403300017118454,2.80226469039917,10000,8562.914777040482,0.5502538681030273,1.966177463531494,0.5148800015449524,2.137053966522217,50000 -567.2139377593994,0.6404788494110107,8445.13327217102,18922,0,8445.13327217102,0.4070000052452087,2.7275853157043457,10000,9014.0472304821,0.5723046660423279,1.7916992902755735,0.5268599987030029,2.030201435089112,50000 -597.48259973526,0.6718833446502686,8865.34806394577,19865,0,8865.34806394577,0.4144000113010406,2.693567276000977,10000,9464.613876342772,0.5950781106948853,1.7214692831039429,0.5264599919319153,2.04833984375,50000 -629.0911407470703,0.7059800624847412,9285.290144205092,20808,0,9285.290144205092,0.4191000163555145,2.663177728652954,10000,9916.250267505646,0.5799609422683716,1.7857638597488403,0.5396199822425842,1.978995442390442,50000 -659.7037007808685,0.7450652122497559,9705.551045656204,21750,0,9705.551045656204,0.4229000210762024,2.649247407913208,10000,10367.21435379982,0.5848437547683716,1.772911548614502,0.5408999919891357,1.9920196533203125,50000 -689.5395059585571,0.7798676490783691,10125.80424952507,22689,0,10125.80424952507,0.4222000241279602,2.663605213165283,10000,10817.389746904371,0.5935742259025574,1.716032862663269,0.5404399633407593,1.994184613227844,50000 -724.6498851776123,0.8144900798797607,10545.795673847198,23629,0,10545.795673847198,0.4326000213623047,2.574949026107788,10000,11272.577404022217,0.5935351252555847,1.7091354131698608,0.5555399656295776,1.896039366722107,50000 -756.9143693447113,0.8506224155426025,10966.144136667252,24578,0,10966.144136667252,0.4345000088214874,2.592747926712036,10000,11725.279448986052,0.5988476276397705,1.6850430965423584,0.5579400062561035,1.8976974487304688,50000 -788.4734981060028,0.8868091106414795,11386.355078458786,25524,0,11386.355078458786,0.4425000250339508,2.5231847763061523,10000,12177.137206077576,0.6112499833106995,1.6054291725158691,0.5604999661445618,1.8555995225906368,50000 -820.2013580799103,2.0886545181274414,11805.259394168854,26463,0,11805.259394168854,0.443200021982193,2.5797505378723145,10000,12629.023286819458,0.6207422018051147,1.6242880821228027,0.5597000122070312,1.911012291908264,50000 -851.1198291778564,2.133695363998413,12225.197360038756,27404,0,12225.197360038756,0.4491000175476074,2.521052837371826,10000,13079.97579050064,0.6065429449081421,1.6465208530426023,0.5616399645805359,1.8516168594360352,50000 -882.6207411289215,2.1664159297943115,12645.36327767372,28342,0,12645.36327767372,0.4531000256538391,2.515817880630493,10000,13531.726953268051,0.6184960603713989,1.6287354230880735,0.570580005645752,1.852982759475708,50000 -914.2631750106812,2.201692819595337,13065.477472305298,29282,0,13065.477472305298,0.4621000289916992,2.471017837524414,10000,13983.570016145706,0.6328710913658142,1.5523300170898438,0.5777599811553955,1.8129938840866089,50000 -945.3267643451692,2.239365339279175,13485.7399559021,30223,0,13485.7399559021,0.4615000188350677,2.451876163482666,10000,14434.98591184616,0.6181054711341858,1.5981839895248413,0.5783399939537048,1.7888963222503662,50000 -976.8150374889374,2.274978876113892,13906.012647390366,31166,0,13906.012647390366,0.4650000333786011,2.4483392238616943,10000,14886.83359527588,0.6274218559265137,1.5659538507461548,0.5806399583816528,1.7789556980133057,50000 -1007.635326385498,2.308807134628296,14326.030616283417,32107,0,14326.030616283417,0.4610000252723694,2.439715623855591,10000,15337.757307052612,0.635058581829071,1.529654622077942,0.5848599672317505,1.7706170082092283,50000 -1039.280338525772,2.342710018157959,14746.13856124878,33047,0,14746.13856124878,0.4671000242233276,2.421347141265869,10000,15789.595637083054,0.64013671875,1.4978110790252686,0.5824599862098694,1.7634974718093872,50000 -1069.6364603042605,2.3791699409484863,15166.186804294586,33990,0,15166.186804294586,0.4688000082969665,2.423509120941162,10000,16240.08767604828,0.6319726705551147,1.5471922159194946,0.5868399739265442,1.755309820175171,50000 -1099.3962256908417,2.414524793624878,15586.226889133452,34929,0,15586.226889133452,0.4708000123500824,2.3756675720214844,10000,16689.974547863007,0.641406238079071,1.4608179330825806,0.5906999707221985,1.7022159099578855,50000 -1132.126388311386,2.4569034576416016,16006.404333353044,35868,0,16006.404333353044,0.4737000167369842,2.368324279785156,10000,17142.975281000137,0.6634570360183716,1.371727705001831,0.5966399908065796,1.674397110939026,50000 -1163.6418359279633,2.496078968048096,16426.55730485916,36811,0,16426.55730485916,0.4733000099658966,2.3658699989318848,10000,17594.734039783478,0.6385351419448853,1.4934431314468384,0.5940399765968323,1.706323266029358,50000 -1193.5054569244385,2.531633138656616,16846.59597182274,37748,0,16846.59597182274,0.4742000102996826,2.3896632194519043,10000,18044.722935915,0.641796886920929,1.4904508590698242,0.5951799750328064,1.7146835327148438,50000 -1223.5478613376615,2.571526288986206,17266.64454650879,38688,0,17266.64454650879,0.4707000255584717,2.3531055450439453,10000,18494.905309677124,0.6561914086341858,1.4135422706604004,0.5982199907302856,1.678788661956787,50000 -1253.3240871429443,2.613891363143921,17686.574685573578,39627,0,17686.574685573578,0.4796000123023987,2.343630075454712,10000,18944.70481967926,0.6489648222923279,1.4219398498535156,0.6005799770355225,1.6512315273284912,50000 -1284.0473837852478,2.649874925613404,18106.86800003052,40566,0,18106.86800003052,0.4792000353336334,2.341360569000244,10000,19395.80882000923,0.6485155820846558,1.4499775171279907,0.6007800102233887,1.6757594347000122,50000 -1313.5189683437347,2.6855294704437256,18527.173591852188,41503,0,18527.173591852188,0.4783000349998474,2.380979061126709,10000,19845.67344045639,0.646484375,1.5099809169769287,0.6010199785232544,1.7225019931793213,50000 -1344.332921743393,2.720686674118042,18947.48678970337,42444,0,18947.48678970337,0.4881000220775604,2.3156774044036865,10000,20296.886627674103,0.6727929711341858,1.3531529903411863,0.6094599962234497,1.6596788167953491,50000 -1375.3200373649595,2.756638288497925,19367.46821308136,43385,0,19367.46821308136,0.4800000190734863,2.359612941741944,10000,20747.942059278488,0.6493554711341858,1.4348890781402588,0.6075999736785889,1.6461001634597778,50000 -1405.9976406097412,2.7973742485046387,19787.398421525955,44327,0,19787.398421525955,0.4851000308990478,2.3165159225463867,10000,21198.64275765419,0.6592382788658142,1.38425874710083,0.609499990940094,1.627719759941101,50000 -1436.3729367256165,2.8370449542999268,20207.447094917297,45265,0,20207.447094917297,0.492000013589859,2.267448663711548,10000,21649.15763068199,0.6744531393051147,1.3124046325683594,0.6171199679374695,1.5861340761184692,50000 -1465.5651831626892,2.880309820175171,20627.74993658065,46203,0,20627.74993658065,0.4882000088691711,2.291762351989746,10000,22098.74702978134,0.66162109375,1.3712722063064575,0.6122599840164185,1.6086606979370115,50000 -1497.9042127132416,2.924238920211792,21047.87811112404,47143,0,21047.87811112404,0.4901000261306762,2.2621548175811768,10000,22551.3086810112,0.6664453148841858,1.3697527647018433,0.616599977016449,1.600008845329285,50000 -1527.9312839508057,2.9619107246398926,21467.8795838356,48088,0,21467.8795838356,0.495600014925003,2.226857662200928,10000,23001.425425052643,0.66943359375,1.329981565475464,0.6164000034332275,1.5844672918319702,50000 -1558.8661715984344,2.999751567840576,21888.09333539009,49027,0,21888.09333539009,0.4955000281333923,2.276652336120605,10000,23452.66234397888,0.6886523365974426,1.2827799320220947,0.617680013179779,1.6056610345840454,50000 -1589.2278089523315,3.0361344814300537,22308.256228208546,49970,0,22308.256228208546,0.4980000257492065,2.2585461139678955,10000,23903.274853229523,0.6638867259025574,1.3901808261871338,0.6227799654006958,1.5859307050704956,50000 -1619.9439854621887,3.0944161415100098,22728.1577205658,50910,0,22728.1577205658,0.495600014925003,2.232228994369507,10000,24354.00265216828,0.6714843511581421,1.3220560550689695,0.6208400130271912,1.5570722818374634,50000 -1653.205320596695,3.13411545753479,23148.37905406952,51849,0,23148.37905406952,0.4967000186443329,2.263665199279785,10000,24807.57603955269,0.6763281226158142,1.3460313081741333,0.6196799874305725,1.6015217304229736,50000 -1686.3185930252075,3.17166519165039,23568.588448762894,52792,0,23568.588448762894,0.4921000301837921,2.28542423248291,10000,25260.98716020584,0.6678906083106995,1.386376976966858,0.6201199889183044,1.6033985614776611,50000 -1718.752513408661,3.2167985439300537,23988.578372478485,53734,0,23988.578372478485,0.4958000183105469,2.2348878383636475,10000,25713.50721931457,0.674609363079071,1.3201045989990234,0.625,1.555474042892456,50000 -1751.7384040355682,3.2596917152404785,24408.97371816635,54673,0,24408.97371816635,0.508400022983551,2.209318161010742,10000,26166.98223090172,0.6805273294448853,1.3037056922912598,0.6293399930000305,1.5446206331253052,50000 -1787.9053149223328,3.299623966217041,24829.13514328003,55613,0,24829.13514328003,0.5047000050544739,2.2150003910064697,10000,26623.4019677639,0.7066406011581421,1.207590937614441,0.630079984664917,1.5488102436065674,50000 -1821.799224853516,3.3368849754333496,25249.2022702694,56558,0,25249.2022702694,0.5024999976158142,2.2308924198150635,10000,27077.45137095452,0.6771875023841858,1.3591915369033811,0.629040002822876,1.5706725120544434,50000 -1854.3612713813784,3.379136562347412,25669.92558169365,57501,0,25669.92558169365,0.5076000094413757,2.20805025100708,10000,27530.83186006546,0.68505859375,1.295912742614746,0.6319800019264221,1.5335566997528076,50000 -1885.599454164505,3.4180257320404053,26090.231457710262,58440,0,26090.231457710262,0.5067000389099121,2.2019996643066406,10000,27982.466410398483,0.6932421922683716,1.2518590688705444,0.6332600116729736,1.53603196144104,50000 -1922.9465517997744,3.4586949348449707,26510.5277197361,59378,0,26510.5277197361,0.5088000297546387,2.177444696426392,10000,28440.20163846016,0.6830077767372131,1.290729284286499,0.6337800025939941,1.5220874547958374,50000 -1955.368063211441,3.498325824737549,26930.86291337013,60323,0,26930.86291337013,0.5124000310897827,2.165943622589112,10000,28893.049347400665,0.6819140315055847,1.2942769527435305,0.6343799829483032,1.5181857347488403,50000 -1986.939861536026,3.541537284851074,27351.005984783173,61265,0,27351.005984783173,0.5146000385284424,2.156124830245972,10000,29344.85937547684,0.6977343559265137,1.215801477432251,0.640500009059906,1.4811792373657229,50000 -2018.25333237648,3.5821726322174072,27771.32909488678,62208,0,27771.32909488678,0.5141000151634216,2.1523773670196533,10000,29796.588678121567,0.7203710675239563,1.1145329475402832,0.6396999955177307,1.4701439142227173,50000 -2051.9938457012177,3.6222991943359375,28191.540812969208,63150,0,28191.540812969208,0.5139000415802002,2.1568527221679688,10000,30250.63220858574,0.68994140625,1.255996823310852,0.6400200128555298,1.4859697818756104,50000 -2085.192571163177,3.666409969329834,28611.683968544006,64094,0,28611.683968544006,0.5193000435829163,2.129345655441284,10000,30704.06983423233,0.6942577958106995,1.2345657348632812,0.6388799548149109,1.49025559425354,50000 -2117.3324999809265,3.71160888671875,29031.91866993904,65041,0,29031.91866993904,0.5171000361442566,2.138280868530273,10000,31156.541711330414,0.7075585722923279,1.1857621669769287,0.6427599787712097,1.4769474267959597,50000 -2148.558340787888,3.76130747795105,29451.99482369423,65984,0,29451.99482369423,0.5164999961853027,2.131663084030152,10000,31607.9455947876,0.69580078125,1.236257791519165,0.645799994468689,1.4627673625946045,50000 -2180.7719078063965,3.802668809890747,29872.294987916943,66925,0,29872.294987916943,0.5211000442504883,2.114197492599488,10000,32060.55270266533,0.6991015672683716,1.2016446590423584,0.6476799845695496,1.4438791275024414,50000 -2213.23275232315,3.84781265258789,30292.256318330765,67866,0,30292.256318330765,0.5210000276565552,2.1125617027282715,10000,32513.07097506523,0.7053515315055847,1.1844969987869265,0.6496999859809875,1.450947642326355,50000 -2245.0266540050507,3.895355939865112,30712.568786621094,68811,0,30712.568786621094,0.524399995803833,2.091939687728882,10000,32965.27647805214,0.7294921875,1.0799065828323364,0.6500200033187866,1.4359275102615356,50000 -2276.8387970924377,3.9474518299102783,31132.49741792679,69748,0,31132.49741792679,0.5236000418663025,2.12225604057312,10000,33417.12102437019,0.6984961032867432,1.211103916168213,0.6511600017547607,1.4439362287521362,50000 -2307.9747524261475,3.996495962142944,31552.53966331482,70689,0,31552.53966331482,0.5211000442504883,2.1125056743621826,10000,33868.39933013916,0.7059765458106995,1.168419361114502,0.6503599882125854,1.433741331100464,50000 -2340.085106611252,4.0459089279174805,31972.546494960785,71632,0,31972.546494960785,0.5218000411987305,2.1029539108276367,10000,34320.61787605286,0.7135937213897705,1.1473480463027954,0.6534000039100647,1.4237979650497437,50000 -2372.071283340454,4.090409517288208,32392.69544196129,72572,0,32392.69544196129,0.5273000001907349,2.112567186355591,10000,34772.84811258316,0.6991015672683716,1.2266145944595337,0.6523399949073792,1.445721983909607,50000 -2403.665696620941,4.1330084800720215,32812.83591794968,73510,0,32812.83591794968,0.5275000333786011,2.077162504196167,10000,35224.67768549919,0.7064843773841858,1.1597418785095217,0.651919960975647,1.4081584215164185,50000 -2441.54381275177,4.176920413970947,33232.861943244934,74451,0,33232.861943244934,0.5303000211715698,2.106510162353516,10000,35682.67705178261,0.7114648222923279,1.1783130168914795,0.652899980545044,1.4420287609100342,50000 -2473.425219774246,4.219008684158325,33653.16945314407,75395,0,33653.16945314407,0.5238000154495239,2.119074583053589,10000,36134.95932555199,0.7361718416213989,1.0584875345230105,0.6551799774169922,1.4225327968597412,50000 -2506.250350475312,4.261711597442627,34073.50621318817,76333,0,34073.50621318817,0.5350000262260437,2.044471740722656,10000,36588.21505665779,0.7122265696525574,1.1508327722549438,0.658840000629425,1.3871424198150637,50000 -2543.693027973175,4.307955265045166,34493.52569794655,77270,0,34493.52569794655,0.5396000146865845,2.0465617179870605,10000,37045.77448058128,0.7195898294448853,1.1208205223083496,0.6632800102233887,1.391443133354187,50000 -2576.5816905498505,4.3477983474731445,34913.678639411926,78215,0,34913.678639411926,0.5385000109672546,2.0573971271514893,10000,37498.90714406967,0.7258203029632568,1.1131298542022705,0.6583600044250488,1.4111359119415283,50000 -2611.512553215027,4.392807483673096,35333.94675087929,79157,0,35333.94675087929,0.5360000133514404,2.062854766845703,10000,37954.20350050926,0.7126757502555847,1.168033480644226,0.6606799960136414,1.408635139465332,50000 -2648.128517389297,4.436190128326416,35753.97294163704,80099,0,35753.97294163704,0.5321000218391418,2.0679373741149902,10000,38410.941356658936,0.71546870470047,1.14302659034729,0.6597599983215332,1.3939063549041748,50000 -2680.9402170181274,4.473475933074951,36174.29339551926,81043,0,36174.29339551926,0.5390000343322754,2.057657480239868,10000,38864.1615319252,0.7226366996765137,1.1191169023513794,0.6619600057601929,1.3967500925064087,50000 -2714.103229045868,4.518529653549194,36594.43833613396,81982,0,36594.43833613396,0.5347000360488892,2.0538101196289062,10000,39317.56582832336,0.7477148175239563,1.0231945514678955,0.6630799770355225,1.392142415046692,50000 -2749.8208360672,4.561380863189697,37014.43953371048,82924,0,37014.43953371048,0.5391000509262085,2.0352330207824707,10000,39773.3790769577,0.7239843606948853,1.1104586124420166,0.6676200032234192,1.363040804862976,50000 -2781.979739665985,4.608523607254028,37434.735203027725,83865,0,37434.735203027725,0.5454000234603882,2.023963689804077,10000,40225.933086156845,0.7298241853713989,1.0960992574691772,0.6693199872970581,1.3602654933929443,50000 -2813.217336177826,4.663280725479126,37854.68165183067,84806,0,37854.68165183067,0.5470000505447388,2.004969358444214,10000,40677.22996091843,0.73451167345047,1.0335588455200195,0.6688599586486816,1.3441568613052368,50000 -2846.062223434448,4.709845066070557,38275.05656552315,85746,0,38275.05656552315,0.541700005531311,2.016709089279175,10000,41130.54695153237,0.7251952886581421,1.0878881216049194,0.667419970035553,1.332442045211792,50000 -2883.2515251636505,4.756951808929443,38695.34867835045,86688,0,38695.34867835045,0.5494000315666199,1.985035419464112,10000,41588.12759113312,0.7310351133346558,1.069846749305725,0.6744799613952637,1.320836067199707,50000 -2918.502597808838,4.796014308929443,39115.71409368515,87633,0,39115.71409368515,0.5503000020980835,1.9918979406356807,10000,42043.83431196213,0.7341992259025574,1.0544023513793943,0.6748200058937073,1.3344672918319702,50000 -2949.6885225772858,4.841732740402222,39535.89856672287,88575,0,39535.89856672287,0.5534000396728516,1.966261625289917,10000,42495.30137729645,0.7554491758346558,0.969889760017395,0.6746399998664856,1.3290221691131592,50000 -2981.5580892562866,4.89051628112793,39956.18599700928,89509,0,39956.18599700928,0.5515000224113464,1.975314736366272,10000,42947.557476997375,0.7327538728713989,1.073253512382507,0.676580011844635,1.3220385313034058,50000 -3015.738451242447,4.947018384933472,40376.54059243202,90446,0,40376.54059243202,0.5520000457763672,1.968473553657532,10000,43402.19970941544,0.7387109398841858,1.030994176864624,0.6740999817848206,1.3186286687850952,50000 -3050.9701120853424,4.993818759918213,40796.58492851257,91389,0,40796.58492851257,0.5552000403404236,1.9847722053527832,10000,43857.57451105118,0.7518945336341858,0.9972430467605592,0.6795799732208252,1.3187674283981323,50000 -3081.7949402332306,5.041177034378052,41216.692452430725,92329,0,41216.692452430725,0.5539000034332275,1.954208254814148,10000,44308.60667061806,0.7335546612739563,1.0436562299728394,0.6787799596786499,1.2994449138641355,50000 -3114.198682308197,5.118264436721802,41636.87300825119,93273,0,41636.87300825119,0.5590000152587891,1.941738963127136,10000,44761.32083892822,0.7421875,1.0251456499099731,0.6827200055122375,1.2972899675369265,50000 -3153.09051322937,5.165774345397949,42057.10592675209,94214,0,42057.10592675209,0.5564000010490417,1.9675493240356443,10000,45220.5447204113,0.7447265386581421,1.0316696166992188,0.6801599860191345,1.314698576927185,50000 -3185.913388967514,5.208800315856934,42477.334065675735,95160,0,42477.334065675735,0.5582000017166138,1.948542833328247,10000,45673.6993329525,0.7561132907867432,0.99065762758255,0.6864399909973145,1.299466252326965,50000 -3216.6718633174896,5.261148452758789,42897.442705631256,96101,0,42897.442705631256,0.5559000372886658,1.9323203563690183,10000,46124.67051315308,0.7415820360183716,1.0177706480026243,0.6850599646568298,1.278401017189026,50000 -3249.469914674759,5.307514429092407,43317.70691943169,97040,0,43317.70691943169,0.5623000264167786,1.909122347831726,10000,46577.83276820183,0.7523437142372131,0.9721384048461914,0.6881399750709534,1.2559822797775269,50000 -3283.58735871315,5.355186462402344,43738.02734136581,97981,0,43738.02734136581,0.5659000277519226,1.902812004089356,10000,47032.36952018738,0.7609765529632568,0.9368932247161864,0.6894800066947937,1.2643203735351562,50000 -3316.1046471595764,5.404359579086304,44158.18087887764,98923,0,44158.18087887764,0.5624000430107117,1.9388071298599243,10000,47485.14091467857,0.7438867092132568,0.993510365486145,0.6860799789428711,1.261129379272461,50000 -3348.923852443695,5.479732513427734,44578.10844540596,99867,0,44578.10844540596,0.5617000460624695,1.9162745475769043,10000,47938.01518249512,0.7521874904632568,0.9745285511016846,0.6863200068473816,1.2652863264083862,50000 -3381.551055908203,5.530679225921631,44998.0839908123,100809,0,44998.0839908123,0.5648000240325928,1.908398151397705,10000,48390.72035264969,0.7562499642372131,0.9649302363395692,0.6868999600410461,1.2645305395126345,50000 -3414.7790591716766,5.598333120346069,45418.158311128616,101754,0,45418.158311128616,0.5596000552177429,1.9227352142333984,10000,48844.14165306091,0.75830078125,0.965836763381958,0.6916399598121643,1.2648234367370603,50000 -3453.1070244312286,5.645540475845337,45838.346932172775,102695,0,45838.346932172775,0.5699000358581543,1.8935141563415527,10000,49302.75633502007,0.7544921636581421,0.9744990468025208,0.6915799975395203,1.252241611480713,50000 -3485.2028307914734,5.693195819854736,46258.4792034626,103641,0,46258.4792034626,0.5738000273704529,1.8720823526382449,10000,49755.083996772766,0.76039057970047,0.9307794570922852,0.6939799785614014,1.2324928045272827,50000 -3522.3583607673645,5.741758108139038,46678.42165637016,104581,0,46678.42165637016,0.5644000172615051,1.9141690731048584,10000,50212.28129172325,0.7684960961341858,0.922848641872406,0.6940000057220459,1.2535018920898438,50000 -3553.533600568772,5.791807174682617,47098.45361089706,105522,0,47098.45361089706,0.5731000304222107,1.8652210235595703,10000,50663.58996009827,0.7603710889816284,0.9424856901168824,0.6970199942588806,1.2145726680755615,50000 -3587.786563158036,5.8463099002838135,47518.67924666405,106459,0,47518.67924666405,0.5690000057220459,1.8761636018753047,10000,51118.17361497879,0.7631054520606995,0.9237319231033324,0.6990599632263184,1.2113115787506104,50000 -3621.135687589645,5.89500904083252,47938.64813065529,107399,0,47938.64813065529,0.5730000138282776,1.8663215637207031,10000,51571.591633319855,0.7676562070846558,0.9086039662361144,0.6971399784088135,1.2299054861068726,50000 -3653.203928470612,5.945183038711548,48358.62239551544,108342,0,48358.62239551544,0.5802000164985657,1.8436038494110107,10000,52023.73527216911,0.7684179544448853,0.91727477312088,0.7032999992370605,1.2101258039474487,50000 -3687.069760560989,5.995041608810425,48778.72830224037,109283,0,48778.72830224037,0.5815000534057617,1.822532296180725,10000,52477.80788874626,0.7674413919448853,0.9051960706710817,0.7060799598693848,1.1832246780395508,50000 -3719.524533748626,6.042582511901856,49198.84672832489,110223,0,49198.84672832489,0.5778000354766846,1.833355188369751,10000,52930.48039579392,0.77357417345047,0.8837026953697205,0.7024399638175964,1.1945924758911133,50000 -3754.3326518535614,6.095505475997925,49619.10298538208,111164,0,49619.10298538208,0.5790000557899475,1.8291984796524048,10000,53385.648253917694,0.7866796851158142,0.8273036479949951,0.7063199877738953,1.1831731796264648,50000 -3791.444357633591,6.148894309997559,50039.17558908463,112098,0,50039.17558908463,0.5746000409126282,1.8687307834625244,10000,53842.93714237213,0.7661523222923279,0.9525970816612244,0.7068399786949158,1.217895746231079,50000 -3824.517323732376,6.20680570602417,50459.45133137703,113041,0,50459.45133137703,0.5866000056266785,1.8236979246139529,10000,54296.39494585991,0.77685546875,0.892575740814209,0.708139955997467,1.1836878061294556,50000 -3863.1170892715454,6.262190341949463,50879.65846824646,113982,0,50879.65846824646,0.5873000025749207,1.8146487474441528,10000,54755.30799412727,0.78724604845047,0.8162120580673218,0.7078999876976013,1.162050485610962,50000 -3897.221100330353,6.308267831802368,51299.89177918434,114924,0,51299.89177918434,0.5843999981880188,1.8155479431152344,10000,55209.74253320694,0.7758007645606995,0.8787176609039307,0.7107399702072144,1.1633338928222656,50000 -3929.47137761116,6.3619771003723145,51720.268078804016,115864,0,51720.268078804016,0.5877000093460083,1.806920051574707,10000,55662.47427988053,0.7816015481948853,0.8619747757911682,0.7142999768257141,1.1624583005905151,50000 -3964.8775465488434,6.414790630340576,52140.46926808357,116803,0,52140.46926808357,0.5905000567436218,1.7739776372909546,10000,56118.18588399887,0.7881640195846558,0.8116686940193176,0.7134599685668945,1.1392488479614258,50000 -3997.784056663513,6.465289831161499,52560.747259140015,117743,0,52560.747259140015,0.5903000235557556,1.7920637130737305,10000,56571.47288775444,0.7998046875,0.7883418202400208,0.7152000069618225,1.1490259170532229,50000 -4029.862047433853,6.51940655708313,52980.7577688694,118679,0,52980.7577688694,0.5915000438690186,1.7688794136047363,10000,57023.66591095925,0.7859179377555847,0.825340986251831,0.7155599594116211,1.1319674253463743,50000 -4063.4930033683777,6.587791919708252,53401.04287528992,119619,0,53401.04287528992,0.5887000560760498,1.7725476026535034,10000,57477.70270061493,0.7870898246765137,0.8163532614707947,0.7176799774169922,1.1236371994018557,50000 -4103.463726997376,6.644713640213013,53820.95779657364,120558,0,53820.95779657364,0.5924000144004822,1.7938237190246582,10000,57937.69656896591,0.7944140434265137,0.8256067633628845,0.7125999927520752,1.1662746667861938,50000 -4137.593298435211,6.69199538230896,54240.93333983421,121504,0,54240.93333983421,0.5913000106811523,1.7923246622085571,10000,58391.9005625248,0.78466796875,0.866515576839447,0.7185399532318115,1.1521672010421753,50000 -4170.385221481323,6.743417024612427,54660.91382360458,122443,0,54660.91382360458,0.5937000513076782,1.7646859884262085,10000,58844.77633571625,0.7889843583106995,0.8201755285263062,0.7201600074768066,1.1214879751205444,50000 -4204.8427193164825,6.7984137535095215,55081.21804690361,123381,0,55081.21804690361,0.5998000502586365,1.7507137060165403,10000,59299.64342093468,0.8001562356948853,0.7793758511543274,0.7241399884223938,1.10870361328125,50000 -4248.118583202362,6.855582475662232,55501.22751235962,124320,0,55501.22751235962,0.5972000360488892,1.7637840509414673,10000,59763.037217378616,0.8098242282867432,0.7531871795654297,0.720579981803894,1.1280159950256348,50000 -4280.447121620178,6.898099422454834,55921.260788440704,125263,0,55921.260788440704,0.5982000231742859,1.7279510498046875,10000,60215.49275445938,0.7971484065055847,0.7829906344413757,0.7261599898338318,1.0861806869506836,50000 -4313.0392434597015,6.958737850189209,56341.26990056038,126199,0,56341.26990056038,0.6076000332832336,1.6989731788635254,10000,60668.204579114914,0.802539050579071,0.7490731477737427,0.7263999581336975,1.0830198526382446,50000 -4350.054116725922,7.012200832366943,56761.53534936905,127137,0,56761.53534936905,0.6055999994277954,1.7176285982131958,10000,61125.58907032013,0.8118359446525574,0.7276731133460999,0.7282399535179138,1.081851363182068,50000 -4382.287897109985,7.066992282867432,57181.63070321083,128076,0,57181.63070321083,0.603600025177002,1.7354732751846311,10000,61578.02355456352,0.7968358993530273,0.7937740087509155,0.7278199791908264,1.0928000211715698,50000 -4416.935405731201,7.12395167350769,57601.90099310875,129015,0,57601.90099310875,0.6086000204086304,1.7183754444122314,10000,62033.05067133904,0.8051171898841858,0.7775402665138245,0.7311999797821045,1.1002105474472046,50000 -4450.6000871658325,7.179595232009888,58022.18709683418,129956,0,58022.18709683418,0.6084000468254089,1.6949561834335327,10000,62487.10815668106,0.8118359446525574,0.7209511995315552,0.7298399806022644,1.0690104961395264,50000 -4482.868451356888,7.265031337738037,58442.412564754486,130898,0,58442.412564754486,0.6083000302314758,1.700702667236328,10000,62939.74114704132,0.8252343535423279,0.6869196891784668,0.7318999767303467,1.0726889371871948,50000 -4515.693987369537,7.319501161575317,58862.37831878662,131835,0,58862.37831878662,0.6165000200271606,1.6794337034225464,10000,63392.63826608658,0.8069726228713989,0.7426121234893799,0.7356399893760681,1.055567502975464,50000 -4552.977478265762,7.37714958190918,59282.7042388916,132771,0,59282.7042388916,0.6173000335693359,1.6698813438415527,10000,63850.35701370239,0.8157616853713989,0.7002706527709961,0.7349199652671814,1.0541627407073977,50000 -4591.900043010712,7.433608055114746,59702.83695149422,133714,0,59702.83695149422,0.6152000427246094,1.6732351779937744,10000,64309.51987743378,0.8212499618530273,0.6917605996131897,0.7364400029182434,1.0647858381271362,50000 -4622.005768299103,7.486406564712524,60122.92111968994,134657,0,60122.92111968994,0.6171000003814697,1.657669186592102,10000,64759.81437373161,0.8107812404632568,0.7149853110313416,0.7350999712944031,1.0497262477874756,50000 -4656.106626033783,7.542491436004639,60542.98550081253,135594,0,60542.98550081253,0.6195000410079956,1.6498830318450928,10000,65214.08685183525,0.8163085579872131,0.6858136653900146,0.7380599975585938,1.0333293676376345,50000 -4688.164028644562,7.599376678466797,60963.28740334511,136534,0,60963.28740334511,0.6193000078201294,1.6543912887573242,10000,65666.55383682251,0.8235937356948853,0.691146969795227,0.7392999529838562,1.0445665121078491,50000 -4726.867444038391,7.671452760696411,61383.44880151749,137478,0,61383.44880151749,0.6115000247955322,1.687080979347229,10000,66125.54321527481,0.8358203172683716,0.6423953771591187,0.7375800013542175,1.0486226081848145,50000 -4763.486090421677,7.724071741104126,61803.55562138557,138421,0,61803.55562138557,0.6202000379562378,1.641932487487793,10000,66582.37235975266,0.8221484422683716,0.6727483868598938,0.7418599724769592,1.0180381536483765,50000 -4796.360649824143,7.7763237953186035,62223.732880592346,139364,0,62223.732880592346,0.6260000467300415,1.6316074132919312,10000,67035.52803492546,0.8287304639816284,0.6483760476112366,0.7456600069999695,1.0039118528366089,50000 -4834.561587095261,7.831675052642822,62643.9823744297,140299,0,62643.9823744297,0.6199000477790833,1.6409794092178345,10000,67494.0847992897,0.8375195264816284,0.6271106004714966,0.7462999820709229,1.0144020318984983,50000 -4875.187557458878,7.8880674839019775,63063.95926761627,141239,0,63063.95926761627,0.6206000447273254,1.6216108798980713,10000,67954.7954685688,0.8244726657867432,0.6539803743362427,0.7459200024604797,1.0001908540725708,50000 -4911.257856369019,7.93324613571167,63484.03082895279,142181,0,63484.03082895279,0.6249000430107117,1.6253139972686768,10000,68411.03473544121,0.83216792345047,0.6442916989326477,0.7468400001525879,1.000510334968567,50000 -4948.755502462387,7.992448329925537,63904.20986747742,143119,0,63904.20986747742,0.6276000142097473,1.612836480140686,10000,68868.8211402893,0.8355468511581421,0.6156179904937744,0.7492799758911133,0.9899269938468932,50000 -4981.9539959430695,8.04038405418396,64324.40742135048,144060,0,64324.40742135048,0.6288000345230103,1.6108155250549316,10000,69322.31814837456,0.8414062261581421,0.6124166250228882,0.7498999834060669,0.9950244426727296,50000 -5019.346714496613,8.100827932357788,64744.35452175141,144994,0,64744.35452175141,0.6317000389099121,1.5940308570861816,10000,69779.76891851425,0.8344921469688416,0.6352013945579529,0.7530800104141235,0.9881432056427002,50000 -5052.624608039856,8.164747714996338,65164.53370523453,145924,0,65164.53370523453,0.6289000511169434,1.6003049612045288,10000,70233.34082078934,0.8385546803474426,0.615352988243103,0.7503199577331543,0.9843414425849916,50000 -5087.729380130768,8.226694107055664,65584.762103796,146860,0,65584.762103796,0.6330000162124634,1.5802263021469116,10000,70688.78717851639,0.8499999642372131,0.5691419243812561,0.7542200088500977,0.9688873291015624,50000 -5128.872689962387,8.282914638519287,66005.28381443024,147801,0,66005.28381443024,0.6347000598907471,1.5633114576339722,10000,71150.55936717987,0.841113269329071,0.5822334289550781,0.7569400072097778,0.9509395956993104,50000 -5159.098656654358,8.345009803771973,66425.58796477318,148745,0,66425.58796477318,0.6318000555038452,1.5700883865356443,10000,71601.20407652855,0.8462304472923279,0.5925246477127075,0.7569599747657776,0.9620269536972046,50000 -5200.234387397766,8.406283617019653,66845.85511136055,149680,0,66845.85511136055,0.6348000168800354,1.5691616535186768,10000,72062.71906900406,0.8490039110183716,0.5663567185401917,0.7562800049781799,0.9526709914207458,50000 -5237.093623876572,8.467820644378662,67265.89777302742,150622,0,67265.89777302742,0.6383000016212463,1.5757783651351929,10000,72519.73343348503,0.8469530940055847,0.5782976150512695,0.7572999596595764,0.9563175439834596,50000 -5272.33419585228,8.526651859283447,67685.99413371086,151562,0,67685.99413371086,0.6325000524520874,1.5823768377304075,10000,72975.18101358414,0.8485546708106995,0.580278217792511,0.7583000063896179,0.9555332660675048,50000 -5319.5845646858215,8.591355562210083,68106.0849275589,152498,0,68106.0849275589,0.6325000524520874,1.597623586654663,10000,73442.63851284981,0.8522460460662842,0.5873256325721741,0.7579799890518188,0.9749816656112672,50000 -5353.610310792923,8.64404821395874,68526.38285326958,153442,0,68526.38285326958,0.6397000551223755,1.548847198486328,10000,73897.06649947166,0.8603124618530273,0.5331978797912598,0.7604599595069885,0.9412803649902344,50000 -5389.820098400116,8.700559377670288,68946.4556479454,154381,0,68946.4556479454,0.6432000398635864,1.552872896194458,10000,74353.45694947243,0.8544335961341858,0.5505743622779846,0.7605400085449219,0.9401016235351562,50000 -5422.504199266434,8.762062788009644,69366.76836133003,155321,0,69366.76836133003,0.6396000385284424,1.5510258674621582,10000,74806.56618714333,0.8555663824081421,0.5521935224533081,0.7650399804115295,0.9392313957214355,50000 -5457.393801450729,8.822291851043701,69786.66360473633,156261,0,69786.66360473633,0.6413000226020813,1.5416245460510254,10000,75261.46242833138,0.8625780940055847,0.516830325126648,0.7641800045967102,0.9275818467140198,50000 -5493.874259471893,9.651360750198364,70206.14763140678,157195,0,70206.14763140678,0.6414000391960144,1.534990310668945,10000,75718.30789279938,0.8593164086341858,0.522832989692688,0.7650799751281738,0.9194480776786804,50000 -5529.660228967667,9.710200309753418,70626.3047413826,158135,0,70626.3047413826,0.6493000388145447,1.521514892578125,10000,76174.36131358147,0.8620507717132568,0.5150201916694641,0.7664799690246582,0.9146672487258912,50000 -5563.565862417221,9.763209104537964,71046.5109963417,159079,0,71046.5109963417,0.6468000411987305,1.5106713771820068,10000,76628.57782149315,0.8653515577316284,0.5030461549758911,0.7672199606895447,0.907681167125702,50000 -5602.491634130478,9.824456453323364,71466.58711266518,160018,0,71466.58711266518,0.6488000154495239,1.520171046257019,10000,77087.69198036194,0.8733007907867432,0.4837360680103302,0.7673999667167664,0.9166808128356934,50000 -5640.905419826508,9.887304306030272,71886.58734107018,160956,0,71886.58734107018,0.6488000154495239,1.511286735534668,10000,77546.2199819088,0.8666601181030273,0.5009443759918213,0.7683999538421631,0.9037065505981444,50000 -5683.460703372955,9.955764532089232,72306.52605938911,161899,0,72306.52605938911,0.6458000540733337,1.5094892978668213,10000,78008.83457326889,0.8688476085662842,0.490939736366272,0.7700600028038025,0.8974904417991638,50000 -5717.021797180176,10.020394086837769,72726.65507078171,162842,0,72726.65507078171,0.6525000333786011,1.496620774269104,10000,78462.64110207558,0.87611323595047,0.4695218503475189,0.7698599696159363,0.8996132016181946,50000 -5752.519358158112,10.0841383934021,73146.93170976639,163776,0,73146.93170976639,0.6538000106811523,1.4957234859466553,10000,78918.5292005539,0.86865234375,0.4866144359111786,0.7710599899291992,0.8960846662521362,50000 -5803.269119262695,10.14598798751831,73566.84558701515,164712,0,73566.84558701515,0.65010005235672,1.4990768432617188,10000,79389.30561876297,0.87123042345047,0.4835419654846191,0.7716599702835083,0.892865002155304,50000 -5838.162070274353,10.200010299682615,73986.8037891388,165655,0,73986.8037891388,0.6559000015258789,1.486812710762024,10000,79844.26286292076,0.8752539157867432,0.4673053920269012,0.7740199565887451,0.8827558159828186,50000 -5873.305253982544,10.267615795135498,74406.99995279312,166594,0,74406.99995279312,0.657200038433075,1.481296181678772,10000,80299.7220902443,0.8815820217132568,0.4435708522796631,0.7743200063705444,0.8817341923713684,50000 -5911.709374427795,10.32997465133667,74827.24753332138,167527,0,74827.24753332138,0.6550000309944153,1.4792810678482056,10000,80758.48889684677,0.8744335770606995,0.4724225103855133,0.774179995059967,0.8804149627685547,50000 -5959.091042995453,10.396169185638428,75247.50064873695,168464,0,75247.50064873695,0.6531000137329102,1.4713881015777588,10000,81226.24056196213,0.8765429258346558,0.4508543014526367,0.774899959564209,0.8729352951049805,50000 -5994.915839672089,10.452755689620972,75667.44517111778,169408,0,75667.44517111778,0.6557000279426575,1.4676754474639893,10000,81682.11791706085,0.8811718821525574,0.447146862745285,0.7753999829292297,0.8756856918334961,50000 -6032.74413061142,10.521770477294922,76087.41112089157,170347,0,76087.41112089157,0.6539000272750854,1.4720929861068726,10000,82140.03224730492,0.8786327838897705,0.4540967941284179,0.7759400010108948,0.8734537363052368,50000 -6076.168865442276,10.582777738571169,76507.5048623085,171285,0,76507.5048623085,0.6573000550270081,1.4642900228500366,10000,82603.66271305084,0.8799608945846558,0.4463948607444763,0.7760399580001831,0.8722897171974182,50000 -6112.465822458267,10.634551048278809,76927.65636372566,172228,0,76927.65636372566,0.6579000353813171,1.4670575857162476,10000,83060.2145652771,0.8814257383346558,0.4432843625545501,0.7773799896240234,0.8704707026481628,50000 -6153.237934350967,10.700610399246216,77347.83917474747,173169,0,77347.83917474747,0.6571000218391418,1.4592924118041992,10000,83521.28679227829,0.8839452862739563,0.4334542453289032,0.7780399918556213,0.8655481934547424,50000 -6187.839814186096,10.765312910079956,77768.06882762909,174111,0,77768.06882762909,0.6586000323295593,1.4545586109161377,10000,83976.2343738079,0.8828905820846558,0.4322830736637115,0.77947998046875,0.8609575033187866,50000 -6222.178782701492,10.833298206329346,78187.95839118958,175049,0,78187.95839118958,0.6575000286102295,1.4523377418518066,10000,84430.58243322372,0.8865429759025574,0.4204282760620117,0.7789599895477295,0.861438512802124,50000 -6260.123591899872,10.897441625595093,78608.33759880066,175986,0,78608.33759880066,0.6561000347137451,1.4516371488571167,10000,84889.0225493908,0.8853319883346558,0.4295352399349212,0.7797200083732605,0.859338641166687,50000 -6299.135026216507,10.9592125415802,79028.45848870277,176925,0,79028.45848870277,0.659000039100647,1.450667142868042,10000,85348.26768136024,0.8836327791213989,0.4284150898456573,0.7784199714660645,0.8580959439277649,50000 -6334.59276509285,11.045661687850952,79448.42927765846,177867,0,79448.42927765846,0.6620000600814819,1.443135380744934,10000,85803.83441829681,0.8875390291213989,0.4265033900737762,0.7801399827003479,0.8563417792320251,50000 -6376.347739696503,11.105734825134276,79868.60091280937,178807,0,79868.60091280937,0.6636000275611877,1.4399917125701904,10000,86265.87176012993,0.887499988079071,0.4158221781253814,0.7795199751853943,0.8523306250572205,50000 -6413.968862771988,11.165611028671265,80288.49507188797,179747,0,80288.49507188797,0.6633000373840332,1.4441020488739014,10000,86723.49782943726,0.8867382407188416,0.418932557106018,0.7800399661064148,0.8535423874855042,50000 -6448.626727581024,11.231953859329224,80708.39556407928,180683,0,80708.39556407928,0.6643000245094299,1.442840576171875,10000,87178.17326307297,0.8873242139816284,0.4149569869041443,0.7804799675941467,0.8523868322372437,50000 -6489.427065849304,11.297489881515505,81128.36767435074,181620,0,81128.36767435074,0.6632000207901001,1.4451278448104858,10000,87639.06216335297,0.8868749737739563,0.4176073670387268,0.780239999294281,0.8525054454803467,50000 -6534.940185070038,11.358872413635254,81548.34212756157,182561,0,81548.34212756157,0.6635000109672546,1.442088603973389,10000,88104.66285085678,0.8882421851158142,0.4123246967792511,0.7803599834442139,0.8496127724647522,50000 -6568.771864891052,11.416648387908936,81968.63080692291,183505,0,81968.63080692291,0.6627000570297241,1.4406883716583252,10000,88558.89284658432,0.8894140720367432,0.4123950004577636,0.780460000038147,0.8484692573547363,50000 -6607.41025018692,11.48314118385315,82388.93534779549,184444,0,82388.93534779549,0.6625000238418579,1.4413601160049438,10000,89017.95306396484,0.8900390267372131,0.4150246083736419,0.7807599902153015,0.8500364422798157,50000 -6653.44875907898,11.552966833114624,82809.04411435127,185384,0,82809.04411435127,0.6627000570297241,1.4396294355392456,10000,89484.22124147415,0.8905664086341858,0.4088824987411499,0.7807799577713013,0.8490588068962097,50000 -6692.481970310211,11.621200561523438,83228.92775535583,186329,0,83228.92775535583,0.6629000306129456,1.4399614334106443,10000,89943.25726413727,0.8901953101158142,0.414861798286438,0.7806599736213684,0.8491452932357788,50000 -6731.947980880737,11.689115285873411,83649.1949005127,187270,0,83649.1949005127,0.6628000140190125,1.439961552619934,10000,90403.1090900898,0.8873632550239563,0.4139579236507416,0.7806599736213684,0.8491368293762207,50000 -6769.103258132935,11.752736330032349,84069.26384472847,188208,0,84069.26384472847,0.6628000140190125,1.439961552619934,10000,90860.44817233086,0.8859374523162842,0.4201960265636444,0.7806599736213684,0.8491368293762207,50000 -6807.8059005737305,11.816109657287598,84489.2108707428,189144,0,84489.2108707428,0.6628000140190125,1.439961552619934,10000,91319.21151280405,0.8887304663658142,0.4120084643363952,0.7806599736213684,0.8491368293762207,50000 -6845.819953918457,11.881962776184082,84909.57313537598,190081,0,84909.57313537598,0.6628000140190125,1.439961552619934,10000,91777.704955101,0.8896288871765137,0.4153327941894531,0.7806599736213684,0.8491368293762207,50000 -6891.935022592545,11.94904851913452,85329.69141507149,191022,0,85329.69141507149,0.6628000140190125,1.439961552619934,10000,92244.0565958023,0.8895507454872131,0.4160298407077789,0.7806599736213684,0.8491368293762207,50000 -6931.216662168503,12.00302505493164,85749.73276805878,191967,0,85749.73276805878,0.6628000140190125,1.439961552619934,10000,92703.48462033272,0.8859570026397705,0.4194627404212951,0.7806599736213684,0.8491368293762207,50000 -6966.921736240387,12.076726198196411,86169.99474191666,192908,0,86169.99474191666,0.6628000140190125,1.439961552619934,10000,93159.5766465664,0.8864843845367432,0.4163427650928497,0.7806599736213684,0.8491368293762207,50000 -7008.095109939575,12.14166522026062,86590.10334300995,193846,0,86590.10334300995,0.6628000140190125,1.439961552619934,10000,93620.9745607376,0.8892577886581421,0.4128409922122955,0.7806599736213684,0.8491368293762207,50000 -7046.456674575806,12.20676326751709,87010.18246912956,194786,0,87010.18246912956,0.6628000140190125,1.439961552619934,10000,94079.54350543022,0.8887109160423279,0.4138174057006836,0.7806599736213684,0.8491368293762207,50000 -7084.988321781158,12.271981477737429,87430.38870239258,195726,0,87430.38870239258,0.6628000140190125,1.439961552619934,10000,94538.39776062964,0.8862695097923279,0.4229434728622436,0.7806599736213684,0.8491368293762207,50000 -7124.067133426666,12.340903997421265,87850.36334037781,196662,0,87850.36334037781,0.6628000140190125,1.439961552619934,10000,94997.57072806358,0.88929682970047,0.4122631251811981,0.7806599736213684,0.8491368293762207,50000 -7161.361990451813,12.406561136245728,88270.65979909897,197600,0,88270.65979909897,0.6628000140190125,1.439961552619934,10000,95455.27904057504,0.8868359327316284,0.4162808060646057,0.7806599736213684,0.8491368293762207,50000 -7199.420048713684,12.473467350006104,88690.6394803524,198538,0,88690.6394803524,0.6628000140190125,1.439961552619934,10000,95913.43540740012,0.8875390291213989,0.413641095161438,0.7806599736213684,0.8491368293762207,50000 -7237.657130002975,12.549538135528564,89110.54193615913,199480,0,89110.54193615913,0.6628000140190125,1.439961552619934,10000,96371.70179224014,0.8901171684265137,0.4157747030258178,0.7806599736213684,0.8491368293762207,50000 -7275.220329999924,12.615999937057495,89530.50885987282,200414,0,89530.50885987282,0.6628000140190125,1.439961552619934,10000,96829.34929418564,0.887499988079071,0.41785928606987,0.7806599736213684,0.8491368293762207,50000 -7314.43705201149,12.690476179122925,89950.4132950306,201349,0,89950.4132950306,0.6628000140190125,1.439961552619934,10000,97288.59665942192,0.8890624642372131,0.4149791598320007,0.7806599736213684,0.8491368293762207,50000 -7353.0320365428925,12.773582935333252,90370.42215585709,202291,0,90370.42215585709,0.6628000140190125,1.439961552619934,10000,97747.3352882862,0.8885741829872131,0.4138651192188263,0.7806599736213684,0.8491368293762207,50000 -7390.410850524902,12.851639747619627,90790.63092684746,203232,0,90790.63092684746,0.6628000140190125,1.439961552619934,10000,98205.05185079576,0.8862109184265137,0.4218650758266449,0.7806599736213684,0.8491368293762207,50000 -7427.037251472473,12.919341325759888,91210.55031061172,204166,0,91210.55031061172,0.6628000140190125,1.439961552619934,10000,98661.71626615524,0.888964831829071,0.4084383249282837,0.7806599736213684,0.8491368293762207,50000 -7468.865349769592,12.98732614517212,91630.6247985363,205106,0,91630.6247985363,0.6628000140190125,1.439961552619934,10000,99123.73776745796,0.8895702958106995,0.4111346006393432,0.7806599736213684,0.8491368293762207,50000 -7505.758831977844,13.057476282119753,92050.679759264,206045,0,92050.679759264,0.6628000140190125,1.439961552619934,10000,99580.80747056007,0.8857226371765137,0.4192066490650177,0.7806599736213684,0.8491368293762207,50000 -7542.178693294525,13.123713493347168,92470.65465569496,206981,0,92470.65465569496,0.6628000140190125,1.439961552619934,10000,100037.31982588768,0.8902148008346558,0.4150867462158203,0.7806599736213684,0.8491368293762207,50000 -7579.383923530579,13.19585919380188,92890.90583324432,207919,0,92890.90583324432,0.6628000140190125,1.439961552619934,10000,100494.89912700652,0.8915429711341858,0.4060378968715668,0.7806599736213684,0.8491368293762207,50000 -7618.397220611572,13.267715454101562,93311.23576164246,208859,0,93311.23576164246,0.6628000140190125,1.439961552619934,10000,100954.3665342331,0.8892382383346558,0.4131664931774139,0.7806599736213684,0.8491368293762207,50000 -7655.169943809509,13.32474970817566,93731.48620271684,209800,0,93731.48620271684,0.6628000140190125,1.439961552619934,10000,101411.4986450672,0.8897460699081421,0.4122765362262726,0.7806599736213684,0.8491368293762207,50000 -7702.275942802429,13.394086122512816,94151.36676955225,210740,0,94151.36676955225,0.6628000140190125,1.439961552619934,10000,101878.605271101,0.8892187476158142,0.4108535349369049,0.7806599736213684,0.8491368293762207,50000 -7735.611426591873,13.45100212097168,94571.33133649826,211682,0,94571.33133649826,0.6628000140190125,1.439961552619934,10000,102332.013286829,0.8870898485183716,0.419926643371582,0.7806599736213684,0.8491368293762207,50000 -7772.595160007477,13.521557331085203,94991.60111403464,212623,0,94991.60111403464,0.6628000140190125,1.439961552619934,10000,102789.38869023325,0.8870312571525574,0.416995108127594,0.7806599736213684,0.8491368293762207,50000 -7812.842822313309,13.592507123947144,95411.76716446877,213562,0,95411.76716446877,0.6628000140190125,1.439961552619934,10000,103249.92401885986,0.8886132836341858,0.4166717827320099,0.7806599736213684,0.8491368293762207,50000 -7850.67919921875,13.663443326950071,95831.76033210754,214501,0,95831.76033210754,0.6628000140190125,1.439961552619934,10000,103707.87656855585,0.888964831829071,0.411579966545105,0.7806599736213684,0.8491368293762207,50000 -7893.881835460663,13.775733470916748,96251.69516730309,215442,0,96251.69516730309,0.6628000140190125,1.439961552619934,10000,104171.17848610878,0.8849804401397705,0.4249436855316162,0.7806599736213684,0.8491368293762207,50000 -7934.622961759567,13.850823163986206,96671.78650546074,216380,0,96671.78650546074,0.6628000140190125,1.439961552619934,10000,104632.13685011864,0.8871093392372131,0.4179880023002624,0.7806599736213684,0.8491368293762207,50000 -7976.588205337524,13.924319505691528,97092.09571146964,217317,0,97092.09571146964,0.6628000140190125,1.439961552619934,10000,105094.53578543664,0.8890234231948853,0.4122736155986786,0.7806599736213684,0.8491368293762207,50000 -8010.768728733063,13.989553689956663,97512.15190458298,218258,0,97512.15190458298,0.6628000140190125,1.439961552619934,10000,105548.88929748537,0.8883007764816284,0.412197470664978,0.7806599736213684,0.8491368293762207,50000 -8049.906625032425,14.060953378677368,97932.20115685464,219194,0,97932.20115685464,0.6628000140190125,1.439961552619934,10000,106008.19831943512,0.8895312547683716,0.4137906432151794,0.7806599736213684,0.8491368293762207,50000 -8087.387703418732,14.133442878723145,98352.26095962524,220131,0,98352.26095962524,0.6628000140190125,1.439961552619934,10000,106465.86269831656,0.8866210579872131,0.419018805027008,0.7806599736213684,0.8491368293762207,50000 -8126.485192060471,14.248473167419434,98772.34060120584,221070,0,98772.34060120584,0.6628000140190125,1.439961552619934,10000,106925.2056400776,0.8873828053474426,0.4174685180187225,0.7806599736213684,0.8491368293762207,50000 -8173.137872457504,14.322559595108032,99192.3106303215,222009,0,99192.3106303215,0.6628000140190125,1.439961552619934,10000,107391.95266342165,0.8881054520606995,0.4179070889949798,0.7806599736213684,0.8491368293762207,50000 -8211.162187337875,14.383384227752686,99612.22961091997,222944,0,99612.22961091997,0.6628000140190125,1.439961552619934,10000,107850.00766944884,0.8896679282188416,0.4159463346004486,0.7806599736213684,0.8491368293762207,50000 -8250.954582214355,14.445920705795288,100032.17080593108,223884,0,100032.17080593108,0.6628000140190125,1.439961552619934,10000,108309.85451054572,0.8884570002555847,0.4101370871067047,0.7806599736213684,0.8491368293762207,50000 -8293.447283029556,14.529171466827393,100452.20421767236,224823,0,100452.20421767236,0.6628000140190125,1.439961552619934,10000,108772.51467132568,0.88734370470047,0.4182247519493103,0.7806599736213684,0.8491368293762207,50000 -8336.703224897385,14.600011348724363,100872.4828350544,225764,0,100872.4828350544,0.6628000140190125,1.439961552619934,10000,109236.17139697076,0.8885741829872131,0.4156999588012695,0.7806599736213684,0.8491368293762207,50000 -8375.506333589554,14.662203788757324,101292.43042826653,226704,0,101292.43042826653,0.6628000140190125,1.439961552619934,10000,109695.03538489342,0.88880854845047,0.4133599698543548,0.7806599736213684,0.8491368293762207,50000 -8422.173720359802,14.73522162437439,101712.58194470406,227643,0,101712.58194470406,0.6628000140190125,1.439961552619934,10000,110161.97889208794,0.8872656226158142,0.4160450398921966,0.7806599736213684,0.8491368293762207,50000 -8467.515223264694,14.81779718399048,102132.54553413393,228582,0,102132.54553413393,0.6628000140190125,1.439961552619934,10000,110627.42151880264,0.8897070288658142,0.4089111089706421,0.7806599736213684,0.8491368293762207,50000 -8506.700605392456,14.884428977966309,102552.66275954248,229521,0,102552.66275954248,0.6628000140190125,1.439961552619934,10000,111086.84193229675,0.8869531154632568,0.4167779386043548,0.7806599736213684,0.8491368293762207,50000 -8544.661287546158,14.990256071090698,102972.72747325896,230459,0,102972.72747325896,0.6628000140190125,1.439961552619934,10000,111545.024851799,0.8913085460662842,0.4091808199882507,0.7806599736213684,0.8491368293762207,50000 -8586.344784021378,15.06243109703064,103392.72380638124,231386,0,103392.72380638124,0.6628000140190125,1.439961552619934,10000,112006.82755184174,0.8886132836341858,0.41818568110466,0.7806599736213684,0.8491368293762207,50000 -8626.981405496597,15.136332988739014,103812.9854183197,232323,0,103812.9854183197,0.6628000140190125,1.439961552619934,10000,112467.85011029243,0.8895898461341858,0.4101483821868896,0.7806599736213684,0.8491368293762207,50000 -8663.754947662354,15.229785442352297,104233.0490720272,233264,0,104233.0490720272,0.6628000140190125,1.439961552619934,10000,112924.8316013813,0.8911718726158142,0.4089621007442474,0.7806599736213684,0.8491368293762207,50000 -8703.668316602707,15.30400776863098,104652.99808740616,234202,0,104652.99808740616,0.6628000140190125,1.439961552619934,10000,113384.81877231598,0.8874609470367432,0.4192294180393219,0.7806599736213684,0.8491368293762207,50000 -8742.130467414856,15.374670505523682,105072.99116683006,235137,0,105072.99116683006,0.6628000140190125,1.439961552619934,10000,113843.39495277403,0.8865624666213989,0.4175738096237182,0.7806599736213684,0.8491368293762207,50000 -8782.207236289978,15.449368953704834,105493.27949905396,236073,0,105493.27949905396,0.6628000140190125,1.439961552619934,10000,114303.88598370552,0.8865820169448853,0.418828547000885,0.7806599736213684,0.8491368293762207,50000 -8820.693273544312,15.526480674743652,105913.54170560835,237012,0,105913.54170560835,0.6628000140190125,1.439961552619934,10000,114762.76225996016,0.8889843821525574,0.4155234694480896,0.7806599736213684,0.8491368293762207,50000 -8866.161488056183,15.605794191360474,106333.91696500778,237948,0,106333.91696500778,0.6628000140190125,1.439961552619934,10000,115228.73643136024,0.8892773389816284,0.4160856008529663,0.7806599736213684,0.8491368293762207,50000 -8904.943639278412,15.677077531814575,106754.01340723038,238888,0,106754.01340723038,0.6628000140190125,1.439961552619934,10000,115687.73761057854,0.8872265219688416,0.4133851826190948,0.7806599736213684,0.8491368293762207,50000 -8941.621190309525,15.752530097961426,107174.28653025629,239829,0,107174.28653025629,0.6628000140190125,1.439961552619934,10000,116144.81491804124,0.8869531154632568,0.4212391078472137,0.7806599736213684,0.8491368293762207,50000 -8978.973022460938,15.82563543319702,107594.49706172945,240767,0,107594.49706172945,0.6628000140190125,1.439961552619934,10000,116602.501486063,0.8889062404632568,0.4103345870971679,0.7806599736213684,0.8491368293762207,50000 -9020.95052599907,15.900647640228271,108014.80342388152,241705,0,108014.80342388152,0.6628000140190125,1.439961552619934,10000,117064.91074442863,0.8889257907867432,0.4125271141529083,0.7806599736213684,0.8491368293762207,50000 -9059.374162912369,15.975814819335938,108434.92385172844,242644,0,108434.92385172844,0.6628000140190125,1.439961552619934,10000,117523.58127760889,0.8873046636581421,0.4170242846012115,0.7806599736213684,0.8491368293762207,50000 -9097.214118480682,16.05459976196289,108855.2262763977,243578,0,108855.2262763977,0.6628000140190125,1.439961552619934,10000,117981.85328555109,0.8865820169448853,0.4223847091197967,0.7806599736213684,0.8491368293762207,50000 -9137.2043967247,16.133418798446655,109275.20730948448,244515,0,109275.20730948448,0.6628000140190125,1.439961552619934,10000,118441.95490217207,0.8882226347923279,0.4138920903205871,0.7806599736213684,0.8491368293762207,50000 -9175.77440404892,16.207163333892822,109695.28633069992,245454,0,109695.28633069992,0.6628000140190125,1.439961552619934,10000,118900.72915911674,0.887499988079071,0.4170665740966797,0.7806599736213684,0.8491368293762207,50000 -9214.267620325089,16.286096334457397,110115.32331180573,246392,0,110115.32331180573,0.6628000140190125,1.439961552619934,10000,119359.3913846016,0.8889257907867432,0.4171882271766662,0.7806599736213684,0.8491368293762207,50000 -9257.57844209671,16.369798183441162,110535.2491569519,247331,0,110535.2491569519,0.6628000140190125,1.439961552619934,10000,119822.7629380226,0.8894921541213989,0.4114971458911896,0.7806599736213684,0.8491368293762207,50000 -9293.305986881256,16.446000337600708,110955.36974859238,248273,0,110955.36974859238,0.6628000140190125,1.439961552619934,10000,120278.73826503754,0.8887109160423279,0.4138701260089874,0.7806599736213684,0.8491368293762207,50000 -9335.543090820312,16.526723384857178,111375.37278437614,249210,0,111375.37278437614,0.6628000140190125,1.439961552619934,10000,120741.10971736908,0.8867577910423279,0.4228262901306152,0.7806599736213684,0.8491368293762207,50000 -9370.087344884872,16.606194257736206,111795.39901304244,250152,0,111795.39901304244,0.6628000140190125,1.439961552619934,10000,121195.81096291542,0.8893749713897705,0.4094995856285095,0.7806599736213684,0.8491368293762207,50000 -9414.86929488182,16.68982768058777,112215.63183784483,251090,0,112215.63183784483,0.6628000140190125,1.439961552619934,10000,121660.96054553986,0.8884179592132568,0.4091140627861023,0.7806599736213684,0.8491368293762207,50000 -9460.509674549105,16.75420641899109,112635.64821743964,252034,0,112635.64821743964,0.6628000140190125,1.439961552619934,10000,122126.73324251176,0.8879687190055847,0.4145607650279999,0.7806599736213684,0.8491368293762207,50000 -9504.239057064056,16.821215629577637,113055.63692212103,252976,0,113055.63692212103,0.6628000140190125,1.439961552619934,10000,122590.56933999062,0.8885937333106995,0.41547492146492,0.7806599736213684,0.8491368293762207,50000 -9542.089750528336,16.89837408065796,113475.78511810304,253915,0,113475.78511810304,0.6628000140190125,1.439961552619934,10000,123048.69641470908,0.8883593678474426,0.4171532988548279,0.7806599736213684,0.8491368293762207,50000 -9581.454507350922,16.978004217147827,113895.78473091124,254846,0,113895.78473091124,0.6628000140190125,1.439961552619934,10000,123508.19112324716,0.8913671970367432,0.4085223972797394,0.7806599736213684,0.8491368293762207,50000 -9626.326175928116,17.073423862457275,114315.99058461188,255785,0,114315.99058461188,0.6628000140190125,1.439961552619934,10000,123973.41496372224,0.8885351419448853,0.4148720502853393,0.7806599736213684,0.8491368293762207,50000 -9659.201932668686,17.14213752746582,114736.19265317915,256728,0,114736.19265317915,0.6628000140190125,1.439961552619934,10000,124426.6122317314,0.8897265195846558,0.4137465059757232,0.7806599736213684,0.8491368293762207,50000 -9699.450918912888,17.224257707595825,115156.30558776855,257655,0,115156.30558776855,0.6628000140190125,1.439961552619934,10000,124887.1065568924,0.8881640434265137,0.4150804877281189,0.7806599736213684,0.8491368293762207,50000 -9747.919601917269,17.31403613090515,115576.53962373734,258593,0,115576.53962373734,0.6628000140190125,1.439961552619934,10000,125355.9505121708,0.8871288895606995,0.4156031608581543,0.7806599736213684,0.8491368293762207,50000 -9787.654287099838,17.384129524230957,115996.80192518234,259536,0,115996.80192518234,0.6628000140190125,1.439961552619934,10000,125816.06837558746,0.8883788585662842,0.4144641458988189,0.7806599736213684,0.8491368293762207,50000 -9829.864854097366,17.466994762420654,116417.05393886566,260473,0,116417.05393886566,0.6628000140190125,1.439961552619934,10000,126278.66386461258,0.888671875,0.4139718115329742,0.7806599736213684,0.8491368293762207,50000 -9864.348781108856,17.530568838119507,116837.02300691605,261413,0,116837.02300691605,0.6628000140190125,1.439961552619934,10000,126733.23106122015,0.8889257907867432,0.4145142734050751,0.7806599736213684,0.8491368293762207,50000 -9912.99257516861,17.60601305961609,117257.40816497804,262350,0,117257.40816497804,0.6628000140190125,1.439961552619934,10000,127202.38626980782,0.8883984088897705,0.4164403676986694,0.7806599736213684,0.8491368293762207,50000 -9952.029757261276,17.68432903289795,117677.55997300148,263291,0,117677.55997300148,0.6628000140190125,1.439961552619934,10000,127661.7043390274,0.8852343559265137,0.4217980802059173,0.7806599736213684,0.8491368293762207,50000 -9995.02408361435,17.765547513961792,118097.5377099514,264227,0,118097.5377099514,0.6628000140190125,1.439961552619934,10000,128124.80843949318,0.8889843821525574,0.4135739505290985,0.7806599736213684,0.8491368293762207,50000 -10030.878251314163,17.853434562683105,118517.52534794807,265162,0,118517.52534794807,0.6628000140190125,1.439961552619934,10000,128580.78847575188,0.8879492282867432,0.4144449830055237,0.7806599736213684,0.8491368293762207,50000 -10069.000126123428,17.936197996139526,118937.56853723526,266097,0,118937.56853723526,0.6628000140190125,1.439961552619934,10000,129039.0865252018,0.8884179592132568,0.4173905849456787,0.7806599736213684,0.8491368293762207,50000 -10111.985719919205,18.02430295944214,119357.60832476616,267036,0,119357.60832476616,0.6628000140190125,1.439961552619934,10000,129502.25046014786,0.8875976204872131,0.4157011210918426,0.7806599736213684,0.8491368293762207,50000 -10150.179717302322,18.103163957595825,119777.81729054452,267975,0,119777.81729054452,0.6628000140190125,1.439961552619934,10000,129960.78221297264,0.8877733945846558,0.4151789844036102,0.7806599736213684,0.8491368293762207,50000 -10190.332374095917,18.181214332580566,120197.99805498125,268913,0,120197.99805498125,0.6628000140190125,1.439961552619934,10000,130421.24511313438,0.8890820145606995,0.411138653755188,0.7806599736213684,0.8491368293762207,50000 -10228.79074215889,18.266133069992065,120617.99323821068,269850,0,120617.99323821068,0.6628000140190125,1.439961552619934,10000,130879.83401083946,0.8871679306030273,0.4229635894298553,0.7806599736213684,0.8491368293762207,50000 -10269.55788254738,18.342522621154785,121037.9982905388,270790,0,121037.9982905388,0.6628000140190125,1.439961552619934,10000,131340.73326659203,0.88929682970047,0.4126971662044525,0.7806599736213684,0.8491368293762207,50000 -10308.489049196243,18.421640157699585,121457.97619390488,271730,0,121457.97619390488,0.6628000140190125,1.439961552619934,10000,131799.77235770226,0.8877539038658142,0.4173440933227539,0.7806599736213684,0.8491368293762207,50000 -10349.891362667084,18.50593400001526,121877.95221614838,272670,0,121877.95221614838,0.6628000140190125,1.439961552619934,10000,132261.28627204895,0.88832026720047,0.4156412780284881,0.7806599736213684,0.8491368293762207,50000 -10395.50066781044,18.59084153175354,122297.96140408516,273609,0,122297.96140408516,0.6628000140190125,1.439961552619934,10000,132727.0404522419,0.8891796469688416,0.4100504517555237,0.7806599736213684,0.8491368293762207,50000 -10430.909448862076,18.673195123672485,122717.90025830267,274550,0,122717.90025830267,0.6628000140190125,1.439961552619934,10000,133182.52132487297,0.8878515362739563,0.4165761172771454,0.7806599736213684,0.8491368293762207,50000 -10472.254363298416,18.7580795288086,123138.03974294662,275488,0,123138.03974294662,0.6628000140190125,1.439961552619934,10000,133644.14214658737,0.8887304663658142,0.4122629761695862,0.7806599736213684,0.8491368293762207,50000 -10512.6342856884,18.83539652824402,123558.17305517197,276426,0,123558.17305517197,0.6628000140190125,1.439961552619934,10000,134104.78236317635,0.8866796493530273,0.4174808859825134,0.7806599736213684,0.8491368293762207,50000 -10551.421706914902,18.918168783187863,123978.34770202637,277365,0,123978.34770202637,0.6628000140190125,1.439961552619934,10000,134563.87787127495,0.8886132836341858,0.4149703085422516,0.7806599736213684,0.8491368293762207,50000 -10590.24893975258,18.998623609542847,124398.55496621132,278304,0,124398.55496621132,0.6628000140190125,1.439961552619934,10000,135023.04332780838,0.8903124928474426,0.4152889549732208,0.7806599736213684,0.8491368293762207,50000 -10629.01440358162,19.098854541778564,124818.6222038269,279242,0,124818.6222038269,0.6628000140190125,1.439961552619934,10000,135482.02679228783,0.888671875,0.4086595475673675,0.7806599736213684,0.8491368293762207,50000 -10668.821629047394,19.18437695503235,125238.82615542412,280179,0,125238.82615542412,0.6628000140190125,1.439961552619934,10000,135942.1735329628,0.8921093344688416,0.4076516628265381,0.7806599736213684,0.8491368293762207,50000 -10708.152201652529,19.26614189147949,125659.04063010216,281119,0,125659.04063010216,0.6628000140190125,1.439961552619934,10000,136401.85151863098,0.8885937333106995,0.414132297039032,0.7806599736213684,0.8491368293762207,50000 -10753.6897149086,19.349141597747803,126079.1490097046,282060,0,126079.1490097046,0.6628000140190125,1.439961552619934,10000,136867.63150930405,0.8877929449081421,0.4160957634449005,0.7806599736213684,0.8491368293762207,50000 -10789.449345111849,19.43389129638672,126499.3281633854,282999,0,126499.3281633854,0.6628000140190125,1.439961552619934,10000,137323.70598578453,0.8861523270606995,0.418018102645874,0.7806599736213684,0.8491368293762207,50000 -10832.9416513443,19.51919531822205,126919.30169415474,283937,0,126919.30169415474,0.6628000140190125,1.439961552619934,10000,137787.30851221085,0.8886132836341858,0.4185100197792053,0.7806599736213684,0.8491368293762207,50000 -10873.36230826378,19.64874291419983,127339.16857624054,284875,0,127339.16857624054,0.6628000140190125,1.439961552619934,10000,138247.77598190308,0.8898632526397705,0.4148930013179779,0.7806599736213684,0.8491368293762207,50000 -10913.14732003212,19.732011795043945,127759.18264389038,285813,0,127759.18264389038,0.6628000140190125,1.439961552619934,10000,138707.7092897892,0.88734370470047,0.4140979647636413,0.7806599736213684,0.8491368293762207,50000 -10956.518753528597,19.81148219108581,128179.11961340904,286748,0,128179.11961340904,0.6628000140190125,1.439961552619934,10000,139171.1473546028,0.8852343559265137,0.4228600561618805,0.7806599736213684,0.8491368293762207,50000 -10993.224220991136,20.136163234710693,128598.82387661934,287688,0,128598.82387661934,0.6628000140190125,1.439961552619934,10000,139627.9332201481,0.8890038728713989,0.4117945730686188,0.7806599736213684,0.8491368293762207,50000 -11030.850098609924,20.22132968902588,129019.09246134758,288625,0,129019.09246134758,0.6628000140190125,1.439961552619934,10000,140085.9644765854,0.8883984088897705,0.4140118360519409,0.7806599736213684,0.8491368293762207,50000 -11068.954473495483,20.305765628814697,129439.0857374668,289560,0,129439.0857374668,0.6628000140190125,1.439961552619934,10000,140544.19653391838,0.88880854845047,0.4138639867305755,0.7806599736213684,0.8491368293762207,50000 -11111.227452993391,20.39596724510193,129859.1750626564,290496,0,129859.1750626564,0.6628000140190125,1.439961552619934,10000,141006.6998987198,0.8863085508346558,0.4227134883403778,0.7806599736213684,0.8491368293762207,50000 -11146.959941864014,20.481508493423465,130279.17562508585,291435,0,130279.17562508585,0.6628000140190125,1.439961552619934,10000,141462.5694692135,0.8886327743530273,0.4127470552921295,0.7806599736213684,0.8491368293762207,50000 -11187.23245549202,20.571951389312744,130699.34307909012,292366,0,130699.34307909012,0.6628000140190125,1.439961552619934,10000,141923.1501107216,0.8895702958106995,0.4100448489189148,0.7806599736213684,0.8491368293762207,50000 -11228.173330783844,20.65904426574707,131119.25686120987,293303,0,131119.25686120987,0.6628000140190125,1.439961552619934,10000,142384.14255547523,0.8876562118530273,0.4232404232025146,0.7806599736213684,0.8491368293762207,50000 -11265.79709982872,20.74518251419068,131539.37345027924,294193,0,131539.37345027924,0.6628000140190125,1.439961552619934,10000,142842.01654458046,0.8883788585662842,0.4139226377010345,0.7806599736213684,0.8491368293762207,50000 -11305.44756770134,20.83014798164368,131959.2889535427,295130,0,131959.2889535427,0.6628000140190125,1.439961552619934,10000,143301.71817946434,0.8861523270606995,0.4199804067611694,0.7806599736213684,0.8491368293762207,50000 -11341.51713204384,20.91266751289368,132379.5202343464,296067,0,132379.5202343464,0.6628000140190125,1.439961552619934,10000,143758.15282964706,0.8904101252555847,0.4130441844463348,0.7806599736213684,0.8491368293762207,50000 -11389.175944328308,21.0025532245636,132799.5991435051,297001,0,132799.5991435051,0.6628000140190125,1.439961552619934,10000,144226.03037834167,0.8873242139816284,0.4146493077278137,0.7806599736213684,0.8491368293762207,50000 -11426.878244876862,21.088128566741943,133219.8224363327,297943,0,133219.8224363327,0.6628000140190125,1.439961552619934,10000,144684.09260296822,0.8876367211341858,0.4151484072208404,0.7806599736213684,0.8491368293762207,50000 -11463.27908229828,21.181781768798828,133640.4175248146,298880,0,133640.4175248146,0.6628000140190125,1.439961552619934,10000,145141.23288106918,0.8898437023162842,0.4100130796432495,0.7806599736213684,0.8491368293762207,50000 -11504.816838264464,21.274759769439697,134060.4425957203,299814,0,134060.4425957203,0.6628000140190125,1.439961552619934,10000,145602.93848657608,0.8866406083106995,0.4186463952064514,0.7806599736213684,0.8491368293762207,50000 -11543.418461561205,21.360515594482425,134480.3460702896,300752,0,134480.3460702896,0.6628000140190125,1.439961552619934,10000,146061.57924103737,0.88978511095047,0.4111753106117248,0.7806599736213684,0.8491368293762207,50000 -11585.907123088837,21.44925570487976,134900.35364151,301689,0,134900.35364151,0.6628000140190125,1.439961552619934,10000,146524.2145433426,0.8907030820846558,0.4115512371063232,0.7806599736213684,0.8491368293762207,50000 -11620.165687322617,21.53703117370605,135320.4859445095,302630,0,135320.4859445095,0.6628000140190125,1.439961552619934,10000,146978.74358296394,0.8902929425239563,0.4095830023288727,0.7806599736213684,0.8491368293762207,50000 -11663.15535378456,21.625725030899048,135740.72599720955,303569,0,135740.72599720955,0.6628000140190125,1.439961552619934,10000,147442.11426973343,0.888476550579071,0.4151207208633423,0.7806599736213684,0.8491368293762207,50000 -11702.606310367584,21.71380090713501,136160.76442551613,304507,0,136160.76442551613,0.6628000140190125,1.439961552619934,10000,147901.74234127998,0.8879101276397705,0.4139824211597442,0.7806599736213684,0.8491368293762207,50000 -11741.822509288788,21.80000257492065,136580.63121008873,305445,0,136580.63121008873,0.6628000140190125,1.439961552619934,10000,148360.96184515953,0.8877343535423279,0.4162862598896026,0.7806599736213684,0.8491368293762207,50000 -11789.387017726898,21.89959049224853,137000.6361260414,306383,0,137000.6361260414,0.6628000140190125,1.439961552619934,10000,148828.68139123917,0.8881250023841858,0.4133361577987671,0.7806599736213684,0.8491368293762207,50000 -11832.368502616882,21.98433923721313,137420.53449702263,307323,0,137420.53449702263,0.6628000140190125,1.439961552619934,10000,149291.6964457035,0.8883593678474426,0.4201784133911133,0.7806599736213684,0.8491368293762207,50000 -11870.268971681597,22.07439494132996,137840.7928082943,308260,0,137840.7928082943,0.6628000140190125,1.439961552619934,10000,149749.99558973312,0.8879687190055847,0.4177784919738769,0.7806599736213684,0.8491368293762207,50000 -11906.708271741869,22.166040182113647,138260.85964107513,309197,0,138260.85964107513,0.6628000140190125,1.439961552619934,10000,150206.64368653297,0.88916015625,0.4091837108135223,0.7806599736213684,0.8491368293762207,50000 -11947.138439178469,22.304643392562863,138680.788346529,310133,0,138680.788346529,0.6628000140190125,1.439961552619934,10000,150667.1916770935,0.8869531154632568,0.4185750484466553,0.7806599736213684,0.8491368293762207,50000 -11986.4548933506,22.394673824310303,139100.7155430317,311071,0,139100.7155430317,0.6628000140190125,1.439961552619934,10000,151126.57604026794,0.8867968320846558,0.418944239616394,0.7806599736213684,0.8491368293762207,50000 -12027.735392093658,22.487955570220947,139520.6550180912,312008,0,139520.6550180912,0.6628000140190125,1.439961552619934,10000,151587.94029808044,0.88916015625,0.411105066537857,0.7806599736213684,0.8491368293762207,50000 -12067.290197372437,22.60226511955261,139940.78490543363,312948,0,139940.78490543363,0.6628000140190125,1.439961552619934,10000,152047.78997969627,0.8881444931030273,0.4167310893535614,0.7806599736213684,0.8491368293762207,50000 -12108.281812429428,22.693623065948486,140360.85139131546,313887,0,140360.85139131546,0.6628000140190125,1.439961552619934,10000,152508.9900314808,0.8859961032867432,0.4235863089561462,0.7806599736213684,0.8491368293762207,50000 -12146.96865439415,22.78712558746338,140781.04164910316,314824,0,140781.04164910316,0.6628000140190125,1.439961552619934,10000,152968.01148629189,0.8883398175239563,0.4152339696884155,0.7806599736213684,0.8491368293762207,50000 -12183.743371248243,22.927883863449097,141201.20494127274,315762,0,141201.20494127274,0.6628000140190125,1.439961552619934,10000,153425.1422805786,0.8909569978713989,0.4055262804031372,0.7806599736213684,0.8491368293762207,50000 -12221.055872440338,23.01867938041687,141621.46349978447,316696,0,141621.46349978447,0.6628000140190125,1.439961552619934,10000,153882.85493707657,0.8879296779632568,0.4203538298606872,0.7806599736213684,0.8491368293762207,50000 -12262.974166870115,23.10886240005493,142041.4583003521,317634,0,142041.4583003521,0.6628000140190125,1.439961552619934,10000,154344.91010689735,0.8866406083106995,0.4178559482097626,0.7806599736213684,0.8491368293762207,50000 -12299.447113752363,23.19686794281006,142461.56035637856,318572,0,142461.56035637856,0.6628000140190125,1.439961552619934,10000,154801.62366724014,0.8883398175239563,0.4167503118515014,0.7806599736213684,0.8491368293762207,50000 -12340.786552906036,23.28324294090271,142881.77575540543,319510,0,142881.77575540543,0.6628000140190125,1.439961552619934,10000,155263.3157105446,0.88916015625,0.4147959053516388,0.7806599736213684,0.8491368293762207,50000 -12379.863595485687,23.374863862991333,143302.01441574097,320448,0,143302.01441574097,0.6628000140190125,1.439961552619934,10000,155722.77356290817,0.8883007764816284,0.4113078117370605,0.7806599736213684,0.8491368293762207,50000 -12423.899383544922,23.473129272460938,143722.09818434715,321387,0,143722.09818434715,0.6628000140190125,1.439961552619934,10000,156187.04256916046,0.8866991996765137,0.4174054563045501,0.7806599736213684,0.8491368293762207,50000 -12464.060123443604,23.5680513381958,144142.33653116226,322323,0,144142.33653116226,0.6628000140190125,1.439961552619934,10000,156647.58727145195,0.8899218440055847,0.4096273779869079,0.7806599736213684,0.8491368293762207,50000 -12503.45373558998,23.660505771636963,144562.4763252735,323258,0,144562.4763252735,0.6628000140190125,1.439961552619934,10000,157107.26393318176,0.8880468606948853,0.4144402146339416,0.7806599736213684,0.8491368293762207,50000 -12541.001872062683,23.807493209838867,144982.35832333565,324194,0,144982.35832333565,0.6628000140190125,1.439961552619934,10000,157564.89127373695,0.8901171684265137,0.4119070470333099,0.7806599736213684,0.8491368293762207,50000 -12582.273058652878,23.902076482772827,145402.36496806145,325133,0,145402.36496806145,0.6628000140190125,1.439961552619934,10000,158026.31503486633,0.8872851133346558,0.4181750118732452,0.7806599736213684,0.8491368293762207,50000 -12623.30327129364,23.99218988418579,145822.30849671364,326071,0,145822.30849671364,0.6628000140190125,1.439961552619934,10000,158487.42919802666,0.8898437023162842,0.4119757115840912,0.7806599736213684,0.8491368293762207,50000 -12659.065303087234,24.135462999343872,146242.26469278336,327009,0,146242.26469278336,0.6628000140190125,1.439961552619934,10000,158943.34083485603,0.8896679282188416,0.4137422144412994,0.7806599736213684,0.8491368293762207,50000 -12697.31627702713,24.234971523284912,146662.29273629189,327945,0,146662.29273629189,0.6628000140190125,1.439961552619934,10000,159401.76953983307,0.8895702958106995,0.4116011261940002,0.7806599736213684,0.8491368293762207,50000 -12738.77527141571,24.327179431915283,147082.52355718613,328883,0,147082.52355718613,0.6628000140190125,1.439961552619934,10000,159863.60181355476,0.8875585794448853,0.4162088930606842,0.7806599736213684,0.8491368293762207,50000 -12778.015112400057,24.46321201324463,147502.60175275803,329818,0,147502.60175275803,0.6628000140190125,1.439961552619934,10000,160323.10658454895,0.8869140148162842,0.4170363843441009,0.7806599736213684,0.8491368293762207,50000 -12819.419175863266,24.56829261779785,147922.8135919571,330755,0,147922.8135919571,0.6628000140190125,1.439961552619934,10000,160784.8777372837,0.8896679282188416,0.4135584831237793,0.7806599736213684,0.8491368293762207,50000 -12867.414836645126,24.66355276107788,148343.09826993942,331694,0,148343.09826993942,0.6628000140190125,1.439961552619934,10000,161253.30429697037,0.8881250023841858,0.4180620610713959,0.7806599736213684,0.8491368293762207,50000 -12907.345923662186,24.7419536113739,148763.2022268772,332637,0,148763.2022268772,0.6628000140190125,1.439961552619934,10000,161713.47340416908,0.8884961009025574,0.4143311083316803,0.7806599736213684,0.8491368293762207,50000 -12950.252911806108,24.83434271812439,149183.26648783684,333573,0,149183.26648783684,0.6628000140190125,1.439961552619934,10000,162176.588023901,0.8854101300239563,0.4235457181930542,0.7806599736213684,0.8491368293762207,50000 -12992.80593752861,24.92528367042541,149603.49204540253,334508,0,149603.49204540253,0.6628000140190125,1.439961552619934,10000,162639.50809574127,0.8883984088897705,0.4125818610191345,0.7806599736213684,0.8491368293762207,50000 -13034.503975629808,24.99918556213379,150023.56773352623,335449,0,150023.56773352623,0.6628000140190125,1.439961552619934,10000,163101.40661740303,0.8907030820846558,0.4066536724567413,0.7806599736213684,0.8491368293762207,50000 -13079.27995443344,25.09358143806457,150443.80747699738,336388,0,150443.80747699738,0.6628000140190125,1.439961552619934,10000,163566.56737589836,0.8867382407188416,0.4205570816993713,0.7806599736213684,0.8491368293762207,50000 -13119.275871992111,25.190476655960083,150863.85061359406,337325,0,150863.85061359406,0.6628000140190125,1.439961552619934,10000,164026.75428032875,0.8875390291213989,0.4201224744319916,0.7806599736213684,0.8491368293762207,50000 -13171.61202454567,25.28453826904297,151283.9351758957,338258,0,151283.9351758957,0.6628000140190125,1.439961552619934,10000,164499.32343387604,0.8876367211341858,0.4145201444625854,0.7806599736213684,0.8491368293762207,50000 -13206.307997703552,25.36584448814392,151703.95897865295,339196,0,151703.95897865295,0.6628000140190125,1.439961552619934,10000,164954.1756284237,0.8885351419448853,0.4120396077632904,0.7806599736213684,0.8491368293762207,50000 -13246.64460015297,25.46287226676941,152124.17705917358,340129,0,152124.17705917358,0.6628000140190125,1.439961552619934,10000,165414.87696027756,0.8873828053474426,0.4202143251895904,0.7806599736213684,0.8491368293762207,50000 -13289.47941160202,25.565407514572144,152544.35502910614,341061,0,152544.35502910614,0.6628000140190125,1.439961552619934,10000,165878.0419712067,0.8905078172683716,0.4098507463932037,0.7806599736213684,0.8491368293762207,50000 -13322.339082717896,25.64190411567688,152964.3291182518,342003,0,152964.3291182518,0.6628000140190125,1.439961552619934,10000,166331.00338292122,0.8872265219688416,0.4186387658119201,0.7806599736213684,0.8491368293762207,50000 -13362.112790107729,25.7352511882782,153384.33371186256,342939,0,153384.33371186256,0.6628000140190125,1.439961552619934,10000,166790.92618012428,0.8864062428474426,0.4191470742225647,0.7806599736213684,0.8491368293762207,50000 -13412.558959245682,25.82959008216858,153804.50244569778,343876,0,153804.50244569778,0.6628000140190125,1.439961552619934,10000,167261.6859352589,0.8897656202316284,0.4129240214824676,0.7806599736213684,0.8491368293762207,50000 -13450.191165685654,25.908400774002075,154224.41515517235,344816,0,154224.41515517235,0.6628000140190125,1.439961552619934,10000,167719.3604798317,0.8880664110183716,0.4134820103645324,0.7806599736213684,0.8491368293762207,50000 -13497.83123230934,26.00806427001953,154644.2866268158,345753,0,154644.2866268158,0.6628000140190125,1.439961552619934,10000,168187.02220201492,0.8886132836341858,0.4115453362464905,0.7806599736213684,0.8491368293762207,50000 -13535.284557580948,26.102439403533936,155064.39183330536,346689,0,155064.39183330536,0.6628000140190125,1.439961552619934,10000,168644.72593283653,0.88720703125,0.4173647165298462,0.7806599736213684,0.8491368293762207,50000 -13579.531408786774,26.227181434631348,155484.39171671867,347627,0,155484.39171671867,0.6628000140190125,1.439961552619934,10000,169109.14815998077,0.8901757597923279,0.4099457561969757,0.7806599736213684,0.8491368293762207,50000 -13619.071528434752,26.375840425491333,155904.38059473038,348561,0,155904.38059473038,0.6628000140190125,1.439961552619934,10000,169568.8761703968,0.8900585770606995,0.4157879054546356,0.7806599736213684,0.8491368293762207,50000 -13669.715999364853,26.47144365310669,156324.56995940208,349496,0,156324.56995940208,0.6628000140190125,1.439961552619934,10000,170039.8563761711,0.8889062404632568,0.4083418846130371,0.7806599736213684,0.8491368293762207,50000 -13707.466798067093,26.550642251968384,156744.44487810135,350436,0,156744.44487810135,0.6628000140190125,1.439961552619934,10000,170497.61222934723,0.8913476467132568,0.4106284379959106,0.7806599736213684,0.8491368293762207,50000 -13749.024479150772,26.64775824546814,157164.40446853638,351364,0,157164.40446853638,0.6628000140190125,1.439961552619934,10000,170959.27654767036,0.8867577910423279,0.41962930560112,0.7806599736213684,0.8491368293762207,50000 -13791.469810724258,26.74502730369568,157584.49403834343,352296,0,157584.49403834343,0.6628000140190125,1.439961552619934,10000,171421.9588572979,0.8877148032188416,0.4151216149330139,0.7806599736213684,0.8491368293762207,50000 -13829.42780804634,26.84268879890442,158004.5953962803,353231,0,158004.5953962803,0.6628000140190125,1.439961552619934,10000,171880.16689515114,0.8877539038658142,0.4173828661441803,0.7806599736213684,0.8491368293762207,50000 -13871.05332994461,26.9394805431366,158424.50338053703,354164,0,158424.50338053703,0.6628000140190125,1.439961552619934,10000,172341.84744262695,0.8898046612739563,0.4126203954219818,0.7806599736213684,0.8491368293762207,50000 -13910.14951992035,27.044692754745483,158844.56557393074,355103,0,158844.56557393074,0.6628000140190125,1.439961552619934,10000,172801.1633708477,0.8887304663658142,0.4155606627464294,0.7806599736213684,0.8491368293762207,50000 -13953.177673339844,27.149667501449585,159264.79722499847,356040,0,159264.79722499847,0.6628000140190125,1.439961552619934,10000,173264.57972598076,0.8875976204872131,0.4178290069103241,0.7806599736213684,0.8491368293762207,50000 -13990.813126564026,27.24921751022339,159684.81939792633,356979,0,159684.81939792633,0.6628000140190125,1.439961552619934,10000,173722.38829994202,0.8841601610183716,0.422722190618515,0.7806599736213684,0.8491368293762207,50000 -14033.412657022476,27.349547863006592,160104.9553952217,357916,0,160104.9553952217,0.6628000140190125,1.439961552619934,10000,174185.2740430832,0.8896288871765137,0.4120044708251953,0.7806599736213684,0.8491368293762207,50000 -14069.0910923481,27.44993758201599,160524.85397338867,358856,0,160524.85397338867,0.6628000140190125,1.439961552619934,10000,174641.0027630329,0.8890820145606995,0.410955011844635,0.7806599736213684,0.8491368293762207,50000 -14108.684185504912,27.54713273048401,160944.96697402,359790,0,160944.96697402,0.6628000140190125,1.439961552619934,10000,175100.8567893505,0.8885937333106995,0.4169111549854278,0.7806599736213684,0.8491368293762207,50000 -14151.323654413223,27.643819570541385,161364.8423793316,360725,0,161364.8423793316,0.6628000140190125,1.439961552619934,10000,175563.51901459694,0.8864648342132568,0.421193391084671,0.7806599736213684,0.8491368293762207,50000 -14188.880197048187,27.73991370201111,161785.01636505127,361665,0,161785.01636505127,0.6628000140190125,1.439961552619934,10000,176021.39610123634,0.8865820169448853,0.4151625633239746,0.7806599736213684,0.8491368293762207,50000 -14234.16361951828,27.84222197532653,162205.02902841568,362603,0,162205.02902841568,0.6628000140190125,1.439961552619934,10000,176486.84476661682,0.8888476490974426,0.4111815989017486,0.7806599736213684,0.8491368293762207,50000 -14275.383254051208,27.941187858581543,162625.00325775146,363541,0,162625.00325775146,0.6628000140190125,1.439961552619934,10000,176948.18750071526,0.8886523246765137,0.4192703664302826,0.7806599736213684,0.8491368293762207,50000 -14311.135527849196,28.044018983840942,163045.1636610031,364479,0,163045.1636610031,0.6628000140190125,1.439961552619934,10000,177404.25304937363,0.8897656202316284,0.4115915596485138,0.7806599736213684,0.8491368293762207,50000 -14356.90769314766,28.14191770553589,163465.0584104061,365413,0,163465.0584104061,0.6628000140190125,1.439961552619934,10000,177870.06887936592,0.8880468606948853,0.4186495542526245,0.7806599736213684,0.8491368293762207,50000 -14391.303559541702,28.22307014465332,163885.0548875332,366350,0,163885.0548875332,0.6628000140190125,1.439961552619934,10000,178324.59226608276,0.8872460722923279,0.4188577532768249,0.7806599736213684,0.8491368293762207,50000 -14434.169000864027,28.325141668319706,164305.09626293182,367286,0,164305.09626293182,0.6628000140190125,1.439961552619934,10000,178787.65193104744,0.8896288871765137,0.4089963436126709,0.7806599736213684,0.8491368293762207,50000 -14474.787692546844,28.424390077590942,164725.26630687714,368223,0,164725.26630687714,0.6628000140190125,1.439961552619934,10000,179248.5905110836,0.8872460722923279,0.4134839475154876,0.7806599736213684,0.8491368293762207,50000 -14516.611475229263,28.584221839904785,165145.18168735504,369160,0,165145.18168735504,0.6628000140190125,1.439961552619934,10000,179710.54042887688,0.8898632526397705,0.4105189740657806,0.7806599736213684,0.8491368293762207,50000 -14557.888511419296,28.685779333114624,165565.3174443245,370098,0,165565.3174443245,0.6628000140190125,1.439961552619934,10000,180172.10485219955,0.8863281011581421,0.4179736673831939,0.7806599736213684,0.8491368293762207,50000 -14596.663992404938,28.78573179244995,165985.38984370232,371035,0,165985.38984370232,0.6628000140190125,1.439961552619934,10000,180631.1038618088,0.889453113079071,0.4162811636924743,0.7806599736213684,0.8491368293762207,50000 -14637.276457309725,28.905762434005737,166405.45948576927,371973,0,166405.45948576927,0.6628000140190125,1.439961552619934,10000,181091.9565012455,0.8895117044448853,0.4145427942276001,0.7806599736213684,0.8491368293762207,50000 -14675.369314670565,29.011127471923828,166825.61236143112,372908,0,166825.61236143112,0.6628000140190125,1.439961552619934,10000,181550.35774874687,0.889941394329071,0.4062135517597198,0.7806599736213684,0.8491368293762207,50000 -14714.035625219343,29.120630979537964,167245.81822776794,373842,0,167245.81822776794,0.6628000140190125,1.439961552619934,10000,182009.38965320587,0.888476550579071,0.4190575182437897,0.7806599736213684,0.8491368293762207,50000 -14750.559784173964,29.287435293197632,167665.64027285576,374778,0,167665.64027285576,0.6628000140190125,1.439961552619934,10000,182465.953375578,0.88880854845047,0.4142800569534302,0.7806599736213684,0.8491368293762207,50000 -14792.984617710114,29.391901969909668,168085.7812242508,375716,0,168085.7812242508,0.6628000140190125,1.439961552619934,10000,182928.67467594147,0.88720703125,0.4153881371021271,0.7806599736213684,0.8491368293762207,50000 -14830.284606933594,29.49255681037903,168506.1006731987,376654,0,168506.1006731987,0.6628000140190125,1.439961552619934,10000,183386.4454011917,0.8880664110183716,0.4154757261276245,0.7806599736213684,0.8491368293762207,50000 -14867.735316514969,29.594119548797607,168926.2371172905,377591,0,168926.2371172905,0.6628000140190125,1.439961552619934,10000,183844.1845271588,0.8890624642372131,0.4145684838294983,0.7806599736213684,0.8491368293762207,50000 -14909.759913921356,29.696993350982662,169346.13319802284,378529,0,169346.13319802284,0.6628000140190125,1.439961552619934,10000,184306.2586414814,0.8909765481948853,0.4109188318252563,0.7806599736213684,0.8491368293762207,50000 -14951.494970560074,29.796590089797974,169766.32176852226,379467,0,169766.32176852226,0.6628000140190125,1.439961552619934,10000,184768.3322675228,0.8854491710662842,0.4222442507743835,0.7806599736213684,0.8491368293762207,50000 -14988.764938354492,29.91712141036988,170186.26800489426,380409,0,170186.26800489426,0.6628000140190125,1.439961552619934,10000,185225.72126054764,0.8866991996765137,0.4186405539512634,0.7806599736213684,0.8491368293762207,50000 -15026.391194105148,30.01814985275269,170606.38489556313,381348,0,170606.38489556313,0.6628000140190125,1.439961552619934,10000,185683.616868496,0.8898242115974426,0.4092896282672882,0.7806599736213684,0.8491368293762207,50000 -15068.558099269869,30.119590759277344,171026.64174222946,382287,0,171026.64174222946,0.6628000140190125,1.439961552619934,10000,186146.192868948,0.8888476490974426,0.4133008122444153,0.7806599736213684,0.8491368293762207,50000 -15108.549597740172,30.223318576812744,171446.73694324493,383223,0,171446.73694324493,0.6628000140190125,1.439961552619934,10000,186606.4345138073,0.8874022960662842,0.4168292582035064,0.7806599736213684,0.8491368293762207,50000 -15150.392226219175,30.32602262496948,171866.6207382679,384161,0,171866.6207382679,0.6628000140190125,1.439961552619934,10000,187068.31393790245,0.8870702981948853,0.4222002029418945,0.7806599736213684,0.8491368293762207,50000 -15185.938012599943,30.430259704589844,172286.67732286453,385099,0,172286.67732286453,0.6628000140190125,1.439961552619934,10000,187524.0705358982,0.8860937356948853,0.4177397787570953,0.7806599736213684,0.8491368293762207,50000 -15226.47905087471,30.53153157234192,172706.67554998398,386036,0,172706.67554998398,0.6628000140190125,1.439961552619934,10000,187984.7664122581,0.89013671875,0.4076444208621979,0.7806599736213684,0.8491368293762207,50000 -15266.359908103945,30.63612031936645,173126.61281085014,386974,0,173126.61281085014,0.6628000140190125,1.439961552619934,10000,188444.7423121929,0.8879687190055847,0.4217220842838287,0.7806599736213684,0.8491368293762207,50000 -15307.598673582075,30.744555711746216,173546.7705936432,387911,0,173546.7705936432,0.6628000140190125,1.439961552619934,10000,188906.29854869845,0.888671875,0.4133156836032867,0.7806599736213684,0.8491368293762207,50000 -15344.09970164299,30.84595322608948,173967.05476927757,388848,0,173967.05476927757,0.6628000140190125,1.439961552619934,10000,189363.2365736961,0.8880078196525574,0.4169968068599701,0.7806599736213684,0.8491368293762207,50000 -15385.21784043312,30.950018405914307,174387.26860404015,389786,0,174387.26860404015,0.6628000140190125,1.439961552619934,10000,189824.72399759293,0.8884570002555847,0.4161718785762787,0.7806599736213684,0.8491368293762207,50000 -15428.580172777176,31.055320739746094,174807.295296669,390726,0,174807.295296669,0.6628000140190125,1.439961552619934,10000,190288.2693283558,0.8868359327316284,0.4168016612529754,0.7806599736213684,0.8491368293762207,50000 -15463.608940601349,31.21803855895996,175227.144854784,391665,0,175227.144854784,0.6628000140190125,1.439961552619934,10000,190743.3609380722,0.8889062404632568,0.4064289629459381,0.7806599736213684,0.8491368293762207,50000 -15504.066883087158,31.32296562194824,175647.2268936634,392599,0,175647.2268936634,0.6628000140190125,1.439961552619934,10000,191204.0561449528,0.8891796469688416,0.4163433015346527,0.7806599736213684,0.8491368293762207,50000 -15545.477265357971,31.42723274230957,176067.19288492203,393536,0,176067.19288492203,0.6628000140190125,1.439961552619934,10000,191665.5880215168,0.8873632550239563,0.4174995720386505,0.7806599736213684,0.8491368293762207,50000 -15585.209602117538,31.53172993659973,176487.19069099426,394473,0,176487.19069099426,0.6628000140190125,1.439961552619934,10000,192125.47338366508,0.8908984065055847,0.4083752334117889,0.7806599736213684,0.8491368293762207,50000 -15630.927192687988,31.637397050857544,176907.33445096016,395412,0,176907.33445096016,0.6628000140190125,1.439961552619934,10000,192591.4912652969,0.8875976204872131,0.4181486368179321,0.7806599736213684,0.8491368293762207,50000 -15677.581050395966,31.74113130569458,177327.337089777,396352,0,177327.337089777,0.6628000140190125,1.439961552619934,10000,193058.3018424511,0.8899609446525574,0.4101354479789734,0.7806599736213684,0.8491368293762207,50000 -15713.32974267006,31.83563208580017,177747.5464015007,397294,0,177747.5464015007,0.6628000140190125,1.439961552619934,10000,193514.40622234344,0.8896484375,0.4111970663070678,0.7806599736213684,0.8491368293762207,50000 -15750.907168149948,31.93990921974182,178167.48953986168,398230,0,178167.48953986168,0.6628000140190125,1.439961552619934,10000,193972.0816402436,0.8900781273841858,0.4136282205581665,0.7806599736213684,0.8491368293762207,50000 -15796.030497074127,32.04532527923584,178587.6754131317,399168,0,178587.6754131317,0.6628000140190125,1.439961552619934,10000,194437.5472960472,0.8863476514816284,0.4190600514411926,0.7806599736213684,0.8491368293762207,50000 -15829.235847234726,32.13245725631714,179007.58912825584,400108,0,179007.58912825584,0.6628000140190125,1.439961552619934,10000,194890.8040919304,0.8864257335662842,0.4217060506343841,0.7806599736213684,0.8491368293762207,50000 -15875.423902750015,32.232825756073,179427.76574611664,401046,0,179427.76574611664,0.6628000140190125,1.439961552619934,10000,195357.3195183277,0.8907226324081421,0.4068452715873718,0.7806599736213684,0.8491368293762207,50000 -15915.782636880876,32.33623695373535,179847.76591467857,401986,0,179847.76591467857,0.6628000140190125,1.439961552619934,10000,195817.83309102056,0.8880859017372131,0.4193184077739715,0.7806599736213684,0.8491368293762207,50000 -15954.291462898254,32.424152135849,180267.8737359047,402925,0,180267.8737359047,0.6628000140190125,1.439961552619934,10000,196276.5892522335,0.8873828053474426,0.4192410707473755,0.7806599736213684,0.8491368293762207,50000 -15989.69368314743,32.5312180519104,180687.81717062,403860,0,180687.81717062,0.6628000140190125,1.439961552619934,10000,196732.092625618,0.8870312571525574,0.4160933196544647,0.7806599736213684,0.8491368293762207,50000 -16034.969373941422,32.63423204421997,181108.0825717449,404797,0,181108.0825717449,0.6628000140190125,1.439961552619934,10000,197197.78607559204,0.8872265219688416,0.4159035086631775,0.7806599736213684,0.8491368293762207,50000 -16072.9281270504,32.72395944595337,181527.94424247745,405738,0,181527.94424247745,0.6628000140190125,1.439961552619934,10000,197655.74682331085,0.8909765481948853,0.4066831469535827,0.7806599736213684,0.8491368293762207,50000 -16109.006716966627,32.83241391181946,181948.04838109016,406678,0,181948.04838109016,0.6628000140190125,1.439961552619934,10000,198112.09012413025,0.8872851133346558,0.4161201119422912,0.7806599736213684,0.8491368293762207,50000 -16158.827764034271,32.94143462181091,182368.29708647728,407614,0,182368.29708647728,0.6628000140190125,1.439961552619934,10000,198582.32165956497,0.8865038752555847,0.4203358888626098,0.7806599736213684,0.8491368293762207,50000 -16198.477226018906,33.029731035232544,182788.3319592476,408554,0,182788.3319592476,0.6628000140190125,1.439961552619934,10000,199042.1483139992,0.8880078196525574,0.4149841964244842,0.7806599736213684,0.8491368293762207,50000 -16233.813900232317,33.13789200782776,183208.33271503448,409490,0,183208.33271503448,0.6628000140190125,1.439961552619934,10000,199497.6459169388,0.8875390291213989,0.4196300804615021,0.7806599736213684,0.8491368293762207,50000 -16280.1210501194,33.24795961380005,183628.6072833538,410422,0,183628.6072833538,0.6628000140190125,1.439961552619934,10000,199964.3894040585,0.8884961009025574,0.4138842225074768,0.7806599736213684,0.8491368293762207,50000 -16312.81937122345,33.35289120674133,184048.7924396992,411364,0,184048.7924396992,0.6628000140190125,1.439961552619934,10000,200417.43037319183,0.8887109160423279,0.4179222285747528,0.7806599736213684,0.8491368293762207,50000 -16354.667058467863,33.46069145202637,184468.7359404564,412301,0,184468.7359404564,0.6628000140190125,1.439961552619934,10000,200879.38162469864,0.8895702958106995,0.4115485548973083,0.7806599736213684,0.8491368293762207,50000 -16401.355442762375,33.5668420791626,184888.77264356613,413233,0,184888.77264356613,0.6628000140190125,1.439961552619934,10000,201346.26455664635,0.8875781297683716,0.4195011556148529,0.7806599736213684,0.8491368293762207,50000 -16434.469732284546,33.65598273277283,185308.6721343994,414171,0,185308.6721343994,0.6628000140190125,1.439961552619934,10000,201799.42095828056,0.8879101276397705,0.4134228527545929,0.7806599736213684,0.8491368293762207,50000 -16471.08065843582,33.76547908782959,185728.9335186481,415111,0,185728.9335186481,0.6628000140190125,1.439961552619934,10000,202256.4557170868,0.8889257907867432,0.4115974307060241,0.7806599736213684,0.8491368293762207,50000 -16522.232833862305,33.87578749656677,186149.0531988144,416049,0,186149.0531988144,0.6628000140190125,1.439961552619934,10000,202727.8894138336,0.888671875,0.4118475914001465,0.7806599736213684,0.8491368293762207,50000 -16556.713079452515,33.965121269226074,186568.91537976265,416992,0,186568.91537976265,0.6628000140190125,1.439961552619934,10000,203182.3723757267,0.8870507478713989,0.4178069829940796,0.7806599736213684,0.8491368293762207,50000 -16599.438136577606,34.98693084716797,186987.864846468,417927,0,186987.864846468,0.6628000140190125,1.439961552619934,10000,203645.1201398373,0.8895117044448853,0.4142775535583496,0.7806599736213684,0.8491368293762207,50000 -16638.66318511963,35.0987401008606,187407.9261534214,418865,0,187407.9261534214,0.6628000140190125,1.439961552619934,10000,204104.56997919083,0.8903124928474426,0.4141020476818084,0.7806599736213684,0.8491368293762207,50000 -16683.20975136757,35.21203064918518,187828.1980266571,419804,0,187828.1980266571,0.6628000140190125,1.439961552619934,10000,204569.553198576,0.8893945217132568,0.4078099727630615,0.7806599736213684,0.8491368293762207,50000 -16722.040591955185,35.3136100769043,188248.1066851616,420742,0,188248.1066851616,0.6628000140190125,1.439961552619934,10000,205028.44544243813,0.8888280987739563,0.4157975018024444,0.7806599736213684,0.8491368293762207,50000 -16759.408123254776,35.428178787231445,188668.3634223938,421679,0,188668.3634223938,0.6628000140190125,1.439961552619934,10000,205486.2354171276,0.8882812261581421,0.4160736501216888,0.7806599736213684,0.8491368293762207,50000 -16802.113570451736,35.536232471466064,189088.33290719983,422616,0,189088.33290719983,0.6628000140190125,1.439961552619934,10000,205949.0699858665,0.8886913657188416,0.4118205606937408,0.7806599736213684,0.8491368293762207,50000 -16839.58140516281,35.73130702972412,189508.17068457603,423554,0,189508.17068457603,0.6628000140190125,1.439961552619934,10000,206406.62183642387,0.8875585794448853,0.4185675084590912,0.7806599736213684,0.8491368293762207,50000 -16879.394988775253,35.84501767158508,189928.0055382252,424490,0,189928.0055382252,0.6628000140190125,1.439961552619934,10000,206866.4358620644,0.8890429735183716,0.4132682383060455,0.7806599736213684,0.8491368293762207,50000 -16920.03908252716,35.94810652732849,190347.9941241741,425429,0,190347.9941241741,0.6628000140190125,1.439961552619934,10000,207327.2237203121,0.88734370470047,0.4196708500385284,0.7806599736213684,0.8491368293762207,50000 -16961.210329294205,36.057634353637695,190768.1888158321,426367,0,190768.1888158321,0.6628000140190125,1.439961552619934,10000,207788.7522785664,0.8880664110183716,0.4123572409152984,0.7806599736213684,0.8491368293762207,50000 -17002.576770305634,36.17466354370117,191188.41430687904,427307,0,191188.41430687904,0.6628000140190125,1.439961552619934,10000,208250.51259469983,0.8867382407188416,0.4196271896362304,0.7806599736213684,0.8491368293762207,50000 -17042.562101125717,36.28608512878418,191608.5195260048,428245,0,191608.5195260048,0.6628000140190125,1.439961552619934,10000,208710.7665054798,0.8895898461341858,0.4129971861839294,0.7806599736213684,0.8491368293762207,50000 -17077.692586660385,36.40941309928894,192028.75000452995,429185,0,192028.75000452995,0.6628000140190125,1.439961552619934,10000,209166.30269360545,0.8884375095367432,0.4158594012260437,0.7806599736213684,0.8491368293762207,50000 -17115.80970931053,36.52241349220276,192449.02923035625,430124,0,192449.02923035625,0.6628000140190125,1.439961552619934,10000,209624.86515569687,0.8883788585662842,0.4134212136268616,0.7806599736213684,0.8491368293762207,50000 -17159.325850486755,36.63349270820618,192869.14423632625,431061,0,192869.14423632625,0.6628000140190125,1.439961552619934,10000,210088.6588923931,0.8847460746765137,0.4247710406780243,0.7806599736213684,0.8491368293762207,50000 -17196.563505887985,36.75905227661133,193288.98817515373,431998,0,193288.98817515373,0.6628000140190125,1.439961552619934,10000,210545.91840672493,0.8877539038658142,0.4151757359504699,0.7806599736213684,0.8491368293762207,50000 -17239.514502763748,36.87295937538147,193709.06731200207,432936,0,193709.06731200207,0.6628000140190125,1.439961552619934,10000,211009.1180469989,0.8878124952316284,0.4115523099899292,0.7806599736213684,0.8491368293762207,50000 -17278.63621211052,36.99358677864075,194129.1805028916,433872,0,194129.1805028916,0.6628000140190125,1.439961552619934,10000,211468.52519202232,0.8901171684265137,0.4164775013923645,0.7806599736213684,0.8491368293762207,50000 -17319.42714571953,37.13652777671814,194549.187220335,434810,0,194549.187220335,0.6628000140190125,1.439961552619934,10000,211929.5177807808,0.887988269329071,0.4155340194702148,0.7806599736213684,0.8491368293762207,50000 -17361.18970298767,37.25248050689697,194969.2794907093,435748,0,194969.2794907093,0.6628000140190125,1.439961552619934,10000,212391.54075169563,0.8893749713897705,0.4128265976905823,0.7806599736213684,0.8491368293762207,50000 -17403.448195934296,37.36638760566712,195389.2151684761,436685,0,195389.2151684761,0.6628000140190125,1.439961552619934,10000,212853.9011592865,0.8864257335662842,0.4213772416114807,0.7806599736213684,0.8491368293762207,50000 -17441.593303203583,37.49781918525696,195809.24481654167,437624,0,195809.24481654167,0.6628000140190125,1.439961552619934,10000,213312.25943803787,0.8902343511581421,0.4092316329479217,0.7806599736213684,0.8491368293762207,50000 -17483.80901670456,37.61236357688904,196229.3932518959,438562,0,196229.3932518959,0.6628000140190125,1.439961552619934,10000,213774.7895960808,0.8884179592132568,0.412599503993988,0.7806599736213684,0.8491368293762207,50000 -17529.58095383644,37.7523980140686,196649.24244713783,439498,0,196649.24244713783,0.6628000140190125,1.439961552619934,10000,214240.602208376,0.8879101276397705,0.4182517826557159,0.7806599736213684,0.8491368293762207,50000 -17570.387481451035,37.847246170043945,197069.23230481148,440435,0,197069.23230481148,0.6628000140190125,1.439961552619934,10000,214701.5451090336,0.8888476490974426,0.4082626700401306,0.7806599736213684,0.8491368293762207,50000 -17611.045613527298,37.9582302570343,197489.0875673294,441371,0,197489.0875673294,0.6628000140190125,1.439961552619934,10000,215162.2216293812,0.8873828053474426,0.4172675311565399,0.7806599736213684,0.8491368293762207,50000 -17651.756640195847,38.07303929328919,197909.2683000565,442309,0,197909.2683000565,0.6628000140190125,1.439961552619934,10000,215623.2797615528,0.8910546898841858,0.4143266081809997,0.7806599736213684,0.8491368293762207,50000 -17696.90135359764,38.18743085861206,198329.16502404213,443245,0,198329.16502404213,0.6628000140190125,1.439961552619934,10000,216088.4873046875,0.8901757597923279,0.4095847606658935,0.7806599736213684,0.8491368293762207,50000 -17735.406715393066,38.32096481323242,198749.2635633945,444182,0,198749.2635633945,0.6628000140190125,1.439961552619934,10000,216547.2765877247,0.8889062404632568,0.412271648645401,0.7806599736213684,0.8491368293762207,50000 -17776.054845571518,38.437227964401245,199169.37876105309,445119,0,199169.37876105309,0.6628000140190125,1.439961552619934,10000,217008.20818781853,0.8889062404632568,0.4158249497413635,0.7806599736213684,0.8491368293762207,50000 -17829.23697257042,38.55621099472046,199589.2116761208,446052,0,199589.2116761208,0.6628000140190125,1.439961552619934,10000,217481.3932626248,0.8854687213897705,0.4212657809257507,0.7806599736213684,0.8491368293762207,50000 -17865.259093284607,38.65262007713318,200009.1923904419,446989,0,200009.1923904419,0.6628000140190125,1.439961552619934,10000,217937.54481863976,0.8869921565055847,0.4161044359207153,0.7806599736213684,0.8491368293762207,50000 -17907.731934070587,38.77121067047119,200429.4089901448,447923,0,200429.4089901448,0.6628000140190125,1.439961552619934,10000,218400.40449666977,0.8915429711341858,0.4091703295707702,0.7806599736213684,0.8491368293762207,50000 -17942.856937885284,38.88737273216248,200849.61452794075,448858,0,200849.61452794075,0.6628000140190125,1.439961552619934,10000,218855.90346169472,0.8885351419448853,0.4149369299411773,0.7806599736213684,0.8491368293762207,50000 -17979.217833518982,38.99995350837708,201269.6993484497,449797,0,201269.6993484497,0.6628000140190125,1.439961552619934,10000,219312.5145866871,0.8884961009025574,0.4168313145637512,0.7806599736213684,0.8491368293762207,50000 -18026.383950948715,39.11453557014465,201689.80603957176,450734,0,201689.80603957176,0.6628000140190125,1.439961552619934,10000,219779.9533605576,0.8858007788658142,0.4214471280574798,0.7806599736213684,0.8491368293762207,50000 -18069.23517775536,39.230353116989136,202109.7711122036,451672,0,202109.7711122036,0.6628000140190125,1.439961552619934,10000,220242.93742847443,0.8877733945846558,0.4172125458717346,0.7806599736213684,0.8491368293762207,50000 -18102.98245286941,39.325342893600464,202529.7824888229,452613,0,202529.7824888229,0.6628000140190125,1.439961552619934,10000,220696.84306454656,0.888476550579071,0.4082569479942322,0.7806599736213684,0.8491368293762207,50000 -18142.578282356262,39.43860602378845,202949.7999596596,453549,0,202949.7999596596,0.6628000140190125,1.439961552619934,10000,221156.62088036537,0.8876367211341858,0.4193960428237915,0.7806599736213684,0.8491368293762207,50000 -18186.859826803207,39.55583477020264,203369.71106481552,454485,0,203369.71106481552,0.6628000140190125,1.439961552619934,10000,221620.98258256912,0.8869335651397705,0.4207581579685211,0.7806599736213684,0.8491368293762207,50000 -18226.327399015427,39.65032410621643,203789.9735286236,455426,0,203789.9735286236,0.6628000140190125,1.439961552619934,10000,222080.85911655423,0.88818359375,0.4161781072616577,0.7806599736213684,0.8491368293762207,50000 -18260.59964799881,39.7682089805603,204210.18108654025,456364,0,204210.18108654025,0.6628000140190125,1.439961552619934,10000,222535.50862407684,0.8884375095367432,0.4117911756038666,0.7806599736213684,0.8491368293762207,50000 -18312.688049077988,39.88325357437134,204630.0314304829,457298,0,204630.0314304829,0.6628000140190125,1.439961552619934,10000,223007.61329698563,0.889453113079071,0.4147363305091858,0.7806599736213684,0.8491368293762207,50000 -18347.80455422401,39.98117995262146,205050.12309336665,458238,0,205050.12309336665,0.6628000140190125,1.439961552619934,10000,223462.97123146057,0.8873828053474426,0.4160374402999878,0.7806599736213684,0.8491368293762207,50000 -18393.14288878441,40.12711071968079,205469.96904063225,459172,0,205469.96904063225,0.6628000140190125,1.439961552619934,10000,223928.3540511132,0.8875781297683716,0.4167408645153045,0.7806599736213684,0.8491368293762207,50000 -18433.56349658966,40.242703437805176,205889.9131486416,460112,0,205889.9131486416,0.6628000140190125,1.439961552619934,10000,224388.8865418434,0.8892577886581421,0.4150897562503814,0.7806599736213684,0.8491368293762207,50000 -18473.569224596024,40.36980032920837,206309.76017975807,461045,0,206309.76017975807,0.6628000140190125,1.439961552619934,10000,224848.9179956913,0.8885155916213989,0.4129526913166046,0.7806599736213684,0.8491368293762207,50000 -18512.961121320724,40.57051396369934,206729.5407309532,461977,0,206729.5407309532,0.6628000140190125,1.439961552619934,10000,225308.3424794674,0.8879492282867432,0.4129699170589447,0.7806599736213684,0.8491368293762207,50000 -18553.678634643555,40.68940997123718,207149.70380544665,462912,0,207149.70380544665,0.6628000140190125,1.439961552619934,10000,225769.39330291748,0.8885155916213989,0.4146435856819153,0.7806599736213684,0.8491368293762207,50000 -18599.97526431084,40.80809164047241,207569.53734230995,463848,0,207569.53734230995,0.6628000140190125,1.439961552619934,10000,226235.6934535504,0.8866796493530273,0.4163015186786651,0.7806599736213684,0.8491368293762207,50000 -18636.92086791992,40.91565990447998,207989.73219513893,464790,0,207989.73219513893,0.6628000140190125,1.439961552619934,10000,226692.9933104515,0.8889062404632568,0.4164812862873077,0.7806599736213684,0.8491368293762207,50000 -18680.181062698364,41.03439331054688,208409.66075110435,465726,0,208409.66075110435,0.6628000140190125,1.439961552619934,10000,227156.35308504105,0.8902929425239563,0.4111060202121734,0.7806599736213684,0.8491368293762207,50000 -18720.83535385132,41.144742250442505,208829.88141417503,466662,0,208829.88141417503,0.6628000140190125,1.439961552619934,10000,227617.39014077187,0.8910546898841858,0.4077297151088714,0.7806599736213684,0.8491368293762207,50000 -18757.53831171989,41.24268078804016,209249.93860125545,467603,0,209249.93860125545,0.6628000140190125,1.439961552619934,10000,228074.3007860184,0.8886913657188416,0.4169077575206756,0.7806599736213684,0.8491368293762207,50000 -18796.923159122467,41.36542820930481,209670.0353667736,468540,0,209670.0353667736,0.6628000140190125,1.439961552619934,10000,228533.95680093765,0.8895702958106995,0.4094632863998413,0.7806599736213684,0.8491368293762207,50000 -18844.599578857426,41.48731422424317,210090.10920333865,469474,0,210090.10920333865,0.6628000140190125,1.439961552619934,10000,229001.8804507256,0.8866015672683716,0.4216593205928802,0.7806599736213684,0.8491368293762207,50000 -18881.850703954697,41.58333134651184,210510.11679530144,470414,0,210510.11679530144,0.6628000140190125,1.439961552619934,10000,229459.28646302223,0.8875976204872131,0.415841668844223,0.7806599736213684,0.8491368293762207,50000 -18926.45775413513,41.70092463493347,210930.04602122307,471347,0,210930.04602122307,0.6628000140190125,1.439961552619934,10000,229923.9920327664,0.8912695050239563,0.4065933227539062,0.7806599736213684,0.8491368293762207,50000 -18965.783930778503,41.82049751281738,211350.18746185303,472275,0,211350.18746185303,0.6628000140190125,1.439961552619934,10000,230383.6294569969,0.8875390291213989,0.4188980460166931,0.7806599736213684,0.8491368293762207,50000 -19005.408094406128,41.94422078132629,211770.68524432185,473211,0,211770.68524432185,0.6628000140190125,1.439961552619934,10000,230843.92622327805,0.8882226347923279,0.415109634399414,0.7806599736213684,0.8491368293762207,50000 -19047.591923952103,42.06261587142944,212190.5941722393,474149,0,212190.5941722393,0.6628000140190125,1.439961552619934,10000,231306.1898348332,0.886523425579071,0.4200004339218139,0.7806599736213684,0.8491368293762207,50000 -19088.447747945786,42.17369341850281,212610.66497921944,475089,0,212610.66497921944,0.6628000140190125,1.439961552619934,10000,231767.27955651283,0.8861913681030273,0.42005056142807,0.7806599736213684,0.8491368293762207,50000 -19130.14138698578,42.2940833568573,213030.53802657127,476023,0,213030.53802657127,0.6628000140190125,1.439961552619934,10000,232229.01929712296,0.8905664086341858,0.4074235558509826,0.7806599736213684,0.8491368293762207,50000 -19168.720163583755,42.48556280136109,213450.59983038905,476960,0,213450.59983038905,0.6628000140190125,1.439961552619934,10000,232687.90293979645,0.8874413967132568,0.4190874397754669,0.7806599736213684,0.8491368293762207,50000 -19211.575483083725,42.624929428100586,213870.68713450432,477900,0,213870.68713450432,0.6628000140190125,1.439961552619934,10000,233151.0371210575,0.8867968320846558,0.4170536398887634,0.7806599736213684,0.8491368293762207,50000 -19249.38017272949,42.74868321418762,214290.8945174217,478839,0,214290.8945174217,0.6628000140190125,1.439961552619934,10000,233609.22505187988,0.8878905773162842,0.4166683554649353,0.7806599736213684,0.8491368293762207,50000 -19288.926438093185,42.94045972824097,214711.0296151638,479778,0,214711.0296151638,0.6628000140190125,1.439961552619934,10000,234069.1584665776,0.8868359327316284,0.4167793989181518,0.7806599736213684,0.8491368293762207,50000 -19328.59648346901,43.05870342254639,215131.2020647525,480716,0,215131.2020647525,0.6628000140190125,1.439961552619934,10000,234529.17214751244,0.8891991972923279,0.4213533401489258,0.7806599736213684,0.8491368293762207,50000 -19374.056756734848,43.17589569091797,215551.36306786537,481652,0,215551.36306786537,0.6628000140190125,1.439961552619934,10000,234994.9626107216,0.8895312547683716,0.4089459180831909,0.7806599736213684,0.8491368293762207,50000 -19410.689818143845,43.365275382995605,215971.16331911087,482594,0,215971.16331911087,0.6628000140190125,1.439961552619934,10000,235451.6388404369,0.8869726657867432,0.420018196105957,0.7806599736213684,0.8491368293762207,50000 -19450.36641383171,43.495421171188354,216391.1415436268,483532,0,216391.1415436268,0.6628000140190125,1.439961552619934,10000,235911.47638177872,0.8881444931030273,0.4166688323020935,0.7806599736213684,0.8491368293762207,50000 -19501.70230698585,43.616320848464966,216811.1978859901,484467,0,216811.1978859901,0.6628000140190125,1.439961552619934,10000,236383.04190945625,0.8903319835662842,0.4062535762786865,0.7806599736213684,0.8491368293762207,50000 -19540.132127523422,43.72937512397766,217231.43805646896,485408,0,217231.43805646896,0.6628000140190125,1.439961552619934,10000,236841.8782045841,0.8874218463897705,0.4174661040306091,0.7806599736213684,0.8491368293762207,50000 -19581.17331981659,43.85249710083008,217651.5862724781,486348,0,217651.5862724781,0.6628000140190125,1.439961552619934,10000,237303.24244117737,0.8887890577316284,0.4112873673439026,0.7806599736213684,0.8491368293762207,50000 -19621.407239437103,43.97446012496948,218071.7227330208,487285,0,218071.7227330208,0.6628000140190125,1.439961552619934,10000,237763.7866547108,0.8875781297683716,0.4143383204936981,0.7806599736213684,0.8491368293762207,50000 -19662.32736802101,44.102455615997314,218491.75671744347,488221,0,218491.75671744347,0.6628000140190125,1.439961552619934,10000,238224.9208495617,0.8894335627555847,0.4136330187320709,0.7806599736213684,0.8491368293762207,50000 -19705.860652446747,44.23263597488403,218911.8640100956,489157,0,218911.8640100956,0.6628000140190125,1.439961552619934,10000,238688.7435998917,0.88916015625,0.4151538014411926,0.7806599736213684,0.8491368293762207,50000 -19749.71036696434,44.36157035827637,219331.83431482315,490095,0,219331.83431482315,0.6628000140190125,1.439961552619934,10000,239152.7444009781,0.8889452815055847,0.4118638038635254,0.7806599736213684,0.8491368293762207,50000 -19789.958799123764,44.49367165565491,219752.01549005508,491034,0,219752.01549005508,0.6628000140190125,1.439961552619934,10000,239613.35805392265,0.8895898461341858,0.4138146638870239,0.7806599736213684,0.8491368293762207,50000 -19829.86556172371,44.63035583496094,220172.0487046241,491973,0,220172.0487046241,0.6628000140190125,1.439961552619934,10000,240073.4868443012,0.8890038728713989,0.4148271977901459,0.7806599736213684,0.8491368293762207,50000 -19870.639471769333,44.75104427337647,220592.05123972893,492910,0,220592.05123972893,0.6628000140190125,1.439961552619934,10000,240534.4372580052,0.8866601586341858,0.4159363508224487,0.7806599736213684,0.8491368293762207,50000 -19914.169314861298,44.8757050037384,221011.9910628796,493848,0,221011.9910628796,0.6628000140190125,1.439961552619934,10000,240998.0865285397,0.887499988079071,0.4184769093990326,0.7806599736213684,0.8491368293762207,50000 -19950.32503771782,44.97604584693909,221432.0510094165,494787,0,221432.0510094165,0.6628000140190125,1.439961552619934,10000,241454.4551520348,0.8898242115974426,0.4121768474578857,0.7806599736213684,0.8491368293762207,50000 -19995.958854675293,45.0961983203888,221852.1177001,495725,0,221852.1177001,0.6628000140190125,1.439961552619934,10000,241920.32781767845,0.89013671875,0.4112144410610199,0.7806599736213684,0.8491368293762207,50000 -20045.42670416832,45.213632106781006,222272.18267846107,496658,0,222272.18267846107,0.6628000140190125,1.439961552619934,10000,242390.0301771164,0.8859961032867432,0.4209813177585602,0.7806599736213684,0.8491368293762207,50000 -20081.30944299698,45.31370210647583,222692.19125318527,497598,0,222692.19125318527,0.6628000140190125,1.439961552619934,10000,242846.0733227729,0.8863281011581421,0.4229179620742798,0.7806599736213684,0.8491368293762207,50000 -20123.3711745739,45.43738150596619,223112.1540651321,498525,0,223112.1540651321,0.6628000140190125,1.439961552619934,10000,243308.27743959427,0.8882030844688416,0.4153804779052734,0.7806599736213684,0.8491368293762207,50000 -20166.64469361305,45.56512475013733,223532.21031570435,499461,0,223532.21031570435,0.6628000140190125,1.439961552619934,10000,243771.78656983376,0.8899218440055847,0.4086339771747589,0.7806599736213684,0.8491368293762207,50000 -20210.7332572937,45.69040513038635,223952.1179277897,500399,0,223952.1179277897,0.6628000140190125,1.439961552619934,10000,244235.95967531204,0.8868749737739563,0.4158554673194885,0.7806599736213684,0.8491368293762207,50000 -20255.60511660576,45.836000204086304,224372.26052856445,501340,0,224372.26052856445,0.6628000140190125,1.439961552619934,10000,244701.171677351,0.8868359327316284,0.4167474210262298,0.7806599736213684,0.8491368293762207,50000 -20294.00301527977,45.96119236946106,224792.3078968525,502278,0,224792.3078968525,0.6628000140190125,1.439961552619934,10000,245159.794072628,0.8883398175239563,0.4172844886779785,0.7806599736213684,0.8491368293762207,50000 -20340.79037499428,46.09040856361389,225212.55551242828,503214,0,225212.55551242828,0.6628000140190125,1.439961552619934,10000,245627.0100402832,0.8876562118530273,0.4177963137626648,0.7806599736213684,0.8491368293762207,50000 -20379.2294960022,46.2122004032135,225632.7815082073,504153,0,225632.7815082073,0.6628000140190125,1.439961552619934,10000,246085.8487613201,0.8895312547683716,0.4146497249603271,0.7806599736213684,0.8491368293762207,50000 -20425.49946808815,46.34702253341675,226052.95351052284,505090,0,226052.95351052284,0.6628000140190125,1.439961552619934,10000,246552.4772992134,0.8879687190055847,0.4165275990962982,0.7806599736213684,0.8491368293762207,50000 -20468.608820915226,46.46912431716919,226473.03325295448,506027,0,226473.03325295448,0.6628000140190125,1.439961552619934,10000,247015.84028172493,0.8889062404632568,0.4119411408901214,0.7806599736213684,0.8491368293762207,50000 -20512.156184911728,46.59465575218201,226892.98617458344,506965,0,226892.98617458344,0.6628000140190125,1.439961552619934,10000,247479.51724481583,0.8897460699081421,0.4139359295368194,0.7806599736213684,0.8491368293762207,50000 -20552.11368203163,46.699103116989136,227312.9258275032,507908,0,227312.9258275032,0.6628000140190125,1.439961552619934,10000,247939.57105445865,0.8889257907867432,0.411639004945755,0.7806599736213684,0.8491368293762207,50000 -20591.683132648468,46.82708263397217,227732.81487941745,508845,0,227732.81487941745,0.6628000140190125,1.439961552619934,10000,248399.2088296413,0.8868945240974426,0.4153727889060974,0.7806599736213684,0.8491368293762207,50000 -20637.40304684639,46.95222425460816,228152.68364787105,509782,0,228152.68364787105,0.6628000140190125,1.439961552619934,10000,248864.9745190144,0.8875781297683716,0.4177852272987366,0.7806599736213684,0.8491368293762207,50000 -20681.03459215164,47.079352617263794,228572.96571731567,510718,0,228572.96571731567,0.6628000140190125,1.439961552619934,10000,249329.0675106049,0.8883788585662842,0.4152357876300812,0.7806599736213684,0.8491368293762207,50000 -20723.862995624542,47.21655368804932,228993.1485798359,511658,0,228993.1485798359,0.6628000140190125,1.439961552619934,10000,249792.2692861557,0.8885741829872131,0.4134371876716614,0.7806599736213684,0.8491368293762207,50000 -20766.809602737427,47.343963384628296,229413.08163452148,512596,0,229413.08163452148,0.6628000140190125,1.439961552619934,10000,250255.3282535076,0.8901171684265137,0.4124380052089691,0.7806599736213684,0.8491368293762207,50000 -20806.49105668068,47.4794237613678,229833.1634716988,513533,0,229833.1634716988,0.6628000140190125,1.439961552619934,10000,250715.27859401703,0.8909960985183716,0.4092914760112762,0.7806599736213684,0.8491368293762207,50000 -20845.29662656784,47.680049657821655,230252.9332678318,514474,0,230252.9332678318,0.6628000140190125,1.439961552619934,10000,251174.10797166824,0.8879101276397705,0.412520170211792,0.7806599736213684,0.8491368293762207,50000 -20887.919131040573,47.809616804122925,230673.0644853115,515414,0,230673.0644853115,0.6628000140190125,1.439961552619934,10000,251637.04307699203,0.8891406059265137,0.4178435504436493,0.7806599736213684,0.8491368293762207,50000 -20935.35671448708,47.94902086257935,231093.2486963272,516348,0,231093.2486963272,0.6628000140190125,1.439961552619934,10000,252104.85655498505,0.8887499570846558,0.4133471250534057,0.7806599736213684,0.8491368293762207,50000 -20973.56746840477,48.052809715271,231513.2865588665,517285,0,231513.2865588665,0.6628000140190125,1.439961552619934,10000,252563.2615237236,0.8858202695846558,0.4180032908916473,0.7806599736213684,0.8491368293762207,50000 -21019.734656095505,48.206233501434326,231933.59404420853,518218,0,231933.59404420853,0.6628000140190125,1.439961552619934,10000,253029.94160294533,0.8895898461341858,0.4120994210243225,0.7806599736213684,0.8491368293762207,50000 -21065.571618556976,48.32579302787781,232353.68795967102,519154,0,232353.68795967102,0.6628000140190125,1.439961552619934,10000,253496.0437746048,0.8883984088897705,0.4173663556575775,0.7806599736213684,0.8491368293762207,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 0a17b7ea1..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5753 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.3829881,6.907756,,,,,,,,,,,,,, -1,,,0.0008789062267169,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,42.50220847129822,82.83079028129578,42.50220847129822,40.32848334312439,0.0,0.0 -100,0.50050557,6.8731804,,,,,,,,,,,,,, -200,0.7011339,6.756558,,,,,,,,,,,,,, -300,0.7614669,6.734746,,,,,,,,,,,,,, -400,1.1094402,6.7266083,,,,,,,,,,,,,, -500,1.3051015,6.511038,,,,,,,,,,,,,, -600,1.001551,6.361764,,,,,,,,,,,,,, -700,1.1283841,6.259181,,,,,,,,,,,,,, -800,1.2762573,6.1603713,,,,,,,,,,,,,, -900,1.3492512,6.122115,,,,,,,,,,,,,, -905,,,0.0355468727648258,5.822093486785889,0.0365999974310398,5.84222936630249,50000.0,0.02730000205338,5.976207733154297,10000.0,462.6079497337341,524.7230081558228,462.6079497337341,62.036508321762085,0.027024507522583,0.0 -1000,1.6092238,6.0644703,,,,,,,,,,,,,, -1100,1.365803,5.991452,,,,,,,,,,,,,, -1200,1.6429361,5.9959373,,,,,,,,,,,,,, -1300,1.173577,6.6827164,,,,,,,,,,,,,, -1400,0.8933575,6.1120257,,,,,,,,,,,,,, -1500,1.1369004,5.846339,,,,,,,,,,,,,, -1600,1.2557962,5.886666,,,,,,,,,,,,,, -1700,1.2336695,5.697193,,,,,,,,,,,,,, -1800,0.95477515,6.314009,,,,,,,,,,,,,, -1862,,,0.0798046886920929,5.190312385559082,0.0763199999928474,5.234590530395508,50000.0,0.0597000010311603,5.4772138595581055,10000.0,882.7354953289032,966.3888356685638,882.7354953289032,83.49371838569641,0.0549170970916748,0.0 -1900,1.1192596,6.286613,,,,,,,,,,,,,, -2000,1.2709221,5.6976256,,,,,,,,,,,,,, -2100,1.1933186,6.323117,,,,,,,,,,,,,, -2200,1.3534677,5.5366173,,,,,,,,,,,,,, -2300,0.9842322,5.5416684,,,,,,,,,,,,,, -2400,0.9216931,6.069952,,,,,,,,,,,,,, -2500,0.998191,5.3691225,,,,,,,,,,,,,, -2600,0.80935705,6.4812126,,,,,,,,,,,,,, -2700,0.8702112,5.277937,,,,,,,,,,,,,, -2800,1.0952698,5.2382927,,,,,,,,,,,,,, -2820,,,0.1470703035593032,4.55432653427124,0.1360999941825866,4.649452209472656,50000.0,0.106900006532669,4.9579620361328125,10000.0,1302.832392454147,1408.2555124759674,1302.832392454147,105.18109583854675,0.0841908454895019,0.0 -2900,0.783885,5.4661016,,,,,,,,,,,,,, -3000,1.074858,5.182512,,,,,,,,,,,,,, -3100,1.0267328,6.521857,,,,,,,,,,,,,, -3200,0.88254565,4.896588,,,,,,,,,,,,,, -3300,0.8524522,5.016838,,,,,,,,,,,,,, -3400,0.64274216,6.1943126,,,,,,,,,,,,,, -3500,0.8390931,4.9358177,,,,,,,,,,,,,, -3600,0.83637977,5.606818,,,,,,,,,,,,,, -3700,0.90795636,5.1876225,,,,,,,,,,,,,, -3779,,,0.2031054645776748,4.117558002471924,0.1888599991798401,4.196497440338135,50000.0,0.1464000046253204,4.582161903381348,10000.0,1722.8942482471466,1849.997201681137,1722.8942482471466,126.77802324295044,0.113541841506958,0.0 -3800,0.88243586,5.8390093,,,,,,,,,,,,,, -3900,0.9139004,5.1503386,,,,,,,,,,,,,, -4000,0.86456776,4.9017057,,,,,,,,,,,,,, -4100,0.97556055,5.1630898,,,,,,,,,,,,,, -4200,1.0696212,4.6611104,,,,,,,,,,,,,, -4300,0.86170816,6.2067575,,,,,,,,,,,,,, -4400,0.7388528,5.734652,,,,,,,,,,,,,, -4500,0.7750657,4.5995398,,,,,,,,,,,,,, -4600,0.993452,4.6225944,,,,,,,,,,,,,, -4700,0.5687587,6.2148595,,,,,,,,,,,,,, -4723,,,0.2610937356948852,3.6418685913085938,0.2449599951505661,3.743607997894287,50000.0,0.185800015926361,4.2058210372924805,10000.0,2143.081571817398,2292.112045764923,2143.081571817398,148.62538933753967,0.1412358283996582,0.0 -4800,0.9362884,4.3565083,,,,,,,,,,,,,, -4900,0.6798942,5.6679187,,,,,,,,,,,,,, -5000,1.000535,4.3617597,,,,,,,,,,,,,, -5100,0.6325434,5.791784,,,,,,,,,,,,,, -5200,0.7135805,5.7742696,,,,,,,,,,,,,, -5300,0.7562063,5.9695945,,,,,,,,,,,,,, -5400,0.99586236,4.2252746,,,,,,,,,,,,,, -5500,0.8128881,4.176561,,,,,,,,,,,,,, -5600,1.1626357,4.1528807,,,,,,,,,,,,,, -5677,,,0.3202343583106994,3.211933135986328,0.2983799874782562,3.3576674461364746,50000.0,0.2299000173807144,3.89674711227417,10000.0,2563.1050510406494,2733.9481399059296,2563.1050510406494,170.35447335243225,0.1706206798553466,0.0 -5700,0.69812286,5.889019,,,,,,,,,,,,,, -5800,0.94497365,4.1328197,,,,,,,,,,,,,, -5900,0.91545933,4.0936213,,,,,,,,,,,,,, -6000,0.8971524,4.121603,,,,,,,,,,,,,, -6100,1.1184582,4.2852135,,,,,,,,,,,,,, -6200,0.64565504,5.8675604,,,,,,,,,,,,,, -6300,0.9103006,4.3264318,,,,,,,,,,,,,, -6400,0.9596251,4.006879,,,,,,,,,,,,,, -6500,0.6881359,6.0498724,,,,,,,,,,,,,, -6600,0.84061056,4.731172,,,,,,,,,,,,,, -6629,,,0.365527331829071,2.978093147277832,0.3293399810791015,3.1722586154937744,50000.0,0.2574000060558319,3.7208309173583984,10000.0,2983.2663242816925,3175.8974702358246,2983.2663242816925,192.0537114143372,0.2007627487182617,0.0 -6700,0.99668646,3.9600565,,,,,,,,,,,,,, -6800,0.6026954,6.0062,,,,,,,,,,,,,, -6900,1.0008067,3.9587715,,,,,,,,,,,,,, -7000,0.9485717,3.7732623,,,,,,,,,,,,,, -7100,0.98802084,3.8694274,,,,,,,,,,,,,, -7200,0.7699709,4.3755794,,,,,,,,,,,,,, -7300,0.88545936,3.7429583,,,,,,,,,,,,,, -7400,0.8753279,3.857658,,,,,,,,,,,,,, -7500,0.73722947,4.141474,,,,,,,,,,,,,, -7578,,,0.392382800579071,2.8480987548828125,0.3620399832725525,2.98736834526062,50000.0,0.2772000133991241,3.57218599319458,10000.0,3403.3904716968536,3622.811475038528,3403.3904716968536,218.76269936561584,0.2289977073669433,0.0 -7600,0.796644,4.190182,,,,,,,,,,,,,, -7700,0.9593566,3.7241094,,,,,,,,,,,,,, -7800,0.7817545,4.963695,,,,,,,,,,,,,, -7900,0.9202884,3.7189507,,,,,,,,,,,,,, -8000,1.0093918,3.8692477,,,,,,,,,,,,,, -8100,1.0410165,3.5878541,,,,,,,,,,,,,, -8200,0.8887695,3.6135633,,,,,,,,,,,,,, -8300,1.0541978,6.040868,,,,,,,,,,,,,, -8400,1.013609,3.5600197,,,,,,,,,,,,,, -8500,0.93963116,3.5420337,,,,,,,,,,,,,, -8528,,,0.4204296767711639,2.6367781162261963,0.3862800002098083,2.817789554595948,50000.0,0.2962000072002411,3.424993991851806,10000.0,3823.300567626953,4069.3166477680206,3823.300567626953,245.2535297870636,0.2788910865783691,0.0 -8600,0.7064875,4.672017,,,,,,,,,,,,,, -8700,1.0102046,3.3985934,,,,,,,,,,,,,, -8800,1.054067,3.5133283,,,,,,,,,,,,,, -8900,1.0142907,3.5956545,,,,,,,,,,,,,, -9000,0.73427075,4.3433337,,,,,,,,,,,,,, -9100,0.9162382,3.7476597,,,,,,,,,,,,,, -9200,0.9465615,3.5482094,,,,,,,,,,,,,, -9300,0.85043097,4.27047,,,,,,,,,,,,,, -9400,0.9746764,5.8273005,,,,,,,,,,,,,, -9478,,,0.458300769329071,2.420060396194458,0.4150999784469604,2.6546177864074707,50000.0,0.3241000175476074,3.2569079399108887,10000.0,4243.458779096603,4515.28219127655,4243.458779096603,270.96862721443176,0.3183200359344482,0.0 -9500,1.0212882,3.7747734,,,,,,,,,,,,,, -9600,0.8982263,3.4602952,,,,,,,,,,,,,, -9700,0.8949557,5.3278594,,,,,,,,,,,,,, -9800,1.1058887,3.5454078,,,,,,,,,,,,,, -9900,1.0110756,3.5711837,,,,,,,,,,,,,, -10000,0.99140966,3.4089181,,,,,,,,,,,,,, -10100,0.8126949,4.36521,,,,,,,,,,,,,, -10200,1.0560309,3.432546,,,,,,,,,,,,,, -10300,0.896196,3.4363382,,,,,,,,,,,,,, -10400,0.71589065,5.764721,,,,,,,,,,,,,, -10424,,,0.4592382609844208,2.421598434448242,0.4288399815559387,2.58870530128479,50000.0,0.324500024318695,3.2115306854248047,10000.0,4663.765017032623,4962.0816378593445,4663.765017032623,297.3756804466248,0.3520138263702392,0.0 -10500,1.1756438,3.4766686,,,,,,,,,,,,,, -10600,0.74681085,5.3347516,,,,,,,,,,,,,, -10700,1.1938287,3.4376726,,,,,,,,,,,,,, -10800,0.94042444,3.6083355,,,,,,,,,,,,,, -10900,0.75049543,4.980508,,,,,,,,,,,,,, -11000,0.9185575,3.4495234,,,,,,,,,,,,,, -11100,0.66858387,5.7367,,,,,,,,,,,,,, -11200,0.8162118,5.7316546,,,,,,,,,,,,,, -11300,0.9737588,3.3842604,,,,,,,,,,,,,, -11373,,,0.4764843583106994,2.326941728591919,0.4420799911022186,2.498100519180298,50000.0,0.3439000248908996,3.133396625518799,10000.0,5083.707577466965,5410.477880477905,5083.707577466965,325.7472426891327,0.3814010620117187,0.0 -11400,0.78032184,4.5519223,,,,,,,,,,,,,, -11500,1.0563852,3.3186157,,,,,,,,,,,,,, -11600,0.9780825,3.7756882,,,,,,,,,,,,,, -11700,1.1205736,3.2384045,,,,,,,,,,,,,, -11800,0.9873153,3.6684015,,,,,,,,,,,,,, -11900,1.0056599,3.78822,,,,,,,,,,,,,, -12000,1.0313317,3.1448917,,,,,,,,,,,,,, -12100,0.7557285,4.243832,,,,,,,,,,,,,, -12200,0.8340323,5.6819553,,,,,,,,,,,,,, -12300,1.0248967,4.2695656,,,,,,,,,,,,,, -12322,,,0.5114062428474426,2.1223716735839844,0.4688199758529663,2.3403637409210205,50000.0,0.3625000119209289,3.0175158977508545,10000.0,5503.754225730896,5858.329460859299,5503.754225730896,353.4686424732208,0.4112565517425537,0.0 -12400,0.8500575,5.8334846,,,,,,,,,,,,,, -12500,0.96910167,3.2039227,,,,,,,,,,,,,, -12600,0.9695535,3.4879825,,,,,,,,,,,,,, -12700,1.0383891,3.3110886,,,,,,,,,,,,,, -12800,0.9294177,4.0471807,,,,,,,,,,,,,, -12900,1.0036093,3.2149982,,,,,,,,,,,,,, -13000,0.94935256,4.8886814,,,,,,,,,,,,,, -13100,1.0259821,3.4684243,,,,,,,,,,,,,, -13200,0.74081683,4.548575,,,,,,,,,,,,,, -13263,,,0.5409570336341858,2.0211377143859863,0.4729399979114532,2.332581043243408,50000.0,0.3637000024318695,2.990636348724365,10000.0,5923.964722394943,6307.325840473175,5923.964722394943,382.17194390296936,0.4417285919189453,0.0 -13300,0.94211686,5.3626375,,,,,,,,,,,,,, -13400,0.86008483,5.630124,,,,,,,,,,,,,, -13500,0.9999252,3.2578197,,,,,,,,,,,,,, -13600,0.98591524,3.725854,,,,,,,,,,,,,, -13700,0.9608008,3.0462863,,,,,,,,,,,,,, -13800,1.1355238,3.201807,,,,,,,,,,,,,, -13900,1.1327039,3.2076068,,,,,,,,,,,,,, -14000,0.8330927,4.6980276,,,,,,,,,,,,,, -14100,1.1543477,3.536521,,,,,,,,,,,,,, -14200,1.1899621,3.0327163,,,,,,,,,,,,,, -14209,,,0.5172070264816284,2.0955166816711426,0.4853599965572357,2.272596597671509,50000.0,0.3805000185966491,2.9285807609558105,10000.0,6344.242172241211,6758.920413732529,6344.242172241211,413.3990566730499,0.4788401126861572,0.0 -14300,1.0464213,4.415711,,,,,,,,,,,,,, -14400,0.96197414,3.2211702,,,,,,,,,,,,,, -14500,1.0072409,3.8211591,,,,,,,,,,,,,, -14600,0.9961533,3.0341344,,,,,,,,,,,,,, -14700,1.0294924,3.16466,,,,,,,,,,,,,, -14800,1.0619693,3.0075479,,,,,,,,,,,,,, -14900,1.0537615,3.1469069,,,,,,,,,,,,,, -15000,1.0520655,2.9853663,,,,,,,,,,,,,, -15100,1.1307681,3.122367,,,,,,,,,,,,,, -15153,,,0.53125,2.060498237609864,0.4918199777603149,2.2550272941589355,50000.0,0.3815000057220459,2.920339822769165,10000.0,6764.433388948441,7211.947317838669,6764.433388948441,446.1518096923828,0.5089349746704102,0.0 -15200,0.9107078,3.7909253,,,,,,,,,,,,,, -15300,1.187635,2.9394553,,,,,,,,,,,,,, -15400,0.95269924,3.3909795,,,,,,,,,,,,,, -15500,1.2207795,3.1388185,,,,,,,,,,,,,, -15600,1.1046252,3.1459656,,,,,,,,,,,,,, -15700,1.1353431,3.0305817,,,,,,,,,,,,,, -15800,0.91753423,3.679271,,,,,,,,,,,,,, -15900,0.85893106,4.0669813,,,,,,,,,,,,,, -16000,1.0998058,2.9637303,,,,,,,,,,,,,, -16096,,,0.5531249642372131,1.9234681129455569,0.5004400014877319,2.1823103427886963,50000.0,0.3912000060081482,2.838825225830078,10000.0,7184.445158958435,7661.808724164963,7184.445158958435,475.9171495437622,0.5414671897888184,0.0 -16100,0.78393215,4.5835466,,,,,,,,,,,,,, -16200,0.885137,4.5117445,,,,,,,,,,,,,, -16300,1.0721209,3.0678763,,,,,,,,,,,,,, -16400,1.3012875,3.0606184,,,,,,,,,,,,,, -16500,0.8986968,5.1805763,,,,,,,,,,,,,, -16600,0.9310176,3.350483,,,,,,,,,,,,,, -16700,1.0689147,3.123728,,,,,,,,,,,,,, -16800,0.84811693,4.711905,,,,,,,,,,,,,, -16900,1.0370092,3.6642036,,,,,,,,,,,,,, -17000,1.1332365,3.686166,,,,,,,,,,,,,, -17035,,,0.5389062166213989,2.0137147903442383,0.5019400119781494,2.185546398162842,50000.0,0.3888000249862671,2.8619396686553955,10000.0,7604.548624038696,8112.031835079193,7604.548624038696,505.95252656936646,0.5746691226959229,0.0 -17100,0.9933337,2.935113,,,,,,,,,,,,,, -17200,1.1242868,2.82991,,,,,,,,,,,,,, -17300,0.96819013,5.0939474,,,,,,,,,,,,,, -17400,1.060159,4.5295973,,,,,,,,,,,,,, -17500,0.9120472,3.7650173,,,,,,,,,,,,,, -17600,1.0623114,3.141625,,,,,,,,,,,,,, -17700,1.0648502,2.8700233,,,,,,,,,,,,,, -17800,1.028763,2.9434187,,,,,,,,,,,,,, -17900,1.0872755,2.9466376,,,,,,,,,,,,,, -17979,,,0.5502538681030273,1.966177463531494,0.5148800015449524,2.137053966522217,50000.0,0.403300017118454,2.80226469039917,10000.0,8024.889421463013,8562.914777040482,8024.889421463013,536.4071035385132,0.6097869873046875,0.0 -18000,0.79239655,5.4031677,,,,,,,,,,,,,, -18100,0.9999806,3.4063916,,,,,,,,,,,,,, -18200,0.8323512,4.555857,,,,,,,,,,,,,, -18300,1.0824318,2.9259915,,,,,,,,,,,,,, -18400,0.9402858,4.0354376,,,,,,,,,,,,,, -18500,1.2583852,3.026786,,,,,,,,,,,,,, -18600,0.9029886,3.9200585,,,,,,,,,,,,,, -18700,1.1743609,2.9478462,,,,,,,,,,,,,, -18800,1.0659065,2.9471323,,,,,,,,,,,,,, -18900,1.0451411,2.8466315,,,,,,,,,,,,,, -18922,,,0.5723046660423279,1.7916992902755735,0.5268599987030029,2.030201435089112,50000.0,0.4070000052452087,2.7275853157043457,10000.0,8445.13327217102,9014.0472304821,8445.13327217102,567.2139377593994,0.6404788494110107,0.0 -19000,0.9091652,4.279497,,,,,,,,,,,,,, -19100,1.0927963,2.9334004,,,,,,,,,,,,,, -19200,0.99393,5.1623197,,,,,,,,,,,,,, -19300,1.2597433,2.8156862,,,,,,,,,,,,,, -19400,0.9161643,4.711758,,,,,,,,,,,,,, -19500,1.0329449,2.9692926,,,,,,,,,,,,,, -19600,1.1896436,2.8519206,,,,,,,,,,,,,, -19700,1.0825356,3.0791366,,,,,,,,,,,,,, -19800,1.0654818,2.9214509,,,,,,,,,,,,,, -19865,,,0.5950781106948853,1.7214692831039429,0.5264599919319153,2.04833984375,50000.0,0.4144000113010406,2.693567276000977,10000.0,8865.34806394577,9464.613876342772,8865.34806394577,597.48259973526,0.6718833446502686,0.0 -19900,1.1936879,2.8366241,,,,,,,,,,,,,, -20000,0.99573296,2.8201563,,,,,,,,,,,,,, -20100,1.1838644,2.830533,,,,,,,,,,,,,, -20200,0.83266395,4.2408075,,,,,,,,,,,,,, -20300,1.1456271,2.7550414,,,,,,,,,,,,,, -20400,0.9828787,3.040155,,,,,,,,,,,,,, -20500,0.8297385,5.4743733,,,,,,,,,,,,,, -20600,1.1290078,3.0240295,,,,,,,,,,,,,, -20700,1.2210411,2.8737566,,,,,,,,,,,,,, -20800,1.0232242,3.0188742,,,,,,,,,,,,,, -20808,,,0.5799609422683716,1.7857638597488403,0.5396199822425842,1.978995442390442,50000.0,0.4191000163555145,2.663177728652954,10000.0,9285.290144205092,9916.250267505646,9285.290144205092,629.0911407470703,0.7059800624847412,0.0 -20900,1.3664186,2.8227286,,,,,,,,,,,,,, -21000,1.2410549,2.891054,,,,,,,,,,,,,, -21100,0.8910432,3.6503544,,,,,,,,,,,,,, -21200,1.1211836,2.9561214,,,,,,,,,,,,,, -21300,1.1049054,2.9549208,,,,,,,,,,,,,, -21400,1.0365546,5.385593,,,,,,,,,,,,,, -21500,1.1207856,2.9336367,,,,,,,,,,,,,, -21600,1.1674287,2.919639,,,,,,,,,,,,,, -21700,1.1192836,3.0004892,,,,,,,,,,,,,, -21750,,,0.5848437547683716,1.772911548614502,0.5408999919891357,1.9920196533203125,50000.0,0.4229000210762024,2.649247407913208,10000.0,9705.551045656204,10367.21435379982,9705.551045656204,659.7037007808685,0.7450652122497559,0.0 -21800,1.1823008,2.813076,,,,,,,,,,,,,, -21900,1.0511163,2.7666082,,,,,,,,,,,,,, -22000,1.2090263,2.8193524,,,,,,,,,,,,,, -22100,0.88092077,5.2549872,,,,,,,,,,,,,, -22200,1.061626,2.88771,,,,,,,,,,,,,, -22300,1.2235665,2.6486192,,,,,,,,,,,,,, -22400,0.93419373,5.50486,,,,,,,,,,,,,, -22500,0.8848524,4.8941817,,,,,,,,,,,,,, -22600,0.9679498,5.374671,,,,,,,,,,,,,, -22689,,,0.5935742259025574,1.716032862663269,0.5404399633407593,1.994184613227844,50000.0,0.4222000241279602,2.663605213165283,10000.0,10125.80424952507,10817.389746904371,10125.80424952507,689.5395059585571,0.7798676490783691,0.0 -22700,1.1728626,2.8096845,,,,,,,,,,,,,, -22800,1.1306227,3.0628548,,,,,,,,,,,,,, -22900,0.98873144,2.8158875,,,,,,,,,,,,,, -23000,0.91468304,5.382689,,,,,,,,,,,,,, -23100,1.0768762,2.6783826,,,,,,,,,,,,,, -23200,1.0601244,2.5301528,,,,,,,,,,,,,, -23300,1.1381627,2.6171646,,,,,,,,,,,,,, -23400,1.0307146,5.293744,,,,,,,,,,,,,, -23500,1.1570898,3.0235782,,,,,,,,,,,,,, -23600,1.2167488,2.6983094,,,,,,,,,,,,,, -23629,,,0.5935351252555847,1.7091354131698608,0.5555399656295776,1.896039366722107,50000.0,0.4326000213623047,2.574949026107788,10000.0,10545.795673847198,11272.577404022217,10545.795673847198,724.6498851776123,0.8144900798797607,0.0 -23700,1.1413803,3.4358718,,,,,,,,,,,,,, -23800,0.9294666,4.4746485,,,,,,,,,,,,,, -23900,1.02791,2.9751189,,,,,,,,,,,,,, -24000,0.82365155,4.5817404,,,,,,,,,,,,,, -24100,1.0748237,3.597318,,,,,,,,,,,,,, -24200,1.0552808,2.7665303,,,,,,,,,,,,,, -24300,1.0041763,2.862625,,,,,,,,,,,,,, -24400,1.2033324,2.8028412,,,,,,,,,,,,,, -24500,1.2044623,2.6278253,,,,,,,,,,,,,, -24578,,,0.5988476276397705,1.6850430965423584,0.5579400062561035,1.8976974487304688,50000.0,0.4345000088214874,2.592747926712036,10000.0,10966.144136667252,11725.279448986052,10966.144136667252,756.9143693447113,0.8506224155426025,0.0 -24600,0.9133099,4.082695,,,,,,,,,,,,,, -24700,1.0801687,3.241663,,,,,,,,,,,,,, -24800,1.0683583,2.9184713,,,,,,,,,,,,,, -24900,1.1960231,2.5939121,,,,,,,,,,,,,, -25000,0.9772938,3.0857031,,,,,,,,,,,,,, -25100,1.1446192,2.66128,,,,,,,,,,,,,, -25200,1.1322389,2.7735543,,,,,,,,,,,,,, -25300,0.84161913,4.332472,,,,,,,,,,,,,, -25400,1.0028356,5.20764,,,,,,,,,,,,,, -25500,0.9573015,5.116683,,,,,,,,,,,,,, -25524,,,0.6112499833106995,1.6054291725158691,0.5604999661445618,1.8555995225906368,50000.0,0.4425000250339508,2.5231847763061523,10000.0,11386.355078458786,12177.137206077576,11386.355078458786,788.4734981060028,0.8868091106414795,0.0 -25600,0.92559654,3.4167652,,,,,,,,,,,,,, -25700,1.1681172,2.7681208,,,,,,,,,,,,,, -25800,1.1995188,2.77099,,,,,,,,,,,,,, -25900,1.2317159,2.6852822,,,,,,,,,,,,,, -26000,0.9994194,4.5814233,,,,,,,,,,,,,, -26100,1.2335211,2.8524396,,,,,,,,,,,,,, -26200,1.027747,4.96856,,,,,,,,,,,,,, -26300,0.9959286,3.7410192,,,,,,,,,,,,,, -26400,1.0349431,4.629692,,,,,,,,,,,,,, -26463,,,0.6207422018051147,1.6242880821228027,0.5597000122070312,1.911012291908264,50000.0,0.443200021982193,2.5797505378723145,10000.0,11805.259394168854,12629.023286819458,11805.259394168854,820.2013580799103,2.0886545181274414,0.0 -26500,0.9973542,5.116875,,,,,,,,,,,,,, -26600,0.92844146,4.303156,,,,,,,,,,,,,, -26700,1.0963877,2.7300925,,,,,,,,,,,,,, -26800,0.9601931,4.7001657,,,,,,,,,,,,,, -26900,0.9016652,4.4632244,,,,,,,,,,,,,, -27000,1.1835771,2.496614,,,,,,,,,,,,,, -27100,0.98087263,5.225318,,,,,,,,,,,,,, -27200,1.1541324,2.6575959,,,,,,,,,,,,,, -27300,1.1587013,2.738226,,,,,,,,,,,,,, -27400,1.1714938,2.6132023,,,,,,,,,,,,,, -27404,,,0.6065429449081421,1.6465208530426023,0.5616399645805359,1.8516168594360352,50000.0,0.4491000175476074,2.521052837371826,10000.0,12225.197360038756,13079.97579050064,12225.197360038756,851.1198291778564,2.133695363998413,0.0 -27500,0.9477048,4.666383,,,,,,,,,,,,,, -27600,0.84995425,4.246585,,,,,,,,,,,,,, -27700,1.1710001,2.6450076,,,,,,,,,,,,,, -27800,1.2120583,2.7995887,,,,,,,,,,,,,, -27900,1.140701,2.6766937,,,,,,,,,,,,,, -28000,1.1362443,2.6624823,,,,,,,,,,,,,, -28100,0.9619182,4.6976004,,,,,,,,,,,,,, -28200,1.221827,2.5069473,,,,,,,,,,,,,, -28300,0.94069165,4.7562413,,,,,,,,,,,,,, -28342,,,0.6184960603713989,1.6287354230880735,0.570580005645752,1.852982759475708,50000.0,0.4531000256538391,2.515817880630493,10000.0,12645.36327767372,13531.726953268051,12645.36327767372,882.6207411289215,2.1664159297943115,0.0 -28400,0.8952307,5.0028234,,,,,,,,,,,,,, -28500,1.1626344,2.6468449,,,,,,,,,,,,,, -28600,0.86714506,4.542758,,,,,,,,,,,,,, -28700,1.1892405,2.6180372,,,,,,,,,,,,,, -28800,0.96313024,5.249552,,,,,,,,,,,,,, -28900,0.94907814,5.2755284,,,,,,,,,,,,,, -29000,1.0013108,3.7135372,,,,,,,,,,,,,, -29100,1.0446091,3.4392605,,,,,,,,,,,,,, -29200,1.1070113,2.8835328,,,,,,,,,,,,,, -29282,,,0.6328710913658142,1.5523300170898438,0.5777599811553955,1.8129938840866089,50000.0,0.4621000289916992,2.471017837524414,10000.0,13065.477472305298,13983.570016145706,13065.477472305298,914.2631750106812,2.201692819595337,0.0 -29300,1.0666409,2.8817768,,,,,,,,,,,,,, -29400,1.0254937,2.6062782,,,,,,,,,,,,,, -29500,1.0095981,5.109978,,,,,,,,,,,,,, -29600,1.0813272,2.5639646,,,,,,,,,,,,,, -29700,1.1023241,2.7643917,,,,,,,,,,,,,, -29800,0.9574523,5.1700296,,,,,,,,,,,,,, -29900,0.93739444,4.3148017,,,,,,,,,,,,,, -30000,1.2221377,2.3622391,,,,,,,,,,,,,, -30100,1.1049238,4.18997,,,,,,,,,,,,,, -30200,1.3163673,2.636567,,,,,,,,,,,,,, -30223,,,0.6181054711341858,1.5981839895248413,0.5783399939537048,1.7888963222503662,50000.0,0.4615000188350677,2.451876163482666,10000.0,13485.7399559021,14434.98591184616,13485.7399559021,945.3267643451692,2.239365339279175,0.0 -30300,1.0881784,5.207426,,,,,,,,,,,,,, -30400,1.2296853,2.6873703,,,,,,,,,,,,,, -30500,1.1008829,2.726382,,,,,,,,,,,,,, -30600,1.1559609,2.7878006,,,,,,,,,,,,,, -30700,1.0826811,2.6464567,,,,,,,,,,,,,, -30800,1.2187133,2.831564,,,,,,,,,,,,,, -30900,1.1783943,2.6084824,,,,,,,,,,,,,, -31000,1.1244892,2.6756744,,,,,,,,,,,,,, -31100,1.1409067,4.919545,,,,,,,,,,,,,, -31166,,,0.6274218559265137,1.5659538507461548,0.5806399583816528,1.7789556980133057,50000.0,0.4650000333786011,2.4483392238616943,10000.0,13906.012647390366,14886.83359527588,13906.012647390366,976.8150374889374,2.274978876113892,0.0 -31200,1.118262,2.6132765,,,,,,,,,,,,,, -31300,1.160359,2.6792555,,,,,,,,,,,,,, -31400,1.0414338,2.5488973,,,,,,,,,,,,,, -31500,1.0798862,2.457066,,,,,,,,,,,,,, -31600,1.1139154,2.5651417,,,,,,,,,,,,,, -31700,1.0603443,2.7387805,,,,,,,,,,,,,, -31800,1.0977366,2.6278172,,,,,,,,,,,,,, -31900,1.0545653,2.5117984,,,,,,,,,,,,,, -32000,1.1918291,2.7817628,,,,,,,,,,,,,, -32100,1.1470648,2.5748858,,,,,,,,,,,,,, -32107,,,0.635058581829071,1.529654622077942,0.5848599672317505,1.7706170082092283,50000.0,0.4610000252723694,2.439715623855591,10000.0,14326.030616283417,15337.757307052612,14326.030616283417,1007.635326385498,2.308807134628296,0.0 -32200,1.1016924,4.466598,,,,,,,,,,,,,, -32300,0.99234504,5.209224,,,,,,,,,,,,,, -32400,1.1725674,2.5571885,,,,,,,,,,,,,, -32500,1.082581,2.7320454,,,,,,,,,,,,,, -32600,0.94871175,3.8128462,,,,,,,,,,,,,, -32700,0.9294931,3.556066,,,,,,,,,,,,,, -32800,1.009766,3.3350506,,,,,,,,,,,,,, -32900,1.0785207,4.9932346,,,,,,,,,,,,,, -33000,1.105506,2.4549701,,,,,,,,,,,,,, -33047,,,0.64013671875,1.4978110790252686,0.5824599862098694,1.7634974718093872,50000.0,0.4671000242233276,2.421347141265869,10000.0,14746.13856124878,15789.595637083054,14746.13856124878,1039.280338525772,2.342710018157959,0.0 -33100,1.0474176,3.3312848,,,,,,,,,,,,,, -33200,1.1007093,4.3889203,,,,,,,,,,,,,, -33300,1.0715472,3.9217782,,,,,,,,,,,,,, -33400,1.2193127,2.7152934,,,,,,,,,,,,,, -33500,1.0960865,2.9787128,,,,,,,,,,,,,, -33600,1.2692356,2.681276,,,,,,,,,,,,,, -33700,1.2251058,2.5377045,,,,,,,,,,,,,, -33800,1.2261373,2.7336967,,,,,,,,,,,,,, -33900,1.0934765,3.2538297,,,,,,,,,,,,,, -33990,,,0.6319726705551147,1.5471922159194946,0.5868399739265442,1.755309820175171,50000.0,0.4688000082969665,2.423509120941162,10000.0,15166.186804294586,16240.08767604828,15166.186804294586,1069.6364603042605,2.3791699409484863,0.0 -34000,1.209971,2.5849164,,,,,,,,,,,,,, -34100,1.3891882,2.5886996,,,,,,,,,,,,,, -34200,0.924402,4.1592565,,,,,,,,,,,,,, -34300,1.1367116,4.6145887,,,,,,,,,,,,,, -34400,1.2024634,2.5303504,,,,,,,,,,,,,, -34500,1.190187,4.9573536,,,,,,,,,,,,,, -34600,0.9956888,4.115071,,,,,,,,,,,,,, -34700,1.1907717,2.8914173,,,,,,,,,,,,,, -34800,0.9890511,4.977897,,,,,,,,,,,,,, -34900,1.1600642,2.628548,,,,,,,,,,,,,, -34929,,,0.641406238079071,1.4608179330825806,0.5906999707221985,1.7022159099578855,50000.0,0.4708000123500824,2.3756675720214844,10000.0,15586.226889133452,16689.974547863007,15586.226889133452,1099.3962256908417,2.414524793624878,0.0 -35000,1.1414028,2.4747167,,,,,,,,,,,,,, -35100,0.95497465,5.0480385,,,,,,,,,,,,,, -35200,1.117827,2.4896705,,,,,,,,,,,,,, -35300,1.0653762,3.8300624,,,,,,,,,,,,,, -35400,0.9236283,3.6911497,,,,,,,,,,,,,, -35500,1.2268423,2.7763052,,,,,,,,,,,,,, -35600,1.1218162,2.5264492,,,,,,,,,,,,,, -35700,1.115996,2.8733892,,,,,,,,,,,,,, -35800,1.1939812,2.3652813,,,,,,,,,,,,,, -35868,,,0.6634570360183716,1.371727705001831,0.5966399908065796,1.674397110939026,50000.0,0.4737000167369842,2.368324279785156,10000.0,16006.404333353044,17142.975281000137,16006.404333353044,1132.126388311386,2.4569034576416016,0.0 -35900,1.1937296,2.6205702,,,,,,,,,,,,,, -36000,1.2399309,2.6027603,,,,,,,,,,,,,, -36100,1.0847381,2.9248734,,,,,,,,,,,,,, -36200,1.0816718,2.486824,,,,,,,,,,,,,, -36300,1.0299205,3.5606935,,,,,,,,,,,,,, -36400,0.9882291,2.9958286,,,,,,,,,,,,,, -36500,1.1370361,2.5631926,,,,,,,,,,,,,, -36600,1.1200895,3.5441585,,,,,,,,,,,,,, -36700,0.92286336,4.685476,,,,,,,,,,,,,, -36800,1.3152807,2.4341624,,,,,,,,,,,,,, -36811,,,0.6385351419448853,1.4934431314468384,0.5940399765968323,1.706323266029358,50000.0,0.4733000099658966,2.3658699989318848,10000.0,16426.55730485916,17594.734039783478,16426.55730485916,1163.6418359279633,2.496078968048096,0.0 -36900,1.0498633,2.7843566,,,,,,,,,,,,,, -37000,1.1812336,2.5294278,,,,,,,,,,,,,, -37100,1.1091005,3.7778194,,,,,,,,,,,,,, -37200,0.92150766,5.0467825,,,,,,,,,,,,,, -37300,1.1302667,2.7510314,,,,,,,,,,,,,, -37400,0.99131066,5.0415554,,,,,,,,,,,,,, -37500,1.0059477,4.6366763,,,,,,,,,,,,,, -37600,1.2465298,2.4605985,,,,,,,,,,,,,, -37700,1.1791116,2.4903944,,,,,,,,,,,,,, -37748,,,0.641796886920929,1.4904508590698242,0.5951799750328064,1.7146835327148438,50000.0,0.4742000102996826,2.3896632194519043,10000.0,16846.59597182274,18044.722935915,16846.59597182274,1193.5054569244385,2.531633138656616,0.0 -37800,1.2473729,2.4847913,,,,,,,,,,,,,, -37900,1.0126779,3.2651296,,,,,,,,,,,,,, -38000,1.0467628,3.9182386,,,,,,,,,,,,,, -38100,1.3120867,2.6312315,,,,,,,,,,,,,, -38200,1.1164501,4.9753,,,,,,,,,,,,,, -38300,1.0148412,4.7345753,,,,,,,,,,,,,, -38400,1.1862538,2.4878714,,,,,,,,,,,,,, -38500,1.1702504,2.5962307,,,,,,,,,,,,,, -38600,1.0161265,3.885394,,,,,,,,,,,,,, -38688,,,0.6561914086341858,1.4135422706604004,0.5982199907302856,1.678788661956787,50000.0,0.4707000255584717,2.3531055450439453,10000.0,17266.64454650879,18494.905309677124,17266.64454650879,1223.5478613376615,2.571526288986206,0.0 -38700,1.1369383,2.43978,,,,,,,,,,,,,, -38800,1.0246694,3.158002,,,,,,,,,,,,,, -38900,1.0533434,2.83867,,,,,,,,,,,,,, -39000,1.280211,2.6247756,,,,,,,,,,,,,, -39100,1.0746682,3.4844427,,,,,,,,,,,,,, -39200,1.317277,2.4637847,,,,,,,,,,,,,, -39300,1.2157617,2.4952462,,,,,,,,,,,,,, -39400,1.1753067,5.0685487,,,,,,,,,,,,,, -39500,1.0656922,2.3991907,,,,,,,,,,,,,, -39600,1.1953843,4.409199,,,,,,,,,,,,,, -39627,,,0.6489648222923279,1.4219398498535156,0.6005799770355225,1.6512315273284912,50000.0,0.4796000123023987,2.343630075454712,10000.0,17686.574685573578,18944.70481967926,17686.574685573578,1253.3240871429443,2.613891363143921,0.0 -39700,1.2175336,2.5287602,,,,,,,,,,,,,, -39800,1.1551827,2.4926763,,,,,,,,,,,,,, -39900,1.3392054,2.4890587,,,,,,,,,,,,,, -40000,0.999794,4.639233,,,,,,,,,,,,,, -40100,1.0951164,3.2446399,,,,,,,,,,,,,, -40200,1.194814,2.4439037,,,,,,,,,,,,,, -40300,1.2444295,2.8125541,,,,,,,,,,,,,, -40400,0.9122811,4.974767,,,,,,,,,,,,,, -40500,0.99259794,4.113358,,,,,,,,,,,,,, -40566,,,0.6485155820846558,1.4499775171279907,0.6007800102233887,1.6757594347000122,50000.0,0.4792000353336334,2.341360569000244,10000.0,18106.86800003052,19395.80882000923,18106.86800003052,1284.0473837852478,2.649874925613404,0.0 -40600,1.173548,2.4535167,,,,,,,,,,,,,, -40700,0.932569,4.4904213,,,,,,,,,,,,,, -40800,1.131448,2.5734754,,,,,,,,,,,,,, -40900,1.1103467,2.9624662,,,,,,,,,,,,,, -41000,1.1992073,2.476777,,,,,,,,,,,,,, -41100,1.1001315,5.012142,,,,,,,,,,,,,, -41200,1.2590581,2.5072422,,,,,,,,,,,,,, -41300,1.2071669,2.5254166,,,,,,,,,,,,,, -41400,1.1852052,2.5627398,,,,,,,,,,,,,, -41500,1.2384384,5.043008,,,,,,,,,,,,,, -41503,,,0.646484375,1.5099809169769287,0.6010199785232544,1.7225019931793213,50000.0,0.4783000349998474,2.380979061126709,10000.0,18527.173591852188,19845.67344045639,18527.173591852188,1313.5189683437347,2.6855294704437256,0.0 -41600,1.1965723,2.5627837,,,,,,,,,,,,,, -41700,1.1215936,4.0446925,,,,,,,,,,,,,, -41800,1.0184087,3.2705708,,,,,,,,,,,,,, -41900,1.0333095,3.137778,,,,,,,,,,,,,, -42000,1.0627927,4.272827,,,,,,,,,,,,,, -42100,1.1296576,2.522259,,,,,,,,,,,,,, -42200,1.1585095,2.4104283,,,,,,,,,,,,,, -42300,1.2338411,2.5220194,,,,,,,,,,,,,, -42400,1.264819,2.591824,,,,,,,,,,,,,, -42444,,,0.6727929711341858,1.3531529903411863,0.6094599962234497,1.6596788167953491,50000.0,0.4881000220775604,2.3156774044036865,10000.0,18947.48678970337,20296.886627674103,18947.48678970337,1344.332921743393,2.720686674118042,0.0 -42500,1.2263682,2.460172,,,,,,,,,,,,,, -42600,1.329856,2.4889083,,,,,,,,,,,,,, -42700,1.0429844,4.227149,,,,,,,,,,,,,, -42800,1.1610482,2.7007108,,,,,,,,,,,,,, -42900,0.96586657,3.9505925,,,,,,,,,,,,,, -43000,1.2222782,2.5354214,,,,,,,,,,,,,, -43100,1.3173896,2.5024762,,,,,,,,,,,,,, -43200,1.0899936,5.004522,,,,,,,,,,,,,, -43300,1.1737435,2.5011706,,,,,,,,,,,,,, -43385,,,0.6493554711341858,1.4348890781402588,0.6075999736785889,1.6461001634597778,50000.0,0.4800000190734863,2.359612941741944,10000.0,19367.46821308136,20747.942059278488,19367.46821308136,1375.3200373649595,2.756638288497925,0.0 -43400,0.9189088,4.2695193,,,,,,,,,,,,,, -43500,0.9696957,3.4991367,,,,,,,,,,,,,, -43600,1.2077761,2.4160686,,,,,,,,,,,,,, -43700,1.2449495,2.4974365,,,,,,,,,,,,,, -43800,1.1264566,4.0114937,,,,,,,,,,,,,, -43900,1.0117317,4.9916897,,,,,,,,,,,,,, -44000,0.98409534,4.9087377,,,,,,,,,,,,,, -44100,1.0471244,3.7597551,,,,,,,,,,,,,, -44200,1.0318099,3.8314693,,,,,,,,,,,,,, -44300,1.1545306,2.342933,,,,,,,,,,,,,, -44327,,,0.6592382788658142,1.38425874710083,0.609499990940094,1.627719759941101,50000.0,0.4851000308990478,2.3165159225463867,10000.0,19787.398421525955,21198.64275765419,19787.398421525955,1405.9976406097412,2.7973742485046387,0.0 -44400,1.1325862,2.6377344,,,,,,,,,,,,,, -44500,1.2636179,2.3403108,,,,,,,,,,,,,, -44600,1.1760198,2.3509777,,,,,,,,,,,,,, -44700,1.2522656,2.4390957,,,,,,,,,,,,,, -44800,1.1940786,2.4302971,,,,,,,,,,,,,, -44900,1.1579564,2.388966,,,,,,,,,,,,,, -45000,1.082678,2.3883934,,,,,,,,,,,,,, -45100,1.2478292,2.5816615,,,,,,,,,,,,,, -45200,1.3057708,2.4998083,,,,,,,,,,,,,, -45265,,,0.6744531393051147,1.3124046325683594,0.6171199679374695,1.5861340761184692,50000.0,0.492000013589859,2.267448663711548,10000.0,20207.447094917297,21649.15763068199,20207.447094917297,1436.3729367256165,2.8370449542999268,0.0 -45300,0.9588214,4.783758,,,,,,,,,,,,,, -45400,1.2886493,2.3951547,,,,,,,,,,,,,, -45500,1.2417059,2.3913548,,,,,,,,,,,,,, -45600,1.1912172,2.316819,,,,,,,,,,,,,, -45700,1.1907984,2.562667,,,,,,,,,,,,,, -45800,0.9583566,5.065547,,,,,,,,,,,,,, -45900,1.1435418,2.84066,,,,,,,,,,,,,, -46000,1.1849118,4.9957743,,,,,,,,,,,,,, -46100,0.93245447,3.4534059,,,,,,,,,,,,,, -46200,1.2423588,2.4287314,,,,,,,,,,,,,, -46203,,,0.66162109375,1.3712722063064575,0.6122599840164185,1.6086606979370115,50000.0,0.4882000088691711,2.291762351989746,10000.0,20627.74993658065,22098.74702978134,20627.74993658065,1465.5651831626892,2.880309820175171,0.0 -46300,1.1984767,2.4985394,,,,,,,,,,,,,, -46400,1.2385932,2.5160172,,,,,,,,,,,,,, -46500,1.3388736,2.451494,,,,,,,,,,,,,, -46600,1.082296,4.9762836,,,,,,,,,,,,,, -46700,1.3942517,2.3685827,,,,,,,,,,,,,, -46800,1.2976477,2.4494457,,,,,,,,,,,,,, -46900,1.1502163,2.4885046,,,,,,,,,,,,,, -47000,1.1686746,2.4246094,,,,,,,,,,,,,, -47100,1.183145,2.3666928,,,,,,,,,,,,,, -47143,,,0.6664453148841858,1.3697527647018433,0.616599977016449,1.600008845329285,50000.0,0.4901000261306762,2.2621548175811768,10000.0,21047.87811112404,22551.3086810112,21047.87811112404,1497.9042127132416,2.924238920211792,0.0 -47200,1.1406544,3.726485,,,,,,,,,,,,,, -47300,0.97048074,3.5486321,,,,,,,,,,,,,, -47400,1.1887918,2.4191134,,,,,,,,,,,,,, -47500,1.0210712,4.935714,,,,,,,,,,,,,, -47600,1.0514628,3.1232185,,,,,,,,,,,,,, -47700,1.1144224,4.9819655,,,,,,,,,,,,,, -47800,1.0440356,4.2561693,,,,,,,,,,,,,, -47900,1.0306722,4.9204617,,,,,,,,,,,,,, -48000,1.1554186,2.4934828,,,,,,,,,,,,,, -48088,,,0.66943359375,1.329981565475464,0.6164000034332275,1.5844672918319702,50000.0,0.495600014925003,2.226857662200928,10000.0,21467.8795838356,23001.425425052643,21467.8795838356,1527.9312839508057,2.9619107246398926,0.0 -48100,1.2657231,2.5807998,,,,,,,,,,,,,, -48200,1.1667085,2.4220235,,,,,,,,,,,,,, -48300,1.2278209,2.460508,,,,,,,,,,,,,, -48400,1.2293731,2.6867418,,,,,,,,,,,,,, -48500,1.2219847,2.4667003,,,,,,,,,,,,,, -48600,1.0076255,3.818991,,,,,,,,,,,,,, -48700,1.1782523,2.520298,,,,,,,,,,,,,, -48800,1.0947204,4.8137107,,,,,,,,,,,,,, -48900,1.198662,2.4577386,,,,,,,,,,,,,, -49000,1.1241189,2.5652053,,,,,,,,,,,,,, -49027,,,0.6886523365974426,1.2827799320220947,0.617680013179779,1.6056610345840454,50000.0,0.4955000281333923,2.276652336120605,10000.0,21888.09333539009,23452.66234397888,21888.09333539009,1558.8661715984344,2.999751567840576,0.0 -49100,1.2400311,2.4041386,,,,,,,,,,,,,, -49200,1.1283202,3.3540933,,,,,,,,,,,,,, -49300,1.1607045,2.5551353,,,,,,,,,,,,,, -49400,1.1971097,2.8008146,,,,,,,,,,,,,, -49500,1.0313094,3.1096663,,,,,,,,,,,,,, -49600,0.9475342,4.3197274,,,,,,,,,,,,,, -49700,1.3635719,2.3764334,,,,,,,,,,,,,, -49800,1.3205322,2.3575764,,,,,,,,,,,,,, -49900,1.1608558,2.792844,,,,,,,,,,,,,, -49970,,,0.6638867259025574,1.3901808261871338,0.6227799654006958,1.5859307050704956,50000.0,0.4980000257492065,2.2585461139678955,10000.0,22308.256228208546,23903.274853229523,22308.256228208546,1589.2278089523315,3.0361344814300537,0.0 -50000,1.4544939,2.3329237,,,,,,,,,,,,,, -50100,1.2666626,2.4067345,,,,,,,,,,,,,, -50200,1.2268184,2.3786314,,,,,,,,,,,,,, -50300,1.2462058,2.4780598,,,,,,,,,,,,,, -50400,1.1204511,4.955421,,,,,,,,,,,,,, -50500,1.1114486,2.4173727,,,,,,,,,,,,,, -50600,1.1115739,4.3527627,,,,,,,,,,,,,, -50700,1.0489304,3.2294958,,,,,,,,,,,,,, -50800,1.1645988,2.4283288,,,,,,,,,,,,,, -50900,1.269773,2.347127,,,,,,,,,,,,,, -50910,,,0.6714843511581421,1.3220560550689695,0.6208400130271912,1.5570722818374634,50000.0,0.495600014925003,2.232228994369507,10000.0,22728.1577205658,24354.00265216828,22728.1577205658,1619.9439854621887,3.0944161415100098,0.0 -51000,1.2714208,2.3525918,,,,,,,,,,,,,, -51100,1.2841147,4.8317494,,,,,,,,,,,,,, -51200,1.1140943,4.8513465,,,,,,,,,,,,,, -51300,1.0822337,2.666559,,,,,,,,,,,,,, -51400,1.2174717,2.9266894,,,,,,,,,,,,,, -51500,1.1968242,2.3702104,,,,,,,,,,,,,, -51600,1.237609,2.389927,,,,,,,,,,,,,, -51700,1.3606726,2.3360338,,,,,,,,,,,,,, -51800,1.2033302,2.3498259,,,,,,,,,,,,,, -51849,,,0.6763281226158142,1.3460313081741333,0.6196799874305725,1.6015217304229736,50000.0,0.4967000186443329,2.263665199279785,10000.0,23148.37905406952,24807.57603955269,23148.37905406952,1653.205320596695,3.13411545753479,0.0 -51900,1.3512791,2.3698895,,,,,,,,,,,,,, -52000,1.0409782,4.0283723,,,,,,,,,,,,,, -52100,1.2764239,2.2635374,,,,,,,,,,,,,, -52200,1.0921719,2.6485465,,,,,,,,,,,,,, -52300,1.3239195,2.2864347,,,,,,,,,,,,,, -52400,1.1199833,4.4486136,,,,,,,,,,,,,, -52500,1.1278262,3.473703,,,,,,,,,,,,,, -52600,1.2183933,3.0841093,,,,,,,,,,,,,, -52700,1.0981511,2.5324733,,,,,,,,,,,,,, -52792,,,0.6678906083106995,1.386376976966858,0.6201199889183044,1.6033985614776611,50000.0,0.4921000301837921,2.28542423248291,10000.0,23568.588448762894,25260.98716020584,23568.588448762894,1686.3185930252075,3.17166519165039,0.0 -52800,0.9725833,3.8673584,,,,,,,,,,,,,, -52900,1.0460966,3.8004084,,,,,,,,,,,,,, -53000,1.168337,3.4504037,,,,,,,,,,,,,, -53100,1.2243459,2.3821256,,,,,,,,,,,,,, -53200,1.0923504,3.2073395,,,,,,,,,,,,,, -53300,1.0311496,4.750217,,,,,,,,,,,,,, -53400,1.1461116,2.9830284,,,,,,,,,,,,,, -53500,1.2026489,2.451568,,,,,,,,,,,,,, -53600,1.0106709,4.287798,,,,,,,,,,,,,, -53700,1.0798241,3.0788176,,,,,,,,,,,,,, -53734,,,0.674609363079071,1.3201045989990234,0.625,1.555474042892456,50000.0,0.4958000183105469,2.2348878383636475,10000.0,23988.578372478485,25713.50721931457,23988.578372478485,1718.752513408661,3.2167985439300537,0.0 -53800,1.2921903,2.417718,,,,,,,,,,,,,, -53900,1.1634254,2.7239258,,,,,,,,,,,,,, -54000,1.0831163,4.38326,,,,,,,,,,,,,, -54100,1.2304537,2.307835,,,,,,,,,,,,,, -54200,1.1783421,3.657762,,,,,,,,,,,,,, -54300,1.1388359,2.8264475,,,,,,,,,,,,,, -54400,1.0444117,4.576563,,,,,,,,,,,,,, -54500,1.2683597,2.312909,,,,,,,,,,,,,, -54600,1.1896343,2.351604,,,,,,,,,,,,,, -54673,,,0.6805273294448853,1.3037056922912598,0.6293399930000305,1.5446206331253052,50000.0,0.508400022983551,2.209318161010742,10000.0,24408.97371816635,26166.98223090172,24408.97371816635,1751.7384040355682,3.2596917152404785,0.0 -54700,1.317333,2.334537,,,,,,,,,,,,,, -54800,1.2989662,2.3151321,,,,,,,,,,,,,, -54900,1.4831101,3.0075068,,,,,,,,,,,,,, -55000,1.294819,2.4427013,,,,,,,,,,,,,, -55100,1.1991687,2.3934164,,,,,,,,,,,,,, -55200,1.2118795,2.3424163,,,,,,,,,,,,,, -55300,1.2145517,2.7740457,,,,,,,,,,,,,, -55400,1.2841035,2.3310988,,,,,,,,,,,,,, -55500,1.2756233,2.3118703,,,,,,,,,,,,,, -55600,1.0199276,3.3773708,,,,,,,,,,,,,, -55613,,,0.7066406011581421,1.207590937614441,0.630079984664917,1.5488102436065674,50000.0,0.5047000050544739,2.2150003910064697,10000.0,24829.13514328003,26623.4019677639,24829.13514328003,1787.9053149223328,3.299623966217041,0.0 -55700,1.1903867,2.6361694,,,,,,,,,,,,,, -55800,1.1677171,2.2380939,,,,,,,,,,,,,, -55900,1.0723094,3.3399396,,,,,,,,,,,,,, -56000,1.0162199,4.125741,,,,,,,,,,,,,, -56100,1.2187042,2.690404,,,,,,,,,,,,,, -56200,1.3402894,2.2264378,,,,,,,,,,,,,, -56300,1.2531449,2.946475,,,,,,,,,,,,,, -56400,1.0781653,4.893105,,,,,,,,,,,,,, -56500,1.3191731,2.3338091,,,,,,,,,,,,,, -56558,,,0.6771875023841858,1.3591915369033811,0.629040002822876,1.5706725120544434,50000.0,0.5024999976158142,2.2308924198150635,10000.0,25249.2022702694,27077.45137095452,25249.2022702694,1821.799224853516,3.3368849754333496,0.0 -56600,1.2989583,2.3880098,,,,,,,,,,,,,, -56700,1.1327682,4.0189466,,,,,,,,,,,,,, -56800,1.338853,2.2822113,,,,,,,,,,,,,, -56900,1.3728489,2.2489378,,,,,,,,,,,,,, -57000,1.2494555,4.596592,,,,,,,,,,,,,, -57100,1.4066317,2.26334,,,,,,,,,,,,,, -57200,1.0992986,3.8535416,,,,,,,,,,,,,, -57300,1.2326931,2.6263628,,,,,,,,,,,,,, -57400,1.2249593,2.3323183,,,,,,,,,,,,,, -57500,1.1135709,3.194053,,,,,,,,,,,,,, -57501,,,0.68505859375,1.295912742614746,0.6319800019264221,1.5335566997528076,50000.0,0.5076000094413757,2.20805025100708,10000.0,25669.92558169365,27530.83186006546,25669.92558169365,1854.3612713813784,3.379136562347412,0.0 -57600,1.0801543,3.674181,,,,,,,,,,,,,, -57700,1.2104924,2.2890005,,,,,,,,,,,,,, -57800,1.1219008,2.998581,,,,,,,,,,,,,, -57900,1.2435645,2.3877447,,,,,,,,,,,,,, -58000,1.4199355,2.503895,,,,,,,,,,,,,, -58100,1.2493291,2.4666002,,,,,,,,,,,,,, -58200,1.3950847,5.0273285,,,,,,,,,,,,,, -58300,1.0936306,2.397486,,,,,,,,,,,,,, -58400,1.4104656,2.3646717,,,,,,,,,,,,,, -58440,,,0.6932421922683716,1.2518590688705444,0.6332600116729736,1.53603196144104,50000.0,0.5067000389099121,2.2019996643066406,10000.0,26090.231457710262,27982.466410398483,26090.231457710262,1885.599454164505,3.4180257320404053,0.0 -58500,1.2094276,2.3385277,,,,,,,,,,,,,, -58600,1.1724429,2.306349,,,,,,,,,,,,,, -58700,1.1527832,4.466196,,,,,,,,,,,,,, -58800,1.3277235,2.3777194,,,,,,,,,,,,,, -58900,1.2556854,2.3535202,,,,,,,,,,,,,, -59000,1.0884689,4.196599,,,,,,,,,,,,,, -59100,1.190785,2.3313215,,,,,,,,,,,,,, -59200,1.2247411,2.6019468,,,,,,,,,,,,,, -59300,1.1949208,2.808303,,,,,,,,,,,,,, -59378,,,0.6830077767372131,1.290729284286499,0.6337800025939941,1.5220874547958374,50000.0,0.5088000297546387,2.177444696426392,10000.0,26510.5277197361,28440.20163846016,26510.5277197361,1922.9465517997744,3.4586949348449707,0.0 -59400,1.2840989,2.2339745,,,,,,,,,,,,,, -59500,1.2594466,2.3313017,,,,,,,,,,,,,, -59600,1.2692282,2.3284078,,,,,,,,,,,,,, -59700,1.1927865,4.596383,,,,,,,,,,,,,, -59800,1.0702244,3.80047,,,,,,,,,,,,,, -59900,1.0522952,4.081515,,,,,,,,,,,,,, -60000,1.1403049,3.0443814,,,,,,,,,,,,,, -60100,1.154736,3.125485,,,,,,,,,,,,,, -60200,1.2553416,2.4119587,,,,,,,,,,,,,, -60300,1.1420164,3.905868,,,,,,,,,,,,,, -60323,,,0.6819140315055847,1.2942769527435305,0.6343799829483032,1.5181857347488403,50000.0,0.5124000310897827,2.165943622589112,10000.0,26930.86291337013,28893.049347400665,26930.86291337013,1955.368063211441,3.498325824737549,0.0 -60400,1.1712552,4.6602383,,,,,,,,,,,,,, -60500,1.1843797,4.5468755,,,,,,,,,,,,,, -60600,1.2197217,2.7424479,,,,,,,,,,,,,, -60700,1.2683922,2.2180285,,,,,,,,,,,,,, -60800,1.1947327,4.030742,,,,,,,,,,,,,, -60900,1.3688095,2.2811604,,,,,,,,,,,,,, -61000,1.377889,2.3411937,,,,,,,,,,,,,, -61100,1.0215119,3.9526284,,,,,,,,,,,,,, -61200,1.2879744,2.3541787,,,,,,,,,,,,,, -61265,,,0.6977343559265137,1.215801477432251,0.640500009059906,1.4811792373657229,50000.0,0.5146000385284424,2.156124830245972,10000.0,27351.005984783173,29344.85937547684,27351.005984783173,1986.939861536026,3.541537284851074,0.0 -61300,1.2721688,2.270133,,,,,,,,,,,,,, -61400,1.5030516,2.3551364,,,,,,,,,,,,,, -61500,1.2930299,2.0703921,,,,,,,,,,,,,, -61600,1.2369155,4.886231,,,,,,,,,,,,,, -61700,1.1268886,3.2842016,,,,,,,,,,,,,, -61800,1.31958,2.405396,,,,,,,,,,,,,, -61900,1.3416874,2.4294481,,,,,,,,,,,,,, -62000,1.2010825,2.9312537,,,,,,,,,,,,,, -62100,1.1989928,4.391283,,,,,,,,,,,,,, -62200,1.3049648,2.3976421,,,,,,,,,,,,,, -62208,,,0.7203710675239563,1.1145329475402832,0.6396999955177307,1.4701439142227173,50000.0,0.5141000151634216,2.1523773670196533,10000.0,27771.32909488678,29796.588678121567,27771.32909488678,2018.25333237648,3.5821726322174072,0.0 -62300,1.091301,4.1404257,,,,,,,,,,,,,, -62400,1.2419467,2.2618775,,,,,,,,,,,,,, -62500,1.2955,2.3200383,,,,,,,,,,,,,, -62600,1.2880207,2.2316327,,,,,,,,,,,,,, -62700,1.17352,4.3334975,,,,,,,,,,,,,, -62800,1.218602,2.4315953,,,,,,,,,,,,,, -62900,1.2488179,2.38586,,,,,,,,,,,,,, -63000,1.3436676,2.1620908,,,,,,,,,,,,,, -63100,1.196499,3.7868383,,,,,,,,,,,,,, -63150,,,0.68994140625,1.255996823310852,0.6400200128555298,1.4859697818756104,50000.0,0.5139000415802002,2.1568527221679688,10000.0,28191.540812969208,30250.63220858574,28191.540812969208,2051.9938457012177,3.6222991943359375,0.0 -63200,1.3284605,2.407707,,,,,,,,,,,,,, -63300,1.2763042,2.5286608,,,,,,,,,,,,,, -63400,1.3272989,2.3520355,,,,,,,,,,,,,, -63500,1.2505889,2.2702637,,,,,,,,,,,,,, -63600,1.2340809,2.271056,,,,,,,,,,,,,, -63700,1.480429,4.5819583,,,,,,,,,,,,,, -63800,1.3337857,2.1763709,,,,,,,,,,,,,, -63900,1.2285125,2.117383,,,,,,,,,,,,,, -64000,1.2881259,2.1378229,,,,,,,,,,,,,, -64094,,,0.6942577958106995,1.2345657348632812,0.6388799548149109,1.49025559425354,50000.0,0.5193000435829163,2.129345655441284,10000.0,28611.683968544006,30704.06983423233,28611.683968544006,2085.192571163177,3.666409969329834,0.0 -64100,1.2923342,2.1737337,,,,,,,,,,,,,, -64200,1.2742497,2.3966212,,,,,,,,,,,,,, -64300,1.2319771,2.2785783,,,,,,,,,,,,,, -64400,1.2342416,2.9716508,,,,,,,,,,,,,, -64500,1.2220532,2.6356168,,,,,,,,,,,,,, -64600,1.1876397,2.6214707,,,,,,,,,,,,,, -64700,1.3068358,2.2584167,,,,,,,,,,,,,, -64800,1.3880719,2.3878603,,,,,,,,,,,,,, -64900,1.3449755,2.1422596,,,,,,,,,,,,,, -65000,1.2177372,4.153278,,,,,,,,,,,,,, -65041,,,0.7075585722923279,1.1857621669769287,0.6427599787712097,1.4769474267959597,50000.0,0.5171000361442566,2.138280868530273,10000.0,29031.91866993904,31156.541711330414,29031.91866993904,2117.3324999809265,3.71160888671875,0.0 -65100,1.2278073,3.6877406,,,,,,,,,,,,,, -65200,1.2906765,2.1899233,,,,,,,,,,,,,, -65300,1.3531659,2.2521086,,,,,,,,,,,,,, -65400,1.321926,2.2198596,,,,,,,,,,,,,, -65500,1.0972809,3.2544656,,,,,,,,,,,,,, -65600,1.338283,2.2167969,,,,,,,,,,,,,, -65700,1.2469486,2.2147858,,,,,,,,,,,,,, -65800,1.1853685,4.8525424,,,,,,,,,,,,,, -65900,1.259171,2.1042013,,,,,,,,,,,,,, -65984,,,0.69580078125,1.236257791519165,0.645799994468689,1.4627673625946045,50000.0,0.5164999961853027,2.131663084030152,10000.0,29451.99482369423,31607.9455947876,29451.99482369423,2148.558340787888,3.76130747795105,0.0 -66000,1.2502732,3.0157912,,,,,,,,,,,,,, -66100,1.3692324,2.1922576,,,,,,,,,,,,,, -66200,1.71439,2.3882706,,,,,,,,,,,,,, -66300,1.1123472,4.363405,,,,,,,,,,,,,, -66400,1.3684131,2.5168216,,,,,,,,,,,,,, -66500,1.1905406,4.870682,,,,,,,,,,,,,, -66600,1.1669123,2.7964635,,,,,,,,,,,,,, -66700,1.3776236,2.3790743,,,,,,,,,,,,,, -66800,1.2666094,2.2362452,,,,,,,,,,,,,, -66900,1.1601307,2.743785,,,,,,,,,,,,,, -66925,,,0.6991015672683716,1.2016446590423584,0.6476799845695496,1.4438791275024414,50000.0,0.5211000442504883,2.114197492599488,10000.0,29872.294987916943,32060.55270266533,29872.294987916943,2180.7719078063965,3.802668809890747,0.0 -67000,1.3751603,2.1204,,,,,,,,,,,,,, -67100,1.3387494,2.3357122,,,,,,,,,,,,,, -67200,1.2041919,2.2582495,,,,,,,,,,,,,, -67300,1.2840967,2.1651158,,,,,,,,,,,,,, -67400,1.2820171,2.1556985,,,,,,,,,,,,,, -67500,1.1908987,2.785495,,,,,,,,,,,,,, -67600,1.3105776,2.1884036,,,,,,,,,,,,,, -67700,1.5607612,2.3324516,,,,,,,,,,,,,, -67800,1.306624,2.1619954,,,,,,,,,,,,,, -67866,,,0.7053515315055847,1.1844969987869265,0.6496999859809875,1.450947642326355,50000.0,0.5210000276565552,2.1125617027282715,10000.0,30292.256318330765,32513.07097506523,30292.256318330765,2213.23275232315,3.84781265258789,0.0 -67900,1.1694735,2.8911192,,,,,,,,,,,,,, -68000,1.0765276,3.4196463,,,,,,,,,,,,,, -68100,1.2315449,2.530428,,,,,,,,,,,,,, -68200,1.0920078,3.079386,,,,,,,,,,,,,, -68300,1.1903132,3.625862,,,,,,,,,,,,,, -68400,1.4057392,2.2015576,,,,,,,,,,,,,, -68500,1.0945611,3.7662666,,,,,,,,,,,,,, -68600,1.3325846,4.3933115,,,,,,,,,,,,,, -68700,1.1826454,2.0195618,,,,,,,,,,,,,, -68800,1.0750055,2.888604,,,,,,,,,,,,,, -68811,,,0.7294921875,1.0799065828323364,0.6500200033187866,1.4359275102615356,50000.0,0.524399995803833,2.091939687728882,10000.0,30712.568786621094,32965.27647805214,30712.568786621094,2245.0266540050507,3.895355939865112,0.0 -68900,1.1829293,4.009277,,,,,,,,,,,,,, -69000,1.3205281,2.3359902,,,,,,,,,,,,,, -69100,1.0747164,3.580195,,,,,,,,,,,,,, -69200,1.4403968,2.4673173,,,,,,,,,,,,,, -69300,1.1177158,4.420896,,,,,,,,,,,,,, -69400,1.2500517,2.1098344,,,,,,,,,,,,,, -69500,1.2629021,2.1402047,,,,,,,,,,,,,, -69600,1.3736098,2.2361863,,,,,,,,,,,,,, -69700,1.0952717,4.4120183,,,,,,,,,,,,,, -69748,,,0.6984961032867432,1.211103916168213,0.6511600017547607,1.4439362287521362,50000.0,0.5236000418663025,2.12225604057312,10000.0,31132.49741792679,33417.12102437019,31132.49741792679,2276.8387970924377,3.9474518299102783,0.0 -69800,1.300589,2.3143282,,,,,,,,,,,,,, -69900,1.1982253,4.605135,,,,,,,,,,,,,, -70000,1.2444888,2.9612985,,,,,,,,,,,,,, -70100,1.3872396,2.199483,,,,,,,,,,,,,, -70200,1.5002637,2.3390772,,,,,,,,,,,,,, -70300,1.1573976,3.018544,,,,,,,,,,,,,, -70400,1.2349452,2.3129847,,,,,,,,,,,,,, -70500,1.2436463,2.1499667,,,,,,,,,,,,,, -70600,1.1813712,3.5673006,,,,,,,,,,,,,, -70689,,,0.7059765458106995,1.168419361114502,0.6503599882125854,1.433741331100464,50000.0,0.5211000442504883,2.1125056743621826,10000.0,31552.53966331482,33868.39933013916,31552.53966331482,2307.9747524261475,3.996495962142944,0.0 -70700,1.1432612,3.9855695,,,,,,,,,,,,,, -70800,1.2557416,2.30458,,,,,,,,,,,,,, -70900,1.3060905,2.1836672,,,,,,,,,,,,,, -71000,1.2554349,2.9272308,,,,,,,,,,,,,, -71100,1.213944,4.107193,,,,,,,,,,,,,, -71200,1.1266063,3.5318298,,,,,,,,,,,,,, -71300,1.1908054,4.649292,,,,,,,,,,,,,, -71400,1.3415468,2.2568896,,,,,,,,,,,,,, -71500,1.4197731,2.0579107,,,,,,,,,,,,,, -71600,1.2756641,2.2201629,,,,,,,,,,,,,, -71632,,,0.7135937213897705,1.1473480463027954,0.6534000039100647,1.4237979650497437,50000.0,0.5218000411987305,2.1029539108276367,10000.0,31972.546494960785,34320.61787605286,31972.546494960785,2340.085106611252,4.0459089279174805,0.0 -71700,1.3008763,4.706443,,,,,,,,,,,,,, -71800,1.4476838,2.1694934,,,,,,,,,,,,,, -71900,1.3236618,3.6493788,,,,,,,,,,,,,, -72000,1.3473309,2.0613387,,,,,,,,,,,,,, -72100,1.3158675,2.2870574,,,,,,,,,,,,,, -72200,1.2498033,2.969697,,,,,,,,,,,,,, -72300,1.1994871,2.3298068,,,,,,,,,,,,,, -72400,1.2307168,4.2145557,,,,,,,,,,,,,, -72500,1.2265453,2.9040606,,,,,,,,,,,,,, -72572,,,0.6991015672683716,1.2266145944595337,0.6523399949073792,1.445721983909607,50000.0,0.5273000001907349,2.112567186355591,10000.0,32392.69544196129,34772.84811258316,32392.69544196129,2372.071283340454,4.090409517288208,0.0 -72600,1.1844993,4.255528,,,,,,,,,,,,,, -72700,1.2267207,2.800978,,,,,,,,,,,,,, -72800,1.3030224,2.170631,,,,,,,,,,,,,, -72900,1.3961792,2.143623,,,,,,,,,,,,,, -73000,1.3737253,2.0063012,,,,,,,,,,,,,, -73100,1.318911,2.5078478,,,,,,,,,,,,,, -73200,1.21925,4.756263,,,,,,,,,,,,,, -73300,1.4116788,2.2320192,,,,,,,,,,,,,, -73400,1.2601042,2.908347,,,,,,,,,,,,,, -73500,1.305652,2.1342168,,,,,,,,,,,,,, -73510,,,0.7064843773841858,1.1597418785095217,0.651919960975647,1.4081584215164185,50000.0,0.5275000333786011,2.077162504196167,10000.0,32812.83591794968,35224.67768549919,32812.83591794968,2403.665696620941,4.1330084800720215,0.0 -73600,1.1758837,3.107226,,,,,,,,,,,,,, -73700,1.2062168,2.6388006,,,,,,,,,,,,,, -73800,1.2204508,3.6795828,,,,,,,,,,,,,, -73900,1.3545686,4.296754,,,,,,,,,,,,,, -74000,1.104256,3.0735226,,,,,,,,,,,,,, -74100,1.3120258,2.2716215,,,,,,,,,,,,,, -74200,1.1956035,4.2271852,,,,,,,,,,,,,, -74300,1.3076817,2.3113816,,,,,,,,,,,,,, -74400,1.3583156,2.281589,,,,,,,,,,,,,, -74451,,,0.7114648222923279,1.1783130168914795,0.652899980545044,1.4420287609100342,50000.0,0.5303000211715698,2.106510162353516,10000.0,33232.861943244934,35682.67705178261,33232.861943244934,2441.54381275177,4.176920413970947,0.0 -74500,1.3271492,2.2066703,,,,,,,,,,,,,, -74600,1.5676255,2.1072872,,,,,,,,,,,,,, -74700,1.2419686,2.9346848,,,,,,,,,,,,,, -74800,1.264736,2.1616194,,,,,,,,,,,,,, -74900,1.2778242,2.477367,,,,,,,,,,,,,, -75000,1.2181064,3.1598234,,,,,,,,,,,,,, -75100,1.3585216,2.0664601,,,,,,,,,,,,,, -75200,1.122581,3.3836925,,,,,,,,,,,,,, -75300,1.3919898,2.1958907,,,,,,,,,,,,,, -75395,,,0.7361718416213989,1.0584875345230105,0.6551799774169922,1.4225327968597412,50000.0,0.5238000154495239,2.119074583053589,10000.0,33653.16945314407,36134.95932555199,33653.16945314407,2473.425219774246,4.219008684158325,0.0 -75400,1.3552016,4.4344325,,,,,,,,,,,,,, -75500,1.2284923,3.2314646,,,,,,,,,,,,,, -75600,1.1745117,2.9850671,,,,,,,,,,,,,, -75700,1.1510752,4.0383263,,,,,,,,,,,,,, -75800,1.2393647,2.8983917,,,,,,,,,,,,,, -75900,1.3906813,2.470906,,,,,,,,,,,,,, -76000,1.3893088,4.6475534,,,,,,,,,,,,,, -76100,1.3124379,2.1306581,,,,,,,,,,,,,, -76200,1.1843059,2.9275007,,,,,,,,,,,,,, -76300,1.3195266,2.19272,,,,,,,,,,,,,, -76333,,,0.7122265696525574,1.1508327722549438,0.658840000629425,1.3871424198150637,50000.0,0.5350000262260437,2.044471740722656,10000.0,34073.50621318817,36588.21505665779,34073.50621318817,2506.250350475312,4.261711597442627,0.0 -76400,1.2926842,2.1233554,,,,,,,,,,,,,, -76500,1.4365721,2.2946894,,,,,,,,,,,,,, -76600,1.3787141,2.3864655,,,,,,,,,,,,,, -76700,1.1757492,2.4795918,,,,,,,,,,,,,, -76800,1.4322959,2.1419125,,,,,,,,,,,,,, -76900,1.5731564,2.278079,,,,,,,,,,,,,, -77000,1.4697144,2.4538436,,,,,,,,,,,,,, -77100,1.2938809,2.7139142,,,,,,,,,,,,,, -77200,1.5798969,2.2141948,,,,,,,,,,,,,, -77270,,,0.7195898294448853,1.1208205223083496,0.6632800102233887,1.391443133354187,50000.0,0.5396000146865845,2.0465617179870605,10000.0,34493.52569794655,37045.77448058128,34493.52569794655,2543.693027973175,4.307955265045166,0.0 -77300,1.3992709,2.393,,,,,,,,,,,,,, -77400,1.246056,4.049859,,,,,,,,,,,,,, -77500,1.3395073,4.681263,,,,,,,,,,,,,, -77600,1.2653676,2.2970963,,,,,,,,,,,,,, -77700,1.3381805,2.298035,,,,,,,,,,,,,, -77800,1.2316828,4.4667015,,,,,,,,,,,,,, -77900,1.2585342,2.7008154,,,,,,,,,,,,,, -78000,1.224656,3.454385,,,,,,,,,,,,,, -78100,1.4577026,2.2087314,,,,,,,,,,,,,, -78200,1.2876439,2.477185,,,,,,,,,,,,,, -78215,,,0.7258203029632568,1.1131298542022705,0.6583600044250488,1.4111359119415283,50000.0,0.5385000109672546,2.0573971271514893,10000.0,34913.678639411926,37498.90714406967,34913.678639411926,2576.5816905498505,4.3477983474731445,0.0 -78300,1.3838999,2.661396,,,,,,,,,,,,,, -78400,1.3003746,2.279672,,,,,,,,,,,,,, -78500,1.1746438,3.0977533,,,,,,,,,,,,,, -78600,1.3303866,4.6935906,,,,,,,,,,,,,, -78700,1.3414224,2.6962576,,,,,,,,,,,,,, -78800,1.3440875,2.6120987,,,,,,,,,,,,,, -78900,1.3372424,4.597924,,,,,,,,,,,,,, -79000,1.4029995,2.4940958,,,,,,,,,,,,,, -79100,1.3151541,2.0490236,,,,,,,,,,,,,, -79157,,,0.7126757502555847,1.168033480644226,0.6606799960136414,1.408635139465332,50000.0,0.5360000133514404,2.062854766845703,10000.0,35333.94675087929,37954.20350050926,35333.94675087929,2611.512553215027,4.392807483673096,0.0 -79200,1.1449567,4.24038,,,,,,,,,,,,,, -79300,1.194365,4.163761,,,,,,,,,,,,,, -79400,1.1820631,3.9626555,,,,,,,,,,,,,, -79500,1.4428357,2.2163014,,,,,,,,,,,,,, -79600,1.2077267,4.212723,,,,,,,,,,,,,, -79700,1.3302443,2.1304188,,,,,,,,,,,,,, -79800,1.2874393,2.6311455,,,,,,,,,,,,,, -79900,1.1818882,2.6227484,,,,,,,,,,,,,, -80000,1.1917384,3.1330538,,,,,,,,,,,,,, -80099,,,0.71546870470047,1.14302659034729,0.6597599983215332,1.3939063549041748,50000.0,0.5321000218391418,2.0679373741149902,10000.0,35753.97294163704,38410.941356658936,35753.97294163704,2648.128517389297,4.436190128326416,0.0 -80100,1.3753915,3.4881456,,,,,,,,,,,,,, -80200,1.3953862,2.2511427,,,,,,,,,,,,,, -80300,1.3642819,2.2821527,,,,,,,,,,,,,, -80400,1.1736457,3.354788,,,,,,,,,,,,,, -80500,1.2763822,2.3055935,,,,,,,,,,,,,, -80600,1.3369627,2.104849,,,,,,,,,,,,,, -80700,1.3526398,2.0387278,,,,,,,,,,,,,, -80800,1.3537512,2.620309,,,,,,,,,,,,,, -80900,1.3188188,2.0065672,,,,,,,,,,,,,, -81000,1.1444768,3.5728931,,,,,,,,,,,,,, -81043,,,0.7226366996765137,1.1191169023513794,0.6619600057601929,1.3967500925064087,50000.0,0.5390000343322754,2.057657480239868,10000.0,36174.29339551926,38864.1615319252,36174.29339551926,2680.9402170181274,4.473475933074951,0.0 -81100,1.230282,2.3230145,,,,,,,,,,,,,, -81200,1.3999252,2.0203893,,,,,,,,,,,,,, -81300,1.3942986,2.0656445,,,,,,,,,,,,,, -81400,1.1558388,2.5865724,,,,,,,,,,,,,, -81500,1.4150059,2.1974223,,,,,,,,,,,,,, -81600,1.5289289,2.0376468,,,,,,,,,,,,,, -81700,1.3040941,2.5215507,,,,,,,,,,,,,, -81800,1.2629862,2.063742,,,,,,,,,,,,,, -81900,1.3510357,2.1756074,,,,,,,,,,,,,, -81982,,,0.7477148175239563,1.0231945514678955,0.6630799770355225,1.392142415046692,50000.0,0.5347000360488892,2.0538101196289062,10000.0,36594.43833613396,39317.56582832336,36594.43833613396,2714.103229045868,4.518529653549194,0.0 -82000,1.2658299,4.535966,,,,,,,,,,,,,, -82100,1.4115676,2.1305463,,,,,,,,,,,,,, -82200,1.2507502,2.2505527,,,,,,,,,,,,,, -82300,1.3088136,2.2452466,,,,,,,,,,,,,, -82400,1.2853955,2.32006,,,,,,,,,,,,,, -82500,1.4310796,2.057361,,,,,,,,,,,,,, -82600,1.3743155,1.9894061,,,,,,,,,,,,,, -82700,1.3541012,2.3253715,,,,,,,,,,,,,, -82800,1.2360421,1.9741635,,,,,,,,,,,,,, -82900,1.2980797,2.4025724,,,,,,,,,,,,,, -82924,,,0.7239843606948853,1.1104586124420166,0.6676200032234192,1.363040804862976,50000.0,0.5391000509262085,2.0352330207824707,10000.0,37014.43953371048,39773.3790769577,37014.43953371048,2749.8208360672,4.561380863189697,0.0 -83000,1.3777574,2.6517823,,,,,,,,,,,,,, -83100,1.3796419,2.1411853,,,,,,,,,,,,,, -83200,1.3736677,2.1122017,,,,,,,,,,,,,, -83300,1.4229853,1.9934021,,,,,,,,,,,,,, -83400,1.3258927,2.0952954,,,,,,,,,,,,,, -83500,1.1615751,4.5176697,,,,,,,,,,,,,, -83600,1.3644634,2.478475,,,,,,,,,,,,,, -83700,1.4036553,2.0412116,,,,,,,,,,,,,, -83800,1.3236654,3.0461547,,,,,,,,,,,,,, -83865,,,0.7298241853713989,1.0960992574691772,0.6693199872970581,1.3602654933929443,50000.0,0.5454000234603882,2.023963689804077,10000.0,37434.735203027725,40225.933086156845,37434.735203027725,2781.979739665985,4.608523607254028,0.0 -83900,1.3393073,2.1447263,,,,,,,,,,,,,, -84000,1.413009,2.30743,,,,,,,,,,,,,, -84100,1.3677955,2.0858383,,,,,,,,,,,,,, -84200,1.235588,3.14772,,,,,,,,,,,,,, -84300,1.3471049,2.617488,,,,,,,,,,,,,, -84400,1.390371,1.9697069,,,,,,,,,,,,,, -84500,1.2971548,4.0338197,,,,,,,,,,,,,, -84600,1.4051269,2.0873947,,,,,,,,,,,,,, -84700,1.4015574,3.167099,,,,,,,,,,,,,, -84800,1.4725026,2.2433627,,,,,,,,,,,,,, -84806,,,0.73451167345047,1.0335588455200195,0.6688599586486816,1.3441568613052368,50000.0,0.5470000505447388,2.004969358444214,10000.0,37854.68165183067,40677.22996091843,37854.68165183067,2813.217336177826,4.663280725479126,0.0 -84900,1.4652348,2.2140152,,,,,,,,,,,,,, -85000,1.2608519,4.1202855,,,,,,,,,,,,,, -85100,1.144923,3.6255713,,,,,,,,,,,,,, -85200,1.4124148,2.1160958,,,,,,,,,,,,,, -85300,1.2095101,2.943133,,,,,,,,,,,,,, -85400,1.381079,2.063579,,,,,,,,,,,,,, -85500,1.2468969,4.487299,,,,,,,,,,,,,, -85600,1.4211649,2.1418452,,,,,,,,,,,,,, -85700,1.3200619,1.9376307,,,,,,,,,,,,,, -85746,,,0.7251952886581421,1.0878881216049194,0.667419970035553,1.332442045211792,50000.0,0.541700005531311,2.016709089279175,10000.0,38275.05656552315,41130.54695153237,38275.05656552315,2846.062223434448,4.709845066070557,0.0 -85800,1.2647333,2.0854533,,,,,,,,,,,,,, -85900,1.3071842,2.3845885,,,,,,,,,,,,,, -86000,1.499263,4.360537,,,,,,,,,,,,,, -86100,1.5525417,2.5587888,,,,,,,,,,,,,, -86200,1.3978248,2.0320363,,,,,,,,,,,,,, -86300,1.4410923,2.553266,,,,,,,,,,,,,, -86400,1.2234759,2.8026133,,,,,,,,,,,,,, -86500,1.1217337,3.2288537,,,,,,,,,,,,,, -86600,1.6217942,2.099832,,,,,,,,,,,,,, -86688,,,0.7310351133346558,1.069846749305725,0.6744799613952637,1.320836067199707,50000.0,0.5494000315666199,1.985035419464112,10000.0,38695.34867835045,41588.12759113312,38695.34867835045,2883.2515251636505,4.756951808929443,0.0 -86700,1.3133488,2.5886686,,,,,,,,,,,,,, -86800,1.2392632,3.3389325,,,,,,,,,,,,,, -86900,1.2640495,2.7268252,,,,,,,,,,,,,, -87000,1.3981557,2.0137637,,,,,,,,,,,,,, -87100,1.4446678,2.1595614,,,,,,,,,,,,,, -87200,1.3860352,4.3372602,,,,,,,,,,,,,, -87300,1.5711714,2.0421982,,,,,,,,,,,,,, -87400,1.1992306,3.5509713,,,,,,,,,,,,,, -87500,1.3493954,2.0214665,,,,,,,,,,,,,, -87600,1.3273798,2.1901789,,,,,,,,,,,,,, -87633,,,0.7341992259025574,1.0544023513793943,0.6748200058937073,1.3344672918319702,50000.0,0.5503000020980835,1.9918979406356807,10000.0,39115.71409368515,42043.83431196213,39115.71409368515,2918.502597808838,4.796014308929443,0.0 -87700,1.3613039,3.9194024,,,,,,,,,,,,,, -87800,1.5004143,2.21563,,,,,,,,,,,,,, -87900,1.3389264,2.0231366,,,,,,,,,,,,,, -88000,1.5162123,4.280714,,,,,,,,,,,,,, -88100,1.4310368,2.0760071,,,,,,,,,,,,,, -88200,1.4423897,2.0787675,,,,,,,,,,,,,, -88300,1.383777,2.1764083,,,,,,,,,,,,,, -88400,1.4285631,2.1581242,,,,,,,,,,,,,, -88500,1.3216057,3.6090038,,,,,,,,,,,,,, -88575,,,0.7554491758346558,0.969889760017395,0.6746399998664856,1.3290221691131592,50000.0,0.5534000396728516,1.966261625289917,10000.0,39535.89856672287,42495.30137729645,39535.89856672287,2949.6885225772858,4.841732740402222,0.0 -88600,1.4884003,2.299627,,,,,,,,,,,,,, -88700,1.4601967,1.9879932,,,,,,,,,,,,,, -88800,1.602088,4.771986,,,,,,,,,,,,,, -88900,1.2738043,2.5190594,,,,,,,,,,,,,, -89000,1.469674,4.5335526,,,,,,,,,,,,,, -89100,1.3544636,2.0368915,,,,,,,,,,,,,, -89200,1.2800953,4.2977824,,,,,,,,,,,,,, -89300,1.455183,2.0421588,,,,,,,,,,,,,, -89400,1.5117997,2.1195123,,,,,,,,,,,,,, -89500,1.4159535,2.3692253,,,,,,,,,,,,,, -89509,,,0.7327538728713989,1.073253512382507,0.676580011844635,1.3220385313034058,50000.0,0.5515000224113464,1.975314736366272,10000.0,39956.18599700928,42947.557476997375,39956.18599700928,2981.5580892562866,4.89051628112793,0.0 -89600,1.4048063,2.159841,,,,,,,,,,,,,, -89700,1.4068931,3.7800713,,,,,,,,,,,,,, -89800,1.3915875,2.0601494,,,,,,,,,,,,,, -89900,1.4784746,1.9523938,,,,,,,,,,,,,, -90000,1.2788259,2.4236135,,,,,,,,,,,,,, -90100,1.4558185,2.0326746,,,,,,,,,,,,,, -90200,1.2744282,4.5457935,,,,,,,,,,,,,, -90300,1.4622859,2.0484307,,,,,,,,,,,,,, -90400,1.3106592,4.3893213,,,,,,,,,,,,,, -90446,,,0.7387109398841858,1.030994176864624,0.6740999817848206,1.3186286687850952,50000.0,0.5520000457763672,1.968473553657532,10000.0,40376.54059243202,43402.19970941544,40376.54059243202,3015.738451242447,4.947018384933472,0.0 -90500,1.4048046,1.949521,,,,,,,,,,,,,, -90600,1.4560075,2.1722386,,,,,,,,,,,,,, -90700,1.3221631,4.1651397,,,,,,,,,,,,,, -90800,1.2361009,3.4446044,,,,,,,,,,,,,, -90900,1.4387987,4.208662,,,,,,,,,,,,,, -91000,1.4077057,4.2952685,,,,,,,,,,,,,, -91100,1.4031587,2.1078777,,,,,,,,,,,,,, -91200,1.509679,1.9129367,,,,,,,,,,,,,, -91300,1.2517987,2.9788435,,,,,,,,,,,,,, -91389,,,0.7518945336341858,0.9972430467605592,0.6795799732208252,1.3187674283981323,50000.0,0.5552000403404236,1.9847722053527832,10000.0,40796.58492851257,43857.57451105118,40796.58492851257,3050.9701120853424,4.993818759918213,0.0 -91400,1.3966837,2.0609024,,,,,,,,,,,,,, -91500,1.3516003,2.1508842,,,,,,,,,,,,,, -91600,1.4633573,1.9383051,,,,,,,,,,,,,, -91700,1.3796778,2.3451023,,,,,,,,,,,,,, -91800,1.4082978,2.5368226,,,,,,,,,,,,,, -91900,1.456292,2.0717726,,,,,,,,,,,,,, -92000,1.4434254,2.0215898,,,,,,,,,,,,,, -92100,1.3691826,2.3206604,,,,,,,,,,,,,, -92200,1.367024,2.7567527,,,,,,,,,,,,,, -92300,1.4512798,2.0359504,,,,,,,,,,,,,, -92329,,,0.7335546612739563,1.0436562299728394,0.6787799596786499,1.2994449138641355,50000.0,0.5539000034332275,1.954208254814148,10000.0,41216.692452430725,44308.60667061806,41216.692452430725,3081.7949402332306,5.041177034378052,0.0 -92400,1.4178092,2.4085362,,,,,,,,,,,,,, -92500,1.2837272,3.0246453,,,,,,,,,,,,,, -92600,1.4296159,2.0709448,,,,,,,,,,,,,, -92700,1.4392757,4.353937,,,,,,,,,,,,,, -92800,1.3837689,2.6898727,,,,,,,,,,,,,, -92900,1.2627461,3.4683638,,,,,,,,,,,,,, -93000,1.5920657,2.1704733,,,,,,,,,,,,,, -93100,1.3678681,2.1366298,,,,,,,,,,,,,, -93200,1.4570155,2.0533578,,,,,,,,,,,,,, -93273,,,0.7421875,1.0251456499099731,0.6827200055122375,1.2972899675369265,50000.0,0.5590000152587891,1.941738963127136,10000.0,41636.87300825119,44761.32083892822,41636.87300825119,3114.198682308197,5.118264436721802,0.0 -93300,1.2989368,2.5105524,,,,,,,,,,,,,, -93400,1.5234628,2.2352235,,,,,,,,,,,,,, -93500,1.3693157,1.8952899,,,,,,,,,,,,,, -93600,1.5034097,1.9695967,,,,,,,,,,,,,, -93700,1.2471051,2.6081285,,,,,,,,,,,,,, -93800,1.5711291,2.1071048,,,,,,,,,,,,,, -93900,1.3996003,2.0531383,,,,,,,,,,,,,, -94000,1.1888452,3.7278323,,,,,,,,,,,,,, -94100,1.3254839,1.9766757,,,,,,,,,,,,,, -94200,1.3584268,4.4837646,,,,,,,,,,,,,, -94214,,,0.7447265386581421,1.0316696166992188,0.6801599860191345,1.314698576927185,50000.0,0.5564000010490417,1.9675493240356443,10000.0,42057.10592675209,45220.5447204113,42057.10592675209,3153.09051322937,5.165774345397949,0.0 -94300,1.3684001,2.9017866,,,,,,,,,,,,,, -94400,1.3985888,1.926034,,,,,,,,,,,,,, -94500,1.4968812,2.013299,,,,,,,,,,,,,, -94600,1.4736785,2.2940068,,,,,,,,,,,,,, -94700,1.5338291,2.119707,,,,,,,,,,,,,, -94800,1.3435879,3.7458346,,,,,,,,,,,,,, -94900,1.4366306,1.9616554,,,,,,,,,,,,,, -95000,1.2488577,3.318177,,,,,,,,,,,,,, -95100,1.3489952,2.929018,,,,,,,,,,,,,, -95160,,,0.7561132907867432,0.99065762758255,0.6864399909973145,1.299466252326965,50000.0,0.5582000017166138,1.948542833328247,10000.0,42477.334065675735,45673.6993329525,42477.334065675735,3185.913388967514,5.208800315856934,0.0 -95200,1.5102315,1.9750746,,,,,,,,,,,,,, -95300,1.5503436,2.0773916,,,,,,,,,,,,,, -95400,1.6931633,2.1340814,,,,,,,,,,,,,, -95500,1.4279488,3.9711592,,,,,,,,,,,,,, -95600,1.398384,3.8314264,,,,,,,,,,,,,, -95700,1.3837163,2.0263715,,,,,,,,,,,,,, -95800,1.4762746,3.9092698,,,,,,,,,,,,,, -95900,1.3948283,3.531043,,,,,,,,,,,,,, -96000,1.438886,2.0975604,,,,,,,,,,,,,, -96100,1.479334,1.9736054,,,,,,,,,,,,,, -96101,,,0.7415820360183716,1.0177706480026243,0.6850599646568298,1.278401017189026,50000.0,0.5559000372886658,1.9323203563690183,10000.0,42897.442705631256,46124.67051315308,42897.442705631256,3216.6718633174896,5.261148452758789,0.0 -96200,1.5519693,1.8533957,,,,,,,,,,,,,, -96300,1.5147305,2.8090599,,,,,,,,,,,,,, -96400,1.327529,3.2970002,,,,,,,,,,,,,, -96500,1.5525297,2.0103555,,,,,,,,,,,,,, -96600,1.5151515,2.230944,,,,,,,,,,,,,, -96700,1.3630174,3.6238995,,,,,,,,,,,,,, -96800,1.425744,2.0128129,,,,,,,,,,,,,, -96900,1.4269223,2.4982717,,,,,,,,,,,,,, -97000,1.5280539,2.012695,,,,,,,,,,,,,, -97040,,,0.7523437142372131,0.9721384048461914,0.6881399750709534,1.2559822797775269,50000.0,0.5623000264167786,1.909122347831726,10000.0,43317.70691943169,46577.83276820183,43317.70691943169,3249.469914674759,5.307514429092407,0.0 -97100,1.2997711,3.7776542,,,,,,,,,,,,,, -97200,1.6415662,2.0454943,,,,,,,,,,,,,, -97300,1.523611,1.985525,,,,,,,,,,,,,, -97400,1.3602417,1.9594694,,,,,,,,,,,,,, -97500,1.2843845,2.6263053,,,,,,,,,,,,,, -97600,1.406737,4.467454,,,,,,,,,,,,,, -97700,1.4033705,3.922847,,,,,,,,,,,,,, -97800,1.5668501,2.041391,,,,,,,,,,,,,, -97900,1.3857789,2.5231671,,,,,,,,,,,,,, -97981,,,0.7609765529632568,0.9368932247161864,0.6894800066947937,1.2643203735351562,50000.0,0.5659000277519226,1.902812004089356,10000.0,43738.02734136581,47032.36952018738,43738.02734136581,3283.58735871315,5.355186462402344,0.0 -98000,1.3785809,3.0924144,,,,,,,,,,,,,, -98100,1.6878247,1.9717491,,,,,,,,,,,,,, -98200,1.3265847,2.8327453,,,,,,,,,,,,,, -98300,1.4375175,2.1203122,,,,,,,,,,,,,, -98400,1.4115229,2.1640444,,,,,,,,,,,,,, -98500,1.4799687,1.9216591,,,,,,,,,,,,,, -98600,1.6666863,2.792562,,,,,,,,,,,,,, -98700,1.4297231,2.9944797,,,,,,,,,,,,,, -98800,1.5442222,1.9775598,,,,,,,,,,,,,, -98900,1.3665305,1.7916393,,,,,,,,,,,,,, -98923,,,0.7438867092132568,0.993510365486145,0.6860799789428711,1.261129379272461,50000.0,0.5624000430107117,1.9388071298599243,10000.0,44158.18087887764,47485.14091467857,44158.18087887764,3316.1046471595764,5.404359579086304,0.0 -99000,1.6001945,2.054056,,,,,,,,,,,,,, -99100,1.4071605,2.8814857,,,,,,,,,,,,,, -99200,1.6385834,1.8849773,,,,,,,,,,,,,, -99300,1.4684067,2.007504,,,,,,,,,,,,,, -99400,1.5365359,2.005153,,,,,,,,,,,,,, -99500,1.4656099,1.9585141,,,,,,,,,,,,,, -99600,1.3683331,2.1695702,,,,,,,,,,,,,, -99700,1.3437386,3.205709,,,,,,,,,,,,,, -99800,1.3157207,2.065089,,,,,,,,,,,,,, -99867,,,0.7521874904632568,0.9745285511016846,0.6863200068473816,1.2652863264083862,50000.0,0.5617000460624695,1.9162745475769043,10000.0,44578.10844540596,47938.01518249512,44578.10844540596,3348.923852443695,5.479732513427734,0.0 -99900,1.5378511,1.953209,,,,,,,,,,,,,, -100000,1.4766095,1.828566,,,,,,,,,,,,,, -100100,1.3128518,2.9208884,,,,,,,,,,,,,, -100200,1.5048786,4.3658223,,,,,,,,,,,,,, -100300,1.4044819,4.0783596,,,,,,,,,,,,,, -100400,1.3298395,2.7437272,,,,,,,,,,,,,, -100500,1.5204369,1.932441,,,,,,,,,,,,,, -100600,1.4601932,1.9899079,,,,,,,,,,,,,, -100700,1.2517619,3.2280865,,,,,,,,,,,,,, -100800,1.5550933,1.9550023,,,,,,,,,,,,,, -100809,,,0.7562499642372131,0.9649302363395692,0.6868999600410461,1.2645305395126345,50000.0,0.5648000240325928,1.908398151397705,10000.0,44998.0839908123,48390.72035264969,44998.0839908123,3381.551055908203,5.530679225921631,0.0 -100900,1.7123733,4.4429226,,,,,,,,,,,,,, -101000,1.3758663,3.219933,,,,,,,,,,,,,, -101100,1.4427062,2.0110369,,,,,,,,,,,,,, -101200,1.4526935,2.7510557,,,,,,,,,,,,,, -101300,1.5386453,1.7982558,,,,,,,,,,,,,, -101400,1.5401763,1.8968346,,,,,,,,,,,,,, -101500,1.4008493,1.9098864,,,,,,,,,,,,,, -101600,1.5076815,2.950335,,,,,,,,,,,,,, -101700,1.4639382,4.047254,,,,,,,,,,,,,, -101754,,,0.75830078125,0.965836763381958,0.6916399598121643,1.2648234367370603,50000.0,0.5596000552177429,1.9227352142333984,10000.0,45418.158311128616,48844.14165306091,45418.158311128616,3414.7790591716766,5.598333120346069,0.0 -101800,1.5230287,1.9011862,,,,,,,,,,,,,, -101900,1.3389351,1.9509594,,,,,,,,,,,,,, -102000,1.5682691,3.839995,,,,,,,,,,,,,, -102100,1.8016424,1.8802292,,,,,,,,,,,,,, -102200,1.5074332,1.9469504,,,,,,,,,,,,,, -102300,1.517509,2.1738007,,,,,,,,,,,,,, -102400,1.5449642,1.9004445,,,,,,,,,,,,,, -102500,1.4520046,1.9056956,,,,,,,,,,,,,, -102600,1.5058917,1.8906895,,,,,,,,,,,,,, -102695,,,0.7544921636581421,0.9744990468025208,0.6915799975395203,1.252241611480713,50000.0,0.5699000358581543,1.8935141563415527,10000.0,45838.346932172775,49302.75633502007,45838.346932172775,3453.1070244312286,5.645540475845337,0.0 -102700,1.4526228,1.9199978,,,,,,,,,,,,,, -102800,1.434959,2.8753366,,,,,,,,,,,,,, -102900,1.5656965,1.893378,,,,,,,,,,,,,, -103000,1.7393088,2.3519056,,,,,,,,,,,,,, -103100,1.4283473,2.0005581,,,,,,,,,,,,,, -103200,1.3818908,4.2473364,,,,,,,,,,,,,, -103300,1.3736557,2.6525996,,,,,,,,,,,,,, -103400,1.4545535,3.3880086,,,,,,,,,,,,,, -103500,1.3118495,3.205206,,,,,,,,,,,,,, -103600,1.4754299,1.8682215,,,,,,,,,,,,,, -103641,,,0.76039057970047,0.9307794570922852,0.6939799785614014,1.2324928045272827,50000.0,0.5738000273704529,1.8720823526382449,10000.0,46258.4792034626,49755.083996772766,46258.4792034626,3485.2028307914734,5.693195819854736,0.0 -103700,1.5332284,4.4810295,,,,,,,,,,,,,, -103800,1.5952857,1.9339354,,,,,,,,,,,,,, -103900,1.4590095,2.1769388,,,,,,,,,,,,,, -104000,1.488824,3.1573935,,,,,,,,,,,,,, -104100,1.6079475,1.9425336,,,,,,,,,,,,,, -104200,1.596848,1.8645921,,,,,,,,,,,,,, -104300,1.6658478,1.8216634,,,,,,,,,,,,,, -104400,1.5560658,1.932699,,,,,,,,,,,,,, -104500,1.5472758,4.225277,,,,,,,,,,,,,, -104581,,,0.7684960961341858,0.922848641872406,0.6940000057220459,1.2535018920898438,50000.0,0.5644000172615051,1.9141690731048584,10000.0,46678.42165637016,50212.28129172325,46678.42165637016,3522.3583607673645,5.741758108139038,0.0 -104600,1.5729545,3.956224,,,,,,,,,,,,,, -104700,1.4203076,3.6000586,,,,,,,,,,,,,, -104800,1.4037726,3.1057174,,,,,,,,,,,,,, -104900,1.5567774,3.8281689,,,,,,,,,,,,,, -105000,1.5779244,1.9239122,,,,,,,,,,,,,, -105100,1.3966739,3.500854,,,,,,,,,,,,,, -105200,1.5857956,1.8894706,,,,,,,,,,,,,, -105300,1.5970147,2.0066774,,,,,,,,,,,,,, -105400,1.5774928,4.239667,,,,,,,,,,,,,, -105500,1.5667701,2.0888443,,,,,,,,,,,,,, -105522,,,0.7603710889816284,0.9424856901168824,0.6970199942588806,1.2145726680755615,50000.0,0.5731000304222107,1.8652210235595703,10000.0,47098.45361089706,50663.58996009827,47098.45361089706,3553.533600568772,5.791807174682617,0.0 -105600,1.424687,2.1663702,,,,,,,,,,,,,, -105700,1.3760774,3.4641771,,,,,,,,,,,,,, -105800,1.4893858,3.3360648,,,,,,,,,,,,,, -105900,1.4906585,1.7747684,,,,,,,,,,,,,, -106000,1.5953563,1.9799328,,,,,,,,,,,,,, -106100,1.4359168,1.799452,,,,,,,,,,,,,, -106200,1.6608937,1.9079103,,,,,,,,,,,,,, -106300,1.5975827,1.9222242,,,,,,,,,,,,,, -106400,1.5613098,1.9285967,,,,,,,,,,,,,, -106459,,,0.7631054520606995,0.9237319231033324,0.6990599632263184,1.2113115787506104,50000.0,0.5690000057220459,1.8761636018753047,10000.0,47518.67924666405,51118.17361497879,47518.67924666405,3587.786563158036,5.8463099002838135,0.0 -106500,1.400387,2.684729,,,,,,,,,,,,,, -106600,1.5889143,1.8678427,,,,,,,,,,,,,, -106700,1.6500068,3.9895124,,,,,,,,,,,,,, -106800,1.5564264,1.865037,,,,,,,,,,,,,, -106900,1.5670161,1.8745677,,,,,,,,,,,,,, -107000,1.5333103,1.8470323,,,,,,,,,,,,,, -107100,1.6311542,4.2203436,,,,,,,,,,,,,, -107200,1.3667128,3.7365737,,,,,,,,,,,,,, -107300,1.3578843,2.5852995,,,,,,,,,,,,,, -107399,,,0.7676562070846558,0.9086039662361144,0.6971399784088135,1.2299054861068726,50000.0,0.5730000138282776,1.8663215637207031,10000.0,47938.64813065529,51571.591633319855,47938.64813065529,3621.135687589645,5.89500904083252,0.0 -107400,1.4788455,1.8489546,,,,,,,,,,,,,, -107500,1.5251296,1.9296749,,,,,,,,,,,,,, -107600,1.4841001,1.9268509,,,,,,,,,,,,,, -107700,1.5753232,2.0168984,,,,,,,,,,,,,, -107800,1.5337553,1.9041111,,,,,,,,,,,,,, -107900,1.6510369,1.9482244,,,,,,,,,,,,,, -108000,1.6656399,1.9981655,,,,,,,,,,,,,, -108100,1.6250595,1.9310138,,,,,,,,,,,,,, -108200,1.5499232,1.8336557,,,,,,,,,,,,,, -108300,1.4521551,3.6501372,,,,,,,,,,,,,, -108342,,,0.7684179544448853,0.91727477312088,0.7032999992370605,1.2101258039474487,50000.0,0.5802000164985657,1.8436038494110107,10000.0,48358.62239551544,52023.73527216911,48358.62239551544,3653.203928470612,5.945183038711548,0.0 -108400,1.5135496,3.0729163,,,,,,,,,,,,,, -108500,1.7193401,2.2093866,,,,,,,,,,,,,, -108600,1.4062567,2.1799347,,,,,,,,,,,,,, -108700,1.5657979,1.8836476,,,,,,,,,,,,,, -108800,1.4271272,2.5636454,,,,,,,,,,,,,, -108900,1.4197092,2.8985853,,,,,,,,,,,,,, -109000,1.4002309,2.959872,,,,,,,,,,,,,, -109100,1.5888661,1.8709422,,,,,,,,,,,,,, -109200,1.5411772,4.2624655,,,,,,,,,,,,,, -109283,,,0.7674413919448853,0.9051960706710817,0.7060799598693848,1.1832246780395508,50000.0,0.5815000534057617,1.822532296180725,10000.0,48778.72830224037,52477.80788874626,48778.72830224037,3687.069760560989,5.995041608810425,0.0 -109300,1.6714156,1.927991,,,,,,,,,,,,,, -109400,1.5621926,3.6155086,,,,,,,,,,,,,, -109500,1.4108788,2.4327147,,,,,,,,,,,,,, -109600,1.6124934,1.8598056,,,,,,,,,,,,,, -109700,1.503239,1.7919436,,,,,,,,,,,,,, -109800,1.7573909,1.8357625,,,,,,,,,,,,,, -109900,1.7835659,2.017867,,,,,,,,,,,,,, -110000,1.5023085,3.0244408,,,,,,,,,,,,,, -110100,1.5264643,3.7938368,,,,,,,,,,,,,, -110200,1.6842359,2.2322547,,,,,,,,,,,,,, -110223,,,0.77357417345047,0.8837026953697205,0.7024399638175964,1.1945924758911133,50000.0,0.5778000354766846,1.833355188369751,10000.0,49198.84672832489,52930.48039579392,49198.84672832489,3719.524533748626,6.042582511901856,0.0 -110300,1.5087765,2.0369015,,,,,,,,,,,,,, -110400,1.447079,3.7544544,,,,,,,,,,,,,, -110500,1.5289273,4.040761,,,,,,,,,,,,,, -110600,1.5439826,1.770029,,,,,,,,,,,,,, -110700,1.7664026,4.192935,,,,,,,,,,,,,, -110800,1.6181775,1.8702765,,,,,,,,,,,,,, -110900,1.6815573,4.3700986,,,,,,,,,,,,,, -111000,1.5523455,4.2974014,,,,,,,,,,,,,, -111100,1.5740521,3.6554039,,,,,,,,,,,,,, -111164,,,0.7866796851158142,0.8273036479949951,0.7063199877738953,1.1831731796264648,50000.0,0.5790000557899475,1.8291984796524048,10000.0,49619.10298538208,53385.648253917694,49619.10298538208,3754.3326518535614,6.095505475997925,0.0 -111200,1.5797335,4.3070545,,,,,,,,,,,,,, -111300,1.6994725,1.7585373,,,,,,,,,,,,,, -111400,1.6873722,2.002111,,,,,,,,,,,,,, -111500,1.5108728,3.5448904,,,,,,,,,,,,,, -111600,1.4437648,2.9406002,,,,,,,,,,,,,, -111700,1.628902,1.9036598,,,,,,,,,,,,,, -111800,1.5945262,1.8843588,,,,,,,,,,,,,, -111900,1.4189612,2.6545002,,,,,,,,,,,,,, -112000,1.5408946,3.3887367,,,,,,,,,,,,,, -112098,,,0.7661523222923279,0.9525970816612244,0.7068399786949158,1.217895746231079,50000.0,0.5746000409126282,1.8687307834625244,10000.0,50039.17558908463,53842.93714237213,50039.17558908463,3791.444357633591,6.148894309997559,0.0 -112100,1.5433568,4.137821,,,,,,,,,,,,,, -112200,1.7875526,1.7751025,,,,,,,,,,,,,, -112300,1.6992294,4.169225,,,,,,,,,,,,,, -112400,1.748599,1.8435305,,,,,,,,,,,,,, -112500,1.7016863,1.7625612,,,,,,,,,,,,,, -112600,1.6934847,1.7723496,,,,,,,,,,,,,, -112700,1.6490208,1.8975697,,,,,,,,,,,,,, -112800,1.4610436,2.9959943,,,,,,,,,,,,,, -112900,1.5843395,1.7602742,,,,,,,,,,,,,, -113000,1.414266,2.8694272,,,,,,,,,,,,,, -113041,,,0.77685546875,0.892575740814209,0.708139955997467,1.1836878061294556,50000.0,0.5866000056266785,1.8236979246139529,10000.0,50459.45133137703,54296.39494585991,50459.45133137703,3824.517323732376,6.20680570602417,0.0 -113100,1.5401605,3.801947,,,,,,,,,,,,,, -113200,1.5068012,3.7443027,,,,,,,,,,,,,, -113300,1.5983894,1.7393873,,,,,,,,,,,,,, -113400,1.5572051,3.4572961,,,,,,,,,,,,,, -113500,1.7290205,1.8329711,,,,,,,,,,,,,, -113600,1.4621155,2.3686736,,,,,,,,,,,,,, -113700,1.52354,2.3499289,,,,,,,,,,,,,, -113800,1.690401,4.185704,,,,,,,,,,,,,, -113900,1.6731174,4.298832,,,,,,,,,,,,,, -113982,,,0.78724604845047,0.8162120580673218,0.7078999876976013,1.162050485610962,50000.0,0.5873000025749207,1.8146487474441528,10000.0,50879.65846824646,54755.30799412727,50879.65846824646,3863.1170892715454,6.262190341949463,0.0 -114000,1.586001,2.4176073,,,,,,,,,,,,,, -114100,1.6336955,1.7124047,,,,,,,,,,,,,, -114200,1.5022185,2.1513777,,,,,,,,,,,,,, -114300,1.725867,1.7763795,,,,,,,,,,,,,, -114400,1.6709968,1.9494374,,,,,,,,,,,,,, -114500,1.5325161,3.1829407,,,,,,,,,,,,,, -114600,1.7382648,1.8854133,,,,,,,,,,,,,, -114700,1.7638075,1.720125,,,,,,,,,,,,,, -114800,1.6259711,1.659574,,,,,,,,,,,,,, -114900,1.6581753,4.0685387,,,,,,,,,,,,,, -114924,,,0.7758007645606995,0.8787176609039307,0.7107399702072144,1.1633338928222656,50000.0,0.5843999981880188,1.8155479431152344,10000.0,51299.89177918434,55209.74253320694,51299.89177918434,3897.221100330353,6.308267831802368,0.0 -115000,1.7826043,1.9003284,,,,,,,,,,,,,, -115100,1.7331783,1.802737,,,,,,,,,,,,,, -115200,1.4917527,3.2160926,,,,,,,,,,,,,, -115300,1.58561,3.670326,,,,,,,,,,,,,, -115400,1.6962625,1.7947905,,,,,,,,,,,,,, -115500,1.862955,4.23993,,,,,,,,,,,,,, -115600,1.4495026,3.267913,,,,,,,,,,,,,, -115700,1.7095295,2.3590317,,,,,,,,,,,,,, -115800,1.5324546,3.0931692,,,,,,,,,,,,,, -115864,,,0.7816015481948853,0.8619747757911682,0.7142999768257141,1.1624583005905151,50000.0,0.5877000093460083,1.806920051574707,10000.0,51720.268078804016,55662.47427988053,51720.268078804016,3929.47137761116,6.3619771003723145,0.0 -115900,1.789439,1.7880068,,,,,,,,,,,,,, -116000,1.70438,1.8690912,,,,,,,,,,,,,, -116100,1.6390431,2.3751132,,,,,,,,,,,,,, -116200,1.6088128,3.7393994,,,,,,,,,,,,,, -116300,1.6982276,1.9000762,,,,,,,,,,,,,, -116400,1.7225531,3.6723962,,,,,,,,,,,,,, -116500,1.6112709,2.2561321,,,,,,,,,,,,,, -116600,1.6675627,1.7950565,,,,,,,,,,,,,, -116700,1.4382428,2.269723,,,,,,,,,,,,,, -116800,1.706426,1.7113917,,,,,,,,,,,,,, -116803,,,0.7881640195846558,0.8116686940193176,0.7134599685668945,1.1392488479614258,50000.0,0.5905000567436218,1.7739776372909546,10000.0,52140.46926808357,56118.18588399887,52140.46926808357,3964.8775465488434,6.414790630340576,0.0 -116900,1.6922643,4.207901,,,,,,,,,,,,,, -117000,1.5870115,1.7247107,,,,,,,,,,,,,, -117100,1.539365,2.029422,,,,,,,,,,,,,, -117200,1.6937904,3.9821517,,,,,,,,,,,,,, -117300,1.560734,2.4910438,,,,,,,,,,,,,, -117400,1.6209396,1.9319962,,,,,,,,,,,,,, -117500,1.6961459,1.7570543,,,,,,,,,,,,,, -117600,1.5512027,2.0601404,,,,,,,,,,,,,, -117700,1.6523063,4.0671225,,,,,,,,,,,,,, -117743,,,0.7998046875,0.7883418202400208,0.7152000069618225,1.1490259170532229,50000.0,0.5903000235557556,1.7920637130737305,10000.0,52560.747259140015,56571.47288775444,52560.747259140015,3997.784056663513,6.465289831161499,0.0 -117800,1.5571452,2.1408389,,,,,,,,,,,,,, -117900,1.6055236,1.6903237,,,,,,,,,,,,,, -118000,1.5662403,2.5204751,,,,,,,,,,,,,, -118100,1.7062004,1.7422321,,,,,,,,,,,,,, -118200,1.7337551,1.7881844,,,,,,,,,,,,,, -118300,1.6482029,1.7016042,,,,,,,,,,,,,, -118400,1.5657765,2.7698917,,,,,,,,,,,,,, -118500,1.6183335,2.0186238,,,,,,,,,,,,,, -118600,1.7088039,1.726295,,,,,,,,,,,,,, -118679,,,0.7859179377555847,0.825340986251831,0.7155599594116211,1.1319674253463743,50000.0,0.5915000438690186,1.7688794136047363,10000.0,52980.7577688694,57023.66591095925,52980.7577688694,4029.862047433853,6.51940655708313,0.0 -118700,1.903292,1.7581491,,,,,,,,,,,,,, -118800,1.6510608,1.9921055,,,,,,,,,,,,,, -118900,1.4537841,2.5479453,,,,,,,,,,,,,, -119000,1.7849907,1.6505892,,,,,,,,,,,,,, -119100,1.5395222,2.1731114,,,,,,,,,,,,,, -119200,1.6415405,2.6105566,,,,,,,,,,,,,, -119300,1.6737747,1.783515,,,,,,,,,,,,,, -119400,1.7688041,1.7825199,,,,,,,,,,,,,, -119500,1.6274045,3.9525158,,,,,,,,,,,,,, -119600,1.6371186,1.9305812,,,,,,,,,,,,,, -119619,,,0.7870898246765137,0.8163532614707947,0.7176799774169922,1.1236371994018557,50000.0,0.5887000560760498,1.7725476026535034,10000.0,53401.04287528992,57477.70270061493,53401.04287528992,4063.4930033683777,6.587791919708252,0.0 -119700,1.7236392,2.1776376,,,,,,,,,,,,,, -119800,1.50931,2.3727386,,,,,,,,,,,,,, -119900,1.7278491,1.9017109,,,,,,,,,,,,,, -120000,1.6553416,3.561149,,,,,,,,,,,,,, -120100,1.7148914,1.7847455,,,,,,,,,,,,,, -120200,1.7658005,1.7213283,,,,,,,,,,,,,, -120300,1.8116151,1.7906648,,,,,,,,,,,,,, -120400,1.7343295,2.041704,,,,,,,,,,,,,, -120500,1.9120954,1.7531419,,,,,,,,,,,,,, -120558,,,0.7944140434265137,0.8256067633628845,0.7125999927520752,1.1662746667861938,50000.0,0.5924000144004822,1.7938237190246582,10000.0,53820.95779657364,57937.69656896591,53820.95779657364,4103.463726997376,6.644713640213013,0.0 -120600,1.7350316,2.4185505,,,,,,,,,,,,,, -120700,1.724499,1.789857,,,,,,,,,,,,,, -120800,1.5839775,2.9139469,,,,,,,,,,,,,, -120900,1.6808511,1.6541106,,,,,,,,,,,,,, -121000,1.8052438,1.9782479,,,,,,,,,,,,,, -121100,1.7009329,1.7000158,,,,,,,,,,,,,, -121200,1.9053537,1.8295805,,,,,,,,,,,,,, -121300,1.7770221,1.8376712,,,,,,,,,,,,,, -121400,1.6594592,1.6443661,,,,,,,,,,,,,, -121500,1.849317,1.7783267,,,,,,,,,,,,,, -121504,,,0.78466796875,0.866515576839447,0.7185399532318115,1.1521672010421753,50000.0,0.5913000106811523,1.7923246622085571,10000.0,54240.93333983421,58391.9005625248,54240.93333983421,4137.593298435211,6.69199538230896,0.0 -121600,1.7063042,3.8777595,,,,,,,,,,,,,, -121700,1.6362002,3.7986252,,,,,,,,,,,,,, -121800,1.9294525,4.039221,,,,,,,,,,,,,, -121900,1.5772029,2.576149,,,,,,,,,,,,,, -122000,1.8961365,4.0209503,,,,,,,,,,,,,, -122100,1.6772189,1.6976961,,,,,,,,,,,,,, -122200,1.7264436,1.7295821,,,,,,,,,,,,,, -122300,1.6397656,2.7587805,,,,,,,,,,,,,, -122400,1.7409593,1.725064,,,,,,,,,,,,,, -122443,,,0.7889843583106995,0.8201755285263062,0.7201600074768066,1.1214879751205444,50000.0,0.5937000513076782,1.7646859884262085,10000.0,54660.91382360458,58844.77633571625,54660.91382360458,4170.385221481323,6.743417024612427,0.0 -122500,1.8169512,1.8960502,,,,,,,,,,,,,, -122600,1.7185509,1.9523219,,,,,,,,,,,,,, -122700,1.6453134,2.3623717,,,,,,,,,,,,,, -122800,1.7307388,1.7352899,,,,,,,,,,,,,, -122900,1.6860145,1.6347378,,,,,,,,,,,,,, -123000,1.7842981,1.8321975,,,,,,,,,,,,,, -123100,1.7064117,3.3320882,,,,,,,,,,,,,, -123200,1.8396027,1.6625692,,,,,,,,,,,,,, -123300,1.6973097,2.5135179,,,,,,,,,,,,,, -123381,,,0.8001562356948853,0.7793758511543274,0.7241399884223938,1.10870361328125,50000.0,0.5998000502586365,1.7507137060165403,10000.0,55081.21804690361,59299.64342093468,55081.21804690361,4204.8427193164825,6.7984137535095215,0.0 -123400,1.8069435,1.5141249,,,,,,,,,,,,,, -123500,1.4792093,2.410326,,,,,,,,,,,,,, -123600,1.74337,3.0919785,,,,,,,,,,,,,, -123700,1.544077,2.4404354,,,,,,,,,,,,,, -123800,1.8015172,3.7637057,,,,,,,,,,,,,, -123900,1.6969503,2.280174,,,,,,,,,,,,,, -124000,1.8388788,2.4986084,,,,,,,,,,,,,, -124100,1.8121017,3.1351442,,,,,,,,,,,,,, -124200,1.8967943,1.5505099,,,,,,,,,,,,,, -124300,1.8350445,1.8896735,,,,,,,,,,,,,, -124320,,,0.8098242282867432,0.7531871795654297,0.720579981803894,1.1280159950256348,50000.0,0.5972000360488892,1.7637840509414673,10000.0,55501.22751235962,59763.037217378616,55501.22751235962,4248.118583202362,6.855582475662232,0.0 -124400,1.648699,2.9243693,,,,,,,,,,,,,, -124500,1.8767884,3.9982564,,,,,,,,,,,,,, -124600,1.7851505,3.2782526,,,,,,,,,,,,,, -124700,1.7240181,3.0518653,,,,,,,,,,,,,, -124800,1.8829144,2.8794522,,,,,,,,,,,,,, -124900,1.7089758,1.8883386,,,,,,,,,,,,,, -125000,1.836279,2.0298297,,,,,,,,,,,,,, -125100,1.6276493,2.2161455,,,,,,,,,,,,,, -125200,1.8170025,3.285377,,,,,,,,,,,,,, -125263,,,0.7971484065055847,0.7829906344413757,0.7261599898338318,1.0861806869506836,50000.0,0.5982000231742859,1.7279510498046875,10000.0,55921.260788440704,60215.49275445938,55921.260788440704,4280.447121620178,6.898099422454834,0.0 -125300,1.6223688,2.5961685,,,,,,,,,,,,,, -125400,1.776041,1.5932921,,,,,,,,,,,,,, -125500,1.9124886,1.7774608,,,,,,,,,,,,,, -125600,1.9408538,3.7847922,,,,,,,,,,,,,, -125700,1.6849477,1.6966372,,,,,,,,,,,,,, -125800,1.992713,1.6131756,,,,,,,,,,,,,, -125900,1.6236633,2.7154796,,,,,,,,,,,,,, -126000,1.8012586,1.802148,,,,,,,,,,,,,, -126100,1.7929991,1.5972213,,,,,,,,,,,,,, -126199,,,0.802539050579071,0.7490731477737427,0.7263999581336975,1.0830198526382446,50000.0,0.6076000332832336,1.6989731788635254,10000.0,56341.26990056038,60668.204579114914,56341.26990056038,4313.0392434597015,6.958737850189209,0.0 -126200,1.9356067,3.8871868,,,,,,,,,,,,,, -126300,1.8833009,1.7498927,,,,,,,,,,,,,, -126400,1.7543988,2.9998791,,,,,,,,,,,,,, -126500,1.784324,1.6565539,,,,,,,,,,,,,, -126600,1.782605,1.8723825,,,,,,,,,,,,,, -126700,1.7285594,2.834652,,,,,,,,,,,,,, -126800,1.8124917,1.4952443,,,,,,,,,,,,,, -126900,1.7705164,1.6467469,,,,,,,,,,,,,, -127000,1.7247286,2.02794,,,,,,,,,,,,,, -127100,1.7608329,2.9913383,,,,,,,,,,,,,, -127137,,,0.8118359446525574,0.7276731133460999,0.7282399535179138,1.081851363182068,50000.0,0.6055999994277954,1.7176285982131958,10000.0,56761.53534936905,61125.58907032013,56761.53534936905,4350.054116725922,7.012200832366943,0.0 -127200,1.9031651,1.5585557,,,,,,,,,,,,,, -127300,1.6969585,2.2242184,,,,,,,,,,,,,, -127400,1.8847336,3.075008,,,,,,,,,,,,,, -127500,1.7494557,1.781187,,,,,,,,,,,,,, -127600,1.6617498,3.2986298,,,,,,,,,,,,,, -127700,1.7162467,3.4835677,,,,,,,,,,,,,, -127800,1.9219048,1.6335918,,,,,,,,,,,,,, -127900,1.9214984,1.649625,,,,,,,,,,,,,, -128000,1.9565723,2.697335,,,,,,,,,,,,,, -128076,,,0.7968358993530273,0.7937740087509155,0.7278199791908264,1.0928000211715698,50000.0,0.603600025177002,1.7354732751846311,10000.0,57181.63070321083,61578.02355456352,57181.63070321083,4382.287897109985,7.066992282867432,0.0 -128100,2.0671394,2.4560664,,,,,,,,,,,,,, -128200,1.6946738,3.2308342,,,,,,,,,,,,,, -128300,1.882576,1.6265155,,,,,,,,,,,,,, -128400,2.2133636,3.1526663,,,,,,,,,,,,,, -128500,1.9370706,1.5583117,,,,,,,,,,,,,, -128600,1.8973253,1.5475186,,,,,,,,,,,,,, -128700,2.17205,1.6452795,,,,,,,,,,,,,, -128800,1.8700156,1.861767,,,,,,,,,,,,,, -128900,1.8530352,1.5665374,,,,,,,,,,,,,, -129000,1.9074844,2.9627829,,,,,,,,,,,,,, -129015,,,0.8051171898841858,0.7775402665138245,0.7311999797821045,1.1002105474472046,50000.0,0.6086000204086304,1.7183754444122314,10000.0,57601.90099310875,62033.05067133904,57601.90099310875,4416.935405731201,7.12395167350769,0.0 -129100,1.61197,2.3959174,,,,,,,,,,,,,, -129200,1.8759706,3.1727746,,,,,,,,,,,,,, -129300,1.9327592,3.0431216,,,,,,,,,,,,,, -129400,2.1208837,3.9761627,,,,,,,,,,,,,, -129500,1.9967858,3.7619028,,,,,,,,,,,,,, -129600,1.7972451,3.4298065,,,,,,,,,,,,,, -129700,1.9600159,1.6241959,,,,,,,,,,,,,, -129800,1.8068432,3.4576976,,,,,,,,,,,,,, -129900,1.8199896,3.5039442,,,,,,,,,,,,,, -129956,,,0.8118359446525574,0.7209511995315552,0.7298399806022644,1.0690104961395264,50000.0,0.6084000468254089,1.6949561834335327,10000.0,58022.18709683418,62487.10815668106,58022.18709683418,4450.6000871658325,7.179595232009888,0.0 -130000,1.7988799,3.37854,,,,,,,,,,,,,, -130100,1.7363364,1.918321,,,,,,,,,,,,,, -130200,1.7408713,2.6948698,,,,,,,,,,,,,, -130300,1.7391471,2.2634802,,,,,,,,,,,,,, -130400,1.8265183,1.6117904,,,,,,,,,,,,,, -130500,1.7453866,1.8682728,,,,,,,,,,,,,, -130600,2.0554316,1.7604659,,,,,,,,,,,,,, -130700,1.9974017,1.6452262,,,,,,,,,,,,,, -130800,1.9631444,1.727907,,,,,,,,,,,,,, -130898,,,0.8252343535423279,0.6869196891784668,0.7318999767303467,1.0726889371871948,50000.0,0.6083000302314758,1.700702667236328,10000.0,58442.412564754486,62939.74114704132,58442.412564754486,4482.868451356888,7.265031337738037,0.0 -130900,1.7966968,2.9494407,,,,,,,,,,,,,, -131000,1.8200307,1.6840984,,,,,,,,,,,,,, -131100,1.8876401,2.5041122,,,,,,,,,,,,,, -131200,1.9889399,1.6696737,,,,,,,,,,,,,, -131300,2.0887945,3.7841344,,,,,,,,,,,,,, -131400,1.8514072,1.633085,,,,,,,,,,,,,, -131500,1.8714706,1.7483816,,,,,,,,,,,,,, -131600,1.9973037,3.495506,,,,,,,,,,,,,, -131700,2.0382252,1.6500503,,,,,,,,,,,,,, -131800,1.8518218,1.5997565,,,,,,,,,,,,,, -131835,,,0.8069726228713989,0.7426121234893799,0.7356399893760681,1.055567502975464,50000.0,0.6165000200271606,1.6794337034225464,10000.0,58862.37831878662,63392.63826608658,58862.37831878662,4515.693987369537,7.319501161575317,0.0 -131900,1.9269629,1.6205728,,,,,,,,,,,,,, -132000,1.9955909,1.6561289,,,,,,,,,,,,,, -132100,1.9218115,1.6143934,,,,,,,,,,,,,, -132200,2.0541773,3.2615135,,,,,,,,,,,,,, -132300,1.8162062,1.7487082,,,,,,,,,,,,,, -132400,1.7207636,2.8154912,,,,,,,,,,,,,, -132500,1.9362327,1.6131977,,,,,,,,,,,,,, -132600,1.8679111,2.010462,,,,,,,,,,,,,, -132700,2.0561588,1.6003844,,,,,,,,,,,,,, -132771,,,0.8157616853713989,0.7002706527709961,0.7349199652671814,1.0541627407073977,50000.0,0.6173000335693359,1.6698813438415527,10000.0,59282.7042388916,63850.35701370239,59282.7042388916,4552.977478265762,7.37714958190918,0.0 -132800,1.9396183,1.5572383,,,,,,,,,,,,,, -132900,1.7039245,1.8610697,,,,,,,,,,,,,, -133000,1.6493167,2.3456805,,,,,,,,,,,,,, -133100,2.054145,1.4721634,,,,,,,,,,,,,, -133200,1.9474665,1.5437886,,,,,,,,,,,,,, -133300,1.9817331,1.7027845,,,,,,,,,,,,,, -133400,2.0172992,1.5790231,,,,,,,,,,,,,, -133500,1.8289788,2.8022728,,,,,,,,,,,,,, -133600,1.8277363,1.6226865,,,,,,,,,,,,,, -133700,1.9273534,2.5964556,,,,,,,,,,,,,, -133714,,,0.8212499618530273,0.6917605996131897,0.7364400029182434,1.0647858381271362,50000.0,0.6152000427246094,1.6732351779937744,10000.0,59702.83695149422,64309.51987743378,59702.83695149422,4591.900043010712,7.433608055114746,0.0 -133800,2.020055,1.6152966,,,,,,,,,,,,,, -133900,1.999946,1.9039792,,,,,,,,,,,,,, -134000,1.8790288,3.346527,,,,,,,,,,,,,, -134100,1.7702433,2.270504,,,,,,,,,,,,,, -134200,1.9522302,1.5353614,,,,,,,,,,,,,, -134300,1.909043,1.6119764,,,,,,,,,,,,,, -134400,2.21978,1.5955769,,,,,,,,,,,,,, -134500,1.9747254,1.494662,,,,,,,,,,,,,, -134600,1.9041642,1.557989,,,,,,,,,,,,,, -134657,,,0.8107812404632568,0.7149853110313416,0.7350999712944031,1.0497262477874756,50000.0,0.6171000003814697,1.657669186592102,10000.0,60122.92111968994,64759.81437373161,60122.92111968994,4622.005768299103,7.486406564712524,0.0 -134700,2.0176616,1.8083832,,,,,,,,,,,,,, -134800,1.8903509,1.6615704,,,,,,,,,,,,,, -134900,2.1625483,4.028393,,,,,,,,,,,,,, -135000,1.951255,1.5857437,,,,,,,,,,,,,, -135100,2.0393078,1.5029836,,,,,,,,,,,,,, -135200,1.8918457,1.6127087,,,,,,,,,,,,,, -135300,1.8994833,2.576543,,,,,,,,,,,,,, -135400,2.0448458,1.5852702,,,,,,,,,,,,,, -135500,1.7711323,2.864032,,,,,,,,,,,,,, -135594,,,0.8163085579872131,0.6858136653900146,0.7380599975585938,1.0333293676376345,50000.0,0.6195000410079956,1.6498830318450928,10000.0,60542.98550081253,65214.08685183525,60542.98550081253,4656.106626033783,7.542491436004639,0.0 -135600,1.949833,1.4287534,,,,,,,,,,,,,, -135700,2.153741,2.2421196,,,,,,,,,,,,,, -135800,1.9669663,2.6553302,,,,,,,,,,,,,, -135900,2.044184,1.5851909,,,,,,,,,,,,,, -136000,1.9221331,2.4449325,,,,,,,,,,,,,, -136100,2.117177,3.9381802,,,,,,,,,,,,,, -136200,2.0233994,1.4076557,,,,,,,,,,,,,, -136300,2.046786,1.59639,,,,,,,,,,,,,, -136400,1.9516921,1.5840062,,,,,,,,,,,,,, -136500,2.0295813,1.5008894,,,,,,,,,,,,,, -136534,,,0.8235937356948853,0.691146969795227,0.7392999529838562,1.0445665121078491,50000.0,0.6193000078201294,1.6543912887573242,10000.0,60963.28740334511,65666.55383682251,60963.28740334511,4688.164028644562,7.599376678466797,0.0 -136600,1.9394156,1.4505546,,,,,,,,,,,,,, -136700,2.140205,1.5580931,,,,,,,,,,,,,, -136800,2.2897134,3.604985,,,,,,,,,,,,,, -136900,1.8400509,2.0864296,,,,,,,,,,,,,, -137000,1.8949456,1.9509461,,,,,,,,,,,,,, -137100,2.0914176,1.5741339,,,,,,,,,,,,,, -137200,1.9704928,1.5961137,,,,,,,,,,,,,, -137300,2.1663094,3.3567448,,,,,,,,,,,,,, -137400,1.909117,1.6616968,,,,,,,,,,,,,, -137478,,,0.8358203172683716,0.6423953771591187,0.7375800013542175,1.0486226081848145,50000.0,0.6115000247955322,1.687080979347229,10000.0,61383.44880151749,66125.54321527481,61383.44880151749,4726.867444038391,7.671452760696411,0.0 -137500,2.065991,1.5132885,,,,,,,,,,,,,, -137600,2.2121594,3.4452024,,,,,,,,,,,,,, -137700,1.9793876,2.192174,,,,,,,,,,,,,, -137800,1.8363545,3.142572,,,,,,,,,,,,,, -137900,1.9683062,1.5829488,,,,,,,,,,,,,, -138000,1.807642,3.0455594,,,,,,,,,,,,,, -138100,2.0185387,3.2120626,,,,,,,,,,,,,, -138200,1.9671707,1.5377173,,,,,,,,,,,,,, -138300,2.2114818,1.5358391,,,,,,,,,,,,,, -138400,2.012448,1.5529782,,,,,,,,,,,,,, -138421,,,0.8221484422683716,0.6727483868598938,0.7418599724769592,1.0180381536483765,50000.0,0.6202000379562378,1.641932487487793,10000.0,61803.55562138557,66582.37235975266,61803.55562138557,4763.486090421677,7.724071741104126,0.0 -138500,2.1186626,1.866011,,,,,,,,,,,,,, -138600,2.0523252,1.7660222,,,,,,,,,,,,,, -138700,2.3905454,3.432992,,,,,,,,,,,,,, -138800,1.9240036,2.2172153,,,,,,,,,,,,,, -138900,2.153019,1.5816942,,,,,,,,,,,,,, -139000,2.2085657,3.6227403,,,,,,,,,,,,,, -139100,2.0006313,1.6218295,,,,,,,,,,,,,, -139200,2.0629957,2.9362338,,,,,,,,,,,,,, -139300,1.9613113,3.4406133,,,,,,,,,,,,,, -139364,,,0.8287304639816284,0.6483760476112366,0.7456600069999695,1.0039118528366089,50000.0,0.6260000467300415,1.6316074132919312,10000.0,62223.732880592346,67035.52803492546,62223.732880592346,4796.360649824143,7.7763237953186035,0.0 -139400,2.0397215,1.4861196,,,,,,,,,,,,,, -139500,2.0855727,1.623058,,,,,,,,,,,,,, -139600,2.0925024,1.5798023,,,,,,,,,,,,,, -139700,2.1662393,1.8627013,,,,,,,,,,,,,, -139800,2.1522708,1.6410577,,,,,,,,,,,,,, -139900,1.8734295,2.2748003,,,,,,,,,,,,,, -140000,2.0328324,1.4387721,,,,,,,,,,,,,, -140100,2.2394392,1.5035859,,,,,,,,,,,,,, -140200,1.9169699,2.080067,,,,,,,,,,,,,, -140299,,,0.8375195264816284,0.6271106004714966,0.7462999820709229,1.0144020318984983,50000.0,0.6199000477790833,1.6409794092178345,10000.0,62643.9823744297,67494.0847992897,62643.9823744297,4834.561587095261,7.831675052642822,0.0 -140300,2.0103362,1.5884004,,,,,,,,,,,,,, -140400,2.3369884,3.9347858,,,,,,,,,,,,,, -140500,1.8073715,2.1503243,,,,,,,,,,,,,, -140600,2.1836336,1.4675364,,,,,,,,,,,,,, -140700,1.9078524,2.0042644,,,,,,,,,,,,,, -140800,2.1987727,3.611151,,,,,,,,,,,,,, -140900,2.1026475,1.5071032,,,,,,,,,,,,,, -141000,2.223754,1.7817497,,,,,,,,,,,,,, -141100,2.093336,3.3449864,,,,,,,,,,,,,, -141200,2.1559336,1.5141287,,,,,,,,,,,,,, -141239,,,0.8244726657867432,0.6539803743362427,0.7459200024604797,1.0001908540725708,50000.0,0.6206000447273254,1.6216108798980713,10000.0,63063.95926761627,67954.7954685688,63063.95926761627,4875.187557458878,7.8880674839019775,0.0 -141300,2.0031786,1.5225172,,,,,,,,,,,,,, -141400,1.9266855,1.530598,,,,,,,,,,,,,, -141500,2.2786381,1.4651182,,,,,,,,,,,,,, -141600,1.9693297,2.3777618,,,,,,,,,,,,,, -141700,2.2628424,1.5736876,,,,,,,,,,,,,, -141800,1.8796175,2.9211874,,,,,,,,,,,,,, -141900,2.178299,3.4418077,,,,,,,,,,,,,, -142000,2.38241,3.622863,,,,,,,,,,,,,, -142100,2.0680697,1.3871946,,,,,,,,,,,,,, -142181,,,0.83216792345047,0.6442916989326477,0.7468400001525879,1.000510334968567,50000.0,0.6249000430107117,1.6253139972686768,10000.0,63484.03082895279,68411.03473544121,63484.03082895279,4911.257856369019,7.93324613571167,0.0 -142200,2.2704594,1.59148,,,,,,,,,,,,,, -142300,1.9790572,1.323565,,,,,,,,,,,,,, -142400,1.790339,2.7470133,,,,,,,,,,,,,, -142500,2.0938451,1.4311377,,,,,,,,,,,,,, -142600,2.0135312,2.3856895,,,,,,,,,,,,,, -142700,2.569591,3.7815444,,,,,,,,,,,,,, -142800,2.341023,1.4275258,,,,,,,,,,,,,, -142900,2.0691655,1.3915803,,,,,,,,,,,,,, -143000,1.9723225,1.8380147,,,,,,,,,,,,,, -143100,2.1540012,2.0979035,,,,,,,,,,,,,, -143119,,,0.8355468511581421,0.6156179904937744,0.7492799758911133,0.9899269938468932,50000.0,0.6276000142097473,1.612836480140686,10000.0,63904.20986747742,68868.8211402893,63904.20986747742,4948.755502462387,7.992448329925537,0.0 -143200,2.5379088,3.439807,,,,,,,,,,,,,, -143300,2.839794,3.7158654,,,,,,,,,,,,,, -143400,2.0849853,3.1687694,,,,,,,,,,,,,, -143500,2.0888453,1.5425295,,,,,,,,,,,,,, -143600,2.1360483,1.4629519,,,,,,,,,,,,,, -143700,2.207595,1.4053588,,,,,,,,,,,,,, -143800,2.1742623,1.6104975,,,,,,,,,,,,,, -143900,2.2769494,2.8185868,,,,,,,,,,,,,, -144000,2.3751037,3.5546513,,,,,,,,,,,,,, -144060,,,0.8414062261581421,0.6124166250228882,0.7498999834060669,0.9950244426727296,50000.0,0.6288000345230103,1.6108155250549316,10000.0,64324.40742135048,69322.31814837456,64324.40742135048,4981.9539959430695,8.04038405418396,0.0 -144100,2.1440554,1.4449723,,,,,,,,,,,,,, -144200,2.106941,2.9784899,,,,,,,,,,,,,, -144300,2.1510932,2.1398673,,,,,,,,,,,,,, -144400,2.2062314,3.4545512,,,,,,,,,,,,,, -144500,2.2891648,3.1604931,,,,,,,,,,,,,, -144600,2.272717,1.4926732,,,,,,,,,,,,,, -144700,2.1455572,1.3686535,,,,,,,,,,,,,, -144800,2.2382472,3.5003057,,,,,,,,,,,,,, -144900,2.0889966,1.4183768,,,,,,,,,,,,,, -144994,,,0.8344921469688416,0.6352013945579529,0.7530800104141235,0.9881432056427002,50000.0,0.6317000389099121,1.5940308570861816,10000.0,64744.35452175141,69779.76891851425,64744.35452175141,5019.346714496613,8.100827932357788,0.0 -145000,2.119445,1.4265219,,,,,,,,,,,,,, -145100,2.492371,3.7529347,,,,,,,,,,,,,, -145200,2.2943356,3.031607,,,,,,,,,,,,,, -145300,2.5228448,3.6568666,,,,,,,,,,,,,, -145400,2.1123135,1.463979,,,,,,,,,,,,,, -145500,2.275937,1.4229128,,,,,,,,,,,,,, -145600,2.110172,2.4899235,,,,,,,,,,,,,, -145700,1.973683,1.4509326,,,,,,,,,,,,,, -145800,2.0513644,1.4606953,,,,,,,,,,,,,, -145900,2.3044007,2.6462045,,,,,,,,,,,,,, -145924,,,0.8385546803474426,0.615352988243103,0.7503199577331543,0.9843414425849916,50000.0,0.6289000511169434,1.6003049612045288,10000.0,65164.53370523453,70233.34082078934,65164.53370523453,5052.624608039856,8.164747714996338,0.0 -146000,2.1553802,1.3717525,,,,,,,,,,,,,, -146100,2.4522328,3.4428587,,,,,,,,,,,,,, -146200,2.5545454,1.3782353,,,,,,,,,,,,,, -146300,2.3518968,1.492703,,,,,,,,,,,,,, -146400,2.3110645,1.8957642,,,,,,,,,,,,,, -146500,2.3215017,1.542544,,,,,,,,,,,,,, -146600,2.2324457,1.3801678,,,,,,,,,,,,,, -146700,2.350521,1.5086987,,,,,,,,,,,,,, -146800,2.563646,1.4111278,,,,,,,,,,,,,, -146860,,,0.8499999642372131,0.5691419243812561,0.7542200088500977,0.9688873291015624,50000.0,0.6330000162124634,1.5802263021469116,10000.0,65584.762103796,70688.78717851639,65584.762103796,5087.729380130768,8.226694107055664,0.0 -146900,2.242185,1.4641306,,,,,,,,,,,,,, -147000,2.2347455,1.4192193,,,,,,,,,,,,,, -147100,2.5392058,1.437362,,,,,,,,,,,,,, -147200,2.0628355,1.9636188,,,,,,,,,,,,,, -147300,2.1035838,1.3331153,,,,,,,,,,,,,, -147400,2.1306024,2.7778668,,,,,,,,,,,,,, -147500,2.4652162,1.4544599,,,,,,,,,,,,,, -147600,2.280755,1.3931122,,,,,,,,,,,,,, -147700,2.234718,2.4698591,,,,,,,,,,,,,, -147800,2.7749856,3.5154006,,,,,,,,,,,,,, -147801,,,0.841113269329071,0.5822334289550781,0.7569400072097778,0.9509395956993104,50000.0,0.6347000598907471,1.5633114576339722,10000.0,66005.28381443024,71150.55936717987,66005.28381443024,5128.872689962387,8.282914638519287,0.0 -147900,2.0857725,2.2742052,,,,,,,,,,,,,, -148000,2.172836,1.3986046,,,,,,,,,,,,,, -148100,2.1163287,1.4005079,,,,,,,,,,,,,, -148200,2.6931717,1.3175968,,,,,,,,,,,,,, -148300,2.1312664,1.3938673,,,,,,,,,,,,,, -148400,2.4118705,3.1531825,,,,,,,,,,,,,, -148500,2.127583,1.5107754,,,,,,,,,,,,,, -148600,2.0307767,1.3710409,,,,,,,,,,,,,, -148700,2.347665,3.3673022,,,,,,,,,,,,,, -148745,,,0.8462304472923279,0.5925246477127075,0.7569599747657776,0.9620269536972046,50000.0,0.6318000555038452,1.5700883865356443,10000.0,66425.58796477318,71601.20407652855,66425.58796477318,5159.098656654358,8.345009803771973,0.0 -148800,2.329965,1.3954521,,,,,,,,,,,,,, -148900,2.2298493,1.243214,,,,,,,,,,,,,, -149000,2.331474,1.4667926,,,,,,,,,,,,,, -149100,2.314158,1.7117947,,,,,,,,,,,,,, -149200,2.1397295,1.9922514,,,,,,,,,,,,,, -149300,2.3787668,3.282425,,,,,,,,,,,,,, -149400,2.3878198,2.264548,,,,,,,,,,,,,, -149500,2.3562734,1.4353768,,,,,,,,,,,,,, -149600,2.5229337,1.5411955,,,,,,,,,,,,,, -149680,,,0.8490039110183716,0.5663567185401917,0.7562800049781799,0.9526709914207458,50000.0,0.6348000168800354,1.5691616535186768,10000.0,66845.85511136055,72062.71906900406,66845.85511136055,5200.234387397766,8.406283617019653,0.0 -149700,2.5562234,3.4928508,,,,,,,,,,,,,, -149800,2.2389073,1.3622205,,,,,,,,,,,,,, -149900,2.5586607,3.4821994,,,,,,,,,,,,,, -150000,2.399906,3.4233649,,,,,,,,,,,,,, -150100,2.438125,3.442807,,,,,,,,,,,,,, -150200,2.7004166,3.6840792,,,,,,,,,,,,,, -150300,2.258072,1.5547242,,,,,,,,,,,,,, -150400,2.3058364,1.442549,,,,,,,,,,,,,, -150500,2.2319045,1.383446,,,,,,,,,,,,,, -150600,2.3627923,1.3616025,,,,,,,,,,,,,, -150622,,,0.8469530940055847,0.5782976150512695,0.7572999596595764,0.9563175439834596,50000.0,0.6383000016212463,1.5757783651351929,10000.0,67265.89777302742,72519.73343348503,67265.89777302742,5237.093623876572,8.467820644378662,0.0 -150700,2.2242968,2.1985252,,,,,,,,,,,,,, -150800,2.1308963,1.6946592,,,,,,,,,,,,,, -150900,2.4022906,2.9791248,,,,,,,,,,,,,, -151000,2.3231082,1.4909402,,,,,,,,,,,,,, -151100,2.4672601,1.2960887,,,,,,,,,,,,,, -151200,2.336573,1.4013684,,,,,,,,,,,,,, -151300,2.1972494,1.9725368,,,,,,,,,,,,,, -151400,2.1388059,2.70953,,,,,,,,,,,,,, -151500,2.1757,1.4724479,,,,,,,,,,,,,, -151562,,,0.8485546708106995,0.580278217792511,0.7583000063896179,0.9555332660675048,50000.0,0.6325000524520874,1.5823768377304075,10000.0,67685.99413371086,72975.18101358414,67685.99413371086,5272.33419585228,8.526651859283447,0.0 -151600,2.643734,1.4694463,,,,,,,,,,,,,, -151700,2.345798,2.9071274,,,,,,,,,,,,,, -151800,2.4895732,1.3730588,,,,,,,,,,,,,, -151900,2.4539242,2.1612997,,,,,,,,,,,,,, -152000,2.3698604,2.5387037,,,,,,,,,,,,,, -152100,2.6103854,2.8600297,,,,,,,,,,,,,, -152200,2.8213255,3.6418805,,,,,,,,,,,,,, -152300,2.3982763,1.2803007,,,,,,,,,,,,,, -152400,2.4159493,1.6667027,,,,,,,,,,,,,, -152498,,,0.8522460460662842,0.5873256325721741,0.7579799890518188,0.9749816656112672,50000.0,0.6325000524520874,1.597623586654663,10000.0,68106.0849275589,73442.63851284981,68106.0849275589,5319.5845646858215,8.591355562210083,0.0 -152500,2.8295386,1.4381244,,,,,,,,,,,,,, -152600,2.5899441,3.499722,,,,,,,,,,,,,, -152700,2.518698,1.4962611,,,,,,,,,,,,,, -152800,2.4166658,1.2759862,,,,,,,,,,,,,, -152900,2.3998919,1.4432597,,,,,,,,,,,,,, -153000,2.4860208,1.5054075,,,,,,,,,,,,,, -153100,2.3952503,1.473865,,,,,,,,,,,,,, -153200,2.263215,1.3504497,,,,,,,,,,,,,, -153300,2.4163635,1.2881175,,,,,,,,,,,,,, -153400,2.3443263,1.4056089,,,,,,,,,,,,,, -153442,,,0.8603124618530273,0.5331978797912598,0.7604599595069885,0.9412803649902344,50000.0,0.6397000551223755,1.548847198486328,10000.0,68526.38285326958,73897.06649947166,68526.38285326958,5353.610310792923,8.64404821395874,0.0 -153500,2.5390103,3.0414526,,,,,,,,,,,,,, -153600,2.2959876,2.2153177,,,,,,,,,,,,,, -153700,2.2937036,1.3490298,,,,,,,,,,,,,, -153800,2.4287035,1.3166381,,,,,,,,,,,,,, -153900,2.524218,2.647284,,,,,,,,,,,,,, -154000,2.6975863,1.3062422,,,,,,,,,,,,,, -154100,2.5905352,3.2172537,,,,,,,,,,,,,, -154200,2.3520122,1.4104971,,,,,,,,,,,,,, -154300,2.7815115,1.3305444,,,,,,,,,,,,,, -154381,,,0.8544335961341858,0.5505743622779846,0.7605400085449219,0.9401016235351562,50000.0,0.6432000398635864,1.552872896194458,10000.0,68946.4556479454,74353.45694947243,68946.4556479454,5389.820098400116,8.700559377670288,0.0 -154400,2.8684378,3.4808688,,,,,,,,,,,,,, -154500,2.5250697,3.295135,,,,,,,,,,,,,, -154600,2.5835211,1.8558364,,,,,,,,,,,,,, -154700,2.5598354,2.7902732,,,,,,,,,,,,,, -154800,2.414934,1.2489481,,,,,,,,,,,,,, -154900,2.8016949,2.9023504,,,,,,,,,,,,,, -155000,3.0067635,3.6695924,,,,,,,,,,,,,, -155100,2.3560514,1.3428591,,,,,,,,,,,,,, -155200,2.4530413,2.8535767,,,,,,,,,,,,,, -155300,2.1820922,2.0287702,,,,,,,,,,,,,, -155321,,,0.8555663824081421,0.5521935224533081,0.7650399804115295,0.9392313957214355,50000.0,0.6396000385284424,1.5510258674621582,10000.0,69366.76836133003,74806.56618714333,69366.76836133003,5422.504199266434,8.762062788009644,0.0 -155400,2.496124,1.2725835,,,,,,,,,,,,,, -155500,2.4265227,1.4703577,,,,,,,,,,,,,, -155600,3.1983778,3.6238437,,,,,,,,,,,,,, -155700,2.445474,1.3335125,,,,,,,,,,,,,, -155800,2.591101,1.3599311,,,,,,,,,,,,,, -155900,2.733759,3.121834,,,,,,,,,,,,,, -156000,2.3634827,1.8964467,,,,,,,,,,,,,, -156100,2.6176174,1.9063768,,,,,,,,,,,,,, -156200,2.4210417,1.6391029,,,,,,,,,,,,,, -156261,,,0.8625780940055847,0.516830325126648,0.7641800045967102,0.9275818467140198,50000.0,0.6413000226020813,1.5416245460510254,10000.0,69786.66360473633,75261.46242833138,69786.66360473633,5457.393801450729,8.822291851043701,0.0 -156300,3.0118172,3.6001153,,,,,,,,,,,,,, -156400,2.5309303,1.2920877,,,,,,,,,,,,,, -156500,2.5050988,1.3222766,,,,,,,,,,,,,, -156600,2.4481091,2.6376033,,,,,,,,,,,,,, -156700,2.619907,1.3850591,,,,,,,,,,,,,, -156800,2.504949,1.333097,,,,,,,,,,,,,, -156900,2.3342435,1.1791883,,,,,,,,,,,,,, -157000,3.0326202,3.3459382,,,,,,,,,,,,,, -157100,2.4894633,1.2087091,,,,,,,,,,,,,, -157195,,,0.8593164086341858,0.522832989692688,0.7650799751281738,0.9194480776786804,50000.0,0.6414000391960144,1.534990310668945,10000.0,70206.14763140678,75718.30789279938,70206.14763140678,5493.874259471893,9.651360750198364,0.0 -157200,2.7023406,3.2394855,,,,,,,,,,,,,, -157300,2.565319,1.3792894,,,,,,,,,,,,,, -157400,2.445211,1.309789,,,,,,,,,,,,,, -157500,2.5206342,1.3289369,,,,,,,,,,,,,, -157600,2.5531588,1.4591436,,,,,,,,,,,,,, -157700,2.7024584,3.290537,,,,,,,,,,,,,, -157800,4.08641,2.069746,,,,,,,,,,,,,, -157900,2.6360822,1.390083,,,,,,,,,,,,,, -158000,2.7519498,2.0432806,,,,,,,,,,,,,, -158100,2.551977,2.18289,,,,,,,,,,,,,, -158135,,,0.8620507717132568,0.5150201916694641,0.7664799690246582,0.9146672487258912,50000.0,0.6493000388145447,1.521514892578125,10000.0,70626.3047413826,76174.36131358147,70626.3047413826,5529.660228967667,9.710200309753418,0.0 -158200,2.4323752,2.4714994,,,,,,,,,,,,,, -158300,2.734751,2.7720604,,,,,,,,,,,,,, -158400,2.6875076,1.3524544,,,,,,,,,,,,,, -158500,2.4841194,1.1659174,,,,,,,,,,,,,, -158600,2.5810125,2.2912745,,,,,,,,,,,,,, -158700,2.5436866,1.2041522,,,,,,,,,,,,,, -158800,2.60428,2.8833137,,,,,,,,,,,,,, -158900,2.5729556,1.3247582,,,,,,,,,,,,,, -159000,2.9263854,1.409551,,,,,,,,,,,,,, -159079,,,0.8653515577316284,0.5030461549758911,0.7672199606895447,0.907681167125702,50000.0,0.6468000411987305,1.5106713771820068,10000.0,71046.5109963417,76628.57782149315,71046.5109963417,5563.565862417221,9.763209104537964,0.0 -159100,2.5242949,1.5648453,,,,,,,,,,,,,, -159200,2.6455889,1.3535105,,,,,,,,,,,,,, -159300,2.5551481,1.4058043,,,,,,,,,,,,,, -159400,2.9892426,3.2760623,,,,,,,,,,,,,, -159500,2.3709967,1.1794608,,,,,,,,,,,,,, -159600,2.5940332,1.2490147,,,,,,,,,,,,,, -159700,3.18644,3.3003633,,,,,,,,,,,,,, -159800,2.5511527,1.1940513,,,,,,,,,,,,,, -159900,2.696582,1.227773,,,,,,,,,,,,,, -160000,3.3536527,3.5114412,,,,,,,,,,,,,, -160018,,,0.8733007907867432,0.4837360680103302,0.7673999667167664,0.9166808128356934,50000.0,0.6488000154495239,1.520171046257019,10000.0,71466.58711266518,77087.69198036194,71466.58711266518,5602.491634130478,9.824456453323364,0.0 -160100,2.6038754,2.7542126,,,,,,,,,,,,,, -160200,2.6496034,1.1482239,,,,,,,,,,,,,, -160300,2.5052097,1.2338816,,,,,,,,,,,,,, -160400,2.8201025,2.232076,,,,,,,,,,,,,, -160500,2.5583754,1.26958,,,,,,,,,,,,,, -160600,2.5610044,1.5089171,,,,,,,,,,,,,, -160700,2.6631033,2.250008,,,,,,,,,,,,,, -160800,2.6135838,1.1635574,,,,,,,,,,,,,, -160900,2.5623977,1.2743642,,,,,,,,,,,,,, -160956,,,0.8666601181030273,0.5009443759918213,0.7683999538421631,0.9037065505981444,50000.0,0.6488000154495239,1.511286735534668,10000.0,71886.58734107018,77546.2199819088,71886.58734107018,5640.905419826508,9.887304306030272,0.0 -161000,3.2084475,3.5157332,,,,,,,,,,,,,, -161100,2.972439,3.5000134,,,,,,,,,,,,,, -161200,2.6024277,1.7176697,,,,,,,,,,,,,, -161300,2.6591089,1.2376703,,,,,,,,,,,,,, -161400,2.6349025,2.0719898,,,,,,,,,,,,,, -161500,2.8688295,1.2316936,,,,,,,,,,,,,, -161600,2.9122796,3.2027938,,,,,,,,,,,,,, -161700,3.3445785,3.6303105,,,,,,,,,,,,,, -161800,2.6724,1.2743021,,,,,,,,,,,,,, -161899,,,0.8688476085662842,0.490939736366272,0.7700600028038025,0.8974904417991638,50000.0,0.6458000540733337,1.5094892978668213,10000.0,72306.52605938911,78008.83457326889,72306.52605938911,5683.460703372955,9.955764532089232,0.0 -161900,2.440002,1.1694478,,,,,,,,,,,,,, -162000,2.8166249,1.3204566,,,,,,,,,,,,,, -162100,2.7702785,1.2348522,,,,,,,,,,,,,, -162200,3.01853,3.2888017,,,,,,,,,,,,,, -162300,3.5996714,2.908308,,,,,,,,,,,,,, -162400,3.2443209,3.4522722,,,,,,,,,,,,,, -162500,2.495216,1.8219621,,,,,,,,,,,,,, -162600,2.619769,1.2441797,,,,,,,,,,,,,, -162700,2.8070166,1.3437188,,,,,,,,,,,,,, -162800,2.5654624,2.0387871,,,,,,,,,,,,,, -162842,,,0.87611323595047,0.4695218503475189,0.7698599696159363,0.8996132016181946,50000.0,0.6525000333786011,1.496620774269104,10000.0,72726.65507078171,78462.64110207558,72726.65507078171,5717.021797180176,10.020394086837769,0.0 -162900,2.852491,1.9786215,,,,,,,,,,,,,, -163000,2.6767132,1.1605699,,,,,,,,,,,,,, -163100,2.7146845,2.947115,,,,,,,,,,,,,, -163200,2.648529,1.7359037,,,,,,,,,,,,,, -163300,2.625711,1.1991549,,,,,,,,,,,,,, -163400,2.628367,1.2206016,,,,,,,,,,,,,, -163500,2.8025095,2.761747,,,,,,,,,,,,,, -163600,2.8025794,1.393886,,,,,,,,,,,,,, -163700,2.9965417,3.32939,,,,,,,,,,,,,, -163776,,,0.86865234375,0.4866144359111786,0.7710599899291992,0.8960846662521362,50000.0,0.6538000106811523,1.4957234859466553,10000.0,73146.93170976639,78918.5292005539,73146.93170976639,5752.519358158112,10.0841383934021,0.0 -163800,2.5775673,1.3801565,,,,,,,,,,,,,, -163900,2.6017497,1.671844,,,,,,,,,,,,,, -164000,2.7536788,1.320948,,,,,,,,,,,,,, -164100,2.4807553,1.5934206,,,,,,,,,,,,,, -164200,3.007754,2.4810953,,,,,,,,,,,,,, -164300,2.6312964,1.1948316,,,,,,,,,,,,,, -164400,2.6591744,1.8049449,,,,,,,,,,,,,, -164500,2.961967,1.3348737,,,,,,,,,,,,,, -164600,2.493341,2.2431896,,,,,,,,,,,,,, -164700,2.8539312,1.1458254,,,,,,,,,,,,,, -164712,,,0.87123042345047,0.4835419654846191,0.7716599702835083,0.892865002155304,50000.0,0.65010005235672,1.4990768432617188,10000.0,73566.84558701515,79389.30561876297,73566.84558701515,5803.269119262695,10.14598798751831,0.0 -164800,2.8966386,1.8039371,,,,,,,,,,,,,, -164900,2.7894428,1.4160074,,,,,,,,,,,,,, -165000,2.8802423,1.3309023,,,,,,,,,,,,,, -165100,2.8730905,1.1728402,,,,,,,,,,,,,, -165200,2.8755336,1.1595954,,,,,,,,,,,,,, -165300,2.9701316,3.2950397,,,,,,,,,,,,,, -165400,2.8220234,1.1593885,,,,,,,,,,,,,, -165500,2.9944572,3.189343,,,,,,,,,,,,,, -165600,2.7294192,1.4518639,,,,,,,,,,,,,, -165655,,,0.8752539157867432,0.4673053920269012,0.7740199565887451,0.8827558159828186,50000.0,0.6559000015258789,1.486812710762024,10000.0,73986.8037891388,79844.26286292076,73986.8037891388,5838.162070274353,10.200010299682615,0.0 -165700,3.1369627,3.282774,,,,,,,,,,,,,, -165800,2.9832768,1.2736264,,,,,,,,,,,,,, -165900,2.8773594,3.0090413,,,,,,,,,,,,,, -166000,2.8549697,2.4573236,,,,,,,,,,,,,, -166100,2.488021,2.460232,,,,,,,,,,,,,, -166200,2.640208,1.1152399,,,,,,,,,,,,,, -166300,2.7141218,1.184065,,,,,,,,,,,,,, -166400,2.7467299,1.3370669,,,,,,,,,,,,,, -166500,2.5708952,1.3857033,,,,,,,,,,,,,, -166594,,,0.8815820217132568,0.4435708522796631,0.7743200063705444,0.8817341923713684,50000.0,0.657200038433075,1.481296181678772,10000.0,74406.99995279312,80299.7220902443,74406.99995279312,5873.305253982544,10.267615795135498,0.0 -166600,3.237938,3.3644452,,,,,,,,,,,,,, -166700,2.729724,1.4896344,,,,,,,,,,,,,, -166800,2.9417872,2.4616342,,,,,,,,,,,,,, -166900,2.6966043,2.3649695,,,,,,,,,,,,,, -167000,3.0584295,1.3876966,,,,,,,,,,,,,, -167100,2.746995,2.853943,,,,,,,,,,,,,, -167200,2.8511527,1.2719709,,,,,,,,,,,,,, -167300,2.7580855,1.6882598,,,,,,,,,,,,,, -167400,2.797728,1.986675,,,,,,,,,,,,,, -167500,2.7792912,1.2262504,,,,,,,,,,,,,, -167527,,,0.8744335770606995,0.4724225103855133,0.774179995059967,0.8804149627685547,50000.0,0.6550000309944153,1.4792810678482056,10000.0,74827.24753332138,80758.48889684677,74827.24753332138,5911.709374427795,10.32997465133667,0.0 -167600,2.793585,1.1734831,,,,,,,,,,,,,, -167700,2.8803363,1.2958516,,,,,,,,,,,,,, -167800,2.8371155,1.1673671,,,,,,,,,,,,,, -167900,2.7533846,1.2397175,,,,,,,,,,,,,, -168000,2.7169945,1.2590183,,,,,,,,,,,,,, -168100,2.7471058,2.075017,,,,,,,,,,,,,, -168200,2.7060413,1.2016587,,,,,,,,,,,,,, -168300,2.6977172,1.5364659,,,,,,,,,,,,,, -168400,3.9408739,3.3959246,,,,,,,,,,,,,, -168464,,,0.8765429258346558,0.4508543014526367,0.774899959564209,0.8729352951049805,50000.0,0.6531000137329102,1.4713881015777588,10000.0,75247.50064873695,81226.24056196213,75247.50064873695,5959.091042995453,10.396169185638428,0.0 -168500,3.5374289,3.3474326,,,,,,,,,,,,,, -168600,2.6127658,1.429042,,,,,,,,,,,,,, -168700,2.6819715,1.4204414,,,,,,,,,,,,,, -168800,2.7197547,1.6842947,,,,,,,,,,,,,, -168900,2.743603,2.0729804,,,,,,,,,,,,,, -169000,2.834061,1.1652886,,,,,,,,,,,,,, -169100,2.7869956,2.6470842,,,,,,,,,,,,,, -169200,2.796931,1.3220723,,,,,,,,,,,,,, -169300,3.1963108,1.7621442,,,,,,,,,,,,,, -169400,2.679299,1.2002141,,,,,,,,,,,,,, -169408,,,0.8811718821525574,0.447146862745285,0.7753999829292297,0.8756856918334961,50000.0,0.6557000279426575,1.4676754474639893,10000.0,75667.44517111778,81682.11791706085,75667.44517111778,5994.915839672089,10.452755689620972,0.0 -169500,2.6100798,1.1462433,,,,,,,,,,,,,, -169600,2.6125066,2.0853012,,,,,,,,,,,,,, -169700,2.8459246,2.3738906,,,,,,,,,,,,,, -169800,2.660936,2.1468627,,,,,,,,,,,,,, -169900,2.793656,2.4587889,,,,,,,,,,,,,, -170000,2.915458,1.330625,,,,,,,,,,,,,, -170100,3.3458085,1.1019909,,,,,,,,,,,,,, -170200,2.8162389,1.9674771,,,,,,,,,,,,,, -170300,3.751861,3.168029,,,,,,,,,,,,,, -170347,,,0.8786327838897705,0.4540967941284179,0.7759400010108948,0.8734537363052368,50000.0,0.6539000272750854,1.4720929861068726,10000.0,76087.41112089157,82140.03224730492,76087.41112089157,6032.74413061142,10.521770477294922,0.0 -170400,3.061125,1.2607875,,,,,,,,,,,,,, -170500,2.905686,1.6824306,,,,,,,,,,,,,, -170600,3.207247,1.1936412,,,,,,,,,,,,,, -170700,2.781521,1.1689541,,,,,,,,,,,,,, -170800,3.1748369,1.209539,,,,,,,,,,,,,, -170900,2.8917553,1.168796,,,,,,,,,,,,,, -171000,2.8332388,1.9337993,,,,,,,,,,,,,, -171100,2.9706173,1.2471042,,,,,,,,,,,,,, -171200,3.3358078,1.9976509,,,,,,,,,,,,,, -171285,,,0.8799608945846558,0.4463948607444763,0.7760399580001831,0.8722897171974182,50000.0,0.6573000550270081,1.4642900228500366,10000.0,76507.5048623085,82603.66271305084,76507.5048623085,6076.168865442276,10.582777738571169,0.0 -171300,2.776768,2.4298441,,,,,,,,,,,,,, -171400,2.8101628,1.5876515,,,,,,,,,,,,,, -171500,3.1645849,3.09275,,,,,,,,,,,,,, -171600,3.4213822,3.1883376,,,,,,,,,,,,,, -171700,4.4171824,2.996677,,,,,,,,,,,,,, -171800,2.8510518,1.5686085,,,,,,,,,,,,,, -171900,3.0039454,3.097305,,,,,,,,,,,,,, -172000,2.9398985,1.1564229,,,,,,,,,,,,,, -172100,3.4431853,1.1742432,,,,,,,,,,,,,, -172200,3.0637004,2.7927573,,,,,,,,,,,,,, -172228,,,0.8814257383346558,0.4432843625545501,0.7773799896240234,0.8704707026481628,50000.0,0.6579000353813171,1.4670575857162476,10000.0,76927.65636372566,83060.2145652771,76927.65636372566,6112.465822458267,10.634551048278809,0.0 -172300,2.9165082,2.0371532,,,,,,,,,,,,,, -172400,2.8862462,1.1634407,,,,,,,,,,,,,, -172500,3.2970552,2.4505014,,,,,,,,,,,,,, -172600,2.9244003,1.580526,,,,,,,,,,,,,, -172700,4.3415403,1.1030159,,,,,,,,,,,,,, -172800,2.9730806,1.4768238,,,,,,,,,,,,,, -172900,2.7816834,1.262501,,,,,,,,,,,,,, -173000,3.143342,1.3413105,,,,,,,,,,,,,, -173100,3.5910776,3.3052137,,,,,,,,,,,,,, -173169,,,0.8839452862739563,0.4334542453289032,0.7780399918556213,0.8655481934547424,50000.0,0.6571000218391418,1.4592924118041992,10000.0,77347.83917474747,83521.28679227829,77347.83917474747,6153.237934350967,10.700610399246216,0.0 -173200,2.818338,1.1513385,,,,,,,,,,,,,, -173300,2.708131,2.1376376,,,,,,,,,,,,,, -173400,2.9240224,1.141177,,,,,,,,,,,,,, -173500,3.0630395,1.2188463,,,,,,,,,,,,,, -173600,2.9029102,1.6560336,,,,,,,,,,,,,, -173700,2.9175136,1.2096401,,,,,,,,,,,,,, -173800,2.8616729,1.1131625,,,,,,,,,,,,,, -173900,3.021442,1.1415937,,,,,,,,,,,,,, -174000,2.8638535,1.6518923,,,,,,,,,,,,,, -174100,2.9245656,1.1206301,,,,,,,,,,,,,, -174111,,,0.8828905820846558,0.4322830736637115,0.77947998046875,0.8609575033187866,50000.0,0.6586000323295593,1.4545586109161377,10000.0,77768.06882762909,83976.2343738079,77768.06882762909,6187.839814186096,10.765312910079956,0.0 -174200,3.1190147,1.1588297,,,,,,,,,,,,,, -174300,3.1598291,1.3831877,,,,,,,,,,,,,, -174400,3.0466204,1.8600252,,,,,,,,,,,,,, -174500,2.9166539,1.171005,,,,,,,,,,,,,, -174600,3.0246766,1.0835109,,,,,,,,,,,,,, -174700,3.249906,3.1925247,,,,,,,,,,,,,, -174800,3.0240307,2.4907699,,,,,,,,,,,,,, -174900,2.903502,2.136894,,,,,,,,,,,,,, -175000,3.0359213,2.6350062,,,,,,,,,,,,,, -175049,,,0.8865429759025574,0.4204282760620117,0.7789599895477295,0.861438512802124,50000.0,0.6575000286102295,1.4523377418518066,10000.0,78187.95839118958,84430.58243322372,78187.95839118958,6222.178782701492,10.833298206329346,0.0 -175100,3.0937433,2.5286572,,,,,,,,,,,,,, -175200,2.825262,2.2297773,,,,,,,,,,,,,, -175300,2.8977895,1.1122203,,,,,,,,,,,,,, -175400,3.0501373,1.2354996,,,,,,,,,,,,,, -175500,3.2955208,2.9356256,,,,,,,,,,,,,, -175600,3.0194678,1.1358976,,,,,,,,,,,,,, -175700,3.9720562,3.0456164,,,,,,,,,,,,,, -175800,2.9568074,1.0960928,,,,,,,,,,,,,, -175900,3.0916076,1.1008049,,,,,,,,,,,,,, -175986,,,0.8853319883346558,0.4295352399349212,0.7797200083732605,0.859338641166687,50000.0,0.6561000347137451,1.4516371488571167,10000.0,78608.33759880066,84889.0225493908,78608.33759880066,6260.123591899872,10.897441625595093,0.0 -176000,3.1490211,1.2649004,,,,,,,,,,,,,, -176100,3.4100642,3.0777173,,,,,,,,,,,,,, -176200,3.3505623,1.2334403,,,,,,,,,,,,,, -176300,3.5485237,3.2628465,,,,,,,,,,,,,, -176400,3.0906377,1.2471387,,,,,,,,,,,,,, -176500,2.7886424,1.7175223,,,,,,,,,,,,,, -176600,3.6223476,3.069983,,,,,,,,,,,,,, -176700,3.109065,1.0864257,,,,,,,,,,,,,, -176800,3.1632342,2.6955514,,,,,,,,,,,,,, -176900,2.7485409,1.0550374,,,,,,,,,,,,,, -176925,,,0.8836327791213989,0.4284150898456573,0.7784199714660645,0.8580959439277649,50000.0,0.659000039100647,1.450667142868042,10000.0,79028.45848870277,85348.26768136024,79028.45848870277,6299.135026216507,10.9592125415802,0.0 -177000,3.3627174,1.2640747,,,,,,,,,,,,,, -177100,3.0604868,1.0992159,,,,,,,,,,,,,, -177200,3.3415768,3.0218303,,,,,,,,,,,,,, -177300,2.844752,1.4716161,,,,,,,,,,,,,, -177400,2.856916,1.2066506,,,,,,,,,,,,,, -177500,3.0413325,1.0837762,,,,,,,,,,,,,, -177600,2.969204,1.2330239,,,,,,,,,,,,,, -177700,2.7931798,1.332426,,,,,,,,,,,,,, -177800,3.6791458,3.244904,,,,,,,,,,,,,, -177867,,,0.8875390291213989,0.4265033900737762,0.7801399827003479,0.8563417792320251,50000.0,0.6620000600814819,1.443135380744934,10000.0,79448.42927765846,85803.83441829681,79448.42927765846,6334.59276509285,11.045661687850952,0.0 -177900,3.5411892,3.0460763,,,,,,,,,,,,,, -178000,3.1621249,2.8076258,,,,,,,,,,,,,, -178100,2.9672856,1.9009115,,,,,,,,,,,,,, -178200,3.1413987,1.0885822,,,,,,,,,,,,,, -178300,3.0645988,1.2244365,,,,,,,,,,,,,, -178400,3.109416,2.4203594,,,,,,,,,,,,,, -178500,3.080897,1.2141184,,,,,,,,,,,,,, -178600,3.5007713,3.038754,,,,,,,,,,,,,, -178700,2.9679012,1.3786033,,,,,,,,,,,,,, -178800,2.8690016,1.1541524,,,,,,,,,,,,,, -178807,,,0.887499988079071,0.4158221781253814,0.7795199751853943,0.8523306250572205,50000.0,0.6636000275611877,1.4399917125701904,10000.0,79868.60091280937,86265.87176012993,79868.60091280937,6376.347739696503,11.105734825134276,0.0 -178900,3.1233308,2.8122776,,,,,,,,,,,,,, -179000,2.9575937,1.3084302,,,,,,,,,,,,,, -179100,3.1626554,2.7207098,,,,,,,,,,,,,, -179200,3.3755946,2.8971655,,,,,,,,,,,,,, -179300,3.6525242,1.1055223,,,,,,,,,,,,,, -179400,2.9976833,1.5769029,,,,,,,,,,,,,, -179500,3.0443382,1.3336155,,,,,,,,,,,,,, -179600,3.004953,1.3487506,,,,,,,,,,,,,, -179700,3.0014157,1.0892911,,,,,,,,,,,,,, -179747,,,0.8867382407188416,0.418932557106018,0.7800399661064148,0.8535423874855042,50000.0,0.6633000373840332,1.4441020488739014,10000.0,80288.49507188797,86723.49782943726,80288.49507188797,6413.968862771988,11.165611028671265,0.0 -179800,3.1887653,1.1274544,,,,,,,,,,,,,, -179900,2.9586418,2.0700042,,,,,,,,,,,,,, -180000,3.039202,1.148008,,,,,,,,,,,,,, -180100,3.227223,1.1554884,,,,,,,,,,,,,, -180200,2.881383,1.1038357,,,,,,,,,,,,,, -180300,3.6876044,3.1923857,,,,,,,,,,,,,, -180400,2.9495108,2.6068294,,,,,,,,,,,,,, -180500,3.6913083,3.0987551,,,,,,,,,,,,,, -180600,3.2536423,1.1877806,,,,,,,,,,,,,, -180683,,,0.8873242139816284,0.4149569869041443,0.7804799675941467,0.8523868322372437,50000.0,0.6643000245094299,1.442840576171875,10000.0,80708.39556407928,87178.17326307297,80708.39556407928,6448.626727581024,11.231953859329224,0.0 -180700,3.037573,1.1999733,,,,,,,,,,,,,, -180800,3.408649,2.8162408,,,,,,,,,,,,,, -180900,3.073231,1.0661665,,,,,,,,,,,,,, -181000,3.0587058,1.2086046,,,,,,,,,,,,,, -181100,3.0910382,1.265652,,,,,,,,,,,,,, -181200,2.827514,1.5468951,,,,,,,,,,,,,, -181300,3.0234427,1.9608779,,,,,,,,,,,,,, -181400,2.77805,1.0980228,,,,,,,,,,,,,, -181500,3.145242,1.23674,,,,,,,,,,,,,, -181600,3.0653276,1.5634362,,,,,,,,,,,,,, -181620,,,0.8868749737739563,0.4176073670387268,0.780239999294281,0.8525054454803467,50000.0,0.6632000207901001,1.4451278448104858,10000.0,81128.36767435074,87639.06216335297,81128.36767435074,6489.427065849304,11.297489881515505,0.0 -181700,3.490148,1.1585943,,,,,,,,,,,,,, -181800,3.3034608,3.056583,,,,,,,,,,,,,, -181900,3.1460948,1.781172,,,,,,,,,,,,,, -182000,3.2656255,2.3456101,,,,,,,,,,,,,, -182100,3.053419,1.1901007,,,,,,,,,,,,,, -182200,3.9652543,3.2907271,,,,,,,,,,,,,, -182300,2.9389577,1.2819177,,,,,,,,,,,,,, -182400,2.9576387,1.086108,,,,,,,,,,,,,, -182500,3.143308,2.626892,,,,,,,,,,,,,, -182561,,,0.8882421851158142,0.4123246967792511,0.7803599834442139,0.8496127724647522,50000.0,0.6635000109672546,1.442088603973389,10000.0,81548.34212756157,88104.66285085678,81548.34212756157,6534.940185070038,11.358872413635254,0.0 -182600,3.1115477,2.3207974,,,,,,,,,,,,,, -182700,2.9059255,1.3333725,,,,,,,,,,,,,, -182800,3.1906,2.7413037,,,,,,,,,,,,,, -182900,3.529472,3.0650918,,,,,,,,,,,,,, -183000,2.8869557,1.3101537,,,,,,,,,,,,,, -183100,3.025034,1.1121844,,,,,,,,,,,,,, -183200,3.1616647,1.0322268,,,,,,,,,,,,,, -183300,3.181043,1.2318206,,,,,,,,,,,,,, -183400,3.0180259,2.1319866,,,,,,,,,,,,,, -183500,2.8936975,2.3532035,,,,,,,,,,,,,, -183505,,,0.8894140720367432,0.4123950004577636,0.780460000038147,0.8484692573547363,50000.0,0.6627000570297241,1.4406883716583252,10000.0,81968.63080692291,88558.89284658432,81968.63080692291,6568.771864891052,11.416648387908936,0.0 -183600,2.9442422,2.2898433,,,,,,,,,,,,,, -183700,2.9352024,1.0439538,,,,,,,,,,,,,, -183800,3.1212487,1.2265114,,,,,,,,,,,,,, -183900,3.1547256,1.1129171,,,,,,,,,,,,,, -184000,3.115825,1.1596876,,,,,,,,,,,,,, -184100,3.036908,1.0556918,,,,,,,,,,,,,, -184200,3.0131526,1.7762891,,,,,,,,,,,,,, -184300,3.3211143,1.1141984,,,,,,,,,,,,,, -184400,3.2341168,2.7416441,,,,,,,,,,,,,, -184444,,,0.8900390267372131,0.4150246083736419,0.7807599902153015,0.8500364422798157,50000.0,0.6625000238418579,1.4413601160049438,10000.0,82388.93534779549,89017.95306396484,82388.93534779549,6607.41025018692,11.48314118385315,0.0 -184500,2.96188,1.5245594,,,,,,,,,,,,,, -184600,5.7358475,1.6666253,,,,,,,,,,,,,, -184700,2.9867337,0.97916466,,,,,,,,,,,,,, -184800,2.8656583,1.1107256,,,,,,,,,,,,,, -184900,3.5969498,1.0982233,,,,,,,,,,,,,, -185000,3.0991607,1.1973243,,,,,,,,,,,,,, -185100,3.0439112,1.2127995,,,,,,,,,,,,,, -185200,3.2201602,1.3667257,,,,,,,,,,,,,, -185300,2.9635315,1.5422816,,,,,,,,,,,,,, -185384,,,0.8905664086341858,0.4088824987411499,0.7807799577713013,0.8490588068962097,50000.0,0.6627000570297241,1.4396294355392456,10000.0,82809.04411435127,89484.22124147415,82809.04411435127,6653.44875907898,11.552966833114624,0.0 -185400,2.9925752,1.1422751,,,,,,,,,,,,,, -185500,3.0513911,2.8777926,,,,,,,,,,,,,, -185600,2.9689133,1.2819118,,,,,,,,,,,,,, -185700,2.8898835,1.1976646,,,,,,,,,,,,,, -185800,2.9400358,1.1476033,,,,,,,,,,,,,, -185900,3.3397532,1.6139828,,,,,,,,,,,,,, -186000,3.4220746,1.0669838,,,,,,,,,,,,,, -186100,2.8789327,1.9340913,,,,,,,,,,,,,, -186200,3.0489223,1.0667425,,,,,,,,,,,,,, -186300,3.1507142,2.0357149,,,,,,,,,,,,,, -186329,,,0.8901953101158142,0.414861798286438,0.7806599736213684,0.8491452932357788,50000.0,0.6629000306129456,1.4399614334106443,10000.0,83228.92775535583,89943.25726413727,83228.92775535583,6692.481970310211,11.621200561523438,0.0 -186400,3.3601732,2.9225254,,,,,,,,,,,,,, -186500,2.8493116,1.3021463,,,,,,,,,,,,,, -186600,2.891518,1.0850154,,,,,,,,,,,,,, -186700,3.099249,1.1892256,,,,,,,,,,,,,, -186800,2.8426933,2.4989867,,,,,,,,,,,,,, -186900,2.9919312,1.5244908,,,,,,,,,,,,,, -187000,2.9749482,1.151371,,,,,,,,,,,,,, -187100,3.1677842,1.0621125,,,,,,,,,,,,,, -187200,2.987171,1.3924813,,,,,,,,,,,,,, -187270,,,0.8873632550239563,0.4139579236507416,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,83649.1949005127,90403.1090900898,83649.1949005127,6731.947980880737,11.689115285873411,0.0 -187300,3.5384815,1.1358454,,,,,,,,,,,,,, -187400,3.3042023,2.911106,,,,,,,,,,,,,, -187500,4.0454698,2.9984965,,,,,,,,,,,,,, -187600,2.8199098,1.60714,,,,,,,,,,,,,, -187700,3.6873634,3.199069,,,,,,,,,,,,,, -187800,3.806324,1.1089474,,,,,,,,,,,,,, -187900,2.9163597,2.1702862,,,,,,,,,,,,,, -188000,3.3197997,1.2849996,,,,,,,,,,,,,, -188100,2.9427383,1.1093649,,,,,,,,,,,,,, -188200,3.4633605,2.7160506,,,,,,,,,,,,,, -188208,,,0.8859374523162842,0.4201960265636444,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,84069.26384472847,90860.44817233086,84069.26384472847,6769.103258132935,11.752736330032349,0.0 -188300,3.1610513,1.0047504,,,,,,,,,,,,,, -188400,3.0756702,1.0585798,,,,,,,,,,,,,, -188500,2.8869078,1.8873111,,,,,,,,,,,,,, -188600,3.0614667,2.4024758,,,,,,,,,,,,,, -188700,4.327185,3.2375584,,,,,,,,,,,,,, -188800,3.238203,2.805103,,,,,,,,,,,,,, -188900,3.076681,1.3291339,,,,,,,,,,,,,, -189000,3.3368752,2.3605075,,,,,,,,,,,,,, -189100,3.0541258,1.1388843,,,,,,,,,,,,,, -189144,,,0.8887304663658142,0.4120084643363952,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,84489.2108707428,91319.21151280405,84489.2108707428,6807.8059005737305,11.816109657287598,0.0 -189200,2.982068,2.5299616,,,,,,,,,,,,,, -189300,3.0881314,1.1395782,,,,,,,,,,,,,, -189400,2.9324224,1.058649,,,,,,,,,,,,,, -189500,3.0467389,1.0743027,,,,,,,,,,,,,, -189600,3.2431931,2.2594285,,,,,,,,,,,,,, -189700,3.076687,1.863095,,,,,,,,,,,,,, -189800,3.1608653,1.1181083,,,,,,,,,,,,,, -189900,3.4144151,3.003046,,,,,,,,,,,,,, -190000,2.914043,1.1448399,,,,,,,,,,,,,, -190081,,,0.8896288871765137,0.4153327941894531,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,84909.57313537598,91777.704955101,84909.57313537598,6845.819953918457,11.881962776184082,0.0 -190100,2.9071922,1.2851459,,,,,,,,,,,,,, -190200,2.971489,1.015407,,,,,,,,,,,,,, -190300,2.9328032,2.4954982,,,,,,,,,,,,,, -190400,2.9203925,1.5462943,,,,,,,,,,,,,, -190500,3.27342,2.5500202,,,,,,,,,,,,,, -190600,3.1906154,1.4909571,,,,,,,,,,,,,, -190700,2.83662,1.0611308,,,,,,,,,,,,,, -190800,3.055812,1.1739887,,,,,,,,,,,,,, -190900,3.8720896,3.059412,,,,,,,,,,,,,, -191000,3.14684,1.1697963,,,,,,,,,,,,,, -191022,,,0.8895507454872131,0.4160298407077789,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,85329.69141507149,92244.0565958023,85329.69141507149,6891.935022592545,11.94904851913452,0.0 -191100,3.8051057,3.2972748,,,,,,,,,,,,,, -191200,3.4225197,2.4519043,,,,,,,,,,,,,, -191300,3.1334553,1.1372535,,,,,,,,,,,,,, -191400,2.933196,1.0583451,,,,,,,,,,,,,, -191500,3.6692896,3.2422068,,,,,,,,,,,,,, -191600,2.9506195,1.7499826,,,,,,,,,,,,,, -191700,3.0809166,2.565973,,,,,,,,,,,,,, -191800,3.1156366,2.6331692,,,,,,,,,,,,,, -191900,3.094975,1.078001,,,,,,,,,,,,,, -191967,,,0.8859570026397705,0.4194627404212951,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,85749.73276805878,92703.48462033272,85749.73276805878,6931.216662168503,12.00302505493164,0.0 -192000,3.08438,1.1496215,,,,,,,,,,,,,, -192100,3.3914735,2.8107998,,,,,,,,,,,,,, -192200,2.7629824,1.5180209,,,,,,,,,,,,,, -192300,3.0801969,1.0526897,,,,,,,,,,,,,, -192400,3.0489986,1.5820524,,,,,,,,,,,,,, -192500,3.4968545,2.8710582,,,,,,,,,,,,,, -192600,3.0553186,1.1492223,,,,,,,,,,,,,, -192700,3.0108995,1.1402746,,,,,,,,,,,,,, -192800,3.1226575,2.4409194,,,,,,,,,,,,,, -192900,3.0899842,1.1755384,,,,,,,,,,,,,, -192908,,,0.8864843845367432,0.4163427650928497,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,86169.99474191666,93159.5766465664,86169.99474191666,6966.921736240387,12.076726198196411,0.0 -193000,2.8757346,1.1606567,,,,,,,,,,,,,, -193100,3.0012603,1.709896,,,,,,,,,,,,,, -193200,4.0257297,3.2935553,,,,,,,,,,,,,, -193300,3.90046,3.149088,,,,,,,,,,,,,, -193400,3.75844,3.2044182,,,,,,,,,,,,,, -193500,3.2754393,2.0556579,,,,,,,,,,,,,, -193600,3.1110663,2.5640657,,,,,,,,,,,,,, -193700,3.381636,1.1532168,,,,,,,,,,,,,, -193800,3.7389119,3.2670865,,,,,,,,,,,,,, -193846,,,0.8892577886581421,0.4128409922122955,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,86590.10334300995,93620.9745607376,86590.10334300995,7008.095109939575,12.14166522026062,0.0 -193900,3.3523867,2.6628928,,,,,,,,,,,,,, -194000,2.9359334,2.044131,,,,,,,,,,,,,, -194100,3.6402051,3.1968663,,,,,,,,,,,,,, -194200,3.2055628,1.1479287,,,,,,,,,,,,,, -194300,3.0863388,1.0843372,,,,,,,,,,,,,, -194400,2.9911134,2.5994468,,,,,,,,,,,,,, -194500,3.2727005,2.9972095,,,,,,,,,,,,,, -194600,3.0834339,1.3247414,,,,,,,,,,,,,, -194700,4.023517,2.7534218,,,,,,,,,,,,,, -194786,,,0.8887109160423279,0.4138174057006836,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,87010.18246912956,94079.54350543022,87010.18246912956,7046.456674575806,12.20676326751709,0.0 -194800,2.9114676,1.1279002,,,,,,,,,,,,,, -194900,3.1914868,1.9168441,,,,,,,,,,,,,, -195000,3.7473269,3.2598338,,,,,,,,,,,,,, -195100,2.8449674,1.0797664,,,,,,,,,,,,,, -195200,3.3656988,1.1611853,,,,,,,,,,,,,, -195300,3.3311803,2.8960059,,,,,,,,,,,,,, -195400,3.447837,1.1889215,,,,,,,,,,,,,, -195500,3.8701203,3.2203069,,,,,,,,,,,,,, -195600,2.972696,2.4341342,,,,,,,,,,,,,, -195700,2.9886618,2.5568976,,,,,,,,,,,,,, -195726,,,0.8862695097923279,0.4229434728622436,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,87430.38870239258,94538.39776062964,87430.38870239258,7084.988321781158,12.271981477737429,0.0 -195800,2.9995258,2.097566,,,,,,,,,,,,,, -195900,3.8046603,3.163416,,,,,,,,,,,,,, -196000,2.8152425,1.5875801,,,,,,,,,,,,,, -196100,3.2074893,1.164891,,,,,,,,,,,,,, -196200,4.0131025,3.1864257,,,,,,,,,,,,,, -196300,3.1373694,1.8703332,,,,,,,,,,,,,, -196400,3.4240282,2.819121,,,,,,,,,,,,,, -196500,3.1183338,1.2431426,,,,,,,,,,,,,, -196600,2.9086652,1.0673126,,,,,,,,,,,,,, -196662,,,0.88929682970047,0.4122631251811981,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,87850.36334037781,94997.57072806358,87850.36334037781,7124.067133426666,12.340903997421265,0.0 -196700,3.2780607,2.6956365,,,,,,,,,,,,,, -196800,2.5652735,1.4757215,,,,,,,,,,,,,, -196900,3.2695117,1.8156207,,,,,,,,,,,,,, -197000,3.1401048,1.1373682,,,,,,,,,,,,,, -197100,2.8575988,1.4273156,,,,,,,,,,,,,, -197200,4.101486,3.1583564,,,,,,,,,,,,,, -197300,2.8931184,1.1272874,,,,,,,,,,,,,, -197400,4.814777,2.68884,,,,,,,,,,,,,, -197500,3.1413805,1.5412242,,,,,,,,,,,,,, -197600,,,0.8868359327316284,0.4162808060646057,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,88270.65979909897,95455.27904057504,88270.65979909897,7161.361990451813,12.406561136245728,0.0 -197600,2.9374468,2.3076508,,,,,,,,,,,,,, -197700,3.0523226,2.5843294,,,,,,,,,,,,,, -197800,3.0800474,1.1557522,,,,,,,,,,,,,, -197900,3.277904,2.4898076,,,,,,,,,,,,,, -198000,3.2244618,1.2638209,,,,,,,,,,,,,, -198100,2.7645035,1.0506269,,,,,,,,,,,,,, -198200,2.9861686,2.2660222,,,,,,,,,,,,,, -198300,3.230733,1.2434982,,,,,,,,,,,,,, -198400,3.136782,1.876186,,,,,,,,,,,,,, -198500,2.8168366,1.4403845,,,,,,,,,,,,,, -198538,,,0.8875390291213989,0.413641095161438,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,88690.6394803524,95913.43540740012,88690.6394803524,7199.420048713684,12.473467350006104,0.0 -198600,2.8920364,1.1039579,,,,,,,,,,,,,, -198700,3.1699913,1.2037721,,,,,,,,,,,,,, -198800,3.6474311,3.0149388,,,,,,,,,,,,,, -198900,3.3429902,2.7108974,,,,,,,,,,,,,, -199000,3.0105112,1.6635686,,,,,,,,,,,,,, -199100,3.2227113,2.1593504,,,,,,,,,,,,,, -199200,2.9858112,2.1259024,,,,,,,,,,,,,, -199300,3.151515,1.6775676,,,,,,,,,,,,,, -199400,3.1852787,1.1594046,,,,,,,,,,,,,, -199480,,,0.8901171684265137,0.4157747030258178,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,89110.54193615913,96371.70179224014,89110.54193615913,7237.657130002975,12.549538135528564,0.0 -199500,2.9637012,2.5835502,,,,,,,,,,,,,, -199600,3.9362826,3.1575654,,,,,,,,,,,,,, -199700,2.9768727,1.1390526,,,,,,,,,,,,,, -199800,3.0526617,1.0454952,,,,,,,,,,,,,, -199900,3.1905282,2.4606552,,,,,,,,,,,,,, -200000,2.870059,1.0049433,,,,,,,,,,,,,, -200100,3.1269841,2.644538,,,,,,,,,,,,,, -200200,2.9987183,1.4734771,,,,,,,,,,,,,, -200300,3.3908033,2.800553,,,,,,,,,,,,,, -200400,3.1393204,1.7337879,,,,,,,,,,,,,, -200414,,,0.887499988079071,0.41785928606987,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,89530.50885987282,96829.34929418564,89530.50885987282,7275.220329999924,12.615999937057495,0.0 -200500,3.2659895,2.0050611,,,,,,,,,,,,,, -200600,3.364232,2.699779,,,,,,,,,,,,,, -200700,3.2379177,1.0067921,,,,,,,,,,,,,, -200800,3.1499617,1.07245,,,,,,,,,,,,,, -200900,3.1998456,1.153264,,,,,,,,,,,,,, -201000,3.2168689,1.0681337,,,,,,,,,,,,,, -201100,3.078139,2.5670283,,,,,,,,,,,,,, -201200,3.9565117,1.9962745,,,,,,,,,,,,,, -201300,3.0065038,1.1755261,,,,,,,,,,,,,, -201349,,,0.8890624642372131,0.4149791598320007,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,89950.4132950306,97288.59665942192,89950.4132950306,7314.43705201149,12.690476179122925,0.0 -201400,2.953278,2.180505,,,,,,,,,,,,,, -201500,2.8959243,1.8050761,,,,,,,,,,,,,, -201600,3.0309708,1.1141424,,,,,,,,,,,,,, -201700,3.2744377,1.1811731,,,,,,,,,,,,,, -201800,3.1049857,1.6594267,,,,,,,,,,,,,, -201900,2.942989,1.1819193,,,,,,,,,,,,,, -202000,3.1400232,2.4124794,,,,,,,,,,,,,, -202100,4.069628,3.3129315,,,,,,,,,,,,,, -202200,3.3211052,1.1333265,,,,,,,,,,,,,, -202291,,,0.8885741829872131,0.4138651192188263,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,90370.42215585709,97747.3352882862,90370.42215585709,7353.0320365428925,12.773582935333252,0.0 -202300,3.179422,1.1627599,,,,,,,,,,,,,, -202400,3.0467331,1.1393661,,,,,,,,,,,,,, -202500,3.1988196,2.5082324,,,,,,,,,,,,,, -202600,2.9103844,1.1398916,,,,,,,,,,,,,, -202700,3.6143324,3.2635396,,,,,,,,,,,,,, -202800,3.0742843,1.1549438,,,,,,,,,,,,,, -202900,3.4993663,3.0978317,,,,,,,,,,,,,, -203000,3.3713808,1.1324718,,,,,,,,,,,,,, -203100,3.0401127,1.0725342,,,,,,,,,,,,,, -203200,3.4899204,3.0885248,,,,,,,,,,,,,, -203232,,,0.8862109184265137,0.4218650758266449,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,90790.63092684746,98205.05185079576,90790.63092684746,7390.410850524902,12.851639747619627,0.0 -203300,2.9798665,1.0525959,,,,,,,,,,,,,, -203400,3.207829,1.9314368,,,,,,,,,,,,,, -203500,3.0777135,1.2431227,,,,,,,,,,,,,, -203600,3.125156,1.2217212,,,,,,,,,,,,,, -203700,3.2115593,1.1729531,,,,,,,,,,,,,, -203800,3.1065505,1.1935499,,,,,,,,,,,,,, -203900,2.9967148,1.0537403,,,,,,,,,,,,,, -204000,3.0794835,1.050073,,,,,,,,,,,,,, -204100,3.1933768,1.1922998,,,,,,,,,,,,,, -204166,,,0.888964831829071,0.4084383249282837,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,91210.55031061172,98661.71626615524,91210.55031061172,7427.037251472473,12.919341325759888,0.0 -204200,3.6017222,2.885682,,,,,,,,,,,,,, -204300,3.5371518,1.1305796,,,,,,,,,,,,,, -204400,3.0406046,1.1903818,,,,,,,,,,,,,, -204500,3.4737406,2.903759,,,,,,,,,,,,,, -204600,3.000071,1.9308448,,,,,,,,,,,,,, -204700,3.0021715,1.1702393,,,,,,,,,,,,,, -204800,2.9913888,2.3160853,,,,,,,,,,,,,, -204900,3.2780442,2.3117962,,,,,,,,,,,,,, -205000,2.7487724,1.6317234,,,,,,,,,,,,,, -205100,2.9014614,1.2175323,,,,,,,,,,,,,, -205106,,,0.8895702958106995,0.4111346006393432,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,91630.6247985363,99123.73776745796,91630.6247985363,7468.865349769592,12.98732614517212,0.0 -205200,3.0902407,1.3491442,,,,,,,,,,,,,, -205300,3.2161465,1.1972984,,,,,,,,,,,,,, -205400,3.2335947,1.2723513,,,,,,,,,,,,,, -205500,3.1524513,2.7643847,,,,,,,,,,,,,, -205600,2.6548157,1.9822986,,,,,,,,,,,,,, -205700,3.5994716,3.1361613,,,,,,,,,,,,,, -205800,3.0541244,1.1585689,,,,,,,,,,,,,, -205900,3.059768,1.3490492,,,,,,,,,,,,,, -206000,2.9863596,1.0742443,,,,,,,,,,,,,, -206045,,,0.8857226371765137,0.4192066490650177,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,92050.679759264,99580.80747056007,92050.679759264,7505.758831977844,13.057476282119753,0.0 -206100,3.0368266,2.4643292,,,,,,,,,,,,,, -206200,3.2046797,2.6119034,,,,,,,,,,,,,, -206300,3.0469284,2.647079,,,,,,,,,,,,,, -206400,3.2104423,1.1485353,,,,,,,,,,,,,, -206500,2.8165612,1.1240934,,,,,,,,,,,,,, -206600,3.7995071,3.0643826,,,,,,,,,,,,,, -206700,3.1713364,2.5769439,,,,,,,,,,,,,, -206800,3.2050958,2.907307,,,,,,,,,,,,,, -206900,3.0283756,1.4074236,,,,,,,,,,,,,, -206981,,,0.8902148008346558,0.4150867462158203,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,92470.65465569496,100037.31982588768,92470.65465569496,7542.178693294525,13.123713493347168,0.0 -207000,2.9363482,1.14561,,,,,,,,,,,,,, -207100,3.521755,2.8133316,,,,,,,,,,,,,, -207200,2.964222,1.2398432,,,,,,,,,,,,,, -207300,3.2727015,2.9887688,,,,,,,,,,,,,, -207400,3.4269645,3.019917,,,,,,,,,,,,,, -207500,3.3587818,3.1591432,,,,,,,,,,,,,, -207600,3.3345134,1.1167896,,,,,,,,,,,,,, -207700,2.9185686,1.0784249,,,,,,,,,,,,,, -207800,3.1057663,1.0564632,,,,,,,,,,,,,, -207900,3.0535967,1.108793,,,,,,,,,,,,,, -207919,,,0.8915429711341858,0.4060378968715668,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,92890.90583324432,100494.89912700652,92890.90583324432,7579.383923530579,13.19585919380188,0.0 -208000,2.7381096,1.5946294,,,,,,,,,,,,,, -208100,2.9456422,1.5518525,,,,,,,,,,,,,, -208200,3.0398967,2.040819,,,,,,,,,,,,,, -208300,2.9399946,1.0838609,,,,,,,,,,,,,, -208400,3.1816754,1.2153913,,,,,,,,,,,,,, -208500,3.2660842,1.157458,,,,,,,,,,,,,, -208600,3.3086214,1.1828903,,,,,,,,,,,,,, -208700,4.5350056,3.21142,,,,,,,,,,,,,, -208800,2.9620037,1.1471375,,,,,,,,,,,,,, -208859,,,0.8892382383346558,0.4131664931774139,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,93311.23576164246,100954.3665342331,93311.23576164246,7618.397220611572,13.267715454101562,0.0 -208900,3.1889222,1.3403809,,,,,,,,,,,,,, -209000,3.1750417,1.1453559,,,,,,,,,,,,,, -209100,4.4603324,3.2217932,,,,,,,,,,,,,, -209200,3.007024,1.489311,,,,,,,,,,,,,, -209300,3.1619737,1.1658992,,,,,,,,,,,,,, -209400,3.12697,1.1607958,,,,,,,,,,,,,, -209500,3.5356715,3.1712244,,,,,,,,,,,,,, -209600,2.9852219,2.4398646,,,,,,,,,,,,,, -209700,3.525648,3.08248,,,,,,,,,,,,,, -209800,,,0.8897460699081421,0.4122765362262726,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,93731.48620271684,101411.4986450672,93731.48620271684,7655.169943809509,13.32474970817566,0.0 -209800,2.9510567,1.1546085,,,,,,,,,,,,,, -209900,3.1416063,1.2045884,,,,,,,,,,,,,, -210000,3.1918566,2.8420408,,,,,,,,,,,,,, -210100,2.9404545,1.292068,,,,,,,,,,,,,, -210200,3.1417134,1.1153467,,,,,,,,,,,,,, -210300,2.940138,1.2066046,,,,,,,,,,,,,, -210400,2.9476469,2.3383265,,,,,,,,,,,,,, -210500,3.2042747,1.8820758,,,,,,,,,,,,,, -210600,3.1025357,1.0632331,,,,,,,,,,,,,, -210700,3.4500988,1.201832,,,,,,,,,,,,,, -210740,,,0.8892187476158142,0.4108535349369049,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,94151.36676955225,101878.605271101,94151.36676955225,7702.275942802429,13.394086122512816,0.0 -210800,3.5784209,3.1433005,,,,,,,,,,,,,, -210900,3.571644,3.0797944,,,,,,,,,,,,,, -211000,3.2556217,1.1633768,,,,,,,,,,,,,, -211100,3.1987207,2.2572322,,,,,,,,,,,,,, -211200,2.8204513,1.1974075,,,,,,,,,,,,,, -211300,3.030236,1.2768945,,,,,,,,,,,,,, -211400,2.8491645,1.8976362,,,,,,,,,,,,,, -211500,3.5618114,2.980112,,,,,,,,,,,,,, -211600,2.964555,2.1238556,,,,,,,,,,,,,, -211682,,,0.8870898485183716,0.419926643371582,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,94571.33133649826,102332.013286829,94571.33133649826,7735.611426591873,13.45100212097168,0.0 -211700,2.9961488,1.1587692,,,,,,,,,,,,,, -211800,3.2308242,1.0571239,,,,,,,,,,,,,, -211900,3.5680983,3.2462144,,,,,,,,,,,,,, -212000,3.2422507,1.1908746,,,,,,,,,,,,,, -212100,3.2543154,1.6915413,,,,,,,,,,,,,, -212200,3.029578,1.1320623,,,,,,,,,,,,,, -212300,3.712277,2.5587869,,,,,,,,,,,,,, -212400,3.3253899,2.2218046,,,,,,,,,,,,,, -212500,3.0092623,1.0652537,,,,,,,,,,,,,, -212600,3.1401181,1.1967647,,,,,,,,,,,,,, -212623,,,0.8870312571525574,0.416995108127594,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,94991.60111403464,102789.38869023325,94991.60111403464,7772.595160007477,13.521557331085203,0.0 -212700,2.9754982,1.0731102,,,,,,,,,,,,,, -212800,3.0842566,1.1291702,,,,,,,,,,,,,, -212900,3.268807,1.15128,,,,,,,,,,,,,, -213000,2.9540079,1.435235,,,,,,,,,,,,,, -213100,2.924819,1.2709378,,,,,,,,,,,,,, -213200,2.9937236,2.2577865,,,,,,,,,,,,,, -213300,2.8631449,1.9468104,,,,,,,,,,,,,, -213400,3.0061438,1.0582057,,,,,,,,,,,,,, -213500,3.7379217,3.2032497,,,,,,,,,,,,,, -213562,,,0.8886132836341858,0.4166717827320099,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,95411.76716446877,103249.92401885986,95411.76716446877,7812.842822313309,13.592507123947144,0.0 -213600,3.1974828,2.3059797,,,,,,,,,,,,,, -213700,3.250837,2.1663058,,,,,,,,,,,,,, -213800,3.0426657,1.0689936,,,,,,,,,,,,,, -213900,3.6398659,3.2621462,,,,,,,,,,,,,, -214000,3.0632613,2.5496492,,,,,,,,,,,,,, -214100,3.2638273,1.1848803,,,,,,,,,,,,,, -214200,4.5070376,3.2561374,,,,,,,,,,,,,, -214300,3.2694683,1.1404306,,,,,,,,,,,,,, -214400,3.4640665,3.1417434,,,,,,,,,,,,,, -214500,3.0626352,2.4541502,,,,,,,,,,,,,, -214501,,,0.888964831829071,0.411579966545105,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,95831.76033210754,103707.87656855585,95831.76033210754,7850.67919921875,13.663443326950071,0.0 -214600,3.4910445,3.1932695,,,,,,,,,,,,,, -214700,3.8625576,3.2466204,,,,,,,,,,,,,, -214800,3.6275487,3.1857393,,,,,,,,,,,,,, -214900,3.0907733,1.4152144,,,,,,,,,,,,,, -215000,2.951234,1.0719141,,,,,,,,,,,,,, -215100,3.3903613,2.85749,,,,,,,,,,,,,, -215200,2.927522,2.347895,,,,,,,,,,,,,, -215300,3.0836647,1.4668391,,,,,,,,,,,,,, -215400,3.115028,1.0493851,,,,,,,,,,,,,, -215442,,,0.8849804401397705,0.4249436855316162,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,96251.69516730309,104171.17848610878,96251.69516730309,7893.881835460663,13.775733470916748,0.0 -215500,3.2083902,1.1815717,,,,,,,,,,,,,, -215600,3.0844603,2.4274535,,,,,,,,,,,,,, -215700,3.9938934,1.085368,,,,,,,,,,,,,, -215800,3.077297,1.5749543,,,,,,,,,,,,,, -215900,2.9316974,2.0305717,,,,,,,,,,,,,, -216000,3.2309513,1.118283,,,,,,,,,,,,,, -216100,3.2914493,1.2659377,,,,,,,,,,,,,, -216200,3.1928413,2.7837203,,,,,,,,,,,,,, -216300,3.0994954,1.1027867,,,,,,,,,,,,,, -216380,,,0.8871093392372131,0.4179880023002624,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,96671.78650546074,104632.13685011864,96671.78650546074,7934.622961759567,13.850823163986206,0.0 -216400,2.9021049,1.0578692,,,,,,,,,,,,,, -216500,2.9828546,1.1912674,,,,,,,,,,,,,, -216600,2.832006,1.0323621,,,,,,,,,,,,,, -216700,3.3735878,2.8550243,,,,,,,,,,,,,, -216800,2.9350467,1.2079688,,,,,,,,,,,,,, -216900,2.9710433,1.0706265,,,,,,,,,,,,,, -217000,3.1667564,1.0711306,,,,,,,,,,,,,, -217100,3.1783993,2.711852,,,,,,,,,,,,,, -217200,2.9793782,1.447572,,,,,,,,,,,,,, -217300,3.0711794,1.0866518,,,,,,,,,,,,,, -217317,,,0.8890234231948853,0.4122736155986786,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,97092.09571146964,105094.53578543664,97092.09571146964,7976.588205337524,13.924319505691528,0.0 -217400,3.2875004,2.3116853,,,,,,,,,,,,,, -217500,2.956056,1.2627172,,,,,,,,,,,,,, -217600,3.274674,2.3934765,,,,,,,,,,,,,, -217700,3.3359127,1.2043769,,,,,,,,,,,,,, -217800,2.926298,1.1337425,,,,,,,,,,,,,, -217900,3.178215,1.1379752,,,,,,,,,,,,,, -218000,2.9854105,1.1938318,,,,,,,,,,,,,, -218100,3.2464178,1.2469096,,,,,,,,,,,,,, -218200,3.4613655,2.8753147,,,,,,,,,,,,,, -218258,,,0.8883007764816284,0.412197470664978,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,97512.15190458298,105548.88929748537,97512.15190458298,8010.768728733063,13.989553689956663,0.0 -218300,3.1445646,1.2353729,,,,,,,,,,,,,, -218400,2.83156,1.1695442,,,,,,,,,,,,,, -218500,3.1279078,1.1652703,,,,,,,,,,,,,, -218600,3.1757112,1.1719551,,,,,,,,,,,,,, -218700,2.792595,1.4722316,,,,,,,,,,,,,, -218800,3.071394,1.9577237,,,,,,,,,,,,,, -218900,4.1165733,3.2364094,,,,,,,,,,,,,, -219000,3.2519162,1.1362131,,,,,,,,,,,,,, -219100,3.1881802,1.7405186,,,,,,,,,,,,,, -219194,,,0.8895312547683716,0.4137906432151794,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,97932.20115685464,106008.19831943512,97932.20115685464,8049.906625032425,14.060953378677368,0.0 -219200,3.1971505,1.1703168,,,,,,,,,,,,,, -219300,3.1075163,1.2653737,,,,,,,,,,,,,, -219400,3.7684429,2.0455036,,,,,,,,,,,,,, -219500,4.0059834,3.1485093,,,,,,,,,,,,,, -219600,3.7817435,3.1696558,,,,,,,,,,,,,, -219700,3.6668591,3.24325,,,,,,,,,,,,,, -219800,3.9811974,2.9057338,,,,,,,,,,,,,, -219900,3.0100477,1.1234752,,,,,,,,,,,,,, -220000,3.2576232,2.6356015,,,,,,,,,,,,,, -220100,5.6613336,1.1329,,,,,,,,,,,,,, -220131,,,0.8866210579872131,0.419018805027008,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,98352.26095962524,106465.86269831656,98352.26095962524,8087.387703418732,14.133442878723145,0.0 -220200,3.2500172,1.1500585,,,,,,,,,,,,,, -220300,5.7484984,1.0352886,,,,,,,,,,,,,, -220400,3.0398304,1.1184425,,,,,,,,,,,,,, -220500,3.2915387,2.5118334,,,,,,,,,,,,,, -220600,2.8329682,1.7470767,,,,,,,,,,,,,, -220700,3.0100038,1.1845348,,,,,,,,,,,,,, -220800,3.0435028,1.1261083,,,,,,,,,,,,,, -220900,3.143213,1.2132827,,,,,,,,,,,,,, -221000,2.996414,1.5901859,,,,,,,,,,,,,, -221070,,,0.8873828053474426,0.4174685180187225,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,98772.34060120584,106925.2056400776,98772.34060120584,8126.485192060471,14.248473167419434,0.0 -221100,3.954736,3.1678755,,,,,,,,,,,,,, -221200,3.0350664,1.1143345,,,,,,,,,,,,,, -221300,2.818035,1.214185,,,,,,,,,,,,,, -221400,3.0175574,2.3136368,,,,,,,,,,,,,, -221500,3.0046518,1.6000648,,,,,,,,,,,,,, -221600,2.984742,1.137238,,,,,,,,,,,,,, -221700,3.43631,2.6730976,,,,,,,,,,,,,, -221800,2.9094388,1.0629412,,,,,,,,,,,,,, -221900,3.182771,1.15955,,,,,,,,,,,,,, -222000,4.5363183,3.1640947,,,,,,,,,,,,,, -222009,,,0.8881054520606995,0.4179070889949798,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,99192.3106303215,107391.95266342165,99192.3106303215,8173.137872457504,14.322559595108032,0.0 -222100,3.0701253,1.8486029,,,,,,,,,,,,,, -222200,3.1825264,1.1128411,,,,,,,,,,,,,, -222300,3.3424892,1.194702,,,,,,,,,,,,,, -222400,3.6432257,3.1626341,,,,,,,,,,,,,, -222500,2.9256992,1.1287732,,,,,,,,,,,,,, -222600,3.247434,1.3214207,,,,,,,,,,,,,, -222700,2.7935214,1.0564653,,,,,,,,,,,,,, -222800,3.04521,1.115558,,,,,,,,,,,,,, -222900,3.2089925,1.2186799,,,,,,,,,,,,,, -222944,,,0.8896679282188416,0.4159463346004486,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,99612.22961091997,107850.00766944884,99612.22961091997,8211.162187337875,14.383384227752686,0.0 -223000,3.573873,3.067035,,,,,,,,,,,,,, -223100,2.9816322,1.2269994,,,,,,,,,,,,,, -223200,2.9444556,1.80912,,,,,,,,,,,,,, -223300,2.9313393,1.1159794,,,,,,,,,,,,,, -223400,3.1003098,1.9479539,,,,,,,,,,,,,, -223500,2.7507372,1.8508632,,,,,,,,,,,,,, -223600,3.912623,3.2227385,,,,,,,,,,,,,, -223700,3.1247523,1.2978518,,,,,,,,,,,,,, -223800,3.0951815,1.4763288,,,,,,,,,,,,,, -223884,,,0.8884570002555847,0.4101370871067047,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,100032.17080593108,108309.85451054572,100032.17080593108,8250.954582214355,14.445920705795288,0.0 -223900,3.5822344,3.3056643,,,,,,,,,,,,,, -224000,3.296783,1.1149211,,,,,,,,,,,,,, -224100,3.392916,2.9041297,,,,,,,,,,,,,, -224200,2.9132442,1.1657647,,,,,,,,,,,,,, -224300,3.1366045,1.0710881,,,,,,,,,,,,,, -224400,3.869364,3.17728,,,,,,,,,,,,,, -224500,3.211026,1.0986962,,,,,,,,,,,,,, -224600,3.245706,1.1606717,,,,,,,,,,,,,, -224700,3.0888436,2.303436,,,,,,,,,,,,,, -224800,3.7927344,3.3010602,,,,,,,,,,,,,, -224823,,,0.88734370470047,0.4182247519493103,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,100452.20421767236,108772.51467132568,100452.20421767236,8293.447283029556,14.529171466827393,0.0 -224900,3.7464101,3.1834326,,,,,,,,,,,,,, -225000,3.493208,3.0613954,,,,,,,,,,,,,, -225100,3.7076967,2.8527298,,,,,,,,,,,,,, -225200,2.8084025,1.9796386,,,,,,,,,,,,,, -225300,2.9564896,1.6982003,,,,,,,,,,,,,, -225400,2.9004166,1.3980162,,,,,,,,,,,,,, -225500,3.5964575,1.1585159,,,,,,,,,,,,,, -225600,2.9145913,2.3954597,,,,,,,,,,,,,, -225700,2.868242,1.021501,,,,,,,,,,,,,, -225764,,,0.8885741829872131,0.4156999588012695,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,100872.4828350544,109236.17139697076,100872.4828350544,8336.703224897385,14.600011348724363,0.0 -225800,3.0654361,1.0623735,,,,,,,,,,,,,, -225900,3.1520922,1.3514699,,,,,,,,,,,,,, -226000,2.8548317,1.6668303,,,,,,,,,,,,,, -226100,3.5365841,3.105379,,,,,,,,,,,,,, -226200,3.048684,1.1099527,,,,,,,,,,,,,, -226300,3.5713603,1.1552206,,,,,,,,,,,,,, -226400,3.7846677,3.231202,,,,,,,,,,,,,, -226500,3.2550745,2.8482828,,,,,,,,,,,,,, -226600,2.9744556,1.1887146,,,,,,,,,,,,,, -226700,3.123794,2.7477007,,,,,,,,,,,,,, -226704,,,0.88880854845047,0.4133599698543548,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,101292.43042826653,109695.03538489342,101292.43042826653,8375.506333589554,14.662203788757324,0.0 -226800,3.0246243,2.5028613,,,,,,,,,,,,,, -226900,4.157565,3.2737021,,,,,,,,,,,,,, -227000,3.0370429,1.1233275,,,,,,,,,,,,,, -227100,3.1268044,1.1586081,,,,,,,,,,,,,, -227200,3.0948713,1.0993669,,,,,,,,,,,,,, -227300,2.8775523,1.0160522,,,,,,,,,,,,,, -227400,2.9106503,2.1748416,,,,,,,,,,,,,, -227500,2.880013,1.2580205,,,,,,,,,,,,,, -227600,3.044893,1.1128993,,,,,,,,,,,,,, -227643,,,0.8872656226158142,0.4160450398921966,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,101712.58194470406,110161.97889208794,101712.58194470406,8422.173720359802,14.73522162437439,0.0 -227700,3.4433446,1.1282004,,,,,,,,,,,,,, -227800,3.072941,1.1860813,,,,,,,,,,,,,, -227900,2.9608421,1.0032959,,,,,,,,,,,,,, -228000,4.822114,2.9098592,,,,,,,,,,,,,, -228100,3.2455468,2.3821387,,,,,,,,,,,,,, -228200,3.3009253,2.8469377,,,,,,,,,,,,,, -228300,3.024016,1.1569432,,,,,,,,,,,,,, -228400,3.3387961,1.4350343,,,,,,,,,,,,,, -228500,3.065603,1.1629742,,,,,,,,,,,,,, -228582,,,0.8897070288658142,0.4089111089706421,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,102132.54553413393,110627.42151880264,102132.54553413393,8467.515223264694,14.81779718399048,0.0 -228600,2.9891124,1.0869188,,,,,,,,,,,,,, -228700,3.0477548,2.1680698,,,,,,,,,,,,,, -228800,2.8809505,1.2519828,,,,,,,,,,,,,, -228900,3.159465,2.581233,,,,,,,,,,,,,, -229000,3.089278,1.1454424,,,,,,,,,,,,,, -229100,2.8631065,1.156218,,,,,,,,,,,,,, -229200,3.2242768,1.1663696,,,,,,,,,,,,,, -229300,3.097178,2.3998585,,,,,,,,,,,,,, -229400,3.3072696,1.3064474,,,,,,,,,,,,,, -229500,3.2900789,1.4520231,,,,,,,,,,,,,, -229521,,,0.8869531154632568,0.4167779386043548,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,102552.66275954248,111086.84193229675,102552.66275954248,8506.700605392456,14.884428977966309,0.0 -229600,3.7616575,3.2259784,,,,,,,,,,,,,, -229700,3.298965,1.412731,,,,,,,,,,,,,, -229800,3.1202266,1.0714409,,,,,,,,,,,,,, -229900,3.0865068,1.1421652,,,,,,,,,,,,,, -230000,3.0392156,1.1225804,,,,,,,,,,,,,, -230100,3.211534,2.2036388,,,,,,,,,,,,,, -230200,3.4327931,3.0662618,,,,,,,,,,,,,, -230300,3.3354404,1.7113823,,,,,,,,,,,,,, -230400,3.945065,1.4168807,,,,,,,,,,,,,, -230459,,,0.8913085460662842,0.4091808199882507,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,102972.72747325896,111545.024851799,102972.72747325896,8544.661287546158,14.990256071090698,0.0 -230500,3.331139,2.7789872,,,,,,,,,,,,,, -230600,2.8734,1.2883875,,,,,,,,,,,,,, -230700,3.203162,1.7832696,,,,,,,,,,,,,, -230800,3.1211634,1.2402767,,,,,,,,,,,,,, -230900,2.8607905,1.6569908,,,,,,,,,,,,,, -231000,2.9910223,1.0241885,,,,,,,,,,,,,, -231100,3.1736386,1.2134647,,,,,,,,,,,,,, -231200,3.2030265,2.5276098,,,,,,,,,,,,,, -231300,3.137538,1.1088895,,,,,,,,,,,,,, -231386,,,0.8886132836341858,0.41818568110466,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,103392.72380638124,112006.82755184174,103392.72380638124,8586.344784021378,15.06243109703064,0.0 -231400,3.91465,3.2473462,,,,,,,,,,,,,, -231500,3.1144085,1.1068307,,,,,,,,,,,,,, -231600,3.27145,2.91298,,,,,,,,,,,,,, -231700,3.710618,2.5778577,,,,,,,,,,,,,, -231800,3.3485646,1.9344888,,,,,,,,,,,,,, -231900,3.6829562,3.2360992,,,,,,,,,,,,,, -232000,3.2060018,1.2139463,,,,,,,,,,,,,, -232100,2.951781,1.0867583,,,,,,,,,,,,,, -232200,3.123196,1.3622773,,,,,,,,,,,,,, -232300,2.996321,1.9106735,,,,,,,,,,,,,, -232323,,,0.8895898461341858,0.4101483821868896,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,103812.9854183197,112467.85011029243,103812.9854183197,8626.981405496597,15.136332988739014,0.0 -232400,2.7630775,1.6223212,,,,,,,,,,,,,, -232500,3.1140308,1.1270461,,,,,,,,,,,,,, -232600,2.9272645,2.5516653,,,,,,,,,,,,,, -232700,2.9496214,1.1486764,,,,,,,,,,,,,, -232800,3.9377172,3.2185326,,,,,,,,,,,,,, -232900,2.8269913,1.4405949,,,,,,,,,,,,,, -233000,3.1393619,1.1517657,,,,,,,,,,,,,, -233100,3.2213795,1.1226138,,,,,,,,,,,,,, -233200,2.9039857,1.9998869,,,,,,,,,,,,,, -233264,,,0.8911718726158142,0.4089621007442474,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,104233.0490720272,112924.8316013813,104233.0490720272,8663.754947662354,15.229785442352297,0.0 -233300,3.1871479,1.0775794,,,,,,,,,,,,,, -233400,2.8911033,1.5965272,,,,,,,,,,,,,, -233500,3.1681876,2.2457638,,,,,,,,,,,,,, -233600,3.2082496,2.903433,,,,,,,,,,,,,, -233700,3.3859081,1.2189398,,,,,,,,,,,,,, -233800,2.9460175,1.2206048,,,,,,,,,,,,,, -233900,3.0660176,1.1200523,,,,,,,,,,,,,, -234000,3.176663,1.9803396,,,,,,,,,,,,,, -234100,3.05209,1.0607361,,,,,,,,,,,,,, -234200,3.8156655,3.2011225,,,,,,,,,,,,,, -234202,,,0.8874609470367432,0.4192294180393219,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,104652.99808740616,113384.81877231598,104652.99808740616,8703.668316602707,15.30400776863098,0.0 -234300,3.2114432,1.1048052,,,,,,,,,,,,,, -234400,3.0746126,1.1430652,,,,,,,,,,,,,, -234500,2.8287709,2.342237,,,,,,,,,,,,,, -234600,2.9124444,1.4688759,,,,,,,,,,,,,, -234700,3.6236825,3.2622287,,,,,,,,,,,,,, -234800,2.981055,1.5740552,,,,,,,,,,,,,, -234900,3.142106,1.0291198,,,,,,,,,,,,,, -235000,3.098982,2.7175684,,,,,,,,,,,,,, -235100,3.6454077,3.2151647,,,,,,,,,,,,,, -235137,,,0.8865624666213989,0.4175738096237182,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,105072.99116683006,113843.39495277403,105072.99116683006,8742.130467414856,15.374670505523682,0.0 -235200,3.1614435,1.8442949,,,,,,,,,,,,,, -235300,3.024923,2.7620537,,,,,,,,,,,,,, -235400,3.0661905,1.1557858,,,,,,,,,,,,,, -235500,2.8696196,1.897882,,,,,,,,,,,,,, -235600,2.9326801,1.3636171,,,,,,,,,,,,,, -235700,3.8070347,3.0582194,,,,,,,,,,,,,, -235800,2.952505,1.0808935,,,,,,,,,,,,,, -235900,3.1137712,2.5134432,,,,,,,,,,,,,, -236000,3.1974864,1.0868914,,,,,,,,,,,,,, -236073,,,0.8865820169448853,0.418828547000885,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,105493.27949905396,114303.88598370552,105493.27949905396,8782.207236289978,15.449368953704834,0.0 -236100,3.0109663,1.1799998,,,,,,,,,,,,,, -236200,2.9455523,1.101951,,,,,,,,,,,,,, -236300,4.4115667,1.2339951,,,,,,,,,,,,,, -236400,3.1333187,2.2972505,,,,,,,,,,,,,, -236500,3.2068543,2.9127653,,,,,,,,,,,,,, -236600,3.3506525,3.0082126,,,,,,,,,,,,,, -236700,3.5061793,1.1470273,,,,,,,,,,,,,, -236800,3.104468,2.6400971,,,,,,,,,,,,,, -236900,3.2994826,1.2921406,,,,,,,,,,,,,, -237000,3.2340648,1.4066756,,,,,,,,,,,,,, -237012,,,0.8889843821525574,0.4155234694480896,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,105913.54170560835,114762.76225996016,105913.54170560835,8820.693273544312,15.526480674743652,0.0 -237100,2.8870912,1.822934,,,,,,,,,,,,,, -237200,2.9843256,1.8371723,,,,,,,,,,,,,, -237300,4.740239,3.2386813,,,,,,,,,,,,,, -237400,3.313515,2.9093516,,,,,,,,,,,,,, -237500,3.0624254,1.7055577,,,,,,,,,,,,,, -237600,3.0135822,2.6191983,,,,,,,,,,,,,, -237700,3.2543042,1.2726868,,,,,,,,,,,,,, -237800,3.0995708,1.1169745,,,,,,,,,,,,,, -237900,2.9104795,2.3516724,,,,,,,,,,,,,, -237948,,,0.8892773389816284,0.4160856008529663,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,106333.91696500778,115228.73643136024,106333.91696500778,8866.161488056183,15.605794191360474,0.0 -238000,2.813104,1.2274942,,,,,,,,,,,,,, -238100,2.9681952,1.897635,,,,,,,,,,,,,, -238200,2.935226,1.2045134,,,,,,,,,,,,,, -238300,3.2353725,1.1957519,,,,,,,,,,,,,, -238400,3.0516162,1.0559456,,,,,,,,,,,,,, -238500,3.315273,1.1097505,,,,,,,,,,,,,, -238600,3.0238924,1.3198375,,,,,,,,,,,,,, -238700,3.0381842,2.4995298,,,,,,,,,,,,,, -238800,3.1008353,1.1423991,,,,,,,,,,,,,, -238888,,,0.8872265219688416,0.4133851826190948,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,106754.01340723038,115687.73761057854,106754.01340723038,8904.943639278412,15.677077531814575,0.0 -238900,3.0661926,1.0510062,,,,,,,,,,,,,, -239000,2.914305,1.10709,,,,,,,,,,,,,, -239100,2.7868998,1.0406312,,,,,,,,,,,,,, -239200,3.2280562,1.1107984,,,,,,,,,,,,,, -239300,3.619951,1.0732855,,,,,,,,,,,,,, -239400,3.1705499,2.300397,,,,,,,,,,,,,, -239500,3.0283134,1.5359015,,,,,,,,,,,,,, -239600,3.0892732,2.5838919,,,,,,,,,,,,,, -239700,2.8706832,1.2059791,,,,,,,,,,,,,, -239800,3.2395315,1.165259,,,,,,,,,,,,,, -239829,,,0.8869531154632568,0.4212391078472137,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,107174.28653025629,116144.81491804124,107174.28653025629,8941.621190309525,15.752530097961426,0.0 -239900,3.0842779,1.0258672,,,,,,,,,,,,,, -240000,3.2396307,2.6543274,,,,,,,,,,,,,, -240100,3.682383,3.1399426,,,,,,,,,,,,,, -240200,2.966341,2.301499,,,,,,,,,,,,,, -240300,3.0080576,1.1300013,,,,,,,,,,,,,, -240400,3.1893904,1.7664914,,,,,,,,,,,,,, -240500,3.0194366,1.9358389,,,,,,,,,,,,,, -240600,3.5988595,3.2461762,,,,,,,,,,,,,, -240700,3.1691668,1.1948503,,,,,,,,,,,,,, -240767,,,0.8889062404632568,0.4103345870971679,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,107594.49706172945,116602.501486063,107594.49706172945,8978.973022460938,15.82563543319702,0.0 -240800,3.2555056,1.133436,,,,,,,,,,,,,, -240900,3.233027,1.1995354,,,,,,,,,,,,,, -241000,3.0224779,1.3266081,,,,,,,,,,,,,, -241100,2.8142004,1.2225246,,,,,,,,,,,,,, -241200,3.6700096,2.8063653,,,,,,,,,,,,,, -241300,2.943671,1.1795408,,,,,,,,,,,,,, -241400,3.1450145,2.5699205,,,,,,,,,,,,,, -241500,2.9932222,1.1438657,,,,,,,,,,,,,, -241600,4.2990437,3.253614,,,,,,,,,,,,,, -241700,3.0086436,1.3117517,,,,,,,,,,,,,, -241705,,,0.8889257907867432,0.4125271141529083,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,108014.80342388152,117064.91074442863,108014.80342388152,9020.95052599907,15.900647640228271,0.0 -241800,3.1428988,1.2061126,,,,,,,,,,,,,, -241900,3.000632,1.4046355,,,,,,,,,,,,,, -242000,3.6207995,3.1890144,,,,,,,,,,,,,, -242100,2.8835158,2.1245892,,,,,,,,,,,,,, -242200,2.96141,1.4004501,,,,,,,,,,,,,, -242300,3.2619085,1.8249643,,,,,,,,,,,,,, -242400,3.1568599,1.3525264,,,,,,,,,,,,,, -242500,3.0420222,1.0822879,,,,,,,,,,,,,, -242600,3.2515082,1.8057528,,,,,,,,,,,,,, -242644,,,0.8873046636581421,0.4170242846012115,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,108434.92385172844,117523.58127760889,108434.92385172844,9059.374162912369,15.975814819335938,0.0 -242700,2.8472815,1.0913856,,,,,,,,,,,,,, -242800,2.8940828,1.1364383,,,,,,,,,,,,,, -242900,3.0203438,1.7626423,,,,,,,,,,,,,, -243000,2.8620503,1.8634443,,,,,,,,,,,,,, -243100,3.4387996,2.9493082,,,,,,,,,,,,,, -243200,3.7626932,2.2205634,,,,,,,,,,,,,, -243300,3.4810102,2.9981852,,,,,,,,,,,,,, -243400,3.1715336,1.0691713,,,,,,,,,,,,,, -243500,2.9680402,1.3562448,,,,,,,,,,,,,, -243578,,,0.8865820169448853,0.4223847091197967,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,108855.2262763977,117981.85328555109,108855.2262763977,9097.214118480682,16.05459976196289,0.0 -243600,3.0775967,1.1189734,,,,,,,,,,,,,, -243700,2.8647623,1.1121235,,,,,,,,,,,,,, -243800,3.4331367,3.1524231,,,,,,,,,,,,,, -243900,3.0620954,1.0830326,,,,,,,,,,,,,, -244000,3.618924,3.2335346,,,,,,,,,,,,,, -244100,3.1453938,1.2648809,,,,,,,,,,,,,, -244200,2.9992266,1.1212931,,,,,,,,,,,,,, -244300,3.2314122,1.0877101,,,,,,,,,,,,,, -244400,3.3333278,2.8274822,,,,,,,,,,,,,, -244500,3.1245327,1.1245636,,,,,,,,,,,,,, -244515,,,0.8882226347923279,0.4138920903205871,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,109275.20730948448,118441.95490217207,109275.20730948448,9137.2043967247,16.133418798446655,0.0 -244600,3.1978736,1.1348009,,,,,,,,,,,,,, -244700,4.3299537,2.9737182,,,,,,,,,,,,,, -244800,3.3793085,3.180986,,,,,,,,,,,,,, -244900,3.0530882,1.859439,,,,,,,,,,,,,, -245000,2.8865178,2.4051204,,,,,,,,,,,,,, -245100,3.066691,1.16677,,,,,,,,,,,,,, -245200,3.2390537,1.0814444,,,,,,,,,,,,,, -245300,2.7458737,1.151235,,,,,,,,,,,,,, -245400,2.92403,1.0783767,,,,,,,,,,,,,, -245454,,,0.887499988079071,0.4170665740966797,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,109695.28633069992,118900.72915911674,109695.28633069992,9175.77440404892,16.207163333892822,0.0 -245500,2.971754,1.2126374,,,,,,,,,,,,,, -245600,2.9021895,1.1233952,,,,,,,,,,,,,, -245700,3.5740898,3.134584,,,,,,,,,,,,,, -245800,3.8332329,3.3098722,,,,,,,,,,,,,, -245900,3.247685,2.350773,,,,,,,,,,,,,, -246000,3.040199,1.1795807,,,,,,,,,,,,,, -246100,3.0471482,1.1033216,,,,,,,,,,,,,, -246200,2.9870136,1.3108916,,,,,,,,,,,,,, -246300,3.188154,1.1641308,,,,,,,,,,,,,, -246392,,,0.8889257907867432,0.4171882271766662,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,110115.32331180573,119359.3913846016,110115.32331180573,9214.267620325089,16.286096334457397,0.0 -246400,3.7224371,1.1884873,,,,,,,,,,,,,, -246500,2.9883022,2.1567745,,,,,,,,,,,,,, -246600,3.0857427,2.114093,,,,,,,,,,,,,, -246700,2.8209767,1.9973801,,,,,,,,,,,,,, -246800,3.3368855,1.0801159,,,,,,,,,,,,,, -246900,2.8865652,1.9772497,,,,,,,,,,,,,, -247000,4.0704055,3.1513164,,,,,,,,,,,,,, -247100,3.0230734,1.1938635,,,,,,,,,,,,,, -247200,3.8678122,2.8875582,,,,,,,,,,,,,, -247300,3.1353796,1.1469592,,,,,,,,,,,,,, -247331,,,0.8894921541213989,0.4114971458911896,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,110535.2491569519,119822.7629380226,110535.2491569519,9257.57844209671,16.369798183441162,0.0 -247400,3.7519433,3.2778254,,,,,,,,,,,,,, -247500,3.4786742,1.1307591,,,,,,,,,,,,,, -247600,2.6595597,1.5177357,,,,,,,,,,,,,, -247700,2.917115,1.2226155,,,,,,,,,,,,,, -247800,3.3936272,2.7823048,,,,,,,,,,,,,, -247900,3.1855483,1.6175494,,,,,,,,,,,,,, -248000,3.21118,1.14331,,,,,,,,,,,,,, -248100,2.9889276,1.9979267,,,,,,,,,,,,,, -248200,2.7934554,1.023144,,,,,,,,,,,,,, -248273,,,0.8887109160423279,0.4138701260089874,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,110955.36974859238,120278.73826503754,110955.36974859238,9293.305986881256,16.446000337600708,0.0 -248300,2.9913418,1.0665536,,,,,,,,,,,,,, -248400,2.9381719,1.9898932,,,,,,,,,,,,,, -248500,2.783067,1.5617307,,,,,,,,,,,,,, -248600,2.8788195,1.6084371,,,,,,,,,,,,,, -248700,2.9045331,2.0640597,,,,,,,,,,,,,, -248800,3.2320883,1.1313779,,,,,,,,,,,,,, -248900,2.7846456,1.1105552,,,,,,,,,,,,,, -249000,3.26594,1.1834762,,,,,,,,,,,,,, -249100,3.2052839,2.7237692,,,,,,,,,,,,,, -249200,3.0667186,1.2830352,,,,,,,,,,,,,, -249210,,,0.8867577910423279,0.4228262901306152,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,111375.37278437614,120741.10971736908,111375.37278437614,9335.543090820312,16.526723384857178,0.0 -249300,3.4067051,1.2794166,,,,,,,,,,,,,, -249400,3.5011034,1.1548163,,,,,,,,,,,,,, -249500,3.1818087,2.8460839,,,,,,,,,,,,,, -249600,3.1916392,1.297792,,,,,,,,,,,,,, -249700,3.0570219,2.104302,,,,,,,,,,,,,, -249800,3.1567366,1.0994532,,,,,,,,,,,,,, -249900,2.8559783,1.0949211,,,,,,,,,,,,,, -250000,2.8956978,1.1777235,,,,,,,,,,,,,, -250100,3.0766273,1.210463,,,,,,,,,,,,,, -250152,,,0.8893749713897705,0.4094995856285095,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,111795.39901304244,121195.81096291542,111795.39901304244,9370.087344884872,16.606194257736206,0.0 -250200,4.212173,3.174505,,,,,,,,,,,,,, -250300,3.0623815,1.430499,,,,,,,,,,,,,, -250400,3.1341395,1.061484,,,,,,,,,,,,,, -250500,3.2327452,1.4225985,,,,,,,,,,,,,, -250600,2.7744591,1.1275539,,,,,,,,,,,,,, -250700,3.0344515,2.423112,,,,,,,,,,,,,, -250800,2.921076,1.0475225,,,,,,,,,,,,,, -250900,2.8609452,1.329716,,,,,,,,,,,,,, -251000,2.9928546,1.3273367,,,,,,,,,,,,,, -251090,,,0.8884179592132568,0.4091140627861023,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,112215.63183784483,121660.96054553986,112215.63183784483,9414.86929488182,16.68982768058777,0.0 -251100,2.8401845,1.081412,,,,,,,,,,,,,, -251200,3.1974983,2.5156343,,,,,,,,,,,,,, -251300,3.1176443,1.8414778,,,,,,,,,,,,,, -251400,2.9650753,1.0692487,,,,,,,,,,,,,, -251500,3.0998774,1.1782273,,,,,,,,,,,,,, -251600,3.034027,1.1541965,,,,,,,,,,,,,, -251700,3.092823,2.3475347,,,,,,,,,,,,,, -251800,3.0106251,1.2177753,,,,,,,,,,,,,, -251900,3.1960552,2.5082068,,,,,,,,,,,,,, -252000,2.9825823,2.157835,,,,,,,,,,,,,, -252034,,,0.8879687190055847,0.4145607650279999,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,112635.64821743964,122126.73324251176,112635.64821743964,9460.509674549105,16.75420641899109,0.0 -252100,2.8432019,1.1440924,,,,,,,,,,,,,, -252200,3.0720139,1.0591022,,,,,,,,,,,,,, -252300,2.912516,1.0519559,,,,,,,,,,,,,, -252400,3.1237833,1.7079607,,,,,,,,,,,,,, -252500,3.0083256,1.1195709,,,,,,,,,,,,,, -252600,3.6493642,3.1930192,,,,,,,,,,,,,, -252700,3.5632796,1.7825378,,,,,,,,,,,,,, -252800,3.9972553,3.0298936,,,,,,,,,,,,,, -252900,3.203027,1.0873845,,,,,,,,,,,,,, -252976,,,0.8885937333106995,0.41547492146492,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,113055.63692212103,122590.56933999062,113055.63692212103,9504.239057064056,16.821215629577637,0.0 -253000,3.4401407,3.061526,,,,,,,,,,,,,, -253100,3.0993772,1.0463569,,,,,,,,,,,,,, -253200,3.153684,2.8852878,,,,,,,,,,,,,, -253300,3.1978393,1.6791464,,,,,,,,,,,,,, -253400,3.0349534,1.8813057,,,,,,,,,,,,,, -253500,4.0491705,1.0774244,,,,,,,,,,,,,, -253600,3.0945258,1.1317621,,,,,,,,,,,,,, -253700,3.1420093,1.7047341,,,,,,,,,,,,,, -253800,3.1718051,1.0891601,,,,,,,,,,,,,, -253900,3.933628,3.264521,,,,,,,,,,,,,, -253915,,,0.8883593678474426,0.4171532988548279,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,113475.78511810304,123048.69641470908,113475.78511810304,9542.089750528336,16.89837408065796,0.0 -254000,3.1118245,1.7995989,,,,,,,,,,,,,, -254100,3.0615613,1.1255193,,,,,,,,,,,,,, -254200,3.0215418,1.0550884,,,,,,,,,,,,,, -254300,3.4287136,3.0228925,,,,,,,,,,,,,, -254400,3.1754835,1.1329575,,,,,,,,,,,,,, -254500,2.9253297,1.1664815,,,,,,,,,,,,,, -254600,3.167535,1.0944514,,,,,,,,,,,,,, -254700,3.2496579,2.7172763,,,,,,,,,,,,,, -254800,3.264071,1.6468261,,,,,,,,,,,,,, -254846,,,0.8913671970367432,0.4085223972797394,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,113895.78473091124,123508.19112324716,113895.78473091124,9581.454507350922,16.978004217147827,0.0 -254900,2.9885569,1.171525,,,,,,,,,,,,,, -255000,3.0606847,1.1897398,,,,,,,,,,,,,, -255100,3.4564726,1.143481,,,,,,,,,,,,,, -255200,3.1352296,2.7463553,,,,,,,,,,,,,, -255300,2.957893,1.4253211,,,,,,,,,,,,,, -255400,3.042448,1.1281317,,,,,,,,,,,,,, -255500,3.1116695,1.1394233,,,,,,,,,,,,,, -255600,3.2080166,1.4948919,,,,,,,,,,,,,, -255700,2.948496,1.9980242,,,,,,,,,,,,,, -255785,,,0.8885351419448853,0.4148720502853393,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,114315.99058461188,123973.41496372224,114315.99058461188,9626.326175928116,17.073423862457275,0.0 -255800,3.0195804,1.1422428,,,,,,,,,,,,,, -255900,2.874692,1.2022166,,,,,,,,,,,,,, -256000,3.173679,1.221292,,,,,,,,,,,,,, -256100,3.242113,1.2745316,,,,,,,,,,,,,, -256200,3.480969,2.5073576,,,,,,,,,,,,,, -256300,2.9203894,1.4104884,,,,,,,,,,,,,, -256400,2.9985034,1.1386946,,,,,,,,,,,,,, -256500,3.3740792,2.960947,,,,,,,,,,,,,, -256600,3.2073202,2.0726523,,,,,,,,,,,,,, -256700,3.1562006,1.1177851,,,,,,,,,,,,,, -256728,,,0.8897265195846558,0.4137465059757232,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,114736.19265317915,124426.6122317314,114736.19265317915,9659.201932668686,17.14213752746582,0.0 -256800,3.3091817,2.0613782,,,,,,,,,,,,,, -256900,3.265509,1.1780938,,,,,,,,,,,,,, -257000,2.8848991,1.1743715,,,,,,,,,,,,,, -257100,3.2175746,2.1119637,,,,,,,,,,,,,, -257200,2.9696162,1.4498253,,,,,,,,,,,,,, -257300,3.263485,1.0986094,,,,,,,,,,,,,, -257400,2.9303598,1.3948075,,,,,,,,,,,,,, -257500,3.3966932,1.0303485,,,,,,,,,,,,,, -257600,3.0273015,1.3447691,,,,,,,,,,,,,, -257655,,,0.8881640434265137,0.4150804877281189,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,115156.30558776855,124887.1065568924,115156.30558776855,9699.450918912888,17.224257707595825,0.0 -257700,3.316737,1.2762445,,,,,,,,,,,,,, -257800,7.0507975,3.0669844,,,,,,,,,,,,,, -257900,2.8507311,1.6189154,,,,,,,,,,,,,, -258000,3.4354181,1.1703244,,,,,,,,,,,,,, -258100,2.8141763,1.1013566,,,,,,,,,,,,,, -258200,3.0169442,1.475668,,,,,,,,,,,,,, -258300,2.7310877,1.9708222,,,,,,,,,,,,,, -258400,3.0685875,1.2878116,,,,,,,,,,,,,, -258500,3.0607767,1.0794398,,,,,,,,,,,,,, -258593,,,0.8871288895606995,0.4156031608581543,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,115576.53962373734,125355.9505121708,115576.53962373734,9747.919601917269,17.31403613090515,0.0 -258600,2.9527025,1.3245445,,,,,,,,,,,,,, -258700,2.8524117,1.5599221,,,,,,,,,,,,,, -258800,3.108282,1.3287338,,,,,,,,,,,,,, -258900,2.9650264,1.2115754,,,,,,,,,,,,,, -259000,3.063222,1.2020751,,,,,,,,,,,,,, -259100,2.8864114,1.6523671,,,,,,,,,,,,,, -259200,3.201368,1.1391814,,,,,,,,,,,,,, -259300,3.0400887,1.0886436,,,,,,,,,,,,,, -259400,3.0907962,1.1792392,,,,,,,,,,,,,, -259500,3.0890665,1.0745026,,,,,,,,,,,,,, -259536,,,0.8883788585662842,0.4144641458988189,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,115996.80192518234,125816.06837558746,115996.80192518234,9787.654287099838,17.384129524230957,0.0 -259600,3.046397,1.2093682,,,,,,,,,,,,,, -259700,2.969711,1.0443474,,,,,,,,,,,,,, -259800,3.1011112,1.1615556,,,,,,,,,,,,,, -259900,3.0795171,1.1046033,,,,,,,,,,,,,, -260000,3.101325,1.3524662,,,,,,,,,,,,,, -260100,4.1179895,3.2320392,,,,,,,,,,,,,, -260200,2.9981937,1.1310436,,,,,,,,,,,,,, -260300,3.524949,3.1224308,,,,,,,,,,,,,, -260400,3.2477312,1.6311704,,,,,,,,,,,,,, -260473,,,0.888671875,0.4139718115329742,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,116417.05393886566,126278.66386461258,116417.05393886566,9829.864854097366,17.466994762420654,0.0 -260500,4.25787,3.264213,,,,,,,,,,,,,, -260600,3.0867128,1.1701,,,,,,,,,,,,,, -260700,3.0149179,1.1415204,,,,,,,,,,,,,, -260800,3.6792657,2.911159,,,,,,,,,,,,,, -260900,2.9948363,1.1150389,,,,,,,,,,,,,, -261000,3.3813715,1.5805793,,,,,,,,,,,,,, -261100,3.127759,2.42002,,,,,,,,,,,,,, -261200,2.8534386,1.0261918,,,,,,,,,,,,,, -261300,3.0348065,2.4637868,,,,,,,,,,,,,, -261400,3.1665854,1.1343479,,,,,,,,,,,,,, -261413,,,0.8889257907867432,0.4145142734050751,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,116837.02300691605,126733.23106122015,116837.02300691605,9864.348781108856,17.530568838119507,0.0 -261500,3.2550242,2.817955,,,,,,,,,,,,,, -261600,2.9564104,1.1264002,,,,,,,,,,,,,, -261700,3.9232156,3.3357859,,,,,,,,,,,,,, -261800,3.1807258,1.2077031,,,,,,,,,,,,,, -261900,2.979753,1.5665581,,,,,,,,,,,,,, -262000,3.1199045,1.1593006,,,,,,,,,,,,,, -262100,3.0313008,2.0534768,,,,,,,,,,,,,, -262200,3.1161513,2.4884183,,,,,,,,,,,,,, -262300,3.1377733,1.0560223,,,,,,,,,,,,,, -262350,,,0.8883984088897705,0.4164403676986694,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,117257.40816497804,127202.38626980782,117257.40816497804,9912.99257516861,17.60601305961609,0.0 -262400,3.149448,1.2194257,,,,,,,,,,,,,, -262500,2.9847775,1.0590128,,,,,,,,,,,,,, -262600,2.6766033,1.5067143,,,,,,,,,,,,,, -262700,2.9017048,1.1194204,,,,,,,,,,,,,, -262800,2.8929725,1.135876,,,,,,,,,,,,,, -262900,3.0530603,1.3272289,,,,,,,,,,,,,, -263000,3.1247075,1.0965277,,,,,,,,,,,,,, -263100,3.06709,1.114027,,,,,,,,,,,,,, -263200,3.1502538,2.0789788,,,,,,,,,,,,,, -263291,,,0.8852343559265137,0.4217980802059173,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,117677.55997300148,127661.7043390274,117677.55997300148,9952.029757261276,17.68432903289795,0.0 -263300,3.0023813,1.099759,,,,,,,,,,,,,, -263400,2.7668915,1.4849907,,,,,,,,,,,,,, -263500,3.931527,2.4062395,,,,,,,,,,,,,, -263600,3.270973,1.0655471,,,,,,,,,,,,,, -263700,2.9744308,1.1004738,,,,,,,,,,,,,, -263800,3.328913,1.1380956,,,,,,,,,,,,,, -263900,2.989404,1.144176,,,,,,,,,,,,,, -264000,2.8582618,2.1416252,,,,,,,,,,,,,, -264100,3.007635,1.1899492,,,,,,,,,,,,,, -264200,4.0760627,3.1721883,,,,,,,,,,,,,, -264227,,,0.8889843821525574,0.4135739505290985,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,118097.5377099514,128124.80843949318,118097.5377099514,9995.02408361435,17.765547513961792,0.0 -264300,2.8396628,1.5040482,,,,,,,,,,,,,, -264400,3.499853,2.035755,,,,,,,,,,,,,, -264500,3.5134647,2.9473238,,,,,,,,,,,,,, -264600,3.2794352,1.9345262,,,,,,,,,,,,,, -264700,3.0545096,1.1469125,,,,,,,,,,,,,, -264800,2.97523,1.1217304,,,,,,,,,,,,,, -264900,2.8321424,1.4655559,,,,,,,,,,,,,, -265000,2.9442039,1.1818123,,,,,,,,,,,,,, -265100,2.9018445,1.0708075,,,,,,,,,,,,,, -265162,,,0.8879492282867432,0.4144449830055237,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,118517.52534794807,128580.78847575188,118517.52534794807,10030.878251314163,17.853434562683105,0.0 -265200,3.3976235,1.2241404,,,,,,,,,,,,,, -265300,3.328179,1.1704737,,,,,,,,,,,,,, -265400,2.9131584,1.3044481,,,,,,,,,,,,,, -265500,2.9295232,1.6249605,,,,,,,,,,,,,, -265600,2.83995,1.431293,,,,,,,,,,,,,, -265700,3.2351916,2.7967503,,,,,,,,,,,,,, -265800,3.0559714,1.8680651,,,,,,,,,,,,,, -265900,3.1773617,1.0939476,,,,,,,,,,,,,, -266000,2.884602,1.6267067,,,,,,,,,,,,,, -266097,,,0.8884179592132568,0.4173905849456787,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,118937.56853723526,129039.0865252018,118937.56853723526,10069.000126123428,17.936197996139526,0.0 -266100,3.116917,2.6168199,,,,,,,,,,,,,, -266200,3.1928704,1.8773398,,,,,,,,,,,,,, -266300,3.0094502,1.131134,,,,,,,,,,,,,, -266400,4.5087495,2.4653018,,,,,,,,,,,,,, -266500,3.13257,1.0892735,,,,,,,,,,,,,, -266600,3.0784233,1.1299181,,,,,,,,,,,,,, -266700,3.373793,2.9161239,,,,,,,,,,,,,, -266800,3.7112286,3.105999,,,,,,,,,,,,,, -266900,2.8893673,2.2892256,,,,,,,,,,,,,, -267000,2.9859898,1.0881313,,,,,,,,,,,,,, -267036,,,0.8875976204872131,0.4157011210918426,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,119357.60832476616,129502.25046014786,119357.60832476616,10111.985719919205,18.02430295944214,0.0 -267100,2.8461115,1.1387128,,,,,,,,,,,,,, -267200,3.377846,1.1672635,,,,,,,,,,,,,, -267300,3.0078495,1.3506938,,,,,,,,,,,,,, -267400,2.9133408,2.234135,,,,,,,,,,,,,, -267500,2.9557521,1.0553226,,,,,,,,,,,,,, -267600,3.0822725,2.0949805,,,,,,,,,,,,,, -267700,3.0811434,2.187859,,,,,,,,,,,,,, -267800,3.1354346,1.0691252,,,,,,,,,,,,,, -267900,3.372491,1.2899604,,,,,,,,,,,,,, -267975,,,0.8877733945846558,0.4151789844036102,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,119777.81729054452,129960.78221297264,119777.81729054452,10150.179717302322,18.103163957595825,0.0 -268000,3.085818,1.6590173,,,,,,,,,,,,,, -268100,2.9894063,1.1122023,,,,,,,,,,,,,, -268200,2.8956127,1.2324842,,,,,,,,,,,,,, -268300,2.9824934,1.0708634,,,,,,,,,,,,,, -268400,3.0649428,1.3117603,,,,,,,,,,,,,, -268500,3.0849323,1.2180215,,,,,,,,,,,,,, -268600,3.8368227,3.1968658,,,,,,,,,,,,,, -268700,3.4727168,1.0449054,,,,,,,,,,,,,, -268800,2.8519669,1.3768872,,,,,,,,,,,,,, -268900,3.1096604,1.1282239,,,,,,,,,,,,,, -268913,,,0.8890820145606995,0.411138653755188,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,120197.99805498125,130421.24511313438,120197.99805498125,10190.332374095917,18.181214332580566,0.0 -269000,2.795878,1.3611567,,,,,,,,,,,,,, -269100,3.0686684,1.5134866,,,,,,,,,,,,,, -269200,3.4080482,2.8610857,,,,,,,,,,,,,, -269300,2.9574137,1.1732075,,,,,,,,,,,,,, -269400,2.869007,2.3968627,,,,,,,,,,,,,, -269500,3.497834,1.2294433,,,,,,,,,,,,,, -269600,3.238881,2.764506,,,,,,,,,,,,,, -269700,2.892144,1.1324174,,,,,,,,,,,,,, -269800,2.9201152,1.1289239,,,,,,,,,,,,,, -269850,,,0.8871679306030273,0.4229635894298553,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,120617.99323821068,130879.83401083946,120617.99323821068,10228.79074215889,18.266133069992065,0.0 -269900,3.0928488,1.1367341,,,,,,,,,,,,,, -270000,3.8037949,3.278934,,,,,,,,,,,,,, -270100,3.2873962,1.7493358,,,,,,,,,,,,,, -270200,2.9269495,2.6109838,,,,,,,,,,,,,, -270300,3.0764797,1.1446557,,,,,,,,,,,,,, -270400,2.7695937,1.6477628,,,,,,,,,,,,,, -270500,3.1534224,1.4057378,,,,,,,,,,,,,, -270600,2.887852,1.7921593,,,,,,,,,,,,,, -270700,3.049638,2.6868253,,,,,,,,,,,,,, -270790,,,0.88929682970047,0.4126971662044525,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,121037.9982905388,131340.73326659203,121037.9982905388,10269.55788254738,18.342522621154785,0.0 -270800,2.9372454,1.1251786,,,,,,,,,,,,,, -270900,4.6906524,3.2568247,,,,,,,,,,,,,, -271000,3.7752414,3.2797909,,,,,,,,,,,,,, -271100,2.5954144,2.2344308,,,,,,,,,,,,,, -271200,2.9871652,1.9156477,,,,,,,,,,,,,, -271300,3.1486623,1.8089947,,,,,,,,,,,,,, -271400,3.3067043,1.7996458,,,,,,,,,,,,,, -271500,3.1239336,2.827198,,,,,,,,,,,,,, -271600,3.336064,2.3179734,,,,,,,,,,,,,, -271700,3.189153,1.0813532,,,,,,,,,,,,,, -271730,,,0.8877539038658142,0.4173440933227539,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,121457.97619390488,131799.77235770226,121457.97619390488,10308.489049196243,18.421640157699585,0.0 -271800,5.7783165,3.2553577,,,,,,,,,,,,,, -271900,3.6776228,1.1986777,,,,,,,,,,,,,, -272000,2.999592,1.7659827,,,,,,,,,,,,,, -272100,3.5829232,3.0045962,,,,,,,,,,,,,, -272200,3.1575925,1.1803879,,,,,,,,,,,,,, -272300,2.8806715,1.3395386,,,,,,,,,,,,,, -272400,3.0472567,1.170997,,,,,,,,,,,,,, -272500,3.6263888,3.1884534,,,,,,,,,,,,,, -272600,2.9426563,1.1762933,,,,,,,,,,,,,, -272670,,,0.88832026720047,0.4156412780284881,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,121877.95221614838,132261.28627204895,121877.95221614838,10349.891362667084,18.50593400001526,0.0 -272700,3.2165945,1.6708236,,,,,,,,,,,,,, -272800,3.3709536,1.2307812,,,,,,,,,,,,,, -272900,3.1533425,2.6875627,,,,,,,,,,,,,, -273000,2.8404744,1.0101827,,,,,,,,,,,,,, -273100,3.5962756,3.0182335,,,,,,,,,,,,,, -273200,3.141604,1.0530957,,,,,,,,,,,,,, -273300,2.9234214,1.0402621,,,,,,,,,,,,,, -273400,3.0434515,1.6005583,,,,,,,,,,,,,, -273500,3.0620391,1.9782188,,,,,,,,,,,,,, -273600,3.6330514,3.0708213,,,,,,,,,,,,,, -273609,,,0.8891796469688416,0.4100504517555237,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,122297.96140408516,132727.0404522419,122297.96140408516,10395.50066781044,18.59084153175354,0.0 -273700,2.8773274,1.0578877,,,,,,,,,,,,,, -273800,3.8969805,3.1975148,,,,,,,,,,,,,, -273900,2.8762112,1.4455405,,,,,,,,,,,,,, -274000,3.1709397,1.2052932,,,,,,,,,,,,,, -274100,3.315572,2.9630346,,,,,,,,,,,,,, -274200,3.083721,1.438987,,,,,,,,,,,,,, -274300,3.3088348,1.2061584,,,,,,,,,,,,,, -274400,4.28173,3.243524,,,,,,,,,,,,,, -274500,3.1495245,1.2080495,,,,,,,,,,,,,, -274550,,,0.8878515362739563,0.4165761172771454,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,122717.90025830267,133182.52132487297,122717.90025830267,10430.909448862076,18.673195123672485,0.0 -274600,3.3091831,1.2850213,,,,,,,,,,,,,, -274700,3.1230216,1.0691838,,,,,,,,,,,,,, -274800,2.999371,0.991018,,,,,,,,,,,,,, -274900,4.222741,1.1869441,,,,,,,,,,,,,, -275000,2.868488,1.6591,,,,,,,,,,,,,, -275100,3.3450165,2.7171009,,,,,,,,,,,,,, -275200,3.2261336,1.3267523,,,,,,,,,,,,,, -275300,3.0515776,1.4269689,,,,,,,,,,,,,, -275400,3.0943975,1.1197723,,,,,,,,,,,,,, -275488,,,0.8887304663658142,0.4122629761695862,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,123138.03974294662,133644.14214658737,123138.03974294662,10472.254363298416,18.7580795288086,0.0 -275500,3.5837216,2.6208,,,,,,,,,,,,,, -275600,3.5351512,3.0302944,,,,,,,,,,,,,, -275700,3.21395,1.283924,,,,,,,,,,,,,, -275800,3.108281,1.1541864,,,,,,,,,,,,,, -275900,3.0148268,1.155051,,,,,,,,,,,,,, -276000,3.650617,3.2268593,,,,,,,,,,,,,, -276100,3.1171434,1.0488827,,,,,,,,,,,,,, -276200,2.979128,2.2128272,,,,,,,,,,,,,, -276300,3.23075,1.1254328,,,,,,,,,,,,,, -276400,3.6890132,3.0730264,,,,,,,,,,,,,, -276426,,,0.8866796493530273,0.4174808859825134,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,123558.17305517197,134104.78236317635,123558.17305517197,10512.6342856884,18.83539652824402,0.0 -276500,3.2944124,2.7863317,,,,,,,,,,,,,, -276600,3.178505,1.1635156,,,,,,,,,,,,,, -276700,3.2957618,1.1059154,,,,,,,,,,,,,, -276800,3.2621746,1.1509922,,,,,,,,,,,,,, -276900,2.7701108,1.4764346,,,,,,,,,,,,,, -277000,2.9090858,1.4087378,,,,,,,,,,,,,, -277100,3.0125663,1.1016756,,,,,,,,,,,,,, -277200,3.2387927,2.2535706,,,,,,,,,,,,,, -277300,3.5529675,3.020443,,,,,,,,,,,,,, -277365,,,0.8886132836341858,0.4149703085422516,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,123978.34770202637,134563.87787127495,123978.34770202637,10551.421706914902,18.918168783187863,0.0 -277400,2.8339684,1.9277558,,,,,,,,,,,,,, -277500,2.8396251,1.090298,,,,,,,,,,,,,, -277600,2.939074,1.1347193,,,,,,,,,,,,,, -277700,2.9368637,1.1183186,,,,,,,,,,,,,, -277800,3.0521505,1.0769309,,,,,,,,,,,,,, -277900,3.449124,1.2646148,,,,,,,,,,,,,, -278000,3.4605222,2.9151232,,,,,,,,,,,,,, -278100,3.1174738,1.958532,,,,,,,,,,,,,, -278200,2.9488578,1.6477432,,,,,,,,,,,,,, -278300,3.1794372,1.3655463,,,,,,,,,,,,,, -278304,,,0.8903124928474426,0.4152889549732208,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,124398.55496621132,135023.04332780838,124398.55496621132,10590.24893975258,18.998623609542847,0.0 -278400,5.2723126,2.8774786,,,,,,,,,,,,,, -278500,3.21984,1.4405544,,,,,,,,,,,,,, -278600,3.8057754,3.2628932,,,,,,,,,,,,,, -278700,2.9058583,1.6585989,,,,,,,,,,,,,, -278800,3.2041175,2.7458081,,,,,,,,,,,,,, -278900,3.239682,1.2128601,,,,,,,,,,,,,, -279000,3.2874422,2.82252,,,,,,,,,,,,,, -279100,3.1109312,1.1234487,,,,,,,,,,,,,, -279200,3.3342757,1.1240282,,,,,,,,,,,,,, -279242,,,0.888671875,0.4086595475673675,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,124818.6222038269,135482.02679228783,124818.6222038269,10629.01440358162,19.098854541778564,0.0 -279300,3.1022518,2.368684,,,,,,,,,,,,,, -279400,3.628507,3.1689715,,,,,,,,,,,,,, -279500,3.0946007,1.1587504,,,,,,,,,,,,,, -279600,3.3239725,2.7237093,,,,,,,,,,,,,, -279700,2.7380648,1.808909,,,,,,,,,,,,,, -279800,3.8135743,3.3017592,,,,,,,,,,,,,, -279900,2.8903744,1.1772544,,,,,,,,,,,,,, -280000,3.0745218,1.4463658,,,,,,,,,,,,,, -280100,3.6139407,3.1895146,,,,,,,,,,,,,, -280179,,,0.8921093344688416,0.4076516628265381,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,125238.82615542412,135942.1735329628,125238.82615542412,10668.821629047394,19.18437695503235,0.0 -280200,3.1938243,2.5306115,,,,,,,,,,,,,, -280300,3.2109196,2.4705577,,,,,,,,,,,,,, -280400,2.9603744,1.496396,,,,,,,,,,,,,, -280500,3.5247996,3.1922019,,,,,,,,,,,,,, -280600,3.0236952,2.013423,,,,,,,,,,,,,, -280700,2.9853027,1.4760435,,,,,,,,,,,,,, -280800,3.1947985,2.8536825,,,,,,,,,,,,,, -280900,2.903098,1.1670358,,,,,,,,,,,,,, -281000,2.8418214,1.0331529,,,,,,,,,,,,,, -281100,2.985789,1.1275696,,,,,,,,,,,,,, -281119,,,0.8885937333106995,0.414132297039032,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,125659.04063010216,136401.85151863098,125659.04063010216,10708.152201652529,19.26614189147949,0.0 -281200,3.065264,1.1349125,,,,,,,,,,,,,, -281300,3.2203543,1.2619793,,,,,,,,,,,,,, -281400,2.9539557,2.2921934,,,,,,,,,,,,,, -281500,3.102371,2.8125062,,,,,,,,,,,,,, -281600,3.232602,1.1018772,,,,,,,,,,,,,, -281700,3.2200496,1.101844,,,,,,,,,,,,,, -281800,3.8214242,3.1478047,,,,,,,,,,,,,, -281900,3.1550543,1.153121,,,,,,,,,,,,,, -282000,3.1581852,1.121014,,,,,,,,,,,,,, -282060,,,0.8877929449081421,0.4160957634449005,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,126079.1490097046,136867.63150930405,126079.1490097046,10753.6897149086,19.349141597747803,0.0 -282100,3.6387854,3.31242,,,,,,,,,,,,,, -282200,3.298072,1.1853982,,,,,,,,,,,,,, -282300,3.1095293,1.5023111,,,,,,,,,,,,,, -282400,3.3182867,2.957664,,,,,,,,,,,,,, -282500,3.1578066,1.0493475,,,,,,,,,,,,,, -282600,3.2866924,2.8580296,,,,,,,,,,,,,, -282700,3.033034,1.0286664,,,,,,,,,,,,,, -282800,3.3185954,2.8144255,,,,,,,,,,,,,, -282900,3.0501096,1.1770536,,,,,,,,,,,,,, -282999,,,0.8861523270606995,0.418018102645874,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,126499.3281633854,137323.70598578453,126499.3281633854,10789.449345111849,19.43389129638672,0.0 -283000,3.247205,1.0583373,,,,,,,,,,,,,, -283100,3.0887413,2.5454779,,,,,,,,,,,,,, -283200,2.8468802,1.2716736,,,,,,,,,,,,,, -283300,3.2462583,2.5408676,,,,,,,,,,,,,, -283400,3.0710988,1.1577454,,,,,,,,,,,,,, -283500,4.34151,3.2674484,,,,,,,,,,,,,, -283600,3.3074443,1.0934956,,,,,,,,,,,,,, -283700,3.2572672,2.5534883,,,,,,,,,,,,,, -283800,2.9526117,1.1242985,,,,,,,,,,,,,, -283900,3.1727138,1.1349405,,,,,,,,,,,,,, -283937,,,0.8886132836341858,0.4185100197792053,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,126919.30169415474,137787.30851221085,126919.30169415474,10832.9416513443,19.51919531822205,0.0 -284000,2.9610968,1.1415155,,,,,,,,,,,,,, -284100,3.202625,2.7634637,,,,,,,,,,,,,, -284200,2.9925363,1.4208305,,,,,,,,,,,,,, -284300,5.565218,3.1840632,,,,,,,,,,,,,, -284400,3.2389407,1.2301638,,,,,,,,,,,,,, -284500,3.3358607,1.2939695,,,,,,,,,,,,,, -284600,3.2353,1.1897218,,,,,,,,,,,,,, -284700,2.8083851,1.7402893,,,,,,,,,,,,,, -284800,3.2363462,1.1277614,,,,,,,,,,,,,, -284875,,,0.8898632526397705,0.4148930013179779,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,127339.16857624054,138247.77598190308,127339.16857624054,10873.36230826378,19.64874291419983,0.0 -284900,3.1130981,1.0657381,,,,,,,,,,,,,, -285000,2.874247,1.6277382,,,,,,,,,,,,,, -285100,3.1100564,2.3223085,,,,,,,,,,,,,, -285200,3.078546,1.1463192,,,,,,,,,,,,,, -285300,3.4178944,2.8974144,,,,,,,,,,,,,, -285400,3.208068,2.740339,,,,,,,,,,,,,, -285500,3.2399485,1.2064471,,,,,,,,,,,,,, -285600,3.2666712,1.4427164,,,,,,,,,,,,,, -285700,3.0493643,1.330774,,,,,,,,,,,,,, -285800,3.0572977,2.7237496,,,,,,,,,,,,,, -285813,,,0.88734370470047,0.4140979647636413,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,127759.18264389038,138707.7092897892,127759.18264389038,10913.14732003212,19.732011795043945,0.0 -285900,2.9468079,1.0946175,,,,,,,,,,,,,, -286000,3.089469,1.174923,,,,,,,,,,,,,, -286100,3.2833908,1.1955144,,,,,,,,,,,,,, -286200,3.2976928,2.4872327,,,,,,,,,,,,,, -286300,2.837921,1.7323234,,,,,,,,,,,,,, -286400,2.7939315,1.567637,,,,,,,,,,,,,, -286500,3.9571724,1.1022882,,,,,,,,,,,,,, -286600,3.1344595,1.1373173,,,,,,,,,,,,,, -286700,2.9983985,2.1824439,,,,,,,,,,,,,, -286748,,,0.8852343559265137,0.4228600561618805,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,128179.11961340904,139171.1473546028,128179.11961340904,10956.518753528597,19.81148219108581,0.0 -286800,2.8039126,1.1019305,,,,,,,,,,,,,, -286900,3.1064484,1.176116,,,,,,,,,,,,,, -287000,3.40905,3.028326,,,,,,,,,,,,,, -287100,3.0276275,1.8220156,,,,,,,,,,,,,, -287200,3.4697976,3.0284448,,,,,,,,,,,,,, -287300,2.9831972,1.1378579,,,,,,,,,,,,,, -287400,3.5537279,3.0744798,,,,,,,,,,,,,, -287500,3.1156566,1.120508,,,,,,,,,,,,,, -287600,3.132274,1.1175029,,,,,,,,,,,,,, -287688,,,0.8890038728713989,0.4117945730686188,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,128598.82387661934,139627.9332201481,128598.82387661934,10993.224220991136,20.136163234710693,0.0 -287700,3.0820956,2.0494828,,,,,,,,,,,,,, -287800,2.7698197,1.2320294,,,,,,,,,,,,,, -287900,3.020871,1.3649769,,,,,,,,,,,,,, -288000,3.416866,2.9195092,,,,,,,,,,,,,, -288100,2.8966477,1.9181019,,,,,,,,,,,,,, -288200,3.2852974,2.8279486,,,,,,,,,,,,,, -288300,2.9713874,1.1595358,,,,,,,,,,,,,, -288400,2.9635763,1.3212043,,,,,,,,,,,,,, -288500,3.1678343,2.549502,,,,,,,,,,,,,, -288600,3.0726593,1.1495036,,,,,,,,,,,,,, -288625,,,0.8883984088897705,0.4140118360519409,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,129019.09246134758,140085.9644765854,129019.09246134758,11030.850098609924,20.22132968902588,0.0 -288700,3.2771063,1.1723188,,,,,,,,,,,,,, -288800,3.077905,1.4951591,,,,,,,,,,,,,, -288900,3.3297122,2.7271283,,,,,,,,,,,,,, -289000,3.2627504,1.1057667,,,,,,,,,,,,,, -289100,3.6043909,3.2525811,,,,,,,,,,,,,, -289200,3.0064764,1.1498387,,,,,,,,,,,,,, -289300,3.330635,1.2373403,,,,,,,,,,,,,, -289400,3.1804223,2.3456006,,,,,,,,,,,,,, -289500,2.8690066,1.6958388,,,,,,,,,,,,,, -289560,,,0.88880854845047,0.4138639867305755,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,129439.0857374668,140544.19653391838,129439.0857374668,11068.954473495483,20.305765628814697,0.0 -289600,3.3866289,2.8672767,,,,,,,,,,,,,, -289700,3.2430809,1.1595246,,,,,,,,,,,,,, -289800,3.4291725,2.4472277,,,,,,,,,,,,,, -289900,2.9727502,1.0810947,,,,,,,,,,,,,, -290000,2.9804254,2.4099677,,,,,,,,,,,,,, -290100,3.0574412,1.1732478,,,,,,,,,,,,,, -290200,3.3571353,3.0277221,,,,,,,,,,,,,, -290300,3.0013838,1.1244227,,,,,,,,,,,,,, -290400,3.0811324,1.0679563,,,,,,,,,,,,,, -290496,,,0.8863085508346558,0.4227134883403778,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,129859.1750626564,141006.6998987198,129859.1750626564,11111.227452993391,20.39596724510193,0.0 -290500,2.951291,1.1006775,,,,,,,,,,,,,, -290600,2.926182,1.3344244,,,,,,,,,,,,,, -290700,3.2263024,1.2130553,,,,,,,,,,,,,, -290800,3.0298104,1.5508125,,,,,,,,,,,,,, -290900,2.8482842,1.1384314,,,,,,,,,,,,,, -291000,3.2821946,1.0655364,,,,,,,,,,,,,, -291100,3.1308002,2.4411993,,,,,,,,,,,,,, -291200,2.9399242,1.1781948,,,,,,,,,,,,,, -291300,3.4043798,3.074119,,,,,,,,,,,,,, -291400,3.1631045,1.0827599,,,,,,,,,,,,,, -291435,,,0.8886327743530273,0.4127470552921295,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,130279.17562508585,141462.5694692135,130279.17562508585,11146.959941864014,20.481508493423465,0.0 -291500,3.1492572,1.1665807,,,,,,,,,,,,,, -291600,3.1628385,1.2280152,,,,,,,,,,,,,, -291700,3.0845106,1.1698456,,,,,,,,,,,,,, -291800,3.3915453,1.0681388,,,,,,,,,,,,,, -291900,3.691321,3.0330765,,,,,,,,,,,,,, -292000,2.8563032,1.4288877,,,,,,,,,,,,,, -292100,3.1754067,1.1173005,,,,,,,,,,,,,, -292200,3.0291157,2.387587,,,,,,,,,,,,,, -292300,3.9687102,1.594321,,,,,,,,,,,,,, -292366,,,0.8895702958106995,0.4100448489189148,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,130699.34307909012,141923.1501107216,130699.34307909012,11187.23245549202,20.571951389312744,0.0 -292400,4.2048874,2.6248739,,,,,,,,,,,,,, -292500,3.0437026,1.4150323,,,,,,,,,,,,,, -292600,3.1219394,1.1222233,,,,,,,,,,,,,, -292700,3.315447,1.7317787,,,,,,,,,,,,,, -292800,3.1109521,1.1039184,,,,,,,,,,,,,, -292900,3.11404,1.4015046,,,,,,,,,,,,,, -293000,2.9625103,0.99427557,,,,,,,,,,,,,, -293100,2.9382174,1.0375509,,,,,,,,,,,,,, -293200,2.846119,1.7124265,,,,,,,,,,,,,, -293300,3.696342,1.234245,,,,,,,,,,,,,, -293303,,,0.8876562118530273,0.4232404232025146,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,131119.25686120987,142384.14255547523,131119.25686120987,11228.173330783844,20.65904426574707,0.0 -293400,3.677713,3.17165,,,,,,,,,,,,,, -293500,3.0498507,1.5380094,,,,,,,,,,,,,, -293600,3.107314,2.0472116,,,,,,,,,,,,,, -293700,3.2018263,2.859625,,,,,,,,,,,,,, -293800,3.0593,2.8755987,,,,,,,,,,,,,, -293900,3.2207823,1.1768348,,,,,,,,,,,,,, -294000,3.5469656,2.9736207,,,,,,,,,,,,,, -294100,3.3009179,1.171873,,,,,,,,,,,,,, -294193,,,0.8883788585662842,0.4139226377010345,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,131539.37345027924,142842.01654458046,131539.37345027924,11265.79709982872,20.74518251419068,0.0 -294200,2.8908236,1.1241193,,,,,,,,,,,,,, -294300,2.7941918,1.4051607,,,,,,,,,,,,,, -294400,3.7166326,3.2901,,,,,,,,,,,,,, -294500,3.457006,1.1759739,,,,,,,,,,,,,, -294600,3.109312,1.2030007,,,,,,,,,,,,,, -294700,3.0060577,1.4858967,,,,,,,,,,,,,, -294800,3.145183,1.8969331,,,,,,,,,,,,,, -294900,2.9557087,1.5189122,,,,,,,,,,,,,, -295000,3.0845826,1.0753194,,,,,,,,,,,,,, -295100,2.9747987,2.1158972,,,,,,,,,,,,,, -295130,,,0.8861523270606995,0.4199804067611694,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,131959.2889535427,143301.71817946434,131959.2889535427,11305.44756770134,20.83014798164368,0.0 -295200,3.5195074,2.9781725,,,,,,,,,,,,,, -295300,3.04096,1.0876571,,,,,,,,,,,,,, -295400,3.2312434,1.1837435,,,,,,,,,,,,,, -295500,2.9525564,1.8302563,,,,,,,,,,,,,, -295600,3.2503145,2.6677294,,,,,,,,,,,,,, -295700,3.0410132,1.0933139,,,,,,,,,,,,,, -295800,3.3090093,1.9454817,,,,,,,,,,,,,, -295900,3.2854424,2.701559,,,,,,,,,,,,,, -296000,3.1907747,2.7282648,,,,,,,,,,,,,, -296067,,,0.8904101252555847,0.4130441844463348,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,132379.5202343464,143758.15282964706,132379.5202343464,11341.51713204384,20.91266751289368,0.0 -296100,3.154044,1.282498,,,,,,,,,,,,,, -296200,2.9395404,1.132391,,,,,,,,,,,,,, -296300,3.1375349,1.1812727,,,,,,,,,,,,,, -296400,2.8086922,1.173582,,,,,,,,,,,,,, -296500,3.1002076,1.4250678,,,,,,,,,,,,,, -296600,3.0682282,1.0842508,,,,,,,,,,,,,, -296700,3.4421608,2.9147966,,,,,,,,,,,,,, -296800,3.0965962,2.7644868,,,,,,,,,,,,,, -296900,3.267691,1.1303638,,,,,,,,,,,,,, -297000,2.7145414,1.1747644,,,,,,,,,,,,,, -297001,,,0.8873242139816284,0.4146493077278137,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,132799.5991435051,144226.03037834167,132799.5991435051,11389.175944328308,21.0025532245636,0.0 -297100,3.4716806,2.9647648,,,,,,,,,,,,,, -297200,3.7462175,3.2322261,,,,,,,,,,,,,, -297300,3.23722,1.2159767,,,,,,,,,,,,,, -297400,3.493086,3.1908813,,,,,,,,,,,,,, -297500,8.496652,1.1302671,,,,,,,,,,,,,, -297600,2.9667091,1.1084872,,,,,,,,,,,,,, -297700,2.8969057,1.4790659,,,,,,,,,,,,,, -297800,3.0328062,1.1242355,,,,,,,,,,,,,, -297900,2.9513223,2.1186535,,,,,,,,,,,,,, -297943,,,0.8876367211341858,0.4151484072208404,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,133219.8224363327,144684.09260296822,133219.8224363327,11426.878244876862,21.088128566741943,0.0 -298000,3.1383765,1.2568959,,,,,,,,,,,,,, -298100,2.9982085,1.1316826,,,,,,,,,,,,,, -298200,2.9596045,1.1131756,,,,,,,,,,,,,, -298300,2.961022,1.367309,,,,,,,,,,,,,, -298400,3.0417705,1.1169115,,,,,,,,,,,,,, -298500,2.8217382,1.5243888,,,,,,,,,,,,,, -298600,3.010213,1.1002297,,,,,,,,,,,,,, -298700,5.3263373,2.958712,,,,,,,,,,,,,, -298800,3.469248,2.8522055,,,,,,,,,,,,,, -298880,,,0.8898437023162842,0.4100130796432495,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,133640.4175248146,145141.23288106918,133640.4175248146,11463.27908229828,21.181781768798828,0.0 -298900,2.7818582,1.0313357,,,,,,,,,,,,,, -299000,3.6244934,3.227142,,,,,,,,,,,,,, -299100,3.2907026,2.719142,,,,,,,,,,,,,, -299200,3.0286858,1.0342635,,,,,,,,,,,,,, -299300,2.8783355,1.8978302,,,,,,,,,,,,,, -299400,3.0599134,2.3848941,,,,,,,,,,,,,, -299500,2.9813824,2.0837188,,,,,,,,,,,,,, -299600,3.7654903,3.2749877,,,,,,,,,,,,,, -299700,3.0524745,2.5190825,,,,,,,,,,,,,, -299800,3.5982485,3.2483578,,,,,,,,,,,,,, -299814,,,0.8866406083106995,0.4186463952064514,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,134060.4425957203,145602.93848657608,134060.4425957203,11504.816838264464,21.274759769439697,0.0 -299900,2.9540398,1.1460843,,,,,,,,,,,,,, -300000,2.9765604,1.440873,,,,,,,,,,,,,, -300100,3.830631,3.1149082,,,,,,,,,,,,,, -300200,3.1661181,2.5835826,,,,,,,,,,,,,, -300300,3.341491,1.1333743,,,,,,,,,,,,,, -300400,2.8768232,1.3838465,,,,,,,,,,,,,, -300500,2.8208373,1.1490213,,,,,,,,,,,,,, -300600,3.6735275,2.9087145,,,,,,,,,,,,,, -300700,4.152589,3.2486482,,,,,,,,,,,,,, -300752,,,0.88978511095047,0.4111753106117248,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,134480.3460702896,146061.57924103737,134480.3460702896,11543.418461561205,21.360515594482425,0.0 -300800,3.0044053,1.2864034,,,,,,,,,,,,,, -300900,3.3143244,1.1084527,,,,,,,,,,,,,, -301000,3.30933,1.5564245,,,,,,,,,,,,,, -301100,4.1012063,1.4618468,,,,,,,,,,,,,, -301200,3.896327,3.2896712,,,,,,,,,,,,,, -301300,3.0321853,2.0567107,,,,,,,,,,,,,, -301400,3.0717537,1.8170507,,,,,,,,,,,,,, -301500,3.0305336,0.9997777,,,,,,,,,,,,,, -301600,3.1508934,1.1516919,,,,,,,,,,,,,, -301689,,,0.8907030820846558,0.4115512371063232,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,134900.35364151,146524.2145433426,134900.35364151,11585.907123088837,21.44925570487976,0.0 -301700,3.2927973,1.0865145,,,,,,,,,,,,,, -301800,2.6634545,1.3478179,,,,,,,,,,,,,, -301900,3.262729,1.156314,,,,,,,,,,,,,, -302000,3.1031642,1.0551071,,,,,,,,,,,,,, -302100,3.6556346,3.1810946,,,,,,,,,,,,,, -302200,3.142016,1.1311696,,,,,,,,,,,,,, -302300,3.2230299,2.04686,,,,,,,,,,,,,, -302400,2.87881,1.1225015,,,,,,,,,,,,,, -302500,3.142936,1.3606682,,,,,,,,,,,,,, -302600,3.4613714,2.995426,,,,,,,,,,,,,, -302630,,,0.8902929425239563,0.4095830023288727,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,135320.4859445095,146978.74358296394,135320.4859445095,11620.165687322617,21.53703117370605,0.0 -302700,3.0854447,1.126167,,,,,,,,,,,,,, -302800,3.1156557,1.1010675,,,,,,,,,,,,,, -302900,3.0791898,2.5447814,,,,,,,,,,,,,, -303000,3.3609097,1.4923711,,,,,,,,,,,,,, -303100,2.8344045,1.7195873,,,,,,,,,,,,,, -303200,3.0670514,2.690125,,,,,,,,,,,,,, -303300,3.0819764,2.4662018,,,,,,,,,,,,,, -303400,3.0028603,1.2548621,,,,,,,,,,,,,, -303500,2.8795104,1.2479024,,,,,,,,,,,,,, -303569,,,0.888476550579071,0.4151207208633423,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,135740.72599720955,147442.11426973343,135740.72599720955,11663.15535378456,21.625725030899048,0.0 -303600,3.0408065,1.1322969,,,,,,,,,,,,,, -303700,2.741643,1.8085644,,,,,,,,,,,,,, -303800,2.922286,1.1579943,,,,,,,,,,,,,, -303900,2.8595102,1.2782698,,,,,,,,,,,,,, -304000,2.9603553,1.592577,,,,,,,,,,,,,, -304100,3.1819363,1.2141465,,,,,,,,,,,,,, -304200,3.0025017,2.0322587,,,,,,,,,,,,,, -304300,2.8503747,1.6173558,,,,,,,,,,,,,, -304400,2.8163047,1.7224525,,,,,,,,,,,,,, -304500,3.365849,1.1494246,,,,,,,,,,,,,, -304507,,,0.8879101276397705,0.4139824211597442,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,136160.76442551613,147901.74234127998,136160.76442551613,11702.606310367584,21.71380090713501,0.0 -304600,2.9447918,1.0828283,,,,,,,,,,,,,, -304700,3.1110168,1.9967266,,,,,,,,,,,,,, -304800,2.982196,1.0412947,,,,,,,,,,,,,, -304900,3.0170963,2.1907191,,,,,,,,,,,,,, -305000,2.9154494,1.1715717,,,,,,,,,,,,,, -305100,3.4619832,2.3740928,,,,,,,,,,,,,, -305200,2.9889817,1.0973321,,,,,,,,,,,,,, -305300,2.997322,1.1661643,,,,,,,,,,,,,, -305400,3.1192813,1.1277516,,,,,,,,,,,,,, -305445,,,0.8877343535423279,0.4162862598896026,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,136580.63121008873,148360.96184515953,136580.63121008873,11741.822509288788,21.80000257492065,0.0 -305500,3.126439,1.871459,,,,,,,,,,,,,, -305600,3.3243096,2.6997159,,,,,,,,,,,,,, -305700,3.4648578,2.8673143,,,,,,,,,,,,,, -305800,3.4284923,2.9051445,,,,,,,,,,,,,, -305900,2.8537033,2.1395183,,,,,,,,,,,,,, -306000,3.1933331,2.6014595,,,,,,,,,,,,,, -306100,3.1417031,1.1409456,,,,,,,,,,,,,, -306200,3.1305552,1.0403394,,,,,,,,,,,,,, -306300,3.2238247,1.2271566,,,,,,,,,,,,,, -306383,,,0.8881250023841858,0.4133361577987671,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,137000.6361260414,148828.68139123917,137000.6361260414,11789.387017726898,21.89959049224853,0.0 -306400,3.0353436,1.1076939,,,,,,,,,,,,,, -306500,3.1756518,1.1380533,,,,,,,,,,,,,, -306600,2.9814556,1.2004826,,,,,,,,,,,,,, -306700,2.9043534,1.7823806,,,,,,,,,,,,,, -306800,3.2821808,1.3010511,,,,,,,,,,,,,, -306900,2.9908617,1.0688719,,,,,,,,,,,,,, -307000,3.2960544,1.1295857,,,,,,,,,,,,,, -307100,3.0006363,1.1310444,,,,,,,,,,,,,, -307200,2.996808,1.9062479,,,,,,,,,,,,,, -307300,2.9864175,2.2427168,,,,,,,,,,,,,, -307323,,,0.8883593678474426,0.4201784133911133,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,137420.53449702263,149291.6964457035,137420.53449702263,11832.368502616882,21.98433923721313,0.0 -307400,3.7192354,3.1636937,,,,,,,,,,,,,, -307500,3.0814805,1.1964815,,,,,,,,,,,,,, -307600,3.2233682,1.1171588,,,,,,,,,,,,,, -307700,3.9378102,3.1685636,,,,,,,,,,,,,, -307800,3.4959908,3.372768,,,,,,,,,,,,,, -307900,4.445774,2.8415961,,,,,,,,,,,,,, -308000,3.0271633,1.8335505,,,,,,,,,,,,,, -308100,2.9856212,1.1365393,,,,,,,,,,,,,, -308200,2.7682846,1.1144735,,,,,,,,,,,,,, -308260,,,0.8879687190055847,0.4177784919738769,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,137840.7928082943,149749.99558973312,137840.7928082943,11870.268971681597,22.07439494132996,0.0 -308300,3.0317693,1.1415694,,,,,,,,,,,,,, -308400,3.0501392,1.261562,,,,,,,,,,,,,, -308500,3.0808573,1.136617,,,,,,,,,,,,,, -308600,2.900381,1.3714703,,,,,,,,,,,,,, -308700,3.1520894,1.7863598,,,,,,,,,,,,,, -308800,2.9011414,1.0681553,,,,,,,,,,,,,, -308900,3.0160453,1.1678312,,,,,,,,,,,,,, -309000,2.7864356,1.9855933,,,,,,,,,,,,,, -309100,3.177511,2.6852734,,,,,,,,,,,,,, -309197,,,0.88916015625,0.4091837108135223,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,138260.85964107513,150206.64368653297,138260.85964107513,11906.708271741869,22.166040182113647,0.0 -309200,3.2471855,2.8577476,,,,,,,,,,,,,, -309300,3.0132482,1.6767519,,,,,,,,,,,,,, -309400,3.10707,1.3827844,,,,,,,,,,,,,, -309500,3.1820295,2.7252545,,,,,,,,,,,,,, -309600,3.0915504,1.065853,,,,,,,,,,,,,, -309700,2.855375,0.99182665,,,,,,,,,,,,,, -309800,3.0774236,1.1958331,,,,,,,,,,,,,, -309900,3.612322,2.969267,,,,,,,,,,,,,, -310000,2.9948187,1.0952125,,,,,,,,,,,,,, -310100,2.9379647,1.0313632,,,,,,,,,,,,,, -310133,,,0.8869531154632568,0.4185750484466553,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,138680.788346529,150667.1916770935,138680.788346529,11947.138439178469,22.304643392562863,0.0 -310200,3.8946328,3.033556,,,,,,,,,,,,,, -310300,3.0327365,2.1186388,,,,,,,,,,,,,, -310400,3.1653976,1.1413481,,,,,,,,,,,,,, -310500,3.33718,2.8192143,,,,,,,,,,,,,, -310600,3.5275404,3.098691,,,,,,,,,,,,,, -310700,3.2285733,1.7987008,,,,,,,,,,,,,, -310800,2.9333766,1.0665729,,,,,,,,,,,,,, -310900,3.2220793,1.585407,,,,,,,,,,,,,, -311000,3.2594426,1.2518629,,,,,,,,,,,,,, -311071,,,0.8867968320846558,0.418944239616394,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,139100.7155430317,151126.57604026794,139100.7155430317,11986.4548933506,22.394673824310303,0.0 -311100,2.9993753,1.179604,,,,,,,,,,,,,, -311200,3.103632,1.1659541,,,,,,,,,,,,,, -311300,2.915002,1.2243339,,,,,,,,,,,,,, -311400,2.8185012,2.4490252,,,,,,,,,,,,,, -311500,3.180683,2.427086,,,,,,,,,,,,,, -311600,3.7011316,3.1377225,,,,,,,,,,,,,, -311700,3.3554583,2.8749776,,,,,,,,,,,,,, -311800,2.9962044,1.5268903,,,,,,,,,,,,,, -311900,2.785241,1.1810086,,,,,,,,,,,,,, -312000,3.5716164,3.0280704,,,,,,,,,,,,,, -312008,,,0.88916015625,0.411105066537857,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,139520.6550180912,151587.94029808044,139520.6550180912,12027.735392093658,22.487955570220947,0.0 -312100,3.5032113,2.9323306,,,,,,,,,,,,,, -312200,3.158889,1.238324,,,,,,,,,,,,,, -312300,2.9881508,1.1413829,,,,,,,,,,,,,, -312400,3.0484722,2.6166992,,,,,,,,,,,,,, -312500,2.842118,1.767223,,,,,,,,,,,,,, -312600,2.756458,1.0945936,,,,,,,,,,,,,, -312700,4.1314645,2.4071815,,,,,,,,,,,,,, -312800,2.9707775,1.1178944,,,,,,,,,,,,,, -312900,3.1314914,1.9604902,,,,,,,,,,,,,, -312948,,,0.8881444931030273,0.4167310893535614,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,139940.78490543363,152047.78997969627,139940.78490543363,12067.290197372437,22.60226511955261,0.0 -313000,2.8804557,2.2705998,,,,,,,,,,,,,, -313100,2.8973715,1.1728144,,,,,,,,,,,,,, -313200,3.0411806,1.1069579,,,,,,,,,,,,,, -313300,3.008951,1.0914459,,,,,,,,,,,,,, -313400,3.191565,2.457162,,,,,,,,,,,,,, -313500,2.8678331,2.057839,,,,,,,,,,,,,, -313600,3.4172792,3.1079743,,,,,,,,,,,,,, -313700,3.0246928,1.2911475,,,,,,,,,,,,,, -313800,3.592732,1.1797731,,,,,,,,,,,,,, -313887,,,0.8859961032867432,0.4235863089561462,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,140360.85139131546,152508.9900314808,140360.85139131546,12108.281812429428,22.693623065948486,0.0 -313900,3.1606617,1.0244977,,,,,,,,,,,,,, -314000,3.148792,1.128093,,,,,,,,,,,,,, -314100,3.227081,1.1803839,,,,,,,,,,,,,, -314200,3.5481334,3.0894704,,,,,,,,,,,,,, -314300,3.2693276,1.3543729,,,,,,,,,,,,,, -314400,3.2515101,1.2624129,,,,,,,,,,,,,, -314500,3.0399683,2.1677508,,,,,,,,,,,,,, -314600,3.0323887,1.8675281,,,,,,,,,,,,,, -314700,2.9322746,1.2156957,,,,,,,,,,,,,, -314800,3.019297,2.6307607,,,,,,,,,,,,,, -314824,,,0.8883398175239563,0.4152339696884155,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,140781.04164910316,152968.01148629189,140781.04164910316,12146.96865439415,22.78712558746338,0.0 -314900,2.9820075,1.1431607,,,,,,,,,,,,,, -315000,2.9112172,1.7377313,,,,,,,,,,,,,, -315100,3.27831,1.1425476,,,,,,,,,,,,,, -315200,3.1717405,1.1406372,,,,,,,,,,,,,, -315300,3.5937178,3.1743596,,,,,,,,,,,,,, -315400,2.8640893,1.4373922,,,,,,,,,,,,,, -315500,3.155702,1.2653172,,,,,,,,,,,,,, -315600,3.0610774,1.1384062,,,,,,,,,,,,,, -315700,3.0516198,1.2242492,,,,,,,,,,,,,, -315762,,,0.8909569978713989,0.4055262804031372,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,141201.20494127274,153425.1422805786,141201.20494127274,12183.743371248243,22.927883863449097,0.0 -315800,3.237093,1.2036711,,,,,,,,,,,,,, -315900,3.0804188,2.0811872,,,,,,,,,,,,,, -316000,3.1753771,1.147008,,,,,,,,,,,,,, -316100,3.0013242,1.1046448,,,,,,,,,,,,,, -316200,2.8635309,1.2601143,,,,,,,,,,,,,, -316300,2.8578308,1.3532406,,,,,,,,,,,,,, -316400,3.032765,1.2376218,,,,,,,,,,,,,, -316500,2.9641852,1.1559066,,,,,,,,,,,,,, -316600,2.929187,1.0803909,,,,,,,,,,,,,, -316696,,,0.8879296779632568,0.4203538298606872,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,141621.46349978447,153882.85493707657,141621.46349978447,12221.055872440338,23.01867938041687,0.0 -316700,3.6728427,3.1284974,,,,,,,,,,,,,, -316800,3.0046494,1.5125531,,,,,,,,,,,,,, -316900,3.570581,2.4937139,,,,,,,,,,,,,, -317000,3.1290483,2.5154338,,,,,,,,,,,,,, -317100,3.7904336,2.9892068,,,,,,,,,,,,,, -317200,3.2393506,1.1571672,,,,,,,,,,,,,, -317300,3.2339227,1.1722523,,,,,,,,,,,,,, -317400,3.2448473,1.2008609,,,,,,,,,,,,,, -317500,2.8717616,1.7342721,,,,,,,,,,,,,, -317600,3.4233768,1.2445252,,,,,,,,,,,,,, -317634,,,0.8866406083106995,0.4178559482097626,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,142041.4583003521,154344.91010689735,142041.4583003521,12262.974166870115,23.10886240005493,0.0 -317700,3.0582714,2.3908482,,,,,,,,,,,,,, -317800,3.3522573,2.7557611,,,,,,,,,,,,,, -317900,3.7322142,3.254732,,,,,,,,,,,,,, -318000,2.7637925,1.9307178,,,,,,,,,,,,,, -318100,3.4695587,2.9716187,,,,,,,,,,,,,, -318200,2.925395,2.5568728,,,,,,,,,,,,,, -318300,2.859607,1.0348886,,,,,,,,,,,,,, -318400,2.8844793,1.0774584,,,,,,,,,,,,,, -318500,2.9434867,1.1005439,,,,,,,,,,,,,, -318572,,,0.8883398175239563,0.4167503118515014,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,142461.56035637856,154801.62366724014,142461.56035637856,12299.447113752363,23.19686794281006,0.0 -318600,2.9717638,2.3942416,,,,,,,,,,,,,, -318700,2.9300861,1.8115773,,,,,,,,,,,,,, -318800,2.9947746,1.0071996,,,,,,,,,,,,,, -318900,3.6447515,3.1825306,,,,,,,,,,,,,, -319000,3.2194688,1.0929017,,,,,,,,,,,,,, -319100,3.2633455,1.3130254,,,,,,,,,,,,,, -319200,2.8547978,1.6530381,,,,,,,,,,,,,, -319300,3.1656485,1.1804488,,,,,,,,,,,,,, -319400,2.7765017,1.0430952,,,,,,,,,,,,,, -319500,3.5484116,1.7966702,,,,,,,,,,,,,, -319510,,,0.88916015625,0.4147959053516388,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,142881.77575540543,155263.3157105446,142881.77575540543,12340.786552906036,23.28324294090271,0.0 -319600,2.7671883,1.0362124,,,,,,,,,,,,,, -319700,3.1400626,2.6601815,,,,,,,,,,,,,, -319800,3.6053643,3.1975832,,,,,,,,,,,,,, -319900,2.8754148,1.2039039,,,,,,,,,,,,,, -320000,2.9325528,2.179521,,,,,,,,,,,,,, -320100,3.1461694,1.0791321,,,,,,,,,,,,,, -320200,3.1792939,2.5536256,,,,,,,,,,,,,, -320300,3.270424,1.6965295,,,,,,,,,,,,,, -320400,3.1138406,2.492313,,,,,,,,,,,,,, -320448,,,0.8883007764816284,0.4113078117370605,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,143302.01441574097,155722.77356290817,143302.01441574097,12379.863595485687,23.374863862991333,0.0 -320500,3.4365451,1.2269163,,,,,,,,,,,,,, -320600,2.9971519,1.5972953,,,,,,,,,,,,,, -320700,3.1885676,1.2378111,,,,,,,,,,,,,, -320800,3.4344912,1.2097741,,,,,,,,,,,,,, -320900,3.207947,1.10816,,,,,,,,,,,,,, -321000,3.0420809,1.120141,,,,,,,,,,,,,, -321100,3.5560503,3.1447346,,,,,,,,,,,,,, -321200,2.9566743,1.086464,,,,,,,,,,,,,, -321300,2.9375877,1.1656334,,,,,,,,,,,,,, -321387,,,0.8866991996765137,0.4174054563045501,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,143722.09818434715,156187.04256916046,143722.09818434715,12423.899383544922,23.473129272460938,0.0 -321400,3.1839502,2.7055914,,,,,,,,,,,,,, -321500,2.8357425,1.1452225,,,,,,,,,,,,,, -321600,3.2963808,1.193113,,,,,,,,,,,,,, -321700,3.0112681,1.1900847,,,,,,,,,,,,,, -321800,3.2233546,2.3277624,,,,,,,,,,,,,, -321900,3.186084,1.25277,,,,,,,,,,,,,, -322000,3.1566935,1.2178948,,,,,,,,,,,,,, -322100,3.089993,1.1717801,,,,,,,,,,,,,, -322200,2.9253597,1.7268568,,,,,,,,,,,,,, -322300,3.102256,2.1592221,,,,,,,,,,,,,, -322323,,,0.8899218440055847,0.4096273779869079,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,144142.33653116226,156647.58727145195,144142.33653116226,12464.060123443604,23.5680513381958,0.0 -322400,3.0901387,1.1834542,,,,,,,,,,,,,, -322500,3.0949907,2.5575423,,,,,,,,,,,,,, -322600,3.1485627,1.2550468,,,,,,,,,,,,,, -322700,2.73655,1.0572673,,,,,,,,,,,,,, -322800,3.1245816,1.1924102,,,,,,,,,,,,,, -322900,2.8285933,1.5376221,,,,,,,,,,,,,, -323000,2.8431704,1.3289701,,,,,,,,,,,,,, -323100,2.8870685,1.4306173,,,,,,,,,,,,,, -323200,2.9101665,1.9242358,,,,,,,,,,,,,, -323258,,,0.8880468606948853,0.4144402146339416,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,144562.4763252735,157107.26393318176,144562.4763252735,12503.45373558998,23.660505771636963,0.0 -323300,3.7753558,3.0405443,,,,,,,,,,,,,, -323400,3.4386194,2.9371595,,,,,,,,,,,,,, -323500,3.5714757,3.1902256,,,,,,,,,,,,,, -323600,2.9550843,1.1403748,,,,,,,,,,,,,, -323700,3.0427365,2.4975944,,,,,,,,,,,,,, -323800,3.5666566,1.1317433,,,,,,,,,,,,,, -323900,3.1693876,1.1667637,,,,,,,,,,,,,, -324000,3.1567872,2.1916769,,,,,,,,,,,,,, -324100,3.188423,1.2006669,,,,,,,,,,,,,, -324194,,,0.8901171684265137,0.4119070470333099,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,144982.35832333565,157564.89127373695,144982.35832333565,12541.001872062683,23.807493209838867,0.0 -324200,3.3911505,2.378033,,,,,,,,,,,,,, -324300,2.942765,2.0027835,,,,,,,,,,,,,, -324400,3.8192,3.1838617,,,,,,,,,,,,,, -324500,3.0081012,1.3083931,,,,,,,,,,,,,, -324600,3.0329306,1.9700592,,,,,,,,,,,,,, -324700,3.4993193,3.0561323,,,,,,,,,,,,,, -324800,3.166623,1.2664444,,,,,,,,,,,,,, -324900,3.1347423,1.2482786,,,,,,,,,,,,,, -325000,2.9511175,1.1349508,,,,,,,,,,,,,, -325100,3.3274307,1.2249681,,,,,,,,,,,,,, -325133,,,0.8872851133346558,0.4181750118732452,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,145402.36496806145,158026.31503486633,145402.36496806145,12582.273058652878,23.902076482772827,0.0 -325200,3.044045,1.0060693,,,,,,,,,,,,,, -325300,3.2368467,1.1404256,,,,,,,,,,,,,, -325400,3.0355592,1.8559675,,,,,,,,,,,,,, -325500,2.9095197,2.076888,,,,,,,,,,,,,, -325600,3.7791593,1.2303034,,,,,,,,,,,,,, -325700,3.5755427,3.291948,,,,,,,,,,,,,, -325800,2.7881105,1.7067286,,,,,,,,,,,,,, -325900,2.8759367,2.094452,,,,,,,,,,,,,, -326000,3.0554485,1.0782089,,,,,,,,,,,,,, -326071,,,0.8898437023162842,0.4119757115840912,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,145822.30849671364,158487.42919802666,145822.30849671364,12623.30327129364,23.99218988418579,0.0 -326100,3.3552392,1.2198938,,,,,,,,,,,,,, -326200,2.9685886,1.0720608,,,,,,,,,,,,,, -326300,2.752567,1.8597884,,,,,,,,,,,,,, -326400,2.9881864,1.0702589,,,,,,,,,,,,,, -326500,2.901127,2.4317055,,,,,,,,,,,,,, -326600,3.0201442,2.0317802,,,,,,,,,,,,,, -326700,2.9068525,1.3070506,,,,,,,,,,,,,, -326800,2.9619415,1.8885239,,,,,,,,,,,,,, -326900,3.7171834,3.1704555,,,,,,,,,,,,,, -327000,3.070829,1.6006856,,,,,,,,,,,,,, -327009,,,0.8896679282188416,0.4137422144412994,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,146242.26469278336,158943.34083485603,146242.26469278336,12659.065303087234,24.135462999343872,0.0 -327100,4.18638,1.9188242,,,,,,,,,,,,,, -327200,3.1107647,1.6466124,,,,,,,,,,,,,, -327300,3.1756365,1.1245399,,,,,,,,,,,,,, -327400,2.9704301,1.1285248,,,,,,,,,,,,,, -327500,2.9884572,1.1400961,,,,,,,,,,,,,, -327600,2.985652,1.2780199,,,,,,,,,,,,,, -327700,2.9175014,1.0227481,,,,,,,,,,,,,, -327800,3.220762,2.4798217,,,,,,,,,,,,,, -327900,3.8443928,3.264593,,,,,,,,,,,,,, -327945,,,0.8895702958106995,0.4116011261940002,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,146662.29273629189,159401.76953983307,146662.29273629189,12697.31627702713,24.234971523284912,0.0 -328000,3.0417244,1.1134149,,,,,,,,,,,,,, -328100,2.762147,1.3968973,,,,,,,,,,,,,, -328200,3.608059,3.0912683,,,,,,,,,,,,,, -328300,3.183174,1.1984774,,,,,,,,,,,,,, -328400,3.178115,3.0422935,,,,,,,,,,,,,, -328500,2.9718072,2.0404832,,,,,,,,,,,,,, -328600,3.076084,1.2775993,,,,,,,,,,,,,, -328700,3.1444821,2.7152076,,,,,,,,,,,,,, -328800,2.9599342,1.1892298,,,,,,,,,,,,,, -328883,,,0.8875585794448853,0.4162088930606842,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,147082.52355718613,159863.60181355476,147082.52355718613,12738.77527141571,24.327179431915283,0.0 -328900,2.8575943,1.3822801,,,,,,,,,,,,,, -329000,3.0702825,1.1961399,,,,,,,,,,,,,, -329100,2.936276,1.1085477,,,,,,,,,,,,,, -329200,2.9409482,1.203556,,,,,,,,,,,,,, -329300,2.893072,1.1011407,,,,,,,,,,,,,, -329400,3.114762,1.9135464,,,,,,,,,,,,,, -329500,3.175071,1.2946532,,,,,,,,,,,,,, -329600,2.924536,2.1296437,,,,,,,,,,,,,, -329700,3.1065466,1.1632041,,,,,,,,,,,,,, -329800,3.2408278,1.2043668,,,,,,,,,,,,,, -329818,,,0.8869140148162842,0.4170363843441009,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,147502.60175275803,160323.10658454895,147502.60175275803,12778.015112400057,24.46321201324463,0.0 -329900,3.2283928,1.2324123,,,,,,,,,,,,,, -330000,3.2649455,2.2530563,,,,,,,,,,,,,, -330100,3.2869043,2.9188924,,,,,,,,,,,,,, -330200,3.0441368,1.7832818,,,,,,,,,,,,,, -330300,2.8712265,1.7636372,,,,,,,,,,,,,, -330400,2.9384325,1.1106722,,,,,,,,,,,,,, -330500,3.069658,2.7123518,,,,,,,,,,,,,, -330600,3.285195,1.1035864,,,,,,,,,,,,,, -330700,2.8961208,2.190947,,,,,,,,,,,,,, -330755,,,0.8896679282188416,0.4135584831237793,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,147922.8135919571,160784.8777372837,147922.8135919571,12819.419175863266,24.56829261779785,0.0 -330800,3.3859005,1.0828375,,,,,,,,,,,,,, -330900,3.0488005,1.3813289,,,,,,,,,,,,,, -331000,3.544587,1.0892087,,,,,,,,,,,,,, -331100,2.9721575,1.0050238,,,,,,,,,,,,,, -331200,3.1594205,1.1873763,,,,,,,,,,,,,, -331300,3.1827254,1.4214995,,,,,,,,,,,,,, -331400,3.0251865,1.1392314,,,,,,,,,,,,,, -331500,3.2536914,2.5007586,,,,,,,,,,,,,, -331600,3.2113175,1.1089503,,,,,,,,,,,,,, -331694,,,0.8881250023841858,0.4180620610713959,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,148343.09826993942,161253.30429697037,148343.09826993942,12867.414836645126,24.66355276107788,0.0 -331700,3.1060655,1.3558958,,,,,,,,,,,,,, -331800,3.1873107,1.1924008,,,,,,,,,,,,,, -331900,3.1900036,1.092103,,,,,,,,,,,,,, -332000,3.1246605,1.1832011,,,,,,,,,,,,,, -332100,3.2746487,2.1401663,,,,,,,,,,,,,, -332200,3.5556452,3.0203424,,,,,,,,,,,,,, -332300,3.2259302,1.1426353,,,,,,,,,,,,,, -332400,2.8026853,1.7811342,,,,,,,,,,,,,, -332500,3.0192454,1.0322146,,,,,,,,,,,,,, -332600,3.049891,1.4163141,,,,,,,,,,,,,, -332637,,,0.8884961009025574,0.4143311083316803,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,148763.2022268772,161713.47340416908,148763.2022268772,12907.345923662186,24.7419536113739,0.0 -332700,3.1580002,1.1324639,,,,,,,,,,,,,, -332800,3.6634603,3.264416,,,,,,,,,,,,,, -332900,3.8357642,3.2738492,,,,,,,,,,,,,, -333000,3.2775652,2.6927385,,,,,,,,,,,,,, -333100,3.2455842,1.5253793,,,,,,,,,,,,,, -333200,3.348369,1.0366093,,,,,,,,,,,,,, -333300,3.7943978,1.6746889,,,,,,,,,,,,,, -333400,3.1038125,1.0742543,,,,,,,,,,,,,, -333500,3.2617855,1.572107,,,,,,,,,,,,,, -333573,,,0.8854101300239563,0.4235457181930542,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,149183.26648783684,162176.588023901,149183.26648783684,12950.252911806108,24.83434271812439,0.0 -333600,4.0495877,1.6494727,,,,,,,,,,,,,, -333700,3.2653239,2.614425,,,,,,,,,,,,,, -333800,2.8919556,1.1639102,,,,,,,,,,,,,, -333900,3.1532161,1.2375518,,,,,,,,,,,,,, -334000,3.0055208,1.0922917,,,,,,,,,,,,,, -334100,2.8277922,1.097841,,,,,,,,,,,,,, -334200,4.015199,3.1102664,,,,,,,,,,,,,, -334300,3.0674813,1.0678881,,,,,,,,,,,,,, -334400,3.456224,2.9807973,,,,,,,,,,,,,, -334500,3.1266878,1.2105173,,,,,,,,,,,,,, -334508,,,0.8883984088897705,0.4125818610191345,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,149603.49204540253,162639.50809574127,149603.49204540253,12992.80593752861,24.92528367042541,0.0 -334600,3.0090091,1.1531031,,,,,,,,,,,,,, -334700,3.016317,1.422456,,,,,,,,,,,,,, -334800,3.1805236,1.1875654,,,,,,,,,,,,,, -334900,3.1044028,1.1534728,,,,,,,,,,,,,, -335000,3.542028,3.0662076,,,,,,,,,,,,,, -335100,2.9780881,1.1248999,,,,,,,,,,,,,, -335200,3.7714832,3.0858016,,,,,,,,,,,,,, -335300,2.9615939,1.5561007,,,,,,,,,,,,,, -335400,3.059416,1.26981,,,,,,,,,,,,,, -335449,,,0.8907030820846558,0.4066536724567413,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,150023.56773352623,163101.40661740303,150023.56773352623,13034.503975629808,24.99918556213379,0.0 -335500,3.186551,1.2529156,,,,,,,,,,,,,, -335600,3.0409896,1.1539605,,,,,,,,,,,,,, -335700,2.8722725,1.9985423,,,,,,,,,,,,,, -335800,3.2348711,1.2382983,,,,,,,,,,,,,, -335900,3.2246132,2.9273465,,,,,,,,,,,,,, -336000,3.457443,1.2030346,,,,,,,,,,,,,, -336100,2.979852,1.3009472,,,,,,,,,,,,,, -336200,3.70333,3.2681484,,,,,,,,,,,,,, -336300,3.1269667,2.3545318,,,,,,,,,,,,,, -336388,,,0.8867382407188416,0.4205570816993713,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,150443.80747699738,163566.56737589836,150443.80747699738,13079.27995443344,25.09358143806457,0.0 -336400,3.0038435,1.3543249,,,,,,,,,,,,,, -336500,2.9504957,2.4569564,,,,,,,,,,,,,, -336600,3.1421838,1.175842,,,,,,,,,,,,,, -336700,3.334078,1.9771645,,,,,,,,,,,,,, -336800,3.3300815,1.1139164,,,,,,,,,,,,,, -336900,3.032791,1.3978997,,,,,,,,,,,,,, -337000,3.3138342,1.1105222,,,,,,,,,,,,,, -337100,3.133735,1.0792208,,,,,,,,,,,,,, -337200,2.8876307,2.0607166,,,,,,,,,,,,,, -337300,3.006495,1.2073083,,,,,,,,,,,,,, -337325,,,0.8875390291213989,0.4201224744319916,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,150863.85061359406,164026.75428032875,150863.85061359406,13119.275871992111,25.190476655960083,0.0 -337400,2.932783,2.2089138,,,,,,,,,,,,,, -337500,2.9845223,2.2692995,,,,,,,,,,,,,, -337600,3.1630802,2.7795022,,,,,,,,,,,,,, -337700,3.5747082,3.1884918,,,,,,,,,,,,,, -337800,3.0328858,2.2421098,,,,,,,,,,,,,, -337900,3.0511224,2.4903452,,,,,,,,,,,,,, -338000,3.200693,2.6592343,,,,,,,,,,,,,, -338100,3.1572757,2.9884315,,,,,,,,,,,,,, -338200,3.0321386,1.7290654,,,,,,,,,,,,,, -338258,,,0.8876367211341858,0.4145201444625854,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,151283.9351758957,164499.32343387604,151283.9351758957,13171.61202454567,25.28453826904297,0.0 -338300,3.8855927,3.2406507,,,,,,,,,,,,,, -338400,2.9844327,1.1700099,,,,,,,,,,,,,, -338500,3.0758512,2.185141,,,,,,,,,,,,,, -338600,3.3669024,2.719482,,,,,,,,,,,,,, -338700,3.160124,1.1918837,,,,,,,,,,,,,, -338800,2.838663,1.1607603,,,,,,,,,,,,,, -338900,3.3870902,1.0960758,,,,,,,,,,,,,, -339000,2.9812028,1.3722801,,,,,,,,,,,,,, -339100,3.4006865,3.2082884,,,,,,,,,,,,,, -339196,,,0.8885351419448853,0.4120396077632904,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,151703.95897865295,164954.1756284237,151703.95897865295,13206.307997703552,25.36584448814392,0.0 -339200,3.1332512,1.1143279,,,,,,,,,,,,,, -339300,3.426017,3.0057106,,,,,,,,,,,,,, -339400,3.596992,3.2711112,,,,,,,,,,,,,, -339500,3.3248112,1.1477078,,,,,,,,,,,,,, -339600,3.258623,1.2632974,,,,,,,,,,,,,, -339700,3.14212,1.1137673,,,,,,,,,,,,,, -339800,5.1334977,3.3313894,,,,,,,,,,,,,, -339900,3.166703,2.4212773,,,,,,,,,,,,,, -340000,3.236977,1.1549021,,,,,,,,,,,,,, -340100,3.050777,1.1065677,,,,,,,,,,,,,, -340129,,,0.8873828053474426,0.4202143251895904,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,152124.17705917358,165414.87696027756,152124.17705917358,13246.64460015297,25.46287226676941,0.0 -340200,3.1479437,2.4184344,,,,,,,,,,,,,, -340300,3.2567582,1.1989102,,,,,,,,,,,,,, -340400,2.7183163,1.0008446,,,,,,,,,,,,,, -340500,3.132923,1.1266919,,,,,,,,,,,,,, -340600,3.0053773,1.171884,,,,,,,,,,,,,, -340700,3.188058,1.2135798,,,,,,,,,,,,,, -340800,3.0507264,2.469127,,,,,,,,,,,,,, -340900,2.7765918,1.262561,,,,,,,,,,,,,, -341000,3.138897,1.6753223,,,,,,,,,,,,,, -341061,,,0.8905078172683716,0.4098507463932037,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,152544.35502910614,165878.0419712067,152544.35502910614,13289.47941160202,25.565407514572144,0.0 -341100,3.7663662,1.2050562,,,,,,,,,,,,,, -341200,2.9654007,1.5742004,,,,,,,,,,,,,, -341300,2.8973293,1.109785,,,,,,,,,,,,,, -341400,3.3273935,2.835111,,,,,,,,,,,,,, -341500,3.210407,1.2127614,,,,,,,,,,,,,, -341600,3.1080644,1.1370101,,,,,,,,,,,,,, -341700,3.0315373,1.3640432,,,,,,,,,,,,,, -341800,3.6394742,2.9025621,,,,,,,,,,,,,, -341900,3.2896688,1.0694965,,,,,,,,,,,,,, -342000,3.10907,1.1330554,,,,,,,,,,,,,, -342003,,,0.8872265219688416,0.4186387658119201,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,152964.3291182518,166331.00338292122,152964.3291182518,13322.339082717896,25.64190411567688,0.0 -342100,3.0454042,1.1655483,,,,,,,,,,,,,, -342200,3.0685406,0.9806493,,,,,,,,,,,,,, -342300,2.925127,1.081952,,,,,,,,,,,,,, -342400,3.4346142,2.6091971,,,,,,,,,,,,,, -342500,3.2473402,1.1446538,,,,,,,,,,,,,, -342600,3.4464085,2.1488857,,,,,,,,,,,,,, -342700,3.0735512,2.553809,,,,,,,,,,,,,, -342800,2.8825738,2.2863967,,,,,,,,,,,,,, -342900,2.9786954,1.1238283,,,,,,,,,,,,,, -342939,,,0.8864062428474426,0.4191470742225647,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,153384.33371186256,166790.92618012428,153384.33371186256,13362.112790107729,25.7352511882782,0.0 -343000,3.0409086,2.5001965,,,,,,,,,,,,,, -343100,3.1496532,2.3375118,,,,,,,,,,,,,, -343200,3.5498924,1.1039698,,,,,,,,,,,,,, -343300,3.0329733,1.1623217,,,,,,,,,,,,,, -343400,3.1869407,2.8094273,,,,,,,,,,,,,, -343500,3.0654454,1.1168298,,,,,,,,,,,,,, -343600,2.7272341,1.5484462,,,,,,,,,,,,,, -343700,2.9235513,1.623579,,,,,,,,,,,,,, -343800,3.113558,1.1118195,,,,,,,,,,,,,, -343876,,,0.8897656202316284,0.4129240214824676,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,153804.50244569778,167261.6859352589,153804.50244569778,13412.558959245682,25.82959008216858,0.0 -343900,2.9926622,1.2947229,,,,,,,,,,,,,, -344000,3.3720794,1.0577865,,,,,,,,,,,,,, -344100,3.828551,2.884155,,,,,,,,,,,,,, -344200,3.0718038,2.4877052,,,,,,,,,,,,,, -344300,3.1352057,1.1357484,,,,,,,,,,,,,, -344400,3.1075351,1.0745553,,,,,,,,,,,,,, -344500,3.2292361,1.1053958,,,,,,,,,,,,,, -344600,3.0294008,2.6699882,,,,,,,,,,,,,, -344700,3.203592,1.0703591,,,,,,,,,,,,,, -344800,3.5911028,3.1424062,,,,,,,,,,,,,, -344816,,,0.8880664110183716,0.4134820103645324,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,154224.41515517235,167719.3604798317,154224.41515517235,13450.191165685654,25.908400774002075,0.0 -344900,3.3104432,1.1848626,,,,,,,,,,,,,, -345000,3.6079044,3.0694823,,,,,,,,,,,,,, -345100,2.9311821,1.4491451,,,,,,,,,,,,,, -345200,3.1179168,1.1577989,,,,,,,,,,,,,, -345300,4.4030814,1.1025367,,,,,,,,,,,,,, -345400,2.9861526,2.5249166,,,,,,,,,,,,,, -345500,3.222233,1.2079407,,,,,,,,,,,,,, -345600,2.9767127,1.0100331,,,,,,,,,,,,,, -345700,2.7833014,1.8773626,,,,,,,,,,,,,, -345753,,,0.8886132836341858,0.4115453362464905,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,154644.2866268158,168187.02220201492,154644.2866268158,13497.83123230934,26.00806427001953,0.0 -345800,2.926759,1.1444578,,,,,,,,,,,,,, -345900,2.9431856,1.1806506,,,,,,,,,,,,,, -346000,3.0174985,1.2049391,,,,,,,,,,,,,, -346100,3.0049615,1.1929438,,,,,,,,,,,,,, -346200,2.8594983,1.0264707,,,,,,,,,,,,,, -346300,3.397909,2.867676,,,,,,,,,,,,,, -346400,2.8248718,1.8965905,,,,,,,,,,,,,, -346500,3.11011,1.1300213,,,,,,,,,,,,,, -346600,3.797144,3.1990263,,,,,,,,,,,,,, -346689,,,0.88720703125,0.4173647165298462,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,155064.39183330536,168644.72593283653,155064.39183330536,13535.284557580948,26.102439403533936,0.0 -346700,3.0676599,1.148918,,,,,,,,,,,,,, -346800,3.3243854,1.1936759,,,,,,,,,,,,,, -346900,3.0594153,1.1906321,,,,,,,,,,,,,, -347000,3.4623737,3.0809894,,,,,,,,,,,,,, -347100,3.8159683,3.1786315,,,,,,,,,,,,,, -347200,3.1101723,1.740292,,,,,,,,,,,,,, -347300,2.951168,1.0542269,,,,,,,,,,,,,, -347400,2.8772533,1.7676084,,,,,,,,,,,,,, -347500,2.949822,1.0981312,,,,,,,,,,,,,, -347600,2.9983654,1.0736297,,,,,,,,,,,,,, -347627,,,0.8901757597923279,0.4099457561969757,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,155484.39171671867,169109.14815998077,155484.39171671867,13579.531408786774,26.227181434631348,0.0 -347700,3.4380877,2.88734,,,,,,,,,,,,,, -347800,3.9082177,3.1991558,,,,,,,,,,,,,, -347900,3.0735645,1.098833,,,,,,,,,,,,,, -348000,3.0423784,1.2432333,,,,,,,,,,,,,, -348100,3.856328,3.3160172,,,,,,,,,,,,,, -348200,3.0109985,1.1933044,,,,,,,,,,,,,, -348300,3.3131835,1.082632,,,,,,,,,,,,,, -348400,3.1592443,1.0684761,,,,,,,,,,,,,, -348500,3.1351278,1.1859112,,,,,,,,,,,,,, -348561,,,0.8900585770606995,0.4157879054546356,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,155904.38059473038,169568.8761703968,155904.38059473038,13619.071528434752,26.375840425491333,0.0 -348600,4.0376143,1.1851735,,,,,,,,,,,,,, -348700,3.0523195,1.121685,,,,,,,,,,,,,, -348800,3.0773685,1.1503277,,,,,,,,,,,,,, -348900,3.1753142,1.1615696,,,,,,,,,,,,,, -349000,3.0728912,1.5446999,,,,,,,,,,,,,, -349100,3.1452827,1.752725,,,,,,,,,,,,,, -349200,3.5534112,1.1419408,,,,,,,,,,,,,, -349300,2.747373,1.4738464,,,,,,,,,,,,,, -349400,3.0555694,1.6398764,,,,,,,,,,,,,, -349496,,,0.8889062404632568,0.4083418846130371,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,156324.56995940208,170039.8563761711,156324.56995940208,13669.715999364853,26.47144365310669,0.0 -349500,2.9866252,2.2027493,,,,,,,,,,,,,, -349600,2.9318585,1.1196125,,,,,,,,,,,,,, -349700,3.3592613,1.106545,,,,,,,,,,,,,, -349800,3.1223989,2.7984219,,,,,,,,,,,,,, -349900,2.9284208,1.7730845,,,,,,,,,,,,,, -350000,2.9597394,1.0597941,,,,,,,,,,,,,, -350100,2.8027244,1.772966,,,,,,,,,,,,,, -350200,3.233407,2.9811666,,,,,,,,,,,,,, -350300,3.1006243,1.1359817,,,,,,,,,,,,,, -350400,3.6389139,3.3016856,,,,,,,,,,,,,, -350436,,,0.8913476467132568,0.4106284379959106,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,156744.44487810135,170497.61222934723,156744.44487810135,13707.466798067093,26.550642251968384,0.0 -350500,3.5151777,2.9521685,,,,,,,,,,,,,, -350600,3.0162723,1.0367994,,,,,,,,,,,,,, -350700,3.0923822,2.4753003,,,,,,,,,,,,,, -350800,3.8229861,1.9415256,,,,,,,,,,,,,, -350900,3.4485042,2.9846437,,,,,,,,,,,,,, -351000,3.432207,2.0593514,,,,,,,,,,,,,, -351100,3.0601432,1.2887613,,,,,,,,,,,,,, -351200,2.9963822,2.053255,,,,,,,,,,,,,, -351300,2.9034815,2.203684,,,,,,,,,,,,,, -351364,,,0.8867577910423279,0.41962930560112,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,157164.40446853638,170959.27654767036,157164.40446853638,13749.024479150772,26.64775824546814,0.0 -351400,2.8746953,1.4838879,,,,,,,,,,,,,, -351500,3.8313682,1.5439918,,,,,,,,,,,,,, -351600,3.100472,1.0445601,,,,,,,,,,,,,, -351700,3.514973,3.0699458,,,,,,,,,,,,,, -351800,2.962362,1.0018101,,,,,,,,,,,,,, -351900,3.211246,2.1008434,,,,,,,,,,,,,, -352000,3.0596008,1.8906298,,,,,,,,,,,,,, -352100,3.026951,1.5072715,,,,,,,,,,,,,, -352200,3.0055299,1.5204945,,,,,,,,,,,,,, -352296,,,0.8877148032188416,0.4151216149330139,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,157584.49403834343,171421.9588572979,157584.49403834343,13791.469810724258,26.74502730369568,0.0 -352300,2.9190197,2.4000077,,,,,,,,,,,,,, -352400,3.119732,2.7013717,,,,,,,,,,,,,, -352500,2.8184068,2.149724,,,,,,,,,,,,,, -352600,2.9490724,1.3455291,,,,,,,,,,,,,, -352700,3.1361454,1.1994138,,,,,,,,,,,,,, -352800,3.7609546,2.8391213,,,,,,,,,,,,,, -352900,3.081633,1.1693379,,,,,,,,,,,,,, -353000,3.3485835,1.0763103,,,,,,,,,,,,,, -353100,2.9350176,1.0691683,,,,,,,,,,,,,, -353200,3.4477856,1.1414367,,,,,,,,,,,,,, -353231,,,0.8877539038658142,0.4173828661441803,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,158004.5953962803,171880.16689515114,158004.5953962803,13829.42780804634,26.84268879890442,0.0 -353300,3.042919,1.2796729,,,,,,,,,,,,,, -353400,3.0325146,1.100352,,,,,,,,,,,,,, -353500,3.0430167,1.1313328,,,,,,,,,,,,,, -353600,2.8889263,1.6657302,,,,,,,,,,,,,, -353700,2.930564,1.1342467,,,,,,,,,,,,,, -353800,3.2581105,1.1377475,,,,,,,,,,,,,, -353900,2.8559563,1.2697276,,,,,,,,,,,,,, -354000,4.924134,2.677733,,,,,,,,,,,,,, -354100,3.1380508,1.1091313,,,,,,,,,,,,,, -354164,,,0.8898046612739563,0.4126203954219818,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,158424.50338053703,172341.84744262695,158424.50338053703,13871.05332994461,26.9394805431366,0.0 -354200,3.4578166,1.4940403,,,,,,,,,,,,,, -354300,3.218994,2.888914,,,,,,,,,,,,,, -354400,2.803106,1.7837831,,,,,,,,,,,,,, -354500,3.4131796,1.2365692,,,,,,,,,,,,,, -354600,3.372957,3.2009847,,,,,,,,,,,,,, -354700,3.7610672,1.12773,,,,,,,,,,,,,, -354800,3.0057063,1.7137505,,,,,,,,,,,,,, -354900,3.5929487,2.8751626,,,,,,,,,,,,,, -355000,3.1584535,1.3889408,,,,,,,,,,,,,, -355100,3.160364,1.1608175,,,,,,,,,,,,,, -355103,,,0.8887304663658142,0.4155606627464294,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,158844.56557393074,172801.1633708477,158844.56557393074,13910.14951992035,27.044692754745483,0.0 -355200,2.8773985,1.9557447,,,,,,,,,,,,,, -355300,3.6991293,2.973865,,,,,,,,,,,,,, -355400,2.9923272,1.2154927,,,,,,,,,,,,,, -355500,3.3083644,2.9583766,,,,,,,,,,,,,, -355600,3.2703834,1.1626291,,,,,,,,,,,,,, -355700,3.1695642,1.1160982,,,,,,,,,,,,,, -355800,2.975396,1.1907811,,,,,,,,,,,,,, -355900,3.2422564,2.3979769,,,,,,,,,,,,,, -356000,2.88416,1.0968617,,,,,,,,,,,,,, -356040,,,0.8875976204872131,0.4178290069103241,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,159264.79722499847,173264.57972598076,159264.79722499847,13953.177673339844,27.149667501449585,0.0 -356100,3.0213182,1.1534891,,,,,,,,,,,,,, -356200,3.092416,2.7303965,,,,,,,,,,,,,, -356300,3.0496802,1.5462005,,,,,,,,,,,,,, -356400,3.3734045,1.040016,,,,,,,,,,,,,, -356500,3.3076987,1.8693537,,,,,,,,,,,,,, -356600,2.9019136,1.2583021,,,,,,,,,,,,,, -356700,2.9243824,1.5006243,,,,,,,,,,,,,, -356800,2.9989088,1.073185,,,,,,,,,,,,,, -356900,3.4815068,1.1148648,,,,,,,,,,,,,, -356979,,,0.8841601610183716,0.422722190618515,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,159684.81939792633,173722.38829994202,159684.81939792633,13990.813126564026,27.24921751022339,0.0 -357000,3.353866,1.1960022,,,,,,,,,,,,,, -357100,3.0560708,2.402461,,,,,,,,,,,,,, -357200,3.0460799,1.1642494,,,,,,,,,,,,,, -357300,4.7060394,3.1548638,,,,,,,,,,,,,, -357400,3.013184,1.1343179,,,,,,,,,,,,,, -357500,2.8257866,1.8920561,,,,,,,,,,,,,, -357600,2.9468262,1.1918013,,,,,,,,,,,,,, -357700,3.062262,1.1341786,,,,,,,,,,,,,, -357800,2.9653897,1.1072714,,,,,,,,,,,,,, -357900,3.0642478,1.0080316,,,,,,,,,,,,,, -357916,,,0.8896288871765137,0.4120044708251953,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,160104.9553952217,174185.2740430832,160104.9553952217,14033.412657022476,27.349547863006592,0.0 -358000,3.5739646,1.1753213,,,,,,,,,,,,,, -358100,3.0754719,1.0685827,,,,,,,,,,,,,, -358200,3.955069,3.2322412,,,,,,,,,,,,,, -358300,3.7591631,3.242272,,,,,,,,,,,,,, -358400,3.0774422,2.1875215,,,,,,,,,,,,,, -358500,3.1487865,1.0786555,,,,,,,,,,,,,, -358600,2.956841,1.0415308,,,,,,,,,,,,,, -358700,3.7062635,3.1621814,,,,,,,,,,,,,, -358800,3.00686,1.1541317,,,,,,,,,,,,,, -358856,,,0.8890820145606995,0.410955011844635,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,160524.85397338867,174641.0027630329,160524.85397338867,14069.0910923481,27.44993758201599,0.0 -358900,2.7188568,1.7086948,,,,,,,,,,,,,, -359000,3.3087504,1.2038645,,,,,,,,,,,,,, -359100,3.1127386,1.0963141,,,,,,,,,,,,,, -359200,3.1797051,1.2321447,,,,,,,,,,,,,, -359300,3.6668057,2.9870281,,,,,,,,,,,,,, -359400,3.86566,3.1673956,,,,,,,,,,,,,, -359500,3.1444986,1.2677472,,,,,,,,,,,,,, -359600,3.1714678,1.1520938,,,,,,,,,,,,,, -359700,3.1691701,1.166486,,,,,,,,,,,,,, -359790,,,0.8885937333106995,0.4169111549854278,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,160944.96697402,175100.8567893505,160944.96697402,14108.684185504912,27.54713273048401,0.0 -359800,3.5109594,3.1435132,,,,,,,,,,,,,, -359900,3.0223384,1.0952903,,,,,,,,,,,,,, -360000,3.9140317,3.128881,,,,,,,,,,,,,, -360100,3.9941514,3.2647643,,,,,,,,,,,,,, -360200,3.037244,2.6598203,,,,,,,,,,,,,, -360300,3.6666741,3.250835,,,,,,,,,,,,,, -360400,2.9551084,1.0736904,,,,,,,,,,,,,, -360500,3.9061236,2.742661,,,,,,,,,,,,,, -360600,2.8865442,1.8206537,,,,,,,,,,,,,, -360700,3.4836462,1.1480335,,,,,,,,,,,,,, -360725,,,0.8864648342132568,0.421193391084671,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,161364.8423793316,175563.51901459694,161364.8423793316,14151.323654413223,27.643819570541385,0.0 -360800,3.5512853,3.2914085,,,,,,,,,,,,,, -360900,2.7889626,1.7511251,,,,,,,,,,,,,, -361000,2.976335,1.3721113,,,,,,,,,,,,,, -361100,2.7877336,1.7498453,,,,,,,,,,,,,, -361200,2.9210308,1.0337738,,,,,,,,,,,,,, -361300,2.9245641,1.1429992,,,,,,,,,,,,,, -361400,3.0558612,1.2203559,,,,,,,,,,,,,, -361500,3.059008,2.4415996,,,,,,,,,,,,,, -361600,3.5067205,3.182288,,,,,,,,,,,,,, -361665,,,0.8865820169448853,0.4151625633239746,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,161785.01636505127,176021.39610123634,161785.01636505127,14188.880197048187,27.73991370201111,0.0 -361700,3.2351155,1.3004133,,,,,,,,,,,,,, -361800,3.0861595,1.5037694,,,,,,,,,,,,,, -361900,3.146086,2.7571669,,,,,,,,,,,,,, -362000,3.0466988,2.079226,,,,,,,,,,,,,, -362100,2.9082212,1.1061318,,,,,,,,,,,,,, -362200,2.8952203,1.1807114,,,,,,,,,,,,,, -362300,3.0911133,2.110353,,,,,,,,,,,,,, -362400,2.7194273,1.684909,,,,,,,,,,,,,, -362500,3.280912,2.5318918,,,,,,,,,,,,,, -362600,2.9264596,1.1865551,,,,,,,,,,,,,, -362603,,,0.8888476490974426,0.4111815989017486,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,162205.02902841568,176486.84476661682,162205.02902841568,14234.16361951828,27.84222197532653,0.0 -362700,3.455011,3.1757228,,,,,,,,,,,,,, -362800,3.1710744,1.0511339,,,,,,,,,,,,,, -362900,3.3063834,2.4341455,,,,,,,,,,,,,, -363000,2.8971348,1.6797662,,,,,,,,,,,,,, -363100,3.2003334,1.1157092,,,,,,,,,,,,,, -363200,2.8881364,2.0421891,,,,,,,,,,,,,, -363300,2.9584546,1.1324363,,,,,,,,,,,,,, -363400,2.8389711,1.3033543,,,,,,,,,,,,,, -363500,2.7811217,1.3051207,,,,,,,,,,,,,, -363541,,,0.8886523246765137,0.4192703664302826,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,162625.00325775146,176948.18750071526,162625.00325775146,14275.383254051208,27.941187858581543,0.0 -363600,3.0064454,2.5690567,,,,,,,,,,,,,, -363700,2.8050723,1.0169374,,,,,,,,,,,,,, -363800,2.788433,1.0856652,,,,,,,,,,,,,, -363900,2.9795492,1.1303596,,,,,,,,,,,,,, -364000,3.2037137,1.9885371,,,,,,,,,,,,,, -364100,2.931331,1.1520501,,,,,,,,,,,,,, -364200,3.7134264,3.282559,,,,,,,,,,,,,, -364300,4.4726014,3.1240602,,,,,,,,,,,,,, -364400,3.3962777,2.5212593,,,,,,,,,,,,,, -364479,,,0.8897656202316284,0.4115915596485138,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,163045.1636610031,177404.25304937363,163045.1636610031,14311.135527849196,28.044018983840942,0.0 -364500,3.3308733,2.576425,,,,,,,,,,,,,, -364600,3.0432384,1.3361111,,,,,,,,,,,,,, -364700,3.2375107,2.3931487,,,,,,,,,,,,,, -364800,3.064994,1.0589998,,,,,,,,,,,,,, -364900,2.7398882,1.1399533,,,,,,,,,,,,,, -365000,3.0238006,1.084314,,,,,,,,,,,,,, -365100,3.6963482,2.6600647,,,,,,,,,,,,,, -365200,2.9899383,1.35421,,,,,,,,,,,,,, -365300,3.5811262,3.120713,,,,,,,,,,,,,, -365400,4.052741,2.9027796,,,,,,,,,,,,,, -365413,,,0.8880468606948853,0.4186495542526245,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,163465.0584104061,177870.06887936592,163465.0584104061,14356.90769314766,28.14191770553589,0.0 -365500,2.9734228,1.0762146,,,,,,,,,,,,,, -365600,3.9644818,3.1705577,,,,,,,,,,,,,, -365700,2.8571858,1.900969,,,,,,,,,,,,,, -365800,3.3638072,2.9959235,,,,,,,,,,,,,, -365900,3.0763662,1.1192274,,,,,,,,,,,,,, -366000,2.8495047,1.4149994,,,,,,,,,,,,,, -366100,3.4792175,3.2770736,,,,,,,,,,,,,, -366200,3.6515038,3.2687945,,,,,,,,,,,,,, -366300,6.198979,1.0436075,,,,,,,,,,,,,, -366350,,,0.8872460722923279,0.4188577532768249,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,163885.0548875332,178324.59226608276,163885.0548875332,14391.303559541702,28.22307014465332,0.0 -366400,2.7920892,1.3110867,,,,,,,,,,,,,, -366500,3.0415747,1.384695,,,,,,,,,,,,,, -366600,3.0403547,1.070783,,,,,,,,,,,,,, -366700,3.150354,2.6365232,,,,,,,,,,,,,, -366800,2.898515,1.1467198,,,,,,,,,,,,,, -366900,2.8406682,1.5926988,,,,,,,,,,,,,, -367000,3.1193957,2.8622665,,,,,,,,,,,,,, -367100,2.9306724,1.1134971,,,,,,,,,,,,,, -367200,4.17503,3.0580118,,,,,,,,,,,,,, -367286,,,0.8896288871765137,0.4089963436126709,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,164305.09626293182,178787.65193104744,164305.09626293182,14434.169000864027,28.325141668319706,0.0 -367300,2.9872432,2.2926667,,,,,,,,,,,,,, -367400,2.8781965,1.1229678,,,,,,,,,,,,,, -367500,3.1744893,2.2448063,,,,,,,,,,,,,, -367600,3.1671827,1.0939268,,,,,,,,,,,,,, -367700,2.9177725,2.3182359,,,,,,,,,,,,,, -367800,2.9284418,1.1052887,,,,,,,,,,,,,, -367900,3.030803,1.071515,,,,,,,,,,,,,, -368000,3.0044177,1.1089249,,,,,,,,,,,,,, -368100,3.156814,1.0468588,,,,,,,,,,,,,, -368200,2.9716883,1.811765,,,,,,,,,,,,,, -368223,,,0.8872460722923279,0.4134839475154876,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,164725.26630687714,179248.5905110836,164725.26630687714,14474.787692546844,28.424390077590942,0.0 -368300,2.9564915,1.1849602,,,,,,,,,,,,,, -368400,3.1087182,1.2118106,,,,,,,,,,,,,, -368500,3.1053119,1.208585,,,,,,,,,,,,,, -368600,3.6633217,2.7096279,,,,,,,,,,,,,, -368700,3.1045876,1.9523844,,,,,,,,,,,,,, -368800,2.880786,1.0398772,,,,,,,,,,,,,, -368900,3.2332566,1.1042588,,,,,,,,,,,,,, -369000,9.0379095,1.1699659,,,,,,,,,,,,,, -369100,3.3683608,1.2111382,,,,,,,,,,,,,, -369160,,,0.8898632526397705,0.4105189740657806,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,165145.18168735504,179710.54042887688,165145.18168735504,14516.611475229263,28.584221839904785,0.0 -369200,2.9722764,2.0496216,,,,,,,,,,,,,, -369300,3.0285013,1.2737352,,,,,,,,,,,,,, -369400,4.040761,1.1526506,,,,,,,,,,,,,, -369500,3.8273473,3.2186413,,,,,,,,,,,,,, -369600,4.1110697,1.1093345,,,,,,,,,,,,,, -369700,3.0309122,1.6854682,,,,,,,,,,,,,, -369800,3.1151226,1.0845333,,,,,,,,,,,,,, -369900,2.7137647,1.5463469,,,,,,,,,,,,,, -370000,7.025284,2.938677,,,,,,,,,,,,,, -370098,,,0.8863281011581421,0.4179736673831939,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,165565.3174443245,180172.10485219955,165565.3174443245,14557.888511419296,28.685779333114624,0.0 -370100,5.0192194,3.2165189,,,,,,,,,,,,,, -370200,3.1992054,1.0579216,,,,,,,,,,,,,, -370300,2.8718705,1.0953271,,,,,,,,,,,,,, -370400,3.0164852,1.9910624,,,,,,,,,,,,,, -370500,2.9756927,2.3445241,,,,,,,,,,,,,, -370600,3.122218,2.2002258,,,,,,,,,,,,,, -370700,2.9342864,1.8994102,,,,,,,,,,,,,, -370800,3.8257737,1.3006215,,,,,,,,,,,,,, -370900,3.0188143,1.0559999,,,,,,,,,,,,,, -371000,3.1677616,1.2100135,,,,,,,,,,,,,, -371035,,,0.889453113079071,0.4162811636924743,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,165985.38984370232,180631.1038618088,165985.38984370232,14596.663992404938,28.78573179244995,0.0 -371100,3.007982,2.357689,,,,,,,,,,,,,, -371200,3.2812784,1.9821159,,,,,,,,,,,,,, -371300,2.9704819,1.0361942,,,,,,,,,,,,,, -371400,3.1686542,1.039959,,,,,,,,,,,,,, -371500,3.0486495,1.0324789,,,,,,,,,,,,,, -371600,3.0175817,1.176732,,,,,,,,,,,,,, -371700,3.6166294,3.0271673,,,,,,,,,,,,,, -371800,3.9468632,2.7571654,,,,,,,,,,,,,, -371900,2.951537,1.1749996,,,,,,,,,,,,,, -371973,,,0.8895117044448853,0.4145427942276001,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,166405.45948576927,181091.9565012455,166405.45948576927,14637.276457309725,28.905762434005737,0.0 -372000,3.3787541,3.1354547,,,,,,,,,,,,,, -372100,3.0305402,1.2932491,,,,,,,,,,,,,, -372200,3.3185935,2.6754966,,,,,,,,,,,,,, -372300,2.853705,2.0943317,,,,,,,,,,,,,, -372400,2.9127233,2.2110212,,,,,,,,,,,,,, -372500,2.8078482,0.98534435,,,,,,,,,,,,,, -372600,2.9878385,1.8852912,,,,,,,,,,,,,, -372700,2.9043329,1.7476964,,,,,,,,,,,,,, -372800,3.9993782,3.284289,,,,,,,,,,,,,, -372900,2.9040189,1.1109028,,,,,,,,,,,,,, -372908,,,0.889941394329071,0.4062135517597198,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,166825.61236143112,181550.35774874687,166825.61236143112,14675.369314670565,29.011127471923828,0.0 -373000,3.5210938,2.8945942,,,,,,,,,,,,,, -373100,3.1194828,1.431309,,,,,,,,,,,,,, -373200,3.2207525,1.1742853,,,,,,,,,,,,,, -373300,3.6542702,3.078456,,,,,,,,,,,,,, -373400,3.2192445,1.8402301,,,,,,,,,,,,,, -373500,3.0343938,1.2672157,,,,,,,,,,,,,, -373600,3.0586576,1.2917116,,,,,,,,,,,,,, -373700,2.9856336,1.1004977,,,,,,,,,,,,,, -373800,3.0523753,1.1423674,,,,,,,,,,,,,, -373842,,,0.888476550579071,0.4190575182437897,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,167245.81822776794,182009.38965320587,167245.81822776794,14714.035625219343,29.120630979537964,0.0 -373900,2.9746728,1.0559021,,,,,,,,,,,,,, -374000,3.8286226,3.0055976,,,,,,,,,,,,,, -374100,3.2581573,1.1131412,,,,,,,,,,,,,, -374200,3.190699,1.0605649,,,,,,,,,,,,,, -374300,2.9012792,1.6462177,,,,,,,,,,,,,, -374400,2.8751302,1.0892667,,,,,,,,,,,,,, -374500,3.0926175,2.50239,,,,,,,,,,,,,, -374600,2.985763,2.155958,,,,,,,,,,,,,, -374700,2.8772216,1.2516226,,,,,,,,,,,,,, -374778,,,0.88880854845047,0.4142800569534302,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,167665.64027285576,182465.953375578,167665.64027285576,14750.559784173964,29.287435293197632,0.0 -374800,3.3862994,2.9169855,,,,,,,,,,,,,, -374900,3.4179575,1.1139779,,,,,,,,,,,,,, -375000,3.1975963,1.0846032,,,,,,,,,,,,,, -375100,3.865611,3.0650046,,,,,,,,,,,,,, -375200,3.096566,1.7315474,,,,,,,,,,,,,, -375300,3.4838166,3.2288582,,,,,,,,,,,,,, -375400,3.833565,3.177733,,,,,,,,,,,,,, -375500,3.6113052,3.1593587,,,,,,,,,,,,,, -375600,2.8739142,1.96929,,,,,,,,,,,,,, -375700,2.8395326,2.255231,,,,,,,,,,,,,, -375716,,,0.88720703125,0.4153881371021271,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,168085.7812242508,182928.67467594147,168085.7812242508,14792.984617710114,29.391901969909668,0.0 -375800,3.8349953,3.0969725,,,,,,,,,,,,,, -375900,2.9807618,2.069037,,,,,,,,,,,,,, -376000,3.353071,1.9224197,,,,,,,,,,,,,, -376100,3.2147937,1.1800203,,,,,,,,,,,,,, -376200,2.878979,2.14296,,,,,,,,,,,,,, -376300,3.059682,1.5956216,,,,,,,,,,,,,, -376400,3.4979184,1.1235507,,,,,,,,,,,,,, -376500,2.961172,2.2054164,,,,,,,,,,,,,, -376600,3.2008982,1.5778481,,,,,,,,,,,,,, -376654,,,0.8880664110183716,0.4154757261276245,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,168506.1006731987,183386.4454011917,168506.1006731987,14830.284606933594,29.49255681037903,0.0 -376700,3.431082,1.0715549,,,,,,,,,,,,,, -376800,3.4693148,2.817864,,,,,,,,,,,,,, -376900,3.0210295,1.5705359,,,,,,,,,,,,,, -377000,3.1051445,2.4508944,,,,,,,,,,,,,, -377100,3.0179267,1.1213379,,,,,,,,,,,,,, -377200,3.2208493,1.0672133,,,,,,,,,,,,,, -377300,2.727508,1.374356,,,,,,,,,,,,,, -377400,2.9988732,1.1516694,,,,,,,,,,,,,, -377500,3.608161,1.197789,,,,,,,,,,,,,, -377591,,,0.8890624642372131,0.4145684838294983,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,168926.2371172905,183844.1845271588,168926.2371172905,14867.735316514969,29.594119548797607,0.0 -377600,2.7956626,1.2766123,,,,,,,,,,,,,, -377700,3.2094443,1.0580051,,,,,,,,,,,,,, -377800,3.199565,2.158277,,,,,,,,,,,,,, -377900,2.982883,2.5037327,,,,,,,,,,,,,, -378000,4.418545,3.206043,,,,,,,,,,,,,, -378100,3.1350398,1.1293298,,,,,,,,,,,,,, -378200,2.8840237,1.5963886,,,,,,,,,,,,,, -378300,2.931831,1.4303586,,,,,,,,,,,,,, -378400,3.2063046,2.801234,,,,,,,,,,,,,, -378500,3.0084894,2.4917872,,,,,,,,,,,,,, -378529,,,0.8909765481948853,0.4109188318252563,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,169346.13319802284,184306.2586414814,169346.13319802284,14909.759913921356,29.696993350982662,0.0 -378600,3.2172737,1.7114666,,,,,,,,,,,,,, -378700,3.0440412,1.0853164,,,,,,,,,,,,,, -378800,2.9879358,1.8687578,,,,,,,,,,,,,, -378900,3.6312637,3.2105153,,,,,,,,,,,,,, -379000,2.947373,1.8838093,,,,,,,,,,,,,, -379100,2.9773111,1.1385298,,,,,,,,,,,,,, -379200,3.4659681,1.0929219,,,,,,,,,,,,,, -379300,3.02321,1.390343,,,,,,,,,,,,,, -379400,2.9447653,1.6488609,,,,,,,,,,,,,, -379467,,,0.8854491710662842,0.4222442507743835,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,169766.32176852226,184768.3322675228,169766.32176852226,14951.494970560074,29.796590089797974,0.0 -379500,3.049673,1.1856785,,,,,,,,,,,,,, -379600,3.7037733,3.1672752,,,,,,,,,,,,,, -379700,3.106124,1.1217396,,,,,,,,,,,,,, -379800,2.96457,1.3946431,,,,,,,,,,,,,, -379900,3.4519973,2.889219,,,,,,,,,,,,,, -380000,3.4085822,1.1812636,,,,,,,,,,,,,, -380100,3.0338955,1.228575,,,,,,,,,,,,,, -380200,2.8865943,2.4760919,,,,,,,,,,,,,, -380300,3.0838506,2.0775986,,,,,,,,,,,,,, -380400,3.388857,1.1822941,,,,,,,,,,,,,, -380409,,,0.8866991996765137,0.4186405539512634,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,170186.26800489426,185225.72126054764,170186.26800489426,14988.764938354492,29.91712141036988,0.0 -380500,3.3011694,1.1099229,,,,,,,,,,,,,, -380600,3.285726,1.2990956,,,,,,,,,,,,,, -380700,3.1101243,1.1298852,,,,,,,,,,,,,, -380800,4.098719,2.9801152,,,,,,,,,,,,,, -380900,2.9598083,1.1335086,,,,,,,,,,,,,, -381000,3.0352864,2.4945846,,,,,,,,,,,,,, -381100,3.135482,1.3360193,,,,,,,,,,,,,, -381200,3.339565,1.1147054,,,,,,,,,,,,,, -381300,3.1738546,1.0779119,,,,,,,,,,,,,, -381348,,,0.8898242115974426,0.4092896282672882,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,170606.38489556313,185683.616868496,170606.38489556313,15026.391194105148,30.01814985275269,0.0 -381400,3.803233,3.32486,,,,,,,,,,,,,, -381500,3.0033436,2.421039,,,,,,,,,,,,,, -381600,3.3064656,1.1135345,,,,,,,,,,,,,, -381700,3.6979785,2.9355683,,,,,,,,,,,,,, -381800,3.265396,1.2848043,,,,,,,,,,,,,, -381900,3.1298857,1.2533485,,,,,,,,,,,,,, -382000,2.949629,1.6781803,,,,,,,,,,,,,, -382100,2.8440754,1.428926,,,,,,,,,,,,,, -382200,2.722984,1.4629452,,,,,,,,,,,,,, -382287,,,0.8888476490974426,0.4133008122444153,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,171026.64174222946,186146.192868948,171026.64174222946,15068.558099269869,30.119590759277344,0.0 -382300,2.919932,2.021846,,,,,,,,,,,,,, -382400,3.2992332,1.1150972,,,,,,,,,,,,,, -382500,3.0773964,1.1940275,,,,,,,,,,,,,, -382600,3.278281,1.4004533,,,,,,,,,,,,,, -382700,3.375947,2.935559,,,,,,,,,,,,,, -382800,3.7467568,2.7634888,,,,,,,,,,,,,, -382900,2.9745708,1.5203748,,,,,,,,,,,,,, -383000,3.0072067,1.4653403,,,,,,,,,,,,,, -383100,2.9920444,2.3221412,,,,,,,,,,,,,, -383200,3.1444213,1.459603,,,,,,,,,,,,,, -383223,,,0.8874022960662842,0.4168292582035064,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,171446.73694324493,186606.4345138073,171446.73694324493,15108.549597740172,30.223318576812744,0.0 -383300,3.1288605,1.3733398,,,,,,,,,,,,,, -383400,2.834061,1.0711536,,,,,,,,,,,,,, -383500,2.876318,1.3041723,,,,,,,,,,,,,, -383600,2.691443,1.6510607,,,,,,,,,,,,,, -383700,2.890381,1.5095562,,,,,,,,,,,,,, -383800,3.1471632,2.2570953,,,,,,,,,,,,,, -383900,3.2008076,1.2968568,,,,,,,,,,,,,, -384000,3.1053233,1.1359568,,,,,,,,,,,,,, -384100,2.9120991,1.2000096,,,,,,,,,,,,,, -384161,,,0.8870702981948853,0.4222002029418945,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,171866.6207382679,187068.31393790245,171866.6207382679,15150.392226219175,30.32602262496948,0.0 -384200,2.8839138,2.3902097,,,,,,,,,,,,,, -384300,3.9282477,3.3406296,,,,,,,,,,,,,, -384400,2.9608555,2.31055,,,,,,,,,,,,,, -384500,3.150492,1.2195014,,,,,,,,,,,,,, -384600,3.6478965,3.289446,,,,,,,,,,,,,, -384700,2.9938624,2.1923325,,,,,,,,,,,,,, -384800,2.8699894,1.0559089,,,,,,,,,,,,,, -384900,3.835156,3.0277843,,,,,,,,,,,,,, -385000,3.2332454,1.0794802,,,,,,,,,,,,,, -385099,,,0.8860937356948853,0.4177397787570953,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,172286.67732286453,187524.0705358982,172286.67732286453,15185.938012599943,30.430259704589844,0.0 -385100,2.959336,2.293151,,,,,,,,,,,,,, -385200,3.0302782,2.5758617,,,,,,,,,,,,,, -385300,2.900212,1.4087473,,,,,,,,,,,,,, -385400,2.971243,1.3043431,,,,,,,,,,,,,, -385500,2.8893867,1.2351699,,,,,,,,,,,,,, -385600,2.8510182,1.4890538,,,,,,,,,,,,,, -385700,3.1227365,1.4142964,,,,,,,,,,,,,, -385800,3.8468506,1.1530402,,,,,,,,,,,,,, -385900,3.1126795,1.1132671,,,,,,,,,,,,,, -386000,2.7533572,1.6184987,,,,,,,,,,,,,, -386036,,,0.89013671875,0.4076444208621979,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,172706.67554998398,187984.7664122581,172706.67554998398,15226.47905087471,30.53153157234192,0.0 -386100,2.9248838,1.7698101,,,,,,,,,,,,,, -386200,2.8941016,1.0508492,,,,,,,,,,,,,, -386300,3.0238297,1.2135146,,,,,,,,,,,,,, -386400,3.2508574,2.5473933,,,,,,,,,,,,,, -386500,2.7169876,1.0169288,,,,,,,,,,,,,, -386600,3.7532814,3.2008228,,,,,,,,,,,,,, -386700,3.8279474,3.2789185,,,,,,,,,,,,,, -386800,3.1807919,1.0576416,,,,,,,,,,,,,, -386900,3.1403892,1.0968702,,,,,,,,,,,,,, -386974,,,0.8879687190055847,0.4217220842838287,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,173126.61281085014,188444.7423121929,173126.61281085014,15266.359908103945,30.63612031936645,0.0 -387000,3.1850822,2.1927156,,,,,,,,,,,,,, -387100,3.3104937,2.549003,,,,,,,,,,,,,, -387200,2.9001582,1.2822201,,,,,,,,,,,,,, -387300,3.3002882,2.6473398,,,,,,,,,,,,,, -387400,3.0379531,1.1898084,,,,,,,,,,,,,, -387500,3.0387373,1.8697962,,,,,,,,,,,,,, -387600,3.2442944,1.1995302,,,,,,,,,,,,,, -387700,3.164939,1.0878747,,,,,,,,,,,,,, -387800,3.2127786,2.171216,,,,,,,,,,,,,, -387900,2.8594694,2.3408046,,,,,,,,,,,,,, -387911,,,0.888671875,0.4133156836032867,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,173546.7705936432,188906.29854869845,173546.7705936432,15307.598673582075,30.744555711746216,0.0 -388000,2.9349647,1.3235589,,,,,,,,,,,,,, -388100,3.124336,2.587081,,,,,,,,,,,,,, -388200,2.8645895,1.7410907,,,,,,,,,,,,,, -388300,2.8351152,1.6589144,,,,,,,,,,,,,, -388400,3.1237376,1.1446791,,,,,,,,,,,,,, -388500,3.1182578,1.4503424,,,,,,,,,,,,,, -388600,3.2052183,1.6398644,,,,,,,,,,,,,, -388700,2.8206496,1.057297,,,,,,,,,,,,,, -388800,3.221378,1.9236866,,,,,,,,,,,,,, -388848,,,0.8880078196525574,0.4169968068599701,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,173967.05476927757,189363.2365736961,173967.05476927757,15344.09970164299,30.84595322608948,0.0 -388900,2.8472857,1.0613317,,,,,,,,,,,,,, -389000,2.8422537,2.0130422,,,,,,,,,,,,,, -389100,2.960225,1.5536729,,,,,,,,,,,,,, -389200,3.3643122,1.1620656,,,,,,,,,,,,,, -389300,2.7833192,1.0244898,,,,,,,,,,,,,, -389400,3.329107,1.1763313,,,,,,,,,,,,,, -389500,3.0051224,1.8371122,,,,,,,,,,,,,, -389600,3.092527,2.450444,,,,,,,,,,,,,, -389700,3.002294,2.37083,,,,,,,,,,,,,, -389786,,,0.8884570002555847,0.4161718785762787,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,174387.26860404015,189824.72399759293,174387.26860404015,15385.21784043312,30.950018405914307,0.0 -389800,2.82807,1.9684931,,,,,,,,,,,,,, -389900,3.551896,1.2152505,,,,,,,,,,,,,, -390000,3.1596308,1.4533429,,,,,,,,,,,,,, -390100,3.5125697,1.7716041,,,,,,,,,,,,,, -390200,2.9835072,2.2309763,,,,,,,,,,,,,, -390300,3.0992303,1.3931113,,,,,,,,,,,,,, -390400,3.095453,1.1700335,,,,,,,,,,,,,, -390500,3.9304261,3.115529,,,,,,,,,,,,,, -390600,3.3743749,2.7366552,,,,,,,,,,,,,, -390700,2.9404361,1.1009403,,,,,,,,,,,,,, -390726,,,0.8868359327316284,0.4168016612529754,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,174807.295296669,190288.2693283558,174807.295296669,15428.580172777176,31.055320739746094,0.0 -390800,3.2739959,1.0440905,,,,,,,,,,,,,, -390900,3.7631965,3.2173603,,,,,,,,,,,,,, -391000,3.983547,2.4845726,,,,,,,,,,,,,, -391100,2.9946678,1.2382989,,,,,,,,,,,,,, -391200,3.5064595,2.968039,,,,,,,,,,,,,, -391300,2.8321197,1.9605827,,,,,,,,,,,,,, -391400,3.1236262,1.1738448,,,,,,,,,,,,,, -391500,3.1887712,2.651408,,,,,,,,,,,,,, -391600,3.1649368,1.1757681,,,,,,,,,,,,,, -391665,,,0.8889062404632568,0.4064289629459381,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,175227.144854784,190743.3609380722,175227.144854784,15463.608940601349,31.21803855895996,0.0 -391700,2.9339488,1.1292965,,,,,,,,,,,,,, -391800,3.1208472,1.1036885,,,,,,,,,,,,,, -391900,2.9383523,0.9990397,,,,,,,,,,,,,, -392000,3.0056372,1.1777787,,,,,,,,,,,,,, -392100,3.3884296,2.9952831,,,,,,,,,,,,,, -392200,3.3370051,1.0610135,,,,,,,,,,,,,, -392300,3.3578575,2.2017748,,,,,,,,,,,,,, -392400,3.1433554,1.2373674,,,,,,,,,,,,,, -392500,3.0789201,2.1465707,,,,,,,,,,,,,, -392599,,,0.8891796469688416,0.4163433015346527,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,175647.2268936634,191204.0561449528,175647.2268936634,15504.066883087158,31.32296562194824,0.0 -392600,3.386552,1.1202456,,,,,,,,,,,,,, -392700,2.8944488,1.5418377,,,,,,,,,,,,,, -392800,3.1249459,1.0894595,,,,,,,,,,,,,, -392900,3.2937574,2.9767897,,,,,,,,,,,,,, -393000,3.5254138,3.1702473,,,,,,,,,,,,,, -393100,3.1364932,1.1616076,,,,,,,,,,,,,, -393200,3.984108,3.2640145,,,,,,,,,,,,,, -393300,2.8555992,1.5144335,,,,,,,,,,,,,, -393400,3.1159966,1.0186857,,,,,,,,,,,,,, -393500,3.3339393,1.671248,,,,,,,,,,,,,, -393536,,,0.8873632550239563,0.4174995720386505,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,176067.19288492203,191665.5880215168,176067.19288492203,15545.477265357971,31.42723274230957,0.0 -393600,2.921742,1.4581655,,,,,,,,,,,,,, -393700,3.5300872,3.1868515,,,,,,,,,,,,,, -393800,4.5583735,3.2358592,,,,,,,,,,,,,, -393900,3.0777884,1.1620976,,,,,,,,,,,,,, -394000,2.9946258,2.2687824,,,,,,,,,,,,,, -394100,2.963561,1.0922928,,,,,,,,,,,,,, -394200,2.8667943,1.3726686,,,,,,,,,,,,,, -394300,2.8784778,2.4532766,,,,,,,,,,,,,, -394400,3.604515,1.1091702,,,,,,,,,,,,,, -394473,,,0.8908984065055847,0.4083752334117889,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,176487.19069099426,192125.47338366508,176487.19069099426,15585.209602117538,31.53172993659973,0.0 -394500,3.7378545,3.3186016,,,,,,,,,,,,,, -394600,3.0248418,1.3403392,,,,,,,,,,,,,, -394700,2.8237166,1.7089456,,,,,,,,,,,,,, -394800,3.2130113,1.2118337,,,,,,,,,,,,,, -394900,3.2373705,1.0990542,,,,,,,,,,,,,, -395000,3.5670776,2.8148746,,,,,,,,,,,,,, -395100,3.0177727,1.5105025,,,,,,,,,,,,,, -395200,2.720481,1.5875715,,,,,,,,,,,,,, -395300,3.7741983,3.1721892,,,,,,,,,,,,,, -395400,3.0091496,1.0960163,,,,,,,,,,,,,, -395412,,,0.8875976204872131,0.4181486368179321,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,176907.33445096016,192591.4912652969,176907.33445096016,15630.927192687988,31.637397050857544,0.0 -395500,3.0682485,1.1512967,,,,,,,,,,,,,, -395600,3.1900976,1.19057,,,,,,,,,,,,,, -395700,3.0993989,1.0677137,,,,,,,,,,,,,, -395800,3.1332974,2.2069654,,,,,,,,,,,,,, -395900,3.2082179,1.9957896,,,,,,,,,,,,,, -396000,3.2178907,1.462791,,,,,,,,,,,,,, -396100,2.9406219,1.3590578,,,,,,,,,,,,,, -396200,3.2980175,2.776585,,,,,,,,,,,,,, -396300,3.7159195,3.1140041,,,,,,,,,,,,,, -396352,,,0.8899609446525574,0.4101354479789734,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,177327.337089777,193058.3018424511,177327.337089777,15677.581050395966,31.74113130569458,0.0 -396400,3.280784,2.9330235,,,,,,,,,,,,,, -396500,3.9468997,3.316946,,,,,,,,,,,,,, -396600,3.551207,2.867386,,,,,,,,,,,,,, -396700,3.0984843,1.1386817,,,,,,,,,,,,,, -396800,2.9687035,1.2093886,,,,,,,,,,,,,, -396900,3.0579677,1.0372527,,,,,,,,,,,,,, -397000,3.4986498,1.6138655,,,,,,,,,,,,,, -397100,3.5532577,3.050375,,,,,,,,,,,,,, -397200,3.186187,1.1093837,,,,,,,,,,,,,, -397294,,,0.8896484375,0.4111970663070678,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,177747.5464015007,193514.40622234344,177747.5464015007,15713.32974267006,31.83563208580017,0.0 -397300,3.404993,3.0298376,,,,,,,,,,,,,, -397400,3.7652316,3.2271094,,,,,,,,,,,,,, -397500,2.9507213,1.6072109,,,,,,,,,,,,,, -397600,3.3823938,1.1148707,,,,,,,,,,,,,, -397700,3.0589533,2.660737,,,,,,,,,,,,,, -397800,2.9614222,2.468837,,,,,,,,,,,,,, -397900,3.035042,1.1928,,,,,,,,,,,,,, -398000,3.0308402,1.1428418,,,,,,,,,,,,,, -398100,3.1048515,1.1417586,,,,,,,,,,,,,, -398200,3.2806275,2.354225,,,,,,,,,,,,,, -398230,,,0.8900781273841858,0.4136282205581665,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,178167.48953986168,193972.0816402436,178167.48953986168,15750.907168149948,31.93990921974182,0.0 -398300,3.270946,1.091231,,,,,,,,,,,,,, -398400,3.4545605,1.1090304,,,,,,,,,,,,,, -398500,3.1967306,1.1841139,,,,,,,,,,,,,, -398600,2.836116,1.3472748,,,,,,,,,,,,,, -398700,3.0485563,1.8588936,,,,,,,,,,,,,, -398800,2.9614089,2.1805243,,,,,,,,,,,,,, -398900,2.982421,1.2200629,,,,,,,,,,,,,, -399000,3.1614938,1.7257924,,,,,,,,,,,,,, -399100,3.0498362,1.9857779,,,,,,,,,,,,,, -399168,,,0.8863476514816284,0.4190600514411926,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,178587.6754131317,194437.5472960472,178587.6754131317,15796.030497074127,32.04532527923584,0.0 -399200,2.9732327,2.372272,,,,,,,,,,,,,, -399300,2.9603915,1.0654482,,,,,,,,,,,,,, -399400,3.1578357,2.6828134,,,,,,,,,,,,,, -399500,3.5042992,1.2717965,,,,,,,,,,,,,, -399600,3.781507,3.1643229,,,,,,,,,,,,,, -399700,2.9000692,1.1080116,,,,,,,,,,,,,, -399800,3.5119858,2.0236256,,,,,,,,,,,,,, -399900,3.139921,2.4974766,,,,,,,,,,,,,, -400000,2.988967,1.1049054,,,,,,,,,,,,,, -400100,3.1208067,2.6999917,,,,,,,,,,,,,, -400108,,,0.8864257335662842,0.4217060506343841,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,179007.58912825584,194890.8040919304,179007.58912825584,15829.235847234726,32.13245725631714,0.0 -400200,3.0532937,2.464457,,,,,,,,,,,,,, -400300,3.7929885,3.015017,,,,,,,,,,,,,, -400400,4.0590463,3.1763163,,,,,,,,,,,,,, -400500,3.0507667,1.5682633,,,,,,,,,,,,,, -400600,2.8019166,1.1515357,,,,,,,,,,,,,, -400700,3.2565184,2.8407626,,,,,,,,,,,,,, -400800,2.987059,1.0933735,,,,,,,,,,,,,, -400900,3.0467103,1.1222948,,,,,,,,,,,,,, -401000,3.348254,2.751139,,,,,,,,,,,,,, -401046,,,0.8907226324081421,0.4068452715873718,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,179427.76574611664,195357.3195183277,179427.76574611664,15875.423902750015,32.232825756073,0.0 -401100,3.1777658,0.9661281,,,,,,,,,,,,,, -401200,3.7591047,2.9402719,,,,,,,,,,,,,, -401300,3.958502,3.0854738,,,,,,,,,,,,,, -401400,3.026158,1.2084926,,,,,,,,,,,,,, -401500,2.7671394,1.10005,,,,,,,,,,,,,, -401600,2.6672351,1.9186686,,,,,,,,,,,,,, -401700,3.2066987,1.3687948,,,,,,,,,,,,,, -401800,2.9692814,1.1297828,,,,,,,,,,,,,, -401900,3.6705801,3.2636495,,,,,,,,,,,,,, -401986,,,0.8880859017372131,0.4193184077739715,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,179847.76591467857,195817.83309102056,179847.76591467857,15915.782636880876,32.33623695373535,0.0 -402000,2.861549,1.5802435,,,,,,,,,,,,,, -402100,3.1506064,1.4509845,,,,,,,,,,,,,, -402200,3.527094,2.8662906,,,,,,,,,,,,,, -402300,3.0972476,1.1570632,,,,,,,,,,,,,, -402400,3.669064,3.1445942,,,,,,,,,,,,,, -402500,3.1895845,2.5465674,,,,,,,,,,,,,, -402600,3.1179483,1.2126245,,,,,,,,,,,,,, -402700,3.1341054,1.5639422,,,,,,,,,,,,,, -402800,3.081559,2.179213,,,,,,,,,,,,,, -402900,3.0371418,1.1992284,,,,,,,,,,,,,, -402925,,,0.8873828053474426,0.4192410707473755,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,180267.8737359047,196276.5892522335,180267.8737359047,15954.291462898254,32.424152135849,0.0 -403000,3.2689202,2.9196262,,,,,,,,,,,,,, -403100,3.8597212,3.1792893,,,,,,,,,,,,,, -403200,2.910522,1.1363244,,,,,,,,,,,,,, -403300,3.093718,1.1935909,,,,,,,,,,,,,, -403400,3.2082658,1.6159291,,,,,,,,,,,,,, -403500,3.0840013,2.9341807,,,,,,,,,,,,,, -403600,3.0251107,1.5487965,,,,,,,,,,,,,, -403700,2.9909897,1.1705967,,,,,,,,,,,,,, -403800,3.341786,2.818441,,,,,,,,,,,,,, -403860,,,0.8870312571525574,0.4160933196544647,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,180687.81717062,196732.092625618,180687.81717062,15989.69368314743,32.5312180519104,0.0 -403900,3.2162368,1.2258028,,,,,,,,,,,,,, -404000,3.4693298,1.8179734,,,,,,,,,,,,,, -404100,3.5307138,1.0846624,,,,,,,,,,,,,, -404200,3.0283718,1.7194586,,,,,,,,,,,,,, -404300,3.169389,1.1533537,,,,,,,,,,,,,, -404400,3.6773787,3.3004534,,,,,,,,,,,,,, -404500,3.2157013,2.3050866,,,,,,,,,,,,,, -404600,2.865509,1.1396332,,,,,,,,,,,,,, -404700,2.877425,1.0644766,,,,,,,,,,,,,, -404797,,,0.8872265219688416,0.4159035086631775,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,181108.0825717449,197197.78607559204,181108.0825717449,16034.969373941422,32.63423204421997,0.0 -404800,2.9057002,1.349522,,,,,,,,,,,,,, -404900,3.0314732,1.5231423,,,,,,,,,,,,,, -405000,2.9628901,1.7427477,,,,,,,,,,,,,, -405100,3.3868623,3.096696,,,,,,,,,,,,,, -405200,3.2255912,2.7809308,,,,,,,,,,,,,, -405300,3.494249,1.2908795,,,,,,,,,,,,,, -405400,3.207444,2.8602035,,,,,,,,,,,,,, -405500,3.4948235,1.113244,,,,,,,,,,,,,, -405600,3.0882926,1.3244847,,,,,,,,,,,,,, -405700,3.0112936,1.2990226,,,,,,,,,,,,,, -405738,,,0.8909765481948853,0.4066831469535827,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,181527.94424247745,197655.74682331085,181527.94424247745,16072.9281270504,32.72395944595337,0.0 -405800,2.8722644,1.05039,,,,,,,,,,,,,, -405900,2.9970841,1.9792377,,,,,,,,,,,,,, -406000,3.000839,1.1629033,,,,,,,,,,,,,, -406100,3.5025628,1.2324318,,,,,,,,,,,,,, -406200,2.8817337,2.2070148,,,,,,,,,,,,,, -406300,3.2099428,1.186214,,,,,,,,,,,,,, -406400,3.4916978,3.038515,,,,,,,,,,,,,, -406500,3.2064424,1.1109085,,,,,,,,,,,,,, -406600,3.2115877,1.0231172,,,,,,,,,,,,,, -406678,,,0.8872851133346558,0.4161201119422912,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,181948.04838109016,198112.09012413025,181948.04838109016,16109.006716966627,32.83241391181946,0.0 -406700,2.9693184,1.5139438,,,,,,,,,,,,,, -406800,3.1111214,1.1816021,,,,,,,,,,,,,, -406900,2.9274745,2.1976073,,,,,,,,,,,,,, -407000,3.9208415,3.2048647,,,,,,,,,,,,,, -407100,3.127694,1.6066426,,,,,,,,,,,,,, -407200,2.865715,2.5411417,,,,,,,,,,,,,, -407300,3.3056214,2.6141834,,,,,,,,,,,,,, -407400,3.7515213,3.238041,,,,,,,,,,,,,, -407500,2.9741833,1.550971,,,,,,,,,,,,,, -407600,3.125741,0.9662832,,,,,,,,,,,,,, -407614,,,0.8865038752555847,0.4203358888626098,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,182368.29708647728,198582.32165956497,182368.29708647728,16158.827764034271,32.94143462181091,0.0 -407700,3.091017,2.1997495,,,,,,,,,,,,,, -407800,2.9616122,1.1235301,,,,,,,,,,,,,, -407900,3.1095638,1.5236018,,,,,,,,,,,,,, -408000,3.0275016,1.6242211,,,,,,,,,,,,,, -408100,3.1679137,1.8541033,,,,,,,,,,,,,, -408200,2.9569578,1.2671723,,,,,,,,,,,,,, -408300,3.1540499,1.0466263,,,,,,,,,,,,,, -408400,2.9754386,2.6559234,,,,,,,,,,,,,, -408500,2.985269,1.3416924,,,,,,,,,,,,,, -408554,,,0.8880078196525574,0.4149841964244842,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,182788.3319592476,199042.1483139992,182788.3319592476,16198.477226018906,33.029731035232544,0.0 -408600,3.0022058,1.7065831,,,,,,,,,,,,,, -408700,3.28489,1.0936328,,,,,,,,,,,,,, -408800,3.533341,2.4629264,,,,,,,,,,,,,, -408900,2.9451356,1.1017052,,,,,,,,,,,,,, -409000,3.2128043,2.7421732,,,,,,,,,,,,,, -409100,3.1201062,1.6414518,,,,,,,,,,,,,, -409200,3.6619043,3.046065,,,,,,,,,,,,,, -409300,2.9147303,1.1263316,,,,,,,,,,,,,, -409400,3.093983,1.5263549,,,,,,,,,,,,,, -409490,,,0.8875390291213989,0.4196300804615021,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,183208.33271503448,199497.6459169388,183208.33271503448,16233.813900232317,33.13789200782776,0.0 -409500,3.4238002,1.0942822,,,,,,,,,,,,,, -409600,4.0833273,3.0279355,,,,,,,,,,,,,, -409700,2.925082,1.326205,,,,,,,,,,,,,, -409800,2.864161,1.4356234,,,,,,,,,,,,,, -409900,3.1604187,1.1057805,,,,,,,,,,,,,, -410000,2.9543006,2.3958006,,,,,,,,,,,,,, -410100,3.5442924,2.7781181,,,,,,,,,,,,,, -410200,3.0764408,1.1391134,,,,,,,,,,,,,, -410300,2.8083,1.0989137,,,,,,,,,,,,,, -410400,3.0110009,1.2695119,,,,,,,,,,,,,, -410422,,,0.8884961009025574,0.4138842225074768,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,183628.6072833538,199964.3894040585,183628.6072833538,16280.1210501194,33.24795961380005,0.0 -410500,5.474742,3.0469887,,,,,,,,,,,,,, -410600,3.211375,1.0992835,,,,,,,,,,,,,, -410700,2.9573061,2.0745187,,,,,,,,,,,,,, -410800,3.0206661,1.2185948,,,,,,,,,,,,,, -410900,3.107601,1.3665272,,,,,,,,,,,,,, -411000,3.0810616,1.1556959,,,,,,,,,,,,,, -411100,3.23734,1.9649304,,,,,,,,,,,,,, -411200,2.7532496,1.060068,,,,,,,,,,,,,, -411300,2.8361976,1.101532,,,,,,,,,,,,,, -411364,,,0.8887109160423279,0.4179222285747528,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,184048.7924396992,200417.43037319183,184048.7924396992,16312.81937122345,33.35289120674133,0.0 -411400,3.1232078,1.1408125,,,,,,,,,,,,,, -411500,2.8973827,1.2222135,,,,,,,,,,,,,, -411600,3.0520244,1.2383186,,,,,,,,,,,,,, -411700,3.3147502,3.0111399,,,,,,,,,,,,,, -411800,2.9711518,1.0927347,,,,,,,,,,,,,, -411900,3.0209975,1.8080648,,,,,,,,,,,,,, -412000,2.8487115,1.08937,,,,,,,,,,,,,, -412100,3.1667905,1.148832,,,,,,,,,,,,,, -412200,2.7966373,1.52064,,,,,,,,,,,,,, -412300,3.6451035,2.9018886,,,,,,,,,,,,,, -412301,,,0.8895702958106995,0.4115485548973083,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,184468.7359404564,200879.38162469864,184468.7359404564,16354.667058467863,33.46069145202637,0.0 -412400,3.4131985,2.3146768,,,,,,,,,,,,,, -412500,2.9514349,1.2094386,,,,,,,,,,,,,, -412600,3.5273464,3.0436532,,,,,,,,,,,,,, -412700,3.1307836,1.2220225,,,,,,,,,,,,,, -412800,2.870165,1.2431495,,,,,,,,,,,,,, -412900,2.9372127,2.1433983,,,,,,,,,,,,,, -413000,2.8523855,1.6001807,,,,,,,,,,,,,, -413100,2.9300082,1.1348624,,,,,,,,,,,,,, -413200,2.8560677,1.5903145,,,,,,,,,,,,,, -413233,,,0.8875781297683716,0.4195011556148529,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,184888.77264356613,201346.26455664635,184888.77264356613,16401.355442762375,33.5668420791626,0.0 -413300,3.038495,1.8244789,,,,,,,,,,,,,, -413400,3.0517428,1.1635342,,,,,,,,,,,,,, -413500,3.0136561,1.1781802,,,,,,,,,,,,,, -413600,2.9545135,1.2742338,,,,,,,,,,,,,, -413700,3.113003,2.5442314,,,,,,,,,,,,,, -413800,3.1328976,1.168116,,,,,,,,,,,,,, -413900,3.081623,1.0592384,,,,,,,,,,,,,, -414000,3.271976,2.0062773,,,,,,,,,,,,,, -414100,3.0614345,1.2617638,,,,,,,,,,,,,, -414171,,,0.8879101276397705,0.4134228527545929,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,185308.6721343994,201799.42095828056,185308.6721343994,16434.469732284546,33.65598273277283,0.0 -414200,3.3505695,3.0140066,,,,,,,,,,,,,, -414300,3.4982793,1.5734924,,,,,,,,,,,,,, -414400,3.636808,3.326562,,,,,,,,,,,,,, -414500,2.812426,1.1129168,,,,,,,,,,,,,, -414600,2.8827372,1.7214642,,,,,,,,,,,,,, -414700,2.8604414,2.536761,,,,,,,,,,,,,, -414800,3.0426888,1.3869929,,,,,,,,,,,,,, -414900,3.2807965,1.0689662,,,,,,,,,,,,,, -415000,2.9226897,1.1738603,,,,,,,,,,,,,, -415100,3.1665554,1.2926708,,,,,,,,,,,,,, -415111,,,0.8889257907867432,0.4115974307060241,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,185728.9335186481,202256.4557170868,185728.9335186481,16471.08065843582,33.76547908782959,0.0 -415200,2.967595,1.2038426,,,,,,,,,,,,,, -415300,3.6155012,3.210394,,,,,,,,,,,,,, -415400,3.3542154,1.1199056,,,,,,,,,,,,,, -415500,3.442896,1.1447518,,,,,,,,,,,,,, -415600,2.7269366,1.0312643,,,,,,,,,,,,,, -415700,3.039414,1.0811483,,,,,,,,,,,,,, -415800,2.9116604,1.2402029,,,,,,,,,,,,,, -415900,3.1677067,2.717458,,,,,,,,,,,,,, -416000,2.9577286,1.42185,,,,,,,,,,,,,, -416049,,,0.888671875,0.4118475914001465,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,186149.0531988144,202727.8894138336,186149.0531988144,16522.232833862305,33.87578749656677,0.0 -416100,2.959349,1.1346344,,,,,,,,,,,,,, -416200,3.2526917,1.3214569,,,,,,,,,,,,,, -416300,3.2059672,1.2528348,,,,,,,,,,,,,, -416400,2.93813,1.2887837,,,,,,,,,,,,,, -416500,3.113716,1.1092821,,,,,,,,,,,,,, -416600,3.1808171,2.5321646,,,,,,,,,,,,,, -416700,2.93098,1.4613519,,,,,,,,,,,,,, -416800,3.034982,2.1371305,,,,,,,,,,,,,, -416900,3.0588186,2.679819,,,,,,,,,,,,,, -416992,,,0.8870507478713989,0.4178069829940796,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,186568.91537976265,203182.3723757267,186568.91537976265,16556.713079452515,33.965121269226074,0.0 -417000,3.0449915,2.5469875,,,,,,,,,,,,,, -417100,2.89039,1.0805808,,,,,,,,,,,,,, -417200,3.7148683,3.1949167,,,,,,,,,,,,,, -417300,3.0633862,1.096578,,,,,,,,,,,,,, -417400,2.9702907,2.0099454,,,,,,,,,,,,,, -417500,3.3151138,1.0518535,,,,,,,,,,,,,, -417600,3.3229387,1.0878974,,,,,,,,,,,,,, -417700,3.0400007,2.8077,,,,,,,,,,,,,, -417800,2.9027517,1.8563579,,,,,,,,,,,,,, -417900,3.081225,1.4552493,,,,,,,,,,,,,, -417927,,,0.8895117044448853,0.4142775535583496,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,186987.864846468,203645.1201398373,186987.864846468,16599.438136577606,34.98693084716797,0.0 -418000,2.849013,1.1068226,,,,,,,,,,,,,, -418100,3.31347,1.1003301,,,,,,,,,,,,,, -418200,3.419466,3.0336647,,,,,,,,,,,,,, -418300,3.177878,1.1305703,,,,,,,,,,,,,, -418400,3.0744755,1.058879,,,,,,,,,,,,,, -418500,2.9959445,1.1250633,,,,,,,,,,,,,, -418600,2.9553566,1.2028017,,,,,,,,,,,,,, -418700,3.1107361,1.0538272,,,,,,,,,,,,,, -418800,3.1270351,1.1479208,,,,,,,,,,,,,, -418865,,,0.8903124928474426,0.4141020476818084,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,187407.9261534214,204104.56997919083,187407.9261534214,16638.66318511963,35.0987401008606,0.0 -418900,3.2095287,1.0618268,,,,,,,,,,,,,, -419000,2.7444913,1.6913306,,,,,,,,,,,,,, -419100,3.1303072,1.1120995,,,,,,,,,,,,,, -419200,3.3951504,1.1668739,,,,,,,,,,,,,, -419300,3.097132,1.0741973,,,,,,,,,,,,,, -419400,3.2364502,1.1783662,,,,,,,,,,,,,, -419500,3.2561607,1.6108093,,,,,,,,,,,,,, -419600,3.4987657,3.2337217,,,,,,,,,,,,,, -419700,3.0333185,2.2964964,,,,,,,,,,,,,, -419800,2.970919,1.1611336,,,,,,,,,,,,,, -419804,,,0.8893945217132568,0.4078099727630615,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,187828.1980266571,204569.553198576,187828.1980266571,16683.20975136757,35.21203064918518,0.0 -419900,3.1061022,1.0352362,,,,,,,,,,,,,, -420000,2.9639616,1.0167359,,,,,,,,,,,,,, -420100,3.165948,1.2037786,,,,,,,,,,,,,, -420200,3.3223717,1.650023,,,,,,,,,,,,,, -420300,2.8452642,1.9548719,,,,,,,,,,,,,, -420400,3.4267685,1.1679802,,,,,,,,,,,,,, -420500,2.9037774,1.568871,,,,,,,,,,,,,, -420600,4.4364886,3.2404115,,,,,,,,,,,,,, -420700,3.0510547,1.4885412,,,,,,,,,,,,,, -420742,,,0.8888280987739563,0.4157975018024444,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,188248.1066851616,205028.44544243813,188248.1066851616,16722.040591955185,35.3136100769043,0.0 -420800,2.8739896,1.7488978,,,,,,,,,,,,,, -420900,3.2296019,1.0667086,,,,,,,,,,,,,, -421000,3.3636227,1.1757207,,,,,,,,,,,,,, -421100,4.6923037,3.296658,,,,,,,,,,,,,, -421200,3.1975136,1.1588331,,,,,,,,,,,,,, -421300,3.251593,1.1223624,,,,,,,,,,,,,, -421400,3.01903,1.0676315,,,,,,,,,,,,,, -421500,2.723699,1.1255698,,,,,,,,,,,,,, -421600,2.9429207,1.0621611,,,,,,,,,,,,,, -421679,,,0.8882812261581421,0.4160736501216888,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,188668.3634223938,205486.2354171276,188668.3634223938,16759.408123254776,35.428178787231445,0.0 -421700,2.9916568,1.8162276,,,,,,,,,,,,,, -421800,3.5617023,2.6615844,,,,,,,,,,,,,, -421900,2.8934617,1.1057239,,,,,,,,,,,,,, -422000,3.090201,1.1714889,,,,,,,,,,,,,, -422100,3.7558067,2.1155527,,,,,,,,,,,,,, -422200,4.961368,3.141651,,,,,,,,,,,,,, -422300,2.9482105,1.1089058,,,,,,,,,,,,,, -422400,3.1721823,1.2013466,,,,,,,,,,,,,, -422500,2.9661655,1.8338535,,,,,,,,,,,,,, -422600,3.146806,2.5131938,,,,,,,,,,,,,, -422616,,,0.8886913657188416,0.4118205606937408,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,189088.33290719983,205949.0699858665,189088.33290719983,16802.113570451736,35.536232471466064,0.0 -422700,2.7339864,1.9624041,,,,,,,,,,,,,, -422800,3.5293257,2.6505888,,,,,,,,,,,,,, -422900,4.457302,1.2713643,,,,,,,,,,,,,, -423000,3.0396888,1.1994236,,,,,,,,,,,,,, -423100,3.6254506,1.1692151,,,,,,,,,,,,,, -423200,2.9260879,1.069175,,,,,,,,,,,,,, -423300,3.5618396,1.1549875,,,,,,,,,,,,,, -423400,2.8914945,1.1614089,,,,,,,,,,,,,, -423500,3.1430004,2.0670469,,,,,,,,,,,,,, -423554,,,0.8875585794448853,0.4185675084590912,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,189508.17068457603,206406.62183642387,189508.17068457603,16839.58140516281,35.73130702972412,0.0 -423600,3.0321982,1.1814383,,,,,,,,,,,,,, -423700,3.1603742,1.6430103,,,,,,,,,,,,,, -423800,3.0205545,1.3049722,,,,,,,,,,,,,, -423900,2.9615378,1.7878416,,,,,,,,,,,,,, -424000,3.7892299,1.6820524,,,,,,,,,,,,,, -424100,3.1317372,2.9496572,,,,,,,,,,,,,, -424200,2.9581175,1.667675,,,,,,,,,,,,,, -424300,3.1053433,1.1147403,,,,,,,,,,,,,, -424400,4.023704,2.974831,,,,,,,,,,,,,, -424490,,,0.8890429735183716,0.4132682383060455,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,189928.0055382252,206866.4358620644,189928.0055382252,16879.394988775253,35.84501767158508,0.0 -424500,2.995891,1.033731,,,,,,,,,,,,,, -424600,2.8296418,1.5070118,,,,,,,,,,,,,, -424700,2.815209,1.1591505,,,,,,,,,,,,,, -424800,2.9358485,1.1151339,,,,,,,,,,,,,, -424900,3.020254,1.2091602,,,,,,,,,,,,,, -425000,3.2167437,2.7843385,,,,,,,,,,,,,, -425100,3.2653003,1.2763135,,,,,,,,,,,,,, -425200,3.262702,1.0541407,,,,,,,,,,,,,, -425300,2.8065567,1.2340115,,,,,,,,,,,,,, -425400,2.95952,1.0059386,,,,,,,,,,,,,, -425429,,,0.88734370470047,0.4196708500385284,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,190347.9941241741,207327.2237203121,190347.9941241741,16920.03908252716,35.94810652732849,0.0 -425500,3.0398684,1.1553985,,,,,,,,,,,,,, -425600,3.0820951,1.059975,,,,,,,,,,,,,, -425700,2.9994874,1.1759491,,,,,,,,,,,,,, -425800,3.2353175,1.0398688,,,,,,,,,,,,,, -425900,3.0757706,1.0854846,,,,,,,,,,,,,, -426000,2.9483557,1.1750045,,,,,,,,,,,,,, -426100,2.9572744,2.0525076,,,,,,,,,,,,,, -426200,2.9507966,1.167001,,,,,,,,,,,,,, -426300,2.893777,1.3663833,,,,,,,,,,,,,, -426367,,,0.8880664110183716,0.4123572409152984,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,190768.1888158321,207788.7522785664,190768.1888158321,16961.210329294205,36.057634353637695,0.0 -426400,2.8617325,1.0852649,,,,,,,,,,,,,, -426500,3.0628595,1.7043638,,,,,,,,,,,,,, -426600,3.4418104,1.6667151,,,,,,,,,,,,,, -426700,2.9836686,1.2219138,,,,,,,,,,,,,, -426800,2.9334228,1.1766233,,,,,,,,,,,,,, -426900,3.071942,1.4269431,,,,,,,,,,,,,, -427000,3.1206298,1.0121709,,,,,,,,,,,,,, -427100,2.8575954,1.15188,,,,,,,,,,,,,, -427200,3.230208,2.5228128,,,,,,,,,,,,,, -427300,3.048654,1.0937052,,,,,,,,,,,,,, -427307,,,0.8867382407188416,0.4196271896362304,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,191188.41430687904,208250.51259469983,191188.41430687904,17002.576770305634,36.17466354370117,0.0 -427400,2.9915257,1.1484966,,,,,,,,,,,,,, -427500,3.3655145,1.499357,,,,,,,,,,,,,, -427600,2.9578152,2.231663,,,,,,,,,,,,,, -427700,2.94974,1.6135364,,,,,,,,,,,,,, -427800,3.110464,1.0272399,,,,,,,,,,,,,, -427900,2.9806254,1.5259511,,,,,,,,,,,,,, -428000,3.0503147,1.1073546,,,,,,,,,,,,,, -428100,2.8746524,1.109797,,,,,,,,,,,,,, -428200,3.5840356,3.2238147,,,,,,,,,,,,,, -428245,,,0.8895898461341858,0.4129971861839294,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,191608.5195260048,208710.7665054798,191608.5195260048,17042.562101125717,36.28608512878418,0.0 -428300,3.0860074,2.1723003,,,,,,,,,,,,,, -428400,3.105287,1.1641204,,,,,,,,,,,,,, -428500,3.1024928,1.1620804,,,,,,,,,,,,,, -428600,3.5827503,2.8377361,,,,,,,,,,,,,, -428700,3.1538062,2.6269512,,,,,,,,,,,,,, -428800,3.1050153,1.1745594,,,,,,,,,,,,,, -428900,3.3234017,3.198673,,,,,,,,,,,,,, -429000,3.63983,3.1229668,,,,,,,,,,,,,, -429100,3.1203249,1.2169123,,,,,,,,,,,,,, -429185,,,0.8884375095367432,0.4158594012260437,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,192028.75000452995,209166.30269360545,192028.75000452995,17077.692586660385,36.40941309928894,0.0 -429200,2.9071133,1.0951817,,,,,,,,,,,,,, -429300,2.8639708,1.7191219,,,,,,,,,,,,,, -429400,2.940025,1.0981152,,,,,,,,,,,,,, -429500,3.0690398,1.0696924,,,,,,,,,,,,,, -429600,3.2116024,2.9350364,,,,,,,,,,,,,, -429700,3.1598287,1.1559733,,,,,,,,,,,,,, -429800,3.3343256,1.482102,,,,,,,,,,,,,, -429900,2.900065,1.1930526,,,,,,,,,,,,,, -430000,3.6388698,3.0505984,,,,,,,,,,,,,, -430100,3.2277162,1.1878475,,,,,,,,,,,,,, -430124,,,0.8883788585662842,0.4134212136268616,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,192449.02923035625,209624.86515569687,192449.02923035625,17115.80970931053,36.52241349220276,0.0 -430200,3.3573968,3.094422,,,,,,,,,,,,,, -430300,3.0958586,2.7875183,,,,,,,,,,,,,, -430400,2.855194,1.0074446,,,,,,,,,,,,,, -430500,3.9148183,3.1827056,,,,,,,,,,,,,, -430600,3.0987709,1.6735228,,,,,,,,,,,,,, -430700,3.1045845,1.2190362,,,,,,,,,,,,,, -430800,3.1131148,1.2316455,,,,,,,,,,,,,, -430900,4.204878,3.266055,,,,,,,,,,,,,, -431000,3.160675,1.901921,,,,,,,,,,,,,, -431061,,,0.8847460746765137,0.4247710406780243,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,192869.14423632625,210088.6588923931,192869.14423632625,17159.325850486755,36.63349270820618,0.0 -431100,2.9178767,1.5352371,,,,,,,,,,,,,, -431200,3.8050737,3.2200656,,,,,,,,,,,,,, -431300,3.4839373,2.813351,,,,,,,,,,,,,, -431400,2.919619,1.8692579,,,,,,,,,,,,,, -431500,3.0311878,1.138014,,,,,,,,,,,,,, -431600,3.0899081,2.1670122,,,,,,,,,,,,,, -431700,3.0831866,1.085499,,,,,,,,,,,,,, -431800,3.292126,2.1764112,,,,,,,,,,,,,, -431900,2.9250402,0.9996413,,,,,,,,,,,,,, -431998,,,0.8877539038658142,0.4151757359504699,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,193288.98817515373,210545.91840672493,193288.98817515373,17196.563505887985,36.75905227661133,0.0 -432000,3.2551656,2.7605095,,,,,,,,,,,,,, -432100,2.962557,1.5678227,,,,,,,,,,,,,, -432200,3.4687707,3.1630833,,,,,,,,,,,,,, -432300,3.2383225,1.089876,,,,,,,,,,,,,, -432400,4.0302033,3.187078,,,,,,,,,,,,,, -432500,2.9566357,1.0240365,,,,,,,,,,,,,, -432600,2.810714,1.5974708,,,,,,,,,,,,,, -432700,2.9236386,2.4958766,,,,,,,,,,,,,, -432800,2.9712532,1.118051,,,,,,,,,,,,,, -432900,3.1419215,1.2265407,,,,,,,,,,,,,, -432936,,,0.8878124952316284,0.4115523099899292,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,193709.06731200207,211009.1180469989,193709.06731200207,17239.514502763748,36.87295937538147,0.0 -433000,2.9299316,1.1535195,,,,,,,,,,,,,, -433100,3.617191,1.1739099,,,,,,,,,,,,,, -433200,3.2699459,2.6089206,,,,,,,,,,,,,, -433300,4.0054474,3.178265,,,,,,,,,,,,,, -433400,2.9339182,1.1285341,,,,,,,,,,,,,, -433500,3.7790875,3.0759003,,,,,,,,,,,,,, -433600,3.3570013,2.22957,,,,,,,,,,,,,, -433700,3.0039666,1.1046234,,,,,,,,,,,,,, -433800,3.1345062,2.2012668,,,,,,,,,,,,,, -433872,,,0.8901171684265137,0.4164775013923645,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,194129.1805028916,211468.52519202232,194129.1805028916,17278.63621211052,36.99358677864075,0.0 -433900,3.0834358,1.2361012,,,,,,,,,,,,,, -434000,3.0584817,1.6525694,,,,,,,,,,,,,, -434100,4.1326704,3.2941232,,,,,,,,,,,,,, -434200,3.1799314,1.1989938,,,,,,,,,,,,,, -434300,3.3379662,2.7999253,,,,,,,,,,,,,, -434400,3.3780346,2.725847,,,,,,,,,,,,,, -434500,4.0306263,3.1783457,,,,,,,,,,,,,, -434600,2.9783025,1.091744,,,,,,,,,,,,,, -434700,2.943174,1.230193,,,,,,,,,,,,,, -434800,3.4711127,3.052094,,,,,,,,,,,,,, -434810,,,0.887988269329071,0.4155340194702148,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,194549.187220335,211929.5177807808,194549.187220335,17319.42714571953,37.13652777671814,0.0 -434900,2.895826,1.1745546,,,,,,,,,,,,,, -435000,3.0856934,1.1521648,,,,,,,,,,,,,, -435100,3.387633,1.1687778,,,,,,,,,,,,,, -435200,3.7757218,3.1826832,,,,,,,,,,,,,, -435300,3.5886145,3.2448273,,,,,,,,,,,,,, -435400,3.0887785,1.1853976,,,,,,,,,,,,,, -435500,2.843886,1.4641874,,,,,,,,,,,,,, -435600,2.7826574,1.7696218,,,,,,,,,,,,,, -435700,2.968821,1.9912171,,,,,,,,,,,,,, -435748,,,0.8893749713897705,0.4128265976905823,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,194969.2794907093,212391.54075169563,194969.2794907093,17361.18970298767,37.25248050689697,0.0 -435800,3.1533077,1.5172614,,,,,,,,,,,,,, -435900,3.139419,1.1512866,,,,,,,,,,,,,, -436000,2.890337,1.8105297,,,,,,,,,,,,,, -436100,3.0285916,1.1362754,,,,,,,,,,,,,, -436200,3.0305607,1.2097077,,,,,,,,,,,,,, -436300,3.2381945,1.1849966,,,,,,,,,,,,,, -436400,3.5756223,2.9710917,,,,,,,,,,,,,, -436500,2.9756434,1.0789825,,,,,,,,,,,,,, -436600,3.1671453,2.3798366,,,,,,,,,,,,,, -436685,,,0.8864257335662842,0.4213772416114807,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,195389.2151684761,212853.9011592865,195389.2151684761,17403.448195934296,37.36638760566712,0.0 -436700,2.8846095,2.2983596,,,,,,,,,,,,,, -436800,2.9804864,1.0733039,,,,,,,,,,,,,, -436900,3.2130787,1.3332107,,,,,,,,,,,,,, -437000,3.099879,3.053141,,,,,,,,,,,,,, -437100,3.0738611,1.3149316,,,,,,,,,,,,,, -437200,3.096909,1.1400695,,,,,,,,,,,,,, -437300,2.8950887,1.384547,,,,,,,,,,,,,, -437400,2.925916,1.0660411,,,,,,,,,,,,,, -437500,2.839733,1.0259033,,,,,,,,,,,,,, -437600,2.9886758,2.375023,,,,,,,,,,,,,, -437624,,,0.8902343511581421,0.4092316329479217,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,195809.24481654167,213312.25943803787,195809.24481654167,17441.593303203583,37.49781918525696,0.0 -437700,3.3559735,2.9414973,,,,,,,,,,,,,, -437800,2.82937,1.096978,,,,,,,,,,,,,, -437900,3.1512048,1.9121109,,,,,,,,,,,,,, -438000,3.0783849,1.4385132,,,,,,,,,,,,,, -438100,3.5364444,1.4862673,,,,,,,,,,,,,, -438200,3.0053253,1.2059467,,,,,,,,,,,,,, -438300,3.0065134,1.025933,,,,,,,,,,,,,, -438400,3.2994695,1.3187692,,,,,,,,,,,,,, -438500,2.928289,1.1540477,,,,,,,,,,,,,, -438562,,,0.8884179592132568,0.412599503993988,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,196229.3932518959,213774.7895960808,196229.3932518959,17483.80901670456,37.61236357688904,0.0 -438600,2.9476047,1.03178,,,,,,,,,,,,,, -438700,3.1550117,1.0849558,,,,,,,,,,,,,, -438800,3.126546,1.4481548,,,,,,,,,,,,,, -438900,3.2301176,2.8396316,,,,,,,,,,,,,, -439000,3.0326302,1.3023471,,,,,,,,,,,,,, -439100,3.3080506,1.5335655,,,,,,,,,,,,,, -439200,3.328319,2.7409315,,,,,,,,,,,,,, -439300,2.946236,1.1520357,,,,,,,,,,,,,, -439400,2.9162614,1.8363028,,,,,,,,,,,,,, -439498,,,0.8879101276397705,0.4182517826557159,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,196649.24244713783,214240.602208376,196649.24244713783,17529.58095383644,37.7523980140686,0.0 -439500,2.937109,2.5225518,,,,,,,,,,,,,, -439600,3.7661202,2.266366,,,,,,,,,,,,,, -439700,3.1274164,1.1826239,,,,,,,,,,,,,, -439800,3.500418,2.6628149,,,,,,,,,,,,,, -439900,2.779051,1.0565003,,,,,,,,,,,,,, -440000,3.014384,1.023944,,,,,,,,,,,,,, -440100,2.7584825,1.0315952,,,,,,,,,,,,,, -440200,3.7024434,3.3210504,,,,,,,,,,,,,, -440300,2.8743334,1.1142153,,,,,,,,,,,,,, -440400,3.3441603,1.1313307,,,,,,,,,,,,,, -440435,,,0.8888476490974426,0.4082626700401306,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,197069.23230481148,214701.5451090336,197069.23230481148,17570.387481451035,37.847246170043945,0.0 -440500,2.9869235,1.0930166,,,,,,,,,,,,,, -440600,3.0172906,1.0864458,,,,,,,,,,,,,, -440700,3.005518,2.6352036,,,,,,,,,,,,,, -440800,3.0153677,1.2798303,,,,,,,,,,,,,, -440900,2.9654343,1.6391462,,,,,,,,,,,,,, -441000,2.8725114,1.1011243,,,,,,,,,,,,,, -441100,3.3542688,2.4044647,,,,,,,,,,,,,, -441200,3.6360643,3.1586332,,,,,,,,,,,,,, -441300,3.328208,2.9860122,,,,,,,,,,,,,, -441371,,,0.8873828053474426,0.4172675311565399,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,197489.0875673294,215162.2216293812,197489.0875673294,17611.045613527298,37.9582302570343,0.0 -441400,3.2898443,1.1183647,,,,,,,,,,,,,, -441500,2.8113108,1.125675,,,,,,,,,,,,,, -441600,3.8831916,3.2457397,,,,,,,,,,,,,, -441700,3.8819127,3.2568011,,,,,,,,,,,,,, -441800,3.443464,2.948318,,,,,,,,,,,,,, -441900,3.2559798,1.4770714,,,,,,,,,,,,,, -442000,3.1615453,2.7358966,,,,,,,,,,,,,, -442100,3.7011154,3.074181,,,,,,,,,,,,,, -442200,3.048119,1.1024098,,,,,,,,,,,,,, -442300,2.9176276,2.1155782,,,,,,,,,,,,,, -442309,,,0.8910546898841858,0.4143266081809997,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,197909.2683000565,215623.2797615528,197909.2683000565,17651.756640195847,38.07303929328919,0.0 -442400,3.0765457,1.9407222,,,,,,,,,,,,,, -442500,4.428108,3.0234323,,,,,,,,,,,,,, -442600,2.887416,1.3131411,,,,,,,,,,,,,, -442700,2.9139216,1.1233593,,,,,,,,,,,,,, -442800,2.945584,1.1782063,,,,,,,,,,,,,, -442900,3.3483243,2.973915,,,,,,,,,,,,,, -443000,3.773606,3.0402193,,,,,,,,,,,,,, -443100,2.8346808,1.5111046,,,,,,,,,,,,,, -443200,2.6838076,1.8147318,,,,,,,,,,,,,, -443245,,,0.8901757597923279,0.4095847606658935,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,198329.16502404213,216088.4873046875,198329.16502404213,17696.90135359764,38.18743085861206,0.0 -443300,3.6803114,3.2058158,,,,,,,,,,,,,, -443400,2.9916406,1.0143316,,,,,,,,,,,,,, -443500,3.316243,1.7401471,,,,,,,,,,,,,, -443600,2.8490086,1.1570346,,,,,,,,,,,,,, -443700,2.9521427,1.1061494,,,,,,,,,,,,,, -443800,3.0176737,1.0627967,,,,,,,,,,,,,, -443900,3.5007474,1.1987329,,,,,,,,,,,,,, -444000,3.1207373,1.0349963,,,,,,,,,,,,,, -444100,2.9338865,1.0604969,,,,,,,,,,,,,, -444182,,,0.8889062404632568,0.412271648645401,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,198749.2635633945,216547.2765877247,198749.2635633945,17735.406715393066,38.32096481323242,0.0 -444200,3.5727446,3.0678682,,,,,,,,,,,,,, -444300,2.7774382,1.6131558,,,,,,,,,,,,,, -444400,3.5638146,3.079662,,,,,,,,,,,,,, -444500,3.6780539,3.0921512,,,,,,,,,,,,,, -444600,3.9282122,3.2801595,,,,,,,,,,,,,, -444700,2.8591275,1.2751386,,,,,,,,,,,,,, -444800,3.6645195,3.0425715,,,,,,,,,,,,,, -444900,3.794211,2.816374,,,,,,,,,,,,,, -445000,2.9324255,1.8874158,,,,,,,,,,,,,, -445100,3.5189688,2.764832,,,,,,,,,,,,,, -445119,,,0.8889062404632568,0.4158249497413635,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,199169.37876105309,217008.20818781853,199169.37876105309,17776.054845571518,38.437227964401245,0.0 -445200,2.9530926,1.9901035,,,,,,,,,,,,,, -445300,2.9399142,2.1454554,,,,,,,,,,,,,, -445400,2.8635535,1.0366881,,,,,,,,,,,,,, -445500,3.093083,1.1206492,,,,,,,,,,,,,, -445600,2.8946128,1.4177661,,,,,,,,,,,,,, -445700,2.8858266,1.0195193,,,,,,,,,,,,,, -445800,2.984521,1.8728282,,,,,,,,,,,,,, -445900,2.9638717,1.1904047,,,,,,,,,,,,,, -446000,3.0537229,1.0716379,,,,,,,,,,,,,, -446052,,,0.8854687213897705,0.4212657809257507,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,199589.2116761208,217481.3932626248,199589.2116761208,17829.23697257042,38.55621099472046,0.0 -446100,3.434569,1.5745543,,,,,,,,,,,,,, -446200,3.2234244,1.1481584,,,,,,,,,,,,,, -446300,2.954725,1.1809635,,,,,,,,,,,,,, -446400,3.1119845,1.1316392,,,,,,,,,,,,,, -446500,2.9979494,1.130204,,,,,,,,,,,,,, -446600,3.2292974,1.1990054,,,,,,,,,,,,,, -446700,3.4219148,1.3674114,,,,,,,,,,,,,, -446800,2.9572818,1.4413736,,,,,,,,,,,,,, -446900,3.6909366,3.0775323,,,,,,,,,,,,,, -446989,,,0.8869921565055847,0.4161044359207153,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,200009.1923904419,217937.54481863976,200009.1923904419,17865.259093284607,38.65262007713318,0.0 -447000,3.3909855,1.0805125,,,,,,,,,,,,,, -447100,3.4350529,1.0962616,,,,,,,,,,,,,, -447200,2.9771748,1.680041,,,,,,,,,,,,,, -447300,3.8285701,2.9961236,,,,,,,,,,,,,, -447400,2.5741603,1.6434512,,,,,,,,,,,,,, -447500,3.1082118,1.216414,,,,,,,,,,,,,, -447600,3.008466,1.0772486,,,,,,,,,,,,,, -447700,3.131375,1.1242579,,,,,,,,,,,,,, -447800,2.9910026,1.0691109,,,,,,,,,,,,,, -447900,3.7189023,3.133121,,,,,,,,,,,,,, -447923,,,0.8915429711341858,0.4091703295707702,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,200429.4089901448,218400.40449666977,200429.4089901448,17907.731934070587,38.77121067047119,0.0 -448000,3.3409503,1.496782,,,,,,,,,,,,,, -448100,3.124352,1.116092,,,,,,,,,,,,,, -448200,2.8496215,1.0801874,,,,,,,,,,,,,, -448300,3.6674948,3.2881496,,,,,,,,,,,,,, -448400,3.0412185,1.174566,,,,,,,,,,,,,, -448500,3.125553,1.1298819,,,,,,,,,,,,,, -448600,3.1474402,1.6418376,,,,,,,,,,,,,, -448700,2.8902705,2.0391197,,,,,,,,,,,,,, -448800,3.3684998,1.1060312,,,,,,,,,,,,,, -448858,,,0.8885351419448853,0.4149369299411773,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,200849.61452794075,218855.90346169472,200849.61452794075,17942.856937885284,38.88737273216248,0.0 -448900,2.9612365,2.119646,,,,,,,,,,,,,, -449000,3.00357,1.2175673,,,,,,,,,,,,,, -449100,2.8745737,1.7854657,,,,,,,,,,,,,, -449200,4.9500694,2.931779,,,,,,,,,,,,,, -449300,3.247342,3.0243742,,,,,,,,,,,,,, -449400,3.012359,1.189676,,,,,,,,,,,,,, -449500,3.1166565,1.053507,,,,,,,,,,,,,, -449600,3.333291,2.0182078,,,,,,,,,,,,,, -449700,2.903167,1.0619463,,,,,,,,,,,,,, -449797,,,0.8884961009025574,0.4168313145637512,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,201269.6993484497,219312.5145866871,201269.6993484497,17979.217833518982,38.99995350837708,0.0 -449800,3.1115534,1.4807274,,,,,,,,,,,,,, -449900,3.0070782,2.1328914,,,,,,,,,,,,,, -450000,3.042132,1.2108672,,,,,,,,,,,,,, -450100,2.9882,1.0851932,,,,,,,,,,,,,, -450200,3.0036364,2.0587876,,,,,,,,,,,,,, -450300,2.964734,1.0368682,,,,,,,,,,,,,, -450400,2.9397385,2.1717863,,,,,,,,,,,,,, -450500,3.2720716,2.4162328,,,,,,,,,,,,,, -450600,2.8513806,2.1701674,,,,,,,,,,,,,, -450700,3.0381813,1.3010962,,,,,,,,,,,,,, -450734,,,0.8858007788658142,0.4214471280574798,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,201689.80603957176,219779.9533605576,201689.80603957176,18026.383950948715,39.11453557014465,0.0 -450800,3.1514845,1.1161456,,,,,,,,,,,,,, -450900,2.9180453,1.076394,,,,,,,,,,,,,, -451000,3.01482,1.0564487,,,,,,,,,,,,,, -451100,2.942921,1.1231228,,,,,,,,,,,,,, -451200,3.2003772,2.4816332,,,,,,,,,,,,,, -451300,2.8145652,1.0586218,,,,,,,,,,,,,, -451400,3.0219882,1.6104141,,,,,,,,,,,,,, -451500,3.3202486,2.4809885,,,,,,,,,,,,,, -451600,3.2419546,2.501493,,,,,,,,,,,,,, -451672,,,0.8877733945846558,0.4172125458717346,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,202109.7711122036,220242.93742847443,202109.7711122036,18069.23517775536,39.230353116989136,0.0 -451700,3.2693381,1.1041907,,,,,,,,,,,,,, -451800,2.9409432,1.3813622,,,,,,,,,,,,,, -451900,3.1715498,1.1836472,,,,,,,,,,,,,, -452000,3.151347,1.1709455,,,,,,,,,,,,,, -452100,4.0647283,3.1660807,,,,,,,,,,,,,, -452200,4.0433674,3.155795,,,,,,,,,,,,,, -452300,3.19444,2.6811552,,,,,,,,,,,,,, -452400,2.7565322,1.7449622,,,,,,,,,,,,,, -452500,3.0440278,1.1137159,,,,,,,,,,,,,, -452600,2.9806597,1.0149713,,,,,,,,,,,,,, -452613,,,0.888476550579071,0.4082569479942322,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,202529.7824888229,220696.84306454656,202529.7824888229,18102.98245286941,39.325342893600464,0.0 -452700,2.930446,2.059188,,,,,,,,,,,,,, -452800,2.9088619,1.0665807,,,,,,,,,,,,,, -452900,3.2770178,1.1789768,,,,,,,,,,,,,, -453000,4.568801,3.0240908,,,,,,,,,,,,,, -453100,3.1298153,1.3937665,,,,,,,,,,,,,, -453200,3.5603237,3.1856012,,,,,,,,,,,,,, -453300,3.7721322,3.098979,,,,,,,,,,,,,, -453400,3.1619246,1.102406,,,,,,,,,,,,,, -453500,3.0869286,1.0874381,,,,,,,,,,,,,, -453549,,,0.8876367211341858,0.4193960428237915,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,202949.7999596596,221156.62088036537,202949.7999596596,18142.578282356262,39.43860602378845,0.0 -453600,3.0362833,2.0134845,,,,,,,,,,,,,, -453700,2.7335713,1.4989625,,,,,,,,,,,,,, -453800,3.051427,2.0167596,,,,,,,,,,,,,, -453900,3.2660685,3.0607462,,,,,,,,,,,,,, -454000,3.0196896,1.7586203,,,,,,,,,,,,,, -454100,3.4188,2.7364717,,,,,,,,,,,,,, -454200,3.432543,3.231869,,,,,,,,,,,,,, -454300,3.1058984,1.2266965,,,,,,,,,,,,,, -454400,3.221598,2.333202,,,,,,,,,,,,,, -454485,,,0.8869335651397705,0.4207581579685211,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,203369.71106481552,221620.98258256912,203369.71106481552,18186.859826803207,39.55583477020264,0.0 -454500,3.162331,0.97448087,,,,,,,,,,,,,, -454600,3.0683389,1.1365771,,,,,,,,,,,,,, -454700,3.18083,1.4743414,,,,,,,,,,,,,, -454800,3.0240922,1.4616202,,,,,,,,,,,,,, -454900,3.0806608,1.0783334,,,,,,,,,,,,,, -455000,2.8667886,1.01167,,,,,,,,,,,,,, -455100,3.2059019,2.9547687,,,,,,,,,,,,,, -455200,2.927729,1.1395949,,,,,,,,,,,,,, -455300,3.394762,3.0983415,,,,,,,,,,,,,, -455400,3.0149918,1.2038264,,,,,,,,,,,,,, -455426,,,0.88818359375,0.4161781072616577,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,203789.9735286236,222080.85911655423,203789.9735286236,18226.327399015427,39.65032410621643,0.0 -455500,2.9454231,1.1487298,,,,,,,,,,,,,, -455600,3.1391096,1.198492,,,,,,,,,,,,,, -455700,3.2773187,1.3236984,,,,,,,,,,,,,, -455800,3.7998285,2.7660475,,,,,,,,,,,,,, -455900,3.1215012,1.115525,,,,,,,,,,,,,, -456000,3.1944177,2.0408208,,,,,,,,,,,,,, -456100,3.2306085,1.1121442,,,,,,,,,,,,,, -456200,3.2200298,1.0917218,,,,,,,,,,,,,, -456300,2.997157,1.0484089,,,,,,,,,,,,,, -456364,,,0.8884375095367432,0.4117911756038666,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,204210.18108654025,222535.50862407684,204210.18108654025,18260.59964799881,39.7682089805603,0.0 -456400,2.9875174,1.1159496,,,,,,,,,,,,,, -456500,3.0896351,1.3724105,,,,,,,,,,,,,, -456600,3.2800813,1.1826745,,,,,,,,,,,,,, -456700,3.0096173,1.2521932,,,,,,,,,,,,,, -456800,3.1636705,1.1083359,,,,,,,,,,,,,, -456900,3.0555894,2.2624397,,,,,,,,,,,,,, -457000,3.014048,1.2028716,,,,,,,,,,,,,, -457100,3.7862515,3.0875313,,,,,,,,,,,,,, -457200,3.6215377,3.1840081,,,,,,,,,,,,,, -457298,,,0.889453113079071,0.4147363305091858,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,204630.0314304829,223007.61329698563,204630.0314304829,18312.688049077988,39.88325357437134,0.0 -457300,3.001927,1.380343,,,,,,,,,,,,,, -457400,2.9532588,1.1383626,,,,,,,,,,,,,, -457500,3.0085437,1.3529553,,,,,,,,,,,,,, -457600,2.9367902,1.1895382,,,,,,,,,,,,,, -457700,2.9552023,1.6456208,,,,,,,,,,,,,, -457800,3.704794,3.2361052,,,,,,,,,,,,,, -457900,3.1065984,1.1828177,,,,,,,,,,,,,, -458000,3.6540387,3.185328,,,,,,,,,,,,,, -458100,4.063094,3.1944056,,,,,,,,,,,,,, -458200,3.079134,1.0466444,,,,,,,,,,,,,, -458238,,,0.8873828053474426,0.4160374402999878,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,205050.12309336665,223462.97123146057,205050.12309336665,18347.80455422401,39.98117995262146,0.0 -458300,3.2099543,1.1590441,,,,,,,,,,,,,, -458400,3.039133,2.3697333,,,,,,,,,,,,,, -458500,2.9823525,1.1305895,,,,,,,,,,,,,, -458600,3.2469163,2.736565,,,,,,,,,,,,,, -458700,3.4840684,1.2154951,,,,,,,,,,,,,, -458800,3.0328743,1.1665183,,,,,,,,,,,,,, -458900,3.717184,3.0814457,,,,,,,,,,,,,, -459000,3.600579,1.152391,,,,,,,,,,,,,, -459100,2.8626237,1.3585179,,,,,,,,,,,,,, -459172,,,0.8875781297683716,0.4167408645153045,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,205469.96904063225,223928.3540511132,205469.96904063225,18393.14288878441,40.12711071968079,0.0 -459200,3.1533086,1.215821,,,,,,,,,,,,,, -459300,3.2410555,2.237739,,,,,,,,,,,,,, -459400,2.8721137,1.805479,,,,,,,,,,,,,, -459500,3.2731056,2.7220159,,,,,,,,,,,,,, -459600,3.3278842,1.1298488,,,,,,,,,,,,,, -459700,2.994156,1.092293,,,,,,,,,,,,,, -459800,3.066738,2.6212206,,,,,,,,,,,,,, -459900,2.9001157,1.5630088,,,,,,,,,,,,,, -460000,3.0589643,1.1944145,,,,,,,,,,,,,, -460100,2.920347,2.481814,,,,,,,,,,,,,, -460112,,,0.8892577886581421,0.4150897562503814,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,205889.9131486416,224388.8865418434,205889.9131486416,18433.56349658966,40.242703437805176,0.0 -460200,2.9969323,1.1169364,,,,,,,,,,,,,, -460300,3.601648,3.085233,,,,,,,,,,,,,, -460400,2.9195633,1.0756091,,,,,,,,,,,,,, -460500,3.5414312,3.0864768,,,,,,,,,,,,,, -460600,3.0529659,2.2709758,,,,,,,,,,,,,, -460700,3.942836,3.1497226,,,,,,,,,,,,,, -460800,3.2993865,1.5384696,,,,,,,,,,,,,, -460900,3.0420742,1.0925233,,,,,,,,,,,,,, -461000,2.9551287,1.0956196,,,,,,,,,,,,,, -461045,,,0.8885155916213989,0.4129526913166046,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,206309.76017975807,224848.9179956913,206309.76017975807,18473.569224596024,40.36980032920837,0.0 -461100,3.3384202,1.0516008,,,,,,,,,,,,,, -461200,3.744304,1.4559983,,,,,,,,,,,,,, -461300,2.893506,0.99585724,,,,,,,,,,,,,, -461400,3.2875512,2.9073906,,,,,,,,,,,,,, -461500,3.3957067,1.1712695,,,,,,,,,,,,,, -461600,3.2246375,2.9376476,,,,,,,,,,,,,, -461700,2.9403808,1.6895049,,,,,,,,,,,,,, -461800,3.2700248,2.9305139,,,,,,,,,,,,,, -461900,2.90119,1.1247007,,,,,,,,,,,,,, -461977,,,0.8879492282867432,0.4129699170589447,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,206729.5407309532,225308.3424794674,206729.5407309532,18512.961121320724,40.57051396369934,0.0 -462000,3.0048718,2.5213656,,,,,,,,,,,,,, -462100,3.610571,3.1894262,,,,,,,,,,,,,, -462200,3.2846482,2.1499493,,,,,,,,,,,,,, -462300,2.8902388,2.3811755,,,,,,,,,,,,,, -462400,3.1907635,1.4811361,,,,,,,,,,,,,, -462500,3.330183,2.624835,,,,,,,,,,,,,, -462600,2.9449217,1.6890062,,,,,,,,,,,,,, -462700,3.5954227,3.1557398,,,,,,,,,,,,,, -462800,3.1006014,1.1418493,,,,,,,,,,,,,, -462900,2.9789982,1.1152786,,,,,,,,,,,,,, -462912,,,0.8885155916213989,0.4146435856819153,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,207149.70380544665,225769.39330291748,207149.70380544665,18553.678634643555,40.68940997123718,0.0 -463000,2.9417615,1.1947734,,,,,,,,,,,,,, -463100,3.0991502,1.7815992,,,,,,,,,,,,,, -463200,3.1141534,2.594974,,,,,,,,,,,,,, -463300,3.2040286,1.3155663,,,,,,,,,,,,,, -463400,3.0699728,1.3171782,,,,,,,,,,,,,, -463500,3.0885153,1.1812049,,,,,,,,,,,,,, -463600,3.1959152,1.5356554,,,,,,,,,,,,,, -463700,3.3736422,2.8091705,,,,,,,,,,,,,, -463800,3.2860792,1.2320327,,,,,,,,,,,,,, -463848,,,0.8866796493530273,0.4163015186786651,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,207569.53734230995,226235.6934535504,207569.53734230995,18599.97526431084,40.80809164047241,0.0 -463900,3.2188296,1.2354248,,,,,,,,,,,,,, -464000,3.856332,3.3456736,,,,,,,,,,,,,, -464100,2.932773,1.3412023,,,,,,,,,,,,,, -464200,3.0258303,2.3327558,,,,,,,,,,,,,, -464300,3.023581,1.1287398,,,,,,,,,,,,,, -464400,2.98546,1.91286,,,,,,,,,,,,,, -464500,3.7162666,1.5523436,,,,,,,,,,,,,, -464600,2.8670526,1.269383,,,,,,,,,,,,,, -464700,3.180599,2.1979766,,,,,,,,,,,,,, -464790,,,0.8889062404632568,0.4164812862873077,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,207989.73219513893,226692.9933104515,207989.73219513893,18636.92086791992,40.91565990447998,0.0 -464800,5.5265183,2.8729205,,,,,,,,,,,,,, -464900,3.566081,1.1357588,,,,,,,,,,,,,, -465000,3.0389292,1.1971598,,,,,,,,,,,,,, -465100,3.5463245,1.6571934,,,,,,,,,,,,,, -465200,3.4256496,1.4190961,,,,,,,,,,,,,, -465300,3.7555602,1.1580218,,,,,,,,,,,,,, -465400,2.9049082,0.9875264,,,,,,,,,,,,,, -465500,3.3657212,1.5681305,,,,,,,,,,,,,, -465600,2.9236171,1.006307,,,,,,,,,,,,,, -465700,3.0062969,1.234618,,,,,,,,,,,,,, -465726,,,0.8902929425239563,0.4111060202121734,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,208409.66075110435,227156.35308504105,208409.66075110435,18680.181062698364,41.03439331054688,0.0 -465800,3.7213395,3.2632778,,,,,,,,,,,,,, -465900,3.7817404,3.2546735,,,,,,,,,,,,,, -466000,3.329653,1.2392242,,,,,,,,,,,,,, -466100,2.8900898,1.5286891,,,,,,,,,,,,,, -466200,3.220274,2.5902448,,,,,,,,,,,,,, -466300,2.9771416,2.6206422,,,,,,,,,,,,,, -466400,3.3927677,1.2491614,,,,,,,,,,,,,, -466500,2.961516,1.1078453,,,,,,,,,,,,,, -466600,3.2320802,1.1202703,,,,,,,,,,,,,, -466662,,,0.8910546898841858,0.4077297151088714,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,208829.88141417503,227617.39014077187,208829.88141417503,18720.83535385132,41.144742250442505,0.0 -466700,2.8357322,1.6910262,,,,,,,,,,,,,, -466800,3.1970928,2.6609488,,,,,,,,,,,,,, -466900,3.9500554,3.1920118,,,,,,,,,,,,,, -467000,3.9718027,3.2679944,,,,,,,,,,,,,, -467100,3.2123737,1.1855931,,,,,,,,,,,,,, -467200,3.1664386,1.3022071,,,,,,,,,,,,,, -467300,2.9405353,1.2680675,,,,,,,,,,,,,, -467400,3.2517905,1.198478,,,,,,,,,,,,,, -467500,3.393775,1.5651114,,,,,,,,,,,,,, -467600,2.8290098,1.5061162,,,,,,,,,,,,,, -467603,,,0.8886913657188416,0.4169077575206756,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,209249.93860125545,228074.3007860184,209249.93860125545,18757.53831171989,41.24268078804016,0.0 -467700,3.61032,2.5205264,,,,,,,,,,,,,, -467800,3.2732222,1.7155185,,,,,,,,,,,,,, -467900,3.6286585,1.97406,,,,,,,,,,,,,, -468000,2.7656407,1.2873564,,,,,,,,,,,,,, -468100,3.050225,1.1068773,,,,,,,,,,,,,, -468200,3.0269167,1.0890862,,,,,,,,,,,,,, -468300,2.8992903,1.1046989,,,,,,,,,,,,,, -468400,3.4925199,2.802744,,,,,,,,,,,,,, -468500,3.1092637,1.5577105,,,,,,,,,,,,,, -468540,,,0.8895702958106995,0.4094632863998413,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,209670.0353667736,228533.95680093765,209670.0353667736,18796.923159122467,41.36542820930481,0.0 -468600,3.7696056,3.1775076,,,,,,,,,,,,,, -468700,3.007102,2.2868404,,,,,,,,,,,,,, -468800,3.009436,2.2469761,,,,,,,,,,,,,, -468900,2.8444364,1.3589787,,,,,,,,,,,,,, -469000,3.1056323,1.0548633,,,,,,,,,,,,,, -469100,3.7913656,2.6974177,,,,,,,,,,,,,, -469200,3.616286,3.2563033,,,,,,,,,,,,,, -469300,3.167925,1.200834,,,,,,,,,,,,,, -469400,2.950361,1.0197074,,,,,,,,,,,,,, -469474,,,0.8866015672683716,0.4216593205928802,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,210090.10920333865,229001.8804507256,210090.10920333865,18844.599578857426,41.48731422424317,0.0 -469500,3.0017576,1.2015923,,,,,,,,,,,,,, -469600,3.2923043,2.8063197,,,,,,,,,,,,,, -469700,3.044295,1.4736538,,,,,,,,,,,,,, -469800,3.3917325,2.6719956,,,,,,,,,,,,,, -469900,3.9494445,3.229678,,,,,,,,,,,,,, -470000,3.661567,2.9251223,,,,,,,,,,,,,, -470100,3.026867,2.6495285,,,,,,,,,,,,,, -470200,3.011875,1.1499039,,,,,,,,,,,,,, -470300,2.9498422,1.2277013,,,,,,,,,,,,,, -470400,3.2985263,1.2168725,,,,,,,,,,,,,, -470414,,,0.8875976204872131,0.415841668844223,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,210510.11679530144,229459.28646302223,210510.11679530144,18881.850703954697,41.58333134651184,0.0 -470500,2.992113,2.2404518,,,,,,,,,,,,,, -470600,3.729623,3.2633767,,,,,,,,,,,,,, -470700,3.2186832,1.1270511,,,,,,,,,,,,,, -470800,3.554239,3.0724611,,,,,,,,,,,,,, -470900,3.0285156,2.1887758,,,,,,,,,,,,,, -471000,2.8184412,1.4185839,,,,,,,,,,,,,, -471100,3.926012,2.9034696,,,,,,,,,,,,,, -471200,3.1832604,1.2712904,,,,,,,,,,,,,, -471300,3.0389173,1.1429514,,,,,,,,,,,,,, -471347,,,0.8912695050239563,0.4065933227539062,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,210930.04602122307,229923.9920327664,210930.04602122307,18926.45775413513,41.70092463493347,0.0 -471400,3.0959458,1.1182659,,,,,,,,,,,,,, -471500,2.914866,1.2901459,,,,,,,,,,,,,, -471600,3.6203415,1.3103559,,,,,,,,,,,,,, -471700,3.2731767,1.4421326,,,,,,,,,,,,,, -471800,3.0425236,1.3766067,,,,,,,,,,,,,, -471900,3.444029,3.0964875,,,,,,,,,,,,,, -472000,3.0515444,1.106419,,,,,,,,,,,,,, -472100,2.8599613,1.6315436,,,,,,,,,,,,,, -472200,3.242489,1.0853889,,,,,,,,,,,,,, -472275,,,0.8875390291213989,0.4188980460166931,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,211350.18746185303,230383.6294569969,211350.18746185303,18965.783930778503,41.82049751281738,0.0 -472300,3.5624208,3.0022385,,,,,,,,,,,,,, -472400,3.065232,1.7633097,,,,,,,,,,,,,, -472500,2.854433,2.2449899,,,,,,,,,,,,,, -472600,3.0013075,1.9982117,,,,,,,,,,,,,, -472700,3.0361035,1.0765822,,,,,,,,,,,,,, -472800,2.9755476,1.07342,,,,,,,,,,,,,, -472900,3.901631,1.1783407,,,,,,,,,,,,,, -473000,3.996174,3.2164416,,,,,,,,,,,,,, -473100,3.3145304,1.2041919,,,,,,,,,,,,,, -473200,2.863185,1.5548968,,,,,,,,,,,,,, -473211,,,0.8882226347923279,0.415109634399414,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,211770.68524432185,230843.92622327805,211770.68524432185,19005.408094406128,41.94422078132629,0.0 -473300,3.616951,3.3017468,,,,,,,,,,,,,, -473400,3.145648,1.1983306,,,,,,,,,,,,,, -473500,2.8112247,1.2208927,,,,,,,,,,,,,, -473600,3.012564,2.7569094,,,,,,,,,,,,,, -473700,2.8942504,1.4492493,,,,,,,,,,,,,, -473800,3.2939832,1.0966833,,,,,,,,,,,,,, -473900,3.2228987,1.1038809,,,,,,,,,,,,,, -474000,3.1266346,1.4170768,,,,,,,,,,,,,, -474100,3.4449823,1.820404,,,,,,,,,,,,,, -474149,,,0.886523425579071,0.4200004339218139,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,212190.5941722393,231306.1898348332,212190.5941722393,19047.591923952103,42.06261587142944,0.0 -474200,3.639212,3.1413827,,,,,,,,,,,,,, -474300,2.957853,2.1104345,,,,,,,,,,,,,, -474400,3.113582,1.1707152,,,,,,,,,,,,,, -474500,2.996495,1.2412838,,,,,,,,,,,,,, -474600,3.2943296,3.0899818,,,,,,,,,,,,,, -474700,3.2130637,1.133312,,,,,,,,,,,,,, -474800,3.3155105,1.2371814,,,,,,,,,,,,,, -474900,2.936464,1.0115325,,,,,,,,,,,,,, -475000,2.8451822,1.8188854,,,,,,,,,,,,,, -475089,,,0.8861913681030273,0.42005056142807,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,212610.66497921944,231767.27955651283,212610.66497921944,19088.447747945786,42.17369341850281,0.0 -475100,3.323953,1.5264627,,,,,,,,,,,,,, -475200,2.8724306,1.2439204,,,,,,,,,,,,,, -475300,3.2463217,2.7079146,,,,,,,,,,,,,, -475400,3.8268516,3.4246497,,,,,,,,,,,,,, -475500,2.8548806,1.7310243,,,,,,,,,,,,,, -475600,3.0295293,1.0982542,,,,,,,,,,,,,, -475700,3.205671,1.1282101,,,,,,,,,,,,,, -475800,2.964145,1.01809,,,,,,,,,,,,,, -475900,3.0448287,1.3547542,,,,,,,,,,,,,, -476000,3.0202343,1.0048288,,,,,,,,,,,,,, -476023,,,0.8905664086341858,0.4074235558509826,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,213030.53802657127,232229.01929712296,213030.53802657127,19130.14138698578,42.2940833568573,0.0 -476100,3.16732,1.2483094,,,,,,,,,,,,,, -476200,2.9546235,1.0928231,,,,,,,,,,,,,, -476300,2.8924084,0.9985518,,,,,,,,,,,,,, -476400,2.9061403,1.06741,,,,,,,,,,,,,, -476500,2.9539082,1.666119,,,,,,,,,,,,,, -476600,3.5454352,3.2714636,,,,,,,,,,,,,, -476700,2.804616,1.354519,,,,,,,,,,,,,, -476800,2.830202,1.4869556,,,,,,,,,,,,,, -476900,3.0826666,1.6172992,,,,,,,,,,,,,, -476960,,,0.8874413967132568,0.4190874397754669,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,213450.59983038905,232687.90293979645,213450.59983038905,19168.720163583755,42.48556280136109,0.0 -477000,3.6459427,3.0169113,,,,,,,,,,,,,, -477100,3.085403,1.0895284,,,,,,,,,,,,,, -477200,2.9396982,1.1334674,,,,,,,,,,,,,, -477300,3.022143,2.0850542,,,,,,,,,,,,,, -477400,3.9143584,3.1187515,,,,,,,,,,,,,, -477500,3.6504893,3.1618185,,,,,,,,,,,,,, -477600,3.090395,1.0355247,,,,,,,,,,,,,, -477700,3.1330996,1.8494966,,,,,,,,,,,,,, -477800,3.0214508,1.037133,,,,,,,,,,,,,, -477900,,,0.8867968320846558,0.4170536398887634,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,213870.68713450432,233151.0371210575,213870.68713450432,19211.575483083725,42.624929428100586,0.0 -477900,3.785631,3.21927,,,,,,,,,,,,,, -478000,3.2782912,1.3005822,,,,,,,,,,,,,, -478100,3.0620248,0.98136926,,,,,,,,,,,,,, -478200,2.943033,2.2583716,,,,,,,,,,,,,, -478300,3.0134702,1.1655309,,,,,,,,,,,,,, -478400,3.2193842,2.814871,,,,,,,,,,,,,, -478500,3.7618697,3.094138,,,,,,,,,,,,,, -478600,3.1879113,2.741139,,,,,,,,,,,,,, -478700,2.9031675,2.6425915,,,,,,,,,,,,,, -478800,3.5395937,2.9700665,,,,,,,,,,,,,, -478839,,,0.8878905773162842,0.4166683554649353,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,214290.8945174217,233609.22505187988,214290.8945174217,19249.38017272949,42.74868321418762,0.0 -478900,2.8736777,2.3155103,,,,,,,,,,,,,, -479000,3.0702279,1.890769,,,,,,,,,,,,,, -479100,2.936126,0.98841643,,,,,,,,,,,,,, -479200,3.0762372,1.0985091,,,,,,,,,,,,,, -479300,3.3587024,1.1912304,,,,,,,,,,,,,, -479400,3.2819839,1.1482335,,,,,,,,,,,,,, -479500,3.3055353,2.336751,,,,,,,,,,,,,, -479600,2.905963,2.475103,,,,,,,,,,,,,, -479700,3.0371902,1.0733342,,,,,,,,,,,,,, -479778,,,0.8868359327316284,0.4167793989181518,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,214711.0296151638,234069.1584665776,214711.0296151638,19288.926438093185,42.94045972824097,0.0 -479800,3.6406882,3.3166184,,,,,,,,,,,,,, -479900,3.0185971,1.7052994,,,,,,,,,,,,,, -480000,2.9001908,1.1511165,,,,,,,,,,,,,, -480100,3.6750746,1.1510936,,,,,,,,,,,,,, -480200,3.0081587,1.9849188,,,,,,,,,,,,,, -480300,3.0147898,1.0792178,,,,,,,,,,,,,, -480400,3.0817106,2.4414668,,,,,,,,,,,,,, -480500,2.956988,1.6733902,,,,,,,,,,,,,, -480600,3.5315607,2.9217045,,,,,,,,,,,,,, -480700,3.1544178,2.7375548,,,,,,,,,,,,,, -480716,,,0.8891991972923279,0.4213533401489258,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,215131.2020647525,234529.17214751244,215131.2020647525,19328.59648346901,43.05870342254639,0.0 -480800,3.414647,1.157042,,,,,,,,,,,,,, -480900,3.2959032,1.1981226,,,,,,,,,,,,,, -481000,3.2657628,1.4430571,,,,,,,,,,,,,, -481100,3.2075388,1.460122,,,,,,,,,,,,,, -481200,2.9408023,1.9668972,,,,,,,,,,,,,, -481300,2.9436212,1.2150736,,,,,,,,,,,,,, -481400,3.2248535,1.2393237,,,,,,,,,,,,,, -481500,3.1920836,1.2355484,,,,,,,,,,,,,, -481600,2.7835333,1.0309315,,,,,,,,,,,,,, -481652,,,0.8895312547683716,0.4089459180831909,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,215551.36306786537,234994.9626107216,215551.36306786537,19374.056756734848,43.17589569091797,0.0 -481700,2.8988392,1.0796175,,,,,,,,,,,,,, -481800,2.9179492,1.0733154,,,,,,,,,,,,,, -481900,3.044942,1.2283685,,,,,,,,,,,,,, -482000,3.090999,1.0503255,,,,,,,,,,,,,, -482100,3.2248678,1.1857692,,,,,,,,,,,,,, -482200,2.9688904,1.2735571,,,,,,,,,,,,,, -482300,3.0273407,1.1664755,,,,,,,,,,,,,, -482400,3.1997132,1.8050033,,,,,,,,,,,,,, -482500,2.7602196,1.8898402,,,,,,,,,,,,,, -482594,,,0.8869726657867432,0.420018196105957,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,215971.16331911087,235451.6388404369,215971.16331911087,19410.689818143845,43.365275382995605,0.0 -482600,3.8183715,3.2981036,,,,,,,,,,,,,, -482700,2.9389014,1.2249178,,,,,,,,,,,,,, -482800,2.9358075,1.9850271,,,,,,,,,,,,,, -482900,3.380035,1.2444811,,,,,,,,,,,,,, -483000,3.6152096,2.4971337,,,,,,,,,,,,,, -483100,2.9334898,1.1472008,,,,,,,,,,,,,, -483200,2.9504893,1.1162364,,,,,,,,,,,,,, -483300,3.144175,1.3055313,,,,,,,,,,,,,, -483400,3.2820616,1.2630061,,,,,,,,,,,,,, -483500,3.6704264,2.959877,,,,,,,,,,,,,, -483532,,,0.8881444931030273,0.4166688323020935,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,216391.1415436268,235911.47638177872,216391.1415436268,19450.36641383171,43.495421171188354,0.0 -483600,3.1333642,2.726894,,,,,,,,,,,,,, -483700,2.9916682,2.680625,,,,,,,,,,,,,, -483800,3.3198438,1.14713,,,,,,,,,,,,,, -483900,3.0318925,1.1840003,,,,,,,,,,,,,, -484000,3.008775,1.1625094,,,,,,,,,,,,,, -484100,3.1329687,1.7737964,,,,,,,,,,,,,, -484200,3.1344988,1.1418205,,,,,,,,,,,,,, -484300,3.0279565,1.1002275,,,,,,,,,,,,,, -484400,3.6037705,2.7347426,,,,,,,,,,,,,, -484467,,,0.8903319835662842,0.4062535762786865,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,216811.1978859901,236383.04190945625,216811.1978859901,19501.70230698585,43.616320848464966,0.0 -484500,3.2543495,1.1399938,,,,,,,,,,,,,, -484600,3.1787622,1.0199487,,,,,,,,,,,,,, -484700,2.8159227,1.1782501,,,,,,,,,,,,,, -484800,3.2489085,1.3050342,,,,,,,,,,,,,, -484900,3.5005293,3.306995,,,,,,,,,,,,,, -485000,2.9936042,1.0249299,,,,,,,,,,,,,, -485100,2.9544487,1.2391722,,,,,,,,,,,,,, -485200,3.54355,3.1178198,,,,,,,,,,,,,, -485300,2.9926775,1.6329435,,,,,,,,,,,,,, -485400,3.2070072,1.1295091,,,,,,,,,,,,,, -485408,,,0.8874218463897705,0.4174661040306091,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,217231.43805646896,236841.8782045841,217231.43805646896,19540.132127523422,43.72937512397766,0.0 -485500,2.9562085,2.141774,,,,,,,,,,,,,, -485600,2.9694855,2.5998025,,,,,,,,,,,,,, -485700,3.075629,1.1248713,,,,,,,,,,,,,, -485800,2.9617667,1.0814736,,,,,,,,,,,,,, -485900,3.0251832,1.6188439,,,,,,,,,,,,,, -486000,3.8205223,1.1171683,,,,,,,,,,,,,, -486100,3.4540453,3.030916,,,,,,,,,,,,,, -486200,3.5195243,2.8548484,,,,,,,,,,,,,, -486300,2.9230287,1.1520534,,,,,,,,,,,,,, -486348,,,0.8887890577316284,0.4112873673439026,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,217651.5862724781,237303.24244117737,217651.5862724781,19581.17331981659,43.85249710083008,0.0 -486400,3.3718147,2.8443203,,,,,,,,,,,,,, -486500,3.2042205,2.1341205,,,,,,,,,,,,,, -486600,3.0763733,1.2737343,,,,,,,,,,,,,, -486700,3.05572,1.2863286,,,,,,,,,,,,,, -486800,3.2203684,2.8443503,,,,,,,,,,,,,, -486900,3.062411,1.1709728,,,,,,,,,,,,,, -487000,3.2028341,2.4005582,,,,,,,,,,,,,, -487100,2.892889,2.2741046,,,,,,,,,,,,,, -487200,3.0305672,1.1041471,,,,,,,,,,,,,, -487285,,,0.8875781297683716,0.4143383204936981,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,218071.7227330208,237763.7866547108,218071.7227330208,19621.407239437103,43.97446012496948,0.0 -487300,3.0428174,1.1417962,,,,,,,,,,,,,, -487400,2.8551848,1.3474739,,,,,,,,,,,,,, -487500,3.058063,1.0930462,,,,,,,,,,,,,, -487600,2.6874657,1.5190554,,,,,,,,,,,,,, -487700,4.115919,1.147806,,,,,,,,,,,,,, -487800,3.3250675,2.6928265,,,,,,,,,,,,,, -487900,3.1134145,1.2177703,,,,,,,,,,,,,, -488000,3.495151,3.0811968,,,,,,,,,,,,,, -488100,3.4691465,1.2339965,,,,,,,,,,,,,, -488200,3.176524,1.1431177,,,,,,,,,,,,,, -488221,,,0.8894335627555847,0.4136330187320709,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,218491.75671744347,238224.9208495617,218491.75671744347,19662.32736802101,44.102455615997314,0.0 -488300,2.8997028,1.3048344,,,,,,,,,,,,,, -488400,3.0992494,1.1655778,,,,,,,,,,,,,, -488500,3.3040535,2.399016,,,,,,,,,,,,,, -488600,3.206162,1.197401,,,,,,,,,,,,,, -488700,2.655359,1.2520808,,,,,,,,,,,,,, -488800,3.1757936,2.5846314,,,,,,,,,,,,,, -488900,3.0504067,2.749069,,,,,,,,,,,,,, -489000,3.1085968,2.7665386,,,,,,,,,,,,,, -489100,2.902918,1.0235925,,,,,,,,,,,,,, -489157,,,0.88916015625,0.4151538014411926,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,218911.8640100956,238688.7435998917,218911.8640100956,19705.860652446747,44.23263597488403,0.0 -489200,3.2697747,1.1439066,,,,,,,,,,,,,, -489300,3.9211075,3.1113248,,,,,,,,,,,,,, -489400,2.9435563,2.4582732,,,,,,,,,,,,,, -489500,3.1362443,1.1411036,,,,,,,,,,,,,, -489600,3.5138123,3.0743608,,,,,,,,,,,,,, -489700,3.0304513,1.8118451,,,,,,,,,,,,,, -489800,3.337851,2.6677816,,,,,,,,,,,,,, -489900,3.9819314,3.1507728,,,,,,,,,,,,,, -490000,3.4166052,1.3973687,,,,,,,,,,,,,, -490095,,,0.8889452815055847,0.4118638038635254,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,219331.83431482315,239152.7444009781,219331.83431482315,19749.71036696434,44.36157035827637,0.0 -490100,3.624231,3.185499,,,,,,,,,,,,,, -490200,2.950842,2.1602764,,,,,,,,,,,,,, -490300,2.9646018,1.1227021,,,,,,,,,,,,,, -490400,2.937424,1.1602889,,,,,,,,,,,,,, -490500,2.8099494,1.5689204,,,,,,,,,,,,,, -490600,3.1139648,1.0943928,,,,,,,,,,,,,, -490700,4.7541475,3.2858078,,,,,,,,,,,,,, -490800,2.7846909,1.0702494,,,,,,,,,,,,,, -490900,3.2291336,2.0993917,,,,,,,,,,,,,, -491000,3.1632204,1.4945054,,,,,,,,,,,,,, -491034,,,0.8895898461341858,0.4138146638870239,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,219752.01549005508,239613.35805392265,219752.01549005508,19789.958799123764,44.49367165565491,0.0 -491100,2.8742397,1.979314,,,,,,,,,,,,,, -491200,3.2003858,1.4569973,,,,,,,,,,,,,, -491300,3.2650118,1.1739655,,,,,,,,,,,,,, -491400,2.825679,2.0448654,,,,,,,,,,,,,, -491500,2.8038745,1.5566616,,,,,,,,,,,,,, -491600,2.9867356,1.252543,,,,,,,,,,,,,, -491700,3.0797648,2.7186303,,,,,,,,,,,,,, -491800,3.2813475,1.1639479,,,,,,,,,,,,,, -491900,2.8628235,1.6955616,,,,,,,,,,,,,, -491973,,,0.8890038728713989,0.4148271977901459,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,220172.0487046241,240073.4868443012,220172.0487046241,19829.86556172371,44.63035583496094,0.0 -492000,3.1220677,2.8177052,,,,,,,,,,,,,, -492100,3.052703,1.123394,,,,,,,,,,,,,, -492200,3.6256015,2.2783995,,,,,,,,,,,,,, -492300,2.950985,1.0910218,,,,,,,,,,,,,, -492400,3.9026833,2.8499472,,,,,,,,,,,,,, -492500,3.1614232,1.0564221,,,,,,,,,,,,,, -492600,8.434606,3.1887722,,,,,,,,,,,,,, -492700,2.9880965,1.4958572,,,,,,,,,,,,,, -492800,3.375207,1.1118321,,,,,,,,,,,,,, -492900,2.6745095,1.7243534,,,,,,,,,,,,,, -492910,,,0.8866601586341858,0.4159363508224487,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,220592.05123972893,240534.4372580052,220592.05123972893,19870.639471769333,44.75104427337647,0.0 -493000,3.4313776,2.179588,,,,,,,,,,,,,, -493100,2.9696536,1.2757332,,,,,,,,,,,,,, -493200,3.004079,1.411588,,,,,,,,,,,,,, -493300,3.68709,1.4080424,,,,,,,,,,,,,, -493400,2.8156888,1.9462215,,,,,,,,,,,,,, -493500,3.1095922,1.0374768,,,,,,,,,,,,,, -493600,3.2381418,2.6920643,,,,,,,,,,,,,, -493700,2.9209466,1.0205127,,,,,,,,,,,,,, -493800,3.360375,1.1969728,,,,,,,,,,,,,, -493848,,,0.887499988079071,0.4184769093990326,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,221011.9910628796,240998.0865285397,221011.9910628796,19914.169314861298,44.8757050037384,0.0 -493900,3.2573893,1.1544243,,,,,,,,,,,,,, -494000,3.7663562,3.1946487,,,,,,,,,,,,,, -494100,2.973532,2.60972,,,,,,,,,,,,,, -494200,3.189311,1.1260817,,,,,,,,,,,,,, -494300,2.9414902,1.066304,,,,,,,,,,,,,, -494400,3.6039782,3.271679,,,,,,,,,,,,,, -494500,3.3587425,2.093392,,,,,,,,,,,,,, -494600,2.8625088,1.2782775,,,,,,,,,,,,,, -494700,2.7391644,1.3167546,,,,,,,,,,,,,, -494787,,,0.8898242115974426,0.4121768474578857,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,221432.0510094165,241454.4551520348,221432.0510094165,19950.32503771782,44.97604584693909,0.0 -494800,3.2182055,2.2565746,,,,,,,,,,,,,, -494900,3.1910539,2.7055385,,,,,,,,,,,,,, -495000,3.0116298,2.1416643,,,,,,,,,,,,,, -495100,2.996767,1.1082802,,,,,,,,,,,,,, -495200,3.2162168,2.6667316,,,,,,,,,,,,,, -495300,3.2763116,1.171508,,,,,,,,,,,,,, -495400,3.9953613,3.2089117,,,,,,,,,,,,,, -495500,2.9224608,1.254003,,,,,,,,,,,,,, -495600,2.8734684,1.609138,,,,,,,,,,,,,, -495700,2.987833,2.030909,,,,,,,,,,,,,, -495725,,,0.89013671875,0.4112144410610199,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,221852.1177001,241920.32781767845,221852.1177001,19995.958854675293,45.0961983203888,0.0 -495800,3.072844,1.1636409,,,,,,,,,,,,,, -495900,2.7642703,1.177887,,,,,,,,,,,,,, -496000,3.2414331,1.1327223,,,,,,,,,,,,,, -496100,3.020061,2.0976245,,,,,,,,,,,,,, -496200,2.8891947,1.3933805,,,,,,,,,,,,,, -496300,3.9977076,2.6402655,,,,,,,,,,,,,, -496400,3.0213685,1.0512997,,,,,,,,,,,,,, -496500,2.8992953,2.2057116,,,,,,,,,,,,,, -496600,3.0711873,1.0757804,,,,,,,,,,,,,, -496658,,,0.8859961032867432,0.4209813177585602,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,222272.18267846107,242390.0301771164,222272.18267846107,20045.42670416832,45.213632106781006,0.0 -496700,2.8332808,1.2792222,,,,,,,,,,,,,, -496800,3.2565153,1.2215278,,,,,,,,,,,,,, -496900,3.0386188,1.1358814,,,,,,,,,,,,,, -497000,3.0966582,1.1639369,,,,,,,,,,,,,, -497100,3.1181986,1.1462204,,,,,,,,,,,,,, -497200,3.2667122,1.365776,,,,,,,,,,,,,, -497300,3.3447254,1.031334,,,,,,,,,,,,,, -497400,3.0582652,1.1136352,,,,,,,,,,,,,, -497500,3.0137799,2.232336,,,,,,,,,,,,,, -497598,,,0.8863281011581421,0.4229179620742798,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,222692.19125318527,242846.0733227729,222692.19125318527,20081.30944299698,45.31370210647583,0.0 -497600,3.030146,1.2072954,,,,,,,,,,,,,, -497700,2.8459427,1.077801,,,,,,,,,,,,,, -497800,2.9062152,1.0823567,,,,,,,,,,,,,, -497900,2.9424074,1.1809119,,,,,,,,,,,,,, -498000,2.9095497,2.2732105,,,,,,,,,,,,,, -498100,3.7005723,3.2274365,,,,,,,,,,,,,, -498200,2.9283304,1.1512527,,,,,,,,,,,,,, -498300,2.901792,0.96668696,,,,,,,,,,,,,, -498400,3.0047204,1.2722964,,,,,,,,,,,,,, -498500,2.924187,2.3041553,,,,,,,,,,,,,, -498525,,,0.8882030844688416,0.4153804779052734,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,223112.1540651321,243308.27743959427,223112.1540651321,20123.3711745739,45.43738150596619,0.0 -498600,3.1780608,2.7959228,,,,,,,,,,,,,, -498700,3.3283935,1.1186435,,,,,,,,,,,,,, -498800,3.006465,1.0764322,,,,,,,,,,,,,, -498900,3.8217893,3.1918607,,,,,,,,,,,,,, -499000,3.023346,1.0378664,,,,,,,,,,,,,, -499100,3.4310763,1.0924914,,,,,,,,,,,,,, -499200,3.2204583,1.3461583,,,,,,,,,,,,,, -499300,2.9388134,1.5718557,,,,,,,,,,,,,, -499400,3.723959,3.2072606,,,,,,,,,,,,,, -499461,,,0.8899218440055847,0.4086339771747589,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,223532.21031570435,243771.78656983376,223532.21031570435,20166.64469361305,45.56512475013733,0.0 -499500,3.2403562,1.0945868,,,,,,,,,,,,,, -499600,3.1046524,1.2459886,,,,,,,,,,,,,, -499700,3.6897852,3.2411485,,,,,,,,,,,,,, -499800,3.1113093,1.7206014,,,,,,,,,,,,,, -499900,2.9492226,2.4813824,,,,,,,,,,,,,, -500000,3.0086179,1.3926473,,,,,,,,,,,,,, -500100,2.880747,1.1253705,,,,,,,,,,,,,, -500200,2.810569,1.1031848,,,,,,,,,,,,,, -500300,3.0117834,1.1279538,,,,,,,,,,,,,, -500399,,,0.8868749737739563,0.4158554673194885,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,223952.1179277897,244235.95967531204,223952.1179277897,20210.7332572937,45.69040513038635,0.0 -500400,2.7928112,1.9023725,,,,,,,,,,,,,, -500500,3.4015124,2.9512572,,,,,,,,,,,,,, -500600,2.9920747,1.3255613,,,,,,,,,,,,,, -500700,3.2504082,2.614409,,,,,,,,,,,,,, -500800,2.8088763,1.1011007,,,,,,,,,,,,,, -500900,3.0930123,1.1986706,,,,,,,,,,,,,, -501000,3.175391,1.0247436,,,,,,,,,,,,,, -501100,3.5253906,3.2165885,,,,,,,,,,,,,, -501200,2.9985366,1.4731379,,,,,,,,,,,,,, -501300,3.0441823,2.5980256,,,,,,,,,,,,,, -501340,,,0.8868359327316284,0.4167474210262298,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,224372.26052856445,244701.171677351,224372.26052856445,20255.60511660576,45.836000204086304,0.0 -501400,3.051706,1.1287186,,,,,,,,,,,,,, -501500,3.3940852,1.1798155,,,,,,,,,,,,,, -501600,3.0074866,2.5329826,,,,,,,,,,,,,, -501700,3.3953776,1.0957092,,,,,,,,,,,,,, -501800,3.2715902,2.8806653,,,,,,,,,,,,,, -501900,3.752301,3.1060147,,,,,,,,,,,,,, -502000,3.690816,3.1481726,,,,,,,,,,,,,, -502100,2.9455633,1.2542102,,,,,,,,,,,,,, -502200,3.131658,1.1635672,,,,,,,,,,,,,, -502278,,,0.8883398175239563,0.4172844886779785,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,224792.3078968525,245159.794072628,224792.3078968525,20294.00301527977,45.96119236946106,0.0 -502300,3.1943505,1.2021449,,,,,,,,,,,,,, -502400,3.1310134,2.5378082,,,,,,,,,,,,,, -502500,4.676652,1.8630313,,,,,,,,,,,,,, -502600,3.6684537,1.2449645,,,,,,,,,,,,,, -502700,3.2553947,1.1733645,,,,,,,,,,,,,, -502800,3.0164835,1.1145877,,,,,,,,,,,,,, -502900,3.2423723,2.6712365,,,,,,,,,,,,,, -503000,3.1830344,1.2482029,,,,,,,,,,,,,, -503100,2.8314416,1.4451071,,,,,,,,,,,,,, -503200,3.2530196,2.9468045,,,,,,,,,,,,,, -503214,,,0.8876562118530273,0.4177963137626648,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,225212.55551242828,245627.0100402832,225212.55551242828,20340.79037499428,46.09040856361389,0.0 -503300,3.0734003,1.2658578,,,,,,,,,,,,,, -503400,3.5662694,1.1283144,,,,,,,,,,,,,, -503500,2.7011526,1.556415,,,,,,,,,,,,,, -503600,3.4209619,1.0533795,,,,,,,,,,,,,, -503700,3.2642517,1.1490886,,,,,,,,,,,,,, -503800,3.1864796,1.1694375,,,,,,,,,,,,,, -503900,3.2301,1.857646,,,,,,,,,,,,,, -504000,3.5029979,3.2847712,,,,,,,,,,,,,, -504100,2.9352553,1.1819898,,,,,,,,,,,,,, -504153,,,0.8895312547683716,0.4146497249603271,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,225632.7815082073,246085.8487613201,225632.7815082073,20379.2294960022,46.2122004032135,0.0 -504200,2.917912,1.4935865,,,,,,,,,,,,,, -504300,3.0089633,2.2725613,,,,,,,,,,,,,, -504400,3.0847766,1.0449979,,,,,,,,,,,,,, -504500,2.99647,1.2722543,,,,,,,,,,,,,, -504600,2.8463995,1.2149179,,,,,,,,,,,,,, -504700,2.988931,1.1062006,,,,,,,,,,,,,, -504800,2.857023,1.1430845,,,,,,,,,,,,,, -504900,2.9708247,1.9361461,,,,,,,,,,,,,, -505000,3.0325246,2.5683968,,,,,,,,,,,,,, -505090,,,0.8879687190055847,0.4165275990962982,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,226052.95351052284,246552.4772992134,226052.95351052284,20425.49946808815,46.34702253341675,0.0 -505100,2.9816427,1.1644212,,,,,,,,,,,,,, -505200,3.1164515,1.581457,,,,,,,,,,,,,, -505300,3.0387104,2.3683028,,,,,,,,,,,,,, -505400,3.002981,1.4675338,,,,,,,,,,,,,, -505500,2.9915347,2.2522502,,,,,,,,,,,,,, -505600,3.0580835,1.1528792,,,,,,,,,,,,,, -505700,3.1160252,2.5958688,,,,,,,,,,,,,, -505800,3.0483518,1.1532323,,,,,,,,,,,,,, -505900,3.0088692,1.057315,,,,,,,,,,,,,, -506000,3.6534417,3.1812778,,,,,,,,,,,,,, -506027,,,0.8889062404632568,0.4119411408901214,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,226473.03325295448,247015.84028172493,226473.03325295448,20468.608820915226,46.46912431716919,0.0 -506100,3.0360923,1.0905672,,,,,,,,,,,,,, -506200,2.957547,1.0346202,,,,,,,,,,,,,, -506300,3.5186398,2.7421393,,,,,,,,,,,,,, -506400,3.082098,1.4553387,,,,,,,,,,,,,, -506500,3.7614176,3.2452192,,,,,,,,,,,,,, -506600,3.454736,2.5186558,,,,,,,,,,,,,, -506700,3.4858727,1.1019119,,,,,,,,,,,,,, -506800,2.7601078,1.6066318,,,,,,,,,,,,,, -506900,3.5662723,3.1100638,,,,,,,,,,,,,, -506965,,,0.8897460699081421,0.4139359295368194,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,226892.98617458344,247479.51724481583,226892.98617458344,20512.156184911728,46.59465575218201,0.0 -507000,2.9329538,1.2053916,,,,,,,,,,,,,, -507100,3.1996446,1.807498,,,,,,,,,,,,,, -507200,3.0197525,1.3660533,,,,,,,,,,,,,, -507300,3.6735287,3.132046,,,,,,,,,,,,,, -507400,2.911122,0.9964087,,,,,,,,,,,,,, -507500,3.2924461,2.81664,,,,,,,,,,,,,, -507600,3.3033319,1.3680899,,,,,,,,,,,,,, -507700,3.426175,2.8049543,,,,,,,,,,,,,, -507800,3.277531,1.5492743,,,,,,,,,,,,,, -507900,3.4730144,1.2835208,,,,,,,,,,,,,, -507908,,,0.8889257907867432,0.411639004945755,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,227312.9258275032,247939.57105445865,227312.9258275032,20552.11368203163,46.699103116989136,0.0 -508000,3.0415907,1.1260751,,,,,,,,,,,,,, -508100,2.7503498,1.8825456,,,,,,,,,,,,,, -508200,2.9845533,1.3490016,,,,,,,,,,,,,, -508300,3.1070127,1.1866,,,,,,,,,,,,,, -508400,3.1446357,1.0875748,,,,,,,,,,,,,, -508500,3.05146,1.5654074,,,,,,,,,,,,,, -508600,3.1215088,1.1295998,,,,,,,,,,,,,, -508700,3.0838392,1.6091312,,,,,,,,,,,,,, -508800,3.2432218,1.1616746,,,,,,,,,,,,,, -508845,,,0.8868945240974426,0.4153727889060974,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,227732.81487941745,248399.2088296413,227732.81487941745,20591.683132648468,46.82708263397217,0.0 -508900,3.039949,1.7944078,,,,,,,,,,,,,, -509000,3.1625116,2.2558463,,,,,,,,,,,,,, -509100,3.5111666,3.062877,,,,,,,,,,,,,, -509200,3.0346456,1.0450976,,,,,,,,,,,,,, -509300,3.8400805,3.2607512,,,,,,,,,,,,,, -509400,3.9678328,3.2151644,,,,,,,,,,,,,, -509500,2.8429224,2.2487786,,,,,,,,,,,,,, -509600,3.1074722,1.1312747,,,,,,,,,,,,,, -509700,3.7650793,2.654534,,,,,,,,,,,,,, -509782,,,0.8875781297683716,0.4177852272987366,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,228152.68364787105,248864.9745190144,228152.68364787105,20637.40304684639,46.95222425460816,0.0 -509800,2.9895298,1.078449,,,,,,,,,,,,,, -509900,2.922477,1.3830658,,,,,,,,,,,,,, -510000,3.5375562,2.9863958,,,,,,,,,,,,,, -510100,2.934099,1.0556185,,,,,,,,,,,,,, -510200,3.0183842,1.099219,,,,,,,,,,,,,, -510300,3.795883,3.173604,,,,,,,,,,,,,, -510400,2.8085582,1.0824178,,,,,,,,,,,,,, -510500,3.1715522,1.0830251,,,,,,,,,,,,,, -510600,3.2366507,1.1008435,,,,,,,,,,,,,, -510700,3.060717,1.160976,,,,,,,,,,,,,, -510718,,,0.8883788585662842,0.4152357876300812,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,228572.96571731567,249329.0675106049,228572.96571731567,20681.03459215164,47.079352617263794,0.0 -510800,2.8717458,0.9687259,,,,,,,,,,,,,, -510900,3.500607,3.202481,,,,,,,,,,,,,, -511000,3.6323028,3.1955194,,,,,,,,,,,,,, -511100,2.9725106,1.1025662,,,,,,,,,,,,,, -511200,3.0518649,1.0065211,,,,,,,,,,,,,, -511300,3.0191536,1.1325039,,,,,,,,,,,,,, -511400,3.4389648,2.7439895,,,,,,,,,,,,,, -511500,3.6214614,3.0609627,,,,,,,,,,,,,, -511600,3.0457926,1.0904158,,,,,,,,,,,,,, -511658,,,0.8885741829872131,0.4134371876716614,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,228993.1485798359,249792.2692861557,228993.1485798359,20723.862995624542,47.21655368804932,0.0 -511700,3.2886539,1.2607484,,,,,,,,,,,,,, -511800,3.1406772,1.2071828,,,,,,,,,,,,,, -511900,3.084071,1.1983981,,,,,,,,,,,,,, -512000,2.962286,1.1035861,,,,,,,,,,,,,, -512100,3.243941,1.182414,,,,,,,,,,,,,, -512200,3.3697293,2.847979,,,,,,,,,,,,,, -512300,2.8017466,1.0448143,,,,,,,,,,,,,, -512400,3.8948145,3.1560605,,,,,,,,,,,,,, -512500,2.8712509,1.7697282,,,,,,,,,,,,,, -512596,,,0.8901171684265137,0.4124380052089691,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,229413.08163452148,250255.3282535076,229413.08163452148,20766.809602737427,47.343963384628296,0.0 -512600,3.5900245,2.9470923,,,,,,,,,,,,,, -512700,3.2404444,2.9416318,,,,,,,,,,,,,, -512800,3.5160358,1.1214108,,,,,,,,,,,,,, -512900,3.3326466,1.1941935,,,,,,,,,,,,,, -513000,3.562846,2.9914293,,,,,,,,,,,,,, -513100,3.5277154,3.282976,,,,,,,,,,,,,, -513200,3.2158248,1.2730362,,,,,,,,,,,,,, -513300,3.1111586,2.0288424,,,,,,,,,,,,,, -513400,3.0200014,2.0524123,,,,,,,,,,,,,, -513500,3.1807249,1.1669003,,,,,,,,,,,,,, -513533,,,0.8909960985183716,0.4092914760112762,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,229833.1634716988,250715.27859401703,229833.1634716988,20806.49105668068,47.4794237613678,0.0 -513600,3.3830805,2.7545938,,,,,,,,,,,,,, -513700,3.230981,2.3823552,,,,,,,,,,,,,, -513800,2.8809178,1.4108425,,,,,,,,,,,,,, -513900,2.8896053,1.3432199,,,,,,,,,,,,,, -514000,3.5935524,1.7651806,,,,,,,,,,,,,, -514100,3.1055007,1.097156,,,,,,,,,,,,,, -514200,2.9924622,1.0259681,,,,,,,,,,,,,, -514300,3.1838384,2.6539655,,,,,,,,,,,,,, -514400,3.5036716,3.0907135,,,,,,,,,,,,,, -514474,,,0.8879101276397705,0.412520170211792,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,230252.9332678318,251174.10797166824,230252.9332678318,20845.29662656784,47.680049657821655,0.0 -514500,2.8644347,2.4194326,,,,,,,,,,,,,, -514600,2.9240472,1.1562593,,,,,,,,,,,,,, -514700,3.0875013,1.1196573,,,,,,,,,,,,,, -514800,2.9228098,1.4854398,,,,,,,,,,,,,, -514900,2.9669034,1.05601,,,,,,,,,,,,,, -515000,3.2310054,1.2233607,,,,,,,,,,,,,, -515100,3.771729,1.2077732,,,,,,,,,,,,,, -515200,3.0925438,2.5246575,,,,,,,,,,,,,, -515300,3.125197,1.1292903,,,,,,,,,,,,,, -515400,3.0259259,2.4043307,,,,,,,,,,,,,, -515414,,,0.8891406059265137,0.4178435504436493,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,230673.0644853115,251637.04307699203,230673.0644853115,20887.919131040573,47.809616804122925,0.0 -515500,3.1616633,1.3402667,,,,,,,,,,,,,, -515600,3.0715737,2.070654,,,,,,,,,,,,,, -515700,3.0214174,1.7006582,,,,,,,,,,,,,, -515800,3.6676195,2.9697003,,,,,,,,,,,,,, -515900,2.7634213,1.0550698,,,,,,,,,,,,,, -516000,3.1905835,2.6316779,,,,,,,,,,,,,, -516100,2.6849554,1.7159021,,,,,,,,,,,,,, -516200,2.9439664,1.79727,,,,,,,,,,,,,, -516300,3.025053,1.1072283,,,,,,,,,,,,,, -516348,,,0.8887499570846558,0.4133471250534057,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,231093.2486963272,252104.85655498505,231093.2486963272,20935.35671448708,47.94902086257935,0.0 -516400,2.9738085,1.14414,,,,,,,,,,,,,, -516500,3.0385468,1.9458575,,,,,,,,,,,,,, -516600,3.3559103,1.127338,,,,,,,,,,,,,, -516700,3.3571804,2.8681233,,,,,,,,,,,,,, -516800,3.1219206,1.1851567,,,,,,,,,,,,,, -516900,3.28946,1.9555537,,,,,,,,,,,,,, -517000,3.7199786,3.2662005,,,,,,,,,,,,,, -517100,3.3759468,1.2242292,,,,,,,,,,,,,, -517200,2.955778,1.7814708,,,,,,,,,,,,,, -517285,,,0.8858202695846558,0.4180032908916473,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,231513.2865588665,252563.2615237236,231513.2865588665,20973.56746840477,48.052809715271,0.0 -517300,3.3250618,1.1966648,,,,,,,,,,,,,, -517400,3.349494,2.771266,,,,,,,,,,,,,, -517500,3.086061,1.192211,,,,,,,,,,,,,, -517600,2.983475,1.058537,,,,,,,,,,,,,, -517700,3.405978,3.0105553,,,,,,,,,,,,,, -517800,3.2767625,2.1803098,,,,,,,,,,,,,, -517900,3.7245815,3.3173552,,,,,,,,,,,,,, -518000,3.1841452,1.1378831,,,,,,,,,,,,,, -518100,3.8110092,2.895718,,,,,,,,,,,,,, -518200,2.9823241,2.2573996,,,,,,,,,,,,,, -518218,,,0.8895898461341858,0.4120994210243225,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,231933.59404420853,253029.94160294533,231933.59404420853,21019.734656095505,48.206233501434326,0.0 -518300,3.06811,2.1411834,,,,,,,,,,,,,, -518400,4.0160327,3.230462,,,,,,,,,,,,,, -518500,3.1006677,1.0685307,,,,,,,,,,,,,, -518600,3.700114,3.3363562,,,,,,,,,,,,,, -518700,3.6368015,2.6688712,,,,,,,,,,,,,, -518800,2.8178682,1.2985231,,,,,,,,,,,,,, -518900,3.0830786,1.2180505,,,,,,,,,,,,,, -519000,2.8211153,2.2912226,,,,,,,,,,,,,, -519100,3.0902214,1.183214,,,,,,,,,,,,,, -519154,,,0.8883984088897705,0.4173663556575775,0.7806599736213684,0.8491368293762207,50000.0,0.6628000140190125,1.439961552619934,10000.0,232353.68795967105,253496.0437746048,232353.68795967105,21065.571618556976,48.32579302787781,0.0 -519200,3.8540664,2.8467891,,,,,,,,,,,,,, -519300,2.9611804,1.186893,,,,,,,,,,,,,, -519400,2.8241296,1.3509206,,,,,,,,,,,,,, -519500,3.33643,1.2433592,,,,,,,,,,,,,, -519600,2.9264195,1.6422079,,,,,,,,,,,,,, -519620,,,,,,,,,,,232560.4075872898,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 9951cbc2f..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,41 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -158.93743777275083,0.0,60.57989621162415,1,0,60.57989621162415,30.594215,2472,0.9323421282473138,219.51739048957825,31.541626,0.9678279157903504,30.454973,5348,0.9188333317242244 -288.5353765487671,0.0399515628814697,1501.106069564819,1805,0,1501.106069564819,2.872518,2472,0.5765441878414884,1789.7498643398285,2.7822611,0.5832732590471875,3.2041247,5348,0.6159668652306979 -425.7475950717926,0.0936610698699951,2941.681283712387,3615,0,2941.681283712387,0.62356925,2472,0.2060609753620539,3367.66596531868,0.573987,0.198955863942679,0.90380496,5348,0.2695868773955608 -556.6074635982513,0.1409566402435302,4382.141820192337,5427,0,4382.141820192337,0.48505977,2472,0.1613348770133853,4939.108454227448,0.40507546,0.14771415468792,0.73548084,5348,0.2199619606669434 -687.1414444446564,0.190387487411499,5822.766416788101,7196,0,5822.766416788101,0.41708505,2472,0.1412873479170475,6510.388904333115,0.3568339,0.1271612437034632,0.65461516,5348,0.1976886760574258 -816.9003918170929,0.2399988174438476,7263.083811998367,8962,0,7263.083811998367,0.3855535,2472,0.1295066317307497,8080.588052272797,0.32957414,0.1200993715888582,0.61736715,5348,0.1856010504262529 -948.8532621860504,0.2881746292114258,8703.315732717514,10764,0,8703.315732717514,0.35787582,2472,0.1226006946560234,9652.896077871324,0.32769698,0.11632373260212,0.575464,5348,0.172837599080877 -1081.4634454250336,0.3373198509216308,10143.879881858826,12528,0,10143.879881858826,0.34133533,2472,0.1163853512887697,11226.192072629929,0.2873813,0.1042849850041477,0.5578073,5348,0.1677013236529345 -1215.3142149448397,0.4638903141021728,11583.694440364838,14276,0,11583.694440364838,0.32584527,2472,0.1094997257936749,12800.056248426436,0.24911912,0.094017005122338,0.5366055,5348,0.1617733666740685 -1347.545287847519,0.5180697441101074,13023.93999671936,16052,0,13023.93999671936,0.31493405,2472,0.106574858326732,14372.66071677208,0.23145117,0.0860567433991775,0.52654713,5348,0.1585294032458943 -1478.269473552704,0.5703423023223877,14464.140059709547,17842,0,14464.140059709547,0.30740193,2472,0.1013954055206873,15943.71213197708,0.22695147,0.0869537237151882,0.50633556,5348,0.1534993290016123 -1611.12948012352,0.6190557479858398,15904.540014266968,19608,0,15904.540014266968,0.29291183,2472,0.1003595149594784,17517.09511733055,0.23605815,0.0877019487559163,0.5006235,5348,0.1504870772468791 -1743.9933621883392,0.6681835651397705,17344.48076581955,21376,0,17344.48076581955,0.29092786,2472,0.1000751528446367,19090.02177286148,0.23558149,0.0873611733980531,0.49014068,5348,0.1476872278594668 -1878.37513589859,0.7173256874084473,18784.482964992523,23149,0,18784.482964992523,0.27992237,2472,0.0942254179107509,20664.528064489365,0.21837407,0.0810937339232431,0.4816064,5348,0.1449259970842947 -2010.588121175766,0.7683911323547363,20224.87201428413,24904,0,20224.87201428413,0.2789931,2472,0.0962159527146426,22237.253833293915,0.20431072,0.0772032431062877,0.47413144,5348,0.1437481294109696 -2143.106283664704,0.8185839653015137,21665.095484256744,26668,0,21665.095484256744,0.27178425,2472,0.0921739483679645,23810.11845588684,0.1921019,0.0740134067137901,0.4618266,5348,0.1380132654933045 -2276.688288450241,0.8722891807556152,23105.360609531403,28445,0,23105.360609531403,0.26862025,2472,0.0901631019844413,25384.09341740608,0.18997656,0.0708338873490094,0.45477387,5348,0.1356671847997142 -2410.2117116451263,0.9327528476715088,24545.301498651505,30247,0,24545.301498651505,0.25948068,2472,0.0879897629638657,26957.69291400909,0.20043005,0.0753759317039348,0.45052484,5348,0.1365747221873582 -2541.3817863464355,0.986194372177124,25985.380284786224,32006,0,25985.380284786224,0.25645247,2472,0.08581642394329,28529.067094802856,0.18588035,0.0679493648907481,0.44561952,5348,0.1330025005551425 -2676.7134261131287,1.0375988483428955,27425.35506367684,33791,0,27425.35506367684,0.24580121,2472,0.0820384701318221,30104.498718500137,0.19237903,0.0706327150188268,0.42201337,5348,0.1253367060254689 -2808.16770529747,1.0907437801361084,28865.51515221596,35567,0,28865.51515221596,0.23764282,2472,0.0807791521946661,31676.24119448661,0.18364745,0.0660773468887539,0.41895315,5348,0.1253077420662888 -2940.807469367981,1.1460282802581787,30305.79084634781,37355,0,30305.79084634781,0.23045272,2472,0.0778949078869863,33249.286796569824,0.13816252,0.052642907659666,0.41114187,5348,0.1220830879442347 -3074.425539016724,1.1972649097442627,31745.960538864136,39120,0,31745.960538864136,0.22464955,2472,0.0757825036053053,34823.198424339294,0.15625998,0.0587632784946463,0.40062064,5348,0.1193315118221226 -3204.458496570587,1.2498650550842283,33186.37688279152,40882,0,33186.37688279152,0.2234842,2472,0.0749700404200434,36393.77374911308,0.19048192,0.0709795349392231,0.39756453,5348,0.1176902208019154 -3337.72132229805,1.3047235012054443,34626.39605259895,42676,0,34626.39605259895,0.21795496,2472,0.0721264192716267,37967.18592643738,0.19422482,0.0713012873535142,0.39274555,5348,0.1154793052511658 -3466.470259666443,1.3679873943328855,36066.54161071777,44442,0,36066.54161071777,0.21539187,2472,0.0701561960473666,39536.21631217003,0.22243957,0.0831407830313949,0.3876792,5348,0.1129401314963746 -3594.2017362117767,1.4228923320770264,37506.617604494095,46203,0,37506.617604494095,0.2095781,2472,0.0694249791806308,41104.15241241455,0.19375552,0.071080783519479,0.38640672,5348,0.1129208221902546 -3727.454374790192,1.476792573928833,38946.74871778488,47980,0,38946.74871778488,0.20401005,2472,0.0674344443767391,42677.6640393734,0.17661671,0.0667275338759422,0.3707086,5348,0.1087596667213763 -3858.368305444717,1.5332348346710205,40386.87485766411,49754,0,40386.87485766411,0.19967617,2472,0.0666829159303719,44248.83420395851,0.14905035,0.0575971462441135,0.36766133,5348,0.1074466339052106 -3988.840337753296,1.5893762111663818,41827.40071630478,51528,0,41827.40071630478,0.19805373,2472,0.0648142506042695,45819.96135163307,0.1598807,0.0607569775209502,0.36701316,5348,0.1054191567626017 -4120.620878696442,1.7224667072296145,43267.91893672943,53319,0,43267.91893672943,0.19280446,2472,0.0630674547559563,47392.46687364578,0.14107403,0.0538996735956021,0.34760937,5348,0.1007269953754211 -4252.91517329216,1.7831799983978271,44708.85677075386,55101,0,44708.85677075386,0.18519415,2472,0.0599191599130664,48965.83365011215,0.13622159,0.0528374141615769,0.3458523,5348,0.0996746381918765 -4384.476100206375,1.8459181785583496,46149.09987139702,56883,0,46149.09987139702,0.18261853,2472,0.0596144862185932,50537.77483201027,0.12630334,0.0483699102829537,0.33833632,5348,0.0968940981105844 -4514.391450881958,1.9014570713043213,47589.387810468674,58649,0,47589.387810468674,0.17599674,2472,0.0569739808664919,52108.10684657097,0.12676607,0.0491741355035884,0.3321127,5348,0.0947411104781949 -4647.504854440689,1.9610748291015625,49029.39686584473,60416,0,49029.39686584473,0.17387721,2472,0.0561005829423354,53681.36341094971,0.12744635,0.0481000664350515,0.32808104,5348,0.0939494289272715 -4779.186099052429,2.0234742164611816,50469.78848719597,62193,0,50469.78848719597,0.16972645,2472,0.0559990250441776,55253.57203221321,0.10352248,0.0405317551057009,0.3215525,5348,0.0921246994989235 -4910.448292016983,2.0826873779296875,51909.92304944992,63948,0,51909.92304944992,0.16616057,2472,0.0551865618589157,56825.101155519485,0.10351148,0.0399570991739309,0.32219872,5348,0.0916902401112216 -5034.910115480423,2.148332595825196,53350.49229979515,65821,0,53350.49229979515,0.16144967,2472,0.0532975849531818,58390.26790380478,0.09428704,0.0364069544257465,0.31000638,5348,0.0878187242341446 -5157.780250310898,2.2042810916900635,54790.55012321472,67697,0,54790.55012321472,0.16130331,2472,0.0517742164808157,59953.31989455223,0.09369294,0.0362907587808946,0.31026995,5348,0.0872490997036021 -5280.730521202087,2.2564854621887207,56230.60710000992,69572,0,56230.60710000992,0.15887609,2472,0.051489854365974044,61516.4474709034,0.08811778,0.03356924040160104,0.3047844,5348,0.08574297382623555 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index 9ab3418da..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,738 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,35.0391,32.3574,,,,,,,,,,,,,, -1,,,31.541626,0.9678279157903504,30.454973,0.9188333317242244,5348.0,30.594215,0.9323421282473138,2472.0,60.57989621162415,219.51739048957825,60.57989621162415,158.93743777275083,0.0,0.0 -100,0.90261984,5.9673257,,,,,,,,,,,,,, -200,0.39289755,5.801491,,,,,,,,,,,,,, -300,0.6292333,5.802974,,,,,,,,,,,,,, -400,0.47856697,5.80481,,,,,,,,,,,,,, -500,0.6150746,5.8129344,,,,,,,,,,,,,, -600,0.34866035,5.765746,,,,,,,,,,,,,, -700,0.8883076,5.5634923,,,,,,,,,,,,,, -800,2.7178638,5.39502,,,,,,,,,,,,,, -900,1.0246578,4.404475,,,,,,,,,,,,,, -1000,0.8829328,3.6507185,,,,,,,,,,,,,, -1100,0.7096878,3.2496815,,,,,,,,,,,,,, -1200,0.65788025,3.0005493,,,,,,,,,,,,,, -1300,0.6646586,2.9036899,,,,,,,,,,,,,, -1400,0.9343111,2.7121646,,,,,,,,,,,,,, -1500,0.80265236,2.6337447,,,,,,,,,,,,,, -1600,0.7257459,2.5716074,,,,,,,,,,,,,, -1700,0.8438702,2.533526,,,,,,,,,,,,,, -1800,1.412154,2.4422588,,,,,,,,,,,,,, -1805,,,2.7822611,0.5832732590471875,3.2041247,0.6159668652306979,5348.0,2.872518,0.5765441878414884,2472.0,1501.106069564819,1789.7498643398285,1501.106069564819,288.5353765487671,0.0399515628814697,0.0 -1900,0.72955906,2.3283472,,,,,,,,,,,,,, -2000,0.68852323,2.288214,,,,,,,,,,,,,, -2100,0.6525975,2.1766024,,,,,,,,,,,,,, -2200,0.57381254,2.2013395,,,,,,,,,,,,,, -2300,0.68188936,2.195754,,,,,,,,,,,,,, -2400,0.6074868,2.0993152,,,,,,,,,,,,,, -2500,0.7256944,2.0578492,,,,,,,,,,,,,, -2600,1.1522695,2.0637636,,,,,,,,,,,,,, -2700,0.5996774,2.0321846,,,,,,,,,,,,,, -2800,0.5815629,1.9975827,,,,,,,,,,,,,, -2900,0.7947796,1.9204006,,,,,,,,,,,,,, -3000,0.63038576,1.9258126,,,,,,,,,,,,,, -3100,0.5764696,1.848467,,,,,,,,,,,,,, -3200,0.5169864,1.8290073,,,,,,,,,,,,,, -3300,0.517343,1.8209718,,,,,,,,,,,,,, -3400,0.7269616,1.8595284,,,,,,,,,,,,,, -3500,0.6401705,1.7694975,,,,,,,,,,,,,, -3600,0.71178657,1.7968482,,,,,,,,,,,,,, -3615,,,0.573987,0.198955863942679,0.90380496,0.2695868773955608,5348.0,0.62356925,0.2060609753620539,2472.0,2941.681283712387,3367.66596531868,2941.681283712387,425.7475950717926,0.0936610698699951,0.0 -3700,0.5626636,1.7985893,,,,,,,,,,,,,, -3800,0.59633845,1.8274095,,,,,,,,,,,,,, -3900,0.55560446,1.8344136,,,,,,,,,,,,,, -4000,0.6197833,1.7849123,,,,,,,,,,,,,, -4100,0.63902557,1.7632611,,,,,,,,,,,,,, -4200,0.6086935,1.7635255,,,,,,,,,,,,,, -4300,0.5183243,1.7294993,,,,,,,,,,,,,, -4400,0.5154741,1.7566836,,,,,,,,,,,,,, -4500,0.52446693,1.7736559,,,,,,,,,,,,,, -4600,0.5666603,1.744893,,,,,,,,,,,,,, -4700,0.6568821,1.7817446,,,,,,,,,,,,,, -4800,0.7083831,1.7947009,,,,,,,,,,,,,, -4900,0.63919806,1.7349167,,,,,,,,,,,,,, -5000,0.52866644,1.6766285,,,,,,,,,,,,,, -5100,0.62795097,1.7653191,,,,,,,,,,,,,, -5200,0.49025536,1.6409779,,,,,,,,,,,,,, -5300,0.5437234,1.6721116,,,,,,,,,,,,,, -5400,0.5868972,1.6563315,,,,,,,,,,,,,, -5427,,,0.40507546,0.14771415468792,0.73548084,0.2199619606669434,5348.0,0.48505977,0.1613348770133853,2472.0,4382.141820192337,4939.108454227448,4382.141820192337,556.6074635982513,0.1409566402435302,0.0 -5500,0.6307139,1.6525173,,,,,,,,,,,,,, -5600,0.56562036,1.6741759,,,,,,,,,,,,,, -5700,0.5412961,1.7186397,,,,,,,,,,,,,, -5800,0.50051117,1.6735343,,,,,,,,,,,,,, -5900,0.4813559,1.637441,,,,,,,,,,,,,, -6000,0.64782685,1.658875,,,,,,,,,,,,,, -6100,0.6229418,1.7063544,,,,,,,,,,,,,, -6200,0.49522153,1.654023,,,,,,,,,,,,,, -6300,0.7595674,1.5549834,,,,,,,,,,,,,, -6400,0.5709535,1.6411104,,,,,,,,,,,,,, -6500,0.56203073,1.6810172,,,,,,,,,,,,,, -6600,0.50837183,1.6417643,,,,,,,,,,,,,, -6700,0.6520308,1.6049472,,,,,,,,,,,,,, -6800,0.57410043,1.6147432,,,,,,,,,,,,,, -6900,0.52551293,1.609624,,,,,,,,,,,,,, -7000,0.5317403,1.5749134,,,,,,,,,,,,,, -7100,0.5577173,1.5334249,,,,,,,,,,,,,, -7196,,,0.3568339,0.1271612437034632,0.65461516,0.1976886760574258,5348.0,0.41708505,0.1412873479170475,2472.0,5822.766416788101,6510.388904333115,5822.766416788101,687.1414444446564,0.190387487411499,0.0 -7200,0.4598992,1.5999497,,,,,,,,,,,,,, -7300,0.7002594,1.5795475,,,,,,,,,,,,,, -7400,0.71155715,1.5786519,,,,,,,,,,,,,, -7500,0.6037944,1.5691444,,,,,,,,,,,,,, -7600,0.51607835,1.5281754,,,,,,,,,,,,,, -7700,0.74853045,1.5908661,,,,,,,,,,,,,, -7800,0.5228055,1.6241064,,,,,,,,,,,,,, -7900,0.50845385,1.5502208,,,,,,,,,,,,,, -8000,0.4714107,1.6005161,,,,,,,,,,,,,, -8100,0.63382435,1.529306,,,,,,,,,,,,,, -8200,0.450249,1.4914038,,,,,,,,,,,,,, -8300,0.49493584,1.5189923,,,,,,,,,,,,,, -8400,0.53632325,1.5305318,,,,,,,,,,,,,, -8500,0.51424944,1.5133977,,,,,,,,,,,,,, -8600,0.49478272,1.5030026,,,,,,,,,,,,,, -8700,0.48432812,1.54563,,,,,,,,,,,,,, -8800,0.5313651,1.5257353,,,,,,,,,,,,,, -8900,0.5026076,1.5651253,,,,,,,,,,,,,, -8962,,,0.32957414,0.1200993715888582,0.61736715,0.1856010504262529,5348.0,0.3855535,0.1295066317307497,2472.0,7263.083811998367,8080.588052272797,7263.083811998367,816.9003918170929,0.2399988174438476,0.0 -9000,0.50462073,1.5763263,,,,,,,,,,,,,, -9100,0.546774,1.5423356,,,,,,,,,,,,,, -9200,0.4769396,1.5086527,,,,,,,,,,,,,, -9300,0.4178957,1.4428713,,,,,,,,,,,,,, -9400,0.51789695,1.4764587,,,,,,,,,,,,,, -9500,0.50633454,1.5190789,,,,,,,,,,,,,, -9600,0.49086452,1.5231268,,,,,,,,,,,,,, -9700,0.54502743,1.5016356,,,,,,,,,,,,,, -9800,0.4692094,1.560956,,,,,,,,,,,,,, -9900,0.6252619,1.540856,,,,,,,,,,,,,, -10000,0.42768532,1.4991709,,,,,,,,,,,,,, -10100,0.54503304,1.5138031,,,,,,,,,,,,,, -10200,0.4920462,1.4902493,,,,,,,,,,,,,, -10300,0.60638446,1.4323716,,,,,,,,,,,,,, -10400,0.5238204,1.4967273,,,,,,,,,,,,,, -10500,0.565641,1.4034745,,,,,,,,,,,,,, -10600,0.53410316,1.4847871,,,,,,,,,,,,,, -10700,0.5591997,1.4872459,,,,,,,,,,,,,, -10764,,,0.32769698,0.11632373260212,0.575464,0.172837599080877,5348.0,0.35787582,0.1226006946560234,2472.0,8703.315732717514,9652.896077871324,8703.315732717514,948.8532621860504,0.2881746292114258,0.0 -10800,0.44563198,1.4633541,,,,,,,,,,,,,, -10900,0.44731635,1.4623724,,,,,,,,,,,,,, -11000,0.5256799,1.4847934,,,,,,,,,,,,,, -11100,0.5034297,1.4922506,,,,,,,,,,,,,, -11200,0.5256343,1.4922825,,,,,,,,,,,,,, -11300,0.48488697,1.5263855,,,,,,,,,,,,,, -11400,0.48055625,1.4227314,,,,,,,,,,,,,, -11500,0.5081802,1.447571,,,,,,,,,,,,,, -11600,0.4973648,1.4461315,,,,,,,,,,,,,, -11700,0.36918318,1.4595543,,,,,,,,,,,,,, -11800,0.48804894,1.4645734,,,,,,,,,,,,,, -11900,0.5024477,1.420369,,,,,,,,,,,,,, -12000,0.56013423,1.4707564,,,,,,,,,,,,,, -12100,0.5151975,1.40681,,,,,,,,,,,,,, -12200,0.43675363,1.4531443,,,,,,,,,,,,,, -12300,0.46655953,1.4241036,,,,,,,,,,,,,, -12400,0.547909,1.3893704,,,,,,,,,,,,,, -12500,0.39958677,1.3862242,,,,,,,,,,,,,, -12528,,,0.2873813,0.1042849850041477,0.5578073,0.1677013236529345,5348.0,0.34133533,0.1163853512887697,2472.0,10143.879881858826,11226.192072629929,10143.879881858826,1081.4634454250336,0.3373198509216308,0.0 -12600,0.43393376,1.4254278,,,,,,,,,,,,,, -12700,0.5145584,1.4222094,,,,,,,,,,,,,, -12800,0.5470498,1.4456222,,,,,,,,,,,,,, -12900,0.5159827,1.4303049,,,,,,,,,,,,,, -13000,0.48942313,1.4699172,,,,,,,,,,,,,, -13100,0.5429971,1.4918095,,,,,,,,,,,,,, -13200,0.6005113,1.4337747,,,,,,,,,,,,,, -13300,0.4631172,1.4807189,,,,,,,,,,,,,, -13400,0.46957117,1.3816596,,,,,,,,,,,,,, -13500,0.49809217,1.4789419,,,,,,,,,,,,,, -13600,0.5645972,1.4720798,,,,,,,,,,,,,, -13700,0.49624142,1.4398143,,,,,,,,,,,,,, -13800,0.46096766,1.3775426,,,,,,,,,,,,,, -13900,0.46327567,1.3998731,,,,,,,,,,,,,, -14000,0.516432,1.4406124,,,,,,,,,,,,,, -14100,0.47446606,1.3981278,,,,,,,,,,,,,, -14200,0.52053607,1.456081,,,,,,,,,,,,,, -14276,,,0.24911912,0.094017005122338,0.5366055,0.1617733666740685,5348.0,0.32584527,0.1094997257936749,2472.0,11583.694440364838,12800.056248426436,11583.694440364838,1215.3142149448397,0.4638903141021728,0.0 -14300,0.47650972,1.39717,,,,,,,,,,,,,, -14400,0.5897717,1.4308167,,,,,,,,,,,,,, -14500,0.47317594,1.4082462,,,,,,,,,,,,,, -14600,0.5021805,1.4380358,,,,,,,,,,,,,, -14700,0.56415164,1.3680736,,,,,,,,,,,,,, -14800,0.54324305,1.4195473,,,,,,,,,,,,,, -14900,0.47327894,1.4365588,,,,,,,,,,,,,, -15000,0.5938675,1.490636,,,,,,,,,,,,,, -15100,0.41734686,1.4042063,,,,,,,,,,,,,, -15200,0.60508746,1.4109313,,,,,,,,,,,,,, -15300,0.5440219,1.435786,,,,,,,,,,,,,, -15400,0.40008253,1.3955532,,,,,,,,,,,,,, -15500,0.5957125,1.3471062,,,,,,,,,,,,,, -15600,0.63641536,1.4024149,,,,,,,,,,,,,, -15700,0.4918714,1.396035,,,,,,,,,,,,,, -15800,0.41900283,1.3685601,,,,,,,,,,,,,, -15900,0.6791009,1.4088327,,,,,,,,,,,,,, -16000,0.47342655,1.395842,,,,,,,,,,,,,, -16052,,,0.23145117,0.0860567433991775,0.52654713,0.1585294032458943,5348.0,0.31493405,0.106574858326732,2472.0,13023.93999671936,14372.66071677208,13023.93999671936,1347.545287847519,0.5180697441101074,0.0 -16100,0.486349,1.4003732,,,,,,,,,,,,,, -16200,0.4608019,1.3513888,,,,,,,,,,,,,, -16300,0.54466206,1.4158913,,,,,,,,,,,,,, -16400,0.52187735,1.385195,,,,,,,,,,,,,, -16500,0.45318663,1.3586466,,,,,,,,,,,,,, -16600,0.54995394,1.3953688,,,,,,,,,,,,,, -16700,0.5000411,1.3607949,,,,,,,,,,,,,, -16800,0.5717773,1.3911924,,,,,,,,,,,,,, -16900,0.44618785,1.3411566,,,,,,,,,,,,,, -17000,0.43145144,1.3577403,,,,,,,,,,,,,, -17100,0.44062513,1.3423796,,,,,,,,,,,,,, -17200,0.48949298,1.3777264,,,,,,,,,,,,,, -17300,0.5217628,1.398453,,,,,,,,,,,,,, -17400,0.5607729,1.3784686,,,,,,,,,,,,,, -17500,0.4551724,1.3545879,,,,,,,,,,,,,, -17600,0.43053263,1.2896982,,,,,,,,,,,,,, -17700,0.46825945,1.3735826,,,,,,,,,,,,,, -17800,0.4640824,1.3492366,,,,,,,,,,,,,, -17842,,,0.22695147,0.0869537237151882,0.50633556,0.1534993290016123,5348.0,0.30740193,0.1013954055206873,2472.0,14464.140059709547,15943.71213197708,14464.140059709547,1478.269473552704,0.5703423023223877,0.0 -17900,0.44352123,1.3704306,,,,,,,,,,,,,, -18000,0.5901009,1.4178962,,,,,,,,,,,,,, -18100,0.52468556,1.4226365,,,,,,,,,,,,,, -18200,0.48887968,1.3676378,,,,,,,,,,,,,, -18300,0.6223121,1.3369474,,,,,,,,,,,,,, -18400,0.44663194,1.4186459,,,,,,,,,,,,,, -18500,0.66532815,1.3682716,,,,,,,,,,,,,, -18600,0.50933576,1.3622223,,,,,,,,,,,,,, -18700,0.4679883,1.3249733,,,,,,,,,,,,,, -18800,0.5146191,1.3588195,,,,,,,,,,,,,, -18900,0.59454095,1.3624076,,,,,,,,,,,,,, -19000,0.4553181,1.3429615,,,,,,,,,,,,,, -19100,0.49256825,1.3379052,,,,,,,,,,,,,, -19200,0.507948,1.4259784,,,,,,,,,,,,,, -19300,0.3784107,1.3190571,,,,,,,,,,,,,, -19400,0.4374479,1.3939848,,,,,,,,,,,,,, -19500,0.47657493,1.3833778,,,,,,,,,,,,,, -19600,0.50543934,1.3411268,,,,,,,,,,,,,, -19608,,,0.23605815,0.0877019487559163,0.5006235,0.1504870772468791,5348.0,0.29291183,0.1003595149594784,2472.0,15904.540014266968,17517.09511733055,15904.540014266968,1611.12948012352,0.6190557479858398,0.0 -19700,0.5813026,1.3311087,,,,,,,,,,,,,, -19800,0.48643237,1.2746724,,,,,,,,,,,,,, -19900,0.59990513,1.3928659,,,,,,,,,,,,,, -20000,0.49195036,1.365126,,,,,,,,,,,,,, -20100,0.51994264,1.3740405,,,,,,,,,,,,,, -20200,0.5018885,1.3793433,,,,,,,,,,,,,, -20300,0.46673095,1.3361421,,,,,,,,,,,,,, -20400,0.506895,1.3564228,,,,,,,,,,,,,, -20500,0.534094,1.3551466,,,,,,,,,,,,,, -20600,0.4933508,1.3943061,,,,,,,,,,,,,, -20700,0.4728099,1.3429475,,,,,,,,,,,,,, -20800,0.49439985,1.3274959,,,,,,,,,,,,,, -20900,0.48759302,1.3297808,,,,,,,,,,,,,, -21000,0.500553,1.2937101,,,,,,,,,,,,,, -21100,0.63941765,1.3632337,,,,,,,,,,,,,, -21200,0.47831836,1.3882457,,,,,,,,,,,,,, -21300,0.43526405,1.3377986,,,,,,,,,,,,,, -21376,,,0.23558149,0.0873611733980531,0.49014068,0.1476872278594668,5348.0,0.29092786,0.1000751528446367,2472.0,17344.48076581955,19090.02177286148,17344.48076581955,1743.9933621883392,0.6681835651397705,0.0 -21400,0.47366503,1.3706836,,,,,,,,,,,,,, -21500,0.5524784,1.3586389,,,,,,,,,,,,,, -21600,0.47183165,1.3071822,,,,,,,,,,,,,, -21700,0.7219111,1.311807,,,,,,,,,,,,,, -21800,0.5021097,1.2853891,,,,,,,,,,,,,, -21900,0.6028041,1.3361079,,,,,,,,,,,,,, -22000,0.42230898,1.3301069,,,,,,,,,,,,,, -22100,0.48685065,1.276782,,,,,,,,,,,,,, -22200,0.49496815,1.3296453,,,,,,,,,,,,,, -22300,0.5645669,1.371977,,,,,,,,,,,,,, -22400,0.7551897,1.3485572,,,,,,,,,,,,,, -22500,0.45109674,1.3209904,,,,,,,,,,,,,, -22600,0.50677925,1.3085755,,,,,,,,,,,,,, -22700,0.56083924,1.3814418,,,,,,,,,,,,,, -22800,0.5094942,1.3126668,,,,,,,,,,,,,, -22900,0.53715456,1.3196629,,,,,,,,,,,,,, -23000,0.4130628,1.3141044,,,,,,,,,,,,,, -23100,0.52964455,1.3691838,,,,,,,,,,,,,, -23149,,,0.21837407,0.0810937339232431,0.4816064,0.1449259970842947,5348.0,0.27992237,0.0942254179107509,2472.0,18784.482964992523,20664.528064489365,18784.482964992523,1878.37513589859,0.7173256874084473,0.0 -23200,0.64542687,1.3566935,,,,,,,,,,,,,, -23300,0.49045125,1.3141085,,,,,,,,,,,,,, -23400,0.45575887,1.3634398,,,,,,,,,,,,,, -23500,0.49560615,1.3000287,,,,,,,,,,,,,, -23600,0.50718105,1.3391699,,,,,,,,,,,,,, -23700,0.41803625,1.3532741,,,,,,,,,,,,,, -23800,0.5617854,1.3294024,,,,,,,,,,,,,, -23900,0.40744835,1.2800316,,,,,,,,,,,,,, -24000,0.5631581,1.3442453,,,,,,,,,,,,,, -24100,0.5329604,1.3578286,,,,,,,,,,,,,, -24200,0.516899,1.3671391,,,,,,,,,,,,,, -24300,0.4990704,1.3091183,,,,,,,,,,,,,, -24400,0.54110545,1.367752,,,,,,,,,,,,,, -24500,0.45545015,1.2817336,,,,,,,,,,,,,, -24600,0.5182217,1.3137305,,,,,,,,,,,,,, -24700,0.55321234,1.4023556,,,,,,,,,,,,,, -24800,0.6058503,1.3446056,,,,,,,,,,,,,, -24900,0.5304269,1.3547602,,,,,,,,,,,,,, -24904,,,0.20431072,0.0772032431062877,0.47413144,0.1437481294109696,5348.0,0.2789931,0.0962159527146426,2472.0,20224.87201428413,22237.253833293915,20224.87201428413,2010.588121175766,0.7683911323547363,0.0 -25000,0.66775054,1.3001324,,,,,,,,,,,,,, -25100,0.582858,1.2559516,,,,,,,,,,,,,, -25200,0.41100353,1.3264469,,,,,,,,,,,,,, -25300,0.54394007,1.3364944,,,,,,,,,,,,,, -25400,0.5656612,1.2967043,,,,,,,,,,,,,, -25500,0.41537482,1.2880788,,,,,,,,,,,,,, -25600,0.4426981,1.292354,,,,,,,,,,,,,, -25700,0.52072513,1.3077856,,,,,,,,,,,,,, -25800,0.46089,1.2918359,,,,,,,,,,,,,, -25900,0.4849932,1.2688578,,,,,,,,,,,,,, -26000,0.5196414,1.3497103,,,,,,,,,,,,,, -26100,0.42127022,1.2867258,,,,,,,,,,,,,, -26200,0.5626257,1.2777954,,,,,,,,,,,,,, -26300,0.5392251,1.3325874,,,,,,,,,,,,,, -26400,0.6675347,1.2966712,,,,,,,,,,,,,, -26500,0.4980313,1.2962472,,,,,,,,,,,,,, -26600,0.6054349,1.3355474,,,,,,,,,,,,,, -26668,,,0.1921019,0.0740134067137901,0.4618266,0.1380132654933045,5348.0,0.27178425,0.0921739483679645,2472.0,21665.095484256744,23810.11845588684,21665.095484256744,2143.106283664704,0.8185839653015137,0.0 -26700,0.63307476,1.3228002,,,,,,,,,,,,,, -26800,0.49996904,1.3393378,,,,,,,,,,,,,, -26900,0.47156215,1.2814575,,,,,,,,,,,,,, -27000,0.461485,1.2890474,,,,,,,,,,,,,, -27100,0.5450224,1.3554906,,,,,,,,,,,,,, -27200,0.626582,1.290676,,,,,,,,,,,,,, -27300,0.541295,1.2862971,,,,,,,,,,,,,, -27400,0.5821884,1.2865982,,,,,,,,,,,,,, -27500,0.4863534,1.3141629,,,,,,,,,,,,,, -27600,0.46635985,1.2730907,,,,,,,,,,,,,, -27700,0.64074326,1.290421,,,,,,,,,,,,,, -27800,0.46359533,1.2331051,,,,,,,,,,,,,, -27900,0.5544126,1.268906,,,,,,,,,,,,,, -28000,0.50903505,1.2829431,,,,,,,,,,,,,, -28100,0.46664637,1.19413,,,,,,,,,,,,,, -28200,0.5375349,1.3283883,,,,,,,,,,,,,, -28300,0.5758506,1.322685,,,,,,,,,,,,,, -28400,0.4775248,1.2388391,,,,,,,,,,,,,, -28445,,,0.18997656,0.0708338873490094,0.45477387,0.1356671847997142,5348.0,0.26862025,0.0901631019844413,2472.0,23105.360609531403,25384.09341740608,23105.360609531403,2276.688288450241,0.8722891807556152,0.0 -28500,0.49920592,1.2543881,,,,,,,,,,,,,, -28600,0.41871965,1.2732948,,,,,,,,,,,,,, -28700,0.5511052,1.3213941,,,,,,,,,,,,,, -28800,0.5450116,1.3337727,,,,,,,,,,,,,, -28900,0.46465987,1.2623376,,,,,,,,,,,,,, -29000,0.5569444,1.2169878,,,,,,,,,,,,,, -29100,0.48639798,1.2031507,,,,,,,,,,,,,, -29200,0.58714217,1.242501,,,,,,,,,,,,,, -29300,0.612474,1.2817732,,,,,,,,,,,,,, -29400,0.4780691,1.2670803,,,,,,,,,,,,,, -29500,0.6671291,1.2638391,,,,,,,,,,,,,, -29600,0.48385176,1.240161,,,,,,,,,,,,,, -29700,0.6690325,1.2948867,,,,,,,,,,,,,, -29800,0.6381995,1.328214,,,,,,,,,,,,,, -29900,0.5076513,1.2703861,,,,,,,,,,,,,, -30000,0.51802784,1.3140895,,,,,,,,,,,,,, -30100,0.54259485,1.3085457,,,,,,,,,,,,,, -30200,0.44435138,1.2768329,,,,,,,,,,,,,, -30247,,,0.20043005,0.0753759317039348,0.45052484,0.1365747221873582,5348.0,0.25948068,0.0879897629638657,2472.0,24545.301498651505,26957.69291400909,24545.301498651505,2410.2117116451263,0.9327528476715088,0.0 -30300,0.6284059,1.2912713,,,,,,,,,,,,,, -30400,0.6182849,1.2491194,,,,,,,,,,,,,, -30500,0.51684016,1.2886935,,,,,,,,,,,,,, -30600,0.56198174,1.2167364,,,,,,,,,,,,,, -30700,0.50622344,1.2288127,,,,,,,,,,,,,, -30800,0.535174,1.2544682,,,,,,,,,,,,,, -30900,0.4558795,1.2522033,,,,,,,,,,,,,, -31000,0.4856248,1.2718492,,,,,,,,,,,,,, -31100,0.4979454,1.2243358,,,,,,,,,,,,,, -31200,0.52292246,1.2646616,,,,,,,,,,,,,, -31300,0.5830027,1.2601582,,,,,,,,,,,,,, -31400,0.59823084,1.2616494,,,,,,,,,,,,,, -31500,0.5549123,1.2901422,,,,,,,,,,,,,, -31600,0.6336303,1.2374486,,,,,,,,,,,,,, -31700,0.4929051,1.2342625,,,,,,,,,,,,,, -31800,0.5257593,1.27535,,,,,,,,,,,,,, -31900,0.54690063,1.2747669,,,,,,,,,,,,,, -32000,0.49248588,1.2151388,,,,,,,,,,,,,, -32006,,,0.18588035,0.0679493648907481,0.44561952,0.1330025005551425,5348.0,0.25645247,0.08581642394329,2472.0,25985.380284786224,28529.067094802856,25985.380284786224,2541.3817863464355,0.986194372177124,0.0 -32100,0.5763523,1.2164196,,,,,,,,,,,,,, -32200,0.618727,1.1404662,,,,,,,,,,,,,, -32300,0.4986632,1.2107567,,,,,,,,,,,,,, -32400,0.49734217,1.2271445,,,,,,,,,,,,,, -32500,0.6062756,1.2972162,,,,,,,,,,,,,, -32600,0.49453944,1.3050284,,,,,,,,,,,,,, -32700,0.57726544,1.2673731,,,,,,,,,,,,,, -32800,0.576437,1.2212454,,,,,,,,,,,,,, -32900,0.5489051,1.2637308,,,,,,,,,,,,,, -33000,0.4872215,1.2046909,,,,,,,,,,,,,, -33100,0.44119543,1.2371641,,,,,,,,,,,,,, -33200,0.47640955,1.2344584,,,,,,,,,,,,,, -33300,0.51201314,1.2371776,,,,,,,,,,,,,, -33400,0.55890656,1.2342535,,,,,,,,,,,,,, -33500,0.48899454,1.201975,,,,,,,,,,,,,, -33600,0.47351092,1.210739,,,,,,,,,,,,,, -33700,0.60884774,1.1604315,,,,,,,,,,,,,, -33791,,,0.19237903,0.0706327150188268,0.42201337,0.1253367060254689,5348.0,0.24580121,0.0820384701318221,2472.0,27425.35506367684,30104.498718500137,27425.35506367684,2676.7134261131287,1.0375988483428955,0.0 -33800,0.5465655,1.2112024,,,,,,,,,,,,,, -33900,0.4741028,1.2270867,,,,,,,,,,,,,, -34000,0.51690805,1.2307491,,,,,,,,,,,,,, -34100,0.5882527,1.2358044,,,,,,,,,,,,,, -34200,0.5258826,1.2563947,,,,,,,,,,,,,, -34300,0.55592906,1.2331148,,,,,,,,,,,,,, -34400,0.5672493,1.1683902,,,,,,,,,,,,,, -34500,0.5665178,1.1815917,,,,,,,,,,,,,, -34600,0.5991157,1.2237946,,,,,,,,,,,,,, -34700,0.5338749,1.2092937,,,,,,,,,,,,,, -34800,0.5314232,1.2394661,,,,,,,,,,,,,, -34900,0.632355,1.305105,,,,,,,,,,,,,, -35000,0.51314366,1.2072313,,,,,,,,,,,,,, -35100,0.56266063,1.1349106,,,,,,,,,,,,,, -35200,0.6687916,1.2172222,,,,,,,,,,,,,, -35300,0.5925958,1.1757313,,,,,,,,,,,,,, -35400,0.5791871,1.2628098,,,,,,,,,,,,,, -35500,0.5581077,1.250383,,,,,,,,,,,,,, -35567,,,0.18364745,0.0660773468887539,0.41895315,0.1253077420662888,5348.0,0.23764282,0.0807791521946661,2472.0,28865.51515221596,31676.24119448661,28865.51515221596,2808.16770529747,1.0907437801361084,0.0 -35600,0.5508258,1.1985103,,,,,,,,,,,,,, -35700,0.5803658,1.1849062,,,,,,,,,,,,,, -35800,0.5560189,1.175164,,,,,,,,,,,,,, -35900,0.44506058,1.2208407,,,,,,,,,,,,,, -36000,0.5793242,1.2094177,,,,,,,,,,,,,, -36100,0.4408404,1.1789137,,,,,,,,,,,,,, -36200,0.4988403,1.1444516,,,,,,,,,,,,,, -36300,0.5025182,1.203804,,,,,,,,,,,,,, -36400,0.54891545,1.1968406,,,,,,,,,,,,,, -36500,0.50842756,1.1852386,,,,,,,,,,,,,, -36600,0.49469426,1.2002202,,,,,,,,,,,,,, -36700,0.5489891,1.2120304,,,,,,,,,,,,,, -36800,0.51799786,1.2342256,,,,,,,,,,,,,, -36900,0.56342626,1.186144,,,,,,,,,,,,,, -37000,0.54435194,1.1715841,,,,,,,,,,,,,, -37100,0.6446187,1.2295084,,,,,,,,,,,,,, -37200,0.50184816,1.1444685,,,,,,,,,,,,,, -37300,0.6351172,1.2205265,,,,,,,,,,,,,, -37355,,,0.13816252,0.052642907659666,0.41114187,0.1220830879442347,5348.0,0.23045272,0.0778949078869863,2472.0,30305.79084634781,33249.286796569824,30305.79084634781,2940.807469367981,1.1460282802581787,0.0 -37400,0.7489773,1.1706239,,,,,,,,,,,,,, -37500,0.5030597,1.203332,,,,,,,,,,,,,, -37600,0.60572577,1.2086467,,,,,,,,,,,,,, -37700,0.5482968,1.1948644,,,,,,,,,,,,,, -37800,0.53435946,1.2316219,,,,,,,,,,,,,, -37900,0.4736796,1.1794326,,,,,,,,,,,,,, -38000,0.61581737,1.1987201,,,,,,,,,,,,,, -38100,0.5631946,1.2330226,,,,,,,,,,,,,, -38200,0.5857913,1.1922829,,,,,,,,,,,,,, -38300,0.48320395,1.1675187,,,,,,,,,,,,,, -38400,0.60794157,1.2336085,,,,,,,,,,,,,, -38500,0.60728604,1.197512,,,,,,,,,,,,,, -38600,0.51737577,1.1651318,,,,,,,,,,,,,, -38700,0.53741395,1.1894889,,,,,,,,,,,,,, -38800,0.6113266,1.2366153,,,,,,,,,,,,,, -38900,0.46062842,1.1635767,,,,,,,,,,,,,, -39000,0.48599836,1.1884223,,,,,,,,,,,,,, -39100,0.6114853,1.2096583,,,,,,,,,,,,,, -39120,,,0.15625998,0.0587632784946463,0.40062064,0.1193315118221226,5348.0,0.22464955,0.0757825036053053,2472.0,31745.960538864136,34823.198424339294,31745.960538864136,3074.425539016724,1.1972649097442627,0.0 -39200,0.5585569,1.1333877,,,,,,,,,,,,,, -39300,0.5241933,1.2015088,,,,,,,,,,,,,, -39400,0.5984064,1.1741142,,,,,,,,,,,,,, -39500,0.47610766,1.1716917,,,,,,,,,,,,,, -39600,0.5426501,1.1811355,,,,,,,,,,,,,, -39700,0.7643702,1.24123,,,,,,,,,,,,,, -39800,0.54780126,1.2321879,,,,,,,,,,,,,, -39900,0.56600696,1.2006326,,,,,,,,,,,,,, -40000,0.58462805,1.1746674,,,,,,,,,,,,,, -40100,0.5850744,1.184752,,,,,,,,,,,,,, -40200,0.59955406,1.1792523,,,,,,,,,,,,,, -40300,0.59028137,1.1694157,,,,,,,,,,,,,, -40400,0.5699181,1.1557068,,,,,,,,,,,,,, -40500,0.60379213,1.1812831,,,,,,,,,,,,,, -40600,0.55865437,1.1461841,,,,,,,,,,,,,, -40700,0.5849607,1.1630148,,,,,,,,,,,,,, -40800,0.6021699,1.1364584,,,,,,,,,,,,,, -40882,,,0.19048192,0.0709795349392231,0.39756453,0.1176902208019154,5348.0,0.2234842,0.0749700404200434,2472.0,33186.37688279152,36393.77374911308,33186.37688279152,3204.458496570587,1.2498650550842283,0.0 -40900,0.6250841,1.1312335,,,,,,,,,,,,,, -41000,0.6015499,1.1900206,,,,,,,,,,,,,, -41100,0.6834896,1.1700279,,,,,,,,,,,,,, -41200,0.5517018,1.1795108,,,,,,,,,,,,,, -41300,0.5435861,1.152324,,,,,,,,,,,,,, -41400,0.5110045,1.1478802,,,,,,,,,,,,,, -41500,0.51086426,1.1247308,,,,,,,,,,,,,, -41600,0.584653,1.119794,,,,,,,,,,,,,, -41700,0.47455254,1.1793076,,,,,,,,,,,,,, -41800,0.5733223,1.1869863,,,,,,,,,,,,,, -41900,0.6333204,1.2019598,,,,,,,,,,,,,, -42000,0.5064792,1.1998297,,,,,,,,,,,,,, -42100,0.52117205,1.1325713,,,,,,,,,,,,,, -42200,0.5456372,1.1500157,,,,,,,,,,,,,, -42300,0.4751494,1.152253,,,,,,,,,,,,,, -42400,0.6157855,1.14187,,,,,,,,,,,,,, -42500,0.49425563,1.1031835,,,,,,,,,,,,,, -42600,0.51358974,1.131817,,,,,,,,,,,,,, -42676,,,0.19422482,0.0713012873535142,0.39274555,0.1154793052511658,5348.0,0.21795496,0.0721264192716267,2472.0,34626.39605259895,37967.18592643738,34626.39605259895,3337.72132229805,1.3047235012054443,0.0 -42700,0.46381253,1.148723,,,,,,,,,,,,,, -42800,0.6342555,1.1726483,,,,,,,,,,,,,, -42900,0.5712633,1.1693071,,,,,,,,,,,,,, -43000,0.6586926,1.1931026,,,,,,,,,,,,,, -43100,0.56641275,1.0904685,,,,,,,,,,,,,, -43200,0.57991517,1.2013402,,,,,,,,,,,,,, -43300,0.65388924,1.1592456,,,,,,,,,,,,,, -43400,0.48297176,1.1244898,,,,,,,,,,,,,, -43500,0.5431928,1.1072367,,,,,,,,,,,,,, -43600,0.57631713,1.1552111,,,,,,,,,,,,,, -43700,0.5326627,1.1303345,,,,,,,,,,,,,, -43800,0.5269786,1.1925569,,,,,,,,,,,,,, -43900,0.47253525,1.196385,,,,,,,,,,,,,, -44000,0.6822838,1.1787806,,,,,,,,,,,,,, -44100,0.5162412,1.1012863,,,,,,,,,,,,,, -44200,0.6812472,1.1347907,,,,,,,,,,,,,, -44300,0.5965022,1.174499,,,,,,,,,,,,,, -44400,0.5929595,1.1486813,,,,,,,,,,,,,, -44442,,,0.22243957,0.0831407830313949,0.3876792,0.1129401314963746,5348.0,0.21539187,0.0701561960473666,2472.0,36066.54161071777,39536.21631217003,36066.54161071777,3466.470259666443,1.3679873943328855,0.0 -44500,0.6021409,1.1200476,,,,,,,,,,,,,, -44600,0.5478141,1.135511,,,,,,,,,,,,,, -44700,0.69384915,1.113833,,,,,,,,,,,,,, -44800,0.6014143,1.1189512,,,,,,,,,,,,,, -44900,0.6008007,1.1165193,,,,,,,,,,,,,, -45000,0.6113231,1.0981566,,,,,,,,,,,,,, -45100,0.5630897,1.1085764,,,,,,,,,,,,,, -45200,0.5187151,1.0726749,,,,,,,,,,,,,, -45300,0.57765037,1.1404183,,,,,,,,,,,,,, -45400,0.5848035,1.137502,,,,,,,,,,,,,, -45500,0.6303235,1.1356165,,,,,,,,,,,,,, -45600,0.5652646,1.1006178,,,,,,,,,,,,,, -45700,0.5808573,1.105018,,,,,,,,,,,,,, -45800,0.4981319,1.0617934,,,,,,,,,,,,,, -45900,0.5805622,1.1026087,,,,,,,,,,,,,, -46000,0.5660747,1.0883112,,,,,,,,,,,,,, -46100,0.5629501,1.1790677,,,,,,,,,,,,,, -46200,0.67111856,1.1448433,,,,,,,,,,,,,, -46203,,,0.19375552,0.071080783519479,0.38640672,0.1129208221902546,5348.0,0.2095781,0.0694249791806308,2472.0,37506.617604494095,41104.15241241455,37506.617604494095,3594.2017362117767,1.4228923320770264,0.0 -46300,0.6664611,1.0909489,,,,,,,,,,,,,, -46400,0.5795465,1.0894104,,,,,,,,,,,,,, -46500,0.619648,1.0977464,,,,,,,,,,,,,, -46600,0.6029849,1.1316577,,,,,,,,,,,,,, -46700,0.6568118,1.089436,,,,,,,,,,,,,, -46800,0.5165538,1.1194761,,,,,,,,,,,,,, -46900,0.5754444,1.0987476,,,,,,,,,,,,,, -47000,0.62754834,1.0869424,,,,,,,,,,,,,, -47100,0.5613863,1.110381,,,,,,,,,,,,,, -47200,0.6238162,1.0803801,,,,,,,,,,,,,, -47300,0.55324435,1.111345,,,,,,,,,,,,,, -47400,0.52765244,1.142077,,,,,,,,,,,,,, -47500,0.5422642,1.0856522,,,,,,,,,,,,,, -47600,0.57134074,1.0981345,,,,,,,,,,,,,, -47700,0.534187,1.1516231,,,,,,,,,,,,,, -47800,0.5416086,1.0472497,,,,,,,,,,,,,, -47900,0.54836273,1.0915275,,,,,,,,,,,,,, -47980,,,0.17661671,0.0667275338759422,0.3707086,0.1087596667213763,5348.0,0.20401005,0.0674344443767391,2472.0,38946.74871778488,42677.6640393734,38946.74871778488,3727.454374790192,1.476792573928833,0.0 -48000,0.6154007,1.1465001,,,,,,,,,,,,,, -48100,0.56229734,1.0691013,,,,,,,,,,,,,, -48200,0.58801544,1.1373966,,,,,,,,,,,,,, -48300,0.56272364,1.1146603,,,,,,,,,,,,,, -48400,0.5677906,1.1279888,,,,,,,,,,,,,, -48500,0.6160929,1.0607249,,,,,,,,,,,,,, -48600,0.54072833,1.105387,,,,,,,,,,,,,, -48700,0.5353284,1.1254487,,,,,,,,,,,,,, -48800,0.5909579,1.0861483,,,,,,,,,,,,,, -48900,0.6254041,1.0630484,,,,,,,,,,,,,, -49000,0.7485456,1.1362209,,,,,,,,,,,,,, -49100,0.55638945,1.0987637,,,,,,,,,,,,,, -49200,0.61010236,1.088893,,,,,,,,,,,,,, -49300,0.5690575,1.0879905,,,,,,,,,,,,,, -49400,0.7075336,1.0508215,,,,,,,,,,,,,, -49500,0.51680547,1.013593,,,,,,,,,,,,,, -49600,0.64182127,1.1305392,,,,,,,,,,,,,, -49700,0.5636791,1.1026858,,,,,,,,,,,,,, -49754,,,0.14905035,0.0575971462441135,0.36766133,0.1074466339052106,5348.0,0.19967617,0.0666829159303719,2472.0,40386.87485766411,44248.83420395851,40386.87485766411,3858.368305444717,1.5332348346710205,0.0 -49800,0.6808211,1.1271101,,,,,,,,,,,,,, -49900,0.53762436,1.0750009,,,,,,,,,,,,,, -50000,0.5884365,1.1302727,,,,,,,,,,,,,, -50100,0.63889366,1.1120983,,,,,,,,,,,,,, -50200,0.63368356,1.0626665,,,,,,,,,,,,,, -50300,0.56951576,1.0813571,,,,,,,,,,,,,, -50400,0.5603117,1.1285783,,,,,,,,,,,,,, -50500,0.53671134,1.0642654,,,,,,,,,,,,,, -50600,0.65205985,1.0109121,,,,,,,,,,,,,, -50700,0.600548,1.1013438,,,,,,,,,,,,,, -50800,0.6274091,1.0931199,,,,,,,,,,,,,, -50900,0.66068375,1.0294853,,,,,,,,,,,,,, -51000,0.6875898,1.08988,,,,,,,,,,,,,, -51100,0.56767386,1.0659792,,,,,,,,,,,,,, -51200,0.6537766,1.0914912,,,,,,,,,,,,,, -51300,0.49562985,1.0051445,,,,,,,,,,,,,, -51400,0.55581117,1.0602944,,,,,,,,,,,,,, -51500,0.7215883,1.0163767,,,,,,,,,,,,,, -51528,,,0.1598807,0.0607569775209502,0.36701316,0.1054191567626017,5348.0,0.19805373,0.0648142506042695,2472.0,41827.40071630478,45819.96135163307,41827.40071630478,3988.840337753296,1.5893762111663818,0.0 -51600,0.57954663,1.0118408,,,,,,,,,,,,,, -51700,0.5377745,1.1005243,,,,,,,,,,,,,, -51800,0.6159988,1.0490335,,,,,,,,,,,,,, -51900,0.5209321,1.0562079,,,,,,,,,,,,,, -52000,0.5532094,1.0668492,,,,,,,,,,,,,, -52100,0.60270715,1.0726434,,,,,,,,,,,,,, -52200,0.6285144,1.0638775,,,,,,,,,,,,,, -52300,0.6771326,1.0505626,,,,,,,,,,,,,, -52400,0.60038817,1.035368,,,,,,,,,,,,,, -52500,0.7464413,1.1090076,,,,,,,,,,,,,, -52600,0.52483344,1.018371,,,,,,,,,,,,,, -52700,0.6526021,1.0452341,,,,,,,,,,,,,, -52800,0.55593854,1.0694797,,,,,,,,,,,,,, -52900,0.5279521,1.06044,,,,,,,,,,,,,, -53000,0.6186443,1.0545387,,,,,,,,,,,,,, -53100,0.6775495,1.0369475,,,,,,,,,,,,,, -53200,0.5589406,1.0399597,,,,,,,,,,,,,, -53300,0.59178966,1.0434091,,,,,,,,,,,,,, -53319,,,0.14107403,0.0538996735956021,0.34760937,0.1007269953754211,5348.0,0.19280446,0.0630674547559563,2472.0,43267.91893672943,47392.46687364578,43267.91893672943,4120.620878696442,1.7224667072296145,0.0 -53400,0.62115145,1.0286229,,,,,,,,,,,,,, -53500,0.60725784,1.0817184,,,,,,,,,,,,,, -53600,0.51328623,1.0071043,,,,,,,,,,,,,, -53700,0.72471243,1.046557,,,,,,,,,,,,,, -53800,0.5078791,0.9962895,,,,,,,,,,,,,, -53900,0.6871503,1.0211949,,,,,,,,,,,,,, -54000,0.6595378,1.0613698,,,,,,,,,,,,,, -54100,0.5208197,1.0654516,,,,,,,,,,,,,, -54200,0.65616393,1.0546762,,,,,,,,,,,,,, -54300,0.5767482,1.0152534,,,,,,,,,,,,,, -54400,0.5980614,0.9981025,,,,,,,,,,,,,, -54500,0.55606425,0.98967916,,,,,,,,,,,,,, -54600,0.49311343,1.0172322,,,,,,,,,,,,,, -54700,0.5529016,0.9698937,,,,,,,,,,,,,, -54800,0.6420246,1.0499974,,,,,,,,,,,,,, -54900,0.5979925,1.0631195,,,,,,,,,,,,,, -55000,0.6390564,1.0052819,,,,,,,,,,,,,, -55100,0.64514416,1.0529898,,,,,,,,,,,,,, -55101,,,0.13622159,0.0528374141615769,0.3458523,0.0996746381918765,5348.0,0.18519415,0.0599191599130664,2472.0,44708.85677075386,48965.83365011215,44708.85677075386,4252.91517329216,1.7831799983978271,0.0 -55200,0.57315,1.0250949,,,,,,,,,,,,,, -55300,0.59136325,1.0510508,,,,,,,,,,,,,, -55400,0.62453765,1.0497185,,,,,,,,,,,,,, -55500,0.5493146,1.0460662,,,,,,,,,,,,,, -55600,0.56702995,1.0526798,,,,,,,,,,,,,, -55700,0.69305,1.0113627,,,,,,,,,,,,,, -55800,0.5586433,1.054691,,,,,,,,,,,,,, -55900,0.64754766,1.0669049,,,,,,,,,,,,,, -56000,0.70409584,1.0510379,,,,,,,,,,,,,, -56100,0.66118085,1.0866843,,,,,,,,,,,,,, -56200,0.6882295,1.0518787,,,,,,,,,,,,,, -56300,0.67885953,1.0260322,,,,,,,,,,,,,, -56400,0.5974699,1.0411296,,,,,,,,,,,,,, -56500,0.6638485,1.0247726,,,,,,,,,,,,,, -56600,0.5836665,0.9972789,,,,,,,,,,,,,, -56700,0.6591459,1.0578701,,,,,,,,,,,,,, -56800,0.6956626,1.0195462,,,,,,,,,,,,,, -56883,,,0.12630334,0.0483699102829537,0.33833632,0.0968940981105844,5348.0,0.18261853,0.0596144862185932,2472.0,46149.09987139702,50537.77483201027,46149.09987139702,4384.476100206375,1.8459181785583496,0.0 -56900,0.62775534,1.0325822,,,,,,,,,,,,,, -57000,0.59137475,0.9927862,,,,,,,,,,,,,, -57100,0.5877631,1.0048018,,,,,,,,,,,,,, -57200,0.6204682,1.0155287,,,,,,,,,,,,,, -57300,0.56558174,0.9698789,,,,,,,,,,,,,, -57400,0.5953482,1.0443,,,,,,,,,,,,,, -57500,0.56286126,0.9961002,,,,,,,,,,,,,, -57600,0.63021296,1.0776569,,,,,,,,,,,,,, -57700,0.7546748,1.0914122,,,,,,,,,,,,,, -57800,0.6204457,1.0055684,,,,,,,,,,,,,, -57900,0.69119316,1.0228962,,,,,,,,,,,,,, -58000,0.65376794,0.9960771,,,,,,,,,,,,,, -58100,0.7552346,0.9896412,,,,,,,,,,,,,, -58200,0.5935526,1.0250795,,,,,,,,,,,,,, -58300,0.66829854,1.0235022,,,,,,,,,,,,,, -58400,0.68803376,0.9978731,,,,,,,,,,,,,, -58500,0.59797853,0.9959384,,,,,,,,,,,,,, -58600,0.63568383,1.0108079,,,,,,,,,,,,,, -58649,,,0.12676607,0.0491741355035884,0.3321127,0.0947411104781949,5348.0,0.17599674,0.0569739808664919,2472.0,47589.387810468674,52108.10684657097,47589.387810468674,4514.391450881958,1.9014570713043213,0.0 -58700,0.62947345,1.006619,,,,,,,,,,,,,, -58800,0.74346,0.9833596,,,,,,,,,,,,,, -58900,0.6635595,0.9999885,,,,,,,,,,,,,, -59000,0.77867615,1.0120867,,,,,,,,,,,,,, -59100,0.55669314,1.0094202,,,,,,,,,,,,,, -59200,0.67030984,0.98771244,,,,,,,,,,,,,, -59300,0.57880014,1.0123975,,,,,,,,,,,,,, -59400,0.6928806,1.0004741,,,,,,,,,,,,,, -59500,0.6420573,1.0160483,,,,,,,,,,,,,, -59600,0.7201677,1.0015199,,,,,,,,,,,,,, -59700,0.5932039,0.9691278,,,,,,,,,,,,,, -59800,0.53112936,0.977566,,,,,,,,,,,,,, -59900,0.6505195,0.97632885,,,,,,,,,,,,,, -60000,0.6481653,1.0151974,,,,,,,,,,,,,, -60100,0.6550266,0.9738458,,,,,,,,,,,,,, -60200,0.62049186,0.9734569,,,,,,,,,,,,,, -60300,0.6175703,0.9155226,,,,,,,,,,,,,, -60400,0.5498897,0.94629383,,,,,,,,,,,,,, -60416,,,0.12744635,0.0481000664350515,0.32808104,0.0939494289272715,5348.0,0.17387721,0.0561005829423354,2472.0,49029.39686584473,53681.36341094971,49029.39686584473,4647.504854440689,1.9610748291015625,0.0 -60500,0.6324128,0.95562327,,,,,,,,,,,,,, -60600,0.6823966,1.0279759,,,,,,,,,,,,,, -60700,0.6602088,0.97991574,,,,,,,,,,,,,, -60800,0.59246063,0.9855523,,,,,,,,,,,,,, -60900,0.6913582,0.95751697,,,,,,,,,,,,,, -61000,0.5495889,0.9874728,,,,,,,,,,,,,, -61100,0.7714375,0.9676332,,,,,,,,,,,,,, -61200,0.66772497,0.99337333,,,,,,,,,,,,,, -61300,0.6831551,0.95096517,,,,,,,,,,,,,, -61400,0.8174584,0.97852844,,,,,,,,,,,,,, -61500,0.70263684,1.0036026,,,,,,,,,,,,,, -61600,0.5684272,0.9566339,,,,,,,,,,,,,, -61700,0.7503206,1.0132942,,,,,,,,,,,,,, -61800,0.64524513,0.9855667,,,,,,,,,,,,,, -61900,0.6701717,0.98000395,,,,,,,,,,,,,, -62000,0.6711883,1.0266582,,,,,,,,,,,,,, -62100,0.61587846,0.9347966,,,,,,,,,,,,,, -62193,,,0.10352248,0.0405317551057009,0.3215525,0.0921246994989235,5348.0,0.16972645,0.0559990250441776,2472.0,50469.78848719597,55253.57203221321,50469.78848719597,4779.186099052429,2.0234742164611816,0.0 -62200,0.6117971,0.9688002,,,,,,,,,,,,,, -62300,0.7807374,0.99030393,,,,,,,,,,,,,, -62400,0.57058024,0.9849822,,,,,,,,,,,,,, -62500,0.58053523,0.91171265,,,,,,,,,,,,,, -62600,0.6542204,0.94098556,,,,,,,,,,,,,, -62700,0.59857935,0.9856379,,,,,,,,,,,,,, -62800,0.7040557,0.9710833,,,,,,,,,,,,,, -62900,0.6460042,0.9538558,,,,,,,,,,,,,, -63000,0.62336457,0.98126864,,,,,,,,,,,,,, -63100,0.7606099,0.98238546,,,,,,,,,,,,,, -63200,0.72558427,1.0050572,,,,,,,,,,,,,, -63300,0.6697291,0.9700068,,,,,,,,,,,,,, -63400,0.7268421,0.97581095,,,,,,,,,,,,,, -63500,0.59871167,1.026689,,,,,,,,,,,,,, -63600,0.64646906,0.99184775,,,,,,,,,,,,,, -63700,0.64671624,0.95789254,,,,,,,,,,,,,, -63800,0.62768304,0.96120393,,,,,,,,,,,,,, -63900,0.7222745,0.94388014,,,,,,,,,,,,,, -63948,,,0.10351148,0.0399570991739309,0.32219872,0.0916902401112216,5348.0,0.16616057,0.0551865618589157,2472.0,51909.92304944992,56825.101155519485,51909.92304944992,4910.448292016983,2.0826873779296875,0.0 -64000,0.6590724,0.9188413,,,,,,,,,,,,,, -64100,0.6246211,0.9409004,,,,,,,,,,,,,, -64200,0.57167304,0.9142903,,,,,,,,,,,,,, -64300,0.6185281,0.90121144,,,,,,,,,,,,,, -64400,0.69345003,0.92649186,,,,,,,,,,,,,, -64500,0.5704563,0.97125584,,,,,,,,,,,,,, -64600,0.57369924,0.97326994,,,,,,,,,,,,,, -64700,0.7215262,0.9728473,,,,,,,,,,,,,, -64800,0.661822,0.91457045,,,,,,,,,,,,,, -64900,0.5957151,0.88489294,,,,,,,,,,,,,, -65000,0.71980983,0.95648116,,,,,,,,,,,,,, -65100,0.8206072,0.9272307,,,,,,,,,,,,,, -65200,0.6472292,0.95136917,,,,,,,,,,,,,, -65300,0.5915044,0.9540869,,,,,,,,,,,,,, -65400,0.8619571,0.9645975,,,,,,,,,,,,,, -65500,0.62727755,0.91385823,,,,,,,,,,,,,, -65600,0.7100543,0.9263394,,,,,,,,,,,,,, -65700,0.63400865,0.9394425,,,,,,,,,,,,,, -65800,0.63633966,0.9637375,,,,,,,,,,,,,, -65821,,,0.09428704,0.0364069544257465,0.31000638,0.0878187242341446,5348.0,0.16144967,0.0532975849531818,2472.0,53350.49229979515,58390.26790380478,53350.49229979515,5034.910115480423,2.148332595825196,0.0 -65900,0.5789326,0.9148507,,,,,,,,,,,,,, -66000,0.71968544,0.9395333,,,,,,,,,,,,,, -66100,0.79945236,0.8522083,,,,,,,,,,,,,, -66200,0.57887334,0.9632663,,,,,,,,,,,,,, -66300,0.7287931,0.92844635,,,,,,,,,,,,,, -66400,0.62178355,0.9280552,,,,,,,,,,,,,, -66500,0.78319,0.9294194,,,,,,,,,,,,,, -66600,0.7549158,0.8888644,,,,,,,,,,,,,, -66700,0.9302994,0.9324725,,,,,,,,,,,,,, -66800,0.57746565,0.9222159,,,,,,,,,,,,,, -66900,0.95412624,0.9155583,,,,,,,,,,,,,, -67000,0.63590103,0.889611,,,,,,,,,,,,,, -67100,0.7500723,0.9095862,,,,,,,,,,,,,, -67200,0.67757034,0.8890794,,,,,,,,,,,,,, -67300,0.6960418,0.9852237,,,,,,,,,,,,,, -67400,0.6229418,0.90286106,,,,,,,,,,,,,, -67500,0.6554419,0.89843714,,,,,,,,,,,,,, -67600,0.655554,0.9218741,,,,,,,,,,,,,, -67697,,,0.09369294,0.0362907587808946,0.31026995,0.0872490997036021,5348.0,0.16130331,0.0517742164808157,2472.0,54790.55012321472,59953.31989455223,54790.55012321472,5157.780250310898,2.2042810916900635,0.0 -67700,0.53291196,0.92400056,,,,,,,,,,,,,, -67800,0.69930947,0.91239107,,,,,,,,,,,,,, -67900,0.7041295,0.8930249,,,,,,,,,,,,,, -68000,0.6613267,0.9537926,,,,,,,,,,,,,, -68100,0.6878477,0.90113187,,,,,,,,,,,,,, -68200,0.66467994,0.90901834,,,,,,,,,,,,,, -68300,0.5999041,0.9066174,,,,,,,,,,,,,, -68400,0.75685966,0.911563,,,,,,,,,,,,,, -68500,0.6812378,0.9143958,,,,,,,,,,,,,, -68600,0.6062454,0.9049475,,,,,,,,,,,,,, -68700,0.9494842,0.86363757,,,,,,,,,,,,,, -68800,0.62846404,0.9087306,,,,,,,,,,,,,, -68900,0.6752368,0.9336073,,,,,,,,,,,,,, -69000,0.6446173,0.89807737,,,,,,,,,,,,,, -69100,0.78863674,0.96701217,,,,,,,,,,,,,, -69200,0.6124928,0.86597687,,,,,,,,,,,,,, -69300,0.8038646,0.89988154,,,,,,,,,,,,,, -69400,0.6449562,0.9097595,,,,,,,,,,,,,, -69500,0.8004018,0.9186401,,,,,,,,,,,,,, -69572,,,0.08811778,0.033569240401601,0.3047844,0.0857429738262355,5348.0,0.15887609,0.051489854365974,2472.0,56230.60710000992,61516.4474709034,56230.60710000992,5280.730521202087,2.2564854621887207,0.0 -69572,,,,,,,,,,,56230.60710000992,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 3221f9c7c..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,84 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -205.8416879177093,0.0,41.73051595687866,1,0,41.73051595687866,29.561045,2472,2.7003635772754047,247.57225155830383,30.523008,2.7218780584303355,29.339128,5348,2.4762157621865857 -331.4041941165924,0.038832664489746,1481.632533788681,1858,0,1481.632533788681,1.2089342,2472,0.3347348323279101,1813.1430945396423,1.1975354,0.330683199110214,1.635861,5348,0.4088455931336107 -462.3708300590515,0.0890276432037353,2922.079998731613,3700,0,2922.079998731613,0.55478996,2472,0.1749233237868909,3384.678767681122,0.5168572,0.1696561276403057,0.8609228,5348,0.2439344642150284 -592.0327932834625,0.1338312625885009,4362.616629600525,5535,0,4362.616629600525,0.4684899,2472,0.1491073060751934,4954.994887113571,0.40652406,0.1399707976870081,0.7689421,5348,0.2186392731977176 -722.7132506370544,0.1826066970825195,5802.827211141586,7371,0,5802.827211141586,0.4458511,2472,0.1449231206710945,6526.008309602737,0.40096128,0.1349422446564525,0.74107677,5348,0.2118230881373277 -850.259199142456,0.2335467338562011,7242.870648384094,9182,0,7242.870648384094,0.41574928,2472,0.1366156846017915,8093.72025179863,0.3648077,0.1252038530953054,0.69743854,5348,0.2004402521795379 -982.551082611084,0.2804255485534668,8683.471186876297,10970,0,8683.471186876297,0.40698555,2472,0.1310300002031158,9666.73115491867,0.38352492,0.1273506142938688,0.6855383,5348,0.1966073549147011 -1112.6423828601835,0.3264663219451904,10124.027910232544,12775,0,10124.027910232544,0.3975374,2472,0.1268458147990169,11237.498494386671,0.35075814,0.1195360857635122,0.6756054,5348,0.1937495776089286 -1244.3145473003387,0.3751053810119629,11564.629160404204,14568,0,11564.629160404204,0.39017347,2472,0.1259114821359657,12809.892186164856,0.32302096,0.1124106191008606,0.6762659,5348,0.1907180165480753 -1375.6002094745636,0.4212281703948974,13004.699365615845,16370,0,13004.699365615845,0.3699334,2472,0.1198789429853959,14381.36817574501,0.2960854,0.1013837388079949,0.64389575,5348,0.1853596840997518 -1507.6907210350037,0.4677519798278808,14445.077192544935,18166,0,14445.077192544935,0.35234314,2472,0.1145166859626673,15953.95691728592,0.29231447,0.1036590924935779,0.61532944,5348,0.1773270127537967 -1640.5355043411255,0.5157749652862549,15885.665563583374,19969,0,15885.665563583374,0.3425359,2472,0.1100075152844636,17527.51263141632,0.3047939,0.1021550014542954,0.6024862,5348,0.1741120132848026 -1773.6003246307373,0.5648343563079834,17326.005275011063,21738,0,17326.005275011063,0.3305221,2472,0.1073873215119939,19101.038394212723,0.29166546,0.0989527000290183,0.5862438,5348,0.1679137260202554 -1905.6069421768188,0.6118886470794678,18766.441838026047,23537,0,18766.441838026047,0.32181597,2472,0.1047874393191558,20673.603135347366,0.27430362,0.0944130054532359,0.57207096,5348,0.1631636367147146 -2036.3903830051424,0.6608626842498779,20206.777960777283,25314,0,20206.777960777283,0.31350023,2472,0.1015782097373712,22244.84517145157,0.24832816,0.0862347451543431,0.554332,5348,0.1595624511233188 -2168.593649625778,0.711815357208252,21646.736126184464,27091,0,21646.736126184464,0.29808536,2472,0.096419068510958,23817.134423017505,0.22888745,0.0808210960118939,0.53358024,5348,0.153885515124014 -2300.7924242019653,0.7672219276428223,23087.10746240616,28847,0,23087.10746240616,0.2901405,2472,0.0931692157699104,25389.833054065704,0.21742319,0.0746847224637162,0.5180734,5348,0.148643038512411 -2431.742198228836,0.8124017715454102,24527.48431277275,30628,0,24527.48431277275,0.28189483,2472,0.0909755651697032,26961.279413223267,0.2305872,0.0798394435777431,0.5070502,5348,0.1471948405534047 -2561.5807065963745,0.8636722564697266,25967.66591262817,32427,0,25967.66591262817,0.27268323,2472,0.0883553713972335,28531.426872015,0.20866887,0.0707545890102496,0.4951226,5348,0.1422226942274829 -2692.9797925949097,0.9129633903503418,27408.01134252548,34173,0,27408.01134252548,0.25914782,2472,0.0833993459671358,30103.295015335083,0.21239454,0.0721840662868767,0.47541836,5348,0.1369222896975197 -2824.903109550476,0.9611873626708984,28848.267106056213,35938,0,28848.267106056213,0.2501728,2472,0.0803119858631405,31675.596952438354,0.20002264,0.0666794547992796,0.46640846,5348,0.1356961487588943 -2957.7188007831573,1.0137641429901123,30288.733597755432,37722,0,30288.733597755432,0.24341933,2472,0.0786870594926167,33249.00675344467,0.15005884,0.0538370078820864,0.4560848,5348,0.1316218851675565 -3089.12416434288,1.0664176940917969,31729.1388194561,39509,0,31729.1388194561,0.23879972,2472,0.0763715394146202,34820.944878816605,0.17008072,0.0592331735338218,0.44051743,5348,0.1278662251272 -3219.284531354904,1.1159639358520508,33169.0481672287,41278,0,33169.0481672287,0.23437671,2472,0.074807547782991,36391.13930130005,0.21692227,0.0749608110087165,0.43239108,5348,0.1240429825154233 -3350.8226585388184,1.16986346244812,34608.96385860443,43028,0,34608.96385860443,0.22969964,2472,0.0736904109032559,37962.72094845772,0.22430344,0.0763863059782276,0.4287209,5348,0.123058207903299 -3484.164263486862,1.2190589904785156,36049.72664427757,44794,0,36049.72664427757,0.22824307,2472,0.0731826214124672,39536.94931817055,0.2613986,0.0901661879719919,0.4249241,5348,0.1222375623931954 -3613.956609010696,1.2820138931274414,37489.63005280495,46564,0,37489.63005280495,0.226913,2472,0.0726951435013101,41106.78757286072,0.23482683,0.0778161778991928,0.42247275,5348,0.1216389739034727 -3747.8457283973694,1.3374078273773191,38930.12107872963,48322,0,38930.12107872963,0.22681805,2472,0.0728576361383624,42681.298971414566,0.21850312,0.0765535056172269,0.42252696,5348,0.1217934483524334 -3881.402559518814,1.3948400020599363,40370.472751140594,50079,0,40370.472751140594,0.2268369,2472,0.0729388824568886,44255.343381881714,0.19701824,0.0689132018836529,0.42258406,5348,0.1218513762707937 -4012.835418462753,1.4524447917938232,41810.91124749184,51869,0,41810.91124749184,0.22682315,2472,0.0728982592976255,45827.35194015503,0.22619057,0.0782449070478528,0.4225461,5348,0.1218513762707937 -4143.967495203018,1.5726919174194336,43250.74343371391,53618,0,43250.74343371391,0.2268347,2472,0.0729185708772571,47398.51116228104,0.21388717,0.0744126868235698,0.42257115,5348,0.1218127576585535 -4274.887751102448,1.6357917785644531,44691.26508259773,55401,0,44691.26508259773,0.22681281,2472,0.0728982592976255,48970.09406757355,0.21749085,0.0768450514759327,0.42252436,5348,0.1218417216177336 -4409.608050823212,1.689574956893921,46131.35928487778,57190,0,46131.35928487778,0.22681028,2472,0.072877947717994,50545.039157152176,0.22159734,0.0764886128364389,0.42251283,5348,0.1218127576585535 -4546.888776302338,1.7484843730926514,47571.82388663292,58964,0,47571.82388663292,0.22684751,2472,0.0729795056161517,52122.92017412186,0.23678349,0.0812758261506587,0.42260745,5348,0.1218610309238537 -4679.714070558548,1.80523419380188,49012.130105018616,60715,0,49012.130105018616,0.2268356,2472,0.0729185708772571,53696.184061050415,0.2431013,0.0816443219812892,0.42257443,5348,0.1218320669646736 -4810.312841653824,1.8586266040802,50452.07918572426,62476,0,50452.07918572426,0.22685315,2472,0.0729185708772571,55266.86240553856,0.21817155,0.076703208016679,0.42259833,5348,0.1218610309238537 -4941.022246599197,1.9132137298583984,51892.149191617966,64265,0,51892.149191617966,0.22683159,2472,0.0729185708772571,56837.77284407616,0.22722098,0.0794977408607548,0.4225785,5348,0.1218224123116135 -5072.738001346588,1.9674947261810305,53332.38691663742,66011,0,53332.38691663742,0.22683603,2472,0.0729185708772571,58409.85575699806,0.22163488,0.0767140561531937,0.42258117,5348,0.1218127576585535 -5203.679777383804,2.0264124870300293,54772.54679131508,67774,0,54772.54679131508,0.22684477,2472,0.0729591940365202,59981.093438625336,0.23132047,0.0787533378502251,0.4225996,5348,0.1218610309238537 -5335.775634050369,2.084322690963745,56212.43574166298,69539,0,56212.43574166298,0.2268291,2472,0.0729185708772571,61553.21344566345,0.2290587,0.0786990451912971,0.4225713,5348,0.1218224123116135 -5466.968840122223,2.1383402347564697,57652.62449741364,71324,0,57652.62449741364,0.22685763,2472,0.0729591940365202,63124.72708106041,0.2350017,0.0823583454169111,0.4226408,5348,0.1218513762707937 -5598.56943821907,2.190993070602417,59093.15683174133,73071,0,59093.15683174133,0.22680901,2472,0.0729185708772571,64696.9893181324,0.20977716,0.0732559520107647,0.42251086,5348,0.1218320669646736 -5730.616383552551,2.2462613582611084,60533.7385866642,74813,0,60533.7385866642,0.2268203,2472,0.0729388824568886,66269.74801373482,0.20217852,0.069316295956455,0.42254487,5348,0.1218127576585535 -5861.242010354996,2.3004186153411865,61974.3486623764,76605,0,61974.3486623764,0.22685282,2472,0.0729795056161517,67841.11692786217,0.24120624,0.084248732741005,0.422617,5348,0.1218803402299738 -5989.411997795105,2.428640842437744,63414.18882107735,78338,0,63414.18882107735,0.22684146,2472,0.0729185708772571,69409.33088636398,0.23317532,0.078230944254835,0.42259285,5348,0.1218513762707937 -6121.52056145668,2.482978343963623,64854.61954379082,80112,0,64854.61954379082,0.22684368,2472,0.0729591940365202,70982.00092959404,0.23077081,0.0776995343340543,0.42259255,5348,0.1218320669646736 -6255.759551286697,2.541866540908813,66294.92736124992,81892,0,66294.92736124992,0.22684759,2472,0.0729185708772571,72556.6824965477,0.21406859,0.0751324795760653,0.42261186,5348,0.1218513762707937 -6385.677375793457,2.598761558532715,67735.07097840309,83650,0,67735.07097840309,0.22683038,2472,0.0729185708772571,74126.87830376625,0.21813707,0.0758759012604678,0.4225675,5348,0.1218610309238537 -6518.407257318497,2.6557695865631104,69175.55707406998,85420,0,69175.55707406998,0.22683091,2472,0.0729185708772571,75700.22792291641,0.21527421,0.0751402136641586,0.42256075,5348,0.1218224123116135 -6669.279905557632,2.7173826694488525,70615.84012532234,87217,0,70615.84012532234,0.22682695,2472,0.0729185708772571,77291.52449822426,0.136696,0.0487799693929603,0.4225517,5348,0.1218031030054935 -6802.439275741577,2.773588180541992,72055.93694233894,89013,0,72055.93694233894,0.22683926,2472,0.0729388824568886,78864.91400146484,0.13775001,0.0477160208423163,0.42260024,5348,0.1218706855769137 -6934.860629796982,2.8354685306549072,73496.3667254448,90750,0,73496.3667254448,0.22682476,2472,0.0729388824568886,80437.90277504921,0.15148643,0.0530193715478738,0.42255753,5348,0.1218417216177336 -7067.9899690151215,2.8983287811279297,74936.79920172691,92494,0,74936.79920172691,0.22684576,2472,0.0729795056161517,82011.60424017906,0.13619804,0.0479902329075882,0.4226031,5348,0.1218899948830338 -7200.00839805603,2.9541354179382324,76377.0752260685,94269,0,76377.0752260685,0.22686507,2472,0.0729998171957833,83584.03257799149,0.14904647,0.0523268522319548,0.4226395,5348,0.1218513762707937 -7335.930129051208,3.0127499103546143,77817.67358899117,96036,0,77817.67358899117,0.2268396,2472,0.0729185708772571,85160.68888759613,0.13960445,0.0499481765629841,0.42258808,5348,0.1218417216177336 -7468.458901882172,3.0681488513946533,79257.58918118477,97761,0,79257.58918118477,0.22680737,2472,0.0728982592976255,86733.26485586166,0.16996482,0.0559007210919392,0.4225012,5348,0.1218127576585535 -7606.780797481537,3.13254976272583,80698.04615569115,99506,0,80698.04615569115,0.22684963,2472,0.0729388824568886,88312.18506860733,0.15189572,0.0539175148430873,0.42260715,5348,0.1218610309238537 -7743.091492176056,3.1927475929260254,82138.8430378437,101269,0,82138.8430378437,0.22683719,2472,0.0729388824568886,89889.42877030373,0.13218991,0.0481681788283444,0.42258465,5348,0.1218417216177336 -7874.434673547745,3.255798816680908,83579.43496155739,103000,0,83579.43496155739,0.22684465,2472,0.0729388824568886,91461.50216126442,0.13685067,0.0485069051633973,0.42259264,5348,0.1218610309238537 -8007.273331642151,3.319591999053955,85019.83351564407,104750,0,85019.83351564407,0.22683273,2472,0.0728982592976255,93034.88031411172,0.13595581,0.0497899612236105,0.42256644,5348,0.1218031030054935 -8142.552544355392,3.381679058074951,86459.7220981121,106513,0,86459.7220981121,0.22684817,2472,0.0729388824568886,94610.18890833856,0.15283024,0.052183996907615,0.42258832,5348,0.1218320669646736 -8274.74119591713,3.4438462257385254,87899.6213388443,108264,0,87899.6213388443,0.2268602,2472,0.0729795056161517,96182.41661047935,0.152285,0.0533659226724665,0.42263636,5348,0.1218513762707937 -8409.858960390091,3.506175756454468,89340.51832222939,110014,0,89340.51832222939,0.22682457,2472,0.0729185708772571,97758.5703458786,0.14916688,0.053241399287218,0.4225592,5348,0.1218706855769137 -8546.651809930801,3.56905198097229,90780.58996725082,111767,0,90780.58996725082,0.22682922,2472,0.0728982592976255,99335.57467722891,0.14817163,0.0517262157389796,0.4225653,5348,0.1218320669646736 -8681.74896121025,3.631675720214844,92221.01919841766,113547,0,92221.01919841766,0.22680725,2472,0.0728576361383624,100911.24274611472,0.1355164,0.0486579392737951,0.4225174,5348,0.1218224123116135 -8821.899176359177,3.6960880756378174,93661.36353373528,115301,0,93661.36353373528,0.22684036,2472,0.0729185708772571,102491.87964940073,0.14265695,0.0494268791872569,0.42258564,5348,0.1218417216177336 -8956.770744800568,3.7561357021331774,95101.26403808594,117052,0,95101.26403808594,0.22685356,2472,0.0729591940365202,104066.78872728348,0.15840484,0.0546797230036788,0.4226101,5348,0.1218320669646736 -9088.92824625969,3.823304891586304,96541.71780490877,118823,0,96541.71780490877,0.22683339,2472,0.0729388824568886,105639.54648947716,0.15105623,0.0514460345043607,0.42255726,5348,0.1218417216177336 -9227.363971710203,3.88323712348938,97982.20047354698,120570,0,97982.20047354698,0.22680296,2472,0.072877947717994,107218.6017394066,0.14774881,0.0528778085564789,0.42251122,5348,0.1217934483524334 -9361.547214984894,3.952324867248535,99422.66232037544,122319,0,99422.66232037544,0.22685014,2472,0.0729795056161517,108793.39494419098,0.18443608,0.0592126955763319,0.4226183,5348,0.1218706855769137 -9495.424234628676,4.015031814575195,100862.93128800392,124089,0,100862.93128800392,0.2268248,2472,0.0729185708772571,110367.68206977844,0.1285202,0.0454345017125746,0.42256793,5348,0.1218706855769137 -9627.709649086,4.083231449127197,102303.157143116,125833,0,102303.157143116,0.22681937,2472,0.0729185708772571,111940.33980464935,0.14595449,0.0509496202041957,0.42254525,5348,0.1218320669646736 -9759.584423303604,4.149665832519531,103743.35876011848,127576,0,103743.35876011848,0.22682925,2472,0.0729388824568886,113512.56096434592,0.20160694,0.0704816760832765,0.4225641,5348,0.1218513762707937 -9893.917643547058,4.215802431106567,105184.04212832452,129314,0,105184.04212832452,0.2268536,2472,0.0730404403550464,115087.72128272057,0.22234,0.0753358951751978,0.42261353,5348,0.1218224123116135 -10024.404949426653,4.290512323379517,106624.42895913124,131085,0,106624.42895913124,0.22685236,2472,0.0729388824568886,116658.7487475872,0.2571062,0.0883086597364024,0.42260465,5348,0.1218899948830338 -10156.198406934738,4.354108810424805,108064.59675383568,132820,0,108064.59675383568,0.22682151,2472,0.0729185708772571,118230.8504126072,0.23016225,0.0764056814180068,0.42252997,5348,0.1217934483524334 -10286.368607997894,4.4263763427734375,109504.54699015616,134567,0,109504.54699015616,0.22683118,2472,0.0729388824568886,119801.12107515337,0.2190213,0.0762029456223062,0.42256498,5348,0.1218706855769137 -10420.468059301376,4.494283437728882,110944.7632997036,136327,0,110944.7632997036,0.22684185,2472,0.0729185708772571,121375.58246660233,0.19412705,0.0682318335919715,0.4225796,5348,0.1218320669646736 -10552.179483890532,4.551663398742676,112384.63683605194,138066,0,112384.63683605194,0.22682582,2472,0.0728982592976255,122947.29821825027,0.2334239,0.0803974607883772,0.42255595,5348,0.1218224123116135 -10686.553840637209,4.618646621704102,113824.81503725052,139824,0,113824.81503725052,0.22685347,2472,0.0729795056161517,124521.99733042716,0.21257159,0.0737407763875521,0.42262492,5348,0.1218031030054935 -10818.959375619888,4.695364475250244,115265.31365394592,141595,0,115265.31365394592,0.22683278,2472,0.072877947717994,126095.0590775013,0.21481861,0.0760388121012688,0.42257047,5348,0.1218513762707937 -10951.057807445526,4.766013145446777,116705.55896472932,143357,0,116705.55896472932,0.22684102,2472,0.0729795056161517,127667.55155706406,0.22314472,0.0771731146606444,0.4225849,5348,0.1217934483524334 -11081.355797290802,4.835377931594849,117226.72539401054,144000,0,117226.72539401054,0.22682914,2472,0.07291857087725713,128319.11793255806,0.23591998,0.08150679031841282,0.42256364,5348,0.12179344835243346 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index a5e4a5c2e..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1525 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,31.242565,33.323677,,,,,,,,,,,,,, -1,,,30.523008,2.7218780584303355,29.339128,2.4762157621865857,5348.0,29.561045,2.7003635772754047,2472.0,41.73051595687866,247.57225155830383,41.73051595687866,205.8416879177093,0.0,0.0 -100,1.0791733,6.168745,,,,,,,,,,,,,, -200,0.3299723,5.8328047,,,,,,,,,,,,,, -300,0.49131104,5.675682,,,,,,,,,,,,,, -400,0.69914085,5.373202,,,,,,,,,,,,,, -500,1.3359932,4.6298037,,,,,,,,,,,,,, -600,3.4998085,3.8419955,,,,,,,,,,,,,, -700,1.5315919,3.4638245,,,,,,,,,,,,,, -800,2.7192078,3.1801436,,,,,,,,,,,,,, -900,2.556485,2.9534264,,,,,,,,,,,,,, -1000,2.2074409,2.8284512,,,,,,,,,,,,,, -1100,1.7519821,2.6512625,,,,,,,,,,,,,, -1200,1.7686067,2.5469396,,,,,,,,,,,,,, -1300,1.8007121,2.5457678,,,,,,,,,,,,,, -1400,2.7594514,2.4354973,,,,,,,,,,,,,, -1500,4.0272665,2.447012,,,,,,,,,,,,,, -1600,1.6311324,2.2825916,,,,,,,,,,,,,, -1700,2.2383723,2.1292498,,,,,,,,,,,,,, -1800,2.3260107,2.1467183,,,,,,,,,,,,,, -1858,,,1.1975354,0.330683199110214,1.635861,0.4088455931336107,5348.0,1.2089342,0.3347348323279101,2472.0,1481.632533788681,1813.1430945396423,1481.632533788681,331.4041941165924,0.038832664489746,0.0 -1900,2.0227504,2.1801512,,,,,,,,,,,,,, -2000,1.9590464,2.1440022,,,,,,,,,,,,,, -2100,1.7220047,2.0797708,,,,,,,,,,,,,, -2200,1.8286653,2.005908,,,,,,,,,,,,,, -2300,2.7493243,2.0187678,,,,,,,,,,,,,, -2400,2.6398897,1.9552001,,,,,,,,,,,,,, -2500,1.8622763,1.9188265,,,,,,,,,,,,,, -2600,2.4719315,1.9472363,,,,,,,,,,,,,, -2700,1.554062,1.8985202,,,,,,,,,,,,,, -2800,1.7594597,1.968958,,,,,,,,,,,,,, -2900,3.1460567,1.8978825,,,,,,,,,,,,,, -3000,2.7646177,1.8747287,,,,,,,,,,,,,, -3100,2.128961,1.8394226,,,,,,,,,,,,,, -3200,2.109468,1.8747879,,,,,,,,,,,,,, -3300,2.7753067,1.8103417,,,,,,,,,,,,,, -3400,1.9393684,1.9061972,,,,,,,,,,,,,, -3500,3.3827317,1.8138455,,,,,,,,,,,,,, -3600,2.612297,1.8271188,,,,,,,,,,,,,, -3700,,,0.5168572,0.1696561276403057,0.8609228,0.2439344642150284,5348.0,0.55478996,0.1749233237868909,2472.0,2922.079998731613,3384.678767681122,2922.079998731613,462.3708300590515,0.0890276432037353,0.0 -3700,2.249179,1.851718,,,,,,,,,,,,,, -3800,1.7018358,1.8291016,,,,,,,,,,,,,, -3900,2.1831033,1.774907,,,,,,,,,,,,,, -4000,3.0163345,1.821901,,,,,,,,,,,,,, -4100,2.531806,1.8300638,,,,,,,,,,,,,, -4200,1.7500739,1.7639256,,,,,,,,,,,,,, -4300,2.17205,1.7744392,,,,,,,,,,,,,, -4400,2.1448917,1.6746078,,,,,,,,,,,,,, -4500,1.7698121,1.727069,,,,,,,,,,,,,, -4600,2.1101694,1.6738285,,,,,,,,,,,,,, -4700,2.4830942,1.7706861,,,,,,,,,,,,,, -4800,1.8540716,1.7659227,,,,,,,,,,,,,, -4900,4.0543003,1.6977583,,,,,,,,,,,,,, -5000,2.9452178,1.7437613,,,,,,,,,,,,,, -5100,2.5342047,1.7917328,,,,,,,,,,,,,, -5200,3.0287428,1.7844783,,,,,,,,,,,,,, -5300,2.3881545,1.7553918,,,,,,,,,,,,,, -5400,2.2271392,1.7064393,,,,,,,,,,,,,, -5500,3.0420642,1.6817892,,,,,,,,,,,,,, -5535,,,0.40652406,0.1399707976870081,0.7689421,0.2186392731977176,5348.0,0.4684899,0.1491073060751934,2472.0,4362.616629600525,4954.994887113571,4362.616629600525,592.0327932834625,0.1338312625885009,0.0 -5600,4.7742534,1.7184879,,,,,,,,,,,,,, -5700,1.3731099,1.6641781,,,,,,,,,,,,,, -5800,2.2216938,1.7191014,,,,,,,,,,,,,, -5900,3.2030776,1.6918874,,,,,,,,,,,,,, -6000,5.528093,1.7448264,,,,,,,,,,,,,, -6100,3.4189835,1.7708511,,,,,,,,,,,,,, -6200,2.274796,1.668697,,,,,,,,,,,,,, -6300,2.4269617,1.7134167,,,,,,,,,,,,,, -6400,3.0959623,1.7225244,,,,,,,,,,,,,, -6500,2.9915953,1.6983188,,,,,,,,,,,,,, -6600,6.647798,1.6447695,,,,,,,,,,,,,, -6700,2.3531194,1.6328181,,,,,,,,,,,,,, -6800,3.131857,1.6946565,,,,,,,,,,,,,, -6900,2.6483507,1.7154244,,,,,,,,,,,,,, -7000,3.8839488,1.6784545,,,,,,,,,,,,,, -7100,1.8152436,1.7470757,,,,,,,,,,,,,, -7200,2.2496538,1.6895543,,,,,,,,,,,,,, -7300,2.8777688,1.6946967,,,,,,,,,,,,,, -7371,,,0.40096128,0.1349422446564525,0.74107677,0.2118230881373277,5348.0,0.4458511,0.1449231206710945,2472.0,5802.827211141586,6526.008309602737,5802.827211141586,722.7132506370544,0.1826066970825195,0.0 -7400,3.5686023,1.652493,,,,,,,,,,,,,, -7500,3.8087304,1.6971023,,,,,,,,,,,,,, -7600,2.9123175,1.596805,,,,,,,,,,,,,, -7700,2.6579373,1.6762118,,,,,,,,,,,,,, -7800,3.3336494,1.6944206,,,,,,,,,,,,,, -7900,3.0731995,1.6226747,,,,,,,,,,,,,, -8000,2.4559753,1.6976238,,,,,,,,,,,,,, -8100,2.328123,1.679304,,,,,,,,,,,,,, -8200,2.4807913,1.6579226,,,,,,,,,,,,,, -8300,2.3288434,1.6657262,,,,,,,,,,,,,, -8400,2.7541664,1.6264626,,,,,,,,,,,,,, -8500,2.9485102,1.6652688,,,,,,,,,,,,,, -8600,1.9538361,1.6881648,,,,,,,,,,,,,, -8700,3.529711,1.7773422,,,,,,,,,,,,,, -8800,2.6690402,1.6343075,,,,,,,,,,,,,, -8900,1.9475073,1.6664132,,,,,,,,,,,,,, -9000,2.7927241,1.693663,,,,,,,,,,,,,, -9100,3.4678257,1.612729,,,,,,,,,,,,,, -9182,,,0.3648077,0.1252038530953054,0.69743854,0.2004402521795379,5348.0,0.41574928,0.1366156846017915,2472.0,7242.870648384094,8093.72025179863,7242.870648384094,850.259199142456,0.2335467338562011,0.0 -9200,3.1801233,1.6890658,,,,,,,,,,,,,, -9300,2.2019937,1.618517,,,,,,,,,,,,,, -9400,2.8130057,1.6628723,,,,,,,,,,,,,, -9500,2.368885,1.6401602,,,,,,,,,,,,,, -9600,3.0440865,1.5995021,,,,,,,,,,,,,, -9700,1.8493098,1.6056955,,,,,,,,,,,,,, -9800,2.651144,1.683161,,,,,,,,,,,,,, -9900,4.166855,1.6734194,,,,,,,,,,,,,, -10000,3.20456,1.6163191,,,,,,,,,,,,,, -10100,1.7979784,1.6191764,,,,,,,,,,,,,, -10200,4.642097,1.6630516,,,,,,,,,,,,,, -10300,2.2036812,1.5910181,,,,,,,,,,,,,, -10400,2.5882854,1.6274569,,,,,,,,,,,,,, -10500,3.3079925,1.6059136,,,,,,,,,,,,,, -10600,2.366037,1.6503699,,,,,,,,,,,,,, -10700,3.0380108,1.6061053,,,,,,,,,,,,,, -10800,1.9515644,1.6799431,,,,,,,,,,,,,, -10900,3.9812856,1.5743136,,,,,,,,,,,,,, -10970,,,0.38352492,0.1273506142938688,0.6855383,0.1966073549147011,5348.0,0.40698555,0.1310300002031158,2472.0,8683.471186876297,9666.73115491867,8683.471186876297,982.551082611084,0.2804255485534668,0.0 -11000,3.010571,1.6324918,,,,,,,,,,,,,, -11100,2.2636511,1.6140724,,,,,,,,,,,,,, -11200,3.270447,1.6453645,,,,,,,,,,,,,, -11300,4.724914,1.6269257,,,,,,,,,,,,,, -11400,1.7718219,1.578492,,,,,,,,,,,,,, -11500,3.4951794,1.6273112,,,,,,,,,,,,,, -11600,3.0751102,1.5961422,,,,,,,,,,,,,, -11700,4.5559063,1.6009899,,,,,,,,,,,,,, -11800,4.865026,1.6647363,,,,,,,,,,,,,, -11900,2.2077894,1.5901731,,,,,,,,,,,,,, -12000,2.849159,1.6593332,,,,,,,,,,,,,, -12100,2.805188,1.6187911,,,,,,,,,,,,,, -12200,2.1917396,1.6000355,,,,,,,,,,,,,, -12300,3.0754669,1.5788095,,,,,,,,,,,,,, -12400,2.7599978,1.5797539,,,,,,,,,,,,,, -12500,6.6508713,1.5968608,,,,,,,,,,,,,, -12600,2.0794184,1.5587394,,,,,,,,,,,,,, -12700,2.0668676,1.6179935,,,,,,,,,,,,,, -12775,,,0.35075814,0.1195360857635122,0.6756054,0.1937495776089286,5348.0,0.3975374,0.1268458147990169,2472.0,10124.027910232544,11237.498494386671,10124.027910232544,1112.6423828601835,0.3264663219451904,0.0 -12800,4.224319,1.6069769,,,,,,,,,,,,,, -12900,2.3696296,1.5265487,,,,,,,,,,,,,, -13000,2.4544845,1.5685021,,,,,,,,,,,,,, -13100,3.9008436,1.6079493,,,,,,,,,,,,,, -13200,1.7632033,1.5701462,,,,,,,,,,,,,, -13300,2.0235198,1.6314238,,,,,,,,,,,,,, -13400,2.42302,1.5706071,,,,,,,,,,,,,, -13500,2.3198779,1.5903409,,,,,,,,,,,,,, -13600,2.6065993,1.5576247,,,,,,,,,,,,,, -13700,2.171973,1.6506147,,,,,,,,,,,,,, -13800,3.1307964,1.7002369,,,,,,,,,,,,,, -13900,2.6801124,1.5886798,,,,,,,,,,,,,, -14000,6.0700197,1.6369883,,,,,,,,,,,,,, -14100,5.0549426,1.6404419,,,,,,,,,,,,,, -14200,3.4853113,1.6087661,,,,,,,,,,,,,, -14300,3.2152493,1.6027389,,,,,,,,,,,,,, -14400,2.7972991,1.5714526,,,,,,,,,,,,,, -14500,3.642662,1.5207782,,,,,,,,,,,,,, -14568,,,0.32302096,0.1124106191008606,0.6762659,0.1907180165480753,5348.0,0.39017347,0.1259114821359657,2472.0,11564.629160404204,12809.892186164856,11564.629160404204,1244.3145473003387,0.3751053810119629,0.0 -14600,3.1155074,1.7122701,,,,,,,,,,,,,, -14700,2.4573548,1.6130284,,,,,,,,,,,,,, -14800,1.7008929,1.5342846,,,,,,,,,,,,,, -14900,4.0469975,1.5821036,,,,,,,,,,,,,, -15000,2.758504,1.5747862,,,,,,,,,,,,,, -15100,3.4671519,1.5536985,,,,,,,,,,,,,, -15200,3.834559,1.646061,,,,,,,,,,,,,, -15300,3.0832121,1.5852176,,,,,,,,,,,,,, -15400,2.5252283,1.5381199,,,,,,,,,,,,,, -15500,4.4540687,1.5173521,,,,,,,,,,,,,, -15600,3.923683,1.6804298,,,,,,,,,,,,,, -15700,2.3030932,1.5232792,,,,,,,,,,,,,, -15800,2.2291207,1.5694327,,,,,,,,,,,,,, -15900,3.369693,1.5365063,,,,,,,,,,,,,, -16000,1.9274191,1.557719,,,,,,,,,,,,,, -16100,3.4874213,1.6035378,,,,,,,,,,,,,, -16200,2.6815424,1.6405617,,,,,,,,,,,,,, -16300,1.8781509,1.5256722,,,,,,,,,,,,,, -16370,,,0.2960854,0.1013837388079949,0.64389575,0.1853596840997518,5348.0,0.3699334,0.1198789429853959,2472.0,13004.699365615845,14381.36817574501,13004.699365615845,1375.6002094745636,0.4212281703948974,0.0 -16400,1.9612805,1.5440996,,,,,,,,,,,,,, -16500,1.953936,1.5420744,,,,,,,,,,,,,, -16600,4.2378273,1.5627742,,,,,,,,,,,,,, -16700,2.3164105,1.500579,,,,,,,,,,,,,, -16800,2.3962092,1.5572548,,,,,,,,,,,,,, -16900,3.8714406,1.5321198,,,,,,,,,,,,,, -17000,2.225771,1.5562239,,,,,,,,,,,,,, -17100,2.1940665,1.5294456,,,,,,,,,,,,,, -17200,1.8088136,1.6275331,,,,,,,,,,,,,, -17300,3.3690808,1.5811323,,,,,,,,,,,,,, -17400,3.7575696,1.5115478,,,,,,,,,,,,,, -17500,3.1005542,1.4830606,,,,,,,,,,,,,, -17600,1.7904577,1.4849931,,,,,,,,,,,,,, -17700,2.1187365,1.533541,,,,,,,,,,,,,, -17800,3.6292837,1.5411814,,,,,,,,,,,,,, -17900,2.917622,1.512539,,,,,,,,,,,,,, -18000,2.8763998,1.6348284,,,,,,,,,,,,,, -18100,4.1870074,1.5074018,,,,,,,,,,,,,, -18166,,,0.29231447,0.1036590924935779,0.61532944,0.1773270127537967,5348.0,0.35234314,0.1145166859626673,2472.0,14445.077192544935,15953.95691728592,14445.077192544935,1507.6907210350037,0.4677519798278808,0.0 -18200,1.7242769,1.5480369,,,,,,,,,,,,,, -18300,3.1126032,1.5321407,,,,,,,,,,,,,, -18400,2.2179973,1.4467533,,,,,,,,,,,,,, -18500,1.7137692,1.4895431,,,,,,,,,,,,,, -18600,2.7183955,1.4945841,,,,,,,,,,,,,, -18700,3.967495,1.5277716,,,,,,,,,,,,,, -18800,2.455252,1.4854983,,,,,,,,,,,,,, -18900,1.8555727,1.5233126,,,,,,,,,,,,,, -19000,3.754945,1.4606022,,,,,,,,,,,,,, -19100,2.2997527,1.4906483,,,,,,,,,,,,,, -19200,3.1540518,1.5111452,,,,,,,,,,,,,, -19300,1.7851592,1.5553107,,,,,,,,,,,,,, -19400,4.227742,1.5618778,,,,,,,,,,,,,, -19500,1.187541,1.4886507,,,,,,,,,,,,,, -19600,4.580837,1.5446671,,,,,,,,,,,,,, -19700,3.1732166,1.4524859,,,,,,,,,,,,,, -19800,2.7480087,1.5109434,,,,,,,,,,,,,, -19900,4.898711,1.5018779,,,,,,,,,,,,,, -19969,,,0.3047939,0.1021550014542954,0.6024862,0.1741120132848026,5348.0,0.3425359,0.1100075152844636,2472.0,15885.665563583374,17527.51263141632,15885.665563583374,1640.5355043411255,0.5157749652862549,0.0 -20000,2.2157617,1.4226075,,,,,,,,,,,,,, -20100,3.704382,1.4517368,,,,,,,,,,,,,, -20200,4.0508013,1.5146946,,,,,,,,,,,,,, -20300,2.6037297,1.4870082,,,,,,,,,,,,,, -20400,2.1796987,1.4953748,,,,,,,,,,,,,, -20500,2.3166158,1.4630798,,,,,,,,,,,,,, -20600,2.744584,1.482545,,,,,,,,,,,,,, -20700,3.6152349,1.5422212,,,,,,,,,,,,,, -20800,3.2638135,1.5717998,,,,,,,,,,,,,, -20900,5.2695355,1.4525341,,,,,,,,,,,,,, -21000,2.0191584,1.4933275,,,,,,,,,,,,,, -21100,3.912553,1.5742103,,,,,,,,,,,,,, -21200,2.455533,1.5009588,,,,,,,,,,,,,, -21300,1.9242533,1.5314934,,,,,,,,,,,,,, -21400,2.677846,1.5341356,,,,,,,,,,,,,, -21500,3.192833,1.5105462,,,,,,,,,,,,,, -21600,4.420935,1.4969788,,,,,,,,,,,,,, -21700,1.9492556,1.4844321,,,,,,,,,,,,,, -21738,,,0.29166546,0.0989527000290183,0.5862438,0.1679137260202554,5348.0,0.3305221,0.1073873215119939,2472.0,17326.005275011063,19101.038394212723,17326.005275011063,1773.6003246307373,0.5648343563079834,0.0 -21800,1.8737339,1.4349769,,,,,,,,,,,,,, -21900,2.8671157,1.4293501,,,,,,,,,,,,,, -22000,5.849254,1.4766556,,,,,,,,,,,,,, -22100,4.8592815,1.4730444,,,,,,,,,,,,,, -22200,2.514993,1.4442208,,,,,,,,,,,,,, -22300,2.9011319,1.4609357,,,,,,,,,,,,,, -22400,4.671394,1.4734957,,,,,,,,,,,,,, -22500,2.3089578,1.3977188,,,,,,,,,,,,,, -22600,2.8140895,1.5338043,,,,,,,,,,,,,, -22700,2.9441254,1.3793505,,,,,,,,,,,,,, -22800,1.6151469,1.4477165,,,,,,,,,,,,,, -22900,2.201367,1.3952136,,,,,,,,,,,,,, -23000,2.665208,1.4419425,,,,,,,,,,,,,, -23100,2.6154003,1.4344747,,,,,,,,,,,,,, -23200,4.1826897,1.4877838,,,,,,,,,,,,,, -23300,1.5585954,1.4529415,,,,,,,,,,,,,, -23400,9.494918,1.5211116,,,,,,,,,,,,,, -23500,3.294544,1.449667,,,,,,,,,,,,,, -23537,,,0.27430362,0.0944130054532359,0.57207096,0.1631636367147146,5348.0,0.32181597,0.1047874393191558,2472.0,18766.441838026047,20673.603135347366,18766.441838026047,1905.6069421768188,0.6118886470794678,0.0 -23600,2.2260492,1.3684766,,,,,,,,,,,,,, -23700,2.7372549,1.3742781,,,,,,,,,,,,,, -23800,3.7480567,1.4603211,,,,,,,,,,,,,, -23900,3.3611028,1.3907399,,,,,,,,,,,,,, -24000,2.132446,1.3902751,,,,,,,,,,,,,, -24100,1.8088716,1.4568824,,,,,,,,,,,,,, -24200,2.558017,1.4423292,,,,,,,,,,,,,, -24300,5.4985633,1.452022,,,,,,,,,,,,,, -24400,1.4240648,1.3358854,,,,,,,,,,,,,, -24500,2.3602514,1.4183054,,,,,,,,,,,,,, -24600,1.6971518,1.469193,,,,,,,,,,,,,, -24700,2.6768298,1.4001759,,,,,,,,,,,,,, -24800,3.413507,1.4804302,,,,,,,,,,,,,, -24900,2.1265368,1.3368288,,,,,,,,,,,,,, -25000,2.0879326,1.4358295,,,,,,,,,,,,,, -25100,2.6981385,1.3547153,,,,,,,,,,,,,, -25200,2.2378228,1.4388269,,,,,,,,,,,,,, -25300,2.9173691,1.3955234,,,,,,,,,,,,,, -25314,,,0.24832816,0.0862347451543431,0.554332,0.1595624511233188,5348.0,0.31350023,0.1015782097373712,2472.0,20206.777960777283,22244.84517145157,20206.777960777283,2036.3903830051424,0.6608626842498779,0.0 -25400,3.4582756,1.4839863,,,,,,,,,,,,,, -25500,10.237592,1.3833052,,,,,,,,,,,,,, -25600,2.2062337,1.4261118,,,,,,,,,,,,,, -25700,2.5833898,1.4527925,,,,,,,,,,,,,, -25800,2.254484,1.3398405,,,,,,,,,,,,,, -25900,2.6220608,1.3368739,,,,,,,,,,,,,, -26000,2.2182953,1.3943399,,,,,,,,,,,,,, -26100,2.0699399,1.4582828,,,,,,,,,,,,,, -26200,2.7122622,1.3473395,,,,,,,,,,,,,, -26300,6.747374,1.3865296,,,,,,,,,,,,,, -26400,3.0362935,1.4183433,,,,,,,,,,,,,, -26500,1.9753366,1.4517194,,,,,,,,,,,,,, -26600,2.814994,1.3998872,,,,,,,,,,,,,, -26700,6.9330964,1.3721505,,,,,,,,,,,,,, -26800,3.036257,1.34554,,,,,,,,,,,,,, -26900,2.0414038,1.3399901,,,,,,,,,,,,,, -27000,1.6560849,1.3588697,,,,,,,,,,,,,, -27091,,,0.22888745,0.0808210960118939,0.53358024,0.153885515124014,5348.0,0.29808536,0.096419068510958,2472.0,21646.736126184464,23817.134423017505,21646.736126184464,2168.593649625778,0.711815357208252,0.0 -27100,5.018439,1.379112,,,,,,,,,,,,,, -27200,2.838041,1.3927093,,,,,,,,,,,,,, -27300,4.4577136,1.3482045,,,,,,,,,,,,,, -27400,4.357057,1.38092,,,,,,,,,,,,,, -27500,4.6543407,1.3723664,,,,,,,,,,,,,, -27600,2.1419144,1.4105022,,,,,,,,,,,,,, -27700,2.5577,1.3719792,,,,,,,,,,,,,, -27800,1.7286617,1.3608673,,,,,,,,,,,,,, -27900,2.4235528,1.3173388,,,,,,,,,,,,,, -28000,2.0870976,1.4015524,,,,,,,,,,,,,, -28100,1.6531895,1.4157834,,,,,,,,,,,,,, -28200,2.0992556,1.3355936,,,,,,,,,,,,,, -28300,2.0007489,1.352572,,,,,,,,,,,,,, -28400,3.3830416,1.3888535,,,,,,,,,,,,,, -28500,3.1562226,1.3729268,,,,,,,,,,,,,, -28600,2.7961552,1.3502296,,,,,,,,,,,,,, -28700,2.9767363,1.3541294,,,,,,,,,,,,,, -28800,2.4049506,1.3674611,,,,,,,,,,,,,, -28847,,,0.21742319,0.0746847224637162,0.5180734,0.148643038512411,5348.0,0.2901405,0.0931692157699104,2472.0,23087.10746240616,25389.833054065704,23087.10746240616,2300.7924242019653,0.7672219276428223,0.0 -28900,2.1883698,1.3765342,,,,,,,,,,,,,, -29000,1.8508164,1.3691014,,,,,,,,,,,,,, -29100,2.2292895,1.380626,,,,,,,,,,,,,, -29200,2.634272,1.3982683,,,,,,,,,,,,,, -29300,2.0048604,1.3647456,,,,,,,,,,,,,, -29400,2.5392363,1.3709216,,,,,,,,,,,,,, -29500,1.835464,1.3391036,,,,,,,,,,,,,, -29600,3.0834997,1.3924363,,,,,,,,,,,,,, -29700,4.1216884,1.4392048,,,,,,,,,,,,,, -29800,2.9195771,1.3761408,,,,,,,,,,,,,, -29900,5.3063345,1.3364338,,,,,,,,,,,,,, -30000,2.5805295,1.2901117,,,,,,,,,,,,,, -30100,2.0644822,1.3235661,,,,,,,,,,,,,, -30200,1.9993742,1.3474275,,,,,,,,,,,,,, -30300,3.3157551,1.3261554,,,,,,,,,,,,,, -30400,3.9792976,1.4194243,,,,,,,,,,,,,, -30500,2.0503073,1.331744,,,,,,,,,,,,,, -30600,5.026739,1.3210588,,,,,,,,,,,,,, -30628,,,0.2305872,0.0798394435777431,0.5070502,0.1471948405534047,5348.0,0.28189483,0.0909755651697032,2472.0,24527.48431277275,26961.279413223267,24527.48431277275,2431.742198228836,0.8124017715454102,0.0 -30700,3.9210231,1.37498,,,,,,,,,,,,,, -30800,2.6839173,1.3336145,,,,,,,,,,,,,, -30900,2.7781954,1.3070526,,,,,,,,,,,,,, -31000,3.5955129,1.4156748,,,,,,,,,,,,,, -31100,2.0107894,1.2828325,,,,,,,,,,,,,, -31200,2.6907134,1.2668085,,,,,,,,,,,,,, -31300,3.1881163,1.2967743,,,,,,,,,,,,,, -31400,2.5609434,1.3231795,,,,,,,,,,,,,, -31500,1.6868978,1.3376687,,,,,,,,,,,,,, -31600,3.5618317,1.326047,,,,,,,,,,,,,, -31700,1.7764068,1.3051203,,,,,,,,,,,,,, -31800,3.8656745,1.3152419,,,,,,,,,,,,,, -31900,3.490303,1.3132591,,,,,,,,,,,,,, -32000,5.2346005,1.2640686,,,,,,,,,,,,,, -32100,8.513938,1.3237596,,,,,,,,,,,,,, -32200,1.870399,1.3522264,,,,,,,,,,,,,, -32300,1.5213444,1.2905583,,,,,,,,,,,,,, -32400,3.2244134,1.298488,,,,,,,,,,,,,, -32427,,,0.20866887,0.0707545890102496,0.4951226,0.1422226942274829,5348.0,0.27268323,0.0883553713972335,2472.0,25967.66591262817,28531.426872015,25967.66591262817,2561.5807065963745,0.8636722564697266,0.0 -32500,2.2492063,1.3141037,,,,,,,,,,,,,, -32600,2.022898,1.2729625,,,,,,,,,,,,,, -32700,2.750699,1.3251216,,,,,,,,,,,,,, -32800,3.2717335,1.3717716,,,,,,,,,,,,,, -32900,2.3952148,1.3188732,,,,,,,,,,,,,, -33000,2.5356965,1.2972109,,,,,,,,,,,,,, -33100,1.5524023,1.2889854,,,,,,,,,,,,,, -33200,1.696713,1.2815287,,,,,,,,,,,,,, -33300,2.8857896,1.3231729,,,,,,,,,,,,,, -33400,2.423259,1.3302379,,,,,,,,,,,,,, -33500,1.6857424,1.279244,,,,,,,,,,,,,, -33600,2.4045746,1.3078536,,,,,,,,,,,,,, -33700,1.8491238,1.208808,,,,,,,,,,,,,, -33800,2.52126,1.3152084,,,,,,,,,,,,,, -33900,3.86889,1.3214546,,,,,,,,,,,,,, -34000,2.9470925,1.2687355,,,,,,,,,,,,,, -34100,5.214308,1.2805915,,,,,,,,,,,,,, -34173,,,0.21239454,0.0721840662868767,0.47541836,0.1369222896975197,5348.0,0.25914782,0.0833993459671358,2472.0,27408.01134252548,30103.295015335083,27408.01134252548,2692.9797925949097,0.9129633903503418,0.0 -34200,3.252927,1.237434,,,,,,,,,,,,,, -34300,2.300517,1.3085362,,,,,,,,,,,,,, -34400,1.9670256,1.2676315,,,,,,,,,,,,,, -34500,5.0336175,1.2972143,,,,,,,,,,,,,, -34600,3.5498044,1.2826614,,,,,,,,,,,,,, -34700,2.0048327,1.3038566,,,,,,,,,,,,,, -34800,2.1913335,1.2713896,,,,,,,,,,,,,, -34900,3.7771518,1.2922783,,,,,,,,,,,,,, -35000,5.121022,1.280906,,,,,,,,,,,,,, -35100,5.566214,1.2531986,,,,,,,,,,,,,, -35200,2.5523,1.2434317,,,,,,,,,,,,,, -35300,3.884917,1.2298523,,,,,,,,,,,,,, -35400,2.5031083,1.2354361,,,,,,,,,,,,,, -35500,2.0341086,1.3008721,,,,,,,,,,,,,, -35600,2.688349,1.2878748,,,,,,,,,,,,,, -35700,1.8631344,1.2895617,,,,,,,,,,,,,, -35800,2.5278823,1.2568163,,,,,,,,,,,,,, -35900,4.2613807,1.27098,,,,,,,,,,,,,, -35938,,,0.20002264,0.0666794547992796,0.46640846,0.1356961487588943,5348.0,0.2501728,0.0803119858631405,2472.0,28848.267106056213,31675.596952438354,28848.267106056213,2824.903109550476,0.9611873626708984,0.0 -36000,2.0855563,1.2209781,,,,,,,,,,,,,, -36100,2.277613,1.2391404,,,,,,,,,,,,,, -36200,2.9418259,1.2470715,,,,,,,,,,,,,, -36300,6.1638145,1.1888083,,,,,,,,,,,,,, -36400,2.5981278,1.2916989,,,,,,,,,,,,,, -36500,2.7353551,1.266248,,,,,,,,,,,,,, -36600,1.6357896,1.2587373,,,,,,,,,,,,,, -36700,3.1434395,1.3299625,,,,,,,,,,,,,, -36800,3.5770545,1.2622198,,,,,,,,,,,,,, -36900,1.7266582,1.2355268,,,,,,,,,,,,,, -37000,2.4963925,1.2483715,,,,,,,,,,,,,, -37100,6.5937324,1.2891723,,,,,,,,,,,,,, -37200,2.3034153,1.2051512,,,,,,,,,,,,,, -37300,4.276433,1.2739029,,,,,,,,,,,,,, -37400,6.642619,1.2051926,,,,,,,,,,,,,, -37500,3.4993474,1.2433226,,,,,,,,,,,,,, -37600,3.007049,1.258966,,,,,,,,,,,,,, -37700,1.6916414,1.2215627,,,,,,,,,,,,,, -37722,,,0.15005884,0.0538370078820864,0.4560848,0.1316218851675565,5348.0,0.24341933,0.0786870594926167,2472.0,30288.733597755432,33249.00675344467,30288.733597755432,2957.7188007831573,1.0137641429901123,0.0 -37800,2.2352738,1.2707323,,,,,,,,,,,,,, -37900,4.846207,1.1948335,,,,,,,,,,,,,, -38000,2.4782212,1.2534602,,,,,,,,,,,,,, -38100,2.3231022,1.2264944,,,,,,,,,,,,,, -38200,3.1351254,1.2162018,,,,,,,,,,,,,, -38300,2.6379375,1.2820777,,,,,,,,,,,,,, -38400,2.9155157,1.2210474,,,,,,,,,,,,,, -38500,2.2097201,1.1932077,,,,,,,,,,,,,, -38600,5.430786,1.2278643,,,,,,,,,,,,,, -38700,1.9949516,1.2386808,,,,,,,,,,,,,, -38800,2.4438336,1.2168628,,,,,,,,,,,,,, -38900,2.1433482,1.2215295,,,,,,,,,,,,,, -39000,2.189992,1.2167411,,,,,,,,,,,,,, -39100,2.3877618,1.2177896,,,,,,,,,,,,,, -39200,2.6079893,1.1899416,,,,,,,,,,,,,, -39300,1.6402022,1.2005407,,,,,,,,,,,,,, -39400,3.6609797,1.2059039,,,,,,,,,,,,,, -39500,2.9044676,1.1950566,,,,,,,,,,,,,, -39509,,,0.17008072,0.0592331735338218,0.44051743,0.1278662251272,5348.0,0.23879972,0.0763715394146202,2472.0,31729.1388194561,34820.944878816605,31729.1388194561,3089.12416434288,1.0664176940917969,0.0 -39600,2.1438193,1.2078524,,,,,,,,,,,,,, -39700,3.2984848,1.2403901,,,,,,,,,,,,,, -39800,3.5226316,1.2051817,,,,,,,,,,,,,, -39900,2.3160512,1.269947,,,,,,,,,,,,,, -40000,2.6796317,1.1776965,,,,,,,,,,,,,, -40100,5.896298,1.281092,,,,,,,,,,,,,, -40200,2.2972124,1.2148613,,,,,,,,,,,,,, -40300,3.8408937,1.1789999,,,,,,,,,,,,,, -40400,2.6197686,1.1747198,,,,,,,,,,,,,, -40500,2.8576992,1.2233156,,,,,,,,,,,,,, -40600,4.429656,1.2069782,,,,,,,,,,,,,, -40700,3.7871127,1.2011646,,,,,,,,,,,,,, -40800,2.7789724,1.1968545,,,,,,,,,,,,,, -40900,2.5055826,1.2014581,,,,,,,,,,,,,, -41000,4.639225,1.2589549,,,,,,,,,,,,,, -41100,2.0932946,1.175172,,,,,,,,,,,,,, -41200,3.606099,1.1759521,,,,,,,,,,,,,, -41278,,,0.21692227,0.0749608110087165,0.43239108,0.1240429825154233,5348.0,0.23437671,0.074807547782991,2472.0,33169.0481672287,36391.13930130005,33169.0481672287,3219.284531354904,1.1159639358520508,0.0 -41300,3.6761582,1.2324854,,,,,,,,,,,,,, -41400,2.0774283,1.1871362,,,,,,,,,,,,,, -41500,2.0553803,1.2002705,,,,,,,,,,,,,, -41600,1.5552367,1.1235679,,,,,,,,,,,,,, -41700,1.7579101,1.1852146,,,,,,,,,,,,,, -41800,1.7808543,1.1584316,,,,,,,,,,,,,, -41900,1.9944907,1.1762139,,,,,,,,,,,,,, -42000,1.9940381,1.2305039,,,,,,,,,,,,,, -42100,2.0984652,1.2191978,,,,,,,,,,,,,, -42200,1.9511794,1.1599733,,,,,,,,,,,,,, -42300,4.3487735,1.1906337,,,,,,,,,,,,,, -42400,2.5725913,1.1867044,,,,,,,,,,,,,, -42500,3.0103676,1.1534028,,,,,,,,,,,,,, -42600,2.9090593,1.2143291,,,,,,,,,,,,,, -42700,2.0068152,1.1527586,,,,,,,,,,,,,, -42800,4.141611,1.175392,,,,,,,,,,,,,, -42900,5.9889746,1.1455343,,,,,,,,,,,,,, -43000,1.8559194,1.1963103,,,,,,,,,,,,,, -43028,,,0.22430344,0.0763863059782276,0.4287209,0.123058207903299,5348.0,0.22969964,0.0736904109032559,2472.0,34608.96385860443,37962.72094845772,34608.96385860443,3350.8226585388184,1.16986346244812,0.0 -43100,9.347398,1.141843,,,,,,,,,,,,,, -43200,2.952062,1.147636,,,,,,,,,,,,,, -43300,1.9802566,1.2147852,,,,,,,,,,,,,, -43400,2.2905948,1.1146485,,,,,,,,,,,,,, -43500,1.7701993,1.1861928,,,,,,,,,,,,,, -43600,4.3219256,1.128306,,,,,,,,,,,,,, -43700,2.2062466,1.16573,,,,,,,,,,,,,, -43800,3.5725725,1.2044322,,,,,,,,,,,,,, -43900,3.1032784,1.1781485,,,,,,,,,,,,,, -44000,3.9848642,1.1399912,,,,,,,,,,,,,, -44100,3.9128952,1.2048758,,,,,,,,,,,,,, -44200,2.3122158,1.1601758,,,,,,,,,,,,,, -44300,8.234175,1.1623975,,,,,,,,,,,,,, -44400,3.13878,1.1482607,,,,,,,,,,,,,, -44500,3.157094,1.1595961,,,,,,,,,,,,,, -44600,3.3220623,1.1573782,,,,,,,,,,,,,, -44700,3.419161,1.12581,,,,,,,,,,,,,, -44794,,,0.2613986,0.0901661879719919,0.4249241,0.1222375623931954,5348.0,0.22824307,0.0731826214124672,2472.0,36049.72664427757,39536.94931817055,36049.72664427757,3484.164263486862,1.2190589904785156,0.0 -44800,1.7931713,1.1617846,,,,,,,,,,,,,, -44900,1.7942293,1.1012522,,,,,,,,,,,,,, -45000,2.1364493,1.1547511,,,,,,,,,,,,,, -45100,3.8112147,1.1678705,,,,,,,,,,,,,, -45200,6.1479683,1.1755713,,,,,,,,,,,,,, -45300,2.830638,1.1126196,,,,,,,,,,,,,, -45400,1.626453,1.1242585,,,,,,,,,,,,,, -45500,1.7778716,1.151994,,,,,,,,,,,,,, -45600,3.7093747,1.1625477,,,,,,,,,,,,,, -45700,2.364338,1.1151094,,,,,,,,,,,,,, -45800,4.812857,1.2120728,,,,,,,,,,,,,, -45900,3.4821835,1.1721866,,,,,,,,,,,,,, -46000,4.4931827,1.1564907,,,,,,,,,,,,,, -46100,3.5957367,1.1588371,,,,,,,,,,,,,, -46200,6.5694737,1.1716566,,,,,,,,,,,,,, -46300,2.9756002,1.1511474,,,,,,,,,,,,,, -46400,3.3776476,1.1855053,,,,,,,,,,,,,, -46500,1.5809561,1.0813704,,,,,,,,,,,,,, -46564,,,0.23482683,0.0778161778991928,0.42247275,0.1216389739034727,5348.0,0.226913,0.0726951435013101,2472.0,37489.63005280495,41106.78757286072,37489.63005280495,3613.956609010696,1.2820138931274414,0.0 -46600,2.5312173,1.1699206,,,,,,,,,,,,,, -46700,3.759405,1.1667657,,,,,,,,,,,,,, -46800,3.3040366,1.1825781,,,,,,,,,,,,,, -46900,2.2483091,1.164928,,,,,,,,,,,,,, -47000,3.1529021,1.1355599,,,,,,,,,,,,,, -47100,1.5340576,1.1007953,,,,,,,,,,,,,, -47200,3.2353375,1.1495849,,,,,,,,,,,,,, -47300,1.5537138,1.1087275,,,,,,,,,,,,,, -47400,1.3677036,1.110591,,,,,,,,,,,,,, -47500,2.0644357,1.147578,,,,,,,,,,,,,, -47600,3.2833693,1.1521354,,,,,,,,,,,,,, -47700,2.145599,1.1287023,,,,,,,,,,,,,, -47800,9.429043,1.1940804,,,,,,,,,,,,,, -47900,2.9377143,1.1242009,,,,,,,,,,,,,, -48000,2.5708025,1.143008,,,,,,,,,,,,,, -48100,1.375596,1.1502788,,,,,,,,,,,,,, -48200,2.5162866,1.2083972,,,,,,,,,,,,,, -48300,3.1899545,1.1264668,,,,,,,,,,,,,, -48322,,,0.21850312,0.0765535056172269,0.42252696,0.1217934483524334,5348.0,0.22681805,0.0728576361383624,2472.0,38930.12107872963,42681.298971414566,38930.12107872963,3747.8457283973694,1.3374078273773191,0.0 -48400,2.2012398,1.1098186,,,,,,,,,,,,,, -48500,3.4187167,1.1587291,,,,,,,,,,,,,, -48600,2.4667897,1.1859493,,,,,,,,,,,,,, -48700,4.8254747,1.2130209,,,,,,,,,,,,,, -48800,2.1052098,1.106041,,,,,,,,,,,,,, -48900,3.4375129,1.1414121,,,,,,,,,,,,,, -49000,3.3616028,1.1713916,,,,,,,,,,,,,, -49100,2.658111,1.1981124,,,,,,,,,,,,,, -49200,1.9430327,1.1625795,,,,,,,,,,,,,, -49300,2.4433174,1.1973493,,,,,,,,,,,,,, -49400,2.1948607,1.1870267,,,,,,,,,,,,,, -49500,4.882152,1.155718,,,,,,,,,,,,,, -49600,2.298471,1.1183537,,,,,,,,,,,,,, -49700,5.2575254,1.1546537,,,,,,,,,,,,,, -49800,4.2814784,1.1231865,,,,,,,,,,,,,, -49900,3.4586515,1.2152569,,,,,,,,,,,,,, -50000,3.4008996,1.1468763,,,,,,,,,,,,,, -50079,,,0.19701824,0.0689132018836529,0.42258406,0.1218513762707937,5348.0,0.2268369,0.0729388824568886,2472.0,40370.472751140594,44255.343381881714,40370.472751140594,3881.402559518814,1.3948400020599363,0.0 -50100,2.733703,1.1472267,,,,,,,,,,,,,, -50200,3.6354227,1.1159369,,,,,,,,,,,,,, -50300,2.4819696,1.1307584,,,,,,,,,,,,,, -50400,2.1037498,1.1452318,,,,,,,,,,,,,, -50500,5.674554,1.1755155,,,,,,,,,,,,,, -50600,2.4012651,1.1297774,,,,,,,,,,,,,, -50700,3.0166533,1.1388375,,,,,,,,,,,,,, -50800,2.730151,1.1748172,,,,,,,,,,,,,, -50900,2.9794497,1.168679,,,,,,,,,,,,,, -51000,4.057518,1.1348189,,,,,,,,,,,,,, -51100,3.9989173,1.1174805,,,,,,,,,,,,,, -51200,2.4257293,1.232544,,,,,,,,,,,,,, -51300,1.9292037,1.1242583,,,,,,,,,,,,,, -51400,5.2981977,1.1448026,,,,,,,,,,,,,, -51500,4.1391115,1.1428359,,,,,,,,,,,,,, -51600,5.702473,1.1447918,,,,,,,,,,,,,, -51700,2.405723,1.123324,,,,,,,,,,,,,, -51800,4.1800437,1.190918,,,,,,,,,,,,,, -51869,,,0.22619057,0.0782449070478528,0.4225461,0.1218513762707937,5348.0,0.22682315,0.0728982592976255,2472.0,41810.91124749184,45827.35194015503,41810.91124749184,4012.835418462753,1.4524447917938232,0.0 -51900,2.1477,1.1526926,,,,,,,,,,,,,, -52000,8.1860485,1.1880256,,,,,,,,,,,,,, -52100,1.9955735,1.1438161,,,,,,,,,,,,,, -52200,1.8715823,1.1724622,,,,,,,,,,,,,, -52300,2.006199,1.0771673,,,,,,,,,,,,,, -52400,2.9390562,1.1366509,,,,,,,,,,,,,, -52500,6.3436313,1.1387745,,,,,,,,,,,,,, -52600,1.6763619,1.2027599,,,,,,,,,,,,,, -52700,1.9659889,1.189013,,,,,,,,,,,,,, -52800,2.2014244,1.1245428,,,,,,,,,,,,,, -52900,2.9989903,1.1777728,,,,,,,,,,,,,, -53000,3.2983809,1.1153691,,,,,,,,,,,,,, -53100,2.0897758,1.1390421,,,,,,,,,,,,,, -53200,2.5380428,1.154625,,,,,,,,,,,,,, -53300,2.899253,1.1742717,,,,,,,,,,,,,, -53400,3.569539,1.1412053,,,,,,,,,,,,,, -53500,2.0203795,1.182277,,,,,,,,,,,,,, -53600,3.174188,1.1264837,,,,,,,,,,,,,, -53618,,,0.21388717,0.0744126868235698,0.42257115,0.1218127576585535,5348.0,0.2268347,0.0729185708772571,2472.0,43250.74343371391,47398.51116228104,43250.74343371391,4143.967495203018,1.5726919174194336,0.0 -53700,2.5912793,1.0802282,,,,,,,,,,,,,, -53800,2.0994432,1.1526407,,,,,,,,,,,,,, -53900,2.1073854,1.1869255,,,,,,,,,,,,,, -54000,7.3635674,1.1257236,,,,,,,,,,,,,, -54100,1.8330091,1.1394753,,,,,,,,,,,,,, -54200,9.830779,1.1129842,,,,,,,,,,,,,, -54300,3.0193262,1.1857193,,,,,,,,,,,,,, -54400,3.1978862,1.1613654,,,,,,,,,,,,,, -54500,3.455709,1.1355755,,,,,,,,,,,,,, -54600,2.8603172,1.1861067,,,,,,,,,,,,,, -54700,2.7625842,1.1860484,,,,,,,,,,,,,, -54800,5.677431,1.2099091,,,,,,,,,,,,,, -54900,3.8333993,1.076345,,,,,,,,,,,,,, -55000,3.318021,1.1266305,,,,,,,,,,,,,, -55100,2.1428902,1.182273,,,,,,,,,,,,,, -55200,3.0428145,1.1433933,,,,,,,,,,,,,, -55300,2.858376,1.1053675,,,,,,,,,,,,,, -55400,5.759934,1.1534581,,,,,,,,,,,,,, -55401,,,0.21749085,0.0768450514759327,0.42252436,0.1218417216177336,5348.0,0.22681281,0.0728982592976255,2472.0,44691.26508259773,48970.09406757355,44691.26508259773,4274.887751102448,1.6357917785644531,0.0 -55500,2.1320882,1.170768,,,,,,,,,,,,,, -55600,2.6095743,1.1201618,,,,,,,,,,,,,, -55700,2.2812479,1.1455312,,,,,,,,,,,,,, -55800,5.4286256,1.1363753,,,,,,,,,,,,,, -55900,4.562988,1.1331922,,,,,,,,,,,,,, -56000,2.1448805,1.2026981,,,,,,,,,,,,,, -56100,3.17115,1.1602777,,,,,,,,,,,,,, -56200,3.73157,1.2351124,,,,,,,,,,,,,, -56300,1.876375,1.170233,,,,,,,,,,,,,, -56400,2.4084947,1.1660596,,,,,,,,,,,,,, -56500,2.8457193,1.1341507,,,,,,,,,,,,,, -56600,3.2830698,1.1448203,,,,,,,,,,,,,, -56700,2.2661626,1.1413167,,,,,,,,,,,,,, -56800,2.3884525,1.1559525,,,,,,,,,,,,,, -56900,3.361487,1.1660213,,,,,,,,,,,,,, -57000,1.892412,1.1221881,,,,,,,,,,,,,, -57100,3.9046557,1.1946845,,,,,,,,,,,,,, -57190,,,0.22159734,0.0764886128364389,0.42251283,0.1218127576585535,5348.0,0.22681028,0.072877947717994,2472.0,46131.35928487778,50545.039157152176,46131.35928487778,4409.608050823212,1.689574956893921,0.0 -57200,3.026337,1.1572165,,,,,,,,,,,,,, -57300,3.4721272,1.1781429,,,,,,,,,,,,,, -57400,3.4079318,1.1310569,,,,,,,,,,,,,, -57500,1.8756973,1.157388,,,,,,,,,,,,,, -57600,7.216372,1.2042435,,,,,,,,,,,,,, -57700,3.4842346,1.180675,,,,,,,,,,,,,, -57800,4.0374465,1.140808,,,,,,,,,,,,,, -57900,3.874406,1.1560063,,,,,,,,,,,,,, -58000,4.3898716,1.1669785,,,,,,,,,,,,,, -58100,2.8787203,1.135456,,,,,,,,,,,,,, -58200,9.623161,1.1349812,,,,,,,,,,,,,, -58300,2.2994459,1.1820766,,,,,,,,,,,,,, -58400,7.7999997,1.1882815,,,,,,,,,,,,,, -58500,1.8300253,1.1345738,,,,,,,,,,,,,, -58600,3.6922503,1.1316732,,,,,,,,,,,,,, -58700,1.6026765,1.1346258,,,,,,,,,,,,,, -58800,2.034253,1.1616706,,,,,,,,,,,,,, -58900,1.6049023,1.177965,,,,,,,,,,,,,, -58964,,,0.23678349,0.0812758261506587,0.42260745,0.1218610309238537,5348.0,0.22684751,0.0729795056161517,2472.0,47571.82388663292,52122.92017412186,47571.82388663292,4546.888776302338,1.7484843730926514,0.0 -59000,2.4320796,1.1792858,,,,,,,,,,,,,, -59100,3.6296303,1.1687647,,,,,,,,,,,,,, -59200,2.655258,1.1355633,,,,,,,,,,,,,, -59300,1.5416077,1.1391968,,,,,,,,,,,,,, -59400,4.3885427,1.1574324,,,,,,,,,,,,,, -59500,2.2828236,1.188089,,,,,,,,,,,,,, -59600,7.057581,1.1341178,,,,,,,,,,,,,, -59700,1.6859819,1.1564615,,,,,,,,,,,,,, -59800,3.5557625,1.163299,,,,,,,,,,,,,, -59900,4.371491,1.205709,,,,,,,,,,,,,, -60000,4.863319,1.1290767,,,,,,,,,,,,,, -60100,3.207778,1.1672331,,,,,,,,,,,,,, -60200,2.0953195,1.1750972,,,,,,,,,,,,,, -60300,1.7341896,1.2252028,,,,,,,,,,,,,, -60400,2.1569643,1.1338902,,,,,,,,,,,,,, -60500,4.367263,1.1635665,,,,,,,,,,,,,, -60600,1.9014808,1.2163821,,,,,,,,,,,,,, -60700,3.1702085,1.2051802,,,,,,,,,,,,,, -60715,,,0.2431013,0.0816443219812892,0.42257443,0.1218320669646736,5348.0,0.2268356,0.0729185708772571,2472.0,49012.130105018616,53696.184061050415,49012.130105018616,4679.714070558548,1.80523419380188,0.0 -60800,2.3645024,1.1190033,,,,,,,,,,,,,, -60900,4.3348083,1.1823281,,,,,,,,,,,,,, -61000,2.943033,1.1624844,,,,,,,,,,,,,, -61100,3.0639234,1.0998844,,,,,,,,,,,,,, -61200,2.937921,1.1296096,,,,,,,,,,,,,, -61300,3.7053065,1.1221472,,,,,,,,,,,,,, -61400,2.3185744,1.161425,,,,,,,,,,,,,, -61500,2.0291195,1.1848115,,,,,,,,,,,,,, -61600,1.8809422,1.1661923,,,,,,,,,,,,,, -61700,1.6025213,1.1848338,,,,,,,,,,,,,, -61800,2.4720807,1.1083226,,,,,,,,,,,,,, -61900,4.808458,1.1584458,,,,,,,,,,,,,, -62000,3.2793834,1.1418455,,,,,,,,,,,,,, -62100,5.287083,1.1533198,,,,,,,,,,,,,, -62200,2.450522,1.1813105,,,,,,,,,,,,,, -62300,2.1994278,1.1173295,,,,,,,,,,,,,, -62400,2.8430257,1.145706,,,,,,,,,,,,,, -62476,,,0.21817155,0.076703208016679,0.42259833,0.1218610309238537,5348.0,0.22685315,0.0729185708772571,2472.0,50452.07918572426,55266.86240553856,50452.07918572426,4810.312841653824,1.8586266040802,0.0 -62500,3.8762467,1.185453,,,,,,,,,,,,,, -62600,4.1034484,1.0881213,,,,,,,,,,,,,, -62700,3.3193886,1.1324604,,,,,,,,,,,,,, -62800,1.6703168,1.188036,,,,,,,,,,,,,, -62900,4.083205,1.1382143,,,,,,,,,,,,,, -63000,2.4488344,1.1298726,,,,,,,,,,,,,, -63100,2.5898387,1.2247764,,,,,,,,,,,,,, -63200,2.7756996,1.1491699,,,,,,,,,,,,,, -63300,2.825078,1.2104884,,,,,,,,,,,,,, -63400,4.106704,1.1662264,,,,,,,,,,,,,, -63500,4.103425,1.1955718,,,,,,,,,,,,,, -63600,3.719975,1.1921256,,,,,,,,,,,,,, -63700,1.7445159,1.1699611,,,,,,,,,,,,,, -63800,4.4366117,1.1448071,,,,,,,,,,,,,, -63900,1.9847214,1.1511854,,,,,,,,,,,,,, -64000,2.465401,1.137477,,,,,,,,,,,,,, -64100,4.7981706,1.1611968,,,,,,,,,,,,,, -64200,2.6756163,1.1605272,,,,,,,,,,,,,, -64265,,,0.22722098,0.0794977408607548,0.4225785,0.1218224123116135,5348.0,0.22683159,0.0729185708772571,2472.0,51892.149191617966,56837.77284407616,51892.149191617966,4941.022246599197,1.9132137298583984,0.0 -64300,2.8110316,1.1782596,,,,,,,,,,,,,, -64400,1.8419876,1.1366131,,,,,,,,,,,,,, -64500,3.0068405,1.1542914,,,,,,,,,,,,,, -64600,2.6651897,1.1295164,,,,,,,,,,,,,, -64700,1.5717692,1.148285,,,,,,,,,,,,,, -64800,3.4087555,1.1356332,,,,,,,,,,,,,, -64900,5.2207994,1.122827,,,,,,,,,,,,,, -65000,1.813066,1.1533285,,,,,,,,,,,,,, -65100,2.4005497,1.1246667,,,,,,,,,,,,,, -65200,4.296851,1.1508813,,,,,,,,,,,,,, -65300,2.598357,1.1953092,,,,,,,,,,,,,, -65400,4.428428,1.2065583,,,,,,,,,,,,,, -65500,1.5967836,1.1101025,,,,,,,,,,,,,, -65600,3.5658538,1.1466534,,,,,,,,,,,,,, -65700,3.2606075,1.1210442,,,,,,,,,,,,,, -65800,3.165149,1.0951533,,,,,,,,,,,,,, -65900,3.8460686,1.1735529,,,,,,,,,,,,,, -66000,2.0311038,1.21788,,,,,,,,,,,,,, -66011,,,0.22163488,0.0767140561531937,0.42258117,0.1218127576585535,5348.0,0.22683603,0.0729185708772571,2472.0,53332.38691663742,58409.85575699806,53332.38691663742,5072.738001346588,1.9674947261810305,0.0 -66100,3.8465545,1.1760159,,,,,,,,,,,,,, -66200,3.6723337,1.1586971,,,,,,,,,,,,,, -66300,2.801886,1.2024193,,,,,,,,,,,,,, -66400,2.7665505,1.1607151,,,,,,,,,,,,,, -66500,4.2815027,1.1655546,,,,,,,,,,,,,, -66600,2.084383,1.1458493,,,,,,,,,,,,,, -66700,3.066862,1.1414406,,,,,,,,,,,,,, -66800,1.9154719,1.1506689,,,,,,,,,,,,,, -66900,6.4264994,1.1605676,,,,,,,,,,,,,, -67000,11.372747,1.1661901,,,,,,,,,,,,,, -67100,3.766624,1.166751,,,,,,,,,,,,,, -67200,2.3110774,1.1315175,,,,,,,,,,,,,, -67300,2.399348,1.1420774,,,,,,,,,,,,,, -67400,3.974199,1.1924882,,,,,,,,,,,,,, -67500,1.6888963,1.149469,,,,,,,,,,,,,, -67600,2.809201,1.153397,,,,,,,,,,,,,, -67700,3.6458616,1.1501405,,,,,,,,,,,,,, -67774,,,0.23132047,0.0787533378502251,0.4225996,0.1218610309238537,5348.0,0.22684477,0.0729591940365202,2472.0,54772.54679131508,59981.093438625336,54772.54679131508,5203.679777383804,2.0264124870300293,0.0 -67800,2.2264147,1.1470956,,,,,,,,,,,,,, -67900,3.0609467,1.1437808,,,,,,,,,,,,,, -68000,2.5651968,1.13522,,,,,,,,,,,,,, -68100,3.0055046,1.1821085,,,,,,,,,,,,,, -68200,2.3683357,1.194066,,,,,,,,,,,,,, -68300,2.717105,1.1614184,,,,,,,,,,,,,, -68400,4.515357,1.1520364,,,,,,,,,,,,,, -68500,1.9546171,1.1669754,,,,,,,,,,,,,, -68600,4.707743,1.2180188,,,,,,,,,,,,,, -68700,1.9136918,1.1474493,,,,,,,,,,,,,, -68800,3.05659,1.1505024,,,,,,,,,,,,,, -68900,1.6625074,1.1176311,,,,,,,,,,,,,, -69000,1.8064642,1.1262068,,,,,,,,,,,,,, -69100,1.3810722,1.145002,,,,,,,,,,,,,, -69200,2.0831168,1.1491636,,,,,,,,,,,,,, -69300,2.8253784,1.1518677,,,,,,,,,,,,,, -69400,2.602095,1.1307696,,,,,,,,,,,,,, -69500,1.6085724,1.1559668,,,,,,,,,,,,,, -69539,,,0.2290587,0.0786990451912971,0.4225713,0.1218224123116135,5348.0,0.2268291,0.0729185708772571,2472.0,56212.43574166298,61553.21344566345,56212.43574166298,5335.775634050369,2.084322690963745,0.0 -69600,2.6639552,1.1582067,,,,,,,,,,,,,, -69700,5.493202,1.1573026,,,,,,,,,,,,,, -69800,2.1604848,1.1260226,,,,,,,,,,,,,, -69900,1.5426872,1.0772452,,,,,,,,,,,,,, -70000,4.098806,1.2006718,,,,,,,,,,,,,, -70100,5.261758,1.19547,,,,,,,,,,,,,, -70200,3.976138,1.1773174,,,,,,,,,,,,,, -70300,2.6764421,1.1747305,,,,,,,,,,,,,, -70400,1.622616,1.169966,,,,,,,,,,,,,, -70500,5.735535,1.1192516,,,,,,,,,,,,,, -70600,2.057951,1.1572425,,,,,,,,,,,,,, -70700,5.5379915,1.1873928,,,,,,,,,,,,,, -70800,11.455913,1.1472019,,,,,,,,,,,,,, -70900,1.6669604,1.1930556,,,,,,,,,,,,,, -71000,2.3148243,1.176155,,,,,,,,,,,,,, -71100,1.6451184,1.1223189,,,,,,,,,,,,,, -71200,1.7283362,1.1251101,,,,,,,,,,,,,, -71300,2.2950983,1.166664,,,,,,,,,,,,,, -71324,,,0.2350017,0.0823583454169111,0.4226408,0.1218513762707937,5348.0,0.22685763,0.0729591940365202,2472.0,57652.62449741364,63124.72708106041,57652.62449741364,5466.968840122223,2.1383402347564697,0.0 -71400,1.6574458,1.128549,,,,,,,,,,,,,, -71500,1.7853694,1.0989736,,,,,,,,,,,,,, -71600,4.7188,1.1192973,,,,,,,,,,,,,, -71700,3.6418934,1.1396421,,,,,,,,,,,,,, -71800,3.518211,1.1635352,,,,,,,,,,,,,, -71900,2.201376,1.140188,,,,,,,,,,,,,, -72000,4.2583804,1.1510391,,,,,,,,,,,,,, -72100,1.5818093,1.1265041,,,,,,,,,,,,,, -72200,2.6235924,1.1893166,,,,,,,,,,,,,, -72300,2.5871215,1.1580632,,,,,,,,,,,,,, -72400,2.3904407,1.1694158,,,,,,,,,,,,,, -72500,2.006844,1.1416814,,,,,,,,,,,,,, -72600,2.2341576,1.1717298,,,,,,,,,,,,,, -72700,2.501568,1.135774,,,,,,,,,,,,,, -72800,2.9440908,1.140454,,,,,,,,,,,,,, -72900,2.1369054,1.152141,,,,,,,,,,,,,, -73000,7.444216,1.1248419,,,,,,,,,,,,,, -73071,,,0.20977716,0.0732559520107647,0.42251086,0.1218320669646736,5348.0,0.22680901,0.0729185708772571,2472.0,59093.15683174133,64696.9893181324,59093.15683174133,5598.56943821907,2.190993070602417,0.0 -73100,2.9500675,1.0996946,,,,,,,,,,,,,, -73200,2.530142,1.1403384,,,,,,,,,,,,,, -73300,4.7320504,1.1674907,,,,,,,,,,,,,, -73400,1.9608619,1.188479,,,,,,,,,,,,,, -73500,3.2526543,1.098228,,,,,,,,,,,,,, -73600,2.2163801,1.163637,,,,,,,,,,,,,, -73700,5.5060816,1.137398,,,,,,,,,,,,,, -73800,8.050046,1.2150488,,,,,,,,,,,,,, -73900,7.012409,1.1859423,,,,,,,,,,,,,, -74000,4.020607,1.1565896,,,,,,,,,,,,,, -74100,1.8750495,1.158937,,,,,,,,,,,,,, -74200,4.094372,1.1387184,,,,,,,,,,,,,, -74300,4.6546435,1.116991,,,,,,,,,,,,,, -74400,3.1694932,1.1668106,,,,,,,,,,,,,, -74500,3.5919771,1.1681691,,,,,,,,,,,,,, -74600,3.541694,1.1588627,,,,,,,,,,,,,, -74700,2.7963188,1.129328,,,,,,,,,,,,,, -74800,2.6347775,1.1362256,,,,,,,,,,,,,, -74813,,,0.20217852,0.069316295956455,0.42254487,0.1218127576585535,5348.0,0.2268203,0.0729388824568886,2472.0,60533.7385866642,66269.74801373482,60533.7385866642,5730.616383552551,2.2462613582611084,0.0 -74900,1.661087,1.0942813,,,,,,,,,,,,,, -75000,5.4327636,1.1578586,,,,,,,,,,,,,, -75100,2.4584596,1.1681949,,,,,,,,,,,,,, -75200,2.6093721,1.1566948,,,,,,,,,,,,,, -75300,2.1761203,1.1415414,,,,,,,,,,,,,, -75400,8.456076,1.1339155,,,,,,,,,,,,,, -75500,3.6607752,1.151275,,,,,,,,,,,,,, -75600,2.8060708,1.1680744,,,,,,,,,,,,,, -75700,1.5443738,1.1715018,,,,,,,,,,,,,, -75800,2.423806,1.1062504,,,,,,,,,,,,,, -75900,1.8362812,1.1098257,,,,,,,,,,,,,, -76000,3.3411021,1.1757265,,,,,,,,,,,,,, -76100,2.2902489,1.0964892,,,,,,,,,,,,,, -76200,2.5755079,1.1953784,,,,,,,,,,,,,, -76300,1.8998046,1.1689535,,,,,,,,,,,,,, -76400,2.3137777,1.1758516,,,,,,,,,,,,,, -76500,3.0325809,1.2204858,,,,,,,,,,,,,, -76600,2.7135653,1.2048643,,,,,,,,,,,,,, -76605,,,0.24120624,0.084248732741005,0.422617,0.1218803402299738,5348.0,0.22685282,0.0729795056161517,2472.0,61974.3486623764,67841.11692786217,61974.3486623764,5861.242010354996,2.3004186153411865,0.0 -76700,3.4126709,1.156068,,,,,,,,,,,,,, -76800,2.0485606,1.2463558,,,,,,,,,,,,,, -76900,2.2137518,1.1759053,,,,,,,,,,,,,, -77000,1.5410571,1.1218722,,,,,,,,,,,,,, -77100,2.469499,1.1300806,,,,,,,,,,,,,, -77200,2.6123302,1.1887102,,,,,,,,,,,,,, -77300,2.490705,1.1764748,,,,,,,,,,,,,, -77400,2.061014,1.2131212,,,,,,,,,,,,,, -77500,6.3871317,1.1282579,,,,,,,,,,,,,, -77600,2.0383239,1.095737,,,,,,,,,,,,,, -77700,1.9201989,1.1221446,,,,,,,,,,,,,, -77800,4.319937,1.1692773,,,,,,,,,,,,,, -77900,3.130614,1.200595,,,,,,,,,,,,,, -78000,3.4456248,1.1257546,,,,,,,,,,,,,, -78100,3.9050965,1.1645898,,,,,,,,,,,,,, -78200,2.1832123,1.1283056,,,,,,,,,,,,,, -78300,2.1330085,1.1701666,,,,,,,,,,,,,, -78338,,,0.23317532,0.078230944254835,0.42259285,0.1218513762707937,5348.0,0.22684146,0.0729185708772571,2472.0,63414.18882107735,69409.33088636398,63414.18882107735,5989.411997795105,2.428640842437744,0.0 -78400,2.861491,1.1843698,,,,,,,,,,,,,, -78500,1.785287,1.1651812,,,,,,,,,,,,,, -78600,4.4370713,1.1274483,,,,,,,,,,,,,, -78700,4.0641723,1.0785574,,,,,,,,,,,,,, -78800,2.465562,1.1515342,,,,,,,,,,,,,, -78900,1.7861025,1.1394467,,,,,,,,,,,,,, -79000,6.3195715,1.1901392,,,,,,,,,,,,,, -79100,4.26856,1.1873988,,,,,,,,,,,,,, -79200,2.7419815,1.1871104,,,,,,,,,,,,,, -79300,1.5051188,1.147802,,,,,,,,,,,,,, -79400,2.0060983,1.1759486,,,,,,,,,,,,,, -79500,3.6652818,1.1185038,,,,,,,,,,,,,, -79600,2.1981492,1.1789658,,,,,,,,,,,,,, -79700,2.3205054,1.1584933,,,,,,,,,,,,,, -79800,2.3059165,1.175881,,,,,,,,,,,,,, -79900,3.583966,1.1486514,,,,,,,,,,,,,, -80000,3.3786016,1.2092011,,,,,,,,,,,,,, -80100,1.8281058,1.1313671,,,,,,,,,,,,,, -80112,,,0.23077081,0.0776995343340543,0.42259255,0.1218320669646736,5348.0,0.22684368,0.0729591940365202,2472.0,64854.61954379082,70982.00092959404,64854.61954379082,6121.52056145668,2.482978343963623,0.0 -80200,5.7066875,1.1508584,,,,,,,,,,,,,, -80300,2.1046622,1.1600823,,,,,,,,,,,,,, -80400,3.0399365,1.1876858,,,,,,,,,,,,,, -80500,4.183365,1.1677464,,,,,,,,,,,,,, -80600,2.7265196,1.1112449,,,,,,,,,,,,,, -80700,1.8895817,1.1437721,,,,,,,,,,,,,, -80800,2.709419,1.1738886,,,,,,,,,,,,,, -80900,5.282439,1.1762484,,,,,,,,,,,,,, -81000,2.329435,1.1259757,,,,,,,,,,,,,, -81100,5.71637,1.1369393,,,,,,,,,,,,,, -81200,1.5942695,1.1186187,,,,,,,,,,,,,, -81300,2.9727805,1.1987011,,,,,,,,,,,,,, -81400,2.5629056,1.137767,,,,,,,,,,,,,, -81500,4.406431,1.1626908,,,,,,,,,,,,,, -81600,3.212599,1.0661888,,,,,,,,,,,,,, -81700,2.7830644,1.1633557,,,,,,,,,,,,,, -81800,1.9898992,1.1163708,,,,,,,,,,,,,, -81892,,,0.21406859,0.0751324795760653,0.42261186,0.1218513762707937,5348.0,0.22684759,0.0729185708772571,2472.0,66294.92736124992,72556.6824965477,66294.92736124992,6255.759551286697,2.541866540908813,0.0 -81900,2.2887905,1.1121538,,,,,,,,,,,,,, -82000,2.144781,1.187868,,,,,,,,,,,,,, -82100,5.8539248,1.1775739,,,,,,,,,,,,,, -82200,2.3721483,1.1910915,,,,,,,,,,,,,, -82300,5.603242,1.1609757,,,,,,,,,,,,,, -82400,2.895043,1.1481715,,,,,,,,,,,,,, -82500,2.3672972,1.1461395,,,,,,,,,,,,,, -82600,5.3026996,1.1453903,,,,,,,,,,,,,, -82700,2.7087638,1.2013178,,,,,,,,,,,,,, -82800,3.1674955,1.0746096,,,,,,,,,,,,,, -82900,3.4742193,1.1154667,,,,,,,,,,,,,, -83000,2.9888103,1.1313388,,,,,,,,,,,,,, -83100,2.3863003,1.2008613,,,,,,,,,,,,,, -83200,6.600432,1.1646405,,,,,,,,,,,,,, -83300,6.72966,1.1057359,,,,,,,,,,,,,, -83400,4.0266504,1.1087441,,,,,,,,,,,,,, -83500,3.3423572,1.1284616,,,,,,,,,,,,,, -83600,3.1675758,1.2183542,,,,,,,,,,,,,, -83650,,,0.21813707,0.0758759012604678,0.4225675,0.1218610309238537,5348.0,0.22683038,0.0729185708772571,2472.0,67735.07097840309,74126.87830376625,67735.07097840309,6385.677375793457,2.598761558532715,0.0 -83700,2.7389545,1.2091492,,,,,,,,,,,,,, -83800,4.100385,1.119129,,,,,,,,,,,,,, -83900,2.6905992,1.1777515,,,,,,,,,,,,,, -84000,2.9412162,1.1545372,,,,,,,,,,,,,, -84100,2.1624386,1.1533344,,,,,,,,,,,,,, -84200,3.3307056,1.0894976,,,,,,,,,,,,,, -84300,3.09934,1.134703,,,,,,,,,,,,,, -84400,2.0469506,1.0831896,,,,,,,,,,,,,, -84500,5.650386,1.1529256,,,,,,,,,,,,,, -84600,1.7153946,1.113279,,,,,,,,,,,,,, -84700,2.5810075,1.1379343,,,,,,,,,,,,,, -84800,3.6948347,1.1471403,,,,,,,,,,,,,, -84900,2.4485462,1.1909498,,,,,,,,,,,,,, -85000,1.8422965,1.1873783,,,,,,,,,,,,,, -85100,5.1912246,1.1660712,,,,,,,,,,,,,, -85200,2.8899245,1.2004291,,,,,,,,,,,,,, -85300,3.855381,1.1952778,,,,,,,,,,,,,, -85400,1.5647842,1.1376348,,,,,,,,,,,,,, -85420,,,0.21527421,0.0751402136641586,0.42256075,0.1218224123116135,5348.0,0.22683091,0.0729185708772571,2472.0,69175.55707406998,75700.22792291641,69175.55707406998,6518.407257318497,2.6557695865631104,0.0 -85500,4.1174493,1.1442491,,,,,,,,,,,,,, -85600,7.670775,1.1955199,,,,,,,,,,,,,, -85700,3.646655,1.138759,,,,,,,,,,,,,, -85800,1.6315488,1.1442755,,,,,,,,,,,,,, -85900,2.6324515,1.1963354,,,,,,,,,,,,,, -86000,3.070314,1.1450306,,,,,,,,,,,,,, -86100,5.1698146,1.1859496,,,,,,,,,,,,,, -86200,3.1029487,1.1888604,,,,,,,,,,,,,, -86300,2.1457648,1.144847,,,,,,,,,,,,,, -86400,4.024159,1.192348,,,,,,,,,,,,,, -86500,1.8775972,1.1328264,,,,,,,,,,,,,, -86600,3.9193134,1.1146997,,,,,,,,,,,,,, -86700,4.69184,1.1384772,,,,,,,,,,,,,, -86800,1.6943688,1.1636404,,,,,,,,,,,,,, -86900,5.47859,1.167232,,,,,,,,,,,,,, -87000,1.5306813,1.1267823,,,,,,,,,,,,,, -87100,5.6066246,1.1142367,,,,,,,,,,,,,, -87200,2.314454,1.1114217,,,,,,,,,,,,,, -87217,,,0.136696,0.0487799693929603,0.4225517,0.1218031030054935,5348.0,0.22682695,0.0729185708772571,2472.0,70615.84012532234,77291.52449822426,70615.84012532234,6669.279905557632,2.7173826694488525,0.0 -87300,2.2241642,1.1800947,,,,,,,,,,,,,, -87400,3.283999,1.1600019,,,,,,,,,,,,,, -87500,2.735423,1.1165482,,,,,,,,,,,,,, -87600,6.942434,1.1404109,,,,,,,,,,,,,, -87700,3.8956563,1.1975874,,,,,,,,,,,,,, -87800,2.6822278,1.1804283,,,,,,,,,,,,,, -87900,3.9165692,1.1589154,,,,,,,,,,,,,, -88000,2.7294014,1.1281604,,,,,,,,,,,,,, -88100,6.2588916,1.225789,,,,,,,,,,,,,, -88200,5.3534513,1.1481733,,,,,,,,,,,,,, -88300,4.0749135,1.1389128,,,,,,,,,,,,,, -88400,2.7832386,1.1672509,,,,,,,,,,,,,, -88500,1.9503073,1.1402569,,,,,,,,,,,,,, -88600,1.3671905,1.1303473,,,,,,,,,,,,,, -88700,3.49103,1.2194773,,,,,,,,,,,,,, -88800,1.9464087,1.17393,,,,,,,,,,,,,, -88900,1.8268968,1.168113,,,,,,,,,,,,,, -89000,2.7694135,1.2145847,,,,,,,,,,,,,, -89013,,,0.13775001,0.0477160208423163,0.42260024,0.1218706855769137,5348.0,0.22683926,0.0729388824568886,2472.0,72055.93694233894,78864.91400146484,72055.93694233894,6802.439275741577,2.773588180541992,0.0 -89100,3.3072116,1.1455505,,,,,,,,,,,,,, -89200,3.7347527,1.1733325,,,,,,,,,,,,,, -89300,4.9876075,1.0899858,,,,,,,,,,,,,, -89400,3.5229363,1.123938,,,,,,,,,,,,,, -89500,4.8223386,1.219023,,,,,,,,,,,,,, -89600,3.365842,1.1853883,,,,,,,,,,,,,, -89700,1.5274693,1.0897664,,,,,,,,,,,,,, -89800,4.5868583,1.1286054,,,,,,,,,,,,,, -89900,2.9307268,1.1631643,,,,,,,,,,,,,, -90000,10.968027,1.1943011,,,,,,,,,,,,,, -90100,2.4539,1.1709576,,,,,,,,,,,,,, -90200,3.1332123,1.1299149,,,,,,,,,,,,,, -90300,5.415153,1.1654563,,,,,,,,,,,,,, -90400,2.7120073,1.1731663,,,,,,,,,,,,,, -90500,1.6864175,1.1216942,,,,,,,,,,,,,, -90600,2.1942205,1.205962,,,,,,,,,,,,,, -90700,4.28577,1.1148736,,,,,,,,,,,,,, -90750,,,0.15148643,0.0530193715478738,0.42255753,0.1218417216177336,5348.0,0.22682476,0.0729388824568886,2472.0,73496.3667254448,80437.90277504921,73496.3667254448,6934.860629796982,2.8354685306549072,0.0 -90800,2.401743,1.1314414,,,,,,,,,,,,,, -90900,3.5502272,1.1164665,,,,,,,,,,,,,, -91000,1.6168481,1.1550785,,,,,,,,,,,,,, -91100,2.307747,1.1091194,,,,,,,,,,,,,, -91200,2.6636958,1.1599076,,,,,,,,,,,,,, -91300,6.573206,1.1244093,,,,,,,,,,,,,, -91400,4.4734282,1.1580354,,,,,,,,,,,,,, -91500,1.3732673,1.137033,,,,,,,,,,,,,, -91600,2.2666726,1.1284865,,,,,,,,,,,,,, -91700,2.2658625,1.172443,,,,,,,,,,,,,, -91800,2.2299287,1.1423424,,,,,,,,,,,,,, -91900,3.3261654,1.176129,,,,,,,,,,,,,, -92000,2.3381598,1.1351287,,,,,,,,,,,,,, -92100,2.377793,1.1753117,,,,,,,,,,,,,, -92200,2.426994,1.183823,,,,,,,,,,,,,, -92300,1.7665031,1.1733477,,,,,,,,,,,,,, -92400,2.1428945,1.1350541,,,,,,,,,,,,,, -92494,,,0.13619804,0.0479902329075882,0.4226031,0.1218899948830338,5348.0,0.22684576,0.0729795056161517,2472.0,74936.79920172691,82011.60424017906,74936.79920172691,7067.9899690151215,2.8983287811279297,0.0 -92500,3.0865161,1.2124238,,,,,,,,,,,,,, -92600,2.3972116,1.2073855,,,,,,,,,,,,,, -92700,1.7624876,1.121549,,,,,,,,,,,,,, -92800,2.056556,1.1031891,,,,,,,,,,,,,, -92900,4.9662952,1.1278404,,,,,,,,,,,,,, -93000,2.1404216,1.0883601,,,,,,,,,,,,,, -93100,2.6754444,1.1936346,,,,,,,,,,,,,, -93200,2.996388,1.1077064,,,,,,,,,,,,,, -93300,2.3656034,1.1177217,,,,,,,,,,,,,, -93400,3.126859,1.1478409,,,,,,,,,,,,,, -93500,3.80715,1.1462948,,,,,,,,,,,,,, -93600,3.065175,1.1452123,,,,,,,,,,,,,, -93700,3.4991546,1.1689688,,,,,,,,,,,,,, -93800,2.6311502,1.1730137,,,,,,,,,,,,,, -93900,2.1424088,1.1883706,,,,,,,,,,,,,, -94000,3.6003137,1.1320428,,,,,,,,,,,,,, -94100,1.774368,1.1861309,,,,,,,,,,,,,, -94200,3.4058444,1.1651675,,,,,,,,,,,,,, -94269,,,0.14904647,0.0523268522319548,0.4226395,0.1218513762707937,5348.0,0.22686507,0.0729998171957833,2472.0,76377.0752260685,83584.03257799149,76377.0752260685,7200.00839805603,2.9541354179382324,0.0 -94300,2.6274254,1.1201775,,,,,,,,,,,,,, -94400,3.0099988,1.1928235,,,,,,,,,,,,,, -94500,2.1757581,1.1141688,,,,,,,,,,,,,, -94600,5.5466723,1.1439453,,,,,,,,,,,,,, -94700,2.9091523,1.1657703,,,,,,,,,,,,,, -94800,2.1080472,1.1292355,,,,,,,,,,,,,, -94900,3.0326924,1.1425351,,,,,,,,,,,,,, -95000,6.9586453,1.1841828,,,,,,,,,,,,,, -95100,4.903489,1.1471946,,,,,,,,,,,,,, -95200,4.287972,1.1673174,,,,,,,,,,,,,, -95300,2.5179505,1.1800877,,,,,,,,,,,,,, -95400,2.953575,1.1432902,,,,,,,,,,,,,, -95500,2.7294292,1.1970813,,,,,,,,,,,,,, -95600,1.8775227,1.1564137,,,,,,,,,,,,,, -95700,2.7512069,1.1555033,,,,,,,,,,,,,, -95800,2.8788357,1.1458547,,,,,,,,,,,,,, -95900,4.478129,1.1177499,,,,,,,,,,,,,, -96000,3.888895,1.1346592,,,,,,,,,,,,,, -96036,,,0.13960445,0.0499481765629841,0.42258808,0.1218417216177336,5348.0,0.2268396,0.0729185708772571,2472.0,77817.67358899117,85160.68888759613,77817.67358899117,7335.930129051208,3.0127499103546143,0.0 -96100,1.6200507,1.1242427,,,,,,,,,,,,,, -96200,7.8106,1.1430416,,,,,,,,,,,,,, -96300,2.891074,1.1389915,,,,,,,,,,,,,, -96400,1.6962895,1.1478683,,,,,,,,,,,,,, -96500,4.5254054,1.1498083,,,,,,,,,,,,,, -96600,2.9743268,1.1359164,,,,,,,,,,,,,, -96700,6.352265,1.2036481,,,,,,,,,,,,,, -96800,3.4257445,1.10574,,,,,,,,,,,,,, -96900,2.881772,1.2181196,,,,,,,,,,,,,, -97000,2.4532638,1.1579932,,,,,,,,,,,,,, -97100,2.7962809,1.1817613,,,,,,,,,,,,,, -97200,3.4546309,1.1783237,,,,,,,,,,,,,, -97300,3.1495337,1.1537092,,,,,,,,,,,,,, -97400,1.6112589,1.1340891,,,,,,,,,,,,,, -97500,1.6052544,1.2005718,,,,,,,,,,,,,, -97600,2.844048,1.1698112,,,,,,,,,,,,,, -97700,2.7453632,1.1097511,,,,,,,,,,,,,, -97761,,,0.16996482,0.0559007210919392,0.4225012,0.1218127576585535,5348.0,0.22680737,0.0728982592976255,2472.0,79257.58918118477,86733.26485586166,79257.58918118477,7468.458901882172,3.0681488513946533,0.0 -97800,3.383107,1.1556945,,,,,,,,,,,,,, -97900,12.800312,1.1671044,,,,,,,,,,,,,, -98000,3.4454389,1.164911,,,,,,,,,,,,,, -98100,3.5877945,1.1328213,,,,,,,,,,,,,, -98200,2.0054593,1.1969296,,,,,,,,,,,,,, -98300,3.2891026,1.2089403,,,,,,,,,,,,,, -98400,2.4291708,1.1824331,,,,,,,,,,,,,, -98500,2.72912,1.1330618,,,,,,,,,,,,,, -98600,3.2493038,1.1540885,,,,,,,,,,,,,, -98700,4.668024,1.1574169,,,,,,,,,,,,,, -98800,2.6439745,1.1751626,,,,,,,,,,,,,, -98900,2.1330369,1.1554416,,,,,,,,,,,,,, -99000,2.7705073,1.0615009,,,,,,,,,,,,,, -99100,4.493867,1.1872189,,,,,,,,,,,,,, -99200,2.5785453,1.1489513,,,,,,,,,,,,,, -99300,3.2036872,1.1161952,,,,,,,,,,,,,, -99400,2.4192557,1.1600575,,,,,,,,,,,,,, -99500,2.1088197,1.145092,,,,,,,,,,,,,, -99506,,,0.15189572,0.0539175148430873,0.42260715,0.1218610309238537,5348.0,0.22684963,0.0729388824568886,2472.0,80698.04615569115,88312.18506860733,80698.04615569115,7606.780797481537,3.13254976272583,0.0 -99600,2.6955776,1.2066233,,,,,,,,,,,,,, -99700,5.394626,1.1692648,,,,,,,,,,,,,, -99800,2.6409857,1.0853333,,,,,,,,,,,,,, -99900,3.3653183,1.132433,,,,,,,,,,,,,, -100000,5.1419215,1.1606311,,,,,,,,,,,,,, -100100,3.7894366,1.1490878,,,,,,,,,,,,,, -100200,3.7870417,1.1387849,,,,,,,,,,,,,, -100300,4.221288,1.153067,,,,,,,,,,,,,, -100400,2.092639,1.1487803,,,,,,,,,,,,,, -100500,5.694563,1.1612058,,,,,,,,,,,,,, -100600,2.66132,1.1732816,,,,,,,,,,,,,, -100700,2.5394957,1.1488379,,,,,,,,,,,,,, -100800,1.8282334,1.147895,,,,,,,,,,,,,, -100900,1.6263864,1.1372412,,,,,,,,,,,,,, -101000,1.8355354,1.1278242,,,,,,,,,,,,,, -101100,2.1315436,1.1263753,,,,,,,,,,,,,, -101200,4.5733943,1.1874089,,,,,,,,,,,,,, -101269,,,0.13218991,0.0481681788283444,0.42258465,0.1218417216177336,5348.0,0.22683719,0.0729388824568886,2472.0,82138.8430378437,89889.42877030373,82138.8430378437,7743.091492176056,3.1927475929260254,0.0 -101300,2.829568,1.1697946,,,,,,,,,,,,,, -101400,2.182126,1.2228009,,,,,,,,,,,,,, -101500,3.4207978,1.2197589,,,,,,,,,,,,,, -101600,2.4422266,1.152526,,,,,,,,,,,,,, -101700,2.312132,1.1556691,,,,,,,,,,,,,, -101800,3.1800697,1.1429306,,,,,,,,,,,,,, -101900,6.4559336,1.1522698,,,,,,,,,,,,,, -102000,2.0486548,1.1680927,,,,,,,,,,,,,, -102100,5.2120447,1.1854913,,,,,,,,,,,,,, -102200,2.6232326,1.1148009,,,,,,,,,,,,,, -102300,3.6646278,1.1803876,,,,,,,,,,,,,, -102400,2.2453701,1.1389686,,,,,,,,,,,,,, -102500,2.3853567,1.2087687,,,,,,,,,,,,,, -102600,2.5736299,1.1295359,,,,,,,,,,,,,, -102700,4.621256,1.161886,,,,,,,,,,,,,, -102800,4.5666766,1.1101555,,,,,,,,,,,,,, -102900,2.3639543,1.171778,,,,,,,,,,,,,, -103000,,,0.13685067,0.0485069051633973,0.42259264,0.1218610309238537,5348.0,0.22684465,0.0729388824568886,2472.0,83579.43496155739,91461.50216126442,83579.43496155739,7874.434673547745,3.255798816680908,0.0 -103000,2.3260415,1.1778688,,,,,,,,,,,,,, -103100,5.0916896,1.1621658,,,,,,,,,,,,,, -103200,4.320662,1.1060914,,,,,,,,,,,,,, -103300,4.982412,1.1579885,,,,,,,,,,,,,, -103400,9.376364,1.1651663,,,,,,,,,,,,,, -103500,3.1155643,1.1141399,,,,,,,,,,,,,, -103600,4.230495,1.1936736,,,,,,,,,,,,,, -103700,3.2341287,1.156047,,,,,,,,,,,,,, -103800,1.9920431,1.1011547,,,,,,,,,,,,,, -103900,2.9896975,1.143449,,,,,,,,,,,,,, -104000,2.0653677,1.1796699,,,,,,,,,,,,,, -104100,4.2858977,1.13164,,,,,,,,,,,,,, -104200,3.2967982,1.1451678,,,,,,,,,,,,,, -104300,2.7329655,1.1564993,,,,,,,,,,,,,, -104400,3.8656788,1.1672711,,,,,,,,,,,,,, -104500,2.5794377,1.1723542,,,,,,,,,,,,,, -104600,2.823294,1.2426242,,,,,,,,,,,,,, -104700,1.7108494,1.1362618,,,,,,,,,,,,,, -104750,,,0.13595581,0.0497899612236105,0.42256644,0.1218031030054935,5348.0,0.22683273,0.0728982592976255,2472.0,85019.83351564407,93034.88031411172,85019.83351564407,8007.273331642151,3.319591999053955,0.0 -104800,3.7595108,1.142522,,,,,,,,,,,,,, -104900,2.9237194,1.1282376,,,,,,,,,,,,,, -105000,3.081039,1.1774254,,,,,,,,,,,,,, -105100,5.089209,1.1255928,,,,,,,,,,,,,, -105200,2.284865,1.1488671,,,,,,,,,,,,,, -105300,6.056953,1.1615773,,,,,,,,,,,,,, -105400,2.8978913,1.1357723,,,,,,,,,,,,,, -105500,5.2237477,1.1137058,,,,,,,,,,,,,, -105600,2.24554,1.1618893,,,,,,,,,,,,,, -105700,4.650197,1.1382935,,,,,,,,,,,,,, -105800,5.287374,1.1673561,,,,,,,,,,,,,, -105900,2.122711,1.1307986,,,,,,,,,,,,,, -106000,3.5407631,1.0512574,,,,,,,,,,,,,, -106100,3.520116,1.1441892,,,,,,,,,,,,,, -106200,5.2425885,1.1266817,,,,,,,,,,,,,, -106300,1.6191076,1.137721,,,,,,,,,,,,,, -106400,2.930772,1.1638834,,,,,,,,,,,,,, -106500,1.948358,1.1666546,,,,,,,,,,,,,, -106513,,,0.15283024,0.052183996907615,0.42258832,0.1218320669646736,5348.0,0.22684817,0.0729388824568886,2472.0,86459.7220981121,94610.18890833856,86459.7220981121,8142.552544355392,3.381679058074951,0.0 -106600,3.346158,1.1231413,,,,,,,,,,,,,, -106700,1.9053738,1.1577655,,,,,,,,,,,,,, -106800,2.4133897,1.1821119,,,,,,,,,,,,,, -106900,8.643159,1.1843053,,,,,,,,,,,,,, -107000,3.117695,1.1957438,,,,,,,,,,,,,, -107100,1.4558543,1.1622148,,,,,,,,,,,,,, -107200,7.0272746,1.1171515,,,,,,,,,,,,,, -107300,5.6738486,1.1459297,,,,,,,,,,,,,, -107400,2.9248333,1.1381775,,,,,,,,,,,,,, -107500,1.4805278,1.1194934,,,,,,,,,,,,,, -107600,5.4823194,1.1973656,,,,,,,,,,,,,, -107700,2.6891153,1.1610843,,,,,,,,,,,,,, -107800,3.7421908,1.1348195,,,,,,,,,,,,,, -107900,5.375263,1.1441102,,,,,,,,,,,,,, -108000,2.2381697,1.1530787,,,,,,,,,,,,,, -108100,4.4371934,1.1430746,,,,,,,,,,,,,, -108200,6.182856,1.1490057,,,,,,,,,,,,,, -108264,,,0.152285,0.0533659226724665,0.42263636,0.1218513762707937,5348.0,0.2268602,0.0729795056161517,2472.0,87899.6213388443,96182.41661047935,87899.6213388443,8274.74119591713,3.4438462257385254,0.0 -108300,2.111882,1.1481584,,,,,,,,,,,,,, -108400,3.6869464,1.152572,,,,,,,,,,,,,, -108500,2.916785,1.1382353,,,,,,,,,,,,,, -108600,1.3338451,1.1459811,,,,,,,,,,,,,, -108700,3.30588,1.1993265,,,,,,,,,,,,,, -108800,3.3697226,1.1379919,,,,,,,,,,,,,, -108900,4.598755,1.1203903,,,,,,,,,,,,,, -109000,1.9280128,1.1740067,,,,,,,,,,,,,, -109100,2.3262823,1.1415633,,,,,,,,,,,,,, -109200,3.4413097,1.1886208,,,,,,,,,,,,,, -109300,2.5797944,1.1424332,,,,,,,,,,,,,, -109400,4.807635,1.1603708,,,,,,,,,,,,,, -109500,3.6968224,1.1976553,,,,,,,,,,,,,, -109600,2.8551574,1.1582874,,,,,,,,,,,,,, -109700,4.0005813,1.1352972,,,,,,,,,,,,,, -109800,3.1419225,1.1567888,,,,,,,,,,,,,, -109900,7.225228,1.1094375,,,,,,,,,,,,,, -110000,1.8285533,1.1068989,,,,,,,,,,,,,, -110014,,,0.14916688,0.053241399287218,0.4225592,0.1218706855769137,5348.0,0.22682457,0.0729185708772571,2472.0,89340.51832222939,97758.5703458786,89340.51832222939,8409.858960390091,3.506175756454468,0.0 -110100,5.247522,1.1643366,,,,,,,,,,,,,, -110200,2.076978,1.1560266,,,,,,,,,,,,,, -110300,3.007047,1.1091504,,,,,,,,,,,,,, -110400,2.979195,1.1065308,,,,,,,,,,,,,, -110500,1.5724745,1.1663239,,,,,,,,,,,,,, -110600,4.8407927,1.1976929,,,,,,,,,,,,,, -110700,4.8459444,1.1329367,,,,,,,,,,,,,, -110800,2.125766,1.2326916,,,,,,,,,,,,,, -110900,4.2395964,1.1909688,,,,,,,,,,,,,, -111000,4.650437,1.163538,,,,,,,,,,,,,, -111100,2.5276887,1.1430894,,,,,,,,,,,,,, -111200,3.262731,1.1963418,,,,,,,,,,,,,, -111300,2.3186786,1.1636181,,,,,,,,,,,,,, -111400,3.6556897,1.1491288,,,,,,,,,,,,,, -111500,1.9108773,1.1059742,,,,,,,,,,,,,, -111600,4.492255,1.2091516,,,,,,,,,,,,,, -111700,3.8877563,1.1881303,,,,,,,,,,,,,, -111767,,,0.14817163,0.0517262157389796,0.4225653,0.1218320669646736,5348.0,0.22682922,0.0728982592976255,2472.0,90780.58996725082,99335.57467722891,90780.58996725082,8546.651809930801,3.56905198097229,0.0 -111800,2.3047843,1.1708543,,,,,,,,,,,,,, -111900,2.4726272,1.1809311,,,,,,,,,,,,,, -112000,3.740889,1.1793201,,,,,,,,,,,,,, -112100,2.493688,1.0943826,,,,,,,,,,,,,, -112200,3.208165,1.1603613,,,,,,,,,,,,,, -112300,6.2513247,1.1791434,,,,,,,,,,,,,, -112400,1.7323203,1.2379494,,,,,,,,,,,,,, -112500,1.8991902,1.1763614,,,,,,,,,,,,,, -112600,4.096601,1.1794689,,,,,,,,,,,,,, -112700,3.0155272,1.1686871,,,,,,,,,,,,,, -112800,3.1670413,1.1505581,,,,,,,,,,,,,, -112900,5.308205,1.1956704,,,,,,,,,,,,,, -113000,2.368467,1.2187212,,,,,,,,,,,,,, -113100,1.8699439,1.1338837,,,,,,,,,,,,,, -113200,4.1712065,1.1427172,,,,,,,,,,,,,, -113300,3.8250358,1.2229813,,,,,,,,,,,,,, -113400,1.6332468,1.1136543,,,,,,,,,,,,,, -113500,3.3255026,1.2223554,,,,,,,,,,,,,, -113547,,,0.1355164,0.0486579392737951,0.4225174,0.1218224123116135,5348.0,0.22680725,0.0728576361383624,2472.0,92221.01919841766,100911.24274611472,92221.01919841766,8681.74896121025,3.631675720214844,0.0 -113600,1.7182766,1.2140778,,,,,,,,,,,,,, -113700,3.3779101,1.1614548,,,,,,,,,,,,,, -113800,2.0802267,1.099401,,,,,,,,,,,,,, -113900,3.6870208,1.2101208,,,,,,,,,,,,,, -114000,4.883762,1.0859717,,,,,,,,,,,,,, -114100,1.9935399,1.2210466,,,,,,,,,,,,,, -114200,8.633291,1.1016649,,,,,,,,,,,,,, -114300,2.0627062,1.1525033,,,,,,,,,,,,,, -114400,2.7779233,1.134877,,,,,,,,,,,,,, -114500,3.8243127,1.2346189,,,,,,,,,,,,,, -114600,5.9027996,1.1449623,,,,,,,,,,,,,, -114700,2.0359304,1.1711558,,,,,,,,,,,,,, -114800,1.6952045,1.1763275,,,,,,,,,,,,,, -114900,2.7898054,1.1542519,,,,,,,,,,,,,, -115000,5.356079,1.1551273,,,,,,,,,,,,,, -115100,5.164953,1.1069624,,,,,,,,,,,,,, -115200,2.601174,1.2024581,,,,,,,,,,,,,, -115300,2.7700891,1.1159738,,,,,,,,,,,,,, -115301,,,0.14265695,0.0494268791872569,0.42258564,0.1218417216177336,5348.0,0.22684036,0.0729185708772571,2472.0,93661.36353373528,102491.87964940073,93661.36353373528,8821.899176359177,3.6960880756378174,0.0 -115400,3.625455,1.1606442,,,,,,,,,,,,,, -115500,4.300975,1.2487242,,,,,,,,,,,,,, -115600,1.9641197,1.1731546,,,,,,,,,,,,,, -115700,2.6556926,1.1506531,,,,,,,,,,,,,, -115800,2.8685155,1.1395508,,,,,,,,,,,,,, -115900,1.9719043,1.1610079,,,,,,,,,,,,,, -116000,8.256111,1.0976195,,,,,,,,,,,,,, -116100,2.2078211,1.1618524,,,,,,,,,,,,,, -116200,3.3023562,1.13527,,,,,,,,,,,,,, -116300,1.7720364,1.1634172,,,,,,,,,,,,,, -116400,3.22795,1.1650126,,,,,,,,,,,,,, -116500,6.130513,1.1671855,,,,,,,,,,,,,, -116600,6.807319,1.1957536,,,,,,,,,,,,,, -116700,4.791089,1.1530594,,,,,,,,,,,,,, -116800,1.6391166,1.1394237,,,,,,,,,,,,,, -116900,2.9518251,1.1344845,,,,,,,,,,,,,, -117000,4.427841,1.145528,,,,,,,,,,,,,, -117052,,,0.15840484,0.0546797230036788,0.4226101,0.1218320669646736,5348.0,0.22685356,0.0729591940365202,2472.0,95101.26403808594,104066.78872728348,95101.26403808594,8956.770744800568,3.7561357021331774,0.0 -117100,2.8129165,1.1648406,,,,,,,,,,,,,, -117200,2.454972,1.1372924,,,,,,,,,,,,,, -117300,1.9210258,1.1711487,,,,,,,,,,,,,, -117400,2.1183565,1.1329088,,,,,,,,,,,,,, -117500,2.744796,1.0752977,,,,,,,,,,,,,, -117600,2.3123689,1.1529844,,,,,,,,,,,,,, -117700,2.7843328,1.2330194,,,,,,,,,,,,,, -117800,3.7770054,1.151319,,,,,,,,,,,,,, -117900,3.0105693,1.0815831,,,,,,,,,,,,,, -118000,2.5846756,1.1656301,,,,,,,,,,,,,, -118100,3.5783648,1.1751951,,,,,,,,,,,,,, -118200,2.0044055,1.1408625,,,,,,,,,,,,,, -118300,3.4146369,1.1760299,,,,,,,,,,,,,, -118400,2.9025383,1.1236229,,,,,,,,,,,,,, -118500,2.6331043,1.2109797,,,,,,,,,,,,,, -118600,2.0347278,1.1448989,,,,,,,,,,,,,, -118700,3.6620426,1.1753983,,,,,,,,,,,,,, -118800,3.8103883,1.2078038,,,,,,,,,,,,,, -118823,,,0.15105623,0.0514460345043607,0.42255726,0.1218417216177336,5348.0,0.22683339,0.0729388824568886,2472.0,96541.71780490877,105639.54648947716,96541.71780490877,9088.92824625969,3.823304891586304,0.0 -118900,2.5082572,1.1547089,,,,,,,,,,,,,, -119000,1.9878681,1.2211192,,,,,,,,,,,,,, -119100,3.1254175,1.1734444,,,,,,,,,,,,,, -119200,3.7182536,1.0739076,,,,,,,,,,,,,, -119300,2.8778758,1.1355765,,,,,,,,,,,,,, -119400,2.8919353,1.1287571,,,,,,,,,,,,,, -119500,2.1076322,1.122299,,,,,,,,,,,,,, -119600,2.611626,1.1498835,,,,,,,,,,,,,, -119700,2.3413773,1.1282071,,,,,,,,,,,,,, -119800,2.2268069,1.1564573,,,,,,,,,,,,,, -119900,3.0404856,1.1242416,,,,,,,,,,,,,, -120000,3.0724204,1.1726425,,,,,,,,,,,,,, -120100,2.6670978,1.1441063,,,,,,,,,,,,,, -120200,2.9074326,1.1876702,,,,,,,,,,,,,, -120300,3.9089835,1.2023984,,,,,,,,,,,,,, -120400,2.2559109,1.1221529,,,,,,,,,,,,,, -120500,3.222695,1.1986434,,,,,,,,,,,,,, -120570,,,0.14774881,0.0528778085564789,0.42251122,0.1217934483524334,5348.0,0.22680296,0.072877947717994,2472.0,97982.20047354698,107218.6017394066,97982.20047354698,9227.363971710203,3.88323712348938,0.0 -120600,2.6778584,1.1060401,,,,,,,,,,,,,, -120700,3.8922718,1.1547285,,,,,,,,,,,,,, -120800,2.1426075,1.2018487,,,,,,,,,,,,,, -120900,2.0457969,1.1505588,,,,,,,,,,,,,, -121000,2.2690144,1.1704351,,,,,,,,,,,,,, -121100,2.4737282,1.1894127,,,,,,,,,,,,,, -121200,2.5699506,1.2208885,,,,,,,,,,,,,, -121300,4.2567067,1.156445,,,,,,,,,,,,,, -121400,1.989126,1.1500494,,,,,,,,,,,,,, -121500,5.6504374,1.1769387,,,,,,,,,,,,,, -121600,4.17373,1.1300082,,,,,,,,,,,,,, -121700,1.5891037,1.1844927,,,,,,,,,,,,,, -121800,1.8063486,1.131516,,,,,,,,,,,,,, -121900,2.1882188,1.1399267,,,,,,,,,,,,,, -122000,2.3772469,1.169152,,,,,,,,,,,,,, -122100,3.8820693,1.1517578,,,,,,,,,,,,,, -122200,3.4262786,1.1289524,,,,,,,,,,,,,, -122300,2.8119361,1.156782,,,,,,,,,,,,,, -122319,,,0.18443608,0.0592126955763319,0.4226183,0.1218706855769137,5348.0,0.22685014,0.0729795056161517,2472.0,99422.66232037544,108793.39494419098,99422.66232037544,9361.547214984894,3.952324867248535,0.0 -122400,1.7794571,1.1278995,,,,,,,,,,,,,, -122500,2.42683,1.1467897,,,,,,,,,,,,,, -122600,3.0452385,1.1606458,,,,,,,,,,,,,, -122700,3.0329442,1.1469618,,,,,,,,,,,,,, -122800,6.5178795,1.1768327,,,,,,,,,,,,,, -122900,4.231425,1.1226505,,,,,,,,,,,,,, -123000,2.0757644,1.1476507,,,,,,,,,,,,,, -123100,3.5004098,1.1734504,,,,,,,,,,,,,, -123200,2.0614095,1.2012116,,,,,,,,,,,,,, -123300,3.017204,1.1475685,,,,,,,,,,,,,, -123400,2.8634405,1.1098235,,,,,,,,,,,,,, -123500,4.033606,1.1385697,,,,,,,,,,,,,, -123600,2.54977,1.1586856,,,,,,,,,,,,,, -123700,3.376589,1.1391932,,,,,,,,,,,,,, -123800,2.114865,1.1618091,,,,,,,,,,,,,, -123900,2.5295584,1.1510416,,,,,,,,,,,,,, -124000,6.6340833,1.145986,,,,,,,,,,,,,, -124089,,,0.1285202,0.0454345017125746,0.42256793,0.1218706855769137,5348.0,0.2268248,0.0729185708772571,2472.0,100862.93128800392,110367.68206977844,100862.93128800392,9495.424234628676,4.015031814575195,0.0 -124100,3.7320514,1.1158038,,,,,,,,,,,,,, -124200,2.6049697,1.2093979,,,,,,,,,,,,,, -124300,2.198586,1.1343855,,,,,,,,,,,,,, -124400,3.6876678,1.1606597,,,,,,,,,,,,,, -124500,1.6769874,1.1752082,,,,,,,,,,,,,, -124600,2.8828683,1.1366612,,,,,,,,,,,,,, -124700,2.6975899,1.167653,,,,,,,,,,,,,, -124800,4.23299,1.1271738,,,,,,,,,,,,,, -124900,1.8401082,1.1255355,,,,,,,,,,,,,, -125000,2.3260138,1.1485819,,,,,,,,,,,,,, -125100,3.9761071,1.1622047,,,,,,,,,,,,,, -125200,3.7999258,1.1911346,,,,,,,,,,,,,, -125300,2.086828,1.1777793,,,,,,,,,,,,,, -125400,3.150264,1.1063467,,,,,,,,,,,,,, -125500,3.1369069,1.1500504,,,,,,,,,,,,,, -125600,2.6533003,1.2286203,,,,,,,,,,,,,, -125700,3.177484,1.1734923,,,,,,,,,,,,,, -125800,2.868203,1.1603924,,,,,,,,,,,,,, -125833,,,0.14595449,0.0509496202041957,0.42254525,0.1218320669646736,5348.0,0.22681937,0.0729185708772571,2472.0,102303.157143116,111940.33980464935,102303.157143116,9627.709649086,4.083231449127197,0.0 -125900,2.8267858,1.0611652,,,,,,,,,,,,,, -126000,5.004586,1.1641493,,,,,,,,,,,,,, -126100,3.8823245,1.1557198,,,,,,,,,,,,,, -126200,2.5515838,1.1449726,,,,,,,,,,,,,, -126300,2.5476942,1.1513444,,,,,,,,,,,,,, -126400,5.8763733,1.1252449,,,,,,,,,,,,,, -126500,4.6512322,1.1722556,,,,,,,,,,,,,, -126600,4.079885,1.137142,,,,,,,,,,,,,, -126700,2.7069576,1.1822985,,,,,,,,,,,,,, -126800,4.7280116,1.1364777,,,,,,,,,,,,,, -126900,2.324128,1.103496,,,,,,,,,,,,,, -127000,3.8898814,1.162513,,,,,,,,,,,,,, -127100,3.1101456,1.0860782,,,,,,,,,,,,,, -127200,2.743458,1.1892755,,,,,,,,,,,,,, -127300,1.9854251,1.1412177,,,,,,,,,,,,,, -127400,4.955169,1.1303291,,,,,,,,,,,,,, -127500,2.8104463,1.1590158,,,,,,,,,,,,,, -127576,,,0.20160694,0.0704816760832765,0.4225641,0.1218513762707937,5348.0,0.22682925,0.0729388824568886,2472.0,103743.35876011848,113512.56096434592,103743.35876011848,9759.584423303604,4.149665832519531,0.0 -127600,6.4096456,1.1789693,,,,,,,,,,,,,, -127700,3.1868658,1.1704777,,,,,,,,,,,,,, -127800,3.8985739,1.113231,,,,,,,,,,,,,, -127900,2.3766582,1.1665471,,,,,,,,,,,,,, -128000,2.842254,1.161374,,,,,,,,,,,,,, -128100,3.9764242,1.0855069,,,,,,,,,,,,,, -128200,2.0968792,1.2391864,,,,,,,,,,,,,, -128300,2.0042493,1.1709207,,,,,,,,,,,,,, -128400,3.0153434,1.1279666,,,,,,,,,,,,,, -128500,3.2357523,1.1885364,,,,,,,,,,,,,, -128600,2.6358728,1.1435626,,,,,,,,,,,,,, -128700,3.5144172,1.1768327,,,,,,,,,,,,,, -128800,1.7705295,1.1311408,,,,,,,,,,,,,, -128900,1.3612306,1.1251675,,,,,,,,,,,,,, -129000,3.0188382,1.1410224,,,,,,,,,,,,,, -129100,2.8856466,1.1154373,,,,,,,,,,,,,, -129200,3.5122478,1.1601015,,,,,,,,,,,,,, -129300,3.260222,1.1831396,,,,,,,,,,,,,, -129314,,,0.22234,0.0753358951751978,0.42261353,0.1218224123116135,5348.0,0.2268536,0.0730404403550464,2472.0,105184.04212832452,115087.72128272057,105184.04212832452,9893.917643547058,4.215802431106567,0.0 -129400,2.8219264,1.1308291,,,,,,,,,,,,,, -129500,8.028886,1.1564857,,,,,,,,,,,,,, -129600,4.3072553,1.0926921,,,,,,,,,,,,,, -129700,2.2414675,1.139505,,,,,,,,,,,,,, -129800,2.1414256,1.1801221,,,,,,,,,,,,,, -129900,6.654823,1.1972731,,,,,,,,,,,,,, -130000,2.6278799,1.1728975,,,,,,,,,,,,,, -130100,3.0851703,1.1709999,,,,,,,,,,,,,, -130200,2.2364519,1.1842988,,,,,,,,,,,,,, -130300,2.2822862,1.1535875,,,,,,,,,,,,,, -130400,3.0208893,1.089978,,,,,,,,,,,,,, -130500,2.4142327,1.1040478,,,,,,,,,,,,,, -130600,5.277356,1.1495541,,,,,,,,,,,,,, -130700,4.54691,1.1826488,,,,,,,,,,,,,, -130800,9.02719,1.1766787,,,,,,,,,,,,,, -130900,5.2631025,1.2016169,,,,,,,,,,,,,, -131000,2.5146453,1.1369448,,,,,,,,,,,,,, -131085,,,0.2571062,0.0883086597364024,0.42260465,0.1218899948830338,5348.0,0.22685236,0.0729388824568886,2472.0,106624.42895913124,116658.7487475872,106624.42895913124,10024.404949426653,4.290512323379517,0.0 -131100,1.7232729,1.136343,,,,,,,,,,,,,, -131200,6.5726366,1.1556123,,,,,,,,,,,,,, -131300,2.7315354,1.1208211,,,,,,,,,,,,,, -131400,3.0780988,1.162746,,,,,,,,,,,,,, -131500,4.2357345,1.0887434,,,,,,,,,,,,,, -131600,2.2486882,1.0640637,,,,,,,,,,,,,, -131700,3.471689,1.1911447,,,,,,,,,,,,,, -131800,3.736191,1.1902325,,,,,,,,,,,,,, -131900,2.829664,1.1272244,,,,,,,,,,,,,, -132000,1.7186525,1.1149472,,,,,,,,,,,,,, -132100,2.6405857,1.1491948,,,,,,,,,,,,,, -132200,4.095648,1.1399398,,,,,,,,,,,,,, -132300,5.708994,1.1681455,,,,,,,,,,,,,, -132400,5.769799,1.1767805,,,,,,,,,,,,,, -132500,6.2194123,1.120916,,,,,,,,,,,,,, -132600,4.163292,1.1856368,,,,,,,,,,,,,, -132700,1.838298,1.1105665,,,,,,,,,,,,,, -132800,1.4366361,1.1648165,,,,,,,,,,,,,, -132820,,,0.23016225,0.0764056814180068,0.42252997,0.1217934483524334,5348.0,0.22682151,0.0729185708772571,2472.0,108064.59675383568,118230.8504126072,108064.59675383568,10156.198406934738,4.354108810424805,0.0 -132900,2.8982036,1.2092621,,,,,,,,,,,,,, -133000,3.610946,1.1910418,,,,,,,,,,,,,, -133100,2.1348057,1.175962,,,,,,,,,,,,,, -133200,3.6830754,1.1528984,,,,,,,,,,,,,, -133300,3.9504845,1.1766123,,,,,,,,,,,,,, -133400,2.6086082,1.1842315,,,,,,,,,,,,,, -133500,2.8324378,1.1563377,,,,,,,,,,,,,, -133600,3.2042465,1.2010367,,,,,,,,,,,,,, -133700,2.2251918,1.146907,,,,,,,,,,,,,, -133800,2.5726616,1.2001288,,,,,,,,,,,,,, -133900,2.4287298,1.1939598,,,,,,,,,,,,,, -134000,2.0228176,1.1757935,,,,,,,,,,,,,, -134100,3.9536793,1.1476203,,,,,,,,,,,,,, -134200,5.416658,1.2034857,,,,,,,,,,,,,, -134300,3.3394527,1.1633832,,,,,,,,,,,,,, -134400,2.6219878,1.1661932,,,,,,,,,,,,,, -134500,5.5932026,1.1580133,,,,,,,,,,,,,, -134567,,,0.2190213,0.0762029456223062,0.42256498,0.1218706855769137,5348.0,0.22683118,0.0729388824568886,2472.0,109504.54699015616,119801.12107515337,109504.54699015616,10286.368607997894,4.4263763427734375,0.0 -134600,2.466067,1.0738919,,,,,,,,,,,,,, -134700,1.6252508,1.096957,,,,,,,,,,,,,, -134800,2.8394175,1.1353577,,,,,,,,,,,,,, -134900,2.8359263,1.183267,,,,,,,,,,,,,, -135000,4.052345,1.1479284,,,,,,,,,,,,,, -135100,6.9931884,1.1607931,,,,,,,,,,,,,, -135200,3.17681,1.1065457,,,,,,,,,,,,,, -135300,4.700856,1.2338312,,,,,,,,,,,,,, -135400,3.0093515,1.1929458,,,,,,,,,,,,,, -135500,1.8973569,1.1617093,,,,,,,,,,,,,, -135600,3.8919597,1.1610008,,,,,,,,,,,,,, -135700,12.42344,1.2103459,,,,,,,,,,,,,, -135800,4.002533,1.1595917,,,,,,,,,,,,,, -135900,4.615424,1.0977914,,,,,,,,,,,,,, -136000,3.338388,1.1975381,,,,,,,,,,,,,, -136100,4.5648317,1.195289,,,,,,,,,,,,,, -136200,4.904676,1.1784117,,,,,,,,,,,,,, -136300,3.5626287,1.1599141,,,,,,,,,,,,,, -136327,,,0.19412705,0.0682318335919715,0.4225796,0.1218320669646736,5348.0,0.22684185,0.0729185708772571,2472.0,110944.7632997036,121375.58246660233,110944.7632997036,10420.468059301376,4.494283437728882,0.0 -136400,2.4861236,1.1630052,,,,,,,,,,,,,, -136500,4.3493433,1.1813443,,,,,,,,,,,,,, -136600,2.1859055,1.2027222,,,,,,,,,,,,,, -136700,2.639346,1.162584,,,,,,,,,,,,,, -136800,1.7662046,1.1016763,,,,,,,,,,,,,, -136900,1.6804419,1.0961381,,,,,,,,,,,,,, -137000,4.17113,1.1205677,,,,,,,,,,,,,, -137100,2.3975549,1.1677115,,,,,,,,,,,,,, -137200,2.1032677,1.1468835,,,,,,,,,,,,,, -137300,2.1033366,1.1345109,,,,,,,,,,,,,, -137400,8.432221,1.0995446,,,,,,,,,,,,,, -137500,2.540468,1.1456074,,,,,,,,,,,,,, -137600,3.842726,1.1710317,,,,,,,,,,,,,, -137700,2.163284,1.1369455,,,,,,,,,,,,,, -137800,4.3035545,1.1307983,,,,,,,,,,,,,, -137900,4.961525,1.1473747,,,,,,,,,,,,,, -138000,4.0698004,1.1733099,,,,,,,,,,,,,, -138066,,,0.2334239,0.0803974607883772,0.42255595,0.1218224123116135,5348.0,0.22682582,0.0728982592976255,2472.0,112384.63683605194,122947.29821825027,112384.63683605194,10552.179483890532,4.551663398742676,0.0 -138100,2.1330442,1.1298224,,,,,,,,,,,,,, -138200,2.230385,1.1981742,,,,,,,,,,,,,, -138300,3.2823997,1.1099211,,,,,,,,,,,,,, -138400,3.647338,1.1628152,,,,,,,,,,,,,, -138500,3.052698,1.236324,,,,,,,,,,,,,, -138600,4.1128106,1.1693628,,,,,,,,,,,,,, -138700,2.9104652,1.1637002,,,,,,,,,,,,,, -138800,7.868423,1.1789044,,,,,,,,,,,,,, -138900,1.8306898,1.2041094,,,,,,,,,,,,,, -139000,3.764863,1.1641076,,,,,,,,,,,,,, -139100,4.932463,1.1256688,,,,,,,,,,,,,, -139200,3.8244638,1.2073065,,,,,,,,,,,,,, -139300,1.696536,1.132478,,,,,,,,,,,,,, -139400,1.8298848,1.1372083,,,,,,,,,,,,,, -139500,2.2105865,1.1268982,,,,,,,,,,,,,, -139600,2.4219453,1.2003117,,,,,,,,,,,,,, -139700,4.8333926,1.1136746,,,,,,,,,,,,,, -139800,2.961004,1.1054935,,,,,,,,,,,,,, -139824,,,0.21257159,0.0737407763875521,0.42262492,0.1218031030054935,5348.0,0.22685347,0.0729795056161517,2472.0,113824.81503725052,124521.99733042716,113824.81503725052,10686.553840637209,4.618646621704102,0.0 -139900,2.9438372,1.1354263,,,,,,,,,,,,,, -140000,2.8640711,1.1922129,,,,,,,,,,,,,, -140100,5.1046114,1.1354876,,,,,,,,,,,,,, -140200,2.6530697,1.1588109,,,,,,,,,,,,,, -140300,2.5531287,1.1937907,,,,,,,,,,,,,, -140400,3.185301,1.13469,,,,,,,,,,,,,, -140500,3.2271175,1.1477098,,,,,,,,,,,,,, -140600,2.9203343,1.1734588,,,,,,,,,,,,,, -140700,3.8618112,1.1359322,,,,,,,,,,,,,, -140800,1.9565932,1.1607844,,,,,,,,,,,,,, -140900,2.5093334,1.1529077,,,,,,,,,,,,,, -141000,6.3661532,1.1298563,,,,,,,,,,,,,, -141100,7.9607477,1.1594124,,,,,,,,,,,,,, -141200,1.7515253,1.1349875,,,,,,,,,,,,,, -141300,2.1811106,1.1894349,,,,,,,,,,,,,, -141400,3.0122871,1.171104,,,,,,,,,,,,,, -141500,1.5325648,1.1471009,,,,,,,,,,,,,, -141595,,,0.21481861,0.0760388121012688,0.42257047,0.1218513762707937,5348.0,0.22683278,0.072877947717994,2472.0,115265.31365394592,126095.0590775013,115265.31365394592,10818.959375619888,4.695364475250244,0.0 -141600,5.2602468,1.1408961,,,,,,,,,,,,,, -141700,4.8351135,1.1402581,,,,,,,,,,,,,, -141800,2.7688046,1.1318346,,,,,,,,,,,,,, -141900,5.5102787,1.1539878,,,,,,,,,,,,,, -142000,2.1013455,1.1352901,,,,,,,,,,,,,, -142100,2.5275025,1.1894578,,,,,,,,,,,,,, -142200,1.9319627,1.1500758,,,,,,,,,,,,,, -142300,2.0560758,1.1884121,,,,,,,,,,,,,, -142400,1.3500221,1.1830711,,,,,,,,,,,,,, -142500,2.992257,1.1783376,,,,,,,,,,,,,, -142600,3.4567971,1.1602105,,,,,,,,,,,,,, -142700,3.8246987,1.1800811,,,,,,,,,,,,,, -142800,4.0755386,1.1376092,,,,,,,,,,,,,, -142900,1.5940001,1.117886,,,,,,,,,,,,,, -143000,2.9487135,1.1788814,,,,,,,,,,,,,, -143100,2.4722152,1.1392791,,,,,,,,,,,,,, -143200,1.6529047,1.1733772,,,,,,,,,,,,,, -143300,3.2259624,1.1951874,,,,,,,,,,,,,, -143357,,,0.22314472,0.0771731146606444,0.4225849,0.1217934483524334,5348.0,0.22684102,0.0729795056161517,2472.0,116705.55896472932,127667.55155706406,116705.55896472932,10951.057807445526,4.766013145446777,0.0 -143400,1.7052652,1.1280556,,,,,,,,,,,,,, -143500,4.371748,1.1854366,,,,,,,,,,,,,, -143600,2.4879587,1.1452965,,,,,,,,,,,,,, -143700,3.7342656,1.1903675,,,,,,,,,,,,,, -143800,3.006974,1.1511446,,,,,,,,,,,,,, -143900,1.7821195,1.1688982,,,,,,,,,,,,,, -144000,,,0.23591998,0.0815067903184128,0.42256364,0.1217934483524334,5348.0,0.22682914,0.0729185708772571,2472.0,117226.72539401054,128319.11793255806,117226.72539401054,11081.355797290802,4.835377931594849,0.0 -144000,,,,,,,,,,,117226.72539401054,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 2b2829e4d..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,232 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -535.2156233787537,0.0,19.403101921081543,1,0,19.403101921081543,0.4646042585372925,0.7593051791191101,0.0285717469569702,43793,554.6187665462494,0.4642853140830993,0.7602338194847107,0.0228730991783322,0.4643697440624237,0.7589086294174194,0.0274652618143725,43793 -654.8684275150299,0.027963638305664,259.46107959747314,742,0,259.46107959747314,0.9831854701042176,0.0639450103044509,0.0520615109430797,43793,914.3788795471193,0.9867706298828124,0.0514782257378101,0.0528244412026438,0.9842023849487304,0.0607735253870487,0.0503369997300077,43793 -777.3883748054504,0.0557601451873779,499.4656026363373,1482,0,499.4656026363373,0.9834840893745422,0.0620399415493011,0.0689065527510964,43793,1276.9522080421448,0.9872096180915833,0.0489099249243736,0.0752923644838142,0.9844236373901368,0.0588754415512084,0.0671462059000805,43793 -897.3871071338654,0.0821943283081054,739.4543447494507,2224,0,739.4543447494507,0.9838589429855348,0.056995864957571,0.1230535037550857,43793,1636.9871473312378,0.9875986576080322,0.0446453876793384,0.1208860737385289,0.9848376512527466,0.0538689643144607,0.1229580905647887,43793 -1022.0600764751434,0.1128618717193603,979.6255490779876,2973,0,979.6255490779876,0.9842245578765868,0.0544905923306942,0.1464378353277409,43793,2001.8835163116453,0.987941324710846,0.0425007417798042,0.1518332776698262,0.9851689338684082,0.0516148544847965,0.1471634693879306,43793 -1148.3995158672333,0.1392774581909179,1219.710813999176,3713,0,1219.710813999176,0.9843947291374208,0.0526860170066356,0.1691754874100842,43793,2368.356336593628,0.9883215427398682,0.0405497141182422,0.1873888127236984,0.9853605031967164,0.0498643070459365,0.1723176411435333,43793 -1278.234468460083,0.168677806854248,1459.977115869522,4454,0,1459.977115869522,0.9845686554908752,0.0527504198253154,0.1856167206647187,43793,2738.5078999996185,0.9883311986923218,0.0401252694427967,0.2039251994853367,0.9854956865310668,0.0498076863586902,0.1814416540684904,43793 -1411.4457199573517,0.1960334777832031,1700.0847454071045,5185,0,1700.0847454071045,0.9848546385765076,0.0508074052631855,0.196022459140312,43793,3111.8785684108734,0.9888564944267272,0.0382429473102092,0.2338408167710781,0.9857197999954224,0.0480681993067264,0.1933692347295754,43793 -1540.9889826774595,0.2287201881408691,1940.293526172638,5918,0,1940.293526172638,0.9849650263786316,0.0509615391492843,0.2083791497336055,43793,3481.6874401569366,0.989027202129364,0.0372817702591419,0.2582431101463912,0.9858922958374025,0.0481997281312942,0.2058692487304153,43793 -1668.957560300827,0.2584164142608642,2180.472153902054,6652,0,2180.472153902054,0.9852867722511292,0.0489683225750923,0.2158038268663492,43793,3849.8858761787415,0.9893152713775636,0.0364081189036369,0.2626560365239537,0.9861143231391908,0.0463856682181358,0.2199867776480811,43793 -1803.142256975174,0.2857906818389892,2420.4435505867004,7394,0,2420.4435505867004,0.9853141903877258,0.0483731143176555,0.2241192075235631,43793,4224.092652320862,0.9894617795944214,0.0357305519282817,0.2777985762334495,0.9862133860588074,0.0458943471312522,0.2205825816389293,43793 -1933.096392393112,0.3145980834960937,2660.463634490967,8133,0,2660.463634490967,0.9854860305786132,0.0479992181062698,0.2323910319049115,43793,4594.116451025009,0.9896026849746704,0.0350330807268619,0.3017076079581593,0.9864204525947572,0.0454067513346672,0.2327738536063612,43793 -2062.624636888504,0.3434135913848877,2900.5610976219177,8876,0,2900.5610976219177,0.9855205416679382,0.0482263714075088,0.2397758718441605,43793,4963.792473316193,0.9896164536476136,0.0348296910524368,0.3129953349495222,0.9864743947982788,0.045508861541748,0.2362024521907661,43793 -2191.552747964859,0.3722200393676758,3140.5976366996765,9619,0,3140.5976366996765,0.9856747388839722,0.047865305095911,0.2418852485105137,43793,5332.80707859993,0.9899412989616394,0.0336496606469154,0.3282852467346069,0.9865624904632568,0.0450242049992084,0.2426041542082993,43793 -2320.8233363628387,0.4018385410308838,3380.571000099182,10355,0,3380.571000099182,0.9856966137886048,0.0476522482931613,0.2460464273669188,43793,5702.102422714233,0.9902228116989136,0.0326662361621856,0.3493272232299749,0.9866660237312316,0.0448374822735786,0.2521623679673248,43793 -2454.1698989868164,0.4301824569702148,3620.668796777725,11097,0,3620.668796777725,0.9858115911483764,0.0472144670784473,0.254014006212062,43793,6075.597155094147,0.9903243780136108,0.0321387909352779,0.3718836815702409,0.98664653301239,0.0445500202476978,0.2481188682180836,43793 -2584.4375660419464,0.4595632553100586,3860.690770864487,11843,0,3860.690770864487,0.9857800006866456,0.0471613258123397,0.2472116040143474,43793,6445.938730955124,0.9905675053596495,0.0313689559698104,0.3800155912255726,0.9866924285888672,0.0444240570068359,0.2527786570240431,43793 -2716.204756736756,0.4894959926605224,4100.80003118515,12581,0,4100.80003118515,0.9857972860336304,0.0474021434783935,0.2524061717930836,43793,6817.867141008377,0.9902974367141724,0.0321634970605373,0.356260952255887,0.9867252707481384,0.044447436928749,0.2620155015599733,43793 -2841.2963259220123,0.5184030532836914,4340.828438043594,13326,0,4340.828438043594,0.9858890771865844,0.0476423725485801,0.2547764526425899,43793,7183.03778886795,0.9902656674385072,0.0322396829724311,0.3608134814237923,0.986819863319397,0.0445113666355609,0.2681383014152059,43793 -2971.42048573494,0.5491487979888916,4580.7753620147705,14067,0,4580.7753620147705,0.9859737753868104,0.0468311235308647,0.2589879758092704,43793,7553.162268877029,0.9905223846435548,0.031429897993803,0.3760678922408069,0.9869217872619628,0.043913647532463,0.2748815194390102,43793 -3101.82884144783,0.5812263488769531,4820.819413661957,14798,0,4820.819413661957,0.9856927990913392,0.0475738048553466,0.2555561248051884,43793,7923.672146081924,0.9905580878257751,0.0311930775642395,0.3902971635582445,0.9866790175437928,0.0446098558604717,0.2687362654336637,43793 -3231.4354071617126,0.6112728118896484,5060.941811084747,15524,0,5060.941811084747,0.985888659954071,0.0469398722052574,0.2664834005963081,43793,8293.455752372742,0.9906007647514344,0.0306931287050247,0.4037977692179422,0.9867488145828248,0.0442212112247943,0.2687962872441771,43793 -3361.455674648285,0.6414785385131836,5301.218809843063,16263,0,5301.218809843063,0.9859308004379272,0.0465576462447643,0.2644002246557672,43793,8663.80548620224,0.9907249808311462,0.030366413295269,0.4092169985073081,0.9867720007896424,0.0439391694962978,0.268025885787124,43793 -3490.2316720485687,0.672264575958252,5541.230150699616,16999,0,5541.230150699616,0.9860049486160278,0.0471245720982551,0.2623276261824349,43793,9032.645967960358,0.9909940361976624,0.029345579445362,0.4297240381765557,0.9869027137756348,0.0440330728888511,0.2727817331032616,43793 -3622.917279958725,0.7036941051483154,5781.477691650391,17734,0,5781.477691650391,0.9858478307724,0.0472505912184715,0.2620539259508231,43793,9405.632671833038,0.9909370541572572,0.0297421440482139,0.4172895521300028,0.9867074489593506,0.0442817248404026,0.266266108306227,43793 -3752.345653772354,0.7333090305328369,6021.736469745636,18468,0,6021.736469745636,0.9860584139823914,0.0464429780840873,0.266278519375825,43793,9775.371485948564,0.9910364151000975,0.0295282341539859,0.423870057287632,0.986844837665558,0.0438341982662677,0.2783312649957931,43793 -3880.687066555023,0.7648305892944336,6261.852321386337,19203,0,6261.852321386337,0.9859741926193236,0.0465784110128879,0.267380460209644,43793,10143.883538722992,0.9909691214561462,0.0296655260026454,0.4261870144662457,0.9869157075881958,0.0437863618135452,0.278916097311652,43793 -4011.790348768234,0.7990889549255371,6502.033669233322,19936,0,6502.033669233322,0.9857467412948608,0.0468638837337493,0.2608550174941025,43793,10515.224586963654,0.9907914400100708,0.0301237683743238,0.4203923085594159,0.9866023063659668,0.0442087799310684,0.274155667844338,43793 -4139.134879112244,0.8307468891143799,6742.084539890289,20676,0,6742.084539890289,0.9859619736671448,0.0467089414596557,0.2625949482700543,43793,10882.673321723938,0.990973174571991,0.0296208467334508,0.4198908615641574,0.9867675304412842,0.0440637990832328,0.2745195357738439,43793 -4272.533679008484,0.861682653427124,6982.223841190338,21409,0,6982.223841190338,0.9859548211097716,0.0463682934641838,0.2663493902990147,43793,11256.264298200607,0.9911744594573976,0.0287772081792354,0.4454445866037658,0.986868977546692,0.0436342731118202,0.2769052961180017,43793 -4400.717868804932,0.8936870098114014,7222.290660858154,22141,0,7222.290660858154,0.986151933670044,0.0466170087456703,0.2707095616703951,43793,11624.571064710615,0.9913221597671508,0.0282847136259078,0.4633205340969653,0.9869778156280518,0.0438401438295841,0.2831026302198904,43793 -4531.240593194962,0.9250195026397704,7462.577006340027,22880,0,7462.577006340027,0.986132562160492,0.0468167290091514,0.2688434020944504,43793,11995.432676315308,0.9915106296539308,0.0275696720927953,0.4735089752188432,0.9869562983512878,0.0441528148949146,0.2790978605654699,43793 -4659.816511154175,0.9558253288269044,7702.591665744781,23626,0,7702.591665744781,0.98606139421463,0.0470444560050964,0.2641718235610194,43793,12364.075269460678,0.9915829300880432,0.0273137167096138,0.4782024616666336,0.9868547916412354,0.0442686900496482,0.2710756095154921,43793 -4790.036564588547,0.98685622215271,7942.569582223892,24373,0,7942.569582223892,0.9861839413642884,0.0465402007102966,0.2701791484926142,43793,12734.326212882996,0.991339385509491,0.0281135123223066,0.4599286637345375,0.9870126843452454,0.043771956115961,0.2795645558470025,43793 -4914.877838134766,1.0195541381835938,8182.74741435051,25118,0,8182.74741435051,0.9860343933105468,0.0470565073192119,0.2656527333506259,43793,13099.399793863297,0.9912641048431396,0.0285040661692619,0.4477710025916689,0.9869189262390136,0.0441875495016574,0.2854335369192523,43793 -5045.421402454376,1.0534796714782717,8422.905773878098,25853,0,8422.905773878098,0.98613041639328,0.0464157201349735,0.2715346448095588,43793,13470.157238006592,0.9912461638450624,0.0284663066267967,0.4541184393410252,0.986922562122345,0.0437526814639568,0.2801145594678734,43793 -5172.97108578682,1.0846607685089111,8662.952923297882,26590,0,8662.952923297882,0.986127495765686,0.046985313296318,0.268233498378407,43793,13837.806899309158,0.9915059804916382,0.0278779752552509,0.4528256546297896,0.9869843125343324,0.0440012738108634,0.2870869938781617,43793 -5300.935266256332,1.1179378032684326,8903.157333612442,27329,0,8903.157333612442,0.9861692190170288,0.046562697738409,0.2776822753222365,43793,14206.030374526978,0.991453230381012,0.0276226308196783,0.4781773408012578,0.9870342016220092,0.0437435507774353,0.2914553301203189,43793 -5430.845782995224,1.1518630981445312,9143.340360164642,28067,0,9143.340360164642,0.9861502647399902,0.0464490689337253,0.2764971134633485,43793,14576.181078672407,0.991605579853058,0.0272387899458408,0.4788237662347861,0.986989974975586,0.0438026152551174,0.2826635722662011,43793 -5559.781363964081,1.1833415031433103,9383.368975162506,28809,0,9383.368975162506,0.9860761165618896,0.0462912358343601,0.2733801524206768,43793,14945.198653936386,0.991728663444519,0.0268288888037204,0.4913344138715968,0.9869307279586792,0.0436968393623828,0.2770742675956774,43793 -5687.874805927277,1.215362310409546,9623.559691429138,29552,0,9623.559691429138,0.9861944913864136,0.0467113442718982,0.2743730009782902,43793,15313.5361597538,0.9918919205665588,0.026216072961688,0.5069955294043473,0.9869245886802672,0.0441036969423294,0.2867755501464883,43793 -5818.787388086319,1.248405933380127,9863.722933769226,30291,0,9863.722933769226,0.9861359000205994,0.0463078506290912,0.2734795702662339,43793,15684.667104959488,0.9920125007629396,0.0261006280779838,0.5134227238878066,0.9868564009666444,0.0438586100935936,0.2853673628324881,43793 -5952.775523900986,1.2820994853973389,10103.851176023483,31034,0,10103.851176023483,0.9862104654312134,0.0466993637382984,0.2768710218968238,43793,16058.839148044586,0.9918270707130432,0.0266173146665096,0.4979042791415171,0.9870216250419616,0.0439487509429454,0.2887786879950446,43793 -6079.255536794663,1.3162181377410889,10343.984036922457,31766,0,10343.984036922457,0.986176371574402,0.0468368269503116,0.2766105111049243,43793,16425.5092420578,0.9914273023605348,0.0275633074343204,0.4744032507199148,0.9870082139968872,0.0441398322582244,0.2862378750792766,43793 -6206.500435352325,1.350158452987671,10584.240829467772,32504,0,10584.240829467772,0.9860592484474182,0.0467476919293403,0.2741441695898787,43793,16793.06715464592,0.991688907146454,0.0269593149423599,0.4850293929760725,0.9868718385696412,0.0442156381905078,0.286878314432203,43793 -6336.887906312943,1.3831946849822998,10824.421991825104,33243,0,10824.421991825104,0.9863191246986388,0.0469299219548702,0.274600195758634,43793,17163.691017866135,0.9917634129524232,0.0265021454542875,0.5059145808603558,0.9871076941490172,0.0442456379532814,0.2891483056162617,43793 -6466.034151554108,1.4176712036132812,11064.680742740631,33978,0,11064.680742740631,0.9861114621162416,0.0466743633151054,0.2715926387695952,43793,17533.152592658997,0.9919225573539734,0.0260626841336488,0.5149891537773273,0.986989974975586,0.043880894780159,0.2891141455523475,43793 -6594.542794704437,1.4511137008666992,11304.925557374954,34716,0,11304.925557374954,0.986080765724182,0.0472131036221981,0.2721279982617621,43793,17901.961685180664,0.992056965827942,0.0256364122033119,0.5100176386259829,0.9869741201400756,0.044404212385416,0.284721986229574,43793 -6724.5836000442505,1.4851961135864258,11544.964567899704,35448,0,11544.964567899704,0.9861615896224976,0.0470213033258914,0.2708743413208484,43793,18272.09800696373,0.9924238920211792,0.024371912702918,0.5482731850412265,0.9869936108589172,0.044317964464426,0.2857833668199892,43793 -6853.508927345276,1.5202860832214355,11784.975558757782,36186,0,11784.975558757782,0.9862930178642272,0.047236256301403,0.2747061458046834,43793,18641.091474056244,0.99224191904068,0.0249270088970661,0.5376483588503354,0.9870427250862122,0.0444023720920085,0.2826900278886544,43793 -6980.73330283165,1.5558314323425293,12025.085746765137,36923,0,12025.085746765137,0.9862037301063538,0.0468671582639217,0.2786638609715531,43793,19008.48497748375,0.9921140670776368,0.0253934878855943,0.519401692550248,0.9870488047599792,0.0440136454999446,0.2858875844680839,43793 -7109.82846736908,1.5917348861694336,12265.234405755997,37667,0,12265.234405755997,0.9862450361251832,0.0476358421146869,0.27484689170873,43793,19377.787124872208,0.9920461177825928,0.0255558025091886,0.5183719989880958,0.9870699644088744,0.044741440564394,0.2828680161152879,43793 -7239.716079473495,1.6303064823150637,12505.260452270508,38403,0,12505.260452270508,0.9862176179885864,0.0471333190798759,0.2745749674200974,43793,19747.76282644272,0.99214369058609,0.0253321044147014,0.517584729671504,0.9869266152381896,0.0444921404123306,0.2914535596647501,43793 -7368.2566702365875,1.6696994304656982,12745.4219083786,39131,0,12745.4219083786,0.9861586689949036,0.0471546128392219,0.275559722777406,43793,20116.52850627899,0.9921480417251588,0.0249071978032588,0.5369891992399025,0.9869238138198853,0.04449999704957,0.2855483362441873,43793 -7491.874098777771,1.7054214477539062,12985.612907409668,39866,0,12985.612907409668,0.9861182570457458,0.046987347304821,0.2686304230613403,43793,20480.394863128666,0.9923309087753296,0.0244951397180557,0.534551486086762,0.9868633151054382,0.0443124733865261,0.2802457208483463,43793 -7630.08384180069,1.7412102222442627,13225.71517944336,40600,0,13225.71517944336,0.986295998096466,0.0472631752490997,0.2766268806489722,43793,20858.764546394348,0.9924687743186952,0.0240130554884672,0.5691049705818645,0.9870585799217224,0.0444880723953247,0.2891847364191049,43793 -7756.849696874618,1.7805137634277344,13465.765576124191,41341,0,13465.765576124191,0.986275315284729,0.0476203300058841,0.2830297049350724,43793,21225.642529010773,0.9926599860191344,0.0233439113944768,0.5782135633249319,0.9871352910995485,0.0446878522634506,0.2936668428755002,43793 -7883.743916749954,1.8158748149871824,13705.795667171478,42081,0,13705.795667171478,0.9861797094345092,0.0473838150501251,0.2803766791096357,43793,21592.62402153015,0.9928129315376282,0.0229587573558092,0.5814897734017909,0.987063467502594,0.0444525331258773,0.2931879108949782,43793 -8011.200119256973,1.8510291576385496,13945.906858921053,42818,0,13945.906858921053,0.9862193465232848,0.0474888384342193,0.2773864421082346,43793,21960.24821233749,0.9925578236579896,0.0238033216446638,0.564259724336508,0.9869980812072754,0.0447949543595314,0.2879867390854697,43793 -8138.873854398727,1.886252403259277,14185.98506641388,43548,0,14185.98506641388,0.9861186742782592,0.0477583967149257,0.2761628388744243,43793,22328.05681943893,0.9924206733703612,0.0242899972945451,0.5410177292215665,0.9870171546936036,0.0448606684803962,0.2942300050683708,43793 -8267.179068088531,1.923084735870361,14426.06815457344,44278,0,14426.06815457344,0.9861868619918824,0.047412171959877,0.2761934727232424,43793,22696.50433659553,0.9924456477165222,0.0241081770509481,0.5539474357205003,0.9869996905326844,0.0446444861590862,0.2899616419114914,43793 -8392.092687368393,1.9581577777862549,14666.161107301712,45009,0,14666.161107301712,0.986127495765686,0.0472488440573215,0.2769249759013684,43793,23061.56834554672,0.992694079875946,0.0233445260673761,0.5726172586558297,0.986934781074524,0.0444469712674617,0.2926345439056429,43793 -8515.576565265656,1.9948420524597168,14906.305232286451,45750,0,14906.305232286451,0.9862618446350098,0.0473022013902664,0.2788322799294261,43793,23425.255395650864,0.9927871227264404,0.022890530526638,0.5784728400986485,0.987106442451477,0.0445770323276519,0.2865539158901619,43793 -8642.272961139679,2.0315263271331787,15146.25148677826,46486,0,15146.25148677826,0.9862555265426636,0.048071514815092,0.2768047922216397,43793,23791.95691871643,0.9928334951400756,0.0226695798337459,0.5871963290355688,0.9870455861091614,0.045237760990858,0.2849342940890492,43793 -8766.850444555283,2.067582845687866,15386.214908361437,47225,0,15386.214908361437,0.9862689971923828,0.0480483807623386,0.2787716507698172,43793,24156.55635929108,0.9932576417922974,0.0212671887129545,0.6281929712934287,0.9870005249977112,0.0450255014002323,0.2873283203104504,43793 -8892.886050701141,2.1041386127471924,15626.345014333723,47963,0,15626.345014333723,0.9861649870872498,0.0481113791465759,0.2771047990863807,43793,24522.78030061721,0.993298590183258,0.0214150361716747,0.6106789888766017,0.9869879484176636,0.045169573277235,0.2882434698512126,43793 -9019.805342435837,2.1411499977111816,15866.32997751236,48705,0,15866.32997751236,0.986295998096466,0.0483037866652011,0.2816303732269372,43793,24889.743696928024,0.993052899837494,0.0219904389232397,0.602626412423219,0.987127959728241,0.0452293753623962,0.2974782090753343,43793 -9149.583971261978,2.1799392700195312,16106.437074422836,49434,0,16106.437074422836,0.9862993359565736,0.0482907220721244,0.278605704419052,43793,25259.69244766236,0.993026316165924,0.0220904797315597,0.5961216643344134,0.9870460033416748,0.0452532954514026,0.2912603318833082,43793 -9276.038954019548,2.2161362171173096,16346.489510059357,50173,0,16346.489510059357,0.9862332344055176,0.0483635328710079,0.2769694896537659,43793,25626.25821518898,0.9931430816650392,0.0218639783561229,0.5994783641295355,0.9869099855422974,0.0454811379313468,0.2840237806197731,43793 -9408.138950824738,2.25294828414917,16586.699061632156,50904,0,16586.699061632156,0.9863094687461852,0.0486667156219482,0.2779850187157393,43793,25998.626806259155,0.9931342601776124,0.0215610228478908,0.6124295443985085,0.9870536923408508,0.0457516275346279,0.2861033373103874,43793 -9540.113079071043,2.294236183166504,16826.849710941315,51633,0,16826.849710941315,0.9862887859344482,0.0488758012652397,0.2805226259149176,43793,26370.819693803787,0.9932966232299804,0.021137511357665,0.6139599699223799,0.9870638251304626,0.0457816421985626,0.2866013596010617,43793 -9663.73701095581,2.3359954357147217,17066.924193382263,52364,0,17066.924193382263,0.986327588558197,0.0485045313835144,0.2836710398323977,43793,26734.583471298218,0.9933950901031494,0.0207604877650737,0.6385449009632981,0.987082540988922,0.0456486195325851,0.2941234444125793,43793 -9789.492477416992,2.375693798065185,17307.14094877243,53103,0,17307.14094877243,0.986295998096466,0.0487239696085453,0.2818153417585818,43793,27100.617507457733,0.9938114285469056,0.0195509679615497,0.6596532125631671,0.987057328224182,0.0461726002395153,0.2965265591422013,43793 -9909.892055034636,2.4137353897094727,17547.083233356476,53835,0,17547.083233356476,0.9863587617874146,0.0489544048905372,0.284141920602494,43793,27461.019245386124,0.9938933253288268,0.0194165743887424,0.6604764193160376,0.9871150255203248,0.0463631376624107,0.2949741394119192,43793 -10039.527564287186,2.451258659362793,17787.043855190277,54579,0,17787.043855190277,0.9861123561859132,0.0491110160946846,0.2809566913061406,43793,27830.67459988594,0.9937782287597656,0.0197028685361146,0.6423692855399952,0.9869229793548584,0.046136025339365,0.293365370342748,43793 -10167.246821165085,2.491405248641968,18027.26170873642,55314,0,18027.26170873642,0.9862247705459596,0.0493720471858978,0.2805172059646496,43793,28198.673761606216,0.9934359788894652,0.0206140764057636,0.631309266735881,0.98701673746109,0.0463982000946998,0.291090867299716,43793 -10294.143009662628,2.530449390411377,18267.27084064484,56050,0,18267.27084064484,0.986237406730652,0.049305684864521,0.285898564507146,43793,28565.63999724388,0.9934125542640686,0.020549688488245,0.6211181522600252,0.9869664311408995,0.0465161129832267,0.291339921317668,43793 -10420.40105509758,2.569205045700073,18507.41755247116,56785,0,18507.41755247116,0.9862247705459596,0.0494795329868793,0.2794730220607643,43793,28932.106006860733,0.9937226176261902,0.0196337290108203,0.6670870496531485,0.9869379997253418,0.0466276854276657,0.2882335202718856,43793 -10549.771411418917,2.607391595840454,18747.6509013176,57508,0,18747.6509013176,0.986193597316742,0.0498026758432388,0.2824516621344256,43793,29301.772493600845,0.9938977956771852,0.0191189385950565,0.6645415284070957,0.9869725108146667,0.046925239264965,0.287619898859582,43793 -10673.395092010498,2.649985790252685,18987.619652748108,58235,0,18987.619652748108,0.9862008094787598,0.0501836277544498,0.2808995364660824,43793,29665.43332839012,0.9940282702445984,0.0187131371349096,0.6765143704229382,0.986976146697998,0.0473107621073722,0.2894394066239375,43793 -10796.089787244797,2.6901214122772217,19227.710742235184,58968,0,19227.710742235184,0.9862332344055176,0.0505859851837158,0.2774700604654444,43793,30028.28174901009,0.9943478107452391,0.0176516938954591,0.698494736500919,0.9870277047157288,0.0475666038691997,0.2892462198395104,43793 -10918.601605176926,2.7292253971099854,19467.699901103973,59706,0,19467.699901103973,0.9862951040267944,0.0503821186721324,0.2815228197907312,43793,30390.844262599945,0.9944453835487366,0.017467513680458,0.7007133967198509,0.9869737029075624,0.0474945120513439,0.2934684995339162,43793 -11040.682296514511,2.767962694168091,19707.881704092026,60443,0,19707.881704092026,0.9862176179885864,0.0507260598242282,0.2830889685645464,43793,30753.16907453537,0.9942583441734314,0.0180808901786804,0.6905366072922718,0.986946940422058,0.0475612133741378,0.2924110143159727,43793 -11166.49221110344,2.806538820266724,19947.89863538742,61177,0,19947.89863538742,0.9861498475074768,0.0508788302540779,0.2810613524262955,43793,31119.0564391613,0.994189202785492,0.0180929508060216,0.6996259855388032,0.9869688749313354,0.0477571599185466,0.2939358251755667,43793 -11289.23188996315,2.847842216491699,20187.945281505585,61894,0,20187.945281505585,0.9862441420555116,0.0512084066867828,0.2797026918756632,43793,31481.90607857704,0.9942132234573364,0.0179898701608181,0.6911047828460279,0.9869351387023926,0.0481023862957954,0.2883425479472828,43793 -11412.255279064178,2.88604474067688,20427.938327550888,62626,0,20427.938327550888,0.9862605929374696,0.0515501610934734,0.2782914654132266,43793,31844.98264527321,0.9943777322769164,0.0174479763954877,0.6899870812187572,0.9869586825370787,0.0483387596905231,0.2845790269248508,43793 -11528.714675426483,2.926175355911255,20667.90618634224,63360,0,20667.90618634224,0.986270308494568,0.0512690246105194,0.280880005569525,43793,32201.472920179367,0.9945623874664308,0.0169183146208524,0.7184558394793336,0.986990749835968,0.04825434461236,0.2925786355656576,43793 -11651.349862098694,2.965777635574341,20908.09387183189,64106,0,20908.09387183189,0.9863195419311525,0.0515687614679336,0.2840452353283872,43793,32564.357125520703,0.9946512579917908,0.0165587794035673,0.724740173698647,0.9870366454124452,0.0486161857843399,0.2894655643069859,43793 -11775.719557523727,3.005262613296509,21148.1995844841,64844,0,21148.1995844841,0.986154854297638,0.0515463948249816,0.2841144505434902,43793,32928.89493846893,0.995011568069458,0.0157795790582895,0.7456424173940478,0.9868746995925904,0.0486071705818176,0.2935998478437701,43793 -11892.819792985916,3.0460572242736816,21388.1816380024,65585,0,21388.1816380024,0.9861262440681458,0.0519878938794136,0.2851634495635804,43793,33286.03973245621,0.9950687885284424,0.0155374398455023,0.744460006485492,0.9869189262390136,0.0489300824701786,0.2965670504331467,43793 -12008.396959066393,3.0855839252471924,21628.180775880814,66332,0,21628.180775880814,0.986139714717865,0.0520018711686134,0.2864536287657171,43793,33641.677525281906,0.9950166940689088,0.0156867504119873,0.7417331948355668,0.9868669509887696,0.0488156154751777,0.2948186594970466,43793 -12129.832447767258,3.125771999359131,21868.17398071289,67063,0,21868.17398071289,0.98628968000412,0.0523262955248355,0.2879645668159487,43793,34003.17022848129,0.9946566820144652,0.0164927765727043,0.7092785997806119,0.9869843125343324,0.049311026930809,0.2938793254291189,43793 -12246.04151725769,3.166090250015259,22108.26478767395,67806,0,22108.26478767395,0.9861851930618286,0.0525597482919693,0.2859367983010787,43793,34359.531853199005,0.9946697354316713,0.0165935941040515,0.7233143226581569,0.9868795275688172,0.0493571013212204,0.2930838377370317,43793 -12361.262688159944,3.2080864906311035,22348.27229404449,68554,0,22348.27229404449,0.9861515164375304,0.0522307381033897,0.2888726151712065,43793,34714.82387185097,0.9949089884757996,0.0157757550477981,0.7360908466647706,0.9869266152381896,0.0492417514324188,0.2951497027189781,43793 -12479.956215381622,3.247627258300781,22588.225852251053,69283,0,22588.225852251053,0.986163318157196,0.0527291633188724,0.2851350869481603,43793,35073.536291360855,0.9949740171432496,0.0155983455479145,0.7509368956064921,0.986963152885437,0.0495180785655975,0.2943744306236122,43793 -12599.468941926956,3.289933919906616,22828.19718813896,70025,0,22828.19718813896,0.9862056374549866,0.0529153123497962,0.285105080197603,43793,35433.087446689606,0.9950136542320251,0.0153692392632365,0.7516144347200444,0.9869911670684814,0.049633577466011,0.2945740366852919,43793 -12721.406133413317,3.331706047058105,23068.196282863617,70763,0,23068.196282863617,0.9860811829566956,0.0528018586337566,0.2839924188185285,43793,35795.0873939991,0.9954207539558412,0.0145008927211165,0.7669779487185071,0.9868775010108948,0.0495731830596923,0.2936450999214748,43793 -12844.09109044075,3.37443470954895,23308.296140670776,71500,0,23308.296140670776,0.9862028956413268,0.0531973056495189,0.2830510085013739,43793,36157.941160440445,0.9954916834831238,0.0141904382035136,0.7654607218142788,0.9869996905326844,0.049864936619997,0.2918474328552706,43793 -12959.840638399124,3.4160678386688232,23548.290924072266,72242,0,23548.290924072266,0.9862096309661864,0.0534116253256797,0.284107939661829,43793,36513.74961900711,0.9953582286834716,0.0145191708579659,0.7666464923592602,0.9869623780250548,0.0500339493155479,0.292061711280013,43793 -13080.330332517624,3.4564311504364014,23788.352286815643,72992,0,23788.352286815643,0.9861733913421632,0.0535128824412822,0.283005966453115,43793,36874.36281085014,0.9953705072402954,0.0144443539902567,0.7663247210353313,0.9869558811187744,0.0500800125300884,0.2917908410573102,43793 -13198.867261886597,3.4985146522521973,24028.334567308422,73740,0,24028.334567308422,0.9862037301063538,0.0534953586757183,0.2853542408933036,43793,37232.946142435074,0.9952463507652284,0.0148539431393146,0.772484291095771,0.9869952201843262,0.0501833371818065,0.2928766906316999,43793 -13316.381780862808,3.541311264038086,24268.569067955017,74480,0,24268.569067955017,0.9862239360809326,0.0535146072506904,0.2863500061223415,43793,37590.75980067253,0.9953171610832214,0.0145237715914845,0.7563114128114683,0.9869773983955384,0.0501542091369628,0.2936811767705728,43793 -13435.755778074265,3.583484649658203,24508.743983984,75214,0,24508.743983984,0.9862167835235596,0.0534422062337398,0.2858810997305218,43793,37950.37283325195,0.9954608678817748,0.0142556382343173,0.7678145129889065,0.9869444966316224,0.0501078702509403,0.2945273355628647,43793 -13550.260266304016,3.625715970993042,24748.70914363861,75953,0,24748.70914363861,0.9862155318260192,0.0535604022443294,0.2856587078179153,43793,38304.90705013275,0.9954636096954346,0.0141442064195871,0.7713807343557014,0.9870074391365052,0.050181545317173,0.2951738603598899,43793 -13670.602442026138,3.667493104934693,24988.704931735992,76699,0,24988.704931735992,0.986198663711548,0.0535183511674404,0.2852305929566506,43793,38665.30813074112,0.995585560798645,0.0139017924666404,0.7811432594628471,0.9869680404663086,0.0502056255936622,0.2938870177098965,43793 -13792.465679645538,3.710915088653565,25228.811174869537,77440,0,25228.811174869537,0.9862180352211,0.0535914674401283,0.2857701660308041,43793,39027.34461021423,0.9955874085426332,0.0138820931315422,0.7873145623554283,0.986979842185974,0.0502523556351661,0.2945675573855502,43793 -13906.49803853035,3.752807378768921,25469.02388048172,78188,0,25469.02388048172,0.9862058162689208,0.0535954870283603,0.2853759011460834,43793,39381.65291333199,0.9955474734306335,0.0139575647190213,0.7654514914396318,0.986976146697998,0.0502530895173549,0.2955371902344945,43793 -14023.504931926727,3.79575777053833,25709.17568540573,78924,0,25709.17568540573,0.9861990809440612,0.0535637363791465,0.2856132740395028,43793,39738.876128435135,0.9955233931541444,0.0140358218923211,0.7749583686005181,0.9869713187217712,0.0502181053161621,0.2957871287319151,43793 -14138.14640378952,3.837240219116211,25949.15015339852,79667,0,25949.15015339852,0.9861965775489808,0.0535612180829048,0.2857440594827082,43793,40093.55513715744,0.9955575466156006,0.0140057615935802,0.7686545535718703,0.9869741201400756,0.050214298069477,0.2952958390202684,43793 -14259.4233148098,3.880352020263672,26189.18081855774,80396,0,26189.18081855774,0.9861940741539,0.0535595826804637,0.2856693514423585,43793,40454.92825961113,0.9955423474311828,0.0139799173921346,0.7752668434431129,0.9869745373725892,0.0502125211060047,0.2944926144825867,43793 -14380.997165203094,3.922349452972412,26429.1895840168,81136,0,26429.1895840168,0.9861940741539,0.0535595826804637,0.2857074698180062,43793,40816.57488918304,0.9955633878707886,0.0139617240056395,0.7864327946995338,0.9869745373725892,0.0502125211060047,0.2945485630598818,43793 -14503.09766292572,3.96537709236145,26669.127532482147,81870,0,26669.127532482147,0.9861940741539,0.0535595789551734,0.2857734941346637,43793,41178.67870616913,0.995542049407959,0.0140200974419713,0.7701309672792755,0.9869745373725892,0.0502125211060047,0.2944653538357855,43793 -14621.939532279968,4.011141300201416,26909.205247163773,82607,0,26909.205247163773,0.9861940741539,0.0535595789551734,0.2856746842297289,43793,41537.66540026665,0.9955546259880066,0.013937359675765,0.7792414263929057,0.9869743585586548,0.0502125211060047,0.2944700054731899,43793 -14742.293938875198,4.054401874542236,27149.281101465225,83349,0,27149.281101465225,0.9861940741539,0.0535595789551734,0.2857537976988816,43793,41898.16096329689,0.9955734610557556,0.0139914928004145,0.7649956167611187,0.9869745373725892,0.0502125173807144,0.294621958718315,43793 -14861.458558797836,4.097753524780273,27389.44238615036,84073,0,27389.44238615036,0.9861940741539,0.0535595789551734,0.2857942642529996,43793,42257.554856061935,0.9955219030380248,0.0140211423859,0.7767582004172939,0.9869745373725892,0.0502125211060047,0.2945195156428599,43793 -14977.927819490433,4.1424174308776855,27629.67704296112,84796,0,27629.67704296112,0.9861940741539,0.0535595826804637,0.2857006953799908,43793,42614.32751560211,0.9955556392669678,0.0139660816639661,0.7768871909447967,0.9869745373725892,0.0502125173807144,0.2944212728334547,43793 -15095.281085968018,4.185272216796875,27869.899612665176,85525,0,27869.899612665176,0.9861940741539,0.0535595826804637,0.285688490661133,43793,42971.96783590317,0.9955314993858336,0.0140366535633802,0.7768274028723495,0.9869745373725892,0.0502125211060047,0.294535595189946,43793 -15214.328017950058,4.231759786605835,28109.899499177933,86254,0,28109.899499177933,0.9861940741539,0.0535595789551734,0.2857650136919014,43793,43331.08562636376,0.9955706000328064,0.0139378253370523,0.7793773087752416,0.9869743585586548,0.0502125173807144,0.294717817910109,43793 -15335.024807214735,4.275558471679688,28349.91640353203,86979,0,28349.91640353203,0.9861940741539,0.0535595789551734,0.2858110313416989,43793,43691.868216753006,0.995578110218048,0.0139131750911474,0.7689155617309109,0.9869745373725892,0.0502125211060047,0.2945113082026619,43793 -15453.49022746086,4.321666479110718,28590.017910003666,87709,0,28590.017910003666,0.9861940741539,0.0535595789551734,0.2858338221202554,43793,44050.504854917526,0.9954716563224792,0.0142105771228671,0.7767942679630444,0.9869741201400756,0.0502125211060047,0.2945353564006339,43793 -15576.203959941864,4.365557670593262,28830.172302246094,88433,0,28830.172302246094,0.9861940741539,0.0535595789551734,0.2857710486792785,43793,44413.43903660774,0.9956030249595642,0.0138314524665474,0.7797903015231427,0.9869743585586548,0.0502125211060047,0.294608031294072,43793 -15693.201646327972,4.409271240234375,29070.12420630455,89165,0,29070.12420630455,0.9861940741539,0.0535595789551734,0.2857071438595264,43793,44770.4546122551,0.995522916316986,0.0140581410378217,0.7773444141831463,0.9869743585586548,0.0502125211060047,0.2945774283025879,43793 -15816.148104906082,4.453778028488159,29310.1939008236,89897,0,29310.1939008236,0.9861940741539,0.0535595789551734,0.2859636076934878,43793,45133.53694319725,0.9955548048019408,0.0139969484880566,0.7708878094406291,0.9869745373725892,0.0502125211060047,0.2944496519986352,43793 -15933.94653224945,4.502529859542847,29550.39145565033,90633,0,29550.39145565033,0.9861940741539,0.0535595789551734,0.2857039443897824,43793,45491.60465455055,0.9955064058303832,0.0140866888687014,0.7742396649486707,0.9869745373725892,0.0502125211060047,0.2945673298006516,43793 -16048.78459572792,4.54682207107544,29790.36308145523,91370,0,29790.36308145523,0.9861940741539,0.0535595789551734,0.2856685747868243,43793,45846.48057794571,0.9955812096595764,0.0139295449480414,0.7740129386936662,0.9869745373725892,0.0502125211060047,0.294608725721259,43793 -16163.1401386261,4.591657400131226,30030.31191968918,92104,0,30030.31191968918,0.9861940741539,0.0535595789551734,0.2859116347293922,43793,46200.85204720497,0.9955791234970092,0.0138949332758784,0.7753812512261653,0.9869743585586548,0.0502125211060047,0.2945517697474837,43793 -16278.03237748146,4.639986991882324,30270.42789888382,92839,0,30270.42789888382,0.9861940741539,0.0535595789551734,0.2857860262767945,43793,46555.93433070183,0.9955458641052246,0.0140043785795569,0.7818500732533857,0.9869745373725892,0.0502125211060047,0.2944538656084857,43793 -16402.670434951782,4.685810327529907,30510.66025996208,93570,0,30510.66025996208,0.9861940741539,0.0535595826804637,0.2858384018367464,43793,46920.87311458588,0.9955337047576904,0.0139936245977878,0.7719343342136304,0.9869745373725892,0.0502125173807144,0.2945561478444909,43793 -16515.649873495102,4.73239541053772,30750.85684776306,94307,0,30750.85684776306,0.9861940741539,0.0535595826804637,0.2857131922781328,43793,47274.11804151535,0.9955654740333556,0.0139470482245087,0.7784315567150579,0.9869745373725892,0.0502125211060047,0.294558602517844,43793 -16629.640988588333,4.777262449264526,30990.98055648804,95048,0,30990.98055648804,0.9861940741539,0.0535595826804637,0.2858839892161938,43793,47628.29944491386,0.9955480098724364,0.0140321403741836,0.7664942574340686,0.9869745373725892,0.0502125211060047,0.2945165792896103,43793 -16745.10850763321,4.822614192962647,31231.16896510124,95782,0,31231.16896510124,0.9861940741539,0.0535595826804637,0.2858748191382337,43793,47984.0223646164,0.9955334663391112,0.0140433926135301,0.7828414672130524,0.9869745373725892,0.0502125211060047,0.294550277044121,43793 -16866.686855316162,4.869102954864502,31471.37131333351,96521,0,31471.37131333351,0.9861940741539,0.0535595789551734,0.2857925861415153,43793,48345.87097764015,0.9955517649650574,0.0139227332547307,0.7761237102647274,0.9869745373725892,0.0502125173807144,0.2945430188510121,43793 -16980.036007642746,4.920137643814087,31711.512481689453,97252,0,31711.512481689453,0.9861940741539,0.0535595826804637,0.2856464496291582,43793,48699.4358150959,0.9955731630325316,0.0139761669561266,0.7844379211332362,0.9869743585586548,0.0502125211060047,0.2947258707045029,43793 -17093.668981790543,4.96671462059021,31951.521282196045,97989,0,31951.521282196045,0.9861940741539,0.0535595789551734,0.2857639257450596,43793,49053.14595079422,0.9955433011054992,0.0139860603958368,0.7620940939206702,0.9869745373725892,0.0502125211060047,0.2945639696575376,43793 -17209.38709139824,5.011778116226196,32191.674822330475,98734,0,32191.674822330475,0.9861940741539,0.0535595826804637,0.2856923450793034,43793,49409.0846850872,0.9955700635910034,0.0139950010925531,0.7774484806299122,0.9869745373725892,0.0502125211060047,0.2945178067125195,43793 -17324.055718421936,5.057526350021362,32431.794410705566,99472,0,32431.794410705566,0.9861940741539,0.0535595826804637,0.2858091196595225,43793,49763.94093894959,0.9954909086227416,0.0140851242467761,0.7770188283552418,0.9869743585586548,0.0502125211060047,0.2945293023106267,43793 -17445.024650096893,5.107828140258789,32671.83607316017,100212,0,32671.83607316017,0.9861940741539,0.0535595826804637,0.2857367479087632,43793,50125.02400946617,0.9955756664276124,0.0139365773648023,0.7749262622194402,0.9869743585586548,0.0502125211060047,0.2944989993784632,43793 -17559.29871249199,5.153444766998291,32911.79198694229,100950,0,32911.79198694229,0.9861940741539,0.0535595789551734,0.2858639553246325,43793,50479.322907447815,0.99558025598526,0.0139152826741337,0.7891355115019367,0.9869745373725892,0.0502125211060047,0.2946131160601966,43793 -17679.49751996994,5.205646276473999,33151.745717048645,101684,0,33151.745717048645,0.9861940741539,0.0535595789551734,0.2857212204469937,43793,50839.55113840103,0.9955013990402222,0.0140962926670908,0.7691005217213778,0.9869745373725892,0.0502125173807144,0.2944360238053917,43793 -17793.961818933487,5.252429485321045,33391.99066519737,102411,0,33391.99066519737,0.9861940741539,0.0535595826804637,0.2857138954593107,43793,51194.32950162888,0.9955734014511108,0.0139194196090102,0.7724062579147245,0.9869743585586548,0.0502125211060047,0.2945845273702825,43793 -17908.058881998062,5.303457260131836,33632.18942737579,103146,0,33632.18942737579,0.9861940741539,0.0535595826804637,0.2856564410760763,43793,51548.698831796646,0.9955319166183472,0.0140671050176024,0.7735482801034659,0.9869745373725892,0.0502125173807144,0.2945610976385742,43793 -18017.360858678818,5.349218845367432,33872.21159863472,103893,0,33872.21159863472,0.9861940741539,0.0535595826804637,0.2856985312715242,43793,51898.09083795548,0.9955902695655824,0.0138916801661252,0.7762755546021376,0.9869745373725892,0.0502125211060047,0.2945026049765142,43793 -18135.431131839752,5.395155668258667,34112.31607079506,104640,0,34112.31607079506,0.9861940741539,0.0535595826804637,0.2858397657044168,43793,52256.33351922035,0.9955270886421204,0.0139942103996872,0.7827216048150842,0.9869743585586548,0.0502125211060047,0.2945389773860817,43793 -18254.377835989,5.444571256637573,34352.42887854576,105384,0,34352.42887854576,0.9861940741539,0.0535595826804637,0.285678615368202,43793,52615.46451854706,0.9955507516860962,0.0140451323240995,0.7694572761821414,0.9869745373725892,0.0502125173807144,0.2945578506109468,43793 -18369.77789402008,5.496541738510132,34592.64105916023,106117,0,34592.64105916023,0.9861940741539,0.0535595826804637,0.2856916424811894,43793,52971.15380167961,0.995536208152771,0.0139791183173656,0.7786668022825228,0.9869743585586548,0.0502125173807144,0.2945464250325268,43793 -18486.26060438156,5.544113397598267,34832.609432935715,106851,0,34832.609432935715,0.9861940741539,0.0535595789551734,0.2856666989155942,43793,53327.67676925659,0.9955457448959352,0.0140038868412375,0.7663039849163852,0.9869745373725892,0.0502125173807144,0.2945365314046178,43793 -18598.39912390709,5.606086492538452,35072.64724302292,107596,0,35072.64724302292,0.9861940741539,0.0535595789551734,0.285848566255761,43793,53679.93759918213,0.9955405592918396,0.0139972921460866,0.7824891126363962,0.9869745373725892,0.0502125211060047,0.2945553946251877,43793 -18709.627541542053,5.654670476913452,35312.800602436066,108330,0,35312.800602436066,0.9861940741539,0.0535595826804637,0.2856731132602263,43793,54031.39209771156,0.9955700039863586,0.013952019624412,0.780402457785562,0.9869743585586548,0.0502125211060047,0.294552374588472,43793 -18819.788346767426,5.7019219398498535,35552.739119291306,109073,0,35552.739119291306,0.9861940741539,0.0535595789551734,0.2857106592316629,43793,54381.56074023247,0.9955811500549316,0.0138882901519536,0.7799591160022632,0.9869745373725892,0.0502125211060047,0.2944818855644533,43793 -18934.87672638893,5.749660015106201,35792.734656095505,109816,0,35792.734656095505,0.9861940741539,0.0535595826804637,0.2857457854232376,43793,54736.71452140808,0.9955028891563416,0.0141618028283119,0.7691131012284711,0.9869745373725892,0.0502125211060047,0.2945828852469386,43793 -19046.673018455505,5.798440217971802,36032.70607757568,110558,0,36032.70607757568,0.9861940741539,0.0535595826804637,0.2856821056185651,43793,55088.55318880081,0.995578408241272,0.0138934273272752,0.7705730961172595,0.9869743585586548,0.0502125173807144,0.294555251483837,43793 -19161.44490623474,5.846911668777466,36272.84723806381,111293,0,36272.84723806381,0.9861940741539,0.0535595789551734,0.2857235851747701,43793,55443.536256074905,0.9955031871795654,0.0141362929716706,0.7735100883177448,0.9869741201400756,0.0502125211060047,0.2944882404998152,43793 -19269.7726829052,5.894815921783447,36513.03610134125,112032,0,36513.03610134125,0.9861940741539,0.0535595826804637,0.2856990657290821,43793,55792.12286829949,0.9955770969390868,0.0138916047289967,0.7804736794685526,0.9869745373725892,0.0502125211060047,0.2945283101312911,43793 -19386.026316165924,5.942633867263794,36753.05175638199,112781,0,36753.05175638199,0.9861940741539,0.0535595789551734,0.2857793929004372,43793,56148.46195721626,0.99552184343338,0.0140519654378294,0.7811979990159192,0.9869745373725892,0.0502125211060047,0.2946104591043645,43793 -19503.940582990646,5.9964470863342285,36993.05020284653,113520,0,36993.05020284653,0.9861940741539,0.0535595826804637,0.2856983237546496,43793,56506.45134592056,0.9955678582191468,0.0139427185058593,0.7712682210840969,0.9869741201400756,0.0502125173807144,0.2946619508789923,43793 -19613.848816156387,6.050333261489868,37233.1000483036,114247,0,37233.1000483036,0.9861940741539,0.0535595789551734,0.2857388313861314,43793,56856.48850417137,0.9955666661262512,0.013914069160819,0.7792822577149576,0.9869745373725892,0.0502125211060047,0.2945101564142774,43793 -19727.10418367386,6.103281021118164,37473.09302973747,114989,0,37473.09302973747,0.9861940741539,0.0535595826804637,0.2858201532068772,43793,57209.81217265129,0.9955500960350036,0.0140851167961955,0.7752664714391566,0.9869743585586548,0.0502125173807144,0.2945460843140218,43793 -19841.389622211456,6.152727365493774,37713.152237176895,115723,0,37713.152237176895,0.9861940741539,0.0535595789551734,0.285702221607734,43793,57564.23022627831,0.9955061078071594,0.0140688307583332,0.7755653651667929,0.9869743585586548,0.0502125211060047,0.2945540190760538,43793 -19954.476145982742,6.206709146499634,37953.37311530113,116462,0,37953.37311530113,0.9861940741539,0.0535595826804637,0.2856991776358954,43793,57917.61430335045,0.9955783486366272,0.0138890761882066,0.7852735330342986,0.9869745373725892,0.0502125211060047,0.2944726678951031,43793 -20067.16177225113,6.261343479156494,38193.32923769951,117193,0,38193.32923769951,0.9861940741539,0.0535595789551734,0.2857992774557656,43793,58270.33309483528,0.995569348335266,0.0139385117217898,0.7685402304600731,0.9869745373725892,0.0502125211060047,0.2944315252732025,43793 -20179.09550833702,6.315385818481445,38433.5650575161,117935,0,38433.5650575161,0.9861940741539,0.0535595789551734,0.2858286062184439,43793,58622.57953858376,0.9955466389656068,0.0140097755938768,0.778482909552098,0.9869745373725892,0.0502125211060047,0.2945424307965914,43793 -20287.63827586174,6.3654186725616455,38673.81978392601,118679,0,38673.81978392601,0.9861940741539,0.0535595789551734,0.2856903053998522,43793,58971.44974565506,0.9955329895019532,0.0140114072710275,0.76781053698901,0.9869745373725892,0.0502125211060047,0.2944648560792302,43793 -20401.36284303665,6.414927959442139,38913.83612418175,119409,0,38913.83612418175,0.9861940741539,0.0535595826804637,0.285719341842257,43793,59325.26311969757,0.9955406188964844,0.013977606780827,0.7782823868286457,0.9869743585586548,0.0502125211060047,0.2945175274820147,43793 -20515.88113284111,6.468688011169434,39153.818912267685,120140,0,39153.818912267685,0.9861940741539,0.0535595826804637,0.2857392894248236,43793,59679.84038281441,0.9955416321754456,0.013997571542859,0.7779783333646086,0.9869745373725892,0.0502125211060047,0.2944633500683524,43793 -20630.25913882256,6.522387504577637,39394.06165885925,120873,0,39394.06165885925,0.9861940741539,0.0535595826804637,0.2858240013559314,43793,60034.53656172752,0.9955613017082214,0.0140098659321665,0.7866055579445876,0.9869743585586548,0.0502125173807144,0.2944610430847229,43793 -20740.979505062103,6.573132276535034,39634.28438973427,121614,0,39634.28438973427,0.9861940741539,0.0535595826804637,0.2857204561472292,43793,60385.552882909775,0.9955511689186096,0.0139482170343399,0.7662657186104772,0.9869745373725892,0.0502125211060047,0.2945873331954825,43793 -20858.42757821083,6.62934947013855,39874.48172211647,122354,0,39874.48172211647,0.9861940741539,0.0535595826804637,0.2858334033911579,43793,60743.27828383446,0.9955592155456544,0.0140119800344109,0.7753787037792779,0.9869745373725892,0.0502125211060047,0.2944693613267529,43793 -20966.82272863388,6.682506799697876,40114.56664967537,123095,0,40114.56664967537,0.9861940741539,0.0535595826804637,0.2857019667587944,43793,61091.83598303795,0.9955024123191832,0.0140598332509398,0.770876647929712,0.9869745373725892,0.0502125211060047,0.2944734527416136,43793 -21077.35178470612,6.739075183868408,40354.73410964012,123830,0,40354.73410964012,0.9861940741539,0.0535595789551734,0.2858282936163239,43793,61442.61308217049,0.9956189393997192,0.0138553949072957,0.7783691267641425,0.9869745373725892,0.0502125211060047,0.2944677065948684,43793 -21188.7815053463,6.7893126010894775,40594.8329679966,124560,0,40594.8329679966,0.9861940741539,0.0535595789551734,0.2859011175114634,43793,61794.2158946991,0.995537519454956,0.0140228625386953,0.7846410551431506,0.9869745373725892,0.0502125211060047,0.2946003063371421,43793 -21301.28376197815,6.841134786605835,40834.93838334084,125294,0,40834.93838334084,0.9861940741539,0.0535595826804637,0.2856668054003706,43793,62146.89764785767,0.9955279231071472,0.0140428263694047,0.7715511151186719,0.9869743585586548,0.0502125211060047,0.2945159621426331,43793 -21413.383462429047,6.894498348236084,41075.21634960175,126032,0,41075.21634960175,0.9861940741539,0.0535595789551734,0.2859947860418702,43793,62499.35304760933,0.9955497980117798,0.0139739634469151,0.7776480587757924,0.9869745373725892,0.0502125211060047,0.2945289512400195,43793 -21524.56535053253,6.948744535446167,41315.295952796936,126774,0,41315.295952796936,0.9861940741539,0.0535595826804637,0.2857411098029996,43793,62850.69105839729,0.9955505132675172,0.013991804793477,0.7683404890570733,0.9869741201400756,0.0502125173807144,0.2945988444665576,43793 -21632.11439609528,6.999794960021973,41555.32707571983,127512,0,41555.32707571983,0.9861940741539,0.0535595826804637,0.285670475255235,43793,63198.34479141235,0.995507538318634,0.014115703292191,0.780440909566623,0.9869745373725892,0.0502125173807144,0.294529530609065,43793 -21743.65156912804,7.051799297332764,41795.47059297562,128248,0,41795.47059297562,0.9861940741539,0.0535595826804637,0.2856927221154515,43793,63550.09992480278,0.9955496191978456,0.0139685394242405,0.7795067563109996,0.9869745373725892,0.0502125173807144,0.2945238855690372,43793 -21857.52256894112,7.103084325790405,42035.69354510307,128988,0,42035.69354510307,0.9861940741539,0.0535595826804637,0.2857917035503163,43793,63904.26799011231,0.9955980181694032,0.0138596631586551,0.78344074625687,0.9869745373725892,0.0502125211060047,0.2945503802955432,43793 -21971.176341056824,7.15567946434021,42275.78005862236,129726,0,42275.78005862236,0.9861940741539,0.0535595789551734,0.2856905130623254,43793,64258.08281850815,0.9955507516860962,0.0139851318672299,0.7687265903200935,0.9869745373725892,0.0502125211060047,0.2946306155417239,43793 -22082.26949906349,7.207721710205078,42515.92790675163,130455,0,42515.92790675163,0.9861940741539,0.0535595789551734,0.2857761141941721,43793,64609.39808940888,0.9955714344978333,0.0139600820839405,0.7693036488054488,0.9869745373725892,0.0502125173807144,0.2945816972273906,43793 -22194.663531780243,7.26179051399231,42755.94807124138,131190,0,42755.94807124138,0.9861940741539,0.0535595789551734,0.2857930735203836,43793,64961.888946056366,0.9954754114151,0.0141520276665687,0.7745251563418538,0.9869745373725892,0.0502125211060047,0.2944870852900064,43793 -22305.20131421089,7.314098358154297,42995.89313149452,131929,0,42995.89313149452,0.9861940741539,0.0535595826804637,0.2856906678947874,43793,65312.44583821297,0.9955991506576538,0.0138892102986574,0.7770598738446075,0.9869745373725892,0.0502125173807144,0.2946191638573021,43793 -22411.660974740986,7.368078708648682,43236.01311707497,132666,0,43236.01311707497,0.9861940741539,0.0535595789551734,0.285701665999011,43793,65659.10172605515,0.9955676794052124,0.0139704309403896,0.784020367065633,0.9869745373725892,0.0502125211060047,0.2944997865847881,43793 -22522.91188192368,7.421867370605469,43476.01622104645,133402,0,43476.01622104645,0.9861940741539,0.0535595826804637,0.2859014656423316,43793,66010.43513774872,0.9955294132232666,0.0139775266870856,0.7719156447251221,0.9869745373725892,0.0502125211060047,0.2944672347701114,43793 -22635.523801088333,7.478373527526855,43716.06605672836,134140,0,43716.06605672836,0.9861940741539,0.0535595789551734,0.2857146601910127,43793,66363.17952227592,0.995570719242096,0.0139369834214448,0.7698331831434866,0.9869743585586548,0.0502125211060047,0.2945582199786213,43793 -22746.583614349365,7.531560182571411,43956.02018260956,134871,0,43956.02018260956,0.9861940741539,0.0535595826804637,0.2858309724099989,43793,66714.26945352554,0.995523989200592,0.0140719972550868,0.772939178937591,0.9869743585586548,0.0502125211060047,0.2945825986045106,43793 -22854.359155654907,7.591469287872314,44196.05651593208,135606,0,44196.05651593208,0.9861940741539,0.0535595789551734,0.2857107305887505,43793,67062.164737463,0.9955481886863708,0.0139481537044048,0.7806334622776318,0.9869745373725892,0.0502125211060047,0.2945880907545196,43793 -22960.33766627312,7.645104646682739,44436.13486599922,136339,0,44436.13486599922,0.9861940741539,0.0535595826804637,0.2857003263771898,43793,67408.29759216309,0.995563507080078,0.0139814345166087,0.7787603347264325,0.9869745373725892,0.0502125211060047,0.2946117411358875,43793 -23069.12981247902,7.69776439666748,44676.229766607285,137078,0,44676.229766607285,0.9861940741539,0.0535595826804637,0.2856534864633366,43793,67757.26006889343,0.995528757572174,0.014043471775949,0.773235530751952,0.9869745373725892,0.0502125173807144,0.2945560652848558,43793 -23179.11955356598,7.754902839660644,44916.29040455818,137812,0,44916.29040455818,0.9861940741539,0.0535595826804637,0.2856929229490918,43793,68107.390021801,0.9955397248268129,0.0140498215332627,0.7716928762320991,0.9869743585586548,0.0502125211060047,0.2945598515327954,43793 -23287.19401454925,7.812101364135742,45156.33108019829,138549,0,45156.33108019829,0.9861940741539,0.0535595789551734,0.285655050187629,43793,68455.58481907845,0.9955530166625975,0.0139563847333192,0.7666460856938745,0.9869745373725892,0.0502125211060047,0.2945474761017078,43793 -23396.60202598572,7.865617513656616,45396.46606326103,139280,0,45396.46606326103,0.9861940741539,0.0535595789551734,0.2856873922115879,43793,68805.20327782631,0.9955349564552308,0.0139978528022766,0.7760984157232222,0.9869743585586548,0.0502125211060047,0.294489289603176,43793 -23505.35227966309,7.920868873596191,45636.407398462296,140024,0,45636.407398462296,0.9861940741539,0.0535595789551734,0.2858944262889938,43793,69153.9723265171,0.9955670833587646,0.0139693990349769,0.7815567202019054,0.9869745373725892,0.0502125211060047,0.2945874760881046,43793 -23615.431432724,7.97593355178833,45876.594853162766,140767,0,45876.594853162766,0.9861940741539,0.0535595826804637,0.2856803090264078,43793,69504.31644678116,0.995562732219696,0.0139715448021888,0.7835845042765484,0.9869741201400756,0.0502125211060047,0.2946952873297384,43793 -23724.393659591675,8.031497716903687,46116.76946258545,141514,0,46116.76946258545,0.9861940741539,0.0535595826804637,0.2857075935707077,43793,69853.53129506111,0.9955607056617736,0.0139527916908264,0.7693302481492503,0.9869745373725892,0.0502125211060047,0.2945097398624374,43793 -23836.30408072472,8.086358785629272,46356.98973464966,142256,0,46356.98973464966,0.9861940741539,0.0535595826804637,0.2857075416590519,43793,70205.73898673058,0.995522677898407,0.0140734603628516,0.7720646022590836,0.9869743585586548,0.0502125211060047,0.2945275033973265,43793 -23943.26594948769,8.141406297683716,46597.234624147415,142996,0,46597.234624147415,0.9861940741539,0.0535595789551734,0.2857092445829384,43793,70553.02288079262,0.9955246448516846,0.0140249570831656,0.7742682707114532,0.9869741201400756,0.0502125173807144,0.2945714240781431,43793 -24051.82504749298,8.196717262268066,46837.26299023628,143740,0,46837.26299023628,0.9861940741539,0.0535595826804637,0.285687980017125,43793,70901.68744587898,0.995592713356018,0.0139059126377105,0.7787287733952317,0.9869745373725892,0.0502125173807144,0.2947369056523955,43793 -24158.07507967949,8.25191354751587,47077.19783067703,144487,0,47077.19783067703,0.9861940741539,0.0535595789551734,0.28566755639923,43793,71247.94971942902,0.995563507080078,0.0140124568715691,0.7812311607734994,0.9869745373725892,0.0502125211060047,0.2945227253288941,43793 -24266.130458831787,8.306992053985596,47317.14991569519,145225,0,47317.14991569519,0.9861940741539,0.0535595826804637,0.285769349252281,43793,71596.03808999062,0.995533287525177,0.0139496708288788,0.7758802733144703,0.9869745373725892,0.0502125211060047,0.2944865694920178,43793 -24376.586597919464,8.366672277450562,47557.18814301491,145966,0,47557.18814301491,0.9861940741539,0.0535595826804637,0.2857172927144193,43793,71946.61416625977,0.995587170124054,0.0138977067545056,0.773782049719947,0.9869743585586548,0.0502125211060047,0.2946646276835428,43793 -24486.95604276657,8.422574520111084,47797.23997068405,146715,0,47797.23997068405,0.9861940741539,0.0535595826804637,0.2857200951461492,43793,72297.11353874207,0.9954803586006165,0.0141344061121344,0.7739438319568628,0.9869743585586548,0.0502125211060047,0.2945254938863308,43793 -24593.60262060165,8.478218793869019,48037.220688819885,147461,0,48037.220688819885,0.9861940741539,0.0535595789551734,0.2857585033554984,43793,72643.81870794296,0.9955896735191344,0.0138777727261185,0.7762476049009219,0.9869745373725892,0.0502125173807144,0.2944843963867238,43793 -24698.474319934845,8.533739566802979,48277.27271056175,148195,0,48277.27271056175,0.9861940741539,0.0535595826804637,0.2857212119481352,43793,72988.82006072998,0.9955464005470276,0.014040638692677,0.7835602272200127,0.9869745373725892,0.0502125173807144,0.294611981843638,43793 -24811.260496854786,8.589645624160767,48517.27303361893,148936,0,48517.27303361893,0.9861940741539,0.0535595789551734,0.2856568886471392,43793,73341.68425559998,0.9955344200134276,0.0140588777139782,0.7736069100158747,0.9869743585586548,0.0502125173807144,0.2945416036507968,43793 -24922.92630887032,8.645307302474976,48757.350826501846,149679,0,48757.350826501846,0.9861940741539,0.0535595789551734,0.2857804305900344,43793,73693.50581741333,0.995566189289093,0.0139274382963776,0.7797421371283534,0.9869743585586548,0.0502125211060047,0.2947302321803623,43793 -25033.61894416809,8.710498809814453,48997.344621658325,150424,0,48997.344621658325,0.9861940741539,0.0535595826804637,0.2856523374013939,43793,74044.27948951721,0.9955343008041382,0.0140413139015436,0.7644432246613764,0.9869745373725892,0.0502125211060047,0.2945536299897155,43793 -25143.97053194046,9.098667860031128,49236.9952814579,151164,0,49236.9952814579,0.9861940741539,0.0535595789551734,0.2857042056457721,43793,74394.6922750473,0.995585322380066,0.0138838086277246,0.7787436689713627,0.9869743585586548,0.0502125173807144,0.2945482853631458,43793 -25257.198073625565,9.155184268951416,49477.232617139816,151907,0,49477.232617139816,0.9861940741539,0.0535595826804637,0.2859263697756146,43793,74748.23588776588,0.9954984784126282,0.0141248051077127,0.7813303133972982,0.9869745373725892,0.0502125211060047,0.2945451697094664,43793 -25363.61196255684,9.216427087783812,49717.24135637283,152649,0,49717.24135637283,0.9861940741539,0.0535595789551734,0.2857255714175824,43793,75094.74559378624,0.99554842710495,0.0139965619891881,0.7735667334688098,0.9869745373725892,0.0502125211060047,0.2945991734176778,43793 -25472.193568706512,9.272395849227903,49957.17316937447,153386,0,49957.17316937447,0.9861940741539,0.0535595789551734,0.2856822054806806,43793,75443.33772158623,0.9955598711967468,0.0139503329992294,0.7727688636510217,0.9869743585586548,0.0502125211060047,0.2945004272365443,43793 -25581.23864412308,9.328509330749512,50197.12275767326,154125,0,50197.12275767326,0.9861940741539,0.0535595789551734,0.2857150265415266,43793,75792.4113547802,0.99556964635849,0.0139679256826639,0.7743029819107554,0.9869743585586548,0.0502125173807144,0.2944922403376407,43793 -25691.170045137405,9.385692358016968,50437.306334257126,154859,0,50437.306334257126,0.9861940741539,0.0535595789551734,0.2857401568499358,43793,76142.60612154007,0.995524525642395,0.0140228336676955,0.7734727200254108,0.9869745373725892,0.0502125211060047,0.2945479611746832,43793 -25799.21874904633,9.444878578186035,50677.31756472588,155601,0,50677.31756472588,0.9861940741539,0.0535595789551734,0.285683389303403,43793,76490.74878931046,0.9955626726150512,0.0139491278678178,0.7778168183954786,0.9869745373725892,0.0502125173807144,0.2944723845323849,43793 -25907.11387944221,9.500680446624756,50917.56117296219,156336,0,50917.56117296219,0.9861940741539,0.0535595789551734,0.2861027977699629,43793,76838.96542525291,0.995536208152771,0.014027833007276,0.7825999920584917,0.9869745373725892,0.0502125173807144,0.2946293654606327,43793 -26011.510071754456,9.56079602241516,51157.53374195099,157076,0,51157.53374195099,0.9861940741539,0.0535595826804637,0.2856674154938795,43793,77183.41665196419,0.9955697655677797,0.0139607917517423,0.7700218074027575,0.9869745373725892,0.0502125211060047,0.2946568657246006,43793 -26120.800348758698,9.617570638656616,51397.612380981445,157817,0,51397.612380981445,0.9861940741539,0.0535595826804637,0.2856970875343083,43793,77532.86435222626,0.9955599308013916,0.0139398034662008,0.7757960907398831,0.9869745373725892,0.0502125211060047,0.2945389451791146,43793 -26224.05740213394,9.681266069412231,51637.58630156517,158556,0,51637.58630156517,0.9861940741539,0.0535595789551734,0.2858380709121086,43793,77876.18381977081,0.9955105185508728,0.0140839051455259,0.7653077269785241,0.9869745373725892,0.0502125211060047,0.2946983592080036,43793 -26337.106037139893,9.740336656570436,51877.640649318695,159302,0,51877.640649318695,0.9861940741539,0.0535595789551734,0.2857684194818551,43793,78229.36774969101,0.995564877986908,0.0139866154640913,0.7803369469318409,0.9869745373725892,0.0502125211060047,0.2945265167472148,43793 -26442.78346991539,9.7996826171875,52117.86292815208,160043,0,52117.86292815208,0.9861940741539,0.0535595826804637,0.2856528216248714,43793,78575.34875035286,0.995552122592926,0.0139129171147942,0.7836719891879151,0.9869745373725892,0.0502125211060047,0.2946214386872627,43793 -26554.365653038025,9.857231855392456,52358.04911708832,160786,0,52358.04911708832,0.9861940741539,0.0535595789551734,0.2857982278245645,43793,78927.19647717476,0.9955381751060486,0.0140344044193625,0.7719031206961338,0.9869745373725892,0.0502125211060047,0.2946162831197981,43793 -26658.87480187416,9.923945665359495,52598.25149941445,161514,0,52598.25149941445,0.9861940741539,0.0535595826804637,0.2857322342018337,43793,79272.00039196014,0.9955840706825256,0.0139219872653484,0.7775919723837796,0.9869743585586548,0.0502125211060047,0.2946046044145738,43793 -26769.24395275116,9.981608629226685,52838.20043492317,162258,0,52838.20043492317,0.9861940741539,0.0535595826804637,0.2857639647881098,43793,79622.39813017845,0.99552983045578,0.0140566751360893,0.7652208045052136,0.9869745373725892,0.0502125173807144,0.2945383467173924,43793 -26875.5023086071,10.03822422027588,53078.12421751022,163008,0,53078.12421751022,0.9861940741539,0.0535595789551734,0.285866172733904,43793,79968.65912795067,0.995524287223816,0.0140065401792526,0.7796286507900729,0.9869741201400756,0.0502125173807144,0.2945911879708596,43793 -26981.566202640533,10.09718656539917,53318.33121609688,163755,0,53318.33121609688,0.9861940741539,0.0535595789551734,0.2857193972505057,43793,80315.01131868362,0.9955673217773438,0.0139431608840823,0.7821653090591817,0.9869745373725892,0.0502125173807144,0.2944987559590682,43793 -27083.28986978531,10.157109260559082,53558.39138150215,164493,0,53558.39138150215,0.9861940741539,0.0535595789551734,0.2856440025829405,43793,80656.87715959549,0.9955205917358398,0.0140897808596491,0.7764114194363383,0.9869745373725892,0.0502125211060047,0.2945493493391814,43793 -27186.615166664124,10.215405464172363,53798.36969470978,165239,0,53798.36969470978,0.9861940741539,0.0535595789551734,0.2857893411828576,43793,81000.26069259644,0.9955911040306092,0.0139289153739809,0.7734966359252158,0.9869745373725892,0.0502125211060047,0.2944849121350764,43793 -27296.31303691864,10.27896499633789,54038.34677696228,165985,0,54038.34677696228,0.9861940741539,0.0535595826804637,0.2857110507380125,43793,81350.02177357674,0.9955465197563172,0.0139400130137801,0.7729779314410089,0.9869743585586548,0.0502125211060047,0.2944748477438155,43793 -27406.18231940269,10.336283922195436,54278.52591109276,166730,0,54278.52591109276,0.9861940741539,0.0535595789551734,0.2857211817064314,43793,81700.15067434311,0.995519995689392,0.0140833482146263,0.7805190367217666,0.9869743585586548,0.0502125211060047,0.294567484941141,43793 -27515.27748608589,10.39532470703125,54518.58748054504,167476,0,54518.58748054504,0.9861940741539,0.0535595826804637,0.2857042919315566,43793,82049.38845849037,0.9955546855926514,0.0139647619798779,0.7759126234046201,0.9869745373725892,0.0502125211060047,0.2945054412881348,43793 -27621.426381587986,10.46361517906189,54758.71322131157,168219,0,54758.71322131157,0.9861940741539,0.0535595826804637,0.2856700453934852,43793,82395.75434350967,0.9955841302871704,0.0139201804995536,0.7750355602140353,0.9869743585586548,0.0502125173807144,0.2946438301036884,43793 -27727.70783424377,10.523662567138672,54998.81431245804,168963,0,54998.81431245804,0.9861940741539,0.0535595789551734,0.2857018252869073,43793,82742.2187845707,0.9955199360847472,0.0140633136034011,0.7707108704904337,0.9869745373725892,0.0502125211060047,0.2944998386340789,43793 -27830.284168481827,10.582506656646729,55238.7580242157,169708,0,55238.7580242157,0.9861940741539001,0.05355958268046379,0.285909555773725,43793,83084.81942462921,0.9955832958221436,0.013892631977796555,0.7774776130313928,0.9869745373725891,0.050212521106004715,0.2945532663143946,43793 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index 051b21791..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1937 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.9824046,0.76625013,,,,,,,,,,,,,,,,, -1,,,0.4642853140830993,0.7602338194847107,0.0228730991783322,0.4643697440624237,0.7589086294174194,0.0274652618143725,43793.0,0.4646042585372925,0.7593051791191101,0.0285717469569702,43793.0,19.403101921081543,554.6187665462494,19.403101921081543,535.2156233787537,0.0,0.0 -100,0.314991,0.28917146,,,,,,,,,,,,,,,,, -200,0.09732901,0.10826518,,,,,,,,,,,,,,,,, -300,0.032088835,0.06643194,,,,,,,,,,,,,,,,, -400,0.016926246,0.058216877,,,,,,,,,,,,,,,,, -500,0.012534139,0.057437506,,,,,,,,,,,,,,,,, -600,0.021659086,0.050116565,,,,,,,,,,,,,,,,, -700,0.02155055,0.052311048,,,,,,,,,,,,,,,,, -742,,,0.9867706298828124,0.0514782257378101,0.0528244412026438,0.9842023849487304,0.0607735253870487,0.0503369997300077,43793.0,0.9831854701042176,0.0639450103044509,0.0520615109430797,43793.0,259.46107959747314,914.3788795471193,259.46107959747314,654.8684275150299,0.027963638305664,0.0 -800,0.015347271,0.048240956,,,,,,,,,,,,,,,,, -900,0.02186266,0.05204561,,,,,,,,,,,,,,,,, -1000,0.012887548,0.04713726,,,,,,,,,,,,,,,,, -1100,0.028286481,0.04562433,,,,,,,,,,,,,,,,, -1200,0.025978586,0.042253368,,,,,,,,,,,,,,,,, -1300,0.014359783,0.04767495,,,,,,,,,,,,,,,,, -1400,0.02726626,0.05373914,,,,,,,,,,,,,,,,, -1482,,,0.9872096180915833,0.0489099249243736,0.0752923644838142,0.9844236373901368,0.0588754415512084,0.0671462059000805,43793.0,0.9834840893745422,0.0620399415493011,0.0689065527510964,43793.0,499.4656026363373,1276.9522080421448,499.4656026363373,777.3883748054504,0.0557601451873779,0.0 -1500,0.023898255,0.05198053,,,,,,,,,,,,,,,,, -1600,0.012898705,0.051150642,,,,,,,,,,,,,,,,, -1700,0.017734138,0.049501028,,,,,,,,,,,,,,,,, -1800,0.015510704,0.05032359,,,,,,,,,,,,,,,,, -1900,0.010517042,0.04347443,,,,,,,,,,,,,,,,, -2000,0.029933944,0.043224074,,,,,,,,,,,,,,,,, -2100,0.0155424345,0.036983766,,,,,,,,,,,,,,,,, -2200,0.010590265,0.041436207,,,,,,,,,,,,,,,,, -2224,,,0.9875986576080322,0.0446453876793384,0.1208860737385289,0.9848376512527466,0.0538689643144607,0.1229580905647887,43793.0,0.9838589429855348,0.056995864957571,0.1230535037550857,43793.0,739.4543447494507,1636.9871473312378,739.4543447494507,897.3871071338654,0.0821943283081054,0.0 -2300,0.011561129,0.046836767,,,,,,,,,,,,,,,,, -2400,0.016919097,0.041564293,,,,,,,,,,,,,,,,, -2500,0.010673466,0.04299379,,,,,,,,,,,,,,,,, -2600,0.025255956,0.044977132,,,,,,,,,,,,,,,,, -2700,0.011172421,0.042437527,,,,,,,,,,,,,,,,, -2800,0.010908821,0.04164363,,,,,,,,,,,,,,,,, -2900,0.018023461,0.044366166,,,,,,,,,,,,,,,,, -2973,,,0.987941324710846,0.0425007417798042,0.1518332776698262,0.9851689338684082,0.0516148544847965,0.1471634693879306,43793.0,0.9842245578765868,0.0544905923306942,0.1464378353277409,43793.0,979.6255490779876,2001.8835163116453,979.6255490779876,1022.0600764751434,0.1128618717193603,0.0 -3000,0.008985239,0.044962667,,,,,,,,,,,,,,,,, -3100,0.01268231,0.040760636,,,,,,,,,,,,,,,,, -3200,0.015028381,0.043703824,,,,,,,,,,,,,,,,, -3300,0.01633749,0.040187955,,,,,,,,,,,,,,,,, -3400,0.009634821,0.040251017,,,,,,,,,,,,,,,,, -3500,0.014798002,0.037203286,,,,,,,,,,,,,,,,, -3600,0.013163083,0.036291104,,,,,,,,,,,,,,,,, -3700,0.01273175,0.04323322,,,,,,,,,,,,,,,,, -3713,,,0.9883215427398682,0.0405497141182422,0.1873888127236984,0.9853605031967164,0.0498643070459365,0.1723176411435333,43793.0,0.9843947291374208,0.0526860170066356,0.1691754874100842,43793.0,1219.710813999176,2368.356336593628,1219.710813999176,1148.3995158672333,0.1392774581909179,0.0 -3800,0.013368669,0.04459127,,,,,,,,,,,,,,,,, -3900,0.014724942,0.045219902,,,,,,,,,,,,,,,,, -4000,0.010295332,0.038357828,,,,,,,,,,,,,,,,, -4100,0.011130561,0.04116841,,,,,,,,,,,,,,,,, -4200,0.011083492,0.039428987,,,,,,,,,,,,,,,,, -4300,0.012482562,0.039008163,,,,,,,,,,,,,,,,, -4400,0.017148841,0.035190806,,,,,,,,,,,,,,,,, -4454,,,0.9883311986923218,0.0401252694427967,0.2039251994853367,0.9854956865310668,0.0498076863586902,0.1814416540684904,43793.0,0.9845686554908752,0.0527504198253154,0.1856167206647187,43793.0,1459.977115869522,2738.5078999996185,1459.977115869522,1278.234468460083,0.168677806854248,0.0 -4500,0.00966679,0.040161252,,,,,,,,,,,,,,,,, -4600,0.014534406,0.04551229,,,,,,,,,,,,,,,,, -4700,0.019310797,0.041925136,,,,,,,,,,,,,,,,, -4800,0.010306585,0.040630046,,,,,,,,,,,,,,,,, -4900,0.012871132,0.03860388,,,,,,,,,,,,,,,,, -5000,0.011937917,0.035215832,,,,,,,,,,,,,,,,, -5100,0.012910268,0.04125928,,,,,,,,,,,,,,,,, -5185,,,0.9888564944267272,0.0382429473102092,0.2338408167710781,0.9857197999954224,0.0480681993067264,0.1933692347295754,43793.0,0.9848546385765076,0.0508074052631855,0.196022459140312,43793.0,1700.0847454071045,3111.8785684108734,1700.0847454071045,1411.4457199573517,0.1960334777832031,0.0 -5200,0.011062911,0.03689353,,,,,,,,,,,,,,,,, -5300,0.012535137,0.03858404,,,,,,,,,,,,,,,,, -5400,0.015708204,0.037724264,,,,,,,,,,,,,,,,, -5500,0.021110745,0.03866159,,,,,,,,,,,,,,,,, -5600,0.016147273,0.03665005,,,,,,,,,,,,,,,,, -5700,0.03485164,0.04346716,,,,,,,,,,,,,,,,, -5800,0.0137869995,0.036541175,,,,,,,,,,,,,,,,, -5900,0.013511454,0.035490774,,,,,,,,,,,,,,,,, -5918,,,0.989027202129364,0.0372817702591419,0.2582431101463912,0.9858922958374025,0.0481997281312942,0.2058692487304153,43793.0,0.9849650263786316,0.0509615391492843,0.2083791497336055,43793.0,1940.293526172638,3481.6874401569366,1940.293526172638,1540.9889826774595,0.2287201881408691,0.0 -6000,0.019328661,0.04353849,,,,,,,,,,,,,,,,, -6100,0.020046387,0.0395058,,,,,,,,,,,,,,,,, -6200,0.012970162,0.040356,,,,,,,,,,,,,,,,, -6300,0.0127830915,0.035103623,,,,,,,,,,,,,,,,, -6400,0.01223635,0.041413825,,,,,,,,,,,,,,,,, -6500,0.025366578,0.041214265,,,,,,,,,,,,,,,,, -6600,0.013189486,0.034861308,,,,,,,,,,,,,,,,, -6652,,,0.9893152713775636,0.0364081189036369,0.2626560365239537,0.9861143231391908,0.0463856682181358,0.2199867776480811,43793.0,0.9852867722511292,0.0489683225750923,0.2158038268663492,43793.0,2180.472153902054,3849.8858761787415,2180.472153902054,1668.957560300827,0.2584164142608642,0.0 -6700,0.010470161,0.036483426,,,,,,,,,,,,,,,,, -6800,0.01746526,0.035921216,,,,,,,,,,,,,,,,, -6900,0.018095978,0.0366991,,,,,,,,,,,,,,,,, -7000,0.017581522,0.041337986,,,,,,,,,,,,,,,,, -7100,0.014017883,0.036849435,,,,,,,,,,,,,,,,, -7200,0.014512017,0.035939325,,,,,,,,,,,,,,,,, -7300,0.019853914,0.039982688,,,,,,,,,,,,,,,,, -7394,,,0.9894617795944214,0.0357305519282817,0.2777985762334495,0.9862133860588074,0.0458943471312522,0.2205825816389293,43793.0,0.9853141903877258,0.0483731143176555,0.2241192075235631,43793.0,2420.4435505867004,4224.092652320862,2420.4435505867004,1803.142256975174,0.2857906818389892,0.0 -7400,0.015656812,0.035865225,,,,,,,,,,,,,,,,, -7500,0.020543603,0.03493536,,,,,,,,,,,,,,,,, -7600,0.017970582,0.035785634,,,,,,,,,,,,,,,,, -7700,0.018057274,0.03736016,,,,,,,,,,,,,,,,, -7800,0.018587802,0.038341563,,,,,,,,,,,,,,,,, -7900,0.014776029,0.034707725,,,,,,,,,,,,,,,,, -8000,0.014350379,0.032485157,,,,,,,,,,,,,,,,, -8100,0.014369988,0.033744264,,,,,,,,,,,,,,,,, -8133,,,0.9896026849746704,0.0350330807268619,0.3017076079581593,0.9864204525947572,0.0454067513346672,0.2327738536063612,43793.0,0.9854860305786132,0.0479992181062698,0.2323910319049115,43793.0,2660.463634490967,4594.116451025009,2660.463634490967,1933.096392393112,0.3145980834960937,0.0 -8200,0.012943296,0.037271094,,,,,,,,,,,,,,,,, -8300,0.025249239,0.034990456,,,,,,,,,,,,,,,,, -8400,0.014626197,0.03495389,,,,,,,,,,,,,,,,, -8500,0.017600097,0.03606479,,,,,,,,,,,,,,,,, -8600,0.01652015,0.032171216,,,,,,,,,,,,,,,,, -8700,0.016068758,0.035651088,,,,,,,,,,,,,,,,, -8800,0.01532392,0.035838563,,,,,,,,,,,,,,,,, -8876,,,0.9896164536476136,0.0348296910524368,0.3129953349495222,0.9864743947982788,0.045508861541748,0.2362024521907661,43793.0,0.9855205416679382,0.0482263714075088,0.2397758718441605,43793.0,2900.5610976219177,4963.792473316193,2900.5610976219177,2062.624636888504,0.3434135913848877,0.0 -8900,0.020609956,0.033472497,,,,,,,,,,,,,,,,, -9000,0.017557234,0.03114367,,,,,,,,,,,,,,,,, -9100,0.033647057,0.039769202,,,,,,,,,,,,,,,,, -9200,0.017032852,0.035406843,,,,,,,,,,,,,,,,, -9300,0.026481971,0.040862627,,,,,,,,,,,,,,,,, -9400,0.030525977,0.029125383,,,,,,,,,,,,,,,,, -9500,0.031758327,0.041319944,,,,,,,,,,,,,,,,, -9600,0.027349066,0.033441387,,,,,,,,,,,,,,,,, -9619,,,0.9899412989616394,0.0336496606469154,0.3282852467346069,0.9865624904632568,0.0450242049992084,0.2426041542082993,43793.0,0.9856747388839722,0.047865305095911,0.2418852485105137,43793.0,3140.5976366996765,5332.80707859993,3140.5976366996765,2191.552747964859,0.3722200393676758,0.0 -9700,0.023683896,0.03383169,,,,,,,,,,,,,,,,, -9800,0.017580666,0.035080973,,,,,,,,,,,,,,,,, -9900,0.018639568,0.035665143,,,,,,,,,,,,,,,,, -10000,0.016881432,0.03648394,,,,,,,,,,,,,,,,, -10100,0.017986204,0.035407882,,,,,,,,,,,,,,,,, -10200,0.018479919,0.03232595,,,,,,,,,,,,,,,,, -10300,0.023706663,0.033185154,,,,,,,,,,,,,,,,, -10355,,,0.9902228116989136,0.0326662361621856,0.3493272232299749,0.9866660237312316,0.0448374822735786,0.2521623679673248,43793.0,0.9856966137886048,0.0476522482931613,0.2460464273669188,43793.0,3380.571000099182,5702.102422714233,3380.571000099182,2320.8233363628387,0.4018385410308838,0.0 -10400,0.020892605,0.03686126,,,,,,,,,,,,,,,,, -10500,0.018660987,0.034339637,,,,,,,,,,,,,,,,, -10600,0.017984737,0.034905374,,,,,,,,,,,,,,,,, -10700,0.021278916,0.03382434,,,,,,,,,,,,,,,,, -10800,0.023522539,0.033841785,,,,,,,,,,,,,,,,, -10900,0.021958092,0.039004847,,,,,,,,,,,,,,,,, -11000,0.022021849,0.037274797,,,,,,,,,,,,,,,,, -11097,,,0.9903243780136108,0.0321387909352779,0.3718836815702409,0.98664653301239,0.0445500202476978,0.2481188682180836,43793.0,0.9858115911483764,0.0472144670784473,0.254014006212062,43793.0,3620.668796777725,6075.597155094147,3620.668796777725,2454.1698989868164,0.4301824569702148,0.0 -11100,0.01745833,0.030233232,,,,,,,,,,,,,,,,, -11200,0.032018427,0.03541099,,,,,,,,,,,,,,,,, -11300,0.026022043,0.034896087,,,,,,,,,,,,,,,,, -11400,0.02567056,0.03482175,,,,,,,,,,,,,,,,, -11500,0.021421582,0.033294063,,,,,,,,,,,,,,,,, -11600,0.028040502,0.034257963,,,,,,,,,,,,,,,,, -11700,0.021281661,0.03232843,,,,,,,,,,,,,,,,, -11800,0.026033035,0.035284914,,,,,,,,,,,,,,,,, -11843,,,0.9905675053596495,0.0313689559698104,0.3800155912255726,0.9866924285888672,0.0444240570068359,0.2527786570240431,43793.0,0.9857800006866456,0.0471613258123397,0.2472116040143474,43793.0,3860.690770864487,6445.938730955124,3860.690770864487,2584.4375660419464,0.4595632553100586,0.0 -11900,0.0246334,0.033661813,,,,,,,,,,,,,,,,, -12000,0.026391871,0.03488483,,,,,,,,,,,,,,,,, -12100,0.023475861,0.03317641,,,,,,,,,,,,,,,,, -12200,0.029869221,0.03435873,,,,,,,,,,,,,,,,, -12300,0.042298418,0.033087987,,,,,,,,,,,,,,,,, -12400,0.02202666,0.032678973,,,,,,,,,,,,,,,,, -12500,0.02370563,0.033706814,,,,,,,,,,,,,,,,, -12581,,,0.9902974367141724,0.0321634970605373,0.356260952255887,0.9867252707481384,0.044447436928749,0.2620155015599733,43793.0,0.9857972860336304,0.0474021434783935,0.2524061717930836,43793.0,4100.80003118515,6817.867141008377,4100.80003118515,2716.204756736756,0.4894959926605224,0.0 -12600,0.03643515,0.032585327,,,,,,,,,,,,,,,,, -12700,0.024072401,0.033188377,,,,,,,,,,,,,,,,, -12800,0.02619018,0.030587304,,,,,,,,,,,,,,,,, -12900,0.027186964,0.03195565,,,,,,,,,,,,,,,,, -13000,0.025895096,0.033178218,,,,,,,,,,,,,,,,, -13100,0.025311729,0.0348603,,,,,,,,,,,,,,,,, -13200,0.025248144,0.03226148,,,,,,,,,,,,,,,,, -13300,0.028357496,0.03247564,,,,,,,,,,,,,,,,, -13326,,,0.9902656674385072,0.0322396829724311,0.3608134814237923,0.986819863319397,0.0445113666355609,0.2681383014152059,43793.0,0.9858890771865844,0.0476423725485801,0.2547764526425899,43793.0,4340.828438043594,7183.03778886795,4340.828438043594,2841.2963259220123,0.5184030532836914,0.0 -13400,0.031646356,0.031136174,,,,,,,,,,,,,,,,, -13500,0.03571628,0.035353497,,,,,,,,,,,,,,,,, -13600,0.029671347,0.0326856,,,,,,,,,,,,,,,,, -13700,0.031119976,0.031787977,,,,,,,,,,,,,,,,, -13800,0.030016052,0.031743284,,,,,,,,,,,,,,,,, -13900,0.027742557,0.031695455,,,,,,,,,,,,,,,,, -14000,0.030141106,0.031120544,,,,,,,,,,,,,,,,, -14067,,,0.9905223846435548,0.031429897993803,0.3760678922408069,0.9869217872619628,0.043913647532463,0.2748815194390102,43793.0,0.9859737753868104,0.0468311235308647,0.2589879758092704,43793.0,4580.7753620147705,7553.162268877029,4580.7753620147705,2971.42048573494,0.5491487979888916,0.0 -14100,0.0352457,0.030099599,,,,,,,,,,,,,,,,, -14200,0.03690526,0.036393292,,,,,,,,,,,,,,,,, -14300,0.043144003,0.03258868,,,,,,,,,,,,,,,,, -14400,0.032358956,0.030420408,,,,,,,,,,,,,,,,, -14500,0.042465158,0.034356248,,,,,,,,,,,,,,,,, -14600,0.03407399,0.03187914,,,,,,,,,,,,,,,,, -14700,0.030499784,0.031559438,,,,,,,,,,,,,,,,, -14798,,,0.9905580878257751,0.0311930775642395,0.3902971635582445,0.9866790175437928,0.0446098558604717,0.2687362654336637,43793.0,0.9856927990913392,0.0475738048553466,0.2555561248051884,43793.0,4820.819413661957,7923.672146081924,4820.819413661957,3101.82884144783,0.5812263488769531,0.0 -14800,0.032494258,0.03312343,,,,,,,,,,,,,,,,, -14900,0.04492725,0.03269453,,,,,,,,,,,,,,,,, -15000,0.028313337,0.032182306,,,,,,,,,,,,,,,,, -15100,0.03251724,0.033554703,,,,,,,,,,,,,,,,, -15200,0.04019745,0.0319651,,,,,,,,,,,,,,,,, -15300,0.046386108,0.034151703,,,,,,,,,,,,,,,,, -15400,0.041479774,0.036342863,,,,,,,,,,,,,,,,, -15500,0.03338425,0.030652992,,,,,,,,,,,,,,,,, -15524,,,0.9906007647514344,0.0306931287050247,0.4037977692179422,0.9867488145828248,0.0442212112247943,0.2687962872441771,43793.0,0.985888659954071,0.0469398722052574,0.2664834005963081,43793.0,5060.941811084747,8293.455752372742,5060.941811084747,3231.4354071617126,0.6112728118896484,0.0 -15600,0.0326996,0.033550467,,,,,,,,,,,,,,,,, -15700,0.036895458,0.037474472,,,,,,,,,,,,,,,,, -15800,0.035633877,0.03392742,,,,,,,,,,,,,,,,, -15900,0.03482153,0.03385452,,,,,,,,,,,,,,,,, -16000,0.032284558,0.030964771,,,,,,,,,,,,,,,,, -16100,0.038209118,0.030665172,,,,,,,,,,,,,,,,, -16200,0.04712654,0.03168236,,,,,,,,,,,,,,,,, -16263,,,0.9907249808311462,0.030366413295269,0.4092169985073081,0.9867720007896424,0.0439391694962978,0.268025885787124,43793.0,0.9859308004379272,0.0465576462447643,0.2644002246557672,43793.0,5301.218809843063,8663.80548620224,5301.218809843063,3361.455674648285,0.6414785385131836,0.0 -16300,0.03346496,0.030345434,,,,,,,,,,,,,,,,, -16400,0.04431667,0.03226805,,,,,,,,,,,,,,,,, -16500,0.037316415,0.028523566,,,,,,,,,,,,,,,,, -16600,0.037283614,0.033630256,,,,,,,,,,,,,,,,, -16700,0.046272006,0.032518983,,,,,,,,,,,,,,,,, -16800,0.030467918,0.032113694,,,,,,,,,,,,,,,,, -16900,0.037954517,0.03103677,,,,,,,,,,,,,,,,, -16999,,,0.9909940361976624,0.029345579445362,0.4297240381765557,0.9869027137756348,0.0440330728888511,0.2727817331032616,43793.0,0.9860049486160278,0.0471245720982551,0.2623276261824349,43793.0,5541.230150699616,9032.645967960358,5541.230150699616,3490.2316720485687,0.672264575958252,0.0 -17000,0.033721317,0.029844966,,,,,,,,,,,,,,,,, -17100,0.043030497,0.034390066,,,,,,,,,,,,,,,,, -17200,0.036072485,0.029616486,,,,,,,,,,,,,,,,, -17300,0.03146981,0.02706877,,,,,,,,,,,,,,,,, -17400,0.04956205,0.03584069,,,,,,,,,,,,,,,,, -17500,0.03877261,0.03413793,,,,,,,,,,,,,,,,, -17600,0.033777323,0.031094465,,,,,,,,,,,,,,,,, -17700,0.046538938,0.034375243,,,,,,,,,,,,,,,,, -17734,,,0.9909370541572572,0.0297421440482139,0.4172895521300028,0.9867074489593506,0.0442817248404026,0.266266108306227,43793.0,0.9858478307724,0.0472505912184715,0.2620539259508231,43793.0,5781.477691650391,9405.632671833038,5781.477691650391,3622.917279958725,0.7036941051483154,0.0 -17800,0.03531398,0.03214938,,,,,,,,,,,,,,,,, -17900,0.04065321,0.03183623,,,,,,,,,,,,,,,,, -18000,0.039741285,0.031177932,,,,,,,,,,,,,,,,, -18100,0.03737909,0.030556053,,,,,,,,,,,,,,,,, -18200,0.03593776,0.029689135,,,,,,,,,,,,,,,,, -18300,0.036543436,0.029931724,,,,,,,,,,,,,,,,, -18400,0.03637902,0.032201823,,,,,,,,,,,,,,,,, -18468,,,0.9910364151000975,0.0295282341539859,0.423870057287632,0.986844837665558,0.0438341982662677,0.2783312649957931,43793.0,0.9860584139823914,0.0464429780840873,0.266278519375825,43793.0,6021.736469745636,9775.371485948564,6021.736469745636,3752.345653772354,0.7333090305328369,0.0 -18500,0.038715675,0.030649524,,,,,,,,,,,,,,,,, -18600,0.03682499,0.028064642,,,,,,,,,,,,,,,,, -18700,0.045741912,0.033022173,,,,,,,,,,,,,,,,, -18800,0.038841646,0.030690731,,,,,,,,,,,,,,,,, -18900,0.039065618,0.031970233,,,,,,,,,,,,,,,,, -19000,0.04155329,0.032455273,,,,,,,,,,,,,,,,, -19100,0.041407716,0.028167425,,,,,,,,,,,,,,,,, -19200,0.036623213,0.026606183,,,,,,,,,,,,,,,,, -19203,,,0.9909691214561462,0.0296655260026454,0.4261870144662457,0.9869157075881958,0.0437863618135452,0.278916097311652,43793.0,0.9859741926193236,0.0465784110128879,0.267380460209644,43793.0,6261.852321386337,10143.883538722992,6261.852321386337,3880.687066555023,0.7648305892944336,0.0 -19300,0.040753514,0.032397907,,,,,,,,,,,,,,,,, -19400,0.04258183,0.031105088,,,,,,,,,,,,,,,,, -19500,0.046374727,0.027981663,,,,,,,,,,,,,,,,, -19600,0.04485605,0.031279877,,,,,,,,,,,,,,,,, -19700,0.041678905,0.03144349,,,,,,,,,,,,,,,,, -19800,0.03964878,0.026451305,,,,,,,,,,,,,,,,, -19900,0.04384333,0.03688479,,,,,,,,,,,,,,,,, -19936,,,0.9907914400100708,0.0301237683743238,0.4203923085594159,0.9866023063659668,0.0442087799310684,0.274155667844338,43793.0,0.9857467412948608,0.0468638837337493,0.2608550174941025,43793.0,6502.033669233322,10515.224586963654,6502.033669233322,4011.790348768234,0.7990889549255371,0.0 -20000,0.041861366,0.032068066,,,,,,,,,,,,,,,,, -20100,0.043888383,0.031153396,,,,,,,,,,,,,,,,, -20200,0.043553546,0.033310365,,,,,,,,,,,,,,,,, -20300,0.062158816,0.033999857,,,,,,,,,,,,,,,,, -20400,0.04086955,0.03101339,,,,,,,,,,,,,,,,, -20500,0.04708827,0.028116984,,,,,,,,,,,,,,,,, -20600,0.049989052,0.032800447,,,,,,,,,,,,,,,,, -20676,,,0.990973174571991,0.0296208467334508,0.4198908615641574,0.9867675304412842,0.0440637990832328,0.2745195357738439,43793.0,0.9859619736671448,0.0467089414596557,0.2625949482700543,43793.0,6742.084539890289,10882.673321723938,6742.084539890289,4139.134879112244,0.8307468891143799,0.0 -20700,0.04581214,0.035033602,,,,,,,,,,,,,,,,, -20800,0.05024029,0.03346864,,,,,,,,,,,,,,,,, -20900,0.038730457,0.033803944,,,,,,,,,,,,,,,,, -21000,0.057677202,0.032850098,,,,,,,,,,,,,,,,, -21100,0.043705013,0.031170199,,,,,,,,,,,,,,,,, -21200,0.041648068,0.032261465,,,,,,,,,,,,,,,,, -21300,0.039181855,0.030550191,,,,,,,,,,,,,,,,, -21400,0.05339688,0.031064564,,,,,,,,,,,,,,,,, -21409,,,0.9911744594573976,0.0287772081792354,0.4454445866037658,0.986868977546692,0.0436342731118202,0.2769052961180017,43793.0,0.9859548211097716,0.0463682934641838,0.2663493902990147,43793.0,6982.223841190338,11256.264298200607,6982.223841190338,4272.533679008484,0.861682653427124,0.0 -21500,0.050479073,0.033930667,,,,,,,,,,,,,,,,, -21600,0.05017285,0.029724933,,,,,,,,,,,,,,,,, -21700,0.061681014,0.025899164,,,,,,,,,,,,,,,,, -21800,0.052617114,0.032466214,,,,,,,,,,,,,,,,, -21900,0.039537102,0.031425633,,,,,,,,,,,,,,,,, -22000,0.049629997,0.03039544,,,,,,,,,,,,,,,,, -22100,0.049083326,0.032960765,,,,,,,,,,,,,,,,, -22141,,,0.9913221597671508,0.0282847136259078,0.4633205340969653,0.9869778156280518,0.0438401438295841,0.2831026302198904,43793.0,0.986151933670044,0.0466170087456703,0.2707095616703951,43793.0,7222.290660858154,11624.571064710615,7222.290660858154,4400.717868804932,0.8936870098114014,0.0 -22200,0.04511278,0.027069073,,,,,,,,,,,,,,,,, -22300,0.047259655,0.032117106,,,,,,,,,,,,,,,,, -22400,0.04369318,0.030473646,,,,,,,,,,,,,,,,, -22500,0.05612438,0.031293474,,,,,,,,,,,,,,,,, -22600,0.055739485,0.03247654,,,,,,,,,,,,,,,,, -22700,0.06461701,0.035918437,,,,,,,,,,,,,,,,, -22800,0.04752778,0.033572022,,,,,,,,,,,,,,,,, -22880,,,0.9915106296539308,0.0275696720927953,0.4735089752188432,0.9869562983512878,0.0441528148949146,0.2790978605654699,43793.0,0.986132562160492,0.0468167290091514,0.2688434020944504,43793.0,7462.577006340027,11995.432676315308,7462.577006340027,4531.240593194962,0.9250195026397704,0.0 -22900,0.05634003,0.03413703,,,,,,,,,,,,,,,,, -23000,0.04313685,0.030192917,,,,,,,,,,,,,,,,, -23100,0.05183095,0.03347035,,,,,,,,,,,,,,,,, -23200,0.056471635,0.033015877,,,,,,,,,,,,,,,,, -23300,0.04953231,0.030658573,,,,,,,,,,,,,,,,, -23400,0.049407244,0.032884642,,,,,,,,,,,,,,,,, -23500,0.04274373,0.03022892,,,,,,,,,,,,,,,,, -23600,0.06020325,0.029612202,,,,,,,,,,,,,,,,, -23626,,,0.9915829300880432,0.0273137167096138,0.4782024616666336,0.9868547916412354,0.0442686900496482,0.2710756095154921,43793.0,0.98606139421463,0.0470444560050964,0.2641718235610194,43793.0,7702.591665744781,12364.075269460678,7702.591665744781,4659.816511154175,0.9558253288269044,0.0 -23700,0.07889659,0.029284999,,,,,,,,,,,,,,,,, -23800,0.05222179,0.03386362,,,,,,,,,,,,,,,,, -23900,0.047531627,0.027856065,,,,,,,,,,,,,,,,, -24000,0.054266673,0.03143671,,,,,,,,,,,,,,,,, -24100,0.050033737,0.030896049,,,,,,,,,,,,,,,,, -24200,0.05182895,0.029420981,,,,,,,,,,,,,,,,, -24300,0.04461867,0.029807935,,,,,,,,,,,,,,,,, -24373,,,0.991339385509491,0.0281135123223066,0.4599286637345375,0.9870126843452454,0.043771956115961,0.2795645558470025,43793.0,0.9861839413642884,0.0465402007102966,0.2701791484926142,43793.0,7942.569582223892,12734.326212882996,7942.569582223892,4790.036564588547,0.98685622215271,0.0 -24400,0.05324588,0.030487102,,,,,,,,,,,,,,,,, -24500,0.04293522,0.029390376,,,,,,,,,,,,,,,,, -24600,0.06757969,0.03399744,,,,,,,,,,,,,,,,, -24700,0.069305785,0.02889074,,,,,,,,,,,,,,,,, -24800,0.060048092,0.032458976,,,,,,,,,,,,,,,,, -24900,0.046357263,0.031574182,,,,,,,,,,,,,,,,, -25000,0.051625606,0.028051991,,,,,,,,,,,,,,,,, -25100,0.053638507,0.033033203,,,,,,,,,,,,,,,,, -25118,,,0.9912641048431396,0.0285040661692619,0.4477710025916689,0.9869189262390136,0.0441875495016574,0.2854335369192523,43793.0,0.9860343933105468,0.0470565073192119,0.2656527333506259,43793.0,8182.74741435051,13099.399793863297,8182.74741435051,4914.877838134766,1.0195541381835938,0.0 -25200,0.04154791,0.02854715,,,,,,,,,,,,,,,,, -25300,0.05669871,0.030824078,,,,,,,,,,,,,,,,, -25400,0.05408857,0.031874243,,,,,,,,,,,,,,,,, -25500,0.057594266,0.034308817,,,,,,,,,,,,,,,,, -25600,0.06146186,0.029872613,,,,,,,,,,,,,,,,, -25700,0.05647099,0.03198597,,,,,,,,,,,,,,,,, -25800,0.05234222,0.030160394,,,,,,,,,,,,,,,,, -25853,,,0.9912461638450624,0.0284663066267967,0.4541184393410252,0.986922562122345,0.0437526814639568,0.2801145594678734,43793.0,0.98613041639328,0.0464157201349735,0.2715346448095588,43793.0,8422.905773878098,13470.157238006592,8422.905773878098,5045.421402454376,1.0534796714782717,0.0 -25900,0.04911793,0.028514875,,,,,,,,,,,,,,,,, -26000,0.06045945,0.03246953,,,,,,,,,,,,,,,,, -26100,0.047045633,0.028957555,,,,,,,,,,,,,,,,, -26200,0.043927465,0.0321819,,,,,,,,,,,,,,,,, -26300,0.07447427,0.030599164,,,,,,,,,,,,,,,,, -26400,0.060950216,0.029298814,,,,,,,,,,,,,,,,, -26500,0.06599119,0.029082162,,,,,,,,,,,,,,,,, -26590,,,0.9915059804916382,0.0278779752552509,0.4528256546297896,0.9869843125343324,0.0440012738108634,0.2870869938781617,43793.0,0.986127495765686,0.046985313296318,0.268233498378407,43793.0,8662.952923297882,13837.806899309158,8662.952923297882,5172.97108578682,1.0846607685089111,0.0 -26600,0.04713173,0.027756205,,,,,,,,,,,,,,,,, -26700,0.057777748,0.03105976,,,,,,,,,,,,,,,,, -26800,0.06313961,0.029495269,,,,,,,,,,,,,,,,, -26900,0.05049265,0.030610906,,,,,,,,,,,,,,,,, -27000,0.07147639,0.029315246,,,,,,,,,,,,,,,,, -27100,0.048154082,0.028947283,,,,,,,,,,,,,,,,, -27200,0.049555197,0.029444877,,,,,,,,,,,,,,,,, -27300,0.054105897,0.028002728,,,,,,,,,,,,,,,,, -27329,,,0.991453230381012,0.0276226308196783,0.4781773408012578,0.9870342016220092,0.0437435507774353,0.2914553301203189,43793.0,0.9861692190170288,0.046562697738409,0.2776822753222365,43793.0,8903.157333612442,14206.030374526978,8903.157333612442,5300.935266256332,1.1179378032684326,0.0 -27400,0.069807,0.032684308,,,,,,,,,,,,,,,,, -27500,0.0618696,0.029370345,,,,,,,,,,,,,,,,, -27600,0.067076154,0.03148944,,,,,,,,,,,,,,,,, -27700,0.06569908,0.033325948,,,,,,,,,,,,,,,,, -27800,0.05168496,0.032609895,,,,,,,,,,,,,,,,, -27900,0.043661326,0.028915573,,,,,,,,,,,,,,,,, -28000,0.062655054,0.03125489,,,,,,,,,,,,,,,,, -28067,,,0.991605579853058,0.0272387899458408,0.4788237662347861,0.986989974975586,0.0438026152551174,0.2826635722662011,43793.0,0.9861502647399902,0.0464490689337253,0.2764971134633485,43793.0,9143.340360164642,14576.181078672407,9143.340360164642,5430.845782995224,1.1518630981445312,0.0 -28100,0.05431589,0.028948292,,,,,,,,,,,,,,,,, -28200,0.054720372,0.027950626,,,,,,,,,,,,,,,,, -28300,0.056146976,0.032339867,,,,,,,,,,,,,,,,, -28400,0.04920692,0.028737953,,,,,,,,,,,,,,,,, -28500,0.061244182,0.032931216,,,,,,,,,,,,,,,,, -28600,0.06192189,0.029635131,,,,,,,,,,,,,,,,, -28700,0.07071596,0.03151386,,,,,,,,,,,,,,,,, -28800,0.06111914,0.029457103,,,,,,,,,,,,,,,,, -28809,,,0.991728663444519,0.0268288888037204,0.4913344138715968,0.9869307279586792,0.0436968393623828,0.2770742675956774,43793.0,0.9860761165618896,0.0462912358343601,0.2733801524206768,43793.0,9383.368975162506,14945.198653936386,9383.368975162506,5559.781363964081,1.1833415031433103,0.0 -28900,0.042929463,0.02727397,,,,,,,,,,,,,,,,, -29000,0.057616625,0.028630301,,,,,,,,,,,,,,,,, -29100,0.05322476,0.030688632,,,,,,,,,,,,,,,,, -29200,0.07806919,0.03021566,,,,,,,,,,,,,,,,, -29300,0.07767598,0.031260464,,,,,,,,,,,,,,,,, -29400,0.06523963,0.030941594,,,,,,,,,,,,,,,,, -29500,0.060986783,0.030250313,,,,,,,,,,,,,,,,, -29552,,,0.9918919205665588,0.026216072961688,0.5069955294043473,0.9869245886802672,0.0441036969423294,0.2867755501464883,43793.0,0.9861944913864136,0.0467113442718982,0.2743730009782902,43793.0,9623.559691429138,15313.5361597538,9623.559691429138,5687.874805927277,1.215362310409546,0.0 -29600,0.0454881,0.029785987,,,,,,,,,,,,,,,,, -29700,0.0576964,0.029478047,,,,,,,,,,,,,,,,, -29800,0.06067107,0.029243076,,,,,,,,,,,,,,,,, -29900,0.053484187,0.030493373,,,,,,,,,,,,,,,,, -30000,0.0767123,0.033673108,,,,,,,,,,,,,,,,, -30100,0.050645262,0.026905965,,,,,,,,,,,,,,,,, -30200,0.087273225,0.03122147,,,,,,,,,,,,,,,,, -30291,,,0.9920125007629396,0.0261006280779838,0.5134227238878066,0.9868564009666444,0.0438586100935936,0.2853673628324881,43793.0,0.9861359000205994,0.0463078506290912,0.2734795702662339,43793.0,9863.722933769226,15684.667104959488,9863.722933769226,5818.787388086319,1.248405933380127,0.0 -30300,0.05355568,0.032274105,,,,,,,,,,,,,,,,, -30400,0.063354224,0.030427407,,,,,,,,,,,,,,,,, -30500,0.066721104,0.028071174,,,,,,,,,,,,,,,,, -30600,0.060849734,0.03013704,,,,,,,,,,,,,,,,, -30700,0.07726531,0.02774886,,,,,,,,,,,,,,,,, -30800,0.05516152,0.03030272,,,,,,,,,,,,,,,,, -30900,0.06251237,0.029689252,,,,,,,,,,,,,,,,, -31000,0.06276676,0.03221221,,,,,,,,,,,,,,,,, -31034,,,0.9918270707130432,0.0266173146665096,0.4979042791415171,0.9870216250419616,0.0439487509429454,0.2887786879950446,43793.0,0.9862104654312134,0.0466993637382984,0.2768710218968238,43793.0,10103.851176023483,16058.839148044586,10103.851176023483,5952.775523900986,1.2820994853973389,0.0 -31100,0.05943365,0.03087311,,,,,,,,,,,,,,,,, -31200,0.053101424,0.027977742,,,,,,,,,,,,,,,,, -31300,0.052259963,0.028668141,,,,,,,,,,,,,,,,, -31400,0.058642205,0.027674317,,,,,,,,,,,,,,,,, -31500,0.054356985,0.027930887,,,,,,,,,,,,,,,,, -31600,0.05713388,0.029919086,,,,,,,,,,,,,,,,, -31700,0.056872386,0.030011835,,,,,,,,,,,,,,,,, -31766,,,0.9914273023605348,0.0275633074343204,0.4744032507199148,0.9870082139968872,0.0441398322582244,0.2862378750792766,43793.0,0.986176371574402,0.0468368269503116,0.2766105111049243,43793.0,10343.984036922457,16425.5092420578,10343.984036922457,6079.255536794663,1.3162181377410889,0.0 -31800,0.065638065,0.030304497,,,,,,,,,,,,,,,,, -31900,0.062043536,0.029734777,,,,,,,,,,,,,,,,, -32000,0.068036065,0.030236749,,,,,,,,,,,,,,,,, -32100,0.065239035,0.029649606,,,,,,,,,,,,,,,,, -32200,0.057756376,0.025108334,,,,,,,,,,,,,,,,, -32300,0.06925821,0.028238123,,,,,,,,,,,,,,,,, -32400,0.05110345,0.026447559,,,,,,,,,,,,,,,,, -32500,0.05716855,0.02860995,,,,,,,,,,,,,,,,, -32504,,,0.991688907146454,0.0269593149423599,0.4850293929760725,0.9868718385696412,0.0442156381905078,0.286878314432203,43793.0,0.9860592484474182,0.0467476919293403,0.2741441695898787,43793.0,10584.240829467772,16793.06715464592,10584.240829467772,6206.500435352325,1.350158452987671,0.0 -32600,0.06854027,0.02790743,,,,,,,,,,,,,,,,, -32700,0.06504468,0.029110942,,,,,,,,,,,,,,,,, -32800,0.06391708,0.026661659,,,,,,,,,,,,,,,,, -32900,0.07774433,0.03192554,,,,,,,,,,,,,,,,, -33000,0.059152156,0.029717049,,,,,,,,,,,,,,,,, -33100,0.06600633,0.027678734,,,,,,,,,,,,,,,,, -33200,0.06374941,0.026697254,,,,,,,,,,,,,,,,, -33243,,,0.9917634129524232,0.0265021454542875,0.5059145808603558,0.9871076941490172,0.0442456379532814,0.2891483056162617,43793.0,0.9863191246986388,0.0469299219548702,0.274600195758634,43793.0,10824.421991825104,17163.691017866135,10824.421991825104,6336.887906312943,1.3831946849822998,0.0 -33300,0.070297524,0.0310947,,,,,,,,,,,,,,,,, -33400,0.059222914,0.028169947,,,,,,,,,,,,,,,,, -33500,0.051668543,0.027942328,,,,,,,,,,,,,,,,, -33600,0.07246989,0.027926002,,,,,,,,,,,,,,,,, -33700,0.070795774,0.031481143,,,,,,,,,,,,,,,,, -33800,0.050827395,0.028522244,,,,,,,,,,,,,,,,, -33900,0.057450842,0.027754316,,,,,,,,,,,,,,,,, -33978,,,0.9919225573539734,0.0260626841336488,0.5149891537773273,0.986989974975586,0.043880894780159,0.2891141455523475,43793.0,0.9861114621162416,0.0466743633151054,0.2715926387695952,43793.0,11064.680742740631,17533.152592658997,11064.680742740631,6466.034151554108,1.4176712036132812,0.0 -34000,0.06520274,0.026170332,,,,,,,,,,,,,,,,, -34100,0.06999079,0.029783132,,,,,,,,,,,,,,,,, -34200,0.058991235,0.026247242,,,,,,,,,,,,,,,,, -34300,0.063952506,0.028357573,,,,,,,,,,,,,,,,, -34400,0.05563746,0.025694093,,,,,,,,,,,,,,,,, -34500,0.069135346,0.026957687,,,,,,,,,,,,,,,,, -34600,0.06542651,0.03041651,,,,,,,,,,,,,,,,, -34700,0.060346447,0.02678372,,,,,,,,,,,,,,,,, -34716,,,0.992056965827942,0.0256364122033119,0.5100176386259829,0.9869741201400756,0.044404212385416,0.284721986229574,43793.0,0.986080765724182,0.0472131036221981,0.2721279982617621,43793.0,11304.925557374954,17901.961685180664,11304.925557374954,6594.542794704437,1.4511137008666992,0.0 -34800,0.06521028,0.026029786,,,,,,,,,,,,,,,,, -34900,0.06650758,0.027027812,,,,,,,,,,,,,,,,, -35000,0.059236955,0.027666079,,,,,,,,,,,,,,,,, -35100,0.06197196,0.029300166,,,,,,,,,,,,,,,,, -35200,0.0697493,0.032230992,,,,,,,,,,,,,,,,, -35300,0.05911708,0.024573207,,,,,,,,,,,,,,,,, -35400,0.05813659,0.028652266,,,,,,,,,,,,,,,,, -35448,,,0.9924238920211792,0.024371912702918,0.5482731850412265,0.9869936108589172,0.044317964464426,0.2857833668199892,43793.0,0.9861615896224976,0.0470213033258914,0.2708743413208484,43793.0,11544.964567899704,18272.09800696373,11544.964567899704,6724.5836000442505,1.4851961135864258,0.0 -35500,0.059318002,0.0281688,,,,,,,,,,,,,,,,, -35600,0.061957534,0.030575091,,,,,,,,,,,,,,,,, -35700,0.06735211,0.027068418,,,,,,,,,,,,,,,,, -35800,0.06673006,0.03245552,,,,,,,,,,,,,,,,, -35900,0.05447994,0.026756361,,,,,,,,,,,,,,,,, -36000,0.06389108,0.027130136,,,,,,,,,,,,,,,,, -36100,0.06147265,0.02926021,,,,,,,,,,,,,,,,, -36186,,,0.99224191904068,0.0249270088970661,0.5376483588503354,0.9870427250862122,0.0444023720920085,0.2826900278886544,43793.0,0.9862930178642272,0.047236256301403,0.2747061458046834,43793.0,11784.975558757782,18641.091474056244,11784.975558757782,6853.508927345276,1.5202860832214355,0.0 -36200,0.06669425,0.03055766,,,,,,,,,,,,,,,,, -36300,0.06979302,0.02864271,,,,,,,,,,,,,,,,, -36400,0.058252927,0.02556206,,,,,,,,,,,,,,,,, -36500,0.07132312,0.024274066,,,,,,,,,,,,,,,,, -36600,0.0617663,0.03009229,,,,,,,,,,,,,,,,, -36700,0.08717095,0.026164087,,,,,,,,,,,,,,,,, -36800,0.07625785,0.030327123,,,,,,,,,,,,,,,,, -36900,0.060537577,0.02962488,,,,,,,,,,,,,,,,, -36923,,,0.9921140670776368,0.0253934878855943,0.519401692550248,0.9870488047599792,0.0440136454999446,0.2858875844680839,43793.0,0.9862037301063538,0.0468671582639217,0.2786638609715531,43793.0,12025.085746765137,19008.48497748375,12025.085746765137,6980.73330283165,1.5558314323425293,0.0 -37000,0.05620133,0.02668372,,,,,,,,,,,,,,,,, -37100,0.055779964,0.026795687,,,,,,,,,,,,,,,,, -37200,0.06282192,0.028444348,,,,,,,,,,,,,,,,, -37300,0.06790789,0.02680991,,,,,,,,,,,,,,,,, -37400,0.07032935,0.02885621,,,,,,,,,,,,,,,,, -37500,0.0684124,0.0288374,,,,,,,,,,,,,,,,, -37600,0.06863019,0.026749002,,,,,,,,,,,,,,,,, -37667,,,0.9920461177825928,0.0255558025091886,0.5183719989880958,0.9870699644088744,0.044741440564394,0.2828680161152879,43793.0,0.9862450361251832,0.0476358421146869,0.27484689170873,43793.0,12265.234405755997,19377.787124872208,12265.234405755997,7109.82846736908,1.5917348861694336,0.0 -37700,0.066612825,0.028385203,,,,,,,,,,,,,,,,, -37800,0.057676695,0.027624866,,,,,,,,,,,,,,,,, -37900,0.06637178,0.027377017,,,,,,,,,,,,,,,,, -38000,0.07305621,0.027658878,,,,,,,,,,,,,,,,, -38100,0.06419569,0.026910769,,,,,,,,,,,,,,,,, -38200,0.063118465,0.025204357,,,,,,,,,,,,,,,,, -38300,0.066410646,0.026688885,,,,,,,,,,,,,,,,, -38400,0.092613965,0.027502077,,,,,,,,,,,,,,,,, -38403,,,0.99214369058609,0.0253321044147014,0.517584729671504,0.9869266152381896,0.0444921404123306,0.2914535596647501,43793.0,0.9862176179885864,0.0471333190798759,0.2745749674200974,43793.0,12505.260452270508,19747.76282644272,12505.260452270508,7239.716079473495,1.6303064823150637,0.0 -38500,0.072819345,0.029265376,,,,,,,,,,,,,,,,, -38600,0.060769968,0.029486243,,,,,,,,,,,,,,,,, -38700,0.058930486,0.027284319,,,,,,,,,,,,,,,,, -38800,0.06997967,0.027865969,,,,,,,,,,,,,,,,, -38900,0.06940637,0.025950063,,,,,,,,,,,,,,,,, -39000,0.067157164,0.025933333,,,,,,,,,,,,,,,,, -39100,0.071368895,0.026188912,,,,,,,,,,,,,,,,, -39131,,,0.9921480417251588,0.0249071978032588,0.5369891992399025,0.9869238138198853,0.04449999704957,0.2855483362441873,43793.0,0.9861586689949036,0.0471546128392219,0.275559722777406,43793.0,12745.4219083786,20116.52850627899,12745.4219083786,7368.2566702365875,1.6696994304656982,0.0 -39200,0.061301302,0.025207274,,,,,,,,,,,,,,,,, -39300,0.06857028,0.026034104,,,,,,,,,,,,,,,,, -39400,0.075741306,0.027394786,,,,,,,,,,,,,,,,, -39500,0.071380615,0.025600912,,,,,,,,,,,,,,,,, -39600,0.061905276,0.025860745,,,,,,,,,,,,,,,,, -39700,0.07124397,0.02739412,,,,,,,,,,,,,,,,, -39800,0.07613097,0.024224374,,,,,,,,,,,,,,,,, -39866,,,0.9923309087753296,0.0244951397180557,0.534551486086762,0.9868633151054382,0.0443124733865261,0.2802457208483463,43793.0,0.9861182570457458,0.046987347304821,0.2686304230613403,43793.0,12985.612907409668,20480.394863128666,12985.612907409668,7491.874098777771,1.7054214477539062,0.0 -39900,0.06913229,0.029057527,,,,,,,,,,,,,,,,, -40000,0.07955408,0.028216988,,,,,,,,,,,,,,,,, -40100,0.06940324,0.026099699,,,,,,,,,,,,,,,,, -40200,0.067765296,0.025994066,,,,,,,,,,,,,,,,, -40300,0.08044506,0.025858708,,,,,,,,,,,,,,,,, -40400,0.07471985,0.027889373,,,,,,,,,,,,,,,,, -40500,0.061830305,0.024387259,,,,,,,,,,,,,,,,, -40600,,,0.9924687743186952,0.0240130554884672,0.5691049705818645,0.9870585799217224,0.0444880723953247,0.2891847364191049,43793.0,0.986295998096466,0.0472631752490997,0.2766268806489722,43793.0,13225.71517944336,20858.764546394348,13225.71517944336,7630.08384180069,1.7412102222442627,0.0 -40600,0.075019635,0.027834162,,,,,,,,,,,,,,,,, -40700,0.069694765,0.026666677,,,,,,,,,,,,,,,,, -40800,0.0791004,0.030018575,,,,,,,,,,,,,,,,, -40900,0.08323638,0.028148456,,,,,,,,,,,,,,,,, -41000,0.06550694,0.028826792,,,,,,,,,,,,,,,,, -41100,0.073413335,0.031250335,,,,,,,,,,,,,,,,, -41200,0.07007933,0.024472468,,,,,,,,,,,,,,,,, -41300,0.06534086,0.026703563,,,,,,,,,,,,,,,,, -41341,,,0.9926599860191344,0.0233439113944768,0.5782135633249319,0.9871352910995485,0.0446878522634506,0.2936668428755002,43793.0,0.986275315284729,0.0476203300058841,0.2830297049350724,43793.0,13465.765576124191,21225.642529010773,13465.765576124191,7756.849696874618,1.7805137634277344,0.0 -41400,0.06531723,0.0248086,,,,,,,,,,,,,,,,, -41500,0.07522236,0.025515946,,,,,,,,,,,,,,,,, -41600,0.07217844,0.024750322,,,,,,,,,,,,,,,,, -41700,0.06738724,0.029361479,,,,,,,,,,,,,,,,, -41800,0.071672656,0.029499598,,,,,,,,,,,,,,,,, -41900,0.086036965,0.028814532,,,,,,,,,,,,,,,,, -42000,0.06155414,0.025082791,,,,,,,,,,,,,,,,, -42081,,,0.9928129315376282,0.0229587573558092,0.5814897734017909,0.987063467502594,0.0444525331258773,0.2931879108949782,43793.0,0.9861797094345092,0.0473838150501251,0.2803766791096357,43793.0,13705.795667171478,21592.62402153015,13705.795667171478,7883.743916749954,1.8158748149871824,0.0 -42100,0.061277222,0.022702217,,,,,,,,,,,,,,,,, -42200,0.074755065,0.026535098,,,,,,,,,,,,,,,,, -42300,0.06487838,0.02473848,,,,,,,,,,,,,,,,, -42400,0.096524425,0.027192835,,,,,,,,,,,,,,,,, -42500,0.07003982,0.025718432,,,,,,,,,,,,,,,,, -42600,0.06755435,0.024674416,,,,,,,,,,,,,,,,, -42700,0.07048829,0.0242733,,,,,,,,,,,,,,,,, -42800,0.077787526,0.027831992,,,,,,,,,,,,,,,,, -42818,,,0.9925578236579896,0.0238033216446638,0.564259724336508,0.9869980812072754,0.0447949543595314,0.2879867390854697,43793.0,0.9862193465232848,0.0474888384342193,0.2773864421082346,43793.0,13945.906858921053,21960.24821233749,13945.906858921053,8011.200119256973,1.8510291576385496,0.0 -42900,0.084239,0.026601419,,,,,,,,,,,,,,,,, -43000,0.06685628,0.0275535,,,,,,,,,,,,,,,,, -43100,0.08290914,0.027740248,,,,,,,,,,,,,,,,, -43200,0.06712801,0.024216322,,,,,,,,,,,,,,,,, -43300,0.08588975,0.026665071,,,,,,,,,,,,,,,,, -43400,0.07803671,0.026828682,,,,,,,,,,,,,,,,, -43500,0.07006119,0.024963917,,,,,,,,,,,,,,,,, -43548,,,0.9924206733703612,0.0242899972945451,0.5410177292215665,0.9870171546936036,0.0448606684803962,0.2942300050683708,43793.0,0.9861186742782592,0.0477583967149257,0.2761628388744243,43793.0,14185.98506641388,22328.05681943893,14185.98506641388,8138.873854398727,1.886252403259277,0.0 -43600,0.08429461,0.028876675,,,,,,,,,,,,,,,,, -43700,0.085876,0.027776012,,,,,,,,,,,,,,,,, -43800,0.086105615,0.029324891,,,,,,,,,,,,,,,,, -43900,0.08670068,0.028357854,,,,,,,,,,,,,,,,, -44000,0.08137703,0.0282711,,,,,,,,,,,,,,,,, -44100,0.081197985,0.026151799,,,,,,,,,,,,,,,,, -44200,0.0701401,0.024367701,,,,,,,,,,,,,,,,, -44278,,,0.9924456477165222,0.0241081770509481,0.5539474357205003,0.9869996905326844,0.0446444861590862,0.2899616419114914,43793.0,0.9861868619918824,0.047412171959877,0.2761934727232424,43793.0,14426.06815457344,22696.50433659553,14426.06815457344,8267.179068088531,1.923084735870361,0.0 -44300,0.08485538,0.027723994,,,,,,,,,,,,,,,,, -44400,0.07647214,0.027926026,,,,,,,,,,,,,,,,, -44500,0.08251731,0.025436457,,,,,,,,,,,,,,,,, -44600,0.0837256,0.024156991,,,,,,,,,,,,,,,,, -44700,0.080679715,0.025766637,,,,,,,,,,,,,,,,, -44800,0.073269084,0.021344084,,,,,,,,,,,,,,,,, -44900,0.071138345,0.025806041,,,,,,,,,,,,,,,,, -45000,0.07604582,0.027406782,,,,,,,,,,,,,,,,, -45009,,,0.992694079875946,0.0233445260673761,0.5726172586558297,0.986934781074524,0.0444469712674617,0.2926345439056429,43793.0,0.986127495765686,0.0472488440573215,0.2769249759013684,43793.0,14666.161107301712,23061.56834554672,14666.161107301712,8392.092687368393,1.9581577777862549,0.0 -45100,0.074854404,0.025415411,,,,,,,,,,,,,,,,, -45200,0.08298823,0.023249812,,,,,,,,,,,,,,,,, -45300,0.09898818,0.026765428,,,,,,,,,,,,,,,,, -45400,0.08452833,0.028950484,,,,,,,,,,,,,,,,, -45500,0.0724529,0.022442488,,,,,,,,,,,,,,,,, -45600,0.06930116,0.02647234,,,,,,,,,,,,,,,,, -45700,0.086777925,0.023431476,,,,,,,,,,,,,,,,, -45750,,,0.9927871227264404,0.022890530526638,0.5784728400986485,0.987106442451477,0.0445770323276519,0.2865539158901619,43793.0,0.9862618446350098,0.0473022013902664,0.2788322799294261,43793.0,14906.305232286451,23425.255395650864,14906.305232286451,8515.576565265656,1.9948420524597168,0.0 -45800,0.08130865,0.028871462,,,,,,,,,,,,,,,,, -45900,0.08083287,0.025558185,,,,,,,,,,,,,,,,, -46000,0.07972713,0.027087485,,,,,,,,,,,,,,,,, -46100,0.08635211,0.025505282,,,,,,,,,,,,,,,,, -46200,0.08648506,0.021937737,,,,,,,,,,,,,,,,, -46300,0.076278664,0.024970151,,,,,,,,,,,,,,,,, -46400,0.08533729,0.023916516,,,,,,,,,,,,,,,,, -46486,,,0.9928334951400756,0.0226695798337459,0.5871963290355688,0.9870455861091614,0.045237760990858,0.2849342940890492,43793.0,0.9862555265426636,0.048071514815092,0.2768047922216397,43793.0,15146.25148677826,23791.95691871643,15146.25148677826,8642.272961139679,2.0315263271331787,0.0 -46500,0.08898914,0.024366401,,,,,,,,,,,,,,,,, -46600,0.082704894,0.026384242,,,,,,,,,,,,,,,,, -46700,0.082267635,0.024641246,,,,,,,,,,,,,,,,, -46800,0.07797295,0.024271922,,,,,,,,,,,,,,,,, -46900,0.09463732,0.026217027,,,,,,,,,,,,,,,,, -47000,0.0940862,0.029926453,,,,,,,,,,,,,,,,, -47100,0.07482746,0.021082437,,,,,,,,,,,,,,,,, -47200,0.0846049,0.02555229,,,,,,,,,,,,,,,,, -47225,,,0.9932576417922974,0.0212671887129545,0.6281929712934287,0.9870005249977112,0.0450255014002323,0.2873283203104504,43793.0,0.9862689971923828,0.0480483807623386,0.2787716507698172,43793.0,15386.214908361437,24156.55635929108,15386.214908361437,8766.850444555283,2.067582845687866,0.0 -47300,0.08184617,0.024026046,,,,,,,,,,,,,,,,, -47400,0.07929701,0.026290087,,,,,,,,,,,,,,,,, -47500,0.07265731,0.020697808,,,,,,,,,,,,,,,,, -47600,0.082535,0.027294908,,,,,,,,,,,,,,,,, -47700,0.071690716,0.02241998,,,,,,,,,,,,,,,,, -47800,0.07574636,0.023148121,,,,,,,,,,,,,,,,, -47900,0.073467515,0.024372358,,,,,,,,,,,,,,,,, -47963,,,0.993298590183258,0.0214150361716747,0.6106789888766017,0.9869879484176636,0.045169573277235,0.2882434698512126,43793.0,0.9861649870872498,0.0481113791465759,0.2771047990863807,43793.0,15626.345014333723,24522.78030061721,15626.345014333723,8892.886050701141,2.1041386127471924,0.0 -48000,0.077046156,0.024279576,,,,,,,,,,,,,,,,, -48100,0.09296597,0.024541155,,,,,,,,,,,,,,,,, -48200,0.09783769,0.02727299,,,,,,,,,,,,,,,,, -48300,0.092774026,0.028467739,,,,,,,,,,,,,,,,, -48400,0.0722418,0.022389418,,,,,,,,,,,,,,,,, -48500,0.08673878,0.02398648,,,,,,,,,,,,,,,,, -48600,0.092683814,0.027491704,,,,,,,,,,,,,,,,, -48700,0.08063981,0.02318305,,,,,,,,,,,,,,,,, -48705,,,0.993052899837494,0.0219904389232397,0.602626412423219,0.987127959728241,0.0452293753623962,0.2974782090753343,43793.0,0.986295998096466,0.0483037866652011,0.2816303732269372,43793.0,15866.32997751236,24889.743696928024,15866.32997751236,9019.805342435837,2.1411499977111816,0.0 -48800,0.0961655,0.026068846,,,,,,,,,,,,,,,,, -48900,0.09454819,0.025817662,,,,,,,,,,,,,,,,, -49000,0.09414038,0.021675412,,,,,,,,,,,,,,,,, -49100,0.07924855,0.024950854,,,,,,,,,,,,,,,,, -49200,0.08594203,0.022697931,,,,,,,,,,,,,,,,, -49300,0.105500296,0.026042404,,,,,,,,,,,,,,,,, -49400,0.0792227,0.024027139,,,,,,,,,,,,,,,,, -49434,,,0.993026316165924,0.0220904797315597,0.5961216643344134,0.9870460033416748,0.0452532954514026,0.2912603318833082,43793.0,0.9862993359565736,0.0482907220721244,0.278605704419052,43793.0,16106.437074422836,25259.69244766236,16106.437074422836,9149.583971261978,2.1799392700195312,0.0 -49500,0.09044557,0.025999947,,,,,,,,,,,,,,,,, -49600,0.08657295,0.022971239,,,,,,,,,,,,,,,,, -49700,0.09708187,0.026302123,,,,,,,,,,,,,,,,, -49800,0.07950444,0.022673097,,,,,,,,,,,,,,,,, -49900,0.09159245,0.024036502,,,,,,,,,,,,,,,,, -50000,0.079581246,0.023194218,,,,,,,,,,,,,,,,, -50100,0.09128753,0.025100142,,,,,,,,,,,,,,,,, -50173,,,0.9931430816650392,0.0218639783561229,0.5994783641295355,0.9869099855422974,0.0454811379313468,0.2840237806197731,43793.0,0.9862332344055176,0.0483635328710079,0.2769694896537659,43793.0,16346.489510059357,25626.25821518898,16346.489510059357,9276.038954019548,2.2161362171173096,0.0 -50200,0.09535292,0.026285695,,,,,,,,,,,,,,,,, -50300,0.08676721,0.022324666,,,,,,,,,,,,,,,,, -50400,0.08365653,0.025391428,,,,,,,,,,,,,,,,, -50500,0.08833786,0.024572335,,,,,,,,,,,,,,,,, -50600,0.08377568,0.022137633,,,,,,,,,,,,,,,,, -50700,0.09529282,0.023100786,,,,,,,,,,,,,,,,, -50800,0.10010117,0.029470423,,,,,,,,,,,,,,,,, -50900,0.09277295,0.02464069,,,,,,,,,,,,,,,,, -50904,,,0.9931342601776124,0.0215610228478908,0.6124295443985085,0.9870536923408508,0.0457516275346279,0.2861033373103874,43793.0,0.9863094687461852,0.0486667156219482,0.2779850187157393,43793.0,16586.699061632156,25998.626806259155,16586.699061632156,9408.138950824738,2.25294828414917,0.0 -51000,0.09321107,0.02718656,,,,,,,,,,,,,,,,, -51100,0.10210202,0.02784814,,,,,,,,,,,,,,,,, -51200,0.08617439,0.022718081,,,,,,,,,,,,,,,,, -51300,0.0886748,0.023075279,,,,,,,,,,,,,,,,, -51400,0.09606153,0.025295861,,,,,,,,,,,,,,,,, -51500,0.11997732,0.027216382,,,,,,,,,,,,,,,,, -51600,0.08921207,0.024816912,,,,,,,,,,,,,,,,, -51633,,,0.9932966232299804,0.021137511357665,0.6139599699223799,0.9870638251304626,0.0457816421985626,0.2866013596010617,43793.0,0.9862887859344482,0.0488758012652397,0.2805226259149176,43793.0,16826.849710941315,26370.819693803787,16826.849710941315,9540.113079071043,2.294236183166504,0.0 -51700,0.08569606,0.023969198,,,,,,,,,,,,,,,,, -51800,0.102077834,0.02306104,,,,,,,,,,,,,,,,, -51900,0.07821127,0.021008128,,,,,,,,,,,,,,,,, -52000,0.099489614,0.02357075,,,,,,,,,,,,,,,,, -52100,0.09479122,0.027037134,,,,,,,,,,,,,,,,, -52200,0.10238353,0.024856316,,,,,,,,,,,,,,,,, -52300,0.106082685,0.020882117,,,,,,,,,,,,,,,,, -52364,,,0.9933950901031494,0.0207604877650737,0.6385449009632981,0.987082540988922,0.0456486195325851,0.2941234444125793,43793.0,0.986327588558197,0.0485045313835144,0.2836710398323977,43793.0,17066.924193382263,26734.583471298218,17066.924193382263,9663.73701095581,2.3359954357147217,0.0 -52400,0.09716347,0.022439959,,,,,,,,,,,,,,,,, -52500,0.080544464,0.0209107,,,,,,,,,,,,,,,,, -52600,0.08598761,0.023192864,,,,,,,,,,,,,,,,, -52700,0.09428667,0.024308207,,,,,,,,,,,,,,,,, -52800,0.10660031,0.02562153,,,,,,,,,,,,,,,,, -52900,0.093230866,0.022928176,,,,,,,,,,,,,,,,, -53000,0.09538186,0.024799274,,,,,,,,,,,,,,,,, -53100,0.0918431,0.022070885,,,,,,,,,,,,,,,,, -53103,,,0.9938114285469056,0.0195509679615497,0.6596532125631671,0.987057328224182,0.0461726002395153,0.2965265591422013,43793.0,0.986295998096466,0.0487239696085453,0.2818153417585818,43793.0,17307.14094877243,27100.617507457733,17307.14094877243,9789.492477416992,2.375693798065185,0.0 -53200,0.11742756,0.023868136,,,,,,,,,,,,,,,,, -53300,0.09982107,0.024886195,,,,,,,,,,,,,,,,, -53400,0.08878995,0.021221464,,,,,,,,,,,,,,,,, -53500,0.11183582,0.027806617,,,,,,,,,,,,,,,,, -53600,0.091490194,0.023001814,,,,,,,,,,,,,,,,, -53700,0.116319,0.023626719,,,,,,,,,,,,,,,,, -53800,0.092812315,0.02091888,,,,,,,,,,,,,,,,, -53835,,,0.9938933253288268,0.0194165743887424,0.6604764193160376,0.9871150255203248,0.0463631376624107,0.2949741394119192,43793.0,0.9863587617874146,0.0489544048905372,0.284141920602494,43793.0,17547.083233356476,27461.019245386124,17547.083233356476,9909.892055034636,2.4137353897094727,0.0 -53900,0.10626633,0.025798583,,,,,,,,,,,,,,,,, -54000,0.09016225,0.022159787,,,,,,,,,,,,,,,,, -54100,0.098318316,0.023440456,,,,,,,,,,,,,,,,, -54200,0.09458233,0.022580072,,,,,,,,,,,,,,,,, -54300,0.112131,0.026883857,,,,,,,,,,,,,,,,, -54400,0.10475137,0.02675199,,,,,,,,,,,,,,,,, -54500,0.097435914,0.022253774,,,,,,,,,,,,,,,,, -54579,,,0.9937782287597656,0.0197028685361146,0.6423692855399952,0.9869229793548584,0.046136025339365,0.293365370342748,43793.0,0.9861123561859132,0.0491110160946846,0.2809566913061406,43793.0,17787.043855190277,27830.67459988594,17787.043855190277,10039.527564287186,2.451258659362793,0.0 -54600,0.10997701,0.02556557,,,,,,,,,,,,,,,,, -54700,0.104108825,0.023863876,,,,,,,,,,,,,,,,, -54800,0.09801645,0.020075744,,,,,,,,,,,,,,,,, -54900,0.106139414,0.023585299,,,,,,,,,,,,,,,,, -55000,0.10628686,0.023083013,,,,,,,,,,,,,,,,, -55100,0.09535831,0.022381278,,,,,,,,,,,,,,,,, -55200,0.10598792,0.024015266,,,,,,,,,,,,,,,,, -55300,0.10375908,0.022237202,,,,,,,,,,,,,,,,, -55314,,,0.9934359788894652,0.0206140764057636,0.631309266735881,0.98701673746109,0.0463982000946998,0.291090867299716,43793.0,0.9862247705459596,0.0493720471858978,0.2805172059646496,43793.0,18027.26170873642,28198.673761606216,18027.26170873642,10167.246821165085,2.491405248641968,0.0 -55400,0.102579445,0.024888817,,,,,,,,,,,,,,,,, -55500,0.09647974,0.02113701,,,,,,,,,,,,,,,,, -55600,0.09583563,0.020750035,,,,,,,,,,,,,,,,, -55700,0.11260714,0.023440698,,,,,,,,,,,,,,,,, -55800,0.09294522,0.021241091,,,,,,,,,,,,,,,,, -55900,0.113567166,0.022595748,,,,,,,,,,,,,,,,, -56000,0.108483404,0.023899205,,,,,,,,,,,,,,,,, -56050,,,0.9934125542640686,0.020549688488245,0.6211181522600252,0.9869664311408995,0.0465161129832267,0.291339921317668,43793.0,0.986237406730652,0.049305684864521,0.285898564507146,43793.0,18267.27084064484,28565.63999724388,18267.27084064484,10294.143009662628,2.530449390411377,0.0 -56100,0.10462314,0.02314411,,,,,,,,,,,,,,,,, -56200,0.09996375,0.02373464,,,,,,,,,,,,,,,,, -56300,0.12573314,0.021212187,,,,,,,,,,,,,,,,, -56400,0.10217539,0.020864818,,,,,,,,,,,,,,,,, -56500,0.11700398,0.024226012,,,,,,,,,,,,,,,,, -56600,0.11811003,0.02484216,,,,,,,,,,,,,,,,, -56700,0.10154255,0.022534464,,,,,,,,,,,,,,,,, -56785,,,0.9937226176261902,0.0196337290108203,0.6670870496531485,0.9869379997253418,0.0466276854276657,0.2882335202718856,43793.0,0.9862247705459596,0.0494795329868793,0.2794730220607643,43793.0,18507.41755247116,28932.106006860733,18507.41755247116,10420.40105509758,2.569205045700073,0.0 -56800,0.10243585,0.02083179,,,,,,,,,,,,,,,,, -56900,0.10948161,0.022360519,,,,,,,,,,,,,,,,, -57000,0.09395267,0.020269366,,,,,,,,,,,,,,,,, -57100,0.09825102,0.020020654,,,,,,,,,,,,,,,,, -57200,0.11158745,0.021032574,,,,,,,,,,,,,,,,, -57300,0.09590701,0.022205176,,,,,,,,,,,,,,,,, -57400,0.11375321,0.024721224,,,,,,,,,,,,,,,,, -57500,0.1034335,0.02223088,,,,,,,,,,,,,,,,, -57508,,,0.9938977956771852,0.0191189385950565,0.6645415284070957,0.9869725108146667,0.046925239264965,0.287619898859582,43793.0,0.986193597316742,0.0498026758432388,0.2824516621344256,43793.0,18747.6509013176,29301.772493600845,18747.6509013176,10549.771411418917,2.607391595840454,0.0 -57600,0.119035006,0.026203243,,,,,,,,,,,,,,,,, -57700,0.11154269,0.022080116,,,,,,,,,,,,,,,,, -57800,0.1148235,0.02134611,,,,,,,,,,,,,,,,, -57900,0.10949265,0.021058198,,,,,,,,,,,,,,,,, -58000,0.11734668,0.024790503,,,,,,,,,,,,,,,,, -58100,0.10816476,0.019069668,,,,,,,,,,,,,,,,, -58200,0.11299927,0.022213371,,,,,,,,,,,,,,,,, -58235,,,0.9940282702445984,0.0187131371349096,0.6765143704229382,0.986976146697998,0.0473107621073722,0.2894394066239375,43793.0,0.9862008094787598,0.0501836277544498,0.2808995364660824,43793.0,18987.619652748108,29665.43332839012,18987.619652748108,10673.395092010498,2.649985790252685,0.0 -58300,0.11735209,0.022671439,,,,,,,,,,,,,,,,, -58400,0.12217051,0.023425857,,,,,,,,,,,,,,,,, -58500,0.09958725,0.021126864,,,,,,,,,,,,,,,,, -58600,0.12052098,0.019929672,,,,,,,,,,,,,,,,, -58700,0.10095748,0.021150284,,,,,,,,,,,,,,,,, -58800,0.10317917,0.020597255,,,,,,,,,,,,,,,,, -58900,0.12952583,0.022247791,,,,,,,,,,,,,,,,, -58968,,,0.9943478107452391,0.0176516938954591,0.698494736500919,0.9870277047157288,0.0475666038691997,0.2892462198395104,43793.0,0.9862332344055176,0.0505859851837158,0.2774700604654444,43793.0,19227.710742235184,30028.28174901009,19227.710742235184,10796.089787244797,2.6901214122772217,0.0 -59000,0.12262492,0.022895357,,,,,,,,,,,,,,,,, -59100,0.11741043,0.022237906,,,,,,,,,,,,,,,,, -59200,0.12362736,0.023299752,,,,,,,,,,,,,,,,, -59300,0.113336846,0.021858582,,,,,,,,,,,,,,,,, -59400,0.10633071,0.02168942,,,,,,,,,,,,,,,,, -59500,0.112859085,0.020838054,,,,,,,,,,,,,,,,, -59600,0.110275485,0.020784399,,,,,,,,,,,,,,,,, -59700,0.12853496,0.02308482,,,,,,,,,,,,,,,,, -59706,,,0.9944453835487366,0.017467513680458,0.7007133967198509,0.9869737029075624,0.0474945120513439,0.2934684995339162,43793.0,0.9862951040267944,0.0503821186721324,0.2815228197907312,43793.0,19467.699901103973,30390.844262599945,19467.699901103973,10918.601605176926,2.7292253971099854,0.0 -59800,0.11960296,0.023027945,,,,,,,,,,,,,,,,, -59900,0.11959003,0.021438789,,,,,,,,,,,,,,,,, -60000,0.10306567,0.020178478,,,,,,,,,,,,,,,,, -60100,0.11554953,0.02042554,,,,,,,,,,,,,,,,, -60200,0.11592453,0.025206465,,,,,,,,,,,,,,,,, -60300,0.12534358,0.019310137,,,,,,,,,,,,,,,,, -60400,0.13587393,0.022280268,,,,,,,,,,,,,,,,, -60443,,,0.9942583441734314,0.0180808901786804,0.6905366072922718,0.986946940422058,0.0475612133741378,0.2924110143159727,43793.0,0.9862176179885864,0.0507260598242282,0.2830889685645464,43793.0,19707.881704092026,30753.16907453537,19707.881704092026,11040.682296514511,2.767962694168091,0.0 -60500,0.1129574,0.019096987,,,,,,,,,,,,,,,,, -60600,0.11702671,0.021708546,,,,,,,,,,,,,,,,, -60700,0.12658513,0.021925127,,,,,,,,,,,,,,,,, -60800,0.111556455,0.021823868,,,,,,,,,,,,,,,,, -60900,0.12661739,0.023747953,,,,,,,,,,,,,,,,, -61000,0.11802291,0.019871116,,,,,,,,,,,,,,,,, -61100,0.11547307,0.020483553,,,,,,,,,,,,,,,,, -61177,,,0.994189202785492,0.0180929508060216,0.6996259855388032,0.9869688749313354,0.0477571599185466,0.2939358251755667,43793.0,0.9861498475074768,0.0508788302540779,0.2810613524262955,43793.0,19947.89863538742,31119.0564391613,19947.89863538742,11166.49221110344,2.806538820266724,0.0 -61200,0.12755059,0.022282934,,,,,,,,,,,,,,,,, -61300,0.123047225,0.020920064,,,,,,,,,,,,,,,,, -61400,0.11738613,0.02096385,,,,,,,,,,,,,,,,, -61500,0.117448494,0.021650296,,,,,,,,,,,,,,,,, -61600,0.13350931,0.02476245,,,,,,,,,,,,,,,,, -61700,0.10981919,0.02081909,,,,,,,,,,,,,,,,, -61800,0.12475481,0.020967845,,,,,,,,,,,,,,,,, -61894,,,0.9942132234573364,0.0179898701608181,0.6911047828460279,0.9869351387023926,0.0481023862957954,0.2883425479472828,43793.0,0.9862441420555116,0.0512084066867828,0.2797026918756632,43793.0,20187.945281505585,31481.90607857704,20187.945281505585,11289.23188996315,2.847842216491699,0.0 -61900,0.1307393,0.022089608,,,,,,,,,,,,,,,,, -62000,0.114900604,0.018505223,,,,,,,,,,,,,,,,, -62100,0.11984775,0.018726353,,,,,,,,,,,,,,,,, -62200,0.12777962,0.01881984,,,,,,,,,,,,,,,,, -62300,0.13360932,0.024057172,,,,,,,,,,,,,,,,, -62400,0.14183335,0.020644471,,,,,,,,,,,,,,,,, -62500,0.117492914,0.018907845,,,,,,,,,,,,,,,,, -62600,0.13236985,0.021055805,,,,,,,,,,,,,,,,, -62626,,,0.9943777322769164,0.0174479763954877,0.6899870812187572,0.9869586825370787,0.0483387596905231,0.2845790269248508,43793.0,0.9862605929374696,0.0515501610934734,0.2782914654132266,43793.0,20427.938327550888,31844.98264527321,20427.938327550888,11412.255279064178,2.88604474067688,0.0 -62700,0.118821084,0.02024108,,,,,,,,,,,,,,,,, -62800,0.11289329,0.02002908,,,,,,,,,,,,,,,,, -62900,0.12127545,0.02027133,,,,,,,,,,,,,,,,, -63000,0.10885595,0.017836198,,,,,,,,,,,,,,,,, -63100,0.11397603,0.019539386,,,,,,,,,,,,,,,,, -63200,0.116570584,0.01810517,,,,,,,,,,,,,,,,, -63300,0.14052495,0.021046337,,,,,,,,,,,,,,,,, -63360,,,0.9945623874664308,0.0169183146208524,0.7184558394793336,0.986990749835968,0.04825434461236,0.2925786355656576,43793.0,0.986270308494568,0.0512690246105194,0.280880005569525,43793.0,20667.90618634224,32201.472920179367,20667.90618634224,11528.714675426483,2.926175355911255,0.0 -63400,0.11279022,0.020052927,,,,,,,,,,,,,,,,, -63500,0.12040284,0.021025438,,,,,,,,,,,,,,,,, -63600,0.1181583,0.018666262,,,,,,,,,,,,,,,,, -63700,0.12965299,0.020889178,,,,,,,,,,,,,,,,, -63800,0.13717611,0.01724679,,,,,,,,,,,,,,,,, -63900,0.10773128,0.018119412,,,,,,,,,,,,,,,,, -64000,0.1267974,0.019815724,,,,,,,,,,,,,,,,, -64100,0.13021612,0.019738702,,,,,,,,,,,,,,,,, -64106,,,0.9946512579917908,0.0165587794035673,0.724740173698647,0.9870366454124452,0.0486161857843399,0.2894655643069859,43793.0,0.9863195419311525,0.0515687614679336,0.2840452353283872,43793.0,20908.09387183189,32564.357125520703,20908.09387183189,11651.349862098694,2.965777635574341,0.0 -64200,0.13092269,0.019336408,,,,,,,,,,,,,,,,, -64300,0.12909666,0.020850679,,,,,,,,,,,,,,,,, -64400,0.13531718,0.020481244,,,,,,,,,,,,,,,,, -64500,0.11275426,0.018843181,,,,,,,,,,,,,,,,, -64600,0.121767215,0.018375456,,,,,,,,,,,,,,,,, -64700,0.14576675,0.020575112,,,,,,,,,,,,,,,,, -64800,0.14039804,0.020370597,,,,,,,,,,,,,,,,, -64844,,,0.995011568069458,0.0157795790582895,0.7456424173940478,0.9868746995925904,0.0486071705818176,0.2935998478437701,43793.0,0.986154854297638,0.0515463948249816,0.2841144505434902,43793.0,21148.1995844841,32928.89493846893,21148.1995844841,11775.719557523727,3.005262613296509,0.0 -64900,0.12995993,0.01949235,,,,,,,,,,,,,,,,, -65000,0.14447258,0.01972218,,,,,,,,,,,,,,,,, -65100,0.10834151,0.019729104,,,,,,,,,,,,,,,,, -65200,0.13022067,0.01901422,,,,,,,,,,,,,,,,, -65300,0.1237203,0.019040147,,,,,,,,,,,,,,,,, -65400,0.15504055,0.018487336,,,,,,,,,,,,,,,,, -65500,0.13461868,0.019756336,,,,,,,,,,,,,,,,, -65585,,,0.9950687885284424,0.0155374398455023,0.744460006485492,0.9869189262390136,0.0489300824701786,0.2965670504331467,43793.0,0.9861262440681458,0.0519878938794136,0.2851634495635804,43793.0,21388.1816380024,33286.03973245621,21388.1816380024,11892.819792985916,3.0460572242736816,0.0 -65600,0.124155656,0.017226012,,,,,,,,,,,,,,,,, -65700,0.13781254,0.017675862,,,,,,,,,,,,,,,,, -65800,0.1319647,0.020137677,,,,,,,,,,,,,,,,, -65900,0.14243846,0.020676691,,,,,,,,,,,,,,,,, -66000,0.12940215,0.020708542,,,,,,,,,,,,,,,,, -66100,0.13647841,0.017760279,,,,,,,,,,,,,,,,, -66200,0.13745931,0.019103434,,,,,,,,,,,,,,,,, -66300,0.1350257,0.020273687,,,,,,,,,,,,,,,,, -66332,,,0.9950166940689088,0.0156867504119873,0.7417331948355668,0.9868669509887696,0.0488156154751777,0.2948186594970466,43793.0,0.986139714717865,0.0520018711686134,0.2864536287657171,43793.0,21628.180775880814,33641.677525281906,21628.180775880814,12008.396959066393,3.0855839252471924,0.0 -66400,0.11873222,0.020058164,,,,,,,,,,,,,,,,, -66500,0.14058971,0.020158743,,,,,,,,,,,,,,,,, -66600,0.14322834,0.019087639,,,,,,,,,,,,,,,,, -66700,0.1281942,0.016881187,,,,,,,,,,,,,,,,, -66800,0.1333901,0.022260029,,,,,,,,,,,,,,,,, -66900,0.14082427,0.019443983,,,,,,,,,,,,,,,,, -67000,0.13909003,0.020594824,,,,,,,,,,,,,,,,, -67063,,,0.9946566820144652,0.0164927765727043,0.7092785997806119,0.9869843125343324,0.049311026930809,0.2938793254291189,43793.0,0.98628968000412,0.0523262955248355,0.2879645668159487,43793.0,21868.17398071289,34003.17022848129,21868.17398071289,12129.832447767258,3.125771999359131,0.0 -67100,0.10710098,0.016809613,,,,,,,,,,,,,,,,, -67200,0.14376591,0.018820206,,,,,,,,,,,,,,,,, -67300,0.13742419,0.019898094,,,,,,,,,,,,,,,,, -67400,0.15510991,0.02126325,,,,,,,,,,,,,,,,, -67500,0.15113352,0.020932807,,,,,,,,,,,,,,,,, -67600,0.13931982,0.020807635,,,,,,,,,,,,,,,,, -67700,0.13369805,0.01803846,,,,,,,,,,,,,,,,, -67800,0.15810676,0.020744892,,,,,,,,,,,,,,,,, -67806,,,0.9946697354316713,0.0165935941040515,0.7233143226581569,0.9868795275688172,0.0493571013212204,0.2930838377370317,43793.0,0.9861851930618286,0.0525597482919693,0.2859367983010787,43793.0,22108.26478767395,34359.531853199005,22108.26478767395,12246.04151725769,3.166090250015259,0.0 -67900,0.13261878,0.019383285,,,,,,,,,,,,,,,,, -68000,0.15131268,0.021718666,,,,,,,,,,,,,,,,, -68100,0.13308518,0.018757228,,,,,,,,,,,,,,,,, -68200,0.12907924,0.018601434,,,,,,,,,,,,,,,,, -68300,0.14818212,0.021287533,,,,,,,,,,,,,,,,, -68400,0.13477863,0.02019414,,,,,,,,,,,,,,,,, -68500,0.12051352,0.018389951,,,,,,,,,,,,,,,,, -68554,,,0.9949089884757996,0.0157757550477981,0.7360908466647706,0.9869266152381896,0.0492417514324188,0.2951497027189781,43793.0,0.9861515164375304,0.0522307381033897,0.2888726151712065,43793.0,22348.27229404449,34714.82387185097,22348.27229404449,12361.262688159944,3.2080864906311035,0.0 -68600,0.13396569,0.017626856,,,,,,,,,,,,,,,,, -68700,0.14431697,0.018496742,,,,,,,,,,,,,,,,, -68800,0.13668622,0.020589238,,,,,,,,,,,,,,,,, -68900,0.14496937,0.019047923,,,,,,,,,,,,,,,,, -69000,0.14202622,0.018909667,,,,,,,,,,,,,,,,, -69100,0.14012383,0.018199643,,,,,,,,,,,,,,,,, -69200,0.12601867,0.018621618,,,,,,,,,,,,,,,,, -69283,,,0.9949740171432496,0.0155983455479145,0.7509368956064921,0.986963152885437,0.0495180785655975,0.2943744306236122,43793.0,0.986163318157196,0.0527291633188724,0.2851350869481603,43793.0,22588.225852251053,35073.536291360855,22588.225852251053,12479.956215381622,3.247627258300781,0.0 -69300,0.1301773,0.019233469,,,,,,,,,,,,,,,,, -69400,0.15032507,0.02242179,,,,,,,,,,,,,,,,, -69500,0.13805878,0.019547135,,,,,,,,,,,,,,,,, -69600,0.14848879,0.018930305,,,,,,,,,,,,,,,,, -69700,0.14040554,0.018304795,,,,,,,,,,,,,,,,, -69800,0.11716957,0.018464552,,,,,,,,,,,,,,,,, -69900,0.14707825,0.018799065,,,,,,,,,,,,,,,,, -70000,0.13900712,0.019186487,,,,,,,,,,,,,,,,, -70025,,,0.9950136542320251,0.0153692392632365,0.7516144347200444,0.9869911670684814,0.049633577466011,0.2945740366852919,43793.0,0.9862056374549866,0.0529153123497962,0.285105080197603,43793.0,22828.19718813896,35433.087446689606,22828.19718813896,12599.468941926956,3.289933919906616,0.0 -70100,0.14260387,0.016594233,,,,,,,,,,,,,,,,, -70200,0.13890329,0.020255143,,,,,,,,,,,,,,,,, -70300,0.14131713,0.019188493,,,,,,,,,,,,,,,,, -70400,0.13943264,0.018868335,,,,,,,,,,,,,,,,, -70500,0.14354676,0.02053698,,,,,,,,,,,,,,,,, -70600,0.15825021,0.018944217,,,,,,,,,,,,,,,,, -70700,0.1202802,0.016767044,,,,,,,,,,,,,,,,, -70763,,,0.9954207539558412,0.0145008927211165,0.7669779487185071,0.9868775010108948,0.0495731830596923,0.2936450999214748,43793.0,0.9860811829566956,0.0528018586337566,0.2839924188185285,43793.0,23068.196282863617,35795.0873939991,23068.196282863617,12721.406133413317,3.331706047058105,0.0 -70800,0.15198196,0.018630058,,,,,,,,,,,,,,,,, -70900,0.15348628,0.020304728,,,,,,,,,,,,,,,,, -71000,0.14836997,0.017941238,,,,,,,,,,,,,,,,, -71100,0.13743974,0.017597644,,,,,,,,,,,,,,,,, -71200,0.13783568,0.017694987,,,,,,,,,,,,,,,,, -71300,0.15169166,0.020025339,,,,,,,,,,,,,,,,, -71400,0.14291829,0.018632347,,,,,,,,,,,,,,,,, -71500,,,0.9954916834831238,0.0141904382035136,0.7654607218142788,0.9869996905326844,0.049864936619997,0.2918474328552706,43793.0,0.9862028956413268,0.0531973056495189,0.2830510085013739,43793.0,23308.296140670776,36157.941160440445,23308.296140670776,12844.09109044075,3.37443470954895,0.0 -71500,0.14054511,0.017160052,,,,,,,,,,,,,,,,, -71600,0.14013559,0.017184623,,,,,,,,,,,,,,,,, -71700,0.13290118,0.016521588,,,,,,,,,,,,,,,,, -71800,0.13497563,0.01814858,,,,,,,,,,,,,,,,, -71900,0.13396399,0.019353924,,,,,,,,,,,,,,,,, -72000,0.1484522,0.016700508,,,,,,,,,,,,,,,,, -72100,0.13784264,0.01720766,,,,,,,,,,,,,,,,, -72200,0.1297172,0.016437385,,,,,,,,,,,,,,,,, -72242,,,0.9953582286834716,0.0145191708579659,0.7666464923592602,0.9869623780250548,0.0500339493155479,0.292061711280013,43793.0,0.9862096309661864,0.0534116253256797,0.284107939661829,43793.0,23548.290924072266,36513.74961900711,23548.290924072266,12959.840638399124,3.4160678386688232,0.0 -72300,0.14640658,0.020127373,,,,,,,,,,,,,,,,, -72400,0.1494974,0.019132782,,,,,,,,,,,,,,,,, -72500,0.13328166,0.017190162,,,,,,,,,,,,,,,,, -72600,0.1330974,0.017090892,,,,,,,,,,,,,,,,, -72700,0.17912842,0.019987084,,,,,,,,,,,,,,,,, -72800,0.13046417,0.017462036,,,,,,,,,,,,,,,,, -72900,0.13446473,0.016336609,,,,,,,,,,,,,,,,, -72992,,,0.9953705072402954,0.0144443539902567,0.7663247210353313,0.9869558811187744,0.0500800125300884,0.2917908410573102,43793.0,0.9861733913421632,0.0535128824412822,0.283005966453115,43793.0,23788.352286815643,36874.36281085014,23788.352286815643,13080.330332517624,3.4564311504364014,0.0 -73000,0.14592525,0.01823895,,,,,,,,,,,,,,,,, -73100,0.14900236,0.018437337,,,,,,,,,,,,,,,,, -73200,0.13198762,0.018176049,,,,,,,,,,,,,,,,, -73300,0.15216625,0.018908206,,,,,,,,,,,,,,,,, -73400,0.14282231,0.018452905,,,,,,,,,,,,,,,,, -73500,0.13172865,0.01650979,,,,,,,,,,,,,,,,, -73600,0.15140021,0.018415421,,,,,,,,,,,,,,,,, -73700,0.15109909,0.019347899,,,,,,,,,,,,,,,,, -73740,,,0.9952463507652284,0.0148539431393146,0.772484291095771,0.9869952201843262,0.0501833371818065,0.2928766906316999,43793.0,0.9862037301063538,0.0534953586757183,0.2853542408933036,43793.0,24028.334567308422,37232.946142435074,24028.334567308422,13198.867261886597,3.4985146522521973,0.0 -73800,0.14945103,0.01841512,,,,,,,,,,,,,,,,, -73900,0.1541307,0.018460391,,,,,,,,,,,,,,,,, -74000,0.13184376,0.017258067,,,,,,,,,,,,,,,,, -74100,0.16165465,0.018730134,,,,,,,,,,,,,,,,, -74200,0.12967479,0.01688036,,,,,,,,,,,,,,,,, -74300,0.14722824,0.017247284,,,,,,,,,,,,,,,,, -74400,0.14588137,0.015571614,,,,,,,,,,,,,,,,, -74480,,,0.9953171610832214,0.0145237715914845,0.7563114128114683,0.9869773983955384,0.0501542091369628,0.2936811767705728,43793.0,0.9862239360809326,0.0535146072506904,0.2863500061223415,43793.0,24268.569067955017,37590.75980067253,24268.569067955017,13316.381780862808,3.541311264038086,0.0 -74500,0.1312085,0.016617296,,,,,,,,,,,,,,,,, -74600,0.1500206,0.017833477,,,,,,,,,,,,,,,,, -74700,0.15622042,0.017824918,,,,,,,,,,,,,,,,, -74800,0.14937805,0.017972894,,,,,,,,,,,,,,,,, -74900,0.13727368,0.018483376,,,,,,,,,,,,,,,,, -75000,0.11819465,0.015783023,,,,,,,,,,,,,,,,, -75100,0.15486977,0.0191096,,,,,,,,,,,,,,,,, -75200,0.15789188,0.018736947,,,,,,,,,,,,,,,,, -75214,,,0.9954608678817748,0.0142556382343173,0.7678145129889065,0.9869444966316224,0.0501078702509403,0.2945273355628647,43793.0,0.9862167835235596,0.0534422062337398,0.2858810997305218,43793.0,24508.743983984,37950.37283325195,24508.743983984,13435.755778074265,3.583484649658203,0.0 -75300,0.12591249,0.016777007,,,,,,,,,,,,,,,,, -75400,0.14203286,0.016170833,,,,,,,,,,,,,,,,, -75500,0.1590151,0.019202799,,,,,,,,,,,,,,,,, -75600,0.14337435,0.017773887,,,,,,,,,,,,,,,,, -75700,0.13946116,0.019864209,,,,,,,,,,,,,,,,, -75800,0.13100696,0.016408408,,,,,,,,,,,,,,,,, -75900,0.1461273,0.018585343,,,,,,,,,,,,,,,,, -75953,,,0.9954636096954346,0.0141442064195871,0.7713807343557014,0.9870074391365052,0.050181545317173,0.2951738603598899,43793.0,0.9862155318260192,0.0535604022443294,0.2856587078179153,43793.0,24748.70914363861,38304.90705013275,24748.70914363861,13550.260266304016,3.625715970993042,0.0 -76000,0.15630652,0.018632613,,,,,,,,,,,,,,,,, -76100,0.14250821,0.016969902,,,,,,,,,,,,,,,,, -76200,0.15200455,0.02000358,,,,,,,,,,,,,,,,, -76300,0.13670185,0.016797392,,,,,,,,,,,,,,,,, -76400,0.1375567,0.017538773,,,,,,,,,,,,,,,,, -76500,0.15799522,0.01679925,,,,,,,,,,,,,,,,, -76600,0.14037114,0.017735153,,,,,,,,,,,,,,,,, -76699,,,0.995585560798645,0.0139017924666404,0.7811432594628471,0.9869680404663086,0.0502056255936622,0.2938870177098965,43793.0,0.986198663711548,0.0535183511674404,0.2852305929566506,43793.0,24988.704931735992,38665.30813074112,24988.704931735992,13670.602442026138,3.667493104934693,0.0 -76700,0.13743952,0.018094175,,,,,,,,,,,,,,,,, -76800,0.13111934,0.01722831,,,,,,,,,,,,,,,,, -76900,0.1404098,0.017472195,,,,,,,,,,,,,,,,, -77000,0.14103898,0.017334918,,,,,,,,,,,,,,,,, -77100,0.15375121,0.019356674,,,,,,,,,,,,,,,,, -77200,0.14179029,0.016923185,,,,,,,,,,,,,,,,, -77300,0.14893691,0.01703209,,,,,,,,,,,,,,,,, -77400,0.15840264,0.018667,,,,,,,,,,,,,,,,, -77440,,,0.9955874085426332,0.0138820931315422,0.7873145623554283,0.986979842185974,0.0502523556351661,0.2945675573855502,43793.0,0.9862180352211,0.0535914674401283,0.2857701660308041,43793.0,25228.811174869537,39027.34461021423,25228.811174869537,13792.465679645538,3.710915088653565,0.0 -77500,0.13621001,0.01637854,,,,,,,,,,,,,,,,, -77600,0.13342553,0.01745285,,,,,,,,,,,,,,,,, -77700,0.1343514,0.017788926,,,,,,,,,,,,,,,,, -77800,0.13146804,0.015659686,,,,,,,,,,,,,,,,, -77900,0.1456361,0.019432975,,,,,,,,,,,,,,,,, -78000,0.14231159,0.018824942,,,,,,,,,,,,,,,,, -78100,0.12766984,0.016130337,,,,,,,,,,,,,,,,, -78188,,,0.9955474734306335,0.0139575647190213,0.7654514914396318,0.986976146697998,0.0502530895173549,0.2955371902344945,43793.0,0.9862058162689208,0.0535954870283603,0.2853759011460834,43793.0,25469.02388048172,39381.65291333199,25469.02388048172,13906.49803853035,3.752807378768921,0.0 -78200,0.13558482,0.015130255,,,,,,,,,,,,,,,,, -78300,0.15686713,0.018816518,,,,,,,,,,,,,,,,, -78400,0.14811449,0.020761326,,,,,,,,,,,,,,,,, -78500,0.13148147,0.015861603,,,,,,,,,,,,,,,,, -78600,0.1533851,0.02126668,,,,,,,,,,,,,,,,, -78700,0.13828799,0.0186363,,,,,,,,,,,,,,,,, -78800,0.14683305,0.018397955,,,,,,,,,,,,,,,,, -78900,0.13282064,0.015493508,,,,,,,,,,,,,,,,, -78924,,,0.9955233931541444,0.0140358218923211,0.7749583686005181,0.9869713187217712,0.0502181053161621,0.2957871287319151,43793.0,0.9861990809440612,0.0535637363791465,0.2856132740395028,43793.0,25709.17568540573,39738.876128435135,25709.17568540573,14023.504931926727,3.79575777053833,0.0 -79000,0.1451896,0.019188846,,,,,,,,,,,,,,,,, -79100,0.14429858,0.016425282,,,,,,,,,,,,,,,,, -79200,0.13879564,0.017011367,,,,,,,,,,,,,,,,, -79300,0.14230198,0.018169915,,,,,,,,,,,,,,,,, -79400,0.14567584,0.017372416,,,,,,,,,,,,,,,,, -79500,0.15016739,0.017638287,,,,,,,,,,,,,,,,, -79600,0.14535075,0.016516257,,,,,,,,,,,,,,,,, -79667,,,0.9955575466156006,0.0140057615935802,0.7686545535718703,0.9869741201400756,0.050214298069477,0.2952958390202684,43793.0,0.9861965775489808,0.0535612180829048,0.2857440594827082,43793.0,25949.15015339852,40093.55513715744,25949.15015339852,14138.14640378952,3.837240219116211,0.0 -79700,0.1476574,0.017093666,,,,,,,,,,,,,,,,, -79800,0.14075394,0.017698696,,,,,,,,,,,,,,,,, -79900,0.14836474,0.018259978,,,,,,,,,,,,,,,,, -80000,0.14798301,0.01957451,,,,,,,,,,,,,,,,, -80100,0.15050817,0.019333906,,,,,,,,,,,,,,,,, -80200,0.14057778,0.016954012,,,,,,,,,,,,,,,,, -80300,0.14974996,0.016115328,,,,,,,,,,,,,,,,, -80396,,,0.9955423474311828,0.0139799173921346,0.7752668434431129,0.9869745373725892,0.0502125211060047,0.2944926144825867,43793.0,0.9861940741539,0.0535595826804637,0.2856693514423585,43793.0,26189.18081855774,40454.92825961113,26189.18081855774,14259.4233148098,3.880352020263672,0.0 -80400,0.14583232,0.017587021,,,,,,,,,,,,,,,,, -80500,0.13773596,0.017431365,,,,,,,,,,,,,,,,, -80600,0.14311144,0.017912656,,,,,,,,,,,,,,,,, -80700,0.16680859,0.019752478,,,,,,,,,,,,,,,,, -80800,0.1580746,0.02115896,,,,,,,,,,,,,,,,, -80900,0.14647968,0.01825435,,,,,,,,,,,,,,,,, -81000,0.15228814,0.016912468,,,,,,,,,,,,,,,,, -81100,0.16179442,0.018863255,,,,,,,,,,,,,,,,, -81136,,,0.9955633878707886,0.0139617240056395,0.7864327946995338,0.9869745373725892,0.0502125211060047,0.2945485630598818,43793.0,0.9861940741539,0.0535595826804637,0.2857074698180062,43793.0,26429.1895840168,40816.57488918304,26429.1895840168,14380.997165203094,3.922349452972412,0.0 -81200,0.13772438,0.016109796,,,,,,,,,,,,,,,,, -81300,0.13212004,0.017049195,,,,,,,,,,,,,,,,, -81400,0.12638734,0.016240846,,,,,,,,,,,,,,,,, -81500,0.14400724,0.01699853,,,,,,,,,,,,,,,,, -81600,0.11729857,0.014967352,,,,,,,,,,,,,,,,, -81700,0.15596125,0.019915743,,,,,,,,,,,,,,,,, -81800,0.15404868,0.019361988,,,,,,,,,,,,,,,,, -81870,,,0.995542049407959,0.0140200974419713,0.7701309672792755,0.9869745373725892,0.0502125211060047,0.2944653538357855,43793.0,0.9861940741539,0.0535595789551734,0.2857734941346637,43793.0,26669.127532482147,41178.67870616913,26669.127532482147,14503.09766292572,3.96537709236145,0.0 -81900,0.14657184,0.019003531,,,,,,,,,,,,,,,,, -82000,0.12874587,0.017200885,,,,,,,,,,,,,,,,, -82100,0.14325185,0.017297411,,,,,,,,,,,,,,,,, -82200,0.14652051,0.02010483,,,,,,,,,,,,,,,,, -82300,0.118802145,0.014365148,,,,,,,,,,,,,,,,, -82400,0.1425289,0.018265719,,,,,,,,,,,,,,,,, -82500,0.13945533,0.016992934,,,,,,,,,,,,,,,,, -82600,0.14356598,0.018183239,,,,,,,,,,,,,,,,, -82607,,,0.9955546259880066,0.013937359675765,0.7792414263929057,0.9869743585586548,0.0502125211060047,0.2944700054731899,43793.0,0.9861940741539,0.0535595789551734,0.2856746842297289,43793.0,26909.205247163773,41537.66540026665,26909.205247163773,14621.939532279968,4.011141300201416,0.0 -82700,0.13683988,0.015533246,,,,,,,,,,,,,,,,, -82800,0.13180679,0.017712468,,,,,,,,,,,,,,,,, -82900,0.15622662,0.018430952,,,,,,,,,,,,,,,,, -83000,0.13172051,0.017302254,,,,,,,,,,,,,,,,, -83100,0.16217473,0.019174546,,,,,,,,,,,,,,,,, -83200,0.14208414,0.017773427,,,,,,,,,,,,,,,,, -83300,0.15998316,0.020075375,,,,,,,,,,,,,,,,, -83349,,,0.9955734610557556,0.0139914928004145,0.7649956167611187,0.9869745373725892,0.0502125173807144,0.294621958718315,43793.0,0.9861940741539,0.0535595789551734,0.2857537976988816,43793.0,27149.281101465225,41898.16096329689,27149.281101465225,14742.293938875198,4.054401874542236,0.0 -83400,0.1311655,0.015506164,,,,,,,,,,,,,,,,, -83500,0.16833255,0.018590873,,,,,,,,,,,,,,,,, -83600,0.13919558,0.018619291,,,,,,,,,,,,,,,,, -83700,0.14425825,0.018200576,,,,,,,,,,,,,,,,, -83800,0.1576691,0.01758894,,,,,,,,,,,,,,,,, -83900,0.14523195,0.018699503,,,,,,,,,,,,,,,,, -84000,0.15686886,0.019193325,,,,,,,,,,,,,,,,, -84073,,,0.9955219030380248,0.0140211423859,0.7767582004172939,0.9869745373725892,0.0502125211060047,0.2945195156428599,43793.0,0.9861940741539,0.0535595789551734,0.2857942642529996,43793.0,27389.44238615036,42257.554856061935,27389.44238615036,14861.458558797836,4.097753524780273,0.0 -84100,0.13438614,0.017294599,,,,,,,,,,,,,,,,, -84200,0.14521275,0.020264545,,,,,,,,,,,,,,,,, -84300,0.15598777,0.01882329,,,,,,,,,,,,,,,,, -84400,0.13535306,0.01671736,,,,,,,,,,,,,,,,, -84500,0.15140244,0.019931441,,,,,,,,,,,,,,,,, -84600,0.14362164,0.017940482,,,,,,,,,,,,,,,,, -84700,0.14856142,0.018386444,,,,,,,,,,,,,,,,, -84796,,,0.9955556392669678,0.0139660816639661,0.7768871909447967,0.9869745373725892,0.0502125173807144,0.2944212728334547,43793.0,0.9861940741539,0.0535595826804637,0.2857006953799908,43793.0,27629.67704296112,42614.32751560211,27629.67704296112,14977.927819490433,4.1424174308776855,0.0 -84800,0.15005685,0.01770863,,,,,,,,,,,,,,,,, -84900,0.13564397,0.017787702,,,,,,,,,,,,,,,,, -85000,0.14537825,0.01833497,,,,,,,,,,,,,,,,, -85100,0.13235475,0.017692545,,,,,,,,,,,,,,,,, -85200,0.16629906,0.019831628,,,,,,,,,,,,,,,,, -85300,0.13409859,0.015928226,,,,,,,,,,,,,,,,, -85400,0.136706,0.017647581,,,,,,,,,,,,,,,,, -85500,0.14449885,0.017681008,,,,,,,,,,,,,,,,, -85525,,,0.9955314993858336,0.0140366535633802,0.7768274028723495,0.9869745373725892,0.0502125211060047,0.294535595189946,43793.0,0.9861940741539,0.0535595826804637,0.285688490661133,43793.0,27869.899612665176,42971.96783590317,27869.899612665176,15095.281085968018,4.185272216796875,0.0 -85600,0.1537105,0.020117303,,,,,,,,,,,,,,,,, -85700,0.14625594,0.018392622,,,,,,,,,,,,,,,,, -85800,0.1195018,0.016149018,,,,,,,,,,,,,,,,, -85900,0.15963593,0.01897462,,,,,,,,,,,,,,,,, -86000,0.14114264,0.017556144,,,,,,,,,,,,,,,,, -86100,0.1607461,0.020575423,,,,,,,,,,,,,,,,, -86200,0.13870683,0.016742636,,,,,,,,,,,,,,,,, -86254,,,0.9955706000328064,0.0139378253370523,0.7793773087752416,0.9869743585586548,0.0502125173807144,0.294717817910109,43793.0,0.9861940741539,0.0535595789551734,0.2857650136919014,43793.0,28109.899499177933,43331.08562636376,28109.899499177933,15214.328017950058,4.231759786605835,0.0 -86300,0.145844,0.017584078,,,,,,,,,,,,,,,,, -86400,0.15649195,0.018441193,,,,,,,,,,,,,,,,, -86500,0.15150857,0.01919128,,,,,,,,,,,,,,,,, -86600,0.13507529,0.017491572,,,,,,,,,,,,,,,,, -86700,0.15201953,0.017581059,,,,,,,,,,,,,,,,, -86800,0.15121283,0.017323025,,,,,,,,,,,,,,,,, -86900,0.14249004,0.01858442,,,,,,,,,,,,,,,,, -86979,,,0.995578110218048,0.0139131750911474,0.7689155617309109,0.9869745373725892,0.0502125211060047,0.2945113082026619,43793.0,0.9861940741539,0.0535595789551734,0.2858110313416989,43793.0,28349.91640353203,43691.868216753006,28349.91640353203,15335.024807214735,4.275558471679688,0.0 -87000,0.1400985,0.017615663,,,,,,,,,,,,,,,,, -87100,0.13742305,0.018382803,,,,,,,,,,,,,,,,, -87200,0.14880744,0.018567346,,,,,,,,,,,,,,,,, -87300,0.13926902,0.01781492,,,,,,,,,,,,,,,,, -87400,0.14480421,0.018565774,,,,,,,,,,,,,,,,, -87500,0.14123617,0.017282795,,,,,,,,,,,,,,,,, -87600,0.13665657,0.015367751,,,,,,,,,,,,,,,,, -87700,0.16647129,0.019540789,,,,,,,,,,,,,,,,, -87709,,,0.9954716563224792,0.0142105771228671,0.7767942679630444,0.9869741201400756,0.0502125211060047,0.2945353564006339,43793.0,0.9861940741539,0.0535595789551734,0.2858338221202554,43793.0,28590.017910003666,44050.504854917526,28590.017910003666,15453.49022746086,4.321666479110718,0.0 -87800,0.15539731,0.018635402,,,,,,,,,,,,,,,,, -87900,0.13806348,0.016816568,,,,,,,,,,,,,,,,, -88000,0.15094575,0.017733397,,,,,,,,,,,,,,,,, -88100,0.13653822,0.01642503,,,,,,,,,,,,,,,,, -88200,0.15519586,0.018826734,,,,,,,,,,,,,,,,, -88300,0.14170225,0.017564697,,,,,,,,,,,,,,,,, -88400,0.15324993,0.018964827,,,,,,,,,,,,,,,,, -88433,,,0.9956030249595642,0.0138314524665474,0.7797903015231427,0.9869743585586548,0.0502125211060047,0.294608031294072,43793.0,0.9861940741539,0.0535595789551734,0.2857710486792785,43793.0,28830.172302246094,44413.43903660774,28830.172302246094,15576.203959941864,4.365557670593262,0.0 -88500,0.1260255,0.017083557,,,,,,,,,,,,,,,,, -88600,0.14110534,0.016656177,,,,,,,,,,,,,,,,, -88700,0.1418703,0.015583311,,,,,,,,,,,,,,,,, -88800,0.13789609,0.018062778,,,,,,,,,,,,,,,,, -88900,0.14126317,0.017444136,,,,,,,,,,,,,,,,, -89000,0.14587674,0.018408014,,,,,,,,,,,,,,,,, -89100,0.14148086,0.018468102,,,,,,,,,,,,,,,,, -89165,,,0.995522916316986,0.0140581410378217,0.7773444141831463,0.9869743585586548,0.0502125211060047,0.2945774283025879,43793.0,0.9861940741539,0.0535595789551734,0.2857071438595264,43793.0,29070.12420630455,44770.4546122551,29070.12420630455,15693.201646327972,4.409271240234375,0.0 -89200,0.14909618,0.017943438,,,,,,,,,,,,,,,,, -89300,0.1373836,0.018565055,,,,,,,,,,,,,,,,, -89400,0.13259897,0.016694564,,,,,,,,,,,,,,,,, -89500,0.1239311,0.01716455,,,,,,,,,,,,,,,,, -89600,0.13787684,0.016406849,,,,,,,,,,,,,,,,, -89700,0.1501494,0.017501267,,,,,,,,,,,,,,,,, -89800,0.14316325,0.01841158,,,,,,,,,,,,,,,,, -89897,,,0.9955548048019408,0.0139969484880566,0.7708878094406291,0.9869745373725892,0.0502125211060047,0.2944496519986352,43793.0,0.9861940741539,0.0535595789551734,0.2859636076934878,43793.0,29310.1939008236,45133.53694319725,29310.1939008236,15816.148104906082,4.453778028488159,0.0 -89900,0.14314035,0.017394004,,,,,,,,,,,,,,,,, -90000,0.13905682,0.017882876,,,,,,,,,,,,,,,,, -90100,0.13895868,0.016624097,,,,,,,,,,,,,,,,, -90200,0.1472208,0.019031549,,,,,,,,,,,,,,,,, -90300,0.13928136,0.017915923,,,,,,,,,,,,,,,,, -90400,0.14162408,0.017037911,,,,,,,,,,,,,,,,, -90500,0.15085235,0.016417226,,,,,,,,,,,,,,,,, -90600,0.1529367,0.018685754,,,,,,,,,,,,,,,,, -90633,,,0.9955064058303832,0.0140866888687014,0.7742396649486707,0.9869745373725892,0.0502125211060047,0.2945673298006516,43793.0,0.9861940741539,0.0535595789551734,0.2857039443897824,43793.0,29550.39145565033,45491.60465455055,29550.39145565033,15933.94653224945,4.502529859542847,0.0 -90700,0.12908117,0.016875254,,,,,,,,,,,,,,,,, -90800,0.14167856,0.017600222,,,,,,,,,,,,,,,,, -90900,0.14517169,0.01839392,,,,,,,,,,,,,,,,, -91000,0.16338086,0.017604643,,,,,,,,,,,,,,,,, -91100,0.14432766,0.018153168,,,,,,,,,,,,,,,,, -91200,0.13502102,0.016712658,,,,,,,,,,,,,,,,, -91300,0.15350655,0.01976416,,,,,,,,,,,,,,,,, -91370,,,0.9955812096595764,0.0139295449480414,0.7740129386936662,0.9869745373725892,0.0502125211060047,0.294608725721259,43793.0,0.9861940741539,0.0535595789551734,0.2856685747868243,43793.0,29790.36308145523,45846.48057794571,29790.36308145523,16048.78459572792,4.54682207107544,0.0 -91400,0.1624044,0.017840913,,,,,,,,,,,,,,,,, -91500,0.17155135,0.019217482,,,,,,,,,,,,,,,,, -91600,0.14225489,0.018508371,,,,,,,,,,,,,,,,, -91700,0.1568215,0.015327036,,,,,,,,,,,,,,,,, -91800,0.14253095,0.016701607,,,,,,,,,,,,,,,,, -91900,0.13944575,0.016080698,,,,,,,,,,,,,,,,, -92000,0.15369767,0.019362811,,,,,,,,,,,,,,,,, -92100,0.1375396,0.016872136,,,,,,,,,,,,,,,,, -92104,,,0.9955791234970092,0.0138949332758784,0.7753812512261653,0.9869743585586548,0.0502125211060047,0.2945517697474837,43793.0,0.9861940741539,0.0535595789551734,0.2859116347293922,43793.0,30030.31191968918,46200.85204720497,30030.31191968918,16163.1401386261,4.591657400131226,0.0 -92200,0.14223346,0.017003227,,,,,,,,,,,,,,,,, -92300,0.14579144,0.018494602,,,,,,,,,,,,,,,,, -92400,0.14507224,0.017198814,,,,,,,,,,,,,,,,, -92500,0.13744149,0.017482717,,,,,,,,,,,,,,,,, -92600,0.15527384,0.018458331,,,,,,,,,,,,,,,,, -92700,0.1394995,0.015987244,,,,,,,,,,,,,,,,, -92800,0.12884298,0.015778664,,,,,,,,,,,,,,,,, -92839,,,0.9955458641052246,0.0140043785795569,0.7818500732533857,0.9869745373725892,0.0502125211060047,0.2944538656084857,43793.0,0.9861940741539,0.0535595789551734,0.2857860262767945,43793.0,30270.42789888382,46555.93433070183,30270.42789888382,16278.03237748146,4.639986991882324,0.0 -92900,0.14212027,0.018176682,,,,,,,,,,,,,,,,, -93000,0.12945059,0.01895019,,,,,,,,,,,,,,,,, -93100,0.15295692,0.01799337,,,,,,,,,,,,,,,,, -93200,0.14463623,0.019155713,,,,,,,,,,,,,,,,, -93300,0.14283223,0.01590023,,,,,,,,,,,,,,,,, -93400,0.14594507,0.016171584,,,,,,,,,,,,,,,,, -93500,0.15488902,0.017418578,,,,,,,,,,,,,,,,, -93570,,,0.9955337047576904,0.0139936245977878,0.7719343342136304,0.9869745373725892,0.0502125173807144,0.2945561478444909,43793.0,0.9861940741539,0.0535595826804637,0.2858384018367464,43793.0,30510.66025996208,46920.87311458588,30510.66025996208,16402.670434951782,4.685810327529907,0.0 -93600,0.13109511,0.017460275,,,,,,,,,,,,,,,,, -93700,0.13999993,0.016830808,,,,,,,,,,,,,,,,, -93800,0.15392639,0.018518692,,,,,,,,,,,,,,,,, -93900,0.13073497,0.016617749,,,,,,,,,,,,,,,,, -94000,0.1312016,0.017125104,,,,,,,,,,,,,,,,, -94100,0.14847666,0.017388785,,,,,,,,,,,,,,,,, -94200,0.14202711,0.020031631,,,,,,,,,,,,,,,,, -94300,0.14054012,0.016836202,,,,,,,,,,,,,,,,, -94307,,,0.9955654740333556,0.0139470482245087,0.7784315567150579,0.9869745373725892,0.0502125211060047,0.294558602517844,43793.0,0.9861940741539,0.0535595826804637,0.2857131922781328,43793.0,30750.85684776306,47274.11804151535,30750.85684776306,16515.649873495102,4.73239541053772,0.0 -94400,0.14206092,0.018390456,,,,,,,,,,,,,,,,, -94500,0.15609841,0.019489983,,,,,,,,,,,,,,,,, -94600,0.14552324,0.016669244,,,,,,,,,,,,,,,,, -94700,0.15260151,0.018345155,,,,,,,,,,,,,,,,, -94800,0.14714743,0.01941952,,,,,,,,,,,,,,,,, -94900,0.13541347,0.017634591,,,,,,,,,,,,,,,,, -95000,0.14975272,0.018066453,,,,,,,,,,,,,,,,, -95048,,,0.9955480098724364,0.0140321403741836,0.7664942574340686,0.9869745373725892,0.0502125211060047,0.2945165792896103,43793.0,0.9861940741539,0.0535595826804637,0.2858839892161938,43793.0,30990.98055648804,47628.29944491386,30990.98055648804,16629.640988588333,4.777262449264526,0.0 -95100,0.14378366,0.020402135,,,,,,,,,,,,,,,,, -95200,0.12835376,0.01716117,,,,,,,,,,,,,,,,, -95300,0.1349557,0.017713718,,,,,,,,,,,,,,,,, -95400,0.15125792,0.020537117,,,,,,,,,,,,,,,,, -95500,0.13723023,0.017587794,,,,,,,,,,,,,,,,, -95600,0.1459495,0.018650075,,,,,,,,,,,,,,,,, -95700,0.13085896,0.016742313,,,,,,,,,,,,,,,,, -95782,,,0.9955334663391112,0.0140433926135301,0.7828414672130524,0.9869745373725892,0.0502125211060047,0.294550277044121,43793.0,0.9861940741539,0.0535595826804637,0.2858748191382337,43793.0,31231.16896510124,47984.0223646164,31231.16896510124,16745.10850763321,4.822614192962647,0.0 -95800,0.14881234,0.018829169,,,,,,,,,,,,,,,,, -95900,0.14267442,0.01612914,,,,,,,,,,,,,,,,, -96000,0.14770007,0.018972395,,,,,,,,,,,,,,,,, -96100,0.13697712,0.017781049,,,,,,,,,,,,,,,,, -96200,0.14665182,0.018364016,,,,,,,,,,,,,,,,, -96300,0.14849596,0.019175615,,,,,,,,,,,,,,,,, -96400,0.15685974,0.020692475,,,,,,,,,,,,,,,,, -96500,0.16730441,0.020333694,,,,,,,,,,,,,,,,, -96521,,,0.9955517649650574,0.0139227332547307,0.7761237102647274,0.9869745373725892,0.0502125173807144,0.2945430188510121,43793.0,0.9861940741539,0.0535595789551734,0.2857925861415153,43793.0,31471.37131333351,48345.87097764015,31471.37131333351,16866.686855316162,4.869102954864502,0.0 -96600,0.14679521,0.018156016,,,,,,,,,,,,,,,,, -96700,0.12939134,0.017217252,,,,,,,,,,,,,,,,, -96800,0.14795442,0.019014893,,,,,,,,,,,,,,,,, -96900,0.15581673,0.016674222,,,,,,,,,,,,,,,,, -97000,0.16157544,0.017596385,,,,,,,,,,,,,,,,, -97100,0.13961023,0.018161608,,,,,,,,,,,,,,,,, -97200,0.13874386,0.017317116,,,,,,,,,,,,,,,,, -97252,,,0.9955731630325316,0.0139761669561266,0.7844379211332362,0.9869743585586548,0.0502125211060047,0.2947258707045029,43793.0,0.9861940741539,0.0535595826804637,0.2856464496291582,43793.0,31711.512481689453,48699.4358150959,31711.512481689453,16980.036007642746,4.920137643814087,0.0 -97300,0.12736979,0.016792191,,,,,,,,,,,,,,,,, -97400,0.1456626,0.01877001,,,,,,,,,,,,,,,,, -97500,0.13804297,0.018882034,,,,,,,,,,,,,,,,, -97600,0.16365601,0.017603949,,,,,,,,,,,,,,,,, -97700,0.15426317,0.019262053,,,,,,,,,,,,,,,,, -97800,0.13203892,0.016993318,,,,,,,,,,,,,,,,, -97900,0.15078056,0.019216591,,,,,,,,,,,,,,,,, -97989,,,0.9955433011054992,0.0139860603958368,0.7620940939206702,0.9869745373725892,0.0502125211060047,0.2945639696575376,43793.0,0.9861940741539,0.0535595789551734,0.2857639257450596,43793.0,31951.521282196045,49053.14595079422,31951.521282196045,17093.668981790543,4.96671462059021,0.0 -98000,0.12598369,0.015687142,,,,,,,,,,,,,,,,, -98100,0.14456774,0.015958294,,,,,,,,,,,,,,,,, -98200,0.14228229,0.01678196,,,,,,,,,,,,,,,,, -98300,0.1524563,0.01990494,,,,,,,,,,,,,,,,, -98400,0.14483011,0.017330403,,,,,,,,,,,,,,,,, -98500,0.13794464,0.018699322,,,,,,,,,,,,,,,,, -98600,0.14944148,0.019263698,,,,,,,,,,,,,,,,, -98700,0.15647553,0.01897283,,,,,,,,,,,,,,,,, -98734,,,0.9955700635910034,0.0139950010925531,0.7774484806299122,0.9869745373725892,0.0502125211060047,0.2945178067125195,43793.0,0.9861940741539,0.0535595826804637,0.2856923450793034,43793.0,32191.674822330475,49409.0846850872,32191.674822330475,17209.38709139824,5.011778116226196,0.0 -98800,0.14060417,0.018716829,,,,,,,,,,,,,,,,, -98900,0.1333416,0.015220787,,,,,,,,,,,,,,,,, -99000,0.15379316,0.017379766,,,,,,,,,,,,,,,,, -99100,0.13949384,0.016277222,,,,,,,,,,,,,,,,, -99200,0.13880853,0.01903445,,,,,,,,,,,,,,,,, -99300,0.14574689,0.016558466,,,,,,,,,,,,,,,,, -99400,0.12272079,0.015702724,,,,,,,,,,,,,,,,, -99472,,,0.9954909086227416,0.0140851242467761,0.7770188283552418,0.9869743585586548,0.0502125211060047,0.2945293023106267,43793.0,0.9861940741539,0.0535595826804637,0.2858091196595225,43793.0,32431.794410705566,49763.94093894959,32431.794410705566,17324.055718421936,5.057526350021362,0.0 -99500,0.1754588,0.01833049,,,,,,,,,,,,,,,,, -99600,0.14992994,0.018173996,,,,,,,,,,,,,,,,, -99700,0.14017747,0.015469441,,,,,,,,,,,,,,,,, -99800,0.14005758,0.016500615,,,,,,,,,,,,,,,,, -99900,0.145254,0.017699568,,,,,,,,,,,,,,,,, -100000,0.15319079,0.017473537,,,,,,,,,,,,,,,,, -100100,0.14009912,0.015441859,,,,,,,,,,,,,,,,, -100200,0.13872467,0.017294688,,,,,,,,,,,,,,,,, -100212,,,0.9955756664276124,0.0139365773648023,0.7749262622194402,0.9869743585586548,0.0502125211060047,0.2944989993784632,43793.0,0.9861940741539,0.0535595826804637,0.2857367479087632,43793.0,32671.83607316017,50125.02400946617,32671.83607316017,17445.024650096893,5.107828140258789,0.0 -100300,0.14809082,0.019456077,,,,,,,,,,,,,,,,, -100400,0.16963628,0.020043885,,,,,,,,,,,,,,,,, -100500,0.13519597,0.016959941,,,,,,,,,,,,,,,,, -100600,0.14833073,0.018491926,,,,,,,,,,,,,,,,, -100700,0.18981344,0.017218692,,,,,,,,,,,,,,,,, -100800,0.14224903,0.016561752,,,,,,,,,,,,,,,,, -100900,0.14590335,0.017638396,,,,,,,,,,,,,,,,, -100950,,,0.99558025598526,0.0139152826741337,0.7891355115019367,0.9869745373725892,0.0502125211060047,0.2946131160601966,43793.0,0.9861940741539,0.0535595789551734,0.2858639553246325,43793.0,32911.79198694229,50479.322907447815,32911.79198694229,17559.29871249199,5.153444766998291,0.0 -101000,0.15250978,0.019468427,,,,,,,,,,,,,,,,, -101100,0.13687739,0.01764376,,,,,,,,,,,,,,,,, -101200,0.124157846,0.015767327,,,,,,,,,,,,,,,,, -101300,0.15137571,0.018319475,,,,,,,,,,,,,,,,, -101400,0.14284182,0.016759513,,,,,,,,,,,,,,,,, -101500,0.1451126,0.017014403,,,,,,,,,,,,,,,,, -101600,0.14237887,0.018728686,,,,,,,,,,,,,,,,, -101684,,,0.9955013990402222,0.0140962926670908,0.7691005217213778,0.9869745373725892,0.0502125173807144,0.2944360238053917,43793.0,0.9861940741539,0.0535595789551734,0.2857212204469937,43793.0,33151.745717048645,50839.55113840103,33151.745717048645,17679.49751996994,5.205646276473999,0.0 -101700,0.13960083,0.017166663,,,,,,,,,,,,,,,,, -101800,0.13777198,0.01657645,,,,,,,,,,,,,,,,, -101900,0.14943406,0.019177122,,,,,,,,,,,,,,,,, -102000,0.14773573,0.01985594,,,,,,,,,,,,,,,,, -102100,0.14023559,0.016320024,,,,,,,,,,,,,,,,, -102200,0.14110205,0.018134093,,,,,,,,,,,,,,,,, -102300,0.1451368,0.019184593,,,,,,,,,,,,,,,,, -102400,0.15224993,0.01935825,,,,,,,,,,,,,,,,, -102411,,,0.9955734014511108,0.0139194196090102,0.7724062579147245,0.9869743585586548,0.0502125211060047,0.2945845273702825,43793.0,0.9861940741539,0.0535595826804637,0.2857138954593107,43793.0,33391.99066519737,51194.32950162888,33391.99066519737,17793.961818933487,5.252429485321045,0.0 -102500,0.1502745,0.019263491,,,,,,,,,,,,,,,,, -102600,0.13567734,0.017489132,,,,,,,,,,,,,,,,, -102700,0.16278464,0.020663228,,,,,,,,,,,,,,,,, -102800,0.14337981,0.019684695,,,,,,,,,,,,,,,,, -102900,0.13653369,0.019016529,,,,,,,,,,,,,,,,, -103000,0.15090346,0.019279154,,,,,,,,,,,,,,,,, -103100,0.17452446,0.0190255,,,,,,,,,,,,,,,,, -103146,,,0.9955319166183472,0.0140671050176024,0.7735482801034659,0.9869745373725892,0.0502125173807144,0.2945610976385742,43793.0,0.9861940741539,0.0535595826804637,0.2856564410760763,43793.0,33632.18942737579,51548.698831796646,33632.18942737579,17908.058881998062,5.303457260131836,0.0 -103200,0.14888358,0.019725164,,,,,,,,,,,,,,,,, -103300,0.13708028,0.017182106,,,,,,,,,,,,,,,,, -103400,0.14522006,0.019045684,,,,,,,,,,,,,,,,, -103500,0.15666695,0.018662537,,,,,,,,,,,,,,,,, -103600,0.14746714,0.017738612,,,,,,,,,,,,,,,,, -103700,0.13426006,0.017798975,,,,,,,,,,,,,,,,, -103800,0.1611024,0.020216106,,,,,,,,,,,,,,,,, -103893,,,0.9955902695655824,0.0138916801661252,0.7762755546021376,0.9869745373725892,0.0502125211060047,0.2945026049765142,43793.0,0.9861940741539,0.0535595826804637,0.2856985312715242,43793.0,33872.21159863472,51898.09083795548,33872.21159863472,18017.360858678818,5.349218845367432,0.0 -103900,0.14941546,0.019116573,,,,,,,,,,,,,,,,, -104000,0.13817795,0.01750579,,,,,,,,,,,,,,,,, -104100,0.1418086,0.017709244,,,,,,,,,,,,,,,,, -104200,0.13999157,0.017776527,,,,,,,,,,,,,,,,, -104300,0.14444964,0.016345335,,,,,,,,,,,,,,,,, -104400,0.14695269,0.017301904,,,,,,,,,,,,,,,,, -104500,0.15647973,0.020194888,,,,,,,,,,,,,,,,, -104600,0.13779044,0.01781628,,,,,,,,,,,,,,,,, -104640,,,0.9955270886421204,0.0139942103996872,0.7827216048150842,0.9869743585586548,0.0502125211060047,0.2945389773860817,43793.0,0.9861940741539,0.0535595826804637,0.2858397657044168,43793.0,34112.31607079506,52256.33351922035,34112.31607079506,18135.431131839752,5.395155668258667,0.0 -104700,0.15134655,0.01882622,,,,,,,,,,,,,,,,, -104800,0.14656219,0.018103806,,,,,,,,,,,,,,,,, -104900,0.14620614,0.0173529,,,,,,,,,,,,,,,,, -105000,0.15671614,0.019984536,,,,,,,,,,,,,,,,, -105100,0.16658415,0.021855725,,,,,,,,,,,,,,,,, -105200,0.13962047,0.016618138,,,,,,,,,,,,,,,,, -105300,0.14912753,0.019589959,,,,,,,,,,,,,,,,, -105384,,,0.9955507516860962,0.0140451323240995,0.7694572761821414,0.9869745373725892,0.0502125173807144,0.2945578506109468,43793.0,0.9861940741539,0.0535595826804637,0.285678615368202,43793.0,34352.42887854576,52615.46451854706,34352.42887854576,18254.377835989,5.444571256637573,0.0 -105400,0.14286,0.018374817,,,,,,,,,,,,,,,,, -105500,0.14953803,0.01999061,,,,,,,,,,,,,,,,, -105600,0.1306563,0.017384613,,,,,,,,,,,,,,,,, -105700,0.1267954,0.01522032,,,,,,,,,,,,,,,,, -105800,0.14707386,0.018749192,,,,,,,,,,,,,,,,, -105900,0.14037593,0.016980885,,,,,,,,,,,,,,,,, -106000,0.1653398,0.01766578,,,,,,,,,,,,,,,,, -106100,0.14208747,0.017318193,,,,,,,,,,,,,,,,, -106117,,,0.995536208152771,0.0139791183173656,0.7786668022825228,0.9869743585586548,0.0502125173807144,0.2945464250325268,43793.0,0.9861940741539,0.0535595826804637,0.2856916424811894,43793.0,34592.64105916023,52971.15380167961,34592.64105916023,18369.77789402008,5.496541738510132,0.0 -106200,0.14418179,0.017675877,,,,,,,,,,,,,,,,, -106300,0.13915828,0.017886342,,,,,,,,,,,,,,,,, -106400,0.14593057,0.01886383,,,,,,,,,,,,,,,,, -106500,0.16071436,0.019421695,,,,,,,,,,,,,,,,, -106600,0.1675491,0.018110016,,,,,,,,,,,,,,,,, -106700,0.13875124,0.015698215,,,,,,,,,,,,,,,,, -106800,0.14879577,0.019005973,,,,,,,,,,,,,,,,, -106851,,,0.9955457448959352,0.0140038868412375,0.7663039849163852,0.9869745373725892,0.0502125173807144,0.2945365314046178,43793.0,0.9861940741539,0.0535595789551734,0.2856666989155942,43793.0,34832.609432935715,53327.67676925659,34832.609432935715,18486.26060438156,5.544113397598267,0.0 -106900,0.14734232,0.0163514,,,,,,,,,,,,,,,,, -107000,0.13884781,0.019069318,,,,,,,,,,,,,,,,, -107100,0.13595703,0.017410962,,,,,,,,,,,,,,,,, -107200,0.14657468,0.018520037,,,,,,,,,,,,,,,,, -107300,0.14704189,0.01872991,,,,,,,,,,,,,,,,, -107400,0.15531379,0.019088786,,,,,,,,,,,,,,,,, -107500,0.1396258,0.015279119,,,,,,,,,,,,,,,,, -107596,,,0.9955405592918396,0.0139972921460866,0.7824891126363962,0.9869745373725892,0.0502125211060047,0.2945553946251877,43793.0,0.9861940741539,0.0535595789551734,0.285848566255761,43793.0,35072.64724302292,53679.93759918213,35072.64724302292,18598.39912390709,5.606086492538452,0.0 -107600,0.15257649,0.020021025,,,,,,,,,,,,,,,,, -107700,0.1489027,0.019988788,,,,,,,,,,,,,,,,, -107800,0.14991301,0.019881466,,,,,,,,,,,,,,,,, -107900,0.15964925,0.019356918,,,,,,,,,,,,,,,,, -108000,0.13646798,0.01726598,,,,,,,,,,,,,,,,, -108100,0.17057167,0.017232358,,,,,,,,,,,,,,,,, -108200,0.16055207,0.020375794,,,,,,,,,,,,,,,,, -108300,0.13724191,0.01752752,,,,,,,,,,,,,,,,, -108330,,,0.9955700039863586,0.013952019624412,0.780402457785562,0.9869743585586548,0.0502125211060047,0.294552374588472,43793.0,0.9861940741539,0.0535595826804637,0.2856731132602263,43793.0,35312.800602436066,54031.39209771156,35312.800602436066,18709.627541542053,5.654670476913452,0.0 -108400,0.1591808,0.018332416,,,,,,,,,,,,,,,,, -108500,0.14394763,0.018782323,,,,,,,,,,,,,,,,, -108600,0.1403857,0.016053945,,,,,,,,,,,,,,,,, -108700,0.13957402,0.018548915,,,,,,,,,,,,,,,,, -108800,0.15228371,0.018368548,,,,,,,,,,,,,,,,, -108900,0.1483225,0.020892981,,,,,,,,,,,,,,,,, -109000,0.14384213,0.01841516,,,,,,,,,,,,,,,,, -109073,,,0.9955811500549316,0.0138882901519536,0.7799591160022632,0.9869745373725892,0.0502125211060047,0.2944818855644533,43793.0,0.9861940741539,0.0535595789551734,0.2857106592316629,43793.0,35552.739119291306,54381.56074023247,35552.739119291306,18819.788346767426,5.7019219398498535,0.0 -109100,0.13996992,0.017081335,,,,,,,,,,,,,,,,, -109200,0.1382089,0.017348638,,,,,,,,,,,,,,,,, -109300,0.14230281,0.01826126,,,,,,,,,,,,,,,,, -109400,0.14793895,0.017995888,,,,,,,,,,,,,,,,, -109500,0.14538221,0.017181586,,,,,,,,,,,,,,,,, -109600,0.12930152,0.01698414,,,,,,,,,,,,,,,,, -109700,0.16156657,0.01967996,,,,,,,,,,,,,,,,, -109800,0.1448494,0.0177432,,,,,,,,,,,,,,,,, -109816,,,0.9955028891563416,0.0141618028283119,0.7691131012284711,0.9869745373725892,0.0502125211060047,0.2945828852469386,43793.0,0.9861940741539,0.0535595826804637,0.2857457854232376,43793.0,35792.734656095505,54736.71452140808,35792.734656095505,18934.87672638893,5.749660015106201,0.0 -109900,0.14746383,0.018125907,,,,,,,,,,,,,,,,, -110000,0.139366,0.0192119,,,,,,,,,,,,,,,,, -110100,0.14451236,0.017782554,,,,,,,,,,,,,,,,, -110200,0.138328,0.016629675,,,,,,,,,,,,,,,,, -110300,0.16349731,0.019313741,,,,,,,,,,,,,,,,, -110400,0.14538531,0.016771164,,,,,,,,,,,,,,,,, -110500,0.15573712,0.019568708,,,,,,,,,,,,,,,,, -110558,,,0.995578408241272,0.0138934273272752,0.7705730961172595,0.9869743585586548,0.0502125173807144,0.294555251483837,43793.0,0.9861940741539,0.0535595826804637,0.2856821056185651,43793.0,36032.70607757568,55088.55318880081,36032.70607757568,19046.673018455505,5.798440217971802,0.0 -110600,0.1474163,0.017187567,,,,,,,,,,,,,,,,, -110700,0.14412344,0.016649773,,,,,,,,,,,,,,,,, -110800,0.15437534,0.019908192,,,,,,,,,,,,,,,,, -110900,0.13553092,0.01620957,,,,,,,,,,,,,,,,, -111000,0.16388623,0.020000242,,,,,,,,,,,,,,,,, -111100,0.14728764,0.016624408,,,,,,,,,,,,,,,,, -111200,0.1469081,0.016825292,,,,,,,,,,,,,,,,, -111293,,,0.9955031871795654,0.0141362929716706,0.7735100883177448,0.9869741201400756,0.0502125211060047,0.2944882404998152,43793.0,0.9861940741539,0.0535595789551734,0.2857235851747701,43793.0,36272.84723806381,55443.536256074905,36272.84723806381,19161.44490623474,5.846911668777466,0.0 -111300,0.1416254,0.019124012,,,,,,,,,,,,,,,,, -111400,0.14105779,0.016848497,,,,,,,,,,,,,,,,, -111500,0.15492032,0.01844863,,,,,,,,,,,,,,,,, -111600,0.13393788,0.017480599,,,,,,,,,,,,,,,,, -111700,0.14316426,0.017032053,,,,,,,,,,,,,,,,, -111800,0.15288869,0.018905768,,,,,,,,,,,,,,,,, -111900,0.14179371,0.017633526,,,,,,,,,,,,,,,,, -112000,0.14678085,0.019519359,,,,,,,,,,,,,,,,, -112032,,,0.9955770969390868,0.0138916047289967,0.7804736794685526,0.9869745373725892,0.0502125211060047,0.2945283101312911,43793.0,0.9861940741539,0.0535595826804637,0.2856990657290821,43793.0,36513.03610134125,55792.12286829949,36513.03610134125,19269.7726829052,5.894815921783447,0.0 -112100,0.14066839,0.019270986,,,,,,,,,,,,,,,,, -112200,0.14029856,0.0175893,,,,,,,,,,,,,,,,, -112300,0.16887417,0.019101977,,,,,,,,,,,,,,,,, -112400,0.13766971,0.019005898,,,,,,,,,,,,,,,,, -112500,0.13384101,0.0175664,,,,,,,,,,,,,,,,, -112600,0.1295339,0.017042603,,,,,,,,,,,,,,,,, -112700,0.15489705,0.020423561,,,,,,,,,,,,,,,,, -112781,,,0.99552184343338,0.0140519654378294,0.7811979990159192,0.9869745373725892,0.0502125211060047,0.2946104591043645,43793.0,0.9861940741539,0.0535595789551734,0.2857793929004372,43793.0,36753.05175638199,56148.46195721626,36753.05175638199,19386.026316165924,5.942633867263794,0.0 -112800,0.1464018,0.016223298,,,,,,,,,,,,,,,,, -112900,0.16851208,0.021090563,,,,,,,,,,,,,,,,, -113000,0.14109837,0.01664162,,,,,,,,,,,,,,,,, -113100,0.1633365,0.017642392,,,,,,,,,,,,,,,,, -113200,0.13264644,0.015722783,,,,,,,,,,,,,,,,, -113300,0.15225375,0.019743932,,,,,,,,,,,,,,,,, -113400,0.15308866,0.018366914,,,,,,,,,,,,,,,,, -113500,0.14992481,0.018647054,,,,,,,,,,,,,,,,, -113520,,,0.9955678582191468,0.0139427185058593,0.7712682210840969,0.9869741201400756,0.0502125173807144,0.2946619508789923,43793.0,0.9861940741539,0.0535595826804637,0.2856983237546496,43793.0,36993.05020284653,56506.45134592056,36993.05020284653,19503.940582990646,5.9964470863342285,0.0 -113600,0.15792044,0.018563708,,,,,,,,,,,,,,,,, -113700,0.15322769,0.018958848,,,,,,,,,,,,,,,,, -113800,0.1531473,0.018561305,,,,,,,,,,,,,,,,, -113900,0.16720067,0.018846026,,,,,,,,,,,,,,,,, -114000,0.14627305,0.018081732,,,,,,,,,,,,,,,,, -114100,0.14404626,0.017725622,,,,,,,,,,,,,,,,, -114200,0.14029218,0.016726613,,,,,,,,,,,,,,,,, -114247,,,0.9955666661262512,0.013914069160819,0.7792822577149576,0.9869745373725892,0.0502125211060047,0.2945101564142774,43793.0,0.9861940741539,0.0535595789551734,0.2857388313861314,43793.0,37233.1000483036,56856.48850417137,37233.1000483036,19613.848816156387,6.050333261489868,0.0 -114300,0.14730304,0.017572396,,,,,,,,,,,,,,,,, -114400,0.13551606,0.017305512,,,,,,,,,,,,,,,,, -114500,0.1538344,0.01920894,,,,,,,,,,,,,,,,, -114600,0.13593176,0.016674504,,,,,,,,,,,,,,,,, -114700,0.13456784,0.017033363,,,,,,,,,,,,,,,,, -114800,0.17048599,0.019218853,,,,,,,,,,,,,,,,, -114900,0.1380572,0.01882583,,,,,,,,,,,,,,,,, -114989,,,0.9955500960350036,0.0140851167961955,0.7752664714391566,0.9869743585586548,0.0502125173807144,0.2945460843140218,43793.0,0.9861940741539,0.0535595826804637,0.2858201532068772,43793.0,37473.09302973747,57209.81217265129,37473.09302973747,19727.10418367386,6.103281021118164,0.0 -115000,0.13781103,0.015568863,,,,,,,,,,,,,,,,, -115100,0.122345984,0.016433777,,,,,,,,,,,,,,,,, -115200,0.13405572,0.015401387,,,,,,,,,,,,,,,,, -115300,0.14850658,0.018183125,,,,,,,,,,,,,,,,, -115400,0.158618,0.017443594,,,,,,,,,,,,,,,,, -115500,0.14972006,0.017642934,,,,,,,,,,,,,,,,, -115600,0.1547424,0.018800503,,,,,,,,,,,,,,,,, -115700,0.15370697,0.019886049,,,,,,,,,,,,,,,,, -115723,,,0.9955061078071594,0.0140688307583332,0.7755653651667929,0.9869743585586548,0.0502125211060047,0.2945540190760538,43793.0,0.9861940741539,0.0535595789551734,0.285702221607734,43793.0,37713.152237176895,57564.23022627831,37713.152237176895,19841.389622211456,6.152727365493774,0.0 -115800,0.12987494,0.016094688,,,,,,,,,,,,,,,,, -115900,0.13602494,0.01571185,,,,,,,,,,,,,,,,, -116000,0.1393588,0.018834766,,,,,,,,,,,,,,,,, -116100,0.13234852,0.016446505,,,,,,,,,,,,,,,,, -116200,0.14414689,0.01920492,,,,,,,,,,,,,,,,, -116300,0.1402466,0.017712845,,,,,,,,,,,,,,,,, -116400,0.15321937,0.017076382,,,,,,,,,,,,,,,,, -116462,,,0.9955783486366272,0.0138890761882066,0.7852735330342986,0.9869745373725892,0.0502125211060047,0.2944726678951031,43793.0,0.9861940741539,0.0535595826804637,0.2856991776358954,43793.0,37953.37311530113,57917.61430335045,37953.37311530113,19954.476145982742,6.206709146499634,0.0 -116500,0.14973848,0.019115722,,,,,,,,,,,,,,,,, -116600,0.17447802,0.018386204,,,,,,,,,,,,,,,,, -116700,0.1327301,0.017142441,,,,,,,,,,,,,,,,, -116800,0.13715236,0.016061908,,,,,,,,,,,,,,,,, -116900,0.1448869,0.020179318,,,,,,,,,,,,,,,,, -117000,0.1493989,0.017998643,,,,,,,,,,,,,,,,, -117100,0.13664155,0.015738225,,,,,,,,,,,,,,,,, -117193,,,0.995569348335266,0.0139385117217898,0.7685402304600731,0.9869745373725892,0.0502125211060047,0.2944315252732025,43793.0,0.9861940741539,0.0535595789551734,0.2857992774557656,43793.0,38193.32923769951,58270.33309483528,38193.32923769951,20067.16177225113,6.261343479156494,0.0 -117200,0.14349665,0.020136544,,,,,,,,,,,,,,,,, -117300,0.13861103,0.016744448,,,,,,,,,,,,,,,,, -117400,0.15389359,0.020386305,,,,,,,,,,,,,,,,, -117500,0.17305781,0.01918975,,,,,,,,,,,,,,,,, -117600,0.14132254,0.017694172,,,,,,,,,,,,,,,,, -117700,0.14425132,0.01720466,,,,,,,,,,,,,,,,, -117800,0.14939585,0.017920554,,,,,,,,,,,,,,,,, -117900,0.1270295,0.015888073,,,,,,,,,,,,,,,,, -117935,,,0.9955466389656068,0.0140097755938768,0.778482909552098,0.9869745373725892,0.0502125211060047,0.2945424307965914,43793.0,0.9861940741539,0.0535595789551734,0.2858286062184439,43793.0,38433.5650575161,58622.57953858376,38433.5650575161,20179.09550833702,6.315385818481445,0.0 -118000,0.161215,0.020353029,,,,,,,,,,,,,,,,, -118100,0.13552335,0.018099342,,,,,,,,,,,,,,,,, -118200,0.13763662,0.018335667,,,,,,,,,,,,,,,,, -118300,0.15325314,0.017935907,,,,,,,,,,,,,,,,, -118400,0.13177875,0.017961426,,,,,,,,,,,,,,,,, -118500,0.13419887,0.018295757,,,,,,,,,,,,,,,,, -118600,0.13957103,0.015754916,,,,,,,,,,,,,,,,, -118679,,,0.9955329895019532,0.0140114072710275,0.76781053698901,0.9869745373725892,0.0502125211060047,0.2944648560792302,43793.0,0.9861940741539,0.0535595789551734,0.2856903053998522,43793.0,38673.81978392601,58971.44974565506,38673.81978392601,20287.63827586174,6.3654186725616455,0.0 -118700,0.14391242,0.01796652,,,,,,,,,,,,,,,,, -118800,0.12690377,0.01662938,,,,,,,,,,,,,,,,, -118900,0.1539389,0.017340263,,,,,,,,,,,,,,,,, -119000,0.14110005,0.017156336,,,,,,,,,,,,,,,,, -119100,0.14041309,0.016860036,,,,,,,,,,,,,,,,, -119200,0.14125405,0.016975993,,,,,,,,,,,,,,,,, -119300,0.1522196,0.02081292,,,,,,,,,,,,,,,,, -119400,0.15345252,0.017600033,,,,,,,,,,,,,,,,, -119409,,,0.9955406188964844,0.013977606780827,0.7782823868286457,0.9869743585586548,0.0502125211060047,0.2945175274820147,43793.0,0.9861940741539,0.0535595826804637,0.285719341842257,43793.0,38913.83612418175,59325.26311969757,38913.83612418175,20401.36284303665,6.414927959442139,0.0 -119500,0.12804838,0.015732694,,,,,,,,,,,,,,,,, -119600,0.15555866,0.020538857,,,,,,,,,,,,,,,,, -119700,0.13019061,0.0152604375,,,,,,,,,,,,,,,,, -119800,0.14105612,0.016329858,,,,,,,,,,,,,,,,, -119900,0.14199904,0.016274095,,,,,,,,,,,,,,,,, -120000,0.14695445,0.017953651,,,,,,,,,,,,,,,,, -120100,0.13644813,0.017803196,,,,,,,,,,,,,,,,, -120140,,,0.9955416321754456,0.013997571542859,0.7779783333646086,0.9869745373725892,0.0502125211060047,0.2944633500683524,43793.0,0.9861940741539,0.0535595826804637,0.2857392894248236,43793.0,39153.818912267685,59679.84038281441,39153.818912267685,20515.88113284111,6.468688011169434,0.0 -120200,0.15081944,0.01855357,,,,,,,,,,,,,,,,, -120300,0.14212222,0.015977496,,,,,,,,,,,,,,,,, -120400,0.14511076,0.019142173,,,,,,,,,,,,,,,,, -120500,0.15517008,0.018263368,,,,,,,,,,,,,,,,, -120600,0.13757953,0.01775219,,,,,,,,,,,,,,,,, -120700,0.14011474,0.017790547,,,,,,,,,,,,,,,,, -120800,0.14185067,0.01806161,,,,,,,,,,,,,,,,, -120873,,,0.9955613017082214,0.0140098659321665,0.7866055579445876,0.9869743585586548,0.0502125173807144,0.2944610430847229,43793.0,0.9861940741539,0.0535595826804637,0.2858240013559314,43793.0,39394.06165885925,60034.53656172752,39394.06165885925,20630.25913882256,6.522387504577637,0.0 -120900,0.15918715,0.01735747,,,,,,,,,,,,,,,,, -121000,0.14279036,0.018272078,,,,,,,,,,,,,,,,, -121100,0.15914576,0.015958585,,,,,,,,,,,,,,,,, -121200,0.15996361,0.018705914,,,,,,,,,,,,,,,,, -121300,0.17259498,0.019826913,,,,,,,,,,,,,,,,, -121400,0.13656212,0.015138436,,,,,,,,,,,,,,,,, -121500,0.14551781,0.017146131,,,,,,,,,,,,,,,,, -121600,0.13173965,0.016545966,,,,,,,,,,,,,,,,, -121614,,,0.9955511689186096,0.0139482170343399,0.7662657186104772,0.9869745373725892,0.0502125211060047,0.2945873331954825,43793.0,0.9861940741539,0.0535595826804637,0.2857204561472292,43793.0,39634.28438973427,60385.552882909775,39634.28438973427,20740.979505062103,6.573132276535034,0.0 -121700,0.1496614,0.018386079,,,,,,,,,,,,,,,,, -121800,0.13340788,0.017572213,,,,,,,,,,,,,,,,, -121900,0.16137718,0.019935582,,,,,,,,,,,,,,,,, -122000,0.16607514,0.01708198,,,,,,,,,,,,,,,,, -122100,0.16296521,0.019674927,,,,,,,,,,,,,,,,, -122200,0.1468366,0.019336332,,,,,,,,,,,,,,,,, -122300,0.13810931,0.01721041,,,,,,,,,,,,,,,,, -122354,,,0.9955592155456544,0.0140119800344109,0.7753787037792779,0.9869745373725892,0.0502125211060047,0.2944693613267529,43793.0,0.9861940741539,0.0535595826804637,0.2858334033911579,43793.0,39874.48172211647,60743.27828383446,39874.48172211647,20858.42757821083,6.62934947013855,0.0 -122400,0.14126001,0.01865283,,,,,,,,,,,,,,,,, -122500,0.12426939,0.016005961,,,,,,,,,,,,,,,,, -122600,0.14234234,0.018386165,,,,,,,,,,,,,,,,, -122700,0.14036252,0.018600643,,,,,,,,,,,,,,,,, -122800,0.13802376,0.016906029,,,,,,,,,,,,,,,,, -122900,0.13135125,0.017345482,,,,,,,,,,,,,,,,, -123000,0.14128311,0.020506663,,,,,,,,,,,,,,,,, -123095,,,0.9955024123191832,0.0140598332509398,0.770876647929712,0.9869745373725892,0.0502125211060047,0.2944734527416136,43793.0,0.9861940741539,0.0535595826804637,0.2857019667587944,43793.0,40114.56664967537,61091.83598303795,40114.56664967537,20966.82272863388,6.682506799697876,0.0 -123100,0.1497943,0.020341406,,,,,,,,,,,,,,,,, -123200,0.14150552,0.017585017,,,,,,,,,,,,,,,,, -123300,0.1511124,0.018506618,,,,,,,,,,,,,,,,, -123400,0.14391941,0.017775787,,,,,,,,,,,,,,,,, -123500,0.12282373,0.016921729,,,,,,,,,,,,,,,,, -123600,0.15559694,0.017701043,,,,,,,,,,,,,,,,, -123700,0.14777558,0.017873406,,,,,,,,,,,,,,,,, -123800,0.13666967,0.017968947,,,,,,,,,,,,,,,,, -123830,,,0.9956189393997192,0.0138553949072957,0.7783691267641425,0.9869745373725892,0.0502125211060047,0.2944677065948684,43793.0,0.9861940741539,0.0535595789551734,0.2858282936163239,43793.0,40354.73410964012,61442.61308217049,40354.73410964012,21077.35178470612,6.739075183868408,0.0 -123900,0.13702762,0.01631403,,,,,,,,,,,,,,,,, -124000,0.13522722,0.01538913,,,,,,,,,,,,,,,,, -124100,0.14713576,0.018528286,,,,,,,,,,,,,,,,, -124200,0.14176638,0.01771572,,,,,,,,,,,,,,,,, -124300,0.16000299,0.018393172,,,,,,,,,,,,,,,,, -124400,0.11326318,0.015084649,,,,,,,,,,,,,,,,, -124500,0.16390625,0.017729582,,,,,,,,,,,,,,,,, -124560,,,0.995537519454956,0.0140228625386953,0.7846410551431506,0.9869745373725892,0.0502125211060047,0.2946003063371421,43793.0,0.9861940741539,0.0535595789551734,0.2859011175114634,43793.0,40594.8329679966,61794.2158946991,40594.8329679966,21188.7815053463,6.7893126010894775,0.0 -124600,0.13382511,0.016790012,,,,,,,,,,,,,,,,, -124700,0.15523182,0.020174941,,,,,,,,,,,,,,,,, -124800,0.1372984,0.01681881,,,,,,,,,,,,,,,,, -124900,0.124984846,0.014853881,,,,,,,,,,,,,,,,, -125000,0.1502439,0.018698068,,,,,,,,,,,,,,,,, -125100,0.14385085,0.019254815,,,,,,,,,,,,,,,,, -125200,0.15172112,0.01954072,,,,,,,,,,,,,,,,, -125294,,,0.9955279231071472,0.0140428263694047,0.7715511151186719,0.9869743585586548,0.0502125211060047,0.2945159621426331,43793.0,0.9861940741539,0.0535595826804637,0.2856668054003706,43793.0,40834.93838334084,62146.89764785767,40834.93838334084,21301.28376197815,6.841134786605835,0.0 -125300,0.14947718,0.020135676,,,,,,,,,,,,,,,,, -125400,0.13454223,0.016943049,,,,,,,,,,,,,,,,, -125500,0.15157145,0.019745585,,,,,,,,,,,,,,,,, -125600,0.15491979,0.017451135,,,,,,,,,,,,,,,,, -125700,0.13738985,0.017250977,,,,,,,,,,,,,,,,, -125800,0.1513107,0.017246032,,,,,,,,,,,,,,,,, -125900,0.15274586,0.017974542,,,,,,,,,,,,,,,,, -126000,0.14443673,0.01868768,,,,,,,,,,,,,,,,, -126032,,,0.9955497980117798,0.0139739634469151,0.7776480587757924,0.9869745373725892,0.0502125211060047,0.2945289512400195,43793.0,0.9861940741539,0.0535595789551734,0.2859947860418702,43793.0,41075.21634960175,62499.35304760933,41075.21634960175,21413.383462429047,6.894498348236084,0.0 -126100,0.14298607,0.018142965,,,,,,,,,,,,,,,,, -126200,0.15826121,0.018040486,,,,,,,,,,,,,,,,, -126300,0.14475779,0.017237956,,,,,,,,,,,,,,,,, -126400,0.14038482,0.017817521,,,,,,,,,,,,,,,,, -126500,0.13981183,0.017265394,,,,,,,,,,,,,,,,, -126600,0.15167566,0.019680602,,,,,,,,,,,,,,,,, -126700,0.13820322,0.01706926,,,,,,,,,,,,,,,,, -126774,,,0.9955505132675172,0.013991804793477,0.7683404890570733,0.9869741201400756,0.0502125173807144,0.2945988444665576,43793.0,0.9861940741539,0.0535595826804637,0.2857411098029996,43793.0,41315.295952796936,62850.69105839729,41315.295952796936,21524.56535053253,6.948744535446167,0.0 -126800,0.13846213,0.017140696,,,,,,,,,,,,,,,,, -126900,0.16671708,0.01912458,,,,,,,,,,,,,,,,, -127000,0.13922106,0.018402293,,,,,,,,,,,,,,,,, -127100,0.15425979,0.020632481,,,,,,,,,,,,,,,,, -127200,0.14162147,0.016203832,,,,,,,,,,,,,,,,, -127300,0.15329409,0.018602733,,,,,,,,,,,,,,,,, -127400,0.1597312,0.017924082,,,,,,,,,,,,,,,,, -127500,0.14004831,0.017881472,,,,,,,,,,,,,,,,, -127512,,,0.995507538318634,0.014115703292191,0.780440909566623,0.9869745373725892,0.0502125173807144,0.294529530609065,43793.0,0.9861940741539,0.0535595826804637,0.285670475255235,43793.0,41555.32707571983,63198.34479141235,41555.32707571983,21632.11439609528,6.999794960021973,0.0 -127600,0.13606039,0.017721364,,,,,,,,,,,,,,,,, -127700,0.1652216,0.019352173,,,,,,,,,,,,,,,,, -127800,0.12970461,0.016464543,,,,,,,,,,,,,,,,, -127900,0.139176,0.018561108,,,,,,,,,,,,,,,,, -128000,0.143627,0.018382475,,,,,,,,,,,,,,,,, -128100,0.1644041,0.019175071,,,,,,,,,,,,,,,,, -128200,0.12547047,0.01606797,,,,,,,,,,,,,,,,, -128248,,,0.9955496191978456,0.0139685394242405,0.7795067563109996,0.9869745373725892,0.0502125173807144,0.2945238855690372,43793.0,0.9861940741539,0.0535595826804637,0.2856927221154515,43793.0,41795.47059297562,63550.09992480278,41795.47059297562,21743.65156912804,7.051799297332764,0.0 -128300,0.14834805,0.019429512,,,,,,,,,,,,,,,,, -128400,0.13937329,0.018683102,,,,,,,,,,,,,,,,, -128500,0.13803186,0.017782632,,,,,,,,,,,,,,,,, -128600,0.13352965,0.017308675,,,,,,,,,,,,,,,,, -128700,0.13912895,0.017091857,,,,,,,,,,,,,,,,, -128800,0.14889735,0.019094592,,,,,,,,,,,,,,,,, -128900,0.14205356,0.016472494,,,,,,,,,,,,,,,,, -128988,,,0.9955980181694032,0.0138596631586551,0.78344074625687,0.9869745373725892,0.0502125211060047,0.2945503802955432,43793.0,0.9861940741539,0.0535595826804637,0.2857917035503163,43793.0,42035.69354510307,63904.26799011231,42035.69354510307,21857.52256894112,7.103084325790405,0.0 -129000,0.1458124,0.017605655,,,,,,,,,,,,,,,,, -129100,0.125265,0.016837575,,,,,,,,,,,,,,,,, -129200,0.13426593,0.016181834,,,,,,,,,,,,,,,,, -129300,0.16103873,0.01951931,,,,,,,,,,,,,,,,, -129400,0.14267251,0.016533766,,,,,,,,,,,,,,,,, -129500,0.1554635,0.019704267,,,,,,,,,,,,,,,,, -129600,0.12762032,0.016669389,,,,,,,,,,,,,,,,, -129700,0.13982874,0.017278949,,,,,,,,,,,,,,,,, -129726,,,0.9955507516860962,0.0139851318672299,0.7687265903200935,0.9869745373725892,0.0502125211060047,0.2946306155417239,43793.0,0.9861940741539,0.0535595789551734,0.2856905130623254,43793.0,42275.78005862236,64258.08281850815,42275.78005862236,21971.176341056824,7.15567946434021,0.0 -129800,0.16034803,0.017264958,,,,,,,,,,,,,,,,, -129900,0.13649513,0.018284515,,,,,,,,,,,,,,,,, -130000,0.14338042,0.01767905,,,,,,,,,,,,,,,,, -130100,0.14278631,0.01904288,,,,,,,,,,,,,,,,, -130200,0.14891392,0.01918241,,,,,,,,,,,,,,,,, -130300,0.15608098,0.019307446,,,,,,,,,,,,,,,,, -130400,0.14070965,0.016783668,,,,,,,,,,,,,,,,, -130455,,,0.9955714344978333,0.0139600820839405,0.7693036488054488,0.9869745373725892,0.0502125173807144,0.2945816972273906,43793.0,0.9861940741539,0.0535595789551734,0.2857761141941721,43793.0,42515.92790675163,64609.39808940888,42515.92790675163,22082.26949906349,7.207721710205078,0.0 -130500,0.1333434,0.017348811,,,,,,,,,,,,,,,,, -130600,0.14103085,0.017364174,,,,,,,,,,,,,,,,, -130700,0.14318106,0.017838132,,,,,,,,,,,,,,,,, -130800,0.1323791,0.014805499,,,,,,,,,,,,,,,,, -130900,0.15498789,0.018862471,,,,,,,,,,,,,,,,, -131000,0.14004688,0.019311924,,,,,,,,,,,,,,,,, -131100,0.14530979,0.018763894,,,,,,,,,,,,,,,,, -131190,,,0.9954754114151,0.0141520276665687,0.7745251563418538,0.9869745373725892,0.0502125211060047,0.2944870852900064,43793.0,0.9861940741539,0.0535595789551734,0.2857930735203836,43793.0,42755.94807124138,64961.888946056366,42755.94807124138,22194.663531780243,7.26179051399231,0.0 -131200,0.1351535,0.018246802,,,,,,,,,,,,,,,,, -131300,0.1508286,0.017534042,,,,,,,,,,,,,,,,, -131400,0.14977778,0.016947601,,,,,,,,,,,,,,,,, -131500,0.16576076,0.01696807,,,,,,,,,,,,,,,,, -131600,0.1304554,0.016847763,,,,,,,,,,,,,,,,, -131700,0.14377661,0.01774879,,,,,,,,,,,,,,,,, -131800,0.1457782,0.019301875,,,,,,,,,,,,,,,,, -131900,0.15902929,0.016991436,,,,,,,,,,,,,,,,, -131929,,,0.9955991506576538,0.0138892102986574,0.7770598738446075,0.9869745373725892,0.0502125173807144,0.2946191638573021,43793.0,0.9861940741539,0.0535595826804637,0.2856906678947874,43793.0,42995.89313149452,65312.44583821297,42995.89313149452,22305.20131421089,7.314098358154297,0.0 -132000,0.13954106,0.017883752,,,,,,,,,,,,,,,,, -132100,0.14656113,0.018468961,,,,,,,,,,,,,,,,, -132200,0.13938999,0.017170383,,,,,,,,,,,,,,,,, -132300,0.1468479,0.018745692,,,,,,,,,,,,,,,,, -132400,0.12387041,0.01392976,,,,,,,,,,,,,,,,, -132500,0.13821915,0.017555676,,,,,,,,,,,,,,,,, -132600,0.13807604,0.01862372,,,,,,,,,,,,,,,,, -132666,,,0.9955676794052124,0.0139704309403896,0.784020367065633,0.9869745373725892,0.0502125211060047,0.2944997865847881,43793.0,0.9861940741539,0.0535595789551734,0.285701665999011,43793.0,43236.01311707497,65659.10172605515,43236.01311707497,22411.660974740986,7.368078708648682,0.0 -132700,0.13924107,0.018352812,,,,,,,,,,,,,,,,, -132800,0.1293888,0.016988449,,,,,,,,,,,,,,,,, -132900,0.12805068,0.0165959,,,,,,,,,,,,,,,,, -133000,0.14476763,0.01877807,,,,,,,,,,,,,,,,, -133100,0.15846798,0.019229444,,,,,,,,,,,,,,,,, -133200,0.14055188,0.018514633,,,,,,,,,,,,,,,,, -133300,0.14421003,0.016589869,,,,,,,,,,,,,,,,, -133400,0.1418537,0.018418238,,,,,,,,,,,,,,,,, -133402,,,0.9955294132232666,0.0139775266870856,0.7719156447251221,0.9869745373725892,0.0502125211060047,0.2944672347701114,43793.0,0.9861940741539,0.0535595826804637,0.2859014656423316,43793.0,43476.01622104645,66010.43513774872,43476.01622104645,22522.91188192368,7.421867370605469,0.0 -133500,0.14307974,0.017191634,,,,,,,,,,,,,,,,, -133600,0.14160953,0.017382782,,,,,,,,,,,,,,,,, -133700,0.14853954,0.01818427,,,,,,,,,,,,,,,,, -133800,0.14648259,0.021243865,,,,,,,,,,,,,,,,, -133900,0.14381304,0.017334035,,,,,,,,,,,,,,,,, -134000,0.13956504,0.018103113,,,,,,,,,,,,,,,,, -134100,0.12692091,0.016924504,,,,,,,,,,,,,,,,, -134140,,,0.995570719242096,0.0139369834214448,0.7698331831434866,0.9869743585586548,0.0502125211060047,0.2945582199786213,43793.0,0.9861940741539,0.0535595789551734,0.2857146601910127,43793.0,43716.06605672836,66363.17952227592,43716.06605672836,22635.523801088333,7.478373527526855,0.0 -134200,0.13297848,0.019037422,,,,,,,,,,,,,,,,, -134300,0.14986055,0.018671352,,,,,,,,,,,,,,,,, -134400,0.15681729,0.019483196,,,,,,,,,,,,,,,,, -134500,0.14315385,0.016628329,,,,,,,,,,,,,,,,, -134600,0.14809553,0.016474893,,,,,,,,,,,,,,,,, -134700,0.1635349,0.019668492,,,,,,,,,,,,,,,,, -134800,0.14356871,0.016294703,,,,,,,,,,,,,,,,, -134871,,,0.995523989200592,0.0140719972550868,0.772939178937591,0.9869743585586548,0.0502125211060047,0.2945825986045106,43793.0,0.9861940741539,0.0535595826804637,0.2858309724099989,43793.0,43956.02018260956,66714.26945352554,43956.02018260956,22746.583614349365,7.531560182571411,0.0 -134900,0.14484678,0.019611072,,,,,,,,,,,,,,,,, -135000,0.15487939,0.019006232,,,,,,,,,,,,,,,,, -135100,0.13205701,0.016264675,,,,,,,,,,,,,,,,, -135200,0.14213538,0.018717062,,,,,,,,,,,,,,,,, -135300,0.1712299,0.019943047,,,,,,,,,,,,,,,,, -135400,0.14025582,0.01845155,,,,,,,,,,,,,,,,, -135500,0.17022735,0.018690271,,,,,,,,,,,,,,,,, -135600,0.15535215,0.019541053,,,,,,,,,,,,,,,,, -135606,,,0.9955481886863708,0.0139481537044048,0.7806334622776318,0.9869745373725892,0.0502125211060047,0.2945880907545196,43793.0,0.9861940741539,0.0535595789551734,0.2857107305887505,43793.0,44196.05651593208,67062.164737463,44196.05651593208,22854.359155654907,7.591469287872314,0.0 -135700,0.15186125,0.0196142,,,,,,,,,,,,,,,,, -135800,0.13215454,0.017315784,,,,,,,,,,,,,,,,, -135900,0.1438589,0.016344931,,,,,,,,,,,,,,,,, -136000,0.12891334,0.017380508,,,,,,,,,,,,,,,,, -136100,0.14358138,0.017794155,,,,,,,,,,,,,,,,, -136200,0.14280702,0.01775536,,,,,,,,,,,,,,,,, -136300,0.15082799,0.017733766,,,,,,,,,,,,,,,,, -136339,,,0.995563507080078,0.0139814345166087,0.7787603347264325,0.9869745373725892,0.0502125211060047,0.2946117411358875,43793.0,0.9861940741539,0.0535595826804637,0.2857003263771898,43793.0,44436.13486599922,67408.29759216309,44436.13486599922,22960.33766627312,7.645104646682739,0.0 -136400,0.13739854,0.016263451,,,,,,,,,,,,,,,,, -136500,0.155388,0.019571334,,,,,,,,,,,,,,,,, -136600,0.13804014,0.017787535,,,,,,,,,,,,,,,,, -136700,0.14213836,0.01671203,,,,,,,,,,,,,,,,, -136800,0.14495386,0.01852103,,,,,,,,,,,,,,,,, -136900,0.15315169,0.018673811,,,,,,,,,,,,,,,,, -137000,0.15930323,0.020005882,,,,,,,,,,,,,,,,, -137078,,,0.995528757572174,0.014043471775949,0.773235530751952,0.9869745373725892,0.0502125173807144,0.2945560652848558,43793.0,0.9861940741539,0.0535595826804637,0.2856534864633366,43793.0,44676.229766607285,67757.26006889343,44676.229766607285,23069.12981247902,7.69776439666748,0.0 -137100,0.15520203,0.01901588,,,,,,,,,,,,,,,,, -137200,0.13333087,0.017357955,,,,,,,,,,,,,,,,, -137300,0.16287833,0.02005299,,,,,,,,,,,,,,,,, -137400,0.13434312,0.017039787,,,,,,,,,,,,,,,,, -137500,0.12567313,0.016633209,,,,,,,,,,,,,,,,, -137600,0.1296713,0.017732592,,,,,,,,,,,,,,,,, -137700,0.14758271,0.019840853,,,,,,,,,,,,,,,,, -137800,0.12890069,0.01577066,,,,,,,,,,,,,,,,, -137812,,,0.9955397248268129,0.0140498215332627,0.7716928762320991,0.9869743585586548,0.0502125211060047,0.2945598515327954,43793.0,0.9861940741539,0.0535595826804637,0.2856929229490918,43793.0,44916.29040455818,68107.390021801,44916.29040455818,23179.11955356598,7.754902839660644,0.0 -137900,0.14410454,0.017531954,,,,,,,,,,,,,,,,, -138000,0.1546133,0.020895267,,,,,,,,,,,,,,,,, -138100,0.15875621,0.017200958,,,,,,,,,,,,,,,,, -138200,0.15070815,0.017891085,,,,,,,,,,,,,,,,, -138300,0.1395698,0.018803855,,,,,,,,,,,,,,,,, -138400,0.14709505,0.019845722,,,,,,,,,,,,,,,,, -138500,0.15423955,0.019257719,,,,,,,,,,,,,,,,, -138549,,,0.9955530166625975,0.0139563847333192,0.7666460856938745,0.9869745373725892,0.0502125211060047,0.2945474761017078,43793.0,0.9861940741539,0.0535595789551734,0.285655050187629,43793.0,45156.33108019829,68455.58481907845,45156.33108019829,23287.19401454925,7.812101364135742,0.0 -138600,0.13795854,0.01812768,,,,,,,,,,,,,,,,, -138700,0.13614762,0.017363463,,,,,,,,,,,,,,,,, -138800,0.14967397,0.019844305,,,,,,,,,,,,,,,,, -138900,0.15772553,0.020267004,,,,,,,,,,,,,,,,, -139000,0.15360463,0.018063873,,,,,,,,,,,,,,,,, -139100,0.13750644,0.016991487,,,,,,,,,,,,,,,,, -139200,0.13809852,0.01674855,,,,,,,,,,,,,,,,, -139280,,,0.9955349564552308,0.0139978528022766,0.7760984157232222,0.9869743585586548,0.0502125211060047,0.294489289603176,43793.0,0.9861940741539,0.0535595789551734,0.2856873922115879,43793.0,45396.46606326103,68805.20327782631,45396.46606326103,23396.60202598572,7.865617513656616,0.0 -139300,0.14427061,0.017246231,,,,,,,,,,,,,,,,, -139400,0.14259802,0.019018278,,,,,,,,,,,,,,,,, -139500,0.14786522,0.017579634,,,,,,,,,,,,,,,,, -139600,0.1426498,0.018559642,,,,,,,,,,,,,,,,, -139700,0.15537913,0.017777592,,,,,,,,,,,,,,,,, -139800,0.15711676,0.018123833,,,,,,,,,,,,,,,,, -139900,0.14865518,0.019011192,,,,,,,,,,,,,,,,, -140000,0.13465002,0.019033737,,,,,,,,,,,,,,,,, -140024,,,0.9955670833587646,0.0139693990349769,0.7815567202019054,0.9869745373725892,0.0502125211060047,0.2945874760881046,43793.0,0.9861940741539,0.0535595789551734,0.2858944262889938,43793.0,45636.407398462296,69153.9723265171,45636.407398462296,23505.35227966309,7.920868873596191,0.0 -140100,0.15792644,0.01939877,,,,,,,,,,,,,,,,, -140200,0.15374687,0.018072931,,,,,,,,,,,,,,,,, -140300,0.14405802,0.017394327,,,,,,,,,,,,,,,,, -140400,0.1468616,0.01751742,,,,,,,,,,,,,,,,, -140500,0.15275495,0.017962594,,,,,,,,,,,,,,,,, -140600,0.13952726,0.017177168,,,,,,,,,,,,,,,,, -140700,0.14934689,0.016965201,,,,,,,,,,,,,,,,, -140767,,,0.995562732219696,0.0139715448021888,0.7835845042765484,0.9869741201400756,0.0502125211060047,0.2946952873297384,43793.0,0.9861940741539,0.0535595826804637,0.2856803090264078,43793.0,45876.594853162766,69504.31644678116,45876.594853162766,23615.431432724,7.97593355178833,0.0 -140800,0.1474202,0.017972786,,,,,,,,,,,,,,,,, -140900,0.14897455,0.016854513,,,,,,,,,,,,,,,,, -141000,0.1479136,0.017625572,,,,,,,,,,,,,,,,, -141100,0.14416802,0.017312672,,,,,,,,,,,,,,,,, -141200,0.14984736,0.019746933,,,,,,,,,,,,,,,,, -141300,0.15230162,0.017704625,,,,,,,,,,,,,,,,, -141400,0.14115414,0.017560843,,,,,,,,,,,,,,,,, -141500,0.12311139,0.015829442,,,,,,,,,,,,,,,,, -141514,,,0.9955607056617736,0.0139527916908264,0.7693302481492503,0.9869745373725892,0.0502125211060047,0.2945097398624374,43793.0,0.9861940741539,0.0535595826804637,0.2857075935707077,43793.0,46116.76946258545,69853.53129506111,46116.76946258545,23724.393659591675,8.031497716903687,0.0 -141600,0.13536762,0.016834848,,,,,,,,,,,,,,,,, -141700,0.13407317,0.016599867,,,,,,,,,,,,,,,,, -141800,0.14784077,0.017139677,,,,,,,,,,,,,,,,, -141900,0.1305982,0.016802596,,,,,,,,,,,,,,,,, -142000,0.12682374,0.017709292,,,,,,,,,,,,,,,,, -142100,0.14092657,0.018065784,,,,,,,,,,,,,,,,, -142200,0.12854455,0.016773857,,,,,,,,,,,,,,,,, -142256,,,0.995522677898407,0.0140734603628516,0.7720646022590836,0.9869743585586548,0.0502125211060047,0.2945275033973265,43793.0,0.9861940741539,0.0535595826804637,0.2857075416590519,43793.0,46356.98973464966,70205.73898673058,46356.98973464966,23836.30408072472,8.086358785629272,0.0 -142300,0.13584812,0.016932292,,,,,,,,,,,,,,,,, -142400,0.13767086,0.017765023,,,,,,,,,,,,,,,,, -142500,0.144001,0.018435929,,,,,,,,,,,,,,,,, -142600,0.13664389,0.016995834,,,,,,,,,,,,,,,,, -142700,0.14226292,0.016919933,,,,,,,,,,,,,,,,, -142800,0.12951645,0.015607616,,,,,,,,,,,,,,,,, -142900,0.16407253,0.019548628,,,,,,,,,,,,,,,,, -142996,,,0.9955246448516846,0.0140249570831656,0.7742682707114532,0.9869741201400756,0.0502125173807144,0.2945714240781431,43793.0,0.9861940741539,0.0535595789551734,0.2857092445829384,43793.0,46597.234624147415,70553.02288079262,46597.234624147415,23943.26594948769,8.141406297683716,0.0 -143000,0.15838821,0.019921381,,,,,,,,,,,,,,,,, -143100,0.124174744,0.017719263,,,,,,,,,,,,,,,,, -143200,0.16026893,0.017452996,,,,,,,,,,,,,,,,, -143300,0.13189991,0.017048651,,,,,,,,,,,,,,,,, -143400,0.13750692,0.017100874,,,,,,,,,,,,,,,,, -143500,0.14019068,0.018020866,,,,,,,,,,,,,,,,, -143600,0.13832583,0.01794882,,,,,,,,,,,,,,,,, -143700,0.1466524,0.018727528,,,,,,,,,,,,,,,,, -143740,,,0.995592713356018,0.0139059126377105,0.7787287733952317,0.9869745373725892,0.0502125173807144,0.2947369056523955,43793.0,0.9861940741539,0.0535595826804637,0.285687980017125,43793.0,46837.26299023628,70901.68744587898,46837.26299023628,24051.82504749298,8.196717262268066,0.0 -143800,0.14829364,0.017193314,,,,,,,,,,,,,,,,, -143900,0.15119317,0.018367196,,,,,,,,,,,,,,,,, -144000,0.16505404,0.017787725,,,,,,,,,,,,,,,,, -144100,0.13538276,0.016551418,,,,,,,,,,,,,,,,, -144200,0.1405784,0.017761463,,,,,,,,,,,,,,,,, -144300,0.15249385,0.0198647,,,,,,,,,,,,,,,,, -144400,0.14454947,0.018675327,,,,,,,,,,,,,,,,, -144487,,,0.995563507080078,0.0140124568715691,0.7812311607734994,0.9869745373725892,0.0502125211060047,0.2945227253288941,43793.0,0.9861940741539,0.0535595789551734,0.28566755639923,43793.0,47077.19783067703,71247.94971942902,47077.19783067703,24158.07507967949,8.25191354751587,0.0 -144500,0.14417666,0.01789529,,,,,,,,,,,,,,,,, -144600,0.16941988,0.02015075,,,,,,,,,,,,,,,,, -144700,0.13949467,0.017824644,,,,,,,,,,,,,,,,, -144800,0.15903845,0.018564036,,,,,,,,,,,,,,,,, -144900,0.15301846,0.021104587,,,,,,,,,,,,,,,,, -145000,0.14244586,0.017298931,,,,,,,,,,,,,,,,, -145100,0.1243279,0.016306052,,,,,,,,,,,,,,,,, -145200,0.15061525,0.018775124,,,,,,,,,,,,,,,,, -145225,,,0.995533287525177,0.0139496708288788,0.7758802733144703,0.9869745373725892,0.0502125211060047,0.2944865694920178,43793.0,0.9861940741539,0.0535595826804637,0.285769349252281,43793.0,47317.14991569519,71596.03808999062,47317.14991569519,24266.130458831787,8.306992053985596,0.0 -145300,0.15680757,0.019104023,,,,,,,,,,,,,,,,, -145400,0.14224775,0.016966663,,,,,,,,,,,,,,,,, -145500,0.15808214,0.018496336,,,,,,,,,,,,,,,,, -145600,0.14558244,0.018967988,,,,,,,,,,,,,,,,, -145700,0.16735263,0.018078662,,,,,,,,,,,,,,,,, -145800,0.14205328,0.01751818,,,,,,,,,,,,,,,,, -145900,0.15034664,0.017503284,,,,,,,,,,,,,,,,, -145966,,,0.995587170124054,0.0138977067545056,0.773782049719947,0.9869743585586548,0.0502125211060047,0.2946646276835428,43793.0,0.9861940741539,0.0535595826804637,0.2857172927144193,43793.0,47557.18814301491,71946.61416625977,47557.18814301491,24376.586597919464,8.366672277450562,0.0 -146000,0.15496466,0.019757807,,,,,,,,,,,,,,,,, -146100,0.13776924,0.018413462,,,,,,,,,,,,,,,,, -146200,0.14136411,0.01707984,,,,,,,,,,,,,,,,, -146300,0.13795626,0.018841065,,,,,,,,,,,,,,,,, -146400,0.14174777,0.017846042,,,,,,,,,,,,,,,,, -146500,0.13902628,0.017832661,,,,,,,,,,,,,,,,, -146600,0.14813179,0.018048756,,,,,,,,,,,,,,,,, -146700,0.14750877,0.019203493,,,,,,,,,,,,,,,,, -146715,,,0.9954803586006165,0.0141344061121344,0.7739438319568628,0.9869743585586548,0.0502125211060047,0.2945254938863308,43793.0,0.9861940741539,0.0535595826804637,0.2857200951461492,43793.0,47797.23997068405,72297.11353874207,47797.23997068405,24486.95604276657,8.422574520111084,0.0 -146800,0.13764745,0.018140294,,,,,,,,,,,,,,,,, -146900,0.15527251,0.018791128,,,,,,,,,,,,,,,,, -147000,0.13577409,0.017047534,,,,,,,,,,,,,,,,, -147100,0.14960144,0.018056327,,,,,,,,,,,,,,,,, -147200,0.13112697,0.017031122,,,,,,,,,,,,,,,,, -147300,0.13286252,0.01668095,,,,,,,,,,,,,,,,, -147400,0.15160513,0.019726545,,,,,,,,,,,,,,,,, -147461,,,0.9955896735191344,0.0138777727261185,0.7762476049009219,0.9869745373725892,0.0502125173807144,0.2944843963867238,43793.0,0.9861940741539,0.0535595789551734,0.2857585033554984,43793.0,48037.220688819885,72643.81870794296,48037.220688819885,24593.60262060165,8.478218793869019,0.0 -147500,0.14237423,0.01909892,,,,,,,,,,,,,,,,, -147600,0.123631105,0.015878864,,,,,,,,,,,,,,,,, -147700,0.14466557,0.019198015,,,,,,,,,,,,,,,,, -147800,0.13315673,0.01640932,,,,,,,,,,,,,,,,, -147900,0.1390152,0.017762125,,,,,,,,,,,,,,,,, -148000,0.17036659,0.019734783,,,,,,,,,,,,,,,,, -148100,0.13575067,0.017398668,,,,,,,,,,,,,,,,, -148195,,,0.9955464005470276,0.014040638692677,0.7835602272200127,0.9869745373725892,0.0502125173807144,0.294611981843638,43793.0,0.9861940741539,0.0535595826804637,0.2857212119481352,43793.0,48277.27271056175,72988.82006072998,48277.27271056175,24698.474319934845,8.533739566802979,0.0 -148200,0.13357545,0.016326433,,,,,,,,,,,,,,,,, -148300,0.15077363,0.01786652,,,,,,,,,,,,,,,,, -148400,0.15209481,0.019163013,,,,,,,,,,,,,,,,, -148500,0.13449636,0.017202081,,,,,,,,,,,,,,,,, -148600,0.1524835,0.018475397,,,,,,,,,,,,,,,,, -148700,0.14208397,0.018422317,,,,,,,,,,,,,,,,, -148800,0.13009495,0.01681321,,,,,,,,,,,,,,,,, -148900,0.1387028,0.0181958,,,,,,,,,,,,,,,,, -148936,,,0.9955344200134276,0.0140588777139782,0.7736069100158747,0.9869743585586548,0.0502125173807144,0.2945416036507968,43793.0,0.9861940741539,0.0535595789551734,0.2856568886471392,43793.0,48517.27303361893,73341.68425559998,48517.27303361893,24811.260496854786,8.589645624160767,0.0 -149000,0.15123953,0.018213375,,,,,,,,,,,,,,,,, -149100,0.15855871,0.019980885,,,,,,,,,,,,,,,,, -149200,0.13979451,0.017768038,,,,,,,,,,,,,,,,, -149300,0.13918635,0.016619889,,,,,,,,,,,,,,,,, -149400,0.152347,0.019428868,,,,,,,,,,,,,,,,, -149500,0.15906578,0.0182948,,,,,,,,,,,,,,,,, -149600,0.13799922,0.019009372,,,,,,,,,,,,,,,,, -149679,,,0.995566189289093,0.0139274382963776,0.7797421371283534,0.9869743585586548,0.0502125211060047,0.2947302321803623,43793.0,0.9861940741539,0.0535595789551734,0.2857804305900344,43793.0,48757.350826501846,73693.50581741333,48757.350826501846,24922.92630887032,8.645307302474976,0.0 -149700,0.1493793,0.01921954,,,,,,,,,,,,,,,,, -149800,0.1349365,0.016127778,,,,,,,,,,,,,,,,, -149900,0.12115305,0.014665741,,,,,,,,,,,,,,,,, -150000,0.1421436,0.018366719,,,,,,,,,,,,,,,,, -150100,0.14335863,0.017314913,,,,,,,,,,,,,,,,, -150200,0.15285799,0.018664088,,,,,,,,,,,,,,,,, -150300,0.15591577,0.018899972,,,,,,,,,,,,,,,,, -150400,0.16722906,0.018229753,,,,,,,,,,,,,,,,, -150424,,,0.9955343008041382,0.0140413139015436,0.7644432246613764,0.9869745373725892,0.0502125211060047,0.2945536299897155,43793.0,0.9861940741539,0.0535595826804637,0.2856523374013939,43793.0,48997.344621658325,74044.27948951721,48997.344621658325,25033.61894416809,8.710498809814453,0.0 -150500,0.14528629,0.017862223,,,,,,,,,,,,,,,,, -150600,0.15239334,0.01887107,,,,,,,,,,,,,,,,, -150700,0.13383381,0.018236808,,,,,,,,,,,,,,,,, -150800,0.16024585,0.018129302,,,,,,,,,,,,,,,,, -150900,0.14685275,0.019083401,,,,,,,,,,,,,,,,, -151000,0.14689176,0.018602371,,,,,,,,,,,,,,,,, -151100,0.15462752,0.017175617,,,,,,,,,,,,,,,,, -151164,,,0.995585322380066,0.0138838086277246,0.7787436689713627,0.9869743585586548,0.0502125173807144,0.2945482853631458,43793.0,0.9861940741539,0.0535595789551734,0.2857042056457721,43793.0,49236.9952814579,74394.6922750473,49236.9952814579,25143.97053194046,9.098667860031128,0.0 -151200,0.145715,0.01790339,,,,,,,,,,,,,,,,, -151300,0.13458033,0.01676276,,,,,,,,,,,,,,,,, -151400,0.16159822,0.019651981,,,,,,,,,,,,,,,,, -151500,0.13531502,0.015301344,,,,,,,,,,,,,,,,, -151600,0.14956148,0.019014234,,,,,,,,,,,,,,,,, -151700,0.13759357,0.016472843,,,,,,,,,,,,,,,,, -151800,0.14514972,0.01749127,,,,,,,,,,,,,,,,, -151900,0.1326823,0.01718137,,,,,,,,,,,,,,,,, -151907,,,0.9954984784126282,0.0141248051077127,0.7813303133972982,0.9869745373725892,0.0502125211060047,0.2945451697094664,43793.0,0.9861940741539,0.0535595826804637,0.2859263697756146,43793.0,49477.232617139816,74748.23588776588,49477.232617139816,25257.198073625565,9.155184268951416,0.0 -152000,0.145502,0.01793167,,,,,,,,,,,,,,,,, -152100,0.16476898,0.016852885,,,,,,,,,,,,,,,,, -152200,0.1584004,0.019255448,,,,,,,,,,,,,,,,, -152300,0.13277249,0.014998494,,,,,,,,,,,,,,,,, -152400,0.15083815,0.017146096,,,,,,,,,,,,,,,,, -152500,0.15969376,0.020503411,,,,,,,,,,,,,,,,, -152600,0.16500075,0.018389508,,,,,,,,,,,,,,,,, -152649,,,0.99554842710495,0.0139965619891881,0.7735667334688098,0.9869745373725892,0.0502125211060047,0.2945991734176778,43793.0,0.9861940741539,0.0535595789551734,0.2857255714175824,43793.0,49717.24135637283,75094.74559378624,49717.24135637283,25363.61196255684,9.216427087783812,0.0 -152700,0.14803396,0.017021812,,,,,,,,,,,,,,,,, -152800,0.12826025,0.016276088,,,,,,,,,,,,,,,,, -152900,0.14510168,0.018873397,,,,,,,,,,,,,,,,, -153000,0.13153662,0.017480236,,,,,,,,,,,,,,,,, -153100,0.149667,0.018099012,,,,,,,,,,,,,,,,, -153200,0.14547011,0.018229648,,,,,,,,,,,,,,,,, -153300,0.15210615,0.016359948,,,,,,,,,,,,,,,,, -153386,,,0.9955598711967468,0.0139503329992294,0.7727688636510217,0.9869743585586548,0.0502125211060047,0.2945004272365443,43793.0,0.9861940741539,0.0535595789551734,0.2856822054806806,43793.0,49957.17316937447,75443.33772158623,49957.17316937447,25472.193568706512,9.272395849227903,0.0 -153400,0.13728023,0.017392378,,,,,,,,,,,,,,,,, -153500,0.14412662,0.020590123,,,,,,,,,,,,,,,,, -153600,0.13042569,0.014913874,,,,,,,,,,,,,,,,, -153700,0.14823318,0.018845523,,,,,,,,,,,,,,,,, -153800,0.13291153,0.017745782,,,,,,,,,,,,,,,,, -153900,0.1495569,0.018218687,,,,,,,,,,,,,,,,, -154000,0.14375351,0.018636107,,,,,,,,,,,,,,,,, -154100,0.13445073,0.016824415,,,,,,,,,,,,,,,,, -154125,,,0.99556964635849,0.0139679256826639,0.7743029819107554,0.9869743585586548,0.0502125173807144,0.2944922403376407,43793.0,0.9861940741539,0.0535595789551734,0.2857150265415266,43793.0,50197.12275767326,75792.4113547802,50197.12275767326,25581.23864412308,9.328509330749512,0.0 -154200,0.15015806,0.018741865,,,,,,,,,,,,,,,,, -154300,0.13307332,0.01623445,,,,,,,,,,,,,,,,, -154400,0.15661347,0.017972885,,,,,,,,,,,,,,,,, -154500,0.13956672,0.017806005,,,,,,,,,,,,,,,,, -154600,0.1345548,0.01776774,,,,,,,,,,,,,,,,, -154700,0.13263103,0.016828911,,,,,,,,,,,,,,,,, -154800,0.13602729,0.01791702,,,,,,,,,,,,,,,,, -154859,,,0.995524525642395,0.0140228336676955,0.7734727200254108,0.9869745373725892,0.0502125211060047,0.2945479611746832,43793.0,0.9861940741539,0.0535595789551734,0.2857401568499358,43793.0,50437.306334257126,76142.60612154007,50437.306334257126,25691.170045137405,9.385692358016968,0.0 -154900,0.13104284,0.015451803,,,,,,,,,,,,,,,,, -155000,0.14399202,0.017019244,,,,,,,,,,,,,,,,, -155100,0.14127128,0.018174596,,,,,,,,,,,,,,,,, -155200,0.14175785,0.018475754,,,,,,,,,,,,,,,,, -155300,0.16808416,0.019721711,,,,,,,,,,,,,,,,, -155400,0.14663461,0.017942937,,,,,,,,,,,,,,,,, -155500,0.13855763,0.017650632,,,,,,,,,,,,,,,,, -155600,0.14639293,0.019960765,,,,,,,,,,,,,,,,, -155601,,,0.9955626726150512,0.0139491278678178,0.7778168183954786,0.9869745373725892,0.0502125173807144,0.2944723845323849,43793.0,0.9861940741539,0.0535595789551734,0.285683389303403,43793.0,50677.31756472588,76490.74878931046,50677.31756472588,25799.21874904633,9.444878578186035,0.0 -155700,0.15150651,0.017554477,,,,,,,,,,,,,,,,, -155800,0.12268051,0.01575659,,,,,,,,,,,,,,,,, -155900,0.14521396,0.017084476,,,,,,,,,,,,,,,,, -156000,0.13440824,0.017136082,,,,,,,,,,,,,,,,, -156100,0.13746743,0.019363578,,,,,,,,,,,,,,,,, -156200,0.12734401,0.015266695,,,,,,,,,,,,,,,,, -156300,0.1466153,0.019482797,,,,,,,,,,,,,,,,, -156336,,,0.995536208152771,0.014027833007276,0.7825999920584917,0.9869745373725892,0.0502125173807144,0.2946293654606327,43793.0,0.9861940741539,0.0535595789551734,0.2861027977699629,43793.0,50917.56117296219,76838.96542525291,50917.56117296219,25907.11387944221,9.500680446624756,0.0 -156400,0.15644006,0.017396204,,,,,,,,,,,,,,,,, -156500,0.15273485,0.017576037,,,,,,,,,,,,,,,,, -156600,0.15034287,0.017670972,,,,,,,,,,,,,,,,, -156700,0.1287132,0.016113294,,,,,,,,,,,,,,,,, -156800,0.14623426,0.017219743,,,,,,,,,,,,,,,,, -156900,0.14600138,0.019193029,,,,,,,,,,,,,,,,, -157000,0.16229284,0.019779244,,,,,,,,,,,,,,,,, -157076,,,0.9955697655677797,0.0139607917517423,0.7700218074027575,0.9869745373725892,0.0502125211060047,0.2946568657246006,43793.0,0.9861940741539,0.0535595826804637,0.2856674154938795,43793.0,51157.53374195099,77183.41665196419,51157.53374195099,26011.510071754456,9.56079602241516,0.0 -157100,0.13872017,0.016989423,,,,,,,,,,,,,,,,, -157200,0.16514073,0.02096401,,,,,,,,,,,,,,,,, -157300,0.1321354,0.017084735,,,,,,,,,,,,,,,,, -157400,0.12967403,0.016554436,,,,,,,,,,,,,,,,, -157500,0.13676909,0.018128712,,,,,,,,,,,,,,,,, -157600,0.12592457,0.016439667,,,,,,,,,,,,,,,,, -157700,0.14729752,0.018775623,,,,,,,,,,,,,,,,, -157800,0.13812621,0.01729218,,,,,,,,,,,,,,,,, -157817,,,0.9955599308013916,0.0139398034662008,0.7757960907398831,0.9869745373725892,0.0502125211060047,0.2945389451791146,43793.0,0.9861940741539,0.0535595826804637,0.2856970875343083,43793.0,51397.612380981445,77532.86435222626,51397.612380981445,26120.800348758698,9.617570638656616,0.0 -157900,0.15168686,0.017150898,,,,,,,,,,,,,,,,, -158000,0.14276859,0.019404031,,,,,,,,,,,,,,,,, -158100,0.14762306,0.01758699,,,,,,,,,,,,,,,,, -158200,0.12561212,0.016733509,,,,,,,,,,,,,,,,, -158300,0.14354053,0.015861435,,,,,,,,,,,,,,,,, -158400,0.15055113,0.019618385,,,,,,,,,,,,,,,,, -158500,0.12874301,0.016319906,,,,,,,,,,,,,,,,, -158556,,,0.9955105185508728,0.0140839051455259,0.7653077269785241,0.9869745373725892,0.0502125211060047,0.2946983592080036,43793.0,0.9861940741539,0.0535595789551734,0.2858380709121086,43793.0,51637.58630156517,77876.18381977081,51637.58630156517,26224.05740213394,9.681266069412231,0.0 -158600,0.16402113,0.01860966,,,,,,,,,,,,,,,,, -158700,0.1687615,0.01778257,,,,,,,,,,,,,,,,, -158800,0.13168833,0.017608428,,,,,,,,,,,,,,,,, -158900,0.14308085,0.01729381,,,,,,,,,,,,,,,,, -159000,0.14762904,0.01865794,,,,,,,,,,,,,,,,, -159100,0.14875022,0.019304343,,,,,,,,,,,,,,,,, -159200,0.13984688,0.01774389,,,,,,,,,,,,,,,,, -159300,0.16231263,0.01967843,,,,,,,,,,,,,,,,, -159302,,,0.995564877986908,0.0139866154640913,0.7803369469318409,0.9869745373725892,0.0502125211060047,0.2945265167472148,43793.0,0.9861940741539,0.0535595789551734,0.2857684194818551,43793.0,51877.640649318695,78229.36774969101,51877.640649318695,26337.106037139893,9.740336656570436,0.0 -159400,0.13758977,0.018236555,,,,,,,,,,,,,,,,, -159500,0.13316791,0.018660996,,,,,,,,,,,,,,,,, -159600,0.13489105,0.016567633,,,,,,,,,,,,,,,,, -159700,0.130858,0.01659503,,,,,,,,,,,,,,,,, -159800,0.15262778,0.018002441,,,,,,,,,,,,,,,,, -159900,0.14464723,0.019440673,,,,,,,,,,,,,,,,, -160000,0.13325122,0.01399759,,,,,,,,,,,,,,,,, -160043,,,0.995552122592926,0.0139129171147942,0.7836719891879151,0.9869745373725892,0.0502125211060047,0.2946214386872627,43793.0,0.9861940741539,0.0535595826804637,0.2856528216248714,43793.0,52117.86292815208,78575.34875035286,52117.86292815208,26442.78346991539,9.7996826171875,0.0 -160100,0.13297728,0.015863229,,,,,,,,,,,,,,,,, -160200,0.14641571,0.016812136,,,,,,,,,,,,,,,,, -160300,0.16539162,0.019609934,,,,,,,,,,,,,,,,, -160400,0.16654037,0.021004567,,,,,,,,,,,,,,,,, -160500,0.15234214,0.019117897,,,,,,,,,,,,,,,,, -160600,0.16489354,0.01868255,,,,,,,,,,,,,,,,, -160700,0.14113103,0.01984748,,,,,,,,,,,,,,,,, -160786,,,0.9955381751060486,0.0140344044193625,0.7719031206961338,0.9869745373725892,0.0502125211060047,0.2946162831197981,43793.0,0.9861940741539,0.0535595789551734,0.2857982278245645,43793.0,52358.04911708832,78927.19647717476,52358.04911708832,26554.365653038025,9.857231855392456,0.0 -160800,0.14186162,0.016849123,,,,,,,,,,,,,,,,, -160900,0.14082092,0.016278017,,,,,,,,,,,,,,,,, -161000,0.13153663,0.018375874,,,,,,,,,,,,,,,,, -161100,0.14569052,0.01803853,,,,,,,,,,,,,,,,, -161200,0.13816938,0.017902257,,,,,,,,,,,,,,,,, -161300,0.16148505,0.019475037,,,,,,,,,,,,,,,,, -161400,0.14533752,0.017940218,,,,,,,,,,,,,,,,, -161500,0.1347114,0.017138997,,,,,,,,,,,,,,,,, -161514,,,0.9955840706825256,0.0139219872653484,0.7775919723837796,0.9869743585586548,0.0502125211060047,0.2946046044145738,43793.0,0.9861940741539,0.0535595826804637,0.2857322342018337,43793.0,52598.25149941445,79272.00039196014,52598.25149941445,26658.87480187416,9.923945665359495,0.0 -161600,0.13141374,0.01804367,,,,,,,,,,,,,,,,, -161700,0.124156736,0.016974904,,,,,,,,,,,,,,,,, -161800,0.15645419,0.018927347,,,,,,,,,,,,,,,,, -161900,0.14318518,0.018391015,,,,,,,,,,,,,,,,, -162000,0.13565554,0.017578678,,,,,,,,,,,,,,,,, -162100,0.15286127,0.020285187,,,,,,,,,,,,,,,,, -162200,0.1459691,0.020342246,,,,,,,,,,,,,,,,, -162258,,,0.99552983045578,0.0140566751360893,0.7652208045052136,0.9869745373725892,0.0502125173807144,0.2945383467173924,43793.0,0.9861940741539,0.0535595826804637,0.2857639647881098,43793.0,52838.20043492317,79622.39813017845,52838.20043492317,26769.24395275116,9.981608629226685,0.0 -162300,0.13680686,0.018202066,,,,,,,,,,,,,,,,, -162400,0.13173269,0.017716544,,,,,,,,,,,,,,,,, -162500,0.13753694,0.017258693,,,,,,,,,,,,,,,,, -162600,0.14520977,0.01910264,,,,,,,,,,,,,,,,, -162700,0.14759926,0.01863762,,,,,,,,,,,,,,,,, -162800,0.13945206,0.017275762,,,,,,,,,,,,,,,,, -162900,0.15508547,0.019006329,,,,,,,,,,,,,,,,, -163000,0.13144343,0.017520713,,,,,,,,,,,,,,,,, -163008,,,0.995524287223816,0.0140065401792526,0.7796286507900729,0.9869741201400756,0.0502125173807144,0.2945911879708596,43793.0,0.9861940741539,0.0535595789551734,0.285866172733904,43793.0,53078.12421751022,79968.65912795067,53078.12421751022,26875.5023086071,10.03822422027588,0.0 -163100,0.132291,0.018842919,,,,,,,,,,,,,,,,, -163200,0.13411419,0.018209036,,,,,,,,,,,,,,,,, -163300,0.15195401,0.018211983,,,,,,,,,,,,,,,,, -163400,0.14041239,0.017741598,,,,,,,,,,,,,,,,, -163500,0.15764183,0.018453764,,,,,,,,,,,,,,,,, -163600,0.12734894,0.017425232,,,,,,,,,,,,,,,,, -163700,0.1441953,0.017371017,,,,,,,,,,,,,,,,, -163755,,,0.9955673217773438,0.0139431608840823,0.7821653090591817,0.9869745373725892,0.0502125173807144,0.2944987559590682,43793.0,0.9861940741539,0.0535595789551734,0.2857193972505057,43793.0,53318.33121609688,80315.01131868362,53318.33121609688,26981.566202640533,10.09718656539917,0.0 -163800,0.12270523,0.016433323,,,,,,,,,,,,,,,,, -163900,0.15763533,0.018727936,,,,,,,,,,,,,,,,, -164000,0.144435,0.018165732,,,,,,,,,,,,,,,,, -164100,0.1466623,0.016439183,,,,,,,,,,,,,,,,, -164200,0.14135025,0.017361907,,,,,,,,,,,,,,,,, -164300,0.13998136,0.016992396,,,,,,,,,,,,,,,,, -164400,0.14158757,0.0186859,,,,,,,,,,,,,,,,, -164493,,,0.9955205917358398,0.0140897808596491,0.7764114194363383,0.9869745373725892,0.0502125211060047,0.2945493493391814,43793.0,0.9861940741539,0.0535595789551734,0.2856440025829405,43793.0,53558.39138150215,80656.87715959549,53558.39138150215,27083.28986978531,10.157109260559082,0.0 -164500,0.1459157,0.015775189,,,,,,,,,,,,,,,,, -164600,0.14525695,0.01761613,,,,,,,,,,,,,,,,, -164700,0.16086559,0.018728549,,,,,,,,,,,,,,,,, -164800,0.14478643,0.0186463,,,,,,,,,,,,,,,,, -164900,0.157214,0.020834867,,,,,,,,,,,,,,,,, -165000,0.13949427,0.017370844,,,,,,,,,,,,,,,,, -165100,0.16769753,0.018533083,,,,,,,,,,,,,,,,, -165200,0.12684625,0.01666175,,,,,,,,,,,,,,,,, -165239,,,0.9955911040306092,0.0139289153739809,0.7734966359252158,0.9869745373725892,0.0502125211060047,0.2944849121350764,43793.0,0.9861940741539,0.0535595789551734,0.2857893411828576,43793.0,53798.36969470978,81000.26069259644,53798.36969470978,27186.615166664124,10.215405464172363,0.0 -165300,0.13812576,0.017780947,,,,,,,,,,,,,,,,, -165400,0.16048397,0.019247722,,,,,,,,,,,,,,,,, -165500,0.12552032,0.016684933,,,,,,,,,,,,,,,,, -165600,0.1302246,0.016953932,,,,,,,,,,,,,,,,, -165700,0.13730621,0.018345052,,,,,,,,,,,,,,,,, -165800,0.14435206,0.018612918,,,,,,,,,,,,,,,,, -165900,0.13919479,0.018165246,,,,,,,,,,,,,,,,, -165985,,,0.9955465197563172,0.0139400130137801,0.7729779314410089,0.9869743585586548,0.0502125211060047,0.2944748477438155,43793.0,0.9861940741539,0.0535595826804637,0.2857110507380125,43793.0,54038.34677696228,81350.02177357674,54038.34677696228,27296.31303691864,10.27896499633789,0.0 -166000,0.1547507,0.01888855,,,,,,,,,,,,,,,,, -166100,0.16371456,0.019370798,,,,,,,,,,,,,,,,, -166200,0.13825347,0.016394049,,,,,,,,,,,,,,,,, -166300,0.14639075,0.019092146,,,,,,,,,,,,,,,,, -166400,0.13542116,0.016710624,,,,,,,,,,,,,,,,, -166500,0.16062324,0.019122595,,,,,,,,,,,,,,,,, -166600,0.14524092,0.017936882,,,,,,,,,,,,,,,,, -166700,0.15487625,0.017941503,,,,,,,,,,,,,,,,, -166730,,,0.995519995689392,0.0140833482146263,0.7805190367217666,0.9869743585586548,0.0502125211060047,0.294567484941141,43793.0,0.9861940741539,0.0535595789551734,0.2857211817064314,43793.0,54278.52591109276,81700.15067434311,54278.52591109276,27406.18231940269,10.336283922195436,0.0 -166800,0.13579656,0.019047236,,,,,,,,,,,,,,,,, -166900,0.15677315,0.018493373,,,,,,,,,,,,,,,,, -167000,0.1496376,0.018754948,,,,,,,,,,,,,,,,, -167100,0.15526073,0.01866085,,,,,,,,,,,,,,,,, -167200,0.12850465,0.018398251,,,,,,,,,,,,,,,,, -167300,0.1422759,0.017295057,,,,,,,,,,,,,,,,, -167400,0.1279102,0.017998518,,,,,,,,,,,,,,,,, -167476,,,0.9955546855926514,0.0139647619798779,0.7759126234046201,0.9869745373725892,0.0502125211060047,0.2945054412881348,43793.0,0.9861940741539,0.0535595826804637,0.2857042919315566,43793.0,54518.58748054504,82049.38845849037,54518.58748054504,27515.27748608589,10.39532470703125,0.0 -167500,0.14773224,0.01663568,,,,,,,,,,,,,,,,, -167600,0.15296318,0.019239765,,,,,,,,,,,,,,,,, -167700,0.13335964,0.017240474,,,,,,,,,,,,,,,,, -167800,0.14777125,0.015719168,,,,,,,,,,,,,,,,, -167900,0.1437579,0.01824929,,,,,,,,,,,,,,,,, -168000,0.15333885,0.018066756,,,,,,,,,,,,,,,,, -168100,0.14762619,0.018108634,,,,,,,,,,,,,,,,, -168200,0.14408866,0.016420813,,,,,,,,,,,,,,,,, -168219,,,0.9955841302871704,0.0139201804995536,0.7750355602140353,0.9869743585586548,0.0502125173807144,0.2946438301036884,43793.0,0.9861940741539,0.0535595826804637,0.2856700453934852,43793.0,54758.71322131157,82395.75434350967,54758.71322131157,27621.426381587986,10.46361517906189,0.0 -168300,0.13351348,0.0172403,,,,,,,,,,,,,,,,, -168400,0.14083874,0.018488783,,,,,,,,,,,,,,,,, -168500,0.13553125,0.016563956,,,,,,,,,,,,,,,,, -168600,0.14114015,0.016466787,,,,,,,,,,,,,,,,, -168700,0.1312498,0.016983712,,,,,,,,,,,,,,,,, -168800,0.13549076,0.01753296,,,,,,,,,,,,,,,,, -168900,0.15696119,0.019556668,,,,,,,,,,,,,,,,, -168963,,,0.9955199360847472,0.0140633136034011,0.7707108704904337,0.9869745373725892,0.0502125211060047,0.2944998386340789,43793.0,0.9861940741539,0.0535595789551734,0.2857018252869073,43793.0,54998.81431245804,82742.2187845707,54998.81431245804,27727.70783424377,10.523662567138672,0.0 -169000,0.13758798,0.018030118,,,,,,,,,,,,,,,,, -169100,0.14595115,0.01735609,,,,,,,,,,,,,,,,, -169200,0.13715759,0.017213056,,,,,,,,,,,,,,,,, -169300,0.13539594,0.017463906,,,,,,,,,,,,,,,,, -169400,0.13291031,0.016381321,,,,,,,,,,,,,,,,, -169500,0.16708654,0.019572562,,,,,,,,,,,,,,,,, -169600,0.15226054,0.01751651,,,,,,,,,,,,,,,,, -169700,0.14242466,0.018018642,,,,,,,,,,,,,,,,, -169708,,,0.9955832958221436,0.0138926319777965,0.7774776130313928,0.9869745373725892,0.0502125211060047,0.2945532663143946,43793.0,0.9861940741539,0.0535595826804637,0.285909555773725,43793.0,55238.7580242157,83084.81942462921,55238.7580242157,27830.284168481827,10.582506656646729,0.0 -169800,0.13729435,0.018522348,,,,,,,,,,,,,,,,, -169900,0.1285533,0.01661445,,,,,,,,,,,,,,,,, -170000,0.14149743,0.018966533,,,,,,,,,,,,,,,,, -170100,0.13544881,0.016157953,,,,,,,,,,,,,,,,, -170200,0.14912309,0.01808447,,,,,,,,,,,,,,,,, -170300,0.14637332,0.019403817,,,,,,,,,,,,,,,,, -170306,,,,,,,,,,,,,,55431.08271622658,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index b6387df54..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,172 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -875.231207370758,0.0,37.56307125091553,1,0,37.56307125091553,0.0007088489946909,2.4277196510573813e-10,11.209013938903809,3003,912.794314622879,0.0004393419076222,4.8358820565900894e-11,11.199871063232422,0.0004835649742744,5.65026423809594e-10,11.21283721923828,3000 -1341.4127733707428,0.030313491821289,877.7042384147644,2362,0,877.7042384147644,0.5118703246116638,17.134560690987676,2.911105871200561,3003,2219.220136165619,0.5115175843238831,22.22694865293249,2.875618696212769,0.5132980346679688,18.48454646661626,2.8582029342651367,3000 -1901.0871896743768,0.0534377098083496,1717.85214304924,4724,0,1717.85214304924,0.5931090712547302,21.86833526566419,2.126178741455078,3003,3619.138933420181,0.5785503387451172,27.20758975158053,2.2474136352539062,0.5907180309295654,23.473832158007244,2.1485509872436523,3000 -2338.508580446244,0.0780870914459228,2558.087086677552,7087,0,2558.087086677552,0.6229620575904846,24.41283791906644,1.8805855512619016,3003,4896.891449689865,0.6039042472839355,29.45638549400617,2.021613597869873,0.6163345575332642,25.38209760851271,1.92962908744812,3000 -2790.7796771526337,0.1033430099487304,3398.317975282669,9450,0,3398.317975282669,0.6379989981651306,25.231059478606504,1.7546569108963013,3003,6189.491263628006,0.6103935837745667,29.46675346875256,1.9617842435836792,0.6299116015434265,26.3658578673267,1.8166937828063965,3000 -3252.4657728672028,0.130044937133789,4238.32371544838,11814,0,4238.32371544838,0.6472139954566956,25.66111560996473,1.68716299533844,3003,7491.281229496002,0.6192044615745544,29.88688007428108,1.901554822921753,0.6387893557548523,26.820170193034105,1.749039649963379,3000 -3787.736862421036,0.1561300754547119,5078.380287885666,14176,0,5078.380287885666,0.6528964042663574,26.293182704377568,1.6480978727340698,3003,8866.706029176712,0.6295682191848755,30.651995885245505,1.819696068763733,0.6438357830047607,27.24148470012765,1.710865139961243,3000 -4254.722022771835,0.1817579269409179,5918.573413133621,16538,0,5918.573413133621,0.6548022031784058,26.148766488248864,1.6183347702026367,3003,10173.98059129715,0.6271182894706726,30.539107747251062,1.8289114236831665,0.6476299166679382,27.438749015513427,1.6841576099395752,3000 -4736.290683746338,0.207653522491455,6759.098012447357,18901,0,6759.098012447357,0.6612864136695862,26.80233486357677,1.5909876823425293,3003,11496.171914815905,0.6572504639625549,32.22472034445667,1.6087989807128906,0.651436448097229,27.6112143291353,1.6518349647521973,3000 -5311.4401948452,0.2353913784027099,7599.073549985886,21263,0,7599.073549985886,0.6627622246742249,26.717003601278886,1.5679388046264648,3003,12911.396076202393,0.6349077224731445,30.705749056075486,1.7626363039016724,0.6522423624992371,27.61326392444762,1.6453105211257937,3000 -5869.200685739517,0.2620282173156738,8439.251418828964,23626,0,8439.251418828964,0.6658416390419006,27.20188573853869,1.549932837486267,3003,14309.434327602386,0.6350238919258118,30.85238259132561,1.7771228551864624,0.6539162397384644,27.74533933170228,1.631062388420105,3000 -6335.238257408142,0.2900032997131347,9279.454516649246,25989,0,9279.454516649246,0.6658997535705566,27.17006367961328,1.5409635305404663,3003,15615.774932384493,0.6467454433441162,31.536803516208945,1.6869769096374512,0.6557761430740356,27.694900356711823,1.6167680025100708,3000 -6827.682823181152,0.3175230026245117,10119.372649908066,28352,0,10119.372649908066,0.6687235236167908,27.420597130706373,1.530341863632202,3003,16948.23571062088,0.6393489837646484,30.885567339140408,1.732313871383667,0.6574996113777161,27.800652263048256,1.6025561094284058,3000 -7330.172876596451,0.345379114151001,10959.414959907532,30714,0,10959.414959907532,0.6724769473075867,27.439883614878845,1.5186891555786133,3003,18290.868687152863,0.6419368982315063,31.349138689209163,1.733702540397644,0.6595702171325684,28.190826470172745,1.598743200302124,3000 -7815.368814945221,0.3722670078277588,11799.583815813065,33077,0,11799.583815813065,0.6725350022315979,27.469167647169623,1.505925536155701,3003,19616.332770109177,0.6454497575759888,31.41565148825551,1.700178146362305,0.6618516445159912,28.273850035586964,1.5838629007339478,3000 -8494.342839956284,0.3995921611785888,12639.751176595688,35440,0,12639.751176595688,0.6744408011436462,27.711011086782264,1.50215744972229,3003,21135.57282590866,0.6440091729164124,31.4159123954736,1.7076191902160645,0.6626452207565308,27.312578619277826,1.580370545387268,3000 -9009.498941421509,0.4266171455383301,13479.938393354416,37803,0,13479.938393354416,0.6750218272209167,27.95044166549037,1.493459939956665,3003,22491.015778541565,0.662269115447998,32.45727219264936,1.5625323057174685,0.6636123657226562,28.33616940026225,1.5751655101776123,3000 -9629.737220525742,0.4572136402130127,14319.857029676436,40165,0,14319.857029676436,0.6757073998451233,27.78759015814581,1.4875746965408323,3003,23951.275722503666,0.6452246308326721,31.45640492988856,1.6897084712982178,0.6639595031738281,28.45596245736317,1.5691026449203491,3000 -10120.314247369766,0.4865961074829101,15159.80151104927,42526,0,15159.80151104927,0.6788914203643799,28.226920449596552,1.47401762008667,3003,25281.899538993835,0.642837405204773,31.78450162676224,1.7198340892791748,0.6651126146316528,28.53575327431008,1.5587149858474731,3000 -10656.51029253006,0.519707202911377,15999.772643089294,44888,0,15999.772643089294,0.6773343086242676,28.39207283070988,1.4703198671340942,3003,26658.1717505455,0.6535823345184326,32.10639596896463,1.6201032400131226,0.6656457781791687,28.288459980167566,1.5512747764587402,3000 -11275.247904539108,0.5494298934936523,16839.723083019257,47249,0,16839.723083019257,0.6796002984046936,28.36357916679935,1.4612714052200315,3003,28116.96264028549,0.6474329829216003,31.87093874194392,1.6818699836730957,0.6667864918708801,28.37202242681093,1.546050190925598,3000 -11770.484377622604,0.5856420993804932,17679.671919107437,49611,0,17679.671919107437,0.6794027090072632,28.18185372945468,1.4521712064743042,3003,29452.255368232727,0.6456196904182434,31.93500844978868,1.6879490613937378,0.6677908301353455,28.72653963720832,1.5375560522079468,3000 -12551.531423330309,0.6143500804901123,18519.571103811264,51972,0,18519.571103811264,0.682877242565155,28.490701505346525,1.4454905986785889,3003,31073.305172920227,0.6551151871681213,32.32626977076003,1.6218953132629397,0.6678900122642517,28.88246091612095,1.531269073486328,3000 -13199.815757513046,0.6437675952911377,19359.637244701385,54334,0,19359.637244701385,0.682540237903595,28.32718456710534,1.4420441389083862,3003,32561.75861167908,0.6510263681411743,32.088404891666734,1.655605435371399,0.6696755290031433,28.627795726474464,1.5259546041488647,3000 -13789.664745807648,0.678861141204834,20199.781319618225,56696,0,20199.781319618225,0.6850502490997314,28.62387632662793,1.4306808710098269,3003,33991.85821199417,0.6673276424407959,33.360906251543774,1.5445960760116575,0.6722297072410583,29.03924901341143,1.5170300006866455,3000 -14389.392583847046,0.7105739116668701,21039.77918243408,59058,0,21039.77918243408,0.6850270628929138,28.680309752241467,1.4248459339141846,3003,35431.68813467026,0.6547285318374634,32.10986175858725,1.627714991569519,0.6729241013526917,29.169344862532743,1.5107427835464478,3000 -14981.200040340424,0.7415766716003418,21879.856380939484,61421,0,21879.856380939484,0.6873162388801575,28.77019337846428,1.4158756732940674,3003,36863.674137830734,0.6533963084220886,32.42845141366208,1.6542311906814575,0.6724777221679688,29.03658871071236,1.5102542638778689,3000 -15511.243689775469,0.7795286178588867,22719.75696492195,63782,0,22719.75696492195,0.6861774325370789,28.776615374007704,1.407378315925598,3003,38233.72824931145,0.6634188890457153,32.64206969909681,1.5673255920410156,0.6739531755447388,29.020543803915928,1.4961543083190918,3000 -16046.860683441162,0.8157198429107666,23559.911987304688,66144,0,23559.911987304688,0.6896519660949707,28.88593770250436,1.402874231338501,3003,39609.60906815529,0.6572242975234985,32.916716181700515,1.6180578470230105,0.6756022572517395,29.525804694230857,1.4923619031906128,3000 -16597.32043647766,0.848196268081665,24399.88455271721,68505,0,24399.88455271721,0.691499650478363,28.92139389859646,1.3894480466842651,3003,41000.14904499054,0.6578312516212463,32.58494738936097,1.619662880897522,0.6772885322570801,29.29081053214644,1.481269598007202,3000 -17273.362685203552,0.881028413772583,25239.83142542839,70867,0,25239.83142542839,0.6918947100639343,29.064057246440694,1.386744737625122,3003,42516.241681814194,0.6645039319992065,32.747214248427234,1.563105225563049,0.677238941192627,29.26196255330381,1.4762758016586304,3000 -17925.62251186371,0.9149608612060548,26080.004014968872,73229,0,26080.004014968872,0.6896170973777771,29.036617066834708,1.3845664262771606,3003,44008.78098034859,0.6643652319908142,32.75438032641624,1.574614405632019,0.6768794059753418,29.36006734681592,1.4778293371200562,3000 -18660.99428129196,0.9479324817657472,26920.101365327835,75591,0,26920.101365327835,0.6935797333717346,29.24812489131629,1.3710495233535769,3003,45584.35468482971,0.6765106320381165,33.98489511197334,1.4892783164978027,0.6804627180099487,29.06213421059305,1.4630177021026611,3000 -19492.98771739006,0.9874765872955322,27760.15807056427,77952,0,27760.15807056427,0.6959154009819031,29.535897402171667,1.3664885759353638,3003,47256.51785254479,0.6673470735549927,32.9429847730829,1.5419315099716189,0.680648684501648,28.36214799696568,1.4586224555969238,3000 -20028.57369875908,1.0211431980133057,28600.16921401024,80314,0,28600.16921401024,0.6945790648460388,29.35732571126601,1.360152244567871,3003,48632.22001385689,0.6680836081504822,32.917216136696155,1.5501035451889038,0.6810950636863708,29.69538787912534,1.452638030052185,3000 -20665.47934579849,1.05684494972229,29440.110209703445,82676,0,29440.110209703445,0.6989832520484924,29.86139070986005,1.3479254245758057,3003,50109.174132585526,0.6728290319442749,33.626086977802274,1.5096951723098757,0.6835501194000244,29.914303811257952,1.4427237510681152,3000 -21165.99251294136,1.0917017459869385,30280.33723807335,85039,0,30280.33723807335,0.6989135146141052,29.94726672862276,1.3375647068023682,3003,51450.02022433281,0.6680791974067688,33.650963677786145,1.537145972251892,0.6841328740119934,29.67939265029403,1.439337968826294,3000 -21775.885035037994,1.134784698486328,31120.81729412079,87401,0,31120.81729412079,0.7013421654701233,29.809605628979703,1.3306312561035156,3003,52900.50894832611,0.6687669157981873,33.083875804779154,1.5372868776321411,0.6854843497276306,30.13844842941912,1.4303168058395386,3000 -22333.71812582016,1.1713433265686035,31960.73653626442,89763,0,31960.73653626442,0.7024112939834595,30.16269517032829,1.3254374265670776,3003,54298.36872267723,0.6764796376228333,33.59193421065701,1.4874475002288818,0.686129093170166,30.10132998243684,1.422110080718994,3000 -23042.47913169861,1.2089464664459229,32800.644548892975,92124,0,32800.644548892975,0.7033408880233765,30.099988416489992,1.3194371461868286,3003,55847.14839839935,0.6730408668518066,33.473575376051215,1.5171024799346924,0.6871086359024048,29.50734041863528,1.4207348823547363,3000 -23572.74237060547,1.244916915893555,33640.74346327782,94486,0,33640.74346327782,0.7031317353248596,30.01753382188169,1.31481671333313,3003,57217.61907982826,0.6843194961547852,34.57505063779952,1.438670635223389,0.6879270076751709,30.146262777719443,1.4136199951171875,3000 -24280.65658521652,1.280604362487793,34480.9733145237,96848,0,34480.9733145237,0.7035849094390869,30.46707014258092,1.308138728141785,3003,58765.871055841446,0.678167462348938,33.81288736021676,1.4757825136184692,0.6871830224990845,30.26572122548636,1.4069184064865112,3000 -24843.572281837463,1.3168067932128906,35321.14249563217,99210,0,35321.14249563217,0.7043286561965942,30.263596635691624,1.3000746965408323,3003,60169.06457781792,0.6783736944198608,34.20753458293841,1.4781726598739624,0.6882989406585693,30.174071090892355,1.4027429819107056,3000 -25554.55659222603,1.3547017574310305,36161.22002506256,101572,0,36161.22002506256,0.7056533694267273,30.105600273315,1.2988643646240234,3003,61720.23471450806,0.6858225464820862,34.8028409637508,1.430709719657898,0.689452052116394,29.648971887837902,1.4000771045684814,3000 -26206.101365804672,1.3905627727508545,37001.32061576843,103935,0,37001.32061576843,0.7094184160232544,30.68149962599856,1.2878787517547607,3003,63211.98622179032,0.6802985072135925,34.38016920874783,1.4587417840957642,0.6917335391044617,30.13557535263701,1.3879741430282593,3000 -26817.12211108208,1.4279468059539795,37841.38295006752,106297,0,37841.38295006752,0.7088606357574463,30.694700225262544,1.2818655967712402,3003,64663.177958250046,0.6838071942329407,34.89865767951385,1.439373016357422,0.6919938921928406,30.601454222110416,1.38360857963562,3000 -27514.03621053696,1.465364694595337,38681.60069704056,108660,0,38681.60069704056,0.7099994421005249,30.653324968205062,1.2747658491134644,3003,66200.41820144653,0.6914603114128113,34.7940850778574,1.391718506813049,0.6927502155303955,30.70223529853148,1.3769105672836304,3000 -28090.752275705338,1.504340410232544,39521.62386965752,111022,0,39521.62386965752,0.7101040482521057,30.655470179904714,1.273871898651123,3003,67617.26778793335,0.6853946447372437,34.54282435835048,1.4290424585342407,0.6935685873031616,30.51653350310114,1.3749322891235352,3000 -28774.43212223053,1.543736219406128,40361.59369277954,113384,0,40361.59369277954,0.7107896208763123,30.63459049836028,1.2676246166229248,3003,69141.02885961533,0.6964467763900757,35.61580584857655,1.3651326894760132,0.6939033269882202,30.671236136957305,1.3725615739822388,3000 -29351.802276611328,1.5817480087280271,41201.82719826698,115747,0,41201.82719826698,0.7119516730308533,30.89395735573037,1.2626426219940186,3003,70558.74024915695,0.6904262900352478,35.42563523204688,1.390766739845276,0.693990170955658,30.734451759132792,1.3671605587005615,3000 -29920.225604772568,1.624889850616455,42041.80287861824,118109,0,42041.80287861824,0.7115681767463684,30.788337111092257,1.2619701623916626,3003,71967.25450706482,0.6913244128227234,34.82535315227721,1.3941733837127686,0.6944116950035095,30.55812479814081,1.3669185638427734,3000 -30519.5025408268,1.671435832977295,42881.7198240757,120471,0,42881.7198240757,0.7127418518066406,31.113236211947505,1.2584396600723269,3003,73406.56572580338,0.6960253119468689,35.065851404224986,1.363867998123169,0.6955400109291077,30.73799429326848,1.363057017326355,3000 -31178.701949357983,1.7121977806091309,43721.67870378494,122831,0,43721.67870378494,0.7138341665267944,31.07882520707422,1.256298542022705,3003,74905.83988952637,0.6932611465454102,35.56988353162563,1.3812857866287231,0.6957756280899048,30.736138291930217,1.3611180782318115,3000 -31770.429005146027,1.752969741821289,44561.77640080452,125193,0,44561.77640080452,0.7141479253768921,31.1441452580602,1.2538695335388184,3003,76337.77584266663,0.6958445310592651,35.29065202228168,1.3700186014175415,0.6958004236221313,30.81504363371959,1.3598471879959106,3000 -32368.1502096653,1.7922303676605225,45402.009125709534,127555,0,45402.009125709534,0.7144616842269897,31.0004361617364,1.2528407573699951,3003,77775.84038686752,0.6929237842559814,35.22217671734824,1.3827592134475708,0.6960979700088501,30.751015518353736,1.3584469556808472,3000 -32978.59933829308,1.833895206451416,46242.11846613884,129917,0,46242.11846613884,0.7143106460571289,30.975489200283697,1.2522014379501345,3003,79226.51109528542,0.6957194805145264,35.71389968451469,1.3678160905838013,0.6961227655410767,30.797391608123867,1.358590602874756,3000 -33585.4789557457,1.875113725662232,47082.26250171661,132279,0,47082.26250171661,0.7146244049072266,31.02888880102673,1.2520089149475098,3003,80673.64637875557,0.6954941153526306,35.85878042865687,1.3714572191238403,0.6958624124526978,30.77876783788234,1.3585927486419678,3000 -34185.63210296631,1.917802095413208,47922.3007247448,134640,0,47922.3007247448,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,82113.95312094688,0.6974707841873169,35.34378432455704,1.3609025478363037,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -34787.02144932747,1.959381103515625,48762.44181466103,137001,0,48762.44181466103,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,83555.59596848488,0.699103832244873,35.55390017405729,1.3539986610412598,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -35395.68215799332,2.0005979537963867,49602.64884185791,139363,0,49602.64884185791,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,85004.57704162598,0.6987770199775696,35.64343763817623,1.35377836227417,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -36002.202165842056,2.0436899662017822,50442.791553497314,141725,0,50442.791553497314,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,86451.35402941704,0.6979433298110962,35.16172339456949,1.3597429990768433,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -36608.51600170136,2.087331771850586,51283.00962305069,144087,0,51283.00962305069,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,87898.00025892258,0.6965329051017761,35.24575844551815,1.362047791481018,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -37216.56715321541,2.130098581314087,52123.04042816162,146449,0,52123.04042816162,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,89346.19786715508,0.6960597634315491,35.78783044286703,1.3612914085388184,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -37813.96529150009,2.1822049617767334,52963.1450676918,148811,0,52963.1450676918,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,90783.82259011269,0.6943530440330505,35.51247185279916,1.3751060962677002,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -38417.81984090805,2.22391128540039,53803.25436472893,151174,0,53803.25436472893,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,92227.89867305756,0.6944592595100403,35.85757723474082,1.378005027770996,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -39021.20381522179,2.266636848449707,54643.3279042244,153534,0,54643.3279042244,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,93671.47524333,0.6963521242141724,35.2620739329758,1.3660151958465576,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -39619.74911165237,2.3116683959960938,55483.48570561409,155896,0,55483.48570561409,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,95110.29526019096,0.6948073506355286,35.64392910309191,1.376955270767212,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -40224.372921705246,2.366755485534668,56323.527406692505,158257,0,56323.527406692505,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,96555.0881664753,0.6981350183486938,35.56427327468007,1.3524571657180786,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -40832.7593934536,2.4104344844818115,57163.6197514534,160619,0,57163.6197514534,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,98003.6810424328,0.6961396336555481,35.31053510832538,1.366602063179016,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -41437.19096302986,2.4548990726470947,58003.75202512741,162981,0,58003.75202512741,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,99448.36136484146,0.6966844201087952,35.52774077610771,1.3696157932281494,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -42045.46137666702,2.499814748764038,58843.80774831772,165342,0,58843.80774831772,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,100896.804625988,0.6953707337379456,35.801689686697244,1.377099871635437,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -42654.33971261978,2.5470104217529297,59683.72746658325,167701,0,59683.72746658325,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,102345.72543001176,0.6936147212982178,35.36641362085959,1.382083773612976,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -43241.46246051788,2.592487573623657,60523.73269152641,170063,0,60523.73269152641,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,103772.97212171556,0.6958931684494019,35.463748492037126,1.3654128313064575,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -43844.98872923851,2.6387388706207275,61363.77395796776,172424,0,61363.77395796776,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,105216.657037735,0.6973316669464111,35.595653273466674,1.3603744506835938,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -44447.63086080551,2.6856791973114014,62203.7920062542,174785,0,62203.7920062542,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,106659.43720054626,0.6968789100646973,35.25066549945658,1.3602083921432495,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -45048.39179825783,2.743337154388428,63043.82459115982,177146,0,63043.82459115982,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,108100.36066675186,0.6965775489807129,35.620679245105535,1.3650387525558472,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -45655.30615091324,2.792304277420044,63884.05560970306,179508,0,63884.05560970306,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,109547.62645864488,0.6962001919746399,35.35056694507423,1.3616358041763306,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -46262.27606844902,2.840443849563598,64724.220274209976,181869,0,64724.220274209976,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,110994.8827548027,0.6943336129188538,35.76284631446216,1.380504846572876,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -46864.49913692474,2.888885021209717,65564.20058608055,184230,0,65564.20058608055,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,112437.20860123634,0.6938932538032532,35.454537768174426,1.381688952445984,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -47464.00195264816,2.9393835067749023,66404.26647567749,186591,0,66404.26647567749,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,113876.89898991583,0.6969566345214844,35.72708821302105,1.3639436960220337,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -48076.644453287125,2.9872334003448486,67244.17678833008,188952,0,67244.17678833008,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,115329.57243132593,0.6942804455757141,35.36090369464056,1.3843278884887695,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -48673.42203640938,3.0373005867004395,68084.2696146965,191313,0,68084.2696146965,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,116766.5665898323,0.6944150328636169,35.722617039369105,1.3720909357070925,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -49281.62001180649,3.0965020656585693,68924.22136425972,193674,0,68924.22136425972,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,118214.84686422348,0.6968308687210083,35.36122057156541,1.3575750589370728,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -49898.25337815285,3.146656990051269,69764.31381583214,196035,0,69764.31381583214,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,119671.69554686546,0.6957708597183228,35.55747837773363,1.3637819290161133,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -50501.69980740547,3.2037465572357178,70604.34129023552,198396,0,70604.34129023552,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,121115.29785180092,0.6959868669509888,35.303194536835115,1.3676646947860718,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -51104.10656738281,3.2538113594055176,71444.29108786583,200758,0,71444.29108786583,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,122557.77553725244,0.6991065740585327,35.527593314348735,1.352362036705017,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -51697.62382078171,3.304860591888428,72284.34855413437,203119,0,72284.34855413437,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,123991.47338604929,0.6929601430892944,35.516672633384594,1.3850677013397217,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -52302.75385856629,3.355266809463501,73124.40326523781,205480,0,73124.40326523781,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,125436.77930998802,0.6939802765846252,35.385792217309884,1.375113010406494,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -52920.74584150314,3.409360647201538,73964.36746478081,207841,0,73964.36746478081,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,126894.86238741876,0.697274386882782,35.38380765178742,1.3620195388793943,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -53527.97238135338,3.470571994781494,74804.41276717186,210201,0,74804.41276717186,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,128342.26640844344,0.6986714601516724,35.170028935962286,1.350701928138733,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -54130.06509780884,3.5222976207733154,75644.38429164886,212562,0,75644.38429164886,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,129784.45477199554,0.699073314666748,35.37286067844723,1.3556746244430542,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -54745.98873233795,3.574280500411988,76484.49631166458,214923,0,76484.49631166458,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,131240.6146156788,0.6939623951911926,35.34666481423096,1.3788866996765137,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -55360.05347490311,3.6353468894958496,77324.43882918358,217284,0,77324.43882918358,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,132694.75473570824,0.6991239786148071,35.3880923790759,1.345736384391785,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -55964.84424567223,3.688246011734009,78164.45926618576,219645,0,78164.45926618576,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,134139.69259738922,0.6964263319969177,35.6649330863915,1.3674417734146118,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -56581.09203457832,3.739297866821289,79004.4024219513,222006,0,79004.4024219513,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,135596.00564146042,0.6972445249557495,35.77868590202388,1.3616267442703247,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -57194.87903165817,3.7929728031158447,79844.48212599754,224367,0,79844.48212599754,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,137049.9975218773,0.6952289342880249,35.30621007605715,1.3711453676223757,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -57798.66555047035,3.8469254970550537,80684.43466353416,226727,0,80684.43466353416,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,138493.86475038528,0.6972992420196533,35.395366730208174,1.3629869222640991,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -58390.79001760483,3.900733709335327,81524.6353867054,229088,0,81524.6353867054,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,139926.31528425217,0.6930255889892578,35.532078761547865,1.3843538761138916,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -59005.66600751877,3.956840753555298,82364.74400091171,231448,0,82364.74400091171,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,141381.4314494133,0.697636067867279,35.64368150155568,1.3580584526062012,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -59604.526320934296,4.011416673660278,83204.71192860603,233808,0,83204.71192860603,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,142820.386374712,0.6989806294441223,35.61818829777432,1.349443793296814,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -60196.6534178257,4.07750678062439,84044.6024339199,236168,0,84044.6024339199,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,144252.54349827766,0.6970456838607788,35.49304734732505,1.3616652488708496,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -60794.35615229607,4.133155822753906,84884.70688819885,238529,0,84884.70688819885,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,145690.4778895378,0.6947537064552307,35.53998794997194,1.375749588012695,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -61401.30293893814,4.189852237701416,85724.79488253593,240890,0,85724.79488253593,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,147137.6408853531,0.6979270577430725,35.702276531386154,1.3634157180786133,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -62015.22755002976,4.245082378387451,86564.79521155357,243250,0,86564.79521155357,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,148591.69333863258,0.6946918368339539,35.802417390103905,1.3715825080871582,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -62615.89183759689,4.302200794219971,87404.78988575935,245610,0,87404.78988575935,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,150032.48102235794,0.6926801204681396,35.56962577221613,1.3849468231201172,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -63219.3861579895,4.359274387359619,88244.95991444588,247970,0,88244.95991444588,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,151476.27511763573,0.6970303654670715,35.03329818639915,1.3621553182601929,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -63823.59740900993,4.417787790298462,89084.85697770119,250331,0,89084.85697770119,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,152920.51184105873,0.6966925263404846,35.4255910535756,1.3629604578018188,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -64430.72727441788,4.476908206939697,89924.99598526955,252691,0,89924.99598526955,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,154367.91337037086,0.693827211856842,35.42538705946164,1.379385232925415,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -65023.09778881073,4.534471273422241,90765.08199381828,255051,0,90765.08199381828,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,155800.5014474392,0.6940208673477173,35.44952550740223,1.374923586845398,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -65632.37321901321,4.591864347457886,91605.0241189003,257412,0,91605.0241189003,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,157249.84708356857,0.6943745017051697,35.384300906912,1.3740166425704956,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -66232.347230196,4.648017883300781,92445.21916270256,259773,0,92445.21916270256,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,158690.14498519895,0.6928779482841492,35.38147552282423,1.3875993490219116,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -66834.46612238884,4.7053093910217285,93285.35460090636,262134,0,93285.35460090636,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,160132.52825331688,0.6953383088111877,35.612914454909664,1.3716511726379397,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -67438.93201184273,4.778682708740234,94125.36615252496,264494,0,94125.36615252496,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,161577.15155768394,0.6960185170173645,35.742916565992346,1.37083899974823,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -68050.00629425049,4.836698532104492,94965.49518156052,266855,0,94965.49518156052,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,163028.48417282104,0.6939523220062256,35.54363592145518,1.3873682022094729,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -68652.84652853012,4.907514810562134,95805.45388770103,269215,0,95805.45388770103,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,164471.42721128464,0.6959852576255798,35.7707856211828,1.371500015258789,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -69258.74201393127,4.968581914901733,96645.50051903725,271576,0,96645.50051903725,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,165917.50253725052,0.694656491279602,35.713971528505866,1.378628492355347,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -69856.77969145775,5.073653221130371,97485.64288377762,273937,0,97485.64288377762,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,167355.8598549366,0.6959747076034546,35.55752862460547,1.3692622184753418,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -70456.73684310913,5.134132146835327,98325.51252913476,276296,0,98325.51252913476,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,168795.82131123543,0.6965546607971191,35.67617262040431,1.3616505861282349,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -71057.9582953453,5.194442510604858,99165.73145341872,278656,0,99165.73145341872,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,170237.39524626732,0.6967638731002808,35.39564349576458,1.3642253875732422,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -71652.07466721535,5.26610803604126,100005.89224982262,281016,0,100005.89224982262,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,171671.8169312477,0.6934937834739685,35.466408250843514,1.3861702680587769,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -72242.53067946434,5.33665919303894,100845.91227436066,283376,0,100845.91227436066,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,173102.43675160408,0.6972944736480713,35.66149794669415,1.3638888597488403,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -72841.66510772705,5.397650003433228,101685.87846279144,285737,0,101685.87846279144,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,174541.66915917397,0.6911959052085876,36.23396595613439,1.395319581031799,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -73444.53052544594,5.460918426513672,102525.98865199088,288099,0,102525.98865199088,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,175984.77796077728,0.6965842843055725,35.65843571065414,1.3646329641342163,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -74059.78676986694,5.522222518920898,103365.85811543465,290459,0,103365.85811543465,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,177440.03746771812,0.696201503276825,35.46740520559859,1.3672454357147217,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -74656.53228163719,5.5849597454071045,104205.91893053056,292820,0,104205.91893053056,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,178876.97866034508,0.695793867111206,35.59809594260861,1.368704319000244,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -75251.69767832756,5.647981643676758,105045.99169325829,295180,0,105045.99169325829,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,180312.35103321075,0.6961166858673096,35.2200033761678,1.3674596548080444,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -75853.98563599586,5.709059000015259,105885.86472034454,297540,0,105885.86472034454,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,181754.6443226337,0.6968651413917542,35.314651608415545,1.370174765586853,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -76456.5664255619,5.785603046417236,106725.71911907196,299900,0,106725.71911907196,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,183197.226900816,0.6927401423454285,35.545994958747805,1.3830639123916626,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -77054.19925522804,5.850619554519653,107565.61857748032,302260,0,107565.61857748032,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,184634.89792728424,0.6952672600746155,35.55978230786836,1.3717122077941897,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -77651.99628734589,5.915649890899658,108405.63617539406,304620,0,108405.63617539406,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,186072.84896922112,0.696880578994751,34.97261335988815,1.362842679023743,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -78251.87782359123,5.980376243591309,109245.5084617138,306979,0,109245.5084617138,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,187512.7414045334,0.6959228515625,35.62716315697208,1.370876431465149,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -78865.54597783089,6.046472072601318,110085.39775514604,309339,0,110085.39775514604,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,188966.4353232384,0.6993377804756165,35.87252412013725,1.3491086959838867,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -79461.58521485329,6.110290288925171,110925.37493872644,311699,0,110925.37493872644,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,190402.58885407448,0.6996234059333801,35.82919492644673,1.3519960641860962,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -80065.39432311058,6.174302101135254,111765.5486676693,314060,0,111765.5486676693,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,191846.70589494705,0.6977344155311584,35.35065572173739,1.3559484481811523,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -80676.55667924881,6.251188278198242,112605.41333723068,316420,0,112605.41333723068,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,193297.88295459747,0.6962667107582092,35.69826971741045,1.3646796941757202,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -81277.62100315094,6.328147649765015,113445.30772137642,318780,0,113445.30772137642,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,194738.9902229309,0.6989819407463074,35.58819065628888,1.3501715660095217,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -81881.28680968285,6.394540309906006,114285.18418955804,321140,0,114285.18418955804,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,196182.6717071533,0.6966158151626587,35.62989032088873,1.3639127016067505,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -82497.62356734276,6.459453105926514,115125.39649248125,323501,0,115125.39649248125,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,197639.3575992584,0.6984859704971313,35.67098677731126,1.3505821228027344,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -83107.43023467064,6.526806592941284,115965.46018266678,325862,0,115965.46018266678,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,199089.3668017388,0.6952317357063293,35.218722071193795,1.3720507621765137,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -83716.94771122932,6.594550371170044,116805.6553747654,328223,0,116805.6553747654,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,200539.21746993065,0.6982913613319397,35.58612960215996,1.3581576347351074,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -84320.30842804909,6.6605446338653564,117645.53719234468,330584,0,117645.53719234468,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,201982.5979168415,0.6947457790374756,35.72883824812817,1.3779659271240234,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -84918.95106649399,6.728163719177246,118485.5244398117,332945,0,118485.5244398117,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,203421.36900758743,0.6956863403320312,35.02622983467754,1.3694723844528198,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -85521.2403678894,6.795787334442139,119325.42063117027,335305,0,119325.42063117027,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,204863.69330143929,0.695704460144043,35.480177783502256,1.3688232898712158,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -86120.65092658997,6.865016460418701,120165.28549027444,337665,0,120165.28549027444,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,206303.11276698112,0.6960183382034302,35.55197694210756,1.3696424961090088,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -86741.76490736008,6.934324264526367,121005.14982128143,340025,0,121005.14982128143,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,207764.23371696472,0.6936178803443909,35.268918831625335,1.38359534740448,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -87348.51953816414,7.001046895980835,121845.14006876944,342385,0,121845.14006876944,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,209211.1178805828,0.6969581842422485,35.45381204623339,1.3596528768539429,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -87956.62901735306,7.070030212402344,122685.2105679512,344746,0,122685.2105679512,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,210659.43799066544,0.694857656955719,35.20722528657935,1.3747631311416626,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -88567.57115674019,7.1397809982299805,123525.30701708794,347107,0,123525.30701708794,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,212110.61843252185,0.6944318413734436,35.670044646723696,1.3756641149520874,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -89163.2274723053,7.209471464157104,124365.3011341095,349467,0,124365.3011341095,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,213546.4123835564,0.6982336044311523,35.713977052649426,1.3613815307617188,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -89769.20688509941,7.279955148696899,125205.46908950806,351827,0,125205.46908950806,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,214992.70389819145,0.6962584853172302,35.62077978803877,1.368923902511597,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -90371.01800084114,7.350381135940552,126045.5473601818,354189,0,126045.5473601818,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,216434.73520970345,0.6944153308868408,35.38775796358296,1.3725401163101196,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -90981.5489873886,7.420124530792236,126885.76126217842,356551,0,126885.76126217842,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,217885.62106704712,0.6954219341278076,35.53209505714748,1.3780286312103271,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -91577.32687735558,7.504851579666138,127725.61231327055,358911,0,127725.61231327055,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,219321.40892481804,0.6971335411071777,35.45657819681476,1.3615541458129885,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -92187.9036836624,7.576892137527466,128565.52987861632,361271,0,128565.52987861632,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,220772.04701256752,0.69556725025177,35.47585553964436,1.368157982826233,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -92778.26169657709,7.64824366569519,129405.45047450066,363631,0,129405.45047450066,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,222202.46889948845,0.6939976811408997,35.58371781025625,1.379857063293457,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -93371.30241632462,7.723176717758179,130245.50266051292,365992,0,130245.50266051292,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,223635.70874118805,0.6946594715118408,35.4348874325702,1.3834761381149292,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -93975.18136429788,7.796278238296509,131085.41303038597,368351,0,131085.41303038597,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,225079.64515781403,0.6929248571395874,35.70822070105537,1.3840899467468262,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -94576.20677280426,7.871663808822632,131925.42628622055,370711,0,131925.42628622055,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,226520.8309454918,0.6935230493545532,35.63322774950154,1.3815573453903198,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -95171.61591076852,7.943910598754883,132765.4077756405,373071,0,132765.4077756405,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,227956.367408514,0.693621814250946,35.99297319457071,1.3835389614105225,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -95771.58613538742,8.019490957260132,133605.52108120918,375431,0,133605.52108120918,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,229396.5992236137,0.6953145861625671,35.52966497028867,1.371348857879639,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -96367.69637274742,8.102221727371216,134445.5572388172,377790,0,134445.5572388172,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,230832.9020171165,0.6948238611221313,35.914049220026286,1.373692512512207,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -96965.33767175674,8.175270795822144,135285.55623698235,380149,0,135285.55623698235,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,232270.6904451847,0.6977683305740356,35.69085706221976,1.3623616695404053,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -97565.20195937157,8.250718832015991,136125.55805802345,382509,0,136125.55805802345,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,233710.7053785324,0.6954215168952942,35.408766270120154,1.3679215908050537,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -98171.5267598629,8.339428424835205,136965.75103139877,384870,0,136965.75103139877,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,235157.38303089145,0.6941781640052795,35.4485266232941,1.3764313459396362,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -98780.79803609848,8.41462779045105,137805.83716368675,387230,0,137805.83716368675,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,236606.8883280754,0.6959114074707031,35.79706608134528,1.3697524070739746,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -99370.363945961,8.490809679031372,138645.8393995762,389591,0,138645.8393995762,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,238036.6040613652,0.6969903111457825,35.59087751015178,1.3658790588378906,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -99969.31581687929,8.585108041763306,139485.86350560188,391951,0,139485.86350560188,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,239475.74658942223,0.6969181299209595,35.6003938792746,1.3643102645874023,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -100582.15226006508,8.660670042037964,140325.8711452484,394311,0,140325.8711452484,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,240928.74007606503,0.6953207850456238,35.67714170119371,1.3696646690368652,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -101183.69587278366,8.738826513290405,141165.98493552208,396671,0,141165.98493552208,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,242370.5478367805,0.6944743990898132,35.42201402751068,1.3710854053497314,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -101786.26503157616,8.816017627716064,142006.1650776863,399032,0,142006.1650776863,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,243813.446236372,0.6948761343955994,35.58164949728439,1.3741616010665894,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 -102384.04146027565,8.893699645996094,142350.1185748577,399999,0,142350.1185748577,0.7146011590957642,31.016732013552563,1.2520129680633545,3003,244755.2840101719,0.6971225738525391,35.5268341500284,1.3681684732437134,0.6957632303237915,30.80970163535293,1.3586411476135254,3000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 5248ad0e7..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,4173 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.213356,11.182037,,,,,,,,,,,,,,,,, -1,,,0.0004393419076222,11.199871063232422,4.8358820565900894e-11,0.0004835649742744,11.21283721923828,5.65026423809594e-10,3000.0,0.0007088489946909,11.209013938903809,2.4277196510573813e-10,3003.0,37.56307125091553,912.794314622879,37.56307125091553,875.231207370758,0.0,0.0 -100,0.14693731,8.213344,,,,,,,,,,,,,,,,, -200,0.3308044,7.509773,,,,,,,,,,,,,,,,, -300,0.5236129,6.851391,,,,,,,,,,,,,,,,, -400,0.44578367,6.305877,,,,,,,,,,,,,,,,, -500,0.52414525,5.9173756,,,,,,,,,,,,,,,,, -600,0.38458338,5.5770607,,,,,,,,,,,,,,,,, -700,0.6974736,5.3608584,,,,,,,,,,,,,,,,, -800,0.6120168,5.059975,,,,,,,,,,,,,,,,, -900,0.58503926,4.827595,,,,,,,,,,,,,,,,, -1000,0.580886,4.5902267,,,,,,,,,,,,,,,,, -1100,0.51060444,4.3634744,,,,,,,,,,,,,,,,, -1200,0.47316077,3.9739826,,,,,,,,,,,,,,,,, -1300,0.5378529,3.9052978,,,,,,,,,,,,,,,,, -1400,0.5965867,3.8373806,,,,,,,,,,,,,,,,, -1500,0.45145413,3.572915,,,,,,,,,,,,,,,,, -1600,0.49245533,3.443701,,,,,,,,,,,,,,,,, -1700,0.45678315,3.3697798,,,,,,,,,,,,,,,,, -1800,0.39821684,3.261728,,,,,,,,,,,,,,,,, -1900,0.42257917,3.2454245,,,,,,,,,,,,,,,,, -2000,0.5098632,3.2186077,,,,,,,,,,,,,,,,, -2100,0.39229417,3.0240798,,,,,,,,,,,,,,,,, -2200,0.41104838,3.0622473,,,,,,,,,,,,,,,,, -2300,0.36739907,3.0404456,,,,,,,,,,,,,,,,, -2362,,,0.5115175843238831,2.875618696212769,22.22694865293249,0.5132980346679688,2.8582029342651367,18.48454646661626,3000.0,0.5118703246116638,2.911105871200561,17.134560690987676,3003.0,877.7042384147644,2219.220136165619,877.7042384147644,1341.4127733707428,0.030313491821289,0.0 -2400,0.4015044,2.9275467,,,,,,,,,,,,,,,,, -2500,0.32818666,2.9162471,,,,,,,,,,,,,,,,, -2600,0.27551764,2.9021835,,,,,,,,,,,,,,,,, -2700,0.39139497,2.8334308,,,,,,,,,,,,,,,,, -2800,0.3412123,2.8206105,,,,,,,,,,,,,,,,, -2900,0.38241228,2.8377082,,,,,,,,,,,,,,,,, -3000,0.24998762,2.720911,,,,,,,,,,,,,,,,, -3100,0.28286573,2.6616883,,,,,,,,,,,,,,,,, -3200,0.38171792,2.6036963,,,,,,,,,,,,,,,,, -3300,0.21966009,2.5814257,,,,,,,,,,,,,,,,, -3400,0.24228534,2.6005118,,,,,,,,,,,,,,,,, -3500,0.19342835,2.4536471,,,,,,,,,,,,,,,,, -3600,0.22635147,2.5157945,,,,,,,,,,,,,,,,, -3700,0.25250697,2.512128,,,,,,,,,,,,,,,,, -3800,0.2181461,2.4963143,,,,,,,,,,,,,,,,, -3900,0.19793867,2.523616,,,,,,,,,,,,,,,,, -4000,0.17151879,2.5724604,,,,,,,,,,,,,,,,, -4100,0.19049057,2.401033,,,,,,,,,,,,,,,,, -4200,0.1918675,2.3807912,,,,,,,,,,,,,,,,, -4300,0.1770359,2.3606863,,,,,,,,,,,,,,,,, -4400,0.17523235,2.3947759,,,,,,,,,,,,,,,,, -4500,0.17206927,2.2635236,,,,,,,,,,,,,,,,, -4600,0.15057737,2.3321402,,,,,,,,,,,,,,,,, -4700,0.17554437,2.3017306,,,,,,,,,,,,,,,,, -4724,,,0.5785503387451172,2.2474136352539062,27.20758975158053,0.5907180309295654,2.1485509872436523,23.473832158007244,3000.0,0.5931090712547302,2.126178741455078,21.86833526566419,3003.0,1717.85214304924,3619.138933420181,1717.85214304924,1901.0871896743768,0.0534377098083496,0.0 -4800,0.17055821,2.3320284,,,,,,,,,,,,,,,,, -4900,0.16951783,2.3209202,,,,,,,,,,,,,,,,, -5000,0.15011258,2.248833,,,,,,,,,,,,,,,,, -5100,0.24967873,2.299984,,,,,,,,,,,,,,,,, -5200,0.15872857,2.2466493,,,,,,,,,,,,,,,,, -5300,0.15323399,2.2679071,,,,,,,,,,,,,,,,, -5400,0.16055706,2.2178032,,,,,,,,,,,,,,,,, -5500,0.20222929,2.3677273,,,,,,,,,,,,,,,,, -5600,0.17138265,2.2439172,,,,,,,,,,,,,,,,, -5700,0.20206453,2.3376725,,,,,,,,,,,,,,,,, -5800,0.16177407,2.1390314,,,,,,,,,,,,,,,,, -5900,0.17332532,2.192301,,,,,,,,,,,,,,,,, -6000,0.17457327,2.3280134,,,,,,,,,,,,,,,,, -6100,0.17616509,2.2061253,,,,,,,,,,,,,,,,, -6200,0.16329455,2.205389,,,,,,,,,,,,,,,,, -6300,0.13918546,2.1672442,,,,,,,,,,,,,,,,, -6400,0.1616085,2.1557803,,,,,,,,,,,,,,,,, -6500,0.20364232,2.180103,,,,,,,,,,,,,,,,, -6600,0.17310186,2.223889,,,,,,,,,,,,,,,,, -6700,0.2161158,2.1202698,,,,,,,,,,,,,,,,, -6800,0.14696479,2.1739635,,,,,,,,,,,,,,,,, -6900,0.16645156,2.1262388,,,,,,,,,,,,,,,,, -7000,0.21592905,2.15931,,,,,,,,,,,,,,,,, -7087,,,0.6039042472839355,2.021613597869873,29.45638549400617,0.6163345575332642,1.92962908744812,25.38209760851271,3000.0,0.6229620575904846,1.8805855512619016,24.41283791906644,3003.0,2558.087086677552,4896.891449689865,2558.087086677552,2338.508580446244,0.0780870914459228,0.0 -7100,0.14433536,2.1321495,,,,,,,,,,,,,,,,, -7200,0.16732867,2.1866086,,,,,,,,,,,,,,,,, -7300,0.15301861,2.0481124,,,,,,,,,,,,,,,,, -7400,0.15198383,2.1141803,,,,,,,,,,,,,,,,, -7500,0.16384356,2.0375922,,,,,,,,,,,,,,,,, -7600,0.16112939,2.1664407,,,,,,,,,,,,,,,,, -7700,0.19167738,2.1162634,,,,,,,,,,,,,,,,, -7800,0.17820911,2.0077562,,,,,,,,,,,,,,,,, -7900,0.17402951,2.0809054,,,,,,,,,,,,,,,,, -8000,0.23954074,2.1382868,,,,,,,,,,,,,,,,, -8100,0.1644447,2.074287,,,,,,,,,,,,,,,,, -8200,0.1770788,2.1647322,,,,,,,,,,,,,,,,, -8300,0.17437102,2.0921602,,,,,,,,,,,,,,,,, -8400,0.17391573,2.0843744,,,,,,,,,,,,,,,,, -8500,0.15053686,2.0653703,,,,,,,,,,,,,,,,, -8600,0.17241324,1.9338801,,,,,,,,,,,,,,,,, -8700,0.20749907,2.0957072,,,,,,,,,,,,,,,,, -8800,0.1878675,1.9957606,,,,,,,,,,,,,,,,, -8900,0.19806162,2.0873227,,,,,,,,,,,,,,,,, -9000,0.16026773,2.0440967,,,,,,,,,,,,,,,,, -9100,0.14621471,2.0887108,,,,,,,,,,,,,,,,, -9200,0.1493739,2.1283975,,,,,,,,,,,,,,,,, -9300,0.15271778,1.9844043,,,,,,,,,,,,,,,,, -9400,0.15412582,2.0074396,,,,,,,,,,,,,,,,, -9450,,,0.6103935837745667,1.9617842435836792,29.46675346875256,0.6299116015434265,1.8166937828063965,26.3658578673267,3000.0,0.6379989981651306,1.7546569108963013,25.231059478606504,3003.0,3398.317975282669,6189.491263628006,3398.317975282669,2790.7796771526337,0.1033430099487304,0.0 -9500,0.1681224,2.0586605,,,,,,,,,,,,,,,,, -9600,0.16703922,2.00917,,,,,,,,,,,,,,,,, -9700,0.22244903,1.9465091,,,,,,,,,,,,,,,,, -9800,0.15180221,1.9943324,,,,,,,,,,,,,,,,, -9900,0.18706626,2.026512,,,,,,,,,,,,,,,,, -10000,0.15038015,2.0003083,,,,,,,,,,,,,,,,, -10100,0.14634877,2.050979,,,,,,,,,,,,,,,,, -10200,0.19344024,1.9942619,,,,,,,,,,,,,,,,, -10300,0.15391438,1.9268517,,,,,,,,,,,,,,,,, -10400,0.15159766,2.0032847,,,,,,,,,,,,,,,,, -10500,0.21446058,2.0200691,,,,,,,,,,,,,,,,, -10600,0.2050654,1.9656118,,,,,,,,,,,,,,,,, -10700,0.16286685,2.0240798,,,,,,,,,,,,,,,,, -10800,0.16168678,2.0707083,,,,,,,,,,,,,,,,, -10900,0.17049249,2.042998,,,,,,,,,,,,,,,,, -11000,0.16083913,2.059155,,,,,,,,,,,,,,,,, -11100,0.18689409,1.9794635,,,,,,,,,,,,,,,,, -11200,0.31008148,1.9307448,,,,,,,,,,,,,,,,, -11300,0.16759764,2.03914,,,,,,,,,,,,,,,,, -11400,0.18001127,1.9416867,,,,,,,,,,,,,,,,, -11500,0.17704807,1.9631844,,,,,,,,,,,,,,,,, -11600,0.18839425,1.996102,,,,,,,,,,,,,,,,, -11700,0.1809571,1.9306686,,,,,,,,,,,,,,,,, -11800,0.15492542,1.9069272,,,,,,,,,,,,,,,,, -11814,,,0.6192044615745544,1.901554822921753,29.88688007428108,0.6387893557548523,1.749039649963379,26.820170193034105,3000.0,0.6472139954566956,1.68716299533844,25.66111560996473,3003.0,4238.32371544838,7491.281229496002,4238.32371544838,3252.4657728672028,0.130044937133789,0.0 -11900,0.30526748,2.0075524,,,,,,,,,,,,,,,,, -12000,0.20215201,1.9700801,,,,,,,,,,,,,,,,, -12100,0.25665227,1.9728769,,,,,,,,,,,,,,,,, -12200,0.18230799,1.9587096,,,,,,,,,,,,,,,,, -12300,0.2256921,2.0240896,,,,,,,,,,,,,,,,, -12400,0.16637217,1.8623705,,,,,,,,,,,,,,,,, -12500,0.20569064,2.067415,,,,,,,,,,,,,,,,, -12600,0.2204573,1.947272,,,,,,,,,,,,,,,,, -12700,0.17333552,1.9605399,,,,,,,,,,,,,,,,, -12800,0.21828708,1.9394834,,,,,,,,,,,,,,,,, -12900,0.26568595,1.9247277,,,,,,,,,,,,,,,,, -13000,0.25904965,1.9953908,,,,,,,,,,,,,,,,, -13100,0.17643598,1.9057902,,,,,,,,,,,,,,,,, -13200,0.18929596,1.9816093,,,,,,,,,,,,,,,,, -13300,0.18781276,1.9065386,,,,,,,,,,,,,,,,, -13400,0.19039223,1.8965623,,,,,,,,,,,,,,,,, -13500,0.1572901,1.9757812,,,,,,,,,,,,,,,,, -13600,0.16981903,1.9115247,,,,,,,,,,,,,,,,, -13700,0.18655059,1.9722239,,,,,,,,,,,,,,,,, -13800,0.19182302,1.8525635,,,,,,,,,,,,,,,,, -13900,0.19975221,1.9581745,,,,,,,,,,,,,,,,, -14000,0.19370823,1.98123,,,,,,,,,,,,,,,,, -14100,0.18766604,1.9078395,,,,,,,,,,,,,,,,, -14176,,,0.6295682191848755,1.819696068763733,30.651995885245505,0.6438357830047607,1.710865139961243,27.24148470012765,3000.0,0.6528964042663574,1.6480978727340698,26.293182704377568,3003.0,5078.380287885666,8866.706029176712,5078.380287885666,3787.736862421036,0.1561300754547119,0.0 -14200,0.21544836,1.9235126,,,,,,,,,,,,,,,,, -14300,0.23084043,1.9033622,,,,,,,,,,,,,,,,, -14400,0.19857608,1.9858909,,,,,,,,,,,,,,,,, -14500,0.19787608,1.9980928,,,,,,,,,,,,,,,,, -14600,0.18810761,1.9903506,,,,,,,,,,,,,,,,, -14700,0.15759508,1.9240509,,,,,,,,,,,,,,,,, -14800,0.18183172,1.9188054,,,,,,,,,,,,,,,,, -14900,0.25845745,1.9747562,,,,,,,,,,,,,,,,, -15000,0.18162069,1.8969551,,,,,,,,,,,,,,,,, -15100,0.18987554,1.9105936,,,,,,,,,,,,,,,,, -15200,0.22309344,1.9144714,,,,,,,,,,,,,,,,, -15300,0.17953466,1.9824002,,,,,,,,,,,,,,,,, -15400,0.184812,1.8914713,,,,,,,,,,,,,,,,, -15500,0.1836413,1.8195907,,,,,,,,,,,,,,,,, -15600,0.18071926,1.9809167,,,,,,,,,,,,,,,,, -15700,0.21625064,1.873389,,,,,,,,,,,,,,,,, -15800,0.24638654,1.9125706,,,,,,,,,,,,,,,,, -15900,0.16383418,1.9377257,,,,,,,,,,,,,,,,, -16000,0.15620105,1.9520093,,,,,,,,,,,,,,,,, -16100,0.1869788,1.9283726,,,,,,,,,,,,,,,,, -16200,0.25324425,1.9120023,,,,,,,,,,,,,,,,, -16300,0.24903613,1.9716882,,,,,,,,,,,,,,,,, -16400,0.21966262,1.9125446,,,,,,,,,,,,,,,,, -16500,0.22532998,1.9087907,,,,,,,,,,,,,,,,, -16538,,,0.6271182894706726,1.8289114236831665,30.539107747251062,0.6476299166679382,1.6841576099395752,27.438749015513427,3000.0,0.6548022031784058,1.6183347702026367,26.148766488248864,3003.0,5918.573413133621,10173.98059129715,5918.573413133621,4254.722022771835,0.1817579269409179,0.0 -16600,0.21413328,1.8647472,,,,,,,,,,,,,,,,, -16700,0.19346274,1.8804988,,,,,,,,,,,,,,,,, -16800,0.17886287,1.8128256,,,,,,,,,,,,,,,,, -16900,0.17716044,1.8690698,,,,,,,,,,,,,,,,, -17000,0.23046364,1.9321219,,,,,,,,,,,,,,,,, -17100,0.20217843,1.9770832,,,,,,,,,,,,,,,,, -17200,0.19554064,1.8784038,,,,,,,,,,,,,,,,, -17300,0.1845184,1.9207172,,,,,,,,,,,,,,,,, -17400,0.18460937,1.902699,,,,,,,,,,,,,,,,, -17500,0.17517634,1.9014817,,,,,,,,,,,,,,,,, -17600,0.18588786,1.935046,,,,,,,,,,,,,,,,, -17700,0.1897626,1.8800188,,,,,,,,,,,,,,,,, -17800,0.18667452,1.8795738,,,,,,,,,,,,,,,,, -17900,0.2372645,1.9325886,,,,,,,,,,,,,,,,, -18000,0.19429001,1.7970836,,,,,,,,,,,,,,,,, -18100,0.20613569,1.8667059,,,,,,,,,,,,,,,,, -18200,0.19756295,1.9221121,,,,,,,,,,,,,,,,, -18300,0.24390294,1.8722297,,,,,,,,,,,,,,,,, -18400,0.19496104,1.9087383,,,,,,,,,,,,,,,,, -18500,0.21120723,1.8618526,,,,,,,,,,,,,,,,, -18600,0.20070262,1.8672738,,,,,,,,,,,,,,,,, -18700,0.18743382,1.7707385,,,,,,,,,,,,,,,,, -18800,0.1792774,1.8199931,,,,,,,,,,,,,,,,, -18900,0.20843945,1.866677,,,,,,,,,,,,,,,,, -18901,,,0.6572504639625549,1.6087989807128906,32.22472034445667,0.651436448097229,1.6518349647521973,27.6112143291353,3000.0,0.6612864136695862,1.5909876823425293,26.80233486357677,3003.0,6759.098012447357,11496.171914815905,6759.098012447357,4736.290683746338,0.207653522491455,0.0 -19000,0.19638608,1.8240572,,,,,,,,,,,,,,,,, -19100,0.16740961,1.7921987,,,,,,,,,,,,,,,,, -19200,0.24412943,1.9038028,,,,,,,,,,,,,,,,, -19300,0.18421437,1.8327869,,,,,,,,,,,,,,,,, -19400,0.17896067,1.8454554,,,,,,,,,,,,,,,,, -19500,0.18607603,1.8829699,,,,,,,,,,,,,,,,, -19600,0.23161633,1.8870678,,,,,,,,,,,,,,,,, -19700,0.21330301,1.8527013,,,,,,,,,,,,,,,,, -19800,0.16851233,1.8884197,,,,,,,,,,,,,,,,, -19900,0.16705291,1.9087243,,,,,,,,,,,,,,,,, -20000,0.1877993,1.9265275,,,,,,,,,,,,,,,,, -20100,0.17922884,1.7634851,,,,,,,,,,,,,,,,, -20200,0.18942547,1.8566331,,,,,,,,,,,,,,,,, -20300,0.21072689,1.8759878,,,,,,,,,,,,,,,,, -20400,0.17689423,1.8912635,,,,,,,,,,,,,,,,, -20500,0.2203175,1.8563032,,,,,,,,,,,,,,,,, -20600,0.2591536,1.7606268,,,,,,,,,,,,,,,,, -20700,0.20333755,1.8625963,,,,,,,,,,,,,,,,, -20800,0.24110709,1.8522544,,,,,,,,,,,,,,,,, -20900,0.24649833,1.7564386,,,,,,,,,,,,,,,,, -21000,0.24872954,1.859168,,,,,,,,,,,,,,,,, -21100,0.20110252,1.8300899,,,,,,,,,,,,,,,,, -21200,0.17301174,1.9324801,,,,,,,,,,,,,,,,, -21263,,,0.6349077224731445,1.7626363039016724,30.705749056075486,0.6522423624992371,1.6453105211257937,27.61326392444762,3000.0,0.6627622246742249,1.5679388046264648,26.717003601278886,3003.0,7599.073549985886,12911.396076202393,7599.073549985886,5311.4401948452,0.2353913784027099,0.0 -21300,0.20865104,1.9244696,,,,,,,,,,,,,,,,, -21400,0.1809737,1.8760909,,,,,,,,,,,,,,,,, -21500,0.1802005,1.918411,,,,,,,,,,,,,,,,, -21600,0.22481415,1.8953816,,,,,,,,,,,,,,,,, -21700,0.19545321,1.7588176,,,,,,,,,,,,,,,,, -21800,0.1925381,1.7799636,,,,,,,,,,,,,,,,, -21900,0.18750548,1.6998714,,,,,,,,,,,,,,,,, -22000,0.16878428,1.8135641,,,,,,,,,,,,,,,,, -22100,0.29350442,1.8506607,,,,,,,,,,,,,,,,, -22200,0.23637752,1.8655927,,,,,,,,,,,,,,,,, -22300,0.20366216,1.7481809,,,,,,,,,,,,,,,,, -22400,0.20747435,1.9581847,,,,,,,,,,,,,,,,, -22500,0.19751011,1.8321671,,,,,,,,,,,,,,,,, -22600,0.18899567,1.7833397,,,,,,,,,,,,,,,,, -22700,0.19714788,1.9799137,,,,,,,,,,,,,,,,, -22800,0.24176292,1.8153994,,,,,,,,,,,,,,,,, -22900,0.21341656,1.8989424,,,,,,,,,,,,,,,,, -23000,0.26955867,1.8866036,,,,,,,,,,,,,,,,, -23100,0.18145132,1.7739064,,,,,,,,,,,,,,,,, -23200,0.20101467,1.8239391,,,,,,,,,,,,,,,,, -23300,0.19274461,1.8760021,,,,,,,,,,,,,,,,, -23400,0.20009403,1.8852415,,,,,,,,,,,,,,,,, -23500,0.17446469,1.7297028,,,,,,,,,,,,,,,,, -23600,0.2663972,1.9107267,,,,,,,,,,,,,,,,, -23626,,,0.6350238919258118,1.7771228551864624,30.85238259132561,0.6539162397384644,1.631062388420105,27.74533933170228,3000.0,0.6658416390419006,1.549932837486267,27.20188573853869,3003.0,8439.251418828964,14309.434327602386,8439.251418828964,5869.200685739517,0.2620282173156738,0.0 -23700,0.20949633,1.8733325,,,,,,,,,,,,,,,,, -23800,0.16916691,1.827492,,,,,,,,,,,,,,,,, -23900,0.17600188,1.8264408,,,,,,,,,,,,,,,,, -24000,0.2829011,1.859068,,,,,,,,,,,,,,,,, -24100,0.1976435,1.8499583,,,,,,,,,,,,,,,,, -24200,0.205842,1.8641057,,,,,,,,,,,,,,,,, -24300,0.20429255,1.8120432,,,,,,,,,,,,,,,,, -24400,0.18604216,1.7867688,,,,,,,,,,,,,,,,, -24500,0.25972915,1.8206583,,,,,,,,,,,,,,,,, -24600,0.19497555,1.869129,,,,,,,,,,,,,,,,, -24700,0.19470312,1.8040868,,,,,,,,,,,,,,,,, -24800,0.20205097,1.8860799,,,,,,,,,,,,,,,,, -24900,0.19528082,1.8644596,,,,,,,,,,,,,,,,, -25000,0.30543295,1.8677151,,,,,,,,,,,,,,,,, -25100,0.24897629,1.822686,,,,,,,,,,,,,,,,, -25200,0.17853636,1.8685455,,,,,,,,,,,,,,,,, -25300,0.2030966,1.8766562,,,,,,,,,,,,,,,,, -25400,0.18864793,1.8259633,,,,,,,,,,,,,,,,, -25500,0.26111802,1.8969522,,,,,,,,,,,,,,,,, -25600,0.20809825,1.7363968,,,,,,,,,,,,,,,,, -25700,0.26769483,1.8382332,,,,,,,,,,,,,,,,, -25800,0.20423391,1.8141156,,,,,,,,,,,,,,,,, -25900,0.1958267,1.8072983,,,,,,,,,,,,,,,,, -25989,,,0.6467454433441162,1.6869769096374512,31.536803516208945,0.6557761430740356,1.6167680025100708,27.694900356711823,3000.0,0.6658997535705566,1.5409635305404663,27.17006367961328,3003.0,9279.454516649246,15615.774932384493,9279.454516649246,6335.238257408142,0.2900032997131347,0.0 -26000,0.19434403,1.7941062,,,,,,,,,,,,,,,,, -26100,0.21882692,1.7546625,,,,,,,,,,,,,,,,, -26200,0.41933686,1.8440275,,,,,,,,,,,,,,,,, -26300,0.19193569,1.7767903,,,,,,,,,,,,,,,,, -26400,0.17818746,1.7885096,,,,,,,,,,,,,,,,, -26500,0.18453008,1.8589162,,,,,,,,,,,,,,,,, -26600,0.2567891,1.7946622,,,,,,,,,,,,,,,,, -26700,0.2480307,1.8270228,,,,,,,,,,,,,,,,, -26800,0.19719915,1.8276459,,,,,,,,,,,,,,,,, -26900,0.17942661,1.775542,,,,,,,,,,,,,,,,, -27000,0.25563464,1.8841646,,,,,,,,,,,,,,,,, -27100,0.19145522,1.7650303,,,,,,,,,,,,,,,,, -27200,0.18623674,1.8293853,,,,,,,,,,,,,,,,, -27300,0.23191656,1.7396117,,,,,,,,,,,,,,,,, -27400,0.2262443,1.8460664,,,,,,,,,,,,,,,,, -27500,0.19981427,1.8131276,,,,,,,,,,,,,,,,, -27600,0.21534486,1.8534061,,,,,,,,,,,,,,,,, -27700,0.19610538,1.8386625,,,,,,,,,,,,,,,,, -27800,0.27149698,1.908733,,,,,,,,,,,,,,,,, -27900,0.22875346,1.8318797,,,,,,,,,,,,,,,,, -28000,0.1730571,1.8448005,,,,,,,,,,,,,,,,, -28100,0.1868874,1.8408455,,,,,,,,,,,,,,,,, -28200,0.19338524,1.7847254,,,,,,,,,,,,,,,,, -28300,0.21665277,1.7778476,,,,,,,,,,,,,,,,, -28352,,,0.6393489837646484,1.732313871383667,30.885567339140408,0.6574996113777161,1.6025561094284058,27.800652263048256,3000.0,0.6687235236167908,1.530341863632202,27.420597130706373,3003.0,10119.372649908066,16948.23571062088,10119.372649908066,6827.682823181152,0.3175230026245117,0.0 -28400,0.22344843,1.8493433,,,,,,,,,,,,,,,,, -28500,0.1988718,1.8261284,,,,,,,,,,,,,,,,, -28600,0.20520991,1.8751662,,,,,,,,,,,,,,,,, -28700,0.23019966,1.8225127,,,,,,,,,,,,,,,,, -28800,0.19663107,1.7913467,,,,,,,,,,,,,,,,, -28900,0.1976533,1.863365,,,,,,,,,,,,,,,,, -29000,0.20214501,1.7736777,,,,,,,,,,,,,,,,, -29100,0.2160524,1.79157,,,,,,,,,,,,,,,,, -29200,0.20460723,1.8293072,,,,,,,,,,,,,,,,, -29300,0.21463332,1.8546588,,,,,,,,,,,,,,,,, -29400,0.21019039,1.7866784,,,,,,,,,,,,,,,,, -29500,0.311461,1.8032498,,,,,,,,,,,,,,,,, -29600,0.21724851,1.8590201,,,,,,,,,,,,,,,,, -29700,0.21431977,1.7453578,,,,,,,,,,,,,,,,, -29800,0.1816703,1.7329141,,,,,,,,,,,,,,,,, -29900,0.21449177,1.7124467,,,,,,,,,,,,,,,,, -30000,0.19205046,1.8256192,,,,,,,,,,,,,,,,, -30100,0.21856251,1.8234752,,,,,,,,,,,,,,,,, -30200,0.19389422,1.8365132,,,,,,,,,,,,,,,,, -30300,0.24220799,1.7865342,,,,,,,,,,,,,,,,, -30400,0.20950909,1.7501222,,,,,,,,,,,,,,,,, -30500,0.20833196,1.8683816,,,,,,,,,,,,,,,,, -30600,0.22210774,1.8570976,,,,,,,,,,,,,,,,, -30700,0.18737848,1.8148806,,,,,,,,,,,,,,,,, -30714,,,0.6419368982315063,1.733702540397644,31.349138689209163,0.6595702171325684,1.598743200302124,28.190826470172745,3000.0,0.6724769473075867,1.5186891555786133,27.439883614878845,3003.0,10959.414959907532,18290.868687152863,10959.414959907532,7330.172876596451,0.345379114151001,0.0 -30800,0.22069065,1.7338904,,,,,,,,,,,,,,,,, -30900,0.19092987,1.735092,,,,,,,,,,,,,,,,, -31000,0.22239912,1.8037956,,,,,,,,,,,,,,,,, -31100,0.1853248,1.7429748,,,,,,,,,,,,,,,,, -31200,0.22306588,1.7492208,,,,,,,,,,,,,,,,, -31300,0.20859602,1.8205051,,,,,,,,,,,,,,,,, -31400,0.20198822,1.8484741,,,,,,,,,,,,,,,,, -31500,0.19186708,1.8045678,,,,,,,,,,,,,,,,, -31600,0.1983203,1.7605498,,,,,,,,,,,,,,,,, -31700,0.19041286,1.8429193,,,,,,,,,,,,,,,,, -31800,0.23844106,1.7901479,,,,,,,,,,,,,,,,, -31900,0.19813544,1.8455904,,,,,,,,,,,,,,,,, -32000,0.19584353,1.8224949,,,,,,,,,,,,,,,,, -32100,0.20768929,1.824089,,,,,,,,,,,,,,,,, -32200,0.19694708,1.7865285,,,,,,,,,,,,,,,,, -32300,0.17236176,1.7656639,,,,,,,,,,,,,,,,, -32400,0.19329958,1.8197603,,,,,,,,,,,,,,,,, -32500,0.18645671,1.7649027,,,,,,,,,,,,,,,,, -32600,0.19972149,1.7914134,,,,,,,,,,,,,,,,, -32700,0.20724382,1.7499985,,,,,,,,,,,,,,,,, -32800,0.24175061,1.8178195,,,,,,,,,,,,,,,,, -32900,0.2875634,1.8481989,,,,,,,,,,,,,,,,, -33000,0.2155122,1.7685751,,,,,,,,,,,,,,,,, -33077,,,0.6454497575759888,1.700178146362305,31.41565148825551,0.6618516445159912,1.5838629007339478,28.273850035586964,3000.0,0.6725350022315979,1.505925536155701,27.469167647169623,3003.0,11799.583815813065,19616.332770109177,11799.583815813065,7815.368814945221,0.3722670078277588,0.0 -33100,0.1907089,1.7868291,,,,,,,,,,,,,,,,, -33200,0.19084172,1.7532258,,,,,,,,,,,,,,,,, -33300,0.20644984,1.742588,,,,,,,,,,,,,,,,, -33400,0.3029311,1.7504658,,,,,,,,,,,,,,,,, -33500,0.18459287,1.754449,,,,,,,,,,,,,,,,, -33600,0.2277761,1.7262766,,,,,,,,,,,,,,,,, -33700,0.23964176,1.7669687,,,,,,,,,,,,,,,,, -33800,0.20038484,1.7553205,,,,,,,,,,,,,,,,, -33900,0.17166354,1.7795044,,,,,,,,,,,,,,,,, -34000,0.22510377,1.8066005,,,,,,,,,,,,,,,,, -34100,0.20101772,1.7589127,,,,,,,,,,,,,,,,, -34200,0.18335952,1.8795146,,,,,,,,,,,,,,,,, -34300,0.19453436,1.7470331,,,,,,,,,,,,,,,,, -34400,0.18791647,1.7984636,,,,,,,,,,,,,,,,, -34500,0.19770305,1.7876269,,,,,,,,,,,,,,,,, -34600,0.28804928,1.8141437,,,,,,,,,,,,,,,,, -34700,0.19583414,1.7778687,,,,,,,,,,,,,,,,, -34800,0.24424931,1.8087721,,,,,,,,,,,,,,,,, -34900,0.18602838,1.7566571,,,,,,,,,,,,,,,,, -35000,0.17987894,1.7133913,,,,,,,,,,,,,,,,, -35100,0.32659143,1.7687746,,,,,,,,,,,,,,,,, -35200,0.20946893,1.8504438,,,,,,,,,,,,,,,,, -35300,0.23544952,1.8303082,,,,,,,,,,,,,,,,, -35400,0.20046094,1.795834,,,,,,,,,,,,,,,,, -35440,,,0.6440091729164124,1.7076191902160645,31.4159123954736,0.6626452207565308,1.580370545387268,27.312578619277826,3000.0,0.6744408011436462,1.50215744972229,27.711011086782264,3003.0,12639.751176595688,21135.57282590866,12639.751176595688,8494.342839956284,0.3995921611785888,0.0 -35500,0.20097739,1.8768219,,,,,,,,,,,,,,,,, -35600,0.21835156,1.767568,,,,,,,,,,,,,,,,, -35700,0.18503198,1.8207791,,,,,,,,,,,,,,,,, -35800,0.19951963,1.7209284,,,,,,,,,,,,,,,,, -35900,0.17904766,1.7659247,,,,,,,,,,,,,,,,, -36000,0.18362097,1.7785528,,,,,,,,,,,,,,,,, -36100,0.19569771,1.7600029,,,,,,,,,,,,,,,,, -36200,0.19672498,1.817204,,,,,,,,,,,,,,,,, -36300,0.20879813,1.7471675,,,,,,,,,,,,,,,,, -36400,0.2035357,1.8003501,,,,,,,,,,,,,,,,, -36500,0.20149031,1.8632914,,,,,,,,,,,,,,,,, -36600,0.21553575,1.7687907,,,,,,,,,,,,,,,,, -36700,0.20768124,1.7981333,,,,,,,,,,,,,,,,, -36800,0.41505742,1.8682371,,,,,,,,,,,,,,,,, -36900,0.199374,1.7403718,,,,,,,,,,,,,,,,, -37000,0.191836,1.767317,,,,,,,,,,,,,,,,, -37100,0.20288402,1.7301965,,,,,,,,,,,,,,,,, -37200,0.20201491,1.8175861,,,,,,,,,,,,,,,,, -37300,0.2182278,1.7705373,,,,,,,,,,,,,,,,, -37400,0.19013086,1.7462964,,,,,,,,,,,,,,,,, -37500,0.20747714,1.8125744,,,,,,,,,,,,,,,,, -37600,0.21318847,1.8376014,,,,,,,,,,,,,,,,, -37700,0.21703793,1.8352225,,,,,,,,,,,,,,,,, -37800,0.19477704,1.6893262,,,,,,,,,,,,,,,,, -37803,,,0.662269115447998,1.5625323057174685,32.45727219264936,0.6636123657226562,1.5751655101776123,28.33616940026225,3000.0,0.6750218272209167,1.493459939956665,27.95044166549037,3003.0,13479.938393354416,22491.015778541565,13479.938393354416,9009.498941421509,0.4266171455383301,0.0 -37900,0.18856853,1.8041677,,,,,,,,,,,,,,,,, -38000,0.19060639,1.7707165,,,,,,,,,,,,,,,,, -38100,0.19638899,1.748215,,,,,,,,,,,,,,,,, -38200,0.19133574,1.8108028,,,,,,,,,,,,,,,,, -38300,0.1931358,1.6895262,,,,,,,,,,,,,,,,, -38400,0.18967861,1.6982206,,,,,,,,,,,,,,,,, -38500,0.22098711,1.7819502,,,,,,,,,,,,,,,,, -38600,0.20038597,1.7455992,,,,,,,,,,,,,,,,, -38700,0.20639405,1.7647486,,,,,,,,,,,,,,,,, -38800,0.3893761,1.8359071,,,,,,,,,,,,,,,,, -38900,0.21647656,1.8215173,,,,,,,,,,,,,,,,, -39000,0.19171321,1.8133851,,,,,,,,,,,,,,,,, -39100,0.20005195,1.7164762,,,,,,,,,,,,,,,,, -39200,0.20797409,1.8022268,,,,,,,,,,,,,,,,, -39300,0.1982771,1.8283846,,,,,,,,,,,,,,,,, -39400,0.1918973,1.7635756,,,,,,,,,,,,,,,,, -39500,0.19476789,1.8282055,,,,,,,,,,,,,,,,, -39600,0.20399743,1.7083162,,,,,,,,,,,,,,,,, -39700,0.17669287,1.7375162,,,,,,,,,,,,,,,,, -39800,0.20899265,1.7088557,,,,,,,,,,,,,,,,, -39900,0.22248575,1.7638242,,,,,,,,,,,,,,,,, -40000,0.20016122,1.7317182,,,,,,,,,,,,,,,,, -40100,0.7485992,1.77092,,,,,,,,,,,,,,,,, -40165,,,0.6452246308326721,1.6897084712982178,31.45640492988856,0.6639595031738281,1.5691026449203491,28.45596245736317,3000.0,0.6757073998451233,1.4875746965408323,27.78759015814581,3003.0,14319.857029676436,23951.275722503666,14319.857029676436,9629.737220525742,0.4572136402130127,0.0 -40200,0.19345783,1.7636278,,,,,,,,,,,,,,,,, -40300,0.2137912,1.7479601,,,,,,,,,,,,,,,,, -40400,0.32459995,1.8191763,,,,,,,,,,,,,,,,, -40500,0.19333236,1.764458,,,,,,,,,,,,,,,,, -40600,0.20172535,1.7144713,,,,,,,,,,,,,,,,, -40700,0.19200696,1.7285162,,,,,,,,,,,,,,,,, -40800,0.19425155,1.8156747,,,,,,,,,,,,,,,,, -40900,0.20082867,1.8377724,,,,,,,,,,,,,,,,, -41000,0.22255613,1.727784,,,,,,,,,,,,,,,,, -41100,0.18277606,1.7932602,,,,,,,,,,,,,,,,, -41200,0.22568485,1.69988,,,,,,,,,,,,,,,,, -41300,0.22975877,1.858322,,,,,,,,,,,,,,,,, -41400,0.22038522,1.8520814,,,,,,,,,,,,,,,,, -41500,0.18778561,1.7803437,,,,,,,,,,,,,,,,, -41600,0.23337248,1.7940885,,,,,,,,,,,,,,,,, -41700,0.19765516,1.7560246,,,,,,,,,,,,,,,,, -41800,0.19164282,1.7814088,,,,,,,,,,,,,,,,, -41900,0.21973138,1.7711375,,,,,,,,,,,,,,,,, -42000,0.21952118,1.7562965,,,,,,,,,,,,,,,,, -42100,0.19198261,1.73349,,,,,,,,,,,,,,,,, -42200,0.18821502,1.7597694,,,,,,,,,,,,,,,,, -42300,0.21335478,1.75507,,,,,,,,,,,,,,,,, -42400,0.20265064,1.8574685,,,,,,,,,,,,,,,,, -42500,0.19358395,1.738821,,,,,,,,,,,,,,,,, -42526,,,0.642837405204773,1.7198340892791748,31.78450162676224,0.6651126146316528,1.5587149858474731,28.53575327431008,3000.0,0.6788914203643799,1.47401762008667,28.226920449596552,3003.0,15159.80151104927,25281.899538993835,15159.80151104927,10120.314247369766,0.4865961074829101,0.0 -42600,0.20439407,1.725049,,,,,,,,,,,,,,,,, -42700,0.18459098,1.7067597,,,,,,,,,,,,,,,,, -42800,0.22121423,1.7329556,,,,,,,,,,,,,,,,, -42900,0.20766903,1.8210264,,,,,,,,,,,,,,,,, -43000,0.19116586,1.7515693,,,,,,,,,,,,,,,,, -43100,0.19242558,1.7917054,,,,,,,,,,,,,,,,, -43200,0.19048233,1.7243661,,,,,,,,,,,,,,,,, -43300,0.1963575,1.7286499,,,,,,,,,,,,,,,,, -43400,0.20928356,1.8219948,,,,,,,,,,,,,,,,, -43500,0.23926081,1.8306772,,,,,,,,,,,,,,,,, -43600,0.23921922,1.7507621,,,,,,,,,,,,,,,,, -43700,0.19910006,1.7599642,,,,,,,,,,,,,,,,, -43800,0.19686052,1.8072895,,,,,,,,,,,,,,,,, -43900,0.20597719,1.7614366,,,,,,,,,,,,,,,,, -44000,0.1933793,1.7800686,,,,,,,,,,,,,,,,, -44100,0.21110153,1.7562561,,,,,,,,,,,,,,,,, -44200,0.2398427,1.7889745,,,,,,,,,,,,,,,,, -44300,0.19984329,1.7849157,,,,,,,,,,,,,,,,, -44400,0.23420586,1.6992555,,,,,,,,,,,,,,,,, -44500,0.19883789,1.8034892,,,,,,,,,,,,,,,,, -44600,0.18198875,1.6622787,,,,,,,,,,,,,,,,, -44700,0.21197586,1.7794077,,,,,,,,,,,,,,,,, -44800,0.18446752,1.6469657,,,,,,,,,,,,,,,,, -44888,,,0.6535823345184326,1.6201032400131226,32.10639596896463,0.6656457781791687,1.5512747764587402,28.288459980167566,3000.0,0.6773343086242676,1.4703198671340942,28.39207283070988,3003.0,15999.772643089294,26658.1717505455,15999.772643089294,10656.51029253006,0.519707202911377,0.0 -44900,0.20466638,1.6947523,,,,,,,,,,,,,,,,, -45000,0.1974258,1.7282304,,,,,,,,,,,,,,,,, -45100,0.21575315,1.7356274,,,,,,,,,,,,,,,,, -45200,0.205589,1.7654428,,,,,,,,,,,,,,,,, -45300,0.19399808,1.7687489,,,,,,,,,,,,,,,,, -45400,0.20689175,1.6570064,,,,,,,,,,,,,,,,, -45500,0.1911278,1.7840478,,,,,,,,,,,,,,,,, -45600,0.18330489,1.8219706,,,,,,,,,,,,,,,,, -45700,0.18836972,1.769389,,,,,,,,,,,,,,,,, -45800,0.21065278,1.7051104,,,,,,,,,,,,,,,,, -45900,0.18895014,1.7175059,,,,,,,,,,,,,,,,, -46000,0.3303895,1.7766769,,,,,,,,,,,,,,,,, -46100,0.18359722,1.6927412,,,,,,,,,,,,,,,,, -46200,0.26171172,1.7159423,,,,,,,,,,,,,,,,, -46300,0.21395388,1.8162068,,,,,,,,,,,,,,,,, -46400,0.18628234,1.7115272,,,,,,,,,,,,,,,,, -46500,0.21990451,1.7930387,,,,,,,,,,,,,,,,, -46600,0.20933454,1.6658384,,,,,,,,,,,,,,,,, -46700,0.19452204,1.8314514,,,,,,,,,,,,,,,,, -46800,0.22578287,1.7336279,,,,,,,,,,,,,,,,, -46900,0.21462545,1.7162459,,,,,,,,,,,,,,,,, -47000,0.1800359,1.6969826,,,,,,,,,,,,,,,,, -47100,0.18704228,1.7542486,,,,,,,,,,,,,,,,, -47200,0.24369,1.7511955,,,,,,,,,,,,,,,,, -47249,,,0.6474329829216003,1.6818699836730957,31.87093874194392,0.6667864918708801,1.546050190925598,28.37202242681093,3000.0,0.6796002984046936,1.4612714052200315,28.36357916679935,3003.0,16839.723083019257,28116.96264028549,16839.723083019257,11275.247904539108,0.5494298934936523,0.0 -47300,0.20478149,1.73279,,,,,,,,,,,,,,,,, -47400,0.20074035,1.7772787,,,,,,,,,,,,,,,,, -47500,0.18836564,1.6847464,,,,,,,,,,,,,,,,, -47600,0.18274829,1.7659516,,,,,,,,,,,,,,,,, -47700,0.20913582,1.7834488,,,,,,,,,,,,,,,,, -47800,0.48316926,1.697327,,,,,,,,,,,,,,,,, -47900,0.22054437,1.7519228,,,,,,,,,,,,,,,,, -48000,0.24831459,1.8037722,,,,,,,,,,,,,,,,, -48100,0.22984174,1.8305093,,,,,,,,,,,,,,,,, -48200,0.18945469,1.8123541,,,,,,,,,,,,,,,,, -48300,0.1919415,1.6720654,,,,,,,,,,,,,,,,, -48400,0.20026757,1.6599838,,,,,,,,,,,,,,,,, -48500,0.1861385,1.8028886,,,,,,,,,,,,,,,,, -48600,0.20150378,1.6982228,,,,,,,,,,,,,,,,, -48700,0.18817598,1.683529,,,,,,,,,,,,,,,,, -48800,0.21129568,1.7145307,,,,,,,,,,,,,,,,, -48900,0.20757487,1.8477324,,,,,,,,,,,,,,,,, -49000,0.19226885,1.7049165,,,,,,,,,,,,,,,,, -49100,0.24307406,1.7444998,,,,,,,,,,,,,,,,, -49200,0.20011017,1.6813848,,,,,,,,,,,,,,,,, -49300,0.21436015,1.7408415,,,,,,,,,,,,,,,,, -49400,0.18554637,1.6829652,,,,,,,,,,,,,,,,, -49500,0.23057045,1.7731601,,,,,,,,,,,,,,,,, -49600,0.20504315,1.7688377,,,,,,,,,,,,,,,,, -49611,,,0.6456196904182434,1.6879490613937378,31.93500844978868,0.6677908301353455,1.5375560522079468,28.72653963720832,3000.0,0.6794027090072632,1.4521712064743042,28.18185372945468,3003.0,17679.671919107437,29452.255368232727,17679.671919107437,11770.484377622604,0.5856420993804932,0.0 -49700,0.192741,1.6691866,,,,,,,,,,,,,,,,, -49800,0.20332462,1.6655701,,,,,,,,,,,,,,,,, -49900,0.2064584,1.7407675,,,,,,,,,,,,,,,,, -50000,0.2046409,1.7042774,,,,,,,,,,,,,,,,, -50100,0.19241403,1.7135878,,,,,,,,,,,,,,,,, -50200,0.20136398,1.7568436,,,,,,,,,,,,,,,,, -50300,0.20241553,1.754424,,,,,,,,,,,,,,,,, -50400,0.20175402,1.7909784,,,,,,,,,,,,,,,,, -50500,0.21774165,1.7431846,,,,,,,,,,,,,,,,, -50600,0.20224185,1.7847579,,,,,,,,,,,,,,,,, -50700,0.21063256,1.7659413,,,,,,,,,,,,,,,,, -50800,0.21760646,1.7234209,,,,,,,,,,,,,,,,, -50900,0.1962269,1.7528216,,,,,,,,,,,,,,,,, -51000,0.20603354,1.6469603,,,,,,,,,,,,,,,,, -51100,0.2246051,1.7516223,,,,,,,,,,,,,,,,, -51200,0.24165839,1.7678033,,,,,,,,,,,,,,,,, -51300,0.21671899,1.727873,,,,,,,,,,,,,,,,, -51400,0.21499582,1.7276963,,,,,,,,,,,,,,,,, -51500,0.21950331,1.661127,,,,,,,,,,,,,,,,, -51600,0.20327418,1.7609588,,,,,,,,,,,,,,,,, -51700,0.19282533,1.668,,,,,,,,,,,,,,,,, -51800,0.18908267,1.6819329,,,,,,,,,,,,,,,,, -51900,0.54089123,1.6656439,,,,,,,,,,,,,,,,, -51972,,,0.6551151871681213,1.6218953132629397,32.32626977076003,0.6678900122642517,1.531269073486328,28.88246091612095,3000.0,0.682877242565155,1.4454905986785889,28.490701505346525,3003.0,18519.571103811264,31073.305172920227,18519.571103811264,12551.531423330309,0.6143500804901123,0.0 -52000,0.1927435,1.8004353,,,,,,,,,,,,,,,,, -52100,0.2159095,1.8278495,,,,,,,,,,,,,,,,, -52200,0.19045886,1.6923846,,,,,,,,,,,,,,,,, -52300,0.20031674,1.740315,,,,,,,,,,,,,,,,, -52400,0.21353029,1.7324066,,,,,,,,,,,,,,,,, -52500,0.19986694,1.6657295,,,,,,,,,,,,,,,,, -52600,0.22676048,1.7307143,,,,,,,,,,,,,,,,, -52700,0.23197888,1.7590326,,,,,,,,,,,,,,,,, -52800,0.20091641,1.6837487,,,,,,,,,,,,,,,,, -52900,0.23222066,1.7300531,,,,,,,,,,,,,,,,, -53000,0.19876212,1.7672265,,,,,,,,,,,,,,,,, -53100,0.20083635,1.7691694,,,,,,,,,,,,,,,,, -53200,0.28227225,1.7651986,,,,,,,,,,,,,,,,, -53300,0.20371647,1.6251799,,,,,,,,,,,,,,,,, -53400,0.1977289,1.7407821,,,,,,,,,,,,,,,,, -53500,0.18311396,1.7033651,,,,,,,,,,,,,,,,, -53600,0.21769214,1.7119887,,,,,,,,,,,,,,,,, -53700,0.18148015,1.6375996,,,,,,,,,,,,,,,,, -53800,0.20995732,1.7780828,,,,,,,,,,,,,,,,, -53900,0.18713918,1.6738471,,,,,,,,,,,,,,,,, -54000,0.19526339,1.7510359,,,,,,,,,,,,,,,,, -54100,0.20597392,1.7567523,,,,,,,,,,,,,,,,, -54200,1.3072433,1.747323,,,,,,,,,,,,,,,,, -54300,0.21493311,1.7223781,,,,,,,,,,,,,,,,, -54334,,,0.6510263681411743,1.655605435371399,32.088404891666734,0.6696755290031433,1.5259546041488647,28.627795726474464,3000.0,0.682540237903595,1.4420441389083862,28.32718456710534,3003.0,19359.637244701385,32561.75861167908,19359.637244701385,13199.815757513046,0.6437675952911377,0.0 -54400,0.19406265,1.6206547,,,,,,,,,,,,,,,,, -54500,0.1818265,1.7168818,,,,,,,,,,,,,,,,, -54600,0.20338862,1.715709,,,,,,,,,,,,,,,,, -54700,0.19850273,1.7366902,,,,,,,,,,,,,,,,, -54800,0.18963042,1.7738761,,,,,,,,,,,,,,,,, -54900,0.23816916,1.7178224,,,,,,,,,,,,,,,,, -55000,0.18491326,1.7806659,,,,,,,,,,,,,,,,, -55100,0.22375138,1.6954465,,,,,,,,,,,,,,,,, -55200,0.21695806,1.6718214,,,,,,,,,,,,,,,,, -55300,0.20697592,1.7306225,,,,,,,,,,,,,,,,, -55400,0.21763548,1.7062132,,,,,,,,,,,,,,,,, -55500,0.18413341,1.6724291,,,,,,,,,,,,,,,,, -55600,0.2151566,1.6789973,,,,,,,,,,,,,,,,, -55700,0.20243017,1.6430985,,,,,,,,,,,,,,,,, -55800,0.21323021,1.7704531,,,,,,,,,,,,,,,,, -55900,0.21361828,1.7020664,,,,,,,,,,,,,,,,, -56000,0.22365205,1.7004164,,,,,,,,,,,,,,,,, -56100,0.22928809,1.8000197,,,,,,,,,,,,,,,,, -56200,0.20157705,1.7929988,,,,,,,,,,,,,,,,, -56300,0.1950041,1.7521718,,,,,,,,,,,,,,,,, -56400,0.22785911,1.8124222,,,,,,,,,,,,,,,,, -56500,0.20531976,1.6621417,,,,,,,,,,,,,,,,, -56600,0.21916902,1.7512912,,,,,,,,,,,,,,,,, -56696,,,0.6673276424407959,1.5445960760116575,33.360906251543774,0.6722297072410583,1.5170300006866455,29.03924901341143,3000.0,0.6850502490997314,1.4306808710098269,28.62387632662793,3003.0,20199.781319618225,33991.85821199417,20199.781319618225,13789.664745807648,0.678861141204834,0.0 -56700,0.20138562,1.7189102,,,,,,,,,,,,,,,,, -56800,0.20877482,1.7399092,,,,,,,,,,,,,,,,, -56900,0.18846567,1.64458,,,,,,,,,,,,,,,,, -57000,0.20218086,1.7536032,,,,,,,,,,,,,,,,, -57100,0.19016708,1.7165127,,,,,,,,,,,,,,,,, -57200,0.19582836,1.7408922,,,,,,,,,,,,,,,,, -57300,0.21623524,1.6650465,,,,,,,,,,,,,,,,, -57400,0.19393438,1.6683263,,,,,,,,,,,,,,,,, -57500,0.18499054,1.7170824,,,,,,,,,,,,,,,,, -57600,0.19581525,1.6967804,,,,,,,,,,,,,,,,, -57700,0.23063633,1.6278821,,,,,,,,,,,,,,,,, -57800,0.20592019,1.7157409,,,,,,,,,,,,,,,,, -57900,0.18801105,1.6342846,,,,,,,,,,,,,,,,, -58000,0.18547079,1.6504928,,,,,,,,,,,,,,,,, -58100,0.20759897,1.7043715,,,,,,,,,,,,,,,,, -58200,0.19271338,1.7157036,,,,,,,,,,,,,,,,, -58300,0.18886924,1.6479961,,,,,,,,,,,,,,,,, -58400,0.20001763,1.7477924,,,,,,,,,,,,,,,,, -58500,0.18975027,1.6651202,,,,,,,,,,,,,,,,, -58600,0.2006337,1.7579484,,,,,,,,,,,,,,,,, -58700,0.22353904,1.7043123,,,,,,,,,,,,,,,,, -58800,0.23056892,1.7075385,,,,,,,,,,,,,,,,, -58900,0.19446984,1.7052196,,,,,,,,,,,,,,,,, -59000,0.21579088,1.6816021,,,,,,,,,,,,,,,,, -59058,,,0.6547285318374634,1.627714991569519,32.10986175858725,0.6729241013526917,1.5107427835464478,29.169344862532743,3000.0,0.6850270628929138,1.4248459339141846,28.680309752241467,3003.0,21039.77918243408,35431.68813467026,21039.77918243408,14389.392583847046,0.7105739116668701,0.0 -59100,0.19297446,1.6476742,,,,,,,,,,,,,,,,, -59200,0.2111786,1.7666695,,,,,,,,,,,,,,,,, -59300,0.18270974,1.7315822,,,,,,,,,,,,,,,,, -59400,0.19980028,1.7118795,,,,,,,,,,,,,,,,, -59500,0.19166818,1.6714061,,,,,,,,,,,,,,,,, -59600,0.24800117,1.7188679,,,,,,,,,,,,,,,,, -59700,0.19953659,1.6723955,,,,,,,,,,,,,,,,, -59800,0.19574434,1.6375581,,,,,,,,,,,,,,,,, -59900,0.19110599,1.6833441,,,,,,,,,,,,,,,,, -60000,0.19656506,1.6951519,,,,,,,,,,,,,,,,, -60100,0.20742227,1.7109358,,,,,,,,,,,,,,,,, -60200,0.18849213,1.8109375,,,,,,,,,,,,,,,,, -60300,0.19573581,1.6615734,,,,,,,,,,,,,,,,, -60400,0.20064154,1.7880993,,,,,,,,,,,,,,,,, -60500,0.19938782,1.6936647,,,,,,,,,,,,,,,,, -60600,0.18159491,1.7166432,,,,,,,,,,,,,,,,, -60700,0.18881941,1.7253131,,,,,,,,,,,,,,,,, -60800,0.18920337,1.7284482,,,,,,,,,,,,,,,,, -60900,0.21588613,1.7675507,,,,,,,,,,,,,,,,, -61000,0.19667888,1.695377,,,,,,,,,,,,,,,,, -61100,0.22845301,1.7632748,,,,,,,,,,,,,,,,, -61200,0.21583417,1.701868,,,,,,,,,,,,,,,,, -61300,0.19135933,1.7314916,,,,,,,,,,,,,,,,, -61400,0.19309738,1.6889689,,,,,,,,,,,,,,,,, -61421,,,0.6533963084220886,1.6542311906814575,32.42845141366208,0.6724777221679688,1.5102542638778689,29.03658871071236,3000.0,0.6873162388801575,1.4158756732940674,28.77019337846428,3003.0,21879.856380939484,36863.674137830734,21879.856380939484,14981.200040340424,0.7415766716003418,0.0 -61500,0.2059674,1.703971,,,,,,,,,,,,,,,,, -61600,0.2109697,1.6678385,,,,,,,,,,,,,,,,, -61700,0.34152755,1.6975936,,,,,,,,,,,,,,,,, -61800,0.19807175,1.6466949,,,,,,,,,,,,,,,,, -61900,0.20869362,1.663367,,,,,,,,,,,,,,,,, -62000,0.25377902,1.8336556,,,,,,,,,,,,,,,,, -62100,0.19077346,1.6024846,,,,,,,,,,,,,,,,, -62200,0.19746697,1.6733944,,,,,,,,,,,,,,,,, -62300,0.18292943,1.6619549,,,,,,,,,,,,,,,,, -62400,0.21991843,1.7024813,,,,,,,,,,,,,,,,, -62500,0.18773559,1.7984512,,,,,,,,,,,,,,,,, -62600,0.19558957,1.6384051,,,,,,,,,,,,,,,,, -62700,0.1974283,1.7614068,,,,,,,,,,,,,,,,, -62800,0.2159225,1.7726703,,,,,,,,,,,,,,,,, -62900,0.20071758,1.7815512,,,,,,,,,,,,,,,,, -63000,0.25564367,1.7411549,,,,,,,,,,,,,,,,, -63100,0.19431245,1.6448731,,,,,,,,,,,,,,,,, -63200,0.20537457,1.663485,,,,,,,,,,,,,,,,, -63300,0.19801132,1.6659126,,,,,,,,,,,,,,,,, -63400,0.20752679,1.7741827,,,,,,,,,,,,,,,,, -63500,0.19507204,1.7119248,,,,,,,,,,,,,,,,, -63600,0.1917943,1.6855128,,,,,,,,,,,,,,,,, -63700,0.1964379,1.6876254,,,,,,,,,,,,,,,,, -63782,,,0.6634188890457153,1.5673255920410156,32.64206969909681,0.6739531755447388,1.4961543083190918,29.020543803915928,3000.0,0.6861774325370789,1.407378315925598,28.776615374007704,3003.0,22719.75696492195,38233.72824931145,22719.75696492195,15511.243689775469,0.7795286178588867,0.0 -63800,0.2126392,1.7564974,,,,,,,,,,,,,,,,, -63900,0.20558487,1.7749752,,,,,,,,,,,,,,,,, -64000,0.20883976,1.695978,,,,,,,,,,,,,,,,, -64100,0.24423994,1.6641299,,,,,,,,,,,,,,,,, -64200,0.20246282,1.7033997,,,,,,,,,,,,,,,,, -64300,0.1900191,1.7190498,,,,,,,,,,,,,,,,, -64400,0.20082137,1.676141,,,,,,,,,,,,,,,,, -64500,0.21182927,1.7138973,,,,,,,,,,,,,,,,, -64600,0.21004812,1.7168919,,,,,,,,,,,,,,,,, -64700,0.19884542,1.7143751,,,,,,,,,,,,,,,,, -64800,0.22821163,1.6068008,,,,,,,,,,,,,,,,, -64900,0.21832666,1.6341587,,,,,,,,,,,,,,,,, -65000,0.18216918,1.5889231,,,,,,,,,,,,,,,,, -65100,0.21422681,1.7516141,,,,,,,,,,,,,,,,, -65200,0.2122329,1.7463388,,,,,,,,,,,,,,,,, -65300,0.21899273,1.7144396,,,,,,,,,,,,,,,,, -65400,0.18838957,1.6764839,,,,,,,,,,,,,,,,, -65500,0.20016295,1.7022913,,,,,,,,,,,,,,,,, -65600,0.20888704,1.707392,,,,,,,,,,,,,,,,, -65700,0.19471169,1.602615,,,,,,,,,,,,,,,,, -65800,0.20974231,1.6211923,,,,,,,,,,,,,,,,, -65900,0.20252748,1.6474755,,,,,,,,,,,,,,,,, -66000,0.19624095,1.7716484,,,,,,,,,,,,,,,,, -66100,0.18323037,1.6329311,,,,,,,,,,,,,,,,, -66144,,,0.6572242975234985,1.6180578470230105,32.916716181700515,0.6756022572517395,1.4923619031906128,29.525804694230857,3000.0,0.6896519660949707,1.402874231338501,28.88593770250436,3003.0,23559.911987304688,39609.60906815529,23559.911987304688,16046.860683441162,0.8157198429107666,0.0 -66200,0.21067707,1.5893829,,,,,,,,,,,,,,,,, -66300,0.2376978,1.7977654,,,,,,,,,,,,,,,,, -66400,0.20887257,1.6526189,,,,,,,,,,,,,,,,, -66500,0.22451879,1.6563461,,,,,,,,,,,,,,,,, -66600,0.19064802,1.7113636,,,,,,,,,,,,,,,,, -66700,0.22601949,1.7005049,,,,,,,,,,,,,,,,, -66800,0.20758945,1.7011652,,,,,,,,,,,,,,,,, -66900,0.19396763,1.6280181,,,,,,,,,,,,,,,,, -67000,0.19592139,1.7091305,,,,,,,,,,,,,,,,, -67100,0.21242519,1.7284741,,,,,,,,,,,,,,,,, -67200,0.19724466,1.638931,,,,,,,,,,,,,,,,, -67300,0.21817029,1.6920508,,,,,,,,,,,,,,,,, -67400,0.21213181,1.6976051,,,,,,,,,,,,,,,,, -67500,0.19145763,1.6921201,,,,,,,,,,,,,,,,, -67600,0.69706386,1.6660848,,,,,,,,,,,,,,,,, -67700,0.19657135,1.6687993,,,,,,,,,,,,,,,,, -67800,0.2410309,1.7558842,,,,,,,,,,,,,,,,, -67900,0.22087386,1.6827316,,,,,,,,,,,,,,,,, -68000,0.19539298,1.6000324,,,,,,,,,,,,,,,,, -68100,0.20800544,1.7498726,,,,,,,,,,,,,,,,, -68200,0.2011662,1.6985902,,,,,,,,,,,,,,,,, -68300,0.22105302,1.7202069,,,,,,,,,,,,,,,,, -68400,0.19827813,1.6066831,,,,,,,,,,,,,,,,, -68500,0.20674047,1.6676275,,,,,,,,,,,,,,,,, -68505,,,0.6578312516212463,1.619662880897522,32.58494738936097,0.6772885322570801,1.481269598007202,29.29081053214644,3000.0,0.691499650478363,1.3894480466842651,28.92139389859646,3003.0,24399.88455271721,41000.14904499054,24399.88455271721,16597.32043647766,0.848196268081665,0.0 -68600,0.2284855,1.7312651,,,,,,,,,,,,,,,,, -68700,0.1912245,1.6208584,,,,,,,,,,,,,,,,, -68800,0.19500156,1.6672947,,,,,,,,,,,,,,,,, -68900,0.19802566,1.6952995,,,,,,,,,,,,,,,,, -69000,0.20579083,1.6973088,,,,,,,,,,,,,,,,, -69100,0.19872564,1.6369468,,,,,,,,,,,,,,,,, -69200,0.21406367,1.637305,,,,,,,,,,,,,,,,, -69300,0.21180326,1.6519312,,,,,,,,,,,,,,,,, -69400,0.18992989,1.700933,,,,,,,,,,,,,,,,, -69500,0.2131118,1.7179788,,,,,,,,,,,,,,,,, -69600,0.21163775,1.622847,,,,,,,,,,,,,,,,, -69700,0.19609562,1.6007937,,,,,,,,,,,,,,,,, -69800,0.20001458,1.6480563,,,,,,,,,,,,,,,,, -69900,0.21430774,1.6999253,,,,,,,,,,,,,,,,, -70000,0.18794265,1.6130472,,,,,,,,,,,,,,,,, -70100,0.24114226,1.7227051,,,,,,,,,,,,,,,,, -70200,0.64855516,1.7049525,,,,,,,,,,,,,,,,, -70300,0.20430838,1.6516044,,,,,,,,,,,,,,,,, -70400,0.19565694,1.7017449,,,,,,,,,,,,,,,,, -70500,0.20082341,1.7057325,,,,,,,,,,,,,,,,, -70600,0.22263475,1.7167926,,,,,,,,,,,,,,,,, -70700,0.20285909,1.6901855,,,,,,,,,,,,,,,,, -70800,0.19909246,1.6155146,,,,,,,,,,,,,,,,, -70867,,,0.6645039319992065,1.563105225563049,32.747214248427234,0.677238941192627,1.4762758016586304,29.26196255330381,3000.0,0.6918947100639343,1.386744737625122,29.064057246440694,3003.0,25239.83142542839,42516.241681814194,25239.83142542839,17273.362685203552,0.881028413772583,0.0 -70900,0.1979535,1.6882404,,,,,,,,,,,,,,,,, -71000,0.20519902,1.697546,,,,,,,,,,,,,,,,, -71100,0.19599721,1.6228652,,,,,,,,,,,,,,,,, -71200,0.19764532,1.6865327,,,,,,,,,,,,,,,,, -71300,0.18688582,1.6826001,,,,,,,,,,,,,,,,, -71400,0.20463973,1.6576469,,,,,,,,,,,,,,,,, -71500,0.19532433,1.6474935,,,,,,,,,,,,,,,,, -71600,0.19819036,1.6148555,,,,,,,,,,,,,,,,, -71700,0.19681329,1.6537206,,,,,,,,,,,,,,,,, -71800,0.20994195,1.7329789,,,,,,,,,,,,,,,,, -71900,0.19422127,1.6945632,,,,,,,,,,,,,,,,, -72000,0.19685234,1.6145275,,,,,,,,,,,,,,,,, -72100,0.19887353,1.61117,,,,,,,,,,,,,,,,, -72200,0.21932638,1.7253608,,,,,,,,,,,,,,,,, -72300,0.2160085,1.5721731,,,,,,,,,,,,,,,,, -72400,0.21201761,1.6627874,,,,,,,,,,,,,,,,, -72500,0.20322965,1.6982247,,,,,,,,,,,,,,,,, -72600,0.219697,1.6577584,,,,,,,,,,,,,,,,, -72700,0.24586107,1.7066642,,,,,,,,,,,,,,,,, -72800,0.20420532,1.7759433,,,,,,,,,,,,,,,,, -72900,0.20453814,1.5743543,,,,,,,,,,,,,,,,, -73000,0.19963355,1.6264299,,,,,,,,,,,,,,,,, -73100,0.1964194,1.5983244,,,,,,,,,,,,,,,,, -73200,0.21255724,1.6495565,,,,,,,,,,,,,,,,, -73229,,,0.6643652319908142,1.574614405632019,32.75438032641624,0.6768794059753418,1.4778293371200562,29.36006734681592,3000.0,0.6896170973777771,1.3845664262771606,29.036617066834708,3003.0,26080.004014968872,44008.78098034859,26080.004014968872,17925.62251186371,0.9149608612060548,0.0 -73300,0.20786282,1.5895399,,,,,,,,,,,,,,,,, -73400,0.23954621,1.6346297,,,,,,,,,,,,,,,,, -73500,0.18726511,1.636819,,,,,,,,,,,,,,,,, -73600,0.21468551,1.6601459,,,,,,,,,,,,,,,,, -73700,0.19306092,1.650134,,,,,,,,,,,,,,,,, -73800,0.20744756,1.6290811,,,,,,,,,,,,,,,,, -73900,0.20971356,1.7046363,,,,,,,,,,,,,,,,, -74000,0.21058407,1.6089815,,,,,,,,,,,,,,,,, -74100,0.21701416,1.6396065,,,,,,,,,,,,,,,,, -74200,0.19897662,1.6441885,,,,,,,,,,,,,,,,, -74300,0.19546014,1.7684789,,,,,,,,,,,,,,,,, -74400,0.20903987,1.650969,,,,,,,,,,,,,,,,, -74500,0.2242132,1.7259665,,,,,,,,,,,,,,,,, -74600,0.1997694,1.6118879,,,,,,,,,,,,,,,,, -74700,0.20774895,1.6275907,,,,,,,,,,,,,,,,, -74800,0.22842927,1.6923602,,,,,,,,,,,,,,,,, -74900,0.20439069,1.6319435,,,,,,,,,,,,,,,,, -75000,0.20723423,1.6288412,,,,,,,,,,,,,,,,, -75100,0.2112298,1.7201947,,,,,,,,,,,,,,,,, -75200,0.20476998,1.645267,,,,,,,,,,,,,,,,, -75300,0.1899895,1.6349459,,,,,,,,,,,,,,,,, -75400,0.20897198,1.6456319,,,,,,,,,,,,,,,,, -75500,0.18782543,1.5757755,,,,,,,,,,,,,,,,, -75591,,,0.6765106320381165,1.4892783164978027,33.98489511197334,0.6804627180099487,1.4630177021026611,29.06213421059305,3000.0,0.6935797333717346,1.3710495233535769,29.24812489131629,3003.0,26920.101365327835,45584.35468482971,26920.101365327835,18660.99428129196,0.9479324817657472,0.0 -75600,0.19634694,1.6841936,,,,,,,,,,,,,,,,, -75700,0.25627774,1.6546422,,,,,,,,,,,,,,,,, -75800,0.20663811,1.7012757,,,,,,,,,,,,,,,,, -75900,0.20560345,1.6772101,,,,,,,,,,,,,,,,, -76000,0.2004127,1.7131599,,,,,,,,,,,,,,,,, -76100,0.20944463,1.5690457,,,,,,,,,,,,,,,,, -76200,0.21146904,1.6439023,,,,,,,,,,,,,,,,, -76300,0.25045478,1.6739901,,,,,,,,,,,,,,,,, -76400,0.20705949,1.6611156,,,,,,,,,,,,,,,,, -76500,0.19271997,1.700352,,,,,,,,,,,,,,,,, -76600,0.26404995,1.6957512,,,,,,,,,,,,,,,,, -76700,0.2015699,1.6634429,,,,,,,,,,,,,,,,, -76800,0.20671774,1.7482991,,,,,,,,,,,,,,,,, -76900,0.19422178,1.66328,,,,,,,,,,,,,,,,, -77000,0.22032742,1.5890758,,,,,,,,,,,,,,,,, -77100,0.20468074,1.6193095,,,,,,,,,,,,,,,,, -77200,0.19801408,1.606203,,,,,,,,,,,,,,,,, -77300,0.197589,1.6549737,,,,,,,,,,,,,,,,, -77400,0.20971344,1.6330099,,,,,,,,,,,,,,,,, -77500,0.20013487,1.6515448,,,,,,,,,,,,,,,,, -77600,0.20028092,1.6407939,,,,,,,,,,,,,,,,, -77700,0.20886949,1.664612,,,,,,,,,,,,,,,,, -77800,0.20921429,1.6126795,,,,,,,,,,,,,,,,, -77900,0.19855644,1.641648,,,,,,,,,,,,,,,,, -77952,,,0.6673470735549927,1.5419315099716189,32.9429847730829,0.680648684501648,1.4586224555969238,28.36214799696568,3000.0,0.6959154009819031,1.3664885759353638,29.535897402171667,3003.0,27760.15807056427,47256.51785254479,27760.15807056427,19492.98771739006,0.9874765872955322,0.0 -78000,0.21165206,1.7406894,,,,,,,,,,,,,,,,, -78100,0.21183579,1.6372471,,,,,,,,,,,,,,,,, -78200,0.22201128,1.7079223,,,,,,,,,,,,,,,,, -78300,0.19804783,1.611301,,,,,,,,,,,,,,,,, -78400,0.23021829,1.5925268,,,,,,,,,,,,,,,,, -78500,0.19514222,1.6213865,,,,,,,,,,,,,,,,, -78600,0.22214955,1.7187431,,,,,,,,,,,,,,,,, -78700,0.18904926,1.5968838,,,,,,,,,,,,,,,,, -78800,0.21468262,1.6448398,,,,,,,,,,,,,,,,, -78900,0.20438577,1.6068183,,,,,,,,,,,,,,,,, -79000,0.24066673,1.6083602,,,,,,,,,,,,,,,,, -79100,0.20362443,1.6557724,,,,,,,,,,,,,,,,, -79200,0.199625,1.602887,,,,,,,,,,,,,,,,, -79300,0.2289076,1.602509,,,,,,,,,,,,,,,,, -79400,0.20068109,1.6182042,,,,,,,,,,,,,,,,, -79500,0.21614729,1.6223404,,,,,,,,,,,,,,,,, -79600,0.20838064,1.6360795,,,,,,,,,,,,,,,,, -79700,0.20493303,1.768866,,,,,,,,,,,,,,,,, -79800,0.21869895,1.6207885,,,,,,,,,,,,,,,,, -79900,0.20537207,1.5432136,,,,,,,,,,,,,,,,, -80000,0.19670436,1.5569971,,,,,,,,,,,,,,,,, -80100,0.3617703,1.6587687,,,,,,,,,,,,,,,,, -80200,0.19226849,1.6119152,,,,,,,,,,,,,,,,, -80300,0.20274056,1.7236468,,,,,,,,,,,,,,,,, -80314,,,0.6680836081504822,1.5501035451889038,32.917216136696155,0.6810950636863708,1.452638030052185,29.69538787912534,3000.0,0.6945790648460388,1.360152244567871,29.35732571126601,3003.0,28600.16921401024,48632.22001385689,28600.16921401024,20028.57369875908,1.0211431980133057,0.0 -80400,0.20557688,1.6286845,,,,,,,,,,,,,,,,, -80500,0.20901507,1.6462253,,,,,,,,,,,,,,,,, -80600,0.18895222,1.5850022,,,,,,,,,,,,,,,,, -80700,0.2085394,1.6120025,,,,,,,,,,,,,,,,, -80800,0.21916318,1.7220603,,,,,,,,,,,,,,,,, -80900,0.21353948,1.6370232,,,,,,,,,,,,,,,,, -81000,0.2009991,1.661464,,,,,,,,,,,,,,,,, -81100,0.20472404,1.6854033,,,,,,,,,,,,,,,,, -81200,0.18975085,1.5641402,,,,,,,,,,,,,,,,, -81300,0.2322299,1.5898771,,,,,,,,,,,,,,,,, -81400,0.20912553,1.6180906,,,,,,,,,,,,,,,,, -81500,0.19616401,1.6323683,,,,,,,,,,,,,,,,, -81600,0.21921629,1.636319,,,,,,,,,,,,,,,,, -81700,0.21353064,1.642074,,,,,,,,,,,,,,,,, -81800,0.19711359,1.5911859,,,,,,,,,,,,,,,,, -81900,0.21363156,1.6920102,,,,,,,,,,,,,,,,, -82000,0.21531387,1.6055858,,,,,,,,,,,,,,,,, -82100,0.19820802,1.6172156,,,,,,,,,,,,,,,,, -82200,0.19759062,1.6528763,,,,,,,,,,,,,,,,, -82300,0.20029894,1.7067875,,,,,,,,,,,,,,,,, -82400,0.20891476,1.5953239,,,,,,,,,,,,,,,,, -82500,0.20697495,1.7111915,,,,,,,,,,,,,,,,, -82600,0.20709954,1.606265,,,,,,,,,,,,,,,,, -82676,,,0.6728290319442749,1.5096951723098757,33.626086977802274,0.6835501194000244,1.4427237510681152,29.914303811257952,3000.0,0.6989832520484924,1.3479254245758057,29.86139070986005,3003.0,29440.110209703445,50109.174132585526,29440.110209703445,20665.47934579849,1.05684494972229,0.0 -82700,0.21371014,1.6566834,,,,,,,,,,,,,,,,, -82800,0.2580029,1.6264136,,,,,,,,,,,,,,,,, -82900,0.20716691,1.5715499,,,,,,,,,,,,,,,,, -83000,0.19978201,1.5968982,,,,,,,,,,,,,,,,, -83100,0.21758223,1.6565586,,,,,,,,,,,,,,,,, -83200,0.19079286,1.5907308,,,,,,,,,,,,,,,,, -83300,0.2594124,1.664854,,,,,,,,,,,,,,,,, -83400,0.21189973,1.6475416,,,,,,,,,,,,,,,,, -83500,0.19680835,1.6423423,,,,,,,,,,,,,,,,, -83600,0.25887707,1.6601249,,,,,,,,,,,,,,,,, -83700,0.20664044,1.6461045,,,,,,,,,,,,,,,,, -83800,0.19903964,1.6787025,,,,,,,,,,,,,,,,, -83900,0.20923764,1.6680617,,,,,,,,,,,,,,,,, -84000,0.21084425,1.5901653,,,,,,,,,,,,,,,,, -84100,0.20575288,1.5958503,,,,,,,,,,,,,,,,, -84200,0.21553726,1.6257375,,,,,,,,,,,,,,,,, -84300,0.21578394,1.6026756,,,,,,,,,,,,,,,,, -84400,0.21700637,1.6403568,,,,,,,,,,,,,,,,, -84500,0.2110423,1.6072338,,,,,,,,,,,,,,,,, -84600,0.21396194,1.5729935,,,,,,,,,,,,,,,,, -84700,0.19432315,1.6060139,,,,,,,,,,,,,,,,, -84800,0.21213986,1.617723,,,,,,,,,,,,,,,,, -84900,0.20552525,1.6234612,,,,,,,,,,,,,,,,, -85000,0.19970146,1.603718,,,,,,,,,,,,,,,,, -85039,,,0.6680791974067688,1.537145972251892,33.650963677786145,0.6841328740119934,1.439337968826294,29.67939265029403,3000.0,0.6989135146141052,1.3375647068023682,29.94726672862276,3003.0,30280.33723807335,51450.02022433281,30280.33723807335,21165.99251294136,1.0917017459869385,0.0 -85100,0.2069972,1.5037359,,,,,,,,,,,,,,,,, -85200,0.21238373,1.6576388,,,,,,,,,,,,,,,,, -85300,0.20735449,1.6779258,,,,,,,,,,,,,,,,, -85400,0.21483721,1.5904562,,,,,,,,,,,,,,,,, -85500,0.2092296,1.5408584,,,,,,,,,,,,,,,,, -85600,0.21581933,1.6453532,,,,,,,,,,,,,,,,, -85700,0.20536993,1.5784646,,,,,,,,,,,,,,,,, -85800,0.20583555,1.6191157,,,,,,,,,,,,,,,,, -85900,0.26558086,1.5965576,,,,,,,,,,,,,,,,, -86000,0.21569696,1.621495,,,,,,,,,,,,,,,,, -86100,0.21728504,1.636649,,,,,,,,,,,,,,,,, -86200,0.20271115,1.5743676,,,,,,,,,,,,,,,,, -86300,0.21020553,1.6138784,,,,,,,,,,,,,,,,, -86400,0.22203171,1.5929347,,,,,,,,,,,,,,,,, -86500,0.21244158,1.5977646,,,,,,,,,,,,,,,,, -86600,0.20073317,1.5836282,,,,,,,,,,,,,,,,, -86700,0.19891495,1.6214529,,,,,,,,,,,,,,,,, -86800,0.2203209,1.5994736,,,,,,,,,,,,,,,,, -86900,0.23732752,1.6891508,,,,,,,,,,,,,,,,, -87000,0.21110772,1.6209484,,,,,,,,,,,,,,,,, -87100,0.21026301,1.6203195,,,,,,,,,,,,,,,,, -87200,0.22620282,1.5686957,,,,,,,,,,,,,,,,, -87300,0.20367585,1.5610617,,,,,,,,,,,,,,,,, -87400,0.21561266,1.6010624,,,,,,,,,,,,,,,,, -87401,,,0.6687669157981873,1.5372868776321411,33.083875804779154,0.6854843497276306,1.4303168058395386,30.13844842941912,3000.0,0.7013421654701233,1.3306312561035156,29.809605628979703,3003.0,31120.81729412079,52900.50894832611,31120.81729412079,21775.885035037994,1.134784698486328,0.0 -87500,0.21346433,1.5753082,,,,,,,,,,,,,,,,, -87600,0.22027451,1.6838413,,,,,,,,,,,,,,,,, -87700,0.2094441,1.6173725,,,,,,,,,,,,,,,,, -87800,0.20822895,1.5265856,,,,,,,,,,,,,,,,, -87900,0.20835367,1.5794982,,,,,,,,,,,,,,,,, -88000,0.2059425,1.565209,,,,,,,,,,,,,,,,, -88100,0.21532652,1.632283,,,,,,,,,,,,,,,,, -88200,0.20527881,1.5989615,,,,,,,,,,,,,,,,, -88300,0.21372913,1.5450568,,,,,,,,,,,,,,,,, -88400,0.23791063,1.6045638,,,,,,,,,,,,,,,,, -88500,0.19426124,1.546443,,,,,,,,,,,,,,,,, -88600,0.20466794,1.5893667,,,,,,,,,,,,,,,,, -88700,0.21666229,1.608302,,,,,,,,,,,,,,,,, -88800,0.20801651,1.6108229,,,,,,,,,,,,,,,,, -88900,0.20966737,1.4983554,,,,,,,,,,,,,,,,, -89000,0.19910476,1.5269026,,,,,,,,,,,,,,,,, -89100,0.20764512,1.5505828,,,,,,,,,,,,,,,,, -89200,0.21096149,1.5529962,,,,,,,,,,,,,,,,, -89300,0.20427798,1.539712,,,,,,,,,,,,,,,,, -89400,0.21905519,1.5302994,,,,,,,,,,,,,,,,, -89500,0.22801387,1.632408,,,,,,,,,,,,,,,,, -89600,0.2076763,1.6107813,,,,,,,,,,,,,,,,, -89700,0.22551207,1.6327603,,,,,,,,,,,,,,,,, -89763,,,0.6764796376228333,1.4874475002288818,33.59193421065701,0.686129093170166,1.422110080718994,30.10132998243684,3000.0,0.7024112939834595,1.3254374265670776,30.16269517032829,3003.0,31960.73653626442,54298.36872267723,31960.73653626442,22333.71812582016,1.1713433265686035,0.0 -89800,0.21156184,1.5382904,,,,,,,,,,,,,,,,, -89900,0.20726593,1.5853453,,,,,,,,,,,,,,,,, -90000,0.21516316,1.5953364,,,,,,,,,,,,,,,,, -90100,0.22062887,1.5855277,,,,,,,,,,,,,,,,, -90200,0.20976268,1.5925416,,,,,,,,,,,,,,,,, -90300,0.20735602,1.673403,,,,,,,,,,,,,,,,, -90400,0.21992755,1.562702,,,,,,,,,,,,,,,,, -90500,0.23676455,1.5495455,,,,,,,,,,,,,,,,, -90600,0.2122785,1.5939897,,,,,,,,,,,,,,,,, -90700,0.21353123,1.5444473,,,,,,,,,,,,,,,,, -90800,0.20356467,1.5914787,,,,,,,,,,,,,,,,, -90900,0.21635193,1.5455763,,,,,,,,,,,,,,,,, -91000,0.21139078,1.5731688,,,,,,,,,,,,,,,,, -91100,0.2101344,1.5409077,,,,,,,,,,,,,,,,, -91200,0.22701946,1.680758,,,,,,,,,,,,,,,,, -91300,0.21126656,1.5611217,,,,,,,,,,,,,,,,, -91400,0.20998275,1.5640218,,,,,,,,,,,,,,,,, -91500,0.21193619,1.5751278,,,,,,,,,,,,,,,,, -91600,0.19491266,1.5779033,,,,,,,,,,,,,,,,, -91700,0.20925473,1.5819677,,,,,,,,,,,,,,,,, -91800,0.3344961,1.6840985,,,,,,,,,,,,,,,,, -91900,0.21054514,1.6662843,,,,,,,,,,,,,,,,, -92000,0.21528311,1.5791906,,,,,,,,,,,,,,,,, -92100,0.21458688,1.581951,,,,,,,,,,,,,,,,, -92124,,,0.6730408668518066,1.5171024799346924,33.473575376051215,0.6871086359024048,1.4207348823547363,29.50734041863528,3000.0,0.7033408880233765,1.3194371461868286,30.099988416489992,3003.0,32800.644548892975,55847.14839839935,32800.644548892975,23042.47913169861,1.2089464664459229,0.0 -92200,0.20124051,1.5795603,,,,,,,,,,,,,,,,, -92300,0.21123485,1.5798708,,,,,,,,,,,,,,,,, -92400,0.24349353,1.6224531,,,,,,,,,,,,,,,,, -92500,0.21852307,1.5964513,,,,,,,,,,,,,,,,, -92600,0.21059532,1.5718062,,,,,,,,,,,,,,,,, -92700,0.20860083,1.5830379,,,,,,,,,,,,,,,,, -92800,0.20792736,1.5349896,,,,,,,,,,,,,,,,, -92900,0.23217799,1.6800944,,,,,,,,,,,,,,,,, -93000,0.21385445,1.6233522,,,,,,,,,,,,,,,,, -93100,0.22661346,1.5689007,,,,,,,,,,,,,,,,, -93200,0.19838311,1.5156014,,,,,,,,,,,,,,,,, -93300,0.22813804,1.6165192,,,,,,,,,,,,,,,,, -93400,0.21723573,1.5344752,,,,,,,,,,,,,,,,, -93500,0.20953704,1.5761249,,,,,,,,,,,,,,,,, -93600,0.22119445,1.5616268,,,,,,,,,,,,,,,,, -93700,0.2182374,1.5293176,,,,,,,,,,,,,,,,, -93800,0.20414142,1.5064778,,,,,,,,,,,,,,,,, -93900,0.22842926,1.6228235,,,,,,,,,,,,,,,,, -94000,0.20498598,1.5410651,,,,,,,,,,,,,,,,, -94100,0.20225638,1.5220128,,,,,,,,,,,,,,,,, -94200,0.21577655,1.5878849,,,,,,,,,,,,,,,,, -94300,0.21476224,1.6017175,,,,,,,,,,,,,,,,, -94400,0.22845104,1.6529617,,,,,,,,,,,,,,,,, -94486,,,0.6843194961547852,1.438670635223389,34.57505063779952,0.6879270076751709,1.4136199951171875,30.146262777719443,3000.0,0.7031317353248596,1.31481671333313,30.01753382188169,3003.0,33640.74346327782,57217.61907982826,33640.74346327782,23572.74237060547,1.244916915893555,0.0 -94500,0.721567,1.6152607,,,,,,,,,,,,,,,,, -94600,0.2103993,1.5759757,,,,,,,,,,,,,,,,, -94700,0.22857893,1.6763198,,,,,,,,,,,,,,,,, -94800,0.20272379,1.5579463,,,,,,,,,,,,,,,,, -94900,0.205561,1.5141456,,,,,,,,,,,,,,,,, -95000,0.21406828,1.5178882,,,,,,,,,,,,,,,,, -95100,0.21053207,1.5713878,,,,,,,,,,,,,,,,, -95200,0.22409265,1.6541669,,,,,,,,,,,,,,,,, -95300,0.20878196,1.501338,,,,,,,,,,,,,,,,, -95400,0.20207965,1.49655,,,,,,,,,,,,,,,,, -95500,0.21593182,1.585461,,,,,,,,,,,,,,,,, -95600,0.21268667,1.4942173,,,,,,,,,,,,,,,,, -95700,0.23327444,1.5861413,,,,,,,,,,,,,,,,, -95800,0.22549598,1.5637708,,,,,,,,,,,,,,,,, -95900,0.23355028,1.5914685,,,,,,,,,,,,,,,,, -96000,0.2060044,1.5099068,,,,,,,,,,,,,,,,, -96100,0.2136193,1.5733593,,,,,,,,,,,,,,,,, -96200,0.21351272,1.6278356,,,,,,,,,,,,,,,,, -96300,0.23070809,1.6125871,,,,,,,,,,,,,,,,, -96400,0.21224768,1.5470324,,,,,,,,,,,,,,,,, -96500,1.3002383,1.5109318,,,,,,,,,,,,,,,,, -96600,0.21474671,1.4908894,,,,,,,,,,,,,,,,, -96700,0.21858934,1.5725381,,,,,,,,,,,,,,,,, -96800,0.20655482,1.5553235,,,,,,,,,,,,,,,,, -96848,,,0.678167462348938,1.4757825136184692,33.81288736021676,0.6871830224990845,1.4069184064865112,30.26572122548636,3000.0,0.7035849094390869,1.308138728141785,30.46707014258092,3003.0,34480.9733145237,58765.871055841446,34480.9733145237,24280.65658521652,1.280604362487793,0.0 -96900,0.20813546,1.570195,,,,,,,,,,,,,,,,, -97000,0.22113812,1.6142986,,,,,,,,,,,,,,,,, -97100,0.20885862,1.5075648,,,,,,,,,,,,,,,,, -97200,0.2174076,1.5346416,,,,,,,,,,,,,,,,, -97300,0.21564944,1.5469658,,,,,,,,,,,,,,,,, -97400,0.22253415,1.5775502,,,,,,,,,,,,,,,,, -97500,0.222079,1.5850008,,,,,,,,,,,,,,,,, -97600,0.22100548,1.5505562,,,,,,,,,,,,,,,,, -97700,0.2110198,1.510314,,,,,,,,,,,,,,,,, -97800,0.21808745,1.5346622,,,,,,,,,,,,,,,,, -97900,0.20436564,1.552351,,,,,,,,,,,,,,,,, -98000,0.21663441,1.5147629,,,,,,,,,,,,,,,,, -98100,0.20611426,1.5439166,,,,,,,,,,,,,,,,, -98200,0.22178897,1.6364579,,,,,,,,,,,,,,,,, -98300,0.20843334,1.5026354,,,,,,,,,,,,,,,,, -98400,0.21571611,1.4572445,,,,,,,,,,,,,,,,, -98500,0.21349421,1.4768376,,,,,,,,,,,,,,,,, -98600,0.21902914,1.534658,,,,,,,,,,,,,,,,, -98700,0.23246782,1.6101738,,,,,,,,,,,,,,,,, -98800,0.22083442,1.6150076,,,,,,,,,,,,,,,,, -98900,0.2218919,1.5118955,,,,,,,,,,,,,,,,, -99000,0.20098284,1.548745,,,,,,,,,,,,,,,,, -99100,0.21271525,1.5911652,,,,,,,,,,,,,,,,, -99200,0.24412087,1.5448097,,,,,,,,,,,,,,,,, -99210,,,0.6783736944198608,1.4781726598739624,34.20753458293841,0.6882989406585693,1.4027429819107056,30.174071090892355,3000.0,0.7043286561965942,1.3000746965408323,30.263596635691624,3003.0,35321.14249563217,60169.06457781792,35321.14249563217,24843.572281837463,1.3168067932128906,0.0 -99300,0.21492623,1.5431147,,,,,,,,,,,,,,,,, -99400,0.21975864,1.5743511,,,,,,,,,,,,,,,,, -99500,0.22099759,1.5244957,,,,,,,,,,,,,,,,, -99600,0.22198702,1.5530305,,,,,,,,,,,,,,,,, -99700,0.24478358,1.5867131,,,,,,,,,,,,,,,,, -99800,0.22573262,1.5380284,,,,,,,,,,,,,,,,, -99900,0.21203315,1.49649,,,,,,,,,,,,,,,,, -100000,0.22153586,1.5682425,,,,,,,,,,,,,,,,, -100100,0.22315094,1.5543555,,,,,,,,,,,,,,,,, -100200,0.21282524,1.5080354,,,,,,,,,,,,,,,,, -100300,0.20558299,1.5656636,,,,,,,,,,,,,,,,, -100400,0.22326204,1.4966149,,,,,,,,,,,,,,,,, -100500,0.21570423,1.5581499,,,,,,,,,,,,,,,,, -100600,0.21349543,1.528642,,,,,,,,,,,,,,,,, -100700,0.21948148,1.5456253,,,,,,,,,,,,,,,,, -100800,0.23157477,1.5963752,,,,,,,,,,,,,,,,, -100900,0.21971053,1.623326,,,,,,,,,,,,,,,,, -101000,0.22654408,1.4992685,,,,,,,,,,,,,,,,, -101100,0.21816593,1.526302,,,,,,,,,,,,,,,,, -101200,0.21251322,1.4905162,,,,,,,,,,,,,,,,, -101300,0.22118622,1.4962423,,,,,,,,,,,,,,,,, -101400,0.2208311,1.5532156,,,,,,,,,,,,,,,,, -101500,0.343696,1.5531309,,,,,,,,,,,,,,,,, -101572,,,0.6858225464820862,1.430709719657898,34.8028409637508,0.689452052116394,1.4000771045684814,29.648971887837902,3000.0,0.7056533694267273,1.2988643646240234,30.105600273315,3003.0,36161.22002506256,61720.23471450806,36161.22002506256,25554.55659222603,1.3547017574310305,0.0 -101600,0.22315353,1.5578297,,,,,,,,,,,,,,,,, -101700,0.2269089,1.5513359,,,,,,,,,,,,,,,,, -101800,0.2313862,1.5166728,,,,,,,,,,,,,,,,, -101900,0.22493032,1.6190742,,,,,,,,,,,,,,,,, -102000,0.21291324,1.5254927,,,,,,,,,,,,,,,,, -102100,0.224767,1.5432256,,,,,,,,,,,,,,,,, -102200,0.23779444,1.5674652,,,,,,,,,,,,,,,,, -102300,0.23642744,1.5342239,,,,,,,,,,,,,,,,, -102400,0.21811184,1.5266043,,,,,,,,,,,,,,,,, -102500,0.21322922,1.562666,,,,,,,,,,,,,,,,, -102600,0.21598898,1.4619888,,,,,,,,,,,,,,,,, -102700,0.2391627,1.580834,,,,,,,,,,,,,,,,, -102800,0.22295572,1.4984437,,,,,,,,,,,,,,,,, -102900,0.21098948,1.4889276,,,,,,,,,,,,,,,,, -103000,0.21198313,1.4944794,,,,,,,,,,,,,,,,, -103100,0.22275837,1.5234125,,,,,,,,,,,,,,,,, -103200,0.22238542,1.5994738,,,,,,,,,,,,,,,,, -103300,0.21007279,1.555881,,,,,,,,,,,,,,,,, -103400,0.23132831,1.5387877,,,,,,,,,,,,,,,,, -103500,0.21649152,1.4898807,,,,,,,,,,,,,,,,, -103600,0.21832186,1.5947251,,,,,,,,,,,,,,,,, -103700,0.25408396,1.4567499,,,,,,,,,,,,,,,,, -103800,0.2262082,1.5584079,,,,,,,,,,,,,,,,, -103900,0.22416233,1.5149513,,,,,,,,,,,,,,,,, -103935,,,0.6802985072135925,1.4587417840957642,34.38016920874783,0.6917335391044617,1.3879741430282593,30.13557535263701,3000.0,0.7094184160232544,1.2878787517547607,30.68149962599856,3003.0,37001.32061576843,63211.98622179032,37001.32061576843,26206.101365804672,1.3905627727508545,0.0 -104000,0.2349285,1.5758711,,,,,,,,,,,,,,,,, -104100,0.22180232,1.520181,,,,,,,,,,,,,,,,, -104200,0.22762874,1.5799662,,,,,,,,,,,,,,,,, -104300,0.2155918,1.4688333,,,,,,,,,,,,,,,,, -104400,0.23452008,1.6289253,,,,,,,,,,,,,,,,, -104500,0.22571953,1.5186286,,,,,,,,,,,,,,,,, -104600,0.22807707,1.5283666,,,,,,,,,,,,,,,,, -104700,0.2312115,1.4864365,,,,,,,,,,,,,,,,, -104800,0.24471723,1.5041492,,,,,,,,,,,,,,,,, -104900,0.21774533,1.5610244,,,,,,,,,,,,,,,,, -105000,0.30585772,1.5281541,,,,,,,,,,,,,,,,, -105100,0.24398966,1.5407925,,,,,,,,,,,,,,,,, -105200,0.22385833,1.5920554,,,,,,,,,,,,,,,,, -105300,0.22246969,1.5521666,,,,,,,,,,,,,,,,, -105400,0.2235204,1.4852425,,,,,,,,,,,,,,,,, -105500,0.22446299,1.4702483,,,,,,,,,,,,,,,,, -105600,0.2416683,1.5133251,,,,,,,,,,,,,,,,, -105700,0.2185447,1.4911071,,,,,,,,,,,,,,,,, -105800,0.24637875,1.5810926,,,,,,,,,,,,,,,,, -105900,0.23325123,1.5348504,,,,,,,,,,,,,,,,, -106000,0.21648537,1.5914459,,,,,,,,,,,,,,,,, -106100,0.2334472,1.4901512,,,,,,,,,,,,,,,,, -106200,0.36084053,1.4858465,,,,,,,,,,,,,,,,, -106297,,,0.6838071942329407,1.439373016357422,34.89865767951385,0.6919938921928406,1.38360857963562,30.601454222110416,3000.0,0.7088606357574463,1.2818655967712402,30.694700225262544,3003.0,37841.38295006752,64663.177958250046,37841.38295006752,26817.12211108208,1.4279468059539795,0.0 -106300,0.22113538,1.5941638,,,,,,,,,,,,,,,,, -106400,0.21906927,1.5601145,,,,,,,,,,,,,,,,, -106500,0.22157086,1.504419,,,,,,,,,,,,,,,,, -106600,0.21980064,1.466708,,,,,,,,,,,,,,,,, -106700,0.24577057,1.5295495,,,,,,,,,,,,,,,,, -106800,0.22597624,1.4904872,,,,,,,,,,,,,,,,, -106900,0.22674155,1.4885722,,,,,,,,,,,,,,,,, -107000,0.21022747,1.4440371,,,,,,,,,,,,,,,,, -107100,0.23854828,1.5409575,,,,,,,,,,,,,,,,, -107200,0.22625303,1.5564396,,,,,,,,,,,,,,,,, -107300,0.21584882,1.4586172,,,,,,,,,,,,,,,,, -107400,0.24590296,1.529305,,,,,,,,,,,,,,,,, -107500,0.23575446,1.5024536,,,,,,,,,,,,,,,,, -107600,0.22873284,1.5460812,,,,,,,,,,,,,,,,, -107700,0.22534542,1.4456397,,,,,,,,,,,,,,,,, -107800,0.23608735,1.4900155,,,,,,,,,,,,,,,,, -107900,0.23098911,1.5384171,,,,,,,,,,,,,,,,, -108000,0.22537312,1.5009285,,,,,,,,,,,,,,,,, -108100,0.24229707,1.5982338,,,,,,,,,,,,,,,,, -108200,0.23470291,1.478081,,,,,,,,,,,,,,,,, -108300,0.22493058,1.5152208,,,,,,,,,,,,,,,,, -108400,0.23455545,1.476647,,,,,,,,,,,,,,,,, -108500,0.23039198,1.507063,,,,,,,,,,,,,,,,, -108600,0.22715524,1.4621925,,,,,,,,,,,,,,,,, -108660,,,0.6914603114128113,1.391718506813049,34.7940850778574,0.6927502155303955,1.3769105672836304,30.70223529853148,3000.0,0.7099994421005249,1.2747658491134644,30.653324968205062,3003.0,38681.60069704056,66200.41820144653,38681.60069704056,27514.03621053696,1.465364694595337,0.0 -108700,0.23282216,1.4837059,,,,,,,,,,,,,,,,, -108800,0.22180259,1.4867476,,,,,,,,,,,,,,,,, -108900,0.22634077,1.5012217,,,,,,,,,,,,,,,,, -109000,0.2509845,1.5220821,,,,,,,,,,,,,,,,, -109100,0.22820555,1.5156182,,,,,,,,,,,,,,,,, -109200,0.2426946,1.4990407,,,,,,,,,,,,,,,,, -109300,0.24886468,1.4802077,,,,,,,,,,,,,,,,, -109400,0.23415159,1.5302598,,,,,,,,,,,,,,,,, -109500,0.22355235,1.4535208,,,,,,,,,,,,,,,,, -109600,0.22858472,1.5122831,,,,,,,,,,,,,,,,, -109700,0.24085341,1.4676297,,,,,,,,,,,,,,,,, -109800,0.22623475,1.4785231,,,,,,,,,,,,,,,,, -109900,0.23859097,1.5386803,,,,,,,,,,,,,,,,, -110000,0.23796663,1.4932351,,,,,,,,,,,,,,,,, -110100,0.23282027,1.4557297,,,,,,,,,,,,,,,,, -110200,0.23126678,1.4802297,,,,,,,,,,,,,,,,, -110300,0.23873916,1.4831827,,,,,,,,,,,,,,,,, -110400,0.23753734,1.5212355,,,,,,,,,,,,,,,,, -110500,0.26623377,1.5809377,,,,,,,,,,,,,,,,, -110600,0.22509637,1.5149541,,,,,,,,,,,,,,,,, -110700,0.23359983,1.520751,,,,,,,,,,,,,,,,, -110800,0.2391404,1.5150322,,,,,,,,,,,,,,,,, -110900,0.2301281,1.5007386,,,,,,,,,,,,,,,,, -111000,0.23040679,1.4974523,,,,,,,,,,,,,,,,, -111022,,,0.6853946447372437,1.4290424585342407,34.54282435835048,0.6935685873031616,1.3749322891235352,30.51653350310114,3000.0,0.7101040482521057,1.273871898651123,30.655470179904714,3003.0,39521.62386965752,67617.26778793335,39521.62386965752,28090.752275705338,1.504340410232544,0.0 -111100,0.2234286,1.4591668,,,,,,,,,,,,,,,,, -111200,0.23402455,1.4870582,,,,,,,,,,,,,,,,, -111300,0.22843035,1.5298567,,,,,,,,,,,,,,,,, -111400,0.2384389,1.5515696,,,,,,,,,,,,,,,,, -111500,0.22495395,1.4738221,,,,,,,,,,,,,,,,, -111600,0.22288561,1.4941564,,,,,,,,,,,,,,,,, -111700,0.24906689,1.4790726,,,,,,,,,,,,,,,,, -111800,0.23419641,1.5169537,,,,,,,,,,,,,,,,, -111900,0.25423667,1.5699031,,,,,,,,,,,,,,,,, -112000,0.22099328,1.4687704,,,,,,,,,,,,,,,,, -112100,0.2237939,1.366251,,,,,,,,,,,,,,,,, -112200,0.23575161,1.4789903,,,,,,,,,,,,,,,,, -112300,0.25086638,1.4320436,,,,,,,,,,,,,,,,, -112400,0.22438262,1.4916217,,,,,,,,,,,,,,,,, -112500,0.24650854,1.4991447,,,,,,,,,,,,,,,,, -112600,0.22829697,1.4794616,,,,,,,,,,,,,,,,, -112700,0.22438318,1.4373194,,,,,,,,,,,,,,,,, -112800,0.22561143,1.488913,,,,,,,,,,,,,,,,, -112900,0.2346086,1.5189484,,,,,,,,,,,,,,,,, -113000,0.22854254,1.482535,,,,,,,,,,,,,,,,, -113100,0.23121765,1.4638692,,,,,,,,,,,,,,,,, -113200,0.23222053,1.474443,,,,,,,,,,,,,,,,, -113300,0.23024492,1.4940722,,,,,,,,,,,,,,,,, -113384,,,0.6964467763900757,1.3651326894760132,35.61580584857655,0.6939033269882202,1.3725615739822388,30.671236136957305,3000.0,0.7107896208763123,1.2676246166229248,30.63459049836028,3003.0,40361.59369277954,69141.02885961533,40361.59369277954,28774.43212223053,1.543736219406128,0.0 -113400,0.22681145,1.4891405,,,,,,,,,,,,,,,,, -113500,0.23002492,1.4186186,,,,,,,,,,,,,,,,, -113600,0.23200326,1.5089966,,,,,,,,,,,,,,,,, -113700,0.23021168,1.505221,,,,,,,,,,,,,,,,, -113800,0.23512655,1.438237,,,,,,,,,,,,,,,,, -113900,0.2421852,1.4977379,,,,,,,,,,,,,,,,, -114000,0.23119691,1.5240321,,,,,,,,,,,,,,,,, -114100,0.32906795,1.4728034,,,,,,,,,,,,,,,,, -114200,0.22779417,1.4769307,,,,,,,,,,,,,,,,, -114300,0.23207098,1.485596,,,,,,,,,,,,,,,,, -114400,0.24313171,1.5175579,,,,,,,,,,,,,,,,, -114500,0.23547177,1.5270963,,,,,,,,,,,,,,,,, -114600,0.23236468,1.5099493,,,,,,,,,,,,,,,,, -114700,0.23044853,1.5099123,,,,,,,,,,,,,,,,, -114800,0.23150676,1.4816334,,,,,,,,,,,,,,,,, -114900,0.23356013,1.4864415,,,,,,,,,,,,,,,,, -115000,0.2280528,1.4453877,,,,,,,,,,,,,,,,, -115100,0.23228067,1.4999377,,,,,,,,,,,,,,,,, -115200,0.23476619,1.4538455,,,,,,,,,,,,,,,,, -115300,0.23969857,1.4689236,,,,,,,,,,,,,,,,, -115400,0.23300426,1.4702272,,,,,,,,,,,,,,,,, -115500,0.22205646,1.4820555,,,,,,,,,,,,,,,,, -115600,0.25003058,1.4972864,,,,,,,,,,,,,,,,, -115700,0.22739604,1.5130286,,,,,,,,,,,,,,,,, -115747,,,0.6904262900352478,1.390766739845276,35.42563523204688,0.693990170955658,1.3671605587005615,30.734451759132792,3000.0,0.7119516730308533,1.2626426219940186,30.89395735573037,3003.0,41201.82719826698,70558.74024915695,41201.82719826698,29351.802276611328,1.5817480087280271,0.0 -115800,0.22606505,1.4464666,,,,,,,,,,,,,,,,, -115900,0.24239014,1.5002048,,,,,,,,,,,,,,,,, -116000,0.243534,1.5244869,,,,,,,,,,,,,,,,, -116100,0.23350768,1.4254448,,,,,,,,,,,,,,,,, -116200,0.23013282,1.4332926,,,,,,,,,,,,,,,,, -116300,0.22757319,1.4848429,,,,,,,,,,,,,,,,, -116400,0.23348187,1.4358485,,,,,,,,,,,,,,,,, -116500,0.22734793,1.4868916,,,,,,,,,,,,,,,,, -116600,0.23740526,1.4299123,,,,,,,,,,,,,,,,, -116700,0.2331327,1.4429779,,,,,,,,,,,,,,,,, -116800,0.23672555,1.52238,,,,,,,,,,,,,,,,, -116900,0.24350573,1.4576113,,,,,,,,,,,,,,,,, -117000,0.2377791,1.5238694,,,,,,,,,,,,,,,,, -117100,0.25404716,1.4848973,,,,,,,,,,,,,,,,, -117200,0.23737931,1.465835,,,,,,,,,,,,,,,,, -117300,0.22579446,1.4346449,,,,,,,,,,,,,,,,, -117400,0.23559749,1.4728034,,,,,,,,,,,,,,,,, -117500,0.23868862,1.53174,,,,,,,,,,,,,,,,, -117600,0.22060934,1.4554898,,,,,,,,,,,,,,,,, -117700,0.23358531,1.4976655,,,,,,,,,,,,,,,,, -117800,0.25643563,1.5842147,,,,,,,,,,,,,,,,, -117900,0.22655846,1.4406116,,,,,,,,,,,,,,,,, -118000,0.25008708,1.5074805,,,,,,,,,,,,,,,,, -118100,0.22209805,1.4129452,,,,,,,,,,,,,,,,, -118109,,,0.6913244128227234,1.3941733837127686,34.82535315227721,0.6944116950035095,1.3669185638427734,30.55812479814081,3000.0,0.7115681767463684,1.2619701623916626,30.788337111092257,3003.0,42041.80287861824,71967.25450706482,42041.80287861824,29920.225604772568,1.624889850616455,0.0 -118200,0.23136832,1.4678903,,,,,,,,,,,,,,,,, -118300,0.22825843,1.4062827,,,,,,,,,,,,,,,,, -118400,0.24111494,1.5461024,,,,,,,,,,,,,,,,, -118500,0.23307034,1.4911028,,,,,,,,,,,,,,,,, -118600,0.24104847,1.5034182,,,,,,,,,,,,,,,,, -118700,0.22718747,1.5485086,,,,,,,,,,,,,,,,, -118800,0.24670312,1.4497399,,,,,,,,,,,,,,,,, -118900,0.22876327,1.4014176,,,,,,,,,,,,,,,,, -119000,0.23287216,1.504046,,,,,,,,,,,,,,,,, -119100,0.22824122,1.4212862,,,,,,,,,,,,,,,,, -119200,0.23429069,1.5024507,,,,,,,,,,,,,,,,, -119300,0.24445023,1.4779432,,,,,,,,,,,,,,,,, -119400,0.2246054,1.4373745,,,,,,,,,,,,,,,,, -119500,0.2320252,1.5070645,,,,,,,,,,,,,,,,, -119600,0.23866417,1.4806913,,,,,,,,,,,,,,,,, -119700,0.22913498,1.4760864,,,,,,,,,,,,,,,,, -119800,0.23673019,1.490068,,,,,,,,,,,,,,,,, -119900,0.22795382,1.4290881,,,,,,,,,,,,,,,,, -120000,0.23017502,1.4274172,,,,,,,,,,,,,,,,, -120100,0.23131299,1.4985756,,,,,,,,,,,,,,,,, -120200,0.2358612,1.4673417,,,,,,,,,,,,,,,,, -120300,0.2370404,1.4468852,,,,,,,,,,,,,,,,, -120400,0.23750019,1.4692787,,,,,,,,,,,,,,,,, -120471,,,0.6960253119468689,1.363867998123169,35.065851404224986,0.6955400109291077,1.363057017326355,30.73799429326848,3000.0,0.7127418518066406,1.2584396600723269,31.113236211947505,3003.0,42881.7198240757,73406.56572580338,42881.7198240757,30519.5025408268,1.671435832977295,0.0 -120500,0.23746154,1.4600384,,,,,,,,,,,,,,,,, -120600,0.24565522,1.4709165,,,,,,,,,,,,,,,,, -120700,0.23494287,1.4091028,,,,,,,,,,,,,,,,, -120800,0.23732424,1.4176223,,,,,,,,,,,,,,,,, -120900,0.23657368,1.4362026,,,,,,,,,,,,,,,,, -121000,0.24717595,1.4526265,,,,,,,,,,,,,,,,, -121100,0.24549204,1.5353942,,,,,,,,,,,,,,,,, -121200,0.23395814,1.4233493,,,,,,,,,,,,,,,,, -121300,0.23404644,1.5495276,,,,,,,,,,,,,,,,, -121400,0.23619044,1.3706111,,,,,,,,,,,,,,,,, -121500,0.2408427,1.4873506,,,,,,,,,,,,,,,,, -121600,0.24730779,1.4949172,,,,,,,,,,,,,,,,, -121700,0.24319215,1.4857991,,,,,,,,,,,,,,,,, -121800,0.23659995,1.4696033,,,,,,,,,,,,,,,,, -121900,0.23104008,1.5004083,,,,,,,,,,,,,,,,, -122000,0.24038425,1.4898016,,,,,,,,,,,,,,,,, -122100,0.23689151,1.4681151,,,,,,,,,,,,,,,,, -122200,0.23708135,1.4229367,,,,,,,,,,,,,,,,, -122300,0.24617542,1.4783571,,,,,,,,,,,,,,,,, -122400,0.23180293,1.4564193,,,,,,,,,,,,,,,,, -122500,0.2408812,1.4946549,,,,,,,,,,,,,,,,, -122600,0.24141009,1.5072118,,,,,,,,,,,,,,,,, -122700,0.24649592,1.4883599,,,,,,,,,,,,,,,,, -122800,0.23029493,1.4663763,,,,,,,,,,,,,,,,, -122831,,,0.6932611465454102,1.3812857866287231,35.56988353162563,0.6957756280899048,1.3611180782318115,30.736138291930217,3000.0,0.7138341665267944,1.256298542022705,31.07882520707422,3003.0,43721.67870378494,74905.83988952637,43721.67870378494,31178.701949357983,1.7121977806091309,0.0 -122900,0.2311972,1.4150618,,,,,,,,,,,,,,,,, -123000,0.24705216,1.5076549,,,,,,,,,,,,,,,,, -123100,0.24572437,1.597052,,,,,,,,,,,,,,,,, -123200,0.2439847,1.4691907,,,,,,,,,,,,,,,,, -123300,0.24724777,1.5109112,,,,,,,,,,,,,,,,, -123400,0.23169069,1.4359998,,,,,,,,,,,,,,,,, -123500,0.23904796,1.4550091,,,,,,,,,,,,,,,,, -123600,0.22445036,1.4032277,,,,,,,,,,,,,,,,, -123700,0.23681274,1.5837697,,,,,,,,,,,,,,,,, -123800,0.22432435,1.405101,,,,,,,,,,,,,,,,, -123900,0.23371917,1.4262079,,,,,,,,,,,,,,,,, -124000,0.2305169,1.4975777,,,,,,,,,,,,,,,,, -124100,0.2439977,1.470324,,,,,,,,,,,,,,,,, -124200,0.23344526,1.4432509,,,,,,,,,,,,,,,,, -124300,0.2363453,1.4195455,,,,,,,,,,,,,,,,, -124400,0.22426651,1.4661303,,,,,,,,,,,,,,,,, -124500,0.241193,1.4137104,,,,,,,,,,,,,,,,, -124600,0.23210788,1.494341,,,,,,,,,,,,,,,,, -124700,0.23049684,1.5014836,,,,,,,,,,,,,,,,, -124800,0.2335606,1.4424314,,,,,,,,,,,,,,,,, -124900,0.24137841,1.4394515,,,,,,,,,,,,,,,,, -125000,0.24092986,1.5043496,,,,,,,,,,,,,,,,, -125100,0.23130922,1.5319948,,,,,,,,,,,,,,,,, -125193,,,0.6958445310592651,1.3700186014175415,35.29065202228168,0.6958004236221313,1.3598471879959106,30.81504363371959,3000.0,0.7141479253768921,1.2538695335388184,31.1441452580602,3003.0,44561.77640080452,76337.77584266663,44561.77640080452,31770.429005146027,1.752969741821289,0.0 -125200,0.22849047,1.4010724,,,,,,,,,,,,,,,,, -125300,0.23643155,1.4288639,,,,,,,,,,,,,,,,, -125400,0.24031606,1.515537,,,,,,,,,,,,,,,,, -125500,0.23662122,1.468918,,,,,,,,,,,,,,,,, -125600,0.24032313,1.4692005,,,,,,,,,,,,,,,,, -125700,0.2549727,1.492936,,,,,,,,,,,,,,,,, -125800,0.22647515,1.3986429,,,,,,,,,,,,,,,,, -125900,0.23950168,1.46793,,,,,,,,,,,,,,,,, -126000,0.22916627,1.430713,,,,,,,,,,,,,,,,, -126100,0.24006844,1.4881533,,,,,,,,,,,,,,,,, -126200,0.22871736,1.3409963,,,,,,,,,,,,,,,,, -126300,0.237061,1.4969213,,,,,,,,,,,,,,,,, -126400,0.23919338,1.5173988,,,,,,,,,,,,,,,,, -126500,0.23704387,1.5228356,,,,,,,,,,,,,,,,, -126600,0.23546265,1.4827495,,,,,,,,,,,,,,,,, -126700,0.24381053,1.4728416,,,,,,,,,,,,,,,,, -126800,0.23957539,1.483898,,,,,,,,,,,,,,,,, -126900,0.2400916,1.5400343,,,,,,,,,,,,,,,,, -127000,0.2313361,1.4736603,,,,,,,,,,,,,,,,, -127100,0.23736171,1.5203791,,,,,,,,,,,,,,,,, -127200,0.2405434,1.4713753,,,,,,,,,,,,,,,,, -127300,0.23505558,1.4711702,,,,,,,,,,,,,,,,, -127400,0.24463385,1.5150732,,,,,,,,,,,,,,,,, -127500,0.22921765,1.3915589,,,,,,,,,,,,,,,,, -127555,,,0.6929237842559814,1.3827592134475708,35.22217671734824,0.6960979700088501,1.3584469556808472,30.751015518353736,3000.0,0.7144616842269897,1.2528407573699951,31.0004361617364,3003.0,45402.009125709534,77775.84038686752,45402.009125709534,32368.1502096653,1.7922303676605225,0.0 -127600,0.2331055,1.4338868,,,,,,,,,,,,,,,,, -127700,0.24194288,1.5323397,,,,,,,,,,,,,,,,, -127800,0.2369142,1.4644974,,,,,,,,,,,,,,,,, -127900,0.2508464,1.554457,,,,,,,,,,,,,,,,, -128000,0.24251883,1.5095778,,,,,,,,,,,,,,,,, -128100,0.24198093,1.4900044,,,,,,,,,,,,,,,,, -128200,0.23944488,1.4671018,,,,,,,,,,,,,,,,, -128300,0.2341799,1.3802826,,,,,,,,,,,,,,,,, -128400,0.24250649,1.3947428,,,,,,,,,,,,,,,,, -128500,0.22144441,1.4091556,,,,,,,,,,,,,,,,, -128600,0.23227558,1.4196444,,,,,,,,,,,,,,,,, -128700,0.24094722,1.4926127,,,,,,,,,,,,,,,,, -128800,0.23945177,1.4441934,,,,,,,,,,,,,,,,, -128900,0.22924934,1.3897295,,,,,,,,,,,,,,,,, -129000,0.23135006,1.3181901,,,,,,,,,,,,,,,,, -129100,0.24972157,1.5004259,,,,,,,,,,,,,,,,, -129200,0.2343895,1.3991728,,,,,,,,,,,,,,,,, -129300,0.23762506,1.5171869,,,,,,,,,,,,,,,,, -129400,0.23657463,1.4993267,,,,,,,,,,,,,,,,, -129500,0.22408561,1.4136395,,,,,,,,,,,,,,,,, -129600,0.24048822,1.3988554,,,,,,,,,,,,,,,,, -129700,0.2338449,1.4353259,,,,,,,,,,,,,,,,, -129800,0.22516325,1.4545171,,,,,,,,,,,,,,,,, -129900,0.23381019,1.4767154,,,,,,,,,,,,,,,,, -129917,,,0.6957194805145264,1.3678160905838013,35.71389968451469,0.6961227655410767,1.358590602874756,30.797391608123867,3000.0,0.7143106460571289,1.2522014379501345,30.975489200283697,3003.0,46242.11846613884,79226.51109528542,46242.11846613884,32978.59933829308,1.833895206451416,0.0 -130000,0.23028512,1.4673835,,,,,,,,,,,,,,,,, -130100,0.23769957,1.4278789,,,,,,,,,,,,,,,,, -130200,0.23186563,1.4193608,,,,,,,,,,,,,,,,, -130300,0.23356983,1.4637252,,,,,,,,,,,,,,,,, -130400,0.22751242,1.4869903,,,,,,,,,,,,,,,,, -130500,0.23958859,1.4815952,,,,,,,,,,,,,,,,, -130600,0.24755783,1.4789808,,,,,,,,,,,,,,,,, -130700,0.22744088,1.4562109,,,,,,,,,,,,,,,,, -130800,0.2355144,1.4913445,,,,,,,,,,,,,,,,, -130900,0.24421471,1.4414499,,,,,,,,,,,,,,,,, -131000,0.23296553,1.4602255,,,,,,,,,,,,,,,,, -131100,0.22424902,1.4432985,,,,,,,,,,,,,,,,, -131200,0.2404748,1.48649,,,,,,,,,,,,,,,,, -131300,0.22615692,1.3563311,,,,,,,,,,,,,,,,, -131400,0.22750305,1.4287871,,,,,,,,,,,,,,,,, -131500,0.23021553,1.4750785,,,,,,,,,,,,,,,,, -131600,0.22670917,1.4104269,,,,,,,,,,,,,,,,, -131700,0.23346306,1.4525849,,,,,,,,,,,,,,,,, -131800,0.22554655,1.4243629,,,,,,,,,,,,,,,,, -131900,0.2311558,1.405234,,,,,,,,,,,,,,,,, -132000,0.24032243,1.4208217,,,,,,,,,,,,,,,,, -132100,0.23867401,1.4121937,,,,,,,,,,,,,,,,, -132200,0.23088033,1.4460725,,,,,,,,,,,,,,,,, -132279,,,0.6954941153526306,1.3714572191238403,35.85878042865687,0.6958624124526978,1.3585927486419678,30.77876783788234,3000.0,0.7146244049072266,1.2520089149475098,31.02888880102673,3003.0,47082.26250171661,80673.64637875557,47082.26250171661,33585.4789557457,1.875113725662232,0.0 -132300,0.22789946,1.4823356,,,,,,,,,,,,,,,,, -132400,0.23025426,1.4258014,,,,,,,,,,,,,,,,, -132500,0.23043463,1.4151182,,,,,,,,,,,,,,,,, -132600,0.23419812,1.4760935,,,,,,,,,,,,,,,,, -132700,0.23496953,1.4070057,,,,,,,,,,,,,,,,, -132800,0.22765397,1.485385,,,,,,,,,,,,,,,,, -132900,0.23146339,1.4643279,,,,,,,,,,,,,,,,, -133000,0.22674775,1.4429533,,,,,,,,,,,,,,,,, -133100,0.22943136,1.4506781,,,,,,,,,,,,,,,,, -133200,0.23506975,1.4183539,,,,,,,,,,,,,,,,, -133300,0.23490553,1.4895782,,,,,,,,,,,,,,,,, -133400,0.22157054,1.4058797,,,,,,,,,,,,,,,,, -133500,0.23197976,1.4240859,,,,,,,,,,,,,,,,, -133600,0.23521766,1.4618121,,,,,,,,,,,,,,,,, -133700,0.23830205,1.4791915,,,,,,,,,,,,,,,,, -133800,0.24157806,1.470298,,,,,,,,,,,,,,,,, -133900,0.23959173,1.4350759,,,,,,,,,,,,,,,,, -134000,0.23885486,1.4705302,,,,,,,,,,,,,,,,, -134100,0.24393122,1.4842094,,,,,,,,,,,,,,,,, -134200,0.22819868,1.4256196,,,,,,,,,,,,,,,,, -134300,0.24628048,1.470411,,,,,,,,,,,,,,,,, -134400,0.22339274,1.4216518,,,,,,,,,,,,,,,,, -134500,0.2425539,1.5059005,,,,,,,,,,,,,,,,, -134600,0.22604014,1.3959557,,,,,,,,,,,,,,,,, -134640,,,0.6974707841873169,1.3609025478363037,35.34378432455704,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,47922.3007247448,82113.95312094688,47922.3007247448,34185.63210296631,1.917802095413208,0.0 -134700,0.23610154,1.4155068,,,,,,,,,,,,,,,,, -134800,0.23604387,1.4750644,,,,,,,,,,,,,,,,, -134900,0.23371051,1.3504077,,,,,,,,,,,,,,,,, -135000,0.23648535,1.472813,,,,,,,,,,,,,,,,, -135100,0.234127,1.3967708,,,,,,,,,,,,,,,,, -135200,0.23273334,1.4550071,,,,,,,,,,,,,,,,, -135300,0.23935911,1.4588574,,,,,,,,,,,,,,,,, -135400,0.23418003,1.3990648,,,,,,,,,,,,,,,,, -135500,0.22554223,1.4558876,,,,,,,,,,,,,,,,, -135600,0.24370477,1.553964,,,,,,,,,,,,,,,,, -135700,0.2276344,1.4130557,,,,,,,,,,,,,,,,, -135800,0.23836337,1.4591951,,,,,,,,,,,,,,,,, -135900,0.23035896,1.3869281,,,,,,,,,,,,,,,,, -136000,0.23111267,1.4648817,,,,,,,,,,,,,,,,, -136100,0.2395458,1.4825397,,,,,,,,,,,,,,,,, -136200,0.22857593,1.444356,,,,,,,,,,,,,,,,, -136300,0.23558305,1.4220972,,,,,,,,,,,,,,,,, -136400,0.23996338,1.3971788,,,,,,,,,,,,,,,,, -136500,0.23545593,1.4541095,,,,,,,,,,,,,,,,, -136600,0.23446319,1.515414,,,,,,,,,,,,,,,,, -136700,0.23562752,1.5149688,,,,,,,,,,,,,,,,, -136800,0.23784134,1.4312444,,,,,,,,,,,,,,,,, -136900,0.23904882,1.4621929,,,,,,,,,,,,,,,,, -137000,0.23837827,1.3983635,,,,,,,,,,,,,,,,, -137001,,,0.699103832244873,1.3539986610412598,35.55390017405729,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,48762.44181466103,83555.59596848488,48762.44181466103,34787.02144932747,1.959381103515625,0.0 -137100,0.24231201,1.463751,,,,,,,,,,,,,,,,, -137200,0.23971409,1.505186,,,,,,,,,,,,,,,,, -137300,0.23809025,1.4875652,,,,,,,,,,,,,,,,, -137400,0.23419556,1.4606123,,,,,,,,,,,,,,,,, -137500,0.22779918,1.3869562,,,,,,,,,,,,,,,,, -137600,0.23970369,1.5119511,,,,,,,,,,,,,,,,, -137700,0.2232015,1.4467808,,,,,,,,,,,,,,,,, -137800,0.23859613,1.3696516,,,,,,,,,,,,,,,,, -137900,0.23614116,1.5095052,,,,,,,,,,,,,,,,, -138000,0.22110334,1.3575664,,,,,,,,,,,,,,,,, -138100,0.22833724,1.4144288,,,,,,,,,,,,,,,,, -138200,0.23980477,1.4196854,,,,,,,,,,,,,,,,, -138300,0.23028296,1.4832957,,,,,,,,,,,,,,,,, -138400,0.23128064,1.4097208,,,,,,,,,,,,,,,,, -138500,0.24369207,1.4225674,,,,,,,,,,,,,,,,, -138600,0.22605991,1.4500209,,,,,,,,,,,,,,,,, -138700,0.2362571,1.4668305,,,,,,,,,,,,,,,,, -138800,0.23377514,1.379303,,,,,,,,,,,,,,,,, -138900,0.23060207,1.4724017,,,,,,,,,,,,,,,,, -139000,0.22902144,1.3728913,,,,,,,,,,,,,,,,, -139100,0.2437336,1.4553678,,,,,,,,,,,,,,,,, -139200,0.24123721,1.454802,,,,,,,,,,,,,,,,, -139300,0.22556272,1.4357606,,,,,,,,,,,,,,,,, -139363,,,0.6987770199775696,1.35377836227417,35.64343763817623,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,49602.64884185791,85004.57704162598,49602.64884185791,35395.68215799332,2.0005979537963867,0.0 -139400,0.22550276,1.3854382,,,,,,,,,,,,,,,,, -139500,0.23177655,1.4726105,,,,,,,,,,,,,,,,, -139600,0.23761496,1.4251944,,,,,,,,,,,,,,,,, -139700,0.23363303,1.4758035,,,,,,,,,,,,,,,,, -139800,0.24110712,1.4726602,,,,,,,,,,,,,,,,, -139900,0.24194263,1.4423099,,,,,,,,,,,,,,,,, -140000,0.2433337,1.4281142,,,,,,,,,,,,,,,,, -140100,0.22632723,1.3649738,,,,,,,,,,,,,,,,, -140200,0.22828838,1.437688,,,,,,,,,,,,,,,,, -140300,0.2464664,1.43575,,,,,,,,,,,,,,,,, -140400,0.22715054,1.3970956,,,,,,,,,,,,,,,,, -140500,0.23071149,1.4611919,,,,,,,,,,,,,,,,, -140600,0.23048922,1.4851383,,,,,,,,,,,,,,,,, -140700,0.23912477,1.4464512,,,,,,,,,,,,,,,,, -140800,0.23587954,1.4584165,,,,,,,,,,,,,,,,, -140900,0.22580673,1.4638891,,,,,,,,,,,,,,,,, -141000,0.22852896,1.4171561,,,,,,,,,,,,,,,,, -141100,0.23467295,1.4591577,,,,,,,,,,,,,,,,, -141200,0.24261491,1.4283135,,,,,,,,,,,,,,,,, -141300,0.23151225,1.3841059,,,,,,,,,,,,,,,,, -141400,0.22838353,1.3587725,,,,,,,,,,,,,,,,, -141500,0.2394401,1.4727523,,,,,,,,,,,,,,,,, -141600,0.24292596,1.4747238,,,,,,,,,,,,,,,,, -141700,0.23929308,1.4295639,,,,,,,,,,,,,,,,, -141725,,,0.6979433298110962,1.3597429990768433,35.16172339456949,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,50442.791553497314,86451.35402941704,50442.791553497314,36002.202165842056,2.0436899662017822,0.0 -141800,0.23993209,1.4820056,,,,,,,,,,,,,,,,, -141900,0.23710999,1.5053397,,,,,,,,,,,,,,,,, -142000,0.23613466,1.4557287,,,,,,,,,,,,,,,,, -142100,0.24031761,1.4061512,,,,,,,,,,,,,,,,, -142200,0.23406239,1.4403151,,,,,,,,,,,,,,,,, -142300,0.24129699,1.4476147,,,,,,,,,,,,,,,,, -142400,0.24371438,1.4917737,,,,,,,,,,,,,,,,, -142500,0.23845333,1.4831496,,,,,,,,,,,,,,,,, -142600,0.23949355,1.4691473,,,,,,,,,,,,,,,,, -142700,0.22450802,1.3441784,,,,,,,,,,,,,,,,, -142800,0.23312438,1.4725153,,,,,,,,,,,,,,,,, -142900,0.22665945,1.4324657,,,,,,,,,,,,,,,,, -143000,0.2336571,1.4879164,,,,,,,,,,,,,,,,, -143100,0.22821067,1.4292358,,,,,,,,,,,,,,,,, -143200,0.23665982,1.4600174,,,,,,,,,,,,,,,,, -143300,0.2453371,1.5362735,,,,,,,,,,,,,,,,, -143400,0.22918062,1.4267871,,,,,,,,,,,,,,,,, -143500,0.24205925,1.4786613,,,,,,,,,,,,,,,,, -143600,0.23670709,1.4601063,,,,,,,,,,,,,,,,, -143700,0.23363917,1.4035256,,,,,,,,,,,,,,,,, -143800,0.22888897,1.4012388,,,,,,,,,,,,,,,,, -143900,0.22786243,1.4804327,,,,,,,,,,,,,,,,, -144000,0.24357842,1.481263,,,,,,,,,,,,,,,,, -144087,,,0.6965329051017761,1.362047791481018,35.24575844551815,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,51283.00962305069,87898.00025892258,51283.00962305069,36608.51600170136,2.087331771850586,0.0 -144100,0.22385542,1.3424296,,,,,,,,,,,,,,,,, -144200,0.2257061,1.4238081,,,,,,,,,,,,,,,,, -144300,0.23167533,1.431931,,,,,,,,,,,,,,,,, -144400,0.23544718,1.4831997,,,,,,,,,,,,,,,,, -144500,0.23797719,1.3836386,,,,,,,,,,,,,,,,, -144600,0.2193126,1.4160947,,,,,,,,,,,,,,,,, -144700,0.22684136,1.3772664,,,,,,,,,,,,,,,,, -144800,0.23386863,1.4611238,,,,,,,,,,,,,,,,, -144900,0.24067138,1.470294,,,,,,,,,,,,,,,,, -145000,0.23375389,1.4243739,,,,,,,,,,,,,,,,, -145100,0.22675689,1.4702648,,,,,,,,,,,,,,,,, -145200,0.23478381,1.405789,,,,,,,,,,,,,,,,, -145300,0.23868653,1.4710073,,,,,,,,,,,,,,,,, -145400,0.25641498,1.4698955,,,,,,,,,,,,,,,,, -145500,0.22520751,1.4751751,,,,,,,,,,,,,,,,, -145600,0.23439151,1.3439168,,,,,,,,,,,,,,,,, -145700,0.22730929,1.4107409,,,,,,,,,,,,,,,,, -145800,0.24267727,1.4621919,,,,,,,,,,,,,,,,, -145900,0.23473333,1.4659865,,,,,,,,,,,,,,,,, -146000,0.22932127,1.4450259,,,,,,,,,,,,,,,,, -146100,0.24213889,1.4323566,,,,,,,,,,,,,,,,, -146200,0.23465686,1.542525,,,,,,,,,,,,,,,,, -146300,0.23464984,1.4297163,,,,,,,,,,,,,,,,, -146400,0.23817217,1.478746,,,,,,,,,,,,,,,,, -146449,,,0.6960597634315491,1.3612914085388184,35.78783044286703,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,52123.04042816162,89346.19786715508,52123.04042816162,37216.56715321541,2.130098581314087,0.0 -146500,0.23410182,1.4910964,,,,,,,,,,,,,,,,, -146600,0.23265925,1.4681088,,,,,,,,,,,,,,,,, -146700,0.23534839,1.4598402,,,,,,,,,,,,,,,,, -146800,0.23800094,1.4870626,,,,,,,,,,,,,,,,, -146900,0.23577847,1.4359978,,,,,,,,,,,,,,,,, -147000,0.22039261,1.4405262,,,,,,,,,,,,,,,,, -147100,0.22955762,1.3964825,,,,,,,,,,,,,,,,, -147200,0.23522994,1.5037773,,,,,,,,,,,,,,,,, -147300,0.23244889,1.4744838,,,,,,,,,,,,,,,,, -147400,0.22325175,1.4276116,,,,,,,,,,,,,,,,, -147500,0.23594598,1.4452065,,,,,,,,,,,,,,,,, -147600,0.23642497,1.4162363,,,,,,,,,,,,,,,,, -147700,0.24365404,1.3949546,,,,,,,,,,,,,,,,, -147800,0.22901155,1.4637531,,,,,,,,,,,,,,,,, -147900,0.23694196,1.3957019,,,,,,,,,,,,,,,,, -148000,0.24410549,1.4693478,,,,,,,,,,,,,,,,, -148100,0.23685011,1.4122201,,,,,,,,,,,,,,,,, -148200,0.24133426,1.4805915,,,,,,,,,,,,,,,,, -148300,0.2496069,1.3760244,,,,,,,,,,,,,,,,, -148400,0.2435077,1.4806844,,,,,,,,,,,,,,,,, -148500,0.22956935,1.3548775,,,,,,,,,,,,,,,,, -148600,0.2446796,1.517226,,,,,,,,,,,,,,,,, -148700,0.22770128,1.4445133,,,,,,,,,,,,,,,,, -148800,0.24237789,1.475769,,,,,,,,,,,,,,,,, -148811,,,0.6943530440330505,1.3751060962677002,35.51247185279916,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,52963.1450676918,90783.82259011269,52963.1450676918,37813.96529150009,2.1822049617767334,0.0 -148900,0.2403388,1.4672377,,,,,,,,,,,,,,,,, -149000,0.24046294,1.4482868,,,,,,,,,,,,,,,,, -149100,0.23736967,1.4554929,,,,,,,,,,,,,,,,, -149200,0.23499693,1.4159825,,,,,,,,,,,,,,,,, -149300,0.24192268,1.422455,,,,,,,,,,,,,,,,, -149400,0.22752811,1.4299598,,,,,,,,,,,,,,,,, -149500,0.23381965,1.4479152,,,,,,,,,,,,,,,,, -149600,0.23628956,1.4686083,,,,,,,,,,,,,,,,, -149700,0.22932744,1.4574827,,,,,,,,,,,,,,,,, -149800,0.22260876,1.4167486,,,,,,,,,,,,,,,,, -149900,0.23298088,1.4606749,,,,,,,,,,,,,,,,, -150000,0.24022084,1.4798805,,,,,,,,,,,,,,,,, -150100,0.23163716,1.4005616,,,,,,,,,,,,,,,,, -150200,0.23562479,1.4348989,,,,,,,,,,,,,,,,, -150300,0.22973828,1.3862137,,,,,,,,,,,,,,,,, -150400,0.2274607,1.3946292,,,,,,,,,,,,,,,,, -150500,0.23682454,1.4498824,,,,,,,,,,,,,,,,, -150600,0.24383911,1.4518776,,,,,,,,,,,,,,,,, -150700,1.2616507,1.4680034,,,,,,,,,,,,,,,,, -150800,0.24035068,1.4052651,,,,,,,,,,,,,,,,, -150900,0.23814826,1.487246,,,,,,,,,,,,,,,,, -151000,0.23716986,1.459313,,,,,,,,,,,,,,,,, -151100,0.23844396,1.4995086,,,,,,,,,,,,,,,,, -151174,,,0.6944592595100403,1.378005027770996,35.85757723474082,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,53803.25436472893,92227.89867305756,53803.25436472893,38417.81984090805,2.22391128540039,0.0 -151200,0.23262656,1.4532558,,,,,,,,,,,,,,,,, -151300,0.23153453,1.3664217,,,,,,,,,,,,,,,,, -151400,0.23133,1.4797348,,,,,,,,,,,,,,,,, -151500,0.23979597,1.4587964,,,,,,,,,,,,,,,,, -151600,0.23826785,1.3933927,,,,,,,,,,,,,,,,, -151700,0.23608707,1.4127469,,,,,,,,,,,,,,,,, -151800,0.24690092,1.4980731,,,,,,,,,,,,,,,,, -151900,0.23785163,1.5031891,,,,,,,,,,,,,,,,, -152000,0.2565459,1.4563075,,,,,,,,,,,,,,,,, -152100,0.23799603,1.4770606,,,,,,,,,,,,,,,,, -152200,0.23250318,1.4644027,,,,,,,,,,,,,,,,, -152300,0.23648776,1.3805543,,,,,,,,,,,,,,,,, -152400,0.23243771,1.4266931,,,,,,,,,,,,,,,,, -152500,0.24261396,1.50395,,,,,,,,,,,,,,,,, -152600,0.23618343,1.4051437,,,,,,,,,,,,,,,,, -152700,0.22773594,1.3550066,,,,,,,,,,,,,,,,, -152800,0.22830008,1.4739895,,,,,,,,,,,,,,,,, -152900,0.24404778,1.4448246,,,,,,,,,,,,,,,,, -153000,0.23234338,1.4582859,,,,,,,,,,,,,,,,, -153100,0.2321694,1.4333998,,,,,,,,,,,,,,,,, -153200,0.2508782,1.3641977,,,,,,,,,,,,,,,,, -153300,0.23434904,1.5048782,,,,,,,,,,,,,,,,, -153400,0.2305031,1.4393246,,,,,,,,,,,,,,,,, -153500,0.23537838,1.4366264,,,,,,,,,,,,,,,,, -153534,,,0.6963521242141724,1.3660151958465576,35.2620739329758,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,54643.3279042244,93671.47524333,54643.3279042244,39021.20381522179,2.266636848449707,0.0 -153600,0.22536533,1.4293649,,,,,,,,,,,,,,,,, -153700,0.23440285,1.3947748,,,,,,,,,,,,,,,,, -153800,0.23156564,1.4003091,,,,,,,,,,,,,,,,, -153900,0.33990327,1.4319928,,,,,,,,,,,,,,,,, -154000,0.23196754,1.4559076,,,,,,,,,,,,,,,,, -154100,0.231927,1.4202231,,,,,,,,,,,,,,,,, -154200,0.24290453,1.5454153,,,,,,,,,,,,,,,,, -154300,0.2270926,1.4242214,,,,,,,,,,,,,,,,, -154400,0.2346143,1.4511654,,,,,,,,,,,,,,,,, -154500,0.23714426,1.400259,,,,,,,,,,,,,,,,, -154600,0.23833771,1.4520494,,,,,,,,,,,,,,,,, -154700,0.23571952,1.3378706,,,,,,,,,,,,,,,,, -154800,0.22844689,1.4758062,,,,,,,,,,,,,,,,, -154900,0.22468756,1.4330194,,,,,,,,,,,,,,,,, -155000,0.22949201,1.413148,,,,,,,,,,,,,,,,, -155100,0.24230264,1.4859709,,,,,,,,,,,,,,,,, -155200,0.24681884,1.4519612,,,,,,,,,,,,,,,,, -155300,0.23962608,1.4717891,,,,,,,,,,,,,,,,, -155400,0.24629277,1.5132277,,,,,,,,,,,,,,,,, -155500,0.2297796,1.4343072,,,,,,,,,,,,,,,,, -155600,0.2314951,1.4677502,,,,,,,,,,,,,,,,, -155700,0.22868776,1.4400382,,,,,,,,,,,,,,,,, -155800,0.23089628,1.4321828,,,,,,,,,,,,,,,,, -155896,,,0.6948073506355286,1.376955270767212,35.64392910309191,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,55483.48570561409,95110.29526019096,55483.48570561409,39619.74911165237,2.3116683959960938,0.0 -155900,0.232655,1.5033631,,,,,,,,,,,,,,,,, -156000,0.2380059,1.4262489,,,,,,,,,,,,,,,,, -156100,0.2405479,1.5524706,,,,,,,,,,,,,,,,, -156200,0.2384081,1.5175575,,,,,,,,,,,,,,,,, -156300,0.2191068,1.31455,,,,,,,,,,,,,,,,, -156400,0.22288455,1.357799,,,,,,,,,,,,,,,,, -156500,0.22899461,1.4410932,,,,,,,,,,,,,,,,, -156600,0.235325,1.3671553,,,,,,,,,,,,,,,,, -156700,0.23683394,1.4059763,,,,,,,,,,,,,,,,, -156800,0.26783764,1.4332136,,,,,,,,,,,,,,,,, -156900,0.22920224,1.4600313,,,,,,,,,,,,,,,,, -157000,0.2365836,1.485589,,,,,,,,,,,,,,,,, -157100,0.23321196,1.4442343,,,,,,,,,,,,,,,,, -157200,0.23575003,1.4621336,,,,,,,,,,,,,,,,, -157300,0.23370734,1.3723906,,,,,,,,,,,,,,,,, -157400,0.23979689,1.4687549,,,,,,,,,,,,,,,,, -157500,0.22231355,1.4552988,,,,,,,,,,,,,,,,, -157600,0.24037571,1.4749986,,,,,,,,,,,,,,,,, -157700,0.23632094,1.4410781,,,,,,,,,,,,,,,,, -157800,0.22813617,1.4196998,,,,,,,,,,,,,,,,, -157900,0.23200966,1.4747841,,,,,,,,,,,,,,,,, -158000,0.23144852,1.4302512,,,,,,,,,,,,,,,,, -158100,0.23328717,1.4337878,,,,,,,,,,,,,,,,, -158200,0.23545505,1.4532267,,,,,,,,,,,,,,,,, -158257,,,0.6981350183486938,1.3524571657180786,35.56427327468007,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,56323.527406692505,96555.0881664753,56323.527406692505,40224.372921705246,2.366755485534668,0.0 -158300,0.22938716,1.4887067,,,,,,,,,,,,,,,,, -158400,0.28404623,1.4366713,,,,,,,,,,,,,,,,, -158500,0.22774164,1.35535,,,,,,,,,,,,,,,,, -158600,0.233813,1.4588519,,,,,,,,,,,,,,,,, -158700,0.23407827,1.4687778,,,,,,,,,,,,,,,,, -158800,0.24683504,1.5267782,,,,,,,,,,,,,,,,, -158900,0.23295598,1.4068749,,,,,,,,,,,,,,,,, -159000,0.23724262,1.4978576,,,,,,,,,,,,,,,,, -159100,0.23267747,1.4147946,,,,,,,,,,,,,,,,, -159200,0.2431369,1.43773,,,,,,,,,,,,,,,,, -159300,0.25091365,1.5202882,,,,,,,,,,,,,,,,, -159400,0.23404759,1.4206178,,,,,,,,,,,,,,,,, -159500,0.22491655,1.370069,,,,,,,,,,,,,,,,, -159600,0.22700433,1.5156274,,,,,,,,,,,,,,,,, -159700,0.22453308,1.411734,,,,,,,,,,,,,,,,, -159800,0.22544856,1.3980759,,,,,,,,,,,,,,,,, -159900,0.22904971,1.4351133,,,,,,,,,,,,,,,,, -160000,0.22927018,1.4289176,,,,,,,,,,,,,,,,, -160100,0.22325304,1.4104346,,,,,,,,,,,,,,,,, -160200,0.24419184,1.5300643,,,,,,,,,,,,,,,,, -160300,0.23552513,1.4325881,,,,,,,,,,,,,,,,, -160400,0.23450068,1.4261076,,,,,,,,,,,,,,,,, -160500,0.22585063,1.4431415,,,,,,,,,,,,,,,,, -160600,0.23967132,1.5118527,,,,,,,,,,,,,,,,, -160619,,,0.6961396336555481,1.366602063179016,35.31053510832538,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,57163.6197514534,98003.6810424328,57163.6197514534,40832.7593934536,2.4104344844818115,0.0 -160700,0.22994865,1.3845038,,,,,,,,,,,,,,,,, -160800,0.243916,1.4996148,,,,,,,,,,,,,,,,, -160900,0.23420274,1.5578734,,,,,,,,,,,,,,,,, -161000,0.22868802,1.4097058,,,,,,,,,,,,,,,,, -161100,0.22459537,1.4230642,,,,,,,,,,,,,,,,, -161200,0.24006866,1.3807726,,,,,,,,,,,,,,,,, -161300,0.23495856,1.4653955,,,,,,,,,,,,,,,,, -161400,0.23324136,1.5033689,,,,,,,,,,,,,,,,, -161500,0.22999993,1.4280674,,,,,,,,,,,,,,,,, -161600,0.23711444,1.4738419,,,,,,,,,,,,,,,,, -161700,0.2299877,1.3814704,,,,,,,,,,,,,,,,, -161800,0.30102292,1.4903473,,,,,,,,,,,,,,,,, -161900,0.26549107,1.5724709,,,,,,,,,,,,,,,,, -162000,0.23528919,1.4077028,,,,,,,,,,,,,,,,, -162100,0.24575248,1.4494184,,,,,,,,,,,,,,,,, -162200,0.23210914,1.4329857,,,,,,,,,,,,,,,,, -162300,0.22732775,1.4397154,,,,,,,,,,,,,,,,, -162400,0.2335496,1.462871,,,,,,,,,,,,,,,,, -162500,0.24796297,1.5141312,,,,,,,,,,,,,,,,, -162600,0.24116395,1.5098387,,,,,,,,,,,,,,,,, -162700,0.26035616,1.4675118,,,,,,,,,,,,,,,,, -162800,0.23487744,1.465596,,,,,,,,,,,,,,,,, -162900,0.22562441,1.4620153,,,,,,,,,,,,,,,,, -162981,,,0.6966844201087952,1.3696157932281494,35.52774077610771,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,58003.75202512741,99448.36136484146,58003.75202512741,41437.19096302986,2.4548990726470947,0.0 -163000,0.2415829,1.4587905,,,,,,,,,,,,,,,,, -163100,0.23220791,1.511659,,,,,,,,,,,,,,,,, -163200,0.23922417,1.5153079,,,,,,,,,,,,,,,,, -163300,0.22719106,1.4900169,,,,,,,,,,,,,,,,, -163400,0.2332575,1.4919789,,,,,,,,,,,,,,,,, -163500,0.23083529,1.4981374,,,,,,,,,,,,,,,,, -163600,0.23824081,1.4925485,,,,,,,,,,,,,,,,, -163700,0.2270336,1.4359546,,,,,,,,,,,,,,,,, -163800,0.22574373,1.4014766,,,,,,,,,,,,,,,,, -163900,0.23744267,1.47752,,,,,,,,,,,,,,,,, -164000,0.24046436,1.4514506,,,,,,,,,,,,,,,,, -164100,0.22535846,1.4226589,,,,,,,,,,,,,,,,, -164200,0.23207934,1.4213597,,,,,,,,,,,,,,,,, -164300,0.2457222,1.4014086,,,,,,,,,,,,,,,,, -164400,0.22939606,1.4379995,,,,,,,,,,,,,,,,, -164500,0.23161124,1.4804238,,,,,,,,,,,,,,,,, -164600,0.24935898,1.4673132,,,,,,,,,,,,,,,,, -164700,0.2316131,1.4522091,,,,,,,,,,,,,,,,, -164800,0.22611853,1.4538338,,,,,,,,,,,,,,,,, -164900,0.22933783,1.4782795,,,,,,,,,,,,,,,,, -165000,0.22767738,1.3719437,,,,,,,,,,,,,,,,, -165100,0.25056732,1.5040375,,,,,,,,,,,,,,,,, -165200,0.23413876,1.4502927,,,,,,,,,,,,,,,,, -165300,0.21803159,1.4098762,,,,,,,,,,,,,,,,, -165342,,,0.6953707337379456,1.377099871635437,35.801689686697244,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,58843.80774831772,100896.804625988,58843.80774831772,42045.46137666702,2.499814748764038,0.0 -165400,0.23561336,1.5066624,,,,,,,,,,,,,,,,, -165500,0.24601811,1.4186487,,,,,,,,,,,,,,,,, -165600,0.23644575,1.4787766,,,,,,,,,,,,,,,,, -165700,0.23326823,1.47366,,,,,,,,,,,,,,,,, -165800,0.23575626,1.4872434,,,,,,,,,,,,,,,,, -165900,0.23690103,1.3881825,,,,,,,,,,,,,,,,, -166000,0.24042784,1.3971089,,,,,,,,,,,,,,,,, -166100,0.2352782,1.4571589,,,,,,,,,,,,,,,,, -166200,0.22806737,1.3762144,,,,,,,,,,,,,,,,, -166300,0.22672288,1.3324358,,,,,,,,,,,,,,,,, -166400,0.2408327,1.4370382,,,,,,,,,,,,,,,,, -166500,0.23056646,1.4513078,,,,,,,,,,,,,,,,, -166600,0.23159492,1.4290804,,,,,,,,,,,,,,,,, -166700,0.23906761,1.4272746,,,,,,,,,,,,,,,,, -166800,0.23834175,1.4667108,,,,,,,,,,,,,,,,, -166900,0.22826973,1.4617575,,,,,,,,,,,,,,,,, -167000,0.23084448,1.4207339,,,,,,,,,,,,,,,,, -167100,0.24351019,1.543507,,,,,,,,,,,,,,,,, -167200,0.24111176,1.4492291,,,,,,,,,,,,,,,,, -167300,0.22458884,1.3706199,,,,,,,,,,,,,,,,, -167400,0.22520134,1.4378179,,,,,,,,,,,,,,,,, -167500,0.25162968,1.5024154,,,,,,,,,,,,,,,,, -167600,0.22842787,1.4347334,,,,,,,,,,,,,,,,, -167700,0.22669601,1.4289186,,,,,,,,,,,,,,,,, -167701,,,0.6936147212982178,1.382083773612976,35.36641362085959,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,59683.72746658325,102345.72543001176,59683.72746658325,42654.33971261978,2.5470104217529297,0.0 -167800,0.23909803,1.3978031,,,,,,,,,,,,,,,,, -167900,0.23152123,1.4698738,,,,,,,,,,,,,,,,, -168000,0.22426769,1.393983,,,,,,,,,,,,,,,,, -168100,0.23126566,1.3809096,,,,,,,,,,,,,,,,, -168200,0.22553338,1.4261657,,,,,,,,,,,,,,,,, -168300,0.22421806,1.4654269,,,,,,,,,,,,,,,,, -168400,0.23743503,1.4932846,,,,,,,,,,,,,,,,, -168500,0.23951797,1.4461981,,,,,,,,,,,,,,,,, -168600,0.22592087,1.4159621,,,,,,,,,,,,,,,,, -168700,0.24133706,1.4650226,,,,,,,,,,,,,,,,, -168800,0.24121225,1.4391731,,,,,,,,,,,,,,,,, -168900,0.24223627,1.4127725,,,,,,,,,,,,,,,,, -169000,0.22991996,1.4771427,,,,,,,,,,,,,,,,, -169100,0.23056377,1.4146047,,,,,,,,,,,,,,,,, -169200,0.25573307,1.4710972,,,,,,,,,,,,,,,,, -169300,0.22252877,1.4546671,,,,,,,,,,,,,,,,, -169400,0.22718416,1.4458855,,,,,,,,,,,,,,,,, -169500,0.23785147,1.4287249,,,,,,,,,,,,,,,,, -169600,0.23360886,1.3535284,,,,,,,,,,,,,,,,, -169700,0.23958832,1.438653,,,,,,,,,,,,,,,,, -169800,0.24112613,1.4784902,,,,,,,,,,,,,,,,, -169900,0.23316579,1.4021881,,,,,,,,,,,,,,,,, -170000,0.23802963,1.4097543,,,,,,,,,,,,,,,,, -170063,,,0.6958931684494019,1.3654128313064575,35.463748492037126,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,60523.73269152641,103772.97212171556,60523.73269152641,43241.46246051788,2.592487573623657,0.0 -170100,0.23462592,1.348446,,,,,,,,,,,,,,,,, -170200,0.23838453,1.4362488,,,,,,,,,,,,,,,,, -170300,0.23046963,1.4903586,,,,,,,,,,,,,,,,, -170400,0.23803365,1.5674279,,,,,,,,,,,,,,,,, -170500,0.23301958,1.383116,,,,,,,,,,,,,,,,, -170600,0.237534,1.5209042,,,,,,,,,,,,,,,,, -170700,0.22675596,1.4134911,,,,,,,,,,,,,,,,, -170800,0.24302882,1.4353392,,,,,,,,,,,,,,,,, -170900,0.23330897,1.4103718,,,,,,,,,,,,,,,,, -171000,0.23829713,1.482592,,,,,,,,,,,,,,,,, -171100,0.22667891,1.4268311,,,,,,,,,,,,,,,,, -171200,0.23880887,1.4665549,,,,,,,,,,,,,,,,, -171300,0.23213276,1.4249713,,,,,,,,,,,,,,,,, -171400,0.23987281,1.402273,,,,,,,,,,,,,,,,, -171500,0.24029115,1.5137452,,,,,,,,,,,,,,,,, -171600,0.2367386,1.3939668,,,,,,,,,,,,,,,,, -171700,0.23710695,1.4017386,,,,,,,,,,,,,,,,, -171800,0.2344188,1.4403015,,,,,,,,,,,,,,,,, -171900,0.23672257,1.4314381,,,,,,,,,,,,,,,,, -172000,0.22557457,1.4259259,,,,,,,,,,,,,,,,, -172100,0.23454003,1.4229969,,,,,,,,,,,,,,,,, -172200,0.24613327,1.5239208,,,,,,,,,,,,,,,,, -172300,0.23645702,1.4137508,,,,,,,,,,,,,,,,, -172400,0.22810942,1.4362087,,,,,,,,,,,,,,,,, -172424,,,0.6973316669464111,1.3603744506835938,35.595653273466674,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,61363.77395796776,105216.657037735,61363.77395796776,43844.98872923851,2.6387388706207275,0.0 -172500,0.2268728,1.4213247,,,,,,,,,,,,,,,,, -172600,0.22886899,1.4304137,,,,,,,,,,,,,,,,, -172700,0.24329598,1.4013915,,,,,,,,,,,,,,,,, -172800,0.2341754,1.477306,,,,,,,,,,,,,,,,, -172900,0.22871858,1.3295863,,,,,,,,,,,,,,,,, -173000,0.24635991,1.4668245,,,,,,,,,,,,,,,,, -173100,0.23809008,1.4686375,,,,,,,,,,,,,,,,, -173200,0.23774059,1.4111509,,,,,,,,,,,,,,,,, -173300,0.22803561,1.4190122,,,,,,,,,,,,,,,,, -173400,0.23310184,1.4645408,,,,,,,,,,,,,,,,, -173500,0.23648657,1.3782997,,,,,,,,,,,,,,,,, -173600,0.2401171,1.4182503,,,,,,,,,,,,,,,,, -173700,0.23228633,1.4409767,,,,,,,,,,,,,,,,, -173800,0.23532297,1.492862,,,,,,,,,,,,,,,,, -173900,0.23708664,1.4877744,,,,,,,,,,,,,,,,, -174000,0.23619837,1.4255378,,,,,,,,,,,,,,,,, -174100,0.2375434,1.4324384,,,,,,,,,,,,,,,,, -174200,0.23566231,1.4610864,,,,,,,,,,,,,,,,, -174300,0.23036225,1.4284317,,,,,,,,,,,,,,,,, -174400,0.23294207,1.4003597,,,,,,,,,,,,,,,,, -174500,0.23112392,1.4281429,,,,,,,,,,,,,,,,, -174600,0.2260592,1.3642341,,,,,,,,,,,,,,,,, -174700,0.23265032,1.4337263,,,,,,,,,,,,,,,,, -174785,,,0.6968789100646973,1.3602083921432495,35.25066549945658,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,62203.7920062542,106659.43720054626,62203.7920062542,44447.63086080551,2.6856791973114014,0.0 -174800,0.23055498,1.4164392,,,,,,,,,,,,,,,,, -174900,0.24350575,1.50203,,,,,,,,,,,,,,,,, -175000,0.22588134,1.4322444,,,,,,,,,,,,,,,,, -175100,0.24291259,1.4480902,,,,,,,,,,,,,,,,, -175200,0.2258071,1.4119208,,,,,,,,,,,,,,,,, -175300,0.24820118,1.458278,,,,,,,,,,,,,,,,, -175400,0.23418182,1.4189545,,,,,,,,,,,,,,,,, -175500,0.23512217,1.444463,,,,,,,,,,,,,,,,, -175600,0.22889148,1.3595586,,,,,,,,,,,,,,,,, -175700,0.25024888,1.4876457,,,,,,,,,,,,,,,,, -175800,0.22676717,1.3924037,,,,,,,,,,,,,,,,, -175900,0.23225333,1.446848,,,,,,,,,,,,,,,,, -176000,0.2326803,1.4783605,,,,,,,,,,,,,,,,, -176100,0.24983422,1.4648305,,,,,,,,,,,,,,,,, -176200,0.23885837,1.4723421,,,,,,,,,,,,,,,,, -176300,0.23822433,1.4297178,,,,,,,,,,,,,,,,, -176400,0.22346257,1.3647922,,,,,,,,,,,,,,,,, -176500,0.23131981,1.4702002,,,,,,,,,,,,,,,,, -176600,0.23836425,1.43635,,,,,,,,,,,,,,,,, -176700,0.23478006,1.3998936,,,,,,,,,,,,,,,,, -176800,0.22547476,1.3534316,,,,,,,,,,,,,,,,, -176900,0.24464214,1.3958187,,,,,,,,,,,,,,,,, -177000,0.24030712,1.4560974,,,,,,,,,,,,,,,,, -177100,0.23398435,1.4881452,,,,,,,,,,,,,,,,, -177146,,,0.6965775489807129,1.3650387525558472,35.620679245105535,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,63043.82459115982,108100.36066675186,63043.82459115982,45048.39179825783,2.743337154388428,0.0 -177200,0.23514059,1.3946677,,,,,,,,,,,,,,,,, -177300,0.23197925,1.4596598,,,,,,,,,,,,,,,,, -177400,0.24169774,1.4751178,,,,,,,,,,,,,,,,, -177500,0.23659205,1.3998024,,,,,,,,,,,,,,,,, -177600,0.24137597,1.4631028,,,,,,,,,,,,,,,,, -177700,0.22984305,1.4447517,,,,,,,,,,,,,,,,, -177800,0.22886516,1.4631037,,,,,,,,,,,,,,,,, -177900,0.23872991,1.4863919,,,,,,,,,,,,,,,,, -178000,0.24179737,1.4996935,,,,,,,,,,,,,,,,, -178100,0.24073343,1.4193457,,,,,,,,,,,,,,,,, -178200,0.24018277,1.5335186,,,,,,,,,,,,,,,,, -178300,0.2350356,1.4427841,,,,,,,,,,,,,,,,, -178400,0.22089675,1.4515015,,,,,,,,,,,,,,,,, -178500,0.2417795,1.4735657,,,,,,,,,,,,,,,,, -178600,0.22113052,1.3533878,,,,,,,,,,,,,,,,, -178700,0.23297985,1.4339224,,,,,,,,,,,,,,,,, -178800,0.2372208,1.4305732,,,,,,,,,,,,,,,,, -178900,0.24039488,1.5056913,,,,,,,,,,,,,,,,, -179000,0.23308513,1.4856781,,,,,,,,,,,,,,,,, -179100,0.23239584,1.4479997,,,,,,,,,,,,,,,,, -179200,0.24088307,1.4015114,,,,,,,,,,,,,,,,, -179300,0.23983504,1.4606974,,,,,,,,,,,,,,,,, -179400,0.240735,1.4295704,,,,,,,,,,,,,,,,, -179500,0.24557984,1.5015692,,,,,,,,,,,,,,,,, -179508,,,0.6962001919746399,1.3616358041763306,35.35056694507423,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,63884.05560970306,109547.62645864488,63884.05560970306,45655.30615091324,2.792304277420044,0.0 -179600,0.22959569,1.4218388,,,,,,,,,,,,,,,,, -179700,0.22748856,1.4392486,,,,,,,,,,,,,,,,, -179800,0.23955779,1.4298415,,,,,,,,,,,,,,,,, -179900,0.23357163,1.4741211,,,,,,,,,,,,,,,,, -180000,0.23029651,1.38702,,,,,,,,,,,,,,,,, -180100,0.22750373,1.4337946,,,,,,,,,,,,,,,,, -180200,0.23738018,1.4510567,,,,,,,,,,,,,,,,, -180300,0.23342158,1.4717829,,,,,,,,,,,,,,,,, -180400,0.23370554,1.4852611,,,,,,,,,,,,,,,,, -180500,0.23434195,1.3710835,,,,,,,,,,,,,,,,, -180600,0.26854673,1.4352375,,,,,,,,,,,,,,,,, -180700,0.23477764,1.4617771,,,,,,,,,,,,,,,,, -180800,0.24472287,1.4384487,,,,,,,,,,,,,,,,, -180900,0.24530086,1.4789526,,,,,,,,,,,,,,,,, -181000,0.23774587,1.444513,,,,,,,,,,,,,,,,, -181100,0.24079016,1.4268757,,,,,,,,,,,,,,,,, -181200,0.23534846,1.3739823,,,,,,,,,,,,,,,,, -181300,0.23139739,1.3731358,,,,,,,,,,,,,,,,, -181400,0.2323265,1.4468297,,,,,,,,,,,,,,,,, -181500,0.23385806,1.4521825,,,,,,,,,,,,,,,,, -181600,0.2287542,1.4604554,,,,,,,,,,,,,,,,, -181700,0.23375128,1.4170345,,,,,,,,,,,,,,,,, -181800,0.24009259,1.4751042,,,,,,,,,,,,,,,,, -181869,,,0.6943336129188538,1.380504846572876,35.76284631446216,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,64724.220274209976,110994.8827548027,64724.220274209976,46262.27606844902,2.840443849563598,0.0 -181900,0.24003892,1.4765943,,,,,,,,,,,,,,,,, -182000,0.23424362,1.4182817,,,,,,,,,,,,,,,,, -182100,0.23522757,1.422954,,,,,,,,,,,,,,,,, -182200,0.24652718,1.4708686,,,,,,,,,,,,,,,,, -182300,0.22850636,1.4085213,,,,,,,,,,,,,,,,, -182400,0.23525913,1.4799168,,,,,,,,,,,,,,,,, -182500,0.23118629,1.457831,,,,,,,,,,,,,,,,, -182600,0.24410343,1.475202,,,,,,,,,,,,,,,,, -182700,0.231229,1.3934791,,,,,,,,,,,,,,,,, -182800,0.23918444,1.4679332,,,,,,,,,,,,,,,,, -182900,0.23579961,1.424501,,,,,,,,,,,,,,,,, -183000,0.2343551,1.4398156,,,,,,,,,,,,,,,,, -183100,0.23886621,1.4103198,,,,,,,,,,,,,,,,, -183200,0.23769377,1.4144684,,,,,,,,,,,,,,,,, -183300,0.22225738,1.4322337,,,,,,,,,,,,,,,,, -183400,0.2281166,1.4327482,,,,,,,,,,,,,,,,, -183500,0.2385486,1.4612112,,,,,,,,,,,,,,,,, -183600,0.23204128,1.423599,,,,,,,,,,,,,,,,, -183700,0.22732103,1.4745761,,,,,,,,,,,,,,,,, -183800,0.23407112,1.4078716,,,,,,,,,,,,,,,,, -183900,0.2343379,1.3925431,,,,,,,,,,,,,,,,, -184000,0.22569928,1.4482816,,,,,,,,,,,,,,,,, -184100,0.23463744,1.4237036,,,,,,,,,,,,,,,,, -184200,0.24127209,1.4410945,,,,,,,,,,,,,,,,, -184230,,,0.6938932538032532,1.381688952445984,35.454537768174426,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,65564.20058608055,112437.20860123634,65564.20058608055,46864.49913692474,2.888885021209717,0.0 -184300,0.2255944,1.4027528,,,,,,,,,,,,,,,,, -184400,0.23003773,1.4104019,,,,,,,,,,,,,,,,, -184500,0.22363456,1.3917891,,,,,,,,,,,,,,,,, -184600,0.23220131,1.4697956,,,,,,,,,,,,,,,,, -184700,0.2417714,1.5403227,,,,,,,,,,,,,,,,, -184800,0.23111275,1.4269565,,,,,,,,,,,,,,,,, -184900,0.23773079,1.3995925,,,,,,,,,,,,,,,,, -185000,0.23171331,1.4203993,,,,,,,,,,,,,,,,, -185100,0.23008439,1.3754617,,,,,,,,,,,,,,,,, -185200,0.23845692,1.4918885,,,,,,,,,,,,,,,,, -185300,0.22104378,1.4430945,,,,,,,,,,,,,,,,, -185400,0.23751676,1.4410845,,,,,,,,,,,,,,,,, -185500,0.24149421,1.4505514,,,,,,,,,,,,,,,,, -185600,0.24988227,1.5047897,,,,,,,,,,,,,,,,, -185700,0.24788152,1.4389958,,,,,,,,,,,,,,,,, -185800,0.2313294,1.4019576,,,,,,,,,,,,,,,,, -185900,0.22733136,1.4220551,,,,,,,,,,,,,,,,, -186000,0.22206856,1.4109002,,,,,,,,,,,,,,,,, -186100,0.24457192,1.4537433,,,,,,,,,,,,,,,,, -186200,0.23023255,1.417816,,,,,,,,,,,,,,,,, -186300,0.23688848,1.4537696,,,,,,,,,,,,,,,,, -186400,0.2345523,1.4272782,,,,,,,,,,,,,,,,, -186500,0.23493218,1.4187822,,,,,,,,,,,,,,,,, -186591,,,0.6969566345214844,1.3639436960220337,35.72708821302105,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,66404.26647567749,113876.89898991583,66404.26647567749,47464.00195264816,2.9393835067749023,0.0 -186600,0.24218877,1.4304265,,,,,,,,,,,,,,,,, -186700,0.2361681,1.4215523,,,,,,,,,,,,,,,,, -186800,0.2259521,1.3750994,,,,,,,,,,,,,,,,, -186900,0.2277874,1.422498,,,,,,,,,,,,,,,,, -187000,0.22684775,1.4519147,,,,,,,,,,,,,,,,, -187100,0.23364927,1.4447078,,,,,,,,,,,,,,,,, -187200,0.230962,1.4375656,,,,,,,,,,,,,,,,, -187300,0.22622378,1.4725722,,,,,,,,,,,,,,,,, -187400,0.25114885,1.4642681,,,,,,,,,,,,,,,,, -187500,0.23994914,1.4780399,,,,,,,,,,,,,,,,, -187600,0.23130417,1.4018731,,,,,,,,,,,,,,,,, -187700,0.24854024,1.4222788,,,,,,,,,,,,,,,,, -187800,0.27555928,1.3877096,,,,,,,,,,,,,,,,, -187900,0.22441724,1.3944656,,,,,,,,,,,,,,,,, -188000,0.24041997,1.4803439,,,,,,,,,,,,,,,,, -188100,0.23089801,1.4343228,,,,,,,,,,,,,,,,, -188200,0.2299284,1.5267454,,,,,,,,,,,,,,,,, -188300,0.23564388,1.3932729,,,,,,,,,,,,,,,,, -188400,0.23272474,1.4139533,,,,,,,,,,,,,,,,, -188500,0.22993506,1.505305,,,,,,,,,,,,,,,,, -188600,0.24220052,1.4526938,,,,,,,,,,,,,,,,, -188700,0.23315041,1.376294,,,,,,,,,,,,,,,,, -188800,0.22959463,1.4108841,,,,,,,,,,,,,,,,, -188900,0.2266168,1.4267377,,,,,,,,,,,,,,,,, -188952,,,0.6942804455757141,1.3843278884887695,35.36090369464056,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,67244.17678833008,115329.57243132593,67244.17678833008,48076.644453287125,2.9872334003448486,0.0 -189000,0.23062931,1.4496315,,,,,,,,,,,,,,,,, -189100,0.22720574,1.4397119,,,,,,,,,,,,,,,,, -189200,0.24155658,1.4256138,,,,,,,,,,,,,,,,, -189300,0.23304498,1.4577397,,,,,,,,,,,,,,,,, -189400,0.24540451,1.4103826,,,,,,,,,,,,,,,,, -189500,0.23614816,1.464951,,,,,,,,,,,,,,,,, -189600,0.23405798,1.5040517,,,,,,,,,,,,,,,,, -189700,0.23362729,1.4035733,,,,,,,,,,,,,,,,, -189800,0.23523216,1.4100066,,,,,,,,,,,,,,,,, -189900,0.25100383,1.489952,,,,,,,,,,,,,,,,, -190000,0.23658809,1.3842334,,,,,,,,,,,,,,,,, -190100,0.22741543,1.4484038,,,,,,,,,,,,,,,,, -190200,0.22247262,1.4277653,,,,,,,,,,,,,,,,, -190300,0.23026933,1.4807227,,,,,,,,,,,,,,,,, -190400,0.22703247,1.4772182,,,,,,,,,,,,,,,,, -190500,0.24901596,1.501201,,,,,,,,,,,,,,,,, -190600,0.23574965,1.4830301,,,,,,,,,,,,,,,,, -190700,0.22880249,1.3481646,,,,,,,,,,,,,,,,, -190800,0.24111292,1.4719399,,,,,,,,,,,,,,,,, -190900,0.2258992,1.4111258,,,,,,,,,,,,,,,,, -191000,0.23823681,1.4639947,,,,,,,,,,,,,,,,, -191100,0.23311041,1.4735988,,,,,,,,,,,,,,,,, -191200,0.23354904,1.4422935,,,,,,,,,,,,,,,,, -191300,0.23571314,1.4131732,,,,,,,,,,,,,,,,, -191313,,,0.6944150328636169,1.3720909357070925,35.722617039369105,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,68084.2696146965,116766.5665898323,68084.2696146965,48673.42203640938,3.0373005867004395,0.0 -191400,0.23741797,1.395995,,,,,,,,,,,,,,,,, -191500,0.23396784,1.492919,,,,,,,,,,,,,,,,, -191600,0.23807521,1.3958426,,,,,,,,,,,,,,,,, -191700,0.23384014,1.4381409,,,,,,,,,,,,,,,,, -191800,0.22692761,1.4244827,,,,,,,,,,,,,,,,, -191900,0.23079729,1.4317251,,,,,,,,,,,,,,,,, -192000,0.23343061,1.4092739,,,,,,,,,,,,,,,,, -192100,0.23483166,1.4510727,,,,,,,,,,,,,,,,, -192200,0.23056331,1.4145304,,,,,,,,,,,,,,,,, -192300,0.22748889,1.4427278,,,,,,,,,,,,,,,,, -192400,0.23339906,1.4076643,,,,,,,,,,,,,,,,, -192500,0.23060514,1.3645701,,,,,,,,,,,,,,,,, -192600,0.22296984,1.4199157,,,,,,,,,,,,,,,,, -192700,0.23295829,1.4860415,,,,,,,,,,,,,,,,, -192800,0.22872807,1.4456341,,,,,,,,,,,,,,,,, -192900,0.23296697,1.411842,,,,,,,,,,,,,,,,, -193000,0.24201463,1.4438046,,,,,,,,,,,,,,,,, -193100,0.24867657,1.524319,,,,,,,,,,,,,,,,, -193200,0.24192624,1.4654577,,,,,,,,,,,,,,,,, -193300,0.23408224,1.4568613,,,,,,,,,,,,,,,,, -193400,0.23746249,1.4386064,,,,,,,,,,,,,,,,, -193500,0.23824632,1.4779565,,,,,,,,,,,,,,,,, -193600,0.2243903,1.3458133,,,,,,,,,,,,,,,,, -193674,,,0.6968308687210083,1.3575750589370728,35.36122057156541,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,68924.22136425972,118214.84686422348,68924.22136425972,49281.62001180649,3.0965020656585693,0.0 -193700,0.22780691,1.4283886,,,,,,,,,,,,,,,,, -193800,0.22854881,1.4610314,,,,,,,,,,,,,,,,, -193900,0.23817536,1.4062228,,,,,,,,,,,,,,,,, -194000,0.23384097,1.46642,,,,,,,,,,,,,,,,, -194100,0.23907207,1.5027202,,,,,,,,,,,,,,,,, -194200,0.23383512,1.4673964,,,,,,,,,,,,,,,,, -194300,0.23549838,1.389231,,,,,,,,,,,,,,,,, -194400,0.2393206,1.4767954,,,,,,,,,,,,,,,,, -194500,0.24062431,1.3861194,,,,,,,,,,,,,,,,, -194600,0.23570639,1.4201002,,,,,,,,,,,,,,,,, -194700,0.23514946,1.3597375,,,,,,,,,,,,,,,,, -194800,0.2409083,1.4697654,,,,,,,,,,,,,,,,, -194900,0.23730057,1.4887233,,,,,,,,,,,,,,,,, -195000,0.23056215,1.4886767,,,,,,,,,,,,,,,,, -195100,0.23009327,1.441791,,,,,,,,,,,,,,,,, -195200,0.2304866,1.4165568,,,,,,,,,,,,,,,,, -195300,0.23707457,1.4761912,,,,,,,,,,,,,,,,, -195400,0.2475825,1.450971,,,,,,,,,,,,,,,,, -195500,0.2348376,1.4193268,,,,,,,,,,,,,,,,, -195600,0.2338976,1.4773088,,,,,,,,,,,,,,,,, -195700,0.23932827,1.3828695,,,,,,,,,,,,,,,,, -195800,0.22806142,1.4165543,,,,,,,,,,,,,,,,, -195900,0.24638899,1.3811643,,,,,,,,,,,,,,,,, -196000,0.2303065,1.3672031,,,,,,,,,,,,,,,,, -196035,,,0.6957708597183228,1.3637819290161133,35.55747837773363,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,69764.31381583214,119671.69554686546,69764.31381583214,49898.25337815285,3.146656990051269,0.0 -196100,0.23148768,1.5096027,,,,,,,,,,,,,,,,, -196200,0.22785805,1.4619176,,,,,,,,,,,,,,,,, -196300,0.23456281,1.4106424,,,,,,,,,,,,,,,,, -196400,0.23408191,1.4472973,,,,,,,,,,,,,,,,, -196500,0.23645243,1.4905403,,,,,,,,,,,,,,,,, -196600,0.24043907,1.4481922,,,,,,,,,,,,,,,,, -196700,0.23910008,1.4185404,,,,,,,,,,,,,,,,, -196800,0.25000268,1.4478252,,,,,,,,,,,,,,,,, -196900,0.24129592,1.476827,,,,,,,,,,,,,,,,, -197000,0.22990488,1.441065,,,,,,,,,,,,,,,,, -197100,0.23479444,1.4031432,,,,,,,,,,,,,,,,, -197200,0.23470862,1.4830673,,,,,,,,,,,,,,,,, -197300,0.23873734,1.4245899,,,,,,,,,,,,,,,,, -197400,0.2324153,1.433018,,,,,,,,,,,,,,,,, -197500,0.24426822,1.4522338,,,,,,,,,,,,,,,,, -197600,0.23021251,1.4118229,,,,,,,,,,,,,,,,, -197700,0.22964782,1.4006218,,,,,,,,,,,,,,,,, -197800,0.22963855,1.4135473,,,,,,,,,,,,,,,,, -197900,0.22529453,1.4381485,,,,,,,,,,,,,,,,, -198000,0.23198394,1.4596524,,,,,,,,,,,,,,,,, -198100,0.23651855,1.4638782,,,,,,,,,,,,,,,,, -198200,0.23813432,1.4885685,,,,,,,,,,,,,,,,, -198300,0.23915158,1.4072846,,,,,,,,,,,,,,,,, -198396,,,0.6959868669509888,1.3676646947860718,35.303194536835115,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,70604.34129023552,121115.29785180092,70604.34129023552,50501.69980740547,3.2037465572357178,0.0 -198400,0.22829902,1.3921349,,,,,,,,,,,,,,,,, -198500,0.24912417,1.4920175,,,,,,,,,,,,,,,,, -198600,0.2331819,1.4670478,,,,,,,,,,,,,,,,, -198700,0.23035842,1.4266428,,,,,,,,,,,,,,,,, -198800,0.23756868,1.4572073,,,,,,,,,,,,,,,,, -198900,0.2538877,1.4672937,,,,,,,,,,,,,,,,, -199000,0.23426665,1.4505148,,,,,,,,,,,,,,,,, -199100,0.23888576,1.5342381,,,,,,,,,,,,,,,,, -199200,0.22990534,1.4108608,,,,,,,,,,,,,,,,, -199300,0.24860132,1.410046,,,,,,,,,,,,,,,,, -199400,0.23107415,1.4439013,,,,,,,,,,,,,,,,, -199500,0.23172002,1.442891,,,,,,,,,,,,,,,,, -199600,0.23322871,1.4444504,,,,,,,,,,,,,,,,, -199700,0.24139348,1.4805957,,,,,,,,,,,,,,,,, -199800,0.22809875,1.4024215,,,,,,,,,,,,,,,,, -199900,0.22968768,1.4583094,,,,,,,,,,,,,,,,, -200000,0.23204795,1.3872207,,,,,,,,,,,,,,,,, -200100,0.24167755,1.4890505,,,,,,,,,,,,,,,,, -200200,0.23989351,1.4042485,,,,,,,,,,,,,,,,, -200300,0.236759,1.500211,,,,,,,,,,,,,,,,, -200400,0.22527659,1.4136877,,,,,,,,,,,,,,,,, -200500,0.2352291,1.4787297,,,,,,,,,,,,,,,,, -200600,0.23340915,1.4643981,,,,,,,,,,,,,,,,, -200700,0.2358442,1.4644839,,,,,,,,,,,,,,,,, -200758,,,0.6991065740585327,1.352362036705017,35.527593314348735,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,71444.29108786583,122557.77553725244,71444.29108786583,51104.10656738281,3.2538113594055176,0.0 -200800,0.23810983,1.4438688,,,,,,,,,,,,,,,,, -200900,0.22922634,1.395809,,,,,,,,,,,,,,,,, -201000,0.23519678,1.4830614,,,,,,,,,,,,,,,,, -201100,0.23544937,1.4561449,,,,,,,,,,,,,,,,, -201200,0.22845171,1.440433,,,,,,,,,,,,,,,,, -201300,0.23981853,1.5254735,,,,,,,,,,,,,,,,, -201400,0.23640844,1.4994645,,,,,,,,,,,,,,,,, -201500,0.24309596,1.4492042,,,,,,,,,,,,,,,,, -201600,0.22925463,1.3736354,,,,,,,,,,,,,,,,, -201700,0.22951612,1.3925725,,,,,,,,,,,,,,,,, -201800,0.23433897,1.464012,,,,,,,,,,,,,,,,, -201900,0.22589363,1.460584,,,,,,,,,,,,,,,,, -202000,0.23078051,1.4306102,,,,,,,,,,,,,,,,, -202100,0.23539412,1.4991968,,,,,,,,,,,,,,,,, -202200,0.22707742,1.4836987,,,,,,,,,,,,,,,,, -202300,0.23091727,1.3723916,,,,,,,,,,,,,,,,, -202400,0.2318559,1.4536351,,,,,,,,,,,,,,,,, -202500,0.24338976,1.4934008,,,,,,,,,,,,,,,,, -202600,0.24173252,1.4333997,,,,,,,,,,,,,,,,, -202700,0.23569664,1.4656833,,,,,,,,,,,,,,,,, -202800,0.22613458,1.4518449,,,,,,,,,,,,,,,,, -202900,0.23630203,1.4982204,,,,,,,,,,,,,,,,, -203000,0.23448953,1.4845592,,,,,,,,,,,,,,,,, -203100,0.2366466,1.3813219,,,,,,,,,,,,,,,,, -203119,,,0.6929601430892944,1.3850677013397217,35.516672633384594,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,72284.34855413437,123991.47338604929,72284.34855413437,51697.62382078171,3.304860591888428,0.0 -203200,0.22999105,1.4069179,,,,,,,,,,,,,,,,, -203300,0.24062039,1.5146759,,,,,,,,,,,,,,,,, -203400,0.22959968,1.414034,,,,,,,,,,,,,,,,, -203500,0.24191174,1.5122122,,,,,,,,,,,,,,,,, -203600,0.23506738,1.4272738,,,,,,,,,,,,,,,,, -203700,0.22791566,1.3889511,,,,,,,,,,,,,,,,, -203800,0.22741137,1.404022,,,,,,,,,,,,,,,,, -203900,0.2307871,1.4333506,,,,,,,,,,,,,,,,, -204000,0.234019,1.4220477,,,,,,,,,,,,,,,,, -204100,0.23432904,1.5150985,,,,,,,,,,,,,,,,, -204200,0.23086473,1.4603496,,,,,,,,,,,,,,,,, -204300,0.22603509,1.4183644,,,,,,,,,,,,,,,,, -204400,0.2282592,1.4996607,,,,,,,,,,,,,,,,, -204500,0.24166113,1.4064065,,,,,,,,,,,,,,,,, -204600,0.22570057,1.3941412,,,,,,,,,,,,,,,,, -204700,0.23948984,1.4533979,,,,,,,,,,,,,,,,, -204800,0.23084232,1.4230419,,,,,,,,,,,,,,,,, -204900,0.23469616,1.4792938,,,,,,,,,,,,,,,,, -205000,0.2416927,1.4379663,,,,,,,,,,,,,,,,, -205100,0.23067851,1.4826442,,,,,,,,,,,,,,,,, -205200,0.23269033,1.3794571,,,,,,,,,,,,,,,,, -205300,0.23731817,1.422807,,,,,,,,,,,,,,,,, -205400,0.22137578,1.4035453,,,,,,,,,,,,,,,,, -205480,,,0.6939802765846252,1.375113010406494,35.385792217309884,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,73124.40326523781,125436.77930998802,73124.40326523781,52302.75385856629,3.355266809463501,0.0 -205500,0.22958495,1.365436,,,,,,,,,,,,,,,,, -205600,0.22862987,1.4590379,,,,,,,,,,,,,,,,, -205700,0.2249802,1.4061319,,,,,,,,,,,,,,,,, -205800,0.23493223,1.423831,,,,,,,,,,,,,,,,, -205900,0.23392078,1.4581286,,,,,,,,,,,,,,,,, -206000,0.23481728,1.4426215,,,,,,,,,,,,,,,,, -206100,0.23585321,1.4597327,,,,,,,,,,,,,,,,, -206200,0.24014685,1.4395045,,,,,,,,,,,,,,,,, -206300,0.24580516,1.5053854,,,,,,,,,,,,,,,,, -206400,0.22316647,1.4401169,,,,,,,,,,,,,,,,, -206500,0.23893037,1.4576706,,,,,,,,,,,,,,,,, -206600,0.2250714,1.3344321,,,,,,,,,,,,,,,,, -206700,0.24229312,1.61172,,,,,,,,,,,,,,,,, -206800,0.23124409,1.4991442,,,,,,,,,,,,,,,,, -206900,0.22724886,1.4408027,,,,,,,,,,,,,,,,, -207000,0.2295841,1.3718338,,,,,,,,,,,,,,,,, -207100,0.23077804,1.4231178,,,,,,,,,,,,,,,,, -207200,0.23451057,1.418887,,,,,,,,,,,,,,,,, -207300,0.23139682,1.406026,,,,,,,,,,,,,,,,, -207400,0.2336762,1.4761567,,,,,,,,,,,,,,,,, -207500,0.23756175,1.4569345,,,,,,,,,,,,,,,,, -207600,0.24056786,1.4843416,,,,,,,,,,,,,,,,, -207700,0.22935022,1.4802153,,,,,,,,,,,,,,,,, -207800,0.23698296,1.4237937,,,,,,,,,,,,,,,,, -207841,,,0.697274386882782,1.3620195388793943,35.38380765178742,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,73964.36746478081,126894.86238741876,73964.36746478081,52920.74584150314,3.409360647201538,0.0 -207900,0.2486549,1.5117838,,,,,,,,,,,,,,,,, -208000,0.23461588,1.4297336,,,,,,,,,,,,,,,,, -208100,0.22642334,1.3853927,,,,,,,,,,,,,,,,, -208200,0.23005097,1.4095613,,,,,,,,,,,,,,,,, -208300,0.2316791,1.4542204,,,,,,,,,,,,,,,,, -208400,0.23779948,1.4420984,,,,,,,,,,,,,,,,, -208500,0.23504092,1.4901152,,,,,,,,,,,,,,,,, -208600,0.23196827,1.4539768,,,,,,,,,,,,,,,,, -208700,0.23551126,1.4626402,,,,,,,,,,,,,,,,, -208800,0.22171919,1.416148,,,,,,,,,,,,,,,,, -208900,0.23024769,1.3817506,,,,,,,,,,,,,,,,, -209000,0.23674011,1.529468,,,,,,,,,,,,,,,,, -209100,0.23016731,1.409391,,,,,,,,,,,,,,,,, -209200,0.23242803,1.4151725,,,,,,,,,,,,,,,,, -209300,0.24298042,1.5025216,,,,,,,,,,,,,,,,, -209400,0.22986045,1.4911144,,,,,,,,,,,,,,,,, -209500,0.22858281,1.4202242,,,,,,,,,,,,,,,,, -209600,0.2327627,1.4669584,,,,,,,,,,,,,,,,, -209700,0.24313815,1.5382462,,,,,,,,,,,,,,,,, -209800,0.22513,1.3977685,,,,,,,,,,,,,,,,, -209900,0.22688659,1.4119228,,,,,,,,,,,,,,,,, -210000,0.22820047,1.4578089,,,,,,,,,,,,,,,,, -210100,0.22865371,1.4656626,,,,,,,,,,,,,,,,, -210200,0.22233348,1.391779,,,,,,,,,,,,,,,,, -210201,,,0.6986714601516724,1.350701928138733,35.170028935962286,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,74804.41276717186,128342.26640844344,74804.41276717186,53527.97238135338,3.470571994781494,0.0 -210300,0.23334613,1.4571346,,,,,,,,,,,,,,,,, -210400,0.23162036,1.4110097,,,,,,,,,,,,,,,,, -210500,0.2430535,1.4766914,,,,,,,,,,,,,,,,, -210600,0.23480292,1.4311014,,,,,,,,,,,,,,,,, -210700,0.22639033,1.3819019,,,,,,,,,,,,,,,,, -210800,0.22876437,1.430964,,,,,,,,,,,,,,,,, -210900,0.23969796,1.4015521,,,,,,,,,,,,,,,,, -211000,0.22973457,1.4268126,,,,,,,,,,,,,,,,, -211100,0.23644502,1.409221,,,,,,,,,,,,,,,,, -211200,0.23139007,1.462119,,,,,,,,,,,,,,,,, -211300,0.22809243,1.4227078,,,,,,,,,,,,,,,,, -211400,0.22548337,1.4354692,,,,,,,,,,,,,,,,, -211500,0.2437212,1.4980439,,,,,,,,,,,,,,,,, -211600,0.23628925,1.3935497,,,,,,,,,,,,,,,,, -211700,0.23551004,1.4983343,,,,,,,,,,,,,,,,, -211800,0.23917633,1.4523904,,,,,,,,,,,,,,,,, -211900,0.23937187,1.4424121,,,,,,,,,,,,,,,,, -212000,0.24487513,1.4473839,,,,,,,,,,,,,,,,, -212100,0.22843352,1.4842948,,,,,,,,,,,,,,,,, -212200,0.22822751,1.457789,,,,,,,,,,,,,,,,, -212300,0.23387167,1.5133936,,,,,,,,,,,,,,,,, -212400,0.23256533,1.4686594,,,,,,,,,,,,,,,,, -212500,0.22442321,1.4191333,,,,,,,,,,,,,,,,, -212562,,,0.699073314666748,1.3556746244430542,35.37286067844723,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,75644.38429164886,129784.45477199554,75644.38429164886,54130.06509780884,3.5222976207733154,0.0 -212600,0.22957486,1.4356236,,,,,,,,,,,,,,,,, -212700,0.23540975,1.4180344,,,,,,,,,,,,,,,,, -212800,0.23219794,1.4452488,,,,,,,,,,,,,,,,, -212900,0.23118679,1.3875054,,,,,,,,,,,,,,,,, -213000,0.22772077,1.4006389,,,,,,,,,,,,,,,,, -213100,0.23201689,1.4615132,,,,,,,,,,,,,,,,, -213200,0.23544939,1.4528826,,,,,,,,,,,,,,,,, -213300,0.22949831,1.4702612,,,,,,,,,,,,,,,,, -213400,0.23158143,1.4104351,,,,,,,,,,,,,,,,, -213500,0.23633532,1.4827427,,,,,,,,,,,,,,,,, -213600,0.24583508,1.4405919,,,,,,,,,,,,,,,,, -213700,0.24114566,1.404161,,,,,,,,,,,,,,,,, -213800,0.2343709,1.5815762,,,,,,,,,,,,,,,,, -213900,0.24524167,1.4583782,,,,,,,,,,,,,,,,, -214000,0.2325585,1.4268005,,,,,,,,,,,,,,,,, -214100,0.23889107,1.440891,,,,,,,,,,,,,,,,, -214200,0.23199369,1.45427,,,,,,,,,,,,,,,,, -214300,0.22391975,1.4298573,,,,,,,,,,,,,,,,, -214400,0.24258615,1.4424298,,,,,,,,,,,,,,,,, -214500,0.22866307,1.468837,,,,,,,,,,,,,,,,, -214600,0.22709896,1.4950126,,,,,,,,,,,,,,,,, -214700,0.23770891,1.4195769,,,,,,,,,,,,,,,,, -214800,0.23276633,1.4368513,,,,,,,,,,,,,,,,, -214900,0.22752172,1.4079521,,,,,,,,,,,,,,,,, -214923,,,0.6939623951911926,1.3788866996765137,35.34666481423096,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,76484.49631166458,131240.6146156788,76484.49631166458,54745.98873233795,3.574280500411988,0.0 -215000,0.23334453,1.4306065,,,,,,,,,,,,,,,,, -215100,0.24140872,1.4533588,,,,,,,,,,,,,,,,, -215200,0.23593439,1.424506,,,,,,,,,,,,,,,,, -215300,0.2351276,1.430715,,,,,,,,,,,,,,,,, -215400,0.2377888,1.4230279,,,,,,,,,,,,,,,,, -215500,0.2193585,1.39516,,,,,,,,,,,,,,,,, -215600,0.23121008,1.3986887,,,,,,,,,,,,,,,,, -215700,0.29304886,1.4220097,,,,,,,,,,,,,,,,, -215800,0.24090418,1.4777797,,,,,,,,,,,,,,,,, -215900,0.23168184,1.4272234,,,,,,,,,,,,,,,,, -216000,0.23316374,1.4916352,,,,,,,,,,,,,,,,, -216100,0.23356003,1.4127281,,,,,,,,,,,,,,,,, -216200,0.24219708,1.4807364,,,,,,,,,,,,,,,,, -216300,0.23090313,1.4638001,,,,,,,,,,,,,,,,, -216400,0.23591536,1.4412553,,,,,,,,,,,,,,,,, -216500,0.22976999,1.41334,,,,,,,,,,,,,,,,, -216600,0.2278619,1.473458,,,,,,,,,,,,,,,,, -216700,0.23208086,1.4340414,,,,,,,,,,,,,,,,, -216800,0.2459832,1.4802351,,,,,,,,,,,,,,,,, -216900,0.25056922,1.412727,,,,,,,,,,,,,,,,, -217000,0.23491834,1.3909407,,,,,,,,,,,,,,,,, -217100,0.23859225,1.4458907,,,,,,,,,,,,,,,,, -217200,0.24165682,1.461662,,,,,,,,,,,,,,,,, -217284,,,0.6991239786148071,1.345736384391785,35.3880923790759,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,77324.43882918358,132694.75473570824,77324.43882918358,55360.05347490311,3.6353468894958496,0.0 -217300,0.23152679,1.4196216,,,,,,,,,,,,,,,,, -217400,0.23524639,1.4852773,,,,,,,,,,,,,,,,, -217500,0.24496469,1.5107638,,,,,,,,,,,,,,,,, -217600,0.2403349,1.4805903,,,,,,,,,,,,,,,,, -217700,0.22710866,1.3248984,,,,,,,,,,,,,,,,, -217800,0.24903114,1.5118774,,,,,,,,,,,,,,,,, -217900,0.22670177,1.4847724,,,,,,,,,,,,,,,,, -218000,0.23777758,1.4690224,,,,,,,,,,,,,,,,, -218100,0.23056234,1.4241116,,,,,,,,,,,,,,,,, -218200,0.22865954,1.4403185,,,,,,,,,,,,,,,,, -218300,0.22846733,1.4504662,,,,,,,,,,,,,,,,, -218400,0.2384275,1.4234271,,,,,,,,,,,,,,,,, -218500,0.25210217,1.4488004,,,,,,,,,,,,,,,,, -218600,0.23570365,1.4668494,,,,,,,,,,,,,,,,, -218700,0.23480566,1.4249316,,,,,,,,,,,,,,,,, -218800,0.23739114,1.4614917,,,,,,,,,,,,,,,,, -218900,0.23261447,1.3793513,,,,,,,,,,,,,,,,, -219000,0.23732048,1.4156543,,,,,,,,,,,,,,,,, -219100,0.23089913,1.3920492,,,,,,,,,,,,,,,,, -219200,0.22917287,1.4108714,,,,,,,,,,,,,,,,, -219300,0.23968996,1.3748084,,,,,,,,,,,,,,,,, -219400,0.22772543,1.3944067,,,,,,,,,,,,,,,,, -219500,0.23254834,1.4710753,,,,,,,,,,,,,,,,, -219600,0.2424344,1.4128942,,,,,,,,,,,,,,,,, -219645,,,0.6964263319969177,1.3674417734146118,35.6649330863915,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,78164.45926618576,134139.69259738922,78164.45926618576,55964.84424567223,3.688246011734009,0.0 -219700,0.23664393,1.4063635,,,,,,,,,,,,,,,,, -219800,0.22920817,1.4476731,,,,,,,,,,,,,,,,, -219900,0.24169834,1.4788817,,,,,,,,,,,,,,,,, -220000,0.23917091,1.4587053,,,,,,,,,,,,,,,,, -220100,0.23493028,1.4348768,,,,,,,,,,,,,,,,, -220200,0.22788836,1.4814469,,,,,,,,,,,,,,,,, -220300,0.2608722,1.563593,,,,,,,,,,,,,,,,, -220400,0.22402565,1.3920976,,,,,,,,,,,,,,,,, -220500,0.23234944,1.4723718,,,,,,,,,,,,,,,,, -220600,0.22780536,1.4207286,,,,,,,,,,,,,,,,, -220700,0.24332073,1.402858,,,,,,,,,,,,,,,,, -220800,0.22929394,1.4080136,,,,,,,,,,,,,,,,, -220900,0.23499703,1.4673663,,,,,,,,,,,,,,,,, -221000,0.22087413,1.437294,,,,,,,,,,,,,,,,, -221100,0.23001447,1.4649655,,,,,,,,,,,,,,,,, -221200,0.3897489,1.4790665,,,,,,,,,,,,,,,,, -221300,0.23226777,1.4422903,,,,,,,,,,,,,,,,, -221400,0.23936762,1.4792477,,,,,,,,,,,,,,,,, -221500,0.23402555,1.4843235,,,,,,,,,,,,,,,,, -221600,0.22867674,1.3418225,,,,,,,,,,,,,,,,, -221700,0.22663923,1.3496912,,,,,,,,,,,,,,,,, -221800,0.23990026,1.5060185,,,,,,,,,,,,,,,,, -221900,0.23615496,1.4533316,,,,,,,,,,,,,,,,, -222000,0.22953708,1.4290472,,,,,,,,,,,,,,,,, -222006,,,0.6972445249557495,1.3616267442703247,35.77868590202388,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,79004.4024219513,135596.00564146042,79004.4024219513,56581.09203457832,3.739297866821289,0.0 -222100,0.24186137,1.4708856,,,,,,,,,,,,,,,,, -222200,0.23783103,1.3831927,,,,,,,,,,,,,,,,, -222300,0.23105866,1.4189256,,,,,,,,,,,,,,,,, -222400,0.24655332,1.421772,,,,,,,,,,,,,,,,, -222500,0.22403006,1.3739983,,,,,,,,,,,,,,,,, -222600,0.23821859,1.457111,,,,,,,,,,,,,,,,, -222700,0.24287228,1.4129754,,,,,,,,,,,,,,,,, -222800,0.22985674,1.4145521,,,,,,,,,,,,,,,,, -222900,0.24399552,1.5113791,,,,,,,,,,,,,,,,, -223000,0.23256135,1.3973774,,,,,,,,,,,,,,,,, -223100,0.23368824,1.4366477,,,,,,,,,,,,,,,,, -223200,0.2498726,1.3745803,,,,,,,,,,,,,,,,, -223300,0.23282845,1.4575372,,,,,,,,,,,,,,,,, -223400,0.23050395,1.4248884,,,,,,,,,,,,,,,,, -223500,0.23096277,1.4692785,,,,,,,,,,,,,,,,, -223600,0.23361823,1.3911699,,,,,,,,,,,,,,,,, -223700,0.24275537,1.4782063,,,,,,,,,,,,,,,,, -223800,0.23385645,1.4244834,,,,,,,,,,,,,,,,, -223900,0.24087624,1.4191107,,,,,,,,,,,,,,,,, -224000,0.23592575,1.3874387,,,,,,,,,,,,,,,,, -224100,0.23126034,1.3759379,,,,,,,,,,,,,,,,, -224200,0.23490928,1.4831353,,,,,,,,,,,,,,,,, -224300,0.2393013,1.4801278,,,,,,,,,,,,,,,,, -224367,,,0.6952289342880249,1.3711453676223757,35.30621007605715,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,79844.48212599754,137049.9975218773,79844.48212599754,57194.87903165817,3.7929728031158447,0.0 -224400,0.23992829,1.4556807,,,,,,,,,,,,,,,,, -224500,0.23183964,1.385579,,,,,,,,,,,,,,,,, -224600,0.23739016,1.4912106,,,,,,,,,,,,,,,,, -224700,0.22784837,1.3878397,,,,,,,,,,,,,,,,, -224800,0.23451419,1.4798809,,,,,,,,,,,,,,,,, -224900,0.23165324,1.3847662,,,,,,,,,,,,,,,,, -225000,0.22970656,1.4479303,,,,,,,,,,,,,,,,, -225100,0.2360133,1.4584008,,,,,,,,,,,,,,,,, -225200,0.23260584,1.4808995,,,,,,,,,,,,,,,,, -225300,0.23206235,1.5016555,,,,,,,,,,,,,,,,, -225400,0.22645475,1.4908036,,,,,,,,,,,,,,,,, -225500,0.23000988,1.4830931,,,,,,,,,,,,,,,,, -225600,0.23125778,1.4642677,,,,,,,,,,,,,,,,, -225700,0.230211,1.4520065,,,,,,,,,,,,,,,,, -225800,0.230869,1.4806856,,,,,,,,,,,,,,,,, -225900,0.23728077,1.459201,,,,,,,,,,,,,,,,, -226000,0.23556946,1.4590597,,,,,,,,,,,,,,,,, -226100,0.24537304,1.5456059,,,,,,,,,,,,,,,,, -226200,0.23554963,1.376447,,,,,,,,,,,,,,,,, -226300,0.2282359,1.45421,,,,,,,,,,,,,,,,, -226400,0.23946229,1.3958063,,,,,,,,,,,,,,,,, -226500,0.23605365,1.5244133,,,,,,,,,,,,,,,,, -226600,0.23114179,1.4419806,,,,,,,,,,,,,,,,, -226700,0.22862808,1.3688861,,,,,,,,,,,,,,,,, -226727,,,0.6972992420196533,1.3629869222640991,35.395366730208174,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,80684.43466353416,138493.86475038528,80684.43466353416,57798.66555047035,3.8469254970550537,0.0 -226800,0.2391657,1.4154648,,,,,,,,,,,,,,,,, -226900,0.24788655,1.5401618,,,,,,,,,,,,,,,,, -227000,0.2383752,1.4180839,,,,,,,,,,,,,,,,, -227100,0.2407218,1.4215604,,,,,,,,,,,,,,,,, -227200,0.23752657,1.413902,,,,,,,,,,,,,,,,, -227300,0.22046782,1.3736657,,,,,,,,,,,,,,,,, -227400,0.23767416,1.4746385,,,,,,,,,,,,,,,,, -227500,0.23800568,1.445739,,,,,,,,,,,,,,,,, -227600,0.22920959,1.4730752,,,,,,,,,,,,,,,,, -227700,0.22829331,1.4438254,,,,,,,,,,,,,,,,, -227800,0.23372212,1.4204347,,,,,,,,,,,,,,,,, -227900,0.23951253,1.4347206,,,,,,,,,,,,,,,,, -228000,0.2368885,1.3893068,,,,,,,,,,,,,,,,, -228100,0.60918033,1.54777,,,,,,,,,,,,,,,,, -228200,0.24755086,1.459109,,,,,,,,,,,,,,,,, -228300,0.23502512,1.4158021,,,,,,,,,,,,,,,,, -228400,0.23359518,1.4120793,,,,,,,,,,,,,,,,, -228500,0.22730394,1.3881838,,,,,,,,,,,,,,,,, -228600,0.23855421,1.394188,,,,,,,,,,,,,,,,, -228700,0.23397742,1.4540523,,,,,,,,,,,,,,,,, -228800,0.23388869,1.3903598,,,,,,,,,,,,,,,,, -228900,0.22776563,1.399101,,,,,,,,,,,,,,,,, -229000,0.2385504,1.4738663,,,,,,,,,,,,,,,,, -229088,,,0.6930255889892578,1.3843538761138916,35.532078761547865,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,81524.6353867054,139926.31528425217,81524.6353867054,58390.79001760483,3.900733709335327,0.0 -229100,0.23331338,1.4217597,,,,,,,,,,,,,,,,, -229200,0.22688557,1.4258586,,,,,,,,,,,,,,,,, -229300,0.23189527,1.3859708,,,,,,,,,,,,,,,,, -229400,0.242593,1.4931674,,,,,,,,,,,,,,,,, -229500,0.23908818,1.4058789,,,,,,,,,,,,,,,,, -229600,0.23350978,1.4891527,,,,,,,,,,,,,,,,, -229700,0.23133929,1.463737,,,,,,,,,,,,,,,,, -229800,0.23630384,1.4893287,,,,,,,,,,,,,,,,, -229900,0.23473321,1.5027332,,,,,,,,,,,,,,,,, -230000,0.22679321,1.4378313,,,,,,,,,,,,,,,,, -230100,0.2409244,1.4894841,,,,,,,,,,,,,,,,, -230200,0.23733662,1.4881463,,,,,,,,,,,,,,,,, -230300,0.2288715,1.4570785,,,,,,,,,,,,,,,,, -230400,0.23895977,1.4422271,,,,,,,,,,,,,,,,, -230500,0.23159881,1.4451488,,,,,,,,,,,,,,,,, -230600,0.24202968,1.4530299,,,,,,,,,,,,,,,,, -230700,0.24422272,1.4360448,,,,,,,,,,,,,,,,, -230800,0.23981328,1.4822807,,,,,,,,,,,,,,,,, -230900,0.23518407,1.5265254,,,,,,,,,,,,,,,,, -231000,0.22682212,1.4116788,,,,,,,,,,,,,,,,, -231100,0.23378006,1.5208826,,,,,,,,,,,,,,,,, -231200,0.24443762,1.4256905,,,,,,,,,,,,,,,,, -231300,0.22795868,1.3932528,,,,,,,,,,,,,,,,, -231400,0.2446822,1.4963568,,,,,,,,,,,,,,,,, -231448,,,0.697636067867279,1.3580584526062012,35.64368150155568,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,82364.74400091171,141381.4314494133,82364.74400091171,59005.66600751877,3.956840753555298,0.0 -231500,0.2232237,1.3681165,,,,,,,,,,,,,,,,, -231600,0.22910686,1.4186426,,,,,,,,,,,,,,,,, -231700,0.23810336,1.4665542,,,,,,,,,,,,,,,,, -231800,0.23370701,1.3387771,,,,,,,,,,,,,,,,, -231900,0.2323878,1.4134192,,,,,,,,,,,,,,,,, -232000,0.24288318,1.5185285,,,,,,,,,,,,,,,,, -232100,0.24240573,1.4199557,,,,,,,,,,,,,,,,, -232200,0.22577573,1.3910166,,,,,,,,,,,,,,,,, -232300,0.23199303,1.4388672,,,,,,,,,,,,,,,,, -232400,0.23187844,1.4235426,,,,,,,,,,,,,,,,, -232500,0.24412295,1.4862063,,,,,,,,,,,,,,,,, -232600,0.22441474,1.4063383,,,,,,,,,,,,,,,,, -232700,0.23449863,1.4257259,,,,,,,,,,,,,,,,, -232800,0.24407157,1.4936703,,,,,,,,,,,,,,,,, -232900,0.2391806,1.4283401,,,,,,,,,,,,,,,,, -233000,0.22768785,1.4087384,,,,,,,,,,,,,,,,, -233100,0.22917055,1.4197553,,,,,,,,,,,,,,,,, -233200,0.2293994,1.4478258,,,,,,,,,,,,,,,,, -233300,0.24361351,1.4856719,,,,,,,,,,,,,,,,, -233400,0.6242251,1.4764584,,,,,,,,,,,,,,,,, -233500,0.24078137,1.4346164,,,,,,,,,,,,,,,,, -233600,0.24529631,1.4703608,,,,,,,,,,,,,,,,, -233700,0.2376649,1.4883438,,,,,,,,,,,,,,,,, -233800,0.24027485,1.462633,,,,,,,,,,,,,,,,, -233808,,,0.6989806294441223,1.349443793296814,35.61818829777432,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,83204.71192860603,142820.386374712,83204.71192860603,59604.526320934296,4.011416673660278,0.0 -233900,0.24493854,1.3981689,,,,,,,,,,,,,,,,, -234000,0.23373087,1.4587971,,,,,,,,,,,,,,,,, -234100,0.23430203,1.4744486,,,,,,,,,,,,,,,,, -234200,0.24013029,1.4848262,,,,,,,,,,,,,,,,, -234300,0.23282129,1.456734,,,,,,,,,,,,,,,,, -234400,0.23876889,1.447566,,,,,,,,,,,,,,,,, -234500,0.22668733,1.4269516,,,,,,,,,,,,,,,,, -234600,0.22838119,1.4175589,,,,,,,,,,,,,,,,, -234700,0.22131898,1.4389759,,,,,,,,,,,,,,,,, -234800,0.23507097,1.4391652,,,,,,,,,,,,,,,,, -234900,0.22325537,1.4019324,,,,,,,,,,,,,,,,, -235000,0.23821747,1.4675539,,,,,,,,,,,,,,,,, -235100,0.23467258,1.4354167,,,,,,,,,,,,,,,,, -235200,0.24225262,1.4524698,,,,,,,,,,,,,,,,, -235300,0.23807248,1.4774953,,,,,,,,,,,,,,,,, -235400,0.23142023,1.4404782,,,,,,,,,,,,,,,,, -235500,0.23012283,1.4493339,,,,,,,,,,,,,,,,, -235600,0.22550309,1.3846793,,,,,,,,,,,,,,,,, -235700,0.27488586,1.4951303,,,,,,,,,,,,,,,,, -235800,0.23979075,1.4650972,,,,,,,,,,,,,,,,, -235900,0.24344838,1.5246496,,,,,,,,,,,,,,,,, -236000,0.24131644,1.5506648,,,,,,,,,,,,,,,,, -236100,0.23362505,1.3917112,,,,,,,,,,,,,,,,, -236168,,,0.6970456838607788,1.3616652488708496,35.49304734732505,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,84044.6024339199,144252.54349827766,84044.6024339199,60196.6534178257,4.07750678062439,0.0 -236200,0.23080799,1.4128028,,,,,,,,,,,,,,,,, -236300,0.24624723,1.4643015,,,,,,,,,,,,,,,,, -236400,0.23455547,1.4253564,,,,,,,,,,,,,,,,, -236500,0.24874814,1.4597306,,,,,,,,,,,,,,,,, -236600,0.23348196,1.4354075,,,,,,,,,,,,,,,,, -236700,0.23124205,1.5028408,,,,,,,,,,,,,,,,, -236800,0.22077397,1.3992112,,,,,,,,,,,,,,,,, -236900,0.2244218,1.4389013,,,,,,,,,,,,,,,,, -237000,0.23371166,1.4388999,,,,,,,,,,,,,,,,, -237100,0.2366401,1.5222383,,,,,,,,,,,,,,,,, -237200,0.23153621,1.4521484,,,,,,,,,,,,,,,,, -237300,0.23316608,1.4579521,,,,,,,,,,,,,,,,, -237400,0.23130935,1.4897012,,,,,,,,,,,,,,,,, -237500,0.24061367,1.4692025,,,,,,,,,,,,,,,,, -237600,0.23773159,1.4189441,,,,,,,,,,,,,,,,, -237700,0.23529314,1.401953,,,,,,,,,,,,,,,,, -237800,0.23876671,1.4211141,,,,,,,,,,,,,,,,, -237900,0.23147851,1.3931005,,,,,,,,,,,,,,,,, -238000,0.22884732,1.443158,,,,,,,,,,,,,,,,, -238100,0.23521969,1.4579096,,,,,,,,,,,,,,,,, -238200,0.40202594,1.4823763,,,,,,,,,,,,,,,,, -238300,0.23199202,1.4620963,,,,,,,,,,,,,,,,, -238400,0.23812647,1.4505004,,,,,,,,,,,,,,,,, -238500,0.22948892,1.4644393,,,,,,,,,,,,,,,,, -238529,,,0.6947537064552307,1.375749588012695,35.53998794997194,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,84884.70688819885,145690.4778895378,84884.70688819885,60794.35615229607,4.133155822753906,0.0 -238600,0.23581982,1.4859662,,,,,,,,,,,,,,,,, -238700,0.2282609,1.4222698,,,,,,,,,,,,,,,,, -238800,0.22947323,1.4093472,,,,,,,,,,,,,,,,, -238900,0.2464179,1.473459,,,,,,,,,,,,,,,,, -239000,0.2431673,1.4761424,,,,,,,,,,,,,,,,, -239100,0.23563561,1.4287742,,,,,,,,,,,,,,,,, -239200,0.22438408,1.4372349,,,,,,,,,,,,,,,,, -239300,0.23865853,1.4767758,,,,,,,,,,,,,,,,, -239400,0.24218886,1.4524207,,,,,,,,,,,,,,,,, -239500,0.23887515,1.5336806,,,,,,,,,,,,,,,,, -239600,0.23527461,1.4786185,,,,,,,,,,,,,,,,, -239700,0.24346858,1.4434226,,,,,,,,,,,,,,,,, -239800,0.22237095,1.3915201,,,,,,,,,,,,,,,,, -239900,0.23278672,1.4339443,,,,,,,,,,,,,,,,, -240000,0.23312585,1.398683,,,,,,,,,,,,,,,,, -240100,0.23245451,1.509311,,,,,,,,,,,,,,,,, -240200,0.24568903,1.4459726,,,,,,,,,,,,,,,,, -240300,0.23379573,1.4449601,,,,,,,,,,,,,,,,, -240400,0.24510495,1.4771342,,,,,,,,,,,,,,,,, -240500,0.232203,1.3848403,,,,,,,,,,,,,,,,, -240600,0.23589818,1.4785007,,,,,,,,,,,,,,,,, -240700,0.2261785,1.4792676,,,,,,,,,,,,,,,,, -240800,0.23724104,1.402343,,,,,,,,,,,,,,,,, -240890,,,0.6979270577430725,1.3634157180786133,35.702276531386154,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,85724.79488253593,147137.6408853531,85724.79488253593,61401.30293893814,4.189852237701416,0.0 -240900,0.23646595,1.5246472,,,,,,,,,,,,,,,,, -241000,0.2415188,1.5010538,,,,,,,,,,,,,,,,, -241100,0.22808428,1.4193819,,,,,,,,,,,,,,,,, -241200,0.22626822,1.3390763,,,,,,,,,,,,,,,,, -241300,0.24247849,1.5234469,,,,,,,,,,,,,,,,, -241400,0.23648256,1.423761,,,,,,,,,,,,,,,,, -241500,0.23178476,1.3801718,,,,,,,,,,,,,,,,, -241600,0.24031767,1.4691241,,,,,,,,,,,,,,,,, -241700,0.23259549,1.4031061,,,,,,,,,,,,,,,,, -241800,0.24011758,1.4698044,,,,,,,,,,,,,,,,, -241900,0.23160249,1.5029147,,,,,,,,,,,,,,,,, -242000,0.23308085,1.4628369,,,,,,,,,,,,,,,,, -242100,0.23412889,1.4790604,,,,,,,,,,,,,,,,, -242200,0.23492461,1.4298402,,,,,,,,,,,,,,,,, -242300,0.24687326,1.502136,,,,,,,,,,,,,,,,, -242400,0.2260181,1.41354,,,,,,,,,,,,,,,,, -242500,0.23244387,1.471546,,,,,,,,,,,,,,,,, -242600,0.23568384,1.4967283,,,,,,,,,,,,,,,,, -242700,0.2294602,1.4744891,,,,,,,,,,,,,,,,, -242800,0.23371251,1.428269,,,,,,,,,,,,,,,,, -242900,0.22555156,1.4604075,,,,,,,,,,,,,,,,, -243000,0.23496656,1.4491075,,,,,,,,,,,,,,,,, -243100,0.23342265,1.4164906,,,,,,,,,,,,,,,,, -243200,0.23498538,1.4361948,,,,,,,,,,,,,,,,, -243250,,,0.6946918368339539,1.3715825080871582,35.802417390103905,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,86564.79521155357,148591.69333863258,86564.79521155357,62015.22755002976,4.245082378387451,0.0 -243300,0.23519969,1.3754504,,,,,,,,,,,,,,,,, -243400,0.24629769,1.4577391,,,,,,,,,,,,,,,,, -243500,0.23905167,1.4745136,,,,,,,,,,,,,,,,, -243600,0.23605672,1.4744201,,,,,,,,,,,,,,,,, -243700,0.23198046,1.4155042,,,,,,,,,,,,,,,,, -243800,0.23188405,1.485281,,,,,,,,,,,,,,,,, -243900,0.22574534,1.3914421,,,,,,,,,,,,,,,,, -244000,0.22280072,1.4663212,,,,,,,,,,,,,,,,, -244100,0.24105772,1.4614801,,,,,,,,,,,,,,,,, -244200,0.23675017,1.5181811,,,,,,,,,,,,,,,,, -244300,0.2338137,1.4411191,,,,,,,,,,,,,,,,, -244400,0.24304777,1.388889,,,,,,,,,,,,,,,,, -244500,0.23084232,1.4301195,,,,,,,,,,,,,,,,, -244600,0.23142283,1.4593472,,,,,,,,,,,,,,,,, -244700,0.24150181,1.4657369,,,,,,,,,,,,,,,,, -244800,0.23425177,1.384927,,,,,,,,,,,,,,,,, -244900,0.23228233,1.4309657,,,,,,,,,,,,,,,,, -245000,0.23602225,1.4450905,,,,,,,,,,,,,,,,, -245100,0.2344849,1.4523851,,,,,,,,,,,,,,,,, -245200,0.23902261,1.4186248,,,,,,,,,,,,,,,,, -245300,0.23179382,1.4523325,,,,,,,,,,,,,,,,, -245400,0.21692953,1.4129455,,,,,,,,,,,,,,,,, -245500,0.23366709,1.3598267,,,,,,,,,,,,,,,,, -245600,0.23897567,1.4442097,,,,,,,,,,,,,,,,, -245610,,,0.6926801204681396,1.3849468231201172,35.56962577221613,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,87404.78988575935,150032.48102235794,87404.78988575935,62615.89183759689,4.302200794219971,0.0 -245700,0.23105867,1.4228715,,,,,,,,,,,,,,,,, -245800,0.23757514,1.4468695,,,,,,,,,,,,,,,,, -245900,0.23607929,1.3854564,,,,,,,,,,,,,,,,, -246000,0.23542967,1.3975414,,,,,,,,,,,,,,,,, -246100,0.23097123,1.430139,,,,,,,,,,,,,,,,, -246200,0.23417519,1.4457612,,,,,,,,,,,,,,,,, -246300,0.23456953,1.4515976,,,,,,,,,,,,,,,,, -246400,0.24199805,1.5008701,,,,,,,,,,,,,,,,, -246500,0.24815597,1.4808742,,,,,,,,,,,,,,,,, -246600,0.23766696,1.4563795,,,,,,,,,,,,,,,,, -246700,0.23557056,1.5056397,,,,,,,,,,,,,,,,, -246800,0.24226777,1.3992466,,,,,,,,,,,,,,,,, -246900,0.24158643,1.4924837,,,,,,,,,,,,,,,,, -247000,0.22717525,1.4347799,,,,,,,,,,,,,,,,, -247100,0.23548277,1.457344,,,,,,,,,,,,,,,,, -247200,0.22840884,1.3542186,,,,,,,,,,,,,,,,, -247300,0.23952541,1.4634913,,,,,,,,,,,,,,,,, -247400,0.23959863,1.4594189,,,,,,,,,,,,,,,,, -247500,0.23394242,1.4574027,,,,,,,,,,,,,,,,, -247600,0.25273317,1.4763515,,,,,,,,,,,,,,,,, -247700,0.23889455,1.5040644,,,,,,,,,,,,,,,,, -247800,0.22528662,1.3816174,,,,,,,,,,,,,,,,, -247900,0.23047395,1.465442,,,,,,,,,,,,,,,,, -247970,,,0.6970303654670715,1.3621553182601929,35.03329818639915,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,88244.95991444588,151476.27511763573,88244.95991444588,63219.3861579895,4.359274387359619,0.0 -248000,0.23166016,1.4375057,,,,,,,,,,,,,,,,, -248100,0.25191393,1.485822,,,,,,,,,,,,,,,,, -248200,0.24430913,1.4738216,,,,,,,,,,,,,,,,, -248300,0.39738774,1.4257448,,,,,,,,,,,,,,,,, -248400,0.23541658,1.4384161,,,,,,,,,,,,,,,,, -248500,0.23286639,1.3835735,,,,,,,,,,,,,,,,, -248600,0.2344259,1.4812465,,,,,,,,,,,,,,,,, -248700,0.23652542,1.4503984,,,,,,,,,,,,,,,,, -248800,0.23147936,1.484819,,,,,,,,,,,,,,,,, -248900,0.23252976,1.397744,,,,,,,,,,,,,,,,, -249000,0.23007642,1.4425188,,,,,,,,,,,,,,,,, -249100,0.2344515,1.4173361,,,,,,,,,,,,,,,,, -249200,0.22914173,1.451653,,,,,,,,,,,,,,,,, -249300,0.23676242,1.4026464,,,,,,,,,,,,,,,,, -249400,0.23069559,1.4273108,,,,,,,,,,,,,,,,, -249500,0.23504308,1.3759476,,,,,,,,,,,,,,,,, -249600,0.22729616,1.4392712,,,,,,,,,,,,,,,,, -249700,0.2356836,1.4225769,,,,,,,,,,,,,,,,, -249800,0.2408882,1.5118226,,,,,,,,,,,,,,,,, -249900,0.2304397,1.4234955,,,,,,,,,,,,,,,,, -250000,0.23385638,1.510709,,,,,,,,,,,,,,,,, -250100,0.23024416,1.4345174,,,,,,,,,,,,,,,,, -250200,0.23653263,1.5719163,,,,,,,,,,,,,,,,, -250300,0.22790441,1.4547601,,,,,,,,,,,,,,,,, -250331,,,0.6966925263404846,1.3629604578018188,35.4255910535756,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,89084.85697770119,152920.51184105873,89084.85697770119,63823.59740900993,4.417787790298462,0.0 -250400,0.23433013,1.3558424,,,,,,,,,,,,,,,,, -250500,0.2382319,1.4244633,,,,,,,,,,,,,,,,, -250600,0.23747039,1.5210141,,,,,,,,,,,,,,,,, -250700,0.220861,1.3992862,,,,,,,,,,,,,,,,, -250800,0.23576286,1.4072139,,,,,,,,,,,,,,,,, -250900,0.24410218,1.4750072,,,,,,,,,,,,,,,,, -251000,0.22767279,1.4285264,,,,,,,,,,,,,,,,, -251100,0.23543651,1.4763716,,,,,,,,,,,,,,,,, -251200,0.23079981,1.3473434,,,,,,,,,,,,,,,,, -251300,0.23191957,1.3997816,,,,,,,,,,,,,,,,, -251400,0.23084345,1.400492,,,,,,,,,,,,,,,,, -251500,0.24016377,1.5047116,,,,,,,,,,,,,,,,, -251600,0.24474074,1.4146665,,,,,,,,,,,,,,,,, -251700,0.24070902,1.4623914,,,,,,,,,,,,,,,,, -251800,0.22639441,1.4475205,,,,,,,,,,,,,,,,, -251900,0.24414468,1.4384173,,,,,,,,,,,,,,,,, -252000,0.23241538,1.4674879,,,,,,,,,,,,,,,,, -252100,0.23890682,1.5076182,,,,,,,,,,,,,,,,, -252200,0.22931896,1.3990426,,,,,,,,,,,,,,,,, -252300,0.2630262,1.4577764,,,,,,,,,,,,,,,,, -252400,0.24682245,1.4767704,,,,,,,,,,,,,,,,, -252500,0.2367046,1.4252422,,,,,,,,,,,,,,,,, -252600,0.23679987,1.4305052,,,,,,,,,,,,,,,,, -252691,,,0.693827211856842,1.379385232925415,35.42538705946164,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,89924.99598526955,154367.91337037086,89924.99598526955,64430.72727441788,4.476908206939697,0.0 -252700,0.2316499,1.4569529,,,,,,,,,,,,,,,,, -252800,0.24153855,1.4317499,,,,,,,,,,,,,,,,, -252900,0.2386865,1.4426328,,,,,,,,,,,,,,,,, -253000,0.23822066,1.4161906,,,,,,,,,,,,,,,,, -253100,0.23507681,1.4783652,,,,,,,,,,,,,,,,, -253200,0.24168018,1.47632,,,,,,,,,,,,,,,,, -253300,0.23638214,1.3870431,,,,,,,,,,,,,,,,, -253400,0.23716114,1.3915683,,,,,,,,,,,,,,,,, -253500,0.23780978,1.4342035,,,,,,,,,,,,,,,,, -253600,0.23584466,1.4625394,,,,,,,,,,,,,,,,, -253700,0.23462608,1.4199568,,,,,,,,,,,,,,,,, -253800,0.22986592,1.4089569,,,,,,,,,,,,,,,,, -253900,0.23068525,1.3745753,,,,,,,,,,,,,,,,, -254000,0.23798963,1.4378107,,,,,,,,,,,,,,,,, -254100,0.23356822,1.4483668,,,,,,,,,,,,,,,,, -254200,0.236348,1.4819837,,,,,,,,,,,,,,,,, -254300,0.2338027,1.4968154,,,,,,,,,,,,,,,,, -254400,0.23672433,1.4813795,,,,,,,,,,,,,,,,, -254500,0.23298453,1.506512,,,,,,,,,,,,,,,,, -254600,0.22879757,1.4063377,,,,,,,,,,,,,,,,, -254700,0.23837511,1.4384556,,,,,,,,,,,,,,,,, -254800,0.23180152,1.4521806,,,,,,,,,,,,,,,,, -254900,0.23288508,1.436938,,,,,,,,,,,,,,,,, -255000,0.23666066,1.3983761,,,,,,,,,,,,,,,,, -255051,,,0.6940208673477173,1.374923586845398,35.44952550740223,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,90765.08199381828,155800.5014474392,90765.08199381828,65023.09778881073,4.534471273422241,0.0 -255100,0.23297252,1.4461012,,,,,,,,,,,,,,,,, -255200,0.23195563,1.4163481,,,,,,,,,,,,,,,,, -255300,0.24182408,1.4949565,,,,,,,,,,,,,,,,, -255400,0.23061872,1.4571469,,,,,,,,,,,,,,,,, -255500,0.22550747,1.3359739,,,,,,,,,,,,,,,,, -255600,0.23737001,1.4713382,,,,,,,,,,,,,,,,, -255700,0.23472412,1.433596,,,,,,,,,,,,,,,,, -255800,0.23106268,1.4677967,,,,,,,,,,,,,,,,, -255900,0.24309005,1.4311032,,,,,,,,,,,,,,,,, -256000,0.24162333,1.4651684,,,,,,,,,,,,,,,,, -256100,0.25041866,1.4684383,,,,,,,,,,,,,,,,, -256200,0.23380716,1.442396,,,,,,,,,,,,,,,,, -256300,0.22434591,1.4603097,,,,,,,,,,,,,,,,, -256400,0.23372938,1.4893223,,,,,,,,,,,,,,,,, -256500,0.24131885,1.4617326,,,,,,,,,,,,,,,,, -256600,0.23170145,1.4415329,,,,,,,,,,,,,,,,, -256700,0.230802,1.4735265,,,,,,,,,,,,,,,,, -256800,0.23294702,1.3929757,,,,,,,,,,,,,,,,, -256900,0.23059958,1.3940864,,,,,,,,,,,,,,,,, -257000,0.22783208,1.3963768,,,,,,,,,,,,,,,,, -257100,0.24092264,1.5047956,,,,,,,,,,,,,,,,, -257200,0.22806592,1.4701047,,,,,,,,,,,,,,,,, -257300,0.22212031,1.4299536,,,,,,,,,,,,,,,,, -257400,0.2376727,1.415614,,,,,,,,,,,,,,,,, -257412,,,0.6943745017051697,1.3740166425704956,35.384300906912,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,91605.0241189003,157249.84708356857,91605.0241189003,65632.37321901321,4.591864347457886,0.0 -257500,0.22859867,1.4172145,,,,,,,,,,,,,,,,, -257600,0.22882658,1.4301383,,,,,,,,,,,,,,,,, -257700,0.2443164,1.5138881,,,,,,,,,,,,,,,,, -257800,0.221744,1.3582333,,,,,,,,,,,,,,,,, -257900,0.22988456,1.4468249,,,,,,,,,,,,,,,,, -258000,0.22656438,1.3920602,,,,,,,,,,,,,,,,, -258100,0.244669,1.4943423,,,,,,,,,,,,,,,,, -258200,0.23619162,1.5521678,,,,,,,,,,,,,,,,, -258300,0.2307408,1.4030045,,,,,,,,,,,,,,,,, -258400,0.23431088,1.4659178,,,,,,,,,,,,,,,,, -258500,0.23169446,1.4082788,,,,,,,,,,,,,,,,, -258600,0.23455633,1.4701722,,,,,,,,,,,,,,,,, -258700,0.24521355,1.3971044,,,,,,,,,,,,,,,,, -258800,0.22692381,1.4119978,,,,,,,,,,,,,,,,, -258900,0.23168676,1.3855003,,,,,,,,,,,,,,,,, -259000,0.2254258,1.4346097,,,,,,,,,,,,,,,,, -259100,0.22656065,1.4487255,,,,,,,,,,,,,,,,, -259200,0.23816167,1.421488,,,,,,,,,,,,,,,,, -259300,0.23750609,1.4915862,,,,,,,,,,,,,,,,, -259400,0.23244663,1.4730806,,,,,,,,,,,,,,,,, -259500,0.24152374,1.4819689,,,,,,,,,,,,,,,,, -259600,0.22470687,1.4016553,,,,,,,,,,,,,,,,, -259700,0.22722363,1.4813591,,,,,,,,,,,,,,,,, -259773,,,0.6928779482841492,1.3875993490219116,35.38147552282423,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,92445.21916270256,158690.14498519895,92445.21916270256,66232.347230196,4.648017883300781,0.0 -259800,0.23128477,1.4371096,,,,,,,,,,,,,,,,, -259900,0.23594853,1.4068944,,,,,,,,,,,,,,,,, -260000,0.23462722,1.4552765,,,,,,,,,,,,,,,,, -260100,0.23026155,1.3832699,,,,,,,,,,,,,,,,, -260200,0.22764146,1.4260374,,,,,,,,,,,,,,,,, -260300,0.24421705,1.5008501,,,,,,,,,,,,,,,,, -260400,0.23169833,1.3752803,,,,,,,,,,,,,,,,, -260500,0.2345515,1.4496229,,,,,,,,,,,,,,,,, -260600,0.2460415,1.4422128,,,,,,,,,,,,,,,,, -260700,0.23715821,1.43638,,,,,,,,,,,,,,,,, -260800,0.24031572,1.465343,,,,,,,,,,,,,,,,, -260900,0.23496713,1.4391551,,,,,,,,,,,,,,,,, -261000,0.23511064,1.4478179,,,,,,,,,,,,,,,,, -261100,0.226011,1.4309126,,,,,,,,,,,,,,,,, -261200,0.23552395,1.5304732,,,,,,,,,,,,,,,,, -261300,0.24598786,1.5407174,,,,,,,,,,,,,,,,, -261400,0.22724952,1.4780108,,,,,,,,,,,,,,,,, -261500,0.2387095,1.4890594,,,,,,,,,,,,,,,,, -261600,0.23100086,1.5253525,,,,,,,,,,,,,,,,, -261700,0.23585185,1.3961166,,,,,,,,,,,,,,,,, -261800,0.2335184,1.4596953,,,,,,,,,,,,,,,,, -261900,0.23494531,1.4386683,,,,,,,,,,,,,,,,, -262000,0.25588983,1.479976,,,,,,,,,,,,,,,,, -262100,0.23613507,1.3849415,,,,,,,,,,,,,,,,, -262134,,,0.6953383088111877,1.3716511726379397,35.612914454909664,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,93285.35460090636,160132.52825331688,93285.35460090636,66834.46612238884,4.7053093910217285,0.0 -262200,0.23319231,1.4655983,,,,,,,,,,,,,,,,, -262300,0.23697384,1.4743325,,,,,,,,,,,,,,,,, -262400,0.2273124,1.4241434,,,,,,,,,,,,,,,,, -262500,0.23977941,1.4415425,,,,,,,,,,,,,,,,, -262600,0.23171446,1.406067,,,,,,,,,,,,,,,,, -262700,0.23866224,1.5122182,,,,,,,,,,,,,,,,, -262800,0.2347206,1.4445091,,,,,,,,,,,,,,,,, -262900,0.2401183,1.3930601,,,,,,,,,,,,,,,,, -263000,0.2367214,1.4301991,,,,,,,,,,,,,,,,, -263100,0.22794642,1.4600906,,,,,,,,,,,,,,,,, -263200,0.23456207,1.4474881,,,,,,,,,,,,,,,,, -263300,0.24542938,1.4205996,,,,,,,,,,,,,,,,, -263400,0.22499503,1.432895,,,,,,,,,,,,,,,,, -263500,0.23235479,1.4123298,,,,,,,,,,,,,,,,, -263600,0.23765144,1.4859258,,,,,,,,,,,,,,,,, -263700,0.2372469,1.4397192,,,,,,,,,,,,,,,,, -263800,0.2362643,1.4500884,,,,,,,,,,,,,,,,, -263900,0.24840374,1.4486705,,,,,,,,,,,,,,,,, -264000,0.2354585,1.411179,,,,,,,,,,,,,,,,, -264100,0.23606335,1.4099324,,,,,,,,,,,,,,,,, -264200,0.23366164,1.4957381,,,,,,,,,,,,,,,,, -264300,0.23599608,1.4534196,,,,,,,,,,,,,,,,, -264400,0.22972338,1.4907248,,,,,,,,,,,,,,,,, -264494,,,0.6960185170173645,1.37083899974823,35.742916565992346,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,94125.36615252496,161577.15155768394,94125.36615252496,67438.93201184273,4.778682708740234,0.0 -264500,0.22896862,1.4960722,,,,,,,,,,,,,,,,, -264600,0.2302884,1.4488232,,,,,,,,,,,,,,,,, -264700,0.2321489,1.3844931,,,,,,,,,,,,,,,,, -264800,0.2417921,1.561751,,,,,,,,,,,,,,,,, -264900,0.2314853,1.4845676,,,,,,,,,,,,,,,,, -265000,0.23635162,1.5008597,,,,,,,,,,,,,,,,, -265100,0.22980817,1.4073305,,,,,,,,,,,,,,,,, -265200,0.23574509,1.4469982,,,,,,,,,,,,,,,,, -265300,0.2306944,1.4529601,,,,,,,,,,,,,,,,, -265400,0.23871222,1.4278309,,,,,,,,,,,,,,,,, -265500,0.24006523,1.4028679,,,,,,,,,,,,,,,,, -265600,0.23449756,1.4399426,,,,,,,,,,,,,,,,, -265700,0.24033076,1.5091666,,,,,,,,,,,,,,,,, -265800,0.22568905,1.433831,,,,,,,,,,,,,,,,, -265900,0.24471653,1.4381863,,,,,,,,,,,,,,,,, -266000,0.22837684,1.4182552,,,,,,,,,,,,,,,,, -266100,0.2339936,1.4975071,,,,,,,,,,,,,,,,, -266200,0.23460175,1.4343598,,,,,,,,,,,,,,,,, -266300,0.23270905,1.3997291,,,,,,,,,,,,,,,,, -266400,0.23048173,1.4266757,,,,,,,,,,,,,,,,, -266500,0.22464684,1.4180341,,,,,,,,,,,,,,,,, -266600,0.22940852,1.4362351,,,,,,,,,,,,,,,,, -266700,0.2239609,1.4302602,,,,,,,,,,,,,,,,, -266800,0.22837454,1.4190718,,,,,,,,,,,,,,,,, -266855,,,0.6939523220062256,1.3873682022094729,35.54363592145518,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,94965.49518156052,163028.48417282104,94965.49518156052,68050.00629425049,4.836698532104492,0.0 -266900,0.24041407,1.3796425,,,,,,,,,,,,,,,,, -267000,0.2510971,1.4821559,,,,,,,,,,,,,,,,, -267100,0.23469427,1.5073274,,,,,,,,,,,,,,,,, -267200,0.23560129,1.4480095,,,,,,,,,,,,,,,,, -267300,0.23634452,1.4297817,,,,,,,,,,,,,,,,, -267400,0.23287396,1.4934815,,,,,,,,,,,,,,,,, -267500,0.23349245,1.4378225,,,,,,,,,,,,,,,,, -267600,0.23109828,1.4447831,,,,,,,,,,,,,,,,, -267700,0.23786417,1.4568439,,,,,,,,,,,,,,,,, -267800,0.24115586,1.4846894,,,,,,,,,,,,,,,,, -267900,0.23825403,1.3900104,,,,,,,,,,,,,,,,, -268000,0.24065815,1.4887972,,,,,,,,,,,,,,,,, -268100,0.2354404,1.441788,,,,,,,,,,,,,,,,, -268200,0.24381898,1.4973696,,,,,,,,,,,,,,,,, -268300,0.24227037,1.3965,,,,,,,,,,,,,,,,, -268400,0.23404062,1.4782416,,,,,,,,,,,,,,,,, -268500,0.23370484,1.4562659,,,,,,,,,,,,,,,,, -268600,0.22635336,1.4299475,,,,,,,,,,,,,,,,, -268700,0.24251847,1.4944658,,,,,,,,,,,,,,,,, -268800,0.24527095,1.4795706,,,,,,,,,,,,,,,,, -268900,0.23838434,1.4965681,,,,,,,,,,,,,,,,, -269000,0.22533427,1.4061742,,,,,,,,,,,,,,,,, -269100,0.23846865,1.4946921,,,,,,,,,,,,,,,,, -269200,0.24032924,1.4869974,,,,,,,,,,,,,,,,, -269215,,,0.6959852576255798,1.371500015258789,35.7707856211828,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,95805.45388770103,164471.42721128464,95805.45388770103,68652.84652853012,4.907514810562134,0.0 -269300,0.23404278,1.4609339,,,,,,,,,,,,,,,,, -269400,0.23201632,1.4667555,,,,,,,,,,,,,,,,, -269500,0.24675733,1.5282991,,,,,,,,,,,,,,,,, -269600,0.24059166,1.4905001,,,,,,,,,,,,,,,,, -269700,0.23349068,1.4444612,,,,,,,,,,,,,,,,, -269800,0.22951148,1.4320718,,,,,,,,,,,,,,,,, -269900,0.22966054,1.3704776,,,,,,,,,,,,,,,,, -270000,0.23988165,1.462429,,,,,,,,,,,,,,,,, -270100,0.23003657,1.4695704,,,,,,,,,,,,,,,,, -270200,0.24568859,1.4447341,,,,,,,,,,,,,,,,, -270300,0.24467753,1.4948293,,,,,,,,,,,,,,,,, -270400,0.24340114,1.4208131,,,,,,,,,,,,,,,,, -270500,0.23246874,1.4171293,,,,,,,,,,,,,,,,, -270600,0.23812225,1.4700332,,,,,,,,,,,,,,,,, -270700,0.2390999,1.5891176,,,,,,,,,,,,,,,,, -270800,0.23473497,1.3935237,,,,,,,,,,,,,,,,, -270900,0.23427212,1.4015285,,,,,,,,,,,,,,,,, -271000,0.23164837,1.4524604,,,,,,,,,,,,,,,,, -271100,0.23479457,1.4447404,,,,,,,,,,,,,,,,, -271200,0.23743846,1.457788,,,,,,,,,,,,,,,,, -271300,0.23507765,1.5175494,,,,,,,,,,,,,,,,, -271400,0.23192538,1.466131,,,,,,,,,,,,,,,,, -271500,0.24058452,1.4500268,,,,,,,,,,,,,,,,, -271576,,,0.694656491279602,1.378628492355347,35.713971528505866,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,96645.50051903725,165917.50253725052,96645.50051903725,69258.74201393127,4.968581914901733,0.0 -271600,0.2404892,1.5465529,,,,,,,,,,,,,,,,, -271700,0.23364767,1.418584,,,,,,,,,,,,,,,,, -271800,0.23567183,1.4793135,,,,,,,,,,,,,,,,, -271900,0.22920282,1.504075,,,,,,,,,,,,,,,,, -272000,0.23410249,1.4599462,,,,,,,,,,,,,,,,, -272100,0.23568878,1.4722859,,,,,,,,,,,,,,,,, -272200,0.2396008,1.4578876,,,,,,,,,,,,,,,,, -272300,0.24263942,1.5182581,,,,,,,,,,,,,,,,, -272400,0.2576999,1.3974653,,,,,,,,,,,,,,,,, -272500,0.22936833,1.4504019,,,,,,,,,,,,,,,,, -272600,0.23960085,1.4353491,,,,,,,,,,,,,,,,, -272700,0.24638635,1.4884874,,,,,,,,,,,,,,,,, -272800,0.23523217,1.400153,,,,,,,,,,,,,,,,, -272900,0.22970006,1.3765746,,,,,,,,,,,,,,,,, -273000,0.2362239,1.4031678,,,,,,,,,,,,,,,,, -273100,0.24246557,1.5128713,,,,,,,,,,,,,,,,, -273200,0.23401175,1.3900579,,,,,,,,,,,,,,,,, -273300,0.2320202,1.4115368,,,,,,,,,,,,,,,,, -273400,0.23193955,1.3748436,,,,,,,,,,,,,,,,, -273500,0.23527509,1.4630946,,,,,,,,,,,,,,,,, -273600,0.24509314,1.4327836,,,,,,,,,,,,,,,,, -273700,0.22821532,1.4223237,,,,,,,,,,,,,,,,, -273800,0.24145852,1.4822513,,,,,,,,,,,,,,,,, -273900,0.23212376,1.4410238,,,,,,,,,,,,,,,,, -273937,,,0.6959747076034546,1.3692622184753418,35.55752862460547,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,97485.64288377762,167355.8598549366,97485.64288377762,69856.77969145775,5.073653221130371,0.0 -274000,0.23404476,1.5007101,,,,,,,,,,,,,,,,, -274100,0.22469315,1.4331645,,,,,,,,,,,,,,,,, -274200,0.24305913,1.46741,,,,,,,,,,,,,,,,, -274300,0.25048998,1.5100625,,,,,,,,,,,,,,,,, -274400,0.2341701,1.388776,,,,,,,,,,,,,,,,, -274500,0.24495792,1.482355,,,,,,,,,,,,,,,,, -274600,0.23557849,1.4039495,,,,,,,,,,,,,,,,, -274700,0.23195459,1.3990608,,,,,,,,,,,,,,,,, -274800,0.23006071,1.3771584,,,,,,,,,,,,,,,,, -274900,0.23431566,1.4274393,,,,,,,,,,,,,,,,, -275000,0.241017,1.4233227,,,,,,,,,,,,,,,,, -275100,0.229468,1.452627,,,,,,,,,,,,,,,,, -275200,0.23945476,1.5208323,,,,,,,,,,,,,,,,, -275300,0.23459943,1.4486476,,,,,,,,,,,,,,,,, -275400,0.239389,1.4561355,,,,,,,,,,,,,,,,, -275500,0.23257025,1.5192728,,,,,,,,,,,,,,,,, -275600,0.25091732,1.4618722,,,,,,,,,,,,,,,,, -275700,0.23435235,1.4491982,,,,,,,,,,,,,,,,, -275800,0.23007073,1.4433341,,,,,,,,,,,,,,,,, -275900,0.23750609,1.4679438,,,,,,,,,,,,,,,,, -276000,0.24128903,1.4857107,,,,,,,,,,,,,,,,, -276100,0.2347678,1.429745,,,,,,,,,,,,,,,,, -276200,0.23332107,1.4766402,,,,,,,,,,,,,,,,, -276296,,,0.6965546607971191,1.3616505861282349,35.67617262040431,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,98325.51252913476,168795.82131123543,98325.51252913476,70456.73684310913,5.134132146835327,0.0 -276300,0.23568308,1.4189771,,,,,,,,,,,,,,,,, -276400,0.23899989,1.5050133,,,,,,,,,,,,,,,,, -276500,0.22392423,1.369263,,,,,,,,,,,,,,,,, -276600,0.23188876,1.3925141,,,,,,,,,,,,,,,,, -276700,0.23439382,1.4929835,,,,,,,,,,,,,,,,, -276800,0.23531638,1.4179851,,,,,,,,,,,,,,,,, -276900,0.22855316,1.3163604,,,,,,,,,,,,,,,,, -277000,0.2416456,1.5163319,,,,,,,,,,,,,,,,, -277100,0.22188926,1.3804463,,,,,,,,,,,,,,,,, -277200,0.23532481,1.4430541,,,,,,,,,,,,,,,,, -277300,0.23378246,1.446554,,,,,,,,,,,,,,,,, -277400,0.24075732,1.4752508,,,,,,,,,,,,,,,,, -277500,0.22412826,1.4086052,,,,,,,,,,,,,,,,, -277600,0.24204814,1.492322,,,,,,,,,,,,,,,,, -277700,0.23307067,1.4413195,,,,,,,,,,,,,,,,, -277800,0.23540701,1.4257946,,,,,,,,,,,,,,,,, -277900,0.22771709,1.3720891,,,,,,,,,,,,,,,,, -278000,0.23732749,1.4384032,,,,,,,,,,,,,,,,, -278100,0.23258363,1.4295297,,,,,,,,,,,,,,,,, -278200,0.22496204,1.386271,,,,,,,,,,,,,,,,, -278300,0.23177184,1.4813854,,,,,,,,,,,,,,,,, -278400,0.23768324,1.4680251,,,,,,,,,,,,,,,,, -278500,0.23161389,1.4444643,,,,,,,,,,,,,,,,, -278600,0.2250178,1.3836416,,,,,,,,,,,,,,,,, -278656,,,0.6967638731002808,1.3642253875732422,35.39564349576458,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,99165.73145341872,170237.39524626732,99165.73145341872,71057.9582953453,5.194442510604858,0.0 -278700,0.23692703,1.4410703,,,,,,,,,,,,,,,,, -278800,0.24219371,1.4457824,,,,,,,,,,,,,,,,, -278900,0.24193838,1.4940907,,,,,,,,,,,,,,,,, -279000,0.2827122,1.5082474,,,,,,,,,,,,,,,,, -279100,0.23484729,1.4079576,,,,,,,,,,,,,,,,, -279200,0.22629291,1.4563956,,,,,,,,,,,,,,,,, -279300,0.23962751,1.4226046,,,,,,,,,,,,,,,,, -279400,0.23552406,1.4332899,,,,,,,,,,,,,,,,, -279500,0.23759101,1.4599009,,,,,,,,,,,,,,,,, -279600,0.22736204,1.4386716,,,,,,,,,,,,,,,,, -279700,0.23170495,1.458005,,,,,,,,,,,,,,,,, -279800,0.23904839,1.5559165,,,,,,,,,,,,,,,,, -279900,0.2394042,1.410014,,,,,,,,,,,,,,,,, -280000,0.22758266,1.4357262,,,,,,,,,,,,,,,,, -280100,0.24164568,1.4628602,,,,,,,,,,,,,,,,, -280200,0.23529878,1.4533656,,,,,,,,,,,,,,,,, -280300,0.24208592,1.4070371,,,,,,,,,,,,,,,,, -280400,0.23830977,1.406809,,,,,,,,,,,,,,,,, -280500,0.23146518,1.4231441,,,,,,,,,,,,,,,,, -280600,0.23753962,1.4699655,,,,,,,,,,,,,,,,, -280700,0.23434035,1.4330039,,,,,,,,,,,,,,,,, -280800,0.22694683,1.4354049,,,,,,,,,,,,,,,,, -280900,0.23298115,1.5149413,,,,,,,,,,,,,,,,, -281000,0.24067664,1.4202234,,,,,,,,,,,,,,,,, -281016,,,0.6934937834739685,1.3861702680587769,35.466408250843514,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,100005.89224982262,171671.8169312477,100005.89224982262,71652.07466721535,5.26610803604126,0.0 -281100,0.23700604,1.4404545,,,,,,,,,,,,,,,,, -281200,0.21778993,1.3617759,,,,,,,,,,,,,,,,, -281300,0.23395719,1.432896,,,,,,,,,,,,,,,,, -281400,0.23140277,1.4420646,,,,,,,,,,,,,,,,, -281500,0.23073284,1.4301997,,,,,,,,,,,,,,,,, -281600,0.23024443,1.4151794,,,,,,,,,,,,,,,,, -281700,0.23335849,1.4614407,,,,,,,,,,,,,,,,, -281800,0.22457938,1.4279152,,,,,,,,,,,,,,,,, -281900,0.23628698,1.4341335,,,,,,,,,,,,,,,,, -282000,0.23942399,1.4312192,,,,,,,,,,,,,,,,, -282100,0.23608549,1.4240739,,,,,,,,,,,,,,,,, -282200,0.24007851,1.499626,,,,,,,,,,,,,,,,, -282300,0.23009908,1.4526109,,,,,,,,,,,,,,,,, -282400,0.24047826,1.483007,,,,,,,,,,,,,,,,, -282500,0.24363412,1.4354794,,,,,,,,,,,,,,,,, -282600,0.24666795,1.398293,,,,,,,,,,,,,,,,, -282700,0.24094768,1.4964244,,,,,,,,,,,,,,,,, -282800,0.23553827,1.4263957,,,,,,,,,,,,,,,,, -282900,0.22952728,1.3666738,,,,,,,,,,,,,,,,, -283000,0.23041989,1.4575855,,,,,,,,,,,,,,,,, -283100,0.22778772,1.4313736,,,,,,,,,,,,,,,,, -283200,0.23534442,1.4114046,,,,,,,,,,,,,,,,, -283300,0.2329454,1.4435248,,,,,,,,,,,,,,,,, -283376,,,0.6972944736480713,1.3638888597488403,35.66149794669415,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,100845.91227436066,173102.43675160408,100845.91227436066,72242.53067946434,5.33665919303894,0.0 -283400,0.23608579,1.5112201,,,,,,,,,,,,,,,,, -283500,0.23124613,1.4522941,,,,,,,,,,,,,,,,, -283600,0.2342399,1.4309694,,,,,,,,,,,,,,,,, -283700,0.24380122,1.4503679,,,,,,,,,,,,,,,,, -283800,0.2324773,1.4468381,,,,,,,,,,,,,,,,, -283900,0.23419173,1.4380515,,,,,,,,,,,,,,,,, -284000,0.23245327,1.4893894,,,,,,,,,,,,,,,,, -284100,0.24029095,1.4650978,,,,,,,,,,,,,,,,, -284200,0.24258712,1.4358358,,,,,,,,,,,,,,,,, -284300,0.23076732,1.4171318,,,,,,,,,,,,,,,,, -284400,0.22742891,1.4474137,,,,,,,,,,,,,,,,, -284500,0.23016746,1.4049585,,,,,,,,,,,,,,,,, -284600,0.23942068,1.4919041,,,,,,,,,,,,,,,,, -284700,0.24247804,1.423901,,,,,,,,,,,,,,,,, -284800,0.24017616,1.4689113,,,,,,,,,,,,,,,,, -284900,0.22629441,1.4081471,,,,,,,,,,,,,,,,, -285000,0.24101843,1.4015503,,,,,,,,,,,,,,,,, -285100,0.23804885,1.4821497,,,,,,,,,,,,,,,,, -285200,0.24135043,1.3992642,,,,,,,,,,,,,,,,, -285300,0.22915058,1.368172,,,,,,,,,,,,,,,,, -285400,0.23848537,1.3548867,,,,,,,,,,,,,,,,, -285500,0.22815946,1.4056265,,,,,,,,,,,,,,,,, -285600,0.23892514,1.5134635,,,,,,,,,,,,,,,,, -285700,0.22728674,1.3780893,,,,,,,,,,,,,,,,, -285737,,,0.6911959052085876,1.395319581031799,36.23396595613439,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,101685.87846279144,174541.66915917397,101685.87846279144,72841.66510772705,5.397650003433228,0.0 -285800,0.24009682,1.4417435,,,,,,,,,,,,,,,,, -285900,0.23789527,1.4523648,,,,,,,,,,,,,,,,, -286000,0.23375638,1.3728726,,,,,,,,,,,,,,,,, -286100,0.22570698,1.4492882,,,,,,,,,,,,,,,,, -286200,0.23673284,1.4469943,,,,,,,,,,,,,,,,, -286300,0.23483703,1.466137,,,,,,,,,,,,,,,,, -286400,0.23484802,1.3796055,,,,,,,,,,,,,,,,, -286500,0.23308623,1.3940696,,,,,,,,,,,,,,,,, -286600,0.24207728,1.4558395,,,,,,,,,,,,,,,,, -286700,0.24768068,1.4997717,,,,,,,,,,,,,,,,, -286800,0.23761386,1.4735004,,,,,,,,,,,,,,,,, -286900,0.23560615,1.4033563,,,,,,,,,,,,,,,,, -287000,0.2276717,1.5042346,,,,,,,,,,,,,,,,, -287100,0.24171217,1.4783255,,,,,,,,,,,,,,,,, -287200,0.23693797,1.3950611,,,,,,,,,,,,,,,,, -287300,0.23823468,1.49733,,,,,,,,,,,,,,,,, -287400,0.23352613,1.4757856,,,,,,,,,,,,,,,,, -287500,0.2371769,1.4864892,,,,,,,,,,,,,,,,, -287600,0.24001744,1.5201981,,,,,,,,,,,,,,,,, -287700,0.24252853,1.4708701,,,,,,,,,,,,,,,,, -287800,0.22441617,1.4357824,,,,,,,,,,,,,,,,, -287900,0.23097788,1.3517889,,,,,,,,,,,,,,,,, -288000,0.23051874,1.4439484,,,,,,,,,,,,,,,,, -288099,,,0.6965842843055725,1.3646329641342163,35.65843571065414,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,102525.98865199088,175984.77796077728,102525.98865199088,73444.53052544594,5.460918426513672,0.0 -288100,0.22851975,1.3654071,,,,,,,,,,,,,,,,, -288200,0.23458378,1.462949,,,,,,,,,,,,,,,,, -288300,0.32829377,1.4805877,,,,,,,,,,,,,,,,, -288400,0.23483042,1.4735794,,,,,,,,,,,,,,,,, -288500,0.23744184,1.4242897,,,,,,,,,,,,,,,,, -288600,0.23444085,1.4498962,,,,,,,,,,,,,,,,, -288700,0.2425305,1.4248,,,,,,,,,,,,,,,,, -288800,0.23983671,1.4846737,,,,,,,,,,,,,,,,, -288900,0.2396491,1.4262272,,,,,,,,,,,,,,,,, -289000,0.23942925,1.41876,,,,,,,,,,,,,,,,, -289100,0.23947902,1.3983573,,,,,,,,,,,,,,,,, -289200,0.22973555,1.4727678,,,,,,,,,,,,,,,,, -289300,0.23329042,1.4118587,,,,,,,,,,,,,,,,, -289400,0.23925853,1.4563665,,,,,,,,,,,,,,,,, -289500,0.24999976,1.4660532,,,,,,,,,,,,,,,,, -289600,0.23193628,1.4637421,,,,,,,,,,,,,,,,, -289700,0.23763451,1.4616855,,,,,,,,,,,,,,,,, -289800,0.23675622,1.4611697,,,,,,,,,,,,,,,,, -289900,0.23512186,1.4574671,,,,,,,,,,,,,,,,, -290000,0.23659404,1.3612561,,,,,,,,,,,,,,,,, -290100,0.24460854,1.4282205,,,,,,,,,,,,,,,,, -290200,0.2317071,1.4838678,,,,,,,,,,,,,,,,, -290300,0.23613183,1.3482449,,,,,,,,,,,,,,,,, -290400,0.2534873,1.441875,,,,,,,,,,,,,,,,, -290459,,,0.696201503276825,1.3672454357147217,35.46740520559859,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,103365.85811543465,177440.03746771812,103365.85811543465,74059.78676986694,5.522222518920898,0.0 -290500,0.22666349,1.4138317,,,,,,,,,,,,,,,,, -290600,0.24260898,1.4144332,,,,,,,,,,,,,,,,, -290700,0.2366554,1.5352817,,,,,,,,,,,,,,,,, -290800,0.23892534,1.4127467,,,,,,,,,,,,,,,,, -290900,0.23590584,1.3994236,,,,,,,,,,,,,,,,, -291000,0.22527252,1.4338169,,,,,,,,,,,,,,,,, -291100,0.2325609,1.4109125,,,,,,,,,,,,,,,,, -291200,0.2280634,1.4663061,,,,,,,,,,,,,,,,, -291300,0.23806137,1.4867369,,,,,,,,,,,,,,,,, -291400,0.22489552,1.3799704,,,,,,,,,,,,,,,,, -291500,0.22714798,1.4448833,,,,,,,,,,,,,,,,, -291600,0.23733069,1.4533688,,,,,,,,,,,,,,,,, -291700,0.23691444,1.4971192,,,,,,,,,,,,,,,,, -291800,0.23311624,1.3957415,,,,,,,,,,,,,,,,, -291900,0.23301269,1.3766278,,,,,,,,,,,,,,,,, -292000,0.23609981,1.4500344,,,,,,,,,,,,,,,,, -292100,0.23103915,1.4631488,,,,,,,,,,,,,,,,, -292200,0.23922797,1.4078672,,,,,,,,,,,,,,,,, -292300,0.22508499,1.4703693,,,,,,,,,,,,,,,,, -292400,0.23423362,1.4081397,,,,,,,,,,,,,,,,, -292500,0.24188785,1.4651268,,,,,,,,,,,,,,,,, -292600,0.23349863,1.4615579,,,,,,,,,,,,,,,,, -292700,0.24707057,1.4578502,,,,,,,,,,,,,,,,, -292800,0.2286215,1.4433765,,,,,,,,,,,,,,,,, -292820,,,0.695793867111206,1.368704319000244,35.59809594260861,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,104205.91893053056,178876.97866034508,104205.91893053056,74656.53228163719,5.5849597454071045,0.0 -292900,0.2349117,1.444319,,,,,,,,,,,,,,,,, -293000,0.23225825,1.4349992,,,,,,,,,,,,,,,,, -293100,0.22849008,1.4405949,,,,,,,,,,,,,,,,, -293200,0.23674646,1.4590385,,,,,,,,,,,,,,,,, -293300,0.23973104,1.4996985,,,,,,,,,,,,,,,,, -293400,0.22927204,1.4725178,,,,,,,,,,,,,,,,, -293500,0.22953866,1.4671317,,,,,,,,,,,,,,,,, -293600,0.22485825,1.4347135,,,,,,,,,,,,,,,,, -293700,0.22777204,1.4146806,,,,,,,,,,,,,,,,, -293800,0.2239167,1.3293701,,,,,,,,,,,,,,,,, -293900,0.24278207,1.4091924,,,,,,,,,,,,,,,,, -294000,0.22692515,1.4730005,,,,,,,,,,,,,,,,, -294100,0.2261136,1.4833761,,,,,,,,,,,,,,,,, -294200,0.23158939,1.4236509,,,,,,,,,,,,,,,,, -294300,0.22852954,1.3793935,,,,,,,,,,,,,,,,, -294400,0.23253283,1.4327673,,,,,,,,,,,,,,,,, -294500,0.23972468,1.4500896,,,,,,,,,,,,,,,,, -294600,0.2384171,1.4117476,,,,,,,,,,,,,,,,, -294700,0.23622271,1.477572,,,,,,,,,,,,,,,,, -294800,0.23625776,1.4347473,,,,,,,,,,,,,,,,, -294900,0.23127507,1.4403353,,,,,,,,,,,,,,,,, -295000,0.23072496,1.420369,,,,,,,,,,,,,,,,, -295100,0.23259284,1.5016809,,,,,,,,,,,,,,,,, -295180,,,0.6961166858673096,1.3674596548080444,35.2200033761678,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,105045.99169325829,180312.35103321075,105045.99169325829,75251.69767832756,5.647981643676758,0.0 -295200,0.233769,1.4703867,,,,,,,,,,,,,,,,, -295300,0.24359407,1.5229595,,,,,,,,,,,,,,,,, -295400,0.24659277,1.4286637,,,,,,,,,,,,,,,,, -295500,0.22958028,1.4441283,,,,,,,,,,,,,,,,, -295600,0.23534527,1.4296906,,,,,,,,,,,,,,,,, -295700,0.23749585,1.445266,,,,,,,,,,,,,,,,, -295800,0.2408117,1.4624,,,,,,,,,,,,,,,,, -295900,0.24260156,1.4007027,,,,,,,,,,,,,,,,, -296000,0.2277363,1.4758016,,,,,,,,,,,,,,,,, -296100,0.24332757,1.5234041,,,,,,,,,,,,,,,,, -296200,0.23855084,1.5192871,,,,,,,,,,,,,,,,, -296300,0.22266358,1.4153614,,,,,,,,,,,,,,,,, -296400,0.2255638,1.4305943,,,,,,,,,,,,,,,,, -296500,0.23807171,1.3929045,,,,,,,,,,,,,,,,, -296600,0.22869837,1.386533,,,,,,,,,,,,,,,,, -296700,0.22461915,1.445673,,,,,,,,,,,,,,,,, -296800,0.23765439,1.5103487,,,,,,,,,,,,,,,,, -296900,0.23265907,1.4546667,,,,,,,,,,,,,,,,, -297000,0.23681623,1.4490969,,,,,,,,,,,,,,,,, -297100,0.23324135,1.407977,,,,,,,,,,,,,,,,, -297200,0.23831798,1.4365258,,,,,,,,,,,,,,,,, -297300,0.23379073,1.4035411,,,,,,,,,,,,,,,,, -297400,0.23589003,1.4640849,,,,,,,,,,,,,,,,, -297500,0.23486036,1.5373946,,,,,,,,,,,,,,,,, -297540,,,0.6968651413917542,1.370174765586853,35.314651608415545,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,105885.86472034454,181754.6443226337,105885.86472034454,75853.98563599586,5.709059000015259,0.0 -297600,0.2431402,1.4624597,,,,,,,,,,,,,,,,, -297700,0.23555139,1.4792691,,,,,,,,,,,,,,,,, -297800,0.23267859,1.4102906,,,,,,,,,,,,,,,,, -297900,0.2311428,1.4382598,,,,,,,,,,,,,,,,, -298000,0.23870523,1.4061066,,,,,,,,,,,,,,,,, -298100,0.22567253,1.4109885,,,,,,,,,,,,,,,,, -298200,0.2372241,1.4587182,,,,,,,,,,,,,,,,, -298300,0.22988023,1.4202983,,,,,,,,,,,,,,,,, -298400,0.23556635,1.3846662,,,,,,,,,,,,,,,,, -298500,0.23217714,1.4197063,,,,,,,,,,,,,,,,, -298600,0.2334041,1.4743481,,,,,,,,,,,,,,,,, -298700,0.22500142,1.4745083,,,,,,,,,,,,,,,,, -298800,0.23501827,1.4355526,,,,,,,,,,,,,,,,, -298900,0.2447416,1.376084,,,,,,,,,,,,,,,,, -299000,0.22928116,1.4528089,,,,,,,,,,,,,,,,, -299100,0.23261485,1.4481876,,,,,,,,,,,,,,,,, -299200,0.2405741,1.5073347,,,,,,,,,,,,,,,,, -299300,0.23893268,1.4847274,,,,,,,,,,,,,,,,, -299400,0.23040529,1.4606162,,,,,,,,,,,,,,,,, -299500,0.24108411,1.46984,,,,,,,,,,,,,,,,, -299600,0.23529616,1.5244421,,,,,,,,,,,,,,,,, -299700,0.23567125,1.4844902,,,,,,,,,,,,,,,,, -299800,0.23945247,1.4228033,,,,,,,,,,,,,,,,, -299900,,,0.6927401423454285,1.3830639123916626,35.545994958747805,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,106725.71911907196,183197.226900816,106725.71911907196,76456.5664255619,5.785603046417236,0.0 -299900,0.22417447,1.3916979,,,,,,,,,,,,,,,,, -300000,0.2303661,1.4427097,,,,,,,,,,,,,,,,, -300100,0.24060291,1.3995179,,,,,,,,,,,,,,,,, -300200,0.2269068,1.4032706,,,,,,,,,,,,,,,,, -300300,0.2475867,1.5177866,,,,,,,,,,,,,,,,, -300400,0.21881211,1.4076573,,,,,,,,,,,,,,,,, -300500,0.23532473,1.4293774,,,,,,,,,,,,,,,,, -300600,0.22704682,1.3637685,,,,,,,,,,,,,,,,, -300700,0.22974609,1.476879,,,,,,,,,,,,,,,,, -300800,0.2274438,1.4231776,,,,,,,,,,,,,,,,, -300900,0.2298013,1.4688324,,,,,,,,,,,,,,,,, -301000,0.24829745,1.5211608,,,,,,,,,,,,,,,,, -301100,0.22915822,1.4546475,,,,,,,,,,,,,,,,, -301200,0.23244148,1.4712509,,,,,,,,,,,,,,,,, -301300,0.22978629,1.4792305,,,,,,,,,,,,,,,,, -301400,0.2338946,1.4447868,,,,,,,,,,,,,,,,, -301500,0.23974022,1.4599274,,,,,,,,,,,,,,,,, -301600,0.23470199,1.4345385,,,,,,,,,,,,,,,,, -301700,0.23110941,1.4626794,,,,,,,,,,,,,,,,, -301800,0.24455458,1.5076907,,,,,,,,,,,,,,,,, -301900,0.24571018,1.4270641,,,,,,,,,,,,,,,,, -302000,0.22647202,1.3502795,,,,,,,,,,,,,,,,, -302100,0.23571843,1.4805908,,,,,,,,,,,,,,,,, -302200,0.2332906,1.4825757,,,,,,,,,,,,,,,,, -302260,,,0.6952672600746155,1.3717122077941897,35.55978230786836,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,107565.61857748032,184634.89792728424,107565.61857748032,77054.19925522804,5.850619554519653,0.0 -302300,0.23017114,1.4847567,,,,,,,,,,,,,,,,, -302400,0.24560715,1.4099927,,,,,,,,,,,,,,,,, -302500,0.2385364,1.4446836,,,,,,,,,,,,,,,,, -302600,0.24080572,1.4856685,,,,,,,,,,,,,,,,, -302700,0.23663843,1.4641212,,,,,,,,,,,,,,,,, -302800,0.2356582,1.5079182,,,,,,,,,,,,,,,,, -302900,0.2373397,1.518177,,,,,,,,,,,,,,,,, -303000,0.23060559,1.3468237,,,,,,,,,,,,,,,,, -303100,0.2335418,1.503443,,,,,,,,,,,,,,,,, -303200,0.6184576,1.420219,,,,,,,,,,,,,,,,, -303300,0.23938987,1.5270867,,,,,,,,,,,,,,,,, -303400,0.24006142,1.4582658,,,,,,,,,,,,,,,,, -303500,0.23213486,1.4786572,,,,,,,,,,,,,,,,, -303600,0.22402668,1.3963346,,,,,,,,,,,,,,,,, -303700,0.24034415,1.4242631,,,,,,,,,,,,,,,,, -303800,0.23614493,1.3914475,,,,,,,,,,,,,,,,, -303900,0.24473551,1.3783371,,,,,,,,,,,,,,,,, -304000,0.23560148,1.4434677,,,,,,,,,,,,,,,,, -304100,0.22703223,1.4537919,,,,,,,,,,,,,,,,, -304200,0.23664409,1.3968904,,,,,,,,,,,,,,,,, -304300,0.22656664,1.4308062,,,,,,,,,,,,,,,,, -304400,0.23096104,1.4242293,,,,,,,,,,,,,,,,, -304500,0.23631968,1.4510374,,,,,,,,,,,,,,,,, -304600,0.24363483,1.4701087,,,,,,,,,,,,,,,,, -304620,,,0.696880578994751,1.362842679023743,34.97261335988815,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,108405.63617539406,186072.84896922112,108405.63617539406,77651.99628734589,5.915649890899658,0.0 -304700,0.23232386,1.4668115,,,,,,,,,,,,,,,,, -304800,0.2314631,1.4346038,,,,,,,,,,,,,,,,, -304900,0.22480516,1.4506207,,,,,,,,,,,,,,,,, -305000,0.23492946,1.4766858,,,,,,,,,,,,,,,,, -305100,0.24234149,1.4544382,,,,,,,,,,,,,,,,, -305200,0.23529606,1.4570671,,,,,,,,,,,,,,,,, -305300,0.2288333,1.4470931,,,,,,,,,,,,,,,,, -305400,0.24241664,1.4838041,,,,,,,,,,,,,,,,, -305500,0.23620468,1.4516898,,,,,,,,,,,,,,,,, -305600,0.23962134,1.4204645,,,,,,,,,,,,,,,,, -305700,0.23248024,1.410428,,,,,,,,,,,,,,,,, -305800,0.22566102,1.4056729,,,,,,,,,,,,,,,,, -305900,0.2380202,1.431199,,,,,,,,,,,,,,,,, -306000,0.2302773,1.4998088,,,,,,,,,,,,,,,,, -306100,0.22700746,1.4151926,,,,,,,,,,,,,,,,, -306200,0.22259353,1.4292481,,,,,,,,,,,,,,,,, -306300,0.22509396,1.4419453,,,,,,,,,,,,,,,,, -306400,0.23002638,1.3526696,,,,,,,,,,,,,,,,, -306500,0.23770984,1.5144581,,,,,,,,,,,,,,,,, -306600,0.23921104,1.5167022,,,,,,,,,,,,,,,,, -306700,0.238492,1.3661087,,,,,,,,,,,,,,,,, -306800,0.23672202,1.4913236,,,,,,,,,,,,,,,,, -306900,0.22265296,1.3964876,,,,,,,,,,,,,,,,, -306979,,,0.6959228515625,1.370876431465149,35.62716315697208,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,109245.5084617138,187512.7414045334,109245.5084617138,78251.87782359123,5.980376243591309,0.0 -307000,0.23015267,1.4118195,,,,,,,,,,,,,,,,, -307100,0.23459364,1.3989744,,,,,,,,,,,,,,,,, -307200,0.23420273,1.4571178,,,,,,,,,,,,,,,,, -307300,0.23345438,1.4765158,,,,,,,,,,,,,,,,, -307400,0.23282188,1.4224494,,,,,,,,,,,,,,,,, -307500,0.24017297,1.439677,,,,,,,,,,,,,,,,, -307600,0.23656876,1.4527895,,,,,,,,,,,,,,,,, -307700,0.23912333,1.472115,,,,,,,,,,,,,,,,, -307800,0.232918,1.4638103,,,,,,,,,,,,,,,,, -307900,0.23974988,1.3990781,,,,,,,,,,,,,,,,, -308000,0.2433227,1.4727106,,,,,,,,,,,,,,,,, -308100,0.23504326,1.4775575,,,,,,,,,,,,,,,,, -308200,0.22514817,1.363766,,,,,,,,,,,,,,,,, -308300,0.23101807,1.4051386,,,,,,,,,,,,,,,,, -308400,0.22695413,1.4178522,,,,,,,,,,,,,,,,, -308500,0.24462786,1.4997741,,,,,,,,,,,,,,,,, -308600,0.23380779,1.4222122,,,,,,,,,,,,,,,,, -308700,0.23255025,1.3982972,,,,,,,,,,,,,,,,, -308800,0.2343562,1.4537311,,,,,,,,,,,,,,,,, -308900,0.23342422,1.5302888,,,,,,,,,,,,,,,,, -309000,0.2402056,1.4767948,,,,,,,,,,,,,,,,, -309100,0.23434348,1.464673,,,,,,,,,,,,,,,,, -309200,0.237226,1.46858,,,,,,,,,,,,,,,,, -309300,0.24300279,1.4433727,,,,,,,,,,,,,,,,, -309339,,,0.6993377804756165,1.3491086959838867,35.87252412013725,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,110085.39775514604,188966.4353232384,110085.39775514604,78865.54597783089,6.046472072601318,0.0 -309400,0.23249386,1.4609971,,,,,,,,,,,,,,,,, -309500,0.23964825,1.4899657,,,,,,,,,,,,,,,,, -309600,0.22319342,1.3597779,,,,,,,,,,,,,,,,, -309700,0.23363663,1.5234938,,,,,,,,,,,,,,,,, -309800,0.22847693,1.3641672,,,,,,,,,,,,,,,,, -309900,0.22808862,1.5155385,,,,,,,,,,,,,,,,, -310000,0.22685668,1.3795745,,,,,,,,,,,,,,,,, -310100,0.23363844,1.4606184,,,,,,,,,,,,,,,,, -310200,0.23993774,1.4215902,,,,,,,,,,,,,,,,, -310300,0.23827748,1.3862327,,,,,,,,,,,,,,,,, -310400,0.24313389,1.397842,,,,,,,,,,,,,,,,, -310500,0.23553343,1.4192139,,,,,,,,,,,,,,,,, -310600,0.24039844,1.4496382,,,,,,,,,,,,,,,,, -310700,0.22366378,1.4551388,,,,,,,,,,,,,,,,, -310800,0.22468032,1.4097809,,,,,,,,,,,,,,,,, -310900,0.2385322,1.4942194,,,,,,,,,,,,,,,,, -311000,0.24420309,1.4636413,,,,,,,,,,,,,,,,, -311100,0.23951678,1.4041754,,,,,,,,,,,,,,,,, -311200,0.24003287,1.5138534,,,,,,,,,,,,,,,,, -311300,0.24052502,1.4394537,,,,,,,,,,,,,,,,, -311400,0.23267092,1.4408324,,,,,,,,,,,,,,,,, -311500,0.23886184,1.4923776,,,,,,,,,,,,,,,,, -311600,0.23947772,1.5009723,,,,,,,,,,,,,,,,, -311699,,,0.6996234059333801,1.3519960641860962,35.82919492644673,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,110925.37493872644,190402.58885407448,110925.37493872644,79461.58521485329,6.110290288925171,0.0 -311700,0.22530985,1.381182,,,,,,,,,,,,,,,,, -311800,0.23373005,1.5382146,,,,,,,,,,,,,,,,, -311900,0.22837304,1.496507,,,,,,,,,,,,,,,,, -312000,0.24080947,1.4953432,,,,,,,,,,,,,,,,, -312100,0.23223257,1.4560362,,,,,,,,,,,,,,,,, -312200,0.2391126,1.4871205,,,,,,,,,,,,,,,,, -312300,0.23816514,1.48389,,,,,,,,,,,,,,,,, -312400,0.24447435,1.4270755,,,,,,,,,,,,,,,,, -312500,0.23333304,1.5186881,,,,,,,,,,,,,,,,, -312600,0.2286596,1.4867853,,,,,,,,,,,,,,,,, -312700,0.2311534,1.4534116,,,,,,,,,,,,,,,,, -312800,0.22229776,1.4001845,,,,,,,,,,,,,,,,, -312900,0.24655922,1.5131687,,,,,,,,,,,,,,,,, -313000,0.23291335,1.3959786,,,,,,,,,,,,,,,,, -313100,0.24680687,1.4546782,,,,,,,,,,,,,,,,, -313200,0.2274435,1.4645953,,,,,,,,,,,,,,,,, -313300,0.24550834,1.5288146,,,,,,,,,,,,,,,,, -313400,0.22801147,1.4644364,,,,,,,,,,,,,,,,, -313500,0.23952526,1.5426081,,,,,,,,,,,,,,,,, -313600,0.23840806,1.4345305,,,,,,,,,,,,,,,,, -313700,0.23179357,1.4276602,,,,,,,,,,,,,,,,, -313800,0.2255298,1.4251903,,,,,,,,,,,,,,,,, -313900,0.24676995,1.523268,,,,,,,,,,,,,,,,, -314000,0.24064715,1.4490254,,,,,,,,,,,,,,,,, -314060,,,0.6977344155311584,1.3559484481811523,35.35065572173739,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,111765.5486676693,191846.70589494705,111765.5486676693,80065.39432311058,6.174302101135254,0.0 -314100,0.23728895,1.4348656,,,,,,,,,,,,,,,,, -314200,0.22230071,1.3828493,,,,,,,,,,,,,,,,, -314300,0.23582414,1.3922927,,,,,,,,,,,,,,,,, -314400,0.24825574,1.4085056,,,,,,,,,,,,,,,,, -314500,0.22272317,1.405268,,,,,,,,,,,,,,,,, -314600,0.24011336,1.5368451,,,,,,,,,,,,,,,,, -314700,0.23722357,1.422201,,,,,,,,,,,,,,,,, -314800,0.23416182,1.4220369,,,,,,,,,,,,,,,,, -314900,0.223649,1.3845873,,,,,,,,,,,,,,,,, -315000,0.23435445,1.4869747,,,,,,,,,,,,,,,,, -315100,0.23203489,1.4558531,,,,,,,,,,,,,,,,, -315200,0.23192088,1.407519,,,,,,,,,,,,,,,,, -315300,0.2316202,1.4609524,,,,,,,,,,,,,,,,, -315400,0.23112631,1.4569755,,,,,,,,,,,,,,,,, -315500,0.23632693,1.421239,,,,,,,,,,,,,,,,, -315600,0.23650585,1.4882587,,,,,,,,,,,,,,,,, -315700,0.23221487,1.4038162,,,,,,,,,,,,,,,,, -315800,0.23013137,1.4453288,,,,,,,,,,,,,,,,, -315900,0.22198898,1.4068098,,,,,,,,,,,,,,,,, -316000,0.23316972,1.4191376,,,,,,,,,,,,,,,,, -316100,0.23934968,1.4692864,,,,,,,,,,,,,,,,, -316200,0.2469441,1.4651252,,,,,,,,,,,,,,,,, -316300,0.23696019,1.4169985,,,,,,,,,,,,,,,,, -316400,0.22354771,1.3973557,,,,,,,,,,,,,,,,, -316420,,,0.6962667107582092,1.3646796941757202,35.69826971741045,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,112605.41333723068,193297.88295459747,112605.41333723068,80676.55667924881,6.251188278198242,0.0 -316500,0.22834511,1.4580735,,,,,,,,,,,,,,,,, -316600,0.23705044,1.4355291,,,,,,,,,,,,,,,,, -316700,0.23498955,1.4437685,,,,,,,,,,,,,,,,, -316800,0.22913492,1.4731333,,,,,,,,,,,,,,,,, -316900,0.22735943,1.4461552,,,,,,,,,,,,,,,,, -317000,0.23010822,1.4405755,,,,,,,,,,,,,,,,, -317100,0.23578146,1.5278789,,,,,,,,,,,,,,,,, -317200,0.23985487,1.5349088,,,,,,,,,,,,,,,,, -317300,0.23134738,1.3528455,,,,,,,,,,,,,,,,, -317400,0.23493809,1.4321283,,,,,,,,,,,,,,,,, -317500,0.23370975,1.4229715,,,,,,,,,,,,,,,,, -317600,0.23567443,1.4411415,,,,,,,,,,,,,,,,, -317700,0.2365677,1.4391016,,,,,,,,,,,,,,,,, -317800,0.22953536,1.3731698,,,,,,,,,,,,,,,,, -317900,0.2341359,1.3747809,,,,,,,,,,,,,,,,, -318000,0.23031126,1.4687091,,,,,,,,,,,,,,,,, -318100,0.23175465,1.4172813,,,,,,,,,,,,,,,,, -318200,0.23474163,1.5190371,,,,,,,,,,,,,,,,, -318300,0.22995387,1.4446206,,,,,,,,,,,,,,,,, -318400,0.23838185,1.4268419,,,,,,,,,,,,,,,,, -318500,0.23663877,1.5385551,,,,,,,,,,,,,,,,, -318600,0.23726982,1.4056689,,,,,,,,,,,,,,,,, -318700,0.22880667,1.4661494,,,,,,,,,,,,,,,,, -318780,,,0.6989819407463074,1.3501715660095217,35.58819065628888,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,113445.30772137642,194738.9902229309,113445.30772137642,81277.62100315094,6.328147649765015,0.0 -318800,0.23962888,1.4062934,,,,,,,,,,,,,,,,, -318900,0.23811477,1.4181331,,,,,,,,,,,,,,,,, -319000,0.2466456,1.478653,,,,,,,,,,,,,,,,, -319100,0.24411023,1.4450871,,,,,,,,,,,,,,,,, -319200,0.24257813,1.4405894,,,,,,,,,,,,,,,,, -319300,0.23324896,1.3929316,,,,,,,,,,,,,,,,, -319400,0.23143773,1.3949726,,,,,,,,,,,,,,,,, -319500,0.23600686,1.4101425,,,,,,,,,,,,,,,,, -319600,0.2472896,1.5011578,,,,,,,,,,,,,,,,, -319700,0.23465772,1.453622,,,,,,,,,,,,,,,,, -319800,0.23490222,1.4813454,,,,,,,,,,,,,,,,, -319900,0.23040293,1.3999369,,,,,,,,,,,,,,,,, -320000,0.24379326,1.4398946,,,,,,,,,,,,,,,,, -320100,0.23072052,1.423917,,,,,,,,,,,,,,,,, -320200,0.23709512,1.4419155,,,,,,,,,,,,,,,,, -320300,0.22647159,1.4569628,,,,,,,,,,,,,,,,, -320400,0.23736693,1.4569935,,,,,,,,,,,,,,,,, -320500,0.23074943,1.4367819,,,,,,,,,,,,,,,,, -320600,0.2288223,1.3772008,,,,,,,,,,,,,,,,, -320700,0.2302674,1.4533554,,,,,,,,,,,,,,,,, -320800,0.238002,1.4148456,,,,,,,,,,,,,,,,, -320900,0.22594294,1.3994731,,,,,,,,,,,,,,,,, -321000,0.24153769,1.417275,,,,,,,,,,,,,,,,, -321100,0.23191571,1.4213873,,,,,,,,,,,,,,,,, -321140,,,0.6966158151626587,1.3639127016067505,35.62989032088873,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,114285.18418955804,196182.6717071533,114285.18418955804,81881.28680968285,6.394540309906006,0.0 -321200,0.23004337,1.4473748,,,,,,,,,,,,,,,,, -321300,0.23826148,1.4539274,,,,,,,,,,,,,,,,, -321400,0.22053002,1.3431613,,,,,,,,,,,,,,,,, -321500,0.22975132,1.4233794,,,,,,,,,,,,,,,,, -321600,0.23708385,1.3892924,,,,,,,,,,,,,,,,, -321700,0.24638017,1.4230894,,,,,,,,,,,,,,,,, -321800,0.22939526,1.4586825,,,,,,,,,,,,,,,,, -321900,0.22562663,1.3680536,,,,,,,,,,,,,,,,, -322000,0.23209135,1.4102461,,,,,,,,,,,,,,,,, -322100,0.23248482,1.4332409,,,,,,,,,,,,,,,,, -322200,0.24386784,1.493192,,,,,,,,,,,,,,,,, -322300,0.23851068,1.3541218,,,,,,,,,,,,,,,,, -322400,0.23141947,1.4560403,,,,,,,,,,,,,,,,, -322500,0.22989523,1.4562967,,,,,,,,,,,,,,,,, -322600,0.23105769,1.3911036,,,,,,,,,,,,,,,,, -322700,0.24810937,1.5101147,,,,,,,,,,,,,,,,, -322800,0.24569674,1.448844,,,,,,,,,,,,,,,,, -322900,0.23799075,1.4994636,,,,,,,,,,,,,,,,, -323000,0.2389106,1.4669546,,,,,,,,,,,,,,,,, -323100,0.23771292,1.4588549,,,,,,,,,,,,,,,,, -323200,0.23348168,1.3953831,,,,,,,,,,,,,,,,, -323300,0.23641363,1.4672227,,,,,,,,,,,,,,,,, -323400,0.22932921,1.4254965,,,,,,,,,,,,,,,,, -323500,0.2421274,1.4552943,,,,,,,,,,,,,,,,, -323501,,,0.6984859704971313,1.3505821228027344,35.67098677731126,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,115125.39649248125,197639.3575992584,115125.39649248125,82497.62356734276,6.459453105926514,0.0 -323600,0.23037763,1.4443464,,,,,,,,,,,,,,,,, -323700,0.22953813,1.4122374,,,,,,,,,,,,,,,,, -323800,0.24729721,1.448583,,,,,,,,,,,,,,,,, -323900,0.23772521,1.4635607,,,,,,,,,,,,,,,,, -324000,0.22964251,1.4710703,,,,,,,,,,,,,,,,, -324100,0.24309407,1.5106846,,,,,,,,,,,,,,,,, -324200,0.23846975,1.4191773,,,,,,,,,,,,,,,,, -324300,0.22723766,1.3372238,,,,,,,,,,,,,,,,, -324400,0.23368993,1.4401817,,,,,,,,,,,,,,,,, -324500,0.24099608,1.4701933,,,,,,,,,,,,,,,,, -324600,0.22332674,1.4204704,,,,,,,,,,,,,,,,, -324700,0.23053923,1.4692535,,,,,,,,,,,,,,,,, -324800,0.24016783,1.4527029,,,,,,,,,,,,,,,,, -324900,0.23865345,1.4844667,,,,,,,,,,,,,,,,, -325000,0.22751632,1.3871995,,,,,,,,,,,,,,,,, -325100,0.24386607,1.4326117,,,,,,,,,,,,,,,,, -325200,0.232468,1.5064251,,,,,,,,,,,,,,,,, -325300,0.22790785,1.3485804,,,,,,,,,,,,,,,,, -325400,0.22812973,1.4397516,,,,,,,,,,,,,,,,, -325500,0.2312821,1.442711,,,,,,,,,,,,,,,,, -325600,0.23593666,1.4362538,,,,,,,,,,,,,,,,, -325700,0.23314483,1.4105221,,,,,,,,,,,,,,,,, -325800,0.24402726,1.5015562,,,,,,,,,,,,,,,,, -325862,,,0.6952317357063293,1.3720507621765137,35.218722071193795,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,115965.46018266678,199089.3668017388,115965.46018266678,83107.43023467064,6.526806592941284,0.0 -325900,0.24619824,1.4492861,,,,,,,,,,,,,,,,, -326000,0.2333632,1.4503379,,,,,,,,,,,,,,,,, -326100,0.23202835,1.4740682,,,,,,,,,,,,,,,,, -326200,0.22610915,1.4129969,,,,,,,,,,,,,,,,, -326300,0.23756312,1.4264318,,,,,,,,,,,,,,,,, -326400,0.2291553,1.4121234,,,,,,,,,,,,,,,,, -326500,0.23667514,1.4779872,,,,,,,,,,,,,,,,, -326600,0.23470885,1.4365754,,,,,,,,,,,,,,,,, -326700,0.2355013,1.4403296,,,,,,,,,,,,,,,,, -326800,0.23362434,1.4461688,,,,,,,,,,,,,,,,, -326900,0.23855338,1.4424928,,,,,,,,,,,,,,,,, -327000,0.23474604,1.4664063,,,,,,,,,,,,,,,,, -327100,0.23415537,1.457292,,,,,,,,,,,,,,,,, -327200,0.22575277,1.4527636,,,,,,,,,,,,,,,,, -327300,0.23640223,1.3965153,,,,,,,,,,,,,,,,, -327400,0.23836707,1.4348705,,,,,,,,,,,,,,,,, -327500,0.23234098,1.4333879,,,,,,,,,,,,,,,,, -327600,0.23316988,1.4438684,,,,,,,,,,,,,,,,, -327700,0.2310566,1.4009029,,,,,,,,,,,,,,,,, -327800,0.23544654,1.463812,,,,,,,,,,,,,,,,, -327900,0.23744969,1.4449728,,,,,,,,,,,,,,,,, -328000,0.23576409,1.4502144,,,,,,,,,,,,,,,,, -328100,0.22596166,1.445301,,,,,,,,,,,,,,,,, -328200,0.22318937,1.3473153,,,,,,,,,,,,,,,,, -328223,,,0.6982913613319397,1.3581576347351074,35.58612960215996,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,116805.6553747654,200539.21746993065,116805.6553747654,83716.94771122932,6.594550371170044,0.0 -328300,0.2392639,1.5292768,,,,,,,,,,,,,,,,, -328400,0.23719922,1.4527742,,,,,,,,,,,,,,,,, -328500,0.22868113,1.4157813,,,,,,,,,,,,,,,,, -328600,0.22251624,1.3567374,,,,,,,,,,,,,,,,, -328700,0.23597467,1.4412059,,,,,,,,,,,,,,,,, -328800,0.23411244,1.4781464,,,,,,,,,,,,,,,,, -328900,0.23165521,1.5133816,,,,,,,,,,,,,,,,, -329000,0.240073,1.4240024,,,,,,,,,,,,,,,,, -329100,0.2278508,1.3974658,,,,,,,,,,,,,,,,, -329200,0.22290908,1.4070969,,,,,,,,,,,,,,,,, -329300,0.24401231,1.5059894,,,,,,,,,,,,,,,,, -329400,0.24933228,1.4623334,,,,,,,,,,,,,,,,, -329500,0.23546843,1.4679394,,,,,,,,,,,,,,,,, -329600,0.2350974,1.4642587,,,,,,,,,,,,,,,,, -329700,0.22617914,1.4552255,,,,,,,,,,,,,,,,, -329800,0.23744927,1.4544573,,,,,,,,,,,,,,,,, -329900,0.23513135,1.4118323,,,,,,,,,,,,,,,,, -330000,0.23951928,1.4299481,,,,,,,,,,,,,,,,, -330100,0.2347021,1.4384462,,,,,,,,,,,,,,,,, -330200,0.23784067,1.4299842,,,,,,,,,,,,,,,,, -330300,0.23835588,1.484966,,,,,,,,,,,,,,,,, -330400,0.22895475,1.4375192,,,,,,,,,,,,,,,,, -330500,0.23105575,1.4433428,,,,,,,,,,,,,,,,, -330584,,,0.6947457790374756,1.3779659271240234,35.72883824812817,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,117645.53719234468,201982.5979168415,117645.53719234468,84320.30842804909,6.6605446338653564,0.0 -330600,0.23316188,1.3553151,,,,,,,,,,,,,,,,, -330700,0.23678593,1.432746,,,,,,,,,,,,,,,,, -330800,0.23770043,1.4986786,,,,,,,,,,,,,,,,, -330900,0.22381535,1.4524691,,,,,,,,,,,,,,,,, -331000,0.2317111,1.4049466,,,,,,,,,,,,,,,,, -331100,0.23553081,1.4448599,,,,,,,,,,,,,,,,, -331200,0.23167646,1.4760814,,,,,,,,,,,,,,,,, -331300,0.22909418,1.4346399,,,,,,,,,,,,,,,,, -331400,0.24035758,1.5269902,,,,,,,,,,,,,,,,, -331500,0.24420743,1.4702302,,,,,,,,,,,,,,,,, -331600,0.5629349,1.4324117,,,,,,,,,,,,,,,,, -331700,0.24283805,1.5382377,,,,,,,,,,,,,,,,, -331800,0.23439227,1.4078733,,,,,,,,,,,,,,,,, -331900,0.23642877,1.4125861,,,,,,,,,,,,,,,,, -332000,0.24234095,1.4540657,,,,,,,,,,,,,,,,, -332100,0.23547985,1.4427449,,,,,,,,,,,,,,,,, -332200,0.23518734,1.4425473,,,,,,,,,,,,,,,,, -332300,0.23384361,1.4938565,,,,,,,,,,,,,,,,, -332400,0.23355277,1.4495053,,,,,,,,,,,,,,,,, -332500,0.24708171,1.4589691,,,,,,,,,,,,,,,,, -332600,0.23711373,1.4571533,,,,,,,,,,,,,,,,, -332700,0.23839045,1.4413881,,,,,,,,,,,,,,,,, -332800,0.2292375,1.4530188,,,,,,,,,,,,,,,,, -332900,0.23053882,1.478805,,,,,,,,,,,,,,,,, -332945,,,0.6956863403320312,1.3694723844528198,35.02622983467754,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,118485.5244398117,203421.36900758743,118485.5244398117,84918.95106649399,6.728163719177246,0.0 -333000,0.22170827,1.426188,,,,,,,,,,,,,,,,, -333100,0.23957358,1.4608659,,,,,,,,,,,,,,,,, -333200,0.23597106,1.400024,,,,,,,,,,,,,,,,, -333300,0.2415485,1.5015554,,,,,,,,,,,,,,,,, -333400,0.23296344,1.4625348,,,,,,,,,,,,,,,,, -333500,0.22870545,1.466836,,,,,,,,,,,,,,,,, -333600,0.23787145,1.4378729,,,,,,,,,,,,,,,,, -333700,0.23083824,1.4209392,,,,,,,,,,,,,,,,, -333800,0.2404347,1.4231526,,,,,,,,,,,,,,,,, -333900,0.23116368,1.4279317,,,,,,,,,,,,,,,,, -334000,0.23777111,1.4891678,,,,,,,,,,,,,,,,, -334100,0.22381838,1.3701912,,,,,,,,,,,,,,,,, -334200,0.23702523,1.439253,,,,,,,,,,,,,,,,, -334300,0.22885281,1.4793606,,,,,,,,,,,,,,,,, -334400,0.23591723,1.454129,,,,,,,,,,,,,,,,, -334500,0.2335907,1.4179827,,,,,,,,,,,,,,,,, -334600,0.2322904,1.5070996,,,,,,,,,,,,,,,,, -334700,0.2432192,1.527443,,,,,,,,,,,,,,,,, -334800,0.2312207,1.4600409,,,,,,,,,,,,,,,,, -334900,0.24958625,1.4169378,,,,,,,,,,,,,,,,, -335000,0.22781388,1.3754112,,,,,,,,,,,,,,,,, -335100,0.2322661,1.4403923,,,,,,,,,,,,,,,,, -335200,0.22877234,1.4643707,,,,,,,,,,,,,,,,, -335300,0.22935522,1.388038,,,,,,,,,,,,,,,,, -335305,,,0.695704460144043,1.3688232898712158,35.480177783502256,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,119325.42063117027,204863.69330143929,119325.42063117027,85521.2403678894,6.795787334442139,0.0 -335400,0.24077383,1.4941062,,,,,,,,,,,,,,,,, -335500,0.23724759,1.3720812,,,,,,,,,,,,,,,,, -335600,0.22996543,1.3949575,,,,,,,,,,,,,,,,, -335700,0.23533575,1.4643352,,,,,,,,,,,,,,,,, -335800,0.23865744,1.5082225,,,,,,,,,,,,,,,,, -335900,0.23112045,1.4532137,,,,,,,,,,,,,,,,, -336000,0.24566296,1.4741793,,,,,,,,,,,,,,,,, -336100,0.22845298,1.4067991,,,,,,,,,,,,,,,,, -336200,0.23013987,1.3731751,,,,,,,,,,,,,,,,, -336300,0.22834963,1.4295716,,,,,,,,,,,,,,,,, -336400,0.24196368,1.4595164,,,,,,,,,,,,,,,,, -336500,0.22769669,1.3762883,,,,,,,,,,,,,,,,, -336600,0.22932938,1.3646382,,,,,,,,,,,,,,,,, -336700,0.22814475,1.4126003,,,,,,,,,,,,,,,,, -336800,0.23333652,1.3885036,,,,,,,,,,,,,,,,, -336900,0.23017822,1.4891597,,,,,,,,,,,,,,,,, -337000,0.23113737,1.4405313,,,,,,,,,,,,,,,,, -337100,0.2267934,1.3958211,,,,,,,,,,,,,,,,, -337200,0.22364214,1.3629758,,,,,,,,,,,,,,,,, -337300,0.23610836,1.5148278,,,,,,,,,,,,,,,,, -337400,0.23865321,1.5147076,,,,,,,,,,,,,,,,, -337500,0.24143964,1.4126929,,,,,,,,,,,,,,,,, -337600,0.23812973,1.4996071,,,,,,,,,,,,,,,,, -337665,,,0.6960183382034302,1.3696424961090088,35.55197694210756,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,120165.28549027444,206303.11276698112,120165.28549027444,86120.65092658997,6.865016460418701,0.0 -337700,0.23742366,1.531357,,,,,,,,,,,,,,,,, -337800,0.23567316,1.422108,,,,,,,,,,,,,,,,, -337900,0.22893573,1.4298664,,,,,,,,,,,,,,,,, -338000,0.24325512,1.4319304,,,,,,,,,,,,,,,,, -338100,0.23253375,1.4227582,,,,,,,,,,,,,,,,, -338200,0.23847999,1.4874808,,,,,,,,,,,,,,,,, -338300,0.22991535,1.4381784,,,,,,,,,,,,,,,,, -338400,0.23764127,1.469989,,,,,,,,,,,,,,,,, -338500,0.23428252,1.4336367,,,,,,,,,,,,,,,,, -338600,0.23981352,1.4716283,,,,,,,,,,,,,,,,, -338700,0.21839197,1.4278455,,,,,,,,,,,,,,,,, -338800,0.23081553,1.4386352,,,,,,,,,,,,,,,,, -338900,0.23028654,1.3983572,,,,,,,,,,,,,,,,, -339000,0.23608679,1.4032353,,,,,,,,,,,,,,,,, -339100,0.23445314,1.4349011,,,,,,,,,,,,,,,,, -339200,0.22688851,1.3894471,,,,,,,,,,,,,,,,, -339300,0.24532989,1.5522624,,,,,,,,,,,,,,,,, -339400,0.23818162,1.459324,,,,,,,,,,,,,,,,, -339500,0.22681563,1.5287205,,,,,,,,,,,,,,,,, -339600,0.23758183,1.366743,,,,,,,,,,,,,,,,, -339700,0.23632194,1.4428132,,,,,,,,,,,,,,,,, -339800,0.22947782,1.478982,,,,,,,,,,,,,,,,, -339900,0.23653458,1.4874738,,,,,,,,,,,,,,,,, -340000,0.23189147,1.3883455,,,,,,,,,,,,,,,,, -340025,,,0.6936178803443909,1.38359534740448,35.268918831625335,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,121005.14982128143,207764.23371696472,121005.14982128143,86741.76490736008,6.934324264526367,0.0 -340100,0.23487258,1.3631544,,,,,,,,,,,,,,,,, -340200,0.23645462,1.4769307,,,,,,,,,,,,,,,,, -340300,0.22465687,1.3829374,,,,,,,,,,,,,,,,, -340400,0.23876181,1.4089092,,,,,,,,,,,,,,,,, -340500,0.23693442,1.4583476,,,,,,,,,,,,,,,,, -340600,0.23412842,1.4506493,,,,,,,,,,,,,,,,, -340700,0.23517765,1.4347663,,,,,,,,,,,,,,,,, -340800,0.23919708,1.422812,,,,,,,,,,,,,,,,, -340900,0.22935317,1.4221188,,,,,,,,,,,,,,,,, -341000,0.22815868,1.4366828,,,,,,,,,,,,,,,,, -341100,0.21731877,1.3926543,,,,,,,,,,,,,,,,, -341200,0.24339868,1.4908369,,,,,,,,,,,,,,,,, -341300,0.24022324,1.3924718,,,,,,,,,,,,,,,,, -341400,0.23305938,1.4513297,,,,,,,,,,,,,,,,, -341500,0.25000077,1.4634135,,,,,,,,,,,,,,,,, -341600,0.22991434,1.4630464,,,,,,,,,,,,,,,,, -341700,0.23615928,1.4299666,,,,,,,,,,,,,,,,, -341800,0.22988737,1.4202096,,,,,,,,,,,,,,,,, -341900,0.23155473,1.4151338,,,,,,,,,,,,,,,,, -342000,0.23358563,1.4856756,,,,,,,,,,,,,,,,, -342100,0.23621117,1.4380513,,,,,,,,,,,,,,,,, -342200,0.23340155,1.4753369,,,,,,,,,,,,,,,,, -342300,0.2321353,1.3900614,,,,,,,,,,,,,,,,, -342385,,,0.6969581842422485,1.3596528768539429,35.45381204623339,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,121845.14006876944,209211.1178805828,121845.14006876944,87348.51953816414,7.001046895980835,0.0 -342400,0.23901008,1.4307301,,,,,,,,,,,,,,,,, -342500,0.23001239,1.360381,,,,,,,,,,,,,,,,, -342600,0.23595327,1.5556265,,,,,,,,,,,,,,,,, -342700,0.23280235,1.4641988,,,,,,,,,,,,,,,,, -342800,0.23490153,1.4575162,,,,,,,,,,,,,,,,, -342900,0.23641233,1.4840066,,,,,,,,,,,,,,,,, -343000,0.23092952,1.4253061,,,,,,,,,,,,,,,,, -343100,0.24417572,1.428488,,,,,,,,,,,,,,,,, -343200,0.24609521,1.4576346,,,,,,,,,,,,,,,,, -343300,0.24656546,1.4191338,,,,,,,,,,,,,,,,, -343400,0.22839363,1.3892953,,,,,,,,,,,,,,,,, -343500,0.23052056,1.408777,,,,,,,,,,,,,,,,, -343600,0.22517031,1.4249575,,,,,,,,,,,,,,,,, -343700,0.24116044,1.4151587,,,,,,,,,,,,,,,,, -343800,0.2311087,1.4179449,,,,,,,,,,,,,,,,, -343900,0.22643867,1.3791465,,,,,,,,,,,,,,,,, -344000,0.22992316,1.4822413,,,,,,,,,,,,,,,,, -344100,0.22173275,1.4124401,,,,,,,,,,,,,,,,, -344200,0.23147686,1.4134331,,,,,,,,,,,,,,,,, -344300,0.23156512,1.4773984,,,,,,,,,,,,,,,,, -344400,0.22587638,1.525025,,,,,,,,,,,,,,,,, -344500,0.23104325,1.4976782,,,,,,,,,,,,,,,,, -344600,0.23854014,1.3603022,,,,,,,,,,,,,,,,, -344700,0.2312509,1.4362915,,,,,,,,,,,,,,,,, -344746,,,0.694857656955719,1.3747631311416626,35.20722528657935,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,122685.2105679512,210659.43799066544,122685.2105679512,87956.62901735306,7.070030212402344,0.0 -344800,0.23911293,1.4725118,,,,,,,,,,,,,,,,, -344900,0.22946982,1.4724903,,,,,,,,,,,,,,,,, -345000,0.23489653,1.4549619,,,,,,,,,,,,,,,,, -345100,0.2280906,1.3864504,,,,,,,,,,,,,,,,, -345200,0.22510397,1.4311593,,,,,,,,,,,,,,,,, -345300,0.24444577,1.4809595,,,,,,,,,,,,,,,,, -345400,0.23006411,1.4730283,,,,,,,,,,,,,,,,, -345500,0.24896435,1.5093721,,,,,,,,,,,,,,,,, -345600,0.23538864,1.4190456,,,,,,,,,,,,,,,,, -345700,0.23313019,1.4580747,,,,,,,,,,,,,,,,, -345800,0.23822543,1.5297792,,,,,,,,,,,,,,,,, -345900,0.22990905,1.4202023,,,,,,,,,,,,,,,,, -346000,0.23694177,1.4950212,,,,,,,,,,,,,,,,, -346100,0.23837954,1.4387724,,,,,,,,,,,,,,,,, -346200,0.22535993,1.4503059,,,,,,,,,,,,,,,,, -346300,0.23577519,1.4525449,,,,,,,,,,,,,,,,, -346400,0.24156533,1.4313886,,,,,,,,,,,,,,,,, -346500,0.22682592,1.3997604,,,,,,,,,,,,,,,,, -346600,0.24599083,1.478707,,,,,,,,,,,,,,,,, -346700,0.23115768,1.4876001,,,,,,,,,,,,,,,,, -346800,0.23304985,1.4550844,,,,,,,,,,,,,,,,, -346900,0.2424303,1.4710523,,,,,,,,,,,,,,,,, -347000,0.23416586,1.4278258,,,,,,,,,,,,,,,,, -347100,0.23320553,1.4322891,,,,,,,,,,,,,,,,, -347107,,,0.6944318413734436,1.3756641149520874,35.670044646723696,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,123525.30701708794,212110.61843252185,123525.30701708794,88567.57115674019,7.1397809982299805,0.0 -347200,0.23544687,1.4491469,,,,,,,,,,,,,,,,, -347300,0.22595724,1.4357103,,,,,,,,,,,,,,,,, -347400,0.23684686,1.4684471,,,,,,,,,,,,,,,,, -347500,0.23192684,1.474275,,,,,,,,,,,,,,,,, -347600,0.24178421,1.4189761,,,,,,,,,,,,,,,,, -347700,0.23844735,1.5057776,,,,,,,,,,,,,,,,, -347800,0.23018734,1.3696058,,,,,,,,,,,,,,,,, -347900,0.23861793,1.4890673,,,,,,,,,,,,,,,,, -348000,0.24408609,1.475205,,,,,,,,,,,,,,,,, -348100,0.23440945,1.4816027,,,,,,,,,,,,,,,,, -348200,0.23155834,1.4773383,,,,,,,,,,,,,,,,, -348300,0.22562936,1.4202962,,,,,,,,,,,,,,,,, -348400,0.23186253,1.4281783,,,,,,,,,,,,,,,,, -348500,0.23351316,1.3982216,,,,,,,,,,,,,,,,, -348600,0.23999684,1.458382,,,,,,,,,,,,,,,,, -348700,0.2488431,1.506206,,,,,,,,,,,,,,,,, -348800,0.23014781,1.4432238,,,,,,,,,,,,,,,,, -348900,0.23971303,1.4652618,,,,,,,,,,,,,,,,, -349000,0.22783814,1.4589955,,,,,,,,,,,,,,,,, -349100,0.22633539,1.3520242,,,,,,,,,,,,,,,,, -349200,0.23504373,1.386704,,,,,,,,,,,,,,,,, -349300,0.24924718,1.4927323,,,,,,,,,,,,,,,,, -349400,0.22704278,1.4386138,,,,,,,,,,,,,,,,, -349467,,,0.6982336044311523,1.3613815307617188,35.713977052649426,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,124365.3011341095,213546.4123835564,124365.3011341095,89163.2274723053,7.209471464157104,0.0 -349500,0.22601612,1.446612,,,,,,,,,,,,,,,,, -349600,0.22312813,1.4118396,,,,,,,,,,,,,,,,, -349700,0.23332791,1.4271032,,,,,,,,,,,,,,,,, -349800,0.22920758,1.4119835,,,,,,,,,,,,,,,,, -349900,0.23759502,1.5169953,,,,,,,,,,,,,,,,, -350000,0.23659293,1.4056455,,,,,,,,,,,,,,,,, -350100,0.23297444,1.4574264,,,,,,,,,,,,,,,,, -350200,0.23906827,1.3832041,,,,,,,,,,,,,,,,, -350300,0.24490228,1.5125809,,,,,,,,,,,,,,,,, -350400,0.23948918,1.4773053,,,,,,,,,,,,,,,,, -350500,0.2444974,1.4825836,,,,,,,,,,,,,,,,, -350600,0.23775871,1.4364222,,,,,,,,,,,,,,,,, -350700,0.23022228,1.4764408,,,,,,,,,,,,,,,,, -350800,0.2357885,1.4764652,,,,,,,,,,,,,,,,, -350900,0.2296909,1.4332467,,,,,,,,,,,,,,,,, -351000,0.29068613,1.4245292,,,,,,,,,,,,,,,,, -351100,0.23639916,1.4086802,,,,,,,,,,,,,,,,, -351200,0.24188693,1.5082649,,,,,,,,,,,,,,,,, -351300,0.22759981,1.4213644,,,,,,,,,,,,,,,,, -351400,0.23232286,1.4839578,,,,,,,,,,,,,,,,, -351500,0.23765713,1.4795499,,,,,,,,,,,,,,,,, -351600,0.23242012,1.4232537,,,,,,,,,,,,,,,,, -351700,0.23550855,1.483981,,,,,,,,,,,,,,,,, -351800,0.22356768,1.4369867,,,,,,,,,,,,,,,,, -351827,,,0.6962584853172302,1.368923902511597,35.62077978803877,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,125205.46908950806,214992.70389819145,125205.46908950806,89769.20688509941,7.279955148696899,0.0 -351900,0.23242572,1.5367345,,,,,,,,,,,,,,,,, -352000,0.22537687,1.421376,,,,,,,,,,,,,,,,, -352100,0.23071964,1.4279377,,,,,,,,,,,,,,,,, -352200,0.23609439,1.5567544,,,,,,,,,,,,,,,,, -352300,0.23610954,1.4207361,,,,,,,,,,,,,,,,, -352400,0.23334539,1.3890837,,,,,,,,,,,,,,,,, -352500,0.23339811,1.5464778,,,,,,,,,,,,,,,,, -352600,0.2264305,1.4102885,,,,,,,,,,,,,,,,, -352700,0.25570998,1.4999692,,,,,,,,,,,,,,,,, -352800,0.2359521,1.4829004,,,,,,,,,,,,,,,,, -352900,0.24441855,1.4254084,,,,,,,,,,,,,,,,, -353000,0.2441127,1.4812517,,,,,,,,,,,,,,,,, -353100,0.23723318,1.3810183,,,,,,,,,,,,,,,,, -353200,0.23665026,1.500757,,,,,,,,,,,,,,,,, -353300,0.22589223,1.3724617,,,,,,,,,,,,,,,,, -353400,0.2338509,1.4311045,,,,,,,,,,,,,,,,, -353500,0.23818998,1.4448544,,,,,,,,,,,,,,,,, -353600,0.22681706,1.3973781,,,,,,,,,,,,,,,,, -353700,0.22460467,1.3936462,,,,,,,,,,,,,,,,, -353800,0.22955844,1.396112,,,,,,,,,,,,,,,,, -353900,0.23018256,1.4688343,,,,,,,,,,,,,,,,, -354000,0.23873071,1.5000563,,,,,,,,,,,,,,,,, -354100,0.23058808,1.450989,,,,,,,,,,,,,,,,, -354189,,,0.6944153308868408,1.3725401163101196,35.38775796358296,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,126045.5473601818,216434.73520970345,126045.5473601818,90371.01800084114,7.350381135940552,0.0 -354200,0.2295302,1.4228948,,,,,,,,,,,,,,,,, -354300,0.2322017,1.420087,,,,,,,,,,,,,,,,, -354400,0.23028302,1.4048193,,,,,,,,,,,,,,,,, -354500,0.23851517,1.4332179,,,,,,,,,,,,,,,,, -354600,0.22873333,1.417037,,,,,,,,,,,,,,,,, -354700,0.23125494,1.506763,,,,,,,,,,,,,,,,, -354800,0.23200975,1.44923,,,,,,,,,,,,,,,,, -354900,0.2323024,1.4229898,,,,,,,,,,,,,,,,, -355000,0.23085022,1.4809425,,,,,,,,,,,,,,,,, -355100,0.23607974,1.4455734,,,,,,,,,,,,,,,,, -355200,0.23820193,1.4481984,,,,,,,,,,,,,,,,, -355300,0.22555462,1.4441661,,,,,,,,,,,,,,,,, -355400,0.23960537,1.4280754,,,,,,,,,,,,,,,,, -355500,0.24318771,1.4949398,,,,,,,,,,,,,,,,, -355600,0.24646203,1.4662751,,,,,,,,,,,,,,,,, -355700,0.23698758,1.4042187,,,,,,,,,,,,,,,,, -355800,0.24599308,1.5072498,,,,,,,,,,,,,,,,, -355900,0.22625656,1.4348953,,,,,,,,,,,,,,,,, -356000,0.23650192,1.3922288,,,,,,,,,,,,,,,,, -356100,0.22428232,1.4407905,,,,,,,,,,,,,,,,, -356200,0.24427086,1.4409336,,,,,,,,,,,,,,,,, -356300,0.2396202,1.4299581,,,,,,,,,,,,,,,,, -356400,0.22766235,1.4548749,,,,,,,,,,,,,,,,, -356500,0.24151601,1.4886684,,,,,,,,,,,,,,,,, -356551,,,0.6954219341278076,1.3780286312103271,35.53209505714748,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,126885.76126217842,217885.62106704712,126885.76126217842,90981.5489873886,7.420124530792236,0.0 -356600,0.2342542,1.4451466,,,,,,,,,,,,,,,,, -356700,0.2285995,1.4665203,,,,,,,,,,,,,,,,, -356800,0.24095194,1.4072028,,,,,,,,,,,,,,,,, -356900,0.23243858,1.3430451,,,,,,,,,,,,,,,,, -357000,0.22505832,1.4695743,,,,,,,,,,,,,,,,, -357100,0.2256765,1.3874531,,,,,,,,,,,,,,,,, -357200,0.23792501,1.4906763,,,,,,,,,,,,,,,,, -357300,0.23516272,1.4275295,,,,,,,,,,,,,,,,, -357400,0.23920785,1.4095812,,,,,,,,,,,,,,,,, -357500,0.23344873,1.5002617,,,,,,,,,,,,,,,,, -357600,0.23525849,1.5370175,,,,,,,,,,,,,,,,, -357700,0.24284239,1.469933,,,,,,,,,,,,,,,,, -357800,0.23252292,1.503307,,,,,,,,,,,,,,,,, -357900,0.22977021,1.4364898,,,,,,,,,,,,,,,,, -358000,0.23046623,1.4161788,,,,,,,,,,,,,,,,, -358100,0.23444396,1.4282387,,,,,,,,,,,,,,,,, -358200,0.24004182,1.5143223,,,,,,,,,,,,,,,,, -358300,0.23119967,1.4648789,,,,,,,,,,,,,,,,, -358400,0.23406135,1.4238654,,,,,,,,,,,,,,,,, -358500,0.23367712,1.436041,,,,,,,,,,,,,,,,, -358600,0.23964886,1.4833429,,,,,,,,,,,,,,,,, -358700,0.24067615,1.4609349,,,,,,,,,,,,,,,,, -358800,0.24055104,1.4559411,,,,,,,,,,,,,,,,, -358900,0.22742559,1.3733857,,,,,,,,,,,,,,,,, -358911,,,0.6971335411071777,1.3615541458129885,35.45657819681476,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,127725.61231327055,219321.40892481804,127725.61231327055,91577.32687735558,7.504851579666138,0.0 -359000,0.22998463,1.4290158,,,,,,,,,,,,,,,,, -359100,0.23376718,1.4752562,,,,,,,,,,,,,,,,, -359200,0.24284106,1.445189,,,,,,,,,,,,,,,,, -359300,0.24034223,1.3741705,,,,,,,,,,,,,,,,, -359400,0.22639073,1.3833654,,,,,,,,,,,,,,,,, -359500,0.23909384,1.3904682,,,,,,,,,,,,,,,,, -359600,0.22980025,1.4576339,,,,,,,,,,,,,,,,, -359700,0.23806725,1.4983494,,,,,,,,,,,,,,,,, -359800,0.2390673,1.4439111,,,,,,,,,,,,,,,,, -359900,0.23072287,1.4556445,,,,,,,,,,,,,,,,, -360000,0.2471586,1.4935629,,,,,,,,,,,,,,,,, -360100,0.23645474,1.4400412,,,,,,,,,,,,,,,,, -360200,0.2309895,1.4508228,,,,,,,,,,,,,,,,, -360300,0.2318424,1.4089266,,,,,,,,,,,,,,,,, -360400,0.2476028,1.5356627,,,,,,,,,,,,,,,,, -360500,0.23390435,1.3889136,,,,,,,,,,,,,,,,, -360600,0.23314789,1.4412662,,,,,,,,,,,,,,,,, -360700,0.23789434,1.3949578,,,,,,,,,,,,,,,,, -360800,0.23163524,1.4305906,,,,,,,,,,,,,,,,, -360900,0.23083942,1.4379126,,,,,,,,,,,,,,,,, -361000,0.23509166,1.4062937,,,,,,,,,,,,,,,,, -361100,0.2564675,1.4484637,,,,,,,,,,,,,,,,, -361200,0.23243694,1.4176005,,,,,,,,,,,,,,,,, -361271,,,0.69556725025177,1.368157982826233,35.47585553964436,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,128565.52987861632,220772.04701256752,128565.52987861632,92187.9036836624,7.576892137527466,0.0 -361300,0.23337832,1.3923944,,,,,,,,,,,,,,,,, -361400,0.2238456,1.4032929,,,,,,,,,,,,,,,,, -361500,0.24571553,1.4270765,,,,,,,,,,,,,,,,, -361600,0.23131512,1.443856,,,,,,,,,,,,,,,,, -361700,0.23240612,1.4577887,,,,,,,,,,,,,,,,, -361800,0.2389295,1.3754616,,,,,,,,,,,,,,,,, -361900,0.24215728,1.4679524,,,,,,,,,,,,,,,,, -362000,0.23501998,1.4035614,,,,,,,,,,,,,,,,, -362100,0.24332318,1.4184307,,,,,,,,,,,,,,,,, -362200,0.24100284,1.4328191,,,,,,,,,,,,,,,,, -362300,0.23781532,1.5004503,,,,,,,,,,,,,,,,, -362400,0.22392964,1.4695919,,,,,,,,,,,,,,,,, -362500,0.23433119,1.4787521,,,,,,,,,,,,,,,,, -362600,0.23782393,1.4282076,,,,,,,,,,,,,,,,, -362700,0.2369056,1.449991,,,,,,,,,,,,,,,,, -362800,0.2305578,1.4266949,,,,,,,,,,,,,,,,, -362900,0.24220634,1.5064827,,,,,,,,,,,,,,,,, -363000,0.24077335,1.3529977,,,,,,,,,,,,,,,,, -363100,0.23125689,1.4596478,,,,,,,,,,,,,,,,, -363200,0.2268647,1.4339621,,,,,,,,,,,,,,,,, -363300,0.2330082,1.4690921,,,,,,,,,,,,,,,,, -363400,0.22958304,1.3890221,,,,,,,,,,,,,,,,, -363500,0.22870362,1.4409808,,,,,,,,,,,,,,,,, -363600,0.2311554,1.4274186,,,,,,,,,,,,,,,,, -363631,,,0.6939976811408997,1.379857063293457,35.58371781025625,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,129405.45047450066,222202.46889948845,129405.45047450066,92778.26169657709,7.64824366569519,0.0 -363700,0.23189905,1.4684957,,,,,,,,,,,,,,,,, -363800,0.23120953,1.4340465,,,,,,,,,,,,,,,,, -363900,0.23496453,1.4960263,,,,,,,,,,,,,,,,, -364000,0.2472068,1.5090845,,,,,,,,,,,,,,,,, -364100,0.23088183,1.4261938,,,,,,,,,,,,,,,,, -364200,0.24136947,1.4983318,,,,,,,,,,,,,,,,, -364300,0.23724648,1.4256296,,,,,,,,,,,,,,,,, -364400,0.2311384,1.4108125,,,,,,,,,,,,,,,,, -364500,0.24024433,1.4086248,,,,,,,,,,,,,,,,, -364600,0.22668646,1.4074218,,,,,,,,,,,,,,,,, -364700,0.25055042,1.4726456,,,,,,,,,,,,,,,,, -364800,0.23085088,1.3675348,,,,,,,,,,,,,,,,, -364900,0.22488403,1.4536352,,,,,,,,,,,,,,,,, -365000,0.24079613,1.4058623,,,,,,,,,,,,,,,,, -365100,0.22626616,1.3979465,,,,,,,,,,,,,,,,, -365200,0.24399334,1.4016967,,,,,,,,,,,,,,,,, -365300,0.23215686,1.4173564,,,,,,,,,,,,,,,,, -365400,0.22998774,1.3942614,,,,,,,,,,,,,,,,, -365500,0.22888684,1.4384738,,,,,,,,,,,,,,,,, -365600,0.24007899,1.4543328,,,,,,,,,,,,,,,,, -365700,0.23171988,1.4190123,,,,,,,,,,,,,,,,, -365800,0.2291619,1.466124,,,,,,,,,,,,,,,,, -365900,0.2309661,1.4366828,,,,,,,,,,,,,,,,, -365992,,,0.6946594715118408,1.3834761381149292,35.4348874325702,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,130245.50266051292,223635.70874118805,130245.50266051292,93371.30241632462,7.723176717758179,0.0 -366000,0.23333326,1.3896359,,,,,,,,,,,,,,,,, -366100,0.22358884,1.4093144,,,,,,,,,,,,,,,,, -366200,0.22740799,1.3922577,,,,,,,,,,,,,,,,, -366300,0.23849966,1.3837078,,,,,,,,,,,,,,,,, -366400,0.23775922,1.489731,,,,,,,,,,,,,,,,, -366500,0.22598879,1.385951,,,,,,,,,,,,,,,,, -366600,0.23135896,1.4173417,,,,,,,,,,,,,,,,, -366700,0.24590707,1.4786658,,,,,,,,,,,,,,,,, -366800,0.23366655,1.3650681,,,,,,,,,,,,,,,,, -366900,0.23150243,1.4032992,,,,,,,,,,,,,,,,, -367000,0.23234808,1.4463929,,,,,,,,,,,,,,,,, -367100,0.24529597,1.5171324,,,,,,,,,,,,,,,,, -367200,0.2251085,1.3884753,,,,,,,,,,,,,,,,, -367300,0.24641359,1.4885426,,,,,,,,,,,,,,,,, -367400,0.22331735,1.3948268,,,,,,,,,,,,,,,,, -367500,0.23888928,1.534709,,,,,,,,,,,,,,,,, -367600,0.23265229,1.4973787,,,,,,,,,,,,,,,,, -367700,7.02581,1.5040368,,,,,,,,,,,,,,,,, -367800,0.2298357,1.4365867,,,,,,,,,,,,,,,,, -367900,0.23636125,1.4314232,,,,,,,,,,,,,,,,, -368000,0.23718646,1.4547865,,,,,,,,,,,,,,,,, -368100,0.24556468,1.4333895,,,,,,,,,,,,,,,,, -368200,0.23529126,1.3859898,,,,,,,,,,,,,,,,, -368300,0.23044425,1.3592463,,,,,,,,,,,,,,,,, -368351,,,0.6929248571395874,1.3840899467468262,35.70822070105537,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,131085.41303038597,225079.64515781403,131085.41303038597,93975.18136429788,7.796278238296509,0.0 -368400,0.23660527,1.405128,,,,,,,,,,,,,,,,, -368500,0.2405863,1.4205009,,,,,,,,,,,,,,,,, -368600,0.22753061,1.4939915,,,,,,,,,,,,,,,,, -368700,0.23868416,1.4313871,,,,,,,,,,,,,,,,, -368800,0.23879337,1.4803721,,,,,,,,,,,,,,,,, -368900,0.24081193,1.4671386,,,,,,,,,,,,,,,,, -369000,0.23558067,1.4839641,,,,,,,,,,,,,,,,, -369100,0.22902308,1.432517,,,,,,,,,,,,,,,,, -369200,0.23693117,1.4348154,,,,,,,,,,,,,,,,, -369300,0.23084095,1.427794,,,,,,,,,,,,,,,,, -369400,0.24410282,1.4485946,,,,,,,,,,,,,,,,, -369500,0.24235626,1.4036394,,,,,,,,,,,,,,,,, -369600,0.22756219,1.4352762,,,,,,,,,,,,,,,,, -369700,0.2331386,1.4311087,,,,,,,,,,,,,,,,, -369800,0.23283565,1.4372394,,,,,,,,,,,,,,,,, -369900,0.2380892,1.4641713,,,,,,,,,,,,,,,,, -370000,0.23200655,1.4667304,,,,,,,,,,,,,,,,, -370100,0.23277093,1.4360021,,,,,,,,,,,,,,,,, -370200,0.22776309,1.4151999,,,,,,,,,,,,,,,,, -370300,0.23121151,1.4641769,,,,,,,,,,,,,,,,, -370400,0.2420559,1.5013286,,,,,,,,,,,,,,,,, -370500,0.23978879,1.4597006,,,,,,,,,,,,,,,,, -370600,0.23633675,1.4832399,,,,,,,,,,,,,,,,, -370700,0.23274262,1.48704,,,,,,,,,,,,,,,,, -370711,,,0.6935230493545532,1.3815573453903198,35.63322774950154,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,131925.42628622055,226520.8309454918,131925.42628622055,94576.20677280426,7.871663808822632,0.0 -370800,0.23555315,1.473648,,,,,,,,,,,,,,,,, -370900,0.23238003,1.4741911,,,,,,,,,,,,,,,,, -371000,0.23314936,1.4404207,,,,,,,,,,,,,,,,, -371100,0.22648256,1.4476451,,,,,,,,,,,,,,,,, -371200,0.23380765,1.442782,,,,,,,,,,,,,,,,, -371300,0.23897319,1.4687743,,,,,,,,,,,,,,,,, -371400,0.23419228,1.4863856,,,,,,,,,,,,,,,,, -371500,0.2367019,1.4649161,,,,,,,,,,,,,,,,, -371600,0.22848013,1.3843939,,,,,,,,,,,,,,,,, -371700,0.23213378,1.455898,,,,,,,,,,,,,,,,, -371800,0.23376729,1.4317545,,,,,,,,,,,,,,,,, -371900,0.23768817,1.4534388,,,,,,,,,,,,,,,,, -372000,0.22966143,1.4704732,,,,,,,,,,,,,,,,, -372100,0.22464389,1.4524969,,,,,,,,,,,,,,,,, -372200,0.2373984,1.4729902,,,,,,,,,,,,,,,,, -372300,0.23016982,1.4140571,,,,,,,,,,,,,,,,, -372400,0.22967944,1.4347938,,,,,,,,,,,,,,,,, -372500,0.23704612,1.3701593,,,,,,,,,,,,,,,,, -372600,0.23542012,1.4406269,,,,,,,,,,,,,,,,, -372700,0.2323725,1.4438289,,,,,,,,,,,,,,,,, -372800,0.23606013,1.4305844,,,,,,,,,,,,,,,,, -372900,0.23897843,1.4678543,,,,,,,,,,,,,,,,, -373000,0.23164135,1.4180119,,,,,,,,,,,,,,,,, -373071,,,0.693621814250946,1.3835389614105225,35.99297319457071,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,132765.4077756405,227956.367408514,132765.4077756405,95171.61591076852,7.943910598754883,0.0 -373100,0.2357093,1.4797387,,,,,,,,,,,,,,,,, -373200,0.24084415,1.4861671,,,,,,,,,,,,,,,,, -373300,0.23083267,1.4689083,,,,,,,,,,,,,,,,, -373400,0.23625676,1.4526606,,,,,,,,,,,,,,,,, -373500,0.23489931,1.4723301,,,,,,,,,,,,,,,,, -373600,0.23836075,1.3802371,,,,,,,,,,,,,,,,, -373700,0.23512912,1.4717944,,,,,,,,,,,,,,,,, -373800,0.23083183,1.4308193,,,,,,,,,,,,,,,,, -373900,0.23506452,1.4301693,,,,,,,,,,,,,,,,, -374000,0.23086852,1.5042766,,,,,,,,,,,,,,,,, -374100,0.24955209,1.4469919,,,,,,,,,,,,,,,,, -374200,0.23720138,1.4493351,,,,,,,,,,,,,,,,, -374300,0.236911,1.3909794,,,,,,,,,,,,,,,,, -374400,0.2449918,1.4788347,,,,,,,,,,,,,,,,, -374500,0.23622392,1.4570996,,,,,,,,,,,,,,,,, -374600,0.22145225,1.4287648,,,,,,,,,,,,,,,,, -374700,0.23486136,1.4473565,,,,,,,,,,,,,,,,, -374800,0.23597531,1.4544765,,,,,,,,,,,,,,,,, -374900,0.22785316,1.4269627,,,,,,,,,,,,,,,,, -375000,0.23169725,1.4533063,,,,,,,,,,,,,,,,, -375100,0.23028818,1.4371443,,,,,,,,,,,,,,,,, -375200,0.22637881,1.3914248,,,,,,,,,,,,,,,,, -375300,0.23239803,1.3954501,,,,,,,,,,,,,,,,, -375400,0.22389783,1.381102,,,,,,,,,,,,,,,,, -375431,,,0.6953145861625671,1.371348857879639,35.52966497028867,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,133605.52108120918,229396.5992236137,133605.52108120918,95771.58613538742,8.019490957260132,0.0 -375500,0.24099496,1.4537041,,,,,,,,,,,,,,,,, -375600,0.2370575,1.4248592,,,,,,,,,,,,,,,,, -375700,0.23809958,1.4633138,,,,,,,,,,,,,,,,, -375800,0.23286113,1.468541,,,,,,,,,,,,,,,,, -375900,0.23599008,1.4804934,,,,,,,,,,,,,,,,, -376000,0.24359538,1.4722484,,,,,,,,,,,,,,,,, -376100,0.2288196,1.4293164,,,,,,,,,,,,,,,,, -376200,0.21922784,1.3716635,,,,,,,,,,,,,,,,, -376300,0.22994769,1.4924831,,,,,,,,,,,,,,,,, -376400,0.22481023,1.3651898,,,,,,,,,,,,,,,,, -376500,0.23993115,1.4715706,,,,,,,,,,,,,,,,, -376600,0.23549965,1.5237255,,,,,,,,,,,,,,,,, -376700,0.22746833,1.3657564,,,,,,,,,,,,,,,,, -376800,0.2366722,1.4181056,,,,,,,,,,,,,,,,, -376900,0.22395901,1.3884156,,,,,,,,,,,,,,,,, -377000,0.22703195,1.4572225,,,,,,,,,,,,,,,,, -377100,0.24789006,1.5019965,,,,,,,,,,,,,,,,, -377200,0.23056528,1.3771949,,,,,,,,,,,,,,,,, -377300,0.23651233,1.4366876,,,,,,,,,,,,,,,,, -377400,0.2252208,1.3768264,,,,,,,,,,,,,,,,, -377500,0.22460395,1.4010695,,,,,,,,,,,,,,,,, -377600,0.23357642,1.526438,,,,,,,,,,,,,,,,, -377700,0.24031346,1.3882018,,,,,,,,,,,,,,,,, -377790,,,0.6948238611221313,1.373692512512207,35.914049220026286,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,134445.5572388172,230832.9020171165,134445.5572388172,96367.69637274742,8.102221727371216,0.0 -377800,0.23662259,1.4811211,,,,,,,,,,,,,,,,, -377900,0.23900035,1.425441,,,,,,,,,,,,,,,,, -378000,0.22719137,1.4062533,,,,,,,,,,,,,,,,, -378100,0.2378177,1.458377,,,,,,,,,,,,,,,,, -378200,0.23859303,1.4454954,,,,,,,,,,,,,,,,, -378300,0.22586256,1.4358209,,,,,,,,,,,,,,,,, -378400,0.22283852,1.3521515,,,,,,,,,,,,,,,,, -378500,0.23425105,1.5310967,,,,,,,,,,,,,,,,, -378600,0.24624722,1.3802234,,,,,,,,,,,,,,,,, -378700,0.24057445,1.4553688,,,,,,,,,,,,,,,,, -378800,0.22686812,1.445955,,,,,,,,,,,,,,,,, -378900,0.23841304,1.4870707,,,,,,,,,,,,,,,,, -379000,0.23239304,1.4197483,,,,,,,,,,,,,,,,, -379100,0.23877762,1.5060824,,,,,,,,,,,,,,,,, -379200,0.23708443,1.3580809,,,,,,,,,,,,,,,,, -379300,0.2970945,1.4192564,,,,,,,,,,,,,,,,, -379400,0.23600763,1.4679381,,,,,,,,,,,,,,,,, -379500,0.23594163,1.4256134,,,,,,,,,,,,,,,,, -379600,0.22687359,1.4368566,,,,,,,,,,,,,,,,, -379700,0.23157513,1.5146648,,,,,,,,,,,,,,,,, -379800,0.23400426,1.4223893,,,,,,,,,,,,,,,,, -379900,0.23006408,1.4123065,,,,,,,,,,,,,,,,, -380000,0.23735523,1.436901,,,,,,,,,,,,,,,,, -380100,0.22995815,1.430829,,,,,,,,,,,,,,,,, -380149,,,0.6977683305740356,1.3623616695404053,35.69085706221976,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,135285.55623698235,232270.6904451847,135285.55623698235,96965.33767175674,8.175270795822144,0.0 -380200,0.23496522,1.4499427,,,,,,,,,,,,,,,,, -380300,0.23486069,1.4647418,,,,,,,,,,,,,,,,, -380400,0.22242951,1.3677744,,,,,,,,,,,,,,,,, -380500,0.23439011,1.4781187,,,,,,,,,,,,,,,,, -380600,0.2308334,1.3771235,,,,,,,,,,,,,,,,, -380700,0.22494236,1.3658721,,,,,,,,,,,,,,,,, -380800,0.23999381,1.5348114,,,,,,,,,,,,,,,,, -380900,0.23636545,1.4737422,,,,,,,,,,,,,,,,, -381000,0.23293324,1.4653604,,,,,,,,,,,,,,,,, -381100,0.23086824,1.3826774,,,,,,,,,,,,,,,,, -381200,0.22751999,1.4750749,,,,,,,,,,,,,,,,, -381300,0.23545298,1.4399959,,,,,,,,,,,,,,,,, -381400,0.23727387,1.5338527,,,,,,,,,,,,,,,,, -381500,0.23479314,1.4369104,,,,,,,,,,,,,,,,, -381600,0.23776002,1.403508,,,,,,,,,,,,,,,,, -381700,0.23356953,1.4221046,,,,,,,,,,,,,,,,, -381800,0.23868823,1.4656994,,,,,,,,,,,,,,,,, -381900,0.23400798,1.4399952,,,,,,,,,,,,,,,,, -382000,0.23463103,1.4940189,,,,,,,,,,,,,,,,, -382100,0.2328568,1.4849044,,,,,,,,,,,,,,,,, -382200,0.2265198,1.4350142,,,,,,,,,,,,,,,,, -382300,0.22853947,1.4154456,,,,,,,,,,,,,,,,, -382400,0.23145336,1.5228977,,,,,,,,,,,,,,,,, -382500,0.23448764,1.4447085,,,,,,,,,,,,,,,,, -382509,,,0.6954215168952942,1.3679215908050537,35.408766270120154,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,136125.55805802345,233710.7053785324,136125.55805802345,97565.20195937157,8.250718832015991,0.0 -382600,0.24455875,1.460822,,,,,,,,,,,,,,,,, -382700,0.23997462,1.441029,,,,,,,,,,,,,,,,, -382800,0.23201059,1.3990602,,,,,,,,,,,,,,,,, -382900,0.23405729,1.374296,,,,,,,,,,,,,,,,, -383000,0.24257925,1.3985898,,,,,,,,,,,,,,,,, -383100,0.23254664,1.3649975,,,,,,,,,,,,,,,,, -383200,0.24243574,1.4493712,,,,,,,,,,,,,,,,, -383300,0.23029731,1.4407806,,,,,,,,,,,,,,,,, -383400,0.23228197,1.4662144,,,,,,,,,,,,,,,,, -383500,0.2250548,1.3992712,,,,,,,,,,,,,,,,, -383600,0.24504298,1.4639531,,,,,,,,,,,,,,,,, -383700,0.2383806,1.4500309,,,,,,,,,,,,,,,,, -383800,0.23353268,1.4654387,,,,,,,,,,,,,,,,, -383900,0.23189394,1.4253317,,,,,,,,,,,,,,,,, -384000,0.23620465,1.4658641,,,,,,,,,,,,,,,,, -384100,0.2427004,1.5021858,,,,,,,,,,,,,,,,, -384200,0.23203331,1.4487514,,,,,,,,,,,,,,,,, -384300,0.23076354,1.4867295,,,,,,,,,,,,,,,,, -384400,0.23014243,1.4753362,,,,,,,,,,,,,,,,, -384500,0.2314805,1.4412606,,,,,,,,,,,,,,,,, -384600,0.23370111,1.4831566,,,,,,,,,,,,,,,,, -384700,0.22681701,1.4965687,,,,,,,,,,,,,,,,, -384800,0.23598816,1.3899164,,,,,,,,,,,,,,,,, -384870,,,0.6941781640052795,1.3764313459396362,35.4485266232941,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,136965.75103139877,235157.38303089145,136965.75103139877,98171.5267598629,8.339428424835205,0.0 -384900,0.24642059,1.4946753,,,,,,,,,,,,,,,,, -385000,0.23666361,1.5064116,,,,,,,,,,,,,,,,, -385100,0.22967651,1.4239187,,,,,,,,,,,,,,,,, -385200,0.23386133,1.3757807,,,,,,,,,,,,,,,,, -385300,0.2411933,1.43941,,,,,,,,,,,,,,,,, -385400,0.23711042,1.3841481,,,,,,,,,,,,,,,,, -385500,0.24018967,1.4850146,,,,,,,,,,,,,,,,, -385600,0.24520695,1.4128612,,,,,,,,,,,,,,,,, -385700,0.23011766,1.3886664,,,,,,,,,,,,,,,,, -385800,0.2363969,1.5090092,,,,,,,,,,,,,,,,, -385900,0.22732274,1.4772936,,,,,,,,,,,,,,,,, -386000,0.23019898,1.44942,,,,,,,,,,,,,,,,, -386100,0.2292202,1.4791957,,,,,,,,,,,,,,,,, -386200,0.24503231,1.4505392,,,,,,,,,,,,,,,,, -386300,0.23060718,1.3801819,,,,,,,,,,,,,,,,, -386400,0.23393348,1.4382018,,,,,,,,,,,,,,,,, -386500,0.22593664,1.4567567,,,,,,,,,,,,,,,,, -386600,0.28137323,1.4535177,,,,,,,,,,,,,,,,, -386700,0.23385848,1.4748353,,,,,,,,,,,,,,,,, -386800,0.23013572,1.389869,,,,,,,,,,,,,,,,, -386900,0.23010536,1.4476691,,,,,,,,,,,,,,,,, -387000,0.23359126,1.3571247,,,,,,,,,,,,,,,,, -387100,0.23245333,1.4647974,,,,,,,,,,,,,,,,, -387200,0.23335832,1.4709568,,,,,,,,,,,,,,,,, -387230,,,0.6959114074707031,1.3697524070739746,35.79706608134528,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,137805.83716368675,236606.8883280754,137805.83716368675,98780.79803609848,8.41462779045105,0.0 -387300,0.23476858,1.469441,,,,,,,,,,,,,,,,, -387400,0.22638875,1.4092392,,,,,,,,,,,,,,,,, -387500,0.22469783,1.3762525,,,,,,,,,,,,,,,,, -387600,0.23919623,1.5408843,,,,,,,,,,,,,,,,, -387700,0.23101114,1.4381964,,,,,,,,,,,,,,,,, -387800,0.24623802,1.450394,,,,,,,,,,,,,,,,, -387900,0.23819107,1.4430814,,,,,,,,,,,,,,,,, -388000,0.22706413,1.369671,,,,,,,,,,,,,,,,, -388100,0.22882222,1.4129504,,,,,,,,,,,,,,,,, -388200,0.24177824,1.4171962,,,,,,,,,,,,,,,,, -388300,0.23150074,1.4677193,,,,,,,,,,,,,,,,, -388400,0.23869167,1.4170446,,,,,,,,,,,,,,,,, -388500,0.24858491,1.4840848,,,,,,,,,,,,,,,,, -388600,0.22667839,1.4475625,,,,,,,,,,,,,,,,, -388700,0.24213336,1.483804,,,,,,,,,,,,,,,,, -388800,0.24715377,1.4657428,,,,,,,,,,,,,,,,, -388900,0.2256514,1.4445746,,,,,,,,,,,,,,,,, -389000,0.22428897,1.4048775,,,,,,,,,,,,,,,,, -389100,0.23181072,1.408919,,,,,,,,,,,,,,,,, -389200,0.2294903,1.3798175,,,,,,,,,,,,,,,,, -389300,0.24663065,1.4860953,,,,,,,,,,,,,,,,, -389400,0.23908351,1.4156998,,,,,,,,,,,,,,,,, -389500,0.22483787,1.3843918,,,,,,,,,,,,,,,,, -389591,,,0.6969903111457825,1.3658790588378906,35.59087751015178,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,138645.8393995762,238036.6040613652,138645.8393995762,99370.363945961,8.490809679031372,0.0 -389600,0.23853548,1.4299483,,,,,,,,,,,,,,,,, -389700,0.22673087,1.4162596,,,,,,,,,,,,,,,,, -389800,0.24184398,1.4287798,,,,,,,,,,,,,,,,, -389900,0.23725305,1.4550194,,,,,,,,,,,,,,,,, -390000,0.23442024,1.5137826,,,,,,,,,,,,,,,,, -390100,0.23304065,1.4437591,,,,,,,,,,,,,,,,, -390200,0.22248788,1.4200345,,,,,,,,,,,,,,,,, -390300,0.24348694,1.4987279,,,,,,,,,,,,,,,,, -390400,0.23899008,1.4894673,,,,,,,,,,,,,,,,, -390500,0.23110627,1.4932944,,,,,,,,,,,,,,,,, -390600,0.23290284,1.4166216,,,,,,,,,,,,,,,,, -390700,0.22933334,1.4089874,,,,,,,,,,,,,,,,, -390800,0.23533227,1.459277,,,,,,,,,,,,,,,,, -390900,0.23377572,1.4462847,,,,,,,,,,,,,,,,, -391000,0.2279314,1.4376241,,,,,,,,,,,,,,,,, -391100,0.2435657,1.4955201,,,,,,,,,,,,,,,,, -391200,0.2534434,1.5040467,,,,,,,,,,,,,,,,, -391300,0.23180948,1.4721346,,,,,,,,,,,,,,,,, -391400,0.23331392,1.4738802,,,,,,,,,,,,,,,,, -391500,0.25025305,1.444666,,,,,,,,,,,,,,,,, -391600,0.23748064,1.4752532,,,,,,,,,,,,,,,,, -391700,0.236627,1.441736,,,,,,,,,,,,,,,,, -391800,0.22557773,1.4072788,,,,,,,,,,,,,,,,, -391900,0.23822784,1.3912731,,,,,,,,,,,,,,,,, -391951,,,0.6969181299209595,1.3643102645874023,35.6003938792746,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,139485.86350560188,239475.74658942223,139485.86350560188,99969.31581687929,8.585108041763306,0.0 -392000,0.22665548,1.4037703,,,,,,,,,,,,,,,,, -392100,0.24208377,1.419348,,,,,,,,,,,,,,,,, -392200,0.2221541,1.3802003,,,,,,,,,,,,,,,,, -392300,0.22718185,1.399931,,,,,,,,,,,,,,,,, -392400,0.2396344,1.4059081,,,,,,,,,,,,,,,,, -392500,0.23203371,1.5200933,,,,,,,,,,,,,,,,, -392600,0.24368972,1.5956563,,,,,,,,,,,,,,,,, -392700,0.23725493,1.4114941,,,,,,,,,,,,,,,,, -392800,0.2376805,1.4188612,,,,,,,,,,,,,,,,, -392900,0.23529042,1.4808832,,,,,,,,,,,,,,,,, -393000,0.2445221,1.4529412,,,,,,,,,,,,,,,,, -393100,0.22298244,1.3307743,,,,,,,,,,,,,,,,, -393200,0.2427076,1.5006993,,,,,,,,,,,,,,,,, -393300,0.23231551,1.4837885,,,,,,,,,,,,,,,,, -393400,0.22681667,1.4177738,,,,,,,,,,,,,,,,, -393500,0.23610598,1.4076141,,,,,,,,,,,,,,,,, -393600,0.22925012,1.4276425,,,,,,,,,,,,,,,,, -393700,0.23604667,1.4320878,,,,,,,,,,,,,,,,, -393800,0.24523267,1.4443432,,,,,,,,,,,,,,,,, -393900,0.22377788,1.4419036,,,,,,,,,,,,,,,,, -394000,0.23706041,1.4296677,,,,,,,,,,,,,,,,, -394100,0.23982546,1.4057032,,,,,,,,,,,,,,,,, -394200,0.22751756,1.3882185,,,,,,,,,,,,,,,,, -394300,0.24444024,1.4874946,,,,,,,,,,,,,,,,, -394311,,,0.6953207850456238,1.3696646690368652,35.67714170119371,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,140325.8711452484,240928.74007606503,140325.8711452484,100582.15226006508,8.660670042037964,0.0 -394400,0.22224009,1.3392186,,,,,,,,,,,,,,,,, -394500,0.23260783,1.4243646,,,,,,,,,,,,,,,,, -394600,0.23563059,1.5111625,,,,,,,,,,,,,,,,, -394700,0.24075809,1.4920257,,,,,,,,,,,,,,,,, -394800,0.22244385,1.3634379,,,,,,,,,,,,,,,,, -394900,0.22917673,1.4336472,,,,,,,,,,,,,,,,, -395000,0.2238438,1.4148971,,,,,,,,,,,,,,,,, -395100,0.22865948,1.4278678,,,,,,,,,,,,,,,,, -395200,0.226789,1.3943657,,,,,,,,,,,,,,,,, -395300,0.23456636,1.403656,,,,,,,,,,,,,,,,, -395400,0.23618002,1.4225953,,,,,,,,,,,,,,,,, -395500,0.23381838,1.4708649,,,,,,,,,,,,,,,,, -395600,0.23340994,1.4125469,,,,,,,,,,,,,,,,, -395700,0.2382871,1.5506315,,,,,,,,,,,,,,,,, -395800,0.23781107,1.4400178,,,,,,,,,,,,,,,,, -395900,0.24446255,1.4996743,,,,,,,,,,,,,,,,, -396000,0.2383408,1.4789242,,,,,,,,,,,,,,,,, -396100,0.23348926,1.4225726,,,,,,,,,,,,,,,,, -396200,0.23419997,1.4704897,,,,,,,,,,,,,,,,, -396300,0.23250082,1.4788662,,,,,,,,,,,,,,,,, -396400,0.23260441,1.3872356,,,,,,,,,,,,,,,,, -396500,0.25064048,1.4166734,,,,,,,,,,,,,,,,, -396600,0.21902762,1.3789383,,,,,,,,,,,,,,,,, -396671,,,0.6944743990898132,1.3710854053497314,35.42201402751068,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,141165.98493552208,242370.5478367805,141165.98493552208,101183.69587278366,8.738826513290405,0.0 -396700,0.24088104,1.4560446,,,,,,,,,,,,,,,,, -396800,0.23617667,1.5142072,,,,,,,,,,,,,,,,, -396900,0.23364782,1.4739226,,,,,,,,,,,,,,,,, -397000,0.24385299,1.4472231,,,,,,,,,,,,,,,,, -397100,0.23545095,1.5487833,,,,,,,,,,,,,,,,, -397200,0.23502514,1.4641751,,,,,,,,,,,,,,,,, -397300,0.23908699,1.4242091,,,,,,,,,,,,,,,,, -397400,0.2387441,1.4701176,,,,,,,,,,,,,,,,, -397500,0.22368796,1.3860487,,,,,,,,,,,,,,,,, -397600,0.24748237,1.536487,,,,,,,,,,,,,,,,, -397700,0.22981933,1.3904365,,,,,,,,,,,,,,,,, -397800,0.22708826,1.4089729,,,,,,,,,,,,,,,,, -397900,0.23371337,1.4659765,,,,,,,,,,,,,,,,, -398000,0.22693756,1.4327615,,,,,,,,,,,,,,,,, -398100,0.23975426,1.4704865,,,,,,,,,,,,,,,,, -398200,0.23294392,1.4977692,,,,,,,,,,,,,,,,, -398300,0.24670538,1.514148,,,,,,,,,,,,,,,,, -398400,0.24165425,1.4649589,,,,,,,,,,,,,,,,, -398500,0.23176214,1.509063,,,,,,,,,,,,,,,,, -398600,0.22589767,1.4720343,,,,,,,,,,,,,,,,, -398700,0.23987263,1.4135159,,,,,,,,,,,,,,,,, -398800,0.25460184,1.504758,,,,,,,,,,,,,,,,, -398900,0.23693794,1.3985591,,,,,,,,,,,,,,,,, -399000,0.23465855,1.4788624,,,,,,,,,,,,,,,,, -399032,,,0.6948761343955994,1.3741616010665894,35.58164949728439,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,142006.1650776863,243813.446236372,142006.1650776863,101786.26503157616,8.816017627716064,0.0 -399100,0.23030342,1.527693,,,,,,,,,,,,,,,,, -399200,0.23536679,1.391553,,,,,,,,,,,,,,,,, -399300,0.22891545,1.4662359,,,,,,,,,,,,,,,,, -399400,0.2397458,1.4359708,,,,,,,,,,,,,,,,, -399500,0.23117192,1.4694085,,,,,,,,,,,,,,,,, -399600,0.23678803,1.4947379,,,,,,,,,,,,,,,,, -399700,0.23613863,1.4086409,,,,,,,,,,,,,,,,, -399800,0.22537634,1.3796016,,,,,,,,,,,,,,,,, -399900,0.2400734,1.408159,,,,,,,,,,,,,,,,, -399999,,,0.6971225738525391,1.3681684732437134,35.5268341500284,0.6957632303237915,1.3586411476135254,30.80970163535293,3000.0,0.7146011590957642,1.2520129680633545,31.016732013552563,3003.0,142350.1185748577,244755.2840101719,142350.1185748577,102384.04146027564,8.893699645996094,0.0 -399999,,,,,,,,,,,,,,142350.1185748577,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 729f00578..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,50 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,total_duration,train/loss,validation/loss,validation/num_examples -784.8753859996796,0.0,20.74029541015625,1,0,20.74029541015625,0.2684693901110197,95000000,805.6157176494598,0.2663397727616178,0.2667628934273754,83274637 -1440.9713680744171,0.027174949645996,141.08283019065857,179,0,141.08283019065857,0.1307766509457237,95000000,1582.0874617099762,0.1278894952023929,0.1284199545960081,83274637 -2068.264231443405,0.0493304729461669,261.6510796546936,359,0,261.6510796546936,0.1294778536389802,95000000,2329.977290391922,0.1244112567495812,0.1269361273205512,83274637 -2684.0567717552185,0.0703155994415283,382.3513696193695,537,0,382.3513696193695,0.1283980485094572,95000000,3066.497213602066,0.1242823892583449,0.1260890645911322,83274637 -3319.367733955384,0.0948197841644287,502.5978763103485,719,0,502.5978763103485,0.1284853827097039,95000000,3822.085498809816,0.1250295413763455,0.1259149528705128,83274637 -3918.4293892383575,0.117992877960205,623.3132708072662,894,0,623.3132708072662,0.1279691620476973,95000000,4541.891732931137,0.1248043727762294,0.1254994502771275,83274637 -4507.998911142349,0.1389319896697998,743.8181428909302,1075,0,743.8181428909302,0.1282013358963815,95000000,5251.993498563767,0.1242811144710336,0.1257489394535915,83274637 -5023.655802726746,0.1610760688781738,863.8242483139038,1250,0,863.8242483139038,0.1282463938219572,95000000,5887.684747219086,0.1256232065318515,0.1257673064314982,83274637 -5468.990014791489,0.182673692703247,983.8378643989564,1425,0,983.8378643989564,0.1277889363898026,95000000,6453.060389280319,0.1253687928565455,0.125561976542487,83274637 -5898.956813812256,0.2033555507659912,1103.9650797843933,1601,0,1103.9650797843933,0.1281505349814967,95000000,7003.181186199188,0.1255340838933703,0.1256407656669859,83274637 -6334.50224852562,0.2240393161773681,1224.746877908707,1777,0,1224.746877908707,0.1277212985505756,95000000,7559.535044431686,0.1241320770866466,0.1254365534382523,83274637 -6764.7794716358185,0.2446043491363525,1345.3115315437317,1956,0,1345.3115315437317,0.1277315736944901,95000000,8110.4038507938385,0.1249245411640255,0.1253922288937747,83274637 -7207.456950187683,0.2668092250823974,1465.5621354579926,2129,0,1465.5621354579926,0.128003693852796,95000000,8673.360074996948,0.1239133313907392,0.1254291004777774,83274637 -7656.000282049179,0.287703275680542,1585.732696056366,2309,0,1585.732696056366,0.1275379534950657,95000000,9242.101125717165,0.1242583309587255,0.1251300903017596,83274637 -8121.151620388031,0.3085925579071045,1705.8460566997528,2484,0,1705.8460566997528,0.1272851764288651,95000000,9827.39279603958,0.1252529286082435,0.1249653129058277,83274637 -8583.522858858109,0.329925537109375,1826.1175792217248,2664,0,1826.1175792217248,0.1271264977693256,95000000,10410.062943696976,0.123065294311294,0.1247274926261555,83274637 -9023.981614589691,0.3516182899475097,1946.225148439408,2837,0,1946.225148439408,0.1271563959292763,95000000,10970.657076835632,0.1240034895076316,0.1247779347999345,83274637 -9463.522585868835,0.3764607906341553,2066.3307497501373,3013,0,2066.3307497501373,0.1270789125411184,95000000,11530.33465719223,0.1244736597433967,0.1247752609947488,83274637 -9921.37606692314,0.3975801467895508,2186.563698530197,3190,0,2186.563698530197,0.1276185454050164,95000000,12108.448246002195,0.1233823119352261,0.1250878789937997,83274637 -10368.43229842186,0.4201464653015136,2307.043375492096,3363,0,2307.043375492096,0.1273633506887335,95000000,12676.013177394869,0.1243024425077363,0.1249246158304422,83274637 -10823.348504304886,0.4440338611602783,2427.7164137363434,3536,0,2427.7164137363434,0.127071047265625,95000000,13251.632189273834,0.123183436551184,0.1247622584671999,83274637 -11300.688010931017,0.465282678604126,2547.7152211666107,3713,0,2547.7152211666107,0.1270655221422697,95000000,13848.997977733612,0.1218390933724131,0.1247118196356187,83274637 -11765.948224782944,0.4866828918457031,2668.2507762908936,3896,0,2668.2507762908936,0.1272068880037006,95000000,14434.821523189545,0.1240636574444156,0.1248220995814915,83274637 -12230.369115829468,0.5120937824249268,2788.8390402793884,4077,0,2788.8390402793884,0.1271889521278782,95000000,15019.862214803696,0.1229546935701707,0.124763756475456,83274637 -12691.46737909317,0.5355861186981201,2909.3049216270447,4250,0,2909.3049216270447,0.126971111482319,95000000,15601.455958843231,0.1226657796182534,0.1245887371956699,83274637 -13142.118698596954,0.5581104755401611,3029.503550052643,4427,0,3029.503550052643,0.1267689109786184,95000000,16172.33487391472,0.1234977096670641,0.1245415374404735,83274637 -13610.447952508926,0.5820302963256836,3149.548124074936,4600,0,3149.548124074936,0.1265022864000822,95000000,16760.738788843155,0.1242832675599639,0.1241943125130907,83274637 -14095.855922222136,0.6045966148376465,3269.919935464859,4774,0,3269.919935464859,0.1267213037726151,95000000,17366.54714846611,0.1229618146500122,0.1243619039282032,83274637 -14574.3692009449,0.626798152923584,3390.042678594589,4955,0,3390.042678594589,0.1265116089432565,95000000,17965.211741685867,0.1232581764295603,0.124244658249803,83274637 -15036.332559347153,0.6554334163665771,3510.5426092147827,5133,0,3510.5426092147827,0.1266374898334704,95000000,18547.709943056107,0.1219072359022479,0.1242624324508088,83274637 -15505.252483844755,0.6820833683013916,3631.015019178392,5309,0,3631.015019178392,0.1264841745374177,95000000,19137.135103940964,0.1221964419145816,0.1241512517298446,83274637 -15990.066748857498,0.7038779258728027,3751.208186149597,5486,0,3751.208186149597,0.1265668650596217,95000000,19742.1704621315,0.1216318206803041,0.1241694000869676,83274637 -16448.77982020378,0.726233959197998,3871.4612600803375,5662,0,3871.4612600803375,0.1265098257401316,95000000,20321.16526722908,0.1220911874336266,0.124157719105492,83274637 -16902.430883169174,0.7541337013244629,3991.6688067913055,5834,0,3991.6688067913055,0.1264106695106908,95000000,20895.05779361725,0.1214350509873163,0.1240995046858006,83274637 -17359.184082746506,0.7772493362426758,4111.884814977646,6009,0,4111.884814977646,0.1266430611225329,95000000,21472.05645680428,0.1249175532105958,0.1242046093673042,83274637 -17846.472038030624,0.7994480133056641,4232.543785095215,6186,0,4232.543785095215,0.1264988046361019,95000000,22080.032007217407,0.1222002775109601,0.1240868618549655,83274637 -18303.52065062523,0.8268420696258545,4352.96163725853,6359,0,4352.96163725853,0.1264529834498355,95000000,22657.53194761276,0.12460690742628,0.1239930162581626,83274637 -18755.63269805908,0.8494350910186768,4473.102373123169,6534,0,4473.102373123169,0.1264160886821546,95000000,23229.813537836075,0.1224327812193887,0.1240451114243618,83274637 -19244.77996492386,0.8721582889556885,4593.279596328735,6706,0,4593.279596328735,0.1263598801706414,95000000,23839.16663122177,0.1237166764788657,0.1240260471994856,83274637 -19726.35587143898,0.9004464149475098,4713.245164394379,6886,0,4713.245164394379,0.1262856927425986,95000000,24440.743014335632,0.1237623976227247,0.1239329783470528,83274637 -20186.95633172989,0.923396110534668,4833.252229213715,7058,0,4833.252229213715,0.1263129463404605,95000000,25021.379301548004,0.1229029778836283,0.1239070159020458,83274637 -20674.39349412918,0.9461169242858888,4953.703994989395,7230,0,4953.703994989395,0.1263338623766447,95000000,25629.297053337097,0.1207994308041514,0.1239140811382529,83274637 -21162.2273747921,0.9694068431854248,5073.844031572342,7404,0,5073.844031572342,0.1263063353824013,95000000,26237.300297021862,0.1211525737981563,0.1239350570827071,83274637 -21640.932502031326,1.0001626014709473,5194.468368291855,7581,0,5194.468368291855,0.1260931680407072,95000000,26836.66686487198,0.1226242299092078,0.1237928679954235,83274637 -22125.35420370102,1.0256502628326416,5314.780197143555,7755,0,5314.780197143555,0.126230996535773,95000000,27441.43182086945,0.1204891691347525,0.1238589555954023,83274637 -22599.20179247856,1.0498626232147217,5435.426011562347,7931,0,5435.426011562347,0.1261580217927631,95000000,28035.95552253723,0.1226318214416691,0.1237735545950391,83274637 -23071.45604896545,1.0735130310058594,5555.516617774963,8107,0,5555.516617774963,0.1260702185958059,95000000,28628.330218553543,0.1211283176364201,0.1237530625246075,83274637 -23553.10405778885,1.096099615097046,5676.100703716278,8284,0,5676.100703716278,0.1262546892680921,95000000,29230.590974092484,0.1219690176629045,0.1238302158015748,83274637 -24024.937223672867,1.1230800151824951,5796.413721323013,8459,0,5796.413721323013,0.1259658468544408,95000000,29822.770414352417,0.11907991399086497,0.12366733509119425,83274637 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv deleted file mode 100644 index 781fdb67f..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/measurements.csv +++ /dev/null @@ -1,136 +0,0 @@ -global_step,grad_norm,loss,train/loss,validation/loss,validation/num_examples,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,2.1930985,0.26846048,,,,,,,,,,, -1,,,0.2663397727616178,0.2667628934273754,83274637.0,0.2684693901110197,95000000.0,20.74029541015625,805.6157176494598,20.74029541015625,784.8753859996796,0.0,0.0 -100,0.104930095,0.13651466,,,,,,,,,,, -179,,,0.1278894952023929,0.1284199545960081,83274637.0,0.1307766509457237,95000000.0,141.08283019065857,1582.0874617099762,141.08283019065857,1440.9713680744171,0.027174949645996,0.0 -200,0.06072267,0.12368205,,,,,,,,,,, -300,0.014612774,0.13404481,,,,,,,,,,, -359,,,0.1244112567495812,0.1269361273205512,83274637.0,0.1294778536389802,95000000.0,261.6510796546936,2329.977290391922,261.6510796546936,2068.264231443405,0.0493304729461669,0.0 -400,0.006130142,0.13142456,,,,,,,,,,, -500,0.036169864,0.13268113,,,,,,,,,,, -537,,,0.1242823892583449,0.1260890645911322,83274637.0,0.1283980485094572,95000000.0,382.3513696193695,3066.497213602066,382.3513696193695,2684.0567717552185,0.0703155994415283,0.0 -600,0.0085886475,0.11881731,,,,,,,,,,, -700,0.01574716,0.123432875,,,,,,,,,,, -719,,,0.1250295413763455,0.1259149528705128,83274637.0,0.1284853827097039,95000000.0,502.5978763103485,3822.085498809816,502.5978763103485,3319.367733955384,0.0948197841644287,0.0 -800,0.030275773,0.12273781,,,,,,,,,,, -894,,,0.1248043727762294,0.1254994502771275,83274637.0,0.1279691620476973,95000000.0,623.3132708072662,4541.891732931137,623.3132708072662,3918.4293892383575,0.117992877960205,0.0 -900,0.017587513,0.12058121,,,,,,,,,,, -1000,0.00907921,0.13148701,,,,,,,,,,, -1075,,,0.1242811144710336,0.1257489394535915,83274637.0,0.1282013358963815,95000000.0,743.8181428909302,5251.993498563767,743.8181428909302,4507.998911142349,0.1389319896697998,0.0 -1100,0.0041403514,0.117603384,,,,,,,,,,, -1200,0.0046649426,0.12518133,,,,,,,,,,, -1250,,,0.1256232065318515,0.1257673064314982,83274637.0,0.1282463938219572,95000000.0,863.8242483139038,5887.684747219086,863.8242483139038,5023.655802726746,0.1610760688781738,0.0 -1300,0.015407337,0.12389787,,,,,,,,,,, -1400,0.014607073,0.12413393,,,,,,,,,,, -1425,,,0.1253687928565455,0.125561976542487,83274637.0,0.1277889363898026,95000000.0,983.8378643989564,6453.060389280319,983.8378643989564,5468.990014791489,0.182673692703247,0.0 -1500,0.020727182,0.12389517,,,,,,,,,,, -1600,0.021942856,0.12210446,,,,,,,,,,, -1601,,,0.1255340838933703,0.1256407656669859,83274637.0,0.1281505349814967,95000000.0,1103.9650797843933,7003.181186199188,1103.9650797843933,5898.956813812256,0.2033555507659912,0.0 -1700,0.0041196756,0.12786102,,,,,,,,,,, -1777,,,0.1241320770866466,0.1254365534382523,83274637.0,0.1277212985505756,95000000.0,1224.746877908707,7559.535044431686,1224.746877908707,6334.50224852562,0.2240393161773681,0.0 -1800,0.010061002,0.12948005,,,,,,,,,,, -1900,0.018932506,0.122595906,,,,,,,,,,, -1956,,,0.1249245411640255,0.1253922288937747,83274637.0,0.1277315736944901,95000000.0,1345.3115315437317,8110.4038507938385,1345.3115315437317,6764.7794716358185,0.2446043491363525,0.0 -2000,0.019218286,0.1182971,,,,,,,,,,, -2100,0.0058660624,0.12448103,,,,,,,,,,, -2129,,,0.1239133313907392,0.1254291004777774,83274637.0,0.128003693852796,95000000.0,1465.5621354579926,8673.360074996948,1465.5621354579926,7207.456950187683,0.2668092250823974,0.0 -2200,0.0074782167,0.12251457,,,,,,,,,,, -2300,0.008873731,0.1172148,,,,,,,,,,, -2309,,,0.1242583309587255,0.1251300903017596,83274637.0,0.1275379534950657,95000000.0,1585.732696056366,9242.101125717165,1585.732696056366,7656.000282049179,0.287703275680542,0.0 -2400,0.01553479,0.12634256,,,,,,,,,,, -2484,,,0.1252529286082435,0.1249653129058277,83274637.0,0.1272851764288651,95000000.0,1705.8460566997528,9827.39279603958,1705.8460566997528,8121.151620388031,0.3085925579071045,0.0 -2500,0.021867786,0.12869057,,,,,,,,,,, -2600,0.022947427,0.12605636,,,,,,,,,,, -2664,,,0.123065294311294,0.1247274926261555,83274637.0,0.1271264977693256,95000000.0,1826.1175792217248,10410.062943696976,1826.1175792217248,8583.522858858109,0.329925537109375,0.0 -2700,0.024058932,0.122145124,,,,,,,,,,, -2800,0.019129712,0.11952773,,,,,,,,,,, -2837,,,0.1240034895076316,0.1247779347999345,83274637.0,0.1271563959292763,95000000.0,1946.225148439408,10970.657076835632,1946.225148439408,9023.981614589691,0.3516182899475097,0.0 -2900,0.010019595,0.12283029,,,,,,,,,,, -3000,0.004922577,0.12011524,,,,,,,,,,, -3013,,,0.1244736597433967,0.1247752609947488,83274637.0,0.1270789125411184,95000000.0,2066.3307497501373,11530.33465719223,2066.3307497501373,9463.522585868835,0.3764607906341553,0.0 -3100,0.008427466,0.12468648,,,,,,,,,,, -3190,,,0.1233823119352261,0.1250878789937997,83274637.0,0.1276185454050164,95000000.0,2186.563698530197,12108.448246002195,2186.563698530197,9921.37606692314,0.3975801467895508,0.0 -3200,0.015292694,0.12036454,,,,,,,,,,, -3300,0.007746675,0.124304414,,,,,,,,,,, -3363,,,0.1243024425077363,0.1249246158304422,83274637.0,0.1273633506887335,95000000.0,2307.043375492096,12676.013177394869,2307.043375492096,10368.43229842186,0.4201464653015136,0.0 -3400,0.018191785,0.11702757,,,,,,,,,,, -3500,0.009586415,0.12326091,,,,,,,,,,, -3536,,,0.123183436551184,0.1247622584671999,83274637.0,0.127071047265625,95000000.0,2427.7164137363434,13251.632189273834,2427.7164137363434,10823.348504304886,0.4440338611602783,0.0 -3600,0.009343405,0.13316886,,,,,,,,,,, -3700,0.0054497137,0.12513854,,,,,,,,,,, -3713,,,0.1218390933724131,0.1247118196356187,83274637.0,0.1270655221422697,95000000.0,2547.7152211666107,13848.997977733612,2547.7152211666107,11300.688010931017,0.465282678604126,0.0 -3800,0.0071838237,0.12695439,,,,,,,,,,, -3896,,,0.1240636574444156,0.1248220995814915,83274637.0,0.1272068880037006,95000000.0,2668.2507762908936,14434.821523189545,2668.2507762908936,11765.948224782944,0.4866828918457031,0.0 -3900,0.026726417,0.12132137,,,,,,,,,,, -4000,0.009785221,0.12559931,,,,,,,,,,, -4077,,,0.1229546935701707,0.124763756475456,83274637.0,0.1271889521278782,95000000.0,2788.8390402793884,15019.862214803696,2788.8390402793884,12230.369115829468,0.5120937824249268,0.0 -4100,0.0077357087,0.12698506,,,,,,,,,,, -4200,0.011404115,0.12219869,,,,,,,,,,, -4250,,,0.1226657796182534,0.1245887371956699,83274637.0,0.126971111482319,95000000.0,2909.3049216270447,15601.455958843231,2909.3049216270447,12691.46737909317,0.5355861186981201,0.0 -4300,0.015791442,0.13019061,,,,,,,,,,, -4400,0.012395812,0.12255643,,,,,,,,,,, -4427,,,0.1234977096670641,0.1245415374404735,83274637.0,0.1267689109786184,95000000.0,3029.503550052643,16172.33487391472,3029.503550052643,13142.118698596954,0.5581104755401611,0.0 -4500,0.0113351885,0.12226565,,,,,,,,,,, -4600,,,0.1242832675599639,0.1241943125130907,83274637.0,0.1265022864000822,95000000.0,3149.548124074936,16760.738788843155,3149.548124074936,13610.447952508926,0.5820302963256836,0.0 -4600,0.007830577,0.11811897,,,,,,,,,,, -4700,0.004468883,0.1164404,,,,,,,,,,, -4774,,,0.1229618146500122,0.1243619039282032,83274637.0,0.1267213037726151,95000000.0,3269.919935464859,17366.54714846611,3269.919935464859,14095.855922222136,0.6045966148376465,0.0 -4800,0.004924645,0.1343841,,,,,,,,,,, -4900,0.022185467,0.13152313,,,,,,,,,,, -4955,,,0.1232581764295603,0.124244658249803,83274637.0,0.1265116089432565,95000000.0,3390.042678594589,17965.211741685867,3390.042678594589,14574.3692009449,0.626798152923584,0.0 -5000,0.014331179,0.13095464,,,,,,,,,,, -5100,0.009507739,0.120239004,,,,,,,,,,, -5133,,,0.1219072359022479,0.1242624324508088,83274637.0,0.1266374898334704,95000000.0,3510.5426092147827,18547.709943056107,3510.5426092147827,15036.332559347153,0.6554334163665771,0.0 -5200,0.013233239,0.1289348,,,,,,,,,,, -5300,0.013229756,0.12450665,,,,,,,,,,, -5309,,,0.1221964419145816,0.1241512517298446,83274637.0,0.1264841745374177,95000000.0,3631.015019178392,19137.135103940964,3631.015019178392,15505.252483844755,0.6820833683013916,0.0 -5400,0.0059553254,0.121257536,,,,,,,,,,, -5486,,,0.1216318206803041,0.1241694000869676,83274637.0,0.1265668650596217,95000000.0,3751.208186149597,19742.1704621315,3751.208186149597,15990.066748857498,0.7038779258728027,0.0 -5500,0.006546553,0.12862343,,,,,,,,,,, -5600,0.00535587,0.11656903,,,,,,,,,,, -5662,,,0.1220911874336266,0.124157719105492,83274637.0,0.1265098257401316,95000000.0,3871.4612600803375,20321.16526722908,3871.4612600803375,16448.77982020378,0.726233959197998,0.0 -5700,0.006018935,0.12926865,,,,,,,,,,, -5800,0.005689731,0.121659696,,,,,,,,,,, -5834,,,0.1214350509873163,0.1240995046858006,83274637.0,0.1264106695106908,95000000.0,3991.6688067913055,20895.05779361725,3991.6688067913055,16902.430883169174,0.7541337013244629,0.0 -5900,0.009266991,0.12151732,,,,,,,,,,, -6000,0.014719007,0.120574884,,,,,,,,,,, -6009,,,0.1249175532105958,0.1242046093673042,83274637.0,0.1266430611225329,95000000.0,4111.884814977646,21472.05645680428,4111.884814977646,17359.184082746506,0.7772493362426758,0.0 -6100,0.008946834,0.12171655,,,,,,,,,,, -6186,,,0.1222002775109601,0.1240868618549655,83274637.0,0.1264988046361019,95000000.0,4232.543785095215,22080.032007217407,4232.543785095215,17846.472038030624,0.7994480133056641,0.0 -6200,0.00966573,0.11661857,,,,,,,,,,, -6300,0.009883191,0.1316916,,,,,,,,,,, -6359,,,0.12460690742628,0.1239930162581626,83274637.0,0.1264529834498355,95000000.0,4352.96163725853,22657.53194761276,4352.96163725853,18303.52065062523,0.8268420696258545,0.0 -6400,0.005149468,0.119738445,,,,,,,,,,, -6500,0.0059948494,0.12524307,,,,,,,,,,, -6534,,,0.1224327812193887,0.1240451114243618,83274637.0,0.1264160886821546,95000000.0,4473.102373123169,23229.813537836075,4473.102373123169,18755.63269805908,0.8494350910186768,0.0 -6600,0.0069100345,0.12103201,,,,,,,,,,, -6700,0.006991834,0.1279879,,,,,,,,,,, -6706,,,0.1237166764788657,0.1240260471994856,83274637.0,0.1263598801706414,95000000.0,4593.279596328735,23839.16663122177,4593.279596328735,19244.77996492386,0.8721582889556885,0.0 -6800,0.0053291437,0.12091319,,,,,,,,,,, -6886,,,0.1237623976227247,0.1239329783470528,83274637.0,0.1262856927425986,95000000.0,4713.245164394379,24440.743014335632,4713.245164394379,19726.35587143898,0.9004464149475098,0.0 -6900,0.0069548567,0.12198989,,,,,,,,,,, -7000,0.0053467005,0.116236754,,,,,,,,,,, -7058,,,0.1229029778836283,0.1239070159020458,83274637.0,0.1263129463404605,95000000.0,4833.252229213715,25021.379301548004,4833.252229213715,20186.95633172989,0.923396110534668,0.0 -7100,0.0060645747,0.11808313,,,,,,,,,,, -7200,0.006028386,0.113733605,,,,,,,,,,, -7230,,,0.1207994308041514,0.1239140811382529,83274637.0,0.1263338623766447,95000000.0,4953.703994989395,25629.297053337097,4953.703994989395,20674.39349412918,0.9461169242858888,0.0 -7300,0.0053902916,0.13075319,,,,,,,,,,, -7400,0.0055731656,0.12459305,,,,,,,,,,, -7404,,,0.1211525737981563,0.1239350570827071,83274637.0,0.1263063353824013,95000000.0,5073.844031572342,26237.300297021862,5073.844031572342,21162.2273747921,0.9694068431854248,0.0 -7500,0.006399163,0.12136135,,,,,,,,,,, -7581,,,0.1226242299092078,0.1237928679954235,83274637.0,0.1260931680407072,95000000.0,5194.468368291855,26836.66686487198,5194.468368291855,21640.932502031326,1.0001626014709473,0.0 -7600,0.0064975964,0.12738897,,,,,,,,,,, -7700,0.0070490106,0.12407769,,,,,,,,,,, -7755,,,0.1204891691347525,0.1238589555954023,83274637.0,0.126230996535773,95000000.0,5314.780197143555,27441.43182086945,5314.780197143555,22125.35420370102,1.0256502628326416,0.0 -7800,0.006615667,0.12868126,,,,,,,,,,, -7900,0.005217898,0.12460662,,,,,,,,,,, -7931,,,0.1226318214416691,0.1237735545950391,83274637.0,0.1261580217927631,95000000.0,5435.426011562347,28035.95552253723,5435.426011562347,22599.20179247856,1.0498626232147217,0.0 -8000,0.00674094,0.118906036,,,,,,,,,,, -8100,0.0056978674,0.12269659,,,,,,,,,,, -8107,,,0.1211283176364201,0.1237530625246075,83274637.0,0.1260702185958059,95000000.0,5555.516617774963,28628.330218553543,5555.516617774963,23071.45604896545,1.0735130310058594,0.0 -8200,0.010510831,0.11569103,,,,,,,,,,, -8284,,,0.1219690176629045,0.1238302158015748,83274637.0,0.1262546892680921,95000000.0,5676.100703716278,29230.590974092484,5676.100703716278,23553.10405778885,1.096099615097046,0.0 -8300,0.0061451956,0.12271461,,,,,,,,,,, -8400,0.0052087633,0.12407857,,,,,,,,,,, -8459,,,0.1190799139908649,0.1236673350911942,83274637.0,0.1259658468544408,95000000.0,5796.413721323013,29822.770414352417,5796.413721323013,24024.937223672867,1.1230800151824951,0.0 -8459,,,,,,,,5796.413721323013,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv deleted file mode 100644 index dea8b4917..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,20 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/loss,test/num_examples,test/ssim,total_duration,train/loss,train/ssim,validation/loss,validation/num_examples,validation/ssim -203.35107350349423,0.0,54.99857354164124,1,0,54.99857354164124,0.9717777937159662,3581,0.2242922853471446,258.349974155426,0.9647607803344728,0.2082930632999965,0.9715372148899128,3554,0.2018833106732731 -207.78307437896729,0.0291476249694824,135.1870138645172,341,0,135.1870138645172,0.3156862362774015,3581,0.7129807545465652,343.01261258125305,0.2917663369859968,0.7175235067095075,0.3134440699519907,3554,0.6955450999314153 -211.7945501804352,0.0605149269104003,215.2247176170349,591,0,215.2247176170349,0.3052406172551312,3581,0.7222395542446244,427.10220170021057,0.2823376655578613,0.7263099806649345,0.302990090665887,3554,0.7053244638523495 -215.8074471950531,0.0999174118041992,295.54242730140686,842,0,295.54242730140686,0.2997143534234501,3581,0.7286815669942055,511.4811406135559,0.2768257004874093,0.7335096086774554,0.2976826768144168,3554,0.7114909037352279 -219.82274556159973,0.1299653053283691,375.6709289550781,1143,0,375.6709289550781,0.2982516572382889,3581,0.7312676441200083,595.6661460399628,0.274808577128819,0.7372263499668666,0.2963607178421321,3554,0.7141574222047341 -223.84068298339844,0.154393196105957,455.93432664871216,1491,0,455.93432664871216,0.2970154098344562,3581,0.7314674017383412,679.9849674701691,0.2738418238503592,0.7373614992414202,0.2952618445985157,3554,0.7141864800225098 -227.8558185100556,0.1814846992492675,535.9488203525543,1836,0,535.9488203525543,0.3004433664653728,3581,0.727346395254468,764.0546314716339,0.2777445146015712,0.7328919683183942,0.2984948531254396,3554,0.7103634191887662 -231.8789942264557,0.205850601196289,616.008457660675,2180,0,616.008457660675,0.2980956008600077,3581,0.7280031410351508,848.1748449802399,0.2749068907329014,0.7341750008719308,0.2964675379523951,3554,0.7107485898371553 -235.9019057750702,0.2352378368377685,696.0686941146851,2526,0,696.0686941146851,0.2919683597567893,3581,0.7374985137487783,932.3001706600188,0.2693015507289341,0.7426437650408063,0.2903870006726575,3554,0.720475677163935 -239.91925048828125,0.265221357345581,776.175074338913,2872,0,776.175074338913,0.2914311276637985,3581,0.7365833102441707,1016.46715259552,0.2687500715255737,0.7419191087995257,0.2899843472473797,3554,0.7192985264182963 -243.938280582428,0.2893610000610351,856.2194812297821,3217,0,856.2194812297821,0.2907195678472319,3581,0.7374621755881737,1100.5676794052124,0.2680133751460484,0.7428614071437291,0.2891478873936058,3554,0.7201804277530599 -247.95599675178528,0.3130578994750976,936.2317168712616,3561,0,936.2317168712616,0.2898982776939053,3581,0.7390436014294192,1184.6341347694397,0.2670303412846156,0.7445182800292969,0.2884467902583357,3554,0.7218931903313168 -251.97154355049133,0.3453776836395263,1016.3818063735962,3908,0,1016.3818063735962,0.2898596215268081,3581,0.7384366927883272,1268.844962835312,0.2670188290732248,0.7442232540675572,0.2884106225489765,3554,0.7212261657199635 -255.9899456501007,0.3698945045471191,1096.402687072754,4251,0,1096.402687072754,0.2894233249812378,3581,0.7399618728837964,1352.9219012260437,0.2666709763663156,0.7455113274710519,0.2879890781074142,3554,0.7227235707125774 -260.0105800628662,0.3952672481536865,1176.3708062171936,4598,0,1176.3708062171936,0.2893448877321279,3581,0.7400011426408475,1436.948877811432,0.2661698205130441,0.7461780820574079,0.2878877192182224,3554,0.7227209603175999 -264.0284032821655,0.4210107326507568,1256.493564605713,4946,0,1256.493564605713,0.2890736127958321,3581,0.7397029379232407,1521.128168106079,0.2660641840526035,0.7456933430262974,0.2876404186414075,3554,0.7224849943505557 -268.04534125328064,0.4468610286712646,1336.5503425598145,5290,0,1336.5503425598145,0.2906346538152751,3581,0.7390925522724099,1605.240535736084,0.2675383261271885,0.7448530878339495,0.289152009069886,3554,0.7219431313089125 -272.06396198272705,0.4738578796386719,1416.5432102680206,5638,0,1416.5432102680206,0.2894326310955389,3581,0.7393302842912245,1689.2920603752136,0.2661756787981306,0.7454932076590401,0.2879706507797112,3554,0.722028793480937 -276.080983877182,0.5024864673614502,1496.5398621559143,5983,0,1496.5398621559143,0.2891913538903239,3581,0.7408738039086498,1773.3470647335052,0.26584884098597933,0.7469315528869629,0.2877302540106658,3554,0.7237229398213281 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv deleted file mode 100644 index 23bcfe7bd..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/measurements.csv +++ /dev/null @@ -1,81 +0,0 @@ -global_step,grad_norm,loss,train/ssim,train/loss,validation/ssim,validation/loss,validation/num_examples,test/ssim,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,5.3472953,0.9628346,,,,,,,,,,,,,, -1,,,0.2082930632999965,0.9647607803344728,0.2018833106732731,0.9715372148899128,3554.0,0.2242922853471446,0.9717777937159662,3581.0,54.99857354164124,258.349974155426,54.99857354164124,203.35107350349423,0.0,0.0 -100,0.116618454,0.27477154,,,,,,,,,,,,,, -200,0.33860973,0.3046119,,,,,,,,,,,,,, -300,0.13344622,0.3497705,,,,,,,,,,,,,, -341,,,0.7175235067095075,0.2917663369859968,0.6955450999314153,0.3134440699519907,3554.0,0.7129807545465652,0.3156862362774015,3581.0,135.1870138645172,343.01261258125305,135.1870138645172,207.78307437896729,0.0291476249694824,0.0 -400,0.298167,0.33631724,,,,,,,,,,,,,, -500,0.5317048,0.21160781,,,,,,,,,,,,,, -591,,,0.7263099806649345,0.2823376655578613,0.7053244638523495,0.302990090665887,3554.0,0.7222395542446244,0.3052406172551312,3581.0,215.2247176170349,427.10220170021057,215.2247176170349,211.7945501804352,0.0605149269104003,0.0 -600,0.06274106,0.386313,,,,,,,,,,,,,, -700,0.23271002,0.22690618,,,,,,,,,,,,,, -800,0.11965273,0.25279245,,,,,,,,,,,,,, -842,,,0.7335096086774554,0.2768257004874093,0.7114909037352279,0.2976826768144168,3554.0,0.7286815669942055,0.2997143534234501,3581.0,295.54242730140686,511.4811406135559,295.54242730140686,215.8074471950531,0.0999174118041992,0.0 -900,0.29103836,0.30034697,,,,,,,,,,,,,, -1000,0.15712154,0.23177695,,,,,,,,,,,,,, -1100,0.36544657,0.3403817,,,,,,,,,,,,,, -1143,,,0.7372263499668666,0.274808577128819,0.7141574222047341,0.2963607178421321,3554.0,0.7312676441200083,0.2982516572382889,3581.0,375.6709289550781,595.6661460399628,375.6709289550781,219.82274556159973,0.1299653053283691,0.0 -1200,0.08185545,0.24584895,,,,,,,,,,,,,, -1300,0.4537895,0.1977385,,,,,,,,,,,,,, -1400,0.19653787,0.29722852,,,,,,,,,,,,,, -1491,,,0.7373614992414202,0.2738418238503592,0.7141864800225098,0.2952618445985157,3554.0,0.7314674017383412,0.2970154098344562,3581.0,455.93432664871216,679.9849674701691,455.93432664871216,223.84068298339844,0.154393196105957,0.0 -1500,0.289041,0.32327464,,,,,,,,,,,,,, -1600,0.11743245,0.2754063,,,,,,,,,,,,,, -1700,0.111924976,0.40991858,,,,,,,,,,,,,, -1800,0.24554604,0.26904437,,,,,,,,,,,,,, -1836,,,0.7328919683183942,0.2777445146015712,0.7103634191887662,0.2984948531254396,3554.0,0.727346395254468,0.3004433664653728,3581.0,535.9488203525543,764.0546314716339,535.9488203525543,227.8558185100556,0.1814846992492675,0.0 -1900,0.07067584,0.2949571,,,,,,,,,,,,,, -2000,0.088865705,0.32275692,,,,,,,,,,,,,, -2100,0.10906551,0.29049945,,,,,,,,,,,,,, -2180,,,0.7341750008719308,0.2749068907329014,0.7107485898371553,0.2964675379523951,3554.0,0.7280031410351508,0.2980956008600077,3581.0,616.008457660675,848.1748449802399,616.008457660675,231.8789942264557,0.205850601196289,0.0 -2200,0.19706152,0.21189171,,,,,,,,,,,,,, -2300,0.11969999,0.2825845,,,,,,,,,,,,,, -2400,0.16476458,0.24509615,,,,,,,,,,,,,, -2500,0.16823672,0.22409984,,,,,,,,,,,,,, -2526,,,0.7426437650408063,0.2693015507289341,0.720475677163935,0.2903870006726575,3554.0,0.7374985137487783,0.2919683597567893,3581.0,696.0686941146851,932.3001706600188,696.0686941146851,235.9019057750702,0.2352378368377685,0.0 -2600,0.10476116,0.2725075,,,,,,,,,,,,,, -2700,0.092618085,0.28791255,,,,,,,,,,,,,, -2800,0.081270546,0.35725233,,,,,,,,,,,,,, -2872,,,0.7419191087995257,0.2687500715255737,0.7192985264182963,0.2899843472473797,3554.0,0.7365833102441707,0.2914311276637985,3581.0,776.175074338913,1016.46715259552,776.175074338913,239.91925048828125,0.265221357345581,0.0 -2900,0.26875317,0.30162546,,,,,,,,,,,,,, -3000,0.3407403,0.2841944,,,,,,,,,,,,,, -3100,0.12142796,0.21849114,,,,,,,,,,,,,, -3200,0.2514092,0.2778743,,,,,,,,,,,,,, -3217,,,0.7428614071437291,0.2680133751460484,0.7201804277530599,0.2891478873936058,3554.0,0.7374621755881737,0.2907195678472319,3581.0,856.2194812297821,1100.5676794052124,856.2194812297821,243.938280582428,0.2893610000610351,0.0 -3300,0.109726645,0.34568653,,,,,,,,,,,,,, -3400,0.107273,0.32152534,,,,,,,,,,,,,, -3500,0.12563196,0.393957,,,,,,,,,,,,,, -3561,,,0.7445182800292969,0.2670303412846156,0.7218931903313168,0.2884467902583357,3554.0,0.7390436014294192,0.2898982776939053,3581.0,936.2317168712616,1184.6341347694397,936.2317168712616,247.95599675178528,0.3130578994750976,0.0 -3600,0.056832436,0.3053237,,,,,,,,,,,,,, -3700,0.32053238,0.26724526,,,,,,,,,,,,,, -3800,0.1804074,0.30236804,,,,,,,,,,,,,, -3900,0.19929487,0.2603258,,,,,,,,,,,,,, -3908,,,0.7442232540675572,0.2670188290732248,0.7212261657199635,0.2884106225489765,3554.0,0.7384366927883272,0.2898596215268081,3581.0,1016.3818063735962,1268.844962835312,1016.3818063735962,251.97154355049133,0.3453776836395263,0.0 -4000,0.21481428,0.23814705,,,,,,,,,,,,,, -4100,0.06664949,0.24539179,,,,,,,,,,,,,, -4200,0.086545855,0.24147768,,,,,,,,,,,,,, -4251,,,0.7455113274710519,0.2666709763663156,0.7227235707125774,0.2879890781074142,3554.0,0.7399618728837964,0.2894233249812378,3581.0,1096.402687072754,1352.9219012260437,1096.402687072754,255.9899456501007,0.3698945045471191,0.0 -4300,0.20748767,0.403835,,,,,,,,,,,,,, -4400,0.09701758,0.27993175,,,,,,,,,,,,,, -4500,0.033469684,0.21781811,,,,,,,,,,,,,, -4598,,,0.7461780820574079,0.2661698205130441,0.7227209603175999,0.2878877192182224,3554.0,0.7400011426408475,0.2893448877321279,3581.0,1176.3708062171936,1436.948877811432,1176.3708062171936,260.0105800628662,0.3952672481536865,0.0 -4600,0.14665861,0.24650055,,,,,,,,,,,,,, -4700,0.07161481,0.27612004,,,,,,,,,,,,,, -4800,0.1146897,0.26817727,,,,,,,,,,,,,, -4900,0.11013956,0.3150043,,,,,,,,,,,,,, -4946,,,0.7456933430262974,0.2660641840526035,0.7224849943505557,0.2876404186414075,3554.0,0.7397029379232407,0.2890736127958321,3581.0,1256.493564605713,1521.128168106079,1256.493564605713,264.0284032821655,0.4210107326507568,0.0 -5000,0.07522274,0.32662624,,,,,,,,,,,,,, -5100,0.090458475,0.21510835,,,,,,,,,,,,,, -5200,0.10262881,0.24532181,,,,,,,,,,,,,, -5290,,,0.7448530878339495,0.2675383261271885,0.7219431313089125,0.289152009069886,3554.0,0.7390925522724099,0.2906346538152751,3581.0,1336.5503425598145,1605.240535736084,1336.5503425598145,268.04534125328064,0.4468610286712646,0.0 -5300,0.20392582,0.25187206,,,,,,,,,,,,,, -5400,0.047241814,0.23349695,,,,,,,,,,,,,, -5500,0.11768691,0.22285004,,,,,,,,,,,,,, -5600,0.053304326,0.2312496,,,,,,,,,,,,,, -5638,,,0.7454932076590401,0.2661756787981306,0.722028793480937,0.2879706507797112,3554.0,0.7393302842912245,0.2894326310955389,3581.0,1416.5432102680206,1689.2920603752136,1416.5432102680206,272.06396198272705,0.4738578796386719,0.0 -5700,0.0987636,0.27925915,,,,,,,,,,,,,, -5800,0.09851787,0.2714108,,,,,,,,,,,,,, -5900,0.09927917,0.28622764,,,,,,,,,,,,,, -5983,,,0.7469315528869629,0.2658488409859793,0.7237229398213281,0.2877302540106658,3554.0,0.7408738039086498,0.2891913538903239,3581.0,1496.5398621559143,1773.3470647335052,1496.5398621559143,276.080983877182,0.5024864673614502,0.0 -5983,,,,,,,,,,,1496.5398621559143,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 96cb8d41b..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,372 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -36.58779454231262,0.0,49.55485677719116,1,0,49.55485677719116,0.0010999999940395,6.911952018737793,10000,86.1427435874939,0.0008968430920504,6.913041591644287,0.0010999999940395,6.912304878234863,50000 -54.492894887924194,0.0251517295837402,559.6508867740631,1513,0,559.6508867740631,0.1111000031232833,4.882546424865723,10000,614.2222878932953,0.169782355427742,4.28282356262207,0.150419995188713,4.425992488861084,50000 -72.50932359695435,0.0539841651916503,1069.7121422290802,3024,0,1069.7121422290802,0.2336000055074691,3.956745862960816,10000,1142.3826808929443,0.3342833220958709,3.0926454067230225,0.3082599937915802,3.276564359664917,50000 -90.5704882144928,0.0806784629821777,1579.652539730072,4537,0,1579.652539730072,0.3188000023365021,3.3225812911987305,10000,1670.4646260738373,0.453523576259613,2.397743225097656,0.4207599759101867,2.597283363342285,50000 -108.53957986831664,0.1065936088562011,2089.819321870804,6052,0,2089.819321870804,0.3797000050544739,2.925008773803711,10000,2198.679522037506,0.5290178656578064,2.0105631351470947,0.4911199808120727,2.2131128311157227,50000 -126.48164319992064,0.1320304870605468,2600.062970161438,7568,0,2600.062970161438,0.4047000110149383,2.7944133281707764,10000,2726.9429802894592,0.5653499364852905,1.82849645614624,0.5282999873161316,2.031683921813965,50000 -144.2371587753296,0.1595776081085205,3110.2991478443146,9084,0,3110.2991478443146,0.4335000216960907,2.639951467514038,10000,3255.014739274978,0.6339883208274841,1.5135927200317385,0.5559399724006653,1.890720248222351,50000 -162.33961367607117,0.1865520477294922,3620.485506534576,10600,0,3620.485506534576,0.4480000138282776,2.5838935375213623,10000,3783.384105682373,0.6399075388908386,1.445544958114624,0.5717999935150146,1.8170902729034424,50000 -180.4450373649597,0.2277424335479736,4130.400605678558,12115,0,4130.400605678558,0.4598000347614288,2.4967310428619385,10000,4311.498994588852,0.6345862150192261,1.4771161079406738,0.5790799856185913,1.7761107683181765,50000 -199.68504357337952,0.2561922073364258,4640.448156118393,13631,0,4640.448156118393,0.456900030374527,2.5264737606048584,10000,4840.869259595871,0.6399673223495483,1.4503324031829834,0.5899400115013123,1.7300362586975098,50000 -220.5019817352295,0.2849607467651367,5150.681735038757,15148,0,5150.681735038757,0.4661000072956085,2.474865436553955,10000,5372.002974510193,0.6467235088348389,1.4315388202667236,0.5961399674415588,1.6884781122207642,50000 -241.3727121353149,0.3264124393463135,5660.88941025734,16665,0,5660.88941025734,0.4826000332832336,2.3955681324005127,10000,5903.179100036621,0.6599370241165161,1.3714268207550049,0.6070199608802795,1.642867922782898,50000 -262.90595984458923,0.3627340793609619,6171.083922147751,18182,0,6171.083922147751,0.4863000214099884,2.367225885391236,10000,6434.996114492416,0.7046794891357422,1.1431324481964111,0.6042199730873108,1.6550146341323853,50000 -282.60028171539307,0.4062559604644775,6681.065488100052,19698,0,6681.065488100052,0.4772000312805176,2.437773466110229,10000,6964.769221544266,0.6736288070678711,1.300462245941162,0.6003400087356567,1.6817909479141235,50000 -301.9392716884613,0.4372162818908691,7191.288413763046,21215,0,7191.288413763046,0.4891000092029571,2.3720510005950928,10000,7494.415567398071,0.6816206574440002,1.2650530338287354,0.6165199875831604,1.6044803857803345,50000 -321.9015429019928,0.4763326644897461,7701.532709360123,22732,0,7701.532709360123,0.4749000370502472,2.3722574710845947,10000,8024.71547794342,0.6628069281578064,1.3463976383209229,0.6113799810409546,1.631981372833252,50000 -344.1114830970764,0.5074851512908936,8211.788202762604,24249,0,8211.788202762604,0.4928000271320343,2.328645706176758,10000,8557.265342473984,0.6758809089660645,1.2752517461776731,0.6176199913024902,1.597264289855957,50000 -366.8413505554199,0.5325994491577148,8721.796803951263,25765,0,8721.796803951263,0.4853000342845917,2.364208459854126,10000,9090.08283829689,0.6642617583274841,1.3381074666976929,0.6139199733734131,1.6178455352783203,50000 -391.0030782222748,0.5623111724853516,9231.725385665894,27281,0,9231.725385665894,0.4900000095367431,2.3477091789245605,10000,9624.25598335266,0.7101402878761292,1.1184210777282717,0.6214799880981445,1.5755999088287354,50000 -414.5363411903381,0.5903370380401611,9741.918065309525,28798,0,9741.918065309525,0.4940000176429748,2.274938106536865,10000,10158.063725471497,0.6959701776504517,1.1897066831588743,0.6214599609375,1.5625724792480469,50000 -440.34696435928345,0.6227197647094727,10251.934691667557,30315,0,10251.934691667557,0.4736000299453735,2.4069299697875977,10000,10693.976108551024,0.6573461294174194,1.3810793161392212,0.590999960899353,1.707097411155701,50000 -464.6104917526245,0.6552050113677979,10762.190105199814,31832,0,10762.190105199814,0.4913000166416168,2.338101863861084,10000,11228.582449913025,0.6697026491165161,1.3070857524871826,0.6118599772453308,1.6050184965133667,50000 -488.4992530345917,0.6910066604614258,11272.451827764511,33349,0,11272.451827764511,0.4957000315189361,2.31124234199524,10000,11762.822406291962,0.6818000674247742,1.2354713678359983,0.6244399547576904,1.547203540802002,50000 -513.3233184814453,0.7194736003875732,11782.472059488297,34867,0,11782.472059488297,0.4843000173568725,2.4063782691955566,10000,12297.74937415123,0.6680484414100647,1.3182141780853271,0.6115800142288208,1.6301088333129885,50000 -537.8653752803802,0.7547996044158936,12292.574988365172,36384,0,12292.574988365172,0.496500015258789,2.2949490547180176,10000,12832.48405623436,0.7164580821990967,1.0982540845870972,0.6311799883842468,1.5274664163589478,50000 -558.8393158912659,0.7837421894073486,12802.7119576931,37900,0,12802.7119576931,0.5042999982833862,2.2614681720733643,10000,13363.678293943403,0.7069514989852905,1.1378610134124756,0.6352800130844116,1.50535249710083,50000 -581.3548767566681,0.8163936138153076,13312.825773000715,39417,0,13312.825773000715,0.4983000159263611,2.2956364154815674,10000,13896.394639730452,0.6930006146430969,1.2103509902954102,0.6251599788665771,1.5366623401641846,50000 -602.5862288475037,0.8450939655303955,13822.957331418993,40934,0,13822.957331418993,0.5011000037193298,2.2670748233795166,10000,14427.839611768724,0.6949737071990967,1.195453643798828,0.6284799575805664,1.5238358974456787,50000 -623.4006247520447,0.8766939640045166,14333.02351474762,42451,0,14333.02351474762,0.498600035905838,2.3081836700439453,10000,14958.806265830994,0.686543345451355,1.2287561893463137,0.6284199953079224,1.5451955795288086,50000 -641.7515881061554,0.9145431518554688,14842.94656443596,43967,0,14842.94656443596,0.5024000406265259,2.258817195892334,10000,15487.173084020616,0.704500138759613,1.1581501960754397,0.6357600092887878,1.5145957469940186,50000 -659.6599915027618,0.949812650680542,15353.166466712952,45484,0,15353.166466712952,0.5052000284194946,2.242684364318848,10000,16015.389912128448,0.720723032951355,1.0734890699386597,0.6385200023651123,1.4788291454315186,50000 -677.5013723373413,0.9880750179290771,15863.43516588211,47001,0,15863.43516588211,0.5174000263214111,2.193423271179199,10000,16543.592682123184,0.7143255472183228,1.107452392578125,0.6450999975204468,1.4719318151474,50000 -696.4210996627808,1.0227127075195312,16373.61761522293,48518,0,16373.61761522293,0.5217000246047974,2.1961843967437744,10000,17072.783661842346,0.7139468789100647,1.1055635213851929,0.6447799801826477,1.4599003791809082,50000 -714.1883132457733,1.057837963104248,16883.693967342377,50035,0,16883.693967342377,0.5103999972343445,2.2414681911468506,10000,17600.717061519623,0.703523576259613,1.156677484512329,0.6417999863624573,1.461700439453125,50000 -732.4035475254059,1.0941922664642334,17393.82016658783,51552,0,17393.82016658783,0.5249000191688538,2.136894941329956,10000,18129.149336099625,0.7147839665412903,1.111628174781799,0.6556400060653687,1.4142948389053345,50000 -750.0831353664398,1.131983757019043,17903.856494903564,53069,0,17903.856494903564,0.5270000100135803,2.1360621452331543,10000,18656.95786833763,0.7700493931770325,0.8872178196907043,0.6549599766731262,1.4167771339416504,50000 -770.7322072982788,1.1675801277160645,18414.08018231392,54587,0,18414.08018231392,0.5200000405311584,2.1638076305389404,10000,19187.921802043915,0.7238121628761292,1.0641529560089111,0.6489999890327454,1.4466161727905271,50000 -788.9550342559814,1.20243501663208,18924.0583422184,56103,0,18924.0583422184,0.5175000429153442,2.1802403926849365,10000,19716.214436769485,0.7228754758834839,1.0718671083450315,0.6495999693870544,1.4355239868164062,50000 -809.2250037193298,1.2412221431732178,19434.272382974625,57621,0,19434.272382974625,0.5300000309944153,2.1339540481567383,10000,20246.79317426681,0.7300701141357422,1.035422444343567,0.6582599878311157,1.3901373147964478,50000 -826.5554084777832,1.2777154445648191,19944.21973395348,59137,0,19944.21973395348,0.5121999979019165,2.223588466644287,10000,20774.16160988808,0.7110969424247742,1.1127877235412598,0.6474999785423279,1.4427329301834106,50000 -844.3818547725677,1.316425323486328,20454.302026748657,60655,0,20454.302026748657,0.5234000086784363,2.151703119277954,10000,21302.164501667023,0.7175143361091614,1.090225100517273,0.6553599834442139,1.408957600593567,50000 -862.1902160644531,1.356015682220459,20964.30534338951,62173,0,20964.30534338951,0.51910001039505,2.190683364868164,10000,21830.071031332016,0.7569156289100647,0.9074265360832214,0.6525200009346008,1.410874843597412,50000 -881.3823571205139,1.397477388381958,21474.428068876263,63691,0,21474.428068876263,0.5295000076293945,2.1336960792541504,10000,22359.483619451523,0.7365473508834839,1.001426100730896,0.6548999547958374,1.4198815822601318,50000 -899.5424103736877,1.4358513355255127,21984.534370660785,65209,0,21984.534370660785,0.5376999974250793,2.088006019592285,10000,22887.84405231476,0.7339963316917419,1.0098152160644531,0.6584599614143372,1.3867906332015991,50000 -916.836306333542,1.4768211841583252,22494.692160129547,66726,0,22494.692160129547,0.5291000008583069,2.114695310592652,10000,23415.391426324844,0.7247887253761292,1.0592175722122192,0.6513000130653381,1.425256371498108,50000 -934.1671011447906,1.5193002223968506,23004.89267778397,68243,0,23004.89267778397,0.531000018119812,2.121507167816162,10000,23943.01952242852,0.7330994606018066,1.0156515836715698,0.6629999876022339,1.369211196899414,50000 -952.1961376667024,1.5602679252624512,23514.99297809601,69760,0,23514.99297809601,0.5324000120162964,2.080422163009644,10000,24471.2444729805,0.7288743257522583,1.038947582244873,0.6620599627494812,1.3821091651916504,50000 -969.347799539566,1.6115102767944336,24025.17360162735,71277,0,24025.17360162735,0.5273000001907349,2.185938835144043,10000,24998.686302900314,0.7565967440605164,0.9106816649436952,0.6576399803161621,1.408387303352356,50000 -986.60405087471,1.6504340171813965,24535.09615302086,72794,0,24535.09615302086,0.5297999978065491,2.1230790615081787,10000,25525.9612326622,0.7460139989852905,0.9631520509719848,0.6609399914741516,1.3794595003128052,50000 -1004.921329498291,1.6894316673278809,25045.33080291748,74311,0,25045.33080291748,0.5426000356674194,2.07227635383606,10000,26054.60822916031,0.7478475570678711,0.971817135810852,0.6683200001716614,1.3521226644515991,50000 -1023.604052066803,1.732506513595581,25555.53547239304,75829,0,25555.53547239304,0.5388000011444092,2.0896012783050537,10000,26583.593682527546,0.7421077489852905,0.9846024513244628,0.6684799790382385,1.3483023643493652,50000 -1041.087366104126,1.7743992805480957,26065.73763155937,77346,0,26065.73763155937,0.5294000506401062,2.150460720062256,10000,27111.376794815063,0.7206632494926453,1.0757194757461548,0.6527599692344666,1.4330989122390747,50000 -1058.3813076019287,2.6917877197265625,26574.949008226395,78861,0,26574.949008226395,0.5441000461578369,2.095369815826416,10000,27638.854197502136,0.7466318607330322,0.96878182888031,0.6676799654960632,1.3544172048568726,50000 -1075.6019878387451,2.734912633895874,27084.90593457222,80378,0,27084.90593457222,0.5437000393867493,2.056852340698242,10000,28166.12890410424,0.7697305083274841,0.8728119730949402,0.6732000112533569,1.3282780647277832,50000 -1092.8077733516693,2.7800300121307373,27594.912347316746,81894,0,27594.912347316746,0.5425000190734863,2.098299026489258,10000,28693.44144463539,0.748066782951355,0.9496461153030396,0.6683799624443054,1.3535780906677246,50000 -1110.1627497673037,2.821474552154541,28104.9787709713,83411,0,28104.9787709713,0.5421000123023987,2.059842586517334,10000,29220.95869898796,0.7523118257522583,0.942414402961731,0.6723600029945374,1.330456018447876,50000 -1127.4880149364471,2.866657495498657,28615.038668632507,84928,0,28615.038668632507,0.5400000214576721,2.069480657577514,10000,29748.444899082184,0.7465720772743225,0.9657593369483948,0.6717000007629395,1.3397146463394165,50000 -1145.004815340042,2.921449422836304,29125.09669423104,86444,0,29125.09669423104,0.5557000041007996,2.0219106674194336,10000,30276.12886691093,0.7580317258834839,0.906674325466156,0.6821799874305725,1.300117254257202,50000 -1161.9832620620728,2.962131977081299,29635.17948579788,87961,0,29635.17948579788,0.5504000186920166,2.0292537212371826,10000,30803.28497552872,0.7918526530265808,0.7671975493431091,0.6766999959945679,1.3069652318954468,50000 -1179.1324508190155,3.006996393203736,30145.17422771454,89478,0,30145.17422771454,0.5547000169754028,2.012035369873047,10000,31330.530110120773,0.7750119566917419,0.8389298915863037,0.6815599799156189,1.3008902072906494,50000 -1196.418663978577,3.062368392944336,30655.103055000305,90996,0,30655.103055000305,0.5522000193595886,2.0209715366363525,10000,31857.856000185013,0.7703682780265808,0.8547177910804749,0.6854400038719177,1.282153844833374,50000 -1213.9319269657135,3.1073718070983887,31165.160599946976,92513,0,31165.160599946976,0.55840003490448,2.015486001968384,10000,32385.52663731575,0.7651067972183228,0.8746317625045776,0.6829400062561035,1.2891305685043335,50000 -1231.2555334568024,3.1558098793029785,31675.272496938705,94031,0,31675.272496938705,0.5626000165939331,1.9711854457855225,10000,32913.0654964447,0.7631736397743225,0.8782700896263123,0.6861799955368042,1.2679184675216677,50000 -1248.4793949127195,3.199093580245972,32185.48177075386,95548,0,32185.48177075386,0.5527000427246094,2.0591471195220947,10000,33440.59732437134,0.7469507455825806,0.953036606311798,0.6697399616241455,1.3450031280517578,50000 -1265.758617401123,3.2494664192199707,32695.579294204712,97065,0,32695.579294204712,0.5666000247001648,1.9498519897460933,10000,33968.07897615433,0.8135961294174194,0.6818169951438904,0.6930800080299377,1.2403595447540283,50000 -1282.8943948745728,3.295815944671631,33205.586966753006,98581,0,33205.586966753006,0.5488000512123108,2.041500806808472,10000,34495.32242035866,0.7823262214660645,0.8076447248458862,0.6801799535751343,1.2974464893341064,50000 -1299.8828177452087,3.344213485717773,33715.81254816055,100019,0,33715.81254816055,0.560200035572052,1.9564324617385864,10000,35022.6372282505,0.7859135866165161,0.7940120697021484,0.6896599531173706,1.251431226730347,50000 -1317.2859869003296,3.389284610748291,34225.73194885254,101536,0,34225.73194885254,0.5621000528335571,1.9738190174102783,10000,35550.06074118614,0.7773038744926453,0.815521776676178,0.6916799545288086,1.2453638315200806,50000 -1334.4074614048004,3.426310539245605,34735.81096434593,103053,0,34735.81096434593,0.5625,1.970668077468872,10000,36077.352365493774,0.7824656963348389,0.8034330010414124,0.6937800049781799,1.2375240325927734,50000 -1351.678982257843,3.4716062545776367,35245.739324092865,104570,0,35245.739324092865,0.5625,2.004796266555786,10000,36604.654005527496,0.7754902839660645,0.8326342701911926,0.6926199793815613,1.252827763557434,50000 -1369.0937926769257,3.5167410373687744,35755.65688419342,106087,0,35755.65688419342,0.5612000226974487,1.9858746528625488,10000,37132.086525440216,0.8234414458274841,0.6515009999275208,0.6938199996948242,1.2419836521148682,50000 -1386.1959176063538,3.563281774520874,36265.8586742878,107604,0,36265.8586742878,0.5699000358581543,1.948879361152649,10000,37659.49220252037,0.8073580861091614,0.7141561508178711,0.6963199973106384,1.2244770526885986,50000 -1403.2778811454773,3.612172365188599,36775.92725396156,109121,0,36775.92725396156,0.5733000040054321,1.927659273147583,10000,38186.74701857567,0.7997050285339355,0.7437049746513367,0.6982399821281433,1.220030665397644,50000 -1420.5989344120026,3.665136575698853,37285.923748254776,110638,0,37285.923748254776,0.5708000063896179,1.961884617805481,10000,38714.17147755623,0.7970942258834839,0.739464282989502,0.7014200091362,1.2130403518676758,50000 -1437.905579805374,3.713687658309937,37796.09580469132,112156,0,37796.09580469132,0.566100001335144,1.9826350212097168,10000,39241.75362706184,0.7895607352256775,0.7663513422012329,0.6975199580192566,1.2365078926086426,50000 -1455.1328389644625,3.761552095413208,38306.05176615715,113674,0,38306.05176615715,0.570900022983551,1.9535249471664429,10000,39769.040078401566,0.7941246628761292,0.7522966861724854,0.6987999677658081,1.2218286991119385,50000 -1472.192055463791,3.811495065689087,38816.13313245773,115192,0,38816.13313245773,0.5716000199317932,1.930112361907959,10000,40296.28511285782,0.832051157951355,0.6108251214027405,0.700659990310669,1.2239474058151243,50000 -1489.6076259613037,3.858402729034424,39326.16049218178,116710,0,39326.16049218178,0.5719000101089478,1.968600034713745,10000,40823.82935762405,0.8093311190605164,0.6926111578941345,0.7005800008773804,1.2249276638031006,50000 -1506.8389210700989,3.907601833343506,39836.12970209122,118228,0,39836.12970209122,0.5803000330924988,1.9361289739608765,10000,41351.13400554657,0.8143933415412903,0.6682717204093933,0.7066400051116943,1.1898187398910522,50000 -1523.9704895019531,3.954399585723877,40346.30208206177,119746,0,40346.30208206177,0.5770000219345093,1.9084508419036863,10000,41878.540466308594,0.8092713356018066,0.6910978555679321,0.7084599733352661,1.1896438598632812,50000 -1541.1607983112335,4.005815267562866,40856.50037384033,121263,0,40856.50037384033,0.5848000049591064,1.903845191001892,10000,42406.03638720512,0.8119817972183228,0.6661557555198669,0.7090799808502197,1.1816290616989136,50000 -1558.2389419078827,4.051869869232178,41366.60287451744,122780,0,41366.60287451744,0.5860000252723694,1.884270191192627,10000,42933.3178434372,0.8156090378761292,0.6675543189048767,0.7113800048828125,1.1680610179901123,50000 -1575.457461833954,4.100099802017212,41876.761585474014,124297,0,41876.761585474014,0.5910000205039978,1.8822606801986688,10000,43460.797714948654,0.8449258208274841,0.5565405488014221,0.710099995136261,1.1687265634536743,50000 -1592.5177104473114,4.146549224853516,42386.82617163658,125814,0,42386.82617163658,0.5781000256538391,1.9235954284667969,10000,43988.02368068695,0.8313934803009033,0.5971071124076843,0.714139997959137,1.1716444492340088,50000 -1609.7732861042025,4.200785160064697,42896.90095162392,127332,0,42896.90095162392,0.5875000357627869,1.8843411207199097,10000,44515.46164655685,0.8355787396430969,0.5848914980888367,0.7186799645423889,1.14486563205719,50000 -1627.1122126579285,4.968811750411987,43406.36694979668,128849,0,43406.36694979668,0.5900000333786011,1.9271339178085327,10000,45043.08939242363,0.8219467401504517,0.6358543634414673,0.7104799747467041,1.1965482234954834,50000 -1644.6134796142578,5.027925252914429,43916.54048705101,130366,0,43916.54048705101,0.5892000198364258,1.9023849964141848,10000,45570.87974739075,0.8296396732330322,0.599907636642456,0.718559980392456,1.1586591005325315,50000 -1662.0161790847778,5.078296184539795,44426.55810856819,131884,0,44426.55810856819,0.6004000306129456,1.8547765016555784,10000,46098.40629863739,0.8519012928009033,0.5267131328582764,0.7257999777793884,1.124706149101257,50000 -1679.2852370738983,5.126734733581543,44936.45907497406,133400,0,44936.45907497406,0.5948000550270081,1.873829960823059,10000,46625.67883968353,0.8651546239852905,0.4772387146949768,0.7245199680328369,1.1348223686218262,50000 -1696.497545480728,5.18248438835144,45446.67143511772,134918,0,45446.67143511772,0.5969000458717346,1.9024995565414429,10000,47153.21483302117,0.8583386540412903,0.5006170272827148,0.7210999727249146,1.1423414945602417,50000 -1713.8090782165527,5.246311902999878,45956.86107802391,136436,0,45956.86107802391,0.5994000434875488,1.866504430770874,10000,47680.83400058746,0.852937638759613,0.5173604488372803,0.7234799861907959,1.1310100555419922,50000 -1731.1785836219788,5.296331167221069,46466.94680929184,137953,0,46466.94680929184,0.6015000343322754,1.8601411581039429,10000,48208.39493966103,0.8542529940605164,0.5037813782691956,0.7277799844741821,1.1227219104766846,50000 -1748.1684920787811,5.347053289413452,46977.11908912659,139471,0,46977.11908912659,0.5992000102996826,1.8634774684906008,10000,48735.66280722618,0.8598333597183228,0.4867278039455414,0.7307999730110168,1.1039236783981323,50000 -1765.2932143211365,5.399057388305664,47487.33484148979,140988,0,47487.33484148979,0.6034000515937805,1.8561393022537231,10000,49263.10889315605,0.8991350531578064,0.3605747520923614,0.7303000092506409,1.114681601524353,50000 -1782.4610340595243,5.45065188407898,47997.32077693939,142506,0,47997.32077693939,0.6010000109672546,1.8697199821472168,10000,49790.36906766892,0.8854033350944519,0.3915910422801971,0.7321400046348572,1.103589653968811,50000 -1799.7154169082642,5.505875587463379,48507.233662605286,144023,0,48507.233662605286,0.6022000312805176,1.8817957639694207,10000,50317.64517068863,0.8768733739852905,0.4285891652107239,0.7296800017356873,1.1146568059921265,50000 -1816.9821512699127,5.555546760559082,49017.14409446716,145540,0,49017.14409446716,0.6028000116348267,1.8748246431350708,10000,50844.926412820816,0.8807597160339355,0.4134978055953979,0.7312399744987488,1.1134945154190063,50000 -1834.1997547149656,5.6061952114105225,49527.31033730507,147058,0,49527.31033730507,0.6084000468254089,1.8452149629592896,10000,51372.41536259651,0.8866389989852905,0.3935048282146454,0.7372199892997742,1.0860953330993652,50000 -1851.286039590836,5.658465147018433,50037.51057219505,148576,0,50037.51057219505,0.6083000302314758,1.8626391887664795,10000,51899.8086400032,0.886738657951355,0.3839896619319916,0.7354399561882019,1.1088933944702148,50000 -1868.521843194961,5.710692644119263,50547.72637343407,150094,0,50547.72637343407,0.6107000112533569,1.8475301265716555,10000,52427.36676931381,0.9194435477256776,0.2786987125873565,0.7404599785804749,1.088493824005127,50000 -1885.9064099788663,5.763494491577148,51057.75116443634,151612,0,51057.75116443634,0.6117000579833984,1.8623268604278564,10000,52954.8849272728,0.9063894748687744,0.3237541913986206,0.7383399605751038,1.0937960147857666,50000 -1903.927239894867,5.821335554122925,51567.71314716339,153129,0,51567.71314716339,0.6148000359535217,1.8307033777236936,10000,53482.98257446289,0.9104551672935486,0.3100117743015289,0.7428399920463562,1.072890043258667,50000 -1921.3516061306,5.878023862838745,52077.77292633057,154646,0,52077.77292633057,0.6117000579833984,1.857428789138794,10000,54010.57850217819,0.908023715019226,0.3052965700626373,0.7430599927902222,1.0862489938735962,50000 -1938.523027420044,5.9307475090026855,52587.91961193085,156163,0,52587.91961193085,0.616100013256073,1.843456506729126,10000,54538.00402379036,0.9146803021430968,0.2902111411094665,0.744879961013794,1.0719773769378662,50000 -1955.8037884235384,6.003688812255859,53097.9214026928,157680,0,53097.9214026928,0.619100034236908,1.838584065437317,10000,55065.41555118561,0.9210578799247742,0.2730793952941894,0.7454400062561035,1.0773038864135742,50000 -1973.0448276996613,6.057266712188721,53607.94852924347,159197,0,53607.94852924347,0.6219000220298767,1.8522708415985107,10000,55592.79269909859,0.9376793503761292,0.2213052958250045,0.7470600008964539,1.0691070556640625,50000 -1990.2089619636536,6.115020275115967,54118.00081586838,160713,0,54118.00081586838,0.6206000447273254,1.8397870063781736,10000,56120.12428307533,0.9351084232330322,0.2282746881246566,0.747439980506897,1.0624006986618042,50000 -2007.458515882492,6.172126054763794,54627.96944499016,162230,0,54627.96944499016,0.6203000545501709,1.8488606214523315,10000,56647.45462155342,0.9322983026504515,0.2332929223775863,0.7475199699401855,1.0625967979431152,50000 -2024.5153052806847,6.228562831878662,55137.90982818604,163746,0,55137.90982818604,0.6248000264167786,1.8343461751937864,10000,57174.564274311066,0.9357860088348388,0.2209810614585876,0.7492600083351135,1.0612397193908691,50000 -2041.534103155136,6.281407117843628,55647.922043800354,165263,0,55647.922043800354,0.6249000430107117,1.8307875394821167,10000,57701.702865600586,0.9393334984779358,0.211116150021553,0.7497999668121338,1.0586296319961548,50000 -2058.5210251808167,6.337911128997803,56157.90705728531,166780,0,56157.90705728531,0.6290000081062317,1.834994435310364,10000,58228.787316560745,0.9436383843421936,0.2011850476264953,0.751800000667572,1.0534090995788574,50000 -2075.694316625595,6.394445896148682,56668.11461615562,168298,0,56668.11461615562,0.6277000308036804,1.8354393243789675,10000,58756.28093838692,0.9491389989852904,0.1779803782701492,0.752079963684082,1.0596641302108765,50000 -2092.9129543304443,6.455153465270996,57178.27260637283,169815,0,57178.27260637283,0.6274000406265259,1.8301844596862795,10000,59283.77220726013,0.950215220451355,0.1770836114883422,0.7531399726867676,1.052578091621399,50000 -2110.01118683815,6.511188507080078,57688.4286673069,171333,0,57688.4286673069,0.627500057220459,1.8310190439224243,10000,59811.13905739784,0.9529256820678712,0.1717437803745269,0.7522000074386597,1.053192973136902,50000 -2127.100840330124,6.570757627487183,58198.459025383,172850,0,58198.459025383,0.6271000504493713,1.833840489387512,10000,60338.37391376495,0.9522879123687744,0.1757299751043319,0.7527799606323242,1.0505002737045288,50000 -2144.1066920757294,6.630687713623047,58708.66118121147,174368,0,58708.66118121147,0.6290000081062317,1.826395988464356,10000,60865.69796657562,0.9563336968421936,0.1650478243827819,0.7552399635314941,1.0452936887741089,50000 -2161.2147500514984,6.689452409744263,59218.857328653336,175886,0,59218.857328653336,0.631600022315979,1.824371576309204,10000,61393.11685776711,0.9587850570678712,0.1511284112930297,0.7549600005149841,1.0456360578536987,50000 -2178.15860247612,6.749515771865845,59728.900113105774,177404,0,59728.900113105774,0.6291000247001648,1.8281220197677608,10000,61920.21709442139,0.9591039419174194,0.1507154107093811,0.7559399604797363,1.0450918674468994,50000 -2195.6819846630096,7.62337064743042,60238.09110379219,178918,0,60238.09110379219,0.629300057888031,1.8209997415542605,10000,62447.85934686661,0.959004282951355,0.1506407111883163,0.756659984588623,1.040774703025818,50000 -2212.891495943069,7.698556661605835,60748.00309252739,180435,0,60748.00309252739,0.6309000253677368,1.8259391784667969,10000,62975.11134815216,0.9593430757522584,0.148685485124588,0.7560799717903137,1.0443583726882937,50000 -2229.9518847465515,7.7558934688568115,61258.025297403336,181954,0,61258.025297403336,0.6309000253677368,1.8252125978469849,10000,63502.30787181854,0.959203600883484,0.1488100439310073,0.7558799982070923,1.0424973964691162,50000 -2247.0036346912384,7.818108081817627,61767.9468460083,183470,0,61767.9468460083,0.6320000290870667,1.821579337120056,10000,64029.39771032333,0.9610969424247742,0.1481508761644363,0.7571799755096436,1.0403326749801636,50000 -2264.17701625824,7.874583959579468,62278.05148458481,184987,0,62278.05148458481,0.6317000389099121,1.8240126371383667,10000,64556.78928041458,0.959781527519226,0.1489413529634475,0.7575399875640869,1.041106343269348,50000 -2281.1940383911133,7.932438135147095,62787.94189047813,186503,0,62787.94189047813,0.6320000290870667,1.8241097927093504,10000,65083.8095471859,0.9605189561843872,0.145164668560028,0.7573599815368652,1.0415334701538086,50000 -2298.656086206436,7.991012096405029,63297.95154929161,188020,0,63297.95154929161,0.6317000389099121,1.8217973709106443,10000,65611.39433121681,0.9613759517669678,0.1438952833414077,0.7572599649429321,1.0408782958984375,50000 -2315.777137756348,8.041492938995361,63808.01502132416,189537,0,63808.01502132416,0.631100058555603,1.8226947784423828,10000,66138.68200182915,0.962332546710968,0.142117902636528,0.757420003414154,1.0418941974639893,50000 -2332.939744949341,8.108099699020386,64317.99254322052,191055,0,64317.99254322052,0.6321000456809998,1.822113394737244,10000,66665.94344830513,0.961734652519226,0.1454230993986129,0.7573399543762207,1.0415695905685425,50000 -2350.846947908401,8.166432857513428,64828.0078830719,192572,0,64828.0078830719,0.6324000358581543,1.8228449821472168,10000,67193.97899508476,0.959741711616516,0.1477141678333282,0.757319986820221,1.040766954421997,50000 -2367.845665693283,8.226392269134521,65338.185666799545,194089,0,65338.185666799545,0.6312000155448914,1.823163986206055,10000,67721.27209353447,0.9594626426696776,0.1482044607400894,0.7571600079536438,1.0418704748153689,50000 -2385.0194478034973,8.285529613494873,65848.35327005386,195606,0,65848.35327005386,0.6314000487327576,1.822136402130127,10000,68248.72687864304,0.9624919891357422,0.142418086528778,0.7576199769973755,1.0406326055526731,50000 -2401.994375705719,8.346335649490356,66358.55031299591,197123,0,66358.55031299591,0.6328000426292419,1.821480393409729,10000,68776.01489758492,0.9610969424247742,0.1462296843528747,0.7571399807929993,1.039520263671875,50000 -2419.071531057358,8.412399053573608,66868.47186207771,198640,0,66868.47186207771,0.6317000389099121,1.8206541538238523,10000,69303.13593435287,0.9618343114852904,0.1448915749788284,0.7569999694824219,1.041074275970459,50000 -2436.330323457718,8.475186109542847,67378.39338731766,200156,0,67378.39338731766,0.6319000124931335,1.8226794004440308,10000,69830.4341545105,0.960180163383484,0.1485601663589477,0.7572199702262878,1.0412520170211792,50000 -2453.5483391284943,8.535441637039185,67888.32476067543,201673,0,67888.32476067543,0.6309000253677368,1.823739767074585,10000,70357.69950819016,0.9604591727256776,0.1474022418260574,0.7573800086975098,1.0413967370986938,50000 -2471.0616297721863,8.596054077148438,68398.4266242981,203191,0,68398.4266242981,0.631100058555603,1.822848916053772,10000,70885.42971634865,0.9622727632522584,0.1433053612709045,0.7568999528884888,1.0419265031814575,50000 -2487.9953002929688,8.656528234481812,68908.5199649334,204708,0,68908.5199649334,0.6314000487327576,1.8237817287445068,10000,71412.57393550873,0.9606186151504515,0.1448125839233398,0.7568399906158447,1.0413434505462646,50000 -2504.9435436725616,8.719544410705566,69418.53223371506,206225,0,69418.53223371506,0.6313000321388245,1.822572112083435,10000,71939.65333914757,0.9596420526504515,0.1503311842679977,0.7572799921035767,1.041097640991211,50000 -2522.005439043045,8.781482458114624,69928.57495713234,207743,0,69928.57495713234,0.6319000124931335,1.8207805156707764,10000,72466.876288414,0.9592633843421936,0.1481289863586425,0.7573599815368652,1.0401948690414429,50000 -2539.191326379776,8.84352707862854,70438.65213561058,209260,0,70438.65213561058,0.6304000020027161,1.822595238685608,10000,72994.25761318207,0.9608577489852904,0.1483916938304901,0.7569199800491333,1.0417641401290894,50000 -2556.391278028488,8.910342454910278,70948.76120257378,210777,0,70948.76120257378,0.6321000456809998,1.8224416971206665,10000,73521.68989634514,0.9615951776504515,0.1449353992938995,0.7570399641990662,1.041481375694275,50000 -2573.594845056534,8.976428747177124,71458.87047481537,212295,0,71458.87047481537,0.6315000057220459,1.823248147964477,10000,74049.12346696854,0.9602199792861938,0.1497755348682403,0.7570199966430664,1.0418096780776978,50000 -2590.599480867386,9.042627096176147,71968.89846563339,213812,0,71968.89846563339,0.6319000124931335,1.823547601699829,10000,74576.27930808067,0.9613958597183228,0.1449228227138519,0.7572999596595764,1.0410807132720947,50000 -2607.659962415695,9.106708526611328,72479.10156702995,215329,0,72479.10156702995,0.631600022315979,1.822715163230896,10000,75103.66161513329,0.9591438174247742,0.1498548686504364,0.7572999596595764,1.0407342910766602,50000 -2625.0996906757355,9.168915510177612,72989.02272677422,216845,0,72989.02272677422,0.6318000555038452,1.823696732521057,10000,75631.14029312134,0.9608777165412904,0.1457101702690124,0.7572999596595764,1.041617512702942,50000 -2642.055506706238,9.225011825561523,73498.98733758926,218361,0,73498.98733758926,0.631600022315979,1.821586012840271,10000,76158.17280721664,0.960359513759613,0.1462864428758621,0.7573399543762207,1.04054057598114,50000 -2659.2388412952423,9.289936542510986,74009.09196305275,219878,0,74009.09196305275,0.6318000555038452,1.8223992586135864,10000,76685.58260345459,0.9594826102256776,0.1475639790296554,0.7575199604034424,1.040326476097107,50000 -2676.3662524223328,9.35293436050415,74519.10008716583,221396,0,74519.10008716583,0.6323000192642212,1.8241325616836548,10000,77212.83652758598,0.9596819281578064,0.1482045650482177,0.7574999928474426,1.0406891107559204,50000 -2693.648894548416,9.417216300964355,75029.03829264641,222912,0,75029.03829264641,0.6314000487327576,1.822832465171814,10000,77740.1799068451,0.9620137214660645,0.1440362632274627,0.7574999928474426,1.0408563613891602,50000 -2710.5937311649323,9.481563091278076,75538.97178125381,224427,0,75538.97178125381,0.6310000419616699,1.8239521980285645,10000,78267.17767620087,0.9602399468421936,0.1480806469917297,0.7569999694824219,1.0420169830322266,50000 -2727.824291229248,9.551263332366943,76049.03899121284,225944,0,76049.03899121284,0.6314000487327576,1.8221471309661863,10000,78794.60010194778,0.9607979655265808,0.1444706618785858,0.7567799687385559,1.04023540019989,50000 -2744.825801372528,9.621059656143188,76558.91853904724,227460,0,76558.91853904724,0.6312000155448914,1.820837140083313,10000,79321.60742902756,0.9618343114852904,0.1431108713150024,0.7572000026702881,1.040708303451538,50000 -2762.005347251892,10.524330615997314,77068.11691904068,228975,0,77068.11691904068,0.6309000253677368,1.8225457668304443,10000,79848.94569897652,0.9617745280265808,0.1452770978212356,0.7570799589157104,1.039811372756958,50000 -2779.7181718349457,10.597296953201294,77578.0630581379,230492,0,77578.0630581379,0.6309000253677368,1.8216196298599243,10000,80376.73307228088,0.960957407951355,0.1454789787530899,0.7569800019264221,1.040593504905701,50000 -2796.831855535507,10.66297459602356,78088.05710411072,232009,0,78088.05710411072,0.6317000389099121,1.8230193853378296,10000,80903.96130156517,0.96000075340271,0.1493954509496688,0.7570199966430664,1.041401743888855,50000 -2813.9345309734344,10.7299907207489,78597.98800444603,233526,0,78597.98800444603,0.6313000321388245,1.8219051361083984,10000,81431.11747145653,0.9608178734779358,0.1444471329450607,0.7569999694824219,1.0412191152572632,50000 -2830.9884712696075,10.809279441833496,79108.1516199112,235044,0,79108.1516199112,0.6317000389099121,1.8218598365783687,10000,81958.46935725212,0.961694836616516,0.1431909799575805,0.7575399875640869,1.041340708732605,50000 -2848.0664477348328,10.87845540046692,79618.13293385506,236561,0,79618.13293385506,0.6304000020027161,1.8224035501480105,10000,82485.65254235268,0.9612165093421936,0.1462259292602539,0.7571799755096436,1.041209697723389,50000 -2865.148155927658,10.948329448699951,80128.2179043293,238078,0,80128.2179043293,0.6309000253677368,1.8252544403076167,10000,83012.94344353676,0.9614157676696776,0.146426573395729,0.7571600079536438,1.0422078371047974,50000 -2882.1735339164734,11.02028512954712,80638.16833734512,239595,0,80638.16833734512,0.6324000358581543,1.8220387697219849,10000,83540.04543042183,0.9596819281578064,0.1480749398469925,0.7570199966430664,1.0406923294067385,50000 -2899.5223717689514,11.105278968811035,81148.19727015495,241112,0,81148.19727015495,0.6330000162124634,1.8216248750686648,10000,84067.56158614159,0.9610570669174194,0.1477317065000534,0.7570399641990662,1.0410159826278689,50000 -2916.6335439682007,11.176434993743896,81658.35938191414,242629,0,81658.35938191414,0.631100058555603,1.8233078718185425,10000,84594.96130204201,0.9610769748687744,0.1438143998384475,0.7569999694824219,1.0407848358154297,50000 -2933.7407212257385,11.245970726013184,82168.30734395981,244146,0,82168.30734395981,0.6317000389099121,1.8220953941345213,10000,85122.14113402367,0.9615951776504515,0.146777406334877,0.7572999596595764,1.0413265228271484,50000 -2950.919387578964,11.318059921264648,82678.47346019745,245662,0,82678.47346019745,0.631600022315979,1.821165919303894,10000,85649.61295294762,0.959741711616516,0.1484692096710205,0.7572599649429321,1.0395944118499756,50000 -2968.063045740128,11.407188892364502,83188.46709156036,247178,0,83188.46709156036,0.6318000555038452,1.8228518962860107,10000,86176.89586758614,0.9597217440605164,0.149254560470581,0.7569199800491333,1.0409460067749023,50000 -2985.100293636322,11.474862098693848,83698.41069197655,248695,0,83698.41069197655,0.6308000087738037,1.8244123458862305,10000,86704.00005698204,0.9600605964660645,0.1471786648035049,0.7571199536323547,1.0416258573532104,50000 -3002.2465505599976,11.541433811187744,84208.50983929634,250213,0,84208.50983929634,0.6315000057220459,1.821985483169556,10000,87231.36796617508,0.9616350531578064,0.1467591375112533,0.7573399543762207,1.0408724546432495,50000 -3019.4082946777344,11.609270811080933,84718.69186162949,251731,0,84718.69186162949,0.6312000155448914,1.8245079517364504,10000,87758.83411717415,0.960758090019226,0.1473046839237213,0.7570799589157104,1.0428922176361084,50000 -3036.6418414115906,11.676239252090454,85228.6314611435,253248,0,85228.6314611435,0.631600022315979,1.820750117301941,10000,88286.13070821762,0.9601203799247742,0.1469669789075851,0.7571199536323547,1.0400139093399048,50000 -3053.6353392601013,11.7507586479187,85738.69255518913,254765,0,85738.69255518913,0.6317000389099121,1.823580026626587,10000,88813.3157889843,0.9601004123687744,0.1476990282535553,0.7572399973869324,1.041129231452942,50000 -3070.663406610489,11.832293510437012,86248.77793478966,256283,0,86248.77793478966,0.6318000555038452,1.8220078945159912,10000,89340.56578230858,0.960957407951355,0.1457504779100418,0.7570199966430664,1.0411179065704346,50000 -3087.710409402848,11.901975393295288,86758.89575958252,257800,0,86758.89575958252,0.631600022315979,1.8211606740951536,10000,89867.85570526123,0.959741711616516,0.1476667374372482,0.7569800019264221,1.041002869606018,50000 -3104.96017575264,11.974144697189333,87268.98562383652,259317,0,87268.98562383652,0.6313000321388245,1.8210585117340088,10000,90395.32349467278,0.960379421710968,0.1454952508211136,0.7571600079536438,1.0412981510162354,50000 -3122.0596079826355,12.059104919433594,87779.135191679,260833,0,87779.135191679,0.6310000419616699,1.8252462148666384,10000,90922.71321320534,0.9606783986091614,0.146960511803627,0.7566999793052673,1.0429741144180298,50000 -3139.257059574127,12.131905794143677,88289.18821763992,262350,0,88289.18821763992,0.6309000253677368,1.82242488861084,10000,91450.09464883804,0.9612364172935486,0.1476792991161346,0.7573999762535095,1.0416964292526243,50000 -3156.3127546310425,12.205149173736572,88799.24826645851,263866,0,88799.24826645851,0.6320000290870667,1.8234227895736688,10000,91977.33986473083,0.9598214030265808,0.1464400887489318,0.7572999596595764,1.0404446125030518,50000 -3173.3309082984924,12.277857065200806,89309.30315113068,265383,0,89309.30315113068,0.6315000057220459,1.82268226146698,10000,92504.5416522026,0.9618542790412904,0.142778679728508,0.7571399807929993,1.040645956993103,50000 -3191.1166141033173,12.352385759353638,89819.25856900215,266900,0,89819.25856900215,0.6315000057220459,1.821916103363037,10000,93032.41458940506,0.9614357352256776,0.1437123864889145,0.7573399543762207,1.041559815406799,50000 -3208.188056945801,12.414271593093872,90329.31252336502,268417,0,90329.31252336502,0.6318000555038452,1.823684811592102,10000,93559.65687131882,0.9615951776504515,0.1457705348730087,0.7571199536323547,1.0407840013504028,50000 -3225.9346590042114,12.48537278175354,90839.26527810095,269933,0,90839.26527810095,0.631100058555603,1.8229761123657229,10000,94087.48363137244,0.9596420526504515,0.1478172987699508,0.7570199966430664,1.0419752597808838,50000 -3242.923921108246,12.55991506576538,91349.21398591997,271450,0,91349.21398591997,0.6315000057220459,1.824073314666748,10000,94614.55206918716,0.961136758327484,0.1462059915065765,0.7572999596595764,1.0409760475158691,50000 -3260.069388628006,12.63509225845337,91859.17581629752,272967,0,91859.17581629752,0.6314000487327576,1.822332143783569,10000,95141.79025554656,0.9608178734779358,0.1457305401563644,0.7572000026702881,1.041136622428894,50000 -3277.009635448456,12.7144455909729,92369.35075259209,274484,0,92369.35075259209,0.6320000290870667,1.8210731744766235,10000,95669.04021525384,0.961136758327484,0.1444941759109497,0.7572999596595764,1.0405818223953247,50000 -3294.059141635895,12.78620743751526,92879.23972916605,276000,0,92879.23972916605,0.6319000124931335,1.8237287998199463,10000,96196.10747170448,0.9621930718421936,0.1435670554637909,0.7574999928474426,1.0415067672729492,50000 -3310.9949176311493,12.863352537155151,93389.430683136,277518,0,93389.430683136,0.6314000487327576,1.8231300115585327,10000,96723.36788201332,0.9601004123687744,0.149187371134758,0.7570599913597107,1.0416173934936523,50000 -3328.229616165161,12.936316013336182,93899.51947426796,279034,0,93899.51947426796,0.631100058555603,1.8226639032363887,10000,97250.8225440979,0.9613759517669678,0.1479669213294983,0.7570399641990662,1.0405782461166382,50000 -3345.328625202179,13.0116868019104,94409.50527763368,280550,0,94409.50527763368,0.6315000057220459,1.822600483894348,10000,97778.03981542587,0.9606783986091614,0.1443613618612289,0.7573399543762207,1.042070388793945,50000 -3362.324405670166,13.08899974822998,94919.44342923164,282066,0,94919.44342923164,0.6318000555038452,1.8230668306350708,10000,98305.10606789587,0.960957407951355,0.1450634747743606,0.7574599981307983,1.0403265953063965,50000 -3379.3356223106384,13.161826610565186,95429.4895875454,283583,0,95429.4895875454,0.632900059223175,1.823802471160889,10000,98832.29233837128,0.9598811864852904,0.148601621389389,0.7568999528884888,1.0414470434188845,50000 -3396.418870925904,13.239661693572998,95939.8146879673,285101,0,95939.8146879673,0.6318000555038452,1.8214406967163088,10000,99359.83514142036,0.9600605964660645,0.1479718536138534,0.7570399641990662,1.0402021408081057,50000 -3413.324624300003,13.314711093902588,96449.9844045639,286618,0,96449.9844045639,0.6312000155448914,1.822487235069275,10000,99887.04127049446,0.959741711616516,0.1493349075317382,0.7571199536323547,1.040514349937439,50000 -3430.393528699875,13.393523693084717,96959.99516606332,288082,0,96959.99516606332,0.6320000290870667,1.8242065906524656,10000,100414.2544465065,0.961734652519226,0.1440905779600143,0.7569999694824219,1.042284607887268,50000 -3447.3284227848053,13.465181112289429,97469.9808113575,289599,0,97469.9808113575,0.6315000057220459,1.821026086807251,10000,100941.30323529243,0.9604192972183228,0.1508612632751464,0.7572799921035767,1.0402519702911377,50000 -3464.5442264080048,13.540786027908323,97979.8696899414,291116,0,97979.8696899414,0.6319000124931335,1.824628591537476,10000,101468.5384926796,0.9608178734779358,0.1459608823060989,0.7571199536323547,1.0409700870513916,50000 -3481.718322277069,13.615017414093018,98490.01154613496,292633,0,98490.01154613496,0.6313000321388245,1.8246649503707888,10000,101995.98503899574,0.9606385231018066,0.1455745548009872,0.7572000026702881,1.0424156188964844,50000 -3498.651032447815,13.689282655715942,99000.07468652724,294151,0,99000.07468652724,0.6307000517845154,1.823425531387329,10000,102523.11113667488,0.9596819281578064,0.1481228917837143,0.756659984588623,1.041306495666504,50000 -3515.6052017211914,13.764830350875854,99510.17113494872,295669,0,99510.17113494872,0.6312000155448914,1.8231909275054927,10000,103050.29315805437,0.9604591727256776,0.1468521803617477,0.7571399807929993,1.0419366359710691,50000 -3532.513886451721,13.849292993545532,100020.35486221312,297186,0,100020.35486221312,0.6318000555038452,1.8235479593276973,10000,103577.52541542052,0.9598811864852904,0.1468236297369003,0.756659984588623,1.0413254499435425,50000 -3549.5913729667664,13.923123359680176,100530.54751634598,298704,0,100530.54751634598,0.6312000155448914,1.8223090171813965,10000,104104.92567372322,0.9595623016357422,0.14792300760746,0.7567999958992004,1.040809154510498,50000 -3566.6976997852325,14.00077509880066,101040.60511946678,300221,0,101040.60511946678,0.6315000057220459,1.8250768184661863,10000,104632.22343468666,0.962332546710968,0.1443368941545486,0.7571199536323547,1.0422277450561523,50000 -3583.581182718277,14.081845045089722,101550.64585661888,301739,0,101550.64585661888,0.6314000487327576,1.822162628173828,10000,105159.28444576263,0.9604591727256776,0.1479778289794922,0.7568999528884888,1.0407817363739014,50000 -3600.6692349910736,14.160232305526732,102060.69908952712,303256,0,102060.69908952712,0.631600022315979,1.8222219944000244,10000,105686.56081461906,0.9602199792861938,0.1464538872241974,0.7568199634552002,1.0412921905517578,50000 -3617.607753753662,14.239030599594116,102570.83042025566,304773,0,102570.83042025566,0.6315000057220459,1.823304533958435,10000,106213.7646138668,0.96195387840271,0.1408708095550537,0.7572199702262878,1.0406047105789185,50000 -3634.594278812408,14.316243886947632,103080.9529056549,306290,0,103080.9529056549,0.631600022315979,1.824026107788086,10000,106741.00653576852,0.9622727632522584,0.1447004824876785,0.7573800086975098,1.0414577722549438,50000 -3652.3201220035553,14.395713567733765,103590.91629076004,307807,0,103590.91629076004,0.6320000290870667,1.8238807916641235,10000,107268.8312842846,0.9609375,0.1463319510221481,0.7570599913597107,1.0414164066314695,50000 -3669.4485788345337,14.476394414901732,104100.90484285356,309324,0,104100.90484285356,0.6309000253677368,1.8230949640274048,10000,107796.08420968056,0.960160195827484,0.1481978297233581,0.7568399906158447,1.04118549823761,50000 -3686.322530031204,14.564249038696287,104610.92230081558,310841,0,104610.92230081558,0.6326000094413757,1.8219730854034424,10000,108323.1176261902,0.9608178734779358,0.1436031311750412,0.7575199604034424,1.0405359268188477,50000 -3703.3795261383057,14.63990044593811,105120.93588781355,312358,0,105120.93588781355,0.6325000524520874,1.8220093250274656,10000,108850.32050657272,0.9614157676696776,0.1467447876930236,0.7572399973869324,1.0412235260009766,50000 -3720.4144394397736,14.726381301879885,105631.11181282996,313877,0,105631.11181282996,0.6322000026702881,1.8241868019104004,10000,109377.67593812944,0.9605388641357422,0.1448961198329925,0.757420003414154,1.0412286520004272,50000 -3737.495297670365,14.80655813217163,106141.03833842278,315393,0,106141.03833842278,0.631600022315979,1.8226666450500488,10000,109904.81872224808,0.9616350531578064,0.1467013955116272,0.7572599649429321,1.0401941537857056,50000 -3754.913206577301,14.88515329360962,106651.106498003,316910,0,106651.106498003,0.6317000389099121,1.825588345527649,10000,110432.43870258331,0.9612563848495485,0.1465180963277816,0.7571199536323547,1.0433276891708374,50000 -3772.040938138962,14.948791265487673,107161.31473088264,318427,0,107161.31473088264,0.6325000524520874,1.8240481615066528,10000,110959.89461374284,0.960957407951355,0.1460195183753967,0.7572199702262878,1.0418167114257812,50000 -3789.090869903565,15.027590036392212,107671.24874210358,319944,0,107671.24874210358,0.6315000057220459,1.8227121829986568,10000,111487.01357507706,0.96000075340271,0.1463822722434997,0.7570399641990662,1.0409986972808838,50000 -3806.1626420021057,15.110766649246216,108181.3416583538,321461,0,108181.3416583538,0.6323000192642212,1.8240190744400024,10000,112014.31840515137,0.9615951776504515,0.1443527191877365,0.7574599981307983,1.0413386821746826,50000 -3823.353661775589,15.195438861846924,108691.42335653304,322978,0,108691.42335653304,0.6319000124931335,1.8205403089523315,10000,112541.7327632904,0.958804965019226,0.1512446850538253,0.7572799921035767,1.0406066179275513,50000 -3840.5118465423584,15.287701606750488,109201.54509663582,324495,0,109201.54509663582,0.6313000321388245,1.82335364818573,10000,113069.16121602058,0.960180163383484,0.1485398411750793,0.7571799755096436,1.0408989191055298,50000 -3857.648682117462,15.368526697158812,109711.53610014915,326012,0,109711.53610014915,0.631600022315979,1.822590947151184,10000,113596.42569184303,0.960339605808258,0.1465994566679,0.7569999694824219,1.041593074798584,50000 -3874.6263246536255,15.447299480438232,110221.53441381454,327531,0,110221.53441381454,0.6326000094413757,1.823375225067139,10000,114123.5357427597,0.9618144035339355,0.1475189924240112,0.7569400072097778,1.041744828224182,50000 -3891.651116847992,15.529188394546509,110731.65055704115,329048,0,110731.65055704115,0.6314000487327576,1.823715090751648,10000,114650.8146300316,0.9612563848495485,0.1457004994153976,0.7572399973869324,1.0415037870407104,50000 -3908.7601075172415,15.613263130187988,111241.78953266144,330565,0,111241.78953266144,0.6319000124931335,1.822548747062683,10000,115178.20145368576,0.9599011540412904,0.146982803940773,0.7574999928474426,1.0409815311431885,50000 -3925.823562145233,15.69628357887268,111751.74325227736,332082,0,111751.74325227736,0.6318000555038452,1.822547912597656,10000,115705.35889673232,0.959382951259613,0.1487138122320175,0.7572599649429321,1.0413559675216677,50000 -3942.847847461701,15.78191375732422,112261.6212658882,333598,0,112261.6212658882,0.6317000389099121,1.822006106376648,10000,116232.40108180046,0.9609175324440002,0.1459372639656067,0.7568599581718445,1.0412836074829102,50000 -3959.836858987808,15.869224786758425,112771.67519831656,335115,0,112771.67519831656,0.631600022315979,1.8224438428878784,10000,116759.58655571938,0.9601004123687744,0.1463820040225982,0.7570399641990662,1.0413930416107178,50000 -3976.698829650879,15.951818227767944,113281.5459959507,336632,0,113281.5459959507,0.6310000419616699,1.8232208490371704,10000,117286.4574854374,0.9599011540412904,0.1479339897632599,0.757319986820221,1.0409128665924072,50000 -3993.6059629917145,16.03429889678955,113791.70448088646,338150,0,113791.70448088646,0.6317000389099121,1.8228672742843628,10000,117813.6619169712,0.9609375,0.1465308368206024,0.7573999762535095,1.0408369302749634,50000 -4010.716674566269,16.120765447616577,114301.828332901,339667,0,114301.828332901,0.6324000358581543,1.8227611780166624,10000,118341.0404188633,0.9609375,0.1465604901313781,0.757099986076355,1.040977120399475,50000 -4027.736142396927,16.20158553123474,114811.8994588852,341184,0,114811.8994588852,0.631600022315979,1.8225243091583248,10000,118868.26939034462,0.9606783986091614,0.1450160592794418,0.7569199800491333,1.0414011478424072,50000 -4044.732864379883,16.282942533493042,115322.14385986328,342701,0,115322.14385986328,0.6321000456809998,1.822193622589112,10000,119395.64589619637,0.9605787396430968,0.1452729403972625,0.7571799755096436,1.040116548538208,50000 -4061.798500061035,16.369632244110107,115832.13288211824,344218,0,115832.13288211824,0.631600022315979,1.8227174282073968,10000,119922.84298753738,0.9624322056770324,0.1427851915359497,0.7567399740219116,1.042166829109192,50000 -4078.7688570022574,16.454465627670288,116342.17555689812,345735,0,116342.17555689812,0.6322000026702881,1.822301864624024,10000,120449.99757409096,0.9613958597183228,0.1466252952814102,0.757099986076355,1.0420275926589966,50000 -4096.516710281372,16.538807153701782,116852.03779673576,347252,0,116852.03779673576,0.6323000192642212,1.82083797454834,10000,120977.74806761742,0.9598612785339355,0.1464956998825073,0.7571199536323547,1.039595127105713,50000 -4113.484414339066,16.631054878234863,117362.03670334816,348768,0,117362.03670334816,0.631600022315979,1.823034405708313,10000,121504.86248326302,0.9602598547935486,0.1485514193773269,0.7569199800491333,1.0419752597808838,50000 -4130.390887975693,16.715901851654053,117872.1169886589,350286,0,117872.1169886589,0.6313000321388245,1.8241838216781616,10000,122031.98891568184,0.961734652519226,0.1429440379142761,0.7569400072097778,1.0426743030548096,50000 -4147.371191740036,16.80562400817871,118382.07088541985,351803,0,118382.07088541985,0.6314000487327576,1.822781562805176,10000,122559.0689971447,0.9614955186843872,0.1441121399402618,0.7574399709701538,1.0409822463989258,50000 -4164.485343694687,16.895461559295654,118892.1052172184,353320,0,118892.1052172184,0.6314000487327576,1.821877241134644,10000,123086.36362028122,0.9614556431770324,0.1467177718877792,0.7574399709701538,1.0406513214111328,50000 -4181.410618782044,16.984424591064453,119402.13184118272,354837,0,119402.13184118272,0.6315000057220459,1.8235318660736084,10000,123613.46213245392,0.9608577489852904,0.146235704421997,0.7569599747657776,1.042207956314087,50000 -4198.244757652283,17.072447061538696,119912.02209234238,356354,0,119912.02209234238,0.6325000524520874,1.8233606815338133,10000,124140.33087706566,0.9602997303009032,0.1481306254863739,0.7575599551200867,1.041216492652893,50000 -4215.293922185898,17.160407543182373,120421.88165020944,357871,0,120421.88165020944,0.6318000555038452,1.8204973936080933,10000,124667.38286662102,0.9605388641357422,0.1478061825037002,0.7572000026702881,1.0402666330337524,50000 -4232.358246326447,17.2476224899292,120931.80952954292,359388,0,120931.80952954292,0.6321000456809998,1.822807669639588,10000,125194.51679444312,0.9615951776504515,0.1428749710321426,0.7572000026702881,1.0410480499267578,50000 -4249.258747339249,17.337990045547485,121441.90519046783,360905,0,121441.90519046783,0.6319000124931335,1.8209383487701416,10000,125721.65953350069,0.9600805044174194,0.1505532264709472,0.7570799589157104,1.0411326885223389,50000 -4266.226603746414,17.42392110824585,121951.9643895626,362421,0,121951.9643895626,0.6318000555038452,1.8226184844970703,10000,126248.82946372032,0.9601004123687744,0.1472820490598678,0.7572000026702881,1.0414822101593018,50000 -4283.07874751091,17.518263339996338,122462.1295595169,363938,0,122462.1295595169,0.6324000358581543,1.823939323425293,10000,126775.99683475494,0.9604990482330322,0.1480488926172256,0.7569999694824219,1.0426230430603027,50000 -4299.988889694214,17.606799125671387,122972.127712965,365455,0,122972.127712965,0.6309000253677368,1.822780728340149,10000,127303.0500793457,0.9614756107330322,0.1442434042692184,0.7568599581718445,1.0414214134216309,50000 -4316.926568746567,17.68900442123413,123482.26363253592,366972,0,123482.26363253592,0.6313000321388245,1.823107481002808,10000,127830.26229286194,0.960758090019226,0.1481673419475555,0.7574399709701538,1.0411081314086914,50000 -4333.906269550324,17.783448457717896,123992.4189324379,368489,0,123992.4189324379,0.631600022315979,1.8244590759277344,10000,128357.54884719849,0.9606983065605164,0.1471341699361801,0.7573999762535095,1.0416234731674194,50000 -4351.035741329193,17.875396251678467,124502.4652018547,370006,0,124502.4652018547,0.6321000456809998,1.8226678371429443,10000,128884.87219500542,0.960359513759613,0.148869127035141,0.7570399641990662,1.040926814079285,50000 -4367.901885032654,17.96213459968567,125012.54879522324,371523,0,125012.54879522324,0.6314000487327576,1.8242459297180176,10000,129411.96352028848,0.9600406289100648,0.1471908539533615,0.7571399807929993,1.0407812595367432,50000 -4384.982523202896,18.056538581848145,125522.60182905196,373039,0,125522.60182905196,0.6312000155448914,1.8227859735488887,10000,129939.24783706664,0.9602399468421936,0.1467160880565643,0.7579599618911743,1.0410587787628174,50000 -4401.814165115356,18.15566897392273,126032.49347376823,374556,0,126032.49347376823,0.6315000057220459,1.823155164718628,10000,130466.12568640707,0.9602000713348388,0.1454567462205886,0.7567200064659119,1.041068434715271,50000 -4418.809207677841,18.24462342262268,126542.64698553084,376073,0,126542.64698553084,0.6321000456809998,1.8223358392715447,10000,130993.41924715042,0.9597616195678712,0.1479842811822891,0.7572000026702881,1.0399235486984253,50000 -4435.703264951706,18.334900856018063,127052.61329627036,377591,0,127052.61329627036,0.631600022315979,1.8216814994812007,10000,131520.42503738403,0.9620535373687744,0.1455946564674377,0.7574799656867981,1.0399705171585083,50000 -4452.67462849617,18.42171573638916,127562.5142903328,379106,0,127562.5142903328,0.6313000321388245,1.822556495666504,10000,132047.43829274178,0.9601004123687744,0.1481819748878479,0.756879985332489,1.0412617921829224,50000 -4469.535237550736,18.51426196098328,128072.50441098212,380623,0,128072.50441098212,0.6308000087738037,1.8233776092529297,10000,132574.43628549576,0.960758090019226,0.1461323052644729,0.7569599747657776,1.0406873226165771,50000 -4486.295618534088,18.60328149795532,128582.37007761002,382139,0,128582.37007761002,0.6324000358581543,1.8214529752731323,10000,133101.20618534088,0.9616549611091614,0.1419612616300583,0.7573399543762207,1.040238618850708,50000 -4503.450261116028,18.69795846939087,129092.24616885184,383655,0,129092.24616885184,0.6321000456809998,1.824738621711731,10000,133628.38803076744,0.962511956691742,0.1423533260822296,0.7570799589157104,1.0416830778121948,50000 -4520.90341758728,18.79100012779236,129602.22697257996,385172,0,129602.22697257996,0.6325000524520874,1.8220690488815308,10000,134155.97039675713,0.9605189561843872,0.1476364582777023,0.757319986820221,1.0413920879364014,50000 -4537.72211432457,18.88157558441162,130112.18139767648,386689,0,130112.18139767648,0.6318000555038452,1.8210442066192627,10000,134682.88991069794,0.9601402878761292,0.1475104093551635,0.7570799589157104,1.0403294563293457,50000 -4554.750309705734,18.971620559692383,130622.19213318823,388206,0,130622.19213318823,0.6315000057220459,1.8237435817718504,10000,135210.0751516819,0.9606385231018066,0.1463115364313125,0.7573999762535095,1.0419403314590454,50000 -4571.768753767014,19.06741619110108,131132.10186076164,389723,0,131132.10186076164,0.6317000389099121,1.8231186866760247,10000,135737.15456604958,0.9618542790412904,0.1434044390916824,0.7575399875640869,1.0409173965454102,50000 -4588.780303239822,19.16048574447632,131642.28998470306,391240,0,131642.28998470306,0.6317000389099121,1.8221389055252075,10000,136264.5047211647,0.961734652519226,0.1444765478372573,0.7570799589157104,1.0409315824508667,50000 -4605.751079559326,19.26059985160828,132152.3566787243,392757,0,132152.3566787243,0.6325000524520874,1.8216640949249268,10000,136791.69770216942,0.960558831691742,0.1481156349182129,0.7575199604034424,1.040520191192627,50000 -4622.717230796814,19.35450482368469,132662.2222611904,394274,0,132662.2222611904,0.6319000124931335,1.8207794427871704,10000,137318.67782998085,0.9601203799247742,0.1482566446065902,0.7576000094413757,1.0401118993759155,50000 -4639.747121095657,19.45023393630981,133172.32787752151,395792,0,133172.32787752151,0.6313000321388245,1.8253854513168333,10000,137845.96401286125,0.9600805044174194,0.1480380147695541,0.7572199702262878,1.0424708127975464,50000 -4656.741573810577,19.547237634658813,133682.22514748573,397308,0,133682.22514748573,0.6315000057220459,1.821471929550171,10000,138373.00830101967,0.9609972834587096,0.1444547176361084,0.7571600079536438,1.039499282836914,50000 -4673.543560028076,19.645186185836792,134192.11761021614,398824,0,134192.11761021614,0.6321000456809998,1.8227674961090088,10000,138899.8560230732,0.9614756107330322,0.1464393436908722,0.757099986076355,1.0412877798080444,50000 -4690.554224252701,19.73924994468689,134702.17586684227,400342,0,134702.17586684227,0.6321000456809998,1.8230339288711548,10000,139427.07520484924,0.9597616195678712,0.146850436925888,0.7565400004386902,1.041675686836243,50000 -4707.534991264343,19.83228182792664,135212.1051683426,401858,0,135212.1051683426,0.6323000192642212,1.822576642036438,10000,139954.13414263725,0.9598811864852904,0.1492884010076522,0.7573800086975098,1.0411863327026367,50000 -4724.468670129776,19.93061709403992,135722.01432061195,403374,0,135722.01432061195,0.6320000290870667,1.822820782661438,10000,140481.13131904602,0.9605787396430968,0.1461165696382522,0.7571799755096436,1.0413422584533691,50000 -4741.439030885696,20.02464318275452,136231.96271348,404890,0,136231.96271348,0.6312000155448914,1.824077844619751,10000,141008.20151233673,0.9606983065605164,0.1485510170459747,0.7574799656867981,1.041606068611145,50000 -4758.309185504913,20.118609189987183,136741.83385276794,406407,0,136741.83385276794,0.6323000192642212,1.821937799453736,10000,141535.09117531776,0.9603196382522584,0.1486114412546157,0.75764000415802,1.0409525632858276,50000 -4775.223704099655,20.21147418022156,137251.82893824577,407924,0,137251.82893824577,0.6314000487327576,1.821706771850586,10000,142062.15008211136,0.960957407951355,0.1447140276432037,0.7571399807929993,1.0417368412017822,50000 -4792.087728738785,20.30940723419189,137761.7340619564,409441,0,137761.7340619564,0.6319000124931335,1.8224622011184688,10000,142589.07264232635,0.958765149116516,0.1495168954133987,0.7569999694824219,1.041053056716919,50000 -4808.927973031998,20.41361141204834,138271.87314987183,410958,0,138271.87314987183,0.6310000419616699,1.821075677871704,10000,143116.21062779427,0.9617944359779358,0.144712746143341,0.7569199800491333,1.0410289764404297,50000 -4825.856830596924,20.50609040260315,138781.94454431534,412475,0,138781.94454431534,0.6312000155448914,1.822407841682434,10000,143643.35971355438,0.958984375,0.1489603668451309,0.7571799755096436,1.041362762451172,50000 -4842.787178993225,20.600675106048584,139291.81418538094,413991,0,139291.81418538094,0.6321000456809998,1.8215856552124023,10000,144170.310328722,0.960339605808258,0.1461835503578186,0.7570799589157104,1.0407791137695312,50000 -4859.819278478622,20.697286128997803,139801.6758582592,415507,0,139801.6758582592,0.6307000517845154,1.823670744895935,10000,144697.35668969154,0.96097731590271,0.1466158181428909,0.7567999958992004,1.041944146156311,50000 -4876.771936416626,20.78969407081604,140311.5811650753,417024,0,140311.5811650753,0.6310000419616699,1.8238797187805176,10000,145224.3636994362,0.9611766338348388,0.1452847421169281,0.756879985332489,1.0428640842437744,50000 -4893.701037406921,20.88233780860901,140821.61366438866,418541,0,140821.61366438866,0.6319000124931335,1.8214082717895508,10000,145751.47378349304,0.9610371589660645,0.1459392458200454,0.7571600079536438,1.040714144706726,50000 -4910.61313700676,20.978939056396484,141331.67381334305,420058,0,141331.67381334305,0.6314000487327576,1.8214672803878784,10000,146278.59851312637,0.9612962007522584,0.1448786854743957,0.7575799822807312,1.0407829284667969,50000 -4927.719024181366,21.077617645263672,141841.59809207916,421575,0,141841.59809207916,0.6322000026702881,1.821758270263672,10000,146805.78198504448,0.9612364172935486,0.1430351883172989,0.7572599649429321,1.0402424335479736,50000 -4944.911732673645,21.177175760269165,142351.5038971901,423092,0,142351.5038971901,0.6313000321388245,1.8226169347763064,10000,147333.03734254837,0.9618741869926452,0.144457459449768,0.7572999596595764,1.040998458862305,50000 -4962.483725547791,21.25806474685669,142861.59340810776,424609,0,142861.59340810776,0.6313000321388245,1.8224475383758545,10000,147860.83490467072,0.9593032598495485,0.1488636881113052,0.7572599649429321,1.040607452392578,50000 -4979.48906993866,21.36016345024109,143371.71987581253,426126,0,143371.71987581253,0.631100058555603,1.8229377269744875,10000,148388.12489271164,0.9600605964660645,0.1482033878564834,0.7573599815368652,1.040932536125183,50000 -4996.394257545471,21.45875644683838,143881.7842924595,427644,0,143881.7842924595,0.631600022315979,1.820892095565796,10000,148915.24838876724,0.9620934128761292,0.1434959471225738,0.7572000026702881,1.0401928424835205,50000 -5013.312318086624,21.560152769088745,144391.64415311813,429160,0,144391.64415311813,0.6306000351905823,1.8239283561706543,10000,149442.18219542503,0.9617147445678712,0.1429807841777801,0.7571600079536438,1.041298508644104,50000 -5030.17825627327,21.67906594276428,144901.60076332092,430677,0,144901.60076332092,0.6312000155448914,1.8239145278930664,10000,149969.17936348915,0.9617745280265808,0.1452418863773346,0.757420003414154,1.0418684482574463,50000 -5047.105713605881,21.780495166778564,145411.59765625,432193,0,145411.59765625,0.6321000456809998,1.8241660594940183,10000,150496.2617239952,0.9602000713348388,0.1488232314586639,0.7574999928474426,1.041712999343872,50000 -5063.910180091858,21.878495693206787,145921.48075079918,433709,0,145921.48075079918,0.6320000290870667,1.823303461074829,10000,151023.10500741005,0.9606783986091614,0.1471301168203354,0.7573399543762207,1.041322112083435,50000 -5080.642971515656,21.97658133506775,146431.42701363564,435226,0,146431.42701363564,0.631600022315979,1.8207440376281736,10000,151549.93702721596,0.9610171914100648,0.1463170796632766,0.757099986076355,1.0406198501586914,50000 -5097.61579823494,22.0786395072937,146941.59220194817,436744,0,146941.59220194817,0.6321000456809998,1.82472825050354,10000,152077.231808424,0.9606584906578064,0.1450043320655822,0.75764000415802,1.0415468215942385,50000 -5114.640841245651,22.17793607711792,147451.58518648148,438261,0,147451.58518648148,0.6318000555038452,1.82334578037262,10000,152604.4050400257,0.9612165093421936,0.146070510149002,0.7572000026702881,1.0408618450164795,50000 -5131.485112667084,22.27350926399231,147961.7461400032,439778,0,147961.7461400032,0.6314000487327576,1.8219047784805296,10000,153131.56278252602,0.9588249325752258,0.1511629223823547,0.7573399543762207,1.0402222871780396,50000 -5148.434104681015,22.382426500320435,148471.76654148102,441295,0,148471.76654148102,0.6318000555038452,1.8222293853759768,10000,153658.69793319702,0.9607381820678712,0.1461882889270782,0.7571399807929993,1.042002558708191,50000 -5165.56057715416,22.4826500415802,148981.70437383652,442811,0,148981.70437383652,0.6321000456809998,1.820679783821106,10000,154185.9196381569,0.9609175324440002,0.1468439698219299,0.7570799589157104,1.040229320526123,50000 -5182.874908208847,22.5871741771698,149491.8409090042,444328,0,149491.8409090042,0.6317000389099121,1.8212313652038568,10000,154713.53189253807,0.9606186151504515,0.1501710265874862,0.7574399709701538,1.0412341356277466,50000 -5199.798217058182,22.6678466796875,150001.83459067345,445845,0,150001.83459067345,0.631100058555603,1.822283148765564,10000,155240.5869767666,0.9606983065605164,0.146799087524414,0.7567799687385559,1.0411818027496338,50000 -5216.711742401123,22.770523071289062,150511.73395705223,447362,0,150511.73395705223,0.6319000124931335,1.8231927156448364,10000,155767.55983424187,0.9602798223495485,0.1471463590860366,0.7571199536323547,1.0414865016937256,50000 -5233.558504581451,22.870090007781982,151021.81368637085,448879,0,151021.81368637085,0.6320000290870667,1.8224512338638303,10000,156294.64122962952,0.9606186151504515,0.1455328613519668,0.7567200064659119,1.0406856536865234,50000 -5250.33748459816,22.97496485710144,151531.68371200562,450395,0,151531.68371200562,0.6322000026702881,1.824686050415039,10000,156821.4516866207,0.9595423936843872,0.1479770094156265,0.7572000026702881,1.0416878461837769,50000 -5267.233897209168,23.07772135734558,152041.5577995777,451911,0,152041.5577995777,0.6318000555038452,1.8233704566955569,10000,157348.38074493408,0.959980845451355,0.1469773203134536,0.7573599815368652,1.0414307117462158,50000 -5284.3033618927,23.179197311401367,152551.6747918129,453428,0,152551.6747918129,0.632900059223175,1.82382333278656,10000,157875.72356677055,0.9600406289100648,0.1466209590435028,0.7575199604034424,1.0417447090148926,50000 -5301.271278142929,23.28038787841797,153061.76563978195,454945,0,153061.76563978195,0.6317000389099121,1.823285818099976,10000,158402.9397919178,0.9618343114852904,0.1466174721717834,0.7567399740219116,1.0415546894073486,50000 -5318.150739192963,23.382413625717163,153571.90072655678,456463,0,153571.90072655678,0.6308000087738037,1.8218352794647217,10000,158930.11351394653,0.96097731590271,0.1465593427419662,0.7573599815368652,1.0404845476150513,50000 -5335.127779006958,23.484861850738525,154081.8287653923,457980,0,154081.8287653923,0.6302000284194946,1.8235985040664675,10000,159457.1767117977,0.9604591727256776,0.1468010693788528,0.7571799755096436,1.0417757034301758,50000 -5352.105293512344,23.593278408050537,154591.9551639557,459497,0,154591.9551639557,0.6321000456809998,1.8234221935272217,10000,159984.44529652596,0.9614556431770324,0.1402759253978729,0.7569999694824219,1.0417706966400146,50000 -5368.956280946732,23.694770097732544,155101.99497246742,461014,0,155101.99497246742,0.6315000057220459,1.8218690156936648,10000,160511.493765831,0.962113320827484,0.1443096399307251,0.7572799921035767,1.0408600568771362,50000 -5386.340245485306,23.79537916183472,155611.86061382294,462531,0,155611.86061382294,0.6313000321388245,1.8227357864379885,10000,161038.8993780613,0.9612165093421936,0.1455664485692978,0.7572000026702881,1.0417051315307615,50000 -5403.175386667252,23.90141129493713,156121.84329080582,464047,0,156121.84329080582,0.6318000555038452,1.824604034423828,10000,161565.87942004204,0.9599210619926452,0.1485941559076309,0.757099986076355,1.0420114994049072,50000 -5420.25766825676,23.998153924942017,156631.992497921,465565,0,156631.992497921,0.6313000321388245,1.8224778175354004,10000,162093.26338124275,0.960359513759613,0.1463554352521896,0.7572799921035767,1.0423952341079712,50000 -5437.138303518295,24.08114218711853,157142.13254094124,467082,0,157142.13254094124,0.6315000057220459,1.8228623867034912,10000,162620.42335128784,0.9616350531578064,0.1432344317436218,0.7572799921035767,1.0408964157104492,50000 -5453.913333177567,24.18593978881836,157652.25593709946,468600,0,157652.25593709946,0.6325000524520874,1.8226418495178225,10000,163147.48210787773,0.9614157676696776,0.146223098039627,0.7571600079536438,1.04023540019989,50000 -5470.884748220444,24.29122257232666,158162.1717247963,470116,0,158162.1717247963,0.6308000087738037,1.8218297958374023,10000,163674.5309638977,0.9616549611091614,0.1469157487154007,0.7570599913597107,1.0405776500701904,50000 -5487.740997314453,24.3963086605072,158672.21157503128,471633,0,158672.21157503128,0.6321000456809998,1.8239001035690308,10000,164201.5887157917,0.960379421710968,0.1477285623550415,0.7573399543762207,1.0409228801727295,50000 -5504.653124094009,24.502280473709103,159182.24066758156,473150,0,159182.24066758156,0.6315000057220459,1.8239738941192627,10000,164728.6911456585,0.9608178734779358,0.145049899816513,0.7572599649429321,1.0415680408477783,50000 -5521.564659357071,24.605119466781616,159692.33936095238,474667,0,159692.33936095238,0.6312000155448914,1.8208348751068115,10000,165255.85982394218,0.9606385231018066,0.1474056243896484,0.7566199898719788,1.0405768156051636,50000 -5538.428807735443,24.70886778831482,160202.31282019615,476184,0,160202.31282019615,0.6312000155448914,1.823325634002685,10000,165782.85601115227,0.9614556431770324,0.1444203853607177,0.7572399973869324,1.0410040616989136,50000 -5555.441893815994,24.811065435409542,160712.27799630165,477701,0,160712.27799630165,0.6321000456809998,1.8233537673950195,10000,166309.99308776855,0.9598811864852904,0.148447573184967,0.7573399543762207,1.0417107343673706,50000 -5572.439893722534,24.91554069519043,161222.3278567791,479218,0,161222.3278567791,0.631600022315979,1.8240758180618288,10000,166837.20202589035,0.9594228267669678,0.1483974158763885,0.7570399641990662,1.0413779020309448,50000 -5589.2941699028015,25.02281379699707,161732.29912495613,480735,0,161732.29912495613,0.6304000020027161,1.824426889419556,10000,167364.19173169136,0.961136758327484,0.1473008692264557,0.7571600079536438,1.0423572063446045,50000 -5606.119745254517,25.12814474105835,162242.1644639969,482252,0,162242.1644639969,0.6324000358581543,1.8233898878097528,10000,167891.04339289665,0.961355984210968,0.1467293053865432,0.7570799589157104,1.0419999361038208,50000 -5623.075105905533,25.23688054084778,162752.32794451714,483769,0,162752.32794451714,0.6326000094413757,1.8239266872406008,10000,168418.32730579376,0.960160195827484,0.1499893218278885,0.7566999793052673,1.0416303873062134,50000 -5639.970458984375,25.34450364112854,163262.20376372337,485286,0,163262.20376372337,0.6328000426292419,1.823925495147705,10000,168945.26424503326,0.9605787396430968,0.1452038437128067,0.7571199536323547,1.0410614013671875,50000 -5656.945882558823,25.46109294891357,163772.209897995,486803,0,163772.209897995,0.6305000185966492,1.8246389627456665,10000,169472.41901779175,0.9608777165412904,0.1467392593622207,0.7570199966430664,1.04235577583313,50000 -5673.986569881439,25.623329162597656,164282.22924995422,488319,0,164282.22924995422,0.6313000321388245,1.822999119758606,10000,169999.69826436043,0.9600605964660645,0.1470848321914672,0.7570799589157104,1.0419479608535769,50000 -5690.992618322372,25.73021769523621,164792.2272515297,489837,0,164792.2272515297,0.6313000321388245,1.8239089250564573,10000,170526.86529254913,0.958984375,0.148769661784172,0.7572799921035767,1.041583776473999,50000 -5707.914986610413,25.83800363540649,165302.07374358177,491353,0,165302.07374358177,0.631600022315979,1.822745442390442,10000,171053.79780197144,0.9603196382522584,0.144847884774208,0.7572199702262878,1.0405925512313845,50000 -5724.814933538437,25.942726135253903,165812.0014333725,492869,0,165812.0014333725,0.6320000290870667,1.8230407238006592,10000,171580.78747987747,0.9604192972183228,0.1478891372680664,0.7571799755096436,1.0413697957992554,50000 -5741.729859828949,26.05689764022827,166322.11298918724,494387,0,166322.11298918724,0.6312000155448914,1.8221620321273804,10000,172107.98474049568,0.961535394191742,0.1452343314886093,0.7576199769973755,1.041755199432373,50000 -5758.526192903519,26.16771912574768,166832.02917313576,495904,0,166832.02917313576,0.6308000087738037,1.8224083185195925,10000,172634.86578130722,0.9598413109779358,0.1481824964284896,0.7570199966430664,1.041217803955078,50000 -5775.282235145569,26.27513027191162,167341.9730234146,497422,0,167341.9730234146,0.631600022315979,1.8234856128692627,10000,173161.72937083244,0.9609375,0.1445536315441131,0.7573999762535095,1.041479468345642,50000 -5792.087192296982,26.38953804969788,167852.10940527916,498940,0,167852.10940527916,0.6319000124931335,1.822174072265625,10000,173688.8420562744,0.9624322056770324,0.1402934044599533,0.7572799921035767,1.0417726039886477,50000 -5808.90158700943,26.495758056640625,168362.00852704048,500457,0,168362.00852704048,0.6327000260353088,1.8239601850509644,10000,174215.71746110916,0.962511956691742,0.1444217264652252,0.7576000094413757,1.041919469833374,50000 -5826.445637702942,26.60779643058777,168871.98159265518,501974,0,168871.98159265518,0.6307000517845154,1.8220081329345703,10000,174743.40242886543,0.9595623016357422,0.1485275775194168,0.7571799755096436,1.04130756855011,50000 -5843.382396221161,26.7176833152771,169381.893055439,503490,0,169381.893055439,0.6318000555038452,1.8225680589675903,10000,175270.41539549828,0.9605388641357422,0.1469699740409851,0.7571399807929993,1.0414334535598757,50000 -5860.224097967148,26.82484793663025,169891.74683499336,505007,0,169891.74683499336,0.631100058555603,1.8230360746383667,10000,175797.27426624298,0.9607979655265808,0.145893707871437,0.7574999928474426,1.0407086610794067,50000 -5877.193679094315,26.938070058822632,170401.87842082977,506524,0,170401.87842082977,0.6308000087738037,1.822520613670349,10000,176324.54439496994,0.9612165093421936,0.144617959856987,0.7572000026702881,1.0419772863388062,50000 -5894.027221918106,27.048126935958862,170911.74804997444,508040,0,170911.74804997444,0.6322000026702881,1.8211873769760127,10000,176851.41265058515,0.9621930718421936,0.1429073512554168,0.7571600079536438,1.04066002368927,50000 -5910.898756742477,27.1546368598938,171421.61796426773,509558,0,171421.61796426773,0.6310000419616699,1.824790477752685,10000,177378.3165552616,0.9604392051696776,0.148814707994461,0.7572999596595764,1.041728973388672,50000 -5927.751242399216,27.26954126358032,171931.7805171013,511075,0,171931.7805171013,0.6322000026702881,1.821224331855774,10000,177905.50328946114,0.960558831691742,0.1479910314083099,0.7572000026702881,1.0400274991989136,50000 -5944.638758897781,27.386696577072144,172441.7371351719,512592,0,172441.7371351719,0.6326000094413757,1.822581768035889,10000,178432.52074337006,0.9605787396430968,0.1462783068418502,0.7572399973869324,1.0417357683181765,50000 -5961.57798075676,27.50553941726685,172951.64312791824,514109,0,172951.64312791824,0.6318000555038452,1.8226438760757449,10000,178959.5411233902,0.9622528553009032,0.1430338323116302,0.7570399641990662,1.0404475927352903,50000 -5978.376556158066,27.6191668510437,173461.72400903702,515627,0,173461.72400903702,0.6319000124931335,1.8204569816589355,10000,179486.58996462822,0.9596619606018066,0.1483250558376312,0.757319986820221,1.0403424501419067,50000 -5995.200742959976,27.73391819000244,173971.74058961868,517144,0,173971.74058961868,0.6307000517845154,1.82349693775177,10000,180013.60146594048,0.9599210619926452,0.1483697742223739,0.7571799755096436,1.0417215824127195,50000 -6012.095266580582,27.856269359588623,174481.84917855263,518661,0,174481.84917855263,0.6325000524520874,1.823171734809876,10000,180540.7841868401,0.959004282951355,0.1498319208621978,0.7571600079536438,1.0412065982818604,50000 -6029.069609165192,27.970374822616577,174991.8268210888,520178,0,174991.8268210888,0.6322000026702881,1.822198867797852,10000,181067.9081850052,0.9616549611091614,0.1446880847215652,0.757319986820221,1.0408968925476074,50000 -6045.909396409988,28.082290172576904,175501.88168287277,521695,0,175501.88168287277,0.6320000290870667,1.8227356672286987,10000,181594.9708378315,0.9610371589660645,0.1484816223382949,0.757319986820221,1.0408750772476196,50000 -6062.849289655685,28.1988844871521,176011.82210946083,523212,0,176011.82210946083,0.6319000124931335,1.8241809606552124,10000,182122.0238277912,0.96097731590271,0.1454038321971893,0.7571600079536438,1.0417425632476809,50000 -6079.59999704361,28.31386494636536,176521.65455007553,524728,0,176521.65455007553,0.6315000057220459,1.823151707649231,10000,182648.7777543068,0.9594427347183228,0.1490860879421234,0.7573599815368652,1.0410236120224,50000 -6096.458354473114,28.42765522003174,177031.57284379005,526245,0,177031.57284379005,0.6308000087738037,1.8240805864334104,10000,183175.72368884087,0.9598214030265808,0.148780271410942,0.7570799589157104,1.0420541763305664,50000 -6113.270159482956,28.54345893859864,177541.49831581116,527761,0,177541.49831581116,0.6322000026702881,1.8236315250396729,10000,183702.63179159164,0.960558831691742,0.1445672214031219,0.7572599649429321,1.041681885719299,50000 -6130.148310184479,28.656593799591064,178051.53967666626,529277,0,178051.53967666626,0.6322000026702881,1.821385264396668,10000,184229.71953058243,0.9598811864852904,0.148500919342041,0.7572799921035767,1.0404560565948486,50000 -6146.977098941803,28.77240538597107,178561.46390724182,530794,0,178561.46390724182,0.6319000124931335,1.823911190032959,10000,184756.6442449093,0.9596819281578064,0.1456088274717331,0.7568999528884888,1.0421922206878662,50000 -6163.73255443573,28.88776683807373,179071.4066681862,532310,0,179071.4066681862,0.6317000389099121,1.8210970163345337,10000,185283.5139591694,0.9608976244926452,0.147421196103096,0.7572000026702881,1.0403904914855957,50000 -6180.650722265244,29.001569032669067,179581.4375398159,533827,0,179581.4375398159,0.6320000290870667,1.8223789930343628,10000,185810.63350367543,0.9612762928009032,0.1458973139524459,0.7568999528884888,1.0407488346099854,50000 -6197.373930454254,29.12678384780884,180091.54822564125,535344,0,180091.54822564125,0.6310000419616699,1.8225830793380733,10000,186337.6497502327,0.9600605964660645,0.1473893821239471,0.756879985332489,1.040579080581665,50000 -6214.216796398163,29.241114854812626,180601.6016540528,536861,0,180601.6016540528,0.6314000487327576,1.823699712753296,10000,186864.71676635745,0.9612762928009032,0.1430565118789672,0.757099986076355,1.0416611433029177,50000 -6231.133362054825,29.359850883483887,181111.55942821503,538378,0,181111.55942821503,0.6320000290870667,1.822485089302063,10000,187391.7662270069,0.9622727632522584,0.1430283784866333,0.7569199800491333,1.041762351989746,50000 -6248.389245271683,29.48081016540528,181621.6319413185,539895,0,181621.6319413185,0.6320000290870667,1.8241206407547,10000,187919.27256298065,0.9616350531578064,0.1455331742763519,0.7573999762535095,1.0420688390731812,50000 -6265.310503005981,29.59652018547058,182131.56746602056,541411,0,182131.56746602056,0.6317000389099121,1.8230903148651123,10000,188446.3013682365,0.959781527519226,0.1493196040391922,0.7568199634552002,1.0412883758544922,50000 -6282.299289464951,29.71865010261536,182641.5040605068,542928,0,182641.5040605068,0.6312000155448914,1.822980165481568,10000,188973.40545129776,0.9606983065605164,0.144793152809143,0.7574999928474426,1.0422248840332031,50000 -6299.116809844971,29.8223876953125,183151.4329688549,544445,0,183151.4329688549,0.6307000517845154,1.824191689491272,10000,189500.3107652664,0.961933970451355,0.1422090828418731,0.7573599815368652,1.0426530838012695,50000 -6315.990729570389,29.937760829925537,183661.53258037567,545962,0,183661.53258037567,0.6306000351905823,1.822177052497864,10000,190027.4561581612,0.9602798223495485,0.1469572782516479,0.7570199966430664,1.041181564331055,50000 -6333.247943878174,30.054824352264404,184171.5425419808,547478,0,184171.5425419808,0.6310000419616699,1.8227473497390747,10000,190554.8959350586,0.9627909660339355,0.1446569114923477,0.757099986076355,1.0409207344055176,50000 -6350.045515298843,30.17049288749695,184681.5361790657,548995,0,184681.5361790657,0.6325000524520874,1.8207401037216189,10000,191081.8586373329,0.9591637253761292,0.1491780281066894,0.7572599649429321,1.0409399271011353,50000 -6366.848225593567,30.291536569595337,185191.63391900063,550512,0,185191.63391900063,0.6324000358581543,1.8237340450286863,10000,191608.9362416268,0.9607381820678712,0.1471794098615646,0.7575199604034424,1.0413737297058103,50000 -6383.623881816864,30.41457509994507,185701.63338589668,552029,0,185701.63338589668,0.6317000389099121,1.8212229013442995,10000,192135.8907732964,0.9604990482330322,0.1474171131849289,0.7571600079536438,1.0401363372802734,50000 -6400.52764081955,30.53302764892578,186211.5637850761,553547,0,186211.5637850761,0.6324000358581543,1.822754144668579,10000,192662.90041589737,0.96097731590271,0.1455092877149582,0.7566999793052673,1.0411295890808103,50000 -6417.333734750748,30.65377688407898,186721.5231051445,555064,0,186721.5231051445,0.6317000389099121,1.8228280544281008,10000,193189.8435087204,0.9598413109779358,0.1486168652772903,0.7573399543762207,1.0403741598129272,50000 -6434.261531352997,30.77181887626648,187231.49054288864,556582,0,187231.49054288864,0.6324000358581543,1.822922110557556,10000,193716.9131946564,0.9602000713348388,0.1469131112098693,0.7574599981307983,1.0412509441375732,50000 -6451.153831481934,30.892110347747803,187741.4332535267,558099,0,187741.4332535267,0.6310000419616699,1.822549104690552,10000,194243.9229915142,0.9608777165412904,0.1464002430438995,0.7570799589157104,1.040795087814331,50000 -6467.947027206421,31.01043510437012,188251.5854208469,559617,0,188251.5854208469,0.6313000321388245,1.8238762617111208,10000,194771.04189276692,0.9605388641357422,0.1474508792161941,0.7572799921035767,1.0420900583267212,50000 -6484.6761956214905,31.13318109512329,188379.31034731865,559998,0,188379.31034731865,0.6313000321388245,1.8217592239379883,10000,194915.6333310604,0.9612165093421936,0.14852507412433624,0.7571199536323547,1.039791464805603,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv deleted file mode 100644 index 320c972ec..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5973 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.6668113,6.9157767,,,,,,,,,,,,,, -1,,,0.0008968430920504,6.913041591644287,0.0010999999940395,6.912304878234863,50000.0,0.0010999999940395,6.911952018737793,10000.0,49.55485677719116,86.1427435874939,49.55485677719116,36.58779454231262,0.0,0.0 -100,0.6799252,6.817208,,,,,,,,,,,,,, -200,0.80222875,6.5243945,,,,,,,,,,,,,, -300,0.9069234,6.2657576,,,,,,,,,,,,,, -400,1.4112072,6.002125,,,,,,,,,,,,,, -500,2.1854312,5.7719345,,,,,,,,,,,,,, -600,2.4046388,5.558484,,,,,,,,,,,,,, -700,3.5686607,5.3749485,,,,,,,,,,,,,, -800,5.0475125,5.243263,,,,,,,,,,,,,, -900,3.5679996,5.1957827,,,,,,,,,,,,,, -1000,4.771815,4.9604683,,,,,,,,,,,,,, -1100,7.6457577,5.027572,,,,,,,,,,,,,, -1200,5.49671,4.748393,,,,,,,,,,,,,, -1300,4.4414563,4.6469774,,,,,,,,,,,,,, -1400,3.5926852,4.6862793,,,,,,,,,,,,,, -1500,3.7630565,4.542828,,,,,,,,,,,,,, -1513,,,0.169782355427742,4.28282356262207,0.150419995188713,4.425992488861084,50000.0,0.1111000031232833,4.882546424865723,10000.0,559.6508867740631,614.2222878932953,559.6508867740631,54.492894887924194,0.0251517295837402,0.0 -1600,5.732926,4.4460773,,,,,,,,,,,,,, -1700,7.474423,4.290468,,,,,,,,,,,,,, -1800,5.3324494,4.323224,,,,,,,,,,,,,, -1900,5.218601,4.142845,,,,,,,,,,,,,, -2000,10.89493,4.113483,,,,,,,,,,,,,, -2100,4.4043837,3.914007,,,,,,,,,,,,,, -2200,3.3156579,3.8448687,,,,,,,,,,,,,, -2300,3.5173173,3.939761,,,,,,,,,,,,,, -2400,5.618124,3.860882,,,,,,,,,,,,,, -2500,4.5820265,3.7191825,,,,,,,,,,,,,, -2600,3.2019157,3.5607874,,,,,,,,,,,,,, -2700,2.9575589,3.5690079,,,,,,,,,,,,,, -2800,2.9765332,3.5832143,,,,,,,,,,,,,, -2900,3.6944153,3.4705608,,,,,,,,,,,,,, -3000,6.8415923,3.3804226,,,,,,,,,,,,,, -3024,,,0.3342833220958709,3.0926454067230225,0.3082599937915802,3.276564359664917,50000.0,0.2336000055074691,3.956745862960816,10000.0,1069.7121422290802,1142.3826808929443,1069.7121422290802,72.50932359695435,0.0539841651916503,0.0 -3100,7.291008,3.4801521,,,,,,,,,,,,,, -3200,3.917411,3.3326488,,,,,,,,,,,,,, -3300,6.3364363,3.276751,,,,,,,,,,,,,, -3400,4.5249434,3.2839718,,,,,,,,,,,,,, -3500,3.6049268,3.161642,,,,,,,,,,,,,, -3600,3.7915297,3.1550615,,,,,,,,,,,,,, -3700,2.251211,3.1007872,,,,,,,,,,,,,, -3800,2.9023046,3.0339766,,,,,,,,,,,,,, -3900,2.155222,3.0064123,,,,,,,,,,,,,, -4000,3.8478801,3.067628,,,,,,,,,,,,,, -4100,2.9297013,2.9386263,,,,,,,,,,,,,, -4200,4.412095,2.866229,,,,,,,,,,,,,, -4300,3.853321,3.0067067,,,,,,,,,,,,,, -4400,3.393866,2.909255,,,,,,,,,,,,,, -4500,2.417107,2.851259,,,,,,,,,,,,,, -4537,,,0.453523576259613,2.397743225097656,0.4207599759101867,2.597283363342285,50000.0,0.3188000023365021,3.3225812911987305,10000.0,1579.652539730072,1670.4646260738373,1579.652539730072,90.5704882144928,0.0806784629821777,0.0 -4600,3.1102524,2.9203403,,,,,,,,,,,,,, -4700,2.149988,2.6646333,,,,,,,,,,,,,, -4800,2.4036303,2.696097,,,,,,,,,,,,,, -4900,3.2313545,2.66534,,,,,,,,,,,,,, -5000,3.2927032,2.5670993,,,,,,,,,,,,,, -5100,2.068888,2.7191727,,,,,,,,,,,,,, -5200,2.5230231,2.4959416,,,,,,,,,,,,,, -5300,2.8334377,2.6024852,,,,,,,,,,,,,, -5400,4.05209,2.6410093,,,,,,,,,,,,,, -5500,1.9529189,2.5673018,,,,,,,,,,,,,, -5600,2.251946,2.522551,,,,,,,,,,,,,, -5700,2.3059292,2.4625382,,,,,,,,,,,,,, -5800,2.1264648,2.3602958,,,,,,,,,,,,,, -5900,2.36891,2.5464396,,,,,,,,,,,,,, -6000,1.6624055,2.5582054,,,,,,,,,,,,,, -6052,,,0.5290178656578064,2.0105631351470947,0.4911199808120727,2.2131128311157227,50000.0,0.3797000050544739,2.925008773803711,10000.0,2089.819321870804,2198.679522037506,2089.819321870804,108.53957986831664,0.1065936088562011,0.0 -6100,2.1090667,2.4190116,,,,,,,,,,,,,, -6200,2.3347816,2.415983,,,,,,,,,,,,,, -6300,2.202365,2.2987337,,,,,,,,,,,,,, -6400,2.6427526,2.46728,,,,,,,,,,,,,, -6500,2.2308173,2.4125743,,,,,,,,,,,,,, -6600,1.9141691,2.3754497,,,,,,,,,,,,,, -6700,2.559198,2.294919,,,,,,,,,,,,,, -6800,1.9778793,2.450092,,,,,,,,,,,,,, -6900,1.791404,2.4311104,,,,,,,,,,,,,, -7000,2.5579379,2.3791676,,,,,,,,,,,,,, -7100,2.5455084,2.3332822,,,,,,,,,,,,,, -7200,2.708599,2.3054771,,,,,,,,,,,,,, -7300,2.5543482,2.2574914,,,,,,,,,,,,,, -7400,1.9068978,2.2324593,,,,,,,,,,,,,, -7500,2.0982444,2.1856306,,,,,,,,,,,,,, -7568,,,0.5653499364852905,1.82849645614624,0.5282999873161316,2.031683921813965,50000.0,0.4047000110149383,2.7944133281707764,10000.0,2600.062970161438,2726.9429802894592,2600.062970161438,126.48164319992064,0.1320304870605468,0.0 -7600,1.8570837,2.196008,,,,,,,,,,,,,, -7700,2.3946784,2.2675636,,,,,,,,,,,,,, -7800,2.9010062,2.1764116,,,,,,,,,,,,,, -7900,2.2236502,2.2467725,,,,,,,,,,,,,, -8000,2.341791,2.111616,,,,,,,,,,,,,, -8100,1.7547429,2.311585,,,,,,,,,,,,,, -8200,1.2662264,1.9668697,,,,,,,,,,,,,, -8300,1.9721413,2.091108,,,,,,,,,,,,,, -8400,2.0571873,2.2470415,,,,,,,,,,,,,, -8500,1.6586331,2.2354317,,,,,,,,,,,,,, -8600,1.8006747,2.1644084,,,,,,,,,,,,,, -8700,2.0163503,2.23426,,,,,,,,,,,,,, -8800,2.2632132,2.061894,,,,,,,,,,,,,, -8900,1.6713908,2.0277812,,,,,,,,,,,,,, -9000,1.8626872,2.1254158,,,,,,,,,,,,,, -9084,,,0.6339883208274841,1.5135927200317385,0.5559399724006653,1.890720248222351,50000.0,0.4335000216960907,2.639951467514038,10000.0,3110.2991478443146,3255.014739274978,3110.2991478443146,144.2371587753296,0.1595776081085205,0.0 -9100,2.2623007,2.0744805,,,,,,,,,,,,,, -9200,2.071311,1.9791641,,,,,,,,,,,,,, -9300,1.549697,2.1568353,,,,,,,,,,,,,, -9400,2.3940058,2.1358793,,,,,,,,,,,,,, -9500,1.896126,1.9853066,,,,,,,,,,,,,, -9600,1.8664024,2.094857,,,,,,,,,,,,,, -9700,2.327799,2.0775583,,,,,,,,,,,,,, -9800,1.7810448,2.0321999,,,,,,,,,,,,,, -9900,1.6554773,2.032106,,,,,,,,,,,,,, -10000,1.6405897,2.2310474,,,,,,,,,,,,,, -10100,1.837667,1.9483721,,,,,,,,,,,,,, -10200,2.070261,2.119542,,,,,,,,,,,,,, -10300,1.6778626,2.0272374,,,,,,,,,,,,,, -10400,1.8824925,2.0297842,,,,,,,,,,,,,, -10500,1.6526276,2.0283086,,,,,,,,,,,,,, -10600,,,0.6399075388908386,1.445544958114624,0.5717999935150146,1.8170902729034424,50000.0,0.4480000138282776,2.5838935375213623,10000.0,3620.485506534576,3783.384105682373,3620.485506534576,162.33961367607117,0.1865520477294922,0.0 -10600,1.5975287,2.0782356,,,,,,,,,,,,,, -10700,2.527667,2.1841514,,,,,,,,,,,,,, -10800,1.4757549,1.9522359,,,,,,,,,,,,,, -10900,1.8792803,1.9048613,,,,,,,,,,,,,, -11000,1.9995407,1.977124,,,,,,,,,,,,,, -11100,1.8075242,1.9716312,,,,,,,,,,,,,, -11200,1.8987727,2.0760446,,,,,,,,,,,,,, -11300,1.4093418,1.965422,,,,,,,,,,,,,, -11400,1.6042897,1.9911354,,,,,,,,,,,,,, -11500,1.7817886,2.0179758,,,,,,,,,,,,,, -11600,1.6297907,1.9423823,,,,,,,,,,,,,, -11700,1.618125,1.9344375,,,,,,,,,,,,,, -11800,1.9074701,1.8840613,,,,,,,,,,,,,, -11900,1.7048103,2.0159557,,,,,,,,,,,,,, -12000,1.8096627,1.9541168,,,,,,,,,,,,,, -12100,1.5890388,1.9911829,,,,,,,,,,,,,, -12115,,,0.6345862150192261,1.4771161079406738,0.5790799856185913,1.7761107683181765,50000.0,0.4598000347614288,2.4967310428619385,10000.0,4130.400605678558,4311.498994588852,4130.400605678558,180.4450373649597,0.2277424335479736,0.0 -12200,1.6412361,1.9905027,,,,,,,,,,,,,, -12300,1.2899939,1.9649274,,,,,,,,,,,,,, -12400,1.7826252,1.9767041,,,,,,,,,,,,,, -12500,1.5792862,1.9498991,,,,,,,,,,,,,, -12600,1.5333277,1.9769937,,,,,,,,,,,,,, -12700,1.6856378,1.9985557,,,,,,,,,,,,,, -12800,1.4096117,1.9672817,,,,,,,,,,,,,, -12900,1.6159917,1.9841465,,,,,,,,,,,,,, -13000,2.0814202,1.9422407,,,,,,,,,,,,,, -13100,1.9608057,2.0134392,,,,,,,,,,,,,, -13200,1.2954626,1.9071125,,,,,,,,,,,,,, -13300,1.4917369,1.9662817,,,,,,,,,,,,,, -13400,1.6406592,1.9419608,,,,,,,,,,,,,, -13500,1.5759028,1.9771061,,,,,,,,,,,,,, -13600,1.6271662,1.8176777,,,,,,,,,,,,,, -13631,,,0.6399673223495483,1.4503324031829834,0.5899400115013123,1.7300362586975098,50000.0,0.456900030374527,2.5264737606048584,10000.0,4640.448156118393,4840.869259595871,4640.448156118393,199.68504357337952,0.2561922073364258,0.0 -13700,1.8010551,1.8433255,,,,,,,,,,,,,, -13800,1.4145597,1.9904523,,,,,,,,,,,,,, -13900,1.4619328,1.8514771,,,,,,,,,,,,,, -14000,1.6815977,2.0132194,,,,,,,,,,,,,, -14100,1.4894019,1.9737456,,,,,,,,,,,,,, -14200,1.9107157,2.028589,,,,,,,,,,,,,, -14300,2.0515735,1.9282236,,,,,,,,,,,,,, -14400,2.4327805,1.9628949,,,,,,,,,,,,,, -14500,1.7654635,1.9310327,,,,,,,,,,,,,, -14600,1.5753908,1.9669498,,,,,,,,,,,,,, -14700,2.0654871,1.9859991,,,,,,,,,,,,,, -14800,2.6681707,1.916557,,,,,,,,,,,,,, -14900,1.85254,1.7570326,,,,,,,,,,,,,, -15000,1.7128153,1.9206506,,,,,,,,,,,,,, -15100,2.1642692,1.9372872,,,,,,,,,,,,,, -15148,,,0.6467235088348389,1.4315388202667236,0.5961399674415588,1.6884781122207642,50000.0,0.4661000072956085,2.474865436553955,10000.0,5150.681735038757,5372.002974510193,5150.681735038757,220.5019817352295,0.2849607467651367,0.0 -15200,1.6618615,1.9629529,,,,,,,,,,,,,, -15300,1.9227247,2.0023355,,,,,,,,,,,,,, -15400,1.9158835,1.901935,,,,,,,,,,,,,, -15500,1.4703104,1.9400067,,,,,,,,,,,,,, -15600,2.2237768,1.9370188,,,,,,,,,,,,,, -15700,1.6459581,1.8437726,,,,,,,,,,,,,, -15800,1.5127032,2.0017614,,,,,,,,,,,,,, -15900,1.6495922,1.8471396,,,,,,,,,,,,,, -16000,2.1082072,1.9315789,,,,,,,,,,,,,, -16100,1.656379,1.8981826,,,,,,,,,,,,,, -16200,1.7143049,1.958442,,,,,,,,,,,,,, -16300,1.8262478,1.8820719,,,,,,,,,,,,,, -16400,1.5918139,1.8608751,,,,,,,,,,,,,, -16500,1.7118932,1.9893014,,,,,,,,,,,,,, -16600,1.5762382,1.859651,,,,,,,,,,,,,, -16665,,,0.6599370241165161,1.3714268207550049,0.6070199608802795,1.642867922782898,50000.0,0.4826000332832336,2.3955681324005127,10000.0,5660.88941025734,5903.179100036621,5660.88941025734,241.3727121353149,0.3264124393463135,0.0 -16700,1.8068117,1.9257784,,,,,,,,,,,,,, -16800,1.8831435,1.9248067,,,,,,,,,,,,,, -16900,1.5035481,1.7974445,,,,,,,,,,,,,, -17000,1.458289,1.9002141,,,,,,,,,,,,,, -17100,1.7305771,1.7290355,,,,,,,,,,,,,, -17200,1.7441052,1.8751031,,,,,,,,,,,,,, -17300,1.7912793,1.8899773,,,,,,,,,,,,,, -17400,1.7563343,1.9085951,,,,,,,,,,,,,, -17500,1.698189,1.7560899,,,,,,,,,,,,,, -17600,2.176125,1.9337616,,,,,,,,,,,,,, -17700,1.3311733,1.8161408,,,,,,,,,,,,,, -17800,1.7393786,1.8166769,,,,,,,,,,,,,, -17900,1.7217671,1.7956108,,,,,,,,,,,,,, -18000,1.7086753,1.8864752,,,,,,,,,,,,,, -18100,1.9022802,1.8843185,,,,,,,,,,,,,, -18182,,,0.7046794891357422,1.1431324481964111,0.6042199730873108,1.6550146341323853,50000.0,0.4863000214099884,2.367225885391236,10000.0,6171.083922147751,6434.996114492416,6171.083922147751,262.90595984458923,0.3627340793609619,0.0 -18200,2.1685534,1.8375653,,,,,,,,,,,,,, -18300,1.5566707,1.8875104,,,,,,,,,,,,,, -18400,1.5943874,1.8830671,,,,,,,,,,,,,, -18500,1.463018,1.8225036,,,,,,,,,,,,,, -18600,2.239867,1.8733987,,,,,,,,,,,,,, -18700,1.715677,1.7626892,,,,,,,,,,,,,, -18800,1.3673627,1.6936222,,,,,,,,,,,,,, -18900,1.555104,1.8391628,,,,,,,,,,,,,, -19000,1.6974452,1.8157148,,,,,,,,,,,,,, -19100,1.54227,1.7509067,,,,,,,,,,,,,, -19200,1.8438836,1.8535032,,,,,,,,,,,,,, -19300,1.8744987,1.8906437,,,,,,,,,,,,,, -19400,1.4420019,1.7676803,,,,,,,,,,,,,, -19500,1.3658884,1.6310805,,,,,,,,,,,,,, -19600,1.678997,1.7872475,,,,,,,,,,,,,, -19698,,,0.6736288070678711,1.300462245941162,0.6003400087356567,1.6817909479141235,50000.0,0.4772000312805176,2.437773466110229,10000.0,6681.065488100052,6964.769221544266,6681.065488100052,282.60028171539307,0.4062559604644775,0.0 -19700,1.5525714,1.7895907,,,,,,,,,,,,,, -19800,1.7353644,1.8293254,,,,,,,,,,,,,, -19900,1.540464,1.9306035,,,,,,,,,,,,,, -20000,1.870293,1.9130254,,,,,,,,,,,,,, -20100,1.7553586,1.7942469,,,,,,,,,,,,,, -20200,1.8241689,1.9434586,,,,,,,,,,,,,, -20300,1.6809436,1.907548,,,,,,,,,,,,,, -20400,1.4831297,1.8543713,,,,,,,,,,,,,, -20500,1.5255462,1.7199935,,,,,,,,,,,,,, -20600,1.7072787,1.6537707,,,,,,,,,,,,,, -20700,1.7478513,1.9095782,,,,,,,,,,,,,, -20800,1.9298248,1.7724004,,,,,,,,,,,,,, -20900,1.7037891,1.7932463,,,,,,,,,,,,,, -21000,1.7388955,1.899264,,,,,,,,,,,,,, -21100,1.8335215,1.7731166,,,,,,,,,,,,,, -21200,2.1350784,1.9859145,,,,,,,,,,,,,, -21215,,,0.6816206574440002,1.2650530338287354,0.6165199875831604,1.6044803857803345,50000.0,0.4891000092029571,2.3720510005950928,10000.0,7191.288413763046,7494.415567398071,7191.288413763046,301.9392716884613,0.4372162818908691,0.0 -21300,1.8037292,1.7867738,,,,,,,,,,,,,, -21400,1.6873113,1.7875111,,,,,,,,,,,,,, -21500,1.4981557,1.8735534,,,,,,,,,,,,,, -21600,1.579858,1.6909074,,,,,,,,,,,,,, -21700,1.7407752,1.8495895,,,,,,,,,,,,,, -21800,1.5930145,1.7910578,,,,,,,,,,,,,, -21900,1.7995454,1.8348608,,,,,,,,,,,,,, -22000,1.572252,1.6899741,,,,,,,,,,,,,, -22100,1.7708384,1.8663849,,,,,,,,,,,,,, -22200,1.8445807,1.7841823,,,,,,,,,,,,,, -22300,1.6448516,1.7118708,,,,,,,,,,,,,, -22400,1.7624073,1.7783425,,,,,,,,,,,,,, -22500,1.6762546,1.8453146,,,,,,,,,,,,,, -22600,1.5040615,1.7961222,,,,,,,,,,,,,, -22700,1.458537,1.6476328,,,,,,,,,,,,,, -22732,,,0.6628069281578064,1.3463976383209229,0.6113799810409546,1.631981372833252,50000.0,0.4749000370502472,2.3722574710845947,10000.0,7701.532709360123,8024.71547794342,7701.532709360123,321.9015429019928,0.4763326644897461,0.0 -22800,1.595104,1.7358943,,,,,,,,,,,,,, -22900,1.6795032,1.8786165,,,,,,,,,,,,,, -23000,1.5000479,1.6501759,,,,,,,,,,,,,, -23100,1.7797885,1.662245,,,,,,,,,,,,,, -23200,1.5585105,1.81601,,,,,,,,,,,,,, -23300,1.8020381,1.8236742,,,,,,,,,,,,,, -23400,1.9586577,1.7402273,,,,,,,,,,,,,, -23500,1.735269,1.8305702,,,,,,,,,,,,,, -23600,1.5092784,1.8411081,,,,,,,,,,,,,, -23700,1.9701844,1.8052398,,,,,,,,,,,,,, -23800,1.6783278,1.7495918,,,,,,,,,,,,,, -23900,1.7309442,1.7695279,,,,,,,,,,,,,, -24000,1.8070643,1.7321099,,,,,,,,,,,,,, -24100,1.7333848,1.8103907,,,,,,,,,,,,,, -24200,1.6317321,1.7808238,,,,,,,,,,,,,, -24249,,,0.6758809089660645,1.2752517461776731,0.6176199913024902,1.597264289855957,50000.0,0.4928000271320343,2.328645706176758,10000.0,8211.788202762604,8557.265342473984,8211.788202762604,344.1114830970764,0.5074851512908936,0.0 -24300,1.9605523,1.8941605,,,,,,,,,,,,,, -24400,1.6727972,1.7836399,,,,,,,,,,,,,, -24500,1.6013271,1.6258274,,,,,,,,,,,,,, -24600,1.859769,1.9223557,,,,,,,,,,,,,, -24700,1.9313059,1.8695645,,,,,,,,,,,,,, -24800,1.7088035,1.7302247,,,,,,,,,,,,,, -24900,1.769714,1.8808091,,,,,,,,,,,,,, -25000,2.3148289,1.6965628,,,,,,,,,,,,,, -25100,2.3877141,1.7573422,,,,,,,,,,,,,, -25200,1.6536714,1.657112,,,,,,,,,,,,,, -25300,1.660248,1.7651467,,,,,,,,,,,,,, -25400,1.8510048,1.7933674,,,,,,,,,,,,,, -25500,1.7106276,1.8326802,,,,,,,,,,,,,, -25600,1.8263229,1.765393,,,,,,,,,,,,,, -25700,1.6942605,1.6254469,,,,,,,,,,,,,, -25765,,,0.6642617583274841,1.3381074666976929,0.6139199733734131,1.6178455352783203,50000.0,0.4853000342845917,2.364208459854126,10000.0,8721.796803951263,9090.08283829689,8721.796803951263,366.8413505554199,0.5325994491577148,0.0 -25800,1.5865922,1.7541473,,,,,,,,,,,,,, -25900,1.5774947,1.8193418,,,,,,,,,,,,,, -26000,1.534009,1.7202084,,,,,,,,,,,,,, -26100,1.7130564,1.7761239,,,,,,,,,,,,,, -26200,1.6539956,1.6798902,,,,,,,,,,,,,, -26300,1.7858293,1.8012104,,,,,,,,,,,,,, -26400,1.6923171,1.6841927,,,,,,,,,,,,,, -26500,1.7754976,1.8123387,,,,,,,,,,,,,, -26600,1.5466486,1.7649161,,,,,,,,,,,,,, -26700,1.9402462,1.838582,,,,,,,,,,,,,, -26800,1.8426785,1.7310821,,,,,,,,,,,,,, -26900,1.6977751,1.8098032,,,,,,,,,,,,,, -27000,1.5840753,1.6916162,,,,,,,,,,,,,, -27100,1.7647182,1.7570457,,,,,,,,,,,,,, -27200,1.5776923,1.6713657,,,,,,,,,,,,,, -27281,,,0.7101402878761292,1.1184210777282717,0.6214799880981445,1.5755999088287354,50000.0,0.4900000095367431,2.3477091789245605,10000.0,9231.725385665894,9624.25598335266,9231.725385665894,391.0030782222748,0.5623111724853516,0.0 -27300,1.734788,1.7604704,,,,,,,,,,,,,, -27400,1.6853212,1.7282182,,,,,,,,,,,,,, -27500,1.8174452,1.6910973,,,,,,,,,,,,,, -27600,1.9433247,1.7832762,,,,,,,,,,,,,, -27700,2.267699,1.785442,,,,,,,,,,,,,, -27800,1.6667602,1.7880994,,,,,,,,,,,,,, -27900,1.9287354,1.7582538,,,,,,,,,,,,,, -28000,1.7078618,1.8075895,,,,,,,,,,,,,, -28100,1.7929354,1.8817105,,,,,,,,,,,,,, -28200,1.83794,1.7366747,,,,,,,,,,,,,, -28300,1.8135,1.7853553,,,,,,,,,,,,,, -28400,1.8343544,1.6929994,,,,,,,,,,,,,, -28500,2.0399358,1.7768564,,,,,,,,,,,,,, -28600,1.7221569,1.7091153,,,,,,,,,,,,,, -28700,1.6070489,1.7276057,,,,,,,,,,,,,, -28798,,,0.6959701776504517,1.1897066831588743,0.6214599609375,1.5625724792480469,50000.0,0.4940000176429748,2.274938106536865,10000.0,9741.918065309525,10158.063725471497,9741.918065309525,414.5363411903381,0.5903370380401611,0.0 -28800,2.1344147,1.7734965,,,,,,,,,,,,,, -28900,1.7095244,1.6039417,,,,,,,,,,,,,, -29000,1.6836379,1.790296,,,,,,,,,,,,,, -29100,1.571855,1.6735569,,,,,,,,,,,,,, -29200,1.7514563,1.6657302,,,,,,,,,,,,,, -29300,1.6332592,1.6551589,,,,,,,,,,,,,, -29400,2.1339684,1.7047483,,,,,,,,,,,,,, -29500,1.9677597,1.6536688,,,,,,,,,,,,,, -29600,1.6880908,1.6968837,,,,,,,,,,,,,, -29700,1.7667085,1.7115062,,,,,,,,,,,,,, -29800,1.710859,1.7255293,,,,,,,,,,,,,, -29900,1.7041382,1.7364813,,,,,,,,,,,,,, -30000,1.9731716,1.7547326,,,,,,,,,,,,,, -30100,1.7864063,1.7972324,,,,,,,,,,,,,, -30200,1.6504623,1.7114717,,,,,,,,,,,,,, -30300,1.7498084,1.7640107,,,,,,,,,,,,,, -30315,,,0.6573461294174194,1.3810793161392212,0.590999960899353,1.707097411155701,50000.0,0.4736000299453735,2.4069299697875977,10000.0,10251.934691667557,10693.976108551024,10251.934691667557,440.34696435928345,0.6227197647094727,0.0 -30400,2.2313774,1.7059928,,,,,,,,,,,,,, -30500,1.9360588,1.6984267,,,,,,,,,,,,,, -30600,1.6922876,1.7655644,,,,,,,,,,,,,, -30700,1.7129561,1.7512562,,,,,,,,,,,,,, -30800,1.9733384,1.8556508,,,,,,,,,,,,,, -30900,1.6933675,1.7062764,,,,,,,,,,,,,, -31000,1.8254805,1.6675462,,,,,,,,,,,,,, -31100,1.9578985,1.7183018,,,,,,,,,,,,,, -31200,1.6710659,1.6004531,,,,,,,,,,,,,, -31300,1.743253,1.6935722,,,,,,,,,,,,,, -31400,1.6866817,1.8083692,,,,,,,,,,,,,, -31500,2.2676861,1.7380669,,,,,,,,,,,,,, -31600,2.2746375,1.7212985,,,,,,,,,,,,,, -31700,1.5931741,1.6254225,,,,,,,,,,,,,, -31800,1.7117122,1.6538827,,,,,,,,,,,,,, -31832,,,0.6697026491165161,1.3070857524871826,0.6118599772453308,1.6050184965133667,50000.0,0.4913000166416168,2.338101863861084,10000.0,10762.190105199814,11228.582449913025,10762.190105199814,464.6104917526245,0.6552050113677979,0.0 -31900,1.8658226,1.7813299,,,,,,,,,,,,,, -32000,1.6741823,1.702939,,,,,,,,,,,,,, -32100,1.8006847,1.7548708,,,,,,,,,,,,,, -32200,1.9029711,1.7501832,,,,,,,,,,,,,, -32300,1.8870578,1.678078,,,,,,,,,,,,,, -32400,1.6829478,1.6595979,,,,,,,,,,,,,, -32500,1.7940332,1.5918443,,,,,,,,,,,,,, -32600,1.8702229,1.6289978,,,,,,,,,,,,,, -32700,1.8182648,1.7153841,,,,,,,,,,,,,, -32800,1.7175643,1.610055,,,,,,,,,,,,,, -32900,1.8273398,1.7427771,,,,,,,,,,,,,, -33000,1.629898,1.6640759,,,,,,,,,,,,,, -33100,1.5910838,1.6353443,,,,,,,,,,,,,, -33200,1.8211778,1.7797103,,,,,,,,,,,,,, -33300,1.8960547,1.6111296,,,,,,,,,,,,,, -33349,,,0.6818000674247742,1.2354713678359983,0.6244399547576904,1.547203540802002,50000.0,0.4957000315189361,2.31124234199524,10000.0,11272.451827764511,11762.822406291962,11272.451827764511,488.4992530345917,0.6910066604614258,0.0 -33400,1.7372428,1.6941274,,,,,,,,,,,,,, -33500,1.764777,1.7455626,,,,,,,,,,,,,, -33600,1.7360367,1.679395,,,,,,,,,,,,,, -33700,1.8762597,1.6680824,,,,,,,,,,,,,, -33800,1.8455753,1.7197555,,,,,,,,,,,,,, -33900,1.6025258,1.7551018,,,,,,,,,,,,,, -34000,1.7588485,1.7361058,,,,,,,,,,,,,, -34100,1.6664289,1.6937646,,,,,,,,,,,,,, -34200,1.6443242,1.7379202,,,,,,,,,,,,,, -34300,1.6617483,1.759363,,,,,,,,,,,,,, -34400,2.3674257,1.8088588,,,,,,,,,,,,,, -34500,1.7609475,1.729557,,,,,,,,,,,,,, -34600,1.9251716,1.6435992,,,,,,,,,,,,,, -34700,1.8197418,1.8146389,,,,,,,,,,,,,, -34800,1.7160251,1.6316501,,,,,,,,,,,,,, -34867,,,0.6680484414100647,1.3182141780853271,0.6115800142288208,1.6301088333129885,50000.0,0.4843000173568725,2.4063782691955566,10000.0,11782.472059488297,12297.74937415123,11782.472059488297,513.3233184814453,0.7194736003875732,0.0 -34900,1.7610554,1.7158891,,,,,,,,,,,,,, -35000,1.9859723,1.6946528,,,,,,,,,,,,,, -35100,1.8616865,1.6021401,,,,,,,,,,,,,, -35200,1.7502141,1.7165141,,,,,,,,,,,,,, -35300,1.7938536,1.6986065,,,,,,,,,,,,,, -35400,1.8817445,1.6641048,,,,,,,,,,,,,, -35500,1.8003682,1.6834245,,,,,,,,,,,,,, -35600,2.0733242,1.7539892,,,,,,,,,,,,,, -35700,1.8396317,1.6534841,,,,,,,,,,,,,, -35800,1.7609038,1.7489537,,,,,,,,,,,,,, -35900,1.6641114,1.6759804,,,,,,,,,,,,,, -36000,2.0440998,1.6988542,,,,,,,,,,,,,, -36100,1.758125,1.6888983,,,,,,,,,,,,,, -36200,1.6677961,1.5941818,,,,,,,,,,,,,, -36300,1.9764583,1.785636,,,,,,,,,,,,,, -36384,,,0.7164580821990967,1.0982540845870972,0.6311799883842468,1.5274664163589478,50000.0,0.496500015258789,2.2949490547180176,10000.0,12292.574988365172,12832.48405623436,12292.574988365172,537.8653752803802,0.7547996044158936,0.0 -36400,1.9078356,1.6796526,,,,,,,,,,,,,, -36500,2.0024025,1.7540649,,,,,,,,,,,,,, -36600,1.8004606,1.7179273,,,,,,,,,,,,,, -36700,1.9008892,1.773012,,,,,,,,,,,,,, -36800,1.658118,1.5508647,,,,,,,,,,,,,, -36900,1.858144,1.6215652,,,,,,,,,,,,,, -37000,1.7121404,1.6885934,,,,,,,,,,,,,, -37100,1.796142,1.7246494,,,,,,,,,,,,,, -37200,1.8257747,1.7222853,,,,,,,,,,,,,, -37300,1.8481629,1.6997302,,,,,,,,,,,,,, -37400,1.9693277,1.7349617,,,,,,,,,,,,,, -37500,2.0099943,1.7001268,,,,,,,,,,,,,, -37600,2.0651917,1.6584485,,,,,,,,,,,,,, -37700,1.6377022,1.6655191,,,,,,,,,,,,,, -37800,1.9057592,1.6712921,,,,,,,,,,,,,, -37900,,,0.7069514989852905,1.1378610134124756,0.6352800130844116,1.50535249710083,50000.0,0.5042999982833862,2.2614681720733643,10000.0,12802.7119576931,13363.678293943403,12802.7119576931,558.8393158912659,0.7837421894073486,0.0 -37900,2.2323196,1.7549124,,,,,,,,,,,,,, -38000,1.8464711,1.54574,,,,,,,,,,,,,, -38100,1.6763213,1.5759889,,,,,,,,,,,,,, -38200,1.8904843,1.6965244,,,,,,,,,,,,,, -38300,2.1491616,1.7155452,,,,,,,,,,,,,, -38400,1.7824,1.6963713,,,,,,,,,,,,,, -38500,1.7544643,1.5982835,,,,,,,,,,,,,, -38600,2.545989,1.636126,,,,,,,,,,,,,, -38700,2.5400367,1.7363037,,,,,,,,,,,,,, -38800,2.1927547,1.7155144,,,,,,,,,,,,,, -38900,1.8847001,1.7695048,,,,,,,,,,,,,, -39000,2.0214524,1.6949474,,,,,,,,,,,,,, -39100,1.8697021,1.6211407,,,,,,,,,,,,,, -39200,1.9669658,1.7305514,,,,,,,,,,,,,, -39300,2.0114994,1.817904,,,,,,,,,,,,,, -39400,1.9050335,1.69747,,,,,,,,,,,,,, -39417,,,0.6930006146430969,1.2103509902954102,0.6251599788665771,1.5366623401641846,50000.0,0.4983000159263611,2.2956364154815674,10000.0,13312.825773000715,13896.394639730452,13312.825773000715,581.3548767566681,0.8163936138153076,0.0 -39500,1.7683241,1.6552837,,,,,,,,,,,,,, -39600,1.767505,1.6234753,,,,,,,,,,,,,, -39700,1.6521009,1.6340387,,,,,,,,,,,,,, -39800,1.7434176,1.5939523,,,,,,,,,,,,,, -39900,1.7649748,1.6112803,,,,,,,,,,,,,, -40000,1.9006404,1.6116573,,,,,,,,,,,,,, -40100,1.9030021,1.7368758,,,,,,,,,,,,,, -40200,1.7790208,1.7012048,,,,,,,,,,,,,, -40300,1.884248,1.7233429,,,,,,,,,,,,,, -40400,1.9676911,1.7378565,,,,,,,,,,,,,, -40500,1.8748902,1.7169952,,,,,,,,,,,,,, -40600,1.7080197,1.5908085,,,,,,,,,,,,,, -40700,1.7304646,1.7212245,,,,,,,,,,,,,, -40800,1.8055917,1.7097766,,,,,,,,,,,,,, -40900,1.6534629,1.6584009,,,,,,,,,,,,,, -40934,,,0.6949737071990967,1.195453643798828,0.6284799575805664,1.5238358974456787,50000.0,0.5011000037193298,2.2670748233795166,10000.0,13822.957331418993,14427.839611768724,13822.957331418993,602.5862288475037,0.8450939655303955,0.0 -41000,1.8564745,1.6547328,,,,,,,,,,,,,, -41100,1.7430027,1.6979485,,,,,,,,,,,,,, -41200,1.7592543,1.5714202,,,,,,,,,,,,,, -41300,1.8111004,1.6000031,,,,,,,,,,,,,, -41400,2.0666997,1.5944443,,,,,,,,,,,,,, -41500,1.8550932,1.6909684,,,,,,,,,,,,,, -41600,1.7760911,1.6274228,,,,,,,,,,,,,, -41700,1.830263,1.5542265,,,,,,,,,,,,,, -41800,1.8669705,1.7408385,,,,,,,,,,,,,, -41900,1.8664227,1.70107,,,,,,,,,,,,,, -42000,1.8287113,1.7130746,,,,,,,,,,,,,, -42100,1.7746454,1.6847553,,,,,,,,,,,,,, -42200,2.76964,1.6243603,,,,,,,,,,,,,, -42300,1.7985673,1.6782734,,,,,,,,,,,,,, -42400,1.845911,1.6124525,,,,,,,,,,,,,, -42451,,,0.686543345451355,1.2287561893463137,0.6284199953079224,1.5451955795288086,50000.0,0.498600035905838,2.3081836700439453,10000.0,14333.02351474762,14958.806265830994,14333.02351474762,623.4006247520447,0.8766939640045166,0.0 -42500,1.6740159,1.6879041,,,,,,,,,,,,,, -42600,1.8509601,1.5386717,,,,,,,,,,,,,, -42700,1.8193498,1.75925,,,,,,,,,,,,,, -42800,2.6198137,1.8047365,,,,,,,,,,,,,, -42900,2.1138067,1.6573238,,,,,,,,,,,,,, -43000,1.6349967,1.6111095,,,,,,,,,,,,,, -43100,1.8398454,1.5672905,,,,,,,,,,,,,, -43200,1.7933662,1.7077233,,,,,,,,,,,,,, -43300,1.9973342,1.6371026,,,,,,,,,,,,,, -43400,1.8152591,1.6493589,,,,,,,,,,,,,, -43500,2.1651018,1.6336685,,,,,,,,,,,,,, -43600,1.7221206,1.5690126,,,,,,,,,,,,,, -43700,1.8498232,1.7811074,,,,,,,,,,,,,, -43800,1.8363831,1.6298915,,,,,,,,,,,,,, -43900,2.1197877,1.6677629,,,,,,,,,,,,,, -43967,,,0.704500138759613,1.1581501960754397,0.6357600092887878,1.5145957469940186,50000.0,0.5024000406265259,2.258817195892334,10000.0,14842.94656443596,15487.173084020616,14842.94656443596,641.7515881061554,0.9145431518554688,0.0 -44000,1.6904273,1.5561424,,,,,,,,,,,,,, -44100,1.6912203,1.6058227,,,,,,,,,,,,,, -44200,1.973731,1.6457267,,,,,,,,,,,,,, -44300,1.8920684,1.6942216,,,,,,,,,,,,,, -44400,1.6377639,1.6095302,,,,,,,,,,,,,, -44500,1.8893008,1.6337763,,,,,,,,,,,,,, -44600,1.789708,1.5702058,,,,,,,,,,,,,, -44700,1.7634934,1.613322,,,,,,,,,,,,,, -44800,1.8569605,1.6182532,,,,,,,,,,,,,, -44900,1.7801292,1.6115503,,,,,,,,,,,,,, -45000,1.9128144,1.615566,,,,,,,,,,,,,, -45100,1.9238306,1.7205602,,,,,,,,,,,,,, -45200,2.0897899,1.7261634,,,,,,,,,,,,,, -45300,1.8584334,1.5757067,,,,,,,,,,,,,, -45400,1.5947641,1.6826391,,,,,,,,,,,,,, -45484,,,0.720723032951355,1.0734890699386597,0.6385200023651123,1.4788291454315186,50000.0,0.5052000284194946,2.242684364318848,10000.0,15353.166466712952,16015.389912128448,15353.166466712952,659.6599915027618,0.949812650680542,0.0 -45500,1.7125812,1.5518162,,,,,,,,,,,,,, -45600,1.9667603,1.724621,,,,,,,,,,,,,, -45700,1.5997727,1.5568632,,,,,,,,,,,,,, -45800,1.7529773,1.6446512,,,,,,,,,,,,,, -45900,2.023939,1.6745594,,,,,,,,,,,,,, -46000,1.9126767,1.6573172,,,,,,,,,,,,,, -46100,1.8734854,1.6016748,,,,,,,,,,,,,, -46200,1.791972,1.7026261,,,,,,,,,,,,,, -46300,1.9215634,1.6645427,,,,,,,,,,,,,, -46400,1.8221675,1.622616,,,,,,,,,,,,,, -46500,2.0942998,1.6682872,,,,,,,,,,,,,, -46600,1.973148,1.658523,,,,,,,,,,,,,, -46700,2.084742,1.6772021,,,,,,,,,,,,,, -46800,1.7127335,1.543362,,,,,,,,,,,,,, -46900,1.8248779,1.7035819,,,,,,,,,,,,,, -47000,1.9065013,1.6433018,,,,,,,,,,,,,, -47001,,,0.7143255472183228,1.107452392578125,0.6450999975204468,1.4719318151474,50000.0,0.5174000263214111,2.193423271179199,10000.0,15863.43516588211,16543.592682123184,15863.43516588211,677.5013723373413,0.9880750179290771,0.0 -47100,1.9447541,1.6177924,,,,,,,,,,,,,, -47200,1.9450768,1.5795931,,,,,,,,,,,,,, -47300,1.7376102,1.6202344,,,,,,,,,,,,,, -47400,1.7236762,1.5840967,,,,,,,,,,,,,, -47500,2.133557,1.5685099,,,,,,,,,,,,,, -47600,1.8565146,1.5812409,,,,,,,,,,,,,, -47700,1.9337778,1.574109,,,,,,,,,,,,,, -47800,2.1809444,1.747504,,,,,,,,,,,,,, -47900,2.0332263,1.6043674,,,,,,,,,,,,,, -48000,1.6441115,1.5907248,,,,,,,,,,,,,, -48100,1.8505986,1.6364105,,,,,,,,,,,,,, -48200,1.9280499,1.6397755,,,,,,,,,,,,,, -48300,1.9010123,1.6388458,,,,,,,,,,,,,, -48400,1.6767504,1.6561203,,,,,,,,,,,,,, -48500,2.0299237,1.702664,,,,,,,,,,,,,, -48518,,,0.7139468789100647,1.1055635213851929,0.6447799801826477,1.4599003791809082,50000.0,0.5217000246047974,2.1961843967437744,10000.0,16373.61761522293,17072.783661842346,16373.61761522293,696.4210996627808,1.0227127075195312,0.0 -48600,1.977496,1.5277483,,,,,,,,,,,,,, -48700,1.9735936,1.6482391,,,,,,,,,,,,,, -48800,1.7689381,1.5603468,,,,,,,,,,,,,, -48900,1.9471352,1.5317435,,,,,,,,,,,,,, -49000,2.100959,1.5927781,,,,,,,,,,,,,, -49100,1.8612895,1.6058621,,,,,,,,,,,,,, -49200,1.9083991,1.6438755,,,,,,,,,,,,,, -49300,1.8780264,1.6745255,,,,,,,,,,,,,, -49400,1.7063997,1.6353924,,,,,,,,,,,,,, -49500,1.8208259,1.5257725,,,,,,,,,,,,,, -49600,2.0460572,1.5505443,,,,,,,,,,,,,, -49700,1.9648285,1.6297944,,,,,,,,,,,,,, -49800,1.9616839,1.5400944,,,,,,,,,,,,,, -49900,1.7373197,1.474889,,,,,,,,,,,,,, -50000,1.944046,1.6570127,,,,,,,,,,,,,, -50035,,,0.703523576259613,1.156677484512329,0.6417999863624573,1.461700439453125,50000.0,0.5103999972343445,2.2414681911468506,10000.0,16883.693967342377,17600.717061519623,16883.693967342377,714.1883132457733,1.057837963104248,0.0 -50100,1.8768593,1.6736395,,,,,,,,,,,,,, -50200,1.6710858,1.5376663,,,,,,,,,,,,,, -50300,2.0500822,1.5332053,,,,,,,,,,,,,, -50400,1.9986637,1.5147374,,,,,,,,,,,,,, -50500,1.785062,1.6044296,,,,,,,,,,,,,, -50600,2.1011205,1.721679,,,,,,,,,,,,,, -50700,1.9873937,1.6660248,,,,,,,,,,,,,, -50800,1.8999379,1.5591977,,,,,,,,,,,,,, -50900,2.0830686,1.6181288,,,,,,,,,,,,,, -51000,1.849365,1.5976017,,,,,,,,,,,,,, -51100,1.7101201,1.6051198,,,,,,,,,,,,,, -51200,2.0081255,1.6905936,,,,,,,,,,,,,, -51300,1.842646,1.7006626,,,,,,,,,,,,,, -51400,1.7418945,1.6892995,,,,,,,,,,,,,, -51500,2.0781822,1.5869036,,,,,,,,,,,,,, -51552,,,0.7147839665412903,1.111628174781799,0.6556400060653687,1.4142948389053345,50000.0,0.5249000191688538,2.136894941329956,10000.0,17393.82016658783,18129.149336099625,17393.82016658783,732.4035475254059,1.0941922664642334,0.0 -51600,2.0013423,1.648052,,,,,,,,,,,,,, -51700,1.856774,1.6223391,,,,,,,,,,,,,, -51800,2.040666,1.5684776,,,,,,,,,,,,,, -51900,2.1015131,1.6943896,,,,,,,,,,,,,, -52000,2.2144961,1.697811,,,,,,,,,,,,,, -52100,1.7645906,1.6352346,,,,,,,,,,,,,, -52200,1.7036493,1.6331788,,,,,,,,,,,,,, -52300,2.0438933,1.6819204,,,,,,,,,,,,,, -52400,1.949189,1.5532504,,,,,,,,,,,,,, -52500,1.8872795,1.5566721,,,,,,,,,,,,,, -52600,1.911113,1.7302637,,,,,,,,,,,,,, -52700,1.8173693,1.5296018,,,,,,,,,,,,,, -52800,1.8319979,1.6130803,,,,,,,,,,,,,, -52900,1.8576198,1.55537,,,,,,,,,,,,,, -53000,1.7384075,1.544701,,,,,,,,,,,,,, -53069,,,0.7700493931770325,0.8872178196907043,0.6549599766731262,1.4167771339416504,50000.0,0.5270000100135803,2.1360621452331543,10000.0,17903.856494903564,18656.95786833763,17903.856494903564,750.0831353664398,1.131983757019043,0.0 -53100,1.7090044,1.4792249,,,,,,,,,,,,,, -53200,1.677231,1.5296166,,,,,,,,,,,,,, -53300,1.9942272,1.6827368,,,,,,,,,,,,,, -53400,1.6774913,1.5033242,,,,,,,,,,,,,, -53500,1.8538901,1.6284269,,,,,,,,,,,,,, -53600,1.8482683,1.6466438,,,,,,,,,,,,,, -53700,2.1418421,1.5994592,,,,,,,,,,,,,, -53800,1.925883,1.6633666,,,,,,,,,,,,,, -53900,1.975417,1.5799638,,,,,,,,,,,,,, -54000,1.9614539,1.6688337,,,,,,,,,,,,,, -54100,1.8941721,1.5834984,,,,,,,,,,,,,, -54200,1.8378412,1.5357178,,,,,,,,,,,,,, -54300,1.7887076,1.4506137,,,,,,,,,,,,,, -54400,1.7755092,1.6190735,,,,,,,,,,,,,, -54500,1.9980559,1.6654416,,,,,,,,,,,,,, -54587,,,0.7238121628761292,1.0641529560089111,0.6489999890327454,1.4466161727905271,50000.0,0.5200000405311584,2.1638076305389404,10000.0,18414.08018231392,19187.921802043915,18414.08018231392,770.7322072982788,1.1675801277160645,0.0 -54600,2.0196447,1.7121365,,,,,,,,,,,,,, -54700,1.861163,1.6966925,,,,,,,,,,,,,, -54800,1.8243206,1.6526145,,,,,,,,,,,,,, -54900,1.8334178,1.5270813,,,,,,,,,,,,,, -55000,1.9167662,1.691954,,,,,,,,,,,,,, -55100,2.1103542,1.6670499,,,,,,,,,,,,,, -55200,1.6726948,1.6024101,,,,,,,,,,,,,, -55300,1.8842441,1.6480881,,,,,,,,,,,,,, -55400,1.9319588,1.6122782,,,,,,,,,,,,,, -55500,1.9487131,1.6667048,,,,,,,,,,,,,, -55600,1.8303928,1.6519755,,,,,,,,,,,,,, -55700,1.8760538,1.6947523,,,,,,,,,,,,,, -55800,2.004854,1.6029549,,,,,,,,,,,,,, -55900,1.834418,1.6439925,,,,,,,,,,,,,, -56000,2.0185473,1.5741923,,,,,,,,,,,,,, -56100,2.1092355,1.4868654,,,,,,,,,,,,,, -56103,,,0.7228754758834839,1.0718671083450315,0.6495999693870544,1.4355239868164062,50000.0,0.5175000429153442,2.1802403926849365,10000.0,18924.0583422184,19716.214436769485,18924.0583422184,788.9550342559814,1.20243501663208,0.0 -56200,1.8227594,1.6283966,,,,,,,,,,,,,, -56300,1.8275012,1.4912984,,,,,,,,,,,,,, -56400,1.8525463,1.6261607,,,,,,,,,,,,,, -56500,1.8804929,1.6125488,,,,,,,,,,,,,, -56600,2.0085056,1.4881511,,,,,,,,,,,,,, -56700,1.7845076,1.5917062,,,,,,,,,,,,,, -56800,1.8126146,1.6290177,,,,,,,,,,,,,, -56900,1.9367743,1.4910555,,,,,,,,,,,,,, -57000,1.6939901,1.4788486,,,,,,,,,,,,,, -57100,1.8693326,1.5974063,,,,,,,,,,,,,, -57200,1.9470948,1.6021931,,,,,,,,,,,,,, -57300,2.0172825,1.5303242,,,,,,,,,,,,,, -57400,1.9539366,1.5225623,,,,,,,,,,,,,, -57500,2.0553195,1.5441298,,,,,,,,,,,,,, -57600,2.1244357,1.5692762,,,,,,,,,,,,,, -57621,,,0.7300701141357422,1.035422444343567,0.6582599878311157,1.3901373147964478,50000.0,0.5300000309944153,2.1339540481567383,10000.0,19434.272382974625,20246.79317426681,19434.272382974625,809.2250037193298,1.2412221431732178,0.0 -57700,1.736495,1.4575033,,,,,,,,,,,,,, -57800,1.8780729,1.6313926,,,,,,,,,,,,,, -57900,1.7533398,1.659629,,,,,,,,,,,,,, -58000,1.8839709,1.5130454,,,,,,,,,,,,,, -58100,2.328613,1.6721113,,,,,,,,,,,,,, -58200,1.769896,1.6016126,,,,,,,,,,,,,, -58300,2.013812,1.4564334,,,,,,,,,,,,,, -58400,1.8175311,1.3697855,,,,,,,,,,,,,, -58500,1.8413347,1.6421409,,,,,,,,,,,,,, -58600,1.9268204,1.6128234,,,,,,,,,,,,,, -58700,2.1554933,1.4960157,,,,,,,,,,,,,, -58800,1.724191,1.4804019,,,,,,,,,,,,,, -58900,2.0511637,1.5542197,,,,,,,,,,,,,, -59000,1.9604058,1.5150925,,,,,,,,,,,,,, -59100,1.8429894,1.6590623,,,,,,,,,,,,,, -59137,,,0.7110969424247742,1.1127877235412598,0.6474999785423279,1.4427329301834106,50000.0,0.5121999979019165,2.223588466644287,10000.0,19944.21973395348,20774.16160988808,19944.21973395348,826.5554084777832,1.2777154445648191,0.0 -59200,1.9094234,1.5169247,,,,,,,,,,,,,, -59300,1.7574923,1.4775063,,,,,,,,,,,,,, -59400,1.9195293,1.6159501,,,,,,,,,,,,,, -59500,1.8851451,1.5136976,,,,,,,,,,,,,, -59600,1.8899567,1.6029485,,,,,,,,,,,,,, -59700,2.021863,1.5202107,,,,,,,,,,,,,, -59800,1.8856024,1.6170034,,,,,,,,,,,,,, -59900,2.1463537,1.4670619,,,,,,,,,,,,,, -60000,1.9445556,1.5623404,,,,,,,,,,,,,, -60100,1.7915295,1.5547601,,,,,,,,,,,,,, -60200,2.1001956,1.6302899,,,,,,,,,,,,,, -60300,2.16319,1.703775,,,,,,,,,,,,,, -60400,1.9551924,1.6390649,,,,,,,,,,,,,, -60500,1.9584517,1.6139828,,,,,,,,,,,,,, -60600,2.0960057,1.6062737,,,,,,,,,,,,,, -60655,,,0.7175143361091614,1.090225100517273,0.6553599834442139,1.408957600593567,50000.0,0.5234000086784363,2.151703119277954,10000.0,20454.302026748657,21302.164501667023,20454.302026748657,844.3818547725677,1.316425323486328,0.0 -60700,1.990917,1.5608599,,,,,,,,,,,,,, -60800,2.1513145,1.646625,,,,,,,,,,,,,, -60900,1.7494495,1.5870801,,,,,,,,,,,,,, -61000,1.7715518,1.5102668,,,,,,,,,,,,,, -61100,1.7726859,1.5650982,,,,,,,,,,,,,, -61200,1.9103838,1.5191886,,,,,,,,,,,,,, -61300,1.9116768,1.5526764,,,,,,,,,,,,,, -61400,2.1012082,1.7271875,,,,,,,,,,,,,, -61500,1.7512867,1.650458,,,,,,,,,,,,,, -61600,1.9158108,1.6508869,,,,,,,,,,,,,, -61700,1.8550408,1.5561874,,,,,,,,,,,,,, -61800,1.9178296,1.5829105,,,,,,,,,,,,,, -61900,1.7696751,1.5451382,,,,,,,,,,,,,, -62000,1.9979824,1.5434678,,,,,,,,,,,,,, -62100,1.8226243,1.5191438,,,,,,,,,,,,,, -62173,,,0.7569156289100647,0.9074265360832214,0.6525200009346008,1.410874843597412,50000.0,0.51910001039505,2.190683364868164,10000.0,20964.30534338951,21830.071031332016,20964.30534338951,862.1902160644531,1.356015682220459,0.0 -62200,2.329591,1.5323019,,,,,,,,,,,,,, -62300,2.176008,1.5020717,,,,,,,,,,,,,, -62400,2.267555,1.5060934,,,,,,,,,,,,,, -62500,1.8957745,1.5017271,,,,,,,,,,,,,, -62600,1.8266699,1.4977597,,,,,,,,,,,,,, -62700,2.0391533,1.5589588,,,,,,,,,,,,,, -62800,1.8321675,1.4750057,,,,,,,,,,,,,, -62900,1.9811825,1.5032338,,,,,,,,,,,,,, -63000,1.8624054,1.4778042,,,,,,,,,,,,,, -63100,2.2485092,1.5945382,,,,,,,,,,,,,, -63200,2.1723926,1.6059202,,,,,,,,,,,,,, -63300,2.0489192,1.513541,,,,,,,,,,,,,, -63400,2.0031002,1.6230711,,,,,,,,,,,,,, -63500,1.9744358,1.4759985,,,,,,,,,,,,,, -63600,1.9295912,1.543821,,,,,,,,,,,,,, -63691,,,0.7365473508834839,1.001426100730896,0.6548999547958374,1.4198815822601318,50000.0,0.5295000076293945,2.1336960792541504,10000.0,21474.428068876263,22359.483619451523,21474.428068876263,881.3823571205139,1.397477388381958,0.0 -63700,2.0818844,1.5539083,,,,,,,,,,,,,, -63800,1.8186576,1.4762651,,,,,,,,,,,,,, -63900,1.9343098,1.6413027,,,,,,,,,,,,,, -64000,2.0635853,1.6602279,,,,,,,,,,,,,, -64100,2.1431751,1.5470158,,,,,,,,,,,,,, -64200,1.9481876,1.6499481,,,,,,,,,,,,,, -64300,2.2164009,1.6294943,,,,,,,,,,,,,, -64400,1.9817357,1.5269263,,,,,,,,,,,,,, -64500,1.953206,1.5679314,,,,,,,,,,,,,, -64600,1.9202836,1.4890096,,,,,,,,,,,,,, -64700,2.0289645,1.5074735,,,,,,,,,,,,,, -64800,1.9765457,1.6104281,,,,,,,,,,,,,, -64900,2.4635925,1.4722733,,,,,,,,,,,,,, -65000,2.2005863,1.5531938,,,,,,,,,,,,,, -65100,1.9649897,1.5454949,,,,,,,,,,,,,, -65200,2.2217693,1.5410836,,,,,,,,,,,,,, -65209,,,0.7339963316917419,1.0098152160644531,0.6584599614143372,1.3867906332015991,50000.0,0.5376999974250793,2.088006019592285,10000.0,21984.534370660785,22887.84405231476,21984.534370660785,899.5424103736877,1.4358513355255127,0.0 -65300,1.7857863,1.4110432,,,,,,,,,,,,,, -65400,1.9499594,1.5696588,,,,,,,,,,,,,, -65500,2.263103,1.531949,,,,,,,,,,,,,, -65600,1.9559401,1.5063148,,,,,,,,,,,,,, -65700,2.0759745,1.5503172,,,,,,,,,,,,,, -65800,1.9837674,1.5236645,,,,,,,,,,,,,, -65900,2.1889353,1.4605805,,,,,,,,,,,,,, -66000,2.201473,1.6302295,,,,,,,,,,,,,, -66100,2.150632,1.4752047,,,,,,,,,,,,,, -66200,1.8094271,1.5225011,,,,,,,,,,,,,, -66300,1.9443303,1.554349,,,,,,,,,,,,,, -66400,1.7733405,1.4734867,,,,,,,,,,,,,, -66500,2.0631526,1.440453,,,,,,,,,,,,,, -66600,2.0531156,1.5761443,,,,,,,,,,,,,, -66700,1.8207864,1.4370403,,,,,,,,,,,,,, -66726,,,0.7247887253761292,1.0592175722122192,0.6513000130653381,1.425256371498108,50000.0,0.5291000008583069,2.114695310592652,10000.0,22494.692160129547,23415.391426324844,22494.692160129547,916.836306333542,1.4768211841583252,0.0 -66800,2.1212063,1.5621107,,,,,,,,,,,,,, -66900,2.3560297,1.5157537,,,,,,,,,,,,,, -67000,1.9620997,1.5581102,,,,,,,,,,,,,, -67100,2.3271108,1.6445565,,,,,,,,,,,,,, -67200,2.0677867,1.4726757,,,,,,,,,,,,,, -67300,1.9936439,1.5211067,,,,,,,,,,,,,, -67400,1.9467989,1.4859512,,,,,,,,,,,,,, -67500,2.2686725,1.4566237,,,,,,,,,,,,,, -67600,1.8482045,1.4033364,,,,,,,,,,,,,, -67700,2.2941093,1.595362,,,,,,,,,,,,,, -67800,1.995197,1.5278912,,,,,,,,,,,,,, -67900,2.1056776,1.5287232,,,,,,,,,,,,,, -68000,1.9928824,1.5368108,,,,,,,,,,,,,, -68100,2.117592,1.5337234,,,,,,,,,,,,,, -68200,2.16759,1.4793029,,,,,,,,,,,,,, -68243,,,0.7330994606018066,1.0156515836715698,0.6629999876022339,1.369211196899414,50000.0,0.531000018119812,2.121507167816162,10000.0,23004.89267778397,23943.01952242852,23004.89267778397,934.1671011447906,1.5193002223968506,0.0 -68300,1.9257045,1.48073,,,,,,,,,,,,,, -68400,2.1502542,1.4902499,,,,,,,,,,,,,, -68500,1.8955762,1.5374022,,,,,,,,,,,,,, -68600,1.9843279,1.534673,,,,,,,,,,,,,, -68700,2.0507183,1.5087484,,,,,,,,,,,,,, -68800,2.0319285,1.4966404,,,,,,,,,,,,,, -68900,2.029613,1.5163991,,,,,,,,,,,,,, -69000,2.0123286,1.568948,,,,,,,,,,,,,, -69100,2.0985937,1.5442232,,,,,,,,,,,,,, -69200,1.8847346,1.4841077,,,,,,,,,,,,,, -69300,1.9344815,1.3997524,,,,,,,,,,,,,, -69400,2.3773875,1.6642787,,,,,,,,,,,,,, -69500,2.259599,1.5320792,,,,,,,,,,,,,, -69600,2.1756742,1.5848398,,,,,,,,,,,,,, -69700,2.0322344,1.5558927,,,,,,,,,,,,,, -69760,,,0.7288743257522583,1.038947582244873,0.6620599627494812,1.3821091651916504,50000.0,0.5324000120162964,2.080422163009644,10000.0,23514.99297809601,24471.2444729805,23514.99297809601,952.1961376667024,1.5602679252624512,0.0 -69800,1.9875922,1.5666302,,,,,,,,,,,,,, -69900,2.0986714,1.5747585,,,,,,,,,,,,,, -70000,1.9189948,1.5716517,,,,,,,,,,,,,, -70100,1.9245819,1.4996785,,,,,,,,,,,,,, -70200,2.1008832,1.5447743,,,,,,,,,,,,,, -70300,2.0048985,1.5129459,,,,,,,,,,,,,, -70400,2.1058803,1.4676925,,,,,,,,,,,,,, -70500,2.1802602,1.6952586,,,,,,,,,,,,,, -70600,2.427061,1.6849688,,,,,,,,,,,,,, -70700,2.0944903,1.5340433,,,,,,,,,,,,,, -70800,1.8188277,1.4714818,,,,,,,,,,,,,, -70900,2.043485,1.517706,,,,,,,,,,,,,, -71000,2.1669345,1.6403358,,,,,,,,,,,,,, -71100,1.8951354,1.4090295,,,,,,,,,,,,,, -71200,2.386416,1.5763611,,,,,,,,,,,,,, -71277,,,0.7565967440605164,0.9106816649436952,0.6576399803161621,1.408387303352356,50000.0,0.5273000001907349,2.185938835144043,10000.0,24025.17360162735,24998.686302900314,24025.17360162735,969.347799539566,1.6115102767944336,0.0 -71300,2.0553074,1.6028615,,,,,,,,,,,,,, -71400,2.1181624,1.4468439,,,,,,,,,,,,,, -71500,2.289858,1.5650282,,,,,,,,,,,,,, -71600,1.979555,1.5228679,,,,,,,,,,,,,, -71700,2.0674663,1.4432381,,,,,,,,,,,,,, -71800,1.8980643,1.4292301,,,,,,,,,,,,,, -71900,2.2165952,1.5338221,,,,,,,,,,,,,, -72000,1.973462,1.5363885,,,,,,,,,,,,,, -72100,2.012824,1.44749,,,,,,,,,,,,,, -72200,2.0026207,1.5442082,,,,,,,,,,,,,, -72300,2.2676704,1.48673,,,,,,,,,,,,,, -72400,1.8877158,1.4314682,,,,,,,,,,,,,, -72500,1.9821826,1.4975984,,,,,,,,,,,,,, -72600,2.1169174,1.4900901,,,,,,,,,,,,,, -72700,2.2832146,1.5669953,,,,,,,,,,,,,, -72794,,,0.7460139989852905,0.9631520509719848,0.6609399914741516,1.3794595003128052,50000.0,0.5297999978065491,2.1230790615081787,10000.0,24535.09615302086,25525.9612326622,24535.09615302086,986.60405087471,1.6504340171813965,0.0 -72800,1.9701087,1.5315977,,,,,,,,,,,,,, -72900,2.0177767,1.4379352,,,,,,,,,,,,,, -73000,2.0732903,1.619544,,,,,,,,,,,,,, -73100,2.0037448,1.6943735,,,,,,,,,,,,,, -73200,2.2547495,1.5738423,,,,,,,,,,,,,, -73300,1.8122171,1.4322993,,,,,,,,,,,,,, -73400,1.8286358,1.3562123,,,,,,,,,,,,,, -73500,2.2192323,1.5501401,,,,,,,,,,,,,, -73600,2.0559385,1.4547396,,,,,,,,,,,,,, -73700,2.0788527,1.4750978,,,,,,,,,,,,,, -73800,2.0006773,1.4614868,,,,,,,,,,,,,, -73900,2.2164156,1.47377,,,,,,,,,,,,,, -74000,2.1333587,1.5714281,,,,,,,,,,,,,, -74100,2.2053778,1.5659426,,,,,,,,,,,,,, -74200,2.0062337,1.5292628,,,,,,,,,,,,,, -74300,2.0947876,1.405876,,,,,,,,,,,,,, -74311,,,0.7478475570678711,0.971817135810852,0.6683200001716614,1.3521226644515991,50000.0,0.5426000356674194,2.07227635383606,10000.0,25045.33080291748,26054.60822916031,25045.33080291748,1004.921329498291,1.6894316673278809,0.0 -74400,2.050686,1.6259532,,,,,,,,,,,,,, -74500,1.842123,1.3833911,,,,,,,,,,,,,, -74600,2.117752,1.5622698,,,,,,,,,,,,,, -74700,2.2257586,1.5067939,,,,,,,,,,,,,, -74800,1.9979129,1.3860475,,,,,,,,,,,,,, -74900,2.162665,1.5661142,,,,,,,,,,,,,, -75000,2.3035617,1.4863086,,,,,,,,,,,,,, -75100,1.8611317,1.493713,,,,,,,,,,,,,, -75200,2.1303825,1.6028647,,,,,,,,,,,,,, -75300,2.0306973,1.5502658,,,,,,,,,,,,,, -75400,2.2696211,1.4197552,,,,,,,,,,,,,, -75500,2.4003959,1.4843094,,,,,,,,,,,,,, -75600,2.078195,1.477791,,,,,,,,,,,,,, -75700,2.3408465,1.5594814,,,,,,,,,,,,,, -75800,2.2001102,1.5386026,,,,,,,,,,,,,, -75829,,,0.7421077489852905,0.9846024513244628,0.6684799790382385,1.3483023643493652,50000.0,0.5388000011444092,2.0896012783050537,10000.0,25555.53547239304,26583.593682527546,25555.53547239304,1023.604052066803,1.732506513595581,0.0 -75900,2.1728108,1.4058495,,,,,,,,,,,,,, -76000,2.0692096,1.4408534,,,,,,,,,,,,,, -76100,2.1365438,1.4216672,,,,,,,,,,,,,, -76200,2.0043166,1.4288726,,,,,,,,,,,,,, -76300,2.1379402,1.3614595,,,,,,,,,,,,,, -76400,2.1342967,1.5229892,,,,,,,,,,,,,, -76500,2.1676962,1.5579685,,,,,,,,,,,,,, -76600,2.2574542,1.543776,,,,,,,,,,,,,, -76700,1.9771243,1.5365188,,,,,,,,,,,,,, -76800,2.2357717,1.5888902,,,,,,,,,,,,,, -76900,2.038103,1.4876132,,,,,,,,,,,,,, -77000,2.3308756,1.4376316,,,,,,,,,,,,,, -77100,2.061639,1.4932717,,,,,,,,,,,,,, -77200,2.000136,1.4484829,,,,,,,,,,,,,, -77300,2.1939018,1.5224038,,,,,,,,,,,,,, -77346,,,0.7206632494926453,1.0757194757461548,0.6527599692344666,1.4330989122390747,50000.0,0.5294000506401062,2.150460720062256,10000.0,26065.73763155937,27111.376794815063,26065.73763155937,1041.087366104126,1.7743992805480957,0.0 -77400,2.2627635,1.5029536,,,,,,,,,,,,,, -77500,1.9921154,1.4746578,,,,,,,,,,,,,, -77600,2.0303347,1.4557469,,,,,,,,,,,,,, -77700,2.210848,1.4816301,,,,,,,,,,,,,, -77800,1.9872284,1.3644558,,,,,,,,,,,,,, -77900,2.1915767,1.453552,,,,,,,,,,,,,, -78000,2.3169217,1.5366725,,,,,,,,,,,,,, -78100,2.1922913,1.4368889,,,,,,,,,,,,,, -78200,2.2748077,1.5081915,,,,,,,,,,,,,, -78300,2.1302087,1.5169015,,,,,,,,,,,,,, -78400,2.0567544,1.5071331,,,,,,,,,,,,,, -78500,2.0611477,1.401055,,,,,,,,,,,,,, -78600,1.941357,1.4247361,,,,,,,,,,,,,, -78700,2.0794759,1.5124302,,,,,,,,,,,,,, -78800,2.307352,1.4737746,,,,,,,,,,,,,, -78861,,,0.7466318607330322,0.96878182888031,0.6676799654960632,1.3544172048568726,50000.0,0.5441000461578369,2.095369815826416,10000.0,26574.949008226395,27638.854197502136,26574.949008226395,1058.3813076019287,2.6917877197265625,0.0 -78900,2.1096218,1.3856807,,,,,,,,,,,,,, -79000,2.2297924,1.5330946,,,,,,,,,,,,,, -79100,2.2503648,1.545017,,,,,,,,,,,,,, -79200,2.0212653,1.3788358,,,,,,,,,,,,,, -79300,2.1240563,1.5303972,,,,,,,,,,,,,, -79400,1.9862115,1.5281492,,,,,,,,,,,,,, -79500,2.2235413,1.4393921,,,,,,,,,,,,,, -79600,2.285141,1.4615095,,,,,,,,,,,,,, -79700,2.0945804,1.4359474,,,,,,,,,,,,,, -79800,2.295864,1.4570429,,,,,,,,,,,,,, -79900,2.080252,1.4317579,,,,,,,,,,,,,, -80000,2.3837404,1.4775509,,,,,,,,,,,,,, -80100,2.0617282,1.4984769,,,,,,,,,,,,,, -80200,2.1634543,1.4454883,,,,,,,,,,,,,, -80300,2.2383392,1.4849792,,,,,,,,,,,,,, -80378,,,0.7697305083274841,0.8728119730949402,0.6732000112533569,1.3282780647277832,50000.0,0.5437000393867493,2.056852340698242,10000.0,27084.90593457222,28166.12890410424,27084.90593457222,1075.6019878387451,2.734912633895874,0.0 -80400,2.2761152,1.4820473,,,,,,,,,,,,,, -80500,2.5360434,1.4713484,,,,,,,,,,,,,, -80600,2.1705494,1.5229692,,,,,,,,,,,,,, -80700,1.9681745,1.5226977,,,,,,,,,,,,,, -80800,2.3075256,1.4349076,,,,,,,,,,,,,, -80900,2.2393763,1.5245421,,,,,,,,,,,,,, -81000,2.0690653,1.4926925,,,,,,,,,,,,,, -81100,2.15054,1.481035,,,,,,,,,,,,,, -81200,1.87538,1.4411931,,,,,,,,,,,,,, -81300,2.0227065,1.4094795,,,,,,,,,,,,,, -81400,2.345636,1.4157777,,,,,,,,,,,,,, -81500,2.2898445,1.4059491,,,,,,,,,,,,,, -81600,2.0697038,1.5092828,,,,,,,,,,,,,, -81700,1.9838916,1.44238,,,,,,,,,,,,,, -81800,2.2418792,1.4579418,,,,,,,,,,,,,, -81894,,,0.748066782951355,0.9496461153030396,0.6683799624443054,1.3535780906677246,50000.0,0.5425000190734863,2.098299026489258,10000.0,27594.912347316746,28693.44144463539,27594.912347316746,1092.8077733516693,2.7800300121307373,0.0 -81900,2.1055417,1.3728268,,,,,,,,,,,,,, -82000,2.1234035,1.4359764,,,,,,,,,,,,,, -82100,2.1479735,1.4604137,,,,,,,,,,,,,, -82200,2.2152321,1.48006,,,,,,,,,,,,,, -82300,2.36881,1.4922558,,,,,,,,,,,,,, -82400,2.2748072,1.4499795,,,,,,,,,,,,,, -82500,2.1863039,1.4125748,,,,,,,,,,,,,, -82600,2.2989461,1.5336716,,,,,,,,,,,,,, -82700,2.3403244,1.3201989,,,,,,,,,,,,,, -82800,2.1181445,1.3502818,,,,,,,,,,,,,, -82900,2.4692802,1.4266729,,,,,,,,,,,,,, -83000,2.104908,1.4569802,,,,,,,,,,,,,, -83100,2.227309,1.500735,,,,,,,,,,,,,, -83200,2.130532,1.40699,,,,,,,,,,,,,, -83300,2.1523907,1.3721945,,,,,,,,,,,,,, -83400,2.1147056,1.3333633,,,,,,,,,,,,,, -83411,,,0.7523118257522583,0.942414402961731,0.6723600029945374,1.330456018447876,50000.0,0.5421000123023987,2.059842586517334,10000.0,28104.9787709713,29220.95869898796,28104.9787709713,1110.1627497673037,2.821474552154541,0.0 -83500,2.2637546,1.3842505,,,,,,,,,,,,,, -83600,2.3139095,1.4675487,,,,,,,,,,,,,, -83700,2.1528754,1.3696342,,,,,,,,,,,,,, -83800,2.2126904,1.4494712,,,,,,,,,,,,,, -83900,2.4875474,1.4086387,,,,,,,,,,,,,, -84000,2.2080722,1.4675255,,,,,,,,,,,,,, -84100,1.9967802,1.3292593,,,,,,,,,,,,,, -84200,2.081078,1.4124383,,,,,,,,,,,,,, -84300,2.0984178,1.3430715,,,,,,,,,,,,,, -84400,2.1331484,1.4168439,,,,,,,,,,,,,, -84500,2.2478466,1.4750228,,,,,,,,,,,,,, -84600,2.119572,1.406665,,,,,,,,,,,,,, -84700,2.2820804,1.5242274,,,,,,,,,,,,,, -84800,2.3473241,1.47177,,,,,,,,,,,,,, -84900,2.2580798,1.4649098,,,,,,,,,,,,,, -84928,,,0.7465720772743225,0.9657593369483948,0.6717000007629395,1.3397146463394165,50000.0,0.5400000214576721,2.069480657577514,10000.0,28615.038668632507,29748.444899082184,28615.038668632507,1127.4880149364471,2.866657495498657,0.0 -85000,2.1734731,1.3933645,,,,,,,,,,,,,, -85100,2.4480097,1.4472475,,,,,,,,,,,,,, -85200,2.131366,1.398042,,,,,,,,,,,,,, -85300,2.1842446,1.4370056,,,,,,,,,,,,,, -85400,2.0582654,1.3738178,,,,,,,,,,,,,, -85500,2.1074984,1.4587003,,,,,,,,,,,,,, -85600,2.0665174,1.4396982,,,,,,,,,,,,,, -85700,2.1355286,1.3962262,,,,,,,,,,,,,, -85800,2.101926,1.3541741,,,,,,,,,,,,,, -85900,2.0327926,1.3348951,,,,,,,,,,,,,, -86000,2.1045918,1.3897604,,,,,,,,,,,,,, -86100,2.2785687,1.4164876,,,,,,,,,,,,,, -86200,2.0066667,1.3463495,,,,,,,,,,,,,, -86300,2.0558336,1.3635911,,,,,,,,,,,,,, -86400,2.1642034,1.4629204,,,,,,,,,,,,,, -86444,,,0.7580317258834839,0.906674325466156,0.6821799874305725,1.300117254257202,50000.0,0.5557000041007996,2.0219106674194336,10000.0,29125.09669423104,30276.12886691093,29125.09669423104,1145.004815340042,2.921449422836304,0.0 -86500,2.248385,1.5157834,,,,,,,,,,,,,, -86600,2.5030582,1.4291786,,,,,,,,,,,,,, -86700,2.3168247,1.3359349,,,,,,,,,,,,,, -86800,2.469908,1.4234678,,,,,,,,,,,,,, -86900,2.1723642,1.4208968,,,,,,,,,,,,,, -87000,2.3645868,1.4625344,,,,,,,,,,,,,, -87100,2.4917245,1.4965874,,,,,,,,,,,,,, -87200,2.1980977,1.3804647,,,,,,,,,,,,,, -87300,2.398783,1.4469504,,,,,,,,,,,,,, -87400,2.4606154,1.4197695,,,,,,,,,,,,,, -87500,2.1926155,1.4473717,,,,,,,,,,,,,, -87600,2.1917012,1.3613927,,,,,,,,,,,,,, -87700,2.5268588,1.5467576,,,,,,,,,,,,,, -87800,2.1426892,1.3309047,,,,,,,,,,,,,, -87900,2.411948,1.422444,,,,,,,,,,,,,, -87961,,,0.7918526530265808,0.7671975493431091,0.6766999959945679,1.3069652318954468,50000.0,0.5504000186920166,2.0292537212371826,10000.0,29635.17948579788,30803.28497552872,29635.17948579788,1161.9832620620728,2.962131977081299,0.0 -88000,2.0731778,1.4124087,,,,,,,,,,,,,, -88100,2.284741,1.4629693,,,,,,,,,,,,,, -88200,2.3923118,1.3716698,,,,,,,,,,,,,, -88300,2.4289403,1.3869045,,,,,,,,,,,,,, -88400,2.3433042,1.3996556,,,,,,,,,,,,,, -88500,2.2673469,1.4559028,,,,,,,,,,,,,, -88600,2.2352386,1.2644181,,,,,,,,,,,,,, -88700,2.2699986,1.4728359,,,,,,,,,,,,,, -88800,2.2285304,1.3914849,,,,,,,,,,,,,, -88900,2.2239754,1.31764,,,,,,,,,,,,,, -89000,2.4222207,1.4944656,,,,,,,,,,,,,, -89100,2.3320165,1.4516499,,,,,,,,,,,,,, -89200,2.4822052,1.4097958,,,,,,,,,,,,,, -89300,2.078219,1.3448045,,,,,,,,,,,,,, -89400,2.3163457,1.5871625,,,,,,,,,,,,,, -89478,,,0.7750119566917419,0.8389298915863037,0.6815599799156189,1.3008902072906494,50000.0,0.5547000169754028,2.012035369873047,10000.0,30145.17422771454,31330.530110120773,30145.17422771454,1179.1324508190155,3.006996393203736,0.0 -89500,2.1156085,1.351914,,,,,,,,,,,,,, -89600,2.3659513,1.443285,,,,,,,,,,,,,, -89700,2.5199196,1.3725547,,,,,,,,,,,,,, -89800,2.385296,1.3466368,,,,,,,,,,,,,, -89900,2.6578877,1.3860497,,,,,,,,,,,,,, -90000,2.2451015,1.4345227,,,,,,,,,,,,,, -90100,2.293374,1.440725,,,,,,,,,,,,,, -90200,2.337383,1.375439,,,,,,,,,,,,,, -90300,2.192209,1.3250117,,,,,,,,,,,,,, -90400,2.1433814,1.3357843,,,,,,,,,,,,,, -90500,2.3226655,1.4132587,,,,,,,,,,,,,, -90600,2.337862,1.4243441,,,,,,,,,,,,,, -90700,2.3522615,1.430203,,,,,,,,,,,,,, -90800,2.1811097,1.4063246,,,,,,,,,,,,,, -90900,2.3936582,1.4426986,,,,,,,,,,,,,, -90996,,,0.7703682780265808,0.8547177910804749,0.6854400038719177,1.282153844833374,50000.0,0.5522000193595886,2.0209715366363525,10000.0,30655.103055000305,31857.856000185013,30655.103055000305,1196.418663978577,3.062368392944336,0.0 -91000,2.0701215,1.3103248,,,,,,,,,,,,,, -91100,2.3215933,1.4536778,,,,,,,,,,,,,, -91200,2.2948856,1.3876445,,,,,,,,,,,,,, -91300,1.9996545,1.3292884,,,,,,,,,,,,,, -91400,2.1900396,1.3214555,,,,,,,,,,,,,, -91500,2.2449305,1.3927273,,,,,,,,,,,,,, -91600,2.4755278,1.4469372,,,,,,,,,,,,,, -91700,2.2075312,1.3981414,,,,,,,,,,,,,, -91800,2.4464304,1.3534734,,,,,,,,,,,,,, -91900,2.4100204,1.4634472,,,,,,,,,,,,,, -92000,2.3176517,1.4427841,,,,,,,,,,,,,, -92100,2.6083038,1.4201392,,,,,,,,,,,,,, -92200,2.2195773,1.4763889,,,,,,,,,,,,,, -92300,2.2140574,1.3122942,,,,,,,,,,,,,, -92400,2.23444,1.3319716,,,,,,,,,,,,,, -92500,2.3839111,1.4209223,,,,,,,,,,,,,, -92513,,,0.7651067972183228,0.8746317625045776,0.6829400062561035,1.2891305685043335,50000.0,0.55840003490448,2.015486001968384,10000.0,31165.160599946976,32385.52663731575,31165.160599946976,1213.9319269657135,3.1073718070983887,0.0 -92600,2.515493,1.4350173,,,,,,,,,,,,,, -92700,2.205848,1.3911445,,,,,,,,,,,,,, -92800,2.4502738,1.3700974,,,,,,,,,,,,,, -92900,2.5225716,1.3458006,,,,,,,,,,,,,, -93000,2.58463,1.3865829,,,,,,,,,,,,,, -93100,2.4920962,1.310427,,,,,,,,,,,,,, -93200,2.2961035,1.3603555,,,,,,,,,,,,,, -93300,2.3806257,1.4076239,,,,,,,,,,,,,, -93400,2.6238756,1.4553405,,,,,,,,,,,,,, -93500,2.2201548,1.3534648,,,,,,,,,,,,,, -93600,2.258578,1.4112315,,,,,,,,,,,,,, -93700,2.247549,1.437726,,,,,,,,,,,,,, -93800,2.4534812,1.4087858,,,,,,,,,,,,,, -93900,2.3812668,1.4370387,,,,,,,,,,,,,, -94000,2.435419,1.3638979,,,,,,,,,,,,,, -94031,,,0.7631736397743225,0.8782700896263123,0.6861799955368042,1.2679184675216677,50000.0,0.5626000165939331,1.9711854457855225,10000.0,31675.272496938705,32913.0654964447,31675.272496938705,1231.2555334568024,3.1558098793029785,0.0 -94100,2.3072228,1.3497962,,,,,,,,,,,,,, -94200,2.3215578,1.3745811,,,,,,,,,,,,,, -94300,2.1977,1.2391179,,,,,,,,,,,,,, -94400,2.230671,1.324321,,,,,,,,,,,,,, -94500,2.2478921,1.3345993,,,,,,,,,,,,,, -94600,2.4848633,1.4413441,,,,,,,,,,,,,, -94700,2.5882463,1.3346462,,,,,,,,,,,,,, -94800,2.6507287,1.4353569,,,,,,,,,,,,,, -94900,2.5707881,1.3726662,,,,,,,,,,,,,, -95000,2.2054987,1.3825601,,,,,,,,,,,,,, -95100,2.4946783,1.314779,,,,,,,,,,,,,, -95200,2.690876,1.4058207,,,,,,,,,,,,,, -95300,2.3613667,1.4198602,,,,,,,,,,,,,, -95400,2.3586318,1.4247347,,,,,,,,,,,,,, -95500,2.5822856,1.2956245,,,,,,,,,,,,,, -95548,,,0.7469507455825806,0.953036606311798,0.6697399616241455,1.3450031280517578,50000.0,0.5527000427246094,2.0591471195220947,10000.0,32185.48177075386,33440.59732437134,32185.48177075386,1248.4793949127195,3.199093580245972,0.0 -95600,2.5156069,1.4336638,,,,,,,,,,,,,, -95700,2.3655925,1.3181628,,,,,,,,,,,,,, -95800,2.0987158,1.2888455,,,,,,,,,,,,,, -95900,2.690039,1.4111594,,,,,,,,,,,,,, -96000,2.410089,1.4003975,,,,,,,,,,,,,, -96100,2.363389,1.3266366,,,,,,,,,,,,,, -96200,2.2893195,1.3594362,,,,,,,,,,,,,, -96300,2.4663205,1.3857585,,,,,,,,,,,,,, -96400,2.5341086,1.4070734,,,,,,,,,,,,,, -96500,2.5579188,1.3979063,,,,,,,,,,,,,, -96600,2.4611385,1.451898,,,,,,,,,,,,,, -96700,2.5879943,1.3834538,,,,,,,,,,,,,, -96800,2.4376948,1.4295248,,,,,,,,,,,,,, -96900,2.3039691,1.2244385,,,,,,,,,,,,,, -97000,2.3385682,1.3425403,,,,,,,,,,,,,, -97065,,,0.8135961294174194,0.6818169951438904,0.6930800080299377,1.2403595447540283,50000.0,0.5666000247001648,1.9498519897460933,10000.0,32695.579294204712,33968.07897615433,32695.579294204712,1265.758617401123,3.2494664192199707,0.0 -97100,2.4163055,1.3117628,,,,,,,,,,,,,, -97200,2.7470562,1.4285278,,,,,,,,,,,,,, -97300,2.5086424,1.3588061,,,,,,,,,,,,,, -97400,2.587404,1.4026253,,,,,,,,,,,,,, -97500,2.2276149,1.33272,,,,,,,,,,,,,, -97600,2.4668603,1.4092324,,,,,,,,,,,,,, -97700,2.6262414,1.4019892,,,,,,,,,,,,,, -97800,2.1711717,1.3272626,,,,,,,,,,,,,, -97900,2.569129,1.4985994,,,,,,,,,,,,,, -98000,2.3479245,1.3972831,,,,,,,,,,,,,, -98100,2.8268974,1.438774,,,,,,,,,,,,,, -98200,2.265086,1.2673006,,,,,,,,,,,,,, -98300,2.3955314,1.3343159,,,,,,,,,,,,,, -98400,2.332148,1.4250286,,,,,,,,,,,,,, -98500,2.4479504,1.3934057,,,,,,,,,,,,,, -98581,,,0.7823262214660645,0.8076447248458862,0.6801799535751343,1.2974464893341064,50000.0,0.5488000512123108,2.041500806808472,10000.0,33205.586966753006,34495.32242035866,33205.586966753006,1282.8943948745728,3.295815944671631,0.0 -98600,2.3881676,1.3044204,,,,,,,,,,,,,, -98700,2.7162242,1.39134,,,,,,,,,,,,,, -98800,2.5029504,1.3687847,,,,,,,,,,,,,, -98900,2.3790193,1.3144678,,,,,,,,,,,,,, -99000,2.3411505,1.2834373,,,,,,,,,,,,,, -99100,2.3950803,1.4080586,,,,,,,,,,,,,, -99200,2.5497181,1.4428368,,,,,,,,,,,,,, -99300,2.4439523,1.3431048,,,,,,,,,,,,,, -99400,2.620233,1.3997146,,,,,,,,,,,,,, -99500,2.7344313,1.3610818,,,,,,,,,,,,,, -99600,2.5158193,1.4424559,,,,,,,,,,,,,, -99700,2.4460266,1.3501863,,,,,,,,,,,,,, -99800,2.2831426,1.2592458,,,,,,,,,,,,,, -99900,2.2352586,1.2349238,,,,,,,,,,,,,, -100000,2.3865826,1.4170532,,,,,,,,,,,,,, -100019,,,0.7859135866165161,0.7940120697021484,0.6896599531173706,1.251431226730347,50000.0,0.560200035572052,1.9564324617385864,10000.0,33715.81254816055,35022.6372282505,33715.81254816055,1299.8828177452087,3.344213485717773,0.0 -100100,2.4783862,1.352956,,,,,,,,,,,,,, -100200,2.4455457,1.2719887,,,,,,,,,,,,,, -100300,2.6344593,1.4020761,,,,,,,,,,,,,, -100400,2.6028185,1.3273469,,,,,,,,,,,,,, -100500,2.684891,1.341434,,,,,,,,,,,,,, -100600,2.4527164,1.3062522,,,,,,,,,,,,,, -100700,2.3885636,1.3293717,,,,,,,,,,,,,, -100800,2.265867,1.2121489,,,,,,,,,,,,,, -100900,2.2268262,1.2859769,,,,,,,,,,,,,, -101000,2.6272357,1.3556992,,,,,,,,,,,,,, -101100,2.3092542,1.2948948,,,,,,,,,,,,,, -101200,2.6006718,1.3261777,,,,,,,,,,,,,, -101300,2.4335463,1.3515327,,,,,,,,,,,,,, -101400,2.5781178,1.3390186,,,,,,,,,,,,,, -101500,2.4174926,1.3571869,,,,,,,,,,,,,, -101536,,,0.7773038744926453,0.815521776676178,0.6916799545288086,1.2453638315200806,50000.0,0.5621000528335571,1.9738190174102783,10000.0,34225.73194885254,35550.06074118614,34225.73194885254,1317.2859869003296,3.389284610748291,0.0 -101600,2.5393407,1.4214734,,,,,,,,,,,,,, -101700,2.534737,1.2715577,,,,,,,,,,,,,, -101800,2.5111527,1.2867517,,,,,,,,,,,,,, -101900,2.5119317,1.243176,,,,,,,,,,,,,, -102000,2.5663598,1.3582615,,,,,,,,,,,,,, -102100,2.3824744,1.2505915,,,,,,,,,,,,,, -102200,2.8742335,1.3852972,,,,,,,,,,,,,, -102300,2.6564724,1.2978028,,,,,,,,,,,,,, -102400,2.4814324,1.2002412,,,,,,,,,,,,,, -102500,2.4634638,1.4192889,,,,,,,,,,,,,, -102600,2.5630612,1.4292891,,,,,,,,,,,,,, -102700,2.2928567,1.2610552,,,,,,,,,,,,,, -102800,2.6125753,1.2976848,,,,,,,,,,,,,, -102900,2.717786,1.4133353,,,,,,,,,,,,,, -103000,2.5301175,1.3504318,,,,,,,,,,,,,, -103053,,,0.7824656963348389,0.8034330010414124,0.6937800049781799,1.2375240325927734,50000.0,0.5625,1.970668077468872,10000.0,34735.81096434593,36077.352365493774,34735.81096434593,1334.4074614048004,3.426310539245605,0.0 -103100,2.5356882,1.2939878,,,,,,,,,,,,,, -103200,2.3438437,1.3314333,,,,,,,,,,,,,, -103300,2.5256135,1.3730243,,,,,,,,,,,,,, -103400,2.6601856,1.3877655,,,,,,,,,,,,,, -103500,2.606864,1.4678432,,,,,,,,,,,,,, -103600,2.486079,1.3483293,,,,,,,,,,,,,, -103700,2.501781,1.3932167,,,,,,,,,,,,,, -103800,2.4884496,1.2781028,,,,,,,,,,,,,, -103900,2.6290855,1.3621018,,,,,,,,,,,,,, -104000,2.5090387,1.4376309,,,,,,,,,,,,,, -104100,2.4348729,1.1935947,,,,,,,,,,,,,, -104200,2.5515924,1.3352947,,,,,,,,,,,,,, -104300,2.6003895,1.3457601,,,,,,,,,,,,,, -104400,2.7452033,1.3035836,,,,,,,,,,,,,, -104500,2.4770548,1.2917091,,,,,,,,,,,,,, -104570,,,0.7754902839660645,0.8326342701911926,0.6926199793815613,1.252827763557434,50000.0,0.5625,2.004796266555786,10000.0,35245.739324092865,36604.654005527496,35245.739324092865,1351.678982257843,3.4716062545776367,0.0 -104600,2.4927347,1.3067178,,,,,,,,,,,,,, -104700,2.6998959,1.3504516,,,,,,,,,,,,,, -104800,2.4991324,1.264399,,,,,,,,,,,,,, -104900,2.8593245,1.3063198,,,,,,,,,,,,,, -105000,2.9315472,1.431032,,,,,,,,,,,,,, -105100,2.5043218,1.3007189,,,,,,,,,,,,,, -105200,2.6592975,1.361056,,,,,,,,,,,,,, -105300,2.4350157,1.296281,,,,,,,,,,,,,, -105400,2.8731887,1.2794914,,,,,,,,,,,,,, -105500,2.5189214,1.2771918,,,,,,,,,,,,,, -105600,2.581263,1.2866126,,,,,,,,,,,,,, -105700,2.5950947,1.3083504,,,,,,,,,,,,,, -105800,2.5590193,1.3847907,,,,,,,,,,,,,, -105900,2.461475,1.2700341,,,,,,,,,,,,,, -106000,2.5181925,1.1786773,,,,,,,,,,,,,, -106087,,,0.8234414458274841,0.6515009999275208,0.6938199996948242,1.2419836521148682,50000.0,0.5612000226974487,1.9858746528625488,10000.0,35755.65688419342,37132.086525440216,35755.65688419342,1369.0937926769257,3.5167410373687744,0.0 -106100,2.4147491,1.3845968,,,,,,,,,,,,,, -106200,2.5199277,1.3294894,,,,,,,,,,,,,, -106300,2.940055,1.3056228,,,,,,,,,,,,,, -106400,2.526141,1.3371094,,,,,,,,,,,,,, -106500,2.603075,1.3070512,,,,,,,,,,,,,, -106600,2.7591257,1.3308569,,,,,,,,,,,,,, -106700,2.9213831,1.3019505,,,,,,,,,,,,,, -106800,2.4658985,1.254775,,,,,,,,,,,,,, -106900,2.75951,1.3454732,,,,,,,,,,,,,, -107000,2.5006585,1.2466993,,,,,,,,,,,,,, -107100,2.5347345,1.1989324,,,,,,,,,,,,,, -107200,2.594363,1.3871095,,,,,,,,,,,,,, -107300,2.9459324,1.3015527,,,,,,,,,,,,,, -107400,2.7404947,1.3766949,,,,,,,,,,,,,, -107500,2.7566152,1.279746,,,,,,,,,,,,,, -107600,2.8114612,1.336435,,,,,,,,,,,,,, -107604,,,0.8073580861091614,0.7141561508178711,0.6963199973106384,1.2244770526885986,50000.0,0.5699000358581543,1.948879361152649,10000.0,36265.8586742878,37659.49220252037,36265.8586742878,1386.1959176063538,3.563281774520874,0.0 -107700,2.6954088,1.2968258,,,,,,,,,,,,,, -107800,2.4789197,1.2256888,,,,,,,,,,,,,, -107900,2.7570708,1.3797055,,,,,,,,,,,,,, -108000,2.835101,1.4643643,,,,,,,,,,,,,, -108100,2.5140514,1.2331537,,,,,,,,,,,,,, -108200,2.7209842,1.3428785,,,,,,,,,,,,,, -108300,2.867735,1.2432175,,,,,,,,,,,,,, -108400,2.621839,1.1578258,,,,,,,,,,,,,, -108500,2.6546385,1.2370555,,,,,,,,,,,,,, -108600,2.7628534,1.2581495,,,,,,,,,,,,,, -108700,2.7876003,1.3515218,,,,,,,,,,,,,, -108800,2.589071,1.3396382,,,,,,,,,,,,,, -108900,2.8235629,1.2467487,,,,,,,,,,,,,, -109000,2.6411667,1.2023085,,,,,,,,,,,,,, -109100,2.8548577,1.4002275,,,,,,,,,,,,,, -109121,,,0.7997050285339355,0.7437049746513367,0.6982399821281433,1.220030665397644,50000.0,0.5733000040054321,1.927659273147583,10000.0,36775.92725396156,38186.74701857567,36775.92725396156,1403.2778811454773,3.612172365188599,0.0 -109200,2.5818567,1.201469,,,,,,,,,,,,,, -109300,2.6395738,1.2800142,,,,,,,,,,,,,, -109400,2.6332934,1.3185133,,,,,,,,,,,,,, -109500,2.5277815,1.2961981,,,,,,,,,,,,,, -109600,2.5046926,1.2073891,,,,,,,,,,,,,, -109700,2.73729,1.305125,,,,,,,,,,,,,, -109800,2.854362,1.2353778,,,,,,,,,,,,,, -109900,2.6078656,1.3489361,,,,,,,,,,,,,, -110000,2.5045679,1.2738552,,,,,,,,,,,,,, -110100,2.8665984,1.3469255,,,,,,,,,,,,,, -110200,2.6647904,1.3473067,,,,,,,,,,,,,, -110300,2.9923837,1.3346117,,,,,,,,,,,,,, -110400,2.827901,1.3451902,,,,,,,,,,,,,, -110500,2.7673461,1.2580956,,,,,,,,,,,,,, -110600,3.0686114,1.36079,,,,,,,,,,,,,, -110638,,,0.7970942258834839,0.739464282989502,0.7014200091362,1.2130403518676758,50000.0,0.5708000063896179,1.961884617805481,10000.0,37285.923748254776,38714.17147755623,37285.923748254776,1420.5989344120026,3.665136575698853,0.0 -110700,2.8010328,1.312061,,,,,,,,,,,,,, -110800,2.529397,1.235223,,,,,,,,,,,,,, -110900,3.071655,1.3040158,,,,,,,,,,,,,, -111000,2.8368952,1.2386326,,,,,,,,,,,,,, -111100,2.6017091,1.3362659,,,,,,,,,,,,,, -111200,2.6743565,1.2331321,,,,,,,,,,,,,, -111300,2.8865464,1.2549459,,,,,,,,,,,,,, -111400,2.5420847,1.2430667,,,,,,,,,,,,,, -111500,2.764319,1.3001224,,,,,,,,,,,,,, -111600,2.8134933,1.2558881,,,,,,,,,,,,,, -111700,2.536413,1.2146473,,,,,,,,,,,,,, -111800,2.5066814,1.2564037,,,,,,,,,,,,,, -111900,2.627796,1.2330424,,,,,,,,,,,,,, -112000,2.8733835,1.2873027,,,,,,,,,,,,,, -112100,2.9277549,1.1990077,,,,,,,,,,,,,, -112156,,,0.7895607352256775,0.7663513422012329,0.6975199580192566,1.2365078926086426,50000.0,0.566100001335144,1.9826350212097168,10000.0,37796.09580469132,39241.75362706184,37796.09580469132,1437.905579805374,3.713687658309937,0.0 -112200,2.5787427,1.2474439,,,,,,,,,,,,,, -112300,2.600818,1.1874659,,,,,,,,,,,,,, -112400,2.8366697,1.2832668,,,,,,,,,,,,,, -112500,2.5317183,1.2422147,,,,,,,,,,,,,, -112600,2.790511,1.2525439,,,,,,,,,,,,,, -112700,2.57974,1.2461952,,,,,,,,,,,,,, -112800,2.7145257,1.2008541,,,,,,,,,,,,,, -112900,2.680488,1.2885256,,,,,,,,,,,,,, -113000,2.830942,1.302236,,,,,,,,,,,,,, -113100,2.7039924,1.2510228,,,,,,,,,,,,,, -113200,2.69651,1.2887894,,,,,,,,,,,,,, -113300,2.7453084,1.199482,,,,,,,,,,,,,, -113400,2.6423488,1.2207797,,,,,,,,,,,,,, -113500,2.5345886,1.1537216,,,,,,,,,,,,,, -113600,2.812916,1.2790135,,,,,,,,,,,,,, -113674,,,0.7941246628761292,0.7522966861724854,0.6987999677658081,1.2218286991119385,50000.0,0.570900022983551,1.9535249471664429,10000.0,38306.05176615715,39769.040078401566,38306.05176615715,1455.1328389644625,3.761552095413208,0.0 -113700,2.8905466,1.2980307,,,,,,,,,,,,,, -113800,2.5836205,1.1534147,,,,,,,,,,,,,, -113900,2.7073417,1.3019099,,,,,,,,,,,,,, -114000,2.7752244,1.2401949,,,,,,,,,,,,,, -114100,2.8163795,1.2387742,,,,,,,,,,,,,, -114200,2.713364,1.2586578,,,,,,,,,,,,,, -114300,2.881036,1.3046377,,,,,,,,,,,,,, -114400,2.9532654,1.2706914,,,,,,,,,,,,,, -114500,2.5910332,1.2616003,,,,,,,,,,,,,, -114600,2.757263,1.296095,,,,,,,,,,,,,, -114700,3.083515,1.2437509,,,,,,,,,,,,,, -114800,2.604073,1.2069,,,,,,,,,,,,,, -114900,3.125823,1.2219353,,,,,,,,,,,,,, -115000,2.8560548,1.2647521,,,,,,,,,,,,,, -115100,2.5389383,1.2093914,,,,,,,,,,,,,, -115192,,,0.832051157951355,0.6108251214027405,0.700659990310669,1.2239474058151243,50000.0,0.5716000199317932,1.930112361907959,10000.0,38816.13313245773,40296.28511285782,38816.13313245773,1472.192055463791,3.811495065689087,0.0 -115200,2.945222,1.1899937,,,,,,,,,,,,,, -115300,2.70734,1.241557,,,,,,,,,,,,,, -115400,3.1086144,1.2548256,,,,,,,,,,,,,, -115500,2.747436,1.0503186,,,,,,,,,,,,,, -115600,2.8850892,1.2713459,,,,,,,,,,,,,, -115700,2.5917833,1.2232122,,,,,,,,,,,,,, -115800,2.865877,1.2704046,,,,,,,,,,,,,, -115900,2.8181856,1.2296842,,,,,,,,,,,,,, -116000,2.889771,1.2182312,,,,,,,,,,,,,, -116100,2.9520519,1.216613,,,,,,,,,,,,,, -116200,2.8278077,1.2336988,,,,,,,,,,,,,, -116300,3.1944265,1.2805882,,,,,,,,,,,,,, -116400,2.9809263,1.2419013,,,,,,,,,,,,,, -116500,2.7002246,1.1466858,,,,,,,,,,,,,, -116600,2.8750703,1.1680506,,,,,,,,,,,,,, -116700,2.9505587,1.2729586,,,,,,,,,,,,,, -116710,,,0.8093311190605164,0.6926111578941345,0.7005800008773804,1.2249276638031006,50000.0,0.5719000101089478,1.968600034713745,10000.0,39326.16049218178,40823.82935762405,39326.16049218178,1489.6076259613037,3.858402729034424,0.0 -116800,2.9226797,1.195563,,,,,,,,,,,,,, -116900,2.790192,1.316757,,,,,,,,,,,,,, -117000,2.961936,1.2145083,,,,,,,,,,,,,, -117100,2.8380642,1.2077651,,,,,,,,,,,,,, -117200,2.7536955,1.1184229,,,,,,,,,,,,,, -117300,2.8459573,1.2134596,,,,,,,,,,,,,, -117400,3.1578753,1.2006447,,,,,,,,,,,,,, -117500,3.0345018,1.2036273,,,,,,,,,,,,,, -117600,2.7812815,1.159484,,,,,,,,,,,,,, -117700,2.5432658,1.1651856,,,,,,,,,,,,,, -117800,2.9636657,1.3195498,,,,,,,,,,,,,, -117900,3.0884316,1.3321431,,,,,,,,,,,,,, -118000,3.0979245,1.2429708,,,,,,,,,,,,,, -118100,3.1502469,1.2805238,,,,,,,,,,,,,, -118200,2.837655,1.2574003,,,,,,,,,,,,,, -118228,,,0.8143933415412903,0.6682717204093933,0.7066400051116943,1.1898187398910522,50000.0,0.5803000330924988,1.9361289739608765,10000.0,39836.12970209122,41351.13400554657,39836.12970209122,1506.8389210700989,3.907601833343506,0.0 -118300,2.7588854,1.1176933,,,,,,,,,,,,,, -118400,2.891825,1.2390091,,,,,,,,,,,,,, -118500,3.422873,1.1890546,,,,,,,,,,,,,, -118600,3.179666,1.2837775,,,,,,,,,,,,,, -118700,2.7866907,1.1490527,,,,,,,,,,,,,, -118800,2.8030088,1.1789827,,,,,,,,,,,,,, -118900,3.1001816,1.2504083,,,,,,,,,,,,,, -119000,3.1133673,1.2202988,,,,,,,,,,,,,, -119100,3.0619178,1.1822958,,,,,,,,,,,,,, -119200,2.978108,1.1267915,,,,,,,,,,,,,, -119300,2.794662,1.1310709,,,,,,,,,,,,,, -119400,2.9143026,1.2570012,,,,,,,,,,,,,, -119500,2.9582467,1.1233667,,,,,,,,,,,,,, -119600,2.9418995,1.181492,,,,,,,,,,,,,, -119700,2.9243867,1.2043958,,,,,,,,,,,,,, -119746,,,0.8092713356018066,0.6910978555679321,0.7084599733352661,1.1896438598632812,50000.0,0.5770000219345093,1.9084508419036863,10000.0,40346.30208206177,41878.540466308594,40346.30208206177,1523.9704895019531,3.954399585723877,0.0 -119800,3.1459062,1.2597845,,,,,,,,,,,,,, -119900,3.163186,1.140075,,,,,,,,,,,,,, -120000,3.1635365,1.1861045,,,,,,,,,,,,,, -120100,2.765721,1.0310355,,,,,,,,,,,,,, -120200,2.6377048,1.1288906,,,,,,,,,,,,,, -120300,3.0298538,1.2343268,,,,,,,,,,,,,, -120400,3.2388687,1.2159758,,,,,,,,,,,,,, -120500,2.8876395,1.1996197,,,,,,,,,,,,,, -120600,3.1280847,1.1339594,,,,,,,,,,,,,, -120700,2.932101,1.1609093,,,,,,,,,,,,,, -120800,2.9682155,1.1924388,,,,,,,,,,,,,, -120900,2.8459682,1.1661433,,,,,,,,,,,,,, -121000,2.9773867,1.1106782,,,,,,,,,,,,,, -121100,2.962618,1.2497654,,,,,,,,,,,,,, -121200,2.8491905,1.040457,,,,,,,,,,,,,, -121263,,,0.8119817972183228,0.6661557555198669,0.7090799808502197,1.1816290616989136,50000.0,0.5848000049591064,1.903845191001892,10000.0,40856.50037384033,42406.03638720512,40856.50037384033,1541.1607983112335,4.005815267562866,0.0 -121300,2.9672306,1.0682243,,,,,,,,,,,,,, -121400,3.1056812,1.1303294,,,,,,,,,,,,,, -121500,2.8490307,1.0553086,,,,,,,,,,,,,, -121600,2.8695984,1.203583,,,,,,,,,,,,,, -121700,2.8613987,1.0857,,,,,,,,,,,,,, -121800,2.9707956,1.0877566,,,,,,,,,,,,,, -121900,3.1188197,1.2551615,,,,,,,,,,,,,, -122000,3.0085952,1.127041,,,,,,,,,,,,,, -122100,3.0962281,1.0936224,,,,,,,,,,,,,, -122200,3.0298645,1.1649636,,,,,,,,,,,,,, -122300,3.147533,1.1955686,,,,,,,,,,,,,, -122400,3.0687468,1.1027465,,,,,,,,,,,,,, -122500,2.7661846,1.1832991,,,,,,,,,,,,,, -122600,3.2426662,1.1018881,,,,,,,,,,,,,, -122700,3.3065279,1.1889815,,,,,,,,,,,,,, -122780,,,0.8156090378761292,0.6675543189048767,0.7113800048828125,1.1680610179901123,50000.0,0.5860000252723694,1.884270191192627,10000.0,41366.60287451744,42933.3178434372,41366.60287451744,1558.2389419078827,4.051869869232178,0.0 -122800,3.172288,1.2019272,,,,,,,,,,,,,, -122900,3.00425,1.1123931,,,,,,,,,,,,,, -123000,3.1207979,1.1990678,,,,,,,,,,,,,, -123100,3.0684865,1.1506007,,,,,,,,,,,,,, -123200,3.0560663,1.0733014,,,,,,,,,,,,,, -123300,2.9244404,1.2196933,,,,,,,,,,,,,, -123400,3.0130599,1.1551971,,,,,,,,,,,,,, -123500,2.9855895,1.1272497,,,,,,,,,,,,,, -123600,3.057453,1.1623504,,,,,,,,,,,,,, -123700,3.200393,1.2361544,,,,,,,,,,,,,, -123800,2.755235,1.131345,,,,,,,,,,,,,, -123900,2.8788476,1.0932324,,,,,,,,,,,,,, -124000,3.2512944,1.1534234,,,,,,,,,,,,,, -124100,3.2038662,1.2186029,,,,,,,,,,,,,, -124200,2.8957932,1.0662014,,,,,,,,,,,,,, -124297,,,0.8449258208274841,0.5565405488014221,0.710099995136261,1.1687265634536743,50000.0,0.5910000205039978,1.8822606801986688,10000.0,41876.761585474014,43460.797714948654,41876.761585474014,1575.457461833954,4.100099802017212,0.0 -124300,3.0399687,1.1082112,,,,,,,,,,,,,, -124400,2.8935194,1.1342223,,,,,,,,,,,,,, -124500,3.3444126,1.1841972,,,,,,,,,,,,,, -124600,3.3764768,1.2135652,,,,,,,,,,,,,, -124700,3.1320753,1.2354108,,,,,,,,,,,,,, -124800,3.127027,1.1079531,,,,,,,,,,,,,, -124900,3.0578625,1.1098844,,,,,,,,,,,,,, -125000,3.0710788,1.1765332,,,,,,,,,,,,,, -125100,3.110987,1.1603131,,,,,,,,,,,,,, -125200,2.9188104,1.2039797,,,,,,,,,,,,,, -125300,3.5223365,1.2350376,,,,,,,,,,,,,, -125400,2.8908958,1.0520312,,,,,,,,,,,,,, -125500,3.158979,1.2297456,,,,,,,,,,,,,, -125600,3.2888744,1.1591682,,,,,,,,,,,,,, -125700,2.9848745,1.0355673,,,,,,,,,,,,,, -125800,3.1325424,1.0625907,,,,,,,,,,,,,, -125814,,,0.8313934803009033,0.5971071124076843,0.714139997959137,1.1716444492340088,50000.0,0.5781000256538391,1.9235954284667969,10000.0,42386.82617163658,43988.02368068695,42386.82617163658,1592.5177104473114,4.146549224853516,0.0 -125900,3.0771413,1.1969862,,,,,,,,,,,,,, -126000,2.882644,1.0461664,,,,,,,,,,,,,, -126100,3.4342322,1.1317364,,,,,,,,,,,,,, -126200,3.35044,1.2087731,,,,,,,,,,,,,, -126300,3.1503847,1.1379352,,,,,,,,,,,,,, -126400,3.1594958,1.0570056,,,,,,,,,,,,,, -126500,3.3471973,1.2448593,,,,,,,,,,,,,, -126600,3.1226196,1.1323256,,,,,,,,,,,,,, -126700,3.192233,1.1772087,,,,,,,,,,,,,, -126800,3.1382701,1.191112,,,,,,,,,,,,,, -126900,3.136615,1.1775142,,,,,,,,,,,,,, -127000,3.5526848,1.2476201,,,,,,,,,,,,,, -127100,3.0620947,1.0612702,,,,,,,,,,,,,, -127200,3.0426826,1.1094693,,,,,,,,,,,,,, -127300,3.2037206,1.1339235,,,,,,,,,,,,,, -127332,,,0.8355787396430969,0.5848914980888367,0.7186799645423889,1.14486563205719,50000.0,0.5875000357627869,1.8843411207199097,10000.0,42896.90095162392,44515.46164655685,42896.90095162392,1609.7732861042025,4.200785160064697,0.0 -127400,2.9057496,1.0567192,,,,,,,,,,,,,, -127500,3.1729145,1.0650808,,,,,,,,,,,,,, -127600,3.022656,1.1322732,,,,,,,,,,,,,, -127700,3.473689,1.1354945,,,,,,,,,,,,,, -127800,3.177836,1.1038916,,,,,,,,,,,,,, -127900,3.3213933,1.0974574,,,,,,,,,,,,,, -128000,3.0362716,1.1065385,,,,,,,,,,,,,, -128100,2.887312,1.0337296,,,,,,,,,,,,,, -128200,2.8773124,1.083912,,,,,,,,,,,,,, -128300,3.0498703,1.0966114,,,,,,,,,,,,,, -128400,3.0913777,1.1243342,,,,,,,,,,,,,, -128500,3.0353234,1.147115,,,,,,,,,,,,,, -128600,3.1957278,1.1616565,,,,,,,,,,,,,, -128700,3.3793528,1.1923944,,,,,,,,,,,,,, -128800,3.482172,1.2101066,,,,,,,,,,,,,, -128849,,,0.8219467401504517,0.6358543634414673,0.7104799747467041,1.1965482234954834,50000.0,0.5900000333786011,1.9271339178085327,10000.0,43406.36694979668,45043.08939242363,43406.36694979668,1627.1122126579285,4.968811750411987,0.0 -128900,3.1140597,0.9785807,,,,,,,,,,,,,, -129000,3.2383,1.0712509,,,,,,,,,,,,,, -129100,2.977129,1.0771585,,,,,,,,,,,,,, -129200,3.123205,1.0640448,,,,,,,,,,,,,, -129300,2.9171448,1.0387623,,,,,,,,,,,,,, -129400,3.0962975,1.076998,,,,,,,,,,,,,, -129500,3.156703,1.1063048,,,,,,,,,,,,,, -129600,2.9727187,0.9952482,,,,,,,,,,,,,, -129700,3.28303,1.0712591,,,,,,,,,,,,,, -129800,3.1976213,1.1882625,,,,,,,,,,,,,, -129900,3.3214712,1.1830454,,,,,,,,,,,,,, -130000,3.216271,1.1505811,,,,,,,,,,,,,, -130100,3.0818949,1.1882474,,,,,,,,,,,,,, -130200,3.291301,1.1300855,,,,,,,,,,,,,, -130300,3.3106759,1.0210598,,,,,,,,,,,,,, -130366,,,0.8296396732330322,0.599907636642456,0.718559980392456,1.1586591005325315,50000.0,0.5892000198364258,1.9023849964141848,10000.0,43916.54048705101,45570.87974739075,43916.54048705101,1644.6134796142578,5.027925252914429,0.0 -130400,3.0192757,1.0233526,,,,,,,,,,,,,, -130500,3.394203,1.1989621,,,,,,,,,,,,,, -130600,3.5405245,1.208061,,,,,,,,,,,,,, -130700,3.503489,1.1103342,,,,,,,,,,,,,, -130800,3.358311,1.1134793,,,,,,,,,,,,,, -130900,3.1939108,1.0546674,,,,,,,,,,,,,, -131000,3.277489,1.0764377,,,,,,,,,,,,,, -131100,3.3985102,1.153173,,,,,,,,,,,,,, -131200,3.2220545,1.0659754,,,,,,,,,,,,,, -131300,3.3651795,1.16265,,,,,,,,,,,,,, -131400,3.3075438,1.1461923,,,,,,,,,,,,,, -131500,3.618037,1.15537,,,,,,,,,,,,,, -131600,3.244376,1.1148213,,,,,,,,,,,,,, -131700,3.3656826,1.1596975,,,,,,,,,,,,,, -131800,3.2581375,1.0793686,,,,,,,,,,,,,, -131884,,,0.8519012928009033,0.5267131328582764,0.7257999777793884,1.124706149101257,50000.0,0.6004000306129456,1.8547765016555784,10000.0,44426.55810856819,46098.40629863739,44426.55810856819,1662.0161790847778,5.078296184539795,0.0 -131900,3.3926444,1.1638222,,,,,,,,,,,,,, -132000,3.9968889,1.1432657,,,,,,,,,,,,,, -132100,3.2836027,1.1400807,,,,,,,,,,,,,, -132200,3.285174,1.0553412,,,,,,,,,,,,,, -132300,3.1743476,0.95605797,,,,,,,,,,,,,, -132400,3.3978105,1.069462,,,,,,,,,,,,,, -132500,3.532935,1.1388398,,,,,,,,,,,,,, -132600,3.2288463,0.95773715,,,,,,,,,,,,,, -132700,3.5430694,1.0155058,,,,,,,,,,,,,, -132800,3.310214,1.0251372,,,,,,,,,,,,,, -132900,3.1106865,1.0397401,,,,,,,,,,,,,, -133000,3.3843918,1.017797,,,,,,,,,,,,,, -133100,3.3023987,0.9806401,,,,,,,,,,,,,, -133200,3.5041802,1.0230091,,,,,,,,,,,,,, -133300,3.4727757,1.085093,,,,,,,,,,,,,, -133400,,,0.8651546239852905,0.4772387146949768,0.7245199680328369,1.1348223686218262,50000.0,0.5948000550270081,1.873829960823059,10000.0,44936.45907497406,46625.67883968353,44936.45907497406,1679.2852370738983,5.126734733581543,0.0 -133400,3.1377592,1.1082916,,,,,,,,,,,,,, -133500,3.3930302,1.1633488,,,,,,,,,,,,,, -133600,3.414935,1.0399112,,,,,,,,,,,,,, -133700,3.3465188,1.0645008,,,,,,,,,,,,,, -133800,3.2680128,1.0645801,,,,,,,,,,,,,, -133900,3.496638,1.0784091,,,,,,,,,,,,,, -134000,3.3534205,0.9922195,,,,,,,,,,,,,, -134100,3.1812775,1.0145814,,,,,,,,,,,,,, -134200,3.2863262,0.99763304,,,,,,,,,,,,,, -134300,3.3987145,1.0081478,,,,,,,,,,,,,, -134400,3.533263,1.0738693,,,,,,,,,,,,,, -134500,3.2301161,1.0059756,,,,,,,,,,,,,, -134600,3.493357,1.0494542,,,,,,,,,,,,,, -134700,3.5899732,1.0289345,,,,,,,,,,,,,, -134800,3.5106082,1.0692856,,,,,,,,,,,,,, -134900,3.6146417,1.0255355,,,,,,,,,,,,,, -134918,,,0.8583386540412903,0.5006170272827148,0.7210999727249146,1.1423414945602417,50000.0,0.5969000458717346,1.9024995565414429,10000.0,45446.67143511772,47153.21483302117,45446.67143511772,1696.497545480728,5.18248438835144,0.0 -135000,3.510096,1.1145991,,,,,,,,,,,,,, -135100,3.6553276,1.0496702,,,,,,,,,,,,,, -135200,3.6058476,1.0775807,,,,,,,,,,,,,, -135300,3.595434,1.1018565,,,,,,,,,,,,,, -135400,3.3461394,1.0562203,,,,,,,,,,,,,, -135500,3.4241557,1.1051788,,,,,,,,,,,,,, -135600,3.2707453,1.0004963,,,,,,,,,,,,,, -135700,3.3159087,0.93559694,,,,,,,,,,,,,, -135800,3.625864,1.1055498,,,,,,,,,,,,,, -135900,3.607774,0.9968999,,,,,,,,,,,,,, -136000,3.6045434,1.0019649,,,,,,,,,,,,,, -136100,3.55204,1.0347867,,,,,,,,,,,,,, -136200,3.4351478,0.99497443,,,,,,,,,,,,,, -136300,3.738869,1.0096531,,,,,,,,,,,,,, -136400,3.5097947,1.0433975,,,,,,,,,,,,,, -136436,,,0.852937638759613,0.5173604488372803,0.7234799861907959,1.1310100555419922,50000.0,0.5994000434875488,1.866504430770874,10000.0,45956.86107802391,47680.83400058746,45956.86107802391,1713.8090782165527,5.246311902999878,0.0 -136500,3.3442485,0.9433859,,,,,,,,,,,,,, -136600,3.8612678,1.1060735,,,,,,,,,,,,,, -136700,3.9755273,0.9947189,,,,,,,,,,,,,, -136800,3.3074322,0.9864184,,,,,,,,,,,,,, -136900,3.5165455,1.1036034,,,,,,,,,,,,,, -137000,3.3582764,1.009353,,,,,,,,,,,,,, -137100,3.5511334,1.117153,,,,,,,,,,,,,, -137200,3.7408042,1.0783601,,,,,,,,,,,,,, -137300,3.2486465,0.9852785,,,,,,,,,,,,,, -137400,3.837663,1.0848289,,,,,,,,,,,,,, -137500,3.5457978,1.0276296,,,,,,,,,,,,,, -137600,3.4704978,1.0301878,,,,,,,,,,,,,, -137700,3.6694636,0.9909062,,,,,,,,,,,,,, -137800,3.4731197,0.9945401,,,,,,,,,,,,,, -137900,4.063605,1.1308179,,,,,,,,,,,,,, -137953,,,0.8542529940605164,0.5037813782691956,0.7277799844741821,1.1227219104766846,50000.0,0.6015000343322754,1.8601411581039429,10000.0,46466.94680929184,48208.39493966103,46466.94680929184,1731.1785836219788,5.296331167221069,0.0 -138000,3.6429043,1.1233592,,,,,,,,,,,,,, -138100,3.6362617,0.94807565,,,,,,,,,,,,,, -138200,3.8083534,1.0750761,,,,,,,,,,,,,, -138300,3.720313,1.1063225,,,,,,,,,,,,,, -138400,3.423613,0.9937035,,,,,,,,,,,,,, -138500,3.5965822,0.9996985,,,,,,,,,,,,,, -138600,3.3970068,0.9287809,,,,,,,,,,,,,, -138700,3.5989132,0.95177674,,,,,,,,,,,,,, -138800,3.8158987,1.0435773,,,,,,,,,,,,,, -138900,3.6651673,0.97244126,,,,,,,,,,,,,, -139000,3.662939,1.0602144,,,,,,,,,,,,,, -139100,3.2242339,0.99661005,,,,,,,,,,,,,, -139200,3.584556,1.0001078,,,,,,,,,,,,,, -139300,3.8239388,0.9710008,,,,,,,,,,,,,, -139400,3.738756,0.97791284,,,,,,,,,,,,,, -139471,,,0.8598333597183228,0.4867278039455414,0.7307999730110168,1.1039236783981323,50000.0,0.5992000102996826,1.8634774684906008,10000.0,46977.11908912659,48735.66280722618,46977.11908912659,1748.1684920787811,5.347053289413452,0.0 -139500,3.2680306,0.87527895,,,,,,,,,,,,,, -139600,3.8273075,1.0220089,,,,,,,,,,,,,, -139700,3.4104626,0.91530144,,,,,,,,,,,,,, -139800,3.271408,0.9109895,,,,,,,,,,,,,, -139900,3.764086,0.98870176,,,,,,,,,,,,,, -140000,3.5422778,0.96498644,,,,,,,,,,,,,, -140100,3.456197,1.0008013,,,,,,,,,,,,,, -140200,3.6427846,1.0373361,,,,,,,,,,,,,, -140300,3.7631702,0.95399153,,,,,,,,,,,,,, -140400,3.7262864,1.0868728,,,,,,,,,,,,,, -140500,3.6543143,1.0392275,,,,,,,,,,,,,, -140600,3.4519432,1.0037284,,,,,,,,,,,,,, -140700,3.3705232,0.9300845,,,,,,,,,,,,,, -140800,3.5431154,0.930846,,,,,,,,,,,,,, -140900,3.8369937,1.0757034,,,,,,,,,,,,,, -140988,,,0.8991350531578064,0.3605747520923614,0.7303000092506409,1.114681601524353,50000.0,0.6034000515937805,1.8561393022537231,10000.0,47487.33484148979,49263.10889315605,47487.33484148979,1765.2932143211365,5.399057388305664,0.0 -141000,3.6520684,1.0948883,,,,,,,,,,,,,, -141100,3.5930276,1.0169914,,,,,,,,,,,,,, -141200,3.5641928,1.0432057,,,,,,,,,,,,,, -141300,3.640805,0.90727615,,,,,,,,,,,,,, -141400,3.394473,0.98842466,,,,,,,,,,,,,, -141500,3.53685,0.93890285,,,,,,,,,,,,,, -141600,3.580201,0.9496763,,,,,,,,,,,,,, -141700,3.818757,0.9819348,,,,,,,,,,,,,, -141800,3.5830421,1.0179399,,,,,,,,,,,,,, -141900,3.8968503,0.97403634,,,,,,,,,,,,,, -142000,3.6980734,0.99752873,,,,,,,,,,,,,, -142100,3.9626024,1.0756809,,,,,,,,,,,,,, -142200,3.8935556,0.9711319,,,,,,,,,,,,,, -142300,3.5814302,0.9422716,,,,,,,,,,,,,, -142400,3.6122763,0.918754,,,,,,,,,,,,,, -142500,3.7906845,0.91564244,,,,,,,,,,,,,, -142506,,,0.8854033350944519,0.3915910422801971,0.7321400046348572,1.103589653968811,50000.0,0.6010000109672546,1.8697199821472168,10000.0,47997.32077693939,49790.36906766892,47997.32077693939,1782.4610340595243,5.45065188407898,0.0 -142600,4.0858097,1.0457218,,,,,,,,,,,,,, -142700,3.7571542,0.96946317,,,,,,,,,,,,,, -142800,4.0137815,1.0820727,,,,,,,,,,,,,, -142900,4.044877,0.93839055,,,,,,,,,,,,,, -143000,4.220686,0.94684994,,,,,,,,,,,,,, -143100,3.7794666,0.93848914,,,,,,,,,,,,,, -143200,3.9018393,0.9735542,,,,,,,,,,,,,, -143300,3.5253575,0.9196376,,,,,,,,,,,,,, -143400,4.252893,0.93822175,,,,,,,,,,,,,, -143500,3.964755,0.9997482,,,,,,,,,,,,,, -143600,4.0456614,1.0352763,,,,,,,,,,,,,, -143700,3.9855528,0.9543198,,,,,,,,,,,,,, -143800,3.7114513,0.91632617,,,,,,,,,,,,,, -143900,3.9082847,1.001009,,,,,,,,,,,,,, -144000,3.4544795,0.8929803,,,,,,,,,,,,,, -144023,,,0.8768733739852905,0.4285891652107239,0.7296800017356873,1.1146568059921265,50000.0,0.6022000312805176,1.8817957639694207,10000.0,48507.233662605286,50317.64517068863,48507.233662605286,1799.7154169082642,5.505875587463379,0.0 -144100,3.9467392,1.0410682,,,,,,,,,,,,,, -144200,3.7796829,0.9483641,,,,,,,,,,,,,, -144300,3.7101479,0.96046644,,,,,,,,,,,,,, -144400,3.642596,0.9181637,,,,,,,,,,,,,, -144500,3.5575032,0.8554673,,,,,,,,,,,,,, -144600,3.9806929,0.97409886,,,,,,,,,,,,,, -144700,3.6714208,0.9543003,,,,,,,,,,,,,, -144800,3.798881,0.89897406,,,,,,,,,,,,,, -144900,3.809565,0.8849315,,,,,,,,,,,,,, -145000,3.9091284,0.9267357,,,,,,,,,,,,,, -145100,3.670456,0.95419794,,,,,,,,,,,,,, -145200,3.91226,1.0260324,,,,,,,,,,,,,, -145300,3.942429,1.0225354,,,,,,,,,,,,,, -145400,4.0471334,0.9160986,,,,,,,,,,,,,, -145500,3.7430904,0.91658926,,,,,,,,,,,,,, -145540,,,0.8807597160339355,0.4134978055953979,0.7312399744987488,1.1134945154190063,50000.0,0.6028000116348267,1.8748246431350708,10000.0,49017.14409446716,50844.926412820816,49017.14409446716,1816.9821512699127,5.555546760559082,0.0 -145600,3.6588912,0.94949174,,,,,,,,,,,,,, -145700,4.0221524,0.913053,,,,,,,,,,,,,, -145800,4.0416546,0.90816027,,,,,,,,,,,,,, -145900,4.197285,0.9920881,,,,,,,,,,,,,, -146000,3.5796516,0.86346686,,,,,,,,,,,,,, -146100,4.0027614,0.8890193,,,,,,,,,,,,,, -146200,3.6610641,0.95448476,,,,,,,,,,,,,, -146300,3.5958939,0.87980103,,,,,,,,,,,,,, -146400,3.950032,0.9106286,,,,,,,,,,,,,, -146500,4.089792,0.8912708,,,,,,,,,,,,,, -146600,3.9374654,0.9115675,,,,,,,,,,,,,, -146700,3.751023,0.8646635,,,,,,,,,,,,,, -146800,3.593941,0.8110945,,,,,,,,,,,,,, -146900,4.0648856,0.89468586,,,,,,,,,,,,,, -147000,3.5447233,0.9252224,,,,,,,,,,,,,, -147058,,,0.8866389989852905,0.3935048282146454,0.7372199892997742,1.0860953330993652,50000.0,0.6084000468254089,1.8452149629592896,10000.0,49527.31033730507,51372.41536259651,49527.31033730507,1834.1997547149656,5.6061952114105225,0.0 -147100,4.019,0.9366123,,,,,,,,,,,,,, -147200,3.673868,0.82786477,,,,,,,,,,,,,, -147300,3.5213401,0.8962561,,,,,,,,,,,,,, -147400,3.768005,0.9062819,,,,,,,,,,,,,, -147500,3.720676,0.82110053,,,,,,,,,,,,,, -147600,3.9977014,0.9393419,,,,,,,,,,,,,, -147700,3.788535,0.9360423,,,,,,,,,,,,,, -147800,4.064765,0.84673965,,,,,,,,,,,,,, -147900,4.1627045,0.86625636,,,,,,,,,,,,,, -148000,3.9381955,0.8383066,,,,,,,,,,,,,, -148100,3.972847,0.8434916,,,,,,,,,,,,,, -148200,4.2281623,0.8499971,,,,,,,,,,,,,, -148300,3.9505394,0.8846059,,,,,,,,,,,,,, -148400,4.105829,0.8234856,,,,,,,,,,,,,, -148500,3.9557083,0.8865319,,,,,,,,,,,,,, -148576,,,0.886738657951355,0.3839896619319916,0.7354399561882019,1.1088933944702148,50000.0,0.6083000302314758,1.8626391887664795,10000.0,50037.51057219505,51899.8086400032,50037.51057219505,1851.286039590836,5.658465147018433,0.0 -148600,4.066796,0.85815245,,,,,,,,,,,,,, -148700,3.663151,0.8985132,,,,,,,,,,,,,, -148800,3.81651,0.92701864,,,,,,,,,,,,,, -148900,3.7010326,0.86467576,,,,,,,,,,,,,, -149000,3.8195996,0.8801285,,,,,,,,,,,,,, -149100,3.567059,0.791683,,,,,,,,,,,,,, -149200,3.8070726,0.8271264,,,,,,,,,,,,,, -149300,3.5779998,0.7787211,,,,,,,,,,,,,, -149400,4.046014,0.90560704,,,,,,,,,,,,,, -149500,4.0223436,0.9086819,,,,,,,,,,,,,, -149600,4.301488,0.89114213,,,,,,,,,,,,,, -149700,3.7308948,0.87967134,,,,,,,,,,,,,, -149800,3.9975154,0.8611354,,,,,,,,,,,,,, -149900,3.6053088,0.8076214,,,,,,,,,,,,,, -150000,3.696187,0.9242706,,,,,,,,,,,,,, -150094,,,0.9194435477256776,0.2786987125873565,0.7404599785804749,1.088493824005127,50000.0,0.6107000112533569,1.8475301265716555,10000.0,50547.72637343407,52427.36676931381,50547.72637343407,1868.521843194961,5.710692644119263,0.0 -150100,4.074773,0.9491491,,,,,,,,,,,,,, -150200,3.8110518,0.85694444,,,,,,,,,,,,,, -150300,4.2576327,0.87195885,,,,,,,,,,,,,, -150400,4.4602766,0.98615056,,,,,,,,,,,,,, -150500,3.8889883,0.8416597,,,,,,,,,,,,,, -150600,4.362364,0.89697385,,,,,,,,,,,,,, -150700,4.0681267,0.8471923,,,,,,,,,,,,,, -150800,4.0638256,0.876997,,,,,,,,,,,,,, -150900,4.260385,0.9070478,,,,,,,,,,,,,, -151000,4.4293523,0.88431954,,,,,,,,,,,,,, -151100,3.993029,0.95801675,,,,,,,,,,,,,, -151200,3.7785046,0.8570747,,,,,,,,,,,,,, -151300,4.1036005,0.92120427,,,,,,,,,,,,,, -151400,4.294171,0.8588271,,,,,,,,,,,,,, -151500,3.9091122,0.8788755,,,,,,,,,,,,,, -151600,3.7342908,0.8197064,,,,,,,,,,,,,, -151612,,,0.9063894748687744,0.3237541913986206,0.7383399605751038,1.0937960147857666,50000.0,0.6117000579833984,1.8623268604278564,10000.0,51057.75116443634,52954.8849272728,51057.75116443634,1885.9064099788663,5.763494491577148,0.0 -151700,4.0162663,0.8665399,,,,,,,,,,,,,, -151800,3.971695,0.75497216,,,,,,,,,,,,,, -151900,4.2236967,0.80227864,,,,,,,,,,,,,, -152000,4.030259,0.8396295,,,,,,,,,,,,,, -152100,3.579048,0.75740033,,,,,,,,,,,,,, -152200,4.2870536,0.7860223,,,,,,,,,,,,,, -152300,3.756781,0.7199841,,,,,,,,,,,,,, -152400,4.047311,0.795624,,,,,,,,,,,,,, -152500,3.8748999,0.7688694,,,,,,,,,,,,,, -152600,4.4314647,0.98212063,,,,,,,,,,,,,, -152700,3.8693433,0.80258805,,,,,,,,,,,,,, -152800,4.2562137,0.8676407,,,,,,,,,,,,,, -152900,4.084877,0.8147486,,,,,,,,,,,,,, -153000,4.263281,0.91867644,,,,,,,,,,,,,, -153100,4.1503754,0.86077994,,,,,,,,,,,,,, -153129,,,0.9104551672935486,0.3100117743015289,0.7428399920463562,1.072890043258667,50000.0,0.6148000359535217,1.8307033777236936,10000.0,51567.71314716339,53482.98257446289,51567.71314716339,1903.927239894867,5.821335554122925,0.0 -153200,4.0906324,0.85908794,,,,,,,,,,,,,, -153300,4.288437,0.8840264,,,,,,,,,,,,,, -153400,3.8504179,0.8389749,,,,,,,,,,,,,, -153500,4.1986046,0.88712496,,,,,,,,,,,,,, -153600,4.129703,0.8429427,,,,,,,,,,,,,, -153700,3.7752035,0.7805183,,,,,,,,,,,,,, -153800,4.1766963,0.84204936,,,,,,,,,,,,,, -153900,4.6301394,0.8847167,,,,,,,,,,,,,, -154000,3.814305,0.8134435,,,,,,,,,,,,,, -154100,4.3181915,0.86766315,,,,,,,,,,,,,, -154200,3.9542785,0.8681412,,,,,,,,,,,,,, -154300,4.2241135,0.76568866,,,,,,,,,,,,,, -154400,3.9933558,0.84243554,,,,,,,,,,,,,, -154500,3.9594092,0.85121006,,,,,,,,,,,,,, -154600,4.016023,0.8246972,,,,,,,,,,,,,, -154646,,,0.908023715019226,0.3052965700626373,0.7430599927902222,1.0862489938735962,50000.0,0.6117000579833984,1.857428789138794,10000.0,52077.77292633057,54010.57850217819,52077.77292633057,1921.3516061306,5.878023862838745,0.0 -154700,3.6485896,0.7630458,,,,,,,,,,,,,, -154800,3.8515382,0.690879,,,,,,,,,,,,,, -154900,4.317773,0.7780123,,,,,,,,,,,,,, -155000,4.320065,0.8065795,,,,,,,,,,,,,, -155100,4.504295,0.8469467,,,,,,,,,,,,,, -155200,4.1299176,0.7838651,,,,,,,,,,,,,, -155300,4.514238,0.8798127,,,,,,,,,,,,,, -155400,4.1410413,0.8304487,,,,,,,,,,,,,, -155500,4.337402,0.8266859,,,,,,,,,,,,,, -155600,4.1167316,0.82026976,,,,,,,,,,,,,, -155700,3.9429421,0.8417724,,,,,,,,,,,,,, -155800,3.9688315,0.748713,,,,,,,,,,,,,, -155900,3.9306648,0.7933015,,,,,,,,,,,,,, -156000,3.9591062,0.7766211,,,,,,,,,,,,,, -156100,4.0453744,0.79449546,,,,,,,,,,,,,, -156163,,,0.9146803021430968,0.2902111411094665,0.744879961013794,1.0719773769378662,50000.0,0.616100013256073,1.843456506729126,10000.0,52587.91961193085,54538.00402379036,52587.91961193085,1938.523027420044,5.9307475090026855,0.0 -156200,4.220272,0.7565431,,,,,,,,,,,,,, -156300,4.293044,0.74740154,,,,,,,,,,,,,, -156400,4.203855,0.77038646,,,,,,,,,,,,,, -156500,4.5190196,0.8966929,,,,,,,,,,,,,, -156600,4.6643,0.8217509,,,,,,,,,,,,,, -156700,4.7513995,0.8385351,,,,,,,,,,,,,, -156800,4.4756994,0.8950538,,,,,,,,,,,,,, -156900,4.1056547,0.7467658,,,,,,,,,,,,,, -157000,4.634913,0.81444466,,,,,,,,,,,,,, -157100,4.2001348,0.7557552,,,,,,,,,,,,,, -157200,4.6267457,0.9069506,,,,,,,,,,,,,, -157300,4.071048,0.7718049,,,,,,,,,,,,,, -157400,4.644693,0.83771783,,,,,,,,,,,,,, -157500,4.2234874,0.8333271,,,,,,,,,,,,,, -157600,4.030279,0.72447926,,,,,,,,,,,,,, -157680,,,0.9210578799247742,0.2730793952941894,0.7454400062561035,1.0773038864135742,50000.0,0.619100034236908,1.838584065437317,10000.0,53097.9214026928,55065.41555118561,53097.9214026928,1955.8037884235384,6.003688812255859,0.0 -157700,4.290791,0.72096545,,,,,,,,,,,,,, -157800,4.4083085,0.9140335,,,,,,,,,,,,,, -157900,4.5719876,0.83903205,,,,,,,,,,,,,, -158000,4.2703695,0.75100315,,,,,,,,,,,,,, -158100,4.1391897,0.76261955,,,,,,,,,,,,,, -158200,4.3990226,0.8156285,,,,,,,,,,,,,, -158300,4.2267094,0.75013584,,,,,,,,,,,,,, -158400,4.72242,0.8706325,,,,,,,,,,,,,, -158500,4.129084,0.75552034,,,,,,,,,,,,,, -158600,4.558541,0.8315988,,,,,,,,,,,,,, -158700,4.368183,0.76709116,,,,,,,,,,,,,, -158800,4.0268517,0.7462472,,,,,,,,,,,,,, -158900,4.209974,0.7621022,,,,,,,,,,,,,, -159000,3.9387145,0.74271417,,,,,,,,,,,,,, -159100,4.420048,0.79073536,,,,,,,,,,,,,, -159197,,,0.9376793503761292,0.2213052958250045,0.7470600008964539,1.0691070556640625,50000.0,0.6219000220298767,1.8522708415985107,10000.0,53607.94852924347,55592.79269909859,53607.94852924347,1973.0448276996613,6.057266712188721,0.0 -159200,3.9761016,0.7783371,,,,,,,,,,,,,, -159300,4.487055,0.88292724,,,,,,,,,,,,,, -159400,4.5458407,0.76201576,,,,,,,,,,,,,, -159500,4.4388366,0.8273859,,,,,,,,,,,,,, -159600,4.4544916,0.80434376,,,,,,,,,,,,,, -159700,4.4196944,0.7196052,,,,,,,,,,,,,, -159800,4.1614876,0.7400329,,,,,,,,,,,,,, -159900,4.3803015,0.7170937,,,,,,,,,,,,,, -160000,3.8587692,0.67807156,,,,,,,,,,,,,, -160100,4.673785,0.7463829,,,,,,,,,,,,,, -160200,4.8036118,0.71714896,,,,,,,,,,,,,, -160300,4.3088503,0.7602238,,,,,,,,,,,,,, -160400,4.399337,0.78949976,,,,,,,,,,,,,, -160500,4.025043,0.7201965,,,,,,,,,,,,,, -160600,4.2922196,0.80166006,,,,,,,,,,,,,, -160700,3.6927004,0.66380394,,,,,,,,,,,,,, -160713,,,0.9351084232330322,0.2282746881246566,0.747439980506897,1.0624006986618042,50000.0,0.6206000447273254,1.8397870063781736,10000.0,54118.00081586838,56120.12428307533,54118.00081586838,1990.2089619636536,6.115020275115967,0.0 -160800,4.3255897,0.8164681,,,,,,,,,,,,,, -160900,4.2812223,0.7219722,,,,,,,,,,,,,, -161000,4.057491,0.7354361,,,,,,,,,,,,,, -161100,4.8480597,0.8501814,,,,,,,,,,,,,, -161200,4.18935,0.7001675,,,,,,,,,,,,,, -161300,3.993448,0.73010576,,,,,,,,,,,,,, -161400,4.1144648,0.7602694,,,,,,,,,,,,,, -161500,4.168143,0.7942252,,,,,,,,,,,,,, -161600,4.270178,0.8099345,,,,,,,,,,,,,, -161700,4.289897,0.78812855,,,,,,,,,,,,,, -161800,4.0209236,0.72069246,,,,,,,,,,,,,, -161900,4.887695,0.8541204,,,,,,,,,,,,,, -162000,4.037782,0.76483965,,,,,,,,,,,,,, -162100,4.3236256,0.73461384,,,,,,,,,,,,,, -162200,4.7196193,0.84261554,,,,,,,,,,,,,, -162230,,,0.9322983026504515,0.2332929223775863,0.7475199699401855,1.0625967979431152,50000.0,0.6203000545501709,1.8488606214523315,10000.0,54627.96944499016,56647.45462155342,54627.96944499016,2007.458515882492,6.172126054763794,0.0 -162300,4.2601113,0.7404227,,,,,,,,,,,,,, -162400,4.69387,0.7366117,,,,,,,,,,,,,, -162500,4.4376826,0.86208045,,,,,,,,,,,,,, -162600,4.2853136,0.68604845,,,,,,,,,,,,,, -162700,4.498309,0.7260393,,,,,,,,,,,,,, -162800,4.632325,0.75673074,,,,,,,,,,,,,, -162900,4.258388,0.7750785,,,,,,,,,,,,,, -163000,4.1841416,0.6990192,,,,,,,,,,,,,, -163100,4.520359,0.7476826,,,,,,,,,,,,,, -163200,4.7611113,0.7337019,,,,,,,,,,,,,, -163300,4.5598173,0.7543811,,,,,,,,,,,,,, -163400,4.734517,0.770963,,,,,,,,,,,,,, -163500,4.911901,0.75107133,,,,,,,,,,,,,, -163600,4.601017,0.8142113,,,,,,,,,,,,,, -163700,3.842851,0.66233087,,,,,,,,,,,,,, -163746,,,0.9357860088348388,0.2209810614585876,0.7492600083351135,1.0612397193908691,50000.0,0.6248000264167786,1.8343461751937864,10000.0,55137.90982818604,57174.564274311066,55137.90982818604,2024.5153052806847,6.228562831878662,0.0 -163800,4.6062365,0.7195723,,,,,,,,,,,,,, -163900,3.8448546,0.5683627,,,,,,,,,,,,,, -164000,4.1808143,0.68154556,,,,,,,,,,,,,, -164100,4.35594,0.7362904,,,,,,,,,,,,,, -164200,4.49243,0.6402412,,,,,,,,,,,,,, -164300,4.288255,0.711182,,,,,,,,,,,,,, -164400,4.0897193,0.7473372,,,,,,,,,,,,,, -164500,4.2316246,0.70407784,,,,,,,,,,,,,, -164600,4.089806,0.6820712,,,,,,,,,,,,,, -164700,4.579049,0.72147775,,,,,,,,,,,,,, -164800,4.471176,0.7952321,,,,,,,,,,,,,, -164900,4.595987,0.6487254,,,,,,,,,,,,,, -165000,4.040143,0.6431089,,,,,,,,,,,,,, -165100,4.437165,0.6940171,,,,,,,,,,,,,, -165200,4.5635614,0.7167182,,,,,,,,,,,,,, -165263,,,0.9393334984779358,0.211116150021553,0.7497999668121338,1.0586296319961548,50000.0,0.6249000430107117,1.8307875394821167,10000.0,55647.922043800354,57701.702865600586,55647.922043800354,2041.534103155136,6.281407117843628,0.0 -165300,4.1643505,0.6722709,,,,,,,,,,,,,, -165400,4.162836,0.6876708,,,,,,,,,,,,,, -165500,4.6081047,0.74734616,,,,,,,,,,,,,, -165600,4.5091577,0.7122121,,,,,,,,,,,,,, -165700,4.8156543,0.7923645,,,,,,,,,,,,,, -165800,4.5341773,0.71724284,,,,,,,,,,,,,, -165900,4.189958,0.6194656,,,,,,,,,,,,,, -166000,4.138402,0.7644287,,,,,,,,,,,,,, -166100,4.697748,0.79185784,,,,,,,,,,,,,, -166200,4.388157,0.6709378,,,,,,,,,,,,,, -166300,4.851523,0.7215663,,,,,,,,,,,,,, -166400,4.021765,0.61029464,,,,,,,,,,,,,, -166500,4.462897,0.66965157,,,,,,,,,,,,,, -166600,4.6736035,0.7526407,,,,,,,,,,,,,, -166700,4.182372,0.6708333,,,,,,,,,,,,,, -166780,,,0.9436383843421936,0.2011850476264953,0.751800000667572,1.0534090995788574,50000.0,0.6290000081062317,1.834994435310364,10000.0,56157.90705728531,58228.787316560745,56157.90705728531,2058.5210251808167,6.337911128997803,0.0 -166800,4.2236495,0.6587407,,,,,,,,,,,,,, -166900,4.448539,0.7194629,,,,,,,,,,,,,, -167000,4.1527104,0.681168,,,,,,,,,,,,,, -167100,4.7882886,0.7361663,,,,,,,,,,,,,, -167200,4.0968184,0.63178533,,,,,,,,,,,,,, -167300,4.1609683,0.68879956,,,,,,,,,,,,,, -167400,4.3682623,0.7100091,,,,,,,,,,,,,, -167500,4.286731,0.6915209,,,,,,,,,,,,,, -167600,4.299709,0.69510645,,,,,,,,,,,,,, -167700,3.9834518,0.6471009,,,,,,,,,,,,,, -167800,4.478327,0.63864994,,,,,,,,,,,,,, -167900,4.2105103,0.7052632,,,,,,,,,,,,,, -168000,4.4661126,0.74044657,,,,,,,,,,,,,, -168100,4.8650556,0.68798894,,,,,,,,,,,,,, -168200,4.6809607,0.72393143,,,,,,,,,,,,,, -168298,,,0.9491389989852904,0.1779803782701492,0.752079963684082,1.0596641302108765,50000.0,0.6277000308036804,1.8354393243789675,10000.0,56668.11461615562,58756.28093838692,56668.11461615562,2075.694316625595,6.394445896148682,0.0 -168300,4.3401556,0.6514183,,,,,,,,,,,,,, -168400,4.8021755,0.7340833,,,,,,,,,,,,,, -168500,4.3297806,0.69367385,,,,,,,,,,,,,, -168600,4.3335896,0.66394645,,,,,,,,,,,,,, -168700,4.7901816,0.5889177,,,,,,,,,,,,,, -168800,4.3000693,0.6949898,,,,,,,,,,,,,, -168900,4.538064,0.7371376,,,,,,,,,,,,,, -169000,4.043064,0.6471887,,,,,,,,,,,,,, -169100,4.357374,0.6212223,,,,,,,,,,,,,, -169200,4.754946,0.7751784,,,,,,,,,,,,,, -169300,4.4868016,0.6596628,,,,,,,,,,,,,, -169400,4.174869,0.70834076,,,,,,,,,,,,,, -169500,4.0019765,0.6559925,,,,,,,,,,,,,, -169600,4.449948,0.681581,,,,,,,,,,,,,, -169700,4.5308547,0.7895525,,,,,,,,,,,,,, -169800,4.542585,0.70600915,,,,,,,,,,,,,, -169815,,,0.950215220451355,0.1770836114883422,0.7531399726867676,1.052578091621399,50000.0,0.6274000406265259,1.8301844596862795,10000.0,57178.27260637283,59283.77220726013,57178.27260637283,2092.9129543304443,6.455153465270996,0.0 -169900,4.6131077,0.6555809,,,,,,,,,,,,,, -170000,4.6558084,0.6869389,,,,,,,,,,,,,, -170100,4.5744896,0.6968243,,,,,,,,,,,,,, -170200,4.5140967,0.67705095,,,,,,,,,,,,,, -170300,4.3156323,0.6551649,,,,,,,,,,,,,, -170400,4.692209,0.7225574,,,,,,,,,,,,,, -170500,4.1113067,0.69497955,,,,,,,,,,,,,, -170600,4.5986605,0.6923828,,,,,,,,,,,,,, -170700,4.5996494,0.6852815,,,,,,,,,,,,,, -170800,4.392227,0.66305876,,,,,,,,,,,,,, -170900,4.712546,0.7371415,,,,,,,,,,,,,, -171000,4.062395,0.64435256,,,,,,,,,,,,,, -171100,4.3087816,0.6329838,,,,,,,,,,,,,, -171200,4.6468773,0.6897555,,,,,,,,,,,,,, -171300,4.409356,0.6529071,,,,,,,,,,,,,, -171333,,,0.9529256820678712,0.1717437803745269,0.7522000074386597,1.053192973136902,50000.0,0.627500057220459,1.8310190439224243,10000.0,57688.4286673069,59811.13905739784,57688.4286673069,2110.01118683815,6.511188507080078,0.0 -171400,4.7460546,0.6813715,,,,,,,,,,,,,, -171500,4.604829,0.70448714,,,,,,,,,,,,,, -171600,4.2305675,0.58897877,,,,,,,,,,,,,, -171700,4.4354453,0.6852348,,,,,,,,,,,,,, -171800,4.5219464,0.6959601,,,,,,,,,,,,,, -171900,4.321881,0.6520505,,,,,,,,,,,,,, -172000,4.6511436,0.6966262,,,,,,,,,,,,,, -172100,4.440347,0.59660614,,,,,,,,,,,,,, -172200,4.5973763,0.63720894,,,,,,,,,,,,,, -172300,4.5706162,0.67216355,,,,,,,,,,,,,, -172400,4.8939795,0.7100487,,,,,,,,,,,,,, -172500,4.2770214,0.6072972,,,,,,,,,,,,,, -172600,4.835946,0.70503736,,,,,,,,,,,,,, -172700,4.5306826,0.7077723,,,,,,,,,,,,,, -172800,4.6132674,0.7066238,,,,,,,,,,,,,, -172850,,,0.9522879123687744,0.1757299751043319,0.7527799606323242,1.0505002737045288,50000.0,0.6271000504493713,1.833840489387512,10000.0,58198.459025383,60338.37391376495,58198.459025383,2127.100840330124,6.570757627487183,0.0 -172900,4.502608,0.6920659,,,,,,,,,,,,,, -173000,4.5301867,0.6551143,,,,,,,,,,,,,, -173100,4.887266,0.7785748,,,,,,,,,,,,,, -173200,4.2324076,0.6397925,,,,,,,,,,,,,, -173300,4.3553433,0.6306074,,,,,,,,,,,,,, -173400,4.5296674,0.6255231,,,,,,,,,,,,,, -173500,4.9691877,0.70191693,,,,,,,,,,,,,, -173600,4.2698913,0.6326979,,,,,,,,,,,,,, -173700,5.057598,0.61694604,,,,,,,,,,,,,, -173800,4.2811713,0.5988293,,,,,,,,,,,,,, -173900,5.195074,0.7614876,,,,,,,,,,,,,, -174000,4.047782,0.6142769,,,,,,,,,,,,,, -174100,4.3028297,0.6206827,,,,,,,,,,,,,, -174200,4.426362,0.60466087,,,,,,,,,,,,,, -174300,4.407124,0.61945677,,,,,,,,,,,,,, -174368,,,0.9563336968421936,0.1650478243827819,0.7552399635314941,1.0452936887741089,50000.0,0.6290000081062317,1.826395988464356,10000.0,58708.66118121147,60865.69796657562,58708.66118121147,2144.1066920757294,6.630687713623047,0.0 -174400,4.2731166,0.6211971,,,,,,,,,,,,,, -174500,4.557989,0.6124234,,,,,,,,,,,,,, -174600,4.484363,0.6939092,,,,,,,,,,,,,, -174700,4.745514,0.744457,,,,,,,,,,,,,, -174800,4.286799,0.630533,,,,,,,,,,,,,, -174900,4.1388593,0.63623995,,,,,,,,,,,,,, -175000,4.2345705,0.680834,,,,,,,,,,,,,, -175100,4.260091,0.6152204,,,,,,,,,,,,,, -175200,4.101163,0.68146783,,,,,,,,,,,,,, -175300,4.6487193,0.65014297,,,,,,,,,,,,,, -175400,4.239544,0.6616831,,,,,,,,,,,,,, -175500,4.240113,0.6465181,,,,,,,,,,,,,, -175600,4.5594764,0.70734537,,,,,,,,,,,,,, -175700,4.7113075,0.71322715,,,,,,,,,,,,,, -175800,4.405068,0.67322123,,,,,,,,,,,,,, -175886,,,0.9587850570678712,0.1511284112930297,0.7549600005149841,1.0456360578536987,50000.0,0.631600022315979,1.824371576309204,10000.0,59218.857328653336,61393.11685776711,59218.857328653336,2161.2147500514984,6.689452409744263,0.0 -175900,4.7660685,0.7104997,,,,,,,,,,,,,, -176000,4.757515,0.7074371,,,,,,,,,,,,,, -176100,4.858423,0.63683087,,,,,,,,,,,,,, -176200,4.4036694,0.5848618,,,,,,,,,,,,,, -176300,4.0667367,0.65520537,,,,,,,,,,,,,, -176400,4.8953123,0.6631379,,,,,,,,,,,,,, -176500,4.50811,0.67487264,,,,,,,,,,,,,, -176600,4.852586,0.6185458,,,,,,,,,,,,,, -176700,4.5779915,0.6649206,,,,,,,,,,,,,, -176800,4.375728,0.67553335,,,,,,,,,,,,,, -176900,4.5854864,0.7598066,,,,,,,,,,,,,, -177000,4.2162566,0.62282264,,,,,,,,,,,,,, -177100,4.150232,0.5859125,,,,,,,,,,,,,, -177200,4.864553,0.6448672,,,,,,,,,,,,,, -177300,4.4504023,0.6211497,,,,,,,,,,,,,, -177400,4.873786,0.68436265,,,,,,,,,,,,,, -177404,,,0.9591039419174194,0.1507154107093811,0.7559399604797363,1.0450918674468994,50000.0,0.6291000247001648,1.8281220197677608,10000.0,59728.900113105774,61920.21709442139,59728.900113105774,2178.15860247612,6.749515771865845,0.0 -177500,4.8605766,0.56377304,,,,,,,,,,,,,, -177600,4.589565,0.6647849,,,,,,,,,,,,,, -177700,4.4517994,0.60570145,,,,,,,,,,,,,, -177800,4.2347784,0.67353135,,,,,,,,,,,,,, -177900,4.2940845,0.6506055,,,,,,,,,,,,,, -178000,4.7209015,0.6439704,,,,,,,,,,,,,, -178100,4.3396854,0.56093043,,,,,,,,,,,,,, -178200,4.52435,0.6836329,,,,,,,,,,,,,, -178300,4.375125,0.6179328,,,,,,,,,,,,,, -178400,4.3189955,0.63736975,,,,,,,,,,,,,, -178500,5.0539856,0.6412232,,,,,,,,,,,,,, -178600,4.2650223,0.6784071,,,,,,,,,,,,,, -178700,4.576754,0.61714673,,,,,,,,,,,,,, -178800,4.8960085,0.6839867,,,,,,,,,,,,,, -178900,4.035432,0.5842989,,,,,,,,,,,,,, -178918,,,0.959004282951355,0.1506407111883163,0.756659984588623,1.040774703025818,50000.0,0.629300057888031,1.8209997415542605,10000.0,60238.09110379219,62447.85934686661,60238.09110379219,2195.6819846630096,7.62337064743042,0.0 -179000,4.425363,0.65802455,,,,,,,,,,,,,, -179100,4.643955,0.60157114,,,,,,,,,,,,,, -179200,4.2009573,0.57842666,,,,,,,,,,,,,, -179300,4.0924406,0.5513058,,,,,,,,,,,,,, -179400,4.8301835,0.63252527,,,,,,,,,,,,,, -179500,4.3730073,0.6294364,,,,,,,,,,,,,, -179600,4.731067,0.6289269,,,,,,,,,,,,,, -179700,4.3736973,0.63945794,,,,,,,,,,,,,, -179800,5.110949,0.74479926,,,,,,,,,,,,,, -179900,4.6753135,0.7084501,,,,,,,,,,,,,, -180000,4.7121243,0.63312393,,,,,,,,,,,,,, -180100,4.175458,0.5891641,,,,,,,,,,,,,, -180200,4.6415014,0.59964395,,,,,,,,,,,,,, -180300,4.3649445,0.5744001,,,,,,,,,,,,,, -180400,4.8514576,0.60969406,,,,,,,,,,,,,, -180435,,,0.9593430757522584,0.148685485124588,0.7560799717903137,1.0443583726882937,50000.0,0.6309000253677368,1.8259391784667969,10000.0,60748.00309252739,62975.11134815216,60748.00309252739,2212.891495943069,7.698556661605835,0.0 -180500,4.1985073,0.6588484,,,,,,,,,,,,,, -180600,4.3841596,0.61837465,,,,,,,,,,,,,, -180700,5.0073266,0.6350583,,,,,,,,,,,,,, -180800,4.83736,0.6681674,,,,,,,,,,,,,, -180900,4.6893864,0.67888343,,,,,,,,,,,,,, -181000,4.734248,0.6930294,,,,,,,,,,,,,, -181100,4.8015227,0.65209913,,,,,,,,,,,,,, -181200,3.7865443,0.55481136,,,,,,,,,,,,,, -181300,4.698577,0.65777045,,,,,,,,,,,,,, -181400,4.5093923,0.57420224,,,,,,,,,,,,,, -181500,4.2754474,0.5898079,,,,,,,,,,,,,, -181600,5.106584,0.66722345,,,,,,,,,,,,,, -181700,4.5435176,0.59604734,,,,,,,,,,,,,, -181800,4.499985,0.63774735,,,,,,,,,,,,,, -181900,4.7144566,0.61630523,,,,,,,,,,,,,, -181954,,,0.959203600883484,0.1488100439310073,0.7558799982070923,1.0424973964691162,50000.0,0.6309000253677368,1.8252125978469849,10000.0,61258.025297403336,63502.30787181854,61258.025297403336,2229.9518847465515,7.7558934688568115,0.0 -182000,4.3304796,0.55733067,,,,,,,,,,,,,, -182100,4.525339,0.61819893,,,,,,,,,,,,,, -182200,5.0236735,0.6759865,,,,,,,,,,,,,, -182300,4.662822,0.60986555,,,,,,,,,,,,,, -182400,4.870743,0.6561618,,,,,,,,,,,,,, -182500,4.62895,0.6385289,,,,,,,,,,,,,, -182600,4.478234,0.63919103,,,,,,,,,,,,,, -182700,4.220995,0.57250834,,,,,,,,,,,,,, -182800,4.9713984,0.7119671,,,,,,,,,,,,,, -182900,4.375274,0.5518862,,,,,,,,,,,,,, -183000,4.546805,0.65476507,,,,,,,,,,,,,, -183100,4.490529,0.5774345,,,,,,,,,,,,,, -183200,4.140155,0.59621155,,,,,,,,,,,,,, -183300,4.636961,0.60264033,,,,,,,,,,,,,, -183400,4.5432057,0.6054939,,,,,,,,,,,,,, -183470,,,0.9610969424247742,0.1481508761644363,0.7571799755096436,1.0403326749801636,50000.0,0.6320000290870667,1.821579337120056,10000.0,61767.9468460083,64029.39771032333,61767.9468460083,2247.0036346912384,7.818108081817627,0.0 -183500,4.1855674,0.61467946,,,,,,,,,,,,,, -183600,4.884379,0.6715783,,,,,,,,,,,,,, -183700,4.3753424,0.59620553,,,,,,,,,,,,,, -183800,4.6982884,0.6338199,,,,,,,,,,,,,, -183900,4.5138845,0.63400376,,,,,,,,,,,,,, -184000,4.434213,0.6273621,,,,,,,,,,,,,, -184100,4.315784,0.6349995,,,,,,,,,,,,,, -184200,5.369154,0.6698668,,,,,,,,,,,,,, -184300,4.7553763,0.6345991,,,,,,,,,,,,,, -184400,4.512571,0.65574193,,,,,,,,,,,,,, -184500,4.7348742,0.67964756,,,,,,,,,,,,,, -184600,4.681228,0.6542988,,,,,,,,,,,,,, -184700,4.3611465,0.6405308,,,,,,,,,,,,,, -184800,4.4664745,0.69896895,,,,,,,,,,,,,, -184900,4.3096657,0.5982201,,,,,,,,,,,,,, -184987,,,0.959781527519226,0.1489413529634475,0.7575399875640869,1.041106343269348,50000.0,0.6317000389099121,1.8240126371383667,10000.0,62278.05148458481,64556.78928041458,62278.05148458481,2264.17701625824,7.874583959579468,0.0 -185000,4.4210296,0.57993793,,,,,,,,,,,,,, -185100,4.206313,0.5985184,,,,,,,,,,,,,, -185200,4.322052,0.6400463,,,,,,,,,,,,,, -185300,4.163148,0.6319337,,,,,,,,,,,,,, -185400,4.685605,0.6871531,,,,,,,,,,,,,, -185500,3.932731,0.59971154,,,,,,,,,,,,,, -185600,4.367129,0.5849243,,,,,,,,,,,,,, -185700,4.1691656,0.59042966,,,,,,,,,,,,,, -185800,4.383889,0.630026,,,,,,,,,,,,,, -185900,4.4400887,0.692337,,,,,,,,,,,,,, -186000,5.064493,0.66471493,,,,,,,,,,,,,, -186100,4.234847,0.57429546,,,,,,,,,,,,,, -186200,4.5054383,0.6882973,,,,,,,,,,,,,, -186300,4.16887,0.5906083,,,,,,,,,,,,,, -186400,4.8683825,0.6406827,,,,,,,,,,,,,, -186500,4.6097655,0.6380262,,,,,,,,,,,,,, -186503,,,0.9605189561843872,0.145164668560028,0.7573599815368652,1.0415334701538086,50000.0,0.6320000290870667,1.8241097927093504,10000.0,62787.94189047813,65083.8095471859,62787.94189047813,2281.1940383911133,7.932438135147095,0.0 -186600,4.6663284,0.67199874,,,,,,,,,,,,,, -186700,4.199198,0.6076721,,,,,,,,,,,,,, -186800,4.528817,0.5608846,,,,,,,,,,,,,, -186900,4.031923,0.5614424,,,,,,,,,,,,,, -187000,4.789816,0.6834652,,,,,,,,,,,,,, -187100,4.27069,0.6375364,,,,,,,,,,,,,, -187200,4.9705005,0.65238744,,,,,,,,,,,,,, -187300,4.2939425,0.60685736,,,,,,,,,,,,,, -187400,4.4715824,0.6128049,,,,,,,,,,,,,, -187500,4.3854313,0.61363506,,,,,,,,,,,,,, -187600,4.579204,0.5934206,,,,,,,,,,,,,, -187700,4.7223787,0.60021067,,,,,,,,,,,,,, -187800,4.410288,0.58030474,,,,,,,,,,,,,, -187900,4.27362,0.6390951,,,,,,,,,,,,,, -188000,4.212713,0.584415,,,,,,,,,,,,,, -188020,,,0.9613759517669678,0.1438952833414077,0.7572599649429321,1.0408782958984375,50000.0,0.6317000389099121,1.8217973709106443,10000.0,63297.95154929161,65611.39433121681,63297.95154929161,2298.656086206436,7.991012096405029,0.0 -188100,4.879216,0.6458501,,,,,,,,,,,,,, -188200,4.618968,0.64796597,,,,,,,,,,,,,, -188300,4.7018914,0.6491405,,,,,,,,,,,,,, -188400,4.5305405,0.6446781,,,,,,,,,,,,,, -188500,4.8016577,0.6269203,,,,,,,,,,,,,, -188600,4.7806206,0.5955901,,,,,,,,,,,,,, -188700,4.541136,0.65210855,,,,,,,,,,,,,, -188800,4.624692,0.72659403,,,,,,,,,,,,,, -188900,4.5070786,0.62814337,,,,,,,,,,,,,, -189000,4.3013315,0.5733761,,,,,,,,,,,,,, -189100,4.2436934,0.59247625,,,,,,,,,,,,,, -189200,4.693028,0.66776973,,,,,,,,,,,,,, -189300,4.93634,0.6941242,,,,,,,,,,,,,, -189400,4.4911704,0.60142756,,,,,,,,,,,,,, -189500,5.008755,0.69638616,,,,,,,,,,,,,, -189537,,,0.962332546710968,0.142117902636528,0.757420003414154,1.0418941974639893,50000.0,0.631100058555603,1.8226947784423828,10000.0,63808.01502132416,66138.68200182915,63808.01502132416,2315.777137756348,8.041492938995361,0.0 -189600,4.26082,0.55628467,,,,,,,,,,,,,, -189700,4.443647,0.6333291,,,,,,,,,,,,,, -189800,4.4536595,0.5899526,,,,,,,,,,,,,, -189900,4.9096847,0.60986006,,,,,,,,,,,,,, -190000,4.558693,0.65151906,,,,,,,,,,,,,, -190100,4.2736998,0.62343913,,,,,,,,,,,,,, -190200,4.194493,0.57609546,,,,,,,,,,,,,, -190300,4.9309673,0.6464292,,,,,,,,,,,,,, -190400,4.503646,0.6213326,,,,,,,,,,,,,, -190500,4.4052215,0.6454773,,,,,,,,,,,,,, -190600,4.539397,0.6011065,,,,,,,,,,,,,, -190700,4.1967235,0.58412445,,,,,,,,,,,,,, -190800,4.3178916,0.6690694,,,,,,,,,,,,,, -190900,4.1490374,0.5251558,,,,,,,,,,,,,, -191000,4.5033946,0.6465806,,,,,,,,,,,,,, -191055,,,0.961734652519226,0.1454230993986129,0.7573399543762207,1.0415695905685425,50000.0,0.6321000456809998,1.822113394737244,10000.0,64317.99254322052,66665.94344830513,64317.99254322052,2332.939744949341,8.108099699020386,0.0 -191100,3.892423,0.5611868,,,,,,,,,,,,,, -191200,4.5022063,0.6776198,,,,,,,,,,,,,, -191300,4.4969754,0.6568963,,,,,,,,,,,,,, -191400,4.8698454,0.65328056,,,,,,,,,,,,,, -191500,4.7975636,0.6042254,,,,,,,,,,,,,, -191600,4.3732557,0.5767728,,,,,,,,,,,,,, -191700,4.3381796,0.6302128,,,,,,,,,,,,,, -191800,5.2706423,0.6417671,,,,,,,,,,,,,, -191900,4.610606,0.6080032,,,,,,,,,,,,,, -192000,4.3276134,0.5261687,,,,,,,,,,,,,, -192100,4.2682643,0.63333726,,,,,,,,,,,,,, -192200,4.728183,0.5795551,,,,,,,,,,,,,, -192300,3.9215195,0.57930076,,,,,,,,,,,,,, -192400,4.554726,0.6691797,,,,,,,,,,,,,, -192500,4.6622624,0.59482133,,,,,,,,,,,,,, -192572,,,0.959741711616516,0.1477141678333282,0.757319986820221,1.040766954421997,50000.0,0.6324000358581543,1.8228449821472168,10000.0,64828.0078830719,67193.97899508476,64828.0078830719,2350.846947908401,8.166432857513428,0.0 -192600,4.6783476,0.67895454,,,,,,,,,,,,,, -192700,4.5514193,0.59314996,,,,,,,,,,,,,, -192800,4.7547917,0.61472034,,,,,,,,,,,,,, -192900,4.747625,0.63749194,,,,,,,,,,,,,, -193000,4.397666,0.5950693,,,,,,,,,,,,,, -193100,3.906578,0.54820764,,,,,,,,,,,,,, -193200,4.875696,0.72236615,,,,,,,,,,,,,, -193300,4.5975137,0.6661197,,,,,,,,,,,,,, -193400,4.4194436,0.59562504,,,,,,,,,,,,,, -193500,4.6671205,0.6279517,,,,,,,,,,,,,, -193600,4.7790737,0.6632573,,,,,,,,,,,,,, -193700,4.1736574,0.6089387,,,,,,,,,,,,,, -193800,4.50408,0.6002544,,,,,,,,,,,,,, -193900,4.129501,0.6192486,,,,,,,,,,,,,, -194000,4.5874906,0.6377827,,,,,,,,,,,,,, -194089,,,0.9594626426696776,0.1482044607400894,0.7571600079536438,1.0418704748153689,50000.0,0.6312000155448914,1.823163986206055,10000.0,65338.185666799545,67721.27209353447,65338.185666799545,2367.845665693283,8.226392269134521,0.0 -194100,4.7001033,0.6507904,,,,,,,,,,,,,, -194200,4.2366414,0.6251653,,,,,,,,,,,,,, -194300,4.4657626,0.6034068,,,,,,,,,,,,,, -194400,4.862202,0.6724012,,,,,,,,,,,,,, -194500,4.35269,0.61835486,,,,,,,,,,,,,, -194600,4.8412004,0.5894598,,,,,,,,,,,,,, -194700,5.066756,0.76009244,,,,,,,,,,,,,, -194800,4.5321326,0.6405469,,,,,,,,,,,,,, -194900,4.722147,0.6618853,,,,,,,,,,,,,, -195000,4.6844263,0.6753896,,,,,,,,,,,,,, -195100,4.0926676,0.560301,,,,,,,,,,,,,, -195200,4.450497,0.59086174,,,,,,,,,,,,,, -195300,4.6033387,0.6273155,,,,,,,,,,,,,, -195400,4.654613,0.64101565,,,,,,,,,,,,,, -195500,5.1557913,0.6290318,,,,,,,,,,,,,, -195600,4.3702574,0.5634977,,,,,,,,,,,,,, -195606,,,0.9624919891357422,0.142418086528778,0.7576199769973755,1.0406326055526731,50000.0,0.6314000487327576,1.822136402130127,10000.0,65848.35327005386,68248.72687864304,65848.35327005386,2385.0194478034973,8.285529613494873,0.0 -195700,4.7031875,0.6765279,,,,,,,,,,,,,, -195800,4.335364,0.66240627,,,,,,,,,,,,,, -195900,4.6318192,0.6437136,,,,,,,,,,,,,, -196000,4.213555,0.6261779,,,,,,,,,,,,,, -196100,4.2052665,0.5962631,,,,,,,,,,,,,, -196200,4.744263,0.63981354,,,,,,,,,,,,,, -196300,4.3497896,0.6061397,,,,,,,,,,,,,, -196400,4.2027445,0.5718003,,,,,,,,,,,,,, -196500,4.9116993,0.6689329,,,,,,,,,,,,,, -196600,4.848743,0.6548104,,,,,,,,,,,,,, -196700,5.047309,0.641667,,,,,,,,,,,,,, -196800,4.7925553,0.5908514,,,,,,,,,,,,,, -196900,4.95547,0.58139944,,,,,,,,,,,,,, -197000,4.624053,0.6295532,,,,,,,,,,,,,, -197100,4.539216,0.6595527,,,,,,,,,,,,,, -197123,,,0.9610969424247742,0.1462296843528747,0.7571399807929993,1.039520263671875,50000.0,0.6328000426292419,1.821480393409729,10000.0,66358.55031299591,68776.01489758492,66358.55031299591,2401.994375705719,8.346335649490356,0.0 -197200,4.299297,0.6082067,,,,,,,,,,,,,, -197300,4.4712267,0.63594866,,,,,,,,,,,,,, -197400,4.0403724,0.57241523,,,,,,,,,,,,,, -197500,4.3568697,0.6322708,,,,,,,,,,,,,, -197600,4.9799714,0.65065706,,,,,,,,,,,,,, -197700,4.434179,0.60151356,,,,,,,,,,,,,, -197800,4.480258,0.62894523,,,,,,,,,,,,,, -197900,4.791383,0.6343419,,,,,,,,,,,,,, -198000,4.323898,0.5948052,,,,,,,,,,,,,, -198100,4.659166,0.58929914,,,,,,,,,,,,,, -198200,4.5217013,0.6169079,,,,,,,,,,,,,, -198300,4.4188213,0.66325665,,,,,,,,,,,,,, -198400,4.4589133,0.64216304,,,,,,,,,,,,,, -198500,4.337888,0.63163173,,,,,,,,,,,,,, -198600,4.8213334,0.64341223,,,,,,,,,,,,,, -198640,,,0.9618343114852904,0.1448915749788284,0.7569999694824219,1.041074275970459,50000.0,0.6317000389099121,1.8206541538238523,10000.0,66868.47186207771,69303.13593435287,66868.47186207771,2419.071531057358,8.412399053573608,0.0 -198700,4.9463496,0.68065464,,,,,,,,,,,,,, -198800,4.686394,0.5912543,,,,,,,,,,,,,, -198900,4.528978,0.7113125,,,,,,,,,,,,,, -199000,4.518031,0.61584175,,,,,,,,,,,,,, -199100,4.6193333,0.6010013,,,,,,,,,,,,,, -199200,4.7820425,0.6588197,,,,,,,,,,,,,, -199300,4.6458235,0.6837255,,,,,,,,,,,,,, -199400,5.0879345,0.6273448,,,,,,,,,,,,,, -199500,4.5435095,0.6062158,,,,,,,,,,,,,, -199600,4.6636963,0.68351245,,,,,,,,,,,,,, -199700,4.510126,0.62626207,,,,,,,,,,,,,, -199800,4.282905,0.60899025,,,,,,,,,,,,,, -199900,4.1212444,0.54870445,,,,,,,,,,,,,, -200000,4.5793743,0.6377307,,,,,,,,,,,,,, -200100,4.5322876,0.682896,,,,,,,,,,,,,, -200156,,,0.960180163383484,0.1485601663589477,0.7572199702262878,1.0412520170211792,50000.0,0.6319000124931335,1.8226794004440308,10000.0,67378.39338731766,69830.4341545105,67378.39338731766,2436.330323457718,8.475186109542847,0.0 -200200,4.4007235,0.58621156,,,,,,,,,,,,,, -200300,4.5576973,0.5874006,,,,,,,,,,,,,, -200400,4.112826,0.5313238,,,,,,,,,,,,,, -200500,3.9635637,0.5304555,,,,,,,,,,,,,, -200600,4.175587,0.52325666,,,,,,,,,,,,,, -200700,4.93319,0.683796,,,,,,,,,,,,,, -200800,4.8577137,0.6554727,,,,,,,,,,,,,, -200900,4.77677,0.6580844,,,,,,,,,,,,,, -201000,5.091707,0.6493957,,,,,,,,,,,,,, -201100,4.6738625,0.6243747,,,,,,,,,,,,,, -201200,4.9004807,0.64559853,,,,,,,,,,,,,, -201300,4.697363,0.6505157,,,,,,,,,,,,,, -201400,4.0409584,0.55032396,,,,,,,,,,,,,, -201500,4.5142035,0.59433365,,,,,,,,,,,,,, -201600,4.0762467,0.60179114,,,,,,,,,,,,,, -201673,,,0.9604591727256776,0.1474022418260574,0.7573800086975098,1.0413967370986938,50000.0,0.6309000253677368,1.823739767074585,10000.0,67888.32476067543,70357.69950819016,67888.32476067543,2453.5483391284943,8.535441637039185,0.0 -201700,4.282078,0.5627285,,,,,,,,,,,,,, -201800,4.479853,0.6284154,,,,,,,,,,,,,, -201900,4.62465,0.6635197,,,,,,,,,,,,,, -202000,4.895102,0.63271075,,,,,,,,,,,,,, -202100,4.193146,0.54389274,,,,,,,,,,,,,, -202200,4.426467,0.67057616,,,,,,,,,,,,,, -202300,4.669635,0.6248603,,,,,,,,,,,,,, -202400,4.224006,0.628241,,,,,,,,,,,,,, -202500,5.008212,0.5579015,,,,,,,,,,,,,, -202600,4.3953,0.6736574,,,,,,,,,,,,,, -202700,4.7546043,0.62101394,,,,,,,,,,,,,, -202800,4.609079,0.6332668,,,,,,,,,,,,,, -202900,4.5522175,0.647683,,,,,,,,,,,,,, -203000,4.3988056,0.6281798,,,,,,,,,,,,,, -203100,4.400607,0.65846014,,,,,,,,,,,,,, -203191,,,0.9622727632522584,0.1433053612709045,0.7568999528884888,1.0419265031814575,50000.0,0.631100058555603,1.822848916053772,10000.0,68398.4266242981,70885.42971634865,68398.4266242981,2471.0616297721863,8.596054077148438,0.0 -203200,4.735774,0.677492,,,,,,,,,,,,,, -203300,4.310593,0.6151114,,,,,,,,,,,,,, -203400,4.4098887,0.6191059,,,,,,,,,,,,,, -203500,4.974442,0.6405257,,,,,,,,,,,,,, -203600,4.5519633,0.61860645,,,,,,,,,,,,,, -203700,4.424906,0.5914094,,,,,,,,,,,,,, -203800,4.804458,0.5957984,,,,,,,,,,,,,, -203900,4.425355,0.56698114,,,,,,,,,,,,,, -204000,4.804056,0.65835744,,,,,,,,,,,,,, -204100,4.5608816,0.6432477,,,,,,,,,,,,,, -204200,4.454297,0.63998073,,,,,,,,,,,,,, -204300,4.3739133,0.68108577,,,,,,,,,,,,,, -204400,4.559651,0.63414055,,,,,,,,,,,,,, -204500,4.788189,0.6428062,,,,,,,,,,,,,, -204600,4.45365,0.58901286,,,,,,,,,,,,,, -204700,4.871194,0.6578305,,,,,,,,,,,,,, -204708,,,0.9606186151504515,0.1448125839233398,0.7568399906158447,1.0413434505462646,50000.0,0.6314000487327576,1.8237817287445068,10000.0,68908.5199649334,71412.57393550873,68908.5199649334,2487.9953002929688,8.656528234481812,0.0 -204800,4.215668,0.60531235,,,,,,,,,,,,,, -204900,4.753578,0.6646774,,,,,,,,,,,,,, -205000,4.4754744,0.6160806,,,,,,,,,,,,,, -205100,4.4602294,0.595085,,,,,,,,,,,,,, -205200,4.400261,0.5773034,,,,,,,,,,,,,, -205300,4.286229,0.6431641,,,,,,,,,,,,,, -205400,4.3652186,0.5964342,,,,,,,,,,,,,, -205500,4.549557,0.6349461,,,,,,,,,,,,,, -205600,4.2159,0.5497685,,,,,,,,,,,,,, -205700,4.6629515,0.53587985,,,,,,,,,,,,,, -205800,4.621572,0.62239474,,,,,,,,,,,,,, -205900,4.912394,0.65203995,,,,,,,,,,,,,, -206000,4.734636,0.59247553,,,,,,,,,,,,,, -206100,3.9802842,0.5875531,,,,,,,,,,,,,, -206200,4.192662,0.5394573,,,,,,,,,,,,,, -206225,,,0.9596420526504515,0.1503311842679977,0.7572799921035767,1.041097640991211,50000.0,0.6313000321388245,1.822572112083435,10000.0,69418.53223371506,71939.65333914757,69418.53223371506,2504.9435436725616,8.719544410705566,0.0 -206300,4.2891216,0.5665536,,,,,,,,,,,,,, -206400,4.288695,0.6256356,,,,,,,,,,,,,, -206500,5.0955453,0.66180867,,,,,,,,,,,,,, -206600,4.7851377,0.6181624,,,,,,,,,,,,,, -206700,4.3229823,0.5934951,,,,,,,,,,,,,, -206800,4.2951555,0.6222966,,,,,,,,,,,,,, -206900,4.6947675,0.64266324,,,,,,,,,,,,,, -207000,4.5770273,0.65525556,,,,,,,,,,,,,, -207100,4.148241,0.5381098,,,,,,,,,,,,,, -207200,5.2489147,0.6703303,,,,,,,,,,,,,, -207300,4.783016,0.62427425,,,,,,,,,,,,,, -207400,4.1579103,0.5718685,,,,,,,,,,,,,, -207500,4.235098,0.57689613,,,,,,,,,,,,,, -207600,4.0122094,0.5696584,,,,,,,,,,,,,, -207700,4.319887,0.5751227,,,,,,,,,,,,,, -207743,,,0.9592633843421936,0.1481289863586425,0.7573599815368652,1.0401948690414429,50000.0,0.6319000124931335,1.8207805156707764,10000.0,69928.57495713234,72466.876288414,69928.57495713234,2522.005439043045,8.781482458114624,0.0 -207800,4.8758287,0.6969498,,,,,,,,,,,,,, -207900,4.24181,0.5722173,,,,,,,,,,,,,, -208000,4.4423323,0.6435789,,,,,,,,,,,,,, -208100,4.8516326,0.72463024,,,,,,,,,,,,,, -208200,4.4439626,0.60277,,,,,,,,,,,,,, -208300,3.9837294,0.61224735,,,,,,,,,,,,,, -208400,4.308847,0.59805167,,,,,,,,,,,,,, -208500,4.3214827,0.6315831,,,,,,,,,,,,,, -208600,4.3528447,0.5824547,,,,,,,,,,,,,, -208700,4.476519,0.61874604,,,,,,,,,,,,,, -208800,4.2618704,0.63168097,,,,,,,,,,,,,, -208900,4.6538014,0.5590719,,,,,,,,,,,,,, -209000,4.223592,0.54332525,,,,,,,,,,,,,, -209100,3.9885273,0.62070096,,,,,,,,,,,,,, -209200,4.435013,0.69446874,,,,,,,,,,,,,, -209260,,,0.9608577489852904,0.1483916938304901,0.7569199800491333,1.0417641401290894,50000.0,0.6304000020027161,1.822595238685608,10000.0,70438.65213561058,72994.25761318207,70438.65213561058,2539.191326379776,8.84352707862854,0.0 -209300,4.120601,0.5917009,,,,,,,,,,,,,, -209400,4.351805,0.6482829,,,,,,,,,,,,,, -209500,4.5519257,0.61303866,,,,,,,,,,,,,, -209600,4.3534775,0.61336416,,,,,,,,,,,,,, -209700,4.554055,0.636921,,,,,,,,,,,,,, -209800,4.926812,0.66997576,,,,,,,,,,,,,, -209900,4.419607,0.61872435,,,,,,,,,,,,,, -210000,4.4327607,0.63342744,,,,,,,,,,,,,, -210100,4.1486435,0.5841863,,,,,,,,,,,,,, -210200,4.678026,0.65063936,,,,,,,,,,,,,, -210300,4.4769573,0.65134066,,,,,,,,,,,,,, -210400,4.7144094,0.59706557,,,,,,,,,,,,,, -210500,4.631381,0.5802082,,,,,,,,,,,,,, -210600,5.133986,0.60471,,,,,,,,,,,,,, -210700,4.8172765,0.70452964,,,,,,,,,,,,,, -210777,,,0.9615951776504515,0.1449353992938995,0.7570399641990662,1.041481375694275,50000.0,0.6321000456809998,1.8224416971206665,10000.0,70948.76120257378,73521.68989634514,70948.76120257378,2556.391278028488,8.910342454910278,0.0 -210800,4.290487,0.6496357,,,,,,,,,,,,,, -210900,4.570292,0.62053144,,,,,,,,,,,,,, -211000,4.312328,0.53524566,,,,,,,,,,,,,, -211100,4.5714226,0.6197061,,,,,,,,,,,,,, -211200,4.6391134,0.6626294,,,,,,,,,,,,,, -211300,4.2802496,0.6214503,,,,,,,,,,,,,, -211400,4.1601954,0.60454243,,,,,,,,,,,,,, -211500,4.328743,0.6167284,,,,,,,,,,,,,, -211600,4.5225744,0.61919034,,,,,,,,,,,,,, -211700,4.5959873,0.62213624,,,,,,,,,,,,,, -211800,4.597571,0.6077602,,,,,,,,,,,,,, -211900,4.4464116,0.59761727,,,,,,,,,,,,,, -212000,4.7653613,0.6810272,,,,,,,,,,,,,, -212100,4.5873055,0.63229656,,,,,,,,,,,,,, -212200,4.4990644,0.6351715,,,,,,,,,,,,,, -212295,,,0.9602199792861938,0.1497755348682403,0.7570199966430664,1.0418096780776978,50000.0,0.6315000057220459,1.823248147964477,10000.0,71458.87047481537,74049.12346696854,71458.87047481537,2573.594845056534,8.976428747177124,0.0 -212300,4.2948895,0.6304029,,,,,,,,,,,,,, -212400,4.5225525,0.60591054,,,,,,,,,,,,,, -212500,4.7431545,0.56768924,,,,,,,,,,,,,, -212600,4.4142466,0.6260162,,,,,,,,,,,,,, -212700,4.5072255,0.67984885,,,,,,,,,,,,,, -212800,4.2061896,0.590823,,,,,,,,,,,,,, -212900,4.4391117,0.6735979,,,,,,,,,,,,,, -213000,4.5920587,0.5892823,,,,,,,,,,,,,, -213100,4.50111,0.6545587,,,,,,,,,,,,,, -213200,4.359308,0.54968995,,,,,,,,,,,,,, -213300,4.454893,0.6304834,,,,,,,,,,,,,, -213400,4.7105784,0.6989598,,,,,,,,,,,,,, -213500,4.4726744,0.63815933,,,,,,,,,,,,,, -213600,4.6517386,0.6191428,,,,,,,,,,,,,, -213700,4.1714087,0.5768015,,,,,,,,,,,,,, -213800,4.726386,0.6145535,,,,,,,,,,,,,, -213812,,,0.9613958597183228,0.1449228227138519,0.7572999596595764,1.0410807132720947,50000.0,0.6319000124931335,1.823547601699829,10000.0,71968.89846563339,74576.27930808067,71968.89846563339,2590.599480867386,9.042627096176147,0.0 -213900,4.896106,0.5944199,,,,,,,,,,,,,, -214000,4.3393917,0.6912795,,,,,,,,,,,,,, -214100,4.507819,0.6506731,,,,,,,,,,,,,, -214200,4.732147,0.6513973,,,,,,,,,,,,,, -214300,4.2261724,0.6292443,,,,,,,,,,,,,, -214400,4.763628,0.71543026,,,,,,,,,,,,,, -214500,4.5037394,0.6027287,,,,,,,,,,,,,, -214600,4.7535777,0.6396347,,,,,,,,,,,,,, -214700,4.477721,0.6244873,,,,,,,,,,,,,, -214800,4.2810087,0.5917525,,,,,,,,,,,,,, -214900,5.076779,0.64855534,,,,,,,,,,,,,, -215000,4.7776055,0.6358886,,,,,,,,,,,,,, -215100,4.482683,0.62147415,,,,,,,,,,,,,, -215200,4.745725,0.6489587,,,,,,,,,,,,,, -215300,4.337218,0.6468245,,,,,,,,,,,,,, -215329,,,0.9591438174247742,0.1498548686504364,0.7572999596595764,1.0407342910766602,50000.0,0.631600022315979,1.822715163230896,10000.0,72479.10156702995,75103.66161513329,72479.10156702995,2607.659962415695,9.106708526611328,0.0 -215400,4.4357123,0.6058836,,,,,,,,,,,,,, -215500,4.4412394,0.6436901,,,,,,,,,,,,,, -215600,4.6921945,0.61959964,,,,,,,,,,,,,, -215700,4.4035783,0.6100963,,,,,,,,,,,,,, -215800,4.974468,0.69487715,,,,,,,,,,,,,, -215900,4.306151,0.61372197,,,,,,,,,,,,,, -216000,4.696133,0.6200445,,,,,,,,,,,,,, -216100,4.3686767,0.54148257,,,,,,,,,,,,,, -216200,4.30338,0.5655307,,,,,,,,,,,,,, -216300,4.1084843,0.5754133,,,,,,,,,,,,,, -216400,4.262724,0.5912658,,,,,,,,,,,,,, -216500,4.9484344,0.6541352,,,,,,,,,,,,,, -216600,4.468025,0.6722303,,,,,,,,,,,,,, -216700,4.306461,0.62692124,,,,,,,,,,,,,, -216800,4.6674504,0.61584425,,,,,,,,,,,,,, -216845,,,0.9608777165412904,0.1457101702690124,0.7572999596595764,1.041617512702942,50000.0,0.6318000555038452,1.823696732521057,10000.0,72989.02272677422,75631.14029312134,72989.02272677422,2625.0996906757355,9.168915510177612,0.0 -216900,4.5963216,0.672019,,,,,,,,,,,,,, -217000,4.4256196,0.6632449,,,,,,,,,,,,,, -217100,4.5926666,0.670515,,,,,,,,,,,,,, -217200,4.5639496,0.6372778,,,,,,,,,,,,,, -217300,4.8025923,0.64647985,,,,,,,,,,,,,, -217400,4.680748,0.67705864,,,,,,,,,,,,,, -217500,4.4475374,0.6178762,,,,,,,,,,,,,, -217600,4.0615497,0.5650183,,,,,,,,,,,,,, -217700,3.8934567,0.5772925,,,,,,,,,,,,,, -217800,4.2706857,0.6811073,,,,,,,,,,,,,, -217900,4.3772964,0.610158,,,,,,,,,,,,,, -218000,4.201943,0.63042396,,,,,,,,,,,,,, -218100,4.269636,0.5564247,,,,,,,,,,,,,, -218200,4.5096965,0.61666024,,,,,,,,,,,,,, -218300,4.7831893,0.6472691,,,,,,,,,,,,,, -218361,,,0.960359513759613,0.1462864428758621,0.7573399543762207,1.04054057598114,50000.0,0.631600022315979,1.821586012840271,10000.0,73498.98733758926,76158.17280721664,73498.98733758926,2642.055506706238,9.225011825561523,0.0 -218400,4.31543,0.58587897,,,,,,,,,,,,,, -218500,4.340555,0.642295,,,,,,,,,,,,,, -218600,4.72107,0.64996624,,,,,,,,,,,,,, -218700,4.605782,0.65462345,,,,,,,,,,,,,, -218800,4.5435386,0.6734615,,,,,,,,,,,,,, -218900,4.4411707,0.66301167,,,,,,,,,,,,,, -219000,4.86066,0.649474,,,,,,,,,,,,,, -219100,4.8927283,0.68405336,,,,,,,,,,,,,, -219200,4.870258,0.711846,,,,,,,,,,,,,, -219300,4.4931383,0.6130719,,,,,,,,,,,,,, -219400,4.402995,0.6299382,,,,,,,,,,,,,, -219500,4.8342605,0.58881915,,,,,,,,,,,,,, -219600,4.4639454,0.6644138,,,,,,,,,,,,,, -219700,4.4498696,0.5603474,,,,,,,,,,,,,, -219800,4.6567802,0.65125,,,,,,,,,,,,,, -219878,,,0.9594826102256776,0.1475639790296554,0.7575199604034424,1.040326476097107,50000.0,0.6318000555038452,1.8223992586135864,10000.0,74009.09196305275,76685.58260345459,74009.09196305275,2659.2388412952423,9.289936542510986,0.0 -219900,4.3523464,0.5880142,,,,,,,,,,,,,, -220000,4.3769646,0.62137765,,,,,,,,,,,,,, -220100,4.6198645,0.58657825,,,,,,,,,,,,,, -220200,3.98942,0.58578134,,,,,,,,,,,,,, -220300,4.243716,0.6260604,,,,,,,,,,,,,, -220400,4.8244014,0.60713106,,,,,,,,,,,,,, -220500,4.2443795,0.6372888,,,,,,,,,,,,,, -220600,4.7965164,0.63499564,,,,,,,,,,,,,, -220700,4.644982,0.6516372,,,,,,,,,,,,,, -220800,4.4551353,0.59825516,,,,,,,,,,,,,, -220900,4.751897,0.6676945,,,,,,,,,,,,,, -221000,4.2584257,0.580686,,,,,,,,,,,,,, -221100,4.518572,0.62147266,,,,,,,,,,,,,, -221200,4.50868,0.6225784,,,,,,,,,,,,,, -221300,4.4262185,0.675955,,,,,,,,,,,,,, -221396,,,0.9596819281578064,0.1482045650482177,0.7574999928474426,1.0406891107559204,50000.0,0.6323000192642212,1.8241325616836548,10000.0,74519.10008716583,77212.83652758598,74519.10008716583,2676.3662524223328,9.35293436050415,0.0 -221400,4.8373637,0.611593,,,,,,,,,,,,,, -221500,4.756993,0.6133634,,,,,,,,,,,,,, -221600,4.244405,0.5627209,,,,,,,,,,,,,, -221700,4.5175776,0.5966556,,,,,,,,,,,,,, -221800,4.7020197,0.57749104,,,,,,,,,,,,,, -221900,4.6322613,0.70134765,,,,,,,,,,,,,, -222000,4.4650855,0.6318945,,,,,,,,,,,,,, -222100,4.1371183,0.6451385,,,,,,,,,,,,,, -222200,5.010176,0.5916596,,,,,,,,,,,,,, -222300,4.240307,0.61588585,,,,,,,,,,,,,, -222400,4.6469784,0.6707193,,,,,,,,,,,,,, -222500,4.3224497,0.6402271,,,,,,,,,,,,,, -222600,4.281287,0.6274893,,,,,,,,,,,,,, -222700,4.781731,0.6498624,,,,,,,,,,,,,, -222800,4.504074,0.62408966,,,,,,,,,,,,,, -222900,4.6237516,0.6089347,,,,,,,,,,,,,, -222912,,,0.9620137214660645,0.1440362632274627,0.7574999928474426,1.0408563613891602,50000.0,0.6314000487327576,1.822832465171814,10000.0,75029.03829264641,77740.1799068451,75029.03829264641,2693.648894548416,9.417216300964355,0.0 -223000,5.0184054,0.63190603,,,,,,,,,,,,,, -223100,4.3567524,0.6632843,,,,,,,,,,,,,, -223200,5.0930586,0.66784,,,,,,,,,,,,,, -223300,4.308708,0.6440814,,,,,,,,,,,,,, -223400,4.4218726,0.667465,,,,,,,,,,,,,, -223500,4.539779,0.60915256,,,,,,,,,,,,,, -223600,3.9790046,0.62680197,,,,,,,,,,,,,, -223700,4.5654597,0.67745817,,,,,,,,,,,,,, -223800,4.074318,0.588227,,,,,,,,,,,,,, -223900,4.3525915,0.6031757,,,,,,,,,,,,,, -224000,4.5576143,0.7295887,,,,,,,,,,,,,, -224100,4.301772,0.56179464,,,,,,,,,,,,,, -224200,4.240553,0.541151,,,,,,,,,,,,,, -224300,4.6503124,0.64967775,,,,,,,,,,,,,, -224400,4.603604,0.62302595,,,,,,,,,,,,,, -224427,,,0.9602399468421936,0.1480806469917297,0.7569999694824219,1.0420169830322266,50000.0,0.6310000419616699,1.8239521980285645,10000.0,75538.97178125381,78267.17767620087,75538.97178125381,2710.5937311649323,9.481563091278076,0.0 -224500,4.457116,0.5871384,,,,,,,,,,,,,, -224600,4.1168323,0.5724366,,,,,,,,,,,,,, -224700,4.4464254,0.6018511,,,,,,,,,,,,,, -224800,4.338515,0.62511814,,,,,,,,,,,,,, -224900,4.5917907,0.62885463,,,,,,,,,,,,,, -225000,4.1783776,0.55376434,,,,,,,,,,,,,, -225100,3.9886901,0.5679734,,,,,,,,,,,,,, -225200,4.676844,0.6249552,,,,,,,,,,,,,, -225300,4.764778,0.64179957,,,,,,,,,,,,,, -225400,4.4648147,0.6173476,,,,,,,,,,,,,, -225500,4.097625,0.57544684,,,,,,,,,,,,,, -225600,4.8146563,0.65076077,,,,,,,,,,,,,, -225700,4.309973,0.5918802,,,,,,,,,,,,,, -225800,4.487364,0.63412017,,,,,,,,,,,,,, -225900,4.1850777,0.55680287,,,,,,,,,,,,,, -225944,,,0.9607979655265808,0.1444706618785858,0.7567799687385559,1.04023540019989,50000.0,0.6314000487327576,1.8221471309661863,10000.0,76049.03899121284,78794.60010194778,76049.03899121284,2727.824291229248,9.551263332366943,0.0 -226000,4.5956287,0.60636926,,,,,,,,,,,,,, -226100,4.3073745,0.6017643,,,,,,,,,,,,,, -226200,4.6030955,0.6343813,,,,,,,,,,,,,, -226300,4.2202234,0.59495527,,,,,,,,,,,,,, -226400,4.0638294,0.5882559,,,,,,,,,,,,,, -226500,4.7543235,0.6385944,,,,,,,,,,,,,, -226600,4.335073,0.59256315,,,,,,,,,,,,,, -226700,4.39229,0.63001597,,,,,,,,,,,,,, -226800,4.218918,0.59729236,,,,,,,,,,,,,, -226900,4.4855375,0.6485256,,,,,,,,,,,,,, -227000,4.7616305,0.7029316,,,,,,,,,,,,,, -227100,4.6780686,0.6485383,,,,,,,,,,,,,, -227200,4.0038357,0.55943656,,,,,,,,,,,,,, -227300,4.4174213,0.5884654,,,,,,,,,,,,,, -227400,4.2470255,0.5776822,,,,,,,,,,,,,, -227460,,,0.9618343114852904,0.1431108713150024,0.7572000026702881,1.040708303451538,50000.0,0.6312000155448914,1.820837140083313,10000.0,76558.91853904724,79321.60742902756,76558.91853904724,2744.825801372528,9.621059656143188,0.0 -227500,5.0130367,0.6836335,,,,,,,,,,,,,, -227600,4.6249084,0.5850415,,,,,,,,,,,,,, -227700,4.6715913,0.625125,,,,,,,,,,,,,, -227800,4.9910865,0.60016143,,,,,,,,,,,,,, -227900,4.4688754,0.62756354,,,,,,,,,,,,,, -228000,4.559444,0.59531873,,,,,,,,,,,,,, -228100,4.262249,0.62507695,,,,,,,,,,,,,, -228200,4.4098444,0.6743902,,,,,,,,,,,,,, -228300,4.933679,0.64564943,,,,,,,,,,,,,, -228400,4.720917,0.6919137,,,,,,,,,,,,,, -228500,4.3406777,0.6467316,,,,,,,,,,,,,, -228600,4.616956,0.63300115,,,,,,,,,,,,,, -228700,4.563541,0.59862155,,,,,,,,,,,,,, -228800,4.2389216,0.5650573,,,,,,,,,,,,,, -228900,4.826303,0.6480423,,,,,,,,,,,,,, -228975,,,0.9617745280265808,0.1452770978212356,0.7570799589157104,1.039811372756958,50000.0,0.6309000253677368,1.8225457668304443,10000.0,77068.11691904068,79848.94569897652,77068.11691904068,2762.005347251892,10.524330615997314,0.0 -229000,4.4464774,0.6439395,,,,,,,,,,,,,, -229100,4.615728,0.6792832,,,,,,,,,,,,,, -229200,4.379513,0.6044309,,,,,,,,,,,,,, -229300,4.546128,0.73553145,,,,,,,,,,,,,, -229400,4.434426,0.60108113,,,,,,,,,,,,,, -229500,4.5630374,0.6451858,,,,,,,,,,,,,, -229600,4.4618134,0.6490834,,,,,,,,,,,,,, -229700,4.419674,0.57439876,,,,,,,,,,,,,, -229800,4.1505294,0.6203182,,,,,,,,,,,,,, -229900,4.195569,0.58040273,,,,,,,,,,,,,, -230000,4.0836053,0.62318695,,,,,,,,,,,,,, -230100,4.4937196,0.6283779,,,,,,,,,,,,,, -230200,4.068549,0.5625097,,,,,,,,,,,,,, -230300,4.5915856,0.55178225,,,,,,,,,,,,,, -230400,4.259349,0.6297516,,,,,,,,,,,,,, -230492,,,0.960957407951355,0.1454789787530899,0.7569800019264221,1.040593504905701,50000.0,0.6309000253677368,1.8216196298599243,10000.0,77578.0630581379,80376.73307228088,77578.0630581379,2779.7181718349457,10.597296953201294,0.0 -230500,4.923334,0.7401649,,,,,,,,,,,,,, -230600,3.9729471,0.56640315,,,,,,,,,,,,,, -230700,5.142668,0.56728375,,,,,,,,,,,,,, -230800,4.6060796,0.6712599,,,,,,,,,,,,,, -230900,4.526015,0.6771248,,,,,,,,,,,,,, -231000,4.363704,0.6349213,,,,,,,,,,,,,, -231100,4.2500377,0.6626485,,,,,,,,,,,,,, -231200,4.55019,0.6063254,,,,,,,,,,,,,, -231300,4.342055,0.59033984,,,,,,,,,,,,,, -231400,5.11022,0.64800406,,,,,,,,,,,,,, -231500,3.9606402,0.5546146,,,,,,,,,,,,,, -231600,4.32413,0.6156296,,,,,,,,,,,,,, -231700,4.687556,0.60968333,,,,,,,,,,,,,, -231800,4.784286,0.6454338,,,,,,,,,,,,,, -231900,4.390735,0.5936465,,,,,,,,,,,,,, -232000,4.640419,0.62998575,,,,,,,,,,,,,, -232009,,,0.96000075340271,0.1493954509496688,0.7570199966430664,1.041401743888855,50000.0,0.6317000389099121,1.8230193853378296,10000.0,78088.05710411072,80903.96130156517,78088.05710411072,2796.831855535507,10.66297459602356,0.0 -232100,4.474043,0.65652275,,,,,,,,,,,,,, -232200,4.2603607,0.61944866,,,,,,,,,,,,,, -232300,4.646121,0.6575383,,,,,,,,,,,,,, -232400,4.502185,0.62727016,,,,,,,,,,,,,, -232500,4.9090652,0.62175417,,,,,,,,,,,,,, -232600,4.5110064,0.65652806,,,,,,,,,,,,,, -232700,5.026138,0.6385418,,,,,,,,,,,,,, -232800,4.760829,0.68823,,,,,,,,,,,,,, -232900,4.9435167,0.66423035,,,,,,,,,,,,,, -233000,5.145169,0.7584373,,,,,,,,,,,,,, -233100,4.6342425,0.63386744,,,,,,,,,,,,,, -233200,4.4649134,0.6180733,,,,,,,,,,,,,, -233300,4.4791503,0.6619035,,,,,,,,,,,,,, -233400,4.7742596,0.6380698,,,,,,,,,,,,,, -233500,4.386315,0.629089,,,,,,,,,,,,,, -233526,,,0.9608178734779358,0.1444471329450607,0.7569999694824219,1.0412191152572632,50000.0,0.6313000321388245,1.8219051361083984,10000.0,78597.98800444603,81431.11747145653,78597.98800444603,2813.9345309734344,10.7299907207489,0.0 -233600,4.4660387,0.61495376,,,,,,,,,,,,,, -233700,4.7868667,0.7023121,,,,,,,,,,,,,, -233800,4.3011665,0.62917936,,,,,,,,,,,,,, -233900,4.7796516,0.6624471,,,,,,,,,,,,,, -234000,4.423828,0.5253819,,,,,,,,,,,,,, -234100,4.728592,0.6492577,,,,,,,,,,,,,, -234200,4.195315,0.56111485,,,,,,,,,,,,,, -234300,4.5115414,0.6588013,,,,,,,,,,,,,, -234400,4.6368866,0.6279754,,,,,,,,,,,,,, -234500,4.678024,0.6143105,,,,,,,,,,,,,, -234600,4.969227,0.64644283,,,,,,,,,,,,,, -234700,4.731363,0.5867161,,,,,,,,,,,,,, -234800,5.096252,0.57043433,,,,,,,,,,,,,, -234900,4.5973268,0.659726,,,,,,,,,,,,,, -235000,4.401732,0.5831686,,,,,,,,,,,,,, -235044,,,0.961694836616516,0.1431909799575805,0.7575399875640869,1.041340708732605,50000.0,0.6317000389099121,1.8218598365783687,10000.0,79108.1516199112,81958.46935725212,79108.1516199112,2830.9884712696075,10.809279441833496,0.0 -235100,4.556328,0.6963678,,,,,,,,,,,,,, -235200,4.656925,0.7563099,,,,,,,,,,,,,, -235300,4.256582,0.6180798,,,,,,,,,,,,,, -235400,4.287563,0.6311615,,,,,,,,,,,,,, -235500,4.4138,0.6322622,,,,,,,,,,,,,, -235600,4.631585,0.6564149,,,,,,,,,,,,,, -235700,4.865595,0.64863545,,,,,,,,,,,,,, -235800,4.6251364,0.6307382,,,,,,,,,,,,,, -235900,4.5760956,0.56889945,,,,,,,,,,,,,, -236000,4.3278065,0.56543076,,,,,,,,,,,,,, -236100,4.5420356,0.6975179,,,,,,,,,,,,,, -236200,4.349685,0.65924615,,,,,,,,,,,,,, -236300,4.321076,0.59640217,,,,,,,,,,,,,, -236400,4.337492,0.600104,,,,,,,,,,,,,, -236500,4.524995,0.6393048,,,,,,,,,,,,,, -236561,,,0.9612165093421936,0.1462259292602539,0.7571799755096436,1.041209697723389,50000.0,0.6304000020027161,1.8224035501480105,10000.0,79618.13293385506,82485.65254235268,79618.13293385506,2848.0664477348328,10.87845540046692,0.0 -236600,4.8449364,0.61785644,,,,,,,,,,,,,, -236700,4.4379625,0.64693505,,,,,,,,,,,,,, -236800,4.302552,0.5385501,,,,,,,,,,,,,, -236900,4.6952257,0.6576471,,,,,,,,,,,,,, -237000,5.0597067,0.6803298,,,,,,,,,,,,,, -237100,4.077402,0.547018,,,,,,,,,,,,,, -237200,4.296309,0.6426212,,,,,,,,,,,,,, -237300,4.33892,0.62846327,,,,,,,,,,,,,, -237400,4.250385,0.6405552,,,,,,,,,,,,,, -237500,4.350699,0.6816019,,,,,,,,,,,,,, -237600,4.683302,0.65349823,,,,,,,,,,,,,, -237700,5.0607934,0.67109,,,,,,,,,,,,,, -237800,4.442963,0.6775811,,,,,,,,,,,,,, -237900,4.3399286,0.5625834,,,,,,,,,,,,,, -238000,4.735784,0.6980445,,,,,,,,,,,,,, -238078,,,0.9614157676696776,0.146426573395729,0.7571600079536438,1.0422078371047974,50000.0,0.6309000253677368,1.8252544403076167,10000.0,80128.2179043293,83012.94344353676,80128.2179043293,2865.148155927658,10.948329448699951,0.0 -238100,4.280961,0.5828118,,,,,,,,,,,,,, -238200,4.6106944,0.6112012,,,,,,,,,,,,,, -238300,4.407872,0.6302694,,,,,,,,,,,,,, -238400,4.3905845,0.6541438,,,,,,,,,,,,,, -238500,4.7858105,0.63697284,,,,,,,,,,,,,, -238600,4.9430876,0.70003897,,,,,,,,,,,,,, -238700,4.403387,0.63201857,,,,,,,,,,,,,, -238800,4.646498,0.67228574,,,,,,,,,,,,,, -238900,4.7520714,0.59985596,,,,,,,,,,,,,, -239000,4.401443,0.54999894,,,,,,,,,,,,,, -239100,4.6408873,0.6043664,,,,,,,,,,,,,, -239200,4.3783436,0.62706834,,,,,,,,,,,,,, -239300,4.670915,0.6019699,,,,,,,,,,,,,, -239400,4.426172,0.6268273,,,,,,,,,,,,,, -239500,4.3289084,0.6440972,,,,,,,,,,,,,, -239595,,,0.9596819281578064,0.1480749398469925,0.7570199966430664,1.0406923294067385,50000.0,0.6324000358581543,1.8220387697219849,10000.0,80638.16833734512,83540.04543042183,80638.16833734512,2882.1735339164734,11.02028512954712,0.0 -239600,4.5060782,0.60828006,,,,,,,,,,,,,, -239700,4.605516,0.6216358,,,,,,,,,,,,,, -239800,4.3952765,0.6453946,,,,,,,,,,,,,, -239900,4.31994,0.60282207,,,,,,,,,,,,,, -240000,4.5264263,0.66358167,,,,,,,,,,,,,, -240100,4.0897627,0.58946615,,,,,,,,,,,,,, -240200,4.482433,0.58351874,,,,,,,,,,,,,, -240300,4.8628993,0.67141855,,,,,,,,,,,,,, -240400,4.086551,0.5246415,,,,,,,,,,,,,, -240500,4.3192773,0.6159681,,,,,,,,,,,,,, -240600,4.43845,0.59935653,,,,,,,,,,,,,, -240700,4.8637767,0.6756174,,,,,,,,,,,,,, -240800,4.616441,0.6582831,,,,,,,,,,,,,, -240900,4.7940307,0.67708284,,,,,,,,,,,,,, -241000,4.497409,0.6783292,,,,,,,,,,,,,, -241100,4.6383486,0.59382135,,,,,,,,,,,,,, -241112,,,0.9610570669174194,0.1477317065000534,0.7570399641990662,1.0410159826278689,50000.0,0.6330000162124634,1.8216248750686648,10000.0,81148.19727015495,84067.56158614159,81148.19727015495,2899.5223717689514,11.105278968811035,0.0 -241200,4.544632,0.6936414,,,,,,,,,,,,,, -241300,4.6812115,0.6723404,,,,,,,,,,,,,, -241400,4.6821537,0.62928146,,,,,,,,,,,,,, -241500,4.86548,0.66527903,,,,,,,,,,,,,, -241600,4.4585347,0.58337474,,,,,,,,,,,,,, -241700,4.2782693,0.57485616,,,,,,,,,,,,,, -241800,4.109479,0.5762312,,,,,,,,,,,,,, -241900,4.662448,0.62339485,,,,,,,,,,,,,, -242000,4.5033126,0.60451317,,,,,,,,,,,,,, -242100,4.219345,0.58757955,,,,,,,,,,,,,, -242200,4.4421964,0.6367353,,,,,,,,,,,,,, -242300,4.80157,0.67642856,,,,,,,,,,,,,, -242400,4.4116006,0.5903596,,,,,,,,,,,,,, -242500,4.3944745,0.59357786,,,,,,,,,,,,,, -242600,4.857258,0.6337845,,,,,,,,,,,,,, -242629,,,0.9610769748687744,0.1438143998384475,0.7569999694824219,1.0407848358154297,50000.0,0.631100058555603,1.8233078718185425,10000.0,81658.35938191414,84594.96130204201,81658.35938191414,2916.6335439682007,11.176434993743896,0.0 -242700,4.5182376,0.61580265,,,,,,,,,,,,,, -242800,4.527641,0.6938289,,,,,,,,,,,,,, -242900,5.1629076,0.61510336,,,,,,,,,,,,,, -243000,4.2607903,0.5795629,,,,,,,,,,,,,, -243100,4.96249,0.6746355,,,,,,,,,,,,,, -243200,4.4784856,0.633761,,,,,,,,,,,,,, -243300,4.365242,0.65150964,,,,,,,,,,,,,, -243400,4.6835556,0.56241757,,,,,,,,,,,,,, -243500,5.1797624,0.7003645,,,,,,,,,,,,,, -243600,4.704652,0.6369436,,,,,,,,,,,,,, -243700,4.3782697,0.614138,,,,,,,,,,,,,, -243800,4.7679563,0.6240448,,,,,,,,,,,,,, -243900,4.116503,0.5566492,,,,,,,,,,,,,, -244000,4.7623014,0.6130209,,,,,,,,,,,,,, -244100,4.510298,0.63375866,,,,,,,,,,,,,, -244146,,,0.9615951776504515,0.146777406334877,0.7572999596595764,1.0413265228271484,50000.0,0.6317000389099121,1.8220953941345213,10000.0,82168.30734395981,85122.14113402367,82168.30734395981,2933.7407212257385,11.245970726013184,0.0 -244200,4.334161,0.63320565,,,,,,,,,,,,,, -244300,4.2868924,0.60217273,,,,,,,,,,,,,, -244400,4.6634455,0.59794265,,,,,,,,,,,,,, -244500,4.2861032,0.6195156,,,,,,,,,,,,,, -244600,4.36082,0.6123786,,,,,,,,,,,,,, -244700,4.6801972,0.7011214,,,,,,,,,,,,,, -244800,4.69221,0.6227794,,,,,,,,,,,,,, -244900,4.6934505,0.5991118,,,,,,,,,,,,,, -245000,4.355745,0.61267257,,,,,,,,,,,,,, -245100,4.469601,0.6049452,,,,,,,,,,,,,, -245200,4.5895104,0.66786104,,,,,,,,,,,,,, -245300,4.310331,0.5915674,,,,,,,,,,,,,, -245400,4.0855684,0.536917,,,,,,,,,,,,,, -245500,4.320195,0.6248989,,,,,,,,,,,,,, -245600,4.289942,0.55246335,,,,,,,,,,,,,, -245662,,,0.959741711616516,0.1484692096710205,0.7572599649429321,1.0395944118499756,50000.0,0.631600022315979,1.821165919303894,10000.0,82678.47346019745,85649.61295294762,82678.47346019745,2950.919387578964,11.318059921264648,0.0 -245700,5.058707,0.66404605,,,,,,,,,,,,,, -245800,4.485186,0.64478,,,,,,,,,,,,,, -245900,4.230867,0.6145517,,,,,,,,,,,,,, -246000,3.9769948,0.5793543,,,,,,,,,,,,,, -246100,4.5050397,0.66061807,,,,,,,,,,,,,, -246200,4.3620405,0.5838946,,,,,,,,,,,,,, -246300,4.438473,0.5783809,,,,,,,,,,,,,, -246400,4.398556,0.6300977,,,,,,,,,,,,,, -246500,4.2809806,0.62885183,,,,,,,,,,,,,, -246600,4.6438584,0.6746214,,,,,,,,,,,,,, -246700,4.846533,0.71809983,,,,,,,,,,,,,, -246800,4.1529493,0.5709847,,,,,,,,,,,,,, -246900,4.73285,0.6718421,,,,,,,,,,,,,, -247000,4.6339407,0.67125046,,,,,,,,,,,,,, -247100,4.8057632,0.6541755,,,,,,,,,,,,,, -247178,,,0.9597217440605164,0.149254560470581,0.7569199800491333,1.0409460067749023,50000.0,0.6318000555038452,1.8228518962860107,10000.0,83188.46709156036,86176.89586758614,83188.46709156036,2968.063045740128,11.407188892364502,0.0 -247200,4.4784236,0.61430895,,,,,,,,,,,,,, -247300,4.613808,0.6304402,,,,,,,,,,,,,, -247400,4.2954226,0.5930057,,,,,,,,,,,,,, -247500,4.1446056,0.6426015,,,,,,,,,,,,,, -247600,4.3379226,0.5760538,,,,,,,,,,,,,, -247700,4.2678466,0.6373763,,,,,,,,,,,,,, -247800,4.8141603,0.64345825,,,,,,,,,,,,,, -247900,5.122717,0.5828899,,,,,,,,,,,,,, -248000,4.760611,0.6504694,,,,,,,,,,,,,, -248100,4.713065,0.6879565,,,,,,,,,,,,,, -248200,4.215974,0.6376702,,,,,,,,,,,,,, -248300,4.394582,0.6074425,,,,,,,,,,,,,, -248400,4.393027,0.59803224,,,,,,,,,,,,,, -248500,4.5682216,0.6835991,,,,,,,,,,,,,, -248600,4.017327,0.5376855,,,,,,,,,,,,,, -248695,,,0.9600605964660645,0.1471786648035049,0.7571199536323547,1.0416258573532104,50000.0,0.6308000087738037,1.8244123458862305,10000.0,83698.41069197655,86704.00005698204,83698.41069197655,2985.100293636322,11.474862098693848,0.0 -248700,4.7455564,0.6590746,,,,,,,,,,,,,, -248800,4.653878,0.63651,,,,,,,,,,,,,, -248900,4.7308664,0.6497509,,,,,,,,,,,,,, -249000,4.2265253,0.62925565,,,,,,,,,,,,,, -249100,4.343795,0.5874201,,,,,,,,,,,,,, -249200,4.7919374,0.6714589,,,,,,,,,,,,,, -249300,4.869169,0.7035485,,,,,,,,,,,,,, -249400,4.0974813,0.56011975,,,,,,,,,,,,,, -249500,4.76395,0.6561395,,,,,,,,,,,,,, -249600,4.577124,0.61717004,,,,,,,,,,,,,, -249700,4.319049,0.65145946,,,,,,,,,,,,,, -249800,4.1992188,0.6480451,,,,,,,,,,,,,, -249900,4.7741218,0.5576995,,,,,,,,,,,,,, -250000,4.0020556,0.5581728,,,,,,,,,,,,,, -250100,5.2133427,0.7084141,,,,,,,,,,,,,, -250200,4.542383,0.63753885,,,,,,,,,,,,,, -250213,,,0.9616350531578064,0.1467591375112533,0.7573399543762207,1.0408724546432495,50000.0,0.6315000057220459,1.821985483169556,10000.0,84208.50983929634,87231.36796617508,84208.50983929634,3002.2465505599976,11.541433811187744,0.0 -250300,4.7924056,0.741422,,,,,,,,,,,,,, -250400,4.319328,0.61171734,,,,,,,,,,,,,, -250500,4.8627496,0.67848814,,,,,,,,,,,,,, -250600,4.6036634,0.6417154,,,,,,,,,,,,,, -250700,4.744961,0.58005524,,,,,,,,,,,,,, -250800,4.140538,0.59980774,,,,,,,,,,,,,, -250900,4.570807,0.61134076,,,,,,,,,,,,,, -251000,4.663176,0.6638579,,,,,,,,,,,,,, -251100,4.5281253,0.6132449,,,,,,,,,,,,,, -251200,5.168393,0.6043937,,,,,,,,,,,,,, -251300,4.530972,0.6677687,,,,,,,,,,,,,, -251400,4.7012944,0.5739659,,,,,,,,,,,,,, -251500,3.8242512,0.55655897,,,,,,,,,,,,,, -251600,4.3301244,0.66110164,,,,,,,,,,,,,, -251700,4.3415046,0.5430207,,,,,,,,,,,,,, -251731,,,0.960758090019226,0.1473046839237213,0.7570799589157104,1.0428922176361084,50000.0,0.6312000155448914,1.8245079517364504,10000.0,84718.69186162949,87758.83411717415,84718.69186162949,3019.4082946777344,11.609270811080933,0.0 -251800,4.3680162,0.63351566,,,,,,,,,,,,,, -251900,4.4384017,0.59157896,,,,,,,,,,,,,, -252000,4.1860676,0.587697,,,,,,,,,,,,,, -252100,4.8879805,0.66488487,,,,,,,,,,,,,, -252200,4.454052,0.5798288,,,,,,,,,,,,,, -252300,4.0120254,0.5798619,,,,,,,,,,,,,, -252400,4.880832,0.63580626,,,,,,,,,,,,,, -252500,4.177137,0.5013711,,,,,,,,,,,,,, -252600,4.1062512,0.53406346,,,,,,,,,,,,,, -252700,4.374492,0.5717453,,,,,,,,,,,,,, -252800,4.1845083,0.5254523,,,,,,,,,,,,,, -252900,4.419015,0.5857847,,,,,,,,,,,,,, -253000,4.307196,0.5413377,,,,,,,,,,,,,, -253100,4.832056,0.6428282,,,,,,,,,,,,,, -253200,4.047201,0.5774888,,,,,,,,,,,,,, -253248,,,0.9601203799247742,0.1469669789075851,0.7571199536323547,1.0400139093399048,50000.0,0.631600022315979,1.820750117301941,10000.0,85228.6314611435,88286.13070821762,85228.6314611435,3036.6418414115906,11.676239252090454,0.0 -253300,4.3405347,0.5672436,,,,,,,,,,,,,, -253400,4.461415,0.5916538,,,,,,,,,,,,,, -253500,4.5227346,0.6439361,,,,,,,,,,,,,, -253600,4.4911036,0.60382426,,,,,,,,,,,,,, -253700,4.259657,0.6342454,,,,,,,,,,,,,, -253800,4.1878133,0.663482,,,,,,,,,,,,,, -253900,4.1140327,0.5766865,,,,,,,,,,,,,, -254000,4.268078,0.64042246,,,,,,,,,,,,,, -254100,4.8916245,0.7264049,,,,,,,,,,,,,, -254200,4.3045616,0.6421057,,,,,,,,,,,,,, -254300,4.5558906,0.57759756,,,,,,,,,,,,,, -254400,5.1790476,0.622317,,,,,,,,,,,,,, -254500,4.1197934,0.591078,,,,,,,,,,,,,, -254600,4.6653686,0.72074604,,,,,,,,,,,,,, -254700,4.466533,0.63790196,,,,,,,,,,,,,, -254765,,,0.9601004123687744,0.1476990282535553,0.7572399973869324,1.041129231452942,50000.0,0.6317000389099121,1.823580026626587,10000.0,85738.69255518913,88813.3157889843,85738.69255518913,3053.6353392601013,11.7507586479187,0.0 -254800,5.0258226,0.7192407,,,,,,,,,,,,,, -254900,4.285431,0.60838777,,,,,,,,,,,,,, -255000,4.7028704,0.62916845,,,,,,,,,,,,,, -255100,4.4386353,0.5969453,,,,,,,,,,,,,, -255200,4.3634477,0.6021242,,,,,,,,,,,,,, -255300,4.9095483,0.6034563,,,,,,,,,,,,,, -255400,4.7028008,0.698411,,,,,,,,,,,,,, -255500,4.7738137,0.65372515,,,,,,,,,,,,,, -255600,4.053586,0.54193735,,,,,,,,,,,,,, -255700,4.397456,0.6664069,,,,,,,,,,,,,, -255800,4.5164776,0.57557315,,,,,,,,,,,,,, -255900,4.825992,0.661906,,,,,,,,,,,,,, -256000,4.153857,0.6059585,,,,,,,,,,,,,, -256100,4.5139923,0.6373603,,,,,,,,,,,,,, -256200,4.5387864,0.60281485,,,,,,,,,,,,,, -256283,,,0.960957407951355,0.1457504779100418,0.7570199966430664,1.0411179065704346,50000.0,0.6318000555038452,1.8220078945159912,10000.0,86248.77793478966,89340.56578230858,86248.77793478966,3070.663406610489,11.832293510437012,0.0 -256300,4.7685337,0.6208446,,,,,,,,,,,,,, -256400,4.1929936,0.62087995,,,,,,,,,,,,,, -256500,4.4704847,0.63097453,,,,,,,,,,,,,, -256600,5.107947,0.63881165,,,,,,,,,,,,,, -256700,4.4074016,0.6306287,,,,,,,,,,,,,, -256800,4.681278,0.6281581,,,,,,,,,,,,,, -256900,4.600891,0.583717,,,,,,,,,,,,,, -257000,4.2294087,0.5834245,,,,,,,,,,,,,, -257100,4.744231,0.5751705,,,,,,,,,,,,,, -257200,5.0342083,0.6134709,,,,,,,,,,,,,, -257300,4.0722704,0.63171434,,,,,,,,,,,,,, -257400,4.4660034,0.67083406,,,,,,,,,,,,,, -257500,4.420757,0.6129663,,,,,,,,,,,,,, -257600,4.5326905,0.6681241,,,,,,,,,,,,,, -257700,4.416361,0.6864517,,,,,,,,,,,,,, -257800,,,0.959741711616516,0.1476667374372482,0.7569800019264221,1.041002869606018,50000.0,0.631600022315979,1.8211606740951536,10000.0,86758.89575958252,89867.85570526123,86758.89575958252,3087.710409402848,11.901975393295288,0.0 -257800,4.5948534,0.6255454,,,,,,,,,,,,,, -257900,4.5126467,0.61707973,,,,,,,,,,,,,, -258000,4.3763914,0.62947553,,,,,,,,,,,,,, -258100,4.687632,0.6749808,,,,,,,,,,,,,, -258200,4.5456657,0.60274106,,,,,,,,,,,,,, -258300,4.9443827,0.6687354,,,,,,,,,,,,,, -258400,4.454011,0.5995578,,,,,,,,,,,,,, -258500,4.609959,0.5854384,,,,,,,,,,,,,, -258600,4.5658746,0.6744236,,,,,,,,,,,,,, -258700,4.159786,0.5670565,,,,,,,,,,,,,, -258800,4.668493,0.6273217,,,,,,,,,,,,,, -258900,4.240266,0.61493325,,,,,,,,,,,,,, -259000,4.821729,0.7109224,,,,,,,,,,,,,, -259100,5.0482097,0.7601927,,,,,,,,,,,,,, -259200,4.5309763,0.6409003,,,,,,,,,,,,,, -259300,4.5389724,0.6477598,,,,,,,,,,,,,, -259317,,,0.960379421710968,0.1454952508211136,0.7571600079536438,1.0412981510162354,50000.0,0.6313000321388245,1.8210585117340088,10000.0,87268.98562383652,90395.32349467278,87268.98562383652,3104.96017575264,11.974144697189333,0.0 -259400,4.3252697,0.6560235,,,,,,,,,,,,,, -259500,4.874684,0.62294924,,,,,,,,,,,,,, -259600,4.5633936,0.6565706,,,,,,,,,,,,,, -259700,4.343501,0.64235294,,,,,,,,,,,,,, -259800,4.6720405,0.6304008,,,,,,,,,,,,,, -259900,4.3321643,0.60412884,,,,,,,,,,,,,, -260000,4.5483246,0.61671674,,,,,,,,,,,,,, -260100,5.2414784,0.69968694,,,,,,,,,,,,,, -260200,4.412593,0.6568004,,,,,,,,,,,,,, -260300,4.638181,0.6767138,,,,,,,,,,,,,, -260400,4.452928,0.60665125,,,,,,,,,,,,,, -260500,4.342829,0.61848253,,,,,,,,,,,,,, -260600,4.632379,0.6932703,,,,,,,,,,,,,, -260700,4.3877153,0.57683825,,,,,,,,,,,,,, -260800,4.678448,0.63546187,,,,,,,,,,,,,, -260833,,,0.9606783986091614,0.146960511803627,0.7566999793052673,1.0429741144180298,50000.0,0.6310000419616699,1.8252462148666384,10000.0,87779.135191679,90922.71321320534,87779.135191679,3122.0596079826355,12.059104919433594,0.0 -260900,4.566634,0.67528784,,,,,,,,,,,,,, -261000,4.711669,0.6292558,,,,,,,,,,,,,, -261100,4.3484592,0.6168307,,,,,,,,,,,,,, -261200,4.441757,0.5589363,,,,,,,,,,,,,, -261300,4.371112,0.5993414,,,,,,,,,,,,,, -261400,4.148159,0.5858735,,,,,,,,,,,,,, -261500,4.678184,0.65378886,,,,,,,,,,,,,, -261600,4.5916605,0.6101477,,,,,,,,,,,,,, -261700,4.609288,0.6700425,,,,,,,,,,,,,, -261800,4.2194414,0.57952213,,,,,,,,,,,,,, -261900,4.6254945,0.6707807,,,,,,,,,,,,,, -262000,4.377817,0.61907685,,,,,,,,,,,,,, -262100,5.0522327,0.75717527,,,,,,,,,,,,,, -262200,4.3623195,0.59248704,,,,,,,,,,,,,, -262300,4.493387,0.5739714,,,,,,,,,,,,,, -262350,,,0.9612364172935486,0.1476792991161346,0.7573999762535095,1.0416964292526243,50000.0,0.6309000253677368,1.82242488861084,10000.0,88289.18821763992,91450.09464883804,88289.18821763992,3139.257059574127,12.131905794143677,0.0 -262400,4.846815,0.68638086,,,,,,,,,,,,,, -262500,4.7365727,0.61875683,,,,,,,,,,,,,, -262600,4.3413835,0.6149304,,,,,,,,,,,,,, -262700,4.695996,0.6301476,,,,,,,,,,,,,, -262800,4.93087,0.6710663,,,,,,,,,,,,,, -262900,4.4170265,0.5768087,,,,,,,,,,,,,, -263000,4.6598773,0.69180465,,,,,,,,,,,,,, -263100,4.465224,0.61440694,,,,,,,,,,,,,, -263200,4.481985,0.6837251,,,,,,,,,,,,,, -263300,4.455434,0.584639,,,,,,,,,,,,,, -263400,4.515547,0.60630625,,,,,,,,,,,,,, -263500,4.38457,0.6519994,,,,,,,,,,,,,, -263600,4.417903,0.64964855,,,,,,,,,,,,,, -263700,4.133533,0.60352445,,,,,,,,,,,,,, -263800,4.2466636,0.5856044,,,,,,,,,,,,,, -263866,,,0.9598214030265808,0.1464400887489318,0.7572999596595764,1.0404446125030518,50000.0,0.6320000290870667,1.8234227895736688,10000.0,88799.24826645851,91977.33986473083,88799.24826645851,3156.3127546310425,12.205149173736572,0.0 -263900,4.4135723,0.6695864,,,,,,,,,,,,,, -264000,4.51497,0.5909407,,,,,,,,,,,,,, -264100,4.1694136,0.5796403,,,,,,,,,,,,,, -264200,4.769203,0.6464772,,,,,,,,,,,,,, -264300,4.2455797,0.6310975,,,,,,,,,,,,,, -264400,4.3362846,0.5820513,,,,,,,,,,,,,, -264500,4.667832,0.64210314,,,,,,,,,,,,,, -264600,4.328364,0.5986449,,,,,,,,,,,,,, -264700,4.509759,0.6504575,,,,,,,,,,,,,, -264800,4.740233,0.6823835,,,,,,,,,,,,,, -264900,4.3391047,0.5754198,,,,,,,,,,,,,, -265000,4.556251,0.6426693,,,,,,,,,,,,,, -265100,4.404249,0.66700435,,,,,,,,,,,,,, -265200,4.3049946,0.5872316,,,,,,,,,,,,,, -265300,4.4970255,0.5779424,,,,,,,,,,,,,, -265383,,,0.9618542790412904,0.142778679728508,0.7571399807929993,1.040645956993103,50000.0,0.6315000057220459,1.82268226146698,10000.0,89309.30315113068,92504.5416522026,89309.30315113068,3173.3309082984924,12.277857065200806,0.0 -265400,4.694905,0.6289903,,,,,,,,,,,,,, -265500,4.2432823,0.60221994,,,,,,,,,,,,,, -265600,4.318166,0.6306585,,,,,,,,,,,,,, -265700,4.5168724,0.6057815,,,,,,,,,,,,,, -265800,4.344372,0.65162235,,,,,,,,,,,,,, -265900,4.512658,0.6265476,,,,,,,,,,,,,, -266000,4.533226,0.60877687,,,,,,,,,,,,,, -266100,4.9330935,0.62624186,,,,,,,,,,,,,, -266200,4.5981565,0.712189,,,,,,,,,,,,,, -266300,4.549894,0.6803212,,,,,,,,,,,,,, -266400,4.5540953,0.634784,,,,,,,,,,,,,, -266500,4.6621513,0.61866415,,,,,,,,,,,,,, -266600,4.280725,0.6237955,,,,,,,,,,,,,, -266700,4.867935,0.6717934,,,,,,,,,,,,,, -266800,4.300814,0.56613636,,,,,,,,,,,,,, -266900,,,0.9614357352256776,0.1437123864889145,0.7573399543762207,1.041559815406799,50000.0,0.6315000057220459,1.821916103363037,10000.0,89819.25856900215,93032.41458940506,89819.25856900215,3191.1166141033173,12.352385759353638,0.0 -266900,4.3688397,0.59310615,,,,,,,,,,,,,, -267000,4.8883076,0.6012227,,,,,,,,,,,,,, -267100,4.807161,0.6333314,,,,,,,,,,,,,, -267200,4.615008,0.57650137,,,,,,,,,,,,,, -267300,4.8496675,0.640601,,,,,,,,,,,,,, -267400,4.353,0.6467357,,,,,,,,,,,,,, -267500,4.32711,0.5326291,,,,,,,,,,,,,, -267600,4.225481,0.58687496,,,,,,,,,,,,,, -267700,4.6571198,0.6748502,,,,,,,,,,,,,, -267800,5.3364134,0.656184,,,,,,,,,,,,,, -267900,4.390798,0.57037544,,,,,,,,,,,,,, -268000,4.398345,0.6419658,,,,,,,,,,,,,, -268100,4.4687686,0.6282146,,,,,,,,,,,,,, -268200,4.6207633,0.57873845,,,,,,,,,,,,,, -268300,4.458981,0.56329566,,,,,,,,,,,,,, -268400,4.2868624,0.58415216,,,,,,,,,,,,,, -268417,,,0.9615951776504515,0.1457705348730087,0.7571199536323547,1.0407840013504028,50000.0,0.6318000555038452,1.823684811592102,10000.0,90329.31252336502,93559.65687131882,90329.31252336502,3208.188056945801,12.414271593093872,0.0 -268500,4.6761403,0.60878235,,,,,,,,,,,,,, -268600,4.680541,0.6323614,,,,,,,,,,,,,, -268700,5.1272836,0.7081769,,,,,,,,,,,,,, -268800,4.507134,0.60691196,,,,,,,,,,,,,, -268900,4.8583975,0.5930528,,,,,,,,,,,,,, -269000,4.7010326,0.6592575,,,,,,,,,,,,,, -269100,4.17897,0.6000784,,,,,,,,,,,,,, -269200,4.4855056,0.6424932,,,,,,,,,,,,,, -269300,4.5560045,0.5243193,,,,,,,,,,,,,, -269400,4.7173586,0.57166845,,,,,,,,,,,,,, -269500,4.3339224,0.6202657,,,,,,,,,,,,,, -269600,4.43631,0.635339,,,,,,,,,,,,,, -269700,4.1265416,0.60291445,,,,,,,,,,,,,, -269800,4.884595,0.6823488,,,,,,,,,,,,,, -269900,4.2542343,0.60293585,,,,,,,,,,,,,, -269933,,,0.9596420526504515,0.1478172987699508,0.7570199966430664,1.0419752597808838,50000.0,0.631100058555603,1.8229761123657229,10000.0,90839.26527810095,94087.48363137244,90839.26527810095,3225.9346590042114,12.48537278175354,0.0 -270000,4.257289,0.57996964,,,,,,,,,,,,,, -270100,3.902694,0.592221,,,,,,,,,,,,,, -270200,3.955181,0.53309494,,,,,,,,,,,,,, -270300,4.6418266,0.6089394,,,,,,,,,,,,,, -270400,4.4937434,0.62762,,,,,,,,,,,,,, -270500,4.772538,0.63463175,,,,,,,,,,,,,, -270600,4.718381,0.62794894,,,,,,,,,,,,,, -270700,4.3948345,0.6337806,,,,,,,,,,,,,, -270800,4.377087,0.6234745,,,,,,,,,,,,,, -270900,4.238493,0.6446061,,,,,,,,,,,,,, -271000,4.990394,0.6977266,,,,,,,,,,,,,, -271100,4.797806,0.60859466,,,,,,,,,,,,,, -271200,3.8769968,0.5337866,,,,,,,,,,,,,, -271300,4.1631403,0.53563106,,,,,,,,,,,,,, -271400,4.629562,0.60988563,,,,,,,,,,,,,, -271450,,,0.961136758327484,0.1462059915065765,0.7572999596595764,1.0409760475158691,50000.0,0.6315000057220459,1.824073314666748,10000.0,91349.21398591997,94614.55206918716,91349.21398591997,3242.923921108246,12.55991506576538,0.0 -271500,4.425532,0.64593756,,,,,,,,,,,,,, -271600,4.932107,0.5762605,,,,,,,,,,,,,, -271700,4.6428995,0.68996286,,,,,,,,,,,,,, -271800,4.7266345,0.6380698,,,,,,,,,,,,,, -271900,4.650315,0.68845254,,,,,,,,,,,,,, -272000,4.322901,0.62388873,,,,,,,,,,,,,, -272100,4.6725426,0.62986094,,,,,,,,,,,,,, -272200,4.6469393,0.5550766,,,,,,,,,,,,,, -272300,4.8899755,0.711272,,,,,,,,,,,,,, -272400,4.31999,0.59081066,,,,,,,,,,,,,, -272500,4.369447,0.60970426,,,,,,,,,,,,,, -272600,4.017638,0.54549533,,,,,,,,,,,,,, -272700,4.5113554,0.5316059,,,,,,,,,,,,,, -272800,4.3592615,0.65999466,,,,,,,,,,,,,, -272900,4.266152,0.5624326,,,,,,,,,,,,,, -272967,,,0.9608178734779358,0.1457305401563644,0.7572000026702881,1.041136622428894,50000.0,0.6314000487327576,1.822332143783569,10000.0,91859.17581629752,95141.79025554656,91859.17581629752,3260.069388628006,12.63509225845337,0.0 -273000,4.354788,0.6305707,,,,,,,,,,,,,, -273100,4.4489927,0.6512852,,,,,,,,,,,,,, -273200,5.187332,0.7632908,,,,,,,,,,,,,, -273300,4.567098,0.6798859,,,,,,,,,,,,,, -273400,5.2711844,0.7568141,,,,,,,,,,,,,, -273500,4.486469,0.6465404,,,,,,,,,,,,,, -273600,4.362031,0.6190506,,,,,,,,,,,,,, -273700,4.6858287,0.6652221,,,,,,,,,,,,,, -273800,4.5699406,0.7047539,,,,,,,,,,,,,, -273900,4.367308,0.6739783,,,,,,,,,,,,,, -274000,4.501531,0.70700145,,,,,,,,,,,,,, -274100,4.5668125,0.5882075,,,,,,,,,,,,,, -274200,4.625009,0.6658645,,,,,,,,,,,,,, -274300,4.0215535,0.5694177,,,,,,,,,,,,,, -274400,4.3582983,0.6122149,,,,,,,,,,,,,, -274484,,,0.961136758327484,0.1444941759109497,0.7572999596595764,1.0405818223953247,50000.0,0.6320000290870667,1.8210731744766235,10000.0,92369.35075259209,95669.04021525384,92369.35075259209,3277.009635448456,12.7144455909729,0.0 -274500,4.780895,0.5977955,,,,,,,,,,,,,, -274600,4.4914947,0.58515227,,,,,,,,,,,,,, -274700,4.594256,0.61432165,,,,,,,,,,,,,, -274800,4.1481695,0.54473007,,,,,,,,,,,,,, -274900,4.5846157,0.65145826,,,,,,,,,,,,,, -275000,4.540388,0.5719478,,,,,,,,,,,,,, -275100,4.2517443,0.5713937,,,,,,,,,,,,,, -275200,4.7657733,0.67412615,,,,,,,,,,,,,, -275300,4.3840427,0.6188328,,,,,,,,,,,,,, -275400,4.4790287,0.62213475,,,,,,,,,,,,,, -275500,4.4282517,0.6688597,,,,,,,,,,,,,, -275600,4.391599,0.6038954,,,,,,,,,,,,,, -275700,4.4978456,0.61978364,,,,,,,,,,,,,, -275800,4.857515,0.6164143,,,,,,,,,,,,,, -275900,4.2348647,0.5760058,,,,,,,,,,,,,, -276000,,,0.9621930718421936,0.1435670554637909,0.7574999928474426,1.0415067672729492,50000.0,0.6319000124931335,1.8237287998199463,10000.0,92879.23972916605,96196.10747170448,92879.23972916605,3294.059141635895,12.78620743751526,0.0 -276000,4.272056,0.64315903,,,,,,,,,,,,,, -276100,4.4605327,0.61777157,,,,,,,,,,,,,, -276200,4.6121016,0.6110611,,,,,,,,,,,,,, -276300,4.245401,0.6122165,,,,,,,,,,,,,, -276400,4.3487296,0.62806284,,,,,,,,,,,,,, -276500,4.4348593,0.63432044,,,,,,,,,,,,,, -276600,4.04576,0.5869312,,,,,,,,,,,,,, -276700,4.208967,0.5666737,,,,,,,,,,,,,, -276800,4.775727,0.61308306,,,,,,,,,,,,,, -276900,4.52089,0.62860394,,,,,,,,,,,,,, -277000,4.293697,0.5803778,,,,,,,,,,,,,, -277100,4.9905396,0.67576605,,,,,,,,,,,,,, -277200,4.3657007,0.58830917,,,,,,,,,,,,,, -277300,4.499963,0.5906208,,,,,,,,,,,,,, -277400,4.6201625,0.6548525,,,,,,,,,,,,,, -277500,4.272085,0.5927923,,,,,,,,,,,,,, -277518,,,0.9601004123687744,0.149187371134758,0.7570599913597107,1.0416173934936523,50000.0,0.6314000487327576,1.8231300115585327,10000.0,93389.430683136,96723.36788201332,93389.430683136,3310.9949176311493,12.863352537155151,0.0 -277600,4.1041555,0.545827,,,,,,,,,,,,,, -277700,4.480534,0.6023372,,,,,,,,,,,,,, -277800,4.7099333,0.613585,,,,,,,,,,,,,, -277900,4.103204,0.59098524,,,,,,,,,,,,,, -278000,5.1217737,0.6391941,,,,,,,,,,,,,, -278100,5.1987786,0.6586994,,,,,,,,,,,,,, -278200,4.366945,0.63550055,,,,,,,,,,,,,, -278300,4.4684377,0.644806,,,,,,,,,,,,,, -278400,4.3931394,0.62131727,,,,,,,,,,,,,, -278500,4.954209,0.6435856,,,,,,,,,,,,,, -278600,4.3494887,0.66522616,,,,,,,,,,,,,, -278700,4.4522367,0.6242411,,,,,,,,,,,,,, -278800,4.491398,0.6466415,,,,,,,,,,,,,, -278900,4.4492354,0.6243193,,,,,,,,,,,,,, -279000,4.327513,0.52746856,,,,,,,,,,,,,, -279034,,,0.9613759517669678,0.1479669213294983,0.7570399641990662,1.0405782461166382,50000.0,0.631100058555603,1.8226639032363887,10000.0,93899.51947426796,97250.8225440979,93899.51947426796,3328.229616165161,12.936316013336182,0.0 -279100,4.722801,0.6123683,,,,,,,,,,,,,, -279200,4.6082945,0.6224322,,,,,,,,,,,,,, -279300,5.1185718,0.7188616,,,,,,,,,,,,,, -279400,4.398264,0.6171086,,,,,,,,,,,,,, -279500,4.2909627,0.5992184,,,,,,,,,,,,,, -279600,4.099534,0.5496072,,,,,,,,,,,,,, -279700,4.5903316,0.6310128,,,,,,,,,,,,,, -279800,4.3778696,0.63305366,,,,,,,,,,,,,, -279900,4.7968326,0.6859135,,,,,,,,,,,,,, -280000,3.977267,0.6065779,,,,,,,,,,,,,, -280100,4.527409,0.6445472,,,,,,,,,,,,,, -280200,4.4851947,0.62201303,,,,,,,,,,,,,, -280300,4.7690473,0.612906,,,,,,,,,,,,,, -280400,4.257665,0.6010674,,,,,,,,,,,,,, -280500,4.49187,0.65142226,,,,,,,,,,,,,, -280550,,,0.9606783986091614,0.1443613618612289,0.7573399543762207,1.042070388793945,50000.0,0.6315000057220459,1.822600483894348,10000.0,94409.50527763368,97778.03981542587,94409.50527763368,3345.328625202179,13.0116868019104,0.0 -280600,4.2645464,0.60239017,,,,,,,,,,,,,, -280700,4.984092,0.65799665,,,,,,,,,,,,,, -280800,4.5602427,0.6883463,,,,,,,,,,,,,, -280900,4.504588,0.64745146,,,,,,,,,,,,,, -281000,4.2054353,0.60002124,,,,,,,,,,,,,, -281100,4.1536684,0.6633806,,,,,,,,,,,,,, -281200,4.7847257,0.6171995,,,,,,,,,,,,,, -281300,5.0191927,0.6363532,,,,,,,,,,,,,, -281400,4.384761,0.6116077,,,,,,,,,,,,,, -281500,4.2805796,0.6461712,,,,,,,,,,,,,, -281600,4.2976646,0.6151576,,,,,,,,,,,,,, -281700,4.183349,0.691553,,,,,,,,,,,,,, -281800,4.249899,0.56660944,,,,,,,,,,,,,, -281900,4.376796,0.6388219,,,,,,,,,,,,,, -282000,4.4832706,0.6152844,,,,,,,,,,,,,, -282066,,,0.960957407951355,0.1450634747743606,0.7574599981307983,1.0403265953063965,50000.0,0.6318000555038452,1.8230668306350708,10000.0,94919.44342923164,98305.10606789587,94919.44342923164,3362.324405670166,13.08899974822998,0.0 -282100,4.360941,0.6785008,,,,,,,,,,,,,, -282200,4.2928495,0.64517766,,,,,,,,,,,,,, -282300,4.274603,0.56692386,,,,,,,,,,,,,, -282400,4.082711,0.5600674,,,,,,,,,,,,,, -282500,4.383365,0.6011678,,,,,,,,,,,,,, -282600,4.4207807,0.5823466,,,,,,,,,,,,,, -282700,4.1412296,0.59249204,,,,,,,,,,,,,, -282800,4.3027306,0.53748727,,,,,,,,,,,,,, -282900,4.609584,0.6474706,,,,,,,,,,,,,, -283000,4.5604386,0.6668001,,,,,,,,,,,,,, -283100,4.403046,0.5479135,,,,,,,,,,,,,, -283200,4.831582,0.72469294,,,,,,,,,,,,,, -283300,4.72261,0.699599,,,,,,,,,,,,,, -283400,4.146054,0.6259229,,,,,,,,,,,,,, -283500,4.363699,0.6454786,,,,,,,,,,,,,, -283583,,,0.9598811864852904,0.148601621389389,0.7568999528884888,1.0414470434188845,50000.0,0.632900059223175,1.823802471160889,10000.0,95429.4895875454,98832.29233837128,95429.4895875454,3379.3356223106384,13.161826610565186,0.0 -283600,4.8877296,0.6407856,,,,,,,,,,,,,, -283700,4.187841,0.64411336,,,,,,,,,,,,,, -283800,4.6814547,0.66391504,,,,,,,,,,,,,, -283900,4.476821,0.5777679,,,,,,,,,,,,,, -284000,4.6711316,0.57272446,,,,,,,,,,,,,, -284100,4.5851007,0.65124506,,,,,,,,,,,,,, -284200,4.46321,0.6035546,,,,,,,,,,,,,, -284300,4.5282526,0.5604906,,,,,,,,,,,,,, -284400,4.876811,0.65926284,,,,,,,,,,,,,, -284500,4.276656,0.6349676,,,,,,,,,,,,,, -284600,4.5913506,0.6702664,,,,,,,,,,,,,, -284700,4.0959325,0.5579684,,,,,,,,,,,,,, -284800,4.611328,0.6349405,,,,,,,,,,,,,, -284900,4.7816954,0.61898464,,,,,,,,,,,,,, -285000,4.473758,0.67473733,,,,,,,,,,,,,, -285100,4.8821645,0.6514148,,,,,,,,,,,,,, -285101,,,0.9600605964660645,0.1479718536138534,0.7570399641990662,1.0402021408081057,50000.0,0.6318000555038452,1.8214406967163088,10000.0,95939.8146879673,99359.83514142036,95939.8146879673,3396.418870925904,13.239661693572998,0.0 -285200,4.5504446,0.6070452,,,,,,,,,,,,,, -285300,4.4847565,0.64516383,,,,,,,,,,,,,, -285400,4.397224,0.62422043,,,,,,,,,,,,,, -285500,4.7995563,0.5975641,,,,,,,,,,,,,, -285600,4.7007356,0.5972429,,,,,,,,,,,,,, -285700,4.2925687,0.63208884,,,,,,,,,,,,,, -285800,4.74987,0.6850922,,,,,,,,,,,,,, -285900,4.8018627,0.603497,,,,,,,,,,,,,, -286000,4.6689878,0.60145056,,,,,,,,,,,,,, -286100,4.6346617,0.62125885,,,,,,,,,,,,,, -286200,4.15344,0.5866696,,,,,,,,,,,,,, -286300,4.3471813,0.59548724,,,,,,,,,,,,,, -286400,4.665277,0.62343806,,,,,,,,,,,,,, -286500,4.6287336,0.64312804,,,,,,,,,,,,,, -286600,4.6711597,0.63015485,,,,,,,,,,,,,, -286618,,,0.959741711616516,0.1493349075317382,0.7571199536323547,1.040514349937439,50000.0,0.6312000155448914,1.822487235069275,10000.0,96449.9844045639,99887.04127049446,96449.9844045639,3413.324624300003,13.314711093902588,0.0 -286700,4.460421,0.6104735,,,,,,,,,,,,,, -286800,4.4643707,0.60445106,,,,,,,,,,,,,, -286900,4.465529,0.63839406,,,,,,,,,,,,,, -287000,4.606259,0.60208297,,,,,,,,,,,,,, -287100,4.74501,0.6087662,,,,,,,,,,,,,, -287200,4.739335,0.6594574,,,,,,,,,,,,,, -287300,4.274484,0.6123911,,,,,,,,,,,,,, -287400,4.4492006,0.64821523,,,,,,,,,,,,,, -287500,4.4073844,0.6827737,,,,,,,,,,,,,, -287600,4.244196,0.5326442,,,,,,,,,,,,,, -287700,4.3059726,0.6189966,,,,,,,,,,,,,, -287800,4.9894557,0.64966774,,,,,,,,,,,,,, -287900,4.3102994,0.5826346,,,,,,,,,,,,,, -288000,4.8017583,0.6540169,,,,,,,,,,,,,, -288082,,,0.961734652519226,0.1440905779600143,0.7569999694824219,1.042284607887268,50000.0,0.6320000290870667,1.8242065906524656,10000.0,96959.99516606332,100414.2544465065,96959.99516606332,3430.393528699875,13.393523693084717,0.0 -288100,4.52307,0.6543768,,,,,,,,,,,,,, -288200,4.4127145,0.6677849,,,,,,,,,,,,,, -288300,4.563246,0.679844,,,,,,,,,,,,,, -288400,4.327622,0.6101438,,,,,,,,,,,,,, -288500,5.0574374,0.6417318,,,,,,,,,,,,,, -288600,4.634549,0.6728674,,,,,,,,,,,,,, -288700,4.7232323,0.579876,,,,,,,,,,,,,, -288800,4.7770033,0.6554649,,,,,,,,,,,,,, -288900,4.4688196,0.6337195,,,,,,,,,,,,,, -289000,4.3142114,0.6401682,,,,,,,,,,,,,, -289100,4.1784577,0.61627716,,,,,,,,,,,,,, -289200,5.5350018,0.6017889,,,,,,,,,,,,,, -289300,4.3550205,0.6345681,,,,,,,,,,,,,, -289400,4.459951,0.6182231,,,,,,,,,,,,,, -289500,4.49533,0.59272826,,,,,,,,,,,,,, -289599,,,0.9604192972183228,0.1508612632751464,0.7572799921035767,1.0402519702911377,50000.0,0.6315000057220459,1.821026086807251,10000.0,97469.9808113575,100941.30323529243,97469.9808113575,3447.3284227848053,13.465181112289429,0.0 -289600,4.3100076,0.6094303,,,,,,,,,,,,,, -289700,4.6936765,0.6698834,,,,,,,,,,,,,, -289800,4.4292874,0.57897335,,,,,,,,,,,,,, -289900,4.736947,0.65256065,,,,,,,,,,,,,, -290000,4.1539974,0.5545437,,,,,,,,,,,,,, -290100,4.347906,0.66237956,,,,,,,,,,,,,, -290200,4.6246552,0.627356,,,,,,,,,,,,,, -290300,4.97132,0.6539147,,,,,,,,,,,,,, -290400,4.517142,0.6159641,,,,,,,,,,,,,, -290500,4.382098,0.5868929,,,,,,,,,,,,,, -290600,4.5110326,0.64214253,,,,,,,,,,,,,, -290700,4.237687,0.64560384,,,,,,,,,,,,,, -290800,4.080038,0.57499254,,,,,,,,,,,,,, -290900,4.4549894,0.6756885,,,,,,,,,,,,,, -291000,5.264901,0.6701166,,,,,,,,,,,,,, -291100,4.309361,0.61122006,,,,,,,,,,,,,, -291116,,,0.9608178734779358,0.1459608823060989,0.7571199536323547,1.0409700870513916,50000.0,0.6319000124931335,1.824628591537476,10000.0,97979.8696899414,101468.5384926796,97979.8696899414,3464.5442264080048,13.540786027908323,0.0 -291200,4.5922966,0.712583,,,,,,,,,,,,,, -291300,4.9769816,0.64422244,,,,,,,,,,,,,, -291400,4.5481873,0.62168455,,,,,,,,,,,,,, -291500,4.4899945,0.63176775,,,,,,,,,,,,,, -291600,4.8123136,0.6978685,,,,,,,,,,,,,, -291700,4.345153,0.6306063,,,,,,,,,,,,,, -291800,4.3585424,0.6031006,,,,,,,,,,,,,, -291900,4.6235275,0.5670094,,,,,,,,,,,,,, -292000,4.2885103,0.62125903,,,,,,,,,,,,,, -292100,4.856346,0.68913805,,,,,,,,,,,,,, -292200,4.6841326,0.7000805,,,,,,,,,,,,,, -292300,4.369129,0.57917535,,,,,,,,,,,,,, -292400,4.4516215,0.61811244,,,,,,,,,,,,,, -292500,4.488178,0.64169043,,,,,,,,,,,,,, -292600,4.317675,0.5963802,,,,,,,,,,,,,, -292633,,,0.9606385231018066,0.1455745548009872,0.7572000026702881,1.0424156188964844,50000.0,0.6313000321388245,1.8246649503707888,10000.0,98490.01154613496,101995.98503899574,98490.01154613496,3481.718322277069,13.615017414093018,0.0 -292700,4.383731,0.65517855,,,,,,,,,,,,,, -292800,3.7927003,0.5527155,,,,,,,,,,,,,, -292900,4.6983757,0.5962716,,,,,,,,,,,,,, -293000,4.1736875,0.6002621,,,,,,,,,,,,,, -293100,5.28659,0.6946008,,,,,,,,,,,,,, -293200,4.412502,0.5675232,,,,,,,,,,,,,, -293300,4.6806397,0.634187,,,,,,,,,,,,,, -293400,4.3059196,0.66459525,,,,,,,,,,,,,, -293500,4.6527977,0.6052754,,,,,,,,,,,,,, -293600,4.404393,0.6372927,,,,,,,,,,,,,, -293700,4.174637,0.6469797,,,,,,,,,,,,,, -293800,4.799461,0.6606156,,,,,,,,,,,,,, -293900,4.2263722,0.6047834,,,,,,,,,,,,,, -294000,4.209926,0.5652151,,,,,,,,,,,,,, -294100,4.7659483,0.6103839,,,,,,,,,,,,,, -294151,,,0.9596819281578064,0.1481228917837143,0.756659984588623,1.041306495666504,50000.0,0.6307000517845154,1.823425531387329,10000.0,99000.07468652724,102523.11113667488,99000.07468652724,3498.651032447815,13.689282655715942,0.0 -294200,4.31461,0.63886476,,,,,,,,,,,,,, -294300,4.473243,0.6672571,,,,,,,,,,,,,, -294400,4.3183374,0.59807646,,,,,,,,,,,,,, -294500,4.4156585,0.67939395,,,,,,,,,,,,,, -294600,4.463338,0.6127889,,,,,,,,,,,,,, -294700,4.2617364,0.6319139,,,,,,,,,,,,,, -294800,4.6656437,0.624621,,,,,,,,,,,,,, -294900,4.1432076,0.58708507,,,,,,,,,,,,,, -295000,4.053712,0.6224339,,,,,,,,,,,,,, -295100,4.4164705,0.6098289,,,,,,,,,,,,,, -295200,4.965634,0.67506975,,,,,,,,,,,,,, -295300,4.4564743,0.65507084,,,,,,,,,,,,,, -295400,4.882806,0.64847636,,,,,,,,,,,,,, -295500,4.177431,0.5780856,,,,,,,,,,,,,, -295600,4.5894623,0.63397294,,,,,,,,,,,,,, -295669,,,0.9604591727256776,0.1468521803617477,0.7571399807929993,1.0419366359710691,50000.0,0.6312000155448914,1.8231909275054927,10000.0,99510.17113494872,103050.29315805437,99510.17113494872,3515.6052017211914,13.764830350875854,0.0 -295700,4.518908,0.6594642,,,,,,,,,,,,,, -295800,4.641392,0.6481207,,,,,,,,,,,,,, -295900,4.2063446,0.6303827,,,,,,,,,,,,,, -296000,4.335888,0.5794263,,,,,,,,,,,,,, -296100,4.6912436,0.6029759,,,,,,,,,,,,,, -296200,4.582352,0.66408455,,,,,,,,,,,,,, -296300,4.204863,0.5728642,,,,,,,,,,,,,, -296400,4.9493814,0.69283104,,,,,,,,,,,,,, -296500,4.627189,0.66291094,,,,,,,,,,,,,, -296600,4.6970534,0.5912124,,,,,,,,,,,,,, -296700,4.218509,0.62711215,,,,,,,,,,,,,, -296800,4.625411,0.5638066,,,,,,,,,,,,,, -296900,4.3217483,0.5877233,,,,,,,,,,,,,, -297000,4.9887633,0.6669528,,,,,,,,,,,,,, -297100,4.454011,0.5795156,,,,,,,,,,,,,, -297186,,,0.9598811864852904,0.1468236297369003,0.756659984588623,1.0413254499435425,50000.0,0.6318000555038452,1.8235479593276973,10000.0,100020.35486221312,103577.52541542052,100020.35486221312,3532.513886451721,13.849292993545532,0.0 -297200,4.240065,0.5929358,,,,,,,,,,,,,, -297300,4.747665,0.6748234,,,,,,,,,,,,,, -297400,4.0968475,0.5924795,,,,,,,,,,,,,, -297500,4.7609944,0.5799161,,,,,,,,,,,,,, -297600,4.5216503,0.61966735,,,,,,,,,,,,,, -297700,4.2854342,0.61357355,,,,,,,,,,,,,, -297800,4.4038644,0.5766015,,,,,,,,,,,,,, -297900,4.491891,0.6453398,,,,,,,,,,,,,, -298000,4.6770015,0.61450624,,,,,,,,,,,,,, -298100,3.8965724,0.54277515,,,,,,,,,,,,,, -298200,4.01334,0.5372458,,,,,,,,,,,,,, -298300,5.170352,0.6926642,,,,,,,,,,,,,, -298400,4.483538,0.62119216,,,,,,,,,,,,,, -298500,4.0244865,0.5599896,,,,,,,,,,,,,, -298600,4.2514253,0.5665295,,,,,,,,,,,,,, -298700,4.656257,0.71483016,,,,,,,,,,,,,, -298704,,,0.9595623016357422,0.14792300760746,0.7567999958992004,1.040809154510498,50000.0,0.6312000155448914,1.8223090171813965,10000.0,100530.54751634598,104104.92567372322,100530.54751634598,3549.5913729667664,13.923123359680176,0.0 -298800,4.4461045,0.59859216,,,,,,,,,,,,,, -298900,4.2529526,0.56419075,,,,,,,,,,,,,, -299000,4.328788,0.5818704,,,,,,,,,,,,,, -299100,4.922895,0.6557029,,,,,,,,,,,,,, -299200,4.4971247,0.61884755,,,,,,,,,,,,,, -299300,4.089433,0.51284343,,,,,,,,,,,,,, -299400,4.561915,0.67308116,,,,,,,,,,,,,, -299500,4.84482,0.577348,,,,,,,,,,,,,, -299600,4.735956,0.61708784,,,,,,,,,,,,,, -299700,4.3991594,0.6173512,,,,,,,,,,,,,, -299800,4.982211,0.6396093,,,,,,,,,,,,,, -299900,4.9419093,0.57097447,,,,,,,,,,,,,, -300000,4.9279957,0.57418406,,,,,,,,,,,,,, -300100,4.3701377,0.5756695,,,,,,,,,,,,,, -300200,4.2750573,0.61508465,,,,,,,,,,,,,, -300221,,,0.962332546710968,0.1443368941545486,0.7571199536323547,1.0422277450561523,50000.0,0.6315000057220459,1.8250768184661863,10000.0,101040.60511946678,104632.22343468666,101040.60511946678,3566.6976997852325,14.00077509880066,0.0 -300300,4.400467,0.6009119,,,,,,,,,,,,,, -300400,4.3309026,0.6040938,,,,,,,,,,,,,, -300500,4.592393,0.62821,,,,,,,,,,,,,, -300600,5.087989,0.7127217,,,,,,,,,,,,,, -300700,4.4088373,0.62110037,,,,,,,,,,,,,, -300800,4.8325577,0.6643231,,,,,,,,,,,,,, -300900,4.1963487,0.55483747,,,,,,,,,,,,,, -301000,4.46324,0.67400336,,,,,,,,,,,,,, -301100,4.56166,0.60205877,,,,,,,,,,,,,, -301200,4.8209753,0.64824224,,,,,,,,,,,,,, -301300,4.4333487,0.651897,,,,,,,,,,,,,, -301400,5.2317405,0.5588074,,,,,,,,,,,,,, -301500,4.433771,0.61967087,,,,,,,,,,,,,, -301600,4.350694,0.58009505,,,,,,,,,,,,,, -301700,4.944287,0.64508206,,,,,,,,,,,,,, -301739,,,0.9604591727256776,0.1479778289794922,0.7568999528884888,1.0407817363739014,50000.0,0.6314000487327576,1.822162628173828,10000.0,101550.64585661888,105159.28444576263,101550.64585661888,3583.581182718277,14.081845045089722,0.0 -301800,4.185797,0.56030047,,,,,,,,,,,,,, -301900,4.715982,0.66002613,,,,,,,,,,,,,, -302000,4.16729,0.56784993,,,,,,,,,,,,,, -302100,4.6552196,0.64108515,,,,,,,,,,,,,, -302200,4.4357586,0.5648969,,,,,,,,,,,,,, -302300,4.3593917,0.63229847,,,,,,,,,,,,,, -302400,4.4551487,0.5827776,,,,,,,,,,,,,, -302500,4.7053127,0.6767956,,,,,,,,,,,,,, -302600,4.737825,0.62711006,,,,,,,,,,,,,, -302700,4.3277645,0.6149963,,,,,,,,,,,,,, -302800,4.7305417,0.5981481,,,,,,,,,,,,,, -302900,5.1298075,0.70242405,,,,,,,,,,,,,, -303000,4.5921946,0.67156976,,,,,,,,,,,,,, -303100,4.795401,0.6776171,,,,,,,,,,,,,, -303200,4.391209,0.61325884,,,,,,,,,,,,,, -303256,,,0.9602199792861938,0.1464538872241974,0.7568199634552002,1.0412921905517578,50000.0,0.631600022315979,1.8222219944000244,10000.0,102060.69908952712,105686.56081461906,102060.69908952712,3600.6692349910736,14.160232305526732,0.0 -303300,4.9884624,0.6639593,,,,,,,,,,,,,, -303400,4.3152585,0.6459048,,,,,,,,,,,,,, -303500,4.714026,0.5993379,,,,,,,,,,,,,, -303600,4.3192987,0.6105922,,,,,,,,,,,,,, -303700,4.6822686,0.62896943,,,,,,,,,,,,,, -303800,4.375538,0.5526422,,,,,,,,,,,,,, -303900,4.4709044,0.58752227,,,,,,,,,,,,,, -304000,4.2620654,0.5394096,,,,,,,,,,,,,, -304100,4.9587145,0.57563484,,,,,,,,,,,,,, -304200,4.3924017,0.59742486,,,,,,,,,,,,,, -304300,4.485226,0.6245436,,,,,,,,,,,,,, -304400,4.2544527,0.617481,,,,,,,,,,,,,, -304500,4.477414,0.6712513,,,,,,,,,,,,,, -304600,4.616389,0.6110285,,,,,,,,,,,,,, -304700,4.8465686,0.681075,,,,,,,,,,,,,, -304773,,,0.96195387840271,0.1408708095550537,0.7572199702262878,1.0406047105789185,50000.0,0.6315000057220459,1.823304533958435,10000.0,102570.83042025566,106213.7646138668,102570.83042025566,3617.607753753662,14.239030599594116,0.0 -304800,4.329779,0.6864122,,,,,,,,,,,,,, -304900,4.607424,0.5694585,,,,,,,,,,,,,, -305000,4.702937,0.6050296,,,,,,,,,,,,,, -305100,4.227147,0.6025355,,,,,,,,,,,,,, -305200,4.547293,0.6288681,,,,,,,,,,,,,, -305300,4.31634,0.62586045,,,,,,,,,,,,,, -305400,4.5327477,0.6432855,,,,,,,,,,,,,, -305500,4.869361,0.67867726,,,,,,,,,,,,,, -305600,4.606323,0.7117923,,,,,,,,,,,,,, -305700,4.704368,0.6748679,,,,,,,,,,,,,, -305800,4.3679247,0.58503413,,,,,,,,,,,,,, -305900,4.5456204,0.62913185,,,,,,,,,,,,,, -306000,4.2170615,0.6336999,,,,,,,,,,,,,, -306100,4.313579,0.5545478,,,,,,,,,,,,,, -306200,4.7007737,0.61637086,,,,,,,,,,,,,, -306290,,,0.9622727632522584,0.1447004824876785,0.7573800086975098,1.0414577722549438,50000.0,0.631600022315979,1.824026107788086,10000.0,103080.9529056549,106741.00653576852,103080.9529056549,3634.594278812408,14.316243886947632,0.0 -306300,4.1584716,0.5823867,,,,,,,,,,,,,, -306400,3.830085,0.4776523,,,,,,,,,,,,,, -306500,4.2927012,0.61784357,,,,,,,,,,,,,, -306600,4.3612022,0.65648717,,,,,,,,,,,,,, -306700,4.5253024,0.6053267,,,,,,,,,,,,,, -306800,4.755488,0.60060614,,,,,,,,,,,,,, -306900,4.580182,0.69174206,,,,,,,,,,,,,, -307000,4.562047,0.64359206,,,,,,,,,,,,,, -307100,4.4975924,0.65727043,,,,,,,,,,,,,, -307200,4.3235292,0.6253868,,,,,,,,,,,,,, -307300,4.8361926,0.65907305,,,,,,,,,,,,,, -307400,4.651861,0.6433247,,,,,,,,,,,,,, -307500,4.36392,0.5683106,,,,,,,,,,,,,, -307600,4.600511,0.62392545,,,,,,,,,,,,,, -307700,4.78488,0.64946187,,,,,,,,,,,,,, -307800,4.480117,0.62231195,,,,,,,,,,,,,, -307807,,,0.9609375,0.1463319510221481,0.7570599913597107,1.0414164066314695,50000.0,0.6320000290870667,1.8238807916641235,10000.0,103590.91629076004,107268.8312842846,103590.91629076004,3652.3201220035553,14.395713567733765,0.0 -307900,4.2968836,0.6027831,,,,,,,,,,,,,, -308000,4.453722,0.6461197,,,,,,,,,,,,,, -308100,4.1110406,0.5914723,,,,,,,,,,,,,, -308200,4.7546344,0.56032157,,,,,,,,,,,,,, -308300,4.8385487,0.6789472,,,,,,,,,,,,,, -308400,4.862965,0.76834923,,,,,,,,,,,,,, -308500,4.344105,0.60888743,,,,,,,,,,,,,, -308600,4.3016734,0.60027325,,,,,,,,,,,,,, -308700,4.2917266,0.67101026,,,,,,,,,,,,,, -308800,4.2969804,0.6539501,,,,,,,,,,,,,, -308900,4.8485894,0.59609973,,,,,,,,,,,,,, -309000,4.673972,0.6270223,,,,,,,,,,,,,, -309100,4.6672707,0.68734664,,,,,,,,,,,,,, -309200,3.9703765,0.5697378,,,,,,,,,,,,,, -309300,4.194552,0.554637,,,,,,,,,,,,,, -309324,,,0.960160195827484,0.1481978297233581,0.7568399906158447,1.04118549823761,50000.0,0.6309000253677368,1.8230949640274048,10000.0,104100.90484285356,107796.08420968056,104100.90484285356,3669.4485788345337,14.476394414901732,0.0 -309400,4.7488313,0.6416757,,,,,,,,,,,,,, -309500,4.0915737,0.6266879,,,,,,,,,,,,,, -309600,4.567521,0.6862208,,,,,,,,,,,,,, -309700,4.471831,0.60454535,,,,,,,,,,,,,, -309800,4.246161,0.56193846,,,,,,,,,,,,,, -309900,4.308305,0.59205437,,,,,,,,,,,,,, -310000,4.409742,0.66282314,,,,,,,,,,,,,, -310100,4.2092943,0.6321169,,,,,,,,,,,,,, -310200,4.142211,0.547432,,,,,,,,,,,,,, -310300,4.5333753,0.6142407,,,,,,,,,,,,,, -310400,4.28547,0.6453581,,,,,,,,,,,,,, -310500,4.8159,0.67570996,,,,,,,,,,,,,, -310600,4.799682,0.5633361,,,,,,,,,,,,,, -310700,4.4037867,0.6193093,,,,,,,,,,,,,, -310800,4.5158052,0.6461038,,,,,,,,,,,,,, -310841,,,0.9608178734779358,0.1436031311750412,0.7575199604034424,1.0405359268188477,50000.0,0.6326000094413757,1.8219730854034424,10000.0,104610.92230081558,108323.1176261902,104610.92230081558,3686.322530031204,14.564249038696287,0.0 -310900,4.255481,0.6044111,,,,,,,,,,,,,, -311000,4.960732,0.64731145,,,,,,,,,,,,,, -311100,4.4699306,0.6973269,,,,,,,,,,,,,, -311200,4.5750933,0.6531615,,,,,,,,,,,,,, -311300,4.464875,0.5957893,,,,,,,,,,,,,, -311400,4.49601,0.5939309,,,,,,,,,,,,,, -311500,4.473694,0.55887276,,,,,,,,,,,,,, -311600,4.6171103,0.60513985,,,,,,,,,,,,,, -311700,4.3924613,0.5934071,,,,,,,,,,,,,, -311800,4.5031786,0.63413084,,,,,,,,,,,,,, -311900,4.5853715,0.6761229,,,,,,,,,,,,,, -312000,4.99329,0.64669245,,,,,,,,,,,,,, -312100,4.412636,0.63787585,,,,,,,,,,,,,, -312200,4.717122,0.678422,,,,,,,,,,,,,, -312300,4.250633,0.5837457,,,,,,,,,,,,,, -312358,,,0.9614157676696776,0.1467447876930236,0.7572399973869324,1.0412235260009766,50000.0,0.6325000524520874,1.8220093250274656,10000.0,105120.93588781355,108850.32050657272,105120.93588781355,3703.3795261383057,14.63990044593811,0.0 -312400,4.656282,0.6651479,,,,,,,,,,,,,, -312500,4.7367105,0.6073022,,,,,,,,,,,,,, -312600,4.2262235,0.6081332,,,,,,,,,,,,,, -312700,4.5325174,0.66637063,,,,,,,,,,,,,, -312800,4.295686,0.6337195,,,,,,,,,,,,,, -312900,4.5640907,0.69657505,,,,,,,,,,,,,, -313000,4.330564,0.619728,,,,,,,,,,,,,, -313100,4.752455,0.61253613,,,,,,,,,,,,,, -313200,4.5715833,0.627604,,,,,,,,,,,,,, -313300,3.9561584,0.5584697,,,,,,,,,,,,,, -313400,4.5387964,0.6178639,,,,,,,,,,,,,, -313500,3.8604603,0.57677364,,,,,,,,,,,,,, -313600,4.99147,0.59701496,,,,,,,,,,,,,, -313700,4.2553782,0.53782696,,,,,,,,,,,,,, -313800,4.3986483,0.63905513,,,,,,,,,,,,,, -313877,,,0.9605388641357422,0.1448961198329925,0.757420003414154,1.0412286520004272,50000.0,0.6322000026702881,1.8241868019104004,10000.0,105631.11181282996,109377.67593812944,105631.11181282996,3720.4144394397736,14.726381301879885,0.0 -313900,4.464215,0.5585835,,,,,,,,,,,,,, -314000,4.365438,0.55574495,,,,,,,,,,,,,, -314100,4.020035,0.6257291,,,,,,,,,,,,,, -314200,4.358714,0.6601009,,,,,,,,,,,,,, -314300,4.8067794,0.6650661,,,,,,,,,,,,,, -314400,4.634603,0.62400776,,,,,,,,,,,,,, -314500,4.312606,0.60737246,,,,,,,,,,,,,, -314600,4.4973397,0.6282199,,,,,,,,,,,,,, -314700,4.434552,0.655145,,,,,,,,,,,,,, -314800,4.5834513,0.63413006,,,,,,,,,,,,,, -314900,4.1458187,0.59553766,,,,,,,,,,,,,, -315000,4.365317,0.63049126,,,,,,,,,,,,,, -315100,4.400467,0.66974515,,,,,,,,,,,,,, -315200,4.825378,0.6236957,,,,,,,,,,,,,, -315300,4.227205,0.5711557,,,,,,,,,,,,,, -315393,,,0.9616350531578064,0.1467013955116272,0.7572599649429321,1.0401941537857056,50000.0,0.631600022315979,1.8226666450500488,10000.0,106141.03833842278,109904.81872224808,106141.03833842278,3737.495297670365,14.80655813217163,0.0 -315400,4.251569,0.63615113,,,,,,,,,,,,,, -315500,4.9413238,0.63721496,,,,,,,,,,,,,, -315600,4.218557,0.62571186,,,,,,,,,,,,,, -315700,4.6135373,0.6470038,,,,,,,,,,,,,, -315800,4.505151,0.5505431,,,,,,,,,,,,,, -315900,4.4897933,0.6578803,,,,,,,,,,,,,, -316000,4.3247128,0.5725676,,,,,,,,,,,,,, -316100,4.6452217,0.635719,,,,,,,,,,,,,, -316200,4.596939,0.6797829,,,,,,,,,,,,,, -316300,4.3321013,0.6045113,,,,,,,,,,,,,, -316400,4.851968,0.6118475,,,,,,,,,,,,,, -316500,4.432519,0.61673087,,,,,,,,,,,,,, -316600,5.1180673,0.6562927,,,,,,,,,,,,,, -316700,4.552951,0.6796945,,,,,,,,,,,,,, -316800,4.330632,0.57870907,,,,,,,,,,,,,, -316900,4.6506844,0.65061176,,,,,,,,,,,,,, -316910,,,0.9612563848495485,0.1465180963277816,0.7571199536323547,1.0433276891708374,50000.0,0.6317000389099121,1.825588345527649,10000.0,106651.106498003,110432.43870258331,106651.106498003,3754.913206577301,14.88515329360962,0.0 -317000,4.229729,0.62564284,,,,,,,,,,,,,, -317100,4.4862785,0.6713549,,,,,,,,,,,,,, -317200,4.588925,0.6039037,,,,,,,,,,,,,, -317300,4.270941,0.6000222,,,,,,,,,,,,,, -317400,4.3858395,0.58651185,,,,,,,,,,,,,, -317500,4.5517015,0.644376,,,,,,,,,,,,,, -317600,4.1959815,0.639292,,,,,,,,,,,,,, -317700,4.0105066,0.5726164,,,,,,,,,,,,,, -317800,4.683788,0.6444601,,,,,,,,,,,,,, -317900,4.4779267,0.6648115,,,,,,,,,,,,,, -318000,4.4261703,0.53980225,,,,,,,,,,,,,, -318100,4.215163,0.6037238,,,,,,,,,,,,,, -318200,4.458616,0.6339087,,,,,,,,,,,,,, -318300,4.265127,0.5690608,,,,,,,,,,,,,, -318400,4.597567,0.63221335,,,,,,,,,,,,,, -318427,,,0.960957407951355,0.1460195183753967,0.7572199702262878,1.0418167114257812,50000.0,0.6325000524520874,1.8240481615066528,10000.0,107161.31473088264,110959.89461374284,107161.31473088264,3772.040938138962,14.948791265487673,0.0 -318500,4.318845,0.63670254,,,,,,,,,,,,,, -318600,4.517841,0.6551431,,,,,,,,,,,,,, -318700,4.571504,0.6365051,,,,,,,,,,,,,, -318800,4.4327154,0.6291889,,,,,,,,,,,,,, -318900,4.5509915,0.57410264,,,,,,,,,,,,,, -319000,4.8194766,0.70404804,,,,,,,,,,,,,, -319100,4.6716113,0.6059953,,,,,,,,,,,,,, -319200,5.4310365,0.70176286,,,,,,,,,,,,,, -319300,4.975668,0.6703031,,,,,,,,,,,,,, -319400,5.239523,0.7289964,,,,,,,,,,,,,, -319500,4.5735826,0.61094975,,,,,,,,,,,,,, -319600,4.724581,0.63952243,,,,,,,,,,,,,, -319700,4.3330526,0.61692137,,,,,,,,,,,,,, -319800,4.517135,0.6215327,,,,,,,,,,,,,, -319900,4.60987,0.6467781,,,,,,,,,,,,,, -319944,,,0.96000075340271,0.1463822722434997,0.7570399641990662,1.0409986972808838,50000.0,0.6315000057220459,1.8227121829986568,10000.0,107671.24874210358,111487.01357507706,107671.24874210358,3789.090869903565,15.027590036392212,0.0 -320000,4.5281076,0.5256746,,,,,,,,,,,,,, -320100,4.489449,0.6028345,,,,,,,,,,,,,, -320200,4.242745,0.60499245,,,,,,,,,,,,,, -320300,4.500786,0.6371304,,,,,,,,,,,,,, -320400,4.6211514,0.6947602,,,,,,,,,,,,,, -320500,4.2493896,0.6186733,,,,,,,,,,,,,, -320600,4.2799816,0.6283896,,,,,,,,,,,,,, -320700,4.6462903,0.6530066,,,,,,,,,,,,,, -320800,4.376717,0.61212635,,,,,,,,,,,,,, -320900,4.8166423,0.6377574,,,,,,,,,,,,,, -321000,4.4715285,0.6851958,,,,,,,,,,,,,, -321100,4.873677,0.6437312,,,,,,,,,,,,,, -321200,4.7927184,0.6077018,,,,,,,,,,,,,, -321300,4.8100724,0.6802919,,,,,,,,,,,,,, -321400,4.520856,0.7176985,,,,,,,,,,,,,, -321461,,,0.9615951776504515,0.1443527191877365,0.7574599981307983,1.0413386821746826,50000.0,0.6323000192642212,1.8240190744400024,10000.0,108181.3416583538,112014.31840515137,108181.3416583538,3806.1626420021057,15.110766649246216,0.0 -321500,4.6528354,0.6063557,,,,,,,,,,,,,, -321600,4.1782913,0.65551084,,,,,,,,,,,,,, -321700,4.5017085,0.6148959,,,,,,,,,,,,,, -321800,4.526335,0.68250513,,,,,,,,,,,,,, -321900,4.7880387,0.65105325,,,,,,,,,,,,,, -322000,4.658942,0.649696,,,,,,,,,,,,,, -322100,4.2052493,0.60649896,,,,,,,,,,,,,, -322200,4.5626755,0.7214327,,,,,,,,,,,,,, -322300,5.7883697,0.6643653,,,,,,,,,,,,,, -322400,5.33519,0.64739686,,,,,,,,,,,,,, -322500,4.600988,0.6264275,,,,,,,,,,,,,, -322600,4.6060586,0.6854613,,,,,,,,,,,,,, -322700,4.3837943,0.7075404,,,,,,,,,,,,,, -322800,4.720919,0.65528095,,,,,,,,,,,,,, -322900,4.1495347,0.59619236,,,,,,,,,,,,,, -322978,,,0.958804965019226,0.1512446850538253,0.7572799921035767,1.0406066179275513,50000.0,0.6319000124931335,1.8205403089523315,10000.0,108691.42335653304,112541.7327632904,108691.42335653304,3823.353661775589,15.195438861846924,0.0 -323000,4.7202735,0.62558585,,,,,,,,,,,,,, -323100,5.0740967,0.6820624,,,,,,,,,,,,,, -323200,4.6105714,0.58420205,,,,,,,,,,,,,, -323300,4.4912157,0.65555286,,,,,,,,,,,,,, -323400,4.7167864,0.57004106,,,,,,,,,,,,,, -323500,4.716205,0.62980366,,,,,,,,,,,,,, -323600,5.2335176,0.65900844,,,,,,,,,,,,,, -323700,4.5245295,0.61724794,,,,,,,,,,,,,, -323800,5.4069996,0.6598955,,,,,,,,,,,,,, -323900,4.283975,0.6532906,,,,,,,,,,,,,, -324000,4.157207,0.5031824,,,,,,,,,,,,,, -324100,4.107401,0.6322383,,,,,,,,,,,,,, -324200,4.493849,0.55432403,,,,,,,,,,,,,, -324300,4.5433197,0.62714344,,,,,,,,,,,,,, -324400,4.303158,0.5539236,,,,,,,,,,,,,, -324495,,,0.960180163383484,0.1485398411750793,0.7571799755096436,1.0408989191055298,50000.0,0.6313000321388245,1.82335364818573,10000.0,109201.54509663582,113069.16121602058,109201.54509663582,3840.5118465423584,15.287701606750488,0.0 -324500,4.6341777,0.6536884,,,,,,,,,,,,,, -324600,4.4125524,0.61321914,,,,,,,,,,,,,, -324700,4.366889,0.61530495,,,,,,,,,,,,,, -324800,4.71943,0.6180411,,,,,,,,,,,,,, -324900,4.3398986,0.6257301,,,,,,,,,,,,,, -325000,4.46503,0.6714942,,,,,,,,,,,,,, -325100,4.88379,0.59178025,,,,,,,,,,,,,, -325200,4.362131,0.5945141,,,,,,,,,,,,,, -325300,4.8060536,0.5652139,,,,,,,,,,,,,, -325400,4.9070044,0.5960924,,,,,,,,,,,,,, -325500,4.2689586,0.61221343,,,,,,,,,,,,,, -325600,4.837094,0.6513261,,,,,,,,,,,,,, -325700,4.3842773,0.60481524,,,,,,,,,,,,,, -325800,5.25556,0.80271244,,,,,,,,,,,,,, -325900,4.3508377,0.6155214,,,,,,,,,,,,,, -326000,4.923298,0.6841145,,,,,,,,,,,,,, -326012,,,0.960339605808258,0.1465994566679,0.7569999694824219,1.041593074798584,50000.0,0.631600022315979,1.822590947151184,10000.0,109711.53610014915,113596.42569184303,109711.53610014915,3857.648682117462,15.368526697158812,0.0 -326100,4.2576184,0.54581124,,,,,,,,,,,,,, -326200,4.636849,0.6679142,,,,,,,,,,,,,, -326300,4.269283,0.5561348,,,,,,,,,,,,,, -326400,4.6354403,0.6556875,,,,,,,,,,,,,, -326500,4.6555157,0.6559414,,,,,,,,,,,,,, -326600,5.1557584,0.66369057,,,,,,,,,,,,,, -326700,4.834534,0.7162287,,,,,,,,,,,,,, -326800,5.115108,0.6427832,,,,,,,,,,,,,, -326900,4.4436836,0.61444604,,,,,,,,,,,,,, -327000,4.5342803,0.64006,,,,,,,,,,,,,, -327100,4.778561,0.5842326,,,,,,,,,,,,,, -327200,4.3063545,0.5947977,,,,,,,,,,,,,, -327300,4.7808237,0.5758203,,,,,,,,,,,,,, -327400,4.7405124,0.6386186,,,,,,,,,,,,,, -327500,4.713986,0.63642585,,,,,,,,,,,,,, -327531,,,0.9618144035339355,0.1475189924240112,0.7569400072097778,1.041744828224182,50000.0,0.6326000094413757,1.823375225067139,10000.0,110221.53441381454,114123.5357427597,110221.53441381454,3874.6263246536255,15.447299480438232,0.0 -327600,4.7073092,0.65101486,,,,,,,,,,,,,, -327700,4.936965,0.629595,,,,,,,,,,,,,, -327800,4.2026625,0.5797676,,,,,,,,,,,,,, -327900,4.5904946,0.6596583,,,,,,,,,,,,,, -328000,4.4228196,0.62583864,,,,,,,,,,,,,, -328100,4.959203,0.6511131,,,,,,,,,,,,,, -328200,4.2562866,0.53859454,,,,,,,,,,,,,, -328300,4.7221293,0.6639862,,,,,,,,,,,,,, -328400,4.1563916,0.5621311,,,,,,,,,,,,,, -328500,4.085445,0.5630499,,,,,,,,,,,,,, -328600,4.563286,0.6227802,,,,,,,,,,,,,, -328700,4.077785,0.56174046,,,,,,,,,,,,,, -328800,4.7083097,0.70859015,,,,,,,,,,,,,, -328900,4.426764,0.57937807,,,,,,,,,,,,,, -329000,4.8338795,0.70442766,,,,,,,,,,,,,, -329048,,,0.9612563848495485,0.1457004994153976,0.7572399973869324,1.0415037870407104,50000.0,0.6314000487327576,1.823715090751648,10000.0,110731.65055704115,114650.8146300316,110731.65055704115,3891.651116847992,15.529188394546509,0.0 -329100,4.347101,0.6423398,,,,,,,,,,,,,, -329200,4.5871677,0.62536997,,,,,,,,,,,,,, -329300,4.530489,0.5792248,,,,,,,,,,,,,, -329400,4.6490836,0.69841033,,,,,,,,,,,,,, -329500,5.1250873,0.6958222,,,,,,,,,,,,,, -329600,4.3259087,0.5793769,,,,,,,,,,,,,, -329700,4.712937,0.625648,,,,,,,,,,,,,, -329800,4.601472,0.5872973,,,,,,,,,,,,,, -329900,4.265202,0.59171593,,,,,,,,,,,,,, -330000,3.932236,0.57173675,,,,,,,,,,,,,, -330100,4.793389,0.6026792,,,,,,,,,,,,,, -330200,4.705171,0.6073322,,,,,,,,,,,,,, -330300,4.2203264,0.57712996,,,,,,,,,,,,,, -330400,4.2518663,0.6201043,,,,,,,,,,,,,, -330500,4.418372,0.65013933,,,,,,,,,,,,,, -330565,,,0.9599011540412904,0.146982803940773,0.7574999928474426,1.0409815311431885,50000.0,0.6319000124931335,1.822548747062683,10000.0,111241.78953266144,115178.20145368576,111241.78953266144,3908.7601075172415,15.613263130187988,0.0 -330600,4.273956,0.66925454,,,,,,,,,,,,,, -330700,4.760569,0.68762875,,,,,,,,,,,,,, -330800,4.705071,0.7062677,,,,,,,,,,,,,, -330900,4.431866,0.6094159,,,,,,,,,,,,,, -331000,4.5496473,0.64992064,,,,,,,,,,,,,, -331100,4.131804,0.5634483,,,,,,,,,,,,,, -331200,4.3810816,0.57672596,,,,,,,,,,,,,, -331300,4.552208,0.56188,,,,,,,,,,,,,, -331400,4.423316,0.6599084,,,,,,,,,,,,,, -331500,4.6429067,0.6344005,,,,,,,,,,,,,, -331600,4.336004,0.5952324,,,,,,,,,,,,,, -331700,4.5465407,0.6709876,,,,,,,,,,,,,, -331800,4.613829,0.598119,,,,,,,,,,,,,, -331900,4.4962754,0.6321764,,,,,,,,,,,,,, -332000,4.875012,0.66304356,,,,,,,,,,,,,, -332082,,,0.959382951259613,0.1487138122320175,0.7572599649429321,1.0413559675216677,50000.0,0.6318000555038452,1.822547912597656,10000.0,111751.74325227736,115705.35889673232,111751.74325227736,3925.823562145233,15.69628357887268,0.0 -332100,4.9142914,0.6893823,,,,,,,,,,,,,, -332200,5.09755,0.63464,,,,,,,,,,,,,, -332300,4.635562,0.66936606,,,,,,,,,,,,,, -332400,4.527801,0.6870952,,,,,,,,,,,,,, -332500,4.298378,0.5998113,,,,,,,,,,,,,, -332600,4.392862,0.6647117,,,,,,,,,,,,,, -332700,4.4418755,0.5787611,,,,,,,,,,,,,, -332800,4.4691515,0.5689963,,,,,,,,,,,,,, -332900,4.3292894,0.68906164,,,,,,,,,,,,,, -333000,4.4444013,0.67371637,,,,,,,,,,,,,, -333100,4.2818904,0.57319236,,,,,,,,,,,,,, -333200,4.505911,0.5814352,,,,,,,,,,,,,, -333300,4.421156,0.63315666,,,,,,,,,,,,,, -333400,4.552232,0.65912163,,,,,,,,,,,,,, -333500,4.711899,0.6484673,,,,,,,,,,,,,, -333598,,,0.9609175324440002,0.1459372639656067,0.7568599581718445,1.0412836074829102,50000.0,0.6317000389099121,1.822006106376648,10000.0,112261.6212658882,116232.40108180046,112261.6212658882,3942.847847461701,15.78191375732422,0.0 -333600,4.825852,0.64961237,,,,,,,,,,,,,, -333700,4.0527806,0.5375286,,,,,,,,,,,,,, -333800,4.895872,0.6758468,,,,,,,,,,,,,, -333900,4.5076,0.60166913,,,,,,,,,,,,,, -334000,4.2049756,0.564508,,,,,,,,,,,,,, -334100,4.33106,0.57741594,,,,,,,,,,,,,, -334200,4.5061674,0.66122353,,,,,,,,,,,,,, -334300,4.4126287,0.6072766,,,,,,,,,,,,,, -334400,4.499595,0.6108626,,,,,,,,,,,,,, -334500,4.5629134,0.6641922,,,,,,,,,,,,,, -334600,4.0142336,0.60537237,,,,,,,,,,,,,, -334700,4.650762,0.67992556,,,,,,,,,,,,,, -334800,4.3196807,0.5790063,,,,,,,,,,,,,, -334900,4.923824,0.68418324,,,,,,,,,,,,,, -335000,4.411753,0.6310631,,,,,,,,,,,,,, -335100,4.2103543,0.572607,,,,,,,,,,,,,, -335115,,,0.9601004123687744,0.1463820040225982,0.7570399641990662,1.0413930416107178,50000.0,0.631600022315979,1.8224438428878784,10000.0,112771.67519831656,116759.58655571938,112771.67519831656,3959.836858987808,15.869224786758425,0.0 -335200,4.978477,0.6502355,,,,,,,,,,,,,, -335300,4.197804,0.61106515,,,,,,,,,,,,,, -335400,4.704436,0.629505,,,,,,,,,,,,,, -335500,4.3011026,0.61262923,,,,,,,,,,,,,, -335600,4.116719,0.65469295,,,,,,,,,,,,,, -335700,5.22369,0.76410073,,,,,,,,,,,,,, -335800,4.650571,0.66252047,,,,,,,,,,,,,, -335900,4.338119,0.5793412,,,,,,,,,,,,,, -336000,4.4068313,0.62263906,,,,,,,,,,,,,, -336100,4.3030577,0.59219176,,,,,,,,,,,,,, -336200,4.380521,0.6330083,,,,,,,,,,,,,, -336300,4.2294292,0.5961429,,,,,,,,,,,,,, -336400,4.80892,0.67338276,,,,,,,,,,,,,, -336500,4.4284067,0.636773,,,,,,,,,,,,,, -336600,4.626983,0.61560905,,,,,,,,,,,,,, -336632,,,0.9599011540412904,0.1479339897632599,0.757319986820221,1.0409128665924072,50000.0,0.6310000419616699,1.8232208490371704,10000.0,113281.5459959507,117286.4574854374,113281.5459959507,3976.698829650879,15.951818227767944,0.0 -336700,4.8239665,0.60208285,,,,,,,,,,,,,, -336800,4.4691854,0.6365512,,,,,,,,,,,,,, -336900,5.302677,0.6703283,,,,,,,,,,,,,, -337000,5.160009,0.7455525,,,,,,,,,,,,,, -337100,4.9034753,0.7142633,,,,,,,,,,,,,, -337200,4.40135,0.59388936,,,,,,,,,,,,,, -337300,4.3819027,0.62656164,,,,,,,,,,,,,, -337400,4.871097,0.651474,,,,,,,,,,,,,, -337500,4.503043,0.63014376,,,,,,,,,,,,,, -337600,4.450717,0.60064477,,,,,,,,,,,,,, -337700,4.6304555,0.62821233,,,,,,,,,,,,,, -337800,4.052022,0.52089405,,,,,,,,,,,,,, -337900,4.1818466,0.59515107,,,,,,,,,,,,,, -338000,4.4077444,0.56451154,,,,,,,,,,,,,, -338100,4.9212055,0.7367908,,,,,,,,,,,,,, -338150,,,0.9609375,0.1465308368206024,0.7573999762535095,1.0408369302749634,50000.0,0.6317000389099121,1.8228672742843628,10000.0,113791.70448088646,117813.6619169712,113791.70448088646,3993.6059629917145,16.03429889678955,0.0 -338200,4.357756,0.5540836,,,,,,,,,,,,,, -338300,4.495407,0.57873446,,,,,,,,,,,,,, -338400,4.5189834,0.5513613,,,,,,,,,,,,,, -338500,4.5689425,0.761294,,,,,,,,,,,,,, -338600,4.0365744,0.5466323,,,,,,,,,,,,,, -338700,3.913209,0.5005015,,,,,,,,,,,,,, -338800,4.6675076,0.7018222,,,,,,,,,,,,,, -338900,4.4863653,0.6753473,,,,,,,,,,,,,, -339000,4.9293537,0.6695278,,,,,,,,,,,,,, -339100,4.5883603,0.67672783,,,,,,,,,,,,,, -339200,4.4862256,0.6074939,,,,,,,,,,,,,, -339300,4.1641417,0.5861387,,,,,,,,,,,,,, -339400,4.438511,0.56066793,,,,,,,,,,,,,, -339500,4.7269716,0.58118355,,,,,,,,,,,,,, -339600,4.447492,0.57733387,,,,,,,,,,,,,, -339667,,,0.9609375,0.1465604901313781,0.757099986076355,1.040977120399475,50000.0,0.6324000358581543,1.8227611780166624,10000.0,114301.828332901,118341.0404188633,114301.828332901,4010.716674566269,16.120765447616577,0.0 -339700,4.490234,0.63431937,,,,,,,,,,,,,, -339800,4.3444624,0.6250873,,,,,,,,,,,,,, -339900,4.366108,0.55217,,,,,,,,,,,,,, -340000,4.6913414,0.63304925,,,,,,,,,,,,,, -340100,4.756427,0.62233835,,,,,,,,,,,,,, -340200,4.4616404,0.6384653,,,,,,,,,,,,,, -340300,4.1452675,0.62175125,,,,,,,,,,,,,, -340400,4.5108094,0.664074,,,,,,,,,,,,,, -340500,4.2267036,0.56667155,,,,,,,,,,,,,, -340600,4.691697,0.6235143,,,,,,,,,,,,,, -340700,4.907474,0.6349591,,,,,,,,,,,,,, -340800,4.2437973,0.65051603,,,,,,,,,,,,,, -340900,4.3664603,0.60983944,,,,,,,,,,,,,, -341000,4.419618,0.59415746,,,,,,,,,,,,,, -341100,4.851293,0.6503857,,,,,,,,,,,,,, -341184,,,0.9606783986091614,0.1450160592794418,0.7569199800491333,1.0414011478424072,50000.0,0.631600022315979,1.8225243091583248,10000.0,114811.8994588852,118868.26939034462,114811.8994588852,4027.736142396927,16.20158553123474,0.0 -341200,4.751906,0.70237565,,,,,,,,,,,,,, -341300,4.3316846,0.56859714,,,,,,,,,,,,,, -341400,4.38029,0.5242706,,,,,,,,,,,,,, -341500,4.702474,0.64934546,,,,,,,,,,,,,, -341600,4.224445,0.5752115,,,,,,,,,,,,,, -341700,4.3025336,0.5731468,,,,,,,,,,,,,, -341800,4.849775,0.6886303,,,,,,,,,,,,,, -341900,4.4959188,0.6318422,,,,,,,,,,,,,, -342000,4.645159,0.6220551,,,,,,,,,,,,,, -342100,4.476153,0.6453881,,,,,,,,,,,,,, -342200,4.835527,0.64850557,,,,,,,,,,,,,, -342300,5.1787705,0.6372406,,,,,,,,,,,,,, -342400,4.674793,0.70080364,,,,,,,,,,,,,, -342500,4.6283774,0.6684453,,,,,,,,,,,,,, -342600,4.892943,0.611838,,,,,,,,,,,,,, -342700,4.5865765,0.68288934,,,,,,,,,,,,,, -342701,,,0.9605787396430968,0.1452729403972625,0.7571799755096436,1.040116548538208,50000.0,0.6321000456809998,1.822193622589112,10000.0,115322.14385986328,119395.64589619637,115322.14385986328,4044.732864379883,16.282942533493042,0.0 -342800,4.96795,0.6775811,,,,,,,,,,,,,, -342900,4.6258383,0.6026139,,,,,,,,,,,,,, -343000,4.2217927,0.5848253,,,,,,,,,,,,,, -343100,5.0082545,0.6751802,,,,,,,,,,,,,, -343200,4.658103,0.6568863,,,,,,,,,,,,,, -343300,4.353635,0.60923135,,,,,,,,,,,,,, -343400,4.663941,0.66698384,,,,,,,,,,,,,, -343500,4.5165634,0.615316,,,,,,,,,,,,,, -343600,4.770011,0.6541871,,,,,,,,,,,,,, -343700,4.345613,0.60154843,,,,,,,,,,,,,, -343800,4.1853023,0.5712694,,,,,,,,,,,,,, -343900,4.3347507,0.6529085,,,,,,,,,,,,,, -344000,4.6563516,0.6592012,,,,,,,,,,,,,, -344100,4.7031875,0.5562576,,,,,,,,,,,,,, -344200,4.870548,0.60216385,,,,,,,,,,,,,, -344218,,,0.9624322056770324,0.1427851915359497,0.7567399740219116,1.042166829109192,50000.0,0.631600022315979,1.8227174282073968,10000.0,115832.13288211824,119922.84298753738,115832.13288211824,4061.798500061035,16.369632244110107,0.0 -344300,4.5683627,0.55257607,,,,,,,,,,,,,, -344400,4.411599,0.5950607,,,,,,,,,,,,,, -344500,4.2362533,0.56826204,,,,,,,,,,,,,, -344600,4.31698,0.640839,,,,,,,,,,,,,, -344700,4.686363,0.6330017,,,,,,,,,,,,,, -344800,4.792814,0.63755023,,,,,,,,,,,,,, -344900,4.6170793,0.6197775,,,,,,,,,,,,,, -345000,4.776718,0.6882555,,,,,,,,,,,,,, -345100,4.323151,0.5759828,,,,,,,,,,,,,, -345200,4.533967,0.6541156,,,,,,,,,,,,,, -345300,4.650346,0.63864696,,,,,,,,,,,,,, -345400,4.479839,0.6237676,,,,,,,,,,,,,, -345500,4.2092547,0.64210194,,,,,,,,,,,,,, -345600,4.368541,0.62020856,,,,,,,,,,,,,, -345700,4.2880073,0.6116239,,,,,,,,,,,,,, -345735,,,0.9613958597183228,0.1466252952814102,0.757099986076355,1.0420275926589966,50000.0,0.6322000026702881,1.822301864624024,10000.0,116342.17555689812,120449.99757409096,116342.17555689812,4078.7688570022574,16.454465627670288,0.0 -345800,4.2068405,0.60626733,,,,,,,,,,,,,, -345900,4.524471,0.6039846,,,,,,,,,,,,,, -346000,4.654238,0.6414224,,,,,,,,,,,,,, -346100,4.755448,0.6019813,,,,,,,,,,,,,, -346200,4.351586,0.61160564,,,,,,,,,,,,,, -346300,4.464815,0.6622525,,,,,,,,,,,,,, -346400,5.037142,0.6522373,,,,,,,,,,,,,, -346500,4.5337815,0.6933881,,,,,,,,,,,,,, -346600,5.0218797,0.6631282,,,,,,,,,,,,,, -346700,4.7041497,0.64172167,,,,,,,,,,,,,, -346800,4.398041,0.6141578,,,,,,,,,,,,,, -346900,4.4770746,0.5940723,,,,,,,,,,,,,, -347000,4.493217,0.7037083,,,,,,,,,,,,,, -347100,4.871757,0.6176109,,,,,,,,,,,,,, -347200,4.442299,0.60283583,,,,,,,,,,,,,, -347252,,,0.9598612785339355,0.1464956998825073,0.7571199536323547,1.039595127105713,50000.0,0.6323000192642212,1.82083797454834,10000.0,116852.03779673576,120977.74806761742,116852.03779673576,4096.516710281372,16.538807153701782,0.0 -347300,4.171249,0.5709584,,,,,,,,,,,,,, -347400,4.335225,0.6219101,,,,,,,,,,,,,, -347500,4.7625976,0.63453436,,,,,,,,,,,,,, -347600,4.3402762,0.60749424,,,,,,,,,,,,,, -347700,4.938501,0.5961038,,,,,,,,,,,,,, -347800,4.288336,0.61488044,,,,,,,,,,,,,, -347900,4.594553,0.6022564,,,,,,,,,,,,,, -348000,4.522883,0.6150359,,,,,,,,,,,,,, -348100,4.2988634,0.6095344,,,,,,,,,,,,,, -348200,4.6004786,0.64659053,,,,,,,,,,,,,, -348300,3.8935637,0.58161604,,,,,,,,,,,,,, -348400,4.841458,0.64222795,,,,,,,,,,,,,, -348500,4.2762575,0.53549296,,,,,,,,,,,,,, -348600,4.9112053,0.5968303,,,,,,,,,,,,,, -348700,4.334794,0.62044287,,,,,,,,,,,,,, -348768,,,0.9602598547935486,0.1485514193773269,0.7569199800491333,1.0419752597808838,50000.0,0.631600022315979,1.823034405708313,10000.0,117362.03670334816,121504.86248326302,117362.03670334816,4113.484414339066,16.631054878234863,0.0 -348800,4.2896056,0.66185224,,,,,,,,,,,,,, -348900,4.6783895,0.6179728,,,,,,,,,,,,,, -349000,4.8606806,0.62490505,,,,,,,,,,,,,, -349100,4.833875,0.63157403,,,,,,,,,,,,,, -349200,4.073233,0.59987056,,,,,,,,,,,,,, -349300,5.028108,0.61425775,,,,,,,,,,,,,, -349400,4.403845,0.6540469,,,,,,,,,,,,,, -349500,4.6840625,0.6453134,,,,,,,,,,,,,, -349600,4.356418,0.6226367,,,,,,,,,,,,,, -349700,5.0686193,0.6641646,,,,,,,,,,,,,, -349800,4.7710004,0.5817877,,,,,,,,,,,,,, -349900,4.188629,0.6069121,,,,,,,,,,,,,, -350000,4.870026,0.6729245,,,,,,,,,,,,,, -350100,4.1438894,0.5203123,,,,,,,,,,,,,, -350200,4.535044,0.6271207,,,,,,,,,,,,,, -350286,,,0.961734652519226,0.1429440379142761,0.7569400072097778,1.0426743030548096,50000.0,0.6313000321388245,1.8241838216781616,10000.0,117872.1169886589,122031.98891568184,117872.1169886589,4130.390887975693,16.715901851654053,0.0 -350300,4.69938,0.6422405,,,,,,,,,,,,,, -350400,4.7531943,0.6203144,,,,,,,,,,,,,, -350500,4.5294085,0.62951297,,,,,,,,,,,,,, -350600,4.960956,0.62392956,,,,,,,,,,,,,, -350700,4.765121,0.6637659,,,,,,,,,,,,,, -350800,4.5491796,0.5816473,,,,,,,,,,,,,, -350900,4.850893,0.66480124,,,,,,,,,,,,,, -351000,4.86017,0.63393545,,,,,,,,,,,,,, -351100,4.533431,0.6369876,,,,,,,,,,,,,, -351200,3.9771373,0.56461346,,,,,,,,,,,,,, -351300,4.305889,0.63302267,,,,,,,,,,,,,, -351400,4.543263,0.6114792,,,,,,,,,,,,,, -351500,4.278941,0.594007,,,,,,,,,,,,,, -351600,4.4860744,0.65227497,,,,,,,,,,,,,, -351700,4.2357855,0.5877904,,,,,,,,,,,,,, -351800,4.4430947,0.67034775,,,,,,,,,,,,,, -351803,,,0.9614955186843872,0.1441121399402618,0.7574399709701538,1.0409822463989258,50000.0,0.6314000487327576,1.822781562805176,10000.0,118382.07088541985,122559.0689971447,118382.07088541985,4147.371191740036,16.80562400817871,0.0 -351900,4.3633265,0.6335832,,,,,,,,,,,,,, -352000,4.395987,0.6511428,,,,,,,,,,,,,, -352100,4.4382744,0.6363277,,,,,,,,,,,,,, -352200,4.3681784,0.6070293,,,,,,,,,,,,,, -352300,4.889971,0.6414789,,,,,,,,,,,,,, -352400,4.5611615,0.6101405,,,,,,,,,,,,,, -352500,4.885732,0.65190095,,,,,,,,,,,,,, -352600,4.3841424,0.58648497,,,,,,,,,,,,,, -352700,4.4889874,0.59077436,,,,,,,,,,,,,, -352800,4.972261,0.6295319,,,,,,,,,,,,,, -352900,4.215471,0.632446,,,,,,,,,,,,,, -353000,4.505055,0.69230145,,,,,,,,,,,,,, -353100,4.4134345,0.67023075,,,,,,,,,,,,,, -353200,4.539142,0.6147026,,,,,,,,,,,,,, -353300,4.6045184,0.62017,,,,,,,,,,,,,, -353320,,,0.9614556431770324,0.1467177718877792,0.7574399709701538,1.0406513214111328,50000.0,0.6314000487327576,1.821877241134644,10000.0,118892.1052172184,123086.36362028122,118892.1052172184,4164.485343694687,16.895461559295654,0.0 -353400,4.5118456,0.6110254,,,,,,,,,,,,,, -353500,4.1760573,0.55879205,,,,,,,,,,,,,, -353600,4.7824903,0.65387005,,,,,,,,,,,,,, -353700,4.76049,0.66110563,,,,,,,,,,,,,, -353800,5.042887,0.69748646,,,,,,,,,,,,,, -353900,4.7598934,0.6252874,,,,,,,,,,,,,, -354000,4.404972,0.65205204,,,,,,,,,,,,,, -354100,4.586202,0.6140688,,,,,,,,,,,,,, -354200,4.290415,0.5809845,,,,,,,,,,,,,, -354300,4.9221005,0.65755683,,,,,,,,,,,,,, -354400,4.8062634,0.65829563,,,,,,,,,,,,,, -354500,4.7035394,0.67848897,,,,,,,,,,,,,, -354600,4.213825,0.5949304,,,,,,,,,,,,,, -354700,4.4244537,0.5863574,,,,,,,,,,,,,, -354800,4.639324,0.6735785,,,,,,,,,,,,,, -354837,,,0.9608577489852904,0.146235704421997,0.7569599747657776,1.042207956314087,50000.0,0.6315000057220459,1.8235318660736084,10000.0,119402.13184118272,123613.46213245392,119402.13184118272,4181.410618782044,16.984424591064453,0.0 -354900,4.619419,0.55241513,,,,,,,,,,,,,, -355000,4.4112697,0.58730733,,,,,,,,,,,,,, -355100,4.7784595,0.66785777,,,,,,,,,,,,,, -355200,4.806773,0.58899564,,,,,,,,,,,,,, -355300,4.8636875,0.6145062,,,,,,,,,,,,,, -355400,4.2046156,0.57088083,,,,,,,,,,,,,, -355500,4.4657993,0.6662137,,,,,,,,,,,,,, -355600,4.373284,0.5910507,,,,,,,,,,,,,, -355700,4.3981905,0.58320737,,,,,,,,,,,,,, -355800,4.3632197,0.57867104,,,,,,,,,,,,,, -355900,4.835887,0.6500525,,,,,,,,,,,,,, -356000,4.8418527,0.6764602,,,,,,,,,,,,,, -356100,4.689704,0.5802652,,,,,,,,,,,,,, -356200,4.2934437,0.5367022,,,,,,,,,,,,,, -356300,5.045532,0.66652024,,,,,,,,,,,,,, -356354,,,0.9602997303009032,0.1481306254863739,0.7575599551200867,1.041216492652893,50000.0,0.6325000524520874,1.8233606815338133,10000.0,119912.02209234238,124140.33087706566,119912.02209234238,4198.244757652283,17.072447061538696,0.0 -356400,4.7323794,0.5683443,,,,,,,,,,,,,, -356500,5.0464544,0.5888346,,,,,,,,,,,,,, -356600,4.4791684,0.6252127,,,,,,,,,,,,,, -356700,4.6835103,0.6063509,,,,,,,,,,,,,, -356800,4.3285613,0.62199795,,,,,,,,,,,,,, -356900,4.380378,0.6438398,,,,,,,,,,,,,, -357000,4.677824,0.6358334,,,,,,,,,,,,,, -357100,4.5645986,0.58981407,,,,,,,,,,,,,, -357200,4.272852,0.58719903,,,,,,,,,,,,,, -357300,4.405766,0.6211109,,,,,,,,,,,,,, -357400,4.361452,0.55263317,,,,,,,,,,,,,, -357500,4.2773824,0.6097053,,,,,,,,,,,,,, -357600,4.168529,0.6117792,,,,,,,,,,,,,, -357700,4.2105455,0.6006726,,,,,,,,,,,,,, -357800,4.3203382,0.62350976,,,,,,,,,,,,,, -357871,,,0.9605388641357422,0.1478061825037002,0.7572000026702881,1.0402666330337524,50000.0,0.6318000555038452,1.8204973936080933,10000.0,120421.88165020944,124667.38286662102,120421.88165020944,4215.293922185898,17.160407543182373,0.0 -357900,4.46663,0.667486,,,,,,,,,,,,,, -358000,4.353759,0.5928145,,,,,,,,,,,,,, -358100,4.0868998,0.56005627,,,,,,,,,,,,,, -358200,4.5556564,0.63830763,,,,,,,,,,,,,, -358300,4.7577477,0.63742733,,,,,,,,,,,,,, -358400,5.016017,0.6468506,,,,,,,,,,,,,, -358500,4.5660214,0.6383529,,,,,,,,,,,,,, -358600,4.481777,0.6174301,,,,,,,,,,,,,, -358700,4.7851253,0.6792671,,,,,,,,,,,,,, -358800,5.03504,0.7308649,,,,,,,,,,,,,, -358900,4.523811,0.6366883,,,,,,,,,,,,,, -359000,4.2623105,0.60631305,,,,,,,,,,,,,, -359100,4.2315335,0.6166744,,,,,,,,,,,,,, -359200,4.346892,0.57734674,,,,,,,,,,,,,, -359300,4.1823225,0.6189065,,,,,,,,,,,,,, -359388,,,0.9615951776504515,0.1428749710321426,0.7572000026702881,1.0410480499267578,50000.0,0.6321000456809998,1.822807669639588,10000.0,120931.80952954292,125194.51679444312,120931.80952954292,4232.358246326447,17.2476224899292,0.0 -359400,4.5323763,0.6419382,,,,,,,,,,,,,, -359500,4.1575484,0.610188,,,,,,,,,,,,,, -359600,4.264355,0.61853284,,,,,,,,,,,,,, -359700,4.818313,0.67988086,,,,,,,,,,,,,, -359800,4.530746,0.61306924,,,,,,,,,,,,,, -359900,4.472242,0.6436572,,,,,,,,,,,,,, -360000,4.6895046,0.6201012,,,,,,,,,,,,,, -360100,4.630631,0.58393425,,,,,,,,,,,,,, -360200,4.5397944,0.6148894,,,,,,,,,,,,,, -360300,4.265014,0.608763,,,,,,,,,,,,,, -360400,3.9177468,0.55989885,,,,,,,,,,,,,, -360500,5.0129685,0.6931766,,,,,,,,,,,,,, -360600,4.4489684,0.61754876,,,,,,,,,,,,,, -360700,5.125072,0.66984296,,,,,,,,,,,,,, -360800,4.3540444,0.63653284,,,,,,,,,,,,,, -360900,3.8897405,0.61181396,,,,,,,,,,,,,, -360905,,,0.9600805044174194,0.1505532264709472,0.7570799589157104,1.0411326885223389,50000.0,0.6319000124931335,1.8209383487701416,10000.0,121441.90519046783,125721.65953350069,121441.90519046783,4249.258747339249,17.337990045547485,0.0 -361000,4.3305516,0.65651596,,,,,,,,,,,,,, -361100,4.400028,0.68366283,,,,,,,,,,,,,, -361200,4.365859,0.61637974,,,,,,,,,,,,,, -361300,5.1896815,0.6475028,,,,,,,,,,,,,, -361400,4.6734486,0.6198004,,,,,,,,,,,,,, -361500,4.350051,0.58808374,,,,,,,,,,,,,, -361600,4.3047223,0.5562439,,,,,,,,,,,,,, -361700,5.005555,0.59464514,,,,,,,,,,,,,, -361800,4.2967257,0.55233854,,,,,,,,,,,,,, -361900,4.635618,0.63862723,,,,,,,,,,,,,, -362000,4.41037,0.6573347,,,,,,,,,,,,,, -362100,4.5332766,0.5651215,,,,,,,,,,,,,, -362200,4.244722,0.59678864,,,,,,,,,,,,,, -362300,4.429346,0.7128808,,,,,,,,,,,,,, -362400,5.3229885,0.6656412,,,,,,,,,,,,,, -362421,,,0.9601004123687744,0.1472820490598678,0.7572000026702881,1.0414822101593018,50000.0,0.6318000555038452,1.8226184844970703,10000.0,121951.9643895626,126248.82946372032,121951.9643895626,4266.226603746414,17.42392110824585,0.0 -362500,4.5538387,0.64789504,,,,,,,,,,,,,, -362600,4.59453,0.6532036,,,,,,,,,,,,,, -362700,4.6052117,0.61818117,,,,,,,,,,,,,, -362800,4.7835927,0.6036982,,,,,,,,,,,,,, -362900,4.6716313,0.649971,,,,,,,,,,,,,, -363000,4.5377893,0.63495654,,,,,,,,,,,,,, -363100,4.573391,0.6107859,,,,,,,,,,,,,, -363200,4.9450045,0.6245955,,,,,,,,,,,,,, -363300,4.580547,0.6501857,,,,,,,,,,,,,, -363400,4.8868904,0.6340623,,,,,,,,,,,,,, -363500,3.8560474,0.5249023,,,,,,,,,,,,,, -363600,4.729967,0.6142442,,,,,,,,,,,,,, -363700,4.6163716,0.6523628,,,,,,,,,,,,,, -363800,4.2542787,0.62095577,,,,,,,,,,,,,, -363900,4.4048414,0.61779195,,,,,,,,,,,,,, -363938,,,0.9604990482330322,0.1480488926172256,0.7569999694824219,1.0426230430603027,50000.0,0.6324000358581543,1.823939323425293,10000.0,122462.1295595169,126775.99683475494,122462.1295595169,4283.07874751091,17.518263339996338,0.0 -364000,4.752924,0.70324624,,,,,,,,,,,,,, -364100,4.5790954,0.683509,,,,,,,,,,,,,, -364200,5.118109,0.63316774,,,,,,,,,,,,,, -364300,4.4086328,0.6366161,,,,,,,,,,,,,, -364400,4.6056433,0.65573907,,,,,,,,,,,,,, -364500,4.72945,0.5986855,,,,,,,,,,,,,, -364600,4.4872403,0.63443935,,,,,,,,,,,,,, -364700,4.4271226,0.5941959,,,,,,,,,,,,,, -364800,4.3437657,0.65089196,,,,,,,,,,,,,, -364900,4.650759,0.7414531,,,,,,,,,,,,,, -365000,4.573879,0.588309,,,,,,,,,,,,,, -365100,4.462326,0.65094143,,,,,,,,,,,,,, -365200,4.212694,0.62568754,,,,,,,,,,,,,, -365300,4.496864,0.6583023,,,,,,,,,,,,,, -365400,4.81785,0.6449404,,,,,,,,,,,,,, -365455,,,0.9614756107330322,0.1442434042692184,0.7568599581718445,1.0414214134216309,50000.0,0.6309000253677368,1.822780728340149,10000.0,122972.127712965,127303.0500793457,122972.127712965,4299.988889694214,17.606799125671387,0.0 -365500,4.5118947,0.605001,,,,,,,,,,,,,, -365600,5.046798,0.68943805,,,,,,,,,,,,,, -365700,4.52579,0.6177137,,,,,,,,,,,,,, -365800,4.2851286,0.6028459,,,,,,,,,,,,,, -365900,4.4912715,0.56713504,,,,,,,,,,,,,, -366000,4.7030377,0.6628494,,,,,,,,,,,,,, -366100,4.824654,0.60126007,,,,,,,,,,,,,, -366200,4.859527,0.6406085,,,,,,,,,,,,,, -366300,4.2391915,0.6048421,,,,,,,,,,,,,, -366400,4.616652,0.6157515,,,,,,,,,,,,,, -366500,4.5872855,0.6165942,,,,,,,,,,,,,, -366600,4.5148916,0.729275,,,,,,,,,,,,,, -366700,4.278248,0.61246157,,,,,,,,,,,,,, -366800,4.136613,0.568491,,,,,,,,,,,,,, -366900,4.896882,0.6791549,,,,,,,,,,,,,, -366972,,,0.960758090019226,0.1481673419475555,0.7574399709701538,1.0411081314086914,50000.0,0.6313000321388245,1.823107481002808,10000.0,123482.26363253592,127830.26229286194,123482.26363253592,4316.926568746567,17.68900442123413,0.0 -367000,4.315459,0.6594577,,,,,,,,,,,,,, -367100,4.8079143,0.6786591,,,,,,,,,,,,,, -367200,4.374333,0.6156193,,,,,,,,,,,,,, -367300,4.199739,0.53450406,,,,,,,,,,,,,, -367400,4.3978686,0.6326462,,,,,,,,,,,,,, -367500,4.43926,0.62613374,,,,,,,,,,,,,, -367600,4.034327,0.54343665,,,,,,,,,,,,,, -367700,5.0082097,0.5999256,,,,,,,,,,,,,, -367800,4.1311717,0.6177944,,,,,,,,,,,,,, -367900,4.494947,0.6424342,,,,,,,,,,,,,, -368000,4.6659594,0.62418294,,,,,,,,,,,,,, -368100,4.493065,0.6005312,,,,,,,,,,,,,, -368200,4.7937827,0.7043301,,,,,,,,,,,,,, -368300,4.4434,0.59257376,,,,,,,,,,,,,, -368400,4.411141,0.65077543,,,,,,,,,,,,,, -368489,,,0.9606983065605164,0.1471341699361801,0.7573999762535095,1.0416234731674194,50000.0,0.631600022315979,1.8244590759277344,10000.0,123992.4189324379,128357.54884719849,123992.4189324379,4333.906269550324,17.783448457717896,0.0 -368500,4.8583927,0.60862863,,,,,,,,,,,,,, -368600,4.090558,0.6138004,,,,,,,,,,,,,, -368700,4.401532,0.56308466,,,,,,,,,,,,,, -368800,4.243534,0.58744395,,,,,,,,,,,,,, -368900,4.4837747,0.6280183,,,,,,,,,,,,,, -369000,3.941592,0.58888316,,,,,,,,,,,,,, -369100,4.0983353,0.5941891,,,,,,,,,,,,,, -369200,4.8635173,0.62352234,,,,,,,,,,,,,, -369300,4.644632,0.6480631,,,,,,,,,,,,,, -369400,4.480322,0.60288167,,,,,,,,,,,,,, -369500,4.3161263,0.6403157,,,,,,,,,,,,,, -369600,4.325293,0.6133744,,,,,,,,,,,,,, -369700,4.044986,0.5369192,,,,,,,,,,,,,, -369800,4.5902147,0.6543213,,,,,,,,,,,,,, -369900,4.3956265,0.6908363,,,,,,,,,,,,,, -370000,4.317554,0.5838169,,,,,,,,,,,,,, -370006,,,0.960359513759613,0.148869127035141,0.7570399641990662,1.040926814079285,50000.0,0.6321000456809998,1.8226678371429443,10000.0,124502.4652018547,128884.87219500542,124502.4652018547,4351.035741329193,17.875396251678467,0.0 -370100,4.4368825,0.62657857,,,,,,,,,,,,,, -370200,4.5426273,0.5876689,,,,,,,,,,,,,, -370300,4.657954,0.6728895,,,,,,,,,,,,,, -370400,4.3251076,0.6068217,,,,,,,,,,,,,, -370500,5.113296,0.6745418,,,,,,,,,,,,,, -370600,4.457405,0.62691325,,,,,,,,,,,,,, -370700,4.2493176,0.57566124,,,,,,,,,,,,,, -370800,4.308389,0.63374066,,,,,,,,,,,,,, -370900,4.5252695,0.6760158,,,,,,,,,,,,,, -371000,4.2948604,0.5561825,,,,,,,,,,,,,, -371100,4.4102488,0.6127819,,,,,,,,,,,,,, -371200,4.117798,0.5337275,,,,,,,,,,,,,, -371300,5.037225,0.713624,,,,,,,,,,,,,, -371400,4.3616347,0.6309586,,,,,,,,,,,,,, -371500,4.37374,0.5872676,,,,,,,,,,,,,, -371523,,,0.9600406289100648,0.1471908539533615,0.7571399807929993,1.0407812595367432,50000.0,0.6314000487327576,1.8242459297180176,10000.0,125012.54879522324,129411.96352028848,125012.54879522324,4367.901885032654,17.96213459968567,0.0 -371600,4.4552994,0.6224986,,,,,,,,,,,,,, -371700,4.486222,0.66424716,,,,,,,,,,,,,, -371800,4.3154655,0.62220407,,,,,,,,,,,,,, -371900,4.0823545,0.56586224,,,,,,,,,,,,,, -372000,4.212496,0.6078763,,,,,,,,,,,,,, -372100,5.5083833,0.6840428,,,,,,,,,,,,,, -372200,4.7180605,0.5868959,,,,,,,,,,,,,, -372300,4.144439,0.6213695,,,,,,,,,,,,,, -372400,4.1458006,0.56420904,,,,,,,,,,,,,, -372500,4.6699586,0.60627246,,,,,,,,,,,,,, -372600,4.4611864,0.599145,,,,,,,,,,,,,, -372700,4.1387143,0.51352394,,,,,,,,,,,,,, -372800,4.566201,0.5803032,,,,,,,,,,,,,, -372900,4.2977524,0.66824305,,,,,,,,,,,,,, -373000,5.791295,0.7223254,,,,,,,,,,,,,, -373039,,,0.9602399468421936,0.1467160880565643,0.7579599618911743,1.0410587787628174,50000.0,0.6312000155448914,1.8227859735488887,10000.0,125522.60182905196,129939.24783706664,125522.60182905196,4384.982523202896,18.056538581848145,0.0 -373100,4.5864706,0.66519773,,,,,,,,,,,,,, -373200,4.572773,0.65044683,,,,,,,,,,,,,, -373300,4.2448955,0.61065036,,,,,,,,,,,,,, -373400,3.977478,0.55899084,,,,,,,,,,,,,, -373500,4.393499,0.5686325,,,,,,,,,,,,,, -373600,4.934947,0.60066986,,,,,,,,,,,,,, -373700,4.585904,0.6763009,,,,,,,,,,,,,, -373800,4.75696,0.63301945,,,,,,,,,,,,,, -373900,5.0611854,0.6734755,,,,,,,,,,,,,, -374000,4.4344826,0.63095903,,,,,,,,,,,,,, -374100,4.3368325,0.5903427,,,,,,,,,,,,,, -374200,4.4855037,0.63946074,,,,,,,,,,,,,, -374300,4.440218,0.6873616,,,,,,,,,,,,,, -374400,4.306942,0.60045457,,,,,,,,,,,,,, -374500,4.488383,0.6621393,,,,,,,,,,,,,, -374556,,,0.9602000713348388,0.1454567462205886,0.7567200064659119,1.041068434715271,50000.0,0.6315000057220459,1.823155164718628,10000.0,126032.49347376823,130466.12568640707,126032.49347376823,4401.814165115356,18.15566897392273,0.0 -374600,4.599257,0.6216437,,,,,,,,,,,,,, -374700,4.402191,0.6137879,,,,,,,,,,,,,, -374800,4.5790944,0.72702104,,,,,,,,,,,,,, -374900,4.6814427,0.6851376,,,,,,,,,,,,,, -375000,4.153455,0.5627458,,,,,,,,,,,,,, -375100,4.6312017,0.6623045,,,,,,,,,,,,,, -375200,4.497569,0.6389957,,,,,,,,,,,,,, -375300,4.3900285,0.56813556,,,,,,,,,,,,,, -375400,4.2172112,0.6243378,,,,,,,,,,,,,, -375500,4.682629,0.56660885,,,,,,,,,,,,,, -375600,4.829431,0.6826647,,,,,,,,,,,,,, -375700,4.416032,0.62107176,,,,,,,,,,,,,, -375800,4.576771,0.6739544,,,,,,,,,,,,,, -375900,4.8640723,0.6640359,,,,,,,,,,,,,, -376000,4.307876,0.60780185,,,,,,,,,,,,,, -376073,,,0.9597616195678712,0.1479842811822891,0.7572000026702881,1.0399235486984253,50000.0,0.6321000456809998,1.8223358392715447,10000.0,126542.64698553084,130993.41924715042,126542.64698553084,4418.809207677841,18.24462342262268,0.0 -376100,4.450521,0.7095997,,,,,,,,,,,,,, -376200,4.4876657,0.5947043,,,,,,,,,,,,,, -376300,4.2453856,0.5869652,,,,,,,,,,,,,, -376400,4.3959475,0.6006209,,,,,,,,,,,,,, -376500,4.8584294,0.6775361,,,,,,,,,,,,,, -376600,4.83894,0.63306713,,,,,,,,,,,,,, -376700,4.5343223,0.70236343,,,,,,,,,,,,,, -376800,4.89268,0.6132185,,,,,,,,,,,,,, -376900,4.5185905,0.6334641,,,,,,,,,,,,,, -377000,4.879137,0.63605666,,,,,,,,,,,,,, -377100,4.383447,0.6479551,,,,,,,,,,,,,, -377200,4.8797145,0.65128756,,,,,,,,,,,,,, -377300,4.283892,0.6888142,,,,,,,,,,,,,, -377400,5.4704137,0.6349887,,,,,,,,,,,,,, -377500,4.166795,0.6114903,,,,,,,,,,,,,, -377591,,,0.9620535373687744,0.1455946564674377,0.7574799656867981,1.0399705171585083,50000.0,0.631600022315979,1.8216814994812007,10000.0,127052.61329627036,131520.42503738403,127052.61329627036,4435.703264951706,18.334900856018063,0.0 -377600,4.3329897,0.5466653,,,,,,,,,,,,,, -377700,4.502146,0.6287033,,,,,,,,,,,,,, -377800,4.5332274,0.61679333,,,,,,,,,,,,,, -377900,4.5719786,0.6548628,,,,,,,,,,,,,, -378000,4.380725,0.5975162,,,,,,,,,,,,,, -378100,4.0713725,0.6241452,,,,,,,,,,,,,, -378200,4.82547,0.60763884,,,,,,,,,,,,,, -378300,4.8143682,0.64681655,,,,,,,,,,,,,, -378400,4.519437,0.6587871,,,,,,,,,,,,,, -378500,4.7188487,0.6586994,,,,,,,,,,,,,, -378600,4.1275787,0.5924249,,,,,,,,,,,,,, -378700,4.2345705,0.6036354,,,,,,,,,,,,,, -378800,4.6379476,0.63242066,,,,,,,,,,,,,, -378900,4.4660935,0.6410572,,,,,,,,,,,,,, -379000,4.539423,0.6986047,,,,,,,,,,,,,, -379100,4.3956747,0.6065575,,,,,,,,,,,,,, -379106,,,0.9601004123687744,0.1481819748878479,0.756879985332489,1.0412617921829224,50000.0,0.6313000321388245,1.822556495666504,10000.0,127562.5142903328,132047.43829274178,127562.5142903328,4452.67462849617,18.42171573638916,0.0 -379200,4.522868,0.62946475,,,,,,,,,,,,,, -379300,4.4858875,0.7064362,,,,,,,,,,,,,, -379400,4.408085,0.62376446,,,,,,,,,,,,,, -379500,4.391274,0.68006694,,,,,,,,,,,,,, -379600,4.8360286,0.69840026,,,,,,,,,,,,,, -379700,4.293373,0.6162896,,,,,,,,,,,,,, -379800,4.5400324,0.6432084,,,,,,,,,,,,,, -379900,4.832267,0.6481722,,,,,,,,,,,,,, -380000,4.176577,0.6144922,,,,,,,,,,,,,, -380100,4.3984365,0.62717,,,,,,,,,,,,,, -380200,4.5639377,0.6217516,,,,,,,,,,,,,, -380300,4.5305533,0.6566669,,,,,,,,,,,,,, -380400,4.8152366,0.56956565,,,,,,,,,,,,,, -380500,4.604549,0.641684,,,,,,,,,,,,,, -380600,4.438777,0.6519508,,,,,,,,,,,,,, -380623,,,0.960758090019226,0.1461323052644729,0.7569599747657776,1.0406873226165771,50000.0,0.6308000087738037,1.8233776092529297,10000.0,128072.50441098212,132574.43628549576,128072.50441098212,4469.535237550736,18.51426196098328,0.0 -380700,4.3255067,0.63561696,,,,,,,,,,,,,, -380800,4.9689636,0.64109564,,,,,,,,,,,,,, -380900,4.721635,0.6721432,,,,,,,,,,,,,, -381000,4.500946,0.67026526,,,,,,,,,,,,,, -381100,4.08017,0.57037103,,,,,,,,,,,,,, -381200,4.4376736,0.598251,,,,,,,,,,,,,, -381300,4.690576,0.5791743,,,,,,,,,,,,,, -381400,4.3832626,0.6223623,,,,,,,,,,,,,, -381500,4.2855077,0.57195365,,,,,,,,,,,,,, -381600,4.454084,0.5991731,,,,,,,,,,,,,, -381700,4.1548386,0.5610816,,,,,,,,,,,,,, -381800,4.173177,0.57365966,,,,,,,,,,,,,, -381900,4.2142158,0.5962018,,,,,,,,,,,,,, -382000,4.8456163,0.62744564,,,,,,,,,,,,,, -382100,5.257357,0.57501566,,,,,,,,,,,,,, -382139,,,0.9616549611091614,0.1419612616300583,0.7573399543762207,1.040238618850708,50000.0,0.6324000358581543,1.8214529752731323,10000.0,128582.37007761002,133101.20618534088,128582.37007761002,4486.295618534088,18.60328149795532,0.0 -382200,5.0038157,0.6740377,,,,,,,,,,,,,, -382300,4.7619653,0.6858344,,,,,,,,,,,,,, -382400,4.6896605,0.70347095,,,,,,,,,,,,,, -382500,4.090133,0.58536476,,,,,,,,,,,,,, -382600,4.925391,0.6405778,,,,,,,,,,,,,, -382700,5.0712037,0.639723,,,,,,,,,,,,,, -382800,4.2112393,0.62614036,,,,,,,,,,,,,, -382900,4.501566,0.7385887,,,,,,,,,,,,,, -383000,4.41572,0.5988623,,,,,,,,,,,,,, -383100,4.934779,0.6799033,,,,,,,,,,,,,, -383200,4.4572725,0.6139265,,,,,,,,,,,,,, -383300,5.054582,0.68820226,,,,,,,,,,,,,, -383400,4.3300433,0.61963725,,,,,,,,,,,,,, -383500,5.226134,0.66398954,,,,,,,,,,,,,, -383600,4.4365883,0.5722884,,,,,,,,,,,,,, -383655,,,0.962511956691742,0.1423533260822296,0.7570799589157104,1.0416830778121948,50000.0,0.6321000456809998,1.824738621711731,10000.0,129092.24616885184,133628.38803076744,129092.24616885184,4503.450261116028,18.69795846939087,0.0 -383700,4.6574583,0.6571918,,,,,,,,,,,,,, -383800,4.398022,0.5907621,,,,,,,,,,,,,, -383900,4.6989913,0.67819315,,,,,,,,,,,,,, -384000,4.7285104,0.5945168,,,,,,,,,,,,,, -384100,4.6545305,0.61661494,,,,,,,,,,,,,, -384200,4.4232483,0.7119055,,,,,,,,,,,,,, -384300,4.8072276,0.7156792,,,,,,,,,,,,,, -384400,4.5346894,0.6006015,,,,,,,,,,,,,, -384500,4.443144,0.6327584,,,,,,,,,,,,,, -384600,4.9441085,0.64602846,,,,,,,,,,,,,, -384700,4.394858,0.6838817,,,,,,,,,,,,,, -384800,4.6139874,0.6577077,,,,,,,,,,,,,, -384900,4.5857635,0.6260532,,,,,,,,,,,,,, -385000,4.626877,0.69715124,,,,,,,,,,,,,, -385100,4.4273825,0.6399307,,,,,,,,,,,,,, -385172,,,0.9605189561843872,0.1476364582777023,0.757319986820221,1.0413920879364014,50000.0,0.6325000524520874,1.8220690488815308,10000.0,129602.22697257996,134155.97039675713,129602.22697257996,4520.90341758728,18.79100012779236,0.0 -385200,4.5768123,0.6885208,,,,,,,,,,,,,, -385300,4.618584,0.6079851,,,,,,,,,,,,,, -385400,4.2052355,0.55068815,,,,,,,,,,,,,, -385500,4.49597,0.66994596,,,,,,,,,,,,,, -385600,4.538528,0.59336215,,,,,,,,,,,,,, -385700,4.95981,0.69527745,,,,,,,,,,,,,, -385800,4.3069553,0.64644134,,,,,,,,,,,,,, -385900,4.838423,0.61343557,,,,,,,,,,,,,, -386000,4.2096977,0.63540673,,,,,,,,,,,,,, -386100,4.6183033,0.65915346,,,,,,,,,,,,,, -386200,5.26634,0.6396552,,,,,,,,,,,,,, -386300,4.1439543,0.58743376,,,,,,,,,,,,,, -386400,4.323474,0.63813734,,,,,,,,,,,,,, -386500,4.63545,0.648616,,,,,,,,,,,,,, -386600,4.154408,0.60435957,,,,,,,,,,,,,, -386689,,,0.9601402878761292,0.1475104093551635,0.7570799589157104,1.0403294563293457,50000.0,0.6318000555038452,1.8210442066192627,10000.0,130112.18139767648,134682.88991069794,130112.18139767648,4537.72211432457,18.88157558441162,0.0 -386700,4.337417,0.59789723,,,,,,,,,,,,,, -386800,4.555255,0.6326979,,,,,,,,,,,,,, -386900,4.485273,0.6136484,,,,,,,,,,,,,, -387000,4.6245794,0.61841434,,,,,,,,,,,,,, -387100,4.562454,0.6666015,,,,,,,,,,,,,, -387200,4.285933,0.626669,,,,,,,,,,,,,, -387300,4.499692,0.58805156,,,,,,,,,,,,,, -387400,4.2750454,0.6319996,,,,,,,,,,,,,, -387500,5.0122013,0.6234054,,,,,,,,,,,,,, -387600,4.788638,0.62211066,,,,,,,,,,,,,, -387700,3.857637,0.55259764,,,,,,,,,,,,,, -387800,5.073252,0.65936095,,,,,,,,,,,,,, -387900,5.0323434,0.63683474,,,,,,,,,,,,,, -388000,4.3072705,0.5616273,,,,,,,,,,,,,, -388100,4.2811494,0.6217321,,,,,,,,,,,,,, -388200,4.0588236,0.5581053,,,,,,,,,,,,,, -388206,,,0.9606385231018066,0.1463115364313125,0.7573999762535095,1.0419403314590454,50000.0,0.6315000057220459,1.8237435817718504,10000.0,130622.19213318823,135210.0751516819,130622.19213318823,4554.750309705734,18.971620559692383,0.0 -388300,4.4344015,0.6109204,,,,,,,,,,,,,, -388400,4.283226,0.586763,,,,,,,,,,,,,, -388500,4.228408,0.6250955,,,,,,,,,,,,,, -388600,4.262255,0.6226857,,,,,,,,,,,,,, -388700,4.167828,0.600053,,,,,,,,,,,,,, -388800,4.20775,0.554448,,,,,,,,,,,,,, -388900,4.05748,0.5472436,,,,,,,,,,,,,, -389000,4.4771085,0.6585233,,,,,,,,,,,,,, -389100,4.828597,0.652529,,,,,,,,,,,,,, -389200,4.330093,0.6027756,,,,,,,,,,,,,, -389300,4.7730136,0.6244598,,,,,,,,,,,,,, -389400,4.2124295,0.6451759,,,,,,,,,,,,,, -389500,4.6260934,0.6333563,,,,,,,,,,,,,, -389600,4.6677623,0.64655787,,,,,,,,,,,,,, -389700,4.698163,0.62351304,,,,,,,,,,,,,, -389723,,,0.9618542790412904,0.1434044390916824,0.7575399875640869,1.0409173965454102,50000.0,0.6317000389099121,1.8231186866760247,10000.0,131132.10186076164,135737.15456604958,131132.10186076164,4571.768753767014,19.06741619110108,0.0 -389800,4.628718,0.6053753,,,,,,,,,,,,,, -389900,4.2241874,0.60294294,,,,,,,,,,,,,, -390000,5.014693,0.6385417,,,,,,,,,,,,,, -390100,4.277373,0.5521721,,,,,,,,,,,,,, -390200,4.1614666,0.6084703,,,,,,,,,,,,,, -390300,5.0199256,0.6495204,,,,,,,,,,,,,, -390400,4.506463,0.6733204,,,,,,,,,,,,,, -390500,5.0124254,0.6360622,,,,,,,,,,,,,, -390600,4.597241,0.6081548,,,,,,,,,,,,,, -390700,4.742088,0.6537264,,,,,,,,,,,,,, -390800,4.6241274,0.6636894,,,,,,,,,,,,,, -390900,4.1273794,0.5932379,,,,,,,,,,,,,, -391000,4.849157,0.6229763,,,,,,,,,,,,,, -391100,4.068198,0.6382474,,,,,,,,,,,,,, -391200,4.444214,0.6940107,,,,,,,,,,,,,, -391240,,,0.961734652519226,0.1444765478372573,0.7570799589157104,1.0409315824508667,50000.0,0.6317000389099121,1.8221389055252075,10000.0,131642.28998470306,136264.5047211647,131642.28998470306,4588.780303239822,19.16048574447632,0.0 -391300,5.2622485,0.72063243,,,,,,,,,,,,,, -391400,5.5860972,0.68248355,,,,,,,,,,,,,, -391500,4.617274,0.6641795,,,,,,,,,,,,,, -391600,4.1108503,0.5675003,,,,,,,,,,,,,, -391700,4.1986413,0.5616381,,,,,,,,,,,,,, -391800,4.209088,0.57886714,,,,,,,,,,,,,, -391900,3.8299396,0.5609329,,,,,,,,,,,,,, -392000,4.3582034,0.6288401,,,,,,,,,,,,,, -392100,4.402908,0.6654306,,,,,,,,,,,,,, -392200,4.1013656,0.5917492,,,,,,,,,,,,,, -392300,4.2979527,0.6933437,,,,,,,,,,,,,, -392400,4.8131156,0.65679955,,,,,,,,,,,,,, -392500,4.2016907,0.54012144,,,,,,,,,,,,,, -392600,4.3381844,0.63686407,,,,,,,,,,,,,, -392700,4.5548153,0.6147005,,,,,,,,,,,,,, -392757,,,0.960558831691742,0.1481156349182129,0.7575199604034424,1.040520191192627,50000.0,0.6325000524520874,1.8216640949249268,10000.0,132152.3566787243,136791.69770216942,132152.3566787243,4605.751079559326,19.26059985160828,0.0 -392800,4.8243737,0.67299634,,,,,,,,,,,,,, -392900,4.9469666,0.64491606,,,,,,,,,,,,,, -393000,4.5438094,0.6525552,,,,,,,,,,,,,, -393100,4.483789,0.7286072,,,,,,,,,,,,,, -393200,4.5796027,0.6579404,,,,,,,,,,,,,, -393300,4.388486,0.68735504,,,,,,,,,,,,,, -393400,4.26287,0.56819737,,,,,,,,,,,,,, -393500,4.1708927,0.5551715,,,,,,,,,,,,,, -393600,4.818924,0.6389298,,,,,,,,,,,,,, -393700,4.3296103,0.64054036,,,,,,,,,,,,,, -393800,4.0744443,0.55822706,,,,,,,,,,,,,, -393900,4.911021,0.6965642,,,,,,,,,,,,,, -394000,4.478907,0.55650246,,,,,,,,,,,,,, -394100,4.24347,0.58632344,,,,,,,,,,,,,, -394200,4.0971637,0.63315797,,,,,,,,,,,,,, -394274,,,0.9601203799247742,0.1482566446065902,0.7576000094413757,1.0401118993759155,50000.0,0.6319000124931335,1.8207794427871704,10000.0,132662.2222611904,137318.67782998085,132662.2222611904,4622.717230796814,19.35450482368469,0.0 -394300,4.7278657,0.7046411,,,,,,,,,,,,,, -394400,4.629856,0.6254895,,,,,,,,,,,,,, -394500,4.800549,0.68317467,,,,,,,,,,,,,, -394600,4.9483085,0.61701554,,,,,,,,,,,,,, -394700,4.7390633,0.6335534,,,,,,,,,,,,,, -394800,4.53247,0.5681528,,,,,,,,,,,,,, -394900,4.733969,0.71829164,,,,,,,,,,,,,, -395000,4.5943995,0.60659593,,,,,,,,,,,,,, -395100,4.508483,0.67704964,,,,,,,,,,,,,, -395200,5.0692177,0.679639,,,,,,,,,,,,,, -395300,4.358581,0.60634995,,,,,,,,,,,,,, -395400,4.137524,0.582574,,,,,,,,,,,,,, -395500,4.664341,0.667737,,,,,,,,,,,,,, -395600,4.094206,0.54388016,,,,,,,,,,,,,, -395700,4.3742185,0.6109247,,,,,,,,,,,,,, -395792,,,0.9600805044174194,0.1480380147695541,0.7572199702262878,1.0424708127975464,50000.0,0.6313000321388245,1.8253854513168333,10000.0,133172.32787752151,137845.96401286125,133172.32787752151,4639.747121095657,19.45023393630981,0.0 -395800,4.9888606,0.74515843,,,,,,,,,,,,,, -395900,4.615099,0.64091337,,,,,,,,,,,,,, -396000,4.821796,0.6325322,,,,,,,,,,,,,, -396100,4.538992,0.61788315,,,,,,,,,,,,,, -396200,4.1685805,0.6068014,,,,,,,,,,,,,, -396300,4.700275,0.724684,,,,,,,,,,,,,, -396400,4.33158,0.6823514,,,,,,,,,,,,,, -396500,4.4544725,0.57903576,,,,,,,,,,,,,, -396600,4.19628,0.61793286,,,,,,,,,,,,,, -396700,4.955134,0.71720725,,,,,,,,,,,,,, -396800,4.7645054,0.6920551,,,,,,,,,,,,,, -396900,4.75876,0.5801294,,,,,,,,,,,,,, -397000,4.659948,0.5483172,,,,,,,,,,,,,, -397100,4.6145725,0.60048175,,,,,,,,,,,,,, -397200,4.576809,0.6100071,,,,,,,,,,,,,, -397300,4.868367,0.6973423,,,,,,,,,,,,,, -397308,,,0.9609972834587096,0.1444547176361084,0.7571600079536438,1.039499282836914,50000.0,0.6315000057220459,1.821471929550171,10000.0,133682.22514748573,138373.00830101967,133682.22514748573,4656.741573810577,19.547237634658813,0.0 -397400,4.849532,0.6663146,,,,,,,,,,,,,, -397500,4.5267334,0.6185535,,,,,,,,,,,,,, -397600,4.581249,0.67143625,,,,,,,,,,,,,, -397700,4.3559813,0.6849482,,,,,,,,,,,,,, -397800,4.593084,0.60775775,,,,,,,,,,,,,, -397900,4.1556334,0.56081456,,,,,,,,,,,,,, -398000,4.6769075,0.6058879,,,,,,,,,,,,,, -398100,5.0452495,0.71112233,,,,,,,,,,,,,, -398200,4.482971,0.6186341,,,,,,,,,,,,,, -398300,4.937167,0.67710185,,,,,,,,,,,,,, -398400,4.3627043,0.645105,,,,,,,,,,,,,, -398500,4.599324,0.6691511,,,,,,,,,,,,,, -398600,4.460217,0.65023375,,,,,,,,,,,,,, -398700,4.4482923,0.6339604,,,,,,,,,,,,,, -398800,5.53916,0.6171187,,,,,,,,,,,,,, -398824,,,0.9614756107330322,0.1464393436908722,0.757099986076355,1.0412877798080444,50000.0,0.6321000456809998,1.8227674961090088,10000.0,134192.11761021614,138899.8560230732,134192.11761021614,4673.543560028076,19.645186185836792,0.0 -398900,4.762585,0.61702937,,,,,,,,,,,,,, -399000,4.719602,0.6173166,,,,,,,,,,,,,, -399100,4.861806,0.5942258,,,,,,,,,,,,,, -399200,4.584905,0.67085326,,,,,,,,,,,,,, -399300,4.6117206,0.6314763,,,,,,,,,,,,,, -399400,4.8007145,0.5748669,,,,,,,,,,,,,, -399500,4.4167924,0.56072944,,,,,,,,,,,,,, -399600,4.1679964,0.6306199,,,,,,,,,,,,,, -399700,4.3905272,0.59180933,,,,,,,,,,,,,, -399800,4.454576,0.59597266,,,,,,,,,,,,,, -399900,4.6319227,0.69751394,,,,,,,,,,,,,, -400000,4.479931,0.60630995,,,,,,,,,,,,,, -400100,4.7060776,0.5864171,,,,,,,,,,,,,, -400200,3.9583528,0.55203545,,,,,,,,,,,,,, -400300,4.3891616,0.6660118,,,,,,,,,,,,,, -400342,,,0.9597616195678712,0.146850436925888,0.7565400004386902,1.041675686836243,50000.0,0.6321000456809998,1.8230339288711548,10000.0,134702.17586684227,139427.07520484924,134702.17586684227,4690.554224252701,19.73924994468689,0.0 -400400,4.609513,0.650505,,,,,,,,,,,,,, -400500,4.518191,0.5933119,,,,,,,,,,,,,, -400600,4.484685,0.5890745,,,,,,,,,,,,,, -400700,4.5514627,0.6031281,,,,,,,,,,,,,, -400800,4.395407,0.6315515,,,,,,,,,,,,,, -400900,4.8425713,0.6611134,,,,,,,,,,,,,, -401000,4.366607,0.6235882,,,,,,,,,,,,,, -401100,4.5072317,0.60519797,,,,,,,,,,,,,, -401200,4.4656887,0.60846484,,,,,,,,,,,,,, -401300,4.5148053,0.6888043,,,,,,,,,,,,,, -401400,5.289589,0.63960516,,,,,,,,,,,,,, -401500,4.183608,0.55190074,,,,,,,,,,,,,, -401600,4.342828,0.6234674,,,,,,,,,,,,,, -401700,4.604512,0.67007065,,,,,,,,,,,,,, -401800,4.591223,0.5542429,,,,,,,,,,,,,, -401858,,,0.9598811864852904,0.1492884010076522,0.7573800086975098,1.0411863327026367,50000.0,0.6323000192642212,1.822576642036438,10000.0,135212.1051683426,139954.13414263725,135212.1051683426,4707.534991264343,19.83228182792664,0.0 -401900,4.9228134,0.6736322,,,,,,,,,,,,,, -402000,4.3072324,0.65973,,,,,,,,,,,,,, -402100,4.3989644,0.68557644,,,,,,,,,,,,,, -402200,5.087876,0.6840099,,,,,,,,,,,,,, -402300,4.015583,0.57126856,,,,,,,,,,,,,, -402400,4.7856083,0.67489845,,,,,,,,,,,,,, -402500,4.490511,0.62892574,,,,,,,,,,,,,, -402600,4.71888,0.5927638,,,,,,,,,,,,,, -402700,4.5102315,0.6556836,,,,,,,,,,,,,, -402800,4.1024485,0.5485959,,,,,,,,,,,,,, -402900,4.7362084,0.66549397,,,,,,,,,,,,,, -403000,4.9396353,0.6272858,,,,,,,,,,,,,, -403100,4.594278,0.65395874,,,,,,,,,,,,,, -403200,4.350286,0.6596013,,,,,,,,,,,,,, -403300,4.4243455,0.62794805,,,,,,,,,,,,,, -403374,,,0.9605787396430968,0.1461165696382522,0.7571799755096436,1.0413422584533691,50000.0,0.6320000290870667,1.822820782661438,10000.0,135722.01432061195,140481.13131904602,135722.01432061195,4724.468670129776,19.93061709403992,0.0 -403400,4.6210775,0.7513386,,,,,,,,,,,,,, -403500,4.9470334,0.72274184,,,,,,,,,,,,,, -403600,4.4476132,0.60783374,,,,,,,,,,,,,, -403700,4.3332148,0.6077006,,,,,,,,,,,,,, -403800,4.364982,0.5954509,,,,,,,,,,,,,, -403900,4.3295956,0.56574714,,,,,,,,,,,,,, -404000,4.6302743,0.6151494,,,,,,,,,,,,,, -404100,4.152795,0.663234,,,,,,,,,,,,,, -404200,4.832422,0.6570678,,,,,,,,,,,,,, -404300,4.211018,0.59449387,,,,,,,,,,,,,, -404400,5.0593734,0.64428914,,,,,,,,,,,,,, -404500,4.3334446,0.6669374,,,,,,,,,,,,,, -404600,4.4466667,0.5508779,,,,,,,,,,,,,, -404700,4.581298,0.7004703,,,,,,,,,,,,,, -404800,4.129581,0.6327668,,,,,,,,,,,,,, -404890,,,0.9606983065605164,0.1485510170459747,0.7574799656867981,1.041606068611145,50000.0,0.6312000155448914,1.824077844619751,10000.0,136231.96271348,141008.20151233673,136231.96271348,4741.439030885696,20.02464318275452,0.0 -404900,4.7702985,0.684928,,,,,,,,,,,,,, -405000,4.560229,0.6310525,,,,,,,,,,,,,, -405100,4.264435,0.63682485,,,,,,,,,,,,,, -405200,4.190762,0.6179636,,,,,,,,,,,,,, -405300,4.5279245,0.6252593,,,,,,,,,,,,,, -405400,4.8674836,0.6651972,,,,,,,,,,,,,, -405500,4.531987,0.6776042,,,,,,,,,,,,,, -405600,4.3641925,0.59606284,,,,,,,,,,,,,, -405700,4.5390067,0.57655716,,,,,,,,,,,,,, -405800,4.414454,0.58898264,,,,,,,,,,,,,, -405900,4.2631598,0.6076148,,,,,,,,,,,,,, -406000,4.69759,0.6151281,,,,,,,,,,,,,, -406100,4.5341444,0.62199366,,,,,,,,,,,,,, -406200,4.4858384,0.65821946,,,,,,,,,,,,,, -406300,4.2143354,0.5991007,,,,,,,,,,,,,, -406400,4.5805373,0.6535996,,,,,,,,,,,,,, -406407,,,0.9603196382522584,0.1486114412546157,0.75764000415802,1.0409525632858276,50000.0,0.6323000192642212,1.821937799453736,10000.0,136741.83385276794,141535.09117531776,136741.83385276794,4758.309185504913,20.118609189987183,0.0 -406500,4.51691,0.6413732,,,,,,,,,,,,,, -406600,4.204929,0.5765129,,,,,,,,,,,,,, -406700,4.299323,0.5890434,,,,,,,,,,,,,, -406800,4.4797277,0.63142043,,,,,,,,,,,,,, -406900,4.6430297,0.60987896,,,,,,,,,,,,,, -407000,4.431691,0.6936404,,,,,,,,,,,,,, -407100,4.8285556,0.65235287,,,,,,,,,,,,,, -407200,4.3931613,0.6184243,,,,,,,,,,,,,, -407300,4.896209,0.64827234,,,,,,,,,,,,,, -407400,4.6013904,0.6681198,,,,,,,,,,,,,, -407500,3.9464495,0.57628584,,,,,,,,,,,,,, -407600,4.2110186,0.60720015,,,,,,,,,,,,,, -407700,4.6203203,0.61921424,,,,,,,,,,,,,, -407800,4.7418947,0.639111,,,,,,,,,,,,,, -407900,5.1671247,0.6848264,,,,,,,,,,,,,, -407924,,,0.960957407951355,0.1447140276432037,0.7571399807929993,1.0417368412017822,50000.0,0.6314000487327576,1.821706771850586,10000.0,137251.82893824577,142062.15008211136,137251.82893824577,4775.223704099655,20.21147418022156,0.0 -408000,4.1846485,0.5949844,,,,,,,,,,,,,, -408100,4.450111,0.67506206,,,,,,,,,,,,,, -408200,4.023621,0.5590037,,,,,,,,,,,,,, -408300,4.3275867,0.6186522,,,,,,,,,,,,,, -408400,4.395381,0.58932084,,,,,,,,,,,,,, -408500,4.492476,0.67162365,,,,,,,,,,,,,, -408600,4.673184,0.5945521,,,,,,,,,,,,,, -408700,4.588155,0.6124882,,,,,,,,,,,,,, -408800,4.4592853,0.63693136,,,,,,,,,,,,,, -408900,4.621242,0.6223616,,,,,,,,,,,,,, -409000,4.368267,0.6462534,,,,,,,,,,,,,, -409100,4.157698,0.58428735,,,,,,,,,,,,,, -409200,4.7899623,0.6391091,,,,,,,,,,,,,, -409300,4.412106,0.6026647,,,,,,,,,,,,,, -409400,4.9147477,0.6588248,,,,,,,,,,,,,, -409441,,,0.958765149116516,0.1495168954133987,0.7569999694824219,1.041053056716919,50000.0,0.6319000124931335,1.8224622011184688,10000.0,137761.7340619564,142589.07264232635,137761.7340619564,4792.087728738785,20.30940723419189,0.0 -409500,4.247847,0.57417965,,,,,,,,,,,,,, -409600,4.8497844,0.6472495,,,,,,,,,,,,,, -409700,4.5624285,0.60904044,,,,,,,,,,,,,, -409800,4.1537895,0.6339065,,,,,,,,,,,,,, -409900,4.7958646,0.60423446,,,,,,,,,,,,,, -410000,4.50879,0.5642568,,,,,,,,,,,,,, -410100,4.309628,0.6306088,,,,,,,,,,,,,, -410200,4.23554,0.57900816,,,,,,,,,,,,,, -410300,4.663922,0.59436864,,,,,,,,,,,,,, -410400,4.6959677,0.6866021,,,,,,,,,,,,,, -410500,4.6501193,0.6778502,,,,,,,,,,,,,, -410600,4.168088,0.6136488,,,,,,,,,,,,,, -410700,4.216153,0.510067,,,,,,,,,,,,,, -410800,4.3397303,0.64345753,,,,,,,,,,,,,, -410900,4.0648975,0.58868825,,,,,,,,,,,,,, -410958,,,0.9617944359779358,0.144712746143341,0.7569199800491333,1.0410289764404297,50000.0,0.6310000419616699,1.821075677871704,10000.0,138271.87314987183,143116.21062779427,138271.87314987183,4808.927973031998,20.41361141204834,0.0 -411000,4.2982235,0.6112543,,,,,,,,,,,,,, -411100,4.37964,0.59137166,,,,,,,,,,,,,, -411200,4.6524215,0.57515264,,,,,,,,,,,,,, -411300,4.464747,0.57452035,,,,,,,,,,,,,, -411400,4.9083924,0.65484333,,,,,,,,,,,,,, -411500,4.863248,0.6051508,,,,,,,,,,,,,, -411600,4.5607195,0.65053713,,,,,,,,,,,,,, -411700,4.487944,0.5828635,,,,,,,,,,,,,, -411800,4.5968366,0.6660889,,,,,,,,,,,,,, -411900,4.039423,0.58370596,,,,,,,,,,,,,, -412000,4.86629,0.6602708,,,,,,,,,,,,,, -412100,4.4271975,0.61707187,,,,,,,,,,,,,, -412200,4.572254,0.6608096,,,,,,,,,,,,,, -412300,4.5036387,0.64302635,,,,,,,,,,,,,, -412400,4.551492,0.59582317,,,,,,,,,,,,,, -412475,,,0.958984375,0.1489603668451309,0.7571799755096436,1.041362762451172,50000.0,0.6312000155448914,1.822407841682434,10000.0,138781.94454431534,143643.35971355438,138781.94454431534,4825.856830596924,20.50609040260315,0.0 -412500,4.37477,0.62592167,,,,,,,,,,,,,, -412600,4.368853,0.5610322,,,,,,,,,,,,,, -412700,4.3426204,0.62273616,,,,,,,,,,,,,, -412800,5.1319594,0.6167259,,,,,,,,,,,,,, -412900,4.413508,0.6643791,,,,,,,,,,,,,, -413000,4.26767,0.633487,,,,,,,,,,,,,, -413100,4.41669,0.58860546,,,,,,,,,,,,,, -413200,4.524589,0.61722577,,,,,,,,,,,,,, -413300,5.4458323,0.6505443,,,,,,,,,,,,,, -413400,4.4503345,0.65764517,,,,,,,,,,,,,, -413500,4.4886394,0.6374415,,,,,,,,,,,,,, -413600,4.147011,0.5670811,,,,,,,,,,,,,, -413700,4.268077,0.5629201,,,,,,,,,,,,,, -413800,5.0385075,0.60876954,,,,,,,,,,,,,, -413900,4.600482,0.5823126,,,,,,,,,,,,,, -413991,,,0.960339605808258,0.1461835503578186,0.7570799589157104,1.0407791137695312,50000.0,0.6321000456809998,1.8215856552124023,10000.0,139291.81418538094,144170.310328722,139291.81418538094,4842.787178993225,20.600675106048584,0.0 -414000,4.255646,0.59096223,,,,,,,,,,,,,, -414100,4.7281837,0.65592504,,,,,,,,,,,,,, -414200,4.2127085,0.6381501,,,,,,,,,,,,,, -414300,4.5517793,0.60936487,,,,,,,,,,,,,, -414400,5.2950864,0.6513019,,,,,,,,,,,,,, -414500,3.8540375,0.5508614,,,,,,,,,,,,,, -414600,4.698448,0.570802,,,,,,,,,,,,,, -414700,4.5655093,0.6557447,,,,,,,,,,,,,, -414800,4.4900517,0.6518036,,,,,,,,,,,,,, -414900,4.735964,0.73674357,,,,,,,,,,,,,, -415000,4.802421,0.6311388,,,,,,,,,,,,,, -415100,4.548944,0.66670626,,,,,,,,,,,,,, -415200,4.232837,0.5723645,,,,,,,,,,,,,, -415300,4.30621,0.65005225,,,,,,,,,,,,,, -415400,4.5640154,0.6236727,,,,,,,,,,,,,, -415500,4.8905563,0.6270118,,,,,,,,,,,,,, -415507,,,0.96097731590271,0.1466158181428909,0.7567999958992004,1.041944146156311,50000.0,0.6307000517845154,1.823670744895935,10000.0,139801.6758582592,144697.35668969154,139801.6758582592,4859.819278478622,20.697286128997803,0.0 -415600,4.3946433,0.610582,,,,,,,,,,,,,, -415700,4.7775016,0.63028693,,,,,,,,,,,,,, -415800,4.6366606,0.6591282,,,,,,,,,,,,,, -415900,4.395383,0.6277083,,,,,,,,,,,,,, -416000,4.272536,0.61936104,,,,,,,,,,,,,, -416100,4.0304666,0.57831436,,,,,,,,,,,,,, -416200,4.3042483,0.5511489,,,,,,,,,,,,,, -416300,4.0476003,0.5265819,,,,,,,,,,,,,, -416400,4.1692424,0.5526603,,,,,,,,,,,,,, -416500,4.461342,0.62710184,,,,,,,,,,,,,, -416600,4.465377,0.6230172,,,,,,,,,,,,,, -416700,4.654733,0.62066615,,,,,,,,,,,,,, -416800,4.287245,0.60926116,,,,,,,,,,,,,, -416900,4.7386365,0.66115904,,,,,,,,,,,,,, -417000,4.193386,0.62217647,,,,,,,,,,,,,, -417024,,,0.9611766338348388,0.1452847421169281,0.756879985332489,1.0428640842437744,50000.0,0.6310000419616699,1.8238797187805176,10000.0,140311.5811650753,145224.3636994362,140311.5811650753,4876.771936416626,20.78969407081604,0.0 -417100,4.745344,0.57066435,,,,,,,,,,,,,, -417200,4.743927,0.5967849,,,,,,,,,,,,,, -417300,4.4708304,0.6999546,,,,,,,,,,,,,, -417400,4.530297,0.58250165,,,,,,,,,,,,,, -417500,4.373036,0.5953406,,,,,,,,,,,,,, -417600,4.5698485,0.6147945,,,,,,,,,,,,,, -417700,4.5325313,0.61869,,,,,,,,,,,,,, -417800,4.4778876,0.68748385,,,,,,,,,,,,,, -417900,4.69171,0.62189615,,,,,,,,,,,,,, -418000,4.4644876,0.59330577,,,,,,,,,,,,,, -418100,4.231075,0.58124006,,,,,,,,,,,,,, -418200,4.4305983,0.59801173,,,,,,,,,,,,,, -418300,4.4507236,0.6042166,,,,,,,,,,,,,, -418400,4.345854,0.5915823,,,,,,,,,,,,,, -418500,4.1826005,0.63043916,,,,,,,,,,,,,, -418541,,,0.9610371589660645,0.1459392458200454,0.7571600079536438,1.040714144706726,50000.0,0.6319000124931335,1.8214082717895508,10000.0,140821.61366438866,145751.47378349304,140821.61366438866,4893.701037406921,20.88233780860901,0.0 -418600,4.6969047,0.59489816,,,,,,,,,,,,,, -418700,4.4057865,0.6221179,,,,,,,,,,,,,, -418800,4.52427,0.6001874,,,,,,,,,,,,,, -418900,4.3084884,0.6037224,,,,,,,,,,,,,, -419000,4.6347256,0.6571304,,,,,,,,,,,,,, -419100,4.493695,0.5997814,,,,,,,,,,,,,, -419200,4.620397,0.6150688,,,,,,,,,,,,,, -419300,4.6050725,0.6223743,,,,,,,,,,,,,, -419400,4.562789,0.60283184,,,,,,,,,,,,,, -419500,4.5874567,0.6448358,,,,,,,,,,,,,, -419600,4.393551,0.62053776,,,,,,,,,,,,,, -419700,4.877845,0.6192554,,,,,,,,,,,,,, -419800,4.9001107,0.6668511,,,,,,,,,,,,,, -419900,4.527775,0.64228886,,,,,,,,,,,,,, -420000,4.234362,0.5973176,,,,,,,,,,,,,, -420058,,,0.9612962007522584,0.1448786854743957,0.7575799822807312,1.0407829284667969,50000.0,0.6314000487327576,1.8214672803878784,10000.0,141331.67381334305,146278.59851312637,141331.67381334305,4910.61313700676,20.978939056396484,0.0 -420100,4.730857,0.5884279,,,,,,,,,,,,,, -420200,4.8024483,0.7212542,,,,,,,,,,,,,, -420300,4.4468346,0.6456722,,,,,,,,,,,,,, -420400,4.9445205,0.61088663,,,,,,,,,,,,,, -420500,4.660234,0.6263292,,,,,,,,,,,,,, -420600,4.7361217,0.64902663,,,,,,,,,,,,,, -420700,4.2592,0.63903135,,,,,,,,,,,,,, -420800,5.1106477,0.6743437,,,,,,,,,,,,,, -420900,4.0935125,0.57438624,,,,,,,,,,,,,, -421000,4.345672,0.5873807,,,,,,,,,,,,,, -421100,4.5016723,0.5950694,,,,,,,,,,,,,, -421200,4.6477585,0.63520515,,,,,,,,,,,,,, -421300,4.752366,0.73141503,,,,,,,,,,,,,, -421400,4.7301,0.6119151,,,,,,,,,,,,,, -421500,4.6145415,0.6563529,,,,,,,,,,,,,, -421575,,,0.9612364172935486,0.1430351883172989,0.7572599649429321,1.0402424335479736,50000.0,0.6322000026702881,1.821758270263672,10000.0,141841.59809207916,146805.78198504448,141841.59809207916,4927.719024181366,21.077617645263672,0.0 -421600,4.678323,0.65955615,,,,,,,,,,,,,, -421700,4.658396,0.59932804,,,,,,,,,,,,,, -421800,4.796326,0.6163999,,,,,,,,,,,,,, -421900,4.2727,0.55857605,,,,,,,,,,,,,, -422000,4.930672,0.59860283,,,,,,,,,,,,,, -422100,4.7146297,0.58876264,,,,,,,,,,,,,, -422200,4.5388246,0.55467165,,,,,,,,,,,,,, -422300,4.417491,0.61934036,,,,,,,,,,,,,, -422400,4.362234,0.56007075,,,,,,,,,,,,,, -422500,4.343874,0.65369403,,,,,,,,,,,,,, -422600,4.49147,0.6980683,,,,,,,,,,,,,, -422700,4.4829736,0.6064604,,,,,,,,,,,,,, -422800,4.3110666,0.60959285,,,,,,,,,,,,,, -422900,4.364718,0.6163865,,,,,,,,,,,,,, -423000,4.5212064,0.6611657,,,,,,,,,,,,,, -423092,,,0.9618741869926452,0.144457459449768,0.7572999596595764,1.040998458862305,50000.0,0.6313000321388245,1.8226169347763064,10000.0,142351.5038971901,147333.03734254837,142351.5038971901,4944.911732673645,21.177175760269165,0.0 -423100,4.6622653,0.5966091,,,,,,,,,,,,,, -423200,4.586384,0.58727586,,,,,,,,,,,,,, -423300,4.349407,0.6220121,,,,,,,,,,,,,, -423400,4.6187196,0.55370516,,,,,,,,,,,,,, -423500,4.3211956,0.58706987,,,,,,,,,,,,,, -423600,4.7723494,0.6014898,,,,,,,,,,,,,, -423700,4.1152463,0.5823077,,,,,,,,,,,,,, -423800,4.2664948,0.61896676,,,,,,,,,,,,,, -423900,4.3596926,0.5936841,,,,,,,,,,,,,, -424000,4.644859,0.61049044,,,,,,,,,,,,,, -424100,4.451297,0.6271492,,,,,,,,,,,,,, -424200,4.092847,0.57956433,,,,,,,,,,,,,, -424300,4.6198177,0.6690016,,,,,,,,,,,,,, -424400,4.6294847,0.6750184,,,,,,,,,,,,,, -424500,4.7761807,0.6443153,,,,,,,,,,,,,, -424600,4.4002666,0.6585167,,,,,,,,,,,,,, -424609,,,0.9593032598495485,0.1488636881113052,0.7572599649429321,1.040607452392578,50000.0,0.6313000321388245,1.8224475383758545,10000.0,142861.59340810776,147860.83490467072,142861.59340810776,4962.483725547791,21.25806474685669,0.0 -424700,4.6297874,0.6124752,,,,,,,,,,,,,, -424800,4.3669524,0.6207104,,,,,,,,,,,,,, -424900,4.2317452,0.67666817,,,,,,,,,,,,,, -425000,4.48672,0.57778615,,,,,,,,,,,,,, -425100,4.430984,0.65737414,,,,,,,,,,,,,, -425200,4.651087,0.63299924,,,,,,,,,,,,,, -425300,4.918462,0.5997845,,,,,,,,,,,,,, -425400,5.084062,0.7117087,,,,,,,,,,,,,, -425500,4.684404,0.72068655,,,,,,,,,,,,,, -425600,4.3798723,0.5663773,,,,,,,,,,,,,, -425700,4.421095,0.6691574,,,,,,,,,,,,,, -425800,4.5546546,0.6057282,,,,,,,,,,,,,, -425900,4.2703376,0.62376523,,,,,,,,,,,,,, -426000,4.4223466,0.61918205,,,,,,,,,,,,,, -426100,4.1924024,0.59464216,,,,,,,,,,,,,, -426126,,,0.9600605964660645,0.1482033878564834,0.7573599815368652,1.040932536125183,50000.0,0.631100058555603,1.8229377269744875,10000.0,143371.71987581253,148388.12489271164,143371.71987581253,4979.48906993866,21.36016345024109,0.0 -426200,4.054589,0.5992892,,,,,,,,,,,,,, -426300,4.2458057,0.56624365,,,,,,,,,,,,,, -426400,4.4877386,0.58010757,,,,,,,,,,,,,, -426500,4.575146,0.66606736,,,,,,,,,,,,,, -426600,4.5841455,0.6588802,,,,,,,,,,,,,, -426700,4.789443,0.6472494,,,,,,,,,,,,,, -426800,4.746759,0.6672479,,,,,,,,,,,,,, -426900,5.0286164,0.6812608,,,,,,,,,,,,,, -427000,4.2089267,0.630836,,,,,,,,,,,,,, -427100,4.5582523,0.60972667,,,,,,,,,,,,,, -427200,5.313002,0.65948755,,,,,,,,,,,,,, -427300,4.5928197,0.6826603,,,,,,,,,,,,,, -427400,4.4802146,0.5452044,,,,,,,,,,,,,, -427500,4.7361937,0.6302177,,,,,,,,,,,,,, -427600,5.304643,0.6630754,,,,,,,,,,,,,, -427644,,,0.9620934128761292,0.1434959471225738,0.7572000026702881,1.0401928424835205,50000.0,0.631600022315979,1.820892095565796,10000.0,143881.7842924595,148915.24838876724,143881.7842924595,4996.394257545471,21.45875644683838,0.0 -427700,4.393953,0.6352298,,,,,,,,,,,,,, -427800,4.627761,0.6550132,,,,,,,,,,,,,, -427900,4.7460155,0.6389985,,,,,,,,,,,,,, -428000,4.232754,0.60971355,,,,,,,,,,,,,, -428100,4.5223055,0.60341495,,,,,,,,,,,,,, -428200,4.4883265,0.6592149,,,,,,,,,,,,,, -428300,4.2097907,0.6201901,,,,,,,,,,,,,, -428400,4.487483,0.71570456,,,,,,,,,,,,,, -428500,4.234098,0.6014426,,,,,,,,,,,,,, -428600,4.1884036,0.6150666,,,,,,,,,,,,,, -428700,4.671952,0.6322017,,,,,,,,,,,,,, -428800,4.377851,0.60295653,,,,,,,,,,,,,, -428900,5.1100564,0.65558964,,,,,,,,,,,,,, -429000,4.5114074,0.59218013,,,,,,,,,,,,,, -429100,4.3851,0.57761437,,,,,,,,,,,,,, -429160,,,0.9617147445678712,0.1429807841777801,0.7571600079536438,1.041298508644104,50000.0,0.6306000351905823,1.8239283561706543,10000.0,144391.64415311813,149442.18219542503,144391.64415311813,5013.312318086624,21.560152769088745,0.0 -429200,4.2760825,0.5969403,,,,,,,,,,,,,, -429300,4.4267387,0.6021661,,,,,,,,,,,,,, -429400,4.567774,0.659736,,,,,,,,,,,,,, -429500,4.862941,0.6974322,,,,,,,,,,,,,, -429600,4.4072957,0.6047072,,,,,,,,,,,,,, -429700,4.1959286,0.5519318,,,,,,,,,,,,,, -429800,4.9958177,0.61571395,,,,,,,,,,,,,, -429900,5.1703267,0.73488575,,,,,,,,,,,,,, -430000,4.7952495,0.64352584,,,,,,,,,,,,,, -430100,4.442287,0.5908881,,,,,,,,,,,,,, -430200,4.172457,0.6378373,,,,,,,,,,,,,, -430300,4.850583,0.6467639,,,,,,,,,,,,,, -430400,4.1111126,0.61128515,,,,,,,,,,,,,, -430500,4.153054,0.55830675,,,,,,,,,,,,,, -430600,4.235987,0.6676136,,,,,,,,,,,,,, -430677,,,0.9617745280265808,0.1452418863773346,0.757420003414154,1.0418684482574463,50000.0,0.6312000155448914,1.8239145278930664,10000.0,144901.60076332092,149969.17936348915,144901.60076332092,5030.17825627327,21.67906594276428,0.0 -430700,4.166873,0.65171957,,,,,,,,,,,,,, -430800,4.576184,0.6432333,,,,,,,,,,,,,, -430900,4.451288,0.59915257,,,,,,,,,,,,,, -431000,5.022738,0.62257797,,,,,,,,,,,,,, -431100,4.569999,0.66059613,,,,,,,,,,,,,, -431200,4.550466,0.6327865,,,,,,,,,,,,,, -431300,4.656541,0.6434593,,,,,,,,,,,,,, -431400,4.6499214,0.6106981,,,,,,,,,,,,,, -431500,4.34157,0.5774154,,,,,,,,,,,,,, -431600,4.7472186,0.58798015,,,,,,,,,,,,,, -431700,4.7071595,0.68846136,,,,,,,,,,,,,, -431800,4.2169266,0.57050604,,,,,,,,,,,,,, -431900,4.8388643,0.6571437,,,,,,,,,,,,,, -432000,4.7556133,0.614593,,,,,,,,,,,,,, -432100,4.7189913,0.66156447,,,,,,,,,,,,,, -432193,,,0.9602000713348388,0.1488232314586639,0.7574999928474426,1.041712999343872,50000.0,0.6321000456809998,1.8241660594940183,10000.0,145411.59765625,150496.2617239952,145411.59765625,5047.105713605881,21.780495166778564,0.0 -432200,3.9558043,0.5577863,,,,,,,,,,,,,, -432300,4.6148467,0.627526,,,,,,,,,,,,,, -432400,4.001436,0.5980455,,,,,,,,,,,,,, -432500,4.8506675,0.6641365,,,,,,,,,,,,,, -432600,4.1052737,0.54819036,,,,,,,,,,,,,, -432700,4.655607,0.66121083,,,,,,,,,,,,,, -432800,4.403717,0.64604783,,,,,,,,,,,,,, -432900,4.898984,0.60907155,,,,,,,,,,,,,, -433000,4.3189316,0.6352313,,,,,,,,,,,,,, -433100,4.6837482,0.6307885,,,,,,,,,,,,,, -433200,4.7789173,0.6536467,,,,,,,,,,,,,, -433300,4.5014997,0.65763795,,,,,,,,,,,,,, -433400,4.944518,0.6955203,,,,,,,,,,,,,, -433500,4.44742,0.6677738,,,,,,,,,,,,,, -433600,4.6607714,0.55710906,,,,,,,,,,,,,, -433700,4.596912,0.60373616,,,,,,,,,,,,,, -433709,,,0.9606783986091614,0.1471301168203354,0.7573399543762207,1.041322112083435,50000.0,0.6320000290870667,1.823303461074829,10000.0,145921.48075079918,151023.10500741005,145921.48075079918,5063.910180091858,21.878495693206787,0.0 -433800,4.491026,0.628404,,,,,,,,,,,,,, -433900,4.783046,0.6457555,,,,,,,,,,,,,, -434000,4.562579,0.6297728,,,,,,,,,,,,,, -434100,4.540248,0.5929248,,,,,,,,,,,,,, -434200,4.0650496,0.55894285,,,,,,,,,,,,,, -434300,4.766739,0.6858487,,,,,,,,,,,,,, -434400,4.410127,0.60682493,,,,,,,,,,,,,, -434500,4.5220613,0.5776154,,,,,,,,,,,,,, -434600,4.2041254,0.60094726,,,,,,,,,,,,,, -434700,4.424799,0.6518456,,,,,,,,,,,,,, -434800,4.4683433,0.7158277,,,,,,,,,,,,,, -434900,4.950247,0.64858747,,,,,,,,,,,,,, -435000,4.5010185,0.5841818,,,,,,,,,,,,,, -435100,4.575094,0.6866969,,,,,,,,,,,,,, -435200,4.42239,0.6916858,,,,,,,,,,,,,, -435226,,,0.9610171914100648,0.1463170796632766,0.757099986076355,1.0406198501586914,50000.0,0.631600022315979,1.8207440376281736,10000.0,146431.42701363564,151549.93702721596,146431.42701363564,5080.642971515656,21.97658133506775,0.0 -435300,4.976009,0.6905776,,,,,,,,,,,,,, -435400,4.2867956,0.63647366,,,,,,,,,,,,,, -435500,4.5735188,0.6155864,,,,,,,,,,,,,, -435600,4.557377,0.68813014,,,,,,,,,,,,,, -435700,4.6927233,0.7176393,,,,,,,,,,,,,, -435800,4.3483963,0.6308132,,,,,,,,,,,,,, -435900,4.091697,0.5361412,,,,,,,,,,,,,, -436000,4.26512,0.5635659,,,,,,,,,,,,,, -436100,4.4624963,0.5672656,,,,,,,,,,,,,, -436200,4.274846,0.6414455,,,,,,,,,,,,,, -436300,4.5268173,0.62739354,,,,,,,,,,,,,, -436400,4.9089303,0.6407095,,,,,,,,,,,,,, -436500,4.3756676,0.5765597,,,,,,,,,,,,,, -436600,4.6182737,0.62354547,,,,,,,,,,,,,, -436700,4.319307,0.6060113,,,,,,,,,,,,,, -436744,,,0.9606584906578064,0.1450043320655822,0.75764000415802,1.0415468215942385,50000.0,0.6321000456809998,1.82472825050354,10000.0,146941.59220194817,152077.231808424,146941.59220194817,5097.61579823494,22.0786395072937,0.0 -436800,4.562329,0.58459085,,,,,,,,,,,,,, -436900,4.4341326,0.6727501,,,,,,,,,,,,,, -437000,4.2414618,0.58752036,,,,,,,,,,,,,, -437100,4.273663,0.5537606,,,,,,,,,,,,,, -437200,4.4349923,0.633644,,,,,,,,,,,,,, -437300,4.686999,0.62578994,,,,,,,,,,,,,, -437400,4.8744435,0.69395673,,,,,,,,,,,,,, -437500,4.6343293,0.69188833,,,,,,,,,,,,,, -437600,4.5923905,0.6449297,,,,,,,,,,,,,, -437700,5.2180486,0.7494047,,,,,,,,,,,,,, -437800,4.102545,0.56608015,,,,,,,,,,,,,, -437900,4.7483,0.68357897,,,,,,,,,,,,,, -438000,4.355911,0.60175365,,,,,,,,,,,,,, -438100,4.1841464,0.58312297,,,,,,,,,,,,,, -438200,4.4369407,0.6542791,,,,,,,,,,,,,, -438261,,,0.9612165093421936,0.146070510149002,0.7572000026702881,1.0408618450164795,50000.0,0.6318000555038452,1.82334578037262,10000.0,147451.58518648148,152604.4050400257,147451.58518648148,5114.640841245651,22.17793607711792,0.0 -438300,4.438785,0.62724125,,,,,,,,,,,,,, -438400,4.8500047,0.6839063,,,,,,,,,,,,,, -438500,4.484724,0.59517944,,,,,,,,,,,,,, -438600,4.4187927,0.65733343,,,,,,,,,,,,,, -438700,4.466139,0.64303344,,,,,,,,,,,,,, -438800,4.6541777,0.61692023,,,,,,,,,,,,,, -438900,4.4774666,0.6400498,,,,,,,,,,,,,, -439000,4.352039,0.59701866,,,,,,,,,,,,,, -439100,4.1364937,0.5601706,,,,,,,,,,,,,, -439200,4.541018,0.6466956,,,,,,,,,,,,,, -439300,4.8349705,0.63794357,,,,,,,,,,,,,, -439400,5.5391874,0.7249577,,,,,,,,,,,,,, -439500,4.562662,0.6431634,,,,,,,,,,,,,, -439600,4.781205,0.65249425,,,,,,,,,,,,,, -439700,4.449562,0.710874,,,,,,,,,,,,,, -439778,,,0.9588249325752258,0.1511629223823547,0.7573399543762207,1.0402222871780396,50000.0,0.6314000487327576,1.8219047784805296,10000.0,147961.7461400032,153131.56278252602,147961.7461400032,5131.485112667084,22.27350926399231,0.0 -439800,4.701836,0.6458102,,,,,,,,,,,,,, -439900,4.8598533,0.6413057,,,,,,,,,,,,,, -440000,4.1074777,0.5244527,,,,,,,,,,,,,, -440100,4.182213,0.6056877,,,,,,,,,,,,,, -440200,4.548223,0.65223175,,,,,,,,,,,,,, -440300,4.536102,0.6286961,,,,,,,,,,,,,, -440400,4.546876,0.6506,,,,,,,,,,,,,, -440500,4.344668,0.63427037,,,,,,,,,,,,,, -440600,4.4582787,0.56004524,,,,,,,,,,,,,, -440700,4.02357,0.6130692,,,,,,,,,,,,,, -440800,4.62115,0.6193095,,,,,,,,,,,,,, -440900,4.392924,0.5845658,,,,,,,,,,,,,, -441000,4.6099405,0.6347132,,,,,,,,,,,,,, -441100,4.799324,0.675564,,,,,,,,,,,,,, -441200,4.8047314,0.73578733,,,,,,,,,,,,,, -441295,,,0.9607381820678712,0.1461882889270782,0.7571399807929993,1.042002558708191,50000.0,0.6318000555038452,1.8222293853759768,10000.0,148471.76654148102,153658.69793319702,148471.76654148102,5148.434104681015,22.382426500320435,0.0 -441300,4.495053,0.61139786,,,,,,,,,,,,,, -441400,4.5456867,0.595194,,,,,,,,,,,,,, -441500,4.2110863,0.63539195,,,,,,,,,,,,,, -441600,4.876282,0.644879,,,,,,,,,,,,,, -441700,4.0402946,0.5672276,,,,,,,,,,,,,, -441800,4.6853437,0.62598646,,,,,,,,,,,,,, -441900,4.8873277,0.67607754,,,,,,,,,,,,,, -442000,4.483566,0.5854441,,,,,,,,,,,,,, -442100,4.344955,0.6275785,,,,,,,,,,,,,, -442200,4.6482334,0.6690568,,,,,,,,,,,,,, -442300,4.306985,0.5779139,,,,,,,,,,,,,, -442400,4.3397927,0.58737373,,,,,,,,,,,,,, -442500,4.5924478,0.62702364,,,,,,,,,,,,,, -442600,4.7093577,0.6482565,,,,,,,,,,,,,, -442700,4.7347074,0.61990935,,,,,,,,,,,,,, -442800,4.5040526,0.5271981,,,,,,,,,,,,,, -442811,,,0.9609175324440002,0.1468439698219299,0.7570799589157104,1.040229320526123,50000.0,0.6321000456809998,1.820679783821106,10000.0,148981.70437383652,154185.9196381569,148981.70437383652,5165.56057715416,22.4826500415802,0.0 -442900,4.6632514,0.5879688,,,,,,,,,,,,,, -443000,4.0552936,0.58475065,,,,,,,,,,,,,, -443100,4.3430214,0.61659557,,,,,,,,,,,,,, -443200,4.8751197,0.6339169,,,,,,,,,,,,,, -443300,4.9715586,0.6798878,,,,,,,,,,,,,, -443400,4.4322057,0.62540543,,,,,,,,,,,,,, -443500,4.285264,0.6515055,,,,,,,,,,,,,, -443600,4.454462,0.62514424,,,,,,,,,,,,,, -443700,4.1868553,0.56807256,,,,,,,,,,,,,, -443800,4.363619,0.6421728,,,,,,,,,,,,,, -443900,4.779927,0.6048907,,,,,,,,,,,,,, -444000,4.3755007,0.62210363,,,,,,,,,,,,,, -444100,4.277336,0.58261174,,,,,,,,,,,,,, -444200,4.498439,0.58001274,,,,,,,,,,,,,, -444300,4.5098867,0.6523295,,,,,,,,,,,,,, -444328,,,0.9606186151504515,0.1501710265874862,0.7574399709701538,1.0412341356277466,50000.0,0.6317000389099121,1.8212313652038568,10000.0,149491.8409090042,154713.53189253807,149491.8409090042,5182.874908208847,22.5871741771698,0.0 -444400,4.4489074,0.6554952,,,,,,,,,,,,,, -444500,4.2643676,0.60978055,,,,,,,,,,,,,, -444600,4.514159,0.6686417,,,,,,,,,,,,,, -444700,4.699594,0.7192618,,,,,,,,,,,,,, -444800,4.154553,0.61208284,,,,,,,,,,,,,, -444900,4.278243,0.6422086,,,,,,,,,,,,,, -445000,4.432141,0.60852426,,,,,,,,,,,,,, -445100,4.7536616,0.72372234,,,,,,,,,,,,,, -445200,4.277967,0.5688183,,,,,,,,,,,,,, -445300,4.465045,0.62060475,,,,,,,,,,,,,, -445400,4.6883955,0.66914415,,,,,,,,,,,,,, -445500,4.6553817,0.604268,,,,,,,,,,,,,, -445600,4.4082656,0.70707554,,,,,,,,,,,,,, -445700,4.7747064,0.632376,,,,,,,,,,,,,, -445800,5.0175076,0.6165716,,,,,,,,,,,,,, -445845,,,0.9606983065605164,0.146799087524414,0.7567799687385559,1.0411818027496338,50000.0,0.631100058555603,1.822283148765564,10000.0,150001.83459067345,155240.5869767666,150001.83459067345,5199.798217058182,22.6678466796875,0.0 -445900,4.6135674,0.6081554,,,,,,,,,,,,,, -446000,4.5514536,0.65818506,,,,,,,,,,,,,, -446100,4.266087,0.5913599,,,,,,,,,,,,,, -446200,4.11568,0.57798934,,,,,,,,,,,,,, -446300,4.795068,0.6523453,,,,,,,,,,,,,, -446400,4.466221,0.67906857,,,,,,,,,,,,,, -446500,4.262261,0.6017877,,,,,,,,,,,,,, -446600,4.883167,0.6651577,,,,,,,,,,,,,, -446700,4.598291,0.6461647,,,,,,,,,,,,,, -446800,4.528557,0.5605527,,,,,,,,,,,,,, -446900,4.4312296,0.66138136,,,,,,,,,,,,,, -447000,4.8982887,0.6396103,,,,,,,,,,,,,, -447100,4.926512,0.61651945,,,,,,,,,,,,,, -447200,4.7475047,0.7605455,,,,,,,,,,,,,, -447300,4.1515913,0.5730256,,,,,,,,,,,,,, -447362,,,0.9602798223495485,0.1471463590860366,0.7571199536323547,1.0414865016937256,50000.0,0.6319000124931335,1.8231927156448364,10000.0,150511.73395705223,155767.55983424187,150511.73395705223,5216.711742401123,22.770523071289062,0.0 -447400,3.9743578,0.6050289,,,,,,,,,,,,,, -447500,4.1157036,0.53846204,,,,,,,,,,,,,, -447600,4.4157505,0.6879825,,,,,,,,,,,,,, -447700,4.0858054,0.63232553,,,,,,,,,,,,,, -447800,4.3973455,0.6006573,,,,,,,,,,,,,, -447900,4.731963,0.65719557,,,,,,,,,,,,,, -448000,4.357629,0.5972489,,,,,,,,,,,,,, -448100,4.1520495,0.6214437,,,,,,,,,,,,,, -448200,4.656133,0.6393736,,,,,,,,,,,,,, -448300,4.5374093,0.5964993,,,,,,,,,,,,,, -448400,4.8052125,0.6705264,,,,,,,,,,,,,, -448500,4.4907665,0.6418556,,,,,,,,,,,,,, -448600,4.3715644,0.6301517,,,,,,,,,,,,,, -448700,4.5267673,0.5828995,,,,,,,,,,,,,, -448800,4.5410447,0.6825253,,,,,,,,,,,,,, -448879,,,0.9606186151504515,0.1455328613519668,0.7567200064659119,1.0406856536865234,50000.0,0.6320000290870667,1.8224512338638303,10000.0,151021.81368637085,156294.64122962952,151021.81368637085,5233.558504581451,22.870090007781982,0.0 -448900,4.764621,0.6847201,,,,,,,,,,,,,, -449000,4.2508545,0.5044165,,,,,,,,,,,,,, -449100,4.33056,0.6700059,,,,,,,,,,,,,, -449200,4.2333975,0.60128915,,,,,,,,,,,,,, -449300,4.764043,0.68410677,,,,,,,,,,,,,, -449400,4.402226,0.631398,,,,,,,,,,,,,, -449500,4.3465414,0.66272396,,,,,,,,,,,,,, -449600,4.528763,0.689497,,,,,,,,,,,,,, -449700,4.3865743,0.6209059,,,,,,,,,,,,,, -449800,4.2917724,0.6184381,,,,,,,,,,,,,, -449900,4.622889,0.6919598,,,,,,,,,,,,,, -450000,5.1258516,0.68911743,,,,,,,,,,,,,, -450100,4.5806713,0.56902397,,,,,,,,,,,,,, -450200,4.449842,0.62198037,,,,,,,,,,,,,, -450300,4.4648495,0.6463831,,,,,,,,,,,,,, -450395,,,0.9595423936843872,0.1479770094156265,0.7572000026702881,1.0416878461837769,50000.0,0.6322000026702881,1.824686050415039,10000.0,151531.68371200562,156821.4516866207,151531.68371200562,5250.33748459816,22.97496485710144,0.0 -450400,4.6958284,0.68644637,,,,,,,,,,,,,, -450500,4.4760942,0.56426054,,,,,,,,,,,,,, -450600,4.2978497,0.53068393,,,,,,,,,,,,,, -450700,4.5757465,0.66647905,,,,,,,,,,,,,, -450800,4.6183596,0.6311565,,,,,,,,,,,,,, -450900,4.639158,0.6382498,,,,,,,,,,,,,, -451000,4.9165015,0.66497236,,,,,,,,,,,,,, -451100,4.69906,0.6733192,,,,,,,,,,,,,, -451200,4.561418,0.647699,,,,,,,,,,,,,, -451300,4.0940475,0.60488844,,,,,,,,,,,,,, -451400,4.4676633,0.60364383,,,,,,,,,,,,,, -451500,5.022352,0.6491101,,,,,,,,,,,,,, -451600,4.31243,0.6084776,,,,,,,,,,,,,, -451700,4.3773956,0.6019007,,,,,,,,,,,,,, -451800,4.4333987,0.5848955,,,,,,,,,,,,,, -451900,4.2114296,0.5684486,,,,,,,,,,,,,, -451911,,,0.959980845451355,0.1469773203134536,0.7573599815368652,1.0414307117462158,50000.0,0.6318000555038452,1.8233704566955569,10000.0,152041.5577995777,157348.38074493408,152041.5577995777,5267.233897209168,23.07772135734558,0.0 -452000,4.175465,0.5629593,,,,,,,,,,,,,, -452100,5.0203285,0.63937503,,,,,,,,,,,,,, -452200,4.2544403,0.63111883,,,,,,,,,,,,,, -452300,4.654094,0.65138376,,,,,,,,,,,,,, -452400,4.520902,0.6163384,,,,,,,,,,,,,, -452500,4.142112,0.6646186,,,,,,,,,,,,,, -452600,4.5306544,0.6436468,,,,,,,,,,,,,, -452700,4.8676314,0.64857906,,,,,,,,,,,,,, -452800,4.475275,0.62266815,,,,,,,,,,,,,, -452900,4.396276,0.5597929,,,,,,,,,,,,,, -453000,4.4143023,0.6265366,,,,,,,,,,,,,, -453100,4.4803786,0.603218,,,,,,,,,,,,,, -453200,4.6058083,0.6638149,,,,,,,,,,,,,, -453300,4.1869035,0.63915956,,,,,,,,,,,,,, -453400,4.6348696,0.60296154,,,,,,,,,,,,,, -453428,,,0.9600406289100648,0.1466209590435028,0.7575199604034424,1.0417447090148926,50000.0,0.632900059223175,1.82382333278656,10000.0,152551.6747918129,157875.72356677055,152551.6747918129,5284.3033618927,23.179197311401367,0.0 -453500,4.8090305,0.68369806,,,,,,,,,,,,,, -453600,4.459547,0.62461555,,,,,,,,,,,,,, -453700,4.596842,0.5874046,,,,,,,,,,,,,, -453800,4.3413215,0.57224894,,,,,,,,,,,,,, -453900,4.5601783,0.63018787,,,,,,,,,,,,,, -454000,4.751325,0.6246292,,,,,,,,,,,,,, -454100,4.835701,0.6069903,,,,,,,,,,,,,, -454200,4.259754,0.6311074,,,,,,,,,,,,,, -454300,4.386686,0.5841975,,,,,,,,,,,,,, -454400,4.642432,0.58559906,,,,,,,,,,,,,, -454500,3.9695823,0.53517514,,,,,,,,,,,,,, -454600,4.332846,0.5811224,,,,,,,,,,,,,, -454700,4.354465,0.624226,,,,,,,,,,,,,, -454800,4.258581,0.58894104,,,,,,,,,,,,,, -454900,4.455944,0.6304296,,,,,,,,,,,,,, -454945,,,0.9618343114852904,0.1466174721717834,0.7567399740219116,1.0415546894073486,50000.0,0.6317000389099121,1.823285818099976,10000.0,153061.76563978195,158402.9397919178,153061.76563978195,5301.271278142929,23.28038787841797,0.0 -455000,4.728269,0.63205296,,,,,,,,,,,,,, -455100,4.4563513,0.6246298,,,,,,,,,,,,,, -455200,4.926256,0.64294785,,,,,,,,,,,,,, -455300,4.610992,0.6747356,,,,,,,,,,,,,, -455400,5.0600185,0.6648197,,,,,,,,,,,,,, -455500,4.8884745,0.65342593,,,,,,,,,,,,,, -455600,4.4534593,0.66047645,,,,,,,,,,,,,, -455700,4.3623786,0.58235085,,,,,,,,,,,,,, -455800,4.314467,0.59691304,,,,,,,,,,,,,, -455900,4.7216773,0.6611283,,,,,,,,,,,,,, -456000,4.390609,0.6566621,,,,,,,,,,,,,, -456100,4.517462,0.60853934,,,,,,,,,,,,,, -456200,4.5403805,0.57718515,,,,,,,,,,,,,, -456300,4.08769,0.59478194,,,,,,,,,,,,,, -456400,4.527193,0.5972167,,,,,,,,,,,,,, -456463,,,0.96097731590271,0.1465593427419662,0.7573599815368652,1.0404845476150513,50000.0,0.6308000087738037,1.8218352794647217,10000.0,153571.90072655678,158930.11351394653,153571.90072655678,5318.150739192963,23.382413625717163,0.0 -456500,4.771224,0.6552226,,,,,,,,,,,,,, -456600,4.2206616,0.55884117,,,,,,,,,,,,,, -456700,4.4313526,0.5729835,,,,,,,,,,,,,, -456800,4.491507,0.67318594,,,,,,,,,,,,,, -456900,5.284447,0.6759889,,,,,,,,,,,,,, -457000,4.494724,0.54318476,,,,,,,,,,,,,, -457100,4.850028,0.61341137,,,,,,,,,,,,,, -457200,4.353127,0.53768855,,,,,,,,,,,,,, -457300,4.2236867,0.5938619,,,,,,,,,,,,,, -457400,4.391285,0.5894464,,,,,,,,,,,,,, -457500,4.106419,0.5973471,,,,,,,,,,,,,, -457600,4.88969,0.59516394,,,,,,,,,,,,,, -457700,4.5137177,0.63182235,,,,,,,,,,,,,, -457800,5.0602546,0.63574135,,,,,,,,,,,,,, -457900,4.3350034,0.6325345,,,,,,,,,,,,,, -457980,,,0.9604591727256776,0.1468010693788528,0.7571799755096436,1.0417757034301758,50000.0,0.6302000284194946,1.8235985040664675,10000.0,154081.8287653923,159457.1767117977,154081.8287653923,5335.127779006958,23.484861850738525,0.0 -458000,4.7952323,0.60040295,,,,,,,,,,,,,, -458100,3.9925888,0.5395612,,,,,,,,,,,,,, -458200,4.337751,0.61407954,,,,,,,,,,,,,, -458300,4.4511356,0.5567742,,,,,,,,,,,,,, -458400,4.3955793,0.651388,,,,,,,,,,,,,, -458500,4.6387076,0.6665374,,,,,,,,,,,,,, -458600,4.7817025,0.62730294,,,,,,,,,,,,,, -458700,4.9738655,0.6066889,,,,,,,,,,,,,, -458800,4.43382,0.6142339,,,,,,,,,,,,,, -458900,4.4933987,0.6505612,,,,,,,,,,,,,, -459000,4.3978944,0.6460064,,,,,,,,,,,,,, -459100,4.4854198,0.5434037,,,,,,,,,,,,,, -459200,4.28901,0.661247,,,,,,,,,,,,,, -459300,4.38036,0.59128326,,,,,,,,,,,,,, -459400,4.279878,0.65636796,,,,,,,,,,,,,, -459497,,,0.9614556431770324,0.1402759253978729,0.7569999694824219,1.0417706966400146,50000.0,0.6321000456809998,1.8234221935272217,10000.0,154591.9551639557,159984.44529652596,154591.9551639557,5352.105293512344,23.593278408050537,0.0 -459500,5.1130085,0.6763602,,,,,,,,,,,,,, -459600,4.5824046,0.59447145,,,,,,,,,,,,,, -459700,4.460463,0.6103445,,,,,,,,,,,,,, -459800,4.6133704,0.6452725,,,,,,,,,,,,,, -459900,4.370387,0.60391587,,,,,,,,,,,,,, -460000,4.622271,0.6608434,,,,,,,,,,,,,, -460100,5.042873,0.6953511,,,,,,,,,,,,,, -460200,4.5362654,0.6485008,,,,,,,,,,,,,, -460300,4.247827,0.67317325,,,,,,,,,,,,,, -460400,4.631473,0.6216545,,,,,,,,,,,,,, -460500,4.4618526,0.6456499,,,,,,,,,,,,,, -460600,4.4838023,0.6507395,,,,,,,,,,,,,, -460700,4.5050044,0.64448744,,,,,,,,,,,,,, -460800,4.9646544,0.6326661,,,,,,,,,,,,,, -460900,4.7330103,0.60433424,,,,,,,,,,,,,, -461000,4.07609,0.5762525,,,,,,,,,,,,,, -461014,,,0.962113320827484,0.1443096399307251,0.7572799921035767,1.0408600568771362,50000.0,0.6315000057220459,1.8218690156936648,10000.0,155101.99497246742,160511.493765831,155101.99497246742,5368.956280946732,23.694770097732544,0.0 -461100,4.220053,0.60972553,,,,,,,,,,,,,, -461200,4.3369246,0.59177655,,,,,,,,,,,,,, -461300,4.600795,0.5701773,,,,,,,,,,,,,, -461400,4.5006332,0.6537877,,,,,,,,,,,,,, -461500,4.765866,0.7304076,,,,,,,,,,,,,, -461600,4.193658,0.5768589,,,,,,,,,,,,,, -461700,4.4532776,0.58509,,,,,,,,,,,,,, -461800,5.0202513,0.68684894,,,,,,,,,,,,,, -461900,4.800463,0.72240317,,,,,,,,,,,,,, -462000,4.699793,0.59156144,,,,,,,,,,,,,, -462100,4.739454,0.6783587,,,,,,,,,,,,,, -462200,4.732991,0.6599991,,,,,,,,,,,,,, -462300,4.4409213,0.60288906,,,,,,,,,,,,,, -462400,4.4079294,0.64381075,,,,,,,,,,,,,, -462500,4.0574317,0.6317414,,,,,,,,,,,,,, -462531,,,0.9612165093421936,0.1455664485692978,0.7572000026702881,1.0417051315307615,50000.0,0.6313000321388245,1.8227357864379885,10000.0,155611.86061382294,161038.8993780613,155611.86061382294,5386.340245485306,23.79537916183472,0.0 -462600,4.6394033,0.63832366,,,,,,,,,,,,,, -462700,4.380878,0.6227929,,,,,,,,,,,,,, -462800,4.5061274,0.6298896,,,,,,,,,,,,,, -462900,4.058756,0.53627884,,,,,,,,,,,,,, -463000,5.3232913,0.66667885,,,,,,,,,,,,,, -463100,4.8796625,0.6370941,,,,,,,,,,,,,, -463200,4.457771,0.67958426,,,,,,,,,,,,,, -463300,4.3755255,0.60526824,,,,,,,,,,,,,, -463400,4.5310597,0.6566696,,,,,,,,,,,,,, -463500,4.4275312,0.6215725,,,,,,,,,,,,,, -463600,4.394706,0.620887,,,,,,,,,,,,,, -463700,4.2607894,0.6217979,,,,,,,,,,,,,, -463800,4.3096256,0.6651481,,,,,,,,,,,,,, -463900,4.262106,0.51295537,,,,,,,,,,,,,, -464000,4.095574,0.5688583,,,,,,,,,,,,,, -464047,,,0.9599210619926452,0.1485941559076309,0.757099986076355,1.0420114994049072,50000.0,0.6318000555038452,1.824604034423828,10000.0,156121.84329080582,161565.87942004204,156121.84329080582,5403.175386667252,23.90141129493713,0.0 -464100,5.0515056,0.6679156,,,,,,,,,,,,,, -464200,4.8078117,0.6380962,,,,,,,,,,,,,, -464300,5.0223346,0.700095,,,,,,,,,,,,,, -464400,4.566243,0.6262926,,,,,,,,,,,,,, -464500,4.401873,0.65099376,,,,,,,,,,,,,, -464600,3.8745594,0.6048338,,,,,,,,,,,,,, -464700,3.8724415,0.48977363,,,,,,,,,,,,,, -464800,4.9424753,0.5842551,,,,,,,,,,,,,, -464900,4.864207,0.62958246,,,,,,,,,,,,,, -465000,4.425806,0.6125684,,,,,,,,,,,,,, -465100,5.143248,0.7035601,,,,,,,,,,,,,, -465200,4.502163,0.5657792,,,,,,,,,,,,,, -465300,4.830948,0.5766117,,,,,,,,,,,,,, -465400,4.3118496,0.5903506,,,,,,,,,,,,,, -465500,4.5547028,0.5848843,,,,,,,,,,,,,, -465565,,,0.960359513759613,0.1463554352521896,0.7572799921035767,1.0423952341079712,50000.0,0.6313000321388245,1.8224778175354004,10000.0,156631.992497921,162093.26338124275,156631.992497921,5420.25766825676,23.998153924942017,0.0 -465600,5.043512,0.682996,,,,,,,,,,,,,, -465700,4.265084,0.6049511,,,,,,,,,,,,,, -465800,4.8546886,0.6522085,,,,,,,,,,,,,, -465900,4.5377574,0.61708677,,,,,,,,,,,,,, -466000,4.4136734,0.63866115,,,,,,,,,,,,,, -466100,4.6092663,0.55946326,,,,,,,,,,,,,, -466200,4.567559,0.5861444,,,,,,,,,,,,,, -466300,4.388339,0.63169587,,,,,,,,,,,,,, -466400,4.6166477,0.54186666,,,,,,,,,,,,,, -466500,4.5684595,0.66930884,,,,,,,,,,,,,, -466600,4.439531,0.66746026,,,,,,,,,,,,,, -466700,4.711964,0.7055348,,,,,,,,,,,,,, -466800,4.358971,0.6195855,,,,,,,,,,,,,, -466900,4.113766,0.59199,,,,,,,,,,,,,, -467000,4.6565633,0.6234186,,,,,,,,,,,,,, -467082,,,0.9616350531578064,0.1432344317436218,0.7572799921035767,1.0408964157104492,50000.0,0.6315000057220459,1.8228623867034912,10000.0,157142.13254094124,162620.42335128784,157142.13254094124,5437.138303518295,24.08114218711853,0.0 -467100,4.7628593,0.6102118,,,,,,,,,,,,,, -467200,4.409712,0.69933146,,,,,,,,,,,,,, -467300,4.9155803,0.78070414,,,,,,,,,,,,,, -467400,4.476949,0.66655636,,,,,,,,,,,,,, -467500,5.058713,0.6316745,,,,,,,,,,,,,, -467600,4.340232,0.6259342,,,,,,,,,,,,,, -467700,3.9954212,0.4972077,,,,,,,,,,,,,, -467800,4.247135,0.6536913,,,,,,,,,,,,,, -467900,4.735193,0.6750642,,,,,,,,,,,,,, -468000,4.370177,0.6335391,,,,,,,,,,,,,, -468100,4.445069,0.67927504,,,,,,,,,,,,,, -468200,4.3853254,0.6198779,,,,,,,,,,,,,, -468300,4.437877,0.6022761,,,,,,,,,,,,,, -468400,4.604593,0.7050027,,,,,,,,,,,,,, -468500,4.2432165,0.51791424,,,,,,,,,,,,,, -468600,,,0.9614157676696776,0.146223098039627,0.7571600079536438,1.04023540019989,50000.0,0.6325000524520874,1.8226418495178225,10000.0,157652.25593709946,163147.48210787773,157652.25593709946,5453.913333177567,24.18593978881836,0.0 -468600,4.853472,0.715075,,,,,,,,,,,,,, -468700,4.6082497,0.64886475,,,,,,,,,,,,,, -468800,4.2651734,0.59363663,,,,,,,,,,,,,, -468900,4.5587516,0.67865366,,,,,,,,,,,,,, -469000,4.39212,0.5664227,,,,,,,,,,,,,, -469100,4.299207,0.6244228,,,,,,,,,,,,,, -469200,4.4370384,0.5986862,,,,,,,,,,,,,, -469300,4.5026793,0.60749996,,,,,,,,,,,,,, -469400,4.0559373,0.57398677,,,,,,,,,,,,,, -469500,4.8132334,0.6004555,,,,,,,,,,,,,, -469600,4.2789135,0.6118182,,,,,,,,,,,,,, -469700,4.403636,0.574396,,,,,,,,,,,,,, -469800,4.440074,0.5977814,,,,,,,,,,,,,, -469900,4.506206,0.647182,,,,,,,,,,,,,, -470000,4.545977,0.6331189,,,,,,,,,,,,,, -470100,4.885868,0.65467507,,,,,,,,,,,,,, -470116,,,0.9616549611091614,0.1469157487154007,0.7570599913597107,1.0405776500701904,50000.0,0.6308000087738037,1.8218297958374023,10000.0,158162.1717247963,163674.5309638977,158162.1717247963,5470.884748220444,24.29122257232666,0.0 -470200,4.928593,0.64661896,,,,,,,,,,,,,, -470300,4.6151347,0.63028616,,,,,,,,,,,,,, -470400,4.8417625,0.6446127,,,,,,,,,,,,,, -470500,4.4601197,0.6564477,,,,,,,,,,,,,, -470600,4.474305,0.6249712,,,,,,,,,,,,,, -470700,4.1481466,0.61485404,,,,,,,,,,,,,, -470800,4.279003,0.59933245,,,,,,,,,,,,,, -470900,4.0324616,0.5734185,,,,,,,,,,,,,, -471000,4.0561414,0.58308136,,,,,,,,,,,,,, -471100,4.3880754,0.58969253,,,,,,,,,,,,,, -471200,4.5945854,0.56415987,,,,,,,,,,,,,, -471300,4.2819467,0.600056,,,,,,,,,,,,,, -471400,4.738806,0.62645173,,,,,,,,,,,,,, -471500,4.1038623,0.57886016,,,,,,,,,,,,,, -471600,4.3390226,0.6085841,,,,,,,,,,,,,, -471633,,,0.960379421710968,0.1477285623550415,0.7573399543762207,1.0409228801727295,50000.0,0.6321000456809998,1.8239001035690308,10000.0,158672.21157503128,164201.5887157917,158672.21157503128,5487.740997314453,24.3963086605072,0.0 -471700,4.990852,0.6588354,,,,,,,,,,,,,, -471800,4.666316,0.62750053,,,,,,,,,,,,,, -471900,4.4171085,0.61474454,,,,,,,,,,,,,, -472000,4.3997526,0.5422334,,,,,,,,,,,,,, -472100,4.5497475,0.65326864,,,,,,,,,,,,,, -472200,4.4662566,0.61872077,,,,,,,,,,,,,, -472300,4.197724,0.54074776,,,,,,,,,,,,,, -472400,5.085746,0.65570223,,,,,,,,,,,,,, -472500,4.079817,0.58250475,,,,,,,,,,,,,, -472600,4.698841,0.6888617,,,,,,,,,,,,,, -472700,4.1945863,0.535406,,,,,,,,,,,,,, -472800,4.080839,0.58568275,,,,,,,,,,,,,, -472900,4.3576345,0.615099,,,,,,,,,,,,,, -473000,4.755591,0.6481366,,,,,,,,,,,,,, -473100,4.675109,0.63964,,,,,,,,,,,,,, -473150,,,0.9608178734779358,0.145049899816513,0.7572599649429321,1.0415680408477783,50000.0,0.6315000057220459,1.8239738941192627,10000.0,159182.24066758156,164728.6911456585,159182.24066758156,5504.653124094009,24.502280473709103,0.0 -473200,4.6086307,0.664628,,,,,,,,,,,,,, -473300,4.3861375,0.5793659,,,,,,,,,,,,,, -473400,4.567057,0.65131044,,,,,,,,,,,,,, -473500,4.4349957,0.62140363,,,,,,,,,,,,,, -473600,4.7731695,0.6986145,,,,,,,,,,,,,, -473700,4.4353347,0.58433574,,,,,,,,,,,,,, -473800,5.208858,0.6431364,,,,,,,,,,,,,, -473900,4.9057417,0.63041395,,,,,,,,,,,,,, -474000,4.9539537,0.6902926,,,,,,,,,,,,,, -474100,4.3717237,0.598124,,,,,,,,,,,,,, -474200,4.4022818,0.54556173,,,,,,,,,,,,,, -474300,4.64568,0.69525206,,,,,,,,,,,,,, -474400,4.5466003,0.6363281,,,,,,,,,,,,,, -474500,4.811621,0.6515287,,,,,,,,,,,,,, -474600,4.4713182,0.59813976,,,,,,,,,,,,,, -474667,,,0.9606385231018066,0.1474056243896484,0.7566199898719788,1.0405768156051636,50000.0,0.6312000155448914,1.8208348751068115,10000.0,159692.33936095238,165255.85982394218,159692.33936095238,5521.564659357071,24.605119466781616,0.0 -474700,4.746624,0.6652383,,,,,,,,,,,,,, -474800,5.224112,0.68316805,,,,,,,,,,,,,, -474900,4.8368974,0.697817,,,,,,,,,,,,,, -475000,4.5423207,0.5710685,,,,,,,,,,,,,, -475100,4.2655816,0.62666637,,,,,,,,,,,,,, -475200,4.3131537,0.60287946,,,,,,,,,,,,,, -475300,4.376384,0.59073097,,,,,,,,,,,,,, -475400,4.7182794,0.7029401,,,,,,,,,,,,,, -475500,4.66371,0.61650765,,,,,,,,,,,,,, -475600,4.546036,0.613294,,,,,,,,,,,,,, -475700,4.953682,0.62491393,,,,,,,,,,,,,, -475800,4.377304,0.6575002,,,,,,,,,,,,,, -475900,4.6907387,0.6174394,,,,,,,,,,,,,, -476000,4.3313837,0.65869033,,,,,,,,,,,,,, -476100,4.4592915,0.62768394,,,,,,,,,,,,,, -476184,,,0.9614556431770324,0.1444203853607177,0.7572399973869324,1.0410040616989136,50000.0,0.6312000155448914,1.823325634002685,10000.0,160202.31282019615,165782.85601115227,160202.31282019615,5538.428807735443,24.70886778831482,0.0 -476200,4.8504844,0.61431545,,,,,,,,,,,,,, -476300,4.4115005,0.59435976,,,,,,,,,,,,,, -476400,4.4806256,0.6830899,,,,,,,,,,,,,, -476500,4.2090826,0.58592904,,,,,,,,,,,,,, -476600,4.2659135,0.5916255,,,,,,,,,,,,,, -476700,4.522035,0.6222586,,,,,,,,,,,,,, -476800,5.1637077,0.65127325,,,,,,,,,,,,,, -476900,4.8718257,0.6748358,,,,,,,,,,,,,, -477000,4.571259,0.59684163,,,,,,,,,,,,,, -477100,4.8690014,0.62036556,,,,,,,,,,,,,, -477200,4.605016,0.63230526,,,,,,,,,,,,,, -477300,4.797464,0.67941296,,,,,,,,,,,,,, -477400,4.170589,0.6659562,,,,,,,,,,,,,, -477500,4.744338,0.6108447,,,,,,,,,,,,,, -477600,5.5334296,0.6130045,,,,,,,,,,,,,, -477700,5.035602,0.66179115,,,,,,,,,,,,,, -477701,,,0.9598811864852904,0.148447573184967,0.7573399543762207,1.0417107343673706,50000.0,0.6321000456809998,1.8233537673950195,10000.0,160712.27799630165,166309.99308776855,160712.27799630165,5555.441893815994,24.811065435409542,0.0 -477800,4.5991936,0.62737566,,,,,,,,,,,,,, -477900,4.677163,0.6391721,,,,,,,,,,,,,, -478000,4.42158,0.6376641,,,,,,,,,,,,,, -478100,4.9031305,0.6569344,,,,,,,,,,,,,, -478200,4.7005644,0.7012731,,,,,,,,,,,,,, -478300,4.246599,0.6108893,,,,,,,,,,,,,, -478400,4.3990483,0.6424118,,,,,,,,,,,,,, -478500,4.3694353,0.66651946,,,,,,,,,,,,,, -478600,4.614206,0.6847659,,,,,,,,,,,,,, -478700,4.8831315,0.6517186,,,,,,,,,,,,,, -478800,5.2154627,0.6712451,,,,,,,,,,,,,, -478900,4.7800426,0.68320495,,,,,,,,,,,,,, -479000,4.5173783,0.65570486,,,,,,,,,,,,,, -479100,5.1214485,0.6719166,,,,,,,,,,,,,, -479200,4.420805,0.6087929,,,,,,,,,,,,,, -479218,,,0.9594228267669678,0.1483974158763885,0.7570399641990662,1.0413779020309448,50000.0,0.631600022315979,1.8240758180618288,10000.0,161222.3278567791,166837.20202589035,161222.3278567791,5572.439893722534,24.91554069519043,0.0 -479300,4.058174,0.57953686,,,,,,,,,,,,,, -479400,4.3720775,0.6586171,,,,,,,,,,,,,, -479500,5.046793,0.72064817,,,,,,,,,,,,,, -479600,4.1074495,0.57766885,,,,,,,,,,,,,, -479700,4.7413917,0.6569415,,,,,,,,,,,,,, -479800,4.3391647,0.61200386,,,,,,,,,,,,,, -479900,4.795625,0.6466028,,,,,,,,,,,,,, -480000,3.9035037,0.56152284,,,,,,,,,,,,,, -480100,4.179852,0.553855,,,,,,,,,,,,,, -480200,5.006665,0.59668297,,,,,,,,,,,,,, -480300,5.2711563,0.647792,,,,,,,,,,,,,, -480400,4.9548674,0.63029265,,,,,,,,,,,,,, -480500,5.0981064,0.65633124,,,,,,,,,,,,,, -480600,4.1523156,0.5542177,,,,,,,,,,,,,, -480700,4.5736837,0.56411374,,,,,,,,,,,,,, -480735,,,0.961136758327484,0.1473008692264557,0.7571600079536438,1.0423572063446045,50000.0,0.6304000020027161,1.824426889419556,10000.0,161732.29912495613,167364.19173169136,161732.29912495613,5589.2941699028015,25.02281379699707,0.0 -480800,4.781408,0.6997988,,,,,,,,,,,,,, -480900,4.3578973,0.5705733,,,,,,,,,,,,,, -481000,4.573297,0.5961644,,,,,,,,,,,,,, -481100,4.7255344,0.6479759,,,,,,,,,,,,,, -481200,4.822593,0.65790206,,,,,,,,,,,,,, -481300,4.285772,0.6121316,,,,,,,,,,,,,, -481400,4.9163146,0.678022,,,,,,,,,,,,,, -481500,4.4645505,0.6950327,,,,,,,,,,,,,, -481600,4.9337816,0.64031285,,,,,,,,,,,,,, -481700,4.9838734,0.59259194,,,,,,,,,,,,,, -481800,4.3497725,0.5643738,,,,,,,,,,,,,, -481900,4.814975,0.5976133,,,,,,,,,,,,,, -482000,4.619614,0.6573092,,,,,,,,,,,,,, -482100,4.20481,0.5927389,,,,,,,,,,,,,, -482200,4.537266,0.62709343,,,,,,,,,,,,,, -482252,,,0.961355984210968,0.1467293053865432,0.7570799589157104,1.0419999361038208,50000.0,0.6324000358581543,1.8233898878097528,10000.0,162242.1644639969,167891.04339289665,162242.1644639969,5606.119745254517,25.12814474105835,0.0 -482300,4.4786024,0.5602957,,,,,,,,,,,,,, -482400,4.8693933,0.6458938,,,,,,,,,,,,,, -482500,3.930316,0.5497363,,,,,,,,,,,,,, -482600,4.611001,0.6283533,,,,,,,,,,,,,, -482700,4.1241913,0.56594265,,,,,,,,,,,,,, -482800,4.1320877,0.5811509,,,,,,,,,,,,,, -482900,4.5267305,0.64352685,,,,,,,,,,,,,, -483000,4.120275,0.5880207,,,,,,,,,,,,,, -483100,4.2270007,0.60921514,,,,,,,,,,,,,, -483200,4.1294026,0.5990984,,,,,,,,,,,,,, -483300,4.417418,0.5504203,,,,,,,,,,,,,, -483400,5.0401917,0.6871095,,,,,,,,,,,,,, -483500,4.501088,0.6601143,,,,,,,,,,,,,, -483600,4.09536,0.6084113,,,,,,,,,,,,,, -483700,4.5526514,0.60588366,,,,,,,,,,,,,, -483769,,,0.960160195827484,0.1499893218278885,0.7566999793052673,1.0416303873062134,50000.0,0.6326000094413757,1.8239266872406008,10000.0,162752.32794451714,168418.32730579376,162752.32794451714,5623.075105905533,25.23688054084778,0.0 -483800,4.3268204,0.5969467,,,,,,,,,,,,,, -483900,4.4910026,0.5921594,,,,,,,,,,,,,, -484000,4.3678126,0.620151,,,,,,,,,,,,,, -484100,4.2295113,0.61116487,,,,,,,,,,,,,, -484200,4.630763,0.6735246,,,,,,,,,,,,,, -484300,4.612898,0.6752237,,,,,,,,,,,,,, -484400,4.5041914,0.6355891,,,,,,,,,,,,,, -484500,4.725616,0.66460073,,,,,,,,,,,,,, -484600,4.846677,0.655715,,,,,,,,,,,,,, -484700,4.3637576,0.60567594,,,,,,,,,,,,,, -484800,5.134781,0.6719397,,,,,,,,,,,,,, -484900,4.5272126,0.6922562,,,,,,,,,,,,,, -485000,4.547089,0.63171506,,,,,,,,,,,,,, -485100,4.506106,0.67436826,,,,,,,,,,,,,, -485200,4.3067837,0.5663617,,,,,,,,,,,,,, -485286,,,0.9605787396430968,0.1452038437128067,0.7571199536323547,1.0410614013671875,50000.0,0.6328000426292419,1.823925495147705,10000.0,163262.20376372337,168945.26424503326,163262.20376372337,5639.970458984375,25.34450364112854,0.0 -485300,4.289194,0.6119518,,,,,,,,,,,,,, -485400,4.1524324,0.6275062,,,,,,,,,,,,,, -485500,4.504818,0.72568774,,,,,,,,,,,,,, -485600,4.6119514,0.6478163,,,,,,,,,,,,,, -485700,4.2814007,0.63630825,,,,,,,,,,,,,, -485800,4.6219115,0.58530414,,,,,,,,,,,,,, -485900,4.320076,0.6010492,,,,,,,,,,,,,, -486000,4.6350703,0.5838104,,,,,,,,,,,,,, -486100,4.59082,0.7032657,,,,,,,,,,,,,, -486200,4.335889,0.6060481,,,,,,,,,,,,,, -486300,4.525886,0.69270575,,,,,,,,,,,,,, -486400,4.298847,0.60807925,,,,,,,,,,,,,, -486500,4.137527,0.5430184,,,,,,,,,,,,,, -486600,4.397029,0.61295307,,,,,,,,,,,,,, -486700,4.636309,0.66323817,,,,,,,,,,,,,, -486800,4.462772,0.63359344,,,,,,,,,,,,,, -486803,,,0.9608777165412904,0.1467392593622207,0.7570199966430664,1.04235577583313,50000.0,0.6305000185966492,1.8246389627456665,10000.0,163772.209897995,169472.41901779175,163772.209897995,5656.945882558823,25.46109294891357,0.0 -486900,4.5294185,0.5740764,,,,,,,,,,,,,, -487000,4.0767145,0.5979098,,,,,,,,,,,,,, -487100,4.647384,0.60270536,,,,,,,,,,,,,, -487200,4.349058,0.60155535,,,,,,,,,,,,,, -487300,4.7502728,0.651081,,,,,,,,,,,,,, -487400,4.552668,0.6424098,,,,,,,,,,,,,, -487500,4.165813,0.6049246,,,,,,,,,,,,,, -487600,4.771596,0.65387714,,,,,,,,,,,,,, -487700,4.8978624,0.6243271,,,,,,,,,,,,,, -487800,4.489555,0.5591641,,,,,,,,,,,,,, -487900,4.5384107,0.6602362,,,,,,,,,,,,,, -488000,4.4839644,0.6760032,,,,,,,,,,,,,, -488100,4.643051,0.6790212,,,,,,,,,,,,,, -488200,4.8147774,0.62584716,,,,,,,,,,,,,, -488300,4.3014874,0.58514076,,,,,,,,,,,,,, -488319,,,0.9600605964660645,0.1470848321914672,0.7570799589157104,1.0419479608535769,50000.0,0.6313000321388245,1.822999119758606,10000.0,164282.22924995422,169999.69826436043,164282.22924995422,5673.986569881439,25.623329162597656,0.0 -488400,4.3943076,0.5660416,,,,,,,,,,,,,, -488500,4.5275126,0.6408858,,,,,,,,,,,,,, -488600,4.38474,0.6488714,,,,,,,,,,,,,, -488700,4.4882083,0.6970992,,,,,,,,,,,,,, -488800,4.733209,0.6813349,,,,,,,,,,,,,, -488900,4.532814,0.6093347,,,,,,,,,,,,,, -489000,4.1735353,0.508953,,,,,,,,,,,,,, -489100,4.97972,0.6719574,,,,,,,,,,,,,, -489200,4.3886614,0.6097966,,,,,,,,,,,,,, -489300,4.4809494,0.56315106,,,,,,,,,,,,,, -489400,4.3864126,0.65540004,,,,,,,,,,,,,, -489500,4.6812806,0.71162003,,,,,,,,,,,,,, -489600,4.622025,0.67233014,,,,,,,,,,,,,, -489700,4.516901,0.6132461,,,,,,,,,,,,,, -489800,4.5614824,0.6320194,,,,,,,,,,,,,, -489837,,,0.958984375,0.148769661784172,0.7572799921035767,1.041583776473999,50000.0,0.6313000321388245,1.8239089250564573,10000.0,164792.2272515297,170526.86529254913,164792.2272515297,5690.992618322372,25.73021769523621,0.0 -489900,4.858511,0.58972454,,,,,,,,,,,,,, -490000,4.404253,0.58933157,,,,,,,,,,,,,, -490100,4.852884,0.65320235,,,,,,,,,,,,,, -490200,4.1578393,0.5658183,,,,,,,,,,,,,, -490300,4.4564257,0.59557825,,,,,,,,,,,,,, -490400,4.3657804,0.60066164,,,,,,,,,,,,,, -490500,4.7095323,0.5882438,,,,,,,,,,,,,, -490600,4.648046,0.6481915,,,,,,,,,,,,,, -490700,4.544688,0.6649115,,,,,,,,,,,,,, -490800,4.2835126,0.58147526,,,,,,,,,,,,,, -490900,3.8327858,0.5151062,,,,,,,,,,,,,, -491000,4.9020996,0.6112831,,,,,,,,,,,,,, -491100,4.2191844,0.58198,,,,,,,,,,,,,, -491200,4.255231,0.62479496,,,,,,,,,,,,,, -491300,4.533297,0.62025225,,,,,,,,,,,,,, -491353,,,0.9603196382522584,0.144847884774208,0.7572199702262878,1.0405925512313845,50000.0,0.631600022315979,1.822745442390442,10000.0,165302.07374358177,171053.79780197144,165302.07374358177,5707.914986610413,25.83800363540649,0.0 -491400,4.3058047,0.6046258,,,,,,,,,,,,,, -491500,4.612462,0.6051252,,,,,,,,,,,,,, -491600,4.5354695,0.6125638,,,,,,,,,,,,,, -491700,5.2647843,0.74998176,,,,,,,,,,,,,, -491800,4.806642,0.6177288,,,,,,,,,,,,,, -491900,4.52749,0.61025536,,,,,,,,,,,,,, -492000,4.8215756,0.6276847,,,,,,,,,,,,,, -492100,4.631843,0.6151999,,,,,,,,,,,,,, -492200,4.4336915,0.602311,,,,,,,,,,,,,, -492300,4.9156384,0.6275697,,,,,,,,,,,,,, -492400,4.227675,0.5124366,,,,,,,,,,,,,, -492500,4.3349595,0.6279941,,,,,,,,,,,,,, -492600,4.094224,0.55834436,,,,,,,,,,,,,, -492700,4.4559946,0.60584223,,,,,,,,,,,,,, -492800,4.8376174,0.66191995,,,,,,,,,,,,,, -492869,,,0.9604192972183228,0.1478891372680664,0.7571799755096436,1.0413697957992554,50000.0,0.6320000290870667,1.8230407238006592,10000.0,165812.0014333725,171580.78747987747,165812.0014333725,5724.814933538437,25.942726135253903,0.0 -492900,4.431334,0.625857,,,,,,,,,,,,,, -493000,4.769281,0.57280046,,,,,,,,,,,,,, -493100,4.0303793,0.6233356,,,,,,,,,,,,,, -493200,4.51,0.5987247,,,,,,,,,,,,,, -493300,4.069623,0.5790268,,,,,,,,,,,,,, -493400,4.943199,0.61253345,,,,,,,,,,,,,, -493500,4.745415,0.6427975,,,,,,,,,,,,,, -493600,4.764192,0.67874867,,,,,,,,,,,,,, -493700,4.304006,0.637123,,,,,,,,,,,,,, -493800,4.567847,0.6372508,,,,,,,,,,,,,, -493900,4.3849525,0.55179244,,,,,,,,,,,,,, -494000,4.3169312,0.6052697,,,,,,,,,,,,,, -494100,4.331297,0.5719707,,,,,,,,,,,,,, -494200,5.0675435,0.6423183,,,,,,,,,,,,,, -494300,4.6796007,0.6317118,,,,,,,,,,,,,, -494387,,,0.961535394191742,0.1452343314886093,0.7576199769973755,1.041755199432373,50000.0,0.6312000155448914,1.8221620321273804,10000.0,166322.11298918724,172107.98474049568,166322.11298918724,5741.729859828949,26.05689764022827,0.0 -494400,4.4720564,0.630704,,,,,,,,,,,,,, -494500,5.136018,0.6082601,,,,,,,,,,,,,, -494600,4.78477,0.5413524,,,,,,,,,,,,,, -494700,4.591573,0.56961095,,,,,,,,,,,,,, -494800,4.7527847,0.63941187,,,,,,,,,,,,,, -494900,3.9777327,0.55828446,,,,,,,,,,,,,, -495000,4.785381,0.58954674,,,,,,,,,,,,,, -495100,4.809218,0.6899008,,,,,,,,,,,,,, -495200,4.5494256,0.596705,,,,,,,,,,,,,, -495300,4.148958,0.5629759,,,,,,,,,,,,,, -495400,4.6859446,0.6185317,,,,,,,,,,,,,, -495500,4.7498655,0.607046,,,,,,,,,,,,,, -495600,4.586253,0.63411343,,,,,,,,,,,,,, -495700,4.290696,0.6149479,,,,,,,,,,,,,, -495800,4.691958,0.6861392,,,,,,,,,,,,,, -495900,4.6119576,0.6134005,,,,,,,,,,,,,, -495904,,,0.9598413109779358,0.1481824964284896,0.7570199966430664,1.041217803955078,50000.0,0.6308000087738037,1.8224083185195925,10000.0,166832.02917313576,172634.86578130722,166832.02917313576,5758.526192903519,26.16771912574768,0.0 -496000,4.8068995,0.61486197,,,,,,,,,,,,,, -496100,5.0013356,0.707299,,,,,,,,,,,,,, -496200,4.5860085,0.62600714,,,,,,,,,,,,,, -496300,4.874531,0.6450001,,,,,,,,,,,,,, -496400,4.840782,0.6665878,,,,,,,,,,,,,, -496500,4.622998,0.6636532,,,,,,,,,,,,,, -496600,4.517449,0.6356748,,,,,,,,,,,,,, -496700,4.7833834,0.641303,,,,,,,,,,,,,, -496800,4.44349,0.64842904,,,,,,,,,,,,,, -496900,4.3619576,0.63766074,,,,,,,,,,,,,, -497000,4.8655157,0.7102927,,,,,,,,,,,,,, -497100,4.961804,0.6460866,,,,,,,,,,,,,, -497200,4.878756,0.636438,,,,,,,,,,,,,, -497300,4.5504208,0.6044899,,,,,,,,,,,,,, -497400,4.3358936,0.61347467,,,,,,,,,,,,,, -497422,,,0.9609375,0.1445536315441131,0.7573999762535095,1.041479468345642,50000.0,0.631600022315979,1.8234856128692627,10000.0,167341.9730234146,173161.72937083244,167341.9730234146,5775.282235145569,26.27513027191162,0.0 -497500,4.5910172,0.633555,,,,,,,,,,,,,, -497600,5.198859,0.6727281,,,,,,,,,,,,,, -497700,5.294628,0.6489875,,,,,,,,,,,,,, -497800,4.363924,0.6332,,,,,,,,,,,,,, -497900,4.522654,0.58490366,,,,,,,,,,,,,, -498000,4.080454,0.5759215,,,,,,,,,,,,,, -498100,4.7826815,0.58650684,,,,,,,,,,,,,, -498200,4.071689,0.64582914,,,,,,,,,,,,,, -498300,4.4138265,0.58503085,,,,,,,,,,,,,, -498400,5.0398703,0.7142418,,,,,,,,,,,,,, -498500,4.41324,0.621466,,,,,,,,,,,,,, -498600,4.440029,0.645473,,,,,,,,,,,,,, -498700,4.2505665,0.5237534,,,,,,,,,,,,,, -498800,4.7609086,0.6200069,,,,,,,,,,,,,, -498900,4.7252913,0.6049161,,,,,,,,,,,,,, -498940,,,0.9624322056770324,0.1402934044599533,0.7572799921035767,1.0417726039886477,50000.0,0.6319000124931335,1.822174072265625,10000.0,167852.10940527916,173688.8420562744,167852.10940527916,5792.087192296982,26.38953804969788,0.0 -499000,4.196179,0.6120734,,,,,,,,,,,,,, -499100,4.5531025,0.7060746,,,,,,,,,,,,,, -499200,4.702972,0.6485948,,,,,,,,,,,,,, -499300,4.426472,0.65956175,,,,,,,,,,,,,, -499400,4.453804,0.63817734,,,,,,,,,,,,,, -499500,4.644367,0.6006508,,,,,,,,,,,,,, -499600,4.369398,0.60880685,,,,,,,,,,,,,, -499700,4.572291,0.6306212,,,,,,,,,,,,,, -499800,4.3819942,0.6412814,,,,,,,,,,,,,, -499900,4.2526855,0.6594089,,,,,,,,,,,,,, -500000,4.988449,0.65201026,,,,,,,,,,,,,, -500100,3.9438527,0.62258226,,,,,,,,,,,,,, -500200,4.2870116,0.61535954,,,,,,,,,,,,,, -500300,4.1252265,0.6290223,,,,,,,,,,,,,, -500400,4.937342,0.629986,,,,,,,,,,,,,, -500457,,,0.962511956691742,0.1444217264652252,0.7576000094413757,1.041919469833374,50000.0,0.6327000260353088,1.8239601850509644,10000.0,168362.00852704048,174215.71746110916,168362.00852704048,5808.90158700943,26.495758056640625,0.0 -500500,4.645991,0.5998092,,,,,,,,,,,,,, -500600,4.0317173,0.56508315,,,,,,,,,,,,,, -500700,4.4753747,0.6396721,,,,,,,,,,,,,, -500800,4.3708315,0.5796211,,,,,,,,,,,,,, -500900,4.6045394,0.640413,,,,,,,,,,,,,, -501000,4.528598,0.6292993,,,,,,,,,,,,,, -501100,5.065487,0.6337638,,,,,,,,,,,,,, -501200,4.294887,0.66474944,,,,,,,,,,,,,, -501300,4.492872,0.58233845,,,,,,,,,,,,,, -501400,4.4256325,0.56689835,,,,,,,,,,,,,, -501500,4.8639956,0.6474956,,,,,,,,,,,,,, -501600,4.689638,0.70968926,,,,,,,,,,,,,, -501700,4.6262927,0.7042699,,,,,,,,,,,,,, -501800,4.59189,0.6822651,,,,,,,,,,,,,, -501900,4.476523,0.6402633,,,,,,,,,,,,,, -501974,,,0.9595623016357422,0.1485275775194168,0.7571799755096436,1.04130756855011,50000.0,0.6307000517845154,1.8220081329345703,10000.0,168871.98159265518,174743.40242886543,168871.98159265518,5826.445637702942,26.60779643058777,0.0 -502000,4.0371947,0.57727593,,,,,,,,,,,,,, -502100,4.1971054,0.62054574,,,,,,,,,,,,,, -502200,4.8680844,0.61763287,,,,,,,,,,,,,, -502300,4.951916,0.6531404,,,,,,,,,,,,,, -502400,4.35788,0.61190444,,,,,,,,,,,,,, -502500,4.0448966,0.60774314,,,,,,,,,,,,,, -502600,4.4389906,0.5877922,,,,,,,,,,,,,, -502700,4.1752276,0.5637508,,,,,,,,,,,,,, -502800,4.313107,0.58510745,,,,,,,,,,,,,, -502900,4.7953715,0.66825384,,,,,,,,,,,,,, -503000,4.5657773,0.6300798,,,,,,,,,,,,,, -503100,4.4608183,0.618524,,,,,,,,,,,,,, -503200,4.3080926,0.63156223,,,,,,,,,,,,,, -503300,4.529392,0.64192325,,,,,,,,,,,,,, -503400,4.512015,0.6559139,,,,,,,,,,,,,, -503490,,,0.9605388641357422,0.1469699740409851,0.7571399807929993,1.0414334535598757,50000.0,0.6318000555038452,1.8225680589675903,10000.0,169381.893055439,175270.41539549828,169381.893055439,5843.382396221161,26.7176833152771,0.0 -503500,4.462272,0.63425905,,,,,,,,,,,,,, -503600,4.6918483,0.6328279,,,,,,,,,,,,,, -503700,4.441492,0.64126027,,,,,,,,,,,,,, -503800,4.554354,0.6372869,,,,,,,,,,,,,, -503900,4.0673532,0.5601605,,,,,,,,,,,,,, -504000,4.691594,0.6349085,,,,,,,,,,,,,, -504100,4.850641,0.7110781,,,,,,,,,,,,,, -504200,4.5010233,0.603009,,,,,,,,,,,,,, -504300,5.409766,0.6926988,,,,,,,,,,,,,, -504400,4.4122853,0.57656634,,,,,,,,,,,,,, -504500,4.7347307,0.6341116,,,,,,,,,,,,,, -504600,4.529604,0.6540718,,,,,,,,,,,,,, -504700,4.612943,0.6500509,,,,,,,,,,,,,, -504800,4.714152,0.66060317,,,,,,,,,,,,,, -504900,4.2450213,0.6308922,,,,,,,,,,,,,, -505000,4.569392,0.6379311,,,,,,,,,,,,,, -505007,,,0.9607979655265808,0.145893707871437,0.7574999928474426,1.0407086610794067,50000.0,0.631100058555603,1.8230360746383667,10000.0,169891.74683499336,175797.27426624298,169891.74683499336,5860.224097967148,26.82484793663025,0.0 -505100,4.5664477,0.61124337,,,,,,,,,,,,,, -505200,4.4744067,0.6007065,,,,,,,,,,,,,, -505300,4.6190686,0.6482388,,,,,,,,,,,,,, -505400,4.38672,0.56887746,,,,,,,,,,,,,, -505500,4.4868317,0.57887447,,,,,,,,,,,,,, -505600,4.6852875,0.6230057,,,,,,,,,,,,,, -505700,4.429287,0.6606158,,,,,,,,,,,,,, -505800,4.287343,0.67120785,,,,,,,,,,,,,, -505900,4.936763,0.6407752,,,,,,,,,,,,,, -506000,4.5695853,0.6774559,,,,,,,,,,,,,, -506100,4.447115,0.5775546,,,,,,,,,,,,,, -506200,4.832435,0.668188,,,,,,,,,,,,,, -506300,4.274819,0.61672586,,,,,,,,,,,,,, -506400,4.6645064,0.6198033,,,,,,,,,,,,,, -506500,4.5946608,0.66458464,,,,,,,,,,,,,, -506524,,,0.9612165093421936,0.144617959856987,0.7572000026702881,1.0419772863388062,50000.0,0.6308000087738037,1.822520613670349,10000.0,170401.87842082977,176324.54439496994,170401.87842082977,5877.193679094315,26.938070058822632,0.0 -506600,4.198103,0.5620281,,,,,,,,,,,,,, -506700,4.400726,0.63378876,,,,,,,,,,,,,, -506800,4.43518,0.57989556,,,,,,,,,,,,,, -506900,4.884935,0.5745439,,,,,,,,,,,,,, -507000,4.5235915,0.6541533,,,,,,,,,,,,,, -507100,4.455868,0.669033,,,,,,,,,,,,,, -507200,4.76114,0.62467206,,,,,,,,,,,,,, -507300,4.6461554,0.6946596,,,,,,,,,,,,,, -507400,4.0966415,0.58346945,,,,,,,,,,,,,, -507500,4.257208,0.6668612,,,,,,,,,,,,,, -507600,4.1322327,0.6086919,,,,,,,,,,,,,, -507700,5.142344,0.7265978,,,,,,,,,,,,,, -507800,5.251101,0.6447264,,,,,,,,,,,,,, -507900,4.650415,0.58173794,,,,,,,,,,,,,, -508000,4.648533,0.63155264,,,,,,,,,,,,,, -508040,,,0.9621930718421936,0.1429073512554168,0.7571600079536438,1.04066002368927,50000.0,0.6322000026702881,1.8211873769760127,10000.0,170911.74804997444,176851.41265058515,170911.74804997444,5894.027221918106,27.048126935958862,0.0 -508100,4.478267,0.5861411,,,,,,,,,,,,,, -508200,4.4094872,0.5624354,,,,,,,,,,,,,, -508300,4.5345073,0.6028673,,,,,,,,,,,,,, -508400,4.3374496,0.58469445,,,,,,,,,,,,,, -508500,4.682677,0.6910105,,,,,,,,,,,,,, -508600,4.275753,0.61881256,,,,,,,,,,,,,, -508700,4.6932635,0.6714707,,,,,,,,,,,,,, -508800,4.5620775,0.6144046,,,,,,,,,,,,,, -508900,4.433331,0.6901371,,,,,,,,,,,,,, -509000,4.2621,0.6196506,,,,,,,,,,,,,, -509100,4.4663515,0.66059023,,,,,,,,,,,,,, -509200,4.5351896,0.6272759,,,,,,,,,,,,,, -509300,4.749971,0.6463543,,,,,,,,,,,,,, -509400,4.5281324,0.65674746,,,,,,,,,,,,,, -509500,4.6639624,0.6435718,,,,,,,,,,,,,, -509558,,,0.9604392051696776,0.148814707994461,0.7572999596595764,1.041728973388672,50000.0,0.6310000419616699,1.824790477752685,10000.0,171421.61796426773,177378.3165552616,171421.61796426773,5910.898756742477,27.1546368598938,0.0 -509600,4.256645,0.6391646,,,,,,,,,,,,,, -509700,4.5572405,0.6174885,,,,,,,,,,,,,, -509800,4.3262053,0.61186457,,,,,,,,,,,,,, -509900,4.568694,0.6095199,,,,,,,,,,,,,, -510000,4.695811,0.7400114,,,,,,,,,,,,,, -510100,4.2636137,0.56241035,,,,,,,,,,,,,, -510200,4.2635007,0.5659979,,,,,,,,,,,,,, -510300,4.3493223,0.6074095,,,,,,,,,,,,,, -510400,4.5518064,0.65679723,,,,,,,,,,,,,, -510500,4.3850126,0.6197254,,,,,,,,,,,,,, -510600,4.254719,0.56972384,,,,,,,,,,,,,, -510700,4.8960133,0.6201197,,,,,,,,,,,,,, -510800,4.533383,0.59196854,,,,,,,,,,,,,, -510900,4.752801,0.65511984,,,,,,,,,,,,,, -511000,4.157417,0.54991084,,,,,,,,,,,,,, -511075,,,0.960558831691742,0.1479910314083099,0.7572000026702881,1.0400274991989136,50000.0,0.6322000026702881,1.821224331855774,10000.0,171931.7805171013,177905.50328946114,171931.7805171013,5927.751242399216,27.26954126358032,0.0 -511100,4.4991856,0.64147186,,,,,,,,,,,,,, -511200,4.236922,0.6002455,,,,,,,,,,,,,, -511300,4.516514,0.61861837,,,,,,,,,,,,,, -511400,4.9775767,0.66232777,,,,,,,,,,,,,, -511500,4.190995,0.53827566,,,,,,,,,,,,,, -511600,4.565513,0.6436808,,,,,,,,,,,,,, -511700,4.423773,0.6287651,,,,,,,,,,,,,, -511800,5.1179385,0.64959776,,,,,,,,,,,,,, -511900,4.5478053,0.55822575,,,,,,,,,,,,,, -512000,4.6701097,0.65567017,,,,,,,,,,,,,, -512100,4.69882,0.6151976,,,,,,,,,,,,,, -512200,4.240812,0.6369301,,,,,,,,,,,,,, -512300,4.8140345,0.6717553,,,,,,,,,,,,,, -512400,4.275506,0.6360599,,,,,,,,,,,,,, -512500,4.4754972,0.6076947,,,,,,,,,,,,,, -512592,,,0.9605787396430968,0.1462783068418502,0.7572399973869324,1.0417357683181765,50000.0,0.6326000094413757,1.822581768035889,10000.0,172441.7371351719,178432.52074337006,172441.7371351719,5944.638758897781,27.386696577072144,0.0 -512600,4.7056403,0.73342276,,,,,,,,,,,,,, -512700,4.950279,0.691019,,,,,,,,,,,,,, -512800,4.655879,0.5846454,,,,,,,,,,,,,, -512900,5.2304864,0.7022755,,,,,,,,,,,,,, -513000,4.027273,0.57838124,,,,,,,,,,,,,, -513100,4.76773,0.648595,,,,,,,,,,,,,, -513200,4.187509,0.598391,,,,,,,,,,,,,, -513300,4.671516,0.61722827,,,,,,,,,,,,,, -513400,4.3773727,0.59078825,,,,,,,,,,,,,, -513500,4.019787,0.56128156,,,,,,,,,,,,,, -513600,4.2183337,0.60819936,,,,,,,,,,,,,, -513700,4.515681,0.61720234,,,,,,,,,,,,,, -513800,4.136966,0.60393846,,,,,,,,,,,,,, -513900,4.2173724,0.62648857,,,,,,,,,,,,,, -514000,4.224114,0.62087446,,,,,,,,,,,,,, -514100,4.356798,0.5860096,,,,,,,,,,,,,, -514109,,,0.9622528553009032,0.1430338323116302,0.7570399641990662,1.0404475927352903,50000.0,0.6318000555038452,1.8226438760757449,10000.0,172951.64312791824,178959.5411233902,172951.64312791824,5961.57798075676,27.50553941726685,0.0 -514200,4.3949604,0.65991443,,,,,,,,,,,,,, -514300,4.7408123,0.68782485,,,,,,,,,,,,,, -514400,4.7486296,0.6416152,,,,,,,,,,,,,, -514500,4.2924314,0.6126834,,,,,,,,,,,,,, -514600,3.9663095,0.57161343,,,,,,,,,,,,,, -514700,4.6312456,0.6648438,,,,,,,,,,,,,, -514800,4.197746,0.5400723,,,,,,,,,,,,,, -514900,4.258272,0.6128402,,,,,,,,,,,,,, -515000,4.503245,0.58823943,,,,,,,,,,,,,, -515100,4.292984,0.580925,,,,,,,,,,,,,, -515200,4.2432585,0.6057761,,,,,,,,,,,,,, -515300,4.4704022,0.64500815,,,,,,,,,,,,,, -515400,4.859903,0.66014194,,,,,,,,,,,,,, -515500,4.484415,0.5753179,,,,,,,,,,,,,, -515600,4.473571,0.64333737,,,,,,,,,,,,,, -515627,,,0.9596619606018066,0.1483250558376312,0.757319986820221,1.0403424501419067,50000.0,0.6319000124931335,1.8204569816589355,10000.0,173461.72400903702,179486.58996462822,173461.72400903702,5978.376556158066,27.6191668510437,0.0 -515700,4.4613843,0.6333646,,,,,,,,,,,,,, -515800,4.3855968,0.53576875,,,,,,,,,,,,,, -515900,5.329956,0.6437482,,,,,,,,,,,,,, -516000,4.969327,0.66115683,,,,,,,,,,,,,, -516100,4.498637,0.61482894,,,,,,,,,,,,,, -516200,4.6571913,0.59634024,,,,,,,,,,,,,, -516300,5.1576996,0.75883734,,,,,,,,,,,,,, -516400,5.0853157,0.7017213,,,,,,,,,,,,,, -516500,5.014886,0.61762863,,,,,,,,,,,,,, -516600,4.2790203,0.5804527,,,,,,,,,,,,,, -516700,4.351078,0.5903647,,,,,,,,,,,,,, -516800,4.684535,0.6765026,,,,,,,,,,,,,, -516900,4.990137,0.7189627,,,,,,,,,,,,,, -517000,4.6405115,0.6585871,,,,,,,,,,,,,, -517100,4.536096,0.7244304,,,,,,,,,,,,,, -517144,,,0.9599210619926452,0.1483697742223739,0.7571799755096436,1.0417215824127195,50000.0,0.6307000517845154,1.82349693775177,10000.0,173971.74058961868,180013.60146594048,173971.74058961868,5995.200742959976,27.73391819000244,0.0 -517200,4.448108,0.5947997,,,,,,,,,,,,,, -517300,4.4708514,0.57228047,,,,,,,,,,,,,, -517400,4.366251,0.5964542,,,,,,,,,,,,,, -517500,4.650386,0.6339677,,,,,,,,,,,,,, -517600,4.1142883,0.59502256,,,,,,,,,,,,,, -517700,4.8801146,0.6649856,,,,,,,,,,,,,, -517800,4.137894,0.61761856,,,,,,,,,,,,,, -517900,4.090462,0.5647037,,,,,,,,,,,,,, -518000,4.860233,0.65072334,,,,,,,,,,,,,, -518100,4.53622,0.6372873,,,,,,,,,,,,,, -518200,4.6678123,0.628682,,,,,,,,,,,,,, -518300,4.640459,0.59268713,,,,,,,,,,,,,, -518400,4.007656,0.55364466,,,,,,,,,,,,,, -518500,4.403571,0.61028767,,,,,,,,,,,,,, -518600,4.101605,0.57732344,,,,,,,,,,,,,, -518661,,,0.959004282951355,0.1498319208621978,0.7571600079536438,1.0412065982818604,50000.0,0.6325000524520874,1.823171734809876,10000.0,174481.84917855263,180540.7841868401,174481.84917855263,6012.095266580582,27.856269359588623,0.0 -518700,4.43365,0.6252092,,,,,,,,,,,,,, -518800,4.4165597,0.63288313,,,,,,,,,,,,,, -518900,4.472256,0.6118927,,,,,,,,,,,,,, -519000,4.5116825,0.5812086,,,,,,,,,,,,,, -519100,4.2976437,0.6672658,,,,,,,,,,,,,, -519200,4.126961,0.5494013,,,,,,,,,,,,,, -519300,4.817231,0.698947,,,,,,,,,,,,,, -519400,4.2276454,0.5583579,,,,,,,,,,,,,, -519500,4.776715,0.6726587,,,,,,,,,,,,,, -519600,4.506722,0.61836857,,,,,,,,,,,,,, -519700,4.307246,0.5673036,,,,,,,,,,,,,, -519800,4.5646434,0.6827826,,,,,,,,,,,,,, -519900,4.4902368,0.5960947,,,,,,,,,,,,,, -520000,4.5921164,0.59011024,,,,,,,,,,,,,, -520100,4.3909173,0.6195875,,,,,,,,,,,,,, -520178,,,0.9616549611091614,0.1446880847215652,0.757319986820221,1.0408968925476074,50000.0,0.6322000026702881,1.822198867797852,10000.0,174991.8268210888,181067.9081850052,174991.8268210888,6029.069609165192,27.970374822616577,0.0 -520200,4.2879963,0.57184803,,,,,,,,,,,,,, -520300,4.7003365,0.7017119,,,,,,,,,,,,,, -520400,4.6705337,0.6875582,,,,,,,,,,,,,, -520500,4.258575,0.5889237,,,,,,,,,,,,,, -520600,4.3386364,0.5482466,,,,,,,,,,,,,, -520700,4.6752434,0.6557538,,,,,,,,,,,,,, -520800,4.7602353,0.6750556,,,,,,,,,,,,,, -520900,4.473408,0.5952075,,,,,,,,,,,,,, -521000,4.063607,0.55187,,,,,,,,,,,,,, -521100,4.4564734,0.60199285,,,,,,,,,,,,,, -521200,6.0972204,0.76121944,,,,,,,,,,,,,, -521300,4.806921,0.64695156,,,,,,,,,,,,,, -521400,4.150382,0.6021436,,,,,,,,,,,,,, -521500,4.3341894,0.5633394,,,,,,,,,,,,,, -521600,4.5440726,0.6627541,,,,,,,,,,,,,, -521695,,,0.9610371589660645,0.1484816223382949,0.757319986820221,1.0408750772476196,50000.0,0.6320000290870667,1.8227356672286987,10000.0,175501.88168287277,181594.9708378315,175501.88168287277,6045.909396409988,28.082290172576904,0.0 -521700,4.4170384,0.6750797,,,,,,,,,,,,,, -521800,4.340252,0.6247957,,,,,,,,,,,,,, -521900,4.6202793,0.67231506,,,,,,,,,,,,,, -522000,4.5709257,0.6642804,,,,,,,,,,,,,, -522100,4.118915,0.62015456,,,,,,,,,,,,,, -522200,4.139726,0.5737469,,,,,,,,,,,,,, -522300,4.0018435,0.5763067,,,,,,,,,,,,,, -522400,4.102467,0.6095622,,,,,,,,,,,,,, -522500,4.5099583,0.63725674,,,,,,,,,,,,,, -522600,5.1012597,0.72486335,,,,,,,,,,,,,, -522700,4.421009,0.56284726,,,,,,,,,,,,,, -522800,4.722158,0.53944653,,,,,,,,,,,,,, -522900,4.5857363,0.5208051,,,,,,,,,,,,,, -523000,4.811489,0.6633529,,,,,,,,,,,,,, -523100,4.824871,0.70062804,,,,,,,,,,,,,, -523200,4.3595676,0.6720084,,,,,,,,,,,,,, -523212,,,0.96097731590271,0.1454038321971893,0.7571600079536438,1.0417425632476809,50000.0,0.6319000124931335,1.8241809606552124,10000.0,176011.82210946083,182122.0238277912,176011.82210946083,6062.849289655685,28.1988844871521,0.0 -523300,4.4171834,0.62370855,,,,,,,,,,,,,, -523400,4.272929,0.565905,,,,,,,,,,,,,, -523500,4.516043,0.65217155,,,,,,,,,,,,,, -523600,4.405228,0.5546339,,,,,,,,,,,,,, -523700,4.5157757,0.70220447,,,,,,,,,,,,,, -523800,4.4574637,0.5882828,,,,,,,,,,,,,, -523900,4.349988,0.5947121,,,,,,,,,,,,,, -524000,4.322577,0.59498,,,,,,,,,,,,,, -524100,4.421861,0.66190064,,,,,,,,,,,,,, -524200,4.3782244,0.6442658,,,,,,,,,,,,,, -524300,4.9735813,0.6119195,,,,,,,,,,,,,, -524400,4.799366,0.6557408,,,,,,,,,,,,,, -524500,4.4880233,0.59861493,,,,,,,,,,,,,, -524600,4.1382008,0.58266413,,,,,,,,,,,,,, -524700,4.348143,0.5818242,,,,,,,,,,,,,, -524728,,,0.9594427347183228,0.1490860879421234,0.7573599815368652,1.0410236120224,50000.0,0.6315000057220459,1.823151707649231,10000.0,176521.65455007553,182648.7777543068,176521.65455007553,6079.59999704361,28.31386494636536,0.0 -524800,4.627499,0.5660963,,,,,,,,,,,,,, -524900,4.596541,0.7086365,,,,,,,,,,,,,, -525000,4.481681,0.6545172,,,,,,,,,,,,,, -525100,4.568587,0.60556656,,,,,,,,,,,,,, -525200,4.6774797,0.6240761,,,,,,,,,,,,,, -525300,4.436353,0.62824035,,,,,,,,,,,,,, -525400,4.407913,0.61749744,,,,,,,,,,,,,, -525500,4.6984158,0.6668144,,,,,,,,,,,,,, -525600,4.197284,0.6254807,,,,,,,,,,,,,, -525700,4.244515,0.5542537,,,,,,,,,,,,,, -525800,4.2385883,0.62532187,,,,,,,,,,,,,, -525900,4.587581,0.6255306,,,,,,,,,,,,,, -526000,4.928352,0.6552195,,,,,,,,,,,,,, -526100,4.817806,0.6892064,,,,,,,,,,,,,, -526200,4.362935,0.5608584,,,,,,,,,,,,,, -526245,,,0.9598214030265808,0.148780271410942,0.7570799589157104,1.0420541763305664,50000.0,0.6308000087738037,1.8240805864334104,10000.0,177031.57284379005,183175.72368884087,177031.57284379005,6096.458354473114,28.42765522003174,0.0 -526300,5.0000806,0.7370669,,,,,,,,,,,,,, -526400,4.7400885,0.6638476,,,,,,,,,,,,,, -526500,4.452533,0.560291,,,,,,,,,,,,,, -526600,4.5731516,0.6008258,,,,,,,,,,,,,, -526700,4.0399904,0.5118965,,,,,,,,,,,,,, -526800,4.580624,0.65704465,,,,,,,,,,,,,, -526900,4.7539234,0.6070223,,,,,,,,,,,,,, -527000,4.231599,0.63048726,,,,,,,,,,,,,, -527100,4.6779,0.6292994,,,,,,,,,,,,,, -527200,4.858808,0.6905423,,,,,,,,,,,,,, -527300,4.4866643,0.6076126,,,,,,,,,,,,,, -527400,4.418503,0.5882697,,,,,,,,,,,,,, -527500,4.6755347,0.6437012,,,,,,,,,,,,,, -527600,4.17368,0.60612595,,,,,,,,,,,,,, -527700,4.756875,0.64175636,,,,,,,,,,,,,, -527761,,,0.960558831691742,0.1445672214031219,0.7572599649429321,1.041681885719299,50000.0,0.6322000026702881,1.8236315250396729,10000.0,177541.49831581116,183702.63179159164,177541.49831581116,6113.270159482956,28.54345893859864,0.0 -527800,4.53887,0.60632896,,,,,,,,,,,,,, -527900,4.4619527,0.5345046,,,,,,,,,,,,,, -528000,4.262069,0.624815,,,,,,,,,,,,,, -528100,4.501322,0.57325184,,,,,,,,,,,,,, -528200,4.5168843,0.5603139,,,,,,,,,,,,,, -528300,4.3903375,0.5454769,,,,,,,,,,,,,, -528400,4.2886357,0.55668235,,,,,,,,,,,,,, -528500,4.373012,0.6079615,,,,,,,,,,,,,, -528600,4.5807095,0.55625814,,,,,,,,,,,,,, -528700,4.6147723,0.6302788,,,,,,,,,,,,,, -528800,4.5460033,0.59612435,,,,,,,,,,,,,, -528900,4.9911084,0.6700486,,,,,,,,,,,,,, -529000,4.7327886,0.585582,,,,,,,,,,,,,, -529100,4.6744003,0.6273463,,,,,,,,,,,,,, -529200,4.489237,0.60543036,,,,,,,,,,,,,, -529277,,,0.9598811864852904,0.148500919342041,0.7572799921035767,1.0404560565948486,50000.0,0.6322000026702881,1.821385264396668,10000.0,178051.53967666626,184229.71953058243,178051.53967666626,6130.148310184479,28.656593799591064,0.0 -529300,4.5109243,0.64366263,,,,,,,,,,,,,, -529400,4.001048,0.6117358,,,,,,,,,,,,,, -529500,4.1337996,0.5979681,,,,,,,,,,,,,, -529600,4.873115,0.6177264,,,,,,,,,,,,,, -529700,4.977348,0.61496085,,,,,,,,,,,,,, -529800,4.203294,0.59946215,,,,,,,,,,,,,, -529900,4.4264555,0.60776865,,,,,,,,,,,,,, -530000,4.577377,0.59899247,,,,,,,,,,,,,, -530100,4.036475,0.51619995,,,,,,,,,,,,,, -530200,4.5876927,0.6186931,,,,,,,,,,,,,, -530300,4.4061627,0.58260083,,,,,,,,,,,,,, -530400,4.669397,0.6397038,,,,,,,,,,,,,, -530500,4.194365,0.60392094,,,,,,,,,,,,,, -530600,4.807929,0.59749883,,,,,,,,,,,,,, -530700,4.668174,0.622841,,,,,,,,,,,,,, -530794,,,0.9596819281578064,0.1456088274717331,0.7568999528884888,1.0421922206878662,50000.0,0.6319000124931335,1.823911190032959,10000.0,178561.46390724182,184756.6442449093,178561.46390724182,6146.977098941803,28.77240538597107,0.0 -530800,4.4645824,0.64952755,,,,,,,,,,,,,, -530900,4.711765,0.708045,,,,,,,,,,,,,, -531000,4.7783504,0.6781657,,,,,,,,,,,,,, -531100,4.329939,0.61782384,,,,,,,,,,,,,, -531200,4.6076055,0.66039664,,,,,,,,,,,,,, -531300,4.7455764,0.5866878,,,,,,,,,,,,,, -531400,4.9181304,0.6749514,,,,,,,,,,,,,, -531500,4.8022313,0.6140305,,,,,,,,,,,,,, -531600,4.2487283,0.599913,,,,,,,,,,,,,, -531700,4.6370187,0.62936646,,,,,,,,,,,,,, -531800,4.463498,0.62507576,,,,,,,,,,,,,, -531900,4.526908,0.6724347,,,,,,,,,,,,,, -532000,4.2278514,0.6400176,,,,,,,,,,,,,, -532100,4.146104,0.59712946,,,,,,,,,,,,,, -532200,4.5793624,0.68087554,,,,,,,,,,,,,, -532300,4.5956693,0.6135303,,,,,,,,,,,,,, -532310,,,0.9608976244926452,0.147421196103096,0.7572000026702881,1.0403904914855957,50000.0,0.6317000389099121,1.8210970163345337,10000.0,179071.4066681862,185283.5139591694,179071.4066681862,6163.73255443573,28.88776683807373,0.0 -532400,4.346623,0.5822042,,,,,,,,,,,,,, -532500,4.6065,0.6428834,,,,,,,,,,,,,, -532600,4.3131547,0.5930487,,,,,,,,,,,,,, -532700,4.238252,0.56759566,,,,,,,,,,,,,, -532800,5.433629,0.68195856,,,,,,,,,,,,,, -532900,5.166222,0.7049212,,,,,,,,,,,,,, -533000,4.1383486,0.5668064,,,,,,,,,,,,,, -533100,4.629847,0.63638306,,,,,,,,,,,,,, -533200,4.4904885,0.6649387,,,,,,,,,,,,,, -533300,4.456908,0.68416435,,,,,,,,,,,,,, -533400,4.7305503,0.6749909,,,,,,,,,,,,,, -533500,4.189426,0.57582724,,,,,,,,,,,,,, -533600,4.7087936,0.5976943,,,,,,,,,,,,,, -533700,4.6085777,0.6111351,,,,,,,,,,,,,, -533800,4.1301713,0.60377896,,,,,,,,,,,,,, -533827,,,0.9612762928009032,0.1458973139524459,0.7568999528884888,1.0407488346099854,50000.0,0.6320000290870667,1.8223789930343628,10000.0,179581.4375398159,185810.63350367543,179581.4375398159,6180.650722265244,29.001569032669067,0.0 -533900,4.913566,0.68597114,,,,,,,,,,,,,, -534000,4.6250443,0.623415,,,,,,,,,,,,,, -534100,4.354455,0.5983524,,,,,,,,,,,,,, -534200,4.9644775,0.6213882,,,,,,,,,,,,,, -534300,4.395528,0.57309043,,,,,,,,,,,,,, -534400,4.340623,0.58529216,,,,,,,,,,,,,, -534500,4.6652546,0.63824064,,,,,,,,,,,,,, -534600,4.969385,0.72313905,,,,,,,,,,,,,, -534700,4.681623,0.67952454,,,,,,,,,,,,,, -534800,4.4776206,0.6321517,,,,,,,,,,,,,, -534900,4.5126414,0.6356917,,,,,,,,,,,,,, -535000,4.542858,0.6399853,,,,,,,,,,,,,, -535100,4.458943,0.6591943,,,,,,,,,,,,,, -535200,4.726795,0.63216233,,,,,,,,,,,,,, -535300,5.007966,0.65508515,,,,,,,,,,,,,, -535344,,,0.9600605964660645,0.1473893821239471,0.756879985332489,1.040579080581665,50000.0,0.6310000419616699,1.8225830793380733,10000.0,180091.54822564125,186337.6497502327,180091.54822564125,6197.373930454254,29.12678384780884,0.0 -535400,4.625113,0.6199608,,,,,,,,,,,,,, -535500,4.5706267,0.6241497,,,,,,,,,,,,,, -535600,4.5515265,0.7023763,,,,,,,,,,,,,, -535700,4.490622,0.6486471,,,,,,,,,,,,,, -535800,4.3030343,0.61843514,,,,,,,,,,,,,, -535900,4.636393,0.6691708,,,,,,,,,,,,,, -536000,4.238076,0.6389089,,,,,,,,,,,,,, -536100,4.4468145,0.61380744,,,,,,,,,,,,,, -536200,4.46138,0.5943761,,,,,,,,,,,,,, -536300,4.163719,0.58041865,,,,,,,,,,,,,, -536400,4.407895,0.58833504,,,,,,,,,,,,,, -536500,4.4450116,0.7035824,,,,,,,,,,,,,, -536600,4.859099,0.6601654,,,,,,,,,,,,,, -536700,4.6911283,0.65286785,,,,,,,,,,,,,, -536800,4.538306,0.6318069,,,,,,,,,,,,,, -536861,,,0.9612762928009032,0.1430565118789672,0.757099986076355,1.0416611433029177,50000.0,0.6314000487327576,1.823699712753296,10000.0,180601.6016540528,186864.71676635745,180601.6016540528,6214.216796398163,29.241114854812626,0.0 -536900,4.858472,0.59055704,,,,,,,,,,,,,, -537000,4.5716643,0.6321756,,,,,,,,,,,,,, -537100,5.061957,0.60524684,,,,,,,,,,,,,, -537200,4.822976,0.6134722,,,,,,,,,,,,,, -537300,3.8919272,0.49235696,,,,,,,,,,,,,, -537400,4.6394677,0.6755496,,,,,,,,,,,,,, -537500,4.3782315,0.59938645,,,,,,,,,,,,,, -537600,4.0681624,0.51067126,,,,,,,,,,,,,, -537700,4.273539,0.55680585,,,,,,,,,,,,,, -537800,4.391616,0.62020123,,,,,,,,,,,,,, -537900,4.6813545,0.6665863,,,,,,,,,,,,,, -538000,4.2605505,0.6220511,,,,,,,,,,,,,, -538100,4.126879,0.56135035,,,,,,,,,,,,,, -538200,4.7181354,0.5998622,,,,,,,,,,,,,, -538300,4.431287,0.6436138,,,,,,,,,,,,,, -538378,,,0.9622727632522584,0.1430283784866333,0.7569199800491333,1.041762351989746,50000.0,0.6320000290870667,1.822485089302063,10000.0,181111.55942821503,187391.7662270069,181111.55942821503,6231.133362054825,29.359850883483887,0.0 -538400,4.344236,0.5595678,,,,,,,,,,,,,, -538500,4.5516477,0.66666555,,,,,,,,,,,,,, -538600,4.813311,0.67266405,,,,,,,,,,,,,, -538700,4.118789,0.5911714,,,,,,,,,,,,,, -538800,5.0941725,0.57323784,,,,,,,,,,,,,, -538900,4.6145535,0.65268224,,,,,,,,,,,,,, -539000,4.6691236,0.57949173,,,,,,,,,,,,,, -539100,4.3993516,0.63894045,,,,,,,,,,,,,, -539200,4.5311446,0.5596607,,,,,,,,,,,,,, -539300,4.5669827,0.5800929,,,,,,,,,,,,,, -539400,4.5294085,0.63476264,,,,,,,,,,,,,, -539500,3.8946273,0.5292933,,,,,,,,,,,,,, -539600,4.466211,0.6349802,,,,,,,,,,,,,, -539700,5.4753757,0.7501776,,,,,,,,,,,,,, -539800,4.3317747,0.54532284,,,,,,,,,,,,,, -539895,,,0.9616350531578064,0.1455331742763519,0.7573999762535095,1.0420688390731812,50000.0,0.6320000290870667,1.8241206407547,10000.0,181621.6319413185,187919.27256298065,181621.6319413185,6248.389245271683,29.48081016540528,0.0 -539900,4.8025537,0.6588092,,,,,,,,,,,,,, -540000,4.711647,0.659625,,,,,,,,,,,,,, -540100,4.501461,0.68230367,,,,,,,,,,,,,, -540200,4.628426,0.60538673,,,,,,,,,,,,,, -540300,4.5595164,0.5390388,,,,,,,,,,,,,, -540400,4.741459,0.6745974,,,,,,,,,,,,,, -540500,4.5222015,0.6359946,,,,,,,,,,,,,, -540600,4.4196367,0.5477747,,,,,,,,,,,,,, -540700,4.340611,0.6289014,,,,,,,,,,,,,, -540800,5.0921063,0.6028734,,,,,,,,,,,,,, -540900,4.5228763,0.60789335,,,,,,,,,,,,,, -541000,4.18636,0.5428693,,,,,,,,,,,,,, -541100,4.1260123,0.60174656,,,,,,,,,,,,,, -541200,4.569149,0.6090415,,,,,,,,,,,,,, -541300,4.6723,0.6674161,,,,,,,,,,,,,, -541400,4.2062035,0.608855,,,,,,,,,,,,,, -541411,,,0.959781527519226,0.1493196040391922,0.7568199634552002,1.0412883758544922,50000.0,0.6317000389099121,1.8230903148651123,10000.0,182131.56746602056,188446.3013682365,182131.56746602056,6265.310503005981,29.59652018547058,0.0 -541500,4.1086807,0.57945836,,,,,,,,,,,,,, -541600,4.368343,0.5771114,,,,,,,,,,,,,, -541700,4.4855237,0.6765558,,,,,,,,,,,,,, -541800,4.6224885,0.6224583,,,,,,,,,,,,,, -541900,4.471738,0.6380073,,,,,,,,,,,,,, -542000,4.3465223,0.62652105,,,,,,,,,,,,,, -542100,4.778003,0.70824325,,,,,,,,,,,,,, -542200,4.8462715,0.6273539,,,,,,,,,,,,,, -542300,4.647355,0.60055125,,,,,,,,,,,,,, -542400,4.8302274,0.5799064,,,,,,,,,,,,,, -542500,4.362499,0.6047681,,,,,,,,,,,,,, -542600,4.585731,0.65913445,,,,,,,,,,,,,, -542700,4.7065597,0.591703,,,,,,,,,,,,,, -542800,4.382001,0.58512,,,,,,,,,,,,,, -542900,4.752894,0.68249977,,,,,,,,,,,,,, -542928,,,0.9606983065605164,0.144793152809143,0.7574999928474426,1.0422248840332031,50000.0,0.6312000155448914,1.822980165481568,10000.0,182641.5040605068,188973.40545129776,182641.5040605068,6282.299289464951,29.71865010261536,0.0 -543000,5.189473,0.69940096,,,,,,,,,,,,,, -543100,4.396161,0.62607485,,,,,,,,,,,,,, -543200,4.1815753,0.5743538,,,,,,,,,,,,,, -543300,4.5042562,0.646296,,,,,,,,,,,,,, -543400,4.815489,0.66599774,,,,,,,,,,,,,, -543500,4.089507,0.5408284,,,,,,,,,,,,,, -543600,4.139967,0.54415566,,,,,,,,,,,,,, -543700,4.404998,0.62807286,,,,,,,,,,,,,, -543800,4.1254077,0.5841503,,,,,,,,,,,,,, -543900,4.2832823,0.54840374,,,,,,,,,,,,,, -544000,3.9502358,0.5377316,,,,,,,,,,,,,, -544100,3.8782206,0.54376054,,,,,,,,,,,,,, -544200,4.5641956,0.67403924,,,,,,,,,,,,,, -544300,4.546462,0.6217018,,,,,,,,,,,,,, -544400,4.50236,0.57484347,,,,,,,,,,,,,, -544445,,,0.961933970451355,0.1422090828418731,0.7573599815368652,1.0426530838012695,50000.0,0.6307000517845154,1.824191689491272,10000.0,183151.4329688549,189500.3107652664,183151.4329688549,6299.116809844971,29.8223876953125,0.0 -544500,4.4708924,0.6202867,,,,,,,,,,,,,, -544600,4.6515293,0.61022025,,,,,,,,,,,,,, -544700,4.568315,0.61066437,,,,,,,,,,,,,, -544800,4.8253217,0.6765125,,,,,,,,,,,,,, -544900,4.1682076,0.5934777,,,,,,,,,,,,,, -545000,4.131232,0.57696337,,,,,,,,,,,,,, -545100,4.772548,0.6726145,,,,,,,,,,,,,, -545200,4.462391,0.635061,,,,,,,,,,,,,, -545300,4.703428,0.6670656,,,,,,,,,,,,,, -545400,4.0888586,0.59510386,,,,,,,,,,,,,, -545500,4.27002,0.64027846,,,,,,,,,,,,,, -545600,4.983809,0.7069958,,,,,,,,,,,,,, -545700,4.3490686,0.5773057,,,,,,,,,,,,,, -545800,4.7453413,0.72517526,,,,,,,,,,,,,, -545900,4.317231,0.60477877,,,,,,,,,,,,,, -545962,,,0.9602798223495485,0.1469572782516479,0.7570199966430664,1.041181564331055,50000.0,0.6306000351905823,1.822177052497864,10000.0,183661.53258037567,190027.4561581612,183661.53258037567,6315.990729570389,29.937760829925537,0.0 -546000,5.3298817,0.67788845,,,,,,,,,,,,,, -546100,4.499755,0.53762656,,,,,,,,,,,,,, -546200,4.55717,0.61879355,,,,,,,,,,,,,, -546300,4.5227704,0.6269087,,,,,,,,,,,,,, -546400,4.220023,0.659848,,,,,,,,,,,,,, -546500,4.670002,0.6243868,,,,,,,,,,,,,, -546600,4.66505,0.5843559,,,,,,,,,,,,,, -546700,4.4788556,0.63245815,,,,,,,,,,,,,, -546800,4.193747,0.60227585,,,,,,,,,,,,,, -546900,4.2053266,0.57533866,,,,,,,,,,,,,, -547000,4.5055785,0.59073424,,,,,,,,,,,,,, -547100,4.6218495,0.66784084,,,,,,,,,,,,,, -547200,4.3326626,0.5740338,,,,,,,,,,,,,, -547300,4.342605,0.5824615,,,,,,,,,,,,,, -547400,4.8972654,0.6595719,,,,,,,,,,,,,, -547478,,,0.9627909660339355,0.1446569114923477,0.757099986076355,1.0409207344055176,50000.0,0.6310000419616699,1.8227473497390747,10000.0,184171.5425419808,190554.8959350586,184171.5425419808,6333.247943878174,30.054824352264404,0.0 -547500,4.342764,0.6175877,,,,,,,,,,,,,, -547600,4.3845367,0.6765442,,,,,,,,,,,,,, -547700,4.9565754,0.5999035,,,,,,,,,,,,,, -547800,4.7696767,0.68581635,,,,,,,,,,,,,, -547900,4.475074,0.64020014,,,,,,,,,,,,,, -548000,4.712407,0.6342634,,,,,,,,,,,,,, -548100,4.1338577,0.54924923,,,,,,,,,,,,,, -548200,4.323114,0.688346,,,,,,,,,,,,,, -548300,4.5126395,0.61578935,,,,,,,,,,,,,, -548400,4.3978004,0.572248,,,,,,,,,,,,,, -548500,4.422928,0.60467255,,,,,,,,,,,,,, -548600,4.580609,0.63317204,,,,,,,,,,,,,, -548700,4.214942,0.571851,,,,,,,,,,,,,, -548800,4.62639,0.7053053,,,,,,,,,,,,,, -548900,4.452764,0.64672637,,,,,,,,,,,,,, -548995,,,0.9591637253761292,0.1491780281066894,0.7572599649429321,1.0409399271011353,50000.0,0.6325000524520874,1.8207401037216189,10000.0,184681.5361790657,191081.8586373329,184681.5361790657,6350.045515298843,30.17049288749695,0.0 -549000,4.6608586,0.64854,,,,,,,,,,,,,, -549100,4.8331146,0.68773574,,,,,,,,,,,,,, -549200,4.1344,0.5794711,,,,,,,,,,,,,, -549300,4.108137,0.58084905,,,,,,,,,,,,,, -549400,4.4105363,0.5857817,,,,,,,,,,,,,, -549500,4.3533487,0.58021,,,,,,,,,,,,,, -549600,5.0822487,0.62266916,,,,,,,,,,,,,, -549700,4.485252,0.6503049,,,,,,,,,,,,,, -549800,4.335226,0.61467236,,,,,,,,,,,,,, -549900,4.7928596,0.6281945,,,,,,,,,,,,,, -550000,4.8841934,0.5756003,,,,,,,,,,,,,, -550100,4.8185887,0.63037074,,,,,,,,,,,,,, -550200,4.721419,0.6439537,,,,,,,,,,,,,, -550300,4.4453773,0.6576333,,,,,,,,,,,,,, -550400,4.125937,0.6189244,,,,,,,,,,,,,, -550500,4.95489,0.67484504,,,,,,,,,,,,,, -550512,,,0.9607381820678712,0.1471794098615646,0.7575199604034424,1.0413737297058103,50000.0,0.6324000358581543,1.8237340450286863,10000.0,185191.63391900063,191608.9362416268,185191.63391900063,6366.848225593567,30.291536569595337,0.0 -550600,4.4514766,0.62337655,,,,,,,,,,,,,, -550700,4.2367153,0.5676529,,,,,,,,,,,,,, -550800,4.7231455,0.62610936,,,,,,,,,,,,,, -550900,4.5053396,0.6345625,,,,,,,,,,,,,, -551000,4.695717,0.627087,,,,,,,,,,,,,, -551100,4.8031425,0.63925457,,,,,,,,,,,,,, -551200,3.8797812,0.5294072,,,,,,,,,,,,,, -551300,4.399789,0.5858592,,,,,,,,,,,,,, -551400,4.410242,0.6103607,,,,,,,,,,,,,, -551500,4.6884565,0.6016191,,,,,,,,,,,,,, -551600,4.363576,0.607345,,,,,,,,,,,,,, -551700,4.446941,0.5387511,,,,,,,,,,,,,, -551800,4.5345383,0.6254024,,,,,,,,,,,,,, -551900,4.3586125,0.61079645,,,,,,,,,,,,,, -552000,4.2595797,0.6373024,,,,,,,,,,,,,, -552029,,,0.9604990482330322,0.1474171131849289,0.7571600079536438,1.0401363372802734,50000.0,0.6317000389099121,1.8212229013442995,10000.0,185701.63338589668,192135.8907732964,185701.63338589668,6383.623881816864,30.41457509994507,0.0 -552100,4.2205367,0.61765766,,,,,,,,,,,,,, -552200,4.3190193,0.62187827,,,,,,,,,,,,,, -552300,4.2025323,0.61846614,,,,,,,,,,,,,, -552400,4.2299924,0.5720045,,,,,,,,,,,,,, -552500,4.8005857,0.58849096,,,,,,,,,,,,,, -552600,4.349651,0.6465478,,,,,,,,,,,,,, -552700,4.0812583,0.5812218,,,,,,,,,,,,,, -552800,4.748434,0.5934955,,,,,,,,,,,,,, -552900,4.282974,0.4887211,,,,,,,,,,,,,, -553000,4.348333,0.61134213,,,,,,,,,,,,,, -553100,4.300005,0.6067948,,,,,,,,,,,,,, -553200,4.4772973,0.58927506,,,,,,,,,,,,,, -553300,4.3957243,0.5407016,,,,,,,,,,,,,, -553400,4.0538907,0.59660536,,,,,,,,,,,,,, -553500,4.6785865,0.6459906,,,,,,,,,,,,,, -553547,,,0.96097731590271,0.1455092877149582,0.7566999793052673,1.0411295890808103,50000.0,0.6324000358581543,1.822754144668579,10000.0,186211.5637850761,192662.90041589737,186211.5637850761,6400.52764081955,30.53302764892578,0.0 -553600,4.300851,0.6074921,,,,,,,,,,,,,, -553700,4.450433,0.6236061,,,,,,,,,,,,,, -553800,4.9009724,0.63192904,,,,,,,,,,,,,, -553900,4.5441227,0.61599416,,,,,,,,,,,,,, -554000,4.6363816,0.59731084,,,,,,,,,,,,,, -554100,4.238369,0.6285956,,,,,,,,,,,,,, -554200,4.927597,0.60996807,,,,,,,,,,,,,, -554300,4.8892713,0.7034513,,,,,,,,,,,,,, -554400,4.5569773,0.6468381,,,,,,,,,,,,,, -554500,4.8766994,0.6486368,,,,,,,,,,,,,, -554600,4.4081154,0.6009535,,,,,,,,,,,,,, -554700,4.3712573,0.58750653,,,,,,,,,,,,,, -554800,4.4737854,0.64880747,,,,,,,,,,,,,, -554900,4.5189257,0.5815095,,,,,,,,,,,,,, -555000,4.493863,0.6994096,,,,,,,,,,,,,, -555064,,,0.9598413109779358,0.1486168652772903,0.7573399543762207,1.0403741598129272,50000.0,0.6317000389099121,1.8228280544281008,10000.0,186721.5231051445,193189.8435087204,186721.5231051445,6417.333734750748,30.65377688407898,0.0 -555100,4.456848,0.63455546,,,,,,,,,,,,,, -555200,4.7649727,0.6125149,,,,,,,,,,,,,, -555300,4.642277,0.66686743,,,,,,,,,,,,,, -555400,4.5219693,0.5963095,,,,,,,,,,,,,, -555500,4.446842,0.6437787,,,,,,,,,,,,,, -555600,4.6634817,0.69900113,,,,,,,,,,,,,, -555700,4.1357,0.62400305,,,,,,,,,,,,,, -555800,4.678378,0.65921533,,,,,,,,,,,,,, -555900,4.4156737,0.6501001,,,,,,,,,,,,,, -556000,4.6867623,0.60738444,,,,,,,,,,,,,, -556100,4.967515,0.7239038,,,,,,,,,,,,,, -556200,4.5801725,0.61151236,,,,,,,,,,,,,, -556300,4.5609803,0.58122385,,,,,,,,,,,,,, -556400,4.4186397,0.6216232,,,,,,,,,,,,,, -556500,4.4336343,0.6685669,,,,,,,,,,,,,, -556582,,,0.9602000713348388,0.1469131112098693,0.7574599981307983,1.0412509441375732,50000.0,0.6324000358581543,1.822922110557556,10000.0,187231.49054288864,193716.9131946564,187231.49054288864,6434.261531352997,30.77181887626648,0.0 -556600,4.5847664,0.6613666,,,,,,,,,,,,,, -556700,4.968695,0.60059077,,,,,,,,,,,,,, -556800,4.4280796,0.6194249,,,,,,,,,,,,,, -556900,4.4932203,0.6351977,,,,,,,,,,,,,, -557000,4.654406,0.6157898,,,,,,,,,,,,,, -557100,4.7171326,0.54745847,,,,,,,,,,,,,, -557200,4.4473166,0.66748357,,,,,,,,,,,,,, -557300,4.6083694,0.6237254,,,,,,,,,,,,,, -557400,4.256424,0.64028996,,,,,,,,,,,,,, -557500,4.37611,0.5571604,,,,,,,,,,,,,, -557600,4.4997964,0.5778031,,,,,,,,,,,,,, -557700,4.4423833,0.6034018,,,,,,,,,,,,,, -557800,4.332168,0.5408465,,,,,,,,,,,,,, -557900,4.258834,0.61746037,,,,,,,,,,,,,, -558000,4.446909,0.58675987,,,,,,,,,,,,,, -558099,,,0.9608777165412904,0.1464002430438995,0.7570799589157104,1.040795087814331,50000.0,0.6310000419616699,1.822549104690552,10000.0,187741.4332535267,194243.9229915142,187741.4332535267,6451.153831481934,30.892110347747803,0.0 -558100,4.5898066,0.5678364,,,,,,,,,,,,,, -558200,4.137187,0.59684116,,,,,,,,,,,,,, -558300,4.656715,0.7001816,,,,,,,,,,,,,, -558400,4.539121,0.72510207,,,,,,,,,,,,,, -558500,4.789489,0.70589197,,,,,,,,,,,,,, -558600,4.2880683,0.6599492,,,,,,,,,,,,,, -558700,4.2165465,0.5477944,,,,,,,,,,,,,, -558800,5.121651,0.63606477,,,,,,,,,,,,,, -558900,4.6341085,0.59748375,,,,,,,,,,,,,, -559000,4.2395973,0.6228389,,,,,,,,,,,,,, -559100,4.161889,0.54916644,,,,,,,,,,,,,, -559200,4.257158,0.62432337,,,,,,,,,,,,,, -559300,4.286799,0.5671488,,,,,,,,,,,,,, -559400,4.504225,0.6648484,,,,,,,,,,,,,, -559500,5.09365,0.62393713,,,,,,,,,,,,,, -559600,4.5134454,0.61984324,,,,,,,,,,,,,, -559617,,,0.9605388641357422,0.1474508792161941,0.7572799921035767,1.0420900583267212,50000.0,0.6313000321388245,1.8238762617111208,10000.0,188251.5854208469,194771.0418927669,188251.5854208469,6467.947027206421,31.01043510437012,0.0 -559700,4.7760077,0.575793,,,,,,,,,,,,,, -559800,4.617054,0.57454246,,,,,,,,,,,,,, -559900,4.4847293,0.5618497,,,,,,,,,,,,,, -559998,,,0.9612165093421936,0.1485250741243362,0.7571199536323547,1.039791464805603,50000.0,0.6313000321388245,1.8217592239379885,10000.0,188379.31034731865,194915.6333310604,188379.31034731865,6484.6761956214905,31.13318109512329,0.0 -559998,,,,,,,,,,,188379.31034731865,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv deleted file mode 100644 index f868a8c46..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,555 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/num_examples,total_duration,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples -42.01347494125366,0.0,44.79174327850342,1,0,44.79174327850342,0.0010000000474974,6.907756805419922,10000,86.80535459518433,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000 -63.51478552818298,0.0266435146331787,465.0541150569916,911,0,465.0541150569916,0.0240000002086162,6.074681758880615,10000,528.6460611820221,0.0321874991059303,5.927576541900635,0.0300399996340274,5.954702377319336,50000 -85.10202193260193,0.053812026977539,885.2891094684601,1870,0,885.2891094684601,0.0636000037193298,5.44067907333374,10000,970.5481154918672,0.0830664038658142,5.143368244171143,0.0786999985575676,5.1934967041015625,50000 -106.90083980560304,0.0830845832824707,1305.504187822342,2830,0,1305.504187822342,0.1094000041484832,4.899428367614746,10000,1412.643222808838,0.1516406238079071,4.460026264190674,0.1383199989795684,4.5624847412109375,50000 -129.67135000228882,0.1106207370758056,1725.7441012859344,3789,0,1725.7441012859344,0.1463000029325485,4.5577168464660645,10000,1855.7330009937289,0.2010351568460464,4.068780422210693,0.1882999986410141,4.150966644287109,50000 -152.84962821006775,0.1393589973449707,2145.8203287124634,4746,0,2145.8203287124634,0.1774000078439712,4.310952663421631,10000,2299.0677371025085,0.2476171851158142,3.7524003982543945,0.2291399985551834,3.865080118179321,50000 -174.51342248916626,0.1706171035766601,2565.7739176750183,5703,0,2565.7739176750183,0.230100005865097,3.9030966758728014,10000,2740.7684206962585,0.3254687488079071,3.208423376083374,0.2970999777317047,3.3675014972686768,50000 -196.54568314552307,0.2015841007232666,2985.7302856445312,6657,0,2985.7302856445312,0.2517000138759613,3.7387924194335938,10000,3182.8387632369995,0.3775781095027923,2.9110069274902344,0.329039990901947,3.177209615707397,50000 -218.7306222915649,0.2344107627868652,3405.709196329117,7611,0,3405.709196329117,0.2745999991893768,3.569232940673828,10000,3625.086359500885,0.3849999904632568,2.8280928134918213,0.3578599989414215,2.9794843196868896,50000 -243.68302154541016,0.2739176750183105,3825.88839006424,8554,0,3825.88839006424,0.2945000231266022,3.458648204803467,10000,4070.308994531632,0.4186523258686065,2.685232162475586,0.3840599954128265,2.860114574432373,50000 -272.23877477645874,0.3016986846923828,4245.836257696152,9502,0,4245.836257696152,0.3096000254154205,3.3151302337646484,10000,4518.890733718872,0.4482226371765136,2.503745555877685,0.4045199751853943,2.7109215259552,50000 -298.39640188217163,0.3389160633087158,4665.796049118042,10455,0,4665.796049118042,0.3285000026226043,3.205925226211548,10000,4965.096488714218,0.4553320109844208,2.4251720905303955,0.4303999841213226,2.5701441764831543,50000 -323.38474798202515,0.3693301677703857,5085.781652927399,11407,0,5085.781652927399,0.3394000232219696,3.1411850452423096,10000,5410.151679992676,0.4743359386920929,2.3213415145874023,0.4418199956417084,2.4998693466186523,50000 -351.4263126850128,0.3970980644226074,5506.2299082279205,12359,0,5506.2299082279205,0.3520000278949737,3.048286199569702,10000,5858.720902204514,0.4992382824420929,2.207763910293579,0.4544000029563904,2.4317727088928223,50000 -380.0723383426666,0.4321315288543701,5926.184532165527,13308,0,5926.184532165527,0.3541000187397003,3.0787501335144043,10000,6307.407933473587,0.5203320384025574,2.1063053607940674,0.4576999843120575,2.4315109252929688,50000 -407.3503110408783,0.4661853313446045,6346.236738204956,14256,0,6346.236738204956,0.3675000071525574,2.9688971042633057,10000,6754.822709321976,0.512402355670929,2.137631893157959,0.4780799746513366,2.320540189743042,50000 -433.8850626945496,0.4962055683135986,6766.212954044342,15204,0,6766.212954044342,0.3815000057220459,2.909775972366333,10000,7201.414595603943,0.5299218893051147,2.066849708557129,0.4887799918651581,2.277812957763672,50000 -462.300833940506,0.5308792591094971,7186.382284641266,16145,0,7186.382284641266,0.3874000310897827,2.8302128314971924,10000,7650.084501981735,0.5477734208106995,1.9316043853759768,0.5002399682998657,2.178302049636841,50000 -496.14386224746704,0.582331657409668,7606.686836242676,17090,0,7606.686836242676,0.3924000263214111,2.787764549255371,10000,8104.333922147751,0.5459765791893005,1.9445308446884155,0.5103999972343445,2.1310389041900635,50000 -532.077849149704,0.6126272678375244,8026.735034227371,18040,0,8026.735034227371,0.4018000066280365,2.779582262039185,10000,8560.396936416626,0.5554491877555847,1.92086660861969,0.5125399827957153,2.1145436763763428,50000 -569.0852816104889,0.6402697563171387,8446.924666404724,18989,0,8446.924666404724,0.4084000289440155,2.721212148666382,10000,9017.6736536026,0.5669531226158142,1.846354842185974,0.5207399725914001,2.081745147705078,50000 -605.6058855056763,0.6697046756744385,8867.289863586426,19938,0,8867.289863586426,0.40870001912117,2.7397501468658447,10000,9474.639935016632,0.5889257788658142,1.7574844360351562,0.5252599716186523,2.079762458801269,50000 -644.8831579685211,0.7056596279144287,9287.409376859663,20886,0,9287.409376859663,0.4234000146389007,2.6601674556732178,10000,9934.12448978424,0.5757421851158142,1.8028565645217896,0.5389999747276306,1.9935595989227293,50000 -679.1143245697021,0.7382750511169434,9707.7606818676,21830,0,9707.7606818676,0.415800005197525,2.6884138584136963,10000,10388.789666175842,0.5738281011581421,1.8270362615585327,0.5313599705696106,2.036630630493164,50000 -713.6164371967316,0.7727911472320557,10127.784230947496,22766,0,10127.784230947496,0.4213000237941742,2.65813946723938,10000,10843.401911258698,0.5947265625,1.7336976528167725,0.5411800146102905,1.994389057159424,50000 -748.1307110786438,0.8031463623046875,10547.981046676636,23713,0,10547.981046676636,0.43340003490448,2.610472679138184,10000,11298.194388628006,0.5895312428474426,1.7517651319503784,0.550059974193573,1.948309540748596,50000 -782.9098272323608,0.8321309089660645,10968.0647251606,24659,0,10968.0647251606,0.4394000172615051,2.570166826248169,10000,11753.13712143898,0.5965625047683716,1.7040070295333862,0.5518199801445007,1.9154651165008545,50000 -814.4276103973389,0.8629908561706543,11388.192962408066,25604,0,11388.192962408066,0.4369000196456909,2.60762095451355,10000,12204.864979743958,0.6005468368530273,1.726381778717041,0.550819993019104,1.963115096092224,50000 -849.3285005092621,1.937572956085205,11807.219497203829,26536,0,11807.219497203829,0.434000015258789,2.540038824081421,10000,12659.916893959044,0.6315820217132568,1.5329514741897583,0.5600399971008301,1.8702781200408936,50000 -882.5709488391876,1.9710197448730469,12227.3085603714,27477,0,12227.3085603714,0.4456000328063965,2.5280001163482666,10000,13113.331893920898,0.6069921851158142,1.6717681884765625,0.5662199854850769,1.8602055311203003,50000 -915.7219965457916,1.998776912689209,12647.270294189451,28419,0,12647.270294189451,0.4465000331401825,2.534619092941284,10000,13566.522515296936,0.6087499856948853,1.6305036544799805,0.5629400014877319,1.8604578971862795,50000 -950.0937445163728,2.026951313018799,13067.21803355217,29361,0,13067.21803355217,0.4513000249862671,2.5328025817871094,10000,14020.92025589943,0.6239452958106995,1.6023894548416138,0.5638799667358398,1.8720834255218504,50000 -984.6698379516602,2.058462142944336,13487.580243349075,30305,0,13487.580243349075,0.4594000279903412,2.4598233699798584,10000,14475.940197706224,0.6150586009025574,1.6110780239105225,0.5711399912834167,1.8235976696014404,50000 -1018.4912509918212,2.090975284576416,13907.861764431,31249,0,13907.861764431,0.4600000083446502,2.4539012908935547,10000,14930.126143455504,0.6220507621765137,1.5800942182540894,0.5763999819755554,1.7837181091308594,50000 -1052.9403524398804,2.1188464164733887,14328.242881059648,32191,0,14328.242881059648,0.4607000350952148,2.4564712047576904,10000,15385.033915281296,0.6280859112739563,1.5354851484298706,0.5817399621009827,1.7738856077194214,50000 -1088.0549836158752,2.1518115997314453,14748.527193307877,33134,0,14748.527193307877,0.4657000303268432,2.420597076416016,10000,15840.515821695328,0.6576367020606995,1.3888908624649048,0.5830000042915344,1.7361211776733398,50000 -1121.9403958320618,2.186378002166748,15168.447052955627,34074,0,15168.447052955627,0.4599000215530395,2.418983221054077,10000,16294.405767679214,0.6277148127555847,1.5349551439285278,0.5827400088310242,1.75555419921875,50000 -1155.1252624988556,2.215908527374268,15588.879315137863,35015,0,15588.879315137863,0.4638000130653381,2.43481183052063,10000,16748.10175061226,0.6364648342132568,1.5358233451843262,0.5851799845695496,1.7673289775848389,50000 -1188.193915605545,2.2564587593078613,16008.960027456284,35958,0,16008.960027456284,0.468500018119812,2.4204137325286865,10000,17201.34227848053,0.6478710770606995,1.4692319631576538,0.5889399647712708,1.74940288066864,50000 -1220.422137260437,2.285893678665161,16429.58711719513,36901,0,16429.58711719513,0.4686000347137451,2.3968710899353027,10000,17654.276985645294,0.63427734375,1.4953609704971311,0.5922600030899048,1.7099379301071167,50000 -1255.0818610191343,2.3198788166046143,16849.518322229385,37842,0,16849.518322229385,0.4766000211238861,2.344536066055298,10000,18108.95192480088,0.6421288847923279,1.446354627609253,0.5999400019645691,1.6714078187942505,50000 -1287.8897886276245,2.355440855026245,17269.525953292847,38784,0,17269.525953292847,0.4739000201225281,2.387554883956909,10000,18561.85310387612,0.6454101204872131,1.46431565284729,0.593239963054657,1.7185540199279783,50000 -1321.4097304344175,2.3911402225494385,17689.873439073563,39726,0,17689.873439073563,0.4764000177383423,2.3578004837036133,10000,19015.806552886963,0.6728320121765137,1.3413584232330322,0.5954399704933167,1.6972488164901731,50000 -1354.7683067321775,2.4212679862976074,18110.056242227554,40667,0,18110.056242227554,0.4778000116348266,2.346426010131836,10000,19469.427932024,0.6483203172683716,1.4345197677612305,0.6010800004005432,1.6640561819076538,50000 -1388.2648763656616,2.4519410133361816,18530.41775512696,41610,0,18530.41775512696,0.4812000095844269,2.312072277069092,10000,19923.3668589592,0.6526757478713989,1.4107017517089844,0.6041799783706665,1.6475238800048828,50000 -1420.8001787662506,2.481510877609253,18950.665768146515,42552,0,18950.665768146515,0.4817000329494476,2.320866584777832,10000,20376.23017168045,0.6664062142372131,1.3672758340835571,0.6040399670600891,1.6572825908660889,50000 -1455.2440390586853,2.5197227001190186,19370.83881545067,43492,0,19370.83881545067,0.482200026512146,2.3180975914001465,10000,20830.934995889664,0.6471484303474426,1.4733134508132937,0.6031799912452698,1.6737443208694458,50000 -1489.634738445282,2.552811861038208,19790.954607486725,44432,0,19790.954607486725,0.4879000186920166,2.2897379398345947,10000,21285.523493766785,0.6621484160423279,1.382601618766785,0.6111400127410889,1.6152799129486084,50000 -1522.9538979530334,2.5838775634765625,20211.304652929302,45376,0,20211.304652929302,0.4842000305652618,2.310580015182495,10000,21739.273277044296,0.6607617139816284,1.3963559865951538,0.6053199768066406,1.6574883460998535,50000 -1558.584661245346,2.618739366531372,20631.39994740486,46316,0,20631.39994740486,0.496800035238266,2.2552144527435303,10000,22195.084174633022,0.6971093416213989,1.220287799835205,0.6148999929428101,1.5943645238876345,50000 -1591.965892314911,2.6543779373168945,21051.45190000534,47252,0,21051.45190000534,0.4879000186920166,2.2941040992736816,10000,22648.6025223732,0.6590625047683716,1.4004347324371338,0.6109799742698669,1.6322163343429563,50000 -1625.8616099357605,2.691848039627075,21471.53701448441,48192,0,21471.53701448441,0.4850000143051147,2.307790756225586,10000,23102.67118740081,0.6631835699081421,1.3827235698699951,0.6143999695777893,1.61689555644989,50000 -1661.7927005290985,2.726196765899658,21891.7851896286,49133,0,21891.7851896286,0.490200012922287,2.286827802658081,10000,23558.934968948364,0.6769726276397705,1.3200808763504028,0.6147800087928772,1.6114550828933716,50000 -1695.9685862064362,2.759733200073242,22312.02948999405,50074,0,22312.02948999405,0.4976000189781189,2.2456860542297363,10000,24013.438791275024,0.6643164157867432,1.361477971076965,0.6191999912261963,1.578934907913208,50000 -1728.5848369598389,2.7923452854156494,22732.127825021744,51013,0,22732.127825021744,0.4964000284671783,2.2490456104278564,10000,24466.235282182693,0.6668554544448853,1.3606523275375366,0.6202799677848816,1.5822699069976809,50000 -1763.3116884231567,2.836010217666626,23152.065212011337,51952,0,23152.065212011337,0.4962000250816345,2.252509593963623,10000,24920.99333834648,0.6714843511581421,1.339386224746704,0.6193999648094177,1.5890146493911743,50000 -1796.709459066391,2.8703958988189697,23572.228355884552,52896,0,23572.228355884552,0.5056000351905823,2.2364532947540283,10000,25374.639178276066,0.6980078220367432,1.2068614959716797,0.6241999864578247,1.5524660348892212,50000 -1831.7415869235992,2.908019781112671,23992.23517799377,53833,0,23992.23517799377,0.4935000240802765,2.286892890930176,10000,25829.765172481537,0.6648241877555847,1.4142321348190308,0.6174799799919128,1.6313257217407229,50000 -1865.720906496048,2.943570375442505,24412.294757843018,54772,0,24412.294757843018,0.4990000128746032,2.2520761489868164,10000,26283.889540433884,0.6703906059265137,1.3580645322799685,0.6202799677848816,1.596779227256775,50000 -1901.078797578812,2.9773149490356445,24832.553787231445,55711,0,24832.553787231445,0.5027000308036804,2.2339158058166504,10000,26739.5900247097,0.6917187571525574,1.2428334951400757,0.6261199712753296,1.5562325716018677,50000 -1937.29074048996,3.00985050201416,25252.7198586464,56652,0,25252.7198586464,0.501300036907196,2.226841449737549,10000,27196.0502679348,0.6736913919448853,1.337292194366455,0.6258000135421753,1.5741112232208252,50000 -1970.1550877094269,3.059091567993164,25672.69250059128,57593,0,25672.69250059128,0.5081000328063965,2.189667224884033,10000,27648.986284017563,0.6827539205551147,1.2796573638916016,0.629040002822876,1.528598427772522,50000 -2004.6654317379,3.09879994392395,26092.81829667092,58530,0,26092.81829667092,0.5099000334739685,2.191647529602051,10000,28103.711234807968,0.6944921612739563,1.2365962266921997,0.6259999871253967,1.5299229621887207,50000 -2039.381314992905,3.1359636783599854,26512.790625810623,59471,0,26512.790625810623,0.5143000483512878,2.179502487182617,10000,28558.48644924164,0.7023242115974426,1.2232853174209597,0.637499988079071,1.5057755708694458,50000 -2072.8224205970764,3.1741859912872314,26932.726552009583,60412,0,26932.726552009583,0.510200023651123,2.1724071502685547,10000,29011.95117688179,0.682324230670929,1.2863421440124512,0.6322399973869324,1.51683247089386,50000 -2108.5590019226074,3.208047389984131,27352.794692516327,61352,0,27352.794692516327,0.5085000395774841,2.223173856735229,10000,29467.8389453888,0.6830468773841858,1.3047661781311035,0.6285399794578552,1.5588743686676023,50000 -2143.499714612961,3.2494382858276367,27772.926307678223,62292,0,27772.926307678223,0.5073000192642212,2.181467771530152,10000,29923.00162839889,0.703808605670929,1.203053593635559,0.6331799626350403,1.529718995094299,50000 -2178.395395040512,3.282611131668091,28192.87024521828,63229,0,28192.87024521828,0.5142000317573547,2.15094256401062,10000,30377.92234969139,0.6870507597923279,1.2610909938812256,0.6409800052642822,1.4879895448684692,50000 -2213.624122619629,3.317622423171997,28613.161401748657,64170,0,28613.161401748657,0.5209000110626221,2.133341789245605,10000,30833.526191473007,0.6971484422683716,1.215964436531067,0.6412799954414368,1.470037579536438,50000 -2247.422921895981,3.3557910919189453,29033.25125908852,65111,0,29033.25125908852,0.51910001039505,2.131485939025879,10000,31287.50227761269,0.7034765481948853,1.2014892101287842,0.6446999907493591,1.480310559272766,50000 -2282.7217609882355,3.39251971244812,29453.42460083961,66053,0,29453.42460083961,0.5210000276565552,2.150148630142212,10000,31743.060341358185,0.6955664157867432,1.2484641075134275,0.6377800107002258,1.503852128982544,50000 -2316.2620203495026,3.425584554672241,29873.4261200428,66995,0,29873.4261200428,0.5213000178337097,2.1135141849517822,10000,32196.684210062027,0.6943359375,1.2172327041625977,0.6418399810791016,1.4598710536956787,50000 -2351.888193130493,3.4705111980438232,30293.34878706932,67937,0,30293.34878706932,0.5241000056266785,2.116852045059204,10000,32652.327559947968,0.7025390267372131,1.1881595849990845,0.6430400013923645,1.46323823928833,50000 -2386.2355921268463,3.507035970687866,30713.28437423706,68878,0,30713.28437423706,0.5215000510215759,2.154394149780273,10000,33106.69675326347,0.7176562547683716,1.159710168838501,0.6420599818229675,1.499109983444214,50000 -2422.250189781189,3.54935359954834,31133.5650537014,69815,0,31133.5650537014,0.5277000069618225,2.1101725101470947,10000,33563.08368849754,0.6985741853713989,1.2329505681991575,0.6497399806976318,1.4623702764511108,50000 -2457.065773963928,3.5868256092071533,31553.61950206757,70754,0,31553.61950206757,0.5285000205039978,2.077733039855957,10000,34018.04113793373,0.7025781273841858,1.182090163230896,0.6479200124740601,1.435321807861328,50000 -2491.428635120392,3.620942354202272,31973.56262850761,71693,0,31973.56262850761,0.5290000438690186,2.096698760986328,10000,34472.430872917175,0.7128710746765137,1.1547907590866089,0.6517399549484253,1.451304316520691,50000 -2526.612436294556,3.659743547439575,32393.721923351288,72634,0,32393.721923351288,0.5267000198364258,2.096280097961426,10000,34927.862758398056,0.7041015625,1.1773505210876465,0.6512599587440491,1.4189000129699707,50000 -2558.8190701007843,3.701513767242432,32813.861067295074,73574,0,32813.861067295074,0.5236999988555908,2.107862949371338,10000,35380.29990744591,0.7027539014816284,1.1971248388290403,0.6504200100898743,1.4388371706008911,50000 -2593.974936962128,3.7460293769836426,33233.81961917877,74514,0,33233.81961917877,0.5236000418663025,2.103861093521118,10000,35835.50805258751,0.7067577838897705,1.1871877908706665,0.6526199579238892,1.4396685361862185,50000 -2627.8011043071747,3.7806742191314697,33653.806334257126,75456,0,33653.806334257126,0.5225000381469727,2.114449977874756,10000,36289.40531396866,0.7224413752555847,1.1160813570022583,0.648419976234436,1.4548654556274414,50000 -2661.0037302970886,3.822371244430542,34073.717185258865,76396,0,34073.717185258865,0.5279000401496887,2.0987911224365234,10000,36742.60909795761,0.70751953125,1.1857048273086548,0.6567800045013428,1.4257631301879885,50000 -2693.128019094467,3.871058940887451,34493.6725795269,77335,0,34493.6725795269,0.5332000255584717,2.056315660476685,10000,37194.7865831852,0.71533203125,1.1208672523498535,0.6605199575424194,1.3832670450210571,50000 -2727.1798944473267,3.917135238647461,34913.984763622284,78272,0,34913.984763622284,0.5371000170707703,2.056650161743164,10000,37649.24601197243,0.7160937190055847,1.1190756559371948,0.6571199893951416,1.4038711786270142,50000 -2760.8076634407043,3.956195592880249,35333.96521568298,79214,0,35333.96521568298,0.5348000526428223,2.051039934158325,10000,38102.94189476967,0.716015636920929,1.125848412513733,0.6620999574661255,1.3742302656173706,50000 -2793.860199928284,3.9993815422058105,35753.91080546379,80154,0,35753.91080546379,0.5333999991416931,2.053675651550293,10000,38556.032870054245,0.7089648246765137,1.1660743951797483,0.6569399833679199,1.411213755607605,50000 -2827.2101967334747,4.042644023895264,36173.98077607155,81093,0,36173.98077607155,0.5402000546455383,2.0339879989624023,10000,39009.54544401169,0.7257421612739563,1.083377242088318,0.6646599769592285,1.3715866804122925,50000 -2860.696093082428,4.080122947692871,36594.00684118271,82032,0,36594.00684118271,0.5360000133514404,2.0379531383514404,10000,39463.14464759827,0.7426952719688416,1.036612629890442,0.6611599922180176,1.392037034034729,50000 -2894.578936815262,4.116491079330444,37014.24144101143,82972,0,37014.24144101143,0.5356000065803528,2.0468437671661377,10000,39917.34832811356,0.7173437476158142,1.1371502876281738,0.6619399785995483,1.3823440074920654,50000 -2929.101461648941,4.154000282287598,37434.266563653946,83912,0,37434.266563653946,0.5403000116348267,2.028369903564453,10000,40371.982800245285,0.7253515720367432,1.1027836799621582,0.6668999791145325,1.3718284368515017,50000 -2961.1710698604584,4.19693660736084,37854.56867027283,84854,0,37854.56867027283,0.5380000472068787,2.027428150177002,10000,40824.446971178055,0.7344335913658142,1.063029170036316,0.6647999882698059,1.370414137840271,50000 -2994.7857854366302,4.246838569641113,38274.63127756119,85792,0,38274.63127756119,0.5444000363349915,2.021790266036988,10000,41278.22320103645,0.7220702767372131,1.1111823320388794,0.6682999730110168,1.3621902465820312,50000 -3028.2366137504578,4.2903008460998535,38694.55052232742,86732,0,38694.55052232742,0.5490000247955322,1.991302251815796,10000,41731.68585777283,0.72572261095047,1.0874943733215332,0.6696000099182129,1.3518980741500854,50000 -3060.714049100876,4.327760934829712,39114.70210146904,87674,0,39114.70210146904,0.5430000424385071,2.0079195499420166,10000,42184.402356147766,0.7313281297683716,1.0686558485031128,0.6709199547767639,1.350008249282837,50000 -3094.576204776764,4.378791570663452,39534.7329928875,88615,0,39534.7329928875,0.5525000095367432,1.9742674827575684,10000,42638.39556074143,0.7570117115974426,0.9418240785598756,0.6727799773216248,1.3221476078033447,50000 -3128.6608567237854,4.420300722122192,39954.71279430389,89555,0,39954.71279430389,0.5479000210762024,2.000091075897217,10000,43092.5505900383,0.7291601300239563,1.0930023193359375,0.675000011920929,1.347053289413452,50000 -3161.6687231063843,4.463085412979126,40374.94827008248,90496,0,40374.94827008248,0.5523000359535217,1.9852027893066408,10000,43545.88556671143,0.7349218726158142,1.0604491233825684,0.6757799983024597,1.330373764038086,50000 -3195.8801221847534,4.511308193206787,40795.17732858658,91433,0,40795.17732858658,0.5514000058174133,1.9640438556671145,10000,44000.42345237732,0.7516992092132568,0.9681676030158995,0.6752200126647949,1.3013813495635986,50000 -3227.666731834412,4.5512306690216064,41215.14496469498,92375,0,41215.14496469498,0.5496000051498413,1.9641692638397217,10000,44452.26744532585,0.7336523532867432,1.0666475296020508,0.6757599711418152,1.3290544748306274,50000 -3262.29194355011,4.605958700180054,41635.211717128754,93314,0,41635.211717128754,0.5519000291824341,1.9784936904907229,10000,44907.06348752976,0.7378124594688416,1.0565478801727295,0.677839994430542,1.3238400220870972,50000 -3296.409605741501,4.649278879165649,42055.45607948303,94257,0,42055.45607948303,0.5519000291824341,1.9706907272338867,10000,45361.51790237427,0.7453515529632568,1.0164759159088137,0.6785799860954285,1.315027952194214,50000 -3331.481355428696,4.691254615783691,42475.57057905197,95199,0,42475.57057905197,0.5520000457763672,1.978226661682129,10000,45816.79495024681,0.7621288895606995,0.9430046677589417,0.676539957523346,1.3137037754058838,50000 -3363.4759736061096,4.730247259140015,42895.6915576458,96140,0,42895.6915576458,0.556600034236908,1.9483904838562007,10000,46268.998499155045,0.7389257550239563,1.0446540117263794,0.680679976940155,1.3052889108657837,50000 -3397.667732000351,4.775901317596436,43315.840493917465,97078,0,43315.840493917465,0.5636000037193298,1.91632878780365,10000,46723.43449640274,0.7458202838897705,0.9984654784202576,0.6850000023841858,1.273817777633667,50000 -3433.5789165496826,4.816876649856567,43736.18922662735,98018,0,43736.18922662735,0.557200014591217,1.95401668548584,10000,47179.78457069397,0.7523632645606995,0.9861083626747132,0.6795200109481812,1.307973861694336,50000 -3467.1358416080475,4.856912851333618,44156.29066824913,98959,0,44156.29066824913,0.5526000261306763,1.950338363647461,10000,47633.53465199471,0.7362499833106995,1.0587328672409058,0.683899998664856,1.307403326034546,50000 -3502.182624578476,4.907997369766235,44576.2303814888,99898,0,44576.2303814888,0.5678000450134277,1.9063079357147217,10000,48088.621727228165,0.7546093463897705,0.9772475361824036,0.6888599991798401,1.2703135013580322,50000 -3535.5730526447296,4.947621583938599,44996.54544496536,100839,0,44996.54544496536,0.5581000447273254,1.9453516006469729,10000,48542.416513204575,0.7542577981948853,0.9589452147483826,0.6887199878692627,1.2671600580215454,50000 -3571.9376525878906,4.999526977539063,45416.76748251915,101780,0,45416.76748251915,0.5766000151634216,1.866647720336914,10000,48999.10462117195,0.771289050579071,0.90166836977005,0.6931999921798706,1.242873191833496,50000 -3607.411679744721,5.043881177902222,45836.98096561432,102719,0,45836.98096561432,0.5665000081062317,1.892961859703064,10000,49454.88519287109,0.7533984184265137,0.9679220914840698,0.6896799802780151,1.2444446086883545,50000 -3641.093369960785,5.086168050765991,46257.00506234169,103657,0,46257.00506234169,0.5705000162124634,1.8715754747390747,10000,49908.68248295784,0.7592577934265137,0.9390373229980468,0.6936799883842468,1.2363741397857666,50000 -3675.241497993469,5.125425815582275,46676.93702673912,104597,0,46676.93702673912,0.5694000124931335,1.8680320978164675,10000,50362.851440668106,0.77699214220047,0.8573799133300781,0.6948599815368652,1.222651720046997,50000 -3710.8573791980734,5.166118860244751,47097.02132034302,105536,0,47097.02132034302,0.5735000371932983,1.864123106002808,10000,50818.641756773,0.7588671445846558,0.9336987137794496,0.6984599828720093,1.214532732963562,50000 -3744.832273721695,5.208755970001221,47516.95656371117,106475,0,47516.95656371117,0.5738000273704529,1.854724287986756,10000,51272.64339399338,0.7658007740974426,0.917033851146698,0.6999399662017822,1.2070460319519043,50000 -3776.548688173294,5.253453969955444,47937.12469291687,107415,0,47937.12469291687,0.5769000053405762,1.873677492141724,10000,51724.62146115303,0.7672070264816284,0.9189236760139464,0.6965199708938599,1.229423761367798,50000 -3812.2145340442657,5.30767297744751,48357.32349872589,108285,0,48357.32349872589,0.5774000287055969,1.8574777841567995,10000,52180.58470964432,0.7604101300239563,0.9380630254745485,0.7011599540710449,1.219152331352234,50000 -3846.595866441727,5.356571435928345,48777.63230252266,109226,0,48777.63230252266,0.5776000022888184,1.840431928634644,10000,52635.37345814705,0.7692773342132568,0.918634593486786,0.7051799893379211,1.20063054561615,50000 -3881.148336172104,5.3964619636535645,49197.60548973084,110169,0,49197.60548973084,0.5714000463485718,1.9044029712677,10000,53089.98840522766,0.7645898461341858,0.961553394794464,0.6979799866676331,1.2633581161499023,50000 -3915.5525193214417,5.436882019042969,49617.58262228966,111109,0,49617.58262228966,0.5804000496864319,1.8377219438552856,10000,53544.45907402039,0.7912499904632568,0.8058376908302307,0.7072599530220032,1.1865533590316772,50000 -3949.753340244293,5.479412078857422,50037.971556425095,112049,0,50037.971556425095,0.5787000060081482,1.8381128311157229,10000,53999.140706539154,0.7679296731948853,0.9026365280151368,0.7026000022888184,1.195265769958496,50000 -3984.365024328232,5.520999431610107,50457.97093844414,112988,0,50457.97093844414,0.5800000429153442,1.8336036205291748,10000,54453.84243106842,0.774707019329071,0.8800515532493591,0.707319974899292,1.1850789785385132,50000 -4019.6268467903137,5.563291549682617,50877.95106720925,113928,0,50877.95106720925,0.5804000496864319,1.827962636947632,10000,54909.1758646965,0.7822265625,0.8601695895195007,0.707099974155426,1.1927919387817385,50000 -4053.151378154754,5.610148906707764,51298.191059827805,114868,0,51298.191059827805,0.5825000405311584,1.818506002426148,10000,55363.03579878807,0.7724804282188416,0.8807520866394043,0.7059599757194519,1.169217824935913,50000 -4086.5362679958334,5.664799213409424,51718.180342674255,115805,0,51718.180342674255,0.5877000093460083,1.7957428693771362,10000,55816.513998031616,0.78187495470047,0.8490674495697021,0.7106599807739258,1.1583318710327148,50000 -4122.379958868027,5.713786840438843,52138.19392943382,116744,0,52138.19392943382,0.5914000272750854,1.7908883094787598,10000,56272.46994638443,0.7876366972923279,0.8236556053161621,0.7130599617958069,1.1581529378890991,50000 -4156.973059415817,5.760483264923096,52558.495411634445,117685,0,52558.495411634445,0.5821000337600708,1.8017897605896,10000,56727.45986151695,0.8011718392372131,0.7712819576263428,0.7135599851608276,1.1575746536254885,50000 -4193.030814886093,5.806100606918335,52978.57805562019,118624,0,52978.57805562019,0.5915000438690186,1.8010717630386353,10000,57183.69558787346,0.7796484231948853,0.8606301546096802,0.7136399745941162,1.1567219495773315,50000 -4227.795293569565,5.846826076507568,53398.8308763504,119566,0,53398.8308763504,0.5922000408172607,1.7661399841308594,10000,57638.8034594059,0.7863476276397705,0.8422191143035889,0.7168399691581726,1.1445538997650146,50000 -4263.846667289734,5.894752502441406,53819.15520334244,120503,0,53819.15520334244,0.5898000001907349,1.7828012704849243,10000,58095.27610850334,0.7981640696525574,0.784917950630188,0.7168599963188171,1.13680100440979,50000 -4301.28515791893,5.94125771522522,54239.20383620262,121441,0,54239.20383620262,0.5955000519752502,1.764423966407776,10000,58552.858276844025,0.7860156297683716,0.8371008038520813,0.7204799652099609,1.1226909160614014,50000 -4335.413290023804,5.986897230148315,54659.44230914116,122374,0,54659.44230914116,0.5901000499725342,1.7791229486465454,10000,59007.31868457794,0.7886718511581421,0.8289256691932678,0.7191599607467651,1.1320635080337524,50000 -4370.508100986481,6.033722162246704,55079.53868961334,123313,0,55079.53868961334,0.5976999998092651,1.7560497522354126,10000,59462.60575866699,0.7957812547683716,0.767406165599823,0.7225199937820435,1.1092499494552612,50000 -4405.81605887413,6.080605506896973,55499.47381663322,124253,0,55499.47381663322,0.598300039768219,1.7344789505004885,10000,59917.94662761688,0.8089062571525574,0.7466806769371033,0.7249999642372131,1.1076456308364868,50000 -4441.247399568558,6.1333723068237305,55919.53238034248,125194,0,55919.53238034248,0.6028000116348267,1.7382676601409912,10000,60373.53852200508,0.7921679615974426,0.793973445892334,0.7214199900627136,1.105116844177246,50000 -4473.872217655182,6.182005167007446,56339.733157634735,126134,0,56339.733157634735,0.6052000522613525,1.73361337184906,10000,60826.4618768692,0.8014843463897705,0.7683876156806946,0.7262399792671204,1.1003386974334717,50000 -4508.729408502579,6.23290491104126,56759.89336299896,127073,0,56759.89336299896,0.5973000526428223,1.7537530660629272,10000,61281.57898569107,0.8082616925239563,0.7514926791191101,0.7242000102996826,1.1179012060165403,50000 -4543.34293794632,6.276918172836304,57180.19188570976,128013,0,57180.19188570976,0.6034000515937805,1.7308070659637451,10000,61736.58455133438,0.7968358993530273,0.7883461713790894,0.7251600027084351,1.1023790836334229,50000 -4577.940878629684,6.324890613555908,57600.443110466,128952,0,57600.443110466,0.6049000024795532,1.71925950050354,10000,62191.53056240082,0.8012109398841858,0.7621845602989197,0.7287200093269348,1.0881730318069458,50000 -4612.796590805054,6.371863126754761,58020.502247571945,129891,0,58020.502247571945,0.6082000136375427,1.7195762395858765,10000,62646.54111742973,0.8115820288658142,0.7281399965286255,0.7297999858856201,1.0883889198303225,50000 -4649.020402431488,6.425073385238648,58440.81789493561,130828,0,58440.81789493561,0.6073000431060791,1.7005314826965332,10000,63103.182911872864,0.8115429282188416,0.721081018447876,0.7346400022506714,1.0667911767959597,50000 -4685.490235805512,6.472413301467896,58860.99866771698,131768,0,58860.99866771698,0.6079000234603882,1.708517074584961,10000,63559.92962670326,0.8031835556030273,0.7564061284065247,0.7330399751663208,1.073096513748169,50000 -4721.690928220749,6.516724586486816,59281.00724768639,132706,0,59281.00724768639,0.6050000190734863,1.7248249053955078,10000,64016.23230576515,0.8125976324081421,0.7448723912239075,0.7306399941444397,1.090269684791565,50000 -4756.766691684723,6.560953378677368,59701.293223142624,133645,0,59701.293223142624,0.6089000105857849,1.680550456047058,10000,64471.68787431717,0.8241796493530273,0.6636848449707031,0.7350800037384033,1.0521515607833862,50000 -4792.405487298965,6.606303691864014,60121.38954138756,134584,0,60121.38954138756,0.6104000210762024,1.690237045288086,10000,64927.51717305184,0.8122656345367432,0.7251321077346802,0.7372199892997742,1.0536880493164062,50000 -4828.497610330582,6.658292770385742,60541.42193317413,135517,0,60541.42193317413,0.6121000051498413,1.66946280002594,10000,65383.74241781235,0.8183202743530273,0.7013428807258606,0.7394199967384338,1.044283747673035,50000 -4864.676825284958,6.702035903930664,60961.35953044891,136454,0,60961.35953044891,0.6099000573158264,1.6748216152191162,10000,65839.95121264458,0.8281640410423279,0.6682937741279602,0.7404800057411194,1.0448720455169678,50000 -4900.667014360428,6.751345872879028,61381.284125328064,137394,0,61381.284125328064,0.6151000261306763,1.670432448387146,10000,66295.96452879906,0.8191210627555847,0.6811794638633728,0.7382599711418152,1.0337536334991455,50000 -4935.666965246201,6.801737308502197,61801.64013171196,138333,0,61801.64013171196,0.6124000549316406,1.660211443901062,10000,66751.41964411736,0.8222265243530273,0.6762940287590027,0.7418199777603149,1.029226779937744,50000 -4970.997862577438,6.847434520721436,62221.607963085175,139273,0,62221.607963085175,0.6152000427246094,1.64912211894989,10000,67206.81400728226,0.8302929401397705,0.6423133015632629,0.7434799671173096,1.0174846649169922,50000 -5006.979150533676,6.8912672996521,62641.6125395298,140211,0,62641.6125395298,0.6163000464439392,1.6549649238586426,10000,67662.89258980751,0.8381249904632568,0.6069045662879944,0.7433399558067322,1.016687273979187,50000 -5039.779830694199,6.936065435409546,63061.63824224472,141147,0,63061.63824224472,0.6189000010490417,1.642565369606018,10000,68115.81379771233,0.82630854845047,0.6645239591598511,0.7433599829673767,1.016500473022461,50000 -5073.741352796555,6.995683908462524,63481.839210510254,142086,0,63481.839210510254,0.6199000477790833,1.6374740600585938,10000,68570.085460186,0.8313671946525574,0.6447664499282837,0.749019980430603,1.0095096826553345,50000 -5109.310643911362,7.047288179397583,63902.15845608711,143023,0,63902.15845608711,0.6215000152587891,1.634261131286621,10000,69026.07445836067,0.8359569907188416,0.6201549172401428,0.748699963092804,0.999498724937439,50000 -5144.327417612076,7.09255576133728,64322.31871080399,143960,0,64322.31871080399,0.6253000497817993,1.6217032670974731,10000,69481.3456788063,0.8331835865974426,0.6416350603103638,0.7472400069236755,1.003479242324829,50000 -5178.6849110126495,7.1408984661102295,64742.95211935043,144901,0,64742.95211935043,0.629300057888031,1.608303427696228,10000,69936.43508267403,0.8355273008346558,0.6195220947265625,0.750059962272644,0.98756343126297,50000 -5211.518479347229,7.18864107131958,65162.93309760094,145837,0,65162.93309760094,0.6294000148773193,1.6080838441848757,10000,70389.34637379646,0.8406445384025574,0.5987057685852051,0.752240002155304,0.9832723736763,50000 -5247.289430618286,7.246530294418335,65583.1298816204,146775,0,65583.1298816204,0.6300000548362732,1.6063361167907717,10000,70845.42119407654,0.8557812571525574,0.5449709296226501,0.7529000043869019,0.9782390594482422,50000 -5282.982690811157,7.298294305801392,66003.42449641228,147714,0,66003.42449641228,0.6282000541687012,1.6120684146881104,10000,71301.50976467133,0.8352148532867432,0.6282013654708862,0.7502999901771545,0.9910974502563475,50000 -5318.54127407074,7.344376087188721,66423.42095088959,148655,0,66423.42095088959,0.629300057888031,1.6120909452438354,10000,71757.16052174568,0.8431054353713989,0.6112679243087769,0.7550599575042725,0.9850051999092102,50000 -5353.161323547363,7.391125440597534,66843.46200037003,149590,0,66843.46200037003,0.631600022315979,1.5866049528121948,10000,72211.91704010963,0.8514257669448853,0.5645542740821838,0.7563799619674683,0.9604097604751588,50000 -5390.246699333191,7.440268039703369,67263.53252577782,150527,0,67263.53252577782,0.6347000598907471,1.5730087757110596,10000,72669.17042684555,0.8474218845367432,0.5692348480224609,0.7574599981307983,0.9543402791023254,50000 -5426.055099248886,7.485986948013306,67683.55815124512,151464,0,67683.55815124512,0.6391000151634216,1.577191710472107,10000,73125.09868621826,0.8489453196525574,0.5678457021713257,0.757099986076355,0.9635516405105592,50000 -5459.991923809052,7.542108774185181,68104.11196637154,152401,0,68104.11196637154,0.641800045967102,1.5591038465499878,10000,73579.6944000721,0.8550195097923279,0.5494584441184998,0.7616399526596069,0.9430665373802184,50000 -5495.893121242523,7.593789339065552,68524.06725406647,153339,0,68524.06725406647,0.641800045967102,1.5675500631332395,10000,74035.65158557892,0.8586718440055847,0.5393452048301697,0.7623199820518494,0.9501993656158448,50000 -5531.262712240219,7.649880170822143,68944.03726816177,154276,0,68944.03726816177,0.6407000422477722,1.5654865503311155,10000,74491.09753751755,0.8527148365974426,0.5654762387275696,0.7623599767684937,0.9489037990570068,50000 -5565.0824184417725,7.698629856109619,69364.2235994339,155216,0,69364.2235994339,0.6420000195503235,1.5497945547103882,10000,74945.20179843903,0.8570116758346558,0.52569180727005,0.7622999548912048,0.9294662475585938,50000 -5600.47674703598,7.753859281539917,69784.42750167847,156156,0,69784.42750167847,0.6428000330924988,1.5421534776687622,10000,75400.90425562859,0.8631835579872131,0.5074052214622498,0.7647199630737305,0.928766131401062,50000 -5635.178020000458,8.052583456039429,70204.14212560654,157093,0,70204.14212560654,0.6445000171661377,1.5387070178985596,10000,75855.6677236557,0.8556640148162842,0.5329411625862122,0.766539990901947,0.919713020324707,50000 -5671.48996591568,8.11124873161316,70624.15113449097,158031,0,70624.15113449097,0.6491000056266785,1.5364491939544678,10000,76312.09645199776,0.86146479845047,0.5181689858436584,0.7660399675369263,0.9252774715423584,50000 -5707.205604314804,8.161112070083618,71044.10197019577,158968,0,71044.10197019577,0.6464000344276428,1.5262484550476074,10000,76767.86173796654,0.8666796684265137,0.4957504868507385,0.767799973487854,0.9111968278884888,50000 -5741.168454885483,8.218778133392334,71464.28036999702,159905,0,71464.28036999702,0.6496000289916992,1.5200276374816897,10000,77222.11035394669,0.8684179782867432,0.4979733824729919,0.7691599726676941,0.9126954674720764,50000 -5775.471606492996,8.270018339157104,71884.30971646309,160840,0,71884.30971646309,0.6538000106811523,1.5107396841049194,10000,77676.5434346199,0.8680663704872131,0.4921901226043701,0.770039975643158,0.9016311764717102,50000 -5811.875743627548,8.321635723114014,72304.306691885,161775,0,72304.306691885,0.6493000388145447,1.5251152515411377,10000,78133.04519224167,0.870898425579071,0.4825400114059448,0.7693799734115601,0.9072303175926208,50000 -5846.686740875244,8.37223219871521,72724.41666007042,162714,0,72724.41666007042,0.6517000198364258,1.525587797164917,10000,78588.06591463089,0.8746874928474426,0.4815309941768646,0.7702800035476685,0.9123331904411316,50000 -5880.516179323196,8.430397033691406,73144.49038362503,163650,0,73144.49038362503,0.6526000499725342,1.5037193298339844,10000,79042.07584261894,0.8691992163658142,0.4835098087787628,0.7697199583053589,0.8991246819496155,50000 -5915.757030487061,8.490588188171387,73564.75876450539,164588,0,73564.75876450539,0.6550000309944153,1.5063070058822632,10000,79497.69409179688,0.87158203125,0.4755339622497558,0.7726399898529053,0.8931786417961121,50000 -5949.728416919708,8.550045013427734,73985.09216928482,165528,0,73985.09216928482,0.6545000076293945,1.494189739227295,10000,79952.10796093941,0.8754296898841858,0.4589492976665497,0.7731800079345703,0.8862924575805664,50000 -5983.164102315903,8.599197626113892,74405.08458447456,166466,0,74405.08458447456,0.6562000513076782,1.5012069940567017,10000,80405.63422679901,0.8744921684265137,0.477140724658966,0.7742799520492554,0.8912918567657471,50000 -6019.371679782867,8.662525653839111,74825.04020619392,167402,0,74825.04020619392,0.6580000519752502,1.495749831199646,10000,80861.90974783897,0.8738085627555847,0.4666889607906341,0.7739599943161011,0.8845837116241455,50000 -6054.917459487915,8.715810537338257,75245.46972846985,168340,0,75245.46972846985,0.6571000218391418,1.4930970668792725,10000,81317.98803067207,0.8773242235183716,0.4531766176223755,0.7756399512290955,0.8850293755531311,50000 -6089.09916472435,8.771555423736572,75665.5037753582,169278,0,75665.5037753582,0.6571000218391418,1.4991886615753174,10000,81772.30905842781,0.8797265291213989,0.4421044588088989,0.7756999731063843,0.8826619386672974,50000 -6124.984532356262,8.821457386016846,76085.7650001049,170213,0,76085.7650001049,0.6593000292778015,1.4936374425888062,10000,82228.5535402298,0.8807030916213989,0.4466629028320312,0.7764399647712708,0.8772845268249512,50000 -6158.859827756882,8.883781433105469,76505.8050429821,171152,0,76505.8050429821,0.6577000021934509,1.484135627746582,10000,82682.57957959175,0.8788085579872131,0.4484823048114776,0.7775599956512451,0.8697133660316467,50000 -6192.639292001724,8.942861557006836,76925.95561552048,172089,0,76925.95561552048,0.6615000367164612,1.4721359014511108,10000,83136.61755609512,0.8827148079872131,0.4331200420856476,0.7792999744415283,0.8637553453445435,50000 -6228.115994215012,8.996717929840088,77345.90822839737,173027,0,77345.90822839737,0.6621000170707703,1.4893403053283691,10000,83592.15035367012,0.8812695145606995,0.4429061114788055,0.7791399955749512,0.8685945868492126,50000 -6263.508965969086,9.05425500869751,77765.81756949425,173963,0,77765.81756949425,0.6599000096321106,1.4833369255065918,10000,84047.55881023407,0.8810155987739563,0.4393672943115234,0.7788999676704407,0.8695529699325562,50000 -6298.112339735031,9.10762882232666,78186.09090733528,174901,0,78186.09090733528,0.6640000343322754,1.4762301445007324,10000,84502.53801631927,0.8854491710662842,0.4333232641220093,0.7795999646186829,0.8688917756080627,50000 -6331.55254650116,9.166329622268677,78606.16627883911,175842,0,78606.16627883911,0.6651000380516052,1.4696331024169922,10000,84956.16199755669,0.88539057970047,0.4322720766067505,0.7794599533081055,0.866097629070282,50000 -6366.944556713104,9.227767944335938,79026.12602353096,176780,0,79026.12602353096,0.6630000472068787,1.4692275524139404,10000,85411.62379169464,0.8858593702316284,0.424483984708786,0.780019998550415,0.8592524528503418,50000 -6401.689318180084,9.282610654830933,79446.03746342659,177719,0,79446.03746342659,0.6668000221252441,1.4655801057815552,10000,85866.38427376747,0.8868359327316284,0.4214153289794922,0.7813000082969666,0.8596096634864807,50000 -6435.998902320862,9.340909719467165,79866.01587152481,178661,0,79866.01587152481,0.664900004863739,1.4604169130325315,10000,86320.78079080582,0.88685542345047,0.419357031583786,0.7822799682617188,0.8534312844276428,50000 -6470.306844234467,9.394191026687622,80285.93200969696,179599,0,80285.93200969696,0.6645000576972961,1.462255358695984,10000,86775.10715174675,0.8860741853713989,0.4188991189002991,0.781059980392456,0.8571061491966248,50000 -6506.475147724152,9.446362257003784,80705.8322134018,180537,0,80705.8322134018,0.6633000373840332,1.465488314628601,10000,87231.2764441967,0.8853710889816284,0.423235535621643,0.7822999954223633,0.856580913066864,50000 -6540.772976398468,9.49793577194214,81125.91492772102,181476,0,81125.91492772102,0.664900004863739,1.4613149166107178,10000,87685.75764489174,0.8878124952316284,0.4163694381713867,0.7823799848556519,0.8540942072868347,50000 -6576.370446205139,9.548835754394531,81545.87590956688,182415,0,81545.87590956688,0.6657000184059143,1.4587408304214478,10000,88141.41509056091,0.8870312571525574,0.4193125069141388,0.7821399569511414,0.8543209433555603,50000 -6610.12516784668,9.60527515411377,81965.95393848419,183355,0,81965.95393848419,0.6669000387191772,1.458864688873291,10000,88595.35318779945,0.8894140720367432,0.4123246669769287,0.7827000021934509,0.8526468276977539,50000 -6646.70134973526,9.67266058921814,82386.22452759743,184294,0,82386.22452759743,0.6668000221252441,1.4599723815917969,10000,89052.31648516655,0.8896093368530273,0.412265419960022,0.7827000021934509,0.8530824184417725,50000 -6682.224026918411,9.72684407234192,82806.58885216713,185231,0,82806.58885216713,0.6669000387191772,1.4578466415405271,10000,89508.3066072464,0.8886132836341858,0.4169757664203644,0.7825799584388733,0.8518304228782654,50000 -6716.146427869797,9.780983686447144,83226.80952954292,186169,0,83226.80952954292,0.6664000153541565,1.458232283592224,10000,89962.55336284637,0.8907030820846558,0.4041774272918701,0.7828999757766724,0.8518894910812378,50000 -6751.817445039749,9.848082780838013,83647.07682418823,187104,0,83647.07682418823,0.6663000583648682,1.4581419229507446,10000,90418.60726046562,0.8879296779632568,0.4138663411140442,0.7828800082206726,0.8518542051315308,50000 -6786.212738990784,9.911024570465088,84067.24569582939,188039,0,84067.24569582939,0.6663000583648682,1.4581419229507446,10000,90873.28254199028,0.88783198595047,0.4124012589454651,0.7828800082206726,0.8518542051315308,50000 -6822.406176567078,9.963230848312378,84487.2626209259,188976,0,84487.2626209259,0.6663000583648682,1.4581419229507446,10000,91329.5935049057,0.8874218463897705,0.4165248870849609,0.7828800082206726,0.8518542051315308,50000 -6857.440523386002,10.020652532577516,84907.21473288536,189911,0,84907.21473288536,0.6663000583648682,1.4581419229507446,10000,91784.68587970734,0.8871288895606995,0.4155504107475281,0.7828800082206726,0.8518542051315308,50000 -6892.907320976257,10.07563328742981,85327.35544729233,190846,0,85327.35544729233,0.6663000583648682,1.4581419229507446,10000,92240.39691138268,0.8899218440055847,0.4114306569099426,0.7828800082206726,0.8518542051315308,50000 -6926.948734521866,10.13442349433899,85747.62012791634,191785,0,85747.62012791634,0.6663000583648682,1.4581419229507446,10000,92694.81047177316,0.8871288895606995,0.4186196029186249,0.7828800082206726,0.8518542051315308,50000 -6963.765806436539,10.197827577590942,86167.71252465248,192721,0,86167.71252465248,0.6663000583648682,1.4581419229507446,10000,93151.8325972557,0.8883984088897705,0.4127451479434967,0.7828800082206726,0.8518542051315308,50000 -6998.387335062027,10.25277328491211,86587.71695780754,193659,0,86587.71695780754,0.6663000583648682,1.4581419229507446,10000,93606.5622651577,0.8858984112739563,0.4185971319675445,0.7828800082206726,0.8518542051315308,50000 -7033.558439016342,10.31568694114685,87007.7027232647,194596,0,87007.7027232647,0.6663000583648682,1.4581419229507446,10000,94061.83039855956,0.8871874809265137,0.416054368019104,0.7828800082206726,0.8518542051315308,50000 -7069.494187831879,10.3733651638031,87427.7664604187,195533,0,87427.7664604187,0.6663000583648682,1.4581419229507446,10000,94517.9357123375,0.8887109160423279,0.4140407741069793,0.7828800082206726,0.8518542051315308,50000 -7105.049699783325,10.437105655670166,87847.94662070274,196472,0,87847.94662070274,0.6663000583648682,1.4581419229507446,10000,94973.78384137154,0.8884179592132568,0.4153717756271362,0.7828800082206726,0.8518542051315308,50000 -7138.435284852982,10.492969989776611,88267.96205282211,197410,0,88267.96205282211,0.6663000583648682,1.4581419229507446,10000,95427.28940010072,0.8886327743530273,0.4180335402488708,0.7828800082206726,0.8518542051315308,50000 -7174.6018846035,10.557025671005247,88688.21113562584,198345,0,88688.21113562584,0.6663000583648682,1.4581419229507446,10000,95883.8189907074,0.8878905773162842,0.4148828089237213,0.7828800082206726,0.8518542051315308,50000 -7208.525338888168,10.615936279296877,89108.45638513565,199283,0,89108.45638513565,0.6663000583648682,1.4581419229507446,10000,96338.09540224075,0.8878905773162842,0.4175741970539093,0.7828800082206726,0.8518542051315308,50000 -7244.22820520401,10.671677350997925,89528.80938267708,200221,0,89528.80938267708,0.6663000583648682,1.4581419229507446,10000,96794.25735282898,0.8884570002555847,0.4154017865657806,0.7828800082206726,0.8518542051315308,50000 -7278.270314931869,10.727835655212402,89948.93990373611,201160,0,89948.93990373611,0.6663000583648682,1.4581419229507446,10000,97248.5354616642,0.8883593678474426,0.4166028499603271,0.7828800082206726,0.8518542051315308,50000 -7311.052936553955,10.78298568725586,90368.9737598896,202100,0,90368.9737598896,0.6663000583648682,1.4581419229507446,10000,97701.45683455469,0.8859374523162842,0.4209687709808349,0.7828800082206726,0.8518542051315308,50000 -7346.512250185013,10.849711179733276,90788.9855811596,203036,0,90788.9855811596,0.6663000583648682,1.4581419229507446,10000,98157.04360675812,0.88880854845047,0.4150761067867279,0.7828800082206726,0.8518542051315308,50000 -7381.7394988536835,10.903832912445068,91208.94326591492,203972,0,91208.94326591492,0.6663000583648682,1.4581419229507446,10000,98612.33208966257,0.8892382383346558,0.4137385487556457,0.7828800082206726,0.8518542051315308,50000 -7415.764829158783,10.96209716796875,91628.93712472916,204908,0,91628.93712472916,0.6663000583648682,1.4581419229507446,10000,99066.45861530304,0.8872265219688416,0.4107032716274261,0.7828800082206726,0.8518542051315308,50000 -7451.258689165115,11.021414756774902,92048.84640073776,205843,0,92048.84640073776,0.6663000583648682,1.4581419229507446,10000,99521.96976304054,0.8876953125,0.4176147878170013,0.7828800082206726,0.8518542051315308,50000 -7486.59016084671,11.081706762313845,92468.78777194025,206781,0,92468.78777194025,0.6663000583648682,1.4581419229507446,10000,99977.3512263298,0.8884570002555847,0.4156330525875091,0.7828800082206726,0.8518542051315308,50000 -7522.379315137863,11.140997648239136,92888.84211206436,207720,0,92888.84211206436,0.6663000583648682,1.4581419229507446,10000,100433.30316019058,0.8912109136581421,0.4088407754898071,0.7828800082206726,0.8518542051315308,50000 -7556.569756746292,11.206549644470217,93309.04252171516,208657,0,93309.04252171516,0.6663000583648682,1.4581419229507446,10000,100887.8080637455,0.8899999856948853,0.4095646440982818,0.7828800082206726,0.8518542051315308,50000 -7590.646590232849,11.269054651260376,93729.30348491669,209592,0,93729.30348491669,0.6663000583648682,1.4581419229507446,10000,101342.25642371178,0.88832026720047,0.4110678136348724,0.7828800082206726,0.8518542051315308,50000 -7626.880306720734,11.32774782180786,94149.20416498184,210529,0,94149.20416498184,0.6663000583648682,1.4581419229507446,10000,101798.49844503404,0.8885546922683716,0.4129042029380798,0.7828800082206726,0.8518542051315308,50000 -7661.169199705124,11.396864175796509,94569.50618243216,211471,0,94569.50618243216,0.6663000583648682,1.4581419229507446,10000,102253.2084736824,0.8863281011581421,0.4174293875694275,0.7828800082206726,0.8518542051315308,50000 -7696.322644948959,11.454381227493286,94989.62823104858,212411,0,94989.62823104858,0.6663000583648682,1.4581419229507446,10000,102708.59036684036,0.8870702981948853,0.4161613285541534,0.7828800082206726,0.8518542051315308,50000 -7730.858766078949,11.52181363105774,95409.86905503272,213345,0,95409.86905503272,0.6663000583648682,1.4581419229507446,10000,103163.48323106766,0.8892773389816284,0.4133086800575256,0.7828800082206726,0.8518542051315308,50000 -7766.759558439255,11.57755184173584,95829.8261590004,214281,0,95829.8261590004,0.6663000583648682,1.4581419229507446,10000,103619.44575953484,0.888964831829071,0.4124159216880798,0.7828800082206726,0.8518542051315308,50000 -7800.7112374305725,11.63347578048706,96249.92731976508,215221,0,96249.92731976508,0.6663000583648682,1.4581419229507446,10000,104073.60329914092,0.8868359327316284,0.4176099598407745,0.7828800082206726,0.8518542051315308,50000 -7834.320219755173,11.700467109680176,96669.82828593254,216159,0,96669.82828593254,0.6663000583648682,1.4581419229507446,10000,104527.22795033456,0.8875195384025574,0.4130787849426269,0.7828800082206726,0.8518542051315308,50000 -7870.795943975449,11.771633863449097,97089.97589826584,217094,0,97089.97589826584,0.6663000583648682,1.4581419229507446,10000,104983.97067832948,0.8885155916213989,0.4117570519447326,0.7828800082206726,0.8518542051315308,50000 -7906.584993600845,11.830165147781372,97509.97085809708,218032,0,97509.97085809708,0.6663000583648682,1.4581419229507446,10000,105439.86210155489,0.8877929449081421,0.4174558520317077,0.7828800082206726,0.8518542051315308,50000 -7941.163271188736,11.894726037979126,97929.96910381316,218971,0,97929.96910381316,0.6663000583648682,1.4581419229507446,10000,105894.55187416077,0.8862499594688416,0.4198924303054809,0.7828800082206726,0.8518542051315308,50000 -7976.790856599808,11.953049659729004,98350.23102784155,219909,0,98350.23102784155,0.6663000583648682,1.4581419229507446,10000,106350.5483431816,0.887499988079071,0.4200115203857422,0.7828800082206726,0.8518542051315308,50000 -8012.15415096283,12.012901067733765,98770.417617321,220798,0,98770.417617321,0.6663000583648682,1.4581419229507446,10000,106806.2047586441,0.8900781273841858,0.4076668322086334,0.7828800082206726,0.8518542051315308,50000 -8046.972583532333,12.071365118026732,99190.58779096603,221735,0,99190.58779096603,0.6663000583648682,1.4581419229507446,10000,107261.30042052268,0.8908789157867432,0.4115219414234161,0.7828800082206726,0.8518542051315308,50000 -8083.12747168541,12.127808094024658,99610.72120189668,222672,0,99610.72120189668,0.6663000583648682,1.4581419229507446,10000,107717.69447016716,0.8867382407188416,0.4187392294406891,0.7828800082206726,0.8518542051315308,50000 -8116.420456647873,12.186193466186523,100030.98379468918,223609,0,100030.98379468918,0.6663000583648682,1.4581419229507446,10000,108171.35675144196,0.8858398199081421,0.4228871762752533,0.7828800082206726,0.8518542051315308,50000 -8151.446369409561,12.254721403121948,100451.22117090224,224545,0,100451.22117090224,0.6663000583648682,1.4581419229507446,10000,108626.73807239532,0.8891015648841858,0.4168025851249695,0.7828800082206726,0.8518542051315308,50000 -8184.739677429199,12.3156476020813,100871.4775969982,225482,0,100871.4775969982,0.6663000583648682,1.4581419229507446,10000,109080.39751529694,0.8861327767372131,0.4200985133647918,0.7828800082206726,0.8518542051315308,50000 -8218.687573194504,12.384529113769531,101291.44068813324,226419,0,101291.44068813324,0.6663000583648682,1.4581419229507446,10000,109534.42616176604,0.8892577886581421,0.4118213951587677,0.7828800082206726,0.8518542051315308,50000 -8254.07694530487,12.444185018539429,101711.74678444862,227359,0,101711.74678444862,0.6663000583648682,1.4581419229507446,10000,109990.22974205016,0.8875781297683716,0.4156609177589416,0.7828800082206726,0.8518542051315308,50000 -8288.393639802933,12.505478382110596,102131.85267829896,228297,0,102131.85267829896,0.6663000583648682,1.4581419229507446,10000,110444.7630224228,0.8884570002555847,0.4122214913368225,0.7828800082206726,0.8518542051315308,50000 -8324.046884298325,12.562983989715576,102552.01909089088,229234,0,102552.01909089088,0.6663000583648682,1.4581419229507446,10000,110900.68841600418,0.8875585794448853,0.4155343770980835,0.7828800082206726,0.8518542051315308,50000 -8358.187504053116,12.623552083969116,102972.22385382652,230175,0,102972.22385382652,0.6663000583648682,1.4581419229507446,10000,111355.14338946342,0.8898828029632568,0.4132254719734192,0.7828800082206726,0.8518542051315308,50000 -8390.762695789337,12.692643880844116,103392.2831542492,231115,0,103392.2831542492,0.6663000583648682,1.4581419229507446,10000,111807.89689588548,0.88916015625,0.4119197726249695,0.7828800082206726,0.8518542051315308,50000 -8425.405321359634,12.775279998779297,103812.51510357855,232053,0,103812.51510357855,0.6663000583648682,1.4581419229507446,10000,112262.90275025368,0.889453113079071,0.4119241535663605,0.7828800082206726,0.8518542051315308,50000 -8460.428438425064,12.834226846694946,104232.78696107864,232993,0,104232.78696107864,0.6663000583648682,1.4581419229507446,10000,112718.30594062804,0.8890820145606995,0.4118773639202118,0.7828800082206726,0.8518542051315308,50000 -8496.338250160217,12.893026351928713,104652.87011408806,233934,0,104652.87011408806,0.6663000583648682,1.4581419229507446,10000,113174.4059624672,0.8887304663658142,0.409778743982315,0.7828800082206726,0.8518542051315308,50000 -8531.880004882812,12.95697784423828,105072.76435279846,234871,0,105072.76435279846,0.6663000583648682,1.4581419229507446,10000,113629.954102993,0.8856640458106995,0.4191117286682129,0.7828800082206726,0.8518542051315308,50000 -8566.35573387146,13.02313780784607,105492.69994354248,235811,0,105492.69994354248,0.6663000583648682,1.4581419229507446,10000,114084.48039150238,0.8889062404632568,0.409523993730545,0.7828800082206726,0.8518542051315308,50000 -8602.39948964119,13.09098768234253,105912.88513946532,236750,0,105912.88513946532,0.6663000583648682,1.4581419229507446,10000,114540.82594275476,0.88880854845047,0.4133784770965576,0.7828800082206726,0.8518542051315308,50000 -8636.400366544724,13.154456853866575,106333.18955159187,237691,0,106333.18955159187,0.6663000583648682,1.4581419229507446,10000,114995.24330091476,0.8860937356948853,0.4221966564655304,0.7828800082206726,0.8518542051315308,50000 -8672.216531038284,13.225939512252808,106753.54248261452,238632,0,106753.54248261452,0.6663000583648682,1.4581419229507446,10000,115451.53286147118,0.8895898461341858,0.4112511873245239,0.7828800082206726,0.8518542051315308,50000 -8705.033515691757,13.291870594024658,107173.82443404198,239572,0,107173.82443404198,0.6663000583648682,1.4581419229507446,10000,115904.74687862396,0.8858007788658142,0.4202886819839477,0.7828800082206726,0.8518542051315308,50000 -8741.489234685898,13.36388087272644,107593.75956201552,240510,0,107593.75956201552,0.6663000583648682,1.4581419229507446,10000,116361.25875759123,0.888671875,0.4127401113510132,0.7828800082206726,0.8518542051315308,50000 -8777.088340520859,13.4263174533844,108013.9641532898,241450,0,108013.9641532898,0.6663000583648682,1.4581419229507446,10000,116817.1728887558,0.8886327743530273,0.4101268351078033,0.7828800082206726,0.8518542051315308,50000 -8811.436619997025,13.484199047088625,108434.24830532074,242389,0,108434.24830532074,0.6663000583648682,1.4581419229507446,10000,117271.9125611782,0.8875585794448853,0.4174584746360779,0.7828800082206726,0.8518542051315308,50000 -8846.27702832222,13.549908876419067,108854.31929779051,243327,0,108854.31929779051,0.6663000583648682,1.4581419229507446,10000,117726.93820238112,0.8881250023841858,0.4179988503456116,0.7828800082206726,0.8518542051315308,50000 -8882.327449798584,13.610346794128418,109274.4168112278,244264,0,109274.4168112278,0.6663000583648682,1.4581419229507446,10000,118183.19608783722,0.8875976204872131,0.4127923548221588,0.7828800082206726,0.8518542051315308,50000 -8918.449665307999,13.66965365409851,109694.49145460127,245205,0,109694.49145460127,0.6663000583648682,1.4581419229507446,10000,118639.5016169548,0.8902148008346558,0.4121952056884765,0.7828800082206726,0.8518542051315308,50000 -8953.590856313705,13.729873657226562,110114.44577169418,246144,0,110114.44577169418,0.6663000583648682,1.4581419229507446,10000,119094.70700001717,0.886035144329071,0.4249056279659271,0.7828800082206726,0.8518542051315308,50000 -8987.530959367752,13.79277729988098,110534.5217063427,247082,0,110534.5217063427,0.6663000583648682,1.4581419229507446,10000,119548.83497023582,0.8880273103713989,0.4152797162532806,0.7828800082206726,0.8518542051315308,50000 -9023.016151428224,13.860776662826538,110954.5231757164,248019,0,110954.5231757164,0.6663000583648682,1.4581419229507446,10000,120004.43804454803,0.8878515362739563,0.4194375574588775,0.7828800082206726,0.8518542051315308,50000 -9058.499682426453,13.92180871963501,111374.45286393166,248961,0,111374.45286393166,0.6663000583648682,1.4581419229507446,10000,120459.96077299118,0.8873632550239563,0.4151614904403686,0.7828800082206726,0.8518542051315308,50000 -9093.047081708908,13.991700649261476,111795.12572550774,249901,0,111795.12572550774,0.6663000583648682,1.4581419229507446,10000,120915.29966640472,0.8882030844688416,0.4157791137695312,0.7828800082206726,0.8518542051315308,50000 -9128.832021713257,14.05441927909851,112215.37737846376,250839,0,112215.37737846376,0.6663000583648682,1.4581419229507446,10000,121371.44778513908,0.8891796469688416,0.4126760065555572,0.7828800082206726,0.8518542051315308,50000 -9163.606656312944,14.1160409450531,112635.57611656188,251777,0,112635.57611656188,0.6663000583648682,1.4581419229507446,10000,121826.532310009,0.8882421851158142,0.4113604128360748,0.7828800082206726,0.8518542051315308,50000 -9200.502378463743,14.178927898406982,113055.79975700378,252710,0,113055.79975700378,0.6663000583648682,1.4581419229507446,10000,122283.76314520836,0.8890429735183716,0.4132768213748932,0.7828800082206726,0.8518542051315308,50000 -9236.304715394974,14.241063117980955,113475.82202625276,253650,0,113475.82202625276,0.6663000583648682,1.4581419229507446,10000,122739.698867321,0.8868749737739563,0.4217991530895233,0.7828800082206726,0.8518542051315308,50000 -9271.973058700562,14.31032943725586,113895.70823192596,254589,0,113895.70823192596,0.6663000583648682,1.4581419229507446,10000,123195.3730111122,0.8882030844688416,0.4138852655887604,0.7828800082206726,0.8518542051315308,50000 -9309.296308279036,14.374987840652466,114315.62262010574,255527,0,114315.62262010574,0.6663000583648682,1.4581419229507446,10000,123652.72332525252,0.8911327719688416,0.4039798676967621,0.7828800082206726,0.8518542051315308,50000 -9345.163439273834,14.435834646224976,114735.56063556673,256465,0,114735.56063556673,0.6663000583648682,1.4581419229507446,10000,124108.63796424866,0.8897265195846558,0.412788063287735,0.7828800082206726,0.8518542051315308,50000 -9380.308512449265,14.500270128250122,115155.77038359642,257404,0,115155.77038359642,0.6663000583648682,1.4581419229507446,10000,124564.10697579384,0.8879101276397705,0.4117997884750366,0.7828800082206726,0.8518542051315308,50000 -9416.172207832336,14.567237615585327,115575.71538758278,258340,0,115575.71538758278,0.6663000583648682,1.4581419229507446,10000,125020.03127336502,0.8883788585662842,0.4111296832561493,0.7828800082206726,0.8518542051315308,50000 -9451.75629067421,14.63141632080078,115995.97527813911,259279,0,115995.97527813911,0.6663000583648682,1.4581419229507446,10000,125475.9886534214,0.8874804377555847,0.4159693717956543,0.7828800082206726,0.8518542051315308,50000 -9486.679986476898,14.693573951721191,116416.2382993698,260218,0,116416.2382993698,0.6663000583648682,1.4581419229507446,10000,125931.2862727642,0.888964831829071,0.4118813574314117,0.7828800082206726,0.8518542051315308,50000 -9522.27923154831,14.757773399353027,116836.5095319748,261156,0,116836.5095319748,0.6663000583648682,1.4581419229507446,10000,126387.26921463013,0.8884961009025574,0.4160282909870147,0.7828800082206726,0.8518542051315308,50000 -9557.40468263626,14.821744441986084,117256.55498623848,262095,0,117256.55498623848,0.6663000583648682,1.4581419229507446,10000,126842.553047657,0.8870702981948853,0.4149435758590698,0.7828800082206726,0.8518542051315308,50000 -9590.883177995682,14.8841450214386,117676.4786427021,263033,0,117676.4786427021,0.6663000583648682,1.4581419229507446,10000,127296.06613850594,0.8878515362739563,0.4189508855342865,0.7828800082206726,0.8518542051315308,50000 -9625.748121738434,14.949349880218506,118096.52826428412,263973,0,118096.52826428412,0.6663000583648682,1.4581419229507446,10000,127751.09358382224,0.8878124952316284,0.4134690463542938,0.7828800082206726,0.8518542051315308,50000 -9660.38653588295,15.022777557373049,118516.4486773014,264914,0,118516.4486773014,0.6663000583648682,1.4581419229507446,10000,128205.77392411232,0.8861327767372131,0.4168863594532013,0.7828800082206726,0.8518542051315308,50000 -9696.406280517578,15.097080707550049,118936.5146408081,265854,0,118936.5146408081,0.6663000583648682,1.4581419229507446,10000,128661.9830391407,0.8871288895606995,0.4185587167739868,0.7828800082206726,0.8518542051315308,50000 -9732.229608774183,15.162375926971436,119356.6550552845,266795,0,119356.6550552845,0.6663000583648682,1.4581419229507446,10000,129118.06162571908,0.8882030844688416,0.4178065359592438,0.7828800082206726,0.8518542051315308,50000 -9766.939259529114,15.23660135269165,119776.81697392464,267734,0,119776.81697392464,0.6663000583648682,1.4581419229507446,10000,129573.05659985542,0.887011706829071,0.4181829690933227,0.7828800082206726,0.8518542051315308,50000 -9801.692220449448,15.302525758743286,120196.85601568222,268674,0,120196.85601568222,0.6663000583648682,1.4581419229507446,10000,130027.96326994896,0.8915820121765137,0.4057655334472656,0.7828800082206726,0.8518542051315308,50000 -9837.649421453476,15.373100996017456,120616.76523089407,269613,0,120616.76523089407,0.6663000583648682,1.4581419229507446,10000,130483.94918727876,0.8875976204872131,0.4203981757164001,0.7828800082206726,0.8518542051315308,50000 -9872.810532808304,15.437132596969604,121037.0772664547,270553,0,121037.0772664547,0.6663000583648682,1.4581419229507446,10000,130939.53532242776,0.8871679306030273,0.4184862673282623,0.7828800082206726,0.8518542051315308,50000 -9909.046003103256,15.510629415512083,121457.0885810852,271495,0,121457.0885810852,0.6663000583648682,1.4581419229507446,10000,131395.90521931648,0.8880859017372131,0.4176699817180633,0.7828800082206726,0.8518542051315308,50000 -9944.696164131165,15.589151382446287,121877.40932011604,272435,0,121877.40932011604,0.6663000583648682,1.4581419229507446,10000,131852.0035393238,0.8871484398841858,0.4141007661819458,0.7828800082206726,0.8518542051315308,50000 -9978.283198833466,15.65884566307068,122297.49721193314,273376,0,122297.49721193314,0.6663000583648682,1.4581419229507446,10000,132305.79712724686,0.887499988079071,0.416747510433197,0.7828800082206726,0.8518542051315308,50000 -10014.80516576767,15.73838758468628,122717.44657683372,274315,0,122717.44657683372,0.6663000583648682,1.4581419229507446,10000,132762.39725995064,0.888671875,0.4159770011901855,0.7828800082206726,0.8518542051315308,50000 -10050.325754642488,15.802060842514038,123137.75361943243,275256,0,123137.75361943243,0.6663000583648682,1.4581419229507446,10000,133218.33816456795,0.8893945217132568,0.4109589755535126,0.7828800082206726,0.8518542051315308,50000 -10084.47655248642,15.868030309677124,123557.85822701454,276195,0,123557.85822701454,0.6663000583648682,1.4581419229507446,10000,133672.7080309391,0.88636714220047,0.4172734320163727,0.7828800082206726,0.8518542051315308,50000 -10117.911062717438,15.936080694198608,123978.16338539124,277135,0,123978.16338539124,0.6663000583648682,1.4581419229507446,10000,134126.56583857536,0.8903319835662842,0.4136055707931518,0.7828800082206726,0.8518542051315308,50000 -10151.590016841888,16.011402368545532,124398.14441418648,278075,0,124398.14441418648,0.6663000583648682,1.4581419229507446,10000,134580.3505373001,0.8899218440055847,0.4158057570457458,0.7828800082206726,0.8518542051315308,50000 -10184.596621990204,16.07777214050293,124818.47934174538,279015,0,124818.47934174538,0.6663000583648682,1.4581419229507446,10000,135033.80769109726,0.88734370470047,0.408898115158081,0.7828800082206726,0.8518542051315308,50000 -10217.464999198914,16.149641752243042,125238.88067746162,279957,0,125238.88067746162,0.6663000583648682,1.4581419229507446,10000,135487.19892835617,0.890429675579071,0.4099819958209991,0.7828800082206726,0.8518542051315308,50000 -10252.94473028183,16.230097770690918,125658.77555394173,280898,0,125658.77555394173,0.6663000583648682,1.4581419229507446,10000,135942.70322799683,0.8888476490974426,0.4090248644351959,0.7828800082206726,0.8518542051315308,50000 -10289.864123106005,16.29589319229126,126078.73848581314,281791,0,126078.73848581314,0.6663000583648682,1.4581419229507446,10000,136399.6976134777,0.8855664134025574,0.4181103110313415,0.7828800082206726,0.8518542051315308,50000 -10325.465045690536,16.3734393119812,126498.94877243042,282733,0,126498.94877243042,0.6663000583648682,1.4581419229507446,10000,136855.63595891,0.8891015648841858,0.4115650951862335,0.7828800082206726,0.8518542051315308,50000 -10359.358072280884,16.440247058868408,126919.20204520226,283671,0,126919.20204520226,0.6663000583648682,1.4581419229507446,10000,137309.89725899696,0.8883788585662842,0.4134488999843597,0.7828800082206726,0.8518542051315308,50000 -10395.601438760756,16.506649494171143,127339.28399133682,284611,0,127339.28399133682,0.6663000583648682,1.4581419229507446,10000,137766.33781647682,0.8886132836341858,0.4173832535743713,0.7828800082206726,0.8518542051315308,50000 -10429.8513276577,16.57482123374939,127759.24665808678,285552,0,127759.24665808678,0.6663000583648682,1.4581419229507446,10000,138220.6668958664,0.88783198595047,0.4149629175662994,0.7828800082206726,0.8518542051315308,50000 -10465.854510307312,16.641371726989746,128179.28594827652,286492,0,128179.28594827652,0.6663000583648682,1.4581419229507446,10000,138676.82459497452,0.8872656226158142,0.417816162109375,0.7828800082206726,0.8518542051315308,50000 -10499.648500919342,16.957338094711304,128598.95075941086,287433,0,128598.95075941086,0.6663000583648682,1.4581419229507446,10000,139130.64910507202,0.88832026720047,0.4100013673305511,0.7828800082206726,0.8518542051315308,50000 -10535.140134096146,17.03790783882141,129019.15416574478,288369,0,129019.15416574478,0.6663000583648682,1.4581419229507446,10000,139586.47339582443,0.8876171708106995,0.4147033095359802,0.7828800082206726,0.8518542051315308,50000 -10569.173550605774,17.106011629104614,129439.41557979584,289307,0,129439.41557979584,0.6663000583648682,1.4581419229507446,10000,140040.8846886158,0.887011706829071,0.4225472211837768,0.7828800082206726,0.8518542051315308,50000 -10602.943563699722,17.181469678878784,129859.462069273,290246,0,129859.462069273,0.6663000583648682,1.4581419229507446,10000,140494.82578778267,0.8878515362739563,0.4139612913131714,0.7828800082206726,0.8518542051315308,50000 -10636.189188480375,17.256900310516357,130279.55357980728,291184,0,130279.55357980728,0.6663000583648682,1.4581419229507446,10000,140948.2878472805,0.8871288895606995,0.4143074452877044,0.7828800082206726,0.8518542051315308,50000 -10671.554021835327,17.342145442962646,130699.70896553992,292121,0,130699.70896553992,0.6663000583648682,1.4581419229507446,10000,141403.94225287437,0.8910741806030273,0.4123877286911011,0.7828800082206726,0.8518542051315308,50000 -10707.190835475922,17.419295072555542,131119.73110198975,293060,0,131119.73110198975,0.6663000583648682,1.4581419229507446,10000,141859.72695946693,0.8864648342132568,0.422519326210022,0.7828800082206726,0.8518542051315308,50000 -10743.051213026049,17.48875403404236,131539.8347415924,293998,0,131539.8347415924,0.6663000583648682,1.4581419229507446,10000,142315.80940794945,0.8893359303474426,0.4128582775592804,0.7828800082206726,0.8518542051315308,50000 -10778.418419837952,17.556360244750977,131959.82090997696,294933,0,131959.82090997696,0.6663000583648682,1.4581419229507446,10000,142771.27851748466,0.8885351419448853,0.4156334996223449,0.7828800082206726,0.8518542051315308,50000 -10812.218878507614,17.62444758415222,132379.98327803612,295873,0,132379.98327803612,0.6663000583648682,1.4581419229507446,10000,143225.35876965523,0.8855273127555847,0.4239533543586731,0.7828800082206726,0.8518542051315308,50000 -10848.138242006302,17.70807909965515,132800.25978302956,296809,0,132800.25978302956,0.6663000583648682,1.4581419229507446,10000,143681.6864824295,0.8883007764816284,0.4127146899700165,0.7828800082206726,0.8518542051315308,50000 -10882.680842399595,17.776843547821045,133220.3181180954,297749,0,133220.3181180954,0.6663000583648682,1.4581419229507446,10000,144136.4060664177,0.8908202648162842,0.405164510011673,0.7828800082206726,0.8518542051315308,50000 -10916.362850904465,17.857324361801147,133640.38061594963,298691,0,133640.38061594963,0.6663000583648682,1.4581419229507446,10000,144590.28002810478,0.8869335651397705,0.4187902808189392,0.7828800082206726,0.8518542051315308,50000 -10949.912775278091,17.94206404685974,134060.37126111984,299630,0,134060.37126111984,0.6663000583648682,1.4581419229507446,10000,145043.9550564289,0.8880468606948853,0.4154566526412964,0.7828800082206726,0.8518542051315308,50000 -10984.052475690842,18.025188207626343,134480.29626059532,300572,0,134480.29626059532,0.6663000583648682,1.4581419229507446,10000,145498.15152049065,0.8863085508346558,0.4205401539802551,0.7828800082206726,0.8518542051315308,50000 -11018.582464933395,18.09760928153992,134900.26795220375,301509,0,134900.26795220375,0.6663000583648682,1.4581419229507446,10000,145952.77487421036,0.8903906345367432,0.4110255241394043,0.7828800082206726,0.8518542051315308,50000 -11055.576867341995,18.16519927978516,135320.3759112358,302451,0,135320.3759112358,0.6663000583648682,1.4581419229507446,10000,146409.99482226372,0.8894921541213989,0.4096398055553436,0.7828800082206726,0.8518542051315308,50000 -11090.334149360657,18.23756146430969,135740.50975704193,303392,0,135740.50975704193,0.6663000583648682,1.4581419229507446,10000,146865.0077443123,0.88832026720047,0.413810133934021,0.7828800082206726,0.8518542051315308,50000 -11126.382422208786,18.318288803100582,136160.44328808784,304333,0,136160.44328808784,0.6663000583648682,1.4581419229507446,10000,147321.12025094032,0.8908593654632568,0.4046581983566284,0.7828800082206726,0.8518542051315308,50000 -11161.775073289871,18.397958040237427,136580.41247677803,305272,0,136580.41247677803,0.6663000583648682,1.4581419229507446,10000,147776.61046028137,0.885546863079071,0.4197524785995483,0.7828800082206726,0.8518542051315308,50000 -11196.973927736282,18.467971086502075,137000.47409558296,306211,0,137000.47409558296,0.6663000583648682,1.4581419229507446,10000,148231.98989629743,0.8883788585662842,0.4131350219249725,0.7828800082206726,0.8518542051315308,50000 -11231.89693236351,18.539098262786865,137420.50700616837,307151,0,137420.50700616837,0.6663000583648682,1.4581419229507446,10000,148687.06598305702,0.8883984088897705,0.4133432805538177,0.7828800082206726,0.8518542051315308,50000 -11267.01406097412,18.608031034469604,137840.77836227417,308089,0,137840.77836227417,0.6663000583648682,1.4581419229507446,10000,149142.5724697113,0.8885155916213989,0.4139094352722168,0.7828800082206726,0.8518542051315308,50000 -11302.639257907867,18.680636405944824,138260.96691298485,309029,0,138260.96691298485,0.6663000583648682,1.4581419229507446,10000,149598.50726366043,0.8866991996765137,0.4185129106044769,0.7828800082206726,0.8518542051315308,50000 -11335.685423135756,18.76006007194519,138681.03453612328,309968,0,138681.03453612328,0.6663000583648682,1.4581419229507446,10000,150051.7492189407,0.8897656202316284,0.4106150865554809,0.7828800082206726,0.8518542051315308,50000 -11368.736935138702,18.845513105392456,139101.31498932838,310907,0,139101.31498932838,0.6663000583648682,1.4581419229507446,10000,150505.21512961388,0.8888671398162842,0.4129617512226105,0.7828800082206726,0.8518542051315308,50000 -11403.054658412932,18.928631067276,139521.58361434937,311842,0,139521.58361434937,0.6663000583648682,1.4581419229507446,10000,150959.9334256649,0.8851171731948853,0.4197098910808563,0.7828800082206726,0.8518542051315308,50000 -11437.10601568222,18.999311447143555,139941.93662166595,312780,0,139941.93662166595,0.6663000583648682,1.4581419229507446,10000,151414.4565434456,0.8889452815055847,0.4150924682617187,0.7828800082206726,0.8518542051315308,50000 -11470.306761026382,19.070598363876343,140362.2047946453,313722,0,140362.2047946453,0.6663000583648682,1.4581419229507446,10000,151868.0457689762,0.88734370470047,0.4192501008510589,0.7828800082206726,0.8518542051315308,50000 -11505.935393810272,19.155100107193,140782.08658885956,314662,0,140782.08658885956,0.6663000583648682,1.4581419229507446,10000,152323.69007754326,0.887011706829071,0.4187501668930053,0.7828800082206726,0.8518542051315308,50000 -11541.796007156372,19.238603115081787,141202.2172217369,315601,0,141202.2172217369,0.6663000583648682,1.4581419229507446,10000,152779.8137331009,0.8889257907867432,0.4117574989795685,0.7828800082206726,0.8518542051315308,50000 -11575.5288336277,19.32027697563172,141622.4211435318,316542,0,141622.4211435318,0.6663000583648682,1.4581419229507446,10000,153233.88162970543,0.8873632550239563,0.4215485751628876,0.7828800082206726,0.8518542051315308,50000 -11608.605984210968,19.408986806869507,142042.42588067055,317477,0,142042.42588067055,0.6663000583648682,1.4581419229507446,10000,153687.1012263298,0.88929682970047,0.4097913205623626,0.7828800082206726,0.8518542051315308,50000 -11643.827334403992,19.49708747863769,142462.3735461235,318414,0,142462.3735461235,0.6663000583648682,1.4581419229507446,10000,154142.40728020668,0.8886523246765137,0.4203369319438934,0.7828800082206726,0.8518542051315308,50000 -11678.085348844528,19.57066559791565,142882.58603477478,319357,0,142882.58603477478,0.6663000583648682,1.4581419229507446,10000,154597.0007481575,0.8871288895606995,0.4173811674118042,0.7828800082206726,0.8518542051315308,50000 -11715.436604499817,19.6518189907074,143302.8230383396,320299,0,143302.8230383396,0.6663000583648682,1.4581419229507446,10000,155054.7192568779,0.8873046636581421,0.4134297370910644,0.7828800082206726,0.8518542051315308,50000 -11750.764872550964,19.725261211395264,143722.71570396423,321239,0,143722.71570396423,0.6663000583648682,1.4581419229507446,10000,155510.06215786934,0.8871288895606995,0.4203206896781921,0.7828800082206726,0.8518542051315308,50000 -11783.672978162766,19.807963609695435,144142.85748529434,322176,0,144142.85748529434,0.6663000583648682,1.4581419229507446,10000,155963.2437081337,0.8892187476158142,0.4106161892414093,0.7828800082206726,0.8518542051315308,50000 -11819.779586553574,19.89247989654541,144562.77267432213,323113,0,144562.77267432213,0.6663000583648682,1.4581419229507446,10000,156419.39860010147,0.8874218463897705,0.414051741361618,0.7828800082206726,0.8518542051315308,50000 -11854.017767190931,19.96893191337585,144983.06563448906,324054,0,144983.06563448906,0.6663000583648682,1.4581419229507446,10000,156874.05456399918,0.8886327743530273,0.416003555059433,0.7828800082206726,0.8518542051315308,50000 -11890.295197963716,20.042044162750244,145403.02961182594,324993,0,145403.02961182594,0.6663000583648682,1.4581419229507446,10000,157330.41791248322,0.890429675579071,0.413343220949173,0.7828800082206726,0.8518542051315308,50000 -11927.290768623352,20.128324270248413,145823.25002264977,325932,0,145823.25002264977,0.6663000583648682,1.4581419229507446,10000,157787.76915454865,0.8891210556030273,0.4105064570903778,0.7828800082206726,0.8518542051315308,50000 -11962.746523618698,20.20204377174377,146243.3017590046,326870,0,146243.3017590046,0.6663000583648682,1.4581419229507446,10000,158243.3984901905,0.8895702958106995,0.4092340767383575,0.7828800082206726,0.8518542051315308,50000 -11998.719693899156,20.27439427375793,146663.34315609932,327807,0,146663.34315609932,0.6663000583648682,1.4581419229507446,10000,158699.53372955322,0.88720703125,0.4160758852958679,0.7828800082206726,0.8518542051315308,50000 -12033.599283218384,20.351003408432007,147083.42253422737,328744,0,147083.42253422737,0.6663000583648682,1.4581419229507446,10000,159154.6178812981,0.8877929449081421,0.4134314954280853,0.7828800082206726,0.8518542051315308,50000 -12068.009401798248,20.425606727600098,147503.57654500008,329682,0,147503.57654500008,0.6663000583648682,1.4581419229507446,10000,159609.30521917343,0.8869140148162842,0.4140605926513672,0.7828800082206726,0.8518542051315308,50000 -12103.958374977112,20.504555225372314,147923.52029657364,330619,0,147923.52029657364,0.6663000583648682,1.4581419229507446,10000,160065.32572340965,0.890429675579071,0.40680992603302,0.7828800082206726,0.8518542051315308,50000 -12137.546013832092,20.57728481292725,148343.6543688774,331560,0,148343.6543688774,0.6663000583648682,1.4581419229507446,10000,160519.1687822342,0.8872460722923279,0.4223819673061371,0.7828800082206726,0.8518542051315308,50000 -12173.454895019531,20.65666031837464,148763.71720457077,332497,0,148763.71720457077,0.6663000583648682,1.4581419229507446,10000,160975.2687857151,0.88783198595047,0.4149302840232849,0.7828800082206726,0.8518542051315308,50000 -12208.547565460203,20.73137640953064,149183.92890954018,333435,0,149183.92890954018,0.6663000583648682,1.4581419229507446,10000,161430.69685649872,0.8897656202316284,0.4081674516201019,0.7828800082206726,0.8518542051315308,50000 -12243.464797735214,20.8051335811615,149604.09540700912,334373,0,149604.09540700912,0.6663000583648682,1.4581419229507446,10000,161885.90264344215,0.8862499594688416,0.4207873046398163,0.7828800082206726,0.8518542051315308,50000 -12277.341688632963,20.890724897384644,150024.3045117855,335311,0,150024.3045117855,0.6663000583648682,1.4581419229507446,10000,162340.12366628647,0.8875390291213989,0.4163751006126404,0.7828800082206726,0.8518542051315308,50000 -12311.726724147797,20.969918489456177,150444.52754735947,336250,0,150444.52754735947,0.6663000583648682,1.4581419229507446,10000,162794.85982775688,0.8867382407188416,0.4137430191040039,0.7828800082206726,0.8518542051315308,50000 -12345.37182021141,21.04414367675781,150864.89817857742,337189,0,150864.89817857742,0.6663000583648682,1.4581419229507446,10000,163248.99874138832,0.888476550579071,0.4186500012874603,0.7828800082206726,0.8518542051315308,50000 -12381.196682929993,21.12299537658692,151284.9808397293,338129,0,151284.9808397293,0.6663000583648682,1.4581419229507446,10000,163705.03434443474,0.8878124952316284,0.4141222536563873,0.7828800082206726,0.8518542051315308,50000 -12416.785385370256,21.197551250457764,151705.25899219513,339070,0,151705.25899219513,0.6663000583648682,1.4581419229507446,10000,164161.02415442467,0.8898437023162842,0.4094454944133758,0.7828800082206726,0.8518542051315308,50000 -12451.752354383469,21.272735834121704,152125.4069724083,340011,0,152125.4069724083,0.6663000583648682,1.4581419229507446,10000,164616.2634806633,0.8881640434265137,0.419726699590683,0.7828800082206726,0.8518542051315308,50000 -12486.864793300629,21.35389828681945,152545.37711572647,340948,0,152545.37711572647,0.6663000583648682,1.4581419229507446,10000,165071.47660183907,0.8879492282867432,0.4159083962440491,0.7828800082206726,0.8518542051315308,50000 -12523.032780647278,21.442599773406982,152965.68938064575,341885,0,152965.68938064575,0.6663000583648682,1.4581419229507446,10000,165528.09414219856,0.8882616758346558,0.4191540777683258,0.7828800082206726,0.8518542051315308,50000 -12558.412452220917,21.51712751388549,153385.66485118866,342824,0,153385.66485118866,0.6663000583648682,1.4581419229507446,10000,165983.57289791107,0.8852733969688416,0.4213873147964477,0.7828800082206726,0.8518542051315308,50000 -12593.196396112442,21.595152139663696,153805.74320554733,343765,0,153805.74320554733,0.6663000583648682,1.4581419229507446,10000,166438.56208109856,0.8892382383346558,0.4125251770019531,0.7828800082206726,0.8518542051315308,50000 -12628.076207399368,21.68330931663513,154226.0605700016,344704,0,154226.0605700016,0.6663000583648682,1.4581419229507446,10000,166893.89604711533,0.8899218440055847,0.4108539223670959,0.7828800082206726,0.8518542051315308,50000 -12663.758479118347,21.75862622261048,154646.28801751137,345643,0,154646.28801751137,0.6663000583648682,1.4581419229507446,10000,167349.93050813675,0.8858007788658142,0.4192408919334411,0.7828800082206726,0.8518542051315308,50000 -12698.328409194946,21.834535837173465,155066.39670681953,346580,0,155066.39670681953,0.6663000583648682,1.4581419229507446,10000,167804.73359775543,0.8874218463897705,0.414071649312973,0.7828800082206726,0.8518542051315308,50000 -12732.518855333328,21.925792455673218,155486.5955555439,347519,0,155486.5955555439,0.6663000583648682,1.4581419229507446,10000,168259.2629210949,0.8876953125,0.4181354343891144,0.7828800082206726,0.8518542051315308,50000 -12766.790929794312,22.002236366271973,155906.79409885406,348456,0,155906.79409885406,0.6663000583648682,1.4581419229507446,10000,168713.85986709595,0.8907812237739563,0.4101113080978393,0.7828800082206726,0.8518542051315308,50000 -12802.994801282885,22.09450626373291,156326.7063846588,349395,0,156326.7063846588,0.6663000583648682,1.4581419229507446,10000,169170.11686944962,0.88978511095047,0.4093369245529175,0.7828800082206726,0.8518542051315308,50000 -12835.585213184357,22.179222583770752,156746.6764435768,350333,0,156746.6764435768,0.6663000583648682,1.4581419229507446,10000,169622.81052684784,0.8917187452316284,0.4031383097171783,0.7828800082206726,0.8518542051315308,50000 -12869.849185228348,22.272608995437626,157166.64605093002,351271,0,157166.64605093002,0.6663000583648682,1.4581419229507446,10000,170077.18683290482,0.88818359375,0.4155411422252655,0.7828800082206726,0.8518542051315308,50000 -12903.49128460884,22.348501205444336,157587.0714263916,352211,0,157587.0714263916,0.6663000583648682,1.4581419229507446,10000,170531.37905216217,0.88671875,0.4179138839244842,0.7828800082206726,0.8518542051315308,50000 -12939.064716339111,22.443691968917847,158007.24652957916,353149,0,158007.24652957916,0.6663000583648682,1.4581419229507446,10000,170987.2718143463,0.8885351419448853,0.4137209355831146,0.7828800082206726,0.8518542051315308,50000 -12975.078892946243,22.521716356277462,158427.2736890316,354090,0,158427.2736890316,0.6663000583648682,1.4581419229507446,10000,171443.43993115425,0.8870312571525574,0.4134266972541809,0.7828800082206726,0.8518542051315308,50000 -13010.76617193222,22.5986111164093,158847.3594338894,355032,0,158847.3594338894,0.6663000583648682,1.4581419229507446,10000,171899.33880877495,0.8900585770606995,0.410669595003128,0.7828800082206726,0.8518542051315308,50000 -13045.676797151566,22.67598962783813,159267.44840097427,355971,0,159267.44840097427,0.6663000583648682,1.4581419229507446,10000,172354.46488285065,0.8856835961341858,0.4218739569187164,0.7828800082206726,0.8518542051315308,50000 -13081.4606487751,22.755434274673465,159687.5464732647,356910,0,159687.5464732647,0.6663000583648682,1.4581419229507446,10000,172810.47546195984,0.8889843821525574,0.4116753935813904,0.7828800082206726,0.8518542051315308,50000 -13117.43409729004,22.83749270439148,160107.59187602997,357849,0,160107.59187602997,0.6663000583648682,1.4581419229507446,10000,173266.62514972687,0.8860546946525574,0.4163259863853454,0.7828800082206726,0.8518542051315308,50000 -13151.582559347153,22.91771149635315,160527.83763742447,358787,0,160527.83763742447,0.6663000583648682,1.4581419229507446,10000,173721.14839720726,0.8892577886581421,0.4093237519264221,0.7828800082206726,0.8518542051315308,50000 -13187.27852511406,23.012847900390625,160947.71241402626,359725,0,160947.71241402626,0.6663000583648682,1.4581419229507446,10000,174176.86294460297,0.8866406083106995,0.4209662973880768,0.7828800082206726,0.8518542051315308,50000 -13223.405816078186,23.090537786483765,161367.6559035778,360665,0,161367.6559035778,0.6663000583648682,1.4581419229507446,10000,174633.06001091003,0.8871093392372131,0.4208268225193023,0.7828800082206726,0.8518542051315308,50000 -13259.286784172058,23.16886854171753,161787.65626049042,361602,0,161787.65626049042,0.6663000583648682,1.4581419229507446,10000,175089.06940317154,0.8885937333106995,0.4136121273040771,0.7828800082206726,0.8518542051315308,50000 -13295.67337846756,23.257583141326904,162207.73133540154,362541,0,162207.73133540154,0.6663000583648682,1.4581419229507446,10000,175545.66818594933,0.8882616758346558,0.4161655008792877,0.7828800082206726,0.8518542051315308,50000 -13333.242425441742,23.33720898628235,162627.79280495644,363479,0,162627.79280495644,0.6663000583648682,1.4581419229507446,10000,176003.4283466339,0.8877929449081421,0.416262537240982,0.7828800082206726,0.8518542051315308,50000 -13366.86541056633,23.42582654953003,163047.926517725,364417,0,163047.926517725,0.6663000583648682,1.4581419229507446,10000,176457.32267189026,0.8885741829872131,0.415870189666748,0.7828800082206726,0.8518542051315308,50000 -13401.921688556671,23.520427465438843,163467.9629137516,365351,0,163467.9629137516,0.6663000583648682,1.4581419229507446,10000,176912.55845713615,0.887499988079071,0.4217265546321869,0.7828800082206726,0.8518542051315308,50000 -13436.286578655245,23.60895323753357,163887.9586417675,366291,0,163887.9586417675,0.6663000583648682,1.4581419229507446,10000,177367.0562813282,0.8881444931030273,0.4135344624519348,0.7828800082206726,0.8518542051315308,50000 -13470.023416519163,23.688948154449463,164307.8692791462,367230,0,164307.8692791462,0.6663000583648682,1.4581419229507446,10000,177820.83196425438,0.8866796493530273,0.4189541637897491,0.7828800082206726,0.8518542051315308,50000 -13504.752668857574,23.773596048355103,164728.0875532627,368170,0,164728.0875532627,0.6663000583648682,1.4581419229507446,10000,178275.9132180214,0.8882421851158142,0.4134348630905151,0.7828800082206726,0.8518542051315308,50000 -13539.072884321213,23.85645580291748,165148.242303133,369109,0,165148.242303133,0.6663000583648682,1.4581419229507446,10000,178730.52012515068,0.8899999856948853,0.4114833772182464,0.7828800082206726,0.8518542051315308,50000 -13574.1240670681,23.93623661994934,165568.15756726265,370046,0,165568.15756726265,0.6663000583648682,1.4581419229507446,10000,179185.61457157135,0.888476550579071,0.4112453460693359,0.7828800082206726,0.8518542051315308,50000 -13607.670770168304,24.01575541496277,165988.48790717125,370985,0,165988.48790717125,0.6663000583648682,1.4581419229507446,10000,179639.6193087101,0.8868359327316284,0.4194703996181488,0.7828800082206726,0.8518542051315308,50000 -13642.039429187776,24.09737181663513,166408.45634293556,371926,0,166408.45634293556,0.6663000583648682,1.4581419229507446,10000,180094.0869989395,0.8890234231948853,0.41407310962677,0.7828800082206726,0.8518542051315308,50000 -13676.183526039124,24.1772894859314,166828.33721232414,372867,0,166828.33721232414,0.6663000583648682,1.4581419229507446,10000,180548.2405500412,0.8907421827316284,0.4095943868160248,0.7828800082206726,0.8518542051315308,50000 -13710.648540973663,24.268741846084595,167248.5930762291,373806,0,167248.5930762291,0.6663000583648682,1.4581419229507446,10000,181003.10163211825,0.8892773389816284,0.4130665361881256,0.7828800082206726,0.8518542051315308,50000 -13745.768962621689,24.34765338897705,167668.823564291,374745,0,167668.823564291,0.6663000583648682,1.4581419229507446,10000,181458.58004879951,0.8873242139816284,0.4112350344657898,0.7828800082206726,0.8518542051315308,50000 -13780.542924642565,24.428040742874146,168088.80501580238,375685,0,168088.80501580238,0.6663000583648682,1.4581419229507446,10000,181913.4650197029,0.8878905773162842,0.4148047864437103,0.7828800082206726,0.8518542051315308,50000 -13816.431394577026,24.518927574157715,168509.0180311203,376625,0,168509.0180311203,0.6663000583648682,1.4581419229507446,10000,182369.70614933968,0.8874218463897705,0.4158387184143066,0.7828800082206726,0.8518542051315308,50000 -13850.50154018402,24.606639623641968,168929.21570897102,377565,0,168929.21570897102,0.6663000583648682,1.4581419229507446,10000,182824.11052680016,0.8898242115974426,0.407855361700058,0.7828800082206726,0.8518542051315308,50000 -13883.762291669846,24.693422079086304,169349.09867095947,378504,0,169349.09867095947,0.6663000583648682,1.4581419229507446,10000,183277.38987207413,0.8873632550239563,0.4182749986648559,0.7828800082206726,0.8518542051315308,50000 -13919.294860601423,24.78789639472961,169769.38675570488,379443,0,169769.38675570488,0.6663000583648682,1.4581419229507446,10000,183733.3545923233,0.88880854845047,0.4128335416316986,0.7828800082206726,0.8518542051315308,50000 -13953.581751823423,24.875013828277588,170189.518556118,380383,0,170189.518556118,0.6663000583648682,1.4581419229507446,10000,184187.91010832787,0.8867577910423279,0.4163850247859955,0.7828800082206726,0.8518542051315308,50000 -13987.508892059326,24.95596599578857,170609.48133444786,381324,0,170609.48133444786,0.6663000583648682,1.4581419229507446,10000,184641.93013739583,0.8875781297683716,0.4150454699993133,0.7828800082206726,0.8518542051315308,50000 -14021.341363430023,25.048585176467896,171029.39256739616,382265,0,171029.39256739616,0.6663000583648682,1.4581419229507446,10000,185095.8162283897,0.8875585794448853,0.4150723218917846,0.7828800082206726,0.8518542051315308,50000 -14056.252525568008,25.13417649269104,171449.58943510056,383205,0,171449.58943510056,0.6663000583648682,1.4581419229507446,10000,185551.0590927601,0.8885546922683716,0.4139377176761627,0.7828800082206726,0.8518542051315308,50000 -14090.210060834885,25.21590518951416,171869.548060894,384144,0,171869.548060894,0.6663000583648682,1.4581419229507446,10000,186005.1061720848,0.8873242139816284,0.4182506799697876,0.7828800082206726,0.8518542051315308,50000 -14125.25502538681,25.30760169029236,172289.46645498276,385083,0,172289.46645498276,0.6663000583648682,1.4581419229507446,10000,186460.21090960503,0.8873828053474426,0.4178923964500427,0.7828800082206726,0.8518542051315308,50000 -14158.521426916122,25.390140295028687,172709.62967181206,386025,0,172709.62967181206,0.6663000583648682,1.4581419229507446,10000,186913.7729110717,0.8904687166213989,0.4116628468036651,0.7828800082206726,0.8518542051315308,50000 -14193.428065538406,25.49014639854431,173129.71286416054,386964,0,173129.71286416054,0.6663000583648682,1.4581419229507446,10000,187368.9116373062,0.8865038752555847,0.4218724966049194,0.7828800082206726,0.8518542051315308,50000 -14227.59961605072,25.574037551879883,173549.81622862816,387900,0,173549.81622862816,0.6663000583648682,1.4581419229507446,10000,187823.31874752045,0.888476550579071,0.4137118458747864,0.7828800082206726,0.8518542051315308,50000 -14261.797946691511,25.674198389053345,173970.01164507866,388837,0,173970.01164507866,0.6663000583648682,1.4581419229507446,10000,188277.86225795743,0.8871679306030273,0.4193262457847595,0.7828800082206726,0.8518542051315308,50000 -14295.256955862043,25.76052308082581,174390.0960047245,389777,0,174390.0960047245,0.6663000583648682,1.4581419229507446,10000,188731.5424156189,0.8882616758346558,0.4169135689735412,0.7828800082206726,0.8518542051315308,50000 -14328.652228593826,25.845441102981567,174809.967348814,390717,0,174809.967348814,0.6663000583648682,1.4581419229507446,10000,189184.9430897236,0.8859570026397705,0.4196897149085998,0.7828800082206726,0.8518542051315308,50000 -14363.51612186432,25.941320419311523,175229.94012641907,391656,0,175229.94012641907,0.6663000583648682,1.4581419229507446,10000,189639.92478442192,0.8898046612739563,0.4087473452091217,0.7828800082206726,0.8518542051315308,50000 -14397.997807979584,26.028414011001587,175649.9788453579,392596,0,175649.9788453579,0.6663000583648682,1.4581419229507446,10000,190094.58134031296,0.8891991972923279,0.4111690521240234,0.7828800082206726,0.8518542051315308,50000 -14431.227613449097,26.12958836555481,176070.06787610054,393535,0,176070.06787610054,0.6663000583648682,1.4581419229507446,10000,190548.05064105988,0.8874413967132568,0.4186658263206482,0.7828800082206726,0.8518542051315308,50000 -14467.168281316755,26.231871128082275,176490.34965968132,394476,0,176490.34965968132,0.6663000583648682,1.4581419229507446,10000,191004.4239397049,0.8876562118530273,0.4162094295024872,0.7828800082206726,0.8518542051315308,50000 -14500.466915845873,26.326695442199707,176910.50852513313,395417,0,176910.50852513313,0.6663000583648682,1.4581419229507446,10000,191458.0256493092,0.8920117020606995,0.4101240336894989,0.7828800082206726,0.8518542051315308,50000 -14535.344248533249,26.43313837051392,177330.77657341957,396356,0,177330.77657341957,0.6663000583648682,1.4581419229507446,10000,191913.3264658451,0.88880854845047,0.409117728471756,0.7828800082206726,0.8518542051315308,50000 -14571.324562311172,26.5184805393219,177750.9477841854,397296,0,177750.9477841854,0.6663000583648682,1.4581419229507446,10000,192369.61221718788,0.8882226347923279,0.4123573303222656,0.7828800082206726,0.8518542051315308,50000 -14607.131270647047,26.60717272758484,178170.99898934364,398236,0,178170.99898934364,0.6663000583648682,1.4581419229507446,10000,192825.60821533203,0.88880854845047,0.4158269166946411,0.7828800082206726,0.8518542051315308,50000 -14642.470609903336,26.69651246070861,178591.1631834507,399173,0,178591.1631834507,0.6663000583648682,1.4581419229507446,10000,193281.24997234344,0.8885546922683716,0.4094387590885162,0.7828800082206726,0.8518542051315308,50000 -14677.70623278618,26.80361557006836,179011.04693436623,400113,0,179011.04693436623,0.6663000583648682,1.4581419229507446,10000,193736.52523708344,0.8860546946525574,0.4168408811092376,0.7828800082206726,0.8518542051315308,50000 -14711.077335596085,26.893108129501343,179431.1743233204,401052,0,179431.1743233204,0.6663000583648682,1.4581419229507446,10000,194190.1621646881,0.8895898461341858,0.4095662832260132,0.7828800082206726,0.8518542051315308,50000 -14746.11722946167,27.03999924659729,179851.25026535988,401993,0,179851.25026535988,0.6663000583648682,1.4581419229507446,10000,194645.4740064144,0.8876367211341858,0.4177990555763244,0.7828800082206726,0.8518542051315308,50000 -14779.690134763718,27.12631464004517,180271.34703087807,402930,0,180271.34703087807,0.6663000583648682,1.4581419229507446,10000,195099.27899551392,0.8872265219688416,0.4172936677932739,0.7828800082206726,0.8518542051315308,50000 -14813.849033117294,27.21081781387329,180691.3440322876,403869,0,180691.3440322876,0.6663000583648682,1.4581419229507446,10000,195553.56817293167,0.8881054520606995,0.4169188141822815,0.7828800082206726,0.8518542051315308,50000 -14849.860277414322,27.298006534576416,181111.3825807572,404806,0,181111.3825807572,0.6663000583648682,1.4581419229507446,10000,196009.753885746,0.8880273103713989,0.4125872850418091,0.7828800082206726,0.8518542051315308,50000 -14886.262979507446,27.38276648521424,181531.36791729927,405746,0,181531.36791729927,0.6663000583648682,1.4581419229507446,10000,196466.27555465687,0.8882226347923279,0.4126043915748596,0.7828800082206726,0.8518542051315308,50000 -14920.631911039352,27.48050045967102,181951.2723581791,406686,0,181951.2723581791,0.6663000583648682,1.4581419229507446,10000,196920.69525766373,0.8878124952316284,0.4145838916301727,0.7828800082206726,0.8518542051315308,50000 -14954.981954574583,27.58235263824463,182371.54948091507,407624,0,182371.54948091507,0.6663000583648682,1.4581419229507446,10000,197375.4734230041,0.8862109184265137,0.419201523065567,0.7828800082206726,0.8518542051315308,50000 -14990.611825942991,27.677670001983643,182791.8166847229,408562,0,182791.8166847229,0.6663000583648682,1.4581419229507446,10000,197831.5146150589,0.8884179592132568,0.4154849052429199,0.7828800082206726,0.8518542051315308,50000 -15025.966703891754,27.77146291732788,183211.9351406097,409503,0,183211.9351406097,0.6663000583648682,1.4581419229507446,10000,198287.1301677227,0.8890624642372131,0.412666767835617,0.7828800082206726,0.8518542051315308,50000 -15061.399684906006,27.8588445186615,183632.14395451543,410441,0,183632.14395451543,0.6663000583648682,1.4581419229507446,10000,198742.9077372551,0.8878710865974426,0.4203723073005676,0.7828800082206726,0.8518542051315308,50000 -15095.746675729752,27.959092378616333,184052.0600640773,411379,0,184052.0600640773,0.6663000583648682,1.4581419229507446,10000,199197.3196368217,0.8880273103713989,0.4152186810970306,0.7828800082206726,0.8518542051315308,50000 -15131.023107528688,28.0471260547638,184472.1326699257,412317,0,184472.1326699257,0.6663000583648682,1.4581419229507446,10000,199652.8047463894,0.8885546922683716,0.4164588153362274,0.7828800082206726,0.8518542051315308,50000 -15164.818674087524,28.145170211791992,184892.1034603119,413257,0,184892.1034603119,0.6663000583648682,1.4581419229507446,10000,200106.71964406967,0.8854687213897705,0.4227531850337982,0.7828800082206726,0.8518542051315308,50000 -15198.814550161362,28.235026121139526,185312.38032460213,414195,0,185312.38032460213,0.6663000583648682,1.4581419229507446,10000,200561.1308102608,0.8895312547683716,0.4116630852222442,0.7828800082206726,0.8518542051315308,50000 -15233.74088358879,28.325153589248657,185732.29559206963,415132,0,185732.29559206963,0.6663000583648682,1.4581419229507446,10000,201016.1106908321,0.8891991972923279,0.4130409955978393,0.7828800082206726,0.8518542051315308,50000 -15268.27907562256,28.421623706817627,186152.5260412693,416073,0,186152.5260412693,0.6663000583648682,1.4581419229507446,10000,201471.02529382703,0.8877929449081421,0.416156530380249,0.7828800082206726,0.8518542051315308,50000 -15302.907843351364,28.50839161872864,186572.4598255157,417016,0,186572.4598255157,0.6663000583648682,1.4581419229507446,10000,201925.7241203785,0.8877539038658142,0.4165347814559936,0.7828800082206726,0.8518542051315308,50000 -15339.094573259354,28.86305069923401,186992.08824324608,417955,0,186992.08824324608,0.6663000583648682,1.4581419229507446,10000,202381.94254612923,0.88832026720047,0.4133238494396209,0.7828800082206726,0.8518542051315308,50000 -15372.716371774672,28.951866388320923,187412.05957746503,418895,0,187412.05957746503,0.6663000583648682,1.4581419229507446,10000,202835.6733837128,0.8907421827316284,0.412113755941391,0.7828800082206726,0.8518542051315308,50000 -15407.61594223976,29.050409078598022,187832.055000782,419835,0,187832.055000782,0.6663000583648682,1.4581419229507446,10000,203290.7157907486,0.8883398175239563,0.4108054339885711,0.7828800082206726,0.8518542051315308,50000 -15442.47093296051,29.14796638488769,188252.1757707596,420776,0,188252.1757707596,0.6663000583648682,1.4581419229507446,10000,203745.8379309177,0.8899804353713989,0.4098820686340332,0.7828800082206726,0.8518542051315308,50000 -15478.182670593262,29.23967170715332,188672.410371542,421715,0,188672.410371542,0.6663000583648682,1.4581419229507446,10000,204201.9252877236,0.8877539038658142,0.412142664194107,0.7828800082206726,0.8518542051315308,50000 -15514.970543146132,29.327571868896484,189092.6740632057,422654,0,189092.6740632057,0.6663000583648682,1.4581419229507446,10000,204659.1133544445,0.8885155916213989,0.4149311780929565,0.7828800082206726,0.8518542051315308,50000 -15549.48907828331,29.4142906665802,189513.013206482,423593,0,189513.013206482,0.6663000583648682,1.4581419229507446,10000,205114.1068179608,0.88671875,0.4145254790782928,0.7828800082206726,0.8518542051315308,50000 -15584.46364927292,29.50173306465149,189933.1962766648,424535,0,189933.1962766648,0.6663000583648682,1.4581419229507446,10000,205569.40095114708,0.8887695074081421,0.4101097881793976,0.7828800082206726,0.8518542051315308,50000 -15621.176329135897,29.60228538513184,190353.0560581684,425473,0,190353.0560581684,0.6663000583648682,1.4581419229507446,10000,206026.12229013443,0.8880859017372131,0.4189074635505676,0.7828800082206726,0.8518542051315308,50000 -15656.826608181,29.69010305404663,190773.18273472783,426408,0,190773.18273472783,0.6663000583648682,1.4581419229507446,10000,206482.0358054638,0.8875585794448853,0.4154538512229919,0.7828800082206726,0.8518542051315308,50000 -15691.703318595886,29.78359007835388,191193.0832927227,427347,0,191193.0832927227,0.6663000583648682,1.4581419229507446,10000,206936.9553265572,0.8865429759025574,0.4164008200168609,0.7828800082206726,0.8518542051315308,50000 -15727.489582061768,29.8861780166626,191613.23265123367,428284,0,191613.23265123367,0.6663000583648682,1.4581419229507446,10000,207393.0426137448,0.8866796493530273,0.4184194803237915,0.7828800082206726,0.8518542051315308,50000 -15761.853075504305,29.976338148117065,192033.45323348045,429223,0,192033.45323348045,0.6663000583648682,1.4581419229507446,10000,207847.76594495773,0.8891991972923279,0.411479115486145,0.7828800082206726,0.8518542051315308,50000 -15796.966268777847,30.06338238716125,192453.38086915016,430160,0,192453.38086915016,0.6663000583648682,1.4581419229507446,10000,208302.9422523976,0.8868749737739563,0.4184442162513733,0.7828800082206726,0.8518542051315308,50000 -15831.376549959185,30.15148520469665,192873.32185602188,431097,0,192873.32185602188,0.6663000583648682,1.4581419229507446,10000,208757.430749178,0.8900390267372131,0.4117555618286133,0.7828800082206726,0.8518542051315308,50000 -15864.387818336489,30.24280261993408,193293.54080319405,432036,0,193293.54080319405,0.6663000583648682,1.4581419229507446,10000,209210.8012115956,0.8875781297683716,0.4167491197586059,0.7828800082206726,0.8518542051315308,50000 -15897.666623830795,30.3501980304718,193713.5750546456,432975,0,193713.5750546456,0.6663000583648682,1.4581419229507446,10000,209664.2701013088,0.88880854845047,0.4180719554424286,0.7828800082206726,0.8518542051315308,50000 -15932.76951098442,30.45403957366944,194133.56434631348,433913,0,194133.56434631348,0.6663000583648682,1.4581419229507446,10000,210119.51546907425,0.8876367211341858,0.4186112880706787,0.7828800082206726,0.8518542051315308,50000 -15966.977990627289,30.544241189956665,194553.46647405624,434852,0,194553.46647405624,0.6663000583648682,1.4581419229507446,10000,210573.7652647496,0.8868945240974426,0.4147317111492157,0.7828800082206726,0.8518542051315308,50000 -16001.56213068962,30.63562846183777,194973.41510796547,435791,0,194973.41510796547,0.6663000583648682,1.4581419229507446,10000,211028.4382355213,0.8885546922683716,0.4164343178272247,0.7828800082206726,0.8518542051315308,50000 -16037.476316690443,30.72657537460327,195393.3149046898,436729,0,195393.3149046898,0.6663000583648682,1.4581419229507446,10000,211484.39165091515,0.8863476514816284,0.4209548830986023,0.7828800082206726,0.8518542051315308,50000 -16072.608691453934,30.82048726081848,195813.20502996445,437668,0,195813.20502996445,0.6663000583648682,1.4581419229507446,10000,211939.55713367465,0.8892577886581421,0.4133823215961456,0.7828800082206726,0.8518542051315308,50000 -16106.368917942047,30.92542052268982,196233.2885670662,438606,0,196233.2885670662,0.6663000583648682,1.4581419229507446,10000,212393.555311203,0.8886327743530273,0.4106488823890686,0.7828800082206726,0.8518542051315308,50000 -16141.804993867874,31.039782285690308,196653.4536757469,439545,0,196653.4536757469,0.6663000583648682,1.4581419229507446,10000,212849.3200573921,0.8883788585662842,0.4163420498371124,0.7828800082206726,0.8518542051315308,50000 -16177.474328279495,31.147948503494263,197073.50745630264,440484,0,197073.50745630264,0.6663000583648682,1.4581419229507446,10000,213305.20050811768,0.8870507478713989,0.4158827364444732,0.7828800082206726,0.8518542051315308,50000 -16213.603892564774,31.240306854248047,197493.4744989872,441423,0,197493.4744989872,0.6663000583648682,1.4581419229507446,10000,213761.4383919239,0.8891406059265137,0.4135304689407348,0.7828800082206726,0.8518542051315308,50000 -16248.398986577988,31.336047172546387,197913.55271673205,442361,0,197913.55271673205,0.6663000583648682,1.4581419229507446,10000,214216.4568591118,0.8888866901397705,0.4150472581386566,0.7828800082206726,0.8518542051315308,50000 -16283.67305135727,31.426677703857425,198333.7465114593,443298,0,198333.7465114593,0.6663000583648682,1.4581419229507446,10000,214672.06289815903,0.8919921517372131,0.4033390283584595,0.7828800082206726,0.8518542051315308,50000 -16319.972463607788,31.51662302017212,198753.99509072304,444235,0,198753.99509072304,0.6663000583648682,1.4581419229507446,10000,215128.74952864647,0.8896874785423279,0.4114967882633209,0.7828800082206726,0.8518542051315308,50000 -16355.516724586489,31.621028661727905,199174.1624150276,445174,0,199174.1624150276,0.6663000583648682,1.4581419229507446,10000,215584.6144790649,0.8869726657867432,0.4166136384010315,0.7828800082206726,0.8518542051315308,50000 -16391.367474079132,31.726098775863647,199594.30109024048,446111,0,199594.30109024048,0.6663000583648682,1.4581419229507446,10000,216040.7579071521,0.8872851133346558,0.4126180410385132,0.7828800082206726,0.8518542051315308,50000 -16427.27664041519,31.828654527664185,200014.64024281505,447049,0,200014.64024281505,0.6663000583648682,1.4581419229507446,10000,216497.158163786,0.88734370470047,0.4156618416309356,0.7828800082206726,0.8518542051315308,50000 -16463.17485022545,31.92135524749756,200434.81145572665,447986,0,200434.81145572665,0.6663000583648682,1.4581419229507446,10000,216953.36850714684,0.8880078196525574,0.4107287228107452,0.7828800082206726,0.8518542051315308,50000 -16499.20669221878,32.01465153694153,200855.06634163857,448924,0,200855.06634163857,0.6663000583648682,1.4581419229507446,10000,217409.7971892357,0.8885155916213989,0.4190287590026855,0.7828800082206726,0.8518542051315308,50000 -16534.43468928337,32.109177350997925,201274.92774033544,449862,0,201274.92774033544,0.6663000583648682,1.4581419229507446,10000,217865.0299890041,0.8880664110183716,0.4124629497528076,0.7828800082206726,0.8518542051315308,50000 -16570.566035032272,32.20013737678528,201694.9643316269,450799,0,201694.9643316269,0.6663000583648682,1.4581419229507446,10000,218321.3376784325,0.8866210579872131,0.4189048409461975,0.7828800082206726,0.8518542051315308,50000 -16606.675546884537,32.308313846588135,202115.2426486016,451735,0,202115.2426486016,0.6663000583648682,1.4581419229507446,10000,218777.88199329376,0.8884375095367432,0.4119004011154175,0.7828800082206726,0.8518542051315308,50000 -16641.969081640244,32.402332067489624,202535.2054517269,452675,0,202535.2054517269,0.6663000583648682,1.4581419229507446,10000,219233.2810678482,0.8880664110183716,0.4146206080913543,0.7828800082206726,0.8518542051315308,50000 -16675.76238799095,32.494225025177,202955.50574207303,453614,0,202955.50574207303,0.6663000583648682,1.4581419229507446,10000,219687.5152556896,0.8863085508346558,0.4198028445243835,0.7828800082206726,0.8518542051315308,50000 -16712.418529748917,32.59436345100403,203375.631287098,454552,0,203375.631287098,0.6663000583648682,1.4581419229507446,10000,220144.4460659027,0.8874413967132568,0.418776273727417,0.7828800082206726,0.8518542051315308,50000 -16748.549387454987,32.69032335281372,203795.91160702705,455492,0,203795.91160702705,0.6663000583648682,1.4581419229507446,10000,220601.0011694432,0.8889843821525574,0.4140480756759643,0.7828800082206726,0.8518542051315308,50000 -16783.94956278801,32.78437304496765,204215.98255586624,456431,0,204215.98255586624,0.6663000583648682,1.4581419229507446,10000,221056.6145207882,0.8893945217132568,0.4115277528762817,0.7828800082206726,0.8518542051315308,50000 -16819.28258252144,32.87864851951599,204635.8847630024,457370,0,204635.8847630024,0.6663000583648682,1.4581419229507446,10000,221511.99255919456,0.8887109160423279,0.4161616861820221,0.7828800082206726,0.8518542051315308,50000 -16853.451112270355,32.97318458557129,205055.78770446777,458307,0,205055.78770446777,0.6663000583648682,1.4581419229507446,10000,221966.20752811432,0.8866796493530273,0.4212839007377624,0.7828800082206726,0.8518542051315308,50000 -16888.39018678665,33.064754486083984,205475.6713354588,459242,0,205475.6713354588,0.6663000583648682,1.4581419229507446,10000,222421.1706006527,0.8888671398162842,0.4136143922805786,0.7828800082206726,0.8518542051315308,50000 -16922.828406095505,33.157649517059326,205895.7057659626,460179,0,205895.7057659626,0.6663000583648682,1.4581419229507446,10000,222875.7852590084,0.8866991996765137,0.41916623711586,0.7828800082206726,0.8518542051315308,50000 -16958.004980802536,33.26661014556885,206315.9453783036,461118,0,206315.9453783036,0.6663000583648682,1.4581419229507446,10000,223331.3592851162,0.8881054520606995,0.4115708768367767,0.7828800082206726,0.8518542051315308,50000 -16994.971632242203,33.36212921142578,206736.0901172161,462057,0,206736.0901172161,0.6663000583648682,1.4581419229507446,10000,223788.6147623062,0.8887499570846558,0.4157285094261169,0.7828800082206726,0.8518542051315308,50000 -17030.81745314598,33.45702028274536,207156.234416008,462995,0,207156.234416008,0.6663000583648682,1.4581419229507446,10000,224244.74796533585,0.8881054520606995,0.4158811569213867,0.7828800082206726,0.8518542051315308,50000 -17066.472289562225,33.561460733413696,207576.18787121773,463932,0,207576.18787121773,0.6663000583648682,1.4581419229507446,10000,224700.5097444057,0.8861132860183716,0.4176024496555328,0.7828800082206726,0.8518542051315308,50000 -17100.941737413406,33.65341758728027,207996.31915855408,464867,0,207996.31915855408,0.6663000583648682,1.4581419229507446,10000,225155.25125455856,0.889941394329071,0.4139235317707062,0.7828800082206726,0.8518542051315308,50000 -17136.74747443199,33.74782729148865,208416.38221096992,465804,0,208416.38221096992,0.6663000583648682,1.4581419229507446,10000,225611.26416301727,0.8892577886581421,0.4131757915019989,0.7828800082206726,0.8518542051315308,50000 -17173.017083644867,33.8416645526886,208836.4301431179,466742,0,208836.4301431179,0.6663000583648682,1.4581419229507446,10000,226067.72537398327,0.8897460699081421,0.4074950218200683,0.7828800082206726,0.8518542051315308,50000 -17207.967556238174,33.94882917404175,209256.819983244,467682,0,209256.819983244,0.6663000583648682,1.4581419229507446,10000,226523.2222290039,0.88929682970047,0.4114511907100677,0.7828800082206726,0.8518542051315308,50000 -17242.25678539276,34.04836964607239,209676.87094807625,468619,0,209676.87094807625,0.6663000583648682,1.4581419229507446,10000,226977.7106547356,0.8893749713897705,0.4095381200313568,0.7828800082206726,0.8518542051315308,50000 -17277.135021448135,34.15190577507019,210097.1221292019,469559,0,210097.1221292019,0.6663000583648682,1.4581419229507446,10000,227432.9923608303,0.8883788585662842,0.4138410091400146,0.7828800082206726,0.8518542051315308,50000 -17313.09035563469,34.2598888874054,210517.18369412425,470494,0,210517.18369412425,0.6663000583648682,1.4581419229507446,10000,227889.16590118408,0.8877343535423279,0.4128125607967376,0.7828800082206726,0.8518542051315308,50000 -17349.413534879684,34.35630774497986,210937.3222939968,471435,0,210937.3222939968,0.6663000583648682,1.4581419229507446,10000,228345.77362632751,0.8878124952316284,0.4137357473373413,0.7828800082206726,0.8518542051315308,50000 -17384.368765592575,34.45077323913574,211357.5579998493,472376,0,211357.5579998493,0.6663000583648682,1.4581419229507446,10000,228801.10828518867,0.8859961032867432,0.421807587146759,0.7828800082206726,0.8518542051315308,50000 -17420.123772144318,34.5459508895874,211777.6955449581,473316,0,211777.6955449581,0.6663000583648682,1.4581419229507446,10000,229257.14534235,0.8886523246765137,0.4129462838172912,0.7828800082206726,0.8518542051315308,50000 -17453.64389872551,34.639634132385254,212197.9356815815,474255,0,212197.9356815815,0.6663000583648682,1.4581419229507446,10000,229711.0480697155,0.8865429759025574,0.4211982488632202,0.7828800082206726,0.8518542051315308,50000 -17489.702049016953,34.77641224861145,212617.8903603553,475192,0,212617.8903603553,0.6663000583648682,1.4581419229507446,10000,230167.2466096878,0.888671875,0.4101220667362213,0.7828800082206726,0.8518542051315308,50000 -17524.20827102661,34.870837688446045,213038.1022245884,476134,0,213038.1022245884,0.6663000583648682,1.4581419229507446,10000,230622.10814762115,0.8873632550239563,0.415277898311615,0.7828800082206726,0.8518542051315308,50000 -17559.069897413254,34.96437788009644,213458.12177467343,477074,0,213458.12177467343,0.6663000583648682,1.4581419229507446,10000,231077.1324160099,0.8871874809265137,0.4179157018661499,0.7828800082206726,0.8518542051315308,50000 -17593.989289283752,35.059608697891235,213878.38510799408,478014,0,213878.38510799408,0.6663000583648682,1.4581419229507446,10000,231532.45939064023,0.8897460699081421,0.4113948047161102,0.7828800082206726,0.8518542051315308,50000 -17630.104825258255,35.1567656993866,214298.27096557617,478951,0,214298.27096557617,0.6663000583648682,1.4581419229507446,10000,231988.60705900192,0.8887109160423279,0.4161091148853302,0.7828800082206726,0.8518542051315308,50000 -17664.77192544937,35.27648377418518,214718.205119133,479893,0,214718.205119133,0.6663000583648682,1.4581419229507446,10000,232443.3776421547,0.8892577886581421,0.4139748215675354,0.7828800082206726,0.8518542051315308,50000 -17698.574368476868,35.3733856678009,215138.26736354828,480834,0,215138.26736354828,0.6663000583648682,1.4581419229507446,10000,232897.3886890412,0.8868163824081421,0.4217853844165802,0.7828800082206726,0.8518542051315308,50000 -17734.160432100296,35.47701930999756,215558.27237558365,481774,0,215558.27237558365,0.6663000583648682,1.4581419229507446,10000,233353.13364720345,0.8877929449081421,0.4164779186248779,0.7828800082206726,0.8518542051315308,50000 -17770.049526929855,35.57323598861694,215978.2754702568,482712,0,215978.2754702568,0.6663000583648682,1.4581419229507446,10000,233809.1701474189,0.8885351419448853,0.4163164794445038,0.7828800082206726,0.8518542051315308,50000 -17803.500860214233,35.67725419998169,216398.384973526,483652,0,216398.384973526,0.6663000583648682,1.4581419229507446,10000,234262.88517141345,0.8861523270606995,0.4175904095172882,0.7828800082206726,0.8518542051315308,50000 -17838.617998600006,35.79676175117493,216818.5877084732,484589,0,216818.5877084732,0.6663000583648682,1.4581419229507446,10000,234718.37487339973,0.8881054520606995,0.4143379628658294,0.7828800082206726,0.8518542051315308,50000 -17874.568171977997,35.89333629608154,217238.5685465336,485529,0,217238.5685465336,0.6663000583648682,1.4581419229507446,10000,235174.451969862,0.8881640434265137,0.4148095548152923,0.7828800082206726,0.8518542051315308,50000 -17908.249662160873,35.98907494544983,217658.7013375759,486470,0,217658.7013375759,0.6663000583648682,1.4581419229507446,10000,235628.41182494164,0.88818359375,0.4164724946022033,0.7828800082206726,0.8518542051315308,50000 -17940.89183330536,36.08620810508728,218078.80738782883,487409,0,218078.80738782883,0.6663000583648682,1.4581419229507446,10000,236081.30646657944,0.8879492282867432,0.4153375029563904,0.7828800082206726,0.8518542051315308,50000 -17976.68732571602,36.20460081100464,218498.74874806404,488346,0,218498.74874806404,0.6663000583648682,1.4581419229507446,10000,236537.2109837532,0.8887695074081421,0.4143279790878296,0.7828800082206726,0.8518542051315308,50000 -18012.489437818527,36.3038694858551,218918.67112493515,489283,0,218918.67112493515,0.6663000583648682,1.4581419229507446,10000,236993.0835986137,0.8891015648841858,0.4146701693534851,0.7828800082206726,0.8518542051315308,50000 -18047.40689873696,36.40238857269287,219338.8282425404,490226,0,219338.8282425404,0.6663000583648682,1.4581419229507446,10000,237448.3058791161,0.8903515338897705,0.4075874984264374,0.7828800082206726,0.8518542051315308,50000 -18082.4962182045,36.51316452026367,219758.6913704872,491164,0,219758.6913704872,0.6663000583648682,1.4581419229507446,10000,237903.41774082184,0.8893554210662842,0.4099286496639251,0.7828800082206726,0.8518542051315308,50000 -18117.168116807938,36.611567735672,220178.8872082233,492103,0,220178.8872082233,0.6663000583648682,1.4581419229507446,10000,238358.4329917431,0.8878515362739563,0.4151861369609833,0.7828800082206726,0.8518542051315308,50000 -18151.76734852791,36.70686912536621,220598.7592358589,493045,0,220598.7592358589,0.6663000583648682,1.4581419229507446,10000,238813.04859733584,0.8883984088897705,0.414981484413147,0.7828800082206726,0.8518542051315308,50000 -18185.89629340172,36.804692029953,221018.6685461998,493983,0,221018.6685461998,0.6663000583648682,1.4581419229507446,10000,239267.2332293988,0.8869531154632568,0.4125227630138397,0.7828800082206726,0.8518542051315308,50000 -18221.183396816254,36.91400504112244,221438.6368880272,494921,0,221438.6368880272,0.6663000583648682,1.4581419229507446,10000,239722.64675951004,0.8883007764816284,0.4121655523777008,0.7828800082206726,0.8518542051315308,50000 -18257.582832574844,37.0140278339386,221858.705681324,495861,0,221858.705681324,0.6663000583648682,1.4581419229507446,10000,240179.26391124725,0.88832026720047,0.4151731729507446,0.7828800082206726,0.8518542051315308,50000 -18292.674980402,37.12426042556763,222278.61622595787,496799,0,222278.61622595787,0.6663000583648682,1.4581419229507446,10000,240634.42562317848,0.88783198595047,0.4155504107475281,0.7828800082206726,0.8518542051315308,50000 -18329.07387948036,37.22601556777954,222698.61618041992,497738,0,222698.61618041992,0.6663000583648682,1.4581419229507446,10000,241090.9749581813,0.8870507478713989,0.4186049699783325,0.7828800082206726,0.8518542051315308,50000 -18363.83420419693,37.32507681846619,223118.7795138359,498677,0,223118.7795138359,0.6663000583648682,1.4581419229507446,10000,241546.0461564064,0.8886913657188416,0.412122905254364,0.7828800082206726,0.8518542051315308,50000 -18399.62007832527,37.42348909378052,223538.7945408821,499615,0,223538.7945408821,0.6663000583648682,1.4581419229507446,10000,242001.99430322647,0.88734370470047,0.4127151072025299,0.7828800082206726,0.8518542051315308,50000 -18434.819314956665,37.53372430801392,223958.84918499,500556,0,223958.84918499,0.6663000583648682,1.4581419229507446,10000,242457.4077973365,0.8880664110183716,0.4174224436283111,0.7828800082206726,0.8518542051315308,50000 -18470.24125123024,37.64730954170227,224378.8450908661,501496,0,224378.8450908661,0.6663000583648682,1.4581419229507446,10000,242912.9882376194,0.8877148032188416,0.4193105101585388,0.7828800082206726,0.8518542051315308,50000 -18505.77114391327,37.74493646621704,224798.8704931736,502436,0,224798.8704931736,0.6663000583648682,1.4581419229507446,10000,243368.6910982132,0.8877733945846558,0.4132482409477234,0.7828800082206726,0.8518542051315308,50000 -18539.91388988495,37.84392428398132,225219.12637400627,503376,0,225219.12637400627,0.6663000583648682,1.4581419229507446,10000,243823.23779582977,0.8895312547683716,0.4103905856609344,0.7828800082206726,0.8518542051315308,50000 -18576.1207678318,37.96061396598816,225639.3281033039,504311,0,225639.3281033039,0.6663000583648682,1.4581419229507446,10000,244279.811917305,0.8867968320846558,0.422403335571289,0.7828800082206726,0.8518542051315308,50000 -18612.577392101288,38.06256675720215,226059.1844224929,505248,0,226059.1844224929,0.6663000583648682,1.4581419229507446,10000,244736.27666950223,0.888671875,0.4171231091022491,0.7828800082206726,0.8518542051315308,50000 -18648.43890714645,38.15953755378723,226479.16041707995,506186,0,226479.16041707995,0.6663000583648682,1.4581419229507446,10000,245192.26065707207,0.88832026720047,0.4186186492443084,0.7828800082206726,0.8518542051315308,50000 -18683.35116648674,38.262371301651,226899.0175216198,507127,0,226899.0175216198,0.6663000583648682,1.4581419229507446,10000,245647.18233394623,0.8861913681030273,0.4204094707965851,0.7828800082206726,0.8518542051315308,50000 -18719.78503012657,38.373663663864136,227319.15297055244,508065,0,227319.15297055244,0.6663000583648682,1.4581419229507446,10000,246103.9114441872,0.8902148008346558,0.4068226814270019,0.7828800082206726,0.8518542051315308,50000 -18753.6684820652,38.47547960281372,227739.37865519524,509005,0,227739.37865519524,0.6663000583648682,1.4581419229507446,10000,246558.1716852188,0.8861913681030273,0.420152872800827,0.7828800082206726,0.8518542051315308,50000 -18789.85410284996,38.58154392242432,228159.2598702908,509943,0,228159.2598702908,0.6663000583648682,1.4581419229507446,10000,247014.39329242703,0.8894726634025574,0.4081210494041443,0.7828800082206726,0.8518542051315308,50000 -18824.52872061729,38.695571184158325,228579.4013376236,510880,0,228579.4013376236,0.6663000583648682,1.4581419229507446,10000,247469.3725888729,0.8873046636581421,0.4174008965492248,0.7828800082206726,0.8518542051315308,50000 -18860.66403913498,38.80521607398987,228999.4161813259,511817,0,228999.4161813259,0.6663000583648682,1.4581419229507446,10000,247925.6817638874,0.8889257907867432,0.4128748476505279,0.7828800082206726,0.8518542051315308,50000 -18894.84908437729,38.91966819763184,229419.3546979428,512755,0,229419.3546979428,0.6663000583648682,1.4581419229507446,10000,248379.9687492848,0.8876953125,0.4195302724838257,0.7828800082206726,0.8518542051315308,50000 -18928.725402593613,39.0208375453949,229839.27992606163,513690,0,229839.27992606163,0.6663000583648682,1.4581419229507446,10000,248833.9214372635,0.8895312547683716,0.4111018180847168,0.7828800082206726,0.8518542051315308,50000 -18965.40698552132,39.12135338783264,230259.215269804,514623,0,230259.215269804,0.6663000583648682,1.4581419229507446,10000,249290.68759441376,0.8900585770606995,0.4098351299762726,0.7828800082206726,0.8518542051315308,50000 -19002.11077594757,39.22586536407471,230679.0721449852,515560,0,230679.0721449852,0.6663000583648682,1.4581419229507446,10000,249747.4019293785,0.8890234231948853,0.4100822210311889,0.7828800082206726,0.8518542051315308,50000 -19037.79772377014,39.339415073394775,231099.2383217812,516498,0,231099.2383217812,0.6663000583648682,1.4581419229507446,10000,250203.4171028137,0.8893945217132568,0.4079234898090362,0.7828800082206726,0.8518542051315308,50000 -19071.938607931137,39.44188857078552,231519.41894960403,517437,0,231519.41894960403,0.6663000583648682,1.4581419229507446,10000,250657.88970041275,0.8845898509025574,0.4218060672283172,0.7828800082206726,0.8518542051315308,50000 -19108.14038753509,39.543895959854126,231939.48673439023,518375,0,231939.48673439023,0.6663000583648682,1.4581419229507446,10000,251114.31041026115,0.8900195360183716,0.4088830351829529,0.7828800082206726,0.8518542051315308,50000 -19143.85590171814,39.65925049781799,232360.05365729332,519311,0,232360.05365729332,0.6663000583648682,1.4581419229507446,10000,251570.7573211193,0.8872656226158142,0.4180852770805359,0.7828800082206726,0.8518542051315308,50000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv deleted file mode 100644 index 820fb7134..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/measurements.csv +++ /dev/null @@ -1,5754 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,validation/accuracy,validation/loss,validation/num_examples,test/accuracy,test/loss,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,0.351232,6.9077563,,,,,,,,,,,,,, -1,,,0.0009960937313735,6.907756328582764,0.0009999999310821,6.9077558517456055,50000.0,0.0010000000474974,6.907756805419922,10000.0,44.79174327850342,86.80535459518433,44.79174327850342,42.01347494125366,0.0,0.0 -100,0.45824468,6.886222,,,,,,,,,,,,,, -200,0.8139546,6.7839723,,,,,,,,,,,,,, -300,1.3399607,6.6524596,,,,,,,,,,,,,, -400,1.4092907,6.608524,,,,,,,,,,,,,, -500,1.0150961,6.4417295,,,,,,,,,,,,,, -600,1.3024337,6.3289294,,,,,,,,,,,,,, -700,1.1906755,6.2654877,,,,,,,,,,,,,, -800,1.3676158,6.215804,,,,,,,,,,,,,, -900,2.029803,6.2277174,,,,,,,,,,,,,, -911,,,0.0321874991059303,5.927576541900635,0.0300399996340274,5.954702377319336,50000.0,0.0240000002086162,6.074681758880615,10000.0,465.0541150569916,528.6460611820221,465.0541150569916,63.51478552818298,0.0266435146331787,0.0 -1000,1.2887667,6.03848,,,,,,,,,,,,,, -1100,1.2446568,5.960166,,,,,,,,,,,,,, -1200,1.1817067,6.2947073,,,,,,,,,,,,,, -1300,1.3577219,5.963889,,,,,,,,,,,,,, -1400,1.3026958,5.8858576,,,,,,,,,,,,,, -1500,1.0981282,6.09473,,,,,,,,,,,,,, -1600,1.3342332,5.7796545,,,,,,,,,,,,,, -1700,0.976218,5.803779,,,,,,,,,,,,,, -1800,1.2013091,5.843648,,,,,,,,,,,,,, -1870,,,0.0830664038658142,5.143368244171143,0.0786999985575676,5.1934967041015625,50000.0,0.0636000037193298,5.44067907333374,10000.0,885.2891094684601,970.5481154918672,885.2891094684601,85.10202193260193,0.053812026977539,0.0 -1900,1.0409439,6.134364,,,,,,,,,,,,,, -2000,1.1345227,5.6502213,,,,,,,,,,,,,, -2100,1.0412841,5.5517087,,,,,,,,,,,,,, -2200,1.0325232,5.4317975,,,,,,,,,,,,,, -2300,1.3840597,5.604598,,,,,,,,,,,,,, -2400,0.9878423,5.5338626,,,,,,,,,,,,,, -2500,0.88261503,5.6312137,,,,,,,,,,,,,, -2600,1.1280584,5.4292603,,,,,,,,,,,,,, -2700,1.0806464,5.421567,,,,,,,,,,,,,, -2800,0.7673526,6.414533,,,,,,,,,,,,,, -2830,,,0.1516406238079071,4.460026264190674,0.1383199989795684,4.5624847412109375,50000.0,0.1094000041484832,4.899428367614746,10000.0,1305.504187822342,1412.643222808838,1305.504187822342,106.90083980560304,0.0830845832824707,0.0 -2900,1.0160127,5.573594,,,,,,,,,,,,,, -3000,1.5141512,5.135382,,,,,,,,,,,,,, -3100,0.91345197,5.2075505,,,,,,,,,,,,,, -3200,1.0291646,5.1249523,,,,,,,,,,,,,, -3300,1.4217256,5.1316257,,,,,,,,,,,,,, -3400,1.0347896,5.055929,,,,,,,,,,,,,, -3500,1.1369736,4.9819636,,,,,,,,,,,,,, -3600,1.7900652,4.9198947,,,,,,,,,,,,,, -3700,0.8612973,5.0536647,,,,,,,,,,,,,, -3789,,,0.2010351568460464,4.068780422210693,0.1882999986410141,4.150966644287109,50000.0,0.1463000029325485,4.5577168464660645,10000.0,1725.7441012859344,1855.7330009937289,1725.7441012859344,129.67135000228882,0.1106207370758056,0.0 -3800,1.3576674,4.825762,,,,,,,,,,,,,, -3900,1.1853228,4.878256,,,,,,,,,,,,,, -4000,0.9051676,5.9124975,,,,,,,,,,,,,, -4100,0.9636565,4.62098,,,,,,,,,,,,,, -4200,1.2333477,5.3936143,,,,,,,,,,,,,, -4300,0.6716493,6.0264416,,,,,,,,,,,,,, -4400,0.9180312,4.405513,,,,,,,,,,,,,, -4500,0.82416505,5.0109124,,,,,,,,,,,,,, -4600,0.80467826,6.183475,,,,,,,,,,,,,, -4700,0.9140983,4.8315396,,,,,,,,,,,,,, -4746,,,0.2476171851158142,3.7524003982543945,0.2291399985551834,3.865080118179321,50000.0,0.1774000078439712,4.310952663421631,10000.0,2145.8203287124634,2299.0677371025085,2145.8203287124634,152.84962821006775,0.1393589973449707,0.0 -4800,0.6941479,5.9070396,,,,,,,,,,,,,, -4900,1.0629138,4.461196,,,,,,,,,,,,,, -5000,1.0198977,4.870679,,,,,,,,,,,,,, -5100,0.8990658,4.1790686,,,,,,,,,,,,,, -5200,0.9494799,4.351896,,,,,,,,,,,,,, -5300,0.91953236,4.3661156,,,,,,,,,,,,,, -5400,1.1289669,4.400463,,,,,,,,,,,,,, -5500,0.83074343,4.234141,,,,,,,,,,,,,, -5600,0.9200742,4.5395246,,,,,,,,,,,,,, -5700,0.8809751,4.3950734,,,,,,,,,,,,,, -5703,,,0.3254687488079071,3.208423376083374,0.2970999777317047,3.3675014972686768,50000.0,0.230100005865097,3.9030966758728014,10000.0,2565.7739176750183,2740.7684206962585,2565.7739176750183,174.51342248916626,0.1706171035766601,0.0 -5800,1.0887792,4.6468315,,,,,,,,,,,,,, -5900,0.78883064,4.663852,,,,,,,,,,,,,, -6000,0.89440644,5.2262907,,,,,,,,,,,,,, -6100,0.93273395,4.05927,,,,,,,,,,,,,, -6200,0.69972014,6.0281973,,,,,,,,,,,,,, -6300,0.9230855,4.0392404,,,,,,,,,,,,,, -6400,0.876843,4.1805973,,,,,,,,,,,,,, -6500,0.68344927,6.1331244,,,,,,,,,,,,,, -6600,0.8928042,4.4135017,,,,,,,,,,,,,, -6657,,,0.3775781095027923,2.9110069274902344,0.329039990901947,3.177209615707397,50000.0,0.2517000138759613,3.7387924194335938,10000.0,2985.7302856445312,3182.8387632369995,2985.7302856445312,196.54568314552307,0.2015841007232666,0.0 -6700,0.8887903,4.0686116,,,,,,,,,,,,,, -6800,0.8489484,3.8390973,,,,,,,,,,,,,, -6900,0.9594965,3.8731458,,,,,,,,,,,,,, -7000,0.89232576,4.3187222,,,,,,,,,,,,,, -7100,0.610714,5.5376782,,,,,,,,,,,,,, -7200,0.76814663,4.2363124,,,,,,,,,,,,,, -7300,0.8823738,4.911396,,,,,,,,,,,,,, -7400,1.0056081,3.9277623,,,,,,,,,,,,,, -7500,0.7167194,5.18421,,,,,,,,,,,,,, -7600,0.8935183,3.9492412,,,,,,,,,,,,,, -7611,,,0.3849999904632568,2.8280928134918213,0.3578599989414215,2.9794843196868896,50000.0,0.2745999991893768,3.569232940673828,10000.0,3405.709196329117,3625.086359500885,3405.709196329117,218.7306222915649,0.2344107627868652,0.0 -7700,0.7505098,4.1993866,,,,,,,,,,,,,, -7800,0.9383474,4.344995,,,,,,,,,,,,,, -7900,0.9390154,3.7383807,,,,,,,,,,,,,, -8000,0.87480384,3.695923,,,,,,,,,,,,,, -8100,1.0022157,3.8782074,,,,,,,,,,,,,, -8200,1.0406163,3.6022916,,,,,,,,,,,,,, -8300,0.8493634,3.774659,,,,,,,,,,,,,, -8400,0.82045954,4.5457582,,,,,,,,,,,,,, -8500,0.8529036,3.6530836,,,,,,,,,,,,,, -8554,,,0.4186523258686065,2.685232162475586,0.3840599954128265,2.860114574432373,50000.0,0.2945000231266022,3.458648204803467,10000.0,3825.88839006424,4070.308994531632,3825.88839006424,243.68302154541016,0.2739176750183105,0.0 -8600,0.8271018,3.7337217,,,,,,,,,,,,,, -8700,1.0454578,3.5917602,,,,,,,,,,,,,, -8800,0.9205541,3.554917,,,,,,,,,,,,,, -8900,0.9524821,3.788969,,,,,,,,,,,,,, -9000,0.8813031,3.6217778,,,,,,,,,,,,,, -9100,0.8380908,3.6950166,,,,,,,,,,,,,, -9200,0.7871325,4.075155,,,,,,,,,,,,,, -9300,1.0064226,3.637169,,,,,,,,,,,,,, -9400,1.0896808,3.4886723,,,,,,,,,,,,,, -9500,0.62346065,5.068738,,,,,,,,,,,,,, -9502,,,0.4482226371765136,2.503745555877685,0.4045199751853943,2.7109215259552,50000.0,0.3096000254154205,3.3151302337646484,10000.0,4245.836257696152,4518.890733718872,4245.836257696152,272.23877477645874,0.3016986846923828,0.0 -9600,0.81211036,4.164614,,,,,,,,,,,,,, -9700,0.8505894,4.6227226,,,,,,,,,,,,,, -9800,0.9701638,3.508632,,,,,,,,,,,,,, -9900,1.1321651,3.4797351,,,,,,,,,,,,,, -10000,0.77170587,5.2010703,,,,,,,,,,,,,, -10100,1.1253779,3.5875554,,,,,,,,,,,,,, -10200,0.9841303,4.5803657,,,,,,,,,,,,,, -10300,0.82215744,5.7803407,,,,,,,,,,,,,, -10400,0.7828481,4.5171437,,,,,,,,,,,,,, -10455,,,0.4553320109844208,2.4251720905303955,0.4303999841213226,2.5701441764831543,50000.0,0.3285000026226043,3.205925226211548,10000.0,4665.796049118042,4965.096488714218,4665.796049118042,298.39640188217163,0.3389160633087158,0.0 -10500,0.9677511,3.3516827,,,,,,,,,,,,,, -10600,0.9503567,3.2927608,,,,,,,,,,,,,, -10700,0.87529266,3.8084483,,,,,,,,,,,,,, -10800,0.95547724,4.761781,,,,,,,,,,,,,, -10900,0.811633,5.815688,,,,,,,,,,,,,, -11000,0.86227286,5.3848023,,,,,,,,,,,,,, -11100,0.7228992,5.6018405,,,,,,,,,,,,,, -11200,1.0208005,3.3704953,,,,,,,,,,,,,, -11300,1.0402768,3.3261232,,,,,,,,,,,,,, -11400,0.9386585,3.4590185,,,,,,,,,,,,,, -11407,,,0.4743359386920929,2.3213415145874023,0.4418199956417084,2.4998693466186523,50000.0,0.3394000232219696,3.1411850452423096,10000.0,5085.781652927399,5410.151679992676,5085.781652927399,323.38474798202515,0.3693301677703857,0.0 -11500,0.99019957,3.2860265,,,,,,,,,,,,,, -11600,0.8564634,4.388264,,,,,,,,,,,,,, -11700,1.1574694,3.264718,,,,,,,,,,,,,, -11800,0.81270903,5.68434,,,,,,,,,,,,,, -11900,0.84099966,4.1845517,,,,,,,,,,,,,, -12000,1.0023679,3.267723,,,,,,,,,,,,,, -12100,0.90733516,5.7179794,,,,,,,,,,,,,, -12200,0.7574543,4.785717,,,,,,,,,,,,,, -12300,0.9541403,3.2708414,,,,,,,,,,,,,, -12359,,,0.4992382824420929,2.207763910293579,0.4544000029563904,2.4317727088928223,50000.0,0.3520000278949737,3.048286199569702,10000.0,5506.2299082279205,5858.720902204514,5506.2299082279205,351.4263126850128,0.3970980644226074,0.0 -12400,1.181187,3.3870213,,,,,,,,,,,,,, -12500,0.93603045,5.23799,,,,,,,,,,,,,, -12600,0.92240787,3.503267,,,,,,,,,,,,,, -12700,1.094862,3.164291,,,,,,,,,,,,,, -12800,1.0414716,3.360832,,,,,,,,,,,,,, -12900,1.0103263,3.3317666,,,,,,,,,,,,,, -13000,0.9623115,5.130972,,,,,,,,,,,,,, -13100,0.96901804,3.5176349,,,,,,,,,,,,,, -13200,1.0720718,3.0507748,,,,,,,,,,,,,, -13300,1.277037,3.0713184,,,,,,,,,,,,,, -13308,,,0.5203320384025574,2.1063053607940674,0.4576999843120575,2.4315109252929688,50000.0,0.3541000187397003,3.0787501335144043,10000.0,5926.184532165527,6307.407933473587,5926.184532165527,380.0723383426666,0.4321315288543701,0.0 -13400,0.9527362,3.2455547,,,,,,,,,,,,,, -13500,0.7650937,5.66236,,,,,,,,,,,,,, -13600,0.93637717,3.9869747,,,,,,,,,,,,,, -13700,0.97871876,3.194746,,,,,,,,,,,,,, -13800,1.0617368,3.1435885,,,,,,,,,,,,,, -13900,1.0460776,3.0475402,,,,,,,,,,,,,, -14000,0.78325534,5.349856,,,,,,,,,,,,,, -14100,0.7903653,4.665129,,,,,,,,,,,,,, -14200,1.0942931,3.332133,,,,,,,,,,,,,, -14256,,,0.512402355670929,2.137631893157959,0.4780799746513366,2.320540189743042,50000.0,0.3675000071525574,2.9688971042633057,10000.0,6346.236738204956,6754.822709321976,6346.236738204956,407.3503110408783,0.4661853313446045,0.0 -14300,1.1926527,3.1718373,,,,,,,,,,,,,, -14400,1.0587542,3.1207914,,,,,,,,,,,,,, -14500,1.0595596,3.246661,,,,,,,,,,,,,, -14600,1.0852033,3.152687,,,,,,,,,,,,,, -14700,1.040826,4.3691297,,,,,,,,,,,,,, -14800,0.8053281,5.526828,,,,,,,,,,,,,, -14900,1.0081726,3.056338,,,,,,,,,,,,,, -15000,0.7685831,5.180373,,,,,,,,,,,,,, -15100,1.1010159,3.0324965,,,,,,,,,,,,,, -15200,0.96303946,4.163135,,,,,,,,,,,,,, -15204,,,0.5299218893051147,2.066849708557129,0.4887799918651581,2.277812957763672,50000.0,0.3815000057220459,2.909775972366333,10000.0,6766.212954044342,7201.414595603943,6766.212954044342,433.8850626945496,0.4962055683135986,0.0 -15300,1.084974,3.0386677,,,,,,,,,,,,,, -15400,1.2538843,3.0330915,,,,,,,,,,,,,, -15500,1.2952076,3.0996926,,,,,,,,,,,,,, -15600,0.84543043,5.571461,,,,,,,,,,,,,, -15700,1.1316288,3.0174131,,,,,,,,,,,,,, -15800,1.0846944,3.1015573,,,,,,,,,,,,,, -15900,1.0724283,3.0411057,,,,,,,,,,,,,, -16000,1.0019436,3.0180302,,,,,,,,,,,,,, -16100,1.3307109,3.1682315,,,,,,,,,,,,,, -16145,,,0.5477734208106995,1.9316043853759768,0.5002399682998657,2.178302049636841,50000.0,0.3874000310897827,2.8302128314971924,10000.0,7186.382284641266,7650.084501981735,7186.382284641266,462.300833940506,0.5308792591094971,0.0 -16200,1.2247455,3.0345962,,,,,,,,,,,,,, -16300,1.0005416,5.2357435,,,,,,,,,,,,,, -16400,1.2919567,2.9126348,,,,,,,,,,,,,, -16500,0.7848887,5.3170595,,,,,,,,,,,,,, -16600,1.1668373,3.0195923,,,,,,,,,,,,,, -16700,0.94952476,5.637696,,,,,,,,,,,,,, -16800,1.2520833,2.897734,,,,,,,,,,,,,, -16900,1.3040524,3.0818932,,,,,,,,,,,,,, -17000,1.1730584,2.858278,,,,,,,,,,,,,, -17090,,,0.5459765791893005,1.9445308446884155,0.5103999972343445,2.1310389041900635,50000.0,0.3924000263214111,2.787764549255371,10000.0,7606.686836242676,8104.333922147751,7606.686836242676,496.14386224746704,0.582331657409668,0.0 -17100,0.876763,5.356686,,,,,,,,,,,,,, -17200,0.90920407,5.2552414,,,,,,,,,,,,,, -17300,1.0265157,3.7141702,,,,,,,,,,,,,, -17400,1.0526942,3.1149807,,,,,,,,,,,,,, -17500,0.7755595,5.165757,,,,,,,,,,,,,, -17600,0.94587994,4.7843366,,,,,,,,,,,,,, -17700,1.0036613,2.9267309,,,,,,,,,,,,,, -17800,1.0938253,2.8391778,,,,,,,,,,,,,, -17900,1.0572166,2.896333,,,,,,,,,,,,,, -18000,0.9278469,4.85469,,,,,,,,,,,,,, -18040,,,0.5554491877555847,1.92086660861969,0.5125399827957153,2.1145436763763428,50000.0,0.4018000066280365,2.779582262039185,10000.0,8026.735034227371,8560.396936416626,8026.735034227371,532.077849149704,0.6126272678375244,0.0 -18100,0.75416595,5.574152,,,,,,,,,,,,,, -18200,0.9608986,3.72296,,,,,,,,,,,,,, -18300,1.3154349,2.9051523,,,,,,,,,,,,,, -18400,0.8850592,3.7713199,,,,,,,,,,,,,, -18500,1.0292493,2.8713555,,,,,,,,,,,,,, -18600,1.025787,3.190726,,,,,,,,,,,,,, -18700,0.9171682,5.285834,,,,,,,,,,,,,, -18800,0.8397758,4.0543337,,,,,,,,,,,,,, -18900,0.8859443,4.4330034,,,,,,,,,,,,,, -18989,,,0.5669531226158142,1.846354842185974,0.5207399725914001,2.081745147705078,50000.0,0.4084000289440155,2.721212148666382,10000.0,8446.924666404724,9017.6736536026,8446.924666404724,569.0852816104889,0.6402697563171387,0.0 -19000,1.0650494,2.8297205,,,,,,,,,,,,,, -19100,1.1277411,2.8869684,,,,,,,,,,,,,, -19200,1.1485301,2.8087065,,,,,,,,,,,,,, -19300,0.84501487,5.347901,,,,,,,,,,,,,, -19400,1.0917933,2.8430924,,,,,,,,,,,,,, -19500,1.0314656,2.867301,,,,,,,,,,,,,, -19600,1.0196452,3.4318442,,,,,,,,,,,,,, -19700,1.0943323,3.5619414,,,,,,,,,,,,,, -19800,1.0979269,2.9641292,,,,,,,,,,,,,, -19900,1.2959311,2.8223226,,,,,,,,,,,,,, -19938,,,0.5889257788658142,1.7574844360351562,0.5252599716186523,2.079762458801269,50000.0,0.40870001912117,2.7397501468658447,10000.0,8867.289863586426,9474.639935016632,8867.289863586426,605.6058855056763,0.6697046756744385,0.0 -20000,0.9707948,5.3924847,,,,,,,,,,,,,, -20100,1.0591521,4.2808137,,,,,,,,,,,,,, -20200,1.1067181,2.7921834,,,,,,,,,,,,,, -20300,0.86980087,3.9980416,,,,,,,,,,,,,, -20400,1.1403264,2.8796403,,,,,,,,,,,,,, -20500,0.8588523,4.0402746,,,,,,,,,,,,,, -20600,1.0344343,2.8891978,,,,,,,,,,,,,, -20700,1.1142803,2.8559926,,,,,,,,,,,,,, -20800,1.0065117,2.7773972,,,,,,,,,,,,,, -20886,,,0.5757421851158142,1.8028565645217896,0.5389999747276306,1.9935595989227293,50000.0,0.4234000146389007,2.6601674556732178,10000.0,9287.409376859663,9934.12448978424,9287.409376859663,644.8831579685211,0.7056596279144287,0.0 -20900,1.1760795,2.8509798,,,,,,,,,,,,,, -21000,1.1743437,3.3699427,,,,,,,,,,,,,, -21100,1.0617968,3.1109066,,,,,,,,,,,,,, -21200,0.98123145,4.0962963,,,,,,,,,,,,,, -21300,1.0892074,2.7287195,,,,,,,,,,,,,, -21400,1.1931978,2.8471985,,,,,,,,,,,,,, -21500,1.111048,2.8697033,,,,,,,,,,,,,, -21600,1.0190048,3.130452,,,,,,,,,,,,,, -21700,1.0403032,5.089532,,,,,,,,,,,,,, -21800,0.8959572,4.5312033,,,,,,,,,,,,,, -21830,,,0.5738281011581421,1.8270362615585327,0.5313599705696106,2.036630630493164,50000.0,0.415800005197525,2.6884138584136963,10000.0,9707.7606818676,10388.789666175842,9707.7606818676,679.1143245697021,0.7382750511169434,0.0 -21900,0.99019724,3.6840563,,,,,,,,,,,,,, -22000,1.0046295,5.19704,,,,,,,,,,,,,, -22100,1.02468,2.732255,,,,,,,,,,,,,, -22200,1.2640014,2.8162584,,,,,,,,,,,,,, -22300,1.0757496,2.7689216,,,,,,,,,,,,,, -22400,0.92732525,5.0319204,,,,,,,,,,,,,, -22500,1.1012579,3.1568286,,,,,,,,,,,,,, -22600,1.0811476,2.8982182,,,,,,,,,,,,,, -22700,0.92563623,4.818336,,,,,,,,,,,,,, -22766,,,0.5947265625,1.7336976528167725,0.5411800146102905,1.994389057159424,50000.0,0.4213000237941742,2.65813946723938,10000.0,10127.784230947496,10843.401911258698,10127.784230947496,713.6164371967316,0.7727911472320557,0.0 -22800,1.1654433,2.7877476,,,,,,,,,,,,,, -22900,0.87175447,4.041813,,,,,,,,,,,,,, -23000,0.99267477,4.2190924,,,,,,,,,,,,,, -23100,1.1639187,2.8992894,,,,,,,,,,,,,, -23200,0.9577197,3.6080596,,,,,,,,,,,,,, -23300,0.87634265,5.3689675,,,,,,,,,,,,,, -23400,1.0905371,2.9179535,,,,,,,,,,,,,, -23500,1.1942971,2.6899943,,,,,,,,,,,,,, -23600,1.0481297,2.7248514,,,,,,,,,,,,,, -23700,0.93487597,3.2340004,,,,,,,,,,,,,, -23713,,,0.5895312428474426,1.7517651319503784,0.550059974193573,1.948309540748596,50000.0,0.43340003490448,2.610472679138184,10000.0,10547.981046676636,11298.194388628006,10547.981046676636,748.1307110786438,0.8031463623046875,0.0 -23800,1.1721234,3.1678338,,,,,,,,,,,,,, -23900,1.1269987,2.729352,,,,,,,,,,,,,, -24000,1.1317143,2.9204879,,,,,,,,,,,,,, -24100,1.0935683,3.2790377,,,,,,,,,,,,,, -24200,1.1126883,2.7008834,,,,,,,,,,,,,, -24300,1.0461335,3.1825511,,,,,,,,,,,,,, -24400,1.3056184,2.7859955,,,,,,,,,,,,,, -24500,1.0486789,2.7613418,,,,,,,,,,,,,, -24600,1.1996813,3.124483,,,,,,,,,,,,,, -24659,,,0.5965625047683716,1.7040070295333862,0.5518199801445007,1.9154651165008545,50000.0,0.4394000172615051,2.570166826248169,10000.0,10968.0647251606,11753.13712143898,10968.0647251606,782.9098272323608,0.8321309089660645,0.0 -24700,1.1241804,2.8349314,,,,,,,,,,,,,, -24800,0.99840194,3.0600483,,,,,,,,,,,,,, -24900,1.1801875,2.7395837,,,,,,,,,,,,,, -25000,1.183853,2.8156695,,,,,,,,,,,,,, -25100,1.2109889,2.7216692,,,,,,,,,,,,,, -25200,1.0338484,2.8871522,,,,,,,,,,,,,, -25300,1.1475971,3.857323,,,,,,,,,,,,,, -25400,0.93008333,4.6547985,,,,,,,,,,,,,, -25500,1.1540282,2.7264202,,,,,,,,,,,,,, -25600,1.0702028,2.824822,,,,,,,,,,,,,, -25604,,,0.6005468368530273,1.726381778717041,0.550819993019104,1.963115096092224,50000.0,0.4369000196456909,2.60762095451355,10000.0,11388.192962408066,12204.864979743958,11388.192962408066,814.4276103973389,0.8629908561706543,0.0 -25700,1.0472362,2.7051823,,,,,,,,,,,,,, -25800,1.122987,2.9478683,,,,,,,,,,,,,, -25900,1.1698979,2.7771995,,,,,,,,,,,,,, -26000,1.0442599,2.7006857,,,,,,,,,,,,,, -26100,0.9447905,4.2858915,,,,,,,,,,,,,, -26200,1.1820135,2.6908607,,,,,,,,,,,,,, -26300,1.1960574,2.7157922,,,,,,,,,,,,,, -26400,0.9917801,3.8879046,,,,,,,,,,,,,, -26500,1.1316849,2.699134,,,,,,,,,,,,,, -26536,,,0.6315820217132568,1.5329514741897583,0.5600399971008301,1.8702781200408936,50000.0,0.434000015258789,2.540038824081421,10000.0,11807.219497203829,12659.916893959044,11807.219497203829,849.3285005092621,1.937572956085205,0.0 -26600,1.2488729,2.7438068,,,,,,,,,,,,,, -26700,1.053577,3.590233,,,,,,,,,,,,,, -26800,1.0618206,2.8873918,,,,,,,,,,,,,, -26900,1.240304,4.5153623,,,,,,,,,,,,,, -27000,1.2270052,2.5135045,,,,,,,,,,,,,, -27100,0.96670663,5.262973,,,,,,,,,,,,,, -27200,1.1055048,3.024253,,,,,,,,,,,,,, -27300,0.9584817,5.3365116,,,,,,,,,,,,,, -27400,0.84448105,5.294055,,,,,,,,,,,,,, -27477,,,0.6069921851158142,1.6717681884765625,0.5662199854850769,1.8602055311203003,50000.0,0.4456000328063965,2.5280001163482666,10000.0,12227.3085603714,13113.331893920898,12227.3085603714,882.5709488391876,1.9710197448730469,0.0 -27500,1.2043041,2.662931,,,,,,,,,,,,,, -27600,0.9682417,4.952247,,,,,,,,,,,,,, -27700,0.9698236,4.5870614,,,,,,,,,,,,,, -27800,0.8837422,4.316947,,,,,,,,,,,,,, -27900,1.2139462,2.6778011,,,,,,,,,,,,,, -28000,0.85844815,5.1653566,,,,,,,,,,,,,, -28100,1.0841497,3.1484184,,,,,,,,,,,,,, -28200,1.0906574,3.193666,,,,,,,,,,,,,, -28300,1.0184692,4.9611216,,,,,,,,,,,,,, -28400,1.1103323,2.6909115,,,,,,,,,,,,,, -28419,,,0.6087499856948853,1.6305036544799805,0.5629400014877319,1.8604578971862795,50000.0,0.4465000331401825,2.534619092941284,10000.0,12647.270294189451,13566.522515296936,12647.270294189451,915.7219965457916,1.998776912689209,0.0 -28500,1.3244009,2.6449373,,,,,,,,,,,,,, -28600,1.0808989,4.4343534,,,,,,,,,,,,,, -28700,0.9038788,4.0472913,,,,,,,,,,,,,, -28800,1.1992387,2.94855,,,,,,,,,,,,,, -28900,1.0430955,3.5312788,,,,,,,,,,,,,, -29000,1.122551,2.715174,,,,,,,,,,,,,, -29100,1.1368351,2.730129,,,,,,,,,,,,,, -29200,1.2308093,2.665674,,,,,,,,,,,,,, -29300,0.85131764,4.9449954,,,,,,,,,,,,,, -29361,,,0.6239452958106995,1.6023894548416138,0.5638799667358398,1.8720834255218504,50000.0,0.4513000249862671,2.5328025817871094,10000.0,13067.21803355217,14020.92025589943,13067.21803355217,950.0937445163728,2.026951313018799,0.0 -29400,1.1681995,2.6022978,,,,,,,,,,,,,, -29500,1.2104706,5.2776313,,,,,,,,,,,,,, -29600,1.2298907,2.7602684,,,,,,,,,,,,,, -29700,1.0030171,3.8964205,,,,,,,,,,,,,, -29800,0.8768483,4.5116186,,,,,,,,,,,,,, -29900,0.978335,4.4992795,,,,,,,,,,,,,, -30000,1.1733588,2.5381303,,,,,,,,,,,,,, -30100,1.0740272,2.6475055,,,,,,,,,,,,,, -30200,1.201119,2.5473976,,,,,,,,,,,,,, -30300,0.8397721,4.827507,,,,,,,,,,,,,, -30305,,,0.6150586009025574,1.6110780239105225,0.5711399912834167,1.8235976696014404,50000.0,0.4594000279903412,2.4598233699798584,10000.0,13487.580243349075,14475.940197706224,13487.580243349075,984.6698379516602,2.058462142944336,0.0 -30400,1.1679912,2.629218,,,,,,,,,,,,,, -30500,1.2078726,2.6043262,,,,,,,,,,,,,, -30600,1.0291318,5.1183057,,,,,,,,,,,,,, -30700,1.2066168,2.640425,,,,,,,,,,,,,, -30800,1.1235983,2.5991824,,,,,,,,,,,,,, -30900,1.1725771,2.6864853,,,,,,,,,,,,,, -31000,1.1760788,2.5122461,,,,,,,,,,,,,, -31100,1.1115581,2.634134,,,,,,,,,,,,,, -31200,0.95334893,4.381222,,,,,,,,,,,,,, -31249,,,0.6220507621765137,1.5800942182540894,0.5763999819755554,1.7837181091308594,50000.0,0.4600000083446502,2.4539012908935547,10000.0,13907.861764431,14930.126143455504,13907.861764431,1018.4912509918212,2.090975284576416,0.0 -31300,1.0128924,3.5146036,,,,,,,,,,,,,, -31400,0.97042966,4.5449247,,,,,,,,,,,,,, -31500,1.2124223,2.4947145,,,,,,,,,,,,,, -31600,1.1416895,3.338046,,,,,,,,,,,,,, -31700,1.1058484,2.5993125,,,,,,,,,,,,,, -31800,1.087701,3.1039603,,,,,,,,,,,,,, -31900,1.1747209,2.6725192,,,,,,,,,,,,,, -32000,1.3010502,2.6106844,,,,,,,,,,,,,, -32100,0.98552674,4.0783367,,,,,,,,,,,,,, -32191,,,0.6280859112739563,1.5354851484298706,0.5817399621009827,1.7738856077194214,50000.0,0.4607000350952148,2.4564712047576904,10000.0,14328.242881059648,15385.033915281296,14328.242881059648,1052.9403524398804,2.1188464164733887,0.0 -32200,1.352933,5.0438714,,,,,,,,,,,,,, -32300,1.000107,4.782097,,,,,,,,,,,,,, -32400,1.2628568,2.64958,,,,,,,,,,,,,, -32500,1.2091166,2.5784857,,,,,,,,,,,,,, -32600,1.1950039,2.9037192,,,,,,,,,,,,,, -32700,1.1798416,5.335132,,,,,,,,,,,,,, -32800,1.1303132,5.2042947,,,,,,,,,,,,,, -32900,1.2521664,2.6272974,,,,,,,,,,,,,, -33000,1.2193091,3.192775,,,,,,,,,,,,,, -33100,1.0660864,2.8247051,,,,,,,,,,,,,, -33134,,,0.6576367020606995,1.3888908624649048,0.5830000042915344,1.7361211776733398,50000.0,0.4657000303268432,2.420597076416016,10000.0,14748.527193307877,15840.515821695328,14748.527193307877,1088.0549836158752,2.1518115997314453,0.0 -33200,1.098524,3.2916856,,,,,,,,,,,,,, -33300,1.0967256,2.6727881,,,,,,,,,,,,,, -33400,1.23064,2.445777,,,,,,,,,,,,,, -33500,1.1772153,2.9862351,,,,,,,,,,,,,, -33600,1.0775641,2.6428113,,,,,,,,,,,,,, -33700,1.0131466,2.7761703,,,,,,,,,,,,,, -33800,1.0030257,2.955282,,,,,,,,,,,,,, -33900,1.2136539,2.6305306,,,,,,,,,,,,,, -34000,1.2877572,2.6266484,,,,,,,,,,,,,, -34074,,,0.6277148127555847,1.5349551439285278,0.5827400088310242,1.75555419921875,50000.0,0.4599000215530395,2.418983221054077,10000.0,15168.447052955627,16294.405767679214,15168.447052955627,1121.9403958320618,2.186378002166748,0.0 -34100,1.1251646,2.5802534,,,,,,,,,,,,,, -34200,1.0998988,2.6555479,,,,,,,,,,,,,, -34300,1.345469,2.6637232,,,,,,,,,,,,,, -34400,1.1684539,2.5333457,,,,,,,,,,,,,, -34500,1.2307014,2.638231,,,,,,,,,,,,,, -34600,0.895512,5.118949,,,,,,,,,,,,,, -34700,0.9839985,5.1853175,,,,,,,,,,,,,, -34800,1.1681991,2.6267543,,,,,,,,,,,,,, -34900,1.033515,4.3690205,,,,,,,,,,,,,, -35000,1.0660589,4.4289727,,,,,,,,,,,,,, -35015,,,0.6364648342132568,1.5358233451843262,0.5851799845695496,1.7673289775848389,50000.0,0.4638000130653381,2.43481183052063,10000.0,15588.879315137863,16748.10175061226,15588.879315137863,1155.1252624988556,2.215908527374268,0.0 -35100,1.1335148,2.4106653,,,,,,,,,,,,,, -35200,1.1774931,2.5641923,,,,,,,,,,,,,, -35300,1.268554,2.744765,,,,,,,,,,,,,, -35400,1.1464746,2.9764283,,,,,,,,,,,,,, -35500,1.2040322,2.482096,,,,,,,,,,,,,, -35600,1.1929222,2.6110482,,,,,,,,,,,,,, -35700,1.00559,3.6499941,,,,,,,,,,,,,, -35800,1.0472887,2.5731323,,,,,,,,,,,,,, -35900,1.0694044,4.8054953,,,,,,,,,,,,,, -35958,,,0.6478710770606995,1.4692319631576538,0.5889399647712708,1.74940288066864,50000.0,0.468500018119812,2.4204137325286865,10000.0,16008.960027456284,17201.34227848053,16008.960027456284,1188.193915605545,2.2564587593078613,0.0 -36000,1.1059293,2.4100165,,,,,,,,,,,,,, -36100,1.0431554,5.152686,,,,,,,,,,,,,, -36200,1.085899,3.123555,,,,,,,,,,,,,, -36300,1.2854779,2.3862064,,,,,,,,,,,,,, -36400,1.0451481,4.1716547,,,,,,,,,,,,,, -36500,0.983623,5.019138,,,,,,,,,,,,,, -36600,1.0405027,4.2241306,,,,,,,,,,,,,, -36700,1.009165,4.455483,,,,,,,,,,,,,, -36800,1.0831629,2.6822944,,,,,,,,,,,,,, -36900,1.2383033,4.2329435,,,,,,,,,,,,,, -36901,,,0.63427734375,1.4953609704971311,0.5922600030899048,1.7099379301071167,50000.0,0.4686000347137451,2.3968710899353027,10000.0,16429.58711719513,17654.276985645294,16429.58711719513,1220.422137260437,2.285893678665161,0.0 -37000,1.1600279,2.7883978,,,,,,,,,,,,,, -37100,1.1633189,2.5985544,,,,,,,,,,,,,, -37200,1.1007558,2.5320957,,,,,,,,,,,,,, -37300,1.068122,3.676807,,,,,,,,,,,,,, -37400,1.2704473,5.032807,,,,,,,,,,,,,, -37500,1.2613842,5.2224555,,,,,,,,,,,,,, -37600,1.1931384,2.4973333,,,,,,,,,,,,,, -37700,1.2013935,2.5686994,,,,,,,,,,,,,, -37800,1.2286174,2.4274223,,,,,,,,,,,,,, -37842,,,0.6421288847923279,1.446354627609253,0.5999400019645691,1.6714078187942505,50000.0,0.4766000211238861,2.344536066055298,10000.0,16849.518322229385,18108.95192480088,16849.518322229385,1255.0818610191343,2.3198788166046143,0.0 -37900,1.1706032,4.9992785,,,,,,,,,,,,,, -38000,1.1668261,2.4492002,,,,,,,,,,,,,, -38100,1.3416696,2.4993155,,,,,,,,,,,,,, -38200,1.2315801,2.478405,,,,,,,,,,,,,, -38300,1.1808865,2.5315828,,,,,,,,,,,,,, -38400,1.0204046,4.870604,,,,,,,,,,,,,, -38500,1.2348379,2.4652905,,,,,,,,,,,,,, -38600,1.0512772,4.18228,,,,,,,,,,,,,, -38700,1.3433219,2.5483906,,,,,,,,,,,,,, -38784,,,0.6454101204872131,1.46431565284729,0.593239963054657,1.7185540199279783,50000.0,0.4739000201225281,2.387554883956909,10000.0,17269.525953292847,18561.85310387612,17269.525953292847,1287.8897886276245,2.355440855026245,0.0 -38800,1.0779628,4.4504323,,,,,,,,,,,,,, -38900,1.1686697,2.4845243,,,,,,,,,,,,,, -39000,1.3685051,2.5560908,,,,,,,,,,,,,, -39100,1.1840782,2.4961903,,,,,,,,,,,,,, -39200,1.0014517,4.645769,,,,,,,,,,,,,, -39300,0.95897,3.6703224,,,,,,,,,,,,,, -39400,1.0370907,3.7373421,,,,,,,,,,,,,, -39500,1.1896309,2.4524505,,,,,,,,,,,,,, -39600,1.1406941,2.484288,,,,,,,,,,,,,, -39700,0.95776504,5.053974,,,,,,,,,,,,,, -39726,,,0.6728320121765137,1.3413584232330322,0.5954399704933167,1.6972488164901731,50000.0,0.4764000177383423,2.3578004837036133,10000.0,17689.873439073563,19015.806552886963,17689.873439073563,1321.4097304344175,2.3911402225494385,0.0 -39800,1.0477304,4.1669135,,,,,,,,,,,,,, -39900,1.181483,4.965661,,,,,,,,,,,,,, -40000,0.98871464,4.3783674,,,,,,,,,,,,,, -40100,1.1016309,2.521298,,,,,,,,,,,,,, -40200,1.0252416,2.888705,,,,,,,,,,,,,, -40300,1.2693539,2.591473,,,,,,,,,,,,,, -40400,1.1515176,2.73325,,,,,,,,,,,,,, -40500,1.2559656,5.102531,,,,,,,,,,,,,, -40600,1.27724,2.4104095,,,,,,,,,,,,,, -40667,,,0.6483203172683716,1.4345197677612305,0.6010800004005432,1.6640561819076538,50000.0,0.4778000116348266,2.346426010131836,10000.0,18110.056242227554,19469.427932024,18110.056242227554,1354.7683067321775,2.4212679862976074,0.0 -40700,1.1419747,3.0097609,,,,,,,,,,,,,, -40800,0.9732138,3.6603682,,,,,,,,,,,,,, -40900,0.92279124,3.6907067,,,,,,,,,,,,,, -41000,1.3098086,2.4957936,,,,,,,,,,,,,, -41100,1.073932,4.6275845,,,,,,,,,,,,,, -41200,1.3146069,2.5130098,,,,,,,,,,,,,, -41300,1.0532693,4.9132686,,,,,,,,,,,,,, -41400,1.1441692,2.6079402,,,,,,,,,,,,,, -41500,1.2092804,2.4563375,,,,,,,,,,,,,, -41600,1.0735377,3.0799155,,,,,,,,,,,,,, -41610,,,0.6526757478713989,1.4107017517089844,0.6041799783706665,1.6475238800048828,50000.0,0.4812000095844269,2.312072277069092,10000.0,18530.41775512696,19923.3668589592,18530.41775512696,1388.2648763656616,2.4519410133361816,0.0 -41700,1.304446,2.5054226,,,,,,,,,,,,,, -41800,1.1408331,2.8199093,,,,,,,,,,,,,, -41900,1.1028377,4.76979,,,,,,,,,,,,,, -42000,1.1202487,2.7680223,,,,,,,,,,,,,, -42100,1.2567462,2.4869862,,,,,,,,,,,,,, -42200,1.1084057,2.6624296,,,,,,,,,,,,,, -42300,1.3010774,2.5666986,,,,,,,,,,,,,, -42400,1.3592293,5.0843983,,,,,,,,,,,,,, -42500,1.0009501,4.3093834,,,,,,,,,,,,,, -42552,,,0.6664062142372131,1.3672758340835571,0.6040399670600891,1.6572825908660889,50000.0,0.4817000329494476,2.320866584777832,10000.0,18950.665768146515,20376.23017168045,18950.665768146515,1420.8001787662506,2.481510877609253,0.0 -42600,1.2896577,2.4033618,,,,,,,,,,,,,, -42700,1.019552,3.0396461,,,,,,,,,,,,,, -42800,1.190691,2.4027534,,,,,,,,,,,,,, -42900,1.0578688,4.46898,,,,,,,,,,,,,, -43000,0.9978611,4.835045,,,,,,,,,,,,,, -43100,1.1419476,2.534042,,,,,,,,,,,,,, -43200,1.277061,2.4236424,,,,,,,,,,,,,, -43300,1.5168492,2.7574544,,,,,,,,,,,,,, -43400,1.2596955,2.6103485,,,,,,,,,,,,,, -43492,,,0.6471484303474426,1.4733134508132937,0.6031799912452698,1.6737443208694458,50000.0,0.482200026512146,2.3180975914001465,10000.0,19370.83881545067,20830.934995889664,19370.83881545067,1455.2440390586853,2.5197227001190186,0.0 -43500,1.2395407,2.6196313,,,,,,,,,,,,,, -43600,1.1094916,3.3555615,,,,,,,,,,,,,, -43700,1.1764739,2.6009364,,,,,,,,,,,,,, -43800,1.2204051,2.4404693,,,,,,,,,,,,,, -43900,1.0774468,4.916135,,,,,,,,,,,,,, -44000,1.0275375,4.25287,,,,,,,,,,,,,, -44100,1.0290893,5.022512,,,,,,,,,,,,,, -44200,1.309453,2.4564683,,,,,,,,,,,,,, -44300,1.0322589,3.2139766,,,,,,,,,,,,,, -44400,1.0163072,4.8777585,,,,,,,,,,,,,, -44432,,,0.6621484160423279,1.382601618766785,0.6111400127410889,1.6152799129486084,50000.0,0.4879000186920166,2.2897379398345947,10000.0,19790.954607486725,21285.523493766785,19790.954607486725,1489.634738445282,2.552811861038208,0.0 -44500,1.2633946,2.5654583,,,,,,,,,,,,,, -44600,1.2159681,2.446589,,,,,,,,,,,,,, -44700,1.0568007,3.6741774,,,,,,,,,,,,,, -44800,1.1776667,2.3075109,,,,,,,,,,,,,, -44900,1.2534746,2.4926984,,,,,,,,,,,,,, -45000,1.2499859,5.0767517,,,,,,,,,,,,,, -45100,1.0126314,2.9472795,,,,,,,,,,,,,, -45200,1.2102244,2.6389008,,,,,,,,,,,,,, -45300,1.1850342,2.6375585,,,,,,,,,,,,,, -45376,,,0.6607617139816284,1.3963559865951538,0.6053199768066406,1.6574883460998535,50000.0,0.4842000305652618,2.310580015182495,10000.0,20211.304652929302,21739.273277044296,20211.304652929302,1522.9538979530334,2.5838775634765625,0.0 -45400,1.2230749,2.4318604,,,,,,,,,,,,,, -45500,1.0734516,3.4590015,,,,,,,,,,,,,, -45600,1.0760534,3.81838,,,,,,,,,,,,,, -45700,1.1891406,2.5941339,,,,,,,,,,,,,, -45800,1.1443621,4.200316,,,,,,,,,,,,,, -45900,1.0903212,4.8982553,,,,,,,,,,,,,, -46000,1.3509846,4.359636,,,,,,,,,,,,,, -46100,1.2988149,2.4238696,,,,,,,,,,,,,, -46200,1.3317167,2.7814515,,,,,,,,,,,,,, -46300,1.1567907,2.6019351,,,,,,,,,,,,,, -46316,,,0.6971093416213989,1.220287799835205,0.6148999929428101,1.5943645238876345,50000.0,0.496800035238266,2.2552144527435303,10000.0,20631.39994740486,22195.084174633022,20631.39994740486,1558.584661245346,2.618739366531372,0.0 -46400,1.2018567,3.4436057,,,,,,,,,,,,,, -46500,1.16732,3.7226906,,,,,,,,,,,,,, -46600,1.2219408,2.5100706,,,,,,,,,,,,,, -46700,1.2503071,2.275514,,,,,,,,,,,,,, -46800,1.1464214,2.8527327,,,,,,,,,,,,,, -46900,1.1636444,2.3658345,,,,,,,,,,,,,, -47000,1.2450271,2.4782996,,,,,,,,,,,,,, -47100,1.1652677,2.7140377,,,,,,,,,,,,,, -47200,1.1001247,3.7804916,,,,,,,,,,,,,, -47252,,,0.6590625047683716,1.4004347324371338,0.6109799742698669,1.6322163343429563,50000.0,0.4879000186920166,2.2941040992736816,10000.0,21051.45190000534,22648.6025223732,21051.45190000534,1591.965892314911,2.6543779373168945,0.0 -47300,1.2272421,2.3559248,,,,,,,,,,,,,, -47400,1.1512204,2.4462938,,,,,,,,,,,,,, -47500,1.2492495,2.8410964,,,,,,,,,,,,,, -47600,1.1049813,4.7245216,,,,,,,,,,,,,, -47700,1.1373323,2.3694718,,,,,,,,,,,,,, -47800,1.1932064,2.6100743,,,,,,,,,,,,,, -47900,1.0914843,3.5826511,,,,,,,,,,,,,, -48000,1.2094473,2.515638,,,,,,,,,,,,,, -48100,1.3508095,2.4672928,,,,,,,,,,,,,, -48192,,,0.6631835699081421,1.3827235698699951,0.6143999695777893,1.61689555644989,50000.0,0.4850000143051147,2.307790756225586,10000.0,21471.53701448441,23102.67118740081,21471.53701448441,1625.8616099357605,2.691848039627075,0.0 -48200,1.0237219,4.312592,,,,,,,,,,,,,, -48300,1.3040456,4.6343007,,,,,,,,,,,,,, -48400,1.2038971,2.5811357,,,,,,,,,,,,,, -48500,1.1664407,3.0037494,,,,,,,,,,,,,, -48600,1.3758152,2.4793367,,,,,,,,,,,,,, -48700,1.1995158,2.841547,,,,,,,,,,,,,, -48800,1.0165734,4.9678497,,,,,,,,,,,,,, -48900,1.2242563,2.3409915,,,,,,,,,,,,,, -49000,1.0842606,4.74067,,,,,,,,,,,,,, -49100,1.2690574,2.5190587,,,,,,,,,,,,,, -49133,,,0.6769726276397705,1.3200808763504028,0.6147800087928772,1.6114550828933716,50000.0,0.490200012922287,2.286827802658081,10000.0,21891.7851896286,23558.934968948364,21891.7851896286,1661.7927005290985,2.726196765899658,0.0 -49200,1.1080663,4.6167097,,,,,,,,,,,,,, -49300,1.1822107,2.5709867,,,,,,,,,,,,,, -49400,1.2200154,2.4501743,,,,,,,,,,,,,, -49500,1.1966999,2.6158123,,,,,,,,,,,,,, -49600,1.4233031,2.5120928,,,,,,,,,,,,,, -49700,1.0864987,2.610198,,,,,,,,,,,,,, -49800,1.0491147,3.460217,,,,,,,,,,,,,, -49900,1.153302,4.0851746,,,,,,,,,,,,,, -50000,1.2370147,2.3714557,,,,,,,,,,,,,, -50074,,,0.6643164157867432,1.361477971076965,0.6191999912261963,1.578934907913208,50000.0,0.4976000189781189,2.2456860542297363,10000.0,22312.02948999405,24013.438791275024,22312.02948999405,1695.9685862064362,2.759733200073242,0.0 -50100,1.0547513,4.8828044,,,,,,,,,,,,,, -50200,1.0656585,4.934848,,,,,,,,,,,,,, -50300,0.9925436,3.9930716,,,,,,,,,,,,,, -50400,1.3371723,2.4898648,,,,,,,,,,,,,, -50500,1.3080117,5.0300274,,,,,,,,,,,,,, -50600,1.3557683,2.494644,,,,,,,,,,,,,, -50700,1.1918126,2.6600056,,,,,,,,,,,,,, -50800,1.2442615,2.3303523,,,,,,,,,,,,,, -50900,0.9773535,4.5000815,,,,,,,,,,,,,, -51000,1.1416252,4.0504146,,,,,,,,,,,,,, -51013,,,0.6668554544448853,1.3606523275375366,0.6202799677848816,1.5822699069976809,50000.0,0.4964000284671783,2.2490456104278564,10000.0,22732.127825021744,24466.235282182693,22732.127825021744,1728.5848369598389,2.7923452854156494,0.0 -51100,1.0774131,4.9953814,,,,,,,,,,,,,, -51200,1.202219,2.5656118,,,,,,,,,,,,,, -51300,0.9698232,4.6972237,,,,,,,,,,,,,, -51400,1.1153402,2.6481552,,,,,,,,,,,,,, -51500,1.2870092,2.4221017,,,,,,,,,,,,,, -51600,1.1063609,4.709109,,,,,,,,,,,,,, -51700,1.1670833,2.828135,,,,,,,,,,,,,, -51800,1.03865,3.0029871,,,,,,,,,,,,,, -51900,1.0917534,2.683463,,,,,,,,,,,,,, -51952,,,0.6714843511581421,1.339386224746704,0.6193999648094177,1.5890146493911743,50000.0,0.4962000250816345,2.252509593963623,10000.0,23152.065212011337,24920.99333834648,23152.065212011337,1763.3116884231567,2.836010217666626,0.0 -52000,1.252066,2.4645905,,,,,,,,,,,,,, -52100,1.3108264,2.3694696,,,,,,,,,,,,,, -52200,1.319441,2.3260186,,,,,,,,,,,,,, -52300,1.3227465,2.3764155,,,,,,,,,,,,,, -52400,1.2587447,2.2705054,,,,,,,,,,,,,, -52500,1.2833033,2.3903177,,,,,,,,,,,,,, -52600,1.2525089,2.2398367,,,,,,,,,,,,,, -52700,1.3780223,2.4185367,,,,,,,,,,,,,, -52800,1.2482066,2.386038,,,,,,,,,,,,,, -52896,,,0.6980078220367432,1.2068614959716797,0.6241999864578247,1.5524660348892212,50000.0,0.5056000351905823,2.2364532947540283,10000.0,23572.228355884552,25374.639178276066,23572.228355884552,1796.709459066391,2.8703958988189697,0.0 -52900,1.5441388,2.30464,,,,,,,,,,,,,, -53000,1.2813777,2.308567,,,,,,,,,,,,,, -53100,1.1757082,2.7356668,,,,,,,,,,,,,, -53200,1.2118801,3.4427435,,,,,,,,,,,,,, -53300,1.2980824,2.33638,,,,,,,,,,,,,, -53400,1.1720761,2.3330255,,,,,,,,,,,,,, -53500,1.3598549,2.470764,,,,,,,,,,,,,, -53600,1.3687183,2.3299932,,,,,,,,,,,,,, -53700,1.2279532,2.6116185,,,,,,,,,,,,,, -53800,1.2256317,4.4922724,,,,,,,,,,,,,, -53833,,,0.6648241877555847,1.4142321348190308,0.6174799799919128,1.6313257217407229,50000.0,0.4935000240802765,2.286892890930176,10000.0,23992.23517799377,25829.765172481537,23992.23517799377,1831.7415869235992,2.908019781112671,0.0 -53900,1.1648608,2.65153,,,,,,,,,,,,,, -54000,1.2067521,4.71054,,,,,,,,,,,,,, -54100,1.1255207,3.172582,,,,,,,,,,,,,, -54200,1.1783681,2.7105498,,,,,,,,,,,,,, -54300,1.2739352,2.2921042,,,,,,,,,,,,,, -54400,1.2703781,2.479144,,,,,,,,,,,,,, -54500,1.4145548,2.4077234,,,,,,,,,,,,,, -54600,1.0144997,4.7961235,,,,,,,,,,,,,, -54700,1.330967,2.674904,,,,,,,,,,,,,, -54772,,,0.6703906059265137,1.3580645322799685,0.6202799677848816,1.596779227256775,50000.0,0.4990000128746032,2.2520761489868164,10000.0,24412.294757843018,26283.889540433884,24412.294757843018,1865.720906496048,2.943570375442505,0.0 -54800,1.0918455,4.944497,,,,,,,,,,,,,, -54900,1.0868343,2.9010072,,,,,,,,,,,,,, -55000,1.3951454,2.3483553,,,,,,,,,,,,,, -55100,1.1981514,2.5896816,,,,,,,,,,,,,, -55200,1.228784,2.3270493,,,,,,,,,,,,,, -55300,1.1232839,4.1284056,,,,,,,,,,,,,, -55400,1.2487149,2.3562028,,,,,,,,,,,,,, -55500,1.053376,4.1705856,,,,,,,,,,,,,, -55600,1.1641445,2.822363,,,,,,,,,,,,,, -55700,1.3386253,2.3244328,,,,,,,,,,,,,, -55711,,,0.6917187571525574,1.2428334951400757,0.6261199712753296,1.5562325716018677,50000.0,0.5027000308036804,2.2339158058166504,10000.0,24832.553787231445,26739.5900247097,24832.553787231445,1901.078797578812,2.9773149490356445,0.0 -55800,1.2311212,2.5171518,,,,,,,,,,,,,, -55900,1.1385937,3.2499642,,,,,,,,,,,,,, -56000,1.2081504,2.8140044,,,,,,,,,,,,,, -56100,1.0612298,3.3179045,,,,,,,,,,,,,, -56200,1.1639783,4.1978874,,,,,,,,,,,,,, -56300,1.2657075,2.392644,,,,,,,,,,,,,, -56400,1.386801,2.3348012,,,,,,,,,,,,,, -56500,1.2747512,2.2880707,,,,,,,,,,,,,, -56600,1.2035094,2.4762726,,,,,,,,,,,,,, -56652,,,0.6736913919448853,1.337292194366455,0.6258000135421753,1.5741112232208252,50000.0,0.501300036907196,2.226841449737549,10000.0,25252.7198586464,27196.0502679348,25252.7198586464,1937.29074048996,3.00985050201416,0.0 -56700,1.0305549,4.8350396,,,,,,,,,,,,,, -56800,1.1800237,2.681948,,,,,,,,,,,,,, -56900,1.2667779,2.2767885,,,,,,,,,,,,,, -57000,1.3152215,2.288693,,,,,,,,,,,,,, -57100,1.1302087,3.0941596,,,,,,,,,,,,,, -57200,1.2181088,2.2965636,,,,,,,,,,,,,, -57300,1.2657735,2.3102918,,,,,,,,,,,,,, -57400,1.1584744,3.8445044,,,,,,,,,,,,,, -57500,1.1416522,4.059325,,,,,,,,,,,,,, -57593,,,0.6827539205551147,1.2796573638916016,0.629040002822876,1.528598427772522,50000.0,0.5081000328063965,2.189667224884033,10000.0,25672.69250059128,27648.986284017563,25672.69250059128,1970.1550877094269,3.059091567993164,0.0 -57600,1.359938,2.349773,,,,,,,,,,,,,, -57700,1.2289841,4.3364334,,,,,,,,,,,,,, -57800,1.157958,2.6061027,,,,,,,,,,,,,, -57900,1.2698531,2.8541565,,,,,,,,,,,,,, -58000,1.2934575,2.3624063,,,,,,,,,,,,,, -58100,1.2369188,2.3444655,,,,,,,,,,,,,, -58200,1.1472535,2.683591,,,,,,,,,,,,,, -58300,1.0996217,3.7288764,,,,,,,,,,,,,, -58400,1.2357514,2.4070582,,,,,,,,,,,,,, -58500,1.2603651,2.9830031,,,,,,,,,,,,,, -58530,,,0.6944921612739563,1.2365962266921997,0.6259999871253967,1.5299229621887207,50000.0,0.5099000334739685,2.191647529602051,10000.0,26092.81829667092,28103.711234807968,26092.81829667092,2004.6654317379,3.09879994392395,0.0 -58600,1.27683,4.698041,,,,,,,,,,,,,, -58700,1.1666119,2.4796956,,,,,,,,,,,,,, -58800,1.1459605,2.3932831,,,,,,,,,,,,,, -58900,1.304682,2.3903213,,,,,,,,,,,,,, -59000,1.2258536,3.8531845,,,,,,,,,,,,,, -59100,1.0978096,3.4257536,,,,,,,,,,,,,, -59200,1.1279428,3.8582864,,,,,,,,,,,,,, -59300,1.3199041,2.277669,,,,,,,,,,,,,, -59400,1.0158259,3.9499614,,,,,,,,,,,,,, -59471,,,0.7023242115974426,1.2232853174209597,0.637499988079071,1.5057755708694458,50000.0,0.5143000483512878,2.179502487182617,10000.0,26512.790625810623,28558.48644924164,26512.790625810623,2039.381314992905,3.1359636783599854,0.0 -59500,1.2450383,2.6752875,,,,,,,,,,,,,, -59600,1.3761438,2.1962018,,,,,,,,,,,,,, -59700,1.2047681,2.3907633,,,,,,,,,,,,,, -59800,1.19615,2.3150287,,,,,,,,,,,,,, -59900,1.3077639,2.3094642,,,,,,,,,,,,,, -60000,1.3523391,2.2762194,,,,,,,,,,,,,, -60100,1.1729072,3.0480845,,,,,,,,,,,,,, -60200,1.1504959,3.084028,,,,,,,,,,,,,, -60300,1.1563834,2.1958876,,,,,,,,,,,,,, -60400,1.1288987,4.989334,,,,,,,,,,,,,, -60412,,,0.682324230670929,1.2863421440124512,0.6322399973869324,1.51683247089386,50000.0,0.510200023651123,2.1724071502685547,10000.0,26932.726552009583,29011.95117688179,26932.726552009583,2072.8224205970764,3.1741859912872314,0.0 -60500,1.2993698,2.1102486,,,,,,,,,,,,,, -60600,1.1230376,3.6522958,,,,,,,,,,,,,, -60700,1.1267658,4.9290814,,,,,,,,,,,,,, -60800,1.1173539,3.899909,,,,,,,,,,,,,, -60900,1.2893744,2.2939577,,,,,,,,,,,,,, -61000,1.1461283,3.5838308,,,,,,,,,,,,,, -61100,1.2639308,2.379707,,,,,,,,,,,,,, -61200,1.0882877,3.018608,,,,,,,,,,,,,, -61300,1.253509,3.1228197,,,,,,,,,,,,,, -61352,,,0.6830468773841858,1.3047661781311035,0.6285399794578552,1.5588743686676023,50000.0,0.5085000395774841,2.223173856735229,10000.0,27352.794692516327,29467.8389453888,27352.794692516327,2108.5590019226074,3.208047389984131,0.0 -61400,1.0904073,3.2793303,,,,,,,,,,,,,, -61500,1.3558598,2.4100602,,,,,,,,,,,,,, -61600,1.165618,3.2571568,,,,,,,,,,,,,, -61700,1.302707,2.22743,,,,,,,,,,,,,, -61800,1.2089988,2.7532144,,,,,,,,,,,,,, -61900,1.3188748,2.601691,,,,,,,,,,,,,, -62000,1.2309318,3.2000873,,,,,,,,,,,,,, -62100,1.3640869,2.2532697,,,,,,,,,,,,,, -62200,1.2780781,2.3367684,,,,,,,,,,,,,, -62292,,,0.703808605670929,1.203053593635559,0.6331799626350403,1.529718995094299,50000.0,0.5073000192642212,2.181467771530152,10000.0,27772.926307678223,29923.00162839889,27772.926307678223,2143.499714612961,3.2494382858276367,0.0 -62300,1.161823,2.6597936,,,,,,,,,,,,,, -62400,1.23066,2.1165648,,,,,,,,,,,,,, -62500,1.0609596,4.772629,,,,,,,,,,,,,, -62600,1.287984,2.385432,,,,,,,,,,,,,, -62700,1.2451528,2.5865948,,,,,,,,,,,,,, -62800,1.310676,2.2187953,,,,,,,,,,,,,, -62900,1.3156896,2.1465478,,,,,,,,,,,,,, -63000,1.0599866,4.5033627,,,,,,,,,,,,,, -63100,1.1197122,3.2773044,,,,,,,,,,,,,, -63200,1.3259728,2.319866,,,,,,,,,,,,,, -63229,,,0.6870507597923279,1.2610909938812256,0.6409800052642822,1.4879895448684692,50000.0,0.5142000317573547,2.15094256401062,10000.0,28192.87024521828,30377.92234969139,28192.87024521828,2178.395395040512,3.282611131668091,0.0 -63300,1.3053119,2.4525023,,,,,,,,,,,,,, -63400,1.2239871,2.362348,,,,,,,,,,,,,, -63500,1.276645,2.118396,,,,,,,,,,,,,, -63600,1.3393807,2.3889513,,,,,,,,,,,,,, -63700,1.3365492,2.2223701,,,,,,,,,,,,,, -63800,1.3141613,2.2122407,,,,,,,,,,,,,, -63900,1.2573851,2.4611652,,,,,,,,,,,,,, -64000,1.4073006,2.3767889,,,,,,,,,,,,,, -64100,1.4442685,2.1976507,,,,,,,,,,,,,, -64170,,,0.6971484422683716,1.215964436531067,0.6412799954414368,1.470037579536438,50000.0,0.5209000110626221,2.133341789245605,10000.0,28613.161401748657,30833.526191473007,28613.161401748657,2213.624122619629,3.317622423171997,0.0 -64200,1.2686129,2.2288673,,,,,,,,,,,,,, -64300,1.0724808,3.980586,,,,,,,,,,,,,, -64400,1.3574404,2.2873378,,,,,,,,,,,,,, -64500,1.1882912,4.4801807,,,,,,,,,,,,,, -64600,1.1411276,2.797097,,,,,,,,,,,,,, -64700,1.227427,4.829632,,,,,,,,,,,,,, -64800,1.1344438,2.5675774,,,,,,,,,,,,,, -64900,1.3837434,2.2909777,,,,,,,,,,,,,, -65000,1.243666,3.1046414,,,,,,,,,,,,,, -65100,1.320261,2.118996,,,,,,,,,,,,,, -65111,,,0.7034765481948853,1.2014892101287842,0.6446999907493591,1.480310559272766,50000.0,0.51910001039505,2.131485939025879,10000.0,29033.25125908852,31287.50227761269,29033.25125908852,2247.422921895981,3.3557910919189453,0.0 -65200,1.2463481,2.4220965,,,,,,,,,,,,,, -65300,1.3228933,2.2547252,,,,,,,,,,,,,, -65400,1.272108,2.1910253,,,,,,,,,,,,,, -65500,1.2006453,2.1991394,,,,,,,,,,,,,, -65600,1.1826189,2.651113,,,,,,,,,,,,,, -65700,1.3542532,4.889928,,,,,,,,,,,,,, -65800,1.2152535,2.2025206,,,,,,,,,,,,,, -65900,1.266507,2.5097146,,,,,,,,,,,,,, -66000,1.4332106,2.239565,,,,,,,,,,,,,, -66053,,,0.6955664157867432,1.2484641075134275,0.6377800107002258,1.503852128982544,50000.0,0.5210000276565552,2.150148630142212,10000.0,29453.42460083961,31743.060341358185,29453.42460083961,2282.7217609882355,3.39251971244812,0.0 -66100,1.3919082,2.092138,,,,,,,,,,,,,, -66200,1.2888598,2.1713157,,,,,,,,,,,,,, -66300,1.2526768,2.3199112,,,,,,,,,,,,,, -66400,1.1897784,3.3398821,,,,,,,,,,,,,, -66500,1.2937164,2.2018344,,,,,,,,,,,,,, -66600,1.3071563,2.2040985,,,,,,,,,,,,,, -66700,1.458158,2.643737,,,,,,,,,,,,,, -66800,1.2986015,2.1831179,,,,,,,,,,,,,, -66900,1.1087207,2.7528424,,,,,,,,,,,,,, -66995,,,0.6943359375,1.2172327041625977,0.6418399810791016,1.4598710536956787,50000.0,0.5213000178337097,2.1135141849517822,10000.0,29873.4261200428,32196.684210062027,29873.4261200428,2316.2620203495026,3.425584554672241,0.0 -67000,1.198053,2.2169688,,,,,,,,,,,,,, -67100,1.1458625,4.6489697,,,,,,,,,,,,,, -67200,1.3012149,2.7987547,,,,,,,,,,,,,, -67300,1.3133982,2.2503185,,,,,,,,,,,,,, -67400,1.3215079,2.198235,,,,,,,,,,,,,, -67500,1.1796129,3.206238,,,,,,,,,,,,,, -67600,1.1555688,4.302013,,,,,,,,,,,,,, -67700,1.3221141,2.1463356,,,,,,,,,,,,,, -67800,1.0731947,4.749666,,,,,,,,,,,,,, -67900,1.2857565,2.1380985,,,,,,,,,,,,,, -67937,,,0.7025390267372131,1.1881595849990845,0.6430400013923645,1.46323823928833,50000.0,0.5241000056266785,2.116852045059204,10000.0,30293.34878706932,32652.327559947968,30293.34878706932,2351.888193130493,3.4705111980438232,0.0 -68000,1.3756593,2.432414,,,,,,,,,,,,,, -68100,1.0921283,4.7422,,,,,,,,,,,,,, -68200,1.3158171,2.450038,,,,,,,,,,,,,, -68300,1.385935,2.6054304,,,,,,,,,,,,,, -68400,1.2457268,3.3604279,,,,,,,,,,,,,, -68500,1.2262348,2.661912,,,,,,,,,,,,,, -68600,1.3134842,2.0203133,,,,,,,,,,,,,, -68700,1.1893009,4.778285,,,,,,,,,,,,,, -68800,1.2153724,3.699633,,,,,,,,,,,,,, -68878,,,0.7176562547683716,1.159710168838501,0.6420599818229675,1.499109983444214,50000.0,0.5215000510215759,2.154394149780273,10000.0,30713.28437423706,33106.69675326347,30713.28437423706,2386.2355921268463,3.507035970687866,0.0 -68900,1.4018074,2.378286,,,,,,,,,,,,,, -69000,1.1914502,4.765442,,,,,,,,,,,,,, -69100,1.1793317,2.529656,,,,,,,,,,,,,, -69200,1.347818,2.1302192,,,,,,,,,,,,,, -69300,1.4076312,2.4403489,,,,,,,,,,,,,, -69400,1.1555578,2.5237308,,,,,,,,,,,,,, -69500,1.1297066,3.7206576,,,,,,,,,,,,,, -69600,1.3234985,2.6670494,,,,,,,,,,,,,, -69700,1.3142525,2.1784034,,,,,,,,,,,,,, -69800,1.2889197,4.2335277,,,,,,,,,,,,,, -69815,,,0.6985741853713989,1.2329505681991575,0.6497399806976318,1.4623702764511108,50000.0,0.5277000069618225,2.1101725101470947,10000.0,31133.5650537014,33563.08368849754,31133.5650537014,2422.250189781189,3.54935359954834,0.0 -69900,1.1146495,2.905054,,,,,,,,,,,,,, -70000,1.337731,2.382214,,,,,,,,,,,,,, -70100,1.3326288,2.3881054,,,,,,,,,,,,,, -70200,1.3445884,4.6699886,,,,,,,,,,,,,, -70300,1.2497253,3.1696918,,,,,,,,,,,,,, -70400,1.3047585,2.169954,,,,,,,,,,,,,, -70500,1.0691667,4.742501,,,,,,,,,,,,,, -70600,1.2873785,4.8431587,,,,,,,,,,,,,, -70700,1.3300298,2.5026531,,,,,,,,,,,,,, -70754,,,0.7025781273841858,1.182090163230896,0.6479200124740601,1.435321807861328,50000.0,0.5285000205039978,2.077733039855957,10000.0,31553.61950206757,34018.04113793373,31553.61950206757,2457.065773963928,3.5868256092071533,0.0 -70800,1.1958615,2.5078259,,,,,,,,,,,,,, -70900,1.3755543,2.3386302,,,,,,,,,,,,,, -71000,1.2507498,3.629373,,,,,,,,,,,,,, -71100,1.1749729,2.463243,,,,,,,,,,,,,, -71200,1.416923,2.3883882,,,,,,,,,,,,,, -71300,1.1143214,4.58609,,,,,,,,,,,,,, -71400,1.1405443,3.0131235,,,,,,,,,,,,,, -71500,1.1685705,2.9815197,,,,,,,,,,,,,, -71600,1.3500965,2.31677,,,,,,,,,,,,,, -71693,,,0.7128710746765137,1.1547907590866089,0.6517399549484253,1.451304316520691,50000.0,0.5290000438690186,2.096698760986328,10000.0,31973.56262850761,34472.430872917175,31973.56262850761,2491.428635120392,3.620942354202272,0.0 -71700,1.1302543,3.75,,,,,,,,,,,,,, -71800,1.2929534,2.1610613,,,,,,,,,,,,,, -71900,1.3004391,2.201583,,,,,,,,,,,,,, -72000,1.1345282,2.8742948,,,,,,,,,,,,,, -72100,1.3437703,2.1609123,,,,,,,,,,,,,, -72200,1.2309055,3.1309826,,,,,,,,,,,,,, -72300,1.3654217,2.2419612,,,,,,,,,,,,,, -72400,1.36782,2.965707,,,,,,,,,,,,,, -72500,1.3029265,2.1460156,,,,,,,,,,,,,, -72600,1.226556,3.1228075,,,,,,,,,,,,,, -72634,,,0.7041015625,1.1773505210876465,0.6512599587440491,1.4189000129699707,50000.0,0.5267000198364258,2.096280097961426,10000.0,32393.721923351288,34927.862758398056,32393.721923351288,2526.612436294556,3.659743547439575,0.0 -72700,1.1611305,3.484056,,,,,,,,,,,,,, -72800,1.3441817,4.3855586,,,,,,,,,,,,,, -72900,1.3560289,2.3208861,,,,,,,,,,,,,, -73000,1.3186812,2.310885,,,,,,,,,,,,,, -73100,1.1917841,2.1339412,,,,,,,,,,,,,, -73200,1.2322767,2.8819084,,,,,,,,,,,,,, -73300,1.3655242,2.2577672,,,,,,,,,,,,,, -73400,1.2017249,2.620122,,,,,,,,,,,,,, -73500,1.3383378,2.101736,,,,,,,,,,,,,, -73574,,,0.7027539014816284,1.1971248388290403,0.6504200100898743,1.4388371706008911,50000.0,0.5236999988555908,2.107862949371338,10000.0,32813.861067295074,35380.29990744591,32813.861067295074,2558.8190701007843,3.701513767242432,0.0 -73600,1.0856522,3.7840948,,,,,,,,,,,,,, -73700,1.3057104,2.718178,,,,,,,,,,,,,, -73800,1.3042617,2.4323657,,,,,,,,,,,,,, -73900,1.3137368,2.212492,,,,,,,,,,,,,, -74000,1.2806035,4.6909337,,,,,,,,,,,,,, -74100,1.307269,2.3748684,,,,,,,,,,,,,, -74200,1.27448,3.65308,,,,,,,,,,,,,, -74300,1.3396329,2.132223,,,,,,,,,,,,,, -74400,1.4125732,1.9663199,,,,,,,,,,,,,, -74500,1.1902462,2.8499122,,,,,,,,,,,,,, -74514,,,0.7067577838897705,1.1871877908706665,0.6526199579238892,1.4396685361862185,50000.0,0.5236000418663025,2.103861093521118,10000.0,33233.81961917877,35835.50805258751,33233.81961917877,2593.974936962128,3.7460293769836426,0.0 -74600,1.1280997,3.535479,,,,,,,,,,,,,, -74700,1.3698634,2.1953292,,,,,,,,,,,,,, -74800,1.5091103,2.2673438,,,,,,,,,,,,,, -74900,1.2995017,2.104152,,,,,,,,,,,,,, -75000,1.1970897,2.981686,,,,,,,,,,,,,, -75100,1.2429435,2.148736,,,,,,,,,,,,,, -75200,1.126835,2.8838105,,,,,,,,,,,,,, -75300,1.478633,2.4182549,,,,,,,,,,,,,, -75400,1.2105504,4.540454,,,,,,,,,,,,,, -75456,,,0.7224413752555847,1.1160813570022583,0.648419976234436,1.4548654556274414,50000.0,0.5225000381469727,2.114449977874756,10000.0,33653.806334257126,36289.40531396866,33653.806334257126,2627.8011043071747,3.7806742191314697,0.0 -75500,1.1993897,3.4664154,,,,,,,,,,,,,, -75600,1.2696345,2.0373774,,,,,,,,,,,,,, -75700,1.3802654,2.3178596,,,,,,,,,,,,,, -75800,1.5172262,2.3193803,,,,,,,,,,,,,, -75900,1.4469935,2.130755,,,,,,,,,,,,,, -76000,1.3318989,3.2099435,,,,,,,,,,,,,, -76100,1.4902042,2.2970963,,,,,,,,,,,,,, -76200,1.214825,3.8126771,,,,,,,,,,,,,, -76300,1.3233231,2.112323,,,,,,,,,,,,,, -76396,,,0.70751953125,1.1857048273086548,0.6567800045013428,1.4257631301879885,50000.0,0.5279000401496887,2.0987911224365234,10000.0,34073.717185258865,36742.60909795761,34073.717185258865,2661.0037302970886,3.822371244430542,0.0 -76400,1.5146673,2.2837806,,,,,,,,,,,,,, -76500,1.1830686,4.0158625,,,,,,,,,,,,,, -76600,1.4200119,2.1689517,,,,,,,,,,,,,, -76700,1.4621992,2.049031,,,,,,,,,,,,,, -76800,1.2682985,2.411047,,,,,,,,,,,,,, -76900,1.3218365,2.3457427,,,,,,,,,,,,,, -77000,1.1693897,3.7229352,,,,,,,,,,,,,, -77100,1.324159,2.1985898,,,,,,,,,,,,,, -77200,1.4011024,2.2550125,,,,,,,,,,,,,, -77300,1.5193689,2.0435443,,,,,,,,,,,,,, -77335,,,0.71533203125,1.1208672523498535,0.6605199575424194,1.3832670450210571,50000.0,0.5332000255584717,2.056315660476685,10000.0,34493.6725795269,37194.7865831852,34493.6725795269,2693.128019094467,3.871058940887451,0.0 -77400,1.1645266,4.6817107,,,,,,,,,,,,,, -77500,1.3170472,3.055135,,,,,,,,,,,,,, -77600,1.3576792,2.1422896,,,,,,,,,,,,,, -77700,1.2495862,2.7092745,,,,,,,,,,,,,, -77800,1.244371,3.763774,,,,,,,,,,,,,, -77900,1.2970057,2.1341407,,,,,,,,,,,,,, -78000,1.1591467,4.2979774,,,,,,,,,,,,,, -78100,1.5180598,2.1782131,,,,,,,,,,,,,, -78200,1.215623,4.3296547,,,,,,,,,,,,,, -78272,,,0.7160937190055847,1.1190756559371948,0.6571199893951416,1.4038711786270142,50000.0,0.5371000170707703,2.056650161743164,10000.0,34913.984763622284,37649.24601197243,34913.984763622284,2727.1798944473267,3.917135238647461,0.0 -78300,1.1845206,3.0107677,,,,,,,,,,,,,, -78400,1.3587532,2.055797,,,,,,,,,,,,,, -78500,1.3309134,2.1635165,,,,,,,,,,,,,, -78600,1.3991319,2.2552915,,,,,,,,,,,,,, -78700,1.1578711,3.5344253,,,,,,,,,,,,,, -78800,1.3559378,2.099295,,,,,,,,,,,,,, -78900,1.4256296,2.422592,,,,,,,,,,,,,, -79000,1.3743931,2.0359967,,,,,,,,,,,,,, -79100,1.2524164,2.696086,,,,,,,,,,,,,, -79200,1.4670895,2.2326846,,,,,,,,,,,,,, -79214,,,0.716015636920929,1.125848412513733,0.6620999574661255,1.3742302656173706,50000.0,0.5348000526428223,2.051039934158325,10000.0,35333.96521568298,38102.94189476967,35333.96521568298,2760.8076634407043,3.956195592880249,0.0 -79300,1.3042599,2.1486096,,,,,,,,,,,,,, -79400,1.3627421,2.167209,,,,,,,,,,,,,, -79500,1.2250727,4.5855937,,,,,,,,,,,,,, -79600,1.1841778,3.4365456,,,,,,,,,,,,,, -79700,1.287109,2.7900767,,,,,,,,,,,,,, -79800,1.1705822,3.6960123,,,,,,,,,,,,,, -79900,1.2830875,2.0102618,,,,,,,,,,,,,, -80000,1.4965364,2.16699,,,,,,,,,,,,,, -80100,1.1775521,4.8314734,,,,,,,,,,,,,, -80154,,,0.7089648246765137,1.1660743951797483,0.6569399833679199,1.411213755607605,50000.0,0.5333999991416931,2.053675651550293,10000.0,35753.91080546379,38556.032870054245,35753.91080546379,2793.860199928284,3.9993815422058105,0.0 -80200,1.10866,4.5139685,,,,,,,,,,,,,, -80300,1.3226448,4.7154245,,,,,,,,,,,,,, -80400,1.2062801,2.609927,,,,,,,,,,,,,, -80500,1.3705642,2.3125477,,,,,,,,,,,,,, -80600,1.3947164,2.0741763,,,,,,,,,,,,,, -80700,1.4809037,2.154861,,,,,,,,,,,,,, -80800,1.4538194,2.0417018,,,,,,,,,,,,,, -80900,1.2378434,3.3878355,,,,,,,,,,,,,, -81000,1.4372137,3.9439552,,,,,,,,,,,,,, -81093,,,0.7257421612739563,1.083377242088318,0.6646599769592285,1.3715866804122925,50000.0,0.5402000546455383,2.0339879989624023,10000.0,36173.98077607155,39009.54544401169,36173.98077607155,2827.2101967334747,4.042644023895264,0.0 -81100,1.4605606,2.2217636,,,,,,,,,,,,,, -81200,1.3228698,2.022673,,,,,,,,,,,,,, -81300,1.3422942,2.086193,,,,,,,,,,,,,, -81400,1.3726785,2.2169878,,,,,,,,,,,,,, -81500,1.3532453,2.0830517,,,,,,,,,,,,,, -81600,1.2855673,3.4392536,,,,,,,,,,,,,, -81700,1.3871132,2.0990114,,,,,,,,,,,,,, -81800,1.1661626,3.7613075,,,,,,,,,,,,,, -81900,1.4404646,2.104529,,,,,,,,,,,,,, -82000,1.3740187,2.3502672,,,,,,,,,,,,,, -82032,,,0.7426952719688416,1.036612629890442,0.6611599922180176,1.392037034034729,50000.0,0.5360000133514404,2.0379531383514404,10000.0,36594.00684118271,39463.14464759827,36594.00684118271,2860.696093082428,4.080122947692871,0.0 -82100,1.3428667,2.1145813,,,,,,,,,,,,,, -82200,1.4468514,2.105384,,,,,,,,,,,,,, -82300,1.3209935,2.4315612,,,,,,,,,,,,,, -82400,1.3596826,2.109618,,,,,,,,,,,,,, -82500,1.1738392,3.2568326,,,,,,,,,,,,,, -82600,1.282995,2.4411783,,,,,,,,,,,,,, -82700,1.2912962,2.1987228,,,,,,,,,,,,,, -82800,1.4889306,2.1366625,,,,,,,,,,,,,, -82900,1.313228,2.8425298,,,,,,,,,,,,,, -82972,,,0.7173437476158142,1.1371502876281738,0.6619399785995483,1.3823440074920654,50000.0,0.5356000065803528,2.0468437671661377,10000.0,37014.24144101143,39917.34832811356,37014.24144101143,2894.578936815262,4.116491079330444,0.0 -83000,1.4453137,2.112279,,,,,,,,,,,,,, -83100,1.1845826,3.102195,,,,,,,,,,,,,, -83200,1.261439,3.0876951,,,,,,,,,,,,,, -83300,1.3205702,2.7687967,,,,,,,,,,,,,, -83400,1.1461644,3.4481547,,,,,,,,,,,,,, -83500,1.3209124,2.5374923,,,,,,,,,,,,,, -83600,1.3919514,2.0975711,,,,,,,,,,,,,, -83700,1.3206359,3.8705046,,,,,,,,,,,,,, -83800,1.3345932,4.5939236,,,,,,,,,,,,,, -83900,1.1603816,3.979187,,,,,,,,,,,,,, -83912,,,0.7253515720367432,1.1027836799621582,0.6668999791145325,1.3718284368515017,50000.0,0.5403000116348267,2.028369903564453,10000.0,37434.266563653946,40371.982800245285,37434.266563653946,2929.101461648941,4.154000282287598,0.0 -84000,1.4930434,2.0256486,,,,,,,,,,,,,, -84100,1.2842104,3.385057,,,,,,,,,,,,,, -84200,1.2699759,3.5948977,,,,,,,,,,,,,, -84300,1.3071061,3.6044204,,,,,,,,,,,,,, -84400,1.3908032,2.2656236,,,,,,,,,,,,,, -84500,1.1848536,3.4445078,,,,,,,,,,,,,, -84600,1.3728378,2.394649,,,,,,,,,,,,,, -84700,1.4374665,2.0910604,,,,,,,,,,,,,, -84800,1.3892804,2.6794724,,,,,,,,,,,,,, -84854,,,0.7344335913658142,1.063029170036316,0.6647999882698059,1.370414137840271,50000.0,0.5380000472068787,2.027428150177002,10000.0,37854.56867027283,40824.446971178055,37854.56867027283,2961.1710698604584,4.19693660736084,0.0 -84900,1.2917932,2.2410274,,,,,,,,,,,,,, -85000,1.357098,2.1799045,,,,,,,,,,,,,, -85100,1.2900479,2.9885302,,,,,,,,,,,,,, -85200,1.3018991,3.4037414,,,,,,,,,,,,,, -85300,1.373182,2.0873399,,,,,,,,,,,,,, -85400,1.3502532,2.1189146,,,,,,,,,,,,,, -85500,1.4478446,2.076553,,,,,,,,,,,,,, -85600,1.2687899,2.6620033,,,,,,,,,,,,,, -85700,1.2469602,3.5666876,,,,,,,,,,,,,, -85792,,,0.7220702767372131,1.1111823320388794,0.6682999730110168,1.3621902465820312,50000.0,0.5444000363349915,2.021790266036988,10000.0,38274.63127756119,41278.22320103645,38274.63127756119,2994.7857854366302,4.246838569641113,0.0 -85800,1.3234833,2.212727,,,,,,,,,,,,,, -85900,1.2807157,3.856675,,,,,,,,,,,,,, -86000,1.3721395,2.071147,,,,,,,,,,,,,, -86100,1.3238887,2.6021967,,,,,,,,,,,,,, -86200,1.3857912,2.0823817,,,,,,,,,,,,,, -86300,1.3948593,2.2507954,,,,,,,,,,,,,, -86400,1.3274537,1.9730589,,,,,,,,,,,,,, -86500,1.2418249,4.186506,,,,,,,,,,,,,, -86600,1.400384,2.0749738,,,,,,,,,,,,,, -86700,1.5097866,2.0045877,,,,,,,,,,,,,, -86732,,,0.72572261095047,1.0874943733215332,0.6696000099182129,1.3518980741500854,50000.0,0.5490000247955322,1.991302251815796,10000.0,38694.55052232742,41731.68585777283,38694.55052232742,3028.2366137504578,4.2903008460998535,0.0 -86800,1.3551431,4.295032,,,,,,,,,,,,,, -86900,1.3898373,2.7959454,,,,,,,,,,,,,, -87000,1.2674327,2.8372083,,,,,,,,,,,,,, -87100,1.5119504,4.6788244,,,,,,,,,,,,,, -87200,1.3843713,2.98864,,,,,,,,,,,,,, -87300,1.3593307,2.0126555,,,,,,,,,,,,,, -87400,1.3852999,2.5226355,,,,,,,,,,,,,, -87500,1.4175491,1.969137,,,,,,,,,,,,,, -87600,1.3585794,2.5321023,,,,,,,,,,,,,, -87674,,,0.7313281297683716,1.0686558485031128,0.6709199547767639,1.350008249282837,50000.0,0.5430000424385071,2.0079195499420166,10000.0,39114.70210146904,42184.402356147766,39114.70210146904,3060.714049100876,4.327760934829712,0.0 -87700,1.3267447,3.8047614,,,,,,,,,,,,,, -87800,1.4053136,2.0531797,,,,,,,,,,,,,, -87900,1.331123,2.2452278,,,,,,,,,,,,,, -88000,1.3284729,4.444114,,,,,,,,,,,,,, -88100,1.5032283,2.056191,,,,,,,,,,,,,, -88200,1.4534496,1.9639512,,,,,,,,,,,,,, -88300,1.4868736,2.0049174,,,,,,,,,,,,,, -88400,1.3625789,3.3474286,,,,,,,,,,,,,, -88500,1.5078489,2.3963177,,,,,,,,,,,,,, -88600,1.2856741,2.6808844,,,,,,,,,,,,,, -88615,,,0.7570117115974426,0.9418240785598756,0.6727799773216248,1.3221476078033447,50000.0,0.5525000095367432,1.9742674827575684,10000.0,39534.7329928875,42638.39556074143,39534.7329928875,3094.576204776764,4.378791570663452,0.0 -88700,1.2952737,4.335161,,,,,,,,,,,,,, -88800,1.4695967,1.9255979,,,,,,,,,,,,,, -88900,1.238203,2.9492643,,,,,,,,,,,,,, -89000,1.4410394,1.9356192,,,,,,,,,,,,,, -89100,1.6413021,2.131219,,,,,,,,,,,,,, -89200,1.4175576,2.4374108,,,,,,,,,,,,,, -89300,1.3143926,4.3956146,,,,,,,,,,,,,, -89400,1.461409,1.9834273,,,,,,,,,,,,,, -89500,1.170837,2.916673,,,,,,,,,,,,,, -89555,,,0.7291601300239563,1.0930023193359375,0.675000011920929,1.347053289413452,50000.0,0.5479000210762024,2.000091075897217,10000.0,39954.71279430389,43092.5505900383,39954.71279430389,3128.6608567237854,4.420300722122192,0.0 -89600,1.440991,2.0393867,,,,,,,,,,,,,, -89700,1.1721066,3.3417306,,,,,,,,,,,,,, -89800,1.3739501,4.602762,,,,,,,,,,,,,, -89900,1.255646,4.2145834,,,,,,,,,,,,,, -90000,1.605242,2.0802674,,,,,,,,,,,,,, -90100,1.2245026,3.168831,,,,,,,,,,,,,, -90200,1.4630399,2.0042505,,,,,,,,,,,,,, -90300,1.3546128,4.5150814,,,,,,,,,,,,,, -90400,1.2808508,4.3952622,,,,,,,,,,,,,, -90496,,,0.7349218726158142,1.0604491233825684,0.6757799983024597,1.330373764038086,50000.0,0.5523000359535217,1.9852027893066408,10000.0,40374.94827008248,43545.88556671143,40374.94827008248,3161.6687231063843,4.463085412979126,0.0 -90500,1.341811,2.7495022,,,,,,,,,,,,,, -90600,1.6186033,2.2569947,,,,,,,,,,,,,, -90700,1.4481957,2.0643702,,,,,,,,,,,,,, -90800,1.5022498,4.2247868,,,,,,,,,,,,,, -90900,1.4490428,2.056003,,,,,,,,,,,,,, -91000,1.3221247,3.4442954,,,,,,,,,,,,,, -91100,1.3563144,1.9664161,,,,,,,,,,,,,, -91200,1.5360309,2.0506341,,,,,,,,,,,,,, -91300,1.344781,2.0469956,,,,,,,,,,,,,, -91400,1.2541145,3.9283538,,,,,,,,,,,,,, -91433,,,0.7516992092132568,0.9681676030158995,0.6752200126647949,1.3013813495635986,50000.0,0.5514000058174133,1.9640438556671145,10000.0,40795.17732858658,44000.42345237732,40795.17732858658,3195.8801221847534,4.511308193206787,0.0 -91500,1.3857882,2.2549856,,,,,,,,,,,,,, -91600,1.2800034,3.8028185,,,,,,,,,,,,,, -91700,1.3995082,2.3686972,,,,,,,,,,,,,, -91800,1.5341535,2.1047258,,,,,,,,,,,,,, -91900,1.3871641,1.9793104,,,,,,,,,,,,,, -92000,1.4937398,2.0313601,,,,,,,,,,,,,, -92100,1.1902883,3.4207544,,,,,,,,,,,,,, -92200,1.665868,1.9275348,,,,,,,,,,,,,, -92300,1.5980619,1.9672598,,,,,,,,,,,,,, -92375,,,0.7336523532867432,1.0666475296020508,0.6757599711418152,1.3290544748306274,50000.0,0.5496000051498413,1.9641692638397217,10000.0,41215.14496469498,44452.26744532585,41215.14496469498,3227.666731834412,4.5512306690216064,0.0 -92400,1.2944636,2.6409442,,,,,,,,,,,,,, -92500,1.5682546,2.1239498,,,,,,,,,,,,,, -92600,1.4154576,2.2817738,,,,,,,,,,,,,, -92700,1.3580469,3.6472614,,,,,,,,,,,,,, -92800,1.2677585,1.9317482,,,,,,,,,,,,,, -92900,1.5547476,2.0948844,,,,,,,,,,,,,, -93000,1.3656263,3.3699327,,,,,,,,,,,,,, -93100,1.3946041,2.0182965,,,,,,,,,,,,,, -93200,1.400387,3.1068442,,,,,,,,,,,,,, -93300,1.3596265,2.5847747,,,,,,,,,,,,,, -93314,,,0.7378124594688416,1.0565478801727295,0.677839994430542,1.3238400220870972,50000.0,0.5519000291824341,1.9784936904907229,10000.0,41635.211717128754,44907.06348752976,41635.211717128754,3262.29194355011,4.605958700180054,0.0 -93400,1.4367683,4.039377,,,,,,,,,,,,,, -93500,1.4537743,2.0103319,,,,,,,,,,,,,, -93600,1.3230699,3.057311,,,,,,,,,,,,,, -93700,1.5914087,2.0629468,,,,,,,,,,,,,, -93800,1.5072181,1.9931355,,,,,,,,,,,,,, -93900,1.5054288,1.9646873,,,,,,,,,,,,,, -94000,1.5845722,3.3690004,,,,,,,,,,,,,, -94100,1.4754599,2.1481817,,,,,,,,,,,,,, -94200,1.356459,1.910528,,,,,,,,,,,,,, -94257,,,0.7453515529632568,1.0164759159088137,0.6785799860954285,1.315027952194214,50000.0,0.5519000291824341,1.9706907272338867,10000.0,42055.45607948303,45361.51790237427,42055.45607948303,3296.409605741501,4.649278879165649,0.0 -94300,1.3037212,3.089819,,,,,,,,,,,,,, -94400,1.4351858,3.9256206,,,,,,,,,,,,,, -94500,1.373838,2.3291388,,,,,,,,,,,,,, -94600,1.3087177,3.0991893,,,,,,,,,,,,,, -94700,1.5597932,2.024821,,,,,,,,,,,,,, -94800,1.4907459,1.9899794,,,,,,,,,,,,,, -94900,1.4400618,2.3495483,,,,,,,,,,,,,, -95000,1.431389,2.0399852,,,,,,,,,,,,,, -95100,1.2410414,3.458821,,,,,,,,,,,,,, -95199,,,0.7621288895606995,0.9430046677589417,0.676539957523346,1.3137037754058838,50000.0,0.5520000457763672,1.978226661682129,10000.0,42475.57057905197,45816.79495024681,42475.57057905197,3331.481355428696,4.691254615783691,0.0 -95200,1.5341105,1.9849694,,,,,,,,,,,,,, -95300,1.4263616,2.07928,,,,,,,,,,,,,, -95400,1.4125292,2.3191996,,,,,,,,,,,,,, -95500,1.2955163,2.598453,,,,,,,,,,,,,, -95600,1.3295335,2.984951,,,,,,,,,,,,,, -95700,1.3821639,2.0798163,,,,,,,,,,,,,, -95800,1.3020228,4.4619613,,,,,,,,,,,,,, -95900,1.5018963,1.8291423,,,,,,,,,,,,,, -96000,1.392724,4.0088177,,,,,,,,,,,,,, -96100,1.4227796,1.93754,,,,,,,,,,,,,, -96140,,,0.7389257550239563,1.0446540117263794,0.680679976940155,1.3052889108657837,50000.0,0.556600034236908,1.9483904838562007,10000.0,42895.6915576458,46268.998499155045,42895.6915576458,3363.4759736061096,4.730247259140015,0.0 -96200,1.7294432,4.4987555,,,,,,,,,,,,,, -96300,1.4831566,1.9649572,,,,,,,,,,,,,, -96400,1.3019984,3.7289512,,,,,,,,,,,,,, -96500,1.3277096,3.979923,,,,,,,,,,,,,, -96600,1.292872,2.8507242,,,,,,,,,,,,,, -96700,1.3216416,3.7374597,,,,,,,,,,,,,, -96800,1.5495145,2.049169,,,,,,,,,,,,,, -96900,1.5838486,2.1073341,,,,,,,,,,,,,, -97000,1.1738456,3.0574186,,,,,,,,,,,,,, -97078,,,0.7458202838897705,0.9984654784202576,0.6850000023841858,1.273817777633667,50000.0,0.5636000037193298,1.91632878780365,10000.0,43315.840493917465,46723.43449640274,43315.840493917465,3397.667732000351,4.775901317596436,0.0 -97100,1.6543118,1.969029,,,,,,,,,,,,,, -97200,1.2771022,2.5578852,,,,,,,,,,,,,, -97300,1.5701989,2.0094688,,,,,,,,,,,,,, -97400,1.3291044,3.3618383,,,,,,,,,,,,,, -97500,1.5016474,1.9938102,,,,,,,,,,,,,, -97600,1.4232401,1.9582877,,,,,,,,,,,,,, -97700,1.3124897,4.479345,,,,,,,,,,,,,, -97800,1.5174525,4.5786686,,,,,,,,,,,,,, -97900,1.6308084,1.8895884,,,,,,,,,,,,,, -98000,1.6438003,1.9527144,,,,,,,,,,,,,, -98018,,,0.7523632645606995,0.9861083626747132,0.6795200109481812,1.307973861694336,50000.0,0.557200014591217,1.95401668548584,10000.0,43736.18922662735,47179.78457069397,43736.18922662735,3433.5789165496826,4.816876649856567,0.0 -98100,1.4942657,1.99409,,,,,,,,,,,,,, -98200,1.5117888,1.956451,,,,,,,,,,,,,, -98300,1.7027178,1.9217957,,,,,,,,,,,,,, -98400,1.51098,2.0610085,,,,,,,,,,,,,, -98500,1.2930006,2.6409144,,,,,,,,,,,,,, -98600,1.3703531,3.6722806,,,,,,,,,,,,,, -98700,1.4084039,2.5553024,,,,,,,,,,,,,, -98800,1.414297,3.143781,,,,,,,,,,,,,, -98900,1.3131905,3.591545,,,,,,,,,,,,,, -98959,,,0.7362499833106995,1.0587328672409058,0.683899998664856,1.307403326034546,50000.0,0.5526000261306763,1.950338363647461,10000.0,44156.29066824913,47633.53465199471,44156.29066824913,3467.1358416080475,4.856912851333618,0.0 -99000,1.5394845,1.9633572,,,,,,,,,,,,,, -99100,1.3560449,4.294358,,,,,,,,,,,,,, -99200,1.3507757,2.9353566,,,,,,,,,,,,,, -99300,1.7110748,1.9280139,,,,,,,,,,,,,, -99400,1.4800668,1.8574227,,,,,,,,,,,,,, -99500,1.4817985,3.354962,,,,,,,,,,,,,, -99600,1.4608941,1.9035525,,,,,,,,,,,,,, -99700,1.3524038,2.3032165,,,,,,,,,,,,,, -99800,1.4149976,1.9822866,,,,,,,,,,,,,, -99898,,,0.7546093463897705,0.9772475361824036,0.6888599991798401,1.2703135013580322,50000.0,0.5678000450134277,1.9063079357147217,10000.0,44576.2303814888,48088.621727228165,44576.2303814888,3502.182624578476,4.907997369766235,0.0 -99900,1.3922822,2.726629,,,,,,,,,,,,,, -100000,1.5592881,1.9009922,,,,,,,,,,,,,, -100100,1.6423703,2.0712993,,,,,,,,,,,,,, -100200,1.489687,3.2491353,,,,,,,,,,,,,, -100300,1.4671541,2.1040459,,,,,,,,,,,,,, -100400,1.580574,1.8453641,,,,,,,,,,,,,, -100500,1.4946594,1.959491,,,,,,,,,,,,,, -100600,1.3885481,3.7880065,,,,,,,,,,,,,, -100700,1.5955219,1.8636204,,,,,,,,,,,,,, -100800,1.4945066,1.9373512,,,,,,,,,,,,,, -100839,,,0.7542577981948853,0.9589452147483826,0.6887199878692627,1.2671600580215454,50000.0,0.5581000447273254,1.9453516006469729,10000.0,44996.54544496536,48542.416513204575,44996.54544496536,3535.5730526447296,4.947621583938599,0.0 -100900,1.6175752,1.7166058,,,,,,,,,,,,,, -101000,1.4626932,2.2358499,,,,,,,,,,,,,, -101100,1.5173435,3.409636,,,,,,,,,,,,,, -101200,1.5492032,1.8758616,,,,,,,,,,,,,, -101300,1.4760085,2.1367993,,,,,,,,,,,,,, -101400,1.3358654,3.519519,,,,,,,,,,,,,, -101500,1.6482536,4.05075,,,,,,,,,,,,,, -101600,1.3296652,2.6015937,,,,,,,,,,,,,, -101700,1.5982039,1.8294272,,,,,,,,,,,,,, -101780,,,0.771289050579071,0.90166836977005,0.6931999921798706,1.242873191833496,50000.0,0.5766000151634216,1.866647720336914,10000.0,45416.76748251915,48999.10462117195,45416.76748251915,3571.9376525878906,4.999526977539063,0.0 -101800,1.2999083,3.4701157,,,,,,,,,,,,,, -101900,1.3308218,3.3139305,,,,,,,,,,,,,, -102000,1.3775531,2.8434935,,,,,,,,,,,,,, -102100,1.3865509,3.7110715,,,,,,,,,,,,,, -102200,1.4969712,2.0956624,,,,,,,,,,,,,, -102300,1.5618058,1.8739401,,,,,,,,,,,,,, -102400,1.53505,1.8746355,,,,,,,,,,,,,, -102500,1.6691258,2.101328,,,,,,,,,,,,,, -102600,1.3505889,3.3925955,,,,,,,,,,,,,, -102700,1.4490428,3.3216283,,,,,,,,,,,,,, -102719,,,0.7533984184265137,0.9679220914840698,0.6896799802780151,1.2444446086883545,50000.0,0.5665000081062317,1.892961859703064,10000.0,45836.98096561432,49454.88519287109,45836.98096561432,3607.411679744721,5.043881177902222,0.0 -102800,1.5639979,4.3851876,,,,,,,,,,,,,, -102900,1.4860355,2.1804316,,,,,,,,,,,,,, -103000,1.6236466,2.0033817,,,,,,,,,,,,,, -103100,1.4584148,1.9427311,,,,,,,,,,,,,, -103200,1.5207243,1.8424027,,,,,,,,,,,,,, -103300,1.5902461,1.7483678,,,,,,,,,,,,,, -103400,1.5333129,2.9405186,,,,,,,,,,,,,, -103500,1.4168593,4.2429075,,,,,,,,,,,,,, -103600,1.5946151,2.0253847,,,,,,,,,,,,,, -103657,,,0.7592577934265137,0.9390373229980468,0.6936799883842468,1.2363741397857666,50000.0,0.5705000162124634,1.8715754747390747,10000.0,46257.00506234169,49908.68248295784,46257.00506234169,3641.093369960785,5.086168050765991,0.0 -103700,1.5252962,2.4996862,,,,,,,,,,,,,, -103800,1.4866214,1.900766,,,,,,,,,,,,,, -103900,1.5926696,1.961413,,,,,,,,,,,,,, -104000,1.5471417,2.8735995,,,,,,,,,,,,,, -104100,1.5314584,4.330991,,,,,,,,,,,,,, -104200,1.4754808,2.2348387,,,,,,,,,,,,,, -104300,1.5338256,2.0634084,,,,,,,,,,,,,, -104400,1.3889235,3.1078095,,,,,,,,,,,,,, -104500,1.5176306,2.4319005,,,,,,,,,,,,,, -104597,,,0.77699214220047,0.8573799133300781,0.6948599815368652,1.222651720046997,50000.0,0.5694000124931335,1.8680320978164675,10000.0,46676.93702673912,50362.851440668106,46676.93702673912,3675.241497993469,5.125425815582275,0.0 -104600,2.049364,4.499935,,,,,,,,,,,,,, -104700,1.3873135,2.8763225,,,,,,,,,,,,,, -104800,1.549115,1.9380157,,,,,,,,,,,,,, -104900,1.3958896,3.9589474,,,,,,,,,,,,,, -105000,1.4118475,2.5588667,,,,,,,,,,,,,, -105100,1.5622796,3.7369373,,,,,,,,,,,,,, -105200,1.4454135,2.39992,,,,,,,,,,,,,, -105300,1.5506065,1.889842,,,,,,,,,,,,,, -105400,1.6088333,2.2179275,,,,,,,,,,,,,, -105500,1.5400411,1.9155742,,,,,,,,,,,,,, -105536,,,0.7588671445846558,0.9336987137794496,0.6984599828720093,1.214532732963562,50000.0,0.5735000371932983,1.864123106002808,10000.0,47097.02132034302,50818.641756773,47097.02132034302,3710.8573791980734,5.166118860244751,0.0 -105600,1.4249222,2.392458,,,,,,,,,,,,,, -105700,1.6274348,4.2405243,,,,,,,,,,,,,, -105800,1.4629662,1.8809923,,,,,,,,,,,,,, -105900,1.5472022,1.9969201,,,,,,,,,,,,,, -106000,1.6522114,1.8923374,,,,,,,,,,,,,, -106100,1.5668358,4.2722826,,,,,,,,,,,,,, -106200,1.6346601,1.8556023,,,,,,,,,,,,,, -106300,1.8714049,2.0284991,,,,,,,,,,,,,, -106400,1.589554,1.8767731,,,,,,,,,,,,,, -106475,,,0.7658007740974426,0.917033851146698,0.6999399662017822,1.2070460319519043,50000.0,0.5738000273704529,1.854724287986756,10000.0,47516.95656371117,51272.64339399338,47516.95656371117,3744.832273721695,5.208755970001221,0.0 -106500,1.5859164,1.8344822,,,,,,,,,,,,,, -106600,1.4622539,3.8667367,,,,,,,,,,,,,, -106700,1.4593434,2.448323,,,,,,,,,,,,,, -106800,1.5055879,2.6027522,,,,,,,,,,,,,, -106900,1.711754,1.9107138,,,,,,,,,,,,,, -107000,1.5016524,1.8337375,,,,,,,,,,,,,, -107100,1.6069152,4.306797,,,,,,,,,,,,,, -107200,1.4688375,3.3727207,,,,,,,,,,,,,, -107300,1.7778342,1.8911173,,,,,,,,,,,,,, -107400,1.419709,2.5628428,,,,,,,,,,,,,, -107415,,,0.7672070264816284,0.9189236760139464,0.6965199708938599,1.229423761367798,50000.0,0.5769000053405762,1.873677492141724,10000.0,47937.12469291687,51724.62146115303,47937.12469291687,3776.548688173294,5.253453969955444,0.0 -107500,1.6057978,1.8229609,,,,,,,,,,,,,, -107600,1.544873,3.95264,,,,,,,,,,,,,, -107700,1.5305641,1.8939373,,,,,,,,,,,,,, -107800,1.3719406,2.771516,,,,,,,,,,,,,, -107900,1.5885571,1.884184,,,,,,,,,,,,,, -108000,1.585402,1.8650701,,,,,,,,,,,,,, -108100,1.6086473,2.268419,,,,,,,,,,,,,, -108200,1.3857812,2.879686,,,,,,,,,,,,,, -108285,,,0.7604101300239563,0.9380630254745485,0.7011599540710449,1.219152331352234,50000.0,0.5774000287055969,1.8574777841567995,10000.0,48357.32349872589,52180.58470964432,48357.32349872589,3812.2145340442657,5.30767297744751,0.0 -108300,1.5368958,2.098843,,,,,,,,,,,,,, -108400,1.5655504,1.8688776,,,,,,,,,,,,,, -108500,1.5171769,4.207766,,,,,,,,,,,,,, -108600,1.6913805,1.826133,,,,,,,,,,,,,, -108700,1.7049563,1.8123087,,,,,,,,,,,,,, -108800,1.4208263,3.7577918,,,,,,,,,,,,,, -108900,1.5073848,3.8680656,,,,,,,,,,,,,, -109000,1.5730785,1.9143552,,,,,,,,,,,,,, -109100,1.5852093,2.0554183,,,,,,,,,,,,,, -109200,1.5168132,1.9746816,,,,,,,,,,,,,, -109226,,,0.7692773342132568,0.918634593486786,0.7051799893379211,1.20063054561615,50000.0,0.5776000022888184,1.840431928634644,10000.0,48777.63230252266,52635.37345814705,48777.63230252266,3846.595866441727,5.356571435928345,0.0 -109300,1.64152,4.0410013,,,,,,,,,,,,,, -109400,1.5626075,2.0369215,,,,,,,,,,,,,, -109500,1.4445105,2.64804,,,,,,,,,,,,,, -109600,1.4947305,2.7140155,,,,,,,,,,,,,, -109700,1.6736686,4.4063697,,,,,,,,,,,,,, -109800,1.5351541,4.0139375,,,,,,,,,,,,,, -109900,1.4398171,2.3515956,,,,,,,,,,,,,, -110000,1.4949465,2.6905212,,,,,,,,,,,,,, -110100,1.453783,2.1697848,,,,,,,,,,,,,, -110169,,,0.7645898461341858,0.961553394794464,0.6979799866676331,1.2633581161499023,50000.0,0.5714000463485718,1.9044029712677,10000.0,49197.60548973084,53089.98840522766,49197.60548973084,3881.148336172104,5.3964619636535645,0.0 -110200,1.5489955,4.037827,,,,,,,,,,,,,, -110300,1.6113135,1.7434294,,,,,,,,,,,,,, -110400,1.7022287,1.7480538,,,,,,,,,,,,,, -110500,1.4298414,3.4902406,,,,,,,,,,,,,, -110600,1.5876411,4.2013535,,,,,,,,,,,,,, -110700,1.6101018,2.0477486,,,,,,,,,,,,,, -110800,1.542499,2.7439408,,,,,,,,,,,,,, -110900,1.7255981,1.8621231,,,,,,,,,,,,,, -111000,1.9706558,1.8179,,,,,,,,,,,,,, -111100,1.7447405,1.9869119,,,,,,,,,,,,,, -111109,,,0.7912499904632568,0.8058376908302307,0.7072599530220032,1.1865533590316772,50000.0,0.5804000496864319,1.8377219438552856,10000.0,49617.58262228966,53544.45907402039,49617.58262228966,3915.5525193214417,5.436882019042969,0.0 -111200,1.508884,1.8412136,,,,,,,,,,,,,, -111300,1.7482213,1.8770169,,,,,,,,,,,,,, -111400,1.702011,1.8997445,,,,,,,,,,,,,, -111500,1.9014014,1.7918007,,,,,,,,,,,,,, -111600,1.673929,1.852558,,,,,,,,,,,,,, -111700,1.5744764,2.3289177,,,,,,,,,,,,,, -111800,1.5424465,1.6834359,,,,,,,,,,,,,, -111900,1.6875408,3.2655048,,,,,,,,,,,,,, -112000,1.7329718,3.7056525,,,,,,,,,,,,,, -112049,,,0.7679296731948853,0.9026365280151368,0.7026000022888184,1.195265769958496,50000.0,0.5787000060081482,1.8381128311157229,10000.0,50037.971556425095,53999.140706539154,50037.971556425095,3949.753340244293,5.479412078857422,0.0 -112100,1.5250288,3.9719837,,,,,,,,,,,,,, -112200,1.7731024,1.824527,,,,,,,,,,,,,, -112300,1.8541284,1.830108,,,,,,,,,,,,,, -112400,1.7359146,3.6289124,,,,,,,,,,,,,, -112500,1.7021472,1.9325018,,,,,,,,,,,,,, -112600,1.6340785,1.7894002,,,,,,,,,,,,,, -112700,1.748206,1.8442831,,,,,,,,,,,,,, -112800,1.7503731,3.4393528,,,,,,,,,,,,,, -112900,1.652384,1.7339091,,,,,,,,,,,,,, -112988,,,0.774707019329071,0.8800515532493591,0.707319974899292,1.1850789785385132,50000.0,0.5800000429153442,1.8336036205291748,10000.0,50457.97093844414,54453.84243106842,50457.97093844414,3984.365024328232,5.520999431610107,0.0 -113000,1.578292,2.4102924,,,,,,,,,,,,,, -113100,1.7363389,1.8584496,,,,,,,,,,,,,, -113200,1.5545968,3.3592072,,,,,,,,,,,,,, -113300,1.5717659,3.4825242,,,,,,,,,,,,,, -113400,1.5777175,1.7677047,,,,,,,,,,,,,, -113500,1.5970863,2.7966208,,,,,,,,,,,,,, -113600,1.7420478,1.7478099,,,,,,,,,,,,,, -113700,1.5136478,2.5385072,,,,,,,,,,,,,, -113800,1.7931974,1.8965464,,,,,,,,,,,,,, -113900,1.6181139,2.0719361,,,,,,,,,,,,,, -113928,,,0.7822265625,0.8601695895195007,0.707099974155426,1.1927919387817385,50000.0,0.5804000496864319,1.827962636947632,10000.0,50877.95106720925,54909.1758646965,50877.95106720925,4019.6268467903137,5.563291549682617,0.0 -114000,1.5656664,3.6818254,,,,,,,,,,,,,, -114100,1.8326117,1.8860564,,,,,,,,,,,,,, -114200,1.4985923,1.9182738,,,,,,,,,,,,,, -114300,1.7380615,4.21208,,,,,,,,,,,,,, -114400,1.7793572,4.273944,,,,,,,,,,,,,, -114500,1.613016,1.7947036,,,,,,,,,,,,,, -114600,1.6795373,4.011677,,,,,,,,,,,,,, -114700,1.5516384,3.4332395,,,,,,,,,,,,,, -114800,1.5847931,4.2295413,,,,,,,,,,,,,, -114868,,,0.7724804282188416,0.8807520866394043,0.7059599757194519,1.169217824935913,50000.0,0.5825000405311584,1.818506002426148,10000.0,51298.191059827805,55363.03579878807,51298.191059827805,4053.151378154754,5.610148906707764,0.0 -114900,1.7710881,4.1043215,,,,,,,,,,,,,, -115000,1.5329669,3.392726,,,,,,,,,,,,,, -115100,1.572897,1.695608,,,,,,,,,,,,,, -115200,1.5443575,2.9855504,,,,,,,,,,,,,, -115300,1.6224661,1.8586284,,,,,,,,,,,,,, -115400,1.5836323,1.8003273,,,,,,,,,,,,,, -115500,1.5946717,3.4981303,,,,,,,,,,,,,, -115600,1.6931182,4.2491846,,,,,,,,,,,,,, -115700,1.7189547,4.0293984,,,,,,,,,,,,,, -115800,1.6130106,3.8270154,,,,,,,,,,,,,, -115805,,,0.78187495470047,0.8490674495697021,0.7106599807739258,1.1583318710327148,50000.0,0.5877000093460083,1.7957428693771362,10000.0,51718.180342674255,55816.513998031616,51718.180342674255,4086.5362679958334,5.664799213409424,0.0 -115900,1.6417555,1.7599673,,,,,,,,,,,,,, -116000,1.4911853,2.1640787,,,,,,,,,,,,,, -116100,1.6356394,3.3100235,,,,,,,,,,,,,, -116200,1.7141557,1.7267345,,,,,,,,,,,,,, -116300,1.7606702,1.7218318,,,,,,,,,,,,,, -116400,1.5168979,2.0731807,,,,,,,,,,,,,, -116500,1.7822318,1.8566675,,,,,,,,,,,,,, -116600,1.7296687,1.7508205,,,,,,,,,,,,,, -116700,1.5342379,2.877153,,,,,,,,,,,,,, -116744,,,0.7876366972923279,0.8236556053161621,0.7130599617958069,1.1581529378890991,50000.0,0.5914000272750854,1.7908883094787598,10000.0,52138.19392943382,56272.46994638443,52138.19392943382,4122.379958868027,5.713786840438843,0.0 -116800,1.6713997,1.96786,,,,,,,,,,,,,, -116900,1.5030047,2.2739203,,,,,,,,,,,,,, -117000,1.6870302,3.2760859,,,,,,,,,,,,,, -117100,1.7475076,1.7055228,,,,,,,,,,,,,, -117200,1.6161306,1.8348008,,,,,,,,,,,,,, -117300,1.5240865,2.065654,,,,,,,,,,,,,, -117400,1.6695534,1.9621435,,,,,,,,,,,,,, -117500,1.5361521,3.3806922,,,,,,,,,,,,,, -117600,1.6265724,2.1486716,,,,,,,,,,,,,, -117685,,,0.8011718392372131,0.7712819576263428,0.7135599851608276,1.1575746536254885,50000.0,0.5821000337600708,1.8017897605896,10000.0,52558.495411634445,56727.45986151695,52558.495411634445,4156.973059415817,5.760483264923096,0.0 -117700,1.7674093,1.8384128,,,,,,,,,,,,,, -117800,1.5729072,3.9588761,,,,,,,,,,,,,, -117900,1.727557,1.9422522,,,,,,,,,,,,,, -118000,1.8035258,4.211975,,,,,,,,,,,,,, -118100,1.6627327,1.8252097,,,,,,,,,,,,,, -118200,1.7218019,1.7921543,,,,,,,,,,,,,, -118300,1.7977692,2.2839532,,,,,,,,,,,,,, -118400,1.8706001,1.740094,,,,,,,,,,,,,, -118500,1.7050017,1.7998841,,,,,,,,,,,,,, -118600,1.9979848,4.3082805,,,,,,,,,,,,,, -118624,,,0.7796484231948853,0.8606301546096802,0.7136399745941162,1.1567219495773315,50000.0,0.5915000438690186,1.8010717630386353,10000.0,52978.57805562019,57183.69558787346,52978.57805562019,4193.030814886093,5.806100606918335,0.0 -118700,1.777695,1.8836133,,,,,,,,,,,,,, -118800,1.5188153,3.1851568,,,,,,,,,,,,,, -118900,1.7760861,1.8378148,,,,,,,,,,,,,, -119000,1.6125665,4.0296273,,,,,,,,,,,,,, -119100,1.5571595,2.5224116,,,,,,,,,,,,,, -119200,1.7565304,3.7153478,,,,,,,,,,,,,, -119300,1.675628,1.7618134,,,,,,,,,,,,,, -119400,1.6180446,2.7404032,,,,,,,,,,,,,, -119500,1.7803844,1.941826,,,,,,,,,,,,,, -119566,,,0.7863476276397705,0.8422191143035889,0.7168399691581726,1.1445538997650146,50000.0,0.5922000408172607,1.7661399841308594,10000.0,53398.8308763504,57638.8034594059,53398.8308763504,4227.795293569565,5.846826076507568,0.0 -119600,1.7680888,1.6353577,,,,,,,,,,,,,, -119700,1.6409222,3.7666435,,,,,,,,,,,,,, -119800,1.5702143,2.969072,,,,,,,,,,,,,, -119900,1.9041834,1.7124435,,,,,,,,,,,,,, -120000,1.7410867,1.7062323,,,,,,,,,,,,,, -120100,1.6868591,1.7883296,,,,,,,,,,,,,, -120200,1.8976734,1.8767295,,,,,,,,,,,,,, -120300,1.806421,1.6094093,,,,,,,,,,,,,, -120400,1.7648735,1.679689,,,,,,,,,,,,,, -120500,1.86205,1.7435071,,,,,,,,,,,,,, -120503,,,0.7981640696525574,0.784917950630188,0.7168599963188171,1.13680100440979,50000.0,0.5898000001907349,1.7828012704849243,10000.0,53819.15520334244,58095.27610850334,53819.15520334244,4263.846667289734,5.894752502441406,0.0 -120600,1.6503105,1.7393322,,,,,,,,,,,,,, -120700,1.6320552,3.0102468,,,,,,,,,,,,,, -120800,1.6162748,2.397592,,,,,,,,,,,,,, -120900,1.8215734,1.6574345,,,,,,,,,,,,,, -121000,1.5330468,3.0481095,,,,,,,,,,,,,, -121100,1.7459856,1.6721301,,,,,,,,,,,,,, -121200,1.5653152,2.4882357,,,,,,,,,,,,,, -121300,1.7125721,1.7518814,,,,,,,,,,,,,, -121400,1.9462198,1.6631289,,,,,,,,,,,,,, -121441,,,0.7860156297683716,0.8371008038520813,0.7204799652099609,1.1226909160614014,50000.0,0.5955000519752502,1.764423966407776,10000.0,54239.20383620262,58552.858276844025,54239.20383620262,4301.28515791893,5.94125771522522,0.0 -121500,1.6140289,2.6805017,,,,,,,,,,,,,, -121600,1.7330476,2.2774408,,,,,,,,,,,,,, -121700,1.8689246,3.7025313,,,,,,,,,,,,,, -121800,1.7165937,2.9128518,,,,,,,,,,,,,, -121900,1.856157,1.6837177,,,,,,,,,,,,,, -122000,1.7629328,4.0927052,,,,,,,,,,,,,, -122100,1.7513587,1.624912,,,,,,,,,,,,,, -122200,1.7995235,1.7573652,,,,,,,,,,,,,, -122300,1.7489113,3.49914,,,,,,,,,,,,,, -122374,,,0.7886718511581421,0.8289256691932678,0.7191599607467651,1.1320635080337524,50000.0,0.5901000499725342,1.7791229486465454,10000.0,54659.44230914116,59007.31868457794,54659.44230914116,4335.413290023804,5.986897230148315,0.0 -122400,1.8361965,1.7518439,,,,,,,,,,,,,, -122500,1.723548,1.939967,,,,,,,,,,,,,, -122600,1.8808336,2.2149324,,,,,,,,,,,,,, -122700,1.7165781,1.811854,,,,,,,,,,,,,, -122800,1.8229693,1.850201,,,,,,,,,,,,,, -122900,1.8336742,1.714595,,,,,,,,,,,,,, -123000,1.8071723,1.6487355,,,,,,,,,,,,,, -123100,1.7543545,1.6392893,,,,,,,,,,,,,, -123200,1.5482665,3.0645397,,,,,,,,,,,,,, -123300,1.6977772,2.258478,,,,,,,,,,,,,, -123313,,,0.7957812547683716,0.767406165599823,0.7225199937820435,1.1092499494552612,50000.0,0.5976999998092651,1.7560497522354126,10000.0,55079.53868961334,59462.60575866699,55079.53868961334,4370.508100986481,6.033722162246704,0.0 -123400,1.8158644,3.4302425,,,,,,,,,,,,,, -123500,1.846017,2.3482199,,,,,,,,,,,,,, -123600,1.7591046,3.908605,,,,,,,,,,,,,, -123700,1.850043,1.6699524,,,,,,,,,,,,,, -123800,1.8835508,1.8041589,,,,,,,,,,,,,, -123900,1.8369874,1.6556506,,,,,,,,,,,,,, -124000,1.9301604,1.6239022,,,,,,,,,,,,,, -124100,1.7746665,3.6106253,,,,,,,,,,,,,, -124200,1.7360505,1.8130877,,,,,,,,,,,,,, -124253,,,0.8089062571525574,0.7466806769371033,0.7249999642372131,1.1076456308364868,50000.0,0.598300039768219,1.7344789505004885,10000.0,55499.47381663322,59917.94662761688,55499.47381663322,4405.81605887413,6.080605506896973,0.0 -124300,1.8115327,2.069624,,,,,,,,,,,,,, -124400,1.7958186,4.0685744,,,,,,,,,,,,,, -124500,1.7658908,2.0141027,,,,,,,,,,,,,, -124600,1.8598068,4.053053,,,,,,,,,,,,,, -124700,1.7600784,1.7553506,,,,,,,,,,,,,, -124800,1.8039242,1.7880266,,,,,,,,,,,,,, -124900,1.7992336,1.887616,,,,,,,,,,,,,, -125000,1.5692075,2.5779111,,,,,,,,,,,,,, -125100,1.7230841,3.8849094,,,,,,,,,,,,,, -125194,,,0.7921679615974426,0.793973445892334,0.7214199900627136,1.105116844177246,50000.0,0.6028000116348267,1.7382676601409912,10000.0,55919.53238034248,60373.53852200508,55919.53238034248,4441.247399568558,6.1333723068237305,0.0 -125200,1.7647834,3.9867206,,,,,,,,,,,,,, -125300,1.6672218,3.2684305,,,,,,,,,,,,,, -125400,1.890148,1.5867493,,,,,,,,,,,,,, -125500,1.7452729,2.0851398,,,,,,,,,,,,,, -125600,1.6127255,2.3173351,,,,,,,,,,,,,, -125700,1.8729267,1.7972329,,,,,,,,,,,,,, -125800,1.7784407,1.6648805,,,,,,,,,,,,,, -125900,1.878365,1.7758232,,,,,,,,,,,,,, -126000,1.8029284,1.8353285,,,,,,,,,,,,,, -126100,1.8069061,1.665701,,,,,,,,,,,,,, -126134,,,0.8014843463897705,0.7683876156806946,0.7262399792671204,1.1003386974334717,50000.0,0.6052000522613525,1.73361337184906,10000.0,56339.733157634735,60826.4618768692,56339.733157634735,4473.872217655182,6.182005167007446,0.0 -126200,1.8630191,1.603349,,,,,,,,,,,,,, -126300,1.6552104,2.9155543,,,,,,,,,,,,,, -126400,1.7427385,1.5872883,,,,,,,,,,,,,, -126500,1.8627117,1.7106371,,,,,,,,,,,,,, -126600,1.6752198,2.2849057,,,,,,,,,,,,,, -126700,1.7586267,1.5815303,,,,,,,,,,,,,, -126800,1.9269723,1.9515756,,,,,,,,,,,,,, -126900,1.6673533,2.8842096,,,,,,,,,,,,,, -127000,1.6484177,3.013495,,,,,,,,,,,,,, -127073,,,0.8082616925239563,0.7514926791191101,0.7242000102996826,1.1179012060165403,50000.0,0.5973000526428223,1.7537530660629272,10000.0,56759.89336299896,61281.57898569107,56759.89336299896,4508.729408502579,6.23290491104126,0.0 -127100,1.9476354,1.7786739,,,,,,,,,,,,,, -127200,1.7862214,1.6196557,,,,,,,,,,,,,, -127300,1.7903432,3.7587993,,,,,,,,,,,,,, -127400,2.290256,4.0333257,,,,,,,,,,,,,, -127500,1.683937,2.8446305,,,,,,,,,,,,,, -127600,1.8798571,1.6800203,,,,,,,,,,,,,, -127700,2.1363,1.6708181,,,,,,,,,,,,,, -127800,1.8842587,3.971035,,,,,,,,,,,,,, -127900,1.6923234,3.0614421,,,,,,,,,,,,,, -128000,1.9288242,1.670445,,,,,,,,,,,,,, -128013,,,0.7968358993530273,0.7883461713790894,0.7251600027084351,1.1023790836334229,50000.0,0.6034000515937805,1.7308070659637451,10000.0,57180.19188570976,61736.58455133438,57180.19188570976,4543.34293794632,6.276918172836304,0.0 -128100,1.7797453,1.9204111,,,,,,,,,,,,,, -128200,1.9189175,1.6122987,,,,,,,,,,,,,, -128300,1.9638413,1.6283139,,,,,,,,,,,,,, -128400,1.8742516,2.0384636,,,,,,,,,,,,,, -128500,1.7483308,2.7580366,,,,,,,,,,,,,, -128600,1.6854323,3.131191,,,,,,,,,,,,,, -128700,1.854387,2.5570066,,,,,,,,,,,,,, -128800,1.9181001,1.6267606,,,,,,,,,,,,,, -128900,2.1774273,4.104668,,,,,,,,,,,,,, -128952,,,0.8012109398841858,0.7621845602989197,0.7287200093269348,1.0881730318069458,50000.0,0.6049000024795532,1.71925950050354,10000.0,57600.443110466,62191.53056240082,57600.443110466,4577.940878629684,6.324890613555908,0.0 -129000,2.5586414,4.0402493,,,,,,,,,,,,,, -129100,1.9289632,1.4713554,,,,,,,,,,,,,, -129200,1.8633627,1.8099564,,,,,,,,,,,,,, -129300,1.5730519,2.4779336,,,,,,,,,,,,,, -129400,1.6590716,2.5846605,,,,,,,,,,,,,, -129500,2.1711056,1.5943605,,,,,,,,,,,,,, -129600,1.9532802,1.6400964,,,,,,,,,,,,,, -129700,1.6782078,1.9066076,,,,,,,,,,,,,, -129800,2.2722688,3.9006758,,,,,,,,,,,,,, -129891,,,0.8115820288658142,0.7281399965286255,0.7297999858856201,1.0883889198303225,50000.0,0.6082000136375427,1.7195762395858765,10000.0,58020.502247571945,62646.54111742973,58020.502247571945,4612.796590805054,6.371863126754761,0.0 -129900,1.7464402,1.5776023,,,,,,,,,,,,,, -130000,2.0073988,3.7155628,,,,,,,,,,,,,, -130100,1.703107,3.0034375,,,,,,,,,,,,,, -130200,1.7285663,1.8579433,,,,,,,,,,,,,, -130300,1.984752,1.6572852,,,,,,,,,,,,,, -130400,1.8507739,1.6749668,,,,,,,,,,,,,, -130500,1.9368402,1.6013055,,,,,,,,,,,,,, -130600,1.6821606,2.7887785,,,,,,,,,,,,,, -130700,1.847589,1.6950791,,,,,,,,,,,,,, -130800,1.9083441,1.6197575,,,,,,,,,,,,,, -130828,,,0.8115429282188416,0.721081018447876,0.7346400022506714,1.0667911767959597,50000.0,0.6073000431060791,1.7005314826965332,10000.0,58440.81789493561,63103.182911872864,58440.81789493561,4649.020402431488,6.425073385238648,0.0 -130900,2.0889122,1.617796,,,,,,,,,,,,,, -131000,1.8403901,3.6082206,,,,,,,,,,,,,, -131100,1.8832645,2.184811,,,,,,,,,,,,,, -131200,2.0190954,1.6659385,,,,,,,,,,,,,, -131300,1.6961554,1.901165,,,,,,,,,,,,,, -131400,2.253873,1.5993931,,,,,,,,,,,,,, -131500,2.052289,1.7235469,,,,,,,,,,,,,, -131600,2.0223033,1.7370083,,,,,,,,,,,,,, -131700,1.8872503,1.6367499,,,,,,,,,,,,,, -131768,,,0.8031835556030273,0.7564061284065247,0.7330399751663208,1.073096513748169,50000.0,0.6079000234603882,1.708517074584961,10000.0,58860.99866771698,63559.92962670326,58860.99866771698,4685.490235805512,6.472413301467896,0.0 -131800,1.8159975,2.4372725,,,,,,,,,,,,,, -131900,1.8050468,2.5283823,,,,,,,,,,,,,, -132000,1.824232,2.0756388,,,,,,,,,,,,,, -132100,1.8009998,2.7569897,,,,,,,,,,,,,, -132200,1.9131075,1.7466856,,,,,,,,,,,,,, -132300,2.0353746,1.5798898,,,,,,,,,,,,,, -132400,1.8759369,2.0473402,,,,,,,,,,,,,, -132500,1.8978107,2.9935715,,,,,,,,,,,,,, -132600,2.0029218,1.6064694,,,,,,,,,,,,,, -132700,2.2465794,1.5694382,,,,,,,,,,,,,, -132706,,,0.8125976324081421,0.7448723912239075,0.7306399941444397,1.090269684791565,50000.0,0.6050000190734863,1.7248249053955078,10000.0,59281.00724768639,64016.23230576515,59281.00724768639,4721.690928220749,6.516724586486816,0.0 -132800,2.019635,1.6802331,,,,,,,,,,,,,, -132900,1.7534196,2.1454034,,,,,,,,,,,,,, -133000,1.9777592,1.5729249,,,,,,,,,,,,,, -133100,1.8894695,1.4977238,,,,,,,,,,,,,, -133200,2.0438228,1.5923935,,,,,,,,,,,,,, -133300,1.966433,2.093321,,,,,,,,,,,,,, -133400,1.6394118,2.444507,,,,,,,,,,,,,, -133500,2.0739548,1.6748729,,,,,,,,,,,,,, -133600,1.9168497,1.5422873,,,,,,,,,,,,,, -133645,,,0.8241796493530273,0.6636848449707031,0.7350800037384033,1.0521515607833862,50000.0,0.6089000105857849,1.680550456047058,10000.0,59701.293223142624,64471.68787431717,59701.293223142624,4756.766691684723,6.560953378677368,0.0 -133700,2.0948312,1.6643834,,,,,,,,,,,,,, -133800,2.0370665,1.566322,,,,,,,,,,,,,, -133900,2.0423768,1.9296038,,,,,,,,,,,,,, -134000,1.7437906,2.2616415,,,,,,,,,,,,,, -134100,2.0022023,1.6434478,,,,,,,,,,,,,, -134200,1.962133,1.584209,,,,,,,,,,,,,, -134300,1.9799808,3.8188996,,,,,,,,,,,,,, -134400,2.1047602,1.6999714,,,,,,,,,,,,,, -134500,1.8005912,2.0337973,,,,,,,,,,,,,, -134584,,,0.8122656345367432,0.7251321077346802,0.7372199892997742,1.0536880493164062,50000.0,0.6104000210762024,1.690237045288086,10000.0,60121.38954138756,64927.51717305184,60121.38954138756,4792.405487298965,6.606303691864014,0.0 -134600,1.9937967,1.720707,,,,,,,,,,,,,, -134700,1.8948352,3.3367774,,,,,,,,,,,,,, -134800,1.9818885,2.1036687,,,,,,,,,,,,,, -134900,1.7559958,2.356816,,,,,,,,,,,,,, -135000,1.9415357,1.6501503,,,,,,,,,,,,,, -135100,1.8410846,1.8229897,,,,,,,,,,,,,, -135200,2.3509398,3.8117821,,,,,,,,,,,,,, -135300,2.0442884,1.5171295,,,,,,,,,,,,,, -135400,2.0644944,1.5261933,,,,,,,,,,,,,, -135500,1.9988009,1.7874836,,,,,,,,,,,,,, -135517,,,0.8183202743530273,0.7013428807258606,0.7394199967384338,1.044283747673035,50000.0,0.6121000051498413,1.66946280002594,10000.0,60541.42193317413,65383.74241781235,60541.42193317413,4828.497610330582,6.658292770385742,0.0 -135600,1.9515331,2.6494327,,,,,,,,,,,,,, -135700,1.8415753,1.4155626,,,,,,,,,,,,,, -135800,2.072001,1.472981,,,,,,,,,,,,,, -135900,2.054231,2.8924975,,,,,,,,,,,,,, -136000,2.1922393,1.6474216,,,,,,,,,,,,,, -136100,2.096766,1.4931316,,,,,,,,,,,,,, -136200,2.0475585,1.5673385,,,,,,,,,,,,,, -136300,2.2347412,2.076416,,,,,,,,,,,,,, -136400,2.0098205,1.5403723,,,,,,,,,,,,,, -136454,,,0.8281640410423279,0.6682937741279602,0.7404800057411194,1.0448720455169678,50000.0,0.6099000573158264,1.6748216152191162,10000.0,60961.35953044891,65839.95121264458,60961.35953044891,4864.676825284958,6.702035903930664,0.0 -136500,2.2017279,3.9613,,,,,,,,,,,,,, -136600,2.0202978,1.4585776,,,,,,,,,,,,,, -136700,1.9849044,1.4950594,,,,,,,,,,,,,, -136800,2.1340427,1.5274663,,,,,,,,,,,,,, -136900,2.006186,3.0571074,,,,,,,,,,,,,, -137000,2.2682464,1.5829997,,,,,,,,,,,,,, -137100,1.8680985,1.9839928,,,,,,,,,,,,,, -137200,2.1271422,2.7069824,,,,,,,,,,,,,, -137300,2.1917436,1.6089032,,,,,,,,,,,,,, -137394,,,0.8191210627555847,0.6811794638633728,0.7382599711418152,1.0337536334991455,50000.0,0.6151000261306763,1.670432448387146,10000.0,61381.284125328064,66295.96452879906,61381.284125328064,4900.667014360428,6.751345872879028,0.0 -137400,2.0418012,1.4550754,,,,,,,,,,,,,, -137500,2.2845385,1.9369895,,,,,,,,,,,,,, -137600,2.00602,1.6522546,,,,,,,,,,,,,, -137700,2.1262026,1.620418,,,,,,,,,,,,,, -137800,2.009575,2.5236077,,,,,,,,,,,,,, -137900,2.1213758,1.5659658,,,,,,,,,,,,,, -138000,2.118962,1.6382475,,,,,,,,,,,,,, -138100,2.2006571,3.4582317,,,,,,,,,,,,,, -138200,1.9861462,1.6088794,,,,,,,,,,,,,, -138300,2.279694,1.5439588,,,,,,,,,,,,,, -138333,,,0.8222265243530273,0.6762940287590027,0.7418199777603149,1.029226779937744,50000.0,0.6124000549316406,1.660211443901062,10000.0,61801.64013171196,66751.41964411736,61801.64013171196,4935.666965246201,6.801737308502197,0.0 -138400,2.196822,1.4580268,,,,,,,,,,,,,, -138500,2.1290383,1.4688545,,,,,,,,,,,,,, -138600,2.2680175,1.5293982,,,,,,,,,,,,,, -138700,1.9262346,2.4182382,,,,,,,,,,,,,, -138800,1.9310657,2.3953757,,,,,,,,,,,,,, -138900,2.1844635,1.6972775,,,,,,,,,,,,,, -139000,2.160443,3.3401241,,,,,,,,,,,,,, -139100,2.1103427,3.2546086,,,,,,,,,,,,,, -139200,2.1392055,1.513669,,,,,,,,,,,,,, -139273,,,0.8302929401397705,0.6423133015632629,0.7434799671173096,1.0174846649169922,50000.0,0.6152000427246094,1.64912211894989,10000.0,62221.607963085175,67206.81400728226,62221.607963085175,4970.997862577438,6.847434520721436,0.0 -139300,1.9679742,1.4557614,,,,,,,,,,,,,, -139400,2.1215851,2.1283422,,,,,,,,,,,,,, -139500,2.0508327,1.4217507,,,,,,,,,,,,,, -139600,2.063121,1.4543116,,,,,,,,,,,,,, -139700,2.0700321,1.7847795,,,,,,,,,,,,,, -139800,1.9478164,1.4885329,,,,,,,,,,,,,, -139900,2.1379156,1.5481951,,,,,,,,,,,,,, -140000,2.2572074,1.5036436,,,,,,,,,,,,,, -140100,2.745467,3.8725092,,,,,,,,,,,,,, -140200,2.4316075,3.7881289,,,,,,,,,,,,,, -140211,,,0.8381249904632568,0.6069045662879944,0.7433399558067322,1.016687273979187,50000.0,0.6163000464439392,1.6549649238586426,10000.0,62641.6125395298,67662.89258980751,62641.6125395298,5006.979150533676,6.8912672996521,0.0 -140300,2.1395867,1.4430205,,,,,,,,,,,,,, -140400,2.0737624,1.4546019,,,,,,,,,,,,,, -140500,2.2129467,3.4898882,,,,,,,,,,,,,, -140600,1.900992,2.7913647,,,,,,,,,,,,,, -140700,2.0938535,2.986779,,,,,,,,,,,,,, -140800,2.0077538,1.7221233,,,,,,,,,,,,,, -140900,2.1765912,1.7502474,,,,,,,,,,,,,, -141000,1.9652741,1.8886614,,,,,,,,,,,,,, -141100,1.9647851,2.0509317,,,,,,,,,,,,,, -141147,,,0.82630854845047,0.6645239591598511,0.7433599829673767,1.016500473022461,50000.0,0.6189000010490417,1.642565369606018,10000.0,63061.63824224472,68115.81379771233,63061.63824224472,5039.779830694199,6.936065435409546,0.0 -141200,2.2232337,2.5875661,,,,,,,,,,,,,, -141300,2.1104772,1.4336661,,,,,,,,,,,,,, -141400,2.4523056,3.739386,,,,,,,,,,,,,, -141500,2.1747713,1.7784299,,,,,,,,,,,,,, -141600,2.2550123,1.6275712,,,,,,,,,,,,,, -141700,2.16464,2.084764,,,,,,,,,,,,,, -141800,2.2207956,1.5231018,,,,,,,,,,,,,, -141900,2.396953,1.6858493,,,,,,,,,,,,,, -142000,2.3626826,3.6082735,,,,,,,,,,,,,, -142086,,,0.8313671946525574,0.6447664499282837,0.749019980430603,1.0095096826553345,50000.0,0.6199000477790833,1.6374740600585938,10000.0,63481.839210510254,68570.085460186,63481.839210510254,5073.741352796555,6.995683908462524,0.0 -142100,2.029079,2.6395276,,,,,,,,,,,,,, -142200,2.152048,1.6354257,,,,,,,,,,,,,, -142300,2.2282534,2.8891966,,,,,,,,,,,,,, -142400,2.0104513,2.3395863,,,,,,,,,,,,,, -142500,2.0985382,1.3873432,,,,,,,,,,,,,, -142600,2.181049,3.5799165,,,,,,,,,,,,,, -142700,2.7850919,3.672769,,,,,,,,,,,,,, -142800,2.0949042,2.0741053,,,,,,,,,,,,,, -142900,1.9622089,1.8492448,,,,,,,,,,,,,, -143000,2.2498252,2.9513438,,,,,,,,,,,,,, -143023,,,0.8359569907188416,0.6201549172401428,0.748699963092804,0.999498724937439,50000.0,0.6215000152587891,1.634261131286621,10000.0,63902.15845608711,69026.07445836067,63902.15845608711,5109.310643911362,7.047288179397583,0.0 -143100,2.2400794,1.4536569,,,,,,,,,,,,,, -143200,2.1064136,2.9987836,,,,,,,,,,,,,, -143300,2.3333032,1.5058528,,,,,,,,,,,,,, -143400,2.2235925,1.3690805,,,,,,,,,,,,,, -143500,2.2507539,1.4354547,,,,,,,,,,,,,, -143600,2.1828823,1.5129554,,,,,,,,,,,,,, -143700,2.121747,1.3791649,,,,,,,,,,,,,, -143800,2.3695352,1.4959042,,,,,,,,,,,,,, -143900,2.532089,3.59555,,,,,,,,,,,,,, -143960,,,0.8331835865974426,0.6416350603103638,0.7472400069236755,1.003479242324829,50000.0,0.6253000497817993,1.6217032670974731,10000.0,64322.31871080399,69481.3456788063,64322.31871080399,5144.327417612076,7.09255576133728,0.0 -144000,2.1376057,1.4331653,,,,,,,,,,,,,, -144100,2.190525,1.3369801,,,,,,,,,,,,,, -144200,2.2603402,1.5933851,,,,,,,,,,,,,, -144300,2.2762635,3.5503724,,,,,,,,,,,,,, -144400,2.1859832,1.7063917,,,,,,,,,,,,,, -144500,2.224522,1.2797412,,,,,,,,,,,,,, -144600,2.1476443,1.6352739,,,,,,,,,,,,,, -144700,2.1478846,2.464102,,,,,,,,,,,,,, -144800,2.1233802,1.5856873,,,,,,,,,,,,,, -144900,2.0567703,1.3066499,,,,,,,,,,,,,, -144901,,,0.8355273008346558,0.6195220947265625,0.750059962272644,0.98756343126297,50000.0,0.629300057888031,1.608303427696228,10000.0,64742.95211935043,69936.43508267403,64742.95211935043,5178.6849110126495,7.1408984661102295,0.0 -145000,2.2359238,1.4308833,,,,,,,,,,,,,, -145100,2.8584409,3.770212,,,,,,,,,,,,,, -145200,2.1540265,2.593894,,,,,,,,,,,,,, -145300,1.9740704,1.6328695,,,,,,,,,,,,,, -145400,2.353877,1.448728,,,,,,,,,,,,,, -145500,2.3197587,1.480491,,,,,,,,,,,,,, -145600,2.3483896,1.6882794,,,,,,,,,,,,,, -145700,2.1656647,3.2378485,,,,,,,,,,,,,, -145800,2.082872,1.3666898,,,,,,,,,,,,,, -145837,,,0.8406445384025574,0.5987057685852051,0.752240002155304,0.9832723736763,50000.0,0.6294000148773193,1.6080838441848757,10000.0,65162.93309760094,70389.34637379646,65162.93309760094,5211.518479347229,7.18864107131958,0.0 -145900,2.1293738,2.0260384,,,,,,,,,,,,,, -146000,2.5239537,3.5658908,,,,,,,,,,,,,, -146100,2.123895,1.3520896,,,,,,,,,,,,,, -146200,2.0412745,1.640286,,,,,,,,,,,,,, -146300,2.312068,1.3933817,,,,,,,,,,,,,, -146400,2.2356274,1.4195253,,,,,,,,,,,,,, -146500,2.7984843,3.7290893,,,,,,,,,,,,,, -146600,2.209069,2.6950727,,,,,,,,,,,,,, -146700,2.3399794,1.4192315,,,,,,,,,,,,,, -146775,,,0.8557812571525574,0.5449709296226501,0.7529000043869019,0.9782390594482422,50000.0,0.6300000548362732,1.6063361167907717,10000.0,65583.1298816204,70845.42119407654,65583.1298816204,5247.289430618286,7.246530294418335,0.0 -146800,2.0663645,2.498685,,,,,,,,,,,,,, -146900,2.3045394,1.6790403,,,,,,,,,,,,,, -147000,2.1607852,1.4729114,,,,,,,,,,,,,, -147100,2.344065,1.466707,,,,,,,,,,,,,, -147200,2.4751196,1.3972282,,,,,,,,,,,,,, -147300,2.2696486,1.4646304,,,,,,,,,,,,,, -147400,1.9236495,2.5391345,,,,,,,,,,,,,, -147500,2.2156396,2.1266093,,,,,,,,,,,,,, -147600,2.3439274,1.359234,,,,,,,,,,,,,, -147700,2.4368763,2.5115986,,,,,,,,,,,,,, -147714,,,0.8352148532867432,0.6282013654708862,0.7502999901771545,0.9910974502563475,50000.0,0.6282000541687012,1.6120684146881104,10000.0,66003.42449641228,71301.50976467133,66003.42449641228,5282.982690811157,7.298294305801392,0.0 -147800,2.4026005,3.3046827,,,,,,,,,,,,,, -147900,2.2946649,1.6036769,,,,,,,,,,,,,, -148000,2.0934157,2.3299413,,,,,,,,,,,,,, -148100,2.5555856,3.6568391,,,,,,,,,,,,,, -148200,2.3387516,1.4517994,,,,,,,,,,,,,, -148300,2.38704,2.1665072,,,,,,,,,,,,,, -148400,2.334286,1.5429184,,,,,,,,,,,,,, -148500,2.086576,2.1810195,,,,,,,,,,,,,, -148600,2.3259141,3.0103655,,,,,,,,,,,,,, -148655,,,0.8431054353713989,0.6112679243087769,0.7550599575042725,0.9850051999092102,50000.0,0.629300057888031,1.6120909452438354,10000.0,66423.42095088959,71757.16052174568,66423.42095088959,5318.54127407074,7.344376087188721,0.0 -148700,2.0983377,2.4613743,,,,,,,,,,,,,, -148800,2.30638,1.4102246,,,,,,,,,,,,,, -148900,2.3573732,1.368067,,,,,,,,,,,,,, -149000,2.543767,1.3825102,,,,,,,,,,,,,, -149100,2.4616797,1.4640528,,,,,,,,,,,,,, -149200,2.286562,1.3483732,,,,,,,,,,,,,, -149300,2.4794638,3.0835133,,,,,,,,,,,,,, -149400,2.3636053,1.2890854,,,,,,,,,,,,,, -149500,2.3014297,1.3842398,,,,,,,,,,,,,, -149590,,,0.8514257669448853,0.5645542740821838,0.7563799619674683,0.9604097604751588,50000.0,0.631600022315979,1.5866049528121948,10000.0,66843.46200037003,72211.91704010963,66843.46200037003,5353.161323547363,7.391125440597534,0.0 -149600,2.3286371,1.5385801,,,,,,,,,,,,,, -149700,2.2936072,1.2982265,,,,,,,,,,,,,, -149800,2.6514032,1.3326877,,,,,,,,,,,,,, -149900,2.2510164,1.2964611,,,,,,,,,,,,,, -150000,2.3225853,1.3453419,,,,,,,,,,,,,, -150100,2.2781117,1.434454,,,,,,,,,,,,,, -150200,2.645429,1.5818794,,,,,,,,,,,,,, -150300,3.0456502,3.7145936,,,,,,,,,,,,,, -150400,2.4523125,3.0118303,,,,,,,,,,,,,, -150500,2.2135856,2.022706,,,,,,,,,,,,,, -150527,,,0.8474218845367432,0.5692348480224609,0.7574599981307983,0.9543402791023254,50000.0,0.6347000598907471,1.5730087757110596,10000.0,67263.53252577782,72669.17042684555,67263.53252577782,5390.246699333191,7.440268039703369,0.0 -150600,2.1306822,2.6937654,,,,,,,,,,,,,, -150700,2.246041,2.4482718,,,,,,,,,,,,,, -150800,2.3207839,1.4776993,,,,,,,,,,,,,, -150900,2.2666924,1.4793922,,,,,,,,,,,,,, -151000,2.1715248,1.7887821,,,,,,,,,,,,,, -151100,2.422959,1.3660846,,,,,,,,,,,,,, -151200,2.348812,1.3710625,,,,,,,,,,,,,, -151300,2.5326648,2.8604636,,,,,,,,,,,,,, -151400,2.3960335,1.2391895,,,,,,,,,,,,,, -151464,,,0.8489453196525574,0.5678457021713257,0.757099986076355,0.9635516405105592,50000.0,0.6391000151634216,1.577191710472107,10000.0,67683.55815124512,73125.09868621826,67683.55815124512,5426.055099248886,7.485986948013306,0.0 -151500,2.9542081,3.6838748,,,,,,,,,,,,,, -151600,2.40964,1.485361,,,,,,,,,,,,,, -151700,2.7939851,3.5826716,,,,,,,,,,,,,, -151800,2.4667232,3.0248733,,,,,,,,,,,,,, -151900,2.442098,1.6909958,,,,,,,,,,,,,, -152000,2.1775784,2.265437,,,,,,,,,,,,,, -152100,2.3159199,1.3863893,,,,,,,,,,,,,, -152200,2.610268,1.444644,,,,,,,,,,,,,, -152300,2.7905047,3.6083565,,,,,,,,,,,,,, -152400,2.401302,1.8478345,,,,,,,,,,,,,, -152401,,,0.8550195097923279,0.5494584441184998,0.7616399526596069,0.9430665373802184,50000.0,0.641800045967102,1.5591038465499878,10000.0,68104.11196637154,73579.6944000721,68104.11196637154,5459.991923809052,7.542108774185181,0.0 -152500,3.209481,3.577369,,,,,,,,,,,,,, -152600,2.416373,1.3125186,,,,,,,,,,,,,, -152700,2.35429,1.2957528,,,,,,,,,,,,,, -152800,2.3962157,1.3597752,,,,,,,,,,,,,, -152900,2.4115858,2.3048432,,,,,,,,,,,,,, -153000,2.3544905,1.2515448,,,,,,,,,,,,,, -153100,2.351129,2.01583,,,,,,,,,,,,,, -153200,2.1065776,1.9862695,,,,,,,,,,,,,, -153300,2.400961,2.8347154,,,,,,,,,,,,,, -153339,,,0.8586718440055847,0.5393452048301697,0.7623199820518494,0.9501993656158448,50000.0,0.641800045967102,1.5675500631332395,10000.0,68524.06725406647,74035.65158557892,68524.06725406647,5495.893121242523,7.593789339065552,0.0 -153400,2.461277,1.2927892,,,,,,,,,,,,,, -153500,2.6664002,1.4090844,,,,,,,,,,,,,, -153600,2.3479838,1.4421954,,,,,,,,,,,,,, -153700,2.3238099,1.3748567,,,,,,,,,,,,,, -153800,2.4749475,1.3525109,,,,,,,,,,,,,, -153900,2.4810843,1.3981838,,,,,,,,,,,,,, -154000,2.3792098,1.4345855,,,,,,,,,,,,,, -154100,2.4088762,2.9879253,,,,,,,,,,,,,, -154200,2.827785,3.3386054,,,,,,,,,,,,,, -154276,,,0.8527148365974426,0.5654762387275696,0.7623599767684937,0.9489037990570068,50000.0,0.6407000422477722,1.5654865503311155,10000.0,68944.03726816177,74491.09753751755,68944.03726816177,5531.262712240219,7.649880170822143,0.0 -154300,2.3161147,1.9987733,,,,,,,,,,,,,, -154400,2.3490841,2.3800013,,,,,,,,,,,,,, -154500,2.3713164,1.9515972,,,,,,,,,,,,,, -154600,2.1416433,2.5932977,,,,,,,,,,,,,, -154700,2.3143446,1.2389182,,,,,,,,,,,,,, -154800,2.3779273,1.583895,,,,,,,,,,,,,, -154900,3.0350945,3.665869,,,,,,,,,,,,,, -155000,2.567277,3.1039276,,,,,,,,,,,,,, -155100,2.4617684,2.907457,,,,,,,,,,,,,, -155200,2.4986007,2.1528971,,,,,,,,,,,,,, -155216,,,0.8570116758346558,0.52569180727005,0.7622999548912048,0.9294662475585938,50000.0,0.6420000195503235,1.5497945547103882,10000.0,69364.2235994339,74945.20179843903,69364.2235994339,5565.0824184417725,7.698629856109619,0.0 -155300,2.5733857,1.8096876,,,,,,,,,,,,,, -155400,2.6771305,2.1988475,,,,,,,,,,,,,, -155500,2.9535835,3.3840837,,,,,,,,,,,,,, -155600,2.4007044,1.4650986,,,,,,,,,,,,,, -155700,2.7322075,2.5315413,,,,,,,,,,,,,, -155800,3.029001,3.4418104,,,,,,,,,,,,,, -155900,2.5766208,1.2536912,,,,,,,,,,,,,, -156000,2.4292436,1.6890866,,,,,,,,,,,,,, -156100,2.4495595,2.1462522,,,,,,,,,,,,,, -156156,,,0.8631835579872131,0.5074052214622498,0.7647199630737305,0.928766131401062,50000.0,0.6428000330924988,1.5421534776687622,10000.0,69784.42750167847,75400.90425562859,69784.42750167847,5600.47674703598,7.753859281539917,0.0 -156200,2.2672505,2.1158137,,,,,,,,,,,,,, -156300,2.66164,1.4153304,,,,,,,,,,,,,, -156400,3.228455,3.621188,,,,,,,,,,,,,, -156500,2.6897638,1.6582855,,,,,,,,,,,,,, -156600,2.3281658,1.9173049,,,,,,,,,,,,,, -156700,2.6023605,1.3866432,,,,,,,,,,,,,, -156800,2.4278066,1.5590872,,,,,,,,,,,,,, -156900,2.752866,1.4518001,,,,,,,,,,,,,, -157000,2.6988745,1.3054982,,,,,,,,,,,,,, -157093,,,0.8556640148162842,0.5329411625862122,0.766539990901947,0.919713020324707,50000.0,0.6445000171661377,1.5387070178985596,10000.0,70204.14212560654,75855.6677236557,70204.14212560654,5635.178020000458,8.052583456039429,0.0 -157100,2.9336622,3.3797245,,,,,,,,,,,,,, -157200,2.924194,3.4710793,,,,,,,,,,,,,, -157300,2.39789,2.137699,,,,,,,,,,,,,, -157400,2.6370046,1.2594944,,,,,,,,,,,,,, -157500,2.5921047,2.1293175,,,,,,,,,,,,,, -157600,2.704379,1.3757039,,,,,,,,,,,,,, -157700,2.6882043,1.9752383,,,,,,,,,,,,,, -157800,2.7709165,2.4318032,,,,,,,,,,,,,, -157900,2.5072253,3.096051,,,,,,,,,,,,,, -158000,2.8190897,1.319178,,,,,,,,,,,,,, -158031,,,0.86146479845047,0.5181689858436584,0.7660399675369263,0.9252774715423584,50000.0,0.6491000056266785,1.5364491939544678,10000.0,70624.15113449097,76312.09645199776,70624.15113449097,5671.48996591568,8.11124873161316,0.0 -158100,2.922969,1.3682463,,,,,,,,,,,,,, -158200,2.8389962,2.9600375,,,,,,,,,,,,,, -158300,2.429337,1.1563679,,,,,,,,,,,,,, -158400,2.6847951,1.2885072,,,,,,,,,,,,,, -158500,2.6027005,1.3905584,,,,,,,,,,,,,, -158600,2.8270795,1.3177391,,,,,,,,,,,,,, -158700,2.6752005,3.0783844,,,,,,,,,,,,,, -158800,2.6461565,1.3210905,,,,,,,,,,,,,, -158900,2.7803793,1.2900009,,,,,,,,,,,,,, -158968,,,0.8666796684265137,0.4957504868507385,0.767799973487854,0.9111968278884888,50000.0,0.6464000344276428,1.5262484550476074,10000.0,71044.10197019577,76767.86173796654,71044.10197019577,5707.205604314804,8.161112070083618,0.0 -159000,2.624455,1.3639605,,,,,,,,,,,,,, -159100,2.8718731,2.4762878,,,,,,,,,,,,,, -159200,2.6197443,1.2880632,,,,,,,,,,,,,, -159300,2.6618166,1.3433963,,,,,,,,,,,,,, -159400,2.5611796,1.2604408,,,,,,,,,,,,,, -159500,2.904225,2.8291764,,,,,,,,,,,,,, -159600,2.7474248,1.6632397,,,,,,,,,,,,,, -159700,2.5791564,1.214537,,,,,,,,,,,,,, -159800,3.163841,1.4358761,,,,,,,,,,,,,, -159900,2.523152,2.0449421,,,,,,,,,,,,,, -159905,,,0.8684179782867432,0.4979733824729919,0.7691599726676941,0.9126954674720764,50000.0,0.6496000289916992,1.5200276374816897,10000.0,71464.28036999702,77222.11035394669,71464.28036999702,5741.168454885483,8.218778133392334,0.0 -160000,2.830812,1.5563207,,,,,,,,,,,,,, -160100,2.676576,1.3006494,,,,,,,,,,,,,, -160200,2.4946756,1.2491282,,,,,,,,,,,,,, -160300,2.5769234,1.6605873,,,,,,,,,,,,,, -160400,2.7027578,1.3147131,,,,,,,,,,,,,, -160500,2.5928955,1.569529,,,,,,,,,,,,,, -160600,2.730912,1.2719637,,,,,,,,,,,,,, -160700,2.8619459,3.0481985,,,,,,,,,,,,,, -160800,2.730593,1.3180176,,,,,,,,,,,,,, -160840,,,0.8680663704872131,0.4921901226043701,0.770039975643158,0.9016311764717102,50000.0,0.6538000106811523,1.5107396841049194,10000.0,71884.30971646309,77676.5434346199,71884.30971646309,5775.471606492996,8.270018339157104,0.0 -160900,2.6453242,1.1791614,,,,,,,,,,,,,, -161000,2.7483337,1.267106,,,,,,,,,,,,,, -161100,2.7443192,1.3419664,,,,,,,,,,,,,, -161200,2.6823046,1.5307992,,,,,,,,,,,,,, -161300,2.723852,2.9120677,,,,,,,,,,,,,, -161400,2.7277246,1.5241361,,,,,,,,,,,,,, -161500,2.6241984,2.764259,,,,,,,,,,,,,, -161600,2.7966914,1.2760338,,,,,,,,,,,,,, -161700,2.9554791,2.0655198,,,,,,,,,,,,,, -161775,,,0.870898425579071,0.4825400114059448,0.7693799734115601,0.9072303175926208,50000.0,0.6493000388145447,1.5251152515411377,10000.0,72304.306691885,78133.04519224167,72304.306691885,5811.875743627548,8.321635723114014,0.0 -161800,2.5611796,1.151902,,,,,,,,,,,,,, -161900,2.8792417,3.1260097,,,,,,,,,,,,,, -162000,2.7299154,2.0965223,,,,,,,,,,,,,, -162100,2.6158202,1.3360047,,,,,,,,,,,,,, -162200,2.9624262,2.9878225,,,,,,,,,,,,,, -162300,2.7186878,1.5538809,,,,,,,,,,,,,, -162400,2.9920037,1.387908,,,,,,,,,,,,,, -162500,2.7036705,1.3425716,,,,,,,,,,,,,, -162600,2.8837347,2.4665866,,,,,,,,,,,,,, -162700,2.8930907,1.381351,,,,,,,,,,,,,, -162714,,,0.8746874928474426,0.4815309941768646,0.7702800035476685,0.9123331904411316,50000.0,0.6517000198364258,1.525587797164917,10000.0,72724.41666007042,78588.06591463089,72724.41666007042,5846.686740875244,8.37223219871521,0.0 -162800,2.7155004,1.2826729,,,,,,,,,,,,,, -162900,2.910904,1.2819309,,,,,,,,,,,,,, -163000,2.6864576,2.2829988,,,,,,,,,,,,,, -163100,2.5557766,2.0899646,,,,,,,,,,,,,, -163200,3.0319746,3.1769884,,,,,,,,,,,,,, -163300,2.5608714,1.2210146,,,,,,,,,,,,,, -163400,2.6415005,1.447493,,,,,,,,,,,,,, -163500,2.7259111,1.5131937,,,,,,,,,,,,,, -163600,3.2479453,2.9860992,,,,,,,,,,,,,, -163650,,,0.8691992163658142,0.4835098087787628,0.7697199583053589,0.8991246819496155,50000.0,0.6526000499725342,1.5037193298339844,10000.0,73144.49038362503,79042.07584261894,73144.49038362503,5880.516179323196,8.430397033691406,0.0 -163700,2.6349642,2.0131133,,,,,,,,,,,,,, -163800,2.782913,1.691781,,,,,,,,,,,,,, -163900,3.1732168,3.0871801,,,,,,,,,,,,,, -164000,2.9088318,1.4911222,,,,,,,,,,,,,, -164100,3.4141753,3.2436721,,,,,,,,,,,,,, -164200,2.8641863,2.7604342,,,,,,,,,,,,,, -164300,2.734243,1.6624506,,,,,,,,,,,,,, -164400,2.697947,2.2733817,,,,,,,,,,,,,, -164500,2.926083,2.2942595,,,,,,,,,,,,,, -164588,,,0.87158203125,0.4755339622497558,0.7726399898529053,0.8931786417961121,50000.0,0.6550000309944153,1.5063070058822632,10000.0,73564.75876450539,79497.69409179688,73564.75876450539,5915.757030487061,8.490588188171387,0.0 -164600,2.6904874,1.7481267,,,,,,,,,,,,,, -164700,2.9933789,1.2883778,,,,,,,,,,,,,, -164800,3.650325,3.3805163,,,,,,,,,,,,,, -164900,2.7141902,1.2091327,,,,,,,,,,,,,, -165000,3.0752606,3.1704597,,,,,,,,,,,,,, -165100,2.839789,1.2489794,,,,,,,,,,,,,, -165200,2.7915854,1.2500298,,,,,,,,,,,,,, -165300,3.6067297,3.502821,,,,,,,,,,,,,, -165400,3.4912808,3.3799362,,,,,,,,,,,,,, -165500,2.6921568,2.2356067,,,,,,,,,,,,,, -165528,,,0.8754296898841858,0.4589492976665497,0.7731800079345703,0.8862924575805664,50000.0,0.6545000076293945,1.494189739227295,10000.0,73985.09216928482,79952.10796093941,73985.09216928482,5949.728416919708,8.550045013427734,0.0 -165600,2.639269,1.1483934,,,,,,,,,,,,,, -165700,2.9418066,1.1619959,,,,,,,,,,,,,, -165800,2.8947866,1.1994027,,,,,,,,,,,,,, -165900,3.2990553,3.402536,,,,,,,,,,,,,, -166000,3.2220147,3.2549598,,,,,,,,,,,,,, -166100,3.3323276,3.4157329,,,,,,,,,,,,,, -166200,2.8626597,1.8858323,,,,,,,,,,,,,, -166300,3.0693786,3.030272,,,,,,,,,,,,,, -166400,2.714245,1.8558799,,,,,,,,,,,,,, -166466,,,0.8744921684265137,0.477140724658966,0.7742799520492554,0.8912918567657471,50000.0,0.6562000513076782,1.5012069940567017,10000.0,74405.08458447456,80405.63422679901,74405.08458447456,5983.164102315903,8.599197626113892,0.0 -166500,2.7839625,1.1790812,,,,,,,,,,,,,, -166600,2.7677827,1.1718663,,,,,,,,,,,,,, -166700,2.683881,1.6399672,,,,,,,,,,,,,, -166800,2.8914425,1.1568525,,,,,,,,,,,,,, -166900,2.961381,1.4048424,,,,,,,,,,,,,, -167000,2.9856327,1.254931,,,,,,,,,,,,,, -167100,2.691975,1.296561,,,,,,,,,,,,,, -167200,3.1297824,1.2265837,,,,,,,,,,,,,, -167300,2.8463173,1.303926,,,,,,,,,,,,,, -167400,3.1849015,3.2549596,,,,,,,,,,,,,, -167402,,,0.8738085627555847,0.4666889607906341,0.7739599943161011,0.8845837116241455,50000.0,0.6580000519752502,1.495749831199646,10000.0,74825.04020619392,80861.90974783897,74825.04020619392,6019.371679782867,8.662525653839111,0.0 -167500,2.8092098,1.1383522,,,,,,,,,,,,,, -167600,2.917011,1.0924096,,,,,,,,,,,,,, -167700,2.995393,2.32861,,,,,,,,,,,,,, -167800,3.050958,3.0851405,,,,,,,,,,,,,, -167900,2.9094543,1.1847903,,,,,,,,,,,,,, -168000,2.9709039,1.8977277,,,,,,,,,,,,,, -168100,3.0316339,1.1433356,,,,,,,,,,,,,, -168200,2.982878,1.2342594,,,,,,,,,,,,,, -168300,2.7642097,1.9188039,,,,,,,,,,,,,, -168340,,,0.8773242235183716,0.4531766176223755,0.7756399512290955,0.8850293755531311,50000.0,0.6571000218391418,1.4930970668792725,10000.0,75245.46972846985,81317.98803067207,75245.46972846985,6054.917459487915,8.715810537338257,0.0 -168400,2.6531582,1.4773738,,,,,,,,,,,,,, -168500,3.0718129,2.3683994,,,,,,,,,,,,,, -168600,2.802623,1.1213942,,,,,,,,,,,,,, -168700,3.3582902,3.250101,,,,,,,,,,,,,, -168800,2.7853563,1.4526653,,,,,,,,,,,,,, -168900,2.6699162,1.7172859,,,,,,,,,,,,,, -169000,3.042601,1.206547,,,,,,,,,,,,,, -169100,2.958594,1.327975,,,,,,,,,,,,,, -169200,2.8143215,1.0873377,,,,,,,,,,,,,, -169278,,,0.8797265291213989,0.4421044588088989,0.7756999731063843,0.8826619386672974,50000.0,0.6571000218391418,1.4991886615753174,10000.0,75665.5037753582,81772.30905842781,75665.5037753582,6089.09916472435,8.771555423736572,0.0 -169300,3.0676181,1.1325889,,,,,,,,,,,,,, -169400,3.0097377,1.2013166,,,,,,,,,,,,,, -169500,3.000807,1.1085427,,,,,,,,,,,,,, -169600,3.4652946,3.251634,,,,,,,,,,,,,, -169700,2.6969314,2.0033405,,,,,,,,,,,,,, -169800,3.644626,3.4067066,,,,,,,,,,,,,, -169900,3.0323222,1.5828588,,,,,,,,,,,,,, -170000,3.2252595,1.1139839,,,,,,,,,,,,,, -170100,3.0525188,2.3254573,,,,,,,,,,,,,, -170200,3.138103,1.2178869,,,,,,,,,,,,,, -170213,,,0.8807030916213989,0.4466629028320312,0.7764399647712708,0.8772845268249512,50000.0,0.6593000292778015,1.4936374425888062,10000.0,76085.7650001049,82228.5535402298,76085.7650001049,6124.984532356262,8.821457386016846,0.0 -170300,3.55297,3.259973,,,,,,,,,,,,,, -170400,3.2599247,3.200708,,,,,,,,,,,,,, -170500,3.1329117,1.21905,,,,,,,,,,,,,, -170600,3.6457715,3.2840517,,,,,,,,,,,,,, -170700,2.900744,1.1771172,,,,,,,,,,,,,, -170800,3.2957149,2.844229,,,,,,,,,,,,,, -170900,3.315576,2.395364,,,,,,,,,,,,,, -171000,3.43392,3.3076246,,,,,,,,,,,,,, -171100,3.053854,1.162047,,,,,,,,,,,,,, -171152,,,0.8788085579872131,0.4484823048114776,0.7775599956512451,0.8697133660316467,50000.0,0.6577000021934509,1.484135627746582,10000.0,76505.8050429821,82682.57957959175,76505.8050429821,6158.859827756882,8.883781433105469,0.0 -171200,3.340075,1.1943274,,,,,,,,,,,,,, -171300,3.0071254,1.2198604,,,,,,,,,,,,,, -171400,2.7482984,1.7267478,,,,,,,,,,,,,, -171500,2.8503883,1.1715927,,,,,,,,,,,,,, -171600,2.8688953,1.179734,,,,,,,,,,,,,, -171700,3.2252712,2.5391626,,,,,,,,,,,,,, -171800,3.3029718,1.1685688,,,,,,,,,,,,,, -171900,2.9145353,1.4677271,,,,,,,,,,,,,, -172000,2.898847,1.1165328,,,,,,,,,,,,,, -172089,,,0.8827148079872131,0.4331200420856476,0.7792999744415283,0.8637553453445435,50000.0,0.6615000367164612,1.4721359014511108,10000.0,76925.95561552048,83136.61755609512,76925.95561552048,6192.639292001724,8.942861557006836,0.0 -172100,3.352513,2.3369622,,,,,,,,,,,,,, -172200,2.813332,1.3802639,,,,,,,,,,,,,, -172300,2.921613,1.1329353,,,,,,,,,,,,,, -172400,3.4654973,1.1946732,,,,,,,,,,,,,, -172500,3.289209,3.057352,,,,,,,,,,,,,, -172600,3.143292,1.1712314,,,,,,,,,,,,,, -172700,3.1667483,1.2460486,,,,,,,,,,,,,, -172800,2.7582278,1.7554755,,,,,,,,,,,,,, -172900,2.9895246,1.1693599,,,,,,,,,,,,,, -173000,3.043132,1.1853211,,,,,,,,,,,,,, -173027,,,0.8812695145606995,0.4429061114788055,0.7791399955749512,0.8685945868492126,50000.0,0.6621000170707703,1.4893403053283691,10000.0,77345.90822839737,83592.15035367012,77345.90822839737,6228.115994215012,8.996717929840088,0.0 -173100,2.8463147,2.046536,,,,,,,,,,,,,, -173200,2.8454497,1.4097683,,,,,,,,,,,,,, -173300,2.84083,2.3115664,,,,,,,,,,,,,, -173400,2.980031,1.0615194,,,,,,,,,,,,,, -173500,2.8309846,1.7347624,,,,,,,,,,,,,, -173600,2.9238625,1.1313064,,,,,,,,,,,,,, -173700,3.1249578,2.243561,,,,,,,,,,,,,, -173800,3.7851734,3.3782868,,,,,,,,,,,,,, -173900,3.257924,1.1257595,,,,,,,,,,,,,, -173963,,,0.8810155987739563,0.4393672943115234,0.7788999676704407,0.8695529699325562,50000.0,0.6599000096321106,1.4833369255065918,10000.0,77765.81756949425,84047.55881023407,77765.81756949425,6263.508965969086,9.05425500869751,0.0 -174000,2.7403085,1.5272613,,,,,,,,,,,,,, -174100,2.908966,1.1704946,,,,,,,,,,,,,, -174200,3.0569088,1.5130751,,,,,,,,,,,,,, -174300,3.0875268,1.1923265,,,,,,,,,,,,,, -174400,3.0724993,2.5725088,,,,,,,,,,,,,, -174500,2.800641,1.1465696,,,,,,,,,,,,,, -174600,2.964182,2.5554388,,,,,,,,,,,,,, -174700,3.093918,1.1089635,,,,,,,,,,,,,, -174800,2.9014254,1.0719864,,,,,,,,,,,,,, -174900,2.8898625,1.3018175,,,,,,,,,,,,,, -174901,,,0.8854491710662842,0.4333232641220093,0.7795999646186829,0.8688917756080627,50000.0,0.6640000343322754,1.4762301445007324,10000.0,78186.09090733528,84502.53801631927,78186.09090733528,6298.112339735031,9.10762882232666,0.0 -175000,3.8508916,3.2776709,,,,,,,,,,,,,, -175100,2.7753286,1.3258398,,,,,,,,,,,,,, -175200,3.000411,1.2064356,,,,,,,,,,,,,, -175300,3.0486577,1.2435267,,,,,,,,,,,,,, -175400,2.8531587,2.329075,,,,,,,,,,,,,, -175500,3.0267417,1.1390384,,,,,,,,,,,,,, -175600,3.0434737,1.8876088,,,,,,,,,,,,,, -175700,3.1604285,1.0428956,,,,,,,,,,,,,, -175800,3.0207415,1.2249687,,,,,,,,,,,,,, -175842,,,0.88539057970047,0.4322720766067505,0.7794599533081055,0.866097629070282,50000.0,0.6651000380516052,1.4696331024169922,10000.0,78606.16627883911,84956.16199755669,78606.16627883911,6331.55254650116,9.166329622268677,0.0 -175900,2.8164773,2.1948798,,,,,,,,,,,,,, -176000,3.0420973,2.2031734,,,,,,,,,,,,,, -176100,3.0245817,1.3441017,,,,,,,,,,,,,, -176200,3.094906,1.1946404,,,,,,,,,,,,,, -176300,3.0359876,1.2967741,,,,,,,,,,,,,, -176400,3.2070005,1.1521883,,,,,,,,,,,,,, -176500,3.464896,2.4043753,,,,,,,,,,,,,, -176600,3.8976524,3.3396826,,,,,,,,,,,,,, -176700,3.179032,2.1546109,,,,,,,,,,,,,, -176780,,,0.8858593702316284,0.424483984708786,0.780019998550415,0.8592524528503418,50000.0,0.6630000472068787,1.4692275524139404,10000.0,79026.12602353096,85411.62379169464,79026.12602353096,6366.944556713104,9.227767944335938,0.0 -176800,2.986386,2.1673765,,,,,,,,,,,,,, -176900,3.6197069,1.1014209,,,,,,,,,,,,,, -177000,3.1732774,1.0994588,,,,,,,,,,,,,, -177100,2.914083,1.1224079,,,,,,,,,,,,,, -177200,3.205344,2.0009184,,,,,,,,,,,,,, -177300,3.1301122,1.1786803,,,,,,,,,,,,,, -177400,3.3240104,3.0020041,,,,,,,,,,,,,, -177500,3.6657076,3.1890602,,,,,,,,,,,,,, -177600,2.9325,1.0842446,,,,,,,,,,,,,, -177700,3.1051123,1.0596104,,,,,,,,,,,,,, -177719,,,0.8868359327316284,0.4214153289794922,0.7813000082969666,0.8596096634864807,50000.0,0.6668000221252441,1.4655801057815552,10000.0,79446.03746342659,85866.38427376747,79446.03746342659,6401.689318180084,9.282610654830933,0.0 -177800,4.205135,3.3286748,,,,,,,,,,,,,, -177900,2.7690294,2.0878334,,,,,,,,,,,,,, -178000,3.1020925,1.6817886,,,,,,,,,,,,,, -178100,3.3733206,3.003374,,,,,,,,,,,,,, -178200,2.9752104,1.2254071,,,,,,,,,,,,,, -178300,4.2292147,3.2210653,,,,,,,,,,,,,, -178400,2.8257155,1.4393133,,,,,,,,,,,,,, -178500,3.3493402,3.0686412,,,,,,,,,,,,,, -178600,3.0712934,1.0923015,,,,,,,,,,,,,, -178661,,,0.88685542345047,0.419357031583786,0.7822799682617188,0.8534312844276428,50000.0,0.664900004863739,1.4604169130325315,10000.0,79866.01587152481,86320.78079080582,79866.01587152481,6435.998902320862,9.340909719467165,0.0 -178700,3.0713406,1.1876053,,,,,,,,,,,,,, -178800,3.3673427,1.2615368,,,,,,,,,,,,,, -178900,3.2595615,1.4979907,,,,,,,,,,,,,, -179000,3.3951824,2.8810236,,,,,,,,,,,,,, -179100,3.0800498,2.440066,,,,,,,,,,,,,, -179200,3.1813254,1.8014431,,,,,,,,,,,,,, -179300,3.0256371,2.642499,,,,,,,,,,,,,, -179400,4.089035,3.2039404,,,,,,,,,,,,,, -179500,3.0440724,1.371665,,,,,,,,,,,,,, -179599,,,0.8860741853713989,0.4188991189002991,0.781059980392456,0.8571061491966248,50000.0,0.6645000576972961,1.462255358695984,10000.0,80285.93200969696,86775.10715174675,80285.93200969696,6470.306844234467,9.394191026687622,0.0 -179600,3.631121,3.031789,,,,,,,,,,,,,, -179700,3.0504558,2.688103,,,,,,,,,,,,,, -179800,3.7501104,3.225416,,,,,,,,,,,,,, -179900,3.1461623,1.118019,,,,,,,,,,,,,, -180000,3.5017254,3.152052,,,,,,,,,,,,,, -180100,3.122821,1.1539751,,,,,,,,,,,,,, -180200,3.034701,2.1043997,,,,,,,,,,,,,, -180300,2.9576604,1.0146317,,,,,,,,,,,,,, -180400,3.1174653,2.5710225,,,,,,,,,,,,,, -180500,3.0446703,2.4173331,,,,,,,,,,,,,, -180537,,,0.8853710889816284,0.423235535621643,0.7822999954223633,0.856580913066864,50000.0,0.6633000373840332,1.465488314628601,10000.0,80705.8322134018,87231.2764441967,80705.8322134018,6506.475147724152,9.446362257003784,0.0 -180600,2.7750378,1.5311878,,,,,,,,,,,,,, -180700,3.4508984,1.1438773,,,,,,,,,,,,,, -180800,3.0530202,1.1894182,,,,,,,,,,,,,, -180900,3.3499427,2.296261,,,,,,,,,,,,,, -181000,3.1692488,1.2638581,,,,,,,,,,,,,, -181100,3.8615925,3.1300914,,,,,,,,,,,,,, -181200,3.1012337,1.0959625,,,,,,,,,,,,,, -181300,3.3166857,2.341379,,,,,,,,,,,,,, -181400,4.2528257,3.0648844,,,,,,,,,,,,,, -181476,,,0.8878124952316284,0.4163694381713867,0.7823799848556519,0.8540942072868347,50000.0,0.664900004863739,1.4613149166107178,10000.0,81125.91492772102,87685.75764489174,81125.91492772102,6540.772976398468,9.49793577194214,0.0 -181500,3.3482404,1.341403,,,,,,,,,,,,,, -181600,4.037507,3.2198863,,,,,,,,,,,,,, -181700,3.1963177,1.2930341,,,,,,,,,,,,,, -181800,3.0738757,1.1374445,,,,,,,,,,,,,, -181900,2.950142,1.7911985,,,,,,,,,,,,,, -182000,2.9334872,2.298863,,,,,,,,,,,,,, -182100,2.9229193,1.0420374,,,,,,,,,,,,,, -182200,2.972814,1.6351235,,,,,,,,,,,,,, -182300,3.0023274,1.786974,,,,,,,,,,,,,, -182400,3.1018326,1.639235,,,,,,,,,,,,,, -182415,,,0.8870312571525574,0.4193125069141388,0.7821399569511414,0.8543209433555603,50000.0,0.6657000184059143,1.4587408304214478,10000.0,81545.87590956688,88141.41509056091,81545.87590956688,6576.370446205139,9.548835754394531,0.0 -182500,3.0187163,2.4755566,,,,,,,,,,,,,, -182600,2.9892647,2.461316,,,,,,,,,,,,,, -182700,3.2774804,1.0313131,,,,,,,,,,,,,, -182800,3.650651,3.1220312,,,,,,,,,,,,,, -182900,3.0891275,1.0873669,,,,,,,,,,,,,, -183000,3.063498,1.069315,,,,,,,,,,,,,, -183100,2.8723738,1.1953923,,,,,,,,,,,,,, -183200,3.214998,1.081306,,,,,,,,,,,,,, -183300,2.752398,1.4780865,,,,,,,,,,,,,, -183355,,,0.8894140720367432,0.4123246669769287,0.7827000021934509,0.8526468276977539,50000.0,0.6669000387191772,1.458864688873291,10000.0,81965.95393848419,88595.35318779945,81965.95393848419,6610.12516784668,9.60527515411377,0.0 -183400,3.4310186,1.2005565,,,,,,,,,,,,,, -183500,3.398281,2.6353254,,,,,,,,,,,,,, -183600,3.0230837,1.1553549,,,,,,,,,,,,,, -183700,2.9952013,1.5004227,,,,,,,,,,,,,, -183800,3.0228012,1.094628,,,,,,,,,,,,,, -183900,2.9467623,1.3283198,,,,,,,,,,,,,, -184000,2.7302938,1.1962602,,,,,,,,,,,,,, -184100,3.457515,1.2110578,,,,,,,,,,,,,, -184200,3.1475399,1.1028161,,,,,,,,,,,,,, -184294,,,0.8896093368530273,0.412265419960022,0.7827000021934509,0.8530824184417725,50000.0,0.6668000221252441,1.4599723815917969,10000.0,82386.22452759743,89052.31648516655,82386.22452759743,6646.70134973526,9.67266058921814,0.0 -184300,2.9316502,1.0211053,,,,,,,,,,,,,, -184400,3.1024258,1.4280344,,,,,,,,,,,,,, -184500,3.071774,2.2197385,,,,,,,,,,,,,, -184600,3.1891234,0.9903406,,,,,,,,,,,,,, -184700,3.2554333,1.4368362,,,,,,,,,,,,,, -184800,3.0202894,1.0794667,,,,,,,,,,,,,, -184900,3.4097965,1.2860554,,,,,,,,,,,,,, -185000,3.1254299,1.3384051,,,,,,,,,,,,,, -185100,3.0255806,1.0778022,,,,,,,,,,,,,, -185200,3.2147675,1.1675498,,,,,,,,,,,,,, -185231,,,0.8886132836341858,0.4169757664203644,0.7825799584388733,0.8518304228782654,50000.0,0.6669000387191772,1.4578466415405271,10000.0,82806.58885216713,89508.3066072464,82806.58885216713,6682.224026918411,9.72684407234192,0.0 -185300,3.3715947,2.5161612,,,,,,,,,,,,,, -185400,3.9973464,3.1905444,,,,,,,,,,,,,, -185500,3.5532355,1.0591751,,,,,,,,,,,,,, -185600,4.252786,3.2125263,,,,,,,,,,,,,, -185700,3.1432798,2.501294,,,,,,,,,,,,,, -185800,3.5175557,1.1720008,,,,,,,,,,,,,, -185900,3.2322001,1.0962508,,,,,,,,,,,,,, -186000,2.8203037,2.1399574,,,,,,,,,,,,,, -186100,3.0184507,1.1165879,,,,,,,,,,,,,, -186169,,,0.8907030820846558,0.4041774272918701,0.7828999757766724,0.8518894910812378,50000.0,0.6664000153541565,1.458232283592224,10000.0,83226.80952954292,89962.55336284637,83226.80952954292,6716.146427869797,9.780983686447144,0.0 -186200,3.154284,1.2065618,,,,,,,,,,,,,, -186300,2.9430952,1.5839026,,,,,,,,,,,,,, -186400,4.457152,1.1792524,,,,,,,,,,,,,, -186500,3.2229207,1.1937741,,,,,,,,,,,,,, -186600,3.2970545,2.1236897,,,,,,,,,,,,,, -186700,3.1113997,1.1071315,,,,,,,,,,,,,, -186800,3.9673934,3.2734995,,,,,,,,,,,,,, -186900,3.0127568,1.243994,,,,,,,,,,,,,, -187000,2.9990904,1.1210954,,,,,,,,,,,,,, -187100,3.2173553,2.7863748,,,,,,,,,,,,,, -187104,,,0.8879296779632568,0.4138663411140442,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,83647.07682418823,90418.60726046562,83647.07682418823,6751.817445039749,9.848082780838013,0.0 -187200,3.1599095,1.0132663,,,,,,,,,,,,,, -187300,2.9603949,1.1251347,,,,,,,,,,,,,, -187400,3.3683393,1.2750937,,,,,,,,,,,,,, -187500,2.7542136,1.4824337,,,,,,,,,,,,,, -187600,3.0851483,1.1897106,,,,,,,,,,,,,, -187700,3.0394275,1.1633295,,,,,,,,,,,,,, -187800,2.9603112,2.5371873,,,,,,,,,,,,,, -187900,3.2904785,1.3359399,,,,,,,,,,,,,, -188000,2.9135368,2.2973669,,,,,,,,,,,,,, -188039,,,0.88783198595047,0.4124012589454651,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,84067.24569582939,90873.28254199028,84067.24569582939,6786.212738990784,9.911024570465088,0.0 -188100,3.06233,1.3457325,,,,,,,,,,,,,, -188200,3.0617914,2.6926875,,,,,,,,,,,,,, -188300,2.957564,1.2380321,,,,,,,,,,,,,, -188400,3.2271497,1.1437441,,,,,,,,,,,,,, -188500,3.5275078,1.2632936,,,,,,,,,,,,,, -188600,3.399872,1.7620384,,,,,,,,,,,,,, -188700,2.8290985,1.1353054,,,,,,,,,,,,,, -188800,2.9419188,1.1324939,,,,,,,,,,,,,, -188900,2.8855855,1.3749187,,,,,,,,,,,,,, -188976,,,0.8874218463897705,0.4165248870849609,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,84487.2626209259,91329.5935049057,84487.2626209259,6822.406176567078,9.963230848312378,0.0 -189000,3.2284088,1.1339048,,,,,,,,,,,,,, -189100,2.8828628,1.5546328,,,,,,,,,,,,,, -189200,3.229025,1.0733827,,,,,,,,,,,,,, -189300,3.7379215,1.1659448,,,,,,,,,,,,,, -189400,2.9845285,1.8959632,,,,,,,,,,,,,, -189500,3.1491652,2.8097095,,,,,,,,,,,,,, -189600,3.2598834,2.6097674,,,,,,,,,,,,,, -189700,2.9970555,1.1108383,,,,,,,,,,,,,, -189800,3.0711398,2.348638,,,,,,,,,,,,,, -189900,3.0743647,1.0669873,,,,,,,,,,,,,, -189911,,,0.8871288895606995,0.4155504107475281,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,84907.21473288536,91784.68587970734,84907.21473288536,6857.440523386002,10.020652532577516,0.0 -190000,3.9768934,3.093774,,,,,,,,,,,,,, -190100,3.2062962,1.1734512,,,,,,,,,,,,,, -190200,3.319759,3.014738,,,,,,,,,,,,,, -190300,3.030507,1.5452707,,,,,,,,,,,,,, -190400,3.1372395,1.1944561,,,,,,,,,,,,,, -190500,3.1643312,1.6051909,,,,,,,,,,,,,, -190600,3.1663506,1.4129375,,,,,,,,,,,,,, -190700,3.1102405,1.09626,,,,,,,,,,,,,, -190800,3.045104,1.1099494,,,,,,,,,,,,,, -190846,,,0.8899218440055847,0.4114306569099426,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,85327.35544729233,92240.39691138268,85327.35544729233,6892.907320976257,10.07563328742981,0.0 -190900,3.3694866,2.6902466,,,,,,,,,,,,,, -191000,3.1888895,1.0533605,,,,,,,,,,,,,, -191100,3.0263686,1.6190891,,,,,,,,,,,,,, -191200,3.0678384,1.4229126,,,,,,,,,,,,,, -191300,3.0195696,2.4831939,,,,,,,,,,,,,, -191400,2.8865778,1.5848298,,,,,,,,,,,,,, -191500,3.250945,1.1651478,,,,,,,,,,,,,, -191600,3.0975823,1.3699552,,,,,,,,,,,,,, -191700,3.1471531,1.1964036,,,,,,,,,,,,,, -191785,,,0.8871288895606995,0.4186196029186249,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,85747.62012791634,92694.81047177316,85747.62012791634,6926.948734521866,10.13442349433899,0.0 -191800,3.3310142,1.1893398,,,,,,,,,,,,,, -191900,4.261106,3.2581294,,,,,,,,,,,,,, -192000,2.9197419,1.578495,,,,,,,,,,,,,, -192100,2.9109225,1.5196549,,,,,,,,,,,,,, -192200,3.2367463,1.1139323,,,,,,,,,,,,,, -192300,3.177889,1.1167948,,,,,,,,,,,,,, -192400,3.0916736,1.2084643,,,,,,,,,,,,,, -192500,3.3906963,3.081681,,,,,,,,,,,,,, -192600,3.721581,3.253096,,,,,,,,,,,,,, -192700,3.0160627,1.1247978,,,,,,,,,,,,,, -192721,,,0.8883984088897705,0.4127451479434967,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,86167.71252465248,93151.8325972557,86167.71252465248,6963.765806436539,10.197827577590942,0.0 -192800,2.8367362,1.4269797,,,,,,,,,,,,,, -192900,2.9740498,1.8380249,,,,,,,,,,,,,, -193000,3.045454,1.6195847,,,,,,,,,,,,,, -193100,3.7261024,3.1483169,,,,,,,,,,,,,, -193200,3.3016822,1.1366904,,,,,,,,,,,,,, -193300,3.4630919,1.5337797,,,,,,,,,,,,,, -193400,2.8176188,1.9833856,,,,,,,,,,,,,, -193500,3.1325996,2.5245295,,,,,,,,,,,,,, -193600,3.0926557,1.7653582,,,,,,,,,,,,,, -193659,,,0.8858984112739563,0.4185971319675445,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,86587.71695780754,93606.5622651577,86587.71695780754,6998.387335062027,10.25277328491211,0.0 -193700,3.0953674,2.0798345,,,,,,,,,,,,,, -193800,3.21438,1.1538453,,,,,,,,,,,,,, -193900,2.825119,1.1724617,,,,,,,,,,,,,, -194000,3.0093684,2.3143024,,,,,,,,,,,,,, -194100,2.9948583,1.0936873,,,,,,,,,,,,,, -194200,3.247604,1.1705892,,,,,,,,,,,,,, -194300,4.0188694,2.6627855,,,,,,,,,,,,,, -194400,2.8743432,2.2772884,,,,,,,,,,,,,, -194500,3.447074,1.0907753,,,,,,,,,,,,,, -194596,,,0.8871874809265137,0.416054368019104,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,87007.7027232647,94061.83039855956,87007.7027232647,7033.558439016342,10.31568694114685,0.0 -194600,3.1863422,1.0885335,,,,,,,,,,,,,, -194700,3.1550074,1.0992882,,,,,,,,,,,,,, -194800,3.0070076,1.0401163,,,,,,,,,,,,,, -194900,4.2656345,3.1354976,,,,,,,,,,,,,, -195000,2.8860176,1.4555403,,,,,,,,,,,,,, -195100,3.1008816,1.3029802,,,,,,,,,,,,,, -195200,2.9587889,1.1936747,,,,,,,,,,,,,, -195300,3.7951045,3.2593994,,,,,,,,,,,,,, -195400,3.5132716,3.0227299,,,,,,,,,,,,,, -195500,3.0449183,1.1654277,,,,,,,,,,,,,, -195533,,,0.8887109160423279,0.4140407741069793,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,87427.7664604187,94517.9357123375,87427.7664604187,7069.494187831879,10.3733651638031,0.0 -195600,3.195456,2.710259,,,,,,,,,,,,,, -195700,2.956019,1.0503783,,,,,,,,,,,,,, -195800,3.113743,1.0999926,,,,,,,,,,,,,, -195900,3.109979,1.0288968,,,,,,,,,,,,,, -196000,3.0810106,2.3928075,,,,,,,,,,,,,, -196100,3.1753254,1.143551,,,,,,,,,,,,,, -196200,3.0786302,1.34444,,,,,,,,,,,,,, -196300,3.4797661,1.1598245,,,,,,,,,,,,,, -196400,3.0695288,1.1335195,,,,,,,,,,,,,, -196472,,,0.8884179592132568,0.4153717756271362,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,87847.94662070274,94973.78384137154,87847.94662070274,7105.049699783325,10.437105655670166,0.0 -196500,3.3064263,1.1274893,,,,,,,,,,,,,, -196600,3.1653993,1.2093133,,,,,,,,,,,,,, -196700,2.978483,1.3938787,,,,,,,,,,,,,, -196800,3.05367,2.154399,,,,,,,,,,,,,, -196900,2.8761137,2.2087321,,,,,,,,,,,,,, -197000,3.6377406,2.7333155,,,,,,,,,,,,,, -197100,3.3775072,1.5181676,,,,,,,,,,,,,, -197200,2.8748806,1.828528,,,,,,,,,,,,,, -197300,3.1182053,2.3429666,,,,,,,,,,,,,, -197400,3.1988387,1.0979427,,,,,,,,,,,,,, -197410,,,0.8886327743530273,0.4180335402488708,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,88267.96205282211,95427.28940010072,88267.96205282211,7138.435284852982,10.492969989776611,0.0 -197500,3.4929886,3.1887865,,,,,,,,,,,,,, -197600,3.5188313,2.9138987,,,,,,,,,,,,,, -197700,3.1575892,1.2518668,,,,,,,,,,,,,, -197800,2.8107955,0.9919299,,,,,,,,,,,,,, -197900,2.8742263,1.7884495,,,,,,,,,,,,,, -198000,3.573994,1.1437912,,,,,,,,,,,,,, -198100,3.9162564,2.7037966,,,,,,,,,,,,,, -198200,3.0553882,1.042526,,,,,,,,,,,,,, -198300,3.1237445,1.1143742,,,,,,,,,,,,,, -198345,,,0.8878905773162842,0.4148828089237213,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,88688.21113562584,95883.8189907074,88688.21113562584,7174.6018846035,10.557025671005247,0.0 -198400,3.1033182,1.1480062,,,,,,,,,,,,,, -198500,2.806072,1.0928338,,,,,,,,,,,,,, -198600,2.928588,1.0416179,,,,,,,,,,,,,, -198700,3.3620913,1.294014,,,,,,,,,,,,,, -198800,2.8938842,1.3434944,,,,,,,,,,,,,, -198900,3.348047,2.540566,,,,,,,,,,,,,, -199000,3.28351,1.1619142,,,,,,,,,,,,,, -199100,3.1480615,2.774576,,,,,,,,,,,,,, -199200,3.2942061,1.2184533,,,,,,,,,,,,,, -199283,,,0.8878905773162842,0.4175741970539093,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,89108.45638513565,96338.09540224075,89108.45638513565,7208.525338888168,10.615936279296877,0.0 -199300,3.5377045,1.1879215,,,,,,,,,,,,,, -199400,3.0753288,2.3305142,,,,,,,,,,,,,, -199500,3.1877317,1.3381382,,,,,,,,,,,,,, -199600,3.1484258,2.5650792,,,,,,,,,,,,,, -199700,3.147683,1.1752708,,,,,,,,,,,,,, -199800,3.2921493,2.8154569,,,,,,,,,,,,,, -199900,2.940288,1.338605,,,,,,,,,,,,,, -200000,2.7070954,1.9173342,,,,,,,,,,,,,, -200100,3.0412805,1.2731352,,,,,,,,,,,,,, -200200,3.1757174,1.0602943,,,,,,,,,,,,,, -200221,,,0.8884570002555847,0.4154017865657806,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,89528.80938267708,96794.25735282898,89528.80938267708,7244.22820520401,10.671677350997925,0.0 -200300,2.8526678,1.8609815,,,,,,,,,,,,,, -200400,3.040089,1.074708,,,,,,,,,,,,,, -200500,2.7452295,1.8775549,,,,,,,,,,,,,, -200600,2.8720644,1.0310916,,,,,,,,,,,,,, -200700,3.1725328,1.883179,,,,,,,,,,,,,, -200800,3.1182673,2.574469,,,,,,,,,,,,,, -200900,3.1402166,1.0854728,,,,,,,,,,,,,, -201000,2.9594855,1.1598048,,,,,,,,,,,,,, -201100,2.9479392,1.1756439,,,,,,,,,,,,,, -201160,,,0.8883593678474426,0.4166028499603271,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,89948.93990373611,97248.5354616642,89948.93990373611,7278.270314931869,10.727835655212402,0.0 -201200,2.9446733,1.0859501,,,,,,,,,,,,,, -201300,3.2331235,1.1489401,,,,,,,,,,,,,, -201400,3.7420766,3.1642177,,,,,,,,,,,,,, -201500,3.369792,1.3963712,,,,,,,,,,,,,, -201600,3.086801,1.1535248,,,,,,,,,,,,,, -201700,3.0885344,1.417673,,,,,,,,,,,,,, -201800,3.1714966,1.9717014,,,,,,,,,,,,,, -201900,3.2813835,1.916582,,,,,,,,,,,,,, -202000,3.1269703,1.2919992,,,,,,,,,,,,,, -202100,,,0.8859374523162842,0.4209687709808349,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,90368.9737598896,97701.45683455469,90368.9737598896,7311.052936553955,10.78298568725586,0.0 -202100,2.8955731,0.98226196,,,,,,,,,,,,,, -202200,3.2724338,1.1967906,,,,,,,,,,,,,, -202300,3.2276616,1.1454141,,,,,,,,,,,,,, -202400,3.1001427,2.7232134,,,,,,,,,,,,,, -202500,2.8306003,2.2720795,,,,,,,,,,,,,, -202600,3.2769039,2.5364676,,,,,,,,,,,,,, -202700,3.0990021,2.7580638,,,,,,,,,,,,,, -202800,2.9928806,1.1469336,,,,,,,,,,,,,, -202900,3.1882615,2.5373838,,,,,,,,,,,,,, -203000,3.433638,2.848723,,,,,,,,,,,,,, -203036,,,0.88880854845047,0.4150761067867279,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,90788.9855811596,98157.04360675812,90788.9855811596,7346.512250185013,10.849711179733276,0.0 -203100,3.0301263,1.1585169,,,,,,,,,,,,,, -203200,2.7544072,1.6509372,,,,,,,,,,,,,, -203300,3.0508087,1.6026447,,,,,,,,,,,,,, -203400,5.058875,1.3273888,,,,,,,,,,,,,, -203500,3.2024183,1.1528765,,,,,,,,,,,,,, -203600,3.1759102,1.2833422,,,,,,,,,,,,,, -203700,3.0028875,1.2709708,,,,,,,,,,,,,, -203800,3.3446593,1.1609566,,,,,,,,,,,,,, -203900,3.4726553,3.175855,,,,,,,,,,,,,, -203972,,,0.8892382383346558,0.4137385487556457,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,91208.94326591492,98612.33208966257,91208.94326591492,7381.7394988536835,10.903832912445068,0.0 -204000,3.0356126,1.0555071,,,,,,,,,,,,,, -204100,3.0062978,1.146984,,,,,,,,,,,,,, -204200,2.765503,1.2415361,,,,,,,,,,,,,, -204300,2.9571593,1.7742375,,,,,,,,,,,,,, -204400,3.4861772,1.1087532,,,,,,,,,,,,,, -204500,3.9619071,3.1800182,,,,,,,,,,,,,, -204600,2.974337,1.0653465,,,,,,,,,,,,,, -204700,2.8247964,1.032768,,,,,,,,,,,,,, -204800,3.235497,2.3706944,,,,,,,,,,,,,, -204900,2.943867,1.5856853,,,,,,,,,,,,,, -204908,,,0.8872265219688416,0.4107032716274261,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,91628.93712472916,99066.45861530304,91628.93712472916,7415.764829158783,10.96209716796875,0.0 -205000,3.497617,3.2581577,,,,,,,,,,,,,, -205100,2.9782186,1.2548709,,,,,,,,,,,,,, -205200,3.151876,2.0360641,,,,,,,,,,,,,, -205300,3.1969318,1.0895663,,,,,,,,,,,,,, -205400,3.0486054,1.133037,,,,,,,,,,,,,, -205500,3.2926562,1.2018803,,,,,,,,,,,,,, -205600,2.9131057,1.1918248,,,,,,,,,,,,,, -205700,3.1919954,1.1710298,,,,,,,,,,,,,, -205800,3.1030698,1.1728361,,,,,,,,,,,,,, -205843,,,0.8876953125,0.4176147878170013,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,92048.84640073776,99521.96976304054,92048.84640073776,7451.258689165115,11.021414756774902,0.0 -205900,3.49982,3.1222131,,,,,,,,,,,,,, -206000,3.2737617,1.2915127,,,,,,,,,,,,,, -206100,3.0601902,1.6959165,,,,,,,,,,,,,, -206200,2.9230106,2.2202773,,,,,,,,,,,,,, -206300,2.977356,1.9493759,,,,,,,,,,,,,, -206400,3.2227118,2.089881,,,,,,,,,,,,,, -206500,2.9110801,1.4083135,,,,,,,,,,,,,, -206600,3.3636172,3.0032268,,,,,,,,,,,,,, -206700,3.2689884,1.1310663,,,,,,,,,,,,,, -206781,,,0.8884570002555847,0.4156330525875091,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,92468.78777194025,99977.3512263298,92468.78777194025,7486.59016084671,11.081706762313845,0.0 -206800,3.0393965,1.1354095,,,,,,,,,,,,,, -206900,3.05388,1.3799195,,,,,,,,,,,,,, -207000,3.7449555,3.197095,,,,,,,,,,,,,, -207100,3.2711775,1.1828834,,,,,,,,,,,,,, -207200,3.2736695,2.6210518,,,,,,,,,,,,,, -207300,3.1711054,1.279686,,,,,,,,,,,,,, -207400,3.0768008,2.4348855,,,,,,,,,,,,,, -207500,2.8543444,1.4347007,,,,,,,,,,,,,, -207600,4.2707753,3.2877545,,,,,,,,,,,,,, -207700,3.019418,1.0403345,,,,,,,,,,,,,, -207720,,,0.8912109136581421,0.4088407754898071,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,92888.84211206436,100433.30316019058,92888.84211206436,7522.379315137863,11.140997648239136,0.0 -207800,2.955198,1.4260713,,,,,,,,,,,,,, -207900,3.9670475,3.373282,,,,,,,,,,,,,, -208000,3.6572351,3.0849338,,,,,,,,,,,,,, -208100,3.1929305,1.1306045,,,,,,,,,,,,,, -208200,3.3379538,2.1597028,,,,,,,,,,,,,, -208300,3.3633318,1.2105217,,,,,,,,,,,,,, -208400,3.3802195,1.0287303,,,,,,,,,,,,,, -208500,3.0212052,1.3389395,,,,,,,,,,,,,, -208600,3.8840222,3.1053612,,,,,,,,,,,,,, -208657,,,0.8899999856948853,0.4095646440982818,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,93309.04252171516,100887.8080637455,93309.04252171516,7556.569756746292,11.206549644470217,0.0 -208700,3.8517075,3.2016842,,,,,,,,,,,,,, -208800,3.8243275,3.2444777,,,,,,,,,,,,,, -208900,3.0555003,1.1973158,,,,,,,,,,,,,, -209000,3.56801,3.1406326,,,,,,,,,,,,,, -209100,2.9062352,1.116246,,,,,,,,,,,,,, -209200,3.0480394,1.5840205,,,,,,,,,,,,,, -209300,3.1948485,1.2877831,,,,,,,,,,,,,, -209400,3.1483757,1.2516251,,,,,,,,,,,,,, -209500,3.0815103,1.6124324,,,,,,,,,,,,,, -209592,,,0.88832026720047,0.4110678136348724,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,93729.30348491669,101342.25642371178,93729.30348491669,7590.646590232849,11.269054651260376,0.0 -209600,3.2565947,1.0074137,,,,,,,,,,,,,, -209700,2.7654233,1.7400446,,,,,,,,,,,,,, -209800,3.0732055,1.0825542,,,,,,,,,,,,,, -209900,3.415278,1.1067224,,,,,,,,,,,,,, -210000,4.1005044,3.1076667,,,,,,,,,,,,,, -210100,3.4294622,1.0823289,,,,,,,,,,,,,, -210200,3.0873625,1.2717458,,,,,,,,,,,,,, -210300,3.3485303,1.0499009,,,,,,,,,,,,,, -210400,3.6315312,3.2278647,,,,,,,,,,,,,, -210500,2.9154468,1.1500009,,,,,,,,,,,,,, -210529,,,0.8885546922683716,0.4129042029380798,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,94149.20416498184,101798.49844503404,94149.20416498184,7626.880306720734,11.32774782180786,0.0 -210600,3.4589975,3.0487375,,,,,,,,,,,,,, -210700,3.9086158,2.034014,,,,,,,,,,,,,, -210800,3.3232186,2.9149926,,,,,,,,,,,,,, -210900,3.674929,2.7069798,,,,,,,,,,,,,, -211000,3.4631665,2.885424,,,,,,,,,,,,,, -211100,3.0140142,1.135207,,,,,,,,,,,,,, -211200,3.3340445,2.6200147,,,,,,,,,,,,,, -211300,2.7285051,1.3039107,,,,,,,,,,,,,, -211400,3.2128704,1.997235,,,,,,,,,,,,,, -211471,,,0.8863281011581421,0.4174293875694275,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,94569.50618243216,102253.2084736824,94569.50618243216,7661.169199705124,11.396864175796509,0.0 -211500,3.1514523,1.1356437,,,,,,,,,,,,,, -211600,3.3785827,1.8994982,,,,,,,,,,,,,, -211700,3.2211432,1.1039361,,,,,,,,,,,,,, -211800,3.1008055,2.0387218,,,,,,,,,,,,,, -211900,3.0155015,1.0698502,,,,,,,,,,,,,, -212000,3.2389452,1.2581333,,,,,,,,,,,,,, -212100,3.0402613,2.275847,,,,,,,,,,,,,, -212200,3.2724752,1.0888236,,,,,,,,,,,,,, -212300,3.1046774,1.3607471,,,,,,,,,,,,,, -212400,3.024495,1.1340756,,,,,,,,,,,,,, -212411,,,0.8870702981948853,0.4161613285541534,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,94989.62823104858,102708.59036684036,94989.62823104858,7696.322644948959,11.454381227493286,0.0 -212500,4.404997,1.1215446,,,,,,,,,,,,,, -212600,3.4204621,1.5717347,,,,,,,,,,,,,, -212700,3.1896887,1.0087553,,,,,,,,,,,,,, -212800,3.4876564,2.9643335,,,,,,,,,,,,,, -212900,3.0902047,1.1691282,,,,,,,,,,,,,, -213000,3.2507231,1.8659139,,,,,,,,,,,,,, -213100,3.1893103,1.0929796,,,,,,,,,,,,,, -213200,3.2108822,1.2889026,,,,,,,,,,,,,, -213300,2.8527176,1.494978,,,,,,,,,,,,,, -213345,,,0.8892773389816284,0.4133086800575256,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,95409.86905503272,103163.48323106766,95409.86905503272,7730.858766078949,11.52181363105774,0.0 -213400,2.9725928,1.3099341,,,,,,,,,,,,,, -213500,3.2886453,1.3703202,,,,,,,,,,,,,, -213600,2.9734824,1.1405269,,,,,,,,,,,,,, -213700,3.344822,2.7636063,,,,,,,,,,,,,, -213800,3.2777607,2.6609714,,,,,,,,,,,,,, -213900,3.0131786,1.3483372,,,,,,,,,,,,,, -214000,2.974773,2.0354903,,,,,,,,,,,,,, -214100,3.444774,1.214523,,,,,,,,,,,,,, -214200,3.25184,1.0646857,,,,,,,,,,,,,, -214281,,,0.888964831829071,0.4124159216880798,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,95829.8261590004,103619.44575953484,95829.8261590004,7766.759558439255,11.57755184173584,0.0 -214300,2.881656,2.0095596,,,,,,,,,,,,,, -214400,2.862509,1.1383024,,,,,,,,,,,,,, -214500,3.7235122,3.1769574,,,,,,,,,,,,,, -214600,3.245345,0.9975349,,,,,,,,,,,,,, -214700,3.1095564,1.1991186,,,,,,,,,,,,,, -214800,3.2817614,3.0003831,,,,,,,,,,,,,, -214900,3.4420576,2.9524412,,,,,,,,,,,,,, -215000,2.9491506,1.210733,,,,,,,,,,,,,, -215100,3.2304726,2.3719475,,,,,,,,,,,,,, -215200,3.0077798,1.3411704,,,,,,,,,,,,,, -215221,,,0.8868359327316284,0.4176099598407745,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,96249.92731976508,104073.60329914092,96249.92731976508,7800.7112374305725,11.63347578048706,0.0 -215300,3.2174647,1.158001,,,,,,,,,,,,,, -215400,3.138437,1.1644702,,,,,,,,,,,,,, -215500,2.9920287,1.1208752,,,,,,,,,,,,,, -215600,3.021446,1.2239873,,,,,,,,,,,,,, -215700,3.600338,1.3054706,,,,,,,,,,,,,, -215800,3.1390507,1.3649215,,,,,,,,,,,,,, -215900,3.3029895,1.0749277,,,,,,,,,,,,,, -216000,3.371306,1.2266763,,,,,,,,,,,,,, -216100,3.3527417,2.4735298,,,,,,,,,,,,,, -216159,,,0.8875195384025574,0.4130787849426269,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,96669.82828593254,104527.22795033456,96669.82828593254,7834.320219755173,11.700467109680176,0.0 -216200,3.1723146,1.2541294,,,,,,,,,,,,,, -216300,4.189154,3.1838946,,,,,,,,,,,,,, -216400,2.9808052,1.1302803,,,,,,,,,,,,,, -216500,3.0520177,2.6907272,,,,,,,,,,,,,, -216600,3.2578835,2.801914,,,,,,,,,,,,,, -216700,3.027196,1.308995,,,,,,,,,,,,,, -216800,3.1484296,1.4806528,,,,,,,,,,,,,, -216900,2.8475733,2.4405708,,,,,,,,,,,,,, -217000,2.8742368,1.7694199,,,,,,,,,,,,,, -217094,,,0.8885155916213989,0.4117570519447326,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,97089.97589826584,104983.97067832948,97089.97589826584,7870.795943975449,11.771633863449097,0.0 -217100,2.9617877,1.3109062,,,,,,,,,,,,,, -217200,2.8564794,1.5309616,,,,,,,,,,,,,, -217300,3.4197578,3.123991,,,,,,,,,,,,,, -217400,2.9590495,1.0783583,,,,,,,,,,,,,, -217500,2.8359177,1.0793461,,,,,,,,,,,,,, -217600,3.060536,2.698184,,,,,,,,,,,,,, -217700,3.0488846,1.1032603,,,,,,,,,,,,,, -217800,3.182357,1.7865193,,,,,,,,,,,,,, -217900,3.0298662,1.9156506,,,,,,,,,,,,,, -218000,3.243307,2.7803783,,,,,,,,,,,,,, -218032,,,0.8877929449081421,0.4174558520317077,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,97509.97085809708,105439.86210155489,97509.97085809708,7906.584993600845,11.830165147781372,0.0 -218100,2.9946306,0.91355765,,,,,,,,,,,,,, -218200,2.9020495,1.1262903,,,,,,,,,,,,,, -218300,3.4531236,1.0978396,,,,,,,,,,,,,, -218400,3.3296475,1.5475528,,,,,,,,,,,,,, -218500,2.9495265,1.4271679,,,,,,,,,,,,,, -218600,3.0947351,1.9770948,,,,,,,,,,,,,, -218700,3.5294478,3.058372,,,,,,,,,,,,,, -218800,3.1617043,1.7784016,,,,,,,,,,,,,, -218900,2.8253152,1.0916384,,,,,,,,,,,,,, -218971,,,0.8862499594688416,0.4198924303054809,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,97929.96910381316,105894.55187416077,97929.96910381316,7941.163271188736,11.894726037979126,0.0 -219000,3.3206124,3.0196118,,,,,,,,,,,,,, -219100,3.152645,2.142621,,,,,,,,,,,,,, -219200,3.1311364,1.1804463,,,,,,,,,,,,,, -219300,2.9103131,1.1705778,,,,,,,,,,,,,, -219400,3.0775676,1.8702978,,,,,,,,,,,,,, -219500,3.0167687,1.1502851,,,,,,,,,,,,,, -219600,2.972153,1.6357269,,,,,,,,,,,,,, -219700,3.475156,2.831637,,,,,,,,,,,,,, -219800,3.0726528,1.8370907,,,,,,,,,,,,,, -219900,3.0425398,2.0805836,,,,,,,,,,,,,, -219909,,,0.887499988079071,0.4200115203857422,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,98350.23102784155,106350.5483431816,98350.23102784155,7976.790856599808,11.953049659729004,0.0 -220000,3.2951593,2.775413,,,,,,,,,,,,,, -220100,3.4784558,1.1541765,,,,,,,,,,,,,, -220200,3.1083791,1.0898956,,,,,,,,,,,,,, -220300,3.5000627,3.0216422,,,,,,,,,,,,,, -220400,2.918295,1.1409885,,,,,,,,,,,,,, -220500,2.9464543,1.3041472,,,,,,,,,,,,,, -220600,3.1498928,2.4956188,,,,,,,,,,,,,, -220700,3.3652892,2.992573,,,,,,,,,,,,,, -220798,,,0.8900781273841858,0.4076668322086334,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,98770.417617321,106806.2047586441,98770.417617321,8012.15415096283,12.012901067733765,0.0 -220800,3.0733323,1.0861241,,,,,,,,,,,,,, -220900,3.2807248,1.1284206,,,,,,,,,,,,,, -221000,2.848784,1.0922372,,,,,,,,,,,,,, -221100,3.104762,1.2791619,,,,,,,,,,,,,, -221200,3.539242,3.0020547,,,,,,,,,,,,,, -221300,2.914474,2.1713848,,,,,,,,,,,,,, -221400,3.2261364,2.61239,,,,,,,,,,,,,, -221500,3.0769591,1.978146,,,,,,,,,,,,,, -221600,3.0686944,1.6833571,,,,,,,,,,,,,, -221700,2.9039683,2.2730014,,,,,,,,,,,,,, -221735,,,0.8908789157867432,0.4115219414234161,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,99190.58779096603,107261.30042052268,99190.58779096603,8046.972583532333,12.071365118026732,0.0 -221800,2.97954,1.0823075,,,,,,,,,,,,,, -221900,3.031304,1.7239645,,,,,,,,,,,,,, -222000,3.3590102,2.2716022,,,,,,,,,,,,,, -222100,3.157717,1.0928457,,,,,,,,,,,,,, -222200,3.132796,2.8493946,,,,,,,,,,,,,, -222300,3.0563653,1.6036758,,,,,,,,,,,,,, -222400,3.6360137,3.162987,,,,,,,,,,,,,, -222500,3.3677042,2.4907603,,,,,,,,,,,,,, -222600,3.1538239,1.2878451,,,,,,,,,,,,,, -222672,,,0.8867382407188416,0.4187392294406891,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,99610.72120189668,107717.69447016716,99610.72120189668,8083.12747168541,12.127808094024658,0.0 -222700,3.1158469,1.1092905,,,,,,,,,,,,,, -222800,3.2521799,2.2459357,,,,,,,,,,,,,, -222900,3.2942412,1.8908888,,,,,,,,,,,,,, -223000,2.895777,2.0794966,,,,,,,,,,,,,, -223100,3.3734798,2.7151778,,,,,,,,,,,,,, -223200,3.9161854,3.2334294,,,,,,,,,,,,,, -223300,3.3184094,1.3323952,,,,,,,,,,,,,, -223400,2.8892596,2.2880924,,,,,,,,,,,,,, -223500,3.5537963,1.1429976,,,,,,,,,,,,,, -223600,4.0431137,3.189694,,,,,,,,,,,,,, -223609,,,0.8858398199081421,0.4228871762752533,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,100030.98379468918,108171.35675144196,100030.98379468918,8116.420456647873,12.186193466186523,0.0 -223700,3.13334,1.0827881,,,,,,,,,,,,,, -223800,2.9730198,2.3546124,,,,,,,,,,,,,, -223900,3.3022084,1.3506954,,,,,,,,,,,,,, -224000,3.2084808,2.414605,,,,,,,,,,,,,, -224100,2.9399793,1.0737473,,,,,,,,,,,,,, -224200,4.1931653,3.2626972,,,,,,,,,,,,,, -224300,2.929844,2.0506015,,,,,,,,,,,,,, -224400,3.1621125,1.3138186,,,,,,,,,,,,,, -224500,2.820864,1.8344868,,,,,,,,,,,,,, -224545,,,0.8891015648841858,0.4168025851249695,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,100451.22117090224,108626.73807239532,100451.22117090224,8151.446369409561,12.254721403121948,0.0 -224600,2.8695736,1.9757154,,,,,,,,,,,,,, -224700,3.0204558,1.128195,,,,,,,,,,,,,, -224800,3.9896622,3.2102554,,,,,,,,,,,,,, -224900,3.1191516,1.2823894,,,,,,,,,,,,,, -225000,3.3068933,1.876617,,,,,,,,,,,,,, -225100,3.1288052,2.8865945,,,,,,,,,,,,,, -225200,3.2978108,1.3961389,,,,,,,,,,,,,, -225300,3.1958747,1.0025784,,,,,,,,,,,,,, -225400,3.3114321,1.0751069,,,,,,,,,,,,,, -225482,,,0.8861327767372131,0.4200985133647918,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,100871.4775969982,109080.39751529694,100871.4775969982,8184.739677429199,12.3156476020813,0.0 -225500,3.0018396,1.3585664,,,,,,,,,,,,,, -225600,3.1574821,1.1150181,,,,,,,,,,,,,, -225700,3.0398748,1.0958737,,,,,,,,,,,,,, -225800,3.0925472,1.1642584,,,,,,,,,,,,,, -225900,3.0343518,1.1569155,,,,,,,,,,,,,, -226000,3.1135054,1.1394477,,,,,,,,,,,,,, -226100,5.679728,1.1678103,,,,,,,,,,,,,, -226200,3.1862,1.1382273,,,,,,,,,,,,,, -226300,2.887986,1.7406225,,,,,,,,,,,,,, -226400,3.3277085,2.7014887,,,,,,,,,,,,,, -226419,,,0.8892577886581421,0.4118213951587677,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,101291.44068813324,109534.42616176604,101291.44068813324,8218.687573194504,12.384529113769531,0.0 -226500,3.172114,1.2866223,,,,,,,,,,,,,, -226600,3.1409543,1.133924,,,,,,,,,,,,,, -226700,3.1860368,1.048569,,,,,,,,,,,,,, -226800,3.463874,2.7674856,,,,,,,,,,,,,, -226900,3.2029517,1.7756882,,,,,,,,,,,,,, -227000,3.1850955,1.1485561,,,,,,,,,,,,,, -227100,3.1324534,1.152628,,,,,,,,,,,,,, -227200,3.674366,3.181514,,,,,,,,,,,,,, -227300,2.890629,1.4970994,,,,,,,,,,,,,, -227359,,,0.8875781297683716,0.4156609177589416,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,101711.74678444862,109990.22974205016,101711.74678444862,8254.07694530487,12.444185018539429,0.0 -227400,3.0175579,1.1663041,,,,,,,,,,,,,, -227500,3.058253,1.1418693,,,,,,,,,,,,,, -227600,2.992433,1.8300717,,,,,,,,,,,,,, -227700,3.0910022,1.1554182,,,,,,,,,,,,,, -227800,3.4210691,1.1425151,,,,,,,,,,,,,, -227900,2.9489365,1.069746,,,,,,,,,,,,,, -228000,3.2108088,1.4498178,,,,,,,,,,,,,, -228100,3.6818104,3.1006756,,,,,,,,,,,,,, -228200,3.2513254,2.3654158,,,,,,,,,,,,,, -228297,,,0.8884570002555847,0.4122214913368225,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,102131.85267829896,110444.7630224228,102131.85267829896,8288.393639802933,12.505478382110596,0.0 -228300,3.0758145,1.1601102,,,,,,,,,,,,,, -228400,3.8541205,3.0005832,,,,,,,,,,,,,, -228500,3.2456696,1.087913,,,,,,,,,,,,,, -228600,2.960995,1.3449988,,,,,,,,,,,,,, -228700,3.0784328,2.3274012,,,,,,,,,,,,,, -228800,3.2142522,1.096505,,,,,,,,,,,,,, -228900,3.4051871,1.1134162,,,,,,,,,,,,,, -229000,3.3591502,2.4645774,,,,,,,,,,,,,, -229100,3.2879531,1.2321113,,,,,,,,,,,,,, -229200,3.2032616,2.06477,,,,,,,,,,,,,, -229234,,,0.8875585794448853,0.4155343770980835,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,102552.01909089088,110900.68841600418,102552.01909089088,8324.046884298325,12.562983989715576,0.0 -229300,3.179308,1.1013051,,,,,,,,,,,,,, -229400,2.9049962,1.2864428,,,,,,,,,,,,,, -229500,3.1652813,1.2155895,,,,,,,,,,,,,, -229600,3.1173494,1.438112,,,,,,,,,,,,,, -229700,2.8671932,1.7792711,,,,,,,,,,,,,, -229800,3.1290817,1.1639241,,,,,,,,,,,,,, -229900,3.9884531,2.859569,,,,,,,,,,,,,, -230000,3.4063828,2.7762773,,,,,,,,,,,,,, -230100,3.329152,1.0362116,,,,,,,,,,,,,, -230175,,,0.8898828029632568,0.4132254719734192,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,102972.22385382652,111355.14338946342,102972.22385382652,8358.187504053116,12.623552083969116,0.0 -230200,3.1196961,1.15831,,,,,,,,,,,,,, -230300,2.9200802,1.2361623,,,,,,,,,,,,,, -230400,3.2423756,1.0520837,,,,,,,,,,,,,, -230500,3.3858445,1.1621704,,,,,,,,,,,,,, -230600,3.6373749,1.107146,,,,,,,,,,,,,, -230700,2.9731286,1.0572506,,,,,,,,,,,,,, -230800,2.9453077,1.0931368,,,,,,,,,,,,,, -230900,3.229757,1.1958444,,,,,,,,,,,,,, -231000,3.0675247,1.0770452,,,,,,,,,,,,,, -231100,3.0031245,1.4627426,,,,,,,,,,,,,, -231115,,,0.88916015625,0.4119197726249695,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,103392.2831542492,111807.89689588548,103392.2831542492,8390.762695789337,12.692643880844116,0.0 -231200,3.3437736,1.286988,,,,,,,,,,,,,, -231300,2.960131,1.5903932,,,,,,,,,,,,,, -231400,3.6452727,3.3099148,,,,,,,,,,,,,, -231500,3.1366956,1.5133332,,,,,,,,,,,,,, -231600,2.8849208,1.1300257,,,,,,,,,,,,,, -231700,2.9975834,1.1666615,,,,,,,,,,,,,, -231800,3.1209946,1.0974973,,,,,,,,,,,,,, -231900,3.3085215,1.385444,,,,,,,,,,,,,, -232000,3.0496843,1.4524705,,,,,,,,,,,,,, -232053,,,0.889453113079071,0.4119241535663605,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,103812.51510357855,112262.90275025368,103812.51510357855,8425.405321359634,12.775279998779297,0.0 -232100,2.9458532,1.2644324,,,,,,,,,,,,,, -232200,3.0074916,2.195984,,,,,,,,,,,,,, -232300,3.082627,1.601181,,,,,,,,,,,,,, -232400,3.0947843,1.1191999,,,,,,,,,,,,,, -232500,3.7781973,3.2505755,,,,,,,,,,,,,, -232600,3.2385375,1.3957434,,,,,,,,,,,,,, -232700,3.2031918,1.5285413,,,,,,,,,,,,,, -232800,2.996484,1.0461183,,,,,,,,,,,,,, -232900,2.782999,1.0280769,,,,,,,,,,,,,, -232993,,,0.8890820145606995,0.4118773639202118,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,104232.78696107864,112718.30594062804,104232.78696107864,8460.428438425064,12.834226846694946,0.0 -233000,3.0909886,1.206342,,,,,,,,,,,,,, -233100,3.1860433,1.3390687,,,,,,,,,,,,,, -233200,3.0912101,1.0581393,,,,,,,,,,,,,, -233300,2.8171437,1.8039213,,,,,,,,,,,,,, -233400,3.3333008,2.7841616,,,,,,,,,,,,,, -233500,3.2355344,1.1610717,,,,,,,,,,,,,, -233600,3.3764322,1.1457129,,,,,,,,,,,,,, -233700,3.0141964,2.098421,,,,,,,,,,,,,, -233800,3.9570568,3.2532697,,,,,,,,,,,,,, -233900,3.4181066,1.2544532,,,,,,,,,,,,,, -233934,,,0.8887304663658142,0.409778743982315,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,104652.87011408806,113174.4059624672,104652.87011408806,8496.338250160217,12.893026351928713,0.0 -234000,3.1911786,2.6206133,,,,,,,,,,,,,, -234100,2.9458127,2.0377991,,,,,,,,,,,,,, -234200,3.085535,1.0744424,,,,,,,,,,,,,, -234300,3.0095472,1.1918191,,,,,,,,,,,,,, -234400,2.8298197,1.2204478,,,,,,,,,,,,,, -234500,3.019634,1.1557425,,,,,,,,,,,,,, -234600,3.90801,3.2776384,,,,,,,,,,,,,, -234700,3.2768543,1.9562067,,,,,,,,,,,,,, -234800,2.907524,1.07617,,,,,,,,,,,,,, -234871,,,0.8856640458106995,0.4191117286682129,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,105072.76435279846,113629.954102993,105072.76435279846,8531.880004882812,12.95697784423828,0.0 -234900,2.910422,1.4848994,,,,,,,,,,,,,, -235000,3.2467315,1.06466,,,,,,,,,,,,,, -235100,3.217305,2.1645303,,,,,,,,,,,,,, -235200,2.8891897,1.0553851,,,,,,,,,,,,,, -235300,3.0975788,2.4813833,,,,,,,,,,,,,, -235400,2.918658,1.2217741,,,,,,,,,,,,,, -235500,3.029406,1.0367897,,,,,,,,,,,,,, -235600,2.8776662,1.6940393,,,,,,,,,,,,,, -235700,3.376232,2.6937401,,,,,,,,,,,,,, -235800,3.264246,1.0576433,,,,,,,,,,,,,, -235811,,,0.8889062404632568,0.409523993730545,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,105492.69994354248,114084.48039150238,105492.69994354248,8566.35573387146,13.02313780784607,0.0 -235900,4.0157003,1.1889514,,,,,,,,,,,,,, -236000,3.1783714,1.4553363,,,,,,,,,,,,,, -236100,3.0764256,1.7446599,,,,,,,,,,,,,, -236200,3.2581148,1.1089268,,,,,,,,,,,,,, -236300,3.1020842,1.2140954,,,,,,,,,,,,,, -236400,3.33788,1.2340326,,,,,,,,,,,,,, -236500,3.0048583,1.0459156,,,,,,,,,,,,,, -236600,3.1868818,1.1269377,,,,,,,,,,,,,, -236700,2.975212,2.4360466,,,,,,,,,,,,,, -236750,,,0.88880854845047,0.4133784770965576,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,105912.88513946532,114540.82594275476,105912.88513946532,8602.39948964119,13.09098768234253,0.0 -236800,3.0214703,1.103893,,,,,,,,,,,,,, -236900,3.1150708,1.9434236,,,,,,,,,,,,,, -237000,3.1739862,1.5949135,,,,,,,,,,,,,, -237100,2.9758158,1.0727592,,,,,,,,,,,,,, -237200,3.0547473,1.0881107,,,,,,,,,,,,,, -237300,3.1078897,1.1218313,,,,,,,,,,,,,, -237400,3.1036963,1.2184719,,,,,,,,,,,,,, -237500,3.071741,1.1583564,,,,,,,,,,,,,, -237600,3.3600392,1.182838,,,,,,,,,,,,,, -237691,,,0.8860937356948853,0.4221966564655304,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,106333.18955159187,114995.24330091476,106333.18955159187,8636.400366544724,13.154456853866575,0.0 -237700,3.2698538,2.8626006,,,,,,,,,,,,,, -237800,3.018462,1.7012255,,,,,,,,,,,,,, -237900,3.364561,1.1003076,,,,,,,,,,,,,, -238000,4.203381,3.2297611,,,,,,,,,,,,,, -238100,3.252759,1.2603328,,,,,,,,,,,,,, -238200,4.32069,3.2696836,,,,,,,,,,,,,, -238300,3.079767,1.3057157,,,,,,,,,,,,,, -238400,3.1592638,2.5428724,,,,,,,,,,,,,, -238500,3.5362728,3.0986605,,,,,,,,,,,,,, -238600,2.8887117,2.2871792,,,,,,,,,,,,,, -238632,,,0.8895898461341858,0.4112511873245239,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,106753.54248261452,115451.53286147118,106753.54248261452,8672.216531038284,13.225939512252808,0.0 -238700,3.4610445,1.0795588,,,,,,,,,,,,,, -238800,3.011503,1.973812,,,,,,,,,,,,,, -238900,3.206653,2.7757564,,,,,,,,,,,,,, -239000,3.1135767,1.0367103,,,,,,,,,,,,,, -239100,2.9996984,1.0562518,,,,,,,,,,,,,, -239200,3.186111,1.0521986,,,,,,,,,,,,,, -239300,3.1906319,1.2376652,,,,,,,,,,,,,, -239400,3.137051,1.1622804,,,,,,,,,,,,,, -239500,2.9080153,2.0588162,,,,,,,,,,,,,, -239572,,,0.8858007788658142,0.4202886819839477,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,107173.82443404198,115904.74687862396,107173.82443404198,8705.033515691757,13.291870594024658,0.0 -239600,3.191246,1.9808345,,,,,,,,,,,,,, -239700,3.2892146,1.1327788,,,,,,,,,,,,,, -239800,3.3286316,1.2734228,,,,,,,,,,,,,, -239900,3.5455909,2.922916,,,,,,,,,,,,,, -240000,3.0770378,1.3003095,,,,,,,,,,,,,, -240100,3.0210571,1.1055748,,,,,,,,,,,,,, -240200,3.8123171,3.2694502,,,,,,,,,,,,,, -240300,3.1830115,1.7907654,,,,,,,,,,,,,, -240400,2.9068894,1.667541,,,,,,,,,,,,,, -240500,3.0754185,1.379135,,,,,,,,,,,,,, -240510,,,0.888671875,0.4127401113510132,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,107593.75956201552,116361.25875759123,107593.75956201552,8741.489234685898,13.36388087272644,0.0 -240600,3.367777,2.3069093,,,,,,,,,,,,,, -240700,3.2197855,1.0973523,,,,,,,,,,,,,, -240800,2.9799802,1.071377,,,,,,,,,,,,,, -240900,3.151209,2.9209702,,,,,,,,,,,,,, -241000,3.8162456,3.0477722,,,,,,,,,,,,,, -241100,3.3189468,1.2075431,,,,,,,,,,,,,, -241200,3.8922963,3.2782812,,,,,,,,,,,,,, -241300,3.3209934,1.3443753,,,,,,,,,,,,,, -241400,2.8847225,1.7539799,,,,,,,,,,,,,, -241450,,,0.8886327743530273,0.4101268351078033,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,108013.9641532898,116817.1728887558,108013.9641532898,8777.088340520859,13.4263174533844,0.0 -241500,3.098833,1.049557,,,,,,,,,,,,,, -241600,2.967679,1.2295492,,,,,,,,,,,,,, -241700,3.2471468,1.5043906,,,,,,,,,,,,,, -241800,3.1248558,1.1705297,,,,,,,,,,,,,, -241900,3.2967978,2.1349049,,,,,,,,,,,,,, -242000,3.1092885,1.3615502,,,,,,,,,,,,,, -242100,3.057235,1.1482687,,,,,,,,,,,,,, -242200,3.5128093,1.2243648,,,,,,,,,,,,,, -242300,3.1820316,1.1179616,,,,,,,,,,,,,, -242389,,,0.8875585794448853,0.4174584746360779,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,108434.24830532074,117271.9125611782,108434.24830532074,8811.436619997025,13.484199047088625,0.0 -242400,3.162891,2.4648995,,,,,,,,,,,,,, -242500,3.36693,2.7668557,,,,,,,,,,,,,, -242600,3.1414921,1.180709,,,,,,,,,,,,,, -242700,3.0965405,1.1537406,,,,,,,,,,,,,, -242800,3.014823,1.0946788,,,,,,,,,,,,,, -242900,4.3640203,2.6469336,,,,,,,,,,,,,, -243000,3.2521346,1.39666,,,,,,,,,,,,,, -243100,3.1548452,1.2471896,,,,,,,,,,,,,, -243200,3.485033,2.7035098,,,,,,,,,,,,,, -243300,3.37704,2.7044668,,,,,,,,,,,,,, -243327,,,0.8881250023841858,0.4179988503456116,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,108854.31929779051,117726.93820238112,108854.31929779051,8846.27702832222,13.549908876419067,0.0 -243400,3.0306802,1.1170952,,,,,,,,,,,,,, -243500,2.932079,1.1118523,,,,,,,,,,,,,, -243600,3.0765536,1.1170373,,,,,,,,,,,,,, -243700,3.2454035,2.200029,,,,,,,,,,,,,, -243800,3.0207372,1.4362476,,,,,,,,,,,,,, -243900,3.428464,2.994317,,,,,,,,,,,,,, -244000,3.0801985,1.8426716,,,,,,,,,,,,,, -244100,3.3034377,1.1616075,,,,,,,,,,,,,, -244200,2.6554003,1.1991526,,,,,,,,,,,,,, -244264,,,0.8875976204872131,0.4127923548221588,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,109274.4168112278,118183.19608783722,109274.4168112278,8882.327449798584,13.610346794128418,0.0 -244300,3.087382,2.4186964,,,,,,,,,,,,,, -244400,3.283195,1.1396979,,,,,,,,,,,,,, -244500,3.2118878,2.6541116,,,,,,,,,,,,,, -244600,3.1531637,2.6063404,,,,,,,,,,,,,, -244700,3.0345907,1.0767695,,,,,,,,,,,,,, -244800,3.0734234,2.064062,,,,,,,,,,,,,, -244900,3.1192434,1.2098271,,,,,,,,,,,,,, -245000,2.8435457,1.1160009,,,,,,,,,,,,,, -245100,3.2011738,2.443702,,,,,,,,,,,,,, -245200,3.2618015,2.8562722,,,,,,,,,,,,,, -245205,,,0.8902148008346558,0.4121952056884765,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,109694.49145460127,118639.5016169548,109694.49145460127,8918.449665307999,13.66965365409851,0.0 -245300,2.9509203,1.0884181,,,,,,,,,,,,,, -245400,2.85865,1.3519993,,,,,,,,,,,,,, -245500,3.1088278,1.0956315,,,,,,,,,,,,,, -245600,3.2735534,1.1257328,,,,,,,,,,,,,, -245700,3.1274002,1.2624619,,,,,,,,,,,,,, -245800,3.1192114,1.686589,,,,,,,,,,,,,, -245900,3.2543006,2.102498,,,,,,,,,,,,,, -246000,2.917667,1.0805994,,,,,,,,,,,,,, -246100,2.8817625,2.110012,,,,,,,,,,,,,, -246144,,,0.886035144329071,0.4249056279659271,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,110114.44577169418,119094.70700001717,110114.44577169418,8953.590856313705,13.729873657226562,0.0 -246200,3.206714,1.8288469,,,,,,,,,,,,,, -246300,2.9980993,0.9625493,,,,,,,,,,,,,, -246400,3.0465133,1.2340798,,,,,,,,,,,,,, -246500,3.0129755,0.97688985,,,,,,,,,,,,,, -246600,3.0994387,1.1614336,,,,,,,,,,,,,, -246700,3.2474327,1.1796883,,,,,,,,,,,,,, -246800,2.9269722,2.2343085,,,,,,,,,,,,,, -246900,2.7793386,1.5169262,,,,,,,,,,,,,, -247000,2.9516685,1.0947078,,,,,,,,,,,,,, -247082,,,0.8880273103713989,0.4152797162532806,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,110534.5217063427,119548.83497023582,110534.5217063427,8987.530959367752,13.79277729988098,0.0 -247100,3.273276,2.1981442,,,,,,,,,,,,,, -247200,3.1822517,2.8541257,,,,,,,,,,,,,, -247300,3.0895047,1.0716072,,,,,,,,,,,,,, -247400,3.0227306,1.7602882,,,,,,,,,,,,,, -247500,3.244673,1.144978,,,,,,,,,,,,,, -247600,3.9823217,3.2723744,,,,,,,,,,,,,, -247700,3.160417,1.1083897,,,,,,,,,,,,,, -247800,2.7558892,1.7990361,,,,,,,,,,,,,, -247900,3.4975967,1.2494538,,,,,,,,,,,,,, -248000,3.0013638,1.6726305,,,,,,,,,,,,,, -248019,,,0.8878515362739563,0.4194375574588775,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,110954.5231757164,120004.43804454803,110954.5231757164,9023.016151428224,13.860776662826538,0.0 -248100,4.361086,1.083238,,,,,,,,,,,,,, -248200,3.1450746,2.3667943,,,,,,,,,,,,,, -248300,2.825085,1.8373797,,,,,,,,,,,,,, -248400,3.744035,3.1106644,,,,,,,,,,,,,, -248500,3.0963132,1.196613,,,,,,,,,,,,,, -248600,3.071656,1.8778095,,,,,,,,,,,,,, -248700,3.1526198,1.0834908,,,,,,,,,,,,,, -248800,3.010358,1.3575686,,,,,,,,,,,,,, -248900,3.212303,1.1184,,,,,,,,,,,,,, -248961,,,0.8873632550239563,0.4151614904403686,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,111374.45286393166,120459.96077299118,111374.45286393166,9058.499682426453,13.92180871963501,0.0 -249000,3.0753024,2.1604862,,,,,,,,,,,,,, -249100,3.110568,1.1837125,,,,,,,,,,,,,, -249200,3.2014468,2.251399,,,,,,,,,,,,,, -249300,3.0169141,1.1062415,,,,,,,,,,,,,, -249400,3.0840352,1.0582098,,,,,,,,,,,,,, -249500,3.4403553,1.8159182,,,,,,,,,,,,,, -249600,2.9560227,1.4619834,,,,,,,,,,,,,, -249700,3.3030589,2.8370054,,,,,,,,,,,,,, -249800,3.0745149,1.0657787,,,,,,,,,,,,,, -249900,3.7333837,3.1818943,,,,,,,,,,,,,, -249901,,,0.8882030844688416,0.4157791137695312,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,111795.12572550774,120915.29966640472,111795.12572550774,9093.047081708908,13.991700649261476,0.0 -250000,3.250891,1.1020744,,,,,,,,,,,,,, -250100,3.0028527,2.3353298,,,,,,,,,,,,,, -250200,2.8481982,1.0190612,,,,,,,,,,,,,, -250300,3.425016,1.127538,,,,,,,,,,,,,, -250400,3.554832,3.0883245,,,,,,,,,,,,,, -250500,3.2929835,1.6176758,,,,,,,,,,,,,, -250600,3.1141791,1.2156479,,,,,,,,,,,,,, -250700,2.8023932,1.1542928,,,,,,,,,,,,,, -250800,3.0412345,1.7280844,,,,,,,,,,,,,, -250839,,,0.8891796469688416,0.4126760065555572,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,112215.37737846376,121371.44778513908,112215.37737846376,9128.832021713257,14.05441927909851,0.0 -250900,3.0674472,1.5612568,,,,,,,,,,,,,, -251000,3.0762634,1.2090887,,,,,,,,,,,,,, -251100,3.2836256,1.1332599,,,,,,,,,,,,,, -251200,3.0747895,1.2635732,,,,,,,,,,,,,, -251300,2.8595042,1.0568366,,,,,,,,,,,,,, -251400,3.329896,2.765131,,,,,,,,,,,,,, -251500,2.9007645,1.043255,,,,,,,,,,,,,, -251600,3.6680038,1.1239045,,,,,,,,,,,,,, -251700,2.8449283,1.8006121,,,,,,,,,,,,,, -251777,,,0.8882421851158142,0.4113604128360748,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,112635.57611656188,121826.532310009,112635.57611656188,9163.606656312944,14.1160409450531,0.0 -251800,3.4210143,3.1117265,,,,,,,,,,,,,, -251900,3.1347039,1.109757,,,,,,,,,,,,,, -252000,3.3422797,2.6896296,,,,,,,,,,,,,, -252100,3.3006492,1.0990081,,,,,,,,,,,,,, -252200,3.158641,1.0952507,,,,,,,,,,,,,, -252300,2.9951842,1.0106376,,,,,,,,,,,,,, -252400,3.1718738,2.669657,,,,,,,,,,,,,, -252500,3.1225517,1.3166399,,,,,,,,,,,,,, -252600,3.0363433,2.1912932,,,,,,,,,,,,,, -252700,3.4187486,1.147768,,,,,,,,,,,,,, -252710,,,0.8890429735183716,0.4132768213748932,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,113055.79975700378,122283.76314520836,113055.79975700378,9200.502378463743,14.178927898406982,0.0 -252800,3.1914017,1.0862881,,,,,,,,,,,,,, -252900,2.9002779,1.0267422,,,,,,,,,,,,,, -253000,3.0880432,2.2074122,,,,,,,,,,,,,, -253100,3.2476974,1.7292093,,,,,,,,,,,,,, -253200,3.0018454,1.1150262,,,,,,,,,,,,,, -253300,3.1491184,1.1621641,,,,,,,,,,,,,, -253400,3.3216205,1.1233383,,,,,,,,,,,,,, -253500,3.3946269,2.897307,,,,,,,,,,,,,, -253600,3.1759996,1.1437992,,,,,,,,,,,,,, -253650,,,0.8868749737739563,0.4217991530895233,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,113475.82202625276,122739.698867321,113475.82202625276,9236.304715394974,14.241063117980955,0.0 -253700,2.9462278,1.1262804,,,,,,,,,,,,,, -253800,3.2067664,1.0842781,,,,,,,,,,,,,, -253900,3.5502703,2.9155078,,,,,,,,,,,,,, -254000,3.1644423,1.1703489,,,,,,,,,,,,,, -254100,3.2199228,1.1545818,,,,,,,,,,,,,, -254200,2.6990485,2.0201232,,,,,,,,,,,,,, -254300,3.0331094,1.0542301,,,,,,,,,,,,,, -254400,3.7585166,3.096864,,,,,,,,,,,,,, -254500,3.361523,2.9503312,,,,,,,,,,,,,, -254589,,,0.8882030844688416,0.4138852655887604,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,113895.70823192596,123195.3730111122,113895.70823192596,9271.973058700562,14.31032943725586,0.0 -254600,3.1319818,1.0887432,,,,,,,,,,,,,, -254700,3.1159465,1.2538275,,,,,,,,,,,,,, -254800,5.0262566,2.3826787,,,,,,,,,,,,,, -254900,3.2416773,1.1666116,,,,,,,,,,,,,, -255000,3.451855,1.2129954,,,,,,,,,,,,,, -255100,2.776867,1.1052847,,,,,,,,,,,,,, -255200,3.1351602,1.1669797,,,,,,,,,,,,,, -255300,3.9087856,3.2445626,,,,,,,,,,,,,, -255400,3.2331893,1.2355338,,,,,,,,,,,,,, -255500,2.9736345,1.0870146,,,,,,,,,,,,,, -255527,,,0.8911327719688416,0.4039798676967621,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,114315.62262010574,123652.72332525252,114315.62262010574,9309.296308279036,14.374987840652466,0.0 -255600,3.6270854,1.1021878,,,,,,,,,,,,,, -255700,3.7352738,3.1643696,,,,,,,,,,,,,, -255800,3.9368818,3.2459002,,,,,,,,,,,,,, -255900,3.3381793,1.1201063,,,,,,,,,,,,,, -256000,3.0932627,1.6104875,,,,,,,,,,,,,, -256100,3.2399218,1.280265,,,,,,,,,,,,,, -256200,3.0376325,1.7142789,,,,,,,,,,,,,, -256300,2.7699094,1.2132503,,,,,,,,,,,,,, -256400,3.0211723,2.4398575,,,,,,,,,,,,,, -256465,,,0.8897265195846558,0.412788063287735,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,114735.56063556673,124108.63796424866,114735.56063556673,9345.163439273834,14.435834646224976,0.0 -256500,3.170512,1.2023275,,,,,,,,,,,,,, -256600,2.9544718,1.6937721,,,,,,,,,,,,,, -256700,3.0794923,2.5103242,,,,,,,,,,,,,, -256800,3.027639,1.026565,,,,,,,,,,,,,, -256900,3.0899174,2.0435371,,,,,,,,,,,,,, -257000,3.1077228,1.2563648,,,,,,,,,,,,,, -257100,3.1922235,2.489693,,,,,,,,,,,,,, -257200,3.1099856,1.142425,,,,,,,,,,,,,, -257300,3.28201,1.2247396,,,,,,,,,,,,,, -257400,4.097424,3.2667713,,,,,,,,,,,,,, -257404,,,0.8879101276397705,0.4117997884750366,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,115155.77038359642,124564.10697579384,115155.77038359642,9380.308512449265,14.500270128250122,0.0 -257500,2.8976848,1.1507554,,,,,,,,,,,,,, -257600,3.4423554,1.0626379,,,,,,,,,,,,,, -257700,2.8370588,1.4763277,,,,,,,,,,,,,, -257800,3.0269182,1.1585622,,,,,,,,,,,,,, -257900,3.2658932,1.1771494,,,,,,,,,,,,,, -258000,3.024656,1.1068579,,,,,,,,,,,,,, -258100,5.8061175,3.2665706,,,,,,,,,,,,,, -258200,3.0884986,1.2425607,,,,,,,,,,,,,, -258300,3.2625396,2.5458221,,,,,,,,,,,,,, -258340,,,0.8883788585662842,0.4111296832561493,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,115575.71538758278,125020.03127336502,115575.71538758278,9416.172207832336,14.567237615585327,0.0 -258400,3.0065746,1.0663381,,,,,,,,,,,,,, -258500,3.0196579,1.7757765,,,,,,,,,,,,,, -258600,3.0674691,1.9977443,,,,,,,,,,,,,, -258700,3.0642781,1.0583675,,,,,,,,,,,,,, -258800,3.5401268,3.1103857,,,,,,,,,,,,,, -258900,3.0798905,1.0557563,,,,,,,,,,,,,, -259000,2.8606658,1.9062828,,,,,,,,,,,,,, -259100,3.0868225,1.1181847,,,,,,,,,,,,,, -259200,3.002773,1.5787182,,,,,,,,,,,,,, -259279,,,0.8874804377555847,0.4159693717956543,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,115995.97527813911,125475.9886534214,115995.97527813911,9451.75629067421,14.63141632080078,0.0 -259300,3.108494,1.0981066,,,,,,,,,,,,,, -259400,2.966827,1.1372937,,,,,,,,,,,,,, -259500,3.2466137,2.1888046,,,,,,,,,,,,,, -259600,2.9831603,1.2790467,,,,,,,,,,,,,, -259700,3.3168614,1.3469241,,,,,,,,,,,,,, -259800,3.2464728,1.1077274,,,,,,,,,,,,,, -259900,3.4219947,3.1037133,,,,,,,,,,,,,, -260000,3.0623226,1.0908489,,,,,,,,,,,,,, -260100,3.0425887,1.2121245,,,,,,,,,,,,,, -260200,3.1859057,1.1522874,,,,,,,,,,,,,, -260218,,,0.888964831829071,0.4118813574314117,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,116416.2382993698,125931.2862727642,116416.2382993698,9486.679986476898,14.693573951721191,0.0 -260300,2.9741702,1.5323987,,,,,,,,,,,,,, -260400,3.304864,1.5811211,,,,,,,,,,,,,, -260500,2.875477,1.5408031,,,,,,,,,,,,,, -260600,3.3121736,1.2295234,,,,,,,,,,,,,, -260700,2.9565082,1.3147151,,,,,,,,,,,,,, -260800,3.1901627,1.1218967,,,,,,,,,,,,,, -260900,3.4183445,1.1529655,,,,,,,,,,,,,, -261000,3.8989122,3.2184036,,,,,,,,,,,,,, -261100,2.8752596,1.598917,,,,,,,,,,,,,, -261156,,,0.8884961009025574,0.4160282909870147,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,116836.5095319748,126387.26921463013,116836.5095319748,9522.27923154831,14.757773399353027,0.0 -261200,2.9310634,1.0985574,,,,,,,,,,,,,, -261300,3.338969,1.9271585,,,,,,,,,,,,,, -261400,3.3310194,1.2296581,,,,,,,,,,,,,, -261500,2.995078,2.0306678,,,,,,,,,,,,,, -261600,4.3710756,3.148677,,,,,,,,,,,,,, -261700,3.278381,2.6993032,,,,,,,,,,,,,, -261800,3.5117414,3.033082,,,,,,,,,,,,,, -261900,3.249322,1.5670136,,,,,,,,,,,,,, -262000,3.034033,1.078802,,,,,,,,,,,,,, -262095,,,0.8870702981948853,0.4149435758590698,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,117256.55498623848,126842.553047657,117256.55498623848,9557.40468263626,14.821744441986084,0.0 -262100,3.137467,1.7044986,,,,,,,,,,,,,, -262200,2.9443123,1.8379068,,,,,,,,,,,,,, -262300,3.2577944,1.2611743,,,,,,,,,,,,,, -262400,2.9514728,1.2594509,,,,,,,,,,,,,, -262500,2.9826653,1.2673029,,,,,,,,,,,,,, -262600,3.3249247,2.6157763,,,,,,,,,,,,,, -262700,3.246984,1.1692601,,,,,,,,,,,,,, -262800,2.8177958,1.5996735,,,,,,,,,,,,,, -262900,3.2254221,2.910451,,,,,,,,,,,,,, -263000,3.3322523,2.7361312,,,,,,,,,,,,,, -263033,,,0.8878515362739563,0.4189508855342865,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,117676.4786427021,127296.06613850594,117676.4786427021,9590.883177995682,14.8841450214386,0.0 -263100,3.1093593,2.3290389,,,,,,,,,,,,,, -263200,3.0319402,1.5532582,,,,,,,,,,,,,, -263300,2.9166026,1.1115856,,,,,,,,,,,,,, -263400,3.0132947,1.0725166,,,,,,,,,,,,,, -263500,2.9734025,1.8338588,,,,,,,,,,,,,, -263600,3.0756905,1.4752384,,,,,,,,,,,,,, -263700,3.0431328,1.0549057,,,,,,,,,,,,,, -263800,3.7260873,3.2233279,,,,,,,,,,,,,, -263900,3.0719194,1.3961377,,,,,,,,,,,,,, -263973,,,0.8878124952316284,0.4134690463542938,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,118096.52826428412,127751.09358382224,118096.52826428412,9625.748121738434,14.949349880218506,0.0 -264000,3.40193,1.0801079,,,,,,,,,,,,,, -264100,3.6383665,2.9569817,,,,,,,,,,,,,, -264200,3.3797965,1.1588038,,,,,,,,,,,,,, -264300,3.1440623,1.1060476,,,,,,,,,,,,,, -264400,3.0934415,2.6120203,,,,,,,,,,,,,, -264500,2.9485712,1.0891081,,,,,,,,,,,,,, -264600,3.0415568,1.8646578,,,,,,,,,,,,,, -264700,3.0497773,1.0838113,,,,,,,,,,,,,, -264800,2.9329467,1.3285805,,,,,,,,,,,,,, -264900,2.9010541,1.1580145,,,,,,,,,,,,,, -264914,,,0.8861327767372131,0.4168863594532013,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,118516.4486773014,128205.77392411232,118516.4486773014,9660.38653588295,15.022777557373049,0.0 -265000,2.9847713,1.5366452,,,,,,,,,,,,,, -265100,3.3367157,1.191037,,,,,,,,,,,,,, -265200,3.439872,1.1493951,,,,,,,,,,,,,, -265300,3.777587,2.9507904,,,,,,,,,,,,,, -265400,3.338613,1.2023361,,,,,,,,,,,,,, -265500,2.946793,2.352055,,,,,,,,,,,,,, -265600,3.034079,1.3543764,,,,,,,,,,,,,, -265700,3.088335,1.0607296,,,,,,,,,,,,,, -265800,3.2516174,1.0753785,,,,,,,,,,,,,, -265854,,,0.8871288895606995,0.4185587167739868,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,118936.5146408081,128661.9830391407,118936.5146408081,9696.406280517578,15.097080707550049,0.0 -265900,3.1399596,1.0727457,,,,,,,,,,,,,, -266000,2.911426,2.2582815,,,,,,,,,,,,,, -266100,3.242351,1.0898805,,,,,,,,,,,,,, -266200,2.930082,1.0628526,,,,,,,,,,,,,, -266300,2.9451895,1.9246972,,,,,,,,,,,,,, -266400,3.1247847,1.1126275,,,,,,,,,,,,,, -266500,3.9459414,3.2286732,,,,,,,,,,,,,, -266600,3.316429,1.1560682,,,,,,,,,,,,,, -266700,3.2734702,1.402252,,,,,,,,,,,,,, -266795,,,0.8882030844688416,0.4178065359592438,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,119356.6550552845,129118.06162571908,119356.6550552845,9732.229608774183,15.162375926971436,0.0 -266800,3.0808015,1.8708867,,,,,,,,,,,,,, -266900,3.2161,1.3246865,,,,,,,,,,,,,, -267000,3.0086856,2.0644703,,,,,,,,,,,,,, -267100,3.0625296,2.6738904,,,,,,,,,,,,,, -267200,4.056083,3.2736526,,,,,,,,,,,,,, -267300,3.4384983,2.0666819,,,,,,,,,,,,,, -267400,2.9158118,1.0185604,,,,,,,,,,,,,, -267500,3.8008318,3.1641042,,,,,,,,,,,,,, -267600,2.9050403,1.0901083,,,,,,,,,,,,,, -267700,3.442056,2.9974809,,,,,,,,,,,,,, -267734,,,0.887011706829071,0.4181829690933227,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,119776.81697392464,129573.05659985542,119776.81697392464,9766.939259529114,15.23660135269165,0.0 -267800,2.929116,1.1335491,,,,,,,,,,,,,, -267900,3.1063058,1.1079551,,,,,,,,,,,,,, -268000,3.173782,2.5027466,,,,,,,,,,,,,, -268100,3.196318,1.7445868,,,,,,,,,,,,,, -268200,3.0841343,1.7016357,,,,,,,,,,,,,, -268300,3.0931177,1.0962868,,,,,,,,,,,,,, -268400,3.3473535,2.951312,,,,,,,,,,,,,, -268500,3.8778858,3.2789276,,,,,,,,,,,,,, -268600,3.0386512,1.4609448,,,,,,,,,,,,,, -268674,,,0.8915820121765137,0.4057655334472656,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,120196.85601568222,130027.96326994896,120196.85601568222,9801.692220449448,15.302525758743286,0.0 -268700,3.3283484,2.2781894,,,,,,,,,,,,,, -268800,3.768728,3.2326913,,,,,,,,,,,,,, -268900,3.7972221,2.0722468,,,,,,,,,,,,,, -269000,2.9727585,2.1939416,,,,,,,,,,,,,, -269100,3.2396934,1.2309964,,,,,,,,,,,,,, -269200,3.1910818,1.1821524,,,,,,,,,,,,,, -269300,3.1081753,1.1378417,,,,,,,,,,,,,, -269400,3.309978,1.156769,,,,,,,,,,,,,, -269500,2.9659896,1.1440603,,,,,,,,,,,,,, -269600,3.065277,2.4535978,,,,,,,,,,,,,, -269613,,,0.8875976204872131,0.4203981757164001,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,120616.76523089407,130483.94918727876,120616.76523089407,9837.649421453476,15.373100996017456,0.0 -269700,2.928281,2.0080795,,,,,,,,,,,,,, -269800,2.964077,0.97197956,,,,,,,,,,,,,, -269900,2.8054166,2.078331,,,,,,,,,,,,,, -270000,3.024613,1.1636306,,,,,,,,,,,,,, -270100,3.1936603,1.6781101,,,,,,,,,,,,,, -270200,3.2908359,1.3212638,,,,,,,,,,,,,, -270300,3.0982075,1.2005731,,,,,,,,,,,,,, -270400,2.9047754,1.9316455,,,,,,,,,,,,,, -270500,2.9802566,1.0333226,,,,,,,,,,,,,, -270553,,,0.8871679306030273,0.4184862673282623,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,121037.0772664547,130939.53532242776,121037.0772664547,9872.810532808304,15.437132596969604,0.0 -270600,3.7563913,3.3301854,,,,,,,,,,,,,, -270700,3.548045,1.2146842,,,,,,,,,,,,,, -270800,3.7634978,3.0787702,,,,,,,,,,,,,, -270900,3.0297925,1.1719005,,,,,,,,,,,,,, -271000,3.1332436,1.0465477,,,,,,,,,,,,,, -271100,3.1152592,2.250897,,,,,,,,,,,,,, -271200,3.282724,1.6392453,,,,,,,,,,,,,, -271300,3.0412176,1.2358221,,,,,,,,,,,,,, -271400,3.169648,1.5849445,,,,,,,,,,,,,, -271495,,,0.8880859017372131,0.4176699817180633,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,121457.0885810852,131395.90521931648,121457.0885810852,9909.046003103256,15.510629415512083,0.0 -271500,3.9000192,3.186429,,,,,,,,,,,,,, -271600,2.9212563,1.1990663,,,,,,,,,,,,,, -271700,3.00404,1.3241383,,,,,,,,,,,,,, -271800,4.037272,2.966489,,,,,,,,,,,,,, -271900,3.8574378,3.3542647,,,,,,,,,,,,,, -272000,3.120459,1.3118814,,,,,,,,,,,,,, -272100,3.0342717,2.0905566,,,,,,,,,,,,,, -272200,2.9018912,1.7049096,,,,,,,,,,,,,, -272300,2.8403785,1.3385651,,,,,,,,,,,,,, -272400,3.012739,1.1652509,,,,,,,,,,,,,, -272435,,,0.8871484398841858,0.4141007661819458,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,121877.40932011604,131852.0035393238,121877.40932011604,9944.696164131165,15.589151382446287,0.0 -272500,3.4969893,1.6052492,,,,,,,,,,,,,, -272600,3.2329195,1.1528685,,,,,,,,,,,,,, -272700,3.3204532,2.9551845,,,,,,,,,,,,,, -272800,3.1398537,1.0595641,,,,,,,,,,,,,, -272900,3.271868,1.6120986,,,,,,,,,,,,,, -273000,3.4478526,3.0539412,,,,,,,,,,,,,, -273100,2.9847248,1.1858469,,,,,,,,,,,,,, -273200,2.940143,1.1191478,,,,,,,,,,,,,, -273300,2.7801592,1.8076057,,,,,,,,,,,,,, -273376,,,0.887499988079071,0.416747510433197,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,122297.49721193314,132305.79712724686,122297.49721193314,9978.283198833466,15.65884566307068,0.0 -273400,3.1130664,1.7014014,,,,,,,,,,,,,, -273500,2.9637835,1.4805013,,,,,,,,,,,,,, -273600,3.0896754,1.1256368,,,,,,,,,,,,,, -273700,2.9617229,1.0847178,,,,,,,,,,,,,, -273800,3.0171652,1.9417229,,,,,,,,,,,,,, -273900,3.2436678,1.1207342,,,,,,,,,,,,,, -274000,2.8845778,1.9371718,,,,,,,,,,,,,, -274100,3.082764,2.2397978,,,,,,,,,,,,,, -274200,3.0228264,1.0124204,,,,,,,,,,,,,, -274300,3.3019388,1.2239088,,,,,,,,,,,,,, -274315,,,0.888671875,0.4159770011901855,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,122717.44657683372,132762.39725995064,122717.44657683372,10014.80516576767,15.73838758468628,0.0 -274400,3.3585672,3.0340362,,,,,,,,,,,,,, -274500,3.1044693,1.0419905,,,,,,,,,,,,,, -274600,3.5085003,1.1198009,,,,,,,,,,,,,, -274700,3.1884727,1.1146067,,,,,,,,,,,,,, -274800,3.4096441,1.2814354,,,,,,,,,,,,,, -274900,2.9380777,1.0065693,,,,,,,,,,,,,, -275000,3.2516854,1.1121885,,,,,,,,,,,,,, -275100,2.957502,1.119067,,,,,,,,,,,,,, -275200,4.0473404,3.187655,,,,,,,,,,,,,, -275256,,,0.8893945217132568,0.4109589755535126,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,123137.75361943243,133218.33816456795,123137.75361943243,10050.325754642488,15.802060842514038,0.0 -275300,2.997481,1.0618148,,,,,,,,,,,,,, -275400,2.9470177,2.216051,,,,,,,,,,,,,, -275500,3.1117392,1.203207,,,,,,,,,,,,,, -275600,3.1574364,1.0383677,,,,,,,,,,,,,, -275700,3.5891352,2.2544508,,,,,,,,,,,,,, -275800,3.5288873,2.939642,,,,,,,,,,,,,, -275900,3.2528048,1.3591135,,,,,,,,,,,,,, -276000,3.5748365,3.2087522,,,,,,,,,,,,,, -276100,3.3858101,1.1565111,,,,,,,,,,,,,, -276195,,,0.88636714220047,0.4172734320163727,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,123557.85822701454,133672.7080309391,123557.85822701454,10084.47655248642,15.868030309677124,0.0 -276200,2.965934,1.1781158,,,,,,,,,,,,,, -276300,3.8646088,3.2825553,,,,,,,,,,,,,, -276400,3.07704,1.7414539,,,,,,,,,,,,,, -276500,3.068201,1.1169732,,,,,,,,,,,,,, -276600,2.9719026,1.134073,,,,,,,,,,,,,, -276700,2.9940524,1.1571249,,,,,,,,,,,,,, -276800,2.9065046,1.0912006,,,,,,,,,,,,,, -276900,3.2314694,1.2218426,,,,,,,,,,,,,, -277000,5.019539,3.1952612,,,,,,,,,,,,,, -277100,3.1766422,1.1979516,,,,,,,,,,,,,, -277135,,,0.8903319835662842,0.4136055707931518,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,123978.16338539124,134126.56583857536,123978.16338539124,10117.911062717438,15.936080694198608,0.0 -277200,2.9372368,1.1144022,,,,,,,,,,,,,, -277300,2.9649715,1.0883695,,,,,,,,,,,,,, -277400,3.2204778,1.1393728,,,,,,,,,,,,,, -277500,3.2339802,1.2018253,,,,,,,,,,,,,, -277600,2.7931073,1.4199301,,,,,,,,,,,,,, -277700,3.1527684,1.1689731,,,,,,,,,,,,,, -277800,2.9135282,1.101779,,,,,,,,,,,,,, -277900,3.0416436,1.9484398,,,,,,,,,,,,,, -278000,3.231029,1.4175239,,,,,,,,,,,,,, -278075,,,0.8899218440055847,0.4158057570457458,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,124398.14441418648,134580.3505373001,124398.14441418648,10151.590016841888,16.011402368545532,0.0 -278100,3.018898,1.247389,,,,,,,,,,,,,, -278200,3.2248087,1.2178466,,,,,,,,,,,,,, -278300,3.3507,1.1683594,,,,,,,,,,,,,, -278400,3.9593234,3.2583485,,,,,,,,,,,,,, -278500,2.973738,1.2770339,,,,,,,,,,,,,, -278600,3.8598716,3.0783296,,,,,,,,,,,,,, -278700,3.4022586,1.1292446,,,,,,,,,,,,,, -278800,3.8852074,3.2621973,,,,,,,,,,,,,, -278900,3.5454295,1.520199,,,,,,,,,,,,,, -279000,2.9947293,1.1405419,,,,,,,,,,,,,, -279015,,,0.88734370470047,0.408898115158081,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,124818.47934174538,135033.80769109726,124818.47934174538,10184.596621990204,16.07777214050293,0.0 -279100,3.1123135,1.1399335,,,,,,,,,,,,,, -279200,2.979648,2.1030645,,,,,,,,,,,,,, -279300,3.0644133,1.992884,,,,,,,,,,,,,, -279400,3.40776,1.0975084,,,,,,,,,,,,,, -279500,2.8196905,1.5408543,,,,,,,,,,,,,, -279600,3.137899,1.5017682,,,,,,,,,,,,,, -279700,3.3635612,1.1027646,,,,,,,,,,,,,, -279800,3.2060542,1.3430456,,,,,,,,,,,,,, -279900,3.5688941,2.7626297,,,,,,,,,,,,,, -279957,,,0.890429675579071,0.4099819958209991,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,125238.88067746162,135487.19892835617,125238.88067746162,10217.464999198914,16.149641752243042,0.0 -280000,3.1217704,1.0519208,,,,,,,,,,,,,, -280100,3.0014696,1.9156824,,,,,,,,,,,,,, -280200,3.8593228,3.184598,,,,,,,,,,,,,, -280300,3.0925193,1.2644219,,,,,,,,,,,,,, -280400,2.9668133,1.3687243,,,,,,,,,,,,,, -280500,3.101471,2.7135468,,,,,,,,,,,,,, -280600,3.0923758,1.1744981,,,,,,,,,,,,,, -280700,2.8964674,2.3090353,,,,,,,,,,,,,, -280800,3.0862885,1.176326,,,,,,,,,,,,,, -280898,,,0.8888476490974426,0.4090248644351959,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,125658.77555394173,135942.70322799683,125658.77555394173,10252.94473028183,16.230097770690918,0.0 -280900,2.9960601,1.6539396,,,,,,,,,,,,,, -281000,3.2600768,2.709403,,,,,,,,,,,,,, -281100,3.1454551,1.155017,,,,,,,,,,,,,, -281200,2.8461335,1.2188988,,,,,,,,,,,,,, -281300,3.0544043,1.4323542,,,,,,,,,,,,,, -281400,2.9124522,1.3087527,,,,,,,,,,,,,, -281500,3.0488932,1.3157113,,,,,,,,,,,,,, -281600,2.9612138,1.7118822,,,,,,,,,,,,,, -281700,2.8801165,1.1532154,,,,,,,,,,,,,, -281791,,,0.8855664134025574,0.4181103110313415,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,126078.73848581314,136399.6976134777,126078.73848581314,10289.864123106005,16.29589319229126,0.0 -281800,3.6456823,1.0741295,,,,,,,,,,,,,, -281900,3.0923114,1.6412385,,,,,,,,,,,,,, -282000,3.1250513,1.9658794,,,,,,,,,,,,,, -282100,3.7091985,3.0995343,,,,,,,,,,,,,, -282200,3.1005008,1.137386,,,,,,,,,,,,,, -282300,3.0347025,1.1211104,,,,,,,,,,,,,, -282400,3.382623,1.1628034,,,,,,,,,,,,,, -282500,3.1948476,1.1776057,,,,,,,,,,,,,, -282600,2.8608968,1.9176894,,,,,,,,,,,,,, -282700,3.1470704,1.388697,,,,,,,,,,,,,, -282733,,,0.8891015648841858,0.4115650951862335,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,126498.94877243042,136855.63595891,126498.94877243042,10325.465045690536,16.3734393119812,0.0 -282800,3.0675237,2.31585,,,,,,,,,,,,,, -282900,3.0217295,2.309125,,,,,,,,,,,,,, -283000,2.998042,1.3458172,,,,,,,,,,,,,, -283100,2.7969902,1.5571786,,,,,,,,,,,,,, -283200,3.9751513,3.2999403,,,,,,,,,,,,,, -283300,3.4455075,1.128416,,,,,,,,,,,,,, -283400,3.0827847,1.3372568,,,,,,,,,,,,,, -283500,3.3711185,1.5557445,,,,,,,,,,,,,, -283600,3.6090243,1.1016078,,,,,,,,,,,,,, -283671,,,0.8883788585662842,0.4134488999843597,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,126919.20204520226,137309.89725899696,126919.20204520226,10359.358072280884,16.440247058868408,0.0 -283700,3.9041495,3.2348187,,,,,,,,,,,,,, -283800,3.1111314,1.7114316,,,,,,,,,,,,,, -283900,3.191164,1.9917643,,,,,,,,,,,,,, -284000,3.2027838,1.0788655,,,,,,,,,,,,,, -284100,3.1419911,1.1118332,,,,,,,,,,,,,, -284200,2.9708536,1.7359908,,,,,,,,,,,,,, -284300,3.9693122,3.050078,,,,,,,,,,,,,, -284400,2.8386257,1.4637322,,,,,,,,,,,,,, -284500,3.3546162,2.7257895,,,,,,,,,,,,,, -284600,2.8856564,1.0000165,,,,,,,,,,,,,, -284611,,,0.8886132836341858,0.4173832535743713,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,127339.28399133682,137766.33781647682,127339.28399133682,10395.601438760756,16.506649494171143,0.0 -284700,3.1130977,1.105968,,,,,,,,,,,,,, -284800,2.974227,1.1180122,,,,,,,,,,,,,, -284900,3.050346,1.1392359,,,,,,,,,,,,,, -285000,3.5176337,1.109902,,,,,,,,,,,,,, -285100,3.040417,1.1303512,,,,,,,,,,,,,, -285200,2.9508758,1.9136021,,,,,,,,,,,,,, -285300,3.3763485,2.2524102,,,,,,,,,,,,,, -285400,3.073109,1.1618214,,,,,,,,,,,,,, -285500,2.940351,1.0960519,,,,,,,,,,,,,, -285552,,,0.88783198595047,0.4149629175662994,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,127759.24665808678,138220.6668958664,127759.24665808678,10429.8513276577,16.57482123374939,0.0 -285600,3.099992,1.1575937,,,,,,,,,,,,,, -285700,3.0781524,1.0980353,,,,,,,,,,,,,, -285800,3.102329,1.1227239,,,,,,,,,,,,,, -285900,3.0657918,1.6712795,,,,,,,,,,,,,, -286000,3.249867,1.1948637,,,,,,,,,,,,,, -286100,3.1490722,1.0988472,,,,,,,,,,,,,, -286200,3.6098669,3.021826,,,,,,,,,,,,,, -286300,3.0488198,1.0997286,,,,,,,,,,,,,, -286400,3.1607437,1.4911978,,,,,,,,,,,,,, -286492,,,0.8872656226158142,0.417816162109375,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,128179.28594827652,138676.82459497452,128179.28594827652,10465.854510307312,16.641371726989746,0.0 -286500,2.972207,1.0699109,,,,,,,,,,,,,, -286600,3.0714784,1.1514692,,,,,,,,,,,,,, -286700,3.1415393,2.0587044,,,,,,,,,,,,,, -286800,3.19127,2.7165267,,,,,,,,,,,,,, -286900,3.4924061,3.0743246,,,,,,,,,,,,,, -287000,3.382462,1.4651902,,,,,,,,,,,,,, -287100,3.6166987,3.1286426,,,,,,,,,,,,,, -287200,2.9822316,1.0182794,,,,,,,,,,,,,, -287300,3.0722878,1.1601897,,,,,,,,,,,,,, -287400,3.2099555,2.998724,,,,,,,,,,,,,, -287433,,,0.88832026720047,0.4100013673305511,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,128598.95075941086,139130.64910507202,128598.95075941086,10499.648500919342,16.957338094711304,0.0 -287500,3.046419,2.2292113,,,,,,,,,,,,,, -287600,3.6518638,3.0785723,,,,,,,,,,,,,, -287700,3.093175,1.1613376,,,,,,,,,,,,,, -287800,3.0619507,1.4313024,,,,,,,,,,,,,, -287900,3.220472,1.4281211,,,,,,,,,,,,,, -288000,3.6671414,3.2228203,,,,,,,,,,,,,, -288100,3.275249,1.3001144,,,,,,,,,,,,,, -288200,3.0887551,2.038178,,,,,,,,,,,,,, -288300,4.1716933,3.0953922,,,,,,,,,,,,,, -288369,,,0.8876171708106995,0.4147033095359802,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,129019.15416574478,139586.47339582443,129019.15416574478,10535.140134096146,17.03790783882141,0.0 -288400,2.9231553,1.0713328,,,,,,,,,,,,,, -288500,3.0619106,1.7960515,,,,,,,,,,,,,, -288600,3.5925286,1.0637045,,,,,,,,,,,,,, -288700,2.9799447,2.2439795,,,,,,,,,,,,,, -288800,3.242688,1.1544049,,,,,,,,,,,,,, -288900,3.0285678,1.2424253,,,,,,,,,,,,,, -289000,2.8148272,1.6140941,,,,,,,,,,,,,, -289100,3.4236434,1.4394836,,,,,,,,,,,,,, -289200,2.942559,2.08125,,,,,,,,,,,,,, -289300,3.000544,1.1351414,,,,,,,,,,,,,, -289307,,,0.887011706829071,0.4225472211837768,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,129439.41557979584,140040.8846886158,129439.41557979584,10569.173550605774,17.106011629104614,0.0 -289400,2.947184,1.2535,,,,,,,,,,,,,, -289500,2.943832,1.3351178,,,,,,,,,,,,,, -289600,3.001431,1.0180917,,,,,,,,,,,,,, -289700,3.9758804,2.4137783,,,,,,,,,,,,,, -289800,3.1542082,1.1567976,,,,,,,,,,,,,, -289900,3.0283165,2.1307817,,,,,,,,,,,,,, -290000,3.1442258,1.0054437,,,,,,,,,,,,,, -290100,3.4747756,2.9398828,,,,,,,,,,,,,, -290200,3.2289207,1.1700269,,,,,,,,,,,,,, -290246,,,0.8878515362739563,0.4139612913131714,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,129859.462069273,140494.82578778267,129859.462069273,10602.943563699722,17.181469678878784,0.0 -290300,3.3168142,2.7658734,,,,,,,,,,,,,, -290400,3.0315678,1.1420292,,,,,,,,,,,,,, -290500,2.9855626,1.0644046,,,,,,,,,,,,,, -290600,3.8930106,3.180058,,,,,,,,,,,,,, -290700,3.081497,1.9983582,,,,,,,,,,,,,, -290800,3.0384529,2.3269126,,,,,,,,,,,,,, -290900,2.946692,1.371799,,,,,,,,,,,,,, -291000,3.4651234,3.069737,,,,,,,,,,,,,, -291100,3.3002229,1.184409,,,,,,,,,,,,,, -291184,,,0.8871288895606995,0.4143074452877044,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,130279.55357980728,140948.2878472805,130279.55357980728,10636.189188480375,17.256900310516357,0.0 -291200,3.009253,1.3955414,,,,,,,,,,,,,, -291300,3.2471023,2.7125845,,,,,,,,,,,,,, -291400,2.8563168,1.2414904,,,,,,,,,,,,,, -291500,2.7748857,1.8103853,,,,,,,,,,,,,, -291600,3.238521,1.2083446,,,,,,,,,,,,,, -291700,3.0424225,1.9091772,,,,,,,,,,,,,, -291800,4.2347445,3.193839,,,,,,,,,,,,,, -291900,2.9129019,1.1990101,,,,,,,,,,,,,, -292000,2.7209547,1.4640176,,,,,,,,,,,,,, -292100,3.0243907,1.1354008,,,,,,,,,,,,,, -292121,,,0.8910741806030273,0.4123877286911011,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,130699.70896553992,141403.94225287437,130699.70896553992,10671.554021835327,17.342145442962646,0.0 -292200,3.4022615,1.1928405,,,,,,,,,,,,,, -292300,3.251649,2.3856554,,,,,,,,,,,,,, -292400,3.25816,1.1574305,,,,,,,,,,,,,, -292500,3.46087,1.0375543,,,,,,,,,,,,,, -292600,3.077093,2.095385,,,,,,,,,,,,,, -292700,3.6681533,1.115752,,,,,,,,,,,,,, -292800,4.028702,3.2808502,,,,,,,,,,,,,, -292900,3.539166,1.0798779,,,,,,,,,,,,,, -293000,3.1021945,1.2950023,,,,,,,,,,,,,, -293060,,,0.8864648342132568,0.422519326210022,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,131119.73110198975,141859.72695946693,131119.73110198975,10707.190835475922,17.419295072555542,0.0 -293100,3.0154839,1.606877,,,,,,,,,,,,,, -293200,3.1799517,1.2108967,,,,,,,,,,,,,, -293300,3.1496136,1.0871412,,,,,,,,,,,,,, -293400,4.147153,3.289767,,,,,,,,,,,,,, -293500,3.380932,2.9306726,,,,,,,,,,,,,, -293600,3.0290825,1.463086,,,,,,,,,,,,,, -293700,3.1210525,2.696454,,,,,,,,,,,,,, -293800,3.0290914,1.7849951,,,,,,,,,,,,,, -293900,2.6918025,1.8217931,,,,,,,,,,,,,, -293998,,,0.8893359303474426,0.4128582775592804,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,131539.8347415924,142315.80940794945,131539.8347415924,10743.051213026049,17.48875403404236,0.0 -294000,3.0817158,2.370307,,,,,,,,,,,,,, -294100,3.0317662,1.2954848,,,,,,,,,,,,,, -294200,3.111971,1.1794614,,,,,,,,,,,,,, -294300,3.3225892,2.8260221,,,,,,,,,,,,,, -294400,2.915379,2.3438492,,,,,,,,,,,,,, -294500,2.9624095,1.1394333,,,,,,,,,,,,,, -294600,2.9135494,1.085046,,,,,,,,,,,,,, -294700,3.1407313,1.085819,,,,,,,,,,,,,, -294800,3.049798,2.171017,,,,,,,,,,,,,, -294900,3.2561808,1.0691051,,,,,,,,,,,,,, -294933,,,0.8885351419448853,0.4156334996223449,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,131959.82090997696,142771.27851748466,131959.82090997696,10778.418419837952,17.556360244750977,0.0 -295000,3.2434149,1.1958756,,,,,,,,,,,,,, -295100,2.8939145,1.2471893,,,,,,,,,,,,,, -295200,3.432968,1.0967591,,,,,,,,,,,,,, -295300,3.2160926,1.2079755,,,,,,,,,,,,,, -295400,3.1414466,2.4051712,,,,,,,,,,,,,, -295500,2.8272593,1.2345755,,,,,,,,,,,,,, -295600,2.9895933,1.3420858,,,,,,,,,,,,,, -295700,3.415217,1.0863363,,,,,,,,,,,,,, -295800,3.2228992,1.6740085,,,,,,,,,,,,,, -295873,,,0.8855273127555847,0.4239533543586731,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,132379.98327803612,143225.35876965523,132379.98327803612,10812.218878507614,17.62444758415222,0.0 -295900,4.071809,3.2728693,,,,,,,,,,,,,, -296000,3.1585715,2.7891886,,,,,,,,,,,,,, -296100,3.2670925,2.4361165,,,,,,,,,,,,,, -296200,3.1487775,1.1242397,,,,,,,,,,,,,, -296300,3.9483132,3.1187735,,,,,,,,,,,,,, -296400,3.08085,1.4622672,,,,,,,,,,,,,, -296500,3.1102593,2.0925202,,,,,,,,,,,,,, -296600,3.0686555,2.4201112,,,,,,,,,,,,,, -296700,3.254073,1.2829164,,,,,,,,,,,,,, -296800,3.0912845,1.3949494,,,,,,,,,,,,,, -296809,,,0.8883007764816284,0.4127146899700165,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,132800.25978302956,143681.6864824295,132800.25978302956,10848.138242006302,17.70807909965515,0.0 -296900,3.1427188,1.0797453,,,,,,,,,,,,,, -297000,3.2840016,2.8427224,,,,,,,,,,,,,, -297100,2.9950397,1.1645133,,,,,,,,,,,,,, -297200,3.08299,2.2900739,,,,,,,,,,,,,, -297300,3.0284257,1.5415041,,,,,,,,,,,,,, -297400,3.1763427,1.1943455,,,,,,,,,,,,,, -297500,3.2928662,1.3534188,,,,,,,,,,,,,, -297600,3.1456783,1.1642375,,,,,,,,,,,,,, -297700,2.9296181,1.9368262,,,,,,,,,,,,,, -297749,,,0.8908202648162842,0.405164510011673,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,133220.3181180954,144136.4060664177,133220.3181180954,10882.680842399595,17.776843547821045,0.0 -297800,3.2120388,1.2075565,,,,,,,,,,,,,, -297900,2.9877105,2.1490047,,,,,,,,,,,,,, -298000,3.0583322,1.140209,,,,,,,,,,,,,, -298100,3.0698545,1.7467587,,,,,,,,,,,,,, -298200,3.0476034,1.1155773,,,,,,,,,,,,,, -298300,2.8987195,1.1523193,,,,,,,,,,,,,, -298400,3.4998057,1.3404973,,,,,,,,,,,,,, -298500,3.2049289,2.7211773,,,,,,,,,,,,,, -298600,2.8406887,2.0568438,,,,,,,,,,,,,, -298691,,,0.8869335651397705,0.4187902808189392,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,133640.38061594963,144590.28002810478,133640.38061594963,10916.362850904465,17.857324361801147,0.0 -298700,2.998033,1.8603847,,,,,,,,,,,,,, -298800,3.0776033,1.0613672,,,,,,,,,,,,,, -298900,3.136129,1.1491792,,,,,,,,,,,,,, -299000,4.156149,3.1374984,,,,,,,,,,,,,, -299100,3.6939187,3.0392406,,,,,,,,,,,,,, -299200,2.9002016,1.0472313,,,,,,,,,,,,,, -299300,3.6400683,3.165633,,,,,,,,,,,,,, -299400,3.9807522,3.194117,,,,,,,,,,,,,, -299500,3.1494582,1.2228808,,,,,,,,,,,,,, -299600,3.1939652,1.9104055,,,,,,,,,,,,,, -299630,,,0.8880468606948853,0.4154566526412964,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,134060.37126111984,145043.9550564289,134060.37126111984,10949.912775278091,17.94206404685974,0.0 -299700,3.2737582,1.2032499,,,,,,,,,,,,,, -299800,3.091124,1.0637167,,,,,,,,,,,,,, -299900,3.1803076,1.885647,,,,,,,,,,,,,, -300000,2.9526966,1.2556043,,,,,,,,,,,,,, -300100,3.5561242,2.1992116,,,,,,,,,,,,,, -300200,3.6961076,3.1041307,,,,,,,,,,,,,, -300300,2.8732364,1.3431215,,,,,,,,,,,,,, -300400,3.129171,2.5426147,,,,,,,,,,,,,, -300500,3.1247,1.2449783,,,,,,,,,,,,,, -300572,,,0.8863085508346558,0.4205401539802551,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,134480.29626059532,145498.15152049065,134480.29626059532,10984.052475690842,18.025188207626343,0.0 -300600,3.2914555,2.9682517,,,,,,,,,,,,,, -300700,3.1002963,1.1192306,,,,,,,,,,,,,, -300800,2.9130948,1.1803021,,,,,,,,,,,,,, -300900,3.1171422,1.1779934,,,,,,,,,,,,,, -301000,3.0005696,1.2545949,,,,,,,,,,,,,, -301100,2.973794,1.0966915,,,,,,,,,,,,,, -301200,3.4539013,1.0916973,,,,,,,,,,,,,, -301300,4.0289493,3.1608996,,,,,,,,,,,,,, -301400,3.3346305,1.3262771,,,,,,,,,,,,,, -301500,3.0565877,1.8277838,,,,,,,,,,,,,, -301509,,,0.8903906345367432,0.4110255241394043,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,134900.26795220375,145952.77487421036,134900.26795220375,11018.582464933395,18.09760928153992,0.0 -301600,2.929861,1.2805347,,,,,,,,,,,,,, -301700,3.17683,1.0614053,,,,,,,,,,,,,, -301800,2.893688,1.1079507,,,,,,,,,,,,,, -301900,3.4341133,1.3512907,,,,,,,,,,,,,, -302000,2.8729262,1.8855307,,,,,,,,,,,,,, -302100,3.284352,1.9576688,,,,,,,,,,,,,, -302200,3.4569738,2.6316447,,,,,,,,,,,,,, -302300,2.923302,1.0800889,,,,,,,,,,,,,, -302400,3.0445855,1.5382587,,,,,,,,,,,,,, -302451,,,0.8894921541213989,0.4096398055553436,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,135320.3759112358,146409.99482226372,135320.3759112358,11055.576867341995,18.16519927978516,0.0 -302500,3.3982837,1.2060382,,,,,,,,,,,,,, -302600,3.2788146,2.2344275,,,,,,,,,,,,,, -302700,2.9064817,1.0299324,,,,,,,,,,,,,, -302800,3.1315596,1.0594108,,,,,,,,,,,,,, -302900,2.9939735,1.1422575,,,,,,,,,,,,,, -303000,3.1628053,1.9003321,,,,,,,,,,,,,, -303100,3.24469,1.0848382,,,,,,,,,,,,,, -303200,3.283198,1.7834072,,,,,,,,,,,,,, -303300,3.535164,1.3016645,,,,,,,,,,,,,, -303392,,,0.88832026720047,0.413810133934021,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,135740.50975704193,146865.0077443123,135740.50975704193,11090.334149360657,18.23756146430969,0.0 -303400,3.7492826,2.9394393,,,,,,,,,,,,,, -303500,3.3108072,1.0773101,,,,,,,,,,,,,, -303600,3.2266452,2.8117402,,,,,,,,,,,,,, -303700,2.9842007,1.0684048,,,,,,,,,,,,,, -303800,2.9442577,1.3772892,,,,,,,,,,,,,, -303900,4.1028485,2.8825195,,,,,,,,,,,,,, -304000,3.2872822,1.6042343,,,,,,,,,,,,,, -304100,3.5985484,3.254752,,,,,,,,,,,,,, -304200,2.7262304,1.5843008,,,,,,,,,,,,,, -304300,3.1799643,1.1388791,,,,,,,,,,,,,, -304333,,,0.8908593654632568,0.4046581983566284,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,136160.44328808784,147321.12025094032,136160.44328808784,11126.382422208786,18.318288803100582,0.0 -304400,3.365479,1.048449,,,,,,,,,,,,,, -304500,3.23749,1.0622123,,,,,,,,,,,,,, -304600,2.981111,2.251156,,,,,,,,,,,,,, -304700,2.97826,1.6920582,,,,,,,,,,,,,, -304800,2.7908862,1.1040676,,,,,,,,,,,,,, -304900,4.0306907,3.2063055,,,,,,,,,,,,,, -305000,3.2563336,1.1972312,,,,,,,,,,,,,, -305100,3.9449997,3.1414657,,,,,,,,,,,,,, -305200,3.1685214,1.1515883,,,,,,,,,,,,,, -305272,,,0.885546863079071,0.4197524785995483,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,136580.41247677803,147776.61046028137,136580.41247677803,11161.775073289871,18.397958040237427,0.0 -305300,3.2904286,1.134023,,,,,,,,,,,,,, -305400,3.2014284,2.9170246,,,,,,,,,,,,,, -305500,2.9553192,1.4832284,,,,,,,,,,,,,, -305600,3.4106126,1.130266,,,,,,,,,,,,,, -305700,3.1484845,1.1360351,,,,,,,,,,,,,, -305800,3.345238,1.1523914,,,,,,,,,,,,,, -305900,3.1237986,2.005512,,,,,,,,,,,,,, -306000,3.0294893,1.0849462,,,,,,,,,,,,,, -306100,3.854679,3.3071163,,,,,,,,,,,,,, -306200,2.847922,1.642833,,,,,,,,,,,,,, -306211,,,0.8883788585662842,0.4131350219249725,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,137000.47409558296,148231.98989629743,137000.47409558296,11196.973927736282,18.467971086502075,0.0 -306300,2.9671125,1.1308821,,,,,,,,,,,,,, -306400,3.577966,3.2212834,,,,,,,,,,,,,, -306500,2.8635416,1.6519953,,,,,,,,,,,,,, -306600,3.4685688,1.3521166,,,,,,,,,,,,,, -306700,2.9694157,1.882083,,,,,,,,,,,,,, -306800,2.849672,2.1650863,,,,,,,,,,,,,, -306900,3.761651,1.0717493,,,,,,,,,,,,,, -307000,3.3243012,1.1182516,,,,,,,,,,,,,, -307100,3.9130657,3.272474,,,,,,,,,,,,,, -307151,,,0.8883984088897705,0.4133432805538177,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,137420.50700616837,148687.06598305702,137420.50700616837,11231.89693236351,18.539098262786865,0.0 -307200,3.2599697,1.1307126,,,,,,,,,,,,,, -307300,3.1546772,1.7141412,,,,,,,,,,,,,, -307400,3.661028,3.1402743,,,,,,,,,,,,,, -307500,3.0829256,1.4104363,,,,,,,,,,,,,, -307600,2.9356527,2.273036,,,,,,,,,,,,,, -307700,3.0635667,2.4342628,,,,,,,,,,,,,, -307800,4.71205,3.0845833,,,,,,,,,,,,,, -307900,3.0197268,1.1245886,,,,,,,,,,,,,, -308000,3.355809,1.1032097,,,,,,,,,,,,,, -308089,,,0.8885155916213989,0.4139094352722168,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,137840.77836227417,149142.5724697113,137840.77836227417,11267.01406097412,18.608031034469604,0.0 -308100,3.2607698,1.025677,,,,,,,,,,,,,, -308200,3.166804,1.2135952,,,,,,,,,,,,,, -308300,3.043619,1.8290501,,,,,,,,,,,,,, -308400,3.3394046,2.4292963,,,,,,,,,,,,,, -308500,4.050168,3.1348472,,,,,,,,,,,,,, -308600,2.9038312,1.4460671,,,,,,,,,,,,,, -308700,3.4178426,2.8503993,,,,,,,,,,,,,, -308800,2.7893045,1.821901,,,,,,,,,,,,,, -308900,2.8558166,1.3385231,,,,,,,,,,,,,, -309000,3.1358292,1.111852,,,,,,,,,,,,,, -309029,,,0.8866991996765137,0.4185129106044769,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,138260.96691298485,149598.50726366043,138260.96691298485,11302.639257907867,18.680636405944824,0.0 -309100,2.9688647,1.3404068,,,,,,,,,,,,,, -309200,3.9936876,3.1535158,,,,,,,,,,,,,, -309300,3.0156736,0.97477186,,,,,,,,,,,,,, -309400,2.722213,1.7154953,,,,,,,,,,,,,, -309500,3.155674,1.1608436,,,,,,,,,,,,,, -309600,3.0540953,1.9776864,,,,,,,,,,,,,, -309700,3.146843,1.591112,,,,,,,,,,,,,, -309800,3.0286653,1.0529809,,,,,,,,,,,,,, -309900,3.152329,1.1673977,,,,,,,,,,,,,, -309968,,,0.8897656202316284,0.4106150865554809,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,138681.03453612328,150051.7492189407,138681.03453612328,11335.685423135756,18.76006007194519,0.0 -310000,2.904399,1.0359131,,,,,,,,,,,,,, -310100,3.1053727,1.2343932,,,,,,,,,,,,,, -310200,3.71503,3.1324131,,,,,,,,,,,,,, -310300,3.1566877,1.1988802,,,,,,,,,,,,,, -310400,2.9514167,1.3060384,,,,,,,,,,,,,, -310500,3.123286,2.2830005,,,,,,,,,,,,,, -310600,3.7069187,1.185662,,,,,,,,,,,,,, -310700,3.0269957,1.1454524,,,,,,,,,,,,,, -310800,3.0072649,2.080393,,,,,,,,,,,,,, -310900,3.0278525,1.7545931,,,,,,,,,,,,,, -310907,,,0.8888671398162842,0.4129617512226105,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,139101.31498932838,150505.21512961388,139101.31498932838,11368.736935138702,18.845513105392456,0.0 -311000,2.9081793,1.5157057,,,,,,,,,,,,,, -311100,3.3347428,1.2290144,,,,,,,,,,,,,, -311200,3.0375295,2.3590174,,,,,,,,,,,,,, -311300,3.5391772,2.751518,,,,,,,,,,,,,, -311400,3.2148247,2.563034,,,,,,,,,,,,,, -311500,3.032095,1.352961,,,,,,,,,,,,,, -311600,3.7328973,2.672682,,,,,,,,,,,,,, -311700,3.6246622,3.1742735,,,,,,,,,,,,,, -311800,3.5260322,1.4640825,,,,,,,,,,,,,, -311842,,,0.8851171731948853,0.4197098910808563,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,139521.58361434937,150959.9334256649,139521.58361434937,11403.054658412932,18.928631067276,0.0 -311900,3.2765071,1.1555326,,,,,,,,,,,,,, -312000,3.0893917,1.2185574,,,,,,,,,,,,,, -312100,3.0964568,2.772018,,,,,,,,,,,,,, -312200,2.96597,1.9578971,,,,,,,,,,,,,, -312300,2.9573264,1.0910755,,,,,,,,,,,,,, -312400,3.3652043,1.422874,,,,,,,,,,,,,, -312500,3.3041294,1.2460685,,,,,,,,,,,,,, -312600,3.061737,1.1083738,,,,,,,,,,,,,, -312700,3.2715836,1.0450037,,,,,,,,,,,,,, -312780,,,0.8889452815055847,0.4150924682617187,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,139941.93662166595,151414.4565434456,139941.93662166595,11437.10601568222,18.999311447143555,0.0 -312800,3.0005522,1.0947499,,,,,,,,,,,,,, -312900,3.2492158,1.076599,,,,,,,,,,,,,, -313000,3.0392427,1.1322265,,,,,,,,,,,,,, -313100,3.3677812,1.0964036,,,,,,,,,,,,,, -313200,3.2623467,1.1313331,,,,,,,,,,,,,, -313300,3.5104854,3.0149984,,,,,,,,,,,,,, -313400,2.901514,1.5256482,,,,,,,,,,,,,, -313500,3.1300604,1.24563,,,,,,,,,,,,,, -313600,3.1393635,1.2600633,,,,,,,,,,,,,, -313700,3.0300026,1.2542727,,,,,,,,,,,,,, -313722,,,0.88734370470047,0.4192501008510589,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,140362.2047946453,151868.0457689762,140362.2047946453,11470.306761026382,19.070598363876343,0.0 -313800,3.0956633,2.5121794,,,,,,,,,,,,,, -313900,3.1360114,1.2327075,,,,,,,,,,,,,, -314000,3.3307078,1.6314561,,,,,,,,,,,,,, -314100,3.1515243,1.1773713,,,,,,,,,,,,,, -314200,3.3950024,2.7457418,,,,,,,,,,,,,, -314300,3.0424001,1.1075766,,,,,,,,,,,,,, -314400,3.0873752,1.1001483,,,,,,,,,,,,,, -314500,3.2170937,1.143301,,,,,,,,,,,,,, -314600,3.503748,2.9222376,,,,,,,,,,,,,, -314662,,,0.887011706829071,0.4187501668930053,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,140782.08658885956,152323.69007754326,140782.08658885956,11505.935393810272,19.155100107193,0.0 -314700,3.3117547,2.5609326,,,,,,,,,,,,,, -314800,3.2131054,2.523427,,,,,,,,,,,,,, -314900,3.8322556,3.088307,,,,,,,,,,,,,, -315000,3.0531042,1.1221302,,,,,,,,,,,,,, -315100,3.0920022,1.1394792,,,,,,,,,,,,,, -315200,2.880316,2.1166463,,,,,,,,,,,,,, -315300,3.1762412,1.0954978,,,,,,,,,,,,,, -315400,3.437935,2.981215,,,,,,,,,,,,,, -315500,2.8827343,1.1099057,,,,,,,,,,,,,, -315600,3.093256,2.168471,,,,,,,,,,,,,, -315601,,,0.8889257907867432,0.4117574989795685,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,141202.2172217369,152779.8137331009,141202.2172217369,11541.796007156372,19.238603115081787,0.0 -315700,3.7485561,3.2111974,,,,,,,,,,,,,, -315800,3.062414,1.0818926,,,,,,,,,,,,,, -315900,3.2429354,1.0830728,,,,,,,,,,,,,, -316000,3.2984905,2.969037,,,,,,,,,,,,,, -316100,3.0391614,1.1559187,,,,,,,,,,,,,, -316200,3.2550025,1.3486212,,,,,,,,,,,,,, -316300,3.1125045,1.1621536,,,,,,,,,,,,,, -316400,3.3988912,1.1580718,,,,,,,,,,,,,, -316500,3.251594,2.1280432,,,,,,,,,,,,,, -316542,,,0.8873632550239563,0.4215485751628876,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,141622.4211435318,153233.88162970543,141622.4211435318,11575.5288336277,19.32027697563172,0.0 -316600,3.098012,1.1618446,,,,,,,,,,,,,, -316700,3.023208,1.1555629,,,,,,,,,,,,,, -316800,3.2934554,1.6132004,,,,,,,,,,,,,, -316900,3.0912976,1.7504203,,,,,,,,,,,,,, -317000,3.2973254,2.7426314,,,,,,,,,,,,,, -317100,2.9846435,1.2668717,,,,,,,,,,,,,, -317200,3.3129208,2.834159,,,,,,,,,,,,,, -317300,3.4149442,1.1213895,,,,,,,,,,,,,, -317400,3.302477,2.9836838,,,,,,,,,,,,,, -317477,,,0.88929682970047,0.4097913205623626,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,142042.42588067055,153687.1012263298,142042.42588067055,11608.605984210968,19.408986806869507,0.0 -317500,3.046851,1.5430685,,,,,,,,,,,,,, -317600,2.8817103,1.8656291,,,,,,,,,,,,,, -317700,3.1768713,1.114622,,,,,,,,,,,,,, -317800,3.2491841,2.3132634,,,,,,,,,,,,,, -317900,2.9455025,1.8832791,,,,,,,,,,,,,, -318000,2.9676855,1.185296,,,,,,,,,,,,,, -318100,3.0997362,1.0656108,,,,,,,,,,,,,, -318200,4.2337275,3.2659361,,,,,,,,,,,,,, -318300,3.08293,1.0821121,,,,,,,,,,,,,, -318400,2.9649174,1.0474993,,,,,,,,,,,,,, -318414,,,0.8886523246765137,0.4203369319438934,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,142462.3735461235,154142.40728020668,142462.3735461235,11643.827334403992,19.49708747863769,0.0 -318500,3.28492,1.1914489,,,,,,,,,,,,,, -318600,2.8112426,1.0437181,,,,,,,,,,,,,, -318700,2.8861692,1.6713499,,,,,,,,,,,,,, -318800,3.5446546,3.0761077,,,,,,,,,,,,,, -318900,3.4421725,1.0362568,,,,,,,,,,,,,, -319000,3.3809774,1.108615,,,,,,,,,,,,,, -319100,2.992311,1.288813,,,,,,,,,,,,,, -319200,3.8268151,3.2287197,,,,,,,,,,,,,, -319300,3.12929,1.5281639,,,,,,,,,,,,,, -319357,,,0.8871288895606995,0.4173811674118042,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,142882.58603477478,154597.0007481575,142882.58603477478,11678.085348844528,19.57066559791565,0.0 -319400,2.9523172,1.2939022,,,,,,,,,,,,,, -319500,3.2813797,2.6328719,,,,,,,,,,,,,, -319600,3.136317,1.1532154,,,,,,,,,,,,,, -319700,3.299366,3.0552852,,,,,,,,,,,,,, -319800,3.124056,1.3659415,,,,,,,,,,,,,, -319900,3.0796452,1.1259308,,,,,,,,,,,,,, -320000,2.99629,2.265302,,,,,,,,,,,,,, -320100,3.0801136,1.3482944,,,,,,,,,,,,,, -320200,3.2660897,1.1434813,,,,,,,,,,,,,, -320299,,,0.8873046636581421,0.4134297370910644,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,143302.8230383396,155054.7192568779,143302.8230383396,11715.436604499817,19.6518189907074,0.0 -320300,2.891795,1.4661306,,,,,,,,,,,,,, -320400,3.0552058,0.986471,,,,,,,,,,,,,, -320500,3.69319,3.161871,,,,,,,,,,,,,, -320600,3.102819,1.4239564,,,,,,,,,,,,,, -320700,2.9392672,1.6502655,,,,,,,,,,,,,, -320800,3.0751693,1.1743807,,,,,,,,,,,,,, -320900,2.9089158,1.8698814,,,,,,,,,,,,,, -321000,3.0331182,1.8239375,,,,,,,,,,,,,, -321100,3.3695548,1.2127515,,,,,,,,,,,,,, -321200,3.1116085,1.2838812,,,,,,,,,,,,,, -321239,,,0.8871288895606995,0.4203206896781921,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,143722.71570396423,155510.06215786934,143722.71570396423,11750.764872550964,19.725261211395264,0.0 -321300,3.0402684,1.1036834,,,,,,,,,,,,,, -321400,3.644707,3.0032666,,,,,,,,,,,,,, -321500,2.7795553,1.2075431,,,,,,,,,,,,,, -321600,3.255471,2.145141,,,,,,,,,,,,,, -321700,2.9443862,1.5021919,,,,,,,,,,,,,, -321800,3.1794543,2.2219224,,,,,,,,,,,,,, -321900,3.2271461,1.3066555,,,,,,,,,,,,,, -322000,5.0249095,1.1558504,,,,,,,,,,,,,, -322100,3.4938555,3.004403,,,,,,,,,,,,,, -322176,,,0.8892187476158142,0.4106161892414093,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,144142.85748529434,155963.2437081337,144142.85748529434,11783.672978162766,19.807963609695435,0.0 -322200,3.0959756,1.4339623,,,,,,,,,,,,,, -322300,4.209433,3.2735255,,,,,,,,,,,,,, -322400,3.1320775,1.2257979,,,,,,,,,,,,,, -322500,3.2474236,1.1126286,,,,,,,,,,,,,, -322600,3.0175836,1.1612332,,,,,,,,,,,,,, -322700,3.1824603,1.1651293,,,,,,,,,,,,,, -322800,3.3639014,2.8963697,,,,,,,,,,,,,, -322900,3.3546305,1.118319,,,,,,,,,,,,,, -323000,3.185636,1.4069862,,,,,,,,,,,,,, -323100,3.2110434,1.210915,,,,,,,,,,,,,, -323113,,,0.8874218463897705,0.414051741361618,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,144562.77267432213,156419.39860010147,144562.77267432213,11819.779586553574,19.89247989654541,0.0 -323200,3.1856883,1.2213615,,,,,,,,,,,,,, -323300,3.4859657,2.901524,,,,,,,,,,,,,, -323400,3.0030236,1.191607,,,,,,,,,,,,,, -323500,3.2447736,1.8187407,,,,,,,,,,,,,, -323600,3.2315426,1.0851462,,,,,,,,,,,,,, -323700,3.2510166,1.1257458,,,,,,,,,,,,,, -323800,2.9447668,1.4767456,,,,,,,,,,,,,, -323900,3.3347158,2.846484,,,,,,,,,,,,,, -324000,4.089287,3.1950495,,,,,,,,,,,,,, -324054,,,0.8886327743530273,0.416003555059433,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,144983.06563448906,156874.05456399918,144983.06563448906,11854.017767190931,19.96893191337585,0.0 -324100,3.0971022,1.2852491,,,,,,,,,,,,,, -324200,4.7640157,1.1680561,,,,,,,,,,,,,, -324300,3.0110745,1.2089665,,,,,,,,,,,,,, -324400,3.1637313,1.136316,,,,,,,,,,,,,, -324500,3.1788018,1.1107419,,,,,,,,,,,,,, -324600,2.7165387,1.7894638,,,,,,,,,,,,,, -324700,3.201159,1.1436249,,,,,,,,,,,,,, -324800,3.8688014,3.1230915,,,,,,,,,,,,,, -324900,2.8387566,1.6227336,,,,,,,,,,,,,, -324993,,,0.890429675579071,0.413343220949173,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,145403.02961182594,157330.41791248322,145403.02961182594,11890.295197963716,20.042044162750244,0.0 -325000,3.9415061,2.1594179,,,,,,,,,,,,,, -325100,3.7146251,3.2488868,,,,,,,,,,,,,, -325200,3.117993,1.1604747,,,,,,,,,,,,,, -325300,2.9196322,1.4514961,,,,,,,,,,,,,, -325400,2.833832,2.0607429,,,,,,,,,,,,,, -325500,3.9287975,3.317905,,,,,,,,,,,,,, -325600,3.192024,2.5519986,,,,,,,,,,,,,, -325700,3.1412902,2.3588526,,,,,,,,,,,,,, -325800,2.871471,2.22118,,,,,,,,,,,,,, -325900,3.1684701,1.1796473,,,,,,,,,,,,,, -325932,,,0.8891210556030273,0.4105064570903778,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,145823.25002264977,157787.76915454865,145823.25002264977,11927.290768623352,20.128324270248413,0.0 -326000,3.0034442,1.0857306,,,,,,,,,,,,,, -326100,3.0221078,1.100319,,,,,,,,,,,,,, -326200,3.3763099,1.1971033,,,,,,,,,,,,,, -326300,2.9943285,1.0398579,,,,,,,,,,,,,, -326400,3.1177657,1.1155576,,,,,,,,,,,,,, -326500,3.1721046,2.7316718,,,,,,,,,,,,,, -326600,3.228351,2.688764,,,,,,,,,,,,,, -326700,3.4965525,1.1485662,,,,,,,,,,,,,, -326800,3.671803,2.6016428,,,,,,,,,,,,,, -326870,,,0.8895702958106995,0.4092340767383575,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,146243.3017590046,158243.3984901905,146243.3017590046,11962.746523618698,20.20204377174377,0.0 -326900,3.2220871,1.1622975,,,,,,,,,,,,,, -327000,3.21775,1.404108,,,,,,,,,,,,,, -327100,3.3001606,1.5435112,,,,,,,,,,,,,, -327200,3.1125152,1.1148432,,,,,,,,,,,,,, -327300,3.4539165,2.8763666,,,,,,,,,,,,,, -327400,3.172889,1.1164576,,,,,,,,,,,,,, -327500,2.9760058,1.1538398,,,,,,,,,,,,,, -327600,2.9941783,1.5883179,,,,,,,,,,,,,, -327700,2.98676,1.0002227,,,,,,,,,,,,,, -327800,3.6510353,3.2904751,,,,,,,,,,,,,, -327807,,,0.88720703125,0.4160758852958679,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,146663.34315609932,158699.53372955322,146663.34315609932,11998.719693899156,20.27439427375793,0.0 -327900,3.1014032,1.3822515,,,,,,,,,,,,,, -328000,3.136552,1.090457,,,,,,,,,,,,,, -328100,3.3505716,1.2917796,,,,,,,,,,,,,, -328200,3.6182573,2.9872906,,,,,,,,,,,,,, -328300,2.8571165,1.7285655,,,,,,,,,,,,,, -328400,3.1562107,1.2298771,,,,,,,,,,,,,, -328500,3.6480927,1.6618353,,,,,,,,,,,,,, -328600,3.2298236,2.602113,,,,,,,,,,,,,, -328700,2.8566368,1.101276,,,,,,,,,,,,,, -328744,,,0.8877929449081421,0.4134314954280853,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,147083.42253422737,159154.6178812981,147083.42253422737,12033.599283218384,20.351003408432007,0.0 -328800,3.2503796,2.677186,,,,,,,,,,,,,, -328900,3.0808005,1.1824825,,,,,,,,,,,,,, -329000,3.0835097,1.3818815,,,,,,,,,,,,,, -329100,3.3249571,1.1545758,,,,,,,,,,,,,, -329200,3.0521455,1.1580191,,,,,,,,,,,,,, -329300,2.9527876,1.1911013,,,,,,,,,,,,,, -329400,3.1911175,1.1246535,,,,,,,,,,,,,, -329500,3.0396523,1.1822646,,,,,,,,,,,,,, -329600,3.6027634,2.8539286,,,,,,,,,,,,,, -329682,,,0.8869140148162842,0.4140605926513672,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,147503.57654500008,159609.30521917343,147503.57654500008,12068.009401798248,20.425606727600098,0.0 -329700,3.03974,2.193994,,,,,,,,,,,,,, -329800,3.3575838,1.1197121,,,,,,,,,,,,,, -329900,3.1982281,1.2001925,,,,,,,,,,,,,, -330000,3.0109026,1.0925595,,,,,,,,,,,,,, -330100,3.378412,1.4105884,,,,,,,,,,,,,, -330200,2.9908934,1.8636303,,,,,,,,,,,,,, -330300,2.8621504,0.964152,,,,,,,,,,,,,, -330400,3.3975408,1.5405047,,,,,,,,,,,,,, -330500,3.0107727,2.0971365,,,,,,,,,,,,,, -330600,3.6452718,3.1116123,,,,,,,,,,,,,, -330619,,,0.890429675579071,0.40680992603302,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,147923.52029657364,160065.32572340965,147923.52029657364,12103.958374977112,20.504555225372314,0.0 -330700,3.200732,1.1060139,,,,,,,,,,,,,, -330800,3.4152186,1.073993,,,,,,,,,,,,,, -330900,3.0405047,1.2453636,,,,,,,,,,,,,, -331000,3.3289962,1.1136345,,,,,,,,,,,,,, -331100,3.469698,2.9993258,,,,,,,,,,,,,, -331200,3.084434,2.4783149,,,,,,,,,,,,,, -331300,3.0224736,1.0861833,,,,,,,,,,,,,, -331400,3.093993,1.2133214,,,,,,,,,,,,,, -331500,3.3501027,1.2162529,,,,,,,,,,,,,, -331560,,,0.8872460722923279,0.4223819673061371,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,148343.6543688774,160519.1687822342,148343.6543688774,12137.546013832092,20.57728481292725,0.0 -331600,3.0721285,1.6644267,,,,,,,,,,,,,, -331700,3.135636,2.1931129,,,,,,,,,,,,,, -331800,3.3268976,2.8117077,,,,,,,,,,,,,, -331900,2.9863815,1.0247724,,,,,,,,,,,,,, -332000,3.644808,3.0728548,,,,,,,,,,,,,, -332100,3.6243286,3.1556654,,,,,,,,,,,,,, -332200,3.123456,1.1101185,,,,,,,,,,,,,, -332300,3.5754614,3.2112975,,,,,,,,,,,,,, -332400,2.7682664,1.5470655,,,,,,,,,,,,,, -332497,,,0.88783198595047,0.4149302840232849,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,148763.71720457077,160975.2687857151,148763.71720457077,12173.454895019531,20.65666031837464,0.0 -332500,4.2570157,3.2039313,,,,,,,,,,,,,, -332600,3.04303,1.1021364,,,,,,,,,,,,,, -332700,3.1597707,1.1854151,,,,,,,,,,,,,, -332800,2.7789252,1.4707838,,,,,,,,,,,,,, -332900,3.1430442,2.6199954,,,,,,,,,,,,,, -333000,3.0915937,1.1654131,,,,,,,,,,,,,, -333100,3.0339272,1.3153403,,,,,,,,,,,,,, -333200,3.0098512,1.9978462,,,,,,,,,,,,,, -333300,3.1806111,2.6791718,,,,,,,,,,,,,, -333400,3.110607,1.5891527,,,,,,,,,,,,,, -333435,,,0.8897656202316284,0.4081674516201019,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,149183.92890954018,161430.69685649872,149183.92890954018,12208.547565460203,20.73137640953064,0.0 -333500,3.0768661,1.2405696,,,,,,,,,,,,,, -333600,3.1083791,2.3591766,,,,,,,,,,,,,, -333700,2.9295104,1.5881493,,,,,,,,,,,,,, -333800,3.027832,1.1068455,,,,,,,,,,,,,, -333900,3.2617738,1.0741262,,,,,,,,,,,,,, -334000,2.9709058,2.1409128,,,,,,,,,,,,,, -334100,2.930576,1.4923453,,,,,,,,,,,,,, -334200,3.4876964,2.9337792,,,,,,,,,,,,,, -334300,3.2745838,1.1724285,,,,,,,,,,,,,, -334373,,,0.8862499594688416,0.4207873046398163,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,149604.09540700912,161885.90264344215,149604.09540700912,12243.464797735214,20.8051335811615,0.0 -334400,3.0132551,1.7991009,,,,,,,,,,,,,, -334500,2.7254367,1.8303152,,,,,,,,,,,,,, -334600,3.178373,1.5710661,,,,,,,,,,,,,, -334700,2.8529902,2.0576231,,,,,,,,,,,,,, -334800,3.8818588,1.1926873,,,,,,,,,,,,,, -334900,3.1080904,2.5025785,,,,,,,,,,,,,, -335000,3.9522345,3.2469037,,,,,,,,,,,,,, -335100,2.915807,2.1225376,,,,,,,,,,,,,, -335200,3.4656,3.0698295,,,,,,,,,,,,,, -335300,3.6007643,3.1574216,,,,,,,,,,,,,, -335311,,,0.8875390291213989,0.4163751006126404,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,150024.3045117855,162340.12366628647,150024.3045117855,12277.341688632963,20.890724897384644,0.0 -335400,2.8806913,2.3490353,,,,,,,,,,,,,, -335500,3.1272886,1.1104274,,,,,,,,,,,,,, -335600,3.274799,1.8524665,,,,,,,,,,,,,, -335700,2.9675438,1.1204249,,,,,,,,,,,,,, -335800,3.0255647,1.2030413,,,,,,,,,,,,,, -335900,3.3808732,1.1929431,,,,,,,,,,,,,, -336000,3.4515388,1.3540748,,,,,,,,,,,,,, -336100,3.3254948,1.3655941,,,,,,,,,,,,,, -336200,3.18383,1.9946069,,,,,,,,,,,,,, -336250,,,0.8867382407188416,0.4137430191040039,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,150444.52754735947,162794.85982775688,150444.52754735947,12311.726724147797,20.969918489456177,0.0 -336300,2.9896016,1.3763024,,,,,,,,,,,,,, -336400,3.0145707,1.6810874,,,,,,,,,,,,,, -336500,3.5535645,2.7121694,,,,,,,,,,,,,, -336600,3.1690235,2.8142102,,,,,,,,,,,,,, -336700,3.0403173,1.1901642,,,,,,,,,,,,,, -336800,3.1797678,1.1349845,,,,,,,,,,,,,, -336900,3.0550132,1.1309133,,,,,,,,,,,,,, -337000,2.7645638,2.047418,,,,,,,,,,,,,, -337100,2.986974,1.1478728,,,,,,,,,,,,,, -337189,,,0.888476550579071,0.4186500012874603,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,150864.89817857742,163248.99874138832,150864.89817857742,12345.37182021141,21.04414367675781,0.0 -337200,3.2899709,2.24677,,,,,,,,,,,,,, -337300,3.1192727,2.56212,,,,,,,,,,,,,, -337400,3.1784267,1.1256735,,,,,,,,,,,,,, -337500,3.060936,1.2519029,,,,,,,,,,,,,, -337600,2.6859782,1.2350405,,,,,,,,,,,,,, -337700,2.9405844,1.0679538,,,,,,,,,,,,,, -337800,2.8479192,1.3997293,,,,,,,,,,,,,, -337900,3.1396925,1.1800213,,,,,,,,,,,,,, -338000,3.2414036,1.5646651,,,,,,,,,,,,,, -338100,3.0823328,2.449081,,,,,,,,,,,,,, -338129,,,0.8878124952316284,0.4141222536563873,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,151284.9808397293,163705.03434443474,151284.9808397293,12381.196682929993,21.12299537658692,0.0 -338200,3.2285168,1.1143801,,,,,,,,,,,,,, -338300,3.3372626,2.3517618,,,,,,,,,,,,,, -338400,2.8402963,1.5515941,,,,,,,,,,,,,, -338500,2.8341062,1.0392255,,,,,,,,,,,,,, -338600,3.0592546,1.5466516,,,,,,,,,,,,,, -338700,2.9814596,1.1470172,,,,,,,,,,,,,, -338800,3.210669,0.9997533,,,,,,,,,,,,,, -338900,3.2077246,1.0738046,,,,,,,,,,,,,, -339000,3.127244,2.2168925,,,,,,,,,,,,,, -339070,,,0.8898437023162842,0.4094454944133758,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,151705.25899219513,164161.02415442467,151705.25899219513,12416.785385370256,21.197551250457764,0.0 -339100,3.5132155,1.5990263,,,,,,,,,,,,,, -339200,2.8236318,1.4915426,,,,,,,,,,,,,, -339300,3.092481,1.0367439,,,,,,,,,,,,,, -339400,3.0786395,1.1010522,,,,,,,,,,,,,, -339500,3.204915,1.1270245,,,,,,,,,,,,,, -339600,3.005961,1.9999521,,,,,,,,,,,,,, -339700,2.8760943,1.1239922,,,,,,,,,,,,,, -339800,3.1209521,1.2714448,,,,,,,,,,,,,, -339900,2.8957283,1.3207433,,,,,,,,,,,,,, -340000,3.0223348,1.22822,,,,,,,,,,,,,, -340011,,,0.8881640434265137,0.419726699590683,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,152125.4069724083,164616.2634806633,152125.4069724083,12451.752354383469,21.272735834121704,0.0 -340100,3.2540576,1.1636941,,,,,,,,,,,,,, -340200,2.9777896,1.8986434,,,,,,,,,,,,,, -340300,2.9307766,1.2008781,,,,,,,,,,,,,, -340400,3.100944,1.6362381,,,,,,,,,,,,,, -340500,3.0808828,1.212465,,,,,,,,,,,,,, -340600,2.9784522,2.0130103,,,,,,,,,,,,,, -340700,3.1957774,1.4522907,,,,,,,,,,,,,, -340800,3.1279333,1.8299401,,,,,,,,,,,,,, -340900,3.1068246,1.112087,,,,,,,,,,,,,, -340948,,,0.8879492282867432,0.4159083962440491,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,152545.37711572647,165071.47660183907,152545.37711572647,12486.864793300629,21.35389828681945,0.0 -341000,3.0432553,1.0861973,,,,,,,,,,,,,, -341100,3.0909925,2.1673484,,,,,,,,,,,,,, -341200,3.3029249,2.9504528,,,,,,,,,,,,,, -341300,3.0513866,1.1478864,,,,,,,,,,,,,, -341400,3.2654781,2.9952707,,,,,,,,,,,,,, -341500,2.9530487,1.0477433,,,,,,,,,,,,,, -341600,3.0281498,1.2067438,,,,,,,,,,,,,, -341700,3.313765,1.6184213,,,,,,,,,,,,,, -341800,2.9384615,1.3259717,,,,,,,,,,,,,, -341885,,,0.8882616758346558,0.4191540777683258,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,152965.68938064575,165528.09414219856,152965.68938064575,12523.032780647278,21.442599773406982,0.0 -341900,3.2025723,1.1896304,,,,,,,,,,,,,, -342000,3.8620834,3.1634746,,,,,,,,,,,,,, -342100,3.0291324,1.0755417,,,,,,,,,,,,,, -342200,3.1459482,1.5446615,,,,,,,,,,,,,, -342300,2.803749,1.1067293,,,,,,,,,,,,,, -342400,3.109854,1.6634043,,,,,,,,,,,,,, -342500,3.0230787,1.181658,,,,,,,,,,,,,, -342600,3.758567,2.7565703,,,,,,,,,,,,,, -342700,3.1589103,1.0612662,,,,,,,,,,,,,, -342800,3.2583258,1.1531006,,,,,,,,,,,,,, -342824,,,0.8852733969688416,0.4213873147964477,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,153385.66485118866,165983.57289791107,153385.66485118866,12558.412452220917,21.51712751388549,0.0 -342900,3.0706096,1.1225619,,,,,,,,,,,,,, -343000,3.0701408,2.5957801,,,,,,,,,,,,,, -343100,3.3729146,2.8832,,,,,,,,,,,,,, -343200,2.9147618,2.2056193,,,,,,,,,,,,,, -343300,2.981993,1.3091247,,,,,,,,,,,,,, -343400,3.0963178,1.3980829,,,,,,,,,,,,,, -343500,3.7245286,1.1785972,,,,,,,,,,,,,, -343600,3.1747258,1.4250058,,,,,,,,,,,,,, -343700,3.0728104,1.3982515,,,,,,,,,,,,,, -343765,,,0.8892382383346558,0.4125251770019531,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,153805.74320554733,166438.56208109856,153805.74320554733,12593.196396112442,21.595152139663696,0.0 -343800,3.4192567,2.8547208,,,,,,,,,,,,,, -343900,3.156306,1.0378458,,,,,,,,,,,,,, -344000,3.0167851,2.4950461,,,,,,,,,,,,,, -344100,2.879926,1.251168,,,,,,,,,,,,,, -344200,3.2010872,1.1544356,,,,,,,,,,,,,, -344300,3.126216,1.4032552,,,,,,,,,,,,,, -344400,3.1550367,2.4371276,,,,,,,,,,,,,, -344500,2.956176,1.2527701,,,,,,,,,,,,,, -344600,2.8717804,1.1003507,,,,,,,,,,,,,, -344700,3.0649478,1.1071639,,,,,,,,,,,,,, -344704,,,0.8899218440055847,0.4108539223670959,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,154226.0605700016,166893.89604711533,154226.0605700016,12628.076207399368,21.68330931663513,0.0 -344800,3.2038822,1.2058396,,,,,,,,,,,,,, -344900,3.028492,1.8430984,,,,,,,,,,,,,, -345000,3.1428878,1.0210264,,,,,,,,,,,,,, -345100,3.3263664,3.1186101,,,,,,,,,,,,,, -345200,2.893028,1.4987859,,,,,,,,,,,,,, -345300,3.215752,1.0970778,,,,,,,,,,,,,, -345400,2.973621,2.1283398,,,,,,,,,,,,,, -345500,2.8354113,1.8028644,,,,,,,,,,,,,, -345600,3.7191217,3.00787,,,,,,,,,,,,,, -345643,,,0.8858007788658142,0.4192408919334411,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,154646.28801751137,167349.93050813675,154646.28801751137,12663.758479118347,21.75862622261048,0.0 -345700,3.1520782,1.4281938,,,,,,,,,,,,,, -345800,3.0301402,1.663465,,,,,,,,,,,,,, -345900,2.9581664,1.4272169,,,,,,,,,,,,,, -346000,3.079906,1.1541218,,,,,,,,,,,,,, -346100,3.3247309,1.1519147,,,,,,,,,,,,,, -346200,3.0539813,0.98506916,,,,,,,,,,,,,, -346300,3.296889,2.517517,,,,,,,,,,,,,, -346400,3.1460128,1.0711607,,,,,,,,,,,,,, -346500,3.434766,1.729015,,,,,,,,,,,,,, -346580,,,0.8874218463897705,0.414071649312973,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,155066.39670681953,167804.73359775543,155066.39670681953,12698.328409194946,21.834535837173465,0.0 -346600,3.9263833,3.2607334,,,,,,,,,,,,,, -346700,3.9885235,3.2968597,,,,,,,,,,,,,, -346800,3.1423657,1.0626211,,,,,,,,,,,,,, -346900,3.2831697,1.1214037,,,,,,,,,,,,,, -347000,2.9420547,1.9374745,,,,,,,,,,,,,, -347100,3.0008287,1.1046473,,,,,,,,,,,,,, -347200,3.143156,1.3599951,,,,,,,,,,,,,, -347300,3.8821802,3.2104292,,,,,,,,,,,,,, -347400,2.9671233,1.2763737,,,,,,,,,,,,,, -347500,3.2500448,2.7585926,,,,,,,,,,,,,, -347519,,,0.8876953125,0.4181354343891144,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,155486.5955555439,168259.2629210949,155486.5955555439,12732.518855333328,21.925792455673218,0.0 -347600,3.2669873,1.1761875,,,,,,,,,,,,,, -347700,2.8523695,1.0209824,,,,,,,,,,,,,, -347800,3.3415651,1.1002896,,,,,,,,,,,,,, -347900,2.9203024,2.0601802,,,,,,,,,,,,,, -348000,3.9273117,3.2098465,,,,,,,,,,,,,, -348100,3.9382153,3.207966,,,,,,,,,,,,,, -348200,3.203028,1.1001325,,,,,,,,,,,,,, -348300,3.176155,1.1870383,,,,,,,,,,,,,, -348400,3.1530821,1.3701745,,,,,,,,,,,,,, -348456,,,0.8907812237739563,0.4101113080978393,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,155906.79409885406,168713.85986709595,155906.79409885406,12766.790929794312,22.002236366271973,0.0 -348500,3.242319,2.761762,,,,,,,,,,,,,, -348600,3.0177126,1.1997236,,,,,,,,,,,,,, -348700,3.9691544,3.1872876,,,,,,,,,,,,,, -348800,3.22602,1.3332193,,,,,,,,,,,,,, -348900,2.9681184,2.5234227,,,,,,,,,,,,,, -349000,3.4983876,3.0042365,,,,,,,,,,,,,, -349100,3.0774899,1.1719562,,,,,,,,,,,,,, -349200,3.1903071,1.0807931,,,,,,,,,,,,,, -349300,2.995997,1.1101035,,,,,,,,,,,,,, -349395,,,0.88978511095047,0.4093369245529175,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,156326.7063846588,169170.11686944962,156326.7063846588,12802.994801282885,22.09450626373291,0.0 -349400,3.1036673,2.2954562,,,,,,,,,,,,,, -349500,3.13906,1.0366806,,,,,,,,,,,,,, -349600,3.3629298,1.124304,,,,,,,,,,,,,, -349700,3.3609161,2.7647917,,,,,,,,,,,,,, -349800,3.1346455,1.1398722,,,,,,,,,,,,,, -349900,3.0733716,1.1265931,,,,,,,,,,,,,, -350000,3.374508,2.7023869,,,,,,,,,,,,,, -350100,3.3586516,2.374841,,,,,,,,,,,,,, -350200,3.4730358,1.1062312,,,,,,,,,,,,,, -350300,4.432878,3.206605,,,,,,,,,,,,,, -350333,,,0.8917187452316284,0.4031383097171783,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,156746.6764435768,169622.81052684784,156746.6764435768,12835.585213184357,22.179222583770752,0.0 -350400,2.8849788,1.2191042,,,,,,,,,,,,,, -350500,2.9891095,1.2511091,,,,,,,,,,,,,, -350600,3.2282333,2.8565118,,,,,,,,,,,,,, -350700,3.5943675,1.1838187,,,,,,,,,,,,,, -350800,3.0999842,1.0768447,,,,,,,,,,,,,, -350900,3.0505953,2.4914067,,,,,,,,,,,,,, -351000,3.1492124,2.72474,,,,,,,,,,,,,, -351100,3.1462486,1.1390415,,,,,,,,,,,,,, -351200,3.181642,1.1033795,,,,,,,,,,,,,, -351271,,,0.88818359375,0.4155411422252655,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,157166.64605093002,170077.18683290482,157166.64605093002,12869.849185228348,22.272608995437626,0.0 -351300,3.0501213,1.173271,,,,,,,,,,,,,, -351400,2.8544369,2.203104,,,,,,,,,,,,,, -351500,3.6873927,2.533502,,,,,,,,,,,,,, -351600,3.268454,1.2692994,,,,,,,,,,,,,, -351700,3.124775,1.2069244,,,,,,,,,,,,,, -351800,3.799433,3.1282864,,,,,,,,,,,,,, -351900,3.8339257,2.5717773,,,,,,,,,,,,,, -352000,3.2916503,2.782992,,,,,,,,,,,,,, -352100,3.6879787,2.618771,,,,,,,,,,,,,, -352200,2.8392572,1.3596203,,,,,,,,,,,,,, -352211,,,0.88671875,0.4179138839244842,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,157587.0714263916,170531.37905216217,157587.0714263916,12903.49128460884,22.348501205444336,0.0 -352300,3.1926458,1.2819623,,,,,,,,,,,,,, -352400,3.4502053,2.3621433,,,,,,,,,,,,,, -352500,3.3215604,1.7589408,,,,,,,,,,,,,, -352600,3.0048625,1.0908787,,,,,,,,,,,,,, -352700,3.3436756,1.1411949,,,,,,,,,,,,,, -352800,3.9468248,3.2051795,,,,,,,,,,,,,, -352900,4.086085,3.209017,,,,,,,,,,,,,, -353000,3.0599008,2.412581,,,,,,,,,,,,,, -353100,2.9243088,1.150527,,,,,,,,,,,,,, -353149,,,0.8885351419448853,0.4137209355831146,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,158007.24652957916,170987.2718143463,158007.24652957916,12939.064716339111,22.443691968917847,0.0 -353200,3.622483,1.123407,,,,,,,,,,,,,, -353300,2.979633,1.0987427,,,,,,,,,,,,,, -353400,3.27204,1.1294729,,,,,,,,,,,,,, -353500,3.6468017,3.0253484,,,,,,,,,,,,,, -353600,3.1371171,1.1363451,,,,,,,,,,,,,, -353700,3.5147138,2.7015467,,,,,,,,,,,,,, -353800,3.2734475,1.3555592,,,,,,,,,,,,,, -353900,3.2284737,2.5429513,,,,,,,,,,,,,, -354000,3.0500143,1.4493992,,,,,,,,,,,,,, -354090,,,0.8870312571525574,0.4134266972541809,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,158427.2736890316,171443.43993115425,158427.2736890316,12975.078892946243,22.521716356277462,0.0 -354100,3.005938,2.0363374,,,,,,,,,,,,,, -354200,2.9415026,1.2232558,,,,,,,,,,,,,, -354300,3.0003257,2.5360072,,,,,,,,,,,,,, -354400,3.4743557,3.092213,,,,,,,,,,,,,, -354500,3.0657282,1.07728,,,,,,,,,,,,,, -354600,3.062057,2.5464606,,,,,,,,,,,,,, -354700,2.8365345,1.2345287,,,,,,,,,,,,,, -354800,3.9654381,3.2041173,,,,,,,,,,,,,, -354900,3.4224153,2.8240664,,,,,,,,,,,,,, -355000,3.2280576,1.1375659,,,,,,,,,,,,,, -355032,,,0.8900585770606995,0.410669595003128,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,158847.3594338894,171899.33880877495,158847.3594338894,13010.76617193222,22.5986111164093,0.0 -355100,3.2709222,1.1452467,,,,,,,,,,,,,, -355200,3.1256168,1.0568848,,,,,,,,,,,,,, -355300,2.9322515,2.21004,,,,,,,,,,,,,, -355400,3.0379107,2.481695,,,,,,,,,,,,,, -355500,3.0211432,1.1148388,,,,,,,,,,,,,, -355600,3.0794106,1.1351752,,,,,,,,,,,,,, -355700,3.0332139,1.0802165,,,,,,,,,,,,,, -355800,3.2959874,1.1333697,,,,,,,,,,,,,, -355900,2.9410965,1.1755403,,,,,,,,,,,,,, -355971,,,0.8856835961341858,0.4218739569187164,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,159267.44840097427,172354.46488285065,159267.44840097427,13045.676797151566,22.67598962783813,0.0 -356000,3.1704404,1.5476484,,,,,,,,,,,,,, -356100,3.6930385,3.1987112,,,,,,,,,,,,,, -356200,3.051638,1.4557966,,,,,,,,,,,,,, -356300,2.9942555,1.1663895,,,,,,,,,,,,,, -356400,3.1990545,1.2633945,,,,,,,,,,,,,, -356500,3.7036285,3.2466476,,,,,,,,,,,,,, -356600,3.1480422,1.1383812,,,,,,,,,,,,,, -356700,3.2029393,2.6713533,,,,,,,,,,,,,, -356800,3.3631918,1.2874771,,,,,,,,,,,,,, -356900,3.0362175,1.0563781,,,,,,,,,,,,,, -356910,,,0.8889843821525574,0.4116753935813904,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,159687.5464732647,172810.47546195984,159687.5464732647,13081.4606487751,22.755434274673465,0.0 -357000,3.1782172,1.2936075,,,,,,,,,,,,,, -357100,3.705625,3.101426,,,,,,,,,,,,,, -357200,3.1708372,1.5703179,,,,,,,,,,,,,, -357300,3.2524662,2.3902934,,,,,,,,,,,,,, -357400,3.2730274,2.6437597,,,,,,,,,,,,,, -357500,3.1523585,1.2685708,,,,,,,,,,,,,, -357600,3.0995188,2.3130367,,,,,,,,,,,,,, -357700,2.9933858,2.0786955,,,,,,,,,,,,,, -357800,3.0431356,1.0946107,,,,,,,,,,,,,, -357849,,,0.8860546946525574,0.4163259863853454,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,160107.59187602997,173266.62514972687,160107.59187602997,13117.43409729004,22.83749270439148,0.0 -357900,2.980882,1.8476069,,,,,,,,,,,,,, -358000,2.9044404,1.1666548,,,,,,,,,,,,,, -358100,2.8277214,1.5427057,,,,,,,,,,,,,, -358200,3.2360077,1.1629025,,,,,,,,,,,,,, -358300,3.2049396,1.1079065,,,,,,,,,,,,,, -358400,3.9099276,3.1378198,,,,,,,,,,,,,, -358500,3.26977,2.4599419,,,,,,,,,,,,,, -358600,3.4394171,1.1405845,,,,,,,,,,,,,, -358700,3.1420765,1.4676273,,,,,,,,,,,,,, -358787,,,0.8892577886581421,0.4093237519264221,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,160527.83763742447,173721.14839720726,160527.83763742447,13151.582559347153,22.91771149635315,0.0 -358800,2.9679356,1.6659482,,,,,,,,,,,,,, -358900,2.9751422,2.0943842,,,,,,,,,,,,,, -359000,3.3258758,1.1399986,,,,,,,,,,,,,, -359100,2.8857844,1.4184093,,,,,,,,,,,,,, -359200,3.0009217,1.2051655,,,,,,,,,,,,,, -359300,3.192215,1.2500575,,,,,,,,,,,,,, -359400,2.8913586,1.2807938,,,,,,,,,,,,,, -359500,3.169068,1.2485281,,,,,,,,,,,,,, -359600,2.819863,1.8153228,,,,,,,,,,,,,, -359700,2.9323955,1.1056733,,,,,,,,,,,,,, -359725,,,0.8866406083106995,0.4209662973880768,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,160947.71241402626,174176.86294460297,160947.71241402626,13187.27852511406,23.012847900390625,0.0 -359800,3.1237211,1.0884995,,,,,,,,,,,,,, -359900,2.9454887,1.5759311,,,,,,,,,,,,,, -360000,2.8434904,1.4281528,,,,,,,,,,,,,, -360100,3.1036615,1.1044544,,,,,,,,,,,,,, -360200,4.187851,3.0874615,,,,,,,,,,,,,, -360300,3.4669325,3.1246686,,,,,,,,,,,,,, -360400,3.2527525,1.0853711,,,,,,,,,,,,,, -360500,5.4773297,3.2505758,,,,,,,,,,,,,, -360600,4.0143566,3.2686217,,,,,,,,,,,,,, -360665,,,0.8871093392372131,0.4208268225193023,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,161367.6559035778,174633.06001091003,161367.6559035778,13223.405816078186,23.090537786483765,0.0 -360700,3.1237755,1.9586724,,,,,,,,,,,,,, -360800,2.9397073,2.0117671,,,,,,,,,,,,,, -360900,3.2185383,1.2590804,,,,,,,,,,,,,, -361000,3.0440962,1.1036257,,,,,,,,,,,,,, -361100,3.2220845,1.1913623,,,,,,,,,,,,,, -361200,3.198403,1.2447225,,,,,,,,,,,,,, -361300,3.1500201,1.4662244,,,,,,,,,,,,,, -361400,3.2472024,1.1580158,,,,,,,,,,,,,, -361500,3.6696312,3.1765532,,,,,,,,,,,,,, -361600,2.8594806,1.6834989,,,,,,,,,,,,,, -361602,,,0.8885937333106995,0.4136121273040771,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,161787.65626049042,175089.06940317154,161787.65626049042,13259.286784172058,23.16886854171753,0.0 -361700,3.522309,1.1054856,,,,,,,,,,,,,, -361800,3.0798883,1.2576671,,,,,,,,,,,,,, -361900,2.911631,2.497826,,,,,,,,,,,,,, -362000,3.1676915,1.1874197,,,,,,,,,,,,,, -362100,3.6330848,2.9037175,,,,,,,,,,,,,, -362200,3.206072,3.000061,,,,,,,,,,,,,, -362300,3.3574455,1.1307547,,,,,,,,,,,,,, -362400,3.0289395,1.3286656,,,,,,,,,,,,,, -362500,3.297595,1.7912831,,,,,,,,,,,,,, -362541,,,0.8882616758346558,0.4161655008792877,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,162207.73133540154,175545.66818594933,162207.73133540154,13295.67337846756,23.257583141326904,0.0 -362600,3.1699219,1.1907672,,,,,,,,,,,,,, -362700,3.1881416,2.3069048,,,,,,,,,,,,,, -362800,3.0800853,1.4272821,,,,,,,,,,,,,, -362900,2.9573162,1.0173402,,,,,,,,,,,,,, -363000,3.1982596,1.1823504,,,,,,,,,,,,,, -363100,3.3944902,1.1915148,,,,,,,,,,,,,, -363200,3.042851,1.3705125,,,,,,,,,,,,,, -363300,3.217926,1.1500788,,,,,,,,,,,,,, -363400,3.1986308,1.1675177,,,,,,,,,,,,,, -363479,,,0.8877929449081421,0.416262537240982,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,162627.79280495644,176003.4283466339,162627.79280495644,13333.242425441742,23.33720898628235,0.0 -363500,3.0589778,1.2826406,,,,,,,,,,,,,, -363600,3.2311442,1.164955,,,,,,,,,,,,,, -363700,3.8394926,3.1919866,,,,,,,,,,,,,, -363800,3.0589461,1.5340897,,,,,,,,,,,,,, -363900,3.835367,3.1676543,,,,,,,,,,,,,, -364000,3.1168191,1.0277574,,,,,,,,,,,,,, -364100,3.1538506,1.087513,,,,,,,,,,,,,, -364200,2.91946,1.1112655,,,,,,,,,,,,,, -364300,3.0859716,1.1812046,,,,,,,,,,,,,, -364400,3.0552611,1.0942855,,,,,,,,,,,,,, -364417,,,0.8885741829872131,0.415870189666748,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,163047.926517725,176457.32267189026,163047.926517725,13366.86541056633,23.42582654953003,0.0 -364500,3.2090087,1.1307129,,,,,,,,,,,,,, -364600,3.140093,1.0997701,,,,,,,,,,,,,, -364700,3.0371275,2.0100973,,,,,,,,,,,,,, -364800,2.841087,0.98101074,,,,,,,,,,,,,, -364900,3.2799747,2.4366167,,,,,,,,,,,,,, -365000,2.8254988,1.814899,,,,,,,,,,,,,, -365100,2.9817817,1.079566,,,,,,,,,,,,,, -365200,3.3324192,2.7651377,,,,,,,,,,,,,, -365300,3.0981352,1.0247256,,,,,,,,,,,,,, -365351,,,0.887499988079071,0.4217265546321869,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,163467.9629137516,176912.55845713615,163467.9629137516,13401.921688556671,23.520427465438843,0.0 -365400,3.3550987,1.0539719,,,,,,,,,,,,,, -365500,2.939514,2.209868,,,,,,,,,,,,,, -365600,3.1735294,1.2070423,,,,,,,,,,,,,, -365700,2.8787508,2.1757412,,,,,,,,,,,,,, -365800,3.3499374,2.7910073,,,,,,,,,,,,,, -365900,3.1601007,2.1624517,,,,,,,,,,,,,, -366000,3.3191342,1.1357195,,,,,,,,,,,,,, -366100,3.0379493,1.1522655,,,,,,,,,,,,,, -366200,3.1171963,1.1417859,,,,,,,,,,,,,, -366291,,,0.8881444931030273,0.4135344624519348,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,163887.9586417675,177367.0562813282,163887.9586417675,13436.286578655245,23.60895323753357,0.0 -366300,4.6682363,3.1782959,,,,,,,,,,,,,, -366400,3.3419113,2.8820279,,,,,,,,,,,,,, -366500,3.4834228,1.0825442,,,,,,,,,,,,,, -366600,3.0734358,1.820025,,,,,,,,,,,,,, -366700,3.239881,1.1618512,,,,,,,,,,,,,, -366800,3.3056934,1.1833131,,,,,,,,,,,,,, -366900,3.0685377,1.9379455,,,,,,,,,,,,,, -367000,3.0811806,2.4237328,,,,,,,,,,,,,, -367100,4.1186166,3.1044567,,,,,,,,,,,,,, -367200,3.3062313,2.1405935,,,,,,,,,,,,,, -367230,,,0.8866796493530273,0.4189541637897491,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,164307.8692791462,177820.83196425438,164307.8692791462,13470.023416519163,23.688948154449463,0.0 -367300,3.2290483,2.7040246,,,,,,,,,,,,,, -367400,3.1229677,1.3496226,,,,,,,,,,,,,, -367500,2.9203057,2.1643517,,,,,,,,,,,,,, -367600,2.9940035,1.2060032,,,,,,,,,,,,,, -367700,3.0239756,1.9090221,,,,,,,,,,,,,, -367800,3.1228557,1.5254252,,,,,,,,,,,,,, -367900,3.0460727,1.956949,,,,,,,,,,,,,, -368000,3.0971427,1.1455042,,,,,,,,,,,,,, -368100,3.7090037,2.8141203,,,,,,,,,,,,,, -368170,,,0.8882421851158142,0.4134348630905151,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,164728.0875532627,178275.9132180214,164728.0875532627,13504.752668857574,23.773596048355103,0.0 -368200,2.8591466,1.7052381,,,,,,,,,,,,,, -368300,3.0268843,1.1349227,,,,,,,,,,,,,, -368400,4.7242446,2.6535358,,,,,,,,,,,,,, -368500,3.2102187,1.6923928,,,,,,,,,,,,,, -368600,3.6594372,3.098132,,,,,,,,,,,,,, -368700,3.2014563,1.3643745,,,,,,,,,,,,,, -368800,3.4291487,2.187323,,,,,,,,,,,,,, -368900,3.3640234,1.1564118,,,,,,,,,,,,,, -369000,2.889458,2.200004,,,,,,,,,,,,,, -369100,2.816127,1.4271574,,,,,,,,,,,,,, -369109,,,0.8899999856948853,0.4114833772182464,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,165148.242303133,178730.52012515068,165148.242303133,13539.072884321213,23.85645580291748,0.0 -369200,3.1718254,1.1588544,,,,,,,,,,,,,, -369300,2.984631,1.9885702,,,,,,,,,,,,,, -369400,2.861404,1.4798422,,,,,,,,,,,,,, -369500,3.538675,2.8251176,,,,,,,,,,,,,, -369600,3.1443355,1.7390864,,,,,,,,,,,,,, -369700,3.1677456,1.042212,,,,,,,,,,,,,, -369800,3.8105078,3.024521,,,,,,,,,,,,,, -369900,3.171564,2.0873816,,,,,,,,,,,,,, -370000,3.690624,1.3639265,,,,,,,,,,,,,, -370046,,,0.888476550579071,0.4112453460693359,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,165568.15756726265,179185.61457157135,165568.15756726265,13574.1240670681,23.93623661994934,0.0 -370100,3.0775418,1.9446876,,,,,,,,,,,,,, -370200,3.166655,1.4001081,,,,,,,,,,,,,, -370300,3.11126,1.0803838,,,,,,,,,,,,,, -370400,3.2398136,1.4148915,,,,,,,,,,,,,, -370500,3.5326698,1.7044382,,,,,,,,,,,,,, -370600,2.9117746,1.083964,,,,,,,,,,,,,, -370700,2.9987352,1.0759858,,,,,,,,,,,,,, -370800,2.9594553,1.4547205,,,,,,,,,,,,,, -370900,3.2898424,1.0979888,,,,,,,,,,,,,, -370985,,,0.8868359327316284,0.4194703996181488,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,165988.48790717125,179639.6193087101,165988.48790717125,13607.670770168304,24.01575541496277,0.0 -371000,3.7533677,3.2051554,,,,,,,,,,,,,, -371100,3.0710828,2.4428353,,,,,,,,,,,,,, -371200,3.5015433,3.0158281,,,,,,,,,,,,,, -371300,2.9657705,1.9325273,,,,,,,,,,,,,, -371400,2.7625113,1.7540503,,,,,,,,,,,,,, -371500,3.0313222,2.3381512,,,,,,,,,,,,,, -371600,3.08155,1.110836,,,,,,,,,,,,,, -371700,3.4444141,1.0879368,,,,,,,,,,,,,, -371800,3.354109,1.2293582,,,,,,,,,,,,,, -371900,3.323246,2.7659478,,,,,,,,,,,,,, -371926,,,0.8890234231948853,0.41407310962677,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,166408.45634293556,180094.0869989395,166408.45634293556,13642.039429187776,24.09737181663513,0.0 -372000,3.214007,1.117486,,,,,,,,,,,,,, -372100,3.1419642,1.7036723,,,,,,,,,,,,,, -372200,3.3774407,1.4144622,,,,,,,,,,,,,, -372300,3.1903174,1.1283915,,,,,,,,,,,,,, -372400,3.1567907,1.110558,,,,,,,,,,,,,, -372500,3.1618671,1.8347161,,,,,,,,,,,,,, -372600,2.959722,2.4961188,,,,,,,,,,,,,, -372700,3.5407794,1.1476383,,,,,,,,,,,,,, -372800,3.98418,3.2547932,,,,,,,,,,,,,, -372867,,,0.8907421827316284,0.4095943868160248,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,166828.33721232414,180548.2405500412,166828.33721232414,13676.183526039124,24.1772894859314,0.0 -372900,3.1027918,1.1650324,,,,,,,,,,,,,, -373000,3.4890778,2.965983,,,,,,,,,,,,,, -373100,4.98922,3.2967682,,,,,,,,,,,,,, -373200,3.0141053,2.1001625,,,,,,,,,,,,,, -373300,3.0847242,2.4961457,,,,,,,,,,,,,, -373400,3.8401928,3.280665,,,,,,,,,,,,,, -373500,3.0676305,1.4053142,,,,,,,,,,,,,, -373600,3.129344,1.0905819,,,,,,,,,,,,,, -373700,3.214254,1.9098873,,,,,,,,,,,,,, -373800,3.0783079,1.1618059,,,,,,,,,,,,,, -373806,,,0.8892773389816284,0.4130665361881256,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,167248.5930762291,181003.10163211825,167248.5930762291,13710.648540973663,24.268741846084595,0.0 -373900,3.031085,1.7599137,,,,,,,,,,,,,, -374000,3.8796008,3.260631,,,,,,,,,,,,,, -374100,3.7128935,3.239624,,,,,,,,,,,,,, -374200,3.0255556,1.0474279,,,,,,,,,,,,,, -374300,2.8995059,1.4681481,,,,,,,,,,,,,, -374400,3.2387178,2.7754831,,,,,,,,,,,,,, -374500,2.980767,1.126224,,,,,,,,,,,,,, -374600,3.3774903,2.1879573,,,,,,,,,,,,,, -374700,3.012696,1.0410402,,,,,,,,,,,,,, -374745,,,0.8873242139816284,0.4112350344657898,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,167668.823564291,181458.58004879951,167668.823564291,13745.768962621689,24.34765338897705,0.0 -374800,3.2130296,2.5666249,,,,,,,,,,,,,, -374900,3.0192196,1.0749395,,,,,,,,,,,,,, -375000,2.865445,1.0739894,,,,,,,,,,,,,, -375100,3.279059,1.3338988,,,,,,,,,,,,,, -375200,3.2275028,2.191646,,,,,,,,,,,,,, -375300,4.0213757,3.217926,,,,,,,,,,,,,, -375400,3.1027184,1.080648,,,,,,,,,,,,,, -375500,3.1013927,1.1735964,,,,,,,,,,,,,, -375600,3.5179014,2.1849465,,,,,,,,,,,,,, -375685,,,0.8878905773162842,0.4148047864437103,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,168088.80501580238,181913.4650197029,168088.80501580238,13780.542924642565,24.428040742874146,0.0 -375700,3.3802104,2.8833017,,,,,,,,,,,,,, -375800,3.3133535,1.2655343,,,,,,,,,,,,,, -375900,3.0480127,1.0571278,,,,,,,,,,,,,, -376000,3.1838744,1.0513086,,,,,,,,,,,,,, -376100,3.1979663,2.3466084,,,,,,,,,,,,,, -376200,3.3750253,1.2466898,,,,,,,,,,,,,, -376300,3.1396716,1.2731998,,,,,,,,,,,,,, -376400,3.1300757,1.0757103,,,,,,,,,,,,,, -376500,3.067951,1.5204571,,,,,,,,,,,,,, -376600,3.3248866,1.1264538,,,,,,,,,,,,,, -376625,,,0.8874218463897705,0.4158387184143066,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,168509.0180311203,182369.70614933968,168509.0180311203,13816.431394577026,24.518927574157715,0.0 -376700,3.6200354,2.8598645,,,,,,,,,,,,,, -376800,3.249786,1.2221621,,,,,,,,,,,,,, -376900,3.7550468,1.1945355,,,,,,,,,,,,,, -377000,3.5620832,2.3560097,,,,,,,,,,,,,, -377100,3.0625975,1.12881,,,,,,,,,,,,,, -377200,3.2246141,2.5851073,,,,,,,,,,,,,, -377300,2.9534714,1.0254468,,,,,,,,,,,,,, -377400,2.926779,1.1877785,,,,,,,,,,,,,, -377500,3.2234077,1.9935937,,,,,,,,,,,,,, -377565,,,0.8898242115974426,0.407855361700058,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,168929.21570897102,182824.11052680016,168929.21570897102,13850.50154018402,24.606639623641968,0.0 -377600,3.8607886,3.1441753,,,,,,,,,,,,,, -377700,2.8020208,1.1651402,,,,,,,,,,,,,, -377800,2.8912427,1.6731222,,,,,,,,,,,,,, -377900,2.8139663,1.0368625,,,,,,,,,,,,,, -378000,3.1066468,2.6745286,,,,,,,,,,,,,, -378100,3.2995152,1.1233599,,,,,,,,,,,,,, -378200,3.236017,2.4039361,,,,,,,,,,,,,, -378300,3.080754,1.5223814,,,,,,,,,,,,,, -378400,3.1995747,1.6100777,,,,,,,,,,,,,, -378500,3.0380716,1.3892533,,,,,,,,,,,,,, -378504,,,0.8873632550239563,0.4182749986648559,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,169349.09867095947,183277.38987207413,169349.09867095947,13883.762291669846,24.693422079086304,0.0 -378600,2.9806,1.3697765,,,,,,,,,,,,,, -378700,3.2189324,1.1313872,,,,,,,,,,,,,, -378800,2.954066,1.8364348,,,,,,,,,,,,,, -378900,3.0777626,1.1715755,,,,,,,,,,,,,, -379000,3.1476972,1.0813084,,,,,,,,,,,,,, -379100,3.254325,1.140621,,,,,,,,,,,,,, -379200,3.0695572,1.0856831,,,,,,,,,,,,,, -379300,3.454336,2.9180372,,,,,,,,,,,,,, -379400,3.2355866,1.1809661,,,,,,,,,,,,,, -379443,,,0.88880854845047,0.4128335416316986,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,169769.38675570488,183733.3545923233,169769.38675570488,13919.294860601423,24.78789639472961,0.0 -379500,3.2167342,1.1886792,,,,,,,,,,,,,, -379600,3.2992005,1.1403275,,,,,,,,,,,,,, -379700,2.9117746,1.4442015,,,,,,,,,,,,,, -379800,2.9112613,1.1339171,,,,,,,,,,,,,, -379900,3.2574432,1.1679162,,,,,,,,,,,,,, -380000,3.2679787,1.2013042,,,,,,,,,,,,,, -380100,3.0687273,1.1422693,,,,,,,,,,,,,, -380200,3.4725394,2.9404054,,,,,,,,,,,,,, -380300,3.0782924,1.5952405,,,,,,,,,,,,,, -380383,,,0.8867577910423279,0.4163850247859955,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,170189.518556118,184187.91010832787,170189.518556118,13953.581751823423,24.875013828277588,0.0 -380400,3.2061675,1.3236374,,,,,,,,,,,,,, -380500,3.1502814,1.1877165,,,,,,,,,,,,,, -380600,3.1066258,1.4447354,,,,,,,,,,,,,, -380700,3.5587146,2.7188847,,,,,,,,,,,,,, -380800,3.2191367,1.1120862,,,,,,,,,,,,,, -380900,3.0381472,1.1512516,,,,,,,,,,,,,, -381000,3.250404,1.1033735,,,,,,,,,,,,,, -381100,3.6871567,3.0904453,,,,,,,,,,,,,, -381200,3.7759519,3.0372226,,,,,,,,,,,,,, -381300,3.3511744,1.1928838,,,,,,,,,,,,,, -381324,,,0.8875781297683716,0.4150454699993133,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,170609.48133444786,184641.93013739583,170609.48133444786,13987.508892059326,24.95596599578857,0.0 -381400,3.0658476,1.1896698,,,,,,,,,,,,,, -381500,3.1594217,1.676334,,,,,,,,,,,,,, -381600,3.0957096,2.144241,,,,,,,,,,,,,, -381700,3.0190732,1.1471522,,,,,,,,,,,,,, -381800,3.1457462,1.3701253,,,,,,,,,,,,,, -381900,3.508241,1.9051902,,,,,,,,,,,,,, -382000,3.317254,1.1943899,,,,,,,,,,,,,, -382100,3.5409207,1.1048412,,,,,,,,,,,,,, -382200,3.090558,1.2270705,,,,,,,,,,,,,, -382265,,,0.8875585794448853,0.4150723218917846,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,171029.39256739616,185095.8162283897,171029.39256739616,14021.341363430023,25.048585176467896,0.0 -382300,4.038888,3.214286,,,,,,,,,,,,,, -382400,3.8341663,3.0877838,,,,,,,,,,,,,, -382500,3.263535,2.1917255,,,,,,,,,,,,,, -382600,3.0079734,1.2413708,,,,,,,,,,,,,, -382700,3.2692761,1.8469115,,,,,,,,,,,,,, -382800,4.451322,1.5913393,,,,,,,,,,,,,, -382900,2.7214081,1.0563344,,,,,,,,,,,,,, -383000,3.0150182,1.2533634,,,,,,,,,,,,,, -383100,3.0526965,1.0702751,,,,,,,,,,,,,, -383200,3.2713387,1.0551282,,,,,,,,,,,,,, -383205,,,0.8885546922683716,0.4139377176761627,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,171449.58943510056,185551.0590927601,171449.58943510056,14056.252525568008,25.13417649269104,0.0 -383300,2.9095898,1.0213886,,,,,,,,,,,,,, -383400,3.8590553,3.192975,,,,,,,,,,,,,, -383500,4.0523157,3.3740416,,,,,,,,,,,,,, -383600,3.019816,1.573133,,,,,,,,,,,,,, -383700,3.065433,1.1966751,,,,,,,,,,,,,, -383800,3.2788103,1.0829377,,,,,,,,,,,,,, -383900,3.1964464,1.2328391,,,,,,,,,,,,,, -384000,3.1217895,1.1011235,,,,,,,,,,,,,, -384100,3.1457684,1.1710367,,,,,,,,,,,,,, -384144,,,0.8873242139816284,0.4182506799697876,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,171869.548060894,186005.1061720848,171869.548060894,14090.210060834885,25.21590518951416,0.0 -384200,2.9208603,1.0649971,,,,,,,,,,,,,, -384300,3.0426824,1.1113919,,,,,,,,,,,,,, -384400,3.21156,2.5317314,,,,,,,,,,,,,, -384500,3.055466,1.1201383,,,,,,,,,,,,,, -384600,3.4683068,1.0651121,,,,,,,,,,,,,, -384700,3.1300492,2.399422,,,,,,,,,,,,,, -384800,3.3400908,1.076811,,,,,,,,,,,,,, -384900,3.728634,1.6112467,,,,,,,,,,,,,, -385000,2.868244,1.2449903,,,,,,,,,,,,,, -385083,,,0.8873828053474426,0.4178923964500427,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,172289.46645498276,186460.21090960503,172289.46645498276,14125.25502538681,25.30760169029236,0.0 -385100,3.050501,1.5298125,,,,,,,,,,,,,, -385200,3.4919734,1.1779616,,,,,,,,,,,,,, -385300,3.2242339,2.2247765,,,,,,,,,,,,,, -385400,3.7481968,2.81077,,,,,,,,,,,,,, -385500,3.1183655,1.1914104,,,,,,,,,,,,,, -385600,2.9755487,1.0822872,,,,,,,,,,,,,, -385700,3.1258006,1.0844369,,,,,,,,,,,,,, -385800,3.0906937,1.0644627,,,,,,,,,,,,,, -385900,3.325375,1.3210814,,,,,,,,,,,,,, -386000,2.9531288,1.0474216,,,,,,,,,,,,,, -386025,,,0.8904687166213989,0.4116628468036651,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,172709.62967181206,186913.7729110717,172709.62967181206,14158.521426916122,25.390140295028687,0.0 -386100,3.0616207,1.5266417,,,,,,,,,,,,,, -386200,3.1686223,1.1076266,,,,,,,,,,,,,, -386300,3.1096575,1.862274,,,,,,,,,,,,,, -386400,3.1465697,1.100004,,,,,,,,,,,,,, -386500,3.4681516,1.1566712,,,,,,,,,,,,,, -386600,3.2322676,1.6849297,,,,,,,,,,,,,, -386700,4.0121922,1.1066748,,,,,,,,,,,,,, -386800,3.2788105,1.0407357,,,,,,,,,,,,,, -386900,3.094055,2.3607965,,,,,,,,,,,,,, -386964,,,0.8865038752555847,0.4218724966049194,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,173129.71286416054,187368.9116373062,173129.71286416054,14193.428065538406,25.49014639854431,0.0 -387000,3.106349,1.1472493,,,,,,,,,,,,,, -387100,2.9805071,1.1908915,,,,,,,,,,,,,, -387200,2.8136868,1.2832469,,,,,,,,,,,,,, -387300,3.235303,1.6091721,,,,,,,,,,,,,, -387400,3.4567447,3.0127313,,,,,,,,,,,,,, -387500,3.229728,2.3169265,,,,,,,,,,,,,, -387600,3.1208186,1.0391022,,,,,,,,,,,,,, -387700,2.882147,2.266932,,,,,,,,,,,,,, -387800,3.1393638,1.504627,,,,,,,,,,,,,, -387900,,,0.888476550579071,0.4137118458747864,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,173549.81622862816,187823.31874752045,173549.81622862816,14227.59961605072,25.574037551879883,0.0 -387900,3.2422361,2.3588338,,,,,,,,,,,,,, -388000,3.164774,2.2599602,,,,,,,,,,,,,, -388100,3.191376,1.5905861,,,,,,,,,,,,,, -388200,3.3321407,1.3797354,,,,,,,,,,,,,, -388300,3.0518663,1.0812495,,,,,,,,,,,,,, -388400,3.321597,2.707586,,,,,,,,,,,,,, -388500,3.096716,1.5273396,,,,,,,,,,,,,, -388600,3.5486224,1.6125957,,,,,,,,,,,,,, -388700,3.0660179,1.1129704,,,,,,,,,,,,,, -388800,3.155155,1.0397117,,,,,,,,,,,,,, -388837,,,0.8871679306030273,0.4193262457847595,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,173970.01164507866,188277.86225795743,173970.01164507866,14261.797946691511,25.674198389053345,0.0 -388900,3.9741704,3.1722293,,,,,,,,,,,,,, -389000,3.4170408,1.2128413,,,,,,,,,,,,,, -389100,4.816651,1.1238835,,,,,,,,,,,,,, -389200,3.409351,2.9596968,,,,,,,,,,,,,, -389300,3.056474,1.1063722,,,,,,,,,,,,,, -389400,3.3898487,1.1534209,,,,,,,,,,,,,, -389500,2.7903545,1.0038784,,,,,,,,,,,,,, -389600,2.8853488,1.2793418,,,,,,,,,,,,,, -389700,2.8928137,1.9756411,,,,,,,,,,,,,, -389777,,,0.8882616758346558,0.4169135689735412,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,174390.0960047245,188731.5424156189,174390.0960047245,14295.256955862043,25.76052308082581,0.0 -389800,3.1082385,2.4187508,,,,,,,,,,,,,, -389900,3.0955517,1.1279129,,,,,,,,,,,,,, -390000,3.1168985,1.0974412,,,,,,,,,,,,,, -390100,2.960615,1.3399522,,,,,,,,,,,,,, -390200,3.407969,1.0460922,,,,,,,,,,,,,, -390300,3.3768377,1.0689754,,,,,,,,,,,,,, -390400,3.0783622,1.1249715,,,,,,,,,,,,,, -390500,3.810584,3.3042314,,,,,,,,,,,,,, -390600,3.0005782,1.2993269,,,,,,,,,,,,,, -390700,3.1371133,1.2295749,,,,,,,,,,,,,, -390717,,,0.8859570026397705,0.4196897149085998,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,174809.967348814,189184.9430897236,174809.967348814,14328.652228593826,25.845441102981567,0.0 -390800,5.152223,3.264142,,,,,,,,,,,,,, -390900,3.3102937,1.2170452,,,,,,,,,,,,,, -391000,3.3204756,1.2494222,,,,,,,,,,,,,, -391100,3.0916655,1.849335,,,,,,,,,,,,,, -391200,3.0777986,1.2100358,,,,,,,,,,,,,, -391300,3.0506935,1.0388399,,,,,,,,,,,,,, -391400,3.1854067,1.6965314,,,,,,,,,,,,,, -391500,3.3418486,1.1728137,,,,,,,,,,,,,, -391600,3.1659338,1.1719725,,,,,,,,,,,,,, -391656,,,0.8898046612739563,0.4087473452091217,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,175229.94012641907,189639.92478442192,175229.94012641907,14363.51612186432,25.941320419311523,0.0 -391700,3.354357,1.3431804,,,,,,,,,,,,,, -391800,3.4195228,1.1274152,,,,,,,,,,,,,, -391900,3.1205177,1.1609269,,,,,,,,,,,,,, -392000,4.899351,3.2600768,,,,,,,,,,,,,, -392100,3.1396468,2.4346342,,,,,,,,,,,,,, -392200,3.1938531,1.1417919,,,,,,,,,,,,,, -392300,3.0323968,1.6359557,,,,,,,,,,,,,, -392400,3.5211577,1.1965595,,,,,,,,,,,,,, -392500,3.2288148,1.310211,,,,,,,,,,,,,, -392596,,,0.8891991972923279,0.4111690521240234,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,175649.9788453579,190094.58134031296,175649.9788453579,14397.997807979584,26.028414011001587,0.0 -392600,3.1964877,1.0996919,,,,,,,,,,,,,, -392700,3.3929229,2.7246847,,,,,,,,,,,,,, -392800,2.8584685,1.7416626,,,,,,,,,,,,,, -392900,3.0684283,1.7120211,,,,,,,,,,,,,, -393000,3.0735655,1.1978544,,,,,,,,,,,,,, -393100,3.2712636,2.5621893,,,,,,,,,,,,,, -393200,3.0795264,1.1335381,,,,,,,,,,,,,, -393300,3.183353,1.2098937,,,,,,,,,,,,,, -393400,3.2527933,1.0867641,,,,,,,,,,,,,, -393500,3.0870318,1.1589105,,,,,,,,,,,,,, -393535,,,0.8874413967132568,0.4186658263206482,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,176070.06787610054,190548.05064105988,176070.06787610054,14431.227613449097,26.12958836555481,0.0 -393600,2.9540164,1.7810574,,,,,,,,,,,,,, -393700,3.6509538,3.1682901,,,,,,,,,,,,,, -393800,2.965784,1.0949652,,,,,,,,,,,,,, -393900,3.2449927,1.109968,,,,,,,,,,,,,, -394000,3.0252802,1.091136,,,,,,,,,,,,,, -394100,3.0232344,1.106004,,,,,,,,,,,,,, -394200,2.930144,1.7617819,,,,,,,,,,,,,, -394300,3.271776,2.1317997,,,,,,,,,,,,,, -394400,3.2122695,1.1144634,,,,,,,,,,,,,, -394476,,,0.8876562118530273,0.4162094295024872,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,176490.34965968132,191004.4239397049,176490.34965968132,14467.168281316755,26.231871128082275,0.0 -394500,3.0363686,1.3437613,,,,,,,,,,,,,, -394600,3.428576,2.6979697,,,,,,,,,,,,,, -394700,3.154645,1.3204874,,,,,,,,,,,,,, -394800,3.1743085,1.2238046,,,,,,,,,,,,,, -394900,3.0993843,1.0906736,,,,,,,,,,,,,, -395000,3.1622667,1.1519792,,,,,,,,,,,,,, -395100,3.0544298,1.0753505,,,,,,,,,,,,,, -395200,3.1096644,2.6359618,,,,,,,,,,,,,, -395300,3.2513864,1.1403039,,,,,,,,,,,,,, -395400,3.9512897,3.3969803,,,,,,,,,,,,,, -395417,,,0.8920117020606995,0.4101240336894989,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,176910.50852513313,191458.0256493092,176910.50852513313,14500.466915845873,26.326695442199707,0.0 -395500,3.2649026,1.2632395,,,,,,,,,,,,,, -395600,3.1065798,1.5132511,,,,,,,,,,,,,, -395700,3.7396426,3.226287,,,,,,,,,,,,,, -395800,3.0104263,2.3607981,,,,,,,,,,,,,, -395900,3.4534407,1.1947417,,,,,,,,,,,,,, -396000,3.1189344,1.1023812,,,,,,,,,,,,,, -396100,3.5946007,3.2072365,,,,,,,,,,,,,, -396200,3.6609683,2.6724842,,,,,,,,,,,,,, -396300,3.2134235,1.1310738,,,,,,,,,,,,,, -396356,,,0.88880854845047,0.409117728471756,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,177330.77657341957,191913.3264658451,177330.77657341957,14535.344248533249,26.43313837051392,0.0 -396400,3.0608394,1.7202992,,,,,,,,,,,,,, -396500,2.9595795,1.146645,,,,,,,,,,,,,, -396600,3.077622,1.1418737,,,,,,,,,,,,,, -396700,3.430737,1.8797916,,,,,,,,,,,,,, -396800,2.8958213,1.128061,,,,,,,,,,,,,, -396900,3.244028,2.9038138,,,,,,,,,,,,,, -397000,3.0008645,1.1344421,,,,,,,,,,,,,, -397100,3.1112318,2.424279,,,,,,,,,,,,,, -397200,2.9306931,2.6777527,,,,,,,,,,,,,, -397296,,,0.8882226347923279,0.4123573303222656,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,177750.9477841854,192369.61221718788,177750.9477841854,14571.324562311172,26.5184805393219,0.0 -397300,3.3607273,1.1354523,,,,,,,,,,,,,, -397400,3.1102304,1.1079062,,,,,,,,,,,,,, -397500,3.3029573,1.218252,,,,,,,,,,,,,, -397600,3.252789,2.3490977,,,,,,,,,,,,,, -397700,3.723171,3.109536,,,,,,,,,,,,,, -397800,2.6929994,1.480226,,,,,,,,,,,,,, -397900,3.8463619,2.88404,,,,,,,,,,,,,, -398000,3.0690293,1.8880693,,,,,,,,,,,,,, -398100,3.0760746,1.097738,,,,,,,,,,,,,, -398200,3.2222178,2.1105049,,,,,,,,,,,,,, -398236,,,0.88880854845047,0.4158269166946411,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,178170.99898934364,192825.60821533203,178170.99898934364,14607.131270647047,26.60717272758484,0.0 -398300,2.9405437,1.5134215,,,,,,,,,,,,,, -398400,3.4055743,3.0470161,,,,,,,,,,,,,, -398500,2.9575617,1.0940914,,,,,,,,,,,,,, -398600,3.1045554,1.1957885,,,,,,,,,,,,,, -398700,3.047469,1.3784174,,,,,,,,,,,,,, -398800,3.3079598,1.1388235,,,,,,,,,,,,,, -398900,3.1008084,2.5703096,,,,,,,,,,,,,, -399000,2.9488053,1.8254949,,,,,,,,,,,,,, -399100,2.9680338,1.4492418,,,,,,,,,,,,,, -399173,,,0.8885546922683716,0.4094387590885162,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,178591.1631834507,193281.24997234344,178591.1631834507,14642.470609903336,26.69651246070861,0.0 -399200,3.0407932,1.1000316,,,,,,,,,,,,,, -399300,2.9588532,1.1282651,,,,,,,,,,,,,, -399400,3.1015913,1.3583708,,,,,,,,,,,,,, -399500,3.260793,1.1704662,,,,,,,,,,,,,, -399600,3.185134,1.4168277,,,,,,,,,,,,,, -399700,2.9847324,2.4581845,,,,,,,,,,,,,, -399800,3.0591586,1.3914347,,,,,,,,,,,,,, -399900,3.0164115,1.2958256,,,,,,,,,,,,,, -400000,2.8774772,1.4552187,,,,,,,,,,,,,, -400100,2.9644408,1.2342776,,,,,,,,,,,,,, -400113,,,0.8860546946525574,0.4168408811092376,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,179011.04693436623,193736.52523708344,179011.04693436623,14677.70623278618,26.80361557006836,0.0 -400200,3.4514027,3.197817,,,,,,,,,,,,,, -400300,3.4436193,1.6714406,,,,,,,,,,,,,, -400400,3.6384447,3.2427864,,,,,,,,,,,,,, -400500,3.8023849,3.2764752,,,,,,,,,,,,,, -400600,3.5580466,2.9744308,,,,,,,,,,,,,, -400700,3.0294056,1.4032223,,,,,,,,,,,,,, -400800,3.3706782,3.0659506,,,,,,,,,,,,,, -400900,3.0492294,1.1205004,,,,,,,,,,,,,, -401000,3.1274936,1.1007488,,,,,,,,,,,,,, -401052,,,0.8895898461341858,0.4095662832260132,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,179431.1743233204,194190.1621646881,179431.1743233204,14711.077335596085,26.893108129501343,0.0 -401100,2.8268983,1.878762,,,,,,,,,,,,,, -401200,3.0478625,1.219116,,,,,,,,,,,,,, -401300,3.4716492,2.8596964,,,,,,,,,,,,,, -401400,3.2981913,2.69755,,,,,,,,,,,,,, -401500,3.156833,1.1463293,,,,,,,,,,,,,, -401600,3.2250838,1.114959,,,,,,,,,,,,,, -401700,3.1774364,1.7373579,,,,,,,,,,,,,, -401800,3.0317955,1.681048,,,,,,,,,,,,,, -401900,2.975293,1.498001,,,,,,,,,,,,,, -401993,,,0.8876367211341858,0.4177990555763244,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,179851.25026535988,194645.4740064144,179851.25026535988,14746.11722946167,27.03999924659729,0.0 -402000,3.138984,1.5218263,,,,,,,,,,,,,, -402100,3.1528273,1.2114264,,,,,,,,,,,,,, -402200,3.2246187,2.549886,,,,,,,,,,,,,, -402300,3.0793712,2.0683868,,,,,,,,,,,,,, -402400,2.9042845,1.2103567,,,,,,,,,,,,,, -402500,2.9146216,1.139889,,,,,,,,,,,,,, -402600,3.030814,1.0802224,,,,,,,,,,,,,, -402700,3.7011144,2.934963,,,,,,,,,,,,,, -402800,3.4314911,2.8681138,,,,,,,,,,,,,, -402900,3.185641,1.1297257,,,,,,,,,,,,,, -402930,,,0.8872265219688416,0.4172936677932739,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,180271.34703087807,195099.27899551392,180271.34703087807,14779.690134763718,27.12631464004517,0.0 -403000,3.442309,1.1801655,,,,,,,,,,,,,, -403100,3.8648221,3.257798,,,,,,,,,,,,,, -403200,3.1285307,1.1860256,,,,,,,,,,,,,, -403300,3.1937144,2.9599226,,,,,,,,,,,,,, -403400,3.216875,1.0737476,,,,,,,,,,,,,, -403500,3.0762916,2.229391,,,,,,,,,,,,,, -403600,3.7573626,3.2796006,,,,,,,,,,,,,, -403700,3.156911,1.170175,,,,,,,,,,,,,, -403800,2.881626,1.98736,,,,,,,,,,,,,, -403869,,,0.8881054520606995,0.4169188141822815,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,180691.3440322876,195553.56817293167,180691.3440322876,14813.849033117294,27.21081781387329,0.0 -403900,3.2247748,1.22697,,,,,,,,,,,,,, -404000,3.4032218,2.302937,,,,,,,,,,,,,, -404100,2.7980323,1.9493927,,,,,,,,,,,,,, -404200,3.2004998,1.8664733,,,,,,,,,,,,,, -404300,4.2903786,3.3194606,,,,,,,,,,,,,, -404400,3.0546248,1.4518646,,,,,,,,,,,,,, -404500,2.9996538,1.0195947,,,,,,,,,,,,,, -404600,3.0331993,1.0646509,,,,,,,,,,,,,, -404700,3.1724682,1.7755792,,,,,,,,,,,,,, -404800,3.0828793,2.5821111,,,,,,,,,,,,,, -404806,,,0.8880273103713989,0.4125872850418091,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,181111.3825807572,196009.753885746,181111.3825807572,14849.860277414322,27.298006534576416,0.0 -404900,3.2273505,1.1130934,,,,,,,,,,,,,, -405000,3.373593,1.1590468,,,,,,,,,,,,,, -405100,2.930654,1.1502422,,,,,,,,,,,,,, -405200,3.0536337,1.3745615,,,,,,,,,,,,,, -405300,3.1282983,1.2179183,,,,,,,,,,,,,, -405400,3.2473137,2.470027,,,,,,,,,,,,,, -405500,3.012751,1.9097444,,,,,,,,,,,,,, -405600,2.8755758,1.1396629,,,,,,,,,,,,,, -405700,2.9682724,1.5767035,,,,,,,,,,,,,, -405746,,,0.8882226347923279,0.4126043915748596,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,181531.36791729927,196466.27555465687,181531.36791729927,14886.262979507446,27.38276648521424,0.0 -405800,2.9992135,1.43221,,,,,,,,,,,,,, -405900,3.7386987,3.1492991,,,,,,,,,,,,,, -406000,3.1359355,2.4288337,,,,,,,,,,,,,, -406100,2.91777,1.3073478,,,,,,,,,,,,,, -406200,3.2032979,1.1796312,,,,,,,,,,,,,, -406300,2.766896,1.4273113,,,,,,,,,,,,,, -406400,2.9347198,1.9258802,,,,,,,,,,,,,, -406500,3.159859,2.6416492,,,,,,,,,,,,,, -406600,3.2920856,2.9002588,,,,,,,,,,,,,, -406686,,,0.8878124952316284,0.4145838916301727,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,181951.2723581791,196920.69525766373,181951.2723581791,14920.631911039352,27.48050045967102,0.0 -406700,3.2516172,2.970732,,,,,,,,,,,,,, -406800,3.1598442,1.2749282,,,,,,,,,,,,,, -406900,5.206563,3.3046842,,,,,,,,,,,,,, -407000,3.317992,1.0865012,,,,,,,,,,,,,, -407100,3.764595,3.0652192,,,,,,,,,,,,,, -407200,3.1276655,1.5268154,,,,,,,,,,,,,, -407300,3.090894,2.1323478,,,,,,,,,,,,,, -407400,3.4848554,3.1682844,,,,,,,,,,,,,, -407500,2.8814466,1.082243,,,,,,,,,,,,,, -407600,3.4231694,1.1878853,,,,,,,,,,,,,, -407624,,,0.8862109184265137,0.419201523065567,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,182371.54948091507,197375.4734230041,182371.54948091507,14954.981954574583,27.58235263824463,0.0 -407700,3.201453,2.7868803,,,,,,,,,,,,,, -407800,3.292516,1.1123117,,,,,,,,,,,,,, -407900,3.0178797,1.4629769,,,,,,,,,,,,,, -408000,3.3633401,1.0503918,,,,,,,,,,,,,, -408100,5.6930995,1.300186,,,,,,,,,,,,,, -408200,3.0173643,2.2084148,,,,,,,,,,,,,, -408300,4.0629807,1.2239583,,,,,,,,,,,,,, -408400,3.079382,1.1950203,,,,,,,,,,,,,, -408500,3.8256156,3.1568952,,,,,,,,,,,,,, -408562,,,0.8884179592132568,0.4154849052429199,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,182791.8166847229,197831.5146150589,182791.8166847229,14990.611825942991,27.677670001983643,0.0 -408600,2.9962225,1.1242807,,,,,,,,,,,,,, -408700,3.072614,1.2701505,,,,,,,,,,,,,, -408800,3.179085,1.0753559,,,,,,,,,,,,,, -408900,3.4938843,1.276369,,,,,,,,,,,,,, -409000,3.2194467,2.1934283,,,,,,,,,,,,,, -409100,2.7423978,1.6152765,,,,,,,,,,,,,, -409200,3.3289723,2.082239,,,,,,,,,,,,,, -409300,3.2590759,1.1864932,,,,,,,,,,,,,, -409400,3.0305176,1.5460244,,,,,,,,,,,,,, -409500,4.0066905,3.062799,,,,,,,,,,,,,, -409503,,,0.8890624642372131,0.412666767835617,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,183211.9351406097,198287.1301677227,183211.9351406097,15025.966703891754,27.77146291732788,0.0 -409600,2.7476358,1.1188115,,,,,,,,,,,,,, -409700,2.9913723,1.0773187,,,,,,,,,,,,,, -409800,3.024533,1.9927635,,,,,,,,,,,,,, -409900,2.9689171,1.1015661,,,,,,,,,,,,,, -410000,2.9416804,2.2201047,,,,,,,,,,,,,, -410100,3.4812677,3.0565488,,,,,,,,,,,,,, -410200,3.073155,1.0385125,,,,,,,,,,,,,, -410300,2.9401367,1.2252895,,,,,,,,,,,,,, -410400,3.0575824,1.0623783,,,,,,,,,,,,,, -410441,,,0.8878710865974426,0.4203723073005676,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,183632.14395451543,198742.9077372551,183632.14395451543,15061.399684906006,27.8588445186615,0.0 -410500,3.0653367,2.3606172,,,,,,,,,,,,,, -410600,2.9509535,1.002192,,,,,,,,,,,,,, -410700,2.6910946,1.2467909,,,,,,,,,,,,,, -410800,2.7864637,1.0645663,,,,,,,,,,,,,, -410900,4.0631046,2.973391,,,,,,,,,,,,,, -411000,4.103443,3.3785436,,,,,,,,,,,,,, -411100,3.9286585,3.1800199,,,,,,,,,,,,,, -411200,3.324929,1.058388,,,,,,,,,,,,,, -411300,3.3778663,2.7662144,,,,,,,,,,,,,, -411379,,,0.8880273103713989,0.4152186810970306,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,184052.0600640773,199197.3196368217,184052.0600640773,15095.746675729752,27.959092378616333,0.0 -411400,3.23676,2.426811,,,,,,,,,,,,,, -411500,4.3603477,3.270683,,,,,,,,,,,,,, -411600,3.2835183,1.363911,,,,,,,,,,,,,, -411700,3.4650636,1.077552,,,,,,,,,,,,,, -411800,3.0845537,1.1620978,,,,,,,,,,,,,, -411900,3.2061927,1.0718378,,,,,,,,,,,,,, -412000,3.1011403,1.0143464,,,,,,,,,,,,,, -412100,2.9760425,2.5228457,,,,,,,,,,,,,, -412200,3.1227908,1.5864719,,,,,,,,,,,,,, -412300,3.2139566,1.3576994,,,,,,,,,,,,,, -412317,,,0.8885546922683716,0.4164588153362274,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,184472.1326699257,199652.8047463894,184472.1326699257,15131.023107528688,28.0471260547638,0.0 -412400,3.1276362,1.1168933,,,,,,,,,,,,,, -412500,3.2317386,1.1557093,,,,,,,,,,,,,, -412600,2.8066049,1.2644668,,,,,,,,,,,,,, -412700,2.9581954,1.3301297,,,,,,,,,,,,,, -412800,3.0678275,2.5167873,,,,,,,,,,,,,, -412900,3.0108159,1.1152872,,,,,,,,,,,,,, -413000,3.2724473,1.2443417,,,,,,,,,,,,,, -413100,2.7872062,2.1452801,,,,,,,,,,,,,, -413200,3.2763462,1.2045286,,,,,,,,,,,,,, -413257,,,0.8854687213897705,0.4227531850337982,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,184892.1034603119,200106.71964406967,184892.1034603119,15164.818674087524,28.145170211791992,0.0 -413300,3.2903214,1.0298554,,,,,,,,,,,,,, -413400,3.0177784,1.1076708,,,,,,,,,,,,,, -413500,3.1891155,1.045576,,,,,,,,,,,,,, -413600,2.8129413,1.0343678,,,,,,,,,,,,,, -413700,3.0342593,1.1318862,,,,,,,,,,,,,, -413800,2.910398,1.0703994,,,,,,,,,,,,,, -413900,3.1349266,1.6285772,,,,,,,,,,,,,, -414000,3.3394997,1.1501008,,,,,,,,,,,,,, -414100,3.1229594,1.1909444,,,,,,,,,,,,,, -414195,,,0.8895312547683716,0.4116630852222442,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,185312.38032460213,200561.1308102608,185312.38032460213,15198.814550161362,28.235026121139526,0.0 -414200,3.2228787,3.0057862,,,,,,,,,,,,,, -414300,3.1230648,1.1395121,,,,,,,,,,,,,, -414400,2.9866655,1.0545796,,,,,,,,,,,,,, -414500,3.114569,2.6637006,,,,,,,,,,,,,, -414600,3.340666,1.1285053,,,,,,,,,,,,,, -414700,2.972517,1.0889826,,,,,,,,,,,,,, -414800,3.1393552,1.2529528,,,,,,,,,,,,,, -414900,3.3894804,2.5568252,,,,,,,,,,,,,, -415000,3.1120176,2.7848217,,,,,,,,,,,,,, -415100,3.7925494,3.1683943,,,,,,,,,,,,,, -415132,,,0.8891991972923279,0.4130409955978393,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,185732.29559206963,201016.1106908321,185732.29559206963,15233.74088358879,28.325153589248657,0.0 -415200,3.7208452,2.8977795,,,,,,,,,,,,,, -415300,3.0816123,1.1485559,,,,,,,,,,,,,, -415400,2.8952699,1.1156435,,,,,,,,,,,,,, -415500,3.2916665,1.1984675,,,,,,,,,,,,,, -415600,3.4312828,1.0825713,,,,,,,,,,,,,, -415700,3.0498216,1.0727621,,,,,,,,,,,,,, -415800,3.071152,2.5278308,,,,,,,,,,,,,, -415900,3.2304718,1.1401767,,,,,,,,,,,,,, -416000,3.3346508,1.5026948,,,,,,,,,,,,,, -416073,,,0.8877929449081421,0.416156530380249,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,186152.5260412693,201471.02529382703,186152.5260412693,15268.27907562256,28.421623706817627,0.0 -416100,3.3971353,1.133627,,,,,,,,,,,,,, -416200,3.273523,1.1437411,,,,,,,,,,,,,, -416300,2.818983,1.4232861,,,,,,,,,,,,,, -416400,3.0274782,1.1559097,,,,,,,,,,,,,, -416500,3.256867,1.163269,,,,,,,,,,,,,, -416600,2.9713295,1.6035595,,,,,,,,,,,,,, -416700,3.3280022,1.0863391,,,,,,,,,,,,,, -416800,3.2591145,2.3190305,,,,,,,,,,,,,, -416900,2.7531922,1.613516,,,,,,,,,,,,,, -417000,3.4810784,3.0412242,,,,,,,,,,,,,, -417016,,,0.8877539038658142,0.4165347814559936,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,186572.4598255157,201925.7241203785,186572.4598255157,15302.907843351364,28.50839161872864,0.0 -417100,3.3583105,3.040905,,,,,,,,,,,,,, -417200,2.9062798,1.013998,,,,,,,,,,,,,, -417300,2.8913035,1.2653005,,,,,,,,,,,,,, -417400,3.2837188,1.2143682,,,,,,,,,,,,,, -417500,2.9919407,1.25539,,,,,,,,,,,,,, -417600,3.0682611,1.0448941,,,,,,,,,,,,,, -417700,3.8073642,3.1283417,,,,,,,,,,,,,, -417800,2.9541762,1.7723248,,,,,,,,,,,,,, -417900,3.8788958,3.3121977,,,,,,,,,,,,,, -417955,,,0.88832026720047,0.4133238494396209,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,186992.08824324608,202381.94254612923,186992.08824324608,15339.094573259354,28.86305069923401,0.0 -418000,2.9272726,1.5628262,,,,,,,,,,,,,, -418100,3.5071461,3.1022174,,,,,,,,,,,,,, -418200,3.5389776,3.1287842,,,,,,,,,,,,,, -418300,2.9574277,1.0706246,,,,,,,,,,,,,, -418400,3.1930919,1.3859942,,,,,,,,,,,,,, -418500,3.0548868,2.3620303,,,,,,,,,,,,,, -418600,3.0468545,1.1495538,,,,,,,,,,,,,, -418700,3.0822008,1.0735247,,,,,,,,,,,,,, -418800,3.0439544,1.1902745,,,,,,,,,,,,,, -418895,,,0.8907421827316284,0.412113755941391,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,187412.05957746503,202835.6733837128,187412.05957746503,15372.716371774672,28.951866388320923,0.0 -418900,3.2073011,1.8565214,,,,,,,,,,,,,, -419000,2.9027908,1.9398096,,,,,,,,,,,,,, -419100,2.8056538,1.400166,,,,,,,,,,,,,, -419200,3.090744,1.1626071,,,,,,,,,,,,,, -419300,3.234534,1.0499363,,,,,,,,,,,,,, -419400,3.0106854,1.0659711,,,,,,,,,,,,,, -419500,3.1103895,1.5063081,,,,,,,,,,,,,, -419600,3.1330125,2.0947032,,,,,,,,,,,,,, -419700,3.0256655,1.133172,,,,,,,,,,,,,, -419800,3.0515208,1.779197,,,,,,,,,,,,,, -419835,,,0.8883398175239563,0.4108054339885711,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,187832.055000782,203290.7157907486,187832.055000782,15407.61594223976,29.050409078598022,0.0 -419900,3.3239653,1.1154329,,,,,,,,,,,,,, -420000,3.2937193,1.0824496,,,,,,,,,,,,,, -420100,2.9964929,1.082978,,,,,,,,,,,,,, -420200,4.1282673,3.0889943,,,,,,,,,,,,,, -420300,3.3798084,1.1065004,,,,,,,,,,,,,, -420400,2.9524777,1.1030593,,,,,,,,,,,,,, -420500,3.342644,2.7774384,,,,,,,,,,,,,, -420600,3.5080507,1.1802794,,,,,,,,,,,,,, -420700,3.3633883,1.2705581,,,,,,,,,,,,,, -420776,,,0.8899804353713989,0.4098820686340332,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,188252.1757707596,203745.8379309177,188252.1757707596,15442.47093296051,29.14796638488769,0.0 -420800,3.1356566,1.6964904,,,,,,,,,,,,,, -420900,3.269471,1.2130903,,,,,,,,,,,,,, -421000,4.053951,3.206593,,,,,,,,,,,,,, -421100,2.831996,1.9497802,,,,,,,,,,,,,, -421200,3.2151399,2.1519566,,,,,,,,,,,,,, -421300,3.1512814,1.1709138,,,,,,,,,,,,,, -421400,3.0555556,1.7880718,,,,,,,,,,,,,, -421500,3.1249032,2.463299,,,,,,,,,,,,,, -421600,3.2180254,2.7391865,,,,,,,,,,,,,, -421700,3.0806632,1.6916955,,,,,,,,,,,,,, -421715,,,0.8877539038658142,0.412142664194107,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,188672.410371542,204201.9252877236,188672.410371542,15478.182670593262,29.23967170715332,0.0 -421800,3.176451,1.1205425,,,,,,,,,,,,,, -421900,3.4247553,2.6127903,,,,,,,,,,,,,, -422000,3.0999067,1.1031201,,,,,,,,,,,,,, -422100,3.6595669,3.2269645,,,,,,,,,,,,,, -422200,3.1152303,1.4053441,,,,,,,,,,,,,, -422300,2.8659968,1.078859,,,,,,,,,,,,,, -422400,3.1917946,1.1356831,,,,,,,,,,,,,, -422500,3.0178144,1.0214653,,,,,,,,,,,,,, -422600,2.9586387,1.5761943,,,,,,,,,,,,,, -422654,,,0.8885155916213989,0.4149311780929565,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,189092.6740632057,204659.1133544445,189092.6740632057,15514.970543146132,29.327571868896484,0.0 -422700,3.0549371,1.1337402,,,,,,,,,,,,,, -422800,3.219642,1.16188,,,,,,,,,,,,,, -422900,3.146873,2.0264027,,,,,,,,,,,,,, -423000,3.0530205,1.1372819,,,,,,,,,,,,,, -423100,3.2169423,1.0876472,,,,,,,,,,,,,, -423200,3.1595962,2.087627,,,,,,,,,,,,,, -423300,3.306267,1.1522571,,,,,,,,,,,,,, -423400,3.6585858,1.1639276,,,,,,,,,,,,,, -423500,3.1466582,1.2662311,,,,,,,,,,,,,, -423593,,,0.88671875,0.4145254790782928,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,189513.013206482,205114.1068179608,189513.013206482,15549.48907828331,29.4142906665802,0.0 -423600,3.3890395,2.7199013,,,,,,,,,,,,,, -423700,3.3381228,2.3486595,,,,,,,,,,,,,, -423800,4.1661353,3.2273316,,,,,,,,,,,,,, -423900,3.096643,1.2783953,,,,,,,,,,,,,, -424000,3.5263212,2.1223378,,,,,,,,,,,,,, -424100,2.985386,1.3917979,,,,,,,,,,,,,, -424200,3.226605,1.0818833,,,,,,,,,,,,,, -424300,3.1225646,1.1434636,,,,,,,,,,,,,, -424400,3.2060974,1.0723681,,,,,,,,,,,,,, -424500,2.9082124,1.6557066,,,,,,,,,,,,,, -424535,,,0.8887695074081421,0.4101097881793976,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,189933.1962766648,205569.40095114708,189933.1962766648,15584.46364927292,29.50173306465149,0.0 -424600,3.1073132,1.7979405,,,,,,,,,,,,,, -424700,3.4230814,3.0045147,,,,,,,,,,,,,, -424800,3.3639097,1.1565876,,,,,,,,,,,,,, -424900,3.0281088,1.1552206,,,,,,,,,,,,,, -425000,3.1922984,2.1574194,,,,,,,,,,,,,, -425100,2.9106712,1.1205523,,,,,,,,,,,,,, -425200,3.2057998,1.2132251,,,,,,,,,,,,,, -425300,3.0643635,2.6662207,,,,,,,,,,,,,, -425400,3.1949608,1.049575,,,,,,,,,,,,,, -425473,,,0.8880859017372131,0.4189074635505676,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,190353.0560581684,206026.12229013443,190353.0560581684,15621.176329135897,29.60228538513184,0.0 -425500,3.5455282,3.1791198,,,,,,,,,,,,,, -425600,3.1010878,2.3857622,,,,,,,,,,,,,, -425700,3.1132815,2.0464153,,,,,,,,,,,,,, -425800,3.348409,2.8097043,,,,,,,,,,,,,, -425900,3.018906,2.042122,,,,,,,,,,,,,, -426000,3.3443477,2.9031463,,,,,,,,,,,,,, -426100,4.308341,3.286572,,,,,,,,,,,,,, -426200,4.252792,3.3050551,,,,,,,,,,,,,, -426300,3.224062,1.9822643,,,,,,,,,,,,,, -426400,3.158625,1.0568793,,,,,,,,,,,,,, -426408,,,0.8875585794448853,0.4154538512229919,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,190773.18273472783,206482.0358054638,190773.18273472783,15656.826608181,29.69010305404663,0.0 -426500,3.3188512,2.905383,,,,,,,,,,,,,, -426600,3.379154,1.3077061,,,,,,,,,,,,,, -426700,4.2599463,3.2283676,,,,,,,,,,,,,, -426800,3.2668648,1.5147587,,,,,,,,,,,,,, -426900,3.4360397,3.0774853,,,,,,,,,,,,,, -427000,3.189397,1.1489561,,,,,,,,,,,,,, -427100,3.8366475,3.3121123,,,,,,,,,,,,,, -427200,3.1885238,1.2795376,,,,,,,,,,,,,, -427300,3.1465476,1.1924965,,,,,,,,,,,,,, -427347,,,0.8865429759025574,0.4164008200168609,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,191193.0832927227,206936.9553265572,191193.0832927227,15691.703318595886,29.78359007835388,0.0 -427400,3.4915428,3.2396624,,,,,,,,,,,,,, -427500,3.4372897,2.8697882,,,,,,,,,,,,,, -427600,3.0888608,1.0999738,,,,,,,,,,,,,, -427700,3.1319249,1.7412485,,,,,,,,,,,,,, -427800,3.6315403,3.162035,,,,,,,,,,,,,, -427900,3.138343,1.173588,,,,,,,,,,,,,, -428000,3.118857,2.69579,,,,,,,,,,,,,, -428100,3.1182818,1.2957063,,,,,,,,,,,,,, -428200,2.9245734,1.1447489,,,,,,,,,,,,,, -428284,,,0.8866796493530273,0.4184194803237915,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,191613.23265123367,207393.0426137448,191613.23265123367,15727.489582061768,29.8861780166626,0.0 -428300,2.7906623,1.9998467,,,,,,,,,,,,,, -428400,3.3237426,2.5751724,,,,,,,,,,,,,, -428500,3.2968538,1.1055312,,,,,,,,,,,,,, -428600,3.181385,1.2130706,,,,,,,,,,,,,, -428700,3.0284758,1.069735,,,,,,,,,,,,,, -428800,2.8271172,1.5846151,,,,,,,,,,,,,, -428900,5.6627426,3.1570752,,,,,,,,,,,,,, -429000,3.8636243,1.2281549,,,,,,,,,,,,,, -429100,3.6515794,3.2006564,,,,,,,,,,,,,, -429200,3.171216,1.6446228,,,,,,,,,,,,,, -429223,,,0.8891991972923279,0.411479115486145,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,192033.45323348045,207847.76594495773,192033.45323348045,15761.853075504305,29.976338148117065,0.0 -429300,3.0487945,1.1888683,,,,,,,,,,,,,, -429400,3.308575,1.1445507,,,,,,,,,,,,,, -429500,2.9840963,1.3593713,,,,,,,,,,,,,, -429600,3.1439354,1.609266,,,,,,,,,,,,,, -429700,3.0166886,1.0548404,,,,,,,,,,,,,, -429800,3.7526538,1.187186,,,,,,,,,,,,,, -429900,3.2985137,1.0692849,,,,,,,,,,,,,, -430000,3.0468545,1.3569144,,,,,,,,,,,,,, -430100,3.0181618,1.840632,,,,,,,,,,,,,, -430160,,,0.8868749737739563,0.4184442162513733,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,192453.38086915016,208302.9422523976,192453.38086915016,15796.966268777847,30.06338238716125,0.0 -430200,3.1366715,1.2817822,,,,,,,,,,,,,, -430300,3.2533376,1.1454623,,,,,,,,,,,,,, -430400,3.0348125,1.8643476,,,,,,,,,,,,,, -430500,3.5725439,3.1768067,,,,,,,,,,,,,, -430600,2.8233895,1.9627883,,,,,,,,,,,,,, -430700,3.7164433,3.1112323,,,,,,,,,,,,,, -430800,3.5689492,1.37828,,,,,,,,,,,,,, -430900,3.210743,1.158269,,,,,,,,,,,,,, -431000,3.3171194,2.7924922,,,,,,,,,,,,,, -431097,,,0.8900390267372131,0.4117555618286133,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,192873.32185602188,208757.430749178,192873.32185602188,15831.376549959185,30.15148520469665,0.0 -431100,3.0437133,1.1650578,,,,,,,,,,,,,, -431200,3.0476537,2.366324,,,,,,,,,,,,,, -431300,3.3874288,1.163399,,,,,,,,,,,,,, -431400,2.9884012,1.2610927,,,,,,,,,,,,,, -431500,3.0145867,1.3586528,,,,,,,,,,,,,, -431600,3.1174061,1.0800583,,,,,,,,,,,,,, -431700,3.0533211,2.056676,,,,,,,,,,,,,, -431800,3.2517102,2.5696435,,,,,,,,,,,,,, -431900,3.6455357,2.7958193,,,,,,,,,,,,,, -432000,2.903922,1.1752796,,,,,,,,,,,,,, -432036,,,0.8875781297683716,0.4167491197586059,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,193293.54080319405,209210.8012115956,193293.54080319405,15864.387818336489,30.24280261993408,0.0 -432100,3.3617418,2.274847,,,,,,,,,,,,,, -432200,3.3326306,1.0769325,,,,,,,,,,,,,, -432300,2.8266957,1.9439683,,,,,,,,,,,,,, -432400,3.3288813,1.197096,,,,,,,,,,,,,, -432500,3.1084988,2.302879,,,,,,,,,,,,,, -432600,3.8251605,3.285216,,,,,,,,,,,,,, -432700,3.1929643,1.516999,,,,,,,,,,,,,, -432800,3.0118942,1.436534,,,,,,,,,,,,,, -432900,3.0669398,1.1259929,,,,,,,,,,,,,, -432975,,,0.88880854845047,0.4180719554424286,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,193713.5750546456,209664.2701013088,193713.5750546456,15897.666623830795,30.3501980304718,0.0 -433000,3.2858026,2.854249,,,,,,,,,,,,,, -433100,3.7684815,3.2569377,,,,,,,,,,,,,, -433200,2.869486,1.3133136,,,,,,,,,,,,,, -433300,3.1901515,1.6275537,,,,,,,,,,,,,, -433400,3.0521977,1.0786254,,,,,,,,,,,,,, -433500,3.9931576,2.9176297,,,,,,,,,,,,,, -433600,3.168854,1.1775497,,,,,,,,,,,,,, -433700,3.290899,2.7484016,,,,,,,,,,,,,, -433800,3.4079955,1.0631227,,,,,,,,,,,,,, -433900,3.0462224,2.5132923,,,,,,,,,,,,,, -433913,,,0.8876367211341858,0.4186112880706787,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,194133.56434631348,210119.51546907425,194133.56434631348,15932.76951098442,30.45403957366944,0.0 -434000,3.1935434,2.2572908,,,,,,,,,,,,,, -434100,3.6318085,3.2011151,,,,,,,,,,,,,, -434200,3.2845285,2.779338,,,,,,,,,,,,,, -434300,2.9818473,1.0991995,,,,,,,,,,,,,, -434400,3.5434532,3.1114872,,,,,,,,,,,,,, -434500,3.1481872,1.0019615,,,,,,,,,,,,,, -434600,3.1966765,1.2267936,,,,,,,,,,,,,, -434700,3.24988,2.6492147,,,,,,,,,,,,,, -434800,3.709089,3.1894224,,,,,,,,,,,,,, -434852,,,0.8868945240974426,0.4147317111492157,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,194553.46647405624,210573.7652647496,194553.46647405624,15966.977990627289,30.544241189956665,0.0 -434900,2.860375,1.0833232,,,,,,,,,,,,,, -435000,3.3290942,2.7763383,,,,,,,,,,,,,, -435100,2.8169403,1.6696501,,,,,,,,,,,,,, -435200,2.997634,1.081627,,,,,,,,,,,,,, -435300,3.3532536,2.8156734,,,,,,,,,,,,,, -435400,3.3302102,1.144883,,,,,,,,,,,,,, -435500,3.5030675,1.5297866,,,,,,,,,,,,,, -435600,3.3077147,1.5949106,,,,,,,,,,,,,, -435700,3.864907,3.2555664,,,,,,,,,,,,,, -435791,,,0.8885546922683716,0.4164343178272247,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,194973.41510796547,211028.4382355213,194973.41510796547,16001.56213068962,30.63562846183777,0.0 -435800,2.9089181,1.2184324,,,,,,,,,,,,,, -435900,3.084338,1.286449,,,,,,,,,,,,,, -436000,3.1053026,2.4351993,,,,,,,,,,,,,, -436100,3.4979782,3.1143923,,,,,,,,,,,,,, -436200,3.3699336,2.698536,,,,,,,,,,,,,, -436300,3.1704032,1.0743124,,,,,,,,,,,,,, -436400,3.9738128,3.2397761,,,,,,,,,,,,,, -436500,3.1433809,1.107057,,,,,,,,,,,,,, -436600,2.821291,1.371947,,,,,,,,,,,,,, -436700,3.9743657,3.2487342,,,,,,,,,,,,,, -436729,,,0.8863476514816284,0.4209548830986023,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,195393.3149046898,211484.39165091515,195393.3149046898,16037.476316690443,30.72657537460327,0.0 -436800,2.9181964,1.5095421,,,,,,,,,,,,,, -436900,3.1105554,1.545131,,,,,,,,,,,,,, -437000,2.893109,1.0744264,,,,,,,,,,,,,, -437100,3.0017333,1.0722477,,,,,,,,,,,,,, -437200,3.79865,3.3009484,,,,,,,,,,,,,, -437300,3.1213768,1.1645083,,,,,,,,,,,,,, -437400,2.741019,1.6206821,,,,,,,,,,,,,, -437500,3.9039319,3.1902883,,,,,,,,,,,,,, -437600,3.296097,1.1274813,,,,,,,,,,,,,, -437668,,,0.8892577886581421,0.4133823215961456,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,195813.20502996445,211939.55713367465,195813.20502996445,16072.608691453934,30.82048726081848,0.0 -437700,3.185637,2.5975318,,,,,,,,,,,,,, -437800,3.157752,2.3644135,,,,,,,,,,,,,, -437900,3.3145485,1.1017005,,,,,,,,,,,,,, -438000,4.1859226,3.3142016,,,,,,,,,,,,,, -438100,2.9322119,1.4187491,,,,,,,,,,,,,, -438200,2.9625988,1.1531746,,,,,,,,,,,,,, -438300,2.998861,1.3152035,,,,,,,,,,,,,, -438400,3.4267423,2.7004278,,,,,,,,,,,,,, -438500,2.8962417,1.5348872,,,,,,,,,,,,,, -438600,2.998593,1.1034979,,,,,,,,,,,,,, -438606,,,0.8886327743530273,0.4106488823890686,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,196233.2885670662,212393.555311203,196233.2885670662,16106.368917942047,30.92542052268982,0.0 -438700,2.86634,1.82788,,,,,,,,,,,,,, -438800,3.1447358,1.1269373,,,,,,,,,,,,,, -438900,3.4092727,2.9897153,,,,,,,,,,,,,, -439000,3.2685022,1.0586195,,,,,,,,,,,,,, -439100,3.178707,2.619987,,,,,,,,,,,,,, -439200,3.1646802,1.1466234,,,,,,,,,,,,,, -439300,3.628005,2.8900769,,,,,,,,,,,,,, -439400,3.0104003,1.2915287,,,,,,,,,,,,,, -439500,3.056077,2.2430978,,,,,,,,,,,,,, -439545,,,0.8883788585662842,0.4163420498371124,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,196653.4536757469,212849.3200573921,196653.4536757469,16141.804993867874,31.039782285690308,0.0 -439600,2.9216664,2.356402,,,,,,,,,,,,,, -439700,2.7471037,1.7192041,,,,,,,,,,,,,, -439800,3.4292772,2.4687455,,,,,,,,,,,,,, -439900,2.8881269,1.2143292,,,,,,,,,,,,,, -440000,3.1656625,1.1114023,,,,,,,,,,,,,, -440100,3.1470246,1.5129765,,,,,,,,,,,,,, -440200,3.2232797,1.9398,,,,,,,,,,,,,, -440300,3.3018186,1.1804787,,,,,,,,,,,,,, -440400,4.1527615,3.1970341,,,,,,,,,,,,,, -440484,,,0.8870507478713989,0.4158827364444732,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,197073.50745630264,213305.20050811768,197073.50745630264,16177.474328279495,31.147948503494263,0.0 -440500,3.0563707,1.23792,,,,,,,,,,,,,, -440600,3.612872,3.2099216,,,,,,,,,,,,,, -440700,2.9711983,1.091882,,,,,,,,,,,,,, -440800,2.8750083,2.0868566,,,,,,,,,,,,,, -440900,3.8058493,2.9901874,,,,,,,,,,,,,, -441000,2.93279,2.41035,,,,,,,,,,,,,, -441100,3.3103929,1.3792629,,,,,,,,,,,,,, -441200,3.185123,1.7572229,,,,,,,,,,,,,, -441300,3.5919266,3.124665,,,,,,,,,,,,,, -441400,3.1421173,1.1086509,,,,,,,,,,,,,, -441423,,,0.8891406059265137,0.4135304689407348,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,197493.4744989872,213761.4383919239,197493.4744989872,16213.603892564774,31.240306854248047,0.0 -441500,2.9341826,1.228197,,,,,,,,,,,,,, -441600,3.320346,1.5174686,,,,,,,,,,,,,, -441700,3.3106267,2.1292489,,,,,,,,,,,,,, -441800,2.8362517,1.871952,,,,,,,,,,,,,, -441900,2.9527946,1.0948288,,,,,,,,,,,,,, -442000,3.2797432,1.096376,,,,,,,,,,,,,, -442100,3.0245578,1.1088482,,,,,,,,,,,,,, -442200,3.1608949,1.1109977,,,,,,,,,,,,,, -442300,2.9274192,2.3047132,,,,,,,,,,,,,, -442361,,,0.8888866901397705,0.4150472581386566,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,197913.55271673205,214216.4568591118,197913.55271673205,16248.398986577988,31.336047172546387,0.0 -442400,4.2386336,3.2790425,,,,,,,,,,,,,, -442500,4.4663405,1.2291119,,,,,,,,,,,,,, -442600,2.8886168,1.1205102,,,,,,,,,,,,,, -442700,2.8677948,1.7319417,,,,,,,,,,,,,, -442800,3.0200558,1.31254,,,,,,,,,,,,,, -442900,2.9820263,1.1570269,,,,,,,,,,,,,, -443000,3.0482757,1.2403581,,,,,,,,,,,,,, -443100,2.726858,1.6133999,,,,,,,,,,,,,, -443200,3.262021,1.122374,,,,,,,,,,,,,, -443298,,,0.8919921517372131,0.4033390283584595,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,198333.7465114593,214672.06289815903,198333.7465114593,16283.67305135727,31.426677703857425,0.0 -443300,3.1991956,1.2125968,,,,,,,,,,,,,, -443400,3.0563717,1.1576511,,,,,,,,,,,,,, -443500,3.4780407,1.1158887,,,,,,,,,,,,,, -443600,2.9920306,1.6651051,,,,,,,,,,,,,, -443700,2.9858055,1.1648319,,,,,,,,,,,,,, -443800,2.9823961,1.8021061,,,,,,,,,,,,,, -443900,3.285959,1.0901501,,,,,,,,,,,,,, -444000,2.9096627,1.1719029,,,,,,,,,,,,,, -444100,4.1888475,1.864947,,,,,,,,,,,,,, -444200,3.2308507,1.140557,,,,,,,,,,,,,, -444235,,,0.8896874785423279,0.4114967882633209,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,198753.99509072304,215128.74952864647,198753.99509072304,16319.972463607788,31.51662302017212,0.0 -444300,3.1403697,1.2083422,,,,,,,,,,,,,, -444400,3.2201862,1.1333276,,,,,,,,,,,,,, -444500,4.176666,3.0637848,,,,,,,,,,,,,, -444600,3.333254,2.5536847,,,,,,,,,,,,,, -444700,2.9589996,2.0621603,,,,,,,,,,,,,, -444800,3.1907887,1.3624853,,,,,,,,,,,,,, -444900,3.205942,1.0926417,,,,,,,,,,,,,, -445000,5.638695,1.2478776,,,,,,,,,,,,,, -445100,3.2840204,1.2682183,,,,,,,,,,,,,, -445174,,,0.8869726657867432,0.4166136384010315,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,199174.1624150276,215584.6144790649,199174.1624150276,16355.516724586489,31.621028661727905,0.0 -445200,3.166102,1.2612591,,,,,,,,,,,,,, -445300,3.1981113,1.2229829,,,,,,,,,,,,,, -445400,3.3423867,2.747059,,,,,,,,,,,,,, -445500,3.446769,3.133587,,,,,,,,,,,,,, -445600,3.419616,2.9592059,,,,,,,,,,,,,, -445700,3.2120347,1.7579781,,,,,,,,,,,,,, -445800,3.3960822,3.0561209,,,,,,,,,,,,,, -445900,2.9493053,1.8962816,,,,,,,,,,,,,, -446000,3.4461203,1.1285759,,,,,,,,,,,,,, -446100,3.3018064,2.6623962,,,,,,,,,,,,,, -446111,,,0.8872851133346558,0.4126180410385132,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,199594.30109024048,216040.7579071521,199594.30109024048,16391.367474079132,31.726098775863647,0.0 -446200,3.0436444,2.1482017,,,,,,,,,,,,,, -446300,3.1695032,1.4288862,,,,,,,,,,,,,, -446400,2.978201,1.0908799,,,,,,,,,,,,,, -446500,3.2094557,1.0503582,,,,,,,,,,,,,, -446600,2.9251914,1.0663383,,,,,,,,,,,,,, -446700,3.7168124,3.0335946,,,,,,,,,,,,,, -446800,3.6935697,3.0420187,,,,,,,,,,,,,, -446900,3.1789277,1.1010089,,,,,,,,,,,,,, -447000,3.3347604,1.1041182,,,,,,,,,,,,,, -447049,,,0.88734370470047,0.4156618416309356,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,200014.64024281505,216497.158163786,200014.64024281505,16427.27664041519,31.828654527664185,0.0 -447100,3.1006937,1.8037022,,,,,,,,,,,,,, -447200,3.2815924,1.2591702,,,,,,,,,,,,,, -447300,2.9322593,2.0434494,,,,,,,,,,,,,, -447400,3.263589,2.3456628,,,,,,,,,,,,,, -447500,3.2460067,1.961106,,,,,,,,,,,,,, -447600,3.305403,1.1213702,,,,,,,,,,,,,, -447700,3.093494,2.017644,,,,,,,,,,,,,, -447800,3.0771112,1.4896111,,,,,,,,,,,,,, -447900,3.059395,2.0182185,,,,,,,,,,,,,, -447986,,,0.8880078196525574,0.4107287228107452,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,200434.81145572665,216953.36850714684,200434.81145572665,16463.17485022545,31.92135524749756,0.0 -448000,3.1416276,1.3084711,,,,,,,,,,,,,, -448100,2.819213,1.3705792,,,,,,,,,,,,,, -448200,3.0483017,1.0438609,,,,,,,,,,,,,, -448300,3.229653,2.7089832,,,,,,,,,,,,,, -448400,3.0990667,1.5095553,,,,,,,,,,,,,, -448500,2.8255534,1.4085019,,,,,,,,,,,,,, -448600,2.9740717,1.0659682,,,,,,,,,,,,,, -448700,2.8314285,1.8575525,,,,,,,,,,,,,, -448800,3.025833,1.3009422,,,,,,,,,,,,,, -448900,3.1478167,1.7511858,,,,,,,,,,,,,, -448924,,,0.8885155916213989,0.4190287590026855,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,200855.06634163857,217409.7971892357,200855.06634163857,16499.20669221878,32.01465153694153,0.0 -449000,3.1298351,1.1669564,,,,,,,,,,,,,, -449100,2.961244,1.1401356,,,,,,,,,,,,,, -449200,2.9627044,1.7223713,,,,,,,,,,,,,, -449300,2.9847245,1.8836862,,,,,,,,,,,,,, -449400,3.01972,1.0648615,,,,,,,,,,,,,, -449500,2.798757,1.4810009,,,,,,,,,,,,,, -449600,3.2971838,1.2877915,,,,,,,,,,,,,, -449700,4.363537,3.2639563,,,,,,,,,,,,,, -449800,3.0817392,2.545084,,,,,,,,,,,,,, -449862,,,0.8880664110183716,0.4124629497528076,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,201274.92774033544,217865.0299890041,201274.92774033544,16534.43468928337,32.109177350997925,0.0 -449900,3.3743474,2.8983917,,,,,,,,,,,,,, -450000,3.128938,1.3348467,,,,,,,,,,,,,, -450100,3.1737342,1.1509498,,,,,,,,,,,,,, -450200,3.1434586,1.5147539,,,,,,,,,,,,,, -450300,3.2277248,2.0446608,,,,,,,,,,,,,, -450400,3.8281703,3.3079295,,,,,,,,,,,,,, -450500,3.1573474,1.1561494,,,,,,,,,,,,,, -450600,3.1460843,1.1674163,,,,,,,,,,,,,, -450700,3.2178578,1.1348534,,,,,,,,,,,,,, -450799,,,0.8866210579872131,0.4189048409461975,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,201694.9643316269,218321.3376784325,201694.9643316269,16570.566035032272,32.20013737678528,0.0 -450800,3.223738,1.6724118,,,,,,,,,,,,,, -450900,3.0361137,1.0836759,,,,,,,,,,,,,, -451000,3.098788,1.107523,,,,,,,,,,,,,, -451100,3.7489622,3.2484317,,,,,,,,,,,,,, -451200,3.3777466,1.2332035,,,,,,,,,,,,,, -451300,3.0013757,1.1932153,,,,,,,,,,,,,, -451400,2.8641984,1.164688,,,,,,,,,,,,,, -451500,4.2045455,2.989469,,,,,,,,,,,,,, -451600,2.9469001,2.2772167,,,,,,,,,,,,,, -451700,3.1525116,1.966316,,,,,,,,,,,,,, -451735,,,0.8884375095367432,0.4119004011154175,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,202115.2426486016,218777.88199329376,202115.2426486016,16606.675546884537,32.308313846588135,0.0 -451800,3.1546397,1.156365,,,,,,,,,,,,,, -451900,3.6293547,2.578378,,,,,,,,,,,,,, -452000,3.0028713,1.1711112,,,,,,,,,,,,,, -452100,3.093344,1.9798249,,,,,,,,,,,,,, -452200,2.9474642,2.0374706,,,,,,,,,,,,,, -452300,4.3382425,2.8776984,,,,,,,,,,,,,, -452400,3.5607862,3.028362,,,,,,,,,,,,,, -452500,3.0864518,2.2173283,,,,,,,,,,,,,, -452600,2.897513,1.5851265,,,,,,,,,,,,,, -452675,,,0.8880664110183716,0.4146206080913543,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,202535.2054517269,219233.2810678482,202535.2054517269,16641.969081640244,32.402332067489624,0.0 -452700,2.9599829,2.3937285,,,,,,,,,,,,,, -452800,3.0703018,1.7121767,,,,,,,,,,,,,, -452900,5.0285573,1.6299895,,,,,,,,,,,,,, -453000,2.922677,1.2718892,,,,,,,,,,,,,, -453100,3.5209358,1.087346,,,,,,,,,,,,,, -453200,3.2398763,1.1736928,,,,,,,,,,,,,, -453300,2.9219778,1.8195962,,,,,,,,,,,,,, -453400,3.0345008,1.1744952,,,,,,,,,,,,,, -453500,3.1348345,2.0199666,,,,,,,,,,,,,, -453600,3.8786483,2.9336782,,,,,,,,,,,,,, -453614,,,0.8863085508346558,0.4198028445243835,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,202955.50574207303,219687.5152556896,202955.50574207303,16675.76238799095,32.494225025177,0.0 -453700,3.0263684,1.2100897,,,,,,,,,,,,,, -453800,3.1839917,1.1784079,,,,,,,,,,,,,, -453900,3.9561465,3.2445664,,,,,,,,,,,,,, -454000,2.976475,1.0736454,,,,,,,,,,,,,, -454100,3.260591,2.2845988,,,,,,,,,,,,,, -454200,3.1224754,1.1402538,,,,,,,,,,,,,, -454300,3.2585905,1.2021984,,,,,,,,,,,,,, -454400,3.7102818,3.1091123,,,,,,,,,,,,,, -454500,3.2791386,2.1091678,,,,,,,,,,,,,, -454552,,,0.8874413967132568,0.418776273727417,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,203375.631287098,220144.4460659027,203375.631287098,16712.418529748917,32.59436345100403,0.0 -454600,3.3726761,1.2019117,,,,,,,,,,,,,, -454700,3.4895887,1.1683387,,,,,,,,,,,,,, -454800,3.6899729,3.0536826,,,,,,,,,,,,,, -454900,3.1305373,2.5220585,,,,,,,,,,,,,, -455000,3.5592237,2.9320998,,,,,,,,,,,,,, -455100,3.68514,3.1338582,,,,,,,,,,,,,, -455200,3.127646,1.2010775,,,,,,,,,,,,,, -455300,3.4161868,3.0259264,,,,,,,,,,,,,, -455400,2.7770896,1.9277563,,,,,,,,,,,,,, -455492,,,0.8889843821525574,0.4140480756759643,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,203795.91160702705,220601.0011694432,203795.91160702705,16748.549387454987,32.69032335281372,0.0 -455500,3.1304169,1.0841255,,,,,,,,,,,,,, -455600,2.874426,1.0254961,,,,,,,,,,,,,, -455700,3.2973218,1.1492041,,,,,,,,,,,,,, -455800,3.2545094,2.772117,,,,,,,,,,,,,, -455900,3.1572282,1.1666209,,,,,,,,,,,,,, -456000,3.1119003,1.3104725,,,,,,,,,,,,,, -456100,2.993823,2.2218256,,,,,,,,,,,,,, -456200,3.3358796,2.007208,,,,,,,,,,,,,, -456300,3.891173,3.2284148,,,,,,,,,,,,,, -456400,3.1192727,1.1854229,,,,,,,,,,,,,, -456431,,,0.8893945217132568,0.4115277528762817,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,204215.98255586624,221056.6145207882,204215.98255586624,16783.94956278801,32.78437304496765,0.0 -456500,2.9652615,2.1108468,,,,,,,,,,,,,, -456600,3.2710772,1.0881695,,,,,,,,,,,,,, -456700,3.2541015,1.0812136,,,,,,,,,,,,,, -456800,2.9918246,2.344286,,,,,,,,,,,,,, -456900,3.798815,3.252331,,,,,,,,,,,,,, -457000,2.9997187,1.0611968,,,,,,,,,,,,,, -457100,3.264352,1.1145056,,,,,,,,,,,,,, -457200,3.2700763,1.1595771,,,,,,,,,,,,,, -457300,3.3059945,1.1136768,,,,,,,,,,,,,, -457370,,,0.8887109160423279,0.4161616861820221,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,204635.8847630024,221511.99255919456,204635.8847630024,16819.28258252144,32.87864851951599,0.0 -457400,3.2281156,1.1652153,,,,,,,,,,,,,, -457500,3.091291,1.1955202,,,,,,,,,,,,,, -457600,3.1853848,1.6665494,,,,,,,,,,,,,, -457700,3.2418568,2.840831,,,,,,,,,,,,,, -457800,3.1777086,2.14862,,,,,,,,,,,,,, -457900,3.1971781,1.1300814,,,,,,,,,,,,,, -458000,3.2571943,1.1810651,,,,,,,,,,,,,, -458100,3.050101,1.0237644,,,,,,,,,,,,,, -458200,3.1554549,1.1797279,,,,,,,,,,,,,, -458300,2.8635151,2.1131513,,,,,,,,,,,,,, -458307,,,0.8866796493530273,0.4212839007377624,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,205055.78770446777,221966.20752811432,205055.78770446777,16853.451112270355,32.97318458557129,0.0 -458400,3.0437045,1.2047062,,,,,,,,,,,,,, -458500,3.145467,2.7711418,,,,,,,,,,,,,, -458600,3.1393914,1.3914871,,,,,,,,,,,,,, -458700,3.0467856,1.8011,,,,,,,,,,,,,, -458800,2.9071767,2.2717223,,,,,,,,,,,,,, -458900,3.1254437,1.1026582,,,,,,,,,,,,,, -459000,2.9149508,1.8992188,,,,,,,,,,,,,, -459100,3.2367532,1.0640928,,,,,,,,,,,,,, -459200,3.843916,3.1486905,,,,,,,,,,,,,, -459242,,,0.8888671398162842,0.4136143922805786,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,205475.6713354588,222421.1706006527,205475.6713354588,16888.39018678665,33.064754486083984,0.0 -459300,3.1489003,2.039472,,,,,,,,,,,,,, -459400,3.0319471,2.4050293,,,,,,,,,,,,,, -459500,2.974403,1.1316338,,,,,,,,,,,,,, -459600,3.5310302,3.0503054,,,,,,,,,,,,,, -459700,3.0401764,1.0243145,,,,,,,,,,,,,, -459800,3.0540257,1.3203708,,,,,,,,,,,,,, -459900,2.7673,1.5468354,,,,,,,,,,,,,, -460000,3.5896184,3.054113,,,,,,,,,,,,,, -460100,3.5754526,1.9595432,,,,,,,,,,,,,, -460179,,,0.8866991996765137,0.41916623711586,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,205895.7057659626,222875.7852590084,205895.7057659626,16922.828406095505,33.157649517059326,0.0 -460200,3.162087,1.0953414,,,,,,,,,,,,,, -460300,3.175853,1.2312734,,,,,,,,,,,,,, -460400,2.9974647,1.18946,,,,,,,,,,,,,, -460500,3.387,1.0933291,,,,,,,,,,,,,, -460600,3.4466462,1.1872135,,,,,,,,,,,,,, -460700,3.9961593,3.0837703,,,,,,,,,,,,,, -460800,2.8276913,1.9880285,,,,,,,,,,,,,, -460900,3.2896318,1.1880745,,,,,,,,,,,,,, -461000,2.890401,2.1453805,,,,,,,,,,,,,, -461100,2.8094194,1.3218255,,,,,,,,,,,,,, -461118,,,0.8881054520606995,0.4115708768367767,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,206315.9453783036,223331.3592851162,206315.9453783036,16958.004980802536,33.26661014556885,0.0 -461200,3.348309,2.0312328,,,,,,,,,,,,,, -461300,3.4653695,1.1948162,,,,,,,,,,,,,, -461400,3.0288746,1.4671955,,,,,,,,,,,,,, -461500,2.9891365,1.8422453,,,,,,,,,,,,,, -461600,3.1126196,1.4177297,,,,,,,,,,,,,, -461700,3.159039,1.2672955,,,,,,,,,,,,,, -461800,3.4037979,1.8700792,,,,,,,,,,,,,, -461900,3.590994,1.2021856,,,,,,,,,,,,,, -462000,2.8629792,1.2009426,,,,,,,,,,,,,, -462057,,,0.8887499570846558,0.4157285094261169,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,206736.0901172161,223788.6147623062,206736.0901172161,16994.971632242203,33.36212921142578,0.0 -462100,3.0694447,1.3161138,,,,,,,,,,,,,, -462200,3.972776,3.252315,,,,,,,,,,,,,, -462300,3.4862168,2.919374,,,,,,,,,,,,,, -462400,5.2187943,3.3470979,,,,,,,,,,,,,, -462500,3.524257,1.8378549,,,,,,,,,,,,,, -462600,3.008477,1.5354658,,,,,,,,,,,,,, -462700,3.0590734,1.0123999,,,,,,,,,,,,,, -462800,3.1146352,1.1187977,,,,,,,,,,,,,, -462900,2.979086,1.1412762,,,,,,,,,,,,,, -462995,,,0.8881054520606995,0.4158811569213867,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,207156.234416008,224244.74796533585,207156.234416008,17030.81745314598,33.45702028274536,0.0 -463000,2.8667648,2.0679142,,,,,,,,,,,,,, -463100,3.4176784,3.1744812,,,,,,,,,,,,,, -463200,2.9054139,1.4061873,,,,,,,,,,,,,, -463300,2.842481,2.2751899,,,,,,,,,,,,,, -463400,3.0592592,1.0780298,,,,,,,,,,,,,, -463500,3.1386728,1.2327902,,,,,,,,,,,,,, -463600,3.256458,2.7410736,,,,,,,,,,,,,, -463700,2.9729095,1.201639,,,,,,,,,,,,,, -463800,3.1635098,2.5354612,,,,,,,,,,,,,, -463900,3.1283474,1.1808382,,,,,,,,,,,,,, -463932,,,0.8861132860183716,0.4176024496555328,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,207576.18787121773,224700.5097444057,207576.18787121773,17066.472289562225,33.561460733413696,0.0 -464000,2.9285157,1.2743342,,,,,,,,,,,,,, -464100,3.1145098,1.9257853,,,,,,,,,,,,,, -464200,3.871643,1.0416247,,,,,,,,,,,,,, -464300,3.8818996,3.2920063,,,,,,,,,,,,,, -464400,3.1435308,1.4348044,,,,,,,,,,,,,, -464500,3.238649,1.4216467,,,,,,,,,,,,,, -464600,3.0338063,1.1168827,,,,,,,,,,,,,, -464700,2.9563704,1.4280727,,,,,,,,,,,,,, -464800,3.0872157,1.1477476,,,,,,,,,,,,,, -464867,,,0.889941394329071,0.4139235317707062,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,207996.31915855408,225155.25125455856,207996.31915855408,17100.941737413406,33.65341758728027,0.0 -464900,3.1324525,1.121511,,,,,,,,,,,,,, -465000,2.9186301,1.4827433,,,,,,,,,,,,,, -465100,3.1293147,2.113374,,,,,,,,,,,,,, -465200,3.2517295,1.2092665,,,,,,,,,,,,,, -465300,3.1187706,2.7806847,,,,,,,,,,,,,, -465400,3.0310717,1.1074697,,,,,,,,,,,,,, -465500,3.5337508,2.7369475,,,,,,,,,,,,,, -465600,3.6486766,1.1926208,,,,,,,,,,,,,, -465700,3.7670937,1.1678387,,,,,,,,,,,,,, -465800,3.3485441,3.0431385,,,,,,,,,,,,,, -465804,,,0.8892577886581421,0.4131757915019989,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,208416.38221096992,225611.26416301727,208416.38221096992,17136.74747443199,33.74782729148865,0.0 -465900,3.030712,1.093524,,,,,,,,,,,,,, -466000,3.221175,2.6976163,,,,,,,,,,,,,, -466100,3.0727563,1.0910393,,,,,,,,,,,,,, -466200,2.9129884,2.0505123,,,,,,,,,,,,,, -466300,2.937076,1.4695659,,,,,,,,,,,,,, -466400,5.3769975,3.193583,,,,,,,,,,,,,, -466500,3.0400863,1.4718376,,,,,,,,,,,,,, -466600,3.0012774,1.0992537,,,,,,,,,,,,,, -466700,4.0532265,3.248041,,,,,,,,,,,,,, -466742,,,0.8897460699081421,0.4074950218200683,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,208836.4301431179,226067.72537398327,208836.4301431179,17173.017083644867,33.8416645526886,0.0 -466800,3.163121,1.0416918,,,,,,,,,,,,,, -466900,3.3064985,2.7242932,,,,,,,,,,,,,, -467000,3.0801978,1.3836515,,,,,,,,,,,,,, -467100,4.420214,3.3715858,,,,,,,,,,,,,, -467200,5.59091,1.0633861,,,,,,,,,,,,,, -467300,3.1248062,1.5528958,,,,,,,,,,,,,, -467400,3.65997,1.2146537,,,,,,,,,,,,,, -467500,3.1779313,1.11392,,,,,,,,,,,,,, -467600,3.0109885,1.7927316,,,,,,,,,,,,,, -467682,,,0.88929682970047,0.4114511907100677,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,209256.819983244,226523.2222290039,209256.819983244,17207.967556238174,33.94882917404175,0.0 -467700,3.3465998,1.1009204,,,,,,,,,,,,,, -467800,3.0501823,1.1091102,,,,,,,,,,,,,, -467900,3.2606723,1.0687301,,,,,,,,,,,,,, -468000,2.942875,1.0696323,,,,,,,,,,,,,, -468100,3.094475,1.1622066,,,,,,,,,,,,,, -468200,3.2662618,2.7687192,,,,,,,,,,,,,, -468300,3.177597,1.0852429,,,,,,,,,,,,,, -468400,3.8907008,3.2639945,,,,,,,,,,,,,, -468500,3.551398,3.0059514,,,,,,,,,,,,,, -468600,2.9630635,1.0848196,,,,,,,,,,,,,, -468619,,,0.8893749713897705,0.4095381200313568,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,209676.87094807625,226977.7106547356,209676.87094807625,17242.25678539276,34.04836964607239,0.0 -468700,3.7552068,1.1456406,,,,,,,,,,,,,, -468800,3.9470963,3.2192774,,,,,,,,,,,,,, -468900,3.4092505,2.8719223,,,,,,,,,,,,,, -469000,3.0595858,2.644657,,,,,,,,,,,,,, -469100,3.2044106,2.6381042,,,,,,,,,,,,,, -469200,3.0527277,1.1134269,,,,,,,,,,,,,, -469300,2.966578,1.2328883,,,,,,,,,,,,,, -469400,3.4959211,3.1301103,,,,,,,,,,,,,, -469500,3.3006234,1.0819085,,,,,,,,,,,,,, -469559,,,0.8883788585662842,0.4138410091400146,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,210097.1221292019,227432.9923608303,210097.1221292019,17277.135021448135,34.15190577507019,0.0 -469600,3.9610875,3.1476893,,,,,,,,,,,,,, -469700,3.4752061,3.0175421,,,,,,,,,,,,,, -469800,3.2020938,2.0141754,,,,,,,,,,,,,, -469900,2.9256182,1.6339307,,,,,,,,,,,,,, -470000,3.1057138,1.1249762,,,,,,,,,,,,,, -470100,2.6811585,1.0597078,,,,,,,,,,,,,, -470200,3.247097,2.9884608,,,,,,,,,,,,,, -470300,3.5368085,1.2040147,,,,,,,,,,,,,, -470400,3.0172284,2.2835474,,,,,,,,,,,,,, -470494,,,0.8877343535423279,0.4128125607967376,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,210517.18369412425,227889.16590118408,210517.18369412425,17313.09035563469,34.2598888874054,0.0 -470500,3.6162057,1.267235,,,,,,,,,,,,,, -470600,3.0506694,1.3189709,,,,,,,,,,,,,, -470700,2.9178524,1.2884309,,,,,,,,,,,,,, -470800,3.0748293,2.273897,,,,,,,,,,,,,, -470900,2.9962828,1.1111568,,,,,,,,,,,,,, -471000,3.2357292,1.1754138,,,,,,,,,,,,,, -471100,2.9060187,1.6957326,,,,,,,,,,,,,, -471200,3.3925872,1.1573203,,,,,,,,,,,,,, -471300,2.8545823,1.0755696,,,,,,,,,,,,,, -471400,3.1551259,2.494746,,,,,,,,,,,,,, -471435,,,0.8878124952316284,0.4137357473373413,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,210937.3222939968,228345.77362632751,210937.3222939968,17349.413534879684,34.35630774497986,0.0 -471500,3.3452556,2.9728193,,,,,,,,,,,,,, -471600,3.2989504,2.8740978,,,,,,,,,,,,,, -471700,3.5773268,3.192206,,,,,,,,,,,,,, -471800,3.5704894,2.90903,,,,,,,,,,,,,, -471900,2.9490542,1.6802745,,,,,,,,,,,,,, -472000,4.0625706,3.2194142,,,,,,,,,,,,,, -472100,2.967029,2.3932245,,,,,,,,,,,,,, -472200,2.9818518,1.1193299,,,,,,,,,,,,,, -472300,3.3978245,2.6467712,,,,,,,,,,,,,, -472376,,,0.8859961032867432,0.421807587146759,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,211357.5579998493,228801.10828518867,211357.5579998493,17384.368765592575,34.45077323913574,0.0 -472400,3.4925,3.1349823,,,,,,,,,,,,,, -472500,2.9825902,1.5011618,,,,,,,,,,,,,, -472600,3.1491237,1.1344848,,,,,,,,,,,,,, -472700,3.1045146,1.1546221,,,,,,,,,,,,,, -472800,3.8517463,3.1105385,,,,,,,,,,,,,, -472900,2.9580219,2.2412558,,,,,,,,,,,,,, -473000,4.241324,3.3172343,,,,,,,,,,,,,, -473100,3.2869177,1.1260326,,,,,,,,,,,,,, -473200,3.1265259,1.0616409,,,,,,,,,,,,,, -473300,3.2693324,2.018105,,,,,,,,,,,,,, -473316,,,0.8886523246765137,0.4129462838172912,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,211777.6955449581,229257.14534235,211777.6955449581,17420.123772144318,34.5459508895874,0.0 -473400,3.2511542,1.327604,,,,,,,,,,,,,, -473500,2.951315,2.376222,,,,,,,,,,,,,, -473600,2.9097838,0.973325,,,,,,,,,,,,,, -473700,3.8143904,3.1603825,,,,,,,,,,,,,, -473800,2.7993646,1.5634131,,,,,,,,,,,,,, -473900,3.0420575,1.8331898,,,,,,,,,,,,,, -474000,3.625649,2.8425996,,,,,,,,,,,,,, -474100,2.9478173,1.1277877,,,,,,,,,,,,,, -474200,2.8355021,2.1126413,,,,,,,,,,,,,, -474255,,,0.8865429759025574,0.4211982488632202,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,212197.9356815815,229711.0480697155,212197.9356815815,17453.64389872551,34.639634132385254,0.0 -474300,3.09496,1.163173,,,,,,,,,,,,,, -474400,3.0964298,1.2154381,,,,,,,,,,,,,, -474500,3.163803,1.0512722,,,,,,,,,,,,,, -474600,3.1072037,1.5723016,,,,,,,,,,,,,, -474700,3.36058,1.8820678,,,,,,,,,,,,,, -474800,3.178517,1.1299238,,,,,,,,,,,,,, -474900,3.2140226,1.845847,,,,,,,,,,,,,, -475000,3.1688838,1.1236619,,,,,,,,,,,,,, -475100,2.9290383,1.05385,,,,,,,,,,,,,, -475192,,,0.888671875,0.4101220667362213,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,212617.8903603553,230167.2466096878,212617.8903603553,17489.702049016953,34.77641224861145,0.0 -475200,3.0162683,1.2148234,,,,,,,,,,,,,, -475300,3.0653489,1.1314129,,,,,,,,,,,,,, -475400,3.1057491,1.8809528,,,,,,,,,,,,,, -475500,3.2894862,2.6505384,,,,,,,,,,,,,, -475600,3.339599,2.0709448,,,,,,,,,,,,,, -475700,3.421187,1.1511055,,,,,,,,,,,,,, -475800,3.2938807,1.1948646,,,,,,,,,,,,,, -475900,3.087945,2.1045697,,,,,,,,,,,,,, -476000,3.1662831,1.143746,,,,,,,,,,,,,, -476100,3.563628,2.8571541,,,,,,,,,,,,,, -476134,,,0.8873632550239563,0.415277898311615,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,213038.1022245884,230622.10814762115,213038.1022245884,17524.20827102661,34.870837688446045,0.0 -476200,3.0777214,1.1313976,,,,,,,,,,,,,, -476300,4.251828,3.2089176,,,,,,,,,,,,,, -476400,3.4292996,1.3037337,,,,,,,,,,,,,, -476500,3.360095,1.184128,,,,,,,,,,,,,, -476600,3.820522,2.9040666,,,,,,,,,,,,,, -476700,2.9533885,1.2960747,,,,,,,,,,,,,, -476800,3.114409,1.0924397,,,,,,,,,,,,,, -476900,3.1106,1.2149493,,,,,,,,,,,,,, -477000,3.6287365,3.251731,,,,,,,,,,,,,, -477074,,,0.8871874809265137,0.4179157018661499,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,213458.12177467343,231077.1324160099,213458.12177467343,17559.069897413254,34.96437788009644,0.0 -477100,3.3061452,2.7395856,,,,,,,,,,,,,, -477200,3.1658146,1.1149288,,,,,,,,,,,,,, -477300,2.8156652,1.6779118,,,,,,,,,,,,,, -477400,3.07456,1.0885804,,,,,,,,,,,,,, -477500,3.0137894,2.5569062,,,,,,,,,,,,,, -477600,3.3755827,2.631466,,,,,,,,,,,,,, -477700,3.3157687,1.3502953,,,,,,,,,,,,,, -477800,3.2823699,1.1426456,,,,,,,,,,,,,, -477900,3.8937848,3.312368,,,,,,,,,,,,,, -478000,3.0863104,1.0956733,,,,,,,,,,,,,, -478014,,,0.8897460699081421,0.4113948047161102,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,213878.38510799408,231532.45939064023,213878.38510799408,17593.989289283752,35.059608697891235,0.0 -478100,3.0311708,2.6789265,,,,,,,,,,,,,, -478200,2.9979408,1.2673974,,,,,,,,,,,,,, -478300,2.7402937,1.4986813,,,,,,,,,,,,,, -478400,3.8898993,3.153369,,,,,,,,,,,,,, -478500,4.0817113,1.1310426,,,,,,,,,,,,,, -478600,3.8694913,3.0177662,,,,,,,,,,,,,, -478700,2.880292,1.2026628,,,,,,,,,,,,,, -478800,3.2429209,2.8407393,,,,,,,,,,,,,, -478900,3.3352263,2.6529088,,,,,,,,,,,,,, -478951,,,0.8887109160423279,0.4161091148853302,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,214298.27096557617,231988.60705900192,214298.27096557617,17630.104825258255,35.1567656993866,0.0 -479000,3.0305777,2.5051866,,,,,,,,,,,,,, -479100,2.9042373,1.3066438,,,,,,,,,,,,,, -479200,3.0786011,1.0269883,,,,,,,,,,,,,, -479300,3.3667648,1.1714904,,,,,,,,,,,,,, -479400,2.9708261,1.9025211,,,,,,,,,,,,,, -479500,3.0434737,1.4095258,,,,,,,,,,,,,, -479600,3.3767266,1.1380453,,,,,,,,,,,,,, -479700,3.6481469,2.918393,,,,,,,,,,,,,, -479800,4.3619075,3.151873,,,,,,,,,,,,,, -479893,,,0.8892577886581421,0.4139748215675354,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,214718.205119133,232443.3776421547,214718.205119133,17664.77192544937,35.27648377418518,0.0 -479900,2.9771106,1.3007687,,,,,,,,,,,,,, -480000,3.611186,2.9621396,,,,,,,,,,,,,, -480100,3.243219,2.831699,,,,,,,,,,,,,, -480200,3.104989,1.236535,,,,,,,,,,,,,, -480300,3.1462994,1.0925374,,,,,,,,,,,,,, -480400,3.5388978,3.1307578,,,,,,,,,,,,,, -480500,2.8113146,1.2033308,,,,,,,,,,,,,, -480600,3.1035771,1.2018497,,,,,,,,,,,,,, -480700,2.9098172,1.2269585,,,,,,,,,,,,,, -480800,3.379476,1.1426125,,,,,,,,,,,,,, -480834,,,0.8868163824081421,0.4217853844165802,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,215138.26736354828,232897.3886890412,215138.26736354828,17698.574368476868,35.3733856678009,0.0 -480900,3.2272303,2.6532195,,,,,,,,,,,,,, -481000,2.9155977,1.1743654,,,,,,,,,,,,,, -481100,4.31418,3.2359514,,,,,,,,,,,,,, -481200,3.7444427,2.7484124,,,,,,,,,,,,,, -481300,3.1172175,1.169729,,,,,,,,,,,,,, -481400,3.3962197,2.5307457,,,,,,,,,,,,,, -481500,3.2254531,1.7675635,,,,,,,,,,,,,, -481600,2.990264,1.2061836,,,,,,,,,,,,,, -481700,2.943406,1.9723082,,,,,,,,,,,,,, -481774,,,0.8877929449081421,0.4164779186248779,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,215558.27237558365,233353.13364720345,215558.27237558365,17734.160432100296,35.47701930999756,0.0 -481800,3.1073995,2.5813622,,,,,,,,,,,,,, -481900,3.25663,1.5086327,,,,,,,,,,,,,, -482000,3.1777897,1.6749424,,,,,,,,,,,,,, -482100,3.9768562,3.156785,,,,,,,,,,,,,, -482200,3.5198486,3.0074015,,,,,,,,,,,,,, -482300,3.4363625,2.9995549,,,,,,,,,,,,,, -482400,3.1097426,1.0586742,,,,,,,,,,,,,, -482500,3.0142772,1.0471146,,,,,,,,,,,,,, -482600,3.4581323,1.6125084,,,,,,,,,,,,,, -482700,2.9582074,1.0475178,,,,,,,,,,,,,, -482712,,,0.8885351419448853,0.4163164794445038,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,215978.2754702568,233809.1701474189,215978.2754702568,17770.049526929855,35.57323598861694,0.0 -482800,3.7419426,3.0446496,,,,,,,,,,,,,, -482900,3.3174984,1.110184,,,,,,,,,,,,,, -483000,2.9461992,1.4182322,,,,,,,,,,,,,, -483100,2.9137185,1.1148921,,,,,,,,,,,,,, -483200,3.0714674,1.0782813,,,,,,,,,,,,,, -483300,3.179245,2.215456,,,,,,,,,,,,,, -483400,3.5278952,1.2505968,,,,,,,,,,,,,, -483500,3.6828995,2.973159,,,,,,,,,,,,,, -483600,3.020584,1.1856987,,,,,,,,,,,,,, -483652,,,0.8861523270606995,0.4175904095172882,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,216398.384973526,234262.88517141345,216398.384973526,17803.500860214233,35.67725419998169,0.0 -483700,3.17436,1.0455174,,,,,,,,,,,,,, -483800,3.2876282,1.3580381,,,,,,,,,,,,,, -483900,3.0334,1.993542,,,,,,,,,,,,,, -484000,3.1679716,1.1450075,,,,,,,,,,,,,, -484100,3.4270759,1.2891,,,,,,,,,,,,,, -484200,3.1641755,1.2090721,,,,,,,,,,,,,, -484300,3.0332668,1.2229271,,,,,,,,,,,,,, -484400,3.8102353,3.1734958,,,,,,,,,,,,,, -484500,3.145572,1.1142325,,,,,,,,,,,,,, -484589,,,0.8881054520606995,0.4143379628658294,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,216818.5877084732,234718.37487339973,216818.5877084732,17838.617998600006,35.79676175117493,0.0 -484600,3.272938,1.235961,,,,,,,,,,,,,, -484700,3.0776742,1.2183326,,,,,,,,,,,,,, -484800,3.1346226,1.0990983,,,,,,,,,,,,,, -484900,3.3255367,3.0973625,,,,,,,,,,,,,, -485000,3.1250396,1.8036507,,,,,,,,,,,,,, -485100,3.651629,3.120752,,,,,,,,,,,,,, -485200,3.3014605,1.6228688,,,,,,,,,,,,,, -485300,2.9874597,1.9120595,,,,,,,,,,,,,, -485400,3.815859,2.6779044,,,,,,,,,,,,,, -485500,3.1009262,2.4106007,,,,,,,,,,,,,, -485529,,,0.8881640434265137,0.4148095548152923,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,217238.5685465336,235174.451969862,217238.5685465336,17874.568171977997,35.89333629608154,0.0 -485600,3.1732643,1.3300016,,,,,,,,,,,,,, -485700,4.5352697,3.1870933,,,,,,,,,,,,,, -485800,3.1728652,1.2302611,,,,,,,,,,,,,, -485900,3.3525712,2.779088,,,,,,,,,,,,,, -486000,3.2789118,1.1604154,,,,,,,,,,,,,, -486100,3.5596774,1.1377873,,,,,,,,,,,,,, -486200,3.1234586,1.3279328,,,,,,,,,,,,,, -486300,3.2225513,2.8286757,,,,,,,,,,,,,, -486400,3.1032984,2.599469,,,,,,,,,,,,,, -486470,,,0.88818359375,0.4164724946022033,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,217658.7013375759,235628.41182494164,217658.7013375759,17908.249662160873,35.98907494544983,0.0 -486500,3.4391844,1.0338625,,,,,,,,,,,,,, -486600,3.5682747,1.1703849,,,,,,,,,,,,,, -486700,3.1897821,2.476587,,,,,,,,,,,,,, -486800,3.0606825,1.0712076,,,,,,,,,,,,,, -486900,3.2328763,1.4240432,,,,,,,,,,,,,, -487000,3.259745,1.1041957,,,,,,,,,,,,,, -487100,3.3406312,1.1284817,,,,,,,,,,,,,, -487200,4.055943,1.0771931,,,,,,,,,,,,,, -487300,3.7255301,3.1472654,,,,,,,,,,,,,, -487400,3.065715,1.0839566,,,,,,,,,,,,,, -487409,,,0.8879492282867432,0.4153375029563904,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,218078.80738782883,236081.30646657944,218078.80738782883,17940.89183330536,36.08620810508728,0.0 -487500,3.2446606,1.4913183,,,,,,,,,,,,,, -487600,2.7982535,1.6222119,,,,,,,,,,,,,, -487700,3.262383,1.1826766,,,,,,,,,,,,,, -487800,2.943945,1.0948993,,,,,,,,,,,,,, -487900,3.3824475,1.1324401,,,,,,,,,,,,,, -488000,3.2945483,1.1318034,,,,,,,,,,,,,, -488100,3.189483,2.1978211,,,,,,,,,,,,,, -488200,3.3133705,2.407711,,,,,,,,,,,,,, -488300,3.2012117,2.78668,,,,,,,,,,,,,, -488346,,,0.8887695074081421,0.4143279790878296,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,218498.74874806404,236537.2109837532,218498.74874806404,17976.68732571602,36.20460081100464,0.0 -488400,3.5945313,2.9105783,,,,,,,,,,,,,, -488500,3.0610745,2.5601373,,,,,,,,,,,,,, -488600,3.145928,1.1857123,,,,,,,,,,,,,, -488700,3.0803413,1.1725402,,,,,,,,,,,,,, -488800,3.6715026,3.0897934,,,,,,,,,,,,,, -488900,3.0192726,1.7541099,,,,,,,,,,,,,, -489000,3.0143714,2.3825243,,,,,,,,,,,,,, -489100,3.0573485,1.827178,,,,,,,,,,,,,, -489200,3.114233,1.1779196,,,,,,,,,,,,,, -489283,,,0.8891015648841858,0.4146701693534851,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,218918.67112493515,236993.0835986137,218918.67112493515,18012.489437818527,36.3038694858551,0.0 -489300,2.9819283,1.0466142,,,,,,,,,,,,,, -489400,2.862166,1.1771584,,,,,,,,,,,,,, -489500,2.889006,1.1590875,,,,,,,,,,,,,, -489600,3.2695775,1.1926947,,,,,,,,,,,,,, -489700,3.3021228,2.8138463,,,,,,,,,,,,,, -489800,3.0279655,2.342722,,,,,,,,,,,,,, -489900,3.3443396,1.1089469,,,,,,,,,,,,,, -490000,4.0773997,3.1626232,,,,,,,,,,,,,, -490100,3.3105059,2.1959205,,,,,,,,,,,,,, -490200,3.6210625,1.0957509,,,,,,,,,,,,,, -490226,,,0.8903515338897705,0.4075874984264374,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,219338.8282425404,237448.3058791161,219338.8282425404,18047.40689873696,36.40238857269287,0.0 -490300,3.4783967,3.0277698,,,,,,,,,,,,,, -490400,3.0038016,1.0816152,,,,,,,,,,,,,, -490500,3.1082785,2.1366723,,,,,,,,,,,,,, -490600,3.4522092,1.1801934,,,,,,,,,,,,,, -490700,3.1703336,1.3210839,,,,,,,,,,,,,, -490800,3.1330297,1.5739453,,,,,,,,,,,,,, -490900,3.387941,1.2879555,,,,,,,,,,,,,, -491000,3.1220016,1.1351753,,,,,,,,,,,,,, -491100,3.3901548,1.1555037,,,,,,,,,,,,,, -491164,,,0.8893554210662842,0.4099286496639251,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,219758.6913704872,237903.41774082184,219758.6913704872,18082.4962182045,36.51316452026367,0.0 -491200,3.0733645,1.0975361,,,,,,,,,,,,,, -491300,2.920665,1.1297066,,,,,,,,,,,,,, -491400,3.0047178,1.0950658,,,,,,,,,,,,,, -491500,2.9735398,1.0685494,,,,,,,,,,,,,, -491600,3.0486674,1.1705217,,,,,,,,,,,,,, -491700,3.3460512,1.1770715,,,,,,,,,,,,,, -491800,2.9893656,2.0309396,,,,,,,,,,,,,, -491900,3.2908802,1.1442978,,,,,,,,,,,,,, -492000,3.6146934,3.1254542,,,,,,,,,,,,,, -492100,3.2263412,2.1120417,,,,,,,,,,,,,, -492103,,,0.8878515362739563,0.4151861369609833,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,220178.8872082233,238358.4329917431,220178.8872082233,18117.168116807938,36.611567735672,0.0 -492200,3.1542902,1.0426172,,,,,,,,,,,,,, -492300,3.9606833,3.1432023,,,,,,,,,,,,,, -492400,3.2661147,1.1633545,,,,,,,,,,,,,, -492500,3.156433,1.2115253,,,,,,,,,,,,,, -492600,3.1429157,1.1731942,,,,,,,,,,,,,, -492700,3.0259576,1.2843996,,,,,,,,,,,,,, -492800,2.7930818,2.381007,,,,,,,,,,,,,, -492900,2.9306586,1.0342396,,,,,,,,,,,,,, -493000,3.2604694,2.6869779,,,,,,,,,,,,,, -493045,,,0.8883984088897705,0.414981484413147,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,220598.7592358589,238813.04859733584,220598.7592358589,18151.76734852791,36.70686912536621,0.0 -493100,3.6738603,3.126171,,,,,,,,,,,,,, -493200,3.23011,1.1184936,,,,,,,,,,,,,, -493300,3.2876315,1.0642955,,,,,,,,,,,,,, -493400,2.8608222,1.520389,,,,,,,,,,,,,, -493500,3.019796,1.2725526,,,,,,,,,,,,,, -493600,3.0811276,1.1725498,,,,,,,,,,,,,, -493700,3.1056669,1.127338,,,,,,,,,,,,,, -493800,3.131665,2.434823,,,,,,,,,,,,,, -493900,3.2109878,1.3508245,,,,,,,,,,,,,, -493983,,,0.8869531154632568,0.4125227630138397,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,221018.6685461998,239267.2332293988,221018.6685461998,18185.89629340172,36.804692029953,0.0 -494000,3.1240284,1.133265,,,,,,,,,,,,,, -494100,2.952882,1.2334812,,,,,,,,,,,,,, -494200,3.0073106,1.0773367,,,,,,,,,,,,,, -494300,2.7949657,1.863879,,,,,,,,,,,,,, -494400,3.3287134,1.1936572,,,,,,,,,,,,,, -494500,3.8929112,3.170376,,,,,,,,,,,,,, -494600,3.1069274,1.1737403,,,,,,,,,,,,,, -494700,3.1363797,1.1233898,,,,,,,,,,,,,, -494800,3.3726454,1.1790409,,,,,,,,,,,,,, -494900,3.6444845,2.8756905,,,,,,,,,,,,,, -494921,,,0.8883007764816284,0.4121655523777008,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,221438.6368880272,239722.64675951004,221438.6368880272,18221.183396816254,36.91400504112244,0.0 -495000,2.9365077,1.7218229,,,,,,,,,,,,,, -495100,3.8056958,3.0520802,,,,,,,,,,,,,, -495200,3.397078,1.2494465,,,,,,,,,,,,,, -495300,2.9142551,1.3376594,,,,,,,,,,,,,, -495400,3.0740833,1.062745,,,,,,,,,,,,,, -495500,2.7830093,1.1714339,,,,,,,,,,,,,, -495600,3.0497901,1.3540763,,,,,,,,,,,,,, -495700,3.018107,1.5632122,,,,,,,,,,,,,, -495800,3.063352,2.257268,,,,,,,,,,,,,, -495861,,,0.88832026720047,0.4151731729507446,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,221858.705681324,240179.26391124725,221858.705681324,18257.582832574844,37.0140278339386,0.0 -495900,2.864772,1.6385255,,,,,,,,,,,,,, -496000,3.1737726,1.1404076,,,,,,,,,,,,,, -496100,2.72094,1.0631895,,,,,,,,,,,,,, -496200,3.280037,1.0968996,,,,,,,,,,,,,, -496300,3.08161,2.4208,,,,,,,,,,,,,, -496400,3.222933,2.4406757,,,,,,,,,,,,,, -496500,3.083146,2.0629392,,,,,,,,,,,,,, -496600,3.0486455,1.4580607,,,,,,,,,,,,,, -496700,3.2841082,1.2164435,,,,,,,,,,,,,, -496799,,,0.88783198595047,0.4155504107475281,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,222278.61622595787,240634.42562317848,222278.61622595787,18292.674980402,37.12426042556763,0.0 -496800,2.9252694,2.2650073,,,,,,,,,,,,,, -496900,3.2679653,1.2067599,,,,,,,,,,,,,, -497000,3.205715,1.097614,,,,,,,,,,,,,, -497100,2.9547307,1.7544068,,,,,,,,,,,,,, -497200,2.9445322,1.0927284,,,,,,,,,,,,,, -497300,3.068509,1.168167,,,,,,,,,,,,,, -497400,2.828343,1.3246585,,,,,,,,,,,,,, -497500,3.40563,1.2669909,,,,,,,,,,,,,, -497600,3.2927585,1.1761433,,,,,,,,,,,,,, -497700,3.011097,2.4641683,,,,,,,,,,,,,, -497738,,,0.8870507478713989,0.4186049699783325,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,222698.61618041992,241090.9749581813,222698.61618041992,18329.07387948036,37.22601556777954,0.0 -497800,3.089314,2.715487,,,,,,,,,,,,,, -497900,2.9989994,1.5772276,,,,,,,,,,,,,, -498000,2.9849005,1.472721,,,,,,,,,,,,,, -498100,3.0469246,1.1564325,,,,,,,,,,,,,, -498200,3.1801302,2.3466945,,,,,,,,,,,,,, -498300,3.7447207,2.9448571,,,,,,,,,,,,,, -498400,2.9051635,1.7951885,,,,,,,,,,,,,, -498500,2.985346,1.1029218,,,,,,,,,,,,,, -498600,2.987334,1.5724639,,,,,,,,,,,,,, -498677,,,0.8886913657188416,0.412122905254364,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,223118.7795138359,241546.0461564064,223118.7795138359,18363.83420419693,37.32507681846619,0.0 -498700,4.05997,3.1540513,,,,,,,,,,,,,, -498800,3.4789224,3.174367,,,,,,,,,,,,,, -498900,3.018346,1.3188531,,,,,,,,,,,,,, -499000,3.8710167,3.2806993,,,,,,,,,,,,,, -499100,2.893883,1.3308697,,,,,,,,,,,,,, -499200,2.9990544,2.3193178,,,,,,,,,,,,,, -499300,3.4157927,1.6739793,,,,,,,,,,,,,, -499400,3.0740163,1.5130079,,,,,,,,,,,,,, -499500,3.2403235,1.1050189,,,,,,,,,,,,,, -499600,3.0164192,1.1019258,,,,,,,,,,,,,, -499615,,,0.88734370470047,0.4127151072025299,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,223538.7945408821,242001.99430322647,223538.7945408821,18399.62007832527,37.42348909378052,0.0 -499700,4.089073,3.1126313,,,,,,,,,,,,,, -499800,3.1107812,2.2485952,,,,,,,,,,,,,, -499900,3.6874292,3.0309045,,,,,,,,,,,,,, -500000,3.0376375,1.168854,,,,,,,,,,,,,, -500100,3.2120326,1.212505,,,,,,,,,,,,,, -500200,2.8394275,1.822195,,,,,,,,,,,,,, -500300,3.2304304,2.0435069,,,,,,,,,,,,,, -500400,3.3152876,1.1031853,,,,,,,,,,,,,, -500500,3.2169104,1.4304588,,,,,,,,,,,,,, -500556,,,0.8880664110183716,0.4174224436283111,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,223958.84918499,242457.4077973365,223958.84918499,18434.819314956665,37.53372430801392,0.0 -500600,3.4090056,2.572885,,,,,,,,,,,,,, -500700,3.063851,1.5681193,,,,,,,,,,,,,, -500800,2.9709709,1.1950549,,,,,,,,,,,,,, -500900,3.1493382,1.200154,,,,,,,,,,,,,, -501000,3.3900723,1.1747106,,,,,,,,,,,,,, -501100,3.1419723,1.1157876,,,,,,,,,,,,,, -501200,3.4130955,1.1243904,,,,,,,,,,,,,, -501300,3.3434594,1.0560751,,,,,,,,,,,,,, -501400,3.055283,1.7374905,,,,,,,,,,,,,, -501496,,,0.8877148032188416,0.4193105101585388,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,224378.8450908661,242912.9882376194,224378.8450908661,18470.24125123024,37.64730954170227,0.0 -501500,3.124561,1.2741392,,,,,,,,,,,,,, -501600,2.977881,1.3235056,,,,,,,,,,,,,, -501700,3.1596458,2.572167,,,,,,,,,,,,,, -501800,3.4825804,3.123198,,,,,,,,,,,,,, -501900,3.2263312,1.1582664,,,,,,,,,,,,,, -502000,2.7158217,1.0729135,,,,,,,,,,,,,, -502100,3.4383051,2.5982714,,,,,,,,,,,,,, -502200,3.2417896,1.1738974,,,,,,,,,,,,,, -502300,2.9612753,2.2551758,,,,,,,,,,,,,, -502400,3.1925154,2.3090239,,,,,,,,,,,,,, -502436,,,0.8877733945846558,0.4132482409477234,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,224798.8704931736,243368.6910982132,224798.8704931736,18505.77114391327,37.74493646621704,0.0 -502500,2.9298625,2.380398,,,,,,,,,,,,,, -502600,2.8201435,1.0192944,,,,,,,,,,,,,, -502700,3.188744,2.4851778,,,,,,,,,,,,,, -502800,3.1537712,1.0643866,,,,,,,,,,,,,, -502900,3.114762,2.6613188,,,,,,,,,,,,,, -503000,4.145979,3.3087606,,,,,,,,,,,,,, -503100,3.1253965,1.6271466,,,,,,,,,,,,,, -503200,3.722843,2.9846227,,,,,,,,,,,,,, -503300,3.126429,2.64317,,,,,,,,,,,,,, -503376,,,0.8895312547683716,0.4103905856609344,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,225219.12637400627,243823.23779582977,225219.12637400627,18539.91388988495,37.84392428398132,0.0 -503400,3.216055,1.2020084,,,,,,,,,,,,,, -503500,3.0103676,1.8034898,,,,,,,,,,,,,, -503600,3.8177388,3.129555,,,,,,,,,,,,,, -503700,3.0513906,1.1169852,,,,,,,,,,,,,, -503800,3.2338574,1.1558878,,,,,,,,,,,,,, -503900,3.0009089,1.1064012,,,,,,,,,,,,,, -504000,2.982872,2.5138483,,,,,,,,,,,,,, -504100,3.1292698,2.8602617,,,,,,,,,,,,,, -504200,2.7819,1.7035695,,,,,,,,,,,,,, -504300,3.0986493,1.081144,,,,,,,,,,,,,, -504311,,,0.8867968320846558,0.422403335571289,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,225639.3281033039,244279.811917305,225639.3281033039,18576.1207678318,37.96061396598816,0.0 -504400,3.3758757,2.8121402,,,,,,,,,,,,,, -504500,3.244151,1.0214736,,,,,,,,,,,,,, -504600,3.2677975,1.2540641,,,,,,,,,,,,,, -504700,2.9671507,1.330311,,,,,,,,,,,,,, -504800,3.5367658,1.0637532,,,,,,,,,,,,,, -504900,2.7441611,1.6270227,,,,,,,,,,,,,, -505000,3.2531395,1.0733436,,,,,,,,,,,,,, -505100,2.991619,1.1506495,,,,,,,,,,,,,, -505200,3.097743,0.9831579,,,,,,,,,,,,,, -505248,,,0.888671875,0.4171231091022491,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,226059.1844224929,244736.27666950223,226059.1844224929,18612.577392101288,38.06256675720215,0.0 -505300,2.9462137,1.4514785,,,,,,,,,,,,,, -505400,3.0894232,2.5848215,,,,,,,,,,,,,, -505500,3.2029724,1.0842828,,,,,,,,,,,,,, -505600,3.0897505,1.0911078,,,,,,,,,,,,,, -505700,3.4275224,1.2080775,,,,,,,,,,,,,, -505800,3.136175,1.0019274,,,,,,,,,,,,,, -505900,3.224943,1.4199709,,,,,,,,,,,,,, -506000,3.7165225,3.2409441,,,,,,,,,,,,,, -506100,3.0926402,1.132431,,,,,,,,,,,,,, -506186,,,0.88832026720047,0.4186186492443084,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,226479.16041707995,245192.26065707207,226479.16041707995,18648.43890714645,38.15953755378723,0.0 -506200,3.4178069,3.0927262,,,,,,,,,,,,,, -506300,2.9979718,2.024288,,,,,,,,,,,,,, -506400,3.2153094,1.7800457,,,,,,,,,,,,,, -506500,3.1499302,1.300925,,,,,,,,,,,,,, -506600,3.1504986,1.124881,,,,,,,,,,,,,, -506700,3.1662252,1.4675486,,,,,,,,,,,,,, -506800,4.615971,3.2372613,,,,,,,,,,,,,, -506900,3.0450327,1.0832392,,,,,,,,,,,,,, -507000,2.9591682,1.3507004,,,,,,,,,,,,,, -507100,3.1546009,1.0575435,,,,,,,,,,,,,, -507127,,,0.8861913681030273,0.4204094707965851,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,226899.0175216198,245647.18233394623,226899.0175216198,18683.35116648674,38.262371301651,0.0 -507200,3.4160838,1.1732006,,,,,,,,,,,,,, -507300,2.9359782,1.0873655,,,,,,,,,,,,,, -507400,3.37491,3.0093005,,,,,,,,,,,,,, -507500,3.3106558,2.7444322,,,,,,,,,,,,,, -507600,4.264574,3.1220262,,,,,,,,,,,,,, -507700,3.246433,1.3740132,,,,,,,,,,,,,, -507800,3.2101395,2.2028296,,,,,,,,,,,,,, -507900,3.2035718,1.6844971,,,,,,,,,,,,,, -508000,3.0173445,1.0122741,,,,,,,,,,,,,, -508065,,,0.8902148008346558,0.4068226814270019,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,227319.15297055244,246103.9114441872,227319.15297055244,18719.78503012657,38.373663663864136,0.0 -508100,3.7042987,3.2532253,,,,,,,,,,,,,, -508200,3.4464006,1.1515476,,,,,,,,,,,,,, -508300,3.0961456,1.1007546,,,,,,,,,,,,,, -508400,3.0216832,1.1764076,,,,,,,,,,,,,, -508500,2.7973576,2.0966437,,,,,,,,,,,,,, -508600,3.2375968,1.1135445,,,,,,,,,,,,,, -508700,3.1019442,1.1852536,,,,,,,,,,,,,, -508800,3.846597,3.2405403,,,,,,,,,,,,,, -508900,3.7299478,3.156058,,,,,,,,,,,,,, -509000,3.2093894,1.026342,,,,,,,,,,,,,, -509005,,,0.8861913681030273,0.420152872800827,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,227739.37865519524,246558.1716852188,227739.37865519524,18753.6684820652,38.47547960281372,0.0 -509100,3.0125856,1.1045763,,,,,,,,,,,,,, -509200,2.9973152,1.077601,,,,,,,,,,,,,, -509300,3.135419,2.1674592,,,,,,,,,,,,,, -509400,3.1927016,1.1427629,,,,,,,,,,,,,, -509500,2.9776268,1.226979,,,,,,,,,,,,,, -509600,2.9968278,1.6222043,,,,,,,,,,,,,, -509700,3.26841,2.4761765,,,,,,,,,,,,,, -509800,2.9759216,1.0354958,,,,,,,,,,,,,, -509900,3.0255835,1.0416229,,,,,,,,,,,,,, -509943,,,0.8894726634025574,0.4081210494041443,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,228159.2598702908,247014.39329242703,228159.2598702908,18789.85410284996,38.58154392242432,0.0 -510000,2.9604654,1.3108855,,,,,,,,,,,,,, -510100,3.1705124,1.1130742,,,,,,,,,,,,,, -510200,3.1718903,1.0678746,,,,,,,,,,,,,, -510300,3.0052452,1.0917691,,,,,,,,,,,,,, -510400,4.003193,3.22469,,,,,,,,,,,,,, -510500,3.5349352,3.2279637,,,,,,,,,,,,,, -510600,2.9022303,2.1924057,,,,,,,,,,,,,, -510700,3.203633,1.2053863,,,,,,,,,,,,,, -510800,2.9738379,2.380004,,,,,,,,,,,,,, -510880,,,0.8873046636581421,0.4174008965492248,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,228579.4013376236,247469.3725888729,228579.4013376236,18824.52872061729,38.695571184158325,0.0 -510900,3.1575074,1.4755708,,,,,,,,,,,,,, -511000,3.1359355,1.1360025,,,,,,,,,,,,,, -511100,3.051908,1.0531027,,,,,,,,,,,,,, -511200,3.3417382,1.1425362,,,,,,,,,,,,,, -511300,2.830315,1.5662308,,,,,,,,,,,,,, -511400,2.8371818,1.5874994,,,,,,,,,,,,,, -511500,3.0364842,1.2497079,,,,,,,,,,,,,, -511600,3.0412476,1.1335417,,,,,,,,,,,,,, -511700,2.949092,2.4001057,,,,,,,,,,,,,, -511800,3.308689,1.112837,,,,,,,,,,,,,, -511817,,,0.8889257907867432,0.4128748476505279,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,228999.4161813259,247925.6817638874,228999.4161813259,18860.66403913498,38.80521607398987,0.0 -511900,3.2944722,1.1516693,,,,,,,,,,,,,, -512000,3.032169,1.08287,,,,,,,,,,,,,, -512100,3.1150296,1.3397584,,,,,,,,,,,,,, -512200,3.2918687,2.506931,,,,,,,,,,,,,, -512300,3.119537,1.1290325,,,,,,,,,,,,,, -512400,2.8947263,2.160546,,,,,,,,,,,,,, -512500,3.2347934,2.690257,,,,,,,,,,,,,, -512600,3.2343874,2.2959168,,,,,,,,,,,,,, -512700,3.6425476,3.0260296,,,,,,,,,,,,,, -512755,,,0.8876953125,0.4195302724838257,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,229419.3546979428,248379.9687492848,229419.3546979428,18894.84908437729,38.91966819763184,0.0 -512800,4.1275353,3.2659378,,,,,,,,,,,,,, -512900,3.0786057,1.066354,,,,,,,,,,,,,, -513000,2.8880455,1.7310596,,,,,,,,,,,,,, -513100,3.1299677,1.0991668,,,,,,,,,,,,,, -513200,3.049374,1.038068,,,,,,,,,,,,,, -513300,3.2052977,1.6543437,,,,,,,,,,,,,, -513400,3.322483,2.427588,,,,,,,,,,,,,, -513500,3.868864,3.230836,,,,,,,,,,,,,, -513600,3.1194487,1.8268739,,,,,,,,,,,,,, -513690,,,0.8895312547683716,0.4111018180847168,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,229839.27992606163,248833.9214372635,229839.27992606163,18928.725402593613,39.0208375453949,0.0 -513700,3.0727994,2.0615544,,,,,,,,,,,,,, -513800,3.4232805,1.2317607,,,,,,,,,,,,,, -513900,3.4984462,1.0677203,,,,,,,,,,,,,, -514000,3.286153,1.027953,,,,,,,,,,,,,, -514100,3.986782,3.342507,,,,,,,,,,,,,, -514200,2.7854965,1.8539296,,,,,,,,,,,,,, -514300,3.0437846,2.2783294,,,,,,,,,,,,,, -514400,3.2909014,1.1477084,,,,,,,,,,,,,, -514500,3.2133327,1.072677,,,,,,,,,,,,,, -514600,3.325921,1.2482126,,,,,,,,,,,,,, -514623,,,0.8900585770606995,0.4098351299762726,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,230259.215269804,249290.68759441376,230259.215269804,18965.40698552132,39.12135338783264,0.0 -514700,2.9438896,1.2375809,,,,,,,,,,,,,, -514800,3.1565526,2.6320941,,,,,,,,,,,,,, -514900,3.1024609,1.9445238,,,,,,,,,,,,,, -515000,3.0563414,1.1071733,,,,,,,,,,,,,, -515100,2.9969668,1.0156541,,,,,,,,,,,,,, -515200,3.3764338,1.0616701,,,,,,,,,,,,,, -515300,3.9424338,2.727675,,,,,,,,,,,,,, -515400,3.122656,1.3292391,,,,,,,,,,,,,, -515500,2.873554,1.1960992,,,,,,,,,,,,,, -515560,,,0.8890234231948853,0.4100822210311889,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,230679.0721449852,249747.4019293785,230679.0721449852,19002.11077594757,39.22586536407471,0.0 -515600,3.21641,1.179892,,,,,,,,,,,,,, -515700,2.8878975,1.9678884,,,,,,,,,,,,,, -515800,2.9841785,1.1191148,,,,,,,,,,,,,, -515900,3.1354227,1.1924036,,,,,,,,,,,,,, -516000,3.2741609,1.1547965,,,,,,,,,,,,,, -516100,3.484911,1.1627461,,,,,,,,,,,,,, -516200,3.1655633,2.7842605,,,,,,,,,,,,,, -516300,3.3438237,1.2958803,,,,,,,,,,,,,, -516400,2.9650583,1.1556906,,,,,,,,,,,,,, -516498,,,0.8893945217132568,0.4079234898090362,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,231099.2383217812,250203.4171028137,231099.2383217812,19037.79772377014,39.339415073394775,0.0 -516500,3.1063943,1.0843841,,,,,,,,,,,,,, -516600,3.2021933,1.1605628,,,,,,,,,,,,,, -516700,3.172667,1.8034432,,,,,,,,,,,,,, -516800,3.2909493,1.127976,,,,,,,,,,,,,, -516900,2.9809866,1.4239405,,,,,,,,,,,,,, -517000,2.9261487,2.1243353,,,,,,,,,,,,,, -517100,3.0051172,2.1916268,,,,,,,,,,,,,, -517200,3.3867269,1.8763216,,,,,,,,,,,,,, -517300,3.1647456,1.6356493,,,,,,,,,,,,,, -517400,2.7103097,1.0005862,,,,,,,,,,,,,, -517437,,,0.8845898509025574,0.4218060672283172,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,231519.41894960403,250657.88970041275,231519.41894960403,19071.938607931137,39.44188857078552,0.0 -517500,3.460252,1.1420331,,,,,,,,,,,,,, -517600,2.9702477,1.3403132,,,,,,,,,,,,,, -517700,3.6056378,2.5545878,,,,,,,,,,,,,, -517800,3.3177333,2.4264524,,,,,,,,,,,,,, -517900,2.8485231,2.042141,,,,,,,,,,,,,, -518000,3.3384213,2.6559255,,,,,,,,,,,,,, -518100,3.5420609,3.0879588,,,,,,,,,,,,,, -518200,3.3084466,1.1346059,,,,,,,,,,,,,, -518300,3.099668,1.1258304,,,,,,,,,,,,,, -518375,,,0.8900195360183716,0.4088830351829529,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,231939.48673439023,251114.31041026115,231939.48673439023,19108.14038753509,39.543895959854126,0.0 -518400,3.4900851,1.2944206,,,,,,,,,,,,,, -518500,3.2114692,1.2311844,,,,,,,,,,,,,, -518600,4.036222,3.2300816,,,,,,,,,,,,,, -518700,3.05054,1.0457454,,,,,,,,,,,,,, -518800,3.3093545,1.3498678,,,,,,,,,,,,,, -518900,3.0037744,1.0537351,,,,,,,,,,,,,, -519000,3.4042459,2.189269,,,,,,,,,,,,,, -519100,3.1713603,2.2032623,,,,,,,,,,,,,, -519200,3.0585854,1.8408827,,,,,,,,,,,,,, -519300,3.1092625,1.0809835,,,,,,,,,,,,,, -519311,,,0.8872656226158142,0.4180852770805359,0.7828800082206726,0.8518542051315308,50000.0,0.6663000583648682,1.4581419229507446,10000.0,232360.0536572933,251570.7573211193,232360.0536572933,19143.85590171814,39.65925049781799,0.0 -519400,3.1747296,1.8287463,,,,,,,,,,,,,, -519500,3.3089428,1.8263704,,,,,,,,,,,,,, -519600,3.3654115,1.3709916,,,,,,,,,,,,,, -519700,3.586743,2.496328,,,,,,,,,,,,,, -519764,,,,,,,,,,,232560.43942379951,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 9d5d18299..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,42 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -185.7470064163208,0.0,58.690967082977295,1,0,58.690967082977295,30.044025,2472,1.2556009180833994,244.4380226135254,30.525188,1.632826978739354,30.029667,5348,1.2028346061384283 -312.36979603767395,0.0443918704986572,1499.073543548584,1731,0,1499.073543548584,3.1974833,2472,0.5987650559584019,1811.5598888397217,3.0738208,0.6013315894626269,3.4970143,5348,0.6448632418394045 -445.6631715297699,0.1009564399719238,2939.144833803177,3490,0,2939.144833803177,0.66895556,2472,0.2187760242114029,3385.059635400772,0.63175845,0.2134299336760514,0.95706713,5348,0.2770499242109734 -578.6111640930176,0.1544415950775146,4379.723701953888,5200,0,4379.723701953888,0.50083953,2472,0.165031584506327,4958.717666625977,0.43173733,0.1540711924980628,0.76068735,5348,0.2270098574007743 -713.0965855121613,0.2123136520385742,5820.102232933044,6925,0,5820.102232933044,0.43859884,2472,0.1470558365324071,6533.717638731003,0.3827844,0.1367330268407879,0.67956024,5348,0.2044179692402753 -846.3146262168884,0.2651317119598388,7260.559650659561,8656,0,7260.559650659561,0.3972195,2472,0.1341376718867426,8107.523503780365,0.34197354,0.1229206984278849,0.6259623,5348,0.1901097734052926 -977.789181470871,0.3199388980865478,8700.735800027847,10352,0,8700.735800027847,0.376258,2472,0.1259114821359657,9679.307429790497,0.33752048,0.1192898410784685,0.5941883,5348,0.1806482134064512 -1112.065169095993,0.3683531284332275,10141.14683175087,12066,0,10141.14683175087,0.3550018,2472,0.1176852923851888,11254.12031364441,0.29464692,0.107576627741263,0.5686792,5348,0.1717755872442724 -1244.5346031188965,0.418978214263916,11581.195252656937,13793,0,11581.195252656937,0.34235185,2472,0.1144963743830357,12826.767081737518,0.26047683,0.0976832470943082,0.5568923,5348,0.1668806781428309 -1377.50799202919,0.4708180427551269,13021.82521033287,15482,0,13021.82521033287,0.32886645,2472,0.1120792964068815,14400.499089717863,0.24350418,0.0904776851502109,0.54537946,5348,0.1634918949187561 -1508.9135534763336,0.5240156650543213,14461.918625116348,17201,0,14461.918625116348,0.3163899,2472,0.1059248877785225,15972.129461288452,0.24185376,0.0926437385703636,0.5224277,5348,0.1579597787153518 -1643.1955354213717,0.5756280422210693,15902.305790424349,18914,0,15902.305790424349,0.3022412,2472,0.1006641886539516,17546.927089452744,0.24827659,0.0911023559586451,0.49703676,5348,0.151143593654962 -1775.2081396579742,0.6244874000549316,17345.91077399254,20601,0,17345.91077399254,0.29747027,2472,0.0999532833668474,19122.66784954071,0.24300954,0.0892447305247052,0.49576217,5348,0.1492512816551937 -1909.0888845920565,0.6796655654907227,18785.811478853226,22318,0,18785.811478853226,0.28847873,2472,0.0986533422704283,20696.58254313469,0.22511902,0.084525156909147,0.49216977,5348,0.1476003359819264 -2045.6049239635468,0.736748456954956,20225.753566741943,24059,0,20225.753566741943,0.2866742,2472,0.0960534600775902,22273.176414966583,0.20986903,0.0783539546471855,0.4833996,5348,0.1443563725537522 -2180.148374080658,0.7892570495605469,21666.226598978043,25752,0,21666.226598978043,0.27209896,2472,0.0911583693863871,23848.32180738449,0.20069353,0.0751923628680977,0.46672383,5348,0.1400697065950935 -2314.819569826126,0.847764253616333,23106.16258573532,27476,0,23106.16258573532,0.26997027,2472,0.0923161294253854,25423.06540942192,0.19208643,0.0705781878062002,0.46214864,5348,0.1379649922280043 -2447.661313056946,0.9036474227905272,24546.349791288376,29202,0,24546.349791288376,0.26454708,2472,0.0883756829768651,26996.2281870842,0.20945032,0.0773151759403709,0.45469,5348,0.1361499174527163 -2583.189457654953,0.949357271194458,25987.04529428482,30887,0,25987.04529428482,0.2595338,2472,0.0865882639692889,28572.57029390335,0.19461107,0.0707166072523453,0.4496563,5348,0.1345375903916892 -2719.184014081955,1.008662462234497,27427.04628181457,32580,0,27427.04628181457,0.25433442,2472,0.085999228159974,30148.702434778214,0.19953363,0.0724306386738515,0.4363747,5348,0.1306178012493121 -2854.495341539383,1.0615310668945312,28867.922538518906,34301,0,28867.922538518906,0.2501841,2472,0.0845774175857656,31725.02043557167,0.19232655,0.0685497191939213,0.4371391,5348,0.1274800390047983 -2989.1104102134705,1.118565559387207,30308.074434041977,35987,0,30308.074434041977,0.24394017,2472,0.0813884995836126,33299.92118215561,0.14790924,0.056189334071318,0.4280118,5348,0.1253753246377091 -3123.903793334961,1.1749954223632812,31748.535685777664,37685,0,31748.535685777664,0.24056098,2472,0.0818962890744013,34875.30942106247,0.16978231,0.0630979282380519,0.4183739,5348,0.1235023219440609 -3256.5838243961334,1.236691951751709,33188.79012131691,39389,0,33188.79012131691,0.2356621,2472,0.0799870005890358,36448.38415312767,0.20656164,0.0754706747015399,0.41282338,5348,0.1211272772912905 -3388.7605333328247,1.292815923690796,34628.73562026024,41079,0,34628.73562026024,0.22915868,2472,0.0767574594276196,38020.64000558853,0.20947912,0.0772079028739237,0.4047456,5348,0.1206059260260482 -3520.995194911957,1.3495397567749023,36068.85384774208,42780,0,36068.85384774208,0.22481863,2472,0.0739747730180976,39593.12744855881,0.24113484,0.0889962620867968,0.39844996,5348,0.1170433590468926 -3650.400237083435,1.4104692935943604,37509.35020899773,44481,0,37509.35020899773,0.22069937,2472,0.0736700993236244,41163.16731405258,0.20987618,0.0761205541392648,0.39491072,5348,0.1156337797001264 -3782.523449420929,1.4723756313323977,38949.5586669445,46168,0,38949.5586669445,0.21279301,2472,0.0713545792456279,42735.63666701317,0.19044133,0.0708583389932415,0.38421032,5348,0.1118974289658901 -3917.786875724793,1.5329391956329346,40390.47105741501,47870,0,40390.47105741501,0.21163197,2472,0.0711311518696809,44311.951271533966,0.16163252,0.0609848937998238,0.3805211,5348,0.1111636753333269 -4053.8710753917694,1.5911128520965576,41830.47101521492,49570,0,41830.47101521492,0.20468043,2472,0.0681859728231064,45888.17083525658,0.17555173,0.0661228391273414,0.37177685,5348,0.108219006150014 -4188.673507928848,1.6475038528442385,43271.03357815743,51275,0,43271.03357815743,0.19672832,2472,0.0654845327321105,47463.67045736313,0.1526628,0.0584414190001717,0.36734933,5348,0.105718451007463 -4322.077683210373,1.7828662395477295,44711.0066754818,53003,0,44711.0066754818,0.1953285,2472,0.0639611642597444,49037.26174616814,0.1528462,0.0582670673363781,0.35978332,5348,0.1029379109261708 -4454.960786104202,1.8396832942962649,46151.02938437462,54694,0,46151.02938437462,0.19122066,2472,0.0628440273800093,50610.30272769928,0.14036037,0.0533664596273291,0.34712532,5348,0.0998098033347171 -4587.426983118057,1.8970699310302728,47591.55652046204,56386,0,47591.55652046204,0.18559092,2472,0.0613409704872748,52183.42995285988,0.14347468,0.0545997733594313,0.34344375,5348,0.0978305994574085 -4720.261053800583,1.95191502571106,49031.58073544502,58103,0,49031.58073544502,0.182006,2472,0.0589442040907521,53756.42091464996,0.14602813,0.0549668369291758,0.33937812,5348,0.0967878969269239 -4851.256810903549,2.0137839317321777,50471.67391204834,59760,0,50471.67391204834,0.17941004,2472,0.0587004651351735,55327.64785504341,0.11838623,0.0463940637119734,0.33321467,5348,0.0941425219884723 -4981.171031713486,2.0737240314483643,51912.141885757446,61455,0,51912.141885757446,0.17505817,2472,0.0565474376942294,56898.16695809364,0.11842466,0.0453082926384008,0.3269322,5348,0.0919798797030228 -5113.37780046463,2.138718366622925,53352.4820356369,63181,0,53352.4820356369,0.17149302,2472,0.0556537281904413,58470.85765528679,0.10824333,0.0414118644824353,0.32207102,5348,0.0907151201521573 -5242.881098985672,2.200098752975464,54793.28907322884,64865,0,54793.28907322884,0.16992109,2472,0.0562021408404931,60041.30572772026,0.10510961,0.0403559614432038,0.31513014,5348,0.0883111115402068 -5375.917708158493,2.264157295227051,56233.88724684715,66552,0,56233.88724684715,0.16586737,2472,0.0532163386346556,61615.08142876625,0.099286236,0.0379463786587956,0.31221068,5348,0.0875387392954034 -5507.4848001003265,2.319945812225342,57674.50718331337,68272,0,57674.50718331337,0.16405582,2472,0.05274917230313001,63187.40218567848,0.09623108,0.03642783613561648,0.3052696,5348,0.08554022611197466 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv deleted file mode 100644 index e14ebb8aa..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/measurements.csv +++ /dev/null @@ -1,726 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,71.16951,31.460064,,,,,,,,,,,,,, -1,,,30.525188,1.632826978739354,30.029667,1.2028346061384283,5348.0,30.044025,1.2556009180833994,2472.0,58.690967082977295,244.4380226135254,58.690967082977295,185.7470064163208,0.0,0.0 -100,0.8682036,5.950865,,,,,,,,,,,,,, -200,1.2198938,5.817406,,,,,,,,,,,,,, -300,3.0682926,5.801611,,,,,,,,,,,,,, -400,1.0120486,5.7875624,,,,,,,,,,,,,, -500,3.161302,5.7974725,,,,,,,,,,,,,, -600,1.9755938,5.7874503,,,,,,,,,,,,,, -700,1.1288373,5.55811,,,,,,,,,,,,,, -800,1.337599,5.387176,,,,,,,,,,,,,, -900,1.1651969,4.2903776,,,,,,,,,,,,,, -1000,1.147993,3.6655037,,,,,,,,,,,,,, -1100,1.0973585,3.3539543,,,,,,,,,,,,,, -1200,1.3241152,3.2147813,,,,,,,,,,,,,, -1300,1.0344901,3.0316746,,,,,,,,,,,,,, -1400,0.6230821,2.8869443,,,,,,,,,,,,,, -1500,0.6193506,2.671509,,,,,,,,,,,,,, -1600,1.1734592,2.5824454,,,,,,,,,,,,,, -1700,0.88197255,2.5088418,,,,,,,,,,,,,, -1731,,,3.0738208,0.6013315894626269,3.4970143,0.6448632418394045,5348.0,3.1974833,0.5987650559584019,2472.0,1499.073543548584,1811.5598888397217,1499.073543548584,312.36979603767395,0.0443918704986572,0.0 -1800,0.74599844,2.409129,,,,,,,,,,,,,, -1900,0.6912994,2.351066,,,,,,,,,,,,,, -2000,0.76694506,2.3203793,,,,,,,,,,,,,, -2100,0.6841353,2.265641,,,,,,,,,,,,,, -2200,1.1801667,2.1722722,,,,,,,,,,,,,, -2300,1.2185022,2.1181908,,,,,,,,,,,,,, -2400,0.84363204,2.1183317,,,,,,,,,,,,,, -2500,0.5541372,2.0139632,,,,,,,,,,,,,, -2600,0.7331222,2.0119016,,,,,,,,,,,,,, -2700,0.59810895,2.0199945,,,,,,,,,,,,,, -2800,0.73232424,2.0591297,,,,,,,,,,,,,, -2900,0.6596105,1.9486116,,,,,,,,,,,,,, -3000,0.92934895,2.0135179,,,,,,,,,,,,,, -3100,0.7002723,1.9368105,,,,,,,,,,,,,, -3200,0.91507685,1.8865191,,,,,,,,,,,,,, -3300,0.78122354,1.9224036,,,,,,,,,,,,,, -3400,0.61016786,1.8273315,,,,,,,,,,,,,, -3490,,,0.63175845,0.2134299336760514,0.95706713,0.2770499242109734,5348.0,0.66895556,0.2187760242114029,2472.0,2939.144833803177,3385.059635400772,2939.144833803177,445.6631715297699,0.1009564399719238,0.0 -3500,0.5954282,1.874915,,,,,,,,,,,,,, -3600,0.60942674,1.8262134,,,,,,,,,,,,,, -3700,0.5837198,1.8652673,,,,,,,,,,,,,, -3800,0.8706054,1.8445052,,,,,,,,,,,,,, -3900,0.53845334,1.7881956,,,,,,,,,,,,,, -4000,0.49738824,1.8438054,,,,,,,,,,,,,, -4100,0.51732415,1.7516594,,,,,,,,,,,,,, -4200,0.67306083,1.7409385,,,,,,,,,,,,,, -4300,0.63267225,1.7618759,,,,,,,,,,,,,, -4400,0.57843804,1.758615,,,,,,,,,,,,,, -4500,0.6540667,1.7887635,,,,,,,,,,,,,, -4600,0.6168944,1.7039608,,,,,,,,,,,,,, -4700,0.65226287,1.708356,,,,,,,,,,,,,, -4800,0.5573084,1.676478,,,,,,,,,,,,,, -4900,0.73708797,1.7641842,,,,,,,,,,,,,, -5000,0.52546316,1.7089893,,,,,,,,,,,,,, -5100,0.6583395,1.6983209,,,,,,,,,,,,,, -5200,,,0.43173733,0.1540711924980628,0.76068735,0.2270098574007743,5348.0,0.50083953,0.165031584506327,2472.0,4379.723701953888,4958.717666625977,4379.723701953888,578.6111640930176,0.1544415950775146,0.0 -5200,0.5585961,1.6467357,,,,,,,,,,,,,, -5300,0.46523494,1.6823423,,,,,,,,,,,,,, -5400,0.46448627,1.6741273,,,,,,,,,,,,,, -5500,0.5046817,1.6371952,,,,,,,,,,,,,, -5600,0.6004263,1.6890066,,,,,,,,,,,,,, -5700,0.4748891,1.6051241,,,,,,,,,,,,,, -5800,0.421306,1.6279215,,,,,,,,,,,,,, -5900,0.59116894,1.5764759,,,,,,,,,,,,,, -6000,0.38691762,1.6438913,,,,,,,,,,,,,, -6100,0.5697478,1.66371,,,,,,,,,,,,,, -6200,0.4784674,1.6196227,,,,,,,,,,,,,, -6300,0.4786567,1.6239095,,,,,,,,,,,,,, -6400,0.45445672,1.6034533,,,,,,,,,,,,,, -6500,0.5529065,1.6314878,,,,,,,,,,,,,, -6600,0.51059955,1.6481876,,,,,,,,,,,,,, -6700,0.43871638,1.5766774,,,,,,,,,,,,,, -6800,0.4119731,1.6191235,,,,,,,,,,,,,, -6900,0.6458836,1.5927284,,,,,,,,,,,,,, -6925,,,0.3827844,0.1367330268407879,0.67956024,0.2044179692402753,5348.0,0.43859884,0.1470558365324071,2472.0,5820.102232933044,6533.717638731003,5820.102232933044,713.0965855121613,0.2123136520385742,0.0 -7000,0.48460427,1.5730584,,,,,,,,,,,,,, -7100,0.5319566,1.6057825,,,,,,,,,,,,,, -7200,0.51738137,1.552498,,,,,,,,,,,,,, -7300,0.5029656,1.5872378,,,,,,,,,,,,,, -7400,0.49399996,1.5844212,,,,,,,,,,,,,, -7500,0.56316376,1.5666366,,,,,,,,,,,,,, -7600,0.49696243,1.5785944,,,,,,,,,,,,,, -7700,0.48233527,1.5410028,,,,,,,,,,,,,, -7800,0.44617888,1.6157097,,,,,,,,,,,,,, -7900,0.56082404,1.5370274,,,,,,,,,,,,,, -8000,0.6704217,1.5805763,,,,,,,,,,,,,, -8100,0.56243306,1.5787119,,,,,,,,,,,,,, -8200,0.47977334,1.6449751,,,,,,,,,,,,,, -8300,0.6750922,1.5573286,,,,,,,,,,,,,, -8400,0.4660286,1.519099,,,,,,,,,,,,,, -8500,0.68168235,1.546907,,,,,,,,,,,,,, -8600,0.46653256,1.5472971,,,,,,,,,,,,,, -8656,,,0.34197354,0.1229206984278849,0.6259623,0.1901097734052926,5348.0,0.3972195,0.1341376718867426,2472.0,7260.559650659561,8107.523503780365,7260.559650659561,846.3146262168884,0.2651317119598388,0.0 -8700,0.42001566,1.5783244,,,,,,,,,,,,,, -8800,0.4802701,1.5321724,,,,,,,,,,,,,, -8900,0.56544834,1.4928296,,,,,,,,,,,,,, -9000,0.44586596,1.5084078,,,,,,,,,,,,,, -9100,0.40512604,1.497369,,,,,,,,,,,,,, -9200,0.4750546,1.580851,,,,,,,,,,,,,, -9300,0.51605105,1.5538152,,,,,,,,,,,,,, -9400,0.53490126,1.5347596,,,,,,,,,,,,,, -9500,0.5014186,1.4857898,,,,,,,,,,,,,, -9600,0.40946102,1.5334461,,,,,,,,,,,,,, -9700,0.53050745,1.4786745,,,,,,,,,,,,,, -9800,0.5144187,1.5113944,,,,,,,,,,,,,, -9900,0.4353577,1.4914705,,,,,,,,,,,,,, -10000,0.51249874,1.4910023,,,,,,,,,,,,,, -10100,0.41406658,1.4940069,,,,,,,,,,,,,, -10200,0.59656286,1.5263935,,,,,,,,,,,,,, -10300,0.5079907,1.4993671,,,,,,,,,,,,,, -10352,,,0.33752048,0.1192898410784685,0.5941883,0.1806482134064512,5348.0,0.376258,0.1259114821359657,2472.0,8700.735800027847,9679.307429790497,8700.735800027847,977.789181470871,0.3199388980865478,0.0 -10400,0.5210014,1.5053493,,,,,,,,,,,,,, -10500,0.48872718,1.4965495,,,,,,,,,,,,,, -10600,0.5334268,1.4543877,,,,,,,,,,,,,, -10700,0.52043295,1.4723814,,,,,,,,,,,,,, -10800,0.47989336,1.4395243,,,,,,,,,,,,,, -10900,0.49107593,1.4797583,,,,,,,,,,,,,, -11000,0.53654104,1.4997318,,,,,,,,,,,,,, -11100,0.57518667,1.5204651,,,,,,,,,,,,,, -11200,0.5051192,1.501728,,,,,,,,,,,,,, -11300,0.59761244,1.5142986,,,,,,,,,,,,,, -11400,0.45854375,1.3854022,,,,,,,,,,,,,, -11500,0.66878843,1.5316961,,,,,,,,,,,,,, -11600,0.5425328,1.563238,,,,,,,,,,,,,, -11700,0.49897105,1.4940991,,,,,,,,,,,,,, -11800,0.78800917,1.5048527,,,,,,,,,,,,,, -11900,0.45756173,1.4575549,,,,,,,,,,,,,, -12000,0.535509,1.4720718,,,,,,,,,,,,,, -12066,,,0.29464692,0.107576627741263,0.5686792,0.1717755872442724,5348.0,0.3550018,0.1176852923851888,2472.0,10141.14683175087,11254.12031364441,10141.14683175087,1112.065169095993,0.3683531284332275,0.0 -12100,0.56228614,1.4632825,,,,,,,,,,,,,, -12200,0.4195616,1.4634403,,,,,,,,,,,,,, -12300,0.49559575,1.5057808,,,,,,,,,,,,,, -12400,0.6431221,1.4249315,,,,,,,,,,,,,, -12500,0.7020245,1.4557018,,,,,,,,,,,,,, -12600,0.5056717,1.4550297,,,,,,,,,,,,,, -12700,0.5698498,1.4992787,,,,,,,,,,,,,, -12800,0.513373,1.5052912,,,,,,,,,,,,,, -12900,0.47610494,1.4618942,,,,,,,,,,,,,, -13000,0.51390845,1.4171326,,,,,,,,,,,,,, -13100,0.6049439,1.49151,,,,,,,,,,,,,, -13200,0.46226588,1.4678265,,,,,,,,,,,,,, -13300,0.51275957,1.4168291,,,,,,,,,,,,,, -13400,0.5641527,1.5007416,,,,,,,,,,,,,, -13500,0.47927415,1.3847158,,,,,,,,,,,,,, -13600,0.45361334,1.4288294,,,,,,,,,,,,,, -13700,0.57265735,1.413926,,,,,,,,,,,,,, -13793,,,0.26047683,0.0976832470943082,0.5568923,0.1668806781428309,5348.0,0.34235185,0.1144963743830357,2472.0,11581.195252656937,12826.767081737518,11581.195252656937,1244.5346031188965,0.418978214263916,0.0 -13800,0.6008821,1.4093987,,,,,,,,,,,,,, -13900,0.4521275,1.4548107,,,,,,,,,,,,,, -14000,0.5042685,1.4209267,,,,,,,,,,,,,, -14100,0.58668613,1.4285086,,,,,,,,,,,,,, -14200,0.44312856,1.442074,,,,,,,,,,,,,, -14300,0.4958054,1.4700016,,,,,,,,,,,,,, -14400,0.5283969,1.4509878,,,,,,,,,,,,,, -14500,0.5793564,1.3943444,,,,,,,,,,,,,, -14600,0.45912144,1.3474056,,,,,,,,,,,,,, -14700,0.46088344,1.3834945,,,,,,,,,,,,,, -14800,0.7244008,1.4184419,,,,,,,,,,,,,, -14900,0.5181638,1.4181345,,,,,,,,,,,,,, -15000,0.6767782,1.4117125,,,,,,,,,,,,,, -15100,0.50020194,1.4278376,,,,,,,,,,,,,, -15200,0.5080771,1.4601004,,,,,,,,,,,,,, -15300,0.48863804,1.4536138,,,,,,,,,,,,,, -15400,0.51018333,1.4018719,,,,,,,,,,,,,, -15482,,,0.24350418,0.0904776851502109,0.54537946,0.1634918949187561,5348.0,0.32886645,0.1120792964068815,2472.0,13021.82521033287,14400.499089717863,13021.82521033287,1377.50799202919,0.4708180427551269,0.0 -15500,0.4444665,1.3612224,,,,,,,,,,,,,, -15600,0.4236871,1.382119,,,,,,,,,,,,,, -15700,0.549678,1.3971667,,,,,,,,,,,,,, -15800,0.571566,1.42599,,,,,,,,,,,,,, -15900,0.5775089,1.3855745,,,,,,,,,,,,,, -16000,0.46247858,1.3943421,,,,,,,,,,,,,, -16100,0.4333067,1.3930995,,,,,,,,,,,,,, -16200,0.55030686,1.4253727,,,,,,,,,,,,,, -16300,0.45818916,1.3868527,,,,,,,,,,,,,, -16400,0.40579998,1.3936039,,,,,,,,,,,,,, -16500,0.702695,1.4108762,,,,,,,,,,,,,, -16600,0.43583015,1.4567481,,,,,,,,,,,,,, -16700,0.49615857,1.431422,,,,,,,,,,,,,, -16800,0.48551762,1.3712324,,,,,,,,,,,,,, -16900,0.5115365,1.3744031,,,,,,,,,,,,,, -17000,0.6193065,1.4752672,,,,,,,,,,,,,, -17100,0.42291743,1.3822412,,,,,,,,,,,,,, -17200,0.6460893,1.3961419,,,,,,,,,,,,,, -17201,,,0.24185376,0.0926437385703636,0.5224277,0.1579597787153518,5348.0,0.3163899,0.1059248877785225,2472.0,14461.918625116348,15972.129461288452,14461.918625116348,1508.9135534763336,0.5240156650543213,0.0 -17300,0.46469137,1.4177126,,,,,,,,,,,,,, -17400,0.41276628,1.394956,,,,,,,,,,,,,, -17500,0.44670084,1.460363,,,,,,,,,,,,,, -17600,0.48070046,1.3588943,,,,,,,,,,,,,, -17700,0.44763052,1.3608577,,,,,,,,,,,,,, -17800,0.56295764,1.3856812,,,,,,,,,,,,,, -17900,0.5464649,1.3738914,,,,,,,,,,,,,, -18000,0.46308082,1.3682394,,,,,,,,,,,,,, -18100,0.5779284,1.3843747,,,,,,,,,,,,,, -18200,0.58317256,1.363498,,,,,,,,,,,,,, -18300,0.4234022,1.3743598,,,,,,,,,,,,,, -18400,0.541435,1.4037902,,,,,,,,,,,,,, -18500,0.6101873,1.396595,,,,,,,,,,,,,, -18600,0.63719743,1.3995148,,,,,,,,,,,,,, -18700,0.5109864,1.3269125,,,,,,,,,,,,,, -18800,0.4398593,1.3023207,,,,,,,,,,,,,, -18900,0.49231407,1.3772607,,,,,,,,,,,,,, -18914,,,0.24827659,0.0911023559586451,0.49703676,0.151143593654962,5348.0,0.3022412,0.1006641886539516,2472.0,15902.305790424349,17546.927089452744,15902.305790424349,1643.1955354213717,0.5756280422210693,0.0 -19000,0.4933798,1.361269,,,,,,,,,,,,,, -19100,0.5764381,1.3860382,,,,,,,,,,,,,, -19200,0.458711,1.3353405,,,,,,,,,,,,,, -19300,0.53367263,1.3816761,,,,,,,,,,,,,, -19400,0.5231605,1.3429583,,,,,,,,,,,,,, -19500,0.5490529,1.3992318,,,,,,,,,,,,,, -19600,0.5050558,1.3354928,,,,,,,,,,,,,, -19700,0.56692296,1.3307039,,,,,,,,,,,,,, -19800,0.5318584,1.3740854,,,,,,,,,,,,,, -19900,0.49001807,1.3999369,,,,,,,,,,,,,, -20000,0.42815366,1.3135724,,,,,,,,,,,,,, -20100,0.4674146,1.3183943,,,,,,,,,,,,,, -20200,0.50654215,1.3311327,,,,,,,,,,,,,, -20300,0.56581867,1.3116075,,,,,,,,,,,,,, -20400,0.62245476,1.3504535,,,,,,,,,,,,,, -20500,0.44613972,1.3512725,,,,,,,,,,,,,, -20600,0.38852906,1.3169034,,,,,,,,,,,,,, -20601,,,0.24300954,0.0892447305247052,0.49576217,0.1492512816551937,5348.0,0.29747027,0.0999532833668474,2472.0,17345.91077399254,19122.66784954071,17345.91077399254,1775.2081396579742,0.6244874000549316,0.0 -20700,0.53663063,1.3533063,,,,,,,,,,,,,, -20800,0.5387376,1.3776348,,,,,,,,,,,,,, -20900,0.48939225,1.3768792,,,,,,,,,,,,,, -21000,0.6392652,1.378614,,,,,,,,,,,,,, -21100,0.5826684,1.3447437,,,,,,,,,,,,,, -21200,0.47563902,1.3254317,,,,,,,,,,,,,, -21300,0.4199318,1.3266951,,,,,,,,,,,,,, -21400,0.5542092,1.3393052,,,,,,,,,,,,,, -21500,0.60555154,1.3730912,,,,,,,,,,,,,, -21600,0.5477652,1.3096801,,,,,,,,,,,,,, -21700,0.57125425,1.4072448,,,,,,,,,,,,,, -21800,0.50514406,1.3623006,,,,,,,,,,,,,, -21900,0.6317035,1.3373945,,,,,,,,,,,,,, -22000,0.48896182,1.3427961,,,,,,,,,,,,,, -22100,0.4811362,1.3368239,,,,,,,,,,,,,, -22200,0.47743607,1.3878946,,,,,,,,,,,,,, -22300,0.54827523,1.3263764,,,,,,,,,,,,,, -22318,,,0.22511902,0.084525156909147,0.49216977,0.1476003359819264,5348.0,0.28847873,0.0986533422704283,2472.0,18785.811478853226,20696.58254313469,18785.811478853226,1909.0888845920565,0.6796655654907227,0.0 -22400,0.58532304,1.3465246,,,,,,,,,,,,,, -22500,0.50608206,1.401921,,,,,,,,,,,,,, -22600,0.50381017,1.370112,,,,,,,,,,,,,, -22700,0.573251,1.2512975,,,,,,,,,,,,,, -22800,0.46326858,1.3005211,,,,,,,,,,,,,, -22900,0.54146045,1.3341906,,,,,,,,,,,,,, -23000,0.5846348,1.3649287,,,,,,,,,,,,,, -23100,0.5954285,1.3071549,,,,,,,,,,,,,, -23200,0.5442844,1.2717903,,,,,,,,,,,,,, -23300,0.6748552,1.3728158,,,,,,,,,,,,,, -23400,0.42561784,1.2853624,,,,,,,,,,,,,, -23500,0.53061056,1.3333957,,,,,,,,,,,,,, -23600,0.5864865,1.3340436,,,,,,,,,,,,,, -23700,0.61114365,1.3343071,,,,,,,,,,,,,, -23800,0.4832553,1.3091341,,,,,,,,,,,,,, -23900,0.53105015,1.3492433,,,,,,,,,,,,,, -24000,0.60543203,1.2728294,,,,,,,,,,,,,, -24059,,,0.20986903,0.0783539546471855,0.4833996,0.1443563725537522,5348.0,0.2866742,0.0960534600775902,2472.0,20225.753566741943,22273.176414966583,20225.753566741943,2045.6049239635468,0.736748456954956,0.0 -24100,0.52722895,1.3295366,,,,,,,,,,,,,, -24200,0.48135075,1.3276454,,,,,,,,,,,,,, -24300,0.46550402,1.3397095,,,,,,,,,,,,,, -24400,0.51864177,1.3106098,,,,,,,,,,,,,, -24500,0.49175844,1.3671352,,,,,,,,,,,,,, -24600,0.67673624,1.356062,,,,,,,,,,,,,, -24700,0.5115924,1.3192506,,,,,,,,,,,,,, -24800,0.49691233,1.3221871,,,,,,,,,,,,,, -24900,0.51674366,1.3558849,,,,,,,,,,,,,, -25000,0.5306826,1.3288836,,,,,,,,,,,,,, -25100,0.52600837,1.3055922,,,,,,,,,,,,,, -25200,0.5924492,1.330582,,,,,,,,,,,,,, -25300,0.49138445,1.3045772,,,,,,,,,,,,,, -25400,0.42523155,1.3377529,,,,,,,,,,,,,, -25500,0.66357654,1.2911111,,,,,,,,,,,,,, -25600,0.55103934,1.326043,,,,,,,,,,,,,, -25700,0.57461816,1.3252707,,,,,,,,,,,,,, -25752,,,0.20069353,0.0751923628680977,0.46672383,0.1400697065950935,5348.0,0.27209896,0.0911583693863871,2472.0,21666.226598978043,23848.32180738449,21666.226598978043,2180.148374080658,0.7892570495605469,0.0 -25800,0.5173534,1.2872828,,,,,,,,,,,,,, -25900,0.5082718,1.3258691,,,,,,,,,,,,,, -26000,0.5537464,1.2744559,,,,,,,,,,,,,, -26100,0.5498647,1.2691282,,,,,,,,,,,,,, -26200,0.6616449,1.3025901,,,,,,,,,,,,,, -26300,0.46416113,1.2995344,,,,,,,,,,,,,, -26400,0.459784,1.3157456,,,,,,,,,,,,,, -26500,0.4828027,1.3112047,,,,,,,,,,,,,, -26600,0.48167124,1.2822993,,,,,,,,,,,,,, -26700,0.5848359,1.3314093,,,,,,,,,,,,,, -26800,0.45790863,1.2308996,,,,,,,,,,,,,, -26900,0.51024884,1.332066,,,,,,,,,,,,,, -27000,0.5247218,1.2642051,,,,,,,,,,,,,, -27100,0.6666456,1.2896726,,,,,,,,,,,,,, -27200,0.55198973,1.2804075,,,,,,,,,,,,,, -27300,0.5659509,1.2903674,,,,,,,,,,,,,, -27400,0.43843156,1.207094,,,,,,,,,,,,,, -27476,,,0.19208643,0.0705781878062002,0.46214864,0.1379649922280043,5348.0,0.26997027,0.0923161294253854,2472.0,23106.16258573532,25423.06540942192,23106.16258573532,2314.819569826126,0.847764253616333,0.0 -27500,0.5505274,1.3236477,,,,,,,,,,,,,, -27600,0.5672791,1.2852641,,,,,,,,,,,,,, -27700,0.5107477,1.2785628,,,,,,,,,,,,,, -27800,0.51505345,1.2587223,,,,,,,,,,,,,, -27900,0.48508662,1.3469089,,,,,,,,,,,,,, -28000,0.42859387,1.2591184,,,,,,,,,,,,,, -28100,0.57049614,1.2932818,,,,,,,,,,,,,, -28200,0.52862793,1.2594534,,,,,,,,,,,,,, -28300,0.51373476,1.2732743,,,,,,,,,,,,,, -28400,0.4824853,1.2786114,,,,,,,,,,,,,, -28500,0.5151849,1.2846154,,,,,,,,,,,,,, -28600,0.5027431,1.2578313,,,,,,,,,,,,,, -28700,0.51302797,1.2782842,,,,,,,,,,,,,, -28800,0.4479077,1.3132055,,,,,,,,,,,,,, -28900,0.60262907,1.2722976,,,,,,,,,,,,,, -29000,0.55404186,1.2913043,,,,,,,,,,,,,, -29100,0.5408556,1.2599863,,,,,,,,,,,,,, -29200,0.45266077,1.2469504,,,,,,,,,,,,,, -29202,,,0.20945032,0.0773151759403709,0.45469,0.1361499174527163,5348.0,0.26454708,0.0883756829768651,2472.0,24546.349791288376,26996.2281870842,24546.349791288376,2447.661313056946,0.9036474227905272,0.0 -29300,0.6730195,1.236479,,,,,,,,,,,,,, -29400,0.4918608,1.2012985,,,,,,,,,,,,,, -29500,0.49103352,1.2426318,,,,,,,,,,,,,, -29600,0.4738303,1.3410205,,,,,,,,,,,,,, -29700,0.47046724,1.2879475,,,,,,,,,,,,,, -29800,0.6056663,1.3108373,,,,,,,,,,,,,, -29900,0.49875966,1.2724305,,,,,,,,,,,,,, -30000,0.5342085,1.2339981,,,,,,,,,,,,,, -30100,0.5490313,1.2788599,,,,,,,,,,,,,, -30200,0.5249902,1.2163227,,,,,,,,,,,,,, -30300,0.49500445,1.244789,,,,,,,,,,,,,, -30400,0.50853455,1.248517,,,,,,,,,,,,,, -30500,0.47054625,1.2639567,,,,,,,,,,,,,, -30600,0.50729376,1.2552433,,,,,,,,,,,,,, -30700,0.57748073,1.2306315,,,,,,,,,,,,,, -30800,0.5530192,1.1876308,,,,,,,,,,,,,, -30887,,,0.19461107,0.0707166072523453,0.4496563,0.1345375903916892,5348.0,0.2595338,0.0865882639692889,2472.0,25987.04529428482,28572.57029390335,25987.04529428482,2583.189457654953,0.949357271194458,0.0 -30900,0.5523033,1.2438024,,,,,,,,,,,,,, -31000,0.49636018,1.2526021,,,,,,,,,,,,,, -31100,0.47528484,1.2767175,,,,,,,,,,,,,, -31200,0.55396956,1.2093813,,,,,,,,,,,,,, -31300,0.5134636,1.1854571,,,,,,,,,,,,,, -31400,0.47926018,1.2336682,,,,,,,,,,,,,, -31500,0.5201416,1.1926649,,,,,,,,,,,,,, -31600,0.5254053,1.2562927,,,,,,,,,,,,,, -31700,0.5454342,1.1939461,,,,,,,,,,,,,, -31800,0.5063363,1.2679992,,,,,,,,,,,,,, -31900,0.59764147,1.2428406,,,,,,,,,,,,,, -32000,0.4751736,1.2223191,,,,,,,,,,,,,, -32100,0.5723097,1.2421798,,,,,,,,,,,,,, -32200,0.5717453,1.2795069,,,,,,,,,,,,,, -32300,0.49781367,1.2341307,,,,,,,,,,,,,, -32400,0.51622957,1.162647,,,,,,,,,,,,,, -32500,0.5647153,1.2668531,,,,,,,,,,,,,, -32580,,,0.19953363,0.0724306386738515,0.4363747,0.1306178012493121,5348.0,0.25433442,0.085999228159974,2472.0,27427.04628181457,30148.702434778214,27427.04628181457,2719.184014081955,1.008662462234497,0.0 -32600,0.5393727,1.2605089,,,,,,,,,,,,,, -32700,0.5434083,1.2398627,,,,,,,,,,,,,, -32800,0.6876423,1.2664502,,,,,,,,,,,,,, -32900,0.51491034,1.2481393,,,,,,,,,,,,,, -33000,0.61038846,1.2544076,,,,,,,,,,,,,, -33100,0.526573,1.2253973,,,,,,,,,,,,,, -33200,0.6021006,1.2216254,,,,,,,,,,,,,, -33300,0.46254712,1.2409028,,,,,,,,,,,,,, -33400,0.51885426,1.2032883,,,,,,,,,,,,,, -33500,0.48053047,1.2497374,,,,,,,,,,,,,, -33600,0.51841664,1.295993,,,,,,,,,,,,,, -33700,0.7859152,1.2522233,,,,,,,,,,,,,, -33800,0.4796871,1.2028964,,,,,,,,,,,,,, -33900,0.54905874,1.2366018,,,,,,,,,,,,,, -34000,0.5129345,1.2488471,,,,,,,,,,,,,, -34100,0.48503673,1.240023,,,,,,,,,,,,,, -34200,0.5506688,1.2459301,,,,,,,,,,,,,, -34300,0.50400686,1.2089657,,,,,,,,,,,,,, -34301,,,0.19232655,0.0685497191939213,0.4371391,0.1274800390047983,5348.0,0.2501841,0.0845774175857656,2472.0,28867.922538518906,31725.02043557167,28867.922538518906,2854.495341539383,1.0615310668945312,0.0 -34400,0.57028556,1.2352068,,,,,,,,,,,,,, -34500,0.6201931,1.2878602,,,,,,,,,,,,,, -34600,0.5490138,1.2568798,,,,,,,,,,,,,, -34700,0.47098756,1.2579374,,,,,,,,,,,,,, -34800,0.56520563,1.2781203,,,,,,,,,,,,,, -34900,0.55966735,1.2098578,,,,,,,,,,,,,, -35000,0.5649347,1.2698259,,,,,,,,,,,,,, -35100,0.42943504,1.2403715,,,,,,,,,,,,,, -35200,0.4848799,1.225651,,,,,,,,,,,,,, -35300,0.5413011,1.2015141,,,,,,,,,,,,,, -35400,0.511577,1.1898526,,,,,,,,,,,,,, -35500,0.42873782,1.2047937,,,,,,,,,,,,,, -35600,0.5490136,1.2172471,,,,,,,,,,,,,, -35700,0.58578193,1.278494,,,,,,,,,,,,,, -35800,0.5786517,1.1963804,,,,,,,,,,,,,, -35900,0.4908615,1.222695,,,,,,,,,,,,,, -35987,,,0.14790924,0.056189334071318,0.4280118,0.1253753246377091,5348.0,0.24394017,0.0813884995836126,2472.0,30308.074434041977,33299.92118215561,30308.074434041977,2989.1104102134705,1.118565559387207,0.0 -36000,0.64704627,1.2730193,,,,,,,,,,,,,, -36100,0.52004933,1.2256638,,,,,,,,,,,,,, -36200,0.4508289,1.1925955,,,,,,,,,,,,,, -36300,0.5211252,1.2145673,,,,,,,,,,,,,, -36400,0.5990879,1.2102195,,,,,,,,,,,,,, -36500,0.5374473,1.1787337,,,,,,,,,,,,,, -36600,0.58000386,1.220211,,,,,,,,,,,,,, -36700,0.51553607,1.2235568,,,,,,,,,,,,,, -36800,0.44886953,1.1933904,,,,,,,,,,,,,, -36900,0.5091619,1.2971207,,,,,,,,,,,,,, -37000,0.5864624,1.2273811,,,,,,,,,,,,,, -37100,0.5606028,1.2239268,,,,,,,,,,,,,, -37200,0.6545253,1.1291009,,,,,,,,,,,,,, -37300,0.5271565,1.2083745,,,,,,,,,,,,,, -37400,0.4577623,1.173376,,,,,,,,,,,,,, -37500,0.5059442,1.274235,,,,,,,,,,,,,, -37600,0.5472004,1.2197455,,,,,,,,,,,,,, -37685,,,0.16978231,0.0630979282380519,0.4183739,0.1235023219440609,5348.0,0.24056098,0.0818962890744013,2472.0,31748.535685777664,34875.30942106247,31748.535685777664,3123.903793334961,1.1749954223632812,0.0 -37700,0.6527794,1.1927243,,,,,,,,,,,,,, -37800,0.52242,1.1928949,,,,,,,,,,,,,, -37900,0.4811289,1.203265,,,,,,,,,,,,,, -38000,0.50021017,1.2343905,,,,,,,,,,,,,, -38100,0.5440761,1.2330804,,,,,,,,,,,,,, -38200,0.5248215,1.1712354,,,,,,,,,,,,,, -38300,0.54986405,1.1813048,,,,,,,,,,,,,, -38400,0.5046133,1.1934328,,,,,,,,,,,,,, -38500,0.5882744,1.1626413,,,,,,,,,,,,,, -38600,0.52963793,1.1568851,,,,,,,,,,,,,, -38700,0.48768976,1.1135854,,,,,,,,,,,,,, -38800,0.55126756,1.1641737,,,,,,,,,,,,,, -38900,0.5293766,1.1853076,,,,,,,,,,,,,, -39000,0.5378698,1.1317345,,,,,,,,,,,,,, -39100,0.63722384,1.2079597,,,,,,,,,,,,,, -39200,0.5063251,1.1723204,,,,,,,,,,,,,, -39300,0.6307059,1.2292231,,,,,,,,,,,,,, -39389,,,0.20656164,0.0754706747015399,0.41282338,0.1211272772912905,5348.0,0.2356621,0.0799870005890358,2472.0,33188.79012131691,36448.38415312767,33188.79012131691,3256.5838243961334,1.236691951751709,0.0 -39400,0.5538646,1.1800522,,,,,,,,,,,,,, -39500,0.5633708,1.1773252,,,,,,,,,,,,,, -39600,0.7178083,1.1918129,,,,,,,,,,,,,, -39700,0.63911176,1.1399975,,,,,,,,,,,,,, -39800,0.49043003,1.1945622,,,,,,,,,,,,,, -39900,0.5166381,1.2059375,,,,,,,,,,,,,, -40000,0.46546367,1.1586654,,,,,,,,,,,,,, -40100,0.60385716,1.1959205,,,,,,,,,,,,,, -40200,0.49251878,1.1519027,,,,,,,,,,,,,, -40300,0.5550084,1.1677047,,,,,,,,,,,,,, -40400,0.5306467,1.1805447,,,,,,,,,,,,,, -40500,0.45841452,1.1699353,,,,,,,,,,,,,, -40600,0.5479321,1.1685885,,,,,,,,,,,,,, -40700,0.48009503,1.1489385,,,,,,,,,,,,,, -40800,0.59429127,1.1935894,,,,,,,,,,,,,, -40900,0.5398016,1.1852474,,,,,,,,,,,,,, -41000,0.6292251,1.1956223,,,,,,,,,,,,,, -41079,,,0.20947912,0.0772079028739237,0.4047456,0.1206059260260482,5348.0,0.22915868,0.0767574594276196,2472.0,34628.73562026024,38020.64000558853,34628.73562026024,3388.7605333328247,1.292815923690796,0.0 -41100,0.5440263,1.1202172,,,,,,,,,,,,,, -41200,0.5512872,1.1971277,,,,,,,,,,,,,, -41300,0.5448409,1.1466972,,,,,,,,,,,,,, -41400,0.44503683,1.179181,,,,,,,,,,,,,, -41500,0.48840302,1.1208622,,,,,,,,,,,,,, -41600,0.58576095,1.1871649,,,,,,,,,,,,,, -41700,0.5176872,1.1751922,,,,,,,,,,,,,, -41800,0.6105946,1.1943259,,,,,,,,,,,,,, -41900,0.57919914,1.1388624,,,,,,,,,,,,,, -42000,0.56076664,1.1660391,,,,,,,,,,,,,, -42100,0.5420509,1.174585,,,,,,,,,,,,,, -42200,0.5240316,1.147868,,,,,,,,,,,,,, -42300,0.58177257,1.1763656,,,,,,,,,,,,,, -42400,0.51068765,1.1533331,,,,,,,,,,,,,, -42500,0.55410415,1.1453717,,,,,,,,,,,,,, -42600,0.54472315,1.1636014,,,,,,,,,,,,,, -42700,0.5055589,1.1643087,,,,,,,,,,,,,, -42780,,,0.24113484,0.0889962620867968,0.39844996,0.1170433590468926,5348.0,0.22481863,0.0739747730180976,2472.0,36068.85384774208,39593.12744855881,36068.85384774208,3520.995194911957,1.3495397567749023,0.0 -42800,0.5748422,1.2210393,,,,,,,,,,,,,, -42900,0.5398086,1.1236767,,,,,,,,,,,,,, -43000,0.49434447,1.1291331,,,,,,,,,,,,,, -43100,0.6303435,1.1818845,,,,,,,,,,,,,, -43200,0.6420578,1.1765765,,,,,,,,,,,,,, -43300,0.57600313,1.1393348,,,,,,,,,,,,,, -43400,0.55401355,1.1870492,,,,,,,,,,,,,, -43500,0.5338219,1.1051065,,,,,,,,,,,,,, -43600,0.547519,1.1101762,,,,,,,,,,,,,, -43700,0.5522978,1.1349138,,,,,,,,,,,,,, -43800,0.4601527,1.1426656,,,,,,,,,,,,,, -43900,0.56978554,1.157097,,,,,,,,,,,,,, -44000,0.5157148,1.1129208,,,,,,,,,,,,,, -44100,0.65242577,1.1737623,,,,,,,,,,,,,, -44200,0.5630839,1.1450133,,,,,,,,,,,,,, -44300,0.578274,1.1409792,,,,,,,,,,,,,, -44400,0.4647117,1.1569493,,,,,,,,,,,,,, -44481,,,0.20987618,0.0761205541392648,0.39491072,0.1156337797001264,5348.0,0.22069937,0.0736700993236244,2472.0,37509.35020899773,41163.16731405258,37509.35020899773,3650.400237083435,1.4104692935943604,0.0 -44500,0.57215285,1.1124859,,,,,,,,,,,,,, -44600,0.55757475,1.1059419,,,,,,,,,,,,,, -44700,0.7424732,1.1218802,,,,,,,,,,,,,, -44800,0.608096,1.1772245,,,,,,,,,,,,,, -44900,0.5873889,1.1366212,,,,,,,,,,,,,, -45000,0.6220039,1.0714198,,,,,,,,,,,,,, -45100,0.6556262,1.1227286,,,,,,,,,,,,,, -45200,0.46346763,1.1533407,,,,,,,,,,,,,, -45300,0.6440426,1.1541455,,,,,,,,,,,,,, -45400,0.55035555,1.0890834,,,,,,,,,,,,,, -45500,0.6510178,1.1376085,,,,,,,,,,,,,, -45600,0.68539315,1.10401,,,,,,,,,,,,,, -45700,0.5458729,1.1183909,,,,,,,,,,,,,, -45800,0.5622968,1.1134422,,,,,,,,,,,,,, -45900,0.6327856,1.1365455,,,,,,,,,,,,,, -46000,0.5621381,1.0902739,,,,,,,,,,,,,, -46100,0.5791691,1.1370648,,,,,,,,,,,,,, -46168,,,0.19044133,0.0708583389932415,0.38421032,0.1118974289658901,5348.0,0.21279301,0.0713545792456279,2472.0,38949.5586669445,42735.63666701317,38949.5586669445,3782.523449420929,1.4723756313323977,0.0 -46200,0.56289846,1.1166215,,,,,,,,,,,,,, -46300,0.6410952,1.102474,,,,,,,,,,,,,, -46400,0.64524186,1.1029758,,,,,,,,,,,,,, -46500,0.6155278,1.1228313,,,,,,,,,,,,,, -46600,0.5804477,1.0554179,,,,,,,,,,,,,, -46700,0.7293852,1.151149,,,,,,,,,,,,,, -46800,0.53591174,1.1092616,,,,,,,,,,,,,, -46900,0.6080431,1.1681011,,,,,,,,,,,,,, -47000,0.58272296,1.1266694,,,,,,,,,,,,,, -47100,0.5586092,1.1417929,,,,,,,,,,,,,, -47200,0.55712414,1.1126374,,,,,,,,,,,,,, -47300,0.52069306,1.1179357,,,,,,,,,,,,,, -47400,0.6188109,1.0395852,,,,,,,,,,,,,, -47500,0.49722567,1.1131676,,,,,,,,,,,,,, -47600,0.58787036,1.1047523,,,,,,,,,,,,,, -47700,0.5593672,1.1811053,,,,,,,,,,,,,, -47800,0.6008416,1.1086893,,,,,,,,,,,,,, -47870,,,0.16163252,0.0609848937998238,0.3805211,0.1111636753333269,5348.0,0.21163197,0.0711311518696809,2472.0,40390.47105741501,44311.951271533966,40390.47105741501,3917.786875724793,1.5329391956329346,0.0 -47900,0.5646407,1.1390686,,,,,,,,,,,,,, -48000,0.6160437,1.0908448,,,,,,,,,,,,,, -48100,0.57542753,1.1186342,,,,,,,,,,,,,, -48200,0.52195734,1.1325768,,,,,,,,,,,,,, -48300,0.6453754,1.1270547,,,,,,,,,,,,,, -48400,0.6192215,1.1055542,,,,,,,,,,,,,, -48500,0.55998695,1.0879474,,,,,,,,,,,,,, -48600,0.5314685,1.0901427,,,,,,,,,,,,,, -48700,0.712982,1.1101031,,,,,,,,,,,,,, -48800,0.6591366,1.0616789,,,,,,,,,,,,,, -48900,0.52340746,1.0389994,,,,,,,,,,,,,, -49000,0.5702867,1.1225597,,,,,,,,,,,,,, -49100,0.51035273,1.0979811,,,,,,,,,,,,,, -49200,0.70529217,1.080151,,,,,,,,,,,,,, -49300,0.5262272,1.0917302,,,,,,,,,,,,,, -49400,0.6613891,1.0573274,,,,,,,,,,,,,, -49500,0.54006135,1.064002,,,,,,,,,,,,,, -49570,,,0.17555173,0.0661228391273414,0.37177685,0.108219006150014,5348.0,0.20468043,0.0681859728231064,2472.0,41830.47101521492,45888.17083525658,41830.47101521492,4053.8710753917694,1.5911128520965576,0.0 -49600,0.7687556,1.0710025,,,,,,,,,,,,,, -49700,0.52633154,1.0832584,,,,,,,,,,,,,, -49800,0.5515113,1.0358384,,,,,,,,,,,,,, -49900,0.68912613,1.1219479,,,,,,,,,,,,,, -50000,0.6455719,1.128425,,,,,,,,,,,,,, -50100,0.72292477,1.1082286,,,,,,,,,,,,,, -50200,0.75369,1.0528362,,,,,,,,,,,,,, -50300,0.6219534,1.0744219,,,,,,,,,,,,,, -50400,0.67513394,1.0558797,,,,,,,,,,,,,, -50500,0.5942749,1.0588254,,,,,,,,,,,,,, -50600,0.6220316,1.0623659,,,,,,,,,,,,,, -50700,0.6131078,1.0575604,,,,,,,,,,,,,, -50800,0.5888107,1.1243664,,,,,,,,,,,,,, -50900,0.73125875,1.0263102,,,,,,,,,,,,,, -51000,0.5450083,1.1023518,,,,,,,,,,,,,, -51100,0.6779002,1.1078594,,,,,,,,,,,,,, -51200,0.55487293,1.0236769,,,,,,,,,,,,,, -51275,,,0.1526628,0.0584414190001717,0.36734933,0.105718451007463,5348.0,0.19672832,0.0654845327321105,2472.0,43271.03357815743,47463.67045736313,43271.03357815743,4188.673507928848,1.6475038528442385,0.0 -51300,0.5591421,1.0995549,,,,,,,,,,,,,, -51400,0.61874235,1.1159555,,,,,,,,,,,,,, -51500,0.5846638,1.0178424,,,,,,,,,,,,,, -51600,0.5514328,1.0653352,,,,,,,,,,,,,, -51700,0.6068964,1.0228014,,,,,,,,,,,,,, -51800,0.6207022,1.1020592,,,,,,,,,,,,,, -51900,0.61223716,0.9831967,,,,,,,,,,,,,, -52000,0.5680948,1.0752082,,,,,,,,,,,,,, -52100,0.580778,1.0884,,,,,,,,,,,,,, -52200,0.5751392,1.0078716,,,,,,,,,,,,,, -52300,0.7471391,1.0436784,,,,,,,,,,,,,, -52400,0.6685976,1.0485641,,,,,,,,,,,,,, -52500,0.73465306,1.0747865,,,,,,,,,,,,,, -52600,0.53673255,1.0453299,,,,,,,,,,,,,, -52700,0.5044871,1.0682614,,,,,,,,,,,,,, -52800,0.6615218,1.1087302,,,,,,,,,,,,,, -52900,0.56710255,1.0695708,,,,,,,,,,,,,, -53000,0.5609919,1.0831279,,,,,,,,,,,,,, -53003,,,0.1528462,0.0582670673363781,0.35978332,0.1029379109261708,5348.0,0.1953285,0.0639611642597444,2472.0,44711.0066754818,49037.26174616814,44711.0066754818,4322.077683210373,1.7828662395477295,0.0 -53100,0.58426166,1.0726436,,,,,,,,,,,,,, -53200,0.58633816,1.0859202,,,,,,,,,,,,,, -53300,0.5425361,1.065964,,,,,,,,,,,,,, -53400,0.9136922,1.0442861,,,,,,,,,,,,,, -53500,0.6092821,1.0677199,,,,,,,,,,,,,, -53600,0.6492517,1.0565736,,,,,,,,,,,,,, -53700,0.59231496,1.0321822,,,,,,,,,,,,,, -53800,0.49679017,1.0408947,,,,,,,,,,,,,, -53900,0.6460146,1.032358,,,,,,,,,,,,,, -54000,0.7208692,1.0855507,,,,,,,,,,,,,, -54100,0.5646248,1.0661428,,,,,,,,,,,,,, -54200,0.5318858,0.9885714,,,,,,,,,,,,,, -54300,0.605719,1.0697769,,,,,,,,,,,,,, -54400,0.5728374,1.0444064,,,,,,,,,,,,,, -54500,0.5222443,1.0082219,,,,,,,,,,,,,, -54600,0.53357667,1.0062563,,,,,,,,,,,,,, -54694,,,0.14036037,0.0533664596273291,0.34712532,0.0998098033347171,5348.0,0.19122066,0.0628440273800093,2472.0,46151.02938437462,50610.30272769928,46151.02938437462,4454.960786104202,1.8396832942962649,0.0 -54700,0.7433691,0.9636934,,,,,,,,,,,,,, -54800,0.6540028,1.0561544,,,,,,,,,,,,,, -54900,0.59495753,1.0491602,,,,,,,,,,,,,, -55000,0.8201947,1.0319796,,,,,,,,,,,,,, -55100,0.48649243,1.0311027,,,,,,,,,,,,,, -55200,0.67826784,1.0406827,,,,,,,,,,,,,, -55300,0.6463435,1.0535051,,,,,,,,,,,,,, -55400,0.5674117,0.9860587,,,,,,,,,,,,,, -55500,0.6489888,1.0663154,,,,,,,,,,,,,, -55600,0.8685952,1.0319418,,,,,,,,,,,,,, -55700,0.55642426,0.97343445,,,,,,,,,,,,,, -55800,0.6088525,1.0148137,,,,,,,,,,,,,, -55900,0.61484146,1.0069523,,,,,,,,,,,,,, -56000,0.6440556,1.045407,,,,,,,,,,,,,, -56100,0.5619432,1.0414435,,,,,,,,,,,,,, -56200,0.95878005,1.0026047,,,,,,,,,,,,,, -56300,0.56136507,0.96326905,,,,,,,,,,,,,, -56386,,,0.14347468,0.0545997733594313,0.34344375,0.0978305994574085,5348.0,0.18559092,0.0613409704872748,2472.0,47591.55652046204,52183.42995285988,47591.55652046204,4587.426983118057,1.8970699310302728,0.0 -56400,0.6347106,1.0305562,,,,,,,,,,,,,, -56500,0.69139016,1.1134638,,,,,,,,,,,,,, -56600,0.57853574,1.0790056,,,,,,,,,,,,,, -56700,0.662607,1.0165862,,,,,,,,,,,,,, -56800,0.66314995,1.0365624,,,,,,,,,,,,,, -56900,0.57247543,1.0152907,,,,,,,,,,,,,, -57000,0.75140226,0.9735837,,,,,,,,,,,,,, -57100,0.6332055,1.054817,,,,,,,,,,,,,, -57200,0.6035238,1.0473373,,,,,,,,,,,,,, -57300,0.5416547,1.049994,,,,,,,,,,,,,, -57400,0.5937346,1.0143135,,,,,,,,,,,,,, -57500,0.61788476,1.0065495,,,,,,,,,,,,,, -57600,0.7127083,1.0256397,,,,,,,,,,,,,, -57700,0.5867409,1.0093632,,,,,,,,,,,,,, -57800,0.67231154,1.0396718,,,,,,,,,,,,,, -57900,0.7482468,1.0280862,,,,,,,,,,,,,, -58000,0.56963754,0.99865985,,,,,,,,,,,,,, -58100,0.5744794,0.99535066,,,,,,,,,,,,,, -58103,,,0.14602813,0.0549668369291758,0.33937812,0.0967878969269239,5348.0,0.182006,0.0589442040907521,2472.0,49031.58073544502,53756.42091464996,49031.58073544502,4720.261053800583,1.95191502571106,0.0 -58200,0.60587716,1.02622,,,,,,,,,,,,,, -58300,0.79023635,1.0427642,,,,,,,,,,,,,, -58400,0.6746489,0.9859515,,,,,,,,,,,,,, -58500,1.442157,1.0281229,,,,,,,,,,,,,, -58600,0.74443716,1.0246587,,,,,,,,,,,,,, -58700,0.60403043,1.0126164,,,,,,,,,,,,,, -58800,0.6137409,0.9694601,,,,,,,,,,,,,, -58900,0.66435426,0.9816127,,,,,,,,,,,,,, -59000,0.5876113,1.0336297,,,,,,,,,,,,,, -59100,0.61584574,0.9698924,,,,,,,,,,,,,, -59200,0.6052798,0.99973345,,,,,,,,,,,,,, -59300,0.58808976,0.99667215,,,,,,,,,,,,,, -59400,0.6737974,0.97964233,,,,,,,,,,,,,, -59500,0.7009091,1.0086352,,,,,,,,,,,,,, -59600,0.59879214,1.0159699,,,,,,,,,,,,,, -59700,0.6393594,1.0329018,,,,,,,,,,,,,, -59760,,,0.11838623,0.0463940637119734,0.33321467,0.0941425219884723,5348.0,0.17941004,0.0587004651351735,2472.0,50471.67391204834,55327.64785504341,50471.67391204834,4851.256810903549,2.0137839317321777,0.0 -59800,0.8121567,1.0498534,,,,,,,,,,,,,, -59900,0.59397167,0.99021226,,,,,,,,,,,,,, -60000,0.60727173,0.9262889,,,,,,,,,,,,,, -60100,0.8872603,0.98962104,,,,,,,,,,,,,, -60200,0.5399719,0.97783,,,,,,,,,,,,,, -60300,0.7134057,1.0062369,,,,,,,,,,,,,, -60400,0.54468775,1.012427,,,,,,,,,,,,,, -60500,0.63735574,1.0020646,,,,,,,,,,,,,, -60600,0.67466843,0.9659598,,,,,,,,,,,,,, -60700,0.557173,0.9986769,,,,,,,,,,,,,, -60800,0.670581,0.9381825,,,,,,,,,,,,,, -60900,1.3899635,0.9662462,,,,,,,,,,,,,, -61000,0.633352,0.96975094,,,,,,,,,,,,,, -61100,0.6220092,1.0024827,,,,,,,,,,,,,, -61200,0.729809,0.9473365,,,,,,,,,,,,,, -61300,0.5960996,0.95678777,,,,,,,,,,,,,, -61400,0.6698894,0.98863673,,,,,,,,,,,,,, -61455,,,0.11842466,0.0453082926384008,0.3269322,0.0919798797030228,5348.0,0.17505817,0.0565474376942294,2472.0,51912.141885757446,56898.16695809364,51912.141885757446,4981.171031713486,2.0737240314483643,0.0 -61500,0.7406939,1.0259894,,,,,,,,,,,,,, -61600,0.5969941,0.96177113,,,,,,,,,,,,,, -61700,0.6561158,1.0181241,,,,,,,,,,,,,, -61800,0.80965334,0.9409519,,,,,,,,,,,,,, -61900,0.6856183,0.91964936,,,,,,,,,,,,,, -62000,0.59671843,0.961733,,,,,,,,,,,,,, -62100,0.7291799,0.9267054,,,,,,,,,,,,,, -62200,0.61302775,0.9732567,,,,,,,,,,,,,, -62300,0.9550302,0.96895635,,,,,,,,,,,,,, -62400,0.5720148,0.9667026,,,,,,,,,,,,,, -62500,0.7246042,1.0148109,,,,,,,,,,,,,, -62600,0.74472845,1.0152769,,,,,,,,,,,,,, -62700,0.6076742,0.9698902,,,,,,,,,,,,,, -62800,0.60638195,0.9798027,,,,,,,,,,,,,, -62900,0.87588775,0.90380746,,,,,,,,,,,,,, -63000,0.89979523,1.0227453,,,,,,,,,,,,,, -63100,0.9170927,0.9712725,,,,,,,,,,,,,, -63181,,,0.10824333,0.0414118644824353,0.32207102,0.0907151201521573,5348.0,0.17149302,0.0556537281904413,2472.0,53352.4820356369,58470.85765528679,53352.4820356369,5113.37780046463,2.138718366622925,0.0 -63200,0.63751644,0.9421352,,,,,,,,,,,,,, -63300,0.64949924,0.97436374,,,,,,,,,,,,,, -63400,0.67351216,0.96635556,,,,,,,,,,,,,, -63500,0.68833023,0.9521408,,,,,,,,,,,,,, -63600,0.6742794,0.95417804,,,,,,,,,,,,,, -63700,0.82453495,0.99271923,,,,,,,,,,,,,, -63800,0.7141834,0.99369884,,,,,,,,,,,,,, -63900,0.6080753,0.91877365,,,,,,,,,,,,,, -64000,0.64839906,0.9525773,,,,,,,,,,,,,, -64100,0.91033113,0.9914424,,,,,,,,,,,,,, -64200,0.5728976,0.928948,,,,,,,,,,,,,, -64300,0.5667069,0.9573948,,,,,,,,,,,,,, -64400,0.6431784,0.96369976,,,,,,,,,,,,,, -64500,0.641397,0.9463582,,,,,,,,,,,,,, -64600,0.63442296,0.93215746,,,,,,,,,,,,,, -64700,0.71351004,0.93193483,,,,,,,,,,,,,, -64800,0.7351352,0.9636696,,,,,,,,,,,,,, -64865,,,0.10510961,0.0403559614432038,0.31513014,0.0883111115402068,5348.0,0.16992109,0.0562021408404931,2472.0,54793.28907322884,60041.30572772026,54793.28907322884,5242.881098985672,2.200098752975464,0.0 -64900,0.710928,0.9335663,,,,,,,,,,,,,, -65000,0.6153405,0.9517366,,,,,,,,,,,,,, -65100,0.6339288,0.9632576,,,,,,,,,,,,,, -65200,0.6910448,0.9422166,,,,,,,,,,,,,, -65300,0.5975505,0.95549643,,,,,,,,,,,,,, -65400,0.60622716,0.9650783,,,,,,,,,,,,,, -65500,0.68479216,0.9039671,,,,,,,,,,,,,, -65600,0.5490499,0.9431274,,,,,,,,,,,,,, -65700,0.58534384,0.91331244,,,,,,,,,,,,,, -65800,0.68240064,0.9147405,,,,,,,,,,,,,, -65900,0.7985747,1.0042775,,,,,,,,,,,,,, -66000,0.5760781,0.9269548,,,,,,,,,,,,,, -66100,0.677327,0.9358556,,,,,,,,,,,,,, -66200,0.6589244,0.9497443,,,,,,,,,,,,,, -66300,0.59286374,0.93705183,,,,,,,,,,,,,, -66400,0.9328472,0.9598103,,,,,,,,,,,,,, -66500,0.94627947,0.9436925,,,,,,,,,,,,,, -66552,,,0.099286236,0.0379463786587956,0.31221068,0.0875387392954034,5348.0,0.16586737,0.0532163386346556,2472.0,56233.88724684715,61615.08142876625,56233.88724684715,5375.917708158493,2.264157295227051,0.0 -66600,0.6466891,0.913873,,,,,,,,,,,,,, -66700,0.6045256,0.9396479,,,,,,,,,,,,,, -66800,0.6408908,0.9213682,,,,,,,,,,,,,, -66900,0.7448735,0.9551482,,,,,,,,,,,,,, -67000,0.7192597,0.911898,,,,,,,,,,,,,, -67100,0.65638524,0.918137,,,,,,,,,,,,,, -67200,0.68291616,0.92448276,,,,,,,,,,,,,, -67300,0.6730958,0.9249699,,,,,,,,,,,,,, -67400,0.75619483,0.8840512,,,,,,,,,,,,,, -67500,1.1651233,0.8882692,,,,,,,,,,,,,, -67600,0.66254556,0.91148406,,,,,,,,,,,,,, -67700,0.7661268,0.96225077,,,,,,,,,,,,,, -67800,0.66029406,0.9010057,,,,,,,,,,,,,, -67900,0.87099713,0.9735145,,,,,,,,,,,,,, -68000,0.751072,0.855928,,,,,,,,,,,,,, -68100,0.68037057,0.9135611,,,,,,,,,,,,,, -68200,0.6520939,0.92462564,,,,,,,,,,,,,, -68272,,,0.09623108,0.0364278361356164,0.3052696,0.0855402261119746,5348.0,0.16405582,0.05274917230313,2472.0,57674.50718331337,63187.40218567848,57674.50718331337,5507.484800100327,2.319945812225342,0.0 -68272,,,,,,,,,,,57674.50718331337,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv deleted file mode 100644 index 87f965f53..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,27 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/ctc_loss,test/num_examples,test/wer,total_duration,train/ctc_loss,train/wer,validation/ctc_loss,validation/num_examples,validation/wer -220.257333278656,0.0,41.24587440490723,1,0,41.24587440490723,30.225792,2472,2.567668027542502,261.5032813549042,30.773157,2.637772525664785,30.24853,5348,2.327167228245653 -349.8941643238068,0.0495276451110839,1481.7202219963074,1770,0,1481.7202219963074,1.8507755,2472,0.4475047224422643,1831.736894369125,1.8952523,0.4500885666618607,2.3708937,5348,0.5239580215684949 -487.8131248950958,0.0849542617797851,2921.9824163913727,3526,0,2921.9824163913727,0.58842236,2472,0.1846728820100339,3410.026333808899,0.55487686,0.1804650667234903,0.9077002,5348,0.2566689516012242 -629.6420524120331,0.1239755153656005,4362.2189383506775,5254,0,4362.2189383506775,0.4845866,2472,0.1564804094814453,4992.204474925995,0.41640094,0.1434603159542672,0.78842723,5348,0.2248665244214449 -764.7253262996674,0.1659708023071289,5802.560395002365,6987,0,5802.560395002365,0.47377455,2472,0.1527633904088721,6567.744369745255,0.4246011,0.1421106095990114,0.77591664,5348,0.2211881016055688 -901.5812139511108,0.2061090469360351,7242.571328878403,8720,0,7242.571328878403,0.4933475,2472,0.1575163000426543,8144.72531414032,0.45573476,0.1526234534345169,0.811886,5348,0.2290566438495032 -1035.2008595466614,0.2446341514587402,8683.132089853287,10454,0,8683.132089853287,0.425938,2472,0.1363110109073182,9719.01669192314,0.40294805,0.1328667225034809,0.729496,5348,0.2057310020564411 -1169.5552642345428,0.2813897132873535,10123.701634168625,12171,0,10123.701634168625,0.4166353,2472,0.1336908171348485,11294.05009317398,0.3721852,0.1251940952502499,0.69942635,5348,0.2015795012406229 -1303.403836965561,0.3187751770019531,11563.997570991516,13914,0,11563.997570991516,0.38221413,2472,0.1240834399691264,12868.305127859116,0.30275872,0.1062118099794586,0.65481377,5348,0.1886422661401662 -1436.3366811275482,0.3561115264892578,13004.067729234695,15643,0,13004.067729234695,0.36668622,2472,0.1192289724371864,14441.418841600418,0.2885871,0.0994844894636881,0.6333401,5348,0.1834866814061036 -1571.290831565857,0.3938989639282226,14444.379429101944,17364,0,14444.379429101944,0.35901284,2472,0.1148213596571405,16016.7948179245,0.2926443,0.1049998122992272,0.61829656,5348,0.1771146103864757 -1708.575800895691,0.4329390525817871,15884.863197088242,19102,0,15884.863197088242,0.34121013,2472,0.1102715658196737,17594.676003694534,0.2893605,0.0972368386260874,0.5994233,5348,0.171708004672852 -1842.3055789470675,0.4730331897735595,17325.22800397873,20836,0,17325.22800397873,0.3285923,2472,0.1064123656896796,19168.884200811382,0.28343403,0.0970691429023663,0.5798606,5348,0.1674020294080732 -1977.9283108711245,0.5135478973388672,18766.104904174805,22553,0,18766.104904174805,0.3166594,2472,0.1046655698413665,20745.496968269348,0.26502982,0.0911101965222759,0.56222516,5348,0.1615320003475675 -2113.864481449127,0.5542182922363281,20206.208694934845,24275,0,20206.208694934845,0.30722615,2472,0.0994861170353218,22321.65062570572,0.23827851,0.0825925847726025,0.5401367,5348,0.1562122865114842 -2250.577463388443,0.6050629615783691,21646.24563932419,25997,0,21646.24563932419,0.29538578,2472,0.0940019905348038,23898.528543949127,0.21937364,0.0774772424945877,0.51852757,5348,0.1502843295326182 -2388.410187959671,0.6562862396240234,23086.39237809181,27696,0,23086.39237809181,0.2851266,2472,0.0920927020494383,25476.63623690605,0.21299855,0.0735494164936433,0.5131484,5348,0.1481023779410487 -2522.7200396060944,0.7084314823150635,24526.297978639603,29414,0,24526.297978639603,0.27973896,2472,0.0905083988381776,27050.98180246353,0.22567616,0.078078956491593,0.49853975,5348,0.1427537001457852 -2656.8231077194214,0.765221357345581,25966.211684703827,31122,0,25966.211684703827,0.26477703,2472,0.086222655535921,28625.13384318352,0.2058136,0.0697941931318129,0.49061537,5348,0.1409096614113172 -2794.605189561844,0.818885087966919,27406.36852216721,32821,0,27406.36852216721,0.2588315,2472,0.0836430849227144,30203.20334362984,0.20854647,0.0714135275775804,0.4694706,5348,0.1349141218610309 -2929.3485980033875,0.8737070560455322,28846.32292485237,34533,0,28846.32292485237,0.24803255,2472,0.0806369711372453,31778.03262424469,0.19717774,0.0654965525325831,0.456449,5348,0.131979107330778 -3064.120968580246,0.928691864013672,30286.631365060806,36242,0,30286.631365060806,0.24380776,2472,0.0775902341925131,33353.2452647686,0.14662404,0.0513668091387103,0.44940788,5348,0.1292468405147861 -3200.515196084976,0.9846453666687012,31726.58775639534,37937,0,31726.58775639534,0.23735188,2472,0.0746247435663071,34929.726868867874,0.16366398,0.056223733395282,0.43865654,5348,0.1260994236172123 -3332.5077028274536,1.0392842292785645,33166.52247095108,39660,0,33166.52247095108,0.23088776,2472,0.0733857372087827,36501.78617596626,0.20840412,0.0723192831533442,0.42887554,5348,0.1220927425972947 -3469.1516411304474,1.09200119972229,34606.51906085014,41349,0,34606.51906085014,0.22452182,2472,0.0723904698068368,38078.55527210236,0.2137908,0.073177637020712,0.41952863,5348,0.1200556108016258 -3600.998679161072,1.14786696434021,36047.307216882706,43028,0,36047.307216882706,0.22208491,2472,0.07098897081226006,39651.32257723808,0.2462839,0.08584916145562178,0.4148276,5348,0.11844328374059879 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv deleted file mode 100644 index 96a02f853..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/measurements.csv +++ /dev/null @@ -1,459 +0,0 @@ -global_step,grad_norm,loss,train/ctc_loss,train/wer,validation/ctc_loss,validation/wer,validation/num_examples,test/ctc_loss,test/wer,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,30.846376,33.0205,,,,,,,,,,,,,, -1,,,30.773157,2.637772525664785,30.24853,2.327167228245653,5348.0,30.225792,2.567668027542502,2472.0,41.24587440490723,261.5032813549042,41.24587440490723,220.257333278656,0.0,0.0 -100,0.63511366,5.9927306,,,,,,,,,,,,,, -200,0.48687255,5.794072,,,,,,,,,,,,,, -300,1.514812,5.7173195,,,,,,,,,,,,,, -400,0.54315567,5.4390054,,,,,,,,,,,,,, -500,0.72679937,4.7900643,,,,,,,,,,,,,, -600,1.3676107,3.955211,,,,,,,,,,,,,, -700,1.3750125,3.47933,,,,,,,,,,,,,, -800,2.9061272,3.2451403,,,,,,,,,,,,,, -900,3.492444,3.0355365,,,,,,,,,,,,,, -1000,2.0383568,2.8354504,,,,,,,,,,,,,, -1100,2.3085556,2.6829338,,,,,,,,,,,,,, -1200,3.1721456,2.587499,,,,,,,,,,,,,, -1300,2.214174,2.5402508,,,,,,,,,,,,,, -1400,3.1622832,2.4755979,,,,,,,,,,,,,, -1500,3.1822813,2.3567724,,,,,,,,,,,,,, -1600,1.6441834,2.3567932,,,,,,,,,,,,,, -1700,2.30615,2.246992,,,,,,,,,,,,,, -1770,,,1.8952523,0.4500885666618607,2.3708937,0.5239580215684949,5348.0,1.8507755,0.4475047224422643,2472.0,1481.7202219963074,1831.736894369125,1481.7202219963074,349.8941643238068,0.0495276451110839,0.0 -1800,2.4889183,2.2305784,,,,,,,,,,,,,, -1900,2.0721304,2.161887,,,,,,,,,,,,,, -2000,3.7171626,2.1192951,,,,,,,,,,,,,, -2100,5.4014273,2.1079612,,,,,,,,,,,,,, -2200,2.1792378,2.0656917,,,,,,,,,,,,,, -2300,1.925539,2.0704257,,,,,,,,,,,,,, -2400,1.9517155,2.0407252,,,,,,,,,,,,,, -2500,1.8856415,2.0221224,,,,,,,,,,,,,, -2600,2.780261,1.9484302,,,,,,,,,,,,,, -2700,2.6486588,2.004857,,,,,,,,,,,,,, -2800,3.4475718,1.8972317,,,,,,,,,,,,,, -2900,2.2183695,2.0029147,,,,,,,,,,,,,, -3000,4.7985916,1.9886748,,,,,,,,,,,,,, -3100,4.47694,1.9781821,,,,,,,,,,,,,, -3200,1.9546307,1.9301233,,,,,,,,,,,,,, -3300,1.6744612,1.8303641,,,,,,,,,,,,,, -3400,1.9576877,1.8874563,,,,,,,,,,,,,, -3500,1.7535682,1.9205614,,,,,,,,,,,,,, -3526,,,0.55487686,0.1804650667234903,0.9077002,0.2566689516012242,5348.0,0.58842236,0.1846728820100339,2472.0,2921.9824163913727,3410.026333808899,2921.9824163913727,487.8131248950958,0.0849542617797851,0.0 -3600,1.8452857,1.8958846,,,,,,,,,,,,,, -3700,2.5284631,1.8144672,,,,,,,,,,,,,, -3800,2.776254,1.8677552,,,,,,,,,,,,,, -3900,2.0541277,1.880145,,,,,,,,,,,,,, -4000,2.5963223,1.8298094,,,,,,,,,,,,,, -4100,1.8554709,1.7802861,,,,,,,,,,,,,, -4200,1.9937727,1.886166,,,,,,,,,,,,,, -4300,3.2618604,1.8840694,,,,,,,,,,,,,, -4400,2.0966358,1.78394,,,,,,,,,,,,,, -4500,1.9355851,1.6848243,,,,,,,,,,,,,, -4600,2.746588,1.8182708,,,,,,,,,,,,,, -4700,2.4259515,1.7869182,,,,,,,,,,,,,, -4800,2.0997207,1.701328,,,,,,,,,,,,,, -4900,6.7705617,1.7644401,,,,,,,,,,,,,, -5000,1.8675677,1.7181396,,,,,,,,,,,,,, -5100,2.2276351,1.7594156,,,,,,,,,,,,,, -5200,1.8608384,1.7891608,,,,,,,,,,,,,, -5254,,,0.41640094,0.1434603159542672,0.78842723,0.2248665244214449,5348.0,0.4845866,0.1564804094814453,2472.0,4362.2189383506775,4992.204474925995,4362.2189383506775,629.6420524120331,0.1239755153656005,0.0 -5300,2.3136342,1.7289369,,,,,,,,,,,,,, -5400,3.3289037,1.8208983,,,,,,,,,,,,,, -5500,2.1912785,1.7397534,,,,,,,,,,,,,, -5600,2.5430946,1.7481079,,,,,,,,,,,,,, -5700,2.1558924,1.6658057,,,,,,,,,,,,,, -5800,2.0040283,1.7112074,,,,,,,,,,,,,, -5900,2.1017222,1.6875684,,,,,,,,,,,,,, -6000,1.9570884,1.7288947,,,,,,,,,,,,,, -6100,2.5146487,1.7184551,,,,,,,,,,,,,, -6200,2.2307582,1.7242728,,,,,,,,,,,,,, -6300,2.4878795,1.7102686,,,,,,,,,,,,,, -6400,3.1487606,1.7080953,,,,,,,,,,,,,, -6500,3.2339838,1.7132294,,,,,,,,,,,,,, -6600,2.481836,1.7046101,,,,,,,,,,,,,, -6700,1.9804618,1.7222918,,,,,,,,,,,,,, -6800,2.805592,1.6975675,,,,,,,,,,,,,, -6900,3.0285456,1.6578417,,,,,,,,,,,,,, -6987,,,0.4246011,0.1421106095990114,0.77591664,0.2211881016055688,5348.0,0.47377455,0.1527633904088721,2472.0,5802.560395002365,6567.744369745255,5802.560395002365,764.7253262996674,0.1659708023071289,0.0 -7000,2.0805132,1.7204026,,,,,,,,,,,,,, -7100,2.396328,1.7054178,,,,,,,,,,,,,, -7200,1.3771737,1.6602857,,,,,,,,,,,,,, -7300,2.3589394,1.6266665,,,,,,,,,,,,,, -7400,2.5753047,1.7072023,,,,,,,,,,,,,, -7500,1.9082991,1.6647524,,,,,,,,,,,,,, -7600,2.257029,1.7268422,,,,,,,,,,,,,, -7700,2.7981706,1.6284708,,,,,,,,,,,,,, -7800,2.9257464,1.6631601,,,,,,,,,,,,,, -7900,4.318575,1.6397971,,,,,,,,,,,,,, -8000,2.4947398,1.6891835,,,,,,,,,,,,,, -8100,3.7184424,1.6544886,,,,,,,,,,,,,, -8200,2.2877033,1.5671577,,,,,,,,,,,,,, -8300,3.7414205,3.251062,,,,,,,,,,,,,, -8400,3.655791,2.0550556,,,,,,,,,,,,,, -8500,2.4598854,1.9038639,,,,,,,,,,,,,, -8600,2.9199276,1.7849286,,,,,,,,,,,,,, -8700,1.5560699,1.6921062,,,,,,,,,,,,,, -8720,,,0.45573476,0.1526234534345169,0.811886,0.2290566438495032,5348.0,0.4933475,0.1575163000426543,2472.0,7242.571328878403,8144.72531414032,7242.571328878403,901.5812139511108,0.2061090469360351,0.0 -8800,2.6417432,1.7406839,,,,,,,,,,,,,, -8900,2.3914795,1.7413065,,,,,,,,,,,,,, -9000,2.3546379,1.7614586,,,,,,,,,,,,,, -9100,2.9334085,1.6794913,,,,,,,,,,,,,, -9200,2.8923519,1.6614228,,,,,,,,,,,,,, -9300,2.458413,1.5964203,,,,,,,,,,,,,, -9400,1.9355139,1.6366148,,,,,,,,,,,,,, -9500,1.891334,1.6690696,,,,,,,,,,,,,, -9600,2.8730369,1.637006,,,,,,,,,,,,,, -9700,2.6761236,1.7356766,,,,,,,,,,,,,, -9800,2.300213,1.6132135,,,,,,,,,,,,,, -9900,2.0556982,1.6398063,,,,,,,,,,,,,, -10000,3.2643216,1.6000372,,,,,,,,,,,,,, -10100,3.1701474,1.5959597,,,,,,,,,,,,,, -10200,2.6426933,1.606868,,,,,,,,,,,,,, -10300,3.3525674,1.6288358,,,,,,,,,,,,,, -10400,3.025054,1.656093,,,,,,,,,,,,,, -10454,,,0.40294805,0.1328667225034809,0.729496,0.2057310020564411,5348.0,0.425938,0.1363110109073182,2472.0,8683.132089853287,9719.01669192314,8683.132089853287,1035.2008595466614,0.2446341514587402,0.0 -10500,3.2401655,1.676394,,,,,,,,,,,,,, -10600,6.921671,1.6541764,,,,,,,,,,,,,, -10700,4.5430512,1.640603,,,,,,,,,,,,,, -10800,3.1754715,1.6489172,,,,,,,,,,,,,, -10900,2.7870429,1.6949737,,,,,,,,,,,,,, -11000,2.461674,1.6214819,,,,,,,,,,,,,, -11100,1.9166721,1.6144352,,,,,,,,,,,,,, -11200,2.0240664,1.6266319,,,,,,,,,,,,,, -11300,2.9844623,1.6457922,,,,,,,,,,,,,, -11400,2.872681,1.7141274,,,,,,,,,,,,,, -11500,2.381941,1.597629,,,,,,,,,,,,,, -11600,2.060926,1.6544766,,,,,,,,,,,,,, -11700,3.7606459,1.6199493,,,,,,,,,,,,,, -11800,2.296371,1.6715273,,,,,,,,,,,,,, -11900,2.467893,1.6364093,,,,,,,,,,,,,, -12000,3.485238,1.5439175,,,,,,,,,,,,,, -12100,3.9556904,1.5834367,,,,,,,,,,,,,, -12171,,,0.3721852,0.1251940952502499,0.69942635,0.2015795012406229,5348.0,0.4166353,0.1336908171348485,2472.0,10123.701634168625,11294.05009317398,10123.701634168625,1169.5552642345428,0.2813897132873535,0.0 -12200,1.8603907,1.5891242,,,,,,,,,,,,,, -12300,2.9037113,1.6718037,,,,,,,,,,,,,, -12400,3.279288,1.670473,,,,,,,,,,,,,, -12500,3.117687,1.6168616,,,,,,,,,,,,,, -12600,4.5014324,1.5917206,,,,,,,,,,,,,, -12700,3.5489624,1.5994265,,,,,,,,,,,,,, -12800,3.2989478,1.541114,,,,,,,,,,,,,, -12900,1.9741532,1.4920218,,,,,,,,,,,,,, -13000,2.3091815,1.5569296,,,,,,,,,,,,,, -13100,2.5551898,1.601589,,,,,,,,,,,,,, -13200,1.7572865,1.5666379,,,,,,,,,,,,,, -13300,1.8133845,1.5542444,,,,,,,,,,,,,, -13400,3.2078867,1.5627608,,,,,,,,,,,,,, -13500,2.5059493,1.5468761,,,,,,,,,,,,,, -13600,2.7628484,1.475497,,,,,,,,,,,,,, -13700,3.369387,1.6483074,,,,,,,,,,,,,, -13800,1.8558685,1.5603087,,,,,,,,,,,,,, -13900,3.6787808,1.574289,,,,,,,,,,,,,, -13914,,,0.30275872,0.1062118099794586,0.65481377,0.1886422661401662,5348.0,0.38221413,0.1240834399691264,2472.0,11563.997570991516,12868.305127859116,11563.997570991516,1303.403836965561,0.3187751770019531,0.0 -14000,1.6258128,1.5816005,,,,,,,,,,,,,, -14100,2.5301945,1.5784832,,,,,,,,,,,,,, -14200,2.397192,1.5434433,,,,,,,,,,,,,, -14300,2.379141,1.6220461,,,,,,,,,,,,,, -14400,1.9566988,1.5095071,,,,,,,,,,,,,, -14500,1.6648068,1.5469855,,,,,,,,,,,,,, -14600,2.7861006,1.5758156,,,,,,,,,,,,,, -14700,1.7018187,1.5797522,,,,,,,,,,,,,, -14800,3.6776984,1.5525256,,,,,,,,,,,,,, -14900,3.2695088,1.586692,,,,,,,,,,,,,, -15000,2.8611012,1.549638,,,,,,,,,,,,,, -15100,1.950288,1.5469784,,,,,,,,,,,,,, -15200,3.125223,1.5193642,,,,,,,,,,,,,, -15300,3.3762693,1.5656792,,,,,,,,,,,,,, -15400,1.8858591,1.5130683,,,,,,,,,,,,,, -15500,2.0542967,1.5558368,,,,,,,,,,,,,, -15600,3.4476752,1.5557898,,,,,,,,,,,,,, -15643,,,0.2885871,0.0994844894636881,0.6333401,0.1834866814061036,5348.0,0.36668622,0.1192289724371864,2472.0,13004.067729234695,14441.418841600418,13004.067729234695,1436.3366811275482,0.3561115264892578,0.0 -15700,2.130896,1.5667377,,,,,,,,,,,,,, -15800,3.2429883,1.5434631,,,,,,,,,,,,,, -15900,2.7736099,1.5018574,,,,,,,,,,,,,, -16000,2.437337,1.5288197,,,,,,,,,,,,,, -16100,2.8360922,1.5153581,,,,,,,,,,,,,, -16200,2.9484863,1.531461,,,,,,,,,,,,,, -16300,2.9426167,1.5302808,,,,,,,,,,,,,, -16400,1.7562532,1.5318539,,,,,,,,,,,,,, -16500,1.8783699,1.5424169,,,,,,,,,,,,,, -16600,3.7245626,1.5683407,,,,,,,,,,,,,, -16700,2.857867,1.5446609,,,,,,,,,,,,,, -16800,3.8355873,1.5318,,,,,,,,,,,,,, -16900,2.1104727,1.5258306,,,,,,,,,,,,,, -17000,2.352047,1.5328548,,,,,,,,,,,,,, -17100,2.2431533,1.5550404,,,,,,,,,,,,,, -17200,2.9625642,1.5471476,,,,,,,,,,,,,, -17300,2.283588,1.4582425,,,,,,,,,,,,,, -17364,,,0.2926443,0.1049998122992272,0.61829656,0.1771146103864757,5348.0,0.35901284,0.1148213596571405,2472.0,14444.379429101944,16016.7948179245,14444.379429101944,1571.290831565857,0.3938989639282226,0.0 -17400,2.701501,1.484765,,,,,,,,,,,,,, -17500,2.5287447,1.5464389,,,,,,,,,,,,,, -17600,2.030786,1.5090584,,,,,,,,,,,,,, -17700,1.8359683,1.5453897,,,,,,,,,,,,,, -17800,4.9920144,1.5566729,,,,,,,,,,,,,, -17900,3.670006,1.5090506,,,,,,,,,,,,,, -18000,1.8886046,1.4897859,,,,,,,,,,,,,, -18100,2.386649,1.4930778,,,,,,,,,,,,,, -18200,2.3446188,1.4406698,,,,,,,,,,,,,, -18300,2.4753883,1.4591588,,,,,,,,,,,,,, -18400,1.9641191,1.4707156,,,,,,,,,,,,,, -18500,5.1079717,1.503808,,,,,,,,,,,,,, -18600,3.1035666,1.535516,,,,,,,,,,,,,, -18700,2.2446625,1.4621444,,,,,,,,,,,,,, -18800,3.372765,1.4734954,,,,,,,,,,,,,, -18900,2.0607352,1.5082251,,,,,,,,,,,,,, -19000,2.0286798,1.4698439,,,,,,,,,,,,,, -19100,2.6555421,1.467761,,,,,,,,,,,,,, -19102,,,0.2893605,0.0972368386260874,0.5994233,0.171708004672852,5348.0,0.34121013,0.1102715658196737,2472.0,15884.863197088242,17594.676003694534,15884.863197088242,1708.575800895691,0.4329390525817871,0.0 -19200,3.1457076,1.5278255,,,,,,,,,,,,,, -19300,2.822008,1.498618,,,,,,,,,,,,,, -19400,4.0441313,1.4356278,,,,,,,,,,,,,, -19500,2.782552,1.4470862,,,,,,,,,,,,,, -19600,2.402723,1.5302742,,,,,,,,,,,,,, -19700,1.6736296,1.392549,,,,,,,,,,,,,, -19800,2.4878316,1.4235331,,,,,,,,,,,,,, -19900,2.8967626,1.4809484,,,,,,,,,,,,,, -20000,1.5104438,1.4002261,,,,,,,,,,,,,, -20100,2.9942892,1.4620992,,,,,,,,,,,,,, -20200,4.606401,1.4046366,,,,,,,,,,,,,, -20300,2.4267938,1.4290735,,,,,,,,,,,,,, -20400,2.6896513,1.4242991,,,,,,,,,,,,,, -20500,2.4367218,1.5145521,,,,,,,,,,,,,, -20600,4.3663464,1.4223511,,,,,,,,,,,,,, -20700,2.3613758,1.3779103,,,,,,,,,,,,,, -20800,3.5404646,1.4477613,,,,,,,,,,,,,, -20836,,,0.28343403,0.0970691429023663,0.5798606,0.1674020294080732,5348.0,0.3285923,0.1064123656896796,2472.0,17325.22800397873,19168.884200811382,17325.22800397873,1842.3055789470675,0.4730331897735595,0.0 -20900,3.3482053,1.4503809,,,,,,,,,,,,,, -21000,2.8891037,1.448562,,,,,,,,,,,,,, -21100,3.4771912,1.4486483,,,,,,,,,,,,,, -21200,3.2110095,1.4768537,,,,,,,,,,,,,, -21300,2.268285,1.5253981,,,,,,,,,,,,,, -21400,3.068564,1.4784575,,,,,,,,,,,,,, -21500,1.986092,1.4579196,,,,,,,,,,,,,, -21600,2.3506746,1.4605726,,,,,,,,,,,,,, -21700,2.8795266,1.4094508,,,,,,,,,,,,,, -21800,1.6678886,1.4008137,,,,,,,,,,,,,, -21900,2.2319477,1.3666092,,,,,,,,,,,,,, -22000,2.2673783,1.4405042,,,,,,,,,,,,,, -22100,3.19903,1.3522534,,,,,,,,,,,,,, -22200,3.075304,1.4738164,,,,,,,,,,,,,, -22300,2.180392,1.4254869,,,,,,,,,,,,,, -22400,2.6922247,1.4128722,,,,,,,,,,,,,, -22500,1.8545787,1.3861678,,,,,,,,,,,,,, -22553,,,0.26502982,0.0911101965222759,0.56222516,0.1615320003475675,5348.0,0.3166594,0.1046655698413665,2472.0,18766.104904174805,20745.496968269348,18766.104904174805,1977.9283108711245,0.5135478973388672,0.0 -22600,2.5355484,1.4179275,,,,,,,,,,,,,, -22700,2.4020622,1.3871269,,,,,,,,,,,,,, -22800,2.7561276,1.3715363,,,,,,,,,,,,,, -22900,3.115075,1.425059,,,,,,,,,,,,,, -23000,3.514058,1.4428891,,,,,,,,,,,,,, -23100,2.7416918,1.538799,,,,,,,,,,,,,, -23200,2.2820935,1.3758532,,,,,,,,,,,,,, -23300,2.2451365,1.3773462,,,,,,,,,,,,,, -23400,4.3806024,1.4616373,,,,,,,,,,,,,, -23500,1.8314749,1.3854426,,,,,,,,,,,,,, -23600,2.672852,1.4524131,,,,,,,,,,,,,, -23700,2.3254278,1.4111583,,,,,,,,,,,,,, -23800,2.4873538,1.3708063,,,,,,,,,,,,,, -23900,1.8778706,1.3399428,,,,,,,,,,,,,, -24000,2.663434,1.4014541,,,,,,,,,,,,,, -24100,2.1097693,1.4628112,,,,,,,,,,,,,, -24200,3.5306292,1.4064528,,,,,,,,,,,,,, -24275,,,0.23827851,0.0825925847726025,0.5401367,0.1562122865114842,5348.0,0.30722615,0.0994861170353218,2472.0,20206.208694934845,22321.65062570572,20206.208694934845,2113.864481449127,0.5542182922363281,0.0 -24300,1.7893504,1.4024955,,,,,,,,,,,,,, -24400,2.4156191,1.3942727,,,,,,,,,,,,,, -24500,3.0987315,1.4034615,,,,,,,,,,,,,, -24600,3.6909645,1.3705165,,,,,,,,,,,,,, -24700,2.8496974,1.4406937,,,,,,,,,,,,,, -24800,2.3104713,1.3351318,,,,,,,,,,,,,, -24900,2.9465213,1.4124293,,,,,,,,,,,,,, -25000,3.506801,1.4011916,,,,,,,,,,,,,, -25100,2.9665356,1.3426195,,,,,,,,,,,,,, -25200,4.058625,1.3523301,,,,,,,,,,,,,, -25300,4.6480837,1.4238448,,,,,,,,,,,,,, -25400,3.9728084,1.3954679,,,,,,,,,,,,,, -25500,2.3683,1.3860736,,,,,,,,,,,,,, -25600,1.8733412,1.3466853,,,,,,,,,,,,,, -25700,3.6575367,1.3511487,,,,,,,,,,,,,, -25800,2.31673,1.3902345,,,,,,,,,,,,,, -25900,2.9842038,1.4665468,,,,,,,,,,,,,, -25997,,,0.21937364,0.0774772424945877,0.51852757,0.1502843295326182,5348.0,0.29538578,0.0940019905348038,2472.0,21646.24563932419,23898.528543949127,21646.24563932419,2250.577463388443,0.6050629615783691,0.0 -26000,2.2729917,1.3277783,,,,,,,,,,,,,, -26100,2.0532649,1.4445034,,,,,,,,,,,,,, -26200,2.3499382,1.4428253,,,,,,,,,,,,,, -26300,1.8279425,1.4167908,,,,,,,,,,,,,, -26400,1.9923425,1.3898739,,,,,,,,,,,,,, -26500,1.6344122,1.3766099,,,,,,,,,,,,,, -26600,2.690212,1.3356262,,,,,,,,,,,,,, -26700,2.427219,1.3766899,,,,,,,,,,,,,, -26800,1.6330036,1.2841073,,,,,,,,,,,,,, -26900,2.7478037,1.3470061,,,,,,,,,,,,,, -27000,3.2459893,1.3802421,,,,,,,,,,,,,, -27100,2.3754406,1.3846884,,,,,,,,,,,,,, -27200,2.9164493,1.3551964,,,,,,,,,,,,,, -27300,3.7676978,1.3147564,,,,,,,,,,,,,, -27400,3.3422298,1.326477,,,,,,,,,,,,,, -27500,2.3230085,1.3352921,,,,,,,,,,,,,, -27600,2.917479,1.342842,,,,,,,,,,,,,, -27696,,,0.21299855,0.0735494164936433,0.5131484,0.1481023779410487,5348.0,0.2851266,0.0920927020494383,2472.0,23086.39237809181,25476.63623690605,23086.39237809181,2388.410187959671,0.6562862396240234,0.0 -27700,2.1212087,1.393776,,,,,,,,,,,,,, -27800,2.0507867,1.3112724,,,,,,,,,,,,,, -27900,1.9056925,1.3110534,,,,,,,,,,,,,, -28000,2.85043,1.384392,,,,,,,,,,,,,, -28100,1.544329,1.2666454,,,,,,,,,,,,,, -28200,1.6545316,1.2630512,,,,,,,,,,,,,, -28300,2.27718,1.3249464,,,,,,,,,,,,,, -28400,3.3478665,1.3336146,,,,,,,,,,,,,, -28500,2.0551963,1.3308187,,,,,,,,,,,,,, -28600,2.0656757,1.409508,,,,,,,,,,,,,, -28700,2.0830474,1.3135906,,,,,,,,,,,,,, -28800,2.3099062,1.3118693,,,,,,,,,,,,,, -28900,1.9368248,1.2880373,,,,,,,,,,,,,, -29000,3.2693875,1.2776285,,,,,,,,,,,,,, -29100,1.9086553,1.3042936,,,,,,,,,,,,,, -29200,2.135753,1.2888837,,,,,,,,,,,,,, -29300,2.4209592,1.3492209,,,,,,,,,,,,,, -29400,2.4804802,1.3513165,,,,,,,,,,,,,, -29414,,,0.22567616,0.078078956491593,0.49853975,0.1427537001457852,5348.0,0.27973896,0.0905083988381776,2472.0,24526.297978639603,27050.98180246353,24526.297978639603,2522.7200396060944,0.7084314823150635,0.0 -29500,2.733001,1.2953368,,,,,,,,,,,,,, -29600,1.9120377,1.4047451,,,,,,,,,,,,,, -29700,2.6479688,1.252473,,,,,,,,,,,,,, -29800,2.2499828,1.3498682,,,,,,,,,,,,,, -29900,2.3059223,1.328178,,,,,,,,,,,,,, -30000,2.8870325,1.2533315,,,,,,,,,,,,,, -30100,1.7406027,1.2734725,,,,,,,,,,,,,, -30200,1.6820016,1.318387,,,,,,,,,,,,,, -30300,1.8441086,1.3379741,,,,,,,,,,,,,, -30400,2.572949,1.3041117,,,,,,,,,,,,,, -30500,2.2219784,1.2895817,,,,,,,,,,,,,, -30600,4.2819967,1.3162322,,,,,,,,,,,,,, -30700,2.9050682,1.3046273,,,,,,,,,,,,,, -30800,1.9945859,1.3148775,,,,,,,,,,,,,, -30900,3.3772225,1.3037008,,,,,,,,,,,,,, -31000,6.751734,1.3350562,,,,,,,,,,,,,, -31100,2.2682705,1.2991667,,,,,,,,,,,,,, -31122,,,0.2058136,0.0697941931318129,0.49061537,0.1409096614113172,5348.0,0.26477703,0.086222655535921,2472.0,25966.211684703827,28625.13384318352,25966.211684703827,2656.8231077194214,0.765221357345581,0.0 -31200,2.9451244,1.2475704,,,,,,,,,,,,,, -31300,1.9913788,1.304595,,,,,,,,,,,,,, -31400,2.295074,1.2899231,,,,,,,,,,,,,, -31500,2.8438406,1.3182449,,,,,,,,,,,,,, -31600,2.253802,1.2720038,,,,,,,,,,,,,, -31700,4.4511805,1.2820714,,,,,,,,,,,,,, -31800,1.6913416,1.2568384,,,,,,,,,,,,,, -31900,2.1518955,1.2796816,,,,,,,,,,,,,, -32000,3.9598918,1.317487,,,,,,,,,,,,,, -32100,3.788995,1.2922052,,,,,,,,,,,,,, -32200,6.157984,1.2845539,,,,,,,,,,,,,, -32300,1.9463148,1.2717887,,,,,,,,,,,,,, -32400,3.3483784,1.2636373,,,,,,,,,,,,,, -32500,2.112859,1.3156487,,,,,,,,,,,,,, -32600,3.1793175,1.2761879,,,,,,,,,,,,,, -32700,3.0775027,1.261823,,,,,,,,,,,,,, -32800,2.8800883,1.2834176,,,,,,,,,,,,,, -32821,,,0.20854647,0.0714135275775804,0.4694706,0.1349141218610309,5348.0,0.2588315,0.0836430849227144,2472.0,27406.36852216721,30203.20334362984,27406.36852216721,2794.605189561844,0.818885087966919,0.0 -32900,2.5199816,1.3109237,,,,,,,,,,,,,, -33000,1.7088821,1.2566944,,,,,,,,,,,,,, -33100,3.1292908,1.2863535,,,,,,,,,,,,,, -33200,3.5161915,1.2552521,,,,,,,,,,,,,, -33300,2.0683386,1.2514274,,,,,,,,,,,,,, -33400,1.566221,1.2551019,,,,,,,,,,,,,, -33500,2.704788,1.2550895,,,,,,,,,,,,,, -33600,3.192927,1.2174003,,,,,,,,,,,,,, -33700,3.1598792,1.2560583,,,,,,,,,,,,,, -33800,1.6955296,1.2852489,,,,,,,,,,,,,, -33900,2.4839156,1.285651,,,,,,,,,,,,,, -34000,1.650782,1.2206551,,,,,,,,,,,,,, -34100,3.1340542,1.238963,,,,,,,,,,,,,, -34200,1.7275184,1.2319429,,,,,,,,,,,,,, -34300,1.8838364,1.2507063,,,,,,,,,,,,,, -34400,4.698357,1.2891235,,,,,,,,,,,,,, -34500,3.60701,1.2600933,,,,,,,,,,,,,, -34533,,,0.19717774,0.0654965525325831,0.456449,0.131979107330778,5348.0,0.24803255,0.0806369711372453,2472.0,28846.32292485237,31778.03262424469,28846.32292485237,2929.3485980033875,0.8737070560455322,0.0 -34600,5.9946017,1.23954,,,,,,,,,,,,,, -34700,2.9440074,1.2359767,,,,,,,,,,,,,, -34800,2.3758128,1.2567363,,,,,,,,,,,,,, -34900,3.7470367,1.2370986,,,,,,,,,,,,,, -35000,3.0499716,1.1838977,,,,,,,,,,,,,, -35100,6.6648,1.1967946,,,,,,,,,,,,,, -35200,2.1252081,1.2393306,,,,,,,,,,,,,, -35300,2.1067436,1.2208773,,,,,,,,,,,,,, -35400,1.4988182,1.2626573,,,,,,,,,,,,,, -35500,2.47281,1.1989461,,,,,,,,,,,,,, -35600,2.3826647,1.2256112,,,,,,,,,,,,,, -35700,2.7230706,1.2234591,,,,,,,,,,,,,, -35800,1.7561859,1.260323,,,,,,,,,,,,,, -35900,3.2558002,1.2265303,,,,,,,,,,,,,, -36000,2.666852,1.2507436,,,,,,,,,,,,,, -36100,2.5252912,1.2016058,,,,,,,,,,,,,, -36200,6.0856586,1.2169106,,,,,,,,,,,,,, -36242,,,0.14662404,0.0513668091387103,0.44940788,0.1292468405147861,5348.0,0.24380776,0.0775902341925131,2472.0,30286.631365060806,33353.2452647686,30286.631365060806,3064.120968580246,0.928691864013672,0.0 -36300,2.278105,1.1935476,,,,,,,,,,,,,, -36400,2.0802686,1.1720123,,,,,,,,,,,,,, -36500,1.9629698,1.2022296,,,,,,,,,,,,,, -36600,2.6046844,1.1509328,,,,,,,,,,,,,, -36700,2.5563273,1.2228713,,,,,,,,,,,,,, -36800,1.4865841,1.2157781,,,,,,,,,,,,,, -36900,2.9540918,1.2294563,,,,,,,,,,,,,, -37000,2.371926,1.2473546,,,,,,,,,,,,,, -37100,2.8634388,1.1532986,,,,,,,,,,,,,, -37200,2.6111643,1.2251467,,,,,,,,,,,,,, -37300,3.0712006,1.2225857,,,,,,,,,,,,,, -37400,2.3788707,1.250157,,,,,,,,,,,,,, -37500,1.6669582,1.2100328,,,,,,,,,,,,,, -37600,4.297928,1.2522469,,,,,,,,,,,,,, -37700,1.2621078,1.1969422,,,,,,,,,,,,,, -37800,3.3657196,1.2047153,,,,,,,,,,,,,, -37900,1.5700978,1.2079954,,,,,,,,,,,,,, -37937,,,0.16366398,0.056223733395282,0.43865654,0.1260994236172123,5348.0,0.23735188,0.0746247435663071,2472.0,31726.58775639534,34929.726868867874,31726.58775639534,3200.515196084976,0.9846453666687012,0.0 -38000,2.4791267,1.175153,,,,,,,,,,,,,, -38100,5.3073535,1.1913346,,,,,,,,,,,,,, -38200,1.9899354,1.2033658,,,,,,,,,,,,,, -38300,1.9625336,1.1636543,,,,,,,,,,,,,, -38400,2.9421558,1.1853613,,,,,,,,,,,,,, -38500,2.9815333,1.1961805,,,,,,,,,,,,,, -38600,2.4715087,1.2042043,,,,,,,,,,,,,, -38700,3.6770782,1.1923003,,,,,,,,,,,,,, -38800,1.4840454,1.2146269,,,,,,,,,,,,,, -38900,2.899547,1.1558546,,,,,,,,,,,,,, -39000,1.5711732,1.1529834,,,,,,,,,,,,,, -39100,2.9177854,1.139301,,,,,,,,,,,,,, -39200,1.5250282,1.2134227,,,,,,,,,,,,,, -39300,3.09929,1.1475246,,,,,,,,,,,,,, -39400,2.152987,1.1578747,,,,,,,,,,,,,, -39500,2.7422032,1.2228491,,,,,,,,,,,,,, -39600,2.6587064,1.1586521,,,,,,,,,,,,,, -39660,,,0.20840412,0.0723192831533442,0.42887554,0.1220927425972947,5348.0,0.23088776,0.0733857372087827,2472.0,33166.52247095108,36501.78617596626,33166.52247095108,3332.5077028274536,1.0392842292785645,0.0 -39700,2.0410542,1.1799922,,,,,,,,,,,,,, -39800,2.6461086,1.2135524,,,,,,,,,,,,,, -39900,2.5756955,1.1824077,,,,,,,,,,,,,, -40000,2.3844635,1.2300355,,,,,,,,,,,,,, -40100,4.705019,1.1616179,,,,,,,,,,,,,, -40200,3.818409,1.0787518,,,,,,,,,,,,,, -40300,5.305146,1.1762437,,,,,,,,,,,,,, -40400,1.7044886,1.1460726,,,,,,,,,,,,,, -40500,3.860241,1.1869388,,,,,,,,,,,,,, -40600,2.166588,1.2142943,,,,,,,,,,,,,, -40700,2.4337585,1.1460525,,,,,,,,,,,,,, -40800,3.09589,1.1747929,,,,,,,,,,,,,, -40900,5.5944147,1.1829852,,,,,,,,,,,,,, -41000,1.7321823,1.1538001,,,,,,,,,,,,,, -41100,2.4295208,1.1042019,,,,,,,,,,,,,, -41200,6.7409525,1.199574,,,,,,,,,,,,,, -41300,2.4282954,1.1369091,,,,,,,,,,,,,, -41349,,,0.2137908,0.073177637020712,0.41952863,0.1200556108016258,5348.0,0.22452182,0.0723904698068368,2472.0,34606.51906085014,38078.55527210236,34606.51906085014,3469.1516411304474,1.09200119972229,0.0 -41400,2.748894,1.1229821,,,,,,,,,,,,,, -41500,1.8873342,1.138596,,,,,,,,,,,,,, -41600,1.5632173,1.0953902,,,,,,,,,,,,,, -41700,1.8420857,1.1778036,,,,,,,,,,,,,, -41800,3.3565054,1.1948748,,,,,,,,,,,,,, -41900,2.399497,1.2168326,,,,,,,,,,,,,, -42000,2.529732,1.1116803,,,,,,,,,,,,,, -42100,3.5210042,1.1484809,,,,,,,,,,,,,, -42200,2.626357,1.1770346,,,,,,,,,,,,,, -42300,2.4601436,1.1775416,,,,,,,,,,,,,, -42400,2.0881567,1.1931771,,,,,,,,,,,,,, -42500,4.0498695,1.169837,,,,,,,,,,,,,, -42600,2.1391368,1.1657256,,,,,,,,,,,,,, -42700,1.9219133,1.1524597,,,,,,,,,,,,,, -42800,3.03228,1.1224687,,,,,,,,,,,,,, -42900,1.9290888,1.1755933,,,,,,,,,,,,,, -43000,2.3728063,1.1268034,,,,,,,,,,,,,, -43028,,,0.2462839,0.0858491614556217,0.4148276,0.1184432837405987,5348.0,0.22208491,0.07098897081226,2472.0,36047.30721688271,39651.32257723808,36047.30721688271,3600.998679161072,1.14786696434021,0.0 -43028,,,,,,,,,,,36047.307216882706,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d7206b8cd..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,232 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,total_duration,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples -518.4977009296417,0.0,18.616023302078247,1,0,18.616023302078247,0.4722304046154022,0.7914055585861206,0.0270653151253273,43793,537.1137669086456,0.4702461659908294,0.7943987846374512,0.0219344540100092,0.4722293615341186,0.7926825284957886,0.0249722052564542,43793 -633.2843382358551,0.0283777713775634,258.824720621109,775,0,258.824720621109,0.9832780957221984,0.0629883781075477,0.0577343690394387,43793,892.1560852527618,0.9868525862693788,0.0508523173630237,0.0579805508242579,0.984272599220276,0.0599129758775234,0.0558535770941507,43793 -749.6343984603882,0.056328535079956,499.05750465393066,1559,0,499.05750465393066,0.983700156211853,0.0584810450673103,0.1055086506954278,43793,1248.7858963012695,0.987420916557312,0.0460253953933715,0.1102378426434518,0.9847450852394104,0.0551337078213691,0.110825755459752,43793 -868.8937499523163,0.0814664363861084,739.084522485733,2336,0,739.084522485733,0.9839503169059752,0.0567783191800117,0.1309682146696898,43793,1608.116632938385,0.9876879453659058,0.0442273467779159,0.1398924353982198,0.985008955001831,0.0533492192625999,0.1355290644133614,43793 -988.9270553588868,0.1082706451416015,979.3258407115936,3116,0,979.3258407115936,0.9842274785041808,0.0534768253564834,0.1543288606950389,43793,1968.437069416046,0.9880595207214355,0.0418546013534069,0.1689057449872897,0.9852123260498048,0.0506405979394912,0.1537483003262923,43793 -1109.51793384552,0.1344447135925293,1219.583377122879,3895,0,1219.583377122879,0.9845648407936096,0.0519956611096859,0.1728490681036745,43793,2329.331063270569,0.9883913993835448,0.0399228259921073,0.2009358619510241,0.9854932427406312,0.0492856800556182,0.1780816860356048,43793 -1230.210641622543,0.1602764129638672,1459.7593297958374,4669,0,1459.7593297958374,0.984831929206848,0.050608541816473,0.1937942979027161,43793,2690.2451293468475,0.9887118339538574,0.0385812260210514,0.2256461587811895,0.9857218265533448,0.0479509532451629,0.1956819743386989,43793 -1352.9720075130465,0.1865897178649902,1699.8314542770386,5437,0,1699.8314542770386,0.9849759340286256,0.0504288785159587,0.2005436096187317,43793,3053.124225139618,0.9889028668403624,0.0378363840281963,0.2413111711242901,0.9858987927436828,0.0474938340485096,0.2085082304214656,43793 -1479.4620878696442,0.2128331661224365,1939.9151394367216,6176,0,1939.9151394367216,0.985202133655548,0.049216840416193,0.2113630635353627,43793,3419.7436504364014,0.9892878532409668,0.0362776331603527,0.2747439853818164,0.9860436916351318,0.0465897619724273,0.2127026190852786,43793 -1605.0842487812042,0.240192174911499,2180.108886241913,6941,0,2180.108886241913,0.9854455590248108,0.0487179309129715,0.2263846545075325,43793,3785.606664180756,0.9895052313804626,0.0354714021086692,0.2851274273494,0.9862803816795348,0.0460730753839015,0.2277599897011227,43793 -1729.5676703453064,0.2676494121551513,2420.2818851470947,7720,0,2420.2818851470947,0.98549485206604,0.0486711822450161,0.2260540501382554,43793,4150.311340808868,0.9895861744880676,0.0348616912961006,0.2963364431756689,0.9863311052322388,0.0460848286747932,0.2284525234305878,43793 -1850.9068267345428,0.2948906421661377,2660.528244495392,8498,0,2660.528244495392,0.9855883717536926,0.0479560121893882,0.2365790668525406,43793,4511.944679737091,0.9899796843528748,0.0335333086550235,0.3374593375725701,0.9864756464958192,0.0453152917325496,0.2374250299991654,43793 -1977.5145423412323,0.3243780136108398,2900.601298093796,9268,0,2900.601298093796,0.9857029318809508,0.0474643260240554,0.2484129351793073,43793,4878.675051450729,0.9901212453842164,0.0331129059195518,0.3409681409111439,0.9865267872810364,0.0450315028429031,0.2348892874300491,43793 -2104.2374680042267,0.3568861484527588,3140.8176939487457,10000,0,3140.8176939487457,0.985702097415924,0.0476629026234149,0.2479148513505949,43793,5245.671521663666,0.9901310801506042,0.0330612622201442,0.3398156014983023,0.9864979386329652,0.0451376028358936,0.2469973556483602,43793 -2230.334555387497,0.384890079498291,3380.950745820999,10771,0,3380.950745820999,0.985705018043518,0.0472121275961399,0.2473022407345934,43793,5611.949606180191,0.990074336528778,0.0331404134631156,0.3361895117138637,0.9865519404411316,0.0447386875748634,0.2440713077247412,43793 -2355.720167160034,0.4129753112792969,3620.993992805481,11547,0,3620.993992805481,0.9857008457183838,0.0471366681158542,0.2437119404115444,43793,5977.426261425018,0.990300178527832,0.0324656441807746,0.3550998255123031,0.9865734577178956,0.0445703901350498,0.2514523384657378,43793 -2478.8244092464447,0.4461667537689209,3861.169105052948,12317,0,3861.169105052948,0.9856165647506714,0.0475370250642299,0.2464815772412357,43793,6340.75782251358,0.9902815222740172,0.0324085690081119,0.3652259174491668,0.9865154027938844,0.0448619574308395,0.2453973131389777,43793 -2603.4416444301605,0.4753603935241699,4101.32523059845,13096,0,4101.32523059845,0.9858625531196594,0.0470395348966121,0.2526472636480409,43793,6705.580302476883,0.9903441071510316,0.0318412445485591,0.3631558729401631,0.9867634773254396,0.0443429425358772,0.253112694139392,43793 -2729.6876018047333,0.5052244663238525,4341.513201236725,13864,0,4341.513201236725,0.985772430896759,0.0467454083263874,0.2591220735212268,43793,7072.06339097023,0.9905263781547546,0.0316419415175914,0.3830037810011942,0.9866445064544678,0.0443065948784351,0.2587238987096111,43793 -2855.8469376564026,0.5335369110107422,4581.769669532776,14634,0,4581.769669532776,0.9859017133712769,0.047769758850336,0.2517687762805413,43793,7438.527512073517,0.990631103515625,0.0308733396232128,0.3843530163345709,0.9867780804634094,0.0449380241334438,0.2532076404834464,43793 -2983.387553453445,0.5630576610565186,4821.992245674133,15395,0,4821.992245674133,0.9859628081321716,0.0468635782599449,0.2587225448907962,43793,7806.339802980423,0.9906738996505736,0.0305171180516481,0.4060332768930056,0.9868982434272766,0.0439459308981895,0.2693682103284675,43793 -3110.181565284729,0.5918211936950684,5062.061771631241,16165,0,5062.061771631241,0.9859126806259156,0.0468367300927639,0.2617815021162231,43793,8173.251757383347,0.9908897876739502,0.0299185272306203,0.4190069996700812,0.9867841601371764,0.0441765636205673,0.2625511838238096,43793 -3237.2122271060944,0.621873140335083,5302.217221975327,16938,0,5302.217221975327,0.985939621925354,0.0470585785806179,0.2587865116364578,43793,8540.487695932388,0.9907991886138916,0.0300168450921773,0.417379686754306,0.9867289662361144,0.0445146299898624,0.2625847733317957,43793 -3361.2305388450623,0.6503303050994873,5542.485919952393,17714,0,5542.485919952393,0.985929548740387,0.0471058934926986,0.2590215087437156,43793,8904.822938919067,0.991073489189148,0.0292695350944995,0.431056186751704,0.9868012070655824,0.04434310272336,0.2621988834189322,43793 -3485.254349946976,0.6788928508758545,5782.542509555817,18472,0,5782.542509555817,0.986065149307251,0.0470095500349998,0.2659354239643912,43793,9268.951786994934,0.991197407245636,0.0287209581583738,0.4492757997726974,0.9868823885917664,0.0443264842033386,0.2703646829249263,43793 -3609.403496026993,0.7088520526885986,6022.602437973023,19245,0,6022.602437973023,0.9859371185302734,0.0474443882703781,0.2659620129994766,43793,9633.210906744003,0.9910871386528016,0.0290567986667156,0.4418756555670832,0.9868450164794922,0.0445682890713214,0.2730459174791392,43793 -3738.122510671616,0.7383873462677002,6262.597642183304,20022,0,6262.597642183304,0.9860432744026184,0.0465757586061954,0.2620121333398993,43793,10001.974766492844,0.9911937713623048,0.0288317780941724,0.4492746429492774,0.9869363903999328,0.0439840964972972,0.2670309948303253,43793 -3863.0693638324738,0.7679581642150879,6502.755940914154,20790,0,6502.755940914154,0.9860184192657472,0.0467420704662799,0.2635152227722349,43793,10367.129216194153,0.9911454319953918,0.0288664679974317,0.4375458194010765,0.9868552088737488,0.044102668762207,0.2674966313573202,43793 -3990.363956451416,0.7991714477539062,6742.881019592285,21562,0,6742.881019592285,0.9859952330589294,0.047163251787424,0.2615548889912085,43793,10734.599383115768,0.9909815192222596,0.0295109990984201,0.4286785680240972,0.9869201183319092,0.0443251132965087,0.2658507401426963,43793 -4116.310038328171,0.8293313980102539,6982.853446960449,22331,0,6982.853446960449,0.986028492450714,0.0466667674481868,0.2643371726781699,43793,11100.56733751297,0.9911559224128724,0.0289796441793441,0.4308432635755393,0.9867740273475648,0.0440210215747356,0.2722770316354325,43793 -4244.544364929199,0.8605177402496338,7223.086431264877,23101,0,7223.086431264877,0.9861287474632264,0.0469279028475284,0.265912846330854,43793,11469.08495092392,0.991181254386902,0.0289966650307178,0.4412407234838167,0.9868909120559692,0.0442680716514587,0.2685556335456881,43793 -4374.132519721985,0.8918707370758057,7463.321610689163,23839,0,7463.321610689163,0.986080765724182,0.0466671139001846,0.2664010170784851,43793,11838.96249818802,0.9912705421447754,0.0284074898809194,0.4519112047921446,0.98692786693573,0.0439685694873333,0.2723111216283164,43793 -4500.043113708496,0.9227774143218994,7703.395177125931,24604,0,7703.395177125931,0.9861123561859132,0.046518225222826,0.2715248516894235,43793,12204.99735713005,0.9913958311080932,0.0279237478971481,0.4629627920889729,0.9869635701179504,0.0436313673853874,0.2792550134875484,43793 -4624.204813718796,0.9522502422332764,7943.4845950603485,25376,0,7943.4845950603485,0.9861026406288148,0.0467845723032951,0.2682354559966355,43793,12569.297526597977,0.9914031624794006,0.0279384851455688,0.4564729812272185,0.9869404435157776,0.0441245399415493,0.2799870684061308,43793 -4754.313112735748,0.9845142364501952,8183.584435462952,26141,0,8183.584435462952,0.986050009727478,0.0468300879001617,0.2667560083856059,43793,12939.556988477709,0.9914870262145996,0.0275341682136058,0.4767759712694727,0.9869298934936525,0.0439975336194038,0.2779713734008719,43793 -4880.574460268021,1.0156822204589844,8423.600947618484,26909,0,8423.600947618484,0.9861173629760742,0.0468805469572544,0.2717513397381658,43793,13305.885575532911,0.9915395975112916,0.0272392891347408,0.4861751748239109,0.9869457483291626,0.0440063215792179,0.2838137070427076,43793 -5008.497307538986,1.046447992324829,8663.806197404861,27679,0,8663.806197404861,0.9861114621162416,0.046661589294672,0.274525427657463,43793,13674.063752174376,0.991851568222046,0.0263207294046878,0.5115942328205016,0.9870147109031676,0.043874554336071,0.2798143819556658,43793 -5131.4640600681305,1.0783584117889404,8903.79060792923,28450,0,8903.79060792923,0.986026406288147,0.0468337163329124,0.2677213978781376,43793,14037.0665640831,0.9919523000717164,0.0260404907166957,0.5119545931912086,0.9868450164794922,0.0437918752431869,0.2751911119907508,43793 -5256.418573856354,1.1103498935699463,9143.852503299711,29222,0,9143.852503299711,0.9861489534378052,0.0470454916357994,0.2713569522959075,43793,14402.134669065475,0.9917731881141664,0.0265221875160932,0.4974743077792137,0.9870484471321106,0.0440644659101963,0.2817265212932616,43793 -5382.11224770546,1.1425657272338867,9383.99340891838,29990,0,9383.99340891838,0.98624587059021,0.0466033667325973,0.2732212950992574,43793,14768.02099108696,0.9917317628860474,0.026847893372178,0.4832345369128923,0.9871158003807068,0.043657187372446,0.274096221644849,43793 -5504.588839054108,1.1740312576293943,9624.226995944977,30754,0,9624.226995944977,0.9861165285110474,0.0470705963671207,0.273177578598428,43793,15130.782220840454,0.9916686415672302,0.0270208194851875,0.4869654090512664,0.9869489669799804,0.0441926978528499,0.2785772326789663,43793 -5631.215336561203,1.205542802810669,9864.3176279068,31525,0,9864.3176279068,0.9860710501670836,0.0469594039022922,0.2723993864317921,43793,15497.550547361374,0.9916601181030272,0.026939183473587,0.4903063582821136,0.9869124293327332,0.0442216247320175,0.2800932828179595,43793 -5756.375189065933,1.2379601001739502,10104.432535648346,32293,0,10104.432535648346,0.9861329793930054,0.0467185899615287,0.2727999380156504,43793,15862.876982688904,0.9916957020759584,0.0268357768654823,0.4875452796120583,0.986953854560852,0.0439895987510681,0.2828627977718431,43793 -5885.21012210846,1.26981782913208,10344.389446496964,33063,0,10344.389446496964,0.986196994781494,0.0469749793410301,0.272280218184706,43793,16231.720217704771,0.9917645454406738,0.0263966470956802,0.4937483853881236,0.9870102405548096,0.0440622717142105,0.2843081741127006,43793 -6010.354276895523,1.3019187450408936,10584.439611911774,33835,0,10584.439611911774,0.986108124256134,0.0469017885625362,0.2736489972800847,43793,16596.966630220413,0.9917652606964112,0.0265411045402288,0.4958937299340971,0.9870187640190125,0.0440790615975856,0.2793230355958955,43793 -6140.787740945816,1.333965539932251,10824.551196575165,34606,0,10824.551196575165,0.9861780405044556,0.0471189729869365,0.2752511972446212,43793,16967.56339740753,0.9919253587722778,0.0258921831846237,0.5138001849053772,0.9870142936706544,0.0442977733910083,0.2835272431791837,43793 -6266.13445353508,1.3671300411224363,11064.743356227877,35370,0,11064.743356227877,0.9861679077148438,0.0470642894506454,0.2723186520389838,43793,17333.15499162674,0.99210524559021,0.0254258923232555,0.520013516105533,0.9870370626449584,0.0441921800374984,0.2838264860463516,43793 -6390.104465007782,1.4000084400177002,11304.83763718605,36133,0,11304.83763718605,0.9861965775489808,0.0474430732429027,0.2730475080736813,43793,17697.271452903748,0.9921540021896362,0.0252154488116502,0.5257004962214487,0.9870886206626892,0.0442344062030315,0.2862508491770593,43793 -6513.972229719162,1.4337913990020752,11545.093611717224,36907,0,11545.093611717224,0.9861688017845154,0.0475516207516193,0.2679133358897493,43793,18061.448248147964,0.9923199415206908,0.0244986284524202,0.5594185517547302,0.9871170520782472,0.0442498102784156,0.2851949177370656,43793 -6636.588159322739,1.467538833618164,11785.12488269806,37673,0,11785.12488269806,0.9860794544219972,0.0477609112858772,0.2707454139816693,43793,18424.14876651764,0.9923747181892396,0.0244077499955892,0.5358389174143836,0.9870086312294006,0.0447503253817558,0.2843611275052901,43793 -6760.528062582016,1.5004465579986572,12025.290924072266,38439,0,12025.290924072266,0.9862340688705444,0.0471976324915885,0.2755533497816337,43793,18788.307156085968,0.992568016052246,0.0237795822322368,0.5684574426282603,0.9871584177017212,0.044204156845808,0.28866362399726,43793 -6884.158543348312,1.536177396774292,12265.243272781372,39209,0,12265.243272781372,0.9862378239631652,0.0474700033664703,0.2816381130732368,43793,19151.94514513016,0.99250328540802,0.0239628404378891,0.5541409386821077,0.98708575963974,0.0444251894950866,0.283720789086509,43793 -7008.079576253891,1.56976318359375,12505.26472067833,39979,0,12505.26472067833,0.9860883355140686,0.0477647967636585,0.2842939134219475,43793,19515.94053053856,0.9925006031990052,0.0241887252777814,0.5570633831413991,0.9870301485061646,0.0446991473436355,0.2828670110903142,43793 -7131.618180990219,1.6031405925750732,12745.222152233124,40717,0,12745.222152233124,0.9861072897911072,0.0472981296479702,0.2748004192246521,43793,19879.48866820336,0.9924003481864928,0.0243264287710189,0.5494295396144823,0.9869505763053894,0.0445120818912982,0.2806861022966553,43793 -7259.434130430222,1.6383109092712402,12985.309354305267,41485,0,12985.309354305267,0.986151933670044,0.0478561930358409,0.2822238579045632,43793,20247.44675898552,0.9922370314598083,0.0248098876327276,0.5382047059952417,0.9869790077209472,0.0448886454105377,0.2904843775330589,43793 -7386.007725477219,1.672457218170166,13225.327796459198,42250,0,13225.327796459198,0.9862955808639526,0.0478070452809333,0.2830454836599912,43793,20614.092463970184,0.992445707321167,0.0241518169641494,0.5460987084725499,0.98708575963974,0.0448537394404411,0.287187447261353,43793 -7520.616132259369,1.7076139450073242,13465.478171348572,43008,0,13465.478171348572,0.9862471222877502,0.0477485209703445,0.2772697280132931,43793,20988.90712785721,0.9924915432929992,0.0239395014941692,0.5532173934177393,0.9871170520782472,0.0446819961071014,0.2867559415482638,43793 -7643.762209892273,1.746145725250244,13705.421085119247,43769,0,13705.421085119247,0.986272394657135,0.0479588136076927,0.2796765902545455,43793,21352.05406999588,0.9925384521484376,0.023691838607192,0.5701476633664833,0.9872027039527892,0.0447466224431991,0.289002886959356,43793 -7766.672614097595,1.7812371253967283,13945.649666070938,44535,0,13945.649666070938,0.9860782027244568,0.047963697463274,0.2759291178161914,43793,21715.24806857109,0.9927001595497132,0.023185497149825,0.5688899401380112,0.9870370626449584,0.0447141453623771,0.2934221886301122,43793 -7892.67279791832,1.8158187866210933,14185.667350292206,45299,0,14185.667350292206,0.9862774610519408,0.0480165034532547,0.2768424105806721,43793,22081.31978273392,0.9928914308547974,0.0225920602679252,0.5915234600156412,0.9871146082878112,0.0449960939586162,0.2908628579927392,43793 -8019.319634437561,1.8533244132995603,14425.91460299492,46072,0,14425.91460299492,0.986213445663452,0.0481310449540615,0.2777517498001093,43793,22448.270634174347,0.9930112957954408,0.0221614427864551,0.5958331862615327,0.98707115650177,0.0450821332633495,0.2912427689904615,43793 -8143.096262216568,1.8887808322906487,14666.095184326172,46838,0,14666.095184326172,0.9862639904022216,0.0482524186372756,0.278822280780592,43793,22812.28260588646,0.9932113885879515,0.0215078685432672,0.6120571175636954,0.9871028065681458,0.0451406985521316,0.2953311430028864,43793 -8265.990796804428,1.9238817691802976,14906.135026216509,47607,0,14906.135026216509,0.9861923456192015,0.0486282743513584,0.2688196035004163,43793,23175.271939516068,0.9933398962020874,0.0211241487413644,0.6331789829490917,0.9870001077651978,0.0456755384802818,0.2915323484639742,43793 -8386.46659874916,1.958917617797852,15146.269136428831,48379,0,15146.269136428831,0.9862968325614928,0.0484311506152153,0.2783990005573538,43793,23535.936802864075,0.9934019446372986,0.0209897384047508,0.6255306974010649,0.9870269298553468,0.0455353967845439,0.2903014903988803,43793 -8506.296981334686,1.9959444999694824,15386.475734949112,49148,0,15386.475734949112,0.9860782027244568,0.0487363375723362,0.2771831528863973,43793,23896.030689239506,0.9931447505950928,0.0217654425650835,0.6131731781560671,0.986970067024231,0.0455004014074802,0.2893361191203478,43793 -8630.36907529831,2.0322859287261963,15626.695341348648,49921,0,15626.695341348648,0.986199915409088,0.0487848855555057,0.2782597699814962,43793,24260.37875604629,0.9931887984275818,0.0215563420206308,0.5925207424069308,0.9870281219482422,0.0456475801765918,0.2885576907983909,43793 -8753.453888893127,2.068065881729126,15866.866524219511,50693,0,15866.866524219511,0.9861531853675842,0.0487932786345481,0.2808486112206802,43793,24623.690307617188,0.9932281374931335,0.0214937105774879,0.6125680428291371,0.9870500564575196,0.0457317978143692,0.2902672784161562,43793 -8871.36655497551,2.1043295860290527,16106.96294260025,51460,0,16106.96294260025,0.9861097931861876,0.0489931516349315,0.2756912633800584,43793,24981.755274295807,0.9931707382202148,0.0215623006224632,0.6086612239107657,0.9869388341903688,0.0458887852728366,0.2875652110156065,43793 -8994.220281600952,2.1409432888031006,16346.9478225708,52231,0,16346.9478225708,0.986139714717865,0.0490336157381534,0.2792324231384169,43793,25344.649944782257,0.9933059811592102,0.0211930908262729,0.6209554499918561,0.9869749546051024,0.0457281917333602,0.2894453169966877,43793 -9116.906694173813,2.1766245365142822,16586.905706882477,53001,0,16586.905706882477,0.9862247705459596,0.0497451350092887,0.2790904600309866,43793,25707.349474668503,0.9933321475982666,0.0208378545939922,0.6346190961639615,0.9871361255645752,0.0464642457664012,0.2905150610562712,43793 -9239.053257226944,2.212446451187134,16827.004104614258,53745,0,16827.004104614258,0.9861270785331726,0.0495319850742816,0.2741914971743095,43793,26069.652314662933,0.9934971332550048,0.0204658973962068,0.630007503667221,0.9869733452796936,0.0461771227419376,0.2869864873591823,43793 -9357.21773147583,2.249485731124878,17067.124797344208,54511,0,17067.124797344208,0.986234486103058,0.0499737039208412,0.275257343295017,43793,26427.993901014328,0.9936056733131408,0.0200942922383546,0.6474843711686171,0.9870001077651978,0.0466343276202678,0.2924957112403711,43793 -9483.120379209518,2.2856380939483643,17307.159264326096,55261,0,17307.159264326096,0.986182689666748,0.0499526932835578,0.2779944627561785,43793,26793.98760509491,0.9938763976097108,0.0192595422267913,0.6639851968755207,0.987063467502594,0.0468171574175357,0.2882980116065824,43793 -9604.013299703598,2.322486400604248,17547.22730565071,56031,0,17547.22730565071,0.9861670732498168,0.0500989556312561,0.2774957790292882,43793,27155.00487589836,0.9941009283065796,0.0186919644474983,0.6720966607253281,0.9870317578315736,0.0469521507620811,0.2862859302274942,43793 -9734.03482222557,2.3649609088897705,17787.253559589386,56768,0,17787.253559589386,0.9861464500427246,0.0500405393540859,0.2790140354084867,43793,27525.11945939064,0.9941667318344116,0.0182786993682384,0.6992927322278004,0.9869331121444702,0.0469362549483776,0.2908976053726619,43793 -9854.109743356705,2.407958745956421,18027.47919726372,57528,0,18027.47919726372,0.9862096309661864,0.0500779040157794,0.2783108176366595,43793,27885.483307123184,0.994340181350708,0.0179514028131961,0.6863589255944273,0.9869924187660216,0.0468558855354785,0.2878277138915034,43793 -9981.927907466888,2.44576358795166,18267.49107837677,58296,0,18267.49107837677,0.9860777854919434,0.0504041463136673,0.2747297670159299,43793,28253.37156844139,0.9941099286079408,0.0186354778707027,0.6760400525936513,0.9870346188545228,0.0470816828310489,0.284984456051923,43793 -10100.244191408156,2.483938217163086,18507.593081474304,59062,0,18507.593081474304,0.9862083792686462,0.0508651062846183,0.2728735942641346,43793,28611.847897052765,0.9940509796142578,0.0185292568057775,0.6800617030023898,0.9870293140411376,0.0475336201488971,0.2872965751470687,43793 -10219.173105716704,2.522862672805786,18747.74967765808,59830,0,18747.74967765808,0.9861797094345092,0.0508675575256347,0.2815177285535439,43793,28970.99193549156,0.9940648674964904,0.0184501875191926,0.6847460980138358,0.9869866967201232,0.0475491769611835,0.2900075163292792,43793 -10339.2678835392,2.560737371444702,18987.98977828025,60601,0,18987.98977828025,0.9862108826637268,0.0513647943735122,0.277457402308432,43793,29331.38431572914,0.9941041469573976,0.0182408802211284,0.6792949123470332,0.9869810342788696,0.0480771027505397,0.2883566277049855,43793 -10456.751803159714,2.598947525024414,19227.93788957596,61370,0,19227.93788957596,0.9860908389091492,0.05152552947402,0.2758072130658053,43793,29688.873893022537,0.9942519664764404,0.0178911536931991,0.6891838281450775,0.9869696497917176,0.0480914264917373,0.2907775939415492,43793 -10577.463264465332,2.6371090412139893,19467.98549079895,62141,0,19467.98549079895,0.986216366291046,0.0519167222082614,0.2729711537032664,43793,30049.69030022621,0.99410742521286,0.018122250214219,0.6837093136990084,0.98700213432312,0.0485698021948337,0.2882322445128811,43793 -10697.105581521988,2.6750237941741943,19707.979488134384,62914,0,19707.979488134384,0.9861868619918824,0.0521451793611049,0.2739277181367657,43793,30409.38412237168,0.9943687319755554,0.0174585618078708,0.710569110240898,0.9870545268058776,0.0485628210008144,0.2889291577087489,43793 -10817.599865198135,2.712806940078736,19948.06530070305,63683,0,19948.06530070305,0.9861218333244324,0.0523505769670009,0.2753913641660267,43793,30770.02154660225,0.9945791959762572,0.0168491005897521,0.7172802330353509,0.9870272874832152,0.0488625429570674,0.2913743414060199,43793 -10936.445219993591,2.7514443397521973,20188.19465708733,64451,0,20188.19465708733,0.9862428903579712,0.0527086183428764,0.2760658608501192,43793,31129.054266929623,0.9945203065872192,0.0168881583958864,0.7186517583486293,0.9870195984840392,0.0491692908108234,0.2907579213342952,43793 -11055.935570955276,2.790802240371704,20428.312505960464,65222,0,20428.312505960464,0.9862112998962402,0.0526616871356964,0.2729145729121758,43793,31488.72131800652,0.994799256324768,0.016018958762288,0.7327114869218326,0.987107276916504,0.0491626411676406,0.28769506812764,43793 -11175.49587225914,2.8298258781433105,20668.418220043182,65996,0,20668.418220043182,0.9861881732940674,0.0529731400310993,0.2759885855868024,43793,31848.44557189941,0.9950602650642396,0.0153986578807234,0.7444797052813373,0.9870184063911438,0.0495036095380783,0.2858881628178701,43793 -11295.720662355425,2.86948561668396,20908.61289286613,66768,0,20908.61289286613,0.9861089587211608,0.0529910996556282,0.2741048176481009,43793,32208.92369627953,0.9951653480529784,0.0151965515688061,0.7628244662934002,0.9869709014892578,0.0494506396353244,0.284574734141852,43793 -11417.421566963196,2.909762382507324,21148.618430376053,67541,0,21148.618430376053,0.9860563278198242,0.0530358180403709,0.2788811681071141,43793,32570.689962387085,0.995209276676178,0.0150286378338932,0.7566962851795804,0.9869514107704164,0.049561109393835,0.2874051538617712,43793 -11537.640836715698,2.9481992721557617,21388.71308946609,68313,0,21388.71308946609,0.986006200313568,0.0531089939177036,0.2743112178448637,43793,32931.06196761131,0.9951818585395812,0.0150977028533816,0.7568026094827762,0.9868556261062622,0.0496043711900711,0.2851188808682722,43793 -11655.715996980667,2.9867284297943115,21628.667563438416,69080,0,21628.667563438416,0.9861005544662476,0.0535284578800201,0.2751838973845723,43793,33289.1495051384,0.9952203035354614,0.014993236400187,0.7534647243648879,0.9869457483291626,0.0499713309109211,0.2849290280465007,43793 -11787.44162750244,3.026840209960937,21868.81844639778,69825,0,21868.81844639778,0.9861228466033936,0.0538902319967746,0.2756676195806692,43793,33661.08895611763,0.9949262142181396,0.0156571958214044,0.7358463544646767,0.986976146697998,0.050358448177576,0.285283862791445,43793 -11911.392924785614,3.07214879989624,22109.06586909294,70562,0,22109.06586909294,0.986019253730774,0.0538304969668388,0.2737792846556728,43793,34025.35740637779,0.9949321746826172,0.0156741421669721,0.7396729190456879,0.9868730306625366,0.0503210015594959,0.2847257727644883,43793 -12040.653289079666,3.115537643432617,22349.18640422821,71302,0,22349.18640422821,0.9860963225364684,0.0537710450589656,0.2754479572202304,43793,34394.80511713028,0.9951404333114624,0.0150937959551811,0.753908089745251,0.9869737029075624,0.0502049997448921,0.2896799036222076,43793 -12164.864319086077,3.1560232639312744,22589.416313409805,72069,0,22589.416313409805,0.9861177802085876,0.0538271330296993,0.2747690843663519,43793,34759.30646586418,0.9952371716499328,0.0147636476904153,0.7646365713289918,0.9869250059127808,0.0503023527562618,0.2853672165884759,43793 -12286.359723567964,3.19601845741272,22829.366702079773,72838,0,22829.366702079773,0.986180543899536,0.0540541112422943,0.2759792410743825,43793,35120.810903549194,0.9951958656311036,0.0147189879789948,0.7631483641737393,0.9870277047157288,0.0504600740969181,0.2847857847326915,43793 -12407.947816610336,3.236490726470948,23069.62701439857,73601,0,23069.62701439857,0.9861106276512146,0.0539973154664039,0.2749001561535323,43793,35482.71941781044,0.9955556988716124,0.0139520000666379,0.7762335065818747,0.9869396090507508,0.0503431037068367,0.2852072589426013,43793 -12526.548835277556,3.2765581607818604,23309.805548667908,74361,0,23309.805548667908,0.9861372113227844,0.0541548170149326,0.2747839274994807,43793,35841.558666944504,0.9955661296844482,0.0139135178178548,0.7747334711189504,0.9869696497917176,0.050527736544609,0.2868443094424433,43793 -12647.343688488008,3.3206896781921387,23549.76802420616,75125,0,23549.76802420616,0.9861240983009338,0.054233469069004,0.2750307266695363,43793,36202.37942099571,0.9955736994743348,0.0139497052878141,0.7811614897240826,0.9869436621665956,0.0505538024008274,0.2867632035727815,43793 -12766.707190036774,3.362133741378784,23789.78480911255,75883,0,23789.78480911255,0.9861254096031188,0.0542790330946445,0.2750352727505398,43793,36561.82056903839,0.9955554008483888,0.0138175962492823,0.7852192319353195,0.9869920015335084,0.0505729056894779,0.2878512364293371,43793 -12888.594792842863,3.4054219722747803,24029.778205871586,76655,0,24029.778205871586,0.986119508743286,0.0542350113391876,0.2757211049940482,43793,36923.76485896111,0.9955092668533324,0.0140359466895461,0.7789031097164165,0.98692786693573,0.0505664274096488,0.2857804150321548,43793 -13008.900380373,3.4460644721984863,24269.738109350204,77424,0,24269.738109350204,0.9861531853675842,0.0542887225747108,0.2760953723677083,43793,37284.09049272537,0.995467722415924,0.0140506261959671,0.7791727016973256,0.9869558811187744,0.0506208762526512,0.2858431907427471,43793 -13128.102998018265,3.4863169193267822,24509.947502613068,78189,0,24509.947502613068,0.9861258268356324,0.0542734786868095,0.2755627204775177,43793,37643.56195235253,0.9954847693443298,0.0140595519915223,0.7697898776753651,0.9869655966758728,0.050602450966835,0.2866528137796882,43793 -13248.0815885067,3.8449010848999023,24749.636902093887,78951,0,24749.636902093887,0.986123263835907,0.0542533323168754,0.2761878777633754,43793,38003.607684612274,0.9955748915672302,0.0139743052423,0.7815056272749845,0.9869611263275146,0.05058304220438,0.2863825448521734,43793 -13371.174401044846,3.886876583099365,24989.647536993027,79719,0,24989.647536993027,0.986132562160492,0.0542628169059753,0.2760192221364203,43793,38366.77245020866,0.9955337047576904,0.0138495787978172,0.7799917830671548,0.9869562983512878,0.0505892001092433,0.2869632430767811,43793 -13486.504242897034,3.929654359817505,25229.74752855301,80486,0,25229.74752855301,0.9861329793930054,0.0542605891823768,0.2760004584768348,43793,38722.26422739029,0.9955554008483888,0.0139354560524225,0.7850682347056999,0.9869558811187744,0.0505871586501598,0.2865156718168872,43793 -13621.399010181429,3.971792221069336,25469.8677110672,81232,0,25469.8677110672,0.9861329793930054,0.0542605891823768,0.2759756767757269,43793,39097.34252977371,0.9955120086669922,0.0139531018212437,0.776054432927252,0.9869558811187744,0.0505871586501598,0.2865305392511518,43793 -13750.508263587952,4.018112182617188,25710.078585147858,81968,0,25710.078585147858,0.9861329793930054,0.0542605929076671,0.2760720686538725,43793,39466.73317718506,0.995551586151123,0.0139019964262843,0.7857464322595502,0.9869558811187744,0.0505871586501598,0.2865071063360335,43793 -13875.881750106812,4.064675092697144,25950.17149996757,82707,0,25950.17149996757,0.9861329793930054,0.0542605891823768,0.2759491004785961,43793,39832.270023822784,0.9955469369888306,0.013970274478197,0.7702644447155856,0.9869558811187744,0.0505871586501598,0.2866412038804638,43793 -14008.269878864288,4.11327862739563,26190.390765428543,83435,0,26190.390765428543,0.9861329793930054,0.0542605891823768,0.2759817029818188,43793,40204.94928979874,0.9955602884292604,0.0138719677925109,0.7792143024381376,0.9869558811187744,0.0505871586501598,0.2867899905995227,43793 -14126.985151052477,4.160360813140869,26430.476687908173,84184,0,26430.476687908173,0.9861329793930054,0.0542605891823768,0.2759793362653783,43793,40563.81817674637,0.9955973625183104,0.0137764988467097,0.783298819531164,0.9869558811187744,0.0505871586501598,0.286624042646645,43793 -14246.918652057648,4.203290939331055,26670.56862926483,84940,0,26670.56862926483,0.9861329793930054,0.0542605891823768,0.2759710816592701,43793,40923.90593361855,0.9954774379730223,0.0140807861462235,0.7756743396009409,0.9869558811187744,0.0505871586501598,0.2864305899233596,43793 -14366.983271598816,4.25146746635437,26910.639630556107,85693,0,26910.639630556107,0.9861329793930054,0.0542605891823768,0.2759714642884614,43793,41284.11113762856,0.9955883622169496,0.0137911858037114,0.7841693580470048,0.9869558811187744,0.0505871586501598,0.2865949001344348,43793 -14488.276557445526,4.293117523193359,27150.82177400589,86448,0,27150.82177400589,0.9861329793930054,0.0542605891823768,0.2759215759192164,43793,41645.6478009224,0.9955260157585144,0.0140296816825866,0.7727005352207057,0.9869558811187744,0.0505871586501598,0.2866704201057026,43793 -14605.95310664177,4.343390703201294,27391.00008749962,87201,0,27391.00008749962,0.9861329793930054,0.0542605929076671,0.2760478402507823,43793,42003.57257246971,0.9955212473869324,0.01403393689543,0.7780336724337944,0.9869558811187744,0.0505871586501598,0.286602377318734,43793 -14729.32446551323,4.386620283126831,27631.004123210907,87944,0,27631.004123210907,0.9861329793930054,0.0542605891823768,0.2760906757638992,43793,42367.01085090637,0.995576798915863,0.0137699004262685,0.7864783435135866,0.9869558811187744,0.0505871586501598,0.2865672216156085,43793 -14848.06216621399,4.431077480316162,27870.94686794281,88693,0,27870.94686794281,0.9861329793930054,0.0542605891823768,0.2761358426923165,43793,42725.75531196594,0.9955042600631714,0.014033424668014,0.7815081377763017,0.9869558811187744,0.0505871586501598,0.2864212378275307,43793 -14969.166311979294,4.4745166301727295,28111.208525419235,89447,0,28111.208525419235,0.9861329793930054,0.0542605891823768,0.2759851509876829,43793,43087.18470454216,0.9955434799194336,0.0138015914708375,0.778073415456541,0.9869558811187744,0.0505871586501598,0.2865742535694123,43793 -15085.602485895157,4.517841577529907,28351.200445890427,90204,0,28351.200445890427,0.9861329793930054,0.0542605891823768,0.2759657702013668,43793,43443.676209926605,0.9955761432647704,0.0139395911246538,0.7750910374084437,0.9869558811187744,0.0505871586501598,0.2865358350164018,43793 -15205.82478427887,4.563338756561279,28591.379506587986,90963,0,28591.379506587986,0.9861329793930054,0.0542605891823768,0.2759976092226506,43793,43804.14271020889,0.9954948425292968,0.014080642722547,0.7802792739060775,0.9869558811187744,0.0505871586501598,0.2866171329509638,43793 -15322.338630437853,4.60586142539978,28831.39079451561,91727,0,28831.39079451561,0.9861329793930054,0.0542605891823768,0.2759996168856555,43793,44160.73046088219,0.9955697059631348,0.0138261914253234,0.7810301983931045,0.9869558811187744,0.0505871586501598,0.2866110429605794,43793 -15441.70261669159,4.650819540023804,29071.55628156662,92495,0,29071.55628156662,0.9861329793930054,0.0542605891823768,0.2759411183774932,43793,44520.32515239716,0.9955700039863586,0.013861620798707,0.7817951724672543,0.9869558811187744,0.0505871586501598,0.2865152886626495,43793 -15562.487857341766,4.694918155670166,29311.766901016235,93254,0,29311.766901016235,0.9861329793930054,0.0542605891823768,0.2759367174136053,43793,44881.38517570496,0.9955008029937744,0.0140228671953082,0.7731956136988019,0.9869558811187744,0.0505871586501598,0.2864977609935052,43793 -15680.41870713234,4.739025115966797,29551.97358584404,94016,0,29551.97358584404,0.9861329793930054,0.0542605891823768,0.2760667388806535,43793,45239.58659052849,0.995567500591278,0.0138811152428388,0.7789318143991779,0.9869558811187744,0.0505871586501598,0.2865480551449149,43793 -15801.521898031237,4.783658504486084,29791.934020996094,94782,0,29791.934020996094,0.9861329793930054,0.0542605891823768,0.2759850102404555,43793,45600.714686870575,0.9955399036407472,0.0139724360778927,0.7753218829520573,0.9869558811187744,0.0505871586501598,0.2866137226228561,43793 -15916.854518413544,4.826842546463013,30032.101594686508,95545,0,30032.101594686508,0.9861329793930054,0.0542605891823768,0.2760216601824851,43793,45956.278367996216,0.9955657124519348,0.0138775575906038,0.7811558660083017,0.9869558811187744,0.0505871586501598,0.2866254112875427,43793 -16035.049534797668,4.871997594833374,30272.227252483368,96309,0,30272.227252483368,0.9861329793930054,0.0542605891823768,0.2760575986127702,43793,46314.66438794136,0.9955589175224304,0.0138029223307967,0.7868677053325954,0.9869558811187744,0.0505871586501598,0.2866101630441003,43793 -16159.157991409302,4.91708254814148,30512.21157264709,97067,0,30512.21157264709,0.9861329793930054,0.0542605891823768,0.2759726786264729,43793,46678.82244515419,0.995500147342682,0.0140522504225373,0.7729460373532289,0.9869558811187744,0.0505871586501598,0.2866899351187257,43793 -16279.08171248436,4.960994243621826,30752.282462358475,97829,0,30752.282462358475,0.9861329793930054,0.0542605891823768,0.2761105488502643,43793,47038.88059139252,0.9955435991287231,0.0138851748779416,0.7760392545421367,0.9869558811187744,0.0505871586501598,0.2865191272693426,43793 -16395.97787475586,5.005981922149658,30992.51039814949,98578,0,30992.51039814949,0.9861329793930054,0.0542605929076671,0.2760305001928738,43793,47396.06888914108,0.9955313801765442,0.0139819029718637,0.7786825487258938,0.9869558811187744,0.0505871586501598,0.2864706683088275,43793 -16514.374661684036,5.0520899295806885,31232.58783340454,99336,0,31232.58783340454,0.9861329793930054,0.0542605929076671,0.2759760452958431,43793,47754.60851669312,0.9955341815948486,0.0139691065996885,0.7826914554191917,0.9869558811187744,0.0505871586501598,0.2865572910504564,43793 -16633.34353852272,5.097035884857178,31472.69523191452,100103,0,31472.69523191452,0.9861329793930054,0.0542605891823768,0.2759562314694286,43793,48113.74931025505,0.9955401420593262,0.0138903046026825,0.7793055337576846,0.9869558811187744,0.0505871586501598,0.2865208102988978,43793 -16751.799800157547,5.141988754272461,31712.652955770493,100869,0,31712.652955770493,0.9861329793930054,0.0542605891823768,0.2760779355054802,43793,48472.2280766964,0.9955520629882812,0.0139562338590621,0.7847198707800941,0.9869558811187744,0.0505871586501598,0.2865390644148356,43793 -16869.23617386818,5.188489437103272,31952.75311899185,101634,0,31952.75311899185,0.9861329793930054,0.0542605891823768,0.2759564045633718,43793,48829.83093047142,0.9955708980560304,0.0137952547520399,0.7813605957526542,0.9869558811187744,0.0505871586501598,0.2866145553852413,43793 -16990.588547229767,5.23371148109436,32192.752415180206,102400,0,32192.752415180206,0.9861329793930054,0.0542605891823768,0.276152526270037,43793,49191.24772572517,0.9955251216888428,0.0139615908265113,0.770412098402832,0.9869558811187744,0.0505871586501598,0.2864852371767319,43793 -17107.008259534836,5.278891801834106,32432.875032186508,103158,0,32432.875032186508,0.9861329793930054,0.0542605891823768,0.275994188857404,43793,49547.85517024994,0.9955227375030518,0.0140439234673976,0.773106773764879,0.9869558811187744,0.0505871586501598,0.2867109029952526,43793 -17218.14317536354,5.324789047241211,32672.83212566376,103919,0,32672.83212566376,0.9861329793930054,0.0542605891823768,0.2759413262208593,43793,49899.01314020157,0.9955666065216064,0.0138385407626628,0.7811419869738975,0.9869558811187744,0.0505871586501598,0.2865256056236226,43793 -17338.682789564133,5.370456695556641,32912.77941918373,104688,0,32912.77941918373,0.9861329793930054,0.0542605891823768,0.2759689009568245,43793,50259.56562113762,0.9955680966377258,0.0138523709028959,0.7850848461487768,0.9869558811187744,0.0505871586501598,0.2865603986056249,43793 -17458.148350954056,5.415530681610107,33152.962929964066,105447,0,33152.962929964066,0.9861329793930054,0.0542605891823768,0.2760524963945224,43793,50619.27890300751,0.9955039620399476,0.0139955282211303,0.7774106373615576,0.9869558811187744,0.0505871586501598,0.2865216599854022,43793 -17582.141822576523,5.462770223617554,33392.90033054352,106180,0,33392.90033054352,0.9861329793930054,0.0542605891823768,0.2759944644891478,43793,50983.27924466133,0.9955573678016664,0.0138543657958507,0.7817870350219891,0.9869558811187744,0.0505871586501598,0.2866419430739763,43793 -17698.731519460678,5.51738166809082,33633.02038526535,106918,0,33633.02038526535,0.9861329793930054,0.0542605929076671,0.2759606267995214,43793,51340.06503534317,0.9955280423164368,0.0140522280707955,0.7661217680770456,0.9869558811187744,0.0505871586501598,0.286556291387751,43793 -17814.46081638336,5.564499616622925,33872.95367479324,107677,0,33872.95367479324,0.9861329793930054,0.0542605929076671,0.2760152750049882,43793,51695.79447507858,0.9955880641937256,0.0138220340013504,0.779899470061354,0.9869558811187744,0.0505871586501598,0.2865691722887861,43793 -17927.368314504623,5.611015796661377,34113.1648209095,108441,0,34113.1648209095,0.9861329793930054,0.0542605891823768,0.2759831511515541,43793,52048.97903227806,0.9955666065216064,0.0138393836095929,0.7841374720767258,0.9869558811187744,0.0505871586501598,0.2864342504751349,43793 -18040.35396838188,5.656085968017578,34353.28302168846,109204,0,34353.28302168846,0.9861329793930054,0.0542605891823768,0.2760052366019689,43793,52402.14703559876,0.9955072999000548,0.0140102095901966,0.781538035227237,0.9869558811187744,0.0505871586501598,0.286505081826393,43793 -18168.47488641739,5.704323053359985,34593.253633499146,109946,0,34593.253633499146,0.9861329793930054,0.0542605891823768,0.2760160422751042,43793,52770.308326005936,0.9955655932426452,0.013794494792819,0.7802740098906389,0.9869558811187744,0.0505871586501598,0.2865730157937727,43793 -18288.60990166664,5.759541034698486,34833.229488134384,110688,0,34833.229488134384,0.9861329793930054,0.0542605891823768,0.2759493312803741,43793,53130.4960372448,0.9955146312713624,0.0140324300155043,0.7722557395643874,0.9869558811187744,0.0505871586501598,0.2867128276145199,43793 -18404.20976114273,5.80646538734436,35073.32724404335,111441,0,35073.32724404335,0.9861329793930054,0.0542605891823768,0.2759526438099676,43793,53486.25983309746,0.9955356121063232,0.0140041010454297,0.7818724497668679,0.9869558811187744,0.0505871586501598,0.2864828757007852,43793 -18523.07643556595,5.854486465454102,35313.35783934593,112205,0,35313.35783934593,0.9861329793930054,0.0542605891823768,0.276035341605793,43793,53845.22560930252,0.9955651760101318,0.0138434814289212,0.7786537040391218,0.9869558811187744,0.0505871586501598,0.2864558681747481,43793 -18635.68253469467,5.9019293785095215,35553.29326963425,112967,0,35553.29326963425,0.9861329793930054,0.0542605891823768,0.2760329738837908,43793,54197.83467626572,0.995530605316162,0.0139740211889147,0.7914633546862597,0.9869558811187744,0.0505871586501598,0.2866072428159543,43793 -18746.87947440148,5.949110984802246,35793.34585046768,113731,0,35793.34585046768,0.9861329793930054,0.0542605891823768,0.2759876579915408,43793,54549.151299238205,0.9955313205718994,0.0138783520087599,0.7689000334908849,0.9869558811187744,0.0505871586501598,0.2866461290888903,43793 -18862.78116440773,5.997893810272217,36033.280943632126,114488,0,36033.280943632126,0.9861329793930054,0.0542605891823768,0.2760602507676384,43793,54905.05683708191,0.9955372214317322,0.0139394728466868,0.7815741067321347,0.9869558811187744,0.0505871586501598,0.2865367019489142,43793 -18975.043164491653,6.046472072601318,36273.37165546417,115249,0,36273.37165546417,0.9861329793930054,0.0542605891823768,0.2759374328359461,43793,55257.477509737015,0.9955669045448304,0.0139568988233804,0.7696413411763118,0.9869558811187744,0.0505871586501598,0.2866747762429935,43793 -19085.425540685654,6.093016862869263,36513.53369688988,116019,0,36513.53369688988,0.9861329793930054,0.0542605891823768,0.2760065789253968,43793,55608.0884168148,0.9955641627311708,0.013859805651009,0.7850317787543281,0.9869558811187744,0.0505871586501598,0.2865551392909943,43793 -19202.251974105835,6.139763593673706,36753.48701620102,116791,0,36753.48701620102,0.9861329793930054,0.0542605929076671,0.2759737336301693,43793,55964.93444156647,0.9955353140830994,0.0139025328680872,0.7856174298674152,0.9869558811187744,0.0505871586501598,0.2865175104084884,43793 -19313.66682934761,6.1876304149627686,36993.59615969658,117554,0,36993.59615969658,0.9861329793930054,0.0542605891823768,0.276063221490286,43793,56316.52581310272,0.9955047965049744,0.0140029955655336,0.7721958738934317,0.9869558811187744,0.0505871586501598,0.286516307077295,43793 -19430.334386587143,6.235366821289063,37233.81376886368,118321,0,37233.81376886368,0.9861329793930054,0.0542605891823768,0.2760523527524613,43793,56673.477946043015,0.9955202341079712,0.0139669338241219,0.7832561135578842,0.9869558811187744,0.0505871586501598,0.2867236390114384,43793 -19549.893506526947,6.28296160697937,37473.95171999931,119085,0,37473.95171999931,0.9861329793930054,0.0542605891823768,0.2759484318103051,43793,57033.24269533157,0.9955582022666932,0.0139131685718894,0.7770310732162424,0.9869558811187744,0.0505871586501598,0.2866907177414473,43793 -19662.93479347229,6.332173109054565,37714.03889346123,119849,0,37714.03889346123,0.9861329793930054,0.0542605891823768,0.2760382490303594,43793,57386.43988108635,0.9955252408981324,0.0140389157459139,0.7764668111826432,0.9869558811187744,0.0505871586501598,0.2867658580304355,43793 -19776.17451095581,6.382076978683472,37954.24491381645,120612,0,37954.24491381645,0.9861329793930054,0.0542605929076671,0.2759795768788805,43793,57739.95515346527,0.9956113696098328,0.013700583949685,0.7857608368347888,0.9869558811187744,0.0505871586501598,0.2865649723131295,43793 -19890.67685246468,6.430763006210327,38194.241970300674,121378,0,38194.241970300674,0.9861329793930054,0.0542605891823768,0.2759365662211449,43793,58094.52316451073,0.9955203533172609,0.0139859775081276,0.7798932588725074,0.9869558811187744,0.0505871586501598,0.2865813047482279,43793 -20014.006541490555,6.4794793128967285,38434.26357674599,122107,0,38434.26357674599,0.9861329793930054,0.0542605891823768,0.2759651343786751,43793,58457.947122097015,0.9955151081085204,0.0139172924682497,0.7799444582331069,0.9869558811187744,0.0505871586501598,0.286507144273752,43793 -20133.949590206143,6.534131288528442,38674.20046806336,122836,0,38674.20046806336,0.9861329793930054,0.0542605891823768,0.2760162167495111,43793,58817.90552377701,0.995549976825714,0.0139400735497474,0.7765179471599951,0.9869558811187744,0.0505871586501598,0.2864950099685684,43793 -20247.099562883377,6.583909273147583,38914.34498047829,123599,0,38914.34498047829,0.9861329793930054,0.0542605891823768,0.2759403551749374,43793,59171.270610809326,0.9955602288246156,0.013960919342935,0.7762064006184903,0.9869558811187744,0.0505871586501598,0.2866554438585214,43793 -20362.13060450554,6.63392972946167,39154.38414406776,124357,0,39154.38414406776,0.9861329793930054,0.0542605891823768,0.2759882452278814,43793,59526.41132116318,0.9955296516418456,0.0139157203957438,0.7814746632594798,0.9869558811187744,0.0505871586501598,0.2864686696059941,43793 -20476.22482061386,6.685439109802246,39394.45890569687,125115,0,39394.45890569687,0.9861329793930054,0.0542605929076671,0.2759794268804273,43793,59880.65182375908,0.9956167340278624,0.0137162897735834,0.7831922579711271,0.9869558811187744,0.0505871586501598,0.2866426636904639,43793 -20591.374747753143,6.735987901687622,39634.669246912,125877,0,39634.669246912,0.9861329793930054,0.0542605891823768,0.2759345952193906,43793,60236.08238339424,0.995496392250061,0.0140708796679973,0.7815206147351226,0.9869558811187744,0.0505871586501598,0.2865493820973373,43793 -20703.41284537316,6.785962104797363,39874.781376600266,126634,0,39874.781376600266,0.9861329793930054,0.0542605891823768,0.2759888347376683,43793,60588.30221319199,0.995526134967804,0.0139440223574638,0.7840203056045634,0.9869558811187744,0.0505871586501598,0.2865940628883549,43793 -20817.052837133408,6.836190938949585,40114.7754137516,127383,0,40114.7754137516,0.9861329793930054,0.0542605891823768,0.2759701992733522,43793,60942.00603628159,0.9955549836158752,0.0139614222571253,0.7755276795761658,0.9869558811187744,0.0505871586501598,0.2864978730948987,43793 -20934.46791172028,6.886796712875366,40354.73989057541,128144,0,40354.73989057541,0.9861329793930054,0.0542605891823768,0.2759531333787746,43793,61299.456189870834,0.9955434203147888,0.0139302490279078,0.7804785749072813,0.9869558811187744,0.0505871586501598,0.2865674623685516,43793 -21046.012306928635,6.93725061416626,40594.930896520615,128908,0,40594.930896520615,0.9861329793930054,0.0542605891823768,0.2761565504820941,43793,61651.26204371452,0.9955562949180604,0.0138796502724289,0.7869704604464649,0.9869558811187744,0.0505871586501598,0.2865487520763469,43793 -21160.976698875427,6.987164974212647,40835.15206003189,129672,0,40835.15206003189,0.9861329793930054,0.0542605891823768,0.2760121926469126,43793,62006.51690149307,0.995521605014801,0.0139315128326416,0.7730002351900508,0.9869558811187744,0.0505871586501598,0.2865531526774762,43793 -21271.14502310753,7.038794279098511,41075.38081145287,130436,0,41075.38081145287,0.9861329793930054,0.0542605929076671,0.2760328250730925,43793,62356.98491954804,0.9955735206604004,0.0138278417289257,0.7775615790059116,0.9869558811187744,0.0505871586501598,0.2866122372734342,43793 -21386.529180288315,7.088238716125488,41315.339268922806,131200,0,41315.339268922806,0.9861329793930054,0.0542605891823768,0.2760851045251188,43793,62712.39630937576,0.9955013394355774,0.0140891848132014,0.7765810385292532,0.9869558811187744,0.0505871586501598,0.2865142023598593,43793 -21493.62951111793,7.139894247055054,41555.53718948364,131968,0,41555.53718948364,0.9861329793930054,0.0542605891823768,0.2760290599539537,43793,63059.76567673683,0.9955359101295472,0.0139632923528552,0.7859540350520396,0.9869558811187744,0.0505871586501598,0.2865155375239707,43793 -21605.530052900314,7.191572189331055,41795.69300460816,132736,0,41795.69300460816,0.9861329793930054,0.0542605891823768,0.2759484631551082,43793,63411.89345788956,0.9955663084983826,0.0137974023818969,0.780957904115043,0.9869558811187744,0.0505871586501598,0.2864971390162966,43793 -21714.979140996933,7.244651079177856,42035.6406815052,133501,0,42035.6406815052,0.9861329793930054,0.0542605891823768,0.2759436747545231,43793,63761.36307883263,0.9955506920814514,0.0139207448810338,0.7788296776529676,0.9869558811187744,0.0505871586501598,0.2866552551738057,43793 -21823.23331308365,7.296629428863525,42275.762905836105,134266,0,42275.762905836105,0.9861329793930054,0.0542605929076671,0.2760273170786373,43793,64109.81100869179,0.995567262172699,0.0138367991894483,0.7788725334472248,0.9869558811187744,0.0505871586501598,0.2864781178666014,43793 -21944.152539014816,7.349959850311279,42515.91206884384,135010,0,42515.91206884384,0.9861329793930054,0.0542605891823768,0.2760359550998511,43793,64470.954426050186,0.99552983045578,0.0139883216470479,0.7758701306845568,0.9869558811187744,0.0505871586501598,0.2865875693942649,43793 -22059.821719884872,7.40891695022583,42756.0000565052,135753,0,42756.0000565052,0.9861329793930054,0.0542605891823768,0.2760901919326172,43793,64826.79478096962,0.9955183863639832,0.0140597959980368,0.7834770621587747,0.9869558811187744,0.0505871586501598,0.2864998181212718,43793 -22171.06094479561,7.4614715576171875,42996.08527255058,136512,0,42996.08527255058,0.9861329793930054,0.0542605891823768,0.2759971165969379,43793,65178.19145774841,0.995557963848114,0.0137953907251358,0.7795278572179078,0.9869558811187744,0.0505871586501598,0.2864227948563459,43793 -22284.24476671219,7.513568878173828,43236.27346968651,137273,0,43236.27346968651,0.9861329793930054,0.0542605891823768,0.2759274889810704,43793,65531.63501739502,0.9955756664276124,0.0138542233034968,0.7806127574085215,0.9869558811187744,0.0505871586501598,0.2866201386157032,43793 -22395.058416366577,7.567029237747192,43476.50746488571,138035,0,43476.50746488571,0.9861329793930054,0.0542605891823768,0.2759771559511502,43793,65882.75573444366,0.9955047369003296,0.0140287214890122,0.776828284611313,0.9869558811187744,0.0505871586501598,0.286563876407547,43793 -22506.916797161102,7.619933605194092,43716.56052541733,138801,0,43716.56052541733,0.9861329793930054,0.0542605929076671,0.2760214367974019,43793,66234.73925423622,0.9955438375473022,0.0139393527060747,0.7791969983498022,0.9869558811187744,0.0505871586501598,0.2866075502097671,43793 -22614.423523902893,7.672637939453125,43956.76708507538,139549,0,43956.76708507538,0.9861329793930054,0.0542605891823768,0.2760094643162549,43793,66582.52459478378,0.9955278635025024,0.0140000423416495,0.7743114486500873,0.9869558811187744,0.0505871586501598,0.2867690866867088,43793 -22723.5324792862,7.724453449249268,44196.79552650452,140319,0,44196.79552650452,0.9861329793930054,0.0542605929076671,0.2759517549835037,43793,66931.73331785202,0.9955880641937256,0.0137920398265123,0.7879608890674992,0.9869558811187744,0.0505871586501598,0.2864734779842767,43793 -22835.397917747498,7.777060031890869,44436.87666225433,141081,0,44436.87666225433,0.9861329793930054,0.0542605891823768,0.2760356463316414,43793,67283.75227689743,0.9955501556396484,0.0138605190441012,0.7867203846922961,0.9869558811187744,0.0505871586501598,0.2865235551818129,43793 -22944.49940443039,7.830647706985474,44676.84755587578,141837,0,44676.84755587578,0.9861329793930054,0.0542605891823768,0.2760171261581809,43793,67632.89763379097,0.9954988956451416,0.0140383094549179,0.7701409373357102,0.9869558811187744,0.0505871586501598,0.2866630419810305,43793 -23052.86403822899,7.88399076461792,44916.99302315712,142608,0,44916.99302315712,0.9861329793930054,0.0542605891823768,0.275967715529291,43793,67981.48130369186,0.9955378770828248,0.0138828856870532,0.7796853291287564,0.9869558811187744,0.0505871586501598,0.2864522326528279,43793 -23165.930099487305,7.938454627990723,45156.97582030296,143365,0,45156.97582030296,0.9861329793930054,0.0542605891823768,0.2759283175225142,43793,68334.60452127457,0.995568573474884,0.0139041766524314,0.7736875430872415,0.9869558811187744,0.0505871586501598,0.2865738946807334,43793 -23272.95647072792,7.992566347122192,45396.94107532501,144129,0,45396.94107532501,0.9861329793930054,0.0542605929076671,0.2759656585305678,43793,68681.67007613182,0.9955149292945862,0.0140272751450538,0.7838841100376359,0.9869558811187744,0.0505871586501598,0.2865780476781701,43793 -23376.3064198494,8.046898603439331,45637.14063572884,144901,0,45637.14063572884,0.9861329793930054,0.0542605929076671,0.2760008915756654,43793,69025.29367923737,0.9955885410308838,0.0137948272749781,0.781838371574247,0.9869558811187744,0.0505871586501598,0.2865226670030312,43793 -23485.505282640457,8.099945545196533,45877.35671019554,145674,0,45877.35671019554,0.9861329793930054,0.0542605891823768,0.276041130417176,43793,69374.78163695335,0.995506227016449,0.0140169616788625,0.7722065956630976,0.9869558811187744,0.0505871586501598,0.2865010927262363,43793 -23593.540898799896,8.154316425323486,46117.34045815468,146443,0,46117.34045815468,0.9861329793930054,0.0542605891823768,0.2759317784239599,43793,69722.8755030632,0.995516002178192,0.0139518873766064,0.7824941855650105,0.9869558811187744,0.0505871586501598,0.286583906969833,43793 -23701.482311964035,8.207560777664185,46357.51933145523,147215,0,46357.51933145523,0.9861329793930054,0.0542605891823768,0.2759693034880986,43793,70071.06904149055,0.9956077337265016,0.01379029545933,0.7769709209696616,0.9869558811187744,0.0505871586501598,0.2864189844674121,43793 -23808.41632771492,8.259890794754028,46597.6997013092,147989,0,46597.6997013092,0.9861329793930054,0.0542605891823768,0.2759677164548653,43793,70418.25558948517,0.9954946637153624,0.0140935676172375,0.7765513850956582,0.9869558811187744,0.0505871586501598,0.2865399241009378,43793 -23926.83319377899,8.31392502784729,46837.91800737381,148766,0,46837.91800737381,0.9861329793930054,0.0542605891823768,0.2759949770797097,43793,70776.96406245232,0.9955323934555054,0.0139013547450304,0.7849211130504072,0.9869558811187744,0.0505871586501598,0.286538708003245,43793 -24045.02954697609,8.379383087158203,47078.12951970101,149502,0,47078.12951970101,0.9861329793930054,0.0542605891823768,0.2759343881185401,43793,71135.46232533455,0.9955655932426452,0.0138668473809957,0.7834967595954778,0.9869558811187744,0.0505871586501598,0.2865266590790944,43793 -24162.76766180992,8.43949007987976,47318.37801861763,150227,0,47318.37801861763,0.9861329793930054,0.0542605929076671,0.2760130703989507,43793,71493.53320717812,0.9955136179924012,0.0139957321807742,0.7791588996594946,0.9869558811187744,0.0505871586501598,0.2867352038186914,43793 -24271.599202156067,8.502744197845459,47558.30167579651,150979,0,47558.30167579651,0.9861329793930054,0.0542605891823768,0.2759236825019268,43793,71842.37278032303,0.9955903887748718,0.0137750133872032,0.7747467790100341,0.9869558811187744,0.0505871586501598,0.2866331928863107,43793 -24379.415264368057,8.558963775634766,47798.30082893372,151750,0,47798.30082893372,0.9861329793930054,0.0542605891823768,0.2760510306082031,43793,72190.26391029358,0.9955427050590516,0.014013847336173,0.7803063834457855,0.9869558811187744,0.0505871586501598,0.2865761340054181,43793 -24485.05907702446,8.614659547805786,48038.29620957375,152523,0,48038.29620957375,0.9861329793930054,0.0542605891823768,0.275992589903956,43793,72535.97880601883,0.9955410957336426,0.0139313545078039,0.782091041865194,0.9869558811187744,0.0505871586501598,0.2866111059765719,43793 -24594.315055131912,8.67185354232788,48278.4686756134,153291,0,48278.4686756134,0.9861329793930054,0.0542605929076671,0.2761404118366132,43793,72885.48430514336,0.9955719709396362,0.0138335553929209,0.7839643969392222,0.9869558811187744,0.0505871586501598,0.2865632471513083,43793 -24700.583837985992,8.727308988571167,48518.43424797058,154068,0,48518.43424797058,0.9861329793930054,0.0542605891823768,0.2759883245069869,43793,73231.79427218437,0.9955121278762816,0.0139707941561937,0.7735758383914202,0.9869558811187744,0.0505871586501598,0.2867517443359907,43793 -24811.903750658035,8.78290057182312,48758.47049832344,154839,0,48758.47049832344,0.9861329793930054,0.0542605891823768,0.2762042065151867,43793,73583.22550678253,0.995508909225464,0.0139925656840205,0.7794013870917653,0.9869558811187744,0.0505871586501598,0.2865083063356302,43793 -24915.600333452225,8.837835311889648,48998.50784397125,155612,0,48998.50784397125,0.9861329793930054,0.0542605891823768,0.2759311870209673,43793,73927.03393220901,0.9955801963806152,0.0138422073796391,0.7769311812947006,0.9869558811187744,0.0505871586501598,0.2865149377407092,43793 -25025.213356494904,8.892682790756226,49238.476517915726,156377,0,49238.476517915726,0.9861329793930054,0.0542605891823768,0.2760219087501092,43793,74276.68950462341,0.9955509901046752,0.0139570720493793,0.7760353013748329,0.9869558811187744,0.0505871586501598,0.2865370845662255,43793 -25130.03033232689,8.947967529296875,49478.666414260864,157147,0,49478.666414260864,0.9861329793930054,0.0542605891823768,0.2760101429392562,43793,74621.77139091492,0.99555104970932,0.0139165678992867,0.7797898173653187,0.9869558811187744,0.0505871586501598,0.2865325905894223,43793 -25239.711701631542,9.003626585006714,49718.83649778366,157912,0,49718.83649778366,0.9861329793930054,0.0542605891823768,0.2759528202316587,43793,74971.6976518631,0.995518445968628,0.013942502439022,0.7842495488099699,0.9869558811187744,0.0505871586501598,0.2866229012628155,43793 -25351.234421491623,9.058332443237305,49958.982744932175,158689,0,49958.982744932175,0.9861329793930054,0.0542605891823768,0.275956717183867,43793,75323.44123673439,0.9955437779426576,0.013895777054131,0.7807502432776531,0.9869558811187744,0.0505871586501598,0.2865415916987422,43793 -25459.5462975502,9.114988565444946,50199.202701091766,159457,0,50199.202701091766,0.9861329793930054,0.0542605891823768,0.2759523538735477,43793,75672.04923796654,0.9955360293388368,0.0139449564740061,0.7723770871421647,0.9869558811187744,0.0505871586501598,0.2867485931237354,43793 -25574.099715471268,9.172712564468384,50439.27810120583,160210,0,50439.27810120583,0.9861329793930054,0.0542605891823768,0.2759910805896102,43793,76026.75747394562,0.9955352544784546,0.0139972520992159,0.7826859588401474,0.9869558811187744,0.0505871586501598,0.2866698257560256,43793 -25692.58653616905,9.235629558563232,50679.41657733917,160942,0,50679.41657733917,0.9861329793930054,0.0542605891823768,0.2759355625368492,43793,76385.46918559074,0.9955859184265136,0.0137790655717253,0.7764327294859807,0.9869558811187744,0.0505871586501598,0.2865785827957917,43793 -25799.029140234,9.29334807395935,50919.51976323128,161709,0,50919.51976323128,0.9861329793930054,0.0542605891823768,0.2759996771960736,43793,76732.0926129818,0.9955579042434692,0.0139086106792092,0.7868474332606503,0.9869558811187744,0.0505871586501598,0.2865003997388438,43793 -25906.129714250565,9.348551273345947,51159.46822762489,162480,0,51159.46822762489,0.9861329793930054,0.0542605891823768,0.276074356700649,43793,77079.21677541733,0.9955121874809264,0.0139806494116783,0.7751245426310285,0.9869558811187744,0.0505871586501598,0.2865003863132244,43793 -26011.374217510223,9.40532898902893,51399.63913941383,163251,0,51399.63913941383,0.9861329793930054,0.0542605891823768,0.2759357398817103,43793,77424.70866513252,0.9955761432647704,0.0138122150674462,0.7789223056384773,0.9869558811187744,0.0505871586501598,0.2865489837939391,43793 -26124.034873485565,9.462925434112549,51639.66033124924,164020,0,51639.66033124924,0.9861329793930054,0.0542605891823768,0.275977611609337,43793,77777.46783638,0.9955022931098938,0.0140853282064199,0.7771181152126543,0.9869558811187744,0.0505871586501598,0.2865277191401532,43793 -26230.481513261795,9.519116163253784,51879.79873919487,164793,0,51879.79873919487,0.9861329793930054,0.0542605891823768,0.2760266176655251,43793,78124.12903475761,0.995532214641571,0.013914574868977,0.7748441575197567,0.9869558811187744,0.0505871586501598,0.2865104223184633,43793 -26338.66453528404,9.57723355293274,52119.75490355492,165553,0,52119.75490355492,0.9861329793930054,0.0542605891823768,0.2759600083074585,43793,78472.34597706795,0.9955798387527466,0.0138430148363113,0.7861792908162337,0.9869558811187744,0.0505871586501598,0.2865396899834355,43793 -26447.33196592331,9.634422063827516,52359.93154716492,166322,0,52359.93154716492,0.9861329793930054,0.0542605891823768,0.2760153037468378,43793,78821.26730561256,0.9955323338508606,0.0139575526118278,0.7757177991067279,0.9869558811187744,0.0505871586501598,0.2866571959988571,43793 -26555.263708114624,9.785484075546265,52600.01885247231,167095,0,52600.01885247231,0.9861329793930054,0.0542605891823768,0.2759521436007556,43793,79169.45726418495,0.995536208152771,0.0139263924211263,0.7815679243915041,0.9869558811187744,0.0505871586501598,0.2867707938192554,43793 -26658.97461915016,9.84364652633667,52840.2066552639,167874,0,52840.2066552639,0.9861329793930054,0.0542605891823768,0.2759558334630442,43793,79513.43434143066,0.9955568313598632,0.0139553509652614,0.7682549175407976,0.9869558811187744,0.0505871586501598,0.2865932665129089,43793 -26766.8594186306,9.901641130447388,53080.41065263748,168616,0,53080.41065263748,0.9861329793930054,0.0542605891823768,0.2760031704897681,43793,79861.60309934616,0.9955222606658936,0.0139672039076685,0.7815628210311292,0.9869558811187744,0.0505871586501598,0.2865442472000015,43793 -26873.43555402756,9.959112644195557,53320.526663541794,169385,0,53320.526663541794,0.9861329793930054,0.0542605891823768,0.2759793808137554,43793,80208.37276148796,0.9955328702926636,0.0139327310025691,0.784643210829564,0.9869558811187744,0.0505871586501598,0.2866681688143653,43793 -26979.10472464561,10.017580509185793,53560.66819763184,170154,0,53560.66819763184,0.9861329793930054,0.0542605929076671,0.2760007219570654,43793,80554.26143503189,0.995567500591278,0.0138555373996496,0.7803541515077737,0.9869558811187744,0.0505871586501598,0.2865558987791517,43793 -27086.828375577927,10.07615089416504,53800.71771574021,170918,0,53800.71771574021,0.9861329793930054,0.0542605891823768,0.2759980739721762,43793,80902.11252045631,0.9955355525016784,0.0139146000146865,0.7782003795864987,0.9869558811187744,0.0505871586501598,0.2867647925173271,43793 -27197.180181980133,10.135051727294922,54040.759315013885,171685,0,54040.759315013885,0.9861329793930054,0.0542605891823768,0.2760119254960862,43793,81252.58443045616,0.9955580830574036,0.013860491104424,0.7741929507380779,0.9869558811187744,0.0505871586501598,0.286599359546828,43793 -27305.70858001709,10.194327116012571,54280.77589941025,172466,0,54280.77589941025,0.9861329793930054,0.0542605891823768,0.2759691651199342,43793,81601.20845675468,0.9955554008483888,0.0139900296926498,0.7805614202828458,0.9869558811187744,0.0505871586501598,0.2864266172910606,43793 -27414.49658846855,10.251972436904907,54520.82110333443,173246,0,54520.82110333443,0.9861329793930054,0.0542605891823768,0.27606934583939,43793,81950.11951947212,0.99552983045578,0.0139098232612013,0.7827256491404169,0.9869558811187744,0.0505871586501598,0.2864472859493177,43793 -27524.51087450981,10.31071424484253,54760.8744161129,174025,0,54760.8744161129,0.9861329793930054,0.0542605891823768,0.2760301578149513,43793,82300.26519680023,0.9955813884735109,0.0138372285291552,0.7825517451429103,0.9869558811187744,0.0505871586501598,0.2865332651294395,43793 -27632.01285123825,10.36878228187561,55001.09853792191,174803,0,55001.09853792191,0.9861329793930054,0.0542605891823768,0.275994238581464,43793,82648.0693461895,0.9954863786697388,0.0140195405110716,0.7766700627017158,0.9869558811187744,0.0505871586501598,0.2865937948374639,43793 -27741.654556512833,10.427810430526733,55241.06003189087,175574,0,55241.06003189087,0.9861329793930054,0.05426058918237686,0.276069050592469,43793,82997.75070357323,0.9955763220787048,0.013856913894414902,0.7794461300147126,0.9869558811187744,0.050587158650159836,0.28655074676841663,43793 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv deleted file mode 100644 index e49323aa8..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1995 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/mean_average_precision,validation/accuracy,validation/loss,validation/mean_average_precision,validation/num_examples,test/accuracy,test/loss,test/mean_average_precision,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,1.8646773,0.7976973,,,,,,,,,,,,,,,,, -1,,,0.4702461659908294,0.7943987846374512,0.0219344540100092,0.4722293615341186,0.7926825284957886,0.0249722052564542,43793.0,0.4722304046154022,0.7914055585861206,0.0270653151253273,43793.0,18.616023302078247,537.1137669086456,18.616023302078247,518.4977009296417,0.0,0.0 -100,0.30788845,0.27934712,,,,,,,,,,,,,,,,, -200,0.09656968,0.110439576,,,,,,,,,,,,,,,,, -300,0.030756168,0.06730965,,,,,,,,,,,,,,,,, -400,0.015522286,0.05785488,,,,,,,,,,,,,,,,, -500,0.037829775,0.057681542,,,,,,,,,,,,,,,,, -600,0.025269073,0.05206358,,,,,,,,,,,,,,,,, -700,0.04198525,0.050622545,,,,,,,,,,,,,,,,, -775,,,0.9868525862693788,0.0508523173630237,0.0579805508242579,0.984272599220276,0.0599129758775234,0.0558535770941507,43793.0,0.9832780957221984,0.0629883781075477,0.0577343690394387,43793.0,258.824720621109,892.1560852527618,258.824720621109,633.2843382358551,0.0283777713775634,0.0 -800,0.021884056,0.05359134,,,,,,,,,,,,,,,,, -900,0.021598551,0.047255877,,,,,,,,,,,,,,,,, -1000,0.021191249,0.051142167,,,,,,,,,,,,,,,,, -1100,0.029601442,0.049905796,,,,,,,,,,,,,,,,, -1200,0.036168233,0.046972856,,,,,,,,,,,,,,,,, -1300,0.0329426,0.04886851,,,,,,,,,,,,,,,,, -1400,0.022155762,0.046679616,,,,,,,,,,,,,,,,, -1500,0.03232579,0.044921845,,,,,,,,,,,,,,,,, -1559,,,0.987420916557312,0.0460253953933715,0.1102378426434518,0.9847450852394104,0.0551337078213691,0.110825755459752,43793.0,0.983700156211853,0.0584810450673103,0.1055086506954278,43793.0,499.05750465393066,1248.7858963012695,499.05750465393066,749.6343984603882,0.056328535079956,0.0 -1600,0.014416245,0.042798795,,,,,,,,,,,,,,,,, -1700,0.019972567,0.048796665,,,,,,,,,,,,,,,,, -1800,0.014934475,0.041791484,,,,,,,,,,,,,,,,, -1900,0.021942524,0.04594054,,,,,,,,,,,,,,,,, -2000,0.022137364,0.04872317,,,,,,,,,,,,,,,,, -2100,0.019521799,0.04847294,,,,,,,,,,,,,,,,, -2200,0.010599609,0.041090287,,,,,,,,,,,,,,,,, -2300,0.020205326,0.04424012,,,,,,,,,,,,,,,,, -2336,,,0.9876879453659058,0.0442273467779159,0.1398924353982198,0.985008955001831,0.0533492192625999,0.1355290644133614,43793.0,0.9839503169059752,0.0567783191800117,0.1309682146696898,43793.0,739.084522485733,1608.116632938385,739.084522485733,868.8937499523163,0.0814664363861084,0.0 -2400,0.012645608,0.044106625,,,,,,,,,,,,,,,,, -2500,0.010583815,0.0432867,,,,,,,,,,,,,,,,, -2600,0.018204404,0.040009435,,,,,,,,,,,,,,,,, -2700,0.010222379,0.038107973,,,,,,,,,,,,,,,,, -2800,0.014100456,0.04297008,,,,,,,,,,,,,,,,, -2900,0.015874797,0.041063618,,,,,,,,,,,,,,,,, -3000,0.017260926,0.043759238,,,,,,,,,,,,,,,,, -3100,0.012613798,0.041197333,,,,,,,,,,,,,,,,, -3116,,,0.9880595207214355,0.0418546013534069,0.1689057449872897,0.9852123260498048,0.0506405979394912,0.1537483003262923,43793.0,0.9842274785041808,0.0534768253564834,0.1543288606950389,43793.0,979.3258407115936,1968.437069416046,979.3258407115936,988.9270553588868,0.1082706451416015,0.0 -3200,0.032688264,0.043305982,,,,,,,,,,,,,,,,, -3300,0.020250477,0.040842116,,,,,,,,,,,,,,,,, -3400,0.013885729,0.041379876,,,,,,,,,,,,,,,,, -3500,0.015520628,0.04366542,,,,,,,,,,,,,,,,, -3600,0.016135043,0.045629468,,,,,,,,,,,,,,,,, -3700,0.01269361,0.04773159,,,,,,,,,,,,,,,,, -3800,0.013772018,0.04249744,,,,,,,,,,,,,,,,, -3895,,,0.9883913993835448,0.0399228259921073,0.2009358619510241,0.9854932427406312,0.0492856800556182,0.1780816860356048,43793.0,0.9845648407936096,0.0519956611096859,0.1728490681036745,43793.0,1219.583377122879,2329.331063270569,1219.583377122879,1109.51793384552,0.1344447135925293,0.0 -3900,0.01686203,0.035215538,,,,,,,,,,,,,,,,, -4000,0.012475979,0.03662366,,,,,,,,,,,,,,,,, -4100,0.012740671,0.04113024,,,,,,,,,,,,,,,,, -4200,0.020528287,0.04022963,,,,,,,,,,,,,,,,, -4300,0.017358646,0.04404779,,,,,,,,,,,,,,,,, -4400,0.013162706,0.039373558,,,,,,,,,,,,,,,,, -4500,0.012197353,0.03552427,,,,,,,,,,,,,,,,, -4600,0.013195284,0.038505223,,,,,,,,,,,,,,,,, -4669,,,0.9887118339538574,0.0385812260210514,0.2256461587811895,0.9857218265533448,0.0479509532451629,0.1956819743386989,43793.0,0.984831929206848,0.050608541816473,0.1937942979027161,43793.0,1459.7593297958374,2690.2451293468475,1459.7593297958374,1230.210641622543,0.1602764129638672,0.0 -4700,0.014690772,0.036702562,,,,,,,,,,,,,,,,, -4800,0.009823844,0.034699157,,,,,,,,,,,,,,,,, -4900,0.011351492,0.036963664,,,,,,,,,,,,,,,,, -5000,0.010516408,0.041750006,,,,,,,,,,,,,,,,, -5100,0.013677627,0.041341983,,,,,,,,,,,,,,,,, -5200,0.011904696,0.038278226,,,,,,,,,,,,,,,,, -5300,0.013942742,0.04069951,,,,,,,,,,,,,,,,, -5400,0.017985774,0.040949598,,,,,,,,,,,,,,,,, -5437,,,0.9889028668403624,0.0378363840281963,0.2413111711242901,0.9858987927436828,0.0474938340485096,0.2085082304214656,43793.0,0.9849759340286256,0.0504288785159587,0.2005436096187317,43793.0,1699.8314542770386,3053.124225139618,1699.8314542770386,1352.9720075130465,0.1865897178649902,0.0 -5500,0.018417113,0.041367274,,,,,,,,,,,,,,,,, -5600,0.017799752,0.038744267,,,,,,,,,,,,,,,,, -5700,0.020982916,0.034798153,,,,,,,,,,,,,,,,, -5800,0.014699406,0.037633806,,,,,,,,,,,,,,,,, -5900,0.015847571,0.03898179,,,,,,,,,,,,,,,,, -6000,0.012245577,0.032828093,,,,,,,,,,,,,,,,, -6100,0.0150039075,0.039169356,,,,,,,,,,,,,,,,, -6176,,,0.9892878532409668,0.0362776331603527,0.2747439853818164,0.9860436916351318,0.0465897619724273,0.2127026190852786,43793.0,0.985202133655548,0.049216840416193,0.2113630635353627,43793.0,1939.9151394367216,3419.7436504364014,1939.9151394367216,1479.4620878696442,0.2128331661224365,0.0 -6200,0.01843773,0.039013833,,,,,,,,,,,,,,,,, -6300,0.020777658,0.03738571,,,,,,,,,,,,,,,,, -6400,0.012477494,0.038932122,,,,,,,,,,,,,,,,, -6500,0.012639703,0.041147985,,,,,,,,,,,,,,,,, -6600,0.01489527,0.03767532,,,,,,,,,,,,,,,,, -6700,0.018494418,0.03727196,,,,,,,,,,,,,,,,, -6800,0.021410309,0.041228224,,,,,,,,,,,,,,,,, -6900,0.012313298,0.033608805,,,,,,,,,,,,,,,,, -6941,,,0.9895052313804626,0.0354714021086692,0.2851274273494,0.9862803816795348,0.0460730753839015,0.2277599897011227,43793.0,0.9854455590248108,0.0487179309129715,0.2263846545075325,43793.0,2180.108886241913,3785.606664180756,2180.108886241913,1605.0842487812042,0.240192174911499,0.0 -7000,0.017832138,0.038042787,,,,,,,,,,,,,,,,, -7100,0.015173325,0.0365528,,,,,,,,,,,,,,,,, -7200,0.019740555,0.038664922,,,,,,,,,,,,,,,,, -7300,0.016713258,0.034464728,,,,,,,,,,,,,,,,, -7400,0.0143587375,0.040532175,,,,,,,,,,,,,,,,, -7500,0.015418773,0.035790566,,,,,,,,,,,,,,,,, -7600,0.013530107,0.0359235,,,,,,,,,,,,,,,,, -7700,0.016590819,0.03643931,,,,,,,,,,,,,,,,, -7720,,,0.9895861744880676,0.0348616912961006,0.2963364431756689,0.9863311052322388,0.0460848286747932,0.2284525234305878,43793.0,0.98549485206604,0.0486711822450161,0.2260540501382554,43793.0,2420.2818851470947,4150.311340808868,2420.2818851470947,1729.5676703453064,0.2676494121551513,0.0 -7800,0.018505216,0.035636216,,,,,,,,,,,,,,,,, -7900,0.018618884,0.040403273,,,,,,,,,,,,,,,,, -8000,0.01726695,0.034732208,,,,,,,,,,,,,,,,, -8100,0.015004191,0.033942133,,,,,,,,,,,,,,,,, -8200,0.0241559,0.035539903,,,,,,,,,,,,,,,,, -8300,0.0152527895,0.033853825,,,,,,,,,,,,,,,,, -8400,0.02552286,0.035175744,,,,,,,,,,,,,,,,, -8498,,,0.9899796843528748,0.0335333086550235,0.3374593375725701,0.9864756464958192,0.0453152917325496,0.2374250299991654,43793.0,0.9855883717536926,0.0479560121893882,0.2365790668525406,43793.0,2660.528244495392,4511.944679737091,2660.528244495392,1850.9068267345428,0.2948906421661377,0.0 -8500,0.016709108,0.034627065,,,,,,,,,,,,,,,,, -8600,0.019595614,0.0427954,,,,,,,,,,,,,,,,, -8700,0.024088634,0.035829354,,,,,,,,,,,,,,,,, -8800,0.019351302,0.03559693,,,,,,,,,,,,,,,,, -8900,0.01843229,0.03550619,,,,,,,,,,,,,,,,, -9000,0.023027213,0.037956823,,,,,,,,,,,,,,,,, -9100,0.0170304,0.035485685,,,,,,,,,,,,,,,,, -9200,0.015011755,0.032222174,,,,,,,,,,,,,,,,, -9268,,,0.9901212453842164,0.0331129059195518,0.3409681409111439,0.9865267872810364,0.0450315028429031,0.2348892874300491,43793.0,0.9857029318809508,0.0474643260240554,0.2484129351793073,43793.0,2900.601298093796,4878.675051450729,2900.601298093796,1977.5145423412323,0.3243780136108398,0.0 -9300,0.016179018,0.03221496,,,,,,,,,,,,,,,,, -9400,0.016605219,0.03270021,,,,,,,,,,,,,,,,, -9500,0.020508666,0.035062347,,,,,,,,,,,,,,,,, -9600,0.017112123,0.037961975,,,,,,,,,,,,,,,,, -9700,0.020888364,0.035006214,,,,,,,,,,,,,,,,, -9800,0.018816363,0.032788776,,,,,,,,,,,,,,,,, -9900,0.014732447,0.027221145,,,,,,,,,,,,,,,,, -10000,,,0.9901310801506042,0.0330612622201442,0.3398156014983023,0.9864979386329652,0.0451376028358936,0.2469973556483602,43793.0,0.985702097415924,0.0476629026234149,0.2479148513505949,43793.0,3140.8176939487457,5245.671521663666,3140.8176939487457,2104.2374680042267,0.3568861484527588,0.0 -10000,0.018355932,0.0315104,,,,,,,,,,,,,,,,, -10100,0.023045385,0.032483976,,,,,,,,,,,,,,,,, -10200,0.02430299,0.031410806,,,,,,,,,,,,,,,,, -10300,0.01855313,0.02976205,,,,,,,,,,,,,,,,, -10400,0.018683035,0.035658393,,,,,,,,,,,,,,,,, -10500,0.026639722,0.032157484,,,,,,,,,,,,,,,,, -10600,0.045279402,0.037715442,,,,,,,,,,,,,,,,, -10700,0.019178499,0.03486577,,,,,,,,,,,,,,,,, -10771,,,0.990074336528778,0.0331404134631156,0.3361895117138637,0.9865519404411316,0.0447386875748634,0.2440713077247412,43793.0,0.985705018043518,0.0472121275961399,0.2473022407345934,43793.0,3380.950745820999,5611.949606180191,3380.950745820999,2230.334555387497,0.384890079498291,0.0 -10800,0.022746265,0.035639938,,,,,,,,,,,,,,,,, -10900,0.023611428,0.0348101,,,,,,,,,,,,,,,,, -11000,0.0225751,0.032239553,,,,,,,,,,,,,,,,, -11100,0.022151347,0.03163887,,,,,,,,,,,,,,,,, -11200,0.025199894,0.034684572,,,,,,,,,,,,,,,,, -11300,0.022068132,0.032833554,,,,,,,,,,,,,,,,, -11400,0.021644253,0.032830548,,,,,,,,,,,,,,,,, -11500,0.028423928,0.032280676,,,,,,,,,,,,,,,,, -11547,,,0.990300178527832,0.0324656441807746,0.3550998255123031,0.9865734577178956,0.0445703901350498,0.2514523384657378,43793.0,0.9857008457183838,0.0471366681158542,0.2437119404115444,43793.0,3620.993992805481,5977.426261425018,3620.993992805481,2355.720167160034,0.4129753112792969,0.0 -11600,0.028569885,0.032510567,,,,,,,,,,,,,,,,, -11700,0.019179743,0.034894746,,,,,,,,,,,,,,,,, -11800,0.024069866,0.03281417,,,,,,,,,,,,,,,,, -11900,0.026586859,0.033175055,,,,,,,,,,,,,,,,, -12000,0.023860468,0.03338004,,,,,,,,,,,,,,,,, -12100,0.023723284,0.034833387,,,,,,,,,,,,,,,,, -12200,0.030056445,0.031195115,,,,,,,,,,,,,,,,, -12300,0.021843258,0.032293357,,,,,,,,,,,,,,,,, -12317,,,0.9902815222740172,0.0324085690081119,0.3652259174491668,0.9865154027938844,0.0448619574308395,0.2453973131389777,43793.0,0.9856165647506714,0.0475370250642299,0.2464815772412357,43793.0,3861.169105052948,6340.75782251358,3861.169105052948,2478.8244092464447,0.4461667537689209,0.0 -12400,0.02925083,0.036337502,,,,,,,,,,,,,,,,, -12500,0.03676716,0.033463944,,,,,,,,,,,,,,,,, -12600,0.026781147,0.03332706,,,,,,,,,,,,,,,,, -12700,0.024671419,0.033928655,,,,,,,,,,,,,,,,, -12800,0.022829922,0.035320215,,,,,,,,,,,,,,,,, -12900,0.027240178,0.033920802,,,,,,,,,,,,,,,,, -13000,0.029270958,0.036737267,,,,,,,,,,,,,,,,, -13096,,,0.9903441071510316,0.0318412445485591,0.3631558729401631,0.9867634773254396,0.0443429425358772,0.253112694139392,43793.0,0.9858625531196594,0.0470395348966121,0.2526472636480409,43793.0,4101.32523059845,6705.580302476883,4101.32523059845,2603.4416444301605,0.4753603935241699,0.0 -13100,0.025267929,0.032702863,,,,,,,,,,,,,,,,, -13200,0.030921198,0.030491007,,,,,,,,,,,,,,,,, -13300,0.027878795,0.031722378,,,,,,,,,,,,,,,,, -13400,0.027568217,0.030065317,,,,,,,,,,,,,,,,, -13500,0.027866697,0.029247118,,,,,,,,,,,,,,,,, -13600,0.028928554,0.033430647,,,,,,,,,,,,,,,,, -13700,0.033838723,0.032886673,,,,,,,,,,,,,,,,, -13800,0.043418825,0.03296675,,,,,,,,,,,,,,,,, -13864,,,0.9905263781547546,0.0316419415175914,0.3830037810011942,0.9866445064544678,0.0443065948784351,0.2587238987096111,43793.0,0.985772430896759,0.0467454083263874,0.2591220735212268,43793.0,4341.513201236725,7072.06339097023,4341.513201236725,2729.6876018047333,0.5052244663238525,0.0 -13900,0.028736118,0.034909554,,,,,,,,,,,,,,,,, -14000,0.036926184,0.030608341,,,,,,,,,,,,,,,,, -14100,0.033710927,0.03299375,,,,,,,,,,,,,,,,, -14200,0.029750306,0.032970782,,,,,,,,,,,,,,,,, -14300,0.038056612,0.033584498,,,,,,,,,,,,,,,,, -14400,0.040808946,0.02775608,,,,,,,,,,,,,,,,, -14500,0.038884234,0.032218203,,,,,,,,,,,,,,,,, -14600,0.03315975,0.03259633,,,,,,,,,,,,,,,,, -14634,,,0.990631103515625,0.0308733396232128,0.3843530163345709,0.9867780804634094,0.0449380241334438,0.2532076404834464,43793.0,0.9859017133712769,0.047769758850336,0.2517687762805413,43793.0,4581.769669532776,7438.527512073517,4581.769669532776,2855.8469376564026,0.5335369110107422,0.0 -14700,0.03896865,0.03454416,,,,,,,,,,,,,,,,, -14800,0.036094464,0.032533146,,,,,,,,,,,,,,,,, -14900,0.027803836,0.031617794,,,,,,,,,,,,,,,,, -15000,0.039622296,0.033479642,,,,,,,,,,,,,,,,, -15100,0.042781524,0.033187225,,,,,,,,,,,,,,,,, -15200,0.035949692,0.032666676,,,,,,,,,,,,,,,,, -15300,0.033971973,0.03072572,,,,,,,,,,,,,,,,, -15395,,,0.9906738996505736,0.0305171180516481,0.4060332768930056,0.9868982434272766,0.0439459308981895,0.2693682103284675,43793.0,0.9859628081321716,0.0468635782599449,0.2587225448907962,43793.0,4821.992245674133,7806.339802980423,4821.992245674133,2983.387553453445,0.5630576610565186,0.0 -15400,0.035780594,0.033904858,,,,,,,,,,,,,,,,, -15500,0.036432255,0.032253586,,,,,,,,,,,,,,,,, -15600,0.04652632,0.03169152,,,,,,,,,,,,,,,,, -15700,0.0410726,0.029278595,,,,,,,,,,,,,,,,, -15800,0.030256066,0.031500448,,,,,,,,,,,,,,,,, -15900,0.039542675,0.032616265,,,,,,,,,,,,,,,,, -16000,0.03847535,0.0334217,,,,,,,,,,,,,,,,, -16100,0.03263612,0.033463567,,,,,,,,,,,,,,,,, -16165,,,0.9908897876739502,0.0299185272306203,0.4190069996700812,0.9867841601371764,0.0441765636205673,0.2625511838238096,43793.0,0.9859126806259156,0.0468367300927639,0.2617815021162231,43793.0,5062.061771631241,8173.251757383347,5062.061771631241,3110.181565284729,0.5918211936950684,0.0 -16200,0.04049211,0.03032803,,,,,,,,,,,,,,,,, -16300,0.043399956,0.033848923,,,,,,,,,,,,,,,,, -16400,0.046652276,0.034418512,,,,,,,,,,,,,,,,, -16500,0.032139014,0.03428985,,,,,,,,,,,,,,,,, -16600,0.03147617,0.035058226,,,,,,,,,,,,,,,,, -16700,0.041882757,0.034935787,,,,,,,,,,,,,,,,, -16800,0.03370156,0.03513493,,,,,,,,,,,,,,,,, -16900,0.04442727,0.032682948,,,,,,,,,,,,,,,,, -16938,,,0.9907991886138916,0.0300168450921773,0.417379686754306,0.9867289662361144,0.0445146299898624,0.2625847733317957,43793.0,0.985939621925354,0.0470585785806179,0.2587865116364578,43793.0,5302.217221975327,8540.487695932388,5302.217221975327,3237.2122271060944,0.621873140335083,0.0 -17000,0.03504916,0.03171948,,,,,,,,,,,,,,,,, -17100,0.03750142,0.031542443,,,,,,,,,,,,,,,,, -17200,0.040245134,0.030233158,,,,,,,,,,,,,,,,, -17300,0.054908272,0.032179613,,,,,,,,,,,,,,,,, -17400,0.031552095,0.028493283,,,,,,,,,,,,,,,,, -17500,0.043447748,0.029690634,,,,,,,,,,,,,,,,, -17600,0.05723817,0.032528635,,,,,,,,,,,,,,,,, -17700,0.048703007,0.03195134,,,,,,,,,,,,,,,,, -17714,,,0.991073489189148,0.0292695350944995,0.431056186751704,0.9868012070655824,0.04434310272336,0.2621988834189322,43793.0,0.985929548740387,0.0471058934926986,0.2590215087437156,43793.0,5542.485919952393,8904.822938919067,5542.485919952393,3361.2305388450623,0.6503303050994873,0.0 -17800,0.057094127,0.030198863,,,,,,,,,,,,,,,,, -17900,0.035416137,0.032754105,,,,,,,,,,,,,,,,, -18000,0.05040186,0.033950314,,,,,,,,,,,,,,,,, -18100,0.044120677,0.033017904,,,,,,,,,,,,,,,,, -18200,0.046898838,0.033740208,,,,,,,,,,,,,,,,, -18300,0.05923791,0.030184846,,,,,,,,,,,,,,,,, -18400,0.1136056,0.034176808,,,,,,,,,,,,,,,,, -18472,,,0.991197407245636,0.0287209581583738,0.4492757997726974,0.9868823885917664,0.0443264842033386,0.2703646829249263,43793.0,0.986065149307251,0.0470095500349998,0.2659354239643912,43793.0,5782.542509555817,9268.951786994934,5782.542509555817,3485.254349946976,0.6788928508758545,0.0 -18500,0.03909286,0.032494813,,,,,,,,,,,,,,,,, -18600,0.042086445,0.03471366,,,,,,,,,,,,,,,,, -18700,0.043358263,0.0317288,,,,,,,,,,,,,,,,, -18800,0.037196666,0.030473035,,,,,,,,,,,,,,,,, -18900,0.03500986,0.030274352,,,,,,,,,,,,,,,,, -19000,0.04565969,0.031412516,,,,,,,,,,,,,,,,, -19100,0.042577077,0.0311914,,,,,,,,,,,,,,,,, -19200,0.039001178,0.030076956,,,,,,,,,,,,,,,,, -19245,,,0.9910871386528016,0.0290567986667156,0.4418756555670832,0.9868450164794922,0.0445682890713214,0.2730459174791392,43793.0,0.9859371185302734,0.0474443882703781,0.2659620129994766,43793.0,6022.602437973023,9633.210906744003,6022.602437973023,3609.403496026993,0.7088520526885986,0.0 -19300,0.040898807,0.03190631,,,,,,,,,,,,,,,,, -19400,0.036539655,0.030064827,,,,,,,,,,,,,,,,, -19500,0.051978912,0.032084208,,,,,,,,,,,,,,,,, -19600,0.04145287,0.02838632,,,,,,,,,,,,,,,,, -19700,0.05515446,0.033763833,,,,,,,,,,,,,,,,, -19800,0.048789278,0.03287277,,,,,,,,,,,,,,,,, -19900,0.04077418,0.03263211,,,,,,,,,,,,,,,,, -20000,0.05685689,0.03429228,,,,,,,,,,,,,,,,, -20022,,,0.9911937713623048,0.0288317780941724,0.4492746429492774,0.9869363903999328,0.0439840964972972,0.2670309948303253,43793.0,0.9860432744026184,0.0465757586061954,0.2620121333398993,43793.0,6262.597642183304,10001.974766492844,6262.597642183304,3738.122510671616,0.7383873462677002,0.0 -20100,0.049846828,0.035160683,,,,,,,,,,,,,,,,, -20200,0.0367817,0.026069228,,,,,,,,,,,,,,,,, -20300,0.04048762,0.030851675,,,,,,,,,,,,,,,,, -20400,0.038104676,0.0312687,,,,,,,,,,,,,,,,, -20500,0.05300454,0.029602299,,,,,,,,,,,,,,,,, -20600,0.04747463,0.029289901,,,,,,,,,,,,,,,,, -20700,0.044616062,0.0282577,,,,,,,,,,,,,,,,, -20790,,,0.9911454319953918,0.0288664679974317,0.4375458194010765,0.9868552088737488,0.044102668762207,0.2674966313573202,43793.0,0.9860184192657472,0.0467420704662799,0.2635152227722349,43793.0,6502.755940914154,10367.129216194153,6502.755940914154,3863.0693638324738,0.7679581642150879,0.0 -20800,0.053407084,0.031145647,,,,,,,,,,,,,,,,, -20900,0.04909119,0.031148743,,,,,,,,,,,,,,,,, -21000,0.03863911,0.029876292,,,,,,,,,,,,,,,,, -21100,0.043767508,0.03155893,,,,,,,,,,,,,,,,, -21200,0.04277386,0.028299477,,,,,,,,,,,,,,,,, -21300,0.08430651,0.028320923,,,,,,,,,,,,,,,,, -21400,0.04457513,0.031837158,,,,,,,,,,,,,,,,, -21500,0.052257225,0.030357547,,,,,,,,,,,,,,,,, -21562,,,0.9909815192222596,0.0295109990984201,0.4286785680240972,0.9869201183319092,0.0443251132965087,0.2658507401426963,43793.0,0.9859952330589294,0.047163251787424,0.2615548889912085,43793.0,6742.881019592285,10734.599383115768,6742.881019592285,3990.363956451416,0.7991714477539062,0.0 -21600,0.04239711,0.028124865,,,,,,,,,,,,,,,,, -21700,0.059057858,0.02997502,,,,,,,,,,,,,,,,, -21800,0.04345855,0.031854082,,,,,,,,,,,,,,,,, -21900,0.04191413,0.026922328,,,,,,,,,,,,,,,,, -22000,0.04263392,0.026012521,,,,,,,,,,,,,,,,, -22100,0.05140551,0.033259172,,,,,,,,,,,,,,,,, -22200,0.061155092,0.030976843,,,,,,,,,,,,,,,,, -22300,0.053604923,0.031207472,,,,,,,,,,,,,,,,, -22331,,,0.9911559224128724,0.0289796441793441,0.4308432635755393,0.9867740273475648,0.0440210215747356,0.2722770316354325,43793.0,0.986028492450714,0.0466667674481868,0.2643371726781699,43793.0,6982.853446960449,11100.56733751297,6982.853446960449,4116.310038328171,0.8293313980102539,0.0 -22400,0.045093168,0.032051694,,,,,,,,,,,,,,,,, -22500,0.054487266,0.029686484,,,,,,,,,,,,,,,,, -22600,0.047059216,0.032695282,,,,,,,,,,,,,,,,, -22700,0.05189185,0.031958017,,,,,,,,,,,,,,,,, -22800,0.046026725,0.029457707,,,,,,,,,,,,,,,,, -22900,0.06411669,0.032964848,,,,,,,,,,,,,,,,, -23000,0.047613107,0.03019498,,,,,,,,,,,,,,,,, -23100,0.049989454,0.030194446,,,,,,,,,,,,,,,,, -23101,,,0.991181254386902,0.0289966650307178,0.4412407234838167,0.9868909120559692,0.0442680716514587,0.2685556335456881,43793.0,0.9861287474632264,0.0469279028475284,0.265912846330854,43793.0,7223.086431264877,11469.08495092392,7223.086431264877,4244.544364929199,0.8605177402496338,0.0 -23200,0.06394678,0.030162686,,,,,,,,,,,,,,,,, -23300,0.049664132,0.032070868,,,,,,,,,,,,,,,,, -23400,0.050710905,0.031196333,,,,,,,,,,,,,,,,, -23500,0.054075807,0.028650897,,,,,,,,,,,,,,,,, -23600,0.05456971,0.028907398,,,,,,,,,,,,,,,,, -23700,0.048647013,0.032240607,,,,,,,,,,,,,,,,, -23800,0.054827653,0.030170204,,,,,,,,,,,,,,,,, -23839,,,0.9912705421447754,0.0284074898809194,0.4519112047921446,0.98692786693573,0.0439685694873333,0.2723111216283164,43793.0,0.986080765724182,0.0466671139001846,0.2664010170784851,43793.0,7463.321610689163,11838.96249818802,7463.321610689163,4374.132519721985,0.8918707370758057,0.0 -23900,0.05278306,0.03314729,,,,,,,,,,,,,,,,, -24000,0.058524754,0.030272232,,,,,,,,,,,,,,,,, -24100,0.050826687,0.030897541,,,,,,,,,,,,,,,,, -24200,0.049700536,0.027903628,,,,,,,,,,,,,,,,, -24300,0.050381698,0.027689187,,,,,,,,,,,,,,,,, -24400,0.057212252,0.02702866,,,,,,,,,,,,,,,,, -24500,0.06929131,0.027166389,,,,,,,,,,,,,,,,, -24600,0.06336675,0.030926973,,,,,,,,,,,,,,,,, -24604,,,0.9913958311080932,0.0279237478971481,0.4629627920889729,0.9869635701179504,0.0436313673853874,0.2792550134875484,43793.0,0.9861123561859132,0.046518225222826,0.2715248516894235,43793.0,7703.395177125931,12204.99735713005,7703.395177125931,4500.043113708496,0.9227774143218994,0.0 -24700,0.06890394,0.030488024,,,,,,,,,,,,,,,,, -24800,0.04900475,0.02896043,,,,,,,,,,,,,,,,, -24900,0.05417435,0.032687083,,,,,,,,,,,,,,,,, -25000,0.067431465,0.027826866,,,,,,,,,,,,,,,,, -25100,0.07130716,0.032801062,,,,,,,,,,,,,,,,, -25200,0.057364333,0.02886789,,,,,,,,,,,,,,,,, -25300,0.051410727,0.030063393,,,,,,,,,,,,,,,,, -25376,,,0.9914031624794006,0.0279384851455688,0.4564729812272185,0.9869404435157776,0.0441245399415493,0.2799870684061308,43793.0,0.9861026406288148,0.0467845723032951,0.2682354559966355,43793.0,7943.4845950603485,12569.297526597977,7943.4845950603485,4624.204813718796,0.9522502422332764,0.0 -25400,0.07494596,0.029565249,,,,,,,,,,,,,,,,, -25500,0.0693199,0.03197351,,,,,,,,,,,,,,,,, -25600,0.041971378,0.030093716,,,,,,,,,,,,,,,,, -25700,0.049194388,0.028197281,,,,,,,,,,,,,,,,, -25800,0.048890326,0.030399011,,,,,,,,,,,,,,,,, -25900,0.054188535,0.029319057,,,,,,,,,,,,,,,,, -26000,0.04724151,0.028431866,,,,,,,,,,,,,,,,, -26100,0.05447435,0.029832536,,,,,,,,,,,,,,,,, -26141,,,0.9914870262145996,0.0275341682136058,0.4767759712694727,0.9869298934936525,0.0439975336194038,0.2779713734008719,43793.0,0.986050009727478,0.0468300879001617,0.2667560083856059,43793.0,8183.584435462952,12939.556988477709,8183.584435462952,4754.313112735748,0.9845142364501952,0.0 -26200,0.048730567,0.030783122,,,,,,,,,,,,,,,,, -26300,0.06681435,0.0305561,,,,,,,,,,,,,,,,, -26400,0.07093395,0.0343228,,,,,,,,,,,,,,,,, -26500,0.07787973,0.026356936,,,,,,,,,,,,,,,,, -26600,0.057598572,0.035422456,,,,,,,,,,,,,,,,, -26700,0.05171206,0.029390289,,,,,,,,,,,,,,,,, -26800,0.049799237,0.031746726,,,,,,,,,,,,,,,,, -26900,0.05369537,0.029750802,,,,,,,,,,,,,,,,, -26909,,,0.9915395975112916,0.0272392891347408,0.4861751748239109,0.9869457483291626,0.0440063215792179,0.2838137070427076,43793.0,0.9861173629760742,0.0468805469572544,0.2717513397381658,43793.0,8423.600947618484,13305.885575532911,8423.600947618484,4880.574460268021,1.0156822204589844,0.0 -27000,0.050074715,0.028549915,,,,,,,,,,,,,,,,, -27100,0.06219063,0.035926607,,,,,,,,,,,,,,,,, -27200,0.06722295,0.034623925,,,,,,,,,,,,,,,,, -27300,0.06572301,0.031748626,,,,,,,,,,,,,,,,, -27400,0.08149717,0.02896489,,,,,,,,,,,,,,,,, -27500,0.053040102,0.027928067,,,,,,,,,,,,,,,,, -27600,0.049207944,0.032305572,,,,,,,,,,,,,,,,, -27679,,,0.991851568222046,0.0263207294046878,0.5115942328205016,0.9870147109031676,0.043874554336071,0.2798143819556658,43793.0,0.9861114621162416,0.046661589294672,0.274525427657463,43793.0,8663.806197404861,13674.063752174376,8663.806197404861,5008.497307538986,1.046447992324829,0.0 -27700,0.06053839,0.03309464,,,,,,,,,,,,,,,,, -27800,0.055426784,0.029737834,,,,,,,,,,,,,,,,, -27900,0.050087146,0.031159954,,,,,,,,,,,,,,,,, -28000,0.05849975,0.028943546,,,,,,,,,,,,,,,,, -28100,0.052726157,0.02901076,,,,,,,,,,,,,,,,, -28200,0.06929114,0.029723378,,,,,,,,,,,,,,,,, -28300,0.052000523,0.026962737,,,,,,,,,,,,,,,,, -28400,0.062760875,0.031571105,,,,,,,,,,,,,,,,, -28450,,,0.9919523000717164,0.0260404907166957,0.5119545931912086,0.9868450164794922,0.0437918752431869,0.2751911119907508,43793.0,0.986026406288147,0.0468337163329124,0.2677213978781376,43793.0,8903.79060792923,14037.0665640831,8903.79060792923,5131.4640600681305,1.0783584117889404,0.0 -28500,0.051024437,0.026766669,,,,,,,,,,,,,,,,, -28600,0.05798594,0.03181734,,,,,,,,,,,,,,,,, -28700,0.060013495,0.02632024,,,,,,,,,,,,,,,,, -28800,0.050979152,0.024083078,,,,,,,,,,,,,,,,, -28900,0.047775175,0.02557471,,,,,,,,,,,,,,,,, -29000,0.064228624,0.031661842,,,,,,,,,,,,,,,,, -29100,0.059717674,0.02979697,,,,,,,,,,,,,,,,, -29200,0.052530553,0.03195212,,,,,,,,,,,,,,,,, -29222,,,0.9917731881141664,0.0265221875160932,0.4974743077792137,0.9870484471321106,0.0440644659101963,0.2817265212932616,43793.0,0.9861489534378052,0.0470454916357994,0.2713569522959075,43793.0,9143.852503299711,14402.134669065475,9143.852503299711,5256.418573856354,1.1103498935699463,0.0 -29300,0.052682105,0.029239224,,,,,,,,,,,,,,,,, -29400,0.05758921,0.030507073,,,,,,,,,,,,,,,,, -29500,0.051510215,0.026337344,,,,,,,,,,,,,,,,, -29600,0.051778264,0.029501855,,,,,,,,,,,,,,,,, -29700,0.056094445,0.028647477,,,,,,,,,,,,,,,,, -29800,0.06321388,0.027108032,,,,,,,,,,,,,,,,, -29900,0.055989727,0.027314944,,,,,,,,,,,,,,,,, -29990,,,0.9917317628860474,0.026847893372178,0.4832345369128923,0.9871158003807068,0.043657187372446,0.274096221644849,43793.0,0.98624587059021,0.0466033667325973,0.2732212950992574,43793.0,9383.99340891838,14768.02099108696,9383.99340891838,5382.11224770546,1.1425657272338867,0.0 -30000,0.05446308,0.029516092,,,,,,,,,,,,,,,,, -30100,0.062068008,0.029093206,,,,,,,,,,,,,,,,, -30200,0.052536957,0.025111658,,,,,,,,,,,,,,,,, -30300,0.05428816,0.03293065,,,,,,,,,,,,,,,,, -30400,0.051603545,0.027474742,,,,,,,,,,,,,,,,, -30500,0.0668487,0.0314841,,,,,,,,,,,,,,,,, -30600,0.060682144,0.029959707,,,,,,,,,,,,,,,,, -30700,0.07208224,0.03116318,,,,,,,,,,,,,,,,, -30754,,,0.9916686415672302,0.0270208194851875,0.4869654090512664,0.9869489669799804,0.0441926978528499,0.2785772326789663,43793.0,0.9861165285110474,0.0470705963671207,0.273177578598428,43793.0,9624.226995944977,15130.782220840454,9624.226995944977,5504.588839054108,1.1740312576293943,0.0 -30800,0.06266995,0.026697163,,,,,,,,,,,,,,,,, -30900,0.052852485,0.027372377,,,,,,,,,,,,,,,,, -31000,0.05890645,0.02760518,,,,,,,,,,,,,,,,, -31100,0.067149736,0.029910084,,,,,,,,,,,,,,,,, -31200,0.061933402,0.030872038,,,,,,,,,,,,,,,,, -31300,0.052509554,0.029200034,,,,,,,,,,,,,,,,, -31400,0.061782338,0.0258846,,,,,,,,,,,,,,,,, -31500,0.051411044,0.027510703,,,,,,,,,,,,,,,,, -31525,,,0.9916601181030272,0.026939183473587,0.4903063582821136,0.9869124293327332,0.0442216247320175,0.2800932828179595,43793.0,0.9860710501670836,0.0469594039022922,0.2723993864317921,43793.0,9864.3176279068,15497.550547361374,9864.3176279068,5631.215336561203,1.205542802810669,0.0 -31600,0.06688832,0.029890137,,,,,,,,,,,,,,,,, -31700,0.07183541,0.03129517,,,,,,,,,,,,,,,,, -31800,0.063992836,0.030735426,,,,,,,,,,,,,,,,, -31900,0.051494826,0.02574369,,,,,,,,,,,,,,,,, -32000,0.056852505,0.029663898,,,,,,,,,,,,,,,,, -32100,0.060611952,0.029288193,,,,,,,,,,,,,,,,, -32200,0.07988832,0.03236002,,,,,,,,,,,,,,,,, -32293,,,0.9916957020759584,0.0268357768654823,0.4875452796120583,0.986953854560852,0.0439895987510681,0.2828627977718431,43793.0,0.9861329793930054,0.0467185899615287,0.2727999380156504,43793.0,10104.432535648346,15862.876982688904,10104.432535648346,5756.375189065933,1.2379601001739502,0.0 -32300,0.05439977,0.027288903,,,,,,,,,,,,,,,,, -32400,0.07285287,0.027554326,,,,,,,,,,,,,,,,, -32500,0.06203504,0.029982358,,,,,,,,,,,,,,,,, -32600,0.07264437,0.030222826,,,,,,,,,,,,,,,,, -32700,0.07091462,0.028924555,,,,,,,,,,,,,,,,, -32800,0.057396382,0.029429572,,,,,,,,,,,,,,,,, -32900,0.057061225,0.030097881,,,,,,,,,,,,,,,,, -33000,0.0659284,0.025893504,,,,,,,,,,,,,,,,, -33063,,,0.9917645454406738,0.0263966470956802,0.4937483853881236,0.9870102405548096,0.0440622717142105,0.2843081741127006,43793.0,0.986196994781494,0.0469749793410301,0.272280218184706,43793.0,10344.389446496964,16231.720217704771,10344.389446496964,5885.21012210846,1.26981782913208,0.0 -33100,0.05513212,0.027665194,,,,,,,,,,,,,,,,, -33200,0.06578562,0.031110397,,,,,,,,,,,,,,,,, -33300,0.054824714,0.027237533,,,,,,,,,,,,,,,,, -33400,0.059106786,0.027121695,,,,,,,,,,,,,,,,, -33500,0.0656229,0.026948178,,,,,,,,,,,,,,,,, -33600,0.057708584,0.031214545,,,,,,,,,,,,,,,,, -33700,0.067214906,0.02687914,,,,,,,,,,,,,,,,, -33800,0.054987106,0.02877132,,,,,,,,,,,,,,,,, -33835,,,0.9917652606964112,0.0265411045402288,0.4958937299340971,0.9870187640190125,0.0440790615975856,0.2793230355958955,43793.0,0.986108124256134,0.0469017885625362,0.2736489972800847,43793.0,10584.439611911774,16596.966630220413,10584.439611911774,6010.354276895523,1.3019187450408936,0.0 -33900,0.07085359,0.029647838,,,,,,,,,,,,,,,,, -34000,0.08290451,0.031750657,,,,,,,,,,,,,,,,, -34100,0.07814928,0.027801348,,,,,,,,,,,,,,,,, -34200,0.0547395,0.026151573,,,,,,,,,,,,,,,,, -34300,0.06740299,0.03159258,,,,,,,,,,,,,,,,, -34400,0.06437163,0.030778285,,,,,,,,,,,,,,,,, -34500,0.05924527,0.030639177,,,,,,,,,,,,,,,,, -34600,0.08146059,0.029084092,,,,,,,,,,,,,,,,, -34606,,,0.9919253587722778,0.0258921831846237,0.5138001849053772,0.9870142936706544,0.0442977733910083,0.2835272431791837,43793.0,0.9861780405044556,0.0471189729869365,0.2752511972446212,43793.0,10824.551196575165,16967.56339740753,10824.551196575165,6140.787740945816,1.333965539932251,0.0 -34700,0.064584844,0.03074264,,,,,,,,,,,,,,,,, -34800,0.07078976,0.026700377,,,,,,,,,,,,,,,,, -34900,0.06224239,0.025798978,,,,,,,,,,,,,,,,, -35000,0.06167568,0.029633585,,,,,,,,,,,,,,,,, -35100,0.059505444,0.024282375,,,,,,,,,,,,,,,,, -35200,0.06823205,0.030207274,,,,,,,,,,,,,,,,, -35300,0.06739341,0.027172955,,,,,,,,,,,,,,,,, -35370,,,0.99210524559021,0.0254258923232555,0.520013516105533,0.9870370626449584,0.0441921800374984,0.2838264860463516,43793.0,0.9861679077148438,0.0470642894506454,0.2723186520389838,43793.0,11064.743356227877,17333.15499162674,11064.743356227877,6266.13445353508,1.3671300411224363,0.0 -35400,0.069651976,0.025841936,,,,,,,,,,,,,,,,, -35500,0.057345472,0.02871741,,,,,,,,,,,,,,,,, -35600,0.06349799,0.028278794,,,,,,,,,,,,,,,,, -35700,0.050988898,0.026144607,,,,,,,,,,,,,,,,, -35800,0.06466459,0.026217418,,,,,,,,,,,,,,,,, -35900,0.082367875,0.027203528,,,,,,,,,,,,,,,,, -36000,0.06677707,0.02932123,,,,,,,,,,,,,,,,, -36100,0.06650472,0.027748957,,,,,,,,,,,,,,,,, -36133,,,0.9921540021896362,0.0252154488116502,0.5257004962214487,0.9870886206626892,0.0442344062030315,0.2862508491770593,43793.0,0.9861965775489808,0.0474430732429027,0.2730475080736813,43793.0,11304.83763718605,17697.271452903748,11304.83763718605,6390.104465007782,1.4000084400177002,0.0 -36200,0.063424885,0.029210186,,,,,,,,,,,,,,,,, -36300,0.07939912,0.033724394,,,,,,,,,,,,,,,,, -36400,0.07168526,0.025924144,,,,,,,,,,,,,,,,, -36500,0.06341091,0.02728876,,,,,,,,,,,,,,,,, -36600,0.07177734,0.028642308,,,,,,,,,,,,,,,,, -36700,0.06513672,0.028566653,,,,,,,,,,,,,,,,, -36800,0.06628231,0.029207885,,,,,,,,,,,,,,,,, -36900,0.06597647,0.025595067,,,,,,,,,,,,,,,,, -36907,,,0.9923199415206908,0.0244986284524202,0.5594185517547302,0.9871170520782472,0.0442498102784156,0.2851949177370656,43793.0,0.9861688017845154,0.0475516207516193,0.2679133358897493,43793.0,11545.093611717224,18061.448248147964,11545.093611717224,6513.972229719162,1.4337913990020752,0.0 -37000,0.059119638,0.027538296,,,,,,,,,,,,,,,,, -37100,0.06570045,0.025445938,,,,,,,,,,,,,,,,, -37200,0.06788597,0.029427381,,,,,,,,,,,,,,,,, -37300,0.06764655,0.027303552,,,,,,,,,,,,,,,,, -37400,0.08968859,0.028090611,,,,,,,,,,,,,,,,, -37500,0.06296309,0.02920105,,,,,,,,,,,,,,,,, -37600,0.06563954,0.029880323,,,,,,,,,,,,,,,,, -37673,,,0.9923747181892396,0.0244077499955892,0.5358389174143836,0.9870086312294006,0.0447503253817558,0.2843611275052901,43793.0,0.9860794544219972,0.0477609112858772,0.2707454139816693,43793.0,11785.12488269806,18424.14876651764,11785.12488269806,6636.588159322739,1.467538833618164,0.0 -37700,0.06355555,0.027859436,,,,,,,,,,,,,,,,, -37800,0.06648535,0.02662895,,,,,,,,,,,,,,,,, -37900,0.062979564,0.028692609,,,,,,,,,,,,,,,,, -38000,0.06942007,0.02998986,,,,,,,,,,,,,,,,, -38100,0.0651289,0.026674816,,,,,,,,,,,,,,,,, -38200,0.066659115,0.02712325,,,,,,,,,,,,,,,,, -38300,0.07672766,0.029724566,,,,,,,,,,,,,,,,, -38400,0.062267922,0.02664675,,,,,,,,,,,,,,,,, -38439,,,0.992568016052246,0.0237795822322368,0.5684574426282603,0.9871584177017212,0.044204156845808,0.28866362399726,43793.0,0.9862340688705444,0.0471976324915885,0.2755533497816337,43793.0,12025.290924072266,18788.307156085968,12025.290924072266,6760.528062582016,1.5004465579986572,0.0 -38500,0.06450726,0.026971392,,,,,,,,,,,,,,,,, -38600,0.06452439,0.025636058,,,,,,,,,,,,,,,,, -38700,0.06494483,0.028798193,,,,,,,,,,,,,,,,, -38800,0.06403207,0.028911987,,,,,,,,,,,,,,,,, -38900,0.069937244,0.02473295,,,,,,,,,,,,,,,,, -39000,0.07181739,0.028304739,,,,,,,,,,,,,,,,, -39100,0.070434526,0.02991578,,,,,,,,,,,,,,,,, -39200,0.057871874,0.0250429,,,,,,,,,,,,,,,,, -39209,,,0.99250328540802,0.0239628404378891,0.5541409386821077,0.98708575963974,0.0444251894950866,0.283720789086509,43793.0,0.9862378239631652,0.0474700033664703,0.2816381130732368,43793.0,12265.243272781372,19151.94514513016,12265.243272781372,6884.158543348312,1.536177396774292,0.0 -39300,0.064398885,0.027008181,,,,,,,,,,,,,,,,, -39400,0.070892096,0.024391253,,,,,,,,,,,,,,,,, -39500,0.070938334,0.026453724,,,,,,,,,,,,,,,,, -39600,0.069964305,0.026869435,,,,,,,,,,,,,,,,, -39700,0.06680574,0.029370653,,,,,,,,,,,,,,,,, -39800,0.06829277,0.02824937,,,,,,,,,,,,,,,,, -39900,0.074742824,0.028410554,,,,,,,,,,,,,,,,, -39979,,,0.9925006031990052,0.0241887252777814,0.5570633831413991,0.9870301485061646,0.0446991473436355,0.2828670110903142,43793.0,0.9860883355140686,0.0477647967636585,0.2842939134219475,43793.0,12505.26472067833,19515.94053053856,12505.26472067833,7008.079576253891,1.56976318359375,0.0 -40000,0.068798915,0.028964903,,,,,,,,,,,,,,,,, -40100,0.08942321,0.025304154,,,,,,,,,,,,,,,,, -40200,0.07206002,0.025801962,,,,,,,,,,,,,,,,, -40300,0.0664175,0.025615666,,,,,,,,,,,,,,,,, -40400,0.10133054,0.028942775,,,,,,,,,,,,,,,,, -40500,0.073964484,0.027114615,,,,,,,,,,,,,,,,, -40600,0.086399004,0.028729491,,,,,,,,,,,,,,,,, -40700,0.057378013,0.024737233,,,,,,,,,,,,,,,,, -40717,,,0.9924003481864928,0.0243264287710189,0.5494295396144823,0.9869505763053894,0.0445120818912982,0.2806861022966553,43793.0,0.9861072897911072,0.0472981296479702,0.2748004192246521,43793.0,12745.222152233124,19879.48866820336,12745.222152233124,7131.618180990219,1.6031405925750732,0.0 -40800,0.07821153,0.024849722,,,,,,,,,,,,,,,,, -40900,0.071870364,0.02868802,,,,,,,,,,,,,,,,, -41000,0.06634535,0.02516034,,,,,,,,,,,,,,,,, -41100,0.06945185,0.024222579,,,,,,,,,,,,,,,,, -41200,0.06486697,0.02662655,,,,,,,,,,,,,,,,, -41300,0.09720765,0.026879255,,,,,,,,,,,,,,,,, -41400,0.071530916,0.022899447,,,,,,,,,,,,,,,,, -41485,,,0.9922370314598083,0.0248098876327276,0.5382047059952417,0.9869790077209472,0.0448886454105377,0.2904843775330589,43793.0,0.986151933670044,0.0478561930358409,0.2822238579045632,43793.0,12985.309354305267,20247.44675898552,12985.309354305267,7259.434130430222,1.6383109092712402,0.0 -41500,0.070230015,0.022366904,,,,,,,,,,,,,,,,, -41600,0.061383225,0.025607655,,,,,,,,,,,,,,,,, -41700,0.07639502,0.028554644,,,,,,,,,,,,,,,,, -41800,0.09422529,0.023540841,,,,,,,,,,,,,,,,, -41900,0.094626896,0.025723943,,,,,,,,,,,,,,,,, -42000,0.06458985,0.023396663,,,,,,,,,,,,,,,,, -42100,0.071683146,0.02522218,,,,,,,,,,,,,,,,, -42200,0.06741954,0.02661529,,,,,,,,,,,,,,,,, -42250,,,0.992445707321167,0.0241518169641494,0.5460987084725499,0.98708575963974,0.0448537394404411,0.287187447261353,43793.0,0.9862955808639526,0.0478070452809333,0.2830454836599912,43793.0,13225.327796459198,20614.092463970184,13225.327796459198,7386.007725477219,1.672457218170166,0.0 -42300,0.08510759,0.027894758,,,,,,,,,,,,,,,,, -42400,0.08129177,0.027161777,,,,,,,,,,,,,,,,, -42500,0.08787802,0.02537116,,,,,,,,,,,,,,,,, -42600,0.07973033,0.02626128,,,,,,,,,,,,,,,,, -42700,0.07204375,0.024726091,,,,,,,,,,,,,,,,, -42800,0.0707091,0.02967842,,,,,,,,,,,,,,,,, -42900,0.07671857,0.02544664,,,,,,,,,,,,,,,,, -43000,0.068310685,0.024933953,,,,,,,,,,,,,,,,, -43008,,,0.9924915432929992,0.0239395014941692,0.5532173934177393,0.9871170520782472,0.0446819961071014,0.2867559415482638,43793.0,0.9862471222877502,0.0477485209703445,0.2772697280132931,43793.0,13465.478171348572,20988.90712785721,13465.478171348572,7520.616132259369,1.7076139450073242,0.0 -43100,0.0724648,0.027261352,,,,,,,,,,,,,,,,, -43200,0.07015997,0.025812976,,,,,,,,,,,,,,,,, -43300,0.06157221,0.024664873,,,,,,,,,,,,,,,,, -43400,0.08066991,0.026055928,,,,,,,,,,,,,,,,, -43500,0.0783612,0.025789557,,,,,,,,,,,,,,,,, -43600,0.07810655,0.027507147,,,,,,,,,,,,,,,,, -43700,0.07057306,0.025776327,,,,,,,,,,,,,,,,, -43769,,,0.9925384521484376,0.023691838607192,0.5701476633664833,0.9872027039527892,0.0447466224431991,0.289002886959356,43793.0,0.986272394657135,0.0479588136076927,0.2796765902545455,43793.0,13705.421085119247,21352.05406999588,13705.421085119247,7643.762209892273,1.746145725250244,0.0 -43800,0.0701807,0.025502566,,,,,,,,,,,,,,,,, -43900,0.06979494,0.02541484,,,,,,,,,,,,,,,,, -44000,0.06745691,0.024119573,,,,,,,,,,,,,,,,, -44100,0.082260616,0.029720591,,,,,,,,,,,,,,,,, -44200,0.07689688,0.026284995,,,,,,,,,,,,,,,,, -44300,0.07564869,0.024680413,,,,,,,,,,,,,,,,, -44400,0.072455496,0.027125219,,,,,,,,,,,,,,,,, -44500,0.079741016,0.024005873,,,,,,,,,,,,,,,,, -44535,,,0.9927001595497132,0.023185497149825,0.5688899401380112,0.9870370626449584,0.0447141453623771,0.2934221886301122,43793.0,0.9860782027244568,0.047963697463274,0.2759291178161914,43793.0,13945.649666070938,21715.24806857109,13945.649666070938,7766.672614097595,1.7812371253967283,0.0 -44600,0.08987566,0.027989978,,,,,,,,,,,,,,,,, -44700,0.089505315,0.02704855,,,,,,,,,,,,,,,,, -44800,0.0878582,0.025929643,,,,,,,,,,,,,,,,, -44900,0.078277715,0.024638288,,,,,,,,,,,,,,,,, -45000,0.07977866,0.026827073,,,,,,,,,,,,,,,,, -45100,0.09274688,0.0293265,,,,,,,,,,,,,,,,, -45200,0.08339694,0.027573317,,,,,,,,,,,,,,,,, -45299,,,0.9928914308547974,0.0225920602679252,0.5915234600156412,0.9871146082878112,0.0449960939586162,0.2908628579927392,43793.0,0.9862774610519408,0.0480165034532547,0.2768424105806721,43793.0,14185.667350292206,22081.31978273392,14185.667350292206,7892.67279791832,1.8158187866210933,0.0 -45300,0.07512229,0.023699075,,,,,,,,,,,,,,,,, -45400,0.06513177,0.022214476,,,,,,,,,,,,,,,,, -45500,0.088367775,0.025793634,,,,,,,,,,,,,,,,, -45600,0.094617166,0.030005526,,,,,,,,,,,,,,,,, -45700,0.06614562,0.023446832,,,,,,,,,,,,,,,,, -45800,0.07758399,0.024254085,,,,,,,,,,,,,,,,, -45900,0.07614008,0.023403395,,,,,,,,,,,,,,,,, -46000,0.088155106,0.026357602,,,,,,,,,,,,,,,,, -46072,,,0.9930112957954408,0.0221614427864551,0.5958331862615327,0.98707115650177,0.0450821332633495,0.2912427689904615,43793.0,0.986213445663452,0.0481310449540615,0.2777517498001093,43793.0,14425.91460299492,22448.270634174347,14425.91460299492,8019.319634437561,1.8533244132995603,0.0 -46100,0.090799436,0.025981152,,,,,,,,,,,,,,,,, -46200,0.08414404,0.023912776,,,,,,,,,,,,,,,,, -46300,0.06765373,0.020909738,,,,,,,,,,,,,,,,, -46400,0.08706038,0.024771811,,,,,,,,,,,,,,,,, -46500,0.071483836,0.021939494,,,,,,,,,,,,,,,,, -46600,0.0823031,0.023732284,,,,,,,,,,,,,,,,, -46700,0.07026024,0.022876507,,,,,,,,,,,,,,,,, -46800,0.07942339,0.024696333,,,,,,,,,,,,,,,,, -46838,,,0.9932113885879515,0.0215078685432672,0.6120571175636954,0.9871028065681458,0.0451406985521316,0.2953311430028864,43793.0,0.9862639904022216,0.0482524186372756,0.278822280780592,43793.0,14666.095184326172,22812.28260588646,14666.095184326172,8143.096262216568,1.8887808322906487,0.0 -46900,0.0768925,0.025275305,,,,,,,,,,,,,,,,, -47000,0.06972937,0.021495374,,,,,,,,,,,,,,,,, -47100,0.1076017,0.028539008,,,,,,,,,,,,,,,,, -47200,0.07639225,0.022854434,,,,,,,,,,,,,,,,, -47300,0.079163864,0.026885195,,,,,,,,,,,,,,,,, -47400,0.07926086,0.022707133,,,,,,,,,,,,,,,,, -47500,0.08191665,0.026284963,,,,,,,,,,,,,,,,, -47600,0.08989582,0.02506862,,,,,,,,,,,,,,,,, -47607,,,0.9933398962020874,0.0211241487413644,0.6331789829490917,0.9870001077651978,0.0456755384802818,0.2915323484639742,43793.0,0.9861923456192015,0.0486282743513584,0.2688196035004163,43793.0,14906.135026216509,23175.271939516068,14906.135026216509,8265.990796804428,1.9238817691802976,0.0 -47700,0.078990854,0.022342926,,,,,,,,,,,,,,,,, -47800,0.07205147,0.023189573,,,,,,,,,,,,,,,,, -47900,0.093731746,0.02817454,,,,,,,,,,,,,,,,, -48000,0.077789225,0.02249642,,,,,,,,,,,,,,,,, -48100,0.08470778,0.024848782,,,,,,,,,,,,,,,,, -48200,0.0857319,0.026566721,,,,,,,,,,,,,,,,, -48300,0.09715331,0.025449304,,,,,,,,,,,,,,,,, -48379,,,0.9934019446372986,0.0209897384047508,0.6255306974010649,0.9870269298553468,0.0455353967845439,0.2903014903988803,43793.0,0.9862968325614928,0.0484311506152153,0.2783990005573538,43793.0,15146.269136428831,23535.936802864075,15146.269136428831,8386.46659874916,1.958917617797852,0.0 -48400,0.09162096,0.02438278,,,,,,,,,,,,,,,,, -48500,0.08410693,0.021306306,,,,,,,,,,,,,,,,, -48600,0.100245856,0.027722746,,,,,,,,,,,,,,,,, -48700,0.07474046,0.02245579,,,,,,,,,,,,,,,,, -48800,0.078963436,0.020557426,,,,,,,,,,,,,,,,, -48900,0.085409455,0.025814762,,,,,,,,,,,,,,,,, -49000,0.08580166,0.023041157,,,,,,,,,,,,,,,,, -49100,0.09742782,0.025091762,,,,,,,,,,,,,,,,, -49148,,,0.9931447505950928,0.0217654425650835,0.6131731781560671,0.986970067024231,0.0455004014074802,0.2893361191203478,43793.0,0.9860782027244568,0.0487363375723362,0.2771831528863973,43793.0,15386.475734949112,23896.030689239506,15386.475734949112,8506.296981334686,1.9959444999694824,0.0 -49200,0.088224255,0.02527732,,,,,,,,,,,,,,,,, -49300,0.10060071,0.027239371,,,,,,,,,,,,,,,,, -49400,0.08964201,0.024529971,,,,,,,,,,,,,,,,, -49500,0.095934466,0.024517352,,,,,,,,,,,,,,,,, -49600,0.08850791,0.02227036,,,,,,,,,,,,,,,,, -49700,0.08417005,0.02496893,,,,,,,,,,,,,,,,, -49800,0.08374853,0.023547104,,,,,,,,,,,,,,,,, -49900,0.092878826,0.023654133,,,,,,,,,,,,,,,,, -49921,,,0.9931887984275818,0.0215563420206308,0.5925207424069308,0.9870281219482422,0.0456475801765918,0.2885576907983909,43793.0,0.986199915409088,0.0487848855555057,0.2782597699814962,43793.0,15626.695341348648,24260.37875604629,15626.695341348648,8630.36907529831,2.0322859287261963,0.0 -50000,0.09034279,0.0223434,,,,,,,,,,,,,,,,, -50100,0.09659864,0.02483826,,,,,,,,,,,,,,,,, -50200,0.08479161,0.023927301,,,,,,,,,,,,,,,,, -50300,0.09598785,0.02585023,,,,,,,,,,,,,,,,, -50400,0.08938914,0.02489834,,,,,,,,,,,,,,,,, -50500,0.08726856,0.024436723,,,,,,,,,,,,,,,,, -50600,0.0908073,0.024771782,,,,,,,,,,,,,,,,, -50693,,,0.9932281374931335,0.0214937105774879,0.6125680428291371,0.9870500564575196,0.0457317978143692,0.2902672784161562,43793.0,0.9861531853675842,0.0487932786345481,0.2808486112206802,43793.0,15866.866524219511,24623.690307617188,15866.866524219511,8753.453888893127,2.068065881729126,0.0 -50700,0.10073015,0.025111133,,,,,,,,,,,,,,,,, -50800,0.09652251,0.026663603,,,,,,,,,,,,,,,,, -50900,0.093948066,0.02358647,,,,,,,,,,,,,,,,, -51000,0.09172083,0.02293219,,,,,,,,,,,,,,,,, -51100,0.10781182,0.027563706,,,,,,,,,,,,,,,,, -51200,0.08306159,0.024519002,,,,,,,,,,,,,,,,, -51300,0.09262247,0.022232082,,,,,,,,,,,,,,,,, -51400,0.09834774,0.026211664,,,,,,,,,,,,,,,,, -51460,,,0.9931707382202148,0.0215623006224632,0.6086612239107657,0.9869388341903688,0.0458887852728366,0.2875652110156065,43793.0,0.9861097931861876,0.0489931516349315,0.2756912633800584,43793.0,16106.96294260025,24981.755274295807,16106.96294260025,8871.36655497551,2.1043295860290527,0.0 -51500,0.09150203,0.023245804,,,,,,,,,,,,,,,,, -51600,0.1050427,0.023180379,,,,,,,,,,,,,,,,, -51700,0.12194751,0.024826085,,,,,,,,,,,,,,,,, -51800,0.08473392,0.022944424,,,,,,,,,,,,,,,,, -51900,0.09771656,0.026028715,,,,,,,,,,,,,,,,, -52000,0.096427634,0.023738263,,,,,,,,,,,,,,,,, -52100,0.08544656,0.0239584,,,,,,,,,,,,,,,,, -52200,0.08495703,0.021166489,,,,,,,,,,,,,,,,, -52231,,,0.9933059811592102,0.0211930908262729,0.6209554499918561,0.9869749546051024,0.0457281917333602,0.2894453169966877,43793.0,0.986139714717865,0.0490336157381534,0.2792324231384169,43793.0,16346.9478225708,25344.649944782257,16346.9478225708,8994.220281600952,2.1409432888031006,0.0 -52300,0.098696105,0.023856804,,,,,,,,,,,,,,,,, -52400,0.095350385,0.023707397,,,,,,,,,,,,,,,,, -52500,0.092473894,0.023779761,,,,,,,,,,,,,,,,, -52600,0.10570959,0.023254234,,,,,,,,,,,,,,,,, -52700,0.094325185,0.024319507,,,,,,,,,,,,,,,,, -52800,0.08338653,0.022100285,,,,,,,,,,,,,,,,, -52900,0.10347526,0.024659075,,,,,,,,,,,,,,,,, -53000,0.08773504,0.021656757,,,,,,,,,,,,,,,,, -53001,,,0.9933321475982666,0.0208378545939922,0.6346190961639615,0.9871361255645752,0.0464642457664012,0.2905150610562712,43793.0,0.9862247705459596,0.0497451350092887,0.2790904600309866,43793.0,16586.905706882477,25707.349474668503,16586.905706882477,9116.906694173813,2.1766245365142822,0.0 -53100,0.0914059,0.022812609,,,,,,,,,,,,,,,,, -53200,0.102574445,0.02550683,,,,,,,,,,,,,,,,, -53300,0.10237946,0.023508126,,,,,,,,,,,,,,,,, -53400,0.09827279,0.023786744,,,,,,,,,,,,,,,,, -53500,0.085812986,0.020496648,,,,,,,,,,,,,,,,, -53600,0.099862896,0.024339156,,,,,,,,,,,,,,,,, -53700,0.09129014,0.02179615,,,,,,,,,,,,,,,,, -53745,,,0.9934971332550048,0.0204658973962068,0.630007503667221,0.9869733452796936,0.0461771227419376,0.2869864873591823,43793.0,0.9861270785331726,0.0495319850742816,0.2741914971743095,43793.0,16827.004104614258,26069.652314662933,16827.004104614258,9239.053257226944,2.212446451187134,0.0 -53800,0.11306518,0.024706123,,,,,,,,,,,,,,,,, -53900,0.095846795,0.023239104,,,,,,,,,,,,,,,,, -54000,0.10508429,0.023231804,,,,,,,,,,,,,,,,, -54100,0.10456322,0.024188392,,,,,,,,,,,,,,,,, -54200,0.1015193,0.025590625,,,,,,,,,,,,,,,,, -54300,0.10153379,0.023780666,,,,,,,,,,,,,,,,, -54400,0.10181174,0.022857865,,,,,,,,,,,,,,,,, -54500,0.0841515,0.02000567,,,,,,,,,,,,,,,,, -54511,,,0.9936056733131408,0.0200942922383546,0.6474843711686171,0.9870001077651978,0.0466343276202678,0.2924957112403711,43793.0,0.986234486103058,0.0499737039208412,0.275257343295017,43793.0,17067.124797344208,26427.993901014328,17067.124797344208,9357.21773147583,2.249485731124878,0.0 -54600,0.094153844,0.021848053,,,,,,,,,,,,,,,,, -54700,0.10031427,0.022871798,,,,,,,,,,,,,,,,, -54800,0.09905754,0.024247333,,,,,,,,,,,,,,,,, -54900,0.11206517,0.025448274,,,,,,,,,,,,,,,,, -55000,0.11142069,0.022954475,,,,,,,,,,,,,,,,, -55100,0.09754175,0.021687256,,,,,,,,,,,,,,,,, -55200,0.10030465,0.023146408,,,,,,,,,,,,,,,,, -55261,,,0.9938763976097108,0.0192595422267913,0.6639851968755207,0.987063467502594,0.0468171574175357,0.2882980116065824,43793.0,0.986182689666748,0.0499526932835578,0.2779944627561785,43793.0,17307.159264326096,26793.98760509491,17307.159264326096,9483.120379209518,2.2856380939483643,0.0 -55300,0.09995576,0.023492962,,,,,,,,,,,,,,,,, -55400,0.10088278,0.02474331,,,,,,,,,,,,,,,,, -55500,0.09793782,0.02295066,,,,,,,,,,,,,,,,, -55600,0.1087976,0.02307996,,,,,,,,,,,,,,,,, -55700,0.114197984,0.023226244,,,,,,,,,,,,,,,,, -55800,0.113600574,0.024030974,,,,,,,,,,,,,,,,, -55900,0.08531658,0.018450782,,,,,,,,,,,,,,,,, -56000,0.1276716,0.023348464,,,,,,,,,,,,,,,,, -56031,,,0.9941009283065796,0.0186919644474983,0.6720966607253281,0.9870317578315736,0.0469521507620811,0.2862859302274942,43793.0,0.9861670732498168,0.0500989556312561,0.2774957790292882,43793.0,17547.22730565071,27155.00487589836,17547.22730565071,9604.013299703598,2.322486400604248,0.0 -56100,0.09355043,0.019134816,,,,,,,,,,,,,,,,, -56200,0.100179546,0.023696063,,,,,,,,,,,,,,,,, -56300,0.11896334,0.024355507,,,,,,,,,,,,,,,,, -56400,0.1248439,0.022120789,,,,,,,,,,,,,,,,, -56500,0.09744739,0.019215329,,,,,,,,,,,,,,,,, -56600,0.10725004,0.021626549,,,,,,,,,,,,,,,,, -56700,0.09642575,0.022597484,,,,,,,,,,,,,,,,, -56768,,,0.9941667318344116,0.0182786993682384,0.6992927322278004,0.9869331121444702,0.0469362549483776,0.2908976053726619,43793.0,0.9861464500427246,0.0500405393540859,0.2790140354084867,43793.0,17787.253559589386,27525.11945939064,17787.253559589386,9734.03482222557,2.3649609088897705,0.0 -56800,0.09351907,0.024311773,,,,,,,,,,,,,,,,, -56900,0.10136087,0.019830815,,,,,,,,,,,,,,,,, -57000,0.11943842,0.022296341,,,,,,,,,,,,,,,,, -57100,0.10457981,0.021232143,,,,,,,,,,,,,,,,, -57200,0.10298509,0.021506345,,,,,,,,,,,,,,,,, -57300,0.11065503,0.023653422,,,,,,,,,,,,,,,,, -57400,0.121133514,0.02185238,,,,,,,,,,,,,,,,, -57500,0.1185373,0.023136048,,,,,,,,,,,,,,,,, -57528,,,0.994340181350708,0.0179514028131961,0.6863589255944273,0.9869924187660216,0.0468558855354785,0.2878277138915034,43793.0,0.9862096309661864,0.0500779040157794,0.2783108176366595,43793.0,18027.47919726372,27885.483307123184,18027.47919726372,9854.109743356705,2.407958745956421,0.0 -57600,0.09909903,0.021408403,,,,,,,,,,,,,,,,, -57700,0.11000553,0.021400744,,,,,,,,,,,,,,,,, -57800,0.10163609,0.0198012,,,,,,,,,,,,,,,,, -57900,0.106678374,0.023871684,,,,,,,,,,,,,,,,, -58000,0.12417649,0.022812596,,,,,,,,,,,,,,,,, -58100,0.10319942,0.019436436,,,,,,,,,,,,,,,,, -58200,0.112514205,0.019775951,,,,,,,,,,,,,,,,, -58296,,,0.9941099286079408,0.0186354778707027,0.6760400525936513,0.9870346188545228,0.0470816828310489,0.284984456051923,43793.0,0.9860777854919434,0.0504041463136673,0.2747297670159299,43793.0,18267.49107837677,28253.37156844139,18267.49107837677,9981.927907466888,2.44576358795166,0.0 -58300,0.11541006,0.02094733,,,,,,,,,,,,,,,,, -58400,0.1060137,0.023334317,,,,,,,,,,,,,,,,, -58500,0.10654318,0.021699456,,,,,,,,,,,,,,,,, -58600,0.114459105,0.024552831,,,,,,,,,,,,,,,,, -58700,0.11519977,0.020643938,,,,,,,,,,,,,,,,, -58800,0.12907037,0.022325814,,,,,,,,,,,,,,,,, -58900,0.10509325,0.019321887,,,,,,,,,,,,,,,,, -59000,0.12794906,0.02362139,,,,,,,,,,,,,,,,, -59062,,,0.9940509796142578,0.0185292568057775,0.6800617030023898,0.9870293140411376,0.0475336201488971,0.2872965751470687,43793.0,0.9862083792686462,0.0508651062846183,0.2728735942641346,43793.0,18507.593081474304,28611.847897052765,18507.593081474304,10100.244191408156,2.483938217163086,0.0 -59100,0.12756649,0.02328846,,,,,,,,,,,,,,,,, -59200,0.117779076,0.023890918,,,,,,,,,,,,,,,,, -59300,0.119698435,0.018805727,,,,,,,,,,,,,,,,, -59400,0.11515819,0.023073379,,,,,,,,,,,,,,,,, -59500,0.11142463,0.021763308,,,,,,,,,,,,,,,,, -59600,0.119137414,0.022269296,,,,,,,,,,,,,,,,, -59700,0.13113096,0.023589632,,,,,,,,,,,,,,,,, -59800,0.1130558,0.021828843,,,,,,,,,,,,,,,,, -59830,,,0.9940648674964904,0.0184501875191926,0.6847460980138358,0.9869866967201232,0.0475491769611835,0.2900075163292792,43793.0,0.9861797094345092,0.0508675575256347,0.2815177285535439,43793.0,18747.74967765808,28970.99193549156,18747.74967765808,10219.173105716704,2.522862672805786,0.0 -59900,0.10243682,0.020707043,,,,,,,,,,,,,,,,, -60000,0.10149572,0.01851349,,,,,,,,,,,,,,,,, -60100,0.111736104,0.021856101,,,,,,,,,,,,,,,,, -60200,0.120954864,0.020951202,,,,,,,,,,,,,,,,, -60300,0.11577521,0.021851562,,,,,,,,,,,,,,,,, -60400,0.111908525,0.021380816,,,,,,,,,,,,,,,,, -60500,0.11641806,0.02277258,,,,,,,,,,,,,,,,, -60600,0.12324413,0.022407398,,,,,,,,,,,,,,,,, -60601,,,0.9941041469573976,0.0182408802211284,0.6792949123470332,0.9869810342788696,0.0480771027505397,0.2883566277049855,43793.0,0.9862108826637268,0.0513647943735122,0.277457402308432,43793.0,18987.98977828025,29331.38431572914,18987.98977828025,10339.2678835392,2.560737371444702,0.0 -60700,0.112226866,0.02234615,,,,,,,,,,,,,,,,, -60800,0.14499028,0.023528405,,,,,,,,,,,,,,,,, -60900,0.13065338,0.02108089,,,,,,,,,,,,,,,,, -61000,0.11402299,0.020373672,,,,,,,,,,,,,,,,, -61100,0.1184943,0.020587992,,,,,,,,,,,,,,,,, -61200,0.11737086,0.019155085,,,,,,,,,,,,,,,,, -61300,0.12593475,0.021922141,,,,,,,,,,,,,,,,, -61370,,,0.9942519664764404,0.0178911536931991,0.6891838281450775,0.9869696497917176,0.0480914264917373,0.2907775939415492,43793.0,0.9860908389091492,0.05152552947402,0.2758072130658053,43793.0,19227.93788957596,29688.873893022537,19227.93788957596,10456.751803159714,2.598947525024414,0.0 -61400,0.13446403,0.022235904,,,,,,,,,,,,,,,,, -61500,0.11315589,0.020456608,,,,,,,,,,,,,,,,, -61600,0.13343479,0.019740071,,,,,,,,,,,,,,,,, -61700,0.15120232,0.022513626,,,,,,,,,,,,,,,,, -61800,0.12578082,0.02264031,,,,,,,,,,,,,,,,, -61900,0.10655757,0.02029651,,,,,,,,,,,,,,,,, -62000,0.117178455,0.020831995,,,,,,,,,,,,,,,,, -62100,0.10909317,0.019101234,,,,,,,,,,,,,,,,, -62141,,,0.99410742521286,0.018122250214219,0.6837093136990084,0.98700213432312,0.0485698021948337,0.2882322445128811,43793.0,0.986216366291046,0.0519167222082614,0.2729711537032664,43793.0,19467.98549079895,30049.69030022621,19467.98549079895,10577.463264465332,2.6371090412139893,0.0 -62200,0.1141464,0.01862991,,,,,,,,,,,,,,,,, -62300,0.108363435,0.016976781,,,,,,,,,,,,,,,,, -62400,0.120348245,0.019709326,,,,,,,,,,,,,,,,, -62500,0.14252305,0.0204582,,,,,,,,,,,,,,,,, -62600,0.12485121,0.01785447,,,,,,,,,,,,,,,,, -62700,0.11541976,0.018961376,,,,,,,,,,,,,,,,, -62800,0.12605342,0.019367708,,,,,,,,,,,,,,,,, -62900,0.12072922,0.019180868,,,,,,,,,,,,,,,,, -62914,,,0.9943687319755554,0.0174585618078708,0.710569110240898,0.9870545268058776,0.0485628210008144,0.2889291577087489,43793.0,0.9861868619918824,0.0521451793611049,0.2739277181367657,43793.0,19707.979488134384,30409.38412237168,19707.979488134384,10697.105581521988,2.6750237941741943,0.0 -63000,0.13092414,0.019823838,,,,,,,,,,,,,,,,, -63100,0.11570194,0.020196335,,,,,,,,,,,,,,,,, -63200,0.1182103,0.020460246,,,,,,,,,,,,,,,,, -63300,0.12409711,0.020623466,,,,,,,,,,,,,,,,, -63400,0.12751709,0.02147513,,,,,,,,,,,,,,,,, -63500,0.12937117,0.021358585,,,,,,,,,,,,,,,,, -63600,0.14501248,0.022685518,,,,,,,,,,,,,,,,, -63683,,,0.9945791959762572,0.0168491005897521,0.7172802330353509,0.9870272874832152,0.0488625429570674,0.2913743414060199,43793.0,0.9861218333244324,0.0523505769670009,0.2753913641660267,43793.0,19948.06530070305,30770.02154660225,19948.06530070305,10817.599865198135,2.712806940078736,0.0 -63700,0.12467614,0.020126235,,,,,,,,,,,,,,,,, -63800,0.13719966,0.021862509,,,,,,,,,,,,,,,,, -63900,0.1292014,0.01950606,,,,,,,,,,,,,,,,, -64000,0.123851925,0.017938556,,,,,,,,,,,,,,,,, -64100,0.13570756,0.020063752,,,,,,,,,,,,,,,,, -64200,0.12918049,0.02152888,,,,,,,,,,,,,,,,, -64300,0.1279096,0.021296535,,,,,,,,,,,,,,,,, -64400,0.12001083,0.018504176,,,,,,,,,,,,,,,,, -64451,,,0.9945203065872192,0.0168881583958864,0.7186517583486293,0.9870195984840392,0.0491692908108234,0.2907579213342952,43793.0,0.9862428903579712,0.0527086183428764,0.2760658608501192,43793.0,20188.19465708733,31129.054266929623,20188.19465708733,10936.445219993591,2.7514443397521973,0.0 -64500,0.12412349,0.021311138,,,,,,,,,,,,,,,,, -64600,0.11714761,0.018245814,,,,,,,,,,,,,,,,, -64700,0.13248233,0.018677663,,,,,,,,,,,,,,,,, -64800,0.15656067,0.023500297,,,,,,,,,,,,,,,,, -64900,0.113581695,0.018900506,,,,,,,,,,,,,,,,, -65000,0.13799846,0.01922973,,,,,,,,,,,,,,,,, -65100,0.12960607,0.020149313,,,,,,,,,,,,,,,,, -65200,0.12907541,0.021810226,,,,,,,,,,,,,,,,, -65222,,,0.994799256324768,0.016018958762288,0.7327114869218326,0.987107276916504,0.0491626411676406,0.28769506812764,43793.0,0.9862112998962402,0.0526616871356964,0.2729145729121758,43793.0,20428.312505960464,31488.72131800652,20428.312505960464,11055.935570955276,2.790802240371704,0.0 -65300,0.14120984,0.020606777,,,,,,,,,,,,,,,,, -65400,0.15995224,0.019507462,,,,,,,,,,,,,,,,, -65500,0.12670548,0.016375966,,,,,,,,,,,,,,,,, -65600,0.13889082,0.019285658,,,,,,,,,,,,,,,,, -65700,0.1401391,0.0224044,,,,,,,,,,,,,,,,, -65800,0.13045584,0.018914886,,,,,,,,,,,,,,,,, -65900,0.12741105,0.018862745,,,,,,,,,,,,,,,,, -65996,,,0.9950602650642396,0.0153986578807234,0.7444797052813373,0.9870184063911438,0.0495036095380783,0.2858881628178701,43793.0,0.9861881732940674,0.0529731400310993,0.2759885855868024,43793.0,20668.418220043182,31848.44557189941,20668.418220043182,11175.49587225914,2.8298258781433105,0.0 -66000,0.12712577,0.018824544,,,,,,,,,,,,,,,,, -66100,0.13645542,0.017559242,,,,,,,,,,,,,,,,, -66200,0.1215882,0.017735971,,,,,,,,,,,,,,,,, -66300,0.13291618,0.020699881,,,,,,,,,,,,,,,,, -66400,0.14147748,0.019200876,,,,,,,,,,,,,,,,, -66500,0.15073813,0.019800691,,,,,,,,,,,,,,,,, -66600,0.13680649,0.020132683,,,,,,,,,,,,,,,,, -66700,0.14048678,0.019169131,,,,,,,,,,,,,,,,, -66768,,,0.9951653480529784,0.0151965515688061,0.7628244662934002,0.9869709014892578,0.0494506396353244,0.284574734141852,43793.0,0.9861089587211608,0.0529910996556282,0.2741048176481009,43793.0,20908.61289286613,32208.92369627953,20908.61289286613,11295.720662355425,2.86948561668396,0.0 -66800,0.1590749,0.022092206,,,,,,,,,,,,,,,,, -66900,0.12696315,0.019196033,,,,,,,,,,,,,,,,, -67000,0.14829338,0.020664059,,,,,,,,,,,,,,,,, -67100,0.1431826,0.021940691,,,,,,,,,,,,,,,,, -67200,0.15290387,0.018273326,,,,,,,,,,,,,,,,, -67300,0.13168585,0.01905948,,,,,,,,,,,,,,,,, -67400,0.123432994,0.019256786,,,,,,,,,,,,,,,,, -67500,0.16103849,0.019925421,,,,,,,,,,,,,,,,, -67541,,,0.995209276676178,0.0150286378338932,0.7566962851795804,0.9869514107704164,0.049561109393835,0.2874051538617712,43793.0,0.9860563278198242,0.0530358180403709,0.2788811681071141,43793.0,21148.618430376053,32570.689962387085,21148.618430376053,11417.421566963196,2.909762382507324,0.0 -67600,0.15076797,0.019904051,,,,,,,,,,,,,,,,, -67700,0.12694404,0.01914105,,,,,,,,,,,,,,,,, -67800,0.13551116,0.018976672,,,,,,,,,,,,,,,,, -67900,0.13480985,0.016866656,,,,,,,,,,,,,,,,, -68000,0.14406428,0.01990578,,,,,,,,,,,,,,,,, -68100,0.14367835,0.019226165,,,,,,,,,,,,,,,,, -68200,0.13627774,0.017826498,,,,,,,,,,,,,,,,, -68300,0.14772433,0.018211436,,,,,,,,,,,,,,,,, -68313,,,0.9951818585395812,0.0150977028533816,0.7568026094827762,0.9868556261062622,0.0496043711900711,0.2851188808682722,43793.0,0.986006200313568,0.0531089939177036,0.2743112178448637,43793.0,21388.71308946609,32931.06196761131,21388.71308946609,11537.640836715698,2.9481992721557617,0.0 -68400,0.13757056,0.01890764,,,,,,,,,,,,,,,,, -68500,0.1476349,0.018463098,,,,,,,,,,,,,,,,, -68600,0.123300955,0.01630504,,,,,,,,,,,,,,,,, -68700,0.14987999,0.018501101,,,,,,,,,,,,,,,,, -68800,0.12609279,0.017917331,,,,,,,,,,,,,,,,, -68900,0.13735537,0.017522024,,,,,,,,,,,,,,,,, -69000,0.13451111,0.020666528,,,,,,,,,,,,,,,,, -69080,,,0.9952203035354614,0.014993236400187,0.7534647243648879,0.9869457483291626,0.0499713309109211,0.2849290280465007,43793.0,0.9861005544662476,0.0535284578800201,0.2751838973845723,43793.0,21628.667563438416,33289.1495051384,21628.667563438416,11655.715996980667,2.9867284297943115,0.0 -69100,0.1305275,0.018741813,,,,,,,,,,,,,,,,, -69200,0.121981665,0.015777625,,,,,,,,,,,,,,,,, -69300,0.1424399,0.020394113,,,,,,,,,,,,,,,,, -69400,0.13655177,0.01917678,,,,,,,,,,,,,,,,, -69500,0.14337417,0.019118764,,,,,,,,,,,,,,,,, -69600,0.13282456,0.018785123,,,,,,,,,,,,,,,,, -69700,0.14588445,0.01968998,,,,,,,,,,,,,,,,, -69800,0.15172501,0.020835593,,,,,,,,,,,,,,,,, -69825,,,0.9949262142181396,0.0156571958214044,0.7358463544646767,0.986976146697998,0.050358448177576,0.285283862791445,43793.0,0.9861228466033936,0.0538902319967746,0.2756676195806692,43793.0,21868.81844639778,33661.08895611763,21868.81844639778,11787.44162750244,3.026840209960937,0.0 -69900,0.14638214,0.01945683,,,,,,,,,,,,,,,,, -70000,0.14541629,0.016737277,,,,,,,,,,,,,,,,, -70100,0.1304397,0.019490646,,,,,,,,,,,,,,,,, -70200,0.11634073,0.01715618,,,,,,,,,,,,,,,,, -70300,0.1524664,0.016207142,,,,,,,,,,,,,,,,, -70400,0.12418387,0.017125182,,,,,,,,,,,,,,,,, -70500,0.13200703,0.018290702,,,,,,,,,,,,,,,,, -70562,,,0.9949321746826172,0.0156741421669721,0.7396729190456879,0.9868730306625366,0.0503210015594959,0.2847257727644883,43793.0,0.986019253730774,0.0538304969668388,0.2737792846556728,43793.0,22109.06586909294,34025.35740637779,22109.06586909294,11911.392924785614,3.07214879989624,0.0 -70600,0.14178084,0.017603517,,,,,,,,,,,,,,,,, -70700,0.14671971,0.019003998,,,,,,,,,,,,,,,,, -70800,0.14120418,0.017522845,,,,,,,,,,,,,,,,, -70900,0.15437613,0.020955933,,,,,,,,,,,,,,,,, -71000,0.14607254,0.019049002,,,,,,,,,,,,,,,,, -71100,0.13145025,0.016755778,,,,,,,,,,,,,,,,, -71200,0.13486464,0.018914547,,,,,,,,,,,,,,,,, -71300,0.15578447,0.019296976,,,,,,,,,,,,,,,,, -71302,,,0.9951404333114624,0.0150937959551811,0.753908089745251,0.9869737029075624,0.0502049997448921,0.2896799036222076,43793.0,0.9860963225364684,0.0537710450589656,0.2754479572202304,43793.0,22349.18640422821,34394.80511713028,22349.18640422821,12040.653289079666,3.115537643432617,0.0 -71400,0.13119224,0.017146984,,,,,,,,,,,,,,,,, -71500,0.12926154,0.017255303,,,,,,,,,,,,,,,,, -71600,0.14031804,0.01829154,,,,,,,,,,,,,,,,, -71700,0.13284777,0.018439857,,,,,,,,,,,,,,,,, -71800,0.14553392,0.018466061,,,,,,,,,,,,,,,,, -71900,0.13821942,0.018818738,,,,,,,,,,,,,,,,, -72000,0.14106649,0.015498788,,,,,,,,,,,,,,,,, -72069,,,0.9952371716499328,0.0147636476904153,0.7646365713289918,0.9869250059127808,0.0503023527562618,0.2853672165884759,43793.0,0.9861177802085876,0.0538271330296993,0.2747690843663519,43793.0,22589.416313409805,34759.30646586418,22589.416313409805,12164.864319086077,3.1560232639312744,0.0 -72100,0.15545535,0.019983811,,,,,,,,,,,,,,,,, -72200,0.1328431,0.019244315,,,,,,,,,,,,,,,,, -72300,0.13632372,0.015722545,,,,,,,,,,,,,,,,, -72400,0.13947599,0.0192536,,,,,,,,,,,,,,,,, -72500,0.14332968,0.017143883,,,,,,,,,,,,,,,,, -72600,0.14248785,0.018737206,,,,,,,,,,,,,,,,, -72700,0.13337012,0.018213138,,,,,,,,,,,,,,,,, -72800,0.12723069,0.016119128,,,,,,,,,,,,,,,,, -72838,,,0.9951958656311036,0.0147189879789948,0.7631483641737393,0.9870277047157288,0.0504600740969181,0.2847857847326915,43793.0,0.986180543899536,0.0540541112422943,0.2759792410743825,43793.0,22829.366702079773,35120.810903549194,22829.366702079773,12286.359723567964,3.19601845741272,0.0 -72900,0.16428718,0.020919802,,,,,,,,,,,,,,,,, -73000,0.13972817,0.019820228,,,,,,,,,,,,,,,,, -73100,0.15015264,0.018304057,,,,,,,,,,,,,,,,, -73200,0.15479289,0.019385539,,,,,,,,,,,,,,,,, -73300,0.14421529,0.01913824,,,,,,,,,,,,,,,,, -73400,0.14894073,0.017787356,,,,,,,,,,,,,,,,, -73500,0.15233003,0.022235734,,,,,,,,,,,,,,,,, -73600,0.13711461,0.01770225,,,,,,,,,,,,,,,,, -73601,,,0.9955556988716124,0.0139520000666379,0.7762335065818747,0.9869396090507508,0.0503431037068367,0.2852072589426013,43793.0,0.9861106276512146,0.0539973154664039,0.2749001561535323,43793.0,23069.62701439857,35482.71941781044,23069.62701439857,12407.947816610336,3.236490726470948,0.0 -73700,0.13919052,0.017538771,,,,,,,,,,,,,,,,, -73800,0.16353974,0.01908526,,,,,,,,,,,,,,,,, -73900,0.1394757,0.017440898,,,,,,,,,,,,,,,,, -74000,0.1355188,0.017229538,,,,,,,,,,,,,,,,, -74100,0.14570956,0.01931634,,,,,,,,,,,,,,,,, -74200,0.14165686,0.017328653,,,,,,,,,,,,,,,,, -74300,0.15660536,0.017748758,,,,,,,,,,,,,,,,, -74361,,,0.9955661296844482,0.0139135178178548,0.7747334711189504,0.9869696497917176,0.050527736544609,0.2868443094424433,43793.0,0.9861372113227844,0.0541548170149326,0.2747839274994807,43793.0,23309.805548667908,35841.558666944504,23309.805548667908,12526.548835277556,3.2765581607818604,0.0 -74400,0.13160762,0.017281458,,,,,,,,,,,,,,,,, -74500,0.1370888,0.016873557,,,,,,,,,,,,,,,,, -74600,0.1279744,0.016228545,,,,,,,,,,,,,,,,, -74700,0.14246571,0.017406832,,,,,,,,,,,,,,,,, -74800,0.15309568,0.018422622,,,,,,,,,,,,,,,,, -74900,0.13895532,0.01692168,,,,,,,,,,,,,,,,, -75000,0.14840329,0.01775938,,,,,,,,,,,,,,,,, -75100,0.14440972,0.017659191,,,,,,,,,,,,,,,,, -75125,,,0.9955736994743348,0.0139497052878141,0.7811614897240826,0.9869436621665956,0.0505538024008274,0.2867632035727815,43793.0,0.9861240983009338,0.054233469069004,0.2750307266695363,43793.0,23549.76802420616,36202.37942099571,23549.76802420616,12647.343688488008,3.3206896781921387,0.0 -75200,0.12534265,0.016836995,,,,,,,,,,,,,,,,, -75300,0.1492718,0.01798925,,,,,,,,,,,,,,,,, -75400,0.13409092,0.016562734,,,,,,,,,,,,,,,,, -75500,0.15447982,0.017962124,,,,,,,,,,,,,,,,, -75600,0.14594975,0.018039789,,,,,,,,,,,,,,,,, -75700,0.1507163,0.01788018,,,,,,,,,,,,,,,,, -75800,0.14735533,0.019859254,,,,,,,,,,,,,,,,, -75883,,,0.9955554008483888,0.0138175962492823,0.7852192319353195,0.9869920015335084,0.0505729056894779,0.2878512364293371,43793.0,0.9861254096031188,0.0542790330946445,0.2750352727505398,43793.0,23789.78480911255,36561.82056903839,23789.78480911255,12766.707190036774,3.362133741378784,0.0 -75900,0.14779902,0.017151322,,,,,,,,,,,,,,,,, -76000,0.14618799,0.017988164,,,,,,,,,,,,,,,,, -76100,0.13525774,0.016363207,,,,,,,,,,,,,,,,, -76200,0.16498342,0.0176478,,,,,,,,,,,,,,,,, -76300,0.15389262,0.019278038,,,,,,,,,,,,,,,,, -76400,0.14645205,0.018072903,,,,,,,,,,,,,,,,, -76500,0.15151504,0.019931898,,,,,,,,,,,,,,,,, -76600,0.13444234,0.016007984,,,,,,,,,,,,,,,,, -76655,,,0.9955092668533324,0.0140359466895461,0.7789031097164165,0.98692786693573,0.0505664274096488,0.2857804150321548,43793.0,0.986119508743286,0.0542350113391876,0.2757211049940482,43793.0,24029.778205871586,36923.76485896111,24029.778205871586,12888.594792842863,3.4054219722747803,0.0 -76700,0.1642171,0.019197876,,,,,,,,,,,,,,,,, -76800,0.140056,0.018741643,,,,,,,,,,,,,,,,, -76900,0.14719102,0.018650206,,,,,,,,,,,,,,,,, -77000,0.15455306,0.020080473,,,,,,,,,,,,,,,,, -77100,0.14791347,0.018916255,,,,,,,,,,,,,,,,, -77200,0.17024261,0.018978043,,,,,,,,,,,,,,,,, -77300,0.13921525,0.016596248,,,,,,,,,,,,,,,,, -77400,0.1312447,0.016609207,,,,,,,,,,,,,,,,, -77424,,,0.995467722415924,0.0140506261959671,0.7791727016973256,0.9869558811187744,0.0506208762526512,0.2858431907427471,43793.0,0.9861531853675842,0.0542887225747108,0.2760953723677083,43793.0,24269.738109350204,37284.09049272537,24269.738109350204,13008.900380373,3.4460644721984863,0.0 -77500,0.13029796,0.017077705,,,,,,,,,,,,,,,,, -77600,0.15050957,0.018672107,,,,,,,,,,,,,,,,, -77700,0.1569551,0.019273689,,,,,,,,,,,,,,,,, -77800,0.14729871,0.018033892,,,,,,,,,,,,,,,,, -77900,0.14737637,0.017073555,,,,,,,,,,,,,,,,, -78000,0.14780237,0.019722404,,,,,,,,,,,,,,,,, -78100,0.16656962,0.01691766,,,,,,,,,,,,,,,,, -78189,,,0.9954847693443298,0.0140595519915223,0.7697898776753651,0.9869655966758728,0.050602450966835,0.2866528137796882,43793.0,0.9861258268356324,0.0542734786868095,0.2755627204775177,43793.0,24509.947502613068,37643.56195235253,24509.947502613068,13128.102998018265,3.4863169193267822,0.0 -78200,0.13830149,0.018125145,,,,,,,,,,,,,,,,, -78300,0.13533014,0.016218139,,,,,,,,,,,,,,,,, -78400,0.14842618,0.018752085,,,,,,,,,,,,,,,,, -78500,0.17096049,0.022580206,,,,,,,,,,,,,,,,, -78600,0.1404108,0.016589368,,,,,,,,,,,,,,,,, -78700,0.14099398,0.01718809,,,,,,,,,,,,,,,,, -78800,0.14649601,0.019311486,,,,,,,,,,,,,,,,, -78900,0.14309187,0.016645651,,,,,,,,,,,,,,,,, -78951,,,0.9955748915672302,0.0139743052423,0.7815056272749845,0.9869611263275146,0.05058304220438,0.2863825448521734,43793.0,0.986123263835907,0.0542533323168754,0.2761878777633754,43793.0,24749.636902093887,38003.607684612274,24749.636902093887,13248.0815885067,3.8449010848999023,0.0 -79000,0.14285944,0.017829651,,,,,,,,,,,,,,,,, -79100,0.14894366,0.017639928,,,,,,,,,,,,,,,,, -79200,0.1787774,0.020500837,,,,,,,,,,,,,,,,, -79300,0.14961928,0.017083028,,,,,,,,,,,,,,,,, -79400,0.14641556,0.019129762,,,,,,,,,,,,,,,,, -79500,0.17183837,0.018323723,,,,,,,,,,,,,,,,, -79600,0.14532039,0.016975049,,,,,,,,,,,,,,,,, -79700,0.1491936,0.01838155,,,,,,,,,,,,,,,,, -79719,,,0.9955337047576904,0.0138495787978172,0.7799917830671548,0.9869562983512878,0.0505892001092433,0.2869632430767811,43793.0,0.986132562160492,0.0542628169059753,0.2760192221364203,43793.0,24989.647536993027,38366.77245020866,24989.647536993027,13371.174401044846,3.886876583099365,0.0 -79800,0.14883178,0.017430503,,,,,,,,,,,,,,,,, -79900,0.14190613,0.01748543,,,,,,,,,,,,,,,,, -80000,0.16389488,0.019672664,,,,,,,,,,,,,,,,, -80100,0.13132106,0.017847253,,,,,,,,,,,,,,,,, -80200,0.16156642,0.02012315,,,,,,,,,,,,,,,,, -80300,0.14450943,0.018080045,,,,,,,,,,,,,,,,, -80400,0.14456952,0.017344924,,,,,,,,,,,,,,,,, -80486,,,0.9955554008483888,0.0139354560524225,0.7850682347056999,0.9869558811187744,0.0505871586501598,0.2865156718168872,43793.0,0.9861329793930054,0.0542605891823768,0.2760004584768348,43793.0,25229.74752855301,38722.26422739029,25229.74752855301,13486.504242897034,3.929654359817505,0.0 -80500,0.14690581,0.017364465,,,,,,,,,,,,,,,,, -80600,0.14340232,0.017203355,,,,,,,,,,,,,,,,, -80700,0.14100702,0.016993683,,,,,,,,,,,,,,,,, -80800,0.14224191,0.017393524,,,,,,,,,,,,,,,,, -80900,0.14008759,0.015803184,,,,,,,,,,,,,,,,, -81000,0.12840939,0.016760947,,,,,,,,,,,,,,,,, -81100,0.13794929,0.017236378,,,,,,,,,,,,,,,,, -81200,0.12716797,0.016560383,,,,,,,,,,,,,,,,, -81232,,,0.9955120086669922,0.0139531018212437,0.776054432927252,0.9869558811187744,0.0505871586501598,0.2865305392511518,43793.0,0.9861329793930054,0.0542605891823768,0.2759756767757269,43793.0,25469.8677110672,39097.34252977371,25469.8677110672,13621.399010181429,3.971792221069336,0.0 -81300,0.13051446,0.01711524,,,,,,,,,,,,,,,,, -81400,0.15879478,0.015845224,,,,,,,,,,,,,,,,, -81500,0.14660576,0.015968787,,,,,,,,,,,,,,,,, -81600,0.12736444,0.015712824,,,,,,,,,,,,,,,,, -81700,0.14686169,0.018404393,,,,,,,,,,,,,,,,, -81800,0.13940814,0.014296952,,,,,,,,,,,,,,,,, -81900,0.13966867,0.017957717,,,,,,,,,,,,,,,,, -81968,,,0.995551586151123,0.0139019964262843,0.7857464322595502,0.9869558811187744,0.0505871586501598,0.2865071063360335,43793.0,0.9861329793930054,0.0542605929076671,0.2760720686538725,43793.0,25710.078585147858,39466.73317718506,25710.078585147858,13750.508263587952,4.018112182617188,0.0 -82000,0.1361116,0.015167874,,,,,,,,,,,,,,,,, -82100,0.15470687,0.017756077,,,,,,,,,,,,,,,,, -82200,0.14217196,0.017741969,,,,,,,,,,,,,,,,, -82300,0.13700421,0.016595168,,,,,,,,,,,,,,,,, -82400,0.15188484,0.018121198,,,,,,,,,,,,,,,,, -82500,0.15740919,0.017904218,,,,,,,,,,,,,,,,, -82600,0.14856744,0.019201366,,,,,,,,,,,,,,,,, -82700,0.15035348,0.018654158,,,,,,,,,,,,,,,,, -82707,,,0.9955469369888306,0.013970274478197,0.7702644447155856,0.9869558811187744,0.0505871586501598,0.2866412038804638,43793.0,0.9861329793930054,0.0542605891823768,0.2759491004785961,43793.0,25950.17149996757,39832.270023822784,25950.17149996757,13875.881750106812,4.064675092697144,0.0 -82800,0.14679618,0.019366572,,,,,,,,,,,,,,,,, -82900,0.13650444,0.017535057,,,,,,,,,,,,,,,,, -83000,0.15743363,0.018483432,,,,,,,,,,,,,,,,, -83100,0.14016137,0.017976927,,,,,,,,,,,,,,,,, -83200,0.16145533,0.018152812,,,,,,,,,,,,,,,,, -83300,0.15702567,0.01736242,,,,,,,,,,,,,,,,, -83400,0.140439,0.018117147,,,,,,,,,,,,,,,,, -83435,,,0.9955602884292604,0.0138719677925109,0.7792143024381376,0.9869558811187744,0.0505871586501598,0.2867899905995227,43793.0,0.9861329793930054,0.0542605891823768,0.2759817029818188,43793.0,26190.390765428543,40204.94928979874,26190.390765428543,14008.269878864288,4.11327862739563,0.0 -83500,0.14551298,0.017690588,,,,,,,,,,,,,,,,, -83600,0.1646104,0.020671729,,,,,,,,,,,,,,,,, -83700,0.16650581,0.01990318,,,,,,,,,,,,,,,,, -83800,0.14336267,0.017151866,,,,,,,,,,,,,,,,, -83900,0.14548221,0.017985122,,,,,,,,,,,,,,,,, -84000,0.14893834,0.018046886,,,,,,,,,,,,,,,,, -84100,0.13519283,0.01799251,,,,,,,,,,,,,,,,, -84184,,,0.9955973625183104,0.0137764988467097,0.783298819531164,0.9869558811187744,0.0505871586501598,0.286624042646645,43793.0,0.9861329793930054,0.0542605891823768,0.2759793362653783,43793.0,26430.476687908173,40563.81817674637,26430.476687908173,14126.985151052477,4.160360813140869,0.0 -84200,0.15000708,0.018473255,,,,,,,,,,,,,,,,, -84300,0.13673007,0.018533625,,,,,,,,,,,,,,,,, -84400,0.13772216,0.018781388,,,,,,,,,,,,,,,,, -84500,0.14962105,0.017629378,,,,,,,,,,,,,,,,, -84600,0.12989351,0.017397722,,,,,,,,,,,,,,,,, -84700,0.14960644,0.017292393,,,,,,,,,,,,,,,,, -84800,0.13618329,0.016775323,,,,,,,,,,,,,,,,, -84900,0.12573528,0.016140567,,,,,,,,,,,,,,,,, -84940,,,0.9954774379730223,0.0140807861462235,0.7756743396009409,0.9869558811187744,0.0505871586501598,0.2864305899233596,43793.0,0.9861329793930054,0.0542605891823768,0.2759710816592701,43793.0,26670.56862926483,40923.90593361855,26670.56862926483,14246.918652057648,4.203290939331055,0.0 -85000,0.1288477,0.01631129,,,,,,,,,,,,,,,,, -85100,0.13367005,0.016808208,,,,,,,,,,,,,,,,, -85200,0.12975718,0.015846003,,,,,,,,,,,,,,,,, -85300,0.14676955,0.017738037,,,,,,,,,,,,,,,,, -85400,0.14574581,0.01785461,,,,,,,,,,,,,,,,, -85500,0.12912016,0.016764734,,,,,,,,,,,,,,,,, -85600,0.15523572,0.016750397,,,,,,,,,,,,,,,,, -85693,,,0.9955883622169496,0.0137911858037114,0.7841693580470048,0.9869558811187744,0.0505871586501598,0.2865949001344348,43793.0,0.9861329793930054,0.0542605891823768,0.2759714642884614,43793.0,26910.639630556107,41284.11113762856,26910.639630556107,14366.983271598816,4.25146746635437,0.0 -85700,0.15089196,0.017393526,,,,,,,,,,,,,,,,, -85800,0.13812372,0.016759504,,,,,,,,,,,,,,,,, -85900,0.14890964,0.01821135,,,,,,,,,,,,,,,,, -86000,0.14821425,0.017632436,,,,,,,,,,,,,,,,, -86100,0.13754873,0.016004546,,,,,,,,,,,,,,,,, -86200,0.14642185,0.017981187,,,,,,,,,,,,,,,,, -86300,0.13805918,0.018180965,,,,,,,,,,,,,,,,, -86400,0.15938625,0.018996516,,,,,,,,,,,,,,,,, -86448,,,0.9955260157585144,0.0140296816825866,0.7727005352207057,0.9869558811187744,0.0505871586501598,0.2866704201057026,43793.0,0.9861329793930054,0.0542605891823768,0.2759215759192164,43793.0,27150.82177400589,41645.6478009224,27150.82177400589,14488.276557445526,4.293117523193359,0.0 -86500,0.1297939,0.017685862,,,,,,,,,,,,,,,,, -86600,0.14400634,0.0199127,,,,,,,,,,,,,,,,, -86700,0.14801994,0.017689427,,,,,,,,,,,,,,,,, -86800,0.16643295,0.018551972,,,,,,,,,,,,,,,,, -86900,0.14139804,0.017048733,,,,,,,,,,,,,,,,, -87000,0.14367819,0.016757157,,,,,,,,,,,,,,,,, -87100,0.14792183,0.01701374,,,,,,,,,,,,,,,,, -87200,0.12676625,0.017610798,,,,,,,,,,,,,,,,, -87201,,,0.9955212473869324,0.01403393689543,0.7780336724337944,0.9869558811187744,0.0505871586501598,0.286602377318734,43793.0,0.9861329793930054,0.0542605929076671,0.2760478402507823,43793.0,27391.00008749962,42003.57257246971,27391.00008749962,14605.95310664177,4.343390703201294,0.0 -87300,0.13799666,0.018044118,,,,,,,,,,,,,,,,, -87400,0.14993896,0.017345354,,,,,,,,,,,,,,,,, -87500,0.13812794,0.019581214,,,,,,,,,,,,,,,,, -87600,0.13484377,0.016739866,,,,,,,,,,,,,,,,, -87700,0.13150051,0.016081426,,,,,,,,,,,,,,,,, -87800,0.14783673,0.017656537,,,,,,,,,,,,,,,,, -87900,0.13125871,0.01777325,,,,,,,,,,,,,,,,, -87944,,,0.995576798915863,0.0137699004262685,0.7864783435135866,0.9869558811187744,0.0505871586501598,0.2865672216156085,43793.0,0.9861329793930054,0.0542605891823768,0.2760906757638992,43793.0,27631.004123210907,42367.01085090637,27631.004123210907,14729.32446551323,4.386620283126831,0.0 -88000,0.13894978,0.017661542,,,,,,,,,,,,,,,,, -88100,0.13331835,0.016916307,,,,,,,,,,,,,,,,, -88200,0.14597555,0.015141095,,,,,,,,,,,,,,,,, -88300,0.15832633,0.02135951,,,,,,,,,,,,,,,,, -88400,0.12973021,0.016710185,,,,,,,,,,,,,,,,, -88500,0.15636696,0.017490217,,,,,,,,,,,,,,,,, -88600,0.14382909,0.016872961,,,,,,,,,,,,,,,,, -88693,,,0.9955042600631714,0.014033424668014,0.7815081377763017,0.9869558811187744,0.0505871586501598,0.2864212378275307,43793.0,0.9861329793930054,0.0542605891823768,0.2761358426923165,43793.0,27870.94686794281,42725.75531196594,27870.94686794281,14848.06216621399,4.431077480316162,0.0 -88700,0.15181294,0.01841396,,,,,,,,,,,,,,,,, -88800,0.13604318,0.016784577,,,,,,,,,,,,,,,,, -88900,0.134334,0.017499523,,,,,,,,,,,,,,,,, -89000,0.1494474,0.020537961,,,,,,,,,,,,,,,,, -89100,0.14518647,0.018970344,,,,,,,,,,,,,,,,, -89200,0.15755199,0.018958217,,,,,,,,,,,,,,,,, -89300,0.13653043,0.017019259,,,,,,,,,,,,,,,,, -89400,0.14751455,0.017040249,,,,,,,,,,,,,,,,, -89447,,,0.9955434799194336,0.0138015914708375,0.778073415456541,0.9869558811187744,0.0505871586501598,0.2865742535694123,43793.0,0.9861329793930054,0.0542605891823768,0.2759851509876829,43793.0,28111.208525419235,43087.18470454216,28111.208525419235,14969.166311979294,4.4745166301727295,0.0 -89500,0.14769028,0.020132476,,,,,,,,,,,,,,,,, -89600,0.13779794,0.018966302,,,,,,,,,,,,,,,,, -89700,0.14303543,0.017679978,,,,,,,,,,,,,,,,, -89800,0.13298903,0.016140291,,,,,,,,,,,,,,,,, -89900,0.13153122,0.016345315,,,,,,,,,,,,,,,,, -90000,0.14174035,0.016597643,,,,,,,,,,,,,,,,, -90100,0.14633736,0.017969577,,,,,,,,,,,,,,,,, -90200,0.15935645,0.02049851,,,,,,,,,,,,,,,,, -90204,,,0.9955761432647704,0.0139395911246538,0.7750910374084437,0.9869558811187744,0.0505871586501598,0.2865358350164018,43793.0,0.9861329793930054,0.0542605891823768,0.2759657702013668,43793.0,28351.200445890427,43443.676209926605,28351.200445890427,15085.602485895157,4.517841577529907,0.0 -90300,0.1334521,0.016075764,,,,,,,,,,,,,,,,, -90400,0.13769674,0.016235575,,,,,,,,,,,,,,,,, -90500,0.13570409,0.017490836,,,,,,,,,,,,,,,,, -90600,0.13352077,0.018039864,,,,,,,,,,,,,,,,, -90700,0.12110541,0.01566381,,,,,,,,,,,,,,,,, -90800,0.16714513,0.018682664,,,,,,,,,,,,,,,,, -90900,0.14870223,0.016372222,,,,,,,,,,,,,,,,, -90963,,,0.9954948425292968,0.014080642722547,0.7802792739060775,0.9869558811187744,0.0505871586501598,0.2866171329509638,43793.0,0.9861329793930054,0.0542605891823768,0.2759976092226506,43793.0,28591.379506587986,43804.14271020889,28591.379506587986,15205.82478427887,4.563338756561279,0.0 -91000,0.13295187,0.018453075,,,,,,,,,,,,,,,,, -91100,0.14005265,0.016912578,,,,,,,,,,,,,,,,, -91200,0.13899784,0.016969416,,,,,,,,,,,,,,,,, -91300,0.14306916,0.016013006,,,,,,,,,,,,,,,,, -91400,0.13065685,0.014877559,,,,,,,,,,,,,,,,, -91500,0.13561957,0.016537435,,,,,,,,,,,,,,,,, -91600,0.1599883,0.017760321,,,,,,,,,,,,,,,,, -91700,0.13940461,0.01730244,,,,,,,,,,,,,,,,, -91727,,,0.9955697059631348,0.0138261914253234,0.7810301983931045,0.9869558811187744,0.0505871586501598,0.2866110429605794,43793.0,0.9861329793930054,0.0542605891823768,0.2759996168856555,43793.0,28831.39079451561,44160.73046088219,28831.39079451561,15322.338630437853,4.60586142539978,0.0 -91800,0.14618124,0.017675802,,,,,,,,,,,,,,,,, -91900,0.1281957,0.016724978,,,,,,,,,,,,,,,,, -92000,0.15750983,0.019393962,,,,,,,,,,,,,,,,, -92100,0.13395862,0.016560258,,,,,,,,,,,,,,,,, -92200,0.13238928,0.016200256,,,,,,,,,,,,,,,,, -92300,0.1513386,0.020414067,,,,,,,,,,,,,,,,, -92400,0.13223822,0.017705368,,,,,,,,,,,,,,,,, -92495,,,0.9955700039863586,0.013861620798707,0.7817951724672543,0.9869558811187744,0.0505871586501598,0.2865152886626495,43793.0,0.9861329793930054,0.0542605891823768,0.2759411183774932,43793.0,29071.55628156662,44520.32515239716,29071.55628156662,15441.70261669159,4.650819540023804,0.0 -92500,0.143674,0.017833872,,,,,,,,,,,,,,,,, -92600,0.14592001,0.017788472,,,,,,,,,,,,,,,,, -92700,0.14693356,0.016275175,,,,,,,,,,,,,,,,, -92800,0.14852862,0.018363765,,,,,,,,,,,,,,,,, -92900,0.129667,0.015426258,,,,,,,,,,,,,,,,, -93000,0.13651831,0.017821522,,,,,,,,,,,,,,,,, -93100,0.14629406,0.018935042,,,,,,,,,,,,,,,,, -93200,0.15851939,0.021175873,,,,,,,,,,,,,,,,, -93254,,,0.9955008029937744,0.0140228671953082,0.7731956136988019,0.9869558811187744,0.0505871586501598,0.2864977609935052,43793.0,0.9861329793930054,0.0542605891823768,0.2759367174136053,43793.0,29311.766901016235,44881.38517570496,29311.766901016235,15562.487857341766,4.694918155670166,0.0 -93300,0.14776838,0.01790001,,,,,,,,,,,,,,,,, -93400,0.15550132,0.019711075,,,,,,,,,,,,,,,,, -93500,0.15682264,0.02089698,,,,,,,,,,,,,,,,, -93600,0.14662863,0.01783352,,,,,,,,,,,,,,,,, -93700,0.14360155,0.016718,,,,,,,,,,,,,,,,, -93800,0.15799195,0.020120919,,,,,,,,,,,,,,,,, -93900,0.13672158,0.016181687,,,,,,,,,,,,,,,,, -94000,0.15173659,0.0203897,,,,,,,,,,,,,,,,, -94016,,,0.995567500591278,0.0138811152428388,0.7789318143991779,0.9869558811187744,0.0505871586501598,0.2865480551449149,43793.0,0.9861329793930054,0.0542605891823768,0.2760667388806535,43793.0,29551.97358584404,45239.58659052849,29551.97358584404,15680.41870713234,4.739025115966797,0.0 -94100,0.16135064,0.019695833,,,,,,,,,,,,,,,,, -94200,0.14936858,0.015793711,,,,,,,,,,,,,,,,, -94300,0.16337994,0.018259397,,,,,,,,,,,,,,,,, -94400,0.13976486,0.016941475,,,,,,,,,,,,,,,,, -94500,0.13533102,0.01663718,,,,,,,,,,,,,,,,, -94600,0.14546256,0.018017175,,,,,,,,,,,,,,,,, -94700,0.13934588,0.019673107,,,,,,,,,,,,,,,,, -94782,,,0.9955399036407472,0.0139724360778927,0.7753218829520573,0.9869558811187744,0.0505871586501598,0.2866137226228561,43793.0,0.9861329793930054,0.0542605891823768,0.2759850102404555,43793.0,29791.934020996094,45600.714686870575,29791.934020996094,15801.521898031237,4.783658504486084,0.0 -94800,0.1328274,0.01781108,,,,,,,,,,,,,,,,, -94900,0.16041987,0.01868437,,,,,,,,,,,,,,,,, -95000,0.13918455,0.019039715,,,,,,,,,,,,,,,,, -95100,0.14990686,0.018037729,,,,,,,,,,,,,,,,, -95200,0.12611668,0.016292516,,,,,,,,,,,,,,,,, -95300,0.16167615,0.01859672,,,,,,,,,,,,,,,,, -95400,0.14334448,0.016209839,,,,,,,,,,,,,,,,, -95500,0.1449713,0.016497102,,,,,,,,,,,,,,,,, -95545,,,0.9955657124519348,0.0138775575906038,0.7811558660083017,0.9869558811187744,0.0505871586501598,0.2866254112875427,43793.0,0.9861329793930054,0.0542605891823768,0.2760216601824851,43793.0,30032.101594686508,45956.278367996216,30032.101594686508,15916.854518413544,4.826842546463013,0.0 -95600,0.14126047,0.018525582,,,,,,,,,,,,,,,,, -95700,0.14378828,0.017804412,,,,,,,,,,,,,,,,, -95800,0.14564478,0.018727228,,,,,,,,,,,,,,,,, -95900,0.12808065,0.01757056,,,,,,,,,,,,,,,,, -96000,0.14088354,0.017689524,,,,,,,,,,,,,,,,, -96100,0.12608856,0.016546398,,,,,,,,,,,,,,,,, -96200,0.14999273,0.018358342,,,,,,,,,,,,,,,,, -96300,0.13585101,0.016189398,,,,,,,,,,,,,,,,, -96309,,,0.9955589175224304,0.0138029223307967,0.7868677053325954,0.9869558811187744,0.0505871586501598,0.2866101630441003,43793.0,0.9861329793930054,0.0542605891823768,0.2760575986127702,43793.0,30272.227252483368,46314.66438794136,30272.227252483368,16035.049534797668,4.871997594833374,0.0 -96400,0.14119938,0.017475076,,,,,,,,,,,,,,,,, -96500,0.13312437,0.015961723,,,,,,,,,,,,,,,,, -96600,0.13814712,0.018515354,,,,,,,,,,,,,,,,, -96700,0.15696642,0.019552113,,,,,,,,,,,,,,,,, -96800,0.15254413,0.021007182,,,,,,,,,,,,,,,,, -96900,0.15022355,0.017746875,,,,,,,,,,,,,,,,, -97000,0.13606937,0.017040495,,,,,,,,,,,,,,,,, -97067,,,0.995500147342682,0.0140522504225373,0.7729460373532289,0.9869558811187744,0.0505871586501598,0.2866899351187257,43793.0,0.9861329793930054,0.0542605891823768,0.2759726786264729,43793.0,30512.21157264709,46678.82244515419,30512.21157264709,16159.157991409302,4.91708254814148,0.0 -97100,0.14251995,0.018860001,,,,,,,,,,,,,,,,, -97200,0.12890103,0.016130975,,,,,,,,,,,,,,,,, -97300,0.14528218,0.019185578,,,,,,,,,,,,,,,,, -97400,0.15167157,0.015540581,,,,,,,,,,,,,,,,, -97500,0.14429103,0.018481858,,,,,,,,,,,,,,,,, -97600,0.13982394,0.017735787,,,,,,,,,,,,,,,,, -97700,0.13447307,0.016690718,,,,,,,,,,,,,,,,, -97800,0.13144307,0.01689701,,,,,,,,,,,,,,,,, -97829,,,0.9955435991287231,0.0138851748779416,0.7760392545421367,0.9869558811187744,0.0505871586501598,0.2865191272693426,43793.0,0.9861329793930054,0.0542605891823768,0.2761105488502643,43793.0,30752.282462358475,47038.88059139252,30752.282462358475,16279.08171248436,4.960994243621826,0.0 -97900,0.12640744,0.015851442,,,,,,,,,,,,,,,,, -98000,0.14435115,0.017155414,,,,,,,,,,,,,,,,, -98100,0.13946715,0.015737668,,,,,,,,,,,,,,,,, -98200,0.14544792,0.01869516,,,,,,,,,,,,,,,,, -98300,0.13007285,0.01725522,,,,,,,,,,,,,,,,, -98400,0.14214431,0.016088327,,,,,,,,,,,,,,,,, -98500,0.15117405,0.018990569,,,,,,,,,,,,,,,,, -98578,,,0.9955313801765442,0.0139819029718637,0.7786825487258938,0.9869558811187744,0.0505871586501598,0.2864706683088275,43793.0,0.9861329793930054,0.0542605929076671,0.2760305001928738,43793.0,30992.51039814949,47396.06888914108,30992.51039814949,16395.97787475586,5.005981922149658,0.0 -98600,0.14225806,0.015483015,,,,,,,,,,,,,,,,, -98700,0.1471517,0.018973213,,,,,,,,,,,,,,,,, -98800,0.14387347,0.018082974,,,,,,,,,,,,,,,,, -98900,0.15205991,0.017066343,,,,,,,,,,,,,,,,, -99000,0.13825442,0.01771422,,,,,,,,,,,,,,,,, -99100,0.14825654,0.020206518,,,,,,,,,,,,,,,,, -99200,0.13230728,0.016556619,,,,,,,,,,,,,,,,, -99300,0.14296982,0.017674817,,,,,,,,,,,,,,,,, -99336,,,0.9955341815948486,0.0139691065996885,0.7826914554191917,0.9869558811187744,0.0505871586501598,0.2865572910504564,43793.0,0.9861329793930054,0.0542605929076671,0.2759760452958431,43793.0,31232.58783340454,47754.60851669312,31232.58783340454,16514.374661684036,5.0520899295806885,0.0 -99400,0.14105454,0.01694475,,,,,,,,,,,,,,,,, -99500,0.1356879,0.016846647,,,,,,,,,,,,,,,,, -99600,0.13775285,0.016268536,,,,,,,,,,,,,,,,, -99700,0.14332247,0.01766451,,,,,,,,,,,,,,,,, -99800,0.13506861,0.01692922,,,,,,,,,,,,,,,,, -99900,0.15789615,0.017531205,,,,,,,,,,,,,,,,, -100000,0.14035515,0.017092776,,,,,,,,,,,,,,,,, -100100,0.13851254,0.016456312,,,,,,,,,,,,,,,,, -100103,,,0.9955401420593262,0.0138903046026825,0.7793055337576846,0.9869558811187744,0.0505871586501598,0.2865208102988978,43793.0,0.9861329793930054,0.0542605891823768,0.2759562314694286,43793.0,31472.69523191452,48113.74931025505,31472.69523191452,16633.34353852272,5.097035884857178,0.0 -100200,0.1399451,0.017001977,,,,,,,,,,,,,,,,, -100300,0.15453391,0.017861955,,,,,,,,,,,,,,,,, -100400,0.13936031,0.018890873,,,,,,,,,,,,,,,,, -100500,0.13030346,0.016216146,,,,,,,,,,,,,,,,, -100600,0.13599095,0.01658277,,,,,,,,,,,,,,,,, -100700,0.14856821,0.018199017,,,,,,,,,,,,,,,,, -100800,0.1577484,0.018448858,,,,,,,,,,,,,,,,, -100869,,,0.9955520629882812,0.0139562338590621,0.7847198707800941,0.9869558811187744,0.0505871586501598,0.2865390644148356,43793.0,0.9861329793930054,0.0542605891823768,0.2760779355054802,43793.0,31712.652955770493,48472.2280766964,31712.652955770493,16751.799800157547,5.141988754272461,0.0 -100900,0.13693093,0.017932108,,,,,,,,,,,,,,,,, -101000,0.14142127,0.02111057,,,,,,,,,,,,,,,,, -101100,0.15379447,0.017694617,,,,,,,,,,,,,,,,, -101200,0.16058347,0.015726244,,,,,,,,,,,,,,,,, -101300,0.14573997,0.017745752,,,,,,,,,,,,,,,,, -101400,0.14303309,0.016861537,,,,,,,,,,,,,,,,, -101500,0.16356273,0.018706802,,,,,,,,,,,,,,,,, -101600,0.14312382,0.018425034,,,,,,,,,,,,,,,,, -101634,,,0.9955708980560304,0.0137952547520399,0.7813605957526542,0.9869558811187744,0.0505871586501598,0.2866145553852413,43793.0,0.9861329793930054,0.0542605891823768,0.2759564045633718,43793.0,31952.75311899185,48829.83093047142,31952.75311899185,16869.23617386818,5.188489437103272,0.0 -101700,0.14448291,0.017144429,,,,,,,,,,,,,,,,, -101800,0.13677794,0.017145509,,,,,,,,,,,,,,,,, -101900,0.14918792,0.01758254,,,,,,,,,,,,,,,,, -102000,0.13618283,0.017659713,,,,,,,,,,,,,,,,, -102100,0.12727949,0.017161421,,,,,,,,,,,,,,,,, -102200,0.13643064,0.018839862,,,,,,,,,,,,,,,,, -102300,0.14077055,0.017579602,,,,,,,,,,,,,,,,, -102400,,,0.9955251216888428,0.0139615908265113,0.770412098402832,0.9869558811187744,0.0505871586501598,0.2864852371767319,43793.0,0.9861329793930054,0.0542605891823768,0.276152526270037,43793.0,32192.752415180206,49191.24772572517,32192.752415180206,16990.588547229767,5.23371148109436,0.0 -102400,0.1436333,0.017498428,,,,,,,,,,,,,,,,, -102500,0.14034429,0.01636305,,,,,,,,,,,,,,,,, -102600,0.12955432,0.016862256,,,,,,,,,,,,,,,,, -102700,0.13184574,0.017720306,,,,,,,,,,,,,,,,, -102800,0.14911774,0.018076735,,,,,,,,,,,,,,,,, -102900,0.14364518,0.017774163,,,,,,,,,,,,,,,,, -103000,0.1335256,0.015578052,,,,,,,,,,,,,,,,, -103100,0.16065256,0.017151257,,,,,,,,,,,,,,,,, -103158,,,0.9955227375030518,0.0140439234673976,0.773106773764879,0.9869558811187744,0.0505871586501598,0.2867109029952526,43793.0,0.9861329793930054,0.0542605891823768,0.275994188857404,43793.0,32432.875032186508,49547.85517024994,32432.875032186508,17107.008259534836,5.278891801834106,0.0 -103200,0.13419405,0.017396253,,,,,,,,,,,,,,,,, -103300,0.14982045,0.019259255,,,,,,,,,,,,,,,,, -103400,0.13107042,0.017618375,,,,,,,,,,,,,,,,, -103500,0.14484563,0.017474938,,,,,,,,,,,,,,,,, -103600,0.14004733,0.018356266,,,,,,,,,,,,,,,,, -103700,0.13861181,0.018118687,,,,,,,,,,,,,,,,, -103800,0.19460414,0.017230181,,,,,,,,,,,,,,,,, -103900,0.15420751,0.019072428,,,,,,,,,,,,,,,,, -103919,,,0.9955666065216064,0.0138385407626628,0.7811419869738975,0.9869558811187744,0.0505871586501598,0.2865256056236226,43793.0,0.9861329793930054,0.0542605891823768,0.2759413262208593,43793.0,32672.83212566376,49899.01314020157,32672.83212566376,17218.14317536354,5.324789047241211,0.0 -104000,0.14251463,0.017828356,,,,,,,,,,,,,,,,, -104100,0.14874433,0.017274475,,,,,,,,,,,,,,,,, -104200,0.14921619,0.017040858,,,,,,,,,,,,,,,,, -104300,0.13751298,0.016503723,,,,,,,,,,,,,,,,, -104400,0.129384,0.016208941,,,,,,,,,,,,,,,,, -104500,0.13431582,0.018167764,,,,,,,,,,,,,,,,, -104600,0.13093215,0.016253177,,,,,,,,,,,,,,,,, -104688,,,0.9955680966377258,0.0138523709028959,0.7850848461487768,0.9869558811187744,0.0505871586501598,0.2865603986056249,43793.0,0.9861329793930054,0.0542605891823768,0.2759689009568245,43793.0,32912.77941918373,50259.56562113762,32912.77941918373,17338.682789564133,5.370456695556641,0.0 -104700,0.1522955,0.017033655,,,,,,,,,,,,,,,,, -104800,0.15870719,0.019259287,,,,,,,,,,,,,,,,, -104900,0.1631607,0.022391118,,,,,,,,,,,,,,,,, -105000,0.14429504,0.018789977,,,,,,,,,,,,,,,,, -105100,0.13934967,0.019886233,,,,,,,,,,,,,,,,, -105200,0.15954162,0.019325433,,,,,,,,,,,,,,,,, -105300,0.14725676,0.017763106,,,,,,,,,,,,,,,,, -105400,0.13075963,0.016909635,,,,,,,,,,,,,,,,, -105447,,,0.9955039620399476,0.0139955282211303,0.7774106373615576,0.9869558811187744,0.0505871586501598,0.2865216599854022,43793.0,0.9861329793930054,0.0542605891823768,0.2760524963945224,43793.0,33152.962929964066,50619.27890300751,33152.962929964066,17458.148350954056,5.415530681610107,0.0 -105500,0.14075604,0.018102739,,,,,,,,,,,,,,,,, -105600,0.13714379,0.017518455,,,,,,,,,,,,,,,,, -105700,0.15987383,0.019140292,,,,,,,,,,,,,,,,, -105800,0.14779432,0.01872513,,,,,,,,,,,,,,,,, -105900,0.14494166,0.017950911,,,,,,,,,,,,,,,,, -106000,0.15593737,0.01883532,,,,,,,,,,,,,,,,, -106100,0.14141923,0.019494895,,,,,,,,,,,,,,,,, -106180,,,0.9955573678016664,0.0138543657958507,0.7817870350219891,0.9869558811187744,0.0505871586501598,0.2866419430739763,43793.0,0.9861329793930054,0.0542605891823768,0.2759944644891478,43793.0,33392.90033054352,50983.27924466133,33392.90033054352,17582.141822576523,5.462770223617554,0.0 -106200,0.1533376,0.018395545,,,,,,,,,,,,,,,,, -106300,0.13103342,0.018621275,,,,,,,,,,,,,,,,, -106400,0.14835177,0.019457486,,,,,,,,,,,,,,,,, -106500,0.15418693,0.019526247,,,,,,,,,,,,,,,,, -106600,0.13704644,0.016263453,,,,,,,,,,,,,,,,, -106700,0.15074013,0.020771565,,,,,,,,,,,,,,,,, -106800,0.14952669,0.017264973,,,,,,,,,,,,,,,,, -106900,0.15672974,0.019330863,,,,,,,,,,,,,,,,, -106918,,,0.9955280423164368,0.0140522280707955,0.7661217680770456,0.9869558811187744,0.0505871586501598,0.286556291387751,43793.0,0.9861329793930054,0.0542605929076671,0.2759606267995214,43793.0,33633.02038526535,51340.06503534317,33633.02038526535,17698.731519460678,5.51738166809082,0.0 -107000,0.1382065,0.017387815,,,,,,,,,,,,,,,,, -107100,0.14419907,0.018703626,,,,,,,,,,,,,,,,, -107200,0.14206253,0.0174597,,,,,,,,,,,,,,,,, -107300,0.15182714,0.019712308,,,,,,,,,,,,,,,,, -107400,0.13776937,0.016678782,,,,,,,,,,,,,,,,, -107500,0.1383822,0.014421913,,,,,,,,,,,,,,,,, -107600,0.13979368,0.01710359,,,,,,,,,,,,,,,,, -107677,,,0.9955880641937256,0.0138220340013504,0.779899470061354,0.9869558811187744,0.0505871586501598,0.2865691722887861,43793.0,0.9861329793930054,0.0542605929076671,0.2760152750049882,43793.0,33872.95367479324,51695.79447507858,33872.95367479324,17814.46081638336,5.564499616622925,0.0 -107700,0.14121357,0.018174026,,,,,,,,,,,,,,,,, -107800,0.1353049,0.016225453,,,,,,,,,,,,,,,,, -107900,0.16686371,0.019342525,,,,,,,,,,,,,,,,, -108000,0.13703904,0.01671702,,,,,,,,,,,,,,,,, -108100,0.13943395,0.017845869,,,,,,,,,,,,,,,,, -108200,0.14446948,0.017915184,,,,,,,,,,,,,,,,, -108300,0.13515963,0.016911343,,,,,,,,,,,,,,,,, -108400,0.1335973,0.015564919,,,,,,,,,,,,,,,,, -108441,,,0.9955666065216064,0.0138393836095929,0.7841374720767258,0.9869558811187744,0.0505871586501598,0.2864342504751349,43793.0,0.9861329793930054,0.0542605891823768,0.2759831511515541,43793.0,34113.1648209095,52048.97903227806,34113.1648209095,17927.368314504623,5.611015796661377,0.0 -108500,0.1348037,0.017627697,,,,,,,,,,,,,,,,, -108600,0.14832564,0.018529318,,,,,,,,,,,,,,,,, -108700,0.13435939,0.017194113,,,,,,,,,,,,,,,,, -108800,0.15669832,0.01858306,,,,,,,,,,,,,,,,, -108900,0.14629887,0.016728828,,,,,,,,,,,,,,,,, -109000,0.13355163,0.017570788,,,,,,,,,,,,,,,,, -109100,0.16140011,0.016408622,,,,,,,,,,,,,,,,, -109200,0.13362753,0.017221376,,,,,,,,,,,,,,,,, -109204,,,0.9955072999000548,0.0140102095901966,0.781538035227237,0.9869558811187744,0.0505871586501598,0.286505081826393,43793.0,0.9861329793930054,0.0542605891823768,0.2760052366019689,43793.0,34353.28302168846,52402.14703559876,34353.28302168846,18040.35396838188,5.656085968017578,0.0 -109300,0.12707274,0.01511828,,,,,,,,,,,,,,,,, -109400,0.15964395,0.01712115,,,,,,,,,,,,,,,,, -109500,0.16925195,0.018455895,,,,,,,,,,,,,,,,, -109600,0.1362997,0.017034726,,,,,,,,,,,,,,,,, -109700,0.15547416,0.017528031,,,,,,,,,,,,,,,,, -109800,0.1544527,0.018629832,,,,,,,,,,,,,,,,, -109900,0.14595874,0.017373586,,,,,,,,,,,,,,,,, -109946,,,0.9955655932426452,0.013794494792819,0.7802740098906389,0.9869558811187744,0.0505871586501598,0.2865730157937727,43793.0,0.9861329793930054,0.0542605891823768,0.2760160422751042,43793.0,34593.253633499146,52770.308326005936,34593.253633499146,18168.47488641739,5.704323053359985,0.0 -110000,0.1433186,0.017246002,,,,,,,,,,,,,,,,, -110100,0.14395267,0.017083433,,,,,,,,,,,,,,,,, -110200,0.13847612,0.016442422,,,,,,,,,,,,,,,,, -110300,0.14886497,0.021306166,,,,,,,,,,,,,,,,, -110400,0.13558447,0.018682538,,,,,,,,,,,,,,,,, -110500,0.15634345,0.020655867,,,,,,,,,,,,,,,,, -110600,0.13001703,0.015822086,,,,,,,,,,,,,,,,, -110688,,,0.9955146312713624,0.0140324300155043,0.7722557395643874,0.9869558811187744,0.0505871586501598,0.2867128276145199,43793.0,0.9861329793930054,0.0542605891823768,0.2759493312803741,43793.0,34833.229488134384,53130.4960372448,34833.229488134384,18288.60990166664,5.759541034698486,0.0 -110700,0.13410124,0.015872406,,,,,,,,,,,,,,,,, -110800,0.14986768,0.018937318,,,,,,,,,,,,,,,,, -110900,0.14762159,0.01771478,,,,,,,,,,,,,,,,, -111000,0.1459506,0.01928229,,,,,,,,,,,,,,,,, -111100,0.1551491,0.016225291,,,,,,,,,,,,,,,,, -111200,0.16338715,0.020383347,,,,,,,,,,,,,,,,, -111300,0.16206515,0.019620815,,,,,,,,,,,,,,,,, -111400,0.14024344,0.017334584,,,,,,,,,,,,,,,,, -111441,,,0.9955356121063232,0.0140041010454297,0.7818724497668679,0.9869558811187744,0.0505871586501598,0.2864828757007852,43793.0,0.9861329793930054,0.0542605891823768,0.2759526438099676,43793.0,35073.32724404335,53486.25983309746,35073.32724404335,18404.20976114273,5.80646538734436,0.0 -111500,0.12730709,0.01537234,,,,,,,,,,,,,,,,, -111600,0.13543773,0.017930359,,,,,,,,,,,,,,,,, -111700,0.14220968,0.016290165,,,,,,,,,,,,,,,,, -111800,0.13554141,0.017040592,,,,,,,,,,,,,,,,, -111900,0.16301228,0.020177756,,,,,,,,,,,,,,,,, -112000,0.16257313,0.020444676,,,,,,,,,,,,,,,,, -112100,0.13394782,0.015681462,,,,,,,,,,,,,,,,, -112200,0.1461611,0.017805422,,,,,,,,,,,,,,,,, -112205,,,0.9955651760101318,0.0138434814289212,0.7786537040391218,0.9869558811187744,0.0505871586501598,0.2864558681747481,43793.0,0.9861329793930054,0.0542605891823768,0.276035341605793,43793.0,35313.35783934593,53845.22560930252,35313.35783934593,18523.07643556595,5.854486465454102,0.0 -112300,0.14909105,0.01805664,,,,,,,,,,,,,,,,, -112400,0.13129523,0.017918818,,,,,,,,,,,,,,,,, -112500,0.14740236,0.016618894,,,,,,,,,,,,,,,,, -112600,0.13184348,0.017454058,,,,,,,,,,,,,,,,, -112700,0.13929558,0.01863739,,,,,,,,,,,,,,,,, -112800,0.12011023,0.016014451,,,,,,,,,,,,,,,,, -112900,0.1412413,0.018428672,,,,,,,,,,,,,,,,, -112967,,,0.995530605316162,0.0139740211889147,0.7914633546862597,0.9869558811187744,0.0505871586501598,0.2866072428159543,43793.0,0.9861329793930054,0.0542605891823768,0.2760329738837908,43793.0,35553.29326963425,54197.83467626572,35553.29326963425,18635.68253469467,5.9019293785095215,0.0 -113000,0.13646518,0.018696537,,,,,,,,,,,,,,,,, -113100,0.1292416,0.016600847,,,,,,,,,,,,,,,,, -113200,0.13245423,0.018233716,,,,,,,,,,,,,,,,, -113300,0.14638619,0.01808847,,,,,,,,,,,,,,,,, -113400,0.13697864,0.0150437625,,,,,,,,,,,,,,,,, -113500,0.15014346,0.01708081,,,,,,,,,,,,,,,,, -113600,0.1755786,0.019497517,,,,,,,,,,,,,,,,, -113700,0.17731623,0.019129483,,,,,,,,,,,,,,,,, -113731,,,0.9955313205718994,0.0138783520087599,0.7689000334908849,0.9869558811187744,0.0505871586501598,0.2866461290888903,43793.0,0.9861329793930054,0.0542605891823768,0.2759876579915408,43793.0,35793.34585046768,54549.151299238205,35793.34585046768,18746.87947440148,5.949110984802246,0.0 -113800,0.14180698,0.018123547,,,,,,,,,,,,,,,,, -113900,0.13730395,0.018100828,,,,,,,,,,,,,,,,, -114000,0.16550738,0.018475411,,,,,,,,,,,,,,,,, -114100,0.14780039,0.019906176,,,,,,,,,,,,,,,,, -114200,0.17678075,0.01863218,,,,,,,,,,,,,,,,, -114300,0.14027502,0.017672483,,,,,,,,,,,,,,,,, -114400,0.1571783,0.018044505,,,,,,,,,,,,,,,,, -114488,,,0.9955372214317322,0.0139394728466868,0.7815741067321347,0.9869558811187744,0.0505871586501598,0.2865367019489142,43793.0,0.9861329793930054,0.0542605891823768,0.2760602507676384,43793.0,36033.280943632126,54905.05683708191,36033.280943632126,18862.78116440773,5.997893810272217,0.0 -114500,0.14702508,0.01757328,,,,,,,,,,,,,,,,, -114600,0.14402753,0.01809242,,,,,,,,,,,,,,,,, -114700,0.15846717,0.018472219,,,,,,,,,,,,,,,,, -114800,0.15162742,0.018140761,,,,,,,,,,,,,,,,, -114900,0.14732288,0.016593007,,,,,,,,,,,,,,,,, -115000,0.14951938,0.016094564,,,,,,,,,,,,,,,,, -115100,0.13627546,0.018310055,,,,,,,,,,,,,,,,, -115200,0.13827595,0.01718225,,,,,,,,,,,,,,,,, -115249,,,0.9955669045448304,0.0139568988233804,0.7696413411763118,0.9869558811187744,0.0505871586501598,0.2866747762429935,43793.0,0.9861329793930054,0.0542605891823768,0.2759374328359461,43793.0,36273.37165546417,55257.477509737015,36273.37165546417,18975.043164491653,6.046472072601318,0.0 -115300,0.1514522,0.017543292,,,,,,,,,,,,,,,,, -115400,0.14303127,0.018185636,,,,,,,,,,,,,,,,, -115500,0.1385223,0.018521285,,,,,,,,,,,,,,,,, -115600,0.13363917,0.016521338,,,,,,,,,,,,,,,,, -115700,0.1454163,0.01751582,,,,,,,,,,,,,,,,, -115800,0.14166252,0.016706698,,,,,,,,,,,,,,,,, -115900,0.15752296,0.016788883,,,,,,,,,,,,,,,,, -116000,0.14183013,0.019783305,,,,,,,,,,,,,,,,, -116019,,,0.9955641627311708,0.013859805651009,0.7850317787543281,0.9869558811187744,0.0505871586501598,0.2865551392909943,43793.0,0.9861329793930054,0.0542605891823768,0.2760065789253968,43793.0,36513.53369688988,55608.0884168148,36513.53369688988,19085.425540685654,6.093016862869263,0.0 -116100,0.14353493,0.017368685,,,,,,,,,,,,,,,,, -116200,0.1618549,0.020065224,,,,,,,,,,,,,,,,, -116300,0.14862831,0.016911741,,,,,,,,,,,,,,,,, -116400,0.15035458,0.01773335,,,,,,,,,,,,,,,,, -116500,0.13505451,0.01615384,,,,,,,,,,,,,,,,, -116600,0.13451944,0.016468342,,,,,,,,,,,,,,,,, -116700,0.14335166,0.019672137,,,,,,,,,,,,,,,,, -116791,,,0.9955353140830994,0.0139025328680872,0.7856174298674152,0.9869558811187744,0.0505871586501598,0.2865175104084884,43793.0,0.9861329793930054,0.0542605929076671,0.2759737336301693,43793.0,36753.48701620102,55964.93444156647,36753.48701620102,19202.251974105835,6.139763593673706,0.0 -116800,0.13952543,0.01735386,,,,,,,,,,,,,,,,, -116900,0.13749683,0.01518289,,,,,,,,,,,,,,,,, -117000,0.14912911,0.01812144,,,,,,,,,,,,,,,,, -117100,0.15086098,0.018528597,,,,,,,,,,,,,,,,, -117200,0.1364467,0.016618084,,,,,,,,,,,,,,,,, -117300,0.17415231,0.019512711,,,,,,,,,,,,,,,,, -117400,0.17793623,0.022080624,,,,,,,,,,,,,,,,, -117500,0.1324038,0.016942242,,,,,,,,,,,,,,,,, -117554,,,0.9955047965049744,0.0140029955655336,0.7721958738934317,0.9869558811187744,0.0505871586501598,0.286516307077295,43793.0,0.9861329793930054,0.0542605891823768,0.276063221490286,43793.0,36993.59615969658,56316.52581310272,36993.59615969658,19313.66682934761,6.1876304149627686,0.0 -117600,0.13426788,0.015124139,,,,,,,,,,,,,,,,, -117700,0.13159072,0.016691128,,,,,,,,,,,,,,,,, -117800,0.15953434,0.019754881,,,,,,,,,,,,,,,,, -117900,0.15173303,0.01728677,,,,,,,,,,,,,,,,, -118000,0.13243008,0.01618071,,,,,,,,,,,,,,,,, -118100,0.14306745,0.017941369,,,,,,,,,,,,,,,,, -118200,0.14316975,0.019285532,,,,,,,,,,,,,,,,, -118300,0.13886212,0.016616592,,,,,,,,,,,,,,,,, -118321,,,0.9955202341079712,0.0139669338241219,0.7832561135578842,0.9869558811187744,0.0505871586501598,0.2867236390114384,43793.0,0.9861329793930054,0.0542605891823768,0.2760523527524613,43793.0,37233.81376886368,56673.477946043015,37233.81376886368,19430.334386587143,6.235366821289063,0.0 -118400,0.14880982,0.01957984,,,,,,,,,,,,,,,,, -118500,0.13294825,0.016730564,,,,,,,,,,,,,,,,, -118600,0.14182132,0.018233977,,,,,,,,,,,,,,,,, -118700,0.13671897,0.01894791,,,,,,,,,,,,,,,,, -118800,0.13945562,0.016250128,,,,,,,,,,,,,,,,, -118900,0.12556267,0.017358439,,,,,,,,,,,,,,,,, -119000,0.15078555,0.01839613,,,,,,,,,,,,,,,,, -119085,,,0.9955582022666932,0.0139131685718894,0.7770310732162424,0.9869558811187744,0.0505871586501598,0.2866907177414473,43793.0,0.9861329793930054,0.0542605891823768,0.2759484318103051,43793.0,37473.95171999931,57033.24269533157,37473.95171999931,19549.893506526947,6.28296160697937,0.0 -119100,0.13611752,0.016804334,,,,,,,,,,,,,,,,, -119200,0.14480111,0.017569799,,,,,,,,,,,,,,,,, -119300,0.15330257,0.01800154,,,,,,,,,,,,,,,,, -119400,0.11517342,0.0136817405,,,,,,,,,,,,,,,,, -119500,0.13265678,0.016587304,,,,,,,,,,,,,,,,, -119600,0.13633333,0.017121576,,,,,,,,,,,,,,,,, -119700,0.15703629,0.019770693,,,,,,,,,,,,,,,,, -119800,0.15048738,0.018319786,,,,,,,,,,,,,,,,, -119849,,,0.9955252408981324,0.0140389157459139,0.7764668111826432,0.9869558811187744,0.0505871586501598,0.2867658580304355,43793.0,0.9861329793930054,0.0542605891823768,0.2760382490303594,43793.0,37714.03889346123,57386.43988108635,37714.03889346123,19662.93479347229,6.332173109054565,0.0 -119900,0.15439454,0.018414972,,,,,,,,,,,,,,,,, -120000,0.14670205,0.018873893,,,,,,,,,,,,,,,,, -120100,0.14179067,0.016464574,,,,,,,,,,,,,,,,, -120200,0.13956589,0.01571983,,,,,,,,,,,,,,,,, -120300,0.18929473,0.021244287,,,,,,,,,,,,,,,,, -120400,0.15493363,0.019128447,,,,,,,,,,,,,,,,, -120500,0.14092918,0.018233197,,,,,,,,,,,,,,,,, -120600,0.13364348,0.014778888,,,,,,,,,,,,,,,,, -120612,,,0.9956113696098328,0.013700583949685,0.7857608368347888,0.9869558811187744,0.0505871586501598,0.2865649723131295,43793.0,0.9861329793930054,0.0542605929076671,0.2759795768788805,43793.0,37954.24491381645,57739.95515346527,37954.24491381645,19776.17451095581,6.382076978683472,0.0 -120700,0.14635196,0.018986203,,,,,,,,,,,,,,,,, -120800,0.13736783,0.016753571,,,,,,,,,,,,,,,,, -120900,0.14734593,0.017253783,,,,,,,,,,,,,,,,, -121000,0.1492393,0.018595748,,,,,,,,,,,,,,,,, -121100,0.13874678,0.018620146,,,,,,,,,,,,,,,,, -121200,0.14271703,0.01675065,,,,,,,,,,,,,,,,, -121300,0.1329576,0.016004754,,,,,,,,,,,,,,,,, -121378,,,0.9955203533172609,0.0139859775081276,0.7798932588725074,0.9869558811187744,0.0505871586501598,0.2865813047482279,43793.0,0.9861329793930054,0.0542605891823768,0.2759365662211449,43793.0,38194.241970300674,58094.52316451073,38194.241970300674,19890.67685246468,6.430763006210327,0.0 -121400,0.14573951,0.017324205,,,,,,,,,,,,,,,,, -121500,0.13626306,0.017026894,,,,,,,,,,,,,,,,, -121600,0.14341682,0.018455211,,,,,,,,,,,,,,,,, -121700,0.15256122,0.016752796,,,,,,,,,,,,,,,,, -121800,0.14952709,0.018254397,,,,,,,,,,,,,,,,, -121900,0.146424,0.017982008,,,,,,,,,,,,,,,,, -122000,0.1591127,0.01930398,,,,,,,,,,,,,,,,, -122100,0.15786247,0.019681796,,,,,,,,,,,,,,,,, -122107,,,0.9955151081085204,0.0139172924682497,0.7799444582331069,0.9869558811187744,0.0505871586501598,0.286507144273752,43793.0,0.9861329793930054,0.0542605891823768,0.2759651343786751,43793.0,38434.26357674599,58457.947122097015,38434.26357674599,20014.006541490555,6.4794793128967285,0.0 -122200,0.14356525,0.015246408,,,,,,,,,,,,,,,,, -122300,0.13701291,0.016723344,,,,,,,,,,,,,,,,, -122400,0.15953478,0.019468715,,,,,,,,,,,,,,,,, -122500,0.1609237,0.019791229,,,,,,,,,,,,,,,,, -122600,0.13428701,0.017230008,,,,,,,,,,,,,,,,, -122700,0.14866893,0.015963335,,,,,,,,,,,,,,,,, -122800,0.13352221,0.016875394,,,,,,,,,,,,,,,,, -122836,,,0.995549976825714,0.0139400735497474,0.7765179471599951,0.9869558811187744,0.0505871586501598,0.2864950099685684,43793.0,0.9861329793930054,0.0542605891823768,0.2760162167495111,43793.0,38674.20046806336,58817.90552377701,38674.20046806336,20133.949590206143,6.534131288528442,0.0 -122900,0.12930363,0.015435784,,,,,,,,,,,,,,,,, -123000,0.14357057,0.018237945,,,,,,,,,,,,,,,,, -123100,0.14281973,0.01848269,,,,,,,,,,,,,,,,, -123200,0.16253711,0.02002696,,,,,,,,,,,,,,,,, -123300,0.14491078,0.018927688,,,,,,,,,,,,,,,,, -123400,0.15019791,0.01859539,,,,,,,,,,,,,,,,, -123500,0.15409008,0.01678642,,,,,,,,,,,,,,,,, -123599,,,0.9955602288246156,0.013960919342935,0.7762064006184903,0.9869558811187744,0.0505871586501598,0.2866554438585214,43793.0,0.9861329793930054,0.0542605891823768,0.2759403551749374,43793.0,38914.34498047829,59171.270610809326,38914.34498047829,20247.099562883377,6.583909273147583,0.0 -123600,0.14026742,0.019550534,,,,,,,,,,,,,,,,, -123700,0.13805835,0.019491928,,,,,,,,,,,,,,,,, -123800,0.14342993,0.019014928,,,,,,,,,,,,,,,,, -123900,0.14868404,0.017627196,,,,,,,,,,,,,,,,, -124000,0.1268007,0.017107312,,,,,,,,,,,,,,,,, -124100,0.12863919,0.016956292,,,,,,,,,,,,,,,,, -124200,0.1554386,0.018365886,,,,,,,,,,,,,,,,, -124300,0.15038939,0.019010393,,,,,,,,,,,,,,,,, -124357,,,0.9955296516418456,0.0139157203957438,0.7814746632594798,0.9869558811187744,0.0505871586501598,0.2864686696059941,43793.0,0.9861329793930054,0.0542605891823768,0.2759882452278814,43793.0,39154.38414406776,59526.41132116318,39154.38414406776,20362.13060450554,6.63392972946167,0.0 -124400,0.15267791,0.019516548,,,,,,,,,,,,,,,,, -124500,0.14128606,0.01756322,,,,,,,,,,,,,,,,, -124600,0.15322977,0.018784458,,,,,,,,,,,,,,,,, -124700,0.12981746,0.015573588,,,,,,,,,,,,,,,,, -124800,0.15875638,0.020240275,,,,,,,,,,,,,,,,, -124900,0.13544758,0.017904535,,,,,,,,,,,,,,,,, -125000,0.14656436,0.018257502,,,,,,,,,,,,,,,,, -125100,0.1288949,0.016380547,,,,,,,,,,,,,,,,, -125115,,,0.9956167340278624,0.0137162897735834,0.7831922579711271,0.9869558811187744,0.0505871586501598,0.2866426636904639,43793.0,0.9861329793930054,0.0542605929076671,0.2759794268804273,43793.0,39394.45890569687,59880.65182375908,39394.45890569687,20476.22482061386,6.685439109802246,0.0 -125200,0.15693624,0.017550102,,,,,,,,,,,,,,,,, -125300,0.13336934,0.017971462,,,,,,,,,,,,,,,,, -125400,0.1448613,0.019209769,,,,,,,,,,,,,,,,, -125500,0.14419726,0.018241726,,,,,,,,,,,,,,,,, -125600,0.15752365,0.018341888,,,,,,,,,,,,,,,,, -125700,0.14521457,0.018902969,,,,,,,,,,,,,,,,, -125800,0.1296695,0.014833211,,,,,,,,,,,,,,,,, -125877,,,0.995496392250061,0.0140708796679973,0.7815206147351226,0.9869558811187744,0.0505871586501598,0.2865493820973373,43793.0,0.9861329793930054,0.0542605891823768,0.2759345952193906,43793.0,39634.669246912,60236.08238339424,39634.669246912,20591.374747753143,6.735987901687622,0.0 -125900,0.15225725,0.018361108,,,,,,,,,,,,,,,,, -126000,0.15366615,0.018349288,,,,,,,,,,,,,,,,, -126100,0.1431658,0.01834503,,,,,,,,,,,,,,,,, -126200,0.137812,0.016187308,,,,,,,,,,,,,,,,, -126300,0.13694036,0.017702647,,,,,,,,,,,,,,,,, -126400,0.15839425,0.018753346,,,,,,,,,,,,,,,,, -126500,0.15223785,0.018335035,,,,,,,,,,,,,,,,, -126600,0.15936929,0.01895246,,,,,,,,,,,,,,,,, -126634,,,0.995526134967804,0.0139440223574638,0.7840203056045634,0.9869558811187744,0.0505871586501598,0.2865940628883549,43793.0,0.9861329793930054,0.0542605891823768,0.2759888347376683,43793.0,39874.781376600266,60588.30221319199,39874.781376600266,20703.41284537316,6.785962104797363,0.0 -126700,0.14303553,0.018341776,,,,,,,,,,,,,,,,, -126800,0.13676353,0.016788278,,,,,,,,,,,,,,,,, -126900,0.13332973,0.017032778,,,,,,,,,,,,,,,,, -127000,0.15837865,0.019273534,,,,,,,,,,,,,,,,, -127100,0.15002696,0.017842658,,,,,,,,,,,,,,,,, -127200,0.15216143,0.018339245,,,,,,,,,,,,,,,,, -127300,0.1316376,0.015478019,,,,,,,,,,,,,,,,, -127383,,,0.9955549836158752,0.0139614222571253,0.7755276795761658,0.9869558811187744,0.0505871586501598,0.2864978730948987,43793.0,0.9861329793930054,0.0542605891823768,0.2759701992733522,43793.0,40114.7754137516,60942.00603628159,40114.7754137516,20817.052837133408,6.836190938949585,0.0 -127400,0.15069887,0.01591244,,,,,,,,,,,,,,,,, -127500,0.14279112,0.01828143,,,,,,,,,,,,,,,,, -127600,0.15509078,0.019918628,,,,,,,,,,,,,,,,, -127700,0.13526835,0.016253257,,,,,,,,,,,,,,,,, -127800,0.13642338,0.017774092,,,,,,,,,,,,,,,,, -127900,0.15183094,0.01708057,,,,,,,,,,,,,,,,, -128000,0.14690223,0.015041936,,,,,,,,,,,,,,,,, -128100,0.1447738,0.019481597,,,,,,,,,,,,,,,,, -128144,,,0.9955434203147888,0.0139302490279078,0.7804785749072813,0.9869558811187744,0.0505871586501598,0.2865674623685516,43793.0,0.9861329793930054,0.0542605891823768,0.2759531333787746,43793.0,40354.73989057541,61299.456189870834,40354.73989057541,20934.46791172028,6.886796712875366,0.0 -128200,0.15828073,0.01791437,,,,,,,,,,,,,,,,, -128300,0.14981408,0.016669644,,,,,,,,,,,,,,,,, -128400,0.14982665,0.01798588,,,,,,,,,,,,,,,,, -128500,0.13155621,0.015508244,,,,,,,,,,,,,,,,, -128600,0.16297036,0.017409703,,,,,,,,,,,,,,,,, -128700,0.13709873,0.017168086,,,,,,,,,,,,,,,,, -128800,0.15134965,0.016251054,,,,,,,,,,,,,,,,, -128900,0.14302593,0.01938218,,,,,,,,,,,,,,,,, -128908,,,0.9955562949180604,0.0138796502724289,0.7869704604464649,0.9869558811187744,0.0505871586501598,0.2865487520763469,43793.0,0.9861329793930054,0.0542605891823768,0.2761565504820941,43793.0,40594.930896520615,61651.26204371452,40594.930896520615,21046.012306928635,6.93725061416626,0.0 -129000,0.13441287,0.017275667,,,,,,,,,,,,,,,,, -129100,0.14622872,0.01897105,,,,,,,,,,,,,,,,, -129200,0.1285058,0.015740698,,,,,,,,,,,,,,,,, -129300,0.14443101,0.016567612,,,,,,,,,,,,,,,,, -129400,0.13847414,0.017536422,,,,,,,,,,,,,,,,, -129500,0.12896374,0.016467417,,,,,,,,,,,,,,,,, -129600,0.15353484,0.018869534,,,,,,,,,,,,,,,,, -129672,,,0.995521605014801,0.0139315128326416,0.7730002351900508,0.9869558811187744,0.0505871586501598,0.2865531526774762,43793.0,0.9861329793930054,0.0542605891823768,0.2760121926469126,43793.0,40835.15206003189,62006.51690149307,40835.15206003189,21160.976698875427,6.987164974212647,0.0 -129700,0.16365044,0.017580029,,,,,,,,,,,,,,,,, -129800,0.14049804,0.017115438,,,,,,,,,,,,,,,,, -129900,0.13025026,0.015998712,,,,,,,,,,,,,,,,, -130000,0.14517751,0.015667403,,,,,,,,,,,,,,,,, -130100,0.15547252,0.018498793,,,,,,,,,,,,,,,,, -130200,0.15792073,0.019303536,,,,,,,,,,,,,,,,, -130300,0.15383843,0.019461583,,,,,,,,,,,,,,,,, -130400,0.13082294,0.017713383,,,,,,,,,,,,,,,,, -130436,,,0.9955735206604004,0.0138278417289257,0.7775615790059116,0.9869558811187744,0.0505871586501598,0.2866122372734342,43793.0,0.9861329793930054,0.0542605929076671,0.2760328250730925,43793.0,41075.38081145287,62356.98491954804,41075.38081145287,21271.14502310753,7.038794279098511,0.0 -130500,0.13999553,0.017888283,,,,,,,,,,,,,,,,, -130600,0.14138605,0.018570136,,,,,,,,,,,,,,,,, -130700,0.13564792,0.017149562,,,,,,,,,,,,,,,,, -130800,0.13880724,0.01709973,,,,,,,,,,,,,,,,, -130900,0.1509101,0.01845826,,,,,,,,,,,,,,,,, -131000,0.14227474,0.018172761,,,,,,,,,,,,,,,,, -131100,0.17573506,0.019962825,,,,,,,,,,,,,,,,, -131200,,,0.9955013394355774,0.0140891848132014,0.7765810385292532,0.9869558811187744,0.0505871586501598,0.2865142023598593,43793.0,0.9861329793930054,0.0542605891823768,0.2760851045251188,43793.0,41315.339268922806,62712.39630937576,41315.339268922806,21386.529180288315,7.088238716125488,0.0 -131200,0.13410056,0.016752651,,,,,,,,,,,,,,,,, -131300,0.1683312,0.021341845,,,,,,,,,,,,,,,,, -131400,0.1387905,0.017947348,,,,,,,,,,,,,,,,, -131500,0.14692992,0.015353777,,,,,,,,,,,,,,,,, -131600,0.14755209,0.017994966,,,,,,,,,,,,,,,,, -131700,0.14885125,0.0169255,,,,,,,,,,,,,,,,, -131800,0.14191714,0.017669916,,,,,,,,,,,,,,,,, -131900,0.15540601,0.019770259,,,,,,,,,,,,,,,,, -131968,,,0.9955359101295472,0.0139632923528552,0.7859540350520396,0.9869558811187744,0.0505871586501598,0.2865155375239707,43793.0,0.9861329793930054,0.0542605891823768,0.2760290599539537,43793.0,41555.53718948364,63059.76567673683,41555.53718948364,21493.62951111793,7.139894247055054,0.0 -132000,0.12969905,0.018359631,,,,,,,,,,,,,,,,, -132100,0.13801302,0.01817676,,,,,,,,,,,,,,,,, -132200,0.15404253,0.016163576,,,,,,,,,,,,,,,,, -132300,0.1437135,0.018468017,,,,,,,,,,,,,,,,, -132400,0.14825898,0.018625546,,,,,,,,,,,,,,,,, -132500,0.14786741,0.01648637,,,,,,,,,,,,,,,,, -132600,0.1427656,0.017136727,,,,,,,,,,,,,,,,, -132700,0.16260791,0.017768536,,,,,,,,,,,,,,,,, -132736,,,0.9955663084983826,0.0137974023818969,0.780957904115043,0.9869558811187744,0.0505871586501598,0.2864971390162966,43793.0,0.9861329793930054,0.0542605891823768,0.2759484631551082,43793.0,41795.69300460816,63411.89345788956,41795.69300460816,21605.530052900314,7.191572189331055,0.0 -132800,0.16703491,0.01787547,,,,,,,,,,,,,,,,, -132900,0.15212381,0.017595287,,,,,,,,,,,,,,,,, -133000,0.1628784,0.020413164,,,,,,,,,,,,,,,,, -133100,0.14137608,0.017211452,,,,,,,,,,,,,,,,, -133200,0.1395802,0.01708377,,,,,,,,,,,,,,,,, -133300,0.16774227,0.020334609,,,,,,,,,,,,,,,,, -133400,0.1496682,0.017564835,,,,,,,,,,,,,,,,, -133500,0.12804316,0.01709183,,,,,,,,,,,,,,,,, -133501,,,0.9955506920814514,0.0139207448810338,0.7788296776529676,0.9869558811187744,0.0505871586501598,0.2866552551738057,43793.0,0.9861329793930054,0.0542605891823768,0.2759436747545231,43793.0,42035.6406815052,63761.36307883263,42035.6406815052,21714.979140996933,7.244651079177856,0.0 -133600,0.12937707,0.016811091,,,,,,,,,,,,,,,,, -133700,0.137122,0.017872319,,,,,,,,,,,,,,,,, -133800,0.15434372,0.017598815,,,,,,,,,,,,,,,,, -133900,0.13996252,0.019572848,,,,,,,,,,,,,,,,, -134000,0.15374948,0.018647162,,,,,,,,,,,,,,,,, -134100,0.13766935,0.018623717,,,,,,,,,,,,,,,,, -134200,0.15516648,0.01876773,,,,,,,,,,,,,,,,, -134266,,,0.995567262172699,0.0138367991894483,0.7788725334472248,0.9869558811187744,0.0505871586501598,0.2864781178666014,43793.0,0.9861329793930054,0.0542605929076671,0.2760273170786373,43793.0,42275.762905836105,64109.81100869179,42275.762905836105,21823.23331308365,7.296629428863525,0.0 -134300,0.14778659,0.01571636,,,,,,,,,,,,,,,,, -134400,0.15233324,0.01805979,,,,,,,,,,,,,,,,, -134500,0.17438996,0.022684565,,,,,,,,,,,,,,,,, -134600,0.12776828,0.016528487,,,,,,,,,,,,,,,,, -134700,0.14530876,0.018748144,,,,,,,,,,,,,,,,, -134800,0.12693179,0.017403638,,,,,,,,,,,,,,,,, -134900,0.13954303,0.017910117,,,,,,,,,,,,,,,,, -135000,0.14367439,0.018808572,,,,,,,,,,,,,,,,, -135010,,,0.99552983045578,0.0139883216470479,0.7758701306845568,0.9869558811187744,0.0505871586501598,0.2865875693942649,43793.0,0.9861329793930054,0.0542605891823768,0.2760359550998511,43793.0,42515.91206884384,64470.954426050186,42515.91206884384,21944.152539014816,7.349959850311279,0.0 -135100,0.14596476,0.019296462,,,,,,,,,,,,,,,,, -135200,0.15026079,0.01739493,,,,,,,,,,,,,,,,, -135300,0.14723206,0.017634014,,,,,,,,,,,,,,,,, -135400,0.150794,0.01891831,,,,,,,,,,,,,,,,, -135500,0.11744264,0.016660584,,,,,,,,,,,,,,,,, -135600,0.14578655,0.017414501,,,,,,,,,,,,,,,,, -135700,0.13787943,0.018033335,,,,,,,,,,,,,,,,, -135753,,,0.9955183863639832,0.0140597959980368,0.7834770621587747,0.9869558811187744,0.0505871586501598,0.2864998181212718,43793.0,0.9861329793930054,0.0542605891823768,0.2760901919326172,43793.0,42756.0000565052,64826.79478096962,42756.0000565052,22059.821719884872,7.40891695022583,0.0 -135800,0.14926185,0.017922135,,,,,,,,,,,,,,,,, -135900,0.15861599,0.0200707,,,,,,,,,,,,,,,,, -136000,0.14741927,0.015459471,,,,,,,,,,,,,,,,, -136100,0.1416023,0.017290212,,,,,,,,,,,,,,,,, -136200,0.1478977,0.018043386,,,,,,,,,,,,,,,,, -136300,0.1425406,0.017422497,,,,,,,,,,,,,,,,, -136400,0.16243164,0.017904447,,,,,,,,,,,,,,,,, -136500,0.14977083,0.018317658,,,,,,,,,,,,,,,,, -136512,,,0.995557963848114,0.0137953907251358,0.7795278572179078,0.9869558811187744,0.0505871586501598,0.2864227948563459,43793.0,0.9861329793930054,0.0542605891823768,0.2759971165969379,43793.0,42996.08527255058,65178.19145774841,42996.08527255058,22171.06094479561,7.4614715576171875,0.0 -136600,0.13495101,0.01469146,,,,,,,,,,,,,,,,, -136700,0.13962531,0.017364075,,,,,,,,,,,,,,,,, -136800,0.15712619,0.020295069,,,,,,,,,,,,,,,,, -136900,0.15852734,0.01969756,,,,,,,,,,,,,,,,, -137000,0.13262285,0.018440498,,,,,,,,,,,,,,,,, -137100,0.13115203,0.017703192,,,,,,,,,,,,,,,,, -137200,0.13516194,0.016584124,,,,,,,,,,,,,,,,, -137273,,,0.9955756664276124,0.0138542233034968,0.7806127574085215,0.9869558811187744,0.0505871586501598,0.2866201386157032,43793.0,0.9861329793930054,0.0542605891823768,0.2759274889810704,43793.0,43236.27346968651,65531.63501739502,43236.27346968651,22284.24476671219,7.513568878173828,0.0 -137300,0.15354745,0.017631887,,,,,,,,,,,,,,,,, -137400,0.13678378,0.017909808,,,,,,,,,,,,,,,,, -137500,0.13359103,0.017763076,,,,,,,,,,,,,,,,, -137600,0.15706275,0.018883191,,,,,,,,,,,,,,,,, -137700,0.1525192,0.017560167,,,,,,,,,,,,,,,,, -137800,0.13840757,0.018039608,,,,,,,,,,,,,,,,, -137900,0.16453229,0.017767282,,,,,,,,,,,,,,,,, -138000,0.13372156,0.017718561,,,,,,,,,,,,,,,,, -138035,,,0.9955047369003296,0.0140287214890122,0.776828284611313,0.9869558811187744,0.0505871586501598,0.286563876407547,43793.0,0.9861329793930054,0.0542605891823768,0.2759771559511502,43793.0,43476.50746488571,65882.75573444366,43476.50746488571,22395.058416366577,7.567029237747192,0.0 -138100,0.13604037,0.015391165,,,,,,,,,,,,,,,,, -138200,0.14509645,0.017455097,,,,,,,,,,,,,,,,, -138300,0.13750464,0.018868264,,,,,,,,,,,,,,,,, -138400,0.13924062,0.018752357,,,,,,,,,,,,,,,,, -138500,0.1475998,0.018369375,,,,,,,,,,,,,,,,, -138600,0.16033141,0.019327287,,,,,,,,,,,,,,,,, -138700,0.14119321,0.014206568,,,,,,,,,,,,,,,,, -138800,0.14492385,0.017879205,,,,,,,,,,,,,,,,, -138801,,,0.9955438375473022,0.0139393527060747,0.7791969983498022,0.9869558811187744,0.0505871586501598,0.2866075502097671,43793.0,0.9861329793930054,0.0542605929076671,0.2760214367974019,43793.0,43716.56052541733,66234.73925423622,43716.56052541733,22506.916797161102,7.619933605194092,0.0 -138900,0.15072674,0.019720094,,,,,,,,,,,,,,,,, -139000,0.15346591,0.017718652,,,,,,,,,,,,,,,,, -139100,0.14537728,0.018521564,,,,,,,,,,,,,,,,, -139200,0.13858362,0.015823191,,,,,,,,,,,,,,,,, -139300,0.13854593,0.017758576,,,,,,,,,,,,,,,,, -139400,0.13368657,0.017936109,,,,,,,,,,,,,,,,, -139500,0.13573402,0.016904565,,,,,,,,,,,,,,,,, -139549,,,0.9955278635025024,0.0140000423416495,0.7743114486500873,0.9869558811187744,0.0505871586501598,0.2867690866867088,43793.0,0.9861329793930054,0.0542605891823768,0.2760094643162549,43793.0,43956.76708507538,66582.52459478378,43956.76708507538,22614.423523902893,7.672637939453125,0.0 -139600,0.12788193,0.016121713,,,,,,,,,,,,,,,,, -139700,0.14923504,0.017806524,,,,,,,,,,,,,,,,, -139800,0.1375551,0.017285341,,,,,,,,,,,,,,,,, -139900,0.16500919,0.016362209,,,,,,,,,,,,,,,,, -140000,0.14022076,0.019038592,,,,,,,,,,,,,,,,, -140100,0.14412113,0.019232737,,,,,,,,,,,,,,,,, -140200,0.1461644,0.017247291,,,,,,,,,,,,,,,,, -140300,0.13927396,0.017872708,,,,,,,,,,,,,,,,, -140319,,,0.9955880641937256,0.0137920398265123,0.7879608890674992,0.9869558811187744,0.0505871586501598,0.2864734779842767,43793.0,0.9861329793930054,0.0542605929076671,0.2759517549835037,43793.0,44196.79552650452,66931.73331785202,44196.79552650452,22723.5324792862,7.724453449249268,0.0 -140400,0.15709482,0.019052822,,,,,,,,,,,,,,,,, -140500,0.13945837,0.01628427,,,,,,,,,,,,,,,,, -140600,0.13156374,0.018552465,,,,,,,,,,,,,,,,, -140700,0.14976926,0.019112678,,,,,,,,,,,,,,,,, -140800,0.15705177,0.01768572,,,,,,,,,,,,,,,,, -140900,0.14046025,0.018608639,,,,,,,,,,,,,,,,, -141000,0.15622404,0.018323287,,,,,,,,,,,,,,,,, -141081,,,0.9955501556396484,0.0138605190441012,0.7867203846922961,0.9869558811187744,0.0505871586501598,0.2865235551818129,43793.0,0.9861329793930054,0.0542605891823768,0.2760356463316414,43793.0,44436.87666225433,67283.75227689743,44436.87666225433,22835.397917747498,7.777060031890869,0.0 -141100,0.13160656,0.014862265,,,,,,,,,,,,,,,,, -141200,0.14573628,0.01692789,,,,,,,,,,,,,,,,, -141300,0.15686634,0.019272996,,,,,,,,,,,,,,,,, -141400,0.13514055,0.016399529,,,,,,,,,,,,,,,,, -141500,0.17372075,0.020354642,,,,,,,,,,,,,,,,, -141600,0.15465027,0.020604676,,,,,,,,,,,,,,,,, -141700,0.14441831,0.01704593,,,,,,,,,,,,,,,,, -141800,0.13659103,0.017778693,,,,,,,,,,,,,,,,, -141837,,,0.9954988956451416,0.0140383094549179,0.7701409373357102,0.9869558811187744,0.0505871586501598,0.2866630419810305,43793.0,0.9861329793930054,0.0542605891823768,0.2760171261581809,43793.0,44676.84755587578,67632.89763379097,44676.84755587578,22944.49940443039,7.830647706985474,0.0 -141900,0.1393529,0.017275468,,,,,,,,,,,,,,,,, -142000,0.14214529,0.017051741,,,,,,,,,,,,,,,,, -142100,0.1291428,0.01725657,,,,,,,,,,,,,,,,, -142200,0.15000303,0.019009376,,,,,,,,,,,,,,,,, -142300,0.14464074,0.019114492,,,,,,,,,,,,,,,,, -142400,0.14446947,0.01873774,,,,,,,,,,,,,,,,, -142500,0.12811401,0.016578367,,,,,,,,,,,,,,,,, -142600,0.1486369,0.016628996,,,,,,,,,,,,,,,,, -142608,,,0.9955378770828248,0.0138828856870532,0.7796853291287564,0.9869558811187744,0.0505871586501598,0.2864522326528279,43793.0,0.9861329793930054,0.0542605891823768,0.275967715529291,43793.0,44916.99302315712,67981.48130369186,44916.99302315712,23052.86403822899,7.88399076461792,0.0 -142700,0.13600558,0.016162472,,,,,,,,,,,,,,,,, -142800,0.14368965,0.017659271,,,,,,,,,,,,,,,,, -142900,0.15352267,0.019323518,,,,,,,,,,,,,,,,, -143000,0.14175041,0.017488366,,,,,,,,,,,,,,,,, -143100,0.16185573,0.018261978,,,,,,,,,,,,,,,,, -143200,0.15955666,0.018008932,,,,,,,,,,,,,,,,, -143300,0.13482267,0.01810784,,,,,,,,,,,,,,,,, -143365,,,0.995568573474884,0.0139041766524314,0.7736875430872415,0.9869558811187744,0.0505871586501598,0.2865738946807334,43793.0,0.9861329793930054,0.0542605891823768,0.2759283175225142,43793.0,45156.97582030296,68334.60452127457,45156.97582030296,23165.930099487305,7.938454627990723,0.0 -143400,0.1514285,0.019504806,,,,,,,,,,,,,,,,, -143500,0.1584972,0.018872567,,,,,,,,,,,,,,,,, -143600,0.13989955,0.01759619,,,,,,,,,,,,,,,,, -143700,0.12740041,0.016314778,,,,,,,,,,,,,,,,, -143800,0.1518561,0.017832983,,,,,,,,,,,,,,,,, -143900,0.14337409,0.018993147,,,,,,,,,,,,,,,,, -144000,0.15442812,0.017432079,,,,,,,,,,,,,,,,, -144100,0.13588049,0.017694598,,,,,,,,,,,,,,,,, -144129,,,0.9955149292945862,0.0140272751450538,0.7838841100376359,0.9869558811187744,0.0505871586501598,0.2865780476781701,43793.0,0.9861329793930054,0.0542605929076671,0.2759656585305678,43793.0,45396.94107532501,68681.67007613182,45396.94107532501,23272.95647072792,7.992566347122192,0.0 -144200,0.14755565,0.020729829,,,,,,,,,,,,,,,,, -144300,0.13318573,0.017171618,,,,,,,,,,,,,,,,, -144400,0.13592914,0.016233673,,,,,,,,,,,,,,,,, -144500,0.14281934,0.018704966,,,,,,,,,,,,,,,,, -144600,0.13920079,0.018434495,,,,,,,,,,,,,,,,, -144700,0.1373064,0.017736584,,,,,,,,,,,,,,,,, -144800,0.14935638,0.0177377,,,,,,,,,,,,,,,,, -144900,0.155049,0.019937672,,,,,,,,,,,,,,,,, -144901,,,0.9955885410308838,0.0137948272749781,0.781838371574247,0.9869558811187744,0.0505871586501598,0.2865226670030312,43793.0,0.9861329793930054,0.0542605929076671,0.2760008915756654,43793.0,45637.14063572884,69025.29367923737,45637.14063572884,23376.3064198494,8.046898603439331,0.0 -145000,0.1449211,0.017828697,,,,,,,,,,,,,,,,, -145100,0.14260739,0.0189635,,,,,,,,,,,,,,,,, -145200,0.15785925,0.016961126,,,,,,,,,,,,,,,,, -145300,0.14375106,0.017686093,,,,,,,,,,,,,,,,, -145400,0.13847975,0.015614916,,,,,,,,,,,,,,,,, -145500,0.13772766,0.01594754,,,,,,,,,,,,,,,,, -145600,0.15789624,0.019277252,,,,,,,,,,,,,,,,, -145674,,,0.995506227016449,0.0140169616788625,0.7722065956630976,0.9869558811187744,0.0505871586501598,0.2865010927262363,43793.0,0.9861329793930054,0.0542605891823768,0.276041130417176,43793.0,45877.35671019554,69374.78163695335,45877.35671019554,23485.505282640457,8.099945545196533,0.0 -145700,0.13667215,0.017137166,,,,,,,,,,,,,,,,, -145800,0.15212946,0.016525678,,,,,,,,,,,,,,,,, -145900,0.14775531,0.01985887,,,,,,,,,,,,,,,,, -146000,0.13307439,0.015005555,,,,,,,,,,,,,,,,, -146100,0.13930178,0.01662119,,,,,,,,,,,,,,,,, -146200,0.1584865,0.016727708,,,,,,,,,,,,,,,,, -146300,0.14422187,0.019540794,,,,,,,,,,,,,,,,, -146400,0.13342024,0.015744325,,,,,,,,,,,,,,,,, -146443,,,0.995516002178192,0.0139518873766064,0.7824941855650105,0.9869558811187744,0.0505871586501598,0.286583906969833,43793.0,0.9861329793930054,0.0542605891823768,0.2759317784239599,43793.0,46117.34045815468,69722.8755030632,46117.34045815468,23593.540898799896,8.154316425323486,0.0 -146500,0.15023527,0.019654559,,,,,,,,,,,,,,,,, -146600,0.14572445,0.018770425,,,,,,,,,,,,,,,,, -146700,0.14376993,0.018137665,,,,,,,,,,,,,,,,, -146800,0.14043899,0.018584896,,,,,,,,,,,,,,,,, -146900,0.13700871,0.017801756,,,,,,,,,,,,,,,,, -147000,0.12876831,0.017153729,,,,,,,,,,,,,,,,, -147100,0.13506803,0.018300926,,,,,,,,,,,,,,,,, -147200,0.15830946,0.018305138,,,,,,,,,,,,,,,,, -147215,,,0.9956077337265016,0.01379029545933,0.7769709209696616,0.9869558811187744,0.0505871586501598,0.2864189844674121,43793.0,0.9861329793930054,0.0542605891823768,0.2759693034880986,43793.0,46357.51933145523,70071.06904149055,46357.51933145523,23701.482311964035,8.207560777664185,0.0 -147300,0.15316655,0.019531,,,,,,,,,,,,,,,,, -147400,0.14423901,0.017776174,,,,,,,,,,,,,,,,, -147500,0.15630445,0.021740945,,,,,,,,,,,,,,,,, -147600,0.14696552,0.018429145,,,,,,,,,,,,,,,,, -147700,0.13435842,0.016870987,,,,,,,,,,,,,,,,, -147800,0.14996238,0.019127863,,,,,,,,,,,,,,,,, -147900,0.16571774,0.020037988,,,,,,,,,,,,,,,,, -147989,,,0.9954946637153624,0.0140935676172375,0.7765513850956582,0.9869558811187744,0.0505871586501598,0.2865399241009378,43793.0,0.9861329793930054,0.0542605891823768,0.2759677164548653,43793.0,46597.6997013092,70418.25558948517,46597.6997013092,23808.41632771492,8.259890794754028,0.0 -148000,0.1510911,0.019269798,,,,,,,,,,,,,,,,, -148100,0.15453231,0.017945226,,,,,,,,,,,,,,,,, -148200,0.13828763,0.016288089,,,,,,,,,,,,,,,,, -148300,0.153302,0.019923864,,,,,,,,,,,,,,,,, -148400,0.15167175,0.018384743,,,,,,,,,,,,,,,,, -148500,0.13122573,0.017853921,,,,,,,,,,,,,,,,, -148600,0.14211366,0.017388688,,,,,,,,,,,,,,,,, -148700,0.1401021,0.018033056,,,,,,,,,,,,,,,,, -148766,,,0.9955323934555054,0.0139013547450304,0.7849211130504072,0.9869558811187744,0.0505871586501598,0.286538708003245,43793.0,0.9861329793930054,0.0542605891823768,0.2759949770797097,43793.0,46837.91800737381,70776.96406245232,46837.91800737381,23926.83319377899,8.31392502784729,0.0 -148800,0.13960437,0.016680494,,,,,,,,,,,,,,,,, -148900,0.14129448,0.019198937,,,,,,,,,,,,,,,,, -149000,0.13748102,0.014869105,,,,,,,,,,,,,,,,, -149100,0.1288074,0.016688986,,,,,,,,,,,,,,,,, -149200,0.14710182,0.018634917,,,,,,,,,,,,,,,,, -149300,0.14025559,0.017806375,,,,,,,,,,,,,,,,, -149400,0.15634753,0.017400408,,,,,,,,,,,,,,,,, -149500,0.17033957,0.018985508,,,,,,,,,,,,,,,,, -149502,,,0.9955655932426452,0.0138668473809957,0.7834967595954778,0.9869558811187744,0.0505871586501598,0.2865266590790944,43793.0,0.9861329793930054,0.0542605891823768,0.2759343881185401,43793.0,47078.12951970101,71135.46232533455,47078.12951970101,24045.02954697609,8.379383087158203,0.0 -149600,0.13609825,0.017035129,,,,,,,,,,,,,,,,, -149700,0.15013695,0.017945966,,,,,,,,,,,,,,,,, -149800,0.14261255,0.017803866,,,,,,,,,,,,,,,,, -149900,0.13672979,0.01850397,,,,,,,,,,,,,,,,, -150000,0.12513381,0.016004,,,,,,,,,,,,,,,,, -150100,0.12648842,0.01563355,,,,,,,,,,,,,,,,, -150200,0.13798764,0.018140873,,,,,,,,,,,,,,,,, -150227,,,0.9955136179924012,0.0139957321807742,0.7791588996594946,0.9869558811187744,0.0505871586501598,0.2867352038186914,43793.0,0.9861329793930054,0.0542605929076671,0.2760130703989507,43793.0,47318.37801861763,71493.53320717812,47318.37801861763,24162.76766180992,8.43949007987976,0.0 -150300,0.14118752,0.01588794,,,,,,,,,,,,,,,,, -150400,0.14533605,0.016059242,,,,,,,,,,,,,,,,, -150500,0.1406409,0.016578646,,,,,,,,,,,,,,,,, -150600,0.15870205,0.020307671,,,,,,,,,,,,,,,,, -150700,0.13708256,0.016643913,,,,,,,,,,,,,,,,, -150800,0.15144725,0.019786129,,,,,,,,,,,,,,,,, -150900,0.13904652,0.018069869,,,,,,,,,,,,,,,,, -150979,,,0.9955903887748718,0.0137750133872032,0.7747467790100341,0.9869558811187744,0.0505871586501598,0.2866331928863107,43793.0,0.9861329793930054,0.0542605891823768,0.2759236825019268,43793.0,47558.30167579651,71842.37278032303,47558.30167579651,24271.599202156067,8.502744197845459,0.0 -151000,0.16930524,0.01771588,,,,,,,,,,,,,,,,, -151100,0.1420852,0.016778152,,,,,,,,,,,,,,,,, -151200,0.1489826,0.019467898,,,,,,,,,,,,,,,,, -151300,0.13244884,0.018162467,,,,,,,,,,,,,,,,, -151400,0.14536907,0.018297354,,,,,,,,,,,,,,,,, -151500,0.15421708,0.019482454,,,,,,,,,,,,,,,,, -151600,0.14783892,0.017596046,,,,,,,,,,,,,,,,, -151700,0.14782712,0.01809737,,,,,,,,,,,,,,,,, -151750,,,0.9955427050590516,0.014013847336173,0.7803063834457855,0.9869558811187744,0.0505871586501598,0.2865761340054181,43793.0,0.9861329793930054,0.0542605891823768,0.2760510306082031,43793.0,47798.30082893372,72190.26391029358,47798.30082893372,24379.415264368057,8.558963775634766,0.0 -151800,0.13260828,0.015225543,,,,,,,,,,,,,,,,, -151900,0.13255867,0.017348763,,,,,,,,,,,,,,,,, -152000,0.1540473,0.018199285,,,,,,,,,,,,,,,,, -152100,0.15891625,0.018851588,,,,,,,,,,,,,,,,, -152200,0.16361435,0.019924777,,,,,,,,,,,,,,,,, -152300,0.15464376,0.019771481,,,,,,,,,,,,,,,,, -152400,0.16933514,0.01959946,,,,,,,,,,,,,,,,, -152500,0.119907,0.017376576,,,,,,,,,,,,,,,,, -152523,,,0.9955410957336426,0.0139313545078039,0.782091041865194,0.9869558811187744,0.0505871586501598,0.2866111059765719,43793.0,0.9861329793930054,0.0542605891823768,0.275992589903956,43793.0,48038.29620957375,72535.97880601883,48038.29620957375,24485.05907702446,8.614659547805786,0.0 -152600,0.14092632,0.017103298,,,,,,,,,,,,,,,,, -152700,0.14266726,0.015840746,,,,,,,,,,,,,,,,, -152800,0.15701124,0.021594344,,,,,,,,,,,,,,,,, -152900,0.14076754,0.017074525,,,,,,,,,,,,,,,,, -153000,0.13168159,0.016588176,,,,,,,,,,,,,,,,, -153100,0.16401184,0.02051515,,,,,,,,,,,,,,,,, -153200,0.13662297,0.018200984,,,,,,,,,,,,,,,,, -153291,,,0.9955719709396362,0.0138335553929209,0.7839643969392222,0.9869558811187744,0.0505871586501598,0.2865632471513083,43793.0,0.9861329793930054,0.0542605929076671,0.2761404118366132,43793.0,48278.4686756134,72885.48430514336,48278.4686756134,24594.315055131912,8.67185354232788,0.0 -153300,0.13726062,0.018244796,,,,,,,,,,,,,,,,, -153400,0.16471466,0.019058961,,,,,,,,,,,,,,,,, -153500,0.14152297,0.019339537,,,,,,,,,,,,,,,,, -153600,0.13177752,0.01673299,,,,,,,,,,,,,,,,, -153700,0.13520366,0.017694548,,,,,,,,,,,,,,,,, -153800,0.15339296,0.018539075,,,,,,,,,,,,,,,,, -153900,0.1482535,0.016741546,,,,,,,,,,,,,,,,, -154000,0.1409518,0.016524319,,,,,,,,,,,,,,,,, -154068,,,0.9955121278762816,0.0139707941561937,0.7735758383914202,0.9869558811187744,0.0505871586501598,0.2867517443359907,43793.0,0.9861329793930054,0.0542605891823768,0.2759883245069869,43793.0,48518.43424797058,73231.79427218437,48518.43424797058,24700.583837985992,8.727308988571167,0.0 -154100,0.15154715,0.017489182,,,,,,,,,,,,,,,,, -154200,0.13550992,0.017294504,,,,,,,,,,,,,,,,, -154300,0.14314768,0.018868199,,,,,,,,,,,,,,,,, -154400,0.16848426,0.017385695,,,,,,,,,,,,,,,,, -154500,0.13398455,0.016597517,,,,,,,,,,,,,,,,, -154600,0.13533176,0.018175129,,,,,,,,,,,,,,,,, -154700,0.1339683,0.017120399,,,,,,,,,,,,,,,,, -154800,0.15534279,0.01924532,,,,,,,,,,,,,,,,, -154839,,,0.995508909225464,0.0139925656840205,0.7794013870917653,0.9869558811187744,0.0505871586501598,0.2865083063356302,43793.0,0.9861329793930054,0.0542605891823768,0.2762042065151867,43793.0,48758.47049832344,73583.22550678253,48758.47049832344,24811.903750658035,8.78290057182312,0.0 -154900,0.14129442,0.016445767,,,,,,,,,,,,,,,,, -155000,0.1463623,0.020832393,,,,,,,,,,,,,,,,, -155100,0.13491069,0.016831636,,,,,,,,,,,,,,,,, -155200,0.12506068,0.016072605,,,,,,,,,,,,,,,,, -155300,0.1404649,0.018813604,,,,,,,,,,,,,,,,, -155400,0.16005027,0.018737275,,,,,,,,,,,,,,,,, -155500,0.13913211,0.016678784,,,,,,,,,,,,,,,,, -155600,0.14697948,0.018072475,,,,,,,,,,,,,,,,, -155612,,,0.9955801963806152,0.0138422073796391,0.7769311812947006,0.9869558811187744,0.0505871586501598,0.2865149377407092,43793.0,0.9861329793930054,0.0542605891823768,0.2759311870209673,43793.0,48998.50784397125,73927.03393220901,48998.50784397125,24915.600333452225,8.837835311889648,0.0 -155700,0.12387594,0.015097794,,,,,,,,,,,,,,,,, -155800,0.14520632,0.016908629,,,,,,,,,,,,,,,,, -155900,0.13202305,0.016986271,,,,,,,,,,,,,,,,, -156000,0.15563186,0.020705862,,,,,,,,,,,,,,,,, -156100,0.16650118,0.02006357,,,,,,,,,,,,,,,,, -156200,0.13523051,0.016676323,,,,,,,,,,,,,,,,, -156300,0.14846998,0.0167058,,,,,,,,,,,,,,,,, -156377,,,0.9955509901046752,0.0139570720493793,0.7760353013748329,0.9869558811187744,0.0505871586501598,0.2865370845662255,43793.0,0.9861329793930054,0.0542605891823768,0.2760219087501092,43793.0,49238.476517915726,74276.68950462341,49238.476517915726,25025.213356494904,8.892682790756226,0.0 -156400,0.13929659,0.017724358,,,,,,,,,,,,,,,,, -156500,0.14529093,0.018457856,,,,,,,,,,,,,,,,, -156600,0.1433134,0.018961938,,,,,,,,,,,,,,,,, -156700,0.14580981,0.018032337,,,,,,,,,,,,,,,,, -156800,0.1423522,0.01812183,,,,,,,,,,,,,,,,, -156900,0.15553336,0.019464761,,,,,,,,,,,,,,,,, -157000,0.12803319,0.01689292,,,,,,,,,,,,,,,,, -157100,0.13924907,0.017040828,,,,,,,,,,,,,,,,, -157147,,,0.99555104970932,0.0139165678992867,0.7797898173653187,0.9869558811187744,0.0505871586501598,0.2865325905894223,43793.0,0.9861329793930054,0.0542605891823768,0.2760101429392562,43793.0,49478.666414260864,74621.77139091492,49478.666414260864,25130.03033232689,8.947967529296875,0.0 -157200,0.13425678,0.017184718,,,,,,,,,,,,,,,,, -157300,0.14167927,0.018315207,,,,,,,,,,,,,,,,, -157400,0.13424495,0.016983496,,,,,,,,,,,,,,,,, -157500,0.13905771,0.017799107,,,,,,,,,,,,,,,,, -157600,0.13518102,0.016451715,,,,,,,,,,,,,,,,, -157700,0.1506794,0.018714063,,,,,,,,,,,,,,,,, -157800,0.15136404,0.019368853,,,,,,,,,,,,,,,,, -157900,0.13988987,0.017087629,,,,,,,,,,,,,,,,, -157912,,,0.995518445968628,0.013942502439022,0.7842495488099699,0.9869558811187744,0.0505871586501598,0.2866229012628155,43793.0,0.9861329793930054,0.0542605891823768,0.2759528202316587,43793.0,49718.83649778366,74971.6976518631,49718.83649778366,25239.711701631542,9.003626585006714,0.0 -158000,0.1464231,0.01799276,,,,,,,,,,,,,,,,, -158100,0.14031343,0.017886939,,,,,,,,,,,,,,,,, -158200,0.14893186,0.018740345,,,,,,,,,,,,,,,,, -158300,0.15770517,0.020013867,,,,,,,,,,,,,,,,, -158400,0.14634737,0.017975446,,,,,,,,,,,,,,,,, -158500,0.1356737,0.018681744,,,,,,,,,,,,,,,,, -158600,0.13478519,0.016374588,,,,,,,,,,,,,,,,, -158689,,,0.9955437779426576,0.013895777054131,0.7807502432776531,0.9869558811187744,0.0505871586501598,0.2865415916987422,43793.0,0.9861329793930054,0.0542605891823768,0.275956717183867,43793.0,49958.982744932175,75323.44123673439,49958.982744932175,25351.234421491623,9.058332443237305,0.0 -158700,0.14502023,0.016631143,,,,,,,,,,,,,,,,, -158800,0.13644172,0.018146006,,,,,,,,,,,,,,,,, -158900,0.14925967,0.017674169,,,,,,,,,,,,,,,,, -159000,0.15741687,0.01988783,,,,,,,,,,,,,,,,, -159100,0.14195296,0.017550882,,,,,,,,,,,,,,,,, -159200,0.14690799,0.018542333,,,,,,,,,,,,,,,,, -159300,0.13407308,0.016139252,,,,,,,,,,,,,,,,, -159400,0.13873743,0.014887771,,,,,,,,,,,,,,,,, -159457,,,0.9955360293388368,0.0139449564740061,0.7723770871421647,0.9869558811187744,0.0505871586501598,0.2867485931237354,43793.0,0.9861329793930054,0.0542605891823768,0.2759523538735477,43793.0,50199.202701091766,75672.04923796654,50199.202701091766,25459.5462975502,9.114988565444946,0.0 -159500,0.15455848,0.017959977,,,,,,,,,,,,,,,,, -159600,0.17294228,0.02119874,,,,,,,,,,,,,,,,, -159700,0.1471594,0.018381866,,,,,,,,,,,,,,,,, -159800,0.14751099,0.019427797,,,,,,,,,,,,,,,,, -159900,0.13562225,0.015719041,,,,,,,,,,,,,,,,, -160000,0.17218566,0.018908396,,,,,,,,,,,,,,,,, -160100,0.15051502,0.015610136,,,,,,,,,,,,,,,,, -160200,0.1376498,0.016508263,,,,,,,,,,,,,,,,, -160210,,,0.9955352544784546,0.0139972520992159,0.7826859588401474,0.9869558811187744,0.0505871586501598,0.2866698257560256,43793.0,0.9861329793930054,0.0542605891823768,0.2759910805896102,43793.0,50439.27810120583,76026.75747394562,50439.27810120583,25574.099715471268,9.172712564468384,0.0 -160300,0.12915811,0.017612908,,,,,,,,,,,,,,,,, -160400,0.14929667,0.018871699,,,,,,,,,,,,,,,,, -160500,0.14696819,0.01927679,,,,,,,,,,,,,,,,, -160600,0.1370906,0.017048627,,,,,,,,,,,,,,,,, -160700,0.13811152,0.016557625,,,,,,,,,,,,,,,,, -160800,0.15221322,0.01915078,,,,,,,,,,,,,,,,, -160900,0.13501936,0.015732525,,,,,,,,,,,,,,,,, -160942,,,0.9955859184265136,0.0137790655717253,0.7764327294859807,0.9869558811187744,0.0505871586501598,0.2865785827957917,43793.0,0.9861329793930054,0.0542605891823768,0.2759355625368492,43793.0,50679.41657733917,76385.46918559074,50679.41657733917,25692.58653616905,9.235629558563232,0.0 -161000,0.14170437,0.017578097,,,,,,,,,,,,,,,,, -161100,0.13998276,0.016414594,,,,,,,,,,,,,,,,, -161200,0.15288718,0.019458994,,,,,,,,,,,,,,,,, -161300,0.14727017,0.017451061,,,,,,,,,,,,,,,,, -161400,0.14052796,0.017391836,,,,,,,,,,,,,,,,, -161500,0.15001053,0.017599571,,,,,,,,,,,,,,,,, -161600,0.16079055,0.01852409,,,,,,,,,,,,,,,,, -161700,0.14999986,0.018720016,,,,,,,,,,,,,,,,, -161709,,,0.9955579042434692,0.0139086106792092,0.7868474332606503,0.9869558811187744,0.0505871586501598,0.2865003997388438,43793.0,0.9861329793930054,0.0542605891823768,0.2759996771960736,43793.0,50919.51976323128,76732.0926129818,50919.51976323128,25799.029140234,9.29334807395935,0.0 -161800,0.1493872,0.017547367,,,,,,,,,,,,,,,,, -161900,0.15826464,0.01824855,,,,,,,,,,,,,,,,, -162000,0.121019654,0.01431819,,,,,,,,,,,,,,,,, -162100,0.15876098,0.019951133,,,,,,,,,,,,,,,,, -162200,0.15253289,0.01998298,,,,,,,,,,,,,,,,, -162300,0.13043602,0.016252765,,,,,,,,,,,,,,,,, -162400,0.13749161,0.018188067,,,,,,,,,,,,,,,,, -162480,,,0.9955121874809264,0.0139806494116783,0.7751245426310285,0.9869558811187744,0.0505871586501598,0.2865003863132244,43793.0,0.9861329793930054,0.0542605891823768,0.276074356700649,43793.0,51159.46822762489,77079.21677541733,51159.46822762489,25906.129714250565,9.348551273345947,0.0 -162500,0.16883709,0.01754095,,,,,,,,,,,,,,,,, -162600,0.13864568,0.017161315,,,,,,,,,,,,,,,,, -162700,0.14014305,0.016935598,,,,,,,,,,,,,,,,, -162800,0.13049199,0.015220675,,,,,,,,,,,,,,,,, -162900,0.12347238,0.014561482,,,,,,,,,,,,,,,,, -163000,0.139664,0.017199945,,,,,,,,,,,,,,,,, -163100,0.14204043,0.019151263,,,,,,,,,,,,,,,,, -163200,0.14621526,0.019057823,,,,,,,,,,,,,,,,, -163251,,,0.9955761432647704,0.0138122150674462,0.7789223056384773,0.9869558811187744,0.0505871586501598,0.2865489837939391,43793.0,0.9861329793930054,0.0542605891823768,0.2759357398817103,43793.0,51399.63913941383,77424.70866513252,51399.63913941383,26011.374217510223,9.40532898902893,0.0 -163300,0.14049082,0.018563367,,,,,,,,,,,,,,,,, -163400,0.14416721,0.015797304,,,,,,,,,,,,,,,,, -163500,0.15453164,0.020640764,,,,,,,,,,,,,,,,, -163600,0.13529555,0.01660731,,,,,,,,,,,,,,,,, -163700,0.13812004,0.017247641,,,,,,,,,,,,,,,,, -163800,0.14665903,0.019010907,,,,,,,,,,,,,,,,, -163900,0.14772831,0.015961519,,,,,,,,,,,,,,,,, -164000,0.14630502,0.016893003,,,,,,,,,,,,,,,,, -164020,,,0.9955022931098938,0.0140853282064199,0.7771181152126543,0.9869558811187744,0.0505871586501598,0.2865277191401532,43793.0,0.9861329793930054,0.0542605891823768,0.275977611609337,43793.0,51639.66033124924,77777.46783638,51639.66033124924,26124.034873485565,9.462925434112549,0.0 -164100,0.12970456,0.017010838,,,,,,,,,,,,,,,,, -164200,0.13702214,0.019360164,,,,,,,,,,,,,,,,, -164300,0.13813713,0.017604893,,,,,,,,,,,,,,,,, -164400,0.1372357,0.018309299,,,,,,,,,,,,,,,,, -164500,0.13359408,0.015969776,,,,,,,,,,,,,,,,, -164600,0.14796564,0.018335054,,,,,,,,,,,,,,,,, -164700,0.14133766,0.016218977,,,,,,,,,,,,,,,,, -164793,,,0.995532214641571,0.013914574868977,0.7748441575197567,0.9869558811187744,0.0505871586501598,0.2865104223184633,43793.0,0.9861329793930054,0.0542605891823768,0.2760266176655251,43793.0,51879.79873919487,78124.12903475761,51879.79873919487,26230.481513261795,9.519116163253784,0.0 -164800,0.1452551,0.016192487,,,,,,,,,,,,,,,,, -164900,0.12934208,0.015713029,,,,,,,,,,,,,,,,, -165000,0.12996788,0.016668523,,,,,,,,,,,,,,,,, -165100,0.13973966,0.019053426,,,,,,,,,,,,,,,,, -165200,0.1453682,0.017296044,,,,,,,,,,,,,,,,, -165300,0.14422917,0.017131802,,,,,,,,,,,,,,,,, -165400,0.1647979,0.018738002,,,,,,,,,,,,,,,,, -165500,0.13442205,0.016871423,,,,,,,,,,,,,,,,, -165553,,,0.9955798387527466,0.0138430148363113,0.7861792908162337,0.9869558811187744,0.0505871586501598,0.2865396899834355,43793.0,0.9861329793930054,0.0542605891823768,0.2759600083074585,43793.0,52119.75490355492,78472.34597706795,52119.75490355492,26338.66453528404,9.57723355293274,0.0 -165600,0.14547369,0.017303847,,,,,,,,,,,,,,,,, -165700,0.13941573,0.016433667,,,,,,,,,,,,,,,,, -165800,0.14908232,0.01633817,,,,,,,,,,,,,,,,, -165900,0.15695682,0.017482191,,,,,,,,,,,,,,,,, -166000,0.14580217,0.016244313,,,,,,,,,,,,,,,,, -166100,0.15370695,0.01673879,,,,,,,,,,,,,,,,, -166200,0.1408394,0.017063279,,,,,,,,,,,,,,,,, -166300,0.15445766,0.019160885,,,,,,,,,,,,,,,,, -166322,,,0.9955323338508606,0.0139575526118278,0.7757177991067279,0.9869558811187744,0.0505871586501598,0.2866571959988571,43793.0,0.9861329793930054,0.0542605891823768,0.2760153037468378,43793.0,52359.93154716492,78821.26730561256,52359.93154716492,26447.33196592331,9.634422063827516,0.0 -166400,0.1375156,0.018346887,,,,,,,,,,,,,,,,, -166500,0.15716673,0.018302793,,,,,,,,,,,,,,,,, -166600,0.14110565,0.017633148,,,,,,,,,,,,,,,,, -166700,0.13591196,0.016817689,,,,,,,,,,,,,,,,, -166800,0.13330822,0.015370628,,,,,,,,,,,,,,,,, -166900,0.12902217,0.017167957,,,,,,,,,,,,,,,,, -167000,0.14816014,0.018542904,,,,,,,,,,,,,,,,, -167095,,,0.995536208152771,0.0139263924211263,0.7815679243915041,0.9869558811187744,0.0505871586501598,0.2867707938192554,43793.0,0.9861329793930054,0.0542605891823768,0.2759521436007556,43793.0,52600.01885247231,79169.45726418495,52600.01885247231,26555.263708114624,9.785484075546265,0.0 -167100,0.14975403,0.018733399,,,,,,,,,,,,,,,,, -167200,0.15823643,0.017742557,,,,,,,,,,,,,,,,, -167300,0.14876272,0.018490145,,,,,,,,,,,,,,,,, -167400,0.15771027,0.017571937,,,,,,,,,,,,,,,,, -167500,0.12871644,0.017338078,,,,,,,,,,,,,,,,, -167600,0.14639868,0.017348876,,,,,,,,,,,,,,,,, -167700,0.15782516,0.018313004,,,,,,,,,,,,,,,,, -167800,0.13141534,0.016736727,,,,,,,,,,,,,,,,, -167874,,,0.9955568313598632,0.0139553509652614,0.7682549175407976,0.9869558811187744,0.0505871586501598,0.2865932665129089,43793.0,0.9861329793930054,0.0542605891823768,0.2759558334630442,43793.0,52840.2066552639,79513.43434143066,52840.2066552639,26658.97461915016,9.84364652633667,0.0 -167900,0.14921959,0.017170412,,,,,,,,,,,,,,,,, -168000,0.12805548,0.017659618,,,,,,,,,,,,,,,,, -168100,0.13761804,0.01749515,,,,,,,,,,,,,,,,, -168200,0.13493785,0.016829278,,,,,,,,,,,,,,,,, -168300,0.18039916,0.019729018,,,,,,,,,,,,,,,,, -168400,0.13898653,0.017148225,,,,,,,,,,,,,,,,, -168500,0.15521415,0.018568305,,,,,,,,,,,,,,,,, -168600,0.14138103,0.01697259,,,,,,,,,,,,,,,,, -168616,,,0.9955222606658936,0.0139672039076685,0.7815628210311292,0.9869558811187744,0.0505871586501598,0.2865442472000015,43793.0,0.9861329793930054,0.0542605891823768,0.2760031704897681,43793.0,53080.41065263748,79861.60309934616,53080.41065263748,26766.8594186306,9.901641130447388,0.0 -168700,0.15491666,0.017264057,,,,,,,,,,,,,,,,, -168800,0.1536958,0.016505927,,,,,,,,,,,,,,,,, -168900,0.13656111,0.016527964,,,,,,,,,,,,,,,,, -169000,0.13069643,0.018554287,,,,,,,,,,,,,,,,, -169100,0.14184688,0.01635917,,,,,,,,,,,,,,,,, -169200,0.13012849,0.017490514,,,,,,,,,,,,,,,,, -169300,0.14454122,0.018217646,,,,,,,,,,,,,,,,, -169385,,,0.9955328702926636,0.0139327310025691,0.784643210829564,0.9869558811187744,0.0505871586501598,0.2866681688143653,43793.0,0.9861329793930054,0.0542605891823768,0.2759793808137554,43793.0,53320.526663541794,80208.37276148796,53320.526663541794,26873.43555402756,9.959112644195557,0.0 -169400,0.14498179,0.015121104,,,,,,,,,,,,,,,,, -169500,0.14195186,0.017888276,,,,,,,,,,,,,,,,, -169600,0.14420117,0.01636458,,,,,,,,,,,,,,,,, -169700,0.14507261,0.01750118,,,,,,,,,,,,,,,,, -169800,0.15660371,0.018729763,,,,,,,,,,,,,,,,, -169900,0.1471503,0.020091675,,,,,,,,,,,,,,,,, -170000,0.14455375,0.018369008,,,,,,,,,,,,,,,,, -170100,0.13664609,0.018319562,,,,,,,,,,,,,,,,, -170154,,,0.995567500591278,0.0138555373996496,0.7803541515077737,0.9869558811187744,0.0505871586501598,0.2865558987791517,43793.0,0.9861329793930054,0.0542605929076671,0.2760007219570654,43793.0,53560.66819763184,80554.26143503189,53560.66819763184,26979.10472464561,10.017580509185793,0.0 -170200,0.16163066,0.018047672,,,,,,,,,,,,,,,,, -170300,0.13185689,0.017087342,,,,,,,,,,,,,,,,, -170400,0.13998508,0.016895462,,,,,,,,,,,,,,,,, -170500,0.16696195,0.020842979,,,,,,,,,,,,,,,,, -170600,0.15150808,0.017744485,,,,,,,,,,,,,,,,, -170700,0.13432592,0.017249925,,,,,,,,,,,,,,,,, -170800,0.14088541,0.018368611,,,,,,,,,,,,,,,,, -170900,0.14667846,0.01716776,,,,,,,,,,,,,,,,, -170918,,,0.9955355525016784,0.0139146000146865,0.7782003795864987,0.9869558811187744,0.0505871586501598,0.2867647925173271,43793.0,0.9861329793930054,0.0542605891823768,0.2759980739721762,43793.0,53800.71771574021,80902.11252045631,53800.71771574021,27086.828375577927,10.07615089416504,0.0 -171000,0.14241189,0.018740036,,,,,,,,,,,,,,,,, -171100,0.15470243,0.019903367,,,,,,,,,,,,,,,,, -171200,0.13939051,0.018617796,,,,,,,,,,,,,,,,, -171300,0.15500413,0.017336732,,,,,,,,,,,,,,,,, -171400,0.14716856,0.0181357,,,,,,,,,,,,,,,,, -171500,0.148003,0.017797709,,,,,,,,,,,,,,,,, -171600,0.15120806,0.018379519,,,,,,,,,,,,,,,,, -171685,,,0.9955580830574036,0.013860491104424,0.7741929507380779,0.9869558811187744,0.0505871586501598,0.286599359546828,43793.0,0.9861329793930054,0.0542605891823768,0.2760119254960862,43793.0,54040.759315013885,81252.58443045616,54040.759315013885,27197.180181980133,10.135051727294922,0.0 -171700,0.13989402,0.016907463,,,,,,,,,,,,,,,,, -171800,0.1563695,0.017985001,,,,,,,,,,,,,,,,, -171900,0.1676396,0.019929832,,,,,,,,,,,,,,,,, -172000,0.14829576,0.018425688,,,,,,,,,,,,,,,,, -172100,0.1536908,0.016790314,,,,,,,,,,,,,,,,, -172200,0.16045482,0.017370233,,,,,,,,,,,,,,,,, -172300,0.123993136,0.014823859,,,,,,,,,,,,,,,,, -172400,0.13408507,0.014525,,,,,,,,,,,,,,,,, -172466,,,0.9955554008483888,0.0139900296926498,0.7805614202828458,0.9869558811187744,0.0505871586501598,0.2864266172910606,43793.0,0.9861329793930054,0.0542605891823768,0.2759691651199342,43793.0,54280.77589941025,81601.20845675468,54280.77589941025,27305.70858001709,10.194327116012571,0.0 -172500,0.13554065,0.0166339,,,,,,,,,,,,,,,,, -172600,0.16664754,0.01919258,,,,,,,,,,,,,,,,, -172700,0.16034812,0.018498698,,,,,,,,,,,,,,,,, -172800,0.14714633,0.017221443,,,,,,,,,,,,,,,,, -172900,0.148071,0.017600456,,,,,,,,,,,,,,,,, -173000,0.14408399,0.018505236,,,,,,,,,,,,,,,,, -173100,0.14999665,0.018503135,,,,,,,,,,,,,,,,, -173200,0.14191507,0.0169631,,,,,,,,,,,,,,,,, -173246,,,0.99552983045578,0.0139098232612013,0.7827256491404169,0.9869558811187744,0.0505871586501598,0.2864472859493177,43793.0,0.9861329793930054,0.0542605891823768,0.27606934583939,43793.0,54520.82110333443,81950.11951947212,54520.82110333443,27414.49658846855,10.251972436904907,0.0 -173300,0.13707802,0.01647007,,,,,,,,,,,,,,,,, -173400,0.13314542,0.017743185,,,,,,,,,,,,,,,,, -173500,0.13530603,0.017561845,,,,,,,,,,,,,,,,, -173600,0.14177395,0.016403703,,,,,,,,,,,,,,,,, -173700,0.12520884,0.014774816,,,,,,,,,,,,,,,,, -173800,0.12391244,0.01631018,,,,,,,,,,,,,,,,, -173900,0.14258906,0.018878246,,,,,,,,,,,,,,,,, -174000,0.13373232,0.016696472,,,,,,,,,,,,,,,,, -174025,,,0.9955813884735109,0.0138372285291552,0.7825517451429103,0.9869558811187744,0.0505871586501598,0.2865332651294395,43793.0,0.9861329793930054,0.0542605891823768,0.2760301578149513,43793.0,54760.8744161129,82300.26519680023,54760.8744161129,27524.51087450981,10.31071424484253,0.0 -174100,0.14481628,0.017771048,,,,,,,,,,,,,,,,, -174200,0.13720989,0.01810638,,,,,,,,,,,,,,,,, -174300,0.16672686,0.018629584,,,,,,,,,,,,,,,,, -174400,0.14562635,0.019237194,,,,,,,,,,,,,,,,, -174500,0.13944282,0.017006189,,,,,,,,,,,,,,,,, -174600,0.14012463,0.01812799,,,,,,,,,,,,,,,,, -174700,0.1482247,0.017968614,,,,,,,,,,,,,,,,, -174800,0.13828692,0.015988272,,,,,,,,,,,,,,,,, -174803,,,0.9954863786697388,0.0140195405110716,0.7766700627017158,0.9869558811187744,0.0505871586501598,0.2865937948374639,43793.0,0.9861329793930054,0.0542605891823768,0.275994238581464,43793.0,55001.09853792191,82648.0693461895,55001.09853792191,27632.01285123825,10.36878228187561,0.0 -174900,0.14702526,0.018585516,,,,,,,,,,,,,,,,, -175000,0.14558783,0.016178299,,,,,,,,,,,,,,,,, -175100,0.12809645,0.016657302,,,,,,,,,,,,,,,,, -175200,0.14268067,0.016807025,,,,,,,,,,,,,,,,, -175300,0.15259972,0.018239493,,,,,,,,,,,,,,,,, -175400,0.14465615,0.016558273,,,,,,,,,,,,,,,,, -175500,0.12933505,0.015926268,,,,,,,,,,,,,,,,, -175574,,,0.9955763220787048,0.0138569138944149,0.7794461300147126,0.9869558811187744,0.0505871586501598,0.2865507467684166,43793.0,0.9861329793930054,0.0542605891823768,0.276069050592469,43793.0,55241.06003189087,82997.75070357323,55241.06003189087,27741.65455651284,10.427810430526732,0.0 -175600,0.17663485,0.020162532,,,,,,,,,,,,,,,,, -175700,0.14505327,0.018110873,,,,,,,,,,,,,,,,, -175800,0.13993609,0.016780503,,,,,,,,,,,,,,,,, -175900,0.16183394,0.02106363,,,,,,,,,,,,,,,,, -176000,0.13572513,0.018036004,,,,,,,,,,,,,,,,, -176100,0.14135624,0.016290644,,,,,,,,,,,,,,,,, -176183,,,,,,,,,,,,,,55431.18258333206,,,,,0.0 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv deleted file mode 100644 index d9b3bc739..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/eval_measurements.csv +++ /dev/null @@ -1,51 +0,0 @@ -accumulated_eval_time,accumulated_logging_time,accumulated_submission_time,global_step,preemption_count,score,test/accuracy,test/bleu,test/loss,test/num_examples,total_duration,train/accuracy,train/bleu,train/loss,validation/accuracy,validation/bleu,validation/loss,validation/num_examples -881.3112032413483,0.0,39.65361142158508,1,0,39.65361142158508,0.0007088489946909,0.0,11.113809585571287,3003,920.964860200882,0.0005556122632697,0.0,11.132158279418944,0.0004835649742744,0.0,11.097217559814451,3000 -1382.864819765091,0.0302114486694335,879.6641981601715,2377,0,879.6641981601715,0.5126023888587952,16.82979069214398,2.8845648765563965,3003,2262.6394739151,0.5125483274459839,22.492414050302205,2.858908176422119,0.5142651796340942,18.30720532561121,2.836130142211914,3000 -1944.9201946258545,0.0610411167144775,1719.578450679779,4754,0,1719.578450679779,0.594375729560852,22.238124830388227,2.114021062850952,3003,3664.720867872238,0.5762366056442261,26.972132326701903,2.2599639892578125,0.5911024212837219,23.643337175562927,2.15031099319458,3000 -2375.4861319065094,0.0864703655242919,2559.766467809677,7132,0,2559.766467809677,0.6212770938873291,24.04397100312569,1.88267195224762,3003,4935.580909729004,0.6054196953773499,28.578043861643884,2.010335922241211,0.6175868511199951,25.03961870233312,1.925435900688172,3000 -2833.360686540604,0.1121339797973632,3399.8166558742523,9510,0,3399.8166558742523,0.6384173035621643,25.14738798082741,1.761852741241455,3003,6233.60720205307,0.6144471764564514,29.758036174075343,1.9414385557174685,0.6304943561553955,26.07328233983416,1.8233507871627808,3000 -3258.925538778305,0.138277530670166,4240.059615850449,11889,0,4240.059615850449,0.6477136611938477,25.62417713326412,1.6896036863327026,3003,7499.519518136978,0.6189337968826294,29.89805366521182,1.889572024345398,0.6390745043754578,26.64441501614611,1.7538105249404907,3000 -3720.4962162971497,0.165790319442749,5080.140574455261,14268,0,5080.140574455261,0.6522108316421509,25.7217432979878,1.6444939374923706,3003,8801.274183273315,0.624485433101654,30.31258298412598,1.8488229513168333,0.6445549130439758,26.9599459838488,1.706523299217224,3000 -4237.940180301666,0.1926445960998535,5920.160865306854,16646,0,5920.160865306854,0.6577886343002319,26.537255448332026,1.6086273193359375,3003,10158.843495845796,0.6264616847038269,30.34013010361292,1.830745816230774,0.6479646563529968,27.2750023954531,1.6873875856399536,3000 -4800.146046638489,0.2206084728240966,6760.097064495087,19024,0,6760.097064495087,0.6595317125320435,26.47710090550588,1.5889374017715454,3003,11561.091912984848,0.6484805345535278,32.06108394190534,1.662880301475525,0.6506428718566895,27.19620966158643,1.659526228904724,3000 -5359.059168815613,0.2494087219238281,7600.03609752655,21401,0,7600.03609752655,0.6625297665596008,27.053601061051705,1.5776896476745603,3003,12960.05086159706,0.6325982213020325,30.65917046993292,1.7833012342453003,0.652019202709198,27.546996197827287,1.656695008277893,3000 -5879.748903989792,0.2777047157287597,8440.063966989517,23779,0,8440.063966989517,0.6655743718147278,26.95498017227819,1.5527108907699585,3003,14320.875417470932,0.6351043581962585,30.41944310720691,1.777759313583374,0.6547470092773438,27.47410007445849,1.6266250610351562,3000 -6406.126123189926,0.3062803745269775,9280.269245147703,26158,0,9280.269245147703,0.6680960059165955,27.35011242889108,1.5441603660583496,3003,15687.564682245256,0.6418652534484863,31.214415856423287,1.7051342725753784,0.6578715443611145,28.10912777034857,1.613921284675598,3000 -7009.274637699127,0.3344330787658691,10120.285497665403,28536,0,10120.285497665403,0.6701993346214294,27.660913346503644,1.5299732685089111,3003,17130.83499264717,0.6390178799629211,31.29680946133006,1.7468416690826416,0.6594090461730957,28.322223092791035,1.6080009937286377,3000 -7553.392800569534,0.3641421794891357,10960.433304071426,30915,0,10960.433304071426,0.669839084148407,27.341967471219807,1.522180676460266,3003,18515.20808935165,0.6399180293083191,31.006162959312263,1.7383826971054075,0.659421443939209,27.883103402956426,1.5983082056045532,3000 -8077.200543165207,0.3959531784057617,11800.462206363678,33293,0,11800.462206363678,0.6725698709487915,27.681689309528668,1.5092467069625854,3003,19879.156966924667,0.6442225575447083,31.78765129289712,1.6994339227676392,0.6610333323478699,28.3620764574474,1.5850367546081543,3000 -8613.434501886368,0.4250335693359375,12640.359733104706,35670,0,12640.359733104706,0.6746150851249695,27.680088726518168,1.4966024160385132,3003,21255.39935207367,0.6402103304862976,31.204288012321065,1.7327693700790403,0.6622112393379211,28.52652591972902,1.5775601863861084,3000 -9109.951312541962,0.4539699554443359,13480.456578493118,38049,0,13480.456578493118,0.6761025190353394,28.02324678272048,1.4913530349731443,3003,22592.117666721344,0.6570454835891724,32.339918945597454,1.6043413877487185,0.6630792021751404,28.517361145543024,1.5772230625152588,3000 -9753.249324083328,0.4837007522583008,14320.392687559128,40427,0,14320.392687559128,0.6763232946395874,27.91822635671377,1.482987880706787,3003,24075.46028089524,0.6448788642883301,31.99471937349964,1.695878028869629,0.6641083359718323,28.41923535197316,1.5695807933807373,3000 -10289.211467266085,0.5132253170013428,15160.55874156952,42806,0,15160.55874156952,0.6781709790229797,27.79226369399139,1.4798240661621094,3003,25451.696177721024,0.6462165117263794,31.77797547770024,1.6971834897994995,0.6654722094535828,28.557748479244324,1.557735800743103,3000 -10809.894648313522,0.5482079982757568,16000.626110076904,45184,0,16000.626110076904,0.6798791885375977,28.49927497697626,1.4696717262268066,3003,26812.56391477585,0.6523397564888,32.163557741677224,1.6415144205093384,0.6671088933944702,29.014261662014693,1.5524981021881104,3000 -11538.277686834335,0.579641580581665,16840.60457134247,47562,0,16840.60457134247,0.679332971572876,28.274578857046443,1.4617844820022583,3003,28381.03490614891,0.6457970142364502,31.71187306106621,1.698169231414795,0.6662285327911377,28.7745931280484,1.5489721298217771,3000 -12098.422665834429,0.6169235706329346,17680.672848939896,49942,0,17680.672848939896,0.6789495348930359,28.027025139918145,1.4625924825668335,3003,29781.36154270172,0.6485532522201538,31.9474758616229,1.6859617233276367,0.6677164435386658,28.515111507698936,1.5421111583709717,3000 -12673.795425891876,0.6527571678161621,18520.70462822914,52321,0,18520.70462822914,0.6807622909545898,28.34466917523456,1.4502663612365725,3003,31196.87852740288,0.6484258770942688,31.91824831416189,1.6598361730575562,0.6704814434051514,28.95387046457841,1.5352940559387207,3000 -13207.22210764885,0.6896021366119385,19361.222714662552,54701,0,19361.222714662552,0.6822962164878845,28.36710670375829,1.4418914318084717,3003,32570.93722343445,0.6495399475097656,31.643514298402373,1.6735833883285522,0.6713989973068237,29.22028575150289,1.5312122106552124,3000 -13693.890861272812,0.7218029499053955,20201.32160258293,57079,0,20201.32160258293,0.6844227910041809,28.452016178586515,1.4322528839111328,3003,33897.81595611572,0.6603872179985046,32.61815973523987,1.5857213735580444,0.670741856098175,29.064819889773123,1.5168012380599976,3000 -14384.28313088417,0.7561678886413574,21041.213926553726,59456,0,21041.213926553726,0.6850734949111938,28.68317868851248,1.4224475622177124,3003,35428.21695756912,0.6552941799163818,32.59673470601075,1.624788522720337,0.6723288893699646,29.290384592689776,1.5095142126083374,3000 -14927.5224506855,0.7940044403076172,21881.345286130905,61835,0,21881.345286130905,0.6883040070533752,28.97549138929193,1.4157791137695312,3003,36811.70310878754,0.6523348689079285,32.40031350898231,1.646162033081055,0.6736432313919067,29.26776115359916,1.504443645477295,3000 -15464.544717550278,0.8278894424438477,22721.47764348984,64214,0,22721.47764348984,0.6827958822250366,28.18834542843649,1.4405723810195925,3003,38188.96942901611,0.6573935151100159,32.14257821076103,1.6122422218322754,0.6687331795692444,28.58250826805473,1.5286437273025513,3000 -16107.16114974022,0.8622546195983887,23561.66490626335,66592,0,23561.66490626335,0.6875951886177063,28.568736458659583,1.4001637697219849,3003,39671.88665890694,0.6559733152389526,32.54262723353351,1.6213107109069824,0.6742507815361023,29.09013429115508,1.4933558702468872,3000 -16617.040287017822,0.8978571891784668,24401.740617513657,68970,0,24401.740617513657,0.6907442808151245,29.05332517746374,1.3908847570419312,3003,41021.95736527443,0.6814596652984619,34.3580861992117,1.4550007581710815,0.6773753762245178,29.381252753878385,1.4872820377349854,3000 -17464.169107675552,0.9319391250610352,25241.72496700287,71348,0,25241.72496700287,0.6910116076469421,29.21504352470452,1.389503002166748,3003,42709.18280673027,0.663874626159668,32.56455290267204,1.576033115386963,0.6783424615859985,29.55113498636966,1.4791110754013062,3000 -18037.4237639904,0.9692091941833496,26081.76236653328,73727,0,26081.76236653328,0.6914764046669006,29.265794145234867,1.378501296043396,3003,44122.5883204937,0.6613771915435791,32.78422420721934,1.5929934978485107,0.6800411343574524,29.65939306819033,1.4665690660476685,3000 -18631.639093637463,1.004053831100464,26921.769094228745,76106,0,26921.769094228745,0.6944977045059204,29.374789377540463,1.370276689529419,3003,45556.92244243622,0.6700928211212158,33.644985715093746,1.5198198556900024,0.6783796548843384,29.57411672308368,1.4667606353759766,3000 -19132.193336486816,1.0392353534698486,27761.8942797184,78485,0,27761.8942797184,0.6956830024719238,29.425967056293665,1.360418438911438,3003,46897.715759277344,0.6622284650802612,32.57791200673433,1.5754069089889526,0.6812066435813904,29.76072163868272,1.4563653469085691,3000 -19671.58778452873,1.0748109817504885,28601.873636484142,80863,0,28601.873636484142,0.6951949596405029,29.63272634036333,1.3603163957595823,3003,48277.20259356499,0.6636092662811279,32.70319836498964,1.5703322887420654,0.6823350191116333,30.035218304926424,1.4525598287582395,3000 -20300.30703139305,1.112574815750122,29442.05039691925,83242,0,29442.05039691925,0.6981233358383179,29.69593409300139,1.3538438081741333,3003,49746.21423172951,0.6721797585487366,33.680118352812386,1.513619065284729,0.6828929781913757,29.927711880365525,1.4463530778884888,3000 -20838.189848661423,1.1500585079193115,30282.133563756943,85621,0,30282.133563756943,0.6989948749542236,30.01270004080984,1.3409332036972046,3003,51124.29514193535,0.6682446002960205,33.75045946079044,1.540069341659546,0.6835377216339111,29.889512828415363,1.439979076385498,3000 -21545.33805847168,1.1864256858825684,31122.28181028366,87999,0,31122.28181028366,0.7000523209571838,29.86954274429864,1.3362531661987305,3003,52671.70739459992,0.6825253367424011,34.338008319013625,1.4487786293029783,0.6857943534851074,30.0903180949495,1.4298676252365112,3000 -22204.37990808487,1.2239248752593994,31962.3699696064,90377,0,31962.3699696064,0.7006449699401855,29.98744025410253,1.3293626308441162,3003,54170.957562446594,0.6723057627677917,33.55991972595787,1.5078905820846558,0.6858935356140137,30.049164222142057,1.4237116575241089,3000 -22774.05219650269,1.2687523365020752,32802.475167512894,92756,0,32802.475167512894,0.7013421654701233,29.937114050955184,1.322433352470398,3003,55580.85606193543,0.6719854474067688,33.860206279249766,1.514437198638916,0.6873442530632019,30.474739154210475,1.416701078414917,3000 -23358.04110336304,1.312809705734253,33642.51808953285,95134,0,33642.51808953285,0.7038289904594421,30.08935415843265,1.3092671632766724,3003,57005.01177072525,0.6826393604278564,34.60986378597462,1.4538395404815674,0.6881625652313232,30.472603703054315,1.4122296571731567,3000 -24051.911954164505,1.350311040878296,34482.75514650345,97513,0,34482.75514650345,0.7047934532165527,29.9236223901206,1.3054507970809937,3003,58539.23271560669,0.6825652122497559,34.105125485857464,1.4482752084732056,0.6888568997383118,30.3004675653798,1.4053305387496948,3000 -24694.09686756134,1.3887178897857666,35322.91156697273,99892,0,35322.91156697273,0.7065016627311707,30.29928668877095,1.2986218929290771,3003,60021.6882147789,0.678978443145752,34.306114082887525,1.479453206062317,0.6903820037841797,30.426522454410275,1.4001511335372925,3000 -25294.16486024857,1.4278733730316162,36162.81519198418,102270,0,36162.81519198418,0.7074313163757324,30.136253694643354,1.2908309698104858,3003,61461.7780213356,0.6835367679595947,34.774789100055045,1.4356738328933716,0.6904935836791992,30.6419763551928,1.3909913301467896,3000 -25915.23298954964,1.467015027999878,37002.91240334511,104649,0,37002.91240334511,0.708453893661499,30.46981512380905,1.2837616205215454,3003,62923.05829691887,0.677696704864502,34.298339599665034,1.4766273498535156,0.6918699145317078,30.743370190512355,1.387955665588379,3000 -26560.762880563736,1.5060055255889893,37843.09711742401,107028,0,37843.09711742401,0.7085817456245422,30.52008339596013,1.281222581863403,3003,64408.88877725601,0.6928110718727112,35.15910011581935,1.384670376777649,0.6928246021270752,30.72062201906474,1.384840965270996,3000 -27178.494292497635,1.5452158451080322,38683.2755715847,109407,0,38683.2755715847,0.7109174728393555,30.94646817655173,1.2740641832351685,3003,65866.91502094269,0.6901442408561707,35.2604371709747,1.4055242538452148,0.6927750110626221,30.8121018107398,1.3797883987426758,3000 -27726.88291144371,1.5854876041412354,39523.27041316032,111784,0,39523.27041316032,0.7106966376304626,30.744305941224404,1.268842339515686,3003,67255.41884446144,0.6872426271438599,34.660629651300326,1.4228895902633667,0.6934818029403687,30.673267151118715,1.3747117519378662,3000 -28344.93062567711,1.6279840469360352,40363.24680304527,114162,0,40363.24680304527,0.7109755277633667,30.69732003049124,1.265406370162964,3003,68713.56347870827,0.6885464191436768,35.04633107417319,1.4088932275772097,0.6940521597862244,30.744583577862368,1.3701497316360474,3000 -29004.25678896904,1.6690990924835205,41203.46653318405,116541,0,41203.46653318405,0.7128115892410278,31.03030929029123,1.2598994970321655,3003,70213.22715878487,0.6922299861907959,34.91732924260557,1.387744426727295,0.6947712898254395,30.894717875592825,1.366743564605713,3000 diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv b/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv deleted file mode 100644 index 83d75ac6d..000000000 --- a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/measurements.csv +++ /dev/null @@ -1,1218 +0,0 @@ -global_step,grad_norm,loss,train/accuracy,train/loss,train/bleu,validation/accuracy,validation/loss,validation/bleu,validation/num_examples,test/accuracy,test/loss,test/bleu,test/num_examples,score,total_duration,accumulated_submission_time,accumulated_eval_time,accumulated_logging_time,preemption_count -0,4.9163036,11.113301,,,,,,,,,,,,,,,,, -1,,,0.0005556122632697,11.132158279418944,0.0,0.0004835649742744,11.097217559814451,0.0,3000.0,0.0007088489946909,11.113809585571287,0.0,3003.0,39.65361142158508,920.964860200882,39.65361142158508,881.3112032413483,0.0,0.0 -100,0.17965347,8.2758255,,,,,,,,,,,,,,,,, -200,0.3968871,7.5326786,,,,,,,,,,,,,,,,, -300,0.39763477,6.8916683,,,,,,,,,,,,,,,,, -400,0.70212215,6.3449626,,,,,,,,,,,,,,,,, -500,0.68329185,5.91069,,,,,,,,,,,,,,,,, -600,0.6546725,5.6009226,,,,,,,,,,,,,,,,, -700,0.423316,5.331258,,,,,,,,,,,,,,,,, -800,0.44307974,4.9274716,,,,,,,,,,,,,,,,, -900,0.40068874,4.7794657,,,,,,,,,,,,,,,,, -1000,0.46350193,4.618991,,,,,,,,,,,,,,,,, -1100,0.5980187,4.3162727,,,,,,,,,,,,,,,,, -1200,0.6633277,4.043761,,,,,,,,,,,,,,,,, -1300,0.528128,3.9795923,,,,,,,,,,,,,,,,, -1400,0.5050395,3.7127004,,,,,,,,,,,,,,,,, -1500,0.47411394,3.621483,,,,,,,,,,,,,,,,, -1600,0.48000166,3.4490328,,,,,,,,,,,,,,,,, -1700,0.4654016,3.4220035,,,,,,,,,,,,,,,,, -1800,0.7058538,3.3374972,,,,,,,,,,,,,,,,, -1900,0.42697826,3.293105,,,,,,,,,,,,,,,,, -2000,0.39234993,3.1055186,,,,,,,,,,,,,,,,, -2100,0.3953766,3.0911086,,,,,,,,,,,,,,,,, -2200,0.39854333,3.148742,,,,,,,,,,,,,,,,, -2300,0.49397326,3.0600045,,,,,,,,,,,,,,,,, -2377,,,0.5125483274459839,2.858908176422119,22.492414050302205,0.5142651796340942,2.836130142211914,18.30720532561121,3000.0,0.5126023888587952,2.8845648765563965,16.82979069214398,3003.0,879.6641981601715,2262.6394739151,879.6641981601715,1382.864819765091,0.0302114486694335,0.0 -2400,0.42336866,3.0287206,,,,,,,,,,,,,,,,, -2500,0.37632862,2.9406307,,,,,,,,,,,,,,,,, -2600,0.29973385,2.7745543,,,,,,,,,,,,,,,,, -2700,0.30261925,2.7724075,,,,,,,,,,,,,,,,, -2800,0.29333284,2.7254803,,,,,,,,,,,,,,,,, -2900,0.29776546,2.7437024,,,,,,,,,,,,,,,,, -3000,0.29285228,2.6848915,,,,,,,,,,,,,,,,, -3100,0.24795723,2.6954856,,,,,,,,,,,,,,,,, -3200,0.27840272,2.6904886,,,,,,,,,,,,,,,,, -3300,0.23832998,2.566884,,,,,,,,,,,,,,,,, -3400,0.21723805,2.552032,,,,,,,,,,,,,,,,, -3500,0.23595992,2.429157,,,,,,,,,,,,,,,,, -3600,0.213085,2.5208187,,,,,,,,,,,,,,,,, -3700,0.21269868,2.4989824,,,,,,,,,,,,,,,,, -3800,0.24176091,2.59935,,,,,,,,,,,,,,,,, -3900,0.19847573,2.4692476,,,,,,,,,,,,,,,,, -4000,0.18280788,2.4913032,,,,,,,,,,,,,,,,, -4100,0.22072965,2.3681633,,,,,,,,,,,,,,,,, -4200,0.17782502,2.4696434,,,,,,,,,,,,,,,,, -4300,0.19968648,2.441283,,,,,,,,,,,,,,,,, -4400,0.2556613,2.3621347,,,,,,,,,,,,,,,,, -4500,0.19508113,2.3701932,,,,,,,,,,,,,,,,, -4600,0.1951536,2.3953943,,,,,,,,,,,,,,,,, -4700,0.17360796,2.2715611,,,,,,,,,,,,,,,,, -4754,,,0.5762366056442261,2.2599639892578125,26.972132326701903,0.5911024212837219,2.15031099319458,23.643337175562927,3000.0,0.594375729560852,2.114021062850952,22.238124830388227,3003.0,1719.578450679779,3664.720867872238,1719.578450679779,1944.9201946258545,0.0610411167144775,0.0 -4800,0.20115435,2.2771244,,,,,,,,,,,,,,,,, -4900,0.17486696,2.3253152,,,,,,,,,,,,,,,,, -5000,0.19457558,2.2393975,,,,,,,,,,,,,,,,, -5100,0.17048535,2.3810856,,,,,,,,,,,,,,,,, -5200,0.1745368,2.2530048,,,,,,,,,,,,,,,,, -5300,0.15426219,2.2729816,,,,,,,,,,,,,,,,, -5400,0.15858607,2.2566438,,,,,,,,,,,,,,,,, -5500,0.16898586,2.283371,,,,,,,,,,,,,,,,, -5600,0.16452143,2.303827,,,,,,,,,,,,,,,,, -5700,0.17649893,2.2407362,,,,,,,,,,,,,,,,, -5800,0.15169694,2.19714,,,,,,,,,,,,,,,,, -5900,0.15706909,2.1737046,,,,,,,,,,,,,,,,, -6000,0.16281505,2.1814663,,,,,,,,,,,,,,,,, -6100,0.17640473,2.2131171,,,,,,,,,,,,,,,,, -6200,0.1540966,2.181255,,,,,,,,,,,,,,,,, -6300,0.14671159,2.127197,,,,,,,,,,,,,,,,, -6400,0.1641264,2.2104468,,,,,,,,,,,,,,,,, -6500,0.15076023,2.1623554,,,,,,,,,,,,,,,,, -6600,0.16796917,2.1500766,,,,,,,,,,,,,,,,, -6700,0.15809216,2.0603805,,,,,,,,,,,,,,,,, -6800,0.15588066,2.1756694,,,,,,,,,,,,,,,,, -6900,0.1995287,2.1742551,,,,,,,,,,,,,,,,, -7000,0.16173325,2.1843958,,,,,,,,,,,,,,,,, -7100,0.21095583,2.2107022,,,,,,,,,,,,,,,,, -7132,,,0.6054196953773499,2.010335922241211,28.578043861643884,0.6175868511199951,1.925435900688172,25.03961870233312,3000.0,0.6212770938873291,1.88267195224762,24.04397100312569,3003.0,2559.766467809677,4935.580909729004,2559.766467809677,2375.4861319065094,0.0864703655242919,0.0 -7200,0.2067134,2.227065,,,,,,,,,,,,,,,,, -7300,0.19234385,2.1133168,,,,,,,,,,,,,,,,, -7400,0.14878322,2.1462216,,,,,,,,,,,,,,,,, -7500,0.1527745,2.0589206,,,,,,,,,,,,,,,,, -7600,0.1521744,2.1606832,,,,,,,,,,,,,,,,, -7700,0.15311997,2.0191722,,,,,,,,,,,,,,,,, -7800,0.14266124,2.1431065,,,,,,,,,,,,,,,,, -7900,0.15911871,2.0830023,,,,,,,,,,,,,,,,, -8000,0.16161188,2.1455855,,,,,,,,,,,,,,,,, -8100,0.15954794,2.17461,,,,,,,,,,,,,,,,, -8200,0.16102117,2.036207,,,,,,,,,,,,,,,,, -8300,0.19273663,2.0860689,,,,,,,,,,,,,,,,, -8400,0.16497546,2.1546519,,,,,,,,,,,,,,,,, -8500,0.16011924,2.0817933,,,,,,,,,,,,,,,,, -8600,0.16682489,2.0298562,,,,,,,,,,,,,,,,, -8700,0.17936774,2.0869102,,,,,,,,,,,,,,,,, -8800,0.15702258,2.0502834,,,,,,,,,,,,,,,,, -8900,0.16874674,2.0365033,,,,,,,,,,,,,,,,, -9000,0.33511403,2.1007984,,,,,,,,,,,,,,,,, -9100,0.17496318,2.0875587,,,,,,,,,,,,,,,,, -9200,0.20750834,2.084167,,,,,,,,,,,,,,,,, -9300,0.36252704,2.0940008,,,,,,,,,,,,,,,,, -9400,0.26452708,2.0320187,,,,,,,,,,,,,,,,, -9500,0.15665586,2.006336,,,,,,,,,,,,,,,,, -9510,,,0.6144471764564514,1.9414385557174685,29.758036174075343,0.6304943561553955,1.8233507871627808,26.07328233983416,3000.0,0.6384173035621643,1.761852741241455,25.14738798082741,3003.0,3399.8166558742523,6233.60720205307,3399.8166558742523,2833.360686540604,0.1121339797973632,0.0 -9600,0.2363399,2.0790768,,,,,,,,,,,,,,,,, -9700,0.22187483,2.0257275,,,,,,,,,,,,,,,,, -9800,0.23038895,1.9982512,,,,,,,,,,,,,,,,, -9900,0.20193043,1.9891756,,,,,,,,,,,,,,,,, -10000,0.2140237,2.0450325,,,,,,,,,,,,,,,,, -10100,0.1688065,2.0209215,,,,,,,,,,,,,,,,, -10200,0.1561126,2.035604,,,,,,,,,,,,,,,,, -10300,0.15605742,1.9867054,,,,,,,,,,,,,,,,, -10400,0.17753597,1.9916692,,,,,,,,,,,,,,,,, -10500,0.16977447,2.0011644,,,,,,,,,,,,,,,,, -10600,0.15228257,1.9398332,,,,,,,,,,,,,,,,, -10700,0.16215609,2.099889,,,,,,,,,,,,,,,,, -10800,0.18867385,2.0108202,,,,,,,,,,,,,,,,, -10900,0.16044435,2.0139577,,,,,,,,,,,,,,,,, -11000,0.19339465,2.0103388,,,,,,,,,,,,,,,,, -11100,0.23805052,1.9684272,,,,,,,,,,,,,,,,, -11200,0.18757366,1.9904153,,,,,,,,,,,,,,,,, -11300,0.33712837,2.0587773,,,,,,,,,,,,,,,,, -11400,0.15465951,1.9317912,,,,,,,,,,,,,,,,, -11500,0.25116923,2.0202916,,,,,,,,,,,,,,,,, -11600,0.15916008,1.9969654,,,,,,,,,,,,,,,,, -11700,0.23007609,2.006921,,,,,,,,,,,,,,,,, -11800,0.19268842,2.0161035,,,,,,,,,,,,,,,,, -11889,,,0.6189337968826294,1.889572024345398,29.89805366521182,0.6390745043754578,1.7538105249404907,26.64441501614611,3000.0,0.6477136611938477,1.6896036863327026,25.62417713326412,3003.0,4240.059615850449,7499.519518136978,4240.059615850449,3258.925538778305,0.138277530670166,0.0 -11900,0.2376083,2.012659,,,,,,,,,,,,,,,,, -12000,0.21225908,2.0339313,,,,,,,,,,,,,,,,, -12100,0.15858205,1.9279615,,,,,,,,,,,,,,,,, -12200,0.14732271,1.93196,,,,,,,,,,,,,,,,, -12300,0.1688162,2.003209,,,,,,,,,,,,,,,,, -12400,0.17218868,1.877724,,,,,,,,,,,,,,,,, -12500,0.17133538,1.945033,,,,,,,,,,,,,,,,, -12600,0.29328367,2.0085368,,,,,,,,,,,,,,,,, -12700,0.18710527,1.8790289,,,,,,,,,,,,,,,,, -12800,0.16706932,1.9774511,,,,,,,,,,,,,,,,, -12900,0.19261357,1.8679131,,,,,,,,,,,,,,,,, -13000,0.18499182,1.9652019,,,,,,,,,,,,,,,,, -13100,0.3247792,2.0322547,,,,,,,,,,,,,,,,, -13200,0.18755071,1.9851247,,,,,,,,,,,,,,,,, -13300,0.1592786,1.8702744,,,,,,,,,,,,,,,,, -13400,0.19966063,1.9989203,,,,,,,,,,,,,,,,, -13500,0.2155304,1.9842159,,,,,,,,,,,,,,,,, -13600,0.24417491,2.0293593,,,,,,,,,,,,,,,,, -13700,0.2007257,1.9730418,,,,,,,,,,,,,,,,, -13800,0.23592658,2.0067284,,,,,,,,,,,,,,,,, -13900,0.319706,1.9662907,,,,,,,,,,,,,,,,, -14000,0.18278411,1.952438,,,,,,,,,,,,,,,,, -14100,0.16406001,1.9155121,,,,,,,,,,,,,,,,, -14200,0.19666573,1.890327,,,,,,,,,,,,,,,,, -14268,,,0.624485433101654,1.8488229513168333,30.31258298412598,0.6445549130439758,1.706523299217224,26.9599459838488,3000.0,0.6522108316421509,1.6444939374923706,25.7217432979878,3003.0,5080.140574455261,8801.274183273315,5080.140574455261,3720.4962162971497,0.165790319442749,0.0 -14300,0.2579097,1.9239281,,,,,,,,,,,,,,,,, -14400,0.17331359,1.9749,,,,,,,,,,,,,,,,, -14500,0.18893561,1.9974849,,,,,,,,,,,,,,,,, -14600,0.20140128,1.8956712,,,,,,,,,,,,,,,,, -14700,0.17851299,1.9714568,,,,,,,,,,,,,,,,, -14800,0.24337007,1.8489807,,,,,,,,,,,,,,,,, -14900,0.18539505,1.8809934,,,,,,,,,,,,,,,,, -15000,0.16916186,2.038751,,,,,,,,,,,,,,,,, -15100,0.19658202,1.8920383,,,,,,,,,,,,,,,,, -15200,0.18980956,1.9399006,,,,,,,,,,,,,,,,, -15300,0.18412337,2.0321467,,,,,,,,,,,,,,,,, -15400,0.1824323,1.9355384,,,,,,,,,,,,,,,,, -15500,0.16516733,1.859343,,,,,,,,,,,,,,,,, -15600,0.22059113,2.0187185,,,,,,,,,,,,,,,,, -15700,0.1931516,1.9370211,,,,,,,,,,,,,,,,, -15800,0.18309782,1.8205322,,,,,,,,,,,,,,,,, -15900,0.16522889,1.8159219,,,,,,,,,,,,,,,,, -16000,0.2075771,1.9425524,,,,,,,,,,,,,,,,, -16100,0.26733667,1.9295635,,,,,,,,,,,,,,,,, -16200,0.22975211,1.8354455,,,,,,,,,,,,,,,,, -16300,0.18326129,1.9010046,,,,,,,,,,,,,,,,, -16400,0.23220725,1.89339,,,,,,,,,,,,,,,,, -16500,0.16768941,1.9636612,,,,,,,,,,,,,,,,, -16600,0.17640138,1.9146967,,,,,,,,,,,,,,,,, -16646,,,0.6264616847038269,1.830745816230774,30.34013010361292,0.6479646563529968,1.6873875856399536,27.2750023954531,3000.0,0.6577886343002319,1.6086273193359375,26.537255448332026,3003.0,5920.160865306854,10158.843495845796,5920.160865306854,4237.940180301666,0.1926445960998535,0.0 -16700,0.33002266,1.950652,,,,,,,,,,,,,,,,, -16800,0.21831341,1.8724387,,,,,,,,,,,,,,,,, -16900,0.27721006,1.9143037,,,,,,,,,,,,,,,,, -17000,0.33254758,1.9783514,,,,,,,,,,,,,,,,, -17100,0.21740457,1.9303007,,,,,,,,,,,,,,,,, -17200,0.200323,1.876773,,,,,,,,,,,,,,,,, -17300,0.27440998,1.8549539,,,,,,,,,,,,,,,,, -17400,0.27923304,1.8382767,,,,,,,,,,,,,,,,, -17500,0.18662868,1.9363453,,,,,,,,,,,,,,,,, -17600,0.22074898,1.905525,,,,,,,,,,,,,,,,, -17700,0.18323538,1.8070657,,,,,,,,,,,,,,,,, -17800,0.19494401,1.8844746,,,,,,,,,,,,,,,,, -17900,0.19873309,1.8459791,,,,,,,,,,,,,,,,, -18000,0.25958025,1.823875,,,,,,,,,,,,,,,,, -18100,0.19451623,1.8741372,,,,,,,,,,,,,,,,, -18200,0.20713843,1.8811502,,,,,,,,,,,,,,,,, -18300,0.19030924,1.8828818,,,,,,,,,,,,,,,,, -18400,0.22881988,1.9182009,,,,,,,,,,,,,,,,, -18500,0.25395006,1.931886,,,,,,,,,,,,,,,,, -18600,0.19657122,1.8191267,,,,,,,,,,,,,,,,, -18700,0.17231326,1.8817759,,,,,,,,,,,,,,,,, -18800,0.18250452,1.8629664,,,,,,,,,,,,,,,,, -18900,0.21422902,1.8775232,,,,,,,,,,,,,,,,, -19000,0.2073189,1.8637036,,,,,,,,,,,,,,,,, -19024,,,0.6484805345535278,1.662880301475525,32.06108394190534,0.6506428718566895,1.659526228904724,27.19620966158643,3000.0,0.6595317125320435,1.5889374017715454,26.47710090550588,3003.0,6760.097064495087,11561.091912984848,6760.097064495087,4800.146046638489,0.2206084728240966,0.0 -19100,0.18897949,1.9531893,,,,,,,,,,,,,,,,, -19200,0.23752706,1.9143986,,,,,,,,,,,,,,,,, -19300,0.20556979,1.8589157,,,,,,,,,,,,,,,,, -19400,0.21939251,1.9472015,,,,,,,,,,,,,,,,, -19500,0.19428124,1.8682925,,,,,,,,,,,,,,,,, -19600,0.17415643,1.8765966,,,,,,,,,,,,,,,,, -19700,0.24722794,1.8144841,,,,,,,,,,,,,,,,, -19800,0.20589459,1.9766641,,,,,,,,,,,,,,,,, -19900,0.19285996,1.8328108,,,,,,,,,,,,,,,,, -20000,0.19214694,1.8896302,,,,,,,,,,,,,,,,, -20100,0.28099898,1.8662969,,,,,,,,,,,,,,,,, -20200,0.20630397,1.8383982,,,,,,,,,,,,,,,,, -20300,0.16706815,1.8854837,,,,,,,,,,,,,,,,, -20400,0.18297519,1.8706408,,,,,,,,,,,,,,,,, -20500,0.16736293,1.9210813,,,,,,,,,,,,,,,,, -20600,0.1916775,1.9090931,,,,,,,,,,,,,,,,, -20700,0.27540702,1.8390589,,,,,,,,,,,,,,,,, -20800,0.17422721,1.8966265,,,,,,,,,,,,,,,,, -20900,0.19789064,1.7811775,,,,,,,,,,,,,,,,, -21000,0.19171973,1.8655857,,,,,,,,,,,,,,,,, -21100,0.2040026,1.8170089,,,,,,,,,,,,,,,,, -21200,0.21920052,1.858481,,,,,,,,,,,,,,,,, -21300,0.19553417,1.8599972,,,,,,,,,,,,,,,,, -21400,0.3396004,1.9177654,,,,,,,,,,,,,,,,, -21401,,,0.6325982213020325,1.7833012342453003,30.65917046993292,0.652019202709198,1.656695008277893,27.546996197827287,3000.0,0.6625297665596008,1.5776896476745603,27.053601061051705,3003.0,7600.03609752655,12960.05086159706,7600.03609752655,5359.059168815613,0.2494087219238281,0.0 -21500,0.19829144,1.9691637,,,,,,,,,,,,,,,,, -21600,0.17483002,1.7600131,,,,,,,,,,,,,,,,, -21700,0.2210429,1.8010353,,,,,,,,,,,,,,,,, -21800,0.19706993,1.8372844,,,,,,,,,,,,,,,,, -21900,0.20813103,1.9432242,,,,,,,,,,,,,,,,, -22000,0.26575404,2.016353,,,,,,,,,,,,,,,,, -22100,0.2096887,1.9080323,,,,,,,,,,,,,,,,, -22200,0.18400307,1.872612,,,,,,,,,,,,,,,,, -22300,0.2512458,1.8227301,,,,,,,,,,,,,,,,, -22400,0.16261108,1.83669,,,,,,,,,,,,,,,,, -22500,0.1953192,1.8827512,,,,,,,,,,,,,,,,, -22600,0.1886411,1.7397211,,,,,,,,,,,,,,,,, -22700,0.1939791,1.905143,,,,,,,,,,,,,,,,, -22800,0.21507333,1.8472222,,,,,,,,,,,,,,,,, -22900,0.21687359,1.8436831,,,,,,,,,,,,,,,,, -23000,0.17181301,1.9192798,,,,,,,,,,,,,,,,, -23100,0.20598951,1.8718436,,,,,,,,,,,,,,,,, -23200,0.18688808,1.8507324,,,,,,,,,,,,,,,,, -23300,0.18687885,1.8205991,,,,,,,,,,,,,,,,, -23400,0.21424931,1.8583192,,,,,,,,,,,,,,,,, -23500,0.22859198,1.8469645,,,,,,,,,,,,,,,,, -23600,0.18225072,1.8554187,,,,,,,,,,,,,,,,, -23700,0.20536426,1.7984954,,,,,,,,,,,,,,,,, -23779,,,0.6351043581962585,1.777759313583374,30.41944310720691,0.6547470092773438,1.6266250610351562,27.47410007445849,3000.0,0.6655743718147278,1.5527108907699585,26.95498017227819,3003.0,8440.063966989517,14320.875417470932,8440.063966989517,5879.748903989792,0.2777047157287597,0.0 -23800,0.20501456,1.8423356,,,,,,,,,,,,,,,,, -23900,0.19857368,1.9090341,,,,,,,,,,,,,,,,, -24000,0.2192931,1.8539087,,,,,,,,,,,,,,,,, -24100,0.26413706,1.8666446,,,,,,,,,,,,,,,,, -24200,0.22835988,1.8090359,,,,,,,,,,,,,,,,, -24300,0.20896056,1.8457825,,,,,,,,,,,,,,,,, -24400,0.22116889,1.9016933,,,,,,,,,,,,,,,,, -24500,0.1930753,1.8774203,,,,,,,,,,,,,,,,, -24600,0.19400063,1.870364,,,,,,,,,,,,,,,,, -24700,0.2248287,1.819538,,,,,,,,,,,,,,,,, -24800,0.213997,1.8459158,,,,,,,,,,,,,,,,, -24900,0.24145839,1.8765267,,,,,,,,,,,,,,,,, -25000,0.18738948,1.7511712,,,,,,,,,,,,,,,,, -25100,0.18851398,1.7647623,,,,,,,,,,,,,,,,, -25200,0.19408767,1.8232973,,,,,,,,,,,,,,,,, -25300,0.1889055,1.8550423,,,,,,,,,,,,,,,,, -25400,0.19386378,1.7565597,,,,,,,,,,,,,,,,, -25500,0.18731314,1.7399473,,,,,,,,,,,,,,,,, -25600,0.203623,1.7964047,,,,,,,,,,,,,,,,, -25700,0.18430486,1.8221332,,,,,,,,,,,,,,,,, -25800,0.22570123,1.9144386,,,,,,,,,,,,,,,,, -25900,0.18966337,1.7784197,,,,,,,,,,,,,,,,, -26000,0.21156281,1.8513993,,,,,,,,,,,,,,,,, -26100,0.18113585,1.7997133,,,,,,,,,,,,,,,,, -26158,,,0.6418652534484863,1.7051342725753784,31.214415856423287,0.6578715443611145,1.613921284675598,28.10912777034857,3000.0,0.6680960059165955,1.5441603660583496,27.35011242889108,3003.0,9280.269245147703,15687.564682245256,9280.269245147703,6406.126123189926,0.3062803745269775,0.0 -26200,0.21060354,1.8856652,,,,,,,,,,,,,,,,, -26300,0.21717899,1.8474296,,,,,,,,,,,,,,,,, -26400,0.17452273,1.8099694,,,,,,,,,,,,,,,,, -26500,0.1961163,1.8538325,,,,,,,,,,,,,,,,, -26600,0.1900614,1.7116257,,,,,,,,,,,,,,,,, -26700,0.25657302,1.7592491,,,,,,,,,,,,,,,,, -26800,0.20096102,1.9128401,,,,,,,,,,,,,,,,, -26900,0.19510068,1.8226575,,,,,,,,,,,,,,,,, -27000,0.3921489,1.9052517,,,,,,,,,,,,,,,,, -27100,0.19999191,1.8199589,,,,,,,,,,,,,,,,, -27200,0.21685413,1.8347129,,,,,,,,,,,,,,,,, -27300,0.21612836,1.839121,,,,,,,,,,,,,,,,, -27400,0.1973543,1.8030173,,,,,,,,,,,,,,,,, -27500,0.23354994,1.918038,,,,,,,,,,,,,,,,, -27600,0.19096693,1.8127202,,,,,,,,,,,,,,,,, -27700,0.17759897,1.8092052,,,,,,,,,,,,,,,,, -27800,0.17434657,1.8128297,,,,,,,,,,,,,,,,, -27900,0.19726719,1.8025231,,,,,,,,,,,,,,,,, -28000,0.20593846,1.8859487,,,,,,,,,,,,,,,,, -28100,0.1816204,1.8192002,,,,,,,,,,,,,,,,, -28200,0.2123637,1.8266762,,,,,,,,,,,,,,,,, -28300,0.19719025,1.7990547,,,,,,,,,,,,,,,,, -28400,0.19991249,1.8564001,,,,,,,,,,,,,,,,, -28500,0.20866604,1.9065232,,,,,,,,,,,,,,,,, -28536,,,0.6390178799629211,1.7468416690826416,31.29680946133006,0.6594090461730957,1.6080009937286377,28.322223092791035,3000.0,0.6701993346214294,1.5299732685089111,27.660913346503644,3003.0,10120.285497665403,17130.83499264717,10120.285497665403,7009.274637699127,0.3344330787658691,0.0 -28600,0.23103422,1.8596785,,,,,,,,,,,,,,,,, -28700,0.16529727,1.8176689,,,,,,,,,,,,,,,,, -28800,0.28666937,1.8614384,,,,,,,,,,,,,,,,, -28900,0.27328062,1.7979336,,,,,,,,,,,,,,,,, -29000,0.24709308,1.7766209,,,,,,,,,,,,,,,,, -29100,0.19894691,1.8675776,,,,,,,,,,,,,,,,, -29200,0.22435163,1.7715595,,,,,,,,,,,,,,,,, -29300,0.21564505,1.8436042,,,,,,,,,,,,,,,,, -29400,0.21709378,1.8153942,,,,,,,,,,,,,,,,, -29500,0.19417925,1.8612757,,,,,,,,,,,,,,,,, -29600,0.20897314,1.7979466,,,,,,,,,,,,,,,,, -29700,0.21118312,1.8215654,,,,,,,,,,,,,,,,, -29800,0.21512508,1.8776476,,,,,,,,,,,,,,,,, -29900,0.19698642,1.8478373,,,,,,,,,,,,,,,,, -30000,0.21070191,1.819598,,,,,,,,,,,,,,,,, -30100,0.21374255,1.8346599,,,,,,,,,,,,,,,,, -30200,0.21711525,1.736206,,,,,,,,,,,,,,,,, -30300,0.2798209,1.803421,,,,,,,,,,,,,,,,, -30400,0.2639892,1.868388,,,,,,,,,,,,,,,,, -30500,0.24223348,1.8381336,,,,,,,,,,,,,,,,, -30600,0.20955351,1.7349447,,,,,,,,,,,,,,,,, -30700,0.22848283,1.746373,,,,,,,,,,,,,,,,, -30800,0.25246552,1.808304,,,,,,,,,,,,,,,,, -30900,0.34510648,1.7932562,,,,,,,,,,,,,,,,, -30915,,,0.6399180293083191,1.7383826971054075,31.006162959312263,0.659421443939209,1.5983082056045532,27.883103402956426,3000.0,0.669839084148407,1.522180676460266,27.341967471219807,3003.0,10960.433304071426,18515.20808935165,10960.433304071426,7553.392800569534,0.3641421794891357,0.0 -31000,0.22003889,1.7769346,,,,,,,,,,,,,,,,, -31100,0.2654254,1.8789912,,,,,,,,,,,,,,,,, -31200,0.20685688,1.7780747,,,,,,,,,,,,,,,,, -31300,0.21182278,1.8799176,,,,,,,,,,,,,,,,, -31400,0.24022032,1.7696518,,,,,,,,,,,,,,,,, -31500,0.24348207,1.7645226,,,,,,,,,,,,,,,,, -31600,0.20168428,1.8156338,,,,,,,,,,,,,,,,, -31700,0.19943844,1.8661803,,,,,,,,,,,,,,,,, -31800,0.20148635,1.7570131,,,,,,,,,,,,,,,,, -31900,0.19075504,1.866328,,,,,,,,,,,,,,,,, -32000,0.21175987,1.7724226,,,,,,,,,,,,,,,,, -32100,0.20236409,1.8044789,,,,,,,,,,,,,,,,, -32200,0.30768645,1.836429,,,,,,,,,,,,,,,,, -32300,0.22809426,1.7108355,,,,,,,,,,,,,,,,, -32400,0.18311176,1.758576,,,,,,,,,,,,,,,,, -32500,0.19197161,1.7342278,,,,,,,,,,,,,,,,, -32600,0.2017499,1.7667481,,,,,,,,,,,,,,,,, -32700,0.2261042,1.8520325,,,,,,,,,,,,,,,,, -32800,1.846073,1.8732679,,,,,,,,,,,,,,,,, -32900,0.19353946,1.8182184,,,,,,,,,,,,,,,,, -33000,0.27240488,1.8286738,,,,,,,,,,,,,,,,, -33100,0.23020712,1.8362348,,,,,,,,,,,,,,,,, -33200,0.21169224,1.8183953,,,,,,,,,,,,,,,,, -33293,,,0.6442225575447083,1.6994339227676392,31.78765129289712,0.6610333323478699,1.5850367546081543,28.3620764574474,3000.0,0.6725698709487915,1.5092467069625854,27.681689309528668,3003.0,11800.462206363678,19879.156966924667,11800.462206363678,8077.200543165207,0.3959531784057617,0.0 -33300,0.1794993,1.7902287,,,,,,,,,,,,,,,,, -33400,0.22221172,1.7507092,,,,,,,,,,,,,,,,, -33500,0.17492235,1.8223885,,,,,,,,,,,,,,,,, -33600,0.17973262,1.7534604,,,,,,,,,,,,,,,,, -33700,0.18118803,1.8385676,,,,,,,,,,,,,,,,, -33800,0.24531843,1.7913295,,,,,,,,,,,,,,,,, -33900,0.194095,1.8863282,,,,,,,,,,,,,,,,, -34000,0.21777004,1.8173254,,,,,,,,,,,,,,,,, -34100,0.18985827,1.7320999,,,,,,,,,,,,,,,,, -34200,0.20535815,1.8161235,,,,,,,,,,,,,,,,, -34300,0.23133728,1.8115035,,,,,,,,,,,,,,,,, -34400,0.20799494,1.7598393,,,,,,,,,,,,,,,,, -34500,0.19447057,1.7906319,,,,,,,,,,,,,,,,, -34600,0.2028362,1.7784295,,,,,,,,,,,,,,,,, -34700,0.29873613,1.794273,,,,,,,,,,,,,,,,, -34800,0.17637303,1.7427757,,,,,,,,,,,,,,,,, -34900,0.19533554,1.8418314,,,,,,,,,,,,,,,,, -35000,0.207437,1.7543538,,,,,,,,,,,,,,,,, -35100,0.22027537,1.816436,,,,,,,,,,,,,,,,, -35200,0.1838615,1.7886075,,,,,,,,,,,,,,,,, -35300,0.23913367,1.756546,,,,,,,,,,,,,,,,, -35400,0.19295122,1.799235,,,,,,,,,,,,,,,,, -35500,0.1980839,1.828852,,,,,,,,,,,,,,,,, -35600,0.22010545,1.809904,,,,,,,,,,,,,,,,, -35670,,,0.6402103304862976,1.7327693700790403,31.204288012321065,0.6622112393379211,1.5775601863861084,28.52652591972902,3000.0,0.6746150851249695,1.4966024160385132,27.680088726518168,3003.0,12640.359733104706,21255.39935207367,12640.359733104706,8613.434501886368,0.4250335693359375,0.0 -35700,0.19133215,1.7502322,,,,,,,,,,,,,,,,, -35800,0.20579787,1.7510798,,,,,,,,,,,,,,,,, -35900,0.20854077,1.8489459,,,,,,,,,,,,,,,,, -36000,0.20242886,1.7571722,,,,,,,,,,,,,,,,, -36100,0.20137656,1.7177448,,,,,,,,,,,,,,,,, -36200,0.20038968,1.8684243,,,,,,,,,,,,,,,,, -36300,0.20748468,1.7101079,,,,,,,,,,,,,,,,, -36400,0.18514217,1.8396772,,,,,,,,,,,,,,,,, -36500,0.21108505,1.7496078,,,,,,,,,,,,,,,,, -36600,0.26260114,1.7784072,,,,,,,,,,,,,,,,, -36700,0.18819009,1.7149395,,,,,,,,,,,,,,,,, -36800,0.21740025,1.7628171,,,,,,,,,,,,,,,,, -36900,0.20189495,1.7252325,,,,,,,,,,,,,,,,, -37000,0.20916499,1.7754121,,,,,,,,,,,,,,,,, -37100,0.20536163,1.6718073,,,,,,,,,,,,,,,,, -37200,0.19460884,1.7560402,,,,,,,,,,,,,,,,, -37300,0.28544927,1.8295633,,,,,,,,,,,,,,,,, -37400,0.19807407,1.8388308,,,,,,,,,,,,,,,,, -37500,0.18231995,1.8206587,,,,,,,,,,,,,,,,, -37600,0.22445668,1.7175281,,,,,,,,,,,,,,,,, -37700,0.23499024,1.832592,,,,,,,,,,,,,,,,, -37800,0.23592606,1.7291359,,,,,,,,,,,,,,,,, -37900,0.18724167,1.7278473,,,,,,,,,,,,,,,,, -38000,0.17833078,1.7697263,,,,,,,,,,,,,,,,, -38049,,,0.6570454835891724,1.6043413877487185,32.339918945597454,0.6630792021751404,1.5772230625152588,28.517361145543024,3000.0,0.6761025190353394,1.4913530349731443,28.02324678272048,3003.0,13480.456578493118,22592.117666721344,13480.456578493118,9109.951312541962,0.4539699554443359,0.0 -38100,0.22427161,1.7511213,,,,,,,,,,,,,,,,, -38200,0.1983948,1.8024254,,,,,,,,,,,,,,,,, -38300,0.20117107,1.7366321,,,,,,,,,,,,,,,,, -38400,0.22576554,1.7947848,,,,,,,,,,,,,,,,, -38500,0.20269972,1.7395226,,,,,,,,,,,,,,,,, -38600,0.21049982,1.8414543,,,,,,,,,,,,,,,,, -38700,0.21373655,1.804062,,,,,,,,,,,,,,,,, -38800,0.17228945,1.7362622,,,,,,,,,,,,,,,,, -38900,0.2688076,1.7914177,,,,,,,,,,,,,,,,, -39000,0.19788013,1.7478472,,,,,,,,,,,,,,,,, -39100,0.21300985,1.7845151,,,,,,,,,,,,,,,,, -39200,0.20035425,1.8559074,,,,,,,,,,,,,,,,, -39300,0.21587761,1.8408595,,,,,,,,,,,,,,,,, -39400,0.19085285,1.7539392,,,,,,,,,,,,,,,,, -39500,0.20047288,1.7464821,,,,,,,,,,,,,,,,, -39600,0.22277099,1.7424983,,,,,,,,,,,,,,,,, -39700,0.20288134,1.7464976,,,,,,,,,,,,,,,,, -39800,0.19324806,1.8044508,,,,,,,,,,,,,,,,, -39900,0.18353216,1.695781,,,,,,,,,,,,,,,,, -40000,0.22171256,1.7006409,,,,,,,,,,,,,,,,, -40100,0.21346596,1.7163284,,,,,,,,,,,,,,,,, -40200,0.20171851,1.7899565,,,,,,,,,,,,,,,,, -40300,0.20273553,1.8036108,,,,,,,,,,,,,,,,, -40400,0.20101601,1.831574,,,,,,,,,,,,,,,,, -40427,,,0.6448788642883301,1.695878028869629,31.99471937349964,0.6641083359718323,1.5695807933807373,28.41923535197316,3000.0,0.6763232946395874,1.482987880706787,27.91822635671377,3003.0,14320.392687559128,24075.46028089524,14320.392687559128,9753.249324083328,0.4837007522583008,0.0 -40500,0.20915411,1.7519052,,,,,,,,,,,,,,,,, -40600,0.20548303,1.7544125,,,,,,,,,,,,,,,,, -40700,0.24296889,1.7765719,,,,,,,,,,,,,,,,, -40800,0.1982251,1.8023607,,,,,,,,,,,,,,,,, -40900,0.2230413,1.8205545,,,,,,,,,,,,,,,,, -41000,0.21364436,1.715141,,,,,,,,,,,,,,,,, -41100,0.24923016,1.7907563,,,,,,,,,,,,,,,,, -41200,0.18433662,1.7346559,,,,,,,,,,,,,,,,, -41300,0.19136831,1.8284215,,,,,,,,,,,,,,,,, -41400,0.21934874,1.789472,,,,,,,,,,,,,,,,, -41500,0.19032669,1.689282,,,,,,,,,,,,,,,,, -41600,0.31149033,1.7993144,,,,,,,,,,,,,,,,, -41700,0.20225325,1.7809495,,,,,,,,,,,,,,,,, -41800,0.2098184,1.7407374,,,,,,,,,,,,,,,,, -41900,0.23737136,1.7330545,,,,,,,,,,,,,,,,, -42000,0.21568662,1.7374965,,,,,,,,,,,,,,,,, -42100,0.22727868,1.7780219,,,,,,,,,,,,,,,,, -42200,0.20071793,1.7717854,,,,,,,,,,,,,,,,, -42300,0.20094582,1.78029,,,,,,,,,,,,,,,,, -42400,0.20540616,1.8245533,,,,,,,,,,,,,,,,, -42500,0.23241998,1.7266663,,,,,,,,,,,,,,,,, -42600,0.2002531,1.7343229,,,,,,,,,,,,,,,,, -42700,0.19471359,1.7605302,,,,,,,,,,,,,,,,, -42800,0.20642228,1.6869746,,,,,,,,,,,,,,,,, -42806,,,0.6462165117263794,1.6971834897994995,31.77797547770024,0.6654722094535828,1.557735800743103,28.557748479244324,3000.0,0.6781709790229797,1.4798240661621094,27.79226369399139,3003.0,15160.55874156952,25451.696177721024,15160.55874156952,10289.211467266085,0.5132253170013428,0.0 -42900,0.19751567,1.8385409,,,,,,,,,,,,,,,,, -43000,0.19891605,1.7700598,,,,,,,,,,,,,,,,, -43100,0.18624958,1.7420257,,,,,,,,,,,,,,,,, -43200,0.17389935,1.7144897,,,,,,,,,,,,,,,,, -43300,0.18326281,1.6992246,,,,,,,,,,,,,,,,, -43400,0.19849445,1.7756189,,,,,,,,,,,,,,,,, -43500,0.17337687,1.7363831,,,,,,,,,,,,,,,,, -43600,0.17602293,1.7491376,,,,,,,,,,,,,,,,, -43700,0.22953919,1.7942643,,,,,,,,,,,,,,,,, -43800,0.2337961,1.784884,,,,,,,,,,,,,,,,, -43900,0.19582514,1.7791337,,,,,,,,,,,,,,,,, -44000,0.19965424,1.7110242,,,,,,,,,,,,,,,,, -44100,0.18746078,1.8007361,,,,,,,,,,,,,,,,, -44200,0.25514817,1.7365441,,,,,,,,,,,,,,,,, -44300,0.2143491,1.8096792,,,,,,,,,,,,,,,,, -44400,0.20788871,1.7186188,,,,,,,,,,,,,,,,, -44500,0.2713071,1.8252196,,,,,,,,,,,,,,,,, -44600,0.19185755,1.7398152,,,,,,,,,,,,,,,,, -44700,0.20629437,1.7618307,,,,,,,,,,,,,,,,, -44800,0.20155124,1.7384566,,,,,,,,,,,,,,,,, -44900,0.22236101,1.7321546,,,,,,,,,,,,,,,,, -45000,0.2536101,1.7990081,,,,,,,,,,,,,,,,, -45100,0.25356093,1.7982957,,,,,,,,,,,,,,,,, -45184,,,0.6523397564888,1.6415144205093384,32.163557741677224,0.6671088933944702,1.5524981021881104,29.014261662014693,3000.0,0.6798791885375977,1.4696717262268066,28.49927497697626,3003.0,16000.626110076904,26812.56391477585,16000.626110076904,10809.894648313522,0.5482079982757568,0.0 -45200,0.18993123,1.7270243,,,,,,,,,,,,,,,,, -45300,0.20959158,1.7249639,,,,,,,,,,,,,,,,, -45400,0.22212099,1.7720318,,,,,,,,,,,,,,,,, -45500,0.27382544,1.777569,,,,,,,,,,,,,,,,, -45600,0.21997185,1.7771538,,,,,,,,,,,,,,,,, -45700,0.18717867,1.7332708,,,,,,,,,,,,,,,,, -45800,0.21225667,1.640692,,,,,,,,,,,,,,,,, -45900,0.18198821,1.7706703,,,,,,,,,,,,,,,,, -46000,0.2096073,1.7261254,,,,,,,,,,,,,,,,, -46100,0.20897505,1.801741,,,,,,,,,,,,,,,,, -46200,0.22758631,1.7329999,,,,,,,,,,,,,,,,, -46300,0.20777842,1.7495244,,,,,,,,,,,,,,,,, -46400,0.20242293,1.7953335,,,,,,,,,,,,,,,,, -46500,0.1920346,1.7747185,,,,,,,,,,,,,,,,, -46600,0.19003753,1.6320513,,,,,,,,,,,,,,,,, -46700,0.1880183,1.8283615,,,,,,,,,,,,,,,,, -46800,0.1898243,1.794216,,,,,,,,,,,,,,,,, -46900,0.20157641,1.7791344,,,,,,,,,,,,,,,,, -47000,1.4255248,1.810279,,,,,,,,,,,,,,,,, -47100,0.38253808,1.8104833,,,,,,,,,,,,,,,,, -47200,0.17930774,1.7763197,,,,,,,,,,,,,,,,, -47300,0.18945663,1.7001605,,,,,,,,,,,,,,,,, -47400,0.20649244,1.763262,,,,,,,,,,,,,,,,, -47500,0.1902775,1.7112836,,,,,,,,,,,,,,,,, -47562,,,0.6457970142364502,1.698169231414795,31.71187306106621,0.6662285327911377,1.5489721298217771,28.7745931280484,3000.0,0.679332971572876,1.4617844820022583,28.274578857046443,3003.0,16840.60457134247,28381.03490614891,16840.60457134247,11538.277686834335,0.579641580581665,0.0 -47600,0.22240208,1.794682,,,,,,,,,,,,,,,,, -47700,0.20891285,1.6833589,,,,,,,,,,,,,,,,, -47800,0.186682,1.7158754,,,,,,,,,,,,,,,,, -47900,0.19680737,1.7989067,,,,,,,,,,,,,,,,, -48000,0.22508337,1.7863767,,,,,,,,,,,,,,,,, -48100,0.2088502,1.7769556,,,,,,,,,,,,,,,,, -48200,0.21740082,1.7077259,,,,,,,,,,,,,,,,, -48300,0.22481054,1.7547216,,,,,,,,,,,,,,,,, -48400,0.2103663,1.7479727,,,,,,,,,,,,,,,,, -48500,1.019814,1.7253852,,,,,,,,,,,,,,,,, -48600,0.22217312,1.8235914,,,,,,,,,,,,,,,,, -48700,0.19795306,1.7711885,,,,,,,,,,,,,,,,, -48800,0.23856813,1.7619293,,,,,,,,,,,,,,,,, -48900,0.20441796,1.8179344,,,,,,,,,,,,,,,,, -49000,0.19296135,1.7843289,,,,,,,,,,,,,,,,, -49100,0.19354893,1.7265327,,,,,,,,,,,,,,,,, -49200,0.20893252,1.7118989,,,,,,,,,,,,,,,,, -49300,0.22129971,1.7126496,,,,,,,,,,,,,,,,, -49400,0.22520937,1.7568743,,,,,,,,,,,,,,,,, -49500,0.19163743,1.8058555,,,,,,,,,,,,,,,,, -49600,0.23577967,1.8711598,,,,,,,,,,,,,,,,, -49700,0.21981305,1.757981,,,,,,,,,,,,,,,,, -49800,0.19192696,1.7195477,,,,,,,,,,,,,,,,, -49900,0.18804811,1.7119156,,,,,,,,,,,,,,,,, -49942,,,0.6485532522201538,1.6859617233276367,31.9474758616229,0.6677164435386658,1.5421111583709717,28.515111507698936,3000.0,0.6789495348930359,1.4625924825668335,28.027025139918145,3003.0,17680.672848939896,29781.36154270172,17680.672848939896,12098.422665834429,0.6169235706329346,0.0 -50000,0.19414341,1.6624271,,,,,,,,,,,,,,,,, -50100,0.21162459,1.7080654,,,,,,,,,,,,,,,,, -50200,0.19820365,1.6353569,,,,,,,,,,,,,,,,, -50300,0.18769097,1.7696016,,,,,,,,,,,,,,,,, -50400,0.20938341,1.753985,,,,,,,,,,,,,,,,, -50500,0.18485227,1.7139149,,,,,,,,,,,,,,,,, -50600,0.31457636,1.7534764,,,,,,,,,,,,,,,,, -50700,0.22037981,1.7912251,,,,,,,,,,,,,,,,, -50800,0.21091127,1.7266183,,,,,,,,,,,,,,,,, -50900,0.20735525,1.7608675,,,,,,,,,,,,,,,,, -51000,0.2201319,1.7451336,,,,,,,,,,,,,,,,, -51100,0.2102912,1.7191824,,,,,,,,,,,,,,,,, -51200,0.20549664,1.752667,,,,,,,,,,,,,,,,, -51300,0.22989051,1.7688469,,,,,,,,,,,,,,,,, -51400,0.2052081,1.7320698,,,,,,,,,,,,,,,,, -51500,0.20099954,1.6987643,,,,,,,,,,,,,,,,, -51600,0.20809074,1.7147359,,,,,,,,,,,,,,,,, -51700,0.18847471,1.7468014,,,,,,,,,,,,,,,,, -51800,0.20160668,1.8500415,,,,,,,,,,,,,,,,, -51900,0.19874735,1.7637308,,,,,,,,,,,,,,,,, -52000,0.20381838,1.7469021,,,,,,,,,,,,,,,,, -52100,0.19109027,1.8251745,,,,,,,,,,,,,,,,, -52200,0.20267142,1.6595176,,,,,,,,,,,,,,,,, -52300,0.20025815,1.725719,,,,,,,,,,,,,,,,, -52321,,,0.6484258770942688,1.6598361730575562,31.91824831416189,0.6704814434051514,1.5352940559387207,28.95387046457841,3000.0,0.6807622909545898,1.4502663612365725,28.34466917523456,3003.0,18520.70462822914,31196.87852740288,18520.70462822914,12673.795425891876,0.6527571678161621,0.0 -52400,0.18616953,1.7763172,,,,,,,,,,,,,,,,, -52500,0.20246616,1.7409912,,,,,,,,,,,,,,,,, -52600,0.20646888,1.7478182,,,,,,,,,,,,,,,,, -52700,0.19474551,1.6449916,,,,,,,,,,,,,,,,, -52800,0.18407787,1.7694093,,,,,,,,,,,,,,,,, -52900,0.1963129,1.7337615,,,,,,,,,,,,,,,,, -53000,0.22436064,1.6890199,,,,,,,,,,,,,,,,, -53100,0.18627584,1.7053643,,,,,,,,,,,,,,,,, -53200,0.22839084,1.8635772,,,,,,,,,,,,,,,,, -53300,0.21902649,1.6754315,,,,,,,,,,,,,,,,, -53400,0.20103635,1.8067132,,,,,,,,,,,,,,,,, -53500,0.20245622,1.7058922,,,,,,,,,,,,,,,,, -53600,0.19517107,1.7861552,,,,,,,,,,,,,,,,, -53700,0.23679593,1.6733891,,,,,,,,,,,,,,,,, -53800,0.19744109,1.7550492,,,,,,,,,,,,,,,,, -53900,0.23462586,1.6920327,,,,,,,,,,,,,,,,, -54000,0.18624943,1.7835435,,,,,,,,,,,,,,,,, -54100,0.20002678,1.7794605,,,,,,,,,,,,,,,,, -54200,0.185207,1.7157165,,,,,,,,,,,,,,,,, -54300,0.1916319,1.7761832,,,,,,,,,,,,,,,,, -54400,0.20079248,1.7782513,,,,,,,,,,,,,,,,, -54500,0.20846781,1.6790999,,,,,,,,,,,,,,,,, -54600,0.18422553,1.7532595,,,,,,,,,,,,,,,,, -54700,0.20532788,1.6775528,,,,,,,,,,,,,,,,, -54701,,,0.6495399475097656,1.6735833883285522,31.643514298402373,0.6713989973068237,1.5312122106552124,29.22028575150289,3000.0,0.6822962164878845,1.4418914318084717,28.36710670375829,3003.0,19361.222714662552,32570.93722343445,19361.222714662552,13207.22210764885,0.6896021366119385,0.0 -54800,0.23569736,1.7965143,,,,,,,,,,,,,,,,, -54900,0.18804397,1.637542,,,,,,,,,,,,,,,,, -55000,0.19537666,1.7251933,,,,,,,,,,,,,,,,, -55100,0.1852964,1.7356541,,,,,,,,,,,,,,,,, -55200,0.18237968,1.6935657,,,,,,,,,,,,,,,,, -55300,0.18781263,1.7290614,,,,,,,,,,,,,,,,, -55400,0.68159974,1.9216427,,,,,,,,,,,,,,,,, -55500,0.23150328,1.7741599,,,,,,,,,,,,,,,,, -55600,0.23775353,1.6784534,,,,,,,,,,,,,,,,, -55700,0.23650141,1.8278252,,,,,,,,,,,,,,,,, -55800,1.003006,1.72805,,,,,,,,,,,,,,,,, -55900,0.2091491,1.7958347,,,,,,,,,,,,,,,,, -56000,0.26478907,1.7725627,,,,,,,,,,,,,,,,, -56100,0.18956524,1.726686,,,,,,,,,,,,,,,,, -56200,0.22983252,1.8233333,,,,,,,,,,,,,,,,, -56300,0.20464998,1.7213221,,,,,,,,,,,,,,,,, -56400,0.21801353,1.685992,,,,,,,,,,,,,,,,, -56500,0.2366285,1.7447724,,,,,,,,,,,,,,,,, -56600,0.21988285,1.8280032,,,,,,,,,,,,,,,,, -56700,0.19045937,1.8212928,,,,,,,,,,,,,,,,, -56800,0.18216996,1.7921889,,,,,,,,,,,,,,,,, -56900,0.21221298,1.7205293,,,,,,,,,,,,,,,,, -57000,0.21219082,1.7265451,,,,,,,,,,,,,,,,, -57079,,,0.6603872179985046,1.5857213735580444,32.61815973523987,0.670741856098175,1.5168012380599976,29.064819889773123,3000.0,0.6844227910041809,1.4322528839111328,28.452016178586515,3003.0,20201.32160258293,33897.81595611572,20201.32160258293,13693.890861272812,0.7218029499053955,0.0 -57100,0.18912646,1.6368963,,,,,,,,,,,,,,,,, -57200,0.24181029,1.7942418,,,,,,,,,,,,,,,,, -57300,0.19941424,1.7296442,,,,,,,,,,,,,,,,, -57400,0.19968635,1.6966096,,,,,,,,,,,,,,,,, -57500,0.20595022,1.7527113,,,,,,,,,,,,,,,,, -57600,0.19476864,1.7097213,,,,,,,,,,,,,,,,, -57700,0.19789197,1.6657948,,,,,,,,,,,,,,,,, -57800,0.2030541,1.6591707,,,,,,,,,,,,,,,,, -57900,0.5978358,1.6930599,,,,,,,,,,,,,,,,, -58000,0.19206437,1.663063,,,,,,,,,,,,,,,,, -58100,0.17957349,1.6604745,,,,,,,,,,,,,,,,, -58200,0.19661987,1.7230599,,,,,,,,,,,,,,,,, -58300,0.19476265,1.6363468,,,,,,,,,,,,,,,,, -58400,0.19949096,1.6917048,,,,,,,,,,,,,,,,, -58500,0.19036235,1.6544399,,,,,,,,,,,,,,,,, -58600,0.21580884,1.7334211,,,,,,,,,,,,,,,,, -58700,0.23263814,1.6646061,,,,,,,,,,,,,,,,, -58800,0.19413465,1.6642785,,,,,,,,,,,,,,,,, -58900,0.22824089,1.781217,,,,,,,,,,,,,,,,, -59000,0.21263106,1.6718817,,,,,,,,,,,,,,,,, -59100,0.1974873,1.6956493,,,,,,,,,,,,,,,,, -59200,0.21812451,1.7311373,,,,,,,,,,,,,,,,, -59300,0.18174194,1.7119299,,,,,,,,,,,,,,,,, -59400,0.21282999,1.6428863,,,,,,,,,,,,,,,,, -59456,,,0.6552941799163818,1.624788522720337,32.59673470601075,0.6723288893699646,1.5095142126083374,29.290384592689776,3000.0,0.6850734949111938,1.4224475622177124,28.68317868851248,3003.0,21041.213926553726,35428.21695756912,21041.213926553726,14384.28313088417,0.7561678886413574,0.0 -59500,0.20650367,1.6761969,,,,,,,,,,,,,,,,, -59600,0.22693598,1.7564304,,,,,,,,,,,,,,,,, -59700,0.24719134,1.6992307,,,,,,,,,,,,,,,,, -59800,0.19024879,1.6892588,,,,,,,,,,,,,,,,, -59900,0.24611664,1.6582736,,,,,,,,,,,,,,,,, -60000,0.21066387,1.7346209,,,,,,,,,,,,,,,,, -60100,0.18997167,1.761331,,,,,,,,,,,,,,,,, -60200,0.21188278,1.7997323,,,,,,,,,,,,,,,,, -60300,0.21966824,1.6790578,,,,,,,,,,,,,,,,, -60400,0.20997158,1.7289025,,,,,,,,,,,,,,,,, -60500,0.21178825,1.683139,,,,,,,,,,,,,,,,, -60600,0.20468514,1.7123873,,,,,,,,,,,,,,,,, -60700,0.20682122,1.7418867,,,,,,,,,,,,,,,,, -60800,0.22096787,1.680395,,,,,,,,,,,,,,,,, -60900,0.192889,1.7670585,,,,,,,,,,,,,,,,, -61000,0.19079769,1.7457559,,,,,,,,,,,,,,,,, -61100,0.22585739,1.7552017,,,,,,,,,,,,,,,,, -61200,0.18135042,1.7074338,,,,,,,,,,,,,,,,, -61300,0.20965019,1.6930425,,,,,,,,,,,,,,,,, -61400,0.20418847,1.6801592,,,,,,,,,,,,,,,,, -61500,0.20694925,1.7465978,,,,,,,,,,,,,,,,, -61600,0.18966375,1.7721562,,,,,,,,,,,,,,,,, -61700,0.20831351,1.7421203,,,,,,,,,,,,,,,,, -61800,0.19704518,1.6731172,,,,,,,,,,,,,,,,, -61835,,,0.6523348689079285,1.646162033081055,32.40031350898231,0.6736432313919067,1.504443645477295,29.26776115359916,3000.0,0.6883040070533752,1.4157791137695312,28.97549138929193,3003.0,21881.345286130905,36811.70310878754,21881.345286130905,14927.5224506855,0.7940044403076172,0.0 -61900,0.19340232,1.6823426,,,,,,,,,,,,,,,,, -62000,0.18395634,1.762119,,,,,,,,,,,,,,,,, -62100,0.18801065,1.6696563,,,,,,,,,,,,,,,,, -62200,0.18262514,1.7395428,,,,,,,,,,,,,,,,, -62300,0.18540817,1.6239135,,,,,,,,,,,,,,,,, -62400,0.210851,1.658964,,,,,,,,,,,,,,,,, -62500,0.19961397,1.7391695,,,,,,,,,,,,,,,,, -62600,0.21082662,1.6926385,,,,,,,,,,,,,,,,, -62700,0.19178566,1.7657485,,,,,,,,,,,,,,,,, -62800,0.20428932,1.7602105,,,,,,,,,,,,,,,,, -62900,0.19589238,1.7878064,,,,,,,,,,,,,,,,, -63000,0.19196948,1.7164243,,,,,,,,,,,,,,,,, -63100,0.21859817,1.7130218,,,,,,,,,,,,,,,,, -63200,0.19033043,1.7049421,,,,,,,,,,,,,,,,, -63300,0.19769305,1.676414,,,,,,,,,,,,,,,,, -63400,0.18305838,1.7007178,,,,,,,,,,,,,,,,, -63500,0.19790217,1.72536,,,,,,,,,,,,,,,,, -63600,0.21959601,1.7503723,,,,,,,,,,,,,,,,, -63700,0.1917392,1.6828601,,,,,,,,,,,,,,,,, -63800,0.2029693,1.7415036,,,,,,,,,,,,,,,,, -63900,0.20060176,1.6241789,,,,,,,,,,,,,,,,, -64000,0.23852666,1.6198902,,,,,,,,,,,,,,,,, -64100,0.2084973,1.6883657,,,,,,,,,,,,,,,,, -64200,0.22815737,1.7537786,,,,,,,,,,,,,,,,, -64214,,,0.6573935151100159,1.6122422218322754,32.14257821076103,0.6687331795692444,1.5286437273025513,28.58250826805473,3000.0,0.6827958822250366,1.4405723810195925,28.18834542843649,3003.0,22721.47764348984,38188.96942901611,22721.47764348984,15464.544717550278,0.8278894424438477,0.0 -64300,0.19503838,1.7383652,,,,,,,,,,,,,,,,, -64400,0.19225168,1.7739933,,,,,,,,,,,,,,,,, -64500,0.19753002,1.7416704,,,,,,,,,,,,,,,,, -64600,0.19920002,1.7633548,,,,,,,,,,,,,,,,, -64700,0.19379424,1.7118248,,,,,,,,,,,,,,,,, -64800,0.1938813,1.7621608,,,,,,,,,,,,,,,,, -64900,0.19218032,1.6363224,,,,,,,,,,,,,,,,, -65000,0.1935789,1.6656636,,,,,,,,,,,,,,,,, -65100,0.22116065,1.6565564,,,,,,,,,,,,,,,,, -65200,0.19459261,1.7191402,,,,,,,,,,,,,,,,, -65300,0.19403662,1.6567026,,,,,,,,,,,,,,,,, -65400,0.18893814,1.6886462,,,,,,,,,,,,,,,,, -65500,0.2116206,1.6275233,,,,,,,,,,,,,,,,, -65600,0.26599446,1.6740665,,,,,,,,,,,,,,,,, -65700,0.18422227,1.6004472,,,,,,,,,,,,,,,,, -65800,0.20086835,1.7187334,,,,,,,,,,,,,,,,, -65900,0.18229015,1.6563892,,,,,,,,,,,,,,,,, -66000,0.18433881,1.7215881,,,,,,,,,,,,,,,,, -66100,0.19410846,1.7033243,,,,,,,,,,,,,,,,, -66200,0.20380425,1.6337118,,,,,,,,,,,,,,,,, -66300,0.20914422,1.6744133,,,,,,,,,,,,,,,,, -66400,0.19504252,1.6497561,,,,,,,,,,,,,,,,, -66500,0.20709656,1.6531987,,,,,,,,,,,,,,,,, -66592,,,0.6559733152389526,1.6213107109069824,32.54262723353351,0.6742507815361023,1.4933558702468872,29.09013429115508,3000.0,0.6875951886177063,1.4001637697219849,28.568736458659583,3003.0,23561.66490626335,39671.88665890694,23561.66490626335,16107.16114974022,0.8622546195983887,0.0 -66600,0.18802607,1.6693339,,,,,,,,,,,,,,,,, -66700,0.19820842,1.7539004,,,,,,,,,,,,,,,,, -66800,0.2148903,1.6763362,,,,,,,,,,,,,,,,, -66900,0.21005745,1.6316547,,,,,,,,,,,,,,,,, -67000,0.19732338,1.7296203,,,,,,,,,,,,,,,,, -67100,0.19501728,1.6691362,,,,,,,,,,,,,,,,, -67200,0.2199148,1.669977,,,,,,,,,,,,,,,,, -67300,0.19323309,1.692778,,,,,,,,,,,,,,,,, -67400,0.21224578,1.6381879,,,,,,,,,,,,,,,,, -67500,0.20779482,1.6150078,,,,,,,,,,,,,,,,, -67600,0.20997709,1.6238488,,,,,,,,,,,,,,,,, -67700,0.19803922,1.6966188,,,,,,,,,,,,,,,,, -67800,0.1946392,1.7046307,,,,,,,,,,,,,,,,, -67900,0.19431202,1.6756064,,,,,,,,,,,,,,,,, -68000,0.18728805,1.618154,,,,,,,,,,,,,,,,, -68100,0.20050968,1.6344095,,,,,,,,,,,,,,,,, -68200,0.20174755,1.6262001,,,,,,,,,,,,,,,,, -68300,0.24752264,1.7284396,,,,,,,,,,,,,,,,, -68400,0.2275558,1.6437846,,,,,,,,,,,,,,,,, -68500,0.20700283,1.6906236,,,,,,,,,,,,,,,,, -68600,0.20346452,1.6627998,,,,,,,,,,,,,,,,, -68700,0.21645989,1.6762936,,,,,,,,,,,,,,,,, -68800,0.20660335,1.7024512,,,,,,,,,,,,,,,,, -68900,0.18940274,1.7436559,,,,,,,,,,,,,,,,, -68970,,,0.6814596652984619,1.4550007581710815,34.3580861992117,0.6773753762245178,1.4872820377349854,29.381252753878385,3000.0,0.6907442808151245,1.3908847570419312,29.05332517746374,3003.0,24401.740617513657,41021.95736527443,24401.740617513657,16617.040287017822,0.8978571891784668,0.0 -69000,0.21179006,1.7045044,,,,,,,,,,,,,,,,, -69100,0.20647095,1.7303703,,,,,,,,,,,,,,,,, -69200,0.2091419,1.6678361,,,,,,,,,,,,,,,,, -69300,0.21000727,1.5948225,,,,,,,,,,,,,,,,, -69400,0.19769523,1.6419556,,,,,,,,,,,,,,,,, -69500,0.19674985,1.70129,,,,,,,,,,,,,,,,, -69600,0.20811804,1.6532255,,,,,,,,,,,,,,,,, -69700,0.19732848,1.6289921,,,,,,,,,,,,,,,,, -69800,0.18801631,1.6749214,,,,,,,,,,,,,,,,, -69900,0.19289355,1.665151,,,,,,,,,,,,,,,,, -70000,0.20476195,1.6423291,,,,,,,,,,,,,,,,, -70100,0.20251831,1.6989583,,,,,,,,,,,,,,,,, -70200,0.23271789,1.6425632,,,,,,,,,,,,,,,,, -70300,0.2097583,1.6384885,,,,,,,,,,,,,,,,, -70400,0.20739773,1.7051715,,,,,,,,,,,,,,,,, -70500,0.1844471,1.5779055,,,,,,,,,,,,,,,,, -70600,0.19178896,1.7060273,,,,,,,,,,,,,,,,, -70700,0.21192563,1.6770645,,,,,,,,,,,,,,,,, -70800,0.18986773,1.6563803,,,,,,,,,,,,,,,,, -70900,0.19448945,1.6036363,,,,,,,,,,,,,,,,, -71000,0.19447796,1.7170552,,,,,,,,,,,,,,,,, -71100,0.20667812,1.6719058,,,,,,,,,,,,,,,,, -71200,0.2244544,1.6705756,,,,,,,,,,,,,,,,, -71300,0.19739614,1.7425177,,,,,,,,,,,,,,,,, -71348,,,0.663874626159668,1.576033115386963,32.56455290267204,0.6783424615859985,1.4791110754013062,29.55113498636966,3000.0,0.6910116076469421,1.389503002166748,29.21504352470452,3003.0,25241.72496700287,42709.18280673027,25241.72496700287,17464.169107675552,0.9319391250610352,0.0 -71400,0.2040732,1.6649387,,,,,,,,,,,,,,,,, -71500,0.19596206,1.6238594,,,,,,,,,,,,,,,,, -71600,0.19215916,1.6471368,,,,,,,,,,,,,,,,, -71700,0.20846659,1.6536474,,,,,,,,,,,,,,,,, -71800,0.19848037,1.6572036,,,,,,,,,,,,,,,,, -71900,0.21131249,1.6851518,,,,,,,,,,,,,,,,, -72000,0.20471646,1.6490309,,,,,,,,,,,,,,,,, -72100,0.2011104,1.6790931,,,,,,,,,,,,,,,,, -72200,0.21534556,1.7474484,,,,,,,,,,,,,,,,, -72300,0.20216629,1.6062281,,,,,,,,,,,,,,,,, -72400,0.1940414,1.6609159,,,,,,,,,,,,,,,,, -72500,0.2118035,1.7096103,,,,,,,,,,,,,,,,, -72600,0.19331107,1.6208171,,,,,,,,,,,,,,,,, -72700,0.2329418,1.6428102,,,,,,,,,,,,,,,,, -72800,0.19723037,1.7081715,,,,,,,,,,,,,,,,, -72900,0.19948283,1.6381538,,,,,,,,,,,,,,,,, -73000,0.2059128,1.7178324,,,,,,,,,,,,,,,,, -73100,0.21086466,1.6210933,,,,,,,,,,,,,,,,, -73200,0.20636514,1.6059108,,,,,,,,,,,,,,,,, -73300,0.22112434,1.6071452,,,,,,,,,,,,,,,,, -73400,0.22264606,1.6324862,,,,,,,,,,,,,,,,, -73500,0.18681207,1.6235605,,,,,,,,,,,,,,,,, -73600,0.21228196,1.5928185,,,,,,,,,,,,,,,,, -73700,0.20719124,1.625299,,,,,,,,,,,,,,,,, -73727,,,0.6613771915435791,1.5929934978485107,32.78422420721934,0.6800411343574524,1.4665690660476685,29.65939306819033,3000.0,0.6914764046669006,1.378501296043396,29.265794145234867,3003.0,26081.76236653328,44122.5883204937,26081.76236653328,18037.4237639904,0.9692091941833496,0.0 -73800,0.19440015,1.631274,,,,,,,,,,,,,,,,, -73900,0.21544833,1.6291205,,,,,,,,,,,,,,,,, -74000,0.20187391,1.5969563,,,,,,,,,,,,,,,,, -74100,0.20689213,1.6109267,,,,,,,,,,,,,,,,, -74200,0.19316559,1.635652,,,,,,,,,,,,,,,,, -74300,0.2030869,1.6730922,,,,,,,,,,,,,,,,, -74400,0.202417,1.6351758,,,,,,,,,,,,,,,,, -74500,0.20678581,1.7956907,,,,,,,,,,,,,,,,, -74600,0.21778369,1.6162196,,,,,,,,,,,,,,,,, -74700,0.22538726,1.6649536,,,,,,,,,,,,,,,,, -74800,0.20463821,1.6582791,,,,,,,,,,,,,,,,, -74900,0.18875583,1.6042039,,,,,,,,,,,,,,,,, -75000,0.19103757,1.6514965,,,,,,,,,,,,,,,,, -75100,0.22238882,1.7350008,,,,,,,,,,,,,,,,, -75200,0.20774126,1.6923877,,,,,,,,,,,,,,,,, -75300,0.21126847,1.6725031,,,,,,,,,,,,,,,,, -75400,0.20536998,1.5821503,,,,,,,,,,,,,,,,, -75500,0.185423,1.6354039,,,,,,,,,,,,,,,,, -75600,0.20203066,1.6328312,,,,,,,,,,,,,,,,, -75700,0.19815186,1.7391053,,,,,,,,,,,,,,,,, -75800,0.21256702,1.7089351,,,,,,,,,,,,,,,,, -75900,0.21686944,1.605509,,,,,,,,,,,,,,,,, -76000,0.20643507,1.7763306,,,,,,,,,,,,,,,,, -76100,0.20497553,1.6236967,,,,,,,,,,,,,,,,, -76106,,,0.6700928211212158,1.5198198556900024,33.644985715093746,0.6783796548843384,1.4667606353759766,29.57411672308368,3000.0,0.6944977045059204,1.370276689529419,29.374789377540463,3003.0,26921.769094228745,45556.92244243622,26921.769094228745,18631.639093637463,1.004053831100464,0.0 -76200,0.19491535,1.652811,,,,,,,,,,,,,,,,, -76300,0.19345991,1.5479397,,,,,,,,,,,,,,,,, -76400,0.20703201,1.6598362,,,,,,,,,,,,,,,,, -76500,0.19215602,1.6313545,,,,,,,,,,,,,,,,, -76600,0.2107672,1.6151389,,,,,,,,,,,,,,,,, -76700,0.20867673,1.6281085,,,,,,,,,,,,,,,,, -76800,0.21145241,1.7168432,,,,,,,,,,,,,,,,, -76900,0.18618323,1.5890229,,,,,,,,,,,,,,,,, -77000,0.19560236,1.6428741,,,,,,,,,,,,,,,,, -77100,0.20667246,1.6501199,,,,,,,,,,,,,,,,, -77200,0.21010539,1.5829979,,,,,,,,,,,,,,,,, -77300,0.21574856,1.6790315,,,,,,,,,,,,,,,,, -77400,0.21133803,1.6509682,,,,,,,,,,,,,,,,, -77500,0.20580372,1.6265535,,,,,,,,,,,,,,,,, -77600,0.20643723,1.6105549,,,,,,,,,,,,,,,,, -77700,0.2898624,1.6089288,,,,,,,,,,,,,,,,, -77800,0.20200737,1.6516594,,,,,,,,,,,,,,,,, -77900,0.19737457,1.6351179,,,,,,,,,,,,,,,,, -78000,0.21472925,1.7009848,,,,,,,,,,,,,,,,, -78100,0.19190481,1.5669909,,,,,,,,,,,,,,,,, -78200,0.19485971,1.6404027,,,,,,,,,,,,,,,,, -78300,0.19850469,1.681598,,,,,,,,,,,,,,,,, -78400,0.20000462,1.5928705,,,,,,,,,,,,,,,,, -78485,,,0.6622284650802612,1.5754069089889526,32.57791200673433,0.6812066435813904,1.4563653469085691,29.76072163868272,3000.0,0.6956830024719238,1.360418438911438,29.425967056293665,3003.0,27761.8942797184,46897.715759277344,27761.8942797184,19132.193336486816,1.0392353534698486,0.0 -78500,0.19338587,1.5796726,,,,,,,,,,,,,,,,, -78600,0.20122331,1.6386948,,,,,,,,,,,,,,,,, -78700,0.20354766,1.6066741,,,,,,,,,,,,,,,,, -78800,0.21289332,1.6149745,,,,,,,,,,,,,,,,, -78900,0.19219175,1.5786241,,,,,,,,,,,,,,,,, -79000,0.21639802,1.6822996,,,,,,,,,,,,,,,,, -79100,0.19939756,1.5907125,,,,,,,,,,,,,,,,, -79200,0.19628699,1.571407,,,,,,,,,,,,,,,,, -79300,0.20895353,1.6133244,,,,,,,,,,,,,,,,, -79400,0.2021777,1.6692698,,,,,,,,,,,,,,,,, -79500,0.2541336,1.619727,,,,,,,,,,,,,,,,, -79600,0.2064674,1.6633563,,,,,,,,,,,,,,,,, -79700,0.20666456,1.683439,,,,,,,,,,,,,,,,, -79800,0.1998993,1.6489662,,,,,,,,,,,,,,,,, -79900,0.20761761,1.5961057,,,,,,,,,,,,,,,,, -80000,0.20855644,1.6070714,,,,,,,,,,,,,,,,, -80100,0.20589651,1.6739948,,,,,,,,,,,,,,,,, -80200,0.19369741,1.539677,,,,,,,,,,,,,,,,, -80300,0.20511018,1.6128961,,,,,,,,,,,,,,,,, -80400,0.21463658,1.6970458,,,,,,,,,,,,,,,,, -80500,0.20281836,1.6085676,,,,,,,,,,,,,,,,, -80600,0.20386417,1.6008167,,,,,,,,,,,,,,,,, -80700,0.22495298,1.6517189,,,,,,,,,,,,,,,,, -80800,0.23133054,1.6763083,,,,,,,,,,,,,,,,, -80863,,,0.6636092662811279,1.5703322887420654,32.70319836498964,0.6823350191116333,1.4525598287582395,30.035218304926424,3000.0,0.6951949596405029,1.3603163957595823,29.63272634036333,3003.0,28601.873636484142,48277.20259356499,28601.873636484142,19671.58778452873,1.0748109817504885,0.0 -80900,0.20308436,1.7389112,,,,,,,,,,,,,,,,, -81000,0.18703113,1.5653238,,,,,,,,,,,,,,,,, -81100,0.19868422,1.6722859,,,,,,,,,,,,,,,,, -81200,0.19128293,1.5673263,,,,,,,,,,,,,,,,, -81300,0.20981666,1.6311876,,,,,,,,,,,,,,,,, -81400,0.21951576,1.6252397,,,,,,,,,,,,,,,,, -81500,0.56501496,1.5782286,,,,,,,,,,,,,,,,, -81600,0.39405975,1.6486542,,,,,,,,,,,,,,,,, -81700,0.21869953,1.5891781,,,,,,,,,,,,,,,,, -81800,0.21678936,1.6558254,,,,,,,,,,,,,,,,, -81900,0.21836838,1.6426867,,,,,,,,,,,,,,,,, -82000,0.22430058,1.6200378,,,,,,,,,,,,,,,,, -82100,0.21370226,1.6587704,,,,,,,,,,,,,,,,, -82200,0.23667283,1.6930683,,,,,,,,,,,,,,,,, -82300,0.21717572,1.7015318,,,,,,,,,,,,,,,,, -82400,0.20801625,1.6009988,,,,,,,,,,,,,,,,, -82500,0.21397178,1.6123453,,,,,,,,,,,,,,,,, -82600,0.20393245,1.6571274,,,,,,,,,,,,,,,,, -82700,0.23355825,1.6997524,,,,,,,,,,,,,,,,, -82800,0.19959888,1.5788431,,,,,,,,,,,,,,,,, -82900,0.2789683,1.5408807,,,,,,,,,,,,,,,,, -83000,0.216636,1.5967681,,,,,,,,,,,,,,,,, -83100,0.19838697,1.6200795,,,,,,,,,,,,,,,,, -83200,0.20169032,1.6214023,,,,,,,,,,,,,,,,, -83242,,,0.6721797585487366,1.513619065284729,33.680118352812386,0.6828929781913757,1.4463530778884888,29.927711880365525,3000.0,0.6981233358383179,1.3538438081741333,29.69593409300139,3003.0,29442.05039691925,49746.21423172951,29442.05039691925,20300.30703139305,1.112574815750122,0.0 -83300,0.2112563,1.6415521,,,,,,,,,,,,,,,,, -83400,0.21090423,1.5847181,,,,,,,,,,,,,,,,, -83500,0.19393119,1.5501962,,,,,,,,,,,,,,,,, -83600,0.21804687,1.5990822,,,,,,,,,,,,,,,,, -83700,0.21000893,1.6291387,,,,,,,,,,,,,,,,, -83800,0.21364272,1.7005413,,,,,,,,,,,,,,,,, -83900,0.20047492,1.567841,,,,,,,,,,,,,,,,, -84000,0.19116025,1.6872373,,,,,,,,,,,,,,,,, -84100,0.21498749,1.6815017,,,,,,,,,,,,,,,,, -84200,0.23589702,1.6309565,,,,,,,,,,,,,,,,, -84300,0.20030504,1.6139624,,,,,,,,,,,,,,,,, -84400,0.201851,1.6644367,,,,,,,,,,,,,,,,, -84500,0.20800154,1.7005674,,,,,,,,,,,,,,,,, -84600,0.18430765,1.606001,,,,,,,,,,,,,,,,, -84700,0.20735964,1.561338,,,,,,,,,,,,,,,,, -84800,0.1987242,1.6540126,,,,,,,,,,,,,,,,, -84900,0.2013325,1.6877009,,,,,,,,,,,,,,,,, -85000,0.20736283,1.5535817,,,,,,,,,,,,,,,,, -85100,0.19082196,1.6266533,,,,,,,,,,,,,,,,, -85200,0.22511096,1.6360319,,,,,,,,,,,,,,,,, -85300,0.23325154,1.6503046,,,,,,,,,,,,,,,,, -85400,0.18949422,1.5521588,,,,,,,,,,,,,,,,, -85500,0.19958454,1.6724234,,,,,,,,,,,,,,,,, -85600,0.2011733,1.6048415,,,,,,,,,,,,,,,,, -85621,,,0.6682446002960205,1.540069341659546,33.75045946079044,0.6835377216339111,1.439979076385498,29.889512828415363,3000.0,0.6989948749542236,1.3409332036972046,30.01270004080984,3003.0,30282.133563756943,51124.29514193535,30282.133563756943,20838.189848661423,1.1500585079193115,0.0 -85700,0.20252942,1.5803807,,,,,,,,,,,,,,,,, -85800,0.21891515,1.716595,,,,,,,,,,,,,,,,, -85900,0.19495809,1.5900549,,,,,,,,,,,,,,,,, -86000,0.20389515,1.6798764,,,,,,,,,,,,,,,,, -86100,0.21768177,1.6059569,,,,,,,,,,,,,,,,, -86200,0.19984458,1.6135075,,,,,,,,,,,,,,,,, -86300,0.21819623,1.6058961,,,,,,,,,,,,,,,,, -86400,0.20163006,1.602627,,,,,,,,,,,,,,,,, -86500,0.21251184,1.6325437,,,,,,,,,,,,,,,,, -86600,0.32647938,1.6256931,,,,,,,,,,,,,,,,, -86700,0.21905783,1.5697589,,,,,,,,,,,,,,,,, -86800,0.2136845,1.6514401,,,,,,,,,,,,,,,,, -86900,0.20472652,1.5790435,,,,,,,,,,,,,,,,, -87000,0.20654406,1.6153972,,,,,,,,,,,,,,,,, -87100,0.21399279,1.6087205,,,,,,,,,,,,,,,,, -87200,0.21650359,1.7321064,,,,,,,,,,,,,,,,, -87300,0.20429061,1.5638769,,,,,,,,,,,,,,,,, -87400,0.20838365,1.5979494,,,,,,,,,,,,,,,,, -87500,0.21734162,1.5665276,,,,,,,,,,,,,,,,, -87600,0.20415004,1.6292919,,,,,,,,,,,,,,,,, -87700,0.19874632,1.5736419,,,,,,,,,,,,,,,,, -87800,0.21204369,1.5665135,,,,,,,,,,,,,,,,, -87900,0.20844243,1.5225734,,,,,,,,,,,,,,,,, -87999,,,0.6825253367424011,1.4487786293029783,34.338008319013625,0.6857943534851074,1.4298676252365112,30.0903180949495,3000.0,0.7000523209571838,1.3362531661987305,29.86954274429864,3003.0,31122.28181028366,52671.70739459992,31122.28181028366,21545.33805847168,1.1864256858825684,0.0 -88000,0.21747264,1.6079212,,,,,,,,,,,,,,,,, -88100,0.22490926,1.5387962,,,,,,,,,,,,,,,,, -88200,0.20667413,1.5657336,,,,,,,,,,,,,,,,, -88300,0.22856125,1.6452575,,,,,,,,,,,,,,,,, -88400,0.20354421,1.649086,,,,,,,,,,,,,,,,, -88500,0.21593863,1.6255535,,,,,,,,,,,,,,,,, -88600,0.19829878,1.5857894,,,,,,,,,,,,,,,,, -88700,0.21698536,1.615603,,,,,,,,,,,,,,,,, -88800,0.22003087,1.7160779,,,,,,,,,,,,,,,,, -88900,0.20473804,1.5729276,,,,,,,,,,,,,,,,, -89000,0.21500656,1.5544796,,,,,,,,,,,,,,,,, -89100,0.21075332,1.5525798,,,,,,,,,,,,,,,,, -89200,0.21102743,1.6213892,,,,,,,,,,,,,,,,, -89300,0.19465284,1.6036831,,,,,,,,,,,,,,,,, -89400,0.37223387,1.598893,,,,,,,,,,,,,,,,, -89500,0.20077877,1.6167597,,,,,,,,,,,,,,,,, -89600,0.19897233,1.6130922,,,,,,,,,,,,,,,,, -89700,0.19584608,1.5145785,,,,,,,,,,,,,,,,, -89800,0.19731261,1.5583566,,,,,,,,,,,,,,,,, -89900,0.21532509,1.5612361,,,,,,,,,,,,,,,,, -90000,0.21491995,1.6461344,,,,,,,,,,,,,,,,, -90100,0.23713386,1.6395062,,,,,,,,,,,,,,,,, -90200,0.21715866,1.5585111,,,,,,,,,,,,,,,,, -90300,0.21420793,1.621581,,,,,,,,,,,,,,,,, -90377,,,0.6723057627677917,1.5078905820846558,33.55991972595787,0.6858935356140137,1.4237116575241089,30.049164222142057,3000.0,0.7006449699401855,1.3293626308441162,29.98744025410253,3003.0,31962.3699696064,54170.957562446594,31962.3699696064,22204.37990808487,1.2239248752593994,0.0 -90400,0.20733064,1.5712527,,,,,,,,,,,,,,,,, -90500,0.21090345,1.5889925,,,,,,,,,,,,,,,,, -90600,0.21099313,1.5798357,,,,,,,,,,,,,,,,, -90700,0.20110662,1.5329307,,,,,,,,,,,,,,,,, -90800,0.22549322,1.6041737,,,,,,,,,,,,,,,,, -90900,0.20678806,1.5436736,,,,,,,,,,,,,,,,, -91000,0.21602266,1.6070951,,,,,,,,,,,,,,,,, -91100,0.2039254,1.5962532,,,,,,,,,,,,,,,,, -91200,0.21224055,1.6817241,,,,,,,,,,,,,,,,, -91300,0.2249784,1.5967029,,,,,,,,,,,,,,,,, -91400,0.20397927,1.5894735,,,,,,,,,,,,,,,,, -91500,0.20397985,1.5608009,,,,,,,,,,,,,,,,, -91600,0.22214068,1.6340289,,,,,,,,,,,,,,,,, -91700,0.20356569,1.5645045,,,,,,,,,,,,,,,,, -91800,0.21076156,1.6427363,,,,,,,,,,,,,,,,, -91900,0.22173911,1.5842369,,,,,,,,,,,,,,,,, -92000,0.21926951,1.5833498,,,,,,,,,,,,,,,,, -92100,0.23385644,1.5622076,,,,,,,,,,,,,,,,, -92200,0.22461125,1.6428427,,,,,,,,,,,,,,,,, -92300,0.19756506,1.5198001,,,,,,,,,,,,,,,,, -92400,0.23888788,1.6318593,,,,,,,,,,,,,,,,, -92500,0.20751318,1.6208119,,,,,,,,,,,,,,,,, -92600,0.20384537,1.5514348,,,,,,,,,,,,,,,,, -92700,0.20706877,1.5628161,,,,,,,,,,,,,,,,, -92756,,,0.6719854474067688,1.514437198638916,33.860206279249766,0.6873442530632019,1.416701078414917,30.474739154210475,3000.0,0.7013421654701233,1.322433352470398,29.937114050955184,3003.0,32802.475167512894,55580.85606193543,32802.475167512894,22774.05219650269,1.2687523365020752,0.0 -92800,0.21037543,1.5659846,,,,,,,,,,,,,,,,, -92900,0.22010928,1.5951849,,,,,,,,,,,,,,,,, -93000,0.23386419,1.6013585,,,,,,,,,,,,,,,,, -93100,0.21427654,1.5381942,,,,,,,,,,,,,,,,, -93200,0.21269132,1.556915,,,,,,,,,,,,,,,,, -93300,0.2255902,1.5257235,,,,,,,,,,,,,,,,, -93400,0.21658969,1.5922148,,,,,,,,,,,,,,,,, -93500,0.21252675,1.6209582,,,,,,,,,,,,,,,,, -93600,0.21505177,1.5265042,,,,,,,,,,,,,,,,, -93700,0.23736165,1.6403785,,,,,,,,,,,,,,,,, -93800,0.22239563,1.5727088,,,,,,,,,,,,,,,,, -93900,0.2056098,1.6114774,,,,,,,,,,,,,,,,, -94000,0.21691573,1.5459032,,,,,,,,,,,,,,,,, -94100,0.20803054,1.6286198,,,,,,,,,,,,,,,,, -94200,0.20827629,1.5963539,,,,,,,,,,,,,,,,, -94300,0.37328407,1.5916666,,,,,,,,,,,,,,,,, -94400,0.21930821,1.6617526,,,,,,,,,,,,,,,,, -94500,0.2224427,1.5897865,,,,,,,,,,,,,,,,, -94600,0.22453363,1.5817451,,,,,,,,,,,,,,,,, -94700,0.20977302,1.5559564,,,,,,,,,,,,,,,,, -94800,0.3994312,1.561792,,,,,,,,,,,,,,,,, -94900,0.21229474,1.5277194,,,,,,,,,,,,,,,,, -95000,0.23392104,1.5437895,,,,,,,,,,,,,,,,, -95100,0.22000788,1.5095297,,,,,,,,,,,,,,,,, -95134,,,0.6826393604278564,1.4538395404815674,34.60986378597462,0.6881625652313232,1.4122296571731567,30.472603703054315,3000.0,0.7038289904594421,1.3092671632766724,30.08935415843265,3003.0,33642.51808953285,57005.01177072525,33642.51808953285,23358.04110336304,1.312809705734253,0.0 -95200,0.21332344,1.5425434,,,,,,,,,,,,,,,,, -95300,0.21045227,1.5559691,,,,,,,,,,,,,,,,, -95400,0.20424943,1.5632142,,,,,,,,,,,,,,,,, -95500,0.21759403,1.5045557,,,,,,,,,,,,,,,,, -95600,0.22246395,1.6227555,,,,,,,,,,,,,,,,, -95700,0.22045289,1.501947,,,,,,,,,,,,,,,,, -95800,0.21675892,1.5837173,,,,,,,,,,,,,,,,, -95900,0.22158958,1.5740695,,,,,,,,,,,,,,,,, -96000,0.23612517,1.5381863,,,,,,,,,,,,,,,,, -96100,0.20734157,1.6116195,,,,,,,,,,,,,,,,, -96200,0.20354438,1.5793713,,,,,,,,,,,,,,,,, -96300,0.22507837,1.6106143,,,,,,,,,,,,,,,,, -96400,0.21928947,1.5223333,,,,,,,,,,,,,,,,, -96500,0.21515001,1.5360811,,,,,,,,,,,,,,,,, -96600,0.20907831,1.5503309,,,,,,,,,,,,,,,,, -96700,0.22071886,1.5931034,,,,,,,,,,,,,,,,, -96800,0.23441081,1.6240875,,,,,,,,,,,,,,,,, -96900,0.21563436,1.5415696,,,,,,,,,,,,,,,,, -97000,0.20771709,1.5446478,,,,,,,,,,,,,,,,, -97100,0.21309991,1.5456761,,,,,,,,,,,,,,,,, -97200,0.22087927,1.549951,,,,,,,,,,,,,,,,, -97300,0.21882313,1.4907064,,,,,,,,,,,,,,,,, -97400,0.2098107,1.5189255,,,,,,,,,,,,,,,,, -97500,0.20960933,1.539136,,,,,,,,,,,,,,,,, -97513,,,0.6825652122497559,1.4482752084732056,34.105125485857464,0.6888568997383118,1.4053305387496948,30.3004675653798,3000.0,0.7047934532165527,1.3054507970809937,29.9236223901206,3003.0,34482.75514650345,58539.23271560669,34482.75514650345,24051.911954164505,1.350311040878296,0.0 -97600,0.19711868,1.4855133,,,,,,,,,,,,,,,,, -97700,0.21034881,1.567715,,,,,,,,,,,,,,,,, -97800,0.21933925,1.5393012,,,,,,,,,,,,,,,,, -97900,0.22060119,1.5876725,,,,,,,,,,,,,,,,, -98000,0.21092722,1.6177917,,,,,,,,,,,,,,,,, -98100,0.21848957,1.5743854,,,,,,,,,,,,,,,,, -98200,0.21683007,1.5690213,,,,,,,,,,,,,,,,, -98300,0.21968824,1.5577675,,,,,,,,,,,,,,,,, -98400,0.2115358,1.498715,,,,,,,,,,,,,,,,, -98500,0.215776,1.5214766,,,,,,,,,,,,,,,,, -98600,0.21086133,1.5417328,,,,,,,,,,,,,,,,, -98700,0.20952296,1.4551916,,,,,,,,,,,,,,,,, -98800,0.22559462,1.5076671,,,,,,,,,,,,,,,,, -98900,0.2356331,1.5784922,,,,,,,,,,,,,,,,, -99000,0.20404188,1.4911467,,,,,,,,,,,,,,,,, -99100,0.23139851,1.588397,,,,,,,,,,,,,,,,, -99200,0.20854275,1.4909477,,,,,,,,,,,,,,,,, -99300,0.21796663,1.5859398,,,,,,,,,,,,,,,,, -99400,0.21003385,1.5993762,,,,,,,,,,,,,,,,, -99500,0.21039209,1.5485097,,,,,,,,,,,,,,,,, -99600,0.21520558,1.5243742,,,,,,,,,,,,,,,,, -99700,0.20310326,1.5268615,,,,,,,,,,,,,,,,, -99800,0.21009025,1.5290906,,,,,,,,,,,,,,,,, -99892,,,0.678978443145752,1.479453206062317,34.306114082887525,0.6903820037841797,1.4001511335372925,30.426522454410275,3000.0,0.7065016627311707,1.2986218929290771,30.29928668877095,3003.0,35322.91156697273,60021.6882147789,35322.91156697273,24694.09686756134,1.3887178897857666,0.0 -99900,0.2202236,1.5847106,,,,,,,,,,,,,,,,, -100000,0.25686747,1.4679418,,,,,,,,,,,,,,,,, -100100,0.22684266,1.534421,,,,,,,,,,,,,,,,, -100200,0.21976644,1.5307684,,,,,,,,,,,,,,,,, -100300,0.22075944,1.547561,,,,,,,,,,,,,,,,, -100400,0.21771005,1.5356064,,,,,,,,,,,,,,,,, -100500,0.23178004,1.5709089,,,,,,,,,,,,,,,,, -100600,0.2117125,1.5439644,,,,,,,,,,,,,,,,, -100700,0.21573998,1.5355811,,,,,,,,,,,,,,,,, -100800,0.22540513,1.4954162,,,,,,,,,,,,,,,,, -100900,0.23331735,1.5316906,,,,,,,,,,,,,,,,, -101000,0.20930804,1.5465337,,,,,,,,,,,,,,,,, -101100,0.22191538,1.5040967,,,,,,,,,,,,,,,,, -101200,0.23508367,1.529099,,,,,,,,,,,,,,,,, -101300,0.20768228,1.5013562,,,,,,,,,,,,,,,,, -101400,0.22871806,1.5552704,,,,,,,,,,,,,,,,, -101500,0.20528772,1.463622,,,,,,,,,,,,,,,,, -101600,0.2207535,1.5577787,,,,,,,,,,,,,,,,, -101700,0.22052214,1.5706756,,,,,,,,,,,,,,,,, -101800,0.21562305,1.5037501,,,,,,,,,,,,,,,,, -101900,0.21390149,1.4974079,,,,,,,,,,,,,,,,, -102000,0.23603539,1.5881164,,,,,,,,,,,,,,,,, -102100,0.21397367,1.5129844,,,,,,,,,,,,,,,,, -102200,0.22522962,1.5352303,,,,,,,,,,,,,,,,, -102270,,,0.6835367679595947,1.4356738328933716,34.774789100055045,0.6904935836791992,1.3909913301467896,30.6419763551928,3000.0,0.7074313163757324,1.2908309698104858,30.136253694643354,3003.0,36162.81519198418,61461.7780213356,36162.81519198418,25294.16486024857,1.4278733730316162,0.0 -102300,0.2285601,1.5411496,,,,,,,,,,,,,,,,, -102400,0.21956632,1.4648309,,,,,,,,,,,,,,,,, -102500,0.22212732,1.5145904,,,,,,,,,,,,,,,,, -102600,0.22282459,1.5026892,,,,,,,,,,,,,,,,, -102700,0.23338713,1.6039472,,,,,,,,,,,,,,,,, -102800,0.21273309,1.5349661,,,,,,,,,,,,,,,,, -102900,0.20092055,1.4900609,,,,,,,,,,,,,,,,, -103000,0.23047425,1.5581628,,,,,,,,,,,,,,,,, -103100,0.24719612,1.5641478,,,,,,,,,,,,,,,,, -103200,0.23004487,1.5143895,,,,,,,,,,,,,,,,, -103300,0.23149192,1.6119163,,,,,,,,,,,,,,,,, -103400,0.22686988,1.5083232,,,,,,,,,,,,,,,,, -103500,0.21499468,1.5380296,,,,,,,,,,,,,,,,, -103600,0.21712469,1.466445,,,,,,,,,,,,,,,,, -103700,0.22165005,1.5200201,,,,,,,,,,,,,,,,, -103800,0.21114668,1.4828148,,,,,,,,,,,,,,,,, -103900,0.21676849,1.5074133,,,,,,,,,,,,,,,,, -104000,0.22767065,1.5679672,,,,,,,,,,,,,,,,, -104100,0.22688961,1.5986509,,,,,,,,,,,,,,,,, -104200,0.22044267,1.5158724,,,,,,,,,,,,,,,,, -104300,0.2269787,1.5930117,,,,,,,,,,,,,,,,, -104400,0.23227006,1.5434731,,,,,,,,,,,,,,,,, -104500,0.211981,1.5630985,,,,,,,,,,,,,,,,, -104600,0.22430867,1.4301232,,,,,,,,,,,,,,,,, -104649,,,0.677696704864502,1.4766273498535156,34.298339599665034,0.6918699145317078,1.387955665588379,30.743370190512355,3000.0,0.708453893661499,1.2837616205215454,30.46981512380905,3003.0,37002.91240334511,62923.05829691887,37002.91240334511,25915.23298954964,1.467015027999878,0.0 -104700,0.21741846,1.4845037,,,,,,,,,,,,,,,,, -104800,0.22102877,1.4703977,,,,,,,,,,,,,,,,, -104900,0.21710806,1.5237474,,,,,,,,,,,,,,,,, -105000,0.2192219,1.4798831,,,,,,,,,,,,,,,,, -105100,0.2340113,1.4934565,,,,,,,,,,,,,,,,, -105200,0.21890397,1.582102,,,,,,,,,,,,,,,,, -105300,0.22529691,1.5511382,,,,,,,,,,,,,,,,, -105400,0.2318369,1.5544597,,,,,,,,,,,,,,,,, -105500,3.0153596,1.6176324,,,,,,,,,,,,,,,,, -105600,0.22216609,1.564456,,,,,,,,,,,,,,,,, -105700,0.2194499,1.5120393,,,,,,,,,,,,,,,,, -105800,0.21817394,1.4902422,,,,,,,,,,,,,,,,, -105900,0.21953817,1.5581532,,,,,,,,,,,,,,,,, -106000,0.22038022,1.524932,,,,,,,,,,,,,,,,, -106100,0.23224442,1.5664551,,,,,,,,,,,,,,,,, -106200,0.21327464,1.4762236,,,,,,,,,,,,,,,,, -106300,0.233136,1.5859518,,,,,,,,,,,,,,,,, -106400,0.22288005,1.5543547,,,,,,,,,,,,,,,,, -106500,0.22881985,1.4925649,,,,,,,,,,,,,,,,, -106600,0.22478187,1.4929073,,,,,,,,,,,,,,,,, -106700,0.23412699,1.542265,,,,,,,,,,,,,,,,, -106800,0.2226716,1.5118165,,,,,,,,,,,,,,,,, -106900,0.2281515,1.4324292,,,,,,,,,,,,,,,,, -107000,0.23030154,1.6124673,,,,,,,,,,,,,,,,, -107028,,,0.6928110718727112,1.384670376777649,35.15910011581935,0.6928246021270752,1.384840965270996,30.72062201906474,3000.0,0.7085817456245422,1.281222581863403,30.52008339596013,3003.0,37843.09711742401,64408.88877725601,37843.09711742401,26560.762880563736,1.5060055255889893,0.0 -107100,0.22260116,1.5735403,,,,,,,,,,,,,,,,, -107200,0.22571477,1.5616592,,,,,,,,,,,,,,,,, -107300,0.22991954,1.4585774,,,,,,,,,,,,,,,,, -107400,0.23193815,1.5244925,,,,,,,,,,,,,,,,, -107500,0.2185236,1.4747323,,,,,,,,,,,,,,,,, -107600,0.22230944,1.5083448,,,,,,,,,,,,,,,,, -107700,0.21426935,1.4574233,,,,,,,,,,,,,,,,, -107800,0.22752084,1.4355379,,,,,,,,,,,,,,,,, -107900,0.2213269,1.427314,,,,,,,,,,,,,,,,, -108000,0.22317484,1.5001442,,,,,,,,,,,,,,,,, -108100,0.23210403,1.5905904,,,,,,,,,,,,,,,,, -108200,0.23417556,1.4780455,,,,,,,,,,,,,,,,, -108300,0.23488611,1.492021,,,,,,,,,,,,,,,,, -108400,0.23623654,1.5406947,,,,,,,,,,,,,,,,, -108500,0.22218455,1.4401191,,,,,,,,,,,,,,,,, -108600,0.21773954,1.4571002,,,,,,,,,,,,,,,,, -108700,0.21961267,1.5289379,,,,,,,,,,,,,,,,, -108800,0.22629948,1.4560298,,,,,,,,,,,,,,,,, -108900,0.21676217,1.4683073,,,,,,,,,,,,,,,,, -109000,0.25744426,1.6026003,,,,,,,,,,,,,,,,, -109100,0.2320959,1.5375108,,,,,,,,,,,,,,,,, -109200,0.22732775,1.5017124,,,,,,,,,,,,,,,,, -109300,0.22645731,1.481057,,,,,,,,,,,,,,,,, -109400,0.22444345,1.4776683,,,,,,,,,,,,,,,,, -109407,,,0.6901442408561707,1.4055242538452148,35.2604371709747,0.6927750110626221,1.3797883987426758,30.8121018107398,3000.0,0.7109174728393555,1.2740641832351685,30.94646817655173,3003.0,38683.2755715847,65866.91502094269,38683.2755715847,27178.494292497635,1.5452158451080322,0.0 -109500,0.23080601,1.4872714,,,,,,,,,,,,,,,,, -109600,0.23966374,1.5233207,,,,,,,,,,,,,,,,, -109700,0.22576806,1.4767456,,,,,,,,,,,,,,,,, -109800,0.23528872,1.5249393,,,,,,,,,,,,,,,,, -109900,0.24201587,1.5179645,,,,,,,,,,,,,,,,, -110000,0.2554253,1.4855816,,,,,,,,,,,,,,,,, -110100,0.22880521,1.4860946,,,,,,,,,,,,,,,,, -110200,0.2642827,1.5507753,,,,,,,,,,,,,,,,, -110300,0.21768738,1.4224191,,,,,,,,,,,,,,,,, -110400,0.2269301,1.521287,,,,,,,,,,,,,,,,, -110500,0.28710917,1.5072638,,,,,,,,,,,,,,,,, -110600,0.23918343,1.4910705,,,,,,,,,,,,,,,,, -110700,0.23351987,1.5028121,,,,,,,,,,,,,,,,, -110800,0.23746406,1.4859842,,,,,,,,,,,,,,,,, -110900,0.2299482,1.5178916,,,,,,,,,,,,,,,,, -111000,0.23363622,1.4981047,,,,,,,,,,,,,,,,, -111100,0.22566746,1.4723123,,,,,,,,,,,,,,,,, -111200,0.23255812,1.4557451,,,,,,,,,,,,,,,,, -111300,0.21893196,1.4880664,,,,,,,,,,,,,,,,, -111400,0.22501206,1.4556506,,,,,,,,,,,,,,,,, -111500,0.21939793,1.5314871,,,,,,,,,,,,,,,,, -111600,0.23279664,1.4445704,,,,,,,,,,,,,,,,, -111700,0.21996522,1.4789255,,,,,,,,,,,,,,,,, -111784,,,0.6872426271438599,1.4228895902633667,34.660629651300326,0.6934818029403687,1.3747117519378662,30.673267151118715,3000.0,0.7106966376304626,1.268842339515686,30.744305941224404,3003.0,39523.27041316032,67255.41884446144,39523.27041316032,27726.88291144371,1.5854876041412354,0.0 -111800,0.22170651,1.4516157,,,,,,,,,,,,,,,,, -111900,0.23901796,1.5524601,,,,,,,,,,,,,,,,, -112000,0.24100319,1.4951466,,,,,,,,,,,,,,,,, -112100,0.24034964,1.4590925,,,,,,,,,,,,,,,,, -112200,0.22861436,1.5246747,,,,,,,,,,,,,,,,, -112300,0.23821089,1.4977062,,,,,,,,,,,,,,,,, -112400,0.23291983,1.5419687,,,,,,,,,,,,,,,,, -112500,0.21804775,1.5027399,,,,,,,,,,,,,,,,, -112600,0.2306226,1.4972954,,,,,,,,,,,,,,,,, -112700,0.22312643,1.4199903,,,,,,,,,,,,,,,,, -112800,0.22323541,1.4310168,,,,,,,,,,,,,,,,, -112900,0.23455788,1.457454,,,,,,,,,,,,,,,,, -113000,0.23760368,1.4965332,,,,,,,,,,,,,,,,, -113100,0.22657403,1.4455512,,,,,,,,,,,,,,,,, -113200,0.23304746,1.4770341,,,,,,,,,,,,,,,,, -113300,0.250242,1.5205481,,,,,,,,,,,,,,,,, -113400,0.24103469,1.5075126,,,,,,,,,,,,,,,,, -113500,0.23939066,1.5245398,,,,,,,,,,,,,,,,, -113600,0.2214098,1.464657,,,,,,,,,,,,,,,,, -113700,0.22812067,1.4854008,,,,,,,,,,,,,,,,, -113800,0.25538704,1.4989996,,,,,,,,,,,,,,,,, -113900,0.2318823,1.4948765,,,,,,,,,,,,,,,,, -114000,0.23759967,1.4724119,,,,,,,,,,,,,,,,, -114100,0.2510927,1.416776,,,,,,,,,,,,,,,,, -114162,,,0.6885464191436768,1.4088932275772097,35.04633107417319,0.6940521597862244,1.3701497316360474,30.744583577862368,3000.0,0.7109755277633667,1.265406370162964,30.69732003049124,3003.0,40363.24680304527,68713.56347870827,40363.24680304527,28344.93062567711,1.6279840469360352,0.0 -114200,0.23022142,1.4806235,,,,,,,,,,,,,,,,, -114300,0.23362541,1.5491182,,,,,,,,,,,,,,,,, -114400,0.22885823,1.5273666,,,,,,,,,,,,,,,,, -114500,0.2337441,1.4889812,,,,,,,,,,,,,,,,, -114600,0.2299691,1.5441811,,,,,,,,,,,,,,,,, -114700,0.23941365,1.556316,,,,,,,,,,,,,,,,, -114800,0.22898066,1.4785722,,,,,,,,,,,,,,,,, -114900,0.22462542,1.5057081,,,,,,,,,,,,,,,,, -115000,0.23683503,1.4499824,,,,,,,,,,,,,,,,, -115100,0.23298511,1.4241518,,,,,,,,,,,,,,,,, -115200,0.22776917,1.4286826,,,,,,,,,,,,,,,,, -115300,0.23507945,1.4600903,,,,,,,,,,,,,,,,, -115400,0.23949647,1.4884273,,,,,,,,,,,,,,,,, -115500,0.22276412,1.4911842,,,,,,,,,,,,,,,,, -115600,0.24372213,1.454012,,,,,,,,,,,,,,,,, -115700,0.2428514,1.4824879,,,,,,,,,,,,,,,,, -115800,0.22712578,1.4759681,,,,,,,,,,,,,,,,, -115900,0.22871877,1.4267408,,,,,,,,,,,,,,,,, -116000,0.22739917,1.5169867,,,,,,,,,,,,,,,,, -116100,0.2351148,1.4610759,,,,,,,,,,,,,,,,, -116200,0.22643544,1.4675875,,,,,,,,,,,,,,,,, -116300,0.22897147,1.4605594,,,,,,,,,,,,,,,,, -116400,0.22806254,1.3583637,,,,,,,,,,,,,,,,, -116500,0.2297399,1.5401248,,,,,,,,,,,,,,,,, -116541,,,0.6922299861907959,1.387744426727295,34.91732924260557,0.6947712898254395,1.366743564605713,30.894717875592825,3000.0,0.7128115892410278,1.2598994970321655,31.03030929029123,3003.0,41203.46653318405,70213.22715878487,41203.46653318405,29004.25678896904,1.6690990924835205,0.0 -116541,,,,,,,,,,,,,,41203.46653318405,,,,,0.0 diff --git a/pyproject.toml b/pyproject.toml index 4e15e4400..dbe3e842b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,7 @@ classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + "Programming Languate :: Python :: 3.11", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ @@ -38,15 +36,15 @@ dependencies = [ "docker==7.1.0", "numpy>=2.0.2", "pandas>=2.0.1", - "tensorflow==2.18.0", - "tensorflow-datasets==4.9.7", + "tensorflow==2.19.0", + "tensorflow-datasets==4.9.9", "tensorflow-probability==0.20.0", "gputil==1.4.0", "psutil==6.1.0", "clu==0.0.12", "matplotlib>=3.9.2", "tabulate==0.9.0", - + "wandb==0.21.0" ] [build-system] @@ -83,40 +81,41 @@ dev = [ "pre-commit==4.0.1", ] -wandb = ["wandb==0.19.6"] - # Workloads criteo1tb = ["scikit-learn==1.5.2"] fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.5.2"] librispeech_conformer = [ "sentencepiece==0.2.0", - "tensorflow-text==2.18.0", + "tensorflow-text==2.19.0", "pydub==0.25.1", ] -wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"] +wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.19.0"] # Frameworks jax_core_deps = [ - "flax==0.8.4", + "flax==0.10.6", "optax==0.2.2", "chex==0.1.86", - "ml_dtypes==0.4.1", + "ml_dtypes==0.5.1", "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", + "jax==0.6.0", "algoperf[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.28", - "jaxlib==0.4.28", - "jax-cuda12-plugin[with_cuda]==0.4.28", - "jax-cuda12-pjrt==0.4.28", + # Temporarily install with -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre + "jax", + "jaxlib", + "jax-cuda12-plugin[with-cuda]", + "jax-cuda12-pjrt", "algoperf[jax_core_deps]", ] -pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] +pytorch_cpu = [ + "torch==2.5.1", + "torchvision==0.20.1" +] pytorch_gpu = [ "torch==2.5.1", "torchvision==0.20.1", diff --git a/reference_algorithms/development_algorithms/__init__.py b/reference_algorithms/development_algorithms/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/cifar/__init__.py b/reference_algorithms/development_algorithms/cifar/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/__init__.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py deleted file mode 100644 index 3d8e35eaa..000000000 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ /dev/null @@ -1,180 +0,0 @@ -"""Training algorithm track submission functions for CIFAR10.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'cifar': 128} - return batch_sizes[workload_name] - - -def cosine_decay(lr, step, total_steps): - ratio = jnp.maximum(0., step / total_steps) - mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio)) - return mult * lr - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('cifar') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, new_model_state - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (_, new_model_state), grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del global_step - del train_state - del eval_results - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, new_params, new_model_state = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/__init__.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py deleted file mode 100644 index d8b91f83a..000000000 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Training algorithm track submission functions for CIFAR10.""" - -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algoperf import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'cifar': 128} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128. - optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2), - } - - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-5, - end_factor=1., - total_iters=hyperparameters.warmup_epochs) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs]) - - return optimizer_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del train_state - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - - steps_per_epoch = workload.num_train_examples // get_batch_size('cifar') - if (global_step + 1) % steps_per_epoch == 0: - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json deleted file mode 100644 index 283341705..000000000 --- a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [200]}, - "l2": {"feasible_points": [5e-4]}, - "momentum": {"feasible_points": [0.9]} -} diff --git a/reference_algorithms/development_algorithms/mnist/__init__.py b/reference_algorithms/development_algorithms/mnist/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/mnist/discrete_space.json b/reference_algorithms/development_algorithms/mnist/discrete_space.json deleted file mode 100644 index 310f19e7d..000000000 --- a/reference_algorithms/development_algorithms/mnist/discrete_space.json +++ /dev/null @@ -1,17 +0,0 @@ -[ - { - "learning_rate": 1e-3, - "one_minus_beta_1": 0.999, - "epsilon": 0.9 - }, - { - "learning_rate": 1e-2, - "one_minus_beta_1": 0.99, - "epsilon": 0.99 - }, - { - "learning_rate": 1e-1, - "one_minus_beta_1": 0.9, - "epsilon": 0.999 - } -] \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/__init__.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py deleted file mode 100644 index c1f54597d..000000000 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Training algorithm track submission functions for MNIST.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'mnist': 1024} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optax.chain( - optax.scale_by_adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.999, - eps=hyperparameters.epsilon), - optax.scale(-hyperparameters.learning_rate)) - return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn - - -# We need to jax.pmap here instead of inside update_params because the latter -# would recompile the function every step. -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_update_params(workload: spec.Workload, - opt_update_fn, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - optimizer_state: spec.OptimizerState, - rng: spec.RandomState) -> spec.UpdateReturn: - del hyperparameters - - def loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - return loss, new_model_state - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, new_model_state), grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - del global_step - - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - optimizer_state, opt_update_fn = optimizer_state - new_optimizer_state, updated_params, new_model_state = pmapped_update_params( - workload, - opt_update_fn, - current_param_container, - model_state, - hyperparameters, - batch, - optimizer_state, - per_device_rngs) - return (new_optimizer_state, opt_update_fn), updated_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/__init__.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py deleted file mode 100644 index dedd96793..000000000 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Training algorithm track submission functions for MNIST.""" - -from typing import Any, Dict, Iterator, List, Optional, Tuple - -import torch - -from algoperf import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'mnist': 1024} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del workload - del rng - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999), - eps=hyperparameters.epsilon), - } - return optimizer_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del hyperparameters - del loss_type - del current_params_types - del train_state - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - for param in current_model.parameters(): - param.grad = None - - output, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=output) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - loss.backward() - optimizer_state['optimizer'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del optimizer_state - del current_param_container - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json deleted file mode 100644 index 35b941133..000000000 --- a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"}, - "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"}, - "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]} -} diff --git a/reference_algorithms/paper_baselines/README.md b/reference_algorithms/paper_baselines/README.md index aadb7eab2..6c7027adf 100644 --- a/reference_algorithms/paper_baselines/README.md +++ b/reference_algorithms/paper_baselines/README.md @@ -1,14 +1,10 @@ # Baseline Submissions from the "Benchmarking Neural Network Training Algorithms" Paper -This directory contains the baseline submissions for the [external tuning ruleset](../README.md#external-tuning-ruleset) as presented in our paper [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179). They are based on eight different update rules: +This directory contains baseline submissions for the [external tuning ruleset](../README.md#external-tuning-ruleset) as presented in our paper [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179): -- [Adafactor](/reference_algorithms/paper_baselines/adafactor) - [AdamW](/reference_algorithms/paper_baselines/adamw) -- [LAMB](/reference_algorithms/paper_baselines/lamb) - [SGD with Momentum](/reference_algorithms/paper_baselines/momentum) - [NadamW](/reference_algorithms/paper_baselines/nadamw) - [SGD with Nesterov Momentum](/reference_algorithms/paper_baselines/nesterov) -- [SAM](/reference_algorithms/paper_baselines/sam) -- [Shampoo](/reference_algorithms/paper_baselines/shampoo/) Each update rule has two different tuning search spaces, one where the first momentum parameter (often denoted $\beta_1$) is tuned and one where it is set to a fixed value. diff --git a/reference_algorithms/paper_baselines/adafactor/__init__.py b/reference_algorithms/paper_baselines/adafactor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/adafactor/jax/__init__.py b/reference_algorithms/paper_baselines/adafactor/jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py deleted file mode 100644 index ff98464ae..000000000 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ /dev/null @@ -1,693 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The init2winit Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PAX/Praxis implementation of Adafactor. - -Copied from Praxis's `sharded_adafactor`, removing unnecessary sharding-related -code and dependencies on Praxis. - -Code: -https://github.com/google/praxis/blob/516a96bce6f03090c5903531038f8f8af6212250/praxis/optimizers.py#L2308 - -Forked from: -https://github.com/google/init2winit/master/init2winit/optimizer_lib/pax_adafactor.py -""" - -import dataclasses -import functools -import re -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union - -import jax -from jax import numpy as jnp -import optax - -JTensor = Any -NestedJTensor = Any -NestedHParams = Any - - -def to_quantized(fvalue: JTensor, - quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]: - """Converts floating point values `fvalues` to quantized values. - - We use a very simple quantization scheme where the range is symmetric around - 0.0, and we simply map 0 to 0.0. - - Let x = bucket_size - We map [-0.5x, 0.5x] to 0 - [-1.5x, -0.5x] to -1 - [0.5x, 1.5x] to 1 - and so on so forth. - - Some properties: - a1, a2 = to_quantized(x, quantized_dtype) - b1 = to_float(a1, a2) - c1, c2 = to_quantized(b1, quantized_dtype) - - then a1 == c1, a2 == c2 - - Args: - fvalue: Values in floating point. - quantized_dtype: Quantized dtype, can be either jnp.int8, or jnp.int16. - - Returns: - A (quantized_values, bucket_size) 2-tuple. - `quantized_values * bucket_size[jnp.newaxis, ...]` are the quantized - values - on the floating value axis. - """ - float_dtype = fvalue.dtype - if quantized_dtype == jnp.int8: - # value -128 is not used. - num_buckets = jnp.array(127.0, dtype=float_dtype) - elif quantized_dtype == jnp.int16: - # value -32768 is not used. - num_buckets = jnp.array(32767.0, dtype=float_dtype) - else: - raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') - # max value is mapped to num_buckets - - # We first decide the scale. - if fvalue.ndim < 1: - raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') - - max_abs = jnp.max(jnp.abs(fvalue), axis=0) - bucket_size = max_abs / num_buckets - bs_expanded = bucket_size[jnp.newaxis, ...] - # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) - ratio = fvalue / bs_nonzero - # We use rounding to remove bias. - quantized = jnp.round(ratio) - return quantized.astype(quantized_dtype), bucket_size - - -def to_float(quantized: JTensor, bucket_size: JTensor) -> JTensor: - """Converts quantized values to float values. - - Args: - quantized: Quantized values, of type either jnp.int8 or jnp.int16. - bucket_size: The size of each bucket on the floating-point axis. bucket_size - is of rank tf.rank(quantized) - 1. For example, if quantized is of shape - [x, ...], bucket_size is of shape [...]. - - Returns: - Unquantized values of type bucket_size.dtype. - """ - float_dtype = bucket_size.dtype - bucket_size = bucket_size[jnp.newaxis, ...] - return quantized.astype(float_dtype) * bucket_size - - -def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor: - """Second-moment decay rate like Adam, subsuming the correction factor. - - Args: - beta2: A floating point value between 0 and 1. - step_counter: A scalar tensor keeping track of the number of steps - performed. - - Returns: - The decay rate as a scalar JTensor. - """ - step = step_counter - beta2 = jnp.array(beta2, dtype=jnp.float32) - t = step + 1. - return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t)) - - -def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: - """Second moment decay rate where memory-length grows as step_num^exponent. - - Args: - exponent: A floating point value between 0 and 1. - step_counter: A scalar tensor keeping track of the number of steps - performed. - - Returns: - The decay rate as a scalar JTensor. - """ - step = step_counter - exponent = jnp.array(exponent, dtype=jnp.float32) - return 1. - jnp.power((step + 1.), -exponent) - - -def reduce_mean(array: JTensor) -> JTensor: - """Computes the mean of `array` in a more numerically stable way. - - Args: - array: Input array. - - Returns: - The mean of the input array as a scalar array. - """ - num_elements = array.size - if num_elements > 1e8: - # When x is too large, simple jnp.mean() can result in nan or inf values. - # TODO(bf-jax): The following code snippet is consistent with the TensorFlow - # implementation. This can be simplified into `jnp.mean(jnp.mean(x, -1))`. - # Update to using mean() after verifying consistency. - array_sum = jnp.sum(array, axis=-1) - array_sum = jnp.sum(array_sum) - return array_sum / jnp.array(num_elements, dtype=array_sum.dtype) - else: - return jnp.mean(array) - - -def reduce_rms(array: JTensor) -> JTensor: - """Computes the RMS of `array` (in a numerically stable way). - - Args: - array: Input array. - - Returns: - The root mean square of the input array as a scalar array. - """ - sq = jnp.square(array) - sq_mean = reduce_mean(sq) - return jnp.sqrt(sq_mean) - - -@dataclasses.dataclass(frozen=True) -class _ShardedAdafactorUpdateResult: - """Structure containing per-variable info for Adafactor.""" - update: Optional[Any] - m: Optional[Any] - m_scale: Optional[Any] - vr: Optional[Any] - vc: Optional[Any] - v: Optional[Any] - - -class ShardedAdafactorState(NamedTuple): - """Overall state of the ShardedAdafactor optimizer.""" - count: JTensor - m: Optional[NestedJTensor] - m_scale: Optional[NestedJTensor] - vr: Optional[NestedJTensor] - vc: Optional[NestedJTensor] - v: Optional[NestedJTensor] - - -class _ShardedAdafactorHelper: - """Helper class to implement optax-based sharded Adafactor.""" - - def __init__(self, - learning_rate: optax.Schedule, - weight_decay: Optional[float], - layerwise_adaptation: bool, - decay_method: str, - decay_adam: float, - decay_pow: float, - beta1: float, - clip_threshold: Optional[float], - factored: bool, - epsilon1_grad_sq_reg: float, - quantized_dtype: jnp.dtype, - respect_skip_lp_regularization: bool, - exclude_from_layerwise_adaptation: Optional[List[str]], - per_var_learning_summary: bool, - sort_factored_second_moment_dims: bool, - min_dim_size_to_factor: int, - multiply_by_parameter_scale: bool, - epsilon2_param_scale_reg: float, - maybe_inf_to_nan: bool, - nesterov: bool) -> None: - """Constructor. See ShardedAdafactor() below.""" - - self._learning_rate = learning_rate - self._weight_decay = weight_decay - self._layerwise_adaptation = layerwise_adaptation - self._decay_method = decay_method - self._decay_adam = decay_adam - self._decay_pow = decay_pow - self._beta1 = beta1 - self._clip_threshold = clip_threshold - self._factored = factored - self._epsilon1 = epsilon1_grad_sq_reg - self._quantized_dtype = quantized_dtype - self._respect_skip_lp_regularization = respect_skip_lp_regularization - self._exclude_from_layerwise_adaptation = exclude_from_layerwise_adaptation - self._per_var_learning_summary = per_var_learning_summary - self._sort_factored_second_moment_dims = sort_factored_second_moment_dims - self._min_dim_size_to_factor = min_dim_size_to_factor - self._multiply_by_parameter_scale = multiply_by_parameter_scale - self._epsilon2 = epsilon2_param_scale_reg - self._maybe_inf_to_nan = maybe_inf_to_nan - self._nesterov = nesterov - - def should_use_factored_second_moment_estimate(self, shape): - """Should we use a factored second moment estimator. - - Based on the shape of the variable. - - Args: - shape: a list of integers. - - Returns: - A boolean. - """ - return self.factored_second_moment_dims(shape) is not None - - def factored_second_moment_dims(self, shape): - """Should we use a factored second moment estimator. - - We select largest and second largest var dims as row and colum dims. - - Default list of factored dims is -1, -2. - - Args: - shape: a list of integers. - - Returns: - either a list of 2 Dimension indices for row and col or None - """ - if not self._factored: - return None - if len(shape) < 2: - return None - if not self._sort_factored_second_moment_dims: - return len(shape) - 1, len(shape) - 2 - - def largest_two_dim_indices(): - s = [(s, i) for i, s in enumerate(shape)] - sorted_dims = sorted(s, key=lambda d: -d[0]) - return sorted_dims[0][1], sorted_dims[1][1] - - r_idx, c_idx = largest_two_dim_indices() - if shape[c_idx] < self._min_dim_size_to_factor: - return None - return r_idx, c_idx - - def should_store_momentum_in_qint(self, shape): - """Should we store momentum as quantized integers. - - Based on the shape of the variable. - - Args: - shape: a list of integers - - Returns: - A boolean. - """ - if jnp.issubdtype(self._quantized_dtype, jnp.floating): - return False - if self._quantized_dtype is None: - return False - return len(shape) >= 1 - - def to_state(self, count, result_tree): - """Maps from a tree of (factored) values to separate trees of values.""" - return ShardedAdafactorState( - count=count, - m=jax.tree.map(lambda o: o.m, result_tree), - m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), - vr=jax.tree.map(lambda o: o.vr, result_tree), - vc=jax.tree.map(lambda o: o.vc, result_tree), - v=jax.tree.map(lambda o: o.v, result_tree)) - - def init(self, param): - """Initializes the optimizer state for a given param.""" - # The actually value that will be added to a variable for updating it. - output_update = jnp.zeros((1,)) - output_m = jnp.zeros((1,)) - output_m_scale = jnp.zeros((1,)) - output_vr = jnp.zeros((1,)) - output_vc = jnp.zeros((1,)) - output_v = jnp.zeros((1,)) - shape = param.shape - if self._beta1: - if jnp.issubdtype(self._quantized_dtype, jnp.floating): - output_m = jnp.zeros(shape, dtype=self._quantized_dtype) - elif self.should_store_momentum_in_qint(shape): - output_m = jnp.zeros(shape, dtype=self._quantized_dtype) - scale_shape = shape[1:] - output_m_scale = jnp.zeros(scale_shape, dtype=jnp.float32) - else: - output_m = jnp.zeros(shape, dtype=jnp.float32) - if self.should_use_factored_second_moment_estimate(shape): - factored_dims = self.factored_second_moment_dims(shape) - vr_axis, vc_axis = factored_dims - output_vr_shape = list(shape).copy() - del output_vr_shape[vr_axis] - output_vc_shape = list(shape).copy() - del output_vc_shape[vc_axis] - output_vr = jnp.zeros(output_vr_shape, dtype=jnp.float32) - output_vc = jnp.zeros(output_vc_shape, dtype=jnp.float32) - else: - output_v = jnp.zeros(shape, dtype=jnp.float32) - return _ShardedAdafactorUpdateResult( - update=output_update, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) - - def inf_to_nan(self, array): - """Converting Infinity values to the more sticky NaN.""" - # For example, when we have y = 1.0 / x in code and x == inf, y will become - # 0. Therefore the infinite value of x is hidden in the calculation, - # leading to silent omission of numerical issues. - if not self._maybe_inf_to_nan: - return array - return jnp.nan_to_num(array, nan=jnp.nan, posinf=jnp.nan, neginf=jnp.nan) - - def parameter_scale(self, var): - """Estimate the scale of the parameters from the current values. - - We include a minimum value of 0.001 to give it a chance to escape 0 - if it was zero-initialized. - - Instead of using the value, we could impute the scale from the shape, - as initializers do. - - Args: - var: a variable or Tensor. - - Returns: - a Scalar - """ - return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype)) - - def compute_var_and_slot_update(self, - count, - grad, - m, - m_scale, - vr, - vc, - v, - param, - var_name=None): - """Computes the var and optimizer slots updates for a single variable.""" - # We can probably skip this step - grad = grad.astype(jnp.float32) - grad = self.inf_to_nan(grad) - grad_squared = jnp.square(grad) - - # Add epsilon1_grad_sq_reg as per Algorithm 4 - # of https://arxiv.org/pdf/1804.04235.pdf - grad_squared += self._epsilon1 - grad_squared_mean = self.inf_to_nan(reduce_mean(grad_squared)) - if self._decay_method == 'adam': - assert self._decay_adam > 0 - decay_rate = adafactor_decay_rate_adam(self._decay_adam, count) - elif self._decay_method == 'pow': - assert self._decay_pow > 0 - decay_rate = adafactor_decay_rate_pow(self._decay_pow, count) - else: - raise ValueError(f'decay_method {self._decay_method} not supported.') - - learning_rate = self._learning_rate - if callable(learning_rate): - learning_rate = learning_rate(count) - - update_scale = learning_rate - old_val = param - - if self._multiply_by_parameter_scale: - update_scale *= self.parameter_scale(old_val).astype(update_scale.dtype) - - # Q(yonghui): Can we remove the hack now? - # HACK: Make things dependent on grad. - # This confounds the XLA rewriter and keeps it from fusing computations - # across different variables. This fusion is a bad for HBM usage, since - # it causes the gradients to persist in memory. - decay_rate += grad_squared_mean * 1e-30 - update_scale += grad_squared_mean * 1e-30 - # END HACK - - mixing_rate = 1. - decay_rate - shape = param.shape - - output_m = jnp.zeros((1,)) - output_m_scale = jnp.zeros((1,)) - output_vr = jnp.zeros((1,)) - output_vc = jnp.zeros((1,)) - output_v = jnp.zeros((1,)) - - factored_second_moment_dims = self.factored_second_moment_dims(shape) - if factored_second_moment_dims is not None: - # Q(shafey): Should we use the more numerically stable version - # reduce_mean(). - vr_axis, vc_axis = factored_second_moment_dims - grad_squared_row_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vr_axis)) - grad_squared_col_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vc_axis)) - new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean - new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean - output_vr = new_vr - output_vc = new_vc - long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True) - r_factor = 1. / jnp.sqrt(new_vr / long_term_mean) - c_factor = 1. / jnp.sqrt(new_vc) - x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims( - c_factor, vc_axis) - else: - # v with sharding annotation. - new_v = decay_rate * v + mixing_rate * grad_squared - output_v = new_v - x = grad / jnp.sqrt(new_v) - - if self._clip_threshold is not None: - clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold) - clipping_denom = self.inf_to_nan(clipping_denom) - x /= clipping_denom - - subtrahend = update_scale * x - if self._beta1: - if jnp.issubdtype(self._quantized_dtype, jnp.floating): - m = m.astype(jnp.float32) - elif self.should_store_momentum_in_qint(shape): - m_init_dtype = m.dtype - m = to_float(m, m_scale) - if self._nesterov: - subtrahend_original = subtrahend - subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend - subtrahend = self.inf_to_nan(subtrahend) - if self._quantized_dtype == jnp.bfloat16: - new_m = subtrahend.astype(jnp.bfloat16) - output_m = new_m - elif self.should_store_momentum_in_qint(shape): - # Update the momentum values. - new_m_val, new_m_scale = to_quantized(subtrahend, m_init_dtype) - output_m = new_m_val - output_m_scale = new_m_scale - else: - output_m = subtrahend - - if self._nesterov: - subtrahend = ( - self._beta1 * subtrahend + - (1.0 - self._beta1) * subtrahend_original) - - if self._weight_decay is not None: - # Apply decoupled weight decay to be consistent with AdamW. - var_weight_decay = None - if isinstance(self._weight_decay, dict): - for scope_pattern in self._weight_decay.keys(): - regex_pattern = re.compile(scope_pattern) - if regex_pattern.match(var_name): - var_weight_decay = self._weight_decay[scope_pattern] - else: - var_weight_decay = self._weight_decay - - if var_weight_decay is not None: - weight_decay = var_weight_decay * learning_rate - subtrahend += weight_decay * old_val - - if self._layerwise_adaptation: - include = True - if self._exclude_from_layerwise_adaptation is not None: - for scope_pattern in self._exclude_from_layerwise_adaptation: - regex_pattern = re.compile(scope_pattern) - if regex_pattern.match(var_name): - include = False - break - if include: - w_norm = reduce_rms(old_val) - g_norm = reduce_rms(subtrahend / update_scale) + self._epsilon1 - ratio = w_norm / g_norm - ratio = jnp.where( - jnp.greater(w_norm, 0), - jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), - 1.0) - subtrahend *= ratio - - return _ShardedAdafactorUpdateResult( - update=-subtrahend, - m=output_m, - m_scale=output_m_scale, - vr=output_vr, - vc=output_vc, - v=output_v) - - -def sharded_adafactor( - learning_rate: optax.Schedule, - weight_decay: Optional[Union[float, Dict[str, float]]] = None, - layerwise_adaptation: bool = False, - decay_method: str = 'adam', - decay_adam: float = 0.99, - decay_pow: float = 0., - beta1: float = 0.9, - clip_threshold: Optional[float] = 1., - factored: bool = True, - epsilon1_grad_sq_reg: float = 1e-30, - quantized_dtype: jnp.dtype = jnp.int8, - respect_skip_lp_regularization: bool = False, - exclude_from_layerwise_adaptation: Optional[List[str]] = None, - per_var_learning_summary: bool = False, - sort_factored_second_moment_dims: bool = False, - # min_dim_size_to_factor is only used when - # sort_factored_second_moment_dims=True. - min_dim_size_to_factor: int = 128, - multiply_by_parameter_scale: bool = False, - epsilon2_param_scale_reg: float = 1e-3, - maybe_inf_to_nan: bool = True, - nesterov: bool = False, -) -> optax.GradientTransformation: - """AdaFactor optimizer that supports SPMD sharding. - - Reference: - Shazeer et al, 2018: https://arxiv.org/abs/1804.04235 - - Adafactor is very similar to Adam (Kingma and Ba, 2019), the major - differences being: - - 1. For a two-dimensional AxB weight matrix, Adafactor uses only A+B auxiliary - parameters to maintain the second-moment estimator, instead of AB. - This is advantageous on memory-limited systems. In addition, beta1 - (momentum) is set to zero by default, saving an additional auxiliary - parameter per weight. Variables with >=3 dimensions are treated as - collections of two-dimensional matrices - factorization is over the final - two dimensions. - - 2. Adafactor incorporates "update-clipping" - a scale-invariant analog of - gradient clipping. This improves stability. - - 3. Adafactor does not require an external "learning rate". By default, it - incorporates a relative-update-scale schedule, corresponding to - inverse-square-root learning-rate-decay in Adam. We hope this works well - for most applications. - - Args: - learning_rate: a callable that given the current training step, returns the - learning rate to apply. - weight_decay: an optional float tensor as decoupled weight decay value, or a - dictionary with key as regex scope pattern and value as corresponding - weight decay float tensor. The value will apply to all variables under - that scope name. - layerwise_adaptation: a boolean, whether or not to use layer-wise adaptive - moments (LAMB) https://arxiv.org/abs/1904.00962. - decay_method: a string, deciding how decay_rate should be computed. - Permitted values are 'adam' and 'pow'. - decay_adam: a float, decay if decay_method == 'adam'. - decay_pow: a float, decay if decay_method == 'pow'. - beta1: a float value between 0 and 1 for momentum. - clip_threshold: an optional float >= 1 - factored: a boolean, whether or not to use factored second order momentum. - epsilon1_grad_sq_reg: Regularization constant for squared gradient. - quantized_dtype: type of the quantized input. Allowed options are jnp.int8, - jnp.int16, jnp.bfloat16 and jnp.float32. If floating-point type is - specified, accumulators are stored as such type, instead of quantized - integers. - respect_skip_lp_regularization: whether or not to respect lingvo - SKIP_LP_REGULARIZATION var collection that skips decoupled weight decay. - exclude_from_layerwise_adaptation: A dictionary with key as regex scope - pattern for variables to be skipped. - per_var_learning_summary: a bool, whether or not to export per-var learning - summaries. - sort_factored_second_moment_dims: a bool, whether to select dims to factor - by size, for the factored second moment. - min_dim_size_to_factor: an integer, only factor the statistics if two array - dimensions have at least this size. NOTE min_dim_size_to_factor is only - used when sort_factored_second_moment_dims=True. - multiply_by_parameter_scale: a boolean, if True, then scale learning_rate by - parameter scale. if False provided learning_rate is absolute step size. - NOTE False by default. - epsilon2_param_scale_reg: Regularization constant for parameter scale. Only - used when multiply_by_parameter_scale is True. - maybe_inf_to_nan: Will use jax.nan_to_num during update when True. - nesterov: Will use Nesterov momentum when True. - - Returns: - A `ShardedGradientTransformation`. - """ - - assert not respect_skip_lp_regularization - assert decay_adam >= 0 - assert decay_pow >= 0 - assert learning_rate is not None - assert decay_method == 'adam' or decay_method == 'pow', ( - f'decay_method: {decay_method} not supported. Supported methods are ' - '"pow", or "adam".') - - sharded_adafactor_helper = _ShardedAdafactorHelper( - learning_rate=learning_rate, - weight_decay=weight_decay, - layerwise_adaptation=layerwise_adaptation, - decay_method=decay_method, - decay_adam=decay_adam, - decay_pow=decay_pow, - beta1=beta1, - clip_threshold=clip_threshold, - factored=factored, - epsilon1_grad_sq_reg=epsilon1_grad_sq_reg, - quantized_dtype=quantized_dtype, - respect_skip_lp_regularization=respect_skip_lp_regularization, - exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation, - per_var_learning_summary=per_var_learning_summary, - sort_factored_second_moment_dims=sort_factored_second_moment_dims, - min_dim_size_to_factor=min_dim_size_to_factor, - multiply_by_parameter_scale=multiply_by_parameter_scale, - epsilon2_param_scale_reg=epsilon2_param_scale_reg, - maybe_inf_to_nan=maybe_inf_to_nan, - nesterov=nesterov) - - def init_fn(params): - """Initializes the optimizer's state.""" - return sharded_adafactor_helper.to_state( - jnp.zeros([], jnp.int32), - jax.tree.map(sharded_adafactor_helper.init, params)) - - def update_fn(updates, state, params=None): - if params is None: - raise ValueError( - 'You are using a transformation that requires the current value of ' - 'parameters, but you are not passing `params` when calling `update`.') - - compute_var_and_slot_update_fn = functools.partial( - sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree.map(compute_var_and_slot_update_fn, - updates, - state.m, - state.m_scale, - state.vr, - state.vc, - state.v, - params) - updates = jax.tree.map(lambda o: o.update, output) - count_plus_one = state.count + jnp.array(1, jnp.int32) - updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) - return updates, updated_states - - return optax.GradientTransformation(init=init_fn, update=update_fn) diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py deleted file mode 100644 index 1833ab8af..000000000 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec -from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ - sharded_adafactor - -_GRAD_CLIP_EPS = 1e-6 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adafactor optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = sharded_adafactor( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree.map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/__init__.py b/reference_algorithms/paper_baselines/adafactor/pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py deleted file mode 100644 index 7aa457a25..000000000 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ /dev/null @@ -1,336 +0,0 @@ -"""Submission file for Adafactor in PyTorch.""" - -from functools import partial -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from absl import logging -import torch -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algoperf import spec -from algoperf.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adafactor optimizer and a learning rate schedule.""" - del model_state - del rng - - # Create optimizer. - optimizer_state = { - 'optimizer': - Adafactor( - model_params.parameters(), - lr=hyperparameters.learning_rate, - beta1=1 - hyperparameters.one_minus_beta1, - weight_decay=hyperparameters.weight_decay), - } - optimizer = optimizer_state['optimizer'] - warmup = LinearLR( - optimizer, - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - cosine_steps = max(workload.step_hint - hyperparameters.warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - optimizer_state['scheduler'] = SequentialLR( - optimizer, - schedulers=[warmup, cosine_decay], - milestones=[hyperparameters.warmup_steps]) - return optimizer_state - - -class Adafactor(torch.optim.Optimizer): - """Adapted from https://github.com/huggingface/transformers/blob/main/ - src/transformers/optimization.py#L386""" - - def __init__( - self, - params, - lr=None, - beta1=0.9, - decay_adam=0.99, - weight_decay=0.0, - ): - defaults = dict( - lr=lr, - beta1=beta1, - decay_adam=decay_adam, - weight_decay=weight_decay, - decay_pow=0.0, - layerwise_adaptation=False, - decay_method='adam', - clip_threshold=1.0, - factored=True, - epsilon1_grad_sq_reg=1e-30, - respect_skip_lp_regularization=False, - exclude_from_layerwise_adaptation=None, - per_var_learning_summary=False, - sort_factored_second_moment_dims=False, - # Unused because sort_factored_second_moment_dims=False. - min_dim_size_to_factor=128, - multiply_by_parameter_scale=False, - # Unused because multiply_by_parameter_scale=False. - epsilon2_param_scale_reg=1e-3, - maybe_inf_to_nan=True, - ) - super().__init__(params, defaults) - - def inf_to_nan(self, group, x): - if group["maybe_inf_to_nan"]: - x = torch.nan_to_num(x, nan=torch.nan, posinf=torch.nan, neginf=torch.nan) - return x - - def step(self, closure=None): - """ - Performs a single optimization step - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - inf_to_nan = partial(self.inf_to_nan, group) - for p in group["params"]: - if p.grad is None: - continue - grad = p.grad.data - grad = inf_to_nan(grad) - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() - if grad.is_sparse: - raise RuntimeError("Adafactor does not support sparse gradients.") - - state = self.state[p] - grad_shape = grad.shape - - factored = len(grad_shape) >= 2 - - # State Initialization - if len(state) == 0: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like(grad) - if factored: - state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) - state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + - grad_shape[-1:]).to(grad) - else: - state["exp_avg_sq"] = torch.zeros_like(grad) - else: - state["exp_avg"] = state["exp_avg"].to(grad) - if factored: - state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) - state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) - else: - state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) - - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - - state["step"] += 1 - lr = group["lr"] - beta1 = group["beta1"] - beta2 = group["decay_adam"] - - t = state["step"] - beta2t = beta2 * (1. - beta2**(t - 1.)) / (1. - beta2**t) - - exp_avg_sq_update = (grad**2) + group["epsilon1_grad_sq_reg"] - if factored: - exp_avg_sq_row = state["exp_avg_sq_row"] - exp_avg_sq_col = state["exp_avg_sq_col"] - - exp_avg_sq_row.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t) - exp_avg_sq_col.mul_(beta2t).add_( - exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t) - - r_factor = inf_to_nan( - exp_avg_sq_row / - exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1) - c_factor = inf_to_nan(exp_avg_sq_col).unsqueeze(-2) - denom = r_factor * c_factor - else: - exp_avg_sq = state["exp_avg_sq"] - - exp_avg_sq.mul_(beta2t).add_(exp_avg_sq_update, alpha=1.0 - beta2t) - denom = exp_avg_sq - - denom = denom.sqrt() - update = grad / denom - # Clip the update based on RMS. - clipping_denom = inf_to_nan(torch.square(update).mean().sqrt() \ - /group["clip_threshold"]).clamp(min=1.0) - update = update / clipping_denom * lr - # Momentum - exp_avg = state["exp_avg"] - exp_avg.mul_(beta1).add_(update, alpha=1 - beta1) - - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * lr) - - p_data_fp32.add_(-exp_avg) - - if p.data.dtype in {torch.float16, torch.bfloat16}: - p.data.copy_(p_data_fp32) - - return loss - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json deleted file mode 100644 index 5543689ea..000000000 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.45, "scaling": "log" - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json deleted file mode 100644 index 98a506084..000000000 --- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index dde41fa6d..bd73cdf0c 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -7,8 +7,11 @@ import jax from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -50,24 +53,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -77,7 +74,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, - update_batch_norm=True) + update_batch_norm=True,) loss_dict = workload.loss_fn( label_batch=batch['targets'], logits_batch=logits, @@ -90,9 +87,8 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + + # Compute local loss and gradients loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -105,7 +101,7 @@ def _loss_fn(params): grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) + current_param_container) updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm @@ -130,7 +126,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -139,23 +134,51 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Set up mesh and sharding + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ), + out_shardings=( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + )) + # print(batch) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -205,6 +228,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 32 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/paper_baselines/lamb/__init__.py b/reference_algorithms/paper_baselines/lamb/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/lamb/jax/__init__.py b/reference_algorithms/paper_baselines/lamb/jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py deleted file mode 100644 index 70e305514..000000000 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a LAMB optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = optax.lamb( - learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree.map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/__init__.py b/reference_algorithms/paper_baselines/lamb/pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py deleted file mode 100644 index c1c6cec0a..000000000 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ /dev/null @@ -1,328 +0,0 @@ -"""Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from absl import logging -import torch -from torch import Tensor -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algoperf import spec - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py -class LAMB(torch.optim.Optimizer): - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0.0): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step.""" - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - lamb( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def lamb(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float): - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - update = exp_avg / denom - update.div_(bias_correction1) - update.add_(weight_decay * param) - - # Scale updates by trust ratio. - param_norm = torch.linalg.norm(param) - update_norm = torch.linalg.norm(update) - - # Set trust_ratio to 1 in case where parameters would never be updated. - if param_norm == 0. or update_norm == 0.: - trust_ratio = 1. - else: - trust_ratio = param_norm / update_norm - - param.add_(update, alpha=-lr * trust_ratio) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a LAMB optimizer and a learning rate schedule.""" - del model_state - del rng - - optimizer_state = { - 'optimizer': - LAMB( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - loss, _ = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json deleted file mode 100644 index f2fcde461..000000000 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.3, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json deleted file mode 100644 index 8934e512d..000000000 --- a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cbb6d6dcd..72d8862b3 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,15 +1,13 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" -import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +35,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,21 +85,15 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -124,9 +116,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Get global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -173,15 +163,48 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index c451a18ac..98bc054d4 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -1,7 +1,4 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. from typing import (Any, @@ -16,13 +13,13 @@ # isort: on import chex -from flax import jax_utils import jax from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -192,24 +189,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -281,16 +272,42 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, batch, rng, grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 0e53aae42..7e09b274f 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,14 +1,14 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" -import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 @@ -37,7 +37,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( @@ -87,21 +87,21 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +# @functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4)) +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -124,9 +124,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # # Get correct global mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -164,7 +162,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -173,23 +170,54 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Create shardings for each argument + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = NamedSharding(mesh, P()) # No partitioning + sharded = NamedSharding(mesh, P('batch')) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rngs + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -239,6 +267,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/reference_algorithms/paper_baselines/sam/__init__.py b/reference_algorithms/paper_baselines/sam/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/sam/jax/__init__.py b/reference_algorithms/paper_baselines/sam/jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py deleted file mode 100644 index b76589705..000000000 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec - -_GRAD_CLIP_EPS = 1e-6 - - -# Copied from the official SAM GitHub repository. Note how it doesn't add an -# epsilon to the gradient norm before normalizing the gradients. -def dual_vector(y: jnp.ndarray) -> jnp.ndarray: - """Returns the solution of max_x y^T x s.t. - ||x||_2 <= 1. - Args: - y: A pytree of numpy ndarray, vector y in the equation above. - """ - gradient_norm = jnp.sqrt( - sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) - normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) - return normalized_gradient - - -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/ -# sharpness_aware_minimization.py -def sharpness_aware_minimization( - rho: float, - grad_clip: Optional[float], - batch_axis_name: str, - base_opt_init_fn, - base_opt_update_fn, -) -> optax.GradientTransformation: - """Implementation of Sharpness Aware Minimization (SAM). - Paper: https://arxiv.org/abs/2010.01412 - Code: https://github.com/google-research/sam - References: - Foret et al, 2021: https://arxiv.org/abs/2010.01412 - Args: - rho: The size of the neighborhood for the sharpness aware minimization - gradient updates. Defaults to 0.1. - grad_clip: The optional value to clip the updates by. Defaults to None. - batch_axis_name: the name of the axis to pmap over. Used to run a pmean - before applying the optimizer update. - base_opt_init_fn: The initialization function for the base optimizer used to - generate updates given the total gradient. - base_opt_update_fn: The update function for the base optimizer used to - generate updates given the total gradient. - Returns: - The corresponding `GradientTransformation`. - """ - - def init_fn(params): - return base_opt_init_fn(params) - - def update_fn(updates, state, grad_fn_params_tuple): - (grad_fn, params) = grad_fn_params_tuple - - # Updates here have been synced (mean) across devices before being sent to - # the optimizer. We again take the correct mean of the gradients computed on - # the noised parameters in the same order as on the original gradients and - # with the same 1e-6 epsilon that is used when clipping the gradients. - updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) - (_, (n_valid_examples, _)), updates = grad_fn(noised_params) - # Get correct global mean grad. - (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), - axis_name=batch_axis_name) - updates = jax.tree.map(lambda x: x / n_valid_examples, updates) - - if grad_clip: - updates_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) - scaled_updates = jax.tree.map( - lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, - lambda _: updates, - None) - updates, state = base_opt_update_fn(updates, state, params) - - return updates, state - - return optax.GradientTransformation(init_fn, update_fn) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a SAM optimizer (with AdamW base) and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create base optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - - # Create SAM update fn. - grad_clip = ( - hyperparameters.grad_clip - if hasattr(hyperparameters, 'grad_clip') else None) - opt_init_fn, opt_update_fn = sharpness_aware_minimization( - rho=hyperparameters.rho, - grad_clip=grad_clip, - batch_axis_name='batch', - base_opt_init_fn=opt_init_fn, - base_opt_update_fn=opt_update_fn) - - # Initialize optimizer state. - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params, update_batch_norm=True): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=update_batch_norm) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - second_grad_fn = jax.value_and_grad( - functools.partial(_loss_fn, update_batch_norm=False), has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree.map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, (second_grad_fn, current_param_container)) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/sam/pytorch/__init__.py b/reference_algorithms/paper_baselines/sam/pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py deleted file mode 100644 index 92603f036..000000000 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ /dev/null @@ -1,287 +0,0 @@ -"""Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" - -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple - -from absl import logging -import torch -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algoperf import spec -from algoperf.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -# Modified from https://github.com/davda54/sam. -class SAM(torch.optim.Optimizer): - - def __init__(self, - params: spec.ParameterContainer, - base_optimizer: torch.optim.Optimizer, - rho: float = 0.05, - adaptive: bool = False, - **kwargs): - if rho < 0.0: - raise ValueError(f'Invalid rho, should be non-negative: {rho}') - - defaults = dict(rho=rho, adaptive=adaptive, **kwargs) - super().__init__(params, defaults) - - self.base_optimizer = base_optimizer(self.param_groups, **kwargs) - self.param_groups = self.base_optimizer.param_groups - self.defaults.update(self.base_optimizer.defaults) - - @torch.no_grad() - def first_step(self, zero_grad: bool = False): - grad_norm = self._grad_norm() - for group in self.param_groups: - scale = group['rho'] / grad_norm - - for p in group['params']: - if p.grad is None: - continue - self.state[p]['old_p'] = p.data.clone() - factor = torch.pow(p, 2) if group['adaptive'] else 1.0 - e_w = factor * p.grad * scale.to(p) - p.add_(e_w) # Climb to the local maximum 'w + e(w)'. - - if zero_grad: - self.zero_grad() - - @torch.no_grad() - def second_step(self, zero_grad: bool = False): - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - p.data = self.state[p]['old_p'] # Get back to 'w' from 'w + e(w)'. - - self.base_optimizer.step() # Do the actual 'sharpness-aware' update. - - if zero_grad: - self.zero_grad() - - @torch.no_grad() - def step(self, closure: Callable = None): - if closure is None: - raise ValueError('SAM requires closure, but it was not provided.') - # The closure should do a full forward-backward pass. - closure = torch.enable_grad()(closure) - - self.first_step(zero_grad=True) - closure() - self.second_step() - - def _grad_norm(self): - # In case of model parallelism, put everything on the same device. - shared_device = self.param_groups[0]['params'][0].device - norm = torch.norm( - torch.stack([((torch.abs(p) if group['adaptive'] else 1.0) * - p.grad).norm(p=2).to(shared_device) - for group in self.param_groups - for p in group['params'] - if p.grad is not None]), - p=2) - return norm - - def load_state_dict(self, state_dict: Dict): - super().load_state_dict(state_dict) - self.base_optimizer.param_groups = self.param_groups - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an AdamW optimizer and a learning rate schedule.""" - del model_state - del rng - - # Create SAM optimizer with AdamW base. - base_optimizer = torch.optim.AdamW - optimizer_state = { - 'optimizer': - SAM(model_params.parameters(), - base_optimizer=base_optimizer, - rho=hyperparameters.rho, - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - # Create learning rate schedule. - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - current_model = current_param_container - current_model.train() - - def _loss_fn(params, update_batch_norm=True): - """Loss function used for training.""" - logits_batch, new_model_state = workload.model_fn( - params=params, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=update_batch_norm) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - return loss, new_model_state - - # First backward pass. - loss, _ = _loss_fn(current_model, update_batch_norm=True) - loss.backward() - - logging_loss = loss.clone().detach() - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - - optimizer_state['optimizer'].first_step(zero_grad=True) - - # Second forward-backward pass. - loss, new_model_state = _loss_fn(current_model, update_batch_norm=False) - loss.backward() - - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].second_step(zero_grad=True) - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': logging_loss.item(), - 'grad_norm': grad_norm.item(), - }, - global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - logging_loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection( - workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json deleted file mode 100644 index 66dae232b..000000000 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 5e-2, "max": 0.43, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } -} diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json deleted file mode 100644 index 89c480e7a..000000000 --- a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 1e-2, "max": 0.2, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - }, - "rho": { - "feasible_points": [0.01, 0.02, 0.05] - } -} diff --git a/reference_algorithms/paper_baselines/shampoo/__init__.py b/reference_algorithms/paper_baselines/shampoo/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/shampoo/jax/__init__.py b/reference_algorithms/paper_baselines/shampoo/jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py deleted file mode 100644 index a5c2732ac..000000000 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ /dev/null @@ -1,2482 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Google Research Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# An implementation of distributed Shampoo optimizer from: -# -# Scalable Second Order Optimization for Deep Learning -# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer -# Preprint Paper: https://arxiv.org/abs/2002.09018 -# -# This implementation moves computation of inverse pth root back to the -# accelerator (if higher precision is available). -# -# Authors: Rohan Anil (rohananil at google dot com) -# Vineet Gupta (vineet at google dot com) -# James Lottes (jlottes at google dot com) -# Anudhyan Boral (anudhyan at google dot com) -# -# Forked with minor modifications from: -# github.com/google-research/google-research/blob/master/scalable_shampoo/ (...) -# optax/distributed_shampoo.py -"""Distributed Shampoo Implementation.""" - -import enum -import functools -import itertools -import logging -from typing import Any, cast, List, NamedTuple, Optional, TypeVar, Union - -import chex -from flax import struct -import jax -from jax import lax -from jax.experimental import pjit -from jax.experimental.sparse import linalg -import jax.numpy as jnp -import numpy as np -import optax - -# Dtype for inverse-pth root routine -# Switch to f64 if you have hardware that supports it. Enable the jax flag -# jax_enable_x64 for this to work, otherwise it will default to float32. -_MAT_INV_PTH_ROOT_DTYPE = jnp.float64 # pylint: disable=invalid-name - -# Small epsilon to avoid divide by zero. -_EPSILON = 1e-25 - - -# pylint:disable=no-value-for-parameter -@struct.dataclass -class QuantizedValue: - """State associated with quantized value.""" - quantized: chex.Array - diagonal: chex.Array # Diagonal (if extract_diagonal is set) - bucket_size: chex.Array - quantized_dtype: jnp.dtype = struct.field( - pytree_node=False) # Dtype for the quantized value. - extract_diagonal: bool = struct.field( - pytree_node=False) # In case its centered. - shape: Any = struct.field(pytree_node=False) # Shape of the tensor. - - @classmethod - def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): - if isinstance(fvalue, list) and not fvalue: - return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) - quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( - fvalue, quantized_dtype, extract_diagonal) - return QuantizedValue(quantized, - diagonal_fvalue, - bucket_size, - quantized_dtype, - extract_diagonal, - list(quantized.shape)) - - # Quantization is from Lingvo JAX optimizers. - # We extend it for int16 quantization of PSD matrices. - @classmethod - def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): - """Returns quantized value and the bucket.""" - if quantized_dtype == jnp.float32: - return fvalue, [], [] - elif quantized_dtype == jnp.bfloat16: - return fvalue.astype(jnp.bfloat16), [], [] - - float_dtype = fvalue.dtype - if quantized_dtype == jnp.int8: - # value -128 is not used. - num_buckets = jnp.array(127.0, dtype=float_dtype) - elif quantized_dtype == jnp.int16: - # value -32768 is not used. - num_buckets = jnp.array(32767.0, dtype=float_dtype) - else: - raise ValueError(f'Quantized dtype {quantized_dtype} not supported.') - # max value is mapped to num_buckets - - if extract_diagonal and fvalue.ndim != 2: - raise ValueError( - f'Input array {fvalue} must be 2D to work with extract_diagonal.') - - diagonal_fvalue = [] - if extract_diagonal: - diagonal_fvalue = jnp.diag(fvalue) - # Remove the diagonal entries. - fvalue = fvalue - jnp.diag(diagonal_fvalue) - - # TODO(rohananil): Extend this by making use of information about the blocks - # SM3 style which will be useful for diagonal statistics - # We first decide the scale. - if fvalue.ndim < 1: - raise ValueError( - f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') - - max_abs = jnp.max(jnp.abs(fvalue), axis=0) - bucket_size = max_abs / num_buckets - bs_expanded = bucket_size[jnp.newaxis, Ellipsis] - # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, - bs_expanded, - jnp.ones_like(bs_expanded)) - ratio = fvalue / bs_nonzero - # We use rounding to remove bias. - quantized = jnp.round(ratio) - return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size - - def to_float(self): - """Returns the float value.""" - if isinstance(self.quantized, list) and not self.quantized: - return self.quantized - - if self.quantized_dtype == jnp.float32: - return self.quantized - - if self.quantized_dtype == jnp.bfloat16: - return self.quantized.astype(jnp.float32) - - float_dtype = self.bucket_size.dtype - bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] - val = self.quantized.astype(float_dtype) * bucket_size - if self.extract_diagonal: - val += jnp.diag(self.diagonal) - return val - - -def _default_zero_field(): - return struct.field( - default_factory=functools.partial(jnp.array, 0, jnp.float32)) - - -T = TypeVar("T") - - -def _maybe_ix(ls, ix): - """Return ls[ix] if not None else None.""" - if ls is None: - return None - return ls[ix] - - -def _maybe(f): - """Lifts f to Maybe monad; ie return None if first arg is.""" - - def wrap_f(x, *args, **kwargs): - if x is None: - return None - return f(x, *args, **kwargs) - - return wrap_f - - -InversePthRootDiagnosticsSubtype = TypeVar( - "InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics") - - -@struct.dataclass -class InversePthRootDiagnostics: - """Diagnostics for inverse p-th root iterative procedure. - - Given an inverse pth root B = A^(-1/p), contains the average and - maximum diagonal and off diagonal absolute entrywise errors between - (B^p A) and I. - """ - max_diag_error: chex.Array = _default_zero_field() - avg_diag_error: chex.Array = _default_zero_field() - max_off_diag_error: chex.Array = _default_zero_field() - avg_off_diag_error: chex.Array = _default_zero_field() - p: chex.Array = _default_zero_field() - - @classmethod - def create(cls, pth_inverse_root, matrix, p): - """Generates a diagnostics struct from (-1/p) root result.""" - mat_m = jnp.matmul( - mat_power(pth_inverse_root, p), - matrix, - precision=jax.lax.Precision.HIGHEST) - num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size - diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32) - off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype( - jnp.float32) - return cls( - max_diag_error=jnp.max(diag_error).astype(jnp.float32), - avg_diag_error=jnp.mean(diag_error).astype(jnp.float32), - max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32), - avg_off_diag_error=(jnp.sum(off_diag_error) / - num_off_diag_entries).astype(jnp.float32), - p=jnp.array(p, jnp.float32)) - - -LOBPCGDiagnosticsSubtype = TypeVar( - "LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics") - - -@struct.dataclass -class LOBPCGDiagnostics: - """Diagnostics for iterative LOBPCG eigenvalue routine. - - Contains consistency error for LOBPCG eigenvalue routine, which - refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair - (v, lambda). This metics dataclass retains consistency error - and other useful LOBPCG values. - """ - lobpcg_iters: chex.Array = _default_zero_field() - max_consistency_error: chex.Array = _default_zero_field() - avg_consistency_error: chex.Array = _default_zero_field() - # Average of absolute value of off-diagonal of V^T V for eigenvalues V. - avg_orthogonality_error: chex.Array = _default_zero_field() - max_eigenvalue: chex.Array = _default_zero_field() - min_eigenvalue: chex.Array = _default_zero_field() - num_topk_eigenvectors: chex.Array = _default_zero_field() - - @classmethod - def create(cls, matrix, eigvals, eigvecs, lobpcg_iters): - """Generates LOBPCG diagnostics from the result of the routine.""" - num_topk = len(eigvals) - num_off_diag = num_topk * (num_topk - 1) - precision = jax.lax.Precision.HIGHEST - - mat_eigvecs = matrix.dot(eigvecs, precision=precision) - consistency_error_unnormalized = jnp.linalg.norm( - mat_eigvecs - eigvals * eigvecs, ord=2, axis=0) - normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals - consistency_error = consistency_error_unnormalized / normalization - - orthogonality_error = eigvecs.T.dot(eigvecs, precision=precision) - orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error)) - - return cls( - lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32), - max_consistency_error=jnp.max(consistency_error).astype(jnp.float32), - avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32), - avg_orthogonality_error=(jnp.sum(orthogonality_error) / - num_off_diag).astype(jnp.float32), - max_eigenvalue=jnp.max(eigvals).astype(jnp.float32), - min_eigenvalue=jnp.min(eigvals).astype(jnp.float32), - num_topk_eigenvectors=jnp.array(num_topk, jnp.float32), - ) - - -@struct.dataclass -class TrainingMetrics: - """Diagnostic metrics from training.""" - # Error for inverse-pth roots. - inverse_pth_root_errors: chex.Array = _default_zero_field() - # Iteration count for inverse-pth roots. - inverse_pth_root_iters: chex.Array = _default_zero_field() - # If final iteration error increases sufficiently, iteration terminates early. - # This field records the ratio of the final iteration error. - final_error_ratio: chex.Array = _default_zero_field() - # Max eigen value from either the power iteration or from LOBPCG. - max_eigen_value: chex.Array = _default_zero_field() - # Total retries of inverse pth root iterative method. - total_retries: chex.Array = _default_zero_field() - - lobpcg_diagnostics: LOBPCGDiagnostics = struct.field( - default_factory=LOBPCGDiagnostics) - # Rich matrix entrywise error diagnostics, if enabled. - inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field( - default_factory=InversePthRootDiagnostics) - # Diagnostics applied to the conditioned p-th root problem, after top - # eigenvectors are removed, if LOBPCG is being applied. - conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = ( - struct.field(default_factory=InversePthRootDiagnostics)) - # TODO(rohananil): Add more important metrics to track during training. - - -# Per parameter optimizer state used in data-parallel training. -class ParameterStats(NamedTuple): - """State associated to each parameter of the model being trained.""" - diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner - statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array) - preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array) - diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner - momentum: QuantizedValue # Momentum for the shampoo preconditioner - training_metrics: Union[TrainingMetrics, optax.MaskedNode] # Optional. - - -# For training extremely large model; We keep a global state with a concatenated -# statistics and preconditioner states for all vars. This is so that we can -# annotate the leading axis to be sharded to save memory at the cost of -# communication. -@struct.dataclass -class GlobalShardedParameterStats: - statistics: chex.Array # Statistics - preconditioners: chex.Array # Preconditioners - exponents: chex.Array # exponents - - -# These are per-parameter local states; All statistics here mirror the parameter -# Thus the sharding is copied over from the param specification. -@struct.dataclass -class LocalShardedParameterStats: - """State associated to each parameter of the model being trained.""" - diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner - diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner - momentum: QuantizedValue # Momentum for the shampoo preconditioner - training_metrics: Union[TrainingMetrics, optax.MaskedNode] - index_start: Union[np.int32, int] = struct.field( - pytree_node=False) # Index into global statistics array - sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics. - - -def default_training_metrics(): - """Create a default TrainingMetrics.""" - return TrainingMetrics() - - -def init_training_metrics( - num_statistics, - generate_training_metrics, -): - """Initialize TrainingMetrics, masked if disabled.""" - if not generate_training_metrics: - return optax.MaskedNode() - return jax.tree.map( - functools.partial(jnp.repeat, repeats=num_statistics), - default_training_metrics()) - - -def init_training_metrics_shapes( - num_statistics, - generate_training_metrics, -): - """Initialize training metrics shape/dtype.""" - seed = init_training_metrics( - num_statistics, - generate_training_metrics, - ) - return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) - - -def init_training_metrics_pspec(generate_training_metrics,): - """Initialize training metrics partition specification.""" - if not generate_training_metrics: - return optax.MaskedNode() - return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), - default_training_metrics()) - - -class ShardedShampooStats(NamedTuple): - """Shampoo state in sharded mode.""" - global_stats: Any - local_stats: Any - - -class ShampooState(NamedTuple): - count: chex.Array - stats: Any - - -class InitFnState(NamedTuple): - init_fn: Any - pspec_fn: Any - shape_and_dtype_fn: Any - - -class GraftingType(enum.IntEnum): - NONE = 0 - SGD = 1 - ADAGRAD = 2 - RMSPROP = 3 - RMSPROP_NORMALIZED = 4 - SQRT_N = 5 - ADAGRAD_NORMALIZED = 6 - - -class PreconditionerType(enum.IntEnum): - # Default, computes preconditioner for each dim - ALL = 1 - # One sided Shampoo, in this cases only on input dim. - # Assumes last dim is always the output dim and everything else input dim. - INPUT = 2 - # One sided Shampoo, in this cases only on output dim. - # Assumes last dim is always the output dim and everything else input dim. - OUTPUT = 3 - - -def power_iteration( - matrix, - num_iters=100, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - padding_start=None, -): - r"""Power iteration algorithm. - - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. - - References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) - - Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - padding_start: if set, assumes rows and columns after padding_start are - zero. - - Returns: - eigen vector, eigen value - """ - matrix_size = matrix.shape[-1] - - def _iter_condition(state): - i, unused_v, unused_s, unused_s_v, run_step = state - return jnp.logical_and(i < num_iters, run_step) - - def _iter_body(state): - """One step of power iteration.""" - i, new_v, s, s_v, unused_run_step = state - new_v = new_v / jnp.linalg.norm(new_v) - - s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision) - s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision) - return (i + 1, - s_v, - s_new, - s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) - - # Figure out how to use step as seed for random. - v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0, - matrix_size).astype(matrix.dtype) - v_0 = jnp.array(v_0) - if padding_start is not None: - v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start) - - init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, - init_state) - v_out = v_out / jnp.linalg.norm(v_out) - return v_out, s_out - - -def mat_power( - mat_m, - p, - precision=lax.Precision.HIGHEST, -): - """A simple matrix power method. M^p where p can be TracedValue.""" - power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE) - - def _iter_condition(state): - i, _, _ = state - return i > 0 - - def _iter_body(state): - i, power, mat = state - - power = jax.lax.cond(i % 2 == 1, - lambda: jnp.matmul(mat, power, precision=precision), - lambda: power) - i //= 2 - mat = jnp.matmul(mat, mat, precision=precision) - return i, power, mat - - _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m)) - return result - - -def _pth_root_difference(w, alpha, beta, p): - """Computes (w+alpha)^(-1/p)-(w+beta)^(-1/p).""" - - a = w + alpha - b = w + beta - a_minus_b = alpha - beta - exp = -1 / p - - def _stable_subtract(b, a_minus_b): - # Mathematically identical to the target expression, with (w+beta)^(-1/p) - # term factored out and w cancellation in the subtraction. - return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b)) - - return jnp.where( - # Choose the branch with the best log1p approximation. - jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a), - -_stable_subtract(a, -a_minus_b), - _stable_subtract(b, a_minus_b)) - - -def matrix_inverse_pth_root( - matrix, - p, - num_iters=100, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - padding_start=None, - prev=None, - eigh=False, -): - """Computes `matrix^(-1/p)`, where `p` is a positive integer. - - This function uses the Eigh or Coupled newton iterations algorithm for - the computation of a matrix's inverse pth root. - - - References: - [Functions of Matrices, Theory and Computation, - Nicholas J Higham, Pg 184, Eq 7.18]( - https://epubs.siam.org/doi/book/10.1137/1.9780898717778) - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - num_iters: Maximum number of iterations. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to - `lobpcg_topk_precondition`. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - eigh: If True, uses eigh for inverse-pth root computation. - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ - - if eigh: - return matrix_inverse_pth_root_eigh(matrix, - p, - ridge_epsilon, - error_tolerance, - precision, - relative_matrix_epsilon, - padding_start, - prev) - del prev - - assert matrix.shape[0] == matrix.shape[1] - - # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root. - # Switch to f64 if you have hardware that supports it. Enable the jax flag - # jax_enable_x64 for this to work. - matrix_size = matrix.shape[0] - orig_dtype = matrix.dtype - matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE) - alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) - identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) - - if padding_start is not None: - # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) - matrix *= ix[jnp.newaxis, :] - matrix *= ix[:, jnp.newaxis] - identity *= ix - - original_matrix = matrix - - # Only used in lobpcg branches, but required by pytype. - eigvals, eigvecs, lobpcg_diagnostics = None, None, None - if lobpcg_topk_precondition > 0: - # TODO(vladf): reuse previous top-k as the initial search directions - pad_shape = (matrix_size - lobpcg_topk_precondition, - lobpcg_topk_precondition) - search_dirs = jnp.concatenate( - (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0) - eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard( # pylint: disable=unbalanced-tuple-unpacking - matrix, search_dirs, - lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter) - lobpcg_diagnostics = LOBPCGDiagnostics.create( - matrix, - eigvals, - eigvecs, - lobpcg_iters, - ) - - # The minimal eigenvalue among top-k becomes the maximal one in the whole - # matrix after deflation. - deflation = eigvals - jnp.min(eigvals) - scaled_vecs = eigvecs * jnp.sqrt(deflation) - - # Deflate out top eigenvectors to reduce matrix condition number. - matrix -= scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) - - if relative_matrix_epsilon: - if eigvals is not None: - max_ev = jnp.max(eigvals) - else: - # Only use power iteration if lobpcg wasn't already used to derive the - # top eigenvalue. - _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=1e-6, - precision=precision, - padding_start=padding_start) - else: - # Use absolute matrix epsilon scaling otherwise. - max_ev = 1.0 - - ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, _EPSILON) - - # Sometimes error increases after an iteration before decreasing and - # converging. 1.2 factor is used to bound the maximal allowed increase. - max_error_ratio = 1.2 - - def _iter_condition(state): - i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state - error_above_threshold = jnp.logical_and(error > error_tolerance, - error_ratio < max_error_ratio) - return jnp.logical_and(i < num_iters, error_above_threshold) - - def _iter_body(state): - (i, mat_m, mat_h, unused_old_mat_h, error, unused_error_ratio) = state - mat_m_i = (1 - alpha) * identity + alpha * mat_m - new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) - new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) - new_error = jnp.max(jnp.abs(new_mat_m - identity)) - return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error / error) - - if matrix_size == 1: - damped_matrix = matrix + ridge_epsilon - resultant_mat_h = damped_matrix**alpha - error = jnp.array(0, jnp.float32) - iters = 0 - error_ratio = 0.0 - else: - - retry_loop_error_threshold = 0.05 - num_tries = 6 - init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True]) - - def _outer_iter_condition_fn(state): - i, _, _, _, _, iter_failed = state - return jnp.logical_and(iter_failed, i < num_tries) - - def _outer_body_fn(state): - i, _, _, _, _, _ = state - # Update the epsilon based on the loop iteration. - damped_matrix = matrix + (ridge_epsilon * (10**i) * identity) - z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) - new_mat_m_0 = damped_matrix * z - new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) - new_mat_h_0 = identity * jnp.power(z, 1.0 / p) - init_state = tuple( - [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0]) - iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop( - _iter_condition, _iter_body, init_state) - error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32) - is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype) - resultant_mat_h = is_converged * \ - mat_h + (1 - is_converged) * old_mat_h - return (i + 1, - resultant_mat_h, - error, - iters, - error_ratio, - error > retry_loop_error_threshold) - - loop_outputs = jax.lax.while_loop(_outer_iter_condition_fn, - _outer_body_fn, - init_outer_state) - total_retries, resultant_mat_h, error, iters, error_ratio, _ = loop_outputs - - conditioned_resultant_mat = resultant_mat_h - - if lobpcg_topk_precondition > 0: - # Since we deflated the top eigenvectors prior to p-th root inverse, - # the resultant matrix has larger eigenvalues associated with those - # same eigenvectors, which we need to now re-deflate. - # - # Note that _pth_root_difference returns positive values for this - # particular argument ordering as min(eigvals) <= eigvals for the - # jnp.sqrt below. - pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p) - scaled_vecs = eigvecs * jnp.sqrt(pth_diff) - resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot( - scaled_vecs.T, precision=jax.lax.Precision.HIGHEST) - - error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32), - inverse_pth_root_iters=jnp.array(iters, jnp.float32), - final_error_ratio=jnp.array(error_ratio, jnp.float32), - max_eigen_value=jnp.array(max_ev, jnp.float32), - total_retries=jnp.array(total_retries, jnp.float32)) - - if lobpcg_topk_precondition > 0: - damped_matrix = matrix + \ - (ridge_epsilon * (10**total_retries) * identity) - conditioned_diagnostics = InversePthRootDiagnostics.create( - conditioned_resultant_mat, damped_matrix, p) - unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity - unconditioned_diagnostics = InversePthRootDiagnostics.create( - resultant_mat_h, unconditioned_damped_matrix, p) - # The max entrywise error in error_metrics.inverse_pth_root_errors refers - # to what was derived from the inverse pth root iteration, which with - # LOBPCG refers to the conditioned problem. Make sure to use the error - # from the unconditioned problem. - unconditional_errors = jnp.maximum( - unconditioned_diagnostics.max_diag_error, - unconditioned_diagnostics.max_off_diag_error) - error_metrics = error_metrics.replace( - inverse_pth_root_errors=unconditional_errors, - lobpcg_diagnostics=lobpcg_diagnostics, - conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics, - inverse_pth_root_diagnostics=unconditioned_diagnostics, - ) - - if padding_start is not None: - # Occasionally, pure-padding matrices are handed to the inversion routine - # due to some TPU hosts not having the same number of preconditioning - # matrices. - resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) - error_metrics = error_metrics.replace(inverse_pth_root_errors=error) - - resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype) - return resultant_mat_h, error_metrics - - -def matrix_inverse_pth_root_eigh( - matrix, - p, - ridge_epsilon=1e-6, - error_tolerance=1e-6, - precision=lax.Precision.HIGHEST, - relative_matrix_epsilon=True, - padding_start=None, - prev=None, -): - """Computes `matrix^(-1/p)`, where `p` is a positive integer. - - This function uses eigh for the computation of a matrix's inverse pth - root. - - Args: - matrix: the symmetric PSD matrix whose power it to be computed - p: exponent, for p a positive integer. - ridge_epsilon: Ridge epsilon added to make the matrix positive definite. - error_tolerance: Error indicator, useful for early termination. - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - padding_start: If the input matrix was padded, then zeros out columns and - rows at the padding start. - prev: previous iteration's solution, zero-padded (unused) - - Returns: - `(matrix + eps)^(-1/p)` and error metrics. - - Note `eps` is not added to zeroed out padding rows and - columns. `eps` is just `ridge_epsilon` if - `relative_matrix_epsilon` is set to `False`, otherwise, it is the - ridge epsilon value scaled by the derived maximum eigenvalue of - the input matrix. - """ - del prev - assert matrix.shape[0] == matrix.shape[1] - matrix_size = matrix.shape[0] - orig_dtype = matrix.dtype - matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE) - alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) - identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) - if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) - matrix *= ix[jnp.newaxis, :] - matrix *= ix[:, jnp.newaxis] - identity *= ix - if relative_matrix_epsilon: - _, max_ev = power_iteration( - matrix=matrix, - num_iters=100, - error_tolerance=error_tolerance, - precision=precision, - padding_start=padding_start) - else: - # Use absolute matrix epsilon scaling otherwise. - max_ev = 1.0 - ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, error_tolerance) - regularized_input = matrix + ridge_epsilon * identity - e, u = jnp.linalg.eigh(regularized_input) - # Due to padding, we may have to zero out eigenvalues. - if padding_start is not None: - e *= jnp.flip(ix) - mm = functools.partial(jnp.matmul, precision=precision) - inv_e = jnp.where(e == 0.0, - 0.0, - jnp.power(jnp.maximum(e, ridge_epsilon), alpha)) - val = mm(mm(u, jnp.diag(inv_e)), u.T) - root = u * jnp.sqrt(inv_e) - val = mm(root, root.T) - recovered_e = mm(u.T, mm(regularized_input, u)) - eig_error = recovered_e - jnp.diag(e) - if padding_start is not None: - eig_error *= jnp.flip(ix) - error = jnp.max(jnp.abs(eig_error)) - error_metrics = TrainingMetrics( - inverse_pth_root_errors=jnp.array(error, jnp.float32)) - if padding_start is not None: - val = jnp.where(padding_start == 0, 0.0, val) - error = jnp.where(padding_start == 0, - 0.0, - error_metrics.inverse_pth_root_errors) - error_metrics = error_metrics.replace(inverse_pth_root_errors=error) - val = jnp.asarray(val, orig_dtype) - return val, error_metrics - - -def merge_small_dims(shape_to_merge, max_dim): - """Merge small dimensions. - - If there are some small dimensions, we collapse them: - e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024 - [1, 2, 768, 1, 2048] --> [2, 768, 2048] - - Args: - shape_to_merge: Shape to merge small dimensions. - max_dim: Maximal dimension of output shape used in merging. - - Returns: - Merged shape. - """ - if shape_to_merge and np.all(np.array(shape_to_merge) == 1): - return [1] - - resulting_shape = [] - product = 1 - for d in shape_to_merge: - if product * d <= max_dim: - product *= d - else: - if product > 1: - resulting_shape.append(product) - product = d - if product > 1: - resulting_shape.append(product) - return resulting_shape - - -def pad_square_matrix(mat, max_size): - """Pad a square matrix up to max_size. - - Args: - mat: a matrix to pad. - max_size: matrix size requested. - - Returns: - Given M returns [[M, 0], [0, I]] - """ - rows, cols = mat.shape - if rows != cols: - raise ValueError("Must have rows == cols, instead got " - f"rows={rows}, cols={cols}") - if cols > max_size: - raise ValueError("Must have cols <= max_size. Instead got " - f"cols={cols}, max_size={max_size}.") - if rows == max_size: - return mat - pad_size = max_size - rows - - zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype) - zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype) - eye = jnp.eye(pad_size, dtype=mat.dtype) - mat = jnp.concatenate([mat, zs1], 1) - mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0) - return mat - - -def pad_vector(vec, max_size): - """Pad a vector to a max_size. - - Args: - vec: a vector to pad. - max_size: matrix size requested. - - Returns: - Given V returns [V, 0] - """ - size = vec.shape[0] - assert size <= max_size - if size == max_size: - return vec - pad_size = max_size - size - zs1 = jnp.zeros([pad_size], dtype=vec.dtype) - return jnp.concatenate([vec, zs1], 0) - - -def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): - """Avoids wasteful buffer allocation with XLA.""" - - def _iter_body(unused_state): - results = compute_fn(*args, **kwargs) - return tuple([False] + list(results)) - - def _iter_condition(state): - return state[0] - - results = jax.lax.while_loop(_iter_condition, - _iter_body, - tuple([predicate] + init_state)) - return tuple(results[1:]) - - -class BlockPartitioner: - """Partitions a tensor into smaller tensors.""" - - def __init__(self, param, block_size): - self._shape = param.shape - self._splits = [] - split_sizes = [] - # We split params into smaller blocks. Here we store the metadata to make - # that split. - for i, d in enumerate(param.shape): - if 0 < block_size < d: - # d-1, otherwise split appends a 0-size array. - nsplit = (d - 1) // block_size - indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size - sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size - sizes[-1] = d - indices[-1] - self._splits.append((i, indices)) - split_sizes.append(sizes) - else: - split_sizes.append(np.array([d], dtype=np.int32)) - self._split_sizes = split_sizes - - def split_sizes(self): - return self._split_sizes - - def partition(self, tensor): - """Partition tensor into blocks.""" - - assert tensor.shape == self._shape - tensors = [tensor] - for (i, indices) in self._splits: - tensors_local = [] - for t in tensors: - tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i)) - tensors = tensors_local - return tensors - - def merge_partitions(self, partitions): - """Merge partitions back to original shape.""" - - for (i, indices) in reversed(self._splits): - n = len(indices) + 1 - partial_merged_tensors = [] - ind = 0 - while ind < len(partitions): - partial_merged_tensors.append( - jnp.concatenate(partitions[ind:ind + n], axis=i)) - ind += n - partitions = partial_merged_tensors - assert len(partitions) == 1 - return partitions[0] - - -def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None): - """Updated statistics via weighted average with new Gram matrix. - - Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose - columns are the flattened slices of the tensor `g` along the given `axis`. - (So, `old_stats` and the returned matrix have dimensions n x n where - n = `g.shape[axis]`). - - Args: - old_stats: Old statistics. - g: Gradient tensor. - axis: Axis along which to slice `g`. - w1: Scalar weight for old statistics. - w2: Scalar weight for new Gram matrix. - precision: Optional precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - Weighted average of old and new statistics. - """ - axes = [i for i in range(g.ndim) if i != axis] - gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision) - return w1 * old_stats + w2 * gram_matrix - - -class Preconditioner: - """Compute statistics/shape from gradients for preconditioning.""" - - def __init__( - self, - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - preconditioner_type=PreconditionerType.ALL, - ): - """Initializes the preconditioner. - - Args: - param: parameter to precondition. - block_size: Block size used to split param. - merge_small_dims_block_size: Block size for merging dims. - best_effort_shape_interpretation: Whether to - collapse/merge dims together. - preconditioner_type: Type of preconditioner to use. - """ - self._original_shape = param.shape - self._transformed_shape = param.shape - if best_effort_shape_interpretation: - self._transformed_shape = merge_small_dims(self._original_shape, - merge_small_dims_block_size) - reshaped_param = jnp.reshape(param, self._transformed_shape) - self._partitioner = BlockPartitioner(reshaped_param, block_size) - self._preconditioner_type = preconditioner_type - - def updated_statistics_from_grad( - self, - stats, - grad, - w1, - w2, - to_float=None, - from_float=None, - precision=None, - ): - """Update statistics from gradients. - - Args: - stats: Old statistics or its Cholesky factor if `cholesky` is True. - grad: Gradient to compute statistics from. - w1: Weight for old statistics. - w2: Weight for new statistics. - to_float: Optional function for converting stats to floating point. - from_float: Optional function for converting from floating point. - precision: Optional precision XLA related flag, the available options - are: - a) lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - - Returns: - A list of updated gradient statistics for each partition. - """ - to_float = to_float if to_float is not None else (lambda x: x) - from_float = from_float if from_float is not None else (lambda x: x) - reshaped_grad = jnp.reshape(grad, self._transformed_shape) - partitioned_grads = self._partitioner.partition(reshaped_grad) - should_preconditioned_dims = self.should_precondition_dims() - preconditioned_dims = [ - i for i, p in enumerate(should_preconditioned_dims) if p - ] - new_stats = [] - index = 0 - for g in partitioned_grads: - for axis in preconditioned_dims: - update = functools.partial(gram_weighted_update, precision=precision) - new_stat = update(to_float(stats[index]), g, axis, w1, w2) - new_stats.append(from_float(new_stat)) - index += 1 - return new_stats - - def should_precondition_dims(self): - """A vector containing indicator indicating if the dim is preconditioned.""" - split_sizes = self._partitioner.split_sizes() - rank = len(split_sizes) - if self._preconditioner_type == PreconditionerType.ALL or rank <= 1: - return [True] * rank - elif self._preconditioner_type == PreconditionerType.INPUT: - return [True] * (rank - 1) + [False] - elif self._preconditioner_type == PreconditionerType.OUTPUT: - return [False] * (rank - 1) + [True] - - def _preconditioner_shape(self, dim): - """Returns possibly rank-compressed preconditioner shape.""" - return [dim, dim] - - def _preconds_for_grad(self, preconditioners, rank, start, end): - """Returns a slice of preconditioners of length rank.""" - preconditioners_for_grad = preconditioners[start:end] - if self._preconditioner_type == PreconditionerType.INPUT: - # When _preconditioner_type is INPUT, we append a None value to the end of - # the list to handle the False index. - preconditioners_for_grad = preconditioners_for_grad + [None] - elif self._preconditioner_type == PreconditionerType.OUTPUT: - # When _preconditioner_type is OUTPUT, we append (rank - 1) many None - # values to the beginning of the list to handle the False indices. - preconditioners_for_grad = [None] * \ - (rank - 1) + preconditioners_for_grad - assert len(preconditioners_for_grad) == rank - return preconditioners_for_grad - - def shapes_for_preconditioners(self): - """Returns shape from statistics.""" - split_sizes = self._partitioner.split_sizes() - rank = len(split_sizes) - # We ignore preconditioner types if rank == 1 - preconditioner_shapes = [] - for t in itertools.product(*split_sizes): - if self._preconditioner_type == PreconditionerType.ALL or rank <= 1: - preconditioner_shapes.extend(map(self._preconditioner_shape, t)) - elif self._preconditioner_type == PreconditionerType.INPUT: - preconditioner_shapes.extend(map(self._preconditioner_shape, t[:-1])) - elif self._preconditioner_type == PreconditionerType.OUTPUT: - preconditioner_shapes.extend(map(self._preconditioner_shape, t[-1:])) - return preconditioner_shapes - - def exponent_for_preconditioner(self): - """Returns exponent to use for inverse-pth root M^{-1/p}.""" - should_preconditioned_dims = self.should_precondition_dims() - num_preconditioners = sum(should_preconditioned_dims) - return 2 * num_preconditioners - - def preconditioned_grad(self, grad, preconditioners): - """Precondition the gradient. - - Args: - grad: A gradient tensor to precondition. - preconditioners: A list of preconditioners to apply. - - Returns: - A preconditioned gradient. - """ - reshaped_grad = jnp.reshape(grad, self._transformed_shape) - partitioned_grads = self._partitioner.partition(reshaped_grad) - should_preconditioned_dims = self.should_precondition_dims() - num_preconditioners = sum(should_preconditioned_dims) - preconditioned_partitioned_grads = [] - for i, g in enumerate(partitioned_grads): - preconditioners_for_grad = self._preconds_for_grad( - preconditioners, - rank=len(should_preconditioned_dims), - start=i * num_preconditioners, - end=(i + 1) * num_preconditioners, - ) - precond_g = self._precondition_block(g, - should_preconditioned_dims, - preconditioners_for_grad) - preconditioned_partitioned_grads.append(precond_g) - merged_grad = self._partitioner.merge_partitions( - preconditioned_partitioned_grads) - return jnp.reshape(merged_grad, self._original_shape) - - def _precondition_block(self, g, should_precondition_dim, preconditioners): - """Perform a preconditioning op on a single gradient block.""" - for j, should_precondition in enumerate(should_precondition_dim): - # Loop invariant: the dimension to be preconditioned is first; we keep - # all axes in the same cyclic order they were originally. - # Case: skip preconditioning this dimension. - rank = len(g.shape) - roll = tuple(range(1, rank)) + (0,) - if not should_precondition: - g = jnp.transpose(g, axes=roll) - continue - # Case: full Shampoo matrix precondition this dimension - g = jnp.tensordot(g, preconditioners[j], axes=[[0], [0]]) - return g - - -def _convert_to_parameter_stats(global_stats, - local_stat, - convert_statistics=True): - """Creates parameter stats from sharded stats.""" - index_start = int(local_stat.index_start) - index_end = int(len(local_stat.sizes)) + index_start - statistics = global_stats.statistics[index_start:index_end, :, :] - preconditioners = global_stats.preconditioners[index_start:index_end, :, :] - new_statistics = [] - new_preconditioners = [] - for i, size in enumerate(local_stat.sizes): - new_statistics.append(statistics[i][:size, :size]) - pd = size - new_preconditioners.append(preconditioners[i][:size, :pd]) - if not convert_statistics: - new_statistics = None - return ParameterStats( - local_stat.diagonal_statistics, - new_statistics, - new_preconditioners, - local_stat.diagonal_momentum, - local_stat.momentum, - local_stat.training_metrics, - ) - - -def _convert_from_parameter_stats(parameter_stats, local_stats): - """Creates sharded stats from paramter stats.""" - return LocalShardedParameterStats( - parameter_stats.diagonal_statistics, - parameter_stats.diagonal_momentum, - parameter_stats.momentum, - parameter_stats.training_metrics, - local_stats.index_start, - local_stats.sizes, - ) - - -def _add_metrics_into_local_stats(local_stats, metrics, keep_old): - """Adds errors back into local statistics.""" - new_local_stats = [] - for local_stat in local_stats: - index_start = int(local_stat.index_start) - index_end = int(len(local_stat.sizes)) + index_start - # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) - # We don't want to update the metrics if we didn't do a new inverse p-th - # root calculation to find a new preconditioner, so that TensorBoard curves - # look consistent (otherwise they'd oscillate between NaN and measured - # values). - per_stat_metrics = efficient_cond(keep_old, - lambda: [local_stat.training_metrics], - [per_stat_metrics])[0] - # pylint:enable=cell-var-from-loop - new_local_stats.append( - local_stat.replace(training_metrics=per_stat_metrics)) - return new_local_stats - - -def batch(x, num_devices): - """Batch `x` so that so that leading axis is num_devices.""" - n = len(x) - b = int(n / num_devices) - return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)]) - - -def unbatch(batched_values): - """Unbatch values across leading axis and return a list of elements.""" - b1, b2 = batched_values.shape[0], batched_values.shape[1] - results = [] - for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0): - v_array = jnp.squeeze(v_array) - # b2 = batches (number of preconditioner computation) per core. - if b2 > 1: - for v in jnp.split(v_array, indices_or_sections=b2, axis=0): - results.append(jnp.squeeze(v)) - else: - results.append(v_array) - return results - - -def distributed_shampoo( - learning_rate, - block_size=1024, - beta1=0.9, - beta2=0.999, - diagonal_epsilon=1e-8, - matrix_epsilon=1e-6, - weight_decay=0.0, - start_preconditioning_step=101, - preconditioning_compute_steps=20, - statistics_compute_steps=1, - best_effort_shape_interpretation=True, - graft_type=GraftingType.RMSPROP_NORMALIZED, - nesterov=True, - exponent_override=0, - # Pass pmap 'batch axis name' in pmap mode. - batch_axis_name=None, - # Only set following 3 params in pjit/spmd mode. - # WARNING: Experimental - statistics_partition_spec=None, - preconditioner_partition_spec=None, - num_devices_for_pjit=None, - shard_optimizer_states=False, - ### - # Experimental memory reduction mode - best_effort_memory_usage_reduction=True, - ### - inverse_failure_threshold=0.1, - moving_average_for_momentum=True, - skip_preconditioning_dim_size_gt=0, - clip_by_scaled_gradient_norm=None, - precision=lax.Precision.HIGHEST, - tensordot_precision=None, - relative_matrix_epsilon=True, - merge_small_dims_block_size=4096, - lobpcg_topk_precondition=0, - lobpcg_max_iter=0, - precondtioner_type=PreconditionerType.ALL, - custom_preconditioner=False, - skip_preconditioning_rank_lt=1, - decoupled_learning_rate=True, - decoupled_weight_decay=False, - generate_training_metrics=True, - reuse_preconditioner=False, - eigh=True, -): - """Distributed Shampoo optimizer. - - Distributed Shampoo is a second-order preconditioned method (concretely, a - variant of full-matrix Adagrad), that provides significant convergence and - wall-clock time improvements compared to conventional first-order methods, - and that has been shown to scale to large state-of-the-art deep learning - models. - - References: - Scalable Second Order Optimization for Deep Learning, - Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer - - Preprint: https://arxiv.org/abs/2002.09018 - - Args: - learning_rate: the step size used to update the parameters. - block_size: Block size for large layers (if > 0). Preconditioning compute - operation is cubic in the dimension of the tensor. Block size allows us - to chunk the layers into sub-layers of maximal dimension dictated by - this value. Use 128 as default (increase if you have compute budget). - beta1: momentum parameter. - beta2: second moment averaging parameter. - diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting - to AdaGrad is enabled). - matrix_epsilon: epsilon to add to statistics before computing inverse pth - root. If you are running in f32 precision for inverse pth root - (recommended today) this can go upto 1e-6. If you have latest hardware - with native f64 precision, set this upto 1e-12. - weight_decay: Weight decay for regularization. - start_preconditioning_step: When to start Shampoo update before which - diagonal update is used. This is because we dont have enough information - to do stable inverse. - preconditioning_compute_steps: How often to compute preconditioner. - Performance tuning params for controlling memory and compute - requirements. - Ideally set this and statistics_compute_steps params to 1. - statistics_compute_steps: How often to compute statistics. - best_effort_shape_interpretation: If there are some small dimensions, - collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if - block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] - graft_type: Grafting is a technique to fix the layerwise scale of Shampoo - optimizer. This allows us to plugin the Shampoo optimizer into settings - where SGD/AdaGrad is already well tuned. - nesterov: Nesterov momentum. - exponent_override: Override the exponent used in matrix inverse. - batch_axis_name: labeled axis over pmap for data-parallel training the - optimizer used for. - statistics_partition_spec: PartitionSpec to be used in sharded mode. - preconditioner_partition_spec: PartitionSpec to be used in sharded mode. - num_devices_for_pjit: Number of devices to parallelize over when using - pjit. - shard_optimizer_states: Shard optimizer states to save memory in model - parallel training. - best_effort_memory_usage_reduction: Best effort memory usage reduction. - - diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) - -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals - inverse_failure_threshold: numerics are hard and inverses fail sometimes; - we determine that using this threshold. - moving_average_for_momentum: Whether to use moving average for momentum - instead of exponential moving average. - skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is - greater than this value. - clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful - when using RMSProp Grafting). - precision: precision XLA related flag, the available options are: a) - lax.Precision.DEFAULT (better step time, but not precise) b) - lax.Precision.HIGH (increased precision, slower) c) - lax.Precision.HIGHEST (best possible precision, slowest) - tensordot_precision: Optional precision to use for the tensordot operation - when computing statistics (e.g., G Gᵀ). Same options as `precision` - above. - relative_matrix_epsilon: Whether to use relative epsilon to the max eigen - value when computing inverse-pth root. - merge_small_dims_block_size: Used as the maximum block size to merge the - shapes. - lobpcg_topk_precondition: If nonzero, specifies the number of top - eigenvectors to subtract out before performing LOBPCG. Note this makes - relative_matrix_epsilon essentially free. - lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to - `lobpcg_topk_precondition`. - precondtioner_type: Preconditioner type to select all, left only or right - only preconditioners. - skip_preconditioning_rank_lt: Skips preconditioning for parameters with - rank less than this value. - decoupled_learning_rate: If True, use decoupled learning rate, otherwise - couple it with preconditioned gradient computation. (Default True) - decoupled_weight_decay: If True, use decoupled weight decay, otherwise - couple with weight decay. (Default False) - generate_training_metrics: If True, gather training metrics, otherwise - avoid generating them (to reduce memory usage). - reuse_preconditioner: If True, pass the previous derived preconditioner - as a warm start to the next iteratin's inverse pth root computation. - eigh: If True, and uses eigen decomposition for inverse-pth root. - - Returns: - a GradientTransformation. - """ - reset_frequency = None - - def _graft_type_has_diagonal_statistics(): - """Returns True if using diagonal firt order method for grafting.""" - return graft_type not in [ - GraftingType.SGD, GraftingType.SQRT_N, GraftingType.NONE - ] - - def quantized_dtype_for_momentum_buffers(var): - return jnp.int8 if best_effort_memory_usage_reduction and len( - var.shape) > 1 else jnp.float32 - - quantize_second_moment = ( - best_effort_memory_usage_reduction and batch_axis_name) - - # Preconditioner and statistics are both stores as int16 in this mode. - # We take out the diagonal to make quantization easier. - def quantized_dtype_for_second_moment_statistics_buffers(): - return jnp.int16 if quantize_second_moment else jnp.float32 - - # Preconditioner and statistics are both stores as int16 in this mode. - # We take out the diagonal to make quantization easier. - def quantized_dtype_for_second_moment_preconditioner_buffers(): - return jnp.int16 if quantize_second_moment else jnp.float32 - - # _quantized_matrix_inverse_pth_root_vmap implementation assumes - # that preconditioner is quantized if and only if stats is quantized. - qdt_precond = quantized_dtype_for_second_moment_preconditioner_buffers() - qdt_stat = quantized_dtype_for_second_moment_statistics_buffers() - assert qdt_precond == qdt_stat - - def _to_float(maybe_quantized): - if isinstance(maybe_quantized, QuantizedValue): - return maybe_quantized.to_float() - else: - return maybe_quantized - - def preconditioner_from_params(param): - """Returns a Preconditioner object for given param.""" - return Preconditioner( - param, - block_size, - merge_small_dims_block_size, - best_effort_shape_interpretation, - precondtioner_type, - ) - - def precond_dim(max_size): - """Derives largest preconditioner dimension.""" - return max_size - - def pad_and_maybe_zero_preconditioners(preconditioners, total, max_size, - step): - """Pad preconditioners up to total x max_size x precond_dim(max_size).""" - pd = precond_dim(max_size) - - def maybe_reset_preconditioner(step, preconditioner): - if reset_frequency is None: - return preconditioner - return jnp.where(step % reset_frequency == 0, 0.0, 1.0) * preconditioner - - def _pad_preconditioner(preconditioner): - assert preconditioner.ndim == 2 - r, c = preconditioner.shape - assert r <= max_size - assert c <= pd - pad_rows = [(0, max_size - r)] - pad_cols = [(0, pd - c)] - padding = pad_rows + pad_cols - preconditioner = maybe_reset_preconditioner(step, preconditioner) - return jnp.pad(preconditioner, padding) - - last_dims_padded = [_pad_preconditioner(p) for p in preconditioners] - dt = preconditioners[0].dtype if preconditioners else jnp.float32 - num_extra = total - len(last_dims_padded) - extra = [jnp.zeros([max_size, pd], dtype=dt)] * num_extra - return last_dims_padded + extra - - def sharded_init_fn(params): - """Returns optimizer state (for PJIT mode). - - Args: - params: the parameters that should be updated. - """ - params_flat, treedef = jax.tree_util.tree_flatten(params) - # Find max size to pad to. - max_size = 0 - for param in params_flat: - preconditioner = preconditioner_from_params(param) - if not _skip_preconditioning(param): - shapes = preconditioner.shapes_for_preconditioners() - sizes = [s[0] for s in shapes] - max_size = max(*sizes, max_size) - - padded_statistics = [] - padded_preconditioners = [] - local_stats_flat = [] - exponents = [] - for param in params_flat: - preconditioner = preconditioner_from_params(param) - shapes = preconditioner.shapes_for_preconditioners() - sizes = [] - - statistics = [] - preconditioners = [] - index_start = len(padded_statistics) - if not _skip_preconditioning(param): - sizes = [s[0] for s in shapes] - shapes = preconditioner.shapes_for_preconditioners() - statistics = [ - matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) - for s in shapes - ] - pd = precond_dim(max_size) - # If the preconditioner is using a low-rank representation, initialize - # it to zero instead of an invalid eye. - preconditioners = [ - jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size) - for s in shapes - ] - padded_statistics.extend(statistics) - padded_preconditioners.extend(preconditioners) - exponent = ( - preconditioner.exponent_for_preconditioner() - if exponent_override == 0 else exponent_override) - exponents.extend([exponent] * len(shapes)) - - diagonal_statistics = jnp.zeros_like(param) - diagonal_momentum = jnp.zeros_like(param) - momentum = jnp.zeros_like(param) - - local_stats_flat.append( - LocalShardedParameterStats( - diagonal_statistics, - diagonal_momentum, - momentum, - init_training_metrics( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes)) - - local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) - to_pad = -len(padded_statistics) % num_devices_for_pjit - if max_size == 0: - to_pad = num_devices_for_pjit - max_size = block_size - stat_dtype = jnp.float32 - else: - stat_dtype = padded_statistics[0].dtype - # Pad the statistics and preconditioner matrices to be a multiple of - # num devices. - # TODO(rohananil): Relax to only the size of the mesh axis where the dim - # is split on. - padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) - pd = precond_dim(max_size) - # If the preconditioner is using a low-rank representation, initialize - # it to zero instead of an invalid eye. - padded_preconditioners.extend([ - jnp.eye(max_size, pd, dtype=stat_dtype) * (pd == max_size) - for _ in range(to_pad) - ]) - exponents.extend([1 for _ in range(to_pad)]) - global_stats = GlobalShardedParameterStats( - jnp.stack(padded_statistics), - jnp.stack(padded_preconditioners), - jnp.stack(exponents)) - return ShampooState( - count=jnp.zeros([], jnp.int32), - stats=ShardedShampooStats(global_stats, local_stats)) - - def _max_statistics_size_from_params(params): - max_size = 0 - for param in params: - param_clone = jnp.zeros(param.shape, dtype=param.dtype) - preconditioner = preconditioner_from_params(param_clone) - if not _skip_preconditioning(param): - shapes = preconditioner.shapes_for_preconditioners() - sizes = [s[0] for s in shapes] - max_size = max(*sizes, max_size) - return max_size - - def _remove_leading_sharding_annotation(pspec): - """Mapping from N-d to (N-1)-d, used for quantization, factoring etc.""" - # None and PSpec(None) are valid PSpecs. - if pspec and len(pspec) > 1: - return jax.sharding.PartitionSpec(*pspec[1:]) - else: - return [] - - def sharded_init_partition_spec_fn(params, - params_partition_spec, - partition_spec_for_statistics): - """Returns a parallel state tree with PartitionSpec associated with state. - - - Args: - params: A pytree with params. - params_partition_spec: A pytree with PartitionSpec for params. - partition_spec_for_statistics: PartitionSpec for the statistics. - """ - # Parallel lists of spec, and params. - param_pspec_flat, _ = jax.tree_util.tree_flatten( - params_partition_spec, is_leaf=lambda x: x is None) - params_flat, treedef = jax.tree_util.tree_flatten(params) - assert param_pspec_flat - assert params_flat - # Step is replicated across cores. - # None means cores. - local_stats_flat = [] - num_statistics = 0 - for param, param_pspec in zip(params_flat, param_pspec_flat): - param_clone = jnp.zeros(param.shape, dtype=param.dtype) - preconditioner = preconditioner_from_params(param_clone) - shapes = preconditioner.shapes_for_preconditioners() - sizes = [] - - index_start = num_statistics - if not _skip_preconditioning(param): - sizes = [s[0] for s in shapes] - shapes = preconditioner.shapes_for_preconditioners() - num_statistics += len(shapes) - - qdtype = quantized_dtype_for_momentum_buffers(param) - m1_pspec = param_pspec - m2_pspec = param_pspec - m1_scale_pspec = [] - m2_scale_pspec = [] - if qdtype != jnp.float32: - m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec) - m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec) - - local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - param_pspec, - [], - [], - jnp.float32, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m1_pspec, - [], - m1_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - QuantizedValue( - m2_pspec, - [], - m2_scale_pspec, - qdtype, - False, # pytype: disable=wrong-arg-types # numpy-scalars - list(param.shape)), - init_training_metrics_pspec(generate_training_metrics,), - index_start, - sizes)) - - local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) - global_stats = GlobalShardedParameterStats(partition_spec_for_statistics, - partition_spec_for_statistics, - jax.sharding.PartitionSpec()) - count_pspec = jax.sharding.PartitionSpec() - return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=count_pspec, - stats=ShardedShampooStats(global_stats, local_stats)) - - def sharded_init_shape_and_dtype_fn(params): - """Returns a parallel state tree with shape, dtype associated with state. - - - Args: - params: A pytree with params. - """ - # Parallel lists of spec, and params. - params_flat, treedef = jax.tree_util.tree_flatten(params) - assert params_flat - # Step is replicated across cores. - # None means cores. - local_stats_flat = [] - num_statistics = 0 - for param in params_flat: - param_clone = jnp.zeros(param.shape, dtype=param.dtype) - preconditioner = preconditioner_from_params(param_clone) - shapes = preconditioner.shapes_for_preconditioners() - sizes = [] - - index_start = num_statistics - if not _skip_preconditioning(param): - sizes = [s[0] for s in shapes] - shapes = preconditioner.shapes_for_preconditioners() - num_statistics += len(shapes) - - qdtype = quantized_dtype_for_momentum_buffers(param) - m1_shape_and_dtype = [list(param.shape), param.dtype] - m2_shape_and_dtype = [list(param.shape), param.dtype] - m1_scale_shape_and_dtype = [] - m2_scale_shape_and_dtype = [] - if qdtype != jnp.float32: - m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] - m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype] - - diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype] - local_stats_flat.append( - LocalShardedParameterStats( - QuantizedValue( - diagonal_statistics_shape_and_dtype, - [], - [], # pytype: disable=wrong-arg-types # numpy-scalars - jnp.float32, - False, - list(param.shape)), - QuantizedValue(m1_shape_and_dtype, [], - m1_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - QuantizedValue(m2_shape_and_dtype, [], - m2_scale_shape_and_dtype, - qdtype, - False, - list(param.shape)), - init_training_metrics_shapes( - len(sizes), - generate_training_metrics, - ), - index_start, - sizes, - )) - - local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat) - max_statistics_size = _max_statistics_size_from_params(params_flat) - to_pad = -num_statistics % num_devices_for_pjit - num_statistics += to_pad - if num_statistics == 0: - num_statistics = num_devices_for_pjit - max_statistics_size = block_size - statistics_shape = [ - num_statistics, max_statistics_size, max_statistics_size - ] - preconditioners_shape = [ - num_statistics, max_statistics_size, precond_dim(max_statistics_size) - ] - global_stats = GlobalShardedParameterStats( - [statistics_shape, jnp.float32], [preconditioners_shape, jnp.float32], - [[num_statistics], jnp.int32]) - return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars - count=[[], jnp.float32], - stats=ShardedShampooStats(global_stats, local_stats)) - - def sharded_update_fn(grads, state, params): - """Transform the input gradient and update all statistics in sharded mode. - - Args: - grads: the gradient tensors for the parameters. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. - - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - params_flat, treedef = jax.tree_util.tree_flatten(params) - grads_flat = treedef.flatten_up_to(grads) - - global_stats = state.stats.global_stats - local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) - stats_flat = [] - for local_stat in local_stats_flat: - stats_flat.append(_convert_to_parameter_stats( - global_stats, - local_stat, - )) - - new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - grads_flat, - stats_flat, - params_flat) - - outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) - updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - - updates = jax.tree_util.tree_unflatten(treedef, updates_flat) - new_local_stats_flat = [] - for new_stat, local_stat in zip(new_stats_flat, local_stats_flat): - new_local_stats_flat.append( - _convert_from_parameter_stats( - new_stat, - local_stat, - )) - - max_size = global_stats.statistics.shape[1] - new_padded_statistics = [] - padding_starts = [] - for stat in new_stats_flat: - new_padded_statistics.extend( - [pad_square_matrix(stat, max_size) for stat in stat.statistics]) - padding_starts.extend([len(stat) for stat in stat.statistics]) - - # Create global stats - # TODO(rohananil): Preconditioner is not updated every step, so cost of - # stack/pad can be obviated away. - # Pad the statistics and preconditioner matrices to be a multiple of - # num devices. - # TODO(rohananil): Relax to only the size of the mesh axis where the dim - # is split on. - to_pad = -len(new_padded_statistics) % num_devices_for_pjit - if not new_padded_statistics: - to_pad = num_devices_for_pjit - stat_dtype = jnp.float32 - else: - stat_dtype = new_padded_statistics[0].dtype - - new_padded_statistics.extend( - [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]) - padding_starts += [0] * to_pad - - if reuse_preconditioner: - prev_preconditioners = [] - for stat in new_stats_flat: - prev_preconditioners.extend(stat.preconditioners) - prev_padded_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, - len(new_padded_statistics), - max_size, - state.count) - else: - prev_padded_preconditioners = None - - new_stacked_padded_statistics = jnp.stack(new_padded_statistics) - new_stacked_padded_statistics = pjit.with_sharding_constraint( - new_stacked_padded_statistics, statistics_partition_spec) - stacked_padding_starts = jnp.array(padding_starts, jnp.int32) - prev_stacked_padded_preconditioners = _maybe(jnp.stack)( - prev_padded_preconditioners) - prev_stacked_padded_preconditioners = _maybe(pjit.with_sharding_constraint)( - prev_padded_preconditioners, statistics_partition_spec) - - def _internal_inverse_pth_root_all(): - preconditioners, metrics = _matrix_inverse_pth_root_pjit( - new_stacked_padded_statistics, - global_stats.exponents, - stacked_padding_starts, - prev_stacked_padded_preconditioners, - statistics_partition_spec, - ) - return preconditioners, metrics - - perform_step = state.count % preconditioning_compute_steps == 0 - - if preconditioning_compute_steps == 1: - new_preconditioners, metrics = _internal_inverse_pth_root_all() - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large error value. - pd = precond_dim(new_stacked_padded_statistics.shape[2]) - preconditioners_init = new_stacked_padded_statistics[:, :, :pd] - n = new_stacked_padded_statistics.shape[0] - metrics_init = cast( - TrainingMetrics, - init_training_metrics( - n, - generate_training_metrics=True, - )) - new_errors = jnp.ones_like(metrics_init.inverse_pth_root_errors) * ( - inverse_failure_threshold) - metrics_init = metrics_init.replace(inverse_pth_root_errors=new_errors) - init_state = [preconditioners_init, metrics_init] - new_preconditioners, metrics = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) - - if generate_training_metrics: - new_local_stats_flat = _add_metrics_into_local_stats( - new_local_stats_flat, metrics, ~perform_step) - new_local_stats = jax.tree_util.tree_unflatten(treedef, - new_local_stats_flat) - errors = metrics.inverse_pth_root_errors - errors = errors.reshape((-1, 1, 1)) - predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) - # TODO(rohananil): Check for numerical instabilities. - new_conditional_preconditioners = ( - predicate * global_stats.preconditioners + - (1.0 - predicate) * new_preconditioners) - new_global_stats = GlobalShardedParameterStats( - new_stacked_padded_statistics, - new_conditional_preconditioners, - global_stats.exponents) - new_shampoo_state = ShampooState( - count=state.count + 1, - stats=ShardedShampooStats(new_global_stats, new_local_stats)) - return updates, new_shampoo_state - - def init_fn(params): - """Initialise the optimiser's state.""" - - def _init(param): - preconditioner = preconditioner_from_params(param) - statistics = [] - preconditioners = [] - if not _skip_preconditioning(param): - shapes = preconditioner.shapes_for_preconditioners() - statistics = [ - matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes - ] - # If the preconditioner is using a low-rank representation, initialize - # it to zero instead of an invalid eye. - preconditioners = [ - jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1]) - for s in shapes - ] - - diagonal_statistics = [] - if _graft_type_has_diagonal_statistics(): - diagonal_statistics = jnp.zeros_like(param) - - # diagonal_momentum = _quantize_momentum(jnp.zeros_like(param)) - # momentum = _quantize_momentum(jnp.zeros_like(param)) - diagonal_momentum = jnp.zeros_like(param) - momentum = jnp.zeros_like(param) - - return ParameterStats( - diagonal_statistics, - statistics, - preconditioners, - # _quantize_diagonal_statistics(diagonal_statistics), - # _maybe_quantize_statistics(statistics), - # _maybe_quantize_preconditioners(preconditioners), - diagonal_momentum, - momentum, - init_training_metrics( - len(statistics), - generate_training_metrics, - )) - - return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) - - def _skip_preconditioning(param): - return len(param.shape) < skip_preconditioning_rank_lt or any( - s > skip_preconditioning_dim_size_gt for s in param.shape) - - def _compute_stats(grad, state, param, step): - """Compute per-parameter statistics.""" - preconditioner = preconditioner_from_params(param) - new_statistics = [[]] * len(state.statistics) - w1 = beta2 - w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2) - if not _skip_preconditioning(param): - - def compute_updated_statistics(): - return preconditioner.updated_statistics_from_grad( - state.statistics, - grad, - w1=w1, - w2=w2, - to_float=_to_float, - from_float=lambda x: x, - # from_float=lambda x: _maybe_quantize_statistics([x])[0], - precision=tensordot_precision, - ) - - if statistics_compute_steps > 1: - perform_step = step % statistics_compute_steps == 0 - init_state = state.statistics - new_statistics = list( - efficient_cond(perform_step, compute_updated_statistics, - init_state)) - else: - new_statistics = compute_updated_statistics() - - return ParameterStats(state.diagonal_statistics, - new_statistics, - state.preconditioners, - state.diagonal_momentum, - state.momentum, - state.training_metrics) - - mi_pth_root = functools.partial( - matrix_inverse_pth_root, - ridge_epsilon=matrix_epsilon, - precision=precision, - relative_matrix_epsilon=relative_matrix_epsilon, - lobpcg_topk_precondition=lobpcg_topk_precondition, - lobpcg_max_iter=lobpcg_max_iter, - eigh=eigh) - - def _matrix_inverse_pth_root_vmap(xs, ps, padding_starts, prev): - return jax.vmap(mi_pth_root)( - xs, ps, padding_start=padding_starts, prev=prev) - - def _matrix_inverse_pth_root_pjit(xs, - ps, - padding_starts, - prev_preconds=None, - statistics_partition_spec=None): - # Partition the concatenated statistics matrix across all cores. - pspec_for_partition = preconditioner_partition_spec - partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition) - if preconditioner_partition_spec: - partitioned_ps_spec = jax.sharding.PartitionSpec( - preconditioner_partition_spec[0]) - else: - partitioned_ps_spec = None - partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec) - partitioned_prev_preconds = _maybe(pjit.with_sharding_constraint)( - prev_preconds, preconditioner_partition_spec) - partitioned_padding_starts = pjit.with_sharding_constraint( - padding_starts, partitioned_ps_spec) # paddings are scalars like ps. - # Run matrix inverse pth root on each shard. - partitioned_preconditioners, partitioned_metrics = ( - _matrix_inverse_pth_root_vmap( - partitioned_xs, - partitioned_ps, - partitioned_padding_starts, - prev=partitioned_prev_preconds)) - # Reshard output to have the same PSpec as input. This is required to avoid - # vmap seeing the full set of statistics. - partitioned_preconditioners = pjit.with_sharding_constraint( - partitioned_preconditioners, pspec_for_partition) - # Recombine the outputs at each core. - preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners, - statistics_partition_spec) - metrics = pjit.with_sharding_constraint(partitioned_metrics, - jax.sharding.PartitionSpec()) - return preconditioners, metrics - - def _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners): - """Computes preconditioners for given statistics in states in PMAP mode. - - Args: - states: A list of optimizer states. - step: Current step number - statistics: A list of statistics for all variables (for every dim) - num_statistics_per_state: Number of statistis per state to reconstruct - output states. - original_shapes: A list of shapes of the statistics. - exponents: Exponent power to use for inverse-pth roots. - max_size: Maximum dim of the statistics to pad. - prev_preconditioners: Previously available preconditioner. - - Returns: - New optimizer states after computing the preconditioner. - """ - if batch_axis_name: - num_devices = lax.psum(1, batch_axis_name) - else: - num_devices = 1 - num_statistics = len(statistics) - # Pad statistics and exponents to next multiple of num_devices. - packed_statistics = [ - pad_square_matrix(stat, max_size) for stat in statistics - ] - to_pad = -num_statistics % num_devices - packed_statistics.extend([ - jnp.eye(max_size, dtype=packed_statistics[0].dtype) - for _ in range(to_pad) - ]) - exponents.extend([1 for _ in range(to_pad)]) - paddings = [len(stat) for stat in statistics] + [0] * to_pad - - if not packed_statistics: - return states - - if reuse_preconditioner: - assert len(prev_preconditioners) == num_statistics - packed_preconditioners = pad_and_maybe_zero_preconditioners( - prev_preconditioners, len(packed_statistics), max_size, step) - else: - packed_preconditioners = None - - all_statistics = batch(packed_statistics, num_devices) - all_exponents = batch(exponents, num_devices) - all_paddings = batch(paddings, num_devices) - all_preconditioners = _maybe(batch)(packed_preconditioners, num_devices) - - def _internal_inverse_pth_root_all(): - if batch_axis_name: - current_replica = lax.axis_index(batch_axis_name) - preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[current_replica], - all_exponents[current_replica], - all_paddings[current_replica], - _maybe_ix(all_preconditioners, current_replica), - ) - preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) - metrics = jax.lax.all_gather(metrics, batch_axis_name) - preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree.map(unbatch, metrics) - else: - preconditioners, metrics = _matrix_inverse_pth_root_vmap( - all_statistics[0], - all_exponents[0], - all_paddings[0], - _maybe_ix(all_preconditioners, 0), - ) - preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree.map( - functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree.map(unbatch, metrics) - - return preconditioners_flat, metrics_flat - - perform_step = step % preconditioning_compute_steps == 0 - if preconditioning_compute_steps == 1: - preconditioners_flat, metrics_flat = _internal_inverse_pth_root_all() - else: - # Passing statistics instead of preconditioners as they are similarly - # shaped tensors. Note statistics will be ignored as we are passing in - # a large error value. - preconditioners_init = [ - s[:, :precond_dim(s.shape[0])] for s in packed_statistics - ] - n = len(packed_statistics) - metrics_init = jax.tree.map( - lambda x: [x] * n, - default_training_metrics().replace( - inverse_pth_root_errors=inverse_failure_threshold)) - init_state = [preconditioners_init, metrics_init] - preconditioners_flat, metrics_flat = efficient_cond( - perform_step, _internal_inverse_pth_root_all, init_state) - - def _skip(error): - condition = jnp.logical_or( - jnp.isnan(error), error >= inverse_failure_threshold) - return condition.astype(error.dtype) - - def _select_preconditioner(error, new_p, old_p): - return lax.cond( - _skip(error), lambda _: old_p, lambda _: new_p, operand=None) - - new_preconditioners_flat = [] - new_errors_flat = metrics_flat.inverse_pth_root_errors - for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, - prev_preconditioners, new_errors_flat): - new_preconditioners_flat.append( - _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) - - assert len(states) == (len(num_statistics_per_state), - f"{len(states)} vs {len(num_statistics_per_state)}") - assert len(new_preconditioners_flat) == num_statistics - assert len(new_errors_flat) == len(packed_statistics), ( - len(new_errors_flat), len(packed_statistics)) - assert len(new_errors_flat) == num_statistics + to_pad, ( - len(new_errors_flat), num_statistics, to_pad) - - # Add back empty preconditioners so we that we can set the optimizer state. - preconditioners_for_states = [] - idx = 0 - metrics_for_states = [] - for num_statistics, state in zip(num_statistics_per_state, states): - if num_statistics == 0: - preconditioners_for_states.append([]) - metrics_for_states.append( - init_training_metrics(0, generate_training_metrics)) - else: - preconditioners_for_state = new_preconditioners_flat[idx:idx + - num_statistics] - assert len(state.statistics) == len(preconditioners_for_state) - preconditioners_for_states.append(preconditioners_for_state) - - if generate_training_metrics: - # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree.map( - lambda x: jnp.stack(x[idx:idx + num_statistics]), - metrics_flat, - is_leaf=lambda x: isinstance(x, list)) - assert jax.tree_util.tree_all( - jax.tree.map(lambda x: len(state.statistics) == len(x), - metrics_for_state)) - # If we skipped preconditioner computation, record old metrics. - metrics_for_state = efficient_cond(perform_step, - lambda: [metrics_for_state], - [state.training_metrics])[0] - # pylint:enable=cell-var-from-loop - else: - metrics_for_state = optax.MaskedNode() - metrics_for_states.append(metrics_for_state) - - idx += num_statistics - new_states = [] - for state, new_preconditioners, new_metrics in zip( - states, preconditioners_for_states, metrics_for_states): - # Note the preconditioner may have been skipped, but we still update the - # metrics with the new error values; whether the preconditioner that's - # actively being used is stale can be derived from the new_metrics - # being greater than the failure threshold. - new_states.append( - ParameterStats(state.diagonal_statistics, - state.statistics, - new_preconditioners, - state.diagonal_momentum, - state.momentum, - new_metrics)) - - return new_states - - def _compute_preconditioners(states, params, step): - """Computes preconditioners for given statistics in states. - - Args: - states: A list of optimizer states. - params: A list of params. - step: Current step number - - Returns: - New optimizer states after computing the preconditioner. - """ - statistics = [] - num_statistics_per_state = [] - original_shapes = [] - exponents = [] - max_size = 0 - prev_preconditioners = [] - - for state, param in zip(states, params): - num_statistics = len(state.statistics) - num_statistics_per_state.append(num_statistics) - original_shapes_for_state = [] - if num_statistics > 0: - preconditioner = preconditioner_from_params(param) - for statistic in state.statistics: - exponents.append(preconditioner.exponent_for_preconditioner( - ) if exponent_override == 0 else exponent_override) - original_shapes_for_state.append(statistic.shape) - max_size = max(max_size, statistic.shape[0]) - - statistics.extend(state.statistics) - prev_preconditioners.extend(state.preconditioners) - original_shapes.extend(original_shapes_for_state) - - return _pmap_compute_preconditioners(states, - step, - statistics, - num_statistics_per_state, - original_shapes, - exponents, - max_size, - prev_preconditioners) - - def _transform_grad(grad, state, param, step): - """Transform per-parameter gradients.""" - preconditioner = preconditioner_from_params(param) - sgd_update = grad - new_diagonal_statistics = state.diagonal_statistics - - if (graft_type == GraftingType.ADAGRAD or - graft_type == GraftingType.ADAGRAD_NORMALIZED): - - scaled_grad = grad - if graft_type == GraftingType.ADAGRAD_NORMALIZED: - scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) - - new_diagonal_statistics = ( - state.diagonal_statistics.to_float() + jnp.square(scaled_grad)) - adagrad_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) - grafting_update = adagrad_update - elif (graft_type == GraftingType.RMSPROP or - graft_type == GraftingType.RMSPROP_NORMALIZED): - - scaled_grad = grad - if graft_type == GraftingType.RMSPROP_NORMALIZED: - scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON) - - w1 = beta2 - w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2) - - new_diagonal_statistics = ( - w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)) - rmsprop_update = scaled_grad / ( - jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) - - if clip_by_scaled_gradient_norm: - scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( - jnp.sqrt(float(rmsprop_update.size))) - clipping_denom = jnp.maximum( - 1., scaled_grad_norm / clip_by_scaled_gradient_norm) - rmsprop_update /= clipping_denom - - grafting_update = rmsprop_update - elif graft_type == GraftingType.SGD: - grafting_update = sgd_update - elif graft_type == GraftingType.NONE: - grafting_update = sgd_update # Use SGD during warmup. - else: - grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update) - - lr = learning_rate - if callable(learning_rate): - lr = learning_rate(step) - - preconditioner_multiplier = lr if not decoupled_learning_rate else 1.0 - grafting_update = grafting_update * preconditioner_multiplier - - precond_grad = grad - if not _skip_preconditioning(param): - precond_grad = preconditioner.preconditioned_grad(precond_grad, - state.preconditioners) - else: - if graft_type == GraftingType.NONE: - logging.error("skipping preconditioning without grafting for param %s", - param) - precond_grad = grafting_update - - grafting_update_norm = jnp.linalg.norm(grafting_update) - precond_grad_norm = jnp.linalg.norm(precond_grad) - - if graft_type is not GraftingType.NONE: - multiplier = grafting_update_norm / (precond_grad_norm + _EPSILON) - else: - multiplier = 1.0 - shampoo_update = precond_grad * multiplier - - shampoo_update_with_wd = shampoo_update - grafting_update_with_wd = grafting_update - - if (weight_decay != 0 and weight_decay is not None and - not decoupled_weight_decay): - shampoo_update_with_wd = shampoo_update + weight_decay * param - grafting_update_with_wd = grafting_update + weight_decay * param - - w = (1.0 - beta1) if moving_average_for_momentum else 1.0 - - shampoo_update_with_wd_momentum = ( - state.momentum * beta1 + w * shampoo_update_with_wd) - - grafting_update_with_wd_momentum = ( - state.diagonal_momentum * beta1 + w * grafting_update_with_wd) - - run_shampoo = (step >= start_preconditioning_step).astype( - grafting_update_with_wd_momentum.dtype) - - momentum_update = ( - run_shampoo * shampoo_update_with_wd_momentum + - (1.0 - run_shampoo) * grafting_update_with_wd_momentum) - - wd_update = ( - run_shampoo * shampoo_update_with_wd + - (1.0 - run_shampoo) * grafting_update_with_wd) - - nesterov_momentum_update = momentum_update - - if nesterov: - nesterov_momentum_update = w * wd_update + beta1 * momentum_update - - if (weight_decay != 0 and weight_decay is not None and - decoupled_weight_decay): - nesterov_momentum_update = ( - nesterov_momentum_update + lr * weight_decay * param) - - momentum_multiplier = lr if decoupled_learning_rate else 1.0 - transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update - - new_diagonal_momentum = grafting_update_with_wd_momentum - new_momentum = shampoo_update_with_wd_momentum - - param_stats = ParameterStats(new_diagonal_statistics, - state.statistics, - state.preconditioners, - new_diagonal_momentum, - new_momentum, - state.training_metrics) - return transformed_update, param_stats - - def update_fn(grads, state, params): - """Transform the input gradient and update all statistics. - - Args: - grads: the gradient tensors for the parameters and any custom - gradients for preconditioners. - state: a named tuple containing the state of the optimizer - params: the parameters that should be updated. - - Returns: - A tuple containing the new parameters and the new optimizer state. - """ - grads_custom = None - if custom_preconditioner and isinstance(grads, tuple): - grads, grads_custom = grads - - params_flat, treedef = jax.tree_util.tree_flatten(params) - stats_flat = treedef.flatten_up_to(state.stats) - grads_flat = treedef.flatten_up_to(grads) - stats_grads = grads_flat - - if custom_preconditioner and grads_custom is not None: - stats_grads = treedef.flatten_up_to(grads_custom) - - new_stats_flat = jax.tree.map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), - stats_grads, - stats_flat, - params_flat) - - new_stats_flat = _compute_preconditioners(new_stats_flat, - params_flat, - state.count) - outputs = jax.tree.map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), - grads_flat, - new_stats_flat, - params_flat) - updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) - updates = jax.tree_util.tree_unflatten(treedef, updates_flat) - new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat) - new_state = ShampooState(count=state.count + 1, stats=new_stats) - return updates, new_state - - if shard_optimizer_states: - # Hijacks the init_fn signature so we can return an OptState with - # appropriate init_fns. - opt_init_fn = sharded_init_fn - - def _init_fns(unused_params): - return InitFnState( - init_fn=opt_init_fn, - pspec_fn=sharded_init_partition_spec_fn, - shape_and_dtype_fn=sharded_init_shape_and_dtype_fn) - - opt_update_fn = sharded_update_fn - return optax.GradientTransformation(_init_fns, opt_update_fn) - else: - return optax.GradientTransformation(init_fn, update_fn) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py deleted file mode 100644 index 2cd054062..000000000 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" - -import functools -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algoperf import spec -from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ - distributed_shampoo - -_GRAD_CLIP_EPS = 1e-6 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a Shampoo optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = distributed_shampoo( - learning_rate=lr_schedule_fn, - beta1=1.0 - hyperparameters.one_minus_beta1, - beta2=hyperparameters.beta2, - weight_decay=hyperparameters.weight_decay, - batch_axis_name='batch', - eigh=False) - params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree.map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del train_state - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def prepare_for_eval(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del workload - del hyperparameters - del current_params_types - del loss_type - del eval_results - del global_step - del rng - return (optimizer_state, current_param_container, model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/paper_baselines/shampoo/pytorch/__init__.py b/reference_algorithms/paper_baselines/shampoo/pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json deleted file mode 100644 index 9d804ba0e..000000000 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "min": 1e-2, "max": 0.15, "scaling": "log" - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json deleted file mode 100644 index b8bd2ea49..000000000 --- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "learning_rate": { - "min": 1e-4, "max": 1e-2, "scaling": "log" - }, - "one_minus_beta1": { - "feasible_points": [0.1] - }, - "beta2": { - "feasible_points": [0.999] - }, - "warmup_factor": { - "feasible_points": [0.05] - }, - "weight_decay": { - "min": 5e-3, "max": 1.0, "scaling": "log" - }, - "label_smoothing": { - "feasible_points": [0.1, 0.2] - }, - "dropout_rate": { - "feasible_points": [0.0, 0.1] - } -} diff --git a/prize_qualification_baselines/README.md b/reference_algorithms/qualification_baselines/README.md similarity index 100% rename from prize_qualification_baselines/README.md rename to reference_algorithms/qualification_baselines/README.md diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_full_budget.py similarity index 89% rename from prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_full_budget.py index c451a18ac..5d4126be5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -1,7 +1,5 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -import functools - # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. from typing import (Any, @@ -16,13 +14,12 @@ # isort: on import chex -from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -192,24 +189,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -232,9 +223,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -272,7 +261,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -281,24 +269,52 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) + {'loss': loss.item(), 'grad_norm': grad_norm.item()}, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py similarity index 87% rename from prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py index b8ac10f33..48dd1571d 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -16,13 +16,13 @@ # isort: on import chex -from flax import jax_utils import jax from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -192,24 +192,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -232,9 +226,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Get mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -281,23 +273,55 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Create shardings for each argument + replicated = jax_sharding_utils.get_replicate_sharding() # No partitioning + sharded = jax_sharding_utils.get_batch_dim_sharding() # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/qualification_baselines/external_tuning/tuning_search_space.json similarity index 100% rename from prize_qualification_baselines/external_tuning/tuning_search_space.json rename to reference_algorithms/qualification_baselines/external_tuning/tuning_search_space.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_02-13-2024-13-02-24.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_02-13-2024-13-02-24.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_02-13-2024-13-02-24.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_02-13-2024-13-02-24.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/criteo1tb_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/fastmri_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-01-13.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-01-13.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-01-13.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-01-13.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_resnet_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-40-46.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-40-46.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-40-46.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-40-46.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/imagenet_vit_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-01-22-23.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-01-22-23.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-01-22-23.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-01-22-23.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_conformer_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-10-2024-22-49-44.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-10-2024-22-49-44.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-10-2024-22-49-44.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-10-2024-22-49-44.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_02-04-2024-21-11-51.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_02-04-2024-21-11-51.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_02-04-2024-21-11-51.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_02-04-2024-21-11-51.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/ogbg_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/wmt_jax_02-06-2024-10-57-53.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/wmt_jax_02-06-2024-10-57-53.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/wmt_jax_02-06-2024-10-57-53.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_0/wmt_jax/wmt_jax_02-06-2024-10-57-53.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-28-38.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-28-38.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-28-38.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-28-38.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/criteo1tb_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/fastmri_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-48.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-48.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-48.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-48.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_resnet_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-42-02.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-42-02.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-42-02.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-42-02.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/imagenet_vit_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-48-34.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-48-34.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-48-34.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-48-34.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_conformer_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-20-46.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-20-46.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-20-46.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-20-46.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_02-04-2024-21-48-01.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_02-04-2024-21-48-01.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_02-04-2024-21-48-01.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_02-04-2024-21-48-01.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/ogbg_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/wmt_jax_02-06-2024-11-39-02.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/wmt_jax_02-06-2024-11-39-02.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/wmt_jax_02-06-2024-11-39-02.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_1/wmt_jax/wmt_jax_02-06-2024-11-39-02.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-24-50.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-24-50.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-24-50.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-24-50.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/criteo1tb_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/fastmri_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-57.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-57.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-57.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-02-57.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_resnet_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-48-02.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-48-02.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-48-02.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-13-48-02.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/imagenet_vit_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-44-53.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-44-53.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-44-53.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-44-53.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_conformer_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-46-57.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-46-57.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-46-57.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-00-46-57.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_02-04-2024-23-24-16.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_02-04-2024-23-24-16.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_02-04-2024-23-24-16.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_02-04-2024-23-24-16.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/ogbg_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/wmt_jax_02-06-2024-13-15-17.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/wmt_jax_02-06-2024-13-15-17.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/wmt_jax_02-06-2024-13-15-17.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_2/wmt_jax/wmt_jax_02-06-2024-13-15-17.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_02-17-2024-11-56-37.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_02-17-2024-11-56-37.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_02-17-2024-11-56-37.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_02-17-2024-11-56-37.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/criteo1tb_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/fastmri_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-03-07.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-03-07.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-03-07.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-26-2024-19-03-07.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-30-2024-20-48-52.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-30-2024-20-48-52.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-30-2024-20-48-52.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_01-30-2024-20-48-52.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_resnet_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_01-26-2024-19-08-22.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_01-26-2024-19-08-22.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_01-26-2024-19-08-22.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_01-26-2024-19-08-22.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_02-03-2024-15-24-33.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_02-03-2024-15-24-33.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_02-03-2024-15-24-33.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_02-03-2024-15-24-33.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/imagenet_vit_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_02-18-2024-00-16-19.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_02-18-2024-00-16-19.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_02-18-2024-00-16-19.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_02-18-2024-00-16-19.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_conformer_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-15-2024-02-44-09.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-15-2024-02-44-09.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-15-2024-02-44-09.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-15-2024-02-44-09.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_02-09-2024-01-06-01.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_02-09-2024-01-06-01.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_02-09-2024-01-06-01.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_02-09-2024-01-06-01.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/ogbg_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/wmt_jax_02-10-2024-14-47-06.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/wmt_jax_02-10-2024-14-47-06.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/wmt_jax_02-10-2024-14-47-06.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_3/wmt_jax/wmt_jax_02-10-2024-14-47-06.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-38-52.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-38-52.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-38-52.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_02-13-2024-14-38-52.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/criteo1tb_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/fastmri_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_01-27-2024-01-35-38.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_01-27-2024-01-35-38.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_01-27-2024-01-35-38.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_01-27-2024-01-35-38.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_resnet_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-19-56-30.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-19-56-30.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-19-56-30.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_01-30-2024-19-56-30.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/imagenet_vit_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-58-52.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-58-52.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-58-52.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_02-14-2024-02-58-52.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_conformer_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-04-05-57.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-04-05-57.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-04-05-57.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_02-11-2024-04-05-57.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_02-05-2024-03-48-25.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_02-05-2024-03-48-25.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_02-05-2024-03-48-25.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_02-05-2024-03-48-25.log diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/ogbg_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_2/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_3/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_4/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/flags_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/flags_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/flags_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/hparams.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/hparams.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/hparams.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/hparams.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/trial_5/meta_data_0.json diff --git a/prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/wmt_jax_02-06-2024-17-34-24.log b/reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/wmt_jax_02-06-2024-17-34-24.log similarity index 100% rename from prize_qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/wmt_jax_02-06-2024-17-34-24.log rename to reference_algorithms/qualification_baselines/logs/external_tuning/full_budget/study_4/wmt_jax/wmt_jax_02-06-2024-17-34-24.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-14-45.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-14-45.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-14-45.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-14-45.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/fastmri_jax_03-12-2024-03-47-57.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/fastmri_jax_03-12-2024-03-47-57.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/fastmri_jax_03-12-2024-03-47-57.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/fastmri_jax_03-12-2024-03-47-57.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-10-31.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-10-31.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-10-31.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-10-31.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-32-59.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-32-59.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-32-59.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-32-59.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-09-52.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-09-52.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-09-52.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-09-52.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-07-50.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-07-50.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-07-50.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-07-50.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_03-05-2024-10-01-36.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_03-05-2024-10-01-36.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_03-05-2024-10-01-36.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/ogbg_jax_03-05-2024-10-01-36.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/wmt_jax_03-12-2024-04-28-16.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/wmt_jax_03-12-2024-04-28-16.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/wmt_jax_03-12-2024-04-28-16.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_0/wmt_jax/wmt_jax_03-12-2024-04-28-16.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-39-31.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-39-31.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-39-31.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/criteo1tb_jax_03-13-2024-10-39-31.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/fastmri_jax_03-12-2024-03-32-44.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/fastmri_jax_03-12-2024-03-32-44.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/fastmri_jax_03-12-2024-03-32-44.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/fastmri_jax_03-12-2024-03-32-44.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-32-47.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-32-47.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-32-47.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-32-47.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-55-17.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-55-17.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-55-17.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-55-17.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-19-59-18.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-19-59-18.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-19-59-18.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-19-59-18.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-37-44.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-37-44.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-37-44.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-23-37-44.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_03-05-2024-10-18-35.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_03-05-2024-10-18-35.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_03-05-2024-10-18-35.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/ogbg_jax_03-05-2024-10-18-35.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/wmt_jax_03-12-2024-04-28-04.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/wmt_jax_03-12-2024-04-28-04.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/wmt_jax_03-12-2024-04-28-04.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_1/wmt_jax/wmt_jax_03-12-2024-04-28-04.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_03-13-2024-09-51-13.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_03-13-2024-09-51-13.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_03-13-2024-09-51-13.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/criteo1tb_jax_03-13-2024-09-51-13.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/fastmri_jax_03-12-2024-03-34-27.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/fastmri_jax_03-12-2024-03-34-27.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/fastmri_jax_03-12-2024-03-34-27.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/fastmri_jax_03-12-2024-03-34-27.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-39.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-39.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-39.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-39.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-06-03.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-06-03.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-06-03.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-06-03.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-18-55-53.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-18-55-53.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-18-55-53.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-18-55-53.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-22-49-22.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-22-49-22.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-22-49-22.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-12-2024-22-49-22.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_03-05-2024-10-04-38.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_03-05-2024-10-04-38.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_03-05-2024-10-04-38.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/ogbg_jax_03-05-2024-10-04-38.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/wmt_jax_03-12-2024-04-14-45.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/wmt_jax_03-12-2024-04-14-45.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/wmt_jax_03-12-2024-04-14-45.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_2/wmt_jax/wmt_jax_03-12-2024-04-14-45.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_03-16-2024-12-03-01.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_03-16-2024-12-03-01.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_03-16-2024-12-03-01.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/criteo1tb_jax_03-16-2024-12-03-01.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/fastmri_jax_03-12-2024-03-34-36.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/fastmri_jax_03-12-2024-03-34-36.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/fastmri_jax_03-12-2024-03-34-36.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/fastmri_jax_03-12-2024-03-34-36.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-57.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-57.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-57.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/imagenet_resnet_jax_02-29-2024-05-13-57.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-31-55.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-31-55.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-31-55.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/imagenet_vit_jax_03-02-2024-11-31-55.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_03-16-2024-21-22-50.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_03-16-2024-21-22-50.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_03-16-2024-21-22-50.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/librispeech_conformer_jax_03-16-2024-21-22-50.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-15-2024-00-20-37.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-15-2024-00-20-37.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-15-2024-00-20-37.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-15-2024-00-20-37.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_03-05-2024-10-10-44.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_03-05-2024-10-10-44.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_03-05-2024-10-10-44.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/ogbg_jax_03-05-2024-10-10-44.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/wmt_jax_03-12-2024-04-09-55.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/wmt_jax_03-12-2024-04-09-55.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/wmt_jax_03-12-2024-04-09-55.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_3/wmt_jax/wmt_jax_03-12-2024-04-09-55.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_03-13-2024-12-00-36.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_03-13-2024-12-00-36.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_03-13-2024-12-00-36.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/criteo1tb_jax_03-13-2024-12-00-36.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/criteo1tb_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/fastmri_jax_03-12-2024-03-27-32.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/fastmri_jax_03-12-2024-03-27-32.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/fastmri_jax_03-12-2024-03-27-32.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/fastmri_jax_03-12-2024-03-27-32.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/fastmri_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_03-05-2024-22-45-59.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_03-05-2024-22-45-59.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_03-05-2024-22-45-59.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/imagenet_resnet_jax_03-05-2024-22-45-59.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_resnet_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_03-08-2024-05-00-35.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_03-08-2024-05-00-35.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_03-08-2024-05-00-35.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/imagenet_vit_jax_03-08-2024-05-00-35.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/imagenet_vit_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-34-38.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-34-38.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-34-38.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/librispeech_conformer_jax_03-13-2024-20-34-38.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_conformer_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-13-2024-00-19-52.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-13-2024-00-19-52.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-13-2024-00-19-52.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/librispeech_deepspeech_jax_03-13-2024-00-19-52.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/librispeech_deepspeech_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_03-11-2024-03-14-14.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_03-11-2024-03-14-14.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_03-11-2024-03-14-14.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/ogbg_jax_03-11-2024-03-14-14.log diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/ogbg_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/flags_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/trial_1/meta_data_0.json diff --git a/prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/wmt_jax_03-12-2024-04-02-58.log b/reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/wmt_jax_03-12-2024-04-02-58.log similarity index 100% rename from prize_qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/wmt_jax_03-12-2024-04-02-58.log rename to reference_algorithms/qualification_baselines/logs/self_tuning/full_budget/study_4/wmt_jax/wmt_jax_03-12-2024-04-02-58.log diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_full_budget.py similarity index 88% rename from prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 78c3b5b3e..d12042e9f 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -18,11 +18,11 @@ import chex from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -207,21 +207,15 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return jax_utils.replicate(optimizer_state), opt_update_fn -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -244,9 +238,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Get mean loss and grad. loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -296,15 +288,50 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Create shardings for each argument + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = jax_sharding_utils.get_replicate_sharding( + mesh) # No partitioning + sharded = jax_sharding_utils.get_batch_sharding( + mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_target_setting.py similarity index 88% rename from prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_target_setting.py index ffe854a0e..baefc12f3 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -1,7 +1,5 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -import functools - # isort: off # We have to turn off isort here to resolve a conflict between isort and yapf. from typing import (Any, @@ -16,13 +14,13 @@ # isort: on import chex -from flax import jax_utils import jax from jax import lax import jax.numpy as jnp import optax from algoperf import spec +from algoperf import jax_sharding_utils _GRAD_CLIP_EPS = 1e-6 @@ -204,24 +202,18 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): + return optimizer_state, opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -296,15 +288,50 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) + + # Create shardings for each argument + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = jax_sharding_utils.get_replicate_sharding( + mesh) # No partitioning + sharded = jax_sharding_utils.get_batch_sharding( + mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing) new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs # Log loss, grad_norm. diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index b64f0dfd6..b99d1fd94 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -41,4 +41,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index a6c3d853b..9da67e8f9 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=False) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 597a43c9e..1ba56bbda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 0c11044fc..533e23f2c 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -41,7 +41,7 @@ def init_optimizer_state(workload: spec.Workload, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn def create_lr_schedule_fn( diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 217228935..44c427736 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,32 +1,25 @@ """Update submission function in Jax.""" -import functools from typing import Any, Dict, List, Optional, Tuple import jax -from jax import lax import jax.numpy as jnp import optax +from algoperf import jax_sharding_utils from algoperf import spec _GRAD_CLIP_EPS = 1e-6 -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): def _loss_fn(params): """Loss function used for training.""" @@ -49,9 +42,7 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + # Compute mean loss and grad loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) @@ -89,7 +80,6 @@ def update_params( del eval_results optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing else: @@ -98,9 +88,44 @@ def update_params( grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long + # Create shardings for each argument + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = jax_sharding_utils.get_replicate_sharding( + mesh) # No partitioning + sharded = jax_sharding_utils.get_batch_sharding( + mesh) # Partition along batch dimension + + # Create the sharding rules for each argument + arg_shardings = ( + # workload is static + # opt_update_fn is static + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # rng + replicated, # grad_clip + replicated # label_smoothing + ) + out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated # grad_norm + ) + + # Jit with shardings + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings) + + new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, grad_clip, + current_param_container, batch, rng, grad_clip, label_smoothing) # Log loss, grad_norm. @@ -108,8 +133,8 @@ def update_params( workload.metrics_logger is not None): workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/scoring/plot_utils/plot_curves.py b/scoring/plot_utils/plot_curves.py new file mode 100644 index 000000000..e6f4636b7 --- /dev/null +++ b/scoring/plot_utils/plot_curves.py @@ -0,0 +1,90 @@ +from absl import flags +from absl import app +from absl import logging + +import re +import pandas as pd +import os +import wandb + +flags.DEFINE_string( + 'experiment_dir', + # '/home/kasimbeg/algoperf-runs-internal/experiments/jit_switch_debug_conformer_old_step_hint', + '/home/kasimbeg/submissions_algorithms/logs/external_tuning/baseline', + 'Path to experiment dir.') +flags.DEFINE_string( + 'workloads', + 'librispeech_conformer_jax', + 'Filter only for workload e.g. fastmri_jax. If None include all workloads in experiment.' +) +flags.DEFINE_string('project_name', + 'visualize-training-curves-legacy-stephint-conformer', + 'Wandb project name.') +flags.DEFINE_string('run_postfix', 'pmap', 'Postfix for wandb runs.') + +FLAGS = flags.FLAGS + +MEASUREMENTS_FILENAME = 'eval_measurements.csv' +TRIAL_DIR_REGEX = 'trial_(\d+)' + + +def get_filename(trial_dir): + filename = os.path.join(trial_dir, MEASUREMENTS_FILENAME) + return filename + + +def main(_): + experiment_dir = FLAGS.experiment_dir + study_dirs = os.listdir(experiment_dir) + for study_dir in study_dirs: + if not FLAGS.workloads: + workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir)) + workload_dirs = [ + w for w in workload_dirs + if os.path.isdir(os.path.join(experiment_dir, study_dir, w)) + ] + print(workload_dirs) + else: + workload_dirs = FLAGS.workloads.split(',') + for workload in workload_dirs: + logging.info(os.path.join(experiment_dir, study_dir, workload)) + trial_dirs = [ + t for t in os.listdir( + os.path.join(experiment_dir, study_dir, workload)) + if re.match(TRIAL_DIR_REGEX, t) + ] + for trial in trial_dirs: + trial_dir = os.path.join(FLAGS.experiment_dir, + study_dir, + workload, + trial) + print(trial_dir) + filename = get_filename(trial_dir) + if not os.path.exists(filename): + continue + + # Start a new W&B run + run = wandb.init( + project=FLAGS.project_name, + name=(f'{workload}_{study_dir}_{trial}' + FLAGS.run_postfix)) + + # Log the CSV as a versioned Artifact + artifact = wandb.Artifact(name="training-data", type="dataset") + artifact.add_file(filename) # Directly add the file + run.log_artifact(artifact) + + # Log the metrics for direct visualization + df = pd.read_csv(filename) + for index, row in df.iterrows(): + metrics = {col: row[col] for col in df.columns} + wandb.log(metrics, step=int(row["global_step"])) + + # Finish the W&B run --- + run.finish() + + return + + +if __name__ == '__main__': + + app.run(main) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index f07dc8cdd..7857da0af 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -11,6 +11,7 @@ --compute_performance_profiles \ --output_dir scoring_results_self_tuning \ --self_tuning_ruleset + """ import operator @@ -100,6 +101,9 @@ def get_summary_df(workload, workload_df, include_test_split=False): 'index to target on val'])] if x['val target reached'] else np.inf, axis=1) + summary_df['step_time (s)'] = (workload_df['accumulated_submission_time'] / + workload_df['global_step']).iloc[-1][-1] + # test metrics if include_test_split: test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test') diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py index 116e70459..7e363bf83 100644 --- a/scoring/utils/slurm/make_job_config.py +++ b/scoring/utils/slurm/make_job_config.py @@ -13,11 +13,10 @@ from absl import flags import jax -SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py' -EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2' -TUNING_SEARCH_SPACE = None -FRAMEWORK = 'pytorch' -TUNING_RULESET = 'self' +SUBMISSION_PATH = 'reference_algorithms/paper_baselines/adamw/jax/submission.py' +TUNING_SEARCH_SPACE = 'reference_algorithms/paper_baselines/adamw/tuning_search_space.json' +NUM_TUNING_TRIALS = 3 # For external tuning ruleset +NUM_STUDIES = 3 flags.DEFINE_string( 'submission_path', @@ -29,27 +28,28 @@ 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.' ) flags.DEFINE_string('experiment_dir', - EXPERIMENT_DIR, + 'experiments/', 'Path to experiment dir where logs will be saved.') flags.DEFINE_enum( 'framework', - FRAMEWORK, + 'jax', enum_values=['jax', 'pytorch'], help='Can be either pytorch or jax.') flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.') flags.DEFINE_enum( 'tuning_ruleset', - TUNING_RULESET, + 'external', enum_values=['external', 'self'], help='Which tuning ruleset to score this submission on. Can be external or self.' ) +flags.DEFINE_string( + 'workloads', None, help='Comma seperated list of workloads to run.') +flags.DEFINE_integer('num_studies', NUM_STUDIES, help='Number of studies.') FLAGS = flags.FLAGS MIN_INT = -2**(31) MAX_INT = 2**(31) - 1 -NUM_TUNING_TRIALS = 5 # For external tuning ruleset -NUM_STUDIES = 3 WORKLOADS = { "imagenet_resnet": {"dataset": "imagenet"}, @@ -64,7 +64,11 @@ def main(_): - workloads = WORKLOADS.keys() + if not FLAGS.workloads: + workloads = WORKLOADS.keys() + else: + workloads = FLAGS.workloads.split(',') + key = jax.random.key(FLAGS.seed) jobs = [] diff --git a/submission_runner.py b/submission_runner.py index 468a04c7c..16fb20c15 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -31,6 +31,10 @@ from absl import logging import jax import tensorflow as tf + +# New PRNG implementation for correct sharding +jax.config.update('jax_default_prng_impl', 'threefry2x32') +jax.config.update('jax_threefry_partitionable', True) import torch import torch.distributed as dist @@ -162,6 +166,13 @@ 'Number of workers for ImageNet PyTorch evaluation data loaders.' 'WARNING: Setting pytorch_eval_num_workers != 0, will result ' 'in incorrect evals currently, see issues/732.') +flags.DEFINE_boolean( + 'capture_jax_trace', + False, + 'Captures jax profiler trace and writes to experiment directory.') +flags.DEFINE_boolean('skip_evals', + False, + 'Skip evals on train eval, validation and test splits.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -209,7 +220,8 @@ def train_once( profiler: Profiler, max_global_steps: int = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True + save_checkpoints: Optional[bool] = True, + skip_evals: Optional[bool] = False, ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) @@ -385,7 +397,8 @@ def train_once( # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + workload.eval_period_time_sec or + train_state['training_complete']) and not skip_evals: # Prepare for evaluation (timed). if prepare_for_eval is not None: @@ -547,7 +560,9 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints: Optional[bool] = True, hparam_start_index: Optional[bool] = None, hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None): + rng_seed: Optional[int] = None, + capture_trace: Optional[bool] = False, + skip_evals: Optional[bool] = False): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -627,6 +642,9 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space[hi] = hyperparameters with profiler.profile('Train'): + if capture_trace: + logging.info(f'Capturing and saving jax trace to {log_dir}') + jax.profiler.start_trace(f'{log_dir}/traces'), timing, metrics = train_once(workload, workload_name, global_batch_size, global_eval_batch_size, @@ -640,7 +658,10 @@ def score_submission_on_workload(workload: spec.Workload, profiler, max_global_steps, tuning_dir_name, - save_checkpoints=save_checkpoints,) + save_checkpoints=save_checkpoints, + skip_evals=skip_evals) + if capture_trace: + jax.profiler.stop_trace() all_timings[hi] = timing all_metrics[hi] = metrics logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') @@ -665,12 +686,17 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Creating directory at {log_dir}.') logger_utils.makedir(log_dir) with profiler.profile('Train'): + if capture_trace: + jax.profiler.start_trace('/algoperf/traces'), + logging.info(f'Capturing and saving jax trace to {log_dir}') score, _ = train_once( workload, workload_name, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, prepare_for_eval, None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) + if capture_trace: + jax.profiler.stop_trace() return score @@ -746,7 +772,10 @@ def main(_): save_checkpoints=FLAGS.save_checkpoints, hparam_start_index=FLAGS.hparam_start_index, hparam_end_index=FLAGS.hparam_end_index, - rng_seed=FLAGS.rng_seed) + rng_seed=FLAGS.rng_seed, + capture_trace=FLAGS.capture_jax_trace, + skip_evals=FLAGS.skip_evals, + ) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index c4ca514a8..f576d136b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -225,7 +225,7 @@ def _build_input_queue(self, *args, **kwargs): del kwargs np.random.seed(42) - if framework == 'jax' or USE_PYTORCH_DDP: + if USE_PYTORCH_DDP: batch_shape = (n_gpus, global_batch_size // n_gpus) else: batch_shape = (global_batch_size,) @@ -422,6 +422,10 @@ def _test_submission(workload_name, global_batch_size = FLAGS.global_batch_size if FLAGS.global_batch_size < 0: raise ValueError('Must set --global_batch_size.') + elif global_batch_size < n_gpus and FLAGS.framework == 'jax': + raise ValueError( + 'Global batch size cannot be smaller than the number of GPUs when using JAX sharding.' + ) workload = _make_one_batch_workload(workload_class, workload_name, framework, diff --git a/tests/test_jax_sharding_invariance.py b/tests/test_jax_sharding_invariance.py new file mode 100644 index 000000000..82e3b38d4 --- /dev/null +++ b/tests/test_jax_sharding_invariance.py @@ -0,0 +1,97 @@ +"""Tests for sharding consistency in JAX workloads. + +Specifically this will test the model_init functions, and input_pipeline. +""" +import copy +import os +import sys + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized + +from algoperf.profiler import PassThroughProfiler +import submission_runner +from algoperf.workloads.workloads import import_workload +from algoperf.workloads.workloads import BASE_WORKLOADS_DIR +from algoperf.workloads.workloads import WORKLOADS + +FLAGS = flags.FLAGS +# Needed to avoid UnparsedFlagAccessError +# (see https://github.com/google/model_search/pull/8). +FLAGS(sys.argv) + +FRAMEWORK = 'jax' # Can extend to pytorch later + +test_case = dict(testcase_name='test_ogbg', workload='ogbg') + + +class SubmissionRunnerTest(parameterized.TestCase): + """Tests for reference submissions.""" + + @parameterized.named_parameters(test_case) + def test_invariance(self, workload_name): + workload_name = 'ogbg' + dataset_dir = f'/data/{workload_name}' + workload_metadata = copy.deepcopy(WORKLOADS[workload_name]) + workload_metadata['workload_path'] = os.path.join( + BASE_WORKLOADS_DIR, + workload_metadata['workload_path'] + '_' + FRAMEWORK, + 'workload.py') + workload = import_workload( + workload_path=workload_metadata['workload_path'], + workload_class_name=workload_metadata['workload_class_name'], + workload_init_kwargs={}) + + rng = jax.random.PRNGKey(0) + initial_params, model_state = workload.init_model_fn(rng) + data_iter = workload._build_input_queue(rng, 'train', dataset_dir, 32) + batch = next(data_iter) + inputs = batch['inputs'] + + def forward_pass( + params, + batch, + model_state, + rng, + ): + logits, _ = workload.model_fn(initial_params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + return logits + + forward_pass_jitted = jax.jit( + forward_pass, + in_shardings=( + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_batch_dim_sharding(), + jax_sharding_utils.get_replicate_sharding(), + jax_sharding_utils.get_replicate_sharding(), + ), + out_shardings=jax_sharding_utils.get_batch_dim_sharding()) + + logits = forward_pass( + initial_params, + batch, + model_state, + rng, + ) + + logits_jitted = forward_pass_jitted( + initial_params, + batch, + model_state, + rng, + ) + + jax.debug.visualize_array_sharding(logits_jitted) + + equal = jnp.allclose(logits, logits_jitted, atol=1e-6) + + +if __name__ == '__main__': + absltest.main()